diff --git a/.gitattributes b/.gitattributes index a4062ffa5f9dc16e30bdb8b07bc0d922f7f249f3..018b81829131af87e6ecbbcceafc1fb0c1266a0c 100644 --- a/.gitattributes +++ b/.gitattributes @@ -37,3 +37,10 @@ phivenv/Lib/site-packages/numpy/random/_common.cp39-win_amd64.pyd filter=lfs dif phivenv/Lib/site-packages/numpy/random/_bounded_integers.cp39-win_amd64.pyd filter=lfs diff=lfs merge=lfs -text phivenv/Lib/site-packages/numpy/random/lib/npyrandom.lib filter=lfs diff=lfs merge=lfs -text phivenv/Lib/site-packages/numpy/random/_generator.cp39-win_amd64.pyd filter=lfs diff=lfs merge=lfs -text +phivenv/Lib/site-packages/numpy/_core/_multiarray_umath.cp39-win_amd64.pyd filter=lfs diff=lfs merge=lfs -text +phivenv/Lib/site-packages/numpy/_core/_simd.cp39-win_amd64.pyd filter=lfs diff=lfs merge=lfs -text +phivenv/Lib/site-packages/numpy/_core/lib/npymath.lib filter=lfs diff=lfs merge=lfs -text +phivenv/Lib/site-packages/numpy/_core/tests/__pycache__/test_numeric.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text +phivenv/Lib/site-packages/numpy/_core/tests/__pycache__/test_multiarray.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text +phivenv/Lib/site-packages/numpy/_core/tests/__pycache__/test_ufunc.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text +phivenv/Lib/site-packages/numpy/_core/tests/__pycache__/test_regression.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text diff --git a/phivenv/Lib/site-packages/numpy/_core/_multiarray_umath.cp39-win_amd64.pyd b/phivenv/Lib/site-packages/numpy/_core/_multiarray_umath.cp39-win_amd64.pyd new file mode 100644 index 0000000000000000000000000000000000000000..b64e263362111751c71b15e6cab9358966a1abfe --- /dev/null +++ b/phivenv/Lib/site-packages/numpy/_core/_multiarray_umath.cp39-win_amd64.pyd @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ba5feeec2aa7bffd9f88c597091942598b3d4f3f3a5f9d2f6b16d7b49996434d +size 4056064 diff --git a/phivenv/Lib/site-packages/numpy/_core/_simd.cp39-win_amd64.pyd b/phivenv/Lib/site-packages/numpy/_core/_simd.cp39-win_amd64.pyd new file mode 100644 index 0000000000000000000000000000000000000000..f2bcbeefb08e7fe6ba389b4cc4f8316ee8c6f06e --- /dev/null +++ b/phivenv/Lib/site-packages/numpy/_core/_simd.cp39-win_amd64.pyd @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cade7e0ccfb92a09158fef11cf77583e3e843bf6dc8596430880594a293eeced +size 2243072 diff --git a/phivenv/Lib/site-packages/numpy/_core/lib/npymath.lib b/phivenv/Lib/site-packages/numpy/_core/lib/npymath.lib new file mode 100644 index 0000000000000000000000000000000000000000..72408a036b9b6cefe38322966db6734cbd4173c3 --- /dev/null +++ b/phivenv/Lib/site-packages/numpy/_core/lib/npymath.lib @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:93eb50206a762449ce1b7218c757b4feb48bf1c2f0cacc79d439e7ebdf40dea7 +size 150974 diff --git a/phivenv/Lib/site-packages/numpy/_core/tests/__pycache__/test_multiarray.cpython-39.pyc b/phivenv/Lib/site-packages/numpy/_core/tests/__pycache__/test_multiarray.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06a1a4c0b7018a792978c2a2460022193a1c520f --- /dev/null +++ b/phivenv/Lib/site-packages/numpy/_core/tests/__pycache__/test_multiarray.cpython-39.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ee150b1a6e9a95357d8cdf59d57871f2369a4e50dbfb0b825c1edf20cfe292df +size 339072 diff --git a/phivenv/Lib/site-packages/numpy/_core/tests/__pycache__/test_numeric.cpython-39.pyc b/phivenv/Lib/site-packages/numpy/_core/tests/__pycache__/test_numeric.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bcb15f602613a4f357e122118c732bbb0c1d2a63 --- /dev/null +++ b/phivenv/Lib/site-packages/numpy/_core/tests/__pycache__/test_numeric.cpython-39.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:872f528e011fd37ce55f998aa1dc1305f935b5f965c047ab20b9bb38ac2e5422 +size 130208 diff --git a/phivenv/Lib/site-packages/numpy/_core/tests/__pycache__/test_regression.cpython-39.pyc b/phivenv/Lib/site-packages/numpy/_core/tests/__pycache__/test_regression.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0161cfd09213e7e13a5ca6e8bf475bd82808d4a5 --- /dev/null +++ b/phivenv/Lib/site-packages/numpy/_core/tests/__pycache__/test_regression.cpython-39.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ecee8b0257be1d3cd6569d2e1bc2b94053c23a531633cdedca80bbf93c2954bb +size 100896 diff --git a/phivenv/Lib/site-packages/numpy/_core/tests/__pycache__/test_ufunc.cpython-39.pyc b/phivenv/Lib/site-packages/numpy/_core/tests/__pycache__/test_ufunc.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21f2f7d75f3bd233944f77d50741af7fec31dff9 --- /dev/null +++ b/phivenv/Lib/site-packages/numpy/_core/tests/__pycache__/test_ufunc.cpython-39.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8540bfeec02898b6273fdb0e4aa5ca5adf0f88e5fa65413ad1ef22b55ee56946 +size 103080 diff --git a/phivenv/Lib/site-packages/sympy/polys/__init__.py b/phivenv/Lib/site-packages/sympy/polys/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8055ed12d213de3ebc7a1f17100607fb1e3b89b8 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/__init__.py @@ -0,0 +1,130 @@ +"""Polynomial manipulation algorithms and algebraic objects. """ + +__all__ = [ + 'Poly', 'PurePoly', 'poly_from_expr', 'parallel_poly_from_expr', 'degree', + 'total_degree', 'degree_list', 'LC', 'LM', 'LT', 'pdiv', 'prem', 'pquo', + 'pexquo', 'div', 'rem', 'quo', 'exquo', 'half_gcdex', 'gcdex', 'invert', + 'subresultants', 'resultant', 'discriminant', 'cofactors', 'gcd_list', + 'gcd', 'lcm_list', 'lcm', 'terms_gcd', 'trunc', 'monic', 'content', + 'primitive', 'compose', 'decompose', 'sturm', 'gff_list', 'gff', + 'sqf_norm', 'sqf_part', 'sqf_list', 'sqf', 'factor_list', 'factor', + 'intervals', 'refine_root', 'count_roots', 'all_roots', 'real_roots', + 'nroots', 'ground_roots', 'nth_power_roots_poly', 'cancel', 'reduced', + 'groebner', 'is_zero_dimensional', 'GroebnerBasis', 'poly', + + 'symmetrize', 'horner', 'interpolate', 'rational_interpolate', 'viete', + + 'together', + + 'BasePolynomialError', 'ExactQuotientFailed', 'PolynomialDivisionFailed', + 'OperationNotSupported', 'HeuristicGCDFailed', 'HomomorphismFailed', + 'IsomorphismFailed', 'ExtraneousFactors', 'EvaluationFailed', + 'RefinementFailed', 'CoercionFailed', 'NotInvertible', 'NotReversible', + 'NotAlgebraic', 'DomainError', 'PolynomialError', 'UnificationFailed', + 'GeneratorsError', 'GeneratorsNeeded', 'ComputationFailed', + 'UnivariatePolynomialError', 'MultivariatePolynomialError', + 'PolificationFailed', 'OptionError', 'FlagError', + + 'minpoly', 'minimal_polynomial', 'primitive_element', 'field_isomorphism', + 'to_number_field', 'isolate', 'round_two', 'prime_decomp', + 'prime_valuation', 'galois_group', + + 'itermonomials', 'Monomial', + + 'lex', 'grlex', 'grevlex', 'ilex', 'igrlex', 'igrevlex', + + 'CRootOf', 'rootof', 'RootOf', 'ComplexRootOf', 'RootSum', + + 'roots', + + 'Domain', 'FiniteField', 'IntegerRing', 'RationalField', 'RealField', + 'ComplexField', 'PythonFiniteField', 'GMPYFiniteField', + 'PythonIntegerRing', 'GMPYIntegerRing', 'PythonRational', + 'GMPYRationalField', 'AlgebraicField', 'PolynomialRing', 'FractionField', + 'ExpressionDomain', 'FF_python', 'FF_gmpy', 'ZZ_python', 'ZZ_gmpy', + 'QQ_python', 'QQ_gmpy', 'GF', 'FF', 'ZZ', 'QQ', 'ZZ_I', 'QQ_I', 'RR', + 'CC', 'EX', 'EXRAW', + + 'construct_domain', + + 'swinnerton_dyer_poly', 'cyclotomic_poly', 'symmetric_poly', + 'random_poly', 'interpolating_poly', + + 'jacobi_poly', 'chebyshevt_poly', 'chebyshevu_poly', 'hermite_poly', + 'hermite_prob_poly', 'legendre_poly', 'laguerre_poly', + + 'bernoulli_poly', 'bernoulli_c_poly', 'genocchi_poly', 'euler_poly', + 'andre_poly', + + 'apart', 'apart_list', 'assemble_partfrac_list', + + 'Options', + + 'ring', 'xring', 'vring', 'sring', + + 'field', 'xfield', 'vfield', 'sfield' +] + +from .polytools import (Poly, PurePoly, poly_from_expr, + parallel_poly_from_expr, degree, total_degree, degree_list, LC, LM, + LT, pdiv, prem, pquo, pexquo, div, rem, quo, exquo, half_gcdex, gcdex, + invert, subresultants, resultant, discriminant, cofactors, gcd_list, + gcd, lcm_list, lcm, terms_gcd, trunc, monic, content, primitive, + compose, decompose, sturm, gff_list, gff, sqf_norm, sqf_part, + sqf_list, sqf, factor_list, factor, intervals, refine_root, + count_roots, all_roots, real_roots, nroots, ground_roots, + nth_power_roots_poly, cancel, reduced, groebner, is_zero_dimensional, + GroebnerBasis, poly) + +from .polyfuncs import (symmetrize, horner, interpolate, + rational_interpolate, viete) + +from .rationaltools import together + +from .polyerrors import (BasePolynomialError, ExactQuotientFailed, + PolynomialDivisionFailed, OperationNotSupported, HeuristicGCDFailed, + HomomorphismFailed, IsomorphismFailed, ExtraneousFactors, + EvaluationFailed, RefinementFailed, CoercionFailed, NotInvertible, + NotReversible, NotAlgebraic, DomainError, PolynomialError, + UnificationFailed, GeneratorsError, GeneratorsNeeded, + ComputationFailed, UnivariatePolynomialError, + MultivariatePolynomialError, PolificationFailed, OptionError, + FlagError) + +from .numberfields import (minpoly, minimal_polynomial, primitive_element, + field_isomorphism, to_number_field, isolate, round_two, prime_decomp, + prime_valuation, galois_group) + +from .monomials import itermonomials, Monomial + +from .orderings import lex, grlex, grevlex, ilex, igrlex, igrevlex + +from .rootoftools import CRootOf, rootof, RootOf, ComplexRootOf, RootSum + +from .polyroots import roots + +from .domains import (Domain, FiniteField, IntegerRing, RationalField, + RealField, ComplexField, PythonFiniteField, GMPYFiniteField, + PythonIntegerRing, GMPYIntegerRing, PythonRational, GMPYRationalField, + AlgebraicField, PolynomialRing, FractionField, ExpressionDomain, + FF_python, FF_gmpy, ZZ_python, ZZ_gmpy, QQ_python, QQ_gmpy, GF, FF, + ZZ, QQ, ZZ_I, QQ_I, RR, CC, EX, EXRAW) + +from .constructor import construct_domain + +from .specialpolys import (swinnerton_dyer_poly, cyclotomic_poly, + symmetric_poly, random_poly, interpolating_poly) + +from .orthopolys import (jacobi_poly, chebyshevt_poly, chebyshevu_poly, + hermite_poly, hermite_prob_poly, legendre_poly, laguerre_poly) + +from .appellseqs import (bernoulli_poly, bernoulli_c_poly, genocchi_poly, + euler_poly, andre_poly) + +from .partfrac import apart, apart_list, assemble_partfrac_list + +from .polyoptions import Options + +from .rings import ring, xring, vring, sring + +from .fields import field, xfield, vfield, sfield diff --git a/phivenv/Lib/site-packages/sympy/polys/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f72aa349046664c85ece97b03a1ec01006a9c647 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/__pycache__/appellseqs.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/__pycache__/appellseqs.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d124670e7b51ea27ed60f48e5d38e268f0659ad Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/__pycache__/appellseqs.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/__pycache__/compatibility.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/__pycache__/compatibility.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6c77eae762b4b7884bdb2792b7d2160709a4f23 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/__pycache__/compatibility.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/__pycache__/constructor.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/__pycache__/constructor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41f3033407f45657393d9335b2c5ac50f83941ff Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/__pycache__/constructor.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/__pycache__/densearith.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/__pycache__/densearith.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8449f943e10363abc2e6c01199cc94cd29f0bcfd Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/__pycache__/densearith.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/__pycache__/densebasic.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/__pycache__/densebasic.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fe33e79063e22c5df37f10df2c8166d9a06c7ce Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/__pycache__/densebasic.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/__pycache__/densetools.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/__pycache__/densetools.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f2abe363eec8b1ef6599c5bf7e7a16f173cb390 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/__pycache__/densetools.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/__pycache__/dispersion.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/__pycache__/dispersion.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9119cc218c53cd3ac69e756963eeffa030d7524e Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/__pycache__/dispersion.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/__pycache__/distributedmodules.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/__pycache__/distributedmodules.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f700417181a7989ca390dec0332583c5d759450e Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/__pycache__/distributedmodules.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/__pycache__/domainmatrix.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/__pycache__/domainmatrix.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a2cf6a4aa9aa62458135ddae1f89e83e2651169f Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/__pycache__/domainmatrix.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/__pycache__/euclidtools.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/__pycache__/euclidtools.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6be1ca0967fe094a6d1111d50b238e951d67dc89 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/__pycache__/euclidtools.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/__pycache__/factortools.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/__pycache__/factortools.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..593409afd006ba5e87f536ed7c3065699e10c306 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/__pycache__/factortools.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/__pycache__/fglmtools.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/__pycache__/fglmtools.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28592cf0111693f2f73a046bcb4eb1970531d13d Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/__pycache__/fglmtools.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/__pycache__/fields.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/__pycache__/fields.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f351c9d91dcf21db7c99ffbc1fed59e162053ab Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/__pycache__/fields.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/__pycache__/galoistools.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/__pycache__/galoistools.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de852aaeee00d64179c306b1eff3ad48a4fac951 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/__pycache__/galoistools.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/__pycache__/groebnertools.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/__pycache__/groebnertools.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4c545fadd7bc09206896c653eb7b6115ead68e8 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/__pycache__/groebnertools.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/__pycache__/heuristicgcd.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/__pycache__/heuristicgcd.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b3e2fd1fda7e269e9e13286a3f5b279cd641b123 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/__pycache__/heuristicgcd.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/__pycache__/modulargcd.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/__pycache__/modulargcd.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0bcb31ff2b18dae0e3cf37b2d9f906f4e2f061c7 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/__pycache__/modulargcd.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/__pycache__/monomials.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/__pycache__/monomials.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9eed94068ac0bedd41eb60044af8f06ed2f02b49 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/__pycache__/monomials.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/__pycache__/multivariate_resultants.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/__pycache__/multivariate_resultants.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..348f224be889043a7d18034e4a0c5990bb9160c7 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/__pycache__/multivariate_resultants.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/__pycache__/orderings.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/__pycache__/orderings.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b95e2f393385e74958443923e519cb387a4b1226 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/__pycache__/orderings.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/__pycache__/orthopolys.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/__pycache__/orthopolys.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a0981dacd8e2c255a91da557f3cd7b2dd5f715c7 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/__pycache__/orthopolys.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/__pycache__/partfrac.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/__pycache__/partfrac.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4fce633e801b9e0cab12f95b6b49f0cd1a77d29a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/__pycache__/partfrac.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/__pycache__/polyconfig.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/__pycache__/polyconfig.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a39e8b2cffc4264c3bd7e2a0dfd5c67ff55e7559 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/__pycache__/polyconfig.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/__pycache__/polyerrors.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/__pycache__/polyerrors.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e76755c2994f5106d097316e6f439744684603b9 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/__pycache__/polyerrors.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/__pycache__/polyfuncs.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/__pycache__/polyfuncs.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa69ed012a0860583769d9e70c88846985921b05 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/__pycache__/polyfuncs.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/__pycache__/polymatrix.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/__pycache__/polymatrix.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3b4e1baa9b19426e412fa00326eecf9f8518b58 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/__pycache__/polymatrix.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/__pycache__/polyoptions.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/__pycache__/polyoptions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..034f5379b5ea5991bcb05e8980d5ec0a26d11baa Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/__pycache__/polyoptions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/__pycache__/polyroots.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/__pycache__/polyroots.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4140057d6f3e6fe608b40efbbbc8452e03b2adfa Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/__pycache__/polyroots.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/__pycache__/polyutils.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/__pycache__/polyutils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..310aa58060549afc2e6e0b977e4aba394b51839f Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/__pycache__/polyutils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/__pycache__/puiseux.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/__pycache__/puiseux.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e3ec48cba2e68ac9a8f7d28d3e77d36ef5e7b40 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/__pycache__/puiseux.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/__pycache__/rationaltools.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/__pycache__/rationaltools.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b20de85f147733839016d8a907977a178ad41dbf Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/__pycache__/rationaltools.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/__pycache__/ring_series.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/__pycache__/ring_series.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..481714f4e8856c45e74b7d1b5347889f2b871306 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/__pycache__/ring_series.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/__pycache__/rings.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/__pycache__/rings.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd23df93f76f818d82e2d6d724db551e2ff4266a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/__pycache__/rings.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/__pycache__/rootisolation.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/__pycache__/rootisolation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f78b7f3a7ea027f16ebf58af3dbaa4fb1083928a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/__pycache__/rootisolation.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/__pycache__/rootoftools.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/__pycache__/rootoftools.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac03e06fd859d507069c484241f645f669ec90c8 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/__pycache__/rootoftools.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/__pycache__/solvers.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/__pycache__/solvers.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68aa8b2de4a7db6135a1d79f364f54430996cc63 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/__pycache__/solvers.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/__pycache__/specialpolys.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/__pycache__/specialpolys.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8501a358e0ff19d4ec9ab7b8dadf4dbf92955bcf Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/__pycache__/specialpolys.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/__pycache__/sqfreetools.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/__pycache__/sqfreetools.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9b0fbc6e123b0bbf8a85586e7f1fe35dd6dca3d Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/__pycache__/sqfreetools.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/__pycache__/subresultants_qq_zz.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/__pycache__/subresultants_qq_zz.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bab67f66e25780fb59df9d56425329e11d7116b1 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/__pycache__/subresultants_qq_zz.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/agca/__init__.py b/phivenv/Lib/site-packages/sympy/polys/agca/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..43486eec86f7453ec470a22b8f8decaa316cbe70 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/agca/__init__.py @@ -0,0 +1,5 @@ +"""Module for algebraic geometry and commutative algebra.""" + +from .homomorphisms import homomorphism + +__all__ = ['homomorphism'] diff --git a/phivenv/Lib/site-packages/sympy/polys/agca/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/agca/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39094afbe7c0584b0735b15e014e603f682a8117 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/agca/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/agca/__pycache__/extensions.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/agca/__pycache__/extensions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..554b549851d78c5279ca28bf8cdbf7517ce23d31 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/agca/__pycache__/extensions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/agca/__pycache__/homomorphisms.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/agca/__pycache__/homomorphisms.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a707732c293e99c66239f91465adb2422e7faa48 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/agca/__pycache__/homomorphisms.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/agca/__pycache__/ideals.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/agca/__pycache__/ideals.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4c34ef95e4187b1b0e4c25fecd3c558b863129b Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/agca/__pycache__/ideals.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/agca/__pycache__/modules.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/agca/__pycache__/modules.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..00160c75157a40a5493f5f3ad28caf0a7f328eca Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/agca/__pycache__/modules.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/agca/extensions.py b/phivenv/Lib/site-packages/sympy/polys/agca/extensions.py new file mode 100644 index 0000000000000000000000000000000000000000..2668f792b5721db877f275e57ed54961b2e4df93 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/agca/extensions.py @@ -0,0 +1,356 @@ +"""Finite extensions of ring domains.""" + +from sympy.polys.domains.domain import Domain +from sympy.polys.domains.domainelement import DomainElement +from sympy.polys.polyerrors import (CoercionFailed, NotInvertible, + GeneratorsError) +from sympy.polys.polytools import Poly +from sympy.printing.defaults import DefaultPrinting + + +class ExtensionElement(DomainElement, DefaultPrinting): + """ + Element of a finite extension. + + A class of univariate polynomials modulo the ``modulus`` + of the extension ``ext``. It is represented by the + unique polynomial ``rep`` of lowest degree. Both + ``rep`` and the representation ``mod`` of ``modulus`` + are of class DMP. + + """ + __slots__ = ('rep', 'ext') + + def __init__(self, rep, ext): + self.rep = rep + self.ext = ext + + def parent(f): + return f.ext + + def as_expr(f): + return f.ext.to_sympy(f) + + def __bool__(f): + return bool(f.rep) + + def __pos__(f): + return f + + def __neg__(f): + return ExtElem(-f.rep, f.ext) + + def _get_rep(f, g): + if isinstance(g, ExtElem): + if g.ext == f.ext: + return g.rep + else: + return None + else: + try: + g = f.ext.convert(g) + return g.rep + except CoercionFailed: + return None + + def __add__(f, g): + rep = f._get_rep(g) + if rep is not None: + return ExtElem(f.rep + rep, f.ext) + else: + return NotImplemented + + __radd__ = __add__ + + def __sub__(f, g): + rep = f._get_rep(g) + if rep is not None: + return ExtElem(f.rep - rep, f.ext) + else: + return NotImplemented + + def __rsub__(f, g): + rep = f._get_rep(g) + if rep is not None: + return ExtElem(rep - f.rep, f.ext) + else: + return NotImplemented + + def __mul__(f, g): + rep = f._get_rep(g) + if rep is not None: + return ExtElem((f.rep * rep) % f.ext.mod, f.ext) + else: + return NotImplemented + + __rmul__ = __mul__ + + def _divcheck(f): + """Raise if division is not implemented for this divisor""" + if not f: + raise NotInvertible('Zero divisor') + elif f.ext.is_Field: + return True + elif f.rep.is_ground and f.ext.domain.is_unit(f.rep.LC()): + return True + else: + # Some cases like (2*x + 2)/2 over ZZ will fail here. It is + # unclear how to implement division in general if the ground + # domain is not a field so for now it was decided to restrict the + # implementation to division by invertible constants. + msg = (f"Can not invert {f} in {f.ext}. " + "Only division by invertible constants is implemented.") + raise NotImplementedError(msg) + + def inverse(f): + """Multiplicative inverse. + + Raises + ====== + + NotInvertible + If the element is a zero divisor. + + """ + f._divcheck() + + if f.ext.is_Field: + invrep = f.rep.invert(f.ext.mod) + else: + R = f.ext.ring + invrep = R.exquo(R.one, f.rep) + + return ExtElem(invrep, f.ext) + + def __truediv__(f, g): + rep = f._get_rep(g) + if rep is None: + return NotImplemented + g = ExtElem(rep, f.ext) + + try: + ginv = g.inverse() + except NotInvertible: + raise ZeroDivisionError(f"{f} / {g}") + + return f * ginv + + __floordiv__ = __truediv__ + + def __rtruediv__(f, g): + try: + g = f.ext.convert(g) + except CoercionFailed: + return NotImplemented + return g / f + + __rfloordiv__ = __rtruediv__ + + def __mod__(f, g): + rep = f._get_rep(g) + if rep is None: + return NotImplemented + g = ExtElem(rep, f.ext) + + try: + g._divcheck() + except NotInvertible: + raise ZeroDivisionError(f"{f} % {g}") + + # Division where defined is always exact so there is no remainder + return f.ext.zero + + def __rmod__(f, g): + try: + g = f.ext.convert(g) + except CoercionFailed: + return NotImplemented + return g % f + + def __pow__(f, n): + if not isinstance(n, int): + raise TypeError("exponent of type 'int' expected") + if n < 0: + try: + f, n = f.inverse(), -n + except NotImplementedError: + raise ValueError("negative powers are not defined") + + b = f.rep + m = f.ext.mod + r = f.ext.one.rep + while n > 0: + if n % 2: + r = (r*b) % m + b = (b*b) % m + n //= 2 + + return ExtElem(r, f.ext) + + def __eq__(f, g): + if isinstance(g, ExtElem): + return f.rep == g.rep and f.ext == g.ext + else: + return NotImplemented + + def __ne__(f, g): + return not f == g + + def __hash__(f): + return hash((f.rep, f.ext)) + + def __str__(f): + from sympy.printing.str import sstr + return sstr(f.as_expr()) + + __repr__ = __str__ + + @property + def is_ground(f): + return f.rep.is_ground + + def to_ground(f): + [c] = f.rep.to_list() + return c + +ExtElem = ExtensionElement + + +class MonogenicFiniteExtension(Domain): + r""" + Finite extension generated by an integral element. + + The generator is defined by a monic univariate + polynomial derived from the argument ``mod``. + + A shorter alias is ``FiniteExtension``. + + Examples + ======== + + Quadratic integer ring $\mathbb{Z}[\sqrt2]$: + + >>> from sympy import Symbol, Poly + >>> from sympy.polys.agca.extensions import FiniteExtension + >>> x = Symbol('x') + >>> R = FiniteExtension(Poly(x**2 - 2)); R + ZZ[x]/(x**2 - 2) + >>> R.rank + 2 + >>> R(1 + x)*(3 - 2*x) + x - 1 + + Finite field $GF(5^3)$ defined by the primitive + polynomial $x^3 + x^2 + 2$ (over $\mathbb{Z}_5$). + + >>> F = FiniteExtension(Poly(x**3 + x**2 + 2, modulus=5)); F + GF(5)[x]/(x**3 + x**2 + 2) + >>> F.basis + (1, x, x**2) + >>> F(x + 3)/(x**2 + 2) + -2*x**2 + x + 2 + + Function field of an elliptic curve: + + >>> t = Symbol('t') + >>> FiniteExtension(Poly(t**2 - x**3 - x + 1, t, field=True)) + ZZ(x)[t]/(t**2 - x**3 - x + 1) + + """ + is_FiniteExtension = True + + dtype = ExtensionElement + + def __init__(self, mod): + if not (isinstance(mod, Poly) and mod.is_univariate): + raise TypeError("modulus must be a univariate Poly") + + # Using auto=True (default) potentially changes the ground domain to a + # field whereas auto=False raises if division is not exact. We'll let + # the caller decide whether or not they want to put the ground domain + # over a field. In most uses mod is already monic. + mod = mod.monic(auto=False) + + self.rank = mod.degree() + self.modulus = mod + self.mod = mod.rep # DMP representation + + self.domain = dom = mod.domain + self.ring = dom.old_poly_ring(*mod.gens) + + self.zero = self.convert(self.ring.zero) + self.one = self.convert(self.ring.one) + + gen = self.ring.gens[0] + self.symbol = self.ring.symbols[0] + self.generator = self.convert(gen) + self.basis = tuple(self.convert(gen**i) for i in range(self.rank)) + + # XXX: It might be necessary to check mod.is_irreducible here + self.is_Field = self.domain.is_Field + + def new(self, arg): + rep = self.ring.convert(arg) + return ExtElem(rep % self.mod, self) + + def __eq__(self, other): + if not isinstance(other, FiniteExtension): + return False + return self.modulus == other.modulus + + def __hash__(self): + return hash((self.__class__.__name__, self.modulus)) + + def __str__(self): + return "%s/(%s)" % (self.ring, self.modulus.as_expr()) + + __repr__ = __str__ + + @property + def has_CharacteristicZero(self): + return self.domain.has_CharacteristicZero + + def characteristic(self): + return self.domain.characteristic() + + def convert(self, f, base=None): + rep = self.ring.convert(f, base) + return ExtElem(rep % self.mod, self) + + def convert_from(self, f, base): + rep = self.ring.convert(f, base) + return ExtElem(rep % self.mod, self) + + def to_sympy(self, f): + return self.ring.to_sympy(f.rep) + + def from_sympy(self, f): + return self.convert(f) + + def set_domain(self, K): + mod = self.modulus.set_domain(K) + return self.__class__(mod) + + def drop(self, *symbols): + if self.symbol in symbols: + raise GeneratorsError('Can not drop generator from FiniteExtension') + K = self.domain.drop(*symbols) + return self.set_domain(K) + + def quo(self, f, g): + return self.exquo(f, g) + + def exquo(self, f, g): + rep = self.ring.exquo(f.rep, g.rep) + return ExtElem(rep % self.mod, self) + + def is_negative(self, a): + return False + + def is_unit(self, a): + if self.is_Field: + return bool(a) + elif a.is_ground: + return self.domain.is_unit(a.to_ground()) + +FiniteExtension = MonogenicFiniteExtension diff --git a/phivenv/Lib/site-packages/sympy/polys/agca/homomorphisms.py b/phivenv/Lib/site-packages/sympy/polys/agca/homomorphisms.py new file mode 100644 index 0000000000000000000000000000000000000000..45e9549980a8848eee944000d321922576961a00 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/agca/homomorphisms.py @@ -0,0 +1,691 @@ +""" +Computations with homomorphisms of modules and rings. + +This module implements classes for representing homomorphisms of rings and +their modules. Instead of instantiating the classes directly, you should use +the function ``homomorphism(from, to, matrix)`` to create homomorphism objects. +""" + + +from sympy.polys.agca.modules import (Module, FreeModule, QuotientModule, + SubModule, SubQuotientModule) +from sympy.polys.polyerrors import CoercionFailed + +# The main computational task for module homomorphisms is kernels. +# For this reason, the concrete classes are organised by domain module type. + + +class ModuleHomomorphism: + """ + Abstract base class for module homomoprhisms. Do not instantiate. + + Instead, use the ``homomorphism`` function: + + >>> from sympy import QQ + >>> from sympy.abc import x + >>> from sympy.polys.agca import homomorphism + + >>> F = QQ.old_poly_ring(x).free_module(2) + >>> homomorphism(F, F, [[1, 0], [0, 1]]) + Matrix([ + [1, 0], : QQ[x]**2 -> QQ[x]**2 + [0, 1]]) + + Attributes: + + - ring - the ring over which we are considering modules + - domain - the domain module + - codomain - the codomain module + - _ker - cached kernel + - _img - cached image + + Non-implemented methods: + + - _kernel + - _image + - _restrict_domain + - _restrict_codomain + - _quotient_domain + - _quotient_codomain + - _apply + - _mul_scalar + - _compose + - _add + """ + + def __init__(self, domain, codomain): + if not isinstance(domain, Module): + raise TypeError('Source must be a module, got %s' % domain) + if not isinstance(codomain, Module): + raise TypeError('Target must be a module, got %s' % codomain) + if domain.ring != codomain.ring: + raise ValueError('Source and codomain must be over same ring, ' + 'got %s != %s' % (domain, codomain)) + self.domain = domain + self.codomain = codomain + self.ring = domain.ring + self._ker = None + self._img = None + + def kernel(self): + r""" + Compute the kernel of ``self``. + + That is, if ``self`` is the homomorphism `\phi: M \to N`, then compute + `ker(\phi) = \{x \in M | \phi(x) = 0\}`. This is a submodule of `M`. + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.abc import x + >>> from sympy.polys.agca import homomorphism + + >>> F = QQ.old_poly_ring(x).free_module(2) + >>> homomorphism(F, F, [[1, 0], [x, 0]]).kernel() + <[x, -1]> + """ + if self._ker is None: + self._ker = self._kernel() + return self._ker + + def image(self): + r""" + Compute the image of ``self``. + + That is, if ``self`` is the homomorphism `\phi: M \to N`, then compute + `im(\phi) = \{\phi(x) | x \in M \}`. This is a submodule of `N`. + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.abc import x + >>> from sympy.polys.agca import homomorphism + + >>> F = QQ.old_poly_ring(x).free_module(2) + >>> homomorphism(F, F, [[1, 0], [x, 0]]).image() == F.submodule([1, 0]) + True + """ + if self._img is None: + self._img = self._image() + return self._img + + def _kernel(self): + """Compute the kernel of ``self``.""" + raise NotImplementedError + + def _image(self): + """Compute the image of ``self``.""" + raise NotImplementedError + + def _restrict_domain(self, sm): + """Implementation of domain restriction.""" + raise NotImplementedError + + def _restrict_codomain(self, sm): + """Implementation of codomain restriction.""" + raise NotImplementedError + + def _quotient_domain(self, sm): + """Implementation of domain quotient.""" + raise NotImplementedError + + def _quotient_codomain(self, sm): + """Implementation of codomain quotient.""" + raise NotImplementedError + + def restrict_domain(self, sm): + """ + Return ``self``, with the domain restricted to ``sm``. + + Here ``sm`` has to be a submodule of ``self.domain``. + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.abc import x + >>> from sympy.polys.agca import homomorphism + + >>> F = QQ.old_poly_ring(x).free_module(2) + >>> h = homomorphism(F, F, [[1, 0], [x, 0]]) + >>> h + Matrix([ + [1, x], : QQ[x]**2 -> QQ[x]**2 + [0, 0]]) + >>> h.restrict_domain(F.submodule([1, 0])) + Matrix([ + [1, x], : <[1, 0]> -> QQ[x]**2 + [0, 0]]) + + This is the same as just composing on the right with the submodule + inclusion: + + >>> h * F.submodule([1, 0]).inclusion_hom() + Matrix([ + [1, x], : <[1, 0]> -> QQ[x]**2 + [0, 0]]) + """ + if not self.domain.is_submodule(sm): + raise ValueError('sm must be a submodule of %s, got %s' + % (self.domain, sm)) + if sm == self.domain: + return self + return self._restrict_domain(sm) + + def restrict_codomain(self, sm): + """ + Return ``self``, with codomain restricted to to ``sm``. + + Here ``sm`` has to be a submodule of ``self.codomain`` containing the + image. + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.abc import x + >>> from sympy.polys.agca import homomorphism + + >>> F = QQ.old_poly_ring(x).free_module(2) + >>> h = homomorphism(F, F, [[1, 0], [x, 0]]) + >>> h + Matrix([ + [1, x], : QQ[x]**2 -> QQ[x]**2 + [0, 0]]) + >>> h.restrict_codomain(F.submodule([1, 0])) + Matrix([ + [1, x], : QQ[x]**2 -> <[1, 0]> + [0, 0]]) + """ + if not sm.is_submodule(self.image()): + raise ValueError('the image %s must contain sm, got %s' + % (self.image(), sm)) + if sm == self.codomain: + return self + return self._restrict_codomain(sm) + + def quotient_domain(self, sm): + """ + Return ``self`` with domain replaced by ``domain/sm``. + + Here ``sm`` must be a submodule of ``self.kernel()``. + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.abc import x + >>> from sympy.polys.agca import homomorphism + + >>> F = QQ.old_poly_ring(x).free_module(2) + >>> h = homomorphism(F, F, [[1, 0], [x, 0]]) + >>> h + Matrix([ + [1, x], : QQ[x]**2 -> QQ[x]**2 + [0, 0]]) + >>> h.quotient_domain(F.submodule([-x, 1])) + Matrix([ + [1, x], : QQ[x]**2/<[-x, 1]> -> QQ[x]**2 + [0, 0]]) + """ + if not self.kernel().is_submodule(sm): + raise ValueError('kernel %s must contain sm, got %s' % + (self.kernel(), sm)) + if sm.is_zero(): + return self + return self._quotient_domain(sm) + + def quotient_codomain(self, sm): + """ + Return ``self`` with codomain replaced by ``codomain/sm``. + + Here ``sm`` must be a submodule of ``self.codomain``. + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.abc import x + >>> from sympy.polys.agca import homomorphism + + >>> F = QQ.old_poly_ring(x).free_module(2) + >>> h = homomorphism(F, F, [[1, 0], [x, 0]]) + >>> h + Matrix([ + [1, x], : QQ[x]**2 -> QQ[x]**2 + [0, 0]]) + >>> h.quotient_codomain(F.submodule([1, 1])) + Matrix([ + [1, x], : QQ[x]**2 -> QQ[x]**2/<[1, 1]> + [0, 0]]) + + This is the same as composing with the quotient map on the left: + + >>> (F/[(1, 1)]).quotient_hom() * h + Matrix([ + [1, x], : QQ[x]**2 -> QQ[x]**2/<[1, 1]> + [0, 0]]) + """ + if not self.codomain.is_submodule(sm): + raise ValueError('sm must be a submodule of codomain %s, got %s' + % (self.codomain, sm)) + if sm.is_zero(): + return self + return self._quotient_codomain(sm) + + def _apply(self, elem): + """Apply ``self`` to ``elem``.""" + raise NotImplementedError + + def __call__(self, elem): + return self.codomain.convert(self._apply(self.domain.convert(elem))) + + def _compose(self, oth): + """ + Compose ``self`` with ``oth``, that is, return the homomorphism + obtained by first applying then ``self``, then ``oth``. + + (This method is private since in this syntax, it is non-obvious which + homomorphism is executed first.) + """ + raise NotImplementedError + + def _mul_scalar(self, c): + """Scalar multiplication. ``c`` is guaranteed in self.ring.""" + raise NotImplementedError + + def _add(self, oth): + """ + Homomorphism addition. + ``oth`` is guaranteed to be a homomorphism with same domain/codomain. + """ + raise NotImplementedError + + def _check_hom(self, oth): + """Helper to check that oth is a homomorphism with same domain/codomain.""" + if not isinstance(oth, ModuleHomomorphism): + return False + return oth.domain == self.domain and oth.codomain == self.codomain + + def __mul__(self, oth): + if isinstance(oth, ModuleHomomorphism) and self.domain == oth.codomain: + return oth._compose(self) + try: + return self._mul_scalar(self.ring.convert(oth)) + except CoercionFailed: + return NotImplemented + + # NOTE: _compose will never be called from rmul + __rmul__ = __mul__ + + def __truediv__(self, oth): + try: + return self._mul_scalar(1/self.ring.convert(oth)) + except CoercionFailed: + return NotImplemented + + def __add__(self, oth): + if self._check_hom(oth): + return self._add(oth) + return NotImplemented + + def __sub__(self, oth): + if self._check_hom(oth): + return self._add(oth._mul_scalar(self.ring.convert(-1))) + return NotImplemented + + def is_injective(self): + """ + Return True if ``self`` is injective. + + That is, check if the elements of the domain are mapped to the same + codomain element. + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.abc import x + >>> from sympy.polys.agca import homomorphism + + >>> F = QQ.old_poly_ring(x).free_module(2) + >>> h = homomorphism(F, F, [[1, 0], [x, 0]]) + >>> h.is_injective() + False + >>> h.quotient_domain(h.kernel()).is_injective() + True + """ + return self.kernel().is_zero() + + def is_surjective(self): + """ + Return True if ``self`` is surjective. + + That is, check if every element of the codomain has at least one + preimage. + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.abc import x + >>> from sympy.polys.agca import homomorphism + + >>> F = QQ.old_poly_ring(x).free_module(2) + >>> h = homomorphism(F, F, [[1, 0], [x, 0]]) + >>> h.is_surjective() + False + >>> h.restrict_codomain(h.image()).is_surjective() + True + """ + return self.image() == self.codomain + + def is_isomorphism(self): + """ + Return True if ``self`` is an isomorphism. + + That is, check if every element of the codomain has precisely one + preimage. Equivalently, ``self`` is both injective and surjective. + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.abc import x + >>> from sympy.polys.agca import homomorphism + + >>> F = QQ.old_poly_ring(x).free_module(2) + >>> h = homomorphism(F, F, [[1, 0], [x, 0]]) + >>> h = h.restrict_codomain(h.image()) + >>> h.is_isomorphism() + False + >>> h.quotient_domain(h.kernel()).is_isomorphism() + True + """ + return self.is_injective() and self.is_surjective() + + def is_zero(self): + """ + Return True if ``self`` is a zero morphism. + + That is, check if every element of the domain is mapped to zero + under self. + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.abc import x + >>> from sympy.polys.agca import homomorphism + + >>> F = QQ.old_poly_ring(x).free_module(2) + >>> h = homomorphism(F, F, [[1, 0], [x, 0]]) + >>> h.is_zero() + False + >>> h.restrict_domain(F.submodule()).is_zero() + True + >>> h.quotient_codomain(h.image()).is_zero() + True + """ + return self.image().is_zero() + + def __eq__(self, oth): + try: + return (self - oth).is_zero() + except TypeError: + return False + + def __ne__(self, oth): + return not (self == oth) + + +class MatrixHomomorphism(ModuleHomomorphism): + r""" + Helper class for all homomoprhisms which are expressed via a matrix. + + That is, for such homomorphisms ``domain`` is contained in a module + generated by finitely many elements `e_1, \ldots, e_n`, so that the + homomorphism is determined uniquely by its action on the `e_i`. It + can thus be represented as a vector of elements of the codomain module, + or potentially a supermodule of the codomain module + (and hence conventionally as a matrix, if there is a similar interpretation + for elements of the codomain module). + + Note that this class does *not* assume that the `e_i` freely generate a + submodule, nor that ``domain`` is even all of this submodule. It exists + only to unify the interface. + + Do not instantiate. + + Attributes: + + - matrix - the list of images determining the homomorphism. + NOTE: the elements of matrix belong to either self.codomain or + self.codomain.container + + Still non-implemented methods: + + - kernel + - _apply + """ + + def __init__(self, domain, codomain, matrix): + ModuleHomomorphism.__init__(self, domain, codomain) + if len(matrix) != domain.rank: + raise ValueError('Need to provide %s elements, got %s' + % (domain.rank, len(matrix))) + + converter = self.codomain.convert + if isinstance(self.codomain, (SubModule, SubQuotientModule)): + converter = self.codomain.container.convert + self.matrix = tuple(converter(x) for x in matrix) + + def _sympy_matrix(self): + """Helper function which returns a SymPy matrix ``self.matrix``.""" + from sympy.matrices import Matrix + c = lambda x: x + if isinstance(self.codomain, (QuotientModule, SubQuotientModule)): + c = lambda x: x.data + return Matrix([[self.ring.to_sympy(y) for y in c(x)] for x in self.matrix]).T + + def __repr__(self): + lines = repr(self._sympy_matrix()).split('\n') + t = " : %s -> %s" % (self.domain, self.codomain) + s = ' '*len(t) + n = len(lines) + for i in range(n // 2): + lines[i] += s + lines[n // 2] += t + for i in range(n//2 + 1, n): + lines[i] += s + return '\n'.join(lines) + + def _restrict_domain(self, sm): + """Implementation of domain restriction.""" + return SubModuleHomomorphism(sm, self.codomain, self.matrix) + + def _restrict_codomain(self, sm): + """Implementation of codomain restriction.""" + return self.__class__(self.domain, sm, self.matrix) + + def _quotient_domain(self, sm): + """Implementation of domain quotient.""" + return self.__class__(self.domain/sm, self.codomain, self.matrix) + + def _quotient_codomain(self, sm): + """Implementation of codomain quotient.""" + Q = self.codomain/sm + converter = Q.convert + if isinstance(self.codomain, SubModule): + converter = Q.container.convert + return self.__class__(self.domain, self.codomain/sm, + [converter(x) for x in self.matrix]) + + def _add(self, oth): + return self.__class__(self.domain, self.codomain, + [x + y for x, y in zip(self.matrix, oth.matrix)]) + + def _mul_scalar(self, c): + return self.__class__(self.domain, self.codomain, [c*x for x in self.matrix]) + + def _compose(self, oth): + return self.__class__(self.domain, oth.codomain, [oth(x) for x in self.matrix]) + + +class FreeModuleHomomorphism(MatrixHomomorphism): + """ + Concrete class for homomorphisms with domain a free module or a quotient + thereof. + + Do not instantiate; the constructor does not check that your data is well + defined. Use the ``homomorphism`` function instead: + + >>> from sympy import QQ + >>> from sympy.abc import x + >>> from sympy.polys.agca import homomorphism + + >>> F = QQ.old_poly_ring(x).free_module(2) + >>> homomorphism(F, F, [[1, 0], [0, 1]]) + Matrix([ + [1, 0], : QQ[x]**2 -> QQ[x]**2 + [0, 1]]) + """ + + def _apply(self, elem): + if isinstance(self.domain, QuotientModule): + elem = elem.data + return sum(x * e for x, e in zip(elem, self.matrix)) + + def _image(self): + return self.codomain.submodule(*self.matrix) + + def _kernel(self): + # The domain is either a free module or a quotient thereof. + # It does not matter if it is a quotient, because that won't increase + # the kernel. + # Our generators {e_i} are sent to the matrix entries {b_i}. + # The kernel is essentially the syzygy module of these {b_i}. + syz = self.image().syzygy_module() + return self.domain.submodule(*syz.gens) + + +class SubModuleHomomorphism(MatrixHomomorphism): + """ + Concrete class for homomorphism with domain a submodule of a free module + or a quotient thereof. + + Do not instantiate; the constructor does not check that your data is well + defined. Use the ``homomorphism`` function instead: + + >>> from sympy import QQ + >>> from sympy.abc import x + >>> from sympy.polys.agca import homomorphism + + >>> M = QQ.old_poly_ring(x).free_module(2)*x + >>> homomorphism(M, M, [[1, 0], [0, 1]]) + Matrix([ + [1, 0], : <[x, 0], [0, x]> -> <[x, 0], [0, x]> + [0, 1]]) + """ + + def _apply(self, elem): + if isinstance(self.domain, SubQuotientModule): + elem = elem.data + return sum(x * e for x, e in zip(elem, self.matrix)) + + def _image(self): + return self.codomain.submodule(*[self(x) for x in self.domain.gens]) + + def _kernel(self): + syz = self.image().syzygy_module() + return self.domain.submodule( + *[sum(xi*gi for xi, gi in zip(s, self.domain.gens)) + for s in syz.gens]) + + +def homomorphism(domain, codomain, matrix): + r""" + Create a homomorphism object. + + This function tries to build a homomorphism from ``domain`` to ``codomain`` + via the matrix ``matrix``. + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.abc import x + >>> from sympy.polys.agca import homomorphism + + >>> R = QQ.old_poly_ring(x) + >>> T = R.free_module(2) + + If ``domain`` is a free module generated by `e_1, \ldots, e_n`, then + ``matrix`` should be an n-element iterable `(b_1, \ldots, b_n)` where + the `b_i` are elements of ``codomain``. The constructed homomorphism is the + unique homomorphism sending `e_i` to `b_i`. + + >>> F = R.free_module(2) + >>> h = homomorphism(F, T, [[1, x], [x**2, 0]]) + >>> h + Matrix([ + [1, x**2], : QQ[x]**2 -> QQ[x]**2 + [x, 0]]) + >>> h([1, 0]) + [1, x] + >>> h([0, 1]) + [x**2, 0] + >>> h([1, 1]) + [x**2 + 1, x] + + If ``domain`` is a submodule of a free module, them ``matrix`` determines + a homomoprhism from the containing free module to ``codomain``, and the + homomorphism returned is obtained by restriction to ``domain``. + + >>> S = F.submodule([1, 0], [0, x]) + >>> homomorphism(S, T, [[1, x], [x**2, 0]]) + Matrix([ + [1, x**2], : <[1, 0], [0, x]> -> QQ[x]**2 + [x, 0]]) + + If ``domain`` is a (sub)quotient `N/K`, then ``matrix`` determines a + homomorphism from `N` to ``codomain``. If the kernel contains `K`, this + homomorphism descends to ``domain`` and is returned; otherwise an exception + is raised. + + >>> homomorphism(S/[(1, 0)], T, [0, [x**2, 0]]) + Matrix([ + [0, x**2], : <[1, 0] + <[1, 0]>, [0, x] + <[1, 0]>, [1, 0] + <[1, 0]>> -> QQ[x]**2 + [0, 0]]) + >>> homomorphism(S/[(0, x)], T, [0, [x**2, 0]]) + Traceback (most recent call last): + ... + ValueError: kernel <[1, 0], [0, 0]> must contain sm, got <[0,x]> + + """ + def freepres(module): + """ + Return a tuple ``(F, S, Q, c)`` where ``F`` is a free module, ``S`` is a + submodule of ``F``, and ``Q`` a submodule of ``S``, such that + ``module = S/Q``, and ``c`` is a conversion function. + """ + if isinstance(module, FreeModule): + return module, module, module.submodule(), lambda x: module.convert(x) + if isinstance(module, QuotientModule): + return (module.base, module.base, module.killed_module, + lambda x: module.convert(x).data) + if isinstance(module, SubQuotientModule): + return (module.base.container, module.base, module.killed_module, + lambda x: module.container.convert(x).data) + # an ordinary submodule + return (module.container, module, module.submodule(), + lambda x: module.container.convert(x)) + + SF, SS, SQ, _ = freepres(domain) + TF, TS, TQ, c = freepres(codomain) + # NOTE this is probably a bit inefficient (redundant checks) + return FreeModuleHomomorphism(SF, TF, [c(x) for x in matrix] + ).restrict_domain(SS).restrict_codomain(TS + ).quotient_codomain(TQ).quotient_domain(SQ) diff --git a/phivenv/Lib/site-packages/sympy/polys/agca/ideals.py b/phivenv/Lib/site-packages/sympy/polys/agca/ideals.py new file mode 100644 index 0000000000000000000000000000000000000000..1969554a1d674bc36ded1a3e312d587c66104086 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/agca/ideals.py @@ -0,0 +1,395 @@ +"""Computations with ideals of polynomial rings.""" + +from sympy.polys.polyerrors import CoercionFailed +from sympy.polys.polyutils import IntegerPowerable + + +class Ideal(IntegerPowerable): + """ + Abstract base class for ideals. + + Do not instantiate - use explicit constructors in the ring class instead: + + >>> from sympy import QQ + >>> from sympy.abc import x + >>> QQ.old_poly_ring(x).ideal(x+1) + + + Attributes + + - ring - the ring this ideal belongs to + + Non-implemented methods: + + - _contains_elem + - _contains_ideal + - _quotient + - _intersect + - _union + - _product + - is_whole_ring + - is_zero + - is_prime, is_maximal, is_primary, is_radical + - is_principal + - height, depth + - radical + + Methods that likely should be overridden in subclasses: + + - reduce_element + """ + + def _contains_elem(self, x): + """Implementation of element containment.""" + raise NotImplementedError + + def _contains_ideal(self, I): + """Implementation of ideal containment.""" + raise NotImplementedError + + def _quotient(self, J): + """Implementation of ideal quotient.""" + raise NotImplementedError + + def _intersect(self, J): + """Implementation of ideal intersection.""" + raise NotImplementedError + + def is_whole_ring(self): + """Return True if ``self`` is the whole ring.""" + raise NotImplementedError + + def is_zero(self): + """Return True if ``self`` is the zero ideal.""" + raise NotImplementedError + + def _equals(self, J): + """Implementation of ideal equality.""" + return self._contains_ideal(J) and J._contains_ideal(self) + + def is_prime(self): + """Return True if ``self`` is a prime ideal.""" + raise NotImplementedError + + def is_maximal(self): + """Return True if ``self`` is a maximal ideal.""" + raise NotImplementedError + + def is_radical(self): + """Return True if ``self`` is a radical ideal.""" + raise NotImplementedError + + def is_primary(self): + """Return True if ``self`` is a primary ideal.""" + raise NotImplementedError + + def is_principal(self): + """Return True if ``self`` is a principal ideal.""" + raise NotImplementedError + + def radical(self): + """Compute the radical of ``self``.""" + raise NotImplementedError + + def depth(self): + """Compute the depth of ``self``.""" + raise NotImplementedError + + def height(self): + """Compute the height of ``self``.""" + raise NotImplementedError + + # TODO more + + # non-implemented methods end here + + def __init__(self, ring): + self.ring = ring + + def _check_ideal(self, J): + """Helper to check ``J`` is an ideal of our ring.""" + if not isinstance(J, Ideal) or J.ring != self.ring: + raise ValueError( + 'J must be an ideal of %s, got %s' % (self.ring, J)) + + def contains(self, elem): + """ + Return True if ``elem`` is an element of this ideal. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy import QQ + >>> QQ.old_poly_ring(x).ideal(x+1, x-1).contains(3) + True + >>> QQ.old_poly_ring(x).ideal(x**2, x**3).contains(x) + False + """ + return self._contains_elem(self.ring.convert(elem)) + + def subset(self, other): + """ + Returns True if ``other`` is is a subset of ``self``. + + Here ``other`` may be an ideal. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy import QQ + >>> I = QQ.old_poly_ring(x).ideal(x+1) + >>> I.subset([x**2 - 1, x**2 + 2*x + 1]) + True + >>> I.subset([x**2 + 1, x + 1]) + False + >>> I.subset(QQ.old_poly_ring(x).ideal(x**2 - 1)) + True + """ + if isinstance(other, Ideal): + return self._contains_ideal(other) + return all(self._contains_elem(x) for x in other) + + def quotient(self, J, **opts): + r""" + Compute the ideal quotient of ``self`` by ``J``. + + That is, if ``self`` is the ideal `I`, compute the set + `I : J = \{x \in R | xJ \subset I \}`. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy import QQ + >>> R = QQ.old_poly_ring(x, y) + >>> R.ideal(x*y).quotient(R.ideal(x)) + + """ + self._check_ideal(J) + return self._quotient(J, **opts) + + def intersect(self, J): + """ + Compute the intersection of self with ideal J. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy import QQ + >>> R = QQ.old_poly_ring(x, y) + >>> R.ideal(x).intersect(R.ideal(y)) + + """ + self._check_ideal(J) + return self._intersect(J) + + def saturate(self, J): + r""" + Compute the ideal saturation of ``self`` by ``J``. + + That is, if ``self`` is the ideal `I`, compute the set + `I : J^\infty = \{x \in R | xJ^n \subset I \text{ for some } n\}`. + """ + raise NotImplementedError + # Note this can be implemented using repeated quotient + + def union(self, J): + """ + Compute the ideal generated by the union of ``self`` and ``J``. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy import QQ + >>> QQ.old_poly_ring(x).ideal(x**2 - 1).union(QQ.old_poly_ring(x).ideal((x+1)**2)) == QQ.old_poly_ring(x).ideal(x+1) + True + """ + self._check_ideal(J) + return self._union(J) + + def product(self, J): + r""" + Compute the ideal product of ``self`` and ``J``. + + That is, compute the ideal generated by products `xy`, for `x` an element + of ``self`` and `y \in J`. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy import QQ + >>> QQ.old_poly_ring(x, y).ideal(x).product(QQ.old_poly_ring(x, y).ideal(y)) + + """ + self._check_ideal(J) + return self._product(J) + + def reduce_element(self, x): + """ + Reduce the element ``x`` of our ring modulo the ideal ``self``. + + Here "reduce" has no specific meaning: it could return a unique normal + form, simplify the expression a bit, or just do nothing. + """ + return x + + def __add__(self, e): + if not isinstance(e, Ideal): + R = self.ring.quotient_ring(self) + if isinstance(e, R.dtype): + return e + if isinstance(e, R.ring.dtype): + return R(e) + return R.convert(e) + self._check_ideal(e) + return self.union(e) + + __radd__ = __add__ + + def __mul__(self, e): + if not isinstance(e, Ideal): + try: + e = self.ring.ideal(e) + except CoercionFailed: + return NotImplemented + self._check_ideal(e) + return self.product(e) + + __rmul__ = __mul__ + + def _zeroth_power(self): + return self.ring.ideal(1) + + def _first_power(self): + # Raising to any power but 1 returns a new instance. So we mult by 1 + # here so that the first power is no exception. + return self * 1 + + def __eq__(self, e): + if not isinstance(e, Ideal) or e.ring != self.ring: + return False + return self._equals(e) + + def __ne__(self, e): + return not (self == e) + + +class ModuleImplementedIdeal(Ideal): + """ + Ideal implementation relying on the modules code. + + Attributes: + + - _module - the underlying module + """ + + def __init__(self, ring, module): + Ideal.__init__(self, ring) + self._module = module + + def _contains_elem(self, x): + return self._module.contains([x]) + + def _contains_ideal(self, J): + if not isinstance(J, ModuleImplementedIdeal): + raise NotImplementedError + return self._module.is_submodule(J._module) + + def _intersect(self, J): + if not isinstance(J, ModuleImplementedIdeal): + raise NotImplementedError + return self.__class__(self.ring, self._module.intersect(J._module)) + + def _quotient(self, J, **opts): + if not isinstance(J, ModuleImplementedIdeal): + raise NotImplementedError + return self._module.module_quotient(J._module, **opts) + + def _union(self, J): + if not isinstance(J, ModuleImplementedIdeal): + raise NotImplementedError + return self.__class__(self.ring, self._module.union(J._module)) + + @property + def gens(self): + """ + Return generators for ``self``. + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.abc import x, y + >>> list(QQ.old_poly_ring(x, y).ideal(x, y, x**2 + y).gens) + [DMP_Python([[1], []], QQ), DMP_Python([[1, 0]], QQ), DMP_Python([[1], [], [1, 0]], QQ)] + """ + return (x[0] for x in self._module.gens) + + def is_zero(self): + """ + Return True if ``self`` is the zero ideal. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy import QQ + >>> QQ.old_poly_ring(x).ideal(x).is_zero() + False + >>> QQ.old_poly_ring(x).ideal().is_zero() + True + """ + return self._module.is_zero() + + def is_whole_ring(self): + """ + Return True if ``self`` is the whole ring, i.e. one generator is a unit. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy import QQ, ilex + >>> QQ.old_poly_ring(x).ideal(x).is_whole_ring() + False + >>> QQ.old_poly_ring(x).ideal(3).is_whole_ring() + True + >>> QQ.old_poly_ring(x, order=ilex).ideal(2 + x).is_whole_ring() + True + """ + return self._module.is_full_module() + + def __repr__(self): + from sympy.printing.str import sstr + gens = [self.ring.to_sympy(x) for [x] in self._module.gens] + return '<' + ','.join(sstr(g) for g in gens) + '>' + + # NOTE this is the only method using the fact that the module is a SubModule + def _product(self, J): + if not isinstance(J, ModuleImplementedIdeal): + raise NotImplementedError + return self.__class__(self.ring, self._module.submodule( + *[[x*y] for [x] in self._module.gens for [y] in J._module.gens])) + + def in_terms_of_generators(self, e): + """ + Express ``e`` in terms of the generators of ``self``. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy import QQ + >>> I = QQ.old_poly_ring(x).ideal(x**2 + 1, x) + >>> I.in_terms_of_generators(1) # doctest: +SKIP + [DMP_Python([1], QQ), DMP_Python([-1, 0], QQ)] + """ + return self._module.in_terms_of_generators([e]) + + def reduce_element(self, x, **options): + return self._module.reduce_element([x], **options)[0] diff --git a/phivenv/Lib/site-packages/sympy/polys/agca/modules.py b/phivenv/Lib/site-packages/sympy/polys/agca/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..0a2e2ed814f4143b4b49f8b1f10c2a07cb32d06a --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/agca/modules.py @@ -0,0 +1,1488 @@ +""" +Computations with modules over polynomial rings. + +This module implements various classes that encapsulate groebner basis +computations for modules. Most of them should not be instantiated by hand. +Instead, use the constructing routines on objects you already have. + +For example, to construct a free module over ``QQ[x, y]``, call +``QQ[x, y].free_module(rank)`` instead of the ``FreeModule`` constructor. +In fact ``FreeModule`` is an abstract base class that should not be +instantiated, the ``free_module`` method instead returns the implementing class +``FreeModulePolyRing``. + +In general, the abstract base classes implement most functionality in terms of +a few non-implemented methods. The concrete base classes supply only these +non-implemented methods. They may also supply new implementations of the +convenience methods, for example if there are faster algorithms available. +""" + + +from copy import copy +from functools import reduce + +from sympy.polys.agca.ideals import Ideal +from sympy.polys.domains.field import Field +from sympy.polys.orderings import ProductOrder, monomial_key +from sympy.polys.polyclasses import DMP +from sympy.polys.polyerrors import CoercionFailed +from sympy.core.basic import _aresame +from sympy.utilities.iterables import iterable + +# TODO +# - module saturation +# - module quotient/intersection for quotient rings +# - free resoltutions / syzygies +# - finding small/minimal generating sets +# - ... + +########################################################################## +## Abstract base classes ################################################# +########################################################################## + + +class Module: + """ + Abstract base class for modules. + + Do not instantiate - use ring explicit constructors instead: + + >>> from sympy import QQ + >>> from sympy.abc import x + >>> QQ.old_poly_ring(x).free_module(2) + QQ[x]**2 + + Attributes: + + - dtype - type of elements + - ring - containing ring + + Non-implemented methods: + + - submodule + - quotient_module + - is_zero + - is_submodule + - multiply_ideal + + The method convert likely needs to be changed in subclasses. + """ + + def __init__(self, ring): + self.ring = ring + + def convert(self, elem, M=None): + """ + Convert ``elem`` into internal representation of this module. + + If ``M`` is not None, it should be a module containing it. + """ + if not isinstance(elem, self.dtype): + raise CoercionFailed + return elem + + def submodule(self, *gens): + """Generate a submodule.""" + raise NotImplementedError + + def quotient_module(self, other): + """Generate a quotient module.""" + raise NotImplementedError + + def __truediv__(self, e): + if not isinstance(e, Module): + e = self.submodule(*e) + return self.quotient_module(e) + + def contains(self, elem): + """Return True if ``elem`` is an element of this module.""" + try: + self.convert(elem) + return True + except CoercionFailed: + return False + + def __contains__(self, elem): + return self.contains(elem) + + def subset(self, other): + """ + Returns True if ``other`` is is a subset of ``self``. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy import QQ + >>> F = QQ.old_poly_ring(x).free_module(2) + >>> F.subset([(1, x), (x, 2)]) + True + >>> F.subset([(1/x, x), (x, 2)]) + False + """ + return all(self.contains(x) for x in other) + + def __eq__(self, other): + return self.is_submodule(other) and other.is_submodule(self) + + def __ne__(self, other): + return not (self == other) + + def is_zero(self): + """Returns True if ``self`` is a zero module.""" + raise NotImplementedError + + def is_submodule(self, other): + """Returns True if ``other`` is a submodule of ``self``.""" + raise NotImplementedError + + def multiply_ideal(self, other): + """ + Multiply ``self`` by the ideal ``other``. + """ + raise NotImplementedError + + def __mul__(self, e): + if not isinstance(e, Ideal): + try: + e = self.ring.ideal(e) + except (CoercionFailed, NotImplementedError): + return NotImplemented + return self.multiply_ideal(e) + + __rmul__ = __mul__ + + def identity_hom(self): + """Return the identity homomorphism on ``self``.""" + raise NotImplementedError + + +class ModuleElement: + """ + Base class for module element wrappers. + + Use this class to wrap primitive data types as module elements. It stores + a reference to the containing module, and implements all the arithmetic + operators. + + Attributes: + + - module - containing module + - data - internal data + + Methods that likely need change in subclasses: + + - add + - mul + - div + - eq + """ + + def __init__(self, module, data): + self.module = module + self.data = data + + def add(self, d1, d2): + """Add data ``d1`` and ``d2``.""" + return d1 + d2 + + def mul(self, m, d): + """Multiply module data ``m`` by coefficient d.""" + return m * d + + def div(self, m, d): + """Divide module data ``m`` by coefficient d.""" + return m / d + + def eq(self, d1, d2): + """Return true if d1 and d2 represent the same element.""" + return d1 == d2 + + def __add__(self, om): + if not isinstance(om, self.__class__) or om.module != self.module: + try: + om = self.module.convert(om) + except CoercionFailed: + return NotImplemented + return self.__class__(self.module, self.add(self.data, om.data)) + + __radd__ = __add__ + + def __neg__(self): + return self.__class__(self.module, self.mul(self.data, + self.module.ring.convert(-1))) + + def __sub__(self, om): + if not isinstance(om, self.__class__) or om.module != self.module: + try: + om = self.module.convert(om) + except CoercionFailed: + return NotImplemented + return self.__add__(-om) + + def __rsub__(self, om): + return (-self).__add__(om) + + def __mul__(self, o): + if not isinstance(o, self.module.ring.dtype): + try: + o = self.module.ring.convert(o) + except CoercionFailed: + return NotImplemented + return self.__class__(self.module, self.mul(self.data, o)) + + __rmul__ = __mul__ + + def __truediv__(self, o): + if not isinstance(o, self.module.ring.dtype): + try: + o = self.module.ring.convert(o) + except CoercionFailed: + return NotImplemented + return self.__class__(self.module, self.div(self.data, o)) + + def __eq__(self, om): + if not isinstance(om, self.__class__) or om.module != self.module: + try: + om = self.module.convert(om) + except CoercionFailed: + return False + return self.eq(self.data, om.data) + + def __ne__(self, om): + return not self == om + +########################################################################## +## Free Modules ########################################################## +########################################################################## + + +class FreeModuleElement(ModuleElement): + """Element of a free module. Data stored as a tuple.""" + + def add(self, d1, d2): + return tuple(x + y for x, y in zip(d1, d2)) + + def mul(self, d, p): + return tuple(x * p for x in d) + + def div(self, d, p): + return tuple(x / p for x in d) + + def __repr__(self): + from sympy.printing.str import sstr + data = self.data + if any(isinstance(x, DMP) for x in data): + data = [self.module.ring.to_sympy(x) for x in data] + return '[' + ', '.join(sstr(x) for x in data) + ']' + + def __iter__(self): + return self.data.__iter__() + + def __getitem__(self, idx): + return self.data[idx] + + +class FreeModule(Module): + """ + Abstract base class for free modules. + + Additional attributes: + + - rank - rank of the free module + + Non-implemented methods: + + - submodule + """ + + dtype = FreeModuleElement + + def __init__(self, ring, rank): + Module.__init__(self, ring) + self.rank = rank + + def __repr__(self): + return repr(self.ring) + "**" + repr(self.rank) + + def is_submodule(self, other): + """ + Returns True if ``other`` is a submodule of ``self``. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy import QQ + >>> F = QQ.old_poly_ring(x).free_module(2) + >>> M = F.submodule([2, x]) + >>> F.is_submodule(F) + True + >>> F.is_submodule(M) + True + >>> M.is_submodule(F) + False + """ + if isinstance(other, SubModule): + return other.container == self + if isinstance(other, FreeModule): + return other.ring == self.ring and other.rank == self.rank + return False + + def convert(self, elem, M=None): + """ + Convert ``elem`` into the internal representation. + + This method is called implicitly whenever computations involve elements + not in the internal representation. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy import QQ + >>> F = QQ.old_poly_ring(x).free_module(2) + >>> F.convert([1, 0]) + [1, 0] + """ + if isinstance(elem, FreeModuleElement): + if elem.module is self: + return elem + if elem.module.rank != self.rank: + raise CoercionFailed + return FreeModuleElement(self, + tuple(self.ring.convert(x, elem.module.ring) for x in elem.data)) + elif iterable(elem): + tpl = tuple(self.ring.convert(x) for x in elem) + if len(tpl) != self.rank: + raise CoercionFailed + return FreeModuleElement(self, tpl) + elif _aresame(elem, 0): + return FreeModuleElement(self, (self.ring.convert(0),)*self.rank) + else: + raise CoercionFailed + + def is_zero(self): + """ + Returns True if ``self`` is a zero module. + + (If, as this implementation assumes, the coefficient ring is not the + zero ring, then this is equivalent to the rank being zero.) + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy import QQ + >>> QQ.old_poly_ring(x).free_module(0).is_zero() + True + >>> QQ.old_poly_ring(x).free_module(1).is_zero() + False + """ + return self.rank == 0 + + def basis(self): + """ + Return a set of basis elements. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy import QQ + >>> QQ.old_poly_ring(x).free_module(3).basis() + ([1, 0, 0], [0, 1, 0], [0, 0, 1]) + """ + from sympy.matrices import eye + M = eye(self.rank) + return tuple(self.convert(M.row(i)) for i in range(self.rank)) + + def quotient_module(self, submodule): + """ + Return a quotient module. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy import QQ + >>> M = QQ.old_poly_ring(x).free_module(2) + >>> M.quotient_module(M.submodule([1, x], [x, 2])) + QQ[x]**2/<[1, x], [x, 2]> + + Or more conicisely, using the overloaded division operator: + + >>> QQ.old_poly_ring(x).free_module(2) / [[1, x], [x, 2]] + QQ[x]**2/<[1, x], [x, 2]> + """ + return QuotientModule(self.ring, self, submodule) + + def multiply_ideal(self, other): + """ + Multiply ``self`` by the ideal ``other``. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy import QQ + >>> I = QQ.old_poly_ring(x).ideal(x) + >>> F = QQ.old_poly_ring(x).free_module(2) + >>> F.multiply_ideal(I) + <[x, 0], [0, x]> + """ + return self.submodule(*self.basis()).multiply_ideal(other) + + def identity_hom(self): + """ + Return the identity homomorphism on ``self``. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy import QQ + >>> QQ.old_poly_ring(x).free_module(2).identity_hom() + Matrix([ + [1, 0], : QQ[x]**2 -> QQ[x]**2 + [0, 1]]) + """ + from sympy.polys.agca.homomorphisms import homomorphism + return homomorphism(self, self, self.basis()) + + +class FreeModulePolyRing(FreeModule): + """ + Free module over a generalized polynomial ring. + + Do not instantiate this, use the constructor method of the ring instead: + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy import QQ + >>> F = QQ.old_poly_ring(x).free_module(3) + >>> F + QQ[x]**3 + >>> F.contains([x, 1, 0]) + True + >>> F.contains([1/x, 0, 1]) + False + """ + + def __init__(self, ring, rank): + from sympy.polys.domains.old_polynomialring import PolynomialRingBase + FreeModule.__init__(self, ring, rank) + if not isinstance(ring, PolynomialRingBase): + raise NotImplementedError('This implementation only works over ' + + 'polynomial rings, got %s' % ring) + if not isinstance(ring.dom, Field): + raise NotImplementedError('Ground domain must be a field, ' + + 'got %s' % ring.dom) + + def submodule(self, *gens, **opts): + """ + Generate a submodule. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy import QQ + >>> M = QQ.old_poly_ring(x, y).free_module(2).submodule([x, x + y]) + >>> M + <[x, x + y]> + >>> M.contains([2*x, 2*x + 2*y]) + True + >>> M.contains([x, y]) + False + """ + return SubModulePolyRing(gens, self, **opts) + + +class FreeModuleQuotientRing(FreeModule): + """ + Free module over a quotient ring. + + Do not instantiate this, use the constructor method of the ring instead: + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy import QQ + >>> F = (QQ.old_poly_ring(x)/[x**2 + 1]).free_module(3) + >>> F + (QQ[x]/)**3 + + Attributes + + - quot - the quotient module `R^n / IR^n`, where `R/I` is our ring + """ + + def __init__(self, ring, rank): + from sympy.polys.domains.quotientring import QuotientRing + FreeModule.__init__(self, ring, rank) + if not isinstance(ring, QuotientRing): + raise NotImplementedError('This implementation only works over ' + + 'quotient rings, got %s' % ring) + F = self.ring.ring.free_module(self.rank) + self.quot = F / (self.ring.base_ideal*F) + + def __repr__(self): + return "(" + repr(self.ring) + ")" + "**" + repr(self.rank) + + def submodule(self, *gens, **opts): + """ + Generate a submodule. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy import QQ + >>> M = (QQ.old_poly_ring(x, y)/[x**2 - y**2]).free_module(2).submodule([x, x + y]) + >>> M + <[x + , x + y + ]> + >>> M.contains([y**2, x**2 + x*y]) + True + >>> M.contains([x, y]) + False + """ + return SubModuleQuotientRing(gens, self, **opts) + + def lift(self, elem): + """ + Lift the element ``elem`` of self to the module self.quot. + + Note that self.quot is the same set as self, just as an R-module + and not as an R/I-module, so this makes sense. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy import QQ + >>> F = (QQ.old_poly_ring(x)/[x**2 + 1]).free_module(2) + >>> e = F.convert([1, 0]) + >>> e + [1 + , 0 + ] + >>> L = F.quot + >>> l = F.lift(e) + >>> l + [1, 0] + <[x**2 + 1, 0], [0, x**2 + 1]> + >>> L.contains(l) + True + """ + return self.quot.convert([x.data for x in elem]) + + def unlift(self, elem): + """ + Push down an element of self.quot to self. + + This undoes ``lift``. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy import QQ + >>> F = (QQ.old_poly_ring(x)/[x**2 + 1]).free_module(2) + >>> e = F.convert([1, 0]) + >>> l = F.lift(e) + >>> e == l + False + >>> e == F.unlift(l) + True + """ + return self.convert(elem.data) + +########################################################################## +## Submodules and subquotients ########################################### +########################################################################## + + +class SubModule(Module): + """ + Base class for submodules. + + Attributes: + + - container - containing module + - gens - generators (subset of containing module) + - rank - rank of containing module + + Non-implemented methods: + + - _contains + - _syzygies + - _in_terms_of_generators + - _intersect + - _module_quotient + + Methods that likely need change in subclasses: + + - reduce_element + """ + + def __init__(self, gens, container): + Module.__init__(self, container.ring) + self.gens = tuple(container.convert(x) for x in gens) + self.container = container + self.rank = container.rank + self.ring = container.ring + self.dtype = container.dtype + + def __repr__(self): + return "<" + ", ".join(repr(x) for x in self.gens) + ">" + + def _contains(self, other): + """Implementation of containment. + Other is guaranteed to be FreeModuleElement.""" + raise NotImplementedError + + def _syzygies(self): + """Implementation of syzygy computation wrt self generators.""" + raise NotImplementedError + + def _in_terms_of_generators(self, e): + """Implementation of expression in terms of generators.""" + raise NotImplementedError + + def convert(self, elem, M=None): + """ + Convert ``elem`` into the internal represantition. + + Mostly called implicitly. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy import QQ + >>> M = QQ.old_poly_ring(x).free_module(2).submodule([1, x]) + >>> M.convert([2, 2*x]) + [2, 2*x] + """ + if isinstance(elem, self.container.dtype) and elem.module is self: + return elem + r = copy(self.container.convert(elem, M)) + r.module = self + if not self._contains(r): + raise CoercionFailed + return r + + def _intersect(self, other): + """Implementation of intersection. + Other is guaranteed to be a submodule of same free module.""" + raise NotImplementedError + + def _module_quotient(self, other): + """Implementation of quotient. + Other is guaranteed to be a submodule of same free module.""" + raise NotImplementedError + + def intersect(self, other, **options): + """ + Returns the intersection of ``self`` with submodule ``other``. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy import QQ + >>> F = QQ.old_poly_ring(x, y).free_module(2) + >>> F.submodule([x, x]).intersect(F.submodule([y, y])) + <[x*y, x*y]> + + Some implementation allow further options to be passed. Currently, to + only one implemented is ``relations=True``, in which case the function + will return a triple ``(res, rela, relb)``, where ``res`` is the + intersection module, and ``rela`` and ``relb`` are lists of coefficient + vectors, expressing the generators of ``res`` in terms of the + generators of ``self`` (``rela``) and ``other`` (``relb``). + + >>> F.submodule([x, x]).intersect(F.submodule([y, y]), relations=True) + (<[x*y, x*y]>, [(DMP_Python([[1, 0]], QQ),)], [(DMP_Python([[1], []], QQ),)]) + + The above result says: the intersection module is generated by the + single element `(-xy, -xy) = -y (x, x) = -x (y, y)`, where + `(x, x)` and `(y, y)` respectively are the unique generators of + the two modules being intersected. + """ + if not isinstance(other, SubModule): + raise TypeError('%s is not a SubModule' % other) + if other.container != self.container: + raise ValueError( + '%s is contained in a different free module' % other) + return self._intersect(other, **options) + + def module_quotient(self, other, **options): + r""" + Returns the module quotient of ``self`` by submodule ``other``. + + That is, if ``self`` is the module `M` and ``other`` is `N`, then + return the ideal `\{f \in R | fN \subset M\}`. + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.abc import x, y + >>> F = QQ.old_poly_ring(x, y).free_module(2) + >>> S = F.submodule([x*y, x*y]) + >>> T = F.submodule([x, x]) + >>> S.module_quotient(T) + + + Some implementations allow further options to be passed. Currently, the + only one implemented is ``relations=True``, which may only be passed + if ``other`` is principal. In this case the function + will return a pair ``(res, rel)`` where ``res`` is the ideal, and + ``rel`` is a list of coefficient vectors, expressing the generators of + the ideal, multiplied by the generator of ``other`` in terms of + generators of ``self``. + + >>> S.module_quotient(T, relations=True) + (, [[DMP_Python([[1]], QQ)]]) + + This means that the quotient ideal is generated by the single element + `y`, and that `y (x, x) = 1 (xy, xy)`, `(x, x)` and `(xy, xy)` being + the generators of `T` and `S`, respectively. + """ + if not isinstance(other, SubModule): + raise TypeError('%s is not a SubModule' % other) + if other.container != self.container: + raise ValueError( + '%s is contained in a different free module' % other) + return self._module_quotient(other, **options) + + def union(self, other): + """ + Returns the module generated by the union of ``self`` and ``other``. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy import QQ + >>> F = QQ.old_poly_ring(x).free_module(1) + >>> M = F.submodule([x**2 + x]) # + >>> N = F.submodule([x**2 - 1]) # <(x-1)(x+1)> + >>> M.union(N) == F.submodule([x+1]) + True + """ + if not isinstance(other, SubModule): + raise TypeError('%s is not a SubModule' % other) + if other.container != self.container: + raise ValueError( + '%s is contained in a different free module' % other) + return self.__class__(self.gens + other.gens, self.container) + + def is_zero(self): + """ + Return True if ``self`` is a zero module. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy import QQ + >>> F = QQ.old_poly_ring(x).free_module(2) + >>> F.submodule([x, 1]).is_zero() + False + >>> F.submodule([0, 0]).is_zero() + True + """ + return all(x == 0 for x in self.gens) + + def submodule(self, *gens): + """ + Generate a submodule. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy import QQ + >>> M = QQ.old_poly_ring(x).free_module(2).submodule([x, 1]) + >>> M.submodule([x**2, x]) + <[x**2, x]> + """ + if not self.subset(gens): + raise ValueError('%s not a subset of %s' % (gens, self)) + return self.__class__(gens, self.container) + + def is_full_module(self): + """ + Return True if ``self`` is the entire free module. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy import QQ + >>> F = QQ.old_poly_ring(x).free_module(2) + >>> F.submodule([x, 1]).is_full_module() + False + >>> F.submodule([1, 1], [1, 2]).is_full_module() + True + """ + return all(self.contains(x) for x in self.container.basis()) + + def is_submodule(self, other): + """ + Returns True if ``other`` is a submodule of ``self``. + + >>> from sympy.abc import x + >>> from sympy import QQ + >>> F = QQ.old_poly_ring(x).free_module(2) + >>> M = F.submodule([2, x]) + >>> N = M.submodule([2*x, x**2]) + >>> M.is_submodule(M) + True + >>> M.is_submodule(N) + True + >>> N.is_submodule(M) + False + """ + if isinstance(other, SubModule): + return self.container == other.container and \ + all(self.contains(x) for x in other.gens) + if isinstance(other, (FreeModule, QuotientModule)): + return self.container == other and self.is_full_module() + return False + + def syzygy_module(self, **opts): + r""" + Compute the syzygy module of the generators of ``self``. + + Suppose `M` is generated by `f_1, \ldots, f_n` over the ring + `R`. Consider the homomorphism `\phi: R^n \to M`, given by + sending `(r_1, \ldots, r_n) \to r_1 f_1 + \cdots + r_n f_n`. + The syzygy module is defined to be the kernel of `\phi`. + + Examples + ======== + + The syzygy module is zero iff the generators generate freely a free + submodule: + + >>> from sympy.abc import x, y + >>> from sympy import QQ + >>> QQ.old_poly_ring(x).free_module(2).submodule([1, 0], [1, 1]).syzygy_module().is_zero() + True + + A slightly more interesting example: + + >>> M = QQ.old_poly_ring(x, y).free_module(2).submodule([x, 2*x], [y, 2*y]) + >>> S = QQ.old_poly_ring(x, y).free_module(2).submodule([y, -x]) + >>> M.syzygy_module() == S + True + """ + F = self.ring.free_module(len(self.gens)) + # NOTE we filter out zero syzygies. This is for convenience of the + # _syzygies function and not meant to replace any real "generating set + # reduction" algorithm + return F.submodule(*[x for x in self._syzygies() if F.convert(x) != 0], + **opts) + + def in_terms_of_generators(self, e): + """ + Express element ``e`` of ``self`` in terms of the generators. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy import QQ + >>> F = QQ.old_poly_ring(x).free_module(2) + >>> M = F.submodule([1, 0], [1, 1]) + >>> M.in_terms_of_generators([x, x**2]) # doctest: +SKIP + [DMP_Python([-1, 1, 0], QQ), DMP_Python([1, 0, 0], QQ)] + """ + try: + e = self.convert(e) + except CoercionFailed: + raise ValueError('%s is not an element of %s' % (e, self)) + return self._in_terms_of_generators(e) + + def reduce_element(self, x): + """ + Reduce the element ``x`` of our ring modulo the ideal ``self``. + + Here "reduce" has no specific meaning, it could return a unique normal + form, simplify the expression a bit, or just do nothing. + """ + return x + + def quotient_module(self, other, **opts): + """ + Return a quotient module. + + This is the same as taking a submodule of a quotient of the containing + module. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy import QQ + >>> F = QQ.old_poly_ring(x).free_module(2) + >>> S1 = F.submodule([x, 1]) + >>> S2 = F.submodule([x**2, x]) + >>> S1.quotient_module(S2) + <[x, 1] + <[x**2, x]>> + + Or more coincisely, using the overloaded division operator: + + >>> F.submodule([x, 1]) / [(x**2, x)] + <[x, 1] + <[x**2, x]>> + """ + if not self.is_submodule(other): + raise ValueError('%s not a submodule of %s' % (other, self)) + return SubQuotientModule(self.gens, + self.container.quotient_module(other), **opts) + + def __add__(self, oth): + return self.container.quotient_module(self).convert(oth) + + __radd__ = __add__ + + def multiply_ideal(self, I): + """ + Multiply ``self`` by the ideal ``I``. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy import QQ + >>> I = QQ.old_poly_ring(x).ideal(x**2) + >>> M = QQ.old_poly_ring(x).free_module(2).submodule([1, 1]) + >>> I*M + <[x**2, x**2]> + """ + return self.submodule(*[x*g for [x] in I._module.gens for g in self.gens]) + + def inclusion_hom(self): + """ + Return a homomorphism representing the inclusion map of ``self``. + + That is, the natural map from ``self`` to ``self.container``. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy import QQ + >>> QQ.old_poly_ring(x).free_module(2).submodule([x, x]).inclusion_hom() + Matrix([ + [1, 0], : <[x, x]> -> QQ[x]**2 + [0, 1]]) + """ + return self.container.identity_hom().restrict_domain(self) + + def identity_hom(self): + """ + Return the identity homomorphism on ``self``. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy import QQ + >>> QQ.old_poly_ring(x).free_module(2).submodule([x, x]).identity_hom() + Matrix([ + [1, 0], : <[x, x]> -> <[x, x]> + [0, 1]]) + """ + return self.container.identity_hom().restrict_domain( + self).restrict_codomain(self) + + +class SubQuotientModule(SubModule): + """ + Submodule of a quotient module. + + Equivalently, quotient module of a submodule. + + Do not instantiate this, instead use the submodule or quotient_module + constructing methods: + + >>> from sympy.abc import x + >>> from sympy import QQ + >>> F = QQ.old_poly_ring(x).free_module(2) + >>> S = F.submodule([1, 0], [1, x]) + >>> Q = F/[(1, 0)] + >>> S/[(1, 0)] == Q.submodule([5, x]) + True + + Attributes: + + - base - base module we are quotient of + - killed_module - submodule used to form the quotient + """ + def __init__(self, gens, container, **opts): + SubModule.__init__(self, gens, container) + self.killed_module = self.container.killed_module + # XXX it is important for some code below that the generators of base + # are in this particular order! + self.base = self.container.base.submodule( + *[x.data for x in self.gens], **opts).union(self.killed_module) + + def _contains(self, elem): + return self.base.contains(elem.data) + + def _syzygies(self): + # let N = self.killed_module be generated by e_1, ..., e_r + # let F = self.base be generated by f_1, ..., f_s and e_1, ..., e_r + # Then self = F/N. + # Let phi: R**s --> self be the evident surjection. + # Similarly psi: R**(s + r) --> F. + # We need to find generators for ker(phi). Let chi: R**s --> F be the + # evident lift of phi. For X in R**s, phi(X) = 0 iff chi(X) is + # contained in N, iff there exists Y in R**r such that + # psi(X, Y) = 0. + # Hence if alpha: R**(s + r) --> R**s is the projection map, then + # ker(phi) = alpha ker(psi). + return [X[:len(self.gens)] for X in self.base._syzygies()] + + def _in_terms_of_generators(self, e): + return self.base._in_terms_of_generators(e.data)[:len(self.gens)] + + def is_full_module(self): + """ + Return True if ``self`` is the entire free module. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy import QQ + >>> F = QQ.old_poly_ring(x).free_module(2) + >>> F.submodule([x, 1]).is_full_module() + False + >>> F.submodule([1, 1], [1, 2]).is_full_module() + True + """ + return self.base.is_full_module() + + def quotient_hom(self): + """ + Return the quotient homomorphism to self. + + That is, return the natural map from ``self.base`` to ``self``. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy import QQ + >>> M = (QQ.old_poly_ring(x).free_module(2) / [(1, x)]).submodule([1, 0]) + >>> M.quotient_hom() + Matrix([ + [1, 0], : <[1, 0], [1, x]> -> <[1, 0] + <[1, x]>, [1, x] + <[1, x]>> + [0, 1]]) + """ + return self.base.identity_hom().quotient_codomain(self.killed_module) + + +_subs0 = lambda x: x[0] +_subs1 = lambda x: x[1:] + + +class ModuleOrder(ProductOrder): + """A product monomial order with a zeroth term as module index.""" + + def __init__(self, o1, o2, TOP): + if TOP: + ProductOrder.__init__(self, (o2, _subs1), (o1, _subs0)) + else: + ProductOrder.__init__(self, (o1, _subs0), (o2, _subs1)) + + +class SubModulePolyRing(SubModule): + """ + Submodule of a free module over a generalized polynomial ring. + + Do not instantiate this, use the constructor method of FreeModule instead: + + >>> from sympy.abc import x, y + >>> from sympy import QQ + >>> F = QQ.old_poly_ring(x, y).free_module(2) + >>> F.submodule([x, y], [1, 0]) + <[x, y], [1, 0]> + + Attributes: + + - order - monomial order used + """ + + #self._gb - cached groebner basis + #self._gbe - cached groebner basis relations + + def __init__(self, gens, container, order="lex", TOP=True): + SubModule.__init__(self, gens, container) + if not isinstance(container, FreeModulePolyRing): + raise NotImplementedError('This implementation is for submodules of ' + + 'FreeModulePolyRing, got %s' % container) + self.order = ModuleOrder(monomial_key(order), self.ring.order, TOP) + self._gb = None + self._gbe = None + + def __eq__(self, other): + if isinstance(other, SubModulePolyRing) and self.order != other.order: + return False + return SubModule.__eq__(self, other) + + def _groebner(self, extended=False): + """Returns a standard basis in sdm form.""" + from sympy.polys.distributedmodules import sdm_groebner, sdm_nf_mora + if self._gbe is None and extended: + gb, gbe = sdm_groebner( + [self.ring._vector_to_sdm(x, self.order) for x in self.gens], + sdm_nf_mora, self.order, self.ring.dom, extended=True) + self._gb, self._gbe = tuple(gb), tuple(gbe) + if self._gb is None: + self._gb = tuple(sdm_groebner( + [self.ring._vector_to_sdm(x, self.order) for x in self.gens], + sdm_nf_mora, self.order, self.ring.dom)) + if extended: + return self._gb, self._gbe + else: + return self._gb + + def _groebner_vec(self, extended=False): + """Returns a standard basis in element form.""" + if not extended: + return [FreeModuleElement(self, + tuple(self.ring._sdm_to_vector(x, self.rank))) + for x in self._groebner()] + gb, gbe = self._groebner(extended=True) + return ([self.convert(self.ring._sdm_to_vector(x, self.rank)) + for x in gb], + [self.ring._sdm_to_vector(x, len(self.gens)) for x in gbe]) + + def _contains(self, x): + from sympy.polys.distributedmodules import sdm_zero, sdm_nf_mora + return sdm_nf_mora(self.ring._vector_to_sdm(x, self.order), + self._groebner(), self.order, self.ring.dom) == \ + sdm_zero() + + def _syzygies(self): + """Compute syzygies. See [SCA, algorithm 2.5.4].""" + # NOTE if self.gens is a standard basis, this can be done more + # efficiently using Schreyer's theorem + + # First bullet point + k = len(self.gens) + r = self.rank + zero = self.ring.convert(0) + one = self.ring.convert(1) + Rkr = self.ring.free_module(r + k) + newgens = [] + for j, f in enumerate(self.gens): + m = [0]*(r + k) + for i, v in enumerate(f): + m[i] = v + for i in range(k): + m[r + i] = one if j == i else zero + m = FreeModuleElement(Rkr, tuple(m)) + newgens.append(m) + # Note: we need *descending* order on module index, and TOP=False to + # get an elimination order + F = Rkr.submodule(*newgens, order='ilex', TOP=False) + + # Second bullet point: standard basis of F + G = F._groebner_vec() + + # Third bullet point: G0 = G intersect the new k components + G0 = [x[r:] for x in G if all(y == zero for y in x[:r])] + + # Fourth and fifth bullet points: we are done + return G0 + + def _in_terms_of_generators(self, e): + """Expression in terms of generators. See [SCA, 2.8.1].""" + # NOTE: if gens is a standard basis, this can be done more efficiently + M = self.ring.free_module(self.rank).submodule(*((e,) + self.gens)) + S = M.syzygy_module( + order="ilex", TOP=False) # We want decreasing order! + G = S._groebner_vec() + # This list cannot not be empty since e is an element + e = [x for x in G if self.ring.is_unit(x[0])][0] + return [-x/e[0] for x in e[1:]] + + def reduce_element(self, x, NF=None): + """ + Reduce the element ``x`` of our container modulo ``self``. + + This applies the normal form ``NF`` to ``x``. If ``NF`` is passed + as none, the default Mora normal form is used (which is not unique!). + """ + from sympy.polys.distributedmodules import sdm_nf_mora + if NF is None: + NF = sdm_nf_mora + return self.container.convert(self.ring._sdm_to_vector(NF( + self.ring._vector_to_sdm(x, self.order), self._groebner(), + self.order, self.ring.dom), + self.rank)) + + def _intersect(self, other, relations=False): + # See: [SCA, section 2.8.2] + fi = self.gens + hi = other.gens + r = self.rank + ci = [[0]*(2*r) for _ in range(r)] + for k in range(r): + ci[k][k] = 1 + ci[k][r + k] = 1 + di = [list(f) + [0]*r for f in fi] + ei = [[0]*r + list(h) for h in hi] + syz = self.ring.free_module(2*r).submodule(*(ci + di + ei))._syzygies() + nonzero = [x for x in syz if any(y != self.ring.zero for y in x[:r])] + res = self.container.submodule(*([-y for y in x[:r]] for x in nonzero)) + reln1 = [x[r:r + len(fi)] for x in nonzero] + reln2 = [x[r + len(fi):] for x in nonzero] + if relations: + return res, reln1, reln2 + return res + + def _module_quotient(self, other, relations=False): + # See: [SCA, section 2.8.4] + if relations and len(other.gens) != 1: + raise NotImplementedError + if len(other.gens) == 0: + return self.ring.ideal(1) + elif len(other.gens) == 1: + # We do some trickery. Let f be the (vector!) generating ``other`` + # and f1, .., fn be the (vectors) generating self. + # Consider the submodule of R^{r+1} generated by (f, 1) and + # {(fi, 0) | i}. Then the intersection with the last module + # component yields the quotient. + g1 = list(other.gens[0]) + [1] + gi = [list(x) + [0] for x in self.gens] + # NOTE: We *need* to use an elimination order + M = self.ring.free_module(self.rank + 1).submodule(*([g1] + gi), + order='ilex', TOP=False) + if not relations: + return self.ring.ideal(*[x[-1] for x in M._groebner_vec() if + all(y == self.ring.zero for y in x[:-1])]) + else: + G, R = M._groebner_vec(extended=True) + indices = [i for i, x in enumerate(G) if + all(y == self.ring.zero for y in x[:-1])] + return (self.ring.ideal(*[G[i][-1] for i in indices]), + [[-x for x in R[i][1:]] for i in indices]) + # For more generators, we use I : = intersection of + # {I : | i} + # TODO this can be done more efficiently + return reduce(lambda x, y: x.intersect(y), + (self._module_quotient(self.container.submodule(x)) for x in other.gens)) + + +class SubModuleQuotientRing(SubModule): + """ + Class for submodules of free modules over quotient rings. + + Do not instantiate this. Instead use the submodule methods. + + >>> from sympy.abc import x, y + >>> from sympy import QQ + >>> M = (QQ.old_poly_ring(x, y)/[x**2 - y**2]).free_module(2).submodule([x, x + y]) + >>> M + <[x + , x + y + ]> + >>> M.contains([y**2, x**2 + x*y]) + True + >>> M.contains([x, y]) + False + + Attributes: + + - quot - the subquotient of `R^n/IR^n` generated by lifts of our generators + """ + + def __init__(self, gens, container): + SubModule.__init__(self, gens, container) + self.quot = self.container.quot.submodule( + *[self.container.lift(x) for x in self.gens]) + + def _contains(self, elem): + return self.quot._contains(self.container.lift(elem)) + + def _syzygies(self): + return [tuple(self.ring.convert(y, self.quot.ring) for y in x) + for x in self.quot._syzygies()] + + def _in_terms_of_generators(self, elem): + return [self.ring.convert(x, self.quot.ring) for x in + self.quot._in_terms_of_generators(self.container.lift(elem))] + +########################################################################## +## Quotient Modules ###################################################### +########################################################################## + + +class QuotientModuleElement(ModuleElement): + """Element of a quotient module.""" + + def eq(self, d1, d2): + """Equality comparison.""" + return self.module.killed_module.contains(d1 - d2) + + def __repr__(self): + return repr(self.data) + " + " + repr(self.module.killed_module) + + +class QuotientModule(Module): + """ + Class for quotient modules. + + Do not instantiate this directly. For subquotients, see the + SubQuotientModule class. + + Attributes: + + - base - the base module we are a quotient of + - killed_module - the submodule used to form the quotient + - rank of the base + """ + + dtype = QuotientModuleElement + + def __init__(self, ring, base, submodule): + Module.__init__(self, ring) + if not base.is_submodule(submodule): + raise ValueError('%s is not a submodule of %s' % (submodule, base)) + self.base = base + self.killed_module = submodule + self.rank = base.rank + + def __repr__(self): + return repr(self.base) + "/" + repr(self.killed_module) + + def is_zero(self): + """ + Return True if ``self`` is a zero module. + + This happens if and only if the base module is the same as the + submodule being killed. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy import QQ + >>> F = QQ.old_poly_ring(x).free_module(2) + >>> (F/[(1, 0)]).is_zero() + False + >>> (F/[(1, 0), (0, 1)]).is_zero() + True + """ + return self.base == self.killed_module + + def is_submodule(self, other): + """ + Return True if ``other`` is a submodule of ``self``. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy import QQ + >>> Q = QQ.old_poly_ring(x).free_module(2) / [(x, x)] + >>> S = Q.submodule([1, 0]) + >>> Q.is_submodule(S) + True + >>> S.is_submodule(Q) + False + """ + if isinstance(other, QuotientModule): + return self.killed_module == other.killed_module and \ + self.base.is_submodule(other.base) + if isinstance(other, SubQuotientModule): + return other.container == self + return False + + def submodule(self, *gens, **opts): + """ + Generate a submodule. + + This is the same as taking a quotient of a submodule of the base + module. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy import QQ + >>> Q = QQ.old_poly_ring(x).free_module(2) / [(x, x)] + >>> Q.submodule([x, 0]) + <[x, 0] + <[x, x]>> + """ + return SubQuotientModule(gens, self, **opts) + + def convert(self, elem, M=None): + """ + Convert ``elem`` into the internal representation. + + This method is called implicitly whenever computations involve elements + not in the internal representation. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy import QQ + >>> F = QQ.old_poly_ring(x).free_module(2) / [(1, 2), (1, x)] + >>> F.convert([1, 0]) + [1, 0] + <[1, 2], [1, x]> + """ + if isinstance(elem, QuotientModuleElement): + if elem.module is self: + return elem + if self.killed_module.is_submodule(elem.module.killed_module): + return QuotientModuleElement(self, self.base.convert(elem.data)) + raise CoercionFailed + return QuotientModuleElement(self, self.base.convert(elem)) + + def identity_hom(self): + """ + Return the identity homomorphism on ``self``. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy import QQ + >>> M = QQ.old_poly_ring(x).free_module(2) / [(1, 2), (1, x)] + >>> M.identity_hom() + Matrix([ + [1, 0], : QQ[x]**2/<[1, 2], [1, x]> -> QQ[x]**2/<[1, 2], [1, x]> + [0, 1]]) + """ + return self.base.identity_hom().quotient_codomain( + self.killed_module).quotient_domain(self.killed_module) + + def quotient_hom(self): + """ + Return the quotient homomorphism to ``self``. + + That is, return a homomorphism representing the natural map from + ``self.base`` to ``self``. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy import QQ + >>> M = QQ.old_poly_ring(x).free_module(2) / [(1, 2), (1, x)] + >>> M.quotient_hom() + Matrix([ + [1, 0], : QQ[x]**2 -> QQ[x]**2/<[1, 2], [1, x]> + [0, 1]]) + """ + return self.base.identity_hom().quotient_codomain( + self.killed_module) diff --git a/phivenv/Lib/site-packages/sympy/polys/agca/tests/__init__.py b/phivenv/Lib/site-packages/sympy/polys/agca/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/polys/agca/tests/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/agca/tests/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9effa3d0009319395ec1591e1f378fc791a316b5 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/agca/tests/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/agca/tests/__pycache__/test_extensions.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/agca/tests/__pycache__/test_extensions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e87944701e8e3bc1bf90c0da48ff4773f069cfbc Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/agca/tests/__pycache__/test_extensions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/agca/tests/__pycache__/test_homomorphisms.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/agca/tests/__pycache__/test_homomorphisms.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b82ca784bcb35a1e3624edaf014f43ed6d415328 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/agca/tests/__pycache__/test_homomorphisms.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/agca/tests/__pycache__/test_ideals.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/agca/tests/__pycache__/test_ideals.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..702f00c54047b3bf6f5b35a48fe6f75c8ee3b539 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/agca/tests/__pycache__/test_ideals.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/agca/tests/__pycache__/test_modules.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/agca/tests/__pycache__/test_modules.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2dc0d47f22c19de7771ac1334e9316943d760c0b Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/agca/tests/__pycache__/test_modules.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/agca/tests/test_extensions.py b/phivenv/Lib/site-packages/sympy/polys/agca/tests/test_extensions.py new file mode 100644 index 0000000000000000000000000000000000000000..4becf4fd800a7a34c16989adaaf97e312c18f01c --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/agca/tests/test_extensions.py @@ -0,0 +1,196 @@ +from sympy.core.symbol import symbols +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.polys import QQ, ZZ +from sympy.polys.polytools import Poly +from sympy.polys.polyerrors import NotInvertible +from sympy.polys.agca.extensions import FiniteExtension +from sympy.polys.domainmatrix import DomainMatrix + +from sympy.testing.pytest import raises + +from sympy.abc import x, y, t + + +def test_FiniteExtension(): + # Gaussian integers + A = FiniteExtension(Poly(x**2 + 1, x)) + assert A.rank == 2 + assert str(A) == 'ZZ[x]/(x**2 + 1)' + i = A.generator + assert i.parent() is A + + assert i*i == A(-1) + raises(TypeError, lambda: i*()) + + assert A.basis == (A.one, i) + assert A(1) == A.one + assert i**2 == A(-1) + assert i**2 != -1 # no coercion + assert (2 + i)*(1 - i) == 3 - i + assert (1 + i)**8 == A(16) + assert A(1).inverse() == A(1) + raises(NotImplementedError, lambda: A(2).inverse()) + + # Finite field of order 27 + F = FiniteExtension(Poly(x**3 - x + 1, x, modulus=3)) + assert F.rank == 3 + a = F.generator # also generates the cyclic group F - {0} + assert F.basis == (F(1), a, a**2) + assert a**27 == a + assert a**26 == F(1) + assert a**13 == F(-1) + assert a**9 == a + 1 + assert a**3 == a - 1 + assert a**6 == a**2 + a + 1 + assert F(x**2 + x).inverse() == 1 - a + assert F(x + 2)**(-1) == F(x + 2).inverse() + assert a**19 * a**(-19) == F(1) + assert (a - 1) / (2*a**2 - 1) == a**2 + 1 + assert (a - 1) // (2*a**2 - 1) == a**2 + 1 + assert 2/(a**2 + 1) == a**2 - a + 1 + assert (a**2 + 1)/2 == -a**2 - 1 + raises(NotInvertible, lambda: F(0).inverse()) + + # Function field of an elliptic curve + K = FiniteExtension(Poly(t**2 - x**3 - x + 1, t, field=True)) + assert K.rank == 2 + assert str(K) == 'ZZ(x)[t]/(t**2 - x**3 - x + 1)' + y = K.generator + c = 1/(x**3 - x**2 + x - 1) + assert ((y + x)*(y - x)).inverse() == K(c) + assert (y + x)*(y - x)*c == K(1) # explicit inverse of y + x + + +def test_FiniteExtension_eq_hash(): + # Test eq and hash + p1 = Poly(x**2 - 2, x, domain=ZZ) + p2 = Poly(x**2 - 2, x, domain=QQ) + K1 = FiniteExtension(p1) + K2 = FiniteExtension(p2) + assert K1 == FiniteExtension(Poly(x**2 - 2)) + assert K2 != FiniteExtension(Poly(x**2 - 2)) + assert len({K1, K2, FiniteExtension(p1)}) == 2 + + +def test_FiniteExtension_mod(): + # Test mod + K = FiniteExtension(Poly(x**3 + 1, x, domain=QQ)) + xf = K(x) + assert (xf**2 - 1) % 1 == K.zero + assert 1 % (xf**2 - 1) == K.zero + assert (xf**2 - 1) / (xf - 1) == xf + 1 + assert (xf**2 - 1) // (xf - 1) == xf + 1 + assert (xf**2 - 1) % (xf - 1) == K.zero + raises(ZeroDivisionError, lambda: (xf**2 - 1) % 0) + raises(TypeError, lambda: xf % []) + raises(TypeError, lambda: [] % xf) + + # Test mod over ring + K = FiniteExtension(Poly(x**3 + 1, x, domain=ZZ)) + xf = K(x) + assert (xf**2 - 1) % 1 == K.zero + raises(NotImplementedError, lambda: (xf**2 - 1) % (xf - 1)) + + +def test_FiniteExtension_from_sympy(): + # Test to_sympy/from_sympy + K = FiniteExtension(Poly(x**3 + 1, x, domain=ZZ)) + xf = K(x) + assert K.from_sympy(x) == xf + assert K.to_sympy(xf) == x + + +def test_FiniteExtension_set_domain(): + KZ = FiniteExtension(Poly(x**2 + 1, x, domain='ZZ')) + KQ = FiniteExtension(Poly(x**2 + 1, x, domain='QQ')) + assert KZ.set_domain(QQ) == KQ + + +def test_FiniteExtension_exquo(): + # Test exquo + K = FiniteExtension(Poly(x**4 + 1)) + xf = K(x) + assert K.exquo(xf**2 - 1, xf - 1) == xf + 1 + + +def test_FiniteExtension_convert(): + # Test from_MonogenicFiniteExtension + K1 = FiniteExtension(Poly(x**2 + 1)) + K2 = QQ[x] + x1, x2 = K1(x), K2(x) + assert K1.convert(x2) == x1 + assert K2.convert(x1) == x2 + + K = FiniteExtension(Poly(x**2 - 1, domain=QQ)) + assert K.convert_from(QQ(1, 2), QQ) == K.one/2 + + +def test_FiniteExtension_division_ring(): + # Test division in FiniteExtension over a ring + KQ = FiniteExtension(Poly(x**2 - 1, x, domain=QQ)) + KZ = FiniteExtension(Poly(x**2 - 1, x, domain=ZZ)) + KQt = FiniteExtension(Poly(x**2 - 1, x, domain=QQ[t])) + KQtf = FiniteExtension(Poly(x**2 - 1, x, domain=QQ.frac_field(t))) + assert KQ.is_Field is True + assert KZ.is_Field is False + assert KQt.is_Field is False + assert KQtf.is_Field is True + for K in KQ, KZ, KQt, KQtf: + xK = K.convert(x) + assert xK / K.one == xK + assert xK // K.one == xK + assert xK % K.one == K.zero + raises(ZeroDivisionError, lambda: xK / K.zero) + raises(ZeroDivisionError, lambda: xK // K.zero) + raises(ZeroDivisionError, lambda: xK % K.zero) + if K.is_Field: + assert xK / xK == K.one + assert xK // xK == K.one + assert xK % xK == K.zero + else: + raises(NotImplementedError, lambda: xK / xK) + raises(NotImplementedError, lambda: xK // xK) + raises(NotImplementedError, lambda: xK % xK) + + +def test_FiniteExtension_Poly(): + K = FiniteExtension(Poly(x**2 - 2)) + p = Poly(x, y, domain=K) + assert p.domain == K + assert p.as_expr() == x + assert (p**2).as_expr() == 2 + + K = FiniteExtension(Poly(x**2 - 2, x, domain=QQ)) + K2 = FiniteExtension(Poly(t**2 - 2, t, domain=K)) + assert str(K2) == 'QQ[x]/(x**2 - 2)[t]/(t**2 - 2)' + + eK = K2.convert(x + t) + assert K2.to_sympy(eK) == x + t + assert K2.to_sympy(eK ** 2) == 4 + 2*x*t + p = Poly(x + t, y, domain=K2) + assert p**2 == Poly(4 + 2*x*t, y, domain=K2) + + +def test_FiniteExtension_sincos_jacobian(): + # Use FiniteExtensino to compute the Jacobian of a matrix involving sin + # and cos of different symbols. + r, p, t = symbols('rho, phi, theta') + elements = [ + [sin(p)*cos(t), r*cos(p)*cos(t), -r*sin(p)*sin(t)], + [sin(p)*sin(t), r*cos(p)*sin(t), r*sin(p)*cos(t)], + [ cos(p), -r*sin(p), 0], + ] + + def make_extension(K): + K = FiniteExtension(Poly(sin(p)**2+cos(p)**2-1, sin(p), domain=K[cos(p)])) + K = FiniteExtension(Poly(sin(t)**2+cos(t)**2-1, sin(t), domain=K[cos(t)])) + return K + + Ksc1 = make_extension(ZZ[r]) + Ksc2 = make_extension(ZZ)[r] + + for K in [Ksc1, Ksc2]: + elements_K = [[K.convert(e) for e in row] for row in elements] + J = DomainMatrix(elements_K, (3, 3), K) + det = J.charpoly()[-1] * (-K.one)**3 + assert det == K.convert(r**2*sin(p)) diff --git a/phivenv/Lib/site-packages/sympy/polys/agca/tests/test_homomorphisms.py b/phivenv/Lib/site-packages/sympy/polys/agca/tests/test_homomorphisms.py new file mode 100644 index 0000000000000000000000000000000000000000..2e63838e09ed9b9436a58a7d8041175e731bc4ef --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/agca/tests/test_homomorphisms.py @@ -0,0 +1,113 @@ +"""Tests for homomorphisms.""" + +from sympy.core.singleton import S +from sympy.polys.domains.rationalfield import QQ +from sympy.abc import x, y +from sympy.polys.agca import homomorphism +from sympy.testing.pytest import raises + + +def test_printing(): + R = QQ.old_poly_ring(x) + + assert str(homomorphism(R.free_module(1), R.free_module(1), [0])) == \ + 'Matrix([[0]]) : QQ[x]**1 -> QQ[x]**1' + assert str(homomorphism(R.free_module(2), R.free_module(2), [0, 0])) == \ + 'Matrix([ \n[0, 0], : QQ[x]**2 -> QQ[x]**2\n[0, 0]]) ' + assert str(homomorphism(R.free_module(1), R.free_module(1) / [[x]], [0])) == \ + 'Matrix([[0]]) : QQ[x]**1 -> QQ[x]**1/<[x]>' + assert str(R.free_module(0).identity_hom()) == 'Matrix(0, 0, []) : QQ[x]**0 -> QQ[x]**0' + +def test_operations(): + F = QQ.old_poly_ring(x).free_module(2) + G = QQ.old_poly_ring(x).free_module(3) + f = F.identity_hom() + g = homomorphism(F, F, [0, [1, x]]) + h = homomorphism(F, F, [[1, 0], 0]) + i = homomorphism(F, G, [[1, 0, 0], [0, 1, 0]]) + + assert f == f + assert f != g + assert f != i + assert (f != F.identity_hom()) is False + assert 2*f == f*2 == homomorphism(F, F, [[2, 0], [0, 2]]) + assert f/2 == homomorphism(F, F, [[S.Half, 0], [0, S.Half]]) + assert f + g == homomorphism(F, F, [[1, 0], [1, x + 1]]) + assert f - g == homomorphism(F, F, [[1, 0], [-1, 1 - x]]) + assert f*g == g == g*f + assert h*g == homomorphism(F, F, [0, [1, 0]]) + assert g*h == homomorphism(F, F, [0, 0]) + assert i*f == i + assert f([1, 2]) == [1, 2] + assert g([1, 2]) == [2, 2*x] + + assert i.restrict_domain(F.submodule([x, x]))([x, x]) == i([x, x]) + h1 = h.quotient_domain(F.submodule([0, 1])) + assert h1([1, 0]) == h([1, 0]) + assert h1.restrict_domain(h1.domain.submodule([x, 0]))([x, 0]) == h([x, 0]) + + raises(TypeError, lambda: f/g) + raises(TypeError, lambda: f + 1) + raises(TypeError, lambda: f + i) + raises(TypeError, lambda: f - 1) + raises(TypeError, lambda: f*i) + + +def test_creation(): + F = QQ.old_poly_ring(x).free_module(3) + G = QQ.old_poly_ring(x).free_module(2) + SM = F.submodule([1, 1, 1]) + Q = F / SM + SQ = Q.submodule([1, 0, 0]) + + matrix = [[1, 0], [0, 1], [-1, -1]] + h = homomorphism(F, G, matrix) + h2 = homomorphism(Q, G, matrix) + assert h.quotient_domain(SM) == h2 + raises(ValueError, lambda: h.quotient_domain(F.submodule([1, 0, 0]))) + assert h2.restrict_domain(SQ) == homomorphism(SQ, G, matrix) + raises(ValueError, lambda: h.restrict_domain(G)) + raises(ValueError, lambda: h.restrict_codomain(G.submodule([1, 0]))) + raises(ValueError, lambda: h.quotient_codomain(F)) + + im = [[1, 0, 0], [0, 1, 0], [0, 0, 1]] + for M in [F, SM, Q, SQ]: + assert M.identity_hom() == homomorphism(M, M, im) + assert SM.inclusion_hom() == homomorphism(SM, F, im) + assert SQ.inclusion_hom() == homomorphism(SQ, Q, im) + assert Q.quotient_hom() == homomorphism(F, Q, im) + assert SQ.quotient_hom() == homomorphism(SQ.base, SQ, im) + + class conv: + def convert(x, y=None): + return x + + class dummy: + container = conv() + + def submodule(*args): + return None + raises(TypeError, lambda: homomorphism(dummy(), G, matrix)) + raises(TypeError, lambda: homomorphism(F, dummy(), matrix)) + raises( + ValueError, lambda: homomorphism(QQ.old_poly_ring(x, y).free_module(3), G, matrix)) + raises(ValueError, lambda: homomorphism(F, G, [0, 0])) + + +def test_properties(): + R = QQ.old_poly_ring(x, y) + F = R.free_module(2) + h = homomorphism(F, F, [[x, 0], [y, 0]]) + assert h.kernel() == F.submodule([-y, x]) + assert h.image() == F.submodule([x, 0], [y, 0]) + assert not h.is_injective() + assert not h.is_surjective() + assert h.restrict_codomain(h.image()).is_surjective() + assert h.restrict_domain(F.submodule([1, 0])).is_injective() + assert h.quotient_domain( + h.kernel()).restrict_codomain(h.image()).is_isomorphism() + + R2 = QQ.old_poly_ring(x, y, order=(("lex", x), ("ilex", y))) / [x**2 + 1] + F = R2.free_module(2) + h = homomorphism(F, F, [[x, 0], [y, y + 1]]) + assert h.is_isomorphism() diff --git a/phivenv/Lib/site-packages/sympy/polys/agca/tests/test_ideals.py b/phivenv/Lib/site-packages/sympy/polys/agca/tests/test_ideals.py new file mode 100644 index 0000000000000000000000000000000000000000..b7fff0674b54a22e2a5acba5110d62d96a877074 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/agca/tests/test_ideals.py @@ -0,0 +1,131 @@ +"""Test ideals.py code.""" + +from sympy.polys import QQ, ilex +from sympy.abc import x, y, z +from sympy.testing.pytest import raises + + +def test_ideal_operations(): + R = QQ.old_poly_ring(x, y) + I = R.ideal(x) + J = R.ideal(y) + S = R.ideal(x*y) + T = R.ideal(x, y) + + assert not (I == J) + assert I == I + + assert I.union(J) == T + assert I + J == T + assert I + T == T + + assert not I.subset(T) + assert T.subset(I) + + assert I.product(J) == S + assert I*J == S + assert x*J == S + assert I*y == S + assert R.convert(x)*J == S + assert I*R.convert(y) == S + + assert not I.is_zero() + assert not J.is_whole_ring() + + assert R.ideal(x**2 + 1, x).is_whole_ring() + assert R.ideal() == R.ideal(0) + assert R.ideal().is_zero() + + assert T.contains(x*y) + assert T.subset([x, y]) + + assert T.in_terms_of_generators(x) == [R(1), R(0)] + + assert T**0 == R.ideal(1) + assert T**1 == T + assert T**2 == R.ideal(x**2, y**2, x*y) + assert I**5 == R.ideal(x**5) + + +def test_exceptions(): + I = QQ.old_poly_ring(x).ideal(x) + J = QQ.old_poly_ring(y).ideal(1) + raises(ValueError, lambda: I.union(x)) + raises(ValueError, lambda: I + J) + raises(ValueError, lambda: I * J) + raises(ValueError, lambda: I.union(J)) + assert (I == J) is False + assert I != J + + +def test_nontriv_global(): + R = QQ.old_poly_ring(x, y, z) + + def contains(I, f): + return R.ideal(*I).contains(f) + + assert contains([x, y], x) + assert contains([x, y], x + y) + assert not contains([x, y], 1) + assert not contains([x, y], z) + assert contains([x**2 + y, x**2 + x], x - y) + assert not contains([x + y + z, x*y + x*z + y*z, x*y*z], x**2) + assert contains([x + y + z, x*y + x*z + y*z, x*y*z], x**3) + assert contains([x + y + z, x*y + x*z + y*z, x*y*z], x**4) + assert not contains([x + y + z, x*y + x*z + y*z, x*y*z], x*y**2) + assert contains([x + y + z, x*y + x*z + y*z, x*y*z], x**4 + y**3 + 2*z*y*x) + assert contains([x + y + z, x*y + x*z + y*z, x*y*z], x*y*z) + assert contains([x, 1 + x + y, 5 - 7*y], 1) + assert contains( + [x**3 + y**3, y**3 + z**3, z**3 + x**3, x**2*y + x**2*z + y**2*z], + x**3) + assert not contains( + [x**3 + y**3, y**3 + z**3, z**3 + x**3, x**2*y + x**2*z + y**2*z], + x**2 + y**2) + + # compare local order + assert not contains([x*(1 + x + y), y*(1 + z)], x) + assert not contains([x*(1 + x + y), y*(1 + z)], x + y) + + +def test_nontriv_local(): + R = QQ.old_poly_ring(x, y, z, order=ilex) + + def contains(I, f): + return R.ideal(*I).contains(f) + + assert contains([x, y], x) + assert contains([x, y], x + y) + assert not contains([x, y], 1) + assert not contains([x, y], z) + assert contains([x**2 + y, x**2 + x], x - y) + assert not contains([x + y + z, x*y + x*z + y*z, x*y*z], x**2) + assert contains([x*(1 + x + y), y*(1 + z)], x) + assert contains([x*(1 + x + y), y*(1 + z)], x + y) + + +def test_intersection(): + R = QQ.old_poly_ring(x, y, z) + # SCA, example 1.8.11 + assert R.ideal(x, y).intersect(R.ideal(y**2, z)) == R.ideal(y**2, y*z, x*z) + + assert R.ideal(x, y).intersect(R.ideal()).is_zero() + + R = QQ.old_poly_ring(x, y, z, order="ilex") + assert R.ideal(x, y).intersect(R.ideal(y**2 + y**2*z, z + z*x**3*y)) == \ + R.ideal(y**2, y*z, x*z) + + +def test_quotient(): + # SCA, example 1.8.13 + R = QQ.old_poly_ring(x, y, z) + assert R.ideal(x, y).quotient(R.ideal(y**2, z)) == R.ideal(x, y) + + +def test_reduction(): + from sympy.polys.distributedmodules import sdm_nf_buchberger_reduced + R = QQ.old_poly_ring(x, y) + I = R.ideal(x**5, y) + e = R.convert(x**3 + y**2) + assert I.reduce_element(e) == e + assert I.reduce_element(e, NF=sdm_nf_buchberger_reduced) == R.convert(x**3) diff --git a/phivenv/Lib/site-packages/sympy/polys/agca/tests/test_modules.py b/phivenv/Lib/site-packages/sympy/polys/agca/tests/test_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..29c2d4ce45f452f6f61420654be64a67d13b396b --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/agca/tests/test_modules.py @@ -0,0 +1,408 @@ +"""Test modules.py code.""" + +from sympy.polys.agca.modules import FreeModule, ModuleOrder, FreeModulePolyRing +from sympy.polys import CoercionFailed, QQ, lex, grlex, ilex, ZZ +from sympy.abc import x, y, z +from sympy.testing.pytest import raises +from sympy.core.numbers import Rational + + +def test_FreeModuleElement(): + M = QQ.old_poly_ring(x).free_module(3) + e = M.convert([1, x, x**2]) + f = [QQ.old_poly_ring(x).convert(1), QQ.old_poly_ring(x).convert(x), QQ.old_poly_ring(x).convert(x**2)] + assert list(e) == f + assert f[0] == e[0] + assert f[1] == e[1] + assert f[2] == e[2] + raises(IndexError, lambda: e[3]) + + g = M.convert([x, 0, 0]) + assert e + g == M.convert([x + 1, x, x**2]) + assert f + g == M.convert([x + 1, x, x**2]) + assert -e == M.convert([-1, -x, -x**2]) + assert e - g == M.convert([1 - x, x, x**2]) + assert e != g + + assert M.convert([x, x, x]) / QQ.old_poly_ring(x).convert(x) == [1, 1, 1] + R = QQ.old_poly_ring(x, order="ilex") + assert R.free_module(1).convert([x]) / R.convert(x) == [1] + + +def test_FreeModule(): + M1 = FreeModule(QQ.old_poly_ring(x), 2) + assert M1 == FreeModule(QQ.old_poly_ring(x), 2) + assert M1 != FreeModule(QQ.old_poly_ring(y), 2) + assert M1 != FreeModule(QQ.old_poly_ring(x), 3) + M2 = FreeModule(QQ.old_poly_ring(x, order="ilex"), 2) + + assert [x, 1] in M1 + assert [x] not in M1 + assert [2, y] not in M1 + assert [1/(x + 1), 2] not in M1 + + e = M1.convert([x, x**2 + 1]) + X = QQ.old_poly_ring(x).convert(x) + assert e == [X, X**2 + 1] + assert e == [x, x**2 + 1] + assert 2*e == [2*x, 2*x**2 + 2] + assert e*2 == [2*x, 2*x**2 + 2] + assert e/2 == [x/2, (x**2 + 1)/2] + assert x*e == [x**2, x**3 + x] + assert e*x == [x**2, x**3 + x] + assert X*e == [x**2, x**3 + x] + assert e*X == [x**2, x**3 + x] + + assert [x, 1] in M2 + assert [x] not in M2 + assert [2, y] not in M2 + assert [1/(x + 1), 2] in M2 + + e = M2.convert([x, x**2 + 1]) + X = QQ.old_poly_ring(x, order="ilex").convert(x) + assert e == [X, X**2 + 1] + assert e == [x, x**2 + 1] + assert 2*e == [2*x, 2*x**2 + 2] + assert e*2 == [2*x, 2*x**2 + 2] + assert e/2 == [x/2, (x**2 + 1)/2] + assert x*e == [x**2, x**3 + x] + assert e*x == [x**2, x**3 + x] + assert e/(1 + x) == [x/(1 + x), (x**2 + 1)/(1 + x)] + assert X*e == [x**2, x**3 + x] + assert e*X == [x**2, x**3 + x] + + M3 = FreeModule(QQ.old_poly_ring(x, y), 2) + assert M3.convert(e) == M3.convert([x, x**2 + 1]) + + assert not M3.is_submodule(0) + assert not M3.is_zero() + + raises(NotImplementedError, lambda: ZZ.old_poly_ring(x).free_module(2)) + raises(NotImplementedError, lambda: FreeModulePolyRing(ZZ, 2)) + raises(CoercionFailed, lambda: M1.convert(QQ.old_poly_ring(x).free_module(3) + .convert([1, 2, 3]))) + raises(CoercionFailed, lambda: M3.convert(1)) + + +def test_ModuleOrder(): + o1 = ModuleOrder(lex, grlex, False) + o2 = ModuleOrder(ilex, lex, False) + + assert o1 == ModuleOrder(lex, grlex, False) + assert (o1 != ModuleOrder(lex, grlex, False)) is False + assert o1 != o2 + + assert o1((1, 2, 3)) == (1, (5, (2, 3))) + assert o2((1, 2, 3)) == (-1, (2, 3)) + + +def test_SubModulePolyRing_global(): + R = QQ.old_poly_ring(x, y) + F = R.free_module(3) + Fd = F.submodule([1, 0, 0], [1, 2, 0], [1, 2, 3]) + M = F.submodule([x**2 + y**2, 1, 0], [x, y, 1]) + + assert F == Fd + assert Fd == F + assert F != M + assert M != F + assert Fd != M + assert M != Fd + assert Fd == F.submodule(*F.basis()) + + assert Fd.is_full_module() + assert not M.is_full_module() + assert not Fd.is_zero() + assert not M.is_zero() + assert Fd.submodule().is_zero() + + assert M.contains([x**2 + y**2 + x, 1 + y, 1]) + assert not M.contains([x**2 + y**2 + x, 1 + y, 2]) + assert M.contains([y**2, 1 - x*y, -x]) + + assert not F.submodule([1 + x, 0, 0]) == F.submodule([1, 0, 0]) + assert F.submodule([1, 0, 0], [0, 1, 0]).union(F.submodule([0, 0, 1])) == F + assert not M.is_submodule(0) + + m = F.convert([x**2 + y**2, 1, 0]) + n = M.convert(m) + assert m.module is F + assert n.module is M + + raises(ValueError, lambda: M.submodule([1, 0, 0])) + raises(TypeError, lambda: M.union(1)) + raises(ValueError, lambda: M.union(R.free_module(1).submodule([x]))) + + assert F.submodule([x, x, x]) != F.submodule([x, x, x], order="ilex") + + +def test_SubModulePolyRing_local(): + R = QQ.old_poly_ring(x, y, order=ilex) + F = R.free_module(3) + Fd = F.submodule([1 + x, 0, 0], [1 + y, 2 + 2*y, 0], [1, 2, 3]) + M = F.submodule([x**2 + y**2, 1, 0], [x, y, 1]) + + assert F == Fd + assert Fd == F + assert F != M + assert M != F + assert Fd != M + assert M != Fd + assert Fd == F.submodule(*F.basis()) + + assert Fd.is_full_module() + assert not M.is_full_module() + assert not Fd.is_zero() + assert not M.is_zero() + assert Fd.submodule().is_zero() + + assert M.contains([x**2 + y**2 + x, 1 + y, 1]) + assert not M.contains([x**2 + y**2 + x, 1 + y, 2]) + assert M.contains([y**2, 1 - x*y, -x]) + + assert F.submodule([1 + x, 0, 0]) == F.submodule([1, 0, 0]) + assert F.submodule( + [1, 0, 0], [0, 1, 0]).union(F.submodule([0, 0, 1 + x*y])) == F + + raises(ValueError, lambda: M.submodule([1, 0, 0])) + + +def test_SubModulePolyRing_nontriv_global(): + R = QQ.old_poly_ring(x, y, z) + F = R.free_module(1) + + def contains(I, f): + return F.submodule(*[[g] for g in I]).contains([f]) + + assert contains([x, y], x) + assert contains([x, y], x + y) + assert not contains([x, y], 1) + assert not contains([x, y], z) + assert contains([x**2 + y, x**2 + x], x - y) + assert not contains([x + y + z, x*y + x*z + y*z, x*y*z], x**2) + assert contains([x + y + z, x*y + x*z + y*z, x*y*z], x**3) + assert contains([x + y + z, x*y + x*z + y*z, x*y*z], x**4) + assert not contains([x + y + z, x*y + x*z + y*z, x*y*z], x*y**2) + assert contains([x + y + z, x*y + x*z + y*z, x*y*z], x**4 + y**3 + 2*z*y*x) + assert contains([x + y + z, x*y + x*z + y*z, x*y*z], x*y*z) + assert contains([x, 1 + x + y, 5 - 7*y], 1) + assert contains( + [x**3 + y**3, y**3 + z**3, z**3 + x**3, x**2*y + x**2*z + y**2*z], + x**3) + assert not contains( + [x**3 + y**3, y**3 + z**3, z**3 + x**3, x**2*y + x**2*z + y**2*z], + x**2 + y**2) + + # compare local order + assert not contains([x*(1 + x + y), y*(1 + z)], x) + assert not contains([x*(1 + x + y), y*(1 + z)], x + y) + + +def test_SubModulePolyRing_nontriv_local(): + R = QQ.old_poly_ring(x, y, z, order=ilex) + F = R.free_module(1) + + def contains(I, f): + return F.submodule(*[[g] for g in I]).contains([f]) + + assert contains([x, y], x) + assert contains([x, y], x + y) + assert not contains([x, y], 1) + assert not contains([x, y], z) + assert contains([x**2 + y, x**2 + x], x - y) + assert not contains([x + y + z, x*y + x*z + y*z, x*y*z], x**2) + assert contains([x*(1 + x + y), y*(1 + z)], x) + assert contains([x*(1 + x + y), y*(1 + z)], x + y) + + +def test_syzygy(): + R = QQ.old_poly_ring(x, y, z) + M = R.free_module(1).submodule([x*y], [y*z], [x*z]) + S = R.free_module(3).submodule([0, x, -y], [z, -x, 0]) + assert M.syzygy_module() == S + + M2 = M / ([x*y*z],) + S2 = R.free_module(3).submodule([z, 0, 0], [0, x, 0], [0, 0, y]) + assert M2.syzygy_module() == S2 + + F = R.free_module(3) + assert F.submodule(*F.basis()).syzygy_module() == F.submodule() + + R2 = QQ.old_poly_ring(x, y, z) / [x*y*z] + M3 = R2.free_module(1).submodule([x*y], [y*z], [x*z]) + S3 = R2.free_module(3).submodule([z, 0, 0], [0, x, 0], [0, 0, y]) + assert M3.syzygy_module() == S3 + + +def test_in_terms_of_generators(): + R = QQ.old_poly_ring(x, order="ilex") + M = R.free_module(2).submodule([2*x, 0], [1, 2]) + assert M.in_terms_of_generators( + [x, x]) == [R.convert(Rational(1, 4)), R.convert(x/2)] + raises(ValueError, lambda: M.in_terms_of_generators([1, 0])) + + M = R.free_module(2) / ([x, 0], [1, 1]) + SM = M.submodule([1, x]) + assert SM.in_terms_of_generators([2, 0]) == [R.convert(-2/(x - 1))] + + R = QQ.old_poly_ring(x, y) / [x**2 - y**2] + M = R.free_module(2) + SM = M.submodule([x, 0], [0, y]) + assert SM.in_terms_of_generators( + [x**2, x**2]) == [R.convert(x), R.convert(y)] + + +def test_QuotientModuleElement(): + R = QQ.old_poly_ring(x) + F = R.free_module(3) + N = F.submodule([1, x, x**2]) + M = F/N + e = M.convert([x**2, 2, 0]) + + assert M.convert([x + 1, x**2 + x, x**3 + x**2]) == 0 + assert e == [x**2, 2, 0] + N == F.convert([x**2, 2, 0]) + N == \ + M.convert(F.convert([x**2, 2, 0])) + + assert M.convert([x**2 + 1, 2*x + 2, x**2]) == e + [0, x, 0] == \ + e + M.convert([0, x, 0]) == e + F.convert([0, x, 0]) + assert M.convert([x**2 + 1, 2, x**2]) == e - [0, x, 0] == \ + e - M.convert([0, x, 0]) == e - F.convert([0, x, 0]) + assert M.convert([0, 2, 0]) == M.convert([x**2, 4, 0]) - e == \ + [x**2, 4, 0] - e == F.convert([x**2, 4, 0]) - e + assert M.convert([x**3 + x**2, 2*x + 2, 0]) == (1 + x)*e == \ + R.convert(1 + x)*e == e*(1 + x) == e*R.convert(1 + x) + assert -e == [-x**2, -2, 0] + + f = [x, x, 0] + N + assert M.convert([1, 1, 0]) == f / x == f / R.convert(x) + + M2 = F/[(2, 2*x, 2*x**2), (0, 0, 1)] + G = R.free_module(2) + M3 = G/[[1, x]] + M4 = F.submodule([1, x, x**2], [1, 0, 0]) / N + raises(CoercionFailed, lambda: M.convert(G.convert([1, x]))) + raises(CoercionFailed, lambda: M.convert(M3.convert([1, x]))) + raises(CoercionFailed, lambda: M.convert(M2.convert([1, x, x]))) + assert M2.convert(M.convert([2, x, x**2])) == [2, x, 0] + assert M.convert(M4.convert([2, 0, 0])) == [2, 0, 0] + + +def test_QuotientModule(): + R = QQ.old_poly_ring(x) + F = R.free_module(3) + N = F.submodule([1, x, x**2]) + M = F/N + + assert M != F + assert M != N + assert M == F / [(1, x, x**2)] + assert not M.is_zero() + assert (F / F.basis()).is_zero() + + SQ = F.submodule([1, x, x**2], [2, 0, 0]) / N + assert SQ == M.submodule([2, x, x**2]) + assert SQ != M.submodule([2, 1, 0]) + assert SQ != M + assert M.is_submodule(SQ) + assert not SQ.is_full_module() + + raises(ValueError, lambda: N/F) + raises(ValueError, lambda: F.submodule([2, 0, 0]) / N) + raises(ValueError, lambda: R.free_module(2)/F) + raises(CoercionFailed, lambda: F.convert(M.convert([1, x, x**2]))) + + M1 = F / [[1, 1, 1]] + M2 = M1.submodule([1, 0, 0], [0, 1, 0]) + assert M1 == M2 + + +def test_ModulesQuotientRing(): + R = QQ.old_poly_ring(x, y, order=(("lex", x), ("ilex", y))) / [x**2 + 1] + M1 = R.free_module(2) + assert M1 == R.free_module(2) + assert M1 != QQ.old_poly_ring(x).free_module(2) + assert M1 != R.free_module(3) + + assert [x, 1] in M1 + assert [x] not in M1 + assert [1/(R.convert(x) + 1), 2] in M1 + assert [1, 2/(1 + y)] in M1 + assert [1, 2/y] not in M1 + + assert M1.convert([x**2, y]) == [-1, y] + + F = R.free_module(3) + Fd = F.submodule([x**2, 0, 0], [1, 2, 0], [1, 2, 3]) + M = F.submodule([x**2 + y**2, 1, 0], [x, y, 1]) + + assert F == Fd + assert Fd == F + assert F != M + assert M != F + assert Fd != M + assert M != Fd + assert Fd == F.submodule(*F.basis()) + + assert Fd.is_full_module() + assert not M.is_full_module() + assert not Fd.is_zero() + assert not M.is_zero() + assert Fd.submodule().is_zero() + + assert M.contains([x**2 + y**2 + x, -x**2 + y, 1]) + assert not M.contains([x**2 + y**2 + x, 1 + y, 2]) + assert M.contains([y**2, 1 - x*y, -x]) + + assert F.submodule([x, 0, 0]) == F.submodule([1, 0, 0]) + assert not F.submodule([y, 0, 0]) == F.submodule([1, 0, 0]) + assert F.submodule([1, 0, 0], [0, 1, 0]).union(F.submodule([0, 0, 1])) == F + assert not M.is_submodule(0) + + +def test_module_mul(): + R = QQ.old_poly_ring(x) + M = R.free_module(2) + S1 = M.submodule([x, 0], [0, x]) + S2 = M.submodule([x**2, 0], [0, x**2]) + I = R.ideal(x) + + assert I*M == M*I == S1 == x*M == M*x + assert I*S1 == S2 == x*S1 + + +def test_intersection(): + # SCA, example 2.8.5 + F = QQ.old_poly_ring(x, y).free_module(2) + M1 = F.submodule([x, y], [y, 1]) + M2 = F.submodule([0, y - 1], [x, 1], [y, x]) + I = F.submodule([x, y], [y**2 - y, y - 1], [x*y + y, x + 1]) + I1, rel1, rel2 = M1.intersect(M2, relations=True) + assert I1 == M2.intersect(M1) == I + for i, g in enumerate(I1.gens): + assert g == sum(c*x for c, x in zip(rel1[i], M1.gens)) \ + == sum(d*y for d, y in zip(rel2[i], M2.gens)) + + assert F.submodule([x, y]).intersect(F.submodule([y, x])).is_zero() + + +def test_quotient(): + # SCA, example 2.8.6 + R = QQ.old_poly_ring(x, y, z) + F = R.free_module(2) + assert F.submodule([x*y, x*z], [y*z, x*y]).module_quotient( + F.submodule([y, z], [z, y])) == QQ.old_poly_ring(x, y, z).ideal(x**2*y**2 - x*y*z**2) + assert F.submodule([x, y]).module_quotient(F.submodule()).is_whole_ring() + + M = F.submodule([x**2, x**2], [y**2, y**2]) + N = F.submodule([x + y, x + y]) + q, rel = M.module_quotient(N, relations=True) + assert q == R.ideal(y**2, x - y) + for i, g in enumerate(q.gens): + assert g*N.gens[0] == sum(c*x for c, x in zip(rel[i], M.gens)) + + +def test_groebner_extendend(): + M = QQ.old_poly_ring(x, y, z).free_module(3).submodule([x + 1, y, 1], [x*y, z, z**2]) + G, R = M._groebner_vec(extended=True) + for i, g in enumerate(G): + assert g == sum(c*gen for c, gen in zip(R[i], M.gens)) diff --git a/phivenv/Lib/site-packages/sympy/polys/benchmarks/__init__.py b/phivenv/Lib/site-packages/sympy/polys/benchmarks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/polys/benchmarks/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/benchmarks/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8aeb75668610a8a0bcd32851524b37b6ba9aee6c Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/benchmarks/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/benchmarks/__pycache__/bench_galoispolys.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/benchmarks/__pycache__/bench_galoispolys.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40a43908991c263746a84b4bb96fd1f09d49f5fa Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/benchmarks/__pycache__/bench_galoispolys.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/benchmarks/__pycache__/bench_groebnertools.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/benchmarks/__pycache__/bench_groebnertools.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69ad826f38699481c5d6f476cf5ded9eb3b88459 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/benchmarks/__pycache__/bench_groebnertools.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/benchmarks/bench_galoispolys.py b/phivenv/Lib/site-packages/sympy/polys/benchmarks/bench_galoispolys.py new file mode 100644 index 0000000000000000000000000000000000000000..8b2a0329a0cf96be2e8359a3741d8e2de13fa37a --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/benchmarks/bench_galoispolys.py @@ -0,0 +1,66 @@ +"""Benchmarks for polynomials over Galois fields. """ + + +from sympy.polys.galoistools import gf_from_dict, gf_factor_sqf +from sympy.polys.domains import ZZ +from sympy.core.numbers import pi +from sympy.ntheory.generate import nextprime + + +def gathen_poly(n, p, K): + return gf_from_dict({n: K.one, 1: K.one, 0: K.one}, p, K) + + +def shoup_poly(n, p, K): + f = [K.one] * (n + 1) + for i in range(1, n + 1): + f[i] = (f[i - 1]**2 + K.one) % p + return f + + +def genprime(n, K): + return K(nextprime(int((2**n * pi).evalf()))) + +p_10 = genprime(10, ZZ) +f_10 = gathen_poly(10, p_10, ZZ) + +p_20 = genprime(20, ZZ) +f_20 = gathen_poly(20, p_20, ZZ) + + +def timeit_gathen_poly_f10_zassenhaus(): + gf_factor_sqf(f_10, p_10, ZZ, method='zassenhaus') + + +def timeit_gathen_poly_f10_shoup(): + gf_factor_sqf(f_10, p_10, ZZ, method='shoup') + + +def timeit_gathen_poly_f20_zassenhaus(): + gf_factor_sqf(f_20, p_20, ZZ, method='zassenhaus') + + +def timeit_gathen_poly_f20_shoup(): + gf_factor_sqf(f_20, p_20, ZZ, method='shoup') + +P_08 = genprime(8, ZZ) +F_10 = shoup_poly(10, P_08, ZZ) + +P_18 = genprime(18, ZZ) +F_20 = shoup_poly(20, P_18, ZZ) + + +def timeit_shoup_poly_F10_zassenhaus(): + gf_factor_sqf(F_10, P_08, ZZ, method='zassenhaus') + + +def timeit_shoup_poly_F10_shoup(): + gf_factor_sqf(F_10, P_08, ZZ, method='shoup') + + +def timeit_shoup_poly_F20_zassenhaus(): + gf_factor_sqf(F_20, P_18, ZZ, method='zassenhaus') + + +def timeit_shoup_poly_F20_shoup(): + gf_factor_sqf(F_20, P_18, ZZ, method='shoup') diff --git a/phivenv/Lib/site-packages/sympy/polys/benchmarks/bench_groebnertools.py b/phivenv/Lib/site-packages/sympy/polys/benchmarks/bench_groebnertools.py new file mode 100644 index 0000000000000000000000000000000000000000..e709f4f6d2cb42c0980d2e49725e01a7a2aa2b87 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/benchmarks/bench_groebnertools.py @@ -0,0 +1,25 @@ +"""Benchmark of the Groebner bases algorithms. """ + + +from sympy.polys.rings import ring +from sympy.polys.domains import QQ +from sympy.polys.groebnertools import groebner + +R, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12 = ring("x1:13", QQ) + +V = R.gens +E = [(x1, x2), (x2, x3), (x1, x4), (x1, x6), (x1, x12), (x2, x5), (x2, x7), (x3, x8), + (x3, x10), (x4, x11), (x4, x9), (x5, x6), (x6, x7), (x7, x8), (x8, x9), (x9, x10), + (x10, x11), (x11, x12), (x5, x12), (x5, x9), (x6, x10), (x7, x11), (x8, x12)] + +F3 = [ x**3 - 1 for x in V ] +Fg = [ x**2 + x*y + y**2 for x, y in E ] + +F_1 = F3 + Fg +F_2 = F3 + Fg + [x3**2 + x3*x4 + x4**2] + +def time_vertex_color_12_vertices_23_edges(): + assert groebner(F_1, R) != [1] + +def time_vertex_color_12_vertices_24_edges(): + assert groebner(F_2, R) == [1] diff --git a/phivenv/Lib/site-packages/sympy/polys/benchmarks/bench_solvers.py b/phivenv/Lib/site-packages/sympy/polys/benchmarks/bench_solvers.py new file mode 100644 index 0000000000000000000000000000000000000000..ed3ce5e246db2f5589e6a5dba9f18b7388c179c4 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/benchmarks/bench_solvers.py @@ -0,0 +1,543 @@ +from sympy.polys.rings import ring +from sympy.polys.fields import field +from sympy.polys.domains import ZZ, QQ +from sympy.polys.solvers import solve_lin_sys + +# Expected times on 3.4 GHz i7: + +# In [1]: %timeit time_solve_lin_sys_189x49() +# 1 loops, best of 3: 864 ms per loop +# In [2]: %timeit time_solve_lin_sys_165x165() +# 1 loops, best of 3: 1.83 s per loop +# In [3]: %timeit time_solve_lin_sys_10x8() +# 1 loops, best of 3: 2.31 s per loop + +# Benchmark R_165: shows how fast are arithmetics in QQ. + +R_165, uk_0, uk_1, uk_2, uk_3, uk_4, uk_5, uk_6, uk_7, uk_8, uk_9, uk_10, uk_11, uk_12, uk_13, uk_14, uk_15, uk_16, uk_17, uk_18, uk_19, uk_20, uk_21, uk_22, uk_23, uk_24, uk_25, uk_26, uk_27, uk_28, uk_29, uk_30, uk_31, uk_32, uk_33, uk_34, uk_35, uk_36, uk_37, uk_38, uk_39, uk_40, uk_41, uk_42, uk_43, uk_44, uk_45, uk_46, uk_47, uk_48, uk_49, uk_50, uk_51, uk_52, uk_53, uk_54, uk_55, uk_56, uk_57, uk_58, uk_59, uk_60, uk_61, uk_62, uk_63, uk_64, uk_65, uk_66, uk_67, uk_68, uk_69, uk_70, uk_71, uk_72, uk_73, uk_74, uk_75, uk_76, uk_77, uk_78, uk_79, uk_80, uk_81, uk_82, uk_83, uk_84, uk_85, uk_86, uk_87, uk_88, uk_89, uk_90, uk_91, uk_92, uk_93, uk_94, uk_95, uk_96, uk_97, uk_98, uk_99, uk_100, uk_101, uk_102, uk_103, uk_104, uk_105, uk_106, uk_107, uk_108, uk_109, uk_110, uk_111, uk_112, uk_113, uk_114, uk_115, uk_116, uk_117, uk_118, uk_119, uk_120, uk_121, uk_122, uk_123, uk_124, uk_125, uk_126, uk_127, uk_128, uk_129, uk_130, uk_131, uk_132, uk_133, uk_134, uk_135, uk_136, uk_137, uk_138, uk_139, uk_140, uk_141, uk_142, uk_143, uk_144, uk_145, uk_146, uk_147, uk_148, uk_149, uk_150, uk_151, uk_152, uk_153, uk_154, uk_155, uk_156, uk_157, uk_158, uk_159, uk_160, uk_161, uk_162, uk_163, uk_164 = ring("uk_:165", QQ) + +def eqs_165x165(): + return [ + uk_0 + 50719*uk_1 + 2789545*uk_10 + 411400*uk_100 + 1683000*uk_101 + 166375*uk_103 + 680625*uk_104 + 2784375*uk_106 + 729*uk_109 + 456471*uk_11 + 4131*uk_110 + 11016*uk_111 + 4455*uk_112 + 18225*uk_113 + 23409*uk_115 + 62424*uk_116 + 25245*uk_117 + 103275*uk_118 + 2586669*uk_12 + 166464*uk_120 + 67320*uk_121 + 275400*uk_122 + 27225*uk_124 + 111375*uk_125 + 455625*uk_127 + 6897784*uk_13 + 132651*uk_130 + 353736*uk_131 + 143055*uk_132 + 585225*uk_133 + 943296*uk_135 + 381480*uk_136 + 1560600*uk_137 + 154275*uk_139 + 2789545*uk_14 + 631125*uk_140 + 2581875*uk_142 + 2515456*uk_145 + 1017280*uk_146 + 4161600*uk_147 + 411400*uk_149 + 11411775*uk_15 + 1683000*uk_150 + 6885000*uk_152 + 166375*uk_155 + 680625*uk_156 + 2784375*uk_158 + 11390625*uk_161 + 3025*uk_17 + 495*uk_18 + 2805*uk_19 + 55*uk_2 + 7480*uk_20 + 3025*uk_21 + 12375*uk_22 + 81*uk_24 + 459*uk_25 + 1224*uk_26 + 495*uk_27 + 2025*uk_28 + 9*uk_3 + 2601*uk_30 + 6936*uk_31 + 2805*uk_32 + 11475*uk_33 + 18496*uk_35 + 7480*uk_36 + 30600*uk_37 + 3025*uk_39 + 51*uk_4 + 12375*uk_40 + 50625*uk_42 + 130470415844959*uk_45 + 141482932855*uk_46 + 23151752649*uk_47 + 131193265011*uk_48 + 349848706696*uk_49 + 136*uk_5 + 141482932855*uk_50 + 578793816225*uk_51 + 153424975*uk_53 + 25105905*uk_54 + 142266795*uk_55 + 379378120*uk_56 + 153424975*uk_57 + 627647625*uk_58 + 55*uk_6 + 4108239*uk_60 + 23280021*uk_61 + 62080056*uk_62 + 25105905*uk_63 + 102705975*uk_64 + 131920119*uk_66 + 351786984*uk_67 + 142266795*uk_68 + 582000525*uk_69 + 225*uk_7 + 938098624*uk_71 + 379378120*uk_72 + 1552001400*uk_73 + 153424975*uk_75 + 627647625*uk_76 + 2567649375*uk_78 + 166375*uk_81 + 27225*uk_82 + 154275*uk_83 + 411400*uk_84 + 166375*uk_85 + 680625*uk_86 + 4455*uk_88 + 25245*uk_89 + 2572416961*uk_9 + 67320*uk_90 + 27225*uk_91 + 111375*uk_92 + 143055*uk_94 + 381480*uk_95 + 154275*uk_96 + 631125*uk_97 + 1017280*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 413820*uk_100 + 1633500*uk_101 + 65340*uk_102 + 178695*uk_103 + 705375*uk_104 + 28215*uk_105 + 2784375*uk_106 + 111375*uk_107 + 4455*uk_108 + 97336*uk_109 + 2333074*uk_11 + 19044*uk_110 + 279312*uk_111 + 120612*uk_112 + 476100*uk_113 + 19044*uk_114 + 3726*uk_115 + 54648*uk_116 + 23598*uk_117 + 93150*uk_118 + 3726*uk_119 + 456471*uk_12 + 801504*uk_120 + 346104*uk_121 + 1366200*uk_122 + 54648*uk_123 + 149454*uk_124 + 589950*uk_125 + 23598*uk_126 + 2328750*uk_127 + 93150*uk_128 + 3726*uk_129 + 6694908*uk_13 + 729*uk_130 + 10692*uk_131 + 4617*uk_132 + 18225*uk_133 + 729*uk_134 + 156816*uk_135 + 67716*uk_136 + 267300*uk_137 + 10692*uk_138 + 29241*uk_139 + 2890983*uk_14 + 115425*uk_140 + 4617*uk_141 + 455625*uk_142 + 18225*uk_143 + 729*uk_144 + 2299968*uk_145 + 993168*uk_146 + 3920400*uk_147 + 156816*uk_148 + 428868*uk_149 + 11411775*uk_15 + 1692900*uk_150 + 67716*uk_151 + 6682500*uk_152 + 267300*uk_153 + 10692*uk_154 + 185193*uk_155 + 731025*uk_156 + 29241*uk_157 + 2885625*uk_158 + 115425*uk_159 + 456471*uk_16 + 4617*uk_160 + 11390625*uk_161 + 455625*uk_162 + 18225*uk_163 + 729*uk_164 + 3025*uk_17 + 2530*uk_18 + 495*uk_19 + 55*uk_2 + 7260*uk_20 + 3135*uk_21 + 12375*uk_22 + 495*uk_23 + 2116*uk_24 + 414*uk_25 + 6072*uk_26 + 2622*uk_27 + 10350*uk_28 + 414*uk_29 + 46*uk_3 + 81*uk_30 + 1188*uk_31 + 513*uk_32 + 2025*uk_33 + 81*uk_34 + 17424*uk_35 + 7524*uk_36 + 29700*uk_37 + 1188*uk_38 + 3249*uk_39 + 9*uk_4 + 12825*uk_40 + 513*uk_41 + 50625*uk_42 + 2025*uk_43 + 81*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 118331180206*uk_47 + 23151752649*uk_48 + 339559038852*uk_49 + 132*uk_5 + 146627766777*uk_50 + 578793816225*uk_51 + 23151752649*uk_52 + 153424975*uk_53 + 128319070*uk_54 + 25105905*uk_55 + 368219940*uk_56 + 159004065*uk_57 + 627647625*uk_58 + 25105905*uk_59 + 57*uk_6 + 107321404*uk_60 + 20997666*uk_61 + 307965768*uk_62 + 132985218*uk_63 + 524941650*uk_64 + 20997666*uk_65 + 4108239*uk_66 + 60254172*uk_67 + 26018847*uk_68 + 102705975*uk_69 + 225*uk_7 + 4108239*uk_70 + 883727856*uk_71 + 381609756*uk_72 + 1506354300*uk_73 + 60254172*uk_74 + 164786031*uk_75 + 650471175*uk_76 + 26018847*uk_77 + 2567649375*uk_78 + 102705975*uk_79 + 9*uk_8 + 4108239*uk_80 + 166375*uk_81 + 139150*uk_82 + 27225*uk_83 + 399300*uk_84 + 172425*uk_85 + 680625*uk_86 + 27225*uk_87 + 116380*uk_88 + 22770*uk_89 + 2572416961*uk_9 + 333960*uk_90 + 144210*uk_91 + 569250*uk_92 + 22770*uk_93 + 4455*uk_94 + 65340*uk_95 + 28215*uk_96 + 111375*uk_97 + 4455*uk_98 + 958320*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 402380*uk_100 + 1534500*uk_101 + 313720*uk_102 + 191455*uk_103 + 730125*uk_104 + 149270*uk_105 + 2784375*uk_106 + 569250*uk_107 + 116380*uk_108 + 912673*uk_109 + 4919743*uk_11 + 432814*uk_110 + 1166716*uk_111 + 555131*uk_112 + 2117025*uk_113 + 432814*uk_114 + 205252*uk_115 + 553288*uk_116 + 263258*uk_117 + 1003950*uk_118 + 205252*uk_119 + 2333074*uk_12 + 1491472*uk_120 + 709652*uk_121 + 2706300*uk_122 + 553288*uk_123 + 337657*uk_124 + 1287675*uk_125 + 263258*uk_126 + 4910625*uk_127 + 1003950*uk_128 + 205252*uk_129 + 6289156*uk_13 + 97336*uk_130 + 262384*uk_131 + 124844*uk_132 + 476100*uk_133 + 97336*uk_134 + 707296*uk_135 + 336536*uk_136 + 1283400*uk_137 + 262384*uk_138 + 160126*uk_139 + 2992421*uk_14 + 610650*uk_140 + 124844*uk_141 + 2328750*uk_142 + 476100*uk_143 + 97336*uk_144 + 1906624*uk_145 + 907184*uk_146 + 3459600*uk_147 + 707296*uk_148 + 431644*uk_149 + 11411775*uk_15 + 1646100*uk_150 + 336536*uk_151 + 6277500*uk_152 + 1283400*uk_153 + 262384*uk_154 + 205379*uk_155 + 783225*uk_156 + 160126*uk_157 + 2986875*uk_158 + 610650*uk_159 + 2333074*uk_16 + 124844*uk_160 + 11390625*uk_161 + 2328750*uk_162 + 476100*uk_163 + 97336*uk_164 + 3025*uk_17 + 5335*uk_18 + 2530*uk_19 + 55*uk_2 + 6820*uk_20 + 3245*uk_21 + 12375*uk_22 + 2530*uk_23 + 9409*uk_24 + 4462*uk_25 + 12028*uk_26 + 5723*uk_27 + 21825*uk_28 + 4462*uk_29 + 97*uk_3 + 2116*uk_30 + 5704*uk_31 + 2714*uk_32 + 10350*uk_33 + 2116*uk_34 + 15376*uk_35 + 7316*uk_36 + 27900*uk_37 + 5704*uk_38 + 3481*uk_39 + 46*uk_4 + 13275*uk_40 + 2714*uk_41 + 50625*uk_42 + 10350*uk_43 + 2116*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 249524445217*uk_47 + 118331180206*uk_48 + 318979703164*uk_49 + 124*uk_5 + 151772600699*uk_50 + 578793816225*uk_51 + 118331180206*uk_52 + 153424975*uk_53 + 270585865*uk_54 + 128319070*uk_55 + 345903580*uk_56 + 164583155*uk_57 + 627647625*uk_58 + 128319070*uk_59 + 59*uk_6 + 477215071*uk_60 + 226308178*uk_61 + 610048132*uk_62 + 290264837*uk_63 + 1106942175*uk_64 + 226308178*uk_65 + 107321404*uk_66 + 289301176*uk_67 + 137651366*uk_68 + 524941650*uk_69 + 225*uk_7 + 107321404*uk_70 + 779855344*uk_71 + 371060204*uk_72 + 1415060100*uk_73 + 289301176*uk_74 + 176552839*uk_75 + 673294725*uk_76 + 137651366*uk_77 + 2567649375*uk_78 + 524941650*uk_79 + 46*uk_8 + 107321404*uk_80 + 166375*uk_81 + 293425*uk_82 + 139150*uk_83 + 375100*uk_84 + 178475*uk_85 + 680625*uk_86 + 139150*uk_87 + 517495*uk_88 + 245410*uk_89 + 2572416961*uk_9 + 661540*uk_90 + 314765*uk_91 + 1200375*uk_92 + 245410*uk_93 + 116380*uk_94 + 313720*uk_95 + 149270*uk_96 + 569250*uk_97 + 116380*uk_98 + 845680*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 389180*uk_100 + 1435500*uk_101 + 618860*uk_102 + 204655*uk_103 + 754875*uk_104 + 325435*uk_105 + 2784375*uk_106 + 1200375*uk_107 + 517495*uk_108 + 3375000*uk_109 + 7607850*uk_11 + 2182500*uk_110 + 2610000*uk_111 + 1372500*uk_112 + 5062500*uk_113 + 2182500*uk_114 + 1411350*uk_115 + 1687800*uk_116 + 887550*uk_117 + 3273750*uk_118 + 1411350*uk_119 + 4919743*uk_12 + 2018400*uk_120 + 1061400*uk_121 + 3915000*uk_122 + 1687800*uk_123 + 558150*uk_124 + 2058750*uk_125 + 887550*uk_126 + 7593750*uk_127 + 3273750*uk_128 + 1411350*uk_129 + 5883404*uk_13 + 912673*uk_130 + 1091444*uk_131 + 573949*uk_132 + 2117025*uk_133 + 912673*uk_134 + 1305232*uk_135 + 686372*uk_136 + 2531700*uk_137 + 1091444*uk_138 + 360937*uk_139 + 3093859*uk_14 + 1331325*uk_140 + 573949*uk_141 + 4910625*uk_142 + 2117025*uk_143 + 912673*uk_144 + 1560896*uk_145 + 820816*uk_146 + 3027600*uk_147 + 1305232*uk_148 + 431636*uk_149 + 11411775*uk_15 + 1592100*uk_150 + 686372*uk_151 + 5872500*uk_152 + 2531700*uk_153 + 1091444*uk_154 + 226981*uk_155 + 837225*uk_156 + 360937*uk_157 + 3088125*uk_158 + 1331325*uk_159 + 4919743*uk_16 + 573949*uk_160 + 11390625*uk_161 + 4910625*uk_162 + 2117025*uk_163 + 912673*uk_164 + 3025*uk_17 + 8250*uk_18 + 5335*uk_19 + 55*uk_2 + 6380*uk_20 + 3355*uk_21 + 12375*uk_22 + 5335*uk_23 + 22500*uk_24 + 14550*uk_25 + 17400*uk_26 + 9150*uk_27 + 33750*uk_28 + 14550*uk_29 + 150*uk_3 + 9409*uk_30 + 11252*uk_31 + 5917*uk_32 + 21825*uk_33 + 9409*uk_34 + 13456*uk_35 + 7076*uk_36 + 26100*uk_37 + 11252*uk_38 + 3721*uk_39 + 97*uk_4 + 13725*uk_40 + 5917*uk_41 + 50625*uk_42 + 21825*uk_43 + 9409*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 385862544150*uk_47 + 249524445217*uk_48 + 298400367476*uk_49 + 116*uk_5 + 156917434621*uk_50 + 578793816225*uk_51 + 249524445217*uk_52 + 153424975*uk_53 + 418431750*uk_54 + 270585865*uk_55 + 323587220*uk_56 + 170162245*uk_57 + 627647625*uk_58 + 270585865*uk_59 + 61*uk_6 + 1141177500*uk_60 + 737961450*uk_61 + 882510600*uk_62 + 464078850*uk_63 + 1711766250*uk_64 + 737961450*uk_65 + 477215071*uk_66 + 570690188*uk_67 + 300104323*uk_68 + 1106942175*uk_69 + 225*uk_7 + 477215071*uk_70 + 682474864*uk_71 + 358887644*uk_72 + 1323765900*uk_73 + 570690188*uk_74 + 188725399*uk_75 + 696118275*uk_76 + 300104323*uk_77 + 2567649375*uk_78 + 1106942175*uk_79 + 97*uk_8 + 477215071*uk_80 + 166375*uk_81 + 453750*uk_82 + 293425*uk_83 + 350900*uk_84 + 184525*uk_85 + 680625*uk_86 + 293425*uk_87 + 1237500*uk_88 + 800250*uk_89 + 2572416961*uk_9 + 957000*uk_90 + 503250*uk_91 + 1856250*uk_92 + 800250*uk_93 + 517495*uk_94 + 618860*uk_95 + 325435*uk_96 + 1200375*uk_97 + 517495*uk_98 + 740080*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 374220*uk_100 + 1336500*uk_101 + 891000*uk_102 + 218295*uk_103 + 779625*uk_104 + 519750*uk_105 + 2784375*uk_106 + 1856250*uk_107 + 1237500*uk_108 + 7189057*uk_109 + 9788767*uk_11 + 5587350*uk_110 + 4022892*uk_111 + 2346687*uk_112 + 8381025*uk_113 + 5587350*uk_114 + 4342500*uk_115 + 3126600*uk_116 + 1823850*uk_117 + 6513750*uk_118 + 4342500*uk_119 + 7607850*uk_12 + 2251152*uk_120 + 1313172*uk_121 + 4689900*uk_122 + 3126600*uk_123 + 766017*uk_124 + 2735775*uk_125 + 1823850*uk_126 + 9770625*uk_127 + 6513750*uk_128 + 4342500*uk_129 + 5477652*uk_13 + 3375000*uk_130 + 2430000*uk_131 + 1417500*uk_132 + 5062500*uk_133 + 3375000*uk_134 + 1749600*uk_135 + 1020600*uk_136 + 3645000*uk_137 + 2430000*uk_138 + 595350*uk_139 + 3195297*uk_14 + 2126250*uk_140 + 1417500*uk_141 + 7593750*uk_142 + 5062500*uk_143 + 3375000*uk_144 + 1259712*uk_145 + 734832*uk_146 + 2624400*uk_147 + 1749600*uk_148 + 428652*uk_149 + 11411775*uk_15 + 1530900*uk_150 + 1020600*uk_151 + 5467500*uk_152 + 3645000*uk_153 + 2430000*uk_154 + 250047*uk_155 + 893025*uk_156 + 595350*uk_157 + 3189375*uk_158 + 2126250*uk_159 + 7607850*uk_16 + 1417500*uk_160 + 11390625*uk_161 + 7593750*uk_162 + 5062500*uk_163 + 3375000*uk_164 + 3025*uk_17 + 10615*uk_18 + 8250*uk_19 + 55*uk_2 + 5940*uk_20 + 3465*uk_21 + 12375*uk_22 + 8250*uk_23 + 37249*uk_24 + 28950*uk_25 + 20844*uk_26 + 12159*uk_27 + 43425*uk_28 + 28950*uk_29 + 193*uk_3 + 22500*uk_30 + 16200*uk_31 + 9450*uk_32 + 33750*uk_33 + 22500*uk_34 + 11664*uk_35 + 6804*uk_36 + 24300*uk_37 + 16200*uk_38 + 3969*uk_39 + 150*uk_4 + 14175*uk_40 + 9450*uk_41 + 50625*uk_42 + 33750*uk_43 + 22500*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 496476473473*uk_47 + 385862544150*uk_48 + 277821031788*uk_49 + 108*uk_5 + 162062268543*uk_50 + 578793816225*uk_51 + 385862544150*uk_52 + 153424975*uk_53 + 538382185*uk_54 + 418431750*uk_55 + 301270860*uk_56 + 175741335*uk_57 + 627647625*uk_58 + 418431750*uk_59 + 63*uk_6 + 1889232031*uk_60 + 1468315050*uk_61 + 1057186836*uk_62 + 616692321*uk_63 + 2202472575*uk_64 + 1468315050*uk_65 + 1141177500*uk_66 + 821647800*uk_67 + 479294550*uk_68 + 1711766250*uk_69 + 225*uk_7 + 1141177500*uk_70 + 591586416*uk_71 + 345092076*uk_72 + 1232471700*uk_73 + 821647800*uk_74 + 201303711*uk_75 + 718941825*uk_76 + 479294550*uk_77 + 2567649375*uk_78 + 1711766250*uk_79 + 150*uk_8 + 1141177500*uk_80 + 166375*uk_81 + 583825*uk_82 + 453750*uk_83 + 326700*uk_84 + 190575*uk_85 + 680625*uk_86 + 453750*uk_87 + 2048695*uk_88 + 1592250*uk_89 + 2572416961*uk_9 + 1146420*uk_90 + 668745*uk_91 + 2388375*uk_92 + 1592250*uk_93 + 1237500*uk_94 + 891000*uk_95 + 519750*uk_96 + 1856250*uk_97 + 1237500*uk_98 + 641520*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 357500*uk_100 + 1237500*uk_101 + 1061500*uk_102 + 232375*uk_103 + 804375*uk_104 + 689975*uk_105 + 2784375*uk_106 + 2388375*uk_107 + 2048695*uk_108 + 9800344*uk_109 + 10853866*uk_11 + 8838628*uk_110 + 4579600*uk_111 + 2976740*uk_112 + 10304100*uk_113 + 8838628*uk_114 + 7971286*uk_115 + 4130200*uk_116 + 2684630*uk_117 + 9292950*uk_118 + 7971286*uk_119 + 9788767*uk_12 + 2140000*uk_120 + 1391000*uk_121 + 4815000*uk_122 + 4130200*uk_123 + 904150*uk_124 + 3129750*uk_125 + 2684630*uk_126 + 10833750*uk_127 + 9292950*uk_128 + 7971286*uk_129 + 5071900*uk_13 + 7189057*uk_130 + 3724900*uk_131 + 2421185*uk_132 + 8381025*uk_133 + 7189057*uk_134 + 1930000*uk_135 + 1254500*uk_136 + 4342500*uk_137 + 3724900*uk_138 + 815425*uk_139 + 3296735*uk_14 + 2822625*uk_140 + 2421185*uk_141 + 9770625*uk_142 + 8381025*uk_143 + 7189057*uk_144 + 1000000*uk_145 + 650000*uk_146 + 2250000*uk_147 + 1930000*uk_148 + 422500*uk_149 + 11411775*uk_15 + 1462500*uk_150 + 1254500*uk_151 + 5062500*uk_152 + 4342500*uk_153 + 3724900*uk_154 + 274625*uk_155 + 950625*uk_156 + 815425*uk_157 + 3290625*uk_158 + 2822625*uk_159 + 9788767*uk_16 + 2421185*uk_160 + 11390625*uk_161 + 9770625*uk_162 + 8381025*uk_163 + 7189057*uk_164 + 3025*uk_17 + 11770*uk_18 + 10615*uk_19 + 55*uk_2 + 5500*uk_20 + 3575*uk_21 + 12375*uk_22 + 10615*uk_23 + 45796*uk_24 + 41302*uk_25 + 21400*uk_26 + 13910*uk_27 + 48150*uk_28 + 41302*uk_29 + 214*uk_3 + 37249*uk_30 + 19300*uk_31 + 12545*uk_32 + 43425*uk_33 + 37249*uk_34 + 10000*uk_35 + 6500*uk_36 + 22500*uk_37 + 19300*uk_38 + 4225*uk_39 + 193*uk_4 + 14625*uk_40 + 12545*uk_41 + 50625*uk_42 + 43425*uk_43 + 37249*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 550497229654*uk_47 + 496476473473*uk_48 + 257241696100*uk_49 + 100*uk_5 + 167207102465*uk_50 + 578793816225*uk_51 + 496476473473*uk_52 + 153424975*uk_53 + 596962630*uk_54 + 538382185*uk_55 + 278954500*uk_56 + 181320425*uk_57 + 627647625*uk_58 + 538382185*uk_59 + 65*uk_6 + 2322727324*uk_60 + 2094796138*uk_61 + 1085386600*uk_62 + 705501290*uk_63 + 2442119850*uk_64 + 2094796138*uk_65 + 1889232031*uk_66 + 978876700*uk_67 + 636269855*uk_68 + 2202472575*uk_69 + 225*uk_7 + 1889232031*uk_70 + 507190000*uk_71 + 329673500*uk_72 + 1141177500*uk_73 + 978876700*uk_74 + 214287775*uk_75 + 741765375*uk_76 + 636269855*uk_77 + 2567649375*uk_78 + 2202472575*uk_79 + 193*uk_8 + 1889232031*uk_80 + 166375*uk_81 + 647350*uk_82 + 583825*uk_83 + 302500*uk_84 + 196625*uk_85 + 680625*uk_86 + 583825*uk_87 + 2518780*uk_88 + 2271610*uk_89 + 2572416961*uk_9 + 1177000*uk_90 + 765050*uk_91 + 2648250*uk_92 + 2271610*uk_93 + 2048695*uk_94 + 1061500*uk_95 + 689975*uk_96 + 2388375*uk_97 + 2048695*uk_98 + 550000*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 339020*uk_100 + 1138500*uk_101 + 1082840*uk_102 + 246895*uk_103 + 829125*uk_104 + 788590*uk_105 + 2784375*uk_106 + 2648250*uk_107 + 2518780*uk_108 + 8120601*uk_109 + 10194519*uk_11 + 8645814*uk_110 + 3716892*uk_111 + 2706867*uk_112 + 9090225*uk_113 + 8645814*uk_114 + 9204996*uk_115 + 3957288*uk_116 + 2881938*uk_117 + 9678150*uk_118 + 9204996*uk_119 + 10853866*uk_12 + 1701264*uk_120 + 1238964*uk_121 + 4160700*uk_122 + 3957288*uk_123 + 902289*uk_124 + 3030075*uk_125 + 2881938*uk_126 + 10175625*uk_127 + 9678150*uk_128 + 9204996*uk_129 + 4666148*uk_13 + 9800344*uk_130 + 4213232*uk_131 + 3068332*uk_132 + 10304100*uk_133 + 9800344*uk_134 + 1811296*uk_135 + 1319096*uk_136 + 4429800*uk_137 + 4213232*uk_138 + 960646*uk_139 + 3398173*uk_14 + 3226050*uk_140 + 3068332*uk_141 + 10833750*uk_142 + 10304100*uk_143 + 9800344*uk_144 + 778688*uk_145 + 567088*uk_146 + 1904400*uk_147 + 1811296*uk_148 + 412988*uk_149 + 11411775*uk_15 + 1386900*uk_150 + 1319096*uk_151 + 4657500*uk_152 + 4429800*uk_153 + 4213232*uk_154 + 300763*uk_155 + 1010025*uk_156 + 960646*uk_157 + 3391875*uk_158 + 3226050*uk_159 + 10853866*uk_16 + 3068332*uk_160 + 11390625*uk_161 + 10833750*uk_162 + 10304100*uk_163 + 9800344*uk_164 + 3025*uk_17 + 11055*uk_18 + 11770*uk_19 + 55*uk_2 + 5060*uk_20 + 3685*uk_21 + 12375*uk_22 + 11770*uk_23 + 40401*uk_24 + 43014*uk_25 + 18492*uk_26 + 13467*uk_27 + 45225*uk_28 + 43014*uk_29 + 201*uk_3 + 45796*uk_30 + 19688*uk_31 + 14338*uk_32 + 48150*uk_33 + 45796*uk_34 + 8464*uk_35 + 6164*uk_36 + 20700*uk_37 + 19688*uk_38 + 4489*uk_39 + 214*uk_4 + 15075*uk_40 + 14338*uk_41 + 50625*uk_42 + 48150*uk_43 + 45796*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 517055809161*uk_47 + 550497229654*uk_48 + 236662360412*uk_49 + 92*uk_5 + 172351936387*uk_50 + 578793816225*uk_51 + 550497229654*uk_52 + 153424975*uk_53 + 560698545*uk_54 + 596962630*uk_55 + 256638140*uk_56 + 186899515*uk_57 + 627647625*uk_58 + 596962630*uk_59 + 67*uk_6 + 2049098319*uk_60 + 2181627066*uk_61 + 937895748*uk_62 + 683032773*uk_63 + 2293766775*uk_64 + 2181627066*uk_65 + 2322727324*uk_66 + 998555672*uk_67 + 727209022*uk_68 + 2442119850*uk_69 + 225*uk_7 + 2322727324*uk_70 + 429285616*uk_71 + 312631916*uk_72 + 1049883300*uk_73 + 998555672*uk_74 + 227677591*uk_75 + 764588925*uk_76 + 727209022*uk_77 + 2567649375*uk_78 + 2442119850*uk_79 + 214*uk_8 + 2322727324*uk_80 + 166375*uk_81 + 608025*uk_82 + 647350*uk_83 + 278300*uk_84 + 202675*uk_85 + 680625*uk_86 + 647350*uk_87 + 2222055*uk_88 + 2365770*uk_89 + 2572416961*uk_9 + 1017060*uk_90 + 740685*uk_91 + 2487375*uk_92 + 2365770*uk_93 + 2518780*uk_94 + 1082840*uk_95 + 788590*uk_96 + 2648250*uk_97 + 2518780*uk_98 + 465520*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 318780*uk_100 + 1039500*uk_101 + 928620*uk_102 + 261855*uk_103 + 853875*uk_104 + 762795*uk_105 + 2784375*uk_106 + 2487375*uk_107 + 2222055*uk_108 + 2863288*uk_109 + 7202098*uk_11 + 4052964*uk_110 + 1693776*uk_111 + 1391316*uk_112 + 4536900*uk_113 + 4052964*uk_114 + 5736942*uk_115 + 2397528*uk_116 + 1969398*uk_117 + 6421950*uk_118 + 5736942*uk_119 + 10194519*uk_12 + 1001952*uk_120 + 823032*uk_121 + 2683800*uk_122 + 2397528*uk_123 + 676062*uk_124 + 2204550*uk_125 + 1969398*uk_126 + 7188750*uk_127 + 6421950*uk_128 + 5736942*uk_129 + 4260396*uk_13 + 8120601*uk_130 + 3393684*uk_131 + 2787669*uk_132 + 9090225*uk_133 + 8120601*uk_134 + 1418256*uk_135 + 1164996*uk_136 + 3798900*uk_137 + 3393684*uk_138 + 956961*uk_139 + 3499611*uk_14 + 3120525*uk_140 + 2787669*uk_141 + 10175625*uk_142 + 9090225*uk_143 + 8120601*uk_144 + 592704*uk_145 + 486864*uk_146 + 1587600*uk_147 + 1418256*uk_148 + 399924*uk_149 + 11411775*uk_15 + 1304100*uk_150 + 1164996*uk_151 + 4252500*uk_152 + 3798900*uk_153 + 3393684*uk_154 + 328509*uk_155 + 1071225*uk_156 + 956961*uk_157 + 3493125*uk_158 + 3120525*uk_159 + 10194519*uk_16 + 2787669*uk_160 + 11390625*uk_161 + 10175625*uk_162 + 9090225*uk_163 + 8120601*uk_164 + 3025*uk_17 + 7810*uk_18 + 11055*uk_19 + 55*uk_2 + 4620*uk_20 + 3795*uk_21 + 12375*uk_22 + 11055*uk_23 + 20164*uk_24 + 28542*uk_25 + 11928*uk_26 + 9798*uk_27 + 31950*uk_28 + 28542*uk_29 + 142*uk_3 + 40401*uk_30 + 16884*uk_31 + 13869*uk_32 + 45225*uk_33 + 40401*uk_34 + 7056*uk_35 + 5796*uk_36 + 18900*uk_37 + 16884*uk_38 + 4761*uk_39 + 201*uk_4 + 15525*uk_40 + 13869*uk_41 + 50625*uk_42 + 45225*uk_43 + 40401*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 365283208462*uk_47 + 517055809161*uk_48 + 216083024724*uk_49 + 84*uk_5 + 177496770309*uk_50 + 578793816225*uk_51 + 517055809161*uk_52 + 153424975*uk_53 + 396115390*uk_54 + 560698545*uk_55 + 234321780*uk_56 + 192478605*uk_57 + 627647625*uk_58 + 560698545*uk_59 + 69*uk_6 + 1022697916*uk_60 + 1447621698*uk_61 + 604976232*uk_62 + 496944762*uk_63 + 1620472050*uk_64 + 1447621698*uk_65 + 2049098319*uk_66 + 856339596*uk_67 + 703421811*uk_68 + 2293766775*uk_69 + 225*uk_7 + 2049098319*uk_70 + 357873264*uk_71 + 293967324*uk_72 + 958589100*uk_73 + 856339596*uk_74 + 241473159*uk_75 + 787412475*uk_76 + 703421811*uk_77 + 2567649375*uk_78 + 2293766775*uk_79 + 201*uk_8 + 2049098319*uk_80 + 166375*uk_81 + 429550*uk_82 + 608025*uk_83 + 254100*uk_84 + 208725*uk_85 + 680625*uk_86 + 608025*uk_87 + 1109020*uk_88 + 1569810*uk_89 + 2572416961*uk_9 + 656040*uk_90 + 538890*uk_91 + 1757250*uk_92 + 1569810*uk_93 + 2222055*uk_94 + 928620*uk_95 + 762795*uk_96 + 2487375*uk_97 + 2222055*uk_98 + 388080*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 296780*uk_100 + 940500*uk_101 + 593560*uk_102 + 277255*uk_103 + 878625*uk_104 + 554510*uk_105 + 2784375*uk_106 + 1757250*uk_107 + 1109020*uk_108 + 15625*uk_109 + 1267975*uk_11 + 88750*uk_110 + 47500*uk_111 + 44375*uk_112 + 140625*uk_113 + 88750*uk_114 + 504100*uk_115 + 269800*uk_116 + 252050*uk_117 + 798750*uk_118 + 504100*uk_119 + 7202098*uk_12 + 144400*uk_120 + 134900*uk_121 + 427500*uk_122 + 269800*uk_123 + 126025*uk_124 + 399375*uk_125 + 252050*uk_126 + 1265625*uk_127 + 798750*uk_128 + 504100*uk_129 + 3854644*uk_13 + 2863288*uk_130 + 1532464*uk_131 + 1431644*uk_132 + 4536900*uk_133 + 2863288*uk_134 + 820192*uk_135 + 766232*uk_136 + 2428200*uk_137 + 1532464*uk_138 + 715822*uk_139 + 3601049*uk_14 + 2268450*uk_140 + 1431644*uk_141 + 7188750*uk_142 + 4536900*uk_143 + 2863288*uk_144 + 438976*uk_145 + 410096*uk_146 + 1299600*uk_147 + 820192*uk_148 + 383116*uk_149 + 11411775*uk_15 + 1214100*uk_150 + 766232*uk_151 + 3847500*uk_152 + 2428200*uk_153 + 1532464*uk_154 + 357911*uk_155 + 1134225*uk_156 + 715822*uk_157 + 3594375*uk_158 + 2268450*uk_159 + 7202098*uk_16 + 1431644*uk_160 + 11390625*uk_161 + 7188750*uk_162 + 4536900*uk_163 + 2863288*uk_164 + 3025*uk_17 + 1375*uk_18 + 7810*uk_19 + 55*uk_2 + 4180*uk_20 + 3905*uk_21 + 12375*uk_22 + 7810*uk_23 + 625*uk_24 + 3550*uk_25 + 1900*uk_26 + 1775*uk_27 + 5625*uk_28 + 3550*uk_29 + 25*uk_3 + 20164*uk_30 + 10792*uk_31 + 10082*uk_32 + 31950*uk_33 + 20164*uk_34 + 5776*uk_35 + 5396*uk_36 + 17100*uk_37 + 10792*uk_38 + 5041*uk_39 + 142*uk_4 + 15975*uk_40 + 10082*uk_41 + 50625*uk_42 + 31950*uk_43 + 20164*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 64310424025*uk_47 + 365283208462*uk_48 + 195503689036*uk_49 + 76*uk_5 + 182641604231*uk_50 + 578793816225*uk_51 + 365283208462*uk_52 + 153424975*uk_53 + 69738625*uk_54 + 396115390*uk_55 + 212005420*uk_56 + 198057695*uk_57 + 627647625*uk_58 + 396115390*uk_59 + 71*uk_6 + 31699375*uk_60 + 180052450*uk_61 + 96366100*uk_62 + 90026225*uk_63 + 285294375*uk_64 + 180052450*uk_65 + 1022697916*uk_66 + 547359448*uk_67 + 511348958*uk_68 + 1620472050*uk_69 + 225*uk_7 + 1022697916*uk_70 + 292952944*uk_71 + 273679724*uk_72 + 867294900*uk_73 + 547359448*uk_74 + 255674479*uk_75 + 810236025*uk_76 + 511348958*uk_77 + 2567649375*uk_78 + 1620472050*uk_79 + 142*uk_8 + 1022697916*uk_80 + 166375*uk_81 + 75625*uk_82 + 429550*uk_83 + 229900*uk_84 + 214775*uk_85 + 680625*uk_86 + 429550*uk_87 + 34375*uk_88 + 195250*uk_89 + 2572416961*uk_9 + 104500*uk_90 + 97625*uk_91 + 309375*uk_92 + 195250*uk_93 + 1109020*uk_94 + 593560*uk_95 + 554510*uk_96 + 1757250*uk_97 + 1109020*uk_98 + 317680*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 321200*uk_100 + 990000*uk_101 + 110000*uk_102 + 293095*uk_103 + 903375*uk_104 + 100375*uk_105 + 2784375*uk_106 + 309375*uk_107 + 34375*uk_108 + 185193*uk_109 + 2890983*uk_11 + 81225*uk_110 + 259920*uk_111 + 237177*uk_112 + 731025*uk_113 + 81225*uk_114 + 35625*uk_115 + 114000*uk_116 + 104025*uk_117 + 320625*uk_118 + 35625*uk_119 + 1267975*uk_12 + 364800*uk_120 + 332880*uk_121 + 1026000*uk_122 + 114000*uk_123 + 303753*uk_124 + 936225*uk_125 + 104025*uk_126 + 2885625*uk_127 + 320625*uk_128 + 35625*uk_129 + 4057520*uk_13 + 15625*uk_130 + 50000*uk_131 + 45625*uk_132 + 140625*uk_133 + 15625*uk_134 + 160000*uk_135 + 146000*uk_136 + 450000*uk_137 + 50000*uk_138 + 133225*uk_139 + 3702487*uk_14 + 410625*uk_140 + 45625*uk_141 + 1265625*uk_142 + 140625*uk_143 + 15625*uk_144 + 512000*uk_145 + 467200*uk_146 + 1440000*uk_147 + 160000*uk_148 + 426320*uk_149 + 11411775*uk_15 + 1314000*uk_150 + 146000*uk_151 + 4050000*uk_152 + 450000*uk_153 + 50000*uk_154 + 389017*uk_155 + 1199025*uk_156 + 133225*uk_157 + 3695625*uk_158 + 410625*uk_159 + 1267975*uk_16 + 45625*uk_160 + 11390625*uk_161 + 1265625*uk_162 + 140625*uk_163 + 15625*uk_164 + 3025*uk_17 + 3135*uk_18 + 1375*uk_19 + 55*uk_2 + 4400*uk_20 + 4015*uk_21 + 12375*uk_22 + 1375*uk_23 + 3249*uk_24 + 1425*uk_25 + 4560*uk_26 + 4161*uk_27 + 12825*uk_28 + 1425*uk_29 + 57*uk_3 + 625*uk_30 + 2000*uk_31 + 1825*uk_32 + 5625*uk_33 + 625*uk_34 + 6400*uk_35 + 5840*uk_36 + 18000*uk_37 + 2000*uk_38 + 5329*uk_39 + 25*uk_4 + 16425*uk_40 + 1825*uk_41 + 50625*uk_42 + 5625*uk_43 + 625*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 146627766777*uk_47 + 64310424025*uk_48 + 205793356880*uk_49 + 80*uk_5 + 187786438153*uk_50 + 578793816225*uk_51 + 64310424025*uk_52 + 153424975*uk_53 + 159004065*uk_54 + 69738625*uk_55 + 223163600*uk_56 + 203636785*uk_57 + 627647625*uk_58 + 69738625*uk_59 + 73*uk_6 + 164786031*uk_60 + 72274575*uk_61 + 231278640*uk_62 + 211041759*uk_63 + 650471175*uk_64 + 72274575*uk_65 + 31699375*uk_66 + 101438000*uk_67 + 92562175*uk_68 + 285294375*uk_69 + 225*uk_7 + 31699375*uk_70 + 324601600*uk_71 + 296198960*uk_72 + 912942000*uk_73 + 101438000*uk_74 + 270281551*uk_75 + 833059575*uk_76 + 92562175*uk_77 + 2567649375*uk_78 + 285294375*uk_79 + 25*uk_8 + 31699375*uk_80 + 166375*uk_81 + 172425*uk_82 + 75625*uk_83 + 242000*uk_84 + 220825*uk_85 + 680625*uk_86 + 75625*uk_87 + 178695*uk_88 + 78375*uk_89 + 2572416961*uk_9 + 250800*uk_90 + 228855*uk_91 + 705375*uk_92 + 78375*uk_93 + 34375*uk_94 + 110000*uk_95 + 100375*uk_96 + 309375*uk_97 + 34375*uk_98 + 352000*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 297000*uk_100 + 891000*uk_101 + 225720*uk_102 + 309375*uk_103 + 928125*uk_104 + 235125*uk_105 + 2784375*uk_106 + 705375*uk_107 + 178695*uk_108 + 6859*uk_109 + 963661*uk_11 + 20577*uk_110 + 25992*uk_111 + 27075*uk_112 + 81225*uk_113 + 20577*uk_114 + 61731*uk_115 + 77976*uk_116 + 81225*uk_117 + 243675*uk_118 + 61731*uk_119 + 2890983*uk_12 + 98496*uk_120 + 102600*uk_121 + 307800*uk_122 + 77976*uk_123 + 106875*uk_124 + 320625*uk_125 + 81225*uk_126 + 961875*uk_127 + 243675*uk_128 + 61731*uk_129 + 3651768*uk_13 + 185193*uk_130 + 233928*uk_131 + 243675*uk_132 + 731025*uk_133 + 185193*uk_134 + 295488*uk_135 + 307800*uk_136 + 923400*uk_137 + 233928*uk_138 + 320625*uk_139 + 3803925*uk_14 + 961875*uk_140 + 243675*uk_141 + 2885625*uk_142 + 731025*uk_143 + 185193*uk_144 + 373248*uk_145 + 388800*uk_146 + 1166400*uk_147 + 295488*uk_148 + 405000*uk_149 + 11411775*uk_15 + 1215000*uk_150 + 307800*uk_151 + 3645000*uk_152 + 923400*uk_153 + 233928*uk_154 + 421875*uk_155 + 1265625*uk_156 + 320625*uk_157 + 3796875*uk_158 + 961875*uk_159 + 2890983*uk_16 + 243675*uk_160 + 11390625*uk_161 + 2885625*uk_162 + 731025*uk_163 + 185193*uk_164 + 3025*uk_17 + 1045*uk_18 + 3135*uk_19 + 55*uk_2 + 3960*uk_20 + 4125*uk_21 + 12375*uk_22 + 3135*uk_23 + 361*uk_24 + 1083*uk_25 + 1368*uk_26 + 1425*uk_27 + 4275*uk_28 + 1083*uk_29 + 19*uk_3 + 3249*uk_30 + 4104*uk_31 + 4275*uk_32 + 12825*uk_33 + 3249*uk_34 + 5184*uk_35 + 5400*uk_36 + 16200*uk_37 + 4104*uk_38 + 5625*uk_39 + 57*uk_4 + 16875*uk_40 + 4275*uk_41 + 50625*uk_42 + 12825*uk_43 + 3249*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 48875922259*uk_47 + 146627766777*uk_48 + 185214021192*uk_49 + 72*uk_5 + 192931272075*uk_50 + 578793816225*uk_51 + 146627766777*uk_52 + 153424975*uk_53 + 53001355*uk_54 + 159004065*uk_55 + 200847240*uk_56 + 209215875*uk_57 + 627647625*uk_58 + 159004065*uk_59 + 75*uk_6 + 18309559*uk_60 + 54928677*uk_61 + 69383592*uk_62 + 72274575*uk_63 + 216823725*uk_64 + 54928677*uk_65 + 164786031*uk_66 + 208150776*uk_67 + 216823725*uk_68 + 650471175*uk_69 + 225*uk_7 + 164786031*uk_70 + 262927296*uk_71 + 273882600*uk_72 + 821647800*uk_73 + 208150776*uk_74 + 285294375*uk_75 + 855883125*uk_76 + 216823725*uk_77 + 2567649375*uk_78 + 650471175*uk_79 + 57*uk_8 + 164786031*uk_80 + 166375*uk_81 + 57475*uk_82 + 172425*uk_83 + 217800*uk_84 + 226875*uk_85 + 680625*uk_86 + 172425*uk_87 + 19855*uk_88 + 59565*uk_89 + 2572416961*uk_9 + 75240*uk_90 + 78375*uk_91 + 235125*uk_92 + 59565*uk_93 + 178695*uk_94 + 225720*uk_95 + 235125*uk_96 + 705375*uk_97 + 178695*uk_98 + 285120*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 304920*uk_100 + 891000*uk_101 + 75240*uk_102 + 326095*uk_103 + 952875*uk_104 + 80465*uk_105 + 2784375*uk_106 + 235125*uk_107 + 19855*uk_108 + 148877*uk_109 + 2688107*uk_11 + 53371*uk_110 + 202248*uk_111 + 216293*uk_112 + 632025*uk_113 + 53371*uk_114 + 19133*uk_115 + 72504*uk_116 + 77539*uk_117 + 226575*uk_118 + 19133*uk_119 + 963661*uk_12 + 274752*uk_120 + 293832*uk_121 + 858600*uk_122 + 72504*uk_123 + 314237*uk_124 + 918225*uk_125 + 77539*uk_126 + 2683125*uk_127 + 226575*uk_128 + 19133*uk_129 + 3651768*uk_13 + 6859*uk_130 + 25992*uk_131 + 27797*uk_132 + 81225*uk_133 + 6859*uk_134 + 98496*uk_135 + 105336*uk_136 + 307800*uk_137 + 25992*uk_138 + 112651*uk_139 + 3905363*uk_14 + 329175*uk_140 + 27797*uk_141 + 961875*uk_142 + 81225*uk_143 + 6859*uk_144 + 373248*uk_145 + 399168*uk_146 + 1166400*uk_147 + 98496*uk_148 + 426888*uk_149 + 11411775*uk_15 + 1247400*uk_150 + 105336*uk_151 + 3645000*uk_152 + 307800*uk_153 + 25992*uk_154 + 456533*uk_155 + 1334025*uk_156 + 112651*uk_157 + 3898125*uk_158 + 329175*uk_159 + 963661*uk_16 + 27797*uk_160 + 11390625*uk_161 + 961875*uk_162 + 81225*uk_163 + 6859*uk_164 + 3025*uk_17 + 2915*uk_18 + 1045*uk_19 + 55*uk_2 + 3960*uk_20 + 4235*uk_21 + 12375*uk_22 + 1045*uk_23 + 2809*uk_24 + 1007*uk_25 + 3816*uk_26 + 4081*uk_27 + 11925*uk_28 + 1007*uk_29 + 53*uk_3 + 361*uk_30 + 1368*uk_31 + 1463*uk_32 + 4275*uk_33 + 361*uk_34 + 5184*uk_35 + 5544*uk_36 + 16200*uk_37 + 1368*uk_38 + 5929*uk_39 + 19*uk_4 + 17325*uk_40 + 1463*uk_41 + 50625*uk_42 + 4275*uk_43 + 361*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 136338098933*uk_47 + 48875922259*uk_48 + 185214021192*uk_49 + 72*uk_5 + 198076105997*uk_50 + 578793816225*uk_51 + 48875922259*uk_52 + 153424975*uk_53 + 147845885*uk_54 + 53001355*uk_55 + 200847240*uk_56 + 214794965*uk_57 + 627647625*uk_58 + 53001355*uk_59 + 77*uk_6 + 142469671*uk_60 + 51074033*uk_61 + 193543704*uk_62 + 206984239*uk_63 + 604824075*uk_64 + 51074033*uk_65 + 18309559*uk_66 + 69383592*uk_67 + 74201897*uk_68 + 216823725*uk_69 + 225*uk_7 + 18309559*uk_70 + 262927296*uk_71 + 281186136*uk_72 + 821647800*uk_73 + 69383592*uk_74 + 300712951*uk_75 + 878706675*uk_76 + 74201897*uk_77 + 2567649375*uk_78 + 216823725*uk_79 + 19*uk_8 + 18309559*uk_80 + 166375*uk_81 + 160325*uk_82 + 57475*uk_83 + 217800*uk_84 + 232925*uk_85 + 680625*uk_86 + 57475*uk_87 + 154495*uk_88 + 55385*uk_89 + 2572416961*uk_9 + 209880*uk_90 + 224455*uk_91 + 655875*uk_92 + 55385*uk_93 + 19855*uk_94 + 75240*uk_95 + 80465*uk_96 + 235125*uk_97 + 19855*uk_98 + 285120*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 278080*uk_100 + 792000*uk_101 + 186560*uk_102 + 343255*uk_103 + 977625*uk_104 + 230285*uk_105 + 2784375*uk_106 + 655875*uk_107 + 154495*uk_108 + uk_109 + 50719*uk_11 + 53*uk_110 + 64*uk_111 + 79*uk_112 + 225*uk_113 + 53*uk_114 + 2809*uk_115 + 3392*uk_116 + 4187*uk_117 + 11925*uk_118 + 2809*uk_119 + 2688107*uk_12 + 4096*uk_120 + 5056*uk_121 + 14400*uk_122 + 3392*uk_123 + 6241*uk_124 + 17775*uk_125 + 4187*uk_126 + 50625*uk_127 + 11925*uk_128 + 2809*uk_129 + 3246016*uk_13 + 148877*uk_130 + 179776*uk_131 + 221911*uk_132 + 632025*uk_133 + 148877*uk_134 + 217088*uk_135 + 267968*uk_136 + 763200*uk_137 + 179776*uk_138 + 330773*uk_139 + 4006801*uk_14 + 942075*uk_140 + 221911*uk_141 + 2683125*uk_142 + 632025*uk_143 + 148877*uk_144 + 262144*uk_145 + 323584*uk_146 + 921600*uk_147 + 217088*uk_148 + 399424*uk_149 + 11411775*uk_15 + 1137600*uk_150 + 267968*uk_151 + 3240000*uk_152 + 763200*uk_153 + 179776*uk_154 + 493039*uk_155 + 1404225*uk_156 + 330773*uk_157 + 3999375*uk_158 + 942075*uk_159 + 2688107*uk_16 + 221911*uk_160 + 11390625*uk_161 + 2683125*uk_162 + 632025*uk_163 + 148877*uk_164 + 3025*uk_17 + 55*uk_18 + 2915*uk_19 + 55*uk_2 + 3520*uk_20 + 4345*uk_21 + 12375*uk_22 + 2915*uk_23 + uk_24 + 53*uk_25 + 64*uk_26 + 79*uk_27 + 225*uk_28 + 53*uk_29 + uk_3 + 2809*uk_30 + 3392*uk_31 + 4187*uk_32 + 11925*uk_33 + 2809*uk_34 + 4096*uk_35 + 5056*uk_36 + 14400*uk_37 + 3392*uk_38 + 6241*uk_39 + 53*uk_4 + 17775*uk_40 + 4187*uk_41 + 50625*uk_42 + 11925*uk_43 + 2809*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 2572416961*uk_47 + 136338098933*uk_48 + 164634685504*uk_49 + 64*uk_5 + 203220939919*uk_50 + 578793816225*uk_51 + 136338098933*uk_52 + 153424975*uk_53 + 2789545*uk_54 + 147845885*uk_55 + 178530880*uk_56 + 220374055*uk_57 + 627647625*uk_58 + 147845885*uk_59 + 79*uk_6 + 50719*uk_60 + 2688107*uk_61 + 3246016*uk_62 + 4006801*uk_63 + 11411775*uk_64 + 2688107*uk_65 + 142469671*uk_66 + 172038848*uk_67 + 212360453*uk_68 + 604824075*uk_69 + 225*uk_7 + 142469671*uk_70 + 207745024*uk_71 + 256435264*uk_72 + 730353600*uk_73 + 172038848*uk_74 + 316537279*uk_75 + 901530225*uk_76 + 212360453*uk_77 + 2567649375*uk_78 + 604824075*uk_79 + 53*uk_8 + 142469671*uk_80 + 166375*uk_81 + 3025*uk_82 + 160325*uk_83 + 193600*uk_84 + 238975*uk_85 + 680625*uk_86 + 160325*uk_87 + 55*uk_88 + 2915*uk_89 + 2572416961*uk_9 + 3520*uk_90 + 4345*uk_91 + 12375*uk_92 + 2915*uk_93 + 154495*uk_94 + 186560*uk_95 + 230285*uk_96 + 655875*uk_97 + 154495*uk_98 + 225280*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 285120*uk_100 + 792000*uk_101 + 3520*uk_102 + 360855*uk_103 + 1002375*uk_104 + 4455*uk_105 + 2784375*uk_106 + 12375*uk_107 + 55*uk_108 + 2197*uk_109 + 659347*uk_11 + 169*uk_110 + 10816*uk_111 + 13689*uk_112 + 38025*uk_113 + 169*uk_114 + 13*uk_115 + 832*uk_116 + 1053*uk_117 + 2925*uk_118 + 13*uk_119 + 50719*uk_12 + 53248*uk_120 + 67392*uk_121 + 187200*uk_122 + 832*uk_123 + 85293*uk_124 + 236925*uk_125 + 1053*uk_126 + 658125*uk_127 + 2925*uk_128 + 13*uk_129 + 3246016*uk_13 + uk_130 + 64*uk_131 + 81*uk_132 + 225*uk_133 + uk_134 + 4096*uk_135 + 5184*uk_136 + 14400*uk_137 + 64*uk_138 + 6561*uk_139 + 4108239*uk_14 + 18225*uk_140 + 81*uk_141 + 50625*uk_142 + 225*uk_143 + uk_144 + 262144*uk_145 + 331776*uk_146 + 921600*uk_147 + 4096*uk_148 + 419904*uk_149 + 11411775*uk_15 + 1166400*uk_150 + 5184*uk_151 + 3240000*uk_152 + 14400*uk_153 + 64*uk_154 + 531441*uk_155 + 1476225*uk_156 + 6561*uk_157 + 4100625*uk_158 + 18225*uk_159 + 50719*uk_16 + 81*uk_160 + 11390625*uk_161 + 50625*uk_162 + 225*uk_163 + uk_164 + 3025*uk_17 + 715*uk_18 + 55*uk_19 + 55*uk_2 + 3520*uk_20 + 4455*uk_21 + 12375*uk_22 + 55*uk_23 + 169*uk_24 + 13*uk_25 + 832*uk_26 + 1053*uk_27 + 2925*uk_28 + 13*uk_29 + 13*uk_3 + uk_30 + 64*uk_31 + 81*uk_32 + 225*uk_33 + uk_34 + 4096*uk_35 + 5184*uk_36 + 14400*uk_37 + 64*uk_38 + 6561*uk_39 + uk_4 + 18225*uk_40 + 81*uk_41 + 50625*uk_42 + 225*uk_43 + uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 33441420493*uk_47 + 2572416961*uk_48 + 164634685504*uk_49 + 64*uk_5 + 208365773841*uk_50 + 578793816225*uk_51 + 2572416961*uk_52 + 153424975*uk_53 + 36264085*uk_54 + 2789545*uk_55 + 178530880*uk_56 + 225953145*uk_57 + 627647625*uk_58 + 2789545*uk_59 + 81*uk_6 + 8571511*uk_60 + 659347*uk_61 + 42198208*uk_62 + 53407107*uk_63 + 148353075*uk_64 + 659347*uk_65 + 50719*uk_66 + 3246016*uk_67 + 4108239*uk_68 + 11411775*uk_69 + 225*uk_7 + 50719*uk_70 + 207745024*uk_71 + 262927296*uk_72 + 730353600*uk_73 + 3246016*uk_74 + 332767359*uk_75 + 924353775*uk_76 + 4108239*uk_77 + 2567649375*uk_78 + 11411775*uk_79 + uk_8 + 50719*uk_80 + 166375*uk_81 + 39325*uk_82 + 3025*uk_83 + 193600*uk_84 + 245025*uk_85 + 680625*uk_86 + 3025*uk_87 + 9295*uk_88 + 715*uk_89 + 2572416961*uk_9 + 45760*uk_90 + 57915*uk_91 + 160875*uk_92 + 715*uk_93 + 55*uk_94 + 3520*uk_95 + 4455*uk_96 + 12375*uk_97 + 55*uk_98 + 225280*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 273900*uk_100 + 742500*uk_101 + 42900*uk_102 + 378895*uk_103 + 1027125*uk_104 + 59345*uk_105 + 2784375*uk_106 + 160875*uk_107 + 9295*uk_108 + 216*uk_109 + 304314*uk_11 + 468*uk_110 + 2160*uk_111 + 2988*uk_112 + 8100*uk_113 + 468*uk_114 + 1014*uk_115 + 4680*uk_116 + 6474*uk_117 + 17550*uk_118 + 1014*uk_119 + 659347*uk_12 + 21600*uk_120 + 29880*uk_121 + 81000*uk_122 + 4680*uk_123 + 41334*uk_124 + 112050*uk_125 + 6474*uk_126 + 303750*uk_127 + 17550*uk_128 + 1014*uk_129 + 3043140*uk_13 + 2197*uk_130 + 10140*uk_131 + 14027*uk_132 + 38025*uk_133 + 2197*uk_134 + 46800*uk_135 + 64740*uk_136 + 175500*uk_137 + 10140*uk_138 + 89557*uk_139 + 4209677*uk_14 + 242775*uk_140 + 14027*uk_141 + 658125*uk_142 + 38025*uk_143 + 2197*uk_144 + 216000*uk_145 + 298800*uk_146 + 810000*uk_147 + 46800*uk_148 + 413340*uk_149 + 11411775*uk_15 + 1120500*uk_150 + 64740*uk_151 + 3037500*uk_152 + 175500*uk_153 + 10140*uk_154 + 571787*uk_155 + 1550025*uk_156 + 89557*uk_157 + 4201875*uk_158 + 242775*uk_159 + 659347*uk_16 + 14027*uk_160 + 11390625*uk_161 + 658125*uk_162 + 38025*uk_163 + 2197*uk_164 + 3025*uk_17 + 330*uk_18 + 715*uk_19 + 55*uk_2 + 3300*uk_20 + 4565*uk_21 + 12375*uk_22 + 715*uk_23 + 36*uk_24 + 78*uk_25 + 360*uk_26 + 498*uk_27 + 1350*uk_28 + 78*uk_29 + 6*uk_3 + 169*uk_30 + 780*uk_31 + 1079*uk_32 + 2925*uk_33 + 169*uk_34 + 3600*uk_35 + 4980*uk_36 + 13500*uk_37 + 780*uk_38 + 6889*uk_39 + 13*uk_4 + 18675*uk_40 + 1079*uk_41 + 50625*uk_42 + 2925*uk_43 + 169*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 15434501766*uk_47 + 33441420493*uk_48 + 154345017660*uk_49 + 60*uk_5 + 213510607763*uk_50 + 578793816225*uk_51 + 33441420493*uk_52 + 153424975*uk_53 + 16737270*uk_54 + 36264085*uk_55 + 167372700*uk_56 + 231532235*uk_57 + 627647625*uk_58 + 36264085*uk_59 + 83*uk_6 + 1825884*uk_60 + 3956082*uk_61 + 18258840*uk_62 + 25258062*uk_63 + 68470650*uk_64 + 3956082*uk_65 + 8571511*uk_66 + 39560820*uk_67 + 54725801*uk_68 + 148353075*uk_69 + 225*uk_7 + 8571511*uk_70 + 182588400*uk_71 + 252580620*uk_72 + 684706500*uk_73 + 39560820*uk_74 + 349403191*uk_75 + 947177325*uk_76 + 54725801*uk_77 + 2567649375*uk_78 + 148353075*uk_79 + 13*uk_8 + 8571511*uk_80 + 166375*uk_81 + 18150*uk_82 + 39325*uk_83 + 181500*uk_84 + 251075*uk_85 + 680625*uk_86 + 39325*uk_87 + 1980*uk_88 + 4290*uk_89 + 2572416961*uk_9 + 19800*uk_90 + 27390*uk_91 + 74250*uk_92 + 4290*uk_93 + 9295*uk_94 + 42900*uk_95 + 59345*uk_96 + 160875*uk_97 + 9295*uk_98 + 198000*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 280500*uk_100 + 742500*uk_101 + 19800*uk_102 + 397375*uk_103 + 1051875*uk_104 + 28050*uk_105 + 2784375*uk_106 + 74250*uk_107 + 1980*uk_108 + 205379*uk_109 + 2992421*uk_11 + 20886*uk_110 + 208860*uk_111 + 295885*uk_112 + 783225*uk_113 + 20886*uk_114 + 2124*uk_115 + 21240*uk_116 + 30090*uk_117 + 79650*uk_118 + 2124*uk_119 + 304314*uk_12 + 212400*uk_120 + 300900*uk_121 + 796500*uk_122 + 21240*uk_123 + 426275*uk_124 + 1128375*uk_125 + 30090*uk_126 + 2986875*uk_127 + 79650*uk_128 + 2124*uk_129 + 3043140*uk_13 + 216*uk_130 + 2160*uk_131 + 3060*uk_132 + 8100*uk_133 + 216*uk_134 + 21600*uk_135 + 30600*uk_136 + 81000*uk_137 + 2160*uk_138 + 43350*uk_139 + 4311115*uk_14 + 114750*uk_140 + 3060*uk_141 + 303750*uk_142 + 8100*uk_143 + 216*uk_144 + 216000*uk_145 + 306000*uk_146 + 810000*uk_147 + 21600*uk_148 + 433500*uk_149 + 11411775*uk_15 + 1147500*uk_150 + 30600*uk_151 + 3037500*uk_152 + 81000*uk_153 + 2160*uk_154 + 614125*uk_155 + 1625625*uk_156 + 43350*uk_157 + 4303125*uk_158 + 114750*uk_159 + 304314*uk_16 + 3060*uk_160 + 11390625*uk_161 + 303750*uk_162 + 8100*uk_163 + 216*uk_164 + 3025*uk_17 + 3245*uk_18 + 330*uk_19 + 55*uk_2 + 3300*uk_20 + 4675*uk_21 + 12375*uk_22 + 330*uk_23 + 3481*uk_24 + 354*uk_25 + 3540*uk_26 + 5015*uk_27 + 13275*uk_28 + 354*uk_29 + 59*uk_3 + 36*uk_30 + 360*uk_31 + 510*uk_32 + 1350*uk_33 + 36*uk_34 + 3600*uk_35 + 5100*uk_36 + 13500*uk_37 + 360*uk_38 + 7225*uk_39 + 6*uk_4 + 19125*uk_40 + 510*uk_41 + 50625*uk_42 + 1350*uk_43 + 36*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 151772600699*uk_47 + 15434501766*uk_48 + 154345017660*uk_49 + 60*uk_5 + 218655441685*uk_50 + 578793816225*uk_51 + 15434501766*uk_52 + 153424975*uk_53 + 164583155*uk_54 + 16737270*uk_55 + 167372700*uk_56 + 237111325*uk_57 + 627647625*uk_58 + 16737270*uk_59 + 85*uk_6 + 176552839*uk_60 + 17954526*uk_61 + 179545260*uk_62 + 254355785*uk_63 + 673294725*uk_64 + 17954526*uk_65 + 1825884*uk_66 + 18258840*uk_67 + 25866690*uk_68 + 68470650*uk_69 + 225*uk_7 + 1825884*uk_70 + 182588400*uk_71 + 258666900*uk_72 + 684706500*uk_73 + 18258840*uk_74 + 366444775*uk_75 + 970000875*uk_76 + 25866690*uk_77 + 2567649375*uk_78 + 68470650*uk_79 + 6*uk_8 + 1825884*uk_80 + 166375*uk_81 + 178475*uk_82 + 18150*uk_83 + 181500*uk_84 + 257125*uk_85 + 680625*uk_86 + 18150*uk_87 + 191455*uk_88 + 19470*uk_89 + 2572416961*uk_9 + 194700*uk_90 + 275825*uk_91 + 730125*uk_92 + 19470*uk_93 + 1980*uk_94 + 19800*uk_95 + 28050*uk_96 + 74250*uk_97 + 1980*uk_98 + 198000*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 267960*uk_100 + 693000*uk_101 + 181720*uk_102 + 416295*uk_103 + 1076625*uk_104 + 282315*uk_105 + 2784375*uk_106 + 730125*uk_107 + 191455*uk_108 + 614125*uk_109 + 4311115*uk_11 + 426275*uk_110 + 404600*uk_111 + 628575*uk_112 + 1625625*uk_113 + 426275*uk_114 + 295885*uk_115 + 280840*uk_116 + 436305*uk_117 + 1128375*uk_118 + 295885*uk_119 + 2992421*uk_12 + 266560*uk_120 + 414120*uk_121 + 1071000*uk_122 + 280840*uk_123 + 643365*uk_124 + 1663875*uk_125 + 436305*uk_126 + 4303125*uk_127 + 1128375*uk_128 + 295885*uk_129 + 2840264*uk_13 + 205379*uk_130 + 194936*uk_131 + 302847*uk_132 + 783225*uk_133 + 205379*uk_134 + 185024*uk_135 + 287448*uk_136 + 743400*uk_137 + 194936*uk_138 + 446571*uk_139 + 4412553*uk_14 + 1154925*uk_140 + 302847*uk_141 + 2986875*uk_142 + 783225*uk_143 + 205379*uk_144 + 175616*uk_145 + 272832*uk_146 + 705600*uk_147 + 185024*uk_148 + 423864*uk_149 + 11411775*uk_15 + 1096200*uk_150 + 287448*uk_151 + 2835000*uk_152 + 743400*uk_153 + 194936*uk_154 + 658503*uk_155 + 1703025*uk_156 + 446571*uk_157 + 4404375*uk_158 + 1154925*uk_159 + 2992421*uk_16 + 302847*uk_160 + 11390625*uk_161 + 2986875*uk_162 + 783225*uk_163 + 205379*uk_164 + 3025*uk_17 + 4675*uk_18 + 3245*uk_19 + 55*uk_2 + 3080*uk_20 + 4785*uk_21 + 12375*uk_22 + 3245*uk_23 + 7225*uk_24 + 5015*uk_25 + 4760*uk_26 + 7395*uk_27 + 19125*uk_28 + 5015*uk_29 + 85*uk_3 + 3481*uk_30 + 3304*uk_31 + 5133*uk_32 + 13275*uk_33 + 3481*uk_34 + 3136*uk_35 + 4872*uk_36 + 12600*uk_37 + 3304*uk_38 + 7569*uk_39 + 59*uk_4 + 19575*uk_40 + 5133*uk_41 + 50625*uk_42 + 13275*uk_43 + 3481*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 218655441685*uk_47 + 151772600699*uk_48 + 144055349816*uk_49 + 56*uk_5 + 223800275607*uk_50 + 578793816225*uk_51 + 151772600699*uk_52 + 153424975*uk_53 + 237111325*uk_54 + 164583155*uk_55 + 156214520*uk_56 + 242690415*uk_57 + 627647625*uk_58 + 164583155*uk_59 + 87*uk_6 + 366444775*uk_60 + 254355785*uk_61 + 241422440*uk_62 + 375067005*uk_63 + 970000875*uk_64 + 254355785*uk_65 + 176552839*uk_66 + 167575576*uk_67 + 260340627*uk_68 + 673294725*uk_69 + 225*uk_7 + 176552839*uk_70 + 159054784*uk_71 + 247102968*uk_72 + 639059400*uk_73 + 167575576*uk_74 + 383892111*uk_75 + 992824425*uk_76 + 260340627*uk_77 + 2567649375*uk_78 + 673294725*uk_79 + 59*uk_8 + 176552839*uk_80 + 166375*uk_81 + 257125*uk_82 + 178475*uk_83 + 169400*uk_84 + 263175*uk_85 + 680625*uk_86 + 178475*uk_87 + 397375*uk_88 + 275825*uk_89 + 2572416961*uk_9 + 261800*uk_90 + 406725*uk_91 + 1051875*uk_92 + 275825*uk_93 + 191455*uk_94 + 181720*uk_95 + 282315*uk_96 + 730125*uk_97 + 191455*uk_98 + 172480*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 254540*uk_100 + 643500*uk_101 + 243100*uk_102 + 435655*uk_103 + 1101375*uk_104 + 416075*uk_105 + 2784375*uk_106 + 1051875*uk_107 + 397375*uk_108 + 474552*uk_109 + 3956082*uk_11 + 517140*uk_110 + 316368*uk_111 + 541476*uk_112 + 1368900*uk_113 + 517140*uk_114 + 563550*uk_115 + 344760*uk_116 + 590070*uk_117 + 1491750*uk_118 + 563550*uk_119 + 4311115*uk_12 + 210912*uk_120 + 360984*uk_121 + 912600*uk_122 + 344760*uk_123 + 617838*uk_124 + 1561950*uk_125 + 590070*uk_126 + 3948750*uk_127 + 1491750*uk_128 + 563550*uk_129 + 2637388*uk_13 + 614125*uk_130 + 375700*uk_131 + 643025*uk_132 + 1625625*uk_133 + 614125*uk_134 + 229840*uk_135 + 393380*uk_136 + 994500*uk_137 + 375700*uk_138 + 673285*uk_139 + 4513991*uk_14 + 1702125*uk_140 + 643025*uk_141 + 4303125*uk_142 + 1625625*uk_143 + 614125*uk_144 + 140608*uk_145 + 240656*uk_146 + 608400*uk_147 + 229840*uk_148 + 411892*uk_149 + 11411775*uk_15 + 1041300*uk_150 + 393380*uk_151 + 2632500*uk_152 + 994500*uk_153 + 375700*uk_154 + 704969*uk_155 + 1782225*uk_156 + 673285*uk_157 + 4505625*uk_158 + 1702125*uk_159 + 4311115*uk_16 + 643025*uk_160 + 11390625*uk_161 + 4303125*uk_162 + 1625625*uk_163 + 614125*uk_164 + 3025*uk_17 + 4290*uk_18 + 4675*uk_19 + 55*uk_2 + 2860*uk_20 + 4895*uk_21 + 12375*uk_22 + 4675*uk_23 + 6084*uk_24 + 6630*uk_25 + 4056*uk_26 + 6942*uk_27 + 17550*uk_28 + 6630*uk_29 + 78*uk_3 + 7225*uk_30 + 4420*uk_31 + 7565*uk_32 + 19125*uk_33 + 7225*uk_34 + 2704*uk_35 + 4628*uk_36 + 11700*uk_37 + 4420*uk_38 + 7921*uk_39 + 85*uk_4 + 20025*uk_40 + 7565*uk_41 + 50625*uk_42 + 19125*uk_43 + 7225*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 200648522958*uk_47 + 218655441685*uk_48 + 133765681972*uk_49 + 52*uk_5 + 228945109529*uk_50 + 578793816225*uk_51 + 218655441685*uk_52 + 153424975*uk_53 + 217584510*uk_54 + 237111325*uk_55 + 145056340*uk_56 + 248269505*uk_57 + 627647625*uk_58 + 237111325*uk_59 + 89*uk_6 + 308574396*uk_60 + 336266970*uk_61 + 205716264*uk_62 + 352091298*uk_63 + 890118450*uk_64 + 336266970*uk_65 + 366444775*uk_66 + 224177980*uk_67 + 383689235*uk_68 + 970000875*uk_69 + 225*uk_7 + 366444775*uk_70 + 137144176*uk_71 + 234727532*uk_72 + 593412300*uk_73 + 224177980*uk_74 + 401745199*uk_75 + 1015647975*uk_76 + 383689235*uk_77 + 2567649375*uk_78 + 970000875*uk_79 + 85*uk_8 + 366444775*uk_80 + 166375*uk_81 + 235950*uk_82 + 257125*uk_83 + 157300*uk_84 + 269225*uk_85 + 680625*uk_86 + 257125*uk_87 + 334620*uk_88 + 364650*uk_89 + 2572416961*uk_9 + 223080*uk_90 + 381810*uk_91 + 965250*uk_92 + 364650*uk_93 + 397375*uk_94 + 243100*uk_95 + 416075*uk_96 + 1051875*uk_97 + 397375*uk_98 + 148720*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 240240*uk_100 + 594000*uk_101 + 205920*uk_102 + 455455*uk_103 + 1126125*uk_104 + 390390*uk_105 + 2784375*uk_106 + 965250*uk_107 + 334620*uk_108 + 32768*uk_109 + 1623008*uk_11 + 79872*uk_110 + 49152*uk_111 + 93184*uk_112 + 230400*uk_113 + 79872*uk_114 + 194688*uk_115 + 119808*uk_116 + 227136*uk_117 + 561600*uk_118 + 194688*uk_119 + 3956082*uk_12 + 73728*uk_120 + 139776*uk_121 + 345600*uk_122 + 119808*uk_123 + 264992*uk_124 + 655200*uk_125 + 227136*uk_126 + 1620000*uk_127 + 561600*uk_128 + 194688*uk_129 + 2434512*uk_13 + 474552*uk_130 + 292032*uk_131 + 553644*uk_132 + 1368900*uk_133 + 474552*uk_134 + 179712*uk_135 + 340704*uk_136 + 842400*uk_137 + 292032*uk_138 + 645918*uk_139 + 4615429*uk_14 + 1597050*uk_140 + 553644*uk_141 + 3948750*uk_142 + 1368900*uk_143 + 474552*uk_144 + 110592*uk_145 + 209664*uk_146 + 518400*uk_147 + 179712*uk_148 + 397488*uk_149 + 11411775*uk_15 + 982800*uk_150 + 340704*uk_151 + 2430000*uk_152 + 842400*uk_153 + 292032*uk_154 + 753571*uk_155 + 1863225*uk_156 + 645918*uk_157 + 4606875*uk_158 + 1597050*uk_159 + 3956082*uk_16 + 553644*uk_160 + 11390625*uk_161 + 3948750*uk_162 + 1368900*uk_163 + 474552*uk_164 + 3025*uk_17 + 1760*uk_18 + 4290*uk_19 + 55*uk_2 + 2640*uk_20 + 5005*uk_21 + 12375*uk_22 + 4290*uk_23 + 1024*uk_24 + 2496*uk_25 + 1536*uk_26 + 2912*uk_27 + 7200*uk_28 + 2496*uk_29 + 32*uk_3 + 6084*uk_30 + 3744*uk_31 + 7098*uk_32 + 17550*uk_33 + 6084*uk_34 + 2304*uk_35 + 4368*uk_36 + 10800*uk_37 + 3744*uk_38 + 8281*uk_39 + 78*uk_4 + 20475*uk_40 + 7098*uk_41 + 50625*uk_42 + 17550*uk_43 + 6084*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 82317342752*uk_47 + 200648522958*uk_48 + 123476014128*uk_49 + 48*uk_5 + 234089943451*uk_50 + 578793816225*uk_51 + 200648522958*uk_52 + 153424975*uk_53 + 89265440*uk_54 + 217584510*uk_55 + 133898160*uk_56 + 253848595*uk_57 + 627647625*uk_58 + 217584510*uk_59 + 91*uk_6 + 51936256*uk_60 + 126594624*uk_61 + 77904384*uk_62 + 147693728*uk_63 + 365176800*uk_64 + 126594624*uk_65 + 308574396*uk_66 + 189891936*uk_67 + 360003462*uk_68 + 890118450*uk_69 + 225*uk_7 + 308574396*uk_70 + 116856576*uk_71 + 221540592*uk_72 + 547765200*uk_73 + 189891936*uk_74 + 420004039*uk_75 + 1038471525*uk_76 + 360003462*uk_77 + 2567649375*uk_78 + 890118450*uk_79 + 78*uk_8 + 308574396*uk_80 + 166375*uk_81 + 96800*uk_82 + 235950*uk_83 + 145200*uk_84 + 275275*uk_85 + 680625*uk_86 + 235950*uk_87 + 56320*uk_88 + 137280*uk_89 + 2572416961*uk_9 + 84480*uk_90 + 160160*uk_91 + 396000*uk_92 + 137280*uk_93 + 334620*uk_94 + 205920*uk_95 + 390390*uk_96 + 965250*uk_97 + 334620*uk_98 + 126720*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 245520*uk_100 + 594000*uk_101 + 84480*uk_102 + 475695*uk_103 + 1150875*uk_104 + 163680*uk_105 + 2784375*uk_106 + 396000*uk_107 + 56320*uk_108 + 39304*uk_109 + 1724446*uk_11 + 36992*uk_110 + 55488*uk_111 + 107508*uk_112 + 260100*uk_113 + 36992*uk_114 + 34816*uk_115 + 52224*uk_116 + 101184*uk_117 + 244800*uk_118 + 34816*uk_119 + 1623008*uk_12 + 78336*uk_120 + 151776*uk_121 + 367200*uk_122 + 52224*uk_123 + 294066*uk_124 + 711450*uk_125 + 101184*uk_126 + 1721250*uk_127 + 244800*uk_128 + 34816*uk_129 + 2434512*uk_13 + 32768*uk_130 + 49152*uk_131 + 95232*uk_132 + 230400*uk_133 + 32768*uk_134 + 73728*uk_135 + 142848*uk_136 + 345600*uk_137 + 49152*uk_138 + 276768*uk_139 + 4716867*uk_14 + 669600*uk_140 + 95232*uk_141 + 1620000*uk_142 + 230400*uk_143 + 32768*uk_144 + 110592*uk_145 + 214272*uk_146 + 518400*uk_147 + 73728*uk_148 + 415152*uk_149 + 11411775*uk_15 + 1004400*uk_150 + 142848*uk_151 + 2430000*uk_152 + 345600*uk_153 + 49152*uk_154 + 804357*uk_155 + 1946025*uk_156 + 276768*uk_157 + 4708125*uk_158 + 669600*uk_159 + 1623008*uk_16 + 95232*uk_160 + 11390625*uk_161 + 1620000*uk_162 + 230400*uk_163 + 32768*uk_164 + 3025*uk_17 + 1870*uk_18 + 1760*uk_19 + 55*uk_2 + 2640*uk_20 + 5115*uk_21 + 12375*uk_22 + 1760*uk_23 + 1156*uk_24 + 1088*uk_25 + 1632*uk_26 + 3162*uk_27 + 7650*uk_28 + 1088*uk_29 + 34*uk_3 + 1024*uk_30 + 1536*uk_31 + 2976*uk_32 + 7200*uk_33 + 1024*uk_34 + 2304*uk_35 + 4464*uk_36 + 10800*uk_37 + 1536*uk_38 + 8649*uk_39 + 32*uk_4 + 20925*uk_40 + 2976*uk_41 + 50625*uk_42 + 7200*uk_43 + 1024*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 87462176674*uk_47 + 82317342752*uk_48 + 123476014128*uk_49 + 48*uk_5 + 239234777373*uk_50 + 578793816225*uk_51 + 82317342752*uk_52 + 153424975*uk_53 + 94844530*uk_54 + 89265440*uk_55 + 133898160*uk_56 + 259427685*uk_57 + 627647625*uk_58 + 89265440*uk_59 + 93*uk_6 + 58631164*uk_60 + 55182272*uk_61 + 82773408*uk_62 + 160373478*uk_63 + 388000350*uk_64 + 55182272*uk_65 + 51936256*uk_66 + 77904384*uk_67 + 150939744*uk_68 + 365176800*uk_69 + 225*uk_7 + 51936256*uk_70 + 116856576*uk_71 + 226409616*uk_72 + 547765200*uk_73 + 77904384*uk_74 + 438668631*uk_75 + 1061295075*uk_76 + 150939744*uk_77 + 2567649375*uk_78 + 365176800*uk_79 + 32*uk_8 + 51936256*uk_80 + 166375*uk_81 + 102850*uk_82 + 96800*uk_83 + 145200*uk_84 + 281325*uk_85 + 680625*uk_86 + 96800*uk_87 + 63580*uk_88 + 59840*uk_89 + 2572416961*uk_9 + 89760*uk_90 + 173910*uk_91 + 420750*uk_92 + 59840*uk_93 + 56320*uk_94 + 84480*uk_95 + 163680*uk_96 + 396000*uk_97 + 56320*uk_98 + 126720*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 250800*uk_100 + 594000*uk_101 + 89760*uk_102 + 496375*uk_103 + 1175625*uk_104 + 177650*uk_105 + 2784375*uk_106 + 420750*uk_107 + 63580*uk_108 + 592704*uk_109 + 4260396*uk_11 + 239904*uk_110 + 338688*uk_111 + 670320*uk_112 + 1587600*uk_113 + 239904*uk_114 + 97104*uk_115 + 137088*uk_116 + 271320*uk_117 + 642600*uk_118 + 97104*uk_119 + 1724446*uk_12 + 193536*uk_120 + 383040*uk_121 + 907200*uk_122 + 137088*uk_123 + 758100*uk_124 + 1795500*uk_125 + 271320*uk_126 + 4252500*uk_127 + 642600*uk_128 + 97104*uk_129 + 2434512*uk_13 + 39304*uk_130 + 55488*uk_131 + 109820*uk_132 + 260100*uk_133 + 39304*uk_134 + 78336*uk_135 + 155040*uk_136 + 367200*uk_137 + 55488*uk_138 + 306850*uk_139 + 4818305*uk_14 + 726750*uk_140 + 109820*uk_141 + 1721250*uk_142 + 260100*uk_143 + 39304*uk_144 + 110592*uk_145 + 218880*uk_146 + 518400*uk_147 + 78336*uk_148 + 433200*uk_149 + 11411775*uk_15 + 1026000*uk_150 + 155040*uk_151 + 2430000*uk_152 + 367200*uk_153 + 55488*uk_154 + 857375*uk_155 + 2030625*uk_156 + 306850*uk_157 + 4809375*uk_158 + 726750*uk_159 + 1724446*uk_16 + 109820*uk_160 + 11390625*uk_161 + 1721250*uk_162 + 260100*uk_163 + 39304*uk_164 + 3025*uk_17 + 4620*uk_18 + 1870*uk_19 + 55*uk_2 + 2640*uk_20 + 5225*uk_21 + 12375*uk_22 + 1870*uk_23 + 7056*uk_24 + 2856*uk_25 + 4032*uk_26 + 7980*uk_27 + 18900*uk_28 + 2856*uk_29 + 84*uk_3 + 1156*uk_30 + 1632*uk_31 + 3230*uk_32 + 7650*uk_33 + 1156*uk_34 + 2304*uk_35 + 4560*uk_36 + 10800*uk_37 + 1632*uk_38 + 9025*uk_39 + 34*uk_4 + 21375*uk_40 + 3230*uk_41 + 50625*uk_42 + 7650*uk_43 + 1156*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 216083024724*uk_47 + 87462176674*uk_48 + 123476014128*uk_49 + 48*uk_5 + 244379611295*uk_50 + 578793816225*uk_51 + 87462176674*uk_52 + 153424975*uk_53 + 234321780*uk_54 + 94844530*uk_55 + 133898160*uk_56 + 265006775*uk_57 + 627647625*uk_58 + 94844530*uk_59 + 95*uk_6 + 357873264*uk_60 + 144853464*uk_61 + 204499008*uk_62 + 404737620*uk_63 + 958589100*uk_64 + 144853464*uk_65 + 58631164*uk_66 + 82773408*uk_67 + 163822370*uk_68 + 388000350*uk_69 + 225*uk_7 + 58631164*uk_70 + 116856576*uk_71 + 231278640*uk_72 + 547765200*uk_73 + 82773408*uk_74 + 457738975*uk_75 + 1084118625*uk_76 + 163822370*uk_77 + 2567649375*uk_78 + 388000350*uk_79 + 34*uk_8 + 58631164*uk_80 + 166375*uk_81 + 254100*uk_82 + 102850*uk_83 + 145200*uk_84 + 287375*uk_85 + 680625*uk_86 + 102850*uk_87 + 388080*uk_88 + 157080*uk_89 + 2572416961*uk_9 + 221760*uk_90 + 438900*uk_91 + 1039500*uk_92 + 157080*uk_93 + 63580*uk_94 + 89760*uk_95 + 177650*uk_96 + 420750*uk_97 + 63580*uk_98 + 126720*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 234740*uk_100 + 544500*uk_101 + 203280*uk_102 + 517495*uk_103 + 1200375*uk_104 + 448140*uk_105 + 2784375*uk_106 + 1039500*uk_107 + 388080*uk_108 + 614125*uk_109 + 4311115*uk_11 + 606900*uk_110 + 317900*uk_111 + 700825*uk_112 + 1625625*uk_113 + 606900*uk_114 + 599760*uk_115 + 314160*uk_116 + 692580*uk_117 + 1606500*uk_118 + 599760*uk_119 + 4260396*uk_12 + 164560*uk_120 + 362780*uk_121 + 841500*uk_122 + 314160*uk_123 + 799765*uk_124 + 1855125*uk_125 + 692580*uk_126 + 4303125*uk_127 + 1606500*uk_128 + 599760*uk_129 + 2231636*uk_13 + 592704*uk_130 + 310464*uk_131 + 684432*uk_132 + 1587600*uk_133 + 592704*uk_134 + 162624*uk_135 + 358512*uk_136 + 831600*uk_137 + 310464*uk_138 + 790356*uk_139 + 4919743*uk_14 + 1833300*uk_140 + 684432*uk_141 + 4252500*uk_142 + 1587600*uk_143 + 592704*uk_144 + 85184*uk_145 + 187792*uk_146 + 435600*uk_147 + 162624*uk_148 + 413996*uk_149 + 11411775*uk_15 + 960300*uk_150 + 358512*uk_151 + 2227500*uk_152 + 831600*uk_153 + 310464*uk_154 + 912673*uk_155 + 2117025*uk_156 + 790356*uk_157 + 4910625*uk_158 + 1833300*uk_159 + 4260396*uk_16 + 684432*uk_160 + 11390625*uk_161 + 4252500*uk_162 + 1587600*uk_163 + 592704*uk_164 + 3025*uk_17 + 4675*uk_18 + 4620*uk_19 + 55*uk_2 + 2420*uk_20 + 5335*uk_21 + 12375*uk_22 + 4620*uk_23 + 7225*uk_24 + 7140*uk_25 + 3740*uk_26 + 8245*uk_27 + 19125*uk_28 + 7140*uk_29 + 85*uk_3 + 7056*uk_30 + 3696*uk_31 + 8148*uk_32 + 18900*uk_33 + 7056*uk_34 + 1936*uk_35 + 4268*uk_36 + 9900*uk_37 + 3696*uk_38 + 9409*uk_39 + 84*uk_4 + 21825*uk_40 + 8148*uk_41 + 50625*uk_42 + 18900*uk_43 + 7056*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 218655441685*uk_47 + 216083024724*uk_48 + 113186346284*uk_49 + 44*uk_5 + 249524445217*uk_50 + 578793816225*uk_51 + 216083024724*uk_52 + 153424975*uk_53 + 237111325*uk_54 + 234321780*uk_55 + 122739980*uk_56 + 270585865*uk_57 + 627647625*uk_58 + 234321780*uk_59 + 97*uk_6 + 366444775*uk_60 + 362133660*uk_61 + 189689060*uk_62 + 418178155*uk_63 + 970000875*uk_64 + 362133660*uk_65 + 357873264*uk_66 + 187457424*uk_67 + 413258412*uk_68 + 958589100*uk_69 + 225*uk_7 + 357873264*uk_70 + 98191984*uk_71 + 216468692*uk_72 + 502118100*uk_73 + 187457424*uk_74 + 477215071*uk_75 + 1106942175*uk_76 + 413258412*uk_77 + 2567649375*uk_78 + 958589100*uk_79 + 84*uk_8 + 357873264*uk_80 + 166375*uk_81 + 257125*uk_82 + 254100*uk_83 + 133100*uk_84 + 293425*uk_85 + 680625*uk_86 + 254100*uk_87 + 397375*uk_88 + 392700*uk_89 + 2572416961*uk_9 + 205700*uk_90 + 453475*uk_91 + 1051875*uk_92 + 392700*uk_93 + 388080*uk_94 + 203280*uk_95 + 448140*uk_96 + 1039500*uk_97 + 388080*uk_98 + 106480*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 217800*uk_100 + 495000*uk_101 + 187000*uk_102 + 539055*uk_103 + 1225125*uk_104 + 462825*uk_105 + 2784375*uk_106 + 1051875*uk_107 + 397375*uk_108 + 29791*uk_109 + 1572289*uk_11 + 81685*uk_110 + 38440*uk_111 + 95139*uk_112 + 216225*uk_113 + 81685*uk_114 + 223975*uk_115 + 105400*uk_116 + 260865*uk_117 + 592875*uk_118 + 223975*uk_119 + 4311115*uk_12 + 49600*uk_120 + 122760*uk_121 + 279000*uk_122 + 105400*uk_123 + 303831*uk_124 + 690525*uk_125 + 260865*uk_126 + 1569375*uk_127 + 592875*uk_128 + 223975*uk_129 + 2028760*uk_13 + 614125*uk_130 + 289000*uk_131 + 715275*uk_132 + 1625625*uk_133 + 614125*uk_134 + 136000*uk_135 + 336600*uk_136 + 765000*uk_137 + 289000*uk_138 + 833085*uk_139 + 5021181*uk_14 + 1893375*uk_140 + 715275*uk_141 + 4303125*uk_142 + 1625625*uk_143 + 614125*uk_144 + 64000*uk_145 + 158400*uk_146 + 360000*uk_147 + 136000*uk_148 + 392040*uk_149 + 11411775*uk_15 + 891000*uk_150 + 336600*uk_151 + 2025000*uk_152 + 765000*uk_153 + 289000*uk_154 + 970299*uk_155 + 2205225*uk_156 + 833085*uk_157 + 5011875*uk_158 + 1893375*uk_159 + 4311115*uk_16 + 715275*uk_160 + 11390625*uk_161 + 4303125*uk_162 + 1625625*uk_163 + 614125*uk_164 + 3025*uk_17 + 1705*uk_18 + 4675*uk_19 + 55*uk_2 + 2200*uk_20 + 5445*uk_21 + 12375*uk_22 + 4675*uk_23 + 961*uk_24 + 2635*uk_25 + 1240*uk_26 + 3069*uk_27 + 6975*uk_28 + 2635*uk_29 + 31*uk_3 + 7225*uk_30 + 3400*uk_31 + 8415*uk_32 + 19125*uk_33 + 7225*uk_34 + 1600*uk_35 + 3960*uk_36 + 9000*uk_37 + 3400*uk_38 + 9801*uk_39 + 85*uk_4 + 22275*uk_40 + 8415*uk_41 + 50625*uk_42 + 19125*uk_43 + 7225*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 79744925791*uk_47 + 218655441685*uk_48 + 102896678440*uk_49 + 40*uk_5 + 254669279139*uk_50 + 578793816225*uk_51 + 218655441685*uk_52 + 153424975*uk_53 + 86475895*uk_54 + 237111325*uk_55 + 111581800*uk_56 + 276164955*uk_57 + 627647625*uk_58 + 237111325*uk_59 + 99*uk_6 + 48740959*uk_60 + 133644565*uk_61 + 62891560*uk_62 + 155656611*uk_63 + 353765025*uk_64 + 133644565*uk_65 + 366444775*uk_66 + 172444600*uk_67 + 426800385*uk_68 + 970000875*uk_69 + 225*uk_7 + 366444775*uk_70 + 81150400*uk_71 + 200847240*uk_72 + 456471000*uk_73 + 172444600*uk_74 + 497096919*uk_75 + 1129765725*uk_76 + 426800385*uk_77 + 2567649375*uk_78 + 970000875*uk_79 + 85*uk_8 + 366444775*uk_80 + 166375*uk_81 + 93775*uk_82 + 257125*uk_83 + 121000*uk_84 + 299475*uk_85 + 680625*uk_86 + 257125*uk_87 + 52855*uk_88 + 144925*uk_89 + 2572416961*uk_9 + 68200*uk_90 + 168795*uk_91 + 383625*uk_92 + 144925*uk_93 + 397375*uk_94 + 187000*uk_95 + 462825*uk_96 + 1051875*uk_97 + 397375*uk_98 + 88000*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 222200*uk_100 + 495000*uk_101 + 68200*uk_102 + 561055*uk_103 + 1249875*uk_104 + 172205*uk_105 + 2784375*uk_106 + 383625*uk_107 + 52855*uk_108 + 4913*uk_109 + 862223*uk_11 + 8959*uk_110 + 11560*uk_111 + 29189*uk_112 + 65025*uk_113 + 8959*uk_114 + 16337*uk_115 + 21080*uk_116 + 53227*uk_117 + 118575*uk_118 + 16337*uk_119 + 1572289*uk_12 + 27200*uk_120 + 68680*uk_121 + 153000*uk_122 + 21080*uk_123 + 173417*uk_124 + 386325*uk_125 + 53227*uk_126 + 860625*uk_127 + 118575*uk_128 + 16337*uk_129 + 2028760*uk_13 + 29791*uk_130 + 38440*uk_131 + 97061*uk_132 + 216225*uk_133 + 29791*uk_134 + 49600*uk_135 + 125240*uk_136 + 279000*uk_137 + 38440*uk_138 + 316231*uk_139 + 5122619*uk_14 + 704475*uk_140 + 97061*uk_141 + 1569375*uk_142 + 216225*uk_143 + 29791*uk_144 + 64000*uk_145 + 161600*uk_146 + 360000*uk_147 + 49600*uk_148 + 408040*uk_149 + 11411775*uk_15 + 909000*uk_150 + 125240*uk_151 + 2025000*uk_152 + 279000*uk_153 + 38440*uk_154 + 1030301*uk_155 + 2295225*uk_156 + 316231*uk_157 + 5113125*uk_158 + 704475*uk_159 + 1572289*uk_16 + 97061*uk_160 + 11390625*uk_161 + 1569375*uk_162 + 216225*uk_163 + 29791*uk_164 + 3025*uk_17 + 935*uk_18 + 1705*uk_19 + 55*uk_2 + 2200*uk_20 + 5555*uk_21 + 12375*uk_22 + 1705*uk_23 + 289*uk_24 + 527*uk_25 + 680*uk_26 + 1717*uk_27 + 3825*uk_28 + 527*uk_29 + 17*uk_3 + 961*uk_30 + 1240*uk_31 + 3131*uk_32 + 6975*uk_33 + 961*uk_34 + 1600*uk_35 + 4040*uk_36 + 9000*uk_37 + 1240*uk_38 + 10201*uk_39 + 31*uk_4 + 22725*uk_40 + 3131*uk_41 + 50625*uk_42 + 6975*uk_43 + 961*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 43731088337*uk_47 + 79744925791*uk_48 + 102896678440*uk_49 + 40*uk_5 + 259814113061*uk_50 + 578793816225*uk_51 + 79744925791*uk_52 + 153424975*uk_53 + 47422265*uk_54 + 86475895*uk_55 + 111581800*uk_56 + 281744045*uk_57 + 627647625*uk_58 + 86475895*uk_59 + 101*uk_6 + 14657791*uk_60 + 26728913*uk_61 + 34488920*uk_62 + 87084523*uk_63 + 194000175*uk_64 + 26728913*uk_65 + 48740959*uk_66 + 62891560*uk_67 + 158801189*uk_68 + 353765025*uk_69 + 225*uk_7 + 48740959*uk_70 + 81150400*uk_71 + 204904760*uk_72 + 456471000*uk_73 + 62891560*uk_74 + 517384519*uk_75 + 1152589275*uk_76 + 158801189*uk_77 + 2567649375*uk_78 + 353765025*uk_79 + 31*uk_8 + 48740959*uk_80 + 166375*uk_81 + 51425*uk_82 + 93775*uk_83 + 121000*uk_84 + 305525*uk_85 + 680625*uk_86 + 93775*uk_87 + 15895*uk_88 + 28985*uk_89 + 2572416961*uk_9 + 37400*uk_90 + 94435*uk_91 + 210375*uk_92 + 28985*uk_93 + 52855*uk_94 + 68200*uk_95 + 172205*uk_96 + 383625*uk_97 + 52855*uk_98 + 88000*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 226600*uk_100 + 495000*uk_101 + 37400*uk_102 + 583495*uk_103 + 1274625*uk_104 + 96305*uk_105 + 2784375*uk_106 + 210375*uk_107 + 15895*uk_108 + 79507*uk_109 + 2180917*uk_11 + 31433*uk_110 + 73960*uk_111 + 190447*uk_112 + 416025*uk_113 + 31433*uk_114 + 12427*uk_115 + 29240*uk_116 + 75293*uk_117 + 164475*uk_118 + 12427*uk_119 + 862223*uk_12 + 68800*uk_120 + 177160*uk_121 + 387000*uk_122 + 29240*uk_123 + 456187*uk_124 + 996525*uk_125 + 75293*uk_126 + 2176875*uk_127 + 164475*uk_128 + 12427*uk_129 + 2028760*uk_13 + 4913*uk_130 + 11560*uk_131 + 29767*uk_132 + 65025*uk_133 + 4913*uk_134 + 27200*uk_135 + 70040*uk_136 + 153000*uk_137 + 11560*uk_138 + 180353*uk_139 + 5224057*uk_14 + 393975*uk_140 + 29767*uk_141 + 860625*uk_142 + 65025*uk_143 + 4913*uk_144 + 64000*uk_145 + 164800*uk_146 + 360000*uk_147 + 27200*uk_148 + 424360*uk_149 + 11411775*uk_15 + 927000*uk_150 + 70040*uk_151 + 2025000*uk_152 + 153000*uk_153 + 11560*uk_154 + 1092727*uk_155 + 2387025*uk_156 + 180353*uk_157 + 5214375*uk_158 + 393975*uk_159 + 862223*uk_16 + 29767*uk_160 + 11390625*uk_161 + 860625*uk_162 + 65025*uk_163 + 4913*uk_164 + 3025*uk_17 + 2365*uk_18 + 935*uk_19 + 55*uk_2 + 2200*uk_20 + 5665*uk_21 + 12375*uk_22 + 935*uk_23 + 1849*uk_24 + 731*uk_25 + 1720*uk_26 + 4429*uk_27 + 9675*uk_28 + 731*uk_29 + 43*uk_3 + 289*uk_30 + 680*uk_31 + 1751*uk_32 + 3825*uk_33 + 289*uk_34 + 1600*uk_35 + 4120*uk_36 + 9000*uk_37 + 680*uk_38 + 10609*uk_39 + 17*uk_4 + 23175*uk_40 + 1751*uk_41 + 50625*uk_42 + 3825*uk_43 + 289*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 110613929323*uk_47 + 43731088337*uk_48 + 102896678440*uk_49 + 40*uk_5 + 264958946983*uk_50 + 578793816225*uk_51 + 43731088337*uk_52 + 153424975*uk_53 + 119950435*uk_54 + 47422265*uk_55 + 111581800*uk_56 + 287323135*uk_57 + 627647625*uk_58 + 47422265*uk_59 + 103*uk_6 + 93779431*uk_60 + 37075589*uk_61 + 87236680*uk_62 + 224634451*uk_63 + 490706325*uk_64 + 37075589*uk_65 + 14657791*uk_66 + 34488920*uk_67 + 88808969*uk_68 + 194000175*uk_69 + 225*uk_7 + 14657791*uk_70 + 81150400*uk_71 + 208962280*uk_72 + 456471000*uk_73 + 34488920*uk_74 + 538077871*uk_75 + 1175412825*uk_76 + 88808969*uk_77 + 2567649375*uk_78 + 194000175*uk_79 + 17*uk_8 + 14657791*uk_80 + 166375*uk_81 + 130075*uk_82 + 51425*uk_83 + 121000*uk_84 + 311575*uk_85 + 680625*uk_86 + 51425*uk_87 + 101695*uk_88 + 40205*uk_89 + 2572416961*uk_9 + 94600*uk_90 + 243595*uk_91 + 532125*uk_92 + 40205*uk_93 + 15895*uk_94 + 37400*uk_95 + 96305*uk_96 + 210375*uk_97 + 15895*uk_98 + 88000*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 207900*uk_100 + 445500*uk_101 + 85140*uk_102 + 606375*uk_103 + 1299375*uk_104 + 248325*uk_105 + 2784375*uk_106 + 532125*uk_107 + 101695*uk_108 + 64*uk_109 + 202876*uk_11 + 688*uk_110 + 576*uk_111 + 1680*uk_112 + 3600*uk_113 + 688*uk_114 + 7396*uk_115 + 6192*uk_116 + 18060*uk_117 + 38700*uk_118 + 7396*uk_119 + 2180917*uk_12 + 5184*uk_120 + 15120*uk_121 + 32400*uk_122 + 6192*uk_123 + 44100*uk_124 + 94500*uk_125 + 18060*uk_126 + 202500*uk_127 + 38700*uk_128 + 7396*uk_129 + 1825884*uk_13 + 79507*uk_130 + 66564*uk_131 + 194145*uk_132 + 416025*uk_133 + 79507*uk_134 + 55728*uk_135 + 162540*uk_136 + 348300*uk_137 + 66564*uk_138 + 474075*uk_139 + 5325495*uk_14 + 1015875*uk_140 + 194145*uk_141 + 2176875*uk_142 + 416025*uk_143 + 79507*uk_144 + 46656*uk_145 + 136080*uk_146 + 291600*uk_147 + 55728*uk_148 + 396900*uk_149 + 11411775*uk_15 + 850500*uk_150 + 162540*uk_151 + 1822500*uk_152 + 348300*uk_153 + 66564*uk_154 + 1157625*uk_155 + 2480625*uk_156 + 474075*uk_157 + 5315625*uk_158 + 1015875*uk_159 + 2180917*uk_16 + 194145*uk_160 + 11390625*uk_161 + 2176875*uk_162 + 416025*uk_163 + 79507*uk_164 + 3025*uk_17 + 220*uk_18 + 2365*uk_19 + 55*uk_2 + 1980*uk_20 + 5775*uk_21 + 12375*uk_22 + 2365*uk_23 + 16*uk_24 + 172*uk_25 + 144*uk_26 + 420*uk_27 + 900*uk_28 + 172*uk_29 + 4*uk_3 + 1849*uk_30 + 1548*uk_31 + 4515*uk_32 + 9675*uk_33 + 1849*uk_34 + 1296*uk_35 + 3780*uk_36 + 8100*uk_37 + 1548*uk_38 + 11025*uk_39 + 43*uk_4 + 23625*uk_40 + 4515*uk_41 + 50625*uk_42 + 9675*uk_43 + 1849*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 10289667844*uk_47 + 110613929323*uk_48 + 92607010596*uk_49 + 36*uk_5 + 270103780905*uk_50 + 578793816225*uk_51 + 110613929323*uk_52 + 153424975*uk_53 + 11158180*uk_54 + 119950435*uk_55 + 100423620*uk_56 + 292902225*uk_57 + 627647625*uk_58 + 119950435*uk_59 + 105*uk_6 + 811504*uk_60 + 8723668*uk_61 + 7303536*uk_62 + 21301980*uk_63 + 45647100*uk_64 + 8723668*uk_65 + 93779431*uk_66 + 78513012*uk_67 + 228996285*uk_68 + 490706325*uk_69 + 225*uk_7 + 93779431*uk_70 + 65731824*uk_71 + 191717820*uk_72 + 410823900*uk_73 + 78513012*uk_74 + 559176975*uk_75 + 1198236375*uk_76 + 228996285*uk_77 + 2567649375*uk_78 + 490706325*uk_79 + 43*uk_8 + 93779431*uk_80 + 166375*uk_81 + 12100*uk_82 + 130075*uk_83 + 108900*uk_84 + 317625*uk_85 + 680625*uk_86 + 130075*uk_87 + 880*uk_88 + 9460*uk_89 + 2572416961*uk_9 + 7920*uk_90 + 23100*uk_91 + 49500*uk_92 + 9460*uk_93 + 101695*uk_94 + 85140*uk_95 + 248325*uk_96 + 532125*uk_97 + 101695*uk_98 + 71280*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 211860*uk_100 + 445500*uk_101 + 7920*uk_102 + 629695*uk_103 + 1324125*uk_104 + 23540*uk_105 + 2784375*uk_106 + 49500*uk_107 + 880*uk_108 + uk_109 + 50719*uk_11 + 4*uk_110 + 36*uk_111 + 107*uk_112 + 225*uk_113 + 4*uk_114 + 16*uk_115 + 144*uk_116 + 428*uk_117 + 900*uk_118 + 16*uk_119 + 202876*uk_12 + 1296*uk_120 + 3852*uk_121 + 8100*uk_122 + 144*uk_123 + 11449*uk_124 + 24075*uk_125 + 428*uk_126 + 50625*uk_127 + 900*uk_128 + 16*uk_129 + 1825884*uk_13 + 64*uk_130 + 576*uk_131 + 1712*uk_132 + 3600*uk_133 + 64*uk_134 + 5184*uk_135 + 15408*uk_136 + 32400*uk_137 + 576*uk_138 + 45796*uk_139 + 5426933*uk_14 + 96300*uk_140 + 1712*uk_141 + 202500*uk_142 + 3600*uk_143 + 64*uk_144 + 46656*uk_145 + 138672*uk_146 + 291600*uk_147 + 5184*uk_148 + 412164*uk_149 + 11411775*uk_15 + 866700*uk_150 + 15408*uk_151 + 1822500*uk_152 + 32400*uk_153 + 576*uk_154 + 1225043*uk_155 + 2576025*uk_156 + 45796*uk_157 + 5416875*uk_158 + 96300*uk_159 + 202876*uk_16 + 1712*uk_160 + 11390625*uk_161 + 202500*uk_162 + 3600*uk_163 + 64*uk_164 + 3025*uk_17 + 55*uk_18 + 220*uk_19 + 55*uk_2 + 1980*uk_20 + 5885*uk_21 + 12375*uk_22 + 220*uk_23 + uk_24 + 4*uk_25 + 36*uk_26 + 107*uk_27 + 225*uk_28 + 4*uk_29 + uk_3 + 16*uk_30 + 144*uk_31 + 428*uk_32 + 900*uk_33 + 16*uk_34 + 1296*uk_35 + 3852*uk_36 + 8100*uk_37 + 144*uk_38 + 11449*uk_39 + 4*uk_4 + 24075*uk_40 + 428*uk_41 + 50625*uk_42 + 900*uk_43 + 16*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 2572416961*uk_47 + 10289667844*uk_48 + 92607010596*uk_49 + 36*uk_5 + 275248614827*uk_50 + 578793816225*uk_51 + 10289667844*uk_52 + 153424975*uk_53 + 2789545*uk_54 + 11158180*uk_55 + 100423620*uk_56 + 298481315*uk_57 + 627647625*uk_58 + 11158180*uk_59 + 107*uk_6 + 50719*uk_60 + 202876*uk_61 + 1825884*uk_62 + 5426933*uk_63 + 11411775*uk_64 + 202876*uk_65 + 811504*uk_66 + 7303536*uk_67 + 21707732*uk_68 + 45647100*uk_69 + 225*uk_7 + 811504*uk_70 + 65731824*uk_71 + 195369588*uk_72 + 410823900*uk_73 + 7303536*uk_74 + 580681831*uk_75 + 1221059925*uk_76 + 21707732*uk_77 + 2567649375*uk_78 + 45647100*uk_79 + 4*uk_8 + 811504*uk_80 + 166375*uk_81 + 3025*uk_82 + 12100*uk_83 + 108900*uk_84 + 323675*uk_85 + 680625*uk_86 + 12100*uk_87 + 55*uk_88 + 220*uk_89 + 2572416961*uk_9 + 1980*uk_90 + 5885*uk_91 + 12375*uk_92 + 220*uk_93 + 880*uk_94 + 7920*uk_95 + 23540*uk_96 + 49500*uk_97 + 880*uk_98 + 71280*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 215820*uk_100 + 445500*uk_101 + 1980*uk_102 + 653455*uk_103 + 1348875*uk_104 + 5995*uk_105 + 2784375*uk_106 + 12375*uk_107 + 55*uk_108 + 39304*uk_109 + 1724446*uk_11 + 1156*uk_110 + 41616*uk_111 + 126004*uk_112 + 260100*uk_113 + 1156*uk_114 + 34*uk_115 + 1224*uk_116 + 3706*uk_117 + 7650*uk_118 + 34*uk_119 + 50719*uk_12 + 44064*uk_120 + 133416*uk_121 + 275400*uk_122 + 1224*uk_123 + 403954*uk_124 + 833850*uk_125 + 3706*uk_126 + 1721250*uk_127 + 7650*uk_128 + 34*uk_129 + 1825884*uk_13 + uk_130 + 36*uk_131 + 109*uk_132 + 225*uk_133 + uk_134 + 1296*uk_135 + 3924*uk_136 + 8100*uk_137 + 36*uk_138 + 11881*uk_139 + 5528371*uk_14 + 24525*uk_140 + 109*uk_141 + 50625*uk_142 + 225*uk_143 + uk_144 + 46656*uk_145 + 141264*uk_146 + 291600*uk_147 + 1296*uk_148 + 427716*uk_149 + 11411775*uk_15 + 882900*uk_150 + 3924*uk_151 + 1822500*uk_152 + 8100*uk_153 + 36*uk_154 + 1295029*uk_155 + 2673225*uk_156 + 11881*uk_157 + 5518125*uk_158 + 24525*uk_159 + 50719*uk_16 + 109*uk_160 + 11390625*uk_161 + 50625*uk_162 + 225*uk_163 + uk_164 + 3025*uk_17 + 1870*uk_18 + 55*uk_19 + 55*uk_2 + 1980*uk_20 + 5995*uk_21 + 12375*uk_22 + 55*uk_23 + 1156*uk_24 + 34*uk_25 + 1224*uk_26 + 3706*uk_27 + 7650*uk_28 + 34*uk_29 + 34*uk_3 + uk_30 + 36*uk_31 + 109*uk_32 + 225*uk_33 + uk_34 + 1296*uk_35 + 3924*uk_36 + 8100*uk_37 + 36*uk_38 + 11881*uk_39 + uk_4 + 24525*uk_40 + 109*uk_41 + 50625*uk_42 + 225*uk_43 + uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 87462176674*uk_47 + 2572416961*uk_48 + 92607010596*uk_49 + 36*uk_5 + 280393448749*uk_50 + 578793816225*uk_51 + 2572416961*uk_52 + 153424975*uk_53 + 94844530*uk_54 + 2789545*uk_55 + 100423620*uk_56 + 304060405*uk_57 + 627647625*uk_58 + 2789545*uk_59 + 109*uk_6 + 58631164*uk_60 + 1724446*uk_61 + 62080056*uk_62 + 187964614*uk_63 + 388000350*uk_64 + 1724446*uk_65 + 50719*uk_66 + 1825884*uk_67 + 5528371*uk_68 + 11411775*uk_69 + 225*uk_7 + 50719*uk_70 + 65731824*uk_71 + 199021356*uk_72 + 410823900*uk_73 + 1825884*uk_74 + 602592439*uk_75 + 1243883475*uk_76 + 5528371*uk_77 + 2567649375*uk_78 + 11411775*uk_79 + uk_8 + 50719*uk_80 + 166375*uk_81 + 102850*uk_82 + 3025*uk_83 + 108900*uk_84 + 329725*uk_85 + 680625*uk_86 + 3025*uk_87 + 63580*uk_88 + 1870*uk_89 + 2572416961*uk_9 + 67320*uk_90 + 203830*uk_91 + 420750*uk_92 + 1870*uk_93 + 55*uk_94 + 1980*uk_95 + 5995*uk_96 + 12375*uk_97 + 55*uk_98 + 71280*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 219780*uk_100 + 445500*uk_101 + 67320*uk_102 + 677655*uk_103 + 1373625*uk_104 + 207570*uk_105 + 2784375*uk_106 + 420750*uk_107 + 63580*uk_108 + 1092727*uk_109 + 5224057*uk_11 + 360706*uk_110 + 381924*uk_111 + 1177599*uk_112 + 2387025*uk_113 + 360706*uk_114 + 119068*uk_115 + 126072*uk_116 + 388722*uk_117 + 787950*uk_118 + 119068*uk_119 + 1724446*uk_12 + 133488*uk_120 + 411588*uk_121 + 834300*uk_122 + 126072*uk_123 + 1269063*uk_124 + 2572425*uk_125 + 388722*uk_126 + 5214375*uk_127 + 787950*uk_128 + 119068*uk_129 + 1825884*uk_13 + 39304*uk_130 + 41616*uk_131 + 128316*uk_132 + 260100*uk_133 + 39304*uk_134 + 44064*uk_135 + 135864*uk_136 + 275400*uk_137 + 41616*uk_138 + 418914*uk_139 + 5629809*uk_14 + 849150*uk_140 + 128316*uk_141 + 1721250*uk_142 + 260100*uk_143 + 39304*uk_144 + 46656*uk_145 + 143856*uk_146 + 291600*uk_147 + 44064*uk_148 + 443556*uk_149 + 11411775*uk_15 + 899100*uk_150 + 135864*uk_151 + 1822500*uk_152 + 275400*uk_153 + 41616*uk_154 + 1367631*uk_155 + 2772225*uk_156 + 418914*uk_157 + 5619375*uk_158 + 849150*uk_159 + 1724446*uk_16 + 128316*uk_160 + 11390625*uk_161 + 1721250*uk_162 + 260100*uk_163 + 39304*uk_164 + 3025*uk_17 + 5665*uk_18 + 1870*uk_19 + 55*uk_2 + 1980*uk_20 + 6105*uk_21 + 12375*uk_22 + 1870*uk_23 + 10609*uk_24 + 3502*uk_25 + 3708*uk_26 + 11433*uk_27 + 23175*uk_28 + 3502*uk_29 + 103*uk_3 + 1156*uk_30 + 1224*uk_31 + 3774*uk_32 + 7650*uk_33 + 1156*uk_34 + 1296*uk_35 + 3996*uk_36 + 8100*uk_37 + 1224*uk_38 + 12321*uk_39 + 34*uk_4 + 24975*uk_40 + 3774*uk_41 + 50625*uk_42 + 7650*uk_43 + 1156*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 264958946983*uk_47 + 87462176674*uk_48 + 92607010596*uk_49 + 36*uk_5 + 285538282671*uk_50 + 578793816225*uk_51 + 87462176674*uk_52 + 153424975*uk_53 + 287323135*uk_54 + 94844530*uk_55 + 100423620*uk_56 + 309639495*uk_57 + 627647625*uk_58 + 94844530*uk_59 + 111*uk_6 + 538077871*uk_60 + 177617938*uk_61 + 188066052*uk_62 + 579870327*uk_63 + 1175412825*uk_64 + 177617938*uk_65 + 58631164*uk_66 + 62080056*uk_67 + 191413506*uk_68 + 388000350*uk_69 + 225*uk_7 + 58631164*uk_70 + 65731824*uk_71 + 202673124*uk_72 + 410823900*uk_73 + 62080056*uk_74 + 624908799*uk_75 + 1266707025*uk_76 + 191413506*uk_77 + 2567649375*uk_78 + 388000350*uk_79 + 34*uk_8 + 58631164*uk_80 + 166375*uk_81 + 311575*uk_82 + 102850*uk_83 + 108900*uk_84 + 335775*uk_85 + 680625*uk_86 + 102850*uk_87 + 583495*uk_88 + 192610*uk_89 + 2572416961*uk_9 + 203940*uk_90 + 628815*uk_91 + 1274625*uk_92 + 192610*uk_93 + 63580*uk_94 + 67320*uk_95 + 207570*uk_96 + 420750*uk_97 + 63580*uk_98 + 71280*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 198880*uk_100 + 396000*uk_101 + 181280*uk_102 + 702295*uk_103 + 1398375*uk_104 + 640145*uk_105 + 2784375*uk_106 + 1274625*uk_107 + 583495*uk_108 + 857375*uk_109 + 4818305*uk_11 + 929575*uk_110 + 288800*uk_111 + 1019825*uk_112 + 2030625*uk_113 + 929575*uk_114 + 1007855*uk_115 + 313120*uk_116 + 1105705*uk_117 + 2201625*uk_118 + 1007855*uk_119 + 5224057*uk_12 + 97280*uk_120 + 343520*uk_121 + 684000*uk_122 + 313120*uk_123 + 1213055*uk_124 + 2415375*uk_125 + 1105705*uk_126 + 4809375*uk_127 + 2201625*uk_128 + 1007855*uk_129 + 1623008*uk_13 + 1092727*uk_130 + 339488*uk_131 + 1198817*uk_132 + 2387025*uk_133 + 1092727*uk_134 + 105472*uk_135 + 372448*uk_136 + 741600*uk_137 + 339488*uk_138 + 1315207*uk_139 + 5731247*uk_14 + 2618775*uk_140 + 1198817*uk_141 + 5214375*uk_142 + 2387025*uk_143 + 1092727*uk_144 + 32768*uk_145 + 115712*uk_146 + 230400*uk_147 + 105472*uk_148 + 408608*uk_149 + 11411775*uk_15 + 813600*uk_150 + 372448*uk_151 + 1620000*uk_152 + 741600*uk_153 + 339488*uk_154 + 1442897*uk_155 + 2873025*uk_156 + 1315207*uk_157 + 5720625*uk_158 + 2618775*uk_159 + 5224057*uk_16 + 1198817*uk_160 + 11390625*uk_161 + 5214375*uk_162 + 2387025*uk_163 + 1092727*uk_164 + 3025*uk_17 + 5225*uk_18 + 5665*uk_19 + 55*uk_2 + 1760*uk_20 + 6215*uk_21 + 12375*uk_22 + 5665*uk_23 + 9025*uk_24 + 9785*uk_25 + 3040*uk_26 + 10735*uk_27 + 21375*uk_28 + 9785*uk_29 + 95*uk_3 + 10609*uk_30 + 3296*uk_31 + 11639*uk_32 + 23175*uk_33 + 10609*uk_34 + 1024*uk_35 + 3616*uk_36 + 7200*uk_37 + 3296*uk_38 + 12769*uk_39 + 103*uk_4 + 25425*uk_40 + 11639*uk_41 + 50625*uk_42 + 23175*uk_43 + 10609*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 244379611295*uk_47 + 264958946983*uk_48 + 82317342752*uk_49 + 32*uk_5 + 290683116593*uk_50 + 578793816225*uk_51 + 264958946983*uk_52 + 153424975*uk_53 + 265006775*uk_54 + 287323135*uk_55 + 89265440*uk_56 + 315218585*uk_57 + 627647625*uk_58 + 287323135*uk_59 + 113*uk_6 + 457738975*uk_60 + 496285415*uk_61 + 154185760*uk_62 + 544468465*uk_63 + 1084118625*uk_64 + 496285415*uk_65 + 538077871*uk_66 + 167169824*uk_67 + 590318441*uk_68 + 1175412825*uk_69 + 225*uk_7 + 538077871*uk_70 + 51936256*uk_71 + 183399904*uk_72 + 365176800*uk_73 + 167169824*uk_74 + 647630911*uk_75 + 1289530575*uk_76 + 590318441*uk_77 + 2567649375*uk_78 + 1175412825*uk_79 + 103*uk_8 + 538077871*uk_80 + 166375*uk_81 + 287375*uk_82 + 311575*uk_83 + 96800*uk_84 + 341825*uk_85 + 680625*uk_86 + 311575*uk_87 + 496375*uk_88 + 538175*uk_89 + 2572416961*uk_9 + 167200*uk_90 + 590425*uk_91 + 1175625*uk_92 + 538175*uk_93 + 583495*uk_94 + 181280*uk_95 + 640145*uk_96 + 1274625*uk_97 + 583495*uk_98 + 56320*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 177100*uk_100 + 346500*uk_101 + 146300*uk_102 + 727375*uk_103 + 1423125*uk_104 + 600875*uk_105 + 2784375*uk_106 + 1175625*uk_107 + 496375*uk_108 + 64*uk_109 + 202876*uk_11 + 1520*uk_110 + 448*uk_111 + 1840*uk_112 + 3600*uk_113 + 1520*uk_114 + 36100*uk_115 + 10640*uk_116 + 43700*uk_117 + 85500*uk_118 + 36100*uk_119 + 4818305*uk_12 + 3136*uk_120 + 12880*uk_121 + 25200*uk_122 + 10640*uk_123 + 52900*uk_124 + 103500*uk_125 + 43700*uk_126 + 202500*uk_127 + 85500*uk_128 + 36100*uk_129 + 1420132*uk_13 + 857375*uk_130 + 252700*uk_131 + 1037875*uk_132 + 2030625*uk_133 + 857375*uk_134 + 74480*uk_135 + 305900*uk_136 + 598500*uk_137 + 252700*uk_138 + 1256375*uk_139 + 5832685*uk_14 + 2458125*uk_140 + 1037875*uk_141 + 4809375*uk_142 + 2030625*uk_143 + 857375*uk_144 + 21952*uk_145 + 90160*uk_146 + 176400*uk_147 + 74480*uk_148 + 370300*uk_149 + 11411775*uk_15 + 724500*uk_150 + 305900*uk_151 + 1417500*uk_152 + 598500*uk_153 + 252700*uk_154 + 1520875*uk_155 + 2975625*uk_156 + 1256375*uk_157 + 5821875*uk_158 + 2458125*uk_159 + 4818305*uk_16 + 1037875*uk_160 + 11390625*uk_161 + 4809375*uk_162 + 2030625*uk_163 + 857375*uk_164 + 3025*uk_17 + 220*uk_18 + 5225*uk_19 + 55*uk_2 + 1540*uk_20 + 6325*uk_21 + 12375*uk_22 + 5225*uk_23 + 16*uk_24 + 380*uk_25 + 112*uk_26 + 460*uk_27 + 900*uk_28 + 380*uk_29 + 4*uk_3 + 9025*uk_30 + 2660*uk_31 + 10925*uk_32 + 21375*uk_33 + 9025*uk_34 + 784*uk_35 + 3220*uk_36 + 6300*uk_37 + 2660*uk_38 + 13225*uk_39 + 95*uk_4 + 25875*uk_40 + 10925*uk_41 + 50625*uk_42 + 21375*uk_43 + 9025*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 10289667844*uk_47 + 244379611295*uk_48 + 72027674908*uk_49 + 28*uk_5 + 295827950515*uk_50 + 578793816225*uk_51 + 244379611295*uk_52 + 153424975*uk_53 + 11158180*uk_54 + 265006775*uk_55 + 78107260*uk_56 + 320797675*uk_57 + 627647625*uk_58 + 265006775*uk_59 + 115*uk_6 + 811504*uk_60 + 19273220*uk_61 + 5680528*uk_62 + 23330740*uk_63 + 45647100*uk_64 + 19273220*uk_65 + 457738975*uk_66 + 134912540*uk_67 + 554105075*uk_68 + 1084118625*uk_69 + 225*uk_7 + 457738975*uk_70 + 39763696*uk_71 + 163315180*uk_72 + 319529700*uk_73 + 134912540*uk_74 + 670758775*uk_75 + 1312354125*uk_76 + 554105075*uk_77 + 2567649375*uk_78 + 1084118625*uk_79 + 95*uk_8 + 457738975*uk_80 + 166375*uk_81 + 12100*uk_82 + 287375*uk_83 + 84700*uk_84 + 347875*uk_85 + 680625*uk_86 + 287375*uk_87 + 880*uk_88 + 20900*uk_89 + 2572416961*uk_9 + 6160*uk_90 + 25300*uk_91 + 49500*uk_92 + 20900*uk_93 + 496375*uk_94 + 146300*uk_95 + 600875*uk_96 + 1175625*uk_97 + 496375*uk_98 + 43120*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 205920*uk_100 + 396000*uk_101 + 7040*uk_102 + 752895*uk_103 + 1447875*uk_104 + 25740*uk_105 + 2784375*uk_106 + 49500*uk_107 + 880*uk_108 + 195112*uk_109 + 2941702*uk_11 + 13456*uk_110 + 107648*uk_111 + 393588*uk_112 + 756900*uk_113 + 13456*uk_114 + 928*uk_115 + 7424*uk_116 + 27144*uk_117 + 52200*uk_118 + 928*uk_119 + 202876*uk_12 + 59392*uk_120 + 217152*uk_121 + 417600*uk_122 + 7424*uk_123 + 793962*uk_124 + 1526850*uk_125 + 27144*uk_126 + 2936250*uk_127 + 52200*uk_128 + 928*uk_129 + 1623008*uk_13 + 64*uk_130 + 512*uk_131 + 1872*uk_132 + 3600*uk_133 + 64*uk_134 + 4096*uk_135 + 14976*uk_136 + 28800*uk_137 + 512*uk_138 + 54756*uk_139 + 5934123*uk_14 + 105300*uk_140 + 1872*uk_141 + 202500*uk_142 + 3600*uk_143 + 64*uk_144 + 32768*uk_145 + 119808*uk_146 + 230400*uk_147 + 4096*uk_148 + 438048*uk_149 + 11411775*uk_15 + 842400*uk_150 + 14976*uk_151 + 1620000*uk_152 + 28800*uk_153 + 512*uk_154 + 1601613*uk_155 + 3080025*uk_156 + 54756*uk_157 + 5923125*uk_158 + 105300*uk_159 + 202876*uk_16 + 1872*uk_160 + 11390625*uk_161 + 202500*uk_162 + 3600*uk_163 + 64*uk_164 + 3025*uk_17 + 3190*uk_18 + 220*uk_19 + 55*uk_2 + 1760*uk_20 + 6435*uk_21 + 12375*uk_22 + 220*uk_23 + 3364*uk_24 + 232*uk_25 + 1856*uk_26 + 6786*uk_27 + 13050*uk_28 + 232*uk_29 + 58*uk_3 + 16*uk_30 + 128*uk_31 + 468*uk_32 + 900*uk_33 + 16*uk_34 + 1024*uk_35 + 3744*uk_36 + 7200*uk_37 + 128*uk_38 + 13689*uk_39 + 4*uk_4 + 26325*uk_40 + 468*uk_41 + 50625*uk_42 + 900*uk_43 + 16*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 149200183738*uk_47 + 10289667844*uk_48 + 82317342752*uk_49 + 32*uk_5 + 300972784437*uk_50 + 578793816225*uk_51 + 10289667844*uk_52 + 153424975*uk_53 + 161793610*uk_54 + 11158180*uk_55 + 89265440*uk_56 + 326376765*uk_57 + 627647625*uk_58 + 11158180*uk_59 + 117*uk_6 + 170618716*uk_60 + 11766808*uk_61 + 94134464*uk_62 + 344179134*uk_63 + 661882950*uk_64 + 11766808*uk_65 + 811504*uk_66 + 6492032*uk_67 + 23736492*uk_68 + 45647100*uk_69 + 225*uk_7 + 811504*uk_70 + 51936256*uk_71 + 189891936*uk_72 + 365176800*uk_73 + 6492032*uk_74 + 694292391*uk_75 + 1335177675*uk_76 + 23736492*uk_77 + 2567649375*uk_78 + 45647100*uk_79 + 4*uk_8 + 811504*uk_80 + 166375*uk_81 + 175450*uk_82 + 12100*uk_83 + 96800*uk_84 + 353925*uk_85 + 680625*uk_86 + 12100*uk_87 + 185020*uk_88 + 12760*uk_89 + 2572416961*uk_9 + 102080*uk_90 + 373230*uk_91 + 717750*uk_92 + 12760*uk_93 + 880*uk_94 + 7040*uk_95 + 25740*uk_96 + 49500*uk_97 + 880*uk_98 + 56320*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 183260*uk_100 + 346500*uk_101 + 89320*uk_102 + 778855*uk_103 + 1472625*uk_104 + 379610*uk_105 + 2784375*uk_106 + 717750*uk_107 + 185020*uk_108 + 15625*uk_109 + 1267975*uk_11 + 36250*uk_110 + 17500*uk_111 + 74375*uk_112 + 140625*uk_113 + 36250*uk_114 + 84100*uk_115 + 40600*uk_116 + 172550*uk_117 + 326250*uk_118 + 84100*uk_119 + 2941702*uk_12 + 19600*uk_120 + 83300*uk_121 + 157500*uk_122 + 40600*uk_123 + 354025*uk_124 + 669375*uk_125 + 172550*uk_126 + 1265625*uk_127 + 326250*uk_128 + 84100*uk_129 + 1420132*uk_13 + 195112*uk_130 + 94192*uk_131 + 400316*uk_132 + 756900*uk_133 + 195112*uk_134 + 45472*uk_135 + 193256*uk_136 + 365400*uk_137 + 94192*uk_138 + 821338*uk_139 + 6035561*uk_14 + 1552950*uk_140 + 400316*uk_141 + 2936250*uk_142 + 756900*uk_143 + 195112*uk_144 + 21952*uk_145 + 93296*uk_146 + 176400*uk_147 + 45472*uk_148 + 396508*uk_149 + 11411775*uk_15 + 749700*uk_150 + 193256*uk_151 + 1417500*uk_152 + 365400*uk_153 + 94192*uk_154 + 1685159*uk_155 + 3186225*uk_156 + 821338*uk_157 + 6024375*uk_158 + 1552950*uk_159 + 2941702*uk_16 + 400316*uk_160 + 11390625*uk_161 + 2936250*uk_162 + 756900*uk_163 + 195112*uk_164 + 3025*uk_17 + 1375*uk_18 + 3190*uk_19 + 55*uk_2 + 1540*uk_20 + 6545*uk_21 + 12375*uk_22 + 3190*uk_23 + 625*uk_24 + 1450*uk_25 + 700*uk_26 + 2975*uk_27 + 5625*uk_28 + 1450*uk_29 + 25*uk_3 + 3364*uk_30 + 1624*uk_31 + 6902*uk_32 + 13050*uk_33 + 3364*uk_34 + 784*uk_35 + 3332*uk_36 + 6300*uk_37 + 1624*uk_38 + 14161*uk_39 + 58*uk_4 + 26775*uk_40 + 6902*uk_41 + 50625*uk_42 + 13050*uk_43 + 3364*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 64310424025*uk_47 + 149200183738*uk_48 + 72027674908*uk_49 + 28*uk_5 + 306117618359*uk_50 + 578793816225*uk_51 + 149200183738*uk_52 + 153424975*uk_53 + 69738625*uk_54 + 161793610*uk_55 + 78107260*uk_56 + 331955855*uk_57 + 627647625*uk_58 + 161793610*uk_59 + 119*uk_6 + 31699375*uk_60 + 73542550*uk_61 + 35503300*uk_62 + 150889025*uk_63 + 285294375*uk_64 + 73542550*uk_65 + 170618716*uk_66 + 82367656*uk_67 + 350062538*uk_68 + 661882950*uk_69 + 225*uk_7 + 170618716*uk_70 + 39763696*uk_71 + 168995708*uk_72 + 319529700*uk_73 + 82367656*uk_74 + 718231759*uk_75 + 1358001225*uk_76 + 350062538*uk_77 + 2567649375*uk_78 + 661882950*uk_79 + 58*uk_8 + 170618716*uk_80 + 166375*uk_81 + 75625*uk_82 + 175450*uk_83 + 84700*uk_84 + 359975*uk_85 + 680625*uk_86 + 175450*uk_87 + 34375*uk_88 + 79750*uk_89 + 2572416961*uk_9 + 38500*uk_90 + 163625*uk_91 + 309375*uk_92 + 79750*uk_93 + 185020*uk_94 + 89320*uk_95 + 379610*uk_96 + 717750*uk_97 + 185020*uk_98 + 43120*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 186340*uk_100 + 346500*uk_101 + 38500*uk_102 + 805255*uk_103 + 1497375*uk_104 + 166375*uk_105 + 2784375*uk_106 + 309375*uk_107 + 34375*uk_108 + 8000*uk_109 + 1014380*uk_11 + 10000*uk_110 + 11200*uk_111 + 48400*uk_112 + 90000*uk_113 + 10000*uk_114 + 12500*uk_115 + 14000*uk_116 + 60500*uk_117 + 112500*uk_118 + 12500*uk_119 + 1267975*uk_12 + 15680*uk_120 + 67760*uk_121 + 126000*uk_122 + 14000*uk_123 + 292820*uk_124 + 544500*uk_125 + 60500*uk_126 + 1012500*uk_127 + 112500*uk_128 + 12500*uk_129 + 1420132*uk_13 + 15625*uk_130 + 17500*uk_131 + 75625*uk_132 + 140625*uk_133 + 15625*uk_134 + 19600*uk_135 + 84700*uk_136 + 157500*uk_137 + 17500*uk_138 + 366025*uk_139 + 6136999*uk_14 + 680625*uk_140 + 75625*uk_141 + 1265625*uk_142 + 140625*uk_143 + 15625*uk_144 + 21952*uk_145 + 94864*uk_146 + 176400*uk_147 + 19600*uk_148 + 409948*uk_149 + 11411775*uk_15 + 762300*uk_150 + 84700*uk_151 + 1417500*uk_152 + 157500*uk_153 + 17500*uk_154 + 1771561*uk_155 + 3294225*uk_156 + 366025*uk_157 + 6125625*uk_158 + 680625*uk_159 + 1267975*uk_16 + 75625*uk_160 + 11390625*uk_161 + 1265625*uk_162 + 140625*uk_163 + 15625*uk_164 + 3025*uk_17 + 1100*uk_18 + 1375*uk_19 + 55*uk_2 + 1540*uk_20 + 6655*uk_21 + 12375*uk_22 + 1375*uk_23 + 400*uk_24 + 500*uk_25 + 560*uk_26 + 2420*uk_27 + 4500*uk_28 + 500*uk_29 + 20*uk_3 + 625*uk_30 + 700*uk_31 + 3025*uk_32 + 5625*uk_33 + 625*uk_34 + 784*uk_35 + 3388*uk_36 + 6300*uk_37 + 700*uk_38 + 14641*uk_39 + 25*uk_4 + 27225*uk_40 + 3025*uk_41 + 50625*uk_42 + 5625*uk_43 + 625*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 51448339220*uk_47 + 64310424025*uk_48 + 72027674908*uk_49 + 28*uk_5 + 311262452281*uk_50 + 578793816225*uk_51 + 64310424025*uk_52 + 153424975*uk_53 + 55790900*uk_54 + 69738625*uk_55 + 78107260*uk_56 + 337534945*uk_57 + 627647625*uk_58 + 69738625*uk_59 + 121*uk_6 + 20287600*uk_60 + 25359500*uk_61 + 28402640*uk_62 + 122739980*uk_63 + 228235500*uk_64 + 25359500*uk_65 + 31699375*uk_66 + 35503300*uk_67 + 153424975*uk_68 + 285294375*uk_69 + 225*uk_7 + 31699375*uk_70 + 39763696*uk_71 + 171835972*uk_72 + 319529700*uk_73 + 35503300*uk_74 + 742576879*uk_75 + 1380824775*uk_76 + 153424975*uk_77 + 2567649375*uk_78 + 285294375*uk_79 + 25*uk_8 + 31699375*uk_80 + 166375*uk_81 + 60500*uk_82 + 75625*uk_83 + 84700*uk_84 + 366025*uk_85 + 680625*uk_86 + 75625*uk_87 + 22000*uk_88 + 27500*uk_89 + 2572416961*uk_9 + 30800*uk_90 + 133100*uk_91 + 247500*uk_92 + 27500*uk_93 + 34375*uk_94 + 38500*uk_95 + 166375*uk_96 + 309375*uk_97 + 34375*uk_98 + 43120*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 189420*uk_100 + 346500*uk_101 + 30800*uk_102 + 832095*uk_103 + 1522125*uk_104 + 135300*uk_105 + 2784375*uk_106 + 247500*uk_107 + 22000*uk_108 + 79507*uk_109 + 2180917*uk_11 + 36980*uk_110 + 51772*uk_111 + 227427*uk_112 + 416025*uk_113 + 36980*uk_114 + 17200*uk_115 + 24080*uk_116 + 105780*uk_117 + 193500*uk_118 + 17200*uk_119 + 1014380*uk_12 + 33712*uk_120 + 148092*uk_121 + 270900*uk_122 + 24080*uk_123 + 650547*uk_124 + 1190025*uk_125 + 105780*uk_126 + 2176875*uk_127 + 193500*uk_128 + 17200*uk_129 + 1420132*uk_13 + 8000*uk_130 + 11200*uk_131 + 49200*uk_132 + 90000*uk_133 + 8000*uk_134 + 15680*uk_135 + 68880*uk_136 + 126000*uk_137 + 11200*uk_138 + 302580*uk_139 + 6238437*uk_14 + 553500*uk_140 + 49200*uk_141 + 1012500*uk_142 + 90000*uk_143 + 8000*uk_144 + 21952*uk_145 + 96432*uk_146 + 176400*uk_147 + 15680*uk_148 + 423612*uk_149 + 11411775*uk_15 + 774900*uk_150 + 68880*uk_151 + 1417500*uk_152 + 126000*uk_153 + 11200*uk_154 + 1860867*uk_155 + 3404025*uk_156 + 302580*uk_157 + 6226875*uk_158 + 553500*uk_159 + 1014380*uk_16 + 49200*uk_160 + 11390625*uk_161 + 1012500*uk_162 + 90000*uk_163 + 8000*uk_164 + 3025*uk_17 + 2365*uk_18 + 1100*uk_19 + 55*uk_2 + 1540*uk_20 + 6765*uk_21 + 12375*uk_22 + 1100*uk_23 + 1849*uk_24 + 860*uk_25 + 1204*uk_26 + 5289*uk_27 + 9675*uk_28 + 860*uk_29 + 43*uk_3 + 400*uk_30 + 560*uk_31 + 2460*uk_32 + 4500*uk_33 + 400*uk_34 + 784*uk_35 + 3444*uk_36 + 6300*uk_37 + 560*uk_38 + 15129*uk_39 + 20*uk_4 + 27675*uk_40 + 2460*uk_41 + 50625*uk_42 + 4500*uk_43 + 400*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 110613929323*uk_47 + 51448339220*uk_48 + 72027674908*uk_49 + 28*uk_5 + 316407286203*uk_50 + 578793816225*uk_51 + 51448339220*uk_52 + 153424975*uk_53 + 119950435*uk_54 + 55790900*uk_55 + 78107260*uk_56 + 343114035*uk_57 + 627647625*uk_58 + 55790900*uk_59 + 123*uk_6 + 93779431*uk_60 + 43618340*uk_61 + 61065676*uk_62 + 268252791*uk_63 + 490706325*uk_64 + 43618340*uk_65 + 20287600*uk_66 + 28402640*uk_67 + 124768740*uk_68 + 228235500*uk_69 + 225*uk_7 + 20287600*uk_70 + 39763696*uk_71 + 174676236*uk_72 + 319529700*uk_73 + 28402640*uk_74 + 767327751*uk_75 + 1403648325*uk_76 + 124768740*uk_77 + 2567649375*uk_78 + 228235500*uk_79 + 20*uk_8 + 20287600*uk_80 + 166375*uk_81 + 130075*uk_82 + 60500*uk_83 + 84700*uk_84 + 372075*uk_85 + 680625*uk_86 + 60500*uk_87 + 101695*uk_88 + 47300*uk_89 + 2572416961*uk_9 + 66220*uk_90 + 290895*uk_91 + 532125*uk_92 + 47300*uk_93 + 22000*uk_94 + 30800*uk_95 + 135300*uk_96 + 247500*uk_97 + 22000*uk_98 + 43120*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 192500*uk_100 + 346500*uk_101 + 66220*uk_102 + 859375*uk_103 + 1546875*uk_104 + 295625*uk_105 + 2784375*uk_106 + 532125*uk_107 + 101695*uk_108 + 830584*uk_109 + 4767586*uk_11 + 379948*uk_110 + 247408*uk_111 + 1104500*uk_112 + 1988100*uk_113 + 379948*uk_114 + 173806*uk_115 + 113176*uk_116 + 505250*uk_117 + 909450*uk_118 + 173806*uk_119 + 2180917*uk_12 + 73696*uk_120 + 329000*uk_121 + 592200*uk_122 + 113176*uk_123 + 1468750*uk_124 + 2643750*uk_125 + 505250*uk_126 + 4758750*uk_127 + 909450*uk_128 + 173806*uk_129 + 1420132*uk_13 + 79507*uk_130 + 51772*uk_131 + 231125*uk_132 + 416025*uk_133 + 79507*uk_134 + 33712*uk_135 + 150500*uk_136 + 270900*uk_137 + 51772*uk_138 + 671875*uk_139 + 6339875*uk_14 + 1209375*uk_140 + 231125*uk_141 + 2176875*uk_142 + 416025*uk_143 + 79507*uk_144 + 21952*uk_145 + 98000*uk_146 + 176400*uk_147 + 33712*uk_148 + 437500*uk_149 + 11411775*uk_15 + 787500*uk_150 + 150500*uk_151 + 1417500*uk_152 + 270900*uk_153 + 51772*uk_154 + 1953125*uk_155 + 3515625*uk_156 + 671875*uk_157 + 6328125*uk_158 + 1209375*uk_159 + 2180917*uk_16 + 231125*uk_160 + 11390625*uk_161 + 2176875*uk_162 + 416025*uk_163 + 79507*uk_164 + 3025*uk_17 + 5170*uk_18 + 2365*uk_19 + 55*uk_2 + 1540*uk_20 + 6875*uk_21 + 12375*uk_22 + 2365*uk_23 + 8836*uk_24 + 4042*uk_25 + 2632*uk_26 + 11750*uk_27 + 21150*uk_28 + 4042*uk_29 + 94*uk_3 + 1849*uk_30 + 1204*uk_31 + 5375*uk_32 + 9675*uk_33 + 1849*uk_34 + 784*uk_35 + 3500*uk_36 + 6300*uk_37 + 1204*uk_38 + 15625*uk_39 + 43*uk_4 + 28125*uk_40 + 5375*uk_41 + 50625*uk_42 + 9675*uk_43 + 1849*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 241807194334*uk_47 + 110613929323*uk_48 + 72027674908*uk_49 + 28*uk_5 + 321552120125*uk_50 + 578793816225*uk_51 + 110613929323*uk_52 + 153424975*uk_53 + 262217230*uk_54 + 119950435*uk_55 + 78107260*uk_56 + 348693125*uk_57 + 627647625*uk_58 + 119950435*uk_59 + 125*uk_6 + 448153084*uk_60 + 205006198*uk_61 + 133492408*uk_62 + 595948250*uk_63 + 1072706850*uk_64 + 205006198*uk_65 + 93779431*uk_66 + 61065676*uk_67 + 272614625*uk_68 + 490706325*uk_69 + 225*uk_7 + 93779431*uk_70 + 39763696*uk_71 + 177516500*uk_72 + 319529700*uk_73 + 61065676*uk_74 + 792484375*uk_75 + 1426471875*uk_76 + 272614625*uk_77 + 2567649375*uk_78 + 490706325*uk_79 + 43*uk_8 + 93779431*uk_80 + 166375*uk_81 + 284350*uk_82 + 130075*uk_83 + 84700*uk_84 + 378125*uk_85 + 680625*uk_86 + 130075*uk_87 + 485980*uk_88 + 222310*uk_89 + 2572416961*uk_9 + 144760*uk_90 + 646250*uk_91 + 1163250*uk_92 + 222310*uk_93 + 101695*uk_94 + 66220*uk_95 + 295625*uk_96 + 532125*uk_97 + 101695*uk_98 + 43120*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 167640*uk_100 + 297000*uk_101 + 124080*uk_102 + 887095*uk_103 + 1571625*uk_104 + 656590*uk_105 + 2784375*uk_106 + 1163250*uk_107 + 485980*uk_108 + 97336*uk_109 + 2333074*uk_11 + 198904*uk_110 + 50784*uk_111 + 268732*uk_112 + 476100*uk_113 + 198904*uk_114 + 406456*uk_115 + 103776*uk_116 + 549148*uk_117 + 972900*uk_118 + 406456*uk_119 + 4767586*uk_12 + 26496*uk_120 + 140208*uk_121 + 248400*uk_122 + 103776*uk_123 + 741934*uk_124 + 1314450*uk_125 + 549148*uk_126 + 2328750*uk_127 + 972900*uk_128 + 406456*uk_129 + 1217256*uk_13 + 830584*uk_130 + 212064*uk_131 + 1122172*uk_132 + 1988100*uk_133 + 830584*uk_134 + 54144*uk_135 + 286512*uk_136 + 507600*uk_137 + 212064*uk_138 + 1516126*uk_139 + 6441313*uk_14 + 2686050*uk_140 + 1122172*uk_141 + 4758750*uk_142 + 1988100*uk_143 + 830584*uk_144 + 13824*uk_145 + 73152*uk_146 + 129600*uk_147 + 54144*uk_148 + 387096*uk_149 + 11411775*uk_15 + 685800*uk_150 + 286512*uk_151 + 1215000*uk_152 + 507600*uk_153 + 212064*uk_154 + 2048383*uk_155 + 3629025*uk_156 + 1516126*uk_157 + 6429375*uk_158 + 2686050*uk_159 + 4767586*uk_16 + 1122172*uk_160 + 11390625*uk_161 + 4758750*uk_162 + 1988100*uk_163 + 830584*uk_164 + 3025*uk_17 + 2530*uk_18 + 5170*uk_19 + 55*uk_2 + 1320*uk_20 + 6985*uk_21 + 12375*uk_22 + 5170*uk_23 + 2116*uk_24 + 4324*uk_25 + 1104*uk_26 + 5842*uk_27 + 10350*uk_28 + 4324*uk_29 + 46*uk_3 + 8836*uk_30 + 2256*uk_31 + 11938*uk_32 + 21150*uk_33 + 8836*uk_34 + 576*uk_35 + 3048*uk_36 + 5400*uk_37 + 2256*uk_38 + 16129*uk_39 + 94*uk_4 + 28575*uk_40 + 11938*uk_41 + 50625*uk_42 + 21150*uk_43 + 8836*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 118331180206*uk_47 + 241807194334*uk_48 + 61738007064*uk_49 + 24*uk_5 + 326696954047*uk_50 + 578793816225*uk_51 + 241807194334*uk_52 + 153424975*uk_53 + 128319070*uk_54 + 262217230*uk_55 + 66949080*uk_56 + 354272215*uk_57 + 627647625*uk_58 + 262217230*uk_59 + 127*uk_6 + 107321404*uk_60 + 219308956*uk_61 + 55993776*uk_62 + 296300398*uk_63 + 524941650*uk_64 + 219308956*uk_65 + 448153084*uk_66 + 114422064*uk_67 + 605483422*uk_68 + 1072706850*uk_69 + 225*uk_7 + 448153084*uk_70 + 29214144*uk_71 + 154591512*uk_72 + 273882600*uk_73 + 114422064*uk_74 + 818046751*uk_75 + 1449295425*uk_76 + 605483422*uk_77 + 2567649375*uk_78 + 1072706850*uk_79 + 94*uk_8 + 448153084*uk_80 + 166375*uk_81 + 139150*uk_82 + 284350*uk_83 + 72600*uk_84 + 384175*uk_85 + 680625*uk_86 + 284350*uk_87 + 116380*uk_88 + 237820*uk_89 + 2572416961*uk_9 + 60720*uk_90 + 321310*uk_91 + 569250*uk_92 + 237820*uk_93 + 485980*uk_94 + 124080*uk_95 + 656590*uk_96 + 1163250*uk_97 + 485980*uk_98 + 31680*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 170280*uk_100 + 297000*uk_101 + 60720*uk_102 + 915255*uk_103 + 1596375*uk_104 + 326370*uk_105 + 2784375*uk_106 + 569250*uk_107 + 116380*uk_108 + 10648*uk_109 + 1115818*uk_11 + 22264*uk_110 + 11616*uk_111 + 62436*uk_112 + 108900*uk_113 + 22264*uk_114 + 46552*uk_115 + 24288*uk_116 + 130548*uk_117 + 227700*uk_118 + 46552*uk_119 + 2333074*uk_12 + 12672*uk_120 + 68112*uk_121 + 118800*uk_122 + 24288*uk_123 + 366102*uk_124 + 638550*uk_125 + 130548*uk_126 + 1113750*uk_127 + 227700*uk_128 + 46552*uk_129 + 1217256*uk_13 + 97336*uk_130 + 50784*uk_131 + 272964*uk_132 + 476100*uk_133 + 97336*uk_134 + 26496*uk_135 + 142416*uk_136 + 248400*uk_137 + 50784*uk_138 + 765486*uk_139 + 6542751*uk_14 + 1335150*uk_140 + 272964*uk_141 + 2328750*uk_142 + 476100*uk_143 + 97336*uk_144 + 13824*uk_145 + 74304*uk_146 + 129600*uk_147 + 26496*uk_148 + 399384*uk_149 + 11411775*uk_15 + 696600*uk_150 + 142416*uk_151 + 1215000*uk_152 + 248400*uk_153 + 50784*uk_154 + 2146689*uk_155 + 3744225*uk_156 + 765486*uk_157 + 6530625*uk_158 + 1335150*uk_159 + 2333074*uk_16 + 272964*uk_160 + 11390625*uk_161 + 2328750*uk_162 + 476100*uk_163 + 97336*uk_164 + 3025*uk_17 + 1210*uk_18 + 2530*uk_19 + 55*uk_2 + 1320*uk_20 + 7095*uk_21 + 12375*uk_22 + 2530*uk_23 + 484*uk_24 + 1012*uk_25 + 528*uk_26 + 2838*uk_27 + 4950*uk_28 + 1012*uk_29 + 22*uk_3 + 2116*uk_30 + 1104*uk_31 + 5934*uk_32 + 10350*uk_33 + 2116*uk_34 + 576*uk_35 + 3096*uk_36 + 5400*uk_37 + 1104*uk_38 + 16641*uk_39 + 46*uk_4 + 29025*uk_40 + 5934*uk_41 + 50625*uk_42 + 10350*uk_43 + 2116*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 56593173142*uk_47 + 118331180206*uk_48 + 61738007064*uk_49 + 24*uk_5 + 331841787969*uk_50 + 578793816225*uk_51 + 118331180206*uk_52 + 153424975*uk_53 + 61369990*uk_54 + 128319070*uk_55 + 66949080*uk_56 + 359851305*uk_57 + 627647625*uk_58 + 128319070*uk_59 + 129*uk_6 + 24547996*uk_60 + 51327628*uk_61 + 26779632*uk_62 + 143940522*uk_63 + 251059050*uk_64 + 51327628*uk_65 + 107321404*uk_66 + 55993776*uk_67 + 300966546*uk_68 + 524941650*uk_69 + 225*uk_7 + 107321404*uk_70 + 29214144*uk_71 + 157026024*uk_72 + 273882600*uk_73 + 55993776*uk_74 + 844014879*uk_75 + 1472118975*uk_76 + 300966546*uk_77 + 2567649375*uk_78 + 524941650*uk_79 + 46*uk_8 + 107321404*uk_80 + 166375*uk_81 + 66550*uk_82 + 139150*uk_83 + 72600*uk_84 + 390225*uk_85 + 680625*uk_86 + 139150*uk_87 + 26620*uk_88 + 55660*uk_89 + 2572416961*uk_9 + 29040*uk_90 + 156090*uk_91 + 272250*uk_92 + 55660*uk_93 + 116380*uk_94 + 60720*uk_95 + 326370*uk_96 + 569250*uk_97 + 116380*uk_98 + 31680*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 172920*uk_100 + 297000*uk_101 + 29040*uk_102 + 943855*uk_103 + 1621125*uk_104 + 158510*uk_105 + 2784375*uk_106 + 272250*uk_107 + 26620*uk_108 + 10648*uk_109 + 1115818*uk_11 + 10648*uk_110 + 11616*uk_111 + 63404*uk_112 + 108900*uk_113 + 10648*uk_114 + 10648*uk_115 + 11616*uk_116 + 63404*uk_117 + 108900*uk_118 + 10648*uk_119 + 1115818*uk_12 + 12672*uk_120 + 69168*uk_121 + 118800*uk_122 + 11616*uk_123 + 377542*uk_124 + 648450*uk_125 + 63404*uk_126 + 1113750*uk_127 + 108900*uk_128 + 10648*uk_129 + 1217256*uk_13 + 10648*uk_130 + 11616*uk_131 + 63404*uk_132 + 108900*uk_133 + 10648*uk_134 + 12672*uk_135 + 69168*uk_136 + 118800*uk_137 + 11616*uk_138 + 377542*uk_139 + 6644189*uk_14 + 648450*uk_140 + 63404*uk_141 + 1113750*uk_142 + 108900*uk_143 + 10648*uk_144 + 13824*uk_145 + 75456*uk_146 + 129600*uk_147 + 12672*uk_148 + 411864*uk_149 + 11411775*uk_15 + 707400*uk_150 + 69168*uk_151 + 1215000*uk_152 + 118800*uk_153 + 11616*uk_154 + 2248091*uk_155 + 3861225*uk_156 + 377542*uk_157 + 6631875*uk_158 + 648450*uk_159 + 1115818*uk_16 + 63404*uk_160 + 11390625*uk_161 + 1113750*uk_162 + 108900*uk_163 + 10648*uk_164 + 3025*uk_17 + 1210*uk_18 + 1210*uk_19 + 55*uk_2 + 1320*uk_20 + 7205*uk_21 + 12375*uk_22 + 1210*uk_23 + 484*uk_24 + 484*uk_25 + 528*uk_26 + 2882*uk_27 + 4950*uk_28 + 484*uk_29 + 22*uk_3 + 484*uk_30 + 528*uk_31 + 2882*uk_32 + 4950*uk_33 + 484*uk_34 + 576*uk_35 + 3144*uk_36 + 5400*uk_37 + 528*uk_38 + 17161*uk_39 + 22*uk_4 + 29475*uk_40 + 2882*uk_41 + 50625*uk_42 + 4950*uk_43 + 484*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 56593173142*uk_47 + 56593173142*uk_48 + 61738007064*uk_49 + 24*uk_5 + 336986621891*uk_50 + 578793816225*uk_51 + 56593173142*uk_52 + 153424975*uk_53 + 61369990*uk_54 + 61369990*uk_55 + 66949080*uk_56 + 365430395*uk_57 + 627647625*uk_58 + 61369990*uk_59 + 131*uk_6 + 24547996*uk_60 + 24547996*uk_61 + 26779632*uk_62 + 146172158*uk_63 + 251059050*uk_64 + 24547996*uk_65 + 24547996*uk_66 + 26779632*uk_67 + 146172158*uk_68 + 251059050*uk_69 + 225*uk_7 + 24547996*uk_70 + 29214144*uk_71 + 159460536*uk_72 + 273882600*uk_73 + 26779632*uk_74 + 870388759*uk_75 + 1494942525*uk_76 + 146172158*uk_77 + 2567649375*uk_78 + 251059050*uk_79 + 22*uk_8 + 24547996*uk_80 + 166375*uk_81 + 66550*uk_82 + 66550*uk_83 + 72600*uk_84 + 396275*uk_85 + 680625*uk_86 + 66550*uk_87 + 26620*uk_88 + 26620*uk_89 + 2572416961*uk_9 + 29040*uk_90 + 158510*uk_91 + 272250*uk_92 + 26620*uk_93 + 26620*uk_94 + 29040*uk_95 + 158510*uk_96 + 272250*uk_97 + 26620*uk_98 + 31680*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 175560*uk_100 + 297000*uk_101 + 29040*uk_102 + 972895*uk_103 + 1645875*uk_104 + 160930*uk_105 + 2784375*uk_106 + 272250*uk_107 + 26620*uk_108 + 97336*uk_109 + 2333074*uk_11 + 46552*uk_110 + 50784*uk_111 + 281428*uk_112 + 476100*uk_113 + 46552*uk_114 + 22264*uk_115 + 24288*uk_116 + 134596*uk_117 + 227700*uk_118 + 22264*uk_119 + 1115818*uk_12 + 26496*uk_120 + 146832*uk_121 + 248400*uk_122 + 24288*uk_123 + 813694*uk_124 + 1376550*uk_125 + 134596*uk_126 + 2328750*uk_127 + 227700*uk_128 + 22264*uk_129 + 1217256*uk_13 + 10648*uk_130 + 11616*uk_131 + 64372*uk_132 + 108900*uk_133 + 10648*uk_134 + 12672*uk_135 + 70224*uk_136 + 118800*uk_137 + 11616*uk_138 + 389158*uk_139 + 6745627*uk_14 + 658350*uk_140 + 64372*uk_141 + 1113750*uk_142 + 108900*uk_143 + 10648*uk_144 + 13824*uk_145 + 76608*uk_146 + 129600*uk_147 + 12672*uk_148 + 424536*uk_149 + 11411775*uk_15 + 718200*uk_150 + 70224*uk_151 + 1215000*uk_152 + 118800*uk_153 + 11616*uk_154 + 2352637*uk_155 + 3980025*uk_156 + 389158*uk_157 + 6733125*uk_158 + 658350*uk_159 + 1115818*uk_16 + 64372*uk_160 + 11390625*uk_161 + 1113750*uk_162 + 108900*uk_163 + 10648*uk_164 + 3025*uk_17 + 2530*uk_18 + 1210*uk_19 + 55*uk_2 + 1320*uk_20 + 7315*uk_21 + 12375*uk_22 + 1210*uk_23 + 2116*uk_24 + 1012*uk_25 + 1104*uk_26 + 6118*uk_27 + 10350*uk_28 + 1012*uk_29 + 46*uk_3 + 484*uk_30 + 528*uk_31 + 2926*uk_32 + 4950*uk_33 + 484*uk_34 + 576*uk_35 + 3192*uk_36 + 5400*uk_37 + 528*uk_38 + 17689*uk_39 + 22*uk_4 + 29925*uk_40 + 2926*uk_41 + 50625*uk_42 + 4950*uk_43 + 484*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 118331180206*uk_47 + 56593173142*uk_48 + 61738007064*uk_49 + 24*uk_5 + 342131455813*uk_50 + 578793816225*uk_51 + 56593173142*uk_52 + 153424975*uk_53 + 128319070*uk_54 + 61369990*uk_55 + 66949080*uk_56 + 371009485*uk_57 + 627647625*uk_58 + 61369990*uk_59 + 133*uk_6 + 107321404*uk_60 + 51327628*uk_61 + 55993776*uk_62 + 310298842*uk_63 + 524941650*uk_64 + 51327628*uk_65 + 24547996*uk_66 + 26779632*uk_67 + 148403794*uk_68 + 251059050*uk_69 + 225*uk_7 + 24547996*uk_70 + 29214144*uk_71 + 161895048*uk_72 + 273882600*uk_73 + 26779632*uk_74 + 897168391*uk_75 + 1517766075*uk_76 + 148403794*uk_77 + 2567649375*uk_78 + 251059050*uk_79 + 22*uk_8 + 24547996*uk_80 + 166375*uk_81 + 139150*uk_82 + 66550*uk_83 + 72600*uk_84 + 402325*uk_85 + 680625*uk_86 + 66550*uk_87 + 116380*uk_88 + 55660*uk_89 + 2572416961*uk_9 + 60720*uk_90 + 336490*uk_91 + 569250*uk_92 + 55660*uk_93 + 26620*uk_94 + 29040*uk_95 + 160930*uk_96 + 272250*uk_97 + 26620*uk_98 + 31680*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 178200*uk_100 + 297000*uk_101 + 60720*uk_102 + 1002375*uk_103 + 1670625*uk_104 + 341550*uk_105 + 2784375*uk_106 + 569250*uk_107 + 116380*uk_108 + 830584*uk_109 + 4767586*uk_11 + 406456*uk_110 + 212064*uk_111 + 1192860*uk_112 + 1988100*uk_113 + 406456*uk_114 + 198904*uk_115 + 103776*uk_116 + 583740*uk_117 + 972900*uk_118 + 198904*uk_119 + 2333074*uk_12 + 54144*uk_120 + 304560*uk_121 + 507600*uk_122 + 103776*uk_123 + 1713150*uk_124 + 2855250*uk_125 + 583740*uk_126 + 4758750*uk_127 + 972900*uk_128 + 198904*uk_129 + 1217256*uk_13 + 97336*uk_130 + 50784*uk_131 + 285660*uk_132 + 476100*uk_133 + 97336*uk_134 + 26496*uk_135 + 149040*uk_136 + 248400*uk_137 + 50784*uk_138 + 838350*uk_139 + 6847065*uk_14 + 1397250*uk_140 + 285660*uk_141 + 2328750*uk_142 + 476100*uk_143 + 97336*uk_144 + 13824*uk_145 + 77760*uk_146 + 129600*uk_147 + 26496*uk_148 + 437400*uk_149 + 11411775*uk_15 + 729000*uk_150 + 149040*uk_151 + 1215000*uk_152 + 248400*uk_153 + 50784*uk_154 + 2460375*uk_155 + 4100625*uk_156 + 838350*uk_157 + 6834375*uk_158 + 1397250*uk_159 + 2333074*uk_16 + 285660*uk_160 + 11390625*uk_161 + 2328750*uk_162 + 476100*uk_163 + 97336*uk_164 + 3025*uk_17 + 5170*uk_18 + 2530*uk_19 + 55*uk_2 + 1320*uk_20 + 7425*uk_21 + 12375*uk_22 + 2530*uk_23 + 8836*uk_24 + 4324*uk_25 + 2256*uk_26 + 12690*uk_27 + 21150*uk_28 + 4324*uk_29 + 94*uk_3 + 2116*uk_30 + 1104*uk_31 + 6210*uk_32 + 10350*uk_33 + 2116*uk_34 + 576*uk_35 + 3240*uk_36 + 5400*uk_37 + 1104*uk_38 + 18225*uk_39 + 46*uk_4 + 30375*uk_40 + 6210*uk_41 + 50625*uk_42 + 10350*uk_43 + 2116*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 241807194334*uk_47 + 118331180206*uk_48 + 61738007064*uk_49 + 24*uk_5 + 347276289735*uk_50 + 578793816225*uk_51 + 118331180206*uk_52 + 153424975*uk_53 + 262217230*uk_54 + 128319070*uk_55 + 66949080*uk_56 + 376588575*uk_57 + 627647625*uk_58 + 128319070*uk_59 + 135*uk_6 + 448153084*uk_60 + 219308956*uk_61 + 114422064*uk_62 + 643624110*uk_63 + 1072706850*uk_64 + 219308956*uk_65 + 107321404*uk_66 + 55993776*uk_67 + 314964990*uk_68 + 524941650*uk_69 + 225*uk_7 + 107321404*uk_70 + 29214144*uk_71 + 164329560*uk_72 + 273882600*uk_73 + 55993776*uk_74 + 924353775*uk_75 + 1540589625*uk_76 + 314964990*uk_77 + 2567649375*uk_78 + 524941650*uk_79 + 46*uk_8 + 107321404*uk_80 + 166375*uk_81 + 284350*uk_82 + 139150*uk_83 + 72600*uk_84 + 408375*uk_85 + 680625*uk_86 + 139150*uk_87 + 485980*uk_88 + 237820*uk_89 + 2572416961*uk_9 + 124080*uk_90 + 697950*uk_91 + 1163250*uk_92 + 237820*uk_93 + 116380*uk_94 + 60720*uk_95 + 341550*uk_96 + 569250*uk_97 + 116380*uk_98 + 31680*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 150700*uk_100 + 247500*uk_101 + 103400*uk_102 + 1032295*uk_103 + 1695375*uk_104 + 708290*uk_105 + 2784375*uk_106 + 1163250*uk_107 + 485980*uk_108 + 24389*uk_109 + 1470851*uk_11 + 79054*uk_110 + 16820*uk_111 + 115217*uk_112 + 189225*uk_113 + 79054*uk_114 + 256244*uk_115 + 54520*uk_116 + 373462*uk_117 + 613350*uk_118 + 256244*uk_119 + 4767586*uk_12 + 11600*uk_120 + 79460*uk_121 + 130500*uk_122 + 54520*uk_123 + 544301*uk_124 + 893925*uk_125 + 373462*uk_126 + 1468125*uk_127 + 613350*uk_128 + 256244*uk_129 + 1014380*uk_13 + 830584*uk_130 + 176720*uk_131 + 1210532*uk_132 + 1988100*uk_133 + 830584*uk_134 + 37600*uk_135 + 257560*uk_136 + 423000*uk_137 + 176720*uk_138 + 1764286*uk_139 + 6948503*uk_14 + 2897550*uk_140 + 1210532*uk_141 + 4758750*uk_142 + 1988100*uk_143 + 830584*uk_144 + 8000*uk_145 + 54800*uk_146 + 90000*uk_147 + 37600*uk_148 + 375380*uk_149 + 11411775*uk_15 + 616500*uk_150 + 257560*uk_151 + 1012500*uk_152 + 423000*uk_153 + 176720*uk_154 + 2571353*uk_155 + 4223025*uk_156 + 1764286*uk_157 + 6935625*uk_158 + 2897550*uk_159 + 4767586*uk_16 + 1210532*uk_160 + 11390625*uk_161 + 4758750*uk_162 + 1988100*uk_163 + 830584*uk_164 + 3025*uk_17 + 1595*uk_18 + 5170*uk_19 + 55*uk_2 + 1100*uk_20 + 7535*uk_21 + 12375*uk_22 + 5170*uk_23 + 841*uk_24 + 2726*uk_25 + 580*uk_26 + 3973*uk_27 + 6525*uk_28 + 2726*uk_29 + 29*uk_3 + 8836*uk_30 + 1880*uk_31 + 12878*uk_32 + 21150*uk_33 + 8836*uk_34 + 400*uk_35 + 2740*uk_36 + 4500*uk_37 + 1880*uk_38 + 18769*uk_39 + 94*uk_4 + 30825*uk_40 + 12878*uk_41 + 50625*uk_42 + 21150*uk_43 + 8836*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 74600091869*uk_47 + 241807194334*uk_48 + 51448339220*uk_49 + 20*uk_5 + 352421123657*uk_50 + 578793816225*uk_51 + 241807194334*uk_52 + 153424975*uk_53 + 80896805*uk_54 + 262217230*uk_55 + 55790900*uk_56 + 382167665*uk_57 + 627647625*uk_58 + 262217230*uk_59 + 137*uk_6 + 42654679*uk_60 + 138259994*uk_61 + 29417020*uk_62 + 201506587*uk_63 + 330941475*uk_64 + 138259994*uk_65 + 448153084*uk_66 + 95351720*uk_67 + 653159282*uk_68 + 1072706850*uk_69 + 225*uk_7 + 448153084*uk_70 + 20287600*uk_71 + 138970060*uk_72 + 228235500*uk_73 + 95351720*uk_74 + 951944911*uk_75 + 1563413175*uk_76 + 653159282*uk_77 + 2567649375*uk_78 + 1072706850*uk_79 + 94*uk_8 + 448153084*uk_80 + 166375*uk_81 + 87725*uk_82 + 284350*uk_83 + 60500*uk_84 + 414425*uk_85 + 680625*uk_86 + 284350*uk_87 + 46255*uk_88 + 149930*uk_89 + 2572416961*uk_9 + 31900*uk_90 + 218515*uk_91 + 358875*uk_92 + 149930*uk_93 + 485980*uk_94 + 103400*uk_95 + 708290*uk_96 + 1163250*uk_97 + 485980*uk_98 + 22000*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 183480*uk_100 + 297000*uk_101 + 38280*uk_102 + 1062655*uk_103 + 1720125*uk_104 + 221705*uk_105 + 2784375*uk_106 + 358875*uk_107 + 46255*uk_108 + 1860867*uk_109 + 6238437*uk_11 + 438741*uk_110 + 363096*uk_111 + 2102931*uk_112 + 3404025*uk_113 + 438741*uk_114 + 103443*uk_115 + 85608*uk_116 + 495813*uk_117 + 802575*uk_118 + 103443*uk_119 + 1470851*uk_12 + 70848*uk_120 + 410328*uk_121 + 664200*uk_122 + 85608*uk_123 + 2376483*uk_124 + 3846825*uk_125 + 495813*uk_126 + 6226875*uk_127 + 802575*uk_128 + 103443*uk_129 + 1217256*uk_13 + 24389*uk_130 + 20184*uk_131 + 116899*uk_132 + 189225*uk_133 + 24389*uk_134 + 16704*uk_135 + 96744*uk_136 + 156600*uk_137 + 20184*uk_138 + 560309*uk_139 + 7049941*uk_14 + 906975*uk_140 + 116899*uk_141 + 1468125*uk_142 + 189225*uk_143 + 24389*uk_144 + 13824*uk_145 + 80064*uk_146 + 129600*uk_147 + 16704*uk_148 + 463704*uk_149 + 11411775*uk_15 + 750600*uk_150 + 96744*uk_151 + 1215000*uk_152 + 156600*uk_153 + 20184*uk_154 + 2685619*uk_155 + 4347225*uk_156 + 560309*uk_157 + 7036875*uk_158 + 906975*uk_159 + 1470851*uk_16 + 116899*uk_160 + 11390625*uk_161 + 1468125*uk_162 + 189225*uk_163 + 24389*uk_164 + 3025*uk_17 + 6765*uk_18 + 1595*uk_19 + 55*uk_2 + 1320*uk_20 + 7645*uk_21 + 12375*uk_22 + 1595*uk_23 + 15129*uk_24 + 3567*uk_25 + 2952*uk_26 + 17097*uk_27 + 27675*uk_28 + 3567*uk_29 + 123*uk_3 + 841*uk_30 + 696*uk_31 + 4031*uk_32 + 6525*uk_33 + 841*uk_34 + 576*uk_35 + 3336*uk_36 + 5400*uk_37 + 696*uk_38 + 19321*uk_39 + 29*uk_4 + 31275*uk_40 + 4031*uk_41 + 50625*uk_42 + 6525*uk_43 + 841*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 316407286203*uk_47 + 74600091869*uk_48 + 61738007064*uk_49 + 24*uk_5 + 357565957579*uk_50 + 578793816225*uk_51 + 74600091869*uk_52 + 153424975*uk_53 + 343114035*uk_54 + 80896805*uk_55 + 66949080*uk_56 + 387746755*uk_57 + 627647625*uk_58 + 80896805*uk_59 + 139*uk_6 + 767327751*uk_60 + 180914673*uk_61 + 149722488*uk_62 + 867142743*uk_63 + 1403648325*uk_64 + 180914673*uk_65 + 42654679*uk_66 + 35300424*uk_67 + 204448289*uk_68 + 330941475*uk_69 + 225*uk_7 + 42654679*uk_70 + 29214144*uk_71 + 169198584*uk_72 + 273882600*uk_73 + 35300424*uk_74 + 979941799*uk_75 + 1586236725*uk_76 + 204448289*uk_77 + 2567649375*uk_78 + 330941475*uk_79 + 29*uk_8 + 42654679*uk_80 + 166375*uk_81 + 372075*uk_82 + 87725*uk_83 + 72600*uk_84 + 420475*uk_85 + 680625*uk_86 + 87725*uk_87 + 832095*uk_88 + 196185*uk_89 + 2572416961*uk_9 + 162360*uk_90 + 940335*uk_91 + 1522125*uk_92 + 196185*uk_93 + 46255*uk_94 + 38280*uk_95 + 221705*uk_96 + 358875*uk_97 + 46255*uk_98 + 31680*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 155100*uk_100 + 247500*uk_101 + 135300*uk_102 + 1093455*uk_103 + 1744875*uk_104 + 953865*uk_105 + 2784375*uk_106 + 1522125*uk_107 + 832095*uk_108 + 1000000*uk_109 + 5071900*uk_11 + 1230000*uk_110 + 200000*uk_111 + 1410000*uk_112 + 2250000*uk_113 + 1230000*uk_114 + 1512900*uk_115 + 246000*uk_116 + 1734300*uk_117 + 2767500*uk_118 + 1512900*uk_119 + 6238437*uk_12 + 40000*uk_120 + 282000*uk_121 + 450000*uk_122 + 246000*uk_123 + 1988100*uk_124 + 3172500*uk_125 + 1734300*uk_126 + 5062500*uk_127 + 2767500*uk_128 + 1512900*uk_129 + 1014380*uk_13 + 1860867*uk_130 + 302580*uk_131 + 2133189*uk_132 + 3404025*uk_133 + 1860867*uk_134 + 49200*uk_135 + 346860*uk_136 + 553500*uk_137 + 302580*uk_138 + 2445363*uk_139 + 7151379*uk_14 + 3902175*uk_140 + 2133189*uk_141 + 6226875*uk_142 + 3404025*uk_143 + 1860867*uk_144 + 8000*uk_145 + 56400*uk_146 + 90000*uk_147 + 49200*uk_148 + 397620*uk_149 + 11411775*uk_15 + 634500*uk_150 + 346860*uk_151 + 1012500*uk_152 + 553500*uk_153 + 302580*uk_154 + 2803221*uk_155 + 4473225*uk_156 + 2445363*uk_157 + 7138125*uk_158 + 3902175*uk_159 + 6238437*uk_16 + 2133189*uk_160 + 11390625*uk_161 + 6226875*uk_162 + 3404025*uk_163 + 1860867*uk_164 + 3025*uk_17 + 5500*uk_18 + 6765*uk_19 + 55*uk_2 + 1100*uk_20 + 7755*uk_21 + 12375*uk_22 + 6765*uk_23 + 10000*uk_24 + 12300*uk_25 + 2000*uk_26 + 14100*uk_27 + 22500*uk_28 + 12300*uk_29 + 100*uk_3 + 15129*uk_30 + 2460*uk_31 + 17343*uk_32 + 27675*uk_33 + 15129*uk_34 + 400*uk_35 + 2820*uk_36 + 4500*uk_37 + 2460*uk_38 + 19881*uk_39 + 123*uk_4 + 31725*uk_40 + 17343*uk_41 + 50625*uk_42 + 27675*uk_43 + 15129*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 257241696100*uk_47 + 316407286203*uk_48 + 51448339220*uk_49 + 20*uk_5 + 362710791501*uk_50 + 578793816225*uk_51 + 316407286203*uk_52 + 153424975*uk_53 + 278954500*uk_54 + 343114035*uk_55 + 55790900*uk_56 + 393325845*uk_57 + 627647625*uk_58 + 343114035*uk_59 + 141*uk_6 + 507190000*uk_60 + 623843700*uk_61 + 101438000*uk_62 + 715137900*uk_63 + 1141177500*uk_64 + 623843700*uk_65 + 767327751*uk_66 + 124768740*uk_67 + 879619617*uk_68 + 1403648325*uk_69 + 225*uk_7 + 767327751*uk_70 + 20287600*uk_71 + 143027580*uk_72 + 228235500*uk_73 + 124768740*uk_74 + 1008344439*uk_75 + 1609060275*uk_76 + 879619617*uk_77 + 2567649375*uk_78 + 1403648325*uk_79 + 123*uk_8 + 767327751*uk_80 + 166375*uk_81 + 302500*uk_82 + 372075*uk_83 + 60500*uk_84 + 426525*uk_85 + 680625*uk_86 + 372075*uk_87 + 550000*uk_88 + 676500*uk_89 + 2572416961*uk_9 + 110000*uk_90 + 775500*uk_91 + 1237500*uk_92 + 676500*uk_93 + 832095*uk_94 + 135300*uk_95 + 953865*uk_96 + 1522125*uk_97 + 832095*uk_98 + 22000*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 157300*uk_100 + 247500*uk_101 + 110000*uk_102 + 1124695*uk_103 + 1769625*uk_104 + 786500*uk_105 + 2784375*uk_106 + 1237500*uk_107 + 550000*uk_108 + 912673*uk_109 + 4919743*uk_11 + 940900*uk_110 + 188180*uk_111 + 1345487*uk_112 + 2117025*uk_113 + 940900*uk_114 + 970000*uk_115 + 194000*uk_116 + 1387100*uk_117 + 2182500*uk_118 + 970000*uk_119 + 5071900*uk_12 + 38800*uk_120 + 277420*uk_121 + 436500*uk_122 + 194000*uk_123 + 1983553*uk_124 + 3120975*uk_125 + 1387100*uk_126 + 4910625*uk_127 + 2182500*uk_128 + 970000*uk_129 + 1014380*uk_13 + 1000000*uk_130 + 200000*uk_131 + 1430000*uk_132 + 2250000*uk_133 + 1000000*uk_134 + 40000*uk_135 + 286000*uk_136 + 450000*uk_137 + 200000*uk_138 + 2044900*uk_139 + 7252817*uk_14 + 3217500*uk_140 + 1430000*uk_141 + 5062500*uk_142 + 2250000*uk_143 + 1000000*uk_144 + 8000*uk_145 + 57200*uk_146 + 90000*uk_147 + 40000*uk_148 + 408980*uk_149 + 11411775*uk_15 + 643500*uk_150 + 286000*uk_151 + 1012500*uk_152 + 450000*uk_153 + 200000*uk_154 + 2924207*uk_155 + 4601025*uk_156 + 2044900*uk_157 + 7239375*uk_158 + 3217500*uk_159 + 5071900*uk_16 + 1430000*uk_160 + 11390625*uk_161 + 5062500*uk_162 + 2250000*uk_163 + 1000000*uk_164 + 3025*uk_17 + 5335*uk_18 + 5500*uk_19 + 55*uk_2 + 1100*uk_20 + 7865*uk_21 + 12375*uk_22 + 5500*uk_23 + 9409*uk_24 + 9700*uk_25 + 1940*uk_26 + 13871*uk_27 + 21825*uk_28 + 9700*uk_29 + 97*uk_3 + 10000*uk_30 + 2000*uk_31 + 14300*uk_32 + 22500*uk_33 + 10000*uk_34 + 400*uk_35 + 2860*uk_36 + 4500*uk_37 + 2000*uk_38 + 20449*uk_39 + 100*uk_4 + 32175*uk_40 + 14300*uk_41 + 50625*uk_42 + 22500*uk_43 + 10000*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 249524445217*uk_47 + 257241696100*uk_48 + 51448339220*uk_49 + 20*uk_5 + 367855625423*uk_50 + 578793816225*uk_51 + 257241696100*uk_52 + 153424975*uk_53 + 270585865*uk_54 + 278954500*uk_55 + 55790900*uk_56 + 398904935*uk_57 + 627647625*uk_58 + 278954500*uk_59 + 143*uk_6 + 477215071*uk_60 + 491974300*uk_61 + 98394860*uk_62 + 703523249*uk_63 + 1106942175*uk_64 + 491974300*uk_65 + 507190000*uk_66 + 101438000*uk_67 + 725281700*uk_68 + 1141177500*uk_69 + 225*uk_7 + 507190000*uk_70 + 20287600*uk_71 + 145056340*uk_72 + 228235500*uk_73 + 101438000*uk_74 + 1037152831*uk_75 + 1631883825*uk_76 + 725281700*uk_77 + 2567649375*uk_78 + 1141177500*uk_79 + 100*uk_8 + 507190000*uk_80 + 166375*uk_81 + 293425*uk_82 + 302500*uk_83 + 60500*uk_84 + 432575*uk_85 + 680625*uk_86 + 302500*uk_87 + 517495*uk_88 + 533500*uk_89 + 2572416961*uk_9 + 106700*uk_90 + 762905*uk_91 + 1200375*uk_92 + 533500*uk_93 + 550000*uk_94 + 110000*uk_95 + 786500*uk_96 + 1237500*uk_97 + 550000*uk_98 + 22000*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 159500*uk_100 + 247500*uk_101 + 106700*uk_102 + 1156375*uk_103 + 1794375*uk_104 + 773575*uk_105 + 2784375*uk_106 + 1200375*uk_107 + 517495*uk_108 + 1481544*uk_109 + 5781966*uk_11 + 1260612*uk_110 + 259920*uk_111 + 1884420*uk_112 + 2924100*uk_113 + 1260612*uk_114 + 1072626*uk_115 + 221160*uk_116 + 1603410*uk_117 + 2488050*uk_118 + 1072626*uk_119 + 4919743*uk_12 + 45600*uk_120 + 330600*uk_121 + 513000*uk_122 + 221160*uk_123 + 2396850*uk_124 + 3719250*uk_125 + 1603410*uk_126 + 5771250*uk_127 + 2488050*uk_128 + 1072626*uk_129 + 1014380*uk_13 + 912673*uk_130 + 188180*uk_131 + 1364305*uk_132 + 2117025*uk_133 + 912673*uk_134 + 38800*uk_135 + 281300*uk_136 + 436500*uk_137 + 188180*uk_138 + 2039425*uk_139 + 7354255*uk_14 + 3164625*uk_140 + 1364305*uk_141 + 4910625*uk_142 + 2117025*uk_143 + 912673*uk_144 + 8000*uk_145 + 58000*uk_146 + 90000*uk_147 + 38800*uk_148 + 420500*uk_149 + 11411775*uk_15 + 652500*uk_150 + 281300*uk_151 + 1012500*uk_152 + 436500*uk_153 + 188180*uk_154 + 3048625*uk_155 + 4730625*uk_156 + 2039425*uk_157 + 7340625*uk_158 + 3164625*uk_159 + 4919743*uk_16 + 1364305*uk_160 + 11390625*uk_161 + 4910625*uk_162 + 2117025*uk_163 + 912673*uk_164 + 3025*uk_17 + 6270*uk_18 + 5335*uk_19 + 55*uk_2 + 1100*uk_20 + 7975*uk_21 + 12375*uk_22 + 5335*uk_23 + 12996*uk_24 + 11058*uk_25 + 2280*uk_26 + 16530*uk_27 + 25650*uk_28 + 11058*uk_29 + 114*uk_3 + 9409*uk_30 + 1940*uk_31 + 14065*uk_32 + 21825*uk_33 + 9409*uk_34 + 400*uk_35 + 2900*uk_36 + 4500*uk_37 + 1940*uk_38 + 21025*uk_39 + 97*uk_4 + 32625*uk_40 + 14065*uk_41 + 50625*uk_42 + 21825*uk_43 + 9409*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 293255533554*uk_47 + 249524445217*uk_48 + 51448339220*uk_49 + 20*uk_5 + 373000459345*uk_50 + 578793816225*uk_51 + 249524445217*uk_52 + 153424975*uk_53 + 318008130*uk_54 + 270585865*uk_55 + 55790900*uk_56 + 404484025*uk_57 + 627647625*uk_58 + 270585865*uk_59 + 145*uk_6 + 659144124*uk_60 + 560850702*uk_61 + 115639320*uk_62 + 838385070*uk_63 + 1300942350*uk_64 + 560850702*uk_65 + 477215071*uk_66 + 98394860*uk_67 + 713362735*uk_68 + 1106942175*uk_69 + 225*uk_7 + 477215071*uk_70 + 20287600*uk_71 + 147085100*uk_72 + 228235500*uk_73 + 98394860*uk_74 + 1066366975*uk_75 + 1654707375*uk_76 + 713362735*uk_77 + 2567649375*uk_78 + 1106942175*uk_79 + 97*uk_8 + 477215071*uk_80 + 166375*uk_81 + 344850*uk_82 + 293425*uk_83 + 60500*uk_84 + 438625*uk_85 + 680625*uk_86 + 293425*uk_87 + 714780*uk_88 + 608190*uk_89 + 2572416961*uk_9 + 125400*uk_90 + 909150*uk_91 + 1410750*uk_92 + 608190*uk_93 + 517495*uk_94 + 106700*uk_95 + 773575*uk_96 + 1200375*uk_97 + 517495*uk_98 + 22000*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 129360*uk_100 + 198000*uk_101 + 100320*uk_102 + 1188495*uk_103 + 1819125*uk_104 + 921690*uk_105 + 2784375*uk_106 + 1410750*uk_107 + 714780*uk_108 + 64*uk_109 + 202876*uk_11 + 1824*uk_110 + 256*uk_111 + 2352*uk_112 + 3600*uk_113 + 1824*uk_114 + 51984*uk_115 + 7296*uk_116 + 67032*uk_117 + 102600*uk_118 + 51984*uk_119 + 5781966*uk_12 + 1024*uk_120 + 9408*uk_121 + 14400*uk_122 + 7296*uk_123 + 86436*uk_124 + 132300*uk_125 + 67032*uk_126 + 202500*uk_127 + 102600*uk_128 + 51984*uk_129 + 811504*uk_13 + 1481544*uk_130 + 207936*uk_131 + 1910412*uk_132 + 2924100*uk_133 + 1481544*uk_134 + 29184*uk_135 + 268128*uk_136 + 410400*uk_137 + 207936*uk_138 + 2463426*uk_139 + 7455693*uk_14 + 3770550*uk_140 + 1910412*uk_141 + 5771250*uk_142 + 2924100*uk_143 + 1481544*uk_144 + 4096*uk_145 + 37632*uk_146 + 57600*uk_147 + 29184*uk_148 + 345744*uk_149 + 11411775*uk_15 + 529200*uk_150 + 268128*uk_151 + 810000*uk_152 + 410400*uk_153 + 207936*uk_154 + 3176523*uk_155 + 4862025*uk_156 + 2463426*uk_157 + 7441875*uk_158 + 3770550*uk_159 + 5781966*uk_16 + 1910412*uk_160 + 11390625*uk_161 + 5771250*uk_162 + 2924100*uk_163 + 1481544*uk_164 + 3025*uk_17 + 220*uk_18 + 6270*uk_19 + 55*uk_2 + 880*uk_20 + 8085*uk_21 + 12375*uk_22 + 6270*uk_23 + 16*uk_24 + 456*uk_25 + 64*uk_26 + 588*uk_27 + 900*uk_28 + 456*uk_29 + 4*uk_3 + 12996*uk_30 + 1824*uk_31 + 16758*uk_32 + 25650*uk_33 + 12996*uk_34 + 256*uk_35 + 2352*uk_36 + 3600*uk_37 + 1824*uk_38 + 21609*uk_39 + 114*uk_4 + 33075*uk_40 + 16758*uk_41 + 50625*uk_42 + 25650*uk_43 + 12996*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 10289667844*uk_47 + 293255533554*uk_48 + 41158671376*uk_49 + 16*uk_5 + 378145293267*uk_50 + 578793816225*uk_51 + 293255533554*uk_52 + 153424975*uk_53 + 11158180*uk_54 + 318008130*uk_55 + 44632720*uk_56 + 410063115*uk_57 + 627647625*uk_58 + 318008130*uk_59 + 147*uk_6 + 811504*uk_60 + 23127864*uk_61 + 3246016*uk_62 + 29822772*uk_63 + 45647100*uk_64 + 23127864*uk_65 + 659144124*uk_66 + 92511456*uk_67 + 849949002*uk_68 + 1300942350*uk_69 + 225*uk_7 + 659144124*uk_70 + 12984064*uk_71 + 119291088*uk_72 + 182588400*uk_73 + 92511456*uk_74 + 1095986871*uk_75 + 1677530925*uk_76 + 849949002*uk_77 + 2567649375*uk_78 + 1300942350*uk_79 + 114*uk_8 + 659144124*uk_80 + 166375*uk_81 + 12100*uk_82 + 344850*uk_83 + 48400*uk_84 + 444675*uk_85 + 680625*uk_86 + 344850*uk_87 + 880*uk_88 + 25080*uk_89 + 2572416961*uk_9 + 3520*uk_90 + 32340*uk_91 + 49500*uk_92 + 25080*uk_93 + 714780*uk_94 + 100320*uk_95 + 921690*uk_96 + 1410750*uk_97 + 714780*uk_98 + 14080*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 163900*uk_100 + 247500*uk_101 + 4400*uk_102 + 1221055*uk_103 + 1843875*uk_104 + 32780*uk_105 + 2784375*uk_106 + 49500*uk_107 + 880*uk_108 + 205379*uk_109 + 2992421*uk_11 + 13924*uk_110 + 69620*uk_111 + 518669*uk_112 + 783225*uk_113 + 13924*uk_114 + 944*uk_115 + 4720*uk_116 + 35164*uk_117 + 53100*uk_118 + 944*uk_119 + 202876*uk_12 + 23600*uk_120 + 175820*uk_121 + 265500*uk_122 + 4720*uk_123 + 1309859*uk_124 + 1977975*uk_125 + 35164*uk_126 + 2986875*uk_127 + 53100*uk_128 + 944*uk_129 + 1014380*uk_13 + 64*uk_130 + 320*uk_131 + 2384*uk_132 + 3600*uk_133 + 64*uk_134 + 1600*uk_135 + 11920*uk_136 + 18000*uk_137 + 320*uk_138 + 88804*uk_139 + 7557131*uk_14 + 134100*uk_140 + 2384*uk_141 + 202500*uk_142 + 3600*uk_143 + 64*uk_144 + 8000*uk_145 + 59600*uk_146 + 90000*uk_147 + 1600*uk_148 + 444020*uk_149 + 11411775*uk_15 + 670500*uk_150 + 11920*uk_151 + 1012500*uk_152 + 18000*uk_153 + 320*uk_154 + 3307949*uk_155 + 4995225*uk_156 + 88804*uk_157 + 7543125*uk_158 + 134100*uk_159 + 202876*uk_16 + 2384*uk_160 + 11390625*uk_161 + 202500*uk_162 + 3600*uk_163 + 64*uk_164 + 3025*uk_17 + 3245*uk_18 + 220*uk_19 + 55*uk_2 + 1100*uk_20 + 8195*uk_21 + 12375*uk_22 + 220*uk_23 + 3481*uk_24 + 236*uk_25 + 1180*uk_26 + 8791*uk_27 + 13275*uk_28 + 236*uk_29 + 59*uk_3 + 16*uk_30 + 80*uk_31 + 596*uk_32 + 900*uk_33 + 16*uk_34 + 400*uk_35 + 2980*uk_36 + 4500*uk_37 + 80*uk_38 + 22201*uk_39 + 4*uk_4 + 33525*uk_40 + 596*uk_41 + 50625*uk_42 + 900*uk_43 + 16*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 151772600699*uk_47 + 10289667844*uk_48 + 51448339220*uk_49 + 20*uk_5 + 383290127189*uk_50 + 578793816225*uk_51 + 10289667844*uk_52 + 153424975*uk_53 + 164583155*uk_54 + 11158180*uk_55 + 55790900*uk_56 + 415642205*uk_57 + 627647625*uk_58 + 11158180*uk_59 + 149*uk_6 + 176552839*uk_60 + 11969684*uk_61 + 59848420*uk_62 + 445870729*uk_63 + 673294725*uk_64 + 11969684*uk_65 + 811504*uk_66 + 4057520*uk_67 + 30228524*uk_68 + 45647100*uk_69 + 225*uk_7 + 811504*uk_70 + 20287600*uk_71 + 151142620*uk_72 + 228235500*uk_73 + 4057520*uk_74 + 1126012519*uk_75 + 1700354475*uk_76 + 30228524*uk_77 + 2567649375*uk_78 + 45647100*uk_79 + 4*uk_8 + 811504*uk_80 + 166375*uk_81 + 178475*uk_82 + 12100*uk_83 + 60500*uk_84 + 450725*uk_85 + 680625*uk_86 + 12100*uk_87 + 191455*uk_88 + 12980*uk_89 + 2572416961*uk_9 + 64900*uk_90 + 483505*uk_91 + 730125*uk_92 + 12980*uk_93 + 880*uk_94 + 4400*uk_95 + 32780*uk_96 + 49500*uk_97 + 880*uk_98 + 22000*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 166100*uk_100 + 247500*uk_101 + 64900*uk_102 + 1254055*uk_103 + 1868625*uk_104 + 489995*uk_105 + 2784375*uk_106 + 730125*uk_107 + 191455*uk_108 + 2406104*uk_109 + 6796346*uk_11 + 1059404*uk_110 + 359120*uk_111 + 2711356*uk_112 + 4040100*uk_113 + 1059404*uk_114 + 466454*uk_115 + 158120*uk_116 + 1193806*uk_117 + 1778850*uk_118 + 466454*uk_119 + 2992421*uk_12 + 53600*uk_120 + 404680*uk_121 + 603000*uk_122 + 158120*uk_123 + 3055334*uk_124 + 4552650*uk_125 + 1193806*uk_126 + 6783750*uk_127 + 1778850*uk_128 + 466454*uk_129 + 1014380*uk_13 + 205379*uk_130 + 69620*uk_131 + 525631*uk_132 + 783225*uk_133 + 205379*uk_134 + 23600*uk_135 + 178180*uk_136 + 265500*uk_137 + 69620*uk_138 + 1345259*uk_139 + 7658569*uk_14 + 2004525*uk_140 + 525631*uk_141 + 2986875*uk_142 + 783225*uk_143 + 205379*uk_144 + 8000*uk_145 + 60400*uk_146 + 90000*uk_147 + 23600*uk_148 + 456020*uk_149 + 11411775*uk_15 + 679500*uk_150 + 178180*uk_151 + 1012500*uk_152 + 265500*uk_153 + 69620*uk_154 + 3442951*uk_155 + 5130225*uk_156 + 1345259*uk_157 + 7644375*uk_158 + 2004525*uk_159 + 2992421*uk_16 + 525631*uk_160 + 11390625*uk_161 + 2986875*uk_162 + 783225*uk_163 + 205379*uk_164 + 3025*uk_17 + 7370*uk_18 + 3245*uk_19 + 55*uk_2 + 1100*uk_20 + 8305*uk_21 + 12375*uk_22 + 3245*uk_23 + 17956*uk_24 + 7906*uk_25 + 2680*uk_26 + 20234*uk_27 + 30150*uk_28 + 7906*uk_29 + 134*uk_3 + 3481*uk_30 + 1180*uk_31 + 8909*uk_32 + 13275*uk_33 + 3481*uk_34 + 400*uk_35 + 3020*uk_36 + 4500*uk_37 + 1180*uk_38 + 22801*uk_39 + 59*uk_4 + 33975*uk_40 + 8909*uk_41 + 50625*uk_42 + 13275*uk_43 + 3481*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 344703872774*uk_47 + 151772600699*uk_48 + 51448339220*uk_49 + 20*uk_5 + 388434961111*uk_50 + 578793816225*uk_51 + 151772600699*uk_52 + 153424975*uk_53 + 373799030*uk_54 + 164583155*uk_55 + 55790900*uk_56 + 421221295*uk_57 + 627647625*uk_58 + 164583155*uk_59 + 151*uk_6 + 910710364*uk_60 + 400984414*uk_61 + 135926920*uk_62 + 1026248246*uk_63 + 1529177850*uk_64 + 400984414*uk_65 + 176552839*uk_66 + 59848420*uk_67 + 451855571*uk_68 + 673294725*uk_69 + 225*uk_7 + 176552839*uk_70 + 20287600*uk_71 + 153171380*uk_72 + 228235500*uk_73 + 59848420*uk_74 + 1156443919*uk_75 + 1723178025*uk_76 + 451855571*uk_77 + 2567649375*uk_78 + 673294725*uk_79 + 59*uk_8 + 176552839*uk_80 + 166375*uk_81 + 405350*uk_82 + 178475*uk_83 + 60500*uk_84 + 456775*uk_85 + 680625*uk_86 + 178475*uk_87 + 987580*uk_88 + 434830*uk_89 + 2572416961*uk_9 + 147400*uk_90 + 1112870*uk_91 + 1658250*uk_92 + 434830*uk_93 + 191455*uk_94 + 64900*uk_95 + 489995*uk_96 + 730125*uk_97 + 191455*uk_98 + 22000*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 134640*uk_100 + 198000*uk_101 + 117920*uk_102 + 1287495*uk_103 + 1893375*uk_104 + 1127610*uk_105 + 2784375*uk_106 + 1658250*uk_107 + 987580*uk_108 + 438976*uk_109 + 3854644*uk_11 + 773984*uk_110 + 92416*uk_111 + 883728*uk_112 + 1299600*uk_113 + 773984*uk_114 + 1364656*uk_115 + 162944*uk_116 + 1558152*uk_117 + 2291400*uk_118 + 1364656*uk_119 + 6796346*uk_12 + 19456*uk_120 + 186048*uk_121 + 273600*uk_122 + 162944*uk_123 + 1779084*uk_124 + 2616300*uk_125 + 1558152*uk_126 + 3847500*uk_127 + 2291400*uk_128 + 1364656*uk_129 + 811504*uk_13 + 2406104*uk_130 + 287296*uk_131 + 2747268*uk_132 + 4040100*uk_133 + 2406104*uk_134 + 34304*uk_135 + 328032*uk_136 + 482400*uk_137 + 287296*uk_138 + 3136806*uk_139 + 7760007*uk_14 + 4612950*uk_140 + 2747268*uk_141 + 6783750*uk_142 + 4040100*uk_143 + 2406104*uk_144 + 4096*uk_145 + 39168*uk_146 + 57600*uk_147 + 34304*uk_148 + 374544*uk_149 + 11411775*uk_15 + 550800*uk_150 + 328032*uk_151 + 810000*uk_152 + 482400*uk_153 + 287296*uk_154 + 3581577*uk_155 + 5267025*uk_156 + 3136806*uk_157 + 7745625*uk_158 + 4612950*uk_159 + 6796346*uk_16 + 2747268*uk_160 + 11390625*uk_161 + 6783750*uk_162 + 4040100*uk_163 + 2406104*uk_164 + 3025*uk_17 + 4180*uk_18 + 7370*uk_19 + 55*uk_2 + 880*uk_20 + 8415*uk_21 + 12375*uk_22 + 7370*uk_23 + 5776*uk_24 + 10184*uk_25 + 1216*uk_26 + 11628*uk_27 + 17100*uk_28 + 10184*uk_29 + 76*uk_3 + 17956*uk_30 + 2144*uk_31 + 20502*uk_32 + 30150*uk_33 + 17956*uk_34 + 256*uk_35 + 2448*uk_36 + 3600*uk_37 + 2144*uk_38 + 23409*uk_39 + 134*uk_4 + 34425*uk_40 + 20502*uk_41 + 50625*uk_42 + 30150*uk_43 + 17956*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 195503689036*uk_47 + 344703872774*uk_48 + 41158671376*uk_49 + 16*uk_5 + 393579795033*uk_50 + 578793816225*uk_51 + 344703872774*uk_52 + 153424975*uk_53 + 212005420*uk_54 + 373799030*uk_55 + 44632720*uk_56 + 426800385*uk_57 + 627647625*uk_58 + 373799030*uk_59 + 153*uk_6 + 292952944*uk_60 + 516522296*uk_61 + 61674304*uk_62 + 589760532*uk_63 + 867294900*uk_64 + 516522296*uk_65 + 910710364*uk_66 + 108741536*uk_67 + 1039840938*uk_68 + 1529177850*uk_69 + 225*uk_7 + 910710364*uk_70 + 12984064*uk_71 + 124160112*uk_72 + 182588400*uk_73 + 108741536*uk_74 + 1187281071*uk_75 + 1746001575*uk_76 + 1039840938*uk_77 + 2567649375*uk_78 + 1529177850*uk_79 + 134*uk_8 + 910710364*uk_80 + 166375*uk_81 + 229900*uk_82 + 405350*uk_83 + 48400*uk_84 + 462825*uk_85 + 680625*uk_86 + 405350*uk_87 + 317680*uk_88 + 560120*uk_89 + 2572416961*uk_9 + 66880*uk_90 + 639540*uk_91 + 940500*uk_92 + 560120*uk_93 + 987580*uk_94 + 117920*uk_95 + 1127610*uk_96 + 1658250*uk_97 + 987580*uk_98 + 14080*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 136400*uk_100 + 198000*uk_101 + 66880*uk_102 + 1321375*uk_103 + 1918125*uk_104 + 647900*uk_105 + 2784375*uk_106 + 940500*uk_107 + 317680*uk_108 + 39304*uk_109 + 1724446*uk_11 + 87856*uk_110 + 18496*uk_111 + 179180*uk_112 + 260100*uk_113 + 87856*uk_114 + 196384*uk_115 + 41344*uk_116 + 400520*uk_117 + 581400*uk_118 + 196384*uk_119 + 3854644*uk_12 + 8704*uk_120 + 84320*uk_121 + 122400*uk_122 + 41344*uk_123 + 816850*uk_124 + 1185750*uk_125 + 400520*uk_126 + 1721250*uk_127 + 581400*uk_128 + 196384*uk_129 + 811504*uk_13 + 438976*uk_130 + 92416*uk_131 + 895280*uk_132 + 1299600*uk_133 + 438976*uk_134 + 19456*uk_135 + 188480*uk_136 + 273600*uk_137 + 92416*uk_138 + 1825900*uk_139 + 7861445*uk_14 + 2650500*uk_140 + 895280*uk_141 + 3847500*uk_142 + 1299600*uk_143 + 438976*uk_144 + 4096*uk_145 + 39680*uk_146 + 57600*uk_147 + 19456*uk_148 + 384400*uk_149 + 11411775*uk_15 + 558000*uk_150 + 188480*uk_151 + 810000*uk_152 + 273600*uk_153 + 92416*uk_154 + 3723875*uk_155 + 5405625*uk_156 + 1825900*uk_157 + 7846875*uk_158 + 2650500*uk_159 + 3854644*uk_16 + 895280*uk_160 + 11390625*uk_161 + 3847500*uk_162 + 1299600*uk_163 + 438976*uk_164 + 3025*uk_17 + 1870*uk_18 + 4180*uk_19 + 55*uk_2 + 880*uk_20 + 8525*uk_21 + 12375*uk_22 + 4180*uk_23 + 1156*uk_24 + 2584*uk_25 + 544*uk_26 + 5270*uk_27 + 7650*uk_28 + 2584*uk_29 + 34*uk_3 + 5776*uk_30 + 1216*uk_31 + 11780*uk_32 + 17100*uk_33 + 5776*uk_34 + 256*uk_35 + 2480*uk_36 + 3600*uk_37 + 1216*uk_38 + 24025*uk_39 + 76*uk_4 + 34875*uk_40 + 11780*uk_41 + 50625*uk_42 + 17100*uk_43 + 5776*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 87462176674*uk_47 + 195503689036*uk_48 + 41158671376*uk_49 + 16*uk_5 + 398724628955*uk_50 + 578793816225*uk_51 + 195503689036*uk_52 + 153424975*uk_53 + 94844530*uk_54 + 212005420*uk_55 + 44632720*uk_56 + 432379475*uk_57 + 627647625*uk_58 + 212005420*uk_59 + 155*uk_6 + 58631164*uk_60 + 131057896*uk_61 + 27591136*uk_62 + 267289130*uk_63 + 388000350*uk_64 + 131057896*uk_65 + 292952944*uk_66 + 61674304*uk_67 + 597469820*uk_68 + 867294900*uk_69 + 225*uk_7 + 292952944*uk_70 + 12984064*uk_71 + 125783120*uk_72 + 182588400*uk_73 + 61674304*uk_74 + 1218523975*uk_75 + 1768825125*uk_76 + 597469820*uk_77 + 2567649375*uk_78 + 867294900*uk_79 + 76*uk_8 + 292952944*uk_80 + 166375*uk_81 + 102850*uk_82 + 229900*uk_83 + 48400*uk_84 + 468875*uk_85 + 680625*uk_86 + 229900*uk_87 + 63580*uk_88 + 142120*uk_89 + 2572416961*uk_9 + 29920*uk_90 + 289850*uk_91 + 420750*uk_92 + 142120*uk_93 + 317680*uk_94 + 66880*uk_95 + 647900*uk_96 + 940500*uk_97 + 317680*uk_98 + 14080*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 138160*uk_100 + 198000*uk_101 + 29920*uk_102 + 1355695*uk_103 + 1942875*uk_104 + 293590*uk_105 + 2784375*uk_106 + 420750*uk_107 + 63580*uk_108 + 512*uk_109 + 405752*uk_11 + 2176*uk_110 + 1024*uk_111 + 10048*uk_112 + 14400*uk_113 + 2176*uk_114 + 9248*uk_115 + 4352*uk_116 + 42704*uk_117 + 61200*uk_118 + 9248*uk_119 + 1724446*uk_12 + 2048*uk_120 + 20096*uk_121 + 28800*uk_122 + 4352*uk_123 + 197192*uk_124 + 282600*uk_125 + 42704*uk_126 + 405000*uk_127 + 61200*uk_128 + 9248*uk_129 + 811504*uk_13 + 39304*uk_130 + 18496*uk_131 + 181492*uk_132 + 260100*uk_133 + 39304*uk_134 + 8704*uk_135 + 85408*uk_136 + 122400*uk_137 + 18496*uk_138 + 838066*uk_139 + 7962883*uk_14 + 1201050*uk_140 + 181492*uk_141 + 1721250*uk_142 + 260100*uk_143 + 39304*uk_144 + 4096*uk_145 + 40192*uk_146 + 57600*uk_147 + 8704*uk_148 + 394384*uk_149 + 11411775*uk_15 + 565200*uk_150 + 85408*uk_151 + 810000*uk_152 + 122400*uk_153 + 18496*uk_154 + 3869893*uk_155 + 5546025*uk_156 + 838066*uk_157 + 7948125*uk_158 + 1201050*uk_159 + 1724446*uk_16 + 181492*uk_160 + 11390625*uk_161 + 1721250*uk_162 + 260100*uk_163 + 39304*uk_164 + 3025*uk_17 + 440*uk_18 + 1870*uk_19 + 55*uk_2 + 880*uk_20 + 8635*uk_21 + 12375*uk_22 + 1870*uk_23 + 64*uk_24 + 272*uk_25 + 128*uk_26 + 1256*uk_27 + 1800*uk_28 + 272*uk_29 + 8*uk_3 + 1156*uk_30 + 544*uk_31 + 5338*uk_32 + 7650*uk_33 + 1156*uk_34 + 256*uk_35 + 2512*uk_36 + 3600*uk_37 + 544*uk_38 + 24649*uk_39 + 34*uk_4 + 35325*uk_40 + 5338*uk_41 + 50625*uk_42 + 7650*uk_43 + 1156*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 20579335688*uk_47 + 87462176674*uk_48 + 41158671376*uk_49 + 16*uk_5 + 403869462877*uk_50 + 578793816225*uk_51 + 87462176674*uk_52 + 153424975*uk_53 + 22316360*uk_54 + 94844530*uk_55 + 44632720*uk_56 + 437958565*uk_57 + 627647625*uk_58 + 94844530*uk_59 + 157*uk_6 + 3246016*uk_60 + 13795568*uk_61 + 6492032*uk_62 + 63703064*uk_63 + 91294200*uk_64 + 13795568*uk_65 + 58631164*uk_66 + 27591136*uk_67 + 270738022*uk_68 + 388000350*uk_69 + 225*uk_7 + 58631164*uk_70 + 12984064*uk_71 + 127406128*uk_72 + 182588400*uk_73 + 27591136*uk_74 + 1250172631*uk_75 + 1791648675*uk_76 + 270738022*uk_77 + 2567649375*uk_78 + 388000350*uk_79 + 34*uk_8 + 58631164*uk_80 + 166375*uk_81 + 24200*uk_82 + 102850*uk_83 + 48400*uk_84 + 474925*uk_85 + 680625*uk_86 + 102850*uk_87 + 3520*uk_88 + 14960*uk_89 + 2572416961*uk_9 + 7040*uk_90 + 69080*uk_91 + 99000*uk_92 + 14960*uk_93 + 63580*uk_94 + 29920*uk_95 + 293590*uk_96 + 420750*uk_97 + 63580*uk_98 + 14080*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 174900*uk_100 + 247500*uk_101 + 8800*uk_102 + 1390455*uk_103 + 1967625*uk_104 + 69960*uk_105 + 2784375*uk_106 + 99000*uk_107 + 3520*uk_108 + 3869893*uk_109 + 7962883*uk_11 + 197192*uk_110 + 492980*uk_111 + 3919191*uk_112 + 5546025*uk_113 + 197192*uk_114 + 10048*uk_115 + 25120*uk_116 + 199704*uk_117 + 282600*uk_118 + 10048*uk_119 + 405752*uk_12 + 62800*uk_120 + 499260*uk_121 + 706500*uk_122 + 25120*uk_123 + 3969117*uk_124 + 5616675*uk_125 + 199704*uk_126 + 7948125*uk_127 + 282600*uk_128 + 10048*uk_129 + 1014380*uk_13 + 512*uk_130 + 1280*uk_131 + 10176*uk_132 + 14400*uk_133 + 512*uk_134 + 3200*uk_135 + 25440*uk_136 + 36000*uk_137 + 1280*uk_138 + 202248*uk_139 + 8064321*uk_14 + 286200*uk_140 + 10176*uk_141 + 405000*uk_142 + 14400*uk_143 + 512*uk_144 + 8000*uk_145 + 63600*uk_146 + 90000*uk_147 + 3200*uk_148 + 505620*uk_149 + 11411775*uk_15 + 715500*uk_150 + 25440*uk_151 + 1012500*uk_152 + 36000*uk_153 + 1280*uk_154 + 4019679*uk_155 + 5688225*uk_156 + 202248*uk_157 + 8049375*uk_158 + 286200*uk_159 + 405752*uk_16 + 10176*uk_160 + 11390625*uk_161 + 405000*uk_162 + 14400*uk_163 + 512*uk_164 + 3025*uk_17 + 8635*uk_18 + 440*uk_19 + 55*uk_2 + 1100*uk_20 + 8745*uk_21 + 12375*uk_22 + 440*uk_23 + 24649*uk_24 + 1256*uk_25 + 3140*uk_26 + 24963*uk_27 + 35325*uk_28 + 1256*uk_29 + 157*uk_3 + 64*uk_30 + 160*uk_31 + 1272*uk_32 + 1800*uk_33 + 64*uk_34 + 400*uk_35 + 3180*uk_36 + 4500*uk_37 + 160*uk_38 + 25281*uk_39 + 8*uk_4 + 35775*uk_40 + 1272*uk_41 + 50625*uk_42 + 1800*uk_43 + 64*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 403869462877*uk_47 + 20579335688*uk_48 + 51448339220*uk_49 + 20*uk_5 + 409014296799*uk_50 + 578793816225*uk_51 + 20579335688*uk_52 + 153424975*uk_53 + 437958565*uk_54 + 22316360*uk_55 + 55790900*uk_56 + 443537655*uk_57 + 627647625*uk_58 + 22316360*uk_59 + 159*uk_6 + 1250172631*uk_60 + 63703064*uk_61 + 159257660*uk_62 + 1266098397*uk_63 + 1791648675*uk_64 + 63703064*uk_65 + 3246016*uk_66 + 8115040*uk_67 + 64514568*uk_68 + 91294200*uk_69 + 225*uk_7 + 3246016*uk_70 + 20287600*uk_71 + 161286420*uk_72 + 228235500*uk_73 + 8115040*uk_74 + 1282227039*uk_75 + 1814472225*uk_76 + 64514568*uk_77 + 2567649375*uk_78 + 91294200*uk_79 + 8*uk_8 + 3246016*uk_80 + 166375*uk_81 + 474925*uk_82 + 24200*uk_83 + 60500*uk_84 + 480975*uk_85 + 680625*uk_86 + 24200*uk_87 + 1355695*uk_88 + 69080*uk_89 + 2572416961*uk_9 + 172700*uk_90 + 1372965*uk_91 + 1942875*uk_92 + 69080*uk_93 + 3520*uk_94 + 8800*uk_95 + 69960*uk_96 + 99000*uk_97 + 3520*uk_98 + 22000*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 106260*uk_100 + 148500*uk_101 + 103620*uk_102 + 1425655*uk_103 + 1992375*uk_104 + 1390235*uk_105 + 2784375*uk_106 + 1942875*uk_107 + 1355695*uk_108 + 64*uk_109 + 202876*uk_11 + 2512*uk_110 + 192*uk_111 + 2576*uk_112 + 3600*uk_113 + 2512*uk_114 + 98596*uk_115 + 7536*uk_116 + 101108*uk_117 + 141300*uk_118 + 98596*uk_119 + 7962883*uk_12 + 576*uk_120 + 7728*uk_121 + 10800*uk_122 + 7536*uk_123 + 103684*uk_124 + 144900*uk_125 + 101108*uk_126 + 202500*uk_127 + 141300*uk_128 + 98596*uk_129 + 608628*uk_13 + 3869893*uk_130 + 295788*uk_131 + 3968489*uk_132 + 5546025*uk_133 + 3869893*uk_134 + 22608*uk_135 + 303324*uk_136 + 423900*uk_137 + 295788*uk_138 + 4069597*uk_139 + 8165759*uk_14 + 5687325*uk_140 + 3968489*uk_141 + 7948125*uk_142 + 5546025*uk_143 + 3869893*uk_144 + 1728*uk_145 + 23184*uk_146 + 32400*uk_147 + 22608*uk_148 + 311052*uk_149 + 11411775*uk_15 + 434700*uk_150 + 303324*uk_151 + 607500*uk_152 + 423900*uk_153 + 295788*uk_154 + 4173281*uk_155 + 5832225*uk_156 + 4069597*uk_157 + 8150625*uk_158 + 5687325*uk_159 + 7962883*uk_16 + 3968489*uk_160 + 11390625*uk_161 + 7948125*uk_162 + 5546025*uk_163 + 3869893*uk_164 + 3025*uk_17 + 220*uk_18 + 8635*uk_19 + 55*uk_2 + 660*uk_20 + 8855*uk_21 + 12375*uk_22 + 8635*uk_23 + 16*uk_24 + 628*uk_25 + 48*uk_26 + 644*uk_27 + 900*uk_28 + 628*uk_29 + 4*uk_3 + 24649*uk_30 + 1884*uk_31 + 25277*uk_32 + 35325*uk_33 + 24649*uk_34 + 144*uk_35 + 1932*uk_36 + 2700*uk_37 + 1884*uk_38 + 25921*uk_39 + 157*uk_4 + 36225*uk_40 + 25277*uk_41 + 50625*uk_42 + 35325*uk_43 + 24649*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 10289667844*uk_47 + 403869462877*uk_48 + 30869003532*uk_49 + 12*uk_5 + 414159130721*uk_50 + 578793816225*uk_51 + 403869462877*uk_52 + 153424975*uk_53 + 11158180*uk_54 + 437958565*uk_55 + 33474540*uk_56 + 449116745*uk_57 + 627647625*uk_58 + 437958565*uk_59 + 161*uk_6 + 811504*uk_60 + 31851532*uk_61 + 2434512*uk_62 + 32663036*uk_63 + 45647100*uk_64 + 31851532*uk_65 + 1250172631*uk_66 + 95554596*uk_67 + 1282024163*uk_68 + 1791648675*uk_69 + 225*uk_7 + 1250172631*uk_70 + 7303536*uk_71 + 97989108*uk_72 + 136941300*uk_73 + 95554596*uk_74 + 1314687199*uk_75 + 1837295775*uk_76 + 1282024163*uk_77 + 2567649375*uk_78 + 1791648675*uk_79 + 157*uk_8 + 1250172631*uk_80 + 166375*uk_81 + 12100*uk_82 + 474925*uk_83 + 36300*uk_84 + 487025*uk_85 + 680625*uk_86 + 474925*uk_87 + 880*uk_88 + 34540*uk_89 + 2572416961*uk_9 + 2640*uk_90 + 35420*uk_91 + 49500*uk_92 + 34540*uk_93 + 1355695*uk_94 + 103620*uk_95 + 1390235*uk_96 + 1942875*uk_97 + 1355695*uk_98 + 7920*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 143440*uk_100 + 198000*uk_101 + 3520*uk_102 + 1461295*uk_103 + 2017125*uk_104 + 35860*uk_105 + 2784375*uk_106 + 49500*uk_107 + 880*uk_108 + 17576*uk_109 + 1318694*uk_11 + 2704*uk_110 + 10816*uk_111 + 110188*uk_112 + 152100*uk_113 + 2704*uk_114 + 416*uk_115 + 1664*uk_116 + 16952*uk_117 + 23400*uk_118 + 416*uk_119 + 202876*uk_12 + 6656*uk_120 + 67808*uk_121 + 93600*uk_122 + 1664*uk_123 + 690794*uk_124 + 953550*uk_125 + 16952*uk_126 + 1316250*uk_127 + 23400*uk_128 + 416*uk_129 + 811504*uk_13 + 64*uk_130 + 256*uk_131 + 2608*uk_132 + 3600*uk_133 + 64*uk_134 + 1024*uk_135 + 10432*uk_136 + 14400*uk_137 + 256*uk_138 + 106276*uk_139 + 8267197*uk_14 + 146700*uk_140 + 2608*uk_141 + 202500*uk_142 + 3600*uk_143 + 64*uk_144 + 4096*uk_145 + 41728*uk_146 + 57600*uk_147 + 1024*uk_148 + 425104*uk_149 + 11411775*uk_15 + 586800*uk_150 + 10432*uk_151 + 810000*uk_152 + 14400*uk_153 + 256*uk_154 + 4330747*uk_155 + 5978025*uk_156 + 106276*uk_157 + 8251875*uk_158 + 146700*uk_159 + 202876*uk_16 + 2608*uk_160 + 11390625*uk_161 + 202500*uk_162 + 3600*uk_163 + 64*uk_164 + 3025*uk_17 + 1430*uk_18 + 220*uk_19 + 55*uk_2 + 880*uk_20 + 8965*uk_21 + 12375*uk_22 + 220*uk_23 + 676*uk_24 + 104*uk_25 + 416*uk_26 + 4238*uk_27 + 5850*uk_28 + 104*uk_29 + 26*uk_3 + 16*uk_30 + 64*uk_31 + 652*uk_32 + 900*uk_33 + 16*uk_34 + 256*uk_35 + 2608*uk_36 + 3600*uk_37 + 64*uk_38 + 26569*uk_39 + 4*uk_4 + 36675*uk_40 + 652*uk_41 + 50625*uk_42 + 900*uk_43 + 16*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 66882840986*uk_47 + 10289667844*uk_48 + 41158671376*uk_49 + 16*uk_5 + 419303964643*uk_50 + 578793816225*uk_51 + 10289667844*uk_52 + 153424975*uk_53 + 72528170*uk_54 + 11158180*uk_55 + 44632720*uk_56 + 454695835*uk_57 + 627647625*uk_58 + 11158180*uk_59 + 163*uk_6 + 34286044*uk_60 + 5274776*uk_61 + 21099104*uk_62 + 214947122*uk_63 + 296706150*uk_64 + 5274776*uk_65 + 811504*uk_66 + 3246016*uk_67 + 33068788*uk_68 + 45647100*uk_69 + 225*uk_7 + 811504*uk_70 + 12984064*uk_71 + 132275152*uk_72 + 182588400*uk_73 + 3246016*uk_74 + 1347553111*uk_75 + 1860119325*uk_76 + 33068788*uk_77 + 2567649375*uk_78 + 45647100*uk_79 + 4*uk_8 + 811504*uk_80 + 166375*uk_81 + 78650*uk_82 + 12100*uk_83 + 48400*uk_84 + 493075*uk_85 + 680625*uk_86 + 12100*uk_87 + 37180*uk_88 + 5720*uk_89 + 2572416961*uk_9 + 22880*uk_90 + 233090*uk_91 + 321750*uk_92 + 5720*uk_93 + 880*uk_94 + 3520*uk_95 + 35860*uk_96 + 49500*uk_97 + 880*uk_98 + 14080*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 145200*uk_100 + 198000*uk_101 + 22880*uk_102 + 1497375*uk_103 + 2041875*uk_104 + 235950*uk_105 + 2784375*uk_106 + 321750*uk_107 + 37180*uk_108 + 262144*uk_109 + 3246016*uk_11 + 106496*uk_110 + 65536*uk_111 + 675840*uk_112 + 921600*uk_113 + 106496*uk_114 + 43264*uk_115 + 26624*uk_116 + 274560*uk_117 + 374400*uk_118 + 43264*uk_119 + 1318694*uk_12 + 16384*uk_120 + 168960*uk_121 + 230400*uk_122 + 26624*uk_123 + 1742400*uk_124 + 2376000*uk_125 + 274560*uk_126 + 3240000*uk_127 + 374400*uk_128 + 43264*uk_129 + 811504*uk_13 + 17576*uk_130 + 10816*uk_131 + 111540*uk_132 + 152100*uk_133 + 17576*uk_134 + 6656*uk_135 + 68640*uk_136 + 93600*uk_137 + 10816*uk_138 + 707850*uk_139 + 8368635*uk_14 + 965250*uk_140 + 111540*uk_141 + 1316250*uk_142 + 152100*uk_143 + 17576*uk_144 + 4096*uk_145 + 42240*uk_146 + 57600*uk_147 + 6656*uk_148 + 435600*uk_149 + 11411775*uk_15 + 594000*uk_150 + 68640*uk_151 + 810000*uk_152 + 93600*uk_153 + 10816*uk_154 + 4492125*uk_155 + 6125625*uk_156 + 707850*uk_157 + 8353125*uk_158 + 965250*uk_159 + 1318694*uk_16 + 111540*uk_160 + 11390625*uk_161 + 1316250*uk_162 + 152100*uk_163 + 17576*uk_164 + 3025*uk_17 + 3520*uk_18 + 1430*uk_19 + 55*uk_2 + 880*uk_20 + 9075*uk_21 + 12375*uk_22 + 1430*uk_23 + 4096*uk_24 + 1664*uk_25 + 1024*uk_26 + 10560*uk_27 + 14400*uk_28 + 1664*uk_29 + 64*uk_3 + 676*uk_30 + 416*uk_31 + 4290*uk_32 + 5850*uk_33 + 676*uk_34 + 256*uk_35 + 2640*uk_36 + 3600*uk_37 + 416*uk_38 + 27225*uk_39 + 26*uk_4 + 37125*uk_40 + 4290*uk_41 + 50625*uk_42 + 5850*uk_43 + 676*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 164634685504*uk_47 + 66882840986*uk_48 + 41158671376*uk_49 + 16*uk_5 + 424448798565*uk_50 + 578793816225*uk_51 + 66882840986*uk_52 + 153424975*uk_53 + 178530880*uk_54 + 72528170*uk_55 + 44632720*uk_56 + 460274925*uk_57 + 627647625*uk_58 + 72528170*uk_59 + 165*uk_6 + 207745024*uk_60 + 84396416*uk_61 + 51936256*uk_62 + 535592640*uk_63 + 730353600*uk_64 + 84396416*uk_65 + 34286044*uk_66 + 21099104*uk_67 + 217584510*uk_68 + 296706150*uk_69 + 225*uk_7 + 34286044*uk_70 + 12984064*uk_71 + 133898160*uk_72 + 182588400*uk_73 + 21099104*uk_74 + 1380824775*uk_75 + 1882942875*uk_76 + 217584510*uk_77 + 2567649375*uk_78 + 296706150*uk_79 + 26*uk_8 + 34286044*uk_80 + 166375*uk_81 + 193600*uk_82 + 78650*uk_83 + 48400*uk_84 + 499125*uk_85 + 680625*uk_86 + 78650*uk_87 + 225280*uk_88 + 91520*uk_89 + 2572416961*uk_9 + 56320*uk_90 + 580800*uk_91 + 792000*uk_92 + 91520*uk_93 + 37180*uk_94 + 22880*uk_95 + 235950*uk_96 + 321750*uk_97 + 37180*uk_98 + 14080*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 146960*uk_100 + 198000*uk_101 + 56320*uk_102 + 1533895*uk_103 + 2066625*uk_104 + 587840*uk_105 + 2784375*uk_106 + 792000*uk_107 + 225280*uk_108 + 1643032*uk_109 + 5984842*uk_11 + 891136*uk_110 + 222784*uk_111 + 2325308*uk_112 + 3132900*uk_113 + 891136*uk_114 + 483328*uk_115 + 120832*uk_116 + 1261184*uk_117 + 1699200*uk_118 + 483328*uk_119 + 3246016*uk_12 + 30208*uk_120 + 315296*uk_121 + 424800*uk_122 + 120832*uk_123 + 3290902*uk_124 + 4433850*uk_125 + 1261184*uk_126 + 5973750*uk_127 + 1699200*uk_128 + 483328*uk_129 + 811504*uk_13 + 262144*uk_130 + 65536*uk_131 + 684032*uk_132 + 921600*uk_133 + 262144*uk_134 + 16384*uk_135 + 171008*uk_136 + 230400*uk_137 + 65536*uk_138 + 1784896*uk_139 + 8470073*uk_14 + 2404800*uk_140 + 684032*uk_141 + 3240000*uk_142 + 921600*uk_143 + 262144*uk_144 + 4096*uk_145 + 42752*uk_146 + 57600*uk_147 + 16384*uk_148 + 446224*uk_149 + 11411775*uk_15 + 601200*uk_150 + 171008*uk_151 + 810000*uk_152 + 230400*uk_153 + 65536*uk_154 + 4657463*uk_155 + 6275025*uk_156 + 1784896*uk_157 + 8454375*uk_158 + 2404800*uk_159 + 3246016*uk_16 + 684032*uk_160 + 11390625*uk_161 + 3240000*uk_162 + 921600*uk_163 + 262144*uk_164 + 3025*uk_17 + 6490*uk_18 + 3520*uk_19 + 55*uk_2 + 880*uk_20 + 9185*uk_21 + 12375*uk_22 + 3520*uk_23 + 13924*uk_24 + 7552*uk_25 + 1888*uk_26 + 19706*uk_27 + 26550*uk_28 + 7552*uk_29 + 118*uk_3 + 4096*uk_30 + 1024*uk_31 + 10688*uk_32 + 14400*uk_33 + 4096*uk_34 + 256*uk_35 + 2672*uk_36 + 3600*uk_37 + 1024*uk_38 + 27889*uk_39 + 64*uk_4 + 37575*uk_40 + 10688*uk_41 + 50625*uk_42 + 14400*uk_43 + 4096*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 303545201398*uk_47 + 164634685504*uk_48 + 41158671376*uk_49 + 16*uk_5 + 429593632487*uk_50 + 578793816225*uk_51 + 164634685504*uk_52 + 153424975*uk_53 + 329166310*uk_54 + 178530880*uk_55 + 44632720*uk_56 + 465854015*uk_57 + 627647625*uk_58 + 178530880*uk_59 + 167*uk_6 + 706211356*uk_60 + 383029888*uk_61 + 95757472*uk_62 + 999468614*uk_63 + 1346589450*uk_64 + 383029888*uk_65 + 207745024*uk_66 + 51936256*uk_67 + 542084672*uk_68 + 730353600*uk_69 + 225*uk_7 + 207745024*uk_70 + 12984064*uk_71 + 135521168*uk_72 + 182588400*uk_73 + 51936256*uk_74 + 1414502191*uk_75 + 1905766425*uk_76 + 542084672*uk_77 + 2567649375*uk_78 + 730353600*uk_79 + 64*uk_8 + 207745024*uk_80 + 166375*uk_81 + 356950*uk_82 + 193600*uk_83 + 48400*uk_84 + 505175*uk_85 + 680625*uk_86 + 193600*uk_87 + 765820*uk_88 + 415360*uk_89 + 2572416961*uk_9 + 103840*uk_90 + 1083830*uk_91 + 1460250*uk_92 + 415360*uk_93 + 225280*uk_94 + 56320*uk_95 + 587840*uk_96 + 792000*uk_97 + 225280*uk_98 + 14080*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 111540*uk_100 + 148500*uk_101 + 77880*uk_102 + 1570855*uk_103 + 2091375*uk_104 + 1096810*uk_105 + 2784375*uk_106 + 1460250*uk_107 + 765820*uk_108 + 6859*uk_109 + 963661*uk_11 + 42598*uk_110 + 4332*uk_111 + 61009*uk_112 + 81225*uk_113 + 42598*uk_114 + 264556*uk_115 + 26904*uk_116 + 378898*uk_117 + 504450*uk_118 + 264556*uk_119 + 5984842*uk_12 + 2736*uk_120 + 38532*uk_121 + 51300*uk_122 + 26904*uk_123 + 542659*uk_124 + 722475*uk_125 + 378898*uk_126 + 961875*uk_127 + 504450*uk_128 + 264556*uk_129 + 608628*uk_13 + 1643032*uk_130 + 167088*uk_131 + 2353156*uk_132 + 3132900*uk_133 + 1643032*uk_134 + 16992*uk_135 + 239304*uk_136 + 318600*uk_137 + 167088*uk_138 + 3370198*uk_139 + 8571511*uk_14 + 4486950*uk_140 + 2353156*uk_141 + 5973750*uk_142 + 3132900*uk_143 + 1643032*uk_144 + 1728*uk_145 + 24336*uk_146 + 32400*uk_147 + 16992*uk_148 + 342732*uk_149 + 11411775*uk_15 + 456300*uk_150 + 239304*uk_151 + 607500*uk_152 + 318600*uk_153 + 167088*uk_154 + 4826809*uk_155 + 6426225*uk_156 + 3370198*uk_157 + 8555625*uk_158 + 4486950*uk_159 + 5984842*uk_16 + 2353156*uk_160 + 11390625*uk_161 + 5973750*uk_162 + 3132900*uk_163 + 1643032*uk_164 + 3025*uk_17 + 1045*uk_18 + 6490*uk_19 + 55*uk_2 + 660*uk_20 + 9295*uk_21 + 12375*uk_22 + 6490*uk_23 + 361*uk_24 + 2242*uk_25 + 228*uk_26 + 3211*uk_27 + 4275*uk_28 + 2242*uk_29 + 19*uk_3 + 13924*uk_30 + 1416*uk_31 + 19942*uk_32 + 26550*uk_33 + 13924*uk_34 + 144*uk_35 + 2028*uk_36 + 2700*uk_37 + 1416*uk_38 + 28561*uk_39 + 118*uk_4 + 38025*uk_40 + 19942*uk_41 + 50625*uk_42 + 26550*uk_43 + 13924*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 48875922259*uk_47 + 303545201398*uk_48 + 30869003532*uk_49 + 12*uk_5 + 434738466409*uk_50 + 578793816225*uk_51 + 303545201398*uk_52 + 153424975*uk_53 + 53001355*uk_54 + 329166310*uk_55 + 33474540*uk_56 + 471433105*uk_57 + 627647625*uk_58 + 329166310*uk_59 + 169*uk_6 + 18309559*uk_60 + 113711998*uk_61 + 11563932*uk_62 + 162858709*uk_63 + 216823725*uk_64 + 113711998*uk_65 + 706211356*uk_66 + 71818104*uk_67 + 1011438298*uk_68 + 1346589450*uk_69 + 225*uk_7 + 706211356*uk_70 + 7303536*uk_71 + 102858132*uk_72 + 136941300*uk_73 + 71818104*uk_74 + 1448585359*uk_75 + 1928589975*uk_76 + 1011438298*uk_77 + 2567649375*uk_78 + 1346589450*uk_79 + 118*uk_8 + 706211356*uk_80 + 166375*uk_81 + 57475*uk_82 + 356950*uk_83 + 36300*uk_84 + 511225*uk_85 + 680625*uk_86 + 356950*uk_87 + 19855*uk_88 + 123310*uk_89 + 2572416961*uk_9 + 12540*uk_90 + 176605*uk_91 + 235125*uk_92 + 123310*uk_93 + 765820*uk_94 + 77880*uk_95 + 1096810*uk_96 + 1460250*uk_97 + 765820*uk_98 + 7920*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 150480*uk_100 + 198000*uk_101 + 16720*uk_102 + 1608255*uk_103 + 2116125*uk_104 + 178695*uk_105 + 2784375*uk_106 + 235125*uk_107 + 19855*uk_108 + 1092727*uk_109 + 5224057*uk_11 + 201571*uk_110 + 169744*uk_111 + 1814139*uk_112 + 2387025*uk_113 + 201571*uk_114 + 37183*uk_115 + 31312*uk_116 + 334647*uk_117 + 440325*uk_118 + 37183*uk_119 + 963661*uk_12 + 26368*uk_120 + 281808*uk_121 + 370800*uk_122 + 31312*uk_123 + 3011823*uk_124 + 3962925*uk_125 + 334647*uk_126 + 5214375*uk_127 + 440325*uk_128 + 37183*uk_129 + 811504*uk_13 + 6859*uk_130 + 5776*uk_131 + 61731*uk_132 + 81225*uk_133 + 6859*uk_134 + 4864*uk_135 + 51984*uk_136 + 68400*uk_137 + 5776*uk_138 + 555579*uk_139 + 8672949*uk_14 + 731025*uk_140 + 61731*uk_141 + 961875*uk_142 + 81225*uk_143 + 6859*uk_144 + 4096*uk_145 + 43776*uk_146 + 57600*uk_147 + 4864*uk_148 + 467856*uk_149 + 11411775*uk_15 + 615600*uk_150 + 51984*uk_151 + 810000*uk_152 + 68400*uk_153 + 5776*uk_154 + 5000211*uk_155 + 6579225*uk_156 + 555579*uk_157 + 8656875*uk_158 + 731025*uk_159 + 963661*uk_16 + 61731*uk_160 + 11390625*uk_161 + 961875*uk_162 + 81225*uk_163 + 6859*uk_164 + 3025*uk_17 + 5665*uk_18 + 1045*uk_19 + 55*uk_2 + 880*uk_20 + 9405*uk_21 + 12375*uk_22 + 1045*uk_23 + 10609*uk_24 + 1957*uk_25 + 1648*uk_26 + 17613*uk_27 + 23175*uk_28 + 1957*uk_29 + 103*uk_3 + 361*uk_30 + 304*uk_31 + 3249*uk_32 + 4275*uk_33 + 361*uk_34 + 256*uk_35 + 2736*uk_36 + 3600*uk_37 + 304*uk_38 + 29241*uk_39 + 19*uk_4 + 38475*uk_40 + 3249*uk_41 + 50625*uk_42 + 4275*uk_43 + 361*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 264958946983*uk_47 + 48875922259*uk_48 + 41158671376*uk_49 + 16*uk_5 + 439883300331*uk_50 + 578793816225*uk_51 + 48875922259*uk_52 + 153424975*uk_53 + 287323135*uk_54 + 53001355*uk_55 + 44632720*uk_56 + 477012195*uk_57 + 627647625*uk_58 + 53001355*uk_59 + 171*uk_6 + 538077871*uk_60 + 99257083*uk_61 + 83584912*uk_62 + 893313747*uk_63 + 1175412825*uk_64 + 99257083*uk_65 + 18309559*uk_66 + 15418576*uk_67 + 164786031*uk_68 + 216823725*uk_69 + 225*uk_7 + 18309559*uk_70 + 12984064*uk_71 + 138767184*uk_72 + 182588400*uk_73 + 15418576*uk_74 + 1483074279*uk_75 + 1951413525*uk_76 + 164786031*uk_77 + 2567649375*uk_78 + 216823725*uk_79 + 19*uk_8 + 18309559*uk_80 + 166375*uk_81 + 311575*uk_82 + 57475*uk_83 + 48400*uk_84 + 517275*uk_85 + 680625*uk_86 + 57475*uk_87 + 583495*uk_88 + 107635*uk_89 + 2572416961*uk_9 + 90640*uk_90 + 968715*uk_91 + 1274625*uk_92 + 107635*uk_93 + 19855*uk_94 + 16720*uk_95 + 178695*uk_96 + 235125*uk_97 + 19855*uk_98 + 14080*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 114180*uk_100 + 148500*uk_101 + 67980*uk_102 + 1646095*uk_103 + 2140875*uk_104 + 980045*uk_105 + 2784375*uk_106 + 1274625*uk_107 + 583495*uk_108 + 27000*uk_109 + 1521570*uk_11 + 92700*uk_110 + 10800*uk_111 + 155700*uk_112 + 202500*uk_113 + 92700*uk_114 + 318270*uk_115 + 37080*uk_116 + 534570*uk_117 + 695250*uk_118 + 318270*uk_119 + 5224057*uk_12 + 4320*uk_120 + 62280*uk_121 + 81000*uk_122 + 37080*uk_123 + 897870*uk_124 + 1167750*uk_125 + 534570*uk_126 + 1518750*uk_127 + 695250*uk_128 + 318270*uk_129 + 608628*uk_13 + 1092727*uk_130 + 127308*uk_131 + 1835357*uk_132 + 2387025*uk_133 + 1092727*uk_134 + 14832*uk_135 + 213828*uk_136 + 278100*uk_137 + 127308*uk_138 + 3082687*uk_139 + 8774387*uk_14 + 4009275*uk_140 + 1835357*uk_141 + 5214375*uk_142 + 2387025*uk_143 + 1092727*uk_144 + 1728*uk_145 + 24912*uk_146 + 32400*uk_147 + 14832*uk_148 + 359148*uk_149 + 11411775*uk_15 + 467100*uk_150 + 213828*uk_151 + 607500*uk_152 + 278100*uk_153 + 127308*uk_154 + 5177717*uk_155 + 6734025*uk_156 + 3082687*uk_157 + 8758125*uk_158 + 4009275*uk_159 + 5224057*uk_16 + 1835357*uk_160 + 11390625*uk_161 + 5214375*uk_162 + 2387025*uk_163 + 1092727*uk_164 + 3025*uk_17 + 1650*uk_18 + 5665*uk_19 + 55*uk_2 + 660*uk_20 + 9515*uk_21 + 12375*uk_22 + 5665*uk_23 + 900*uk_24 + 3090*uk_25 + 360*uk_26 + 5190*uk_27 + 6750*uk_28 + 3090*uk_29 + 30*uk_3 + 10609*uk_30 + 1236*uk_31 + 17819*uk_32 + 23175*uk_33 + 10609*uk_34 + 144*uk_35 + 2076*uk_36 + 2700*uk_37 + 1236*uk_38 + 29929*uk_39 + 103*uk_4 + 38925*uk_40 + 17819*uk_41 + 50625*uk_42 + 23175*uk_43 + 10609*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 77172508830*uk_47 + 264958946983*uk_48 + 30869003532*uk_49 + 12*uk_5 + 445028134253*uk_50 + 578793816225*uk_51 + 264958946983*uk_52 + 153424975*uk_53 + 83686350*uk_54 + 287323135*uk_55 + 33474540*uk_56 + 482591285*uk_57 + 627647625*uk_58 + 287323135*uk_59 + 173*uk_6 + 45647100*uk_60 + 156721710*uk_61 + 18258840*uk_62 + 263231610*uk_63 + 342353250*uk_64 + 156721710*uk_65 + 538077871*uk_66 + 62688684*uk_67 + 903761861*uk_68 + 1175412825*uk_69 + 225*uk_7 + 538077871*uk_70 + 7303536*uk_71 + 105292644*uk_72 + 136941300*uk_73 + 62688684*uk_74 + 1517968951*uk_75 + 1974237075*uk_76 + 903761861*uk_77 + 2567649375*uk_78 + 1175412825*uk_79 + 103*uk_8 + 538077871*uk_80 + 166375*uk_81 + 90750*uk_82 + 311575*uk_83 + 36300*uk_84 + 523325*uk_85 + 680625*uk_86 + 311575*uk_87 + 49500*uk_88 + 169950*uk_89 + 2572416961*uk_9 + 19800*uk_90 + 285450*uk_91 + 371250*uk_92 + 169950*uk_93 + 583495*uk_94 + 67980*uk_95 + 980045*uk_96 + 1274625*uk_97 + 583495*uk_98 + 7920*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 154000*uk_100 + 198000*uk_101 + 26400*uk_102 + 1684375*uk_103 + 2165625*uk_104 + 288750*uk_105 + 2784375*uk_106 + 371250*uk_107 + 49500*uk_108 + 2985984*uk_109 + 7303536*uk_11 + 622080*uk_110 + 331776*uk_111 + 3628800*uk_112 + 4665600*uk_113 + 622080*uk_114 + 129600*uk_115 + 69120*uk_116 + 756000*uk_117 + 972000*uk_118 + 129600*uk_119 + 1521570*uk_12 + 36864*uk_120 + 403200*uk_121 + 518400*uk_122 + 69120*uk_123 + 4410000*uk_124 + 5670000*uk_125 + 756000*uk_126 + 7290000*uk_127 + 972000*uk_128 + 129600*uk_129 + 811504*uk_13 + 27000*uk_130 + 14400*uk_131 + 157500*uk_132 + 202500*uk_133 + 27000*uk_134 + 7680*uk_135 + 84000*uk_136 + 108000*uk_137 + 14400*uk_138 + 918750*uk_139 + 8875825*uk_14 + 1181250*uk_140 + 157500*uk_141 + 1518750*uk_142 + 202500*uk_143 + 27000*uk_144 + 4096*uk_145 + 44800*uk_146 + 57600*uk_147 + 7680*uk_148 + 490000*uk_149 + 11411775*uk_15 + 630000*uk_150 + 84000*uk_151 + 810000*uk_152 + 108000*uk_153 + 14400*uk_154 + 5359375*uk_155 + 6890625*uk_156 + 918750*uk_157 + 8859375*uk_158 + 1181250*uk_159 + 1521570*uk_16 + 157500*uk_160 + 11390625*uk_161 + 1518750*uk_162 + 202500*uk_163 + 27000*uk_164 + 3025*uk_17 + 7920*uk_18 + 1650*uk_19 + 55*uk_2 + 880*uk_20 + 9625*uk_21 + 12375*uk_22 + 1650*uk_23 + 20736*uk_24 + 4320*uk_25 + 2304*uk_26 + 25200*uk_27 + 32400*uk_28 + 4320*uk_29 + 144*uk_3 + 900*uk_30 + 480*uk_31 + 5250*uk_32 + 6750*uk_33 + 900*uk_34 + 256*uk_35 + 2800*uk_36 + 3600*uk_37 + 480*uk_38 + 30625*uk_39 + 30*uk_4 + 39375*uk_40 + 5250*uk_41 + 50625*uk_42 + 6750*uk_43 + 900*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 370428042384*uk_47 + 77172508830*uk_48 + 41158671376*uk_49 + 16*uk_5 + 450172968175*uk_50 + 578793816225*uk_51 + 77172508830*uk_52 + 153424975*uk_53 + 401694480*uk_54 + 83686350*uk_55 + 44632720*uk_56 + 488170375*uk_57 + 627647625*uk_58 + 83686350*uk_59 + 175*uk_6 + 1051709184*uk_60 + 219106080*uk_61 + 116856576*uk_62 + 1278118800*uk_63 + 1643295600*uk_64 + 219106080*uk_65 + 45647100*uk_66 + 24345120*uk_67 + 266274750*uk_68 + 342353250*uk_69 + 225*uk_7 + 45647100*uk_70 + 12984064*uk_71 + 142013200*uk_72 + 182588400*uk_73 + 24345120*uk_74 + 1553269375*uk_75 + 1997060625*uk_76 + 266274750*uk_77 + 2567649375*uk_78 + 342353250*uk_79 + 30*uk_8 + 45647100*uk_80 + 166375*uk_81 + 435600*uk_82 + 90750*uk_83 + 48400*uk_84 + 529375*uk_85 + 680625*uk_86 + 90750*uk_87 + 1140480*uk_88 + 237600*uk_89 + 2572416961*uk_9 + 126720*uk_90 + 1386000*uk_91 + 1782000*uk_92 + 237600*uk_93 + 49500*uk_94 + 26400*uk_95 + 288750*uk_96 + 371250*uk_97 + 49500*uk_98 + 14080*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 116820*uk_100 + 148500*uk_101 + 95040*uk_102 + 1723095*uk_103 + 2190375*uk_104 + 1401840*uk_105 + 2784375*uk_106 + 1782000*uk_107 + 1140480*uk_108 + 912673*uk_109 + 4919743*uk_11 + 1354896*uk_110 + 112908*uk_111 + 1665393*uk_112 + 2117025*uk_113 + 1354896*uk_114 + 2011392*uk_115 + 167616*uk_116 + 2472336*uk_117 + 3142800*uk_118 + 2011392*uk_119 + 7303536*uk_12 + 13968*uk_120 + 206028*uk_121 + 261900*uk_122 + 167616*uk_123 + 3038913*uk_124 + 3863025*uk_125 + 2472336*uk_126 + 4910625*uk_127 + 3142800*uk_128 + 2011392*uk_129 + 608628*uk_13 + 2985984*uk_130 + 248832*uk_131 + 3670272*uk_132 + 4665600*uk_133 + 2985984*uk_134 + 20736*uk_135 + 305856*uk_136 + 388800*uk_137 + 248832*uk_138 + 4511376*uk_139 + 8977263*uk_14 + 5734800*uk_140 + 3670272*uk_141 + 7290000*uk_142 + 4665600*uk_143 + 2985984*uk_144 + 1728*uk_145 + 25488*uk_146 + 32400*uk_147 + 20736*uk_148 + 375948*uk_149 + 11411775*uk_15 + 477900*uk_150 + 305856*uk_151 + 607500*uk_152 + 388800*uk_153 + 248832*uk_154 + 5545233*uk_155 + 7049025*uk_156 + 4511376*uk_157 + 8960625*uk_158 + 5734800*uk_159 + 7303536*uk_16 + 3670272*uk_160 + 11390625*uk_161 + 7290000*uk_162 + 4665600*uk_163 + 2985984*uk_164 + 3025*uk_17 + 5335*uk_18 + 7920*uk_19 + 55*uk_2 + 660*uk_20 + 9735*uk_21 + 12375*uk_22 + 7920*uk_23 + 9409*uk_24 + 13968*uk_25 + 1164*uk_26 + 17169*uk_27 + 21825*uk_28 + 13968*uk_29 + 97*uk_3 + 20736*uk_30 + 1728*uk_31 + 25488*uk_32 + 32400*uk_33 + 20736*uk_34 + 144*uk_35 + 2124*uk_36 + 2700*uk_37 + 1728*uk_38 + 31329*uk_39 + 144*uk_4 + 39825*uk_40 + 25488*uk_41 + 50625*uk_42 + 32400*uk_43 + 20736*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 249524445217*uk_47 + 370428042384*uk_48 + 30869003532*uk_49 + 12*uk_5 + 455317802097*uk_50 + 578793816225*uk_51 + 370428042384*uk_52 + 153424975*uk_53 + 270585865*uk_54 + 401694480*uk_55 + 33474540*uk_56 + 493749465*uk_57 + 627647625*uk_58 + 401694480*uk_59 + 177*uk_6 + 477215071*uk_60 + 708442992*uk_61 + 59036916*uk_62 + 870794511*uk_63 + 1106942175*uk_64 + 708442992*uk_65 + 1051709184*uk_66 + 87642432*uk_67 + 1292725872*uk_68 + 1643295600*uk_69 + 225*uk_7 + 1051709184*uk_70 + 7303536*uk_71 + 107727156*uk_72 + 136941300*uk_73 + 87642432*uk_74 + 1588975551*uk_75 + 2019884175*uk_76 + 1292725872*uk_77 + 2567649375*uk_78 + 1643295600*uk_79 + 144*uk_8 + 1051709184*uk_80 + 166375*uk_81 + 293425*uk_82 + 435600*uk_83 + 36300*uk_84 + 535425*uk_85 + 680625*uk_86 + 435600*uk_87 + 517495*uk_88 + 768240*uk_89 + 2572416961*uk_9 + 64020*uk_90 + 944295*uk_91 + 1200375*uk_92 + 768240*uk_93 + 1140480*uk_94 + 95040*uk_95 + 1401840*uk_96 + 1782000*uk_97 + 1140480*uk_98 + 7920*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 118140*uk_100 + 148500*uk_101 + 64020*uk_102 + 1762255*uk_103 + 2215125*uk_104 + 954965*uk_105 + 2784375*uk_106 + 1200375*uk_107 + 517495*uk_108 + 238328*uk_109 + 3144578*uk_11 + 372868*uk_110 + 46128*uk_111 + 688076*uk_112 + 864900*uk_113 + 372868*uk_114 + 583358*uk_115 + 72168*uk_116 + 1076506*uk_117 + 1353150*uk_118 + 583358*uk_119 + 4919743*uk_12 + 8928*uk_120 + 133176*uk_121 + 167400*uk_122 + 72168*uk_123 + 1986542*uk_124 + 2497050*uk_125 + 1076506*uk_126 + 3138750*uk_127 + 1353150*uk_128 + 583358*uk_129 + 608628*uk_13 + 912673*uk_130 + 112908*uk_131 + 1684211*uk_132 + 2117025*uk_133 + 912673*uk_134 + 13968*uk_135 + 208356*uk_136 + 261900*uk_137 + 112908*uk_138 + 3107977*uk_139 + 9078701*uk_14 + 3906675*uk_140 + 1684211*uk_141 + 4910625*uk_142 + 2117025*uk_143 + 912673*uk_144 + 1728*uk_145 + 25776*uk_146 + 32400*uk_147 + 13968*uk_148 + 384492*uk_149 + 11411775*uk_15 + 483300*uk_150 + 208356*uk_151 + 607500*uk_152 + 261900*uk_153 + 112908*uk_154 + 5735339*uk_155 + 7209225*uk_156 + 3107977*uk_157 + 9061875*uk_158 + 3906675*uk_159 + 4919743*uk_16 + 1684211*uk_160 + 11390625*uk_161 + 4910625*uk_162 + 2117025*uk_163 + 912673*uk_164 + 3025*uk_17 + 3410*uk_18 + 5335*uk_19 + 55*uk_2 + 660*uk_20 + 9845*uk_21 + 12375*uk_22 + 5335*uk_23 + 3844*uk_24 + 6014*uk_25 + 744*uk_26 + 11098*uk_27 + 13950*uk_28 + 6014*uk_29 + 62*uk_3 + 9409*uk_30 + 1164*uk_31 + 17363*uk_32 + 21825*uk_33 + 9409*uk_34 + 144*uk_35 + 2148*uk_36 + 2700*uk_37 + 1164*uk_38 + 32041*uk_39 + 97*uk_4 + 40275*uk_40 + 17363*uk_41 + 50625*uk_42 + 21825*uk_43 + 9409*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 159489851582*uk_47 + 249524445217*uk_48 + 30869003532*uk_49 + 12*uk_5 + 460462636019*uk_50 + 578793816225*uk_51 + 249524445217*uk_52 + 153424975*uk_53 + 172951790*uk_54 + 270585865*uk_55 + 33474540*uk_56 + 499328555*uk_57 + 627647625*uk_58 + 270585865*uk_59 + 179*uk_6 + 194963836*uk_60 + 305024066*uk_61 + 37734936*uk_62 + 562879462*uk_63 + 707530050*uk_64 + 305024066*uk_65 + 477215071*uk_66 + 59036916*uk_67 + 880633997*uk_68 + 1106942175*uk_69 + 225*uk_7 + 477215071*uk_70 + 7303536*uk_71 + 108944412*uk_72 + 136941300*uk_73 + 59036916*uk_74 + 1625087479*uk_75 + 2042707725*uk_76 + 880633997*uk_77 + 2567649375*uk_78 + 1106942175*uk_79 + 97*uk_8 + 477215071*uk_80 + 166375*uk_81 + 187550*uk_82 + 293425*uk_83 + 36300*uk_84 + 541475*uk_85 + 680625*uk_86 + 293425*uk_87 + 211420*uk_88 + 330770*uk_89 + 2572416961*uk_9 + 40920*uk_90 + 610390*uk_91 + 767250*uk_92 + 330770*uk_93 + 517495*uk_94 + 64020*uk_95 + 954965*uk_96 + 1200375*uk_97 + 517495*uk_98 + 7920*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 119460*uk_100 + 148500*uk_101 + 40920*uk_102 + 1801855*uk_103 + 2239875*uk_104 + 617210*uk_105 + 2784375*uk_106 + 767250*uk_107 + 211420*uk_108 + 59319*uk_109 + 1978041*uk_11 + 94302*uk_110 + 18252*uk_111 + 275301*uk_112 + 342225*uk_113 + 94302*uk_114 + 149916*uk_115 + 29016*uk_116 + 437658*uk_117 + 544050*uk_118 + 149916*uk_119 + 3144578*uk_12 + 5616*uk_120 + 84708*uk_121 + 105300*uk_122 + 29016*uk_123 + 1277679*uk_124 + 1588275*uk_125 + 437658*uk_126 + 1974375*uk_127 + 544050*uk_128 + 149916*uk_129 + 608628*uk_13 + 238328*uk_130 + 46128*uk_131 + 695764*uk_132 + 864900*uk_133 + 238328*uk_134 + 8928*uk_135 + 134664*uk_136 + 167400*uk_137 + 46128*uk_138 + 2031182*uk_139 + 9180139*uk_14 + 2524950*uk_140 + 695764*uk_141 + 3138750*uk_142 + 864900*uk_143 + 238328*uk_144 + 1728*uk_145 + 26064*uk_146 + 32400*uk_147 + 8928*uk_148 + 393132*uk_149 + 11411775*uk_15 + 488700*uk_150 + 134664*uk_151 + 607500*uk_152 + 167400*uk_153 + 46128*uk_154 + 5929741*uk_155 + 7371225*uk_156 + 2031182*uk_157 + 9163125*uk_158 + 2524950*uk_159 + 3144578*uk_16 + 695764*uk_160 + 11390625*uk_161 + 3138750*uk_162 + 864900*uk_163 + 238328*uk_164 + 3025*uk_17 + 2145*uk_18 + 3410*uk_19 + 55*uk_2 + 660*uk_20 + 9955*uk_21 + 12375*uk_22 + 3410*uk_23 + 1521*uk_24 + 2418*uk_25 + 468*uk_26 + 7059*uk_27 + 8775*uk_28 + 2418*uk_29 + 39*uk_3 + 3844*uk_30 + 744*uk_31 + 11222*uk_32 + 13950*uk_33 + 3844*uk_34 + 144*uk_35 + 2172*uk_36 + 2700*uk_37 + 744*uk_38 + 32761*uk_39 + 62*uk_4 + 40725*uk_40 + 11222*uk_41 + 50625*uk_42 + 13950*uk_43 + 3844*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 100324261479*uk_47 + 159489851582*uk_48 + 30869003532*uk_49 + 12*uk_5 + 465607469941*uk_50 + 578793816225*uk_51 + 159489851582*uk_52 + 153424975*uk_53 + 108792255*uk_54 + 172951790*uk_55 + 33474540*uk_56 + 504907645*uk_57 + 627647625*uk_58 + 172951790*uk_59 + 181*uk_6 + 77143599*uk_60 + 122638542*uk_61 + 23736492*uk_62 + 358025421*uk_63 + 445059225*uk_64 + 122638542*uk_65 + 194963836*uk_66 + 37734936*uk_67 + 569168618*uk_68 + 707530050*uk_69 + 225*uk_7 + 194963836*uk_70 + 7303536*uk_71 + 110161668*uk_72 + 136941300*uk_73 + 37734936*uk_74 + 1661605159*uk_75 + 2065531275*uk_76 + 569168618*uk_77 + 2567649375*uk_78 + 707530050*uk_79 + 62*uk_8 + 194963836*uk_80 + 166375*uk_81 + 117975*uk_82 + 187550*uk_83 + 36300*uk_84 + 547525*uk_85 + 680625*uk_86 + 187550*uk_87 + 83655*uk_88 + 132990*uk_89 + 2572416961*uk_9 + 25740*uk_90 + 388245*uk_91 + 482625*uk_92 + 132990*uk_93 + 211420*uk_94 + 40920*uk_95 + 617210*uk_96 + 767250*uk_97 + 211420*uk_98 + 7920*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 120780*uk_100 + 148500*uk_101 + 25740*uk_102 + 1841895*uk_103 + 2264625*uk_104 + 392535*uk_105 + 2784375*uk_106 + 482625*uk_107 + 83655*uk_108 + 21952*uk_109 + 1420132*uk_11 + 30576*uk_110 + 9408*uk_111 + 143472*uk_112 + 176400*uk_113 + 30576*uk_114 + 42588*uk_115 + 13104*uk_116 + 199836*uk_117 + 245700*uk_118 + 42588*uk_119 + 1978041*uk_12 + 4032*uk_120 + 61488*uk_121 + 75600*uk_122 + 13104*uk_123 + 937692*uk_124 + 1152900*uk_125 + 199836*uk_126 + 1417500*uk_127 + 245700*uk_128 + 42588*uk_129 + 608628*uk_13 + 59319*uk_130 + 18252*uk_131 + 278343*uk_132 + 342225*uk_133 + 59319*uk_134 + 5616*uk_135 + 85644*uk_136 + 105300*uk_137 + 18252*uk_138 + 1306071*uk_139 + 9281577*uk_14 + 1605825*uk_140 + 278343*uk_141 + 1974375*uk_142 + 342225*uk_143 + 59319*uk_144 + 1728*uk_145 + 26352*uk_146 + 32400*uk_147 + 5616*uk_148 + 401868*uk_149 + 11411775*uk_15 + 494100*uk_150 + 85644*uk_151 + 607500*uk_152 + 105300*uk_153 + 18252*uk_154 + 6128487*uk_155 + 7535025*uk_156 + 1306071*uk_157 + 9264375*uk_158 + 1605825*uk_159 + 1978041*uk_16 + 278343*uk_160 + 11390625*uk_161 + 1974375*uk_162 + 342225*uk_163 + 59319*uk_164 + 3025*uk_17 + 1540*uk_18 + 2145*uk_19 + 55*uk_2 + 660*uk_20 + 10065*uk_21 + 12375*uk_22 + 2145*uk_23 + 784*uk_24 + 1092*uk_25 + 336*uk_26 + 5124*uk_27 + 6300*uk_28 + 1092*uk_29 + 28*uk_3 + 1521*uk_30 + 468*uk_31 + 7137*uk_32 + 8775*uk_33 + 1521*uk_34 + 144*uk_35 + 2196*uk_36 + 2700*uk_37 + 468*uk_38 + 33489*uk_39 + 39*uk_4 + 41175*uk_40 + 7137*uk_41 + 50625*uk_42 + 8775*uk_43 + 1521*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 72027674908*uk_47 + 100324261479*uk_48 + 30869003532*uk_49 + 12*uk_5 + 470752303863*uk_50 + 578793816225*uk_51 + 100324261479*uk_52 + 153424975*uk_53 + 78107260*uk_54 + 108792255*uk_55 + 33474540*uk_56 + 510486735*uk_57 + 627647625*uk_58 + 108792255*uk_59 + 183*uk_6 + 39763696*uk_60 + 55385148*uk_61 + 17041584*uk_62 + 259884156*uk_63 + 319529700*uk_64 + 55385148*uk_65 + 77143599*uk_66 + 23736492*uk_67 + 361981503*uk_68 + 445059225*uk_69 + 225*uk_7 + 77143599*uk_70 + 7303536*uk_71 + 111378924*uk_72 + 136941300*uk_73 + 23736492*uk_74 + 1698528591*uk_75 + 2088354825*uk_76 + 361981503*uk_77 + 2567649375*uk_78 + 445059225*uk_79 + 39*uk_8 + 77143599*uk_80 + 166375*uk_81 + 84700*uk_82 + 117975*uk_83 + 36300*uk_84 + 553575*uk_85 + 680625*uk_86 + 117975*uk_87 + 43120*uk_88 + 60060*uk_89 + 2572416961*uk_9 + 18480*uk_90 + 281820*uk_91 + 346500*uk_92 + 60060*uk_93 + 83655*uk_94 + 25740*uk_95 + 392535*uk_96 + 482625*uk_97 + 83655*uk_98 + 7920*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 122100*uk_100 + 148500*uk_101 + 18480*uk_102 + 1882375*uk_103 + 2289375*uk_104 + 284900*uk_105 + 2784375*uk_106 + 346500*uk_107 + 43120*uk_108 + 24389*uk_109 + 1470851*uk_11 + 23548*uk_110 + 10092*uk_111 + 155585*uk_112 + 189225*uk_113 + 23548*uk_114 + 22736*uk_115 + 9744*uk_116 + 150220*uk_117 + 182700*uk_118 + 22736*uk_119 + 1420132*uk_12 + 4176*uk_120 + 64380*uk_121 + 78300*uk_122 + 9744*uk_123 + 992525*uk_124 + 1207125*uk_125 + 150220*uk_126 + 1468125*uk_127 + 182700*uk_128 + 22736*uk_129 + 608628*uk_13 + 21952*uk_130 + 9408*uk_131 + 145040*uk_132 + 176400*uk_133 + 21952*uk_134 + 4032*uk_135 + 62160*uk_136 + 75600*uk_137 + 9408*uk_138 + 958300*uk_139 + 9383015*uk_14 + 1165500*uk_140 + 145040*uk_141 + 1417500*uk_142 + 176400*uk_143 + 21952*uk_144 + 1728*uk_145 + 26640*uk_146 + 32400*uk_147 + 4032*uk_148 + 410700*uk_149 + 11411775*uk_15 + 499500*uk_150 + 62160*uk_151 + 607500*uk_152 + 75600*uk_153 + 9408*uk_154 + 6331625*uk_155 + 7700625*uk_156 + 958300*uk_157 + 9365625*uk_158 + 1165500*uk_159 + 1420132*uk_16 + 145040*uk_160 + 11390625*uk_161 + 1417500*uk_162 + 176400*uk_163 + 21952*uk_164 + 3025*uk_17 + 1595*uk_18 + 1540*uk_19 + 55*uk_2 + 660*uk_20 + 10175*uk_21 + 12375*uk_22 + 1540*uk_23 + 841*uk_24 + 812*uk_25 + 348*uk_26 + 5365*uk_27 + 6525*uk_28 + 812*uk_29 + 29*uk_3 + 784*uk_30 + 336*uk_31 + 5180*uk_32 + 6300*uk_33 + 784*uk_34 + 144*uk_35 + 2220*uk_36 + 2700*uk_37 + 336*uk_38 + 34225*uk_39 + 28*uk_4 + 41625*uk_40 + 5180*uk_41 + 50625*uk_42 + 6300*uk_43 + 784*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 74600091869*uk_47 + 72027674908*uk_48 + 30869003532*uk_49 + 12*uk_5 + 475897137785*uk_50 + 578793816225*uk_51 + 72027674908*uk_52 + 153424975*uk_53 + 80896805*uk_54 + 78107260*uk_55 + 33474540*uk_56 + 516065825*uk_57 + 627647625*uk_58 + 78107260*uk_59 + 185*uk_6 + 42654679*uk_60 + 41183828*uk_61 + 17650212*uk_62 + 272107435*uk_63 + 330941475*uk_64 + 41183828*uk_65 + 39763696*uk_66 + 17041584*uk_67 + 262724420*uk_68 + 319529700*uk_69 + 225*uk_7 + 39763696*uk_70 + 7303536*uk_71 + 112596180*uk_72 + 136941300*uk_73 + 17041584*uk_74 + 1735857775*uk_75 + 2111178375*uk_76 + 262724420*uk_77 + 2567649375*uk_78 + 319529700*uk_79 + 28*uk_8 + 39763696*uk_80 + 166375*uk_81 + 87725*uk_82 + 84700*uk_83 + 36300*uk_84 + 559625*uk_85 + 680625*uk_86 + 84700*uk_87 + 46255*uk_88 + 44660*uk_89 + 2572416961*uk_9 + 19140*uk_90 + 295075*uk_91 + 358875*uk_92 + 44660*uk_93 + 43120*uk_94 + 18480*uk_95 + 284900*uk_96 + 346500*uk_97 + 43120*uk_98 + 7920*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 123420*uk_100 + 148500*uk_101 + 19140*uk_102 + 1923295*uk_103 + 2314125*uk_104 + 298265*uk_105 + 2784375*uk_106 + 358875*uk_107 + 46255*uk_108 + 74088*uk_109 + 2130198*uk_11 + 51156*uk_110 + 21168*uk_111 + 329868*uk_112 + 396900*uk_113 + 51156*uk_114 + 35322*uk_115 + 14616*uk_116 + 227766*uk_117 + 274050*uk_118 + 35322*uk_119 + 1470851*uk_12 + 6048*uk_120 + 94248*uk_121 + 113400*uk_122 + 14616*uk_123 + 1468698*uk_124 + 1767150*uk_125 + 227766*uk_126 + 2126250*uk_127 + 274050*uk_128 + 35322*uk_129 + 608628*uk_13 + 24389*uk_130 + 10092*uk_131 + 157267*uk_132 + 189225*uk_133 + 24389*uk_134 + 4176*uk_135 + 65076*uk_136 + 78300*uk_137 + 10092*uk_138 + 1014101*uk_139 + 9484453*uk_14 + 1220175*uk_140 + 157267*uk_141 + 1468125*uk_142 + 189225*uk_143 + 24389*uk_144 + 1728*uk_145 + 26928*uk_146 + 32400*uk_147 + 4176*uk_148 + 419628*uk_149 + 11411775*uk_15 + 504900*uk_150 + 65076*uk_151 + 607500*uk_152 + 78300*uk_153 + 10092*uk_154 + 6539203*uk_155 + 7868025*uk_156 + 1014101*uk_157 + 9466875*uk_158 + 1220175*uk_159 + 1470851*uk_16 + 157267*uk_160 + 11390625*uk_161 + 1468125*uk_162 + 189225*uk_163 + 24389*uk_164 + 3025*uk_17 + 2310*uk_18 + 1595*uk_19 + 55*uk_2 + 660*uk_20 + 10285*uk_21 + 12375*uk_22 + 1595*uk_23 + 1764*uk_24 + 1218*uk_25 + 504*uk_26 + 7854*uk_27 + 9450*uk_28 + 1218*uk_29 + 42*uk_3 + 841*uk_30 + 348*uk_31 + 5423*uk_32 + 6525*uk_33 + 841*uk_34 + 144*uk_35 + 2244*uk_36 + 2700*uk_37 + 348*uk_38 + 34969*uk_39 + 29*uk_4 + 42075*uk_40 + 5423*uk_41 + 50625*uk_42 + 6525*uk_43 + 841*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 108041512362*uk_47 + 74600091869*uk_48 + 30869003532*uk_49 + 12*uk_5 + 481041971707*uk_50 + 578793816225*uk_51 + 74600091869*uk_52 + 153424975*uk_53 + 117160890*uk_54 + 80896805*uk_55 + 33474540*uk_56 + 521644915*uk_57 + 627647625*uk_58 + 80896805*uk_59 + 187*uk_6 + 89468316*uk_60 + 61775742*uk_61 + 25562376*uk_62 + 398347026*uk_63 + 479294550*uk_64 + 61775742*uk_65 + 42654679*uk_66 + 17650212*uk_67 + 275049137*uk_68 + 330941475*uk_69 + 225*uk_7 + 42654679*uk_70 + 7303536*uk_71 + 113813436*uk_72 + 136941300*uk_73 + 17650212*uk_74 + 1773592711*uk_75 + 2134001925*uk_76 + 275049137*uk_77 + 2567649375*uk_78 + 330941475*uk_79 + 29*uk_8 + 42654679*uk_80 + 166375*uk_81 + 127050*uk_82 + 87725*uk_83 + 36300*uk_84 + 565675*uk_85 + 680625*uk_86 + 87725*uk_87 + 97020*uk_88 + 66990*uk_89 + 2572416961*uk_9 + 27720*uk_90 + 431970*uk_91 + 519750*uk_92 + 66990*uk_93 + 46255*uk_94 + 19140*uk_95 + 298265*uk_96 + 358875*uk_97 + 46255*uk_98 + 7920*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 124740*uk_100 + 148500*uk_101 + 27720*uk_102 + 1964655*uk_103 + 2338875*uk_104 + 436590*uk_105 + 2784375*uk_106 + 519750*uk_107 + 97020*uk_108 + 300763*uk_109 + 3398173*uk_11 + 188538*uk_110 + 53868*uk_111 + 848421*uk_112 + 1010025*uk_113 + 188538*uk_114 + 118188*uk_115 + 33768*uk_116 + 531846*uk_117 + 633150*uk_118 + 118188*uk_119 + 2130198*uk_12 + 9648*uk_120 + 151956*uk_121 + 180900*uk_122 + 33768*uk_123 + 2393307*uk_124 + 2849175*uk_125 + 531846*uk_126 + 3391875*uk_127 + 633150*uk_128 + 118188*uk_129 + 608628*uk_13 + 74088*uk_130 + 21168*uk_131 + 333396*uk_132 + 396900*uk_133 + 74088*uk_134 + 6048*uk_135 + 95256*uk_136 + 113400*uk_137 + 21168*uk_138 + 1500282*uk_139 + 9585891*uk_14 + 1786050*uk_140 + 333396*uk_141 + 2126250*uk_142 + 396900*uk_143 + 74088*uk_144 + 1728*uk_145 + 27216*uk_146 + 32400*uk_147 + 6048*uk_148 + 428652*uk_149 + 11411775*uk_15 + 510300*uk_150 + 95256*uk_151 + 607500*uk_152 + 113400*uk_153 + 21168*uk_154 + 6751269*uk_155 + 8037225*uk_156 + 1500282*uk_157 + 9568125*uk_158 + 1786050*uk_159 + 2130198*uk_16 + 333396*uk_160 + 11390625*uk_161 + 2126250*uk_162 + 396900*uk_163 + 74088*uk_164 + 3025*uk_17 + 3685*uk_18 + 2310*uk_19 + 55*uk_2 + 660*uk_20 + 10395*uk_21 + 12375*uk_22 + 2310*uk_23 + 4489*uk_24 + 2814*uk_25 + 804*uk_26 + 12663*uk_27 + 15075*uk_28 + 2814*uk_29 + 67*uk_3 + 1764*uk_30 + 504*uk_31 + 7938*uk_32 + 9450*uk_33 + 1764*uk_34 + 144*uk_35 + 2268*uk_36 + 2700*uk_37 + 504*uk_38 + 35721*uk_39 + 42*uk_4 + 42525*uk_40 + 7938*uk_41 + 50625*uk_42 + 9450*uk_43 + 1764*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 172351936387*uk_47 + 108041512362*uk_48 + 30869003532*uk_49 + 12*uk_5 + 486186805629*uk_50 + 578793816225*uk_51 + 108041512362*uk_52 + 153424975*uk_53 + 186899515*uk_54 + 117160890*uk_55 + 33474540*uk_56 + 527224005*uk_57 + 627647625*uk_58 + 117160890*uk_59 + 189*uk_6 + 227677591*uk_60 + 142723266*uk_61 + 40778076*uk_62 + 642254697*uk_63 + 764588925*uk_64 + 142723266*uk_65 + 89468316*uk_66 + 25562376*uk_67 + 402607422*uk_68 + 479294550*uk_69 + 225*uk_7 + 89468316*uk_70 + 7303536*uk_71 + 115030692*uk_72 + 136941300*uk_73 + 25562376*uk_74 + 1811733399*uk_75 + 2156825475*uk_76 + 402607422*uk_77 + 2567649375*uk_78 + 479294550*uk_79 + 42*uk_8 + 89468316*uk_80 + 166375*uk_81 + 202675*uk_82 + 127050*uk_83 + 36300*uk_84 + 571725*uk_85 + 680625*uk_86 + 127050*uk_87 + 246895*uk_88 + 154770*uk_89 + 2572416961*uk_9 + 44220*uk_90 + 696465*uk_91 + 829125*uk_92 + 154770*uk_93 + 97020*uk_94 + 27720*uk_95 + 436590*uk_96 + 519750*uk_97 + 97020*uk_98 + 7920*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 126060*uk_100 + 148500*uk_101 + 44220*uk_102 + 2006455*uk_103 + 2363625*uk_104 + 703835*uk_105 + 2784375*uk_106 + 829125*uk_107 + 246895*uk_108 + 1124864*uk_109 + 5274776*uk_11 + 724672*uk_110 + 129792*uk_111 + 2065856*uk_112 + 2433600*uk_113 + 724672*uk_114 + 466856*uk_115 + 83616*uk_116 + 1330888*uk_117 + 1567800*uk_118 + 466856*uk_119 + 3398173*uk_12 + 14976*uk_120 + 238368*uk_121 + 280800*uk_122 + 83616*uk_123 + 3794024*uk_124 + 4469400*uk_125 + 1330888*uk_126 + 5265000*uk_127 + 1567800*uk_128 + 466856*uk_129 + 608628*uk_13 + 300763*uk_130 + 53868*uk_131 + 857399*uk_132 + 1010025*uk_133 + 300763*uk_134 + 9648*uk_135 + 153564*uk_136 + 180900*uk_137 + 53868*uk_138 + 2444227*uk_139 + 9687329*uk_14 + 2879325*uk_140 + 857399*uk_141 + 3391875*uk_142 + 1010025*uk_143 + 300763*uk_144 + 1728*uk_145 + 27504*uk_146 + 32400*uk_147 + 9648*uk_148 + 437772*uk_149 + 11411775*uk_15 + 515700*uk_150 + 153564*uk_151 + 607500*uk_152 + 180900*uk_153 + 53868*uk_154 + 6967871*uk_155 + 8208225*uk_156 + 2444227*uk_157 + 9669375*uk_158 + 2879325*uk_159 + 3398173*uk_16 + 857399*uk_160 + 11390625*uk_161 + 3391875*uk_162 + 1010025*uk_163 + 300763*uk_164 + 3025*uk_17 + 5720*uk_18 + 3685*uk_19 + 55*uk_2 + 660*uk_20 + 10505*uk_21 + 12375*uk_22 + 3685*uk_23 + 10816*uk_24 + 6968*uk_25 + 1248*uk_26 + 19864*uk_27 + 23400*uk_28 + 6968*uk_29 + 104*uk_3 + 4489*uk_30 + 804*uk_31 + 12797*uk_32 + 15075*uk_33 + 4489*uk_34 + 144*uk_35 + 2292*uk_36 + 2700*uk_37 + 804*uk_38 + 36481*uk_39 + 67*uk_4 + 42975*uk_40 + 12797*uk_41 + 50625*uk_42 + 15075*uk_43 + 4489*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 267531363944*uk_47 + 172351936387*uk_48 + 30869003532*uk_49 + 12*uk_5 + 491331639551*uk_50 + 578793816225*uk_51 + 172351936387*uk_52 + 153424975*uk_53 + 290112680*uk_54 + 186899515*uk_55 + 33474540*uk_56 + 532803095*uk_57 + 627647625*uk_58 + 186899515*uk_59 + 191*uk_6 + 548576704*uk_60 + 353409992*uk_61 + 63297312*uk_62 + 1007482216*uk_63 + 1186824600*uk_64 + 353409992*uk_65 + 227677591*uk_66 + 40778076*uk_67 + 649051043*uk_68 + 764588925*uk_69 + 225*uk_7 + 227677591*uk_70 + 7303536*uk_71 + 116247948*uk_72 + 136941300*uk_73 + 40778076*uk_74 + 1850279839*uk_75 + 2179649025*uk_76 + 649051043*uk_77 + 2567649375*uk_78 + 764588925*uk_79 + 67*uk_8 + 227677591*uk_80 + 166375*uk_81 + 314600*uk_82 + 202675*uk_83 + 36300*uk_84 + 577775*uk_85 + 680625*uk_86 + 202675*uk_87 + 594880*uk_88 + 383240*uk_89 + 2572416961*uk_9 + 68640*uk_90 + 1092520*uk_91 + 1287000*uk_92 + 383240*uk_93 + 246895*uk_94 + 44220*uk_95 + 703835*uk_96 + 829125*uk_97 + 246895*uk_98 + 7920*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 127380*uk_100 + 148500*uk_101 + 68640*uk_102 + 2048695*uk_103 + 2388375*uk_104 + 1103960*uk_105 + 2784375*uk_106 + 1287000*uk_107 + 594880*uk_108 + 3581577*uk_109 + 7760007*uk_11 + 2434536*uk_110 + 280908*uk_111 + 4517937*uk_112 + 5267025*uk_113 + 2434536*uk_114 + 1654848*uk_115 + 190944*uk_116 + 3071016*uk_117 + 3580200*uk_118 + 1654848*uk_119 + 5274776*uk_12 + 22032*uk_120 + 354348*uk_121 + 413100*uk_122 + 190944*uk_123 + 5699097*uk_124 + 6644025*uk_125 + 3071016*uk_126 + 7745625*uk_127 + 3580200*uk_128 + 1654848*uk_129 + 608628*uk_13 + 1124864*uk_130 + 129792*uk_131 + 2087488*uk_132 + 2433600*uk_133 + 1124864*uk_134 + 14976*uk_135 + 240864*uk_136 + 280800*uk_137 + 129792*uk_138 + 3873896*uk_139 + 9788767*uk_14 + 4516200*uk_140 + 2087488*uk_141 + 5265000*uk_142 + 2433600*uk_143 + 1124864*uk_144 + 1728*uk_145 + 27792*uk_146 + 32400*uk_147 + 14976*uk_148 + 446988*uk_149 + 11411775*uk_15 + 521100*uk_150 + 240864*uk_151 + 607500*uk_152 + 280800*uk_153 + 129792*uk_154 + 7189057*uk_155 + 8381025*uk_156 + 3873896*uk_157 + 9770625*uk_158 + 4516200*uk_159 + 5274776*uk_16 + 2087488*uk_160 + 11390625*uk_161 + 5265000*uk_162 + 2433600*uk_163 + 1124864*uk_164 + 3025*uk_17 + 8415*uk_18 + 5720*uk_19 + 55*uk_2 + 660*uk_20 + 10615*uk_21 + 12375*uk_22 + 5720*uk_23 + 23409*uk_24 + 15912*uk_25 + 1836*uk_26 + 29529*uk_27 + 34425*uk_28 + 15912*uk_29 + 153*uk_3 + 10816*uk_30 + 1248*uk_31 + 20072*uk_32 + 23400*uk_33 + 10816*uk_34 + 144*uk_35 + 2316*uk_36 + 2700*uk_37 + 1248*uk_38 + 37249*uk_39 + 104*uk_4 + 43425*uk_40 + 20072*uk_41 + 50625*uk_42 + 23400*uk_43 + 10816*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 393579795033*uk_47 + 267531363944*uk_48 + 30869003532*uk_49 + 12*uk_5 + 496476473473*uk_50 + 578793816225*uk_51 + 267531363944*uk_52 + 153424975*uk_53 + 426800385*uk_54 + 290112680*uk_55 + 33474540*uk_56 + 538382185*uk_57 + 627647625*uk_58 + 290112680*uk_59 + 193*uk_6 + 1187281071*uk_60 + 807040728*uk_61 + 93120084*uk_62 + 1497681351*uk_63 + 1746001575*uk_64 + 807040728*uk_65 + 548576704*uk_66 + 63297312*uk_67 + 1018031768*uk_68 + 1186824600*uk_69 + 225*uk_7 + 548576704*uk_70 + 7303536*uk_71 + 117465204*uk_72 + 136941300*uk_73 + 63297312*uk_74 + 1889232031*uk_75 + 2202472575*uk_76 + 1018031768*uk_77 + 2567649375*uk_78 + 1186824600*uk_79 + 104*uk_8 + 548576704*uk_80 + 166375*uk_81 + 462825*uk_82 + 314600*uk_83 + 36300*uk_84 + 583825*uk_85 + 680625*uk_86 + 314600*uk_87 + 1287495*uk_88 + 875160*uk_89 + 2572416961*uk_9 + 100980*uk_90 + 1624095*uk_91 + 1893375*uk_92 + 875160*uk_93 + 594880*uk_94 + 68640*uk_95 + 1103960*uk_96 + 1287000*uk_97 + 594880*uk_98 + 7920*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 85800*uk_100 + 99000*uk_101 + 67320*uk_102 + 2091375*uk_103 + 2413125*uk_104 + 1640925*uk_105 + 2784375*uk_106 + 1893375*uk_107 + 1287495*uk_108 + 6859*uk_109 + 963661*uk_11 + 55233*uk_110 + 2888*uk_111 + 70395*uk_112 + 81225*uk_113 + 55233*uk_114 + 444771*uk_115 + 23256*uk_116 + 566865*uk_117 + 654075*uk_118 + 444771*uk_119 + 7760007*uk_12 + 1216*uk_120 + 29640*uk_121 + 34200*uk_122 + 23256*uk_123 + 722475*uk_124 + 833625*uk_125 + 566865*uk_126 + 961875*uk_127 + 654075*uk_128 + 444771*uk_129 + 405752*uk_13 + 3581577*uk_130 + 187272*uk_131 + 4564755*uk_132 + 5267025*uk_133 + 3581577*uk_134 + 9792*uk_135 + 238680*uk_136 + 275400*uk_137 + 187272*uk_138 + 5817825*uk_139 + 9890205*uk_14 + 6712875*uk_140 + 4564755*uk_141 + 7745625*uk_142 + 5267025*uk_143 + 3581577*uk_144 + 512*uk_145 + 12480*uk_146 + 14400*uk_147 + 9792*uk_148 + 304200*uk_149 + 11411775*uk_15 + 351000*uk_150 + 238680*uk_151 + 405000*uk_152 + 275400*uk_153 + 187272*uk_154 + 7414875*uk_155 + 8555625*uk_156 + 5817825*uk_157 + 9871875*uk_158 + 6712875*uk_159 + 7760007*uk_16 + 4564755*uk_160 + 11390625*uk_161 + 7745625*uk_162 + 5267025*uk_163 + 3581577*uk_164 + 3025*uk_17 + 1045*uk_18 + 8415*uk_19 + 55*uk_2 + 440*uk_20 + 10725*uk_21 + 12375*uk_22 + 8415*uk_23 + 361*uk_24 + 2907*uk_25 + 152*uk_26 + 3705*uk_27 + 4275*uk_28 + 2907*uk_29 + 19*uk_3 + 23409*uk_30 + 1224*uk_31 + 29835*uk_32 + 34425*uk_33 + 23409*uk_34 + 64*uk_35 + 1560*uk_36 + 1800*uk_37 + 1224*uk_38 + 38025*uk_39 + 153*uk_4 + 43875*uk_40 + 29835*uk_41 + 50625*uk_42 + 34425*uk_43 + 23409*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 48875922259*uk_47 + 393579795033*uk_48 + 20579335688*uk_49 + 8*uk_5 + 501621307395*uk_50 + 578793816225*uk_51 + 393579795033*uk_52 + 153424975*uk_53 + 53001355*uk_54 + 426800385*uk_55 + 22316360*uk_56 + 543961275*uk_57 + 627647625*uk_58 + 426800385*uk_59 + 195*uk_6 + 18309559*uk_60 + 147440133*uk_61 + 7709288*uk_62 + 187913895*uk_63 + 216823725*uk_64 + 147440133*uk_65 + 1187281071*uk_66 + 62080056*uk_67 + 1513201365*uk_68 + 1746001575*uk_69 + 225*uk_7 + 1187281071*uk_70 + 3246016*uk_71 + 79121640*uk_72 + 91294200*uk_73 + 62080056*uk_74 + 1928589975*uk_75 + 2225296125*uk_76 + 1513201365*uk_77 + 2567649375*uk_78 + 1746001575*uk_79 + 153*uk_8 + 1187281071*uk_80 + 166375*uk_81 + 57475*uk_82 + 462825*uk_83 + 24200*uk_84 + 589875*uk_85 + 680625*uk_86 + 462825*uk_87 + 19855*uk_88 + 159885*uk_89 + 2572416961*uk_9 + 8360*uk_90 + 203775*uk_91 + 235125*uk_92 + 159885*uk_93 + 1287495*uk_94 + 67320*uk_95 + 1640925*uk_96 + 1893375*uk_97 + 1287495*uk_98 + 3520*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 130020*uk_100 + 148500*uk_101 + 12540*uk_102 + 2134495*uk_103 + 2437875*uk_104 + 205865*uk_105 + 2784375*uk_106 + 235125*uk_107 + 19855*uk_108 + 729000*uk_109 + 4564710*uk_11 + 153900*uk_110 + 97200*uk_111 + 1595700*uk_112 + 1822500*uk_113 + 153900*uk_114 + 32490*uk_115 + 20520*uk_116 + 336870*uk_117 + 384750*uk_118 + 32490*uk_119 + 963661*uk_12 + 12960*uk_120 + 212760*uk_121 + 243000*uk_122 + 20520*uk_123 + 3492810*uk_124 + 3989250*uk_125 + 336870*uk_126 + 4556250*uk_127 + 384750*uk_128 + 32490*uk_129 + 608628*uk_13 + 6859*uk_130 + 4332*uk_131 + 71117*uk_132 + 81225*uk_133 + 6859*uk_134 + 2736*uk_135 + 44916*uk_136 + 51300*uk_137 + 4332*uk_138 + 737371*uk_139 + 9991643*uk_14 + 842175*uk_140 + 71117*uk_141 + 961875*uk_142 + 81225*uk_143 + 6859*uk_144 + 1728*uk_145 + 28368*uk_146 + 32400*uk_147 + 2736*uk_148 + 465708*uk_149 + 11411775*uk_15 + 531900*uk_150 + 44916*uk_151 + 607500*uk_152 + 51300*uk_153 + 4332*uk_154 + 7645373*uk_155 + 8732025*uk_156 + 737371*uk_157 + 9973125*uk_158 + 842175*uk_159 + 963661*uk_16 + 71117*uk_160 + 11390625*uk_161 + 961875*uk_162 + 81225*uk_163 + 6859*uk_164 + 3025*uk_17 + 4950*uk_18 + 1045*uk_19 + 55*uk_2 + 660*uk_20 + 10835*uk_21 + 12375*uk_22 + 1045*uk_23 + 8100*uk_24 + 1710*uk_25 + 1080*uk_26 + 17730*uk_27 + 20250*uk_28 + 1710*uk_29 + 90*uk_3 + 361*uk_30 + 228*uk_31 + 3743*uk_32 + 4275*uk_33 + 361*uk_34 + 144*uk_35 + 2364*uk_36 + 2700*uk_37 + 228*uk_38 + 38809*uk_39 + 19*uk_4 + 44325*uk_40 + 3743*uk_41 + 50625*uk_42 + 4275*uk_43 + 361*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 231517526490*uk_47 + 48875922259*uk_48 + 30869003532*uk_49 + 12*uk_5 + 506766141317*uk_50 + 578793816225*uk_51 + 48875922259*uk_52 + 153424975*uk_53 + 251059050*uk_54 + 53001355*uk_55 + 33474540*uk_56 + 549540365*uk_57 + 627647625*uk_58 + 53001355*uk_59 + 197*uk_6 + 410823900*uk_60 + 86729490*uk_61 + 54776520*uk_62 + 899247870*uk_63 + 1027059750*uk_64 + 86729490*uk_65 + 18309559*uk_66 + 11563932*uk_67 + 189841217*uk_68 + 216823725*uk_69 + 225*uk_7 + 18309559*uk_70 + 7303536*uk_71 + 119899716*uk_72 + 136941300*uk_73 + 11563932*uk_74 + 1968353671*uk_75 + 2248119675*uk_76 + 189841217*uk_77 + 2567649375*uk_78 + 216823725*uk_79 + 19*uk_8 + 18309559*uk_80 + 166375*uk_81 + 272250*uk_82 + 57475*uk_83 + 36300*uk_84 + 595925*uk_85 + 680625*uk_86 + 57475*uk_87 + 445500*uk_88 + 94050*uk_89 + 2572416961*uk_9 + 59400*uk_90 + 975150*uk_91 + 1113750*uk_92 + 94050*uk_93 + 19855*uk_94 + 12540*uk_95 + 205865*uk_96 + 235125*uk_97 + 19855*uk_98 + 7920*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 131340*uk_100 + 148500*uk_101 + 59400*uk_102 + 2178055*uk_103 + 2462625*uk_104 + 985050*uk_105 + 2784375*uk_106 + 1113750*uk_107 + 445500*uk_108 + 5177717*uk_109 + 8774387*uk_11 + 2693610*uk_110 + 359148*uk_111 + 5955871*uk_112 + 6734025*uk_113 + 2693610*uk_114 + 1401300*uk_115 + 186840*uk_116 + 3098430*uk_117 + 3503250*uk_118 + 1401300*uk_119 + 4564710*uk_12 + 24912*uk_120 + 413124*uk_121 + 467100*uk_122 + 186840*uk_123 + 6850973*uk_124 + 7746075*uk_125 + 3098430*uk_126 + 8758125*uk_127 + 3503250*uk_128 + 1401300*uk_129 + 608628*uk_13 + 729000*uk_130 + 97200*uk_131 + 1611900*uk_132 + 1822500*uk_133 + 729000*uk_134 + 12960*uk_135 + 214920*uk_136 + 243000*uk_137 + 97200*uk_138 + 3564090*uk_139 + 10093081*uk_14 + 4029750*uk_140 + 1611900*uk_141 + 4556250*uk_142 + 1822500*uk_143 + 729000*uk_144 + 1728*uk_145 + 28656*uk_146 + 32400*uk_147 + 12960*uk_148 + 475212*uk_149 + 11411775*uk_15 + 537300*uk_150 + 214920*uk_151 + 607500*uk_152 + 243000*uk_153 + 97200*uk_154 + 7880599*uk_155 + 8910225*uk_156 + 3564090*uk_157 + 10074375*uk_158 + 4029750*uk_159 + 4564710*uk_16 + 1611900*uk_160 + 11390625*uk_161 + 4556250*uk_162 + 1822500*uk_163 + 729000*uk_164 + 3025*uk_17 + 9515*uk_18 + 4950*uk_19 + 55*uk_2 + 660*uk_20 + 10945*uk_21 + 12375*uk_22 + 4950*uk_23 + 29929*uk_24 + 15570*uk_25 + 2076*uk_26 + 34427*uk_27 + 38925*uk_28 + 15570*uk_29 + 173*uk_3 + 8100*uk_30 + 1080*uk_31 + 17910*uk_32 + 20250*uk_33 + 8100*uk_34 + 144*uk_35 + 2388*uk_36 + 2700*uk_37 + 1080*uk_38 + 39601*uk_39 + 90*uk_4 + 44775*uk_40 + 17910*uk_41 + 50625*uk_42 + 20250*uk_43 + 8100*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 445028134253*uk_47 + 231517526490*uk_48 + 30869003532*uk_49 + 12*uk_5 + 511910975239*uk_50 + 578793816225*uk_51 + 231517526490*uk_52 + 153424975*uk_53 + 482591285*uk_54 + 251059050*uk_55 + 33474540*uk_56 + 555119455*uk_57 + 627647625*uk_58 + 251059050*uk_59 + 199*uk_6 + 1517968951*uk_60 + 789694830*uk_61 + 105292644*uk_62 + 1746103013*uk_63 + 1974237075*uk_64 + 789694830*uk_65 + 410823900*uk_66 + 54776520*uk_67 + 908377290*uk_68 + 1027059750*uk_69 + 225*uk_7 + 410823900*uk_70 + 7303536*uk_71 + 121116972*uk_72 + 136941300*uk_73 + 54776520*uk_74 + 2008523119*uk_75 + 2270943225*uk_76 + 908377290*uk_77 + 2567649375*uk_78 + 1027059750*uk_79 + 90*uk_8 + 410823900*uk_80 + 166375*uk_81 + 523325*uk_82 + 272250*uk_83 + 36300*uk_84 + 601975*uk_85 + 680625*uk_86 + 272250*uk_87 + 1646095*uk_88 + 856350*uk_89 + 2572416961*uk_9 + 114180*uk_90 + 1893485*uk_91 + 2140875*uk_92 + 856350*uk_93 + 445500*uk_94 + 59400*uk_95 + 985050*uk_96 + 1113750*uk_97 + 445500*uk_98 + 7920*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 88440*uk_100 + 99000*uk_101 + 76120*uk_102 + 2222055*uk_103 + 2487375*uk_104 + 1912515*uk_105 + 2784375*uk_106 + 2140875*uk_107 + 1646095*uk_108 + 300763*uk_109 + 3398173*uk_11 + 776597*uk_110 + 35912*uk_111 + 902289*uk_112 + 1010025*uk_113 + 776597*uk_114 + 2005243*uk_115 + 92728*uk_116 + 2329791*uk_117 + 2607975*uk_118 + 2005243*uk_119 + 8774387*uk_12 + 4288*uk_120 + 107736*uk_121 + 120600*uk_122 + 92728*uk_123 + 2706867*uk_124 + 3030075*uk_125 + 2329791*uk_126 + 3391875*uk_127 + 2607975*uk_128 + 2005243*uk_129 + 405752*uk_13 + 5177717*uk_130 + 239432*uk_131 + 6015729*uk_132 + 6734025*uk_133 + 5177717*uk_134 + 11072*uk_135 + 278184*uk_136 + 311400*uk_137 + 239432*uk_138 + 6989373*uk_139 + 10194519*uk_14 + 7823925*uk_140 + 6015729*uk_141 + 8758125*uk_142 + 6734025*uk_143 + 5177717*uk_144 + 512*uk_145 + 12864*uk_146 + 14400*uk_147 + 11072*uk_148 + 323208*uk_149 + 11411775*uk_15 + 361800*uk_150 + 278184*uk_151 + 405000*uk_152 + 311400*uk_153 + 239432*uk_154 + 8120601*uk_155 + 9090225*uk_156 + 6989373*uk_157 + 10175625*uk_158 + 7823925*uk_159 + 8774387*uk_16 + 6015729*uk_160 + 11390625*uk_161 + 8758125*uk_162 + 6734025*uk_163 + 5177717*uk_164 + 3025*uk_17 + 3685*uk_18 + 9515*uk_19 + 55*uk_2 + 440*uk_20 + 11055*uk_21 + 12375*uk_22 + 9515*uk_23 + 4489*uk_24 + 11591*uk_25 + 536*uk_26 + 13467*uk_27 + 15075*uk_28 + 11591*uk_29 + 67*uk_3 + 29929*uk_30 + 1384*uk_31 + 34773*uk_32 + 38925*uk_33 + 29929*uk_34 + 64*uk_35 + 1608*uk_36 + 1800*uk_37 + 1384*uk_38 + 40401*uk_39 + 173*uk_4 + 45225*uk_40 + 34773*uk_41 + 50625*uk_42 + 38925*uk_43 + 29929*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 172351936387*uk_47 + 445028134253*uk_48 + 20579335688*uk_49 + 8*uk_5 + 517055809161*uk_50 + 578793816225*uk_51 + 445028134253*uk_52 + 153424975*uk_53 + 186899515*uk_54 + 482591285*uk_55 + 22316360*uk_56 + 560698545*uk_57 + 627647625*uk_58 + 482591285*uk_59 + 201*uk_6 + 227677591*uk_60 + 587883929*uk_61 + 27185384*uk_62 + 683032773*uk_63 + 764588925*uk_64 + 587883929*uk_65 + 1517968951*uk_66 + 70195096*uk_67 + 1763651787*uk_68 + 1974237075*uk_69 + 225*uk_7 + 1517968951*uk_70 + 3246016*uk_71 + 81556152*uk_72 + 91294200*uk_73 + 70195096*uk_74 + 2049098319*uk_75 + 2293766775*uk_76 + 1763651787*uk_77 + 2567649375*uk_78 + 1974237075*uk_79 + 173*uk_8 + 1517968951*uk_80 + 166375*uk_81 + 202675*uk_82 + 523325*uk_83 + 24200*uk_84 + 608025*uk_85 + 680625*uk_86 + 523325*uk_87 + 246895*uk_88 + 637505*uk_89 + 2572416961*uk_9 + 29480*uk_90 + 740685*uk_91 + 829125*uk_92 + 637505*uk_93 + 1646095*uk_94 + 76120*uk_95 + 1912515*uk_96 + 2140875*uk_97 + 1646095*uk_98 + 3520*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 133980*uk_100 + 148500*uk_101 + 44220*uk_102 + 2266495*uk_103 + 2512125*uk_104 + 748055*uk_105 + 2784375*uk_106 + 829125*uk_107 + 246895*uk_108 + 5088448*uk_109 + 8723668*uk_11 + 1982128*uk_110 + 355008*uk_111 + 6005552*uk_112 + 6656400*uk_113 + 1982128*uk_114 + 772108*uk_115 + 138288*uk_116 + 2339372*uk_117 + 2592900*uk_118 + 772108*uk_119 + 3398173*uk_12 + 24768*uk_120 + 418992*uk_121 + 464400*uk_122 + 138288*uk_123 + 7087948*uk_124 + 7856100*uk_125 + 2339372*uk_126 + 8707500*uk_127 + 2592900*uk_128 + 772108*uk_129 + 608628*uk_13 + 300763*uk_130 + 53868*uk_131 + 911267*uk_132 + 1010025*uk_133 + 300763*uk_134 + 9648*uk_135 + 163212*uk_136 + 180900*uk_137 + 53868*uk_138 + 2761003*uk_139 + 10295957*uk_14 + 3060225*uk_140 + 911267*uk_141 + 3391875*uk_142 + 1010025*uk_143 + 300763*uk_144 + 1728*uk_145 + 29232*uk_146 + 32400*uk_147 + 9648*uk_148 + 494508*uk_149 + 11411775*uk_15 + 548100*uk_150 + 163212*uk_151 + 607500*uk_152 + 180900*uk_153 + 53868*uk_154 + 8365427*uk_155 + 9272025*uk_156 + 2761003*uk_157 + 10276875*uk_158 + 3060225*uk_159 + 3398173*uk_16 + 911267*uk_160 + 11390625*uk_161 + 3391875*uk_162 + 1010025*uk_163 + 300763*uk_164 + 3025*uk_17 + 9460*uk_18 + 3685*uk_19 + 55*uk_2 + 660*uk_20 + 11165*uk_21 + 12375*uk_22 + 3685*uk_23 + 29584*uk_24 + 11524*uk_25 + 2064*uk_26 + 34916*uk_27 + 38700*uk_28 + 11524*uk_29 + 172*uk_3 + 4489*uk_30 + 804*uk_31 + 13601*uk_32 + 15075*uk_33 + 4489*uk_34 + 144*uk_35 + 2436*uk_36 + 2700*uk_37 + 804*uk_38 + 41209*uk_39 + 67*uk_4 + 45675*uk_40 + 13601*uk_41 + 50625*uk_42 + 15075*uk_43 + 4489*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 442455717292*uk_47 + 172351936387*uk_48 + 30869003532*uk_49 + 12*uk_5 + 522200643083*uk_50 + 578793816225*uk_51 + 172351936387*uk_52 + 153424975*uk_53 + 479801740*uk_54 + 186899515*uk_55 + 33474540*uk_56 + 566277635*uk_57 + 627647625*uk_58 + 186899515*uk_59 + 203*uk_6 + 1500470896*uk_60 + 584485756*uk_61 + 104684016*uk_62 + 1770904604*uk_63 + 1962825300*uk_64 + 584485756*uk_65 + 227677591*uk_66 + 40778076*uk_67 + 689829119*uk_68 + 764588925*uk_69 + 225*uk_7 + 227677591*uk_70 + 7303536*uk_71 + 123551484*uk_72 + 136941300*uk_73 + 40778076*uk_74 + 2090079271*uk_75 + 2316590325*uk_76 + 689829119*uk_77 + 2567649375*uk_78 + 764588925*uk_79 + 67*uk_8 + 227677591*uk_80 + 166375*uk_81 + 520300*uk_82 + 202675*uk_83 + 36300*uk_84 + 614075*uk_85 + 680625*uk_86 + 202675*uk_87 + 1627120*uk_88 + 633820*uk_89 + 2572416961*uk_9 + 113520*uk_90 + 1920380*uk_91 + 2128500*uk_92 + 633820*uk_93 + 246895*uk_94 + 44220*uk_95 + 748055*uk_96 + 829125*uk_97 + 246895*uk_98 + 7920*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 90200*uk_100 + 99000*uk_101 + 75680*uk_102 + 2311375*uk_103 + 2536875*uk_104 + 1939300*uk_105 + 2784375*uk_106 + 2128500*uk_107 + 1627120*uk_108 + 592704*uk_109 + 4260396*uk_11 + 1213632*uk_110 + 56448*uk_111 + 1446480*uk_112 + 1587600*uk_113 + 1213632*uk_114 + 2485056*uk_115 + 115584*uk_116 + 2961840*uk_117 + 3250800*uk_118 + 2485056*uk_119 + 8723668*uk_12 + 5376*uk_120 + 137760*uk_121 + 151200*uk_122 + 115584*uk_123 + 3530100*uk_124 + 3874500*uk_125 + 2961840*uk_126 + 4252500*uk_127 + 3250800*uk_128 + 2485056*uk_129 + 405752*uk_13 + 5088448*uk_130 + 236672*uk_131 + 6064720*uk_132 + 6656400*uk_133 + 5088448*uk_134 + 11008*uk_135 + 282080*uk_136 + 309600*uk_137 + 236672*uk_138 + 7228300*uk_139 + 10397395*uk_14 + 7933500*uk_140 + 6064720*uk_141 + 8707500*uk_142 + 6656400*uk_143 + 5088448*uk_144 + 512*uk_145 + 13120*uk_146 + 14400*uk_147 + 11008*uk_148 + 336200*uk_149 + 11411775*uk_15 + 369000*uk_150 + 282080*uk_151 + 405000*uk_152 + 309600*uk_153 + 236672*uk_154 + 8615125*uk_155 + 9455625*uk_156 + 7228300*uk_157 + 10378125*uk_158 + 7933500*uk_159 + 8723668*uk_16 + 6064720*uk_160 + 11390625*uk_161 + 8707500*uk_162 + 6656400*uk_163 + 5088448*uk_164 + 3025*uk_17 + 4620*uk_18 + 9460*uk_19 + 55*uk_2 + 440*uk_20 + 11275*uk_21 + 12375*uk_22 + 9460*uk_23 + 7056*uk_24 + 14448*uk_25 + 672*uk_26 + 17220*uk_27 + 18900*uk_28 + 14448*uk_29 + 84*uk_3 + 29584*uk_30 + 1376*uk_31 + 35260*uk_32 + 38700*uk_33 + 29584*uk_34 + 64*uk_35 + 1640*uk_36 + 1800*uk_37 + 1376*uk_38 + 42025*uk_39 + 172*uk_4 + 46125*uk_40 + 35260*uk_41 + 50625*uk_42 + 38700*uk_43 + 29584*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 216083024724*uk_47 + 442455717292*uk_48 + 20579335688*uk_49 + 8*uk_5 + 527345477005*uk_50 + 578793816225*uk_51 + 442455717292*uk_52 + 153424975*uk_53 + 234321780*uk_54 + 479801740*uk_55 + 22316360*uk_56 + 571856725*uk_57 + 627647625*uk_58 + 479801740*uk_59 + 205*uk_6 + 357873264*uk_60 + 732788112*uk_61 + 34083168*uk_62 + 873381180*uk_63 + 958589100*uk_64 + 732788112*uk_65 + 1500470896*uk_66 + 69789344*uk_67 + 1788351940*uk_68 + 1962825300*uk_69 + 225*uk_7 + 1500470896*uk_70 + 3246016*uk_71 + 83179160*uk_72 + 91294200*uk_73 + 69789344*uk_74 + 2131465975*uk_75 + 2339413875*uk_76 + 1788351940*uk_77 + 2567649375*uk_78 + 1962825300*uk_79 + 172*uk_8 + 1500470896*uk_80 + 166375*uk_81 + 254100*uk_82 + 520300*uk_83 + 24200*uk_84 + 620125*uk_85 + 680625*uk_86 + 520300*uk_87 + 388080*uk_88 + 794640*uk_89 + 2572416961*uk_9 + 36960*uk_90 + 947100*uk_91 + 1039500*uk_92 + 794640*uk_93 + 1627120*uk_94 + 75680*uk_95 + 1939300*uk_96 + 2128500*uk_97 + 1627120*uk_98 + 3520*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 91080*uk_100 + 99000*uk_101 + 36960*uk_102 + 2356695*uk_103 + 2561625*uk_104 + 956340*uk_105 + 2784375*uk_106 + 1039500*uk_107 + 388080*uk_108 + 64*uk_109 + 202876*uk_11 + 1344*uk_110 + 128*uk_111 + 3312*uk_112 + 3600*uk_113 + 1344*uk_114 + 28224*uk_115 + 2688*uk_116 + 69552*uk_117 + 75600*uk_118 + 28224*uk_119 + 4260396*uk_12 + 256*uk_120 + 6624*uk_121 + 7200*uk_122 + 2688*uk_123 + 171396*uk_124 + 186300*uk_125 + 69552*uk_126 + 202500*uk_127 + 75600*uk_128 + 28224*uk_129 + 405752*uk_13 + 592704*uk_130 + 56448*uk_131 + 1460592*uk_132 + 1587600*uk_133 + 592704*uk_134 + 5376*uk_135 + 139104*uk_136 + 151200*uk_137 + 56448*uk_138 + 3599316*uk_139 + 10498833*uk_14 + 3912300*uk_140 + 1460592*uk_141 + 4252500*uk_142 + 1587600*uk_143 + 592704*uk_144 + 512*uk_145 + 13248*uk_146 + 14400*uk_147 + 5376*uk_148 + 342792*uk_149 + 11411775*uk_15 + 372600*uk_150 + 139104*uk_151 + 405000*uk_152 + 151200*uk_153 + 56448*uk_154 + 8869743*uk_155 + 9641025*uk_156 + 3599316*uk_157 + 10479375*uk_158 + 3912300*uk_159 + 4260396*uk_16 + 1460592*uk_160 + 11390625*uk_161 + 4252500*uk_162 + 1587600*uk_163 + 592704*uk_164 + 3025*uk_17 + 220*uk_18 + 4620*uk_19 + 55*uk_2 + 440*uk_20 + 11385*uk_21 + 12375*uk_22 + 4620*uk_23 + 16*uk_24 + 336*uk_25 + 32*uk_26 + 828*uk_27 + 900*uk_28 + 336*uk_29 + 4*uk_3 + 7056*uk_30 + 672*uk_31 + 17388*uk_32 + 18900*uk_33 + 7056*uk_34 + 64*uk_35 + 1656*uk_36 + 1800*uk_37 + 672*uk_38 + 42849*uk_39 + 84*uk_4 + 46575*uk_40 + 17388*uk_41 + 50625*uk_42 + 18900*uk_43 + 7056*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 10289667844*uk_47 + 216083024724*uk_48 + 20579335688*uk_49 + 8*uk_5 + 532490310927*uk_50 + 578793816225*uk_51 + 216083024724*uk_52 + 153424975*uk_53 + 11158180*uk_54 + 234321780*uk_55 + 22316360*uk_56 + 577435815*uk_57 + 627647625*uk_58 + 234321780*uk_59 + 207*uk_6 + 811504*uk_60 + 17041584*uk_61 + 1623008*uk_62 + 41995332*uk_63 + 45647100*uk_64 + 17041584*uk_65 + 357873264*uk_66 + 34083168*uk_67 + 881901972*uk_68 + 958589100*uk_69 + 225*uk_7 + 357873264*uk_70 + 3246016*uk_71 + 83990664*uk_72 + 91294200*uk_73 + 34083168*uk_74 + 2173258431*uk_75 + 2362237425*uk_76 + 881901972*uk_77 + 2567649375*uk_78 + 958589100*uk_79 + 84*uk_8 + 357873264*uk_80 + 166375*uk_81 + 12100*uk_82 + 254100*uk_83 + 24200*uk_84 + 626175*uk_85 + 680625*uk_86 + 254100*uk_87 + 880*uk_88 + 18480*uk_89 + 2572416961*uk_9 + 1760*uk_90 + 45540*uk_91 + 49500*uk_92 + 18480*uk_93 + 388080*uk_94 + 36960*uk_95 + 956340*uk_96 + 1039500*uk_97 + 388080*uk_98 + 3520*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 137940*uk_100 + 148500*uk_101 + 2640*uk_102 + 2402455*uk_103 + 2586375*uk_104 + 45980*uk_105 + 2784375*uk_106 + 49500*uk_107 + 880*uk_108 + 2803221*uk_109 + 7151379*uk_11 + 79524*uk_110 + 238572*uk_111 + 4155129*uk_112 + 4473225*uk_113 + 79524*uk_114 + 2256*uk_115 + 6768*uk_116 + 117876*uk_117 + 126900*uk_118 + 2256*uk_119 + 202876*uk_12 + 20304*uk_120 + 353628*uk_121 + 380700*uk_122 + 6768*uk_123 + 6159021*uk_124 + 6630525*uk_125 + 117876*uk_126 + 7138125*uk_127 + 126900*uk_128 + 2256*uk_129 + 608628*uk_13 + 64*uk_130 + 192*uk_131 + 3344*uk_132 + 3600*uk_133 + 64*uk_134 + 576*uk_135 + 10032*uk_136 + 10800*uk_137 + 192*uk_138 + 174724*uk_139 + 10600271*uk_14 + 188100*uk_140 + 3344*uk_141 + 202500*uk_142 + 3600*uk_143 + 64*uk_144 + 1728*uk_145 + 30096*uk_146 + 32400*uk_147 + 576*uk_148 + 524172*uk_149 + 11411775*uk_15 + 564300*uk_150 + 10032*uk_151 + 607500*uk_152 + 10800*uk_153 + 192*uk_154 + 9129329*uk_155 + 9828225*uk_156 + 174724*uk_157 + 10580625*uk_158 + 188100*uk_159 + 202876*uk_16 + 3344*uk_160 + 11390625*uk_161 + 202500*uk_162 + 3600*uk_163 + 64*uk_164 + 3025*uk_17 + 7755*uk_18 + 220*uk_19 + 55*uk_2 + 660*uk_20 + 11495*uk_21 + 12375*uk_22 + 220*uk_23 + 19881*uk_24 + 564*uk_25 + 1692*uk_26 + 29469*uk_27 + 31725*uk_28 + 564*uk_29 + 141*uk_3 + 16*uk_30 + 48*uk_31 + 836*uk_32 + 900*uk_33 + 16*uk_34 + 144*uk_35 + 2508*uk_36 + 2700*uk_37 + 48*uk_38 + 43681*uk_39 + 4*uk_4 + 47025*uk_40 + 836*uk_41 + 50625*uk_42 + 900*uk_43 + 16*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 362710791501*uk_47 + 10289667844*uk_48 + 30869003532*uk_49 + 12*uk_5 + 537635144849*uk_50 + 578793816225*uk_51 + 10289667844*uk_52 + 153424975*uk_53 + 393325845*uk_54 + 11158180*uk_55 + 33474540*uk_56 + 583014905*uk_57 + 627647625*uk_58 + 11158180*uk_59 + 209*uk_6 + 1008344439*uk_60 + 28605516*uk_61 + 85816548*uk_62 + 1494638211*uk_63 + 1609060275*uk_64 + 28605516*uk_65 + 811504*uk_66 + 2434512*uk_67 + 42401084*uk_68 + 45647100*uk_69 + 225*uk_7 + 811504*uk_70 + 7303536*uk_71 + 127203252*uk_72 + 136941300*uk_73 + 2434512*uk_74 + 2215456639*uk_75 + 2385060975*uk_76 + 42401084*uk_77 + 2567649375*uk_78 + 45647100*uk_79 + 4*uk_8 + 811504*uk_80 + 166375*uk_81 + 426525*uk_82 + 12100*uk_83 + 36300*uk_84 + 632225*uk_85 + 680625*uk_86 + 12100*uk_87 + 1093455*uk_88 + 31020*uk_89 + 2572416961*uk_9 + 93060*uk_90 + 1620795*uk_91 + 1744875*uk_92 + 31020*uk_93 + 880*uk_94 + 2640*uk_95 + 45980*uk_96 + 49500*uk_97 + 880*uk_98 + 7920*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 92840*uk_100 + 99000*uk_101 + 62040*uk_102 + 2448655*uk_103 + 2611125*uk_104 + 1636305*uk_105 + 2784375*uk_106 + 1744875*uk_107 + 1093455*uk_108 + 493039*uk_109 + 4006801*uk_11 + 879981*uk_110 + 49928*uk_111 + 1316851*uk_112 + 1404225*uk_113 + 879981*uk_114 + 1570599*uk_115 + 89112*uk_116 + 2350329*uk_117 + 2506275*uk_118 + 1570599*uk_119 + 7151379*uk_12 + 5056*uk_120 + 133352*uk_121 + 142200*uk_122 + 89112*uk_123 + 3517159*uk_124 + 3750525*uk_125 + 2350329*uk_126 + 3999375*uk_127 + 2506275*uk_128 + 1570599*uk_129 + 405752*uk_13 + 2803221*uk_130 + 159048*uk_131 + 4194891*uk_132 + 4473225*uk_133 + 2803221*uk_134 + 9024*uk_135 + 238008*uk_136 + 253800*uk_137 + 159048*uk_138 + 6277461*uk_139 + 10701709*uk_14 + 6693975*uk_140 + 4194891*uk_141 + 7138125*uk_142 + 4473225*uk_143 + 2803221*uk_144 + 512*uk_145 + 13504*uk_146 + 14400*uk_147 + 9024*uk_148 + 356168*uk_149 + 11411775*uk_15 + 379800*uk_150 + 238008*uk_151 + 405000*uk_152 + 253800*uk_153 + 159048*uk_154 + 9393931*uk_155 + 10017225*uk_156 + 6277461*uk_157 + 10681875*uk_158 + 6693975*uk_159 + 7151379*uk_16 + 4194891*uk_160 + 11390625*uk_161 + 7138125*uk_162 + 4473225*uk_163 + 2803221*uk_164 + 3025*uk_17 + 4345*uk_18 + 7755*uk_19 + 55*uk_2 + 440*uk_20 + 11605*uk_21 + 12375*uk_22 + 7755*uk_23 + 6241*uk_24 + 11139*uk_25 + 632*uk_26 + 16669*uk_27 + 17775*uk_28 + 11139*uk_29 + 79*uk_3 + 19881*uk_30 + 1128*uk_31 + 29751*uk_32 + 31725*uk_33 + 19881*uk_34 + 64*uk_35 + 1688*uk_36 + 1800*uk_37 + 1128*uk_38 + 44521*uk_39 + 141*uk_4 + 47475*uk_40 + 29751*uk_41 + 50625*uk_42 + 31725*uk_43 + 19881*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 203220939919*uk_47 + 362710791501*uk_48 + 20579335688*uk_49 + 8*uk_5 + 542779978771*uk_50 + 578793816225*uk_51 + 362710791501*uk_52 + 153424975*uk_53 + 220374055*uk_54 + 393325845*uk_55 + 22316360*uk_56 + 588593995*uk_57 + 627647625*uk_58 + 393325845*uk_59 + 211*uk_6 + 316537279*uk_60 + 564958941*uk_61 + 32054408*uk_62 + 845435011*uk_63 + 901530225*uk_64 + 564958941*uk_65 + 1008344439*uk_66 + 57211032*uk_67 + 1508940969*uk_68 + 1609060275*uk_69 + 225*uk_7 + 1008344439*uk_70 + 3246016*uk_71 + 85613672*uk_72 + 91294200*uk_73 + 57211032*uk_74 + 2258060599*uk_75 + 2407884525*uk_76 + 1508940969*uk_77 + 2567649375*uk_78 + 1609060275*uk_79 + 141*uk_8 + 1008344439*uk_80 + 166375*uk_81 + 238975*uk_82 + 426525*uk_83 + 24200*uk_84 + 638275*uk_85 + 680625*uk_86 + 426525*uk_87 + 343255*uk_88 + 612645*uk_89 + 2572416961*uk_9 + 34760*uk_90 + 916795*uk_91 + 977625*uk_92 + 612645*uk_93 + 1093455*uk_94 + 62040*uk_95 + 1636305*uk_96 + 1744875*uk_97 + 1093455*uk_98 + 3520*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 93720*uk_100 + 99000*uk_101 + 34760*uk_102 + 2495295*uk_103 + 2635875*uk_104 + 925485*uk_105 + 2784375*uk_106 + 977625*uk_107 + 343255*uk_108 + 15625*uk_109 + 1267975*uk_11 + 49375*uk_110 + 5000*uk_111 + 133125*uk_112 + 140625*uk_113 + 49375*uk_114 + 156025*uk_115 + 15800*uk_116 + 420675*uk_117 + 444375*uk_118 + 156025*uk_119 + 4006801*uk_12 + 1600*uk_120 + 42600*uk_121 + 45000*uk_122 + 15800*uk_123 + 1134225*uk_124 + 1198125*uk_125 + 420675*uk_126 + 1265625*uk_127 + 444375*uk_128 + 156025*uk_129 + 405752*uk_13 + 493039*uk_130 + 49928*uk_131 + 1329333*uk_132 + 1404225*uk_133 + 493039*uk_134 + 5056*uk_135 + 134616*uk_136 + 142200*uk_137 + 49928*uk_138 + 3584151*uk_139 + 10803147*uk_14 + 3786075*uk_140 + 1329333*uk_141 + 3999375*uk_142 + 1404225*uk_143 + 493039*uk_144 + 512*uk_145 + 13632*uk_146 + 14400*uk_147 + 5056*uk_148 + 362952*uk_149 + 11411775*uk_15 + 383400*uk_150 + 134616*uk_151 + 405000*uk_152 + 142200*uk_153 + 49928*uk_154 + 9663597*uk_155 + 10208025*uk_156 + 3584151*uk_157 + 10783125*uk_158 + 3786075*uk_159 + 4006801*uk_16 + 1329333*uk_160 + 11390625*uk_161 + 3999375*uk_162 + 1404225*uk_163 + 493039*uk_164 + 3025*uk_17 + 1375*uk_18 + 4345*uk_19 + 55*uk_2 + 440*uk_20 + 11715*uk_21 + 12375*uk_22 + 4345*uk_23 + 625*uk_24 + 1975*uk_25 + 200*uk_26 + 5325*uk_27 + 5625*uk_28 + 1975*uk_29 + 25*uk_3 + 6241*uk_30 + 632*uk_31 + 16827*uk_32 + 17775*uk_33 + 6241*uk_34 + 64*uk_35 + 1704*uk_36 + 1800*uk_37 + 632*uk_38 + 45369*uk_39 + 79*uk_4 + 47925*uk_40 + 16827*uk_41 + 50625*uk_42 + 17775*uk_43 + 6241*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 64310424025*uk_47 + 203220939919*uk_48 + 20579335688*uk_49 + 8*uk_5 + 547924812693*uk_50 + 578793816225*uk_51 + 203220939919*uk_52 + 153424975*uk_53 + 69738625*uk_54 + 220374055*uk_55 + 22316360*uk_56 + 594173085*uk_57 + 627647625*uk_58 + 220374055*uk_59 + 213*uk_6 + 31699375*uk_60 + 100170025*uk_61 + 10143800*uk_62 + 270078675*uk_63 + 285294375*uk_64 + 100170025*uk_65 + 316537279*uk_66 + 32054408*uk_67 + 853448613*uk_68 + 901530225*uk_69 + 225*uk_7 + 316537279*uk_70 + 3246016*uk_71 + 86425176*uk_72 + 91294200*uk_73 + 32054408*uk_74 + 2301070311*uk_75 + 2430708075*uk_76 + 853448613*uk_77 + 2567649375*uk_78 + 901530225*uk_79 + 79*uk_8 + 316537279*uk_80 + 166375*uk_81 + 75625*uk_82 + 238975*uk_83 + 24200*uk_84 + 644325*uk_85 + 680625*uk_86 + 238975*uk_87 + 34375*uk_88 + 108625*uk_89 + 2572416961*uk_9 + 11000*uk_90 + 292875*uk_91 + 309375*uk_92 + 108625*uk_93 + 343255*uk_94 + 34760*uk_95 + 925485*uk_96 + 977625*uk_97 + 343255*uk_98 + 3520*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 141900*uk_100 + 148500*uk_101 + 16500*uk_102 + 2542375*uk_103 + 2660625*uk_104 + 295625*uk_105 + 2784375*uk_106 + 309375*uk_107 + 34375*uk_108 + 7301384*uk_109 + 9839486*uk_11 + 940900*uk_110 + 451632*uk_111 + 8091740*uk_112 + 8468100*uk_113 + 940900*uk_114 + 121250*uk_115 + 58200*uk_116 + 1042750*uk_117 + 1091250*uk_118 + 121250*uk_119 + 1267975*uk_12 + 27936*uk_120 + 500520*uk_121 + 523800*uk_122 + 58200*uk_123 + 8967650*uk_124 + 9384750*uk_125 + 1042750*uk_126 + 9821250*uk_127 + 1091250*uk_128 + 121250*uk_129 + 608628*uk_13 + 15625*uk_130 + 7500*uk_131 + 134375*uk_132 + 140625*uk_133 + 15625*uk_134 + 3600*uk_135 + 64500*uk_136 + 67500*uk_137 + 7500*uk_138 + 1155625*uk_139 + 10904585*uk_14 + 1209375*uk_140 + 134375*uk_141 + 1265625*uk_142 + 140625*uk_143 + 15625*uk_144 + 1728*uk_145 + 30960*uk_146 + 32400*uk_147 + 3600*uk_148 + 554700*uk_149 + 11411775*uk_15 + 580500*uk_150 + 64500*uk_151 + 607500*uk_152 + 67500*uk_153 + 7500*uk_154 + 9938375*uk_155 + 10400625*uk_156 + 1155625*uk_157 + 10884375*uk_158 + 1209375*uk_159 + 1267975*uk_16 + 134375*uk_160 + 11390625*uk_161 + 1265625*uk_162 + 140625*uk_163 + 15625*uk_164 + 3025*uk_17 + 10670*uk_18 + 1375*uk_19 + 55*uk_2 + 660*uk_20 + 11825*uk_21 + 12375*uk_22 + 1375*uk_23 + 37636*uk_24 + 4850*uk_25 + 2328*uk_26 + 41710*uk_27 + 43650*uk_28 + 4850*uk_29 + 194*uk_3 + 625*uk_30 + 300*uk_31 + 5375*uk_32 + 5625*uk_33 + 625*uk_34 + 144*uk_35 + 2580*uk_36 + 2700*uk_37 + 300*uk_38 + 46225*uk_39 + 25*uk_4 + 48375*uk_40 + 5375*uk_41 + 50625*uk_42 + 5625*uk_43 + 625*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 499048890434*uk_47 + 64310424025*uk_48 + 30869003532*uk_49 + 12*uk_5 + 553069646615*uk_50 + 578793816225*uk_51 + 64310424025*uk_52 + 153424975*uk_53 + 541171730*uk_54 + 69738625*uk_55 + 33474540*uk_56 + 599752175*uk_57 + 627647625*uk_58 + 69738625*uk_59 + 215*uk_6 + 1908860284*uk_60 + 245987150*uk_61 + 118073832*uk_62 + 2115489490*uk_63 + 2213884350*uk_64 + 245987150*uk_65 + 31699375*uk_66 + 15215700*uk_67 + 272614625*uk_68 + 285294375*uk_69 + 225*uk_7 + 31699375*uk_70 + 7303536*uk_71 + 130855020*uk_72 + 136941300*uk_73 + 15215700*uk_74 + 2344485775*uk_75 + 2453531625*uk_76 + 272614625*uk_77 + 2567649375*uk_78 + 285294375*uk_79 + 25*uk_8 + 31699375*uk_80 + 166375*uk_81 + 586850*uk_82 + 75625*uk_83 + 36300*uk_84 + 650375*uk_85 + 680625*uk_86 + 75625*uk_87 + 2069980*uk_88 + 266750*uk_89 + 2572416961*uk_9 + 128040*uk_90 + 2294050*uk_91 + 2400750*uk_92 + 266750*uk_93 + 34375*uk_94 + 16500*uk_95 + 295625*uk_96 + 309375*uk_97 + 34375*uk_98 + 7920*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 95480*uk_100 + 99000*uk_101 + 85360*uk_102 + 2589895*uk_103 + 2685375*uk_104 + 2315390*uk_105 + 2784375*uk_106 + 2400750*uk_107 + 2069980*uk_108 + 3944312*uk_109 + 8013602*uk_11 + 4843016*uk_110 + 199712*uk_111 + 5417188*uk_112 + 5616900*uk_113 + 4843016*uk_114 + 5946488*uk_115 + 245216*uk_116 + 6651484*uk_117 + 6896700*uk_118 + 5946488*uk_119 + 9839486*uk_12 + 10112*uk_120 + 274288*uk_121 + 284400*uk_122 + 245216*uk_123 + 7440062*uk_124 + 7714350*uk_125 + 6651484*uk_126 + 7998750*uk_127 + 6896700*uk_128 + 5946488*uk_129 + 405752*uk_13 + 7301384*uk_130 + 301088*uk_131 + 8167012*uk_132 + 8468100*uk_133 + 7301384*uk_134 + 12416*uk_135 + 336784*uk_136 + 349200*uk_137 + 301088*uk_138 + 9135266*uk_139 + 11006023*uk_14 + 9472050*uk_140 + 8167012*uk_141 + 9821250*uk_142 + 8468100*uk_143 + 7301384*uk_144 + 512*uk_145 + 13888*uk_146 + 14400*uk_147 + 12416*uk_148 + 376712*uk_149 + 11411775*uk_15 + 390600*uk_150 + 336784*uk_151 + 405000*uk_152 + 349200*uk_153 + 301088*uk_154 + 10218313*uk_155 + 10595025*uk_156 + 9135266*uk_157 + 10985625*uk_158 + 9472050*uk_159 + 9839486*uk_16 + 8167012*uk_160 + 11390625*uk_161 + 9821250*uk_162 + 8468100*uk_163 + 7301384*uk_164 + 3025*uk_17 + 8690*uk_18 + 10670*uk_19 + 55*uk_2 + 440*uk_20 + 11935*uk_21 + 12375*uk_22 + 10670*uk_23 + 24964*uk_24 + 30652*uk_25 + 1264*uk_26 + 34286*uk_27 + 35550*uk_28 + 30652*uk_29 + 158*uk_3 + 37636*uk_30 + 1552*uk_31 + 42098*uk_32 + 43650*uk_33 + 37636*uk_34 + 64*uk_35 + 1736*uk_36 + 1800*uk_37 + 1552*uk_38 + 47089*uk_39 + 194*uk_4 + 48825*uk_40 + 42098*uk_41 + 50625*uk_42 + 43650*uk_43 + 37636*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 406441879838*uk_47 + 499048890434*uk_48 + 20579335688*uk_49 + 8*uk_5 + 558214480537*uk_50 + 578793816225*uk_51 + 499048890434*uk_52 + 153424975*uk_53 + 440748110*uk_54 + 541171730*uk_55 + 22316360*uk_56 + 605331265*uk_57 + 627647625*uk_58 + 541171730*uk_59 + 217*uk_6 + 1266149116*uk_60 + 1554638788*uk_61 + 64108816*uk_62 + 1738951634*uk_63 + 1803060450*uk_64 + 1554638788*uk_65 + 1908860284*uk_66 + 78715888*uk_67 + 2135168462*uk_68 + 2213884350*uk_69 + 225*uk_7 + 1908860284*uk_70 + 3246016*uk_71 + 88048184*uk_72 + 91294200*uk_73 + 78715888*uk_74 + 2388306991*uk_75 + 2476355175*uk_76 + 2135168462*uk_77 + 2567649375*uk_78 + 2213884350*uk_79 + 194*uk_8 + 1908860284*uk_80 + 166375*uk_81 + 477950*uk_82 + 586850*uk_83 + 24200*uk_84 + 656425*uk_85 + 680625*uk_86 + 586850*uk_87 + 1373020*uk_88 + 1685860*uk_89 + 2572416961*uk_9 + 69520*uk_90 + 1885730*uk_91 + 1955250*uk_92 + 1685860*uk_93 + 2069980*uk_94 + 85360*uk_95 + 2315390*uk_96 + 2400750*uk_97 + 2069980*uk_98 + 3520*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 96360*uk_100 + 99000*uk_101 + 69520*uk_102 + 2637855*uk_103 + 2710125*uk_104 + 1903110*uk_105 + 2784375*uk_106 + 1955250*uk_107 + 1373020*uk_108 + 2197000*uk_109 + 6593470*uk_11 + 2670200*uk_110 + 135200*uk_111 + 3701100*uk_112 + 3802500*uk_113 + 2670200*uk_114 + 3245320*uk_115 + 164320*uk_116 + 4498260*uk_117 + 4621500*uk_118 + 3245320*uk_119 + 8013602*uk_12 + 8320*uk_120 + 227760*uk_121 + 234000*uk_122 + 164320*uk_123 + 6234930*uk_124 + 6405750*uk_125 + 4498260*uk_126 + 6581250*uk_127 + 4621500*uk_128 + 3245320*uk_129 + 405752*uk_13 + 3944312*uk_130 + 199712*uk_131 + 5467116*uk_132 + 5616900*uk_133 + 3944312*uk_134 + 10112*uk_135 + 276816*uk_136 + 284400*uk_137 + 199712*uk_138 + 7577838*uk_139 + 11107461*uk_14 + 7785450*uk_140 + 5467116*uk_141 + 7998750*uk_142 + 5616900*uk_143 + 3944312*uk_144 + 512*uk_145 + 14016*uk_146 + 14400*uk_147 + 10112*uk_148 + 383688*uk_149 + 11411775*uk_15 + 394200*uk_150 + 276816*uk_151 + 405000*uk_152 + 284400*uk_153 + 199712*uk_154 + 10503459*uk_155 + 10791225*uk_156 + 7577838*uk_157 + 11086875*uk_158 + 7785450*uk_159 + 8013602*uk_16 + 5467116*uk_160 + 11390625*uk_161 + 7998750*uk_162 + 5616900*uk_163 + 3944312*uk_164 + 3025*uk_17 + 7150*uk_18 + 8690*uk_19 + 55*uk_2 + 440*uk_20 + 12045*uk_21 + 12375*uk_22 + 8690*uk_23 + 16900*uk_24 + 20540*uk_25 + 1040*uk_26 + 28470*uk_27 + 29250*uk_28 + 20540*uk_29 + 130*uk_3 + 24964*uk_30 + 1264*uk_31 + 34602*uk_32 + 35550*uk_33 + 24964*uk_34 + 64*uk_35 + 1752*uk_36 + 1800*uk_37 + 1264*uk_38 + 47961*uk_39 + 158*uk_4 + 49275*uk_40 + 34602*uk_41 + 50625*uk_42 + 35550*uk_43 + 24964*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 334414204930*uk_47 + 406441879838*uk_48 + 20579335688*uk_49 + 8*uk_5 + 563359314459*uk_50 + 578793816225*uk_51 + 406441879838*uk_52 + 153424975*uk_53 + 362640850*uk_54 + 440748110*uk_55 + 22316360*uk_56 + 610910355*uk_57 + 627647625*uk_58 + 440748110*uk_59 + 219*uk_6 + 857151100*uk_60 + 1041768260*uk_61 + 52747760*uk_62 + 1443969930*uk_63 + 1483530750*uk_64 + 1041768260*uk_65 + 1266149116*uk_66 + 64108816*uk_67 + 1754978838*uk_68 + 1803060450*uk_69 + 225*uk_7 + 1266149116*uk_70 + 3246016*uk_71 + 88859688*uk_72 + 91294200*uk_73 + 64108816*uk_74 + 2432533959*uk_75 + 2499178725*uk_76 + 1754978838*uk_77 + 2567649375*uk_78 + 1803060450*uk_79 + 158*uk_8 + 1266149116*uk_80 + 166375*uk_81 + 393250*uk_82 + 477950*uk_83 + 24200*uk_84 + 662475*uk_85 + 680625*uk_86 + 477950*uk_87 + 929500*uk_88 + 1129700*uk_89 + 2572416961*uk_9 + 57200*uk_90 + 1565850*uk_91 + 1608750*uk_92 + 1129700*uk_93 + 1373020*uk_94 + 69520*uk_95 + 1903110*uk_96 + 1955250*uk_97 + 1373020*uk_98 + 3520*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 97240*uk_100 + 99000*uk_101 + 57200*uk_102 + 2686255*uk_103 + 2734875*uk_104 + 1580150*uk_105 + 2784375*uk_106 + 1608750*uk_107 + 929500*uk_108 + 1331000*uk_109 + 5579090*uk_11 + 1573000*uk_110 + 96800*uk_111 + 2674100*uk_112 + 2722500*uk_113 + 1573000*uk_114 + 1859000*uk_115 + 114400*uk_116 + 3160300*uk_117 + 3217500*uk_118 + 1859000*uk_119 + 6593470*uk_12 + 7040*uk_120 + 194480*uk_121 + 198000*uk_122 + 114400*uk_123 + 5372510*uk_124 + 5469750*uk_125 + 3160300*uk_126 + 5568750*uk_127 + 3217500*uk_128 + 1859000*uk_129 + 405752*uk_13 + 2197000*uk_130 + 135200*uk_131 + 3734900*uk_132 + 3802500*uk_133 + 2197000*uk_134 + 8320*uk_135 + 229840*uk_136 + 234000*uk_137 + 135200*uk_138 + 6349330*uk_139 + 11208899*uk_14 + 6464250*uk_140 + 3734900*uk_141 + 6581250*uk_142 + 3802500*uk_143 + 2197000*uk_144 + 512*uk_145 + 14144*uk_146 + 14400*uk_147 + 8320*uk_148 + 390728*uk_149 + 11411775*uk_15 + 397800*uk_150 + 229840*uk_151 + 405000*uk_152 + 234000*uk_153 + 135200*uk_154 + 10793861*uk_155 + 10989225*uk_156 + 6349330*uk_157 + 11188125*uk_158 + 6464250*uk_159 + 6593470*uk_16 + 3734900*uk_160 + 11390625*uk_161 + 6581250*uk_162 + 3802500*uk_163 + 2197000*uk_164 + 3025*uk_17 + 6050*uk_18 + 7150*uk_19 + 55*uk_2 + 440*uk_20 + 12155*uk_21 + 12375*uk_22 + 7150*uk_23 + 12100*uk_24 + 14300*uk_25 + 880*uk_26 + 24310*uk_27 + 24750*uk_28 + 14300*uk_29 + 110*uk_3 + 16900*uk_30 + 1040*uk_31 + 28730*uk_32 + 29250*uk_33 + 16900*uk_34 + 64*uk_35 + 1768*uk_36 + 1800*uk_37 + 1040*uk_38 + 48841*uk_39 + 130*uk_4 + 49725*uk_40 + 28730*uk_41 + 50625*uk_42 + 29250*uk_43 + 16900*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 282965865710*uk_47 + 334414204930*uk_48 + 20579335688*uk_49 + 8*uk_5 + 568504148381*uk_50 + 578793816225*uk_51 + 334414204930*uk_52 + 153424975*uk_53 + 306849950*uk_54 + 362640850*uk_55 + 22316360*uk_56 + 616489445*uk_57 + 627647625*uk_58 + 362640850*uk_59 + 221*uk_6 + 613699900*uk_60 + 725281700*uk_61 + 44632720*uk_62 + 1232978890*uk_63 + 1255295250*uk_64 + 725281700*uk_65 + 857151100*uk_66 + 52747760*uk_67 + 1457156870*uk_68 + 1483530750*uk_69 + 225*uk_7 + 857151100*uk_70 + 3246016*uk_71 + 89671192*uk_72 + 91294200*uk_73 + 52747760*uk_74 + 2477166679*uk_75 + 2522002275*uk_76 + 1457156870*uk_77 + 2567649375*uk_78 + 1483530750*uk_79 + 130*uk_8 + 857151100*uk_80 + 166375*uk_81 + 332750*uk_82 + 393250*uk_83 + 24200*uk_84 + 668525*uk_85 + 680625*uk_86 + 393250*uk_87 + 665500*uk_88 + 786500*uk_89 + 2572416961*uk_9 + 48400*uk_90 + 1337050*uk_91 + 1361250*uk_92 + 786500*uk_93 + 929500*uk_94 + 57200*uk_95 + 1580150*uk_96 + 1608750*uk_97 + 929500*uk_98 + 3520*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 98120*uk_100 + 99000*uk_101 + 48400*uk_102 + 2735095*uk_103 + 2759625*uk_104 + 1349150*uk_105 + 2784375*uk_106 + 1361250*uk_107 + 665500*uk_108 + 941192*uk_109 + 4970462*uk_11 + 1056440*uk_110 + 76832*uk_111 + 2141692*uk_112 + 2160900*uk_113 + 1056440*uk_114 + 1185800*uk_115 + 86240*uk_116 + 2403940*uk_117 + 2425500*uk_118 + 1185800*uk_119 + 5579090*uk_12 + 6272*uk_120 + 174832*uk_121 + 176400*uk_122 + 86240*uk_123 + 4873442*uk_124 + 4917150*uk_125 + 2403940*uk_126 + 4961250*uk_127 + 2425500*uk_128 + 1185800*uk_129 + 405752*uk_13 + 1331000*uk_130 + 96800*uk_131 + 2698300*uk_132 + 2722500*uk_133 + 1331000*uk_134 + 7040*uk_135 + 196240*uk_136 + 198000*uk_137 + 96800*uk_138 + 5470190*uk_139 + 11310337*uk_14 + 5519250*uk_140 + 2698300*uk_141 + 5568750*uk_142 + 2722500*uk_143 + 1331000*uk_144 + 512*uk_145 + 14272*uk_146 + 14400*uk_147 + 7040*uk_148 + 397832*uk_149 + 11411775*uk_15 + 401400*uk_150 + 196240*uk_151 + 405000*uk_152 + 198000*uk_153 + 96800*uk_154 + 11089567*uk_155 + 11189025*uk_156 + 5470190*uk_157 + 11289375*uk_158 + 5519250*uk_159 + 5579090*uk_16 + 2698300*uk_160 + 11390625*uk_161 + 5568750*uk_162 + 2722500*uk_163 + 1331000*uk_164 + 3025*uk_17 + 5390*uk_18 + 6050*uk_19 + 55*uk_2 + 440*uk_20 + 12265*uk_21 + 12375*uk_22 + 6050*uk_23 + 9604*uk_24 + 10780*uk_25 + 784*uk_26 + 21854*uk_27 + 22050*uk_28 + 10780*uk_29 + 98*uk_3 + 12100*uk_30 + 880*uk_31 + 24530*uk_32 + 24750*uk_33 + 12100*uk_34 + 64*uk_35 + 1784*uk_36 + 1800*uk_37 + 880*uk_38 + 49729*uk_39 + 110*uk_4 + 50175*uk_40 + 24530*uk_41 + 50625*uk_42 + 24750*uk_43 + 12100*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 252096862178*uk_47 + 282965865710*uk_48 + 20579335688*uk_49 + 8*uk_5 + 573648982303*uk_50 + 578793816225*uk_51 + 282965865710*uk_52 + 153424975*uk_53 + 273375410*uk_54 + 306849950*uk_55 + 22316360*uk_56 + 622068535*uk_57 + 627647625*uk_58 + 306849950*uk_59 + 223*uk_6 + 487105276*uk_60 + 546750820*uk_61 + 39763696*uk_62 + 1108413026*uk_63 + 1118353950*uk_64 + 546750820*uk_65 + 613699900*uk_66 + 44632720*uk_67 + 1244137070*uk_68 + 1255295250*uk_69 + 225*uk_7 + 613699900*uk_70 + 3246016*uk_71 + 90482696*uk_72 + 91294200*uk_73 + 44632720*uk_74 + 2522205151*uk_75 + 2544825825*uk_76 + 1244137070*uk_77 + 2567649375*uk_78 + 1255295250*uk_79 + 110*uk_8 + 613699900*uk_80 + 166375*uk_81 + 296450*uk_82 + 332750*uk_83 + 24200*uk_84 + 674575*uk_85 + 680625*uk_86 + 332750*uk_87 + 528220*uk_88 + 592900*uk_89 + 2572416961*uk_9 + 43120*uk_90 + 1201970*uk_91 + 1212750*uk_92 + 592900*uk_93 + 665500*uk_94 + 48400*uk_95 + 1349150*uk_96 + 1361250*uk_97 + 665500*uk_98 + 3520*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 99000*uk_100 + 99000*uk_101 + 43120*uk_102 + 2784375*uk_103 + 2784375*uk_104 + 1212750*uk_105 + 2784375*uk_106 + 1212750*uk_107 + 528220*uk_108 + 830584*uk_109 + 4767586*uk_11 + 865928*uk_110 + 70688*uk_111 + 1988100*uk_112 + 1988100*uk_113 + 865928*uk_114 + 902776*uk_115 + 73696*uk_116 + 2072700*uk_117 + 2072700*uk_118 + 902776*uk_119 + 4970462*uk_12 + 6016*uk_120 + 169200*uk_121 + 169200*uk_122 + 73696*uk_123 + 4758750*uk_124 + 4758750*uk_125 + 2072700*uk_126 + 4758750*uk_127 + 2072700*uk_128 + 902776*uk_129 + 405752*uk_13 + 941192*uk_130 + 76832*uk_131 + 2160900*uk_132 + 2160900*uk_133 + 941192*uk_134 + 6272*uk_135 + 176400*uk_136 + 176400*uk_137 + 76832*uk_138 + 4961250*uk_139 + 11411775*uk_14 + 4961250*uk_140 + 2160900*uk_141 + 4961250*uk_142 + 2160900*uk_143 + 941192*uk_144 + 512*uk_145 + 14400*uk_146 + 14400*uk_147 + 6272*uk_148 + 405000*uk_149 + 11411775*uk_15 + 405000*uk_150 + 176400*uk_151 + 405000*uk_152 + 176400*uk_153 + 76832*uk_154 + 11390625*uk_155 + 11390625*uk_156 + 4961250*uk_157 + 11390625*uk_158 + 4961250*uk_159 + 4970462*uk_16 + 2160900*uk_160 + 11390625*uk_161 + 4961250*uk_162 + 2160900*uk_163 + 941192*uk_164 + 3025*uk_17 + 5170*uk_18 + 5390*uk_19 + 55*uk_2 + 440*uk_20 + 12375*uk_21 + 12375*uk_22 + 5390*uk_23 + 8836*uk_24 + 9212*uk_25 + 752*uk_26 + 21150*uk_27 + 21150*uk_28 + 9212*uk_29 + 94*uk_3 + 9604*uk_30 + 784*uk_31 + 22050*uk_32 + 22050*uk_33 + 9604*uk_34 + 64*uk_35 + 1800*uk_36 + 1800*uk_37 + 784*uk_38 + 50625*uk_39 + 98*uk_4 + 50625*uk_40 + 22050*uk_41 + 50625*uk_42 + 22050*uk_43 + 9604*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 241807194334*uk_47 + 252096862178*uk_48 + 20579335688*uk_49 + 8*uk_5 + 578793816225*uk_50 + 578793816225*uk_51 + 252096862178*uk_52 + 153424975*uk_53 + 262217230*uk_54 + 273375410*uk_55 + 22316360*uk_56 + 627647625*uk_57 + 627647625*uk_58 + 273375410*uk_59 + 225*uk_6 + 448153084*uk_60 + 467223428*uk_61 + 38140688*uk_62 + 1072706850*uk_63 + 1072706850*uk_64 + 467223428*uk_65 + 487105276*uk_66 + 39763696*uk_67 + 1118353950*uk_68 + 1118353950*uk_69 + 225*uk_7 + 487105276*uk_70 + 3246016*uk_71 + 91294200*uk_72 + 91294200*uk_73 + 39763696*uk_74 + 2567649375*uk_75 + 2567649375*uk_76 + 1118353950*uk_77 + 2567649375*uk_78 + 1118353950*uk_79 + 98*uk_8 + 487105276*uk_80 + 166375*uk_81 + 284350*uk_82 + 296450*uk_83 + 24200*uk_84 + 680625*uk_85 + 680625*uk_86 + 296450*uk_87 + 485980*uk_88 + 506660*uk_89 + 2572416961*uk_9 + 41360*uk_90 + 1163250*uk_91 + 1163250*uk_92 + 506660*uk_93 + 528220*uk_94 + 43120*uk_95 + 1212750*uk_96 + 1212750*uk_97 + 528220*uk_98 + 3520*uk_99, + uk_0 + 50719*uk_1 + 2789545*uk_10 + 99880*uk_100 + 99000*uk_101 + 41360*uk_102 + 2834095*uk_103 + 2809125*uk_104 + 1173590*uk_105 + 2784375*uk_106 + 1163250*uk_107 + 485980*uk_108 + 941192*uk_109 + 4970462*uk_11 + 902776*uk_110 + 76832*uk_111 + 2180108*uk_112 + 2160900*uk_113 + 902776*uk_114 + 865928*uk_115 + 73696*uk_116 + 2091124*uk_117 + 2072700*uk_118 + 865928*uk_119 + 4767586*uk_12 + 6272*uk_120 + 177968*uk_121 + 176400*uk_122 + 73696*uk_123 + 5049842*uk_124 + 5005350*uk_125 + 2091124*uk_126 + 4961250*uk_127 + 2072700*uk_128 + 865928*uk_129 + 405752*uk_13 + 830584*uk_130 + 70688*uk_131 + 2005772*uk_132 + 1988100*uk_133 + 830584*uk_134 + 6016*uk_135 + 170704*uk_136 + 169200*uk_137 + 70688*uk_138 + 4843726*uk_139 + 11513213*uk_14 + 4801050*uk_140 + 2005772*uk_141 + 4758750*uk_142 + 1988100*uk_143 + 830584*uk_144 + 512*uk_145 + 14528*uk_146 + 14400*uk_147 + 6016*uk_148 + 412232*uk_149 + 11411775*uk_15 + 408600*uk_150 + 170704*uk_151 + 405000*uk_152 + 169200*uk_153 + 70688*uk_154 + 11697083*uk_155 + 11594025*uk_156 + 4843726*uk_157 + 11491875*uk_158 + 4801050*uk_159 + 4767586*uk_16 + 2005772*uk_160 + 11390625*uk_161 + 4758750*uk_162 + 1988100*uk_163 + 830584*uk_164 + 3025*uk_17 + 5390*uk_18 + 5170*uk_19 + 55*uk_2 + 440*uk_20 + 12485*uk_21 + 12375*uk_22 + 5170*uk_23 + 9604*uk_24 + 9212*uk_25 + 784*uk_26 + 22246*uk_27 + 22050*uk_28 + 9212*uk_29 + 98*uk_3 + 8836*uk_30 + 752*uk_31 + 21338*uk_32 + 21150*uk_33 + 8836*uk_34 + 64*uk_35 + 1816*uk_36 + 1800*uk_37 + 752*uk_38 + 51529*uk_39 + 94*uk_4 + 51075*uk_40 + 21338*uk_41 + 50625*uk_42 + 21150*uk_43 + 8836*uk_44 + 130470415844959*uk_45 + 141482932855*uk_46 + 252096862178*uk_47 + 241807194334*uk_48 + 20579335688*uk_49 + 8*uk_5 + 583938650147*uk_50 + 578793816225*uk_51 + 241807194334*uk_52 + 153424975*uk_53 + 273375410*uk_54 + 262217230*uk_55 + 22316360*uk_56 + 633226715*uk_57 + 627647625*uk_58 + 262217230*uk_59 + 227*uk_6 + 487105276*uk_60 + 467223428*uk_61 + 39763696*uk_62 + 1128294874*uk_63 + 1118353950*uk_64 + 467223428*uk_65 + 448153084*uk_66 + 38140688*uk_67 + 1082242022*uk_68 + 1072706850*uk_69 + 225*uk_7 + 448153084*uk_70 + 3246016*uk_71 + 92105704*uk_72 + 91294200*uk_73 + 38140688*uk_74 + 2613499351*uk_75 + 2590472925*uk_76 + 1082242022*uk_77 + 2567649375*uk_78 + 1072706850*uk_79 + 94*uk_8 + 448153084*uk_80 + 166375*uk_81 + 296450*uk_82 + 284350*uk_83 + 24200*uk_84 + 686675*uk_85 + 680625*uk_86 + 284350*uk_87 + 528220*uk_88 + 506660*uk_89 + 2572416961*uk_9 + 43120*uk_90 + 1223530*uk_91 + 1212750*uk_92 + 506660*uk_93 + 485980*uk_94 + 41360*uk_95 + 1173590*uk_96 + 1163250*uk_97 + 485980*uk_98 + 3520*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 396900*uk_100 + 1367100*uk_101 + 250047*uk_103 + 861273*uk_104 + 2966607*uk_106 + 64000*uk_109 + 1894120*uk_11 + 27200*uk_110 + 160000*uk_111 + 100800*uk_112 + 347200*uk_113 + 11560*uk_115 + 68000*uk_116 + 42840*uk_117 + 147560*uk_118 + 805001*uk_12 + 400000*uk_120 + 252000*uk_121 + 868000*uk_122 + 158760*uk_124 + 546840*uk_125 + 1883560*uk_127 + 4735300*uk_13 + 4913*uk_130 + 28900*uk_131 + 18207*uk_132 + 62713*uk_133 + 170000*uk_135 + 107100*uk_136 + 368900*uk_137 + 67473*uk_139 + 2983239*uk_14 + 232407*uk_140 + 800513*uk_142 + 1000000*uk_145 + 630000*uk_146 + 2170000*uk_147 + 396900*uk_149 + 10275601*uk_15 + 1367100*uk_150 + 4708900*uk_152 + 250047*uk_155 + 861273*uk_156 + 2966607*uk_158 + 10218313*uk_161 + 3969*uk_17 + 2520*uk_18 + 1071*uk_19 + 63*uk_2 + 6300*uk_20 + 3969*uk_21 + 13671*uk_22 + 1600*uk_24 + 680*uk_25 + 4000*uk_26 + 2520*uk_27 + 8680*uk_28 + 40*uk_3 + 289*uk_30 + 1700*uk_31 + 1071*uk_32 + 3689*uk_33 + 10000*uk_35 + 6300*uk_36 + 21700*uk_37 + 3969*uk_39 + 17*uk_4 + 13671*uk_40 + 47089*uk_42 + 106179944855977*uk_45 + 141265316367*uk_46 + 89692264360*uk_47 + 38119212353*uk_48 + 224230660900*uk_49 + 100*uk_5 + 141265316367*uk_50 + 486580534153*uk_51 + 187944057*uk_53 + 119329560*uk_54 + 50715063*uk_55 + 298323900*uk_56 + 187944057*uk_57 + 647362863*uk_58 + 63*uk_6 + 75764800*uk_60 + 32200040*uk_61 + 189412000*uk_62 + 119329560*uk_63 + 411024040*uk_64 + 13685017*uk_66 + 80500100*uk_67 + 50715063*uk_68 + 174685217*uk_69 + 217*uk_7 + 473530000*uk_71 + 298323900*uk_72 + 1027560100*uk_73 + 187944057*uk_75 + 647362863*uk_76 + 2229805417*uk_78 + 250047*uk_81 + 158760*uk_82 + 67473*uk_83 + 396900*uk_84 + 250047*uk_85 + 861273*uk_86 + 100800*uk_88 + 42840*uk_89 + 2242306609*uk_9 + 252000*uk_90 + 158760*uk_91 + 546840*uk_92 + 18207*uk_94 + 107100*uk_95 + 67473*uk_96 + 232407*uk_97 + 630000*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 376740*uk_100 + 1257732*uk_101 + 231840*uk_102 + 266175*uk_103 + 888615*uk_104 + 163800*uk_105 + 2966607*uk_106 + 546840*uk_107 + 100800*uk_108 + 35937*uk_109 + 1562649*uk_11 + 43560*uk_110 + 100188*uk_111 + 70785*uk_112 + 236313*uk_113 + 43560*uk_114 + 52800*uk_115 + 121440*uk_116 + 85800*uk_117 + 286440*uk_118 + 52800*uk_119 + 1894120*uk_12 + 279312*uk_120 + 197340*uk_121 + 658812*uk_122 + 121440*uk_123 + 139425*uk_124 + 465465*uk_125 + 85800*uk_126 + 1553937*uk_127 + 286440*uk_128 + 52800*uk_129 + 4356476*uk_13 + 64000*uk_130 + 147200*uk_131 + 104000*uk_132 + 347200*uk_133 + 64000*uk_134 + 338560*uk_135 + 239200*uk_136 + 798560*uk_137 + 147200*uk_138 + 169000*uk_139 + 3077945*uk_14 + 564200*uk_140 + 104000*uk_141 + 1883560*uk_142 + 347200*uk_143 + 64000*uk_144 + 778688*uk_145 + 550160*uk_146 + 1836688*uk_147 + 338560*uk_148 + 388700*uk_149 + 10275601*uk_15 + 1297660*uk_150 + 239200*uk_151 + 4332188*uk_152 + 798560*uk_153 + 147200*uk_154 + 274625*uk_155 + 916825*uk_156 + 169000*uk_157 + 3060785*uk_158 + 564200*uk_159 + 1894120*uk_16 + 104000*uk_160 + 10218313*uk_161 + 1883560*uk_162 + 347200*uk_163 + 64000*uk_164 + 3969*uk_17 + 2079*uk_18 + 2520*uk_19 + 63*uk_2 + 5796*uk_20 + 4095*uk_21 + 13671*uk_22 + 2520*uk_23 + 1089*uk_24 + 1320*uk_25 + 3036*uk_26 + 2145*uk_27 + 7161*uk_28 + 1320*uk_29 + 33*uk_3 + 1600*uk_30 + 3680*uk_31 + 2600*uk_32 + 8680*uk_33 + 1600*uk_34 + 8464*uk_35 + 5980*uk_36 + 19964*uk_37 + 3680*uk_38 + 4225*uk_39 + 40*uk_4 + 14105*uk_40 + 2600*uk_41 + 47089*uk_42 + 8680*uk_43 + 1600*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 73996118097*uk_47 + 89692264360*uk_48 + 206292208028*uk_49 + 92*uk_5 + 145749929585*uk_50 + 486580534153*uk_51 + 89692264360*uk_52 + 187944057*uk_53 + 98446887*uk_54 + 119329560*uk_55 + 274457988*uk_56 + 193910535*uk_57 + 647362863*uk_58 + 119329560*uk_59 + 65*uk_6 + 51567417*uk_60 + 62505960*uk_61 + 143763708*uk_62 + 101572185*uk_63 + 339094833*uk_64 + 62505960*uk_65 + 75764800*uk_66 + 174259040*uk_67 + 123117800*uk_68 + 411024040*uk_69 + 217*uk_7 + 75764800*uk_70 + 400795792*uk_71 + 283170940*uk_72 + 945355292*uk_73 + 174259040*uk_74 + 200066425*uk_75 + 667914065*uk_76 + 123117800*uk_77 + 2229805417*uk_78 + 411024040*uk_79 + 40*uk_8 + 75764800*uk_80 + 250047*uk_81 + 130977*uk_82 + 158760*uk_83 + 365148*uk_84 + 257985*uk_85 + 861273*uk_86 + 158760*uk_87 + 68607*uk_88 + 83160*uk_89 + 2242306609*uk_9 + 191268*uk_90 + 135135*uk_91 + 451143*uk_92 + 83160*uk_93 + 100800*uk_94 + 231840*uk_95 + 163800*uk_96 + 546840*uk_97 + 100800*uk_98 + 533232*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 371448*uk_100 + 1203048*uk_101 + 182952*uk_102 + 282807*uk_103 + 915957*uk_104 + 139293*uk_105 + 2966607*uk_106 + 451143*uk_107 + 68607*uk_108 + 132651*uk_109 + 2415003*uk_11 + 85833*uk_110 + 228888*uk_111 + 174267*uk_112 + 564417*uk_113 + 85833*uk_114 + 55539*uk_115 + 148104*uk_116 + 112761*uk_117 + 365211*uk_118 + 55539*uk_119 + 1562649*uk_12 + 394944*uk_120 + 300696*uk_121 + 973896*uk_122 + 148104*uk_123 + 228939*uk_124 + 741489*uk_125 + 112761*uk_126 + 2401539*uk_127 + 365211*uk_128 + 55539*uk_129 + 4167064*uk_13 + 35937*uk_130 + 95832*uk_131 + 72963*uk_132 + 236313*uk_133 + 35937*uk_134 + 255552*uk_135 + 194568*uk_136 + 630168*uk_137 + 95832*uk_138 + 148137*uk_139 + 3172651*uk_14 + 479787*uk_140 + 72963*uk_141 + 1553937*uk_142 + 236313*uk_143 + 35937*uk_144 + 681472*uk_145 + 518848*uk_146 + 1680448*uk_147 + 255552*uk_148 + 395032*uk_149 + 10275601*uk_15 + 1279432*uk_150 + 194568*uk_151 + 4143832*uk_152 + 630168*uk_153 + 95832*uk_154 + 300763*uk_155 + 974113*uk_156 + 148137*uk_157 + 3154963*uk_158 + 479787*uk_159 + 1562649*uk_16 + 72963*uk_160 + 10218313*uk_161 + 1553937*uk_162 + 236313*uk_163 + 35937*uk_164 + 3969*uk_17 + 3213*uk_18 + 2079*uk_19 + 63*uk_2 + 5544*uk_20 + 4221*uk_21 + 13671*uk_22 + 2079*uk_23 + 2601*uk_24 + 1683*uk_25 + 4488*uk_26 + 3417*uk_27 + 11067*uk_28 + 1683*uk_29 + 51*uk_3 + 1089*uk_30 + 2904*uk_31 + 2211*uk_32 + 7161*uk_33 + 1089*uk_34 + 7744*uk_35 + 5896*uk_36 + 19096*uk_37 + 2904*uk_38 + 4489*uk_39 + 33*uk_4 + 14539*uk_40 + 2211*uk_41 + 47089*uk_42 + 7161*uk_43 + 1089*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 114357637059*uk_47 + 73996118097*uk_48 + 197322981592*uk_49 + 88*uk_5 + 150234542803*uk_50 + 486580534153*uk_51 + 73996118097*uk_52 + 187944057*uk_53 + 152145189*uk_54 + 98446887*uk_55 + 262525032*uk_56 + 199877013*uk_57 + 647362863*uk_58 + 98446887*uk_59 + 67*uk_6 + 123165153*uk_60 + 79695099*uk_61 + 212520264*uk_62 + 161805201*uk_63 + 524055651*uk_64 + 79695099*uk_65 + 51567417*uk_66 + 137513112*uk_67 + 104697483*uk_68 + 339094833*uk_69 + 217*uk_7 + 51567417*uk_70 + 366701632*uk_71 + 279193288*uk_72 + 904252888*uk_73 + 137513112*uk_74 + 212567617*uk_75 + 688465267*uk_76 + 104697483*uk_77 + 2229805417*uk_78 + 339094833*uk_79 + 33*uk_8 + 51567417*uk_80 + 250047*uk_81 + 202419*uk_82 + 130977*uk_83 + 349272*uk_84 + 265923*uk_85 + 861273*uk_86 + 130977*uk_87 + 163863*uk_88 + 106029*uk_89 + 2242306609*uk_9 + 282744*uk_90 + 215271*uk_91 + 697221*uk_92 + 106029*uk_93 + 68607*uk_94 + 182952*uk_95 + 139293*uk_96 + 451143*uk_97 + 68607*uk_98 + 487872*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 347760*uk_100 + 1093680*uk_101 + 257040*uk_102 + 299943*uk_103 + 943299*uk_104 + 221697*uk_105 + 2966607*uk_106 + 697221*uk_107 + 163863*uk_108 + 6859*uk_109 + 899707*uk_11 + 18411*uk_110 + 28880*uk_111 + 24909*uk_112 + 78337*uk_113 + 18411*uk_114 + 49419*uk_115 + 77520*uk_116 + 66861*uk_117 + 210273*uk_118 + 49419*uk_119 + 2415003*uk_12 + 121600*uk_120 + 104880*uk_121 + 329840*uk_122 + 77520*uk_123 + 90459*uk_124 + 284487*uk_125 + 66861*uk_126 + 894691*uk_127 + 210273*uk_128 + 49419*uk_129 + 3788240*uk_13 + 132651*uk_130 + 208080*uk_131 + 179469*uk_132 + 564417*uk_133 + 132651*uk_134 + 326400*uk_135 + 281520*uk_136 + 885360*uk_137 + 208080*uk_138 + 242811*uk_139 + 3267357*uk_14 + 763623*uk_140 + 179469*uk_141 + 2401539*uk_142 + 564417*uk_143 + 132651*uk_144 + 512000*uk_145 + 441600*uk_146 + 1388800*uk_147 + 326400*uk_148 + 380880*uk_149 + 10275601*uk_15 + 1197840*uk_150 + 281520*uk_151 + 3767120*uk_152 + 885360*uk_153 + 208080*uk_154 + 328509*uk_155 + 1033137*uk_156 + 242811*uk_157 + 3249141*uk_158 + 763623*uk_159 + 2415003*uk_16 + 179469*uk_160 + 10218313*uk_161 + 2401539*uk_162 + 564417*uk_163 + 132651*uk_164 + 3969*uk_17 + 1197*uk_18 + 3213*uk_19 + 63*uk_2 + 5040*uk_20 + 4347*uk_21 + 13671*uk_22 + 3213*uk_23 + 361*uk_24 + 969*uk_25 + 1520*uk_26 + 1311*uk_27 + 4123*uk_28 + 969*uk_29 + 19*uk_3 + 2601*uk_30 + 4080*uk_31 + 3519*uk_32 + 11067*uk_33 + 2601*uk_34 + 6400*uk_35 + 5520*uk_36 + 17360*uk_37 + 4080*uk_38 + 4761*uk_39 + 51*uk_4 + 14973*uk_40 + 3519*uk_41 + 47089*uk_42 + 11067*uk_43 + 2601*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 42603825571*uk_47 + 114357637059*uk_48 + 179384528720*uk_49 + 80*uk_5 + 154719156021*uk_50 + 486580534153*uk_51 + 114357637059*uk_52 + 187944057*uk_53 + 56681541*uk_54 + 152145189*uk_55 + 238659120*uk_56 + 205843491*uk_57 + 647362863*uk_58 + 152145189*uk_59 + 69*uk_6 + 17094433*uk_60 + 45885057*uk_61 + 71976560*uk_62 + 62079783*uk_63 + 195236419*uk_64 + 45885057*uk_65 + 123165153*uk_66 + 193200240*uk_67 + 166635207*uk_68 + 524055651*uk_69 + 217*uk_7 + 123165153*uk_70 + 303059200*uk_71 + 261388560*uk_72 + 822048080*uk_73 + 193200240*uk_74 + 225447633*uk_75 + 709016469*uk_76 + 166635207*uk_77 + 2229805417*uk_78 + 524055651*uk_79 + 51*uk_8 + 123165153*uk_80 + 250047*uk_81 + 75411*uk_82 + 202419*uk_83 + 317520*uk_84 + 273861*uk_85 + 861273*uk_86 + 202419*uk_87 + 22743*uk_88 + 61047*uk_89 + 2242306609*uk_9 + 95760*uk_90 + 82593*uk_91 + 259749*uk_92 + 61047*uk_93 + 163863*uk_94 + 257040*uk_95 + 221697*uk_96 + 697221*uk_97 + 163863*uk_98 + 403200*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 357840*uk_100 + 1093680*uk_101 + 95760*uk_102 + 317583*uk_103 + 970641*uk_104 + 84987*uk_105 + 2966607*uk_106 + 259749*uk_107 + 22743*uk_108 + 300763*uk_109 + 3172651*uk_11 + 85291*uk_110 + 359120*uk_111 + 318719*uk_112 + 974113*uk_113 + 85291*uk_114 + 24187*uk_115 + 101840*uk_116 + 90383*uk_117 + 276241*uk_118 + 24187*uk_119 + 899707*uk_12 + 428800*uk_120 + 380560*uk_121 + 1163120*uk_122 + 101840*uk_123 + 337747*uk_124 + 1032269*uk_125 + 90383*uk_126 + 3154963*uk_127 + 276241*uk_128 + 24187*uk_129 + 3788240*uk_13 + 6859*uk_130 + 28880*uk_131 + 25631*uk_132 + 78337*uk_133 + 6859*uk_134 + 121600*uk_135 + 107920*uk_136 + 329840*uk_137 + 28880*uk_138 + 95779*uk_139 + 3362063*uk_14 + 292733*uk_140 + 25631*uk_141 + 894691*uk_142 + 78337*uk_143 + 6859*uk_144 + 512000*uk_145 + 454400*uk_146 + 1388800*uk_147 + 121600*uk_148 + 403280*uk_149 + 10275601*uk_15 + 1232560*uk_150 + 107920*uk_151 + 3767120*uk_152 + 329840*uk_153 + 28880*uk_154 + 357911*uk_155 + 1093897*uk_156 + 95779*uk_157 + 3343319*uk_158 + 292733*uk_159 + 899707*uk_16 + 25631*uk_160 + 10218313*uk_161 + 894691*uk_162 + 78337*uk_163 + 6859*uk_164 + 3969*uk_17 + 4221*uk_18 + 1197*uk_19 + 63*uk_2 + 5040*uk_20 + 4473*uk_21 + 13671*uk_22 + 1197*uk_23 + 4489*uk_24 + 1273*uk_25 + 5360*uk_26 + 4757*uk_27 + 14539*uk_28 + 1273*uk_29 + 67*uk_3 + 361*uk_30 + 1520*uk_31 + 1349*uk_32 + 4123*uk_33 + 361*uk_34 + 6400*uk_35 + 5680*uk_36 + 17360*uk_37 + 1520*uk_38 + 5041*uk_39 + 19*uk_4 + 15407*uk_40 + 1349*uk_41 + 47089*uk_42 + 4123*uk_43 + 361*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 150234542803*uk_47 + 42603825571*uk_48 + 179384528720*uk_49 + 80*uk_5 + 159203769239*uk_50 + 486580534153*uk_51 + 42603825571*uk_52 + 187944057*uk_53 + 199877013*uk_54 + 56681541*uk_55 + 238659120*uk_56 + 211809969*uk_57 + 647362863*uk_58 + 56681541*uk_59 + 71*uk_6 + 212567617*uk_60 + 60280369*uk_61 + 253812080*uk_62 + 225258221*uk_63 + 688465267*uk_64 + 60280369*uk_65 + 17094433*uk_66 + 71976560*uk_67 + 63879197*uk_68 + 195236419*uk_69 + 217*uk_7 + 17094433*uk_70 + 303059200*uk_71 + 268965040*uk_72 + 822048080*uk_73 + 71976560*uk_74 + 238706473*uk_75 + 729567671*uk_76 + 63879197*uk_77 + 2229805417*uk_78 + 195236419*uk_79 + 19*uk_8 + 17094433*uk_80 + 250047*uk_81 + 265923*uk_82 + 75411*uk_83 + 317520*uk_84 + 281799*uk_85 + 861273*uk_86 + 75411*uk_87 + 282807*uk_88 + 80199*uk_89 + 2242306609*uk_9 + 337680*uk_90 + 299691*uk_91 + 915957*uk_92 + 80199*uk_93 + 22743*uk_94 + 95760*uk_95 + 84987*uk_96 + 259749*uk_97 + 22743*uk_98 + 403200*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 331128*uk_100 + 984312*uk_101 + 303912*uk_102 + 335727*uk_103 + 997983*uk_104 + 308133*uk_105 + 2966607*uk_106 + 915957*uk_107 + 282807*uk_108 + 117649*uk_109 + 2320297*uk_11 + 160867*uk_110 + 172872*uk_111 + 175273*uk_112 + 521017*uk_113 + 160867*uk_114 + 219961*uk_115 + 236376*uk_116 + 239659*uk_117 + 712411*uk_118 + 219961*uk_119 + 3172651*uk_12 + 254016*uk_120 + 257544*uk_121 + 765576*uk_122 + 236376*uk_123 + 261121*uk_124 + 776209*uk_125 + 239659*uk_126 + 2307361*uk_127 + 712411*uk_128 + 219961*uk_129 + 3409416*uk_13 + 300763*uk_130 + 323208*uk_131 + 327697*uk_132 + 974113*uk_133 + 300763*uk_134 + 347328*uk_135 + 352152*uk_136 + 1046808*uk_137 + 323208*uk_138 + 357043*uk_139 + 3456769*uk_14 + 1061347*uk_140 + 327697*uk_141 + 3154963*uk_142 + 974113*uk_143 + 300763*uk_144 + 373248*uk_145 + 378432*uk_146 + 1124928*uk_147 + 347328*uk_148 + 383688*uk_149 + 10275601*uk_15 + 1140552*uk_150 + 352152*uk_151 + 3390408*uk_152 + 1046808*uk_153 + 323208*uk_154 + 389017*uk_155 + 1156393*uk_156 + 357043*uk_157 + 3437497*uk_158 + 1061347*uk_159 + 3172651*uk_16 + 327697*uk_160 + 10218313*uk_161 + 3154963*uk_162 + 974113*uk_163 + 300763*uk_164 + 3969*uk_17 + 3087*uk_18 + 4221*uk_19 + 63*uk_2 + 4536*uk_20 + 4599*uk_21 + 13671*uk_22 + 4221*uk_23 + 2401*uk_24 + 3283*uk_25 + 3528*uk_26 + 3577*uk_27 + 10633*uk_28 + 3283*uk_29 + 49*uk_3 + 4489*uk_30 + 4824*uk_31 + 4891*uk_32 + 14539*uk_33 + 4489*uk_34 + 5184*uk_35 + 5256*uk_36 + 15624*uk_37 + 4824*uk_38 + 5329*uk_39 + 67*uk_4 + 15841*uk_40 + 4891*uk_41 + 47089*uk_42 + 14539*uk_43 + 4489*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 109873023841*uk_47 + 150234542803*uk_48 + 161446075848*uk_49 + 72*uk_5 + 163688382457*uk_50 + 486580534153*uk_51 + 150234542803*uk_52 + 187944057*uk_53 + 146178711*uk_54 + 199877013*uk_55 + 214793208*uk_56 + 217776447*uk_57 + 647362863*uk_58 + 199877013*uk_59 + 73*uk_6 + 113694553*uk_60 + 155459899*uk_61 + 167061384*uk_62 + 169381681*uk_63 + 503504449*uk_64 + 155459899*uk_65 + 212567617*uk_66 + 228430872*uk_67 + 231603523*uk_68 + 688465267*uk_69 + 217*uk_7 + 212567617*uk_70 + 245477952*uk_71 + 248887368*uk_72 + 739843272*uk_73 + 228430872*uk_74 + 252344137*uk_75 + 750118873*uk_76 + 231603523*uk_77 + 2229805417*uk_78 + 688465267*uk_79 + 67*uk_8 + 212567617*uk_80 + 250047*uk_81 + 194481*uk_82 + 265923*uk_83 + 285768*uk_84 + 289737*uk_85 + 861273*uk_86 + 265923*uk_87 + 151263*uk_88 + 206829*uk_89 + 2242306609*uk_9 + 222264*uk_90 + 225351*uk_91 + 669879*uk_92 + 206829*uk_93 + 282807*uk_94 + 303912*uk_95 + 308133*uk_96 + 915957*uk_97 + 282807*uk_98 + 326592*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 321300*uk_100 + 929628*uk_101 + 209916*uk_102 + 354375*uk_103 + 1025325*uk_104 + 231525*uk_105 + 2966607*uk_106 + 669879*uk_107 + 151263*uk_108 + 21952*uk_109 + 1325884*uk_11 + 38416*uk_110 + 53312*uk_111 + 58800*uk_112 + 170128*uk_113 + 38416*uk_114 + 67228*uk_115 + 93296*uk_116 + 102900*uk_117 + 297724*uk_118 + 67228*uk_119 + 2320297*uk_12 + 129472*uk_120 + 142800*uk_121 + 413168*uk_122 + 93296*uk_123 + 157500*uk_124 + 455700*uk_125 + 102900*uk_126 + 1318492*uk_127 + 297724*uk_128 + 67228*uk_129 + 3220004*uk_13 + 117649*uk_130 + 163268*uk_131 + 180075*uk_132 + 521017*uk_133 + 117649*uk_134 + 226576*uk_135 + 249900*uk_136 + 723044*uk_137 + 163268*uk_138 + 275625*uk_139 + 3551475*uk_14 + 797475*uk_140 + 180075*uk_141 + 2307361*uk_142 + 521017*uk_143 + 117649*uk_144 + 314432*uk_145 + 346800*uk_146 + 1003408*uk_147 + 226576*uk_148 + 382500*uk_149 + 10275601*uk_15 + 1106700*uk_150 + 249900*uk_151 + 3202052*uk_152 + 723044*uk_153 + 163268*uk_154 + 421875*uk_155 + 1220625*uk_156 + 275625*uk_157 + 3531675*uk_158 + 797475*uk_159 + 2320297*uk_16 + 180075*uk_160 + 10218313*uk_161 + 2307361*uk_162 + 521017*uk_163 + 117649*uk_164 + 3969*uk_17 + 1764*uk_18 + 3087*uk_19 + 63*uk_2 + 4284*uk_20 + 4725*uk_21 + 13671*uk_22 + 3087*uk_23 + 784*uk_24 + 1372*uk_25 + 1904*uk_26 + 2100*uk_27 + 6076*uk_28 + 1372*uk_29 + 28*uk_3 + 2401*uk_30 + 3332*uk_31 + 3675*uk_32 + 10633*uk_33 + 2401*uk_34 + 4624*uk_35 + 5100*uk_36 + 14756*uk_37 + 3332*uk_38 + 5625*uk_39 + 49*uk_4 + 16275*uk_40 + 3675*uk_41 + 47089*uk_42 + 10633*uk_43 + 2401*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 62784585052*uk_47 + 109873023841*uk_48 + 152476849412*uk_49 + 68*uk_5 + 168172995675*uk_50 + 486580534153*uk_51 + 109873023841*uk_52 + 187944057*uk_53 + 83530692*uk_54 + 146178711*uk_55 + 202860252*uk_56 + 223742925*uk_57 + 647362863*uk_58 + 146178711*uk_59 + 75*uk_6 + 37124752*uk_60 + 64968316*uk_61 + 90160112*uk_62 + 99441300*uk_63 + 287716828*uk_64 + 64968316*uk_65 + 113694553*uk_66 + 157780196*uk_67 + 174022275*uk_68 + 503504449*uk_69 + 217*uk_7 + 113694553*uk_70 + 218960272*uk_71 + 241500300*uk_72 + 698740868*uk_73 + 157780196*uk_74 + 266360625*uk_75 + 770670075*uk_76 + 174022275*uk_77 + 2229805417*uk_78 + 503504449*uk_79 + 49*uk_8 + 113694553*uk_80 + 250047*uk_81 + 111132*uk_82 + 194481*uk_83 + 269892*uk_84 + 297675*uk_85 + 861273*uk_86 + 194481*uk_87 + 49392*uk_88 + 86436*uk_89 + 2242306609*uk_9 + 119952*uk_90 + 132300*uk_91 + 382788*uk_92 + 86436*uk_93 + 151263*uk_94 + 209916*uk_95 + 231525*uk_96 + 669879*uk_97 + 151263*uk_98 + 291312*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 329868*uk_100 + 929628*uk_101 + 119952*uk_102 + 373527*uk_103 + 1052667*uk_104 + 135828*uk_105 + 2966607*uk_106 + 382788*uk_107 + 49392*uk_108 + 421875*uk_109 + 3551475*uk_11 + 157500*uk_110 + 382500*uk_111 + 433125*uk_112 + 1220625*uk_113 + 157500*uk_114 + 58800*uk_115 + 142800*uk_116 + 161700*uk_117 + 455700*uk_118 + 58800*uk_119 + 1325884*uk_12 + 346800*uk_120 + 392700*uk_121 + 1106700*uk_122 + 142800*uk_123 + 444675*uk_124 + 1253175*uk_125 + 161700*uk_126 + 3531675*uk_127 + 455700*uk_128 + 58800*uk_129 + 3220004*uk_13 + 21952*uk_130 + 53312*uk_131 + 60368*uk_132 + 170128*uk_133 + 21952*uk_134 + 129472*uk_135 + 146608*uk_136 + 413168*uk_137 + 53312*uk_138 + 166012*uk_139 + 3646181*uk_14 + 467852*uk_140 + 60368*uk_141 + 1318492*uk_142 + 170128*uk_143 + 21952*uk_144 + 314432*uk_145 + 356048*uk_146 + 1003408*uk_147 + 129472*uk_148 + 403172*uk_149 + 10275601*uk_15 + 1136212*uk_150 + 146608*uk_151 + 3202052*uk_152 + 413168*uk_153 + 53312*uk_154 + 456533*uk_155 + 1286593*uk_156 + 166012*uk_157 + 3625853*uk_158 + 467852*uk_159 + 1325884*uk_16 + 60368*uk_160 + 10218313*uk_161 + 1318492*uk_162 + 170128*uk_163 + 21952*uk_164 + 3969*uk_17 + 4725*uk_18 + 1764*uk_19 + 63*uk_2 + 4284*uk_20 + 4851*uk_21 + 13671*uk_22 + 1764*uk_23 + 5625*uk_24 + 2100*uk_25 + 5100*uk_26 + 5775*uk_27 + 16275*uk_28 + 2100*uk_29 + 75*uk_3 + 784*uk_30 + 1904*uk_31 + 2156*uk_32 + 6076*uk_33 + 784*uk_34 + 4624*uk_35 + 5236*uk_36 + 14756*uk_37 + 1904*uk_38 + 5929*uk_39 + 28*uk_4 + 16709*uk_40 + 2156*uk_41 + 47089*uk_42 + 6076*uk_43 + 784*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 168172995675*uk_47 + 62784585052*uk_48 + 152476849412*uk_49 + 68*uk_5 + 172657608893*uk_50 + 486580534153*uk_51 + 62784585052*uk_52 + 187944057*uk_53 + 223742925*uk_54 + 83530692*uk_55 + 202860252*uk_56 + 229709403*uk_57 + 647362863*uk_58 + 83530692*uk_59 + 77*uk_6 + 266360625*uk_60 + 99441300*uk_61 + 241500300*uk_62 + 273463575*uk_63 + 770670075*uk_64 + 99441300*uk_65 + 37124752*uk_66 + 90160112*uk_67 + 102093068*uk_68 + 287716828*uk_69 + 217*uk_7 + 37124752*uk_70 + 218960272*uk_71 + 247940308*uk_72 + 698740868*uk_73 + 90160112*uk_74 + 280755937*uk_75 + 791221277*uk_76 + 102093068*uk_77 + 2229805417*uk_78 + 287716828*uk_79 + 28*uk_8 + 37124752*uk_80 + 250047*uk_81 + 297675*uk_82 + 111132*uk_83 + 269892*uk_84 + 305613*uk_85 + 861273*uk_86 + 111132*uk_87 + 354375*uk_88 + 132300*uk_89 + 2242306609*uk_9 + 321300*uk_90 + 363825*uk_91 + 1025325*uk_92 + 132300*uk_93 + 49392*uk_94 + 119952*uk_95 + 135828*uk_96 + 382788*uk_97 + 49392*uk_98 + 291312*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 298620*uk_100 + 820260*uk_101 + 283500*uk_102 + 393183*uk_103 + 1080009*uk_104 + 373275*uk_105 + 2966607*uk_106 + 1025325*uk_107 + 354375*uk_108 + 32768*uk_109 + 1515296*uk_11 + 76800*uk_110 + 61440*uk_111 + 80896*uk_112 + 222208*uk_113 + 76800*uk_114 + 180000*uk_115 + 144000*uk_116 + 189600*uk_117 + 520800*uk_118 + 180000*uk_119 + 3551475*uk_12 + 115200*uk_120 + 151680*uk_121 + 416640*uk_122 + 144000*uk_123 + 199712*uk_124 + 548576*uk_125 + 189600*uk_126 + 1506848*uk_127 + 520800*uk_128 + 180000*uk_129 + 2841180*uk_13 + 421875*uk_130 + 337500*uk_131 + 444375*uk_132 + 1220625*uk_133 + 421875*uk_134 + 270000*uk_135 + 355500*uk_136 + 976500*uk_137 + 337500*uk_138 + 468075*uk_139 + 3740887*uk_14 + 1285725*uk_140 + 444375*uk_141 + 3531675*uk_142 + 1220625*uk_143 + 421875*uk_144 + 216000*uk_145 + 284400*uk_146 + 781200*uk_147 + 270000*uk_148 + 374460*uk_149 + 10275601*uk_15 + 1028580*uk_150 + 355500*uk_151 + 2825340*uk_152 + 976500*uk_153 + 337500*uk_154 + 493039*uk_155 + 1354297*uk_156 + 468075*uk_157 + 3720031*uk_158 + 1285725*uk_159 + 3551475*uk_16 + 444375*uk_160 + 10218313*uk_161 + 3531675*uk_162 + 1220625*uk_163 + 421875*uk_164 + 3969*uk_17 + 2016*uk_18 + 4725*uk_19 + 63*uk_2 + 3780*uk_20 + 4977*uk_21 + 13671*uk_22 + 4725*uk_23 + 1024*uk_24 + 2400*uk_25 + 1920*uk_26 + 2528*uk_27 + 6944*uk_28 + 2400*uk_29 + 32*uk_3 + 5625*uk_30 + 4500*uk_31 + 5925*uk_32 + 16275*uk_33 + 5625*uk_34 + 3600*uk_35 + 4740*uk_36 + 13020*uk_37 + 4500*uk_38 + 6241*uk_39 + 75*uk_4 + 17143*uk_40 + 5925*uk_41 + 47089*uk_42 + 16275*uk_43 + 5625*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 71753811488*uk_47 + 168172995675*uk_48 + 134538396540*uk_49 + 60*uk_5 + 177142222111*uk_50 + 486580534153*uk_51 + 168172995675*uk_52 + 187944057*uk_53 + 95463648*uk_54 + 223742925*uk_55 + 178994340*uk_56 + 235675881*uk_57 + 647362863*uk_58 + 223742925*uk_59 + 79*uk_6 + 48489472*uk_60 + 113647200*uk_61 + 90917760*uk_62 + 119708384*uk_63 + 328819232*uk_64 + 113647200*uk_65 + 266360625*uk_66 + 213088500*uk_67 + 280566525*uk_68 + 770670075*uk_69 + 217*uk_7 + 266360625*uk_70 + 170470800*uk_71 + 224453220*uk_72 + 616536060*uk_73 + 213088500*uk_74 + 295530073*uk_75 + 811772479*uk_76 + 280566525*uk_77 + 2229805417*uk_78 + 770670075*uk_79 + 75*uk_8 + 266360625*uk_80 + 250047*uk_81 + 127008*uk_82 + 297675*uk_83 + 238140*uk_84 + 313551*uk_85 + 861273*uk_86 + 297675*uk_87 + 64512*uk_88 + 151200*uk_89 + 2242306609*uk_9 + 120960*uk_90 + 159264*uk_91 + 437472*uk_92 + 151200*uk_93 + 354375*uk_94 + 283500*uk_95 + 373275*uk_96 + 1025325*uk_97 + 354375*uk_98 + 226800*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 306180*uk_100 + 820260*uk_101 + 120960*uk_102 + 413343*uk_103 + 1107351*uk_104 + 163296*uk_105 + 2966607*uk_106 + 437472*uk_107 + 64512*uk_108 + 117649*uk_109 + 2320297*uk_11 + 76832*uk_110 + 144060*uk_111 + 194481*uk_112 + 521017*uk_113 + 76832*uk_114 + 50176*uk_115 + 94080*uk_116 + 127008*uk_117 + 340256*uk_118 + 50176*uk_119 + 1515296*uk_12 + 176400*uk_120 + 238140*uk_121 + 637980*uk_122 + 94080*uk_123 + 321489*uk_124 + 861273*uk_125 + 127008*uk_126 + 2307361*uk_127 + 340256*uk_128 + 50176*uk_129 + 2841180*uk_13 + 32768*uk_130 + 61440*uk_131 + 82944*uk_132 + 222208*uk_133 + 32768*uk_134 + 115200*uk_135 + 155520*uk_136 + 416640*uk_137 + 61440*uk_138 + 209952*uk_139 + 3835593*uk_14 + 562464*uk_140 + 82944*uk_141 + 1506848*uk_142 + 222208*uk_143 + 32768*uk_144 + 216000*uk_145 + 291600*uk_146 + 781200*uk_147 + 115200*uk_148 + 393660*uk_149 + 10275601*uk_15 + 1054620*uk_150 + 155520*uk_151 + 2825340*uk_152 + 416640*uk_153 + 61440*uk_154 + 531441*uk_155 + 1423737*uk_156 + 209952*uk_157 + 3814209*uk_158 + 562464*uk_159 + 1515296*uk_16 + 82944*uk_160 + 10218313*uk_161 + 1506848*uk_162 + 222208*uk_163 + 32768*uk_164 + 3969*uk_17 + 3087*uk_18 + 2016*uk_19 + 63*uk_2 + 3780*uk_20 + 5103*uk_21 + 13671*uk_22 + 2016*uk_23 + 2401*uk_24 + 1568*uk_25 + 2940*uk_26 + 3969*uk_27 + 10633*uk_28 + 1568*uk_29 + 49*uk_3 + 1024*uk_30 + 1920*uk_31 + 2592*uk_32 + 6944*uk_33 + 1024*uk_34 + 3600*uk_35 + 4860*uk_36 + 13020*uk_37 + 1920*uk_38 + 6561*uk_39 + 32*uk_4 + 17577*uk_40 + 2592*uk_41 + 47089*uk_42 + 6944*uk_43 + 1024*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 109873023841*uk_47 + 71753811488*uk_48 + 134538396540*uk_49 + 60*uk_5 + 181626835329*uk_50 + 486580534153*uk_51 + 71753811488*uk_52 + 187944057*uk_53 + 146178711*uk_54 + 95463648*uk_55 + 178994340*uk_56 + 241642359*uk_57 + 647362863*uk_58 + 95463648*uk_59 + 81*uk_6 + 113694553*uk_60 + 74249504*uk_61 + 139217820*uk_62 + 187944057*uk_63 + 503504449*uk_64 + 74249504*uk_65 + 48489472*uk_66 + 90917760*uk_67 + 122738976*uk_68 + 328819232*uk_69 + 217*uk_7 + 48489472*uk_70 + 170470800*uk_71 + 230135580*uk_72 + 616536060*uk_73 + 90917760*uk_74 + 310683033*uk_75 + 832323681*uk_76 + 122738976*uk_77 + 2229805417*uk_78 + 328819232*uk_79 + 32*uk_8 + 48489472*uk_80 + 250047*uk_81 + 194481*uk_82 + 127008*uk_83 + 238140*uk_84 + 321489*uk_85 + 861273*uk_86 + 127008*uk_87 + 151263*uk_88 + 98784*uk_89 + 2242306609*uk_9 + 185220*uk_90 + 250047*uk_91 + 669879*uk_92 + 98784*uk_93 + 64512*uk_94 + 120960*uk_95 + 163296*uk_96 + 437472*uk_97 + 64512*uk_98 + 226800*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 292824*uk_100 + 765576*uk_101 + 172872*uk_102 + 434007*uk_103 + 1134693*uk_104 + 256221*uk_105 + 2966607*uk_106 + 669879*uk_107 + 151263*uk_108 + 79507*uk_109 + 2036179*uk_11 + 90601*uk_110 + 103544*uk_111 + 153467*uk_112 + 401233*uk_113 + 90601*uk_114 + 103243*uk_115 + 117992*uk_116 + 174881*uk_117 + 457219*uk_118 + 103243*uk_119 + 2320297*uk_12 + 134848*uk_120 + 199864*uk_121 + 522536*uk_122 + 117992*uk_123 + 296227*uk_124 + 774473*uk_125 + 174881*uk_126 + 2024827*uk_127 + 457219*uk_128 + 103243*uk_129 + 2651768*uk_13 + 117649*uk_130 + 134456*uk_131 + 199283*uk_132 + 521017*uk_133 + 117649*uk_134 + 153664*uk_135 + 227752*uk_136 + 595448*uk_137 + 134456*uk_138 + 337561*uk_139 + 3930299*uk_14 + 882539*uk_140 + 199283*uk_141 + 2307361*uk_142 + 521017*uk_143 + 117649*uk_144 + 175616*uk_145 + 260288*uk_146 + 680512*uk_147 + 153664*uk_148 + 385784*uk_149 + 10275601*uk_15 + 1008616*uk_150 + 227752*uk_151 + 2636984*uk_152 + 595448*uk_153 + 134456*uk_154 + 571787*uk_155 + 1494913*uk_156 + 337561*uk_157 + 3908387*uk_158 + 882539*uk_159 + 2320297*uk_16 + 199283*uk_160 + 10218313*uk_161 + 2307361*uk_162 + 521017*uk_163 + 117649*uk_164 + 3969*uk_17 + 2709*uk_18 + 3087*uk_19 + 63*uk_2 + 3528*uk_20 + 5229*uk_21 + 13671*uk_22 + 3087*uk_23 + 1849*uk_24 + 2107*uk_25 + 2408*uk_26 + 3569*uk_27 + 9331*uk_28 + 2107*uk_29 + 43*uk_3 + 2401*uk_30 + 2744*uk_31 + 4067*uk_32 + 10633*uk_33 + 2401*uk_34 + 3136*uk_35 + 4648*uk_36 + 12152*uk_37 + 2744*uk_38 + 6889*uk_39 + 49*uk_4 + 18011*uk_40 + 4067*uk_41 + 47089*uk_42 + 10633*uk_43 + 2401*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 96419184187*uk_47 + 109873023841*uk_48 + 125569170104*uk_49 + 56*uk_5 + 186111448547*uk_50 + 486580534153*uk_51 + 109873023841*uk_52 + 187944057*uk_53 + 128279277*uk_54 + 146178711*uk_55 + 167061384*uk_56 + 247608837*uk_57 + 647362863*uk_58 + 146178711*uk_59 + 83*uk_6 + 87555697*uk_60 + 99772771*uk_61 + 114026024*uk_62 + 169002857*uk_63 + 441850843*uk_64 + 99772771*uk_65 + 113694553*uk_66 + 129936632*uk_67 + 192584651*uk_68 + 503504449*uk_69 + 217*uk_7 + 113694553*uk_70 + 148499008*uk_71 + 220096744*uk_72 + 575433656*uk_73 + 129936632*uk_74 + 326214817*uk_75 + 852874883*uk_76 + 192584651*uk_77 + 2229805417*uk_78 + 503504449*uk_79 + 49*uk_8 + 113694553*uk_80 + 250047*uk_81 + 170667*uk_82 + 194481*uk_83 + 222264*uk_84 + 329427*uk_85 + 861273*uk_86 + 194481*uk_87 + 116487*uk_88 + 132741*uk_89 + 2242306609*uk_9 + 151704*uk_90 + 224847*uk_91 + 587853*uk_92 + 132741*uk_93 + 151263*uk_94 + 172872*uk_95 + 256221*uk_96 + 669879*uk_97 + 151263*uk_98 + 197568*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 278460*uk_100 + 710892*uk_101 + 140868*uk_102 + 455175*uk_103 + 1162035*uk_104 + 230265*uk_105 + 2966607*uk_106 + 587853*uk_107 + 116487*uk_108 + 512*uk_109 + 378824*uk_11 + 2752*uk_110 + 3328*uk_111 + 5440*uk_112 + 13888*uk_113 + 2752*uk_114 + 14792*uk_115 + 17888*uk_116 + 29240*uk_117 + 74648*uk_118 + 14792*uk_119 + 2036179*uk_12 + 21632*uk_120 + 35360*uk_121 + 90272*uk_122 + 17888*uk_123 + 57800*uk_124 + 147560*uk_125 + 29240*uk_126 + 376712*uk_127 + 74648*uk_128 + 14792*uk_129 + 2462356*uk_13 + 79507*uk_130 + 96148*uk_131 + 157165*uk_132 + 401233*uk_133 + 79507*uk_134 + 116272*uk_135 + 190060*uk_136 + 485212*uk_137 + 96148*uk_138 + 310675*uk_139 + 4025005*uk_14 + 793135*uk_140 + 157165*uk_141 + 2024827*uk_142 + 401233*uk_143 + 79507*uk_144 + 140608*uk_145 + 229840*uk_146 + 586768*uk_147 + 116272*uk_148 + 375700*uk_149 + 10275601*uk_15 + 959140*uk_150 + 190060*uk_151 + 2448628*uk_152 + 485212*uk_153 + 96148*uk_154 + 614125*uk_155 + 1567825*uk_156 + 310675*uk_157 + 4002565*uk_158 + 793135*uk_159 + 2036179*uk_16 + 157165*uk_160 + 10218313*uk_161 + 2024827*uk_162 + 401233*uk_163 + 79507*uk_164 + 3969*uk_17 + 504*uk_18 + 2709*uk_19 + 63*uk_2 + 3276*uk_20 + 5355*uk_21 + 13671*uk_22 + 2709*uk_23 + 64*uk_24 + 344*uk_25 + 416*uk_26 + 680*uk_27 + 1736*uk_28 + 344*uk_29 + 8*uk_3 + 1849*uk_30 + 2236*uk_31 + 3655*uk_32 + 9331*uk_33 + 1849*uk_34 + 2704*uk_35 + 4420*uk_36 + 11284*uk_37 + 2236*uk_38 + 7225*uk_39 + 43*uk_4 + 18445*uk_40 + 3655*uk_41 + 47089*uk_42 + 9331*uk_43 + 1849*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 17938452872*uk_47 + 96419184187*uk_48 + 116599943668*uk_49 + 52*uk_5 + 190596061765*uk_50 + 486580534153*uk_51 + 96419184187*uk_52 + 187944057*uk_53 + 23865912*uk_54 + 128279277*uk_55 + 155128428*uk_56 + 253575315*uk_57 + 647362863*uk_58 + 128279277*uk_59 + 85*uk_6 + 3030592*uk_60 + 16289432*uk_61 + 19698848*uk_62 + 32200040*uk_63 + 82204808*uk_64 + 16289432*uk_65 + 87555697*uk_66 + 105881308*uk_67 + 173075215*uk_68 + 441850843*uk_69 + 217*uk_7 + 87555697*uk_70 + 128042512*uk_71 + 209300260*uk_72 + 534331252*uk_73 + 105881308*uk_74 + 342125425*uk_75 + 873426085*uk_76 + 173075215*uk_77 + 2229805417*uk_78 + 441850843*uk_79 + 43*uk_8 + 87555697*uk_80 + 250047*uk_81 + 31752*uk_82 + 170667*uk_83 + 206388*uk_84 + 337365*uk_85 + 861273*uk_86 + 170667*uk_87 + 4032*uk_88 + 21672*uk_89 + 2242306609*uk_9 + 26208*uk_90 + 42840*uk_91 + 109368*uk_92 + 21672*uk_93 + 116487*uk_94 + 140868*uk_95 + 230265*uk_96 + 587853*uk_97 + 116487*uk_98 + 170352*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 285012*uk_100 + 710892*uk_101 + 26208*uk_102 + 476847*uk_103 + 1189377*uk_104 + 43848*uk_105 + 2966607*uk_106 + 109368*uk_107 + 4032*uk_108 + 15625*uk_109 + 1183825*uk_11 + 5000*uk_110 + 32500*uk_111 + 54375*uk_112 + 135625*uk_113 + 5000*uk_114 + 1600*uk_115 + 10400*uk_116 + 17400*uk_117 + 43400*uk_118 + 1600*uk_119 + 378824*uk_12 + 67600*uk_120 + 113100*uk_121 + 282100*uk_122 + 10400*uk_123 + 189225*uk_124 + 471975*uk_125 + 17400*uk_126 + 1177225*uk_127 + 43400*uk_128 + 1600*uk_129 + 2462356*uk_13 + 512*uk_130 + 3328*uk_131 + 5568*uk_132 + 13888*uk_133 + 512*uk_134 + 21632*uk_135 + 36192*uk_136 + 90272*uk_137 + 3328*uk_138 + 60552*uk_139 + 4119711*uk_14 + 151032*uk_140 + 5568*uk_141 + 376712*uk_142 + 13888*uk_143 + 512*uk_144 + 140608*uk_145 + 235248*uk_146 + 586768*uk_147 + 21632*uk_148 + 393588*uk_149 + 10275601*uk_15 + 981708*uk_150 + 36192*uk_151 + 2448628*uk_152 + 90272*uk_153 + 3328*uk_154 + 658503*uk_155 + 1642473*uk_156 + 60552*uk_157 + 4096743*uk_158 + 151032*uk_159 + 378824*uk_16 + 5568*uk_160 + 10218313*uk_161 + 376712*uk_162 + 13888*uk_163 + 512*uk_164 + 3969*uk_17 + 1575*uk_18 + 504*uk_19 + 63*uk_2 + 3276*uk_20 + 5481*uk_21 + 13671*uk_22 + 504*uk_23 + 625*uk_24 + 200*uk_25 + 1300*uk_26 + 2175*uk_27 + 5425*uk_28 + 200*uk_29 + 25*uk_3 + 64*uk_30 + 416*uk_31 + 696*uk_32 + 1736*uk_33 + 64*uk_34 + 2704*uk_35 + 4524*uk_36 + 11284*uk_37 + 416*uk_38 + 7569*uk_39 + 8*uk_4 + 18879*uk_40 + 696*uk_41 + 47089*uk_42 + 1736*uk_43 + 64*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 56057665225*uk_47 + 17938452872*uk_48 + 116599943668*uk_49 + 52*uk_5 + 195080674983*uk_50 + 486580534153*uk_51 + 17938452872*uk_52 + 187944057*uk_53 + 74580975*uk_54 + 23865912*uk_55 + 155128428*uk_56 + 259541793*uk_57 + 647362863*uk_58 + 23865912*uk_59 + 87*uk_6 + 29595625*uk_60 + 9470600*uk_61 + 61558900*uk_62 + 102992775*uk_63 + 256890025*uk_64 + 9470600*uk_65 + 3030592*uk_66 + 19698848*uk_67 + 32957688*uk_68 + 82204808*uk_69 + 217*uk_7 + 3030592*uk_70 + 128042512*uk_71 + 214224972*uk_72 + 534331252*uk_73 + 19698848*uk_74 + 358414857*uk_75 + 893977287*uk_76 + 32957688*uk_77 + 2229805417*uk_78 + 82204808*uk_79 + 8*uk_8 + 3030592*uk_80 + 250047*uk_81 + 99225*uk_82 + 31752*uk_83 + 206388*uk_84 + 345303*uk_85 + 861273*uk_86 + 31752*uk_87 + 39375*uk_88 + 12600*uk_89 + 2242306609*uk_9 + 81900*uk_90 + 137025*uk_91 + 341775*uk_92 + 12600*uk_93 + 4032*uk_94 + 26208*uk_95 + 43848*uk_96 + 109368*uk_97 + 4032*uk_98 + 170352*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 269136*uk_100 + 656208*uk_101 + 75600*uk_102 + 499023*uk_103 + 1216719*uk_104 + 140175*uk_105 + 2966607*uk_106 + 341775*uk_107 + 39375*uk_108 + 125*uk_109 + 236765*uk_11 + 625*uk_110 + 1200*uk_111 + 2225*uk_112 + 5425*uk_113 + 625*uk_114 + 3125*uk_115 + 6000*uk_116 + 11125*uk_117 + 27125*uk_118 + 3125*uk_119 + 1183825*uk_12 + 11520*uk_120 + 21360*uk_121 + 52080*uk_122 + 6000*uk_123 + 39605*uk_124 + 96565*uk_125 + 11125*uk_126 + 235445*uk_127 + 27125*uk_128 + 3125*uk_129 + 2272944*uk_13 + 15625*uk_130 + 30000*uk_131 + 55625*uk_132 + 135625*uk_133 + 15625*uk_134 + 57600*uk_135 + 106800*uk_136 + 260400*uk_137 + 30000*uk_138 + 198025*uk_139 + 4214417*uk_14 + 482825*uk_140 + 55625*uk_141 + 1177225*uk_142 + 135625*uk_143 + 15625*uk_144 + 110592*uk_145 + 205056*uk_146 + 499968*uk_147 + 57600*uk_148 + 380208*uk_149 + 10275601*uk_15 + 927024*uk_150 + 106800*uk_151 + 2260272*uk_152 + 260400*uk_153 + 30000*uk_154 + 704969*uk_155 + 1718857*uk_156 + 198025*uk_157 + 4190921*uk_158 + 482825*uk_159 + 1183825*uk_16 + 55625*uk_160 + 10218313*uk_161 + 1177225*uk_162 + 135625*uk_163 + 15625*uk_164 + 3969*uk_17 + 315*uk_18 + 1575*uk_19 + 63*uk_2 + 3024*uk_20 + 5607*uk_21 + 13671*uk_22 + 1575*uk_23 + 25*uk_24 + 125*uk_25 + 240*uk_26 + 445*uk_27 + 1085*uk_28 + 125*uk_29 + 5*uk_3 + 625*uk_30 + 1200*uk_31 + 2225*uk_32 + 5425*uk_33 + 625*uk_34 + 2304*uk_35 + 4272*uk_36 + 10416*uk_37 + 1200*uk_38 + 7921*uk_39 + 25*uk_4 + 19313*uk_40 + 2225*uk_41 + 47089*uk_42 + 5425*uk_43 + 625*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 11211533045*uk_47 + 56057665225*uk_48 + 107630717232*uk_49 + 48*uk_5 + 199565288201*uk_50 + 486580534153*uk_51 + 56057665225*uk_52 + 187944057*uk_53 + 14916195*uk_54 + 74580975*uk_55 + 143195472*uk_56 + 265508271*uk_57 + 647362863*uk_58 + 74580975*uk_59 + 89*uk_6 + 1183825*uk_60 + 5919125*uk_61 + 11364720*uk_62 + 21072085*uk_63 + 51378005*uk_64 + 5919125*uk_65 + 29595625*uk_66 + 56823600*uk_67 + 105360425*uk_68 + 256890025*uk_69 + 217*uk_7 + 29595625*uk_70 + 109101312*uk_71 + 202292016*uk_72 + 493228848*uk_73 + 56823600*uk_74 + 375083113*uk_75 + 914528489*uk_76 + 105360425*uk_77 + 2229805417*uk_78 + 256890025*uk_79 + 25*uk_8 + 29595625*uk_80 + 250047*uk_81 + 19845*uk_82 + 99225*uk_83 + 190512*uk_84 + 353241*uk_85 + 861273*uk_86 + 99225*uk_87 + 1575*uk_88 + 7875*uk_89 + 2242306609*uk_9 + 15120*uk_90 + 28035*uk_91 + 68355*uk_92 + 7875*uk_93 + 39375*uk_94 + 75600*uk_95 + 140175*uk_96 + 341775*uk_97 + 39375*uk_98 + 145152*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 275184*uk_100 + 656208*uk_101 + 15120*uk_102 + 521703*uk_103 + 1244061*uk_104 + 28665*uk_105 + 2966607*uk_106 + 68355*uk_107 + 1575*uk_108 + 35937*uk_109 + 1562649*uk_11 + 5445*uk_110 + 52272*uk_111 + 99099*uk_112 + 236313*uk_113 + 5445*uk_114 + 825*uk_115 + 7920*uk_116 + 15015*uk_117 + 35805*uk_118 + 825*uk_119 + 236765*uk_12 + 76032*uk_120 + 144144*uk_121 + 343728*uk_122 + 7920*uk_123 + 273273*uk_124 + 651651*uk_125 + 15015*uk_126 + 1553937*uk_127 + 35805*uk_128 + 825*uk_129 + 2272944*uk_13 + 125*uk_130 + 1200*uk_131 + 2275*uk_132 + 5425*uk_133 + 125*uk_134 + 11520*uk_135 + 21840*uk_136 + 52080*uk_137 + 1200*uk_138 + 41405*uk_139 + 4309123*uk_14 + 98735*uk_140 + 2275*uk_141 + 235445*uk_142 + 5425*uk_143 + 125*uk_144 + 110592*uk_145 + 209664*uk_146 + 499968*uk_147 + 11520*uk_148 + 397488*uk_149 + 10275601*uk_15 + 947856*uk_150 + 21840*uk_151 + 2260272*uk_152 + 52080*uk_153 + 1200*uk_154 + 753571*uk_155 + 1796977*uk_156 + 41405*uk_157 + 4285099*uk_158 + 98735*uk_159 + 236765*uk_16 + 2275*uk_160 + 10218313*uk_161 + 235445*uk_162 + 5425*uk_163 + 125*uk_164 + 3969*uk_17 + 2079*uk_18 + 315*uk_19 + 63*uk_2 + 3024*uk_20 + 5733*uk_21 + 13671*uk_22 + 315*uk_23 + 1089*uk_24 + 165*uk_25 + 1584*uk_26 + 3003*uk_27 + 7161*uk_28 + 165*uk_29 + 33*uk_3 + 25*uk_30 + 240*uk_31 + 455*uk_32 + 1085*uk_33 + 25*uk_34 + 2304*uk_35 + 4368*uk_36 + 10416*uk_37 + 240*uk_38 + 8281*uk_39 + 5*uk_4 + 19747*uk_40 + 455*uk_41 + 47089*uk_42 + 1085*uk_43 + 25*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 73996118097*uk_47 + 11211533045*uk_48 + 107630717232*uk_49 + 48*uk_5 + 204049901419*uk_50 + 486580534153*uk_51 + 11211533045*uk_52 + 187944057*uk_53 + 98446887*uk_54 + 14916195*uk_55 + 143195472*uk_56 + 271474749*uk_57 + 647362863*uk_58 + 14916195*uk_59 + 91*uk_6 + 51567417*uk_60 + 7813245*uk_61 + 75007152*uk_62 + 142201059*uk_63 + 339094833*uk_64 + 7813245*uk_65 + 1183825*uk_66 + 11364720*uk_67 + 21545615*uk_68 + 51378005*uk_69 + 217*uk_7 + 1183825*uk_70 + 109101312*uk_71 + 206837904*uk_72 + 493228848*uk_73 + 11364720*uk_74 + 392130193*uk_75 + 935079691*uk_76 + 21545615*uk_77 + 2229805417*uk_78 + 51378005*uk_79 + 5*uk_8 + 1183825*uk_80 + 250047*uk_81 + 130977*uk_82 + 19845*uk_83 + 190512*uk_84 + 361179*uk_85 + 861273*uk_86 + 19845*uk_87 + 68607*uk_88 + 10395*uk_89 + 2242306609*uk_9 + 99792*uk_90 + 189189*uk_91 + 451143*uk_92 + 10395*uk_93 + 1575*uk_94 + 15120*uk_95 + 28665*uk_96 + 68355*uk_97 + 1575*uk_98 + 145152*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 257796*uk_100 + 601524*uk_101 + 91476*uk_102 + 544887*uk_103 + 1271403*uk_104 + 193347*uk_105 + 2966607*uk_106 + 451143*uk_107 + 68607*uk_108 + 4096*uk_109 + 757648*uk_11 + 8448*uk_110 + 11264*uk_111 + 23808*uk_112 + 55552*uk_113 + 8448*uk_114 + 17424*uk_115 + 23232*uk_116 + 49104*uk_117 + 114576*uk_118 + 17424*uk_119 + 1562649*uk_12 + 30976*uk_120 + 65472*uk_121 + 152768*uk_122 + 23232*uk_123 + 138384*uk_124 + 322896*uk_125 + 49104*uk_126 + 753424*uk_127 + 114576*uk_128 + 17424*uk_129 + 2083532*uk_13 + 35937*uk_130 + 47916*uk_131 + 101277*uk_132 + 236313*uk_133 + 35937*uk_134 + 63888*uk_135 + 135036*uk_136 + 315084*uk_137 + 47916*uk_138 + 285417*uk_139 + 4403829*uk_14 + 665973*uk_140 + 101277*uk_141 + 1553937*uk_142 + 236313*uk_143 + 35937*uk_144 + 85184*uk_145 + 180048*uk_146 + 420112*uk_147 + 63888*uk_148 + 380556*uk_149 + 10275601*uk_15 + 887964*uk_150 + 135036*uk_151 + 2071916*uk_152 + 315084*uk_153 + 47916*uk_154 + 804357*uk_155 + 1876833*uk_156 + 285417*uk_157 + 4379277*uk_158 + 665973*uk_159 + 1562649*uk_16 + 101277*uk_160 + 10218313*uk_161 + 1553937*uk_162 + 236313*uk_163 + 35937*uk_164 + 3969*uk_17 + 1008*uk_18 + 2079*uk_19 + 63*uk_2 + 2772*uk_20 + 5859*uk_21 + 13671*uk_22 + 2079*uk_23 + 256*uk_24 + 528*uk_25 + 704*uk_26 + 1488*uk_27 + 3472*uk_28 + 528*uk_29 + 16*uk_3 + 1089*uk_30 + 1452*uk_31 + 3069*uk_32 + 7161*uk_33 + 1089*uk_34 + 1936*uk_35 + 4092*uk_36 + 9548*uk_37 + 1452*uk_38 + 8649*uk_39 + 33*uk_4 + 20181*uk_40 + 3069*uk_41 + 47089*uk_42 + 7161*uk_43 + 1089*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 35876905744*uk_47 + 73996118097*uk_48 + 98661490796*uk_49 + 44*uk_5 + 208534514637*uk_50 + 486580534153*uk_51 + 73996118097*uk_52 + 187944057*uk_53 + 47731824*uk_54 + 98446887*uk_55 + 131262516*uk_56 + 277441227*uk_57 + 647362863*uk_58 + 98446887*uk_59 + 93*uk_6 + 12122368*uk_60 + 25002384*uk_61 + 33336512*uk_62 + 70461264*uk_63 + 164409616*uk_64 + 25002384*uk_65 + 51567417*uk_66 + 68756556*uk_67 + 145326357*uk_68 + 339094833*uk_69 + 217*uk_7 + 51567417*uk_70 + 91675408*uk_71 + 193768476*uk_72 + 452126444*uk_73 + 68756556*uk_74 + 409556097*uk_75 + 955630893*uk_76 + 145326357*uk_77 + 2229805417*uk_78 + 339094833*uk_79 + 33*uk_8 + 51567417*uk_80 + 250047*uk_81 + 63504*uk_82 + 130977*uk_83 + 174636*uk_84 + 369117*uk_85 + 861273*uk_86 + 130977*uk_87 + 16128*uk_88 + 33264*uk_89 + 2242306609*uk_9 + 44352*uk_90 + 93744*uk_91 + 218736*uk_92 + 33264*uk_93 + 68607*uk_94 + 91476*uk_95 + 193347*uk_96 + 451143*uk_97 + 68607*uk_98 + 121968*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 263340*uk_100 + 601524*uk_101 + 44352*uk_102 + 568575*uk_103 + 1298745*uk_104 + 95760*uk_105 + 2966607*uk_106 + 218736*uk_107 + 16128*uk_108 + 79507*uk_109 + 2036179*uk_11 + 29584*uk_110 + 81356*uk_111 + 175655*uk_112 + 401233*uk_113 + 29584*uk_114 + 11008*uk_115 + 30272*uk_116 + 65360*uk_117 + 149296*uk_118 + 11008*uk_119 + 757648*uk_12 + 83248*uk_120 + 179740*uk_121 + 410564*uk_122 + 30272*uk_123 + 388075*uk_124 + 886445*uk_125 + 65360*uk_126 + 2024827*uk_127 + 149296*uk_128 + 11008*uk_129 + 2083532*uk_13 + 4096*uk_130 + 11264*uk_131 + 24320*uk_132 + 55552*uk_133 + 4096*uk_134 + 30976*uk_135 + 66880*uk_136 + 152768*uk_137 + 11264*uk_138 + 144400*uk_139 + 4498535*uk_14 + 329840*uk_140 + 24320*uk_141 + 753424*uk_142 + 55552*uk_143 + 4096*uk_144 + 85184*uk_145 + 183920*uk_146 + 420112*uk_147 + 30976*uk_148 + 397100*uk_149 + 10275601*uk_15 + 907060*uk_150 + 66880*uk_151 + 2071916*uk_152 + 152768*uk_153 + 11264*uk_154 + 857375*uk_155 + 1958425*uk_156 + 144400*uk_157 + 4473455*uk_158 + 329840*uk_159 + 757648*uk_16 + 24320*uk_160 + 10218313*uk_161 + 753424*uk_162 + 55552*uk_163 + 4096*uk_164 + 3969*uk_17 + 2709*uk_18 + 1008*uk_19 + 63*uk_2 + 2772*uk_20 + 5985*uk_21 + 13671*uk_22 + 1008*uk_23 + 1849*uk_24 + 688*uk_25 + 1892*uk_26 + 4085*uk_27 + 9331*uk_28 + 688*uk_29 + 43*uk_3 + 256*uk_30 + 704*uk_31 + 1520*uk_32 + 3472*uk_33 + 256*uk_34 + 1936*uk_35 + 4180*uk_36 + 9548*uk_37 + 704*uk_38 + 9025*uk_39 + 16*uk_4 + 20615*uk_40 + 1520*uk_41 + 47089*uk_42 + 3472*uk_43 + 256*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 96419184187*uk_47 + 35876905744*uk_48 + 98661490796*uk_49 + 44*uk_5 + 213019127855*uk_50 + 486580534153*uk_51 + 35876905744*uk_52 + 187944057*uk_53 + 128279277*uk_54 + 47731824*uk_55 + 131262516*uk_56 + 283407705*uk_57 + 647362863*uk_58 + 47731824*uk_59 + 95*uk_6 + 87555697*uk_60 + 32578864*uk_61 + 89591876*uk_62 + 193437005*uk_63 + 441850843*uk_64 + 32578864*uk_65 + 12122368*uk_66 + 33336512*uk_67 + 71976560*uk_68 + 164409616*uk_69 + 217*uk_7 + 12122368*uk_70 + 91675408*uk_71 + 197935540*uk_72 + 452126444*uk_73 + 33336512*uk_74 + 427360825*uk_75 + 976182095*uk_76 + 71976560*uk_77 + 2229805417*uk_78 + 164409616*uk_79 + 16*uk_8 + 12122368*uk_80 + 250047*uk_81 + 170667*uk_82 + 63504*uk_83 + 174636*uk_84 + 377055*uk_85 + 861273*uk_86 + 63504*uk_87 + 116487*uk_88 + 43344*uk_89 + 2242306609*uk_9 + 119196*uk_90 + 257355*uk_91 + 587853*uk_92 + 43344*uk_93 + 16128*uk_94 + 44352*uk_95 + 95760*uk_96 + 218736*uk_97 + 16128*uk_98 + 121968*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 244440*uk_100 + 546840*uk_101 + 108360*uk_102 + 592767*uk_103 + 1326087*uk_104 + 262773*uk_105 + 2966607*uk_106 + 587853*uk_107 + 116487*uk_108 + 4913*uk_109 + 805001*uk_11 + 12427*uk_110 + 11560*uk_111 + 28033*uk_112 + 62713*uk_113 + 12427*uk_114 + 31433*uk_115 + 29240*uk_116 + 70907*uk_117 + 158627*uk_118 + 31433*uk_119 + 2036179*uk_12 + 27200*uk_120 + 65960*uk_121 + 147560*uk_122 + 29240*uk_123 + 159953*uk_124 + 357833*uk_125 + 70907*uk_126 + 800513*uk_127 + 158627*uk_128 + 31433*uk_129 + 1894120*uk_13 + 79507*uk_130 + 73960*uk_131 + 179353*uk_132 + 401233*uk_133 + 79507*uk_134 + 68800*uk_135 + 166840*uk_136 + 373240*uk_137 + 73960*uk_138 + 404587*uk_139 + 4593241*uk_14 + 905107*uk_140 + 179353*uk_141 + 2024827*uk_142 + 401233*uk_143 + 79507*uk_144 + 64000*uk_145 + 155200*uk_146 + 347200*uk_147 + 68800*uk_148 + 376360*uk_149 + 10275601*uk_15 + 841960*uk_150 + 166840*uk_151 + 1883560*uk_152 + 373240*uk_153 + 73960*uk_154 + 912673*uk_155 + 2041753*uk_156 + 404587*uk_157 + 4567633*uk_158 + 905107*uk_159 + 2036179*uk_16 + 179353*uk_160 + 10218313*uk_161 + 2024827*uk_162 + 401233*uk_163 + 79507*uk_164 + 3969*uk_17 + 1071*uk_18 + 2709*uk_19 + 63*uk_2 + 2520*uk_20 + 6111*uk_21 + 13671*uk_22 + 2709*uk_23 + 289*uk_24 + 731*uk_25 + 680*uk_26 + 1649*uk_27 + 3689*uk_28 + 731*uk_29 + 17*uk_3 + 1849*uk_30 + 1720*uk_31 + 4171*uk_32 + 9331*uk_33 + 1849*uk_34 + 1600*uk_35 + 3880*uk_36 + 8680*uk_37 + 1720*uk_38 + 9409*uk_39 + 43*uk_4 + 21049*uk_40 + 4171*uk_41 + 47089*uk_42 + 9331*uk_43 + 1849*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 38119212353*uk_47 + 96419184187*uk_48 + 89692264360*uk_49 + 40*uk_5 + 217503741073*uk_50 + 486580534153*uk_51 + 96419184187*uk_52 + 187944057*uk_53 + 50715063*uk_54 + 128279277*uk_55 + 119329560*uk_56 + 289374183*uk_57 + 647362863*uk_58 + 128279277*uk_59 + 97*uk_6 + 13685017*uk_60 + 34615043*uk_61 + 32200040*uk_62 + 78085097*uk_63 + 174685217*uk_64 + 34615043*uk_65 + 87555697*uk_66 + 81447160*uk_67 + 197509363*uk_68 + 441850843*uk_69 + 217*uk_7 + 87555697*uk_70 + 75764800*uk_71 + 183729640*uk_72 + 411024040*uk_73 + 81447160*uk_74 + 445544377*uk_75 + 996733297*uk_76 + 197509363*uk_77 + 2229805417*uk_78 + 441850843*uk_79 + 43*uk_8 + 87555697*uk_80 + 250047*uk_81 + 67473*uk_82 + 170667*uk_83 + 158760*uk_84 + 384993*uk_85 + 861273*uk_86 + 170667*uk_87 + 18207*uk_88 + 46053*uk_89 + 2242306609*uk_9 + 42840*uk_90 + 103887*uk_91 + 232407*uk_92 + 46053*uk_93 + 116487*uk_94 + 108360*uk_95 + 262773*uk_96 + 587853*uk_97 + 116487*uk_98 + 100800*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 249480*uk_100 + 546840*uk_101 + 42840*uk_102 + 617463*uk_103 + 1353429*uk_104 + 106029*uk_105 + 2966607*uk_106 + 232407*uk_107 + 18207*uk_108 + 29791*uk_109 + 1467943*uk_11 + 16337*uk_110 + 38440*uk_111 + 95139*uk_112 + 208537*uk_113 + 16337*uk_114 + 8959*uk_115 + 21080*uk_116 + 52173*uk_117 + 114359*uk_118 + 8959*uk_119 + 805001*uk_12 + 49600*uk_120 + 122760*uk_121 + 269080*uk_122 + 21080*uk_123 + 303831*uk_124 + 665973*uk_125 + 52173*uk_126 + 1459759*uk_127 + 114359*uk_128 + 8959*uk_129 + 1894120*uk_13 + 4913*uk_130 + 11560*uk_131 + 28611*uk_132 + 62713*uk_133 + 4913*uk_134 + 27200*uk_135 + 67320*uk_136 + 147560*uk_137 + 11560*uk_138 + 166617*uk_139 + 4687947*uk_14 + 365211*uk_140 + 28611*uk_141 + 800513*uk_142 + 62713*uk_143 + 4913*uk_144 + 64000*uk_145 + 158400*uk_146 + 347200*uk_147 + 27200*uk_148 + 392040*uk_149 + 10275601*uk_15 + 859320*uk_150 + 67320*uk_151 + 1883560*uk_152 + 147560*uk_153 + 11560*uk_154 + 970299*uk_155 + 2126817*uk_156 + 166617*uk_157 + 4661811*uk_158 + 365211*uk_159 + 805001*uk_16 + 28611*uk_160 + 10218313*uk_161 + 800513*uk_162 + 62713*uk_163 + 4913*uk_164 + 3969*uk_17 + 1953*uk_18 + 1071*uk_19 + 63*uk_2 + 2520*uk_20 + 6237*uk_21 + 13671*uk_22 + 1071*uk_23 + 961*uk_24 + 527*uk_25 + 1240*uk_26 + 3069*uk_27 + 6727*uk_28 + 527*uk_29 + 31*uk_3 + 289*uk_30 + 680*uk_31 + 1683*uk_32 + 3689*uk_33 + 289*uk_34 + 1600*uk_35 + 3960*uk_36 + 8680*uk_37 + 680*uk_38 + 9801*uk_39 + 17*uk_4 + 21483*uk_40 + 1683*uk_41 + 47089*uk_42 + 3689*uk_43 + 289*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 69511504879*uk_47 + 38119212353*uk_48 + 89692264360*uk_49 + 40*uk_5 + 221988354291*uk_50 + 486580534153*uk_51 + 38119212353*uk_52 + 187944057*uk_53 + 92480409*uk_54 + 50715063*uk_55 + 119329560*uk_56 + 295340661*uk_57 + 647362863*uk_58 + 50715063*uk_59 + 99*uk_6 + 45506233*uk_60 + 24955031*uk_61 + 58717720*uk_62 + 145326357*uk_63 + 318543631*uk_64 + 24955031*uk_65 + 13685017*uk_66 + 32200040*uk_67 + 79695099*uk_68 + 174685217*uk_69 + 217*uk_7 + 13685017*uk_70 + 75764800*uk_71 + 187517880*uk_72 + 411024040*uk_73 + 32200040*uk_74 + 464106753*uk_75 + 1017284499*uk_76 + 79695099*uk_77 + 2229805417*uk_78 + 174685217*uk_79 + 17*uk_8 + 13685017*uk_80 + 250047*uk_81 + 123039*uk_82 + 67473*uk_83 + 158760*uk_84 + 392931*uk_85 + 861273*uk_86 + 67473*uk_87 + 60543*uk_88 + 33201*uk_89 + 2242306609*uk_9 + 78120*uk_90 + 193347*uk_91 + 423801*uk_92 + 33201*uk_93 + 18207*uk_94 + 42840*uk_95 + 106029*uk_96 + 232407*uk_97 + 18207*uk_98 + 100800*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 254520*uk_100 + 546840*uk_101 + 78120*uk_102 + 642663*uk_103 + 1380771*uk_104 + 197253*uk_105 + 2966607*uk_106 + 423801*uk_107 + 60543*uk_108 + 614125*uk_109 + 4025005*uk_11 + 223975*uk_110 + 289000*uk_111 + 729725*uk_112 + 1567825*uk_113 + 223975*uk_114 + 81685*uk_115 + 105400*uk_116 + 266135*uk_117 + 571795*uk_118 + 81685*uk_119 + 1467943*uk_12 + 136000*uk_120 + 343400*uk_121 + 737800*uk_122 + 105400*uk_123 + 867085*uk_124 + 1862945*uk_125 + 266135*uk_126 + 4002565*uk_127 + 571795*uk_128 + 81685*uk_129 + 1894120*uk_13 + 29791*uk_130 + 38440*uk_131 + 97061*uk_132 + 208537*uk_133 + 29791*uk_134 + 49600*uk_135 + 125240*uk_136 + 269080*uk_137 + 38440*uk_138 + 316231*uk_139 + 4782653*uk_14 + 679427*uk_140 + 97061*uk_141 + 1459759*uk_142 + 208537*uk_143 + 29791*uk_144 + 64000*uk_145 + 161600*uk_146 + 347200*uk_147 + 49600*uk_148 + 408040*uk_149 + 10275601*uk_15 + 876680*uk_150 + 125240*uk_151 + 1883560*uk_152 + 269080*uk_153 + 38440*uk_154 + 1030301*uk_155 + 2213617*uk_156 + 316231*uk_157 + 4755989*uk_158 + 679427*uk_159 + 1467943*uk_16 + 97061*uk_160 + 10218313*uk_161 + 1459759*uk_162 + 208537*uk_163 + 29791*uk_164 + 3969*uk_17 + 5355*uk_18 + 1953*uk_19 + 63*uk_2 + 2520*uk_20 + 6363*uk_21 + 13671*uk_22 + 1953*uk_23 + 7225*uk_24 + 2635*uk_25 + 3400*uk_26 + 8585*uk_27 + 18445*uk_28 + 2635*uk_29 + 85*uk_3 + 961*uk_30 + 1240*uk_31 + 3131*uk_32 + 6727*uk_33 + 961*uk_34 + 1600*uk_35 + 4040*uk_36 + 8680*uk_37 + 1240*uk_38 + 10201*uk_39 + 31*uk_4 + 21917*uk_40 + 3131*uk_41 + 47089*uk_42 + 6727*uk_43 + 961*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 190596061765*uk_47 + 69511504879*uk_48 + 89692264360*uk_49 + 40*uk_5 + 226472967509*uk_50 + 486580534153*uk_51 + 69511504879*uk_52 + 187944057*uk_53 + 253575315*uk_54 + 92480409*uk_55 + 119329560*uk_56 + 301307139*uk_57 + 647362863*uk_58 + 92480409*uk_59 + 101*uk_6 + 342125425*uk_60 + 124775155*uk_61 + 161000200*uk_62 + 406525505*uk_63 + 873426085*uk_64 + 124775155*uk_65 + 45506233*uk_66 + 58717720*uk_67 + 148262243*uk_68 + 318543631*uk_69 + 217*uk_7 + 45506233*uk_70 + 75764800*uk_71 + 191306120*uk_72 + 411024040*uk_73 + 58717720*uk_74 + 483047953*uk_75 + 1037835701*uk_76 + 148262243*uk_77 + 2229805417*uk_78 + 318543631*uk_79 + 31*uk_8 + 45506233*uk_80 + 250047*uk_81 + 337365*uk_82 + 123039*uk_83 + 158760*uk_84 + 400869*uk_85 + 861273*uk_86 + 123039*uk_87 + 455175*uk_88 + 166005*uk_89 + 2242306609*uk_9 + 214200*uk_90 + 540855*uk_91 + 1162035*uk_92 + 166005*uk_93 + 60543*uk_94 + 78120*uk_95 + 197253*uk_96 + 423801*uk_97 + 60543*uk_98 + 100800*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 233604*uk_100 + 492156*uk_101 + 192780*uk_102 + 668367*uk_103 + 1408113*uk_104 + 551565*uk_105 + 2966607*uk_106 + 1162035*uk_107 + 455175*uk_108 + 438976*uk_109 + 3598828*uk_11 + 490960*uk_110 + 207936*uk_111 + 594928*uk_112 + 1253392*uk_113 + 490960*uk_114 + 549100*uk_115 + 232560*uk_116 + 665380*uk_117 + 1401820*uk_118 + 549100*uk_119 + 4025005*uk_12 + 98496*uk_120 + 281808*uk_121 + 593712*uk_122 + 232560*uk_123 + 806284*uk_124 + 1698676*uk_125 + 665380*uk_126 + 3578764*uk_127 + 1401820*uk_128 + 549100*uk_129 + 1704708*uk_13 + 614125*uk_130 + 260100*uk_131 + 744175*uk_132 + 1567825*uk_133 + 614125*uk_134 + 110160*uk_135 + 315180*uk_136 + 664020*uk_137 + 260100*uk_138 + 901765*uk_139 + 4877359*uk_14 + 1899835*uk_140 + 744175*uk_141 + 4002565*uk_142 + 1567825*uk_143 + 614125*uk_144 + 46656*uk_145 + 133488*uk_146 + 281232*uk_147 + 110160*uk_148 + 381924*uk_149 + 10275601*uk_15 + 804636*uk_150 + 315180*uk_151 + 1695204*uk_152 + 664020*uk_153 + 260100*uk_154 + 1092727*uk_155 + 2302153*uk_156 + 901765*uk_157 + 4850167*uk_158 + 1899835*uk_159 + 4025005*uk_16 + 744175*uk_160 + 10218313*uk_161 + 4002565*uk_162 + 1567825*uk_163 + 614125*uk_164 + 3969*uk_17 + 4788*uk_18 + 5355*uk_19 + 63*uk_2 + 2268*uk_20 + 6489*uk_21 + 13671*uk_22 + 5355*uk_23 + 5776*uk_24 + 6460*uk_25 + 2736*uk_26 + 7828*uk_27 + 16492*uk_28 + 6460*uk_29 + 76*uk_3 + 7225*uk_30 + 3060*uk_31 + 8755*uk_32 + 18445*uk_33 + 7225*uk_34 + 1296*uk_35 + 3708*uk_36 + 7812*uk_37 + 3060*uk_38 + 10609*uk_39 + 85*uk_4 + 22351*uk_40 + 8755*uk_41 + 47089*uk_42 + 18445*uk_43 + 7225*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 170415302284*uk_47 + 190596061765*uk_48 + 80723037924*uk_49 + 36*uk_5 + 230957580727*uk_50 + 486580534153*uk_51 + 190596061765*uk_52 + 187944057*uk_53 + 226726164*uk_54 + 253575315*uk_55 + 107396604*uk_56 + 307273617*uk_57 + 647362863*uk_58 + 253575315*uk_59 + 103*uk_6 + 273510928*uk_60 + 305900380*uk_61 + 129557808*uk_62 + 370679284*uk_63 + 780945676*uk_64 + 305900380*uk_65 + 342125425*uk_66 + 144900180*uk_67 + 414575515*uk_68 + 873426085*uk_69 + 217*uk_7 + 342125425*uk_70 + 61369488*uk_71 + 175584924*uk_72 + 369921636*uk_73 + 144900180*uk_74 + 502367977*uk_75 + 1058386903*uk_76 + 414575515*uk_77 + 2229805417*uk_78 + 873426085*uk_79 + 85*uk_8 + 342125425*uk_80 + 250047*uk_81 + 301644*uk_82 + 337365*uk_83 + 142884*uk_84 + 408807*uk_85 + 861273*uk_86 + 337365*uk_87 + 363888*uk_88 + 406980*uk_89 + 2242306609*uk_9 + 172368*uk_90 + 493164*uk_91 + 1038996*uk_92 + 406980*uk_93 + 455175*uk_94 + 192780*uk_95 + 551565*uk_96 + 1162035*uk_97 + 455175*uk_98 + 81648*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 238140*uk_100 + 492156*uk_101 + 172368*uk_102 + 694575*uk_103 + 1435455*uk_104 + 502740*uk_105 + 2966607*uk_106 + 1038996*uk_107 + 363888*uk_108 + 1092727*uk_109 + 4877359*uk_11 + 806284*uk_110 + 381924*uk_111 + 1113945*uk_112 + 2302153*uk_113 + 806284*uk_114 + 594928*uk_115 + 281808*uk_116 + 821940*uk_117 + 1698676*uk_118 + 594928*uk_119 + 3598828*uk_12 + 133488*uk_120 + 389340*uk_121 + 804636*uk_122 + 281808*uk_123 + 1135575*uk_124 + 2346855*uk_125 + 821940*uk_126 + 4850167*uk_127 + 1698676*uk_128 + 594928*uk_129 + 1704708*uk_13 + 438976*uk_130 + 207936*uk_131 + 606480*uk_132 + 1253392*uk_133 + 438976*uk_134 + 98496*uk_135 + 287280*uk_136 + 593712*uk_137 + 207936*uk_138 + 837900*uk_139 + 4972065*uk_14 + 1731660*uk_140 + 606480*uk_141 + 3578764*uk_142 + 1253392*uk_143 + 438976*uk_144 + 46656*uk_145 + 136080*uk_146 + 281232*uk_147 + 98496*uk_148 + 396900*uk_149 + 10275601*uk_15 + 820260*uk_150 + 287280*uk_151 + 1695204*uk_152 + 593712*uk_153 + 207936*uk_154 + 1157625*uk_155 + 2392425*uk_156 + 837900*uk_157 + 4944345*uk_158 + 1731660*uk_159 + 3598828*uk_16 + 606480*uk_160 + 10218313*uk_161 + 3578764*uk_162 + 1253392*uk_163 + 438976*uk_164 + 3969*uk_17 + 6489*uk_18 + 4788*uk_19 + 63*uk_2 + 2268*uk_20 + 6615*uk_21 + 13671*uk_22 + 4788*uk_23 + 10609*uk_24 + 7828*uk_25 + 3708*uk_26 + 10815*uk_27 + 22351*uk_28 + 7828*uk_29 + 103*uk_3 + 5776*uk_30 + 2736*uk_31 + 7980*uk_32 + 16492*uk_33 + 5776*uk_34 + 1296*uk_35 + 3780*uk_36 + 7812*uk_37 + 2736*uk_38 + 11025*uk_39 + 76*uk_4 + 22785*uk_40 + 7980*uk_41 + 47089*uk_42 + 16492*uk_43 + 5776*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 230957580727*uk_47 + 170415302284*uk_48 + 80723037924*uk_49 + 36*uk_5 + 235442193945*uk_50 + 486580534153*uk_51 + 170415302284*uk_52 + 187944057*uk_53 + 307273617*uk_54 + 226726164*uk_55 + 107396604*uk_56 + 313240095*uk_57 + 647362863*uk_58 + 226726164*uk_59 + 105*uk_6 + 502367977*uk_60 + 370679284*uk_61 + 175584924*uk_62 + 512122695*uk_63 + 1058386903*uk_64 + 370679284*uk_65 + 273510928*uk_66 + 129557808*uk_67 + 377876940*uk_68 + 780945676*uk_69 + 217*uk_7 + 273510928*uk_70 + 61369488*uk_71 + 178994340*uk_72 + 369921636*uk_73 + 129557808*uk_74 + 522066825*uk_75 + 1078938105*uk_76 + 377876940*uk_77 + 2229805417*uk_78 + 780945676*uk_79 + 76*uk_8 + 273510928*uk_80 + 250047*uk_81 + 408807*uk_82 + 301644*uk_83 + 142884*uk_84 + 416745*uk_85 + 861273*uk_86 + 301644*uk_87 + 668367*uk_88 + 493164*uk_89 + 2242306609*uk_9 + 233604*uk_90 + 681345*uk_91 + 1408113*uk_92 + 493164*uk_93 + 363888*uk_94 + 172368*uk_95 + 502740*uk_96 + 1038996*uk_97 + 363888*uk_98 + 81648*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 215712*uk_100 + 437472*uk_101 + 207648*uk_102 + 721287*uk_103 + 1462797*uk_104 + 694323*uk_105 + 2966607*uk_106 + 1408113*uk_107 + 668367*uk_108 + 205379*uk_109 + 2793827*uk_11 + 358543*uk_110 + 111392*uk_111 + 372467*uk_112 + 755377*uk_113 + 358543*uk_114 + 625931*uk_115 + 194464*uk_116 + 650239*uk_117 + 1318709*uk_118 + 625931*uk_119 + 4877359*uk_12 + 60416*uk_120 + 202016*uk_121 + 409696*uk_122 + 194464*uk_123 + 675491*uk_124 + 1369921*uk_125 + 650239*uk_126 + 2778251*uk_127 + 1318709*uk_128 + 625931*uk_129 + 1515296*uk_13 + 1092727*uk_130 + 339488*uk_131 + 1135163*uk_132 + 2302153*uk_133 + 1092727*uk_134 + 105472*uk_135 + 352672*uk_136 + 715232*uk_137 + 339488*uk_138 + 1179247*uk_139 + 5066771*uk_14 + 2391557*uk_140 + 1135163*uk_141 + 4850167*uk_142 + 2302153*uk_143 + 1092727*uk_144 + 32768*uk_145 + 109568*uk_146 + 222208*uk_147 + 105472*uk_148 + 366368*uk_149 + 10275601*uk_15 + 743008*uk_150 + 352672*uk_151 + 1506848*uk_152 + 715232*uk_153 + 339488*uk_154 + 1225043*uk_155 + 2484433*uk_156 + 1179247*uk_157 + 5038523*uk_158 + 2391557*uk_159 + 4877359*uk_16 + 1135163*uk_160 + 10218313*uk_161 + 4850167*uk_162 + 2302153*uk_163 + 1092727*uk_164 + 3969*uk_17 + 3717*uk_18 + 6489*uk_19 + 63*uk_2 + 2016*uk_20 + 6741*uk_21 + 13671*uk_22 + 6489*uk_23 + 3481*uk_24 + 6077*uk_25 + 1888*uk_26 + 6313*uk_27 + 12803*uk_28 + 6077*uk_29 + 59*uk_3 + 10609*uk_30 + 3296*uk_31 + 11021*uk_32 + 22351*uk_33 + 10609*uk_34 + 1024*uk_35 + 3424*uk_36 + 6944*uk_37 + 3296*uk_38 + 11449*uk_39 + 103*uk_4 + 23219*uk_40 + 11021*uk_41 + 47089*uk_42 + 22351*uk_43 + 10609*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 132296089931*uk_47 + 230957580727*uk_48 + 71753811488*uk_49 + 32*uk_5 + 239926807163*uk_50 + 486580534153*uk_51 + 230957580727*uk_52 + 187944057*uk_53 + 176011101*uk_54 + 307273617*uk_55 + 95463648*uk_56 + 319206573*uk_57 + 647362863*uk_58 + 307273617*uk_59 + 107*uk_6 + 164835793*uk_60 + 287764181*uk_61 + 89402464*uk_62 + 298939489*uk_63 + 606260459*uk_64 + 287764181*uk_65 + 502367977*uk_66 + 156075488*uk_67 + 521877413*uk_68 + 1058386903*uk_69 + 217*uk_7 + 502367977*uk_70 + 48489472*uk_71 + 162136672*uk_72 + 328819232*uk_73 + 156075488*uk_74 + 542144497*uk_75 + 1099489307*uk_76 + 521877413*uk_77 + 2229805417*uk_78 + 1058386903*uk_79 + 103*uk_8 + 502367977*uk_80 + 250047*uk_81 + 234171*uk_82 + 408807*uk_83 + 127008*uk_84 + 424683*uk_85 + 861273*uk_86 + 408807*uk_87 + 219303*uk_88 + 382851*uk_89 + 2242306609*uk_9 + 118944*uk_90 + 397719*uk_91 + 806589*uk_92 + 382851*uk_93 + 668367*uk_94 + 207648*uk_95 + 694323*uk_96 + 1408113*uk_97 + 668367*uk_98 + 64512*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 219744*uk_100 + 437472*uk_101 + 118944*uk_102 + 748503*uk_103 + 1490139*uk_104 + 405153*uk_105 + 2966607*uk_106 + 806589*uk_107 + 219303*uk_108 + 103823*uk_109 + 2225591*uk_11 + 130331*uk_110 + 70688*uk_111 + 240781*uk_112 + 479353*uk_113 + 130331*uk_114 + 163607*uk_115 + 88736*uk_116 + 302257*uk_117 + 601741*uk_118 + 163607*uk_119 + 2793827*uk_12 + 48128*uk_120 + 163936*uk_121 + 326368*uk_122 + 88736*uk_123 + 558407*uk_124 + 1111691*uk_125 + 302257*uk_126 + 2213183*uk_127 + 601741*uk_128 + 163607*uk_129 + 1515296*uk_13 + 205379*uk_130 + 111392*uk_131 + 379429*uk_132 + 755377*uk_133 + 205379*uk_134 + 60416*uk_135 + 205792*uk_136 + 409696*uk_137 + 111392*uk_138 + 700979*uk_139 + 5161477*uk_14 + 1395527*uk_140 + 379429*uk_141 + 2778251*uk_142 + 755377*uk_143 + 205379*uk_144 + 32768*uk_145 + 111616*uk_146 + 222208*uk_147 + 60416*uk_148 + 380192*uk_149 + 10275601*uk_15 + 756896*uk_150 + 205792*uk_151 + 1506848*uk_152 + 409696*uk_153 + 111392*uk_154 + 1295029*uk_155 + 2578177*uk_156 + 700979*uk_157 + 5132701*uk_158 + 1395527*uk_159 + 2793827*uk_16 + 379429*uk_160 + 10218313*uk_161 + 2778251*uk_162 + 755377*uk_163 + 205379*uk_164 + 3969*uk_17 + 2961*uk_18 + 3717*uk_19 + 63*uk_2 + 2016*uk_20 + 6867*uk_21 + 13671*uk_22 + 3717*uk_23 + 2209*uk_24 + 2773*uk_25 + 1504*uk_26 + 5123*uk_27 + 10199*uk_28 + 2773*uk_29 + 47*uk_3 + 3481*uk_30 + 1888*uk_31 + 6431*uk_32 + 12803*uk_33 + 3481*uk_34 + 1024*uk_35 + 3488*uk_36 + 6944*uk_37 + 1888*uk_38 + 11881*uk_39 + 59*uk_4 + 23653*uk_40 + 6431*uk_41 + 47089*uk_42 + 12803*uk_43 + 3481*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 105388410623*uk_47 + 132296089931*uk_48 + 71753811488*uk_49 + 32*uk_5 + 244411420381*uk_50 + 486580534153*uk_51 + 132296089931*uk_52 + 187944057*uk_53 + 140212233*uk_54 + 176011101*uk_55 + 95463648*uk_56 + 325173051*uk_57 + 647362863*uk_58 + 176011101*uk_59 + 109*uk_6 + 104602777*uk_60 + 131309869*uk_61 + 71218912*uk_62 + 242589419*uk_63 + 482953247*uk_64 + 131309869*uk_65 + 164835793*uk_66 + 89402464*uk_67 + 304527143*uk_68 + 606260459*uk_69 + 217*uk_7 + 164835793*uk_70 + 48489472*uk_71 + 165167264*uk_72 + 328819232*uk_73 + 89402464*uk_74 + 562600993*uk_75 + 1120040509*uk_76 + 304527143*uk_77 + 2229805417*uk_78 + 606260459*uk_79 + 59*uk_8 + 164835793*uk_80 + 250047*uk_81 + 186543*uk_82 + 234171*uk_83 + 127008*uk_84 + 432621*uk_85 + 861273*uk_86 + 234171*uk_87 + 139167*uk_88 + 174699*uk_89 + 2242306609*uk_9 + 94752*uk_90 + 322749*uk_91 + 642537*uk_92 + 174699*uk_93 + 219303*uk_94 + 118944*uk_95 + 405153*uk_96 + 806589*uk_97 + 219303*uk_98 + 64512*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 223776*uk_100 + 437472*uk_101 + 94752*uk_102 + 776223*uk_103 + 1517481*uk_104 + 328671*uk_105 + 2966607*uk_106 + 642537*uk_107 + 139167*uk_108 + 300763*uk_109 + 3172651*uk_11 + 210983*uk_110 + 143648*uk_111 + 498279*uk_112 + 974113*uk_113 + 210983*uk_114 + 148003*uk_115 + 100768*uk_116 + 349539*uk_117 + 683333*uk_118 + 148003*uk_119 + 2225591*uk_12 + 68608*uk_120 + 237984*uk_121 + 465248*uk_122 + 100768*uk_123 + 825507*uk_124 + 1613829*uk_125 + 349539*uk_126 + 3154963*uk_127 + 683333*uk_128 + 148003*uk_129 + 1515296*uk_13 + 103823*uk_130 + 70688*uk_131 + 245199*uk_132 + 479353*uk_133 + 103823*uk_134 + 48128*uk_135 + 166944*uk_136 + 326368*uk_137 + 70688*uk_138 + 579087*uk_139 + 5256183*uk_14 + 1132089*uk_140 + 245199*uk_141 + 2213183*uk_142 + 479353*uk_143 + 103823*uk_144 + 32768*uk_145 + 113664*uk_146 + 222208*uk_147 + 48128*uk_148 + 394272*uk_149 + 10275601*uk_15 + 770784*uk_150 + 166944*uk_151 + 1506848*uk_152 + 326368*uk_153 + 70688*uk_154 + 1367631*uk_155 + 2673657*uk_156 + 579087*uk_157 + 5226879*uk_158 + 1132089*uk_159 + 2225591*uk_16 + 245199*uk_160 + 10218313*uk_161 + 2213183*uk_162 + 479353*uk_163 + 103823*uk_164 + 3969*uk_17 + 4221*uk_18 + 2961*uk_19 + 63*uk_2 + 2016*uk_20 + 6993*uk_21 + 13671*uk_22 + 2961*uk_23 + 4489*uk_24 + 3149*uk_25 + 2144*uk_26 + 7437*uk_27 + 14539*uk_28 + 3149*uk_29 + 67*uk_3 + 2209*uk_30 + 1504*uk_31 + 5217*uk_32 + 10199*uk_33 + 2209*uk_34 + 1024*uk_35 + 3552*uk_36 + 6944*uk_37 + 1504*uk_38 + 12321*uk_39 + 47*uk_4 + 24087*uk_40 + 5217*uk_41 + 47089*uk_42 + 10199*uk_43 + 2209*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 150234542803*uk_47 + 105388410623*uk_48 + 71753811488*uk_49 + 32*uk_5 + 248896033599*uk_50 + 486580534153*uk_51 + 105388410623*uk_52 + 187944057*uk_53 + 199877013*uk_54 + 140212233*uk_55 + 95463648*uk_56 + 331139529*uk_57 + 647362863*uk_58 + 140212233*uk_59 + 111*uk_6 + 212567617*uk_60 + 149114597*uk_61 + 101524832*uk_62 + 352164261*uk_63 + 688465267*uk_64 + 149114597*uk_65 + 104602777*uk_66 + 71218912*uk_67 + 247040601*uk_68 + 482953247*uk_69 + 217*uk_7 + 104602777*uk_70 + 48489472*uk_71 + 168197856*uk_72 + 328819232*uk_73 + 71218912*uk_74 + 583436313*uk_75 + 1140591711*uk_76 + 247040601*uk_77 + 2229805417*uk_78 + 482953247*uk_79 + 47*uk_8 + 104602777*uk_80 + 250047*uk_81 + 265923*uk_82 + 186543*uk_83 + 127008*uk_84 + 440559*uk_85 + 861273*uk_86 + 186543*uk_87 + 282807*uk_88 + 198387*uk_89 + 2242306609*uk_9 + 135072*uk_90 + 468531*uk_91 + 915957*uk_92 + 198387*uk_93 + 139167*uk_94 + 94752*uk_95 + 328671*uk_96 + 642537*uk_97 + 139167*uk_98 + 64512*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 199332*uk_100 + 382788*uk_101 + 118188*uk_102 + 804447*uk_103 + 1544823*uk_104 + 476973*uk_105 + 2966607*uk_106 + 915957*uk_107 + 282807*uk_108 + 216*uk_109 + 284118*uk_11 + 2412*uk_110 + 1008*uk_111 + 4068*uk_112 + 7812*uk_113 + 2412*uk_114 + 26934*uk_115 + 11256*uk_116 + 45426*uk_117 + 87234*uk_118 + 26934*uk_119 + 3172651*uk_12 + 4704*uk_120 + 18984*uk_121 + 36456*uk_122 + 11256*uk_123 + 76614*uk_124 + 147126*uk_125 + 45426*uk_126 + 282534*uk_127 + 87234*uk_128 + 26934*uk_129 + 1325884*uk_13 + 300763*uk_130 + 125692*uk_131 + 507257*uk_132 + 974113*uk_133 + 300763*uk_134 + 52528*uk_135 + 211988*uk_136 + 407092*uk_137 + 125692*uk_138 + 855523*uk_139 + 5350889*uk_14 + 1642907*uk_140 + 507257*uk_141 + 3154963*uk_142 + 974113*uk_143 + 300763*uk_144 + 21952*uk_145 + 88592*uk_146 + 170128*uk_147 + 52528*uk_148 + 357532*uk_149 + 10275601*uk_15 + 686588*uk_150 + 211988*uk_151 + 1318492*uk_152 + 407092*uk_153 + 125692*uk_154 + 1442897*uk_155 + 2770873*uk_156 + 855523*uk_157 + 5321057*uk_158 + 1642907*uk_159 + 3172651*uk_16 + 507257*uk_160 + 10218313*uk_161 + 3154963*uk_162 + 974113*uk_163 + 300763*uk_164 + 3969*uk_17 + 378*uk_18 + 4221*uk_19 + 63*uk_2 + 1764*uk_20 + 7119*uk_21 + 13671*uk_22 + 4221*uk_23 + 36*uk_24 + 402*uk_25 + 168*uk_26 + 678*uk_27 + 1302*uk_28 + 402*uk_29 + 6*uk_3 + 4489*uk_30 + 1876*uk_31 + 7571*uk_32 + 14539*uk_33 + 4489*uk_34 + 784*uk_35 + 3164*uk_36 + 6076*uk_37 + 1876*uk_38 + 12769*uk_39 + 67*uk_4 + 24521*uk_40 + 7571*uk_41 + 47089*uk_42 + 14539*uk_43 + 4489*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 13453839654*uk_47 + 150234542803*uk_48 + 62784585052*uk_49 + 28*uk_5 + 253380646817*uk_50 + 486580534153*uk_51 + 150234542803*uk_52 + 187944057*uk_53 + 17899434*uk_54 + 199877013*uk_55 + 83530692*uk_56 + 337106007*uk_57 + 647362863*uk_58 + 199877013*uk_59 + 113*uk_6 + 1704708*uk_60 + 19035906*uk_61 + 7955304*uk_62 + 32105334*uk_63 + 61653606*uk_64 + 19035906*uk_65 + 212567617*uk_66 + 88834228*uk_67 + 358509563*uk_68 + 688465267*uk_69 + 217*uk_7 + 212567617*uk_70 + 37124752*uk_71 + 149824892*uk_72 + 287716828*uk_73 + 88834228*uk_74 + 604650457*uk_75 + 1161142913*uk_76 + 358509563*uk_77 + 2229805417*uk_78 + 688465267*uk_79 + 67*uk_8 + 212567617*uk_80 + 250047*uk_81 + 23814*uk_82 + 265923*uk_83 + 111132*uk_84 + 448497*uk_85 + 861273*uk_86 + 265923*uk_87 + 2268*uk_88 + 25326*uk_89 + 2242306609*uk_9 + 10584*uk_90 + 42714*uk_91 + 82026*uk_92 + 25326*uk_93 + 282807*uk_94 + 118188*uk_95 + 476973*uk_96 + 915957*uk_97 + 282807*uk_98 + 49392*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 231840*uk_100 + 437472*uk_101 + 12096*uk_102 + 833175*uk_103 + 1572165*uk_104 + 43470*uk_105 + 2966607*uk_106 + 82026*uk_107 + 2268*uk_108 + 681472*uk_109 + 4167064*uk_11 + 46464*uk_110 + 247808*uk_111 + 890560*uk_112 + 1680448*uk_113 + 46464*uk_114 + 3168*uk_115 + 16896*uk_116 + 60720*uk_117 + 114576*uk_118 + 3168*uk_119 + 284118*uk_12 + 90112*uk_120 + 323840*uk_121 + 611072*uk_122 + 16896*uk_123 + 1163800*uk_124 + 2196040*uk_125 + 60720*uk_126 + 4143832*uk_127 + 114576*uk_128 + 3168*uk_129 + 1515296*uk_13 + 216*uk_130 + 1152*uk_131 + 4140*uk_132 + 7812*uk_133 + 216*uk_134 + 6144*uk_135 + 22080*uk_136 + 41664*uk_137 + 1152*uk_138 + 79350*uk_139 + 5445595*uk_14 + 149730*uk_140 + 4140*uk_141 + 282534*uk_142 + 7812*uk_143 + 216*uk_144 + 32768*uk_145 + 117760*uk_146 + 222208*uk_147 + 6144*uk_148 + 423200*uk_149 + 10275601*uk_15 + 798560*uk_150 + 22080*uk_151 + 1506848*uk_152 + 41664*uk_153 + 1152*uk_154 + 1520875*uk_155 + 2869825*uk_156 + 79350*uk_157 + 5415235*uk_158 + 149730*uk_159 + 284118*uk_16 + 4140*uk_160 + 10218313*uk_161 + 282534*uk_162 + 7812*uk_163 + 216*uk_164 + 3969*uk_17 + 5544*uk_18 + 378*uk_19 + 63*uk_2 + 2016*uk_20 + 7245*uk_21 + 13671*uk_22 + 378*uk_23 + 7744*uk_24 + 528*uk_25 + 2816*uk_26 + 10120*uk_27 + 19096*uk_28 + 528*uk_29 + 88*uk_3 + 36*uk_30 + 192*uk_31 + 690*uk_32 + 1302*uk_33 + 36*uk_34 + 1024*uk_35 + 3680*uk_36 + 6944*uk_37 + 192*uk_38 + 13225*uk_39 + 6*uk_4 + 24955*uk_40 + 690*uk_41 + 47089*uk_42 + 1302*uk_43 + 36*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 197322981592*uk_47 + 13453839654*uk_48 + 71753811488*uk_49 + 32*uk_5 + 257865260035*uk_50 + 486580534153*uk_51 + 13453839654*uk_52 + 187944057*uk_53 + 262525032*uk_54 + 17899434*uk_55 + 95463648*uk_56 + 343072485*uk_57 + 647362863*uk_58 + 17899434*uk_59 + 115*uk_6 + 366701632*uk_60 + 25002384*uk_61 + 133346048*uk_62 + 479212360*uk_63 + 904252888*uk_64 + 25002384*uk_65 + 1704708*uk_66 + 9091776*uk_67 + 32673570*uk_68 + 61653606*uk_69 + 217*uk_7 + 1704708*uk_70 + 48489472*uk_71 + 174259040*uk_72 + 328819232*uk_73 + 9091776*uk_74 + 626243425*uk_75 + 1181694115*uk_76 + 32673570*uk_77 + 2229805417*uk_78 + 61653606*uk_79 + 6*uk_8 + 1704708*uk_80 + 250047*uk_81 + 349272*uk_82 + 23814*uk_83 + 127008*uk_84 + 456435*uk_85 + 861273*uk_86 + 23814*uk_87 + 487872*uk_88 + 33264*uk_89 + 2242306609*uk_9 + 177408*uk_90 + 637560*uk_91 + 1203048*uk_92 + 33264*uk_93 + 2268*uk_94 + 12096*uk_95 + 43470*uk_96 + 82026*uk_97 + 2268*uk_98 + 64512*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 206388*uk_100 + 382788*uk_101 + 155232*uk_102 + 862407*uk_103 + 1599507*uk_104 + 648648*uk_105 + 2966607*uk_106 + 1203048*uk_107 + 487872*uk_108 + 614125*uk_109 + 4025005*uk_11 + 635800*uk_110 + 202300*uk_111 + 845325*uk_112 + 1567825*uk_113 + 635800*uk_114 + 658240*uk_115 + 209440*uk_116 + 875160*uk_117 + 1623160*uk_118 + 658240*uk_119 + 4167064*uk_12 + 66640*uk_120 + 278460*uk_121 + 516460*uk_122 + 209440*uk_123 + 1163565*uk_124 + 2158065*uk_125 + 875160*uk_126 + 4002565*uk_127 + 1623160*uk_128 + 658240*uk_129 + 1325884*uk_13 + 681472*uk_130 + 216832*uk_131 + 906048*uk_132 + 1680448*uk_133 + 681472*uk_134 + 68992*uk_135 + 288288*uk_136 + 534688*uk_137 + 216832*uk_138 + 1204632*uk_139 + 5540301*uk_14 + 2234232*uk_140 + 906048*uk_141 + 4143832*uk_142 + 1680448*uk_143 + 681472*uk_144 + 21952*uk_145 + 91728*uk_146 + 170128*uk_147 + 68992*uk_148 + 383292*uk_149 + 10275601*uk_15 + 710892*uk_150 + 288288*uk_151 + 1318492*uk_152 + 534688*uk_153 + 216832*uk_154 + 1601613*uk_155 + 2970513*uk_156 + 1204632*uk_157 + 5509413*uk_158 + 2234232*uk_159 + 4167064*uk_16 + 906048*uk_160 + 10218313*uk_161 + 4143832*uk_162 + 1680448*uk_163 + 681472*uk_164 + 3969*uk_17 + 5355*uk_18 + 5544*uk_19 + 63*uk_2 + 1764*uk_20 + 7371*uk_21 + 13671*uk_22 + 5544*uk_23 + 7225*uk_24 + 7480*uk_25 + 2380*uk_26 + 9945*uk_27 + 18445*uk_28 + 7480*uk_29 + 85*uk_3 + 7744*uk_30 + 2464*uk_31 + 10296*uk_32 + 19096*uk_33 + 7744*uk_34 + 784*uk_35 + 3276*uk_36 + 6076*uk_37 + 2464*uk_38 + 13689*uk_39 + 88*uk_4 + 25389*uk_40 + 10296*uk_41 + 47089*uk_42 + 19096*uk_43 + 7744*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 190596061765*uk_47 + 197322981592*uk_48 + 62784585052*uk_49 + 28*uk_5 + 262349873253*uk_50 + 486580534153*uk_51 + 197322981592*uk_52 + 187944057*uk_53 + 253575315*uk_54 + 262525032*uk_55 + 83530692*uk_56 + 349038963*uk_57 + 647362863*uk_58 + 262525032*uk_59 + 117*uk_6 + 342125425*uk_60 + 354200440*uk_61 + 112700140*uk_62 + 470925585*uk_63 + 873426085*uk_64 + 354200440*uk_65 + 366701632*uk_66 + 116677792*uk_67 + 487546488*uk_68 + 904252888*uk_69 + 217*uk_7 + 366701632*uk_70 + 37124752*uk_71 + 155128428*uk_72 + 287716828*uk_73 + 116677792*uk_74 + 648215217*uk_75 + 1202245317*uk_76 + 487546488*uk_77 + 2229805417*uk_78 + 904252888*uk_79 + 88*uk_8 + 366701632*uk_80 + 250047*uk_81 + 337365*uk_82 + 349272*uk_83 + 111132*uk_84 + 464373*uk_85 + 861273*uk_86 + 349272*uk_87 + 455175*uk_88 + 471240*uk_89 + 2242306609*uk_9 + 149940*uk_90 + 626535*uk_91 + 1162035*uk_92 + 471240*uk_93 + 487872*uk_94 + 155232*uk_95 + 648648*uk_96 + 1203048*uk_97 + 487872*uk_98 + 49392*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 209916*uk_100 + 382788*uk_101 + 149940*uk_102 + 892143*uk_103 + 1626849*uk_104 + 637245*uk_105 + 2966607*uk_106 + 1162035*uk_107 + 455175*uk_108 + 1331000*uk_109 + 5208830*uk_11 + 1028500*uk_110 + 338800*uk_111 + 1439900*uk_112 + 2625700*uk_113 + 1028500*uk_114 + 794750*uk_115 + 261800*uk_116 + 1112650*uk_117 + 2028950*uk_118 + 794750*uk_119 + 4025005*uk_12 + 86240*uk_120 + 366520*uk_121 + 668360*uk_122 + 261800*uk_123 + 1557710*uk_124 + 2840530*uk_125 + 1112650*uk_126 + 5179790*uk_127 + 2028950*uk_128 + 794750*uk_129 + 1325884*uk_13 + 614125*uk_130 + 202300*uk_131 + 859775*uk_132 + 1567825*uk_133 + 614125*uk_134 + 66640*uk_135 + 283220*uk_136 + 516460*uk_137 + 202300*uk_138 + 1203685*uk_139 + 5635007*uk_14 + 2194955*uk_140 + 859775*uk_141 + 4002565*uk_142 + 1567825*uk_143 + 614125*uk_144 + 21952*uk_145 + 93296*uk_146 + 170128*uk_147 + 66640*uk_148 + 396508*uk_149 + 10275601*uk_15 + 723044*uk_150 + 283220*uk_151 + 1318492*uk_152 + 516460*uk_153 + 202300*uk_154 + 1685159*uk_155 + 3072937*uk_156 + 1203685*uk_157 + 5603591*uk_158 + 2194955*uk_159 + 4025005*uk_16 + 859775*uk_160 + 10218313*uk_161 + 4002565*uk_162 + 1567825*uk_163 + 614125*uk_164 + 3969*uk_17 + 6930*uk_18 + 5355*uk_19 + 63*uk_2 + 1764*uk_20 + 7497*uk_21 + 13671*uk_22 + 5355*uk_23 + 12100*uk_24 + 9350*uk_25 + 3080*uk_26 + 13090*uk_27 + 23870*uk_28 + 9350*uk_29 + 110*uk_3 + 7225*uk_30 + 2380*uk_31 + 10115*uk_32 + 18445*uk_33 + 7225*uk_34 + 784*uk_35 + 3332*uk_36 + 6076*uk_37 + 2380*uk_38 + 14161*uk_39 + 85*uk_4 + 25823*uk_40 + 10115*uk_41 + 47089*uk_42 + 18445*uk_43 + 7225*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 246653726990*uk_47 + 190596061765*uk_48 + 62784585052*uk_49 + 28*uk_5 + 266834486471*uk_50 + 486580534153*uk_51 + 190596061765*uk_52 + 187944057*uk_53 + 328156290*uk_54 + 253575315*uk_55 + 83530692*uk_56 + 355005441*uk_57 + 647362863*uk_58 + 253575315*uk_59 + 119*uk_6 + 572971300*uk_60 + 442750550*uk_61 + 145847240*uk_62 + 619850770*uk_63 + 1130316110*uk_64 + 442750550*uk_65 + 342125425*uk_66 + 112700140*uk_67 + 478975595*uk_68 + 873426085*uk_69 + 217*uk_7 + 342125425*uk_70 + 37124752*uk_71 + 157780196*uk_72 + 287716828*uk_73 + 112700140*uk_74 + 670565833*uk_75 + 1222796519*uk_76 + 478975595*uk_77 + 2229805417*uk_78 + 873426085*uk_79 + 85*uk_8 + 342125425*uk_80 + 250047*uk_81 + 436590*uk_82 + 337365*uk_83 + 111132*uk_84 + 472311*uk_85 + 861273*uk_86 + 337365*uk_87 + 762300*uk_88 + 589050*uk_89 + 2242306609*uk_9 + 194040*uk_90 + 824670*uk_91 + 1503810*uk_92 + 589050*uk_93 + 455175*uk_94 + 149940*uk_95 + 637245*uk_96 + 1162035*uk_97 + 455175*uk_98 + 49392*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 182952*uk_100 + 328104*uk_101 + 166320*uk_102 + 922383*uk_103 + 1654191*uk_104 + 838530*uk_105 + 2966607*uk_106 + 1503810*uk_107 + 762300*uk_108 + 74088*uk_109 + 1988826*uk_11 + 194040*uk_110 + 42336*uk_111 + 213444*uk_112 + 382788*uk_113 + 194040*uk_114 + 508200*uk_115 + 110880*uk_116 + 559020*uk_117 + 1002540*uk_118 + 508200*uk_119 + 5208830*uk_12 + 24192*uk_120 + 121968*uk_121 + 218736*uk_122 + 110880*uk_123 + 614922*uk_124 + 1102794*uk_125 + 559020*uk_126 + 1977738*uk_127 + 1002540*uk_128 + 508200*uk_129 + 1136472*uk_13 + 1331000*uk_130 + 290400*uk_131 + 1464100*uk_132 + 2625700*uk_133 + 1331000*uk_134 + 63360*uk_135 + 319440*uk_136 + 572880*uk_137 + 290400*uk_138 + 1610510*uk_139 + 5729713*uk_14 + 2888270*uk_140 + 1464100*uk_141 + 5179790*uk_142 + 2625700*uk_143 + 1331000*uk_144 + 13824*uk_145 + 69696*uk_146 + 124992*uk_147 + 63360*uk_148 + 351384*uk_149 + 10275601*uk_15 + 630168*uk_150 + 319440*uk_151 + 1130136*uk_152 + 572880*uk_153 + 290400*uk_154 + 1771561*uk_155 + 3177097*uk_156 + 1610510*uk_157 + 5697769*uk_158 + 2888270*uk_159 + 5208830*uk_16 + 1464100*uk_160 + 10218313*uk_161 + 5179790*uk_162 + 2625700*uk_163 + 1331000*uk_164 + 3969*uk_17 + 2646*uk_18 + 6930*uk_19 + 63*uk_2 + 1512*uk_20 + 7623*uk_21 + 13671*uk_22 + 6930*uk_23 + 1764*uk_24 + 4620*uk_25 + 1008*uk_26 + 5082*uk_27 + 9114*uk_28 + 4620*uk_29 + 42*uk_3 + 12100*uk_30 + 2640*uk_31 + 13310*uk_32 + 23870*uk_33 + 12100*uk_34 + 576*uk_35 + 2904*uk_36 + 5208*uk_37 + 2640*uk_38 + 14641*uk_39 + 110*uk_4 + 26257*uk_40 + 13310*uk_41 + 47089*uk_42 + 23870*uk_43 + 12100*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 94176877578*uk_47 + 246653726990*uk_48 + 53815358616*uk_49 + 24*uk_5 + 271319099689*uk_50 + 486580534153*uk_51 + 246653726990*uk_52 + 187944057*uk_53 + 125296038*uk_54 + 328156290*uk_55 + 71597736*uk_56 + 360971919*uk_57 + 647362863*uk_58 + 328156290*uk_59 + 121*uk_6 + 83530692*uk_60 + 218770860*uk_61 + 47731824*uk_62 + 240647946*uk_63 + 431575242*uk_64 + 218770860*uk_65 + 572971300*uk_66 + 125011920*uk_67 + 630268430*uk_68 + 1130316110*uk_69 + 217*uk_7 + 572971300*uk_70 + 27275328*uk_71 + 137513112*uk_72 + 246614424*uk_73 + 125011920*uk_74 + 693295273*uk_75 + 1243347721*uk_76 + 630268430*uk_77 + 2229805417*uk_78 + 1130316110*uk_79 + 110*uk_8 + 572971300*uk_80 + 250047*uk_81 + 166698*uk_82 + 436590*uk_83 + 95256*uk_84 + 480249*uk_85 + 861273*uk_86 + 436590*uk_87 + 111132*uk_88 + 291060*uk_89 + 2242306609*uk_9 + 63504*uk_90 + 320166*uk_91 + 574182*uk_92 + 291060*uk_93 + 762300*uk_94 + 166320*uk_95 + 838530*uk_96 + 1503810*uk_97 + 762300*uk_98 + 36288*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 216972*uk_100 + 382788*uk_101 + 74088*uk_102 + 953127*uk_103 + 1681533*uk_104 + 325458*uk_105 + 2966607*uk_106 + 574182*uk_107 + 111132*uk_108 + 1771561*uk_109 + 5729713*uk_11 + 614922*uk_110 + 409948*uk_111 + 1800843*uk_112 + 3177097*uk_113 + 614922*uk_114 + 213444*uk_115 + 142296*uk_116 + 625086*uk_117 + 1102794*uk_118 + 213444*uk_119 + 1988826*uk_12 + 94864*uk_120 + 416724*uk_121 + 735196*uk_122 + 142296*uk_123 + 1830609*uk_124 + 3229611*uk_125 + 625086*uk_126 + 5697769*uk_127 + 1102794*uk_128 + 213444*uk_129 + 1325884*uk_13 + 74088*uk_130 + 49392*uk_131 + 216972*uk_132 + 382788*uk_133 + 74088*uk_134 + 32928*uk_135 + 144648*uk_136 + 255192*uk_137 + 49392*uk_138 + 635418*uk_139 + 5824419*uk_14 + 1121022*uk_140 + 216972*uk_141 + 1977738*uk_142 + 382788*uk_143 + 74088*uk_144 + 21952*uk_145 + 96432*uk_146 + 170128*uk_147 + 32928*uk_148 + 423612*uk_149 + 10275601*uk_15 + 747348*uk_150 + 144648*uk_151 + 1318492*uk_152 + 255192*uk_153 + 49392*uk_154 + 1860867*uk_155 + 3282993*uk_156 + 635418*uk_157 + 5791947*uk_158 + 1121022*uk_159 + 1988826*uk_16 + 216972*uk_160 + 10218313*uk_161 + 1977738*uk_162 + 382788*uk_163 + 74088*uk_164 + 3969*uk_17 + 7623*uk_18 + 2646*uk_19 + 63*uk_2 + 1764*uk_20 + 7749*uk_21 + 13671*uk_22 + 2646*uk_23 + 14641*uk_24 + 5082*uk_25 + 3388*uk_26 + 14883*uk_27 + 26257*uk_28 + 5082*uk_29 + 121*uk_3 + 1764*uk_30 + 1176*uk_31 + 5166*uk_32 + 9114*uk_33 + 1764*uk_34 + 784*uk_35 + 3444*uk_36 + 6076*uk_37 + 1176*uk_38 + 15129*uk_39 + 42*uk_4 + 26691*uk_40 + 5166*uk_41 + 47089*uk_42 + 9114*uk_43 + 1764*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 271319099689*uk_47 + 94176877578*uk_48 + 62784585052*uk_49 + 28*uk_5 + 275803712907*uk_50 + 486580534153*uk_51 + 94176877578*uk_52 + 187944057*uk_53 + 360971919*uk_54 + 125296038*uk_55 + 83530692*uk_56 + 366938397*uk_57 + 647362863*uk_58 + 125296038*uk_59 + 123*uk_6 + 693295273*uk_60 + 240647946*uk_61 + 160431964*uk_62 + 704754699*uk_63 + 1243347721*uk_64 + 240647946*uk_65 + 83530692*uk_66 + 55687128*uk_67 + 244625598*uk_68 + 431575242*uk_69 + 217*uk_7 + 83530692*uk_70 + 37124752*uk_71 + 163083732*uk_72 + 287716828*uk_73 + 55687128*uk_74 + 716403537*uk_75 + 1263898923*uk_76 + 244625598*uk_77 + 2229805417*uk_78 + 431575242*uk_79 + 42*uk_8 + 83530692*uk_80 + 250047*uk_81 + 480249*uk_82 + 166698*uk_83 + 111132*uk_84 + 488187*uk_85 + 861273*uk_86 + 166698*uk_87 + 922383*uk_88 + 320166*uk_89 + 2242306609*uk_9 + 213444*uk_90 + 937629*uk_91 + 1654191*uk_92 + 320166*uk_93 + 111132*uk_94 + 74088*uk_95 + 325458*uk_96 + 574182*uk_97 + 111132*uk_98 + 49392*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 189000*uk_100 + 328104*uk_101 + 182952*uk_102 + 984375*uk_103 + 1708875*uk_104 + 952875*uk_105 + 2966607*uk_106 + 1654191*uk_107 + 922383*uk_108 + 1092727*uk_109 + 4877359*uk_11 + 1283689*uk_110 + 254616*uk_111 + 1326125*uk_112 + 2302153*uk_113 + 1283689*uk_114 + 1508023*uk_115 + 299112*uk_116 + 1557875*uk_117 + 2704471*uk_118 + 1508023*uk_119 + 5729713*uk_12 + 59328*uk_120 + 309000*uk_121 + 536424*uk_122 + 299112*uk_123 + 1609375*uk_124 + 2793875*uk_125 + 1557875*uk_126 + 4850167*uk_127 + 2704471*uk_128 + 1508023*uk_129 + 1136472*uk_13 + 1771561*uk_130 + 351384*uk_131 + 1830125*uk_132 + 3177097*uk_133 + 1771561*uk_134 + 69696*uk_135 + 363000*uk_136 + 630168*uk_137 + 351384*uk_138 + 1890625*uk_139 + 5919125*uk_14 + 3282125*uk_140 + 1830125*uk_141 + 5697769*uk_142 + 3177097*uk_143 + 1771561*uk_144 + 13824*uk_145 + 72000*uk_146 + 124992*uk_147 + 69696*uk_148 + 375000*uk_149 + 10275601*uk_15 + 651000*uk_150 + 363000*uk_151 + 1130136*uk_152 + 630168*uk_153 + 351384*uk_154 + 1953125*uk_155 + 3390625*uk_156 + 1890625*uk_157 + 5886125*uk_158 + 3282125*uk_159 + 5729713*uk_16 + 1830125*uk_160 + 10218313*uk_161 + 5697769*uk_162 + 3177097*uk_163 + 1771561*uk_164 + 3969*uk_17 + 6489*uk_18 + 7623*uk_19 + 63*uk_2 + 1512*uk_20 + 7875*uk_21 + 13671*uk_22 + 7623*uk_23 + 10609*uk_24 + 12463*uk_25 + 2472*uk_26 + 12875*uk_27 + 22351*uk_28 + 12463*uk_29 + 103*uk_3 + 14641*uk_30 + 2904*uk_31 + 15125*uk_32 + 26257*uk_33 + 14641*uk_34 + 576*uk_35 + 3000*uk_36 + 5208*uk_37 + 2904*uk_38 + 15625*uk_39 + 121*uk_4 + 27125*uk_40 + 15125*uk_41 + 47089*uk_42 + 26257*uk_43 + 14641*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 230957580727*uk_47 + 271319099689*uk_48 + 53815358616*uk_49 + 24*uk_5 + 280288326125*uk_50 + 486580534153*uk_51 + 271319099689*uk_52 + 187944057*uk_53 + 307273617*uk_54 + 360971919*uk_55 + 71597736*uk_56 + 372904875*uk_57 + 647362863*uk_58 + 360971919*uk_59 + 125*uk_6 + 502367977*uk_60 + 590160439*uk_61 + 117056616*uk_62 + 609669875*uk_63 + 1058386903*uk_64 + 590160439*uk_65 + 693295273*uk_66 + 137513112*uk_67 + 716214125*uk_68 + 1243347721*uk_69 + 217*uk_7 + 693295273*uk_70 + 27275328*uk_71 + 142059000*uk_72 + 246614424*uk_73 + 137513112*uk_74 + 739890625*uk_75 + 1284450125*uk_76 + 716214125*uk_77 + 2229805417*uk_78 + 1243347721*uk_79 + 121*uk_8 + 693295273*uk_80 + 250047*uk_81 + 408807*uk_82 + 480249*uk_83 + 95256*uk_84 + 496125*uk_85 + 861273*uk_86 + 480249*uk_87 + 668367*uk_88 + 785169*uk_89 + 2242306609*uk_9 + 155736*uk_90 + 811125*uk_91 + 1408113*uk_92 + 785169*uk_93 + 922383*uk_94 + 182952*uk_95 + 952875*uk_96 + 1654191*uk_97 + 922383*uk_98 + 36288*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 192024*uk_100 + 328104*uk_101 + 155736*uk_102 + 1016127*uk_103 + 1736217*uk_104 + 824103*uk_105 + 2966607*uk_106 + 1408113*uk_107 + 668367*uk_108 + 1295029*uk_109 + 5161477*uk_11 + 1223743*uk_110 + 285144*uk_111 + 1508887*uk_112 + 2578177*uk_113 + 1223743*uk_114 + 1156381*uk_115 + 269448*uk_116 + 1425829*uk_117 + 2436259*uk_118 + 1156381*uk_119 + 4877359*uk_12 + 62784*uk_120 + 332232*uk_121 + 567672*uk_122 + 269448*uk_123 + 1758061*uk_124 + 3003931*uk_125 + 1425829*uk_126 + 5132701*uk_127 + 2436259*uk_128 + 1156381*uk_129 + 1136472*uk_13 + 1092727*uk_130 + 254616*uk_131 + 1347343*uk_132 + 2302153*uk_133 + 1092727*uk_134 + 59328*uk_135 + 313944*uk_136 + 536424*uk_137 + 254616*uk_138 + 1661287*uk_139 + 6013831*uk_14 + 2838577*uk_140 + 1347343*uk_141 + 4850167*uk_142 + 2302153*uk_143 + 1092727*uk_144 + 13824*uk_145 + 73152*uk_146 + 124992*uk_147 + 59328*uk_148 + 387096*uk_149 + 10275601*uk_15 + 661416*uk_150 + 313944*uk_151 + 1130136*uk_152 + 536424*uk_153 + 254616*uk_154 + 2048383*uk_155 + 3499993*uk_156 + 1661287*uk_157 + 5980303*uk_158 + 2838577*uk_159 + 4877359*uk_16 + 1347343*uk_160 + 10218313*uk_161 + 4850167*uk_162 + 2302153*uk_163 + 1092727*uk_164 + 3969*uk_17 + 6867*uk_18 + 6489*uk_19 + 63*uk_2 + 1512*uk_20 + 8001*uk_21 + 13671*uk_22 + 6489*uk_23 + 11881*uk_24 + 11227*uk_25 + 2616*uk_26 + 13843*uk_27 + 23653*uk_28 + 11227*uk_29 + 109*uk_3 + 10609*uk_30 + 2472*uk_31 + 13081*uk_32 + 22351*uk_33 + 10609*uk_34 + 576*uk_35 + 3048*uk_36 + 5208*uk_37 + 2472*uk_38 + 16129*uk_39 + 103*uk_4 + 27559*uk_40 + 13081*uk_41 + 47089*uk_42 + 22351*uk_43 + 10609*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 244411420381*uk_47 + 230957580727*uk_48 + 53815358616*uk_49 + 24*uk_5 + 284772939343*uk_50 + 486580534153*uk_51 + 230957580727*uk_52 + 187944057*uk_53 + 325173051*uk_54 + 307273617*uk_55 + 71597736*uk_56 + 378871353*uk_57 + 647362863*uk_58 + 307273617*uk_59 + 127*uk_6 + 562600993*uk_60 + 531632131*uk_61 + 123875448*uk_62 + 655507579*uk_63 + 1120040509*uk_64 + 531632131*uk_65 + 502367977*uk_66 + 117056616*uk_67 + 619424593*uk_68 + 1058386903*uk_69 + 217*uk_7 + 502367977*uk_70 + 27275328*uk_71 + 144331944*uk_72 + 246614424*uk_73 + 117056616*uk_74 + 763756537*uk_75 + 1305001327*uk_76 + 619424593*uk_77 + 2229805417*uk_78 + 1058386903*uk_79 + 103*uk_8 + 502367977*uk_80 + 250047*uk_81 + 432621*uk_82 + 408807*uk_83 + 95256*uk_84 + 504063*uk_85 + 861273*uk_86 + 408807*uk_87 + 748503*uk_88 + 707301*uk_89 + 2242306609*uk_9 + 164808*uk_90 + 872109*uk_91 + 1490139*uk_92 + 707301*uk_93 + 668367*uk_94 + 155736*uk_95 + 824103*uk_96 + 1408113*uk_97 + 668367*uk_98 + 36288*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 162540*uk_100 + 273420*uk_101 + 137340*uk_102 + 1048383*uk_103 + 1763559*uk_104 + 885843*uk_105 + 2966607*uk_106 + 1490139*uk_107 + 748503*uk_108 + 1000*uk_109 + 473530*uk_11 + 10900*uk_110 + 2000*uk_111 + 12900*uk_112 + 21700*uk_113 + 10900*uk_114 + 118810*uk_115 + 21800*uk_116 + 140610*uk_117 + 236530*uk_118 + 118810*uk_119 + 5161477*uk_12 + 4000*uk_120 + 25800*uk_121 + 43400*uk_122 + 21800*uk_123 + 166410*uk_124 + 279930*uk_125 + 140610*uk_126 + 470890*uk_127 + 236530*uk_128 + 118810*uk_129 + 947060*uk_13 + 1295029*uk_130 + 237620*uk_131 + 1532649*uk_132 + 2578177*uk_133 + 1295029*uk_134 + 43600*uk_135 + 281220*uk_136 + 473060*uk_137 + 237620*uk_138 + 1813869*uk_139 + 6108537*uk_14 + 3051237*uk_140 + 1532649*uk_141 + 5132701*uk_142 + 2578177*uk_143 + 1295029*uk_144 + 8000*uk_145 + 51600*uk_146 + 86800*uk_147 + 43600*uk_148 + 332820*uk_149 + 10275601*uk_15 + 559860*uk_150 + 281220*uk_151 + 941780*uk_152 + 473060*uk_153 + 237620*uk_154 + 2146689*uk_155 + 3611097*uk_156 + 1813869*uk_157 + 6074481*uk_158 + 3051237*uk_159 + 5161477*uk_16 + 1532649*uk_160 + 10218313*uk_161 + 5132701*uk_162 + 2578177*uk_163 + 1295029*uk_164 + 3969*uk_17 + 630*uk_18 + 6867*uk_19 + 63*uk_2 + 1260*uk_20 + 8127*uk_21 + 13671*uk_22 + 6867*uk_23 + 100*uk_24 + 1090*uk_25 + 200*uk_26 + 1290*uk_27 + 2170*uk_28 + 1090*uk_29 + 10*uk_3 + 11881*uk_30 + 2180*uk_31 + 14061*uk_32 + 23653*uk_33 + 11881*uk_34 + 400*uk_35 + 2580*uk_36 + 4340*uk_37 + 2180*uk_38 + 16641*uk_39 + 109*uk_4 + 27993*uk_40 + 14061*uk_41 + 47089*uk_42 + 23653*uk_43 + 11881*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 22423066090*uk_47 + 244411420381*uk_48 + 44846132180*uk_49 + 20*uk_5 + 289257552561*uk_50 + 486580534153*uk_51 + 244411420381*uk_52 + 187944057*uk_53 + 29832390*uk_54 + 325173051*uk_55 + 59664780*uk_56 + 384837831*uk_57 + 647362863*uk_58 + 325173051*uk_59 + 129*uk_6 + 4735300*uk_60 + 51614770*uk_61 + 9470600*uk_62 + 61085370*uk_63 + 102756010*uk_64 + 51614770*uk_65 + 562600993*uk_66 + 103229540*uk_67 + 665830533*uk_68 + 1120040509*uk_69 + 217*uk_7 + 562600993*uk_70 + 18941200*uk_71 + 122170740*uk_72 + 205512020*uk_73 + 103229540*uk_74 + 788001273*uk_75 + 1325552529*uk_76 + 665830533*uk_77 + 2229805417*uk_78 + 1120040509*uk_79 + 109*uk_8 + 562600993*uk_80 + 250047*uk_81 + 39690*uk_82 + 432621*uk_83 + 79380*uk_84 + 512001*uk_85 + 861273*uk_86 + 432621*uk_87 + 6300*uk_88 + 68670*uk_89 + 2242306609*uk_9 + 12600*uk_90 + 81270*uk_91 + 136710*uk_92 + 68670*uk_93 + 748503*uk_94 + 137340*uk_95 + 885843*uk_96 + 1490139*uk_97 + 748503*uk_98 + 25200*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 198072*uk_100 + 328104*uk_101 + 15120*uk_102 + 1081143*uk_103 + 1790901*uk_104 + 82530*uk_105 + 2966607*uk_106 + 136710*uk_107 + 6300*uk_108 + 238328*uk_109 + 2935886*uk_11 + 38440*uk_110 + 92256*uk_111 + 503564*uk_112 + 834148*uk_113 + 38440*uk_114 + 6200*uk_115 + 14880*uk_116 + 81220*uk_117 + 134540*uk_118 + 6200*uk_119 + 473530*uk_12 + 35712*uk_120 + 194928*uk_121 + 322896*uk_122 + 14880*uk_123 + 1063982*uk_124 + 1762474*uk_125 + 81220*uk_126 + 2919518*uk_127 + 134540*uk_128 + 6200*uk_129 + 1136472*uk_13 + 1000*uk_130 + 2400*uk_131 + 13100*uk_132 + 21700*uk_133 + 1000*uk_134 + 5760*uk_135 + 31440*uk_136 + 52080*uk_137 + 2400*uk_138 + 171610*uk_139 + 6203243*uk_14 + 284270*uk_140 + 13100*uk_141 + 470890*uk_142 + 21700*uk_143 + 1000*uk_144 + 13824*uk_145 + 75456*uk_146 + 124992*uk_147 + 5760*uk_148 + 411864*uk_149 + 10275601*uk_15 + 682248*uk_150 + 31440*uk_151 + 1130136*uk_152 + 52080*uk_153 + 2400*uk_154 + 2248091*uk_155 + 3723937*uk_156 + 171610*uk_157 + 6168659*uk_158 + 284270*uk_159 + 473530*uk_16 + 13100*uk_160 + 10218313*uk_161 + 470890*uk_162 + 21700*uk_163 + 1000*uk_164 + 3969*uk_17 + 3906*uk_18 + 630*uk_19 + 63*uk_2 + 1512*uk_20 + 8253*uk_21 + 13671*uk_22 + 630*uk_23 + 3844*uk_24 + 620*uk_25 + 1488*uk_26 + 8122*uk_27 + 13454*uk_28 + 620*uk_29 + 62*uk_3 + 100*uk_30 + 240*uk_31 + 1310*uk_32 + 2170*uk_33 + 100*uk_34 + 576*uk_35 + 3144*uk_36 + 5208*uk_37 + 240*uk_38 + 17161*uk_39 + 10*uk_4 + 28427*uk_40 + 1310*uk_41 + 47089*uk_42 + 2170*uk_43 + 100*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 139023009758*uk_47 + 22423066090*uk_48 + 53815358616*uk_49 + 24*uk_5 + 293742165779*uk_50 + 486580534153*uk_51 + 22423066090*uk_52 + 187944057*uk_53 + 184960818*uk_54 + 29832390*uk_55 + 71597736*uk_56 + 390804309*uk_57 + 647362863*uk_58 + 29832390*uk_59 + 131*uk_6 + 182024932*uk_60 + 29358860*uk_61 + 70461264*uk_62 + 384601066*uk_63 + 637087262*uk_64 + 29358860*uk_65 + 4735300*uk_66 + 11364720*uk_67 + 62032430*uk_68 + 102756010*uk_69 + 217*uk_7 + 4735300*uk_70 + 27275328*uk_71 + 148877832*uk_72 + 246614424*uk_73 + 11364720*uk_74 + 812624833*uk_75 + 1346103731*uk_76 + 62032430*uk_77 + 2229805417*uk_78 + 102756010*uk_79 + 10*uk_8 + 4735300*uk_80 + 250047*uk_81 + 246078*uk_82 + 39690*uk_83 + 95256*uk_84 + 519939*uk_85 + 861273*uk_86 + 39690*uk_87 + 242172*uk_88 + 39060*uk_89 + 2242306609*uk_9 + 93744*uk_90 + 511686*uk_91 + 847602*uk_92 + 39060*uk_93 + 6300*uk_94 + 15120*uk_95 + 82530*uk_96 + 136710*uk_97 + 6300*uk_98 + 36288*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 167580*uk_100 + 273420*uk_101 + 78120*uk_102 + 1114407*uk_103 + 1818243*uk_104 + 519498*uk_105 + 2966607*uk_106 + 847602*uk_107 + 242172*uk_108 + 125*uk_109 + 236765*uk_11 + 1550*uk_110 + 500*uk_111 + 3325*uk_112 + 5425*uk_113 + 1550*uk_114 + 19220*uk_115 + 6200*uk_116 + 41230*uk_117 + 67270*uk_118 + 19220*uk_119 + 2935886*uk_12 + 2000*uk_120 + 13300*uk_121 + 21700*uk_122 + 6200*uk_123 + 88445*uk_124 + 144305*uk_125 + 41230*uk_126 + 235445*uk_127 + 67270*uk_128 + 19220*uk_129 + 947060*uk_13 + 238328*uk_130 + 76880*uk_131 + 511252*uk_132 + 834148*uk_133 + 238328*uk_134 + 24800*uk_135 + 164920*uk_136 + 269080*uk_137 + 76880*uk_138 + 1096718*uk_139 + 6297949*uk_14 + 1789382*uk_140 + 511252*uk_141 + 2919518*uk_142 + 834148*uk_143 + 238328*uk_144 + 8000*uk_145 + 53200*uk_146 + 86800*uk_147 + 24800*uk_148 + 353780*uk_149 + 10275601*uk_15 + 577220*uk_150 + 164920*uk_151 + 941780*uk_152 + 269080*uk_153 + 76880*uk_154 + 2352637*uk_155 + 3838513*uk_156 + 1096718*uk_157 + 6262837*uk_158 + 1789382*uk_159 + 2935886*uk_16 + 511252*uk_160 + 10218313*uk_161 + 2919518*uk_162 + 834148*uk_163 + 238328*uk_164 + 3969*uk_17 + 315*uk_18 + 3906*uk_19 + 63*uk_2 + 1260*uk_20 + 8379*uk_21 + 13671*uk_22 + 3906*uk_23 + 25*uk_24 + 310*uk_25 + 100*uk_26 + 665*uk_27 + 1085*uk_28 + 310*uk_29 + 5*uk_3 + 3844*uk_30 + 1240*uk_31 + 8246*uk_32 + 13454*uk_33 + 3844*uk_34 + 400*uk_35 + 2660*uk_36 + 4340*uk_37 + 1240*uk_38 + 17689*uk_39 + 62*uk_4 + 28861*uk_40 + 8246*uk_41 + 47089*uk_42 + 13454*uk_43 + 3844*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 11211533045*uk_47 + 139023009758*uk_48 + 44846132180*uk_49 + 20*uk_5 + 298226778997*uk_50 + 486580534153*uk_51 + 139023009758*uk_52 + 187944057*uk_53 + 14916195*uk_54 + 184960818*uk_55 + 59664780*uk_56 + 396770787*uk_57 + 647362863*uk_58 + 184960818*uk_59 + 133*uk_6 + 1183825*uk_60 + 14679430*uk_61 + 4735300*uk_62 + 31489745*uk_63 + 51378005*uk_64 + 14679430*uk_65 + 182024932*uk_66 + 58717720*uk_67 + 390472838*uk_68 + 637087262*uk_69 + 217*uk_7 + 182024932*uk_70 + 18941200*uk_71 + 125958980*uk_72 + 205512020*uk_73 + 58717720*uk_74 + 837627217*uk_75 + 1366654933*uk_76 + 390472838*uk_77 + 2229805417*uk_78 + 637087262*uk_79 + 62*uk_8 + 182024932*uk_80 + 250047*uk_81 + 19845*uk_82 + 246078*uk_83 + 79380*uk_84 + 527877*uk_85 + 861273*uk_86 + 246078*uk_87 + 1575*uk_88 + 19530*uk_89 + 2242306609*uk_9 + 6300*uk_90 + 41895*uk_91 + 68355*uk_92 + 19530*uk_93 + 242172*uk_94 + 78120*uk_95 + 519498*uk_96 + 847602*uk_97 + 242172*uk_98 + 25200*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 204120*uk_100 + 328104*uk_101 + 7560*uk_102 + 1148175*uk_103 + 1845585*uk_104 + 42525*uk_105 + 2966607*uk_106 + 68355*uk_107 + 1575*uk_108 + 1092727*uk_109 + 4877359*uk_11 + 53045*uk_110 + 254616*uk_111 + 1432215*uk_112 + 2302153*uk_113 + 53045*uk_114 + 2575*uk_115 + 12360*uk_116 + 69525*uk_117 + 111755*uk_118 + 2575*uk_119 + 236765*uk_12 + 59328*uk_120 + 333720*uk_121 + 536424*uk_122 + 12360*uk_123 + 1877175*uk_124 + 3017385*uk_125 + 69525*uk_126 + 4850167*uk_127 + 111755*uk_128 + 2575*uk_129 + 1136472*uk_13 + 125*uk_130 + 600*uk_131 + 3375*uk_132 + 5425*uk_133 + 125*uk_134 + 2880*uk_135 + 16200*uk_136 + 26040*uk_137 + 600*uk_138 + 91125*uk_139 + 6392655*uk_14 + 146475*uk_140 + 3375*uk_141 + 235445*uk_142 + 5425*uk_143 + 125*uk_144 + 13824*uk_145 + 77760*uk_146 + 124992*uk_147 + 2880*uk_148 + 437400*uk_149 + 10275601*uk_15 + 703080*uk_150 + 16200*uk_151 + 1130136*uk_152 + 26040*uk_153 + 600*uk_154 + 2460375*uk_155 + 3954825*uk_156 + 91125*uk_157 + 6357015*uk_158 + 146475*uk_159 + 236765*uk_16 + 3375*uk_160 + 10218313*uk_161 + 235445*uk_162 + 5425*uk_163 + 125*uk_164 + 3969*uk_17 + 6489*uk_18 + 315*uk_19 + 63*uk_2 + 1512*uk_20 + 8505*uk_21 + 13671*uk_22 + 315*uk_23 + 10609*uk_24 + 515*uk_25 + 2472*uk_26 + 13905*uk_27 + 22351*uk_28 + 515*uk_29 + 103*uk_3 + 25*uk_30 + 120*uk_31 + 675*uk_32 + 1085*uk_33 + 25*uk_34 + 576*uk_35 + 3240*uk_36 + 5208*uk_37 + 120*uk_38 + 18225*uk_39 + 5*uk_4 + 29295*uk_40 + 675*uk_41 + 47089*uk_42 + 1085*uk_43 + 25*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 230957580727*uk_47 + 11211533045*uk_48 + 53815358616*uk_49 + 24*uk_5 + 302711392215*uk_50 + 486580534153*uk_51 + 11211533045*uk_52 + 187944057*uk_53 + 307273617*uk_54 + 14916195*uk_55 + 71597736*uk_56 + 402737265*uk_57 + 647362863*uk_58 + 14916195*uk_59 + 135*uk_6 + 502367977*uk_60 + 24386795*uk_61 + 117056616*uk_62 + 658443465*uk_63 + 1058386903*uk_64 + 24386795*uk_65 + 1183825*uk_66 + 5682360*uk_67 + 31963275*uk_68 + 51378005*uk_69 + 217*uk_7 + 1183825*uk_70 + 27275328*uk_71 + 153423720*uk_72 + 246614424*uk_73 + 5682360*uk_74 + 863008425*uk_75 + 1387206135*uk_76 + 31963275*uk_77 + 2229805417*uk_78 + 51378005*uk_79 + 5*uk_8 + 1183825*uk_80 + 250047*uk_81 + 408807*uk_82 + 19845*uk_83 + 95256*uk_84 + 535815*uk_85 + 861273*uk_86 + 19845*uk_87 + 668367*uk_88 + 32445*uk_89 + 2242306609*uk_9 + 155736*uk_90 + 876015*uk_91 + 1408113*uk_92 + 32445*uk_93 + 1575*uk_94 + 7560*uk_95 + 42525*uk_96 + 68355*uk_97 + 1575*uk_98 + 36288*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 172620*uk_100 + 273420*uk_101 + 129780*uk_102 + 1182447*uk_103 + 1872927*uk_104 + 888993*uk_105 + 2966607*uk_106 + 1408113*uk_107 + 668367*uk_108 + 681472*uk_109 + 4167064*uk_11 + 797632*uk_110 + 154880*uk_111 + 1060928*uk_112 + 1680448*uk_113 + 797632*uk_114 + 933592*uk_115 + 181280*uk_116 + 1241768*uk_117 + 1966888*uk_118 + 933592*uk_119 + 4877359*uk_12 + 35200*uk_120 + 241120*uk_121 + 381920*uk_122 + 181280*uk_123 + 1651672*uk_124 + 2616152*uk_125 + 1241768*uk_126 + 4143832*uk_127 + 1966888*uk_128 + 933592*uk_129 + 947060*uk_13 + 1092727*uk_130 + 212180*uk_131 + 1453433*uk_132 + 2302153*uk_133 + 1092727*uk_134 + 41200*uk_135 + 282220*uk_136 + 447020*uk_137 + 212180*uk_138 + 1933207*uk_139 + 6487361*uk_14 + 3062087*uk_140 + 1453433*uk_141 + 4850167*uk_142 + 2302153*uk_143 + 1092727*uk_144 + 8000*uk_145 + 54800*uk_146 + 86800*uk_147 + 41200*uk_148 + 375380*uk_149 + 10275601*uk_15 + 594580*uk_150 + 282220*uk_151 + 941780*uk_152 + 447020*uk_153 + 212180*uk_154 + 2571353*uk_155 + 4072873*uk_156 + 1933207*uk_157 + 6451193*uk_158 + 3062087*uk_159 + 4877359*uk_16 + 1453433*uk_160 + 10218313*uk_161 + 4850167*uk_162 + 2302153*uk_163 + 1092727*uk_164 + 3969*uk_17 + 5544*uk_18 + 6489*uk_19 + 63*uk_2 + 1260*uk_20 + 8631*uk_21 + 13671*uk_22 + 6489*uk_23 + 7744*uk_24 + 9064*uk_25 + 1760*uk_26 + 12056*uk_27 + 19096*uk_28 + 9064*uk_29 + 88*uk_3 + 10609*uk_30 + 2060*uk_31 + 14111*uk_32 + 22351*uk_33 + 10609*uk_34 + 400*uk_35 + 2740*uk_36 + 4340*uk_37 + 2060*uk_38 + 18769*uk_39 + 103*uk_4 + 29729*uk_40 + 14111*uk_41 + 47089*uk_42 + 22351*uk_43 + 10609*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 197322981592*uk_47 + 230957580727*uk_48 + 44846132180*uk_49 + 20*uk_5 + 307196005433*uk_50 + 486580534153*uk_51 + 230957580727*uk_52 + 187944057*uk_53 + 262525032*uk_54 + 307273617*uk_55 + 59664780*uk_56 + 408703743*uk_57 + 647362863*uk_58 + 307273617*uk_59 + 137*uk_6 + 366701632*uk_60 + 429207592*uk_61 + 83341280*uk_62 + 570887768*uk_63 + 904252888*uk_64 + 429207592*uk_65 + 502367977*uk_66 + 97547180*uk_67 + 668198183*uk_68 + 1058386903*uk_69 + 217*uk_7 + 502367977*uk_70 + 18941200*uk_71 + 129747220*uk_72 + 205512020*uk_73 + 97547180*uk_74 + 888768457*uk_75 + 1407757337*uk_76 + 668198183*uk_77 + 2229805417*uk_78 + 1058386903*uk_79 + 103*uk_8 + 502367977*uk_80 + 250047*uk_81 + 349272*uk_82 + 408807*uk_83 + 79380*uk_84 + 543753*uk_85 + 861273*uk_86 + 408807*uk_87 + 487872*uk_88 + 571032*uk_89 + 2242306609*uk_9 + 110880*uk_90 + 759528*uk_91 + 1203048*uk_92 + 571032*uk_93 + 668367*uk_94 + 129780*uk_95 + 888993*uk_96 + 1408113*uk_97 + 668367*uk_98 + 25200*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 175140*uk_100 + 273420*uk_101 + 110880*uk_102 + 1217223*uk_103 + 1900269*uk_104 + 770616*uk_105 + 2966607*uk_106 + 1203048*uk_107 + 487872*uk_108 + 804357*uk_109 + 4403829*uk_11 + 761112*uk_110 + 172980*uk_111 + 1202211*uk_112 + 1876833*uk_113 + 761112*uk_114 + 720192*uk_115 + 163680*uk_116 + 1137576*uk_117 + 1775928*uk_118 + 720192*uk_119 + 4167064*uk_12 + 37200*uk_120 + 258540*uk_121 + 403620*uk_122 + 163680*uk_123 + 1796853*uk_124 + 2805159*uk_125 + 1137576*uk_126 + 4379277*uk_127 + 1775928*uk_128 + 720192*uk_129 + 947060*uk_13 + 681472*uk_130 + 154880*uk_131 + 1076416*uk_132 + 1680448*uk_133 + 681472*uk_134 + 35200*uk_135 + 244640*uk_136 + 381920*uk_137 + 154880*uk_138 + 1700248*uk_139 + 6582067*uk_14 + 2654344*uk_140 + 1076416*uk_141 + 4143832*uk_142 + 1680448*uk_143 + 681472*uk_144 + 8000*uk_145 + 55600*uk_146 + 86800*uk_147 + 35200*uk_148 + 386420*uk_149 + 10275601*uk_15 + 603260*uk_150 + 244640*uk_151 + 941780*uk_152 + 381920*uk_153 + 154880*uk_154 + 2685619*uk_155 + 4192657*uk_156 + 1700248*uk_157 + 6545371*uk_158 + 2654344*uk_159 + 4167064*uk_16 + 1076416*uk_160 + 10218313*uk_161 + 4143832*uk_162 + 1680448*uk_163 + 681472*uk_164 + 3969*uk_17 + 5859*uk_18 + 5544*uk_19 + 63*uk_2 + 1260*uk_20 + 8757*uk_21 + 13671*uk_22 + 5544*uk_23 + 8649*uk_24 + 8184*uk_25 + 1860*uk_26 + 12927*uk_27 + 20181*uk_28 + 8184*uk_29 + 93*uk_3 + 7744*uk_30 + 1760*uk_31 + 12232*uk_32 + 19096*uk_33 + 7744*uk_34 + 400*uk_35 + 2780*uk_36 + 4340*uk_37 + 1760*uk_38 + 19321*uk_39 + 88*uk_4 + 30163*uk_40 + 12232*uk_41 + 47089*uk_42 + 19096*uk_43 + 7744*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 208534514637*uk_47 + 197322981592*uk_48 + 44846132180*uk_49 + 20*uk_5 + 311680618651*uk_50 + 486580534153*uk_51 + 197322981592*uk_52 + 187944057*uk_53 + 277441227*uk_54 + 262525032*uk_55 + 59664780*uk_56 + 414670221*uk_57 + 647362863*uk_58 + 262525032*uk_59 + 139*uk_6 + 409556097*uk_60 + 387536952*uk_61 + 88076580*uk_62 + 612132231*uk_63 + 955630893*uk_64 + 387536952*uk_65 + 366701632*uk_66 + 83341280*uk_67 + 579221896*uk_68 + 904252888*uk_69 + 217*uk_7 + 366701632*uk_70 + 18941200*uk_71 + 131641340*uk_72 + 205512020*uk_73 + 83341280*uk_74 + 914907313*uk_75 + 1428308539*uk_76 + 579221896*uk_77 + 2229805417*uk_78 + 904252888*uk_79 + 88*uk_8 + 366701632*uk_80 + 250047*uk_81 + 369117*uk_82 + 349272*uk_83 + 79380*uk_84 + 551691*uk_85 + 861273*uk_86 + 349272*uk_87 + 544887*uk_88 + 515592*uk_89 + 2242306609*uk_9 + 117180*uk_90 + 814401*uk_91 + 1271403*uk_92 + 515592*uk_93 + 487872*uk_94 + 110880*uk_95 + 770616*uk_96 + 1203048*uk_97 + 487872*uk_98 + 25200*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 177660*uk_100 + 273420*uk_101 + 117180*uk_102 + 1252503*uk_103 + 1927611*uk_104 + 826119*uk_105 + 2966607*uk_106 + 1271403*uk_107 + 544887*uk_108 + 1643032*uk_109 + 5587654*uk_11 + 1294932*uk_110 + 278480*uk_111 + 1963284*uk_112 + 3021508*uk_113 + 1294932*uk_114 + 1020582*uk_115 + 219480*uk_116 + 1547334*uk_117 + 2381358*uk_118 + 1020582*uk_119 + 4403829*uk_12 + 47200*uk_120 + 332760*uk_121 + 512120*uk_122 + 219480*uk_123 + 2345958*uk_124 + 3610446*uk_125 + 1547334*uk_126 + 5556502*uk_127 + 2381358*uk_128 + 1020582*uk_129 + 947060*uk_13 + 804357*uk_130 + 172980*uk_131 + 1219509*uk_132 + 1876833*uk_133 + 804357*uk_134 + 37200*uk_135 + 262260*uk_136 + 403620*uk_137 + 172980*uk_138 + 1848933*uk_139 + 6676773*uk_14 + 2845521*uk_140 + 1219509*uk_141 + 4379277*uk_142 + 1876833*uk_143 + 804357*uk_144 + 8000*uk_145 + 56400*uk_146 + 86800*uk_147 + 37200*uk_148 + 397620*uk_149 + 10275601*uk_15 + 611940*uk_150 + 262260*uk_151 + 941780*uk_152 + 403620*uk_153 + 172980*uk_154 + 2803221*uk_155 + 4314177*uk_156 + 1848933*uk_157 + 6639549*uk_158 + 2845521*uk_159 + 4403829*uk_16 + 1219509*uk_160 + 10218313*uk_161 + 4379277*uk_162 + 1876833*uk_163 + 804357*uk_164 + 3969*uk_17 + 7434*uk_18 + 5859*uk_19 + 63*uk_2 + 1260*uk_20 + 8883*uk_21 + 13671*uk_22 + 5859*uk_23 + 13924*uk_24 + 10974*uk_25 + 2360*uk_26 + 16638*uk_27 + 25606*uk_28 + 10974*uk_29 + 118*uk_3 + 8649*uk_30 + 1860*uk_31 + 13113*uk_32 + 20181*uk_33 + 8649*uk_34 + 400*uk_35 + 2820*uk_36 + 4340*uk_37 + 1860*uk_38 + 19881*uk_39 + 93*uk_4 + 30597*uk_40 + 13113*uk_41 + 47089*uk_42 + 20181*uk_43 + 8649*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 264592179862*uk_47 + 208534514637*uk_48 + 44846132180*uk_49 + 20*uk_5 + 316165231869*uk_50 + 486580534153*uk_51 + 208534514637*uk_52 + 187944057*uk_53 + 352022202*uk_54 + 277441227*uk_55 + 59664780*uk_56 + 420636699*uk_57 + 647362863*uk_58 + 277441227*uk_59 + 141*uk_6 + 659343172*uk_60 + 519651822*uk_61 + 111753080*uk_62 + 787859214*uk_63 + 1212520918*uk_64 + 519651822*uk_65 + 409556097*uk_66 + 88076580*uk_67 + 620939889*uk_68 + 955630893*uk_69 + 217*uk_7 + 409556097*uk_70 + 18941200*uk_71 + 133535460*uk_72 + 205512020*uk_73 + 88076580*uk_74 + 941424993*uk_75 + 1448859741*uk_76 + 620939889*uk_77 + 2229805417*uk_78 + 955630893*uk_79 + 93*uk_8 + 409556097*uk_80 + 250047*uk_81 + 468342*uk_82 + 369117*uk_83 + 79380*uk_84 + 559629*uk_85 + 861273*uk_86 + 369117*uk_87 + 877212*uk_88 + 691362*uk_89 + 2242306609*uk_9 + 148680*uk_90 + 1048194*uk_91 + 1613178*uk_92 + 691362*uk_93 + 544887*uk_94 + 117180*uk_95 + 826119*uk_96 + 1271403*uk_97 + 544887*uk_98 + 25200*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 144144*uk_100 + 218736*uk_101 + 118944*uk_102 + 1288287*uk_103 + 1954953*uk_104 + 1063062*uk_105 + 2966607*uk_106 + 1613178*uk_107 + 877212*uk_108 + 8000*uk_109 + 947060*uk_11 + 47200*uk_110 + 6400*uk_111 + 57200*uk_112 + 86800*uk_113 + 47200*uk_114 + 278480*uk_115 + 37760*uk_116 + 337480*uk_117 + 512120*uk_118 + 278480*uk_119 + 5587654*uk_12 + 5120*uk_120 + 45760*uk_121 + 69440*uk_122 + 37760*uk_123 + 408980*uk_124 + 620620*uk_125 + 337480*uk_126 + 941780*uk_127 + 512120*uk_128 + 278480*uk_129 + 757648*uk_13 + 1643032*uk_130 + 222784*uk_131 + 1991132*uk_132 + 3021508*uk_133 + 1643032*uk_134 + 30208*uk_135 + 269984*uk_136 + 409696*uk_137 + 222784*uk_138 + 2412982*uk_139 + 6771479*uk_14 + 3661658*uk_140 + 1991132*uk_141 + 5556502*uk_142 + 3021508*uk_143 + 1643032*uk_144 + 4096*uk_145 + 36608*uk_146 + 55552*uk_147 + 30208*uk_148 + 327184*uk_149 + 10275601*uk_15 + 496496*uk_150 + 269984*uk_151 + 753424*uk_152 + 409696*uk_153 + 222784*uk_154 + 2924207*uk_155 + 4437433*uk_156 + 2412982*uk_157 + 6733727*uk_158 + 3661658*uk_159 + 5587654*uk_16 + 1991132*uk_160 + 10218313*uk_161 + 5556502*uk_162 + 3021508*uk_163 + 1643032*uk_164 + 3969*uk_17 + 1260*uk_18 + 7434*uk_19 + 63*uk_2 + 1008*uk_20 + 9009*uk_21 + 13671*uk_22 + 7434*uk_23 + 400*uk_24 + 2360*uk_25 + 320*uk_26 + 2860*uk_27 + 4340*uk_28 + 2360*uk_29 + 20*uk_3 + 13924*uk_30 + 1888*uk_31 + 16874*uk_32 + 25606*uk_33 + 13924*uk_34 + 256*uk_35 + 2288*uk_36 + 3472*uk_37 + 1888*uk_38 + 20449*uk_39 + 118*uk_4 + 31031*uk_40 + 16874*uk_41 + 47089*uk_42 + 25606*uk_43 + 13924*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 44846132180*uk_47 + 264592179862*uk_48 + 35876905744*uk_49 + 16*uk_5 + 320649845087*uk_50 + 486580534153*uk_51 + 264592179862*uk_52 + 187944057*uk_53 + 59664780*uk_54 + 352022202*uk_55 + 47731824*uk_56 + 426603177*uk_57 + 647362863*uk_58 + 352022202*uk_59 + 143*uk_6 + 18941200*uk_60 + 111753080*uk_61 + 15152960*uk_62 + 135429580*uk_63 + 205512020*uk_64 + 111753080*uk_65 + 659343172*uk_66 + 89402464*uk_67 + 799034522*uk_68 + 1212520918*uk_69 + 217*uk_7 + 659343172*uk_70 + 12122368*uk_71 + 108343664*uk_72 + 164409616*uk_73 + 89402464*uk_74 + 968321497*uk_75 + 1469410943*uk_76 + 799034522*uk_77 + 2229805417*uk_78 + 1212520918*uk_79 + 118*uk_8 + 659343172*uk_80 + 250047*uk_81 + 79380*uk_82 + 468342*uk_83 + 63504*uk_84 + 567567*uk_85 + 861273*uk_86 + 468342*uk_87 + 25200*uk_88 + 148680*uk_89 + 2242306609*uk_9 + 20160*uk_90 + 180180*uk_91 + 273420*uk_92 + 148680*uk_93 + 877212*uk_94 + 118944*uk_95 + 1063062*uk_96 + 1613178*uk_97 + 877212*uk_98 + 16128*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 182700*uk_100 + 273420*uk_101 + 25200*uk_102 + 1324575*uk_103 + 1982295*uk_104 + 182700*uk_105 + 2966607*uk_106 + 273420*uk_107 + 25200*uk_108 + 571787*uk_109 + 3930299*uk_11 + 137780*uk_110 + 137780*uk_111 + 998905*uk_112 + 1494913*uk_113 + 137780*uk_114 + 33200*uk_115 + 33200*uk_116 + 240700*uk_117 + 360220*uk_118 + 33200*uk_119 + 947060*uk_12 + 33200*uk_120 + 240700*uk_121 + 360220*uk_122 + 33200*uk_123 + 1745075*uk_124 + 2611595*uk_125 + 240700*uk_126 + 3908387*uk_127 + 360220*uk_128 + 33200*uk_129 + 947060*uk_13 + 8000*uk_130 + 8000*uk_131 + 58000*uk_132 + 86800*uk_133 + 8000*uk_134 + 8000*uk_135 + 58000*uk_136 + 86800*uk_137 + 8000*uk_138 + 420500*uk_139 + 6866185*uk_14 + 629300*uk_140 + 58000*uk_141 + 941780*uk_142 + 86800*uk_143 + 8000*uk_144 + 8000*uk_145 + 58000*uk_146 + 86800*uk_147 + 8000*uk_148 + 420500*uk_149 + 10275601*uk_15 + 629300*uk_150 + 58000*uk_151 + 941780*uk_152 + 86800*uk_153 + 8000*uk_154 + 3048625*uk_155 + 4562425*uk_156 + 420500*uk_157 + 6827905*uk_158 + 629300*uk_159 + 947060*uk_16 + 58000*uk_160 + 10218313*uk_161 + 941780*uk_162 + 86800*uk_163 + 8000*uk_164 + 3969*uk_17 + 5229*uk_18 + 1260*uk_19 + 63*uk_2 + 1260*uk_20 + 9135*uk_21 + 13671*uk_22 + 1260*uk_23 + 6889*uk_24 + 1660*uk_25 + 1660*uk_26 + 12035*uk_27 + 18011*uk_28 + 1660*uk_29 + 83*uk_3 + 400*uk_30 + 400*uk_31 + 2900*uk_32 + 4340*uk_33 + 400*uk_34 + 400*uk_35 + 2900*uk_36 + 4340*uk_37 + 400*uk_38 + 21025*uk_39 + 20*uk_4 + 31465*uk_40 + 2900*uk_41 + 47089*uk_42 + 4340*uk_43 + 400*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 186111448547*uk_47 + 44846132180*uk_48 + 44846132180*uk_49 + 20*uk_5 + 325134458305*uk_50 + 486580534153*uk_51 + 44846132180*uk_52 + 187944057*uk_53 + 247608837*uk_54 + 59664780*uk_55 + 59664780*uk_56 + 432569655*uk_57 + 647362863*uk_58 + 59664780*uk_59 + 145*uk_6 + 326214817*uk_60 + 78605980*uk_61 + 78605980*uk_62 + 569893355*uk_63 + 852874883*uk_64 + 78605980*uk_65 + 18941200*uk_66 + 18941200*uk_67 + 137323700*uk_68 + 205512020*uk_69 + 217*uk_7 + 18941200*uk_70 + 18941200*uk_71 + 137323700*uk_72 + 205512020*uk_73 + 18941200*uk_74 + 995596825*uk_75 + 1489962145*uk_76 + 137323700*uk_77 + 2229805417*uk_78 + 205512020*uk_79 + 20*uk_8 + 18941200*uk_80 + 250047*uk_81 + 329427*uk_82 + 79380*uk_83 + 79380*uk_84 + 575505*uk_85 + 861273*uk_86 + 79380*uk_87 + 434007*uk_88 + 104580*uk_89 + 2242306609*uk_9 + 104580*uk_90 + 758205*uk_91 + 1134693*uk_92 + 104580*uk_93 + 25200*uk_94 + 25200*uk_95 + 182700*uk_96 + 273420*uk_97 + 25200*uk_98 + 25200*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 148176*uk_100 + 218736*uk_101 + 83664*uk_102 + 1361367*uk_103 + 2009637*uk_104 + 768663*uk_105 + 2966607*uk_106 + 1134693*uk_107 + 434007*uk_108 + 6859*uk_109 + 899707*uk_11 + 29963*uk_110 + 5776*uk_111 + 53067*uk_112 + 78337*uk_113 + 29963*uk_114 + 130891*uk_115 + 25232*uk_116 + 231819*uk_117 + 342209*uk_118 + 130891*uk_119 + 3930299*uk_12 + 4864*uk_120 + 44688*uk_121 + 65968*uk_122 + 25232*uk_123 + 410571*uk_124 + 606081*uk_125 + 231819*uk_126 + 894691*uk_127 + 342209*uk_128 + 130891*uk_129 + 757648*uk_13 + 571787*uk_130 + 110224*uk_131 + 1012683*uk_132 + 1494913*uk_133 + 571787*uk_134 + 21248*uk_135 + 195216*uk_136 + 288176*uk_137 + 110224*uk_138 + 1793547*uk_139 + 6960891*uk_14 + 2647617*uk_140 + 1012683*uk_141 + 3908387*uk_142 + 1494913*uk_143 + 571787*uk_144 + 4096*uk_145 + 37632*uk_146 + 55552*uk_147 + 21248*uk_148 + 345744*uk_149 + 10275601*uk_15 + 510384*uk_150 + 195216*uk_151 + 753424*uk_152 + 288176*uk_153 + 110224*uk_154 + 3176523*uk_155 + 4689153*uk_156 + 1793547*uk_157 + 6922083*uk_158 + 2647617*uk_159 + 3930299*uk_16 + 1012683*uk_160 + 10218313*uk_161 + 3908387*uk_162 + 1494913*uk_163 + 571787*uk_164 + 3969*uk_17 + 1197*uk_18 + 5229*uk_19 + 63*uk_2 + 1008*uk_20 + 9261*uk_21 + 13671*uk_22 + 5229*uk_23 + 361*uk_24 + 1577*uk_25 + 304*uk_26 + 2793*uk_27 + 4123*uk_28 + 1577*uk_29 + 19*uk_3 + 6889*uk_30 + 1328*uk_31 + 12201*uk_32 + 18011*uk_33 + 6889*uk_34 + 256*uk_35 + 2352*uk_36 + 3472*uk_37 + 1328*uk_38 + 21609*uk_39 + 83*uk_4 + 31899*uk_40 + 12201*uk_41 + 47089*uk_42 + 18011*uk_43 + 6889*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 42603825571*uk_47 + 186111448547*uk_48 + 35876905744*uk_49 + 16*uk_5 + 329619071523*uk_50 + 486580534153*uk_51 + 186111448547*uk_52 + 187944057*uk_53 + 56681541*uk_54 + 247608837*uk_55 + 47731824*uk_56 + 438536133*uk_57 + 647362863*uk_58 + 247608837*uk_59 + 147*uk_6 + 17094433*uk_60 + 74675681*uk_61 + 14395312*uk_62 + 132256929*uk_63 + 195236419*uk_64 + 74675681*uk_65 + 326214817*uk_66 + 62884784*uk_67 + 577753953*uk_68 + 852874883*uk_69 + 217*uk_7 + 326214817*uk_70 + 12122368*uk_71 + 111374256*uk_72 + 164409616*uk_73 + 62884784*uk_74 + 1023250977*uk_75 + 1510513347*uk_76 + 577753953*uk_77 + 2229805417*uk_78 + 852874883*uk_79 + 83*uk_8 + 326214817*uk_80 + 250047*uk_81 + 75411*uk_82 + 329427*uk_83 + 63504*uk_84 + 583443*uk_85 + 861273*uk_86 + 329427*uk_87 + 22743*uk_88 + 99351*uk_89 + 2242306609*uk_9 + 19152*uk_90 + 175959*uk_91 + 259749*uk_92 + 99351*uk_93 + 434007*uk_94 + 83664*uk_95 + 768663*uk_96 + 1134693*uk_97 + 434007*uk_98 + 16128*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 187740*uk_100 + 273420*uk_101 + 23940*uk_102 + 1398663*uk_103 + 2036979*uk_104 + 178353*uk_105 + 2966607*uk_106 + 259749*uk_107 + 22743*uk_108 + 1728000*uk_109 + 5682360*uk_11 + 273600*uk_110 + 288000*uk_111 + 2145600*uk_112 + 3124800*uk_113 + 273600*uk_114 + 43320*uk_115 + 45600*uk_116 + 339720*uk_117 + 494760*uk_118 + 43320*uk_119 + 899707*uk_12 + 48000*uk_120 + 357600*uk_121 + 520800*uk_122 + 45600*uk_123 + 2664120*uk_124 + 3879960*uk_125 + 339720*uk_126 + 5650680*uk_127 + 494760*uk_128 + 43320*uk_129 + 947060*uk_13 + 6859*uk_130 + 7220*uk_131 + 53789*uk_132 + 78337*uk_133 + 6859*uk_134 + 7600*uk_135 + 56620*uk_136 + 82460*uk_137 + 7220*uk_138 + 421819*uk_139 + 7055597*uk_14 + 614327*uk_140 + 53789*uk_141 + 894691*uk_142 + 78337*uk_143 + 6859*uk_144 + 8000*uk_145 + 59600*uk_146 + 86800*uk_147 + 7600*uk_148 + 444020*uk_149 + 10275601*uk_15 + 646660*uk_150 + 56620*uk_151 + 941780*uk_152 + 82460*uk_153 + 7220*uk_154 + 3307949*uk_155 + 4817617*uk_156 + 421819*uk_157 + 7016261*uk_158 + 614327*uk_159 + 899707*uk_16 + 53789*uk_160 + 10218313*uk_161 + 894691*uk_162 + 78337*uk_163 + 6859*uk_164 + 3969*uk_17 + 7560*uk_18 + 1197*uk_19 + 63*uk_2 + 1260*uk_20 + 9387*uk_21 + 13671*uk_22 + 1197*uk_23 + 14400*uk_24 + 2280*uk_25 + 2400*uk_26 + 17880*uk_27 + 26040*uk_28 + 2280*uk_29 + 120*uk_3 + 361*uk_30 + 380*uk_31 + 2831*uk_32 + 4123*uk_33 + 361*uk_34 + 400*uk_35 + 2980*uk_36 + 4340*uk_37 + 380*uk_38 + 22201*uk_39 + 19*uk_4 + 32333*uk_40 + 2831*uk_41 + 47089*uk_42 + 4123*uk_43 + 361*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 269076793080*uk_47 + 42603825571*uk_48 + 44846132180*uk_49 + 20*uk_5 + 334103684741*uk_50 + 486580534153*uk_51 + 42603825571*uk_52 + 187944057*uk_53 + 357988680*uk_54 + 56681541*uk_55 + 59664780*uk_56 + 444502611*uk_57 + 647362863*uk_58 + 56681541*uk_59 + 149*uk_6 + 681883200*uk_60 + 107964840*uk_61 + 113647200*uk_62 + 846671640*uk_63 + 1233072120*uk_64 + 107964840*uk_65 + 17094433*uk_66 + 17994140*uk_67 + 134056343*uk_68 + 195236419*uk_69 + 217*uk_7 + 17094433*uk_70 + 18941200*uk_71 + 141111940*uk_72 + 205512020*uk_73 + 17994140*uk_74 + 1051283953*uk_75 + 1531064549*uk_76 + 134056343*uk_77 + 2229805417*uk_78 + 195236419*uk_79 + 19*uk_8 + 17094433*uk_80 + 250047*uk_81 + 476280*uk_82 + 75411*uk_83 + 79380*uk_84 + 591381*uk_85 + 861273*uk_86 + 75411*uk_87 + 907200*uk_88 + 143640*uk_89 + 2242306609*uk_9 + 151200*uk_90 + 1126440*uk_91 + 1640520*uk_92 + 143640*uk_93 + 22743*uk_94 + 23940*uk_95 + 178353*uk_96 + 259749*uk_97 + 22743*uk_98 + 25200*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 152208*uk_100 + 218736*uk_101 + 120960*uk_102 + 1436463*uk_103 + 2064321*uk_104 + 1141560*uk_105 + 2966607*uk_106 + 1640520*uk_107 + 907200*uk_108 + 729000*uk_109 + 4261770*uk_11 + 972000*uk_110 + 129600*uk_111 + 1223100*uk_112 + 1757700*uk_113 + 972000*uk_114 + 1296000*uk_115 + 172800*uk_116 + 1630800*uk_117 + 2343600*uk_118 + 1296000*uk_119 + 5682360*uk_12 + 23040*uk_120 + 217440*uk_121 + 312480*uk_122 + 172800*uk_123 + 2052090*uk_124 + 2949030*uk_125 + 1630800*uk_126 + 4238010*uk_127 + 2343600*uk_128 + 1296000*uk_129 + 757648*uk_13 + 1728000*uk_130 + 230400*uk_131 + 2174400*uk_132 + 3124800*uk_133 + 1728000*uk_134 + 30720*uk_135 + 289920*uk_136 + 416640*uk_137 + 230400*uk_138 + 2736120*uk_139 + 7150303*uk_14 + 3932040*uk_140 + 2174400*uk_141 + 5650680*uk_142 + 3124800*uk_143 + 1728000*uk_144 + 4096*uk_145 + 38656*uk_146 + 55552*uk_147 + 30720*uk_148 + 364816*uk_149 + 10275601*uk_15 + 524272*uk_150 + 289920*uk_151 + 753424*uk_152 + 416640*uk_153 + 230400*uk_154 + 3442951*uk_155 + 4947817*uk_156 + 2736120*uk_157 + 7110439*uk_158 + 3932040*uk_159 + 5682360*uk_16 + 2174400*uk_160 + 10218313*uk_161 + 5650680*uk_162 + 3124800*uk_163 + 1728000*uk_164 + 3969*uk_17 + 5670*uk_18 + 7560*uk_19 + 63*uk_2 + 1008*uk_20 + 9513*uk_21 + 13671*uk_22 + 7560*uk_23 + 8100*uk_24 + 10800*uk_25 + 1440*uk_26 + 13590*uk_27 + 19530*uk_28 + 10800*uk_29 + 90*uk_3 + 14400*uk_30 + 1920*uk_31 + 18120*uk_32 + 26040*uk_33 + 14400*uk_34 + 256*uk_35 + 2416*uk_36 + 3472*uk_37 + 1920*uk_38 + 22801*uk_39 + 120*uk_4 + 32767*uk_40 + 18120*uk_41 + 47089*uk_42 + 26040*uk_43 + 14400*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 201807594810*uk_47 + 269076793080*uk_48 + 35876905744*uk_49 + 16*uk_5 + 338588297959*uk_50 + 486580534153*uk_51 + 269076793080*uk_52 + 187944057*uk_53 + 268491510*uk_54 + 357988680*uk_55 + 47731824*uk_56 + 450469089*uk_57 + 647362863*uk_58 + 357988680*uk_59 + 151*uk_6 + 383559300*uk_60 + 511412400*uk_61 + 68188320*uk_62 + 643527270*uk_63 + 924804090*uk_64 + 511412400*uk_65 + 681883200*uk_66 + 90917760*uk_67 + 858036360*uk_68 + 1233072120*uk_69 + 217*uk_7 + 681883200*uk_70 + 12122368*uk_71 + 114404848*uk_72 + 164409616*uk_73 + 90917760*uk_74 + 1079695753*uk_75 + 1551615751*uk_76 + 858036360*uk_77 + 2229805417*uk_78 + 1233072120*uk_79 + 120*uk_8 + 681883200*uk_80 + 250047*uk_81 + 357210*uk_82 + 476280*uk_83 + 63504*uk_84 + 599319*uk_85 + 861273*uk_86 + 476280*uk_87 + 510300*uk_88 + 680400*uk_89 + 2242306609*uk_9 + 90720*uk_90 + 856170*uk_91 + 1230390*uk_92 + 680400*uk_93 + 907200*uk_94 + 120960*uk_95 + 1141560*uk_96 + 1640520*uk_97 + 907200*uk_98 + 16128*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 154224*uk_100 + 218736*uk_101 + 90720*uk_102 + 1474767*uk_103 + 2091663*uk_104 + 867510*uk_105 + 2966607*uk_106 + 1230390*uk_107 + 510300*uk_108 + 438976*uk_109 + 3598828*uk_11 + 519840*uk_110 + 92416*uk_111 + 883728*uk_112 + 1253392*uk_113 + 519840*uk_114 + 615600*uk_115 + 109440*uk_116 + 1046520*uk_117 + 1484280*uk_118 + 615600*uk_119 + 4261770*uk_12 + 19456*uk_120 + 186048*uk_121 + 263872*uk_122 + 109440*uk_123 + 1779084*uk_124 + 2523276*uk_125 + 1046520*uk_126 + 3578764*uk_127 + 1484280*uk_128 + 615600*uk_129 + 757648*uk_13 + 729000*uk_130 + 129600*uk_131 + 1239300*uk_132 + 1757700*uk_133 + 729000*uk_134 + 23040*uk_135 + 220320*uk_136 + 312480*uk_137 + 129600*uk_138 + 2106810*uk_139 + 7245009*uk_14 + 2988090*uk_140 + 1239300*uk_141 + 4238010*uk_142 + 1757700*uk_143 + 729000*uk_144 + 4096*uk_145 + 39168*uk_146 + 55552*uk_147 + 23040*uk_148 + 374544*uk_149 + 10275601*uk_15 + 531216*uk_150 + 220320*uk_151 + 753424*uk_152 + 312480*uk_153 + 129600*uk_154 + 3581577*uk_155 + 5079753*uk_156 + 2106810*uk_157 + 7204617*uk_158 + 2988090*uk_159 + 4261770*uk_16 + 1239300*uk_160 + 10218313*uk_161 + 4238010*uk_162 + 1757700*uk_163 + 729000*uk_164 + 3969*uk_17 + 4788*uk_18 + 5670*uk_19 + 63*uk_2 + 1008*uk_20 + 9639*uk_21 + 13671*uk_22 + 5670*uk_23 + 5776*uk_24 + 6840*uk_25 + 1216*uk_26 + 11628*uk_27 + 16492*uk_28 + 6840*uk_29 + 76*uk_3 + 8100*uk_30 + 1440*uk_31 + 13770*uk_32 + 19530*uk_33 + 8100*uk_34 + 256*uk_35 + 2448*uk_36 + 3472*uk_37 + 1440*uk_38 + 23409*uk_39 + 90*uk_4 + 33201*uk_40 + 13770*uk_41 + 47089*uk_42 + 19530*uk_43 + 8100*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 170415302284*uk_47 + 201807594810*uk_48 + 35876905744*uk_49 + 16*uk_5 + 343072911177*uk_50 + 486580534153*uk_51 + 201807594810*uk_52 + 187944057*uk_53 + 226726164*uk_54 + 268491510*uk_55 + 47731824*uk_56 + 456435567*uk_57 + 647362863*uk_58 + 268491510*uk_59 + 153*uk_6 + 273510928*uk_60 + 323894520*uk_61 + 57581248*uk_62 + 550620684*uk_63 + 780945676*uk_64 + 323894520*uk_65 + 383559300*uk_66 + 68188320*uk_67 + 652050810*uk_68 + 924804090*uk_69 + 217*uk_7 + 383559300*uk_70 + 12122368*uk_71 + 115920144*uk_72 + 164409616*uk_73 + 68188320*uk_74 + 1108486377*uk_75 + 1572166953*uk_76 + 652050810*uk_77 + 2229805417*uk_78 + 924804090*uk_79 + 90*uk_8 + 383559300*uk_80 + 250047*uk_81 + 301644*uk_82 + 357210*uk_83 + 63504*uk_84 + 607257*uk_85 + 861273*uk_86 + 357210*uk_87 + 363888*uk_88 + 430920*uk_89 + 2242306609*uk_9 + 76608*uk_90 + 732564*uk_91 + 1038996*uk_92 + 430920*uk_93 + 510300*uk_94 + 90720*uk_95 + 867510*uk_96 + 1230390*uk_97 + 510300*uk_98 + 16128*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 156240*uk_100 + 218736*uk_101 + 76608*uk_102 + 1513575*uk_103 + 2119005*uk_104 + 742140*uk_105 + 2966607*uk_106 + 1038996*uk_107 + 363888*uk_108 + 474552*uk_109 + 3693534*uk_11 + 462384*uk_110 + 97344*uk_111 + 943020*uk_112 + 1320228*uk_113 + 462384*uk_114 + 450528*uk_115 + 94848*uk_116 + 918840*uk_117 + 1286376*uk_118 + 450528*uk_119 + 3598828*uk_12 + 19968*uk_120 + 193440*uk_121 + 270816*uk_122 + 94848*uk_123 + 1873950*uk_124 + 2623530*uk_125 + 918840*uk_126 + 3672942*uk_127 + 1286376*uk_128 + 450528*uk_129 + 757648*uk_13 + 438976*uk_130 + 92416*uk_131 + 895280*uk_132 + 1253392*uk_133 + 438976*uk_134 + 19456*uk_135 + 188480*uk_136 + 263872*uk_137 + 92416*uk_138 + 1825900*uk_139 + 7339715*uk_14 + 2556260*uk_140 + 895280*uk_141 + 3578764*uk_142 + 1253392*uk_143 + 438976*uk_144 + 4096*uk_145 + 39680*uk_146 + 55552*uk_147 + 19456*uk_148 + 384400*uk_149 + 10275601*uk_15 + 538160*uk_150 + 188480*uk_151 + 753424*uk_152 + 263872*uk_153 + 92416*uk_154 + 3723875*uk_155 + 5213425*uk_156 + 1825900*uk_157 + 7298795*uk_158 + 2556260*uk_159 + 3598828*uk_16 + 895280*uk_160 + 10218313*uk_161 + 3578764*uk_162 + 1253392*uk_163 + 438976*uk_164 + 3969*uk_17 + 4914*uk_18 + 4788*uk_19 + 63*uk_2 + 1008*uk_20 + 9765*uk_21 + 13671*uk_22 + 4788*uk_23 + 6084*uk_24 + 5928*uk_25 + 1248*uk_26 + 12090*uk_27 + 16926*uk_28 + 5928*uk_29 + 78*uk_3 + 5776*uk_30 + 1216*uk_31 + 11780*uk_32 + 16492*uk_33 + 5776*uk_34 + 256*uk_35 + 2480*uk_36 + 3472*uk_37 + 1216*uk_38 + 24025*uk_39 + 76*uk_4 + 33635*uk_40 + 11780*uk_41 + 47089*uk_42 + 16492*uk_43 + 5776*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 174899915502*uk_47 + 170415302284*uk_48 + 35876905744*uk_49 + 16*uk_5 + 347557524395*uk_50 + 486580534153*uk_51 + 170415302284*uk_52 + 187944057*uk_53 + 232692642*uk_54 + 226726164*uk_55 + 47731824*uk_56 + 462402045*uk_57 + 647362863*uk_58 + 226726164*uk_59 + 155*uk_6 + 288095652*uk_60 + 280708584*uk_61 + 59096544*uk_62 + 572497770*uk_63 + 801496878*uk_64 + 280708584*uk_65 + 273510928*uk_66 + 57581248*uk_67 + 557818340*uk_68 + 780945676*uk_69 + 217*uk_7 + 273510928*uk_70 + 12122368*uk_71 + 117435440*uk_72 + 164409616*uk_73 + 57581248*uk_74 + 1137655825*uk_75 + 1592718155*uk_76 + 557818340*uk_77 + 2229805417*uk_78 + 780945676*uk_79 + 76*uk_8 + 273510928*uk_80 + 250047*uk_81 + 309582*uk_82 + 301644*uk_83 + 63504*uk_84 + 615195*uk_85 + 861273*uk_86 + 301644*uk_87 + 383292*uk_88 + 373464*uk_89 + 2242306609*uk_9 + 78624*uk_90 + 761670*uk_91 + 1066338*uk_92 + 373464*uk_93 + 363888*uk_94 + 76608*uk_95 + 742140*uk_96 + 1038996*uk_97 + 363888*uk_98 + 16128*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 158256*uk_100 + 218736*uk_101 + 78624*uk_102 + 1552887*uk_103 + 2146347*uk_104 + 771498*uk_105 + 2966607*uk_106 + 1066338*uk_107 + 383292*uk_108 + 884736*uk_109 + 4545888*uk_11 + 718848*uk_110 + 147456*uk_111 + 1446912*uk_112 + 1999872*uk_113 + 718848*uk_114 + 584064*uk_115 + 119808*uk_116 + 1175616*uk_117 + 1624896*uk_118 + 584064*uk_119 + 3693534*uk_12 + 24576*uk_120 + 241152*uk_121 + 333312*uk_122 + 119808*uk_123 + 2366304*uk_124 + 3270624*uk_125 + 1175616*uk_126 + 4520544*uk_127 + 1624896*uk_128 + 584064*uk_129 + 757648*uk_13 + 474552*uk_130 + 97344*uk_131 + 955188*uk_132 + 1320228*uk_133 + 474552*uk_134 + 19968*uk_135 + 195936*uk_136 + 270816*uk_137 + 97344*uk_138 + 1922622*uk_139 + 7434421*uk_14 + 2657382*uk_140 + 955188*uk_141 + 3672942*uk_142 + 1320228*uk_143 + 474552*uk_144 + 4096*uk_145 + 40192*uk_146 + 55552*uk_147 + 19968*uk_148 + 394384*uk_149 + 10275601*uk_15 + 545104*uk_150 + 195936*uk_151 + 753424*uk_152 + 270816*uk_153 + 97344*uk_154 + 3869893*uk_155 + 5348833*uk_156 + 1922622*uk_157 + 7392973*uk_158 + 2657382*uk_159 + 3693534*uk_16 + 955188*uk_160 + 10218313*uk_161 + 3672942*uk_162 + 1320228*uk_163 + 474552*uk_164 + 3969*uk_17 + 6048*uk_18 + 4914*uk_19 + 63*uk_2 + 1008*uk_20 + 9891*uk_21 + 13671*uk_22 + 4914*uk_23 + 9216*uk_24 + 7488*uk_25 + 1536*uk_26 + 15072*uk_27 + 20832*uk_28 + 7488*uk_29 + 96*uk_3 + 6084*uk_30 + 1248*uk_31 + 12246*uk_32 + 16926*uk_33 + 6084*uk_34 + 256*uk_35 + 2512*uk_36 + 3472*uk_37 + 1248*uk_38 + 24649*uk_39 + 78*uk_4 + 34069*uk_40 + 12246*uk_41 + 47089*uk_42 + 16926*uk_43 + 6084*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 215261434464*uk_47 + 174899915502*uk_48 + 35876905744*uk_49 + 16*uk_5 + 352042137613*uk_50 + 486580534153*uk_51 + 174899915502*uk_52 + 187944057*uk_53 + 286390944*uk_54 + 232692642*uk_55 + 47731824*uk_56 + 468368523*uk_57 + 647362863*uk_58 + 232692642*uk_59 + 157*uk_6 + 436405248*uk_60 + 354579264*uk_61 + 72734208*uk_62 + 713704416*uk_63 + 986457696*uk_64 + 354579264*uk_65 + 288095652*uk_66 + 59096544*uk_67 + 579884838*uk_68 + 801496878*uk_69 + 217*uk_7 + 288095652*uk_70 + 12122368*uk_71 + 118950736*uk_72 + 164409616*uk_73 + 59096544*uk_74 + 1167204097*uk_75 + 1613269357*uk_76 + 579884838*uk_77 + 2229805417*uk_78 + 801496878*uk_79 + 78*uk_8 + 288095652*uk_80 + 250047*uk_81 + 381024*uk_82 + 309582*uk_83 + 63504*uk_84 + 623133*uk_85 + 861273*uk_86 + 309582*uk_87 + 580608*uk_88 + 471744*uk_89 + 2242306609*uk_9 + 96768*uk_90 + 949536*uk_91 + 1312416*uk_92 + 471744*uk_93 + 383292*uk_94 + 78624*uk_95 + 771498*uk_96 + 1066338*uk_97 + 383292*uk_98 + 16128*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 160272*uk_100 + 218736*uk_101 + 96768*uk_102 + 1592703*uk_103 + 2173689*uk_104 + 961632*uk_105 + 2966607*uk_106 + 1312416*uk_107 + 580608*uk_108 + 2197000*uk_109 + 6155890*uk_11 + 1622400*uk_110 + 270400*uk_111 + 2687100*uk_112 + 3667300*uk_113 + 1622400*uk_114 + 1198080*uk_115 + 199680*uk_116 + 1984320*uk_117 + 2708160*uk_118 + 1198080*uk_119 + 4545888*uk_12 + 33280*uk_120 + 330720*uk_121 + 451360*uk_122 + 199680*uk_123 + 3286530*uk_124 + 4485390*uk_125 + 1984320*uk_126 + 6121570*uk_127 + 2708160*uk_128 + 1198080*uk_129 + 757648*uk_13 + 884736*uk_130 + 147456*uk_131 + 1465344*uk_132 + 1999872*uk_133 + 884736*uk_134 + 24576*uk_135 + 244224*uk_136 + 333312*uk_137 + 147456*uk_138 + 2426976*uk_139 + 7529127*uk_14 + 3312288*uk_140 + 1465344*uk_141 + 4520544*uk_142 + 1999872*uk_143 + 884736*uk_144 + 4096*uk_145 + 40704*uk_146 + 55552*uk_147 + 24576*uk_148 + 404496*uk_149 + 10275601*uk_15 + 552048*uk_150 + 244224*uk_151 + 753424*uk_152 + 333312*uk_153 + 147456*uk_154 + 4019679*uk_155 + 5485977*uk_156 + 2426976*uk_157 + 7487151*uk_158 + 3312288*uk_159 + 4545888*uk_16 + 1465344*uk_160 + 10218313*uk_161 + 4520544*uk_162 + 1999872*uk_163 + 884736*uk_164 + 3969*uk_17 + 8190*uk_18 + 6048*uk_19 + 63*uk_2 + 1008*uk_20 + 10017*uk_21 + 13671*uk_22 + 6048*uk_23 + 16900*uk_24 + 12480*uk_25 + 2080*uk_26 + 20670*uk_27 + 28210*uk_28 + 12480*uk_29 + 130*uk_3 + 9216*uk_30 + 1536*uk_31 + 15264*uk_32 + 20832*uk_33 + 9216*uk_34 + 256*uk_35 + 2544*uk_36 + 3472*uk_37 + 1536*uk_38 + 25281*uk_39 + 96*uk_4 + 34503*uk_40 + 15264*uk_41 + 47089*uk_42 + 20832*uk_43 + 9216*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 291499859170*uk_47 + 215261434464*uk_48 + 35876905744*uk_49 + 16*uk_5 + 356526750831*uk_50 + 486580534153*uk_51 + 215261434464*uk_52 + 187944057*uk_53 + 387821070*uk_54 + 286390944*uk_55 + 47731824*uk_56 + 474335001*uk_57 + 647362863*uk_58 + 286390944*uk_59 + 159*uk_6 + 800265700*uk_60 + 590965440*uk_61 + 98494240*uk_62 + 978786510*uk_63 + 1335828130*uk_64 + 590965440*uk_65 + 436405248*uk_66 + 72734208*uk_67 + 722796192*uk_68 + 986457696*uk_69 + 217*uk_7 + 436405248*uk_70 + 12122368*uk_71 + 120466032*uk_72 + 164409616*uk_73 + 72734208*uk_74 + 1197131193*uk_75 + 1633820559*uk_76 + 722796192*uk_77 + 2229805417*uk_78 + 986457696*uk_79 + 96*uk_8 + 436405248*uk_80 + 250047*uk_81 + 515970*uk_82 + 381024*uk_83 + 63504*uk_84 + 631071*uk_85 + 861273*uk_86 + 381024*uk_87 + 1064700*uk_88 + 786240*uk_89 + 2242306609*uk_9 + 131040*uk_90 + 1302210*uk_91 + 1777230*uk_92 + 786240*uk_93 + 580608*uk_94 + 96768*uk_95 + 961632*uk_96 + 1312416*uk_97 + 580608*uk_98 + 16128*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 121716*uk_100 + 164052*uk_101 + 98280*uk_102 + 1633023*uk_103 + 2201031*uk_104 + 1318590*uk_105 + 2966607*uk_106 + 1777230*uk_107 + 1064700*uk_108 + 6859*uk_109 + 899707*uk_11 + 46930*uk_110 + 4332*uk_111 + 58121*uk_112 + 78337*uk_113 + 46930*uk_114 + 321100*uk_115 + 29640*uk_116 + 397670*uk_117 + 535990*uk_118 + 321100*uk_119 + 6155890*uk_12 + 2736*uk_120 + 36708*uk_121 + 49476*uk_122 + 29640*uk_123 + 492499*uk_124 + 663803*uk_125 + 397670*uk_126 + 894691*uk_127 + 535990*uk_128 + 321100*uk_129 + 568236*uk_13 + 2197000*uk_130 + 202800*uk_131 + 2720900*uk_132 + 3667300*uk_133 + 2197000*uk_134 + 18720*uk_135 + 251160*uk_136 + 338520*uk_137 + 202800*uk_138 + 3369730*uk_139 + 7623833*uk_14 + 4541810*uk_140 + 2720900*uk_141 + 6121570*uk_142 + 3667300*uk_143 + 2197000*uk_144 + 1728*uk_145 + 23184*uk_146 + 31248*uk_147 + 18720*uk_148 + 311052*uk_149 + 10275601*uk_15 + 419244*uk_150 + 251160*uk_151 + 565068*uk_152 + 338520*uk_153 + 202800*uk_154 + 4173281*uk_155 + 5624857*uk_156 + 3369730*uk_157 + 7581329*uk_158 + 4541810*uk_159 + 6155890*uk_16 + 2720900*uk_160 + 10218313*uk_161 + 6121570*uk_162 + 3667300*uk_163 + 2197000*uk_164 + 3969*uk_17 + 1197*uk_18 + 8190*uk_19 + 63*uk_2 + 756*uk_20 + 10143*uk_21 + 13671*uk_22 + 8190*uk_23 + 361*uk_24 + 2470*uk_25 + 228*uk_26 + 3059*uk_27 + 4123*uk_28 + 2470*uk_29 + 19*uk_3 + 16900*uk_30 + 1560*uk_31 + 20930*uk_32 + 28210*uk_33 + 16900*uk_34 + 144*uk_35 + 1932*uk_36 + 2604*uk_37 + 1560*uk_38 + 25921*uk_39 + 130*uk_4 + 34937*uk_40 + 20930*uk_41 + 47089*uk_42 + 28210*uk_43 + 16900*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 42603825571*uk_47 + 291499859170*uk_48 + 26907679308*uk_49 + 12*uk_5 + 361011364049*uk_50 + 486580534153*uk_51 + 291499859170*uk_52 + 187944057*uk_53 + 56681541*uk_54 + 387821070*uk_55 + 35798868*uk_56 + 480301479*uk_57 + 647362863*uk_58 + 387821070*uk_59 + 161*uk_6 + 17094433*uk_60 + 116961910*uk_61 + 10796484*uk_62 + 144852827*uk_63 + 195236419*uk_64 + 116961910*uk_65 + 800265700*uk_66 + 73870680*uk_67 + 991098290*uk_68 + 1335828130*uk_69 + 217*uk_7 + 800265700*uk_70 + 6818832*uk_71 + 91485996*uk_72 + 123307212*uk_73 + 73870680*uk_74 + 1227437113*uk_75 + 1654371761*uk_76 + 991098290*uk_77 + 2229805417*uk_78 + 1335828130*uk_79 + 130*uk_8 + 800265700*uk_80 + 250047*uk_81 + 75411*uk_82 + 515970*uk_83 + 47628*uk_84 + 639009*uk_85 + 861273*uk_86 + 515970*uk_87 + 22743*uk_88 + 155610*uk_89 + 2242306609*uk_9 + 14364*uk_90 + 192717*uk_91 + 259749*uk_92 + 155610*uk_93 + 1064700*uk_94 + 98280*uk_95 + 1318590*uk_96 + 1777230*uk_97 + 1064700*uk_98 + 9072*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 164304*uk_100 + 218736*uk_101 + 19152*uk_102 + 1673847*uk_103 + 2228373*uk_104 + 195111*uk_105 + 2966607*uk_106 + 259749*uk_107 + 22743*uk_108 + 571787*uk_109 + 3930299*uk_11 + 130891*uk_110 + 110224*uk_111 + 1122907*uk_112 + 1494913*uk_113 + 130891*uk_114 + 29963*uk_115 + 25232*uk_116 + 257051*uk_117 + 342209*uk_118 + 29963*uk_119 + 899707*uk_12 + 21248*uk_120 + 216464*uk_121 + 288176*uk_122 + 25232*uk_123 + 2205227*uk_124 + 2935793*uk_125 + 257051*uk_126 + 3908387*uk_127 + 342209*uk_128 + 29963*uk_129 + 757648*uk_13 + 6859*uk_130 + 5776*uk_131 + 58843*uk_132 + 78337*uk_133 + 6859*uk_134 + 4864*uk_135 + 49552*uk_136 + 65968*uk_137 + 5776*uk_138 + 504811*uk_139 + 7718539*uk_14 + 672049*uk_140 + 58843*uk_141 + 894691*uk_142 + 78337*uk_143 + 6859*uk_144 + 4096*uk_145 + 41728*uk_146 + 55552*uk_147 + 4864*uk_148 + 425104*uk_149 + 10275601*uk_15 + 565936*uk_150 + 49552*uk_151 + 753424*uk_152 + 65968*uk_153 + 5776*uk_154 + 4330747*uk_155 + 5765473*uk_156 + 504811*uk_157 + 7675507*uk_158 + 672049*uk_159 + 899707*uk_16 + 58843*uk_160 + 10218313*uk_161 + 894691*uk_162 + 78337*uk_163 + 6859*uk_164 + 3969*uk_17 + 5229*uk_18 + 1197*uk_19 + 63*uk_2 + 1008*uk_20 + 10269*uk_21 + 13671*uk_22 + 1197*uk_23 + 6889*uk_24 + 1577*uk_25 + 1328*uk_26 + 13529*uk_27 + 18011*uk_28 + 1577*uk_29 + 83*uk_3 + 361*uk_30 + 304*uk_31 + 3097*uk_32 + 4123*uk_33 + 361*uk_34 + 256*uk_35 + 2608*uk_36 + 3472*uk_37 + 304*uk_38 + 26569*uk_39 + 19*uk_4 + 35371*uk_40 + 3097*uk_41 + 47089*uk_42 + 4123*uk_43 + 361*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 186111448547*uk_47 + 42603825571*uk_48 + 35876905744*uk_49 + 16*uk_5 + 365495977267*uk_50 + 486580534153*uk_51 + 42603825571*uk_52 + 187944057*uk_53 + 247608837*uk_54 + 56681541*uk_55 + 47731824*uk_56 + 486267957*uk_57 + 647362863*uk_58 + 56681541*uk_59 + 163*uk_6 + 326214817*uk_60 + 74675681*uk_61 + 62884784*uk_62 + 640638737*uk_63 + 852874883*uk_64 + 74675681*uk_65 + 17094433*uk_66 + 14395312*uk_67 + 146652241*uk_68 + 195236419*uk_69 + 217*uk_7 + 17094433*uk_70 + 12122368*uk_71 + 123496624*uk_72 + 164409616*uk_73 + 14395312*uk_74 + 1258121857*uk_75 + 1674922963*uk_76 + 146652241*uk_77 + 2229805417*uk_78 + 195236419*uk_79 + 19*uk_8 + 17094433*uk_80 + 250047*uk_81 + 329427*uk_82 + 75411*uk_83 + 63504*uk_84 + 646947*uk_85 + 861273*uk_86 + 75411*uk_87 + 434007*uk_88 + 99351*uk_89 + 2242306609*uk_9 + 83664*uk_90 + 852327*uk_91 + 1134693*uk_92 + 99351*uk_93 + 22743*uk_94 + 19152*uk_95 + 195111*uk_96 + 259749*uk_97 + 22743*uk_98 + 16128*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 166320*uk_100 + 218736*uk_101 + 83664*uk_102 + 1715175*uk_103 + 2255715*uk_104 + 862785*uk_105 + 2966607*uk_106 + 1134693*uk_107 + 434007*uk_108 + 4330747*uk_109 + 7718539*uk_11 + 2205227*uk_110 + 425104*uk_111 + 4383885*uk_112 + 5765473*uk_113 + 2205227*uk_114 + 1122907*uk_115 + 216464*uk_116 + 2232285*uk_117 + 2935793*uk_118 + 1122907*uk_119 + 3930299*uk_12 + 41728*uk_120 + 430320*uk_121 + 565936*uk_122 + 216464*uk_123 + 4437675*uk_124 + 5836215*uk_125 + 2232285*uk_126 + 7675507*uk_127 + 2935793*uk_128 + 1122907*uk_129 + 757648*uk_13 + 571787*uk_130 + 110224*uk_131 + 1136685*uk_132 + 1494913*uk_133 + 571787*uk_134 + 21248*uk_135 + 219120*uk_136 + 288176*uk_137 + 110224*uk_138 + 2259675*uk_139 + 7813245*uk_14 + 2971815*uk_140 + 1136685*uk_141 + 3908387*uk_142 + 1494913*uk_143 + 571787*uk_144 + 4096*uk_145 + 42240*uk_146 + 55552*uk_147 + 21248*uk_148 + 435600*uk_149 + 10275601*uk_15 + 572880*uk_150 + 219120*uk_151 + 753424*uk_152 + 288176*uk_153 + 110224*uk_154 + 4492125*uk_155 + 5907825*uk_156 + 2259675*uk_157 + 7769685*uk_158 + 2971815*uk_159 + 3930299*uk_16 + 1136685*uk_160 + 10218313*uk_161 + 3908387*uk_162 + 1494913*uk_163 + 571787*uk_164 + 3969*uk_17 + 10269*uk_18 + 5229*uk_19 + 63*uk_2 + 1008*uk_20 + 10395*uk_21 + 13671*uk_22 + 5229*uk_23 + 26569*uk_24 + 13529*uk_25 + 2608*uk_26 + 26895*uk_27 + 35371*uk_28 + 13529*uk_29 + 163*uk_3 + 6889*uk_30 + 1328*uk_31 + 13695*uk_32 + 18011*uk_33 + 6889*uk_34 + 256*uk_35 + 2640*uk_36 + 3472*uk_37 + 1328*uk_38 + 27225*uk_39 + 83*uk_4 + 35805*uk_40 + 13695*uk_41 + 47089*uk_42 + 18011*uk_43 + 6889*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 365495977267*uk_47 + 186111448547*uk_48 + 35876905744*uk_49 + 16*uk_5 + 369980590485*uk_50 + 486580534153*uk_51 + 186111448547*uk_52 + 187944057*uk_53 + 486267957*uk_54 + 247608837*uk_55 + 47731824*uk_56 + 492234435*uk_57 + 647362863*uk_58 + 247608837*uk_59 + 165*uk_6 + 1258121857*uk_60 + 640638737*uk_61 + 123496624*uk_62 + 1273558935*uk_63 + 1674922963*uk_64 + 640638737*uk_65 + 326214817*uk_66 + 62884784*uk_67 + 648499335*uk_68 + 852874883*uk_69 + 217*uk_7 + 326214817*uk_70 + 12122368*uk_71 + 125011920*uk_72 + 164409616*uk_73 + 62884784*uk_74 + 1289185425*uk_75 + 1695474165*uk_76 + 648499335*uk_77 + 2229805417*uk_78 + 852874883*uk_79 + 83*uk_8 + 326214817*uk_80 + 250047*uk_81 + 646947*uk_82 + 329427*uk_83 + 63504*uk_84 + 654885*uk_85 + 861273*uk_86 + 329427*uk_87 + 1673847*uk_88 + 852327*uk_89 + 2242306609*uk_9 + 164304*uk_90 + 1694385*uk_91 + 2228373*uk_92 + 852327*uk_93 + 434007*uk_94 + 83664*uk_95 + 862785*uk_96 + 1134693*uk_97 + 434007*uk_98 + 16128*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 126252*uk_100 + 164052*uk_101 + 123228*uk_102 + 1757007*uk_103 + 2283057*uk_104 + 1714923*uk_105 + 2966607*uk_106 + 2228373*uk_107 + 1673847*uk_108 + 778688*uk_109 + 4356476*uk_11 + 1379632*uk_110 + 101568*uk_111 + 1413488*uk_112 + 1836688*uk_113 + 1379632*uk_114 + 2444348*uk_115 + 179952*uk_116 + 2504332*uk_117 + 3254132*uk_118 + 2444348*uk_119 + 7718539*uk_12 + 13248*uk_120 + 184368*uk_121 + 239568*uk_122 + 179952*uk_123 + 2565788*uk_124 + 3333988*uk_125 + 2504332*uk_126 + 4332188*uk_127 + 3254132*uk_128 + 2444348*uk_129 + 568236*uk_13 + 4330747*uk_130 + 318828*uk_131 + 4437023*uk_132 + 5765473*uk_133 + 4330747*uk_134 + 23472*uk_135 + 326652*uk_136 + 424452*uk_137 + 318828*uk_138 + 4545907*uk_139 + 7907951*uk_14 + 5906957*uk_140 + 4437023*uk_141 + 7675507*uk_142 + 5765473*uk_143 + 4330747*uk_144 + 1728*uk_145 + 24048*uk_146 + 31248*uk_147 + 23472*uk_148 + 334668*uk_149 + 10275601*uk_15 + 434868*uk_150 + 326652*uk_151 + 565068*uk_152 + 424452*uk_153 + 318828*uk_154 + 4657463*uk_155 + 6051913*uk_156 + 4545907*uk_157 + 7863863*uk_158 + 5906957*uk_159 + 7718539*uk_16 + 4437023*uk_160 + 10218313*uk_161 + 7675507*uk_162 + 5765473*uk_163 + 4330747*uk_164 + 3969*uk_17 + 5796*uk_18 + 10269*uk_19 + 63*uk_2 + 756*uk_20 + 10521*uk_21 + 13671*uk_22 + 10269*uk_23 + 8464*uk_24 + 14996*uk_25 + 1104*uk_26 + 15364*uk_27 + 19964*uk_28 + 14996*uk_29 + 92*uk_3 + 26569*uk_30 + 1956*uk_31 + 27221*uk_32 + 35371*uk_33 + 26569*uk_34 + 144*uk_35 + 2004*uk_36 + 2604*uk_37 + 1956*uk_38 + 27889*uk_39 + 163*uk_4 + 36239*uk_40 + 27221*uk_41 + 47089*uk_42 + 35371*uk_43 + 26569*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 206292208028*uk_47 + 365495977267*uk_48 + 26907679308*uk_49 + 12*uk_5 + 374465203703*uk_50 + 486580534153*uk_51 + 365495977267*uk_52 + 187944057*uk_53 + 274457988*uk_54 + 486267957*uk_55 + 35798868*uk_56 + 498200913*uk_57 + 647362863*uk_58 + 486267957*uk_59 + 167*uk_6 + 400795792*uk_60 + 710105588*uk_61 + 52277712*uk_62 + 727531492*uk_63 + 945355292*uk_64 + 710105588*uk_65 + 1258121857*uk_66 + 92622468*uk_67 + 1288996013*uk_68 + 1674922963*uk_69 + 217*uk_7 + 1258121857*uk_70 + 6818832*uk_71 + 94895412*uk_72 + 123307212*uk_73 + 92622468*uk_74 + 1320627817*uk_75 + 1716025367*uk_76 + 1288996013*uk_77 + 2229805417*uk_78 + 1674922963*uk_79 + 163*uk_8 + 1258121857*uk_80 + 250047*uk_81 + 365148*uk_82 + 646947*uk_83 + 47628*uk_84 + 662823*uk_85 + 861273*uk_86 + 646947*uk_87 + 533232*uk_88 + 944748*uk_89 + 2242306609*uk_9 + 69552*uk_90 + 967932*uk_91 + 1257732*uk_92 + 944748*uk_93 + 1673847*uk_94 + 123228*uk_95 + 1714923*uk_96 + 2228373*uk_97 + 1673847*uk_98 + 9072*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 127764*uk_100 + 164052*uk_101 + 69552*uk_102 + 1799343*uk_103 + 2310399*uk_104 + 979524*uk_105 + 2966607*uk_106 + 1257732*uk_107 + 533232*uk_108 + 35937*uk_109 + 1562649*uk_11 + 100188*uk_110 + 13068*uk_111 + 184041*uk_112 + 236313*uk_113 + 100188*uk_114 + 279312*uk_115 + 36432*uk_116 + 513084*uk_117 + 658812*uk_118 + 279312*uk_119 + 4356476*uk_12 + 4752*uk_120 + 66924*uk_121 + 85932*uk_122 + 36432*uk_123 + 942513*uk_124 + 1210209*uk_125 + 513084*uk_126 + 1553937*uk_127 + 658812*uk_128 + 279312*uk_129 + 568236*uk_13 + 778688*uk_130 + 101568*uk_131 + 1430416*uk_132 + 1836688*uk_133 + 778688*uk_134 + 13248*uk_135 + 186576*uk_136 + 239568*uk_137 + 101568*uk_138 + 2627612*uk_139 + 8002657*uk_14 + 3373916*uk_140 + 1430416*uk_141 + 4332188*uk_142 + 1836688*uk_143 + 778688*uk_144 + 1728*uk_145 + 24336*uk_146 + 31248*uk_147 + 13248*uk_148 + 342732*uk_149 + 10275601*uk_15 + 440076*uk_150 + 186576*uk_151 + 565068*uk_152 + 239568*uk_153 + 101568*uk_154 + 4826809*uk_155 + 6197737*uk_156 + 2627612*uk_157 + 7958041*uk_158 + 3373916*uk_159 + 4356476*uk_16 + 1430416*uk_160 + 10218313*uk_161 + 4332188*uk_162 + 1836688*uk_163 + 778688*uk_164 + 3969*uk_17 + 2079*uk_18 + 5796*uk_19 + 63*uk_2 + 756*uk_20 + 10647*uk_21 + 13671*uk_22 + 5796*uk_23 + 1089*uk_24 + 3036*uk_25 + 396*uk_26 + 5577*uk_27 + 7161*uk_28 + 3036*uk_29 + 33*uk_3 + 8464*uk_30 + 1104*uk_31 + 15548*uk_32 + 19964*uk_33 + 8464*uk_34 + 144*uk_35 + 2028*uk_36 + 2604*uk_37 + 1104*uk_38 + 28561*uk_39 + 92*uk_4 + 36673*uk_40 + 15548*uk_41 + 47089*uk_42 + 19964*uk_43 + 8464*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 73996118097*uk_47 + 206292208028*uk_48 + 26907679308*uk_49 + 12*uk_5 + 378949816921*uk_50 + 486580534153*uk_51 + 206292208028*uk_52 + 187944057*uk_53 + 98446887*uk_54 + 274457988*uk_55 + 35798868*uk_56 + 504167391*uk_57 + 647362863*uk_58 + 274457988*uk_59 + 169*uk_6 + 51567417*uk_60 + 143763708*uk_61 + 18751788*uk_62 + 264087681*uk_63 + 339094833*uk_64 + 143763708*uk_65 + 400795792*uk_66 + 52277712*uk_67 + 736244444*uk_68 + 945355292*uk_69 + 217*uk_7 + 400795792*uk_70 + 6818832*uk_71 + 96031884*uk_72 + 123307212*uk_73 + 52277712*uk_74 + 1352449033*uk_75 + 1736576569*uk_76 + 736244444*uk_77 + 2229805417*uk_78 + 945355292*uk_79 + 92*uk_8 + 400795792*uk_80 + 250047*uk_81 + 130977*uk_82 + 365148*uk_83 + 47628*uk_84 + 670761*uk_85 + 861273*uk_86 + 365148*uk_87 + 68607*uk_88 + 191268*uk_89 + 2242306609*uk_9 + 24948*uk_90 + 351351*uk_91 + 451143*uk_92 + 191268*uk_93 + 533232*uk_94 + 69552*uk_95 + 979524*uk_96 + 1257732*uk_97 + 533232*uk_98 + 9072*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 172368*uk_100 + 218736*uk_101 + 33264*uk_102 + 1842183*uk_103 + 2337741*uk_104 + 355509*uk_105 + 2966607*uk_106 + 451143*uk_107 + 68607*uk_108 + 3869893*uk_109 + 7434421*uk_11 + 813417*uk_110 + 394384*uk_111 + 4214979*uk_112 + 5348833*uk_113 + 813417*uk_114 + 170973*uk_115 + 82896*uk_116 + 885951*uk_117 + 1124277*uk_118 + 170973*uk_119 + 1562649*uk_12 + 40192*uk_120 + 429552*uk_121 + 545104*uk_122 + 82896*uk_123 + 4590837*uk_124 + 5825799*uk_125 + 885951*uk_126 + 7392973*uk_127 + 1124277*uk_128 + 170973*uk_129 + 757648*uk_13 + 35937*uk_130 + 17424*uk_131 + 186219*uk_132 + 236313*uk_133 + 35937*uk_134 + 8448*uk_135 + 90288*uk_136 + 114576*uk_137 + 17424*uk_138 + 964953*uk_139 + 8097363*uk_14 + 1224531*uk_140 + 186219*uk_141 + 1553937*uk_142 + 236313*uk_143 + 35937*uk_144 + 4096*uk_145 + 43776*uk_146 + 55552*uk_147 + 8448*uk_148 + 467856*uk_149 + 10275601*uk_15 + 593712*uk_150 + 90288*uk_151 + 753424*uk_152 + 114576*uk_153 + 17424*uk_154 + 5000211*uk_155 + 6345297*uk_156 + 964953*uk_157 + 8052219*uk_158 + 1224531*uk_159 + 1562649*uk_16 + 186219*uk_160 + 10218313*uk_161 + 1553937*uk_162 + 236313*uk_163 + 35937*uk_164 + 3969*uk_17 + 9891*uk_18 + 2079*uk_19 + 63*uk_2 + 1008*uk_20 + 10773*uk_21 + 13671*uk_22 + 2079*uk_23 + 24649*uk_24 + 5181*uk_25 + 2512*uk_26 + 26847*uk_27 + 34069*uk_28 + 5181*uk_29 + 157*uk_3 + 1089*uk_30 + 528*uk_31 + 5643*uk_32 + 7161*uk_33 + 1089*uk_34 + 256*uk_35 + 2736*uk_36 + 3472*uk_37 + 528*uk_38 + 29241*uk_39 + 33*uk_4 + 37107*uk_40 + 5643*uk_41 + 47089*uk_42 + 7161*uk_43 + 1089*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 352042137613*uk_47 + 73996118097*uk_48 + 35876905744*uk_49 + 16*uk_5 + 383434430139*uk_50 + 486580534153*uk_51 + 73996118097*uk_52 + 187944057*uk_53 + 468368523*uk_54 + 98446887*uk_55 + 47731824*uk_56 + 510133869*uk_57 + 647362863*uk_58 + 98446887*uk_59 + 171*uk_6 + 1167204097*uk_60 + 245335893*uk_61 + 118950736*uk_62 + 1271285991*uk_63 + 1613269357*uk_64 + 245335893*uk_65 + 51567417*uk_66 + 25002384*uk_67 + 267212979*uk_68 + 339094833*uk_69 + 217*uk_7 + 51567417*uk_70 + 12122368*uk_71 + 129557808*uk_72 + 164409616*uk_73 + 25002384*uk_74 + 1384649073*uk_75 + 1757127771*uk_76 + 267212979*uk_77 + 2229805417*uk_78 + 339094833*uk_79 + 33*uk_8 + 51567417*uk_80 + 250047*uk_81 + 623133*uk_82 + 130977*uk_83 + 63504*uk_84 + 678699*uk_85 + 861273*uk_86 + 130977*uk_87 + 1552887*uk_88 + 326403*uk_89 + 2242306609*uk_9 + 158256*uk_90 + 1691361*uk_91 + 2146347*uk_92 + 326403*uk_93 + 68607*uk_94 + 33264*uk_95 + 355509*uk_96 + 451143*uk_97 + 68607*uk_98 + 16128*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 130788*uk_100 + 164052*uk_101 + 118692*uk_102 + 1885527*uk_103 + 2365083*uk_104 + 1711143*uk_105 + 2966607*uk_106 + 2146347*uk_107 + 1552887*uk_108 + 1906624*uk_109 + 5871772*uk_11 + 2414032*uk_110 + 184512*uk_111 + 2660048*uk_112 + 3336592*uk_113 + 2414032*uk_114 + 3056476*uk_115 + 233616*uk_116 + 3367964*uk_117 + 4224556*uk_118 + 3056476*uk_119 + 7434421*uk_12 + 17856*uk_120 + 257424*uk_121 + 322896*uk_122 + 233616*uk_123 + 3711196*uk_124 + 4655084*uk_125 + 3367964*uk_126 + 5839036*uk_127 + 4224556*uk_128 + 3056476*uk_129 + 568236*uk_13 + 3869893*uk_130 + 295788*uk_131 + 4264277*uk_132 + 5348833*uk_133 + 3869893*uk_134 + 22608*uk_135 + 325932*uk_136 + 408828*uk_137 + 295788*uk_138 + 4698853*uk_139 + 8192069*uk_14 + 5893937*uk_140 + 4264277*uk_141 + 7392973*uk_142 + 5348833*uk_143 + 3869893*uk_144 + 1728*uk_145 + 24912*uk_146 + 31248*uk_147 + 22608*uk_148 + 359148*uk_149 + 10275601*uk_15 + 450492*uk_150 + 325932*uk_151 + 565068*uk_152 + 408828*uk_153 + 295788*uk_154 + 5177717*uk_155 + 6494593*uk_156 + 4698853*uk_157 + 8146397*uk_158 + 5893937*uk_159 + 7434421*uk_16 + 4264277*uk_160 + 10218313*uk_161 + 7392973*uk_162 + 5348833*uk_163 + 3869893*uk_164 + 3969*uk_17 + 7812*uk_18 + 9891*uk_19 + 63*uk_2 + 756*uk_20 + 10899*uk_21 + 13671*uk_22 + 9891*uk_23 + 15376*uk_24 + 19468*uk_25 + 1488*uk_26 + 21452*uk_27 + 26908*uk_28 + 19468*uk_29 + 124*uk_3 + 24649*uk_30 + 1884*uk_31 + 27161*uk_32 + 34069*uk_33 + 24649*uk_34 + 144*uk_35 + 2076*uk_36 + 2604*uk_37 + 1884*uk_38 + 29929*uk_39 + 157*uk_4 + 37541*uk_40 + 27161*uk_41 + 47089*uk_42 + 34069*uk_43 + 24649*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 278046019516*uk_47 + 352042137613*uk_48 + 26907679308*uk_49 + 12*uk_5 + 387919043357*uk_50 + 486580534153*uk_51 + 352042137613*uk_52 + 187944057*uk_53 + 369921636*uk_54 + 468368523*uk_55 + 35798868*uk_56 + 516100347*uk_57 + 647362863*uk_58 + 468368523*uk_59 + 173*uk_6 + 728099728*uk_60 + 921868204*uk_61 + 70461264*uk_62 + 1015816556*uk_63 + 1274174524*uk_64 + 921868204*uk_65 + 1167204097*uk_66 + 89213052*uk_67 + 1286154833*uk_68 + 1613269357*uk_69 + 217*uk_7 + 1167204097*uk_70 + 6818832*uk_71 + 98304828*uk_72 + 123307212*uk_73 + 89213052*uk_74 + 1417227937*uk_75 + 1777678973*uk_76 + 1286154833*uk_77 + 2229805417*uk_78 + 1613269357*uk_79 + 157*uk_8 + 1167204097*uk_80 + 250047*uk_81 + 492156*uk_82 + 623133*uk_83 + 47628*uk_84 + 686637*uk_85 + 861273*uk_86 + 623133*uk_87 + 968688*uk_88 + 1226484*uk_89 + 2242306609*uk_9 + 93744*uk_90 + 1351476*uk_91 + 1695204*uk_92 + 1226484*uk_93 + 1552887*uk_94 + 118692*uk_95 + 1711143*uk_96 + 2146347*uk_97 + 1552887*uk_98 + 9072*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 132300*uk_100 + 164052*uk_101 + 93744*uk_102 + 1929375*uk_103 + 2392425*uk_104 + 1367100*uk_105 + 2966607*uk_106 + 1695204*uk_107 + 968688*uk_108 + 1092727*uk_109 + 4877359*uk_11 + 1315516*uk_110 + 127308*uk_111 + 1856575*uk_112 + 2302153*uk_113 + 1315516*uk_114 + 1583728*uk_115 + 153264*uk_116 + 2235100*uk_117 + 2771524*uk_118 + 1583728*uk_119 + 5871772*uk_12 + 14832*uk_120 + 216300*uk_121 + 268212*uk_122 + 153264*uk_123 + 3154375*uk_124 + 3911425*uk_125 + 2235100*uk_126 + 4850167*uk_127 + 2771524*uk_128 + 1583728*uk_129 + 568236*uk_13 + 1906624*uk_130 + 184512*uk_131 + 2690800*uk_132 + 3336592*uk_133 + 1906624*uk_134 + 17856*uk_135 + 260400*uk_136 + 322896*uk_137 + 184512*uk_138 + 3797500*uk_139 + 8286775*uk_14 + 4708900*uk_140 + 2690800*uk_141 + 5839036*uk_142 + 3336592*uk_143 + 1906624*uk_144 + 1728*uk_145 + 25200*uk_146 + 31248*uk_147 + 17856*uk_148 + 367500*uk_149 + 10275601*uk_15 + 455700*uk_150 + 260400*uk_151 + 565068*uk_152 + 322896*uk_153 + 184512*uk_154 + 5359375*uk_155 + 6645625*uk_156 + 3797500*uk_157 + 8240575*uk_158 + 4708900*uk_159 + 5871772*uk_16 + 2690800*uk_160 + 10218313*uk_161 + 5839036*uk_162 + 3336592*uk_163 + 1906624*uk_164 + 3969*uk_17 + 6489*uk_18 + 7812*uk_19 + 63*uk_2 + 756*uk_20 + 11025*uk_21 + 13671*uk_22 + 7812*uk_23 + 10609*uk_24 + 12772*uk_25 + 1236*uk_26 + 18025*uk_27 + 22351*uk_28 + 12772*uk_29 + 103*uk_3 + 15376*uk_30 + 1488*uk_31 + 21700*uk_32 + 26908*uk_33 + 15376*uk_34 + 144*uk_35 + 2100*uk_36 + 2604*uk_37 + 1488*uk_38 + 30625*uk_39 + 124*uk_4 + 37975*uk_40 + 21700*uk_41 + 47089*uk_42 + 26908*uk_43 + 15376*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 230957580727*uk_47 + 278046019516*uk_48 + 26907679308*uk_49 + 12*uk_5 + 392403656575*uk_50 + 486580534153*uk_51 + 278046019516*uk_52 + 187944057*uk_53 + 307273617*uk_54 + 369921636*uk_55 + 35798868*uk_56 + 522066825*uk_57 + 647362863*uk_58 + 369921636*uk_59 + 175*uk_6 + 502367977*uk_60 + 604792516*uk_61 + 58528308*uk_62 + 853537825*uk_63 + 1058386903*uk_64 + 604792516*uk_65 + 728099728*uk_66 + 70461264*uk_67 + 1027560100*uk_68 + 1274174524*uk_69 + 217*uk_7 + 728099728*uk_70 + 6818832*uk_71 + 99441300*uk_72 + 123307212*uk_73 + 70461264*uk_74 + 1450185625*uk_75 + 1798230175*uk_76 + 1027560100*uk_77 + 2229805417*uk_78 + 1274174524*uk_79 + 124*uk_8 + 728099728*uk_80 + 250047*uk_81 + 408807*uk_82 + 492156*uk_83 + 47628*uk_84 + 694575*uk_85 + 861273*uk_86 + 492156*uk_87 + 668367*uk_88 + 804636*uk_89 + 2242306609*uk_9 + 77868*uk_90 + 1135575*uk_91 + 1408113*uk_92 + 804636*uk_93 + 968688*uk_94 + 93744*uk_95 + 1367100*uk_96 + 1695204*uk_97 + 968688*uk_98 + 9072*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 133812*uk_100 + 164052*uk_101 + 77868*uk_102 + 1973727*uk_103 + 2419767*uk_104 + 1148553*uk_105 + 2966607*uk_106 + 1408113*uk_107 + 668367*uk_108 + 830584*uk_109 + 4451182*uk_11 + 910108*uk_110 + 106032*uk_111 + 1563972*uk_112 + 1917412*uk_113 + 910108*uk_114 + 997246*uk_115 + 116184*uk_116 + 1713714*uk_117 + 2100994*uk_118 + 997246*uk_119 + 4877359*uk_12 + 13536*uk_120 + 199656*uk_121 + 244776*uk_122 + 116184*uk_123 + 2944926*uk_124 + 3610446*uk_125 + 1713714*uk_126 + 4426366*uk_127 + 2100994*uk_128 + 997246*uk_129 + 568236*uk_13 + 1092727*uk_130 + 127308*uk_131 + 1877793*uk_132 + 2302153*uk_133 + 1092727*uk_134 + 14832*uk_135 + 218772*uk_136 + 268212*uk_137 + 127308*uk_138 + 3226887*uk_139 + 8381481*uk_14 + 3956127*uk_140 + 1877793*uk_141 + 4850167*uk_142 + 2302153*uk_143 + 1092727*uk_144 + 1728*uk_145 + 25488*uk_146 + 31248*uk_147 + 14832*uk_148 + 375948*uk_149 + 10275601*uk_15 + 460908*uk_150 + 218772*uk_151 + 565068*uk_152 + 268212*uk_153 + 127308*uk_154 + 5545233*uk_155 + 6798393*uk_156 + 3226887*uk_157 + 8334753*uk_158 + 3956127*uk_159 + 4877359*uk_16 + 1877793*uk_160 + 10218313*uk_161 + 4850167*uk_162 + 2302153*uk_163 + 1092727*uk_164 + 3969*uk_17 + 5922*uk_18 + 6489*uk_19 + 63*uk_2 + 756*uk_20 + 11151*uk_21 + 13671*uk_22 + 6489*uk_23 + 8836*uk_24 + 9682*uk_25 + 1128*uk_26 + 16638*uk_27 + 20398*uk_28 + 9682*uk_29 + 94*uk_3 + 10609*uk_30 + 1236*uk_31 + 18231*uk_32 + 22351*uk_33 + 10609*uk_34 + 144*uk_35 + 2124*uk_36 + 2604*uk_37 + 1236*uk_38 + 31329*uk_39 + 103*uk_4 + 38409*uk_40 + 18231*uk_41 + 47089*uk_42 + 22351*uk_43 + 10609*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 210776821246*uk_47 + 230957580727*uk_48 + 26907679308*uk_49 + 12*uk_5 + 396888269793*uk_50 + 486580534153*uk_51 + 230957580727*uk_52 + 187944057*uk_53 + 280424466*uk_54 + 307273617*uk_55 + 35798868*uk_56 + 528033303*uk_57 + 647362863*uk_58 + 307273617*uk_59 + 177*uk_6 + 418411108*uk_60 + 458471746*uk_61 + 53414184*uk_62 + 787859214*uk_63 + 965906494*uk_64 + 458471746*uk_65 + 502367977*uk_66 + 58528308*uk_67 + 863292543*uk_68 + 1058386903*uk_69 + 217*uk_7 + 502367977*uk_70 + 6818832*uk_71 + 100577772*uk_72 + 123307212*uk_73 + 58528308*uk_74 + 1483522137*uk_75 + 1818781377*uk_76 + 863292543*uk_77 + 2229805417*uk_78 + 1058386903*uk_79 + 103*uk_8 + 502367977*uk_80 + 250047*uk_81 + 373086*uk_82 + 408807*uk_83 + 47628*uk_84 + 702513*uk_85 + 861273*uk_86 + 408807*uk_87 + 556668*uk_88 + 609966*uk_89 + 2242306609*uk_9 + 71064*uk_90 + 1048194*uk_91 + 1285074*uk_92 + 609966*uk_93 + 668367*uk_94 + 77868*uk_95 + 1148553*uk_96 + 1408113*uk_97 + 668367*uk_98 + 9072*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 135324*uk_100 + 164052*uk_101 + 71064*uk_102 + 2018583*uk_103 + 2447109*uk_104 + 1060038*uk_105 + 2966607*uk_106 + 1285074*uk_107 + 556668*uk_108 + 912673*uk_109 + 4593241*uk_11 + 884446*uk_110 + 112908*uk_111 + 1684211*uk_112 + 2041753*uk_113 + 884446*uk_114 + 857092*uk_115 + 109416*uk_116 + 1632122*uk_117 + 1978606*uk_118 + 857092*uk_119 + 4451182*uk_12 + 13968*uk_120 + 208356*uk_121 + 252588*uk_122 + 109416*uk_123 + 3107977*uk_124 + 3767771*uk_125 + 1632122*uk_126 + 4567633*uk_127 + 1978606*uk_128 + 857092*uk_129 + 568236*uk_13 + 830584*uk_130 + 106032*uk_131 + 1581644*uk_132 + 1917412*uk_133 + 830584*uk_134 + 13536*uk_135 + 201912*uk_136 + 244776*uk_137 + 106032*uk_138 + 3011854*uk_139 + 8476187*uk_14 + 3651242*uk_140 + 1581644*uk_141 + 4426366*uk_142 + 1917412*uk_143 + 830584*uk_144 + 1728*uk_145 + 25776*uk_146 + 31248*uk_147 + 13536*uk_148 + 384492*uk_149 + 10275601*uk_15 + 466116*uk_150 + 201912*uk_151 + 565068*uk_152 + 244776*uk_153 + 106032*uk_154 + 5735339*uk_155 + 6952897*uk_156 + 3011854*uk_157 + 8428931*uk_158 + 3651242*uk_159 + 4451182*uk_16 + 1581644*uk_160 + 10218313*uk_161 + 4426366*uk_162 + 1917412*uk_163 + 830584*uk_164 + 3969*uk_17 + 6111*uk_18 + 5922*uk_19 + 63*uk_2 + 756*uk_20 + 11277*uk_21 + 13671*uk_22 + 5922*uk_23 + 9409*uk_24 + 9118*uk_25 + 1164*uk_26 + 17363*uk_27 + 21049*uk_28 + 9118*uk_29 + 97*uk_3 + 8836*uk_30 + 1128*uk_31 + 16826*uk_32 + 20398*uk_33 + 8836*uk_34 + 144*uk_35 + 2148*uk_36 + 2604*uk_37 + 1128*uk_38 + 32041*uk_39 + 94*uk_4 + 38843*uk_40 + 16826*uk_41 + 47089*uk_42 + 20398*uk_43 + 8836*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 217503741073*uk_47 + 210776821246*uk_48 + 26907679308*uk_49 + 12*uk_5 + 401372883011*uk_50 + 486580534153*uk_51 + 210776821246*uk_52 + 187944057*uk_53 + 289374183*uk_54 + 280424466*uk_55 + 35798868*uk_56 + 533999781*uk_57 + 647362863*uk_58 + 280424466*uk_59 + 179*uk_6 + 445544377*uk_60 + 431764654*uk_61 + 55118892*uk_62 + 822190139*uk_63 + 996733297*uk_64 + 431764654*uk_65 + 418411108*uk_66 + 53414184*uk_67 + 796761578*uk_68 + 965906494*uk_69 + 217*uk_7 + 418411108*uk_70 + 6818832*uk_71 + 101714244*uk_72 + 123307212*uk_73 + 53414184*uk_74 + 1517237473*uk_75 + 1839332579*uk_76 + 796761578*uk_77 + 2229805417*uk_78 + 965906494*uk_79 + 94*uk_8 + 418411108*uk_80 + 250047*uk_81 + 384993*uk_82 + 373086*uk_83 + 47628*uk_84 + 710451*uk_85 + 861273*uk_86 + 373086*uk_87 + 592767*uk_88 + 574434*uk_89 + 2242306609*uk_9 + 73332*uk_90 + 1093869*uk_91 + 1326087*uk_92 + 574434*uk_93 + 556668*uk_94 + 71064*uk_95 + 1060038*uk_96 + 1285074*uk_97 + 556668*uk_98 + 9072*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 136836*uk_100 + 164052*uk_101 + 73332*uk_102 + 2063943*uk_103 + 2474451*uk_104 + 1106091*uk_105 + 2966607*uk_106 + 1326087*uk_107 + 592767*uk_108 + 1404928*uk_109 + 5303536*uk_11 + 1216768*uk_110 + 150528*uk_111 + 2270464*uk_112 + 2722048*uk_113 + 1216768*uk_114 + 1053808*uk_115 + 130368*uk_116 + 1966384*uk_117 + 2357488*uk_118 + 1053808*uk_119 + 4593241*uk_12 + 16128*uk_120 + 243264*uk_121 + 291648*uk_122 + 130368*uk_123 + 3669232*uk_124 + 4399024*uk_125 + 1966384*uk_126 + 5273968*uk_127 + 2357488*uk_128 + 1053808*uk_129 + 568236*uk_13 + 912673*uk_130 + 112908*uk_131 + 1703029*uk_132 + 2041753*uk_133 + 912673*uk_134 + 13968*uk_135 + 210684*uk_136 + 252588*uk_137 + 112908*uk_138 + 3177817*uk_139 + 8570893*uk_14 + 3809869*uk_140 + 1703029*uk_141 + 4567633*uk_142 + 2041753*uk_143 + 912673*uk_144 + 1728*uk_145 + 26064*uk_146 + 31248*uk_147 + 13968*uk_148 + 393132*uk_149 + 10275601*uk_15 + 471324*uk_150 + 210684*uk_151 + 565068*uk_152 + 252588*uk_153 + 112908*uk_154 + 5929741*uk_155 + 7109137*uk_156 + 3177817*uk_157 + 8523109*uk_158 + 3809869*uk_159 + 4593241*uk_16 + 1703029*uk_160 + 10218313*uk_161 + 4567633*uk_162 + 2041753*uk_163 + 912673*uk_164 + 3969*uk_17 + 7056*uk_18 + 6111*uk_19 + 63*uk_2 + 756*uk_20 + 11403*uk_21 + 13671*uk_22 + 6111*uk_23 + 12544*uk_24 + 10864*uk_25 + 1344*uk_26 + 20272*uk_27 + 24304*uk_28 + 10864*uk_29 + 112*uk_3 + 9409*uk_30 + 1164*uk_31 + 17557*uk_32 + 21049*uk_33 + 9409*uk_34 + 144*uk_35 + 2172*uk_36 + 2604*uk_37 + 1164*uk_38 + 32761*uk_39 + 97*uk_4 + 39277*uk_40 + 17557*uk_41 + 47089*uk_42 + 21049*uk_43 + 9409*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 251138340208*uk_47 + 217503741073*uk_48 + 26907679308*uk_49 + 12*uk_5 + 405857496229*uk_50 + 486580534153*uk_51 + 217503741073*uk_52 + 187944057*uk_53 + 334122768*uk_54 + 289374183*uk_55 + 35798868*uk_56 + 539966259*uk_57 + 647362863*uk_58 + 289374183*uk_59 + 181*uk_6 + 593996032*uk_60 + 514442992*uk_61 + 63642432*uk_62 + 959940016*uk_63 + 1150867312*uk_64 + 514442992*uk_65 + 445544377*uk_66 + 55118892*uk_67 + 831376621*uk_68 + 996733297*uk_69 + 217*uk_7 + 445544377*uk_70 + 6818832*uk_71 + 102850716*uk_72 + 123307212*uk_73 + 55118892*uk_74 + 1551331633*uk_75 + 1859883781*uk_76 + 831376621*uk_77 + 2229805417*uk_78 + 996733297*uk_79 + 97*uk_8 + 445544377*uk_80 + 250047*uk_81 + 444528*uk_82 + 384993*uk_83 + 47628*uk_84 + 718389*uk_85 + 861273*uk_86 + 384993*uk_87 + 790272*uk_88 + 684432*uk_89 + 2242306609*uk_9 + 84672*uk_90 + 1277136*uk_91 + 1531152*uk_92 + 684432*uk_93 + 592767*uk_94 + 73332*uk_95 + 1106091*uk_96 + 1326087*uk_97 + 592767*uk_98 + 9072*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 138348*uk_100 + 164052*uk_101 + 84672*uk_102 + 2109807*uk_103 + 2501793*uk_104 + 1291248*uk_105 + 2966607*uk_106 + 1531152*uk_107 + 790272*uk_108 + 2685619*uk_109 + 6582067*uk_11 + 2163952*uk_110 + 231852*uk_111 + 3535743*uk_112 + 4192657*uk_113 + 2163952*uk_114 + 1743616*uk_115 + 186816*uk_116 + 2848944*uk_117 + 3378256*uk_118 + 1743616*uk_119 + 5303536*uk_12 + 20016*uk_120 + 305244*uk_121 + 361956*uk_122 + 186816*uk_123 + 4654971*uk_124 + 5519829*uk_125 + 2848944*uk_126 + 6545371*uk_127 + 3378256*uk_128 + 1743616*uk_129 + 568236*uk_13 + 1404928*uk_130 + 150528*uk_131 + 2295552*uk_132 + 2722048*uk_133 + 1404928*uk_134 + 16128*uk_135 + 245952*uk_136 + 291648*uk_137 + 150528*uk_138 + 3750768*uk_139 + 8665599*uk_14 + 4447632*uk_140 + 2295552*uk_141 + 5273968*uk_142 + 2722048*uk_143 + 1404928*uk_144 + 1728*uk_145 + 26352*uk_146 + 31248*uk_147 + 16128*uk_148 + 401868*uk_149 + 10275601*uk_15 + 476532*uk_150 + 245952*uk_151 + 565068*uk_152 + 291648*uk_153 + 150528*uk_154 + 6128487*uk_155 + 7267113*uk_156 + 3750768*uk_157 + 8617287*uk_158 + 4447632*uk_159 + 5303536*uk_16 + 2295552*uk_160 + 10218313*uk_161 + 5273968*uk_162 + 2722048*uk_163 + 1404928*uk_164 + 3969*uk_17 + 8757*uk_18 + 7056*uk_19 + 63*uk_2 + 756*uk_20 + 11529*uk_21 + 13671*uk_22 + 7056*uk_23 + 19321*uk_24 + 15568*uk_25 + 1668*uk_26 + 25437*uk_27 + 30163*uk_28 + 15568*uk_29 + 139*uk_3 + 12544*uk_30 + 1344*uk_31 + 20496*uk_32 + 24304*uk_33 + 12544*uk_34 + 144*uk_35 + 2196*uk_36 + 2604*uk_37 + 1344*uk_38 + 33489*uk_39 + 112*uk_4 + 39711*uk_40 + 20496*uk_41 + 47089*uk_42 + 24304*uk_43 + 12544*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 311680618651*uk_47 + 251138340208*uk_48 + 26907679308*uk_49 + 12*uk_5 + 410342109447*uk_50 + 486580534153*uk_51 + 251138340208*uk_52 + 187944057*uk_53 + 414670221*uk_54 + 334122768*uk_55 + 35798868*uk_56 + 545932737*uk_57 + 647362863*uk_58 + 334122768*uk_59 + 183*uk_6 + 914907313*uk_60 + 737191504*uk_61 + 78984804*uk_62 + 1204518261*uk_63 + 1428308539*uk_64 + 737191504*uk_65 + 593996032*uk_66 + 63642432*uk_67 + 970547088*uk_68 + 1150867312*uk_69 + 217*uk_7 + 593996032*uk_70 + 6818832*uk_71 + 103987188*uk_72 + 123307212*uk_73 + 63642432*uk_74 + 1585804617*uk_75 + 1880434983*uk_76 + 970547088*uk_77 + 2229805417*uk_78 + 1150867312*uk_79 + 112*uk_8 + 593996032*uk_80 + 250047*uk_81 + 551691*uk_82 + 444528*uk_83 + 47628*uk_84 + 726327*uk_85 + 861273*uk_86 + 444528*uk_87 + 1217223*uk_88 + 980784*uk_89 + 2242306609*uk_9 + 105084*uk_90 + 1602531*uk_91 + 1900269*uk_92 + 980784*uk_93 + 790272*uk_94 + 84672*uk_95 + 1291248*uk_96 + 1531152*uk_97 + 790272*uk_98 + 9072*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 139860*uk_100 + 164052*uk_101 + 105084*uk_102 + 2156175*uk_103 + 2529135*uk_104 + 1620045*uk_105 + 2966607*uk_106 + 1900269*uk_107 + 1217223*uk_108 + 5639752*uk_109 + 8428834*uk_11 + 4404076*uk_110 + 380208*uk_111 + 5861540*uk_112 + 6875428*uk_113 + 4404076*uk_114 + 3439138*uk_115 + 296904*uk_116 + 4577270*uk_117 + 5369014*uk_118 + 3439138*uk_119 + 6582067*uk_12 + 25632*uk_120 + 395160*uk_121 + 463512*uk_122 + 296904*uk_123 + 6092050*uk_124 + 7145810*uk_125 + 4577270*uk_126 + 8381842*uk_127 + 5369014*uk_128 + 3439138*uk_129 + 568236*uk_13 + 2685619*uk_130 + 231852*uk_131 + 3574385*uk_132 + 4192657*uk_133 + 2685619*uk_134 + 20016*uk_135 + 308580*uk_136 + 361956*uk_137 + 231852*uk_138 + 4757275*uk_139 + 8760305*uk_14 + 5580155*uk_140 + 3574385*uk_141 + 6545371*uk_142 + 4192657*uk_143 + 2685619*uk_144 + 1728*uk_145 + 26640*uk_146 + 31248*uk_147 + 20016*uk_148 + 410700*uk_149 + 10275601*uk_15 + 481740*uk_150 + 308580*uk_151 + 565068*uk_152 + 361956*uk_153 + 231852*uk_154 + 6331625*uk_155 + 7426825*uk_156 + 4757275*uk_157 + 8711465*uk_158 + 5580155*uk_159 + 6582067*uk_16 + 3574385*uk_160 + 10218313*uk_161 + 6545371*uk_162 + 4192657*uk_163 + 2685619*uk_164 + 3969*uk_17 + 11214*uk_18 + 8757*uk_19 + 63*uk_2 + 756*uk_20 + 11655*uk_21 + 13671*uk_22 + 8757*uk_23 + 31684*uk_24 + 24742*uk_25 + 2136*uk_26 + 32930*uk_27 + 38626*uk_28 + 24742*uk_29 + 178*uk_3 + 19321*uk_30 + 1668*uk_31 + 25715*uk_32 + 30163*uk_33 + 19321*uk_34 + 144*uk_35 + 2220*uk_36 + 2604*uk_37 + 1668*uk_38 + 34225*uk_39 + 139*uk_4 + 40145*uk_40 + 25715*uk_41 + 47089*uk_42 + 30163*uk_43 + 19321*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 399130576402*uk_47 + 311680618651*uk_48 + 26907679308*uk_49 + 12*uk_5 + 414826722665*uk_50 + 486580534153*uk_51 + 311680618651*uk_52 + 187944057*uk_53 + 531016542*uk_54 + 414670221*uk_55 + 35798868*uk_56 + 551899215*uk_57 + 647362863*uk_58 + 414670221*uk_59 + 185*uk_6 + 1500332452*uk_60 + 1171607926*uk_61 + 101146008*uk_62 + 1559334290*uk_63 + 1829056978*uk_64 + 1171607926*uk_65 + 914907313*uk_66 + 78984804*uk_67 + 1217682395*uk_68 + 1428308539*uk_69 + 217*uk_7 + 914907313*uk_70 + 6818832*uk_71 + 105123660*uk_72 + 123307212*uk_73 + 78984804*uk_74 + 1620656425*uk_75 + 1900986185*uk_76 + 1217682395*uk_77 + 2229805417*uk_78 + 1428308539*uk_79 + 139*uk_8 + 914907313*uk_80 + 250047*uk_81 + 706482*uk_82 + 551691*uk_83 + 47628*uk_84 + 734265*uk_85 + 861273*uk_86 + 551691*uk_87 + 1996092*uk_88 + 1558746*uk_89 + 2242306609*uk_9 + 134568*uk_90 + 2074590*uk_91 + 2433438*uk_92 + 1558746*uk_93 + 1217223*uk_94 + 105084*uk_95 + 1620045*uk_96 + 1900269*uk_97 + 1217223*uk_98 + 9072*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 94248*uk_100 + 109368*uk_101 + 89712*uk_102 + 2203047*uk_103 + 2556477*uk_104 + 2097018*uk_105 + 2966607*uk_106 + 2433438*uk_107 + 1996092*uk_108 + 74088*uk_109 + 1988826*uk_11 + 313992*uk_110 + 14112*uk_111 + 329868*uk_112 + 382788*uk_113 + 313992*uk_114 + 1330728*uk_115 + 59808*uk_116 + 1398012*uk_117 + 1622292*uk_118 + 1330728*uk_119 + 8428834*uk_12 + 2688*uk_120 + 62832*uk_121 + 72912*uk_122 + 59808*uk_123 + 1468698*uk_124 + 1704318*uk_125 + 1398012*uk_126 + 1977738*uk_127 + 1622292*uk_128 + 1330728*uk_129 + 378824*uk_13 + 5639752*uk_130 + 253472*uk_131 + 5924908*uk_132 + 6875428*uk_133 + 5639752*uk_134 + 11392*uk_135 + 266288*uk_136 + 309008*uk_137 + 253472*uk_138 + 6224482*uk_139 + 8855011*uk_14 + 7223062*uk_140 + 5924908*uk_141 + 8381842*uk_142 + 6875428*uk_143 + 5639752*uk_144 + 512*uk_145 + 11968*uk_146 + 13888*uk_147 + 11392*uk_148 + 279752*uk_149 + 10275601*uk_15 + 324632*uk_150 + 266288*uk_151 + 376712*uk_152 + 309008*uk_153 + 253472*uk_154 + 6539203*uk_155 + 7588273*uk_156 + 6224482*uk_157 + 8805643*uk_158 + 7223062*uk_159 + 8428834*uk_16 + 5924908*uk_160 + 10218313*uk_161 + 8381842*uk_162 + 6875428*uk_163 + 5639752*uk_164 + 3969*uk_17 + 2646*uk_18 + 11214*uk_19 + 63*uk_2 + 504*uk_20 + 11781*uk_21 + 13671*uk_22 + 11214*uk_23 + 1764*uk_24 + 7476*uk_25 + 336*uk_26 + 7854*uk_27 + 9114*uk_28 + 7476*uk_29 + 42*uk_3 + 31684*uk_30 + 1424*uk_31 + 33286*uk_32 + 38626*uk_33 + 31684*uk_34 + 64*uk_35 + 1496*uk_36 + 1736*uk_37 + 1424*uk_38 + 34969*uk_39 + 178*uk_4 + 40579*uk_40 + 33286*uk_41 + 47089*uk_42 + 38626*uk_43 + 31684*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 94176877578*uk_47 + 399130576402*uk_48 + 17938452872*uk_49 + 8*uk_5 + 419311335883*uk_50 + 486580534153*uk_51 + 399130576402*uk_52 + 187944057*uk_53 + 125296038*uk_54 + 531016542*uk_55 + 23865912*uk_56 + 557865693*uk_57 + 647362863*uk_58 + 531016542*uk_59 + 187*uk_6 + 83530692*uk_60 + 354011028*uk_61 + 15910608*uk_62 + 371910462*uk_63 + 431575242*uk_64 + 354011028*uk_65 + 1500332452*uk_66 + 67430672*uk_67 + 1576191958*uk_68 + 1829056978*uk_69 + 217*uk_7 + 1500332452*uk_70 + 3030592*uk_71 + 70840088*uk_72 + 82204808*uk_73 + 67430672*uk_74 + 1655887057*uk_75 + 1921537387*uk_76 + 1576191958*uk_77 + 2229805417*uk_78 + 1829056978*uk_79 + 178*uk_8 + 1500332452*uk_80 + 250047*uk_81 + 166698*uk_82 + 706482*uk_83 + 31752*uk_84 + 742203*uk_85 + 861273*uk_86 + 706482*uk_87 + 111132*uk_88 + 470988*uk_89 + 2242306609*uk_9 + 21168*uk_90 + 494802*uk_91 + 574182*uk_92 + 470988*uk_93 + 1996092*uk_94 + 89712*uk_95 + 2097018*uk_96 + 2433438*uk_97 + 1996092*uk_98 + 4032*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 142884*uk_100 + 164052*uk_101 + 31752*uk_102 + 2250423*uk_103 + 2583819*uk_104 + 500094*uk_105 + 2966607*uk_106 + 574182*uk_107 + 111132*uk_108 + 1092727*uk_109 + 4877359*uk_11 + 445578*uk_110 + 127308*uk_111 + 2005101*uk_112 + 2302153*uk_113 + 445578*uk_114 + 181692*uk_115 + 51912*uk_116 + 817614*uk_117 + 938742*uk_118 + 181692*uk_119 + 1988826*uk_12 + 14832*uk_120 + 233604*uk_121 + 268212*uk_122 + 51912*uk_123 + 3679263*uk_124 + 4224339*uk_125 + 817614*uk_126 + 4850167*uk_127 + 938742*uk_128 + 181692*uk_129 + 568236*uk_13 + 74088*uk_130 + 21168*uk_131 + 333396*uk_132 + 382788*uk_133 + 74088*uk_134 + 6048*uk_135 + 95256*uk_136 + 109368*uk_137 + 21168*uk_138 + 1500282*uk_139 + 8949717*uk_14 + 1722546*uk_140 + 333396*uk_141 + 1977738*uk_142 + 382788*uk_143 + 74088*uk_144 + 1728*uk_145 + 27216*uk_146 + 31248*uk_147 + 6048*uk_148 + 428652*uk_149 + 10275601*uk_15 + 492156*uk_150 + 95256*uk_151 + 565068*uk_152 + 109368*uk_153 + 21168*uk_154 + 6751269*uk_155 + 7751457*uk_156 + 1500282*uk_157 + 8899821*uk_158 + 1722546*uk_159 + 1988826*uk_16 + 333396*uk_160 + 10218313*uk_161 + 1977738*uk_162 + 382788*uk_163 + 74088*uk_164 + 3969*uk_17 + 6489*uk_18 + 2646*uk_19 + 63*uk_2 + 756*uk_20 + 11907*uk_21 + 13671*uk_22 + 2646*uk_23 + 10609*uk_24 + 4326*uk_25 + 1236*uk_26 + 19467*uk_27 + 22351*uk_28 + 4326*uk_29 + 103*uk_3 + 1764*uk_30 + 504*uk_31 + 7938*uk_32 + 9114*uk_33 + 1764*uk_34 + 144*uk_35 + 2268*uk_36 + 2604*uk_37 + 504*uk_38 + 35721*uk_39 + 42*uk_4 + 41013*uk_40 + 7938*uk_41 + 47089*uk_42 + 9114*uk_43 + 1764*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 230957580727*uk_47 + 94176877578*uk_48 + 26907679308*uk_49 + 12*uk_5 + 423795949101*uk_50 + 486580534153*uk_51 + 94176877578*uk_52 + 187944057*uk_53 + 307273617*uk_54 + 125296038*uk_55 + 35798868*uk_56 + 563832171*uk_57 + 647362863*uk_58 + 125296038*uk_59 + 189*uk_6 + 502367977*uk_60 + 204849078*uk_61 + 58528308*uk_62 + 921820851*uk_63 + 1058386903*uk_64 + 204849078*uk_65 + 83530692*uk_66 + 23865912*uk_67 + 375888114*uk_68 + 431575242*uk_69 + 217*uk_7 + 83530692*uk_70 + 6818832*uk_71 + 107396604*uk_72 + 123307212*uk_73 + 23865912*uk_74 + 1691496513*uk_75 + 1942088589*uk_76 + 375888114*uk_77 + 2229805417*uk_78 + 431575242*uk_79 + 42*uk_8 + 83530692*uk_80 + 250047*uk_81 + 408807*uk_82 + 166698*uk_83 + 47628*uk_84 + 750141*uk_85 + 861273*uk_86 + 166698*uk_87 + 668367*uk_88 + 272538*uk_89 + 2242306609*uk_9 + 77868*uk_90 + 1226421*uk_91 + 1408113*uk_92 + 272538*uk_93 + 111132*uk_94 + 31752*uk_95 + 500094*uk_96 + 574182*uk_97 + 111132*uk_98 + 9072*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 144396*uk_100 + 164052*uk_101 + 77868*uk_102 + 2298303*uk_103 + 2611161*uk_104 + 1239399*uk_105 + 2966607*uk_106 + 1408113*uk_107 + 668367*uk_108 + 5451776*uk_109 + 8334128*uk_11 + 3190528*uk_110 + 371712*uk_111 + 5916416*uk_112 + 6721792*uk_113 + 3190528*uk_114 + 1867184*uk_115 + 217536*uk_116 + 3462448*uk_117 + 3933776*uk_118 + 1867184*uk_119 + 4877359*uk_12 + 25344*uk_120 + 403392*uk_121 + 458304*uk_122 + 217536*uk_123 + 6420656*uk_124 + 7294672*uk_125 + 3462448*uk_126 + 8287664*uk_127 + 3933776*uk_128 + 1867184*uk_129 + 568236*uk_13 + 1092727*uk_130 + 127308*uk_131 + 2026319*uk_132 + 2302153*uk_133 + 1092727*uk_134 + 14832*uk_135 + 236076*uk_136 + 268212*uk_137 + 127308*uk_138 + 3757543*uk_139 + 9044423*uk_14 + 4269041*uk_140 + 2026319*uk_141 + 4850167*uk_142 + 2302153*uk_143 + 1092727*uk_144 + 1728*uk_145 + 27504*uk_146 + 31248*uk_147 + 14832*uk_148 + 437772*uk_149 + 10275601*uk_15 + 497364*uk_150 + 236076*uk_151 + 565068*uk_152 + 268212*uk_153 + 127308*uk_154 + 6967871*uk_155 + 7916377*uk_156 + 3757543*uk_157 + 8993999*uk_158 + 4269041*uk_159 + 4877359*uk_16 + 2026319*uk_160 + 10218313*uk_161 + 4850167*uk_162 + 2302153*uk_163 + 1092727*uk_164 + 3969*uk_17 + 11088*uk_18 + 6489*uk_19 + 63*uk_2 + 756*uk_20 + 12033*uk_21 + 13671*uk_22 + 6489*uk_23 + 30976*uk_24 + 18128*uk_25 + 2112*uk_26 + 33616*uk_27 + 38192*uk_28 + 18128*uk_29 + 176*uk_3 + 10609*uk_30 + 1236*uk_31 + 19673*uk_32 + 22351*uk_33 + 10609*uk_34 + 144*uk_35 + 2292*uk_36 + 2604*uk_37 + 1236*uk_38 + 36481*uk_39 + 103*uk_4 + 41447*uk_40 + 19673*uk_41 + 47089*uk_42 + 22351*uk_43 + 10609*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 394645963184*uk_47 + 230957580727*uk_48 + 26907679308*uk_49 + 12*uk_5 + 428280562319*uk_50 + 486580534153*uk_51 + 230957580727*uk_52 + 187944057*uk_53 + 525050064*uk_54 + 307273617*uk_55 + 35798868*uk_56 + 569798649*uk_57 + 647362863*uk_58 + 307273617*uk_59 + 191*uk_6 + 1466806528*uk_60 + 858415184*uk_61 + 100009536*uk_62 + 1591818448*uk_63 + 1808505776*uk_64 + 858415184*uk_65 + 502367977*uk_66 + 58528308*uk_67 + 931575569*uk_68 + 1058386903*uk_69 + 217*uk_7 + 502367977*uk_70 + 6818832*uk_71 + 108533076*uk_72 + 123307212*uk_73 + 58528308*uk_74 + 1727484793*uk_75 + 1962639791*uk_76 + 931575569*uk_77 + 2229805417*uk_78 + 1058386903*uk_79 + 103*uk_8 + 502367977*uk_80 + 250047*uk_81 + 698544*uk_82 + 408807*uk_83 + 47628*uk_84 + 758079*uk_85 + 861273*uk_86 + 408807*uk_87 + 1951488*uk_88 + 1142064*uk_89 + 2242306609*uk_9 + 133056*uk_90 + 2117808*uk_91 + 2406096*uk_92 + 1142064*uk_93 + 668367*uk_94 + 77868*uk_95 + 1239399*uk_96 + 1408113*uk_97 + 668367*uk_98 + 9072*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 97272*uk_100 + 109368*uk_101 + 88704*uk_102 + 2346687*uk_103 + 2638503*uk_104 + 2139984*uk_105 + 2966607*uk_106 + 2406096*uk_107 + 1951488*uk_108 + 314432*uk_109 + 3220004*uk_11 + 813824*uk_110 + 36992*uk_111 + 892432*uk_112 + 1003408*uk_113 + 813824*uk_114 + 2106368*uk_115 + 95744*uk_116 + 2309824*uk_117 + 2597056*uk_118 + 2106368*uk_119 + 8334128*uk_12 + 4352*uk_120 + 104992*uk_121 + 118048*uk_122 + 95744*uk_123 + 2532932*uk_124 + 2847908*uk_125 + 2309824*uk_126 + 3202052*uk_127 + 2597056*uk_128 + 2106368*uk_129 + 378824*uk_13 + 5451776*uk_130 + 247808*uk_131 + 5978368*uk_132 + 6721792*uk_133 + 5451776*uk_134 + 11264*uk_135 + 271744*uk_136 + 305536*uk_137 + 247808*uk_138 + 6555824*uk_139 + 9139129*uk_14 + 7371056*uk_140 + 5978368*uk_141 + 8287664*uk_142 + 6721792*uk_143 + 5451776*uk_144 + 512*uk_145 + 12352*uk_146 + 13888*uk_147 + 11264*uk_148 + 297992*uk_149 + 10275601*uk_15 + 335048*uk_150 + 271744*uk_151 + 376712*uk_152 + 305536*uk_153 + 247808*uk_154 + 7189057*uk_155 + 8083033*uk_156 + 6555824*uk_157 + 9088177*uk_158 + 7371056*uk_159 + 8334128*uk_16 + 5978368*uk_160 + 10218313*uk_161 + 8287664*uk_162 + 6721792*uk_163 + 5451776*uk_164 + 3969*uk_17 + 4284*uk_18 + 11088*uk_19 + 63*uk_2 + 504*uk_20 + 12159*uk_21 + 13671*uk_22 + 11088*uk_23 + 4624*uk_24 + 11968*uk_25 + 544*uk_26 + 13124*uk_27 + 14756*uk_28 + 11968*uk_29 + 68*uk_3 + 30976*uk_30 + 1408*uk_31 + 33968*uk_32 + 38192*uk_33 + 30976*uk_34 + 64*uk_35 + 1544*uk_36 + 1736*uk_37 + 1408*uk_38 + 37249*uk_39 + 176*uk_4 + 41881*uk_40 + 33968*uk_41 + 47089*uk_42 + 38192*uk_43 + 30976*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 152476849412*uk_47 + 394645963184*uk_48 + 17938452872*uk_49 + 8*uk_5 + 432765175537*uk_50 + 486580534153*uk_51 + 394645963184*uk_52 + 187944057*uk_53 + 202860252*uk_54 + 525050064*uk_55 + 23865912*uk_56 + 575765127*uk_57 + 647362863*uk_58 + 525050064*uk_59 + 193*uk_6 + 218960272*uk_60 + 566720704*uk_61 + 25760032*uk_62 + 621460772*uk_63 + 698740868*uk_64 + 566720704*uk_65 + 1466806528*uk_66 + 66673024*uk_67 + 1608486704*uk_68 + 1808505776*uk_69 + 217*uk_7 + 1466806528*uk_70 + 3030592*uk_71 + 73113032*uk_72 + 82204808*uk_73 + 66673024*uk_74 + 1763851897*uk_75 + 1983190993*uk_76 + 1608486704*uk_77 + 2229805417*uk_78 + 1808505776*uk_79 + 176*uk_8 + 1466806528*uk_80 + 250047*uk_81 + 269892*uk_82 + 698544*uk_83 + 31752*uk_84 + 766017*uk_85 + 861273*uk_86 + 698544*uk_87 + 291312*uk_88 + 753984*uk_89 + 2242306609*uk_9 + 34272*uk_90 + 826812*uk_91 + 929628*uk_92 + 753984*uk_93 + 1951488*uk_94 + 88704*uk_95 + 2139984*uk_96 + 2406096*uk_97 + 1951488*uk_98 + 4032*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 147420*uk_100 + 164052*uk_101 + 51408*uk_102 + 2395575*uk_103 + 2665845*uk_104 + 835380*uk_105 + 2966607*uk_106 + 929628*uk_107 + 291312*uk_108 + 4330747*uk_109 + 7718539*uk_11 + 1806692*uk_110 + 318828*uk_111 + 5180955*uk_112 + 5765473*uk_113 + 1806692*uk_114 + 753712*uk_115 + 133008*uk_116 + 2161380*uk_117 + 2405228*uk_118 + 753712*uk_119 + 3220004*uk_12 + 23472*uk_120 + 381420*uk_121 + 424452*uk_122 + 133008*uk_123 + 6198075*uk_124 + 6897345*uk_125 + 2161380*uk_126 + 7675507*uk_127 + 2405228*uk_128 + 753712*uk_129 + 568236*uk_13 + 314432*uk_130 + 55488*uk_131 + 901680*uk_132 + 1003408*uk_133 + 314432*uk_134 + 9792*uk_135 + 159120*uk_136 + 177072*uk_137 + 55488*uk_138 + 2585700*uk_139 + 9233835*uk_14 + 2877420*uk_140 + 901680*uk_141 + 3202052*uk_142 + 1003408*uk_143 + 314432*uk_144 + 1728*uk_145 + 28080*uk_146 + 31248*uk_147 + 9792*uk_148 + 456300*uk_149 + 10275601*uk_15 + 507780*uk_150 + 159120*uk_151 + 565068*uk_152 + 177072*uk_153 + 55488*uk_154 + 7414875*uk_155 + 8251425*uk_156 + 2585700*uk_157 + 9182355*uk_158 + 2877420*uk_159 + 3220004*uk_16 + 901680*uk_160 + 10218313*uk_161 + 3202052*uk_162 + 1003408*uk_163 + 314432*uk_164 + 3969*uk_17 + 10269*uk_18 + 4284*uk_19 + 63*uk_2 + 756*uk_20 + 12285*uk_21 + 13671*uk_22 + 4284*uk_23 + 26569*uk_24 + 11084*uk_25 + 1956*uk_26 + 31785*uk_27 + 35371*uk_28 + 11084*uk_29 + 163*uk_3 + 4624*uk_30 + 816*uk_31 + 13260*uk_32 + 14756*uk_33 + 4624*uk_34 + 144*uk_35 + 2340*uk_36 + 2604*uk_37 + 816*uk_38 + 38025*uk_39 + 68*uk_4 + 42315*uk_40 + 13260*uk_41 + 47089*uk_42 + 14756*uk_43 + 4624*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 365495977267*uk_47 + 152476849412*uk_48 + 26907679308*uk_49 + 12*uk_5 + 437249788755*uk_50 + 486580534153*uk_51 + 152476849412*uk_52 + 187944057*uk_53 + 486267957*uk_54 + 202860252*uk_55 + 35798868*uk_56 + 581731605*uk_57 + 647362863*uk_58 + 202860252*uk_59 + 195*uk_6 + 1258121857*uk_60 + 524860652*uk_61 + 92622468*uk_62 + 1505115105*uk_63 + 1674922963*uk_64 + 524860652*uk_65 + 218960272*uk_66 + 38640048*uk_67 + 627900780*uk_68 + 698740868*uk_69 + 217*uk_7 + 218960272*uk_70 + 6818832*uk_71 + 110806020*uk_72 + 123307212*uk_73 + 38640048*uk_74 + 1800597825*uk_75 + 2003742195*uk_76 + 627900780*uk_77 + 2229805417*uk_78 + 698740868*uk_79 + 68*uk_8 + 218960272*uk_80 + 250047*uk_81 + 646947*uk_82 + 269892*uk_83 + 47628*uk_84 + 773955*uk_85 + 861273*uk_86 + 269892*uk_87 + 1673847*uk_88 + 698292*uk_89 + 2242306609*uk_9 + 123228*uk_90 + 2002455*uk_91 + 2228373*uk_92 + 698292*uk_93 + 291312*uk_94 + 51408*uk_95 + 835380*uk_96 + 929628*uk_97 + 291312*uk_98 + 9072*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 99288*uk_100 + 109368*uk_101 + 82152*uk_102 + 2444967*uk_103 + 2693187*uk_104 + 2022993*uk_105 + 2966607*uk_106 + 2228373*uk_107 + 1673847*uk_108 + 389017*uk_109 + 3456769*uk_11 + 868627*uk_110 + 42632*uk_111 + 1049813*uk_112 + 1156393*uk_113 + 868627*uk_114 + 1939537*uk_115 + 95192*uk_116 + 2344103*uk_117 + 2582083*uk_118 + 1939537*uk_119 + 7718539*uk_12 + 4672*uk_120 + 115048*uk_121 + 126728*uk_122 + 95192*uk_123 + 2833057*uk_124 + 3120677*uk_125 + 2344103*uk_126 + 3437497*uk_127 + 2582083*uk_128 + 1939537*uk_129 + 378824*uk_13 + 4330747*uk_130 + 212552*uk_131 + 5234093*uk_132 + 5765473*uk_133 + 4330747*uk_134 + 10432*uk_135 + 256888*uk_136 + 282968*uk_137 + 212552*uk_138 + 6325867*uk_139 + 9328541*uk_14 + 6968087*uk_140 + 5234093*uk_141 + 7675507*uk_142 + 5765473*uk_143 + 4330747*uk_144 + 512*uk_145 + 12608*uk_146 + 13888*uk_147 + 10432*uk_148 + 310472*uk_149 + 10275601*uk_15 + 341992*uk_150 + 256888*uk_151 + 376712*uk_152 + 282968*uk_153 + 212552*uk_154 + 7645373*uk_155 + 8421553*uk_156 + 6325867*uk_157 + 9276533*uk_158 + 6968087*uk_159 + 7718539*uk_16 + 5234093*uk_160 + 10218313*uk_161 + 7675507*uk_162 + 5765473*uk_163 + 4330747*uk_164 + 3969*uk_17 + 4599*uk_18 + 10269*uk_19 + 63*uk_2 + 504*uk_20 + 12411*uk_21 + 13671*uk_22 + 10269*uk_23 + 5329*uk_24 + 11899*uk_25 + 584*uk_26 + 14381*uk_27 + 15841*uk_28 + 11899*uk_29 + 73*uk_3 + 26569*uk_30 + 1304*uk_31 + 32111*uk_32 + 35371*uk_33 + 26569*uk_34 + 64*uk_35 + 1576*uk_36 + 1736*uk_37 + 1304*uk_38 + 38809*uk_39 + 163*uk_4 + 42749*uk_40 + 32111*uk_41 + 47089*uk_42 + 35371*uk_43 + 26569*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 163688382457*uk_47 + 365495977267*uk_48 + 17938452872*uk_49 + 8*uk_5 + 441734401973*uk_50 + 486580534153*uk_51 + 365495977267*uk_52 + 187944057*uk_53 + 217776447*uk_54 + 486267957*uk_55 + 23865912*uk_56 + 587698083*uk_57 + 647362863*uk_58 + 486267957*uk_59 + 197*uk_6 + 252344137*uk_60 + 563453347*uk_61 + 27654152*uk_62 + 680983493*uk_63 + 750118873*uk_64 + 563453347*uk_65 + 1258121857*uk_66 + 61748312*uk_67 + 1520552183*uk_68 + 1674922963*uk_69 + 217*uk_7 + 1258121857*uk_70 + 3030592*uk_71 + 74628328*uk_72 + 82204808*uk_73 + 61748312*uk_74 + 1837722577*uk_75 + 2024293397*uk_76 + 1520552183*uk_77 + 2229805417*uk_78 + 1674922963*uk_79 + 163*uk_8 + 1258121857*uk_80 + 250047*uk_81 + 289737*uk_82 + 646947*uk_83 + 31752*uk_84 + 781893*uk_85 + 861273*uk_86 + 646947*uk_87 + 335727*uk_88 + 749637*uk_89 + 2242306609*uk_9 + 36792*uk_90 + 906003*uk_91 + 997983*uk_92 + 749637*uk_93 + 1673847*uk_94 + 82152*uk_95 + 2022993*uk_96 + 2228373*uk_97 + 1673847*uk_98 + 4032*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 150444*uk_100 + 164052*uk_101 + 55188*uk_102 + 2494863*uk_103 + 2720529*uk_104 + 915201*uk_105 + 2966607*uk_106 + 997983*uk_107 + 335727*uk_108 + 6859000*uk_109 + 8997070*uk_11 + 2635300*uk_110 + 433200*uk_111 + 7183900*uk_112 + 7833700*uk_113 + 2635300*uk_114 + 1012510*uk_115 + 166440*uk_116 + 2760130*uk_117 + 3009790*uk_118 + 1012510*uk_119 + 3456769*uk_12 + 27360*uk_120 + 453720*uk_121 + 494760*uk_122 + 166440*uk_123 + 7524190*uk_124 + 8204770*uk_125 + 2760130*uk_126 + 8946910*uk_127 + 3009790*uk_128 + 1012510*uk_129 + 568236*uk_13 + 389017*uk_130 + 63948*uk_131 + 1060471*uk_132 + 1156393*uk_133 + 389017*uk_134 + 10512*uk_135 + 174324*uk_136 + 190092*uk_137 + 63948*uk_138 + 2890873*uk_139 + 9423247*uk_14 + 3152359*uk_140 + 1060471*uk_141 + 3437497*uk_142 + 1156393*uk_143 + 389017*uk_144 + 1728*uk_145 + 28656*uk_146 + 31248*uk_147 + 10512*uk_148 + 475212*uk_149 + 10275601*uk_15 + 518196*uk_150 + 174324*uk_151 + 565068*uk_152 + 190092*uk_153 + 63948*uk_154 + 7880599*uk_155 + 8593417*uk_156 + 2890873*uk_157 + 9370711*uk_158 + 3152359*uk_159 + 3456769*uk_16 + 1060471*uk_160 + 10218313*uk_161 + 3437497*uk_162 + 1156393*uk_163 + 389017*uk_164 + 3969*uk_17 + 11970*uk_18 + 4599*uk_19 + 63*uk_2 + 756*uk_20 + 12537*uk_21 + 13671*uk_22 + 4599*uk_23 + 36100*uk_24 + 13870*uk_25 + 2280*uk_26 + 37810*uk_27 + 41230*uk_28 + 13870*uk_29 + 190*uk_3 + 5329*uk_30 + 876*uk_31 + 14527*uk_32 + 15841*uk_33 + 5329*uk_34 + 144*uk_35 + 2388*uk_36 + 2604*uk_37 + 876*uk_38 + 39601*uk_39 + 73*uk_4 + 43183*uk_40 + 14527*uk_41 + 47089*uk_42 + 15841*uk_43 + 5329*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 426038255710*uk_47 + 163688382457*uk_48 + 26907679308*uk_49 + 12*uk_5 + 446219015191*uk_50 + 486580534153*uk_51 + 163688382457*uk_52 + 187944057*uk_53 + 566815410*uk_54 + 217776447*uk_55 + 35798868*uk_56 + 593664561*uk_57 + 647362863*uk_58 + 217776447*uk_59 + 199*uk_6 + 1709443300*uk_60 + 656786110*uk_61 + 107964840*uk_62 + 1790416930*uk_63 + 1952364190*uk_64 + 656786110*uk_65 + 252344137*uk_66 + 41481228*uk_67 + 687897031*uk_68 + 750118873*uk_69 + 217*uk_7 + 252344137*uk_70 + 6818832*uk_71 + 113078964*uk_72 + 123307212*uk_73 + 41481228*uk_74 + 1875226153*uk_75 + 2044844599*uk_76 + 687897031*uk_77 + 2229805417*uk_78 + 750118873*uk_79 + 73*uk_8 + 252344137*uk_80 + 250047*uk_81 + 754110*uk_82 + 289737*uk_83 + 47628*uk_84 + 789831*uk_85 + 861273*uk_86 + 289737*uk_87 + 2274300*uk_88 + 873810*uk_89 + 2242306609*uk_9 + 143640*uk_90 + 2382030*uk_91 + 2597490*uk_92 + 873810*uk_93 + 335727*uk_94 + 55188*uk_95 + 915201*uk_96 + 997983*uk_97 + 335727*uk_98 + 9072*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 101304*uk_100 + 109368*uk_101 + 95760*uk_102 + 2545263*uk_103 + 2747871*uk_104 + 2405970*uk_105 + 2966607*uk_106 + 2597490*uk_107 + 2274300*uk_108 + 1643032*uk_109 + 5587654*uk_11 + 2645560*uk_110 + 111392*uk_111 + 2798724*uk_112 + 3021508*uk_113 + 2645560*uk_114 + 4259800*uk_115 + 179360*uk_116 + 4506420*uk_117 + 4865140*uk_118 + 4259800*uk_119 + 8997070*uk_12 + 7552*uk_120 + 189744*uk_121 + 204848*uk_122 + 179360*uk_123 + 4767318*uk_124 + 5146806*uk_125 + 4506420*uk_126 + 5556502*uk_127 + 4865140*uk_128 + 4259800*uk_129 + 378824*uk_13 + 6859000*uk_130 + 288800*uk_131 + 7256100*uk_132 + 7833700*uk_133 + 6859000*uk_134 + 12160*uk_135 + 305520*uk_136 + 329840*uk_137 + 288800*uk_138 + 7676190*uk_139 + 9517953*uk_14 + 8287230*uk_140 + 7256100*uk_141 + 8946910*uk_142 + 7833700*uk_143 + 6859000*uk_144 + 512*uk_145 + 12864*uk_146 + 13888*uk_147 + 12160*uk_148 + 323208*uk_149 + 10275601*uk_15 + 348936*uk_150 + 305520*uk_151 + 376712*uk_152 + 329840*uk_153 + 288800*uk_154 + 8120601*uk_155 + 8767017*uk_156 + 7676190*uk_157 + 9464889*uk_158 + 8287230*uk_159 + 8997070*uk_16 + 7256100*uk_160 + 10218313*uk_161 + 8946910*uk_162 + 7833700*uk_163 + 6859000*uk_164 + 3969*uk_17 + 7434*uk_18 + 11970*uk_19 + 63*uk_2 + 504*uk_20 + 12663*uk_21 + 13671*uk_22 + 11970*uk_23 + 13924*uk_24 + 22420*uk_25 + 944*uk_26 + 23718*uk_27 + 25606*uk_28 + 22420*uk_29 + 118*uk_3 + 36100*uk_30 + 1520*uk_31 + 38190*uk_32 + 41230*uk_33 + 36100*uk_34 + 64*uk_35 + 1608*uk_36 + 1736*uk_37 + 1520*uk_38 + 40401*uk_39 + 190*uk_4 + 43617*uk_40 + 38190*uk_41 + 47089*uk_42 + 41230*uk_43 + 36100*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 264592179862*uk_47 + 426038255710*uk_48 + 17938452872*uk_49 + 8*uk_5 + 450703628409*uk_50 + 486580534153*uk_51 + 426038255710*uk_52 + 187944057*uk_53 + 352022202*uk_54 + 566815410*uk_55 + 23865912*uk_56 + 599631039*uk_57 + 647362863*uk_58 + 566815410*uk_59 + 201*uk_6 + 659343172*uk_60 + 1061654260*uk_61 + 44701232*uk_62 + 1123118454*uk_63 + 1212520918*uk_64 + 1061654260*uk_65 + 1709443300*uk_66 + 71976560*uk_67 + 1808411070*uk_68 + 1952364190*uk_69 + 217*uk_7 + 1709443300*uk_70 + 3030592*uk_71 + 76143624*uk_72 + 82204808*uk_73 + 71976560*uk_74 + 1913108553*uk_75 + 2065395801*uk_76 + 1808411070*uk_77 + 2229805417*uk_78 + 1952364190*uk_79 + 190*uk_8 + 1709443300*uk_80 + 250047*uk_81 + 468342*uk_82 + 754110*uk_83 + 31752*uk_84 + 797769*uk_85 + 861273*uk_86 + 754110*uk_87 + 877212*uk_88 + 1412460*uk_89 + 2242306609*uk_9 + 59472*uk_90 + 1494234*uk_91 + 1613178*uk_92 + 1412460*uk_93 + 2274300*uk_94 + 95760*uk_95 + 2405970*uk_96 + 2597490*uk_97 + 2274300*uk_98 + 4032*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 102312*uk_100 + 109368*uk_101 + 59472*uk_102 + 2596167*uk_103 + 2775213*uk_104 + 1509102*uk_105 + 2966607*uk_106 + 1613178*uk_107 + 877212*uk_108 + 157464*uk_109 + 2557062*uk_11 + 344088*uk_110 + 23328*uk_111 + 591948*uk_112 + 632772*uk_113 + 344088*uk_114 + 751896*uk_115 + 50976*uk_116 + 1293516*uk_117 + 1382724*uk_118 + 751896*uk_119 + 5587654*uk_12 + 3456*uk_120 + 87696*uk_121 + 93744*uk_122 + 50976*uk_123 + 2225286*uk_124 + 2378754*uk_125 + 1293516*uk_126 + 2542806*uk_127 + 1382724*uk_128 + 751896*uk_129 + 378824*uk_13 + 1643032*uk_130 + 111392*uk_131 + 2826572*uk_132 + 3021508*uk_133 + 1643032*uk_134 + 7552*uk_135 + 191632*uk_136 + 204848*uk_137 + 111392*uk_138 + 4862662*uk_139 + 9612659*uk_14 + 5198018*uk_140 + 2826572*uk_141 + 5556502*uk_142 + 3021508*uk_143 + 1643032*uk_144 + 512*uk_145 + 12992*uk_146 + 13888*uk_147 + 7552*uk_148 + 329672*uk_149 + 10275601*uk_15 + 352408*uk_150 + 191632*uk_151 + 376712*uk_152 + 204848*uk_153 + 111392*uk_154 + 8365427*uk_155 + 8942353*uk_156 + 4862662*uk_157 + 9559067*uk_158 + 5198018*uk_159 + 5587654*uk_16 + 2826572*uk_160 + 10218313*uk_161 + 5556502*uk_162 + 3021508*uk_163 + 1643032*uk_164 + 3969*uk_17 + 3402*uk_18 + 7434*uk_19 + 63*uk_2 + 504*uk_20 + 12789*uk_21 + 13671*uk_22 + 7434*uk_23 + 2916*uk_24 + 6372*uk_25 + 432*uk_26 + 10962*uk_27 + 11718*uk_28 + 6372*uk_29 + 54*uk_3 + 13924*uk_30 + 944*uk_31 + 23954*uk_32 + 25606*uk_33 + 13924*uk_34 + 64*uk_35 + 1624*uk_36 + 1736*uk_37 + 944*uk_38 + 41209*uk_39 + 118*uk_4 + 44051*uk_40 + 23954*uk_41 + 47089*uk_42 + 25606*uk_43 + 13924*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 121084556886*uk_47 + 264592179862*uk_48 + 17938452872*uk_49 + 8*uk_5 + 455188241627*uk_50 + 486580534153*uk_51 + 264592179862*uk_52 + 187944057*uk_53 + 161094906*uk_54 + 352022202*uk_55 + 23865912*uk_56 + 605597517*uk_57 + 647362863*uk_58 + 352022202*uk_59 + 203*uk_6 + 138081348*uk_60 + 301733316*uk_61 + 20456496*uk_62 + 519083586*uk_63 + 554882454*uk_64 + 301733316*uk_65 + 659343172*uk_66 + 44701232*uk_67 + 1134293762*uk_68 + 1212520918*uk_69 + 217*uk_7 + 659343172*uk_70 + 3030592*uk_71 + 76901272*uk_72 + 82204808*uk_73 + 44701232*uk_74 + 1951369777*uk_75 + 2085947003*uk_76 + 1134293762*uk_77 + 2229805417*uk_78 + 1212520918*uk_79 + 118*uk_8 + 659343172*uk_80 + 250047*uk_81 + 214326*uk_82 + 468342*uk_83 + 31752*uk_84 + 805707*uk_85 + 861273*uk_86 + 468342*uk_87 + 183708*uk_88 + 401436*uk_89 + 2242306609*uk_9 + 27216*uk_90 + 690606*uk_91 + 738234*uk_92 + 401436*uk_93 + 877212*uk_94 + 59472*uk_95 + 1509102*uk_96 + 1613178*uk_97 + 877212*uk_98 + 4032*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 154980*uk_100 + 164052*uk_101 + 40824*uk_102 + 2647575*uk_103 + 2802555*uk_104 + 697410*uk_105 + 2966607*uk_106 + 738234*uk_107 + 183708*uk_108 + 8365427*uk_109 + 9612659*uk_11 + 2225286*uk_110 + 494508*uk_111 + 8447845*uk_112 + 8942353*uk_113 + 2225286*uk_114 + 591948*uk_115 + 131544*uk_116 + 2247210*uk_117 + 2378754*uk_118 + 591948*uk_119 + 2557062*uk_12 + 29232*uk_120 + 499380*uk_121 + 528612*uk_122 + 131544*uk_123 + 8531075*uk_124 + 9030455*uk_125 + 2247210*uk_126 + 9559067*uk_127 + 2378754*uk_128 + 591948*uk_129 + 568236*uk_13 + 157464*uk_130 + 34992*uk_131 + 597780*uk_132 + 632772*uk_133 + 157464*uk_134 + 7776*uk_135 + 132840*uk_136 + 140616*uk_137 + 34992*uk_138 + 2269350*uk_139 + 9707365*uk_14 + 2402190*uk_140 + 597780*uk_141 + 2542806*uk_142 + 632772*uk_143 + 157464*uk_144 + 1728*uk_145 + 29520*uk_146 + 31248*uk_147 + 7776*uk_148 + 504300*uk_149 + 10275601*uk_15 + 533820*uk_150 + 132840*uk_151 + 565068*uk_152 + 140616*uk_153 + 34992*uk_154 + 8615125*uk_155 + 9119425*uk_156 + 2269350*uk_157 + 9653245*uk_158 + 2402190*uk_159 + 2557062*uk_16 + 597780*uk_160 + 10218313*uk_161 + 2542806*uk_162 + 632772*uk_163 + 157464*uk_164 + 3969*uk_17 + 12789*uk_18 + 3402*uk_19 + 63*uk_2 + 756*uk_20 + 12915*uk_21 + 13671*uk_22 + 3402*uk_23 + 41209*uk_24 + 10962*uk_25 + 2436*uk_26 + 41615*uk_27 + 44051*uk_28 + 10962*uk_29 + 203*uk_3 + 2916*uk_30 + 648*uk_31 + 11070*uk_32 + 11718*uk_33 + 2916*uk_34 + 144*uk_35 + 2460*uk_36 + 2604*uk_37 + 648*uk_38 + 42025*uk_39 + 54*uk_4 + 44485*uk_40 + 11070*uk_41 + 47089*uk_42 + 11718*uk_43 + 2916*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 455188241627*uk_47 + 121084556886*uk_48 + 26907679308*uk_49 + 12*uk_5 + 459672854845*uk_50 + 486580534153*uk_51 + 121084556886*uk_52 + 187944057*uk_53 + 605597517*uk_54 + 161094906*uk_55 + 35798868*uk_56 + 611563995*uk_57 + 647362863*uk_58 + 161094906*uk_59 + 205*uk_6 + 1951369777*uk_60 + 519083586*uk_61 + 115351908*uk_62 + 1970595095*uk_63 + 2085947003*uk_64 + 519083586*uk_65 + 138081348*uk_66 + 30684744*uk_67 + 524197710*uk_68 + 554882454*uk_69 + 217*uk_7 + 138081348*uk_70 + 6818832*uk_71 + 116488380*uk_72 + 123307212*uk_73 + 30684744*uk_74 + 1990009825*uk_75 + 2106498205*uk_76 + 524197710*uk_77 + 2229805417*uk_78 + 554882454*uk_79 + 54*uk_8 + 138081348*uk_80 + 250047*uk_81 + 805707*uk_82 + 214326*uk_83 + 47628*uk_84 + 813645*uk_85 + 861273*uk_86 + 214326*uk_87 + 2596167*uk_88 + 690606*uk_89 + 2242306609*uk_9 + 153468*uk_90 + 2621745*uk_91 + 2775213*uk_92 + 690606*uk_93 + 183708*uk_94 + 40824*uk_95 + 697410*uk_96 + 738234*uk_97 + 183708*uk_98 + 9072*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 104328*uk_100 + 109368*uk_101 + 102312*uk_102 + 2699487*uk_103 + 2829897*uk_104 + 2647323*uk_105 + 2966607*uk_106 + 2775213*uk_107 + 2596167*uk_108 + 3869893*uk_109 + 7434421*uk_11 + 5003747*uk_110 + 197192*uk_111 + 5102343*uk_112 + 5348833*uk_113 + 5003747*uk_114 + 6469813*uk_115 + 254968*uk_116 + 6597297*uk_117 + 6916007*uk_118 + 6469813*uk_119 + 9612659*uk_12 + 10048*uk_120 + 259992*uk_121 + 272552*uk_122 + 254968*uk_123 + 6727293*uk_124 + 7052283*uk_125 + 6597297*uk_126 + 7392973*uk_127 + 6916007*uk_128 + 6469813*uk_129 + 378824*uk_13 + 8365427*uk_130 + 329672*uk_131 + 8530263*uk_132 + 8942353*uk_133 + 8365427*uk_134 + 12992*uk_135 + 336168*uk_136 + 352408*uk_137 + 329672*uk_138 + 8698347*uk_139 + 9802071*uk_14 + 9118557*uk_140 + 8530263*uk_141 + 9559067*uk_142 + 8942353*uk_143 + 8365427*uk_144 + 512*uk_145 + 13248*uk_146 + 13888*uk_147 + 12992*uk_148 + 342792*uk_149 + 10275601*uk_15 + 359352*uk_150 + 336168*uk_151 + 376712*uk_152 + 352408*uk_153 + 329672*uk_154 + 8869743*uk_155 + 9298233*uk_156 + 8698347*uk_157 + 9747423*uk_158 + 9118557*uk_159 + 9612659*uk_16 + 8530263*uk_160 + 10218313*uk_161 + 9559067*uk_162 + 8942353*uk_163 + 8365427*uk_164 + 3969*uk_17 + 9891*uk_18 + 12789*uk_19 + 63*uk_2 + 504*uk_20 + 13041*uk_21 + 13671*uk_22 + 12789*uk_23 + 24649*uk_24 + 31871*uk_25 + 1256*uk_26 + 32499*uk_27 + 34069*uk_28 + 31871*uk_29 + 157*uk_3 + 41209*uk_30 + 1624*uk_31 + 42021*uk_32 + 44051*uk_33 + 41209*uk_34 + 64*uk_35 + 1656*uk_36 + 1736*uk_37 + 1624*uk_38 + 42849*uk_39 + 203*uk_4 + 44919*uk_40 + 42021*uk_41 + 47089*uk_42 + 44051*uk_43 + 41209*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 352042137613*uk_47 + 455188241627*uk_48 + 17938452872*uk_49 + 8*uk_5 + 464157468063*uk_50 + 486580534153*uk_51 + 455188241627*uk_52 + 187944057*uk_53 + 468368523*uk_54 + 605597517*uk_55 + 23865912*uk_56 + 617530473*uk_57 + 647362863*uk_58 + 605597517*uk_59 + 207*uk_6 + 1167204097*uk_60 + 1509187463*uk_61 + 59475368*uk_62 + 1538925147*uk_63 + 1613269357*uk_64 + 1509187463*uk_65 + 1951369777*uk_66 + 76901272*uk_67 + 1989820413*uk_68 + 2085947003*uk_69 + 217*uk_7 + 1951369777*uk_70 + 3030592*uk_71 + 78416568*uk_72 + 82204808*uk_73 + 76901272*uk_74 + 2029028697*uk_75 + 2127049407*uk_76 + 1989820413*uk_77 + 2229805417*uk_78 + 2085947003*uk_79 + 203*uk_8 + 1951369777*uk_80 + 250047*uk_81 + 623133*uk_82 + 805707*uk_83 + 31752*uk_84 + 821583*uk_85 + 861273*uk_86 + 805707*uk_87 + 1552887*uk_88 + 2007873*uk_89 + 2242306609*uk_9 + 79128*uk_90 + 2047437*uk_91 + 2146347*uk_92 + 2007873*uk_93 + 2596167*uk_94 + 102312*uk_95 + 2647323*uk_96 + 2775213*uk_97 + 2596167*uk_98 + 4032*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 105336*uk_100 + 109368*uk_101 + 79128*uk_102 + 2751903*uk_103 + 2857239*uk_104 + 2067219*uk_105 + 2966607*uk_106 + 2146347*uk_107 + 1552887*uk_108 + 1685159*uk_109 + 5635007*uk_11 + 2223277*uk_110 + 113288*uk_111 + 2959649*uk_112 + 3072937*uk_113 + 2223277*uk_114 + 2933231*uk_115 + 149464*uk_116 + 3904747*uk_117 + 4054211*uk_118 + 2933231*uk_119 + 7434421*uk_12 + 7616*uk_120 + 198968*uk_121 + 206584*uk_122 + 149464*uk_123 + 5198039*uk_124 + 5397007*uk_125 + 3904747*uk_126 + 5603591*uk_127 + 4054211*uk_128 + 2933231*uk_129 + 378824*uk_13 + 3869893*uk_130 + 197192*uk_131 + 5151641*uk_132 + 5348833*uk_133 + 3869893*uk_134 + 10048*uk_135 + 262504*uk_136 + 272552*uk_137 + 197192*uk_138 + 6857917*uk_139 + 9896777*uk_14 + 7120421*uk_140 + 5151641*uk_141 + 7392973*uk_142 + 5348833*uk_143 + 3869893*uk_144 + 512*uk_145 + 13376*uk_146 + 13888*uk_147 + 10048*uk_148 + 349448*uk_149 + 10275601*uk_15 + 362824*uk_150 + 262504*uk_151 + 376712*uk_152 + 272552*uk_153 + 197192*uk_154 + 9129329*uk_155 + 9478777*uk_156 + 6857917*uk_157 + 9841601*uk_158 + 7120421*uk_159 + 7434421*uk_16 + 5151641*uk_160 + 10218313*uk_161 + 7392973*uk_162 + 5348833*uk_163 + 3869893*uk_164 + 3969*uk_17 + 7497*uk_18 + 9891*uk_19 + 63*uk_2 + 504*uk_20 + 13167*uk_21 + 13671*uk_22 + 9891*uk_23 + 14161*uk_24 + 18683*uk_25 + 952*uk_26 + 24871*uk_27 + 25823*uk_28 + 18683*uk_29 + 119*uk_3 + 24649*uk_30 + 1256*uk_31 + 32813*uk_32 + 34069*uk_33 + 24649*uk_34 + 64*uk_35 + 1672*uk_36 + 1736*uk_37 + 1256*uk_38 + 43681*uk_39 + 157*uk_4 + 45353*uk_40 + 32813*uk_41 + 47089*uk_42 + 34069*uk_43 + 24649*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 266834486471*uk_47 + 352042137613*uk_48 + 17938452872*uk_49 + 8*uk_5 + 468642081281*uk_50 + 486580534153*uk_51 + 352042137613*uk_52 + 187944057*uk_53 + 355005441*uk_54 + 468368523*uk_55 + 23865912*uk_56 + 623496951*uk_57 + 647362863*uk_58 + 468368523*uk_59 + 209*uk_6 + 670565833*uk_60 + 884696099*uk_61 + 45080056*uk_62 + 1177716463*uk_63 + 1222796519*uk_64 + 884696099*uk_65 + 1167204097*uk_66 + 59475368*uk_67 + 1553793989*uk_68 + 1613269357*uk_69 + 217*uk_7 + 1167204097*uk_70 + 3030592*uk_71 + 79174216*uk_72 + 82204808*uk_73 + 59475368*uk_74 + 2068426393*uk_75 + 2147600609*uk_76 + 1553793989*uk_77 + 2229805417*uk_78 + 1613269357*uk_79 + 157*uk_8 + 1167204097*uk_80 + 250047*uk_81 + 472311*uk_82 + 623133*uk_83 + 31752*uk_84 + 829521*uk_85 + 861273*uk_86 + 623133*uk_87 + 892143*uk_88 + 1177029*uk_89 + 2242306609*uk_9 + 59976*uk_90 + 1566873*uk_91 + 1626849*uk_92 + 1177029*uk_93 + 1552887*uk_94 + 79128*uk_95 + 2067219*uk_96 + 2146347*uk_97 + 1552887*uk_98 + 4032*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 106344*uk_100 + 109368*uk_101 + 59976*uk_102 + 2804823*uk_103 + 2884581*uk_104 + 1581867*uk_105 + 2966607*uk_106 + 1626849*uk_107 + 892143*uk_108 + 704969*uk_109 + 4214417*uk_11 + 942599*uk_110 + 63368*uk_111 + 1671331*uk_112 + 1718857*uk_113 + 942599*uk_114 + 1260329*uk_115 + 84728*uk_116 + 2234701*uk_117 + 2298247*uk_118 + 1260329*uk_119 + 5635007*uk_12 + 5696*uk_120 + 150232*uk_121 + 154504*uk_122 + 84728*uk_123 + 3962369*uk_124 + 4075043*uk_125 + 2234701*uk_126 + 4190921*uk_127 + 2298247*uk_128 + 1260329*uk_129 + 378824*uk_13 + 1685159*uk_130 + 113288*uk_131 + 2987971*uk_132 + 3072937*uk_133 + 1685159*uk_134 + 7616*uk_135 + 200872*uk_136 + 206584*uk_137 + 113288*uk_138 + 5297999*uk_139 + 9991483*uk_14 + 5448653*uk_140 + 2987971*uk_141 + 5603591*uk_142 + 3072937*uk_143 + 1685159*uk_144 + 512*uk_145 + 13504*uk_146 + 13888*uk_147 + 7616*uk_148 + 356168*uk_149 + 10275601*uk_15 + 366296*uk_150 + 200872*uk_151 + 376712*uk_152 + 206584*uk_153 + 113288*uk_154 + 9393931*uk_155 + 9661057*uk_156 + 5297999*uk_157 + 9935779*uk_158 + 5448653*uk_159 + 5635007*uk_16 + 2987971*uk_160 + 10218313*uk_161 + 5603591*uk_162 + 3072937*uk_163 + 1685159*uk_164 + 3969*uk_17 + 5607*uk_18 + 7497*uk_19 + 63*uk_2 + 504*uk_20 + 13293*uk_21 + 13671*uk_22 + 7497*uk_23 + 7921*uk_24 + 10591*uk_25 + 712*uk_26 + 18779*uk_27 + 19313*uk_28 + 10591*uk_29 + 89*uk_3 + 14161*uk_30 + 952*uk_31 + 25109*uk_32 + 25823*uk_33 + 14161*uk_34 + 64*uk_35 + 1688*uk_36 + 1736*uk_37 + 952*uk_38 + 44521*uk_39 + 119*uk_4 + 45787*uk_40 + 25109*uk_41 + 47089*uk_42 + 25823*uk_43 + 14161*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 199565288201*uk_47 + 266834486471*uk_48 + 17938452872*uk_49 + 8*uk_5 + 473126694499*uk_50 + 486580534153*uk_51 + 266834486471*uk_52 + 187944057*uk_53 + 265508271*uk_54 + 355005441*uk_55 + 23865912*uk_56 + 629463429*uk_57 + 647362863*uk_58 + 355005441*uk_59 + 211*uk_6 + 375083113*uk_60 + 501515623*uk_61 + 33715336*uk_62 + 889241987*uk_63 + 914528489*uk_64 + 501515623*uk_65 + 670565833*uk_66 + 45080056*uk_67 + 1188986477*uk_68 + 1222796519*uk_69 + 217*uk_7 + 670565833*uk_70 + 3030592*uk_71 + 79931864*uk_72 + 82204808*uk_73 + 45080056*uk_74 + 2108202913*uk_75 + 2168151811*uk_76 + 1188986477*uk_77 + 2229805417*uk_78 + 1222796519*uk_79 + 119*uk_8 + 670565833*uk_80 + 250047*uk_81 + 353241*uk_82 + 472311*uk_83 + 31752*uk_84 + 837459*uk_85 + 861273*uk_86 + 472311*uk_87 + 499023*uk_88 + 667233*uk_89 + 2242306609*uk_9 + 44856*uk_90 + 1183077*uk_91 + 1216719*uk_92 + 667233*uk_93 + 892143*uk_94 + 59976*uk_95 + 1581867*uk_96 + 1626849*uk_97 + 892143*uk_98 + 4032*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 107352*uk_100 + 109368*uk_101 + 44856*uk_102 + 2858247*uk_103 + 2911923*uk_104 + 1194291*uk_105 + 2966607*uk_106 + 1216719*uk_107 + 499023*uk_108 + 300763*uk_109 + 3172651*uk_11 + 399521*uk_110 + 35912*uk_111 + 956157*uk_112 + 974113*uk_113 + 399521*uk_114 + 530707*uk_115 + 47704*uk_116 + 1270119*uk_117 + 1293971*uk_118 + 530707*uk_119 + 4214417*uk_12 + 4288*uk_120 + 114168*uk_121 + 116312*uk_122 + 47704*uk_123 + 3039723*uk_124 + 3096807*uk_125 + 1270119*uk_126 + 3154963*uk_127 + 1293971*uk_128 + 530707*uk_129 + 378824*uk_13 + 704969*uk_130 + 63368*uk_131 + 1687173*uk_132 + 1718857*uk_133 + 704969*uk_134 + 5696*uk_135 + 151656*uk_136 + 154504*uk_137 + 63368*uk_138 + 4037841*uk_139 + 10086189*uk_14 + 4113669*uk_140 + 1687173*uk_141 + 4190921*uk_142 + 1718857*uk_143 + 704969*uk_144 + 512*uk_145 + 13632*uk_146 + 13888*uk_147 + 5696*uk_148 + 362952*uk_149 + 10275601*uk_15 + 369768*uk_150 + 151656*uk_151 + 376712*uk_152 + 154504*uk_153 + 63368*uk_154 + 9663597*uk_155 + 9845073*uk_156 + 4037841*uk_157 + 10029957*uk_158 + 4113669*uk_159 + 4214417*uk_16 + 1687173*uk_160 + 10218313*uk_161 + 4190921*uk_162 + 1718857*uk_163 + 704969*uk_164 + 3969*uk_17 + 4221*uk_18 + 5607*uk_19 + 63*uk_2 + 504*uk_20 + 13419*uk_21 + 13671*uk_22 + 5607*uk_23 + 4489*uk_24 + 5963*uk_25 + 536*uk_26 + 14271*uk_27 + 14539*uk_28 + 5963*uk_29 + 67*uk_3 + 7921*uk_30 + 712*uk_31 + 18957*uk_32 + 19313*uk_33 + 7921*uk_34 + 64*uk_35 + 1704*uk_36 + 1736*uk_37 + 712*uk_38 + 45369*uk_39 + 89*uk_4 + 46221*uk_40 + 18957*uk_41 + 47089*uk_42 + 19313*uk_43 + 7921*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 150234542803*uk_47 + 199565288201*uk_48 + 17938452872*uk_49 + 8*uk_5 + 477611307717*uk_50 + 486580534153*uk_51 + 199565288201*uk_52 + 187944057*uk_53 + 199877013*uk_54 + 265508271*uk_55 + 23865912*uk_56 + 635429907*uk_57 + 647362863*uk_58 + 265508271*uk_59 + 213*uk_6 + 212567617*uk_60 + 282365939*uk_61 + 25381208*uk_62 + 675774663*uk_63 + 688465267*uk_64 + 282365939*uk_65 + 375083113*uk_66 + 33715336*uk_67 + 897670821*uk_68 + 914528489*uk_69 + 217*uk_7 + 375083113*uk_70 + 3030592*uk_71 + 80689512*uk_72 + 82204808*uk_73 + 33715336*uk_74 + 2148358257*uk_75 + 2188703013*uk_76 + 897670821*uk_77 + 2229805417*uk_78 + 914528489*uk_79 + 89*uk_8 + 375083113*uk_80 + 250047*uk_81 + 265923*uk_82 + 353241*uk_83 + 31752*uk_84 + 845397*uk_85 + 861273*uk_86 + 353241*uk_87 + 282807*uk_88 + 375669*uk_89 + 2242306609*uk_9 + 33768*uk_90 + 899073*uk_91 + 915957*uk_92 + 375669*uk_93 + 499023*uk_94 + 44856*uk_95 + 1194291*uk_96 + 1216719*uk_97 + 499023*uk_98 + 4032*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 108360*uk_100 + 109368*uk_101 + 33768*uk_102 + 2912175*uk_103 + 2939265*uk_104 + 907515*uk_105 + 2966607*uk_106 + 915957*uk_107 + 282807*uk_108 + 148877*uk_109 + 2509709*uk_11 + 188203*uk_110 + 22472*uk_111 + 603935*uk_112 + 609553*uk_113 + 188203*uk_114 + 237917*uk_115 + 28408*uk_116 + 763465*uk_117 + 770567*uk_118 + 237917*uk_119 + 3172651*uk_12 + 3392*uk_120 + 91160*uk_121 + 92008*uk_122 + 28408*uk_123 + 2449925*uk_124 + 2472715*uk_125 + 763465*uk_126 + 2495717*uk_127 + 770567*uk_128 + 237917*uk_129 + 378824*uk_13 + 300763*uk_130 + 35912*uk_131 + 965135*uk_132 + 974113*uk_133 + 300763*uk_134 + 4288*uk_135 + 115240*uk_136 + 116312*uk_137 + 35912*uk_138 + 3097075*uk_139 + 10180895*uk_14 + 3125885*uk_140 + 965135*uk_141 + 3154963*uk_142 + 974113*uk_143 + 300763*uk_144 + 512*uk_145 + 13760*uk_146 + 13888*uk_147 + 4288*uk_148 + 369800*uk_149 + 10275601*uk_15 + 373240*uk_150 + 115240*uk_151 + 376712*uk_152 + 116312*uk_153 + 35912*uk_154 + 9938375*uk_155 + 10030825*uk_156 + 3097075*uk_157 + 10124135*uk_158 + 3125885*uk_159 + 3172651*uk_16 + 965135*uk_160 + 10218313*uk_161 + 3154963*uk_162 + 974113*uk_163 + 300763*uk_164 + 3969*uk_17 + 3339*uk_18 + 4221*uk_19 + 63*uk_2 + 504*uk_20 + 13545*uk_21 + 13671*uk_22 + 4221*uk_23 + 2809*uk_24 + 3551*uk_25 + 424*uk_26 + 11395*uk_27 + 11501*uk_28 + 3551*uk_29 + 53*uk_3 + 4489*uk_30 + 536*uk_31 + 14405*uk_32 + 14539*uk_33 + 4489*uk_34 + 64*uk_35 + 1720*uk_36 + 1736*uk_37 + 536*uk_38 + 46225*uk_39 + 67*uk_4 + 46655*uk_40 + 14405*uk_41 + 47089*uk_42 + 14539*uk_43 + 4489*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 118842250277*uk_47 + 150234542803*uk_48 + 17938452872*uk_49 + 8*uk_5 + 482095920935*uk_50 + 486580534153*uk_51 + 150234542803*uk_52 + 187944057*uk_53 + 158111667*uk_54 + 199877013*uk_55 + 23865912*uk_56 + 641396385*uk_57 + 647362863*uk_58 + 199877013*uk_59 + 215*uk_6 + 133014577*uk_60 + 168150503*uk_61 + 20077672*uk_62 + 539587435*uk_63 + 544606853*uk_64 + 168150503*uk_65 + 212567617*uk_66 + 25381208*uk_67 + 682119965*uk_68 + 688465267*uk_69 + 217*uk_7 + 212567617*uk_70 + 3030592*uk_71 + 81447160*uk_72 + 82204808*uk_73 + 25381208*uk_74 + 2188892425*uk_75 + 2209254215*uk_76 + 682119965*uk_77 + 2229805417*uk_78 + 688465267*uk_79 + 67*uk_8 + 212567617*uk_80 + 250047*uk_81 + 210357*uk_82 + 265923*uk_83 + 31752*uk_84 + 853335*uk_85 + 861273*uk_86 + 265923*uk_87 + 176967*uk_88 + 223713*uk_89 + 2242306609*uk_9 + 26712*uk_90 + 717885*uk_91 + 724563*uk_92 + 223713*uk_93 + 282807*uk_94 + 33768*uk_95 + 907515*uk_96 + 915957*uk_97 + 282807*uk_98 + 4032*uk_99, + uk_0 + 47353*uk_1 + 2983239*uk_10 + 109368*uk_100 + 109368*uk_101 + 26712*uk_102 + 2966607*uk_103 + 2966607*uk_104 + 724563*uk_105 + 2966607*uk_106 + 724563*uk_107 + 176967*uk_108 + 103823*uk_109 + 2225591*uk_11 + 117077*uk_110 + 17672*uk_111 + 479353*uk_112 + 479353*uk_113 + 117077*uk_114 + 132023*uk_115 + 19928*uk_116 + 540547*uk_117 + 540547*uk_118 + 132023*uk_119 + 2509709*uk_12 + 3008*uk_120 + 81592*uk_121 + 81592*uk_122 + 19928*uk_123 + 2213183*uk_124 + 2213183*uk_125 + 540547*uk_126 + 2213183*uk_127 + 540547*uk_128 + 132023*uk_129 + 378824*uk_13 + 148877*uk_130 + 22472*uk_131 + 609553*uk_132 + 609553*uk_133 + 148877*uk_134 + 3392*uk_135 + 92008*uk_136 + 92008*uk_137 + 22472*uk_138 + 2495717*uk_139 + 10275601*uk_14 + 2495717*uk_140 + 609553*uk_141 + 2495717*uk_142 + 609553*uk_143 + 148877*uk_144 + 512*uk_145 + 13888*uk_146 + 13888*uk_147 + 3392*uk_148 + 376712*uk_149 + 10275601*uk_15 + 376712*uk_150 + 92008*uk_151 + 376712*uk_152 + 92008*uk_153 + 22472*uk_154 + 10218313*uk_155 + 10218313*uk_156 + 2495717*uk_157 + 10218313*uk_158 + 2495717*uk_159 + 2509709*uk_16 + 609553*uk_160 + 10218313*uk_161 + 2495717*uk_162 + 609553*uk_163 + 148877*uk_164 + 3969*uk_17 + 2961*uk_18 + 3339*uk_19 + 63*uk_2 + 504*uk_20 + 13671*uk_21 + 13671*uk_22 + 3339*uk_23 + 2209*uk_24 + 2491*uk_25 + 376*uk_26 + 10199*uk_27 + 10199*uk_28 + 2491*uk_29 + 47*uk_3 + 2809*uk_30 + 424*uk_31 + 11501*uk_32 + 11501*uk_33 + 2809*uk_34 + 64*uk_35 + 1736*uk_36 + 1736*uk_37 + 424*uk_38 + 47089*uk_39 + 53*uk_4 + 47089*uk_40 + 11501*uk_41 + 47089*uk_42 + 11501*uk_43 + 2809*uk_44 + 106179944855977*uk_45 + 141265316367*uk_46 + 105388410623*uk_47 + 118842250277*uk_48 + 17938452872*uk_49 + 8*uk_5 + 486580534153*uk_50 + 486580534153*uk_51 + 118842250277*uk_52 + 187944057*uk_53 + 140212233*uk_54 + 158111667*uk_55 + 23865912*uk_56 + 647362863*uk_57 + 647362863*uk_58 + 158111667*uk_59 + 217*uk_6 + 104602777*uk_60 + 117956323*uk_61 + 17804728*uk_62 + 482953247*uk_63 + 482953247*uk_64 + 117956323*uk_65 + 133014577*uk_66 + 20077672*uk_67 + 544606853*uk_68 + 544606853*uk_69 + 217*uk_7 + 133014577*uk_70 + 3030592*uk_71 + 82204808*uk_72 + 82204808*uk_73 + 20077672*uk_74 + 2229805417*uk_75 + 2229805417*uk_76 + 544606853*uk_77 + 2229805417*uk_78 + 544606853*uk_79 + 53*uk_8 + 133014577*uk_80 + 250047*uk_81 + 186543*uk_82 + 210357*uk_83 + 31752*uk_84 + 861273*uk_85 + 861273*uk_86 + 210357*uk_87 + 139167*uk_88 + 156933*uk_89 + 2242306609*uk_9 + 23688*uk_90 + 642537*uk_91 + 642537*uk_92 + 156933*uk_93 + 176967*uk_94 + 26712*uk_95 + 724563*uk_96 + 724563*uk_97 + 176967*uk_98 + 4032*uk_99, + ] + +def sol_165x165(): + return { + uk_0: -QQ(295441,1683)*uk_2 - QQ(175799,1683)*uk_7 + QQ(2401696807,1)*uk_9 - QQ(9606787228,1683)*uk_10 + QQ(9606787228,1683)*uk_15 - QQ(29030443,1683)*uk_17 - QQ(5965893,187)*uk_22 + QQ(262901,99)*uk_42 + QQ(235539209256104,1)*uk_45 - QQ(232597130667529,1683)*uk_46 + QQ(1364372733998209,1683)*uk_51 - QQ(1133600892904,1683)*uk_53 - QQ(172922170104,187)*uk_58 + QQ(249776467928,99)*uk_78 - QQ(2401889209,1683)*uk_81 - QQ(636292759,187)*uk_86 - QQ(1034157281,187)*uk_106 + QQ(10558824289,1683)*uk_161, + uk_1: QQ(4,1683)*uk_2 - QQ(4,1683)*uk_7 - QQ(98072,1)*uk_9 + QQ(96847,1683)*uk_10 - QQ(568087,1683)*uk_15 + QQ(472,1683)*uk_17 + QQ(72,187)*uk_22 - QQ(104,99)*uk_42 - QQ(7216420377,1)*uk_45 - QQ(108808244,1683)*uk_46 - QQ(46106641036,1683)*uk_51 + QQ(17259541,1683)*uk_53 + QQ(1095291,187)*uk_58 - QQ(9936587,99)*uk_78 + QQ(41836,1683)*uk_81 + QQ(10036,187)*uk_86 + QQ(10124,187)*uk_106 - QQ(8,1)*uk_149 - QQ(586156,1683)*uk_161, + uk_3: -QQ(295441,1683)*uk_18 - QQ(175799,1683)*uk_28 + QQ(2401696807,1)*uk_47 - QQ(9606787228,1683)*uk_54 + QQ(9606787228,1683)*uk_64 - QQ(29030443,1683)*uk_82 - QQ(5965893,187)*uk_92 + QQ(262901,99)*uk_127 + QQ(8,1)*uk_149, + uk_4: -QQ(295441,1683)*uk_19 + QQ(1602583,3366)*uk_29 - QQ(175799,1683)*uk_33 - QQ(45670,99)*uk_34 - QQ(76006,187)*uk_38 + QQ(295441,1683)*uk_41 - QQ(45670,99)*uk_44 + QQ(2401696807,1)*uk_48 - QQ(9606787228,1683)*uk_55 + QQ(74452601017,3366)*uk_65 + QQ(9606787228,1683)*uk_69 - QQ(2401696807,99)*uk_70 - QQ(4803393614,187)*uk_74 + QQ(9606787228,1683)*uk_77 - QQ(2401696807,99)*uk_80 - QQ(29030443,1683)*uk_83 + QQ(11596905,374)*uk_93 - QQ(5965893,187)*uk_97 - QQ(769658,33)*uk_98 - QQ(17335370,1683)*uk_102 + QQ(29030443,1683)*uk_105 - QQ(769658,33)*uk_108 + QQ(77314807,3366)*uk_114 + QQ(750229,198)*uk_119 + QQ(72457964,1683)*uk_123 + QQ(11596905,374)*uk_126 + QQ(31304645,306)*uk_128 + QQ(750229,198)*uk_129 - QQ(3191393,99)*uk_134 - QQ(647642,9)*uk_138 - QQ(769658,33)*uk_141 + QQ(262901,99)*uk_142 - QQ(10478626,99)*uk_143 - QQ(3191393,99)*uk_144 - QQ(20480616,187)*uk_148 - QQ(17335370,1683)*uk_151 - QQ(174199750,1683)*uk_153 - QQ(647642,9)*uk_154 + QQ(29030443,1683)*uk_157 + QQ(5965893,187)*uk_159 - QQ(769658,33)*uk_160 - QQ(10478626,99)*uk_163 - QQ(3191393,99)*uk_164, + uk_5: -QQ(295441,1683)*uk_20 - QQ(175799,1683)*uk_37 + QQ(2401696807,1)*uk_49 - QQ(9606787228,1683)*uk_56 + QQ(9606787228,1683)*uk_73 - QQ(29030443,1683)*uk_84 - QQ(5965893,187)*uk_101 + QQ(262901,99)*uk_152, + uk_6: -QQ(295441,1683)*uk_21 - QQ(175799,1683)*uk_40 + QQ(2401696807,1)*uk_50 - QQ(9606787228,1683)*uk_57 + QQ(9606787228,1683)*uk_76 - QQ(29030443,1683)*uk_85 - QQ(5965893,187)*uk_104 + QQ(262901,99)*uk_158, + uk_8: -QQ(295441,1683)*uk_23 - QQ(1602583,3366)*uk_29 + QQ(45670,99)*uk_34 + QQ(76006,187)*uk_38 - QQ(295441,1683)*uk_41 - QQ(175799,1683)*uk_43 + QQ(45670,99)*uk_44 + QQ(2401696807,1)*uk_52 - QQ(9606787228,1683)*uk_59 - QQ(74452601017,3366)*uk_65 + QQ(2401696807,99)*uk_70 + QQ(4803393614,187)*uk_74 - QQ(9606787228,1683)*uk_77 + QQ(9606787228,1683)*uk_79 + QQ(2401696807,99)*uk_80 - QQ(29030443,1683)*uk_87 - QQ(11596905,374)*uk_93 + QQ(769658,33)*uk_98 + QQ(17335370,1683)*uk_102 - QQ(29030443,1683)*uk_105 - QQ(5965893,187)*uk_107 + QQ(769658,33)*uk_108 - QQ(77314807,3366)*uk_114 - QQ(750229,198)*uk_119 - QQ(72457964,1683)*uk_123 - QQ(11596905,374)*uk_126 - QQ(31304645,306)*uk_128 - QQ(750229,198)*uk_129 + QQ(3191393,99)*uk_134 + QQ(647642,9)*uk_138 + QQ(769658,33)*uk_141 + QQ(10478626,99)*uk_143 + QQ(3191393,99)*uk_144 + QQ(20480616,187)*uk_148 + QQ(17335370,1683)*uk_151 + QQ(174199750,1683)*uk_153 + QQ(647642,9)*uk_154 - QQ(29030443,1683)*uk_157 - QQ(5965893,187)*uk_159 + QQ(769658,33)*uk_160 + QQ(262901,99)*uk_162 + QQ(10478626,99)*uk_163 + QQ(3191393,99)*uk_164, + uk_11: QQ(4,1683)*uk_18 - QQ(4,1683)*uk_28 - QQ(98072,1)*uk_47 + QQ(96847,1683)*uk_54 - QQ(568087,1683)*uk_64 + QQ(472,1683)*uk_82 + QQ(72,187)*uk_92 - QQ(104,99)*uk_127, + uk_12: QQ(4,1683)*uk_19 - QQ(31,3366)*uk_29 - QQ(4,1683)*uk_33 + QQ(1,99)*uk_34 + QQ(2,187)*uk_38 - QQ(4,1683)*uk_41 + QQ(1,99)*uk_44 - QQ(98072,1)*uk_48 + QQ(96847,1683)*uk_55 - QQ(1437649,3366)*uk_65 - QQ(568087,1683)*uk_69 + QQ(52402,99)*uk_70 + QQ(120138,187)*uk_74 - QQ(96847,1683)*uk_77 + QQ(52402,99)*uk_80 + QQ(472,1683)*uk_83 - QQ(225,374)*uk_93 + QQ(72,187)*uk_97 + QQ(17,33)*uk_98 + QQ(590,1683)*uk_102 - QQ(472,1683)*uk_105 + QQ(17,33)*uk_108 - QQ(1519,3366)*uk_114 - QQ(13,198)*uk_119 - QQ(1388,1683)*uk_123 - QQ(225,374)*uk_126 - QQ(605,306)*uk_128 - QQ(13,198)*uk_129 + QQ(68,99)*uk_134 + QQ(14,9)*uk_138 + QQ(17,33)*uk_141 - QQ(104,99)*uk_142 + QQ(229,99)*uk_143 + QQ(68,99)*uk_144 + QQ(472,187)*uk_148 + QQ(590,1683)*uk_151 + QQ(4450,1683)*uk_153 + QQ(14,9)*uk_154 - QQ(472,1683)*uk_157 - QQ(72,187)*uk_159 + QQ(17,33)*uk_160 + QQ(229,99)*uk_163 + QQ(68,99)*uk_164, + uk_13: QQ(4,1683)*uk_20 - QQ(4,1683)*uk_37 - QQ(98072,1)*uk_49 + QQ(96847,1683)*uk_56 - QQ(568087,1683)*uk_73 + QQ(472,1683)*uk_84 + QQ(72,187)*uk_101 - QQ(104,99)*uk_152, + uk_14: QQ(4,1683)*uk_21 - QQ(4,1683)*uk_40 - QQ(98072,1)*uk_50 + QQ(96847,1683)*uk_57 - QQ(568087,1683)*uk_76 + QQ(472,1683)*uk_85 + QQ(72,187)*uk_104 - QQ(104,99)*uk_158, + uk_16: QQ(4,1683)*uk_23 + QQ(31,3366)*uk_29 - QQ(1,99)*uk_34 - QQ(2,187)*uk_38 + QQ(4,1683)*uk_41 - QQ(4,1683)*uk_43 - QQ(1,99)*uk_44 - QQ(98072,1)*uk_52 + QQ(96847,1683)*uk_59 + QQ(1437649,3366)*uk_65 - QQ(52402,99)*uk_70 - QQ(120138,187)*uk_74 + QQ(96847,1683)*uk_77 - QQ(568087,1683)*uk_79 - QQ(52402,99)*uk_80 + QQ(472,1683)*uk_87 + QQ(225,374)*uk_93 - QQ(17,33)*uk_98 - QQ(590,1683)*uk_102 + QQ(472,1683)*uk_105 + QQ(72,187)*uk_107 - QQ(17,33)*uk_108 + QQ(1519,3366)*uk_114 + QQ(13,198)*uk_119 + QQ(1388,1683)*uk_123 + QQ(225,374)*uk_126 + QQ(605,306)*uk_128 + QQ(13,198)*uk_129 - QQ(68,99)*uk_134 - QQ(14,9)*uk_138 - QQ(17,33)*uk_141 - QQ(229,99)*uk_143 - QQ(68,99)*uk_144 - QQ(472,187)*uk_148 - QQ(590,1683)*uk_151 - QQ(4450,1683)*uk_153 - QQ(14,9)*uk_154 + QQ(472,1683)*uk_157 + QQ(72,187)*uk_159 - QQ(17,33)*uk_160 - QQ(104,99)*uk_162 - QQ(229,99)*uk_163 - QQ(68,99)*uk_164, + uk_24: -QQ(295441,1683)*uk_88 - QQ(175799,1683)*uk_113, + uk_26: -QQ(295441,1683)*uk_90 - QQ(175799,1683)*uk_122, uk_25: -uk_29 - QQ(295441,1683)*uk_89 - QQ(295441,1683)*uk_93 - QQ(175799,1683)*uk_118 - QQ(175799,1683)*uk_128, + uk_27: -QQ(295441,1683)*uk_91 - QQ(175799,1683)*uk_125 - QQ(4,1)*uk_149, + uk_30: -uk_34 - uk_44 - QQ(295441,1683)*uk_94 - QQ(295441,1683)*uk_98 - QQ(295441,1683)*uk_108 - QQ(175799,1683)*uk_133 - QQ(175799,1683)*uk_143 - QQ(175799,1683)*uk_163, + uk_31: -uk_38 - QQ(295441,1683)*uk_95 - QQ(295441,1683)*uk_102 - QQ(175799,1683)*uk_137 - QQ(175799,1683)*uk_153, + uk_32: -uk_41 - QQ(295441,1683)*uk_96 - QQ(295441,1683)*uk_105 - QQ(175799,1683)*uk_140 + QQ(4,1)*uk_149 - QQ(175799,1683)*uk_159, + uk_35: -QQ(295441,1683)*uk_99 - QQ(175799,1683)*uk_147, + uk_36: -QQ(295441,1683)*uk_100 - QQ(2,1)*uk_149 - QQ(175799,1683)*uk_150, + uk_39: -QQ(295441,1683)*uk_103 - QQ(175799,1683)*uk_156, + uk_60: QQ(4,1683)*uk_88 - QQ(4,1683)*uk_113, + uk_61: -uk_65 + QQ(4,1683)*uk_89 + QQ(4,1683)*uk_93 - QQ(4,1683)*uk_118 - QQ(4,1683)*uk_128, + uk_62: QQ(4,1683)*uk_90 - QQ(4,1683)*uk_122, + uk_63: QQ(4,1683)*uk_91 - QQ(4,1683)*uk_125, + uk_66: -uk_70 - uk_80 + QQ(4,1683)*uk_94 + QQ(4,1683)*uk_98 + QQ(4,1683)*uk_108 - QQ(4,1683)*uk_133 - QQ(4,1683)*uk_143 - QQ(4,1683)*uk_163, + uk_67: -uk_74 + QQ(4,1683)*uk_95 + QQ(4,1683)*uk_102 - QQ(4,1683)*uk_137 - QQ(4,1683)*uk_153, + uk_68: -uk_77 + QQ(4,1683)*uk_96 + QQ(4,1683)*uk_105 - QQ(4,1683)*uk_140 - QQ(4,1683)*uk_159, + uk_71: QQ(4,1683)*uk_99 - QQ(4,1683)*uk_147, + uk_72: QQ(4,1683)*uk_100 - QQ(4,1683)*uk_150, + uk_75: QQ(4,1683)*uk_103 - QQ(4,1683)*uk_156, + uk_109: 0, + uk_110: -uk_114, + uk_111: 0, + uk_112: 0, + uk_115: -uk_119 - uk_129, + uk_116: -uk_123, + uk_117: -uk_126, + uk_120: 0, + uk_121: 0, + uk_124: 0, + uk_130: -uk_134 - uk_144 - uk_164, + uk_131: -uk_138 - uk_154, + uk_132: -uk_141 - uk_160, + uk_135: -uk_148, + uk_136: -uk_151, + uk_139: -uk_157, + uk_145: 0, + uk_146: 0, + uk_155: 0, + } + +def time_eqs_165x165(): + if len(eqs_165x165()) != 165: + raise ValueError("length should be 165") + +def time_solve_lin_sys_165x165(): + eqs = eqs_165x165() + sol = solve_lin_sys(eqs, R_165) + if sol != sol_165x165(): + raise ValueError("Value should be equal") + +def time_verify_sol_165x165(): + eqs = eqs_165x165() + sol = sol_165x165() + zeros = [ eq.compose(sol) for eq in eqs ] + if not all(zero == 0 for zero in zeros): + raise ValueError("All should be 0") + +def time_to_expr_eqs_165x165(): + eqs = eqs_165x165() + assert [ R_165.from_expr(eq.as_expr()) for eq in eqs ] == eqs + +# Benchmark R_49: shows how fast are arithmetics in rational function fields. +F_abc, a, b, c = field("a,b,c", ZZ) +R_49, k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, k11, k12, k13, k14, k15, k16, k17, k18, k19, k20, k21, k22, k23, k24, k25, k26, k27, k28, k29, k30, k31, k32, k33, k34, k35, k36, k37, k38, k39, k40, k41, k42, k43, k44, k45, k46, k47, k48, k49 = ring("k1:50", F_abc) + +def eqs_189x49(): + return [ + -b*k8/a+c*k8/a, + -b*k11/a+c*k11/a, + -b*k10/a+c*k10/a+k2, + -k3-b*k9/a+c*k9/a, + -b*k14/a+c*k14/a, + -b*k15/a+c*k15/a, + -b*k18/a+c*k18/a-k2, + -b*k17/a+c*k17/a, + -b*k16/a+c*k16/a+k4, + -b*k13/a+c*k13/a-b*k21/a+c*k21/a+b*k5/a-c*k5/a, + b*k44/a-c*k44/a, + -b*k45/a+c*k45/a, + -b*k20/a+c*k20/a, + -b*k44/a+c*k44/a, + b*k46/a-c*k46/a, + b**2*k47/a**2-2*b*c*k47/a**2+c**2*k47/a**2, + k3, + -k4, + -b*k12/a+c*k12/a-a*k6/b+c*k6/b, + -b*k19/a+c*k19/a+a*k7/c-b*k7/c, + b*k45/a-c*k45/a, + -b*k46/a+c*k46/a, + -k48+c*k48/a+c*k48/b-c**2*k48/(a*b), + -k49+b*k49/a+b*k49/c-b**2*k49/(a*c), + a*k1/b-c*k1/b, + a*k4/b-c*k4/b, + a*k3/b-c*k3/b+k9, + -k10+a*k2/b-c*k2/b, + a*k7/b-c*k7/b, + -k9, + k11, + b*k12/a-c*k12/a+a*k6/b-c*k6/b, + a*k15/b-c*k15/b, + k10+a*k18/b-c*k18/b, + -k11+a*k17/b-c*k17/b, + a*k16/b-c*k16/b, + -a*k13/b+c*k13/b+a*k21/b-c*k21/b+a*k5/b-c*k5/b, + -a*k44/b+c*k44/b, + a*k45/b-c*k45/b, + a*k14/c-b*k14/c+a*k20/b-c*k20/b, + a*k44/b-c*k44/b, + -a*k46/b+c*k46/b, + -k47+c*k47/a+c*k47/b-c**2*k47/(a*b), + a*k19/b-c*k19/b, + -a*k45/b+c*k45/b, + a*k46/b-c*k46/b, + a**2*k48/b**2-2*a*c*k48/b**2+c**2*k48/b**2, + -k49+a*k49/b+a*k49/c-a**2*k49/(b*c), + k16, + -k17, + -a*k1/c+b*k1/c, + -k16-a*k4/c+b*k4/c, + -a*k3/c+b*k3/c, + k18-a*k2/c+b*k2/c, + b*k19/a-c*k19/a-a*k7/c+b*k7/c, + -a*k6/c+b*k6/c, + -a*k8/c+b*k8/c, + -a*k11/c+b*k11/c+k17, + -a*k10/c+b*k10/c-k18, + -a*k9/c+b*k9/c, + -a*k14/c+b*k14/c-a*k20/b+c*k20/b, + -a*k13/c+b*k13/c+a*k21/c-b*k21/c-a*k5/c+b*k5/c, + a*k44/c-b*k44/c, + -a*k45/c+b*k45/c, + -a*k44/c+b*k44/c, + a*k46/c-b*k46/c, + -k47+b*k47/a+b*k47/c-b**2*k47/(a*c), + -a*k12/c+b*k12/c, + a*k45/c-b*k45/c, + -a*k46/c+b*k46/c, + -k48+a*k48/b+a*k48/c-a**2*k48/(b*c), + a**2*k49/c**2-2*a*b*k49/c**2+b**2*k49/c**2, + k8, + k11, + -k15, + k10-k18, + -k17, + k9, + -k16, + -k29, + k14-k32, + -k21+k23-k31, + -k24-k30, + -k35, + k44, + -k45, + k36, + k13-k23+k39, + -k20+k38, + k25+k37, + b*k26/a-c*k26/a-k34+k42, + -2*k44, + k45, + k46, + b*k47/a-c*k47/a, + k41, + k44, + -k46, + -b*k47/a+c*k47/a, + k12+k24, + -k19-k25, + -a*k27/b+c*k27/b-k33, + k45, + -k46, + -a*k48/b+c*k48/b, + a*k28/c-b*k28/c+k40, + -k45, + k46, + a*k48/b-c*k48/b, + a*k49/c-b*k49/c, + -a*k49/c+b*k49/c, + -k1, + -k4, + -k3, + k15, + k18-k2, + k17, + k16, + k22, + k25-k7, + k24+k30, + k21+k23-k31, + k28, + -k44, + k45, + -k30-k6, + k20+k32, + k27+b*k33/a-c*k33/a, + k44, + -k46, + -b*k47/a+c*k47/a, + -k36, + k31-k39-k5, + -k32-k38, + k19-k37, + k26-a*k34/b+c*k34/b-k42, + k44, + -2*k45, + k46, + a*k48/b-c*k48/b, + a*k35/c-b*k35/c-k41, + -k44, + k46, + b*k47/a-c*k47/a, + -a*k49/c+b*k49/c, + -k40, + k45, + -k46, + -a*k48/b+c*k48/b, + a*k49/c-b*k49/c, + k1, + k4, + k3, + -k8, + -k11, + -k10+k2, + -k9, + k37+k7, + -k14-k38, + -k22, + -k25-k37, + -k24+k6, + -k13-k23+k39, + -k28+b*k40/a-c*k40/a, + k44, + -k45, + -k27, + -k44, + k46, + b*k47/a-c*k47/a, + k29, + k32+k38, + k31-k39+k5, + -k12+k30, + k35-a*k41/b+c*k41/b, + -k44, + k45, + -k26+k34+a*k42/c-b*k42/c, + k44, + k45, + -2*k46, + -b*k47/a+c*k47/a, + -a*k48/b+c*k48/b, + a*k49/c-b*k49/c, + k33, + -k45, + k46, + a*k48/b-c*k48/b, + -a*k49/c+b*k49/c, + ] + +def sol_189x49(): + return { + k49: 0, k48: 0, k47: 0, k46: 0, k45: 0, k44: 0, k41: 0, k40: 0, + k38: 0, k37: 0, k36: 0, k35: 0, k33: 0, k32: 0, k30: 0, k29: 0, + k28: 0, k27: 0, k25: 0, k24: 0, k22: 0, k21: 0, k20: 0, k19: 0, + k18: 0, k17: 0, k16: 0, k15: 0, k14: 0, k13: 0, k12: 0, k11: 0, + k10: 0, k9: 0, k8: 0, k7: 0, k6: 0, k5: 0, k4: 0, k3: 0, + k2: 0, k1: 0, + k34: b/c*k42, + k31: k39, + k26: a/c*k42, + k23: k39, + } + +def time_eqs_189x49(): + if len(eqs_189x49()) != 189: + raise ValueError("Length should be equal to 189") + +def time_solve_lin_sys_189x49(): + eqs = eqs_189x49() + sol = solve_lin_sys(eqs, R_49) + if sol != sol_189x49(): + raise ValueError("Values should be equal") + +def time_verify_sol_189x49(): + eqs = eqs_189x49() + sol = sol_189x49() + zeros = [ eq.compose(sol) for eq in eqs ] + assert all(zero == 0 for zero in zeros) + +def time_to_expr_eqs_189x49(): + eqs = eqs_189x49() + assert [ R_49.from_expr(eq.as_expr()) for eq in eqs ] == eqs + +# Benchmark R_8: shows how fast polynomial GCDs are computed. + +F_a5_5, a_11, a_12, a_13, a_14, a_21, a_22, a_23, a_24, a_31, a_32, a_33, a_34, a_41, a_42, a_43, a_44 = field("a_(1:5)(1:5)", ZZ) +R_8, x0, x1, x2, x3, x4, x5, x6, x7 = ring("x:8", F_a5_5) + +def eqs_10x8(): + return [ + (a_33*a_34 + a_33*a_44 + a_43*a_44)*x3 + (a_33*a_34 + a_33*a_44 + a_43*a_44)*x4 + (a_12*a_34 + a_12*a_44 + a_22*a_34 + a_22*a_44)*x5 + (a_12*a_44 + a_22*a_44)*x6 + (a_12*a_33 + a_22*a_33)*x7 - a_12*a_33 - a_12*a_43 - a_22*a_33 - a_22*a_43, + (a_33 + a_34 + a_43 + a_44)*x3 + (a_33 + a_34 + a_43 + a_44)*x4 + (a_12 + a_22 + a_34 + a_44)*x5 + (a_12 + a_22 + a_44)*x6 + (a_12 + a_22 + a_33)*x7 - a_12 - a_22 - a_33 - a_43, + x3 + x4 + x5 + x6 + x7 - 1, + (a_12*a_33*a_34 + a_12*a_33*a_44 + a_12*a_43*a_44 + a_22*a_33*a_34 + a_22*a_33*a_44 + a_22*a_43*a_44)*x0 + (a_22*a_33*a_34 + a_22*a_33*a_44 + a_22*a_43*a_44)*x1 + (a_12*a_33*a_34 + a_12*a_33*a_44 + a_12*a_43*a_44 + a_22*a_33*a_34 + a_22*a_33*a_44 + a_22*a_43*a_44)*x2 + (a_11*a_33*a_34 + a_11*a_33*a_44 + a_11*a_43*a_44 + a_31*a_33*a_34 + a_31*a_33*a_44 + a_31*a_43*a_44)*x3 + (a_11*a_33*a_34 + a_11*a_33*a_44 + a_11*a_43*a_44 + a_21*a_33*a_34 + a_21*a_33*a_44 + a_21*a_43*a_44 + a_31*a_33*a_34 + a_31*a_33*a_44 + a_31*a_43*a_44)*x4 + (a_11*a_12*a_34 + a_11*a_12*a_44 + a_11*a_22*a_34 + a_11*a_22*a_44 + a_12*a_31*a_34 + a_12*a_31*a_44 + a_21*a_22*a_34 + a_21*a_22*a_44 + a_22*a_31*a_34 + a_22*a_31*a_44)*x5 + (a_11*a_12*a_44 + a_11*a_22*a_44 + a_12*a_31*a_44 + a_21*a_22*a_44 + a_22*a_31*a_44)*x6 + (a_11*a_12*a_33 + a_11*a_22*a_33 + a_12*a_31*a_33 + a_21*a_22*a_33 + a_22*a_31*a_33)*x7 - a_11*a_12*a_33 - a_11*a_12*a_43 - a_11*a_22*a_33 - a_11*a_22*a_43 - a_12*a_31*a_33 - a_12*a_31*a_43 - a_21*a_22*a_33 - a_21*a_22*a_43 - a_22*a_31*a_33 - a_22*a_31*a_43, + (a_12*a_33 + a_12*a_34 + a_12*a_43 + a_12*a_44 + a_22*a_33 + a_22*a_34 + a_22*a_43 + a_22*a_44 + a_33*a_34 + a_33*a_44 + a_43*a_44)*x0 + (a_22*a_33 + a_22*a_34 + a_22*a_43 + a_22*a_44 + a_33*a_34 + a_33*a_44 + a_43*a_44)*x1 + (a_12*a_33 + a_12*a_34 + a_12*a_43 + a_12*a_44 + a_22*a_33 + a_22*a_34 + a_22*a_43 + a_22*a_44 + a_33*a_34 + a_33*a_44 + a_43*a_44)*x2 + (a_11*a_33 + a_11*a_34 + a_11*a_43 + a_11*a_44 + a_31*a_33 + a_31*a_34 + a_31*a_43 + a_31*a_44 + a_33*a_34 + a_33*a_44 + a_43*a_44)*x3 + (a_11*a_33 + a_11*a_34 + a_11*a_43 + a_11*a_44 + a_21*a_33 + a_21*a_34 + a_21*a_43 + a_21*a_44 + a_31*a_33 + a_31*a_34 + a_31*a_43 + a_31*a_44 + a_33*a_34 + a_33*a_44 + a_43*a_44)*x4 + (a_11*a_12 + a_11*a_22 + a_11*a_34 + a_11*a_44 + a_12*a_31 + a_12*a_34 + a_12*a_44 + a_21*a_22 + a_21*a_34 + a_21*a_44 + a_22*a_31 + a_22*a_34 + a_22*a_44 + a_31*a_34 + a_31*a_44)*x5 + (a_11*a_12 + a_11*a_22 + a_11*a_44 + a_12*a_31 + a_12*a_44 + a_21*a_22 + a_21*a_44 + a_22*a_31 + a_22*a_44 + a_31*a_44)*x6 + (a_11*a_12 + a_11*a_22 + a_11*a_33 + a_12*a_31 + a_12*a_33 + a_21*a_22 + a_21*a_33 + a_22*a_31 + a_22*a_33 + a_31*a_33)*x7 - a_11*a_12 - a_11*a_22 - a_11*a_33 - a_11*a_43 - a_12*a_31 - a_12*a_33 - a_12*a_43 - a_21*a_22 - a_21*a_33 - a_21*a_43 - a_22*a_31 - a_22*a_33 - a_22*a_43 - a_31*a_33 - a_31*a_43, + (a_12 + a_22 + a_33 + a_34 + a_43 + a_44)*x0 + (a_22 + a_33 + a_34 + a_43 + a_44)*x1 + (a_12 + a_22 + a_33 + a_34 + a_43 + a_44)*x2 + (a_11 + a_31 + a_33 + a_34 + a_43 + a_44)*x3 + (a_11 + a_21 + a_31 + a_33 + a_34 + a_43 + a_44)*x4 + (a_11 + a_12 + a_21 + a_22 + a_31 + a_34 + a_44)*x5 + (a_11 + a_12 + a_21 + a_22 + a_31 + a_44)*x6 + (a_11 + a_12 + a_21 + a_22 + a_31 + a_33)*x7 - a_11 - a_12 - a_21 - a_22 - a_31 - a_33 - a_43, + x0 + x1 + x2 + x3 + x4 + x5 + x6 + x7 - 1, + (a_12*a_34 + a_12*a_44 + a_22*a_34 + a_22*a_44)*x2 + (a_31*a_34 + a_31*a_44)*x3 + (a_31*a_34 + a_31*a_44)*x4 + (a_12*a_31 + a_22*a_31)*x7 - a_12*a_31 - a_22*a_31, + (a_12 + a_22 + a_34 + a_44)*x2 + a_31*x3 + a_31*x4 + a_31*x7 - a_31, + x2, + ] + +def sol_10x8(): + return { + x0: -a_21/a_12*x4, + x1: a_21/a_12*x4, + x2: 0, + x3: -x4, + x5: a_43/a_34, + x6: -a_43/a_34, + x7: 1, + } + +def time_eqs_10x8(): + if len(eqs_10x8()) != 10: + raise ValueError("Value should be equal to 10") + +def time_solve_lin_sys_10x8(): + eqs = eqs_10x8() + sol = solve_lin_sys(eqs, R_8) + if sol != sol_10x8(): + raise ValueError("Values should be equal") + +def time_verify_sol_10x8(): + eqs = eqs_10x8() + sol = sol_10x8() + zeros = [ eq.compose(sol) for eq in eqs ] + if not all(zero == 0 for zero in zeros): + raise ValueError("All values in zero should be 0") + +def time_to_expr_eqs_10x8(): + eqs = eqs_10x8() + assert [ R_8.from_expr(eq.as_expr()) for eq in eqs ] == eqs diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/__init__.py b/phivenv/Lib/site-packages/sympy/polys/domains/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c6839b4494afd0ee0c0ecd9ddee65d1afbdc6b53 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/domains/__init__.py @@ -0,0 +1,57 @@ +"""Implementation of mathematical domains. """ + +__all__ = [ + 'Domain', 'FiniteField', 'IntegerRing', 'RationalField', 'RealField', + 'ComplexField', 'AlgebraicField', 'PolynomialRing', 'FractionField', + 'ExpressionDomain', 'PythonRational', + + 'GF', 'FF', 'ZZ', 'QQ', 'ZZ_I', 'QQ_I', 'RR', 'CC', 'EX', 'EXRAW', +] + +from .domain import Domain +from .finitefield import FiniteField, FF, GF +from .integerring import IntegerRing, ZZ +from .rationalfield import RationalField, QQ +from .algebraicfield import AlgebraicField +from .gaussiandomains import ZZ_I, QQ_I +from .realfield import RealField, RR +from .complexfield import ComplexField, CC +from .polynomialring import PolynomialRing +from .fractionfield import FractionField +from .expressiondomain import ExpressionDomain, EX +from .expressionrawdomain import EXRAW +from .pythonrational import PythonRational + + +# This is imported purely for backwards compatibility because some parts of +# the codebase used to import this from here and it's possible that downstream +# does as well: +from sympy.external.gmpy import GROUND_TYPES # noqa: F401 + +# +# The rest of these are obsolete and provided only for backwards +# compatibility: +# + +from .pythonfinitefield import PythonFiniteField +from .gmpyfinitefield import GMPYFiniteField +from .pythonintegerring import PythonIntegerRing +from .gmpyintegerring import GMPYIntegerRing +from .pythonrationalfield import PythonRationalField +from .gmpyrationalfield import GMPYRationalField + +FF_python = PythonFiniteField +FF_gmpy = GMPYFiniteField + +ZZ_python = PythonIntegerRing +ZZ_gmpy = GMPYIntegerRing + +QQ_python = PythonRationalField +QQ_gmpy = GMPYRationalField + +__all__.extend(( + 'PythonFiniteField', 'GMPYFiniteField', 'PythonIntegerRing', + 'GMPYIntegerRing', 'PythonRational', 'GMPYRationalField', + + 'FF_python', 'FF_gmpy', 'ZZ_python', 'ZZ_gmpy', 'QQ_python', 'QQ_gmpy', +)) diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f9f3b76b1d2ed4fb827509687311ac578ecacab Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/algebraicfield.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/algebraicfield.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df05b84ca16e6426df1624963ab81f36c5f90e22 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/algebraicfield.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/characteristiczero.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/characteristiczero.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7cbf096fe046b05a60826c865aaa3d5f50b6ce5f Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/characteristiczero.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/complexfield.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/complexfield.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..803679055b47d84d4616484c00772384db08a778 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/complexfield.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/compositedomain.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/compositedomain.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6cace02ee523330f6d4ac0ba3c1e297c0ea6f52a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/compositedomain.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/domain.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/domain.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b2c10dfa26df5d81627399dd4da5f192c6a077f Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/domain.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/domainelement.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/domainelement.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec47682104f5c8c9008dc1390d45c38dbd0fec54 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/domainelement.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/expressiondomain.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/expressiondomain.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d71e5c5eba80b76c8632066bb9c8a5864df6db22 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/expressiondomain.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/expressionrawdomain.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/expressionrawdomain.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..58310795de42ff72533786026be6c2a523e68e49 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/expressionrawdomain.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/field.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/field.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26c716f657b956f8d5eb137fb0f2b491d9f90538 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/field.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/finitefield.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/finitefield.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f39a828e421159a13b6e0f9cb2cd87b0da84185 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/finitefield.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/fractionfield.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/fractionfield.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e22c29b016ef43a84e669612b032fca66bf4781 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/fractionfield.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/gaussiandomains.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/gaussiandomains.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84ef7b8b51a4866082e231654995f8fc0041f8f7 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/gaussiandomains.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/gmpyfinitefield.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/gmpyfinitefield.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb426377e8cb93c4f3f55ef58b53995667a4d6a0 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/gmpyfinitefield.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/gmpyintegerring.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/gmpyintegerring.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..86106df9350523eb8fe92f75feaa6d7d26f42852 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/gmpyintegerring.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/gmpyrationalfield.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/gmpyrationalfield.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0da873bb276da3e1101a2f322beefe80db80008a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/gmpyrationalfield.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/groundtypes.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/groundtypes.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50a2533dd91894d139aed804fd122393b345e1e0 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/groundtypes.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/integerring.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/integerring.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e5b37c4d0bbdbedc2ef17fc537c328d278c9426 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/integerring.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/modularinteger.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/modularinteger.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e1659be5d9dbb0b116461d58a1b7746f41bece2 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/modularinteger.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/mpelements.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/mpelements.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f479a0529c4ac63a3b40707b106d2cac73ddd56d Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/mpelements.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/old_fractionfield.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/old_fractionfield.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce74aa80cc06303dbee6adb6643562c695b1f562 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/old_fractionfield.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/old_polynomialring.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/old_polynomialring.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66a8b49724b1233dfb0101d2610658f99c022739 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/old_polynomialring.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/polynomialring.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/polynomialring.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c13375df024248f637fe8c2727ad994738fae93b Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/polynomialring.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/pythonfinitefield.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/pythonfinitefield.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af07cc9785648cf0127481c826780381282763f9 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/pythonfinitefield.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/pythonintegerring.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/pythonintegerring.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a63da81c001a07abcfeaad29479abf70bbe5f26 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/pythonintegerring.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/pythonrational.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/pythonrational.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc4f6f6ada6f338be67bee256de646147eb1b149 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/pythonrational.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/pythonrationalfield.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/pythonrationalfield.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..882ade65b88721a12da1762caa621823bf6b942a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/pythonrationalfield.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/quotientring.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/quotientring.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f1fe499dfae78fcd0001b9f2021492ce589bc8a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/quotientring.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/rationalfield.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/rationalfield.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..394b4167a050786bff58e903ffcb90b0af143bbd Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/rationalfield.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/realfield.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/realfield.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e5bd62efcbc1311434fce87feeb8ea7b5b2f4cf Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/realfield.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/ring.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/ring.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ac637fc0f390832698d278937b45ad78b987fdf Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/ring.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/simpledomain.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/simpledomain.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a234ebbedb396afe4f59a644da79cd363cd8af51 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/domains/__pycache__/simpledomain.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/algebraicfield.py b/phivenv/Lib/site-packages/sympy/polys/domains/algebraicfield.py new file mode 100644 index 0000000000000000000000000000000000000000..3ee3f10d90fc4a3331471ea9a24589d65654d1cd --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/domains/algebraicfield.py @@ -0,0 +1,638 @@ +"""Implementation of :class:`AlgebraicField` class. """ + + +from sympy.core.add import Add +from sympy.core.mul import Mul +from sympy.core.singleton import S +from sympy.core.symbol import Dummy, symbols +from sympy.polys.domains.characteristiczero import CharacteristicZero +from sympy.polys.domains.field import Field +from sympy.polys.domains.simpledomain import SimpleDomain +from sympy.polys.polyclasses import ANP +from sympy.polys.polyerrors import CoercionFailed, DomainError, NotAlgebraic, IsomorphismFailed +from sympy.utilities import public + +@public +class AlgebraicField(Field, CharacteristicZero, SimpleDomain): + r"""Algebraic number field :ref:`QQ(a)` + + A :ref:`QQ(a)` domain represents an `algebraic number field`_ + `\mathbb{Q}(a)` as a :py:class:`~.Domain` in the domain system (see + :ref:`polys-domainsintro`). + + A :py:class:`~.Poly` created from an expression involving `algebraic + numbers`_ will treat the algebraic numbers as generators if the generators + argument is not specified. + + >>> from sympy import Poly, Symbol, sqrt + >>> x = Symbol('x') + >>> Poly(x**2 + sqrt(2)) + Poly(x**2 + (sqrt(2)), x, sqrt(2), domain='ZZ') + + That is a multivariate polynomial with ``sqrt(2)`` treated as one of the + generators (variables). If the generators are explicitly specified then + ``sqrt(2)`` will be considered to be a coefficient but by default the + :ref:`EX` domain is used. To make a :py:class:`~.Poly` with a :ref:`QQ(a)` + domain the argument ``extension=True`` can be given. + + >>> Poly(x**2 + sqrt(2), x) + Poly(x**2 + sqrt(2), x, domain='EX') + >>> Poly(x**2 + sqrt(2), x, extension=True) + Poly(x**2 + sqrt(2), x, domain='QQ') + + A generator of the algebraic field extension can also be specified + explicitly which is particularly useful if the coefficients are all + rational but an extension field is needed (e.g. to factor the + polynomial). + + >>> Poly(x**2 + 1) + Poly(x**2 + 1, x, domain='ZZ') + >>> Poly(x**2 + 1, extension=sqrt(2)) + Poly(x**2 + 1, x, domain='QQ') + + It is possible to factorise a polynomial over a :ref:`QQ(a)` domain using + the ``extension`` argument to :py:func:`~.factor` or by specifying the domain + explicitly. + + >>> from sympy import factor, QQ + >>> factor(x**2 - 2) + x**2 - 2 + >>> factor(x**2 - 2, extension=sqrt(2)) + (x - sqrt(2))*(x + sqrt(2)) + >>> factor(x**2 - 2, domain='QQ') + (x - sqrt(2))*(x + sqrt(2)) + >>> factor(x**2 - 2, domain=QQ.algebraic_field(sqrt(2))) + (x - sqrt(2))*(x + sqrt(2)) + + The ``extension=True`` argument can be used but will only create an + extension that contains the coefficients which is usually not enough to + factorise the polynomial. + + >>> p = x**3 + sqrt(2)*x**2 - 2*x - 2*sqrt(2) + >>> factor(p) # treats sqrt(2) as a symbol + (x + sqrt(2))*(x**2 - 2) + >>> factor(p, extension=True) + (x - sqrt(2))*(x + sqrt(2))**2 + >>> factor(x**2 - 2, extension=True) # all rational coefficients + x**2 - 2 + + It is also possible to use :ref:`QQ(a)` with the :py:func:`~.cancel` + and :py:func:`~.gcd` functions. + + >>> from sympy import cancel, gcd + >>> cancel((x**2 - 2)/(x - sqrt(2))) + (x**2 - 2)/(x - sqrt(2)) + >>> cancel((x**2 - 2)/(x - sqrt(2)), extension=sqrt(2)) + x + sqrt(2) + >>> gcd(x**2 - 2, x - sqrt(2)) + 1 + >>> gcd(x**2 - 2, x - sqrt(2), extension=sqrt(2)) + x - sqrt(2) + + When using the domain directly :ref:`QQ(a)` can be used as a constructor + to create instances which then support the operations ``+,-,*,**,/``. The + :py:meth:`~.Domain.algebraic_field` method is used to construct a + particular :ref:`QQ(a)` domain. The :py:meth:`~.Domain.from_sympy` method + can be used to create domain elements from normal SymPy expressions. + + >>> K = QQ.algebraic_field(sqrt(2)) + >>> K + QQ + >>> xk = K.from_sympy(3 + 4*sqrt(2)) + >>> xk # doctest: +SKIP + ANP([4, 3], [1, 0, -2], QQ) + + Elements of :ref:`QQ(a)` are instances of :py:class:`~.ANP` which have + limited printing support. The raw display shows the internal + representation of the element as the list ``[4, 3]`` representing the + coefficients of ``1`` and ``sqrt(2)`` for this element in the form + ``a * sqrt(2) + b * 1`` where ``a`` and ``b`` are elements of :ref:`QQ`. + The minimal polynomial for the generator ``(x**2 - 2)`` is also shown in + the :ref:`dup-representation` as the list ``[1, 0, -2]``. We can use + :py:meth:`~.Domain.to_sympy` to get a better printed form for the + elements and to see the results of operations. + + >>> xk = K.from_sympy(3 + 4*sqrt(2)) + >>> yk = K.from_sympy(2 + 3*sqrt(2)) + >>> xk * yk # doctest: +SKIP + ANP([17, 30], [1, 0, -2], QQ) + >>> K.to_sympy(xk * yk) + 17*sqrt(2) + 30 + >>> K.to_sympy(xk + yk) + 5 + 7*sqrt(2) + >>> K.to_sympy(xk ** 2) + 24*sqrt(2) + 41 + >>> K.to_sympy(xk / yk) + sqrt(2)/14 + 9/7 + + Any expression representing an algebraic number can be used to generate + a :ref:`QQ(a)` domain provided its `minimal polynomial`_ can be computed. + The function :py:func:`~.minpoly` function is used for this. + + >>> from sympy import exp, I, pi, minpoly + >>> g = exp(2*I*pi/3) + >>> g + exp(2*I*pi/3) + >>> g.is_algebraic + True + >>> minpoly(g, x) + x**2 + x + 1 + >>> factor(x**3 - 1, extension=g) + (x - 1)*(x - exp(2*I*pi/3))*(x + 1 + exp(2*I*pi/3)) + + It is also possible to make an algebraic field from multiple extension + elements. + + >>> K = QQ.algebraic_field(sqrt(2), sqrt(3)) + >>> K + QQ + >>> p = x**4 - 5*x**2 + 6 + >>> factor(p) + (x**2 - 3)*(x**2 - 2) + >>> factor(p, domain=K) + (x - sqrt(2))*(x + sqrt(2))*(x - sqrt(3))*(x + sqrt(3)) + >>> factor(p, extension=[sqrt(2), sqrt(3)]) + (x - sqrt(2))*(x + sqrt(2))*(x - sqrt(3))*(x + sqrt(3)) + + Multiple extension elements are always combined together to make a single + `primitive element`_. In the case of ``[sqrt(2), sqrt(3)]`` the primitive + element chosen is ``sqrt(2) + sqrt(3)`` which is why the domain displays + as ``QQ``. The minimal polynomial for the primitive + element is computed using the :py:func:`~.primitive_element` function. + + >>> from sympy import primitive_element + >>> primitive_element([sqrt(2), sqrt(3)], x) + (x**4 - 10*x**2 + 1, [1, 1]) + >>> minpoly(sqrt(2) + sqrt(3), x) + x**4 - 10*x**2 + 1 + + The extension elements that generate the domain can be accessed from the + domain using the :py:attr:`~.ext` and :py:attr:`~.orig_ext` attributes as + instances of :py:class:`~.AlgebraicNumber`. The minimal polynomial for + the primitive element as a :py:class:`~.DMP` instance is available as + :py:attr:`~.mod`. + + >>> K = QQ.algebraic_field(sqrt(2), sqrt(3)) + >>> K + QQ + >>> K.ext + sqrt(2) + sqrt(3) + >>> K.orig_ext + (sqrt(2), sqrt(3)) + >>> K.mod # doctest: +SKIP + DMP_Python([1, 0, -10, 0, 1], QQ) + + The `discriminant`_ of the field can be obtained from the + :py:meth:`~.discriminant` method, and an `integral basis`_ from the + :py:meth:`~.integral_basis` method. The latter returns a list of + :py:class:`~.ANP` instances by default, but can be made to return instances + of :py:class:`~.Expr` or :py:class:`~.AlgebraicNumber` by passing a ``fmt`` + argument. The maximal order, or ring of integers, of the field can also be + obtained from the :py:meth:`~.maximal_order` method, as a + :py:class:`~sympy.polys.numberfields.modules.Submodule`. + + >>> zeta5 = exp(2*I*pi/5) + >>> K = QQ.algebraic_field(zeta5) + >>> K + QQ + >>> K.discriminant() + 125 + >>> K = QQ.algebraic_field(sqrt(5)) + >>> K + QQ + >>> K.integral_basis(fmt='sympy') + [1, 1/2 + sqrt(5)/2] + >>> K.maximal_order() + Submodule[[2, 0], [1, 1]]/2 + + The factorization of a rational prime into prime ideals of the field is + computed by the :py:meth:`~.primes_above` method, which returns a list + of :py:class:`~sympy.polys.numberfields.primes.PrimeIdeal` instances. + + >>> zeta7 = exp(2*I*pi/7) + >>> K = QQ.algebraic_field(zeta7) + >>> K + QQ + >>> K.primes_above(11) + [(11, _x**3 + 5*_x**2 + 4*_x - 1), (11, _x**3 - 4*_x**2 - 5*_x - 1)] + + The Galois group of the Galois closure of the field can be computed (when + the minimal polynomial of the field is of sufficiently small degree). + + >>> K.galois_group(by_name=True)[0] + S6TransitiveSubgroups.C6 + + Notes + ===== + + It is not currently possible to generate an algebraic extension over any + domain other than :ref:`QQ`. Ideally it would be possible to generate + extensions like ``QQ(x)(sqrt(x**2 - 2))``. This is equivalent to the + quotient ring ``QQ(x)[y]/(y**2 - x**2 + 2)`` and there are two + implementations of this kind of quotient ring/extension in the + :py:class:`~.QuotientRing` and :py:class:`~.MonogenicFiniteExtension` + classes. Each of those implementations needs some work to make them fully + usable though. + + .. _algebraic number field: https://en.wikipedia.org/wiki/Algebraic_number_field + .. _algebraic numbers: https://en.wikipedia.org/wiki/Algebraic_number + .. _discriminant: https://en.wikipedia.org/wiki/Discriminant_of_an_algebraic_number_field + .. _integral basis: https://en.wikipedia.org/wiki/Algebraic_number_field#Integral_basis + .. _minimal polynomial: https://en.wikipedia.org/wiki/Minimal_polynomial_(field_theory) + .. _primitive element: https://en.wikipedia.org/wiki/Primitive_element_theorem + """ + + dtype = ANP + + is_AlgebraicField = is_Algebraic = True + is_Numerical = True + + has_assoc_Ring = False + has_assoc_Field = True + + def __init__(self, dom, *ext, alias=None): + r""" + Parameters + ========== + + dom : :py:class:`~.Domain` + The base field over which this is an extension field. + Currently only :ref:`QQ` is accepted. + + *ext : One or more :py:class:`~.Expr` + Generators of the extension. These should be expressions that are + algebraic over `\mathbb{Q}`. + + alias : str, :py:class:`~.Symbol`, None, optional (default=None) + If provided, this will be used as the alias symbol for the + primitive element of the :py:class:`~.AlgebraicField`. + If ``None``, while ``ext`` consists of exactly one + :py:class:`~.AlgebraicNumber`, its alias (if any) will be used. + """ + if not dom.is_QQ: + raise DomainError("ground domain must be a rational field") + + from sympy.polys.numberfields import to_number_field + if len(ext) == 1 and isinstance(ext[0], tuple): + orig_ext = ext[0][1:] + else: + orig_ext = ext + + if alias is None and len(ext) == 1: + alias = getattr(ext[0], 'alias', None) + + self.orig_ext = orig_ext + """ + Original elements given to generate the extension. + + >>> from sympy import QQ, sqrt + >>> K = QQ.algebraic_field(sqrt(2), sqrt(3)) + >>> K.orig_ext + (sqrt(2), sqrt(3)) + """ + + self.ext = to_number_field(ext, alias=alias) + """ + Primitive element used for the extension. + + >>> from sympy import QQ, sqrt + >>> K = QQ.algebraic_field(sqrt(2), sqrt(3)) + >>> K.ext + sqrt(2) + sqrt(3) + """ + + self.mod = self.ext.minpoly.rep + """ + Minimal polynomial for the primitive element of the extension. + + >>> from sympy import QQ, sqrt + >>> K = QQ.algebraic_field(sqrt(2)) + >>> K.mod + DMP([1, 0, -2], QQ) + """ + + self.domain = self.dom = dom + + self.ngens = 1 + self.symbols = self.gens = (self.ext,) + self.unit = self([dom(1), dom(0)]) + + self.zero = self.dtype.zero(self.mod.to_list(), dom) + self.one = self.dtype.one(self.mod.to_list(), dom) + + self._maximal_order = None + self._discriminant = None + self._nilradicals_mod_p = {} + + def new(self, element): + return self.dtype(element, self.mod.to_list(), self.dom) + + def __str__(self): + return str(self.dom) + '<' + str(self.ext) + '>' + + def __hash__(self): + return hash((self.__class__.__name__, self.dtype, self.dom, self.ext)) + + def __eq__(self, other): + """Returns ``True`` if two domains are equivalent. """ + if isinstance(other, AlgebraicField): + return self.dtype == other.dtype and self.ext == other.ext + else: + return NotImplemented + + def algebraic_field(self, *extension, alias=None): + r"""Returns an algebraic field, i.e. `\mathbb{Q}(\alpha, \ldots)`. """ + return AlgebraicField(self.dom, *((self.ext,) + extension), alias=alias) + + def to_alg_num(self, a): + """Convert ``a`` of ``dtype`` to an :py:class:`~.AlgebraicNumber`. """ + return self.ext.field_element(a) + + def to_sympy(self, a): + """Convert ``a`` of ``dtype`` to a SymPy object. """ + # Precompute a converter to be reused: + if not hasattr(self, '_converter'): + self._converter = _make_converter(self) + + return self._converter(a) + + def from_sympy(self, a): + """Convert SymPy's expression to ``dtype``. """ + try: + return self([self.dom.from_sympy(a)]) + except CoercionFailed: + pass + + from sympy.polys.numberfields import to_number_field + + try: + return self(to_number_field(a, self.ext).native_coeffs()) + except (NotAlgebraic, IsomorphismFailed): + raise CoercionFailed( + "%s is not a valid algebraic number in %s" % (a, self)) + + def from_ZZ(K1, a, K0): + """Convert a Python ``int`` object to ``dtype``. """ + return K1(K1.dom.convert(a, K0)) + + def from_ZZ_python(K1, a, K0): + """Convert a Python ``int`` object to ``dtype``. """ + return K1(K1.dom.convert(a, K0)) + + def from_QQ(K1, a, K0): + """Convert a Python ``Fraction`` object to ``dtype``. """ + return K1(K1.dom.convert(a, K0)) + + def from_QQ_python(K1, a, K0): + """Convert a Python ``Fraction`` object to ``dtype``. """ + return K1(K1.dom.convert(a, K0)) + + def from_ZZ_gmpy(K1, a, K0): + """Convert a GMPY ``mpz`` object to ``dtype``. """ + return K1(K1.dom.convert(a, K0)) + + def from_QQ_gmpy(K1, a, K0): + """Convert a GMPY ``mpq`` object to ``dtype``. """ + return K1(K1.dom.convert(a, K0)) + + def from_RealField(K1, a, K0): + """Convert a mpmath ``mpf`` object to ``dtype``. """ + return K1(K1.dom.convert(a, K0)) + + def get_ring(self): + """Returns a ring associated with ``self``. """ + raise DomainError('there is no ring associated with %s' % self) + + def is_positive(self, a): + """Returns True if ``a`` is positive. """ + return self.dom.is_positive(a.LC()) + + def is_negative(self, a): + """Returns True if ``a`` is negative. """ + return self.dom.is_negative(a.LC()) + + def is_nonpositive(self, a): + """Returns True if ``a`` is non-positive. """ + return self.dom.is_nonpositive(a.LC()) + + def is_nonnegative(self, a): + """Returns True if ``a`` is non-negative. """ + return self.dom.is_nonnegative(a.LC()) + + def numer(self, a): + """Returns numerator of ``a``. """ + return a + + def denom(self, a): + """Returns denominator of ``a``. """ + return self.one + + def from_AlgebraicField(K1, a, K0): + """Convert AlgebraicField element 'a' to another AlgebraicField """ + return K1.from_sympy(K0.to_sympy(a)) + + def from_GaussianIntegerRing(K1, a, K0): + """Convert a GaussianInteger element 'a' to ``dtype``. """ + return K1.from_sympy(K0.to_sympy(a)) + + def from_GaussianRationalField(K1, a, K0): + """Convert a GaussianRational element 'a' to ``dtype``. """ + return K1.from_sympy(K0.to_sympy(a)) + + def _do_round_two(self): + from sympy.polys.numberfields.basis import round_two + ZK, dK = round_two(self, radicals=self._nilradicals_mod_p) + self._maximal_order = ZK + self._discriminant = dK + + def maximal_order(self): + """ + Compute the maximal order, or ring of integers, of the field. + + Returns + ======= + + :py:class:`~sympy.polys.numberfields.modules.Submodule`. + + See Also + ======== + + integral_basis + + """ + if self._maximal_order is None: + self._do_round_two() + return self._maximal_order + + def integral_basis(self, fmt=None): + r""" + Get an integral basis for the field. + + Parameters + ========== + + fmt : str, None, optional (default=None) + If ``None``, return a list of :py:class:`~.ANP` instances. + If ``"sympy"``, convert each element of the list to an + :py:class:`~.Expr`, using ``self.to_sympy()``. + If ``"alg"``, convert each element of the list to an + :py:class:`~.AlgebraicNumber`, using ``self.to_alg_num()``. + + Examples + ======== + + >>> from sympy import QQ, AlgebraicNumber, sqrt + >>> alpha = AlgebraicNumber(sqrt(5), alias='alpha') + >>> k = QQ.algebraic_field(alpha) + >>> B0 = k.integral_basis() + >>> B1 = k.integral_basis(fmt='sympy') + >>> B2 = k.integral_basis(fmt='alg') + >>> print(B0[1]) # doctest: +SKIP + ANP([mpq(1,2), mpq(1,2)], [mpq(1,1), mpq(0,1), mpq(-5,1)], QQ) + >>> print(B1[1]) + 1/2 + alpha/2 + >>> print(B2[1]) + alpha/2 + 1/2 + + In the last two cases we get legible expressions, which print somewhat + differently because of the different types involved: + + >>> print(type(B1[1])) + + >>> print(type(B2[1])) + + + See Also + ======== + + to_sympy + to_alg_num + maximal_order + """ + ZK = self.maximal_order() + M = ZK.QQ_matrix + n = M.shape[1] + B = [self.new(list(reversed(M[:, j].flat()))) for j in range(n)] + if fmt == 'sympy': + return [self.to_sympy(b) for b in B] + elif fmt == 'alg': + return [self.to_alg_num(b) for b in B] + return B + + def discriminant(self): + """Get the discriminant of the field.""" + if self._discriminant is None: + self._do_round_two() + return self._discriminant + + def primes_above(self, p): + """Compute the prime ideals lying above a given rational prime *p*.""" + from sympy.polys.numberfields.primes import prime_decomp + ZK = self.maximal_order() + dK = self.discriminant() + rad = self._nilradicals_mod_p.get(p) + return prime_decomp(p, ZK=ZK, dK=dK, radical=rad) + + def galois_group(self, by_name=False, max_tries=30, randomize=False): + """ + Compute the Galois group of the Galois closure of this field. + + Examples + ======== + + If the field is Galois, the order of the group will equal the degree + of the field: + + >>> from sympy import QQ + >>> from sympy.abc import x + >>> k = QQ.alg_field_from_poly(x**4 + 1) + >>> G, _ = k.galois_group() + >>> G.order() + 4 + + If the field is not Galois, then its Galois closure is a proper + extension, and the order of the Galois group will be greater than the + degree of the field: + + >>> k = QQ.alg_field_from_poly(x**4 - 2) + >>> G, _ = k.galois_group() + >>> G.order() + 8 + + See Also + ======== + + sympy.polys.numberfields.galoisgroups.galois_group + + """ + return self.ext.minpoly_of_element().galois_group( + by_name=by_name, max_tries=max_tries, randomize=randomize) + + +def _make_converter(K): + """Construct the converter to convert back to Expr""" + # Precompute the effect of converting to SymPy and expanding expressions + # like (sqrt(2) + sqrt(3))**2. Asking Expr to do the expansion on every + # conversion from K to Expr is slow. Here we compute the expansions for + # each power of the generator and collect together the resulting algebraic + # terms and the rational coefficients into a matrix. + + ext = K.ext.as_expr() + todom = K.dom.from_sympy + toexpr = K.dom.to_sympy + + if not ext.is_Add: + powers = [ext**n for n in range(K.mod.degree())] + else: + # primitive_element generates a QQ-linear combination of lower degree + # algebraic numbers to generate the higher degree extension e.g. + # QQ That means that we end up having high powers of low + # degree algebraic numbers that can be reduced. Here we will use the + # minimal polynomials of the algebraic numbers to reduce those powers + # before converting to Expr. + from sympy.polys.numberfields.minpoly import minpoly + + # Decompose ext as a linear combination of gens and make a symbol for + # each gen. + gens, coeffs = zip(*ext.as_coefficients_dict().items()) + syms = symbols(f'a:{len(gens)}', cls=Dummy) + sym2gen = dict(zip(syms, gens)) + + # Make a polynomial ring that can express ext and minpolys of all gens + # in terms of syms. + R = K.dom[syms] + monoms = [R.ring.monomial_basis(i) for i in range(R.ngens)] + ext_dict = {m: todom(c) for m, c in zip(monoms, coeffs)} + ext_poly = R.ring.from_dict(ext_dict) + minpolys = [R.from_sympy(minpoly(g, s)) for s, g in sym2gen.items()] + + # Compute all powers of ext_poly reduced modulo minpolys + powers = [R.one, ext_poly] + for n in range(2, K.mod.degree()): + ext_poly_n = (powers[-1] * ext_poly).rem(minpolys) + powers.append(ext_poly_n) + + # Convert the powers back to Expr. This will recombine some things like + # sqrt(2)*sqrt(3) -> sqrt(6). + powers = [p.as_expr().xreplace(sym2gen) for p in powers] + + # This also expands some rational powers + powers = [p.expand() for p in powers] + + # Collect the rational coefficients and algebraic Expr that can + # map the ANP coefficients into an expanded SymPy expression + terms = [dict(t.as_coeff_Mul()[::-1] for t in Add.make_args(p)) for p in powers] + algebraics = set().union(*terms) + matrix = [[todom(t.get(a, S.Zero)) for t in terms] for a in algebraics] + + # Create a function to do the conversion efficiently: + + def converter(a): + """Convert a to Expr using converter""" + ai = a.to_list()[::-1] + coeffs_dom = [sum(mij*aj for mij, aj in zip(mi, ai)) for mi in matrix] + coeffs_sympy = [toexpr(c) for c in coeffs_dom] + res = Add(*(Mul(c, a) for c, a in zip(coeffs_sympy, algebraics))) + return res + + return converter diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/characteristiczero.py b/phivenv/Lib/site-packages/sympy/polys/domains/characteristiczero.py new file mode 100644 index 0000000000000000000000000000000000000000..755a354bea9594b9e8f73256c448b3debae037b2 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/domains/characteristiczero.py @@ -0,0 +1,15 @@ +"""Implementation of :class:`CharacteristicZero` class. """ + + +from sympy.polys.domains.domain import Domain +from sympy.utilities import public + +@public +class CharacteristicZero(Domain): + """Domain that has infinite number of elements. """ + + has_CharacteristicZero = True + + def characteristic(self): + """Return the characteristic of this domain. """ + return 0 diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/complexfield.py b/phivenv/Lib/site-packages/sympy/polys/domains/complexfield.py new file mode 100644 index 0000000000000000000000000000000000000000..69f0bff2c1b311a150add88d5a1f146ea7b1726a --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/domains/complexfield.py @@ -0,0 +1,198 @@ +"""Implementation of :class:`ComplexField` class. """ + + +from sympy.external.gmpy import SYMPY_INTS +from sympy.core.numbers import Float, I +from sympy.polys.domains.characteristiczero import CharacteristicZero +from sympy.polys.domains.field import Field +from sympy.polys.domains.gaussiandomains import QQ_I +from sympy.polys.domains.simpledomain import SimpleDomain +from sympy.polys.polyerrors import DomainError, CoercionFailed +from sympy.utilities import public + +from mpmath import MPContext + + +@public +class ComplexField(Field, CharacteristicZero, SimpleDomain): + """Complex numbers up to the given precision. """ + + rep = 'CC' + + is_ComplexField = is_CC = True + + is_Exact = False + is_Numerical = True + + has_assoc_Ring = False + has_assoc_Field = True + + _default_precision = 53 + + @property + def has_default_precision(self): + return self.precision == self._default_precision + + @property + def precision(self): + return self._context.prec + + @property + def dps(self): + return self._context.dps + + @property + def tolerance(self): + return self._tolerance + + def __init__(self, prec=None, dps=None, tol=None): + # XXX: The tolerance parameter is ignored but is kept for backward + # compatibility for now. + + context = MPContext() + + if prec is None and dps is None: + context.prec = self._default_precision + elif dps is None: + context.prec = prec + elif prec is None: + context.dps = dps + else: + raise TypeError("Cannot set both prec and dps") + + self._context = context + + self._dtype = context.mpc + self.zero = self.dtype(0) + self.one = self.dtype(1) + + # XXX: Neither of these is actually used anywhere. + self._max_denom = max(2**context.prec // 200, 99) + self._tolerance = self.one / self._max_denom + + @property + def tp(self): + # XXX: Domain treats tp as an alias of dtype. Here we need two separate + # things: dtype is a callable to make/convert instances. We use tp with + # isinstance to check if an object is an instance of the domain + # already. + return self._dtype + + def dtype(self, x, y=0): + # XXX: This is needed because mpmath does not recognise fmpz. + # It might be better to add conversion routines to mpmath and if that + # happens then this can be removed. + if isinstance(x, SYMPY_INTS): + x = int(x) + if isinstance(y, SYMPY_INTS): + y = int(y) + return self._dtype(x, y) + + def __eq__(self, other): + return isinstance(other, ComplexField) and self.precision == other.precision + + def __hash__(self): + return hash((self.__class__.__name__, self._dtype, self.precision)) + + def to_sympy(self, element): + """Convert ``element`` to SymPy number. """ + return Float(element.real, self.dps) + I*Float(element.imag, self.dps) + + def from_sympy(self, expr): + """Convert SymPy's number to ``dtype``. """ + number = expr.evalf(n=self.dps) + real, imag = number.as_real_imag() + + if real.is_Number and imag.is_Number: + return self.dtype(real, imag) + else: + raise CoercionFailed("expected complex number, got %s" % expr) + + def from_ZZ(self, element, base): + return self.dtype(element) + + def from_ZZ_gmpy(self, element, base): + return self.dtype(int(element)) + + def from_ZZ_python(self, element, base): + return self.dtype(element) + + def from_QQ(self, element, base): + return self.dtype(int(element.numerator)) / int(element.denominator) + + def from_QQ_python(self, element, base): + return self.dtype(element.numerator) / element.denominator + + def from_QQ_gmpy(self, element, base): + return self.dtype(int(element.numerator)) / int(element.denominator) + + def from_GaussianIntegerRing(self, element, base): + return self.dtype(int(element.x), int(element.y)) + + def from_GaussianRationalField(self, element, base): + x = element.x + y = element.y + return (self.dtype(int(x.numerator)) / int(x.denominator) + + self.dtype(0, int(y.numerator)) / int(y.denominator)) + + def from_AlgebraicField(self, element, base): + return self.from_sympy(base.to_sympy(element).evalf(self.dps)) + + def from_RealField(self, element, base): + return self.dtype(element) + + def from_ComplexField(self, element, base): + return self.dtype(element) + + def get_ring(self): + """Returns a ring associated with ``self``. """ + raise DomainError("there is no ring associated with %s" % self) + + def get_exact(self): + """Returns an exact domain associated with ``self``. """ + return QQ_I + + def is_negative(self, element): + """Returns ``False`` for any ``ComplexElement``. """ + return False + + def is_positive(self, element): + """Returns ``False`` for any ``ComplexElement``. """ + return False + + def is_nonnegative(self, element): + """Returns ``False`` for any ``ComplexElement``. """ + return False + + def is_nonpositive(self, element): + """Returns ``False`` for any ``ComplexElement``. """ + return False + + def gcd(self, a, b): + """Returns GCD of ``a`` and ``b``. """ + return self.one + + def lcm(self, a, b): + """Returns LCM of ``a`` and ``b``. """ + return a*b + + def almosteq(self, a, b, tolerance=None): + """Check if ``a`` and ``b`` are almost equal. """ + return self._context.almosteq(a, b, tolerance) + + def is_square(self, a): + """Returns ``True``. Every complex number has a complex square root.""" + return True + + def exsqrt(self, a): + r"""Returns the principal complex square root of ``a``. + + Explanation + =========== + The argument of the principal square root is always within + $(-\frac{\pi}{2}, \frac{\pi}{2}]$. The square root may be + slightly inaccurate due to floating point rounding error. + """ + return a ** 0.5 + +CC = ComplexField() diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/compositedomain.py b/phivenv/Lib/site-packages/sympy/polys/domains/compositedomain.py new file mode 100644 index 0000000000000000000000000000000000000000..a8f63ba7bb86b1d69493b77bfa8c7f33652adbbf --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/domains/compositedomain.py @@ -0,0 +1,52 @@ +"""Implementation of :class:`CompositeDomain` class. """ + + +from sympy.polys.domains.domain import Domain +from sympy.polys.polyerrors import GeneratorsError + +from sympy.utilities import public + +@public +class CompositeDomain(Domain): + """Base class for composite domains, e.g. ZZ[x], ZZ(X). """ + + is_Composite = True + + gens, ngens, symbols, domain = [None]*4 + + def inject(self, *symbols): + """Inject generators into this domain. """ + if not (set(self.symbols) & set(symbols)): + return self.__class__(self.domain, self.symbols + symbols, self.order) + else: + raise GeneratorsError("common generators in %s and %s" % (self.symbols, symbols)) + + def drop(self, *symbols): + """Drop generators from this domain. """ + symset = set(symbols) + newsyms = tuple(s for s in self.symbols if s not in symset) + domain = self.domain.drop(*symbols) + if not newsyms: + return domain + else: + return self.__class__(domain, newsyms, self.order) + + def set_domain(self, domain): + """Set the ground domain of this domain. """ + return self.__class__(domain, self.symbols, self.order) + + @property + def is_Exact(self): + """Returns ``True`` if this domain is exact. """ + return self.domain.is_Exact + + def get_exact(self): + """Returns an exact version of this domain. """ + return self.set_domain(self.domain.get_exact()) + + @property + def has_CharacteristicZero(self): + return self.domain.has_CharacteristicZero + + def characteristic(self): + return self.domain.characteristic() diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/domain.py b/phivenv/Lib/site-packages/sympy/polys/domains/domain.py new file mode 100644 index 0000000000000000000000000000000000000000..1d7fc1eac6184601c199fb6724a11e92346789f1 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/domains/domain.py @@ -0,0 +1,1382 @@ +"""Implementation of :class:`Domain` class. """ + +from __future__ import annotations +from typing import Any + +from sympy.core.numbers import AlgebraicNumber +from sympy.core import Basic, sympify +from sympy.core.sorting import ordered +from sympy.external.gmpy import GROUND_TYPES +from sympy.polys.domains.domainelement import DomainElement +from sympy.polys.orderings import lex +from sympy.polys.polyerrors import UnificationFailed, CoercionFailed, DomainError +from sympy.polys.polyutils import _unify_gens, _not_a_coeff +from sympy.utilities import public +from sympy.utilities.iterables import is_sequence + + +@public +class Domain: + """Superclass for all domains in the polys domains system. + + See :ref:`polys-domainsintro` for an introductory explanation of the + domains system. + + The :py:class:`~.Domain` class is an abstract base class for all of the + concrete domain types. There are many different :py:class:`~.Domain` + subclasses each of which has an associated ``dtype`` which is a class + representing the elements of the domain. The coefficients of a + :py:class:`~.Poly` are elements of a domain which must be a subclass of + :py:class:`~.Domain`. + + Examples + ======== + + The most common example domains are the integers :ref:`ZZ` and the + rationals :ref:`QQ`. + + >>> from sympy import Poly, symbols, Domain + >>> x, y = symbols('x, y') + >>> p = Poly(x**2 + y) + >>> p + Poly(x**2 + y, x, y, domain='ZZ') + >>> p.domain + ZZ + >>> isinstance(p.domain, Domain) + True + >>> Poly(x**2 + y/2) + Poly(x**2 + 1/2*y, x, y, domain='QQ') + + The domains can be used directly in which case the domain object e.g. + (:ref:`ZZ` or :ref:`QQ`) can be used as a constructor for elements of + ``dtype``. + + >>> from sympy import ZZ, QQ + >>> ZZ(2) + 2 + >>> ZZ.dtype # doctest: +SKIP + + >>> type(ZZ(2)) # doctest: +SKIP + + >>> QQ(1, 2) + 1/2 + >>> type(QQ(1, 2)) # doctest: +SKIP + + + The corresponding domain elements can be used with the arithmetic + operations ``+,-,*,**`` and depending on the domain some combination of + ``/,//,%`` might be usable. For example in :ref:`ZZ` both ``//`` (floor + division) and ``%`` (modulo division) can be used but ``/`` (true + division) cannot. Since :ref:`QQ` is a :py:class:`~.Field` its elements + can be used with ``/`` but ``//`` and ``%`` should not be used. Some + domains have a :py:meth:`~.Domain.gcd` method. + + >>> ZZ(2) + ZZ(3) + 5 + >>> ZZ(5) // ZZ(2) + 2 + >>> ZZ(5) % ZZ(2) + 1 + >>> QQ(1, 2) / QQ(2, 3) + 3/4 + >>> ZZ.gcd(ZZ(4), ZZ(2)) + 2 + >>> QQ.gcd(QQ(2,7), QQ(5,3)) + 1/21 + >>> ZZ.is_Field + False + >>> QQ.is_Field + True + + There are also many other domains including: + + 1. :ref:`GF(p)` for finite fields of prime order. + 2. :ref:`RR` for real (floating point) numbers. + 3. :ref:`CC` for complex (floating point) numbers. + 4. :ref:`QQ(a)` for algebraic number fields. + 5. :ref:`K[x]` for polynomial rings. + 6. :ref:`K(x)` for rational function fields. + 7. :ref:`EX` for arbitrary expressions. + + Each domain is represented by a domain object and also an implementation + class (``dtype``) for the elements of the domain. For example the + :ref:`K[x]` domains are represented by a domain object which is an + instance of :py:class:`~.PolynomialRing` and the elements are always + instances of :py:class:`~.PolyElement`. The implementation class + represents particular types of mathematical expressions in a way that is + more efficient than a normal SymPy expression which is of type + :py:class:`~.Expr`. The domain methods :py:meth:`~.Domain.from_sympy` and + :py:meth:`~.Domain.to_sympy` are used to convert from :py:class:`~.Expr` + to a domain element and vice versa. + + >>> from sympy import Symbol, ZZ, Expr + >>> x = Symbol('x') + >>> K = ZZ[x] # polynomial ring domain + >>> K + ZZ[x] + >>> type(K) # class of the domain + + >>> K.dtype # doctest: +SKIP + + >>> p_expr = x**2 + 1 # Expr + >>> p_expr + x**2 + 1 + >>> type(p_expr) + + >>> isinstance(p_expr, Expr) + True + >>> p_domain = K.from_sympy(p_expr) + >>> p_domain # domain element + x**2 + 1 + >>> type(p_domain) + + >>> K.to_sympy(p_domain) == p_expr + True + + The :py:meth:`~.Domain.convert_from` method is used to convert domain + elements from one domain to another. + + >>> from sympy import ZZ, QQ + >>> ez = ZZ(2) + >>> eq = QQ.convert_from(ez, ZZ) + >>> type(ez) # doctest: +SKIP + + >>> type(eq) # doctest: +SKIP + + + Elements from different domains should not be mixed in arithmetic or other + operations: they should be converted to a common domain first. The domain + method :py:meth:`~.Domain.unify` is used to find a domain that can + represent all the elements of two given domains. + + >>> from sympy import ZZ, QQ, symbols + >>> x, y = symbols('x, y') + >>> ZZ.unify(QQ) + QQ + >>> ZZ[x].unify(QQ) + QQ[x] + >>> ZZ[x].unify(QQ[y]) + QQ[x,y] + + If a domain is a :py:class:`~.Ring` then is might have an associated + :py:class:`~.Field` and vice versa. The :py:meth:`~.Domain.get_field` and + :py:meth:`~.Domain.get_ring` methods will find or create the associated + domain. + + >>> from sympy import ZZ, QQ, Symbol + >>> x = Symbol('x') + >>> ZZ.has_assoc_Field + True + >>> ZZ.get_field() + QQ + >>> QQ.has_assoc_Ring + True + >>> QQ.get_ring() + ZZ + >>> K = QQ[x] + >>> K + QQ[x] + >>> K.get_field() + QQ(x) + + See also + ======== + + DomainElement: abstract base class for domain elements + construct_domain: construct a minimal domain for some expressions + + """ + + dtype: type | None = None + """The type (class) of the elements of this :py:class:`~.Domain`: + + >>> from sympy import ZZ, QQ, Symbol + >>> ZZ.dtype + + >>> z = ZZ(2) + >>> z + 2 + >>> type(z) + + >>> type(z) == ZZ.dtype + True + + Every domain has an associated **dtype** ("datatype") which is the + class of the associated domain elements. + + See also + ======== + + of_type + """ + + zero: Any = None + """The zero element of the :py:class:`~.Domain`: + + >>> from sympy import QQ + >>> QQ.zero + 0 + >>> QQ.of_type(QQ.zero) + True + + See also + ======== + + of_type + one + """ + + one: Any = None + """The one element of the :py:class:`~.Domain`: + + >>> from sympy import QQ + >>> QQ.one + 1 + >>> QQ.of_type(QQ.one) + True + + See also + ======== + + of_type + zero + """ + + is_Ring = False + """Boolean flag indicating if the domain is a :py:class:`~.Ring`. + + >>> from sympy import ZZ + >>> ZZ.is_Ring + True + + Basically every :py:class:`~.Domain` represents a ring so this flag is + not that useful. + + See also + ======== + + is_PID + is_Field + get_ring + has_assoc_Ring + """ + + is_Field = False + """Boolean flag indicating if the domain is a :py:class:`~.Field`. + + >>> from sympy import ZZ, QQ + >>> ZZ.is_Field + False + >>> QQ.is_Field + True + + See also + ======== + + is_PID + is_Ring + get_field + has_assoc_Field + """ + + has_assoc_Ring = False + """Boolean flag indicating if the domain has an associated + :py:class:`~.Ring`. + + >>> from sympy import QQ + >>> QQ.has_assoc_Ring + True + >>> QQ.get_ring() + ZZ + + See also + ======== + + is_Field + get_ring + """ + + has_assoc_Field = False + """Boolean flag indicating if the domain has an associated + :py:class:`~.Field`. + + >>> from sympy import ZZ + >>> ZZ.has_assoc_Field + True + >>> ZZ.get_field() + QQ + + See also + ======== + + is_Field + get_field + """ + + is_FiniteField = is_FF = False + is_IntegerRing = is_ZZ = False + is_RationalField = is_QQ = False + is_GaussianRing = is_ZZ_I = False + is_GaussianField = is_QQ_I = False + is_RealField = is_RR = False + is_ComplexField = is_CC = False + is_AlgebraicField = is_Algebraic = False + is_PolynomialRing = is_Poly = False + is_FractionField = is_Frac = False + is_SymbolicDomain = is_EX = False + is_SymbolicRawDomain = is_EXRAW = False + is_FiniteExtension = False + + is_Exact = True + is_Numerical = False + + is_Simple = False + is_Composite = False + + is_PID = False + """Boolean flag indicating if the domain is a `principal ideal domain`_. + + >>> from sympy import ZZ + >>> ZZ.has_assoc_Field + True + >>> ZZ.get_field() + QQ + + .. _principal ideal domain: https://en.wikipedia.org/wiki/Principal_ideal_domain + + See also + ======== + + is_Field + get_field + """ + + has_CharacteristicZero = False + + rep: str | None = None + alias: str | None = None + + def __init__(self): + raise NotImplementedError + + def __str__(self): + return self.rep + + def __repr__(self): + return str(self) + + def __hash__(self): + return hash((self.__class__.__name__, self.dtype)) + + def new(self, *args): + return self.dtype(*args) + + @property + def tp(self): + """Alias for :py:attr:`~.Domain.dtype`""" + return self.dtype + + def __call__(self, *args): + """Construct an element of ``self`` domain from ``args``. """ + return self.new(*args) + + def normal(self, *args): + return self.dtype(*args) + + def convert_from(self, element, base): + """Convert ``element`` to ``self.dtype`` given the base domain. """ + if base.alias is not None: + method = "from_" + base.alias + else: + method = "from_" + base.__class__.__name__ + + _convert = getattr(self, method) + + if _convert is not None: + result = _convert(element, base) + + if result is not None: + return result + + raise CoercionFailed("Cannot convert %s of type %s from %s to %s" % (element, type(element), base, self)) + + def convert(self, element, base=None): + """Convert ``element`` to ``self.dtype``. """ + + if base is not None: + if _not_a_coeff(element): + raise CoercionFailed('%s is not in any domain' % element) + return self.convert_from(element, base) + + if self.of_type(element): + return element + + if _not_a_coeff(element): + raise CoercionFailed('%s is not in any domain' % element) + + from sympy.polys.domains import ZZ, QQ, RealField, ComplexField + + if ZZ.of_type(element): + return self.convert_from(element, ZZ) + + if isinstance(element, int): + return self.convert_from(ZZ(element), ZZ) + + if GROUND_TYPES != 'python': + if isinstance(element, ZZ.tp): + return self.convert_from(element, ZZ) + if isinstance(element, QQ.tp): + return self.convert_from(element, QQ) + + if isinstance(element, float): + parent = RealField() + return self.convert_from(parent(element), parent) + + if isinstance(element, complex): + parent = ComplexField() + return self.convert_from(parent(element), parent) + + if type(element).__name__ == 'mpf': + parent = RealField() + return self.convert_from(parent(element), parent) + + if type(element).__name__ == 'mpc': + parent = ComplexField() + return self.convert_from(parent(element), parent) + + if isinstance(element, DomainElement): + return self.convert_from(element, element.parent()) + + # TODO: implement this in from_ methods + if self.is_Numerical and getattr(element, 'is_ground', False): + return self.convert(element.LC()) + + if isinstance(element, Basic): + try: + return self.from_sympy(element) + except (TypeError, ValueError): + pass + else: # TODO: remove this branch + if not is_sequence(element): + try: + element = sympify(element, strict=True) + if isinstance(element, Basic): + return self.from_sympy(element) + except (TypeError, ValueError): + pass + + raise CoercionFailed("Cannot convert %s of type %s to %s" % (element, type(element), self)) + + def of_type(self, element): + """Check if ``a`` is of type ``dtype``. """ + return isinstance(element, self.tp) + + def __contains__(self, a): + """Check if ``a`` belongs to this domain. """ + try: + if _not_a_coeff(a): + raise CoercionFailed + self.convert(a) # this might raise, too + except CoercionFailed: + return False + + return True + + def to_sympy(self, a): + """Convert domain element *a* to a SymPy expression (Expr). + + Explanation + =========== + + Convert a :py:class:`~.Domain` element *a* to :py:class:`~.Expr`. Most + public SymPy functions work with objects of type :py:class:`~.Expr`. + The elements of a :py:class:`~.Domain` have a different internal + representation. It is not possible to mix domain elements with + :py:class:`~.Expr` so each domain has :py:meth:`~.Domain.to_sympy` and + :py:meth:`~.Domain.from_sympy` methods to convert its domain elements + to and from :py:class:`~.Expr`. + + Parameters + ========== + + a: domain element + An element of this :py:class:`~.Domain`. + + Returns + ======= + + expr: Expr + A normal SymPy expression of type :py:class:`~.Expr`. + + Examples + ======== + + Construct an element of the :ref:`QQ` domain and then convert it to + :py:class:`~.Expr`. + + >>> from sympy import QQ, Expr + >>> q_domain = QQ(2) + >>> q_domain + 2 + >>> q_expr = QQ.to_sympy(q_domain) + >>> q_expr + 2 + + Although the printed forms look similar these objects are not of the + same type. + + >>> isinstance(q_domain, Expr) + False + >>> isinstance(q_expr, Expr) + True + + Construct an element of :ref:`K[x]` and convert to + :py:class:`~.Expr`. + + >>> from sympy import Symbol + >>> x = Symbol('x') + >>> K = QQ[x] + >>> x_domain = K.gens[0] # generator x as a domain element + >>> p_domain = x_domain**2/3 + 1 + >>> p_domain + 1/3*x**2 + 1 + >>> p_expr = K.to_sympy(p_domain) + >>> p_expr + x**2/3 + 1 + + The :py:meth:`~.Domain.from_sympy` method is used for the opposite + conversion from a normal SymPy expression to a domain element. + + >>> p_domain == p_expr + False + >>> K.from_sympy(p_expr) == p_domain + True + >>> K.to_sympy(p_domain) == p_expr + True + >>> K.from_sympy(K.to_sympy(p_domain)) == p_domain + True + >>> K.to_sympy(K.from_sympy(p_expr)) == p_expr + True + + The :py:meth:`~.Domain.from_sympy` method makes it easier to construct + domain elements interactively. + + >>> from sympy import Symbol + >>> x = Symbol('x') + >>> K = QQ[x] + >>> K.from_sympy(x**2/3 + 1) + 1/3*x**2 + 1 + + See also + ======== + + from_sympy + convert_from + """ + raise NotImplementedError + + def from_sympy(self, a): + """Convert a SymPy expression to an element of this domain. + + Explanation + =========== + + See :py:meth:`~.Domain.to_sympy` for explanation and examples. + + Parameters + ========== + + expr: Expr + A normal SymPy expression of type :py:class:`~.Expr`. + + Returns + ======= + + a: domain element + An element of this :py:class:`~.Domain`. + + See also + ======== + + to_sympy + convert_from + """ + raise NotImplementedError + + def sum(self, args): + return sum(args, start=self.zero) + + def from_FF(K1, a, K0): + """Convert ``ModularInteger(int)`` to ``dtype``. """ + return None + + def from_FF_python(K1, a, K0): + """Convert ``ModularInteger(int)`` to ``dtype``. """ + return None + + def from_ZZ_python(K1, a, K0): + """Convert a Python ``int`` object to ``dtype``. """ + return None + + def from_QQ_python(K1, a, K0): + """Convert a Python ``Fraction`` object to ``dtype``. """ + return None + + def from_FF_gmpy(K1, a, K0): + """Convert ``ModularInteger(mpz)`` to ``dtype``. """ + return None + + def from_ZZ_gmpy(K1, a, K0): + """Convert a GMPY ``mpz`` object to ``dtype``. """ + return None + + def from_QQ_gmpy(K1, a, K0): + """Convert a GMPY ``mpq`` object to ``dtype``. """ + return None + + def from_RealField(K1, a, K0): + """Convert a real element object to ``dtype``. """ + return None + + def from_ComplexField(K1, a, K0): + """Convert a complex element to ``dtype``. """ + return None + + def from_AlgebraicField(K1, a, K0): + """Convert an algebraic number to ``dtype``. """ + return None + + def from_PolynomialRing(K1, a, K0): + """Convert a polynomial to ``dtype``. """ + if a.is_ground: + return K1.convert(a.LC, K0.dom) + + def from_FractionField(K1, a, K0): + """Convert a rational function to ``dtype``. """ + return None + + def from_MonogenicFiniteExtension(K1, a, K0): + """Convert an ``ExtensionElement`` to ``dtype``. """ + return K1.convert_from(a.rep, K0.ring) + + def from_ExpressionDomain(K1, a, K0): + """Convert a ``EX`` object to ``dtype``. """ + return K1.from_sympy(a.ex) + + def from_ExpressionRawDomain(K1, a, K0): + """Convert a ``EX`` object to ``dtype``. """ + return K1.from_sympy(a) + + def from_GlobalPolynomialRing(K1, a, K0): + """Convert a polynomial to ``dtype``. """ + if a.degree() <= 0: + return K1.convert(a.LC(), K0.dom) + + def from_GeneralizedPolynomialRing(K1, a, K0): + return K1.from_FractionField(a, K0) + + def unify_with_symbols(K0, K1, symbols): + if (K0.is_Composite and (set(K0.symbols) & set(symbols))) or (K1.is_Composite and (set(K1.symbols) & set(symbols))): + raise UnificationFailed("Cannot unify %s with %s, given %s generators" % (K0, K1, tuple(symbols))) + + return K0.unify(K1) + + def unify_composite(K0, K1): + """Unify two domains where at least one is composite.""" + K0_ground = K0.dom if K0.is_Composite else K0 + K1_ground = K1.dom if K1.is_Composite else K1 + + K0_symbols = K0.symbols if K0.is_Composite else () + K1_symbols = K1.symbols if K1.is_Composite else () + + domain = K0_ground.unify(K1_ground) + symbols = _unify_gens(K0_symbols, K1_symbols) + order = K0.order if K0.is_Composite else K1.order + + # E.g. ZZ[x].unify(QQ.frac_field(x)) -> ZZ.frac_field(x) + if ((K0.is_FractionField and K1.is_PolynomialRing or + K1.is_FractionField and K0.is_PolynomialRing) and + (not K0_ground.is_Field or not K1_ground.is_Field) and domain.is_Field + and domain.has_assoc_Ring): + domain = domain.get_ring() + + if K0.is_Composite and (not K1.is_Composite or K0.is_FractionField or K1.is_PolynomialRing): + cls = K0.__class__ + else: + cls = K1.__class__ + + # Here cls might be PolynomialRing, FractionField, GlobalPolynomialRing + # (dense/old Polynomialring) or dense/old FractionField. + + from sympy.polys.domains.old_polynomialring import GlobalPolynomialRing + if cls == GlobalPolynomialRing: + return cls(domain, symbols) + + return cls(domain, symbols, order) + + def unify(K0, K1, symbols=None): + """ + Construct a minimal domain that contains elements of ``K0`` and ``K1``. + + Known domains (from smallest to largest): + + - ``GF(p)`` + - ``ZZ`` + - ``QQ`` + - ``RR(prec, tol)`` + - ``CC(prec, tol)`` + - ``ALG(a, b, c)`` + - ``K[x, y, z]`` + - ``K(x, y, z)`` + - ``EX`` + + """ + if symbols is not None: + return K0.unify_with_symbols(K1, symbols) + + if K0 == K1: + return K0 + + if not (K0.has_CharacteristicZero and K1.has_CharacteristicZero): + # Reject unification of domains with different characteristics. + if K0.characteristic() != K1.characteristic(): + raise UnificationFailed("Cannot unify %s with %s" % (K0, K1)) + + # We do not get here if K0 == K1. The two domains have the same + # characteristic but are unequal so at least one is composite and + # we are unifying something like GF(3).unify(GF(3)[x]). + return K0.unify_composite(K1) + + # From here we know both domains have characteristic zero and it can be + # acceptable to fall back on EX. + + if K0.is_EXRAW: + return K0 + if K1.is_EXRAW: + return K1 + + if K0.is_EX: + return K0 + if K1.is_EX: + return K1 + + if K0.is_FiniteExtension or K1.is_FiniteExtension: + if K1.is_FiniteExtension: + K0, K1 = K1, K0 + if K1.is_FiniteExtension: + # Unifying two extensions. + # Try to ensure that K0.unify(K1) == K1.unify(K0) + if list(ordered([K0.modulus, K1.modulus]))[1] == K0.modulus: + K0, K1 = K1, K0 + return K1.set_domain(K0) + else: + # Drop the generator from other and unify with the base domain + K1 = K1.drop(K0.symbol) + K1 = K0.domain.unify(K1) + return K0.set_domain(K1) + + if K0.is_Composite or K1.is_Composite: + return K0.unify_composite(K1) + + if K1.is_ComplexField: + K0, K1 = K1, K0 + if K0.is_ComplexField: + if K1.is_ComplexField or K1.is_RealField: + if K0.precision >= K1.precision: + return K0 + else: + from sympy.polys.domains.complexfield import ComplexField + return ComplexField(prec=K1.precision) + else: + return K0 + + if K1.is_RealField: + K0, K1 = K1, K0 + if K0.is_RealField: + if K1.is_RealField: + if K0.precision >= K1.precision: + return K0 + else: + return K1 + elif K1.is_GaussianRing or K1.is_GaussianField: + from sympy.polys.domains.complexfield import ComplexField + return ComplexField(prec=K0.precision) + else: + return K0 + + if K1.is_AlgebraicField: + K0, K1 = K1, K0 + if K0.is_AlgebraicField: + if K1.is_GaussianRing: + K1 = K1.get_field() + if K1.is_GaussianField: + K1 = K1.as_AlgebraicField() + if K1.is_AlgebraicField: + return K0.__class__(K0.dom.unify(K1.dom), *_unify_gens(K0.orig_ext, K1.orig_ext)) + else: + return K0 + + if K0.is_GaussianField: + return K0 + if K1.is_GaussianField: + return K1 + + if K0.is_GaussianRing: + if K1.is_RationalField: + K0 = K0.get_field() + return K0 + if K1.is_GaussianRing: + if K0.is_RationalField: + K1 = K1.get_field() + return K1 + + if K0.is_RationalField: + return K0 + if K1.is_RationalField: + return K1 + + if K0.is_IntegerRing: + return K0 + if K1.is_IntegerRing: + return K1 + + from sympy.polys.domains import EX + return EX + + def __eq__(self, other): + """Returns ``True`` if two domains are equivalent. """ + # XXX: Remove this. + return isinstance(other, Domain) and self.dtype == other.dtype + + def __ne__(self, other): + """Returns ``False`` if two domains are equivalent. """ + return not self == other + + def map(self, seq): + """Rersively apply ``self`` to all elements of ``seq``. """ + result = [] + + for elt in seq: + if isinstance(elt, list): + result.append(self.map(elt)) + else: + result.append(self(elt)) + + return result + + def get_ring(self): + """Returns a ring associated with ``self``. """ + raise DomainError('there is no ring associated with %s' % self) + + def get_field(self): + """Returns a field associated with ``self``. """ + raise DomainError('there is no field associated with %s' % self) + + def get_exact(self): + """Returns an exact domain associated with ``self``. """ + return self + + def __getitem__(self, symbols): + """The mathematical way to make a polynomial ring. """ + if hasattr(symbols, '__iter__'): + return self.poly_ring(*symbols) + else: + return self.poly_ring(symbols) + + def poly_ring(self, *symbols, order=lex): + """Returns a polynomial ring, i.e. `K[X]`. """ + from sympy.polys.domains.polynomialring import PolynomialRing + return PolynomialRing(self, symbols, order) + + def frac_field(self, *symbols, order=lex): + """Returns a fraction field, i.e. `K(X)`. """ + from sympy.polys.domains.fractionfield import FractionField + return FractionField(self, symbols, order) + + def old_poly_ring(self, *symbols, **kwargs): + """Returns a polynomial ring, i.e. `K[X]`. """ + from sympy.polys.domains.old_polynomialring import PolynomialRing + return PolynomialRing(self, *symbols, **kwargs) + + def old_frac_field(self, *symbols, **kwargs): + """Returns a fraction field, i.e. `K(X)`. """ + from sympy.polys.domains.old_fractionfield import FractionField + return FractionField(self, *symbols, **kwargs) + + def algebraic_field(self, *extension, alias=None): + r"""Returns an algebraic field, i.e. `K(\alpha, \ldots)`. """ + raise DomainError("Cannot create algebraic field over %s" % self) + + def alg_field_from_poly(self, poly, alias=None, root_index=-1): + r""" + Convenience method to construct an algebraic extension on a root of a + polynomial, chosen by root index. + + Parameters + ========== + + poly : :py:class:`~.Poly` + The polynomial whose root generates the extension. + alias : str, optional (default=None) + Symbol name for the generator of the extension. + E.g. "alpha" or "theta". + root_index : int, optional (default=-1) + Specifies which root of the polynomial is desired. The ordering is + as defined by the :py:class:`~.ComplexRootOf` class. The default of + ``-1`` selects the most natural choice in the common cases of + quadratic and cyclotomic fields (the square root on the positive + real or imaginary axis, resp. $\mathrm{e}^{2\pi i/n}$). + + Examples + ======== + + >>> from sympy import QQ, Poly + >>> from sympy.abc import x + >>> f = Poly(x**2 - 2) + >>> K = QQ.alg_field_from_poly(f) + >>> K.ext.minpoly == f + True + >>> g = Poly(8*x**3 - 6*x - 1) + >>> L = QQ.alg_field_from_poly(g, "alpha") + >>> L.ext.minpoly == g + True + >>> L.to_sympy(L([1, 1, 1])) + alpha**2 + alpha + 1 + + """ + from sympy.polys.rootoftools import CRootOf + root = CRootOf(poly, root_index) + alpha = AlgebraicNumber(root, alias=alias) + return self.algebraic_field(alpha, alias=alias) + + def cyclotomic_field(self, n, ss=False, alias="zeta", gen=None, root_index=-1): + r""" + Convenience method to construct a cyclotomic field. + + Parameters + ========== + + n : int + Construct the nth cyclotomic field. + ss : boolean, optional (default=False) + If True, append *n* as a subscript on the alias string. + alias : str, optional (default="zeta") + Symbol name for the generator. + gen : :py:class:`~.Symbol`, optional (default=None) + Desired variable for the cyclotomic polynomial that defines the + field. If ``None``, a dummy variable will be used. + root_index : int, optional (default=-1) + Specifies which root of the polynomial is desired. The ordering is + as defined by the :py:class:`~.ComplexRootOf` class. The default of + ``-1`` selects the root $\mathrm{e}^{2\pi i/n}$. + + Examples + ======== + + >>> from sympy import QQ, latex + >>> K = QQ.cyclotomic_field(5) + >>> K.to_sympy(K([-1, 1])) + 1 - zeta + >>> L = QQ.cyclotomic_field(7, True) + >>> a = L.to_sympy(L([-1, 1])) + >>> print(a) + 1 - zeta7 + >>> print(latex(a)) + 1 - \zeta_{7} + + """ + from sympy.polys.specialpolys import cyclotomic_poly + if ss: + alias += str(n) + return self.alg_field_from_poly(cyclotomic_poly(n, gen), alias=alias, + root_index=root_index) + + def inject(self, *symbols): + """Inject generators into this domain. """ + raise NotImplementedError + + def drop(self, *symbols): + """Drop generators from this domain. """ + if self.is_Simple: + return self + raise NotImplementedError # pragma: no cover + + def is_zero(self, a): + """Returns True if ``a`` is zero. """ + return not a + + def is_one(self, a): + """Returns True if ``a`` is one. """ + return a == self.one + + def is_positive(self, a): + """Returns True if ``a`` is positive. """ + return a > 0 + + def is_negative(self, a): + """Returns True if ``a`` is negative. """ + return a < 0 + + def is_nonpositive(self, a): + """Returns True if ``a`` is non-positive. """ + return a <= 0 + + def is_nonnegative(self, a): + """Returns True if ``a`` is non-negative. """ + return a >= 0 + + def canonical_unit(self, a): + if self.is_negative(a): + return -self.one + else: + return self.one + + def abs(self, a): + """Absolute value of ``a``, implies ``__abs__``. """ + return abs(a) + + def neg(self, a): + """Returns ``a`` negated, implies ``__neg__``. """ + return -a + + def pos(self, a): + """Returns ``a`` positive, implies ``__pos__``. """ + return +a + + def add(self, a, b): + """Sum of ``a`` and ``b``, implies ``__add__``. """ + return a + b + + def sub(self, a, b): + """Difference of ``a`` and ``b``, implies ``__sub__``. """ + return a - b + + def mul(self, a, b): + """Product of ``a`` and ``b``, implies ``__mul__``. """ + return a * b + + def pow(self, a, b): + """Raise ``a`` to power ``b``, implies ``__pow__``. """ + return a ** b + + def exquo(self, a, b): + """Exact quotient of *a* and *b*. Analogue of ``a / b``. + + Explanation + =========== + + This is essentially the same as ``a / b`` except that an error will be + raised if the division is inexact (if there is any remainder) and the + result will always be a domain element. When working in a + :py:class:`~.Domain` that is not a :py:class:`~.Field` (e.g. :ref:`ZZ` + or :ref:`K[x]`) ``exquo`` should be used instead of ``/``. + + The key invariant is that if ``q = K.exquo(a, b)`` (and ``exquo`` does + not raise an exception) then ``a == b*q``. + + Examples + ======== + + We can use ``K.exquo`` instead of ``/`` for exact division. + + >>> from sympy import ZZ + >>> ZZ.exquo(ZZ(4), ZZ(2)) + 2 + >>> ZZ.exquo(ZZ(5), ZZ(2)) + Traceback (most recent call last): + ... + ExactQuotientFailed: 2 does not divide 5 in ZZ + + Over a :py:class:`~.Field` such as :ref:`QQ`, division (with nonzero + divisor) is always exact so in that case ``/`` can be used instead of + :py:meth:`~.Domain.exquo`. + + >>> from sympy import QQ + >>> QQ.exquo(QQ(5), QQ(2)) + 5/2 + >>> QQ(5) / QQ(2) + 5/2 + + Parameters + ========== + + a: domain element + The dividend + b: domain element + The divisor + + Returns + ======= + + q: domain element + The exact quotient + + Raises + ====== + + ExactQuotientFailed: if exact division is not possible. + ZeroDivisionError: when the divisor is zero. + + See also + ======== + + quo: Analogue of ``a // b`` + rem: Analogue of ``a % b`` + div: Analogue of ``divmod(a, b)`` + + Notes + ===== + + Since the default :py:attr:`~.Domain.dtype` for :ref:`ZZ` is ``int`` + (or ``mpz``) division as ``a / b`` should not be used as it would give + a ``float`` which is not a domain element. + + >>> ZZ(4) / ZZ(2) # doctest: +SKIP + 2.0 + >>> ZZ(5) / ZZ(2) # doctest: +SKIP + 2.5 + + On the other hand with `SYMPY_GROUND_TYPES=flint` elements of :ref:`ZZ` + are ``flint.fmpz`` and division would raise an exception: + + >>> ZZ(4) / ZZ(2) # doctest: +SKIP + Traceback (most recent call last): + ... + TypeError: unsupported operand type(s) for /: 'fmpz' and 'fmpz' + + Using ``/`` with :ref:`ZZ` will lead to incorrect results so + :py:meth:`~.Domain.exquo` should be used instead. + + """ + raise NotImplementedError + + def quo(self, a, b): + """Quotient of *a* and *b*. Analogue of ``a // b``. + + ``K.quo(a, b)`` is equivalent to ``K.div(a, b)[0]``. See + :py:meth:`~.Domain.div` for more explanation. + + See also + ======== + + rem: Analogue of ``a % b`` + div: Analogue of ``divmod(a, b)`` + exquo: Analogue of ``a / b`` + """ + raise NotImplementedError + + def rem(self, a, b): + """Modulo division of *a* and *b*. Analogue of ``a % b``. + + ``K.rem(a, b)`` is equivalent to ``K.div(a, b)[1]``. See + :py:meth:`~.Domain.div` for more explanation. + + See also + ======== + + quo: Analogue of ``a // b`` + div: Analogue of ``divmod(a, b)`` + exquo: Analogue of ``a / b`` + """ + raise NotImplementedError + + def div(self, a, b): + """Quotient and remainder for *a* and *b*. Analogue of ``divmod(a, b)`` + + Explanation + =========== + + This is essentially the same as ``divmod(a, b)`` except that is more + consistent when working over some :py:class:`~.Field` domains such as + :ref:`QQ`. When working over an arbitrary :py:class:`~.Domain` the + :py:meth:`~.Domain.div` method should be used instead of ``divmod``. + + The key invariant is that if ``q, r = K.div(a, b)`` then + ``a == b*q + r``. + + The result of ``K.div(a, b)`` is the same as the tuple + ``(K.quo(a, b), K.rem(a, b))`` except that if both quotient and + remainder are needed then it is more efficient to use + :py:meth:`~.Domain.div`. + + Examples + ======== + + We can use ``K.div`` instead of ``divmod`` for floor division and + remainder. + + >>> from sympy import ZZ, QQ + >>> ZZ.div(ZZ(5), ZZ(2)) + (2, 1) + + If ``K`` is a :py:class:`~.Field` then the division is always exact + with a remainder of :py:attr:`~.Domain.zero`. + + >>> QQ.div(QQ(5), QQ(2)) + (5/2, 0) + + Parameters + ========== + + a: domain element + The dividend + b: domain element + The divisor + + Returns + ======= + + (q, r): tuple of domain elements + The quotient and remainder + + Raises + ====== + + ZeroDivisionError: when the divisor is zero. + + See also + ======== + + quo: Analogue of ``a // b`` + rem: Analogue of ``a % b`` + exquo: Analogue of ``a / b`` + + Notes + ===== + + If ``gmpy`` is installed then the ``gmpy.mpq`` type will be used as + the :py:attr:`~.Domain.dtype` for :ref:`QQ`. The ``gmpy.mpq`` type + defines ``divmod`` in a way that is undesirable so + :py:meth:`~.Domain.div` should be used instead of ``divmod``. + + >>> a = QQ(1) + >>> b = QQ(3, 2) + >>> a # doctest: +SKIP + mpq(1,1) + >>> b # doctest: +SKIP + mpq(3,2) + >>> divmod(a, b) # doctest: +SKIP + (mpz(0), mpq(1,1)) + >>> QQ.div(a, b) # doctest: +SKIP + (mpq(2,3), mpq(0,1)) + + Using ``//`` or ``%`` with :ref:`QQ` will lead to incorrect results so + :py:meth:`~.Domain.div` should be used instead. + + """ + raise NotImplementedError + + def invert(self, a, b): + """Returns inversion of ``a mod b``, implies something. """ + raise NotImplementedError + + def revert(self, a): + """Returns ``a**(-1)`` if possible. """ + raise NotImplementedError + + def numer(self, a): + """Returns numerator of ``a``. """ + raise NotImplementedError + + def denom(self, a): + """Returns denominator of ``a``. """ + raise NotImplementedError + + def half_gcdex(self, a, b): + """Half extended GCD of ``a`` and ``b``. """ + s, t, h = self.gcdex(a, b) + return s, h + + def gcdex(self, a, b): + """Extended GCD of ``a`` and ``b``. """ + raise NotImplementedError + + def cofactors(self, a, b): + """Returns GCD and cofactors of ``a`` and ``b``. """ + gcd = self.gcd(a, b) + cfa = self.quo(a, gcd) + cfb = self.quo(b, gcd) + return gcd, cfa, cfb + + def gcd(self, a, b): + """Returns GCD of ``a`` and ``b``. """ + raise NotImplementedError + + def lcm(self, a, b): + """Returns LCM of ``a`` and ``b``. """ + raise NotImplementedError + + def log(self, a, b): + """Returns b-base logarithm of ``a``. """ + raise NotImplementedError + + def sqrt(self, a): + """Returns a (possibly inexact) square root of ``a``. + + Explanation + =========== + There is no universal definition of "inexact square root" for all + domains. It is not recommended to implement this method for domains + other then :ref:`ZZ`. + + See also + ======== + exsqrt + """ + raise NotImplementedError + + def is_square(self, a): + """Returns whether ``a`` is a square in the domain. + + Explanation + =========== + Returns ``True`` if there is an element ``b`` in the domain such that + ``b * b == a``, otherwise returns ``False``. For inexact domains like + :ref:`RR` and :ref:`CC`, a tiny difference in this equality can be + tolerated. + + See also + ======== + exsqrt + """ + raise NotImplementedError + + def exsqrt(self, a): + """Principal square root of a within the domain if ``a`` is square. + + Explanation + =========== + The implementation of this method should return an element ``b`` in the + domain such that ``b * b == a``, or ``None`` if there is no such ``b``. + For inexact domains like :ref:`RR` and :ref:`CC`, a tiny difference in + this equality can be tolerated. The choice of a "principal" square root + should follow a consistent rule whenever possible. + + See also + ======== + sqrt, is_square + """ + raise NotImplementedError + + def evalf(self, a, prec=None, **options): + """Returns numerical approximation of ``a``. """ + return self.to_sympy(a).evalf(prec, **options) + + n = evalf + + def real(self, a): + return a + + def imag(self, a): + return self.zero + + def almosteq(self, a, b, tolerance=None): + """Check if ``a`` and ``b`` are almost equal. """ + return a == b + + def characteristic(self): + """Return the characteristic of this domain. """ + raise NotImplementedError('characteristic()') + + +__all__ = ['Domain'] diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/domainelement.py b/phivenv/Lib/site-packages/sympy/polys/domains/domainelement.py new file mode 100644 index 0000000000000000000000000000000000000000..b1033e86a7edcbffa633efd65ca7ced48f3b1f1a --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/domains/domainelement.py @@ -0,0 +1,38 @@ +"""Trait for implementing domain elements. """ + + +from sympy.utilities import public + +@public +class DomainElement: + """ + Represents an element of a domain. + + Mix in this trait into a class whose instances should be recognized as + elements of a domain. Method ``parent()`` gives that domain. + """ + + __slots__ = () + + def parent(self): + """Get the domain associated with ``self`` + + Examples + ======== + + >>> from sympy import ZZ, symbols + >>> x, y = symbols('x, y') + >>> K = ZZ[x,y] + >>> p = K(x)**2 + K(y)**2 + >>> p + x**2 + y**2 + >>> p.parent() + ZZ[x,y] + + Notes + ===== + + This is used by :py:meth:`~.Domain.convert` to identify the domain + associated with a domain element. + """ + raise NotImplementedError("abstract method") diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/expressiondomain.py b/phivenv/Lib/site-packages/sympy/polys/domains/expressiondomain.py new file mode 100644 index 0000000000000000000000000000000000000000..26cd5aa5bf34985f885093be227df6aa9b35d36c --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/domains/expressiondomain.py @@ -0,0 +1,278 @@ +"""Implementation of :class:`ExpressionDomain` class. """ + + +from sympy.core import sympify, SympifyError +from sympy.polys.domains.domainelement import DomainElement +from sympy.polys.domains.characteristiczero import CharacteristicZero +from sympy.polys.domains.field import Field +from sympy.polys.domains.simpledomain import SimpleDomain +from sympy.polys.polyutils import PicklableWithSlots +from sympy.utilities import public + +eflags = {"deep": False, "mul": True, "power_exp": False, "power_base": False, + "basic": False, "multinomial": False, "log": False} + +@public +class ExpressionDomain(Field, CharacteristicZero, SimpleDomain): + """A class for arbitrary expressions. """ + + is_SymbolicDomain = is_EX = True + + class Expression(DomainElement, PicklableWithSlots): + """An arbitrary expression. """ + + __slots__ = ('ex',) + + def __init__(self, ex): + if not isinstance(ex, self.__class__): + self.ex = sympify(ex) + else: + self.ex = ex.ex + + def __repr__(f): + return 'EX(%s)' % repr(f.ex) + + def __str__(f): + return 'EX(%s)' % str(f.ex) + + def __hash__(self): + return hash((self.__class__.__name__, self.ex)) + + def parent(self): + return EX + + def as_expr(f): + return f.ex + + def numer(f): + return f.__class__(f.ex.as_numer_denom()[0]) + + def denom(f): + return f.__class__(f.ex.as_numer_denom()[1]) + + def simplify(f, ex): + return f.__class__(ex.cancel().expand(**eflags)) + + def __abs__(f): + return f.__class__(abs(f.ex)) + + def __neg__(f): + return f.__class__(-f.ex) + + def _to_ex(f, g): + try: + return f.__class__(g) + except SympifyError: + return None + + def __lt__(f, g): + return f.ex.sort_key() < g.ex.sort_key() + + def __add__(f, g): + g = f._to_ex(g) + + if g is None: + return NotImplemented + elif g == EX.zero: + return f + elif f == EX.zero: + return g + else: + return f.simplify(f.ex + g.ex) + + def __radd__(f, g): + return f.simplify(f.__class__(g).ex + f.ex) + + def __sub__(f, g): + g = f._to_ex(g) + + if g is None: + return NotImplemented + elif g == EX.zero: + return f + elif f == EX.zero: + return -g + else: + return f.simplify(f.ex - g.ex) + + def __rsub__(f, g): + return f.simplify(f.__class__(g).ex - f.ex) + + def __mul__(f, g): + g = f._to_ex(g) + + if g is None: + return NotImplemented + + if EX.zero in (f, g): + return EX.zero + elif f.ex.is_Number and g.ex.is_Number: + return f.__class__(f.ex*g.ex) + + return f.simplify(f.ex*g.ex) + + def __rmul__(f, g): + return f.simplify(f.__class__(g).ex*f.ex) + + def __pow__(f, n): + n = f._to_ex(n) + + if n is not None: + return f.simplify(f.ex**n.ex) + else: + return NotImplemented + + def __truediv__(f, g): + g = f._to_ex(g) + + if g is not None: + return f.simplify(f.ex/g.ex) + else: + return NotImplemented + + def __rtruediv__(f, g): + return f.simplify(f.__class__(g).ex/f.ex) + + def __eq__(f, g): + return f.ex == f.__class__(g).ex + + def __ne__(f, g): + return not f == g + + def __bool__(f): + return not f.ex.is_zero + + def gcd(f, g): + from sympy.polys import gcd + return f.__class__(gcd(f.ex, f.__class__(g).ex)) + + def lcm(f, g): + from sympy.polys import lcm + return f.__class__(lcm(f.ex, f.__class__(g).ex)) + + dtype = Expression + + zero = Expression(0) + one = Expression(1) + + rep = 'EX' + + has_assoc_Ring = False + has_assoc_Field = True + + def __init__(self): + pass + + def __eq__(self, other): + if isinstance(other, ExpressionDomain): + return True + else: + return NotImplemented + + def __hash__(self): + return hash("EX") + + def to_sympy(self, a): + """Convert ``a`` to a SymPy object. """ + return a.as_expr() + + def from_sympy(self, a): + """Convert SymPy's expression to ``dtype``. """ + return self.dtype(a) + + def from_ZZ(K1, a, K0): + """Convert a Python ``int`` object to ``dtype``. """ + return K1(K0.to_sympy(a)) + + def from_ZZ_python(K1, a, K0): + """Convert a Python ``int`` object to ``dtype``. """ + return K1(K0.to_sympy(a)) + + def from_QQ(K1, a, K0): + """Convert a Python ``Fraction`` object to ``dtype``. """ + return K1(K0.to_sympy(a)) + + def from_QQ_python(K1, a, K0): + """Convert a Python ``Fraction`` object to ``dtype``. """ + return K1(K0.to_sympy(a)) + + def from_ZZ_gmpy(K1, a, K0): + """Convert a GMPY ``mpz`` object to ``dtype``. """ + return K1(K0.to_sympy(a)) + + def from_QQ_gmpy(K1, a, K0): + """Convert a GMPY ``mpq`` object to ``dtype``. """ + return K1(K0.to_sympy(a)) + + def from_GaussianIntegerRing(K1, a, K0): + """Convert a ``GaussianRational`` object to ``dtype``. """ + return K1(K0.to_sympy(a)) + + def from_GaussianRationalField(K1, a, K0): + """Convert a ``GaussianRational`` object to ``dtype``. """ + return K1(K0.to_sympy(a)) + + def from_AlgebraicField(K1, a, K0): + """Convert an ``ANP`` object to ``dtype``. """ + return K1(K0.to_sympy(a)) + + def from_RealField(K1, a, K0): + """Convert a mpmath ``mpf`` object to ``dtype``. """ + return K1(K0.to_sympy(a)) + + def from_ComplexField(K1, a, K0): + """Convert a mpmath ``mpc`` object to ``dtype``. """ + return K1(K0.to_sympy(a)) + + def from_PolynomialRing(K1, a, K0): + """Convert a ``DMP`` object to ``dtype``. """ + return K1(K0.to_sympy(a)) + + def from_FractionField(K1, a, K0): + """Convert a ``DMF`` object to ``dtype``. """ + return K1(K0.to_sympy(a)) + + def from_ExpressionDomain(K1, a, K0): + """Convert a ``EX`` object to ``dtype``. """ + return a + + def get_ring(self): + """Returns a ring associated with ``self``. """ + return self # XXX: EX is not a ring but we don't have much choice here. + + def get_field(self): + """Returns a field associated with ``self``. """ + return self + + def is_positive(self, a): + """Returns True if ``a`` is positive. """ + return a.ex.as_coeff_mul()[0].is_positive + + def is_negative(self, a): + """Returns True if ``a`` is negative. """ + return a.ex.could_extract_minus_sign() + + def is_nonpositive(self, a): + """Returns True if ``a`` is non-positive. """ + return a.ex.as_coeff_mul()[0].is_nonpositive + + def is_nonnegative(self, a): + """Returns True if ``a`` is non-negative. """ + return a.ex.as_coeff_mul()[0].is_nonnegative + + def numer(self, a): + """Returns numerator of ``a``. """ + return a.numer() + + def denom(self, a): + """Returns denominator of ``a``. """ + return a.denom() + + def gcd(self, a, b): + return self(1) + + def lcm(self, a, b): + return a.lcm(b) + + +EX = ExpressionDomain() diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/expressionrawdomain.py b/phivenv/Lib/site-packages/sympy/polys/domains/expressionrawdomain.py new file mode 100644 index 0000000000000000000000000000000000000000..9811ca26c965197a13f56ab8266ad744e4571560 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/domains/expressionrawdomain.py @@ -0,0 +1,57 @@ +"""Implementation of :class:`ExpressionRawDomain` class. """ + + +from sympy.core import Expr, S, sympify, Add +from sympy.polys.domains.characteristiczero import CharacteristicZero +from sympy.polys.domains.field import Field +from sympy.polys.domains.simpledomain import SimpleDomain +from sympy.polys.polyerrors import CoercionFailed +from sympy.utilities import public + + +@public +class ExpressionRawDomain(Field, CharacteristicZero, SimpleDomain): + """A class for arbitrary expressions but without automatic simplification. """ + + is_SymbolicRawDomain = is_EXRAW = True + + dtype = Expr + + zero = S.Zero + one = S.One + + rep = 'EXRAW' + + has_assoc_Ring = False + has_assoc_Field = True + + def __init__(self): + pass + + @classmethod + def new(self, a): + return sympify(a) + + def to_sympy(self, a): + """Convert ``a`` to a SymPy object. """ + return a + + def from_sympy(self, a): + """Convert SymPy's expression to ``dtype``. """ + if not isinstance(a, Expr): + raise CoercionFailed(f"Expecting an Expr instance but found: {type(a).__name__}") + return a + + def convert_from(self, a, K): + """Convert a domain element from another domain to EXRAW""" + return K.to_sympy(a) + + def get_field(self): + """Returns a field associated with ``self``. """ + return self + + def sum(self, items): + return Add(*items) + + +EXRAW = ExpressionRawDomain() diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/field.py b/phivenv/Lib/site-packages/sympy/polys/domains/field.py new file mode 100644 index 0000000000000000000000000000000000000000..a6370294365a38dee1b2eda9942a66aeef8fdae9 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/domains/field.py @@ -0,0 +1,118 @@ +"""Implementation of :class:`Field` class. """ + + +from sympy.polys.domains.ring import Ring +from sympy.polys.polyerrors import NotReversible, DomainError +from sympy.utilities import public + +@public +class Field(Ring): + """Represents a field domain. """ + + is_Field = True + is_PID = True + + def get_ring(self): + """Returns a ring associated with ``self``. """ + raise DomainError('there is no ring associated with %s' % self) + + def get_field(self): + """Returns a field associated with ``self``. """ + return self + + def exquo(self, a, b): + """Exact quotient of ``a`` and ``b``, implies ``__truediv__``. """ + return a / b + + def quo(self, a, b): + """Quotient of ``a`` and ``b``, implies ``__truediv__``. """ + return a / b + + def rem(self, a, b): + """Remainder of ``a`` and ``b``, implies nothing. """ + return self.zero + + def div(self, a, b): + """Division of ``a`` and ``b``, implies ``__truediv__``. """ + return a / b, self.zero + + def gcd(self, a, b): + """ + Returns GCD of ``a`` and ``b``. + + This definition of GCD over fields allows to clear denominators + in `primitive()`. + + Examples + ======== + + >>> from sympy.polys.domains import QQ + >>> from sympy import S, gcd, primitive + >>> from sympy.abc import x + + >>> QQ.gcd(QQ(2, 3), QQ(4, 9)) + 2/9 + >>> gcd(S(2)/3, S(4)/9) + 2/9 + >>> primitive(2*x/3 + S(4)/9) + (2/9, 3*x + 2) + + """ + try: + ring = self.get_ring() + except DomainError: + return self.one + + p = ring.gcd(self.numer(a), self.numer(b)) + q = ring.lcm(self.denom(a), self.denom(b)) + + return self.convert(p, ring)/q + + def gcdex(self, a, b): + """ + Returns x, y, g such that a * x + b * y == g == gcd(a, b) + """ + d = self.gcd(a, b) + + if a == self.zero: + if b == self.zero: + return self.zero, self.one, self.zero + else: + return self.zero, d/b, d + else: + return d/a, self.zero, d + + def lcm(self, a, b): + """ + Returns LCM of ``a`` and ``b``. + + >>> from sympy.polys.domains import QQ + >>> from sympy import S, lcm + + >>> QQ.lcm(QQ(2, 3), QQ(4, 9)) + 4/3 + >>> lcm(S(2)/3, S(4)/9) + 4/3 + + """ + + try: + ring = self.get_ring() + except DomainError: + return a*b + + p = ring.lcm(self.numer(a), self.numer(b)) + q = ring.gcd(self.denom(a), self.denom(b)) + + return self.convert(p, ring)/q + + def revert(self, a): + """Returns ``a**(-1)`` if possible. """ + if a: + return 1/a + else: + raise NotReversible('zero is not reversible') + + def is_unit(self, a): + """Return true if ``a`` is a invertible""" + return bool(a) diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/finitefield.py b/phivenv/Lib/site-packages/sympy/polys/domains/finitefield.py new file mode 100644 index 0000000000000000000000000000000000000000..d3c48ac07f63aefb9a58c83bb95c5261e67e6a9e --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/domains/finitefield.py @@ -0,0 +1,368 @@ +"""Implementation of :class:`FiniteField` class. """ + +import operator + +from sympy.external.gmpy import GROUND_TYPES +from sympy.utilities.decorator import doctest_depends_on + +from sympy.core.numbers import int_valued +from sympy.polys.domains.field import Field + +from sympy.polys.domains.modularinteger import ModularIntegerFactory +from sympy.polys.domains.simpledomain import SimpleDomain +from sympy.polys.galoistools import gf_zassenhaus, gf_irred_p_rabin +from sympy.polys.polyerrors import CoercionFailed +from sympy.utilities import public +from sympy.polys.domains.groundtypes import SymPyInteger + + +if GROUND_TYPES == 'flint': + __doctest_skip__ = ['FiniteField'] + + +if GROUND_TYPES == 'flint': + import flint + # Don't use python-flint < 0.5.0 because nmod was missing some features in + # previous versions of python-flint and fmpz_mod was not yet added. + _major, _minor, *_ = flint.__version__.split('.') + if (int(_major), int(_minor)) < (0, 5): + flint = None +else: + flint = None + + +def _modular_int_factory_nmod(mod): + # nmod only recognises int + index = operator.index + mod = index(mod) + nmod = flint.nmod + nmod_poly = flint.nmod_poly + + # flint's nmod is only for moduli up to 2^64-1 (on a 64-bit machine) + try: + nmod(0, mod) + except OverflowError: + return None, None + + def ctx(x): + try: + return nmod(x, mod) + except TypeError: + return nmod(index(x), mod) + + def poly_ctx(cs): + return nmod_poly(cs, mod) + + return ctx, poly_ctx + + +def _modular_int_factory_fmpz_mod(mod): + index = operator.index + fctx = flint.fmpz_mod_ctx(mod) + fctx_poly = flint.fmpz_mod_poly_ctx(mod) + fmpz_mod_poly = flint.fmpz_mod_poly + + def ctx(x): + try: + return fctx(x) + except TypeError: + # x might be Integer + return fctx(index(x)) + + def poly_ctx(cs): + return fmpz_mod_poly(cs, fctx_poly) + + return ctx, poly_ctx + + +def _modular_int_factory(mod, dom, symmetric, self): + # Convert the modulus to ZZ + try: + mod = dom.convert(mod) + except CoercionFailed: + raise ValueError('modulus must be an integer, got %s' % mod) + + ctx, poly_ctx, is_flint = None, None, False + + # Don't use flint if the modulus is not prime as it often crashes. + if flint is not None and mod.is_prime(): + + is_flint = True + + # Try to use flint's nmod first + ctx, poly_ctx = _modular_int_factory_nmod(mod) + + if ctx is None: + # Use fmpz_mod for larger moduli + ctx, poly_ctx = _modular_int_factory_fmpz_mod(mod) + + if ctx is None: + # Use the Python implementation if flint is not available or the + # modulus is not prime. + ctx = ModularIntegerFactory(mod, dom, symmetric, self) + poly_ctx = None # not used + + return ctx, poly_ctx, is_flint + + +@public +@doctest_depends_on(modules=['python', 'gmpy']) +class FiniteField(Field, SimpleDomain): + r"""Finite field of prime order :ref:`GF(p)` + + A :ref:`GF(p)` domain represents a `finite field`_ `\mathbb{F}_p` of prime + order as :py:class:`~.Domain` in the domain system (see + :ref:`polys-domainsintro`). + + A :py:class:`~.Poly` created from an expression with integer + coefficients will have the domain :ref:`ZZ`. However, if the ``modulus=p`` + option is given then the domain will be a finite field instead. + + >>> from sympy import Poly, Symbol + >>> x = Symbol('x') + >>> p = Poly(x**2 + 1) + >>> p + Poly(x**2 + 1, x, domain='ZZ') + >>> p.domain + ZZ + >>> p2 = Poly(x**2 + 1, modulus=2) + >>> p2 + Poly(x**2 + 1, x, modulus=2) + >>> p2.domain + GF(2) + + It is possible to factorise a polynomial over :ref:`GF(p)` using the + modulus argument to :py:func:`~.factor` or by specifying the domain + explicitly. The domain can also be given as a string. + + >>> from sympy import factor, GF + >>> factor(x**2 + 1) + x**2 + 1 + >>> factor(x**2 + 1, modulus=2) + (x + 1)**2 + >>> factor(x**2 + 1, domain=GF(2)) + (x + 1)**2 + >>> factor(x**2 + 1, domain='GF(2)') + (x + 1)**2 + + It is also possible to use :ref:`GF(p)` with the :py:func:`~.cancel` + and :py:func:`~.gcd` functions. + + >>> from sympy import cancel, gcd + >>> cancel((x**2 + 1)/(x + 1)) + (x**2 + 1)/(x + 1) + >>> cancel((x**2 + 1)/(x + 1), domain=GF(2)) + x + 1 + >>> gcd(x**2 + 1, x + 1) + 1 + >>> gcd(x**2 + 1, x + 1, domain=GF(2)) + x + 1 + + When using the domain directly :ref:`GF(p)` can be used as a constructor + to create instances which then support the operations ``+,-,*,**,/`` + + >>> from sympy import GF + >>> K = GF(5) + >>> K + GF(5) + >>> x = K(3) + >>> y = K(2) + >>> x + 3 mod 5 + >>> y + 2 mod 5 + >>> x * y + 1 mod 5 + >>> x / y + 4 mod 5 + + Notes + ===== + + It is also possible to create a :ref:`GF(p)` domain of **non-prime** + order but the resulting ring is **not** a field: it is just the ring of + the integers modulo ``n``. + + >>> K = GF(9) + >>> z = K(3) + >>> z + 3 mod 9 + >>> z**2 + 0 mod 9 + + It would be good to have a proper implementation of prime power fields + (``GF(p**n)``) but these are not yet implemented in SymPY. + + .. _finite field: https://en.wikipedia.org/wiki/Finite_field + """ + + rep = 'FF' + alias = 'FF' + + is_FiniteField = is_FF = True + is_Numerical = True + + has_assoc_Ring = False + has_assoc_Field = True + + dom = None + mod = None + + def __init__(self, mod, symmetric=True): + from sympy.polys.domains import ZZ + dom = ZZ + + if mod <= 0: + raise ValueError('modulus must be a positive integer, got %s' % mod) + + ctx, poly_ctx, is_flint = _modular_int_factory(mod, dom, symmetric, self) + + self.dtype = ctx + self._poly_ctx = poly_ctx + self._is_flint = is_flint + + self.zero = self.dtype(0) + self.one = self.dtype(1) + self.dom = dom + self.mod = mod + self.sym = symmetric + self._tp = type(self.zero) + + @property + def tp(self): + return self._tp + + @property + def is_Field(self): + is_field = getattr(self, '_is_field', None) + if is_field is None: + from sympy.ntheory.primetest import isprime + self._is_field = is_field = isprime(self.mod) + return is_field + + def __str__(self): + return 'GF(%s)' % self.mod + + def __hash__(self): + return hash((self.__class__.__name__, self.dtype, self.mod, self.dom)) + + def __eq__(self, other): + """Returns ``True`` if two domains are equivalent. """ + return isinstance(other, FiniteField) and \ + self.mod == other.mod and self.dom == other.dom + + def characteristic(self): + """Return the characteristic of this domain. """ + return self.mod + + def get_field(self): + """Returns a field associated with ``self``. """ + return self + + def to_sympy(self, a): + """Convert ``a`` to a SymPy object. """ + return SymPyInteger(self.to_int(a)) + + def from_sympy(self, a): + """Convert SymPy's Integer to SymPy's ``Integer``. """ + if a.is_Integer: + return self.dtype(self.dom.dtype(int(a))) + elif int_valued(a): + return self.dtype(self.dom.dtype(int(a))) + else: + raise CoercionFailed("expected an integer, got %s" % a) + + def to_int(self, a): + """Convert ``val`` to a Python ``int`` object. """ + aval = int(a) + if self.sym and aval > self.mod // 2: + aval -= self.mod + return aval + + def is_positive(self, a): + """Returns True if ``a`` is positive. """ + return bool(a) + + def is_nonnegative(self, a): + """Returns True if ``a`` is non-negative. """ + return True + + def is_negative(self, a): + """Returns True if ``a`` is negative. """ + return False + + def is_nonpositive(self, a): + """Returns True if ``a`` is non-positive. """ + return not a + + def from_FF(K1, a, K0=None): + """Convert ``ModularInteger(int)`` to ``dtype``. """ + return K1.dtype(K1.dom.from_ZZ(int(a), K0.dom)) + + def from_FF_python(K1, a, K0=None): + """Convert ``ModularInteger(int)`` to ``dtype``. """ + return K1.dtype(K1.dom.from_ZZ_python(int(a), K0.dom)) + + def from_ZZ(K1, a, K0=None): + """Convert Python's ``int`` to ``dtype``. """ + return K1.dtype(K1.dom.from_ZZ_python(a, K0)) + + def from_ZZ_python(K1, a, K0=None): + """Convert Python's ``int`` to ``dtype``. """ + return K1.dtype(K1.dom.from_ZZ_python(a, K0)) + + def from_QQ(K1, a, K0=None): + """Convert Python's ``Fraction`` to ``dtype``. """ + if a.denominator == 1: + return K1.from_ZZ_python(a.numerator) + + def from_QQ_python(K1, a, K0=None): + """Convert Python's ``Fraction`` to ``dtype``. """ + if a.denominator == 1: + return K1.from_ZZ_python(a.numerator) + + def from_FF_gmpy(K1, a, K0=None): + """Convert ``ModularInteger(mpz)`` to ``dtype``. """ + return K1.dtype(K1.dom.from_ZZ_gmpy(a.val, K0.dom)) + + def from_ZZ_gmpy(K1, a, K0=None): + """Convert GMPY's ``mpz`` to ``dtype``. """ + return K1.dtype(K1.dom.from_ZZ_gmpy(a, K0)) + + def from_QQ_gmpy(K1, a, K0=None): + """Convert GMPY's ``mpq`` to ``dtype``. """ + if a.denominator == 1: + return K1.from_ZZ_gmpy(a.numerator) + + def from_RealField(K1, a, K0): + """Convert mpmath's ``mpf`` to ``dtype``. """ + p, q = K0.to_rational(a) + + if q == 1: + return K1.dtype(K1.dom.dtype(p)) + + def is_square(self, a): + """Returns True if ``a`` is a quadratic residue modulo p. """ + # a is not a square <=> x**2-a is irreducible + poly = [int(x) for x in [self.one, self.zero, -a]] + return not gf_irred_p_rabin(poly, self.mod, self.dom) + + def exsqrt(self, a): + """Square root modulo p of ``a`` if it is a quadratic residue. + + Explanation + =========== + Always returns the square root that is no larger than ``p // 2``. + """ + # x**2-a is not square-free if a=0 or the field is characteristic 2 + if self.mod == 2 or a == 0: + return a + # Otherwise, use square-free factorization routine to factorize x**2-a + poly = [int(x) for x in [self.one, self.zero, -a]] + for factor in gf_zassenhaus(poly, self.mod, self.dom): + if len(factor) == 2 and factor[1] <= self.mod // 2: + return self.dtype(factor[1]) + return None + + +FF = GF = FiniteField diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/fractionfield.py b/phivenv/Lib/site-packages/sympy/polys/domains/fractionfield.py new file mode 100644 index 0000000000000000000000000000000000000000..78f5054ddd5480fe6f77442f7a25f22603a4d90d --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/domains/fractionfield.py @@ -0,0 +1,181 @@ +"""Implementation of :class:`FractionField` class. """ + + +from sympy.polys.domains.compositedomain import CompositeDomain +from sympy.polys.domains.field import Field +from sympy.polys.polyerrors import CoercionFailed, GeneratorsError +from sympy.utilities import public + +@public +class FractionField(Field, CompositeDomain): + """A class for representing multivariate rational function fields. """ + + is_FractionField = is_Frac = True + + has_assoc_Ring = True + has_assoc_Field = True + + def __init__(self, domain_or_field, symbols=None, order=None): + from sympy.polys.fields import FracField + + if isinstance(domain_or_field, FracField) and symbols is None and order is None: + field = domain_or_field + else: + field = FracField(symbols, domain_or_field, order) + + self.field = field + self.dtype = field.dtype + + self.gens = field.gens + self.ngens = field.ngens + self.symbols = field.symbols + self.domain = field.domain + + # TODO: remove this + self.dom = self.domain + + def new(self, element): + return self.field.field_new(element) + + def of_type(self, element): + """Check if ``a`` is of type ``dtype``. """ + return self.field.is_element(element) + + @property + def zero(self): + return self.field.zero + + @property + def one(self): + return self.field.one + + @property + def order(self): + return self.field.order + + def __str__(self): + return str(self.domain) + '(' + ','.join(map(str, self.symbols)) + ')' + + def __hash__(self): + return hash((self.__class__.__name__, self.field, self.domain, self.symbols)) + + def __eq__(self, other): + """Returns ``True`` if two domains are equivalent. """ + if not isinstance(other, FractionField): + return NotImplemented + return self.field == other.field + + def to_sympy(self, a): + """Convert ``a`` to a SymPy object. """ + return a.as_expr() + + def from_sympy(self, a): + """Convert SymPy's expression to ``dtype``. """ + return self.field.from_expr(a) + + def from_ZZ(K1, a, K0): + """Convert a Python ``int`` object to ``dtype``. """ + return K1(K1.domain.convert(a, K0)) + + def from_ZZ_python(K1, a, K0): + """Convert a Python ``int`` object to ``dtype``. """ + return K1(K1.domain.convert(a, K0)) + + def from_QQ(K1, a, K0): + """Convert a Python ``Fraction`` object to ``dtype``. """ + dom = K1.domain + conv = dom.convert_from + if dom.is_ZZ: + return K1(conv(K0.numer(a), K0)) / K1(conv(K0.denom(a), K0)) + else: + return K1(conv(a, K0)) + + def from_QQ_python(K1, a, K0): + """Convert a Python ``Fraction`` object to ``dtype``. """ + return K1(K1.domain.convert(a, K0)) + + def from_ZZ_gmpy(K1, a, K0): + """Convert a GMPY ``mpz`` object to ``dtype``. """ + return K1(K1.domain.convert(a, K0)) + + def from_QQ_gmpy(K1, a, K0): + """Convert a GMPY ``mpq`` object to ``dtype``. """ + return K1(K1.domain.convert(a, K0)) + + def from_GaussianRationalField(K1, a, K0): + """Convert a ``GaussianRational`` object to ``dtype``. """ + return K1(K1.domain.convert(a, K0)) + + def from_GaussianIntegerRing(K1, a, K0): + """Convert a ``GaussianInteger`` object to ``dtype``. """ + return K1(K1.domain.convert(a, K0)) + + def from_RealField(K1, a, K0): + """Convert a mpmath ``mpf`` object to ``dtype``. """ + return K1(K1.domain.convert(a, K0)) + + def from_ComplexField(K1, a, K0): + """Convert a mpmath ``mpf`` object to ``dtype``. """ + return K1(K1.domain.convert(a, K0)) + + def from_AlgebraicField(K1, a, K0): + """Convert an algebraic number to ``dtype``. """ + if K1.domain != K0: + a = K1.domain.convert_from(a, K0) + if a is not None: + return K1.new(a) + + def from_PolynomialRing(K1, a, K0): + """Convert a polynomial to ``dtype``. """ + if a.is_ground: + return K1.convert_from(a.coeff(1), K0.domain) + try: + return K1.new(a.set_ring(K1.field.ring)) + except (CoercionFailed, GeneratorsError): + # XXX: We get here if K1=ZZ(x,y) and K0=QQ[x,y] + # and the poly a in K0 has non-integer coefficients. + # It seems that K1.new can handle this but K1.new doesn't work + # when K0.domain is an algebraic field... + try: + return K1.new(a) + except (CoercionFailed, GeneratorsError): + return None + + def from_FractionField(K1, a, K0): + """Convert a rational function to ``dtype``. """ + try: + return a.set_field(K1.field) + except (CoercionFailed, GeneratorsError): + return None + + def get_ring(self): + """Returns a field associated with ``self``. """ + return self.field.to_ring().to_domain() + + def is_positive(self, a): + """Returns True if ``LC(a)`` is positive. """ + return self.domain.is_positive(a.numer.LC) + + def is_negative(self, a): + """Returns True if ``LC(a)`` is negative. """ + return self.domain.is_negative(a.numer.LC) + + def is_nonpositive(self, a): + """Returns True if ``LC(a)`` is non-positive. """ + return self.domain.is_nonpositive(a.numer.LC) + + def is_nonnegative(self, a): + """Returns True if ``LC(a)`` is non-negative. """ + return self.domain.is_nonnegative(a.numer.LC) + + def numer(self, a): + """Returns numerator of ``a``. """ + return a.numer + + def denom(self, a): + """Returns denominator of ``a``. """ + return a.denom + + def factorial(self, a): + """Returns factorial of ``a``. """ + return self.dtype(self.domain.factorial(a)) diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/gaussiandomains.py b/phivenv/Lib/site-packages/sympy/polys/domains/gaussiandomains.py new file mode 100644 index 0000000000000000000000000000000000000000..a96bed78e29445c90c53605a85faa4df16bf807c --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/domains/gaussiandomains.py @@ -0,0 +1,706 @@ +"""Domains of Gaussian type.""" + +from __future__ import annotations +from sympy.core.numbers import I +from sympy.polys.polyclasses import DMP +from sympy.polys.polyerrors import CoercionFailed +from sympy.polys.domains.integerring import ZZ +from sympy.polys.domains.rationalfield import QQ +from sympy.polys.domains.algebraicfield import AlgebraicField +from sympy.polys.domains.domain import Domain +from sympy.polys.domains.domainelement import DomainElement +from sympy.polys.domains.field import Field +from sympy.polys.domains.ring import Ring + + +class GaussianElement(DomainElement): + """Base class for elements of Gaussian type domains.""" + base: Domain + _parent: Domain + + __slots__ = ('x', 'y') + + def __new__(cls, x, y=0): + conv = cls.base.convert + return cls.new(conv(x), conv(y)) + + @classmethod + def new(cls, x, y): + """Create a new GaussianElement of the same domain.""" + obj = super().__new__(cls) + obj.x = x + obj.y = y + return obj + + def parent(self): + """The domain that this is an element of (ZZ_I or QQ_I)""" + return self._parent + + def __hash__(self): + return hash((self.x, self.y)) + + def __eq__(self, other): + if isinstance(other, self.__class__): + return self.x == other.x and self.y == other.y + else: + return NotImplemented + + def __lt__(self, other): + if not isinstance(other, GaussianElement): + return NotImplemented + return [self.y, self.x] < [other.y, other.x] + + def __pos__(self): + return self + + def __neg__(self): + return self.new(-self.x, -self.y) + + def __repr__(self): + return "%s(%s, %s)" % (self._parent.rep, self.x, self.y) + + def __str__(self): + return str(self._parent.to_sympy(self)) + + @classmethod + def _get_xy(cls, other): + if not isinstance(other, cls): + try: + other = cls._parent.convert(other) + except CoercionFailed: + return None, None + return other.x, other.y + + def __add__(self, other): + x, y = self._get_xy(other) + if x is not None: + return self.new(self.x + x, self.y + y) + else: + return NotImplemented + + __radd__ = __add__ + + def __sub__(self, other): + x, y = self._get_xy(other) + if x is not None: + return self.new(self.x - x, self.y - y) + else: + return NotImplemented + + def __rsub__(self, other): + x, y = self._get_xy(other) + if x is not None: + return self.new(x - self.x, y - self.y) + else: + return NotImplemented + + def __mul__(self, other): + x, y = self._get_xy(other) + if x is not None: + return self.new(self.x*x - self.y*y, self.x*y + self.y*x) + else: + return NotImplemented + + __rmul__ = __mul__ + + def __pow__(self, exp): + if exp == 0: + return self.new(1, 0) + if exp < 0: + self, exp = 1/self, -exp + if exp == 1: + return self + pow2 = self + prod = self if exp % 2 else self._parent.one + exp //= 2 + while exp: + pow2 *= pow2 + if exp % 2: + prod *= pow2 + exp //= 2 + return prod + + def __bool__(self): + return bool(self.x) or bool(self.y) + + def quadrant(self): + """Return quadrant index 0-3. + + 0 is included in quadrant 0. + """ + if self.y > 0: + return 0 if self.x > 0 else 1 + elif self.y < 0: + return 2 if self.x < 0 else 3 + else: + return 0 if self.x >= 0 else 2 + + def __rdivmod__(self, other): + try: + other = self._parent.convert(other) + except CoercionFailed: + return NotImplemented + else: + return other.__divmod__(self) + + def __rtruediv__(self, other): + try: + other = QQ_I.convert(other) + except CoercionFailed: + return NotImplemented + else: + return other.__truediv__(self) + + def __floordiv__(self, other): + qr = self.__divmod__(other) + return qr if qr is NotImplemented else qr[0] + + def __rfloordiv__(self, other): + qr = self.__rdivmod__(other) + return qr if qr is NotImplemented else qr[0] + + def __mod__(self, other): + qr = self.__divmod__(other) + return qr if qr is NotImplemented else qr[1] + + def __rmod__(self, other): + qr = self.__rdivmod__(other) + return qr if qr is NotImplemented else qr[1] + + +class GaussianInteger(GaussianElement): + """Gaussian integer: domain element for :ref:`ZZ_I` + + >>> from sympy import ZZ_I + >>> z = ZZ_I(2, 3) + >>> z + (2 + 3*I) + >>> type(z) + + """ + base = ZZ + + def __truediv__(self, other): + """Return a Gaussian rational.""" + return QQ_I.convert(self)/other + + def __divmod__(self, other): + if not other: + raise ZeroDivisionError('divmod({}, 0)'.format(self)) + x, y = self._get_xy(other) + if x is None: + return NotImplemented + + # multiply self and other by x - I*y + # self/other == (a + I*b)/c + a, b = self.x*x + self.y*y, -self.x*y + self.y*x + c = x*x + y*y + + # find integers qx and qy such that + # |a - qx*c| <= c/2 and |b - qy*c| <= c/2 + qx = (2*a + c) // (2*c) # -c <= 2*a - qx*2*c < c + qy = (2*b + c) // (2*c) + + q = GaussianInteger(qx, qy) + # |self/other - q| < 1 since + # |a/c - qx|**2 + |b/c - qy|**2 <= 1/4 + 1/4 < 1 + + return q, self - q*other # |r| < |other| + + +class GaussianRational(GaussianElement): + """Gaussian rational: domain element for :ref:`QQ_I` + + >>> from sympy import QQ_I, QQ + >>> z = QQ_I(QQ(2, 3), QQ(4, 5)) + >>> z + (2/3 + 4/5*I) + >>> type(z) + + """ + base = QQ + + def __truediv__(self, other): + """Return a Gaussian rational.""" + if not other: + raise ZeroDivisionError('{} / 0'.format(self)) + x, y = self._get_xy(other) + if x is None: + return NotImplemented + c = x*x + y*y + + return GaussianRational((self.x*x + self.y*y)/c, + (-self.x*y + self.y*x)/c) + + def __divmod__(self, other): + try: + other = self._parent.convert(other) + except CoercionFailed: + return NotImplemented + if not other: + raise ZeroDivisionError('{} % 0'.format(self)) + else: + return self/other, QQ_I.zero + + +class GaussianDomain(): + """Base class for Gaussian domains.""" + dom: Domain + + is_Numerical = True + is_Exact = True + + has_assoc_Ring = True + has_assoc_Field = True + + def to_sympy(self, a): + """Convert ``a`` to a SymPy object. """ + conv = self.dom.to_sympy + return conv(a.x) + I*conv(a.y) + + def from_sympy(self, a): + """Convert a SymPy object to ``self.dtype``.""" + r, b = a.as_coeff_Add() + x = self.dom.from_sympy(r) # may raise CoercionFailed + if not b: + return self.new(x, 0) + r, b = b.as_coeff_Mul() + y = self.dom.from_sympy(r) + if b is I: + return self.new(x, y) + else: + raise CoercionFailed("{} is not Gaussian".format(a)) + + def inject(self, *gens): + """Inject generators into this domain. """ + return self.poly_ring(*gens) + + def canonical_unit(self, d): + unit = self.units[-d.quadrant()] # - for inverse power + return unit + + def is_negative(self, element): + """Returns ``False`` for any ``GaussianElement``. """ + return False + + def is_positive(self, element): + """Returns ``False`` for any ``GaussianElement``. """ + return False + + def is_nonnegative(self, element): + """Returns ``False`` for any ``GaussianElement``. """ + return False + + def is_nonpositive(self, element): + """Returns ``False`` for any ``GaussianElement``. """ + return False + + def from_ZZ_gmpy(K1, a, K0): + """Convert a GMPY mpz to ``self.dtype``.""" + return K1(a) + + def from_ZZ(K1, a, K0): + """Convert a ZZ_python element to ``self.dtype``.""" + return K1(a) + + def from_ZZ_python(K1, a, K0): + """Convert a ZZ_python element to ``self.dtype``.""" + return K1(a) + + def from_QQ(K1, a, K0): + """Convert a GMPY mpq to ``self.dtype``.""" + return K1(a) + + def from_QQ_gmpy(K1, a, K0): + """Convert a GMPY mpq to ``self.dtype``.""" + return K1(a) + + def from_QQ_python(K1, a, K0): + """Convert a QQ_python element to ``self.dtype``.""" + return K1(a) + + def from_AlgebraicField(K1, a, K0): + """Convert an element from ZZ or QQ to ``self.dtype``.""" + if K0.ext.args[0] == I: + return K1.from_sympy(K0.to_sympy(a)) + + +class GaussianIntegerRing(GaussianDomain, Ring): + r"""Ring of Gaussian integers ``ZZ_I`` + + The :ref:`ZZ_I` domain represents the `Gaussian integers`_ `\mathbb{Z}[i]` + as a :py:class:`~.Domain` in the domain system (see + :ref:`polys-domainsintro`). + + By default a :py:class:`~.Poly` created from an expression with + coefficients that are combinations of integers and ``I`` (`\sqrt{-1}`) + will have the domain :ref:`ZZ_I`. + + >>> from sympy import Poly, Symbol, I + >>> x = Symbol('x') + >>> p = Poly(x**2 + I) + >>> p + Poly(x**2 + I, x, domain='ZZ_I') + >>> p.domain + ZZ_I + + The :ref:`ZZ_I` domain can be used to factorise polynomials that are + reducible over the Gaussian integers. + + >>> from sympy import factor + >>> factor(x**2 + 1) + x**2 + 1 + >>> factor(x**2 + 1, domain='ZZ_I') + (x - I)*(x + I) + + The corresponding `field of fractions`_ is the domain of the Gaussian + rationals :ref:`QQ_I`. Conversely :ref:`ZZ_I` is the `ring of integers`_ + of :ref:`QQ_I`. + + >>> from sympy import ZZ_I, QQ_I + >>> ZZ_I.get_field() + QQ_I + >>> QQ_I.get_ring() + ZZ_I + + When using the domain directly :ref:`ZZ_I` can be used as a constructor. + + >>> ZZ_I(3, 4) + (3 + 4*I) + >>> ZZ_I(5) + (5 + 0*I) + + The domain elements of :ref:`ZZ_I` are instances of + :py:class:`~.GaussianInteger` which support the rings operations + ``+,-,*,**``. + + >>> z1 = ZZ_I(5, 1) + >>> z2 = ZZ_I(2, 3) + >>> z1 + (5 + 1*I) + >>> z2 + (2 + 3*I) + >>> z1 + z2 + (7 + 4*I) + >>> z1 * z2 + (7 + 17*I) + >>> z1 ** 2 + (24 + 10*I) + + Both floor (``//``) and modulo (``%``) division work with + :py:class:`~.GaussianInteger` (see the :py:meth:`~.Domain.div` method). + + >>> z3, z4 = ZZ_I(5), ZZ_I(1, 3) + >>> z3 // z4 # floor division + (1 + -1*I) + >>> z3 % z4 # modulo division (remainder) + (1 + -2*I) + >>> (z3//z4)*z4 + z3%z4 == z3 + True + + True division (``/``) in :ref:`ZZ_I` gives an element of :ref:`QQ_I`. The + :py:meth:`~.Domain.exquo` method can be used to divide in :ref:`ZZ_I` when + exact division is possible. + + >>> z1 / z2 + (1 + -1*I) + >>> ZZ_I.exquo(z1, z2) + (1 + -1*I) + >>> z3 / z4 + (1/2 + -3/2*I) + >>> ZZ_I.exquo(z3, z4) + Traceback (most recent call last): + ... + ExactQuotientFailed: (1 + 3*I) does not divide (5 + 0*I) in ZZ_I + + The :py:meth:`~.Domain.gcd` method can be used to compute the `gcd`_ of any + two elements. + + >>> ZZ_I.gcd(ZZ_I(10), ZZ_I(2)) + (2 + 0*I) + >>> ZZ_I.gcd(ZZ_I(5), ZZ_I(2, 1)) + (2 + 1*I) + + .. _Gaussian integers: https://en.wikipedia.org/wiki/Gaussian_integer + .. _gcd: https://en.wikipedia.org/wiki/Greatest_common_divisor + + """ + dom = ZZ + mod = DMP([ZZ.one, ZZ.zero, ZZ.one], ZZ) + dtype = GaussianInteger + zero = dtype(ZZ(0), ZZ(0)) + one = dtype(ZZ(1), ZZ(0)) + imag_unit = dtype(ZZ(0), ZZ(1)) + units = (one, imag_unit, -one, -imag_unit) # powers of i + + rep = 'ZZ_I' + + is_GaussianRing = True + is_ZZ_I = True + is_PID = True + + def __init__(self): # override Domain.__init__ + """For constructing ZZ_I.""" + + def __eq__(self, other): + """Returns ``True`` if two domains are equivalent. """ + if isinstance(other, GaussianIntegerRing): + return True + else: + return NotImplemented + + def __hash__(self): + """Compute hash code of ``self``. """ + return hash('ZZ_I') + + @property + def has_CharacteristicZero(self): + return True + + def characteristic(self): + return 0 + + def get_ring(self): + """Returns a ring associated with ``self``. """ + return self + + def get_field(self): + """Returns a field associated with ``self``. """ + return QQ_I + + def normalize(self, d, *args): + """Return first quadrant element associated with ``d``. + + Also multiply the other arguments by the same power of i. + """ + unit = self.canonical_unit(d) + d *= unit + args = tuple(a*unit for a in args) + return (d,) + args if args else d + + def gcd(self, a, b): + """Greatest common divisor of a and b over ZZ_I.""" + while b: + a, b = b, a % b + return self.normalize(a) + + def gcdex(self, a, b): + """Return x, y, g such that x * a + y * b = g = gcd(a, b)""" + x_a = self.one + x_b = self.zero + y_a = self.zero + y_b = self.one + while b: + q = a // b + a, b = b, a - q * b + x_a, x_b = x_b, x_a - q * x_b + y_a, y_b = y_b, y_a - q * y_b + + a, x_a, y_a = self.normalize(a, x_a, y_a) + return x_a, y_a, a + + def lcm(self, a, b): + """Least common multiple of a and b over ZZ_I.""" + return (a * b) // self.gcd(a, b) + + def from_GaussianIntegerRing(K1, a, K0): + """Convert a ZZ_I element to ZZ_I.""" + return a + + def from_GaussianRationalField(K1, a, K0): + """Convert a QQ_I element to ZZ_I.""" + return K1.new(ZZ.convert(a.x), ZZ.convert(a.y)) + +ZZ_I = GaussianInteger._parent = GaussianIntegerRing() + + +class GaussianRationalField(GaussianDomain, Field): + r"""Field of Gaussian rationals ``QQ_I`` + + The :ref:`QQ_I` domain represents the `Gaussian rationals`_ `\mathbb{Q}(i)` + as a :py:class:`~.Domain` in the domain system (see + :ref:`polys-domainsintro`). + + By default a :py:class:`~.Poly` created from an expression with + coefficients that are combinations of rationals and ``I`` (`\sqrt{-1}`) + will have the domain :ref:`QQ_I`. + + >>> from sympy import Poly, Symbol, I + >>> x = Symbol('x') + >>> p = Poly(x**2 + I/2) + >>> p + Poly(x**2 + I/2, x, domain='QQ_I') + >>> p.domain + QQ_I + + The polys option ``gaussian=True`` can be used to specify that the domain + should be :ref:`QQ_I` even if the coefficients do not contain ``I`` or are + all integers. + + >>> Poly(x**2) + Poly(x**2, x, domain='ZZ') + >>> Poly(x**2 + I) + Poly(x**2 + I, x, domain='ZZ_I') + >>> Poly(x**2/2) + Poly(1/2*x**2, x, domain='QQ') + >>> Poly(x**2, gaussian=True) + Poly(x**2, x, domain='QQ_I') + >>> Poly(x**2 + I, gaussian=True) + Poly(x**2 + I, x, domain='QQ_I') + >>> Poly(x**2/2, gaussian=True) + Poly(1/2*x**2, x, domain='QQ_I') + + The :ref:`QQ_I` domain can be used to factorise polynomials that are + reducible over the Gaussian rationals. + + >>> from sympy import factor, QQ_I + >>> factor(x**2/4 + 1) + (x**2 + 4)/4 + >>> factor(x**2/4 + 1, domain='QQ_I') + (x - 2*I)*(x + 2*I)/4 + >>> factor(x**2/4 + 1, domain=QQ_I) + (x - 2*I)*(x + 2*I)/4 + + It is also possible to specify the :ref:`QQ_I` domain explicitly with + polys functions like :py:func:`~.apart`. + + >>> from sympy import apart + >>> apart(1/(1 + x**2)) + 1/(x**2 + 1) + >>> apart(1/(1 + x**2), domain=QQ_I) + I/(2*(x + I)) - I/(2*(x - I)) + + The corresponding `ring of integers`_ is the domain of the Gaussian + integers :ref:`ZZ_I`. Conversely :ref:`QQ_I` is the `field of fractions`_ + of :ref:`ZZ_I`. + + >>> from sympy import ZZ_I, QQ_I, QQ + >>> ZZ_I.get_field() + QQ_I + >>> QQ_I.get_ring() + ZZ_I + + When using the domain directly :ref:`QQ_I` can be used as a constructor. + + >>> QQ_I(3, 4) + (3 + 4*I) + >>> QQ_I(5) + (5 + 0*I) + >>> QQ_I(QQ(2, 3), QQ(4, 5)) + (2/3 + 4/5*I) + + The domain elements of :ref:`QQ_I` are instances of + :py:class:`~.GaussianRational` which support the field operations + ``+,-,*,**,/``. + + >>> z1 = QQ_I(5, 1) + >>> z2 = QQ_I(2, QQ(1, 2)) + >>> z1 + (5 + 1*I) + >>> z2 + (2 + 1/2*I) + >>> z1 + z2 + (7 + 3/2*I) + >>> z1 * z2 + (19/2 + 9/2*I) + >>> z2 ** 2 + (15/4 + 2*I) + + True division (``/``) in :ref:`QQ_I` gives an element of :ref:`QQ_I` and + is always exact. + + >>> z1 / z2 + (42/17 + -2/17*I) + >>> QQ_I.exquo(z1, z2) + (42/17 + -2/17*I) + >>> z1 == (z1/z2)*z2 + True + + Both floor (``//``) and modulo (``%``) division can be used with + :py:class:`~.GaussianRational` (see :py:meth:`~.Domain.div`) + but division is always exact so there is no remainder. + + >>> z1 // z2 + (42/17 + -2/17*I) + >>> z1 % z2 + (0 + 0*I) + >>> QQ_I.div(z1, z2) + ((42/17 + -2/17*I), (0 + 0*I)) + >>> (z1//z2)*z2 + z1%z2 == z1 + True + + .. _Gaussian rationals: https://en.wikipedia.org/wiki/Gaussian_rational + """ + dom = QQ + mod = DMP([QQ.one, QQ.zero, QQ.one], QQ) + dtype = GaussianRational + zero = dtype(QQ(0), QQ(0)) + one = dtype(QQ(1), QQ(0)) + imag_unit = dtype(QQ(0), QQ(1)) + units = (one, imag_unit, -one, -imag_unit) # powers of i + + rep = 'QQ_I' + + is_GaussianField = True + is_QQ_I = True + + def __init__(self): # override Domain.__init__ + """For constructing QQ_I.""" + + def __eq__(self, other): + """Returns ``True`` if two domains are equivalent. """ + if isinstance(other, GaussianRationalField): + return True + else: + return NotImplemented + + def __hash__(self): + """Compute hash code of ``self``. """ + return hash('QQ_I') + + @property + def has_CharacteristicZero(self): + return True + + def characteristic(self): + return 0 + + def get_ring(self): + """Returns a ring associated with ``self``. """ + return ZZ_I + + def get_field(self): + """Returns a field associated with ``self``. """ + return self + + def as_AlgebraicField(self): + """Get equivalent domain as an ``AlgebraicField``. """ + return AlgebraicField(self.dom, I) + + def numer(self, a): + """Get the numerator of ``a``.""" + ZZ_I = self.get_ring() + return ZZ_I.convert(a * self.denom(a)) + + def denom(self, a): + """Get the denominator of ``a``.""" + ZZ = self.dom.get_ring() + QQ = self.dom + ZZ_I = self.get_ring() + denom_ZZ = ZZ.lcm(QQ.denom(a.x), QQ.denom(a.y)) + return ZZ_I(denom_ZZ, ZZ.zero) + + def from_GaussianIntegerRing(K1, a, K0): + """Convert a ZZ_I element to QQ_I.""" + return K1.new(a.x, a.y) + + def from_GaussianRationalField(K1, a, K0): + """Convert a QQ_I element to QQ_I.""" + return a + + def from_ComplexField(K1, a, K0): + """Convert a ComplexField element to QQ_I.""" + return K1.new(QQ.convert(a.real), QQ.convert(a.imag)) + + +QQ_I = GaussianRational._parent = GaussianRationalField() diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/gmpyfinitefield.py b/phivenv/Lib/site-packages/sympy/polys/domains/gmpyfinitefield.py new file mode 100644 index 0000000000000000000000000000000000000000..2e8315a29eca8160102d66b83d953caf998b0fd7 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/domains/gmpyfinitefield.py @@ -0,0 +1,16 @@ +"""Implementation of :class:`GMPYFiniteField` class. """ + + +from sympy.polys.domains.finitefield import FiniteField +from sympy.polys.domains.gmpyintegerring import GMPYIntegerRing + +from sympy.utilities import public + +@public +class GMPYFiniteField(FiniteField): + """Finite field based on GMPY integers. """ + + alias = 'FF_gmpy' + + def __init__(self, mod, symmetric=True): + super().__init__(mod, GMPYIntegerRing(), symmetric) diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/gmpyintegerring.py b/phivenv/Lib/site-packages/sympy/polys/domains/gmpyintegerring.py new file mode 100644 index 0000000000000000000000000000000000000000..f132bbe5aff7a4164a09b9b90f00ae5f140cbd03 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/domains/gmpyintegerring.py @@ -0,0 +1,105 @@ +"""Implementation of :class:`GMPYIntegerRing` class. """ + + +from sympy.polys.domains.groundtypes import ( + GMPYInteger, SymPyInteger, + factorial as gmpy_factorial, + gmpy_gcdex, gmpy_gcd, gmpy_lcm, sqrt as gmpy_sqrt, +) +from sympy.core.numbers import int_valued +from sympy.polys.domains.integerring import IntegerRing +from sympy.polys.polyerrors import CoercionFailed +from sympy.utilities import public + +@public +class GMPYIntegerRing(IntegerRing): + """Integer ring based on GMPY's ``mpz`` type. + + This will be the implementation of :ref:`ZZ` if ``gmpy`` or ``gmpy2`` is + installed. Elements will be of type ``gmpy.mpz``. + """ + + dtype = GMPYInteger + zero = dtype(0) + one = dtype(1) + tp = type(one) + alias = 'ZZ_gmpy' + + def __init__(self): + """Allow instantiation of this domain. """ + + def to_sympy(self, a): + """Convert ``a`` to a SymPy object. """ + return SymPyInteger(int(a)) + + def from_sympy(self, a): + """Convert SymPy's Integer to ``dtype``. """ + if a.is_Integer: + return GMPYInteger(a.p) + elif int_valued(a): + return GMPYInteger(int(a)) + else: + raise CoercionFailed("expected an integer, got %s" % a) + + def from_FF_python(K1, a, K0): + """Convert ``ModularInteger(int)`` to GMPY's ``mpz``. """ + return K0.to_int(a) + + def from_ZZ_python(K1, a, K0): + """Convert Python's ``int`` to GMPY's ``mpz``. """ + return GMPYInteger(a) + + def from_QQ(K1, a, K0): + """Convert Python's ``Fraction`` to GMPY's ``mpz``. """ + if a.denominator == 1: + return GMPYInteger(a.numerator) + + def from_QQ_python(K1, a, K0): + """Convert Python's ``Fraction`` to GMPY's ``mpz``. """ + if a.denominator == 1: + return GMPYInteger(a.numerator) + + def from_FF_gmpy(K1, a, K0): + """Convert ``ModularInteger(mpz)`` to GMPY's ``mpz``. """ + return K0.to_int(a) + + def from_ZZ_gmpy(K1, a, K0): + """Convert GMPY's ``mpz`` to GMPY's ``mpz``. """ + return a + + def from_QQ_gmpy(K1, a, K0): + """Convert GMPY ``mpq`` to GMPY's ``mpz``. """ + if a.denominator == 1: + return a.numerator + + def from_RealField(K1, a, K0): + """Convert mpmath's ``mpf`` to GMPY's ``mpz``. """ + p, q = K0.to_rational(a) + + if q == 1: + return GMPYInteger(p) + + def from_GaussianIntegerRing(K1, a, K0): + if a.y == 0: + return a.x + + def gcdex(self, a, b): + """Compute extended GCD of ``a`` and ``b``. """ + h, s, t = gmpy_gcdex(a, b) + return s, t, h + + def gcd(self, a, b): + """Compute GCD of ``a`` and ``b``. """ + return gmpy_gcd(a, b) + + def lcm(self, a, b): + """Compute LCM of ``a`` and ``b``. """ + return gmpy_lcm(a, b) + + def sqrt(self, a): + """Compute square root of ``a``. """ + return gmpy_sqrt(a) + + def factorial(self, a): + """Compute factorial of ``a``. """ + return gmpy_factorial(a) diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/gmpyrationalfield.py b/phivenv/Lib/site-packages/sympy/polys/domains/gmpyrationalfield.py new file mode 100644 index 0000000000000000000000000000000000000000..10bae5b2b7b476f96ba06f637c549ee4afff4c6d --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/domains/gmpyrationalfield.py @@ -0,0 +1,100 @@ +"""Implementation of :class:`GMPYRationalField` class. """ + + +from sympy.polys.domains.groundtypes import ( + GMPYRational, SymPyRational, + gmpy_numer, gmpy_denom, factorial as gmpy_factorial, +) +from sympy.polys.domains.rationalfield import RationalField +from sympy.polys.polyerrors import CoercionFailed +from sympy.utilities import public + +@public +class GMPYRationalField(RationalField): + """Rational field based on GMPY's ``mpq`` type. + + This will be the implementation of :ref:`QQ` if ``gmpy`` or ``gmpy2`` is + installed. Elements will be of type ``gmpy.mpq``. + """ + + dtype = GMPYRational + zero = dtype(0) + one = dtype(1) + tp = type(one) + alias = 'QQ_gmpy' + + def __init__(self): + pass + + def get_ring(self): + """Returns ring associated with ``self``. """ + from sympy.polys.domains import GMPYIntegerRing + return GMPYIntegerRing() + + def to_sympy(self, a): + """Convert ``a`` to a SymPy object. """ + return SymPyRational(int(gmpy_numer(a)), + int(gmpy_denom(a))) + + def from_sympy(self, a): + """Convert SymPy's Integer to ``dtype``. """ + if a.is_Rational: + return GMPYRational(a.p, a.q) + elif a.is_Float: + from sympy.polys.domains import RR + return GMPYRational(*map(int, RR.to_rational(a))) + else: + raise CoercionFailed("expected ``Rational`` object, got %s" % a) + + def from_ZZ_python(K1, a, K0): + """Convert a Python ``int`` object to ``dtype``. """ + return GMPYRational(a) + + def from_QQ_python(K1, a, K0): + """Convert a Python ``Fraction`` object to ``dtype``. """ + return GMPYRational(a.numerator, a.denominator) + + def from_ZZ_gmpy(K1, a, K0): + """Convert a GMPY ``mpz`` object to ``dtype``. """ + return GMPYRational(a) + + def from_QQ_gmpy(K1, a, K0): + """Convert a GMPY ``mpq`` object to ``dtype``. """ + return a + + def from_GaussianRationalField(K1, a, K0): + """Convert a ``GaussianElement`` object to ``dtype``. """ + if a.y == 0: + return GMPYRational(a.x) + + def from_RealField(K1, a, K0): + """Convert a mpmath ``mpf`` object to ``dtype``. """ + return GMPYRational(*map(int, K0.to_rational(a))) + + def exquo(self, a, b): + """Exact quotient of ``a`` and ``b``, implies ``__truediv__``. """ + return GMPYRational(a) / GMPYRational(b) + + def quo(self, a, b): + """Quotient of ``a`` and ``b``, implies ``__truediv__``. """ + return GMPYRational(a) / GMPYRational(b) + + def rem(self, a, b): + """Remainder of ``a`` and ``b``, implies nothing. """ + return self.zero + + def div(self, a, b): + """Division of ``a`` and ``b``, implies ``__truediv__``. """ + return GMPYRational(a) / GMPYRational(b), self.zero + + def numer(self, a): + """Returns numerator of ``a``. """ + return a.numerator + + def denom(self, a): + """Returns denominator of ``a``. """ + return a.denominator + + def factorial(self, a): + """Returns factorial of ``a``. """ + return GMPYRational(gmpy_factorial(int(a))) diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/groundtypes.py b/phivenv/Lib/site-packages/sympy/polys/domains/groundtypes.py new file mode 100644 index 0000000000000000000000000000000000000000..1d50cf912a998767c4a52c5a2f3aab825e072aec --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/domains/groundtypes.py @@ -0,0 +1,99 @@ +"""Ground types for various mathematical domains in SymPy. """ + +import builtins +from sympy.external.gmpy import GROUND_TYPES, factorial, sqrt, is_square, sqrtrem + +PythonInteger = builtins.int +PythonReal = builtins.float +PythonComplex = builtins.complex + +from .pythonrational import PythonRational + +from sympy.core.intfunc import ( + igcdex as python_gcdex, + igcd2 as python_gcd, + ilcm as python_lcm, +) + +from sympy.core.numbers import (Float as SymPyReal, Integer as SymPyInteger, Rational as SymPyRational) + + +class _GMPYInteger: + def __init__(self, obj): + pass + +class _GMPYRational: + def __init__(self, obj): + pass + + +if GROUND_TYPES == 'gmpy': + + from gmpy2 import ( + mpz as GMPYInteger, + mpq as GMPYRational, + numer as gmpy_numer, + denom as gmpy_denom, + gcdext as gmpy_gcdex, + gcd as gmpy_gcd, + lcm as gmpy_lcm, + qdiv as gmpy_qdiv, + ) + gcdex = gmpy_gcdex + gcd = gmpy_gcd + lcm = gmpy_lcm + +elif GROUND_TYPES == 'flint': + + from flint import fmpz as _fmpz + + GMPYInteger = _GMPYInteger + GMPYRational = _GMPYRational + gmpy_numer = None + gmpy_denom = None + gmpy_gcdex = None + gmpy_gcd = None + gmpy_lcm = None + gmpy_qdiv = None + + def gcd(a, b): + return a.gcd(b) + + def gcdex(a, b): + x, y, g = python_gcdex(a, b) + return _fmpz(x), _fmpz(y), _fmpz(g) + + def lcm(a, b): + return a.lcm(b) + +else: + GMPYInteger = _GMPYInteger + GMPYRational = _GMPYRational + gmpy_numer = None + gmpy_denom = None + gmpy_gcdex = None + gmpy_gcd = None + gmpy_lcm = None + gmpy_qdiv = None + gcdex = python_gcdex + gcd = python_gcd + lcm = python_lcm + + +__all__ = [ + 'PythonInteger', 'PythonReal', 'PythonComplex', + + 'PythonRational', + + 'python_gcdex', 'python_gcd', 'python_lcm', + + 'SymPyReal', 'SymPyInteger', 'SymPyRational', + + 'GMPYInteger', 'GMPYRational', 'gmpy_numer', + 'gmpy_denom', 'gmpy_gcdex', 'gmpy_gcd', 'gmpy_lcm', + 'gmpy_qdiv', + + 'factorial', 'sqrt', 'is_square', 'sqrtrem', + + 'GMPYInteger', 'GMPYRational', +] diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/integerring.py b/phivenv/Lib/site-packages/sympy/polys/domains/integerring.py new file mode 100644 index 0000000000000000000000000000000000000000..65eaa9631cfdf138997a4ebdb362c4233fb098fb --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/domains/integerring.py @@ -0,0 +1,276 @@ +"""Implementation of :class:`IntegerRing` class. """ + +from sympy.external.gmpy import MPZ, GROUND_TYPES + +from sympy.core.numbers import int_valued +from sympy.polys.domains.groundtypes import ( + SymPyInteger, + factorial, + gcdex, gcd, lcm, sqrt, is_square, sqrtrem, +) + +from sympy.polys.domains.characteristiczero import CharacteristicZero +from sympy.polys.domains.ring import Ring +from sympy.polys.domains.simpledomain import SimpleDomain +from sympy.polys.polyerrors import CoercionFailed +from sympy.utilities import public + +import math + +@public +class IntegerRing(Ring, CharacteristicZero, SimpleDomain): + r"""The domain ``ZZ`` representing the integers `\mathbb{Z}`. + + The :py:class:`IntegerRing` class represents the ring of integers as a + :py:class:`~.Domain` in the domain system. :py:class:`IntegerRing` is a + super class of :py:class:`PythonIntegerRing` and + :py:class:`GMPYIntegerRing` one of which will be the implementation for + :ref:`ZZ` depending on whether or not ``gmpy`` or ``gmpy2`` is installed. + + See also + ======== + + Domain + """ + + rep = 'ZZ' + alias = 'ZZ' + dtype = MPZ + zero = dtype(0) + one = dtype(1) + tp = type(one) + + + is_IntegerRing = is_ZZ = True + is_Numerical = True + is_PID = True + + has_assoc_Ring = True + has_assoc_Field = True + + def __init__(self): + """Allow instantiation of this domain. """ + + def __eq__(self, other): + """Returns ``True`` if two domains are equivalent. """ + if isinstance(other, IntegerRing): + return True + else: + return NotImplemented + + def __hash__(self): + """Compute a hash value for this domain. """ + return hash('ZZ') + + def to_sympy(self, a): + """Convert ``a`` to a SymPy object. """ + return SymPyInteger(int(a)) + + def from_sympy(self, a): + """Convert SymPy's Integer to ``dtype``. """ + if a.is_Integer: + return MPZ(a.p) + elif int_valued(a): + return MPZ(int(a)) + else: + raise CoercionFailed("expected an integer, got %s" % a) + + def get_field(self): + r"""Return the associated field of fractions :ref:`QQ` + + Returns + ======= + + :ref:`QQ`: + The associated field of fractions :ref:`QQ`, a + :py:class:`~.Domain` representing the rational numbers + `\mathbb{Q}`. + + Examples + ======== + + >>> from sympy import ZZ + >>> ZZ.get_field() + QQ + """ + from sympy.polys.domains import QQ + return QQ + + def algebraic_field(self, *extension, alias=None): + r"""Returns an algebraic field, i.e. `\mathbb{Q}(\alpha, \ldots)`. + + Parameters + ========== + + *extension : One or more :py:class:`~.Expr`. + Generators of the extension. These should be expressions that are + algebraic over `\mathbb{Q}`. + + alias : str, :py:class:`~.Symbol`, None, optional (default=None) + If provided, this will be used as the alias symbol for the + primitive element of the returned :py:class:`~.AlgebraicField`. + + Returns + ======= + + :py:class:`~.AlgebraicField` + A :py:class:`~.Domain` representing the algebraic field extension. + + Examples + ======== + + >>> from sympy import ZZ, sqrt + >>> ZZ.algebraic_field(sqrt(2)) + QQ + """ + return self.get_field().algebraic_field(*extension, alias=alias) + + def from_AlgebraicField(K1, a, K0): + """Convert a :py:class:`~.ANP` object to :ref:`ZZ`. + + See :py:meth:`~.Domain.convert`. + """ + if a.is_ground: + return K1.convert(a.LC(), K0.dom) + + def log(self, a, b): + r"""Logarithm of *a* to the base *b*. + + Parameters + ========== + + a: number + b: number + + Returns + ======= + + $\\lfloor\log(a, b)\\rfloor$: + Floor of the logarithm of *a* to the base *b* + + Examples + ======== + + >>> from sympy import ZZ + >>> ZZ.log(ZZ(8), ZZ(2)) + 3 + >>> ZZ.log(ZZ(9), ZZ(2)) + 3 + + Notes + ===== + + This function uses ``math.log`` which is based on ``float`` so it will + fail for large integer arguments. + """ + return self.dtype(int(math.log(int(a), b))) + + def from_FF(K1, a, K0): + """Convert ``ModularInteger(int)`` to GMPY's ``mpz``. """ + return MPZ(K0.to_int(a)) + + def from_FF_python(K1, a, K0): + """Convert ``ModularInteger(int)`` to GMPY's ``mpz``. """ + return MPZ(K0.to_int(a)) + + def from_ZZ(K1, a, K0): + """Convert Python's ``int`` to GMPY's ``mpz``. """ + return MPZ(a) + + def from_ZZ_python(K1, a, K0): + """Convert Python's ``int`` to GMPY's ``mpz``. """ + return MPZ(a) + + def from_QQ(K1, a, K0): + """Convert Python's ``Fraction`` to GMPY's ``mpz``. """ + if a.denominator == 1: + return MPZ(a.numerator) + + def from_QQ_python(K1, a, K0): + """Convert Python's ``Fraction`` to GMPY's ``mpz``. """ + if a.denominator == 1: + return MPZ(a.numerator) + + def from_FF_gmpy(K1, a, K0): + """Convert ``ModularInteger(mpz)`` to GMPY's ``mpz``. """ + return MPZ(K0.to_int(a)) + + def from_ZZ_gmpy(K1, a, K0): + """Convert GMPY's ``mpz`` to GMPY's ``mpz``. """ + return a + + def from_QQ_gmpy(K1, a, K0): + """Convert GMPY ``mpq`` to GMPY's ``mpz``. """ + if a.denominator == 1: + return a.numerator + + def from_RealField(K1, a, K0): + """Convert mpmath's ``mpf`` to GMPY's ``mpz``. """ + p, q = K0.to_rational(a) + + if q == 1: + # XXX: If MPZ is flint.fmpz and p is a gmpy2.mpz, then we need + # to convert via int because fmpz and mpz do not know about each + # other. + return MPZ(int(p)) + + def from_GaussianIntegerRing(K1, a, K0): + if a.y == 0: + return a.x + + def from_EX(K1, a, K0): + """Convert ``Expression`` to GMPY's ``mpz``. """ + if a.is_Integer: + return K1.from_sympy(a) + + def gcdex(self, a, b): + """Compute extended GCD of ``a`` and ``b``. """ + h, s, t = gcdex(a, b) + # XXX: This conditional logic should be handled somewhere else. + if GROUND_TYPES == 'gmpy': + return s, t, h + else: + return h, s, t + + def gcd(self, a, b): + """Compute GCD of ``a`` and ``b``. """ + return gcd(a, b) + + def lcm(self, a, b): + """Compute LCM of ``a`` and ``b``. """ + return lcm(a, b) + + def sqrt(self, a): + """Compute square root of ``a``. """ + return sqrt(a) + + def is_square(self, a): + """Return ``True`` if ``a`` is a square. + + Explanation + =========== + An integer is a square if and only if there exists an integer + ``b`` such that ``b * b == a``. + """ + return is_square(a) + + def exsqrt(self, a): + """Non-negative square root of ``a`` if ``a`` is a square. + + See also + ======== + is_square + """ + if a < 0: + return None + root, rem = sqrtrem(a) + if rem != 0: + return None + return root + + def factorial(self, a): + """Compute factorial of ``a``. """ + return factorial(a) + + +ZZ = IntegerRing() diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/modularinteger.py b/phivenv/Lib/site-packages/sympy/polys/domains/modularinteger.py new file mode 100644 index 0000000000000000000000000000000000000000..39a0237563c69a77e4736466d1ebcaa7ca39485f --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/domains/modularinteger.py @@ -0,0 +1,237 @@ +"""Implementation of :class:`ModularInteger` class. """ + +from __future__ import annotations +from typing import Any + +import operator + +from sympy.polys.polyutils import PicklableWithSlots +from sympy.polys.polyerrors import CoercionFailed +from sympy.polys.domains.domainelement import DomainElement + +from sympy.utilities import public +from sympy.utilities.exceptions import sympy_deprecation_warning + +@public +class ModularInteger(PicklableWithSlots, DomainElement): + """A class representing a modular integer. """ + + mod, dom, sym, _parent = None, None, None, None + + __slots__ = ('val',) + + def parent(self): + return self._parent + + def __init__(self, val): + if isinstance(val, self.__class__): + self.val = val.val % self.mod + else: + self.val = self.dom.convert(val) % self.mod + + def modulus(self): + return self.mod + + def __hash__(self): + return hash((self.val, self.mod)) + + def __repr__(self): + return "%s(%s)" % (self.__class__.__name__, self.val) + + def __str__(self): + return "%s mod %s" % (self.val, self.mod) + + def __int__(self): + return int(self.val) + + def to_int(self): + + sympy_deprecation_warning( + """ModularInteger.to_int() is deprecated. + + Use int(a) or K = GF(p) and K.to_int(a) instead of a.to_int(). + """, + deprecated_since_version="1.13", + active_deprecations_target="modularinteger-to-int", + ) + + if self.sym: + if self.val <= self.mod // 2: + return self.val + else: + return self.val - self.mod + else: + return self.val + + def __pos__(self): + return self + + def __neg__(self): + return self.__class__(-self.val) + + @classmethod + def _get_val(cls, other): + if isinstance(other, cls): + return other.val + else: + try: + return cls.dom.convert(other) + except CoercionFailed: + return None + + def __add__(self, other): + val = self._get_val(other) + + if val is not None: + return self.__class__(self.val + val) + else: + return NotImplemented + + def __radd__(self, other): + return self.__add__(other) + + def __sub__(self, other): + val = self._get_val(other) + + if val is not None: + return self.__class__(self.val - val) + else: + return NotImplemented + + def __rsub__(self, other): + return (-self).__add__(other) + + def __mul__(self, other): + val = self._get_val(other) + + if val is not None: + return self.__class__(self.val * val) + else: + return NotImplemented + + def __rmul__(self, other): + return self.__mul__(other) + + def __truediv__(self, other): + val = self._get_val(other) + + if val is not None: + return self.__class__(self.val * self._invert(val)) + else: + return NotImplemented + + def __rtruediv__(self, other): + return self.invert().__mul__(other) + + def __mod__(self, other): + val = self._get_val(other) + + if val is not None: + return self.__class__(self.val % val) + else: + return NotImplemented + + def __rmod__(self, other): + val = self._get_val(other) + + if val is not None: + return self.__class__(val % self.val) + else: + return NotImplemented + + def __pow__(self, exp): + if not exp: + return self.__class__(self.dom.one) + + if exp < 0: + val, exp = self.invert().val, -exp + else: + val = self.val + + return self.__class__(pow(val, int(exp), self.mod)) + + def _compare(self, other, op): + val = self._get_val(other) + + if val is None: + return NotImplemented + + return op(self.val, val % self.mod) + + def _compare_deprecated(self, other, op): + val = self._get_val(other) + + if val is None: + return NotImplemented + + sympy_deprecation_warning( + """Ordered comparisons with modular integers are deprecated. + + Use e.g. int(a) < int(b) instead of a < b. + """, + deprecated_since_version="1.13", + active_deprecations_target="modularinteger-compare", + stacklevel=4, + ) + + return op(self.val, val % self.mod) + + def __eq__(self, other): + return self._compare(other, operator.eq) + + def __ne__(self, other): + return self._compare(other, operator.ne) + + def __lt__(self, other): + return self._compare_deprecated(other, operator.lt) + + def __le__(self, other): + return self._compare_deprecated(other, operator.le) + + def __gt__(self, other): + return self._compare_deprecated(other, operator.gt) + + def __ge__(self, other): + return self._compare_deprecated(other, operator.ge) + + def __bool__(self): + return bool(self.val) + + @classmethod + def _invert(cls, value): + return cls.dom.invert(value, cls.mod) + + def invert(self): + return self.__class__(self._invert(self.val)) + +_modular_integer_cache: dict[tuple[Any, Any, Any], type[ModularInteger]] = {} + +def ModularIntegerFactory(_mod, _dom, _sym, parent): + """Create custom class for specific integer modulus.""" + try: + _mod = _dom.convert(_mod) + except CoercionFailed: + ok = False + else: + ok = True + + if not ok or _mod < 1: + raise ValueError("modulus must be a positive integer, got %s" % _mod) + + key = _mod, _dom, _sym + + try: + cls = _modular_integer_cache[key] + except KeyError: + class cls(ModularInteger): + mod, dom, sym = _mod, _dom, _sym + _parent = parent + + if _sym: + cls.__name__ = "SymmetricModularIntegerMod%s" % _mod + else: + cls.__name__ = "ModularIntegerMod%s" % _mod + + _modular_integer_cache[key] = cls + + return cls diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/mpelements.py b/phivenv/Lib/site-packages/sympy/polys/domains/mpelements.py new file mode 100644 index 0000000000000000000000000000000000000000..04ae8eaddcbb7fd8fae684374d9d2c05e79f6c7a --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/domains/mpelements.py @@ -0,0 +1,181 @@ +# +# This module is deprecated and should not be used any more. The actual +# implementation of RR and CC now uses mpmath's mpf and mpc types directly. +# +"""Real and complex elements. """ + + +from sympy.external.gmpy import MPQ +from sympy.polys.domains.domainelement import DomainElement +from sympy.utilities import public + +from mpmath.ctx_mp_python import PythonMPContext, _mpf, _mpc, _constant +from mpmath.libmp import (MPZ_ONE, fzero, fone, finf, fninf, fnan, + round_nearest, mpf_mul, repr_dps, int_types, + from_int, from_float, from_str, to_rational) + + +@public +class RealElement(_mpf, DomainElement): + """An element of a real domain. """ + + __slots__ = ('__mpf__',) + + def _set_mpf(self, val): + self.__mpf__ = val + + _mpf_ = property(lambda self: self.__mpf__, _set_mpf) + + def parent(self): + return self.context._parent + +@public +class ComplexElement(_mpc, DomainElement): + """An element of a complex domain. """ + + __slots__ = ('__mpc__',) + + def _set_mpc(self, val): + self.__mpc__ = val + + _mpc_ = property(lambda self: self.__mpc__, _set_mpc) + + def parent(self): + return self.context._parent + +new = object.__new__ + +@public +class MPContext(PythonMPContext): + + def __init__(ctx, prec=53, dps=None, tol=None, real=False): + ctx._prec_rounding = [prec, round_nearest] + + if dps is None: + ctx._set_prec(prec) + else: + ctx._set_dps(dps) + + ctx.mpf = RealElement + ctx.mpc = ComplexElement + ctx.mpf._ctxdata = [ctx.mpf, new, ctx._prec_rounding] + ctx.mpc._ctxdata = [ctx.mpc, new, ctx._prec_rounding] + + if real: + ctx.mpf.context = ctx + else: + ctx.mpc.context = ctx + + ctx.constant = _constant + ctx.constant._ctxdata = [ctx.mpf, new, ctx._prec_rounding] + ctx.constant.context = ctx + + ctx.types = [ctx.mpf, ctx.mpc, ctx.constant] + ctx.trap_complex = True + ctx.pretty = True + + if tol is None: + ctx.tol = ctx._make_tol() + elif tol is False: + ctx.tol = fzero + else: + ctx.tol = ctx._convert_tol(tol) + + ctx.tolerance = ctx.make_mpf(ctx.tol) + + if not ctx.tolerance: + ctx.max_denom = 1000000 + else: + ctx.max_denom = int(1/ctx.tolerance) + + ctx.zero = ctx.make_mpf(fzero) + ctx.one = ctx.make_mpf(fone) + ctx.j = ctx.make_mpc((fzero, fone)) + ctx.inf = ctx.make_mpf(finf) + ctx.ninf = ctx.make_mpf(fninf) + ctx.nan = ctx.make_mpf(fnan) + + def _make_tol(ctx): + hundred = (0, 25, 2, 5) + eps = (0, MPZ_ONE, 1-ctx.prec, 1) + return mpf_mul(hundred, eps) + + def make_tol(ctx): + return ctx.make_mpf(ctx._make_tol()) + + def _convert_tol(ctx, tol): + if isinstance(tol, int_types): + return from_int(tol) + if isinstance(tol, float): + return from_float(tol) + if hasattr(tol, "_mpf_"): + return tol._mpf_ + prec, rounding = ctx._prec_rounding + if isinstance(tol, str): + return from_str(tol, prec, rounding) + raise ValueError("expected a real number, got %s" % tol) + + def _convert_fallback(ctx, x, strings): + raise TypeError("cannot create mpf from " + repr(x)) + + @property + def _repr_digits(ctx): + return repr_dps(ctx._prec) + + @property + def _str_digits(ctx): + return ctx._dps + + def to_rational(ctx, s, limit=True): + p, q = to_rational(s._mpf_) + + # Needed for GROUND_TYPES=flint if gmpy2 is installed because mpmath's + # to_rational() function returns a gmpy2.mpz instance and if MPQ is + # flint.fmpq then MPQ(p, q) will fail. + p = int(p) + + if not limit or q <= ctx.max_denom: + return p, q + + p0, q0, p1, q1 = 0, 1, 1, 0 + n, d = p, q + + while True: + a = n//d + q2 = q0 + a*q1 + if q2 > ctx.max_denom: + break + p0, q0, p1, q1 = p1, q1, p0 + a*p1, q2 + n, d = d, n - a*d + + k = (ctx.max_denom - q0)//q1 + + number = MPQ(p, q) + bound1 = MPQ(p0 + k*p1, q0 + k*q1) + bound2 = MPQ(p1, q1) + + if not bound2 or not bound1: + return p, q + elif abs(bound2 - number) <= abs(bound1 - number): + return bound2.numerator, bound2.denominator + else: + return bound1.numerator, bound1.denominator + + def almosteq(ctx, s, t, rel_eps=None, abs_eps=None): + t = ctx.convert(t) + if abs_eps is None and rel_eps is None: + rel_eps = abs_eps = ctx.tolerance or ctx.make_tol() + if abs_eps is None: + abs_eps = ctx.convert(rel_eps) + elif rel_eps is None: + rel_eps = ctx.convert(abs_eps) + diff = abs(s-t) + if diff <= abs_eps: + return True + abss = abs(s) + abst = abs(t) + if abss < abst: + err = diff/abst + else: + err = diff/abss + return err <= rel_eps diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/old_fractionfield.py b/phivenv/Lib/site-packages/sympy/polys/domains/old_fractionfield.py new file mode 100644 index 0000000000000000000000000000000000000000..25d849c39e45259728479ab0305d4956053ae743 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/domains/old_fractionfield.py @@ -0,0 +1,188 @@ +"""Implementation of :class:`FractionField` class. """ + + +from sympy.polys.domains.field import Field +from sympy.polys.domains.compositedomain import CompositeDomain +from sympy.polys.polyclasses import DMF +from sympy.polys.polyerrors import GeneratorsNeeded +from sympy.polys.polyutils import dict_from_basic, basic_from_dict, _dict_reorder +from sympy.utilities import public + +@public +class FractionField(Field, CompositeDomain): + """A class for representing rational function fields. """ + + dtype = DMF + is_FractionField = is_Frac = True + + has_assoc_Ring = True + has_assoc_Field = True + + def __init__(self, dom, *gens): + if not gens: + raise GeneratorsNeeded("generators not specified") + + lev = len(gens) - 1 + self.ngens = len(gens) + + self.zero = self.dtype.zero(lev, dom) + self.one = self.dtype.one(lev, dom) + + self.domain = self.dom = dom + self.symbols = self.gens = gens + + def set_domain(self, dom): + """Make a new fraction field with given domain. """ + return self.__class__(dom, *self.gens) + + def new(self, element): + return self.dtype(element, self.dom, len(self.gens) - 1) + + def __str__(self): + return str(self.dom) + '(' + ','.join(map(str, self.gens)) + ')' + + def __hash__(self): + return hash((self.__class__.__name__, self.dtype, self.dom, self.gens)) + + def __eq__(self, other): + """Returns ``True`` if two domains are equivalent. """ + return isinstance(other, FractionField) and \ + self.dtype == other.dtype and self.dom == other.dom and self.gens == other.gens + + def to_sympy(self, a): + """Convert ``a`` to a SymPy object. """ + return (basic_from_dict(a.numer().to_sympy_dict(), *self.gens) / + basic_from_dict(a.denom().to_sympy_dict(), *self.gens)) + + def from_sympy(self, a): + """Convert SymPy's expression to ``dtype``. """ + p, q = a.as_numer_denom() + + num, _ = dict_from_basic(p, gens=self.gens) + den, _ = dict_from_basic(q, gens=self.gens) + + for k, v in num.items(): + num[k] = self.dom.from_sympy(v) + + for k, v in den.items(): + den[k] = self.dom.from_sympy(v) + + return self((num, den)).cancel() + + def from_ZZ(K1, a, K0): + """Convert a Python ``int`` object to ``dtype``. """ + return K1(K1.dom.convert(a, K0)) + + def from_ZZ_python(K1, a, K0): + """Convert a Python ``int`` object to ``dtype``. """ + return K1(K1.dom.convert(a, K0)) + + def from_QQ_python(K1, a, K0): + """Convert a Python ``Fraction`` object to ``dtype``. """ + return K1(K1.dom.convert(a, K0)) + + def from_ZZ_gmpy(K1, a, K0): + """Convert a GMPY ``mpz`` object to ``dtype``. """ + return K1(K1.dom.convert(a, K0)) + + def from_QQ_gmpy(K1, a, K0): + """Convert a GMPY ``mpq`` object to ``dtype``. """ + return K1(K1.dom.convert(a, K0)) + + def from_RealField(K1, a, K0): + """Convert a mpmath ``mpf`` object to ``dtype``. """ + return K1(K1.dom.convert(a, K0)) + + def from_GlobalPolynomialRing(K1, a, K0): + """Convert a ``DMF`` object to ``dtype``. """ + if K1.gens == K0.gens: + if K1.dom == K0.dom: + return K1(a.to_list()) + else: + return K1(a.convert(K1.dom).to_list()) + else: + monoms, coeffs = _dict_reorder(a.to_dict(), K0.gens, K1.gens) + + if K1.dom != K0.dom: + coeffs = [ K1.dom.convert(c, K0.dom) for c in coeffs ] + + return K1(dict(zip(monoms, coeffs))) + + def from_FractionField(K1, a, K0): + """ + Convert a fraction field element to another fraction field. + + Examples + ======== + + >>> from sympy.polys.polyclasses import DMF + >>> from sympy.polys.domains import ZZ, QQ + >>> from sympy.abc import x + + >>> f = DMF(([ZZ(1), ZZ(2)], [ZZ(1), ZZ(1)]), ZZ) + + >>> QQx = QQ.old_frac_field(x) + >>> ZZx = ZZ.old_frac_field(x) + + >>> QQx.from_FractionField(f, ZZx) + DMF([1, 2], [1, 1], QQ) + + """ + if K1.gens == K0.gens: + if K1.dom == K0.dom: + return a + else: + return K1((a.numer().convert(K1.dom).to_list(), + a.denom().convert(K1.dom).to_list())) + elif set(K0.gens).issubset(K1.gens): + nmonoms, ncoeffs = _dict_reorder( + a.numer().to_dict(), K0.gens, K1.gens) + dmonoms, dcoeffs = _dict_reorder( + a.denom().to_dict(), K0.gens, K1.gens) + + if K1.dom != K0.dom: + ncoeffs = [ K1.dom.convert(c, K0.dom) for c in ncoeffs ] + dcoeffs = [ K1.dom.convert(c, K0.dom) for c in dcoeffs ] + + return K1((dict(zip(nmonoms, ncoeffs)), dict(zip(dmonoms, dcoeffs)))) + + def get_ring(self): + """Returns a ring associated with ``self``. """ + from sympy.polys.domains import PolynomialRing + return PolynomialRing(self.dom, *self.gens) + + def poly_ring(self, *gens): + """Returns a polynomial ring, i.e. `K[X]`. """ + raise NotImplementedError('nested domains not allowed') + + def frac_field(self, *gens): + """Returns a fraction field, i.e. `K(X)`. """ + raise NotImplementedError('nested domains not allowed') + + def is_positive(self, a): + """Returns True if ``a`` is positive. """ + return self.dom.is_positive(a.numer().LC()) + + def is_negative(self, a): + """Returns True if ``a`` is negative. """ + return self.dom.is_negative(a.numer().LC()) + + def is_nonpositive(self, a): + """Returns True if ``a`` is non-positive. """ + return self.dom.is_nonpositive(a.numer().LC()) + + def is_nonnegative(self, a): + """Returns True if ``a`` is non-negative. """ + return self.dom.is_nonnegative(a.numer().LC()) + + def numer(self, a): + """Returns numerator of ``a``. """ + return a.numer() + + def denom(self, a): + """Returns denominator of ``a``. """ + return a.denom() + + def factorial(self, a): + """Returns factorial of ``a``. """ + return self.dtype(self.dom.factorial(a)) diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/old_polynomialring.py b/phivenv/Lib/site-packages/sympy/polys/domains/old_polynomialring.py new file mode 100644 index 0000000000000000000000000000000000000000..c29a4529aac3c64b29d8c670ac45b6c100294ced --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/domains/old_polynomialring.py @@ -0,0 +1,490 @@ +"""Implementation of :class:`PolynomialRing` class. """ + + +from sympy.polys.agca.modules import FreeModulePolyRing +from sympy.polys.domains.compositedomain import CompositeDomain +from sympy.polys.domains.old_fractionfield import FractionField +from sympy.polys.domains.ring import Ring +from sympy.polys.orderings import monomial_key, build_product_order +from sympy.polys.polyclasses import DMP, DMF +from sympy.polys.polyerrors import (GeneratorsNeeded, PolynomialError, + CoercionFailed, ExactQuotientFailed, NotReversible) +from sympy.polys.polyutils import dict_from_basic, basic_from_dict, _dict_reorder +from sympy.utilities import public +from sympy.utilities.iterables import iterable + + +@public +class PolynomialRingBase(Ring, CompositeDomain): + """ + Base class for generalized polynomial rings. + + This base class should be used for uniform access to generalized polynomial + rings. Subclasses only supply information about the element storage etc. + + Do not instantiate. + """ + + has_assoc_Ring = True + has_assoc_Field = True + + default_order = "grevlex" + + def __init__(self, dom, *gens, **opts): + if not gens: + raise GeneratorsNeeded("generators not specified") + + lev = len(gens) - 1 + self.ngens = len(gens) + + self.zero = self.dtype.zero(lev, dom) + self.one = self.dtype.one(lev, dom) + + self.domain = self.dom = dom + self.symbols = self.gens = gens + # NOTE 'order' may not be set if inject was called through CompositeDomain + self.order = opts.get('order', monomial_key(self.default_order)) + + def set_domain(self, dom): + """Return a new polynomial ring with given domain. """ + return self.__class__(dom, *self.gens, order=self.order) + + def new(self, element): + return self.dtype(element, self.dom, len(self.gens) - 1) + + def _ground_new(self, element): + return self.one.ground_new(element) + + def _from_dict(self, element): + return DMP.from_dict(element, len(self.gens) - 1, self.dom) + + def __str__(self): + s_order = str(self.order) + orderstr = ( + " order=" + s_order) if s_order != self.default_order else "" + return str(self.dom) + '[' + ','.join(map(str, self.gens)) + orderstr + ']' + + def __hash__(self): + return hash((self.__class__.__name__, self.dtype, self.dom, + self.gens, self.order)) + + def __eq__(self, other): + """Returns ``True`` if two domains are equivalent. """ + return isinstance(other, PolynomialRingBase) and \ + self.dtype == other.dtype and self.dom == other.dom and \ + self.gens == other.gens and self.order == other.order + + def from_ZZ(K1, a, K0): + """Convert a Python ``int`` object to ``dtype``. """ + return K1._ground_new(K1.dom.convert(a, K0)) + + def from_ZZ_python(K1, a, K0): + """Convert a Python ``int`` object to ``dtype``. """ + return K1._ground_new(K1.dom.convert(a, K0)) + + def from_QQ(K1, a, K0): + """Convert a Python ``Fraction`` object to ``dtype``. """ + return K1._ground_new(K1.dom.convert(a, K0)) + + def from_QQ_python(K1, a, K0): + """Convert a Python ``Fraction`` object to ``dtype``. """ + return K1._ground_new(K1.dom.convert(a, K0)) + + def from_ZZ_gmpy(K1, a, K0): + """Convert a GMPY ``mpz`` object to ``dtype``. """ + return K1._ground_new(K1.dom.convert(a, K0)) + + def from_QQ_gmpy(K1, a, K0): + """Convert a GMPY ``mpq`` object to ``dtype``. """ + return K1._ground_new(K1.dom.convert(a, K0)) + + def from_RealField(K1, a, K0): + """Convert a mpmath ``mpf`` object to ``dtype``. """ + return K1._ground_new(K1.dom.convert(a, K0)) + + def from_AlgebraicField(K1, a, K0): + """Convert a ``ANP`` object to ``dtype``. """ + if K1.dom == K0: + return K1._ground_new(a) + + def from_PolynomialRing(K1, a, K0): + """Convert a ``PolyElement`` object to ``dtype``. """ + if K1.gens == K0.symbols: + if K1.dom == K0.dom: + return K1(dict(a)) # set the correct ring + else: + convert_dom = lambda c: K1.dom.convert_from(c, K0.dom) + return K1._from_dict({m: convert_dom(c) for m, c in a.items()}) + else: + monoms, coeffs = _dict_reorder(a.to_dict(), K0.symbols, K1.gens) + + if K1.dom != K0.dom: + coeffs = [ K1.dom.convert(c, K0.dom) for c in coeffs ] + + return K1._from_dict(dict(zip(monoms, coeffs))) + + def from_GlobalPolynomialRing(K1, a, K0): + """Convert a ``DMP`` object to ``dtype``. """ + if K1.gens == K0.gens: + if K1.dom != K0.dom: + a = a.convert(K1.dom) + return K1(a.to_list()) + else: + monoms, coeffs = _dict_reorder(a.to_dict(), K0.gens, K1.gens) + + if K1.dom != K0.dom: + coeffs = [ K1.dom.convert(c, K0.dom) for c in coeffs ] + + return K1(dict(zip(monoms, coeffs))) + + def get_field(self): + """Returns a field associated with ``self``. """ + return FractionField(self.dom, *self.gens) + + def poly_ring(self, *gens): + """Returns a polynomial ring, i.e. ``K[X]``. """ + raise NotImplementedError('nested domains not allowed') + + def frac_field(self, *gens): + """Returns a fraction field, i.e. ``K(X)``. """ + raise NotImplementedError('nested domains not allowed') + + def revert(self, a): + try: + return self.exquo(self.one, a) + except (ExactQuotientFailed, ZeroDivisionError): + raise NotReversible('%s is not a unit' % a) + + def gcdex(self, a, b): + """Extended GCD of ``a`` and ``b``. """ + return a.gcdex(b) + + def gcd(self, a, b): + """Returns GCD of ``a`` and ``b``. """ + return a.gcd(b) + + def lcm(self, a, b): + """Returns LCM of ``a`` and ``b``. """ + return a.lcm(b) + + def factorial(self, a): + """Returns factorial of ``a``. """ + return self.dtype(self.dom.factorial(a)) + + def _vector_to_sdm(self, v, order): + """ + For internal use by the modules class. + + Convert an iterable of elements of this ring into a sparse distributed + module element. + """ + raise NotImplementedError + + def _sdm_to_dics(self, s, n): + """Helper for _sdm_to_vector.""" + from sympy.polys.distributedmodules import sdm_to_dict + dic = sdm_to_dict(s) + res = [{} for _ in range(n)] + for k, v in dic.items(): + res[k[0]][k[1:]] = v + return res + + def _sdm_to_vector(self, s, n): + """ + For internal use by the modules class. + + Convert a sparse distributed module into a list of length ``n``. + + Examples + ======== + + >>> from sympy import QQ, ilex + >>> from sympy.abc import x, y + >>> R = QQ.old_poly_ring(x, y, order=ilex) + >>> L = [((1, 1, 1), QQ(1)), ((0, 1, 0), QQ(1)), ((0, 0, 1), QQ(2))] + >>> R._sdm_to_vector(L, 2) + [DMF([[1], [2, 0]], [[1]], QQ), DMF([[1, 0], []], [[1]], QQ)] + """ + dics = self._sdm_to_dics(s, n) + # NOTE this works for global and local rings! + return [self(x) for x in dics] + + def free_module(self, rank): + """ + Generate a free module of rank ``rank`` over ``self``. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy import QQ + >>> QQ.old_poly_ring(x).free_module(2) + QQ[x]**2 + """ + return FreeModulePolyRing(self, rank) + + +def _vector_to_sdm_helper(v, order): + """Helper method for common code in Global and Local poly rings.""" + from sympy.polys.distributedmodules import sdm_from_dict + d = {} + for i, e in enumerate(v): + for key, value in e.to_dict().items(): + d[(i,) + key] = value + return sdm_from_dict(d, order) + + +@public +class GlobalPolynomialRing(PolynomialRingBase): + """A true polynomial ring, with objects DMP. """ + + is_PolynomialRing = is_Poly = True + dtype = DMP + + def new(self, element): + if isinstance(element, dict): + return DMP.from_dict(element, len(self.gens) - 1, self.dom) + elif element in self.dom: + return self._ground_new(self.dom.convert(element)) + else: + return self.dtype(element, self.dom, len(self.gens) - 1) + + def from_FractionField(K1, a, K0): + """ + Convert a ``DMF`` object to ``DMP``. + + Examples + ======== + + >>> from sympy.polys.polyclasses import DMP, DMF + >>> from sympy.polys.domains import ZZ + >>> from sympy.abc import x + + >>> f = DMF(([ZZ(1), ZZ(1)], [ZZ(1)]), ZZ) + >>> K = ZZ.old_frac_field(x) + + >>> F = ZZ.old_poly_ring(x).from_FractionField(f, K) + + >>> F == DMP([ZZ(1), ZZ(1)], ZZ) + True + >>> type(F) # doctest: +SKIP + + + """ + if a.denom().is_one: + return K1.from_GlobalPolynomialRing(a.numer(), K0) + + def to_sympy(self, a): + """Convert ``a`` to a SymPy object. """ + return basic_from_dict(a.to_sympy_dict(), *self.gens) + + def from_sympy(self, a): + """Convert SymPy's expression to ``dtype``. """ + try: + rep, _ = dict_from_basic(a, gens=self.gens) + except PolynomialError: + raise CoercionFailed("Cannot convert %s to type %s" % (a, self)) + + for k, v in rep.items(): + rep[k] = self.dom.from_sympy(v) + + return DMP.from_dict(rep, self.ngens - 1, self.dom) + + def is_positive(self, a): + """Returns True if ``LC(a)`` is positive. """ + return self.dom.is_positive(a.LC()) + + def is_negative(self, a): + """Returns True if ``LC(a)`` is negative. """ + return self.dom.is_negative(a.LC()) + + def is_nonpositive(self, a): + """Returns True if ``LC(a)`` is non-positive. """ + return self.dom.is_nonpositive(a.LC()) + + def is_nonnegative(self, a): + """Returns True if ``LC(a)`` is non-negative. """ + return self.dom.is_nonnegative(a.LC()) + + def _vector_to_sdm(self, v, order): + """ + Examples + ======== + + >>> from sympy import lex, QQ + >>> from sympy.abc import x, y + >>> R = QQ.old_poly_ring(x, y) + >>> f = R.convert(x + 2*y) + >>> g = R.convert(x * y) + >>> R._vector_to_sdm([f, g], lex) + [((1, 1, 1), 1), ((0, 1, 0), 1), ((0, 0, 1), 2)] + """ + return _vector_to_sdm_helper(v, order) + + +class GeneralizedPolynomialRing(PolynomialRingBase): + """A generalized polynomial ring, with objects DMF. """ + + dtype = DMF + + def new(self, a): + """Construct an element of ``self`` domain from ``a``. """ + res = self.dtype(a, self.dom, len(self.gens) - 1) + + # make sure res is actually in our ring + if res.denom().terms(order=self.order)[0][0] != (0,)*len(self.gens): + from sympy.printing.str import sstr + raise CoercionFailed("denominator %s not allowed in %s" + % (sstr(res), self)) + return res + + def __contains__(self, a): + try: + a = self.convert(a) + except CoercionFailed: + return False + return a.denom().terms(order=self.order)[0][0] == (0,)*len(self.gens) + + def to_sympy(self, a): + """Convert ``a`` to a SymPy object. """ + return (basic_from_dict(a.numer().to_sympy_dict(), *self.gens) / + basic_from_dict(a.denom().to_sympy_dict(), *self.gens)) + + def from_sympy(self, a): + """Convert SymPy's expression to ``dtype``. """ + p, q = a.as_numer_denom() + + num, _ = dict_from_basic(p, gens=self.gens) + den, _ = dict_from_basic(q, gens=self.gens) + + for k, v in num.items(): + num[k] = self.dom.from_sympy(v) + + for k, v in den.items(): + den[k] = self.dom.from_sympy(v) + + return self((num, den)).cancel() + + def exquo(self, a, b): + """Exact quotient of ``a`` and ``b``. """ + # Elements are DMF that will always divide (except 0). The result is + # not guaranteed to be in this ring, so we have to check that. + r = a / b + + try: + r = self.new((r.num, r.den)) + except CoercionFailed: + raise ExactQuotientFailed(a, b, self) + + return r + + def from_FractionField(K1, a, K0): + dmf = K1.get_field().from_FractionField(a, K0) + return K1((dmf.num, dmf.den)) + + def _vector_to_sdm(self, v, order): + """ + Turn an iterable into a sparse distributed module. + + Note that the vector is multiplied by a unit first to make all entries + polynomials. + + Examples + ======== + + >>> from sympy import ilex, QQ + >>> from sympy.abc import x, y + >>> R = QQ.old_poly_ring(x, y, order=ilex) + >>> f = R.convert((x + 2*y) / (1 + x)) + >>> g = R.convert(x * y) + >>> R._vector_to_sdm([f, g], ilex) + [((0, 0, 1), 2), ((0, 1, 0), 1), ((1, 1, 1), 1), ((1, + 2, 1), 1)] + """ + # NOTE this is quite inefficient... + u = self.one.numer() + for x in v: + u *= x.denom() + return _vector_to_sdm_helper([x.numer()*u/x.denom() for x in v], order) + + +@public +def PolynomialRing(dom, *gens, **opts): + r""" + Create a generalized multivariate polynomial ring. + + A generalized polynomial ring is defined by a ground field `K`, a set + of generators (typically `x_1, \ldots, x_n`) and a monomial order `<`. + The monomial order can be global, local or mixed. In any case it induces + a total ordering on the monomials, and there exists for every (non-zero) + polynomial `f \in K[x_1, \ldots, x_n]` a well-defined "leading monomial" + `LM(f) = LM(f, >)`. One can then define a multiplicative subset + `S = S_> = \{f \in K[x_1, \ldots, x_n] | LM(f) = 1\}`. The generalized + polynomial ring corresponding to the monomial order is + `R = S^{-1}K[x_1, \ldots, x_n]`. + + If `>` is a so-called global order, that is `1` is the smallest monomial, + then we just have `S = K` and `R = K[x_1, \ldots, x_n]`. + + Examples + ======== + + A few examples may make this clearer. + + >>> from sympy.abc import x, y + >>> from sympy import QQ + + Our first ring uses global lexicographic order. + + >>> R1 = QQ.old_poly_ring(x, y, order=(("lex", x, y),)) + + The second ring uses local lexicographic order. Note that when using a + single (non-product) order, you can just specify the name and omit the + variables: + + >>> R2 = QQ.old_poly_ring(x, y, order="ilex") + + The third and fourth rings use a mixed orders: + + >>> o1 = (("ilex", x), ("lex", y)) + >>> o2 = (("lex", x), ("ilex", y)) + >>> R3 = QQ.old_poly_ring(x, y, order=o1) + >>> R4 = QQ.old_poly_ring(x, y, order=o2) + + We will investigate what elements of `K(x, y)` are contained in the various + rings. + + >>> L = [x, 1/x, y/(1 + x), 1/(1 + y), 1/(1 + x*y)] + >>> test = lambda R: [f in R for f in L] + + The first ring is just `K[x, y]`: + + >>> test(R1) + [True, False, False, False, False] + + The second ring is R1 localised at the maximal ideal (x, y): + + >>> test(R2) + [True, False, True, True, True] + + The third ring is R1 localised at the prime ideal (x): + + >>> test(R3) + [True, False, True, False, True] + + Finally the fourth ring is R1 localised at `S = K[x, y] \setminus yK[y]`: + + >>> test(R4) + [True, False, False, True, False] + """ + + order = opts.get("order", GeneralizedPolynomialRing.default_order) + if iterable(order): + order = build_product_order(order, gens) + order = monomial_key(order) + opts['order'] = order + + if order.is_global: + return GlobalPolynomialRing(dom, *gens, **opts) + else: + return GeneralizedPolynomialRing(dom, *gens, **opts) diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/polynomialring.py b/phivenv/Lib/site-packages/sympy/polys/domains/polynomialring.py new file mode 100644 index 0000000000000000000000000000000000000000..daccdcdede4d409e995a79540b0c3f9e8017d2d9 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/domains/polynomialring.py @@ -0,0 +1,203 @@ +"""Implementation of :class:`PolynomialRing` class. """ + + +from sympy.polys.domains.ring import Ring +from sympy.polys.domains.compositedomain import CompositeDomain + +from sympy.polys.polyerrors import CoercionFailed, GeneratorsError +from sympy.utilities import public + +@public +class PolynomialRing(Ring, CompositeDomain): + """A class for representing multivariate polynomial rings. """ + + is_PolynomialRing = is_Poly = True + + has_assoc_Ring = True + has_assoc_Field = True + + def __init__(self, domain_or_ring, symbols=None, order=None): + from sympy.polys.rings import PolyRing + + if isinstance(domain_or_ring, PolyRing) and symbols is None and order is None: + ring = domain_or_ring + else: + ring = PolyRing(symbols, domain_or_ring, order) + + self.ring = ring + self.dtype = ring.dtype + + self.gens = ring.gens + self.ngens = ring.ngens + self.symbols = ring.symbols + self.domain = ring.domain + + + if symbols: + if ring.domain.is_Field and ring.domain.is_Exact and len(symbols)==1: + self.is_PID = True + + # TODO: remove this + self.dom = self.domain + + def new(self, element): + return self.ring.ring_new(element) + + def of_type(self, element): + """Check if ``a`` is of type ``dtype``. """ + return self.ring.is_element(element) + + @property + def zero(self): + return self.ring.zero + + @property + def one(self): + return self.ring.one + + @property + def order(self): + return self.ring.order + + def __str__(self): + return str(self.domain) + '[' + ','.join(map(str, self.symbols)) + ']' + + def __hash__(self): + return hash((self.__class__.__name__, self.ring, self.domain, self.symbols)) + + def __eq__(self, other): + """Returns `True` if two domains are equivalent. """ + if not isinstance(other, PolynomialRing): + return NotImplemented + return self.ring == other.ring + + def is_unit(self, a): + """Returns ``True`` if ``a`` is a unit of ``self``""" + if not a.is_ground: + return False + K = self.domain + return K.is_unit(K.convert_from(a, self)) + + def canonical_unit(self, a): + u = self.domain.canonical_unit(a.LC) + return self.ring.ground_new(u) + + def to_sympy(self, a): + """Convert `a` to a SymPy object. """ + return a.as_expr() + + def from_sympy(self, a): + """Convert SymPy's expression to `dtype`. """ + return self.ring.from_expr(a) + + def from_ZZ(K1, a, K0): + """Convert a Python `int` object to `dtype`. """ + return K1(K1.domain.convert(a, K0)) + + def from_ZZ_python(K1, a, K0): + """Convert a Python `int` object to `dtype`. """ + return K1(K1.domain.convert(a, K0)) + + def from_QQ(K1, a, K0): + """Convert a Python `Fraction` object to `dtype`. """ + return K1(K1.domain.convert(a, K0)) + + def from_QQ_python(K1, a, K0): + """Convert a Python `Fraction` object to `dtype`. """ + return K1(K1.domain.convert(a, K0)) + + def from_ZZ_gmpy(K1, a, K0): + """Convert a GMPY `mpz` object to `dtype`. """ + return K1(K1.domain.convert(a, K0)) + + def from_QQ_gmpy(K1, a, K0): + """Convert a GMPY `mpq` object to `dtype`. """ + return K1(K1.domain.convert(a, K0)) + + def from_GaussianIntegerRing(K1, a, K0): + """Convert a `GaussianInteger` object to `dtype`. """ + return K1(K1.domain.convert(a, K0)) + + def from_GaussianRationalField(K1, a, K0): + """Convert a `GaussianRational` object to `dtype`. """ + return K1(K1.domain.convert(a, K0)) + + def from_RealField(K1, a, K0): + """Convert a mpmath `mpf` object to `dtype`. """ + return K1(K1.domain.convert(a, K0)) + + def from_ComplexField(K1, a, K0): + """Convert a mpmath `mpf` object to `dtype`. """ + return K1(K1.domain.convert(a, K0)) + + def from_AlgebraicField(K1, a, K0): + """Convert an algebraic number to ``dtype``. """ + if K1.domain != K0: + a = K1.domain.convert_from(a, K0) + if a is not None: + return K1.new(a) + + def from_PolynomialRing(K1, a, K0): + """Convert a polynomial to ``dtype``. """ + try: + return a.set_ring(K1.ring) + except (CoercionFailed, GeneratorsError): + return None + + def from_FractionField(K1, a, K0): + """Convert a rational function to ``dtype``. """ + if K1.domain == K0: + return K1.ring.from_list([a]) + + q, r = K0.numer(a).div(K0.denom(a)) + + if r.is_zero: + return K1.from_PolynomialRing(q, K0.field.ring.to_domain()) + else: + return None + + def from_GlobalPolynomialRing(K1, a, K0): + """Convert from old poly ring to ``dtype``. """ + if K1.symbols == K0.gens: + ad = a.to_dict() + if K1.domain != K0.domain: + ad = {m: K1.domain.convert(c) for m, c in ad.items()} + return K1(ad) + elif a.is_ground and K0.domain == K1: + return K1.convert_from(a.to_list()[0], K0.domain) + + def get_field(self): + """Returns a field associated with `self`. """ + return self.ring.to_field().to_domain() + + def is_positive(self, a): + """Returns True if `LC(a)` is positive. """ + return self.domain.is_positive(a.LC) + + def is_negative(self, a): + """Returns True if `LC(a)` is negative. """ + return self.domain.is_negative(a.LC) + + def is_nonpositive(self, a): + """Returns True if `LC(a)` is non-positive. """ + return self.domain.is_nonpositive(a.LC) + + def is_nonnegative(self, a): + """Returns True if `LC(a)` is non-negative. """ + return self.domain.is_nonnegative(a.LC) + + def gcdex(self, a, b): + """Extended GCD of `a` and `b`. """ + return a.gcdex(b) + + def gcd(self, a, b): + """Returns GCD of `a` and `b`. """ + return a.gcd(b) + + def lcm(self, a, b): + """Returns LCM of `a` and `b`. """ + return a.lcm(b) + + def factorial(self, a): + """Returns factorial of `a`. """ + return self.dtype(self.domain.factorial(a)) diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/pythonfinitefield.py b/phivenv/Lib/site-packages/sympy/polys/domains/pythonfinitefield.py new file mode 100644 index 0000000000000000000000000000000000000000..44baa4f6d1b43317283041206eaa43e06a5cc8db --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/domains/pythonfinitefield.py @@ -0,0 +1,16 @@ +"""Implementation of :class:`PythonFiniteField` class. """ + + +from sympy.polys.domains.finitefield import FiniteField +from sympy.polys.domains.pythonintegerring import PythonIntegerRing + +from sympy.utilities import public + +@public +class PythonFiniteField(FiniteField): + """Finite field based on Python's integers. """ + + alias = 'FF_python' + + def __init__(self, mod, symmetric=True): + super().__init__(mod, PythonIntegerRing(), symmetric) diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/pythonintegerring.py b/phivenv/Lib/site-packages/sympy/polys/domains/pythonintegerring.py new file mode 100644 index 0000000000000000000000000000000000000000..81ee9637a4ebcfaf3c5f11d12c18265305984c25 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/domains/pythonintegerring.py @@ -0,0 +1,98 @@ +"""Implementation of :class:`PythonIntegerRing` class. """ + + +from sympy.core.numbers import int_valued +from sympy.polys.domains.groundtypes import ( + PythonInteger, SymPyInteger, sqrt as python_sqrt, + factorial as python_factorial, python_gcdex, python_gcd, python_lcm, +) +from sympy.polys.domains.integerring import IntegerRing +from sympy.polys.polyerrors import CoercionFailed +from sympy.utilities import public + +@public +class PythonIntegerRing(IntegerRing): + """Integer ring based on Python's ``int`` type. + + This will be used as :ref:`ZZ` if ``gmpy`` and ``gmpy2`` are not + installed. Elements are instances of the standard Python ``int`` type. + """ + + dtype = PythonInteger + zero = dtype(0) + one = dtype(1) + alias = 'ZZ_python' + + def __init__(self): + """Allow instantiation of this domain. """ + + def to_sympy(self, a): + """Convert ``a`` to a SymPy object. """ + return SymPyInteger(a) + + def from_sympy(self, a): + """Convert SymPy's Integer to ``dtype``. """ + if a.is_Integer: + return PythonInteger(a.p) + elif int_valued(a): + return PythonInteger(int(a)) + else: + raise CoercionFailed("expected an integer, got %s" % a) + + def from_FF_python(K1, a, K0): + """Convert ``ModularInteger(int)`` to Python's ``int``. """ + return K0.to_int(a) + + def from_ZZ_python(K1, a, K0): + """Convert Python's ``int`` to Python's ``int``. """ + return a + + def from_QQ(K1, a, K0): + """Convert Python's ``Fraction`` to Python's ``int``. """ + if a.denominator == 1: + return a.numerator + + def from_QQ_python(K1, a, K0): + """Convert Python's ``Fraction`` to Python's ``int``. """ + if a.denominator == 1: + return a.numerator + + def from_FF_gmpy(K1, a, K0): + """Convert ``ModularInteger(mpz)`` to Python's ``int``. """ + return PythonInteger(K0.to_int(a)) + + def from_ZZ_gmpy(K1, a, K0): + """Convert GMPY's ``mpz`` to Python's ``int``. """ + return PythonInteger(a) + + def from_QQ_gmpy(K1, a, K0): + """Convert GMPY's ``mpq`` to Python's ``int``. """ + if a.denom() == 1: + return PythonInteger(a.numer()) + + def from_RealField(K1, a, K0): + """Convert mpmath's ``mpf`` to Python's ``int``. """ + p, q = K0.to_rational(a) + + if q == 1: + return PythonInteger(p) + + def gcdex(self, a, b): + """Compute extended GCD of ``a`` and ``b``. """ + return python_gcdex(a, b) + + def gcd(self, a, b): + """Compute GCD of ``a`` and ``b``. """ + return python_gcd(a, b) + + def lcm(self, a, b): + """Compute LCM of ``a`` and ``b``. """ + return python_lcm(a, b) + + def sqrt(self, a): + """Compute square root of ``a``. """ + return python_sqrt(a) + + def factorial(self, a): + """Compute factorial of ``a``. """ + return python_factorial(a) diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/pythonrational.py b/phivenv/Lib/site-packages/sympy/polys/domains/pythonrational.py new file mode 100644 index 0000000000000000000000000000000000000000..87b56d6c929c3ce3ce153dce7b3c210821d706a0 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/domains/pythonrational.py @@ -0,0 +1,22 @@ +""" +Rational number type based on Python integers. + +The PythonRational class from here has been moved to +sympy.external.pythonmpq + +This module is just left here for backwards compatibility. +""" + + +from sympy.core.numbers import Rational +from sympy.core.sympify import _sympy_converter +from sympy.utilities import public +from sympy.external.pythonmpq import PythonMPQ + + +PythonRational = public(PythonMPQ) + + +def sympify_pythonrational(arg): + return Rational(arg.numerator, arg.denominator) +_sympy_converter[PythonRational] = sympify_pythonrational diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/pythonrationalfield.py b/phivenv/Lib/site-packages/sympy/polys/domains/pythonrationalfield.py new file mode 100644 index 0000000000000000000000000000000000000000..51afaef636f000855d51a69fb93eb416ae1e5347 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/domains/pythonrationalfield.py @@ -0,0 +1,73 @@ +"""Implementation of :class:`PythonRationalField` class. """ + + +from sympy.polys.domains.groundtypes import PythonInteger, PythonRational, SymPyRational +from sympy.polys.domains.rationalfield import RationalField +from sympy.polys.polyerrors import CoercionFailed +from sympy.utilities import public + +@public +class PythonRationalField(RationalField): + """Rational field based on :ref:`MPQ`. + + This will be used as :ref:`QQ` if ``gmpy`` and ``gmpy2`` are not + installed. Elements are instances of :ref:`MPQ`. + """ + + dtype = PythonRational + zero = dtype(0) + one = dtype(1) + alias = 'QQ_python' + + def __init__(self): + pass + + def get_ring(self): + """Returns ring associated with ``self``. """ + from sympy.polys.domains import PythonIntegerRing + return PythonIntegerRing() + + def to_sympy(self, a): + """Convert `a` to a SymPy object. """ + return SymPyRational(a.numerator, a.denominator) + + def from_sympy(self, a): + """Convert SymPy's Rational to `dtype`. """ + if a.is_Rational: + return PythonRational(a.p, a.q) + elif a.is_Float: + from sympy.polys.domains import RR + p, q = RR.to_rational(a) + return PythonRational(int(p), int(q)) + else: + raise CoercionFailed("expected `Rational` object, got %s" % a) + + def from_ZZ_python(K1, a, K0): + """Convert a Python `int` object to `dtype`. """ + return PythonRational(a) + + def from_QQ_python(K1, a, K0): + """Convert a Python `Fraction` object to `dtype`. """ + return a + + def from_ZZ_gmpy(K1, a, K0): + """Convert a GMPY `mpz` object to `dtype`. """ + return PythonRational(PythonInteger(a)) + + def from_QQ_gmpy(K1, a, K0): + """Convert a GMPY `mpq` object to `dtype`. """ + return PythonRational(PythonInteger(a.numer()), + PythonInteger(a.denom())) + + def from_RealField(K1, a, K0): + """Convert a mpmath `mpf` object to `dtype`. """ + p, q = K0.to_rational(a) + return PythonRational(int(p), int(q)) + + def numer(self, a): + """Returns numerator of `a`. """ + return a.numerator + + def denom(self, a): + """Returns denominator of `a`. """ + return a.denominator diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/quotientring.py b/phivenv/Lib/site-packages/sympy/polys/domains/quotientring.py new file mode 100644 index 0000000000000000000000000000000000000000..7e8abf6b210a5627c9c139e41248637c9b88931f --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/domains/quotientring.py @@ -0,0 +1,202 @@ +"""Implementation of :class:`QuotientRing` class.""" + + +from sympy.polys.agca.modules import FreeModuleQuotientRing +from sympy.polys.domains.ring import Ring +from sympy.polys.polyerrors import NotReversible, CoercionFailed +from sympy.utilities import public + +# TODO +# - successive quotients (when quotient ideals are implemented) +# - poly rings over quotients? +# - division by non-units in integral domains? + +@public +class QuotientRingElement: + """ + Class representing elements of (commutative) quotient rings. + + Attributes: + + - ring - containing ring + - data - element of ring.ring (i.e. base ring) representing self + """ + + def __init__(self, ring, data): + self.ring = ring + self.data = data + + def __str__(self): + from sympy.printing.str import sstr + data = self.ring.ring.to_sympy(self.data) + return sstr(data) + " + " + str(self.ring.base_ideal) + + __repr__ = __str__ + + def __bool__(self): + return not self.ring.is_zero(self) + + def __add__(self, om): + if not isinstance(om, self.__class__) or om.ring != self.ring: + try: + om = self.ring.convert(om) + except (NotImplementedError, CoercionFailed): + return NotImplemented + return self.ring(self.data + om.data) + + __radd__ = __add__ + + def __neg__(self): + return self.ring(self.data*self.ring.ring.convert(-1)) + + def __sub__(self, om): + return self.__add__(-om) + + def __rsub__(self, om): + return (-self).__add__(om) + + def __mul__(self, o): + if not isinstance(o, self.__class__): + try: + o = self.ring.convert(o) + except (NotImplementedError, CoercionFailed): + return NotImplemented + return self.ring(self.data*o.data) + + __rmul__ = __mul__ + + def __rtruediv__(self, o): + return self.ring.revert(self)*o + + def __truediv__(self, o): + if not isinstance(o, self.__class__): + try: + o = self.ring.convert(o) + except (NotImplementedError, CoercionFailed): + return NotImplemented + return self.ring.revert(o)*self + + def __pow__(self, oth): + if oth < 0: + return self.ring.revert(self) ** -oth + return self.ring(self.data ** oth) + + def __eq__(self, om): + if not isinstance(om, self.__class__) or om.ring != self.ring: + return False + return self.ring.is_zero(self - om) + + def __ne__(self, om): + return not self == om + + +class QuotientRing(Ring): + """ + Class representing (commutative) quotient rings. + + You should not usually instantiate this by hand, instead use the constructor + from the base ring in the construction. + + >>> from sympy.abc import x + >>> from sympy import QQ + >>> I = QQ.old_poly_ring(x).ideal(x**3 + 1) + >>> QQ.old_poly_ring(x).quotient_ring(I) + QQ[x]/ + + Shorter versions are possible: + + >>> QQ.old_poly_ring(x)/I + QQ[x]/ + + >>> QQ.old_poly_ring(x)/[x**3 + 1] + QQ[x]/ + + Attributes: + + - ring - the base ring + - base_ideal - the ideal used to form the quotient + """ + + has_assoc_Ring = True + has_assoc_Field = False + dtype = QuotientRingElement + + def __init__(self, ring, ideal): + if not ideal.ring == ring: + raise ValueError('Ideal must belong to %s, got %s' % (ring, ideal)) + self.ring = ring + self.base_ideal = ideal + self.zero = self(self.ring.zero) + self.one = self(self.ring.one) + + def __str__(self): + return str(self.ring) + "/" + str(self.base_ideal) + + def __hash__(self): + return hash((self.__class__.__name__, self.dtype, self.ring, self.base_ideal)) + + def new(self, a): + """Construct an element of ``self`` domain from ``a``. """ + if not isinstance(a, self.ring.dtype): + a = self.ring(a) + # TODO optionally disable reduction? + return self.dtype(self, self.base_ideal.reduce_element(a)) + + def __eq__(self, other): + """Returns ``True`` if two domains are equivalent. """ + return isinstance(other, QuotientRing) and \ + self.ring == other.ring and self.base_ideal == other.base_ideal + + def from_ZZ(K1, a, K0): + """Convert a Python ``int`` object to ``dtype``. """ + return K1(K1.ring.convert(a, K0)) + + from_ZZ_python = from_ZZ + from_QQ_python = from_ZZ_python + from_ZZ_gmpy = from_ZZ_python + from_QQ_gmpy = from_ZZ_python + from_RealField = from_ZZ_python + from_GlobalPolynomialRing = from_ZZ_python + from_FractionField = from_ZZ_python + + def from_sympy(self, a): + return self(self.ring.from_sympy(a)) + + def to_sympy(self, a): + return self.ring.to_sympy(a.data) + + def from_QuotientRing(self, a, K0): + if K0 == self: + return a + + def poly_ring(self, *gens): + """Returns a polynomial ring, i.e. ``K[X]``. """ + raise NotImplementedError('nested domains not allowed') + + def frac_field(self, *gens): + """Returns a fraction field, i.e. ``K(X)``. """ + raise NotImplementedError('nested domains not allowed') + + def revert(self, a): + """ + Compute a**(-1), if possible. + """ + I = self.ring.ideal(a.data) + self.base_ideal + try: + return self(I.in_terms_of_generators(1)[0]) + except ValueError: # 1 not in I + raise NotReversible('%s not a unit in %r' % (a, self)) + + def is_zero(self, a): + return self.base_ideal.contains(a.data) + + def free_module(self, rank): + """ + Generate a free module of rank ``rank`` over ``self``. + + >>> from sympy.abc import x + >>> from sympy import QQ + >>> (QQ.old_poly_ring(x)/[x**2 + 1]).free_module(2) + (QQ[x]/)**2 + """ + return FreeModuleQuotientRing(self, rank) diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/rationalfield.py b/phivenv/Lib/site-packages/sympy/polys/domains/rationalfield.py new file mode 100644 index 0000000000000000000000000000000000000000..6da570332de8a6d39a21bb3d57447670c7a98441 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/domains/rationalfield.py @@ -0,0 +1,200 @@ +"""Implementation of :class:`RationalField` class. """ + + +from sympy.external.gmpy import MPQ + +from sympy.polys.domains.groundtypes import SymPyRational, is_square, sqrtrem + +from sympy.polys.domains.characteristiczero import CharacteristicZero +from sympy.polys.domains.field import Field +from sympy.polys.domains.simpledomain import SimpleDomain +from sympy.polys.polyerrors import CoercionFailed +from sympy.utilities import public + +@public +class RationalField(Field, CharacteristicZero, SimpleDomain): + r"""Abstract base class for the domain :ref:`QQ`. + + The :py:class:`RationalField` class represents the field of rational + numbers $\mathbb{Q}$ as a :py:class:`~.Domain` in the domain system. + :py:class:`RationalField` is a superclass of + :py:class:`PythonRationalField` and :py:class:`GMPYRationalField` one of + which will be the implementation for :ref:`QQ` depending on whether either + of ``gmpy`` or ``gmpy2`` is installed or not. + + See also + ======== + + Domain + """ + + rep = 'QQ' + alias = 'QQ' + + is_RationalField = is_QQ = True + is_Numerical = True + + has_assoc_Ring = True + has_assoc_Field = True + + dtype = MPQ + zero = dtype(0) + one = dtype(1) + tp = type(one) + + def __init__(self): + pass + + def __eq__(self, other): + """Returns ``True`` if two domains are equivalent. """ + if isinstance(other, RationalField): + return True + else: + return NotImplemented + + def __hash__(self): + """Returns hash code of ``self``. """ + return hash('QQ') + + def get_ring(self): + """Returns ring associated with ``self``. """ + from sympy.polys.domains import ZZ + return ZZ + + def to_sympy(self, a): + """Convert ``a`` to a SymPy object. """ + return SymPyRational(int(a.numerator), int(a.denominator)) + + def from_sympy(self, a): + """Convert SymPy's Integer to ``dtype``. """ + if a.is_Rational: + return MPQ(a.p, a.q) + elif a.is_Float: + from sympy.polys.domains import RR + return MPQ(*map(int, RR.to_rational(a))) + else: + raise CoercionFailed("expected `Rational` object, got %s" % a) + + def algebraic_field(self, *extension, alias=None): + r"""Returns an algebraic field, i.e. `\mathbb{Q}(\alpha, \ldots)`. + + Parameters + ========== + + *extension : One or more :py:class:`~.Expr` + Generators of the extension. These should be expressions that are + algebraic over `\mathbb{Q}`. + + alias : str, :py:class:`~.Symbol`, None, optional (default=None) + If provided, this will be used as the alias symbol for the + primitive element of the returned :py:class:`~.AlgebraicField`. + + Returns + ======= + + :py:class:`~.AlgebraicField` + A :py:class:`~.Domain` representing the algebraic field extension. + + Examples + ======== + + >>> from sympy import QQ, sqrt + >>> QQ.algebraic_field(sqrt(2)) + QQ + """ + from sympy.polys.domains import AlgebraicField + return AlgebraicField(self, *extension, alias=alias) + + def from_AlgebraicField(K1, a, K0): + """Convert a :py:class:`~.ANP` object to :ref:`QQ`. + + See :py:meth:`~.Domain.convert` + """ + if a.is_ground: + return K1.convert(a.LC(), K0.dom) + + def from_ZZ(K1, a, K0): + """Convert a Python ``int`` object to ``dtype``. """ + return MPQ(a) + + def from_ZZ_python(K1, a, K0): + """Convert a Python ``int`` object to ``dtype``. """ + return MPQ(a) + + def from_QQ(K1, a, K0): + """Convert a Python ``Fraction`` object to ``dtype``. """ + return MPQ(a.numerator, a.denominator) + + def from_QQ_python(K1, a, K0): + """Convert a Python ``Fraction`` object to ``dtype``. """ + return MPQ(a.numerator, a.denominator) + + def from_ZZ_gmpy(K1, a, K0): + """Convert a GMPY ``mpz`` object to ``dtype``. """ + return MPQ(a) + + def from_QQ_gmpy(K1, a, K0): + """Convert a GMPY ``mpq`` object to ``dtype``. """ + return a + + def from_GaussianRationalField(K1, a, K0): + """Convert a ``GaussianElement`` object to ``dtype``. """ + if a.y == 0: + return MPQ(a.x) + + def from_RealField(K1, a, K0): + """Convert a mpmath ``mpf`` object to ``dtype``. """ + return MPQ(*map(int, K0.to_rational(a))) + + def exquo(self, a, b): + """Exact quotient of ``a`` and ``b``, implies ``__truediv__``. """ + return MPQ(a) / MPQ(b) + + def quo(self, a, b): + """Quotient of ``a`` and ``b``, implies ``__truediv__``. """ + return MPQ(a) / MPQ(b) + + def rem(self, a, b): + """Remainder of ``a`` and ``b``, implies nothing. """ + return self.zero + + def div(self, a, b): + """Division of ``a`` and ``b``, implies ``__truediv__``. """ + return MPQ(a) / MPQ(b), self.zero + + def numer(self, a): + """Returns numerator of ``a``. """ + return a.numerator + + def denom(self, a): + """Returns denominator of ``a``. """ + return a.denominator + + def is_square(self, a): + """Return ``True`` if ``a`` is a square. + + Explanation + =========== + A rational number is a square if and only if there exists + a rational number ``b`` such that ``b * b == a``. + """ + return is_square(a.numerator) and is_square(a.denominator) + + def exsqrt(self, a): + """Non-negative square root of ``a`` if ``a`` is a square. + + See also + ======== + is_square + """ + if a.numerator < 0: # denominator is always positive + return None + p_sqrt, p_rem = sqrtrem(a.numerator) + if p_rem != 0: + return None + q_sqrt, q_rem = sqrtrem(a.denominator) + if q_rem != 0: + return None + return MPQ(p_sqrt, q_sqrt) + +QQ = RationalField() diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/realfield.py b/phivenv/Lib/site-packages/sympy/polys/domains/realfield.py new file mode 100644 index 0000000000000000000000000000000000000000..12f543b2619aa238969ecbe20215d6fd59792904 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/domains/realfield.py @@ -0,0 +1,220 @@ +"""Implementation of :class:`RealField` class. """ + + +from sympy.external.gmpy import SYMPY_INTS, MPQ +from sympy.core.numbers import Float +from sympy.polys.domains.field import Field +from sympy.polys.domains.simpledomain import SimpleDomain +from sympy.polys.domains.characteristiczero import CharacteristicZero +from sympy.polys.polyerrors import CoercionFailed +from sympy.utilities import public + +from mpmath import MPContext +from mpmath.libmp import to_rational as _mpmath_to_rational + + +def to_rational(s, max_denom, limit=True): + + p, q = _mpmath_to_rational(s._mpf_) + + # Needed for GROUND_TYPES=flint if gmpy2 is installed because mpmath's + # to_rational() function returns a gmpy2.mpz instance and if MPQ is + # flint.fmpq then MPQ(p, q) will fail. + p = int(p) + q = int(q) + + if not limit or q <= max_denom: + return p, q + + p0, q0, p1, q1 = 0, 1, 1, 0 + n, d = p, q + + while True: + a = n//d + q2 = q0 + a*q1 + if q2 > max_denom: + break + p0, q0, p1, q1 = p1, q1, p0 + a*p1, q2 + n, d = d, n - a*d + + k = (max_denom - q0)//q1 + + number = MPQ(p, q) + bound1 = MPQ(p0 + k*p1, q0 + k*q1) + bound2 = MPQ(p1, q1) + + if not bound2 or not bound1: + return p, q + elif abs(bound2 - number) <= abs(bound1 - number): + return bound2.numerator, bound2.denominator + else: + return bound1.numerator, bound1.denominator + + +@public +class RealField(Field, CharacteristicZero, SimpleDomain): + """Real numbers up to the given precision. """ + + rep = 'RR' + + is_RealField = is_RR = True + + is_Exact = False + is_Numerical = True + is_PID = False + + has_assoc_Ring = False + has_assoc_Field = True + + _default_precision = 53 + + @property + def has_default_precision(self): + return self.precision == self._default_precision + + @property + def precision(self): + return self._context.prec + + @property + def dps(self): + return self._context.dps + + @property + def tolerance(self): + return self._tolerance + + def __init__(self, prec=None, dps=None, tol=None): + # XXX: The tol parameter is ignored but is kept for now for backwards + # compatibility. + + context = MPContext() + + if prec is None and dps is None: + context.prec = self._default_precision + elif dps is None: + context.prec = prec + elif prec is None: + context.dps = dps + else: + raise TypeError("Cannot set both prec and dps") + + self._context = context + + self._dtype = context.mpf + self.zero = self.dtype(0) + self.one = self.dtype(1) + + # Only max_denom here is used for anything and is only used for + # to_rational. + self._max_denom = max(2**context.prec // 200, 99) + self._tolerance = self.one / self._max_denom + + @property + def tp(self): + # XXX: Domain treats tp as an alias of dtype. Here we need to two + # separate things: dtype is a callable to make/convert instances. + # We use tp with isinstance to check if an object is an instance + # of the domain already. + return self._dtype + + def dtype(self, arg): + # XXX: This is needed because mpmath does not recognise fmpz. + # It might be better to add conversion routines to mpmath and if that + # happens then this can be removed. + if isinstance(arg, SYMPY_INTS): + arg = int(arg) + return self._dtype(arg) + + def __eq__(self, other): + return isinstance(other, RealField) and self.precision == other.precision + + def __hash__(self): + return hash((self.__class__.__name__, self._dtype, self.precision)) + + def to_sympy(self, element): + """Convert ``element`` to SymPy number. """ + return Float(element, self.dps) + + def from_sympy(self, expr): + """Convert SymPy's number to ``dtype``. """ + number = expr.evalf(n=self.dps) + + if number.is_Number: + return self.dtype(number) + else: + raise CoercionFailed("expected real number, got %s" % expr) + + def from_ZZ(self, element, base): + return self.dtype(element) + + def from_ZZ_python(self, element, base): + return self.dtype(element) + + def from_ZZ_gmpy(self, element, base): + return self.dtype(int(element)) + + # XXX: We need to convert the denominators to int here because mpmath does + # not recognise mpz. Ideally mpmath would handle this and if it changed to + # do so then the calls to int here could be removed. + + def from_QQ(self, element, base): + return self.dtype(element.numerator) / int(element.denominator) + + def from_QQ_python(self, element, base): + return self.dtype(element.numerator) / int(element.denominator) + + def from_QQ_gmpy(self, element, base): + return self.dtype(int(element.numerator)) / int(element.denominator) + + def from_AlgebraicField(self, element, base): + return self.from_sympy(base.to_sympy(element).evalf(self.dps)) + + def from_RealField(self, element, base): + return self.dtype(element) + + def from_ComplexField(self, element, base): + if not element.imag: + return self.dtype(element.real) + + def to_rational(self, element, limit=True): + """Convert a real number to rational number. """ + return to_rational(element, self._max_denom, limit=limit) + + def get_ring(self): + """Returns a ring associated with ``self``. """ + return self + + def get_exact(self): + """Returns an exact domain associated with ``self``. """ + from sympy.polys.domains import QQ + return QQ + + def gcd(self, a, b): + """Returns GCD of ``a`` and ``b``. """ + return self.one + + def lcm(self, a, b): + """Returns LCM of ``a`` and ``b``. """ + return a*b + + def almosteq(self, a, b, tolerance=None): + """Check if ``a`` and ``b`` are almost equal. """ + return self._context.almosteq(a, b, tolerance) + + def is_square(self, a): + """Returns ``True`` if ``a >= 0`` and ``False`` otherwise. """ + return a >= 0 + + def exsqrt(self, a): + """Non-negative square root for ``a >= 0`` and ``None`` otherwise. + + Explanation + =========== + The square root may be slightly inaccurate due to floating point + rounding error. + """ + return a ** 0.5 if a >= 0 else None + + +RR = RealField() diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/ring.py b/phivenv/Lib/site-packages/sympy/polys/domains/ring.py new file mode 100644 index 0000000000000000000000000000000000000000..c69e6944d8f51e4b319609368a476e6e847ae126 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/domains/ring.py @@ -0,0 +1,118 @@ +"""Implementation of :class:`Ring` class. """ + + +from sympy.polys.domains.domain import Domain +from sympy.polys.polyerrors import ExactQuotientFailed, NotInvertible, NotReversible + +from sympy.utilities import public + +@public +class Ring(Domain): + """Represents a ring domain. """ + + is_Ring = True + + def get_ring(self): + """Returns a ring associated with ``self``. """ + return self + + def exquo(self, a, b): + """Exact quotient of ``a`` and ``b``, implies ``__floordiv__``. """ + if a % b: + raise ExactQuotientFailed(a, b, self) + else: + return a // b + + def quo(self, a, b): + """Quotient of ``a`` and ``b``, implies ``__floordiv__``. """ + return a // b + + def rem(self, a, b): + """Remainder of ``a`` and ``b``, implies ``__mod__``. """ + return a % b + + def div(self, a, b): + """Division of ``a`` and ``b``, implies ``__divmod__``. """ + return divmod(a, b) + + def invert(self, a, b): + """Returns inversion of ``a mod b``. """ + s, t, h = self.gcdex(a, b) + + if self.is_one(h): + return s % b + else: + raise NotInvertible("zero divisor") + + def revert(self, a): + """Returns ``a**(-1)`` if possible. """ + if self.is_one(a) or self.is_one(-a): + return a + else: + raise NotReversible('only units are reversible in a ring') + + def is_unit(self, a): + try: + self.revert(a) + return True + except NotReversible: + return False + + def numer(self, a): + """Returns numerator of ``a``. """ + return a + + def denom(self, a): + """Returns denominator of `a`. """ + return self.one + + def free_module(self, rank): + """ + Generate a free module of rank ``rank`` over self. + + >>> from sympy.abc import x + >>> from sympy import QQ + >>> QQ.old_poly_ring(x).free_module(2) + QQ[x]**2 + """ + raise NotImplementedError + + def ideal(self, *gens): + """ + Generate an ideal of ``self``. + + >>> from sympy.abc import x + >>> from sympy import QQ + >>> QQ.old_poly_ring(x).ideal(x**2) + + """ + from sympy.polys.agca.ideals import ModuleImplementedIdeal + return ModuleImplementedIdeal(self, self.free_module(1).submodule( + *[[x] for x in gens])) + + def quotient_ring(self, e): + """ + Form a quotient ring of ``self``. + + Here ``e`` can be an ideal or an iterable. + + >>> from sympy.abc import x + >>> from sympy import QQ + >>> QQ.old_poly_ring(x).quotient_ring(QQ.old_poly_ring(x).ideal(x**2)) + QQ[x]/ + >>> QQ.old_poly_ring(x).quotient_ring([x**2]) + QQ[x]/ + + The division operator has been overloaded for this: + + >>> QQ.old_poly_ring(x)/[x**2] + QQ[x]/ + """ + from sympy.polys.agca.ideals import Ideal + from sympy.polys.domains.quotientring import QuotientRing + if not isinstance(e, Ideal): + e = self.ideal(*e) + return QuotientRing(self, e) + + def __truediv__(self, e): + return self.quotient_ring(e) diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/simpledomain.py b/phivenv/Lib/site-packages/sympy/polys/domains/simpledomain.py new file mode 100644 index 0000000000000000000000000000000000000000..88cf634555d8bd9229d7fc511af3cf96fececbb8 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/domains/simpledomain.py @@ -0,0 +1,15 @@ +"""Implementation of :class:`SimpleDomain` class. """ + + +from sympy.polys.domains.domain import Domain +from sympy.utilities import public + +@public +class SimpleDomain(Domain): + """Base class for simple domains, e.g. ZZ, QQ. """ + + is_Simple = True + + def inject(self, *gens): + """Inject generators into this domain. """ + return self.poly_ring(*gens) diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/tests/__init__.py b/phivenv/Lib/site-packages/sympy/polys/domains/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/tests/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/domains/tests/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..705e6fc37bcfefcf10a0f5b30209e7506871fc53 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/domains/tests/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/tests/__pycache__/test_domains.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/domains/tests/__pycache__/test_domains.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d17a132a25c07bc0035f98b223467f4b6753de5a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/domains/tests/__pycache__/test_domains.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/tests/__pycache__/test_polynomialring.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/domains/tests/__pycache__/test_polynomialring.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b27671a9149c40a2415f60fa6962b43871725140 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/domains/tests/__pycache__/test_polynomialring.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/tests/__pycache__/test_quotientring.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/domains/tests/__pycache__/test_quotientring.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d3851f623b9c1612cb4dac75f3030ada8bb376d Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/domains/tests/__pycache__/test_quotientring.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/tests/test_domains.py b/phivenv/Lib/site-packages/sympy/polys/domains/tests/test_domains.py new file mode 100644 index 0000000000000000000000000000000000000000..403cb37a4f093517183345f0b53fc5253f6756bd --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/domains/tests/test_domains.py @@ -0,0 +1,1434 @@ +"""Tests for classes defining properties of ground domains, e.g. ZZ, QQ, ZZ[x] ... """ + +from sympy.external.gmpy import GROUND_TYPES + +from sympy.core.numbers import (AlgebraicNumber, E, Float, I, Integer, + Rational, oo, pi, _illegal) +from sympy.core.singleton import S +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import sin +from sympy.polys.polytools import Poly +from sympy.abc import x, y, z + +from sympy.polys.domains import (ZZ, QQ, RR, CC, FF, GF, EX, EXRAW, ZZ_gmpy, + ZZ_python, QQ_gmpy, QQ_python) +from sympy.polys.domains.algebraicfield import AlgebraicField +from sympy.polys.domains.gaussiandomains import ZZ_I, QQ_I +from sympy.polys.domains.polynomialring import PolynomialRing +from sympy.polys.domains.realfield import RealField + +from sympy.polys.numberfields.subfield import field_isomorphism +from sympy.polys.rings import ring, PolyElement +from sympy.polys.specialpolys import cyclotomic_poly +from sympy.polys.fields import field, FracElement + +from sympy.polys.agca.extensions import FiniteExtension + +from sympy.polys.polyerrors import ( + UnificationFailed, + GeneratorsError, + CoercionFailed, + NotInvertible, + DomainError) + +from sympy.testing.pytest import raises, warns_deprecated_sympy + +from itertools import product + +ALG = QQ.algebraic_field(sqrt(2), sqrt(3)) + +def unify(K0, K1): + return K0.unify(K1) + +def test_Domain_unify(): + F3 = GF(3) + F5 = GF(5) + + assert unify(F3, F3) == F3 + raises(UnificationFailed, lambda: unify(F3, ZZ)) + raises(UnificationFailed, lambda: unify(F3, QQ)) + raises(UnificationFailed, lambda: unify(F3, ZZ_I)) + raises(UnificationFailed, lambda: unify(F3, QQ_I)) + raises(UnificationFailed, lambda: unify(F3, ALG)) + raises(UnificationFailed, lambda: unify(F3, RR)) + raises(UnificationFailed, lambda: unify(F3, CC)) + raises(UnificationFailed, lambda: unify(F3, ZZ[x])) + raises(UnificationFailed, lambda: unify(F3, ZZ.frac_field(x))) + raises(UnificationFailed, lambda: unify(F3, EX)) + + assert unify(F5, F5) == F5 + raises(UnificationFailed, lambda: unify(F5, F3)) + raises(UnificationFailed, lambda: unify(F5, F3[x])) + raises(UnificationFailed, lambda: unify(F5, F3.frac_field(x))) + + raises(UnificationFailed, lambda: unify(ZZ, F3)) + assert unify(ZZ, ZZ) == ZZ + assert unify(ZZ, QQ) == QQ + assert unify(ZZ, ALG) == ALG + assert unify(ZZ, RR) == RR + assert unify(ZZ, CC) == CC + assert unify(ZZ, ZZ[x]) == ZZ[x] + assert unify(ZZ, ZZ.frac_field(x)) == ZZ.frac_field(x) + assert unify(ZZ, EX) == EX + + raises(UnificationFailed, lambda: unify(QQ, F3)) + assert unify(QQ, ZZ) == QQ + assert unify(QQ, QQ) == QQ + assert unify(QQ, ALG) == ALG + assert unify(QQ, RR) == RR + assert unify(QQ, CC) == CC + assert unify(QQ, ZZ[x]) == QQ[x] + assert unify(QQ, ZZ.frac_field(x)) == QQ.frac_field(x) + assert unify(QQ, EX) == EX + + raises(UnificationFailed, lambda: unify(ZZ_I, F3)) + assert unify(ZZ_I, ZZ) == ZZ_I + assert unify(ZZ_I, ZZ_I) == ZZ_I + assert unify(ZZ_I, QQ) == QQ_I + assert unify(ZZ_I, ALG) == QQ.algebraic_field(I, sqrt(2), sqrt(3)) + assert unify(ZZ_I, RR) == CC + assert unify(ZZ_I, CC) == CC + assert unify(ZZ_I, ZZ[x]) == ZZ_I[x] + assert unify(ZZ_I, ZZ_I[x]) == ZZ_I[x] + assert unify(ZZ_I, ZZ.frac_field(x)) == ZZ_I.frac_field(x) + assert unify(ZZ_I, ZZ_I.frac_field(x)) == ZZ_I.frac_field(x) + assert unify(ZZ_I, EX) == EX + + raises(UnificationFailed, lambda: unify(QQ_I, F3)) + assert unify(QQ_I, ZZ) == QQ_I + assert unify(QQ_I, ZZ_I) == QQ_I + assert unify(QQ_I, QQ) == QQ_I + assert unify(QQ_I, ALG) == QQ.algebraic_field(I, sqrt(2), sqrt(3)) + assert unify(QQ_I, RR) == CC + assert unify(QQ_I, CC) == CC + assert unify(QQ_I, ZZ[x]) == QQ_I[x] + assert unify(QQ_I, ZZ_I[x]) == QQ_I[x] + assert unify(QQ_I, QQ[x]) == QQ_I[x] + assert unify(QQ_I, QQ_I[x]) == QQ_I[x] + assert unify(QQ_I, ZZ.frac_field(x)) == QQ_I.frac_field(x) + assert unify(QQ_I, ZZ_I.frac_field(x)) == QQ_I.frac_field(x) + assert unify(QQ_I, QQ.frac_field(x)) == QQ_I.frac_field(x) + assert unify(QQ_I, QQ_I.frac_field(x)) == QQ_I.frac_field(x) + assert unify(QQ_I, EX) == EX + + raises(UnificationFailed, lambda: unify(RR, F3)) + assert unify(RR, ZZ) == RR + assert unify(RR, QQ) == RR + assert unify(RR, ALG) == RR + assert unify(RR, RR) == RR + assert unify(RR, CC) == CC + assert unify(RR, ZZ[x]) == RR[x] + assert unify(RR, ZZ.frac_field(x)) == RR.frac_field(x) + assert unify(RR, EX) == EX + assert RR[x].unify(ZZ.frac_field(y)) == RR.frac_field(x, y) + + raises(UnificationFailed, lambda: unify(CC, F3)) + assert unify(CC, ZZ) == CC + assert unify(CC, QQ) == CC + assert unify(CC, ALG) == CC + assert unify(CC, RR) == CC + assert unify(CC, CC) == CC + assert unify(CC, ZZ[x]) == CC[x] + assert unify(CC, ZZ.frac_field(x)) == CC.frac_field(x) + assert unify(CC, EX) == EX + + raises(UnificationFailed, lambda: unify(ZZ[x], F3)) + assert unify(ZZ[x], ZZ) == ZZ[x] + assert unify(ZZ[x], QQ) == QQ[x] + assert unify(ZZ[x], ALG) == ALG[x] + assert unify(ZZ[x], RR) == RR[x] + assert unify(ZZ[x], CC) == CC[x] + assert unify(ZZ[x], ZZ[x]) == ZZ[x] + assert unify(ZZ[x], ZZ.frac_field(x)) == ZZ.frac_field(x) + assert unify(ZZ[x], EX) == EX + + raises(UnificationFailed, lambda: unify(ZZ.frac_field(x), F3)) + assert unify(ZZ.frac_field(x), ZZ) == ZZ.frac_field(x) + assert unify(ZZ.frac_field(x), QQ) == QQ.frac_field(x) + assert unify(ZZ.frac_field(x), ALG) == ALG.frac_field(x) + assert unify(ZZ.frac_field(x), RR) == RR.frac_field(x) + assert unify(ZZ.frac_field(x), CC) == CC.frac_field(x) + assert unify(ZZ.frac_field(x), ZZ[x]) == ZZ.frac_field(x) + assert unify(ZZ.frac_field(x), ZZ.frac_field(x)) == ZZ.frac_field(x) + assert unify(ZZ.frac_field(x), EX) == EX + + raises(UnificationFailed, lambda: unify(EX, F3)) + assert unify(EX, ZZ) == EX + assert unify(EX, QQ) == EX + assert unify(EX, ALG) == EX + assert unify(EX, RR) == EX + assert unify(EX, CC) == EX + assert unify(EX, ZZ[x]) == EX + assert unify(EX, ZZ.frac_field(x)) == EX + assert unify(EX, EX) == EX + +def test_Domain_unify_composite(): + assert unify(ZZ.poly_ring(x), ZZ) == ZZ.poly_ring(x) + assert unify(ZZ.poly_ring(x), QQ) == QQ.poly_ring(x) + assert unify(QQ.poly_ring(x), ZZ) == QQ.poly_ring(x) + assert unify(QQ.poly_ring(x), QQ) == QQ.poly_ring(x) + + assert unify(ZZ, ZZ.poly_ring(x)) == ZZ.poly_ring(x) + assert unify(QQ, ZZ.poly_ring(x)) == QQ.poly_ring(x) + assert unify(ZZ, QQ.poly_ring(x)) == QQ.poly_ring(x) + assert unify(QQ, QQ.poly_ring(x)) == QQ.poly_ring(x) + + assert unify(ZZ.poly_ring(x, y), ZZ) == ZZ.poly_ring(x, y) + assert unify(ZZ.poly_ring(x, y), QQ) == QQ.poly_ring(x, y) + assert unify(QQ.poly_ring(x, y), ZZ) == QQ.poly_ring(x, y) + assert unify(QQ.poly_ring(x, y), QQ) == QQ.poly_ring(x, y) + + assert unify(ZZ, ZZ.poly_ring(x, y)) == ZZ.poly_ring(x, y) + assert unify(QQ, ZZ.poly_ring(x, y)) == QQ.poly_ring(x, y) + assert unify(ZZ, QQ.poly_ring(x, y)) == QQ.poly_ring(x, y) + assert unify(QQ, QQ.poly_ring(x, y)) == QQ.poly_ring(x, y) + + assert unify(ZZ.frac_field(x), ZZ) == ZZ.frac_field(x) + assert unify(ZZ.frac_field(x), QQ) == QQ.frac_field(x) + assert unify(QQ.frac_field(x), ZZ) == QQ.frac_field(x) + assert unify(QQ.frac_field(x), QQ) == QQ.frac_field(x) + + assert unify(ZZ, ZZ.frac_field(x)) == ZZ.frac_field(x) + assert unify(QQ, ZZ.frac_field(x)) == QQ.frac_field(x) + assert unify(ZZ, QQ.frac_field(x)) == QQ.frac_field(x) + assert unify(QQ, QQ.frac_field(x)) == QQ.frac_field(x) + + assert unify(ZZ.frac_field(x, y), ZZ) == ZZ.frac_field(x, y) + assert unify(ZZ.frac_field(x, y), QQ) == QQ.frac_field(x, y) + assert unify(QQ.frac_field(x, y), ZZ) == QQ.frac_field(x, y) + assert unify(QQ.frac_field(x, y), QQ) == QQ.frac_field(x, y) + + assert unify(ZZ, ZZ.frac_field(x, y)) == ZZ.frac_field(x, y) + assert unify(QQ, ZZ.frac_field(x, y)) == QQ.frac_field(x, y) + assert unify(ZZ, QQ.frac_field(x, y)) == QQ.frac_field(x, y) + assert unify(QQ, QQ.frac_field(x, y)) == QQ.frac_field(x, y) + + assert unify(ZZ.poly_ring(x), ZZ.poly_ring(x)) == ZZ.poly_ring(x) + assert unify(ZZ.poly_ring(x), QQ.poly_ring(x)) == QQ.poly_ring(x) + assert unify(QQ.poly_ring(x), ZZ.poly_ring(x)) == QQ.poly_ring(x) + assert unify(QQ.poly_ring(x), QQ.poly_ring(x)) == QQ.poly_ring(x) + + assert unify(ZZ.poly_ring(x, y), ZZ.poly_ring(x)) == ZZ.poly_ring(x, y) + assert unify(ZZ.poly_ring(x, y), QQ.poly_ring(x)) == QQ.poly_ring(x, y) + assert unify(QQ.poly_ring(x, y), ZZ.poly_ring(x)) == QQ.poly_ring(x, y) + assert unify(QQ.poly_ring(x, y), QQ.poly_ring(x)) == QQ.poly_ring(x, y) + + assert unify(ZZ.poly_ring(x), ZZ.poly_ring(x, y)) == ZZ.poly_ring(x, y) + assert unify(ZZ.poly_ring(x), QQ.poly_ring(x, y)) == QQ.poly_ring(x, y) + assert unify(QQ.poly_ring(x), ZZ.poly_ring(x, y)) == QQ.poly_ring(x, y) + assert unify(QQ.poly_ring(x), QQ.poly_ring(x, y)) == QQ.poly_ring(x, y) + + assert unify(ZZ.poly_ring(x, y), ZZ.poly_ring(x, z)) == ZZ.poly_ring(x, y, z) + assert unify(ZZ.poly_ring(x, y), QQ.poly_ring(x, z)) == QQ.poly_ring(x, y, z) + assert unify(QQ.poly_ring(x, y), ZZ.poly_ring(x, z)) == QQ.poly_ring(x, y, z) + assert unify(QQ.poly_ring(x, y), QQ.poly_ring(x, z)) == QQ.poly_ring(x, y, z) + + assert unify(ZZ.frac_field(x), ZZ.frac_field(x)) == ZZ.frac_field(x) + assert unify(ZZ.frac_field(x), QQ.frac_field(x)) == QQ.frac_field(x) + assert unify(QQ.frac_field(x), ZZ.frac_field(x)) == QQ.frac_field(x) + assert unify(QQ.frac_field(x), QQ.frac_field(x)) == QQ.frac_field(x) + + assert unify(ZZ.frac_field(x, y), ZZ.frac_field(x)) == ZZ.frac_field(x, y) + assert unify(ZZ.frac_field(x, y), QQ.frac_field(x)) == QQ.frac_field(x, y) + assert unify(QQ.frac_field(x, y), ZZ.frac_field(x)) == QQ.frac_field(x, y) + assert unify(QQ.frac_field(x, y), QQ.frac_field(x)) == QQ.frac_field(x, y) + + assert unify(ZZ.frac_field(x), ZZ.frac_field(x, y)) == ZZ.frac_field(x, y) + assert unify(ZZ.frac_field(x), QQ.frac_field(x, y)) == QQ.frac_field(x, y) + assert unify(QQ.frac_field(x), ZZ.frac_field(x, y)) == QQ.frac_field(x, y) + assert unify(QQ.frac_field(x), QQ.frac_field(x, y)) == QQ.frac_field(x, y) + + assert unify(ZZ.frac_field(x, y), ZZ.frac_field(x, z)) == ZZ.frac_field(x, y, z) + assert unify(ZZ.frac_field(x, y), QQ.frac_field(x, z)) == QQ.frac_field(x, y, z) + assert unify(QQ.frac_field(x, y), ZZ.frac_field(x, z)) == QQ.frac_field(x, y, z) + assert unify(QQ.frac_field(x, y), QQ.frac_field(x, z)) == QQ.frac_field(x, y, z) + + assert unify(ZZ.poly_ring(x), ZZ.frac_field(x)) == ZZ.frac_field(x) + assert unify(ZZ.poly_ring(x), QQ.frac_field(x)) == ZZ.frac_field(x) + assert unify(QQ.poly_ring(x), ZZ.frac_field(x)) == ZZ.frac_field(x) + assert unify(QQ.poly_ring(x), QQ.frac_field(x)) == QQ.frac_field(x) + + assert unify(ZZ.poly_ring(x, y), ZZ.frac_field(x)) == ZZ.frac_field(x, y) + assert unify(ZZ.poly_ring(x, y), QQ.frac_field(x)) == ZZ.frac_field(x, y) + assert unify(QQ.poly_ring(x, y), ZZ.frac_field(x)) == ZZ.frac_field(x, y) + assert unify(QQ.poly_ring(x, y), QQ.frac_field(x)) == QQ.frac_field(x, y) + + assert unify(ZZ.poly_ring(x), ZZ.frac_field(x, y)) == ZZ.frac_field(x, y) + assert unify(ZZ.poly_ring(x), QQ.frac_field(x, y)) == ZZ.frac_field(x, y) + assert unify(QQ.poly_ring(x), ZZ.frac_field(x, y)) == ZZ.frac_field(x, y) + assert unify(QQ.poly_ring(x), QQ.frac_field(x, y)) == QQ.frac_field(x, y) + + assert unify(ZZ.poly_ring(x, y), ZZ.frac_field(x, z)) == ZZ.frac_field(x, y, z) + assert unify(ZZ.poly_ring(x, y), QQ.frac_field(x, z)) == ZZ.frac_field(x, y, z) + assert unify(QQ.poly_ring(x, y), ZZ.frac_field(x, z)) == ZZ.frac_field(x, y, z) + assert unify(QQ.poly_ring(x, y), QQ.frac_field(x, z)) == QQ.frac_field(x, y, z) + + assert unify(ZZ.frac_field(x), ZZ.poly_ring(x)) == ZZ.frac_field(x) + assert unify(ZZ.frac_field(x), QQ.poly_ring(x)) == ZZ.frac_field(x) + assert unify(QQ.frac_field(x), ZZ.poly_ring(x)) == ZZ.frac_field(x) + assert unify(QQ.frac_field(x), QQ.poly_ring(x)) == QQ.frac_field(x) + + assert unify(ZZ.frac_field(x, y), ZZ.poly_ring(x)) == ZZ.frac_field(x, y) + assert unify(ZZ.frac_field(x, y), QQ.poly_ring(x)) == ZZ.frac_field(x, y) + assert unify(QQ.frac_field(x, y), ZZ.poly_ring(x)) == ZZ.frac_field(x, y) + assert unify(QQ.frac_field(x, y), QQ.poly_ring(x)) == QQ.frac_field(x, y) + + assert unify(ZZ.frac_field(x), ZZ.poly_ring(x, y)) == ZZ.frac_field(x, y) + assert unify(ZZ.frac_field(x), QQ.poly_ring(x, y)) == ZZ.frac_field(x, y) + assert unify(QQ.frac_field(x), ZZ.poly_ring(x, y)) == ZZ.frac_field(x, y) + assert unify(QQ.frac_field(x), QQ.poly_ring(x, y)) == QQ.frac_field(x, y) + + assert unify(ZZ.frac_field(x, y), ZZ.poly_ring(x, z)) == ZZ.frac_field(x, y, z) + assert unify(ZZ.frac_field(x, y), QQ.poly_ring(x, z)) == ZZ.frac_field(x, y, z) + assert unify(QQ.frac_field(x, y), ZZ.poly_ring(x, z)) == ZZ.frac_field(x, y, z) + assert unify(QQ.frac_field(x, y), QQ.poly_ring(x, z)) == QQ.frac_field(x, y, z) + +def test_Domain_unify_algebraic(): + sqrt5 = QQ.algebraic_field(sqrt(5)) + sqrt7 = QQ.algebraic_field(sqrt(7)) + sqrt57 = QQ.algebraic_field(sqrt(5), sqrt(7)) + + assert sqrt5.unify(sqrt7) == sqrt57 + + assert sqrt5.unify(sqrt5[x, y]) == sqrt5[x, y] + assert sqrt5[x, y].unify(sqrt5) == sqrt5[x, y] + + assert sqrt5.unify(sqrt5.frac_field(x, y)) == sqrt5.frac_field(x, y) + assert sqrt5.frac_field(x, y).unify(sqrt5) == sqrt5.frac_field(x, y) + + assert sqrt5.unify(sqrt7[x, y]) == sqrt57[x, y] + assert sqrt5[x, y].unify(sqrt7) == sqrt57[x, y] + + assert sqrt5.unify(sqrt7.frac_field(x, y)) == sqrt57.frac_field(x, y) + assert sqrt5.frac_field(x, y).unify(sqrt7) == sqrt57.frac_field(x, y) + +def test_Domain_unify_FiniteExtension(): + KxZZ = FiniteExtension(Poly(x**2 - 2, x, domain=ZZ)) + KxQQ = FiniteExtension(Poly(x**2 - 2, x, domain=QQ)) + KxZZy = FiniteExtension(Poly(x**2 - 2, x, domain=ZZ[y])) + KxQQy = FiniteExtension(Poly(x**2 - 2, x, domain=QQ[y])) + + assert KxZZ.unify(KxZZ) == KxZZ + assert KxQQ.unify(KxQQ) == KxQQ + assert KxZZy.unify(KxZZy) == KxZZy + assert KxQQy.unify(KxQQy) == KxQQy + + assert KxZZ.unify(ZZ) == KxZZ + assert KxZZ.unify(QQ) == KxQQ + assert KxQQ.unify(ZZ) == KxQQ + assert KxQQ.unify(QQ) == KxQQ + + assert KxZZ.unify(ZZ[y]) == KxZZy + assert KxZZ.unify(QQ[y]) == KxQQy + assert KxQQ.unify(ZZ[y]) == KxQQy + assert KxQQ.unify(QQ[y]) == KxQQy + + assert KxZZy.unify(ZZ) == KxZZy + assert KxZZy.unify(QQ) == KxQQy + assert KxQQy.unify(ZZ) == KxQQy + assert KxQQy.unify(QQ) == KxQQy + + assert KxZZy.unify(ZZ[y]) == KxZZy + assert KxZZy.unify(QQ[y]) == KxQQy + assert KxQQy.unify(ZZ[y]) == KxQQy + assert KxQQy.unify(QQ[y]) == KxQQy + + K = FiniteExtension(Poly(x**2 - 2, x, domain=ZZ[y])) + assert K.unify(ZZ) == K + assert K.unify(ZZ[x]) == K + assert K.unify(ZZ[y]) == K + assert K.unify(ZZ[x, y]) == K + + Kz = FiniteExtension(Poly(x**2 - 2, x, domain=ZZ[y, z])) + assert K.unify(ZZ[z]) == Kz + assert K.unify(ZZ[x, z]) == Kz + assert K.unify(ZZ[y, z]) == Kz + assert K.unify(ZZ[x, y, z]) == Kz + + Kx = FiniteExtension(Poly(x**2 - 2, x, domain=ZZ)) + Ky = FiniteExtension(Poly(y**2 - 2, y, domain=ZZ)) + Kxy = FiniteExtension(Poly(y**2 - 2, y, domain=Kx)) + assert Kx.unify(Kx) == Kx + assert Ky.unify(Ky) == Ky + assert Kx.unify(Ky) == Kxy + assert Ky.unify(Kx) == Kxy + +def test_Domain_unify_with_symbols(): + raises(UnificationFailed, lambda: ZZ[x, y].unify_with_symbols(ZZ, (y, z))) + raises(UnificationFailed, lambda: ZZ.unify_with_symbols(ZZ[x, y], (y, z))) + +def test_Domain__contains__(): + assert (0 in EX) is True + assert (0 in ZZ) is True + assert (0 in QQ) is True + assert (0 in RR) is True + assert (0 in CC) is True + assert (0 in ALG) is True + assert (0 in ZZ[x, y]) is True + assert (0 in QQ[x, y]) is True + assert (0 in RR[x, y]) is True + + assert (-7 in EX) is True + assert (-7 in ZZ) is True + assert (-7 in QQ) is True + assert (-7 in RR) is True + assert (-7 in CC) is True + assert (-7 in ALG) is True + assert (-7 in ZZ[x, y]) is True + assert (-7 in QQ[x, y]) is True + assert (-7 in RR[x, y]) is True + + assert (17 in EX) is True + assert (17 in ZZ) is True + assert (17 in QQ) is True + assert (17 in RR) is True + assert (17 in CC) is True + assert (17 in ALG) is True + assert (17 in ZZ[x, y]) is True + assert (17 in QQ[x, y]) is True + assert (17 in RR[x, y]) is True + + assert (Rational(-1, 7) in EX) is True + assert (Rational(-1, 7) in ZZ) is False + assert (Rational(-1, 7) in QQ) is True + assert (Rational(-1, 7) in RR) is True + assert (Rational(-1, 7) in CC) is True + assert (Rational(-1, 7) in ALG) is True + assert (Rational(-1, 7) in ZZ[x, y]) is False + assert (Rational(-1, 7) in QQ[x, y]) is True + assert (Rational(-1, 7) in RR[x, y]) is True + + assert (Rational(3, 5) in EX) is True + assert (Rational(3, 5) in ZZ) is False + assert (Rational(3, 5) in QQ) is True + assert (Rational(3, 5) in RR) is True + assert (Rational(3, 5) in CC) is True + assert (Rational(3, 5) in ALG) is True + assert (Rational(3, 5) in ZZ[x, y]) is False + assert (Rational(3, 5) in QQ[x, y]) is True + assert (Rational(3, 5) in RR[x, y]) is True + + assert (3.0 in EX) is True + assert (3.0 in ZZ) is True + assert (3.0 in QQ) is True + assert (3.0 in RR) is True + assert (3.0 in CC) is True + assert (3.0 in ALG) is True + assert (3.0 in ZZ[x, y]) is True + assert (3.0 in QQ[x, y]) is True + assert (3.0 in RR[x, y]) is True + + assert (3.14 in EX) is True + assert (3.14 in ZZ) is False + assert (3.14 in QQ) is True + assert (3.14 in RR) is True + assert (3.14 in CC) is True + assert (3.14 in ALG) is True + assert (3.14 in ZZ[x, y]) is False + assert (3.14 in QQ[x, y]) is True + assert (3.14 in RR[x, y]) is True + + assert (oo in ALG) is False + assert (oo in ZZ[x, y]) is False + assert (oo in QQ[x, y]) is False + + assert (-oo in ZZ) is False + assert (-oo in QQ) is False + assert (-oo in ALG) is False + assert (-oo in ZZ[x, y]) is False + assert (-oo in QQ[x, y]) is False + + assert (sqrt(7) in EX) is True + assert (sqrt(7) in ZZ) is False + assert (sqrt(7) in QQ) is False + assert (sqrt(7) in RR) is True + assert (sqrt(7) in CC) is True + assert (sqrt(7) in ALG) is False + assert (sqrt(7) in ZZ[x, y]) is False + assert (sqrt(7) in QQ[x, y]) is False + assert (sqrt(7) in RR[x, y]) is True + + assert (2*sqrt(3) + 1 in EX) is True + assert (2*sqrt(3) + 1 in ZZ) is False + assert (2*sqrt(3) + 1 in QQ) is False + assert (2*sqrt(3) + 1 in RR) is True + assert (2*sqrt(3) + 1 in CC) is True + assert (2*sqrt(3) + 1 in ALG) is True + assert (2*sqrt(3) + 1 in ZZ[x, y]) is False + assert (2*sqrt(3) + 1 in QQ[x, y]) is False + assert (2*sqrt(3) + 1 in RR[x, y]) is True + + assert (sin(1) in EX) is True + assert (sin(1) in ZZ) is False + assert (sin(1) in QQ) is False + assert (sin(1) in RR) is True + assert (sin(1) in CC) is True + assert (sin(1) in ALG) is False + assert (sin(1) in ZZ[x, y]) is False + assert (sin(1) in QQ[x, y]) is False + assert (sin(1) in RR[x, y]) is True + + assert (x**2 + 1 in EX) is True + assert (x**2 + 1 in ZZ) is False + assert (x**2 + 1 in QQ) is False + assert (x**2 + 1 in RR) is False + assert (x**2 + 1 in CC) is False + assert (x**2 + 1 in ALG) is False + assert (x**2 + 1 in ZZ[x]) is True + assert (x**2 + 1 in QQ[x]) is True + assert (x**2 + 1 in RR[x]) is True + assert (x**2 + 1 in ZZ[x, y]) is True + assert (x**2 + 1 in QQ[x, y]) is True + assert (x**2 + 1 in RR[x, y]) is True + + assert (x**2 + y**2 in EX) is True + assert (x**2 + y**2 in ZZ) is False + assert (x**2 + y**2 in QQ) is False + assert (x**2 + y**2 in RR) is False + assert (x**2 + y**2 in CC) is False + assert (x**2 + y**2 in ALG) is False + assert (x**2 + y**2 in ZZ[x]) is False + assert (x**2 + y**2 in QQ[x]) is False + assert (x**2 + y**2 in RR[x]) is False + assert (x**2 + y**2 in ZZ[x, y]) is True + assert (x**2 + y**2 in QQ[x, y]) is True + assert (x**2 + y**2 in RR[x, y]) is True + + assert (Rational(3, 2)*x/(y + 1) - z in QQ[x, y, z]) is False + + +def test_issue_14433(): + assert (Rational(2, 3)*x in QQ.frac_field(1/x)) is True + assert (1/x in QQ.frac_field(x)) is True + assert ((x**2 + y**2) in QQ.frac_field(1/x, 1/y)) is True + assert ((x + y) in QQ.frac_field(1/x, y)) is True + assert ((x - y) in QQ.frac_field(x, 1/y)) is True + + +def test_Domain_is_field(): + assert ZZ.is_Field is False + assert GF(5).is_Field is True + assert GF(6).is_Field is False + assert QQ.is_Field is True + assert RR.is_Field is True + assert CC.is_Field is True + assert EX.is_Field is True + assert ALG.is_Field is True + assert QQ[x].is_Field is False + assert ZZ.frac_field(x).is_Field is True + + +def test_Domain_get_ring(): + assert ZZ.has_assoc_Ring is True + assert QQ.has_assoc_Ring is True + assert ZZ[x].has_assoc_Ring is True + assert QQ[x].has_assoc_Ring is True + assert ZZ[x, y].has_assoc_Ring is True + assert QQ[x, y].has_assoc_Ring is True + assert ZZ.frac_field(x).has_assoc_Ring is True + assert QQ.frac_field(x).has_assoc_Ring is True + assert ZZ.frac_field(x, y).has_assoc_Ring is True + assert QQ.frac_field(x, y).has_assoc_Ring is True + + assert EX.has_assoc_Ring is False + assert RR.has_assoc_Ring is False + assert ALG.has_assoc_Ring is False + + assert ZZ.get_ring() == ZZ + assert QQ.get_ring() == ZZ + assert ZZ[x].get_ring() == ZZ[x] + assert QQ[x].get_ring() == QQ[x] + assert ZZ[x, y].get_ring() == ZZ[x, y] + assert QQ[x, y].get_ring() == QQ[x, y] + assert ZZ.frac_field(x).get_ring() == ZZ[x] + assert QQ.frac_field(x).get_ring() == QQ[x] + assert ZZ.frac_field(x, y).get_ring() == ZZ[x, y] + assert QQ.frac_field(x, y).get_ring() == QQ[x, y] + + assert EX.get_ring() == EX + + assert RR.get_ring() == RR + # XXX: This should also be like RR + raises(DomainError, lambda: ALG.get_ring()) + + +def test_Domain_get_field(): + assert EX.has_assoc_Field is True + assert ZZ.has_assoc_Field is True + assert QQ.has_assoc_Field is True + assert RR.has_assoc_Field is True + assert ALG.has_assoc_Field is True + assert ZZ[x].has_assoc_Field is True + assert QQ[x].has_assoc_Field is True + assert ZZ[x, y].has_assoc_Field is True + assert QQ[x, y].has_assoc_Field is True + + assert EX.get_field() == EX + assert ZZ.get_field() == QQ + assert QQ.get_field() == QQ + assert RR.get_field() == RR + assert ALG.get_field() == ALG + assert ZZ[x].get_field() == ZZ.frac_field(x) + assert QQ[x].get_field() == QQ.frac_field(x) + assert ZZ[x, y].get_field() == ZZ.frac_field(x, y) + assert QQ[x, y].get_field() == QQ.frac_field(x, y) + + +def test_Domain_set_domain(): + doms = [GF(5), ZZ, QQ, ALG, RR, CC, EX, ZZ[z], QQ[z], RR[z], CC[z], EX[z]] + for D1 in doms: + for D2 in doms: + assert D1[x].set_domain(D2) == D2[x] + assert D1[x, y].set_domain(D2) == D2[x, y] + assert D1.frac_field(x).set_domain(D2) == D2.frac_field(x) + assert D1.frac_field(x, y).set_domain(D2) == D2.frac_field(x, y) + assert D1.old_poly_ring(x).set_domain(D2) == D2.old_poly_ring(x) + assert D1.old_poly_ring(x, y).set_domain(D2) == D2.old_poly_ring(x, y) + assert D1.old_frac_field(x).set_domain(D2) == D2.old_frac_field(x) + assert D1.old_frac_field(x, y).set_domain(D2) == D2.old_frac_field(x, y) + + +def test_Domain_is_Exact(): + exact = [GF(5), ZZ, QQ, ALG, EX] + inexact = [RR, CC] + for D in exact + inexact: + for R in D, D[x], D.frac_field(x), D.old_poly_ring(x), D.old_frac_field(x): + if D in exact: + assert R.is_Exact is True + else: + assert R.is_Exact is False + + +def test_Domain_get_exact(): + assert EX.get_exact() == EX + assert ZZ.get_exact() == ZZ + assert QQ.get_exact() == QQ + assert RR.get_exact() == QQ + assert CC.get_exact() == QQ_I + assert ALG.get_exact() == ALG + assert ZZ[x].get_exact() == ZZ[x] + assert QQ[x].get_exact() == QQ[x] + assert RR[x].get_exact() == QQ[x] + assert CC[x].get_exact() == QQ_I[x] + assert ZZ[x, y].get_exact() == ZZ[x, y] + assert QQ[x, y].get_exact() == QQ[x, y] + assert RR[x, y].get_exact() == QQ[x, y] + assert CC[x, y].get_exact() == QQ_I[x, y] + assert ZZ.frac_field(x).get_exact() == ZZ.frac_field(x) + assert QQ.frac_field(x).get_exact() == QQ.frac_field(x) + assert RR.frac_field(x).get_exact() == QQ.frac_field(x) + assert CC.frac_field(x).get_exact() == QQ_I.frac_field(x) + assert ZZ.frac_field(x, y).get_exact() == ZZ.frac_field(x, y) + assert QQ.frac_field(x, y).get_exact() == QQ.frac_field(x, y) + assert RR.frac_field(x, y).get_exact() == QQ.frac_field(x, y) + assert CC.frac_field(x, y).get_exact() == QQ_I.frac_field(x, y) + assert ZZ.old_poly_ring(x).get_exact() == ZZ.old_poly_ring(x) + assert QQ.old_poly_ring(x).get_exact() == QQ.old_poly_ring(x) + assert RR.old_poly_ring(x).get_exact() == QQ.old_poly_ring(x) + assert CC.old_poly_ring(x).get_exact() == QQ_I.old_poly_ring(x) + assert ZZ.old_poly_ring(x, y).get_exact() == ZZ.old_poly_ring(x, y) + assert QQ.old_poly_ring(x, y).get_exact() == QQ.old_poly_ring(x, y) + assert RR.old_poly_ring(x, y).get_exact() == QQ.old_poly_ring(x, y) + assert CC.old_poly_ring(x, y).get_exact() == QQ_I.old_poly_ring(x, y) + assert ZZ.old_frac_field(x).get_exact() == ZZ.old_frac_field(x) + assert QQ.old_frac_field(x).get_exact() == QQ.old_frac_field(x) + assert RR.old_frac_field(x).get_exact() == QQ.old_frac_field(x) + assert CC.old_frac_field(x).get_exact() == QQ_I.old_frac_field(x) + assert ZZ.old_frac_field(x, y).get_exact() == ZZ.old_frac_field(x, y) + assert QQ.old_frac_field(x, y).get_exact() == QQ.old_frac_field(x, y) + assert RR.old_frac_field(x, y).get_exact() == QQ.old_frac_field(x, y) + assert CC.old_frac_field(x, y).get_exact() == QQ_I.old_frac_field(x, y) + + +def test_Domain_characteristic(): + for F, c in [(FF(3), 3), (FF(5), 5), (FF(7), 7)]: + for R in F, F[x], F.frac_field(x), F.old_poly_ring(x), F.old_frac_field(x): + assert R.has_CharacteristicZero is False + assert R.characteristic() == c + for D in ZZ, QQ, ZZ_I, QQ_I, ALG: + for R in D, D[x], D.frac_field(x), D.old_poly_ring(x), D.old_frac_field(x): + assert R.has_CharacteristicZero is True + assert R.characteristic() == 0 + + +def test_Domain_is_unit(): + nums = [-2, -1, 0, 1, 2] + invring = [False, True, False, True, False] + invfield = [True, True, False, True, True] + ZZx, QQx, QQxf = ZZ[x], QQ[x], QQ.frac_field(x) + assert [ZZ.is_unit(ZZ(n)) for n in nums] == invring + assert [QQ.is_unit(QQ(n)) for n in nums] == invfield + assert [ZZx.is_unit(ZZx(n)) for n in nums] == invring + assert [QQx.is_unit(QQx(n)) for n in nums] == invfield + assert [QQxf.is_unit(QQxf(n)) for n in nums] == invfield + assert ZZx.is_unit(ZZx(x)) is False + assert QQx.is_unit(QQx(x)) is False + assert QQxf.is_unit(QQxf(x)) is True + + +def test_Domain_convert(): + + def check_element(e1, e2, K1, K2, K3): + if isinstance(e1, PolyElement): + assert isinstance(e2, PolyElement) and e1.ring == e2.ring + elif isinstance(e1, FracElement): + assert isinstance(e2, FracElement) and e1.field == e2.field + else: + assert type(e1) is type(e2), '%s, %s: %s %s -> %s' % (e1, e2, K1, K2, K3) + assert e1 == e2, '%s, %s: %s %s -> %s' % (e1, e2, K1, K2, K3) + + def check_domains(K1, K2): + K3 = K1.unify(K2) + check_element(K3.convert_from(K1.one, K1), K3.one, K1, K2, K3) + check_element(K3.convert_from(K2.one, K2), K3.one, K1, K2, K3) + check_element(K3.convert_from(K1.zero, K1), K3.zero, K1, K2, K3) + check_element(K3.convert_from(K2.zero, K2), K3.zero, K1, K2, K3) + + def composite_domains(K): + domains = [ + K, + K[y], K[z], K[y, z], + K.frac_field(y), K.frac_field(z), K.frac_field(y, z), + # XXX: These should be tested and made to work... + # K.old_poly_ring(y), K.old_frac_field(y), + ] + return domains + + QQ2 = QQ.algebraic_field(sqrt(2)) + QQ3 = QQ.algebraic_field(sqrt(3)) + doms = [ZZ, QQ, QQ2, QQ3, QQ_I, ZZ_I, RR, CC] + + for i, K1 in enumerate(doms): + for K2 in doms[i:]: + for K3 in composite_domains(K1): + for K4 in composite_domains(K2): + check_domains(K3, K4) + + assert QQ.convert(10e-52) == QQ(1684996666696915, 1684996666696914987166688442938726917102321526408785780068975640576) + + R, xr = ring("x", ZZ) + assert ZZ.convert(xr - xr) == 0 + assert ZZ.convert(xr - xr, R.to_domain()) == 0 + + assert CC.convert(ZZ_I(1, 2)) == CC(1, 2) + assert CC.convert(QQ_I(1, 2)) == CC(1, 2) + + assert QQ.convert_from(RR(0.5), RR) == QQ(1, 2) + assert RR.convert_from(QQ(1, 2), QQ) == RR(0.5) + assert QQ_I.convert_from(CC(0.5, 0.75), CC) == QQ_I(QQ(1, 2), QQ(3, 4)) + assert CC.convert_from(QQ_I(QQ(1, 2), QQ(3, 4)), QQ_I) == CC(0.5, 0.75) + + K1 = QQ.frac_field(x) + K2 = ZZ.frac_field(x) + K3 = QQ[x] + K4 = ZZ[x] + Ks = [K1, K2, K3, K4] + for Ka, Kb in product(Ks, Ks): + assert Ka.convert_from(Kb.from_sympy(x), Kb) == Ka.from_sympy(x) + + assert K2.convert_from(QQ(1, 2), QQ) == K2(QQ(1, 2)) + + +def test_EX_convert(): + + elements = [ + (ZZ, ZZ(3)), + (QQ, QQ(1,2)), + (ZZ_I, ZZ_I(1,2)), + (QQ_I, QQ_I(1,2)), + (RR, RR(3)), + (CC, CC(1,2)), + (EX, EX(3)), + (EXRAW, EXRAW(3)), + (ALG, ALG.from_sympy(sqrt(2))), + ] + + for R, e in elements: + for EE in EX, EXRAW: + elem = EE.from_sympy(R.to_sympy(e)) + assert EE.convert_from(e, R) == elem + assert R.convert_from(elem, EE) == e + + +def test_GlobalPolynomialRing_convert(): + K1 = QQ.old_poly_ring(x) + K2 = QQ[x] + assert K1.convert(x) == K1.convert(K2.convert(x), K2) + assert K2.convert(x) == K2.convert(K1.convert(x), K1) + + K1 = QQ.old_poly_ring(x, y) + K2 = QQ[x] + assert K1.convert(x) == K1.convert(K2.convert(x), K2) + #assert K2.convert(x) == K2.convert(K1.convert(x), K1) + + K1 = ZZ.old_poly_ring(x, y) + K2 = QQ[x] + assert K1.convert(x) == K1.convert(K2.convert(x), K2) + #assert K2.convert(x) == K2.convert(K1.convert(x), K1) + + +def test_PolynomialRing__init(): + R, = ring("", ZZ) + assert ZZ.poly_ring() == R.to_domain() + + +def test_FractionField__init(): + F, = field("", ZZ) + assert ZZ.frac_field() == F.to_domain() + + +def test_FractionField_convert(): + K = QQ.frac_field(x) + assert K.convert(QQ(2, 3), QQ) == K.from_sympy(Rational(2, 3)) + K = QQ.frac_field(x) + assert K.convert(ZZ(2), ZZ) == K.from_sympy(Integer(2)) + + +def test_inject(): + assert ZZ.inject(x, y, z) == ZZ[x, y, z] + assert ZZ[x].inject(y, z) == ZZ[x, y, z] + assert ZZ.frac_field(x).inject(y, z) == ZZ.frac_field(x, y, z) + raises(GeneratorsError, lambda: ZZ[x].inject(x)) + + +def test_drop(): + assert ZZ.drop(x) == ZZ + assert ZZ[x].drop(x) == ZZ + assert ZZ[x, y].drop(x) == ZZ[y] + assert ZZ.frac_field(x).drop(x) == ZZ + assert ZZ.frac_field(x, y).drop(x) == ZZ.frac_field(y) + assert ZZ[x][y].drop(y) == ZZ[x] + assert ZZ[x][y].drop(x) == ZZ[y] + assert ZZ.frac_field(x)[y].drop(x) == ZZ[y] + assert ZZ.frac_field(x)[y].drop(y) == ZZ.frac_field(x) + Ky = FiniteExtension(Poly(x**2-1, x, domain=ZZ[y])) + K = FiniteExtension(Poly(x**2-1, x, domain=ZZ)) + assert Ky.drop(y) == K + raises(GeneratorsError, lambda: Ky.drop(x)) + + +def test_Domain_map(): + seq = ZZ.map([1, 2, 3, 4]) + + assert all(ZZ.of_type(elt) for elt in seq) + + seq = ZZ.map([[1, 2, 3, 4]]) + + assert all(ZZ.of_type(elt) for elt in seq[0]) and len(seq) == 1 + + +def test_Domain___eq__(): + assert (ZZ[x, y] == ZZ[x, y]) is True + assert (QQ[x, y] == QQ[x, y]) is True + + assert (ZZ[x, y] == QQ[x, y]) is False + assert (QQ[x, y] == ZZ[x, y]) is False + + assert (ZZ.frac_field(x, y) == ZZ.frac_field(x, y)) is True + assert (QQ.frac_field(x, y) == QQ.frac_field(x, y)) is True + + assert (ZZ.frac_field(x, y) == QQ.frac_field(x, y)) is False + assert (QQ.frac_field(x, y) == ZZ.frac_field(x, y)) is False + + assert RealField()[x] == RR[x] + + +def test_Domain__algebraic_field(): + alg = ZZ.algebraic_field(sqrt(2)) + assert alg.ext.minpoly == Poly(x**2 - 2) + assert alg.dom == QQ + + alg = QQ.algebraic_field(sqrt(2)) + assert alg.ext.minpoly == Poly(x**2 - 2) + assert alg.dom == QQ + + alg = alg.algebraic_field(sqrt(3)) + assert alg.ext.minpoly == Poly(x**4 - 10*x**2 + 1) + assert alg.dom == QQ + + +def test_Domain_alg_field_from_poly(): + f = Poly(x**2 - 2) + g = Poly(x**2 - 3) + h = Poly(x**4 - 10*x**2 + 1) + + alg = ZZ.alg_field_from_poly(f) + assert alg.ext.minpoly == f + assert alg.dom == QQ + + alg = QQ.alg_field_from_poly(f) + assert alg.ext.minpoly == f + assert alg.dom == QQ + + alg = alg.alg_field_from_poly(g) + assert alg.ext.minpoly == h + assert alg.dom == QQ + + +def test_Domain_cyclotomic_field(): + K = ZZ.cyclotomic_field(12) + assert K.ext.minpoly == Poly(cyclotomic_poly(12)) + assert K.dom == QQ + + F = QQ.cyclotomic_field(3) + assert F.ext.minpoly == Poly(cyclotomic_poly(3)) + assert F.dom == QQ + + E = F.cyclotomic_field(4) + assert field_isomorphism(E.ext, K.ext) is not None + assert E.dom == QQ + + +def test_PolynomialRing_from_FractionField(): + F, x,y = field("x,y", ZZ) + R, X,Y = ring("x,y", ZZ) + + f = (x**2 + y**2)/(x + 1) + g = (x**2 + y**2)/4 + h = x**2 + y**2 + + assert R.to_domain().from_FractionField(f, F.to_domain()) is None + assert R.to_domain().from_FractionField(g, F.to_domain()) == X**2/4 + Y**2/4 + assert R.to_domain().from_FractionField(h, F.to_domain()) == X**2 + Y**2 + + F, x,y = field("x,y", QQ) + R, X,Y = ring("x,y", QQ) + + f = (x**2 + y**2)/(x + 1) + g = (x**2 + y**2)/4 + h = x**2 + y**2 + + assert R.to_domain().from_FractionField(f, F.to_domain()) is None + assert R.to_domain().from_FractionField(g, F.to_domain()) == X**2/4 + Y**2/4 + assert R.to_domain().from_FractionField(h, F.to_domain()) == X**2 + Y**2 + + +def test_FractionField_from_PolynomialRing(): + R, x,y = ring("x,y", QQ) + F, X,Y = field("x,y", ZZ) + + f = 3*x**2 + 5*y**2 + g = x**2/3 + y**2/5 + + assert F.to_domain().from_PolynomialRing(f, R.to_domain()) == 3*X**2 + 5*Y**2 + assert F.to_domain().from_PolynomialRing(g, R.to_domain()) == (5*X**2 + 3*Y**2)/15 + + +def test_FF_of_type(): + # XXX: of_type is not very useful here because in the case of ground types + # = flint all elements are of type nmod. + assert FF(3).of_type(FF(3)(1)) is True + assert FF(5).of_type(FF(5)(3)) is True + + +def test___eq__(): + assert not QQ[x] == ZZ[x] + assert not QQ.frac_field(x) == ZZ.frac_field(x) + + +def test_RealField_from_sympy(): + assert RR.convert(S.Zero) == RR.dtype(0) + assert RR.convert(S(0.0)) == RR.dtype(0.0) + assert RR.convert(S.One) == RR.dtype(1) + assert RR.convert(S(1.0)) == RR.dtype(1.0) + assert RR.convert(sin(1)) == RR.dtype(sin(1).evalf()) + + +def test_not_in_any_domain(): + check = list(_illegal) + [x] + [ + float(i) for i in _illegal[:3]] + for dom in (ZZ, QQ, RR, CC, EX): + for i in check: + if i == x and dom == EX: + continue + assert i not in dom, (i, dom) + raises(CoercionFailed, lambda: dom.convert(i)) + + +def test_ModularInteger(): + F3 = FF(3) + + a = F3(0) + assert F3.of_type(a) and a == 0 + a = F3(1) + assert F3.of_type(a) and a == 1 + a = F3(2) + assert F3.of_type(a) and a == 2 + a = F3(3) + assert F3.of_type(a) and a == 0 + a = F3(4) + assert F3.of_type(a) and a == 1 + + a = F3(F3(0)) + assert F3.of_type(a) and a == 0 + a = F3(F3(1)) + assert F3.of_type(a) and a == 1 + a = F3(F3(2)) + assert F3.of_type(a) and a == 2 + a = F3(F3(3)) + assert F3.of_type(a) and a == 0 + a = F3(F3(4)) + assert F3.of_type(a) and a == 1 + + a = -F3(1) + assert F3.of_type(a) and a == 2 + a = -F3(2) + assert F3.of_type(a) and a == 1 + + a = 2 + F3(2) + assert F3.of_type(a) and a == 1 + a = F3(2) + 2 + assert F3.of_type(a) and a == 1 + a = F3(2) + F3(2) + assert F3.of_type(a) and a == 1 + a = F3(2) + F3(2) + assert F3.of_type(a) and a == 1 + + a = 3 - F3(2) + assert F3.of_type(a) and a == 1 + a = F3(3) - 2 + assert F3.of_type(a) and a == 1 + a = F3(3) - F3(2) + assert F3.of_type(a) and a == 1 + a = F3(3) - F3(2) + assert F3.of_type(a) and a == 1 + + a = 2*F3(2) + assert F3.of_type(a) and a == 1 + a = F3(2)*2 + assert F3.of_type(a) and a == 1 + a = F3(2)*F3(2) + assert F3.of_type(a) and a == 1 + a = F3(2)*F3(2) + assert F3.of_type(a) and a == 1 + + a = 2/F3(2) + assert F3.of_type(a) and a == 1 + a = F3(2)/2 + assert F3.of_type(a) and a == 1 + a = F3(2)/F3(2) + assert F3.of_type(a) and a == 1 + a = F3(2)/F3(2) + assert F3.of_type(a) and a == 1 + + a = F3(2)**0 + assert F3.of_type(a) and a == 1 + a = F3(2)**1 + assert F3.of_type(a) and a == 2 + a = F3(2)**2 + assert F3.of_type(a) and a == 1 + + F7 = FF(7) + + a = F7(3)**100000000000 + assert F7.of_type(a) and a == 4 + a = F7(3)**-100000000000 + assert F7.of_type(a) and a == 2 + + assert bool(F3(3)) is False + assert bool(F3(4)) is True + + F5 = FF(5) + + a = F5(1)**(-1) + assert F5.of_type(a) and a == 1 + a = F5(2)**(-1) + assert F5.of_type(a) and a == 3 + a = F5(3)**(-1) + assert F5.of_type(a) and a == 2 + a = F5(4)**(-1) + assert F5.of_type(a) and a == 4 + + if GROUND_TYPES != 'flint': + # XXX: This gives a core dump with python-flint... + raises(NotInvertible, lambda: F5(0)**(-1)) + raises(NotInvertible, lambda: F5(5)**(-1)) + + raises(ValueError, lambda: FF(0)) + raises(ValueError, lambda: FF(2.1)) + + for n1 in range(5): + for n2 in range(5): + if GROUND_TYPES != 'flint': + with warns_deprecated_sympy(): + assert (F5(n1) < F5(n2)) is (n1 < n2) + with warns_deprecated_sympy(): + assert (F5(n1) <= F5(n2)) is (n1 <= n2) + with warns_deprecated_sympy(): + assert (F5(n1) > F5(n2)) is (n1 > n2) + with warns_deprecated_sympy(): + assert (F5(n1) >= F5(n2)) is (n1 >= n2) + else: + raises(TypeError, lambda: F5(n1) < F5(n2)) + raises(TypeError, lambda: F5(n1) <= F5(n2)) + raises(TypeError, lambda: F5(n1) > F5(n2)) + raises(TypeError, lambda: F5(n1) >= F5(n2)) + + # https://github.com/sympy/sympy/issues/26789 + assert GF(Integer(5)) == F5 + assert F5(Integer(3)) == F5(3) + + +def test_QQ_int(): + assert int(QQ(2**2000, 3**1250)) == 455431 + assert int(QQ(2**100, 3)) == 422550200076076467165567735125 + + +def test_RR_double(): + assert RR(3.14) > 1e-50 + assert RR(1e-13) > 1e-50 + assert RR(1e-14) > 1e-50 + assert RR(1e-15) > 1e-50 + assert RR(1e-20) > 1e-50 + assert RR(1e-40) > 1e-50 + + +def test_RR_Float(): + f1 = Float("1.01") + f2 = Float("1.0000000000000000000001") + assert f1._prec == 53 + assert f2._prec == 80 + assert RR(f1)-1 > 1e-50 + assert RR(f2)-1 < 1e-50 # RR's precision is lower than f2's + + RR2 = RealField(prec=f2._prec) + assert RR2(f1)-1 > 1e-50 + assert RR2(f2)-1 > 1e-50 # RR's precision is equal to f2's + + +def test_CC_double(): + assert CC(3.14).real > 1e-50 + assert CC(1e-13).real > 1e-50 + assert CC(1e-14).real > 1e-50 + assert CC(1e-15).real > 1e-50 + assert CC(1e-20).real > 1e-50 + assert CC(1e-40).real > 1e-50 + + assert CC(3.14j).imag > 1e-50 + assert CC(1e-13j).imag > 1e-50 + assert CC(1e-14j).imag > 1e-50 + assert CC(1e-15j).imag > 1e-50 + assert CC(1e-20j).imag > 1e-50 + assert CC(1e-40j).imag > 1e-50 + + +def test_gaussian_domains(): + I = S.ImaginaryUnit + a, b, c, d = [ZZ_I.convert(x) for x in (5, 2 + I, 3 - I, 5 - 5*I)] + assert ZZ_I.gcd(a, b) == b + assert ZZ_I.gcd(a, c) == b + assert ZZ_I.lcm(a, b) == a + assert ZZ_I.lcm(a, c) == d + assert ZZ_I(3, 4) != QQ_I(3, 4) # XXX is this right or should QQ->ZZ if possible? + assert ZZ_I(3, 0) != 3 # and should this go to Integer? + assert QQ_I(S(3)/4, 0) != S(3)/4 # and this to Rational? + assert ZZ_I(0, 0).quadrant() == 0 + assert ZZ_I(-1, 0).quadrant() == 2 + + assert QQ_I.convert(QQ(3, 2)) == QQ_I(QQ(3, 2), QQ(0)) + assert QQ_I.convert(QQ(3, 2), QQ) == QQ_I(QQ(3, 2), QQ(0)) + + for G in (QQ_I, ZZ_I): + + q = G(3, 4) + assert str(q) == '3 + 4*I' + assert q.parent() == G + assert q._get_xy(pi) == (None, None) + assert q._get_xy(2) == (2, 0) + assert q._get_xy(2*I) == (0, 2) + + assert hash(q) == hash((3, 4)) + assert G(1, 2) == G(1, 2) + assert G(1, 2) != G(1, 3) + assert G(3, 0) == G(3) + + assert q + q == G(6, 8) + assert q - q == G(0, 0) + assert 3 - q == -q + 3 == G(0, -4) + assert 3 + q == q + 3 == G(6, 4) + assert q * q == G(-7, 24) + assert 3 * q == q * 3 == G(9, 12) + assert q ** 0 == G(1, 0) + assert q ** 1 == q + assert q ** 2 == q * q == G(-7, 24) + assert q ** 3 == q * q * q == G(-117, 44) + assert 1 / q == q ** -1 == QQ_I(S(3)/25, - S(4)/25) + assert q / 1 == QQ_I(3, 4) + assert q / 2 == QQ_I(S(3)/2, 2) + assert q/3 == QQ_I(1, S(4)/3) + assert 3/q == QQ_I(S(9)/25, -S(12)/25) + i, r = divmod(q, 2) + assert 2*i + r == q + i, r = divmod(2, q) + assert q*i + r == G(2, 0) + + a, b = G(2, 0), G(1, -1) + c, d, g = G.gcdex(a, b) + assert g == G.gcd(a, b) + assert c * a + d * b == g + + raises(ZeroDivisionError, lambda: q % 0) + raises(ZeroDivisionError, lambda: q / 0) + raises(ZeroDivisionError, lambda: q // 0) + raises(ZeroDivisionError, lambda: divmod(q, 0)) + raises(ZeroDivisionError, lambda: divmod(q, 0)) + raises(TypeError, lambda: q + x) + raises(TypeError, lambda: q - x) + raises(TypeError, lambda: x + q) + raises(TypeError, lambda: x - q) + raises(TypeError, lambda: q * x) + raises(TypeError, lambda: x * q) + raises(TypeError, lambda: q / x) + raises(TypeError, lambda: x / q) + raises(TypeError, lambda: q // x) + raises(TypeError, lambda: x // q) + + assert G.from_sympy(S(2)) == G(2, 0) + assert G.to_sympy(G(2, 0)) == S(2) + raises(CoercionFailed, lambda: G.from_sympy(pi)) + + PR = G.inject(x) + assert isinstance(PR, PolynomialRing) + assert PR.domain == G + assert len(PR.gens) == 1 and PR.gens[0].as_expr() == x + + if G is QQ_I: + AF = G.as_AlgebraicField() + assert isinstance(AF, AlgebraicField) + assert AF.domain == QQ + assert AF.ext.args[0] == I + + for qi in [G(-1, 0), G(1, 0), G(0, -1), G(0, 1)]: + assert G.is_negative(qi) is False + assert G.is_positive(qi) is False + assert G.is_nonnegative(qi) is False + assert G.is_nonpositive(qi) is False + + domains = [ZZ, QQ, AlgebraicField(QQ, I)] + + # XXX: These domains are all obsolete because ZZ/QQ with MPZ/MPQ + # already use either gmpy, flint or python depending on the + # availability of these libraries. We can keep these tests for now but + # ideally we should remove these alternate domains entirely. + domains += [ZZ_python(), QQ_python()] + if GROUND_TYPES == 'gmpy': + domains += [ZZ_gmpy(), QQ_gmpy()] + + for K in domains: + assert G.convert(K(2)) == G(2, 0) + assert G.convert(K(2), K) == G(2, 0) + + for K in ZZ_I, QQ_I: + assert G.convert(K(1, 1)) == G(1, 1) + assert G.convert(K(1, 1), K) == G(1, 1) + + if G == ZZ_I: + assert repr(q) == 'ZZ_I(3, 4)' + assert q//3 == G(1, 1) + assert 12//q == G(1, -2) + assert 12 % q == G(1, 2) + assert q % 2 == G(-1, 0) + assert i == G(0, 0) + assert r == G(2, 0) + assert G.get_ring() == G + assert G.get_field() == QQ_I + else: + assert repr(q) == 'QQ_I(3, 4)' + assert G.get_ring() == ZZ_I + assert G.get_field() == G + assert q//3 == G(1, S(4)/3) + assert 12//q == G(S(36)/25, -S(48)/25) + assert 12 % q == G(0, 0) + assert q % 2 == G(0, 0) + assert i == G(S(6)/25, -S(8)/25), (G,i) + assert r == G(0, 0) + q2 = G(S(3)/2, S(5)/3) + assert G.numer(q2) == ZZ_I(9, 10) + assert G.denom(q2) == ZZ_I(6) + + +def test_EX_EXRAW(): + assert EXRAW.zero is S.Zero + assert EXRAW.one is S.One + + assert EX(1) == EX.Expression(1) + assert EX(1).ex is S.One + assert EXRAW(1) is S.One + + # EX has cancelling but EXRAW does not + assert 2*EX((x + y*x)/x) == EX(2 + 2*y) != 2*((x + y*x)/x) + assert 2*EXRAW((x + y*x)/x) == 2*((x + y*x)/x) != (1 + y) + + assert EXRAW.convert_from(EX(1), EX) is EXRAW.one + assert EX.convert_from(EXRAW(1), EXRAW) == EX.one + + assert EXRAW.from_sympy(S.One) is S.One + assert EXRAW.to_sympy(EXRAW.one) is S.One + raises(CoercionFailed, lambda: EXRAW.from_sympy([])) + + assert EXRAW.get_field() == EXRAW + + assert EXRAW.unify(EX) == EXRAW + assert EX.unify(EXRAW) == EXRAW + + +def test_EX_ordering(): + elements = [EX(1), EX(x), EX(3)] + assert sorted(elements) == [EX(1), EX(3), EX(x)] + + +def test_canonical_unit(): + + for K in [ZZ, QQ, RR]: # CC? + assert K.canonical_unit(K(2)) == K(1) + assert K.canonical_unit(K(-2)) == K(-1) + + for K in [ZZ_I, QQ_I]: + i = K.from_sympy(I) + assert K.canonical_unit(K(2)) == K(1) + assert K.canonical_unit(K(2)*i) == -i + assert K.canonical_unit(-K(2)) == K(-1) + assert K.canonical_unit(-K(2)*i) == i + + K = ZZ[x] + assert K.canonical_unit(K(x + 1)) == K(1) + assert K.canonical_unit(K(-x + 1)) == K(-1) + + K = ZZ_I[x] + assert K.canonical_unit(K.from_sympy(I*x)) == ZZ_I(0, -1) + + K = ZZ_I.frac_field(x, y) + i = K.from_sympy(I) + assert i / i == K.one + assert (K.one + i)/(i - K.one) == -i + + +def test_Domain_is_negative(): + I = S.ImaginaryUnit + a, b = [CC.convert(x) for x in (2 + I, 5)] + assert CC.is_negative(a) == False + assert CC.is_negative(b) == False + + +def test_Domain_is_positive(): + I = S.ImaginaryUnit + a, b = [CC.convert(x) for x in (2 + I, 5)] + assert CC.is_positive(a) == False + assert CC.is_positive(b) == False + + +def test_Domain_is_nonnegative(): + I = S.ImaginaryUnit + a, b = [CC.convert(x) for x in (2 + I, 5)] + assert CC.is_nonnegative(a) == False + assert CC.is_nonnegative(b) == False + + +def test_Domain_is_nonpositive(): + I = S.ImaginaryUnit + a, b = [CC.convert(x) for x in (2 + I, 5)] + assert CC.is_nonpositive(a) == False + assert CC.is_nonpositive(b) == False + + +def test_exponential_domain(): + K = ZZ[E] + eK = K.from_sympy(E) + assert K.from_sympy(exp(3)) == eK ** 3 + assert K.convert(exp(3)) == eK ** 3 + + +def test_AlgebraicField_alias(): + # No default alias: + k = QQ.algebraic_field(sqrt(2)) + assert k.ext.alias is None + + # For a single extension, its alias is used: + alpha = AlgebraicNumber(sqrt(2), alias='alpha') + k = QQ.algebraic_field(alpha) + assert k.ext.alias.name == 'alpha' + + # Can override the alias of a single extension: + k = QQ.algebraic_field(alpha, alias='theta') + assert k.ext.alias.name == 'theta' + + # With multiple extensions, no default alias: + k = QQ.algebraic_field(sqrt(2), sqrt(3)) + assert k.ext.alias is None + + # With multiple extensions, no default alias, even if one of + # the extensions has one: + k = QQ.algebraic_field(alpha, sqrt(3)) + assert k.ext.alias is None + + # With multiple extensions, may set an alias: + k = QQ.algebraic_field(sqrt(2), sqrt(3), alias='theta') + assert k.ext.alias.name == 'theta' + + # Alias is passed to constructed field elements: + k = QQ.algebraic_field(alpha) + beta = k.to_alg_num(k([1, 2, 3])) + assert beta.alias is alpha.alias + + +def test_exsqrt(): + assert ZZ.is_square(ZZ(4)) is True + assert ZZ.exsqrt(ZZ(4)) == ZZ(2) + assert ZZ.is_square(ZZ(42)) is False + assert ZZ.exsqrt(ZZ(42)) is None + assert ZZ.is_square(ZZ(0)) is True + assert ZZ.exsqrt(ZZ(0)) == ZZ(0) + assert ZZ.is_square(ZZ(-1)) is False + assert ZZ.exsqrt(ZZ(-1)) is None + + assert QQ.is_square(QQ(9, 4)) is True + assert QQ.exsqrt(QQ(9, 4)) == QQ(3, 2) + assert QQ.is_square(QQ(18, 8)) is True + assert QQ.exsqrt(QQ(18, 8)) == QQ(3, 2) + assert QQ.is_square(QQ(-9, -4)) is True + assert QQ.exsqrt(QQ(-9, -4)) == QQ(3, 2) + assert QQ.is_square(QQ(11, 4)) is False + assert QQ.exsqrt(QQ(11, 4)) is None + assert QQ.is_square(QQ(9, 5)) is False + assert QQ.exsqrt(QQ(9, 5)) is None + assert QQ.is_square(QQ(4)) is True + assert QQ.exsqrt(QQ(4)) == QQ(2) + assert QQ.is_square(QQ(0)) is True + assert QQ.exsqrt(QQ(0)) == QQ(0) + assert QQ.is_square(QQ(-16, 9)) is False + assert QQ.exsqrt(QQ(-16, 9)) is None + + assert RR.is_square(RR(6.25)) is True + assert RR.exsqrt(RR(6.25)) == RR(2.5) + assert RR.is_square(RR(2)) is True + assert RR.almosteq(RR.exsqrt(RR(2)), RR(1.4142135623730951), tolerance=1e-15) + assert RR.is_square(RR(0)) is True + assert RR.exsqrt(RR(0)) == RR(0) + assert RR.is_square(RR(-1)) is False + assert RR.exsqrt(RR(-1)) is None + + assert CC.is_square(CC(2)) is True + assert CC.almosteq(CC.exsqrt(CC(2)), CC(1.4142135623730951), tolerance=1e-15) + assert CC.is_square(CC(0)) is True + assert CC.exsqrt(CC(0)) == CC(0) + assert CC.is_square(CC(-1)) is True + assert CC.exsqrt(CC(-1)) == CC(0, 1) + assert CC.is_square(CC(0, 2)) is True + assert CC.exsqrt(CC(0, 2)) == CC(1, 1) + assert CC.is_square(CC(-3, -4)) is True + assert CC.exsqrt(CC(-3, -4)) == CC(1, -2) + + F2 = FF(2) + assert F2.is_square(F2(1)) is True + assert F2.exsqrt(F2(1)) == F2(1) + assert F2.is_square(F2(0)) is True + assert F2.exsqrt(F2(0)) == F2(0) + + F7 = FF(7) + assert F7.is_square(F7(2)) is True + assert F7.exsqrt(F7(2)) == F7(3) + assert F7.is_square(F7(3)) is False + assert F7.exsqrt(F7(3)) is None + assert F7.is_square(F7(0)) is True + assert F7.exsqrt(F7(0)) == F7(0) diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/tests/test_polynomialring.py b/phivenv/Lib/site-packages/sympy/polys/domains/tests/test_polynomialring.py new file mode 100644 index 0000000000000000000000000000000000000000..6cb1fdf3f9f9250518289019b0bb108047e8cb6c --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/domains/tests/test_polynomialring.py @@ -0,0 +1,93 @@ +"""Tests for the PolynomialRing classes. """ + +from sympy.polys.domains import QQ, ZZ +from sympy.polys.polyerrors import ExactQuotientFailed, CoercionFailed, NotReversible + +from sympy.abc import x, y + +from sympy.testing.pytest import raises + + +def test_build_order(): + R = QQ.old_poly_ring(x, y, order=(("lex", x), ("ilex", y))) + assert R.order((1, 5)) == ((1,), (-5,)) + + +def test_globalring(): + Qxy = QQ.old_frac_field(x, y) + R = QQ.old_poly_ring(x, y) + X = R.convert(x) + Y = R.convert(y) + + assert x in R + assert 1/x not in R + assert 1/(1 + x) not in R + assert Y in R + assert X * (Y**2 + 1) == R.convert(x * (y**2 + 1)) + assert X + 1 == R.convert(x + 1) + raises(ExactQuotientFailed, lambda: X/Y) + raises(TypeError, lambda: x/Y) + raises(TypeError, lambda: X/y) + assert X**2 / X == X + + assert R.from_GlobalPolynomialRing(ZZ.old_poly_ring(x, y).convert(x), ZZ.old_poly_ring(x, y)) == X + assert R.from_FractionField(Qxy.convert(x), Qxy) == X + assert R.from_FractionField(Qxy.convert(x/y), Qxy) is None + + assert R._sdm_to_vector(R._vector_to_sdm([X, Y], R.order), 2) == [X, Y] + + +def test_localring(): + Qxy = QQ.old_frac_field(x, y) + R = QQ.old_poly_ring(x, y, order="ilex") + X = R.convert(x) + Y = R.convert(y) + + assert x in R + assert 1/x not in R + assert 1/(1 + x) in R + assert Y in R + assert X*(Y**2 + 1)/(1 + X) == R.convert(x*(y**2 + 1)/(1 + x)) + raises(TypeError, lambda: x/Y) + raises(TypeError, lambda: X/y) + assert X + 1 == R.convert(x + 1) + assert X**2 / X == X + + assert R.from_GlobalPolynomialRing(ZZ.old_poly_ring(x, y).convert(x), ZZ.old_poly_ring(x, y)) == X + assert R.from_FractionField(Qxy.convert(x), Qxy) == X + raises(CoercionFailed, lambda: R.from_FractionField(Qxy.convert(x/y), Qxy)) + raises(ExactQuotientFailed, lambda: R.exquo(X, Y)) + raises(NotReversible, lambda: R.revert(X)) + + assert R._sdm_to_vector( + R._vector_to_sdm([X/(X + 1), Y/(1 + X*Y)], R.order), 2) == \ + [X*(1 + X*Y), Y*(1 + X)] + + +def test_conversion(): + L = QQ.old_poly_ring(x, y, order="ilex") + G = QQ.old_poly_ring(x, y) + + assert L.convert(x) == L.convert(G.convert(x), G) + assert G.convert(x) == G.convert(L.convert(x), L) + raises(CoercionFailed, lambda: G.convert(L.convert(1/(1 + x)), L)) + + +def test_units(): + R = QQ.old_poly_ring(x) + assert R.is_unit(R.convert(1)) + assert R.is_unit(R.convert(2)) + assert not R.is_unit(R.convert(x)) + assert not R.is_unit(R.convert(1 + x)) + + R = QQ.old_poly_ring(x, order='ilex') + assert R.is_unit(R.convert(1)) + assert R.is_unit(R.convert(2)) + assert not R.is_unit(R.convert(x)) + assert R.is_unit(R.convert(1 + x)) + + R = ZZ.old_poly_ring(x) + assert R.is_unit(R.convert(1)) + assert not R.is_unit(R.convert(2)) + assert not R.is_unit(R.convert(x)) + assert not R.is_unit(R.convert(1 + x)) diff --git a/phivenv/Lib/site-packages/sympy/polys/domains/tests/test_quotientring.py b/phivenv/Lib/site-packages/sympy/polys/domains/tests/test_quotientring.py new file mode 100644 index 0000000000000000000000000000000000000000..aff167bdd72dc4400785efefef7b3e9057fd0727 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/domains/tests/test_quotientring.py @@ -0,0 +1,52 @@ +"""Tests for quotient rings.""" + +from sympy.polys.domains.integerring import ZZ +from sympy.polys.domains.rationalfield import QQ +from sympy.abc import x, y + +from sympy.polys.polyerrors import NotReversible + +from sympy.testing.pytest import raises + + +def test_QuotientRingElement(): + R = QQ.old_poly_ring(x)/[x**10] + X = R.convert(x) + + assert X*(X + 1) == R.convert(x**2 + x) + assert X*x == R.convert(x**2) + assert x*X == R.convert(x**2) + assert X + x == R.convert(2*x) + assert x + X == 2*X + assert X**2 == R.convert(x**2) + assert 1/(1 - X) == R.convert(sum(x**i for i in range(10))) + assert X**10 == R.zero + assert X != x + + raises(NotReversible, lambda: 1/X) + + +def test_QuotientRing(): + I = QQ.old_poly_ring(x).ideal(x**2 + 1) + R = QQ.old_poly_ring(x)/I + + assert R == QQ.old_poly_ring(x)/[x**2 + 1] + assert R == QQ.old_poly_ring(x)/QQ.old_poly_ring(x).ideal(x**2 + 1) + assert R != QQ.old_poly_ring(x) + + assert R.convert(1)/x == -x + I + assert -1 + I == x**2 + I + assert R.convert(ZZ(1), ZZ) == 1 + I + assert R.convert(R.convert(x), R) == R.convert(x) + + X = R.convert(x) + Y = QQ.old_poly_ring(x).convert(x) + assert -1 + I == X**2 + I + assert -1 + I == Y**2 + I + assert R.to_sympy(X) == x + + raises(ValueError, lambda: QQ.old_poly_ring(x)/QQ.old_poly_ring(x, y).ideal(x)) + + R = QQ.old_poly_ring(x, order="ilex") + I = R.ideal(x) + assert R.convert(1) + I == (R/I).convert(1) diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/__init__.py b/phivenv/Lib/site-packages/sympy/polys/matrices/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e4ebc3d71ba3dac9ccc695d046d6b3d2ad940fa1 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/matrices/__init__.py @@ -0,0 +1,15 @@ +""" + +sympy.polys.matrices package. + +The main export from this package is the DomainMatrix class which is a +lower-level implementation of matrices based on the polys Domains. This +implementation is typically a lot faster than SymPy's standard Matrix class +but is a work in progress and is still experimental. + +""" +from .domainmatrix import DomainMatrix, DM + +__all__ = [ + 'DomainMatrix', 'DM', +] diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/matrices/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff055653a265ab7825450c59a32ab58979559d3d Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/matrices/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/__pycache__/_dfm.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/matrices/__pycache__/_dfm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd0ed2b6998d8b41801685119be159845ac13ddd Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/matrices/__pycache__/_dfm.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/__pycache__/_typing.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/matrices/__pycache__/_typing.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14be28b9347ec2681ab43526d662287197a10bd2 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/matrices/__pycache__/_typing.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/__pycache__/ddm.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/matrices/__pycache__/ddm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..200004aa8e8e507c284bfd045a8528bc267c5bff Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/matrices/__pycache__/ddm.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/__pycache__/dense.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/matrices/__pycache__/dense.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..789cf51f8f952293d6a1221f079ef6defe6cc5ef Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/matrices/__pycache__/dense.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/__pycache__/dfm.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/matrices/__pycache__/dfm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e873bf96e480d5a3ed262a75d52150a539b820bc Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/matrices/__pycache__/dfm.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/__pycache__/domainscalar.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/matrices/__pycache__/domainscalar.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24be9b8272a57a41c6a3113ad1bb03af10e150f0 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/matrices/__pycache__/domainscalar.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/__pycache__/eigen.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/matrices/__pycache__/eigen.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be3c00bd02b99cd869cfa0441642cd2ed526c9ee Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/matrices/__pycache__/eigen.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/__pycache__/exceptions.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/matrices/__pycache__/exceptions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4928b658ce31c8c2a808749b11d02872e57d2c8a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/matrices/__pycache__/exceptions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/__pycache__/linsolve.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/matrices/__pycache__/linsolve.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc791072369c0470c8e399b68903f6b68a0d98f4 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/matrices/__pycache__/linsolve.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/__pycache__/lll.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/matrices/__pycache__/lll.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c9ea5c81f85fc7c9c16d11e108c04a6a33b1ae2 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/matrices/__pycache__/lll.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/__pycache__/normalforms.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/matrices/__pycache__/normalforms.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f92cb95d167d9fa21fa2462e92bad53599f0b02a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/matrices/__pycache__/normalforms.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/__pycache__/rref.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/matrices/__pycache__/rref.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04838ececf1436027c4108db7c0f501de044c0ea Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/matrices/__pycache__/rref.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/__pycache__/sdm.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/matrices/__pycache__/sdm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95540aea651b2d57c9948fddeb4221f26094c72b Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/matrices/__pycache__/sdm.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/_dfm.py b/phivenv/Lib/site-packages/sympy/polys/matrices/_dfm.py new file mode 100644 index 0000000000000000000000000000000000000000..1d02076014168ed4966fecd07f3d7a1d4828ae63 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/matrices/_dfm.py @@ -0,0 +1,951 @@ +# +# sympy.polys.matrices.dfm +# +# This modules defines the DFM class which is a wrapper for dense flint +# matrices as found in python-flint. +# +# As of python-flint 0.4.1 matrices over the following domains can be supported +# by python-flint: +# +# ZZ: flint.fmpz_mat +# QQ: flint.fmpq_mat +# GF(p): flint.nmod_mat (p prime and p < ~2**62) +# +# The underlying flint library has many more domains, but these are not yet +# supported by python-flint. +# +# The DFM class is a wrapper for the flint matrices and provides a common +# interface for all supported domains that is interchangeable with the DDM +# and SDM classes so that DomainMatrix can be used with any as its internal +# matrix representation. +# + +# TODO: +# +# Implement the following methods that are provided by python-flint: +# +# - hnf (Hermite normal form) +# - snf (Smith normal form) +# - minpoly +# - is_hnf +# - is_snf +# - rank +# +# The other types DDM and SDM do not have these methods and the algorithms +# for hnf, snf and rank are already implemented. Algorithms for minpoly, +# is_hnf and is_snf would need to be added. +# +# Add more methods to python-flint to expose more of Flint's functionality +# and also to make some of the above methods simpler or more efficient e.g. +# slicing, fancy indexing etc. + +from sympy.external.gmpy import GROUND_TYPES +from sympy.external.importtools import import_module +from sympy.utilities.decorator import doctest_depends_on + +from sympy.polys.domains import ZZ, QQ + +from .exceptions import ( + DMBadInputError, + DMDomainError, + DMNonSquareMatrixError, + DMNonInvertibleMatrixError, + DMRankError, + DMShapeError, + DMValueError, +) + + +if GROUND_TYPES != 'flint': + __doctest_skip__ = ['*'] + + +flint = import_module('flint') + + +__all__ = ['DFM'] + + +@doctest_depends_on(ground_types=['flint']) +class DFM: + """ + Dense FLINT matrix. This class is a wrapper for matrices from python-flint. + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.matrices.dfm import DFM + >>> dfm = DFM([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + >>> dfm + [[1, 2], [3, 4]] + >>> dfm.rep + [1, 2] + [3, 4] + >>> type(dfm.rep) # doctest: +SKIP + + + Usually, the DFM class is not instantiated directly, but is created as the + internal representation of :class:`~.DomainMatrix`. When + `SYMPY_GROUND_TYPES` is set to `flint` and `python-flint` is installed, the + :class:`DFM` class is used automatically as the internal representation of + :class:`~.DomainMatrix` in dense format if the domain is supported by + python-flint. + + >>> from sympy.polys.matrices.domainmatrix import DM + >>> dM = DM([[1, 2], [3, 4]], ZZ) + >>> dM.rep + [[1, 2], [3, 4]] + + A :class:`~.DomainMatrix` can be converted to :class:`DFM` by calling the + :meth:`to_dfm` method: + + >>> dM.to_dfm() + [[1, 2], [3, 4]] + + """ + + fmt = 'dense' + is_DFM = True + is_DDM = False + + def __new__(cls, rowslist, shape, domain): + """Construct from a nested list.""" + flint_mat = cls._get_flint_func(domain) + + if 0 not in shape: + try: + rep = flint_mat(rowslist) + except (ValueError, TypeError): + raise DMBadInputError(f"Input should be a list of list of {domain}") + else: + rep = flint_mat(*shape) + + return cls._new(rep, shape, domain) + + @classmethod + def _new(cls, rep, shape, domain): + """Internal constructor from a flint matrix.""" + cls._check(rep, shape, domain) + obj = object.__new__(cls) + obj.rep = rep + obj.shape = obj.rows, obj.cols = shape + obj.domain = domain + return obj + + def _new_rep(self, rep): + """Create a new DFM with the same shape and domain but a new rep.""" + return self._new(rep, self.shape, self.domain) + + @classmethod + def _check(cls, rep, shape, domain): + repshape = (rep.nrows(), rep.ncols()) + if repshape != shape: + raise DMBadInputError("Shape of rep does not match shape of DFM") + if domain == ZZ and not isinstance(rep, flint.fmpz_mat): + raise RuntimeError("Rep is not a flint.fmpz_mat") + elif domain == QQ and not isinstance(rep, flint.fmpq_mat): + raise RuntimeError("Rep is not a flint.fmpq_mat") + elif domain.is_FF and not isinstance(rep, (flint.fmpz_mod_mat, flint.nmod_mat)): + raise RuntimeError("Rep is not a flint.fmpz_mod_mat or flint.nmod_mat") + elif domain not in (ZZ, QQ) and not domain.is_FF: + raise NotImplementedError("Only ZZ and QQ are supported by DFM") + + @classmethod + def _supports_domain(cls, domain): + """Return True if the given domain is supported by DFM.""" + return domain in (ZZ, QQ) or domain.is_FF and domain._is_flint + + @classmethod + def _get_flint_func(cls, domain): + """Return the flint matrix class for the given domain.""" + if domain == ZZ: + return flint.fmpz_mat + elif domain == QQ: + return flint.fmpq_mat + elif domain.is_FF: + c = domain.characteristic() + if isinstance(domain.one, flint.nmod): + _cls = flint.nmod_mat + def _func(*e): + if len(e) == 1 and isinstance(e[0], flint.nmod_mat): + return _cls(e[0]) + else: + return _cls(*e, c) + else: + m = flint.fmpz_mod_ctx(c) + _func = lambda *e: flint.fmpz_mod_mat(*e, m) + return _func + else: + raise NotImplementedError("Only ZZ and QQ are supported by DFM") + + @property + def _func(self): + """Callable to create a flint matrix of the same domain.""" + return self._get_flint_func(self.domain) + + def __str__(self): + """Return ``str(self)``.""" + return str(self.to_ddm()) + + def __repr__(self): + """Return ``repr(self)``.""" + return f'DFM{repr(self.to_ddm())[3:]}' + + def __eq__(self, other): + """Return ``self == other``.""" + if not isinstance(other, DFM): + return NotImplemented + # Compare domains first because we do *not* want matrices with + # different domains to be equal but e.g. a flint fmpz_mat and fmpq_mat + # with the same entries will compare equal. + return self.domain == other.domain and self.rep == other.rep + + @classmethod + def from_list(cls, rowslist, shape, domain): + """Construct from a nested list.""" + return cls(rowslist, shape, domain) + + def to_list(self): + """Convert to a nested list.""" + return self.rep.tolist() + + def copy(self): + """Return a copy of self.""" + return self._new_rep(self._func(self.rep)) + + def to_ddm(self): + """Convert to a DDM.""" + return DDM.from_list(self.to_list(), self.shape, self.domain) + + def to_sdm(self): + """Convert to a SDM.""" + return SDM.from_list(self.to_list(), self.shape, self.domain) + + def to_dfm(self): + """Return self.""" + return self + + def to_dfm_or_ddm(self): + """ + Convert to a :class:`DFM`. + + This :class:`DFM` method exists to parallel the :class:`~.DDM` and + :class:`~.SDM` methods. For :class:`DFM` it will always return self. + + See Also + ======== + + to_ddm + to_sdm + sympy.polys.matrices.domainmatrix.DomainMatrix.to_dfm_or_ddm + """ + return self + + @classmethod + def from_ddm(cls, ddm): + """Convert from a DDM.""" + return cls.from_list(ddm.to_list(), ddm.shape, ddm.domain) + + @classmethod + def from_list_flat(cls, elements, shape, domain): + """Inverse of :meth:`to_list_flat`.""" + func = cls._get_flint_func(domain) + try: + rep = func(*shape, elements) + except ValueError: + raise DMBadInputError(f"Incorrect number of elements for shape {shape}") + except TypeError: + raise DMBadInputError(f"Input should be a list of {domain}") + return cls(rep, shape, domain) + + def to_list_flat(self): + """Convert to a flat list.""" + return self.rep.entries() + + def to_flat_nz(self): + """Convert to a flat list of non-zeros.""" + return self.to_ddm().to_flat_nz() + + @classmethod + def from_flat_nz(cls, elements, data, domain): + """Inverse of :meth:`to_flat_nz`.""" + return DDM.from_flat_nz(elements, data, domain).to_dfm() + + def to_dod(self): + """Convert to a DOD.""" + return self.to_ddm().to_dod() + + @classmethod + def from_dod(cls, dod, shape, domain): + """Inverse of :meth:`to_dod`.""" + return DDM.from_dod(dod, shape, domain).to_dfm() + + def to_dok(self): + """Convert to a DOK.""" + return self.to_ddm().to_dok() + + @classmethod + def from_dok(cls, dok, shape, domain): + """Inverse of :math:`to_dod`.""" + return DDM.from_dok(dok, shape, domain).to_dfm() + + def iter_values(self): + """Iterate over the non-zero values of the matrix.""" + m, n = self.shape + rep = self.rep + for i in range(m): + for j in range(n): + repij = rep[i, j] + if repij: + yield rep[i, j] + + def iter_items(self): + """Iterate over indices and values of nonzero elements of the matrix.""" + m, n = self.shape + rep = self.rep + for i in range(m): + for j in range(n): + repij = rep[i, j] + if repij: + yield ((i, j), repij) + + def convert_to(self, domain): + """Convert to a new domain.""" + if domain == self.domain: + return self.copy() + elif domain == QQ and self.domain == ZZ: + return self._new(flint.fmpq_mat(self.rep), self.shape, domain) + elif self._supports_domain(domain): + # XXX: Use more efficient conversions when possible. + return self.to_ddm().convert_to(domain).to_dfm() + else: + # It is the callers responsibility to convert to DDM before calling + # this method if the domain is not supported by DFM. + raise NotImplementedError("Only ZZ and QQ are supported by DFM") + + def getitem(self, i, j): + """Get the ``(i, j)``-th entry.""" + # XXX: flint matrices do not support negative indices + # XXX: They also raise ValueError instead of IndexError + m, n = self.shape + if i < 0: + i += m + if j < 0: + j += n + try: + return self.rep[i, j] + except ValueError: + raise IndexError(f"Invalid indices ({i}, {j}) for Matrix of shape {self.shape}") + + def setitem(self, i, j, value): + """Set the ``(i, j)``-th entry.""" + # XXX: flint matrices do not support negative indices + # XXX: They also raise ValueError instead of IndexError + m, n = self.shape + if i < 0: + i += m + if j < 0: + j += n + try: + self.rep[i, j] = value + except ValueError: + raise IndexError(f"Invalid indices ({i}, {j}) for Matrix of shape {self.shape}") + + def _extract(self, i_indices, j_indices): + """Extract a submatrix with no checking.""" + # Indices must be positive and in range. + M = self.rep + lol = [[M[i, j] for j in j_indices] for i in i_indices] + shape = (len(i_indices), len(j_indices)) + return self.from_list(lol, shape, self.domain) + + def extract(self, rowslist, colslist): + """Extract a submatrix.""" + # XXX: flint matrices do not support fancy indexing or negative indices + # + # Check and convert negative indices before calling _extract. + m, n = self.shape + + new_rows = [] + new_cols = [] + + for i in rowslist: + if i < 0: + i_pos = i + m + else: + i_pos = i + if not 0 <= i_pos < m: + raise IndexError(f"Invalid row index {i} for Matrix of shape {self.shape}") + new_rows.append(i_pos) + + for j in colslist: + if j < 0: + j_pos = j + n + else: + j_pos = j + if not 0 <= j_pos < n: + raise IndexError(f"Invalid column index {j} for Matrix of shape {self.shape}") + new_cols.append(j_pos) + + return self._extract(new_rows, new_cols) + + def extract_slice(self, rowslice, colslice): + """Slice a DFM.""" + # XXX: flint matrices do not support slicing + m, n = self.shape + i_indices = range(m)[rowslice] + j_indices = range(n)[colslice] + return self._extract(i_indices, j_indices) + + def neg(self): + """Negate a DFM matrix.""" + return self._new_rep(-self.rep) + + def add(self, other): + """Add two DFM matrices.""" + return self._new_rep(self.rep + other.rep) + + def sub(self, other): + """Subtract two DFM matrices.""" + return self._new_rep(self.rep - other.rep) + + def mul(self, other): + """Multiply a DFM matrix from the right by a scalar.""" + return self._new_rep(self.rep * other) + + def rmul(self, other): + """Multiply a DFM matrix from the left by a scalar.""" + return self._new_rep(other * self.rep) + + def mul_elementwise(self, other): + """Elementwise multiplication of two DFM matrices.""" + # XXX: flint matrices do not support elementwise multiplication + return self.to_ddm().mul_elementwise(other.to_ddm()).to_dfm() + + def matmul(self, other): + """Multiply two DFM matrices.""" + shape = (self.rows, other.cols) + return self._new(self.rep * other.rep, shape, self.domain) + + # XXX: For the most part DomainMatrix does not expect DDM, SDM, or DFM to + # have arithmetic operators defined. The only exception is negation. + # Perhaps that should be removed. + + def __neg__(self): + """Negate a DFM matrix.""" + return self.neg() + + @classmethod + def zeros(cls, shape, domain): + """Return a zero DFM matrix.""" + func = cls._get_flint_func(domain) + return cls._new(func(*shape), shape, domain) + + # XXX: flint matrices do not have anything like ones or eye + # In the methods below we convert to DDM and then back to DFM which is + # probably about as efficient as implementing these methods directly. + + @classmethod + def ones(cls, shape, domain): + """Return a one DFM matrix.""" + # XXX: flint matrices do not have anything like ones + return DDM.ones(shape, domain).to_dfm() + + @classmethod + def eye(cls, n, domain): + """Return the identity matrix of size n.""" + # XXX: flint matrices do not have anything like eye + return DDM.eye(n, domain).to_dfm() + + @classmethod + def diag(cls, elements, domain): + """Return a diagonal matrix.""" + return DDM.diag(elements, domain).to_dfm() + + def applyfunc(self, func, domain): + """Apply a function to each entry of a DFM matrix.""" + return self.to_ddm().applyfunc(func, domain).to_dfm() + + def transpose(self): + """Transpose a DFM matrix.""" + return self._new(self.rep.transpose(), (self.cols, self.rows), self.domain) + + def hstack(self, *others): + """Horizontally stack matrices.""" + return self.to_ddm().hstack(*[o.to_ddm() for o in others]).to_dfm() + + def vstack(self, *others): + """Vertically stack matrices.""" + return self.to_ddm().vstack(*[o.to_ddm() for o in others]).to_dfm() + + def diagonal(self): + """Return the diagonal of a DFM matrix.""" + M = self.rep + m, n = self.shape + return [M[i, i] for i in range(min(m, n))] + + def is_upper(self): + """Return ``True`` if the matrix is upper triangular.""" + M = self.rep + for i in range(self.rows): + for j in range(min(i, self.cols)): + if M[i, j]: + return False + return True + + def is_lower(self): + """Return ``True`` if the matrix is lower triangular.""" + M = self.rep + for i in range(self.rows): + for j in range(i + 1, self.cols): + if M[i, j]: + return False + return True + + def is_diagonal(self): + """Return ``True`` if the matrix is diagonal.""" + return self.is_upper() and self.is_lower() + + def is_zero_matrix(self): + """Return ``True`` if the matrix is the zero matrix.""" + M = self.rep + for i in range(self.rows): + for j in range(self.cols): + if M[i, j]: + return False + return True + + def nnz(self): + """Return the number of non-zero elements in the matrix.""" + return self.to_ddm().nnz() + + def scc(self): + """Return the strongly connected components of the matrix.""" + return self.to_ddm().scc() + + @doctest_depends_on(ground_types='flint') + def det(self): + """ + Compute the determinant of the matrix using FLINT. + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix([[1, 2], [3, 4]]) + >>> dfm = M.to_DM().to_dfm() + >>> dfm + [[1, 2], [3, 4]] + >>> dfm.det() + -2 + + Notes + ===== + + Calls the ``.det()`` method of the underlying FLINT matrix. + + For :ref:`ZZ` or :ref:`QQ` this calls ``fmpz_mat_det`` or + ``fmpq_mat_det`` respectively. + + At the time of writing the implementation of ``fmpz_mat_det`` uses one + of several algorithms depending on the size of the matrix and bit size + of the entries. The algorithms used are: + + - Cofactor for very small (up to 4x4) matrices. + - Bareiss for small (up to 25x25) matrices. + - Modular algorithms for larger matrices (up to 60x60) or for larger + matrices with large bit sizes. + - Modular "accelerated" for larger matrices (60x60 upwards) if the bit + size is smaller than the dimensions of the matrix. + + The implementation of ``fmpq_mat_det`` clears denominators from each + row (not the whole matrix) and then calls ``fmpz_mat_det`` and divides + by the product of the denominators. + + See Also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.det + Higher level interface to compute the determinant of a matrix. + """ + # XXX: At least the first three algorithms described above should also + # be implemented in the pure Python DDM and SDM classes which at the + # time of writng just use Bareiss for all matrices and domains. + # Probably in Python the thresholds would be different though. + return self.rep.det() + + @doctest_depends_on(ground_types='flint') + def charpoly(self): + """ + Compute the characteristic polynomial of the matrix using FLINT. + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix([[1, 2], [3, 4]]) + >>> dfm = M.to_DM().to_dfm() # need ground types = 'flint' + >>> dfm + [[1, 2], [3, 4]] + >>> dfm.charpoly() + [1, -5, -2] + + Notes + ===== + + Calls the ``.charpoly()`` method of the underlying FLINT matrix. + + For :ref:`ZZ` or :ref:`QQ` this calls ``fmpz_mat_charpoly`` or + ``fmpq_mat_charpoly`` respectively. + + At the time of writing the implementation of ``fmpq_mat_charpoly`` + clears a denominator from the whole matrix and then calls + ``fmpz_mat_charpoly``. The coefficients of the characteristic + polynomial are then multiplied by powers of the denominator. + + The ``fmpz_mat_charpoly`` method uses a modular algorithm with CRT + reconstruction. The modular algorithm uses ``nmod_mat_charpoly`` which + uses Berkowitz for small matrices and non-prime moduli or otherwise + the Danilevsky method. + + See Also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.charpoly + Higher level interface to compute the characteristic polynomial of + a matrix. + """ + # FLINT polynomial coefficients are in reverse order compared to SymPy. + return self.rep.charpoly().coeffs()[::-1] + + @doctest_depends_on(ground_types='flint') + def inv(self): + """ + Compute the inverse of a matrix using FLINT. + + Examples + ======== + + >>> from sympy import Matrix, QQ + >>> M = Matrix([[1, 2], [3, 4]]) + >>> dfm = M.to_DM().to_dfm().convert_to(QQ) + >>> dfm + [[1, 2], [3, 4]] + >>> dfm.inv() + [[-2, 1], [3/2, -1/2]] + >>> dfm.matmul(dfm.inv()) + [[1, 0], [0, 1]] + + Notes + ===== + + Calls the ``.inv()`` method of the underlying FLINT matrix. + + For now this will raise an error if the domain is :ref:`ZZ` but will + use the FLINT method for :ref:`QQ`. + + The FLINT methods for :ref:`ZZ` and :ref:`QQ` are ``fmpz_mat_inv`` and + ``fmpq_mat_inv`` respectively. The ``fmpz_mat_inv`` method computes an + inverse with denominator. This is implemented by calling + ``fmpz_mat_solve`` (see notes in :meth:`lu_solve` about the algorithm). + + The ``fmpq_mat_inv`` method clears denominators from each row and then + multiplies those into the rhs identity matrix before calling + ``fmpz_mat_solve``. + + See Also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.inv + Higher level method for computing the inverse of a matrix. + """ + # TODO: Implement similar algorithms for DDM and SDM. + # + # XXX: The flint fmpz_mat and fmpq_mat inv methods both return fmpq_mat + # by default. The fmpz_mat method has an optional argument to return + # fmpz_mat instead for unimodular matrices. + # + # The convention in DomainMatrix is to raise an error if the matrix is + # not over a field regardless of whether the matrix is invertible over + # its domain or over any associated field. Maybe DomainMatrix.inv + # should be changed to always return a matrix over an associated field + # except with a unimodular argument for returning an inverse over a + # ring if possible. + # + # For now we follow the existing DomainMatrix convention... + K = self.domain + m, n = self.shape + + if m != n: + raise DMNonSquareMatrixError("cannot invert a non-square matrix") + + if K == ZZ: + raise DMDomainError("field expected, got %s" % K) + elif K == QQ or K.is_FF: + try: + return self._new_rep(self.rep.inv()) + except ZeroDivisionError: + raise DMNonInvertibleMatrixError("matrix is not invertible") + else: + # If more domains are added for DFM then we will need to consider + # what happens here. + raise NotImplementedError("DFM.inv() is not implemented for %s" % K) + + def lu(self): + """Return the LU decomposition of the matrix.""" + L, U, swaps = self.to_ddm().lu() + return L.to_dfm(), U.to_dfm(), swaps + + def qr(self): + """Return the QR decomposition of the matrix.""" + Q, R = self.to_ddm().qr() + return Q.to_dfm(), R.to_dfm() + + # XXX: The lu_solve function should be renamed to solve. Whether or not it + # uses an LU decomposition is an implementation detail. A method called + # lu_solve would make sense for a situation in which an LU decomposition is + # reused several times to solve with different rhs but that would imply a + # different call signature. + # + # The underlying python-flint method has an algorithm= argument so we could + # use that and have e.g. solve_lu and solve_modular or perhaps also a + # method= argument to choose between the two. Flint itself has more + # possible algorithms to choose from than are exposed by python-flint. + + @doctest_depends_on(ground_types='flint') + def lu_solve(self, rhs): + """ + Solve a matrix equation using FLINT. + + Examples + ======== + + >>> from sympy import Matrix, QQ + >>> M = Matrix([[1, 2], [3, 4]]) + >>> dfm = M.to_DM().to_dfm().convert_to(QQ) + >>> dfm + [[1, 2], [3, 4]] + >>> rhs = Matrix([1, 2]).to_DM().to_dfm().convert_to(QQ) + >>> dfm.lu_solve(rhs) + [[0], [1/2]] + + Notes + ===== + + Calls the ``.solve()`` method of the underlying FLINT matrix. + + For now this will raise an error if the domain is :ref:`ZZ` but will + use the FLINT method for :ref:`QQ`. + + The FLINT methods for :ref:`ZZ` and :ref:`QQ` are ``fmpz_mat_solve`` + and ``fmpq_mat_solve`` respectively. The ``fmpq_mat_solve`` method + uses one of two algorithms: + + - For small matrices (<25 rows) it clears denominators between the + matrix and rhs and uses ``fmpz_mat_solve``. + - For larger matrices it uses ``fmpq_mat_solve_dixon`` which is a + modular approach with CRT reconstruction over :ref:`QQ`. + + The ``fmpz_mat_solve`` method uses one of four algorithms: + + - For very small (<= 3x3) matrices it uses a Cramer's rule. + - For small (<= 15x15) matrices it uses a fraction-free LU solve. + - Otherwise it uses either Dixon or another multimodular approach. + + See Also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.lu_solve + Higher level interface to solve a matrix equation. + """ + if not self.domain == rhs.domain: + raise DMDomainError("Domains must match: %s != %s" % (self.domain, rhs.domain)) + + # XXX: As for inv we should consider whether to return a matrix over + # over an associated field or attempt to find a solution in the ring. + # For now we follow the existing DomainMatrix convention... + if not self.domain.is_Field: + raise DMDomainError("Field expected, got %s" % self.domain) + + m, n = self.shape + j, k = rhs.shape + if m != j: + raise DMShapeError("Matrix size mismatch: %s * %s vs %s * %s" % (m, n, j, k)) + sol_shape = (n, k) + + # XXX: The Flint solve method only handles square matrices. Probably + # Flint has functions that could be used to solve non-square systems + # but they are not exposed in python-flint yet. Alternatively we could + # put something here using the features that are available like rref. + if m != n: + return self.to_ddm().lu_solve(rhs.to_ddm()).to_dfm() + + try: + sol = self.rep.solve(rhs.rep) + except ZeroDivisionError: + raise DMNonInvertibleMatrixError("Matrix det == 0; not invertible.") + + return self._new(sol, sol_shape, self.domain) + + def fflu(self): + """ + Fraction-free LU decomposition of DFM. + + Explanation + =========== + + Uses `python-flint` if possible for a matrix of + integers otherwise uses the DDM method. + + See Also + ======== + + sympy.polys.matrices.ddm.DDM.fflu + """ + if self.domain == ZZ: + fflu = getattr(self.rep, 'fflu', None) + if fflu is not None: + P, L, D, U = self.rep.fflu() + m, n = self.shape + return ( + self._new(P, (m, m), self.domain), + self._new(L, (m, m), self.domain), + self._new(D, (m, m), self.domain), + self._new(U, self.shape, self.domain) + ) + ddm_p, ddm_l, ddm_d, ddm_u = self.to_ddm().fflu() + P = ddm_p.to_dfm() + L = ddm_l.to_dfm() + D = ddm_d.to_dfm() + U = ddm_u.to_dfm() + return P, L, D, U + + def nullspace(self): + """Return a basis for the nullspace of the matrix.""" + # Code to compute nullspace using flint: + # + # V, nullity = self.rep.nullspace() + # V_dfm = self._new_rep(V)._extract(range(self.rows), range(nullity)) + # + # XXX: That gives the nullspace but does not give us nonpivots. So we + # use the slower DDM method anyway. It would be better to change the + # signature of the nullspace method to not return nonpivots. + # + # XXX: Also python-flint exposes a nullspace method for fmpz_mat but + # not for fmpq_mat. This is the reverse of the situation for DDM etc + # which only allow nullspace over a field. The nullspace method for + # DDM, SDM etc should be changed to allow nullspace over ZZ as well. + # The DomainMatrix nullspace method does allow the domain to be a ring + # but does not directly call the lower-level nullspace methods and uses + # rref_den instead. Nullspace methods should also be added to all + # matrix types in python-flint. + ddm, nonpivots = self.to_ddm().nullspace() + return ddm.to_dfm(), nonpivots + + def nullspace_from_rref(self, pivots=None): + """Return a basis for the nullspace of the matrix.""" + # XXX: Use the flint nullspace method!!! + sdm, nonpivots = self.to_sdm().nullspace_from_rref(pivots=pivots) + return sdm.to_dfm(), nonpivots + + def particular(self): + """Return a particular solution to the system.""" + return self.to_ddm().particular().to_dfm() + + def _lll(self, transform=False, delta=0.99, eta=0.51, rep='zbasis', gram='approx'): + """Call the fmpz_mat.lll() method but check rank to avoid segfaults.""" + + # XXX: There are tests that pass e.g. QQ(5,6) for delta. That fails + # with a TypeError in flint because if QQ is fmpq then conversion with + # float fails. We handle that here but there are two better fixes: + # + # - Make python-flint's fmpq convert with float(x) + # - Change the tests because delta should just be a float. + + def to_float(x): + if QQ.of_type(x): + return float(x.numerator) / float(x.denominator) + else: + return float(x) + + delta = to_float(delta) + eta = to_float(eta) + + if not 0.25 < delta < 1: + raise DMValueError("delta must be between 0.25 and 1") + + # XXX: The flint lll method segfaults if the matrix is not full rank. + m, n = self.shape + if self.rep.rank() != m: + raise DMRankError("Matrix must have full row rank for Flint LLL.") + + # Actually call the flint method. + return self.rep.lll(transform=transform, delta=delta, eta=eta, rep=rep, gram=gram) + + @doctest_depends_on(ground_types='flint') + def lll(self, delta=0.75): + """Compute LLL-reduced basis using FLINT. + + See :meth:`lll_transform` for more information. + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix([[1, 2, 3], [4, 5, 6]]) + >>> M.to_DM().to_dfm().lll() + [[2, 1, 0], [-1, 1, 3]] + + See Also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.lll + Higher level interface to compute LLL-reduced basis. + lll_transform + Compute LLL-reduced basis and transform matrix. + """ + if self.domain != ZZ: + raise DMDomainError("ZZ expected, got %s" % self.domain) + elif self.rows > self.cols: + raise DMShapeError("Matrix must not have more rows than columns.") + + rep = self._lll(delta=delta) + return self._new_rep(rep) + + @doctest_depends_on(ground_types='flint') + def lll_transform(self, delta=0.75): + """Compute LLL-reduced basis and transform using FLINT. + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix([[1, 2, 3], [4, 5, 6]]).to_DM().to_dfm() + >>> M_lll, T = M.lll_transform() + >>> M_lll + [[2, 1, 0], [-1, 1, 3]] + >>> T + [[-2, 1], [3, -1]] + >>> T.matmul(M) == M_lll + True + + See Also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.lll + Higher level interface to compute LLL-reduced basis. + lll + Compute LLL-reduced basis without transform matrix. + """ + if self.domain != ZZ: + raise DMDomainError("ZZ expected, got %s" % self.domain) + elif self.rows > self.cols: + raise DMShapeError("Matrix must not have more rows than columns.") + + rep, T = self._lll(transform=True, delta=delta) + basis = self._new_rep(rep) + T_dfm = self._new(T, (self.rows, self.rows), self.domain) + return basis, T_dfm + + +# Avoid circular imports +from sympy.polys.matrices.ddm import DDM +from sympy.polys.matrices.ddm import SDM diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/_typing.py b/phivenv/Lib/site-packages/sympy/polys/matrices/_typing.py new file mode 100644 index 0000000000000000000000000000000000000000..fc7c3b601fe85d591ddf853acbf33f5bba64b11c --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/matrices/_typing.py @@ -0,0 +1,16 @@ +from typing import TypeVar, Protocol + + +T = TypeVar('T') + + +class RingElement(Protocol): + """A ring element. + + Must support ``+``, ``-``, ``*``, ``**`` and ``-``. + """ + def __add__(self: T, other: T, /) -> T: ... + def __sub__(self: T, other: T, /) -> T: ... + def __mul__(self: T, other: T, /) -> T: ... + def __pow__(self: T, other: int, /) -> T: ... + def __neg__(self: T, /) -> T: ... diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/ddm.py b/phivenv/Lib/site-packages/sympy/polys/matrices/ddm.py new file mode 100644 index 0000000000000000000000000000000000000000..9b7836ef298fe27a1c02ed069f33711a632d6ed8 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/matrices/ddm.py @@ -0,0 +1,1176 @@ +""" + +Module for the DDM class. + +The DDM class is an internal representation used by DomainMatrix. The letters +DDM stand for Dense Domain Matrix. A DDM instance represents a matrix using +elements from a polynomial Domain (e.g. ZZ, QQ, ...) in a dense-matrix +representation. + +Basic usage: + + >>> from sympy import ZZ, QQ + >>> from sympy.polys.matrices.ddm import DDM + >>> A = DDM([[ZZ(0), ZZ(1)], [ZZ(-1), ZZ(0)]], (2, 2), ZZ) + >>> A.shape + (2, 2) + >>> A + [[0, 1], [-1, 0]] + >>> type(A) + + >>> A @ A + [[-1, 0], [0, -1]] + +The ddm_* functions are designed to operate on DDM as well as on an ordinary +list of lists: + + >>> from sympy.polys.matrices.dense import ddm_idet + >>> ddm_idet(A, QQ) + 1 + >>> ddm_idet([[0, 1], [-1, 0]], QQ) + 1 + >>> A + [[-1, 0], [0, -1]] + +Note that ddm_idet modifies the input matrix in-place. It is recommended to +use the DDM.det method as a friendlier interface to this instead which takes +care of copying the matrix: + + >>> B = DDM([[ZZ(0), ZZ(1)], [ZZ(-1), ZZ(0)]], (2, 2), ZZ) + >>> B.det() + 1 + +Normally DDM would not be used directly and is just part of the internal +representation of DomainMatrix which adds further functionality including e.g. +unifying domains. + +The dense format used by DDM is a list of lists of elements e.g. the 2x2 +identity matrix is like [[1, 0], [0, 1]]. The DDM class itself is a subclass +of list and its list items are plain lists. Elements are accessed as e.g. +ddm[i][j] where ddm[i] gives the ith row and ddm[i][j] gets the element in the +jth column of that row. Subclassing list makes e.g. iteration and indexing +very efficient. We do not override __getitem__ because it would lose that +benefit. + +The core routines are implemented by the ddm_* functions defined in dense.py. +Those functions are intended to be able to operate on a raw list-of-lists +representation of matrices with most functions operating in-place. The DDM +class takes care of copying etc and also stores a Domain object associated +with its elements. This makes it possible to implement things like A + B with +domain checking and also shape checking so that the list of lists +representation is friendlier. + +""" +from itertools import chain + +from sympy.external.gmpy import GROUND_TYPES +from sympy.utilities.decorator import doctest_depends_on + +from .exceptions import ( + DMBadInputError, + DMDomainError, + DMNonSquareMatrixError, + DMShapeError, +) + +from sympy.polys.domains import QQ + +from .dense import ( + ddm_transpose, + ddm_iadd, + ddm_isub, + ddm_ineg, + ddm_imul, + ddm_irmul, + ddm_imatmul, + ddm_irref, + ddm_irref_den, + ddm_idet, + ddm_iinv, + ddm_ilu_split, + ddm_ilu_solve, + ddm_berk, + ) + +from .lll import ddm_lll, ddm_lll_transform + + +if GROUND_TYPES != 'flint': + __doctest_skip__ = ['DDM.to_dfm', 'DDM.to_dfm_or_ddm'] + + +class DDM(list): + """Dense matrix based on polys domain elements + + This is a list subclass and is a wrapper for a list of lists that supports + basic matrix arithmetic +, -, *, **. + """ + + fmt = 'dense' + is_DFM = False + is_DDM = True + + def __init__(self, rowslist, shape, domain): + if not (isinstance(rowslist, list) and all(type(row) is list for row in rowslist)): + raise DMBadInputError("rowslist must be a list of lists") + m, n = shape + if len(rowslist) != m or any(len(row) != n for row in rowslist): + raise DMBadInputError("Inconsistent row-list/shape") + + super().__init__([i.copy() for i in rowslist]) + self.shape = (m, n) + self.rows = m + self.cols = n + self.domain = domain + + def getitem(self, i, j): + return self[i][j] + + def setitem(self, i, j, value): + self[i][j] = value + + def extract_slice(self, slice1, slice2): + ddm = [row[slice2] for row in self[slice1]] + rows = len(ddm) + cols = len(ddm[0]) if ddm else len(range(self.shape[1])[slice2]) + return DDM(ddm, (rows, cols), self.domain) + + def extract(self, rows, cols): + ddm = [] + for i in rows: + rowi = self[i] + ddm.append([rowi[j] for j in cols]) + return DDM(ddm, (len(rows), len(cols)), self.domain) + + @classmethod + def from_list(cls, rowslist, shape, domain): + """ + Create a :class:`DDM` from a list of lists. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices.ddm import DDM + >>> A = DDM.from_list([[ZZ(0), ZZ(1)], [ZZ(-1), ZZ(0)]], (2, 2), ZZ) + >>> A + [[0, 1], [-1, 0]] + >>> A == DDM([[ZZ(0), ZZ(1)], [ZZ(-1), ZZ(0)]], (2, 2), ZZ) + True + + See Also + ======== + + from_list_flat + """ + return cls(rowslist, shape, domain) + + @classmethod + def from_ddm(cls, other): + return other.copy() + + def to_list(self): + """ + Convert to a list of lists. + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.polys.matrices.ddm import DDM + >>> A = DDM([[1, 2], [3, 4]], (2, 2), QQ) + >>> A.to_list() + [[1, 2], [3, 4]] + + See Also + ======== + + to_list_flat + sympy.polys.matrices.domainmatrix.DomainMatrix.to_list + """ + return [row[:] for row in self] + + def to_list_flat(self): + """ + Convert to a flat list of elements. + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.polys.matrices.ddm import DDM + >>> A = DDM([[1, 2], [3, 4]], (2, 2), QQ) + >>> A.to_list_flat() + [1, 2, 3, 4] + >>> A == DDM.from_list_flat(A.to_list_flat(), A.shape, A.domain) + True + + See Also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.to_list_flat + """ + flat = [] + for row in self: + flat.extend(row) + return flat + + @classmethod + def from_list_flat(cls, flat, shape, domain): + """ + Create a :class:`DDM` from a flat list of elements. + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.polys.matrices.ddm import DDM + >>> A = DDM.from_list_flat([1, 2, 3, 4], (2, 2), QQ) + >>> A + [[1, 2], [3, 4]] + >>> A == DDM.from_list_flat(A.to_list_flat(), A.shape, A.domain) + True + + See Also + ======== + + to_list_flat + sympy.polys.matrices.domainmatrix.DomainMatrix.from_list_flat + """ + assert type(flat) is list + rows, cols = shape + if not (len(flat) == rows*cols): + raise DMBadInputError("Inconsistent flat-list shape") + lol = [flat[i*cols:(i+1)*cols] for i in range(rows)] + return cls(lol, shape, domain) + + def flatiter(self): + return chain.from_iterable(self) + + def flat(self): + items = [] + for row in self: + items.extend(row) + return items + + def to_flat_nz(self): + """ + Convert to a flat list of nonzero elements and data. + + Explanation + =========== + + This is used to operate on a list of the elements of a matrix and then + reconstruct a matrix using :meth:`from_flat_nz`. Zero elements are + included in the list but that may change in the future. + + Examples + ======== + + >>> from sympy.polys.matrices.ddm import DDM + >>> from sympy import QQ + >>> A = DDM([[1, 2], [3, 4]], (2, 2), QQ) + >>> elements, data = A.to_flat_nz() + >>> elements + [1, 2, 3, 4] + >>> A == DDM.from_flat_nz(elements, data, A.domain) + True + + See Also + ======== + + from_flat_nz + sympy.polys.matrices.sdm.SDM.to_flat_nz + sympy.polys.matrices.domainmatrix.DomainMatrix.to_flat_nz + """ + return self.to_sdm().to_flat_nz() + + @classmethod + def from_flat_nz(cls, elements, data, domain): + """ + Reconstruct a :class:`DDM` after calling :meth:`to_flat_nz`. + + Examples + ======== + + >>> from sympy.polys.matrices.ddm import DDM + >>> from sympy import QQ + >>> A = DDM([[1, 2], [3, 4]], (2, 2), QQ) + >>> elements, data = A.to_flat_nz() + >>> elements + [1, 2, 3, 4] + >>> A == DDM.from_flat_nz(elements, data, A.domain) + True + + See Also + ======== + + to_flat_nz + sympy.polys.matrices.sdm.SDM.from_flat_nz + sympy.polys.matrices.domainmatrix.DomainMatrix.from_flat_nz + """ + return SDM.from_flat_nz(elements, data, domain).to_ddm() + + def to_dod(self): + """ + Convert to a dictionary of dictionaries (dod) format. + + Examples + ======== + + >>> from sympy.polys.matrices.ddm import DDM + >>> from sympy import QQ + >>> A = DDM([[1, 2], [3, 4]], (2, 2), QQ) + >>> A.to_dod() + {0: {0: 1, 1: 2}, 1: {0: 3, 1: 4}} + + See Also + ======== + + from_dod + sympy.polys.matrices.sdm.SDM.to_dod + sympy.polys.matrices.domainmatrix.DomainMatrix.to_dod + """ + dod = {} + for i, row in enumerate(self): + row = {j:e for j, e in enumerate(row) if e} + if row: + dod[i] = row + return dod + + @classmethod + def from_dod(cls, dod, shape, domain): + """ + Create a :class:`DDM` from a dictionary of dictionaries (dod) format. + + Examples + ======== + + >>> from sympy.polys.matrices.ddm import DDM + >>> from sympy import QQ + >>> dod = {0: {0: 1, 1: 2}, 1: {0: 3, 1: 4}} + >>> A = DDM.from_dod(dod, (2, 2), QQ) + >>> A + [[1, 2], [3, 4]] + + See Also + ======== + + to_dod + sympy.polys.matrices.sdm.SDM.from_dod + sympy.polys.matrices.domainmatrix.DomainMatrix.from_dod + """ + rows, cols = shape + lol = [[domain.zero] * cols for _ in range(rows)] + for i, row in dod.items(): + for j, element in row.items(): + lol[i][j] = element + return DDM(lol, shape, domain) + + def to_dok(self): + """ + Convert :class:`DDM` to dictionary of keys (dok) format. + + Examples + ======== + + >>> from sympy.polys.matrices.ddm import DDM + >>> from sympy import QQ + >>> A = DDM([[1, 2], [3, 4]], (2, 2), QQ) + >>> A.to_dok() + {(0, 0): 1, (0, 1): 2, (1, 0): 3, (1, 1): 4} + + See Also + ======== + + from_dok + sympy.polys.matrices.sdm.SDM.to_dok + sympy.polys.matrices.domainmatrix.DomainMatrix.to_dok + """ + dok = {} + for i, row in enumerate(self): + for j, element in enumerate(row): + if element: + dok[i, j] = element + return dok + + @classmethod + def from_dok(cls, dok, shape, domain): + """ + Create a :class:`DDM` from a dictionary of keys (dok) format. + + Examples + ======== + + >>> from sympy.polys.matrices.ddm import DDM + >>> from sympy import QQ + >>> dok = {(0, 0): 1, (0, 1): 2, (1, 0): 3, (1, 1): 4} + >>> A = DDM.from_dok(dok, (2, 2), QQ) + >>> A + [[1, 2], [3, 4]] + + See Also + ======== + + to_dok + sympy.polys.matrices.sdm.SDM.from_dok + sympy.polys.matrices.domainmatrix.DomainMatrix.from_dok + """ + rows, cols = shape + lol = [[domain.zero] * cols for _ in range(rows)] + for (i, j), element in dok.items(): + lol[i][j] = element + return DDM(lol, shape, domain) + + def iter_values(self): + """ + Iterate over the non-zero values of the matrix. + + Examples + ======== + + >>> from sympy.polys.matrices.ddm import DDM + >>> from sympy import QQ + >>> A = DDM([[QQ(1), QQ(0)], [QQ(3), QQ(4)]], (2, 2), QQ) + >>> list(A.iter_values()) + [1, 3, 4] + + See Also + ======== + + iter_items + to_list_flat + sympy.polys.matrices.domainmatrix.DomainMatrix.iter_values + """ + for row in self: + yield from filter(None, row) + + def iter_items(self): + """ + Iterate over indices and values of nonzero elements of the matrix. + + Examples + ======== + + >>> from sympy.polys.matrices.ddm import DDM + >>> from sympy import QQ + >>> A = DDM([[QQ(1), QQ(0)], [QQ(3), QQ(4)]], (2, 2), QQ) + >>> list(A.iter_items()) + [((0, 0), 1), ((1, 0), 3), ((1, 1), 4)] + + See Also + ======== + + iter_values + to_dok + sympy.polys.matrices.domainmatrix.DomainMatrix.iter_items + """ + for i, row in enumerate(self): + for j, element in enumerate(row): + if element: + yield (i, j), element + + def to_ddm(self): + """ + Convert to a :class:`DDM`. + + This just returns ``self`` but exists to parallel the corresponding + method in other matrix types like :class:`~.SDM`. + + See Also + ======== + + to_sdm + to_dfm + to_dfm_or_ddm + sympy.polys.matrices.sdm.SDM.to_ddm + sympy.polys.matrices.domainmatrix.DomainMatrix.to_ddm + """ + return self + + def to_sdm(self): + """ + Convert to a :class:`~.SDM`. + + Examples + ======== + + >>> from sympy.polys.matrices.ddm import DDM + >>> from sympy import QQ + >>> A = DDM([[1, 2], [3, 4]], (2, 2), QQ) + >>> A.to_sdm() + {0: {0: 1, 1: 2}, 1: {0: 3, 1: 4}} + >>> type(A.to_sdm()) + + + See Also + ======== + + SDM + sympy.polys.matrices.sdm.SDM.to_ddm + """ + return SDM.from_list(self, self.shape, self.domain) + + @doctest_depends_on(ground_types=['flint']) + def to_dfm(self): + """ + Convert to :class:`~.DDM` to :class:`~.DFM`. + + Examples + ======== + + >>> from sympy.polys.matrices.ddm import DDM + >>> from sympy import QQ + >>> A = DDM([[1, 2], [3, 4]], (2, 2), QQ) + >>> A.to_dfm() + [[1, 2], [3, 4]] + >>> type(A.to_dfm()) + + + See Also + ======== + + DFM + sympy.polys.matrices._dfm.DFM.to_ddm + """ + return DFM(list(self), self.shape, self.domain) + + @doctest_depends_on(ground_types=['flint']) + def to_dfm_or_ddm(self): + """ + Convert to :class:`~.DFM` if possible or otherwise return self. + + Examples + ======== + + >>> from sympy.polys.matrices.ddm import DDM + >>> from sympy import QQ + >>> A = DDM([[1, 2], [3, 4]], (2, 2), QQ) + >>> A.to_dfm_or_ddm() + [[1, 2], [3, 4]] + >>> type(A.to_dfm_or_ddm()) + + + See Also + ======== + + to_dfm + to_ddm + sympy.polys.matrices.domainmatrix.DomainMatrix.to_dfm_or_ddm + """ + if DFM._supports_domain(self.domain): + return self.to_dfm() + return self + + def convert_to(self, K): + Kold = self.domain + if K == Kold: + return self.copy() + rows = [[K.convert_from(e, Kold) for e in row] for row in self] + return DDM(rows, self.shape, K) + + def __str__(self): + rowsstr = ['[%s]' % ', '.join(map(str, row)) for row in self] + return '[%s]' % ', '.join(rowsstr) + + def __repr__(self): + cls = type(self).__name__ + rows = list.__repr__(self) + return '%s(%s, %s, %s)' % (cls, rows, self.shape, self.domain) + + def __eq__(self, other): + if not isinstance(other, DDM): + return False + return (super().__eq__(other) and self.domain == other.domain) + + def __ne__(self, other): + return not self.__eq__(other) + + @classmethod + def zeros(cls, shape, domain): + z = domain.zero + m, n = shape + rowslist = [[z] * n for _ in range(m)] + return DDM(rowslist, shape, domain) + + @classmethod + def ones(cls, shape, domain): + one = domain.one + m, n = shape + rowlist = [[one] * n for _ in range(m)] + return DDM(rowlist, shape, domain) + + @classmethod + def eye(cls, size, domain): + if isinstance(size, tuple): + m, n = size + elif isinstance(size, int): + m = n = size + one = domain.one + ddm = cls.zeros((m, n), domain) + for i in range(min(m, n)): + ddm[i][i] = one + return ddm + + def copy(self): + copyrows = [row[:] for row in self] + return DDM(copyrows, self.shape, self.domain) + + def transpose(self): + rows, cols = self.shape + if rows: + ddmT = ddm_transpose(self) + else: + ddmT = [[]] * cols + return DDM(ddmT, (cols, rows), self.domain) + + def __add__(a, b): + if not isinstance(b, DDM): + return NotImplemented + return a.add(b) + + def __sub__(a, b): + if not isinstance(b, DDM): + return NotImplemented + return a.sub(b) + + def __neg__(a): + return a.neg() + + def __mul__(a, b): + if b in a.domain: + return a.mul(b) + else: + return NotImplemented + + def __rmul__(a, b): + if b in a.domain: + return a.mul(b) + else: + return NotImplemented + + def __matmul__(a, b): + if isinstance(b, DDM): + return a.matmul(b) + else: + return NotImplemented + + @classmethod + def _check(cls, a, op, b, ashape, bshape): + if a.domain != b.domain: + msg = "Domain mismatch: %s %s %s" % (a.domain, op, b.domain) + raise DMDomainError(msg) + if ashape != bshape: + msg = "Shape mismatch: %s %s %s" % (a.shape, op, b.shape) + raise DMShapeError(msg) + + def add(a, b): + """a + b""" + a._check(a, '+', b, a.shape, b.shape) + c = a.copy() + ddm_iadd(c, b) + return c + + def sub(a, b): + """a - b""" + a._check(a, '-', b, a.shape, b.shape) + c = a.copy() + ddm_isub(c, b) + return c + + def neg(a): + """-a""" + b = a.copy() + ddm_ineg(b) + return b + + def mul(a, b): + c = a.copy() + ddm_imul(c, b) + return c + + def rmul(a, b): + c = a.copy() + ddm_irmul(c, b) + return c + + def matmul(a, b): + """a @ b (matrix product)""" + m, o = a.shape + o2, n = b.shape + a._check(a, '*', b, o, o2) + c = a.zeros((m, n), a.domain) + ddm_imatmul(c, a, b) + return c + + def mul_elementwise(a, b): + assert a.shape == b.shape + assert a.domain == b.domain + c = [[aij * bij for aij, bij in zip(ai, bi)] for ai, bi in zip(a, b)] + return DDM(c, a.shape, a.domain) + + def hstack(A, *B): + """Horizontally stacks :py:class:`~.DDM` matrices. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices.sdm import DDM + + >>> A = DDM([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + >>> B = DDM([[ZZ(5), ZZ(6)], [ZZ(7), ZZ(8)]], (2, 2), ZZ) + >>> A.hstack(B) + [[1, 2, 5, 6], [3, 4, 7, 8]] + + >>> C = DDM([[ZZ(9), ZZ(10)], [ZZ(11), ZZ(12)]], (2, 2), ZZ) + >>> A.hstack(B, C) + [[1, 2, 5, 6, 9, 10], [3, 4, 7, 8, 11, 12]] + """ + Anew = list(A.copy()) + rows, cols = A.shape + domain = A.domain + + for Bk in B: + Bkrows, Bkcols = Bk.shape + assert Bkrows == rows + assert Bk.domain == domain + + cols += Bkcols + + for i, Bki in enumerate(Bk): + Anew[i].extend(Bki) + + return DDM(Anew, (rows, cols), A.domain) + + def vstack(A, *B): + """Vertically stacks :py:class:`~.DDM` matrices. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices.sdm import DDM + + >>> A = DDM([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + >>> B = DDM([[ZZ(5), ZZ(6)], [ZZ(7), ZZ(8)]], (2, 2), ZZ) + >>> A.vstack(B) + [[1, 2], [3, 4], [5, 6], [7, 8]] + + >>> C = DDM([[ZZ(9), ZZ(10)], [ZZ(11), ZZ(12)]], (2, 2), ZZ) + >>> A.vstack(B, C) + [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]] + """ + Anew = list(A.copy()) + rows, cols = A.shape + domain = A.domain + + for Bk in B: + Bkrows, Bkcols = Bk.shape + assert Bkcols == cols + assert Bk.domain == domain + + rows += Bkrows + + Anew.extend(Bk.copy()) + + return DDM(Anew, (rows, cols), A.domain) + + def applyfunc(self, func, domain): + elements = [list(map(func, row)) for row in self] + return DDM(elements, self.shape, domain) + + def nnz(a): + """Number of non-zero entries in :py:class:`~.DDM` matrix. + + See Also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.nnz + """ + return sum(sum(map(bool, row)) for row in a) + + def scc(a): + """Strongly connected components of a square matrix *a*. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices.sdm import DDM + >>> A = DDM([[ZZ(1), ZZ(0)], [ZZ(0), ZZ(1)]], (2, 2), ZZ) + >>> A.scc() + [[0], [1]] + + See also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.scc + + """ + return a.to_sdm().scc() + + @classmethod + def diag(cls, values, domain): + """Returns a square diagonal matrix with *values* on the diagonal. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices.sdm import DDM + >>> DDM.diag([ZZ(1), ZZ(2), ZZ(3)], ZZ) + [[1, 0, 0], [0, 2, 0], [0, 0, 3]] + + See also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.diag + """ + return SDM.diag(values, domain).to_ddm() + + def rref(a): + """Reduced-row echelon form of a and list of pivots. + + See Also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.rref + Higher level interface to this function. + sympy.polys.matrices.dense.ddm_irref + The underlying algorithm. + """ + b = a.copy() + K = a.domain + partial_pivot = K.is_RealField or K.is_ComplexField + pivots = ddm_irref(b, _partial_pivot=partial_pivot) + return b, pivots + + def rref_den(a): + """Reduced-row echelon form of a with denominator and list of pivots + + See Also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.rref_den + Higher level interface to this function. + sympy.polys.matrices.dense.ddm_irref_den + The underlying algorithm. + """ + b = a.copy() + K = a.domain + denom, pivots = ddm_irref_den(b, K) + return b, denom, pivots + + def nullspace(a): + """Returns a basis for the nullspace of a. + + The domain of the matrix must be a field. + + See Also + ======== + + rref + sympy.polys.matrices.domainmatrix.DomainMatrix.nullspace + """ + rref, pivots = a.rref() + return rref.nullspace_from_rref(pivots) + + def nullspace_from_rref(a, pivots=None): + """Compute the nullspace of a matrix from its rref. + + The domain of the matrix can be any domain. + + Returns a tuple (basis, nonpivots). + + See Also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.nullspace + The higher level interface to this function. + """ + m, n = a.shape + K = a.domain + + if pivots is None: + pivots = [] + last_pivot = -1 + for i in range(m): + ai = a[i] + for j in range(last_pivot+1, n): + if ai[j]: + last_pivot = j + pivots.append(j) + break + + if not pivots: + return (a.eye(n, K), list(range(n))) + + # After rref the pivots are all one but after rref_den they may not be. + pivot_val = a[0][pivots[0]] + + basis = [] + nonpivots = [] + for i in range(n): + if i in pivots: + continue + nonpivots.append(i) + vec = [pivot_val if i == j else K.zero for j in range(n)] + for ii, jj in enumerate(pivots): + vec[jj] -= a[ii][i] + basis.append(vec) + + basis_ddm = DDM(basis, (len(basis), n), K) + + return (basis_ddm, nonpivots) + + def particular(a): + return a.to_sdm().particular().to_ddm() + + def det(a): + """Determinant of a""" + m, n = a.shape + if m != n: + raise DMNonSquareMatrixError("Determinant of non-square matrix") + b = a.copy() + K = b.domain + deta = ddm_idet(b, K) + return deta + + def inv(a): + """Inverse of a""" + m, n = a.shape + if m != n: + raise DMNonSquareMatrixError("Determinant of non-square matrix") + ainv = a.copy() + K = a.domain + ddm_iinv(ainv, a, K) + return ainv + + def lu(a): + """L, U decomposition of a""" + m, n = a.shape + K = a.domain + + U = a.copy() + L = a.eye(m, K) + swaps = ddm_ilu_split(L, U, K) + + return L, U, swaps + + def _fflu(self): + """ + Private method for Phase 1 of fraction-free LU decomposition. + Performs row operations and elimination to compute U and permutation indices. + + Returns: + LU : decomposition as a single matrix. + perm (list): Permutation indices for row swaps. + """ + rows, cols = self.shape + K = self.domain + + LU = self.copy() + perm = list(range(rows)) + rank = 0 + + for j in range(min(rows, cols)): + # Skip columns where all entries are zero + if all(LU[i][j] == K.zero for i in range(rows)): + continue + + # Find the first non-zero pivot in the current column + pivot_row = -1 + for i in range(rank, rows): + if LU[i][j] != K.zero: + pivot_row = i + break + + # If no pivot is found, skip column + if pivot_row == -1: + continue + + # Swap rows to bring the pivot to the current rank + if pivot_row != rank: + LU[rank], LU[pivot_row] = LU[pivot_row], LU[rank] + perm[rank], perm[pivot_row] = perm[pivot_row], perm[rank] + + # Found pivot - (Gauss-Bareiss elimination) + pivot = LU[rank][j] + for i in range(rank + 1, rows): + multiplier = LU[i][j] + # Denominator is previous pivot or 1 + denominator = LU[rank - 1][rank - 1] if rank > 0 else K.one + for k in range(j + 1, cols): + LU[i][k] = K.exquo(pivot * LU[i][k] - LU[rank][k] * multiplier, denominator) + # Keep the multiplier for L matrix + LU[i][j] = multiplier + rank += 1 + + return LU, perm + + def fflu(self): + """ + Fraction-free LU decomposition of DDM. + + See Also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.fflu + The higher-level interface to this function. + """ + rows, cols = self.shape + K = self.domain + + # Phase 1: Perform row operations and get permutation + U, perm = self._fflu() + + # Phase 2: Construct P, L, D matrices + # Create P from permutation + P = self.zeros((rows, rows), K) + for i, pi in enumerate(perm): + P[i][pi] = K.one + + # Create L matrix + L = self.zeros((rows, rows), K) + i = j = 0 + while i < rows and j < cols: + if U[i][j] != K.zero: + # Found non-zero pivot + # Diagonal entry is the pivot + L[i][i] = U[i][j] + for l in range(i + 1, rows): + # Off-diagonal entries are the multipliers + L[l][i] = U[l][j] + # zero out the entries in U + U[l][j] = K.zero + i += 1 + j += 1 + + # Fill remaining diagonal of L with ones + for i in range(i, rows): + L[i][i] = K.one + + # Create D matrix - using FLINT's approach with accumulator + D = self.zeros((rows, rows), K) + if rows >= 1: + D[0][0] = L[0][0] + di = K.one + for i in range(1, rows): + # Accumulate product of pivots + di = L[i - 1][i - 1] * L[i][i] + D[i][i] = di + + return P, L, D, U + + def qr(self): + """ + QR decomposition for DDM. + + Returns: + - Q: Orthogonal matrix as a DDM. + - R: Upper triangular matrix as a DDM. + + See Also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.qr + The higher-level interface to this function. + """ + rows, cols = self.shape + K = self.domain + Q = self.copy() + R = self.zeros((min(rows, cols), cols), K) + + # Check that the domain is a field + if not K.is_Field: + raise DMDomainError("QR decomposition requires a field (e.g. QQ).") + + dot_cols = lambda i, j: K.sum(Q[k][i] * Q[k][j] for k in range(rows)) + + for j in range(cols): + for i in range(min(j, rows)): + dot_ii = dot_cols(i, i) + if dot_ii != K.zero: + R[i][j] = dot_cols(i, j) / dot_ii + for k in range(rows): + Q[k][j] -= R[i][j] * Q[k][i] + + if j < rows: + dot_jj = dot_cols(j, j) + if dot_jj != K.zero: + R[j][j] = K.one + + Q = Q.extract(range(rows), range(min(rows, cols))) + + return Q, R + + def lu_solve(a, b): + """x where a*x = b""" + m, n = a.shape + m2, o = b.shape + a._check(a, 'lu_solve', b, m, m2) + if not a.domain.is_Field: + raise DMDomainError("lu_solve requires a field") + + L, U, swaps = a.lu() + x = a.zeros((n, o), a.domain) + ddm_ilu_solve(x, L, U, swaps, b) + return x + + def charpoly(a): + """Coefficients of characteristic polynomial of a""" + K = a.domain + m, n = a.shape + if m != n: + raise DMNonSquareMatrixError("Charpoly of non-square matrix") + vec = ddm_berk(a, K) + coeffs = [vec[i][0] for i in range(n+1)] + return coeffs + + def is_zero_matrix(self): + """ + Says whether this matrix has all zero entries. + """ + zero = self.domain.zero + return all(Mij == zero for Mij in self.flatiter()) + + def is_upper(self): + """ + Says whether this matrix is upper-triangular. True can be returned + even if the matrix is not square. + """ + zero = self.domain.zero + return all(Mij == zero for i, Mi in enumerate(self) for Mij in Mi[:i]) + + def is_lower(self): + """ + Says whether this matrix is lower-triangular. True can be returned + even if the matrix is not square. + """ + zero = self.domain.zero + return all(Mij == zero for i, Mi in enumerate(self) for Mij in Mi[i+1:]) + + def is_diagonal(self): + """ + Says whether this matrix is diagonal. True can be returned even if + the matrix is not square. + """ + return self.is_upper() and self.is_lower() + + def diagonal(self): + """ + Returns a list of the elements from the diagonal of the matrix. + """ + m, n = self.shape + return [self[i][i] for i in range(min(m, n))] + + def lll(A, delta=QQ(3, 4)): + return ddm_lll(A, delta=delta) + + def lll_transform(A, delta=QQ(3, 4)): + return ddm_lll_transform(A, delta=delta) + + +from .sdm import SDM +from .dfm import DFM diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/dense.py b/phivenv/Lib/site-packages/sympy/polys/matrices/dense.py new file mode 100644 index 0000000000000000000000000000000000000000..47ab2d6897c6d9f3781af23ccb68f96f15c7e859 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/matrices/dense.py @@ -0,0 +1,824 @@ +""" + +Module for the ddm_* routines for operating on a matrix in list of lists +matrix representation. + +These routines are used internally by the DDM class which also provides a +friendlier interface for them. The idea here is to implement core matrix +routines in a way that can be applied to any simple list representation +without the need to use any particular matrix class. For example we can +compute the RREF of a matrix like: + + >>> from sympy.polys.matrices.dense import ddm_irref + >>> M = [[1, 2, 3], [4, 5, 6]] + >>> pivots = ddm_irref(M) + >>> M + [[1.0, 0.0, -1.0], [0, 1.0, 2.0]] + +These are lower-level routines that work mostly in place.The routines at this +level should not need to know what the domain of the elements is but should +ideally document what operations they will use and what functions they need to +be provided with. + +The next-level up is the DDM class which uses these routines but wraps them up +with an interface that handles copying etc and keeps track of the Domain of +the elements of the matrix: + + >>> from sympy.polys.domains import QQ + >>> from sympy.polys.matrices.ddm import DDM + >>> M = DDM([[QQ(1), QQ(2), QQ(3)], [QQ(4), QQ(5), QQ(6)]], (2, 3), QQ) + >>> M + [[1, 2, 3], [4, 5, 6]] + >>> Mrref, pivots = M.rref() + >>> Mrref + [[1, 0, -1], [0, 1, 2]] + +""" +from __future__ import annotations +from operator import mul +from .exceptions import ( + DMShapeError, + DMDomainError, + DMNonInvertibleMatrixError, + DMNonSquareMatrixError, +) +from typing import Sequence, TypeVar +from sympy.polys.matrices._typing import RingElement + + +#: Type variable for the elements of the matrix +T = TypeVar('T') + +#: Type variable for the elements of the matrix that are in a ring +R = TypeVar('R', bound=RingElement) + + +def ddm_transpose(matrix: Sequence[Sequence[T]]) -> list[list[T]]: + """matrix transpose""" + return list(map(list, zip(*matrix))) + + +def ddm_iadd(a: list[list[R]], b: Sequence[Sequence[R]]) -> None: + """a += b""" + for ai, bi in zip(a, b): + for j, bij in enumerate(bi): + ai[j] += bij + + +def ddm_isub(a: list[list[R]], b: Sequence[Sequence[R]]) -> None: + """a -= b""" + for ai, bi in zip(a, b): + for j, bij in enumerate(bi): + ai[j] -= bij + + +def ddm_ineg(a: list[list[R]]) -> None: + """a <-- -a""" + for ai in a: + for j, aij in enumerate(ai): + ai[j] = -aij + + +def ddm_imul(a: list[list[R]], b: R) -> None: + """a <-- a*b""" + for ai in a: + for j, aij in enumerate(ai): + ai[j] = aij * b + + +def ddm_irmul(a: list[list[R]], b: R) -> None: + """a <-- b*a""" + for ai in a: + for j, aij in enumerate(ai): + ai[j] = b * aij + + +def ddm_imatmul( + a: list[list[R]], b: Sequence[Sequence[R]], c: Sequence[Sequence[R]] +) -> None: + """a += b @ c""" + cT = list(zip(*c)) + + for bi, ai in zip(b, a): + for j, cTj in enumerate(cT): + ai[j] = sum(map(mul, bi, cTj), ai[j]) + + +def ddm_irref(a, _partial_pivot=False): + """In-place reduced row echelon form of a matrix. + + Compute the reduced row echelon form of $a$. Modifies $a$ in place and + returns a list of the pivot columns. + + Uses naive Gauss-Jordan elimination in the ground domain which must be a + field. + + This routine is only really suitable for use with simple field domains like + :ref:`GF(p)`, :ref:`QQ` and :ref:`QQ(a)` although even for :ref:`QQ` with + larger matrices it is possibly more efficient to use fraction free + approaches. + + This method is not suitable for use with rational function fields + (:ref:`K(x)`) because the elements will blowup leading to costly gcd + operations. In this case clearing denominators and using fraction free + approaches is likely to be more efficient. + + For inexact numeric domains like :ref:`RR` and :ref:`CC` pass + ``_partial_pivot=True`` to use partial pivoting to control rounding errors. + + Examples + ======== + + >>> from sympy.polys.matrices.dense import ddm_irref + >>> from sympy import QQ + >>> M = [[QQ(1), QQ(2), QQ(3)], [QQ(4), QQ(5), QQ(6)]] + >>> pivots = ddm_irref(M) + >>> M + [[1, 0, -1], [0, 1, 2]] + >>> pivots + [0, 1] + + See Also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.rref + Higher level interface to this routine. + ddm_irref_den + The fraction free version of this routine. + sdm_irref + A sparse version of this routine. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Row_echelon_form#Reduced_row_echelon_form + """ + # We compute aij**-1 below and then use multiplication instead of division + # in the innermost loop. The domain here is a field so either operation is + # defined. There are significant performance differences for some domains + # though. In the case of e.g. QQ or QQ(x) inversion is free but + # multiplication and division have the same cost so it makes no difference. + # In cases like GF(p), QQ, RR or CC though multiplication is + # faster than division so reusing a precomputed inverse for many + # multiplications can be a lot faster. The biggest win is QQ when + # deg(minpoly(a)) is large. + # + # With domains like QQ(x) this can perform badly for other reasons. + # Typically the initial matrix has simple denominators and the + # fraction-free approach with exquo (ddm_irref_den) will preserve that + # property throughout. The method here causes denominator blowup leading to + # expensive gcd reductions in the intermediate expressions. With many + # generators like QQ(x,y,z,...) this is extremely bad. + # + # TODO: Use a nontrivial pivoting strategy to control intermediate + # expression growth. Rearranging rows and/or columns could defer the most + # complicated elements until the end. If the first pivot is a + # complicated/large element then the first round of reduction will + # immediately introduce expression blowup across the whole matrix. + + # a is (m x n) + m = len(a) + if not m: + return [] + n = len(a[0]) + + i = 0 + pivots = [] + + for j in range(n): + # Proper pivoting should be used for all domains for performance + # reasons but it is only strictly needed for RR and CC (and possibly + # other domains like RR(x)). This path is used by DDM.rref() if the + # domain is RR or CC. It uses partial (row) pivoting based on the + # absolute value of the pivot candidates. + if _partial_pivot: + ip = max(range(i, m), key=lambda ip: abs(a[ip][j])) + a[i], a[ip] = a[ip], a[i] + + # pivot + aij = a[i][j] + + # zero-pivot + if not aij: + for ip in range(i+1, m): + aij = a[ip][j] + # row-swap + if aij: + a[i], a[ip] = a[ip], a[i] + break + else: + # next column + continue + + # normalise row + ai = a[i] + aijinv = aij**-1 + for l in range(j, n): + ai[l] *= aijinv # ai[j] = one + + # eliminate above and below to the right + for k, ak in enumerate(a): + if k == i or not ak[j]: + continue + akj = ak[j] + ak[j] -= akj # ak[j] = zero + for l in range(j+1, n): + ak[l] -= akj * ai[l] + + # next row + pivots.append(j) + i += 1 + + # no more rows? + if i >= m: + break + + return pivots + + +def ddm_irref_den(a, K): + """a <-- rref(a); return (den, pivots) + + Compute the fraction-free reduced row echelon form (RREF) of $a$. Modifies + $a$ in place and returns a tuple containing the denominator of the RREF and + a list of the pivot columns. + + Explanation + =========== + + The algorithm used is the fraction-free version of Gauss-Jordan elimination + described as FFGJ in [1]_. Here it is modified to handle zero or missing + pivots and to avoid redundant arithmetic. + + The domain $K$ must support exact division (``K.exquo``) but does not need + to be a field. This method is suitable for most exact rings and fields like + :ref:`ZZ`, :ref:`QQ` and :ref:`QQ(a)`. In the case of :ref:`QQ` or + :ref:`K(x)` it might be more efficient to clear denominators and use + :ref:`ZZ` or :ref:`K[x]` instead. + + For inexact domains like :ref:`RR` and :ref:`CC` use ``ddm_irref`` instead. + + Examples + ======== + + >>> from sympy.polys.matrices.dense import ddm_irref_den + >>> from sympy import ZZ, Matrix + >>> M = [[ZZ(1), ZZ(2), ZZ(3)], [ZZ(4), ZZ(5), ZZ(6)]] + >>> den, pivots = ddm_irref_den(M, ZZ) + >>> M + [[-3, 0, 3], [0, -3, -6]] + >>> den + -3 + >>> pivots + [0, 1] + >>> Matrix(M).rref()[0] + Matrix([ + [1, 0, -1], + [0, 1, 2]]) + + See Also + ======== + + ddm_irref + A version of this routine that uses field division. + sdm_irref + A sparse version of :func:`ddm_irref`. + sdm_rref_den + A sparse version of :func:`ddm_irref_den`. + sympy.polys.matrices.domainmatrix.DomainMatrix.rref_den + Higher level interface. + + References + ========== + + .. [1] Fraction-free algorithms for linear and polynomial equations. + George C. Nakos , Peter R. Turner , Robert M. Williams. + https://dl.acm.org/doi/10.1145/271130.271133 + """ + # + # A simpler presentation of this algorithm is given in [1]: + # + # Given an n x n matrix A and n x 1 matrix b: + # + # for i in range(n): + # if i != 0: + # d = a[i-1][i-1] + # for j in range(n): + # if j == i: + # continue + # b[j] = a[i][i]*b[j] - a[j][i]*b[i] + # for k in range(n): + # a[j][k] = a[i][i]*a[j][k] - a[j][i]*a[i][k] + # if i != 0: + # a[j][k] /= d + # + # Our version here is a bit more complicated because: + # + # 1. We use row-swaps to avoid zero pivots. + # 2. We allow for some columns to be missing pivots. + # 3. We avoid a lot of redundant arithmetic. + # + # TODO: Use a non-trivial pivoting strategy. Even just row swapping makes a + # big difference to performance if e.g. the upper-left entry of the matrix + # is a huge polynomial. + + # a is (m x n) + m = len(a) + if not m: + return K.one, [] + n = len(a[0]) + + d = None + pivots = [] + no_pivots = [] + + # i, j will be the row and column indices of the current pivot + i = 0 + for j in range(n): + # next pivot? + aij = a[i][j] + + # swap rows if zero + if not aij: + for ip in range(i+1, m): + aij = a[ip][j] + # row-swap + if aij: + a[i], a[ip] = a[ip], a[i] + break + else: + # go to next column + no_pivots.append(j) + continue + + # Now aij is the pivot and i,j are the row and column. We need to clear + # the column above and below but we also need to keep track of the + # denominator of the RREF which means also multiplying everything above + # and to the left by the current pivot aij and dividing by d (which we + # multiplied everything by in the previous iteration so this is an + # exact division). + # + # First handle the upper left corner which is usually already diagonal + # with all diagonal entries equal to the current denominator but there + # can be other non-zero entries in any column that has no pivot. + + # Update previous pivots in the matrix + if pivots: + pivot_val = aij * a[0][pivots[0]] + # Divide out the common factor + if d is not None: + pivot_val = K.exquo(pivot_val, d) + + # Could defer this until the end but it is pretty cheap and + # helps when debugging. + for ip, jp in enumerate(pivots): + a[ip][jp] = pivot_val + + # Update columns without pivots + for jnp in no_pivots: + for ip in range(i): + aijp = a[ip][jnp] + if aijp: + aijp *= aij + if d is not None: + aijp = K.exquo(aijp, d) + a[ip][jnp] = aijp + + # Eliminate above, below and to the right as in ordinary division free + # Gauss-Jordan elmination except also dividing out d from every entry. + + for jp, aj in enumerate(a): + + # Skip the current row + if jp == i: + continue + + # Eliminate to the right in all rows + for kp in range(j+1, n): + ajk = aij * aj[kp] - aj[j] * a[i][kp] + if d is not None: + ajk = K.exquo(ajk, d) + aj[kp] = ajk + + # Set to zero above and below the pivot + aj[j] = K.zero + + # next row + pivots.append(j) + i += 1 + + # no more rows left? + if i >= m: + break + + if not K.is_one(aij): + d = aij + else: + d = None + + if not pivots: + denom = K.one + else: + denom = a[0][pivots[0]] + + return denom, pivots + + +def ddm_idet(a, K): + """a <-- echelon(a); return det + + Explanation + =========== + + Compute the determinant of $a$ using the Bareiss fraction-free algorithm. + The matrix $a$ is modified in place. Its diagonal elements are the + determinants of the leading principal minors. The determinant of $a$ is + returned. + + The domain $K$ must support exact division (``K.exquo``). This method is + suitable for most exact rings and fields like :ref:`ZZ`, :ref:`QQ` and + :ref:`QQ(a)` but not for inexact domains like :ref:`RR` and :ref:`CC`. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices.ddm import ddm_idet + >>> a = [[ZZ(1), ZZ(2), ZZ(3)], [ZZ(4), ZZ(5), ZZ(6)], [ZZ(7), ZZ(8), ZZ(9)]] + >>> a + [[1, 2, 3], [4, 5, 6], [7, 8, 9]] + >>> ddm_idet(a, ZZ) + 0 + >>> a + [[1, 2, 3], [4, -3, -6], [7, -6, 0]] + >>> [a[i][i] for i in range(len(a))] + [1, -3, 0] + + See Also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.det + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Bareiss_algorithm + .. [2] https://www.math.usm.edu/perry/Research/Thesis_DRL.pdf + """ + # Bareiss algorithm + # https://www.math.usm.edu/perry/Research/Thesis_DRL.pdf + + # a is (m x n) + m = len(a) + if not m: + return K.one + n = len(a[0]) + + exquo = K.exquo + # uf keeps track of the sign change from row swaps + uf = K.one + + for k in range(n-1): + if not a[k][k]: + for i in range(k+1, n): + if a[i][k]: + a[k], a[i] = a[i], a[k] + uf = -uf + break + else: + return K.zero + + akkm1 = a[k-1][k-1] if k else K.one + + for i in range(k+1, n): + for j in range(k+1, n): + a[i][j] = exquo(a[i][j]*a[k][k] - a[i][k]*a[k][j], akkm1) + + return uf * a[-1][-1] + + +def ddm_iinv(ainv, a, K): + """ainv <-- inv(a) + + Compute the inverse of a matrix $a$ over a field $K$ using Gauss-Jordan + elimination. The result is stored in $ainv$. + + Uses division in the ground domain which should be an exact field. + + Examples + ======== + + >>> from sympy.polys.matrices.ddm import ddm_iinv, ddm_imatmul + >>> from sympy import QQ + >>> a = [[QQ(1), QQ(2)], [QQ(3), QQ(4)]] + >>> ainv = [[None, None], [None, None]] + >>> ddm_iinv(ainv, a, QQ) + >>> ainv + [[-2, 1], [3/2, -1/2]] + >>> result = [[QQ(0), QQ(0)], [QQ(0), QQ(0)]] + >>> ddm_imatmul(result, a, ainv) + >>> result + [[1, 0], [0, 1]] + + See Also + ======== + + ddm_irref: the underlying routine. + """ + if not K.is_Field: + raise DMDomainError('Not a field') + + # a is (m x n) + m = len(a) + if not m: + return + n = len(a[0]) + if m != n: + raise DMNonSquareMatrixError + + eye = [[K.one if i==j else K.zero for j in range(n)] for i in range(n)] + Aaug = [row + eyerow for row, eyerow in zip(a, eye)] + pivots = ddm_irref(Aaug) + if pivots != list(range(n)): + raise DMNonInvertibleMatrixError('Matrix det == 0; not invertible.') + ainv[:] = [row[n:] for row in Aaug] + + +def ddm_ilu_split(L, U, K): + """L, U <-- LU(U) + + Compute the LU decomposition of a matrix $L$ in place and store the lower + and upper triangular matrices in $L$ and $U$, respectively. Returns a list + of row swaps that were performed. + + Uses division in the ground domain which should be an exact field. + + Examples + ======== + + >>> from sympy.polys.matrices.ddm import ddm_ilu_split + >>> from sympy import QQ + >>> L = [[QQ(0), QQ(0)], [QQ(0), QQ(0)]] + >>> U = [[QQ(1), QQ(2)], [QQ(3), QQ(4)]] + >>> swaps = ddm_ilu_split(L, U, QQ) + >>> swaps + [] + >>> L + [[0, 0], [3, 0]] + >>> U + [[1, 2], [0, -2]] + + See Also + ======== + + ddm_ilu + ddm_ilu_solve + """ + m = len(U) + if not m: + return [] + n = len(U[0]) + + swaps = ddm_ilu(U) + + zeros = [K.zero] * min(m, n) + for i in range(1, m): + j = min(i, n) + L[i][:j] = U[i][:j] + U[i][:j] = zeros[:j] + + return swaps + + +def ddm_ilu(a): + """a <-- LU(a) + + Computes the LU decomposition of a matrix in place. Returns a list of + row swaps that were performed. + + Uses division in the ground domain which should be an exact field. + + This is only suitable for domains like :ref:`GF(p)`, :ref:`QQ`, :ref:`QQ_I` + and :ref:`QQ(a)`. With a rational function field like :ref:`K(x)` it is + better to clear denominators and use division-free algorithms. Pivoting is + used to avoid exact zeros but not for floating point accuracy so :ref:`RR` + and :ref:`CC` are not suitable (use :func:`ddm_irref` instead). + + Examples + ======== + + >>> from sympy.polys.matrices.dense import ddm_ilu + >>> from sympy import QQ + >>> a = [[QQ(1, 2), QQ(1, 3)], [QQ(1, 4), QQ(1, 5)]] + >>> swaps = ddm_ilu(a) + >>> swaps + [] + >>> a + [[1/2, 1/3], [1/2, 1/30]] + + The same example using ``Matrix``: + + >>> from sympy import Matrix, S + >>> M = Matrix([[S(1)/2, S(1)/3], [S(1)/4, S(1)/5]]) + >>> L, U, swaps = M.LUdecomposition() + >>> L + Matrix([ + [ 1, 0], + [1/2, 1]]) + >>> U + Matrix([ + [1/2, 1/3], + [ 0, 1/30]]) + >>> swaps + [] + + See Also + ======== + + ddm_irref + ddm_ilu_solve + sympy.matrices.matrixbase.MatrixBase.LUdecomposition + """ + m = len(a) + if not m: + return [] + n = len(a[0]) + + swaps = [] + + for i in range(min(m, n)): + if not a[i][i]: + for ip in range(i+1, m): + if a[ip][i]: + swaps.append((i, ip)) + a[i], a[ip] = a[ip], a[i] + break + else: + # M = Matrix([[1, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 1], [0, 0, 1, 2]]) + continue + for j in range(i+1, m): + l_ji = a[j][i] / a[i][i] + a[j][i] = l_ji + for k in range(i+1, n): + a[j][k] -= l_ji * a[i][k] + + return swaps + + +def ddm_ilu_solve(x, L, U, swaps, b): + """x <-- solve(L*U*x = swaps(b)) + + Solve a linear system, $A*x = b$, given an LU factorization of $A$. + + Uses division in the ground domain which must be a field. + + Modifies $x$ in place. + + Examples + ======== + + Compute the LU decomposition of $A$ (in place): + + >>> from sympy import QQ + >>> from sympy.polys.matrices.dense import ddm_ilu, ddm_ilu_solve + >>> A = [[QQ(1), QQ(2)], [QQ(3), QQ(4)]] + >>> swaps = ddm_ilu(A) + >>> A + [[1, 2], [3, -2]] + >>> L = U = A + + Solve the linear system: + + >>> b = [[QQ(5)], [QQ(6)]] + >>> x = [[None], [None]] + >>> ddm_ilu_solve(x, L, U, swaps, b) + >>> x + [[-4], [9/2]] + + See Also + ======== + + ddm_ilu + Compute the LU decomposition of a matrix in place. + ddm_ilu_split + Compute the LU decomposition of a matrix and separate $L$ and $U$. + sympy.polys.matrices.domainmatrix.DomainMatrix.lu_solve + Higher level interface to this function. + """ + m = len(U) + if not m: + return + n = len(U[0]) + + m2 = len(b) + if not m2: + raise DMShapeError("Shape mismtch") + o = len(b[0]) + + if m != m2: + raise DMShapeError("Shape mismtch") + if m < n: + raise NotImplementedError("Underdetermined") + + if swaps: + b = [row[:] for row in b] + for i1, i2 in swaps: + b[i1], b[i2] = b[i2], b[i1] + + # solve Ly = b + y = [[None] * o for _ in range(m)] + for k in range(o): + for i in range(m): + rhs = b[i][k] + for j in range(i): + rhs -= L[i][j] * y[j][k] + y[i][k] = rhs + + if m > n: + for i in range(n, m): + for j in range(o): + if y[i][j]: + raise DMNonInvertibleMatrixError + + # Solve Ux = y + for k in range(o): + for i in reversed(range(n)): + if not U[i][i]: + raise DMNonInvertibleMatrixError + rhs = y[i][k] + for j in range(i+1, n): + rhs -= U[i][j] * x[j][k] + x[i][k] = rhs / U[i][i] + + +def ddm_berk(M, K): + """ + Berkowitz algorithm for computing the characteristic polynomial. + + Explanation + =========== + + The Berkowitz algorithm is a division-free algorithm for computing the + characteristic polynomial of a matrix over any commutative ring using only + arithmetic in the coefficient ring. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.polys.matrices.dense import ddm_berk + >>> from sympy.polys.domains import ZZ + >>> M = [[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]] + >>> ddm_berk(M, ZZ) + [[1], [-5], [-2]] + >>> Matrix(M).charpoly() + PurePoly(lambda**2 - 5*lambda - 2, lambda, domain='ZZ') + + See Also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.charpoly + The high-level interface to this function. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Samuelson%E2%80%93Berkowitz_algorithm + """ + m = len(M) + if not m: + return [[K.one]] + n = len(M[0]) + + if m != n: + raise DMShapeError("Not square") + + if n == 1: + return [[K.one], [-M[0][0]]] + + a = M[0][0] + R = [M[0][1:]] + C = [[row[0]] for row in M[1:]] + A = [row[1:] for row in M[1:]] + + q = ddm_berk(A, K) + + T = [[K.zero] * n for _ in range(n+1)] + for i in range(n): + T[i][i] = K.one + T[i+1][i] = -a + for i in range(2, n+1): + if i == 2: + AnC = C + else: + C = AnC + AnC = [[K.zero] for row in C] + ddm_imatmul(AnC, A, C) + RAnC = [[K.zero]] + ddm_imatmul(RAnC, R, AnC) + for j in range(0, n+1-i): + T[i+j][j] = -RAnC[0][0] + + qout = [[K.zero] for _ in range(n+1)] + ddm_imatmul(qout, T, q) + return qout diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/dfm.py b/phivenv/Lib/site-packages/sympy/polys/matrices/dfm.py new file mode 100644 index 0000000000000000000000000000000000000000..22938b7004654121f74b020bd6649bee84909e1e --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/matrices/dfm.py @@ -0,0 +1,35 @@ +""" +sympy.polys.matrices.dfm + +Provides the :class:`DFM` class if ``GROUND_TYPES=flint'``. Otherwise, ``DFM`` +is a placeholder class that raises NotImplementedError when instantiated. +""" + +from sympy.external.gmpy import GROUND_TYPES + +if GROUND_TYPES == "flint": # pragma: no cover + # When python-flint is installed we will try to use it for dense matrices + # if the domain is supported by python-flint. + from ._dfm import DFM + +else: # pragma: no cover + # Other code should be able to import this and it should just present as a + # version of DFM that does not support any domains. + class DFM_dummy: + """ + Placeholder class for DFM when python-flint is not installed. + """ + def __init__(*args, **kwargs): + raise NotImplementedError("DFM requires GROUND_TYPES=flint.") + + @classmethod + def _supports_domain(cls, domain): + return False + + @classmethod + def _get_flint_func(cls, domain): + raise NotImplementedError("DFM requires GROUND_TYPES=flint.") + + # mypy really struggles with this kind of conditional type assignment. + # Maybe there is a better way to annotate this rather than type: ignore. + DFM = DFM_dummy # type: ignore diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/domainmatrix.py b/phivenv/Lib/site-packages/sympy/polys/matrices/domainmatrix.py new file mode 100644 index 0000000000000000000000000000000000000000..627835eca93b5e70f9aa121f097c9828a709ca78 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/matrices/domainmatrix.py @@ -0,0 +1,3983 @@ +""" + +Module for the DomainMatrix class. + +A DomainMatrix represents a matrix with elements that are in a particular +Domain. Each DomainMatrix internally wraps a DDM which is used for the +lower-level operations. The idea is that the DomainMatrix class provides the +convenience routines for converting between Expr and the poly domains as well +as unifying matrices with different domains. + +""" +from __future__ import annotations +from collections import Counter +from functools import reduce + +from sympy.external.gmpy import GROUND_TYPES +from sympy.utilities.decorator import doctest_depends_on + +from sympy.core.sympify import _sympify + +from ..domains import Domain + +from ..constructor import construct_domain + +from .exceptions import ( + DMFormatError, + DMBadInputError, + DMShapeError, + DMDomainError, + DMNotAField, + DMNonSquareMatrixError, + DMNonInvertibleMatrixError +) + +from .domainscalar import DomainScalar + +from sympy.polys.domains import ZZ, EXRAW, QQ + +from sympy.polys.densearith import dup_mul +from sympy.polys.densebasic import dup_convert +from sympy.polys.densetools import ( + dup_mul_ground, + dup_quo_ground, + dup_content, + dup_clear_denoms, + dup_primitive, + dup_transform, +) +from sympy.polys.factortools import dup_factor_list +from sympy.polys.polyutils import _sort_factors + +from .ddm import DDM + +from .sdm import SDM + +from .dfm import DFM + +from .rref import _dm_rref, _dm_rref_den + + +if GROUND_TYPES != 'flint': + __doctest_skip__ = ['DomainMatrix.to_dfm', 'DomainMatrix.to_dfm_or_ddm'] +else: + __doctest_skip__ = ['DomainMatrix.from_list'] + + +def DM(rows, domain): + """Convenient alias for DomainMatrix.from_list + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DM + >>> DM([[1, 2], [3, 4]], ZZ) + DomainMatrix([[1, 2], [3, 4]], (2, 2), ZZ) + + See Also + ======== + + DomainMatrix.from_list + """ + return DomainMatrix.from_list(rows, domain) + + +class DomainMatrix: + r""" + Associate Matrix with :py:class:`~.Domain` + + Explanation + =========== + + DomainMatrix uses :py:class:`~.Domain` for its internal representation + which makes it faster than the SymPy Matrix class (currently) for many + common operations, but this advantage makes it not entirely compatible + with Matrix. DomainMatrix are analogous to numpy arrays with "dtype". + In the DomainMatrix, each element has a domain such as :ref:`ZZ` + or :ref:`QQ(a)`. + + + Examples + ======== + + Creating a DomainMatrix from the existing Matrix class: + + >>> from sympy import Matrix + >>> from sympy.polys.matrices import DomainMatrix + >>> Matrix1 = Matrix([ + ... [1, 2], + ... [3, 4]]) + >>> A = DomainMatrix.from_Matrix(Matrix1) + >>> A + DomainMatrix({0: {0: 1, 1: 2}, 1: {0: 3, 1: 4}}, (2, 2), ZZ) + + Directly forming a DomainMatrix: + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([ + ... [ZZ(1), ZZ(2)], + ... [ZZ(3), ZZ(4)]], (2, 2), ZZ) + >>> A + DomainMatrix([[1, 2], [3, 4]], (2, 2), ZZ) + + See Also + ======== + + DDM + SDM + Domain + Poly + + """ + rep: SDM | DDM | DFM + shape: tuple[int, int] + domain: Domain + + def __new__(cls, rows, shape, domain, *, fmt=None): + """ + Creates a :py:class:`~.DomainMatrix`. + + Parameters + ========== + + rows : Represents elements of DomainMatrix as list of lists + shape : Represents dimension of DomainMatrix + domain : Represents :py:class:`~.Domain` of DomainMatrix + + Raises + ====== + + TypeError + If any of rows, shape and domain are not provided + + """ + if isinstance(rows, (DDM, SDM, DFM)): + raise TypeError("Use from_rep to initialise from SDM/DDM") + elif isinstance(rows, list): + rep = DDM(rows, shape, domain) + elif isinstance(rows, dict): + rep = SDM(rows, shape, domain) + else: + msg = "Input should be list-of-lists or dict-of-dicts" + raise TypeError(msg) + + if fmt is not None: + if fmt == 'sparse': + rep = rep.to_sdm() + elif fmt == 'dense': + rep = rep.to_ddm() + else: + raise ValueError("fmt should be 'sparse' or 'dense'") + + # Use python-flint for dense matrices if possible + if rep.fmt == 'dense' and DFM._supports_domain(domain): + rep = rep.to_dfm() + + return cls.from_rep(rep) + + def __reduce__(self): + rep = self.rep + if rep.fmt == 'dense': + arg = self.to_list() + elif rep.fmt == 'sparse': + arg = dict(rep) + else: + raise RuntimeError # pragma: no cover + args = (arg, rep.shape, rep.domain) + return (self.__class__, args) + + def __getitem__(self, key): + i, j = key + m, n = self.shape + if not (isinstance(i, slice) or isinstance(j, slice)): + return DomainScalar(self.rep.getitem(i, j), self.domain) + + if not isinstance(i, slice): + if not -m <= i < m: + raise IndexError("Row index out of range") + i = i % m + i = slice(i, i+1) + if not isinstance(j, slice): + if not -n <= j < n: + raise IndexError("Column index out of range") + j = j % n + j = slice(j, j+1) + + return self.from_rep(self.rep.extract_slice(i, j)) + + def getitem_sympy(self, i, j): + return self.domain.to_sympy(self.rep.getitem(i, j)) + + def extract(self, rowslist, colslist): + return self.from_rep(self.rep.extract(rowslist, colslist)) + + def __setitem__(self, key, value): + i, j = key + if not self.domain.of_type(value): + raise TypeError + if isinstance(i, int) and isinstance(j, int): + self.rep.setitem(i, j, value) + else: + raise NotImplementedError + + @classmethod + def from_rep(cls, rep): + """Create a new DomainMatrix efficiently from DDM/SDM. + + Examples + ======== + + Create a :py:class:`~.DomainMatrix` with an dense internal + representation as :py:class:`~.DDM`: + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.matrices import DomainMatrix + >>> from sympy.polys.matrices.ddm import DDM + >>> drep = DDM([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + >>> dM = DomainMatrix.from_rep(drep) + >>> dM + DomainMatrix([[1, 2], [3, 4]], (2, 2), ZZ) + + Create a :py:class:`~.DomainMatrix` with a sparse internal + representation as :py:class:`~.SDM`: + + >>> from sympy.polys.matrices import DomainMatrix + >>> from sympy.polys.matrices.sdm import SDM + >>> from sympy import ZZ + >>> drep = SDM({0:{1:ZZ(1)},1:{0:ZZ(2)}}, (2, 2), ZZ) + >>> dM = DomainMatrix.from_rep(drep) + >>> dM + DomainMatrix({0: {1: 1}, 1: {0: 2}}, (2, 2), ZZ) + + Parameters + ========== + + rep: SDM or DDM + The internal sparse or dense representation of the matrix. + + Returns + ======= + + DomainMatrix + A :py:class:`~.DomainMatrix` wrapping *rep*. + + Notes + ===== + + This takes ownership of rep as its internal representation. If rep is + being mutated elsewhere then a copy should be provided to + ``from_rep``. Only minimal verification or checking is done on *rep* + as this is supposed to be an efficient internal routine. + + """ + if not (isinstance(rep, (DDM, SDM)) or (DFM is not None and isinstance(rep, DFM))): + raise TypeError("rep should be of type DDM or SDM") + self = super().__new__(cls) + self.rep = rep + self.shape = rep.shape + self.domain = rep.domain + return self + + @classmethod + @doctest_depends_on(ground_types=['python', 'gmpy']) + def from_list(cls, rows, domain): + r""" + Convert a list of lists into a DomainMatrix + + Parameters + ========== + + rows: list of lists + Each element of the inner lists should be either the single arg, + or tuple of args, that would be passed to the domain constructor + in order to form an element of the domain. See examples. + + Returns + ======= + + DomainMatrix containing elements defined in rows + + Examples + ======== + + >>> from sympy.polys.matrices import DomainMatrix + >>> from sympy import FF, QQ, ZZ + >>> A = DomainMatrix.from_list([[1, 0, 1], [0, 0, 1]], ZZ) + >>> A + DomainMatrix([[1, 0, 1], [0, 0, 1]], (2, 3), ZZ) + >>> B = DomainMatrix.from_list([[1, 0, 1], [0, 0, 1]], FF(7)) + >>> B + DomainMatrix([[1 mod 7, 0 mod 7, 1 mod 7], [0 mod 7, 0 mod 7, 1 mod 7]], (2, 3), GF(7)) + >>> C = DomainMatrix.from_list([[(1, 2), (3, 1)], [(1, 4), (5, 1)]], QQ) + >>> C + DomainMatrix([[1/2, 3], [1/4, 5]], (2, 2), QQ) + + See Also + ======== + + from_list_sympy + + """ + nrows = len(rows) + ncols = 0 if not nrows else len(rows[0]) + conv = lambda e: domain(*e) if isinstance(e, tuple) else domain(e) + domain_rows = [[conv(e) for e in row] for row in rows] + return DomainMatrix(domain_rows, (nrows, ncols), domain) + + @classmethod + def from_list_sympy(cls, nrows, ncols, rows, **kwargs): + r""" + Convert a list of lists of Expr into a DomainMatrix using construct_domain + + Parameters + ========== + + nrows: number of rows + ncols: number of columns + rows: list of lists + + Returns + ======= + + DomainMatrix containing elements of rows + + Examples + ======== + + >>> from sympy.polys.matrices import DomainMatrix + >>> from sympy.abc import x, y, z + >>> A = DomainMatrix.from_list_sympy(1, 3, [[x, y, z]]) + >>> A + DomainMatrix([[x, y, z]], (1, 3), ZZ[x,y,z]) + + See Also + ======== + + sympy.polys.constructor.construct_domain, from_dict_sympy + + """ + assert len(rows) == nrows + assert all(len(row) == ncols for row in rows) + + items_sympy = [_sympify(item) for row in rows for item in row] + + domain, items_domain = cls.get_domain(items_sympy, **kwargs) + + domain_rows = [[items_domain[ncols*r + c] for c in range(ncols)] for r in range(nrows)] + + return DomainMatrix(domain_rows, (nrows, ncols), domain) + + @classmethod + def from_dict_sympy(cls, nrows, ncols, elemsdict, **kwargs): + """ + + Parameters + ========== + + nrows: number of rows + ncols: number of cols + elemsdict: dict of dicts containing non-zero elements of the DomainMatrix + + Returns + ======= + + DomainMatrix containing elements of elemsdict + + Examples + ======== + + >>> from sympy.polys.matrices import DomainMatrix + >>> from sympy.abc import x,y,z + >>> elemsdict = {0: {0:x}, 1:{1: y}, 2: {2: z}} + >>> A = DomainMatrix.from_dict_sympy(3, 3, elemsdict) + >>> A + DomainMatrix({0: {0: x}, 1: {1: y}, 2: {2: z}}, (3, 3), ZZ[x,y,z]) + + See Also + ======== + + from_list_sympy + + """ + if not all(0 <= r < nrows for r in elemsdict): + raise DMBadInputError("Row out of range") + if not all(0 <= c < ncols for row in elemsdict.values() for c in row): + raise DMBadInputError("Column out of range") + + items_sympy = [_sympify(item) for row in elemsdict.values() for item in row.values()] + domain, items_domain = cls.get_domain(items_sympy, **kwargs) + + idx = 0 + items_dict = {} + for i, row in elemsdict.items(): + items_dict[i] = {} + for j in row: + items_dict[i][j] = items_domain[idx] + idx += 1 + + return DomainMatrix(items_dict, (nrows, ncols), domain) + + @classmethod + def from_Matrix(cls, M, fmt='sparse',**kwargs): + r""" + Convert Matrix to DomainMatrix + + Parameters + ========== + + M: Matrix + + Returns + ======= + + Returns DomainMatrix with identical elements as M + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.polys.matrices import DomainMatrix + >>> M = Matrix([ + ... [1.0, 3.4], + ... [2.4, 1]]) + >>> A = DomainMatrix.from_Matrix(M) + >>> A + DomainMatrix({0: {0: 1.0, 1: 3.4}, 1: {0: 2.4, 1: 1.0}}, (2, 2), RR) + + We can keep internal representation as ddm using fmt='dense' + >>> from sympy import Matrix, QQ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix.from_Matrix(Matrix([[QQ(1, 2), QQ(3, 4)], [QQ(0, 1), QQ(0, 1)]]), fmt='dense') + >>> A.rep + [[1/2, 3/4], [0, 0]] + + See Also + ======== + + Matrix + + """ + if fmt == 'dense': + return cls.from_list_sympy(*M.shape, M.tolist(), **kwargs) + + return cls.from_dict_sympy(*M.shape, M.todod(), **kwargs) + + @classmethod + def get_domain(cls, items_sympy, **kwargs): + K, items_K = construct_domain(items_sympy, **kwargs) + return K, items_K + + def choose_domain(self, **opts): + """Convert to a domain found by :func:`~.construct_domain`. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DM + >>> M = DM([[1, 2], [3, 4]], ZZ) + >>> M + DomainMatrix([[1, 2], [3, 4]], (2, 2), ZZ) + >>> M.choose_domain(field=True) + DomainMatrix([[1, 2], [3, 4]], (2, 2), QQ) + + >>> from sympy.abc import x + >>> M = DM([[1, x], [x**2, x**3]], ZZ[x]) + >>> M.choose_domain(field=True).domain + ZZ(x) + + Keyword arguments are passed to :func:`~.construct_domain`. + + See Also + ======== + + construct_domain + convert_to + """ + elements, data = self.to_sympy().to_flat_nz() + dom, elements_dom = construct_domain(elements, **opts) + return self.from_flat_nz(elements_dom, data, dom) + + def copy(self): + return self.from_rep(self.rep.copy()) + + def convert_to(self, K): + r""" + Change the domain of DomainMatrix to desired domain or field + + Parameters + ========== + + K : Represents the desired domain or field. + Alternatively, ``None`` may be passed, in which case this method + just returns a copy of this DomainMatrix. + + Returns + ======= + + DomainMatrix + DomainMatrix with the desired domain or field + + Examples + ======== + + >>> from sympy import ZZ, ZZ_I + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([ + ... [ZZ(1), ZZ(2)], + ... [ZZ(3), ZZ(4)]], (2, 2), ZZ) + + >>> A.convert_to(ZZ_I) + DomainMatrix([[1, 2], [3, 4]], (2, 2), ZZ_I) + + """ + if K == self.domain: + return self.copy() + + rep = self.rep + + # The DFM, DDM and SDM types do not do any implicit conversions so we + # manage switching between DDM and DFM here. + if rep.is_DFM and not DFM._supports_domain(K): + rep_K = rep.to_ddm().convert_to(K) + elif rep.is_DDM and DFM._supports_domain(K): + rep_K = rep.convert_to(K).to_dfm() + else: + rep_K = rep.convert_to(K) + + return self.from_rep(rep_K) + + def to_sympy(self): + return self.convert_to(EXRAW) + + def to_field(self): + r""" + Returns a DomainMatrix with the appropriate field + + Returns + ======= + + DomainMatrix + DomainMatrix with the appropriate field + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([ + ... [ZZ(1), ZZ(2)], + ... [ZZ(3), ZZ(4)]], (2, 2), ZZ) + + >>> A.to_field() + DomainMatrix([[1, 2], [3, 4]], (2, 2), QQ) + + """ + K = self.domain.get_field() + return self.convert_to(K) + + def to_sparse(self): + """ + Return a sparse DomainMatrix representation of *self*. + + Examples + ======== + + >>> from sympy.polys.matrices import DomainMatrix + >>> from sympy import QQ + >>> A = DomainMatrix([[1, 0],[0, 2]], (2, 2), QQ) + >>> A.rep + [[1, 0], [0, 2]] + >>> B = A.to_sparse() + >>> B.rep + {0: {0: 1}, 1: {1: 2}} + """ + if self.rep.fmt == 'sparse': + return self + + return self.from_rep(self.rep.to_sdm()) + + def to_dense(self): + """ + Return a dense DomainMatrix representation of *self*. + + Examples + ======== + + >>> from sympy.polys.matrices import DomainMatrix + >>> from sympy import QQ + >>> A = DomainMatrix({0: {0: 1}, 1: {1: 2}}, (2, 2), QQ) + >>> A.rep + {0: {0: 1}, 1: {1: 2}} + >>> B = A.to_dense() + >>> B.rep + [[1, 0], [0, 2]] + + """ + rep = self.rep + + if rep.fmt == 'dense': + return self + + return self.from_rep(rep.to_dfm_or_ddm()) + + def to_ddm(self): + """ + Return a :class:`~.DDM` representation of *self*. + + Examples + ======== + + >>> from sympy.polys.matrices import DomainMatrix + >>> from sympy import QQ + >>> A = DomainMatrix({0: {0: 1}, 1: {1: 2}}, (2, 2), QQ) + >>> ddm = A.to_ddm() + >>> ddm + [[1, 0], [0, 2]] + >>> type(ddm) + + + See Also + ======== + + to_sdm + to_dense + sympy.polys.matrices.ddm.DDM.to_sdm + """ + return self.rep.to_ddm() + + def to_sdm(self): + """ + Return a :class:`~.SDM` representation of *self*. + + Examples + ======== + + >>> from sympy.polys.matrices import DomainMatrix + >>> from sympy import QQ + >>> A = DomainMatrix([[1, 0],[0, 2]], (2, 2), QQ) + >>> sdm = A.to_sdm() + >>> sdm + {0: {0: 1}, 1: {1: 2}} + >>> type(sdm) + + + See Also + ======== + + to_ddm + to_sparse + sympy.polys.matrices.sdm.SDM.to_ddm + """ + return self.rep.to_sdm() + + @doctest_depends_on(ground_types=['flint']) + def to_dfm(self): + """ + Return a :class:`~.DFM` representation of *self*. + + Examples + ======== + + >>> from sympy.polys.matrices import DomainMatrix + >>> from sympy import QQ + >>> A = DomainMatrix([[1, 0],[0, 2]], (2, 2), QQ) + >>> dfm = A.to_dfm() + >>> dfm + [[1, 0], [0, 2]] + >>> type(dfm) + + + See Also + ======== + + to_ddm + to_dense + DFM + """ + return self.rep.to_dfm() + + @doctest_depends_on(ground_types=['flint']) + def to_dfm_or_ddm(self): + """ + Return a :class:`~.DFM` or :class:`~.DDM` representation of *self*. + + Explanation + =========== + + The :class:`~.DFM` representation can only be used if the ground types + are ``flint`` and the ground domain is supported by ``python-flint``. + This method will return a :class:`~.DFM` representation if possible, + but will return a :class:`~.DDM` representation otherwise. + + Examples + ======== + + >>> from sympy.polys.matrices import DomainMatrix + >>> from sympy import QQ + >>> A = DomainMatrix([[1, 0],[0, 2]], (2, 2), QQ) + >>> dfm = A.to_dfm_or_ddm() + >>> dfm + [[1, 0], [0, 2]] + >>> type(dfm) # Depends on the ground domain and ground types + + + See Also + ======== + + to_ddm: Always return a :class:`~.DDM` representation. + to_dfm: Returns a :class:`~.DFM` representation or raise an error. + to_dense: Convert internally to a :class:`~.DFM` or :class:`~.DDM` + DFM: The :class:`~.DFM` dense FLINT matrix representation. + DDM: The Python :class:`~.DDM` dense domain matrix representation. + """ + return self.rep.to_dfm_or_ddm() + + @classmethod + def _unify_domain(cls, *matrices): + """Convert matrices to a common domain""" + domains = {matrix.domain for matrix in matrices} + if len(domains) == 1: + return matrices + domain = reduce(lambda x, y: x.unify(y), domains) + return tuple(matrix.convert_to(domain) for matrix in matrices) + + @classmethod + def _unify_fmt(cls, *matrices, fmt=None): + """Convert matrices to the same format. + + If all matrices have the same format, then return unmodified. + Otherwise convert both to the preferred format given as *fmt* which + should be 'dense' or 'sparse'. + """ + formats = {matrix.rep.fmt for matrix in matrices} + if len(formats) == 1: + return matrices + if fmt == 'sparse': + return tuple(matrix.to_sparse() for matrix in matrices) + elif fmt == 'dense': + return tuple(matrix.to_dense() for matrix in matrices) + else: + raise ValueError("fmt should be 'sparse' or 'dense'") + + def unify(self, *others, fmt=None): + """ + Unifies the domains and the format of self and other + matrices. + + Parameters + ========== + + others : DomainMatrix + + fmt: string 'dense', 'sparse' or `None` (default) + The preferred format to convert to if self and other are not + already in the same format. If `None` or not specified then no + conversion if performed. + + Returns + ======= + + Tuple[DomainMatrix] + Matrices with unified domain and format + + Examples + ======== + + Unify the domain of DomainMatrix that have different domains: + + >>> from sympy import ZZ, QQ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([[ZZ(1), ZZ(2)]], (1, 2), ZZ) + >>> B = DomainMatrix([[QQ(1, 2), QQ(2)]], (1, 2), QQ) + >>> Aq, Bq = A.unify(B) + >>> Aq + DomainMatrix([[1, 2]], (1, 2), QQ) + >>> Bq + DomainMatrix([[1/2, 2]], (1, 2), QQ) + + Unify the format (dense or sparse): + + >>> A = DomainMatrix([[ZZ(1), ZZ(2)]], (1, 2), ZZ) + >>> B = DomainMatrix({0:{0: ZZ(1)}}, (2, 2), ZZ) + >>> B.rep + {0: {0: 1}} + + >>> A2, B2 = A.unify(B, fmt='dense') + >>> B2.rep + [[1, 0], [0, 0]] + + See Also + ======== + + convert_to, to_dense, to_sparse + + """ + matrices = (self,) + others + matrices = DomainMatrix._unify_domain(*matrices) + if fmt is not None: + matrices = DomainMatrix._unify_fmt(*matrices, fmt=fmt) + return matrices + + def to_Matrix(self): + r""" + Convert DomainMatrix to Matrix + + Returns + ======= + + Matrix + MutableDenseMatrix for the DomainMatrix + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([ + ... [ZZ(1), ZZ(2)], + ... [ZZ(3), ZZ(4)]], (2, 2), ZZ) + + >>> A.to_Matrix() + Matrix([ + [1, 2], + [3, 4]]) + + See Also + ======== + + from_Matrix + + """ + from sympy.matrices.dense import MutableDenseMatrix + + # XXX: If the internal representation of RepMatrix changes then this + # might need to be changed also. + if self.domain in (ZZ, QQ, EXRAW): + if self.rep.fmt == "sparse": + rep = self.copy() + else: + rep = self.to_sparse() + else: + rep = self.convert_to(EXRAW).to_sparse() + + return MutableDenseMatrix._fromrep(rep) + + def to_list(self): + """ + Convert :class:`DomainMatrix` to list of lists. + + See Also + ======== + + from_list + to_list_flat + to_flat_nz + to_dok + """ + return self.rep.to_list() + + def to_list_flat(self): + """ + Convert :class:`DomainMatrix` to flat list. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + >>> A.to_list_flat() + [1, 2, 3, 4] + + See Also + ======== + + from_list_flat + to_list + to_flat_nz + to_dok + """ + return self.rep.to_list_flat() + + @classmethod + def from_list_flat(cls, elements, shape, domain): + """ + Create :class:`DomainMatrix` from flat list. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DomainMatrix + >>> element_list = [ZZ(1), ZZ(2), ZZ(3), ZZ(4)] + >>> A = DomainMatrix.from_list_flat(element_list, (2, 2), ZZ) + >>> A + DomainMatrix([[1, 2], [3, 4]], (2, 2), ZZ) + >>> A == A.from_list_flat(A.to_list_flat(), A.shape, A.domain) + True + + See Also + ======== + + to_list_flat + """ + ddm = DDM.from_list_flat(elements, shape, domain) + return cls.from_rep(ddm.to_dfm_or_ddm()) + + def to_flat_nz(self): + """ + Convert :class:`DomainMatrix` to list of nonzero elements and data. + + Explanation + =========== + + Returns a tuple ``(elements, data)`` where ``elements`` is a list of + elements of the matrix with zeros possibly excluded. The matrix can be + reconstructed by passing these to :meth:`from_flat_nz`. The idea is to + be able to modify a flat list of the elements and then create a new + matrix of the same shape with the modified elements in the same + positions. + + The format of ``data`` differs depending on whether the underlying + representation is dense or sparse but either way it represents the + positions of the elements in the list in a way that + :meth:`from_flat_nz` can use to reconstruct the matrix. The + :meth:`from_flat_nz` method should be called on the same + :class:`DomainMatrix` that was used to call :meth:`to_flat_nz`. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([ + ... [ZZ(1), ZZ(2)], + ... [ZZ(3), ZZ(4)]], (2, 2), ZZ) + >>> elements, data = A.to_flat_nz() + >>> elements + [1, 2, 3, 4] + >>> A == A.from_flat_nz(elements, data, A.domain) + True + + Create a matrix with the elements doubled: + + >>> elements_doubled = [2*x for x in elements] + >>> A2 = A.from_flat_nz(elements_doubled, data, A.domain) + >>> A2 == 2*A + True + + See Also + ======== + + from_flat_nz + """ + return self.rep.to_flat_nz() + + def from_flat_nz(self, elements, data, domain): + """ + Reconstruct :class:`DomainMatrix` after calling :meth:`to_flat_nz`. + + See :meth:`to_flat_nz` for explanation. + + See Also + ======== + + to_flat_nz + """ + rep = self.rep.from_flat_nz(elements, data, domain) + return self.from_rep(rep) + + def to_dod(self): + """ + Convert :class:`DomainMatrix` to dictionary of dictionaries (dod) format. + + Explanation + =========== + + Returns a dictionary of dictionaries representing the matrix. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DM + >>> A = DM([[ZZ(1), ZZ(2), ZZ(0)], [ZZ(3), ZZ(0), ZZ(4)]], ZZ) + >>> A.to_dod() + {0: {0: 1, 1: 2}, 1: {0: 3, 2: 4}} + >>> A.to_sparse() == A.from_dod(A.to_dod(), A.shape, A.domain) + True + >>> A == A.from_dod_like(A.to_dod()) + True + + See Also + ======== + + from_dod + from_dod_like + to_dok + to_list + to_list_flat + to_flat_nz + sympy.matrices.matrixbase.MatrixBase.todod + """ + return self.rep.to_dod() + + @classmethod + def from_dod(cls, dod, shape, domain): + """ + Create sparse :class:`DomainMatrix` from dict of dict (dod) format. + + See :meth:`to_dod` for explanation. + + See Also + ======== + + to_dod + from_dod_like + """ + return cls.from_rep(SDM.from_dod(dod, shape, domain)) + + def from_dod_like(self, dod, domain=None): + """ + Create :class:`DomainMatrix` like ``self`` from dict of dict (dod) format. + + See :meth:`to_dod` for explanation. + + See Also + ======== + + to_dod + from_dod + """ + if domain is None: + domain = self.domain + return self.from_rep(self.rep.from_dod(dod, self.shape, domain)) + + def to_dok(self): + """ + Convert :class:`DomainMatrix` to dictionary of keys (dok) format. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([ + ... [ZZ(1), ZZ(0)], + ... [ZZ(0), ZZ(4)]], (2, 2), ZZ) + >>> A.to_dok() + {(0, 0): 1, (1, 1): 4} + + The matrix can be reconstructed by calling :meth:`from_dok` although + the reconstructed matrix will always be in sparse format: + + >>> A.to_sparse() == A.from_dok(A.to_dok(), A.shape, A.domain) + True + + See Also + ======== + + from_dok + to_list + to_list_flat + to_flat_nz + """ + return self.rep.to_dok() + + @classmethod + def from_dok(cls, dok, shape, domain): + """ + Create :class:`DomainMatrix` from dictionary of keys (dok) format. + + See :meth:`to_dok` for explanation. + + See Also + ======== + + to_dok + """ + return cls.from_rep(SDM.from_dok(dok, shape, domain)) + + def iter_values(self): + """ + Iterate over nonzero elements of the matrix. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([[ZZ(1), ZZ(0)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + >>> list(A.iter_values()) + [1, 3, 4] + + See Also + ======== + + iter_items + to_list_flat + sympy.matrices.matrixbase.MatrixBase.iter_values + """ + return self.rep.iter_values() + + def iter_items(self): + """ + Iterate over indices and values of nonzero elements of the matrix. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([[ZZ(1), ZZ(0)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + >>> list(A.iter_items()) + [((0, 0), 1), ((1, 0), 3), ((1, 1), 4)] + + See Also + ======== + + iter_values + to_dok + sympy.matrices.matrixbase.MatrixBase.iter_items + """ + return self.rep.iter_items() + + def nnz(self): + """ + Number of nonzero elements in the matrix. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DM + >>> A = DM([[1, 0], [0, 4]], ZZ) + >>> A.nnz() + 2 + """ + return self.rep.nnz() + + def __repr__(self): + return 'DomainMatrix(%s, %r, %r)' % (str(self.rep), self.shape, self.domain) + + def transpose(self): + """Matrix transpose of ``self``""" + return self.from_rep(self.rep.transpose()) + + def flat(self): + rows, cols = self.shape + return [self[i,j].element for i in range(rows) for j in range(cols)] + + @property + def is_zero_matrix(self): + return self.rep.is_zero_matrix() + + @property + def is_upper(self): + """ + Says whether this matrix is upper-triangular. True can be returned + even if the matrix is not square. + """ + return self.rep.is_upper() + + @property + def is_lower(self): + """ + Says whether this matrix is lower-triangular. True can be returned + even if the matrix is not square. + """ + return self.rep.is_lower() + + @property + def is_diagonal(self): + """ + True if the matrix is diagonal. + + Can return true for non-square matrices. A matrix is diagonal if + ``M[i,j] == 0`` whenever ``i != j``. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DM + >>> M = DM([[ZZ(1), ZZ(0)], [ZZ(0), ZZ(1)]], ZZ) + >>> M.is_diagonal + True + + See Also + ======== + + is_upper + is_lower + is_square + diagonal + """ + return self.rep.is_diagonal() + + def diagonal(self): + """ + Get the diagonal entries of the matrix as a list. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DM + >>> M = DM([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], ZZ) + >>> M.diagonal() + [1, 4] + + See Also + ======== + + is_diagonal + diag + """ + return self.rep.diagonal() + + @property + def is_square(self): + """ + True if the matrix is square. + """ + return self.shape[0] == self.shape[1] + + def rank(self): + rref, pivots = self.rref() + return len(pivots) + + def hstack(A, *B): + r"""Horizontally stack the given matrices. + + Parameters + ========== + + B: DomainMatrix + Matrices to stack horizontally. + + Returns + ======= + + DomainMatrix + DomainMatrix by stacking horizontally. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DomainMatrix + + >>> A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + >>> B = DomainMatrix([[ZZ(5), ZZ(6)], [ZZ(7), ZZ(8)]], (2, 2), ZZ) + >>> A.hstack(B) + DomainMatrix([[1, 2, 5, 6], [3, 4, 7, 8]], (2, 4), ZZ) + + >>> C = DomainMatrix([[ZZ(9), ZZ(10)], [ZZ(11), ZZ(12)]], (2, 2), ZZ) + >>> A.hstack(B, C) + DomainMatrix([[1, 2, 5, 6, 9, 10], [3, 4, 7, 8, 11, 12]], (2, 6), ZZ) + + See Also + ======== + + unify + """ + A, *B = A.unify(*B, fmt=A.rep.fmt) + return DomainMatrix.from_rep(A.rep.hstack(*(Bk.rep for Bk in B))) + + def vstack(A, *B): + r"""Vertically stack the given matrices. + + Parameters + ========== + + B: DomainMatrix + Matrices to stack vertically. + + Returns + ======= + + DomainMatrix + DomainMatrix by stacking vertically. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DomainMatrix + + >>> A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + >>> B = DomainMatrix([[ZZ(5), ZZ(6)], [ZZ(7), ZZ(8)]], (2, 2), ZZ) + >>> A.vstack(B) + DomainMatrix([[1, 2], [3, 4], [5, 6], [7, 8]], (4, 2), ZZ) + + >>> C = DomainMatrix([[ZZ(9), ZZ(10)], [ZZ(11), ZZ(12)]], (2, 2), ZZ) + >>> A.vstack(B, C) + DomainMatrix([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]], (6, 2), ZZ) + + See Also + ======== + + unify + """ + A, *B = A.unify(*B, fmt='dense') + return DomainMatrix.from_rep(A.rep.vstack(*(Bk.rep for Bk in B))) + + def applyfunc(self, func, domain=None): + if domain is None: + domain = self.domain + return self.from_rep(self.rep.applyfunc(func, domain)) + + def __add__(A, B): + if not isinstance(B, DomainMatrix): + return NotImplemented + A, B = A.unify(B, fmt='dense') + return A.add(B) + + def __sub__(A, B): + if not isinstance(B, DomainMatrix): + return NotImplemented + A, B = A.unify(B, fmt='dense') + return A.sub(B) + + def __neg__(A): + return A.neg() + + def __mul__(A, B): + """A * B""" + if isinstance(B, DomainMatrix): + A, B = A.unify(B, fmt='dense') + return A.matmul(B) + elif B in A.domain: + return A.scalarmul(B) + elif isinstance(B, DomainScalar): + A, B = A.unify(B) + return A.scalarmul(B.element) + else: + return NotImplemented + + def __rmul__(A, B): + if B in A.domain: + return A.rscalarmul(B) + elif isinstance(B, DomainScalar): + A, B = A.unify(B) + return A.rscalarmul(B.element) + else: + return NotImplemented + + def __pow__(A, n): + """A ** n""" + if not isinstance(n, int): + return NotImplemented + return A.pow(n) + + def _check(a, op, b, ashape, bshape): + if a.domain != b.domain: + msg = "Domain mismatch: %s %s %s" % (a.domain, op, b.domain) + raise DMDomainError(msg) + if ashape != bshape: + msg = "Shape mismatch: %s %s %s" % (a.shape, op, b.shape) + raise DMShapeError(msg) + if a.rep.fmt != b.rep.fmt: + msg = "Format mismatch: %s %s %s" % (a.rep.fmt, op, b.rep.fmt) + raise DMFormatError(msg) + if type(a.rep) != type(b.rep): + msg = "Type mismatch: %s %s %s" % (type(a.rep), op, type(b.rep)) + raise DMFormatError(msg) + + def add(A, B): + r""" + Adds two DomainMatrix matrices of the same Domain + + Parameters + ========== + + A, B: DomainMatrix + matrices to add + + Returns + ======= + + DomainMatrix + DomainMatrix after Addition + + Raises + ====== + + DMShapeError + If the dimensions of the two DomainMatrix are not equal + + ValueError + If the domain of the two DomainMatrix are not same + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([ + ... [ZZ(1), ZZ(2)], + ... [ZZ(3), ZZ(4)]], (2, 2), ZZ) + >>> B = DomainMatrix([ + ... [ZZ(4), ZZ(3)], + ... [ZZ(2), ZZ(1)]], (2, 2), ZZ) + + >>> A.add(B) + DomainMatrix([[5, 5], [5, 5]], (2, 2), ZZ) + + See Also + ======== + + sub, matmul + + """ + A._check('+', B, A.shape, B.shape) + return A.from_rep(A.rep.add(B.rep)) + + + def sub(A, B): + r""" + Subtracts two DomainMatrix matrices of the same Domain + + Parameters + ========== + + A, B: DomainMatrix + matrices to subtract + + Returns + ======= + + DomainMatrix + DomainMatrix after Subtraction + + Raises + ====== + + DMShapeError + If the dimensions of the two DomainMatrix are not equal + + ValueError + If the domain of the two DomainMatrix are not same + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([ + ... [ZZ(1), ZZ(2)], + ... [ZZ(3), ZZ(4)]], (2, 2), ZZ) + >>> B = DomainMatrix([ + ... [ZZ(4), ZZ(3)], + ... [ZZ(2), ZZ(1)]], (2, 2), ZZ) + + >>> A.sub(B) + DomainMatrix([[-3, -1], [1, 3]], (2, 2), ZZ) + + See Also + ======== + + add, matmul + + """ + A._check('-', B, A.shape, B.shape) + return A.from_rep(A.rep.sub(B.rep)) + + def neg(A): + r""" + Returns the negative of DomainMatrix + + Parameters + ========== + + A : Represents a DomainMatrix + + Returns + ======= + + DomainMatrix + DomainMatrix after Negation + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([ + ... [ZZ(1), ZZ(2)], + ... [ZZ(3), ZZ(4)]], (2, 2), ZZ) + + >>> A.neg() + DomainMatrix([[-1, -2], [-3, -4]], (2, 2), ZZ) + + """ + return A.from_rep(A.rep.neg()) + + def mul(A, b): + r""" + Performs term by term multiplication for the second DomainMatrix + w.r.t first DomainMatrix. Returns a DomainMatrix whose rows are + list of DomainMatrix matrices created after term by term multiplication. + + Parameters + ========== + + A, B: DomainMatrix + matrices to multiply term-wise + + Returns + ======= + + DomainMatrix + DomainMatrix after term by term multiplication + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([ + ... [ZZ(1), ZZ(2)], + ... [ZZ(3), ZZ(4)]], (2, 2), ZZ) + >>> b = ZZ(2) + + >>> A.mul(b) + DomainMatrix([[2, 4], [6, 8]], (2, 2), ZZ) + + See Also + ======== + + matmul + + """ + return A.from_rep(A.rep.mul(b)) + + def rmul(A, b): + return A.from_rep(A.rep.rmul(b)) + + def matmul(A, B): + r""" + Performs matrix multiplication of two DomainMatrix matrices + + Parameters + ========== + + A, B: DomainMatrix + to multiply + + Returns + ======= + + DomainMatrix + DomainMatrix after multiplication + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([ + ... [ZZ(1), ZZ(2)], + ... [ZZ(3), ZZ(4)]], (2, 2), ZZ) + >>> B = DomainMatrix([ + ... [ZZ(1), ZZ(1)], + ... [ZZ(0), ZZ(1)]], (2, 2), ZZ) + + >>> A.matmul(B) + DomainMatrix([[1, 3], [3, 7]], (2, 2), ZZ) + + See Also + ======== + + mul, pow, add, sub + + """ + + A._check('*', B, A.shape[1], B.shape[0]) + return A.from_rep(A.rep.matmul(B.rep)) + + def _scalarmul(A, lamda, reverse): + if lamda == A.domain.zero: + return DomainMatrix.zeros(A.shape, A.domain) + elif lamda == A.domain.one: + return A.copy() + elif reverse: + return A.rmul(lamda) + else: + return A.mul(lamda) + + def scalarmul(A, lamda): + return A._scalarmul(lamda, reverse=False) + + def rscalarmul(A, lamda): + return A._scalarmul(lamda, reverse=True) + + def mul_elementwise(A, B): + assert A.domain == B.domain + return A.from_rep(A.rep.mul_elementwise(B.rep)) + + def __truediv__(A, lamda): + """ Method for Scalar Division""" + if isinstance(lamda, int) or ZZ.of_type(lamda): + lamda = DomainScalar(ZZ(lamda), ZZ) + elif A.domain.is_Field and lamda in A.domain: + K = A.domain + lamda = DomainScalar(K.convert(lamda), K) + + if not isinstance(lamda, DomainScalar): + return NotImplemented + + A, lamda = A.to_field().unify(lamda) + if lamda.element == lamda.domain.zero: + raise ZeroDivisionError + if lamda.element == lamda.domain.one: + return A + + return A.mul(1 / lamda.element) + + def pow(A, n): + r""" + Computes A**n + + Parameters + ========== + + A : DomainMatrix + + n : exponent for A + + Returns + ======= + + DomainMatrix + DomainMatrix on computing A**n + + Raises + ====== + + NotImplementedError + if n is negative. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([ + ... [ZZ(1), ZZ(1)], + ... [ZZ(0), ZZ(1)]], (2, 2), ZZ) + + >>> A.pow(2) + DomainMatrix([[1, 2], [0, 1]], (2, 2), ZZ) + + See Also + ======== + + matmul + + """ + nrows, ncols = A.shape + if nrows != ncols: + raise DMNonSquareMatrixError('Power of a nonsquare matrix') + if n < 0: + raise NotImplementedError('Negative powers') + elif n == 0: + return A.eye(nrows, A.domain) + elif n == 1: + return A + elif n % 2 == 1: + return A * A**(n - 1) + else: + sqrtAn = A ** (n // 2) + return sqrtAn * sqrtAn + + def scc(self): + """Compute the strongly connected components of a DomainMatrix + + Explanation + =========== + + A square matrix can be considered as the adjacency matrix for a + directed graph where the row and column indices are the vertices. In + this graph if there is an edge from vertex ``i`` to vertex ``j`` if + ``M[i, j]`` is nonzero. This routine computes the strongly connected + components of that graph which are subsets of the rows and columns that + are connected by some nonzero element of the matrix. The strongly + connected components are useful because many operations such as the + determinant can be computed by working with the submatrices + corresponding to each component. + + Examples + ======== + + Find the strongly connected components of a matrix: + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DomainMatrix + >>> M = DomainMatrix([[ZZ(1), ZZ(0), ZZ(2)], + ... [ZZ(0), ZZ(3), ZZ(0)], + ... [ZZ(4), ZZ(6), ZZ(5)]], (3, 3), ZZ) + >>> M.scc() + [[1], [0, 2]] + + Compute the determinant from the components: + + >>> MM = M.to_Matrix() + >>> MM + Matrix([ + [1, 0, 2], + [0, 3, 0], + [4, 6, 5]]) + >>> MM[[1], [1]] + Matrix([[3]]) + >>> MM[[0, 2], [0, 2]] + Matrix([ + [1, 2], + [4, 5]]) + >>> MM.det() + -9 + >>> MM[[1], [1]].det() * MM[[0, 2], [0, 2]].det() + -9 + + The components are given in reverse topological order and represent a + permutation of the rows and columns that will bring the matrix into + block lower-triangular form: + + >>> MM[[1, 0, 2], [1, 0, 2]] + Matrix([ + [3, 0, 0], + [0, 1, 2], + [6, 4, 5]]) + + Returns + ======= + + List of lists of integers + Each list represents a strongly connected component. + + See also + ======== + + sympy.matrices.matrixbase.MatrixBase.strongly_connected_components + sympy.utilities.iterables.strongly_connected_components + + """ + if not self.is_square: + raise DMNonSquareMatrixError('Matrix must be square for scc') + + return self.rep.scc() + + def clear_denoms(self, convert=False): + """ + Clear denominators, but keep the domain unchanged. + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.polys.matrices import DM + >>> A = DM([[(1,2), (1,3)], [(1,4), (1,5)]], QQ) + >>> den, Anum = A.clear_denoms() + >>> den.to_sympy() + 60 + >>> Anum.to_Matrix() + Matrix([ + [30, 20], + [15, 12]]) + >>> den * A == Anum + True + + The numerator matrix will be in the same domain as the original matrix + unless ``convert`` is set to ``True``: + + >>> A.clear_denoms()[1].domain + QQ + >>> A.clear_denoms(convert=True)[1].domain + ZZ + + The denominator is always in the associated ring: + + >>> A.clear_denoms()[0].domain + ZZ + >>> A.domain.get_ring() + ZZ + + See Also + ======== + + sympy.polys.polytools.Poly.clear_denoms + clear_denoms_rowwise + """ + elems0, data = self.to_flat_nz() + + K0 = self.domain + K1 = K0.get_ring() if K0.has_assoc_Ring else K0 + + den, elems1 = dup_clear_denoms(elems0, K0, K1, convert=convert) + + if convert: + Kden, Knum = K1, K1 + else: + Kden, Knum = K1, K0 + + den = DomainScalar(den, Kden) + num = self.from_flat_nz(elems1, data, Knum) + + return den, num + + def clear_denoms_rowwise(self, convert=False): + """ + Clear denominators from each row of the matrix. + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.polys.matrices import DM + >>> A = DM([[(1,2), (1,3), (1,4)], [(1,5), (1,6), (1,7)]], QQ) + >>> den, Anum = A.clear_denoms_rowwise() + >>> den.to_Matrix() + Matrix([ + [12, 0], + [ 0, 210]]) + >>> Anum.to_Matrix() + Matrix([ + [ 6, 4, 3], + [42, 35, 30]]) + + The denominator matrix is a diagonal matrix with the denominators of + each row on the diagonal. The invariants are: + + >>> den * A == Anum + True + >>> A == den.to_field().inv() * Anum + True + + The numerator matrix will be in the same domain as the original matrix + unless ``convert`` is set to ``True``: + + >>> A.clear_denoms_rowwise()[1].domain + QQ + >>> A.clear_denoms_rowwise(convert=True)[1].domain + ZZ + + The domain of the denominator matrix is the associated ring: + + >>> A.clear_denoms_rowwise()[0].domain + ZZ + + See Also + ======== + + sympy.polys.polytools.Poly.clear_denoms + clear_denoms + """ + dod = self.to_dod() + + K0 = self.domain + K1 = K0.get_ring() if K0.has_assoc_Ring else K0 + + diagonals = [K0.one] * self.shape[0] + dod_num = {} + for i, rowi in dod.items(): + indices, elems = zip(*rowi.items()) + den, elems_num = dup_clear_denoms(elems, K0, K1, convert=convert) + rowi_num = dict(zip(indices, elems_num)) + diagonals[i] = den + dod_num[i] = rowi_num + + if convert: + Kden, Knum = K1, K1 + else: + Kden, Knum = K1, K0 + + den = self.diag(diagonals, Kden) + num = self.from_dod_like(dod_num, Knum) + + return den, num + + def cancel_denom(self, denom): + """ + Cancel factors between a matrix and a denominator. + + Returns a matrix and denominator on lowest terms. + + Requires ``gcd`` in the ground domain. + + Methods like :meth:`solve_den`, :meth:`inv_den` and :meth:`rref_den` + return a matrix and denominator but not necessarily on lowest terms. + Reduction to lowest terms without fractions can be performed with + :meth:`cancel_denom`. + + Examples + ======== + + >>> from sympy.polys.matrices import DM + >>> from sympy import ZZ + >>> M = DM([[2, 2, 0], + ... [0, 2, 2], + ... [0, 0, 2]], ZZ) + >>> Minv, den = M.inv_den() + >>> Minv.to_Matrix() + Matrix([ + [1, -1, 1], + [0, 1, -1], + [0, 0, 1]]) + >>> den + 2 + >>> Minv_reduced, den_reduced = Minv.cancel_denom(den) + >>> Minv_reduced.to_Matrix() + Matrix([ + [1, -1, 1], + [0, 1, -1], + [0, 0, 1]]) + >>> den_reduced + 2 + >>> Minv_reduced.to_field() / den_reduced == Minv.to_field() / den + True + + The denominator is made canonical with respect to units (e.g. a + negative denominator is made positive): + + >>> M = DM([[2, 2, 0]], ZZ) + >>> den = ZZ(-4) + >>> M.cancel_denom(den) + (DomainMatrix([[-1, -1, 0]], (1, 3), ZZ), 2) + + Any factor common to _all_ elements will be cancelled but there can + still be factors in common between _some_ elements of the matrix and + the denominator. To cancel factors between each element and the + denominator, use :meth:`cancel_denom_elementwise` or otherwise convert + to a field and use division: + + >>> M = DM([[4, 6]], ZZ) + >>> den = ZZ(12) + >>> M.cancel_denom(den) + (DomainMatrix([[2, 3]], (1, 2), ZZ), 6) + >>> numers, denoms = M.cancel_denom_elementwise(den) + >>> numers + DomainMatrix([[1, 1]], (1, 2), ZZ) + >>> denoms + DomainMatrix([[3, 2]], (1, 2), ZZ) + >>> M.to_field() / den + DomainMatrix([[1/3, 1/2]], (1, 2), QQ) + + See Also + ======== + + solve_den + inv_den + rref_den + cancel_denom_elementwise + """ + M = self + K = self.domain + + if K.is_zero(denom): + raise ZeroDivisionError('denominator is zero') + elif K.is_one(denom): + return (M.copy(), denom) + + elements, data = M.to_flat_nz() + + # First canonicalize the denominator (e.g. multiply by -1). + if K.is_negative(denom): + u = -K.one + else: + u = K.canonical_unit(denom) + + # Often after e.g. solve_den the denominator will be much more + # complicated than the elements of the numerator. Hopefully it will be + # quicker to find the gcd of the numerator and if there is no content + # then we do not need to look at the denominator at all. + content = dup_content(elements, K) + common = K.gcd(content, denom) + + if not K.is_one(content): + + common = K.gcd(content, denom) + + if not K.is_one(common): + elements = dup_quo_ground(elements, common, K) + denom = K.quo(denom, common) + + if not K.is_one(u): + elements = dup_mul_ground(elements, u, K) + denom = u * denom + elif K.is_one(common): + return (M.copy(), denom) + + M_cancelled = M.from_flat_nz(elements, data, K) + + return M_cancelled, denom + + def cancel_denom_elementwise(self, denom): + """ + Cancel factors between the elements of a matrix and a denominator. + + Returns a matrix of numerators and matrix of denominators. + + Requires ``gcd`` in the ground domain. + + Examples + ======== + + >>> from sympy.polys.matrices import DM + >>> from sympy import ZZ + >>> M = DM([[2, 3], [4, 12]], ZZ) + >>> denom = ZZ(6) + >>> numers, denoms = M.cancel_denom_elementwise(denom) + >>> numers.to_Matrix() + Matrix([ + [1, 1], + [2, 2]]) + >>> denoms.to_Matrix() + Matrix([ + [3, 2], + [3, 1]]) + >>> M_frac = (M.to_field() / denom).to_Matrix() + >>> M_frac + Matrix([ + [1/3, 1/2], + [2/3, 2]]) + >>> denoms_inverted = denoms.to_Matrix().applyfunc(lambda e: 1/e) + >>> numers.to_Matrix().multiply_elementwise(denoms_inverted) == M_frac + True + + Use :meth:`cancel_denom` to cancel factors between the matrix and the + denominator while preserving the form of a matrix with a scalar + denominator. + + See Also + ======== + + cancel_denom + """ + K = self.domain + M = self + + if K.is_zero(denom): + raise ZeroDivisionError('denominator is zero') + elif K.is_one(denom): + M_numers = M.copy() + M_denoms = M.ones(M.shape, M.domain) + return (M_numers, M_denoms) + + elements, data = M.to_flat_nz() + + cofactors = [K.cofactors(numer, denom) for numer in elements] + gcds, numers, denoms = zip(*cofactors) + + M_numers = M.from_flat_nz(list(numers), data, K) + M_denoms = M.from_flat_nz(list(denoms), data, K) + + return (M_numers, M_denoms) + + def content(self): + """ + Return the gcd of the elements of the matrix. + + Requires ``gcd`` in the ground domain. + + Examples + ======== + + >>> from sympy.polys.matrices import DM + >>> from sympy import ZZ + >>> M = DM([[2, 4], [4, 12]], ZZ) + >>> M.content() + 2 + + See Also + ======== + + primitive + cancel_denom + """ + K = self.domain + elements, _ = self.to_flat_nz() + return dup_content(elements, K) + + def primitive(self): + """ + Factor out gcd of the elements of a matrix. + + Requires ``gcd`` in the ground domain. + + Examples + ======== + + >>> from sympy.polys.matrices import DM + >>> from sympy import ZZ + >>> M = DM([[2, 4], [4, 12]], ZZ) + >>> content, M_primitive = M.primitive() + >>> content + 2 + >>> M_primitive + DomainMatrix([[1, 2], [2, 6]], (2, 2), ZZ) + >>> content * M_primitive == M + True + >>> M_primitive.content() == ZZ(1) + True + + See Also + ======== + + content + cancel_denom + """ + K = self.domain + elements, data = self.to_flat_nz() + content, prims = dup_primitive(elements, K) + M_primitive = self.from_flat_nz(prims, data, K) + return content, M_primitive + + def rref(self, *, method='auto'): + r""" + Returns reduced-row echelon form (RREF) and list of pivots. + + If the domain is not a field then it will be converted to a field. See + :meth:`rref_den` for the fraction-free version of this routine that + returns RREF with denominator instead. + + The domain must either be a field or have an associated fraction field + (see :meth:`to_field`). + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([ + ... [QQ(2), QQ(-1), QQ(0)], + ... [QQ(-1), QQ(2), QQ(-1)], + ... [QQ(0), QQ(0), QQ(2)]], (3, 3), QQ) + + >>> rref_matrix, rref_pivots = A.rref() + >>> rref_matrix + DomainMatrix([[1, 0, 0], [0, 1, 0], [0, 0, 1]], (3, 3), QQ) + >>> rref_pivots + (0, 1, 2) + + Parameters + ========== + + method : str, optional (default: 'auto') + The method to use to compute the RREF. The default is ``'auto'``, + which will attempt to choose the fastest method. The other options + are: + + - ``A.rref(method='GJ')`` uses Gauss-Jordan elimination with + division. If the domain is not a field then it will be converted + to a field with :meth:`to_field` first and RREF will be computed + by inverting the pivot elements in each row. This is most + efficient for very sparse matrices or for matrices whose elements + have complex denominators. + + - ``A.rref(method='FF')`` uses fraction-free Gauss-Jordan + elimination. Elimination is performed using exact division + (``exquo``) to control the growth of the coefficients. In this + case the current domain is always used for elimination but if + the domain is not a field then it will be converted to a field + at the end and divided by the denominator. This is most efficient + for dense matrices or for matrices with simple denominators. + + - ``A.rref(method='CD')`` clears the denominators before using + fraction-free Gauss-Jordan elimination in the associated ring. + This is most efficient for dense matrices with very simple + denominators. + + - ``A.rref(method='GJ_dense')``, ``A.rref(method='FF_dense')``, and + ``A.rref(method='CD_dense')`` are the same as the above methods + except that the dense implementations of the algorithms are used. + By default ``A.rref(method='auto')`` will usually choose the + sparse implementations for RREF. + + Regardless of which algorithm is used the returned matrix will + always have the same format (sparse or dense) as the input and its + domain will always be the field of fractions of the input domain. + + Returns + ======= + + (DomainMatrix, list) + reduced-row echelon form and list of pivots for the DomainMatrix + + See Also + ======== + + rref_den + RREF with denominator + sympy.polys.matrices.sdm.sdm_irref + Sparse implementation of ``method='GJ'``. + sympy.polys.matrices.sdm.sdm_rref_den + Sparse implementation of ``method='FF'`` and ``method='CD'``. + sympy.polys.matrices.dense.ddm_irref + Dense implementation of ``method='GJ'``. + sympy.polys.matrices.dense.ddm_irref_den + Dense implementation of ``method='FF'`` and ``method='CD'``. + clear_denoms + Clear denominators from a matrix, used by ``method='CD'`` and + by ``method='GJ'`` when the original domain is not a field. + + """ + return _dm_rref(self, method=method) + + def rref_den(self, *, method='auto', keep_domain=True): + r""" + Returns reduced-row echelon form with denominator and list of pivots. + + Requires exact division in the ground domain (``exquo``). + + Examples + ======== + + >>> from sympy import ZZ, QQ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([ + ... [ZZ(2), ZZ(-1), ZZ(0)], + ... [ZZ(-1), ZZ(2), ZZ(-1)], + ... [ZZ(0), ZZ(0), ZZ(2)]], (3, 3), ZZ) + + >>> A_rref, denom, pivots = A.rref_den() + >>> A_rref + DomainMatrix([[6, 0, 0], [0, 6, 0], [0, 0, 6]], (3, 3), ZZ) + >>> denom + 6 + >>> pivots + (0, 1, 2) + >>> A_rref.to_field() / denom + DomainMatrix([[1, 0, 0], [0, 1, 0], [0, 0, 1]], (3, 3), QQ) + >>> A_rref.to_field() / denom == A.convert_to(QQ).rref()[0] + True + + Parameters + ========== + + method : str, optional (default: 'auto') + The method to use to compute the RREF. The default is ``'auto'``, + which will attempt to choose the fastest method. The other options + are: + + - ``A.rref(method='FF')`` uses fraction-free Gauss-Jordan + elimination. Elimination is performed using exact division + (``exquo``) to control the growth of the coefficients. In this + case the current domain is always used for elimination and the + result is always returned as a matrix over the current domain. + This is most efficient for dense matrices or for matrices with + simple denominators. + + - ``A.rref(method='CD')`` clears denominators before using + fraction-free Gauss-Jordan elimination in the associated ring. + The result will be converted back to the original domain unless + ``keep_domain=False`` is passed in which case the result will be + over the ring used for elimination. This is most efficient for + dense matrices with very simple denominators. + + - ``A.rref(method='GJ')`` uses Gauss-Jordan elimination with + division. If the domain is not a field then it will be converted + to a field with :meth:`to_field` first and RREF will be computed + by inverting the pivot elements in each row. The result is + converted back to the original domain by clearing denominators + unless ``keep_domain=False`` is passed in which case the result + will be over the field used for elimination. This is most + efficient for very sparse matrices or for matrices whose elements + have complex denominators. + + - ``A.rref(method='GJ_dense')``, ``A.rref(method='FF_dense')``, and + ``A.rref(method='CD_dense')`` are the same as the above methods + except that the dense implementations of the algorithms are used. + By default ``A.rref(method='auto')`` will usually choose the + sparse implementations for RREF. + + Regardless of which algorithm is used the returned matrix will + always have the same format (sparse or dense) as the input and if + ``keep_domain=True`` its domain will always be the same as the + input. + + keep_domain : bool, optional + If True (the default), the domain of the returned matrix and + denominator are the same as the domain of the input matrix. If + False, the domain of the returned matrix might be changed to an + associated ring or field if the algorithm used a different domain. + This is useful for efficiency if the caller does not need the + result to be in the original domain e.g. it avoids clearing + denominators in the case of ``A.rref(method='GJ')``. + + Returns + ======= + + (DomainMatrix, scalar, list) + Reduced-row echelon form, denominator and list of pivot indices. + + See Also + ======== + + rref + RREF without denominator for field domains. + sympy.polys.matrices.sdm.sdm_irref + Sparse implementation of ``method='GJ'``. + sympy.polys.matrices.sdm.sdm_rref_den + Sparse implementation of ``method='FF'`` and ``method='CD'``. + sympy.polys.matrices.dense.ddm_irref + Dense implementation of ``method='GJ'``. + sympy.polys.matrices.dense.ddm_irref_den + Dense implementation of ``method='FF'`` and ``method='CD'``. + clear_denoms + Clear denominators from a matrix, used by ``method='CD'``. + + """ + return _dm_rref_den(self, method=method, keep_domain=keep_domain) + + def columnspace(self): + r""" + Returns the columnspace for the DomainMatrix + + Returns + ======= + + DomainMatrix + The columns of this matrix form a basis for the columnspace. + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([ + ... [QQ(1), QQ(-1)], + ... [QQ(2), QQ(-2)]], (2, 2), QQ) + >>> A.columnspace() + DomainMatrix([[1], [2]], (2, 1), QQ) + + """ + if not self.domain.is_Field: + raise DMNotAField('Not a field') + rref, pivots = self.rref() + rows, cols = self.shape + return self.extract(range(rows), pivots) + + def rowspace(self): + r""" + Returns the rowspace for the DomainMatrix + + Returns + ======= + + DomainMatrix + The rows of this matrix form a basis for the rowspace. + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([ + ... [QQ(1), QQ(-1)], + ... [QQ(2), QQ(-2)]], (2, 2), QQ) + >>> A.rowspace() + DomainMatrix([[1, -1]], (1, 2), QQ) + + """ + if not self.domain.is_Field: + raise DMNotAField('Not a field') + rref, pivots = self.rref() + rows, cols = self.shape + return self.extract(range(len(pivots)), range(cols)) + + def nullspace(self, divide_last=False): + r""" + Returns the nullspace for the DomainMatrix + + Returns + ======= + + DomainMatrix + The rows of this matrix form a basis for the nullspace. + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.polys.matrices import DM + >>> A = DM([ + ... [QQ(2), QQ(-2)], + ... [QQ(4), QQ(-4)]], QQ) + >>> A.nullspace() + DomainMatrix([[1, 1]], (1, 2), QQ) + + The returned matrix is a basis for the nullspace: + + >>> A_null = A.nullspace().transpose() + >>> A * A_null + DomainMatrix([[0], [0]], (2, 1), QQ) + >>> rows, cols = A.shape + >>> nullity = rows - A.rank() + >>> A_null.shape == (cols, nullity) + True + + Nullspace can also be computed for non-field rings. If the ring is not + a field then division is not used. Setting ``divide_last`` to True will + raise an error in this case: + + >>> from sympy import ZZ + >>> B = DM([[6, -3], + ... [4, -2]], ZZ) + >>> B.nullspace() + DomainMatrix([[3, 6]], (1, 2), ZZ) + >>> B.nullspace(divide_last=True) + Traceback (most recent call last): + ... + DMNotAField: Cannot normalize vectors over a non-field + + Over a ring with ``gcd`` defined the nullspace can potentially be + reduced with :meth:`primitive`: + + >>> B.nullspace().primitive() + (3, DomainMatrix([[1, 2]], (1, 2), ZZ)) + + A matrix over a ring can often be normalized by converting it to a + field but it is often a bad idea to do so: + + >>> from sympy.abc import a, b, c + >>> from sympy import Matrix + >>> M = Matrix([[ a*b, b + c, c], + ... [ a - b, b*c, c**2], + ... [a*b + a - b, b*c + b + c, c**2 + c]]) + >>> M.to_DM().domain + ZZ[a,b,c] + >>> M.to_DM().nullspace().to_Matrix().transpose() + Matrix([ + [ c**3], + [ -a*b*c**2 + a*c - b*c], + [a*b**2*c - a*b - a*c + b**2 + b*c]]) + + The unnormalized form here is nicer than the normalized form that + spreads a large denominator throughout the matrix: + + >>> M.to_DM().to_field().nullspace(divide_last=True).to_Matrix().transpose() + Matrix([ + [ c**3/(a*b**2*c - a*b - a*c + b**2 + b*c)], + [(-a*b*c**2 + a*c - b*c)/(a*b**2*c - a*b - a*c + b**2 + b*c)], + [ 1]]) + + Parameters + ========== + + divide_last : bool, optional + If False (the default), the vectors are not normalized and the RREF + is computed using :meth:`rref_den` and the denominator is + discarded. If True, then each row is divided by its final element; + the domain must be a field in this case. + + See Also + ======== + + nullspace_from_rref + rref + rref_den + rowspace + """ + A = self + K = A.domain + + if divide_last and not K.is_Field: + raise DMNotAField("Cannot normalize vectors over a non-field") + + if divide_last: + A_rref, pivots = A.rref() + else: + A_rref, den, pivots = A.rref_den() + + # Ensure that the sign is canonical before discarding the + # denominator. Then M.nullspace().primitive() is canonical. + u = K.canonical_unit(den) + if u != K.one: + A_rref *= u + + A_null = A_rref.nullspace_from_rref(pivots) + + return A_null + + def nullspace_from_rref(self, pivots=None): + """ + Compute nullspace from rref and pivots. + + The domain of the matrix can be any domain. + + The matrix must be in reduced row echelon form already. Otherwise the + result will be incorrect. Use :meth:`rref` or :meth:`rref_den` first + to get the reduced row echelon form or use :meth:`nullspace` instead. + + See Also + ======== + + nullspace + rref + rref_den + sympy.polys.matrices.sdm.SDM.nullspace_from_rref + sympy.polys.matrices.ddm.DDM.nullspace_from_rref + """ + null_rep, nonpivots = self.rep.nullspace_from_rref(pivots) + return self.from_rep(null_rep) + + def inv(self): + r""" + Finds the inverse of the DomainMatrix if exists + + Returns + ======= + + DomainMatrix + DomainMatrix after inverse + + Raises + ====== + + ValueError + If the domain of DomainMatrix not a Field + + DMNonSquareMatrixError + If the DomainMatrix is not a not Square DomainMatrix + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([ + ... [QQ(2), QQ(-1), QQ(0)], + ... [QQ(-1), QQ(2), QQ(-1)], + ... [QQ(0), QQ(0), QQ(2)]], (3, 3), QQ) + >>> A.inv() + DomainMatrix([[2/3, 1/3, 1/6], [1/3, 2/3, 1/3], [0, 0, 1/2]], (3, 3), QQ) + + See Also + ======== + + neg + + """ + if not self.domain.is_Field: + raise DMNotAField('Not a field') + m, n = self.shape + if m != n: + raise DMNonSquareMatrixError + inv = self.rep.inv() + return self.from_rep(inv) + + def det(self): + r""" + Returns the determinant of a square :class:`DomainMatrix`. + + Returns + ======= + + determinant: DomainElement + Determinant of the matrix. + + Raises + ====== + + ValueError + If the domain of DomainMatrix is not a Field + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([ + ... [ZZ(1), ZZ(2)], + ... [ZZ(3), ZZ(4)]], (2, 2), ZZ) + + >>> A.det() + -2 + + """ + m, n = self.shape + if m != n: + raise DMNonSquareMatrixError + return self.rep.det() + + def adj_det(self): + """ + Adjugate and determinant of a square :class:`DomainMatrix`. + + Returns + ======= + + (adjugate, determinant) : (DomainMatrix, DomainScalar) + The adjugate matrix and determinant of this matrix. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DM + >>> A = DM([ + ... [ZZ(1), ZZ(2)], + ... [ZZ(3), ZZ(4)]], ZZ) + >>> adjA, detA = A.adj_det() + >>> adjA + DomainMatrix([[4, -2], [-3, 1]], (2, 2), ZZ) + >>> detA + -2 + + See Also + ======== + + adjugate + Returns only the adjugate matrix. + det + Returns only the determinant. + inv_den + Returns a matrix/denominator pair representing the inverse matrix + but perhaps differing from the adjugate and determinant by a common + factor. + """ + m, n = self.shape + I_m = self.eye((m, m), self.domain) + adjA, detA = self.solve_den_charpoly(I_m, check=False) + if self.rep.fmt == "dense": + adjA = adjA.to_dense() + return adjA, detA + + def adjugate(self): + """ + Adjugate of a square :class:`DomainMatrix`. + + The adjugate matrix is the transpose of the cofactor matrix and is + related to the inverse by:: + + adj(A) = det(A) * A.inv() + + Unlike the inverse matrix the adjugate matrix can be computed and + expressed without division or fractions in the ground domain. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DM + >>> A = DM([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], ZZ) + >>> A.adjugate() + DomainMatrix([[4, -2], [-3, 1]], (2, 2), ZZ) + + Returns + ======= + + DomainMatrix + The adjugate matrix of this matrix with the same domain. + + See Also + ======== + + adj_det + """ + adjA, detA = self.adj_det() + return adjA + + def inv_den(self, method=None): + """ + Return the inverse as a :class:`DomainMatrix` with denominator. + + Returns + ======= + + (inv, den) : (:class:`DomainMatrix`, :class:`~.DomainElement`) + The inverse matrix and its denominator. + + This is more or less equivalent to :meth:`adj_det` except that ``inv`` + and ``den`` are not guaranteed to be the adjugate and inverse. The + ratio ``inv/den`` is equivalent to ``adj/det`` but some factors + might be cancelled between ``inv`` and ``den``. In simple cases this + might just be a minus sign so that ``(inv, den) == (-adj, -det)`` but + factors more complicated than ``-1`` can also be cancelled. + Cancellation is not guaranteed to be complete so ``inv`` and ``den`` + may not be on lowest terms. The denominator ``den`` will be zero if and + only if the determinant is zero. + + If the actual adjugate and determinant are needed, use :meth:`adj_det` + instead. If the intention is to compute the inverse matrix or solve a + system of equations then :meth:`inv_den` is more efficient. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([ + ... [ZZ(2), ZZ(-1), ZZ(0)], + ... [ZZ(-1), ZZ(2), ZZ(-1)], + ... [ZZ(0), ZZ(0), ZZ(2)]], (3, 3), ZZ) + >>> Ainv, den = A.inv_den() + >>> den + 6 + >>> Ainv + DomainMatrix([[4, 2, 1], [2, 4, 2], [0, 0, 3]], (3, 3), ZZ) + >>> A * Ainv == den * A.eye(A.shape, A.domain).to_dense() + True + + Parameters + ========== + + method : str, optional + The method to use to compute the inverse. Can be one of ``None``, + ``'rref'`` or ``'charpoly'``. If ``None`` then the method is + chosen automatically (see :meth:`solve_den` for details). + + See Also + ======== + + inv + det + adj_det + solve_den + """ + I = self.eye(self.shape, self.domain) + return self.solve_den(I, method=method) + + def solve_den(self, b, method=None): + """ + Solve matrix equation $Ax = b$ without fractions in the ground domain. + + Examples + ======== + + Solve a matrix equation over the integers: + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DM + >>> A = DM([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], ZZ) + >>> b = DM([[ZZ(5)], [ZZ(6)]], ZZ) + >>> xnum, xden = A.solve_den(b) + >>> xden + -2 + >>> xnum + DomainMatrix([[8], [-9]], (2, 1), ZZ) + >>> A * xnum == xden * b + True + + Solve a matrix equation over a polynomial ring: + + >>> from sympy import ZZ + >>> from sympy.abc import x, y, z, a, b + >>> R = ZZ[x, y, z, a, b] + >>> M = DM([[x*y, x*z], [y*z, x*z]], R) + >>> b = DM([[a], [b]], R) + >>> M.to_Matrix() + Matrix([ + [x*y, x*z], + [y*z, x*z]]) + >>> b.to_Matrix() + Matrix([ + [a], + [b]]) + >>> xnum, xden = M.solve_den(b) + >>> xden + x**2*y*z - x*y*z**2 + >>> xnum.to_Matrix() + Matrix([ + [ a*x*z - b*x*z], + [-a*y*z + b*x*y]]) + >>> M * xnum == xden * b + True + + The solution can be expressed over a fraction field which will cancel + gcds between the denominator and the elements of the numerator: + + >>> xsol = xnum.to_field() / xden + >>> xsol.to_Matrix() + Matrix([ + [ (a - b)/(x*y - y*z)], + [(-a*z + b*x)/(x**2*z - x*z**2)]]) + >>> (M * xsol).to_Matrix() == b.to_Matrix() + True + + When solving a large system of equations this cancellation step might + be a lot slower than :func:`solve_den` itself. The solution can also be + expressed as a ``Matrix`` without attempting any polynomial + cancellation between the numerator and denominator giving a less + simplified result more quickly: + + >>> xsol_uncancelled = xnum.to_Matrix() / xnum.domain.to_sympy(xden) + >>> xsol_uncancelled + Matrix([ + [ (a*x*z - b*x*z)/(x**2*y*z - x*y*z**2)], + [(-a*y*z + b*x*y)/(x**2*y*z - x*y*z**2)]]) + >>> from sympy import cancel + >>> cancel(xsol_uncancelled) == xsol.to_Matrix() + True + + Parameters + ========== + + self : :class:`DomainMatrix` + The ``m x n`` matrix $A$ in the equation $Ax = b$. Underdetermined + systems are not supported so ``m >= n``: $A$ should be square or + have more rows than columns. + b : :class:`DomainMatrix` + The ``n x m`` matrix $b$ for the rhs. + cp : list of :class:`~.DomainElement`, optional + The characteristic polynomial of the matrix $A$. If not given, it + will be computed using :meth:`charpoly`. + method: str, optional + The method to use for solving the system. Can be one of ``None``, + ``'charpoly'`` or ``'rref'``. If ``None`` (the default) then the + method will be chosen automatically. + + The ``charpoly`` method uses :meth:`solve_den_charpoly` and can + only be used if the matrix is square. This method is division free + and can be used with any domain. + + The ``rref`` method is fraction free but requires exact division + in the ground domain (``exquo``). This is also suitable for most + domains. This method can be used with overdetermined systems (more + equations than unknowns) but not underdetermined systems as a + unique solution is sought. + + Returns + ======= + + (xnum, xden) : (DomainMatrix, DomainElement) + The solution of the equation $Ax = b$ as a pair consisting of an + ``n x m`` matrix numerator ``xnum`` and a scalar denominator + ``xden``. + + The solution $x$ is given by ``x = xnum / xden``. The division free + invariant is ``A * xnum == xden * b``. If $A$ is square then the + denominator ``xden`` will be a divisor of the determinant $det(A)$. + + Raises + ====== + + DMNonInvertibleMatrixError + If the system $Ax = b$ does not have a unique solution. + + See Also + ======== + + solve_den_charpoly + solve_den_rref + inv_den + """ + m, n = self.shape + bm, bn = b.shape + + if m != bm: + raise DMShapeError("Matrix equation shape mismatch.") + + if method is None: + method = 'rref' + elif method == 'charpoly' and m != n: + raise DMNonSquareMatrixError("method='charpoly' requires a square matrix.") + + if method == 'charpoly': + xnum, xden = self.solve_den_charpoly(b) + elif method == 'rref': + xnum, xden = self.solve_den_rref(b) + else: + raise DMBadInputError("method should be 'rref' or 'charpoly'") + + return xnum, xden + + def solve_den_rref(self, b): + """ + Solve matrix equation $Ax = b$ using fraction-free RREF + + Solves the matrix equation $Ax = b$ for $x$ and returns the solution + as a numerator/denominator pair. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DM + >>> A = DM([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], ZZ) + >>> b = DM([[ZZ(5)], [ZZ(6)]], ZZ) + >>> xnum, xden = A.solve_den_rref(b) + >>> xden + -2 + >>> xnum + DomainMatrix([[8], [-9]], (2, 1), ZZ) + >>> A * xnum == xden * b + True + + See Also + ======== + + solve_den + solve_den_charpoly + """ + A = self + m, n = A.shape + bm, bn = b.shape + + if m != bm: + raise DMShapeError("Matrix equation shape mismatch.") + + if m < n: + raise DMShapeError("Underdetermined matrix equation.") + + Aaug = A.hstack(b) + Aaug_rref, denom, pivots = Aaug.rref_den() + + # XXX: We check here if there are pivots after the last column. If + # there were than it possibly means that rref_den performed some + # unnecessary elimination. It would be better if rref methods had a + # parameter indicating how many columns should be used for elimination. + if len(pivots) != n or pivots and pivots[-1] >= n: + raise DMNonInvertibleMatrixError("Non-unique solution.") + + xnum = Aaug_rref[:n, n:] + xden = denom + + return xnum, xden + + def solve_den_charpoly(self, b, cp=None, check=True): + """ + Solve matrix equation $Ax = b$ using the characteristic polynomial. + + This method solves the square matrix equation $Ax = b$ for $x$ using + the characteristic polynomial without any division or fractions in the + ground domain. + + Examples + ======== + + Solve a matrix equation over the integers: + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DM + >>> A = DM([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], ZZ) + >>> b = DM([[ZZ(5)], [ZZ(6)]], ZZ) + >>> xnum, detA = A.solve_den_charpoly(b) + >>> detA + -2 + >>> xnum + DomainMatrix([[8], [-9]], (2, 1), ZZ) + >>> A * xnum == detA * b + True + + Parameters + ========== + + self : DomainMatrix + The ``n x n`` matrix `A` in the equation `Ax = b`. Must be square + and invertible. + b : DomainMatrix + The ``n x m`` matrix `b` for the rhs. + cp : list, optional + The characteristic polynomial of the matrix `A` if known. If not + given, it will be computed using :meth:`charpoly`. + check : bool, optional + If ``True`` (the default) check that the determinant is not zero + and raise an error if it is. If ``False`` then if the determinant + is zero the return value will be equal to ``(A.adjugate()*b, 0)``. + + Returns + ======= + + (xnum, detA) : (DomainMatrix, DomainElement) + The solution of the equation `Ax = b` as a matrix numerator and + scalar denominator pair. The denominator is equal to the + determinant of `A` and the numerator is ``adj(A)*b``. + + The solution $x$ is given by ``x = xnum / detA``. The division free + invariant is ``A * xnum == detA * b``. + + If ``b`` is the identity matrix, then ``xnum`` is the adjugate matrix + and we have ``A * adj(A) == detA * I``. + + See Also + ======== + + solve_den + Main frontend for solving matrix equations with denominator. + solve_den_rref + Solve matrix equations using fraction-free RREF. + inv_den + Invert a matrix using the characteristic polynomial. + """ + A, b = self.unify(b) + m, n = self.shape + mb, nb = b.shape + + if m != n: + raise DMNonSquareMatrixError("Matrix must be square") + + if mb != m: + raise DMShapeError("Matrix and vector must have the same number of rows") + + f, detA = self.adj_poly_det(cp=cp) + + if check and not detA: + raise DMNonInvertibleMatrixError("Matrix is not invertible") + + # Compute adj(A)*b = det(A)*inv(A)*b using Horner's method without + # constructing inv(A) explicitly. + adjA_b = self.eval_poly_mul(f, b) + + return (adjA_b, detA) + + def adj_poly_det(self, cp=None): + """ + Return the polynomial $p$ such that $p(A) = adj(A)$ and also the + determinant of $A$. + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.polys.matrices import DM + >>> A = DM([[QQ(1), QQ(2)], [QQ(3), QQ(4)]], QQ) + >>> p, detA = A.adj_poly_det() + >>> p + [-1, 5] + >>> p_A = A.eval_poly(p) + >>> p_A + DomainMatrix([[4, -2], [-3, 1]], (2, 2), QQ) + >>> p[0]*A**1 + p[1]*A**0 == p_A + True + >>> p_A == A.adjugate() + True + >>> A * A.adjugate() == detA * A.eye(A.shape, A.domain).to_dense() + True + + See Also + ======== + + adjugate + eval_poly + adj_det + """ + + # Cayley-Hamilton says that a matrix satisfies its own minimal + # polynomial + # + # p[0]*A^n + p[1]*A^(n-1) + ... + p[n]*I = 0 + # + # with p[0]=1 and p[n]=(-1)^n*det(A) or + # + # det(A)*I = -(-1)^n*(p[0]*A^(n-1) + p[1]*A^(n-2) + ... + p[n-1]*A). + # + # Define a new polynomial f with f[i] = -(-1)^n*p[i] for i=0..n-1. Then + # + # det(A)*I = f[0]*A^n + f[1]*A^(n-1) + ... + f[n-1]*A. + # + # Multiplying on the right by inv(A) gives + # + # det(A)*inv(A) = f[0]*A^(n-1) + f[1]*A^(n-2) + ... + f[n-1]. + # + # So adj(A) = det(A)*inv(A) = f(A) + + A = self + m, n = self.shape + + if m != n: + raise DMNonSquareMatrixError("Matrix must be square") + + if cp is None: + cp = A.charpoly() + + if len(cp) % 2: + # n is even + detA = cp[-1] + f = [-cpi for cpi in cp[:-1]] + else: + # n is odd + detA = -cp[-1] + f = cp[:-1] + + return f, detA + + def eval_poly(self, p): + """ + Evaluate polynomial function of a matrix $p(A)$. + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.polys.matrices import DM + >>> A = DM([[QQ(1), QQ(2)], [QQ(3), QQ(4)]], QQ) + >>> p = [QQ(1), QQ(2), QQ(3)] + >>> p_A = A.eval_poly(p) + >>> p_A + DomainMatrix([[12, 14], [21, 33]], (2, 2), QQ) + >>> p_A == p[0]*A**2 + p[1]*A + p[2]*A**0 + True + + See Also + ======== + + eval_poly_mul + """ + A = self + m, n = A.shape + + if m != n: + raise DMNonSquareMatrixError("Matrix must be square") + + if not p: + return self.zeros(self.shape, self.domain) + elif len(p) == 1: + return p[0] * self.eye(self.shape, self.domain) + + # Evaluate p(A) using Horner's method: + # XXX: Use Paterson-Stockmeyer method? + I = A.eye(A.shape, A.domain) + p_A = p[0] * I + for pi in p[1:]: + p_A = A*p_A + pi*I + + return p_A + + def eval_poly_mul(self, p, B): + r""" + Evaluate polynomial matrix product $p(A) \times B$. + + Evaluate the polynomial matrix product $p(A) \times B$ using Horner's + method without creating the matrix $p(A)$ explicitly. If $B$ is a + column matrix then this method will only use matrix-vector multiplies + and no matrix-matrix multiplies are needed. + + If $B$ is square or wide or if $A$ can be represented in a simpler + domain than $B$ then it might be faster to evaluate $p(A)$ explicitly + (see :func:`eval_poly`) and then multiply with $B$. + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.polys.matrices import DM + >>> A = DM([[QQ(1), QQ(2)], [QQ(3), QQ(4)]], QQ) + >>> b = DM([[QQ(5)], [QQ(6)]], QQ) + >>> p = [QQ(1), QQ(2), QQ(3)] + >>> p_A_b = A.eval_poly_mul(p, b) + >>> p_A_b + DomainMatrix([[144], [303]], (2, 1), QQ) + >>> p_A_b == p[0]*A**2*b + p[1]*A*b + p[2]*b + True + >>> A.eval_poly_mul(p, b) == A.eval_poly(p)*b + True + + See Also + ======== + + eval_poly + solve_den_charpoly + """ + A = self + m, n = A.shape + mb, nb = B.shape + + if m != n: + raise DMNonSquareMatrixError("Matrix must be square") + + if mb != n: + raise DMShapeError("Matrices are not aligned") + + if A.domain != B.domain: + raise DMDomainError("Matrices must have the same domain") + + # Given a polynomial p(x) = p[0]*x^n + p[1]*x^(n-1) + ... + p[n-1] + # and matrices A and B we want to find + # + # p(A)*B = p[0]*A^n*B + p[1]*A^(n-1)*B + ... + p[n-1]*B + # + # Factoring out A term by term we get + # + # p(A)*B = A*(...A*(A*(A*(p[0]*B) + p[1]*B) + p[2]*B) + ...) + p[n-1]*B + # + # where each pair of brackets represents one iteration of the loop + # below starting from the innermost p[0]*B. If B is a column matrix + # then products like A*(...) are matrix-vector multiplies and products + # like p[i]*B are scalar-vector multiplies so there are no + # matrix-matrix multiplies. + + if not p: + return B.zeros(B.shape, B.domain, fmt=B.rep.fmt) + + p_A_B = p[0]*B + + for p_i in p[1:]: + p_A_B = A*p_A_B + p_i*B + + return p_A_B + + def lu(self): + r""" + Returns Lower and Upper decomposition of the DomainMatrix + + Returns + ======= + + (L, U, exchange) + L, U are Lower and Upper decomposition of the DomainMatrix, + exchange is the list of indices of rows exchanged in the + decomposition. + + Raises + ====== + + ValueError + If the domain of DomainMatrix not a Field + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([ + ... [QQ(1), QQ(-1)], + ... [QQ(2), QQ(-2)]], (2, 2), QQ) + >>> L, U, exchange = A.lu() + >>> L + DomainMatrix([[1, 0], [2, 1]], (2, 2), QQ) + >>> U + DomainMatrix([[1, -1], [0, 0]], (2, 2), QQ) + >>> exchange + [] + + See Also + ======== + + lu_solve + + """ + if not self.domain.is_Field: + raise DMNotAField('Not a field') + L, U, swaps = self.rep.lu() + return self.from_rep(L), self.from_rep(U), swaps + + def qr(self): + r""" + QR decomposition of the DomainMatrix. + + Explanation + =========== + + The QR decomposition expresses a matrix as the product of an orthogonal + matrix (Q) and an upper triangular matrix (R). In this implementation, + Q is not orthonormal: its columns are orthogonal but not normalized to + unit vectors. This avoids unnecessary divisions and is particularly + suited for exact arithmetic domains. + + Note + ==== + + This implementation is valid only for matrices over real domains. For + matrices over complex domains, a proper QR decomposition would require + handling conjugation to ensure orthogonality. + + Returns + ======= + + (Q, R) + Q is the orthogonal matrix, and R is the upper triangular matrix + resulting from the QR decomposition of the DomainMatrix. + + Raises + ====== + + DMDomainError + If the domain of the DomainMatrix is not a field (e.g., QQ). + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([[1, 2], [3, 4], [5, 6]], (3, 2), QQ) + >>> Q, R = A.qr() + >>> Q + DomainMatrix([[1, 26/35], [3, 8/35], [5, -2/7]], (3, 2), QQ) + >>> R + DomainMatrix([[1, 44/35], [0, 1]], (2, 2), QQ) + >>> Q * R == A + True + >>> (Q.transpose() * Q).is_diagonal + True + >>> R.is_upper + True + + See Also + ======== + + lu + + """ + ddm_q, ddm_r = self.rep.qr() + Q = self.from_rep(ddm_q) + R = self.from_rep(ddm_r) + return Q, R + + def lu_solve(self, rhs): + r""" + Solver for DomainMatrix x in the A*x = B + + Parameters + ========== + + rhs : DomainMatrix B + + Returns + ======= + + DomainMatrix + x in A*x = B + + Raises + ====== + + DMShapeError + If the DomainMatrix A and rhs have different number of rows + + ValueError + If the domain of DomainMatrix A not a Field + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([ + ... [QQ(1), QQ(2)], + ... [QQ(3), QQ(4)]], (2, 2), QQ) + >>> B = DomainMatrix([ + ... [QQ(1), QQ(1)], + ... [QQ(0), QQ(1)]], (2, 2), QQ) + + >>> A.lu_solve(B) + DomainMatrix([[-2, -1], [3/2, 1]], (2, 2), QQ) + + See Also + ======== + + lu + + """ + if self.shape[0] != rhs.shape[0]: + raise DMShapeError("Shape") + if not self.domain.is_Field: + raise DMNotAField('Not a field') + sol = self.rep.lu_solve(rhs.rep) + return self.from_rep(sol) + + def fflu(self): + """ + Fraction-free LU decomposition of DomainMatrix. + + Explanation + =========== + + This method computes the PLDU decomposition + using Gauss-Bareiss elimination in a fraction-free manner, + it ensures that all intermediate results remain in + the domain of the input matrix. Unlike standard + LU decomposition, which introduces division, this approach + avoids fractions, making it particularly suitable + for exact arithmetic over integers or polynomials. + + This method satisfies the invariant: + + P * A = L * inv(D) * U + + Returns + ======= + + (P, L, D, U) + - P (Permutation matrix) + - L (Lower triangular matrix) + - D (Diagonal matrix) + - U (Upper triangular matrix) + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([[1, 2], [3, 4]], (2, 2), ZZ) + >>> P, L, D, U = A.fflu() + >>> P + DomainMatrix([[1, 0], [0, 1]], (2, 2), ZZ) + >>> L + DomainMatrix([[1, 0], [3, -2]], (2, 2), ZZ) + >>> D + DomainMatrix([[1, 0], [0, -2]], (2, 2), ZZ) + >>> U + DomainMatrix([[1, 2], [0, -2]], (2, 2), ZZ) + >>> L.is_lower and U.is_upper and D.is_diagonal + True + >>> L * D.to_field().inv() * U == P * A.to_field() + True + >>> I, d = D.inv_den() + >>> L * I * U == d * P * A + True + + See Also + ======== + + sympy.polys.matrices.ddm.DDM.fflu + + References + ========== + + .. [1] Nakos, G. C., Turner, P. R., & Williams, R. M. (1997). Fraction-free + algorithms for linear and polynomial equations. ACM SIGSAM Bulletin, + 31(3), 11-19. https://doi.org/10.1145/271130.271133 + .. [2] Middeke, J.; Jeffrey, D.J.; Koutschan, C. (2020), "Common Factors + in Fraction-Free Matrix Decompositions", Mathematics in Computer Science, + 15 (4): 589–608, arXiv:2005.12380, doi:10.1007/s11786-020-00495-9 + .. [3] https://en.wikipedia.org/wiki/Bareiss_algorithm + """ + from_rep = self.from_rep + P, L, D, U = self.rep.fflu() + return from_rep(P), from_rep(L), from_rep(D), from_rep(U) + + def _solve(A, b): + # XXX: Not sure about this method or its signature. It is just created + # because it is needed by the holonomic module. + if A.shape[0] != b.shape[0]: + raise DMShapeError("Shape") + if A.domain != b.domain or not A.domain.is_Field: + raise DMNotAField('Not a field') + Aaug = A.hstack(b) + Arref, pivots = Aaug.rref() + particular = Arref.from_rep(Arref.rep.particular()) + nullspace_rep, nonpivots = Arref[:,:-1].rep.nullspace() + nullspace = Arref.from_rep(nullspace_rep) + return particular, nullspace + + def charpoly(self): + r""" + Characteristic polynomial of a square matrix. + + Computes the characteristic polynomial in a fully expanded form using + division free arithmetic. If a factorization of the characteristic + polynomial is needed then it is more efficient to call + :meth:`charpoly_factor_list` than calling :meth:`charpoly` and then + factorizing the result. + + Returns + ======= + + list: list of DomainElement + coefficients of the characteristic polynomial + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([ + ... [ZZ(1), ZZ(2)], + ... [ZZ(3), ZZ(4)]], (2, 2), ZZ) + + >>> A.charpoly() + [1, -5, -2] + + See Also + ======== + + charpoly_factor_list + Compute the factorisation of the characteristic polynomial. + charpoly_factor_blocks + A partial factorisation of the characteristic polynomial that can + be computed more efficiently than either the full factorisation or + the fully expanded polynomial. + """ + M = self + K = M.domain + + factors = M.charpoly_factor_blocks() + + cp = [K.one] + + for f, mult in factors: + for _ in range(mult): + cp = dup_mul(cp, f, K) + + return cp + + def charpoly_factor_list(self): + """ + Full factorization of the characteristic polynomial. + + Examples + ======== + + >>> from sympy.polys.matrices import DM + >>> from sympy import ZZ + >>> M = DM([[6, -1, 0, 0], + ... [9, 12, 0, 0], + ... [0, 0, 1, 2], + ... [0, 0, 5, 6]], ZZ) + + Compute the factorization of the characteristic polynomial: + + >>> M.charpoly_factor_list() + [([1, -9], 2), ([1, -7, -4], 1)] + + Use :meth:`charpoly` to get the unfactorized characteristic polynomial: + + >>> M.charpoly() + [1, -25, 203, -495, -324] + + The same calculations with ``Matrix``: + + >>> M.to_Matrix().charpoly().as_expr() + lambda**4 - 25*lambda**3 + 203*lambda**2 - 495*lambda - 324 + >>> M.to_Matrix().charpoly().as_expr().factor() + (lambda - 9)**2*(lambda**2 - 7*lambda - 4) + + Returns + ======= + + list: list of pairs (factor, multiplicity) + A full factorization of the characteristic polynomial. + + See Also + ======== + + charpoly + Expanded form of the characteristic polynomial. + charpoly_factor_blocks + A partial factorisation of the characteristic polynomial that can + be computed more efficiently. + """ + M = self + K = M.domain + + # It is more efficient to start from the partial factorization provided + # for free by M.charpoly_factor_blocks than the expanded M.charpoly. + factors = M.charpoly_factor_blocks() + + factors_irreducible = [] + + for factor_i, mult_i in factors: + + _, factors_list = dup_factor_list(factor_i, K) + + for factor_j, mult_j in factors_list: + factors_irreducible.append((factor_j, mult_i * mult_j)) + + return _collect_factors(factors_irreducible) + + def charpoly_factor_blocks(self): + """ + Partial factorisation of the characteristic polynomial. + + This factorisation arises from a block structure of the matrix (if any) + and so the factors are not guaranteed to be irreducible. The + :meth:`charpoly_factor_blocks` method is the most efficient way to get + a representation of the characteristic polynomial but the result is + neither fully expanded nor fully factored. + + Examples + ======== + + >>> from sympy.polys.matrices import DM + >>> from sympy import ZZ + >>> M = DM([[6, -1, 0, 0], + ... [9, 12, 0, 0], + ... [0, 0, 1, 2], + ... [0, 0, 5, 6]], ZZ) + + This computes a partial factorization using only the block structure of + the matrix to reveal factors: + + >>> M.charpoly_factor_blocks() + [([1, -18, 81], 1), ([1, -7, -4], 1)] + + These factors correspond to the two diagonal blocks in the matrix: + + >>> DM([[6, -1], [9, 12]], ZZ).charpoly() + [1, -18, 81] + >>> DM([[1, 2], [5, 6]], ZZ).charpoly() + [1, -7, -4] + + Use :meth:`charpoly_factor_list` to get a complete factorization into + irreducibles: + + >>> M.charpoly_factor_list() + [([1, -9], 2), ([1, -7, -4], 1)] + + Use :meth:`charpoly` to get the expanded characteristic polynomial: + + >>> M.charpoly() + [1, -25, 203, -495, -324] + + Returns + ======= + + list: list of pairs (factor, multiplicity) + A partial factorization of the characteristic polynomial. + + See Also + ======== + + charpoly + Compute the fully expanded characteristic polynomial. + charpoly_factor_list + Compute a full factorization of the characteristic polynomial. + """ + M = self + + if not M.is_square: + raise DMNonSquareMatrixError("not square") + + # scc returns indices that permute the matrix into block triangular + # form and can extract the diagonal blocks. M.charpoly() is equal to + # the product of the diagonal block charpolys. + components = M.scc() + + block_factors = [] + + for indices in components: + block = M.extract(indices, indices) + block_factors.append((block.charpoly_base(), 1)) + + return _collect_factors(block_factors) + + def charpoly_base(self): + """ + Base case for :meth:`charpoly_factor_blocks` after block decomposition. + + This method is used internally by :meth:`charpoly_factor_blocks` as the + base case for computing the characteristic polynomial of a block. It is + more efficient to call :meth:`charpoly_factor_blocks`, :meth:`charpoly` + or :meth:`charpoly_factor_list` rather than call this method directly. + + This will use either the dense or the sparse implementation depending + on the sparsity of the matrix and will clear denominators if possible + before calling :meth:`charpoly_berk` to compute the characteristic + polynomial using the Berkowitz algorithm. + + See Also + ======== + + charpoly + charpoly_factor_list + charpoly_factor_blocks + charpoly_berk + """ + M = self + K = M.domain + + # It seems that the sparse implementation is always faster for random + # matrices with fewer than 50% non-zero entries. This does not seem to + # depend on domain, size, bit count etc. + density = self.nnz() / self.shape[0]**2 + if density < 0.5: + M = M.to_sparse() + else: + M = M.to_dense() + + # Clearing denominators is always more efficient if it can be done. + # Doing it here after block decomposition is good because each block + # might have a smaller denominator. However it might be better for + # charpoly and charpoly_factor_list to restore the denominators only at + # the very end so that they can call e.g. dup_factor_list before + # restoring the denominators. The methods would need to be changed to + # return (poly, denom) pairs to make that work though. + clear_denoms = K.is_Field and K.has_assoc_Ring + + if clear_denoms: + clear_denoms = True + d, M = M.clear_denoms(convert=True) + d = d.element + K_f = K + K_r = M.domain + + # Berkowitz algorithm over K_r. + cp = M.charpoly_berk() + + if clear_denoms: + # Restore the denominator in the charpoly over K_f. + # + # If M = N/d then p_M(x) = p_N(x*d)/d^n. + cp = dup_convert(cp, K_r, K_f) + p = [K_f.one, K_f.zero] + q = [K_f.one/d] + cp = dup_transform(cp, p, q, K_f) + + return cp + + def charpoly_berk(self): + """Compute the characteristic polynomial using the Berkowitz algorithm. + + This method directly calls the underlying implementation of the + Berkowitz algorithm (:meth:`sympy.polys.matrices.dense.ddm_berk` or + :meth:`sympy.polys.matrices.sdm.sdm_berk`). + + This is used by :meth:`charpoly` and other methods as the base case for + for computing the characteristic polynomial. However those methods will + apply other optimizations such as block decomposition, clearing + denominators and converting between dense and sparse representations + before calling this method. It is more efficient to call those methods + instead of this one but this method is provided for direct access to + the Berkowitz algorithm. + + Examples + ======== + + >>> from sympy.polys.matrices import DM + >>> from sympy import QQ + >>> M = DM([[6, -1, 0, 0], + ... [9, 12, 0, 0], + ... [0, 0, 1, 2], + ... [0, 0, 5, 6]], QQ) + >>> M.charpoly_berk() + [1, -25, 203, -495, -324] + + See Also + ======== + + charpoly + charpoly_base + charpoly_factor_list + charpoly_factor_blocks + sympy.polys.matrices.dense.ddm_berk + sympy.polys.matrices.sdm.sdm_berk + """ + return self.rep.charpoly() + + @classmethod + def eye(cls, shape, domain): + r""" + Return identity matrix of size n or shape (m, n). + + Examples + ======== + + >>> from sympy.polys.matrices import DomainMatrix + >>> from sympy import QQ + >>> DomainMatrix.eye(3, QQ) + DomainMatrix({0: {0: 1}, 1: {1: 1}, 2: {2: 1}}, (3, 3), QQ) + + """ + if isinstance(shape, int): + shape = (shape, shape) + return cls.from_rep(SDM.eye(shape, domain)) + + @classmethod + def diag(cls, diagonal, domain, shape=None): + r""" + Return diagonal matrix with entries from ``diagonal``. + + Examples + ======== + + >>> from sympy.polys.matrices import DomainMatrix + >>> from sympy import ZZ + >>> DomainMatrix.diag([ZZ(5), ZZ(6)], ZZ) + DomainMatrix({0: {0: 5}, 1: {1: 6}}, (2, 2), ZZ) + + """ + if shape is None: + N = len(diagonal) + shape = (N, N) + return cls.from_rep(SDM.diag(diagonal, domain, shape)) + + @classmethod + def zeros(cls, shape, domain, *, fmt='sparse'): + """Returns a zero DomainMatrix of size shape, belonging to the specified domain + + Examples + ======== + + >>> from sympy.polys.matrices import DomainMatrix + >>> from sympy import QQ + >>> DomainMatrix.zeros((2, 3), QQ) + DomainMatrix({}, (2, 3), QQ) + + """ + return cls.from_rep(SDM.zeros(shape, domain)) + + @classmethod + def ones(cls, shape, domain): + """Returns a DomainMatrix of 1s, of size shape, belonging to the specified domain + + Examples + ======== + + >>> from sympy.polys.matrices import DomainMatrix + >>> from sympy import QQ + >>> DomainMatrix.ones((2,3), QQ) + DomainMatrix([[1, 1, 1], [1, 1, 1]], (2, 3), QQ) + + """ + return cls.from_rep(DDM.ones(shape, domain).to_dfm_or_ddm()) + + def __eq__(A, B): + r""" + Checks for two DomainMatrix matrices to be equal or not + + Parameters + ========== + + A, B: DomainMatrix + to check equality + + Returns + ======= + + Boolean + True for equal, else False + + Raises + ====== + + NotImplementedError + If B is not a DomainMatrix + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([ + ... [ZZ(1), ZZ(2)], + ... [ZZ(3), ZZ(4)]], (2, 2), ZZ) + >>> B = DomainMatrix([ + ... [ZZ(1), ZZ(1)], + ... [ZZ(0), ZZ(1)]], (2, 2), ZZ) + >>> A.__eq__(A) + True + >>> A.__eq__(B) + False + + """ + if not isinstance(A, type(B)): + return NotImplemented + return A.domain == B.domain and A.rep == B.rep + + def unify_eq(A, B): + if A.shape != B.shape: + return False + if A.domain != B.domain: + A, B = A.unify(B) + return A == B + + def lll(A, delta=QQ(3, 4)): + """ + Performs the Lenstra–Lenstra–Lovász (LLL) basis reduction algorithm. + See [1]_ and [2]_. + + Parameters + ========== + + delta : QQ, optional + The Lovász parameter. Must be in the interval (0.25, 1), with larger + values producing a more reduced basis. The default is 0.75 for + historical reasons. + + Returns + ======= + + The reduced basis as a DomainMatrix over ZZ. + + Throws + ====== + + DMValueError: if delta is not in the range (0.25, 1) + DMShapeError: if the matrix is not of shape (m, n) with m <= n + DMDomainError: if the matrix domain is not ZZ + DMRankError: if the matrix contains linearly dependent rows + + Examples + ======== + + >>> from sympy.polys.domains import ZZ, QQ + >>> from sympy.polys.matrices import DM + >>> x = DM([[1, 0, 0, 0, -20160], + ... [0, 1, 0, 0, 33768], + ... [0, 0, 1, 0, 39578], + ... [0, 0, 0, 1, 47757]], ZZ) + >>> y = DM([[10, -3, -2, 8, -4], + ... [3, -9, 8, 1, -11], + ... [-3, 13, -9, -3, -9], + ... [-12, -7, -11, 9, -1]], ZZ) + >>> assert x.lll(delta=QQ(5, 6)) == y + + Notes + ===== + + The implementation is derived from the Maple code given in Figures 4.3 + and 4.4 of [3]_ (pp.68-69). It uses the efficient method of only calculating + state updates as they are required. + + See also + ======== + + lll_transform + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Lenstra%E2%80%93Lenstra%E2%80%93Lov%C3%A1sz_lattice_basis_reduction_algorithm + .. [2] https://web.archive.org/web/20221029115428/https://web.cs.elte.hu/~lovasz/scans/lll.pdf + .. [3] Murray R. Bremner, "Lattice Basis Reduction: An Introduction to the LLL Algorithm and Its Applications" + + """ + return DomainMatrix.from_rep(A.rep.lll(delta=delta)) + + def lll_transform(A, delta=QQ(3, 4)): + """ + Performs the Lenstra–Lenstra–Lovász (LLL) basis reduction algorithm + and returns the reduced basis and transformation matrix. + + Explanation + =========== + + Parameters, algorithm and basis are the same as for :meth:`lll` except that + the return value is a tuple `(B, T)` with `B` the reduced basis and + `T` a transformation matrix. The original basis `A` is transformed to + `B` with `T*A == B`. If only `B` is needed then :meth:`lll` should be + used as it is a little faster. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ, QQ + >>> from sympy.polys.matrices import DM + >>> X = DM([[1, 0, 0, 0, -20160], + ... [0, 1, 0, 0, 33768], + ... [0, 0, 1, 0, 39578], + ... [0, 0, 0, 1, 47757]], ZZ) + >>> B, T = X.lll_transform(delta=QQ(5, 6)) + >>> T * X == B + True + + See also + ======== + + lll + + """ + reduced, transform = A.rep.lll_transform(delta=delta) + return DomainMatrix.from_rep(reduced), DomainMatrix.from_rep(transform) + + +def _collect_factors(factors_list): + """ + Collect repeating factors and sort. + + >>> from sympy.polys.matrices.domainmatrix import _collect_factors + >>> _collect_factors([([1, 2], 2), ([1, 4], 3), ([1, 2], 5)]) + [([1, 4], 3), ([1, 2], 7)] + """ + factors = Counter() + for factor, exponent in factors_list: + factors[tuple(factor)] += exponent + + factors_list = [(list(f), e) for f, e in factors.items()] + + return _sort_factors(factors_list) diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/domainscalar.py b/phivenv/Lib/site-packages/sympy/polys/matrices/domainscalar.py new file mode 100644 index 0000000000000000000000000000000000000000..df439a60a0ea0df5f6fac988c06da2a06a4fbac2 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/matrices/domainscalar.py @@ -0,0 +1,122 @@ +""" + +Module for the DomainScalar class. + +A DomainScalar represents an element which is in a particular +Domain. The idea is that the DomainScalar class provides the +convenience routines for unifying elements with different domains. + +It assists in Scalar Multiplication and getitem for DomainMatrix. + +""" +from ..constructor import construct_domain + +from sympy.polys.domains import Domain, ZZ + + +class DomainScalar: + r""" + docstring + """ + + def __new__(cls, element, domain): + if not isinstance(domain, Domain): + raise TypeError("domain should be of type Domain") + if not domain.of_type(element): + raise TypeError("element %s should be in domain %s" % (element, domain)) + return cls.new(element, domain) + + @classmethod + def new(cls, element, domain): + obj = super().__new__(cls) + obj.element = element + obj.domain = domain + return obj + + def __repr__(self): + return repr(self.element) + + @classmethod + def from_sympy(cls, expr): + [domain, [element]] = construct_domain([expr]) + return cls.new(element, domain) + + def to_sympy(self): + return self.domain.to_sympy(self.element) + + def to_domain(self, domain): + element = domain.convert_from(self.element, self.domain) + return self.new(element, domain) + + def convert_to(self, domain): + return self.to_domain(domain) + + def unify(self, other): + domain = self.domain.unify(other.domain) + return self.to_domain(domain), other.to_domain(domain) + + def __bool__(self): + return bool(self.element) + + def __add__(self, other): + if not isinstance(other, DomainScalar): + return NotImplemented + self, other = self.unify(other) + return self.new(self.element + other.element, self.domain) + + def __sub__(self, other): + if not isinstance(other, DomainScalar): + return NotImplemented + self, other = self.unify(other) + return self.new(self.element - other.element, self.domain) + + def __mul__(self, other): + if not isinstance(other, DomainScalar): + if isinstance(other, int): + other = DomainScalar(ZZ(other), ZZ) + else: + return NotImplemented + + self, other = self.unify(other) + return self.new(self.element * other.element, self.domain) + + def __floordiv__(self, other): + if not isinstance(other, DomainScalar): + return NotImplemented + self, other = self.unify(other) + return self.new(self.domain.quo(self.element, other.element), self.domain) + + def __mod__(self, other): + if not isinstance(other, DomainScalar): + return NotImplemented + self, other = self.unify(other) + return self.new(self.domain.rem(self.element, other.element), self.domain) + + def __divmod__(self, other): + if not isinstance(other, DomainScalar): + return NotImplemented + self, other = self.unify(other) + q, r = self.domain.div(self.element, other.element) + return (self.new(q, self.domain), self.new(r, self.domain)) + + def __pow__(self, n): + if not isinstance(n, int): + return NotImplemented + return self.new(self.element**n, self.domain) + + def __pos__(self): + return self.new(+self.element, self.domain) + + def __neg__(self): + return self.new(-self.element, self.domain) + + def __eq__(self, other): + if not isinstance(other, DomainScalar): + return NotImplemented + return self.element == other.element and self.domain == other.domain + + def is_zero(self): + return self.element == self.domain.zero + + def is_one(self): + return self.element == self.domain.one diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/eigen.py b/phivenv/Lib/site-packages/sympy/polys/matrices/eigen.py new file mode 100644 index 0000000000000000000000000000000000000000..17d673c6ea09002e1cfd5357f301c447a7af4341 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/matrices/eigen.py @@ -0,0 +1,90 @@ +""" + +Routines for computing eigenvectors with DomainMatrix. + +""" +from sympy.core.symbol import Dummy + +from ..agca.extensions import FiniteExtension +from ..factortools import dup_factor_list +from ..polyroots import roots +from ..polytools import Poly +from ..rootoftools import CRootOf + +from .domainmatrix import DomainMatrix + + +def dom_eigenvects(A, l=Dummy('lambda')): + charpoly = A.charpoly() + rows, cols = A.shape + domain = A.domain + _, factors = dup_factor_list(charpoly, domain) + + rational_eigenvects = [] + algebraic_eigenvects = [] + for base, exp in factors: + if len(base) == 2: + field = domain + eigenval = -base[1] / base[0] + + EE_items = [ + [eigenval if i == j else field.zero for j in range(cols)] + for i in range(rows)] + EE = DomainMatrix(EE_items, (rows, cols), field) + + basis = (A - EE).nullspace(divide_last=True) + rational_eigenvects.append((field, eigenval, exp, basis)) + else: + minpoly = Poly.from_list(base, l, domain=domain) + field = FiniteExtension(minpoly) + eigenval = field(l) + + AA_items = [ + [Poly.from_list([item], l, domain=domain).rep for item in row] + for row in A.rep.to_ddm()] + AA_items = [[field(item) for item in row] for row in AA_items] + AA = DomainMatrix(AA_items, (rows, cols), field) + EE_items = [ + [eigenval if i == j else field.zero for j in range(cols)] + for i in range(rows)] + EE = DomainMatrix(EE_items, (rows, cols), field) + + basis = (AA - EE).nullspace(divide_last=True) + algebraic_eigenvects.append((field, minpoly, exp, basis)) + + return rational_eigenvects, algebraic_eigenvects + + +def dom_eigenvects_to_sympy( + rational_eigenvects, algebraic_eigenvects, + Matrix, **kwargs +): + result = [] + + for field, eigenvalue, multiplicity, eigenvects in rational_eigenvects: + eigenvects = eigenvects.rep.to_ddm() + eigenvalue = field.to_sympy(eigenvalue) + new_eigenvects = [ + Matrix([field.to_sympy(x) for x in vect]) + for vect in eigenvects] + result.append((eigenvalue, multiplicity, new_eigenvects)) + + for field, minpoly, multiplicity, eigenvects in algebraic_eigenvects: + eigenvects = eigenvects.rep.to_ddm() + l = minpoly.gens[0] + + eigenvects = [[field.to_sympy(x) for x in vect] for vect in eigenvects] + + degree = minpoly.degree() + minpoly = minpoly.as_expr() + eigenvals = roots(minpoly, l, **kwargs) + if len(eigenvals) != degree: + eigenvals = [CRootOf(minpoly, l, idx) for idx in range(degree)] + + for eigenvalue in eigenvals: + new_eigenvects = [ + Matrix([x.subs(l, eigenvalue) for x in vect]) + for vect in eigenvects] + result.append((eigenvalue, multiplicity, new_eigenvects)) + + return result diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/exceptions.py b/phivenv/Lib/site-packages/sympy/polys/matrices/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..b1e5a4195c66aceed2d5ac1994381d3dec6a64ba --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/matrices/exceptions.py @@ -0,0 +1,67 @@ +""" + +Module to define exceptions to be used in sympy.polys.matrices modules and +classes. + +Ideally all exceptions raised in these modules would be defined and documented +here and not e.g. imported from matrices. Also ideally generic exceptions like +ValueError/TypeError would not be raised anywhere. + +""" + + +class DMError(Exception): + """Base class for errors raised by DomainMatrix""" + pass + + +class DMBadInputError(DMError): + """list of lists is inconsistent with shape""" + pass + + +class DMDomainError(DMError): + """domains do not match""" + pass + + +class DMNotAField(DMDomainError): + """domain is not a field""" + pass + + +class DMFormatError(DMError): + """mixed dense/sparse not supported""" + pass + + +class DMNonInvertibleMatrixError(DMError): + """The matrix in not invertible""" + pass + + +class DMRankError(DMError): + """matrix does not have expected rank""" + pass + + +class DMShapeError(DMError): + """shapes are inconsistent""" + pass + + +class DMNonSquareMatrixError(DMShapeError): + """The matrix is not square""" + pass + + +class DMValueError(DMError): + """The value passed is invalid""" + pass + + +__all__ = [ + 'DMError', 'DMBadInputError', 'DMDomainError', 'DMFormatError', + 'DMRankError', 'DMShapeError', 'DMNotAField', + 'DMNonInvertibleMatrixError', 'DMNonSquareMatrixError', 'DMValueError' +] diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/linsolve.py b/phivenv/Lib/site-packages/sympy/polys/matrices/linsolve.py new file mode 100644 index 0000000000000000000000000000000000000000..af74058d859b744cf8fe1059ddb7c775fece79c7 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/matrices/linsolve.py @@ -0,0 +1,230 @@ +# +# sympy.polys.matrices.linsolve module +# +# This module defines the _linsolve function which is the internal workhorse +# used by linsolve. This computes the solution of a system of linear equations +# using the SDM sparse matrix implementation in sympy.polys.matrices.sdm. This +# is a replacement for solve_lin_sys in sympy.polys.solvers which is +# inefficient for large sparse systems due to the use of a PolyRing with many +# generators: +# +# https://github.com/sympy/sympy/issues/20857 +# +# The implementation of _linsolve here handles: +# +# - Extracting the coefficients from the Expr/Eq input equations. +# - Constructing a domain and converting the coefficients to +# that domain. +# - Using the SDM.rref, SDM.nullspace etc methods to generate the full +# solution working with arithmetic only in the domain of the coefficients. +# +# The routines here are particularly designed to be efficient for large sparse +# systems of linear equations although as well as dense systems. It is +# possible that for some small dense systems solve_lin_sys which uses the +# dense matrix implementation DDM will be more efficient. With smaller systems +# though the bulk of the time is spent just preprocessing the inputs and the +# relative time spent in rref is too small to be noticeable. +# + +from collections import defaultdict + +from sympy.core.add import Add +from sympy.core.mul import Mul +from sympy.core.singleton import S + +from sympy.polys.constructor import construct_domain +from sympy.polys.solvers import PolyNonlinearError + +from .sdm import ( + SDM, + sdm_irref, + sdm_particular_from_rref, + sdm_nullspace_from_rref +) + +from sympy.utilities.misc import filldedent + + +def _linsolve(eqs, syms): + + """Solve a linear system of equations. + + Examples + ======== + + Solve a linear system with a unique solution: + + >>> from sympy import symbols, Eq + >>> from sympy.polys.matrices.linsolve import _linsolve + >>> x, y = symbols('x, y') + >>> eqs = [Eq(x + y, 1), Eq(x - y, 2)] + >>> _linsolve(eqs, [x, y]) + {x: 3/2, y: -1/2} + + In the case of underdetermined systems the solution will be expressed in + terms of the unknown symbols that are unconstrained: + + >>> _linsolve([Eq(x + y, 0)], [x, y]) + {x: -y, y: y} + + """ + # Number of unknowns (columns in the non-augmented matrix) + nsyms = len(syms) + + # Convert to sparse augmented matrix (len(eqs) x (nsyms+1)) + eqsdict, const = _linear_eq_to_dict(eqs, syms) + Aaug = sympy_dict_to_dm(eqsdict, const, syms) + K = Aaug.domain + + # sdm_irref has issues with float matrices. This uses the ddm_rref() + # function. When sdm_rref() can handle float matrices reasonably this + # should be removed... + if K.is_RealField or K.is_ComplexField: + Aaug = Aaug.to_ddm().rref()[0].to_sdm() + + # Compute reduced-row echelon form (RREF) + Arref, pivots, nzcols = sdm_irref(Aaug) + + # No solution: + if pivots and pivots[-1] == nsyms: + return None + + # Particular solution for non-homogeneous system: + P = sdm_particular_from_rref(Arref, nsyms+1, pivots) + + # Nullspace - general solution to homogeneous system + # Note: using nsyms not nsyms+1 to ignore last column + V, nonpivots = sdm_nullspace_from_rref(Arref, K.one, nsyms, pivots, nzcols) + + # Collect together terms from particular and nullspace: + sol = defaultdict(list) + for i, v in P.items(): + sol[syms[i]].append(K.to_sympy(v)) + for npi, Vi in zip(nonpivots, V): + sym = syms[npi] + for i, v in Vi.items(): + sol[syms[i]].append(sym * K.to_sympy(v)) + + # Use a single call to Add for each term: + sol = {s: Add(*terms) for s, terms in sol.items()} + + # Fill in the zeros: + zero = S.Zero + for s in set(syms) - set(sol): + sol[s] = zero + + # All done! + return sol + + +def sympy_dict_to_dm(eqs_coeffs, eqs_rhs, syms): + """Convert a system of dict equations to a sparse augmented matrix""" + elems = set(eqs_rhs).union(*(e.values() for e in eqs_coeffs)) + K, elems_K = construct_domain(elems, field=True, extension=True) + elem_map = dict(zip(elems, elems_K)) + neqs = len(eqs_coeffs) + nsyms = len(syms) + sym2index = dict(zip(syms, range(nsyms))) + eqsdict = [] + for eq, rhs in zip(eqs_coeffs, eqs_rhs): + eqdict = {sym2index[s]: elem_map[c] for s, c in eq.items()} + if rhs: + eqdict[nsyms] = -elem_map[rhs] + if eqdict: + eqsdict.append(eqdict) + sdm_aug = SDM(enumerate(eqsdict), (neqs, nsyms + 1), K) + return sdm_aug + + +def _linear_eq_to_dict(eqs, syms): + """Convert a system Expr/Eq equations into dict form, returning + the coefficient dictionaries and a list of syms-independent terms + from each expression in ``eqs```. + + Examples + ======== + + >>> from sympy.polys.matrices.linsolve import _linear_eq_to_dict + >>> from sympy.abc import x + >>> _linear_eq_to_dict([2*x + 3], {x}) + ([{x: 2}], [3]) + """ + coeffs = [] + ind = [] + symset = set(syms) + for e in eqs: + if e.is_Equality: + coeff, terms = _lin_eq2dict(e.lhs, symset) + cR, tR = _lin_eq2dict(e.rhs, symset) + # there were no nonlinear errors so now + # cancellation is allowed + coeff -= cR + for k, v in tR.items(): + if k in terms: + terms[k] -= v + else: + terms[k] = -v + # don't store coefficients of 0, however + terms = {k: v for k, v in terms.items() if v} + c, d = coeff, terms + else: + c, d = _lin_eq2dict(e, symset) + coeffs.append(d) + ind.append(c) + return coeffs, ind + + +def _lin_eq2dict(a, symset): + """return (c, d) where c is the sym-independent part of ``a`` and + ``d`` is an efficiently calculated dictionary mapping symbols to + their coefficients. A PolyNonlinearError is raised if non-linearity + is detected. + + The values in the dictionary will be non-zero. + + Examples + ======== + + >>> from sympy.polys.matrices.linsolve import _lin_eq2dict + >>> from sympy.abc import x, y + >>> _lin_eq2dict(x + 2*y + 3, {x, y}) + (3, {x: 1, y: 2}) + """ + if a in symset: + return S.Zero, {a: S.One} + elif a.is_Add: + terms_list = defaultdict(list) + coeff_list = [] + for ai in a.args: + ci, ti = _lin_eq2dict(ai, symset) + coeff_list.append(ci) + for mij, cij in ti.items(): + terms_list[mij].append(cij) + coeff = Add(*coeff_list) + terms = {sym: Add(*coeffs) for sym, coeffs in terms_list.items()} + return coeff, terms + elif a.is_Mul: + terms = terms_coeff = None + coeff_list = [] + for ai in a.args: + ci, ti = _lin_eq2dict(ai, symset) + if not ti: + coeff_list.append(ci) + elif terms is None: + terms = ti + terms_coeff = ci + else: + # since ti is not null and we already have + # a term, this is a cross term + raise PolyNonlinearError(filldedent(''' + nonlinear cross-term: %s''' % a)) + coeff = Mul._from_args(coeff_list) + if terms is None: + return coeff, {} + else: + terms = {sym: coeff * c for sym, c in terms.items()} + return coeff * terms_coeff, terms + elif not a.has_xfree(symset): + return a, {} + else: + raise PolyNonlinearError('nonlinear term: %s' % a) diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/lll.py b/phivenv/Lib/site-packages/sympy/polys/matrices/lll.py new file mode 100644 index 0000000000000000000000000000000000000000..f33f91d92c5e20f89f302991e494a6a5b9fa4b2e --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/matrices/lll.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +from math import floor as mfloor + +from sympy.polys.domains import ZZ, QQ +from sympy.polys.matrices.exceptions import DMRankError, DMShapeError, DMValueError, DMDomainError + + +def _ddm_lll(x, delta=QQ(3, 4), return_transform=False): + if QQ(1, 4) >= delta or delta >= QQ(1, 1): + raise DMValueError("delta must lie in range (0.25, 1)") + if x.shape[0] > x.shape[1]: + raise DMShapeError("input matrix must have shape (m, n) with m <= n") + if x.domain != ZZ: + raise DMDomainError("input matrix domain must be ZZ") + m = x.shape[0] + n = x.shape[1] + k = 1 + y = x.copy() + y_star = x.zeros((m, n), QQ) + mu = x.zeros((m, m), QQ) + g_star = [QQ(0, 1) for _ in range(m)] + half = QQ(1, 2) + T = x.eye(m, ZZ) if return_transform else None + linear_dependent_error = "input matrix contains linearly dependent rows" + + def closest_integer(x): + return ZZ(mfloor(x + half)) + + def lovasz_condition(k: int) -> bool: + return g_star[k] >= ((delta - mu[k][k - 1] ** 2) * g_star[k - 1]) + + def mu_small(k: int, j: int) -> bool: + return abs(mu[k][j]) <= half + + def dot_rows(x, y, rows: tuple[int, int]): + return sum(x[rows[0]][z] * y[rows[1]][z] for z in range(x.shape[1])) + + def reduce_row(T, mu, y, rows: tuple[int, int]): + r = closest_integer(mu[rows[0]][rows[1]]) + y[rows[0]] = [y[rows[0]][z] - r * y[rows[1]][z] for z in range(n)] + mu[rows[0]][:rows[1]] = [mu[rows[0]][z] - r * mu[rows[1]][z] for z in range(rows[1])] + mu[rows[0]][rows[1]] -= r + if return_transform: + T[rows[0]] = [T[rows[0]][z] - r * T[rows[1]][z] for z in range(m)] + + for i in range(m): + y_star[i] = [QQ.convert_from(z, ZZ) for z in y[i]] + for j in range(i): + row_dot = dot_rows(y, y_star, (i, j)) + try: + mu[i][j] = row_dot / g_star[j] + except ZeroDivisionError: + raise DMRankError(linear_dependent_error) + y_star[i] = [y_star[i][z] - mu[i][j] * y_star[j][z] for z in range(n)] + g_star[i] = dot_rows(y_star, y_star, (i, i)) + while k < m: + if not mu_small(k, k - 1): + reduce_row(T, mu, y, (k, k - 1)) + if lovasz_condition(k): + for l in range(k - 2, -1, -1): + if not mu_small(k, l): + reduce_row(T, mu, y, (k, l)) + k += 1 + else: + nu = mu[k][k - 1] + alpha = g_star[k] + nu ** 2 * g_star[k - 1] + try: + beta = g_star[k - 1] / alpha + except ZeroDivisionError: + raise DMRankError(linear_dependent_error) + mu[k][k - 1] = nu * beta + g_star[k] = g_star[k] * beta + g_star[k - 1] = alpha + y[k], y[k - 1] = y[k - 1], y[k] + mu[k][:k - 1], mu[k - 1][:k - 1] = mu[k - 1][:k - 1], mu[k][:k - 1] + for i in range(k + 1, m): + xi = mu[i][k] + mu[i][k] = mu[i][k - 1] - nu * xi + mu[i][k - 1] = mu[k][k - 1] * mu[i][k] + xi + if return_transform: + T[k], T[k - 1] = T[k - 1], T[k] + k = max(k - 1, 1) + assert all(lovasz_condition(i) for i in range(1, m)) + assert all(mu_small(i, j) for i in range(m) for j in range(i)) + return y, T + + +def ddm_lll(x, delta=QQ(3, 4)): + return _ddm_lll(x, delta=delta, return_transform=False)[0] + + +def ddm_lll_transform(x, delta=QQ(3, 4)): + return _ddm_lll(x, delta=delta, return_transform=True) diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/normalforms.py b/phivenv/Lib/site-packages/sympy/polys/matrices/normalforms.py new file mode 100644 index 0000000000000000000000000000000000000000..506a68b6946acbeb235eed7650246104da265b78 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/matrices/normalforms.py @@ -0,0 +1,540 @@ +'''Functions returning normal forms of matrices''' + +from collections import defaultdict + +from .domainmatrix import DomainMatrix +from .exceptions import DMDomainError, DMShapeError +from sympy.ntheory.modular import symmetric_residue +from sympy.polys.domains import QQ, ZZ + + +# TODO (future work): +# There are faster algorithms for Smith and Hermite normal forms, which +# we should implement. See e.g. the Kannan-Bachem algorithm: +# + + +def smith_normal_form(m): + ''' + Return the Smith Normal Form of a matrix `m` over the ring `domain`. + This will only work if the ring is a principal ideal domain. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DomainMatrix + >>> from sympy.polys.matrices.normalforms import smith_normal_form + >>> m = DomainMatrix([[ZZ(12), ZZ(6), ZZ(4)], + ... [ZZ(3), ZZ(9), ZZ(6)], + ... [ZZ(2), ZZ(16), ZZ(14)]], (3, 3), ZZ) + >>> print(smith_normal_form(m).to_Matrix()) + Matrix([[1, 0, 0], [0, 10, 0], [0, 0, 30]]) + + ''' + invs = invariant_factors(m) + smf = DomainMatrix.diag(invs, m.domain, m.shape) + return smf + + +def is_smith_normal_form(m): + ''' + Checks that the matrix is in Smith Normal Form + ''' + domain = m.domain + shape = m.shape + zero = domain.zero + m = m.to_list() + + for i in range(shape[0]): + for j in range(shape[1]): + if i == j: + continue + if not m[i][j] == zero: + return False + + upper = min(shape[0], shape[1]) + for i in range(1, upper): + if m[i-1][i-1] == zero: + if m[i][i] != zero: + return False + else: + r = domain.div(m[i][i], m[i-1][i-1])[1] + if r != zero: + return False + + return True + + +def add_columns(m, i, j, a, b, c, d): + # replace m[:, i] by a*m[:, i] + b*m[:, j] + # and m[:, j] by c*m[:, i] + d*m[:, j] + for k in range(len(m)): + e = m[k][i] + m[k][i] = a*e + b*m[k][j] + m[k][j] = c*e + d*m[k][j] + + +def invariant_factors(m): + ''' + Return the tuple of abelian invariants for a matrix `m` + (as in the Smith-Normal form) + + References + ========== + + [1] https://en.wikipedia.org/wiki/Smith_normal_form#Algorithm + [2] https://web.archive.org/web/20200331143852/https://sierra.nmsu.edu/morandi/notes/SmithNormalForm.pdf + + ''' + domain = m.domain + shape = m.shape + m = m.to_list() + return _smith_normal_decomp(m, domain, shape=shape, full=False) + + +def smith_normal_decomp(m): + ''' + Return the Smith-Normal form decomposition of matrix `m`. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DomainMatrix + >>> from sympy.polys.matrices.normalforms import smith_normal_decomp + >>> m = DomainMatrix([[ZZ(12), ZZ(6), ZZ(4)], + ... [ZZ(3), ZZ(9), ZZ(6)], + ... [ZZ(2), ZZ(16), ZZ(14)]], (3, 3), ZZ) + >>> a, s, t = smith_normal_decomp(m) + >>> assert a == s * m * t + ''' + domain = m.domain + rows, cols = shape = m.shape + m = m.to_list() + + invs, s, t = _smith_normal_decomp(m, domain, shape=shape, full=True) + smf = DomainMatrix.diag(invs, domain, shape).to_dense() + + s = DomainMatrix(s, domain=domain, shape=(rows, rows)) + t = DomainMatrix(t, domain=domain, shape=(cols, cols)) + return smf, s, t + + +def _smith_normal_decomp(m, domain, shape, full): + ''' + Return the tuple of abelian invariants for a matrix `m` + (as in the Smith-Normal form). If `full=True` then invertible matrices + ``s, t`` such that the product ``s, m, t`` is the Smith Normal Form + are also returned. + ''' + if not domain.is_PID: + msg = f"The matrix entries must be over a principal ideal domain, but got {domain}" + raise ValueError(msg) + + rows, cols = shape + zero = domain.zero + one = domain.one + + def eye(n): + return [[one if i == j else zero for i in range(n)] for j in range(n)] + + if 0 in shape: + if full: + return (), eye(rows), eye(cols) + else: + return () + + if full: + s = eye(rows) + t = eye(cols) + + def add_rows(m, i, j, a, b, c, d): + # replace m[i, :] by a*m[i, :] + b*m[j, :] + # and m[j, :] by c*m[i, :] + d*m[j, :] + for k in range(len(m[0])): + e = m[i][k] + m[i][k] = a*e + b*m[j][k] + m[j][k] = c*e + d*m[j][k] + + def clear_column(): + # make m[1:, 0] zero by row and column operations + pivot = m[0][0] + for j in range(1, rows): + if m[j][0] == zero: + continue + d, r = domain.div(m[j][0], pivot) + if r == zero: + add_rows(m, 0, j, 1, 0, -d, 1) + if full: + add_rows(s, 0, j, 1, 0, -d, 1) + else: + a, b, g = domain.gcdex(pivot, m[j][0]) + d_0 = domain.exquo(m[j][0], g) + d_j = domain.exquo(pivot, g) + add_rows(m, 0, j, a, b, d_0, -d_j) + if full: + add_rows(s, 0, j, a, b, d_0, -d_j) + pivot = g + + def clear_row(): + # make m[0, 1:] zero by row and column operations + pivot = m[0][0] + for j in range(1, cols): + if m[0][j] == zero: + continue + d, r = domain.div(m[0][j], pivot) + if r == zero: + add_columns(m, 0, j, 1, 0, -d, 1) + if full: + add_columns(t, 0, j, 1, 0, -d, 1) + else: + a, b, g = domain.gcdex(pivot, m[0][j]) + d_0 = domain.exquo(m[0][j], g) + d_j = domain.exquo(pivot, g) + add_columns(m, 0, j, a, b, d_0, -d_j) + if full: + add_columns(t, 0, j, a, b, d_0, -d_j) + pivot = g + + # permute the rows and columns until m[0,0] is non-zero if possible + ind = [i for i in range(rows) if m[i][0] != zero] + if ind and ind[0] != zero: + m[0], m[ind[0]] = m[ind[0]], m[0] + if full: + s[0], s[ind[0]] = s[ind[0]], s[0] + else: + ind = [j for j in range(cols) if m[0][j] != zero] + if ind and ind[0] != zero: + for row in m: + row[0], row[ind[0]] = row[ind[0]], row[0] + if full: + for row in t: + row[0], row[ind[0]] = row[ind[0]], row[0] + + # make the first row and column except m[0,0] zero + while (any(m[0][i] != zero for i in range(1,cols)) or + any(m[i][0] != zero for i in range(1,rows))): + clear_column() + clear_row() + + def to_domain_matrix(m): + return DomainMatrix(m, shape=(len(m), len(m[0])), domain=domain) + + if m[0][0] != 0: + c = domain.canonical_unit(m[0][0]) + if domain.is_Field: + c = 1 / m[0][0] + if c != domain.one: + m[0][0] *= c + if full: + s[0] = [elem * c for elem in s[0]] + + if 1 in shape: + invs = () + else: + lower_right = [r[1:] for r in m[1:]] + ret = _smith_normal_decomp(lower_right, domain, + shape=(rows - 1, cols - 1), full=full) + if full: + invs, s_small, t_small = ret + s2 = [[1] + [0]*(rows-1)] + [[0] + row for row in s_small] + t2 = [[1] + [0]*(cols-1)] + [[0] + row for row in t_small] + s, s2, t, t2 = list(map(to_domain_matrix, [s, s2, t, t2])) + s = s2 * s + t = t * t2 + s = s.to_list() + t = t.to_list() + else: + invs = ret + + if m[0][0]: + result = [m[0][0]] + result.extend(invs) + # in case m[0] doesn't divide the invariants of the rest of the matrix + for i in range(len(result)-1): + a, b = result[i], result[i+1] + if b and domain.div(b, a)[1] != zero: + if full: + x, y, d = domain.gcdex(a, b) + else: + d = domain.gcd(a, b) + + alpha = domain.div(a, d)[0] + if full: + beta = domain.div(b, d)[0] + add_rows(s, i, i + 1, 1, 0, x, 1) + add_columns(t, i, i + 1, 1, y, 0, 1) + add_rows(s, i, i + 1, 1, -alpha, 0, 1) + add_columns(t, i, i + 1, 1, 0, -beta, 1) + add_rows(s, i, i + 1, 0, 1, -1, 0) + + result[i+1] = b * alpha + result[i] = d + else: + break + else: + if full: + if rows > 1: + s = s[1:] + [s[0]] + if cols > 1: + t = [row[1:] + [row[0]] for row in t] + result = invs + (m[0][0],) + + if full: + return tuple(result), s, t + else: + return tuple(result) + + +def _gcdex(a, b): + r""" + This supports the functions that compute Hermite Normal Form. + + Explanation + =========== + + Let x, y be the coefficients returned by the extended Euclidean + Algorithm, so that x*a + y*b = g. In the algorithms for computing HNF, + it is critical that x, y not only satisfy the condition of being small + in magnitude -- namely that |x| <= |b|/g, |y| <- |a|/g -- but also that + y == 0 when a | b. + + """ + x, y, g = ZZ.gcdex(a, b) + if a != 0 and b % a == 0: + y = 0 + x = -1 if a < 0 else 1 + return x, y, g + + +def _hermite_normal_form(A): + r""" + Compute the Hermite Normal Form of DomainMatrix *A* over :ref:`ZZ`. + + Parameters + ========== + + A : :py:class:`~.DomainMatrix` over domain :ref:`ZZ`. + + Returns + ======= + + :py:class:`~.DomainMatrix` + The HNF of matrix *A*. + + Raises + ====== + + DMDomainError + If the domain of the matrix is not :ref:`ZZ`. + + References + ========== + + .. [1] Cohen, H. *A Course in Computational Algebraic Number Theory.* + (See Algorithm 2.4.5.) + + """ + if not A.domain.is_ZZ: + raise DMDomainError('Matrix must be over domain ZZ.') + # We work one row at a time, starting from the bottom row, and working our + # way up. + m, n = A.shape + A = A.to_ddm().copy() + # Our goal is to put pivot entries in the rightmost columns. + # Invariant: Before processing each row, k should be the index of the + # leftmost column in which we have so far put a pivot. + k = n + for i in range(m - 1, -1, -1): + if k == 0: + # This case can arise when n < m and we've already found n pivots. + # We don't need to consider any more rows, because this is already + # the maximum possible number of pivots. + break + k -= 1 + # k now points to the column in which we want to put a pivot. + # We want zeros in all entries to the left of the pivot column. + for j in range(k - 1, -1, -1): + if A[i][j] != 0: + # Replace cols j, k by lin combs of these cols such that, in row i, + # col j has 0, while col k has the gcd of their row i entries. Note + # that this ensures a nonzero entry in col k. + u, v, d = _gcdex(A[i][k], A[i][j]) + r, s = A[i][k] // d, A[i][j] // d + add_columns(A, k, j, u, v, -s, r) + b = A[i][k] + # Do not want the pivot entry to be negative. + if b < 0: + add_columns(A, k, k, -1, 0, -1, 0) + b = -b + # The pivot entry will be 0 iff the row was 0 from the pivot col all the + # way to the left. In this case, we are still working on the same pivot + # col for the next row. Therefore: + if b == 0: + k += 1 + # If the pivot entry is nonzero, then we want to reduce all entries to its + # right in the sense of the division algorithm, i.e. make them all remainders + # w.r.t. the pivot as divisor. + else: + for j in range(k + 1, n): + q = A[i][j] // b + add_columns(A, j, k, 1, -q, 0, 1) + # Finally, the HNF consists of those columns of A in which we succeeded in making + # a nonzero pivot. + return DomainMatrix.from_rep(A.to_dfm_or_ddm())[:, k:] + + +def _hermite_normal_form_modulo_D(A, D): + r""" + Perform the mod *D* Hermite Normal Form reduction algorithm on + :py:class:`~.DomainMatrix` *A*. + + Explanation + =========== + + If *A* is an $m \times n$ matrix of rank $m$, having Hermite Normal Form + $W$, and if *D* is any positive integer known in advance to be a multiple + of $\det(W)$, then the HNF of *A* can be computed by an algorithm that + works mod *D* in order to prevent coefficient explosion. + + Parameters + ========== + + A : :py:class:`~.DomainMatrix` over :ref:`ZZ` + $m \times n$ matrix, having rank $m$. + D : :ref:`ZZ` + Positive integer, known to be a multiple of the determinant of the + HNF of *A*. + + Returns + ======= + + :py:class:`~.DomainMatrix` + The HNF of matrix *A*. + + Raises + ====== + + DMDomainError + If the domain of the matrix is not :ref:`ZZ`, or + if *D* is given but is not in :ref:`ZZ`. + + DMShapeError + If the matrix has more rows than columns. + + References + ========== + + .. [1] Cohen, H. *A Course in Computational Algebraic Number Theory.* + (See Algorithm 2.4.8.) + + """ + if not A.domain.is_ZZ: + raise DMDomainError('Matrix must be over domain ZZ.') + if not ZZ.of_type(D) or D < 1: + raise DMDomainError('Modulus D must be positive element of domain ZZ.') + + def add_columns_mod_R(m, R, i, j, a, b, c, d): + # replace m[:, i] by (a*m[:, i] + b*m[:, j]) % R + # and m[:, j] by (c*m[:, i] + d*m[:, j]) % R + for k in range(len(m)): + e = m[k][i] + m[k][i] = symmetric_residue((a * e + b * m[k][j]) % R, R) + m[k][j] = symmetric_residue((c * e + d * m[k][j]) % R, R) + + W = defaultdict(dict) + + m, n = A.shape + if n < m: + raise DMShapeError('Matrix must have at least as many columns as rows.') + A = A.to_list() + k = n + R = D + for i in range(m - 1, -1, -1): + k -= 1 + for j in range(k - 1, -1, -1): + if A[i][j] != 0: + u, v, d = _gcdex(A[i][k], A[i][j]) + r, s = A[i][k] // d, A[i][j] // d + add_columns_mod_R(A, R, k, j, u, v, -s, r) + b = A[i][k] + if b == 0: + A[i][k] = b = R + u, v, d = _gcdex(b, R) + for ii in range(m): + W[ii][i] = u*A[ii][k] % R + if W[i][i] == 0: + W[i][i] = R + for j in range(i + 1, m): + q = W[i][j] // W[i][i] + add_columns(W, j, i, 1, -q, 0, 1) + R //= d + return DomainMatrix(W, (m, m), ZZ).to_dense() + + +def hermite_normal_form(A, *, D=None, check_rank=False): + r""" + Compute the Hermite Normal Form of :py:class:`~.DomainMatrix` *A* over + :ref:`ZZ`. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DomainMatrix + >>> from sympy.polys.matrices.normalforms import hermite_normal_form + >>> m = DomainMatrix([[ZZ(12), ZZ(6), ZZ(4)], + ... [ZZ(3), ZZ(9), ZZ(6)], + ... [ZZ(2), ZZ(16), ZZ(14)]], (3, 3), ZZ) + >>> print(hermite_normal_form(m).to_Matrix()) + Matrix([[10, 0, 2], [0, 15, 3], [0, 0, 2]]) + + Parameters + ========== + + A : $m \times n$ ``DomainMatrix`` over :ref:`ZZ`. + + D : :ref:`ZZ`, optional + Let $W$ be the HNF of *A*. If known in advance, a positive integer *D* + being any multiple of $\det(W)$ may be provided. In this case, if *A* + also has rank $m$, then we may use an alternative algorithm that works + mod *D* in order to prevent coefficient explosion. + + check_rank : boolean, optional (default=False) + The basic assumption is that, if you pass a value for *D*, then + you already believe that *A* has rank $m$, so we do not waste time + checking it for you. If you do want this to be checked (and the + ordinary, non-modulo *D* algorithm to be used if the check fails), then + set *check_rank* to ``True``. + + Returns + ======= + + :py:class:`~.DomainMatrix` + The HNF of matrix *A*. + + Raises + ====== + + DMDomainError + If the domain of the matrix is not :ref:`ZZ`, or + if *D* is given but is not in :ref:`ZZ`. + + DMShapeError + If the mod *D* algorithm is used but the matrix has more rows than + columns. + + References + ========== + + .. [1] Cohen, H. *A Course in Computational Algebraic Number Theory.* + (See Algorithms 2.4.5 and 2.4.8.) + + """ + if not A.domain.is_ZZ: + raise DMDomainError('Matrix must be over domain ZZ.') + if D is not None and (not check_rank or A.convert_to(QQ).rank() == A.shape[0]): + return _hermite_normal_form_modulo_D(A, D) + else: + return _hermite_normal_form(A) diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/rref.py b/phivenv/Lib/site-packages/sympy/polys/matrices/rref.py new file mode 100644 index 0000000000000000000000000000000000000000..c5a71b04971e8dc8ecac5cc2691f98ba68e35d45 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/matrices/rref.py @@ -0,0 +1,422 @@ +# Algorithms for computing the reduced row echelon form of a matrix. +# +# We need to choose carefully which algorithms to use depending on the domain, +# shape, and sparsity of the matrix as well as things like the bit count in the +# case of ZZ or QQ. This is important because the algorithms have different +# performance characteristics in the extremes of dense vs sparse. +# +# In all cases we use the sparse implementations but we need to choose between +# Gauss-Jordan elimination with division and fraction-free Gauss-Jordan +# elimination. For very sparse matrices over ZZ with low bit counts it is +# asymptotically faster to use Gauss-Jordan elimination with division. For +# dense matrices with high bit counts it is asymptotically faster to use +# fraction-free Gauss-Jordan. +# +# The most important thing is to get the extreme cases right because it can +# make a big difference. In between the extremes though we have to make a +# choice and here we use empirically determined thresholds based on timings +# with random sparse matrices. +# +# In the case of QQ we have to consider the denominators as well. If the +# denominators are small then it is faster to clear them and use fraction-free +# Gauss-Jordan over ZZ. If the denominators are large then it is faster to use +# Gauss-Jordan elimination with division over QQ. +# +# Timings for the various algorithms can be found at +# +# https://github.com/sympy/sympy/issues/25410 +# https://github.com/sympy/sympy/pull/25443 + +from sympy.polys.domains import ZZ + +from sympy.polys.matrices.sdm import SDM, sdm_irref, sdm_rref_den +from sympy.polys.matrices.ddm import DDM +from sympy.polys.matrices.dense import ddm_irref, ddm_irref_den + + +def _dm_rref(M, *, method='auto'): + """ + Compute the reduced row echelon form of a ``DomainMatrix``. + + This function is the implementation of :meth:`DomainMatrix.rref`. + + Chooses the best algorithm depending on the domain, shape, and sparsity of + the matrix as well as things like the bit count in the case of :ref:`ZZ` or + :ref:`QQ`. The result is returned over the field associated with the domain + of the Matrix. + + See Also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.rref + The ``DomainMatrix`` method that calls this function. + sympy.polys.matrices.rref._dm_rref_den + Alternative function for computing RREF with denominator. + """ + method, use_fmt = _dm_rref_choose_method(M, method, denominator=False) + + M, old_fmt = _dm_to_fmt(M, use_fmt) + + if method == 'GJ': + # Use Gauss-Jordan with division over the associated field. + Mf = _to_field(M) + M_rref, pivots = _dm_rref_GJ(Mf) + + elif method == 'FF': + # Use fraction-free GJ over the current domain. + M_rref_f, den, pivots = _dm_rref_den_FF(M) + M_rref = _to_field(M_rref_f) / den + + elif method == 'CD': + # Clear denominators and use fraction-free GJ in the associated ring. + _, Mr = M.clear_denoms_rowwise(convert=True) + M_rref_f, den, pivots = _dm_rref_den_FF(Mr) + M_rref = _to_field(M_rref_f) / den + + else: + raise ValueError(f"Unknown method for rref: {method}") + + M_rref, _ = _dm_to_fmt(M_rref, old_fmt) + + # Invariants: + # - M_rref is in the same format (sparse or dense) as the input matrix. + # - M_rref is in the associated field domain and any denominator was + # divided in (so is implicitly 1 now). + + return M_rref, pivots + + +def _dm_rref_den(M, *, keep_domain=True, method='auto'): + """ + Compute the reduced row echelon form of a ``DomainMatrix`` with denominator. + + This function is the implementation of :meth:`DomainMatrix.rref_den`. + + Chooses the best algorithm depending on the domain, shape, and sparsity of + the matrix as well as things like the bit count in the case of :ref:`ZZ` or + :ref:`QQ`. The result is returned over the same domain as the input matrix + unless ``keep_domain=False`` in which case the result might be over an + associated ring or field domain. + + See Also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.rref_den + The ``DomainMatrix`` method that calls this function. + sympy.polys.matrices.rref._dm_rref + Alternative function for computing RREF without denominator. + """ + method, use_fmt = _dm_rref_choose_method(M, method, denominator=True) + + M, old_fmt = _dm_to_fmt(M, use_fmt) + + if method == 'FF': + # Use fraction-free GJ over the current domain. + M_rref, den, pivots = _dm_rref_den_FF(M) + + elif method == 'GJ': + # Use Gauss-Jordan with division over the associated field. + M_rref_f, pivots = _dm_rref_GJ(_to_field(M)) + + # Convert back to the ring? + if keep_domain and M_rref_f.domain != M.domain: + _, M_rref = M_rref_f.clear_denoms(convert=True) + + if pivots: + den = M_rref[0, pivots[0]].element + else: + den = M_rref.domain.one + else: + # Possibly an associated field + M_rref = M_rref_f + den = M_rref.domain.one + + elif method == 'CD': + # Clear denominators and use fraction-free GJ in the associated ring. + _, Mr = M.clear_denoms_rowwise(convert=True) + + M_rref_r, den, pivots = _dm_rref_den_FF(Mr) + + if keep_domain and M_rref_r.domain != M.domain: + # Convert back to the field + M_rref = _to_field(M_rref_r) / den + den = M.domain.one + else: + # Possibly an associated ring + M_rref = M_rref_r + + if pivots: + den = M_rref[0, pivots[0]].element + else: + den = M_rref.domain.one + else: + raise ValueError(f"Unknown method for rref: {method}") + + M_rref, _ = _dm_to_fmt(M_rref, old_fmt) + + # Invariants: + # - M_rref is in the same format (sparse or dense) as the input matrix. + # - If keep_domain=True then M_rref and den are in the same domain as the + # input matrix + # - If keep_domain=False then M_rref might be in an associated ring or + # field domain but den is always in the same domain as M_rref. + + return M_rref, den, pivots + + +def _dm_to_fmt(M, fmt): + """Convert a matrix to the given format and return the old format.""" + old_fmt = M.rep.fmt + if old_fmt == fmt: + pass + elif fmt == 'dense': + M = M.to_dense() + elif fmt == 'sparse': + M = M.to_sparse() + else: + raise ValueError(f'Unknown format: {fmt}') # pragma: no cover + return M, old_fmt + + +# These are the four basic implementations that we want to choose between: + + +def _dm_rref_GJ(M): + """Compute RREF using Gauss-Jordan elimination with division.""" + if M.rep.fmt == 'sparse': + return _dm_rref_GJ_sparse(M) + else: + return _dm_rref_GJ_dense(M) + + +def _dm_rref_den_FF(M): + """Compute RREF using fraction-free Gauss-Jordan elimination.""" + if M.rep.fmt == 'sparse': + return _dm_rref_den_FF_sparse(M) + else: + return _dm_rref_den_FF_dense(M) + + +def _dm_rref_GJ_sparse(M): + """Compute RREF using sparse Gauss-Jordan elimination with division.""" + M_rref_d, pivots, _ = sdm_irref(M.rep) + M_rref_sdm = SDM(M_rref_d, M.shape, M.domain) + pivots = tuple(pivots) + return M.from_rep(M_rref_sdm), pivots + + +def _dm_rref_GJ_dense(M): + """Compute RREF using dense Gauss-Jordan elimination with division.""" + partial_pivot = M.domain.is_RR or M.domain.is_CC + ddm = M.rep.to_ddm().copy() + pivots = ddm_irref(ddm, _partial_pivot=partial_pivot) + M_rref_ddm = DDM(ddm, M.shape, M.domain) + pivots = tuple(pivots) + return M.from_rep(M_rref_ddm.to_dfm_or_ddm()), pivots + + +def _dm_rref_den_FF_sparse(M): + """Compute RREF using sparse fraction-free Gauss-Jordan elimination.""" + M_rref_d, den, pivots = sdm_rref_den(M.rep, M.domain) + M_rref_sdm = SDM(M_rref_d, M.shape, M.domain) + pivots = tuple(pivots) + return M.from_rep(M_rref_sdm), den, pivots + + +def _dm_rref_den_FF_dense(M): + """Compute RREF using sparse fraction-free Gauss-Jordan elimination.""" + ddm = M.rep.to_ddm().copy() + den, pivots = ddm_irref_den(ddm, M.domain) + M_rref_ddm = DDM(ddm, M.shape, M.domain) + pivots = tuple(pivots) + return M.from_rep(M_rref_ddm.to_dfm_or_ddm()), den, pivots + + +def _dm_rref_choose_method(M, method, *, denominator=False): + """Choose the fastest method for computing RREF for M.""" + + if method != 'auto': + if method.endswith('_dense'): + method = method[:-len('_dense')] + use_fmt = 'dense' + else: + use_fmt = 'sparse' + + else: + # The sparse implementations are always faster + use_fmt = 'sparse' + + K = M.domain + + if K.is_ZZ: + method = _dm_rref_choose_method_ZZ(M, denominator=denominator) + elif K.is_QQ: + method = _dm_rref_choose_method_QQ(M, denominator=denominator) + elif K.is_RR or K.is_CC: + # TODO: Add partial pivot support to the sparse implementations. + method = 'GJ' + use_fmt = 'dense' + elif K.is_EX and M.rep.fmt == 'dense' and not denominator: + # Do not switch to the sparse implementation for EX because the + # domain does not have proper canonicalization and the sparse + # implementation gives equivalent but non-identical results over EX + # from performing arithmetic in a different order. Specifically + # test_issue_23718 ends up getting a more complicated expression + # when using the sparse implementation. Probably the best fix for + # this is something else but for now we stick with the dense + # implementation for EX if the matrix is already dense. + method = 'GJ' + use_fmt = 'dense' + else: + # This is definitely suboptimal. More work is needed to determine + # the best method for computing RREF over different domains. + if denominator: + method = 'FF' + else: + method = 'GJ' + + return method, use_fmt + + +def _dm_rref_choose_method_QQ(M, *, denominator=False): + """Choose the fastest method for computing RREF over QQ.""" + # The same sorts of considerations apply here as in the case of ZZ. Here + # though a new more significant consideration is what sort of denominators + # we have and what to do with them so we focus on that. + + # First compute the density. This is the average number of non-zero entries + # per row but only counting rows that have at least one non-zero entry + # since RREF can ignore fully zero rows. + density, _, ncols = _dm_row_density(M) + + # For sparse matrices use Gauss-Jordan elimination over QQ regardless. + if density < min(5, ncols/2): + return 'GJ' + + # Compare the bit-length of the lcm of the denominators to the bit length + # of the numerators. + # + # The threshold here is empirical: we prefer rref over QQ if clearing + # denominators would result in a numerator matrix having 5x the bit size of + # the current numerators. + numers, denoms = _dm_QQ_numers_denoms(M) + numer_bits = max([n.bit_length() for n in numers], default=1) + + denom_lcm = ZZ.one + for d in denoms: + denom_lcm = ZZ.lcm(denom_lcm, d) + if denom_lcm.bit_length() > 5*numer_bits: + return 'GJ' + + # If we get here then the matrix is dense and the lcm of the denominators + # is not too large compared to the numerators. For particularly small + # denominators it is fastest just to clear them and use fraction-free + # Gauss-Jordan over ZZ. With very small denominators this is a little + # faster than using rref_den over QQ but there is an intermediate regime + # where rref_den over QQ is significantly faster. The small denominator + # case is probably very common because small fractions like 1/2 or 1/3 are + # often seen in user inputs. + + if denom_lcm.bit_length() < 50: + return 'CD' + else: + return 'FF' + + +def _dm_rref_choose_method_ZZ(M, *, denominator=False): + """Choose the fastest method for computing RREF over ZZ.""" + # In the extreme of very sparse matrices and low bit counts it is faster to + # use Gauss-Jordan elimination over QQ rather than fraction-free + # Gauss-Jordan over ZZ. In the opposite extreme of dense matrices and high + # bit counts it is faster to use fraction-free Gauss-Jordan over ZZ. These + # two extreme cases need to be handled differently because they lead to + # different asymptotic complexities. In between these two extremes we need + # a threshold for deciding which method to use. This threshold is + # determined empirically by timing the two methods with random matrices. + + # The disadvantage of using empirical timings is that future optimisations + # might change the relative speeds so this can easily become out of date. + # The main thing is to get the asymptotic complexity right for the extreme + # cases though so the precise value of the threshold is hopefully not too + # important. + + # Empirically determined parameter. + PARAM = 10000 + + # First compute the density. This is the average number of non-zero entries + # per row but only counting rows that have at least one non-zero entry + # since RREF can ignore fully zero rows. + density, nrows_nz, ncols = _dm_row_density(M) + + # For small matrices use QQ if more than half the entries are zero. + if nrows_nz < 10: + if density < ncols/2: + return 'GJ' + else: + return 'FF' + + # These are just shortcuts for the formula below. + if density < 5: + return 'GJ' + elif density > 5 + PARAM/nrows_nz: + return 'FF' # pragma: no cover + + # Maximum bitsize of any entry. + elements = _dm_elements(M) + bits = max([e.bit_length() for e in elements], default=1) + + # Wideness parameter. This is 1 for square or tall matrices but >1 for wide + # matrices. + wideness = max(1, 2/3*ncols/nrows_nz) + + max_density = (5 + PARAM/(nrows_nz*bits**2)) * wideness + + if density < max_density: + return 'GJ' + else: + return 'FF' + + +def _dm_row_density(M): + """Density measure for sparse matrices. + + Defines the "density", ``d`` as the average number of non-zero entries per + row except ignoring rows that are fully zero. RREF can ignore fully zero + rows so they are excluded. By definition ``d >= 1`` except that we define + ``d = 0`` for the zero matrix. + + Returns ``(density, nrows_nz, ncols)`` where ``nrows_nz`` counts the number + of nonzero rows and ``ncols`` is the number of columns. + """ + # Uses the SDM dict-of-dicts representation. + ncols = M.shape[1] + rows_nz = M.rep.to_sdm().values() + if not rows_nz: + return 0, 0, ncols + else: + nrows_nz = len(rows_nz) + density = sum(map(len, rows_nz)) / nrows_nz + return density, nrows_nz, ncols + + +def _dm_elements(M): + """Return nonzero elements of a DomainMatrix.""" + elements, _ = M.to_flat_nz() + return elements + + +def _dm_QQ_numers_denoms(Mq): + """Returns the numerators and denominators of a DomainMatrix over QQ.""" + elements = _dm_elements(Mq) + numers = [e.numerator for e in elements] + denoms = [e.denominator for e in elements] + return numers, denoms + + +def _to_field(M): + """Convert a DomainMatrix to a field if possible.""" + K = M.domain + if K.has_assoc_Field: + return M.to_field() + else: + return M diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/sdm.py b/phivenv/Lib/site-packages/sympy/polys/matrices/sdm.py new file mode 100644 index 0000000000000000000000000000000000000000..84558d83b6f58a3a9074d31f1a315ac901cd68da --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/matrices/sdm.py @@ -0,0 +1,2197 @@ +""" + +Module for the SDM class. + +""" + +from operator import add, neg, pos, sub, mul +from collections import defaultdict + +from sympy.external.gmpy import GROUND_TYPES +from sympy.utilities.decorator import doctest_depends_on +from sympy.utilities.iterables import _strongly_connected_components + +from .exceptions import DMBadInputError, DMDomainError, DMShapeError + +from sympy.polys.domains import QQ + +from .ddm import DDM + + +if GROUND_TYPES != 'flint': + __doctest_skip__ = ['SDM.to_dfm', 'SDM.to_dfm_or_ddm'] + + +class SDM(dict): + r"""Sparse matrix based on polys domain elements + + This is a dict subclass and is a wrapper for a dict of dicts that supports + basic matrix arithmetic +, -, *, **. + + + In order to create a new :py:class:`~.SDM`, a dict + of dicts mapping non-zero elements to their + corresponding row and column in the matrix is needed. + + We also need to specify the shape and :py:class:`~.Domain` + of our :py:class:`~.SDM` object. + + We declare a 2x2 :py:class:`~.SDM` matrix belonging + to QQ domain as shown below. + The 2x2 Matrix in the example is + + .. math:: + A = \left[\begin{array}{ccc} + 0 & \frac{1}{2} \\ + 0 & 0 \end{array} \right] + + + >>> from sympy.polys.matrices.sdm import SDM + >>> from sympy import QQ + >>> elemsdict = {0:{1:QQ(1, 2)}} + >>> A = SDM(elemsdict, (2, 2), QQ) + >>> A + {0: {1: 1/2}} + + We can manipulate :py:class:`~.SDM` the same way + as a Matrix class + + >>> from sympy import ZZ + >>> A = SDM({0:{1: ZZ(2)}, 1:{0:ZZ(1)}}, (2, 2), ZZ) + >>> B = SDM({0:{0: ZZ(3)}, 1:{1:ZZ(4)}}, (2, 2), ZZ) + >>> A + B + {0: {0: 3, 1: 2}, 1: {0: 1, 1: 4}} + + Multiplication + + >>> A*B + {0: {1: 8}, 1: {0: 3}} + >>> A*ZZ(2) + {0: {1: 4}, 1: {0: 2}} + + """ + + fmt = 'sparse' + is_DFM = False + is_DDM = False + + def __init__(self, elemsdict, shape, domain): + super().__init__(elemsdict) + self.shape = self.rows, self.cols = m, n = shape + self.domain = domain + + if not all(0 <= r < m for r in self): + raise DMBadInputError("Row out of range") + if not all(0 <= c < n for row in self.values() for c in row): + raise DMBadInputError("Column out of range") + + def getitem(self, i, j): + try: + return self[i][j] + except KeyError: + m, n = self.shape + if -m <= i < m and -n <= j < n: + try: + return self[i % m][j % n] + except KeyError: + return self.domain.zero + else: + raise IndexError("index out of range") + + def setitem(self, i, j, value): + m, n = self.shape + if not (-m <= i < m and -n <= j < n): + raise IndexError("index out of range") + i, j = i % m, j % n + if value: + try: + self[i][j] = value + except KeyError: + self[i] = {j: value} + else: + rowi = self.get(i, None) + if rowi is not None: + try: + del rowi[j] + except KeyError: + pass + else: + if not rowi: + del self[i] + + def extract_slice(self, slice1, slice2): + m, n = self.shape + ri = range(m)[slice1] + ci = range(n)[slice2] + + sdm = {} + for i, row in self.items(): + if i in ri: + row = {ci.index(j): e for j, e in row.items() if j in ci} + if row: + sdm[ri.index(i)] = row + + return self.new(sdm, (len(ri), len(ci)), self.domain) + + def extract(self, rows, cols): + if not (self and rows and cols): + return self.zeros((len(rows), len(cols)), self.domain) + + m, n = self.shape + if not (-m <= min(rows) <= max(rows) < m): + raise IndexError('Row index out of range') + if not (-n <= min(cols) <= max(cols) < n): + raise IndexError('Column index out of range') + + # rows and cols can contain duplicates e.g. M[[1, 2, 2], [0, 1]] + # Build a map from row/col in self to list of rows/cols in output + rowmap = defaultdict(list) + colmap = defaultdict(list) + for i2, i1 in enumerate(rows): + rowmap[i1 % m].append(i2) + for j2, j1 in enumerate(cols): + colmap[j1 % n].append(j2) + + # Used to efficiently skip zero rows/cols + rowset = set(rowmap) + colset = set(colmap) + + sdm1 = self + sdm2 = {} + for i1 in rowset & sdm1.keys(): + row1 = sdm1[i1] + row2 = {} + for j1 in colset & row1.keys(): + row1_j1 = row1[j1] + for j2 in colmap[j1]: + row2[j2] = row1_j1 + if row2: + for i2 in rowmap[i1]: + sdm2[i2] = row2.copy() + + return self.new(sdm2, (len(rows), len(cols)), self.domain) + + def __str__(self): + rowsstr = [] + for i, row in self.items(): + elemsstr = ', '.join('%s: %s' % (j, elem) for j, elem in row.items()) + rowsstr.append('%s: {%s}' % (i, elemsstr)) + return '{%s}' % ', '.join(rowsstr) + + def __repr__(self): + cls = type(self).__name__ + rows = dict.__repr__(self) + return '%s(%s, %s, %s)' % (cls, rows, self.shape, self.domain) + + @classmethod + def new(cls, sdm, shape, domain): + """ + + Parameters + ========== + + sdm: A dict of dicts for non-zero elements in SDM + shape: tuple representing dimension of SDM + domain: Represents :py:class:`~.Domain` of SDM + + Returns + ======= + + An :py:class:`~.SDM` object + + Examples + ======== + + >>> from sympy.polys.matrices.sdm import SDM + >>> from sympy import QQ + >>> elemsdict = {0:{1: QQ(2)}} + >>> A = SDM.new(elemsdict, (2, 2), QQ) + >>> A + {0: {1: 2}} + + """ + return cls(sdm, shape, domain) + + def copy(A): + """ + Returns the copy of a :py:class:`~.SDM` object + + Examples + ======== + + >>> from sympy.polys.matrices.sdm import SDM + >>> from sympy import QQ + >>> elemsdict = {0:{1:QQ(2)}, 1:{}} + >>> A = SDM(elemsdict, (2, 2), QQ) + >>> B = A.copy() + >>> B + {0: {1: 2}, 1: {}} + + """ + Ac = {i: Ai.copy() for i, Ai in A.items()} + return A.new(Ac, A.shape, A.domain) + + @classmethod + def from_list(cls, ddm, shape, domain): + """ + Create :py:class:`~.SDM` object from a list of lists. + + Parameters + ========== + + ddm: + list of lists containing domain elements + shape: + Dimensions of :py:class:`~.SDM` matrix + domain: + Represents :py:class:`~.Domain` of :py:class:`~.SDM` object + + Returns + ======= + + :py:class:`~.SDM` containing elements of ddm + + Examples + ======== + + >>> from sympy.polys.matrices.sdm import SDM + >>> from sympy import QQ + >>> ddm = [[QQ(1, 2), QQ(0)], [QQ(0), QQ(3, 4)]] + >>> A = SDM.from_list(ddm, (2, 2), QQ) + >>> A + {0: {0: 1/2}, 1: {1: 3/4}} + + See Also + ======== + + to_list + from_list_flat + from_dok + from_ddm + """ + + m, n = shape + if not (len(ddm) == m and all(len(row) == n for row in ddm)): + raise DMBadInputError("Inconsistent row-list/shape") + getrow = lambda i: {j:ddm[i][j] for j in range(n) if ddm[i][j]} + irows = ((i, getrow(i)) for i in range(m)) + sdm = {i: row for i, row in irows if row} + return cls(sdm, shape, domain) + + @classmethod + def from_ddm(cls, ddm): + """ + Create :py:class:`~.SDM` from a :py:class:`~.DDM`. + + Examples + ======== + + >>> from sympy.polys.matrices.ddm import DDM + >>> from sympy.polys.matrices.sdm import SDM + >>> from sympy import QQ + >>> ddm = DDM( [[QQ(1, 2), 0], [0, QQ(3, 4)]], (2, 2), QQ) + >>> A = SDM.from_ddm(ddm) + >>> A + {0: {0: 1/2}, 1: {1: 3/4}} + >>> SDM.from_ddm(ddm).to_ddm() == ddm + True + + See Also + ======== + + to_ddm + from_list + from_list_flat + from_dok + """ + return cls.from_list(ddm, ddm.shape, ddm.domain) + + def to_list(M): + """ + Convert a :py:class:`~.SDM` object to a list of lists. + + Examples + ======== + + >>> from sympy.polys.matrices.sdm import SDM + >>> from sympy import QQ + >>> elemsdict = {0:{1:QQ(2)}, 1:{}} + >>> A = SDM(elemsdict, (2, 2), QQ) + >>> A.to_list() + [[0, 2], [0, 0]] + + + """ + m, n = M.shape + zero = M.domain.zero + ddm = [[zero] * n for _ in range(m)] + for i, row in M.items(): + for j, e in row.items(): + ddm[i][j] = e + return ddm + + def to_list_flat(M): + """ + Convert :py:class:`~.SDM` to a flat list. + + Examples + ======== + + >>> from sympy.polys.matrices.sdm import SDM + >>> from sympy import QQ + >>> A = SDM({0:{1:QQ(2)}, 1:{0: QQ(3)}}, (2, 2), QQ) + >>> A.to_list_flat() + [0, 2, 3, 0] + >>> A == A.from_list_flat(A.to_list_flat(), A.shape, A.domain) + True + + See Also + ======== + + from_list_flat + to_list + to_dok + to_ddm + """ + m, n = M.shape + zero = M.domain.zero + flat = [zero] * (m * n) + for i, row in M.items(): + for j, e in row.items(): + flat[i*n + j] = e + return flat + + @classmethod + def from_list_flat(cls, elements, shape, domain): + """ + Create :py:class:`~.SDM` from a flat list of elements. + + Examples + ======== + + >>> from sympy.polys.matrices.sdm import SDM + >>> from sympy import QQ + >>> A = SDM.from_list_flat([QQ(0), QQ(2), QQ(0), QQ(0)], (2, 2), QQ) + >>> A + {0: {1: 2}} + >>> A == A.from_list_flat(A.to_list_flat(), A.shape, A.domain) + True + + See Also + ======== + + to_list_flat + from_list + from_dok + from_ddm + """ + m, n = shape + if len(elements) != m * n: + raise DMBadInputError("Inconsistent flat-list shape") + sdm = defaultdict(dict) + for inj, element in enumerate(elements): + if element: + i, j = divmod(inj, n) + sdm[i][j] = element + return cls(sdm, shape, domain) + + def to_flat_nz(M): + """ + Convert :class:`SDM` to a flat list of nonzero elements and data. + + Explanation + =========== + + This is used to operate on a list of the elements of a matrix and then + reconstruct a modified matrix with elements in the same positions using + :meth:`from_flat_nz`. Zero elements are omitted from the list. + + Examples + ======== + + >>> from sympy.polys.matrices.sdm import SDM + >>> from sympy import QQ + >>> A = SDM({0:{1:QQ(2)}, 1:{0: QQ(3)}}, (2, 2), QQ) + >>> elements, data = A.to_flat_nz() + >>> elements + [2, 3] + >>> A == A.from_flat_nz(elements, data, A.domain) + True + + See Also + ======== + + from_flat_nz + to_list_flat + sympy.polys.matrices.ddm.DDM.to_flat_nz + sympy.polys.matrices.domainmatrix.DomainMatrix.to_flat_nz + """ + dok = M.to_dok() + indices = tuple(dok) + elements = list(dok.values()) + data = (indices, M.shape) + return elements, data + + @classmethod + def from_flat_nz(cls, elements, data, domain): + """ + Reconstruct a :class:`~.SDM` after calling :meth:`to_flat_nz`. + + See :meth:`to_flat_nz` for explanation. + + See Also + ======== + + to_flat_nz + from_list_flat + sympy.polys.matrices.ddm.DDM.from_flat_nz + sympy.polys.matrices.domainmatrix.DomainMatrix.from_flat_nz + """ + indices, shape = data + dok = dict(zip(indices, elements)) + return cls.from_dok(dok, shape, domain) + + def to_dod(M): + """ + Convert to dictionary of dictionaries (dod) format. + + Examples + ======== + + >>> from sympy.polys.matrices.sdm import SDM + >>> from sympy import QQ + >>> A = SDM({0: {1: QQ(2)}, 1: {0: QQ(3)}}, (2, 2), QQ) + >>> A.to_dod() + {0: {1: 2}, 1: {0: 3}} + + See Also + ======== + + from_dod + sympy.polys.matrices.domainmatrix.DomainMatrix.to_dod + """ + return {i: row.copy() for i, row in M.items()} + + @classmethod + def from_dod(cls, dod, shape, domain): + """ + Create :py:class:`~.SDM` from dictionary of dictionaries (dod) format. + + Examples + ======== + + >>> from sympy.polys.matrices.sdm import SDM + >>> from sympy import QQ + >>> dod = {0: {1: QQ(2)}, 1: {0: QQ(3)}} + >>> A = SDM.from_dod(dod, (2, 2), QQ) + >>> A + {0: {1: 2}, 1: {0: 3}} + >>> A == SDM.from_dod(A.to_dod(), A.shape, A.domain) + True + + See Also + ======== + + to_dod + sympy.polys.matrices.domainmatrix.DomainMatrix.to_dod + """ + sdm = defaultdict(dict) + for i, row in dod.items(): + for j, e in row.items(): + if e: + sdm[i][j] = e + return cls(sdm, shape, domain) + + def to_dok(M): + """ + Convert to dictionary of keys (dok) format. + + Examples + ======== + + >>> from sympy.polys.matrices.sdm import SDM + >>> from sympy import QQ + >>> A = SDM({0: {1: QQ(2)}, 1: {0: QQ(3)}}, (2, 2), QQ) + >>> A.to_dok() + {(0, 1): 2, (1, 0): 3} + + See Also + ======== + + from_dok + to_list + to_list_flat + to_ddm + """ + return {(i, j): e for i, row in M.items() for j, e in row.items()} + + @classmethod + def from_dok(cls, dok, shape, domain): + """ + Create :py:class:`~.SDM` from dictionary of keys (dok) format. + + Examples + ======== + + >>> from sympy.polys.matrices.sdm import SDM + >>> from sympy import QQ + >>> dok = {(0, 1): QQ(2), (1, 0): QQ(3)} + >>> A = SDM.from_dok(dok, (2, 2), QQ) + >>> A + {0: {1: 2}, 1: {0: 3}} + >>> A == SDM.from_dok(A.to_dok(), A.shape, A.domain) + True + + See Also + ======== + + to_dok + from_list + from_list_flat + from_ddm + """ + sdm = defaultdict(dict) + for (i, j), e in dok.items(): + if e: + sdm[i][j] = e + return cls(sdm, shape, domain) + + def iter_values(M): + """ + Iterate over the nonzero values of a :py:class:`~.SDM` matrix. + + Examples + ======== + + >>> from sympy.polys.matrices.sdm import SDM + >>> from sympy import QQ + >>> A = SDM({0: {1: QQ(2)}, 1: {0: QQ(3)}}, (2, 2), QQ) + >>> list(A.iter_values()) + [2, 3] + + """ + for row in M.values(): + yield from row.values() + + def iter_items(M): + """ + Iterate over indices and values of the nonzero elements. + + Examples + ======== + + >>> from sympy.polys.matrices.sdm import SDM + >>> from sympy import QQ + >>> A = SDM({0: {1: QQ(2)}, 1: {0: QQ(3)}}, (2, 2), QQ) + >>> list(A.iter_items()) + [((0, 1), 2), ((1, 0), 3)] + + See Also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.iter_items + """ + for i, row in M.items(): + for j, e in row.items(): + yield (i, j), e + + def to_ddm(M): + """ + Convert a :py:class:`~.SDM` object to a :py:class:`~.DDM` object + + Examples + ======== + + >>> from sympy.polys.matrices.sdm import SDM + >>> from sympy import QQ + >>> A = SDM({0:{1:QQ(2)}, 1:{}}, (2, 2), QQ) + >>> A.to_ddm() + [[0, 2], [0, 0]] + + """ + return DDM(M.to_list(), M.shape, M.domain) + + def to_sdm(M): + """ + Convert to :py:class:`~.SDM` format (returns self). + """ + return M + + @doctest_depends_on(ground_types=['flint']) + def to_dfm(M): + """ + Convert a :py:class:`~.SDM` object to a :py:class:`~.DFM` object + + Examples + ======== + + >>> from sympy.polys.matrices.sdm import SDM + >>> from sympy import QQ + >>> A = SDM({0:{1:QQ(2)}, 1:{}}, (2, 2), QQ) + >>> A.to_dfm() + [[0, 2], [0, 0]] + + See Also + ======== + + to_ddm + to_dfm_or_ddm + sympy.polys.matrices.domainmatrix.DomainMatrix.to_dfm + """ + return M.to_ddm().to_dfm() + + @doctest_depends_on(ground_types=['flint']) + def to_dfm_or_ddm(M): + """ + Convert to :py:class:`~.DFM` if possible, else :py:class:`~.DDM`. + + Examples + ======== + + >>> from sympy.polys.matrices.sdm import SDM + >>> from sympy import QQ + >>> A = SDM({0:{1:QQ(2)}, 1:{}}, (2, 2), QQ) + >>> A.to_dfm_or_ddm() + [[0, 2], [0, 0]] + >>> type(A.to_dfm_or_ddm()) # depends on the ground types + + + See Also + ======== + + to_ddm + to_dfm + sympy.polys.matrices.domainmatrix.DomainMatrix.to_dfm_or_ddm + """ + return M.to_ddm().to_dfm_or_ddm() + + @classmethod + def zeros(cls, shape, domain): + r""" + + Returns a :py:class:`~.SDM` of size shape, + belonging to the specified domain + + In the example below we declare a matrix A where, + + .. math:: + A := \left[\begin{array}{ccc} + 0 & 0 & 0 \\ + 0 & 0 & 0 \end{array} \right] + + >>> from sympy.polys.matrices.sdm import SDM + >>> from sympy import QQ + >>> A = SDM.zeros((2, 3), QQ) + >>> A + {} + + """ + return cls({}, shape, domain) + + @classmethod + def ones(cls, shape, domain): + one = domain.one + m, n = shape + row = dict(zip(range(n), [one]*n)) + sdm = {i: row.copy() for i in range(m)} + return cls(sdm, shape, domain) + + @classmethod + def eye(cls, shape, domain): + """ + + Returns a identity :py:class:`~.SDM` matrix of dimensions + size x size, belonging to the specified domain + + Examples + ======== + + >>> from sympy.polys.matrices.sdm import SDM + >>> from sympy import QQ + >>> I = SDM.eye((2, 2), QQ) + >>> I + {0: {0: 1}, 1: {1: 1}} + + """ + if isinstance(shape, int): + rows, cols = shape, shape + else: + rows, cols = shape + one = domain.one + sdm = {i: {i: one} for i in range(min(rows, cols))} + return cls(sdm, (rows, cols), domain) + + @classmethod + def diag(cls, diagonal, domain, shape=None): + if shape is None: + shape = (len(diagonal), len(diagonal)) + sdm = {i: {i: v} for i, v in enumerate(diagonal) if v} + return cls(sdm, shape, domain) + + def transpose(M): + """ + + Returns the transpose of a :py:class:`~.SDM` matrix + + Examples + ======== + + >>> from sympy.polys.matrices.sdm import SDM + >>> from sympy import QQ + >>> A = SDM({0:{1:QQ(2)}, 1:{}}, (2, 2), QQ) + >>> A.transpose() + {1: {0: 2}} + + """ + MT = sdm_transpose(M) + return M.new(MT, M.shape[::-1], M.domain) + + def __add__(A, B): + if not isinstance(B, SDM): + return NotImplemented + elif A.shape != B.shape: + raise DMShapeError("Matrix size mismatch: %s + %s" % (A.shape, B.shape)) + return A.add(B) + + def __sub__(A, B): + if not isinstance(B, SDM): + return NotImplemented + elif A.shape != B.shape: + raise DMShapeError("Matrix size mismatch: %s - %s" % (A.shape, B.shape)) + return A.sub(B) + + def __neg__(A): + return A.neg() + + def __mul__(A, B): + """A * B""" + if isinstance(B, SDM): + return A.matmul(B) + elif B in A.domain: + return A.mul(B) + else: + return NotImplemented + + def __rmul__(a, b): + if b in a.domain: + return a.rmul(b) + else: + return NotImplemented + + def matmul(A, B): + """ + Performs matrix multiplication of two SDM matrices + + Parameters + ========== + + A, B: SDM to multiply + + Returns + ======= + + SDM + SDM after multiplication + + Raises + ====== + + DomainError + If domain of A does not match + with that of B + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices.sdm import SDM + >>> A = SDM({0:{1: ZZ(2)}, 1:{0:ZZ(1)}}, (2, 2), ZZ) + >>> B = SDM({0:{0:ZZ(2), 1:ZZ(3)}, 1:{0:ZZ(4)}}, (2, 2), ZZ) + >>> A.matmul(B) + {0: {0: 8}, 1: {0: 2, 1: 3}} + + """ + if A.domain != B.domain: + raise DMDomainError + m, n = A.shape + n2, o = B.shape + if n != n2: + raise DMShapeError + C = sdm_matmul(A, B, A.domain, m, o) + return A.new(C, (m, o), A.domain) + + def mul(A, b): + """ + Multiplies each element of A with a scalar b + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices.sdm import SDM + >>> A = SDM({0:{1: ZZ(2)}, 1:{0:ZZ(1)}}, (2, 2), ZZ) + >>> A.mul(ZZ(3)) + {0: {1: 6}, 1: {0: 3}} + + """ + Csdm = unop_dict(A, lambda aij: aij*b) + return A.new(Csdm, A.shape, A.domain) + + def rmul(A, b): + Csdm = unop_dict(A, lambda aij: b*aij) + return A.new(Csdm, A.shape, A.domain) + + def mul_elementwise(A, B): + if A.domain != B.domain: + raise DMDomainError + if A.shape != B.shape: + raise DMShapeError + zero = A.domain.zero + fzero = lambda e: zero + Csdm = binop_dict(A, B, mul, fzero, fzero) + return A.new(Csdm, A.shape, A.domain) + + def add(A, B): + """ + + Adds two :py:class:`~.SDM` matrices + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices.sdm import SDM + >>> A = SDM({0:{1: ZZ(2)}, 1:{0:ZZ(1)}}, (2, 2), ZZ) + >>> B = SDM({0:{0: ZZ(3)}, 1:{1:ZZ(4)}}, (2, 2), ZZ) + >>> A.add(B) + {0: {0: 3, 1: 2}, 1: {0: 1, 1: 4}} + + """ + Csdm = binop_dict(A, B, add, pos, pos) + return A.new(Csdm, A.shape, A.domain) + + def sub(A, B): + """ + + Subtracts two :py:class:`~.SDM` matrices + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices.sdm import SDM + >>> A = SDM({0:{1: ZZ(2)}, 1:{0:ZZ(1)}}, (2, 2), ZZ) + >>> B = SDM({0:{0: ZZ(3)}, 1:{1:ZZ(4)}}, (2, 2), ZZ) + >>> A.sub(B) + {0: {0: -3, 1: 2}, 1: {0: 1, 1: -4}} + + """ + Csdm = binop_dict(A, B, sub, pos, neg) + return A.new(Csdm, A.shape, A.domain) + + def neg(A): + """ + + Returns the negative of a :py:class:`~.SDM` matrix + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices.sdm import SDM + >>> A = SDM({0:{1: ZZ(2)}, 1:{0:ZZ(1)}}, (2, 2), ZZ) + >>> A.neg() + {0: {1: -2}, 1: {0: -1}} + + """ + Csdm = unop_dict(A, neg) + return A.new(Csdm, A.shape, A.domain) + + def convert_to(A, K): + """ + Converts the :py:class:`~.Domain` of a :py:class:`~.SDM` matrix to K + + Examples + ======== + + >>> from sympy import ZZ, QQ + >>> from sympy.polys.matrices.sdm import SDM + >>> A = SDM({0:{1: ZZ(2)}, 1:{0:ZZ(1)}}, (2, 2), ZZ) + >>> A.convert_to(QQ) + {0: {1: 2}, 1: {0: 1}} + + """ + Kold = A.domain + if K == Kold: + return A.copy() + Ak = unop_dict(A, lambda e: K.convert_from(e, Kold)) + return A.new(Ak, A.shape, K) + + def nnz(A): + """Number of non-zero elements in the :py:class:`~.SDM` matrix. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices.sdm import SDM + >>> A = SDM({0:{1: ZZ(2)}, 1:{0:ZZ(1)}}, (2, 2), ZZ) + >>> A.nnz() + 2 + + See Also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.nnz + """ + return sum(map(len, A.values())) + + def scc(A): + """Strongly connected components of a square matrix *A*. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices.sdm import SDM + >>> A = SDM({0:{0: ZZ(2)}, 1:{1:ZZ(1)}}, (2, 2), ZZ) + >>> A.scc() + [[0], [1]] + + See also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.scc + """ + rows, cols = A.shape + assert rows == cols + V = range(rows) + Emap = {v: list(A.get(v, [])) for v in V} + return _strongly_connected_components(V, Emap) + + def rref(A): + """ + + Returns reduced-row echelon form and list of pivots for the :py:class:`~.SDM` + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.polys.matrices.sdm import SDM + >>> A = SDM({0:{0:QQ(1), 1:QQ(2)}, 1:{0:QQ(2), 1:QQ(4)}}, (2, 2), QQ) + >>> A.rref() + ({0: {0: 1, 1: 2}}, [0]) + + """ + B, pivots, _ = sdm_irref(A) + return A.new(B, A.shape, A.domain), pivots + + def rref_den(A): + """ + + Returns reduced-row echelon form (RREF) with denominator and pivots. + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.polys.matrices.sdm import SDM + >>> A = SDM({0:{0:QQ(1), 1:QQ(2)}, 1:{0:QQ(2), 1:QQ(4)}}, (2, 2), QQ) + >>> A.rref_den() + ({0: {0: 1, 1: 2}}, 1, [0]) + + """ + K = A.domain + A_rref_sdm, denom, pivots = sdm_rref_den(A, K) + A_rref = A.new(A_rref_sdm, A.shape, A.domain) + return A_rref, denom, pivots + + def inv(A): + """ + + Returns inverse of a matrix A + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.polys.matrices.sdm import SDM + >>> A = SDM({0:{0:QQ(1), 1:QQ(2)}, 1:{0:QQ(3), 1:QQ(4)}}, (2, 2), QQ) + >>> A.inv() + {0: {0: -2, 1: 1}, 1: {0: 3/2, 1: -1/2}} + + """ + return A.to_dfm_or_ddm().inv().to_sdm() + + def det(A): + """ + Returns determinant of A + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.polys.matrices.sdm import SDM + >>> A = SDM({0:{0:QQ(1), 1:QQ(2)}, 1:{0:QQ(3), 1:QQ(4)}}, (2, 2), QQ) + >>> A.det() + -2 + + """ + # It would be better to have a sparse implementation of det for use + # with very sparse matrices. Extremely sparse matrices probably just + # have determinant zero and we could probably detect that very quickly. + # In the meantime, we convert to a dense matrix and use ddm_idet. + # + # If GROUND_TYPES=flint though then we will use Flint's implementation + # if possible (dfm). + return A.to_dfm_or_ddm().det() + + def lu(A): + """ + + Returns LU decomposition for a matrix A + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.polys.matrices.sdm import SDM + >>> A = SDM({0:{0:QQ(1), 1:QQ(2)}, 1:{0:QQ(3), 1:QQ(4)}}, (2, 2), QQ) + >>> A.lu() + ({0: {0: 1}, 1: {0: 3, 1: 1}}, {0: {0: 1, 1: 2}, 1: {1: -2}}, []) + + """ + L, U, swaps = A.to_ddm().lu() + return A.from_ddm(L), A.from_ddm(U), swaps + + def qr(self): + """ + QR decomposition for SDM (Sparse Domain Matrix). + + Returns: + - Q: Orthogonal matrix as a SDM. + - R: Upper triangular matrix as a SDM. + """ + ddm_q, ddm_r = self.to_ddm().qr() + Q = ddm_q.to_sdm() + R = ddm_r.to_sdm() + return Q, R + + def lu_solve(A, b): + """ + + Uses LU decomposition to solve Ax = b, + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.polys.matrices.sdm import SDM + >>> A = SDM({0:{0:QQ(1), 1:QQ(2)}, 1:{0:QQ(3), 1:QQ(4)}}, (2, 2), QQ) + >>> b = SDM({0:{0:QQ(1)}, 1:{0:QQ(2)}}, (2, 1), QQ) + >>> A.lu_solve(b) + {1: {0: 1/2}} + + """ + return A.from_ddm(A.to_ddm().lu_solve(b.to_ddm())) + + def fflu(self): + """ + Fraction free LU decomposition of SDM. + + Uses DDM implementation. + + See Also + ======== + + sympy.polys.matrices.ddm.DDM.fflu + """ + ddm_p, ddm_l, ddm_d, ddm_u = self.to_dfm_or_ddm().fflu() + P = ddm_p.to_sdm() + L = ddm_l.to_sdm() + D = ddm_d.to_sdm() + U = ddm_u.to_sdm() + return P, L, D, U + + def nullspace(A): + """ + Nullspace of a :py:class:`~.SDM` matrix A. + + The domain of the matrix must be a field. + + It is better to use the :meth:`~.DomainMatrix.nullspace` method rather + than this method which is otherwise no longer used. + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.polys.matrices.sdm import SDM + >>> A = SDM({0:{0:QQ(1), 1:QQ(2)}, 1:{0: QQ(2), 1: QQ(4)}}, (2, 2), QQ) + >>> A.nullspace() + ({0: {0: -2, 1: 1}}, [1]) + + + See Also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.nullspace + The preferred way to get the nullspace of a matrix. + + """ + ncols = A.shape[1] + one = A.domain.one + B, pivots, nzcols = sdm_irref(A) + K, nonpivots = sdm_nullspace_from_rref(B, one, ncols, pivots, nzcols) + K = dict(enumerate(K)) + shape = (len(K), ncols) + return A.new(K, shape, A.domain), nonpivots + + def nullspace_from_rref(A, pivots=None): + """ + Returns nullspace for a :py:class:`~.SDM` matrix ``A`` in RREF. + + The domain of the matrix can be any domain. + + The matrix must already be in reduced row echelon form (RREF). + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.polys.matrices.sdm import SDM + >>> A = SDM({0:{0:QQ(1), 1:QQ(2)}, 1:{0: QQ(2), 1: QQ(4)}}, (2, 2), QQ) + >>> A_rref, pivots = A.rref() + >>> A_null, nonpivots = A_rref.nullspace_from_rref(pivots) + >>> A_null + {0: {0: -2, 1: 1}} + >>> pivots + [0] + >>> nonpivots + [1] + + See Also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.nullspace + The higher-level function that would usually be called instead of + calling this one directly. + + sympy.polys.matrices.domainmatrix.DomainMatrix.nullspace_from_rref + The higher-level direct equivalent of this function. + + sympy.polys.matrices.ddm.DDM.nullspace_from_rref + The equivalent function for dense :py:class:`~.DDM` matrices. + + """ + m, n = A.shape + K = A.domain + + if pivots is None: + pivots = sorted(map(min, A.values())) + + if not pivots: + return A.eye((n, n), K), list(range(n)) + elif len(pivots) == n: + return A.zeros((0, n), K), [] + + # In fraction-free RREF the nonzero entry inserted for the pivots is + # not necessarily 1. + pivot_val = A[0][pivots[0]] + assert not K.is_zero(pivot_val) + + pivots_set = set(pivots) + + # Loop once over all nonzero entries making a map from column indices + # to the nonzero entries in that column along with the row index of the + # nonzero entry. This is basically the transpose of the matrix. + nonzero_cols = defaultdict(list) + for i, Ai in A.items(): + for j, Aij in Ai.items(): + nonzero_cols[j].append((i, Aij)) + + # Usually in SDM we want to avoid looping over the dimensions of the + # matrix because it is optimised to support extremely sparse matrices. + # Here in nullspace though every zero column becomes a nonzero column + # so we need to loop once over the columns at least (range(n)) rather + # than just the nonzero entries of the matrix. We can still avoid + # an inner loop over the rows though by using the nonzero_cols map. + basis = [] + nonpivots = [] + for j in range(n): + if j in pivots_set: + continue + nonpivots.append(j) + + vec = {j: pivot_val} + for ip, Aij in nonzero_cols[j]: + vec[pivots[ip]] = -Aij + + basis.append(vec) + + sdm = dict(enumerate(basis)) + A_null = A.new(sdm, (len(basis), n), K) + + return (A_null, nonpivots) + + def particular(A): + ncols = A.shape[1] + B, pivots, nzcols = sdm_irref(A) + P = sdm_particular_from_rref(B, ncols, pivots) + rep = {0:P} if P else {} + return A.new(rep, (1, ncols-1), A.domain) + + def hstack(A, *B): + """Horizontally stacks :py:class:`~.SDM` matrices. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices.sdm import SDM + + >>> A = SDM({0: {0: ZZ(1), 1: ZZ(2)}, 1: {0: ZZ(3), 1: ZZ(4)}}, (2, 2), ZZ) + >>> B = SDM({0: {0: ZZ(5), 1: ZZ(6)}, 1: {0: ZZ(7), 1: ZZ(8)}}, (2, 2), ZZ) + >>> A.hstack(B) + {0: {0: 1, 1: 2, 2: 5, 3: 6}, 1: {0: 3, 1: 4, 2: 7, 3: 8}} + + >>> C = SDM({0: {0: ZZ(9), 1: ZZ(10)}, 1: {0: ZZ(11), 1: ZZ(12)}}, (2, 2), ZZ) + >>> A.hstack(B, C) + {0: {0: 1, 1: 2, 2: 5, 3: 6, 4: 9, 5: 10}, 1: {0: 3, 1: 4, 2: 7, 3: 8, 4: 11, 5: 12}} + """ + Anew = dict(A.copy()) + rows, cols = A.shape + domain = A.domain + + for Bk in B: + Bkrows, Bkcols = Bk.shape + assert Bkrows == rows + assert Bk.domain == domain + + for i, Bki in Bk.items(): + Ai = Anew.get(i, None) + if Ai is None: + Anew[i] = Ai = {} + for j, Bkij in Bki.items(): + Ai[j + cols] = Bkij + cols += Bkcols + + return A.new(Anew, (rows, cols), A.domain) + + def vstack(A, *B): + """Vertically stacks :py:class:`~.SDM` matrices. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices.sdm import SDM + + >>> A = SDM({0: {0: ZZ(1), 1: ZZ(2)}, 1: {0: ZZ(3), 1: ZZ(4)}}, (2, 2), ZZ) + >>> B = SDM({0: {0: ZZ(5), 1: ZZ(6)}, 1: {0: ZZ(7), 1: ZZ(8)}}, (2, 2), ZZ) + >>> A.vstack(B) + {0: {0: 1, 1: 2}, 1: {0: 3, 1: 4}, 2: {0: 5, 1: 6}, 3: {0: 7, 1: 8}} + + >>> C = SDM({0: {0: ZZ(9), 1: ZZ(10)}, 1: {0: ZZ(11), 1: ZZ(12)}}, (2, 2), ZZ) + >>> A.vstack(B, C) + {0: {0: 1, 1: 2}, 1: {0: 3, 1: 4}, 2: {0: 5, 1: 6}, 3: {0: 7, 1: 8}, 4: {0: 9, 1: 10}, 5: {0: 11, 1: 12}} + """ + Anew = dict(A.copy()) + rows, cols = A.shape + domain = A.domain + + for Bk in B: + Bkrows, Bkcols = Bk.shape + assert Bkcols == cols + assert Bk.domain == domain + + for i, Bki in Bk.items(): + Anew[i + rows] = Bki + rows += Bkrows + + return A.new(Anew, (rows, cols), A.domain) + + def applyfunc(self, func, domain): + sdm = {i: {j: func(e) for j, e in row.items()} for i, row in self.items()} + return self.new(sdm, self.shape, domain) + + def charpoly(A): + """ + Returns the coefficients of the characteristic polynomial + of the :py:class:`~.SDM` matrix. These elements will be domain elements. + The domain of the elements will be same as domain of the :py:class:`~.SDM`. + + Examples + ======== + + >>> from sympy import QQ, Symbol + >>> from sympy.polys.matrices.sdm import SDM + >>> from sympy.polys import Poly + >>> A = SDM({0:{0:QQ(1), 1:QQ(2)}, 1:{0:QQ(3), 1:QQ(4)}}, (2, 2), QQ) + >>> A.charpoly() + [1, -5, -2] + + We can create a polynomial using the + coefficients using :py:class:`~.Poly` + + >>> x = Symbol('x') + >>> p = Poly(A.charpoly(), x, domain=A.domain) + >>> p + Poly(x**2 - 5*x - 2, x, domain='QQ') + + """ + K = A.domain + n, _ = A.shape + pdict = sdm_berk(A, n, K) + plist = [K.zero] * (n + 1) + for i, pi in pdict.items(): + plist[i] = pi + return plist + + def is_zero_matrix(self): + """ + Says whether this matrix has all zero entries. + """ + return not self + + def is_upper(self): + """ + Says whether this matrix is upper-triangular. True can be returned + even if the matrix is not square. + """ + return all(i <= j for i, row in self.items() for j in row) + + def is_lower(self): + """ + Says whether this matrix is lower-triangular. True can be returned + even if the matrix is not square. + """ + return all(i >= j for i, row in self.items() for j in row) + + def is_diagonal(self): + """ + Says whether this matrix is diagonal. True can be returned + even if the matrix is not square. + """ + return all(i == j for i, row in self.items() for j in row) + + def diagonal(self): + """ + Returns the diagonal of the matrix as a list. + """ + m, n = self.shape + zero = self.domain.zero + return [row.get(i, zero) for i, row in self.items() if i < n] + + def lll(A, delta=QQ(3, 4)): + """ + Returns the LLL-reduced basis for the :py:class:`~.SDM` matrix. + """ + return A.to_dfm_or_ddm().lll(delta=delta).to_sdm() + + def lll_transform(A, delta=QQ(3, 4)): + """ + Returns the LLL-reduced basis and transformation matrix. + """ + reduced, transform = A.to_dfm_or_ddm().lll_transform(delta=delta) + return reduced.to_sdm(), transform.to_sdm() + + +def binop_dict(A, B, fab, fa, fb): + Anz, Bnz = set(A), set(B) + C = {} + + for i in Anz & Bnz: + Ai, Bi = A[i], B[i] + Ci = {} + Anzi, Bnzi = set(Ai), set(Bi) + for j in Anzi & Bnzi: + Cij = fab(Ai[j], Bi[j]) + if Cij: + Ci[j] = Cij + for j in Anzi - Bnzi: + Cij = fa(Ai[j]) + if Cij: + Ci[j] = Cij + for j in Bnzi - Anzi: + Cij = fb(Bi[j]) + if Cij: + Ci[j] = Cij + if Ci: + C[i] = Ci + + for i in Anz - Bnz: + Ai = A[i] + Ci = {} + for j, Aij in Ai.items(): + Cij = fa(Aij) + if Cij: + Ci[j] = Cij + if Ci: + C[i] = Ci + + for i in Bnz - Anz: + Bi = B[i] + Ci = {} + for j, Bij in Bi.items(): + Cij = fb(Bij) + if Cij: + Ci[j] = Cij + if Ci: + C[i] = Ci + + return C + + +def unop_dict(A, f): + B = {} + for i, Ai in A.items(): + Bi = {} + for j, Aij in Ai.items(): + Bij = f(Aij) + if Bij: + Bi[j] = Bij + if Bi: + B[i] = Bi + return B + + +def sdm_transpose(M): + MT = {} + for i, Mi in M.items(): + for j, Mij in Mi.items(): + try: + MT[j][i] = Mij + except KeyError: + MT[j] = {i: Mij} + return MT + + +def sdm_dotvec(A, B, K): + return K.sum(A[j] * B[j] for j in A.keys() & B.keys()) + + +def sdm_matvecmul(A, B, K): + C = {} + for i, Ai in A.items(): + Ci = sdm_dotvec(Ai, B, K) + if Ci: + C[i] = Ci + return C + + +def sdm_matmul(A, B, K, m, o): + # + # Should be fast if A and B are very sparse. + # Consider e.g. A = B = eye(1000). + # + # The idea here is that we compute C = A*B in terms of the rows of C and + # B since the dict of dicts representation naturally stores the matrix as + # rows. The ith row of C (Ci) is equal to the sum of Aik * Bk where Bk is + # the kth row of B. The algorithm below loops over each nonzero element + # Aik of A and if the corresponding row Bj is nonzero then we do + # Ci += Aik * Bk. + # To make this more efficient we don't need to loop over all elements Aik. + # Instead for each row Ai we compute the intersection of the nonzero + # columns in Ai with the nonzero rows in B. That gives the k such that + # Aik and Bk are both nonzero. In Python the intersection of two sets + # of int can be computed very efficiently. + # + if K.is_EXRAW: + return sdm_matmul_exraw(A, B, K, m, o) + + C = {} + B_knz = set(B) + for i, Ai in A.items(): + Ci = {} + Ai_knz = set(Ai) + for k in Ai_knz & B_knz: + Aik = Ai[k] + for j, Bkj in B[k].items(): + Cij = Ci.get(j, None) + if Cij is not None: + Cij = Cij + Aik * Bkj + if Cij: + Ci[j] = Cij + else: + Ci.pop(j) + else: + Cij = Aik * Bkj + if Cij: + Ci[j] = Cij + if Ci: + C[i] = Ci + return C + + +def sdm_matmul_exraw(A, B, K, m, o): + # + # Like sdm_matmul above except that: + # + # - Handles cases like 0*oo -> nan (sdm_matmul skips multiplication by zero) + # - Uses K.sum (Add(*items)) for efficient addition of Expr + # + zero = K.zero + C = {} + B_knz = set(B) + for i, Ai in A.items(): + Ci_list = defaultdict(list) + Ai_knz = set(Ai) + + # Nonzero row/column pair + for k in Ai_knz & B_knz: + Aik = Ai[k] + if zero * Aik == zero: + # This is the main inner loop: + for j, Bkj in B[k].items(): + Ci_list[j].append(Aik * Bkj) + else: + for j in range(o): + Ci_list[j].append(Aik * B[k].get(j, zero)) + + # Zero row in B, check for infinities in A + for k in Ai_knz - B_knz: + zAik = zero * Ai[k] + if zAik != zero: + for j in range(o): + Ci_list[j].append(zAik) + + # Add terms using K.sum (Add(*terms)) for efficiency + Ci = {} + for j, Cij_list in Ci_list.items(): + Cij = K.sum(Cij_list) + if Cij: + Ci[j] = Cij + if Ci: + C[i] = Ci + + # Find all infinities in B + for k, Bk in B.items(): + for j, Bkj in Bk.items(): + if zero * Bkj != zero: + for i in range(m): + Aik = A.get(i, {}).get(k, zero) + # If Aik is not zero then this was handled above + if Aik == zero: + Ci = C.get(i, {}) + Cij = Ci.get(j, zero) + Aik * Bkj + if Cij != zero: + Ci[j] = Cij + C[i] = Ci + else: + Ci.pop(j, None) + if Ci: + C[i] = Ci + else: + C.pop(i, None) + + return C + + +def sdm_irref(A): + """RREF and pivots of a sparse matrix *A*. + + Compute the reduced row echelon form (RREF) of the matrix *A* and return a + list of the pivot columns. This routine does not work in place and leaves + the original matrix *A* unmodified. + + The domain of the matrix must be a field. + + Examples + ======== + + This routine works with a dict of dicts sparse representation of a matrix: + + >>> from sympy import QQ + >>> from sympy.polys.matrices.sdm import sdm_irref + >>> A = {0: {0: QQ(1), 1: QQ(2)}, 1: {0: QQ(3), 1: QQ(4)}} + >>> Arref, pivots, _ = sdm_irref(A) + >>> Arref + {0: {0: 1}, 1: {1: 1}} + >>> pivots + [0, 1] + + The analogous calculation with :py:class:`~.MutableDenseMatrix` would be + + >>> from sympy import Matrix + >>> M = Matrix([[1, 2], [3, 4]]) + >>> Mrref, pivots = M.rref() + >>> Mrref + Matrix([ + [1, 0], + [0, 1]]) + >>> pivots + (0, 1) + + Notes + ===== + + The cost of this algorithm is determined purely by the nonzero elements of + the matrix. No part of the cost of any step in this algorithm depends on + the number of rows or columns in the matrix. No step depends even on the + number of nonzero rows apart from the primary loop over those rows. The + implementation is much faster than ddm_rref for sparse matrices. In fact + at the time of writing it is also (slightly) faster than the dense + implementation even if the input is a fully dense matrix so it seems to be + faster in all cases. + + The elements of the matrix should support exact division with ``/``. For + example elements of any domain that is a field (e.g. ``QQ``) should be + fine. No attempt is made to handle inexact arithmetic. + + See Also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.rref + The higher-level function that would normally be used to call this + routine. + sympy.polys.matrices.dense.ddm_irref + The dense equivalent of this routine. + sdm_rref_den + Fraction-free version of this routine. + """ + # + # Any zeros in the matrix are not stored at all so an element is zero if + # its row dict has no index at that key. A row is entirely zero if its + # row index is not in the outer dict. Since rref reorders the rows and + # removes zero rows we can completely discard the row indices. The first + # step then copies the row dicts into a list sorted by the index of the + # first nonzero column in each row. + # + # The algorithm then processes each row Ai one at a time. Previously seen + # rows are used to cancel their pivot columns from Ai. Then a pivot from + # Ai is chosen and is cancelled from all previously seen rows. At this + # point Ai joins the previously seen rows. Once all rows are seen all + # elimination has occurred and the rows are sorted by pivot column index. + # + # The previously seen rows are stored in two separate groups. The reduced + # group consists of all rows that have been reduced to a single nonzero + # element (the pivot). There is no need to attempt any further reduction + # with these. Rows that still have other nonzeros need to be considered + # when Ai is cancelled from the previously seen rows. + # + # A dict nonzerocolumns is used to map from a column index to a set of + # previously seen rows that still have a nonzero element in that column. + # This means that we can cancel the pivot from Ai into the previously seen + # rows without needing to loop over each row that might have a zero in + # that column. + # + + # Row dicts sorted by index of first nonzero column + # (Maybe sorting is not needed/useful.) + Arows = sorted((Ai.copy() for Ai in A.values()), key=min) + + # Each processed row has an associated pivot column. + # pivot_row_map maps from the pivot column index to the row dict. + # This means that we can represent a set of rows purely as a set of their + # pivot indices. + pivot_row_map = {} + + # Set of pivot indices for rows that are fully reduced to a single nonzero. + reduced_pivots = set() + + # Set of pivot indices for rows not fully reduced + nonreduced_pivots = set() + + # Map from column index to a set of pivot indices representing the rows + # that have a nonzero at that column. + nonzero_columns = defaultdict(set) + + while Arows: + # Select pivot element and row + Ai = Arows.pop() + + # Nonzero columns from fully reduced pivot rows can be removed + Ai = {j: Aij for j, Aij in Ai.items() if j not in reduced_pivots} + + # Others require full row cancellation + for j in nonreduced_pivots & set(Ai): + Aj = pivot_row_map[j] + Aij = Ai[j] + Ainz = set(Ai) + Ajnz = set(Aj) + for k in Ajnz - Ainz: + Ai[k] = - Aij * Aj[k] + Ai.pop(j) + Ainz.remove(j) + for k in Ajnz & Ainz: + Aik = Ai[k] - Aij * Aj[k] + if Aik: + Ai[k] = Aik + else: + Ai.pop(k) + + # We have now cancelled previously seen pivots from Ai. + # If it is zero then discard it. + if not Ai: + continue + + # Choose a pivot from Ai: + j = min(Ai) + Aij = Ai[j] + pivot_row_map[j] = Ai + Ainz = set(Ai) + + # Normalise the pivot row to make the pivot 1. + # + # This approach is slow for some domains. Cross cancellation might be + # better for e.g. QQ(x) with division delayed to the final steps. + Aijinv = Aij**-1 + for l in Ai: + Ai[l] *= Aijinv + + # Use Aij to cancel column j from all previously seen rows + for k in nonzero_columns.pop(j, ()): + Ak = pivot_row_map[k] + Akj = Ak[j] + Aknz = set(Ak) + for l in Ainz - Aknz: + Ak[l] = - Akj * Ai[l] + nonzero_columns[l].add(k) + Ak.pop(j) + Aknz.remove(j) + for l in Ainz & Aknz: + Akl = Ak[l] - Akj * Ai[l] + if Akl: + Ak[l] = Akl + else: + # Drop nonzero elements + Ak.pop(l) + if l != j: + nonzero_columns[l].remove(k) + if len(Ak) == 1: + reduced_pivots.add(k) + nonreduced_pivots.remove(k) + + if len(Ai) == 1: + reduced_pivots.add(j) + else: + nonreduced_pivots.add(j) + for l in Ai: + if l != j: + nonzero_columns[l].add(j) + + # All done! + pivots = sorted(reduced_pivots | nonreduced_pivots) + pivot2row = {p: n for n, p in enumerate(pivots)} + nonzero_columns = {c: {pivot2row[p] for p in s} for c, s in nonzero_columns.items()} + rows = [pivot_row_map[i] for i in pivots] + rref = dict(enumerate(rows)) + return rref, pivots, nonzero_columns + + +def sdm_rref_den(A, K): + """ + Return the reduced row echelon form (RREF) of A with denominator. + + The RREF is computed using fraction-free Gauss-Jordan elimination. + + Explanation + =========== + + The algorithm used is the fraction-free version of Gauss-Jordan elimination + described as FFGJ in [1]_. Here it is modified to handle zero or missing + pivots and to avoid redundant arithmetic. This implementation is also + optimized for sparse matrices. + + The domain $K$ must support exact division (``K.exquo``) but does not need + to be a field. This method is suitable for most exact rings and fields like + :ref:`ZZ`, :ref:`QQ` and :ref:`QQ(a)`. In the case of :ref:`QQ` or + :ref:`K(x)` it might be more efficient to clear denominators and use + :ref:`ZZ` or :ref:`K[x]` instead. + + For inexact domains like :ref:`RR` and :ref:`CC` use ``ddm_irref`` instead. + + Examples + ======== + + >>> from sympy.polys.matrices.sdm import sdm_rref_den + >>> from sympy.polys.domains import ZZ + >>> A = {0: {0: ZZ(1), 1: ZZ(2)}, 1: {0: ZZ(3), 1: ZZ(4)}} + >>> A_rref, den, pivots = sdm_rref_den(A, ZZ) + >>> A_rref + {0: {0: -2}, 1: {1: -2}} + >>> den + -2 + >>> pivots + [0, 1] + + See Also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.rref_den + Higher-level interface to ``sdm_rref_den`` that would usually be used + instead of calling this function directly. + sympy.polys.matrices.sdm.sdm_rref_den + The ``SDM`` method that uses this function. + sdm_irref + Computes RREF using field division. + ddm_irref_den + The dense version of this algorithm. + + References + ========== + + .. [1] Fraction-free algorithms for linear and polynomial equations. + George C. Nakos , Peter R. Turner , Robert M. Williams. + https://dl.acm.org/doi/10.1145/271130.271133 + """ + # + # We represent each row of the matrix as a dict mapping column indices to + # nonzero elements. We will build the RREF matrix starting from an empty + # matrix and appending one row at a time. At each step we will have the + # RREF of the rows we have processed so far. + # + # Our representation of the RREF divides it into three parts: + # + # 1. Fully reduced rows having only a single nonzero element (the pivot). + # 2. Partially reduced rows having nonzeros after the pivot. + # 3. The current denominator and divisor. + # + # For example if the incremental RREF might be: + # + # [2, 0, 0, 0, 0, 0, 0, 0, 0, 0] + # [0, 0, 2, 0, 0, 0, 7, 0, 0, 0] + # [0, 0, 0, 0, 0, 2, 0, 0, 0, 0] + # [0, 0, 0, 0, 0, 0, 0, 2, 0, 0] + # [0, 0, 0, 0, 0, 0, 0, 0, 2, 0] + # + # Here the second row is partially reduced and the other rows are fully + # reduced. The denominator would be 2 in this case. We distinguish the + # fully reduced rows because we can handle them more efficiently when + # adding a new row. + # + # When adding a new row we need to multiply it by the current denominator. + # Then we reduce the new row by cross cancellation with the previous rows. + # Then if it is not reduced to zero we take its leading entry as the new + # pivot, cross cancel the new row from the previous rows and update the + # denominator. In the fraction-free version this last step requires + # multiplying and dividing the whole matrix by the new pivot and the + # current divisor. The advantage of building the RREF one row at a time is + # that in the sparse case we only need to work with the relatively sparse + # upper rows of the matrix. The simple version of FFGJ in [1] would + # multiply and divide all the dense lower rows at each step. + + # Handle the trivial cases. + if not A: + return ({}, K.one, []) + elif len(A) == 1: + Ai, = A.values() + j = min(Ai) + Aij = Ai[j] + return ({0: Ai.copy()}, Aij, [j]) + + # For inexact domains like RR[x] we use quo and discard the remainder. + # Maybe it would be better for K.exquo to do this automatically. + if K.is_Exact: + exquo = K.exquo + else: + exquo = K.quo + + # Make sure we have the rows in order to make this deterministic from the + # outset. + _, rows_in_order = zip(*sorted(A.items())) + + col_to_row_reduced = {} + col_to_row_unreduced = {} + reduced = col_to_row_reduced.keys() + unreduced = col_to_row_unreduced.keys() + + # Our representation of the RREF so far. + A_rref_rows = [] + denom = None + divisor = None + + # The rows that remain to be added to the RREF. These are sorted by the + # column index of their leading entry. Note that sorted() is stable so the + # previous sort by unique row index is still needed to make this + # deterministic (there may be multiple rows with the same leading column). + A_rows = sorted(rows_in_order, key=min) + + for Ai in A_rows: + + # All fully reduced columns can be immediately discarded. + Ai = {j: Aij for j, Aij in Ai.items() if j not in reduced} + + # We need to multiply the new row by the current denominator to bring + # it into the same scale as the previous rows and then cross-cancel to + # reduce it wrt the previous unreduced rows. All pivots in the previous + # rows are equal to denom so the coefficients we need to make a linear + # combination of the previous rows to cancel into the new row are just + # the ones that are already in the new row *before* we multiply by + # denom. We compute that linear combination first and then multiply the + # new row by denom before subtraction. + Ai_cancel = {} + + for j in unreduced & Ai.keys(): + # Remove the pivot column from the new row since it would become + # zero anyway. + Aij = Ai.pop(j) + + Aj = A_rref_rows[col_to_row_unreduced[j]] + + for k, Ajk in Aj.items(): + Aik_cancel = Ai_cancel.get(k) + if Aik_cancel is None: + Ai_cancel[k] = Aij * Ajk + else: + Aik_cancel = Aik_cancel + Aij * Ajk + if Aik_cancel: + Ai_cancel[k] = Aik_cancel + else: + Ai_cancel.pop(k) + + # Multiply the new row by the current denominator and subtract. + Ai_nz = set(Ai) + Ai_cancel_nz = set(Ai_cancel) + + d = denom or K.one + + for k in Ai_cancel_nz - Ai_nz: + Ai[k] = -Ai_cancel[k] + + for k in Ai_nz - Ai_cancel_nz: + Ai[k] = Ai[k] * d + + for k in Ai_cancel_nz & Ai_nz: + Aik = Ai[k] * d - Ai_cancel[k] + if Aik: + Ai[k] = Aik + else: + Ai.pop(k) + + # Now Ai has the same scale as the other rows and is reduced wrt the + # unreduced rows. + + # If the row is reduced to zero then discard it. + if not Ai: + continue + + # Choose a pivot for this row. + j = min(Ai) + Aij = Ai.pop(j) + + # Cross cancel the unreduced rows by the new row. + # a[k][l] = (a[i][j]*a[k][l] - a[k][j]*a[i][l]) / divisor + for pk, k in list(col_to_row_unreduced.items()): + + Ak = A_rref_rows[k] + + if j not in Ak: + # This row is already reduced wrt the new row but we need to + # bring it to the same scale as the new denominator. This step + # is not needed in sdm_irref. + for l, Akl in Ak.items(): + Akl = Akl * Aij + if divisor is not None: + Akl = exquo(Akl, divisor) + Ak[l] = Akl + continue + + Akj = Ak.pop(j) + Ai_nz = set(Ai) + Ak_nz = set(Ak) + + for l in Ai_nz - Ak_nz: + Ak[l] = - Akj * Ai[l] + if divisor is not None: + Ak[l] = exquo(Ak[l], divisor) + + # This loop also not needed in sdm_irref. + for l in Ak_nz - Ai_nz: + Ak[l] = Aij * Ak[l] + if divisor is not None: + Ak[l] = exquo(Ak[l], divisor) + + for l in Ai_nz & Ak_nz: + Akl = Aij * Ak[l] - Akj * Ai[l] + if Akl: + if divisor is not None: + Akl = exquo(Akl, divisor) + Ak[l] = Akl + else: + Ak.pop(l) + + if not Ak: + col_to_row_unreduced.pop(pk) + col_to_row_reduced[pk] = k + + i = len(A_rref_rows) + A_rref_rows.append(Ai) + if Ai: + col_to_row_unreduced[j] = i + else: + col_to_row_reduced[j] = i + + # Update the denominator. + if not K.is_one(Aij): + if denom is None: + denom = Aij + else: + denom *= Aij + + if divisor is not None: + denom = exquo(denom, divisor) + + # Update the divisor. + divisor = denom + + if denom is None: + denom = K.one + + # Sort the rows by their leading column index. + col_to_row = {**col_to_row_reduced, **col_to_row_unreduced} + row_to_col = {i: j for j, i in col_to_row.items()} + A_rref_rows_col = [(row_to_col[i], Ai) for i, Ai in enumerate(A_rref_rows)] + pivots, A_rref = zip(*sorted(A_rref_rows_col)) + pivots = list(pivots) + + # Insert the pivot values + for i, Ai in enumerate(A_rref): + Ai[pivots[i]] = denom + + A_rref_sdm = dict(enumerate(A_rref)) + + return A_rref_sdm, denom, pivots + + +def sdm_nullspace_from_rref(A, one, ncols, pivots, nonzero_cols): + """Get nullspace from A which is in RREF""" + nonpivots = sorted(set(range(ncols)) - set(pivots)) + + K = [] + for j in nonpivots: + Kj = {j:one} + for i in nonzero_cols.get(j, ()): + Kj[pivots[i]] = -A[i][j] + K.append(Kj) + + return K, nonpivots + + +def sdm_particular_from_rref(A, ncols, pivots): + """Get a particular solution from A which is in RREF""" + P = {} + for i, j in enumerate(pivots): + Ain = A[i].get(ncols-1, None) + if Ain is not None: + P[j] = Ain / A[i][j] + return P + + +def sdm_berk(M, n, K): + """ + Berkowitz algorithm for computing the characteristic polynomial. + + Explanation + =========== + + The Berkowitz algorithm is a division-free algorithm for computing the + characteristic polynomial of a matrix over any commutative ring using only + arithmetic in the coefficient ring. This implementation is for sparse + matrices represented in a dict-of-dicts format (like :class:`SDM`). + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.polys.matrices.sdm import sdm_berk + >>> from sympy.polys.domains import ZZ + >>> M = {0: {0: ZZ(1), 1:ZZ(2)}, 1: {0:ZZ(3), 1:ZZ(4)}} + >>> sdm_berk(M, 2, ZZ) + {0: 1, 1: -5, 2: -2} + >>> Matrix([[1, 2], [3, 4]]).charpoly() + PurePoly(lambda**2 - 5*lambda - 2, lambda, domain='ZZ') + + See Also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.charpoly + The high-level interface to this function. + sympy.polys.matrices.dense.ddm_berk + The dense version of this function. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Samuelson%E2%80%93Berkowitz_algorithm + """ + zero = K.zero + one = K.one + + if n == 0: + return {0: one} + elif n == 1: + pdict = {0: one} + if M00 := M.get(0, {}).get(0, zero): + pdict[1] = -M00 + + # M = [[a, R], + # [C, A]] + a, R, C, A = K.zero, {}, {}, defaultdict(dict) + for i, Mi in M.items(): + for j, Mij in Mi.items(): + if i and j: + A[i-1][j-1] = Mij + elif i: + C[i-1] = Mij + elif j: + R[j-1] = Mij + else: + a = Mij + + # T = [ 1, 0, 0, 0, 0, ... ] + # [ -a, 1, 0, 0, 0, ... ] + # [ -R*C, -a, 1, 0, 0, ... ] + # [ -R*A*C, -R*C, -a, 1, 0, ... ] + # [-R*A^2*C, -R*A*C, -R*C, -a, 1, ... ] + # [ ... ] + # T is (n+1) x n + # + # In the sparse case we might have A^m*C = 0 for some m making T banded + # rather than triangular so we just compute the nonzero entries of the + # first column rather than constructing the matrix explicitly. + + AnC = C + RC = sdm_dotvec(R, C, K) + + Tvals = [one, -a, -RC] + for i in range(3, n+1): + AnC = sdm_matvecmul(A, AnC, K) + if not AnC: + break + RAnC = sdm_dotvec(R, AnC, K) + Tvals.append(-RAnC) + + # Strip trailing zeros + while Tvals and not Tvals[-1]: + Tvals.pop() + + q = sdm_berk(A, n-1, K) + + # This would be the explicit multiplication T*q but we can do better: + # + # T = {} + # for i in range(n+1): + # Ti = {} + # for j in range(max(0, i-len(Tvals)+1), min(i+1, n)): + # Ti[j] = Tvals[i-j] + # T[i] = Ti + # Tq = sdm_matvecmul(T, q, K) + # + # In the sparse case q might be mostly zero. We know that T[i,j] is nonzero + # for i <= j < i + len(Tvals) so if q does not have a nonzero entry in that + # range then Tq[j] must be zero. We exploit this potential banded + # structure and the potential sparsity of q to compute Tq more efficiently. + + Tvals = Tvals[::-1] + + Tq = {} + + for i in range(min(q), min(max(q)+len(Tvals), n+1)): + Ti = dict(enumerate(Tvals, i-len(Tvals)+1)) + if Tqi := sdm_dotvec(Ti, q, K): + Tq[i] = Tqi + + return Tq diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/tests/__init__.py b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/tests/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b603167f21529bb8ea5a9112747780815987c7ad Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/tests/__pycache__/test_ddm.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/__pycache__/test_ddm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da92d226f10e57ca71675ba6f77e9bfc3b3e3c06 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/__pycache__/test_ddm.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/tests/__pycache__/test_dense.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/__pycache__/test_dense.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5313ec04f53a4e3115bcde83eaff6c27e9208cf2 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/__pycache__/test_dense.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/tests/__pycache__/test_domainmatrix.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/__pycache__/test_domainmatrix.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a635aaccb9d0e2f6e22f19893e1b84b146f5123 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/__pycache__/test_domainmatrix.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/tests/__pycache__/test_domainscalar.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/__pycache__/test_domainscalar.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db795ce690cfd09775439fa1c6356c26c2d0ca55 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/__pycache__/test_domainscalar.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/tests/__pycache__/test_eigen.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/__pycache__/test_eigen.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e5aaa7ea583da05280ac2623c6e57606f35c6a7 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/__pycache__/test_eigen.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/tests/__pycache__/test_fflu.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/__pycache__/test_fflu.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1d0fc2e9141a074e14bf7632fbba44cf5f02bf9 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/__pycache__/test_fflu.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/tests/__pycache__/test_inverse.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/__pycache__/test_inverse.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9eaf6fec8abef0a75848865e9e30887f8de38cc7 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/__pycache__/test_inverse.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/tests/__pycache__/test_linsolve.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/__pycache__/test_linsolve.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6ddbbbd228f1e1afa10e444d070da9906c6e170 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/__pycache__/test_linsolve.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/tests/__pycache__/test_lll.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/__pycache__/test_lll.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99a2593a5a179412338c4044e7b820b6c490edd2 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/__pycache__/test_lll.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/tests/__pycache__/test_normalforms.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/__pycache__/test_normalforms.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ffa56c619a3325bc1734340eaad191abea5f547 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/__pycache__/test_normalforms.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/tests/__pycache__/test_nullspace.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/__pycache__/test_nullspace.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa017356a3393cd6a62ced6972f1f0fe09b6fb06 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/__pycache__/test_nullspace.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/tests/__pycache__/test_rref.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/__pycache__/test_rref.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45023e0a3b355c3a1b45db55799ce4770804b044 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/__pycache__/test_rref.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/tests/__pycache__/test_sdm.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/__pycache__/test_sdm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..705b09bb8045378e6949ddb4f0ad6ea497cbfca6 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/__pycache__/test_sdm.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/tests/__pycache__/test_xxm.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/__pycache__/test_xxm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..61d11fad31bc89683563e8a1acfb5198bbceb1c6 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/__pycache__/test_xxm.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/tests/test_ddm.py b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/test_ddm.py new file mode 100644 index 0000000000000000000000000000000000000000..44c862461e85d503696e621874c10d67d8ee1f1d --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/test_ddm.py @@ -0,0 +1,558 @@ +from sympy.testing.pytest import raises +from sympy.external.gmpy import GROUND_TYPES + +from sympy.polys import ZZ, QQ + +from sympy.polys.matrices.ddm import DDM +from sympy.polys.matrices.exceptions import ( + DMShapeError, DMNonInvertibleMatrixError, DMDomainError, + DMBadInputError) + + +def test_DDM_init(): + items = [[ZZ(0), ZZ(1), ZZ(2)], [ZZ(3), ZZ(4), ZZ(5)]] + shape = (2, 3) + ddm = DDM(items, shape, ZZ) + assert ddm.shape == shape + assert ddm.rows == 2 + assert ddm.cols == 3 + assert ddm.domain == ZZ + + raises(DMBadInputError, lambda: DDM([[ZZ(2), ZZ(3)]], (2, 2), ZZ)) + raises(DMBadInputError, lambda: DDM([[ZZ(1)], [ZZ(2), ZZ(3)]], (2, 2), ZZ)) + + +def test_DDM_getsetitem(): + ddm = DDM([[ZZ(2), ZZ(3)], [ZZ(4), ZZ(5)]], (2, 2), ZZ) + + assert ddm[0][0] == ZZ(2) + assert ddm[0][1] == ZZ(3) + assert ddm[1][0] == ZZ(4) + assert ddm[1][1] == ZZ(5) + + raises(IndexError, lambda: ddm[2][0]) + raises(IndexError, lambda: ddm[0][2]) + + ddm[0][0] = ZZ(-1) + assert ddm[0][0] == ZZ(-1) + + +def test_DDM_str(): + ddm = DDM([[ZZ(0), ZZ(1)], [ZZ(2), ZZ(3)]], (2, 2), ZZ) + if GROUND_TYPES == 'gmpy': # pragma: no cover + assert str(ddm) == '[[0, 1], [2, 3]]' + assert repr(ddm) == 'DDM([[mpz(0), mpz(1)], [mpz(2), mpz(3)]], (2, 2), ZZ)' + else: # pragma: no cover + assert repr(ddm) == 'DDM([[0, 1], [2, 3]], (2, 2), ZZ)' + assert str(ddm) == '[[0, 1], [2, 3]]' + + +def test_DDM_eq(): + items = [[ZZ(0), ZZ(1)], [ZZ(2), ZZ(3)]] + ddm1 = DDM(items, (2, 2), ZZ) + ddm2 = DDM(items, (2, 2), ZZ) + + assert (ddm1 == ddm1) is True + assert (ddm1 == items) is False + assert (items == ddm1) is False + assert (ddm1 == ddm2) is True + assert (ddm2 == ddm1) is True + + assert (ddm1 != ddm1) is False + assert (ddm1 != items) is True + assert (items != ddm1) is True + assert (ddm1 != ddm2) is False + assert (ddm2 != ddm1) is False + + ddm3 = DDM([[ZZ(0), ZZ(1)], [ZZ(3), ZZ(3)]], (2, 2), ZZ) + ddm3 = DDM(items, (2, 2), QQ) + + assert (ddm1 == ddm3) is False + assert (ddm3 == ddm1) is False + assert (ddm1 != ddm3) is True + assert (ddm3 != ddm1) is True + + +def test_DDM_convert_to(): + ddm = DDM([[ZZ(1), ZZ(2)]], (1, 2), ZZ) + assert ddm.convert_to(ZZ) == ddm + ddmq = ddm.convert_to(QQ) + assert ddmq.domain == QQ + + +def test_DDM_zeros(): + ddmz = DDM.zeros((3, 4), QQ) + assert list(ddmz) == [[QQ(0)] * 4] * 3 + assert ddmz.shape == (3, 4) + assert ddmz.domain == QQ + +def test_DDM_ones(): + ddmone = DDM.ones((2, 3), QQ) + assert list(ddmone) == [[QQ(1)] * 3] * 2 + assert ddmone.shape == (2, 3) + assert ddmone.domain == QQ + +def test_DDM_eye(): + ddmz = DDM.eye(3, QQ) + f = lambda i, j: QQ(1) if i == j else QQ(0) + assert list(ddmz) == [[f(i, j) for i in range(3)] for j in range(3)] + assert ddmz.shape == (3, 3) + assert ddmz.domain == QQ + + +def test_DDM_copy(): + ddm1 = DDM([[QQ(1)], [QQ(2)]], (2, 1), QQ) + ddm2 = ddm1.copy() + assert (ddm1 == ddm2) is True + ddm1[0][0] = QQ(-1) + assert (ddm1 == ddm2) is False + ddm2[0][0] = QQ(-1) + assert (ddm1 == ddm2) is True + + +def test_DDM_transpose(): + ddm = DDM([[QQ(1)], [QQ(2)]], (2, 1), QQ) + ddmT = DDM([[QQ(1), QQ(2)]], (1, 2), QQ) + assert ddm.transpose() == ddmT + ddm02 = DDM([], (0, 2), QQ) + ddm02T = DDM([[], []], (2, 0), QQ) + assert ddm02.transpose() == ddm02T + assert ddm02T.transpose() == ddm02 + ddm0 = DDM([], (0, 0), QQ) + assert ddm0.transpose() == ddm0 + + +def test_DDM_add(): + A = DDM([[ZZ(1)], [ZZ(2)]], (2, 1), ZZ) + B = DDM([[ZZ(3)], [ZZ(4)]], (2, 1), ZZ) + C = DDM([[ZZ(4)], [ZZ(6)]], (2, 1), ZZ) + AQ = DDM([[QQ(1)], [QQ(2)]], (2, 1), QQ) + assert A + B == A.add(B) == C + + raises(DMShapeError, lambda: A + DDM([[ZZ(5)]], (1, 1), ZZ)) + raises(TypeError, lambda: A + ZZ(1)) + raises(TypeError, lambda: ZZ(1) + A) + raises(DMDomainError, lambda: A + AQ) + raises(DMDomainError, lambda: AQ + A) + + +def test_DDM_sub(): + A = DDM([[ZZ(1)], [ZZ(2)]], (2, 1), ZZ) + B = DDM([[ZZ(3)], [ZZ(4)]], (2, 1), ZZ) + C = DDM([[ZZ(-2)], [ZZ(-2)]], (2, 1), ZZ) + AQ = DDM([[QQ(1)], [QQ(2)]], (2, 1), QQ) + D = DDM([[ZZ(5)]], (1, 1), ZZ) + assert A - B == A.sub(B) == C + + raises(TypeError, lambda: A - ZZ(1)) + raises(TypeError, lambda: ZZ(1) - A) + raises(DMShapeError, lambda: A - D) + raises(DMShapeError, lambda: D - A) + raises(DMShapeError, lambda: A.sub(D)) + raises(DMShapeError, lambda: D.sub(A)) + raises(DMDomainError, lambda: A - AQ) + raises(DMDomainError, lambda: AQ - A) + raises(DMDomainError, lambda: A.sub(AQ)) + raises(DMDomainError, lambda: AQ.sub(A)) + + +def test_DDM_neg(): + A = DDM([[ZZ(1)], [ZZ(2)]], (2, 1), ZZ) + An = DDM([[ZZ(-1)], [ZZ(-2)]], (2, 1), ZZ) + assert -A == A.neg() == An + assert -An == An.neg() == A + + +def test_DDM_mul(): + A = DDM([[ZZ(1)]], (1, 1), ZZ) + A2 = DDM([[ZZ(2)]], (1, 1), ZZ) + assert A * ZZ(2) == A2 + assert ZZ(2) * A == A2 + raises(TypeError, lambda: [[1]] * A) + raises(TypeError, lambda: A * [[1]]) + + +def test_DDM_matmul(): + A = DDM([[ZZ(1)], [ZZ(2)]], (2, 1), ZZ) + B = DDM([[ZZ(3), ZZ(4)]], (1, 2), ZZ) + AB = DDM([[ZZ(3), ZZ(4)], [ZZ(6), ZZ(8)]], (2, 2), ZZ) + BA = DDM([[ZZ(11)]], (1, 1), ZZ) + + assert A @ B == A.matmul(B) == AB + assert B @ A == B.matmul(A) == BA + + raises(TypeError, lambda: A @ 1) + raises(TypeError, lambda: A @ [[3, 4]]) + + Bq = DDM([[QQ(3), QQ(4)]], (1, 2), QQ) + + raises(DMDomainError, lambda: A @ Bq) + raises(DMDomainError, lambda: Bq @ A) + + C = DDM([[ZZ(1)]], (1, 1), ZZ) + + assert A @ C == A.matmul(C) == A + + raises(DMShapeError, lambda: C @ A) + raises(DMShapeError, lambda: C.matmul(A)) + + Z04 = DDM([], (0, 4), ZZ) + Z40 = DDM([[]]*4, (4, 0), ZZ) + Z50 = DDM([[]]*5, (5, 0), ZZ) + Z05 = DDM([], (0, 5), ZZ) + Z45 = DDM([[0] * 5] * 4, (4, 5), ZZ) + Z54 = DDM([[0] * 4] * 5, (5, 4), ZZ) + Z00 = DDM([], (0, 0), ZZ) + + assert Z04 @ Z45 == Z04.matmul(Z45) == Z05 + assert Z45 @ Z50 == Z45.matmul(Z50) == Z40 + assert Z00 @ Z04 == Z00.matmul(Z04) == Z04 + assert Z50 @ Z00 == Z50.matmul(Z00) == Z50 + assert Z00 @ Z00 == Z00.matmul(Z00) == Z00 + assert Z50 @ Z04 == Z50.matmul(Z04) == Z54 + + raises(DMShapeError, lambda: Z05 @ Z40) + raises(DMShapeError, lambda: Z05.matmul(Z40)) + + +def test_DDM_hstack(): + A = DDM([[ZZ(1), ZZ(2), ZZ(3)]], (1, 3), ZZ) + B = DDM([[ZZ(4), ZZ(5)]], (1, 2), ZZ) + C = DDM([[ZZ(6)]], (1, 1), ZZ) + + Ah = A.hstack(B) + assert Ah.shape == (1, 5) + assert Ah.domain == ZZ + assert Ah == DDM([[ZZ(1), ZZ(2), ZZ(3), ZZ(4), ZZ(5)]], (1, 5), ZZ) + + Ah = A.hstack(B, C) + assert Ah.shape == (1, 6) + assert Ah.domain == ZZ + assert Ah == DDM([[ZZ(1), ZZ(2), ZZ(3), ZZ(4), ZZ(5), ZZ(6)]], (1, 6), ZZ) + + +def test_DDM_vstack(): + A = DDM([[ZZ(1)], [ZZ(2)], [ZZ(3)]], (3, 1), ZZ) + B = DDM([[ZZ(4)], [ZZ(5)]], (2, 1), ZZ) + C = DDM([[ZZ(6)]], (1, 1), ZZ) + + Ah = A.vstack(B) + assert Ah.shape == (5, 1) + assert Ah.domain == ZZ + assert Ah == DDM([[ZZ(1)], [ZZ(2)], [ZZ(3)], [ZZ(4)], [ZZ(5)]], (5, 1), ZZ) + + Ah = A.vstack(B, C) + assert Ah.shape == (6, 1) + assert Ah.domain == ZZ + assert Ah == DDM([[ZZ(1)], [ZZ(2)], [ZZ(3)], [ZZ(4)], [ZZ(5)], [ZZ(6)]], (6, 1), ZZ) + + +def test_DDM_applyfunc(): + A = DDM([[ZZ(1), ZZ(2), ZZ(3)]], (1, 3), ZZ) + B = DDM([[ZZ(2), ZZ(4), ZZ(6)]], (1, 3), ZZ) + assert A.applyfunc(lambda x: 2*x, ZZ) == B + +def test_DDM_rref(): + + A = DDM([], (0, 4), QQ) + assert A.rref() == (A, []) + + A = DDM([[QQ(0), QQ(1)], [QQ(1), QQ(1)]], (2, 2), QQ) + Ar = DDM([[QQ(1), QQ(0)], [QQ(0), QQ(1)]], (2, 2), QQ) + pivots = [0, 1] + assert A.rref() == (Ar, pivots) + + A = DDM([[QQ(1), QQ(2), QQ(1)], [QQ(3), QQ(4), QQ(1)]], (2, 3), QQ) + Ar = DDM([[QQ(1), QQ(0), QQ(-1)], [QQ(0), QQ(1), QQ(1)]], (2, 3), QQ) + pivots = [0, 1] + assert A.rref() == (Ar, pivots) + + A = DDM([[QQ(3), QQ(4), QQ(1)], [QQ(1), QQ(2), QQ(1)]], (2, 3), QQ) + Ar = DDM([[QQ(1), QQ(0), QQ(-1)], [QQ(0), QQ(1), QQ(1)]], (2, 3), QQ) + pivots = [0, 1] + assert A.rref() == (Ar, pivots) + + A = DDM([[QQ(1), QQ(0)], [QQ(1), QQ(3)], [QQ(0), QQ(1)]], (3, 2), QQ) + Ar = DDM([[QQ(1), QQ(0)], [QQ(0), QQ(1)], [QQ(0), QQ(0)]], (3, 2), QQ) + pivots = [0, 1] + assert A.rref() == (Ar, pivots) + + A = DDM([[QQ(1), QQ(0), QQ(1)], [QQ(3), QQ(0), QQ(1)]], (2, 3), QQ) + Ar = DDM([[QQ(1), QQ(0), QQ(0)], [QQ(0), QQ(0), QQ(1)]], (2, 3), QQ) + pivots = [0, 2] + assert A.rref() == (Ar, pivots) + + +def test_DDM_nullspace(): + # more tests are in test_nullspace.py + A = DDM([[QQ(1), QQ(1)], [QQ(1), QQ(1)]], (2, 2), QQ) + Anull = DDM([[QQ(-1), QQ(1)]], (1, 2), QQ) + nonpivots = [1] + assert A.nullspace() == (Anull, nonpivots) + + +def test_DDM_particular(): + A = DDM([[QQ(1), QQ(0)]], (1, 2), QQ) + assert A.particular() == DDM.zeros((1, 1), QQ) + + +def test_DDM_det(): + # 0x0 case + A = DDM([], (0, 0), ZZ) + assert A.det() == ZZ(1) + + # 1x1 case + A = DDM([[ZZ(2)]], (1, 1), ZZ) + assert A.det() == ZZ(2) + + # 2x2 case + A = DDM([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + assert A.det() == ZZ(-2) + + # 3x3 with swap + A = DDM([[ZZ(1), ZZ(2), ZZ(3)], [ZZ(1), ZZ(2), ZZ(4)], [ZZ(1), ZZ(2), ZZ(5)]], (3, 3), ZZ) + assert A.det() == ZZ(0) + + # 2x2 QQ case + A = DDM([[QQ(1, 2), QQ(1, 2)], [QQ(1, 3), QQ(1, 4)]], (2, 2), QQ) + assert A.det() == QQ(-1, 24) + + # Nonsquare error + A = DDM([[ZZ(1)], [ZZ(2)]], (2, 1), ZZ) + raises(DMShapeError, lambda: A.det()) + + # Nonsquare error with empty matrix + A = DDM([], (0, 1), ZZ) + raises(DMShapeError, lambda: A.det()) + + +def test_DDM_inv(): + A = DDM([[QQ(1, 1), QQ(2, 1)], [QQ(3, 1), QQ(4, 1)]], (2, 2), QQ) + Ainv = DDM([[QQ(-2, 1), QQ(1, 1)], [QQ(3, 2), QQ(-1, 2)]], (2, 2), QQ) + assert A.inv() == Ainv + + A = DDM([[QQ(1), QQ(2)]], (1, 2), QQ) + raises(DMShapeError, lambda: A.inv()) + + A = DDM([[ZZ(2)]], (1, 1), ZZ) + raises(DMDomainError, lambda: A.inv()) + + A = DDM([], (0, 0), QQ) + assert A.inv() == A + + A = DDM([[QQ(1), QQ(2)], [QQ(2), QQ(4)]], (2, 2), QQ) + raises(DMNonInvertibleMatrixError, lambda: A.inv()) + + +def test_DDM_lu(): + A = DDM([[QQ(1), QQ(2)], [QQ(3), QQ(4)]], (2, 2), QQ) + L, U, swaps = A.lu() + assert L == DDM([[QQ(1), QQ(0)], [QQ(3), QQ(1)]], (2, 2), QQ) + assert U == DDM([[QQ(1), QQ(2)], [QQ(0), QQ(-2)]], (2, 2), QQ) + assert swaps == [] + + A = [[1, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 1], [0, 0, 1, 2]] + Lexp = [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 1, 1]] + Uexp = [[1, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 1], [0, 0, 0, 1]] + to_dom = lambda rows, dom: [[dom(e) for e in row] for row in rows] + A = DDM(to_dom(A, QQ), (4, 4), QQ) + Lexp = DDM(to_dom(Lexp, QQ), (4, 4), QQ) + Uexp = DDM(to_dom(Uexp, QQ), (4, 4), QQ) + L, U, swaps = A.lu() + assert L == Lexp + assert U == Uexp + assert swaps == [] + + +def test_DDM_lu_solve(): + # Basic example + A = DDM([[QQ(1), QQ(2)], [QQ(3), QQ(4)]], (2, 2), QQ) + b = DDM([[QQ(1)], [QQ(2)]], (2, 1), QQ) + x = DDM([[QQ(0)], [QQ(1, 2)]], (2, 1), QQ) + assert A.lu_solve(b) == x + + # Example with swaps + A = DDM([[QQ(0), QQ(2)], [QQ(3), QQ(4)]], (2, 2), QQ) + assert A.lu_solve(b) == x + + # Overdetermined, consistent + A = DDM([[QQ(1), QQ(2)], [QQ(3), QQ(4)], [QQ(5), QQ(6)]], (3, 2), QQ) + b = DDM([[QQ(1)], [QQ(2)], [QQ(3)]], (3, 1), QQ) + assert A.lu_solve(b) == x + + # Overdetermined, inconsistent + b = DDM([[QQ(1)], [QQ(2)], [QQ(4)]], (3, 1), QQ) + raises(DMNonInvertibleMatrixError, lambda: A.lu_solve(b)) + + # Square, noninvertible + A = DDM([[QQ(1), QQ(2)], [QQ(1), QQ(2)]], (2, 2), QQ) + b = DDM([[QQ(1)], [QQ(2)]], (2, 1), QQ) + raises(DMNonInvertibleMatrixError, lambda: A.lu_solve(b)) + + # Underdetermined + A = DDM([[QQ(1), QQ(2)]], (1, 2), QQ) + b = DDM([[QQ(3)]], (1, 1), QQ) + raises(NotImplementedError, lambda: A.lu_solve(b)) + + # Domain mismatch + bz = DDM([[ZZ(1)], [ZZ(2)]], (2, 1), ZZ) + raises(DMDomainError, lambda: A.lu_solve(bz)) + + # Shape mismatch + b3 = DDM([[QQ(1)], [QQ(2)], [QQ(3)]], (3, 1), QQ) + raises(DMShapeError, lambda: A.lu_solve(b3)) + + +def test_DDM_charpoly(): + A = DDM([], (0, 0), ZZ) + assert A.charpoly() == [ZZ(1)] + + A = DDM([ + [ZZ(1), ZZ(2), ZZ(3)], + [ZZ(4), ZZ(5), ZZ(6)], + [ZZ(7), ZZ(8), ZZ(9)]], (3, 3), ZZ) + Avec = [ZZ(1), ZZ(-15), ZZ(-18), ZZ(0)] + assert A.charpoly() == Avec + + A = DDM([[ZZ(1), ZZ(2)]], (1, 2), ZZ) + raises(DMShapeError, lambda: A.charpoly()) + + +def test_DDM_getitem(): + dm = DDM([ + [ZZ(1), ZZ(2), ZZ(3)], + [ZZ(4), ZZ(5), ZZ(6)], + [ZZ(7), ZZ(8), ZZ(9)]], (3, 3), ZZ) + + assert dm.getitem(1, 1) == ZZ(5) + assert dm.getitem(1, -2) == ZZ(5) + assert dm.getitem(-1, -3) == ZZ(7) + + raises(IndexError, lambda: dm.getitem(3, 3)) + + +def test_DDM_setitem(): + dm = DDM.zeros((3, 3), ZZ) + dm.setitem(0, 0, 1) + dm.setitem(1, -2, 1) + dm.setitem(-1, -1, 1) + assert dm == DDM.eye(3, ZZ) + + raises(IndexError, lambda: dm.setitem(3, 3, 0)) + + +def test_DDM_extract_slice(): + dm = DDM([ + [ZZ(1), ZZ(2), ZZ(3)], + [ZZ(4), ZZ(5), ZZ(6)], + [ZZ(7), ZZ(8), ZZ(9)]], (3, 3), ZZ) + + assert dm.extract_slice(slice(0, 3), slice(0, 3)) == dm + assert dm.extract_slice(slice(1, 3), slice(-2)) == DDM([[4], [7]], (2, 1), ZZ) + assert dm.extract_slice(slice(1, 3), slice(-2)) == DDM([[4], [7]], (2, 1), ZZ) + assert dm.extract_slice(slice(2, 3), slice(-2)) == DDM([[ZZ(7)]], (1, 1), ZZ) + assert dm.extract_slice(slice(0, 2), slice(-2)) == DDM([[1], [4]], (2, 1), ZZ) + assert dm.extract_slice(slice(-1), slice(-1)) == DDM([[1, 2], [4, 5]], (2, 2), ZZ) + + assert dm.extract_slice(slice(2), slice(3, 4)) == DDM([[], []], (2, 0), ZZ) + assert dm.extract_slice(slice(3, 4), slice(2)) == DDM([], (0, 2), ZZ) + assert dm.extract_slice(slice(3, 4), slice(3, 4)) == DDM([], (0, 0), ZZ) + + +def test_DDM_extract(): + dm1 = DDM([ + [ZZ(1), ZZ(2), ZZ(3)], + [ZZ(4), ZZ(5), ZZ(6)], + [ZZ(7), ZZ(8), ZZ(9)]], (3, 3), ZZ) + dm2 = DDM([ + [ZZ(6), ZZ(4)], + [ZZ(3), ZZ(1)]], (2, 2), ZZ) + assert dm1.extract([1, 0], [2, 0]) == dm2 + assert dm1.extract([-2, 0], [-1, 0]) == dm2 + + assert dm1.extract([], []) == DDM.zeros((0, 0), ZZ) + assert dm1.extract([1], []) == DDM.zeros((1, 0), ZZ) + assert dm1.extract([], [1]) == DDM.zeros((0, 1), ZZ) + + raises(IndexError, lambda: dm2.extract([2], [0])) + raises(IndexError, lambda: dm2.extract([0], [2])) + raises(IndexError, lambda: dm2.extract([-3], [0])) + raises(IndexError, lambda: dm2.extract([0], [-3])) + + +def test_DDM_flat(): + dm = DDM([ + [ZZ(6), ZZ(4)], + [ZZ(3), ZZ(1)]], (2, 2), ZZ) + assert dm.flat() == [ZZ(6), ZZ(4), ZZ(3), ZZ(1)] + + +def test_DDM_is_zero_matrix(): + A = DDM([[QQ(1), QQ(0)], [QQ(0), QQ(0)]], (2, 2), QQ) + Azero = DDM.zeros((1, 2), QQ) + assert A.is_zero_matrix() is False + assert Azero.is_zero_matrix() is True + + +def test_DDM_is_upper(): + # Wide matrices: + A = DDM([ + [QQ(1), QQ(2), QQ(3), QQ(4)], + [QQ(0), QQ(5), QQ(6), QQ(7)], + [QQ(0), QQ(0), QQ(8), QQ(9)] + ], (3, 4), QQ) + B = DDM([ + [QQ(1), QQ(2), QQ(3), QQ(4)], + [QQ(0), QQ(5), QQ(6), QQ(7)], + [QQ(0), QQ(7), QQ(8), QQ(9)] + ], (3, 4), QQ) + assert A.is_upper() is True + assert B.is_upper() is False + + # Tall matrices: + A = DDM([ + [QQ(1), QQ(2), QQ(3)], + [QQ(0), QQ(5), QQ(6)], + [QQ(0), QQ(0), QQ(8)], + [QQ(0), QQ(0), QQ(0)] + ], (4, 3), QQ) + B = DDM([ + [QQ(1), QQ(2), QQ(3)], + [QQ(0), QQ(5), QQ(6)], + [QQ(0), QQ(0), QQ(8)], + [QQ(0), QQ(0), QQ(10)] + ], (4, 3), QQ) + assert A.is_upper() is True + assert B.is_upper() is False + + +def test_DDM_is_lower(): + # Tall matrices: + A = DDM([ + [QQ(1), QQ(2), QQ(3), QQ(4)], + [QQ(0), QQ(5), QQ(6), QQ(7)], + [QQ(0), QQ(0), QQ(8), QQ(9)] + ], (3, 4), QQ).transpose() + B = DDM([ + [QQ(1), QQ(2), QQ(3), QQ(4)], + [QQ(0), QQ(5), QQ(6), QQ(7)], + [QQ(0), QQ(7), QQ(8), QQ(9)] + ], (3, 4), QQ).transpose() + assert A.is_lower() is True + assert B.is_lower() is False + + # Wide matrices: + A = DDM([ + [QQ(1), QQ(2), QQ(3)], + [QQ(0), QQ(5), QQ(6)], + [QQ(0), QQ(0), QQ(8)], + [QQ(0), QQ(0), QQ(0)] + ], (4, 3), QQ).transpose() + B = DDM([ + [QQ(1), QQ(2), QQ(3)], + [QQ(0), QQ(5), QQ(6)], + [QQ(0), QQ(0), QQ(8)], + [QQ(0), QQ(0), QQ(10)] + ], (4, 3), QQ).transpose() + assert A.is_lower() is True + assert B.is_lower() is False diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/tests/test_dense.py b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/test_dense.py new file mode 100644 index 0000000000000000000000000000000000000000..75315ebf6b2ae7d53b4a5737578d3ac5ed4ea36a --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/test_dense.py @@ -0,0 +1,350 @@ +from sympy.testing.pytest import raises + +from sympy.polys import ZZ, QQ + +from sympy.polys.matrices.ddm import DDM +from sympy.polys.matrices.dense import ( + ddm_transpose, + ddm_iadd, ddm_isub, ddm_ineg, ddm_imatmul, ddm_imul, ddm_irref, + ddm_idet, ddm_iinv, ddm_ilu, ddm_ilu_split, ddm_ilu_solve, ddm_berk) + +from sympy.polys.matrices.exceptions import ( + DMDomainError, + DMNonInvertibleMatrixError, + DMNonSquareMatrixError, + DMShapeError, +) + + +def test_ddm_transpose(): + a = [[1, 2], [3, 4]] + assert ddm_transpose(a) == [[1, 3], [2, 4]] + + +def test_ddm_iadd(): + a = [[1, 2], [3, 4]] + b = [[5, 6], [7, 8]] + ddm_iadd(a, b) + assert a == [[6, 8], [10, 12]] + + +def test_ddm_isub(): + a = [[1, 2], [3, 4]] + b = [[5, 6], [7, 8]] + ddm_isub(a, b) + assert a == [[-4, -4], [-4, -4]] + + +def test_ddm_ineg(): + a = [[1, 2], [3, 4]] + ddm_ineg(a) + assert a == [[-1, -2], [-3, -4]] + + +def test_ddm_matmul(): + a = [[1, 2], [3, 4]] + ddm_imul(a, 2) + assert a == [[2, 4], [6, 8]] + + a = [[1, 2], [3, 4]] + ddm_imul(a, 0) + assert a == [[0, 0], [0, 0]] + + +def test_ddm_imatmul(): + a = [[1, 2, 3], [4, 5, 6]] + b = [[1, 2], [3, 4], [5, 6]] + + c1 = [[0, 0], [0, 0]] + ddm_imatmul(c1, a, b) + assert c1 == [[22, 28], [49, 64]] + + c2 = [[0, 0, 0], [0, 0, 0], [0, 0, 0]] + ddm_imatmul(c2, b, a) + assert c2 == [[9, 12, 15], [19, 26, 33], [29, 40, 51]] + + b3 = [[1], [2], [3]] + c3 = [[0], [0]] + ddm_imatmul(c3, a, b3) + assert c3 == [[14], [32]] + + +def test_ddm_irref(): + # Empty matrix + A = [] + Ar = [] + pivots = [] + assert ddm_irref(A) == pivots + assert A == Ar + + # Standard square case + A = [[QQ(0), QQ(1)], [QQ(1), QQ(1)]] + Ar = [[QQ(1), QQ(0)], [QQ(0), QQ(1)]] + pivots = [0, 1] + assert ddm_irref(A) == pivots + assert A == Ar + + # m < n case + A = [[QQ(1), QQ(2), QQ(1)], [QQ(3), QQ(4), QQ(1)]] + Ar = [[QQ(1), QQ(0), QQ(-1)], [QQ(0), QQ(1), QQ(1)]] + pivots = [0, 1] + assert ddm_irref(A) == pivots + assert A == Ar + + # same m < n but reversed + A = [[QQ(3), QQ(4), QQ(1)], [QQ(1), QQ(2), QQ(1)]] + Ar = [[QQ(1), QQ(0), QQ(-1)], [QQ(0), QQ(1), QQ(1)]] + pivots = [0, 1] + assert ddm_irref(A) == pivots + assert A == Ar + + # m > n case + A = [[QQ(1), QQ(0)], [QQ(1), QQ(3)], [QQ(0), QQ(1)]] + Ar = [[QQ(1), QQ(0)], [QQ(0), QQ(1)], [QQ(0), QQ(0)]] + pivots = [0, 1] + assert ddm_irref(A) == pivots + assert A == Ar + + # Example with missing pivot + A = [[QQ(1), QQ(0), QQ(1)], [QQ(3), QQ(0), QQ(1)]] + Ar = [[QQ(1), QQ(0), QQ(0)], [QQ(0), QQ(0), QQ(1)]] + pivots = [0, 2] + assert ddm_irref(A) == pivots + assert A == Ar + + # Example with missing pivot and no replacement + A = [[QQ(0), QQ(1)], [QQ(0), QQ(2)], [QQ(1), QQ(0)]] + Ar = [[QQ(1), QQ(0)], [QQ(0), QQ(1)], [QQ(0), QQ(0)]] + pivots = [0, 1] + assert ddm_irref(A) == pivots + assert A == Ar + + +def test_ddm_idet(): + A = [] + assert ddm_idet(A, ZZ) == ZZ(1) + + A = [[ZZ(2)]] + assert ddm_idet(A, ZZ) == ZZ(2) + + A = [[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]] + assert ddm_idet(A, ZZ) == ZZ(-2) + + A = [[ZZ(1), ZZ(2), ZZ(3)], [ZZ(1), ZZ(2), ZZ(4)], [ZZ(1), ZZ(3), ZZ(5)]] + assert ddm_idet(A, ZZ) == ZZ(-1) + + A = [[ZZ(1), ZZ(2), ZZ(3)], [ZZ(1), ZZ(2), ZZ(4)], [ZZ(1), ZZ(2), ZZ(5)]] + assert ddm_idet(A, ZZ) == ZZ(0) + + A = [[QQ(1, 2), QQ(1, 2)], [QQ(1, 3), QQ(1, 4)]] + assert ddm_idet(A, QQ) == QQ(-1, 24) + + +def test_ddm_inv(): + A = [] + Ainv = [] + ddm_iinv(Ainv, A, QQ) + assert Ainv == A + + A = [] + Ainv = [] + raises(DMDomainError, lambda: ddm_iinv(Ainv, A, ZZ)) + + A = [[QQ(1), QQ(2)]] + Ainv = [[QQ(0), QQ(0)]] + raises(DMNonSquareMatrixError, lambda: ddm_iinv(Ainv, A, QQ)) + + A = [[QQ(1, 1), QQ(2, 1)], [QQ(3, 1), QQ(4, 1)]] + Ainv = [[QQ(0), QQ(0)], [QQ(0), QQ(0)]] + Ainv_expected = [[QQ(-2, 1), QQ(1, 1)], [QQ(3, 2), QQ(-1, 2)]] + ddm_iinv(Ainv, A, QQ) + assert Ainv == Ainv_expected + + A = [[QQ(1, 1), QQ(2, 1)], [QQ(2, 1), QQ(4, 1)]] + Ainv = [[QQ(0), QQ(0)], [QQ(0), QQ(0)]] + raises(DMNonInvertibleMatrixError, lambda: ddm_iinv(Ainv, A, QQ)) + + +def test_ddm_ilu(): + A = [] + Alu = [] + swaps = ddm_ilu(A) + assert A == Alu + assert swaps == [] + + A = [[]] + Alu = [[]] + swaps = ddm_ilu(A) + assert A == Alu + assert swaps == [] + + A = [[QQ(1), QQ(2)], [QQ(3), QQ(4)]] + Alu = [[QQ(1), QQ(2)], [QQ(3), QQ(-2)]] + swaps = ddm_ilu(A) + assert A == Alu + assert swaps == [] + + A = [[QQ(0), QQ(2)], [QQ(3), QQ(4)]] + Alu = [[QQ(3), QQ(4)], [QQ(0), QQ(2)]] + swaps = ddm_ilu(A) + assert A == Alu + assert swaps == [(0, 1)] + + A = [[QQ(1), QQ(2), QQ(3)], [QQ(4), QQ(5), QQ(6)], [QQ(7), QQ(8), QQ(9)]] + Alu = [[QQ(1), QQ(2), QQ(3)], [QQ(4), QQ(-3), QQ(-6)], [QQ(7), QQ(2), QQ(0)]] + swaps = ddm_ilu(A) + assert A == Alu + assert swaps == [] + + A = [[QQ(0), QQ(1), QQ(2)], [QQ(0), QQ(1), QQ(3)], [QQ(1), QQ(1), QQ(2)]] + Alu = [[QQ(1), QQ(1), QQ(2)], [QQ(0), QQ(1), QQ(3)], [QQ(0), QQ(1), QQ(-1)]] + swaps = ddm_ilu(A) + assert A == Alu + assert swaps == [(0, 2)] + + A = [[QQ(1), QQ(2), QQ(3)], [QQ(4), QQ(5), QQ(6)]] + Alu = [[QQ(1), QQ(2), QQ(3)], [QQ(4), QQ(-3), QQ(-6)]] + swaps = ddm_ilu(A) + assert A == Alu + assert swaps == [] + + A = [[QQ(1), QQ(2)], [QQ(3), QQ(4)], [QQ(5), QQ(6)]] + Alu = [[QQ(1), QQ(2)], [QQ(3), QQ(-2)], [QQ(5), QQ(2)]] + swaps = ddm_ilu(A) + assert A == Alu + assert swaps == [] + + +def test_ddm_ilu_split(): + U = [] + L = [] + Uexp = [] + Lexp = [] + swaps = ddm_ilu_split(L, U, QQ) + assert U == Uexp + assert L == Lexp + assert swaps == [] + + U = [[]] + L = [[QQ(1)]] + Uexp = [[]] + Lexp = [[QQ(1)]] + swaps = ddm_ilu_split(L, U, QQ) + assert U == Uexp + assert L == Lexp + assert swaps == [] + + U = [[QQ(1), QQ(2)], [QQ(3), QQ(4)]] + L = [[QQ(1), QQ(0)], [QQ(0), QQ(1)]] + Uexp = [[QQ(1), QQ(2)], [QQ(0), QQ(-2)]] + Lexp = [[QQ(1), QQ(0)], [QQ(3), QQ(1)]] + swaps = ddm_ilu_split(L, U, QQ) + assert U == Uexp + assert L == Lexp + assert swaps == [] + + U = [[QQ(1), QQ(2), QQ(3)], [QQ(4), QQ(5), QQ(6)]] + L = [[QQ(1), QQ(0)], [QQ(0), QQ(1)]] + Uexp = [[QQ(1), QQ(2), QQ(3)], [QQ(0), QQ(-3), QQ(-6)]] + Lexp = [[QQ(1), QQ(0)], [QQ(4), QQ(1)]] + swaps = ddm_ilu_split(L, U, QQ) + assert U == Uexp + assert L == Lexp + assert swaps == [] + + U = [[QQ(1), QQ(2)], [QQ(3), QQ(4)], [QQ(5), QQ(6)]] + L = [[QQ(1), QQ(0), QQ(0)], [QQ(0), QQ(1), QQ(0)], [QQ(0), QQ(0), QQ(1)]] + Uexp = [[QQ(1), QQ(2)], [QQ(0), QQ(-2)], [QQ(0), QQ(0)]] + Lexp = [[QQ(1), QQ(0), QQ(0)], [QQ(3), QQ(1), QQ(0)], [QQ(5), QQ(2), QQ(1)]] + swaps = ddm_ilu_split(L, U, QQ) + assert U == Uexp + assert L == Lexp + assert swaps == [] + + +def test_ddm_ilu_solve(): + # Basic example + # A = [[QQ(1), QQ(2)], [QQ(3), QQ(4)]] + U = [[QQ(1), QQ(2)], [QQ(0), QQ(-2)]] + L = [[QQ(1), QQ(0)], [QQ(3), QQ(1)]] + swaps = [] + b = DDM([[QQ(1)], [QQ(2)]], (2, 1), QQ) + x = DDM([[QQ(0)], [QQ(0)]], (2, 1), QQ) + xexp = DDM([[QQ(0)], [QQ(1, 2)]], (2, 1), QQ) + ddm_ilu_solve(x, L, U, swaps, b) + assert x == xexp + + # Example with swaps + # A = [[QQ(0), QQ(2)], [QQ(3), QQ(4)]] + U = [[QQ(3), QQ(4)], [QQ(0), QQ(2)]] + L = [[QQ(1), QQ(0)], [QQ(0), QQ(1)]] + swaps = [(0, 1)] + b = DDM([[QQ(1)], [QQ(2)]], (2, 1), QQ) + x = DDM([[QQ(0)], [QQ(0)]], (2, 1), QQ) + xexp = DDM([[QQ(0)], [QQ(1, 2)]], (2, 1), QQ) + ddm_ilu_solve(x, L, U, swaps, b) + assert x == xexp + + # Overdetermined, consistent + # A = DDM([[QQ(1), QQ(2)], [QQ(3), QQ(4)], [QQ(5), QQ(6)]], (3, 2), QQ) + U = [[QQ(1), QQ(2)], [QQ(0), QQ(-2)], [QQ(0), QQ(0)]] + L = [[QQ(1), QQ(0), QQ(0)], [QQ(3), QQ(1), QQ(0)], [QQ(5), QQ(2), QQ(1)]] + swaps = [] + b = DDM([[QQ(1)], [QQ(2)], [QQ(3)]], (3, 1), QQ) + x = DDM([[QQ(0)], [QQ(0)]], (2, 1), QQ) + xexp = DDM([[QQ(0)], [QQ(1, 2)]], (2, 1), QQ) + ddm_ilu_solve(x, L, U, swaps, b) + assert x == xexp + + # Overdetermined, inconsistent + b = DDM([[QQ(1)], [QQ(2)], [QQ(4)]], (3, 1), QQ) + raises(DMNonInvertibleMatrixError, lambda: ddm_ilu_solve(x, L, U, swaps, b)) + + # Square, noninvertible + # A = DDM([[QQ(1), QQ(2)], [QQ(1), QQ(2)]], (2, 2), QQ) + U = [[QQ(1), QQ(2)], [QQ(0), QQ(0)]] + L = [[QQ(1), QQ(0)], [QQ(1), QQ(1)]] + swaps = [] + b = DDM([[QQ(1)], [QQ(2)]], (2, 1), QQ) + raises(DMNonInvertibleMatrixError, lambda: ddm_ilu_solve(x, L, U, swaps, b)) + + # Underdetermined + # A = DDM([[QQ(1), QQ(2)]], (1, 2), QQ) + U = [[QQ(1), QQ(2)]] + L = [[QQ(1)]] + swaps = [] + b = DDM([[QQ(3)]], (1, 1), QQ) + raises(NotImplementedError, lambda: ddm_ilu_solve(x, L, U, swaps, b)) + + # Shape mismatch + b3 = DDM([[QQ(1)], [QQ(2)], [QQ(3)]], (3, 1), QQ) + raises(DMShapeError, lambda: ddm_ilu_solve(x, L, U, swaps, b3)) + + # Empty shape mismatch + U = [[QQ(1)]] + L = [[QQ(1)]] + swaps = [] + x = [[QQ(1)]] + b = [] + raises(DMShapeError, lambda: ddm_ilu_solve(x, L, U, swaps, b)) + + # Empty system + U = [] + L = [] + swaps = [] + b = [] + x = [] + ddm_ilu_solve(x, L, U, swaps, b) + assert x == [] + + +def test_ddm_charpoly(): + A = [] + assert ddm_berk(A, ZZ) == [[ZZ(1)]] + + A = [[ZZ(1), ZZ(2), ZZ(3)], [ZZ(4), ZZ(5), ZZ(6)], [ZZ(7), ZZ(8), ZZ(9)]] + Avec = [[ZZ(1)], [ZZ(-15)], [ZZ(-18)], [ZZ(0)]] + assert ddm_berk(A, ZZ) == Avec + + A = DDM([[ZZ(1), ZZ(2)]], (1, 2), ZZ) + raises(DMShapeError, lambda: ddm_berk(A, ZZ)) diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/tests/test_domainmatrix.py b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/test_domainmatrix.py new file mode 100644 index 0000000000000000000000000000000000000000..2f45029fb080ca91e98ea04aa4717fa675492052 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/test_domainmatrix.py @@ -0,0 +1,1383 @@ +from sympy.external.gmpy import GROUND_TYPES + +from sympy import Integer, Rational, S, sqrt, Matrix, symbols +from sympy import FF, ZZ, QQ, QQ_I, EXRAW + +from sympy.polys.matrices.domainmatrix import DomainMatrix, DomainScalar, DM +from sympy.polys.matrices.exceptions import ( + DMBadInputError, DMDomainError, DMShapeError, DMFormatError, DMNotAField, + DMNonSquareMatrixError, DMNonInvertibleMatrixError, +) +from sympy.polys.matrices.ddm import DDM +from sympy.polys.matrices.sdm import SDM + +from sympy.testing.pytest import raises + + +def test_DM(): + ddm = DDM([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + A = DM([[1, 2], [3, 4]], ZZ) + if GROUND_TYPES != 'flint': + assert A.rep == ddm + else: + assert A.rep == ddm.to_dfm() + assert A.shape == (2, 2) + assert A.domain == ZZ + + +def test_DomainMatrix_init(): + lol = [[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]] + dod = {0: {0: ZZ(1), 1:ZZ(2)}, 1: {0:ZZ(3), 1:ZZ(4)}} + ddm = DDM(lol, (2, 2), ZZ) + sdm = SDM(dod, (2, 2), ZZ) + + A = DomainMatrix(lol, (2, 2), ZZ) + if GROUND_TYPES != 'flint': + assert A.rep == ddm + else: + assert A.rep == ddm.to_dfm() + assert A.shape == (2, 2) + assert A.domain == ZZ + + A = DomainMatrix(dod, (2, 2), ZZ) + assert A.rep == sdm + assert A.shape == (2, 2) + assert A.domain == ZZ + + raises(TypeError, lambda: DomainMatrix(ddm, (2, 2), ZZ)) + raises(TypeError, lambda: DomainMatrix(sdm, (2, 2), ZZ)) + raises(TypeError, lambda: DomainMatrix(Matrix([[1]]), (1, 1), ZZ)) + + for fmt, rep in [('sparse', sdm), ('dense', ddm)]: + if fmt == 'dense' and GROUND_TYPES == 'flint': + rep = rep.to_dfm() + A = DomainMatrix(lol, (2, 2), ZZ, fmt=fmt) + assert A.rep == rep + A = DomainMatrix(dod, (2, 2), ZZ, fmt=fmt) + assert A.rep == rep + + raises(ValueError, lambda: DomainMatrix(lol, (2, 2), ZZ, fmt='invalid')) + + raises(DMBadInputError, lambda: DomainMatrix([[ZZ(1), ZZ(2)]], (2, 2), ZZ)) + + # uses copy + was = [i.copy() for i in lol] + A[0,0] = ZZ(42) + assert was == lol + + +def test_DomainMatrix_from_rep(): + ddm = DDM([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + A = DomainMatrix.from_rep(ddm) + # XXX: Should from_rep convert to DFM? + assert A.rep == ddm + assert A.shape == (2, 2) + assert A.domain == ZZ + + sdm = SDM({0: {0: ZZ(1), 1:ZZ(2)}, 1: {0:ZZ(3), 1:ZZ(4)}}, (2, 2), ZZ) + A = DomainMatrix.from_rep(sdm) + assert A.rep == sdm + assert A.shape == (2, 2) + assert A.domain == ZZ + + A = DomainMatrix([[ZZ(1)]], (1, 1), ZZ) + raises(TypeError, lambda: DomainMatrix.from_rep(A)) + + +def test_DomainMatrix_from_list(): + ddm = DDM([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + A = DomainMatrix.from_list([[1, 2], [3, 4]], ZZ) + if GROUND_TYPES != 'flint': + assert A.rep == ddm + else: + assert A.rep == ddm.to_dfm() + assert A.shape == (2, 2) + assert A.domain == ZZ + + dom = FF(7) + ddm = DDM([[dom(1), dom(2)], [dom(3), dom(4)]], (2, 2), dom) + A = DomainMatrix.from_list([[1, 2], [3, 4]], dom) + if GROUND_TYPES != 'flint': + assert A.rep == ddm + else: + assert A.rep == ddm.to_dfm() + assert A.shape == (2, 2) + assert A.domain == dom + + dom = FF(2**127-1) + ddm = DDM([[dom(1), dom(2)], [dom(3), dom(4)]], (2, 2), dom) + A = DomainMatrix.from_list([[1, 2], [3, 4]], dom) + if GROUND_TYPES != 'flint': + assert A.rep == ddm + else: + assert A.rep == ddm.to_dfm() + assert A.shape == (2, 2) + assert A.domain == dom + + ddm = DDM([[QQ(1, 2), QQ(3, 1)], [QQ(1, 4), QQ(5, 1)]], (2, 2), QQ) + A = DomainMatrix.from_list([[(1, 2), (3, 1)], [(1, 4), (5, 1)]], QQ) + if GROUND_TYPES != 'flint': + assert A.rep == ddm + else: + assert A.rep == ddm.to_dfm() + assert A.shape == (2, 2) + assert A.domain == QQ + + +def test_DomainMatrix_from_list_sympy(): + ddm = DDM([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + A = DomainMatrix.from_list_sympy(2, 2, [[1, 2], [3, 4]]) + if GROUND_TYPES != 'flint': + assert A.rep == ddm + else: + assert A.rep == ddm.to_dfm() + assert A.shape == (2, 2) + assert A.domain == ZZ + + K = QQ.algebraic_field(sqrt(2)) + ddm = DDM( + [[K.convert(1 + sqrt(2)), K.convert(2 + sqrt(2))], + [K.convert(3 + sqrt(2)), K.convert(4 + sqrt(2))]], + (2, 2), + K + ) + A = DomainMatrix.from_list_sympy( + 2, 2, [[1 + sqrt(2), 2 + sqrt(2)], [3 + sqrt(2), 4 + sqrt(2)]], + extension=True) + assert A.rep == ddm + assert A.shape == (2, 2) + assert A.domain == K + + +def test_DomainMatrix_from_dict_sympy(): + sdm = SDM({0: {0: QQ(1, 2)}, 1: {1: QQ(2, 3)}}, (2, 2), QQ) + sympy_dict = {0: {0: Rational(1, 2)}, 1: {1: Rational(2, 3)}} + A = DomainMatrix.from_dict_sympy(2, 2, sympy_dict) + assert A.rep == sdm + assert A.shape == (2, 2) + assert A.domain == QQ + + fds = DomainMatrix.from_dict_sympy + raises(DMBadInputError, lambda: fds(2, 2, {3: {0: Rational(1, 2)}})) + raises(DMBadInputError, lambda: fds(2, 2, {0: {3: Rational(1, 2)}})) + + +def test_DomainMatrix_from_Matrix(): + sdm = SDM({0: {0: ZZ(1), 1: ZZ(2)}, 1: {0: ZZ(3), 1: ZZ(4)}}, (2, 2), ZZ) + A = DomainMatrix.from_Matrix(Matrix([[1, 2], [3, 4]])) + assert A.rep == sdm + assert A.shape == (2, 2) + assert A.domain == ZZ + + K = QQ.algebraic_field(sqrt(2)) + sdm = SDM( + {0: {0: K.convert(1 + sqrt(2)), 1: K.convert(2 + sqrt(2))}, + 1: {0: K.convert(3 + sqrt(2)), 1: K.convert(4 + sqrt(2))}}, + (2, 2), + K + ) + A = DomainMatrix.from_Matrix( + Matrix([[1 + sqrt(2), 2 + sqrt(2)], [3 + sqrt(2), 4 + sqrt(2)]]), + extension=True) + assert A.rep == sdm + assert A.shape == (2, 2) + assert A.domain == K + + A = DomainMatrix.from_Matrix(Matrix([[QQ(1, 2), QQ(3, 4)], [QQ(0, 1), QQ(0, 1)]]), fmt='dense') + ddm = DDM([[QQ(1, 2), QQ(3, 4)], [QQ(0, 1), QQ(0, 1)]], (2, 2), QQ) + + if GROUND_TYPES != 'flint': + assert A.rep == ddm + else: + assert A.rep == ddm.to_dfm() + assert A.shape == (2, 2) + assert A.domain == QQ + + +def test_DomainMatrix_eq(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + assert A == A + B = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(1)]], (2, 2), ZZ) + assert A != B + C = [[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]] + assert A != C + + +def test_DomainMatrix_unify_eq(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + B1 = DomainMatrix([[QQ(1), QQ(2)], [QQ(3), QQ(4)]], (2, 2), QQ) + B2 = DomainMatrix([[QQ(1), QQ(3)], [QQ(3), QQ(4)]], (2, 2), QQ) + B3 = DomainMatrix([[ZZ(1)]], (1, 1), ZZ) + assert A.unify_eq(B1) is True + assert A.unify_eq(B2) is False + assert A.unify_eq(B3) is False + + +def test_DomainMatrix_get_domain(): + K, items = DomainMatrix.get_domain([1, 2, 3, 4]) + assert items == [ZZ(1), ZZ(2), ZZ(3), ZZ(4)] + assert K == ZZ + + K, items = DomainMatrix.get_domain([1, 2, 3, Rational(1, 2)]) + assert items == [QQ(1), QQ(2), QQ(3), QQ(1, 2)] + assert K == QQ + + +def test_DomainMatrix_convert_to(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + Aq = A.convert_to(QQ) + assert Aq == DomainMatrix([[QQ(1), QQ(2)], [QQ(3), QQ(4)]], (2, 2), QQ) + + +def test_DomainMatrix_choose_domain(): + A = [[1, 2], [3, 0]] + assert DM(A, QQ).choose_domain() == DM(A, ZZ) + assert DM(A, QQ).choose_domain(field=True) == DM(A, QQ) + assert DM(A, ZZ).choose_domain(field=True) == DM(A, QQ) + + x = symbols('x') + B = [[1, x], [x**2, x**3]] + assert DM(B, QQ[x]).choose_domain(field=True) == DM(B, ZZ.frac_field(x)) + + +def test_DomainMatrix_to_flat_nz(): + Adm = DM([[1, 2], [3, 0]], ZZ) + Addm = Adm.rep.to_ddm() + Asdm = Adm.rep.to_sdm() + for A in [Adm, Addm, Asdm]: + elems, data = A.to_flat_nz() + assert A.from_flat_nz(elems, data, A.domain) == A + elemsq = [QQ(e) for e in elems] + assert A.from_flat_nz(elemsq, data, QQ) == A.convert_to(QQ) + elems2 = [2*e for e in elems] + assert A.from_flat_nz(elems2, data, A.domain) == 2*A + + +def test_DomainMatrix_to_sympy(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + assert A.to_sympy() == A.convert_to(EXRAW) + + +def test_DomainMatrix_to_field(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + Aq = A.to_field() + assert Aq == DomainMatrix([[QQ(1), QQ(2)], [QQ(3), QQ(4)]], (2, 2), QQ) + + +def test_DomainMatrix_to_sparse(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + A_sparse = A.to_sparse() + assert A_sparse.rep == {0: {0: 1, 1: 2}, 1: {0: 3, 1: 4}} + + +def test_DomainMatrix_to_dense(): + A = DomainMatrix({0: {0: 1, 1: 2}, 1: {0: 3, 1: 4}}, (2, 2), ZZ) + A_dense = A.to_dense() + ddm = DDM([[1, 2], [3, 4]], (2, 2), ZZ) + if GROUND_TYPES != 'flint': + assert A_dense.rep == ddm + else: + assert A_dense.rep == ddm.to_dfm() + + +def test_DomainMatrix_unify(): + Az = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + Aq = DomainMatrix([[QQ(1), QQ(2)], [QQ(3), QQ(4)]], (2, 2), QQ) + assert Az.unify(Az) == (Az, Az) + assert Az.unify(Aq) == (Aq, Aq) + assert Aq.unify(Az) == (Aq, Aq) + assert Aq.unify(Aq) == (Aq, Aq) + + As = DomainMatrix({0: {1: ZZ(1)}, 1:{0:ZZ(2)}}, (2, 2), ZZ) + Ad = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + + assert As.unify(As) == (As, As) + assert Ad.unify(Ad) == (Ad, Ad) + + Bs, Bd = As.unify(Ad, fmt='dense') + assert Bs.rep == DDM([[0, 1], [2, 0]], (2, 2), ZZ).to_dfm_or_ddm() + assert Bd.rep == DDM([[1, 2],[3, 4]], (2, 2), ZZ).to_dfm_or_ddm() + + Bs, Bd = As.unify(Ad, fmt='sparse') + assert Bs.rep == SDM({0: {1: 1}, 1: {0: 2}}, (2, 2), ZZ) + assert Bd.rep == SDM({0: {0: 1, 1: 2}, 1: {0: 3, 1: 4}}, (2, 2), ZZ) + + raises(ValueError, lambda: As.unify(Ad, fmt='invalid')) + + +def test_DomainMatrix_to_Matrix(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + A_Matrix = Matrix([[1, 2], [3, 4]]) + assert A.to_Matrix() == A_Matrix + assert A.to_sparse().to_Matrix() == A_Matrix + assert A.convert_to(QQ).to_Matrix() == A_Matrix + assert A.convert_to(QQ.algebraic_field(sqrt(2))).to_Matrix() == A_Matrix + + +def test_DomainMatrix_to_list(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + assert A.to_list() == [[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]] + + +def test_DomainMatrix_to_list_flat(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + assert A.to_list_flat() == [ZZ(1), ZZ(2), ZZ(3), ZZ(4)] + + +def test_DomainMatrix_flat(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + assert A.flat() == [ZZ(1), ZZ(2), ZZ(3), ZZ(4)] + + +def test_DomainMatrix_from_list_flat(): + nums = [ZZ(1), ZZ(2), ZZ(3), ZZ(4)] + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + + assert DomainMatrix.from_list_flat(nums, (2, 2), ZZ) == A + assert DDM.from_list_flat(nums, (2, 2), ZZ) == A.rep.to_ddm() + assert SDM.from_list_flat(nums, (2, 2), ZZ) == A.rep.to_sdm() + + assert A == A.from_list_flat(A.to_list_flat(), A.shape, A.domain) + + raises(DMBadInputError, DomainMatrix.from_list_flat, nums, (2, 3), ZZ) + raises(DMBadInputError, DDM.from_list_flat, nums, (2, 3), ZZ) + raises(DMBadInputError, SDM.from_list_flat, nums, (2, 3), ZZ) + + +def test_DomainMatrix_to_dod(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + assert A.to_dod() == {0: {0: ZZ(1), 1:ZZ(2)}, 1: {0:ZZ(3), 1:ZZ(4)}} + A = DomainMatrix([[ZZ(1), ZZ(0)], [ZZ(0), ZZ(4)]], (2, 2), ZZ) + assert A.to_dod() == {0: {0: ZZ(1)}, 1: {1: ZZ(4)}} + + +def test_DomainMatrix_from_dod(): + items = {0: {0: ZZ(1), 1:ZZ(2)}, 1: {0:ZZ(3), 1:ZZ(4)}} + A = DM([[1, 2], [3, 4]], ZZ) + assert DomainMatrix.from_dod(items, (2, 2), ZZ) == A.to_sparse() + assert A.from_dod_like(items) == A + assert A.from_dod_like(items, QQ) == A.convert_to(QQ) + + +def test_DomainMatrix_to_dok(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + assert A.to_dok() == {(0, 0):ZZ(1), (0, 1):ZZ(2), (1, 0):ZZ(3), (1, 1):ZZ(4)} + A = DomainMatrix([[ZZ(1), ZZ(0)], [ZZ(0), ZZ(4)]], (2, 2), ZZ) + dok = {(0, 0):ZZ(1), (1, 1):ZZ(4)} + assert A.to_dok() == dok + assert A.to_dense().to_dok() == dok + assert A.to_sparse().to_dok() == dok + assert A.rep.to_ddm().to_dok() == dok + assert A.rep.to_sdm().to_dok() == dok + + +def test_DomainMatrix_from_dok(): + items = {(0, 0): ZZ(1), (1, 1): ZZ(2)} + A = DM([[1, 0], [0, 2]], ZZ) + assert DomainMatrix.from_dok(items, (2, 2), ZZ) == A.to_sparse() + assert DDM.from_dok(items, (2, 2), ZZ) == A.rep.to_ddm() + assert SDM.from_dok(items, (2, 2), ZZ) == A.rep.to_sdm() + + +def test_DomainMatrix_repr(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + assert repr(A) == 'DomainMatrix([[1, 2], [3, 4]], (2, 2), ZZ)' + + +def test_DomainMatrix_transpose(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + AT = DomainMatrix([[ZZ(1), ZZ(3)], [ZZ(2), ZZ(4)]], (2, 2), ZZ) + assert A.transpose() == AT + + +def test_DomainMatrix_is_zero_matrix(): + A = DomainMatrix([[ZZ(1)]], (1, 1), ZZ) + B = DomainMatrix([[ZZ(0)]], (1, 1), ZZ) + assert A.is_zero_matrix is False + assert B.is_zero_matrix is True + + +def test_DomainMatrix_is_upper(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(0), ZZ(4)]], (2, 2), ZZ) + B = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + assert A.is_upper is True + assert B.is_upper is False + + +def test_DomainMatrix_is_lower(): + A = DomainMatrix([[ZZ(1), ZZ(0)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + B = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + assert A.is_lower is True + assert B.is_lower is False + + +def test_DomainMatrix_is_diagonal(): + A = DM([[1, 0], [0, 4]], ZZ) + B = DM([[1, 2], [3, 4]], ZZ) + assert A.is_diagonal is A.to_sparse().is_diagonal is True + assert B.is_diagonal is B.to_sparse().is_diagonal is False + + +def test_DomainMatrix_is_square(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + B = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)], [ZZ(5), ZZ(6)]], (3, 2), ZZ) + assert A.is_square is True + assert B.is_square is False + + +def test_DomainMatrix_diagonal(): + A = DM([[1, 2], [3, 4]], ZZ) + assert A.diagonal() == A.to_sparse().diagonal() == [ZZ(1), ZZ(4)] + A = DM([[1, 2], [3, 4], [5, 6]], ZZ) + assert A.diagonal() == A.to_sparse().diagonal() == [ZZ(1), ZZ(4)] + A = DM([[1, 2, 3], [4, 5, 6]], ZZ) + assert A.diagonal() == A.to_sparse().diagonal() == [ZZ(1), ZZ(5)] + + +def test_DomainMatrix_rank(): + A = DomainMatrix([[QQ(1), QQ(2)], [QQ(3), QQ(4)], [QQ(6), QQ(8)]], (3, 2), QQ) + assert A.rank() == 2 + + +def test_DomainMatrix_add(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + B = DomainMatrix([[ZZ(2), ZZ(4)], [ZZ(6), ZZ(8)]], (2, 2), ZZ) + assert A + A == A.add(A) == B + + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + L = [[2, 3], [3, 4]] + raises(TypeError, lambda: A + L) + raises(TypeError, lambda: L + A) + + A1 = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + A2 = DomainMatrix([[ZZ(1), ZZ(2)]], (1, 2), ZZ) + raises(DMShapeError, lambda: A1 + A2) + raises(DMShapeError, lambda: A2 + A1) + raises(DMShapeError, lambda: A1.add(A2)) + raises(DMShapeError, lambda: A2.add(A1)) + + Az = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + Aq = DomainMatrix([[QQ(1), QQ(2)], [QQ(3), QQ(4)]], (2, 2), QQ) + Asum = DomainMatrix([[QQ(2), QQ(4)], [QQ(6), QQ(8)]], (2, 2), QQ) + assert Az + Aq == Asum + assert Aq + Az == Asum + raises(DMDomainError, lambda: Az.add(Aq)) + raises(DMDomainError, lambda: Aq.add(Az)) + + As = DomainMatrix({0: {1: ZZ(1)}, 1: {0: ZZ(2)}}, (2, 2), ZZ) + Ad = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + + Asd = As + Ad + Ads = Ad + As + assert Asd == DomainMatrix([[1, 3], [5, 4]], (2, 2), ZZ) + assert Asd.rep == DDM([[1, 3], [5, 4]], (2, 2), ZZ).to_dfm_or_ddm() + assert Ads == DomainMatrix([[1, 3], [5, 4]], (2, 2), ZZ) + assert Ads.rep == DDM([[1, 3], [5, 4]], (2, 2), ZZ).to_dfm_or_ddm() + raises(DMFormatError, lambda: As.add(Ad)) + + +def test_DomainMatrix_sub(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + B = DomainMatrix([[ZZ(0), ZZ(0)], [ZZ(0), ZZ(0)]], (2, 2), ZZ) + assert A - A == A.sub(A) == B + + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + L = [[2, 3], [3, 4]] + raises(TypeError, lambda: A - L) + raises(TypeError, lambda: L - A) + + A1 = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + A2 = DomainMatrix([[ZZ(1), ZZ(2)]], (1, 2), ZZ) + raises(DMShapeError, lambda: A1 - A2) + raises(DMShapeError, lambda: A2 - A1) + raises(DMShapeError, lambda: A1.sub(A2)) + raises(DMShapeError, lambda: A2.sub(A1)) + + Az = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + Aq = DomainMatrix([[QQ(1), QQ(2)], [QQ(3), QQ(4)]], (2, 2), QQ) + Adiff = DomainMatrix([[QQ(0), QQ(0)], [QQ(0), QQ(0)]], (2, 2), QQ) + assert Az - Aq == Adiff + assert Aq - Az == Adiff + raises(DMDomainError, lambda: Az.sub(Aq)) + raises(DMDomainError, lambda: Aq.sub(Az)) + + As = DomainMatrix({0: {1: ZZ(1)}, 1: {0: ZZ(2)}}, (2, 2), ZZ) + Ad = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + + Asd = As - Ad + Ads = Ad - As + assert Asd == DomainMatrix([[-1, -1], [-1, -4]], (2, 2), ZZ) + assert Asd.rep == DDM([[-1, -1], [-1, -4]], (2, 2), ZZ).to_dfm_or_ddm() + assert Asd == -Ads + assert Asd.rep == -Ads.rep + + +def test_DomainMatrix_neg(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + Aneg = DomainMatrix([[ZZ(-1), ZZ(-2)], [ZZ(-3), ZZ(-4)]], (2, 2), ZZ) + assert -A == A.neg() == Aneg + + +def test_DomainMatrix_mul(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + A2 = DomainMatrix([[ZZ(7), ZZ(10)], [ZZ(15), ZZ(22)]], (2, 2), ZZ) + assert A*A == A.matmul(A) == A2 + + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + L = [[1, 2], [3, 4]] + raises(TypeError, lambda: A * L) + raises(TypeError, lambda: L * A) + + Az = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + Aq = DomainMatrix([[QQ(1), QQ(2)], [QQ(3), QQ(4)]], (2, 2), QQ) + Aprod = DomainMatrix([[QQ(7), QQ(10)], [QQ(15), QQ(22)]], (2, 2), QQ) + assert Az * Aq == Aprod + assert Aq * Az == Aprod + raises(DMDomainError, lambda: Az.matmul(Aq)) + raises(DMDomainError, lambda: Aq.matmul(Az)) + + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + AA = DomainMatrix([[ZZ(2), ZZ(4)], [ZZ(6), ZZ(8)]], (2, 2), ZZ) + x = ZZ(2) + assert A * x == x * A == A.mul(x) == AA + + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + AA = DomainMatrix.zeros((2, 2), ZZ) + x = ZZ(0) + assert A * x == x * A == A.mul(x).to_sparse() == AA + + As = DomainMatrix({0: {1: ZZ(1)}, 1: {0: ZZ(2)}}, (2, 2), ZZ) + Ad = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + + Asd = As * Ad + Ads = Ad * As + assert Asd == DomainMatrix([[3, 4], [2, 4]], (2, 2), ZZ) + assert Asd.rep == DDM([[3, 4], [2, 4]], (2, 2), ZZ).to_dfm_or_ddm() + assert Ads == DomainMatrix([[4, 1], [8, 3]], (2, 2), ZZ) + assert Ads.rep == DDM([[4, 1], [8, 3]], (2, 2), ZZ).to_dfm_or_ddm() + + +def test_DomainMatrix_mul_elementwise(): + A = DomainMatrix([[ZZ(2), ZZ(2)], [ZZ(0), ZZ(0)]], (2, 2), ZZ) + B = DomainMatrix([[ZZ(4), ZZ(0)], [ZZ(3), ZZ(0)]], (2, 2), ZZ) + C = DomainMatrix([[ZZ(8), ZZ(0)], [ZZ(0), ZZ(0)]], (2, 2), ZZ) + assert A.mul_elementwise(B) == C + assert B.mul_elementwise(A) == C + + +def test_DomainMatrix_pow(): + eye = DomainMatrix.eye(2, ZZ) + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + A2 = DomainMatrix([[ZZ(7), ZZ(10)], [ZZ(15), ZZ(22)]], (2, 2), ZZ) + A3 = DomainMatrix([[ZZ(37), ZZ(54)], [ZZ(81), ZZ(118)]], (2, 2), ZZ) + assert A**0 == A.pow(0) == eye + assert A**1 == A.pow(1) == A + assert A**2 == A.pow(2) == A2 + assert A**3 == A.pow(3) == A3 + + raises(TypeError, lambda: A ** Rational(1, 2)) + raises(NotImplementedError, lambda: A ** -1) + raises(NotImplementedError, lambda: A.pow(-1)) + + A = DomainMatrix.zeros((2, 1), ZZ) + raises(DMNonSquareMatrixError, lambda: A ** 1) + + +def test_DomainMatrix_clear_denoms(): + A = DM([[(1,2),(1,3)],[(1,4),(1,5)]], QQ) + + den_Z = DomainScalar(ZZ(60), ZZ) + Anum_Z = DM([[30, 20], [15, 12]], ZZ) + Anum_Q = Anum_Z.convert_to(QQ) + + assert A.clear_denoms() == (den_Z, Anum_Q) + assert A.clear_denoms(convert=True) == (den_Z, Anum_Z) + assert A * den_Z == Anum_Q + assert A == Anum_Q / den_Z + + +def test_DomainMatrix_clear_denoms_rowwise(): + A = DM([[(1,2),(1,3)],[(1,4),(1,5)]], QQ) + + den_Z = DM([[6, 0], [0, 20]], ZZ).to_sparse() + Anum_Z = DM([[3, 2], [5, 4]], ZZ) + Anum_Q = DM([[3, 2], [5, 4]], QQ) + + assert A.clear_denoms_rowwise() == (den_Z, Anum_Q) + assert A.clear_denoms_rowwise(convert=True) == (den_Z, Anum_Z) + assert den_Z * A == Anum_Q + assert A == den_Z.to_field().inv() * Anum_Q + + A = DM([[(1,2),(1,3),0,0],[0,0,0,0], [(1,4),(1,5),(1,6),(1,7)]], QQ) + den_Z = DM([[6, 0, 0], [0, 1, 0], [0, 0, 420]], ZZ).to_sparse() + Anum_Z = DM([[3, 2, 0, 0], [0, 0, 0, 0], [105, 84, 70, 60]], ZZ) + Anum_Q = Anum_Z.convert_to(QQ) + + assert A.clear_denoms_rowwise() == (den_Z, Anum_Q) + assert A.clear_denoms_rowwise(convert=True) == (den_Z, Anum_Z) + assert den_Z * A == Anum_Q + assert A == den_Z.to_field().inv() * Anum_Q + + +def test_DomainMatrix_cancel_denom(): + A = DM([[2, 4], [6, 8]], ZZ) + assert A.cancel_denom(ZZ(1)) == (DM([[2, 4], [6, 8]], ZZ), ZZ(1)) + assert A.cancel_denom(ZZ(3)) == (DM([[2, 4], [6, 8]], ZZ), ZZ(3)) + assert A.cancel_denom(ZZ(4)) == (DM([[1, 2], [3, 4]], ZZ), ZZ(2)) + + A = DM([[1, 2], [3, 4]], ZZ) + assert A.cancel_denom(ZZ(2)) == (A, ZZ(2)) + assert A.cancel_denom(ZZ(-2)) == (-A, ZZ(2)) + + # Test canonicalization of denominator over Gaussian rationals. + A = DM([[1, 2], [3, 4]], QQ_I) + assert A.cancel_denom(QQ_I(0,2)) == (QQ_I(0,-1)*A, QQ_I(2)) + + raises(ZeroDivisionError, lambda: A.cancel_denom(ZZ(0))) + + +def test_DomainMatrix_cancel_denom_elementwise(): + A = DM([[2, 4], [6, 8]], ZZ) + numers, denoms = A.cancel_denom_elementwise(ZZ(1)) + assert numers == DM([[2, 4], [6, 8]], ZZ) + assert denoms == DM([[1, 1], [1, 1]], ZZ) + numers, denoms = A.cancel_denom_elementwise(ZZ(4)) + assert numers == DM([[1, 1], [3, 2]], ZZ) + assert denoms == DM([[2, 1], [2, 1]], ZZ) + + raises(ZeroDivisionError, lambda: A.cancel_denom_elementwise(ZZ(0))) + + +def test_DomainMatrix_content_primitive(): + A = DM([[2, 4], [6, 8]], ZZ) + A_primitive = DM([[1, 2], [3, 4]], ZZ) + A_content = ZZ(2) + assert A.content() == A_content + assert A.primitive() == (A_content, A_primitive) + + +def test_DomainMatrix_scc(): + Ad = DomainMatrix([[ZZ(1), ZZ(2), ZZ(3)], + [ZZ(0), ZZ(1), ZZ(0)], + [ZZ(2), ZZ(0), ZZ(4)]], (3, 3), ZZ) + As = Ad.to_sparse() + Addm = Ad.rep + Asdm = As.rep + for A in [Ad, As, Addm, Asdm]: + assert Ad.scc() == [[1], [0, 2]] + + A = DM([[ZZ(1), ZZ(2), ZZ(3)]], ZZ) + raises(DMNonSquareMatrixError, lambda: A.scc()) + + +def test_DomainMatrix_rref(): + # More tests in test_rref.py + A = DomainMatrix([], (0, 1), QQ) + assert A.rref() == (A, ()) + + A = DomainMatrix([[QQ(1)]], (1, 1), QQ) + assert A.rref() == (A, (0,)) + + A = DomainMatrix([[QQ(0)]], (1, 1), QQ) + assert A.rref() == (A, ()) + + A = DomainMatrix([[QQ(1), QQ(2)], [QQ(3), QQ(4)]], (2, 2), QQ) + Ar, pivots = A.rref() + assert Ar == DomainMatrix([[QQ(1), QQ(0)], [QQ(0), QQ(1)]], (2, 2), QQ) + assert pivots == (0, 1) + + A = DomainMatrix([[QQ(0), QQ(2)], [QQ(3), QQ(4)]], (2, 2), QQ) + Ar, pivots = A.rref() + assert Ar == DomainMatrix([[QQ(1), QQ(0)], [QQ(0), QQ(1)]], (2, 2), QQ) + assert pivots == (0, 1) + + A = DomainMatrix([[QQ(0), QQ(2)], [QQ(0), QQ(4)]], (2, 2), QQ) + Ar, pivots = A.rref() + assert Ar == DomainMatrix([[QQ(0), QQ(1)], [QQ(0), QQ(0)]], (2, 2), QQ) + assert pivots == (1,) + + Az = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + Ar, pivots = Az.rref() + assert Ar == DomainMatrix([[QQ(1), QQ(0)], [QQ(0), QQ(1)]], (2, 2), QQ) + assert pivots == (0, 1) + + methods = ('auto', 'GJ', 'FF', 'CD', 'GJ_dense', 'FF_dense', 'CD_dense') + Az = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + for method in methods: + Ar, pivots = Az.rref(method=method) + assert Ar == DomainMatrix([[QQ(1), QQ(0)], [QQ(0), QQ(1)]], (2, 2), QQ) + assert pivots == (0, 1) + + raises(ValueError, lambda: Az.rref(method='foo')) + raises(ValueError, lambda: Az.rref_den(method='foo')) + + +def test_DomainMatrix_columnspace(): + A = DomainMatrix([[QQ(1), QQ(-1), QQ(1)], [QQ(2), QQ(-2), QQ(3)]], (2, 3), QQ) + Acol = DomainMatrix([[QQ(1), QQ(1)], [QQ(2), QQ(3)]], (2, 2), QQ) + assert A.columnspace() == Acol + + Az = DomainMatrix([[ZZ(1), ZZ(-1), ZZ(1)], [ZZ(2), ZZ(-2), ZZ(3)]], (2, 3), ZZ) + raises(DMNotAField, lambda: Az.columnspace()) + + A = DomainMatrix([[QQ(1), QQ(-1), QQ(1)], [QQ(2), QQ(-2), QQ(3)]], (2, 3), QQ, fmt='sparse') + Acol = DomainMatrix({0: {0: QQ(1), 1: QQ(1)}, 1: {0: QQ(2), 1: QQ(3)}}, (2, 2), QQ) + assert A.columnspace() == Acol + + +def test_DomainMatrix_rowspace(): + A = DomainMatrix([[QQ(1), QQ(-1), QQ(1)], [QQ(2), QQ(-2), QQ(3)]], (2, 3), QQ) + assert A.rowspace() == A + + Az = DomainMatrix([[ZZ(1), ZZ(-1), ZZ(1)], [ZZ(2), ZZ(-2), ZZ(3)]], (2, 3), ZZ) + raises(DMNotAField, lambda: Az.rowspace()) + + A = DomainMatrix([[QQ(1), QQ(-1), QQ(1)], [QQ(2), QQ(-2), QQ(3)]], (2, 3), QQ, fmt='sparse') + assert A.rowspace() == A + + +def test_DomainMatrix_nullspace(): + A = DomainMatrix([[QQ(1), QQ(1)], [QQ(1), QQ(1)]], (2, 2), QQ) + Anull = DomainMatrix([[QQ(-1), QQ(1)]], (1, 2), QQ) + assert A.nullspace() == Anull + + A = DomainMatrix([[ZZ(1), ZZ(1)], [ZZ(1), ZZ(1)]], (2, 2), ZZ) + Anull = DomainMatrix([[ZZ(-1), ZZ(1)]], (1, 2), ZZ) + assert A.nullspace() == Anull + + raises(DMNotAField, lambda: A.nullspace(divide_last=True)) + + A = DomainMatrix([[ZZ(2), ZZ(2)], [ZZ(2), ZZ(2)]], (2, 2), ZZ) + Anull = DomainMatrix([[ZZ(-2), ZZ(2)]], (1, 2), ZZ) + + Arref, den, pivots = A.rref_den() + assert den == ZZ(2) + assert Arref.nullspace_from_rref() == Anull + assert Arref.nullspace_from_rref(pivots) == Anull + assert Arref.to_sparse().nullspace_from_rref() == Anull.to_sparse() + assert Arref.to_sparse().nullspace_from_rref(pivots) == Anull.to_sparse() + + +def test_DomainMatrix_solve(): + # XXX: Maybe the _solve method should be changed... + A = DomainMatrix([[QQ(1), QQ(2)], [QQ(2), QQ(4)]], (2, 2), QQ) + b = DomainMatrix([[QQ(1)], [QQ(2)]], (2, 1), QQ) + particular = DomainMatrix([[1, 0]], (1, 2), QQ) + nullspace = DomainMatrix([[-2, 1]], (1, 2), QQ) + assert A._solve(b) == (particular, nullspace) + + b3 = DomainMatrix([[QQ(1)], [QQ(1)], [QQ(1)]], (3, 1), QQ) + raises(DMShapeError, lambda: A._solve(b3)) + + bz = DomainMatrix([[ZZ(1)], [ZZ(1)]], (2, 1), ZZ) + raises(DMNotAField, lambda: A._solve(bz)) + + +def test_DomainMatrix_inv(): + A = DomainMatrix([], (0, 0), QQ) + assert A.inv() == A + + A = DomainMatrix([[QQ(1), QQ(2)], [QQ(3), QQ(4)]], (2, 2), QQ) + Ainv = DomainMatrix([[QQ(-2), QQ(1)], [QQ(3, 2), QQ(-1, 2)]], (2, 2), QQ) + assert A.inv() == Ainv + + Az = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + raises(DMNotAField, lambda: Az.inv()) + + Ans = DomainMatrix([[QQ(1), QQ(2)]], (1, 2), QQ) + raises(DMNonSquareMatrixError, lambda: Ans.inv()) + + Aninv = DomainMatrix([[QQ(1), QQ(2)], [QQ(3), QQ(6)]], (2, 2), QQ) + raises(DMNonInvertibleMatrixError, lambda: Aninv.inv()) + + Z3 = FF(3) + assert DM([[1, 2], [3, 4]], Z3).inv() == DM([[1, 1], [0, 1]], Z3) + + Z6 = FF(6) + raises(DMNotAField, lambda: DM([[1, 2], [3, 4]], Z6).inv()) + + +def test_DomainMatrix_det(): + A = DomainMatrix([], (0, 0), ZZ) + assert A.det() == 1 + + A = DomainMatrix([[1]], (1, 1), ZZ) + assert A.det() == 1 + + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + assert A.det() == ZZ(-2) + + A = DomainMatrix([[ZZ(1), ZZ(2), ZZ(3)], [ZZ(1), ZZ(2), ZZ(4)], [ZZ(1), ZZ(3), ZZ(5)]], (3, 3), ZZ) + assert A.det() == ZZ(-1) + + A = DomainMatrix([[ZZ(1), ZZ(2), ZZ(3)], [ZZ(1), ZZ(2), ZZ(4)], [ZZ(1), ZZ(2), ZZ(5)]], (3, 3), ZZ) + assert A.det() == ZZ(0) + + Ans = DomainMatrix([[QQ(1), QQ(2)]], (1, 2), QQ) + raises(DMNonSquareMatrixError, lambda: Ans.det()) + + A = DomainMatrix([[QQ(1), QQ(2)], [QQ(3), QQ(4)]], (2, 2), QQ) + assert A.det() == QQ(-2) + + +def test_DomainMatrix_eval_poly(): + dM = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + p = [ZZ(1), ZZ(2), ZZ(3)] + result = DomainMatrix([[ZZ(12), ZZ(14)], [ZZ(21), ZZ(33)]], (2, 2), ZZ) + assert dM.eval_poly(p) == result == p[0]*dM**2 + p[1]*dM + p[2]*dM**0 + assert dM.eval_poly([]) == dM.zeros(dM.shape, dM.domain) + assert dM.eval_poly([ZZ(2)]) == 2*dM.eye(2, dM.domain) + + dM2 = DomainMatrix([[ZZ(1), ZZ(2)]], (1, 2), ZZ) + raises(DMNonSquareMatrixError, lambda: dM2.eval_poly([ZZ(1)])) + + +def test_DomainMatrix_eval_poly_mul(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + b = DomainMatrix([[ZZ(1)], [ZZ(2)]], (2, 1), ZZ) + p = [ZZ(1), ZZ(2), ZZ(3)] + result = DomainMatrix([[ZZ(40)], [ZZ(87)]], (2, 1), ZZ) + assert A.eval_poly_mul(p, b) == result == p[0]*A**2*b + p[1]*A*b + p[2]*b + + dM = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + dM1 = DomainMatrix([[ZZ(1)], [ZZ(2)]], (2, 1), ZZ) + raises(DMNonSquareMatrixError, lambda: dM1.eval_poly_mul([ZZ(1)], b)) + b1 = DomainMatrix([[ZZ(1), ZZ(2)]], (1, 2), ZZ) + raises(DMShapeError, lambda: dM.eval_poly_mul([ZZ(1)], b1)) + bq = DomainMatrix([[QQ(1)], [QQ(2)]], (2, 1), QQ) + raises(DMDomainError, lambda: dM.eval_poly_mul([ZZ(1)], bq)) + + +def _check_solve_den(A, b, xnum, xden): + # Examples for solve_den, solve_den_charpoly, solve_den_rref should use + # this so that all methods and types are tested. + + case1 = (A, xnum, b) + case2 = (A.to_sparse(), xnum.to_sparse(), b.to_sparse()) + + for Ai, xnum_i, b_i in [case1, case2]: + # The key invariant for solve_den: + assert Ai*xnum_i == xden*b_i + + # solve_den_rref can differ at least by a minus sign + answers = [(xnum_i, xden), (-xnum_i, -xden)] + assert Ai.solve_den(b) in answers + assert Ai.solve_den(b, method='rref') in answers + assert Ai.solve_den_rref(b) in answers + + # charpoly can only be used if A is square and guarantees to return the + # actual determinant as a denominator. + m, n = Ai.shape + if m == n: + assert Ai.solve_den(b_i, method='charpoly') == (xnum_i, xden) + assert Ai.solve_den_charpoly(b_i) == (xnum_i, xden) + else: + raises(DMNonSquareMatrixError, lambda: Ai.solve_den_charpoly(b)) + raises(DMNonSquareMatrixError, lambda: Ai.solve_den(b, method='charpoly')) + + +def test_DomainMatrix_solve_den(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + b = DomainMatrix([[ZZ(1)], [ZZ(2)]], (2, 1), ZZ) + result = DomainMatrix([[ZZ(0)], [ZZ(-1)]], (2, 1), ZZ) + den = ZZ(-2) + _check_solve_den(A, b, result, den) + + A = DomainMatrix([ + [ZZ(1), ZZ(2), ZZ(3)], + [ZZ(1), ZZ(2), ZZ(4)], + [ZZ(1), ZZ(3), ZZ(5)]], (3, 3), ZZ) + b = DomainMatrix([[ZZ(1)], [ZZ(2)], [ZZ(3)]], (3, 1), ZZ) + result = DomainMatrix([[ZZ(2)], [ZZ(0)], [ZZ(-1)]], (3, 1), ZZ) + den = ZZ(-1) + _check_solve_den(A, b, result, den) + + A = DomainMatrix([[ZZ(2)], [ZZ(2)]], (2, 1), ZZ) + b = DomainMatrix([[ZZ(3)], [ZZ(3)]], (2, 1), ZZ) + result = DomainMatrix([[ZZ(3)]], (1, 1), ZZ) + den = ZZ(2) + _check_solve_den(A, b, result, den) + + +def test_DomainMatrix_solve_den_charpoly(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + b = DomainMatrix([[ZZ(1)], [ZZ(2)]], (2, 1), ZZ) + A1 = DomainMatrix([[ZZ(1), ZZ(2)]], (1, 2), ZZ) + raises(DMNonSquareMatrixError, lambda: A1.solve_den_charpoly(b)) + b1 = DomainMatrix([[ZZ(1), ZZ(2)]], (1, 2), ZZ) + raises(DMShapeError, lambda: A.solve_den_charpoly(b1)) + bq = DomainMatrix([[QQ(1)], [QQ(2)]], (2, 1), QQ) + raises(DMDomainError, lambda: A.solve_den_charpoly(bq)) + + +def test_DomainMatrix_solve_den_charpoly_check(): + # Test check + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(2), ZZ(4)]], (2, 2), ZZ) + b = DomainMatrix([[ZZ(1)], [ZZ(3)]], (2, 1), ZZ) + raises(DMNonInvertibleMatrixError, lambda: A.solve_den_charpoly(b)) + adjAb = DomainMatrix([[ZZ(-2)], [ZZ(1)]], (2, 1), ZZ) + assert A.adjugate() * b == adjAb + assert A.solve_den_charpoly(b, check=False) == (adjAb, ZZ(0)) + + +def test_DomainMatrix_solve_den_errors(): + A = DomainMatrix([[ZZ(1), ZZ(2)]], (1, 2), ZZ) + b = DomainMatrix([[ZZ(1)], [ZZ(2)]], (2, 1), ZZ) + raises(DMShapeError, lambda: A.solve_den(b)) + raises(DMShapeError, lambda: A.solve_den_rref(b)) + + A = DomainMatrix([[ZZ(1), ZZ(2)]], (1, 2), ZZ) + b = DomainMatrix([[ZZ(1), ZZ(2)]], (1, 2), ZZ) + raises(DMShapeError, lambda: A.solve_den(b)) + raises(DMShapeError, lambda: A.solve_den_rref(b)) + + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + b1 = DomainMatrix([[ZZ(1), ZZ(2)]], (1, 2), ZZ) + raises(DMShapeError, lambda: A.solve_den(b1)) + + A = DomainMatrix([[ZZ(2)]], (1, 1), ZZ) + b = DomainMatrix([[ZZ(2)]], (1, 1), ZZ) + raises(DMBadInputError, lambda: A.solve_den(b1, method='invalid')) + + A = DomainMatrix([[ZZ(1)], [ZZ(2)]], (2, 1), ZZ) + b = DomainMatrix([[ZZ(1)], [ZZ(2)]], (2, 1), ZZ) + raises(DMNonSquareMatrixError, lambda: A.solve_den_charpoly(b)) + + +def test_DomainMatrix_solve_den_rref_underdetermined(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(1), ZZ(2)]], (2, 2), ZZ) + b = DomainMatrix([[ZZ(1)], [ZZ(1)]], (2, 1), ZZ) + raises(DMNonInvertibleMatrixError, lambda: A.solve_den(b)) + raises(DMNonInvertibleMatrixError, lambda: A.solve_den_rref(b)) + + +def test_DomainMatrix_adj_poly_det(): + A = DM([[ZZ(1), ZZ(2), ZZ(3)], + [ZZ(4), ZZ(5), ZZ(6)], + [ZZ(7), ZZ(8), ZZ(9)]], ZZ) + p, detA = A.adj_poly_det() + assert p == [ZZ(1), ZZ(-15), ZZ(-18)] + assert A.adjugate() == p[0]*A**2 + p[1]*A**1 + p[2]*A**0 == A.eval_poly(p) + assert A.det() == detA + + A = DM([[ZZ(1), ZZ(2), ZZ(3)], + [ZZ(7), ZZ(8), ZZ(9)]], ZZ) + raises(DMNonSquareMatrixError, lambda: A.adj_poly_det()) + + +def test_DomainMatrix_inv_den(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + den = ZZ(-2) + result = DomainMatrix([[ZZ(4), ZZ(-2)], [ZZ(-3), ZZ(1)]], (2, 2), ZZ) + assert A.inv_den() == (result, den) + + +def test_DomainMatrix_adjugate(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + result = DomainMatrix([[ZZ(4), ZZ(-2)], [ZZ(-3), ZZ(1)]], (2, 2), ZZ) + assert A.adjugate() == result + + +def test_DomainMatrix_adj_det(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + adjA = DomainMatrix([[ZZ(4), ZZ(-2)], [ZZ(-3), ZZ(1)]], (2, 2), ZZ) + assert A.adj_det() == (adjA, ZZ(-2)) + + +def test_DomainMatrix_lu(): + A = DomainMatrix([], (0, 0), QQ) + assert A.lu() == (A, A, []) + + A = DomainMatrix([[QQ(1), QQ(2)], [QQ(3), QQ(4)]], (2, 2), QQ) + L = DomainMatrix([[QQ(1), QQ(0)], [QQ(3), QQ(1)]], (2, 2), QQ) + U = DomainMatrix([[QQ(1), QQ(2)], [QQ(0), QQ(-2)]], (2, 2), QQ) + swaps = [] + assert A.lu() == (L, U, swaps) + + A = DomainMatrix([[QQ(0), QQ(2)], [QQ(3), QQ(4)]], (2, 2), QQ) + L = DomainMatrix([[QQ(1), QQ(0)], [QQ(0), QQ(1)]], (2, 2), QQ) + U = DomainMatrix([[QQ(3), QQ(4)], [QQ(0), QQ(2)]], (2, 2), QQ) + swaps = [(0, 1)] + assert A.lu() == (L, U, swaps) + + A = DomainMatrix([[QQ(1), QQ(2)], [QQ(2), QQ(4)]], (2, 2), QQ) + L = DomainMatrix([[QQ(1), QQ(0)], [QQ(2), QQ(1)]], (2, 2), QQ) + U = DomainMatrix([[QQ(1), QQ(2)], [QQ(0), QQ(0)]], (2, 2), QQ) + swaps = [] + assert A.lu() == (L, U, swaps) + + A = DomainMatrix([[QQ(0), QQ(2)], [QQ(0), QQ(4)]], (2, 2), QQ) + L = DomainMatrix([[QQ(1), QQ(0)], [QQ(0), QQ(1)]], (2, 2), QQ) + U = DomainMatrix([[QQ(0), QQ(2)], [QQ(0), QQ(4)]], (2, 2), QQ) + swaps = [] + assert A.lu() == (L, U, swaps) + + A = DomainMatrix([[QQ(1), QQ(2), QQ(3)], [QQ(4), QQ(5), QQ(6)]], (2, 3), QQ) + L = DomainMatrix([[QQ(1), QQ(0)], [QQ(4), QQ(1)]], (2, 2), QQ) + U = DomainMatrix([[QQ(1), QQ(2), QQ(3)], [QQ(0), QQ(-3), QQ(-6)]], (2, 3), QQ) + swaps = [] + assert A.lu() == (L, U, swaps) + + A = DomainMatrix([[QQ(1), QQ(2)], [QQ(3), QQ(4)], [QQ(5), QQ(6)]], (3, 2), QQ) + L = DomainMatrix([ + [QQ(1), QQ(0), QQ(0)], + [QQ(3), QQ(1), QQ(0)], + [QQ(5), QQ(2), QQ(1)]], (3, 3), QQ) + U = DomainMatrix([[QQ(1), QQ(2)], [QQ(0), QQ(-2)], [QQ(0), QQ(0)]], (3, 2), QQ) + swaps = [] + assert A.lu() == (L, U, swaps) + + A = [[1, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 1], [0, 0, 1, 2]] + L = [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 1, 1]] + U = [[1, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 1], [0, 0, 0, 1]] + to_dom = lambda rows, dom: [[dom(e) for e in row] for row in rows] + A = DomainMatrix(to_dom(A, QQ), (4, 4), QQ) + L = DomainMatrix(to_dom(L, QQ), (4, 4), QQ) + U = DomainMatrix(to_dom(U, QQ), (4, 4), QQ) + assert A.lu() == (L, U, []) + + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + raises(DMNotAField, lambda: A.lu()) + + +def test_DomainMatrix_lu_solve(): + # Base case + A = b = x = DomainMatrix([], (0, 0), QQ) + assert A.lu_solve(b) == x + + # Basic example + A = DomainMatrix([[QQ(1), QQ(2)], [QQ(3), QQ(4)]], (2, 2), QQ) + b = DomainMatrix([[QQ(1)], [QQ(2)]], (2, 1), QQ) + x = DomainMatrix([[QQ(0)], [QQ(1, 2)]], (2, 1), QQ) + assert A.lu_solve(b) == x + + # Example with swaps + A = DomainMatrix([[QQ(0), QQ(2)], [QQ(3), QQ(4)]], (2, 2), QQ) + b = DomainMatrix([[QQ(1)], [QQ(2)]], (2, 1), QQ) + x = DomainMatrix([[QQ(0)], [QQ(1, 2)]], (2, 1), QQ) + assert A.lu_solve(b) == x + + # Non-invertible + A = DomainMatrix([[QQ(1), QQ(2)], [QQ(2), QQ(4)]], (2, 2), QQ) + b = DomainMatrix([[QQ(1)], [QQ(2)]], (2, 1), QQ) + raises(DMNonInvertibleMatrixError, lambda: A.lu_solve(b)) + + # Overdetermined, consistent + A = DomainMatrix([[QQ(1), QQ(2)], [QQ(3), QQ(4)], [QQ(5), QQ(6)]], (3, 2), QQ) + b = DomainMatrix([[QQ(1)], [QQ(2)], [QQ(3)]], (3, 1), QQ) + x = DomainMatrix([[QQ(0)], [QQ(1, 2)]], (2, 1), QQ) + assert A.lu_solve(b) == x + + # Overdetermined, inconsistent + A = DomainMatrix([[QQ(1), QQ(2)], [QQ(3), QQ(4)], [QQ(5), QQ(6)]], (3, 2), QQ) + b = DomainMatrix([[QQ(1)], [QQ(2)], [QQ(4)]], (3, 1), QQ) + raises(DMNonInvertibleMatrixError, lambda: A.lu_solve(b)) + + # Underdetermined + A = DomainMatrix([[QQ(1), QQ(2)]], (1, 2), QQ) + b = DomainMatrix([[QQ(1)]], (1, 1), QQ) + raises(NotImplementedError, lambda: A.lu_solve(b)) + + # Non-field + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + b = DomainMatrix([[ZZ(1)], [ZZ(2)]], (2, 1), ZZ) + raises(DMNotAField, lambda: A.lu_solve(b)) + + # Shape mismatch + A = DomainMatrix([[QQ(1), QQ(2)], [QQ(3), QQ(4)]], (2, 2), QQ) + b = DomainMatrix([[QQ(1), QQ(2)]], (1, 2), QQ) + raises(DMShapeError, lambda: A.lu_solve(b)) + + +def test_DomainMatrix_charpoly(): + A = DomainMatrix([], (0, 0), ZZ) + p = [ZZ(1)] + assert A.charpoly() == p + assert A.to_sparse().charpoly() == p + + A = DomainMatrix([[1]], (1, 1), ZZ) + p = [ZZ(1), ZZ(-1)] + assert A.charpoly() == p + assert A.to_sparse().charpoly() == p + + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + p = [ZZ(1), ZZ(-5), ZZ(-2)] + assert A.charpoly() == p + assert A.to_sparse().charpoly() == p + + A = DomainMatrix([[ZZ(1), ZZ(2), ZZ(3)], [ZZ(4), ZZ(5), ZZ(6)], [ZZ(7), ZZ(8), ZZ(9)]], (3, 3), ZZ) + p = [ZZ(1), ZZ(-15), ZZ(-18), ZZ(0)] + assert A.charpoly() == p + assert A.to_sparse().charpoly() == p + + A = DomainMatrix([[ZZ(0), ZZ(1), ZZ(0)], + [ZZ(1), ZZ(0), ZZ(1)], + [ZZ(0), ZZ(1), ZZ(0)]], (3, 3), ZZ) + p = [ZZ(1), ZZ(0), ZZ(-2), ZZ(0)] + assert A.charpoly() == p + assert A.to_sparse().charpoly() == p + + A = DM([[17, 0, 30, 0, 0, 0, 0, 0, 0, 0], + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [69, 0, 0, 0, 0, 86, 0, 0, 0, 0], + [23, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [ 0, 0, 0, 13, 0, 0, 0, 0, 0, 0], + [ 0, 0, 0, 0, 0, 0, 0, 32, 0, 0], + [ 0, 0, 0, 0, 37, 67, 0, 0, 0, 0], + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], ZZ) + p = ZZ.map([1, -17, -2070, 0, -771420, 0, 0, 0, 0, 0, 0]) + assert A.charpoly() == p + assert A.to_sparse().charpoly() == p + + Ans = DomainMatrix([[QQ(1), QQ(2)]], (1, 2), QQ) + raises(DMNonSquareMatrixError, lambda: Ans.charpoly()) + + +def test_DomainMatrix_charpoly_factor_list(): + A = DomainMatrix([], (0, 0), ZZ) + assert A.charpoly_factor_list() == [] + + A = DM([[1]], ZZ) + assert A.charpoly_factor_list() == [ + ([ZZ(1), ZZ(-1)], 1) + ] + + A = DM([[1, 2], [3, 4]], ZZ) + assert A.charpoly_factor_list() == [ + ([ZZ(1), ZZ(-5), ZZ(-2)], 1) + ] + + A = DM([[1, 2, 0], [3, 4, 0], [0, 0, 1]], ZZ) + assert A.charpoly_factor_list() == [ + ([ZZ(1), ZZ(-1)], 1), + ([ZZ(1), ZZ(-5), ZZ(-2)], 1) + ] + + +def test_DomainMatrix_eye(): + A = DomainMatrix.eye(3, QQ) + assert A.rep == SDM.eye((3, 3), QQ) + assert A.shape == (3, 3) + assert A.domain == QQ + + +def test_DomainMatrix_zeros(): + A = DomainMatrix.zeros((1, 2), QQ) + assert A.rep == SDM.zeros((1, 2), QQ) + assert A.shape == (1, 2) + assert A.domain == QQ + + +def test_DomainMatrix_ones(): + A = DomainMatrix.ones((2, 3), QQ) + if GROUND_TYPES != 'flint': + assert A.rep == DDM.ones((2, 3), QQ) + else: + assert A.rep == SDM.ones((2, 3), QQ).to_dfm() + assert A.shape == (2, 3) + assert A.domain == QQ + + +def test_DomainMatrix_diag(): + A = DomainMatrix({0:{0:ZZ(2)}, 1:{1:ZZ(3)}}, (2, 2), ZZ) + assert DomainMatrix.diag([ZZ(2), ZZ(3)], ZZ) == A + + A = DomainMatrix({0:{0:ZZ(2)}, 1:{1:ZZ(3)}}, (3, 4), ZZ) + assert DomainMatrix.diag([ZZ(2), ZZ(3)], ZZ, (3, 4)) == A + + +def test_DomainMatrix_hstack(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + B = DomainMatrix([[ZZ(5), ZZ(6)], [ZZ(7), ZZ(8)]], (2, 2), ZZ) + C = DomainMatrix([[ZZ(9), ZZ(10)], [ZZ(11), ZZ(12)]], (2, 2), ZZ) + + AB = DomainMatrix([ + [ZZ(1), ZZ(2), ZZ(5), ZZ(6)], + [ZZ(3), ZZ(4), ZZ(7), ZZ(8)]], (2, 4), ZZ) + ABC = DomainMatrix([ + [ZZ(1), ZZ(2), ZZ(5), ZZ(6), ZZ(9), ZZ(10)], + [ZZ(3), ZZ(4), ZZ(7), ZZ(8), ZZ(11), ZZ(12)]], (2, 6), ZZ) + assert A.hstack(B) == AB + assert A.hstack(B, C) == ABC + + +def test_DomainMatrix_vstack(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + B = DomainMatrix([[ZZ(5), ZZ(6)], [ZZ(7), ZZ(8)]], (2, 2), ZZ) + C = DomainMatrix([[ZZ(9), ZZ(10)], [ZZ(11), ZZ(12)]], (2, 2), ZZ) + + AB = DomainMatrix([ + [ZZ(1), ZZ(2)], + [ZZ(3), ZZ(4)], + [ZZ(5), ZZ(6)], + [ZZ(7), ZZ(8)]], (4, 2), ZZ) + ABC = DomainMatrix([ + [ZZ(1), ZZ(2)], + [ZZ(3), ZZ(4)], + [ZZ(5), ZZ(6)], + [ZZ(7), ZZ(8)], + [ZZ(9), ZZ(10)], + [ZZ(11), ZZ(12)]], (6, 2), ZZ) + assert A.vstack(B) == AB + assert A.vstack(B, C) == ABC + + +def test_DomainMatrix_applyfunc(): + A = DomainMatrix([[ZZ(1), ZZ(2)]], (1, 2), ZZ) + B = DomainMatrix([[ZZ(2), ZZ(4)]], (1, 2), ZZ) + assert A.applyfunc(lambda x: 2*x) == B + + +def test_DomainMatrix_scalarmul(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + lamda = DomainScalar(QQ(3)/QQ(2), QQ) + assert A * lamda == DomainMatrix([[QQ(3, 2), QQ(3)], [QQ(9, 2), QQ(6)]], (2, 2), QQ) + assert A * 2 == DomainMatrix([[ZZ(2), ZZ(4)], [ZZ(6), ZZ(8)]], (2, 2), ZZ) + assert 2 * A == DomainMatrix([[ZZ(2), ZZ(4)], [ZZ(6), ZZ(8)]], (2, 2), ZZ) + assert A * DomainScalar(ZZ(0), ZZ) == DomainMatrix({}, (2, 2), ZZ) + assert A * DomainScalar(ZZ(1), ZZ) == A + + raises(TypeError, lambda: A * 1.5) + + +def test_DomainMatrix_truediv(): + A = DomainMatrix.from_Matrix(Matrix([[1, 2], [3, 4]])) + lamda = DomainScalar(QQ(3)/QQ(2), QQ) + assert A / lamda == DomainMatrix({0: {0: QQ(2, 3), 1: QQ(4, 3)}, 1: {0: QQ(2), 1: QQ(8, 3)}}, (2, 2), QQ) + b = DomainScalar(ZZ(1), ZZ) + assert A / b == DomainMatrix({0: {0: QQ(1), 1: QQ(2)}, 1: {0: QQ(3), 1: QQ(4)}}, (2, 2), QQ) + + assert A / 1 == DomainMatrix({0: {0: QQ(1), 1: QQ(2)}, 1: {0: QQ(3), 1: QQ(4)}}, (2, 2), QQ) + assert A / 2 == DomainMatrix({0: {0: QQ(1, 2), 1: QQ(1)}, 1: {0: QQ(3, 2), 1: QQ(2)}}, (2, 2), QQ) + + raises(ZeroDivisionError, lambda: A / 0) + raises(TypeError, lambda: A / 1.5) + raises(ZeroDivisionError, lambda: A / DomainScalar(ZZ(0), ZZ)) + + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + assert A.to_field() / 2 == DomainMatrix([[QQ(1, 2), QQ(1)], [QQ(3, 2), QQ(2)]], (2, 2), QQ) + assert A / 2 == DomainMatrix([[QQ(1, 2), QQ(1)], [QQ(3, 2), QQ(2)]], (2, 2), QQ) + assert A.to_field() / QQ(2,3) == DomainMatrix([[QQ(3, 2), QQ(3)], [QQ(9, 2), QQ(6)]], (2, 2), QQ) + + +def test_DomainMatrix_getitem(): + dM = DomainMatrix([ + [ZZ(1), ZZ(2), ZZ(3)], + [ZZ(4), ZZ(5), ZZ(6)], + [ZZ(7), ZZ(8), ZZ(9)]], (3, 3), ZZ) + + assert dM[1:,:-2] == DomainMatrix([[ZZ(4)], [ZZ(7)]], (2, 1), ZZ) + assert dM[2,:-2] == DomainMatrix([[ZZ(7)]], (1, 1), ZZ) + assert dM[:-2,:-2] == DomainMatrix([[ZZ(1)]], (1, 1), ZZ) + assert dM[:-1,0:2] == DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(4), ZZ(5)]], (2, 2), ZZ) + assert dM[:, -1] == DomainMatrix([[ZZ(3)], [ZZ(6)], [ZZ(9)]], (3, 1), ZZ) + assert dM[-1, :] == DomainMatrix([[ZZ(7), ZZ(8), ZZ(9)]], (1, 3), ZZ) + assert dM[::-1, :] == DomainMatrix([ + [ZZ(7), ZZ(8), ZZ(9)], + [ZZ(4), ZZ(5), ZZ(6)], + [ZZ(1), ZZ(2), ZZ(3)]], (3, 3), ZZ) + + raises(IndexError, lambda: dM[4, :-2]) + raises(IndexError, lambda: dM[:-2, 4]) + + assert dM[1, 2] == DomainScalar(ZZ(6), ZZ) + assert dM[-2, 2] == DomainScalar(ZZ(6), ZZ) + assert dM[1, -2] == DomainScalar(ZZ(5), ZZ) + assert dM[-1, -3] == DomainScalar(ZZ(7), ZZ) + + raises(IndexError, lambda: dM[3, 3]) + raises(IndexError, lambda: dM[1, 4]) + raises(IndexError, lambda: dM[-1, -4]) + + dM = DomainMatrix({0: {0: ZZ(1)}}, (10, 10), ZZ) + assert dM[5, 5] == DomainScalar(ZZ(0), ZZ) + assert dM[0, 0] == DomainScalar(ZZ(1), ZZ) + + dM = DomainMatrix({1: {0: 1}}, (2,1), ZZ) + assert dM[0:, 0] == DomainMatrix({1: {0: 1}}, (2, 1), ZZ) + raises(IndexError, lambda: dM[3, 0]) + + dM = DomainMatrix({2: {2: ZZ(1)}, 4: {4: ZZ(1)}}, (5, 5), ZZ) + assert dM[:2,:2] == DomainMatrix({}, (2, 2), ZZ) + assert dM[2:,2:] == DomainMatrix({0: {0: 1}, 2: {2: 1}}, (3, 3), ZZ) + assert dM[3:,3:] == DomainMatrix({1: {1: 1}}, (2, 2), ZZ) + assert dM[2:, 6:] == DomainMatrix({}, (3, 0), ZZ) + + +def test_DomainMatrix_getitem_sympy(): + dM = DomainMatrix({2: {2: ZZ(2)}, 4: {4: ZZ(1)}}, (5, 5), ZZ) + val1 = dM.getitem_sympy(0, 0) + assert val1 is S.Zero + val2 = dM.getitem_sympy(2, 2) + assert val2 == 2 and isinstance(val2, Integer) + + +def test_DomainMatrix_extract(): + dM1 = DomainMatrix([ + [ZZ(1), ZZ(2), ZZ(3)], + [ZZ(4), ZZ(5), ZZ(6)], + [ZZ(7), ZZ(8), ZZ(9)]], (3, 3), ZZ) + dM2 = DomainMatrix([ + [ZZ(1), ZZ(3)], + [ZZ(7), ZZ(9)]], (2, 2), ZZ) + assert dM1.extract([0, 2], [0, 2]) == dM2 + assert dM1.to_sparse().extract([0, 2], [0, 2]) == dM2.to_sparse() + assert dM1.extract([0, -1], [0, -1]) == dM2 + assert dM1.to_sparse().extract([0, -1], [0, -1]) == dM2.to_sparse() + + dM3 = DomainMatrix([ + [ZZ(1), ZZ(2), ZZ(2)], + [ZZ(4), ZZ(5), ZZ(5)], + [ZZ(4), ZZ(5), ZZ(5)]], (3, 3), ZZ) + assert dM1.extract([0, 1, 1], [0, 1, 1]) == dM3 + assert dM1.to_sparse().extract([0, 1, 1], [0, 1, 1]) == dM3.to_sparse() + + empty = [ + ([], [], (0, 0)), + ([1], [], (1, 0)), + ([], [1], (0, 1)), + ] + for rows, cols, size in empty: + assert dM1.extract(rows, cols) == DomainMatrix.zeros(size, ZZ).to_dense() + assert dM1.to_sparse().extract(rows, cols) == DomainMatrix.zeros(size, ZZ) + + dM = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + bad_indices = [([2], [0]), ([0], [2]), ([-3], [0]), ([0], [-3])] + for rows, cols in bad_indices: + raises(IndexError, lambda: dM.extract(rows, cols)) + raises(IndexError, lambda: dM.to_sparse().extract(rows, cols)) + + +def test_DomainMatrix_setitem(): + dM = DomainMatrix({2: {2: ZZ(1)}, 4: {4: ZZ(1)}}, (5, 5), ZZ) + dM[2, 2] = ZZ(2) + assert dM == DomainMatrix({2: {2: ZZ(2)}, 4: {4: ZZ(1)}}, (5, 5), ZZ) + def setitem(i, j, val): + dM[i, j] = val + raises(TypeError, lambda: setitem(2, 2, QQ(1, 2))) + raises(NotImplementedError, lambda: setitem(slice(1, 2), 2, ZZ(1))) + + +def test_DomainMatrix_pickling(): + import pickle + dM = DomainMatrix({2: {2: ZZ(1)}, 4: {4: ZZ(1)}}, (5, 5), ZZ) + assert pickle.loads(pickle.dumps(dM)) == dM + dM = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + assert pickle.loads(pickle.dumps(dM)) == dM + + +def test_DomainMatrix_fflu(): + A = DM([[1, 2], [3, 4]], ZZ) + P, L, D, U = A.fflu() + assert P.shape == A.shape + assert L.shape == A.shape + assert D.shape == A.shape + assert U.shape == A.shape + assert P == DM([[1, 0], [0, 1]], ZZ) + assert L == DM([[1, 0], [3, -2]], ZZ) + assert D == DM([[1, 0], [0, -2]], ZZ) + assert U == DM([[1, 2], [0, -2]], ZZ) + di, d = D.inv_den() + assert P.matmul(A).rmul(d) == L.matmul(di).matmul(U) diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/tests/test_domainscalar.py b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/test_domainscalar.py new file mode 100644 index 0000000000000000000000000000000000000000..8c507caf079cc62ba23ba171a50d0d27c98eb6d9 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/test_domainscalar.py @@ -0,0 +1,153 @@ +from sympy.testing.pytest import raises + +from sympy.core.symbol import S +from sympy.polys import ZZ, QQ +from sympy.polys.matrices.domainscalar import DomainScalar +from sympy.polys.matrices.domainmatrix import DomainMatrix + + +def test_DomainScalar___new__(): + raises(TypeError, lambda: DomainScalar(ZZ(1), QQ)) + raises(TypeError, lambda: DomainScalar(ZZ(1), 1)) + + +def test_DomainScalar_new(): + A = DomainScalar(ZZ(1), ZZ) + B = A.new(ZZ(4), ZZ) + assert B == DomainScalar(ZZ(4), ZZ) + + +def test_DomainScalar_repr(): + A = DomainScalar(ZZ(1), ZZ) + assert repr(A) in {'1', 'mpz(1)'} + + +def test_DomainScalar_from_sympy(): + expr = S(1) + B = DomainScalar.from_sympy(expr) + assert B == DomainScalar(ZZ(1), ZZ) + + +def test_DomainScalar_to_sympy(): + B = DomainScalar(ZZ(1), ZZ) + expr = B.to_sympy() + assert expr.is_Integer and expr == 1 + + +def test_DomainScalar_to_domain(): + A = DomainScalar(ZZ(1), ZZ) + B = A.to_domain(QQ) + assert B == DomainScalar(QQ(1), QQ) + + +def test_DomainScalar_convert_to(): + A = DomainScalar(ZZ(1), ZZ) + B = A.convert_to(QQ) + assert B == DomainScalar(QQ(1), QQ) + + +def test_DomainScalar_unify(): + A = DomainScalar(ZZ(1), ZZ) + B = DomainScalar(QQ(2), QQ) + A, B = A.unify(B) + assert A.domain == B.domain == QQ + + +def test_DomainScalar_add(): + A = DomainScalar(ZZ(1), ZZ) + B = DomainScalar(QQ(2), QQ) + assert A + B == DomainScalar(QQ(3), QQ) + + raises(TypeError, lambda: A + 1.5) + +def test_DomainScalar_sub(): + A = DomainScalar(ZZ(1), ZZ) + B = DomainScalar(QQ(2), QQ) + assert A - B == DomainScalar(QQ(-1), QQ) + + raises(TypeError, lambda: A - 1.5) + +def test_DomainScalar_mul(): + A = DomainScalar(ZZ(1), ZZ) + B = DomainScalar(QQ(2), QQ) + dm = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + assert A * B == DomainScalar(QQ(2), QQ) + assert A * dm == dm + assert B * 2 == DomainScalar(QQ(4), QQ) + + raises(TypeError, lambda: A * 1.5) + + +def test_DomainScalar_floordiv(): + A = DomainScalar(ZZ(-5), ZZ) + B = DomainScalar(QQ(2), QQ) + assert A // B == DomainScalar(QQ(-5, 2), QQ) + C = DomainScalar(ZZ(2), ZZ) + assert A // C == DomainScalar(ZZ(-3), ZZ) + + raises(TypeError, lambda: A // 1.5) + + +def test_DomainScalar_mod(): + A = DomainScalar(ZZ(5), ZZ) + B = DomainScalar(QQ(2), QQ) + assert A % B == DomainScalar(QQ(0), QQ) + C = DomainScalar(ZZ(2), ZZ) + assert A % C == DomainScalar(ZZ(1), ZZ) + + raises(TypeError, lambda: A % 1.5) + + +def test_DomainScalar_divmod(): + A = DomainScalar(ZZ(5), ZZ) + B = DomainScalar(QQ(2), QQ) + assert divmod(A, B) == (DomainScalar(QQ(5, 2), QQ), DomainScalar(QQ(0), QQ)) + C = DomainScalar(ZZ(2), ZZ) + assert divmod(A, C) == (DomainScalar(ZZ(2), ZZ), DomainScalar(ZZ(1), ZZ)) + + raises(TypeError, lambda: divmod(A, 1.5)) + + +def test_DomainScalar_pow(): + A = DomainScalar(ZZ(-5), ZZ) + B = A**(2) + assert B == DomainScalar(ZZ(25), ZZ) + + raises(TypeError, lambda: A**(1.5)) + + +def test_DomainScalar_pos(): + A = DomainScalar(QQ(2), QQ) + B = DomainScalar(QQ(2), QQ) + assert +A == B + + +def test_DomainScalar_neg(): + A = DomainScalar(QQ(2), QQ) + B = DomainScalar(QQ(-2), QQ) + assert -A == B + + +def test_DomainScalar_eq(): + A = DomainScalar(QQ(2), QQ) + assert A == A + B = DomainScalar(ZZ(-5), ZZ) + assert A != B + C = DomainScalar(ZZ(2), ZZ) + assert A != C + D = [1] + assert A != D + + +def test_DomainScalar_isZero(): + A = DomainScalar(ZZ(0), ZZ) + assert A.is_zero() == True + B = DomainScalar(ZZ(1), ZZ) + assert B.is_zero() == False + + +def test_DomainScalar_isOne(): + A = DomainScalar(ZZ(1), ZZ) + assert A.is_one() == True + B = DomainScalar(ZZ(0), ZZ) + assert B.is_one() == False diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/tests/test_eigen.py b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/test_eigen.py new file mode 100644 index 0000000000000000000000000000000000000000..70482eab686d5b4e1c45d552f5eccb5bdaa9e1ed --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/test_eigen.py @@ -0,0 +1,90 @@ +""" +Tests for the sympy.polys.matrices.eigen module +""" + +from sympy.core.singleton import S +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.matrices.dense import Matrix + +from sympy.polys.agca.extensions import FiniteExtension +from sympy.polys.domains import QQ +from sympy.polys.polytools import Poly +from sympy.polys.rootoftools import CRootOf +from sympy.polys.matrices.domainmatrix import DomainMatrix + +from sympy.polys.matrices.eigen import dom_eigenvects, dom_eigenvects_to_sympy + + +def test_dom_eigenvects_rational(): + # Rational eigenvalues + A = DomainMatrix([[QQ(1), QQ(2)], [QQ(1), QQ(2)]], (2, 2), QQ) + rational_eigenvects = [ + (QQ, QQ(3), 1, DomainMatrix([[QQ(1), QQ(1)]], (1, 2), QQ)), + (QQ, QQ(0), 1, DomainMatrix([[QQ(-2), QQ(1)]], (1, 2), QQ)), + ] + assert dom_eigenvects(A) == (rational_eigenvects, []) + + # Test converting to Expr: + sympy_eigenvects = [ + (S(3), 1, [Matrix([1, 1])]), + (S(0), 1, [Matrix([-2, 1])]), + ] + assert dom_eigenvects_to_sympy(rational_eigenvects, [], Matrix) == sympy_eigenvects + + +def test_dom_eigenvects_algebraic(): + # Algebraic eigenvalues + A = DomainMatrix([[QQ(1), QQ(2)], [QQ(3), QQ(4)]], (2, 2), QQ) + Avects = dom_eigenvects(A) + + # Extract the dummy to build the expected result: + lamda = Avects[1][0][1].gens[0] + irreducible = Poly(lamda**2 - 5*lamda - 2, lamda, domain=QQ) + K = FiniteExtension(irreducible) + KK = K.from_sympy + algebraic_eigenvects = [ + (K, irreducible, 1, DomainMatrix([[KK((lamda-4)/3), KK(1)]], (1, 2), K)), + ] + assert Avects == ([], algebraic_eigenvects) + + # Test converting to Expr: + sympy_eigenvects = [ + (S(5)/2 - sqrt(33)/2, 1, [Matrix([[-sqrt(33)/6 - S(1)/2], [1]])]), + (S(5)/2 + sqrt(33)/2, 1, [Matrix([[-S(1)/2 + sqrt(33)/6], [1]])]), + ] + assert dom_eigenvects_to_sympy([], algebraic_eigenvects, Matrix) == sympy_eigenvects + + +def test_dom_eigenvects_rootof(): + # Algebraic eigenvalues + A = DomainMatrix([ + [0, 0, 0, 0, -1], + [1, 0, 0, 0, 1], + [0, 1, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 1, 0]], (5, 5), QQ) + Avects = dom_eigenvects(A) + + # Extract the dummy to build the expected result: + lamda = Avects[1][0][1].gens[0] + irreducible = Poly(lamda**5 - lamda + 1, lamda, domain=QQ) + K = FiniteExtension(irreducible) + KK = K.from_sympy + algebraic_eigenvects = [ + (K, irreducible, 1, + DomainMatrix([ + [KK(lamda**4-1), KK(lamda**3), KK(lamda**2), KK(lamda), KK(1)] + ], (1, 5), K)), + ] + assert Avects == ([], algebraic_eigenvects) + + # Test converting to Expr (slow): + l0, l1, l2, l3, l4 = [CRootOf(lamda**5 - lamda + 1, i) for i in range(5)] + sympy_eigenvects = [ + (l0, 1, [Matrix([-1 + l0**4, l0**3, l0**2, l0, 1])]), + (l1, 1, [Matrix([-1 + l1**4, l1**3, l1**2, l1, 1])]), + (l2, 1, [Matrix([-1 + l2**4, l2**3, l2**2, l2, 1])]), + (l3, 1, [Matrix([-1 + l3**4, l3**3, l3**2, l3, 1])]), + (l4, 1, [Matrix([-1 + l4**4, l4**3, l4**2, l4, 1])]), + ] + assert dom_eigenvects_to_sympy([], algebraic_eigenvects, Matrix) == sympy_eigenvects diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/tests/test_fflu.py b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/test_fflu.py new file mode 100644 index 0000000000000000000000000000000000000000..0a4676ce0c3ee2d495b7011ddc48db8c8c40648b --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/test_fflu.py @@ -0,0 +1,301 @@ +from sympy.polys.matrices import DomainMatrix, DM +from sympy.polys.domains import ZZ, QQ +from sympy import Matrix +import pytest + + +FFLU_EXAMPLES = [ + ( + 'zz_2x3', + DM([[1, 2, 3], [4, 5, 6]], ZZ), + DM([[1, 0], [0, 1]], ZZ), + DM([[1, 0], [4, -3]], ZZ), + DM([[1, 0], [0, -3]], ZZ), + DM([[1, 2, 3], [0, -3, -6]], ZZ), + ), + + ( + 'zz_2x2', + DM([[4, 3], [6, 3]], ZZ), + DM([[1, 0], [0, 1]], ZZ), + DM([[1, 0], [6, -6]], ZZ), + DM([[4, 0], [0, -3]], ZZ), + DM([[4, 3], [0, -3]], ZZ), + ), + + ( + 'zz_3x2', + DM([[1, 2], [3, 4], [5, 6]], ZZ), + DM([[1, 0, 0], [0, 1, 0], [0, 0, 1]], ZZ), + DM([[1, 0, 0], [3, 1, 0], [5, 2, 1]], ZZ), + DM([[1, 0], [0, -2]], ZZ), + DM([[1, 2], [0, -2], [0, 0]], ZZ), + ), + + ( + 'zz_3x3', + DM([[1, 2, 3], [4, 5, 6], [7, 8, 9]], ZZ), + DM([[1, 0, 0], [0, 1, 0], [0, 0, 1]], ZZ), + DM([[1, 0, 0], [4, 1, 0], [7, 2, 1]], ZZ), + DM([[1, 0, 0], [0, -3, 0], [0, 0, 0]], ZZ), + DM([[1, 2, 3], [0, -3, -6], [0, 0, 0]], ZZ), + ), + + ( + 'zz_zero', + DM([[0, 0, 0], [0, 0, 0], [0, 0, 0]], ZZ), + DM([[1, 0, 0], [0, 1, 0], [0, 0, 1]], ZZ), + DM([[1, 0, 0], [0, 1, 0], [0, 0, 1]], ZZ), + DM([[0, 0, 0], [0, 0, 0], [0, 0, 0]], ZZ), + DM([[0, 0, 0], [0, 0, 0], [0, 0, 0]], ZZ), + ), + + ( + 'zz_empty', + DM([], ZZ), + DM([], ZZ), + DM([], ZZ), + DM([], ZZ), + DM([], ZZ), + ), + + ( + 'zz_empty_0x2', + DomainMatrix([], (0, 2), ZZ), + DomainMatrix([], (0, 0), ZZ), + DomainMatrix([], (0, 0), ZZ), + DomainMatrix([], (0, 0), ZZ), + DomainMatrix([], (0, 2), ZZ) + ), + + ( + + 'zz_empty_2x0', + DomainMatrix([[], []], (2, 0), ZZ), + DomainMatrix.eye((2, 2), ZZ), + DomainMatrix.eye((2, 2), ZZ), + DomainMatrix.eye((2, 2), ZZ), + DomainMatrix([[], []], (2, 0), ZZ) + + ), + + ( + 'zz_negative', + DM([[-1, -2], [-3, -4]], ZZ), + DM([[1, 0], [0, 1]], ZZ), + DM([[-1, 0], [-3, -2]], ZZ), + DM([[-1, 0], [0, 2]], ZZ), + DM([[-1, -2], [0, -2]], ZZ), + ), + + ( + 'zz_mixed_signs', + DM([[1, -2], [-3, 4]], ZZ), + DM([[1, 0], [0, 1]], ZZ), + DM([[1, 0], [-3, 1]], ZZ), + DM([[1, 0], [0, -2]], ZZ), + DM([[1, -2], [0, -2]], ZZ), + ), + + ( + 'zz_upper_triangular', + DM([[1, 2, 3], [0, 4, 5], [0, 0, 6]], ZZ), + DM([[1, 0, 0], [0, 1, 0], [0, 0, 1]], ZZ), + DM([[1, 0, 0], [0, 4, 0], [0, 0, 24]], ZZ), + DM([[1, 0, 0], [0, 4, 0], [0, 0, 96]], ZZ), + DM([[1, 2, 3], [0, 4, 5], [0, 0, 24]], ZZ), + ), + + ( + 'zz_lower_triangular', + DM([[1, 0, 0], [2, 3, 0], [4, 5, 6]], ZZ), + DM([[1, 0, 0], [0, 1, 0], [0, 0, 1]], ZZ), + DM([[1, 0, 0], [2, 3, 0], [4, 5, 18]], ZZ), + DM([[1, 0, 0], [0, 3, 0], [0, 0, 54]], ZZ), + DM([[1, 0, 0], [0, 3, 0], [0, 0, 18]], ZZ), + ), + + ( + 'zz_diagonal', + DM([[2, 0, 0], [0, 3, 0], [0, 0, 4]], ZZ), + DM([[1, 0, 0], [0, 1, 0], [0, 0, 1]], ZZ), + DM([[2, 0, 0], [0, 6, 0], [0, 0, 24]], ZZ), + DM([[2, 0, 0], [0, 12, 0], [0, 0, 144]], ZZ), + DM([[2, 0, 0], [0, 6, 0], [0, 0, 24]], ZZ) + + ), + + ( + 'rank_deficient_3x3', + DM([[1, 2, 3], [2, 4, 6], [3, 6, 9]], ZZ), + DM([[1, 0, 0], [0, 1, 0], [0, 0, 1]], ZZ), + DM([[1, 0, 0], [2, 1, 0], [3, 0, 1]], ZZ), + DM([[1, 0, 0], [0, 0, 0], [0, 0, 0]], ZZ), + DM([[1, 2, 3], [0, 0, 0], [0, 0, 0]], ZZ), + ), + + ( + 'zz_1x1', + DM([[5]], ZZ), + DM([[1]], ZZ), + DM([[5]], ZZ), + DM([[5]], ZZ), + DM([[5]], ZZ), + ), + + ( + 'zz_nx1_2rows', + DM([[81], [54]], ZZ), + DM([[1, 0], [0, 1]], ZZ), + DM([[81, 0], [54, 81]], ZZ), + DM([[81, 0], [0, 81]], ZZ), + DM([[81], [0]], ZZ), + ), + + ( + 'zz_nx2_3rows', + DM([[2, 7], [7, 45], [25, 84]], ZZ), + DM([[1, 0, 0], [0, 1, 0], [0, 0, 1]], ZZ), + DM([[2, 0, 0], [7, 82, 0], [25, 41, 41]], ZZ), + DM([[2, 0, 0], [0, 82, 0], [0, 0, 41]], ZZ), + DM([[2, 7], [0, 82], [0, 0]], ZZ), + ), + + ( + + 'zz_1x2', + DM([[0, 28]], ZZ), + DM([[1]], ZZ), + DM([[28]], ZZ), + DM([[28]], ZZ), + DM([[0, 28]], ZZ) + ), + + ( + 'zz_nx3_4rows', + DM([[84, 30, 9], [20, 59, 13], [53, 46, 81], [63, 48, 29]], ZZ), + DM([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], ZZ), + DM([[84, 0, 0, 0], [20, 365904, 0, 0], [53, 303411, 303411, 0], [63, 303411, 303411, 303411]], ZZ), + DM([[84, 0, 0, 0], [0, 365904, 0, 0], [0, 0, 1321658316, 0], [0, 0, 0, 303411]], ZZ), + DM([[84, 30, 9], [0, 365904, 13], [0, 0, 1321658316], [0, 0, 0]], ZZ), + ), + + ( + 'fflu_row_swap', + DM([[0, 1, 2], [3, 4, 5], [6, 7, 8]], ZZ), + DM([[0, 1, 0], [1, 0, 0], [0, 0, 1]], ZZ), + DM([[3, 0, 0], [0, 3, 0], [6, -3, 1]], ZZ), + DM([[3, 0, 0], [0, 9, 0], [0, 0, 3]], ZZ), + DM([[3, 4, 5], [0, 3, 6], [0, 0, 0]], ZZ) + ), +] + + +def _check_fflu(A, P, L, D, U): + P_field = P.to_field().to_dense() + L_field = L.to_field().to_dense() + D_field = D.to_field().to_dense() + U_field = U.to_field().to_dense() + m, n = A.shape + assert P_field.shape == (m, m) + assert L_field.shape == (m, m) + assert D_field.shape == (m, m) + assert U_field.shape == (m, n) + assert L_field.is_lower + assert D_field.is_diagonal + di, d = D.inv_den() + assert P.matmul(A).rmul(d) == L.matmul(di).matmul(U) + assert U_field.is_upper + + +def _to_DM(A, ans): + if isinstance(A, DomainMatrix): + return A + elif isinstance(A, Matrix): + return A.to_DM(ans.domain) + return DomainMatrix(A.to_list(), A.shape, A.domain) + + +def _check_fflu_result(result, A, P_ans, L_ans, D_ans, U_ans): + P, L, D, U = result + P = _to_DM(P, P_ans) + L = _to_DM(L, L_ans) + D = _to_DM(D, D_ans) + U = _to_DM(U, U_ans) + A = _to_DM(A, P_ans) + m, n = A.shape + assert P.shape == (m, m) + assert L.shape == (m, m) + assert D.shape == (m, m) + assert U.shape == (m, n) + assert L.is_lower + assert D.is_diagonal + di, d = D.inv_den() + assert P.matmul(A).rmul(d) == L.matmul(di).matmul(U) + assert U.is_upper + + +@pytest.mark.parametrize('name, A, P_ans, L_ans, D_ans, U_ans', FFLU_EXAMPLES) +def test_dm_dense_fflu(name, A, P_ans, L_ans, D_ans, U_ans): + A = A.to_dense() + _check_fflu_result(A.fflu(), A, P_ans, L_ans, D_ans, U_ans) + + +@pytest.mark.parametrize('name, A, P_ans, L_ans, D_ans, U_ans', FFLU_EXAMPLES) +def test_dm_sparse_fflu(name, A, P_ans, L_ans, D_ans, U_ans): + A = A.to_sparse() + _check_fflu_result(A.fflu(), A, P_ans, L_ans, D_ans, U_ans) + + +@pytest.mark.parametrize('name, A, P_ans, L_ans, D_ans, U_ans', FFLU_EXAMPLES) +def test_ddm_fflu(name, A, P_ans, L_ans, D_ans, U_ans): + A = A.to_ddm() + _check_fflu_result(A.fflu(), A, P_ans, L_ans, D_ans, U_ans) + + +@pytest.mark.parametrize('name, A, P_ans, L_ans, D_ans, U_ans', FFLU_EXAMPLES) +def test_sdm_fflu(name, A, P_ans, L_ans, D_ans, U_ans): + A = A.to_sdm() + _check_fflu_result(A.fflu(), A, P_ans, L_ans, D_ans, U_ans) + + +@pytest.mark.parametrize('name, A, P_ans, L_ans, D_ans, U_ans', FFLU_EXAMPLES) +def test_dfm_fflu(name, A, P_ans, L_ans, D_ans, U_ans): + pytest.importorskip('flint') + if A.domain not in (ZZ, QQ) and not A.domain.is_FF: + pytest.skip("Domain not supported by DFM") + A = A.to_dfm() + _check_fflu_result(A.fflu(), A, P_ans, L_ans, D_ans, U_ans) + + +def test_fflu_empty_matrix(): + A = DomainMatrix([], (0, 0), ZZ) + P, L, D, U = A.fflu() + assert P.shape == (0, 0) + assert L.shape == (0, 0) + assert D.shape == (0, 0) + assert U.shape == (0, 0) + + +def test_fflu_properties(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + P, L, D, U = A.fflu() + assert P.shape == (2, 2) + assert L.shape == (2, 2) + assert D.shape == (2, 2) + assert U.shape == (2, 2) + assert L.is_lower + assert U.is_upper + assert D.is_diagonal + di, d = D.inv_den() + assert P.matmul(A).rmul(d) == L.matmul(di).matmul(U) + + +def test_fflu_rank_deficient(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(2), ZZ(4)]], (2, 2), ZZ) + P, L, D, U = A.fflu() + assert P.shape == (2, 2) + assert L.shape == (2, 2) + assert D.shape == (2, 2) + assert U.shape == (2, 2) + assert U.getitem_sympy(1, 1) == 0 diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/tests/test_inverse.py b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/test_inverse.py new file mode 100644 index 0000000000000000000000000000000000000000..47c82799324518bd7d1cc2405ade0aa0a5a4f6e9 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/test_inverse.py @@ -0,0 +1,193 @@ +from sympy import ZZ, Matrix +from sympy.polys.matrices import DM, DomainMatrix +from sympy.polys.matrices.dense import ddm_iinv +from sympy.polys.matrices.exceptions import DMNonInvertibleMatrixError +from sympy.matrices.exceptions import NonInvertibleMatrixError + +import pytest +from sympy.testing.pytest import raises +from sympy.core.numbers import all_close + +from sympy.abc import x + + +# Examples are given as adjugate matrix and determinant adj_det should match +# these exactly but inv_den only matches after cancel_denom. + + +INVERSE_EXAMPLES = [ + + ( + 'zz_1', + DomainMatrix([], (0, 0), ZZ), + DomainMatrix([], (0, 0), ZZ), + ZZ(1), + ), + + ( + 'zz_2', + DM([[2]], ZZ), + DM([[1]], ZZ), + ZZ(2), + ), + + ( + 'zz_3', + DM([[2, 0], + [0, 2]], ZZ), + DM([[2, 0], + [0, 2]], ZZ), + ZZ(4), + ), + + ( + 'zz_4', + DM([[1, 2], + [3, 4]], ZZ), + DM([[ 4, -2], + [-3, 1]], ZZ), + ZZ(-2), + ), + + ( + 'zz_5', + DM([[2, 2, 0], + [0, 2, 2], + [0, 0, 2]], ZZ), + DM([[4, -4, 4], + [0, 4, -4], + [0, 0, 4]], ZZ), + ZZ(8), + ), + + ( + 'zz_6', + DM([[1, 2, 3], + [4, 5, 6], + [7, 8, 9]], ZZ), + DM([[-3, 6, -3], + [ 6, -12, 6], + [-3, 6, -3]], ZZ), + ZZ(0), + ), +] + + +@pytest.mark.parametrize('name, A, A_inv, den', INVERSE_EXAMPLES) +def test_Matrix_inv(name, A, A_inv, den): + + def _check(**kwargs): + if den != 0: + assert A.inv(**kwargs) == A_inv + else: + raises(NonInvertibleMatrixError, lambda: A.inv(**kwargs)) + + K = A.domain + A = A.to_Matrix() + A_inv = A_inv.to_Matrix() / K.to_sympy(den) + _check() + for method in ['GE', 'LU', 'ADJ', 'CH', 'LDL', 'QR']: + _check(method=method) + + +@pytest.mark.parametrize('name, A, A_inv, den', INVERSE_EXAMPLES) +def test_dm_inv_den(name, A, A_inv, den): + if den != 0: + A_inv_f, den_f = A.inv_den() + assert A_inv_f.cancel_denom(den_f) == A_inv.cancel_denom(den) + else: + raises(DMNonInvertibleMatrixError, lambda: A.inv_den()) + + +@pytest.mark.parametrize('name, A, A_inv, den', INVERSE_EXAMPLES) +def test_dm_inv(name, A, A_inv, den): + A = A.to_field() + if den != 0: + A_inv = A_inv.to_field() / den + assert A.inv() == A_inv + else: + raises(DMNonInvertibleMatrixError, lambda: A.inv()) + + +@pytest.mark.parametrize('name, A, A_inv, den', INVERSE_EXAMPLES) +def test_ddm_inv(name, A, A_inv, den): + A = A.to_field().to_ddm() + if den != 0: + A_inv = (A_inv.to_field() / den).to_ddm() + assert A.inv() == A_inv + else: + raises(DMNonInvertibleMatrixError, lambda: A.inv()) + + +@pytest.mark.parametrize('name, A, A_inv, den', INVERSE_EXAMPLES) +def test_sdm_inv(name, A, A_inv, den): + A = A.to_field().to_sdm() + if den != 0: + A_inv = (A_inv.to_field() / den).to_sdm() + assert A.inv() == A_inv + else: + raises(DMNonInvertibleMatrixError, lambda: A.inv()) + + +@pytest.mark.parametrize('name, A, A_inv, den', INVERSE_EXAMPLES) +def test_dense_ddm_iinv(name, A, A_inv, den): + A = A.to_field().to_ddm().copy() + K = A.domain + A_result = A.copy() + if den != 0: + A_inv = (A_inv.to_field() / den).to_ddm() + ddm_iinv(A_result, A, K) + assert A_result == A_inv + else: + raises(DMNonInvertibleMatrixError, lambda: ddm_iinv(A_result, A, K)) + + +@pytest.mark.parametrize('name, A, A_inv, den', INVERSE_EXAMPLES) +def test_Matrix_adjugate(name, A, A_inv, den): + A = A.to_Matrix() + A_inv = A_inv.to_Matrix() + assert A.adjugate() == A_inv + for method in ["bareiss", "berkowitz", "bird", "laplace", "lu"]: + assert A.adjugate(method=method) == A_inv + + +@pytest.mark.parametrize('name, A, A_inv, den', INVERSE_EXAMPLES) +def test_dm_adj_det(name, A, A_inv, den): + assert A.adj_det() == (A_inv, den) + + +def test_inverse_inexact(): + + M = Matrix([[x-0.3, -0.06, -0.22], + [-0.46, x-0.48, -0.41], + [-0.14, -0.39, x-0.64]]) + + Mn = Matrix([[1.0*x**2 - 1.12*x + 0.1473, 0.06*x + 0.0474, 0.22*x - 0.081], + [0.46*x - 0.237, 1.0*x**2 - 0.94*x + 0.1612, 0.41*x - 0.0218], + [0.14*x + 0.1122, 0.39*x - 0.1086, 1.0*x**2 - 0.78*x + 0.1164]]) + + d = 1.0*x**3 - 1.42*x**2 + 0.4249*x - 0.0546540000000002 + + Mi = Mn / d + + M_dm = M.to_DM() + M_dmd = M_dm.to_dense() + M_dm_num, M_dm_den = M_dm.inv_den() + M_dmd_num, M_dmd_den = M_dmd.inv_den() + + # XXX: We don't check M_dm().to_field().inv() which currently uses division + # and produces a more complicate result from gcd cancellation failing. + # DomainMatrix.inv() over RR(x) should be changed to clear denominators and + # use DomainMatrix.inv_den(). + + Minvs = [ + M.inv(), + (M_dm_num.to_field() / M_dm_den).to_Matrix(), + (M_dmd_num.to_field() / M_dmd_den).to_Matrix(), + M_dm_num.to_Matrix() / M_dm_den.as_expr(), + M_dmd_num.to_Matrix() / M_dmd_den.as_expr(), + ] + + for Minv in Minvs: + for Mi1, Mi2 in zip(Minv.flat(), Mi.flat()): + assert all_close(Mi2, Mi1) diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/tests/test_linsolve.py b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/test_linsolve.py new file mode 100644 index 0000000000000000000000000000000000000000..25300ef2cb4792e4424c9c15c0bbbc313ce062e6 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/test_linsolve.py @@ -0,0 +1,112 @@ +# +# test_linsolve.py +# +# Test the internal implementation of linsolve. +# + +from sympy.testing.pytest import raises + +from sympy.core.numbers import I +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.abc import x, y, z + +from sympy.polys.matrices.linsolve import _linsolve +from sympy.polys.solvers import PolyNonlinearError + + +def test__linsolve(): + assert _linsolve([], [x]) == {x:x} + assert _linsolve([S.Zero], [x]) == {x:x} + assert _linsolve([x-1,x-2], [x]) is None + assert _linsolve([x-1], [x]) == {x:1} + assert _linsolve([x-1, y], [x, y]) == {x:1, y:S.Zero} + assert _linsolve([2*I], [x]) is None + raises(PolyNonlinearError, lambda: _linsolve([x*(1 + x)], [x])) + + +def test__linsolve_float(): + + # This should give the exact answer: + eqs = [ + y - x, + y - 0.0216 * x + ] + # Should _linsolve return floats here? + sol = {x:0, y:0} + assert _linsolve(eqs, (x, y)) == sol + + # Other cases should be close to eps + + def all_close(sol1, sol2, eps=1e-15): + close = lambda a, b: abs(a - b) < eps + assert sol1.keys() == sol2.keys() + return all(close(sol1[s], sol2[s]) for s in sol1) + + eqs = [ + 0.8*x + 0.8*z + 0.2, + 0.9*x + 0.7*y + 0.2*z + 0.9, + 0.7*x + 0.2*y + 0.2*z + 0.5 + ] + sol_exact = {x:-29/42, y:-11/21, z:37/84} + sol_linsolve = _linsolve(eqs, [x,y,z]) + assert all_close(sol_exact, sol_linsolve) + + eqs = [ + 0.9*x + 0.3*y + 0.4*z + 0.6, + 0.6*x + 0.9*y + 0.1*z + 0.7, + 0.4*x + 0.6*y + 0.9*z + 0.5 + ] + sol_exact = {x:-88/175, y:-46/105, z:-1/25} + sol_linsolve = _linsolve(eqs, [x,y,z]) + assert all_close(sol_exact, sol_linsolve) + + eqs = [ + 0.4*x + 0.3*y + 0.6*z + 0.7, + 0.4*x + 0.3*y + 0.9*z + 0.9, + 0.7*x + 0.9*y, + ] + sol_exact = {x:-9/5, y:7/5, z:-2/3} + sol_linsolve = _linsolve(eqs, [x,y,z]) + assert all_close(sol_exact, sol_linsolve) + + eqs = [ + x*(0.7 + 0.6*I) + y*(0.4 + 0.7*I) + z*(0.9 + 0.1*I) + 0.5, + 0.2*I*x + 0.2*I*y + z*(0.9 + 0.2*I) + 0.1, + x*(0.9 + 0.7*I) + y*(0.9 + 0.7*I) + z*(0.9 + 0.4*I) + 0.4, + ] + sol_exact = { + x:-6157/7995 - 411/5330*I, + y:8519/15990 + 1784/7995*I, + z:-34/533 + 107/1599*I, + } + sol_linsolve = _linsolve(eqs, [x,y,z]) + assert all_close(sol_exact, sol_linsolve) + + # XXX: This system for x and y over RR(z) is problematic. + # + # eqs = [ + # x*(0.2*z + 0.9) + y*(0.5*z + 0.8) + 0.6, + # 0.1*x*z + y*(0.1*z + 0.6) + 0.9, + # ] + # + # linsolve(eqs, [x, y]) + # The solution for x comes out as + # + # -3.9e-5*z**2 - 3.6e-5*z - 8.67361737988404e-20 + # x = ---------------------------------------------- + # 3.0e-6*z**3 - 1.3e-5*z**2 - 5.4e-5*z + # + # The 8e-20 in the numerator should be zero which would allow z to cancel + # from top and bottom. It should be possible to avoid this somehow because + # the inverse of the matrix only has a quadratic factor (the determinant) + # in the denominator. + + +def test__linsolve_deprecated(): + raises(PolyNonlinearError, lambda: + _linsolve([Eq(x**2, x**2 + y)], [x, y])) + raises(PolyNonlinearError, lambda: + _linsolve([(x + y)**2 - x**2], [x])) + raises(PolyNonlinearError, lambda: + _linsolve([Eq((x + y)**2, x**2)], [x])) diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/tests/test_lll.py b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/test_lll.py new file mode 100644 index 0000000000000000000000000000000000000000..2cf91a00703532f02d763656d6117018fbc496cf --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/test_lll.py @@ -0,0 +1,145 @@ +from sympy.polys.domains import ZZ, QQ +from sympy.polys.matrices import DM +from sympy.polys.matrices.domainmatrix import DomainMatrix +from sympy.polys.matrices.exceptions import DMRankError, DMValueError, DMShapeError, DMDomainError +from sympy.polys.matrices.lll import _ddm_lll, ddm_lll, ddm_lll_transform +from sympy.testing.pytest import raises + + +def test_lll(): + normal_test_data = [ + ( + DM([[1, 0, 0, 0, -20160], + [0, 1, 0, 0, 33768], + [0, 0, 1, 0, 39578], + [0, 0, 0, 1, 47757]], ZZ), + DM([[10, -3, -2, 8, -4], + [3, -9, 8, 1, -11], + [-3, 13, -9, -3, -9], + [-12, -7, -11, 9, -1]], ZZ) + ), + ( + DM([[20, 52, 3456], + [14, 31, -1], + [34, -442, 0]], ZZ), + DM([[14, 31, -1], + [188, -101, -11], + [236, 13, 3443]], ZZ) + ), + ( + DM([[34, -1, -86, 12], + [-54, 34, 55, 678], + [23, 3498, 234, 6783], + [87, 49, 665, 11]], ZZ), + DM([[34, -1, -86, 12], + [291, 43, 149, 83], + [-54, 34, 55, 678], + [-189, 3077, -184, -223]], ZZ) + ) + ] + delta = QQ(5, 6) + for basis_dm, reduced_dm in normal_test_data: + reduced = _ddm_lll(basis_dm.rep.to_ddm(), delta=delta)[0] + assert reduced == reduced_dm.rep.to_ddm() + + reduced = ddm_lll(basis_dm.rep.to_ddm(), delta=delta) + assert reduced == reduced_dm.rep.to_ddm() + + reduced, transform = _ddm_lll(basis_dm.rep.to_ddm(), delta=delta, return_transform=True) + assert reduced == reduced_dm.rep.to_ddm() + assert transform.matmul(basis_dm.rep.to_ddm()) == reduced_dm.rep.to_ddm() + + reduced, transform = ddm_lll_transform(basis_dm.rep.to_ddm(), delta=delta) + assert reduced == reduced_dm.rep.to_ddm() + assert transform.matmul(basis_dm.rep.to_ddm()) == reduced_dm.rep.to_ddm() + + reduced = basis_dm.rep.lll(delta=delta) + assert reduced == reduced_dm.rep + + reduced, transform = basis_dm.rep.lll_transform(delta=delta) + assert reduced == reduced_dm.rep + assert transform.matmul(basis_dm.rep) == reduced_dm.rep + + reduced = basis_dm.rep.to_sdm().lll(delta=delta) + assert reduced == reduced_dm.rep.to_sdm() + + reduced, transform = basis_dm.rep.to_sdm().lll_transform(delta=delta) + assert reduced == reduced_dm.rep.to_sdm() + assert transform.matmul(basis_dm.rep.to_sdm()) == reduced_dm.rep.to_sdm() + + reduced = basis_dm.lll(delta=delta) + assert reduced == reduced_dm + + reduced, transform = basis_dm.lll_transform(delta=delta) + assert reduced == reduced_dm + assert transform.matmul(basis_dm) == reduced_dm + + +def test_lll_linear_dependent(): + linear_dependent_test_data = [ + DM([[0, -1, -2, -3], + [1, 0, -1, -2], + [2, 1, 0, -1], + [3, 2, 1, 0]], ZZ), + DM([[1, 0, 0, 1], + [0, 1, 0, 1], + [0, 0, 1, 1], + [1, 2, 3, 6]], ZZ), + DM([[3, -5, 1], + [4, 6, 0], + [10, -4, 2]], ZZ) + ] + for not_basis in linear_dependent_test_data: + raises(DMRankError, lambda: _ddm_lll(not_basis.rep.to_ddm())) + raises(DMRankError, lambda: ddm_lll(not_basis.rep.to_ddm())) + raises(DMRankError, lambda: not_basis.rep.lll()) + raises(DMRankError, lambda: not_basis.rep.to_sdm().lll()) + raises(DMRankError, lambda: not_basis.lll()) + raises(DMRankError, lambda: _ddm_lll(not_basis.rep.to_ddm(), return_transform=True)) + raises(DMRankError, lambda: ddm_lll_transform(not_basis.rep.to_ddm())) + raises(DMRankError, lambda: not_basis.rep.lll_transform()) + raises(DMRankError, lambda: not_basis.rep.to_sdm().lll_transform()) + raises(DMRankError, lambda: not_basis.lll_transform()) + + +def test_lll_wrong_delta(): + dummy_matrix = DomainMatrix.ones((3, 3), ZZ) + for wrong_delta in [QQ(-1, 4), QQ(0, 1), QQ(1, 4), QQ(1, 1), QQ(100, 1)]: + raises(DMValueError, lambda: _ddm_lll(dummy_matrix.rep, delta=wrong_delta)) + raises(DMValueError, lambda: ddm_lll(dummy_matrix.rep, delta=wrong_delta)) + raises(DMValueError, lambda: dummy_matrix.rep.lll(delta=wrong_delta)) + raises(DMValueError, lambda: dummy_matrix.rep.to_sdm().lll(delta=wrong_delta)) + raises(DMValueError, lambda: dummy_matrix.lll(delta=wrong_delta)) + raises(DMValueError, lambda: _ddm_lll(dummy_matrix.rep, delta=wrong_delta, return_transform=True)) + raises(DMValueError, lambda: ddm_lll_transform(dummy_matrix.rep, delta=wrong_delta)) + raises(DMValueError, lambda: dummy_matrix.rep.lll_transform(delta=wrong_delta)) + raises(DMValueError, lambda: dummy_matrix.rep.to_sdm().lll_transform(delta=wrong_delta)) + raises(DMValueError, lambda: dummy_matrix.lll_transform(delta=wrong_delta)) + + +def test_lll_wrong_shape(): + wrong_shape_matrix = DomainMatrix.ones((4, 3), ZZ) + raises(DMShapeError, lambda: _ddm_lll(wrong_shape_matrix.rep)) + raises(DMShapeError, lambda: ddm_lll(wrong_shape_matrix.rep)) + raises(DMShapeError, lambda: wrong_shape_matrix.rep.lll()) + raises(DMShapeError, lambda: wrong_shape_matrix.rep.to_sdm().lll()) + raises(DMShapeError, lambda: wrong_shape_matrix.lll()) + raises(DMShapeError, lambda: _ddm_lll(wrong_shape_matrix.rep, return_transform=True)) + raises(DMShapeError, lambda: ddm_lll_transform(wrong_shape_matrix.rep)) + raises(DMShapeError, lambda: wrong_shape_matrix.rep.lll_transform()) + raises(DMShapeError, lambda: wrong_shape_matrix.rep.to_sdm().lll_transform()) + raises(DMShapeError, lambda: wrong_shape_matrix.lll_transform()) + + +def test_lll_wrong_domain(): + wrong_domain_matrix = DomainMatrix.ones((3, 3), QQ) + raises(DMDomainError, lambda: _ddm_lll(wrong_domain_matrix.rep)) + raises(DMDomainError, lambda: ddm_lll(wrong_domain_matrix.rep)) + raises(DMDomainError, lambda: wrong_domain_matrix.rep.lll()) + raises(DMDomainError, lambda: wrong_domain_matrix.rep.to_sdm().lll()) + raises(DMDomainError, lambda: wrong_domain_matrix.lll()) + raises(DMDomainError, lambda: _ddm_lll(wrong_domain_matrix.rep, return_transform=True)) + raises(DMDomainError, lambda: ddm_lll_transform(wrong_domain_matrix.rep)) + raises(DMDomainError, lambda: wrong_domain_matrix.rep.lll_transform()) + raises(DMDomainError, lambda: wrong_domain_matrix.rep.to_sdm().lll_transform()) + raises(DMDomainError, lambda: wrong_domain_matrix.lll_transform()) diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/tests/test_normalforms.py b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/test_normalforms.py new file mode 100644 index 0000000000000000000000000000000000000000..542d9064aea204759158578a4bfbbf5acbb06db3 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/test_normalforms.py @@ -0,0 +1,156 @@ +from sympy.testing.pytest import raises + +from sympy.core.symbol import Symbol +from sympy.polys.matrices.normalforms import ( + invariant_factors, + smith_normal_form, + smith_normal_decomp, + is_smith_normal_form, + hermite_normal_form, + _hermite_normal_form, + _hermite_normal_form_modulo_D +) +from sympy.polys.domains import ZZ, QQ +from sympy.polys.matrices import DomainMatrix, DM +from sympy.polys.matrices.exceptions import DMDomainError, DMShapeError + + +def test_is_smith_normal_form(): + + snf_examples = [ + DM([[0, 0], [0, 0]], ZZ), + DM([[1, 0], [0, 0]], ZZ), + DM([[1, 0], [0, 1]], ZZ), + DM([[1, 0], [0, 2]], ZZ), + ] + + non_snf_examples = [ + DM([[0, 1], [0, 0]], ZZ), + DM([[0, 0], [0, 1]], ZZ), + DM([[2, 0], [0, 3]], ZZ), + ] + + for m in snf_examples: + assert is_smith_normal_form(m) is True + + for m in non_snf_examples: + assert is_smith_normal_form(m) is False + + +def test_smith_normal(): + + m = DM([ + [12, 6, 4, 8], + [3, 9, 6, 12], + [2, 16, 14, 28], + [20, 10, 10, 20]], ZZ) + + smf = DM([ + [1, 0, 0, 0], + [0, 10, 0, 0], + [0, 0, 30, 0], + [0, 0, 0, 0]], ZZ) + + s = DM([ + [0, 1, -1, 0], + [1, -4, 0, 0], + [0, -2, 3, 0], + [-2, 2, -1, 1]], ZZ) + + t = DM([ + [1, 1, 10, 0], + [0, -1, -2, 0], + [0, 1, 3, -2], + [0, 0, 0, 1]], ZZ) + + assert smith_normal_form(m).to_dense() == smf + assert smith_normal_decomp(m) == (smf, s, t) + assert is_smith_normal_form(smf) + assert smf == s * m * t + + m00 = DomainMatrix.zeros((0, 0), ZZ).to_dense() + m01 = DomainMatrix.zeros((0, 1), ZZ).to_dense() + m10 = DomainMatrix.zeros((1, 0), ZZ).to_dense() + i11 = DM([[1]], ZZ) + + assert smith_normal_form(m00) == m00.to_sparse() + assert smith_normal_form(m01) == m01.to_sparse() + assert smith_normal_form(m10) == m10.to_sparse() + assert smith_normal_form(i11) == i11.to_sparse() + + assert smith_normal_decomp(m00) == (m00, m00, m00) + assert smith_normal_decomp(m01) == (m01, m00, i11) + assert smith_normal_decomp(m10) == (m10, i11, m00) + assert smith_normal_decomp(i11) == (i11, i11, i11) + + x = Symbol('x') + m = DM([[x-1, 1, -1], + [ 0, x, -1], + [ 0, -1, x]], QQ[x]) + dx = m.domain.gens[0] + assert invariant_factors(m) == (1, dx-1, dx**2-1) + + zr = DomainMatrix([], (0, 2), ZZ) + zc = DomainMatrix([[], []], (2, 0), ZZ) + assert smith_normal_form(zr).to_dense() == zr + assert smith_normal_form(zc).to_dense() == zc + + assert smith_normal_form(DM([[2, 4]], ZZ)).to_dense() == DM([[2, 0]], ZZ) + assert smith_normal_form(DM([[0, -2]], ZZ)).to_dense() == DM([[2, 0]], ZZ) + assert smith_normal_form(DM([[0], [-2]], ZZ)).to_dense() == DM([[2], [0]], ZZ) + + assert smith_normal_decomp(DM([[0, -2]], ZZ)) == ( + DM([[2, 0]], ZZ), DM([[-1]], ZZ), DM([[0, 1], [1, 0]], ZZ) + ) + assert smith_normal_decomp(DM([[0], [-2]], ZZ)) == ( + DM([[2], [0]], ZZ), DM([[0, -1], [1, 0]], ZZ), DM([[1]], ZZ) + ) + + m = DM([[3, 0, 0, 0], [0, 0, 0, 0], [0, 0, 2, 0]], ZZ) + snf = DM([[1, 0, 0, 0], [0, 6, 0, 0], [0, 0, 0, 0]], ZZ) + s = DM([[1, 0, 1], [2, 0, 3], [0, 1, 0]], ZZ) + t = DM([[1, -2, 0, 0], [0, 0, 0, 1], [-1, 3, 0, 0], [0, 0, 1, 0]], ZZ) + + assert smith_normal_form(m).to_dense() == snf + assert smith_normal_decomp(m) == (snf, s, t) + assert is_smith_normal_form(snf) + assert snf == s * m * t + + raises(ValueError, lambda: smith_normal_form(DM([[1]], ZZ[x]))) + + +def test_hermite_normal(): + m = DM([[2, 7, 17, 29, 41], [3, 11, 19, 31, 43], [5, 13, 23, 37, 47]], ZZ) + hnf = DM([[1, 0, 0], [0, 2, 1], [0, 0, 1]], ZZ) + assert hermite_normal_form(m) == hnf + assert hermite_normal_form(m, D=ZZ(2)) == hnf + assert hermite_normal_form(m, D=ZZ(2), check_rank=True) == hnf + + m = m.transpose() + hnf = DM([[37, 0, 19], [222, -6, 113], [48, 0, 25], [0, 2, 1], [0, 0, 1]], ZZ) + assert hermite_normal_form(m) == hnf + raises(DMShapeError, lambda: _hermite_normal_form_modulo_D(m, ZZ(96))) + raises(DMDomainError, lambda: _hermite_normal_form_modulo_D(m, QQ(96))) + + m = DM([[8, 28, 68, 116, 164], [3, 11, 19, 31, 43], [5, 13, 23, 37, 47]], ZZ) + hnf = DM([[4, 0, 0], [0, 2, 1], [0, 0, 1]], ZZ) + assert hermite_normal_form(m) == hnf + assert hermite_normal_form(m, D=ZZ(8)) == hnf + assert hermite_normal_form(m, D=ZZ(8), check_rank=True) == hnf + + m = DM([[10, 8, 6, 30, 2], [45, 36, 27, 18, 9], [5, 4, 3, 2, 1]], ZZ) + hnf = DM([[26, 2], [0, 9], [0, 1]], ZZ) + assert hermite_normal_form(m) == hnf + + m = DM([[2, 7], [0, 0], [0, 0]], ZZ) + hnf = DM([[1], [0], [0]], ZZ) + assert hermite_normal_form(m) == hnf + + m = DM([[-2, 1], [0, 1]], ZZ) + hnf = DM([[2, 1], [0, 1]], ZZ) + assert hermite_normal_form(m) == hnf + + m = DomainMatrix([[QQ(1)]], (1, 1), QQ) + raises(DMDomainError, lambda: hermite_normal_form(m)) + raises(DMDomainError, lambda: _hermite_normal_form(m)) + raises(DMDomainError, lambda: _hermite_normal_form_modulo_D(m, ZZ(1))) diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/tests/test_nullspace.py b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/test_nullspace.py new file mode 100644 index 0000000000000000000000000000000000000000..dbb025b7dc9dff31bc97d86e175147ffede5a7e3 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/test_nullspace.py @@ -0,0 +1,209 @@ +from sympy import ZZ, Matrix +from sympy.polys.matrices import DM, DomainMatrix +from sympy.polys.matrices.ddm import DDM +from sympy.polys.matrices.sdm import SDM + +import pytest + +zeros = lambda shape, K: DomainMatrix.zeros(shape, K).to_dense() +eye = lambda n, K: DomainMatrix.eye(n, K).to_dense() + + +# +# DomainMatrix.nullspace can have a divided answer or can return an undivided +# uncanonical answer. The uncanonical answer is not unique but we can make it +# unique by making it primitive (remove gcd). The tests here all show the +# primitive form. We test two things: +# +# A.nullspace().primitive()[1] == answer. +# A.nullspace(divide_last=True) == _divide_last(answer). +# +# The nullspace as returned by DomainMatrix and related classes is the +# transpose of the nullspace as returned by Matrix. Matrix returns a list of +# of column vectors whereas DomainMatrix returns a matrix whose rows are the +# nullspace vectors. +# + + +NULLSPACE_EXAMPLES = [ + + ( + 'zz_1', + DM([[ 1, 2, 3]], ZZ), + DM([[-2, 1, 0], + [-3, 0, 1]], ZZ), + ), + + ( + 'zz_2', + zeros((0, 0), ZZ), + zeros((0, 0), ZZ), + ), + + ( + 'zz_3', + zeros((2, 0), ZZ), + zeros((0, 0), ZZ), + ), + + ( + 'zz_4', + zeros((0, 2), ZZ), + eye(2, ZZ), + ), + + ( + 'zz_5', + zeros((2, 2), ZZ), + eye(2, ZZ), + ), + + ( + 'zz_6', + DM([[1, 2], + [3, 4]], ZZ), + zeros((0, 2), ZZ), + ), + + ( + 'zz_7', + DM([[1, 1], + [1, 1]], ZZ), + DM([[-1, 1]], ZZ), + ), + + ( + 'zz_8', + DM([[1], + [1]], ZZ), + zeros((0, 1), ZZ), + ), + + ( + 'zz_9', + DM([[1, 1]], ZZ), + DM([[-1, 1]], ZZ), + ), + + ( + 'zz_10', + DM([[0, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 1, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 1]], ZZ), + DM([[ 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], + [-1, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [ 0, -1, 0, 0, 0, 0, 0, 1, 0, 0], + [ 0, 0, 0, -1, 0, 0, 0, 0, 1, 0], + [ 0, 0, 0, 0, -1, 0, 0, 0, 0, 1]], ZZ), + ), + +] + + +def _to_DM(A, ans): + """Convert the answer to DomainMatrix.""" + if isinstance(A, DomainMatrix): + return A.to_dense() + elif isinstance(A, DDM): + return DomainMatrix(list(A), A.shape, A.domain).to_dense() + elif isinstance(A, SDM): + return DomainMatrix(dict(A), A.shape, A.domain).to_dense() + else: + assert False # pragma: no cover + + +def _divide_last(null): + """Normalize the nullspace by the rightmost non-zero entry.""" + null = null.to_field() + + if null.is_zero_matrix: + return null + + rows = [] + for i in range(null.shape[0]): + for j in reversed(range(null.shape[1])): + if null[i, j]: + rows.append(null[i, :] / null[i, j]) + break + else: + assert False # pragma: no cover + + return DomainMatrix.vstack(*rows) + + +def _check_primitive(null, null_ans): + """Check that the primitive of the answer matches.""" + null = _to_DM(null, null_ans) + cont, null_prim = null.primitive() + assert null_prim == null_ans + + +def _check_divided(null, null_ans): + """Check the divided answer.""" + null = _to_DM(null, null_ans) + null_ans_norm = _divide_last(null_ans) + assert null == null_ans_norm + + +@pytest.mark.parametrize('name, A, A_null', NULLSPACE_EXAMPLES) +def test_Matrix_nullspace(name, A, A_null): + A = A.to_Matrix() + + A_null_cols = A.nullspace() + + # We have to patch up the case where the nullspace is empty + if A_null_cols: + A_null_found = Matrix.hstack(*A_null_cols) + else: + A_null_found = Matrix.zeros(A.cols, 0) + + A_null_found = A_null_found.to_DM().to_field().to_dense() + + # The Matrix result is the transpose of DomainMatrix result. + A_null_found = A_null_found.transpose() + + _check_divided(A_null_found, A_null) + + +@pytest.mark.parametrize('name, A, A_null', NULLSPACE_EXAMPLES) +def test_dm_dense_nullspace(name, A, A_null): + A = A.to_field().to_dense() + A_null_found = A.nullspace(divide_last=True) + _check_divided(A_null_found, A_null) + + +@pytest.mark.parametrize('name, A, A_null', NULLSPACE_EXAMPLES) +def test_dm_sparse_nullspace(name, A, A_null): + A = A.to_field().to_sparse() + A_null_found = A.nullspace(divide_last=True) + _check_divided(A_null_found, A_null) + + +@pytest.mark.parametrize('name, A, A_null', NULLSPACE_EXAMPLES) +def test_ddm_nullspace(name, A, A_null): + A = A.to_field().to_ddm() + A_null_found, _ = A.nullspace() + _check_divided(A_null_found, A_null) + + +@pytest.mark.parametrize('name, A, A_null', NULLSPACE_EXAMPLES) +def test_sdm_nullspace(name, A, A_null): + A = A.to_field().to_sdm() + A_null_found, _ = A.nullspace() + _check_divided(A_null_found, A_null) + + +@pytest.mark.parametrize('name, A, A_null', NULLSPACE_EXAMPLES) +def test_dm_dense_nullspace_fracfree(name, A, A_null): + A = A.to_dense() + A_null_found = A.nullspace() + _check_primitive(A_null_found, A_null) + + +@pytest.mark.parametrize('name, A, A_null', NULLSPACE_EXAMPLES) +def test_dm_sparse_nullspace_fracfree(name, A, A_null): + A = A.to_sparse() + A_null_found = A.nullspace() + _check_primitive(A_null_found, A_null) diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/tests/test_rref.py b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/test_rref.py new file mode 100644 index 0000000000000000000000000000000000000000..49def18c8132c0537540163a96bf6cf323c5a85c --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/test_rref.py @@ -0,0 +1,737 @@ +from sympy import ZZ, QQ, ZZ_I, EX, Matrix, eye, zeros, symbols +from sympy.polys.matrices import DM, DomainMatrix +from sympy.polys.matrices.dense import ddm_irref_den, ddm_irref +from sympy.polys.matrices.ddm import DDM +from sympy.polys.matrices.sdm import SDM, sdm_irref, sdm_rref_den + +import pytest + + +# +# The dense and sparse implementations of rref_den are ddm_irref_den and +# sdm_irref_den. These can give results that differ by some factor and also +# give different results if the order of the rows is changed. The tests below +# show all results on lowest terms as should be returned by cancel_denom. +# +# The EX domain is also a case where the dense and sparse implementations +# can give results in different forms: the results should be equivalent but +# are not canonical because EX does not have a canonical form. +# + + +a, b, c, d = symbols('a, b, c, d') + + +qq_large_1 = DM([ +[ (1,2), (1,3), (1,5), (1,7), (1,11), (1,13), (1,17), (1,19), (1,23), (1,29), (1,31)], +[ (1,37), (1,41), (1,43), (1,47), (1,53), (1,59), (1,61), (1,67), (1,71), (1,73), (1,79)], +[ (1,83), (1,89), (1,97),(1,101),(1,103),(1,107),(1,109),(1,113),(1,127),(1,131),(1,137)], +[(1,139),(1,149),(1,151),(1,157),(1,163),(1,167),(1,173),(1,179),(1,181),(1,191),(1,193)], +[(1,197),(1,199),(1,211),(1,223),(1,227),(1,229),(1,233),(1,239),(1,241),(1,251),(1,257)], +[(1,263),(1,269),(1,271),(1,277),(1,281),(1,283),(1,293),(1,307),(1,311),(1,313),(1,317)], +[(1,331),(1,337),(1,347),(1,349),(1,353),(1,359),(1,367),(1,373),(1,379),(1,383),(1,389)], +[(1,397),(1,401),(1,409),(1,419),(1,421),(1,431),(1,433),(1,439),(1,443),(1,449),(1,457)], +[(1,461),(1,463),(1,467),(1,479),(1,487),(1,491),(1,499),(1,503),(1,509),(1,521),(1,523)], +[(1,541),(1,547),(1,557),(1,563),(1,569),(1,571),(1,577),(1,587),(1,593),(1,599),(1,601)], +[(1,607),(1,613),(1,617),(1,619),(1,631),(1,641),(1,643),(1,647),(1,653),(1,659),(1,661)]], + QQ) + +qq_large_2 = qq_large_1 + 10**100 * DomainMatrix.eye(11, QQ) + + +RREF_EXAMPLES = [ + ( + 'zz_1', + DM([[1, 2, 3]], ZZ), + DM([[1, 2, 3]], ZZ), + ZZ(1), + ), + + ( + 'zz_2', + DomainMatrix([], (0, 0), ZZ), + DomainMatrix([], (0, 0), ZZ), + ZZ(1), + ), + + ( + 'zz_3', + DM([[1, 2], + [3, 4]], ZZ), + DM([[1, 0], + [0, 1]], ZZ), + ZZ(1), + ), + + ( + 'zz_4', + DM([[1, 0], + [3, 4]], ZZ), + DM([[1, 0], + [0, 1]], ZZ), + ZZ(1), + ), + + ( + 'zz_5', + DM([[0, 2], + [3, 4]], ZZ), + DM([[1, 0], + [0, 1]], ZZ), + ZZ(1), + ), + + ( + 'zz_6', + DM([[1, 2, 3], + [4, 5, 6], + [7, 8, 9]], ZZ), + DM([[1, 0, -1], + [0, 1, 2], + [0, 0, 0]], ZZ), + ZZ(1), + ), + + ( + 'zz_7', + DM([[0, 0, 0], + [0, 0, 0], + [1, 0, 0]], ZZ), + DM([[1, 0, 0], + [0, 0, 0], + [0, 0, 0]], ZZ), + ZZ(1), + ), + + ( + 'zz_8', + DM([[0, 0, 0], + [0, 0, 0], + [0, 0, 0]], ZZ), + DM([[0, 0, 0], + [0, 0, 0], + [0, 0, 0]], ZZ), + ZZ(1), + ), + + ( + 'zz_9', + DM([[1, 1, 0], + [0, 0, 2], + [0, 0, 0]], ZZ), + DM([[1, 1, 0], + [0, 0, 1], + [0, 0, 0]], ZZ), + ZZ(1), + ), + + ( + 'zz_10', + DM([[2, 2, 0], + [0, 0, 2], + [0, 0, 0]], ZZ), + DM([[1, 1, 0], + [0, 0, 1], + [0, 0, 0]], ZZ), + ZZ(1), + ), + + ( + 'zz_11', + DM([[2, 2, 0], + [0, 2, 2], + [0, 0, 2]], ZZ), + DM([[1, 0, 0], + [0, 1, 0], + [0, 0, 1]], ZZ), + ZZ(1), + ), + + ( + 'zz_12', + DM([[ 1, 2, 3], + [ 4, 5, 6], + [ 7, 8, 9], + [10, 11, 12]], ZZ), + DM([[1, 0, -1], + [0, 1, 2], + [0, 0, 0], + [0, 0, 0]], ZZ), + ZZ(1), + ), + + ( + 'zz_13', + DM([[ 1, 2, 3], + [ 4, 5, 6], + [ 7, 8, 9], + [10, 11, 13]], ZZ), + DM([[ 1, 0, 0], + [ 0, 1, 0], + [ 0, 0, 1], + [ 0, 0, 0]], ZZ), + ZZ(1), + ), + + ( + 'zz_14', + DM([[1, 2, 4, 3], + [4, 5, 10, 6], + [7, 8, 16, 9]], ZZ), + DM([[1, 0, 0, -1], + [0, 1, 2, 2], + [0, 0, 0, 0]], ZZ), + ZZ(1), + ), + + ( + 'zz_15', + DM([[1, 2, 4, 3], + [4, 5, 10, 6], + [7, 8, 17, 9]], ZZ), + DM([[1, 0, 0, -1], + [0, 1, 0, 2], + [0, 0, 1, 0]], ZZ), + ZZ(1), + ), + + ( + 'zz_16', + DM([[1, 2, 0, 1], + [1, 1, 9, 0]], ZZ), + DM([[1, 0, 18, -1], + [0, 1, -9, 1]], ZZ), + ZZ(1), + ), + + ( + 'zz_17', + DM([[1, 1, 1], + [1, 2, 2]], ZZ), + DM([[1, 0, 0], + [0, 1, 1]], ZZ), + ZZ(1), + ), + + ( + # Here the sparse implementation and dense implementation give very + # different denominators: 4061232 and -1765176. + 'zz_18', + DM([[94, 24, 0, 27, 0], + [79, 0, 0, 0, 0], + [85, 16, 71, 81, 0], + [ 0, 0, 72, 77, 0], + [21, 0, 34, 0, 0]], ZZ), + DM([[ 1, 0, 0, 0, 0], + [ 0, 1, 0, 0, 0], + [ 0, 0, 1, 0, 0], + [ 0, 0, 0, 1, 0], + [ 0, 0, 0, 0, 0]], ZZ), + ZZ(1), + ), + + ( + # Let's have a denominator that cannot be cancelled. + 'zz_19', + DM([[1, 2, 4], + [4, 5, 6]], ZZ), + DM([[3, 0, -8], + [0, 3, 10]], ZZ), + ZZ(3), + ), + + ( + 'zz_20', + DM([[0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 4]], ZZ), + DM([[0, 0, 0, 0, 1], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0]], ZZ), + ZZ(1), + ), + + ( + 'zz_21', + DM([[0, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 1, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 1]], ZZ), + DM([[1, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 1, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0]], ZZ), + ZZ(1), + ), + + ( + 'zz_22', + DM([[1, 1, 1, 0, 1], + [1, 1, 0, 1, 0], + [1, 0, 1, 0, 1], + [1, 1, 0, 1, 0], + [1, 0, 0, 0, 0]], ZZ), + DM([[1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 1, 0, 1], + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 0]], ZZ), + ZZ(1), + ), + + ( + 'zz_large_1', + DM([ +[ 0, 0, 0, 81, 0, 0, 75, 0, 0, 0, 0, 0, 0, 27, 0, 0, 0, 0, 0, 0], +[ 0, 0, 0, 0, 0, 86, 0, 92, 79, 54, 0, 7, 0, 0, 0, 0, 79, 0, 0, 0], +[89, 54, 81, 0, 0, 20, 0, 0, 0, 0, 0, 0, 51, 0, 94, 0, 0, 77, 0, 0], +[ 0, 0, 0, 96, 0, 0, 0, 0, 0, 0, 0, 0, 48, 29, 0, 0, 5, 0, 32, 0], +[ 0, 70, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 60, 0, 0, 0, 11], +[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 37, 0, 43, 0, 0], +[ 0, 0, 0, 0, 0, 38, 91, 0, 0, 0, 0, 38, 0, 0, 0, 0, 0, 26, 0, 0], +[69, 0, 0, 0, 0, 0, 94, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 55], +[ 0, 13, 18, 49, 49, 88, 0, 0, 35, 54, 0, 0, 51, 0, 0, 0, 0, 0, 0, 87], +[ 0, 0, 0, 0, 31, 0, 40, 0, 0, 0, 0, 0, 0, 50, 0, 0, 0, 0, 88, 0], +[ 0, 0, 0, 0, 0, 0, 0, 0, 98, 0, 0, 0, 15, 53, 0, 92, 0, 0, 0, 0], +[ 0, 0, 0, 95, 0, 0, 0, 36, 0, 0, 0, 0, 0, 72, 0, 0, 0, 0, 73, 19], +[ 0, 65, 14, 96, 0, 0, 0, 0, 0, 0, 0, 0, 0, 90, 0, 0, 0, 34, 0, 0], +[ 0, 0, 0, 16, 39, 44, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 51, 0, 0], +[ 0, 17, 0, 0, 0, 99, 84, 13, 50, 84, 0, 0, 0, 0, 95, 0, 43, 33, 20, 0], +[79, 0, 17, 52, 99, 12, 69, 0, 98, 0, 68, 0, 0, 0, 0, 0, 0, 0, 0, 0], +[ 0, 0, 0, 82, 0, 44, 0, 0, 0, 97, 0, 0, 0, 0, 0, 10, 0, 0, 31, 0], +[ 0, 0, 21, 0, 67, 0, 0, 0, 0, 0, 4, 0, 50, 0, 0, 0, 33, 0, 0, 0], +[ 0, 0, 0, 0, 9, 42, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8], +[ 0, 77, 0, 0, 0, 0, 0, 0, 0, 0, 34, 93, 0, 0, 0, 0, 47, 0, 0, 0]], + ZZ), + DM([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]], ZZ), + ZZ(1), + ), + + ( + 'zz_large_2', + DM([ +[ 0, 0, 0, 0, 50, 0, 6, 81, 0, 1, 86, 0, 0, 98, 82, 94, 4, 0, 0, 29], +[ 0, 44, 43, 0, 62, 0, 0, 0, 60, 0, 0, 0, 0, 71, 9, 0, 57, 41, 0, 93], +[ 0, 0, 28, 0, 74, 89, 42, 0, 28, 0, 6, 0, 0, 0, 44, 0, 0, 0, 77, 19], +[ 0, 21, 82, 0, 30, 88, 0, 89, 68, 0, 0, 0, 79, 41, 0, 0, 99, 0, 0, 0], +[31, 0, 0, 0, 19, 64, 0, 0, 79, 0, 5, 0, 72, 10, 60, 32, 64, 59, 0, 24], +[ 0, 0, 0, 0, 0, 57, 0, 94, 0, 83, 20, 0, 0, 9, 31, 0, 49, 26, 58, 0], +[ 0, 65, 56, 31, 64, 0, 0, 0, 0, 0, 0, 52, 85, 0, 0, 0, 0, 51, 0, 0], +[ 0, 35, 0, 0, 0, 69, 0, 0, 64, 0, 0, 0, 0, 70, 0, 0, 90, 0, 75, 76], +[69, 7, 0, 90, 0, 0, 84, 0, 47, 69, 19, 20, 42, 0, 0, 32, 71, 35, 0, 0], +[39, 0, 90, 0, 0, 4, 85, 0, 0, 55, 0, 0, 0, 35, 67, 40, 0, 40, 0, 77], +[98, 63, 0, 71, 0, 50, 0, 2, 61, 0, 38, 0, 0, 0, 0, 75, 0, 40, 33, 56], +[ 0, 73, 0, 64, 0, 38, 0, 35, 61, 0, 0, 52, 0, 7, 0, 51, 0, 0, 0, 34], +[ 0, 0, 28, 0, 34, 5, 63, 45, 14, 42, 60, 16, 76, 54, 99, 0, 28, 30, 0, 0], +[58, 37, 14, 0, 0, 0, 94, 0, 0, 90, 0, 0, 0, 0, 0, 0, 0, 8, 90, 53], +[86, 74, 94, 0, 49, 10, 60, 0, 40, 18, 0, 0, 0, 31, 60, 24, 0, 1, 0, 29], +[53, 0, 0, 97, 0, 0, 58, 0, 0, 39, 44, 47, 0, 0, 0, 12, 50, 0, 0, 11], +[ 4, 0, 92, 10, 28, 0, 0, 89, 0, 0, 18, 54, 23, 39, 0, 2, 0, 48, 0, 92], +[ 0, 0, 90, 77, 95, 33, 0, 0, 49, 22, 39, 0, 0, 0, 0, 0, 0, 40, 0, 0], +[96, 0, 0, 0, 0, 38, 86, 0, 22, 76, 0, 0, 0, 0, 83, 88, 95, 65, 72, 0], +[81, 65, 0, 4, 60, 0, 19, 0, 0, 68, 0, 0, 89, 0, 67, 22, 0, 0, 55, 33]], + ZZ), + DM([ +[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], +[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], +[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], +[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], +[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], +[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], +[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]], + ZZ), + ZZ(1), + ), + + ( + 'zz_large_3', + DM([ +[62,35,89,58,22,47,30,28,52,72,17,56,80,26,64,21,10,35,24,42,96,32,23,50,92,37,76,94,63,66], +[20,47,96,34,10,98,19,6,29,2,19,92,61,94,38,41,32,9,5,94,31,58,27,41,72,85,61,62,40,46], +[69,26,35,68,25,52,94,13,38,65,81,10,29,15,5,4,13,99,85,0,80,51,60,60,26,77,85,2,87,25], +[99,58,69,15,52,12,18,7,27,56,12,54,21,92,38,95,33,83,28,1,44,8,29,84,92,12,2,25,46,46], +[93,13,55,48,35,87,24,40,23,35,25,32,0,19,0,85,4,79,26,11,46,75,7,96,76,11,7,57,99,75], +[128,85,26,51,161,173,77,78,85,103,123,58,91,147,38,91,161,36,123,81,102,25,75,59,17,150,112,65,77,143], +[15,59,61,82,12,83,34,8,94,71,66,7,91,21,48,69,26,12,64,38,97,87,38,15,51,33,93,43,66,89], +[74,74,53,39,69,90,41,80,32,66,40,83,87,87,61,38,12,80,24,49,37,90,19,33,56,0,46,57,56,60], +[82,11,0,25,56,58,39,49,92,93,80,38,19,62,33,85,19,61,14,30,45,91,97,34,97,53,92,28,33,43], +[83,79,41,16,95,35,53,45,26,4,71,76,61,69,69,72,87,92,59,72,54,11,22,83,8,57,77,55,19,22], +[49,34,13,31,72,77,52,70,46,41,37,6,42,66,35,6,75,33,62,57,30,14,26,31,9,95,89,13,12,90], +[29,3,49,30,51,32,77,41,38,50,16,1,87,81,93,88,58,91,83,0,38,67,29,64,60,84,5,60,23,28], +[79,51,13,20,89,96,25,8,39,62,86,52,49,81,3,85,86,3,61,24,72,11,49,28,8,55,23,52,65,53], +[96,86,73,20,41,20,37,18,10,61,85,24,40,83,69,41,4,92,23,99,64,33,18,36,32,56,60,98,39,24], +[32,62,47,80,51,66,17,1,9,30,65,75,75,88,99,92,64,53,53,86,38,51,41,14,35,18,39,25,26,32], +[39,21,8,16,33,6,35,85,75,62,43,34,18,68,71,28,32,18,12,0,81,53,1,99,3,5,45,99,35,33], +[19,95,89,45,75,94,92,5,84,93,34,17,50,56,79,98,68,82,65,81,51,90,5,95,33,71,46,61,14,7], +[53,92,8,49,67,84,21,79,49,95,66,48,36,14,62,97,26,45,58,31,83,48,11,89,67,72,91,34,56,89], +[56,76,99,92,40,8,0,16,15,48,35,72,91,46,81,14,86,60,51,7,33,12,53,78,48,21,3,89,15,79], +[81,43,33,49,6,49,36,32,57,74,87,91,17,37,31,17,67,1,40,38,69,8,3,48,59,37,64,97,11,3], +[98,48,77,16,2,48,57,38,63,59,79,35,16,71,60,86,71,41,14,76,80,97,77,69,4,58,22,55,26,73], +[80,47,78,44,31,48,47,29,29,62,19,21,17,24,19,3,53,93,97,57,13,54,12,10,77,66,60,75,32,21], +[86,63,2,13,71,38,86,23,18,15,91,65,77,65,9,92,50,0,17,42,99,80,99,27,10,99,92,9,87,84], +[66,27,72,13,13,15,72,75,39,3,14,71,15,68,10,19,49,54,11,29,47,20,63,13,97,47,24,62,16,96], +[42,63,83,60,49,68,9,53,75,87,40,25,12,63,0,12,0,95,46,46,55,25,89,1,51,1,1,96,80,52], +[35,9,97,13,86,39,66,48,41,57,23,38,11,9,35,72,88,13,41,60,10,64,71,23,1,5,23,57,6,19], +[70,61,5,50,72,60,77,13,41,94,1,45,52,22,99,47,27,18,99,42,16,48,26,9,88,77,10,94,11,92], +[55,68,58,2,72,56,81,52,79,37,1,40,21,46,27,60,37,13,97,42,85,98,69,60,76,44,42,46,29,73], +[73,0,43,17,89,97,45,2,68,14,55,60,95,2,74,85,88,68,93,76,38,76,2,51,45,76,50,79,56,18], +[72,58,41,39,24,80,23,79,44,7,98,75,30,6,85,60,20,58,77,71,90,51,38,80,30,15,33,10,82,8]], + ZZ), + Matrix([ + [eye(29) * 2028539767964472550625641331179545072876560857886207583101, + Matrix([ 4260575808093245475167216057435155595594339172099000182569, + 169148395880755256182802335904188369274227936894862744452, + 4915975976683942569102447281579134986891620721539038348914, + 6113916866367364958834844982578214901958429746875633283248, + 5585689617819894460378537031623265659753379011388162534838, + 359776822829880747716695359574308645968094838905181892423, + -2800926112141776386671436511182421432449325232461665113305, + 941642292388230001722444876624818265766384442910688463158, + 3648811843256146649321864698600908938933015862008642023935, + -4104526163246702252932955226754097174212129127510547462419, + -704814955438106792441896903238080197619233342348191408078, + 1640882266829725529929398131287244562048075707575030019335, + -4068330845192910563212155694231438198040299927120544468520, + 136589038308366497790495711534532612862715724187671166593, + 2544937011460702462290799932536905731142196510605191645593, + 755591839174293940486133926192300657264122907519174116472, + -3683838489869297144348089243628436188645897133242795965021, + -522207137101161299969706310062775465103537953077871128403, + -2260451796032703984456606059649402832441331339246756656334, + -6476809325293587953616004856993300606040336446656916663680, + 3521944238996782387785653800944972787867472610035040989081, + 2270762115788407950241944504104975551914297395787473242379, + -3259947194628712441902262570532921252128444706733549251156, + -5624569821491886970999097239695637132075823246850431083557, + -3262698255682055804320585332902837076064075936601504555698, + 5786719943788937667411185880136324396357603606944869545501, + -955257841973865996077323863289453200904051299086000660036, + -1294235552446355326174641248209752679127075717918392702116, + -3718353510747301598130831152458342785269166356215331448279, + ]),], + [zeros(1, 29), zeros(1, 1)], + ]).to_DM().to_dense(), + ZZ(2028539767964472550625641331179545072876560857886207583101), + ), + + + ( + 'qq_1', + DM([[(1,2), 0], [0, 2]], QQ), + DM([[1, 0], [0, 1]], QQ), + QQ(1), + ), + + ( + # Standard square case + 'qq_2', + DM([[0, 1], + [1, 1]], QQ), + DM([[1, 0], + [0, 1]], QQ), + QQ(1), + ), + + ( + # m < n case + 'qq_3', + DM([[1, 2, 1], + [3, 4, 1]], QQ), + DM([[1, 0, -1], + [0, 1, 1]], QQ), + QQ(1), + ), + + ( + # same m < n but reversed + 'qq_4', + DM([[3, 4, 1], + [1, 2, 1]], QQ), + DM([[1, 0, -1], + [0, 1, 1]], QQ), + QQ(1), + ), + + ( + # m > n case + 'qq_5', + DM([[1, 0], + [1, 3], + [0, 1]], QQ), + DM([[1, 0], + [0, 1], + [0, 0]], QQ), + QQ(1), + ), + + ( + # Example with missing pivot + 'qq_6', + DM([[1, 0, 1], + [3, 0, 1]], QQ), + DM([[1, 0, 0], + [0, 0, 1]], QQ), + QQ(1), + ), + + ( + # This is intended to trigger the threshold where we give up on + # clearing denominators. + 'qq_large_1', + qq_large_1, + DomainMatrix.eye(11, QQ).to_dense(), + QQ(1), + ), + + ( + # This is intended to trigger the threshold where we use rref_den over + # QQ. + 'qq_large_2', + qq_large_2, + DomainMatrix.eye(11, QQ).to_dense(), + QQ(1), + ), + + ( + # Example with missing pivot and no replacement + + # This example is just enough to show a different result from the dense + # and sparse versions of the algorithm: + # + # >>> A = Matrix([[0, 1], [0, 2], [1, 0]]) + # >>> A.to_DM().to_sparse().rref_den()[0].to_Matrix() + # Matrix([ + # [1, 0], + # [0, 1], + # [0, 0]]) + # >>> A.to_DM().to_dense().rref_den()[0].to_Matrix() + # Matrix([ + # [2, 0], + # [0, 2], + # [0, 0]]) + # + 'qq_7', + DM([[0, 1], + [0, 2], + [1, 0]], QQ), + DM([[1, 0], + [0, 1], + [0, 0]], QQ), + QQ(1), + ), + + ( + # Gaussian integers + 'zz_i_1', + DM([[(0,1), 1, 1], + [ 1, 1, 1]], ZZ_I), + DM([[1, 0, 0], + [0, 1, 1]], ZZ_I), + ZZ_I(1), + ), + + ( + # EX: test_issue_23718 + 'EX_1', + DM([ + [a, b, 1], + [c, d, 1]], EX), + DM([[a*d - b*c, 0, -b + d], + [ 0, a*d - b*c, a - c]], EX), + EX(a*d - b*c), + ), + +] + + +def _to_DM(A, ans): + """Convert the answer to DomainMatrix.""" + if isinstance(A, DomainMatrix): + return A.to_dense() + elif isinstance(A, Matrix): + return A.to_DM(ans.domain).to_dense() + + if not (hasattr(A, 'shape') and hasattr(A, 'domain')): + shape, domain = ans.shape, ans.domain + else: + shape, domain = A.shape, A.domain + + if isinstance(A, (DDM, list)): + return DomainMatrix(list(A), shape, domain).to_dense() + elif isinstance(A, (SDM, dict)): + return DomainMatrix(dict(A), shape, domain).to_dense() + else: + assert False # pragma: no cover + + +def _pivots(A_rref): + """Return the pivots from the rref of A.""" + return tuple(sorted(map(min, A_rref.to_sdm().values()))) + + +def _check_cancel(result, rref_ans, den_ans): + """Check the cancelled result.""" + rref, den, pivots = result + if isinstance(rref, (DDM, SDM, list, dict)): + assert type(pivots) is list + pivots = tuple(pivots) + rref = _to_DM(rref, rref_ans) + rref2, den2 = rref.cancel_denom(den) + assert rref2 == rref_ans + assert den2 == den_ans + assert pivots == _pivots(rref) + + +def _check_divide(result, rref_ans, den_ans): + """Check the divided result.""" + rref, pivots = result + if isinstance(rref, (DDM, SDM, list, dict)): + assert type(pivots) is list + pivots = tuple(pivots) + rref_ans = rref_ans.to_field() / den_ans + rref = _to_DM(rref, rref_ans) + assert rref == rref_ans + assert _pivots(rref) == pivots + + +@pytest.mark.parametrize('name, A, A_rref, den', RREF_EXAMPLES) +def test_Matrix_rref(name, A, A_rref, den): + K = A.domain + A = A.to_Matrix() + A_rref_found, pivots = A.rref() + if K.is_EX: + A_rref_found = A_rref_found.expand() + _check_divide((A_rref_found, pivots), A_rref, den) + + +@pytest.mark.parametrize('name, A, A_rref, den', RREF_EXAMPLES) +def test_dm_dense_rref(name, A, A_rref, den): + A = A.to_field() + _check_divide(A.rref(), A_rref, den) + + +@pytest.mark.parametrize('name, A, A_rref, den', RREF_EXAMPLES) +def test_dm_dense_rref_den(name, A, A_rref, den): + _check_cancel(A.rref_den(), A_rref, den) + + +@pytest.mark.parametrize('name, A, A_rref, den', RREF_EXAMPLES) +def test_dm_sparse_rref(name, A, A_rref, den): + A = A.to_field().to_sparse() + _check_divide(A.rref(), A_rref, den) + + +@pytest.mark.parametrize('name, A, A_rref, den', RREF_EXAMPLES) +def test_dm_sparse_rref_den(name, A, A_rref, den): + A = A.to_sparse() + _check_cancel(A.rref_den(), A_rref, den) + + +@pytest.mark.parametrize('name, A, A_rref, den', RREF_EXAMPLES) +def test_dm_sparse_rref_den_keep_domain(name, A, A_rref, den): + A = A.to_sparse() + A_rref_f, den_f, pivots_f = A.rref_den(keep_domain=False) + A_rref_f = A_rref_f.to_field() / den_f + _check_divide((A_rref_f, pivots_f), A_rref, den) + + +@pytest.mark.parametrize('name, A, A_rref, den', RREF_EXAMPLES) +def test_dm_sparse_rref_den_keep_domain_CD(name, A, A_rref, den): + A = A.to_sparse() + A_rref_f, den_f, pivots_f = A.rref_den(keep_domain=False, method='CD') + A_rref_f = A_rref_f.to_field() / den_f + _check_divide((A_rref_f, pivots_f), A_rref, den) + + +@pytest.mark.parametrize('name, A, A_rref, den', RREF_EXAMPLES) +def test_dm_sparse_rref_den_keep_domain_GJ(name, A, A_rref, den): + A = A.to_sparse() + A_rref_f, den_f, pivots_f = A.rref_den(keep_domain=False, method='GJ') + A_rref_f = A_rref_f.to_field() / den_f + _check_divide((A_rref_f, pivots_f), A_rref, den) + + +@pytest.mark.parametrize('name, A, A_rref, den', RREF_EXAMPLES) +def test_ddm_rref_den(name, A, A_rref, den): + A = A.to_ddm() + _check_cancel(A.rref_den(), A_rref, den) + + +@pytest.mark.parametrize('name, A, A_rref, den', RREF_EXAMPLES) +def test_sdm_rref_den(name, A, A_rref, den): + A = A.to_sdm() + _check_cancel(A.rref_den(), A_rref, den) + + +@pytest.mark.parametrize('name, A, A_rref, den', RREF_EXAMPLES) +def test_ddm_rref(name, A, A_rref, den): + A = A.to_field().to_ddm() + _check_divide(A.rref(), A_rref, den) + + +@pytest.mark.parametrize('name, A, A_rref, den', RREF_EXAMPLES) +def test_sdm_rref(name, A, A_rref, den): + A = A.to_field().to_sdm() + _check_divide(A.rref(), A_rref, den) + + +@pytest.mark.parametrize('name, A, A_rref, den', RREF_EXAMPLES) +def test_ddm_irref(name, A, A_rref, den): + A = A.to_field().to_ddm().copy() + pivots_found = ddm_irref(A) + _check_divide((A, pivots_found), A_rref, den) + + +@pytest.mark.parametrize('name, A, A_rref, den', RREF_EXAMPLES) +def test_ddm_irref_den(name, A, A_rref, den): + A = A.to_ddm().copy() + (den_found, pivots_found) = ddm_irref_den(A, A.domain) + result = (A, den_found, pivots_found) + _check_cancel(result, A_rref, den) + + +@pytest.mark.parametrize('name, A, A_rref, den', RREF_EXAMPLES) +def test_sparse_sdm_rref(name, A, A_rref, den): + A = A.to_field().to_sdm() + _check_divide(sdm_irref(A)[:2], A_rref, den) + + +@pytest.mark.parametrize('name, A, A_rref, den', RREF_EXAMPLES) +def test_sparse_sdm_rref_den(name, A, A_rref, den): + A = A.to_sdm().copy() + K = A.domain + _check_cancel(sdm_rref_den(A, K), A_rref, den) diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/tests/test_sdm.py b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/test_sdm.py new file mode 100644 index 0000000000000000000000000000000000000000..cd7e5d460a1b2d44279a2a1772cc901f80ca733e --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/test_sdm.py @@ -0,0 +1,428 @@ +""" +Tests for the basic functionality of the SDM class. +""" + +from itertools import product + +from sympy.core.singleton import S +from sympy.external.gmpy import GROUND_TYPES +from sympy.testing.pytest import raises + +from sympy.polys.domains import QQ, ZZ, EXRAW +from sympy.polys.matrices.sdm import SDM +from sympy.polys.matrices.ddm import DDM +from sympy.polys.matrices.exceptions import (DMBadInputError, DMDomainError, + DMShapeError) + + +def test_SDM(): + A = SDM({0:{0:ZZ(1)}}, (2, 2), ZZ) + assert A.domain == ZZ + assert A.shape == (2, 2) + assert dict(A) == {0:{0:ZZ(1)}} + + raises(DMBadInputError, lambda: SDM({5:{1:ZZ(0)}}, (2, 2), ZZ)) + raises(DMBadInputError, lambda: SDM({0:{5:ZZ(0)}}, (2, 2), ZZ)) + + +def test_DDM_str(): + sdm = SDM({0:{0:ZZ(1)}, 1:{1:ZZ(1)}}, (2, 2), ZZ) + assert str(sdm) == '{0: {0: 1}, 1: {1: 1}}' + if GROUND_TYPES == 'gmpy': # pragma: no cover + assert repr(sdm) == 'SDM({0: {0: mpz(1)}, 1: {1: mpz(1)}}, (2, 2), ZZ)' + else: # pragma: no cover + assert repr(sdm) == 'SDM({0: {0: 1}, 1: {1: 1}}, (2, 2), ZZ)' + + +def test_SDM_new(): + A = SDM({0:{0:ZZ(1)}}, (2, 2), ZZ) + B = A.new({}, (2, 2), ZZ) + assert B == SDM({}, (2, 2), ZZ) + + +def test_SDM_copy(): + A = SDM({0:{0:ZZ(1)}}, (2, 2), ZZ) + B = A.copy() + assert A == B + A[0][0] = ZZ(2) + assert A != B + + +def test_SDM_from_list(): + A = SDM.from_list([[ZZ(0), ZZ(1)], [ZZ(1), ZZ(0)]], (2, 2), ZZ) + assert A == SDM({0:{1:ZZ(1)}, 1:{0:ZZ(1)}}, (2, 2), ZZ) + + raises(DMBadInputError, lambda: SDM.from_list([[ZZ(0)], [ZZ(0), ZZ(1)]], (2, 2), ZZ)) + raises(DMBadInputError, lambda: SDM.from_list([[ZZ(0), ZZ(1)]], (2, 2), ZZ)) + + +def test_SDM_to_list(): + A = SDM({0:{1: ZZ(1)}}, (2, 2), ZZ) + assert A.to_list() == [[ZZ(0), ZZ(1)], [ZZ(0), ZZ(0)]] + + A = SDM({}, (0, 2), ZZ) + assert A.to_list() == [] + + A = SDM({}, (2, 0), ZZ) + assert A.to_list() == [[], []] + + +def test_SDM_to_list_flat(): + A = SDM({0:{1: ZZ(1)}}, (2, 2), ZZ) + assert A.to_list_flat() == [ZZ(0), ZZ(1), ZZ(0), ZZ(0)] + + +def test_SDM_to_dok(): + A = SDM({0:{1: ZZ(1)}}, (2, 2), ZZ) + assert A.to_dok() == {(0, 1): ZZ(1)} + + +def test_SDM_from_ddm(): + A = DDM([[ZZ(1), ZZ(0)], [ZZ(1), ZZ(0)]], (2, 2), ZZ) + B = SDM.from_ddm(A) + assert B.domain == ZZ + assert B.shape == (2, 2) + assert dict(B) == {0:{0:ZZ(1)}, 1:{0:ZZ(1)}} + + +def test_SDM_to_ddm(): + A = SDM({0:{1: ZZ(1)}}, (2, 2), ZZ) + B = DDM([[ZZ(0), ZZ(1)], [ZZ(0), ZZ(0)]], (2, 2), ZZ) + assert A.to_ddm() == B + + +def test_SDM_to_sdm(): + A = SDM({0:{1: ZZ(1)}}, (2, 2), ZZ) + assert A.to_sdm() == A + + +def test_SDM_getitem(): + A = SDM({0:{1:ZZ(1)}}, (2, 2), ZZ) + assert A.getitem(0, 0) == ZZ.zero + assert A.getitem(0, 1) == ZZ.one + assert A.getitem(1, 0) == ZZ.zero + assert A.getitem(-2, -2) == ZZ.zero + assert A.getitem(-2, -1) == ZZ.one + assert A.getitem(-1, -2) == ZZ.zero + raises(IndexError, lambda: A.getitem(2, 0)) + raises(IndexError, lambda: A.getitem(0, 2)) + + +def test_SDM_setitem(): + A = SDM({0:{1:ZZ(1)}}, (2, 2), ZZ) + A.setitem(0, 0, ZZ(1)) + assert A == SDM({0:{0:ZZ(1), 1:ZZ(1)}}, (2, 2), ZZ) + A.setitem(1, 0, ZZ(1)) + assert A == SDM({0:{0:ZZ(1), 1:ZZ(1)}, 1:{0:ZZ(1)}}, (2, 2), ZZ) + A.setitem(1, 0, ZZ(0)) + assert A == SDM({0:{0:ZZ(1), 1:ZZ(1)}}, (2, 2), ZZ) + # Repeat the above test so that this time the row is empty + A.setitem(1, 0, ZZ(0)) + assert A == SDM({0:{0:ZZ(1), 1:ZZ(1)}}, (2, 2), ZZ) + A.setitem(0, 0, ZZ(0)) + assert A == SDM({0:{1:ZZ(1)}}, (2, 2), ZZ) + # This time the row is there but column is empty + A.setitem(0, 0, ZZ(0)) + assert A == SDM({0:{1:ZZ(1)}}, (2, 2), ZZ) + raises(IndexError, lambda: A.setitem(2, 0, ZZ(1))) + raises(IndexError, lambda: A.setitem(0, 2, ZZ(1))) + + +def test_SDM_extract_slice(): + A = SDM({0:{0:ZZ(1), 1:ZZ(2)}, 1:{0:ZZ(3), 1:ZZ(4)}}, (2, 2), ZZ) + B = A.extract_slice(slice(1, 2), slice(1, 2)) + assert B == SDM({0:{0:ZZ(4)}}, (1, 1), ZZ) + + +def test_SDM_extract(): + A = SDM({0:{0:ZZ(1), 1:ZZ(2)}, 1:{0:ZZ(3), 1:ZZ(4)}}, (2, 2), ZZ) + B = A.extract([1], [1]) + assert B == SDM({0:{0:ZZ(4)}}, (1, 1), ZZ) + B = A.extract([1, 0], [1, 0]) + assert B == SDM({0:{0:ZZ(4), 1:ZZ(3)}, 1:{0:ZZ(2), 1:ZZ(1)}}, (2, 2), ZZ) + B = A.extract([1, 1], [1, 1]) + assert B == SDM({0:{0:ZZ(4), 1:ZZ(4)}, 1:{0:ZZ(4), 1:ZZ(4)}}, (2, 2), ZZ) + B = A.extract([-1], [-1]) + assert B == SDM({0:{0:ZZ(4)}}, (1, 1), ZZ) + + A = SDM({}, (2, 2), ZZ) + B = A.extract([0, 1, 0], [0, 0]) + assert B == SDM({}, (3, 2), ZZ) + + A = SDM({0:{0:ZZ(1), 1:ZZ(2)}, 1:{0:ZZ(3), 1:ZZ(4)}}, (2, 2), ZZ) + assert A.extract([], []) == SDM.zeros((0, 0), ZZ) + assert A.extract([1], []) == SDM.zeros((1, 0), ZZ) + assert A.extract([], [1]) == SDM.zeros((0, 1), ZZ) + + raises(IndexError, lambda: A.extract([2], [0])) + raises(IndexError, lambda: A.extract([0], [2])) + raises(IndexError, lambda: A.extract([-3], [0])) + raises(IndexError, lambda: A.extract([0], [-3])) + + +def test_SDM_zeros(): + A = SDM.zeros((2, 2), ZZ) + assert A.domain == ZZ + assert A.shape == (2, 2) + assert dict(A) == {} + +def test_SDM_ones(): + A = SDM.ones((1, 2), QQ) + assert A.domain == QQ + assert A.shape == (1, 2) + assert dict(A) == {0:{0:QQ(1), 1:QQ(1)}} + +def test_SDM_eye(): + A = SDM.eye((2, 2), ZZ) + assert A.domain == ZZ + assert A.shape == (2, 2) + assert dict(A) == {0:{0:ZZ(1)}, 1:{1:ZZ(1)}} + + +def test_SDM_diag(): + A = SDM.diag([ZZ(1), ZZ(2)], ZZ, (2, 3)) + assert A == SDM({0:{0:ZZ(1)}, 1:{1:ZZ(2)}}, (2, 3), ZZ) + + +def test_SDM_transpose(): + A = SDM({0:{0:ZZ(1), 1:ZZ(2)}, 1:{0:ZZ(3), 1:ZZ(4)}}, (2, 2), ZZ) + B = SDM({0:{0:ZZ(1), 1:ZZ(3)}, 1:{0:ZZ(2), 1:ZZ(4)}}, (2, 2), ZZ) + assert A.transpose() == B + + A = SDM({0:{1:ZZ(2)}}, (2, 2), ZZ) + B = SDM({1:{0:ZZ(2)}}, (2, 2), ZZ) + assert A.transpose() == B + + A = SDM({0:{1:ZZ(2)}}, (1, 2), ZZ) + B = SDM({1:{0:ZZ(2)}}, (2, 1), ZZ) + assert A.transpose() == B + + +def test_SDM_mul(): + A = SDM({0:{0:ZZ(2)}}, (2, 2), ZZ) + B = SDM({0:{0:ZZ(4)}}, (2, 2), ZZ) + assert A*ZZ(2) == B + assert ZZ(2)*A == B + + raises(TypeError, lambda: A*QQ(1, 2)) + raises(TypeError, lambda: QQ(1, 2)*A) + + +def test_SDM_mul_elementwise(): + A = SDM({0:{0:ZZ(2), 1:ZZ(2)}}, (2, 2), ZZ) + B = SDM({0:{0:ZZ(4)}, 1:{0:ZZ(3)}}, (2, 2), ZZ) + C = SDM({0:{0:ZZ(8)}}, (2, 2), ZZ) + assert A.mul_elementwise(B) == C + assert B.mul_elementwise(A) == C + + Aq = A.convert_to(QQ) + A1 = SDM({0:{0:ZZ(1)}}, (1, 1), ZZ) + + raises(DMDomainError, lambda: Aq.mul_elementwise(B)) + raises(DMShapeError, lambda: A1.mul_elementwise(B)) + + +def test_SDM_matmul(): + A = SDM({0:{0:ZZ(2)}}, (2, 2), ZZ) + B = SDM({0:{0:ZZ(4)}}, (2, 2), ZZ) + assert A.matmul(A) == A*A == B + + C = SDM({0:{0:ZZ(2)}}, (2, 2), QQ) + raises(DMDomainError, lambda: A.matmul(C)) + + A = SDM({0:{0:ZZ(1), 1:ZZ(2)}, 1:{0:ZZ(3), 1:ZZ(4)}}, (2, 2), ZZ) + B = SDM({0:{0:ZZ(7), 1:ZZ(10)}, 1:{0:ZZ(15), 1:ZZ(22)}}, (2, 2), ZZ) + assert A.matmul(A) == A*A == B + + A22 = SDM({0:{0:ZZ(4)}}, (2, 2), ZZ) + A32 = SDM({0:{0:ZZ(2)}}, (3, 2), ZZ) + A23 = SDM({0:{0:ZZ(4)}}, (2, 3), ZZ) + A33 = SDM({0:{0:ZZ(8)}}, (3, 3), ZZ) + A22 = SDM({0:{0:ZZ(8)}}, (2, 2), ZZ) + assert A32.matmul(A23) == A33 + assert A23.matmul(A32) == A22 + # XXX: @ not supported by SDM... + #assert A32.matmul(A23) == A32 @ A23 == A33 + #assert A23.matmul(A32) == A23 @ A32 == A22 + #raises(DMShapeError, lambda: A23 @ A22) + raises(DMShapeError, lambda: A23.matmul(A22)) + + A = SDM({0: {0: ZZ(-1), 1: ZZ(1)}}, (1, 2), ZZ) + B = SDM({0: {0: ZZ(-1)}, 1: {0: ZZ(-1)}}, (2, 1), ZZ) + assert A.matmul(B) == A*B == SDM({}, (1, 1), ZZ) + + +def test_matmul_exraw(): + + def dm(d): + result = {} + for i, row in d.items(): + row = {j:val for j, val in row.items() if val} + if row: + result[i] = row + return SDM(result, (2, 2), EXRAW) + + values = [S.NegativeInfinity, S.NegativeOne, S.Zero, S.One, S.Infinity] + for a, b, c, d in product(*[values]*4): + Ad = dm({0: {0:a, 1:b}, 1: {0:c, 1:d}}) + Ad2 = dm({0: {0:a*a + b*c, 1:a*b + b*d}, 1:{0:c*a + d*c, 1: c*b + d*d}}) + assert Ad * Ad == Ad2 + + +def test_SDM_add(): + A = SDM({0:{1:ZZ(1)}, 1:{0:ZZ(2), 1:ZZ(3)}}, (2, 2), ZZ) + B = SDM({0:{0:ZZ(1)}, 1:{0:ZZ(-2), 1:ZZ(3)}}, (2, 2), ZZ) + C = SDM({0:{0:ZZ(1), 1:ZZ(1)}, 1:{1:ZZ(6)}}, (2, 2), ZZ) + assert A.add(B) == B.add(A) == A + B == B + A == C + + A = SDM({0:{1:ZZ(1)}}, (2, 2), ZZ) + B = SDM({0:{0:ZZ(1)}, 1:{0:ZZ(-2), 1:ZZ(3)}}, (2, 2), ZZ) + C = SDM({0:{0:ZZ(1), 1:ZZ(1)}, 1:{0:ZZ(-2), 1:ZZ(3)}}, (2, 2), ZZ) + assert A.add(B) == B.add(A) == A + B == B + A == C + + raises(TypeError, lambda: A + []) + + +def test_SDM_sub(): + A = SDM({0:{1:ZZ(1)}, 1:{0:ZZ(2), 1:ZZ(3)}}, (2, 2), ZZ) + B = SDM({0:{0:ZZ(1)}, 1:{0:ZZ(-2), 1:ZZ(3)}}, (2, 2), ZZ) + C = SDM({0:{0:ZZ(-1), 1:ZZ(1)}, 1:{0:ZZ(4)}}, (2, 2), ZZ) + assert A.sub(B) == A - B == C + + raises(TypeError, lambda: A - []) + + +def test_SDM_neg(): + A = SDM({0:{1:ZZ(1)}, 1:{0:ZZ(2), 1:ZZ(3)}}, (2, 2), ZZ) + B = SDM({0:{1:ZZ(-1)}, 1:{0:ZZ(-2), 1:ZZ(-3)}}, (2, 2), ZZ) + assert A.neg() == -A == B + + +def test_SDM_convert_to(): + A = SDM({0:{1:ZZ(1)}, 1:{0:ZZ(2), 1:ZZ(3)}}, (2, 2), ZZ) + B = SDM({0:{1:QQ(1)}, 1:{0:QQ(2), 1:QQ(3)}}, (2, 2), QQ) + C = A.convert_to(QQ) + assert C == B + assert C.domain == QQ + + D = A.convert_to(ZZ) + assert D == A + assert D.domain == ZZ + + +def test_SDM_hstack(): + A = SDM({0:{1:ZZ(1)}}, (2, 2), ZZ) + B = SDM({1:{1:ZZ(1)}}, (2, 2), ZZ) + AA = SDM({0:{1:ZZ(1), 3:ZZ(1)}}, (2, 4), ZZ) + AB = SDM({0:{1:ZZ(1)}, 1:{3:ZZ(1)}}, (2, 4), ZZ) + assert SDM.hstack(A) == A + assert SDM.hstack(A, A) == AA + assert SDM.hstack(A, B) == AB + + +def test_SDM_vstack(): + A = SDM({0:{1:ZZ(1)}}, (2, 2), ZZ) + B = SDM({1:{1:ZZ(1)}}, (2, 2), ZZ) + AA = SDM({0:{1:ZZ(1)}, 2:{1:ZZ(1)}}, (4, 2), ZZ) + AB = SDM({0:{1:ZZ(1)}, 3:{1:ZZ(1)}}, (4, 2), ZZ) + assert SDM.vstack(A) == A + assert SDM.vstack(A, A) == AA + assert SDM.vstack(A, B) == AB + + +def test_SDM_applyfunc(): + A = SDM({0:{1:ZZ(1)}}, (2, 2), ZZ) + B = SDM({0:{1:ZZ(2)}}, (2, 2), ZZ) + assert A.applyfunc(lambda x: 2*x, ZZ) == B + + +def test_SDM_inv(): + A = SDM({0:{0:QQ(1), 1:QQ(2)}, 1:{0:QQ(3), 1:QQ(4)}}, (2, 2), QQ) + B = SDM({0:{0:QQ(-2), 1:QQ(1)}, 1:{0:QQ(3, 2), 1:QQ(-1, 2)}}, (2, 2), QQ) + assert A.inv() == B + + +def test_SDM_det(): + A = SDM({0:{0:QQ(1), 1:QQ(2)}, 1:{0:QQ(3), 1:QQ(4)}}, (2, 2), QQ) + assert A.det() == QQ(-2) + + +def test_SDM_lu(): + A = SDM({0:{0:QQ(1), 1:QQ(2)}, 1:{0:QQ(3), 1:QQ(4)}}, (2, 2), QQ) + L = SDM({0:{0:QQ(1)}, 1:{0:QQ(3), 1:QQ(1)}}, (2, 2), QQ) + #U = SDM({0:{0:QQ(1), 1:QQ(2)}, 1:{0:QQ(3), 1:QQ(-2)}}, (2, 2), QQ) + #swaps = [] + # This doesn't quite work. U has some nonzero elements in the lower part. + #assert A.lu() == (L, U, swaps) + assert A.lu()[0] == L + + +def test_SDM_lu_solve(): + A = SDM({0:{0:QQ(1), 1:QQ(2)}, 1:{0:QQ(3), 1:QQ(4)}}, (2, 2), QQ) + b = SDM({0:{0:QQ(1)}, 1:{0:QQ(2)}}, (2, 1), QQ) + x = SDM({1:{0:QQ(1, 2)}}, (2, 1), QQ) + assert A.matmul(x) == b + assert A.lu_solve(b) == x + + +def test_SDM_charpoly(): + A = SDM({0:{0:ZZ(1), 1:ZZ(2)}, 1:{0:ZZ(3), 1:ZZ(4)}}, (2, 2), ZZ) + assert A.charpoly() == [ZZ(1), ZZ(-5), ZZ(-2)] + + +def test_SDM_nullspace(): + # More tests are in test_nullspace.py + A = SDM({0:{0:QQ(1), 1:QQ(1)}}, (2, 2), QQ) + assert A.nullspace()[0] == SDM({0:{0:QQ(-1), 1:QQ(1)}}, (1, 2), QQ) + + +def test_SDM_rref(): + # More tests are in test_rref.py + + A = SDM({0:{0:QQ(1), 1:QQ(2)}, + 1:{0:QQ(3), 1:QQ(4)}}, (2, 2), QQ) + A_rref = SDM({0:{0:QQ(1)}, 1:{1:QQ(1)}}, (2, 2), QQ) + assert A.rref() == (A_rref, [0, 1]) + + A = SDM({0: {0: QQ(1), 1: QQ(2), 2: QQ(2)}, + 1: {0: QQ(3), 2: QQ(4)}}, (2, 3), ZZ) + A_rref = SDM({0: {0: QQ(1,1), 2: QQ(4,3)}, + 1: {1: QQ(1,1), 2: QQ(1,3)}}, (2, 3), QQ) + assert A.rref() == (A_rref, [0, 1]) + + +def test_SDM_particular(): + A = SDM({0:{0:QQ(1)}}, (2, 2), QQ) + Apart = SDM.zeros((1, 2), QQ) + assert A.particular() == Apart + + +def test_SDM_is_zero_matrix(): + A = SDM({0: {0: QQ(1)}}, (2, 2), QQ) + Azero = SDM.zeros((1, 2), QQ) + assert A.is_zero_matrix() is False + assert Azero.is_zero_matrix() is True + + +def test_SDM_is_upper(): + A = SDM({0: {0: QQ(1), 1: QQ(2), 2: QQ(3), 3: QQ(4)}, + 1: {1: QQ(5), 2: QQ(6), 3: QQ(7)}, + 2: {2: QQ(8), 3: QQ(9)}}, (3, 4), QQ) + B = SDM({0: {0: QQ(1), 1: QQ(2), 2: QQ(3), 3: QQ(4)}, + 1: {1: QQ(5), 2: QQ(6), 3: QQ(7)}, + 2: {1: QQ(7), 2: QQ(8), 3: QQ(9)}}, (3, 4), QQ) + assert A.is_upper() is True + assert B.is_upper() is False + + +def test_SDM_is_lower(): + A = SDM({0: {0: QQ(1), 1: QQ(2), 2: QQ(3), 3: QQ(4)}, + 1: {1: QQ(5), 2: QQ(6), 3: QQ(7)}, + 2: {2: QQ(8), 3: QQ(9)}}, (3, 4), QQ + ).transpose() + B = SDM({0: {0: QQ(1), 1: QQ(2), 2: QQ(3), 3: QQ(4)}, + 1: {1: QQ(5), 2: QQ(6), 3: QQ(7)}, + 2: {1: QQ(7), 2: QQ(8), 3: QQ(9)}}, (3, 4), QQ + ).transpose() + assert A.is_lower() is True + assert B.is_lower() is False diff --git a/phivenv/Lib/site-packages/sympy/polys/matrices/tests/test_xxm.py b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/test_xxm.py new file mode 100644 index 0000000000000000000000000000000000000000..628d66d15f5db82718231ba8f89bc0dadd393594 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/matrices/tests/test_xxm.py @@ -0,0 +1,1023 @@ +# +# Test basic features of DDM, SDM and DFM. +# +# These three types are supposed to be interchangeable, so we should use the +# same tests for all of them for the most part. +# +# The tests here cover the basic part of the interface that the three types +# should expose and that DomainMatrix should mostly rely on. +# +# More in-depth tests of the heavier algorithms like rref etc should go in +# their own test files. +# +# Any new methods added to the DDM, SDM or DFM classes should be tested here +# and added to all classes. +# + +from sympy.external.gmpy import GROUND_TYPES + +from sympy import ZZ, QQ, GF, ZZ_I, symbols + +from sympy.polys.matrices.exceptions import ( + DMBadInputError, + DMDomainError, + DMNonSquareMatrixError, + DMNonInvertibleMatrixError, + DMShapeError, +) + +from sympy.polys.matrices.domainmatrix import DM, DomainMatrix, DDM, SDM, DFM + +from sympy.testing.pytest import raises, skip +import pytest + + +def test_XXM_constructors(): + """Test the DDM, etc constructors.""" + + lol = [ + [ZZ(1), ZZ(2)], + [ZZ(3), ZZ(4)], + [ZZ(5), ZZ(6)], + ] + dod = { + 0: {0: ZZ(1), 1: ZZ(2)}, + 1: {0: ZZ(3), 1: ZZ(4)}, + 2: {0: ZZ(5), 1: ZZ(6)}, + } + + lol_0x0 = [] + lol_0x2 = [] + lol_2x0 = [[], []] + dod_0x0 = {} + dod_0x2 = {} + dod_2x0 = {} + + lol_bad = [ + [ZZ(1), ZZ(2)], + [ZZ(3), ZZ(4)], + [ZZ(5), ZZ(6), ZZ(7)], + ] + dod_bad = { + 0: {0: ZZ(1), 1: ZZ(2)}, + 1: {0: ZZ(3), 1: ZZ(4)}, + 2: {0: ZZ(5), 1: ZZ(6), 2: ZZ(7)}, + } + + XDM_dense = [DDM] + XDM_sparse = [SDM] + + if GROUND_TYPES == 'flint': + XDM_dense.append(DFM) + + for XDM in XDM_dense: + + A = XDM(lol, (3, 2), ZZ) + assert A.rows == 3 + assert A.cols == 2 + assert A.domain == ZZ + assert A.shape == (3, 2) + if XDM is not DFM: + assert ZZ.of_type(A[0][0]) is True + else: + assert ZZ.of_type(A.rep[0, 0]) is True + + Adm = DomainMatrix(lol, (3, 2), ZZ) + if XDM is DFM: + assert Adm.rep == A + assert Adm.rep.to_ddm() != A + elif GROUND_TYPES == 'flint': + assert Adm.rep.to_ddm() == A + assert Adm.rep != A + else: + assert Adm.rep == A + assert Adm.rep.to_ddm() == A + + assert XDM(lol_0x0, (0, 0), ZZ).shape == (0, 0) + assert XDM(lol_0x2, (0, 2), ZZ).shape == (0, 2) + assert XDM(lol_2x0, (2, 0), ZZ).shape == (2, 0) + raises(DMBadInputError, lambda: XDM(lol, (2, 3), ZZ)) + raises(DMBadInputError, lambda: XDM(lol_bad, (3, 2), ZZ)) + raises(DMBadInputError, lambda: XDM(dod, (3, 2), ZZ)) + + for XDM in XDM_sparse: + + A = XDM(dod, (3, 2), ZZ) + assert A.rows == 3 + assert A.cols == 2 + assert A.domain == ZZ + assert A.shape == (3, 2) + assert ZZ.of_type(A[0][0]) is True + + assert DomainMatrix(dod, (3, 2), ZZ).rep == A + + assert XDM(dod_0x0, (0, 0), ZZ).shape == (0, 0) + assert XDM(dod_0x2, (0, 2), ZZ).shape == (0, 2) + assert XDM(dod_2x0, (2, 0), ZZ).shape == (2, 0) + raises(DMBadInputError, lambda: XDM(dod, (2, 3), ZZ)) + raises(DMBadInputError, lambda: XDM(lol, (3, 2), ZZ)) + raises(DMBadInputError, lambda: XDM(dod_bad, (3, 2), ZZ)) + + raises(DMBadInputError, lambda: DomainMatrix(lol, (2, 3), ZZ)) + raises(DMBadInputError, lambda: DomainMatrix(lol_bad, (3, 2), ZZ)) + raises(DMBadInputError, lambda: DomainMatrix(dod_bad, (3, 2), ZZ)) + + +def test_XXM_eq(): + """Test equality for DDM, SDM, DFM and DomainMatrix.""" + + lol1 = [[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]] + dod1 = {0: {0: ZZ(1), 1: ZZ(2)}, 1: {0: ZZ(3), 1: ZZ(4)}} + + lol2 = [[ZZ(1), ZZ(2)], [ZZ(3), ZZ(5)]] + dod2 = {0: {0: ZZ(1), 1: ZZ(2)}, 1: {0: ZZ(3), 1: ZZ(5)}} + + A1_ddm = DDM(lol1, (2, 2), ZZ) + A1_sdm = SDM(dod1, (2, 2), ZZ) + A1_dm_d = DomainMatrix(lol1, (2, 2), ZZ) + A1_dm_s = DomainMatrix(dod1, (2, 2), ZZ) + + A2_ddm = DDM(lol2, (2, 2), ZZ) + A2_sdm = SDM(dod2, (2, 2), ZZ) + A2_dm_d = DomainMatrix(lol2, (2, 2), ZZ) + A2_dm_s = DomainMatrix(dod2, (2, 2), ZZ) + + A1_all = [A1_ddm, A1_sdm, A1_dm_d, A1_dm_s] + A2_all = [A2_ddm, A2_sdm, A2_dm_d, A2_dm_s] + + if GROUND_TYPES == 'flint': + + A1_dfm = DFM([[1, 2], [3, 4]], (2, 2), ZZ) + A2_dfm = DFM([[1, 2], [3, 5]], (2, 2), ZZ) + + A1_all.append(A1_dfm) + A2_all.append(A2_dfm) + + for n, An in enumerate(A1_all): + for m, Am in enumerate(A1_all): + if n == m: + assert (An == Am) is True + assert (An != Am) is False + else: + assert (An == Am) is False + assert (An != Am) is True + + for n, An in enumerate(A2_all): + for m, Am in enumerate(A2_all): + if n == m: + assert (An == Am) is True + assert (An != Am) is False + else: + assert (An == Am) is False + assert (An != Am) is True + + for n, A1 in enumerate(A1_all): + for m, A2 in enumerate(A2_all): + assert (A1 == A2) is False + assert (A1 != A2) is True + + +def test_to_XXM(): + """Test to_ddm etc. for DDM, SDM, DFM and DomainMatrix.""" + + lol = [[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]] + dod = {0: {0: ZZ(1), 1: ZZ(2)}, 1: {0: ZZ(3), 1: ZZ(4)}} + + A_ddm = DDM(lol, (2, 2), ZZ) + A_sdm = SDM(dod, (2, 2), ZZ) + A_dm_d = DomainMatrix(lol, (2, 2), ZZ) + A_dm_s = DomainMatrix(dod, (2, 2), ZZ) + + A_all = [A_ddm, A_sdm, A_dm_d, A_dm_s] + + if GROUND_TYPES == 'flint': + A_dfm = DFM(lol, (2, 2), ZZ) + A_all.append(A_dfm) + + for A in A_all: + assert A.to_ddm() == A_ddm + assert A.to_sdm() == A_sdm + if GROUND_TYPES != 'flint': + raises(NotImplementedError, lambda: A.to_dfm()) + assert A.to_dfm_or_ddm() == A_ddm + + # Add e.g. DDM.to_DM()? + # assert A.to_DM() == A_dm + + if GROUND_TYPES == 'flint': + for A in A_all: + assert A.to_dfm() == A_dfm + for K in [ZZ, QQ, GF(5), ZZ_I]: + if isinstance(A, DFM) and not DFM._supports_domain(K): + raises(NotImplementedError, lambda: A.convert_to(K)) + else: + A_K = A.convert_to(K) + if DFM._supports_domain(K): + A_dfm_K = A_dfm.convert_to(K) + assert A_K.to_dfm() == A_dfm_K + assert A_K.to_dfm_or_ddm() == A_dfm_K + else: + raises(NotImplementedError, lambda: A_K.to_dfm()) + assert A_K.to_dfm_or_ddm() == A_ddm.convert_to(K) + + +def test_DFM_domains(): + """Test which domains are supported by DFM.""" + + x, y = symbols('x, y') + + if GROUND_TYPES in ('python', 'gmpy'): + + supported = [] + flint_funcs = {} + not_supported = [ZZ, QQ, GF(5), QQ[x], QQ[x,y]] + + elif GROUND_TYPES == 'flint': + + import flint + supported = [ZZ, QQ] + flint_funcs = { + ZZ: flint.fmpz_mat, + QQ: flint.fmpq_mat, + GF(5): None, + } + not_supported = [ + # Other domains could be supported but not implemented as matrices + # in python-flint: + QQ[x], + QQ[x,y], + QQ.frac_field(x,y), + # Others would potentially never be supported by python-flint: + ZZ_I, + ] + + else: + assert False, "Unknown GROUND_TYPES: %s" % GROUND_TYPES + + for domain in supported: + assert DFM._supports_domain(domain) is True + if flint_funcs[domain] is not None: + assert DFM._get_flint_func(domain) == flint_funcs[domain] + for domain in not_supported: + assert DFM._supports_domain(domain) is False + raises(NotImplementedError, lambda: DFM._get_flint_func(domain)) + + +def _DM(lol, typ, K): + """Make a DM of type typ over K from lol.""" + A = DM(lol, K) + + if typ == 'DDM': + return A.to_ddm() + elif typ == 'SDM': + return A.to_sdm() + elif typ == 'DFM': + if GROUND_TYPES != 'flint': + skip("DFM not supported in this ground type") + return A.to_dfm() + else: + assert False, "Unknown type %s" % typ + + +def _DMZ(lol, typ): + """Make a DM of type typ over ZZ from lol.""" + return _DM(lol, typ, ZZ) + + +def _DMQ(lol, typ): + """Make a DM of type typ over QQ from lol.""" + return _DM(lol, typ, QQ) + + +def DM_ddm(lol, K): + """Make a DDM over K from lol.""" + return _DM(lol, 'DDM', K) + + +def DM_sdm(lol, K): + """Make a SDM over K from lol.""" + return _DM(lol, 'SDM', K) + + +def DM_dfm(lol, K): + """Make a DFM over K from lol.""" + return _DM(lol, 'DFM', K) + + +def DMZ_ddm(lol): + """Make a DDM from lol.""" + return _DMZ(lol, 'DDM') + + +def DMZ_sdm(lol): + """Make a SDM from lol.""" + return _DMZ(lol, 'SDM') + + +def DMZ_dfm(lol): + """Make a DFM from lol.""" + return _DMZ(lol, 'DFM') + + +def DMQ_ddm(lol): + """Make a DDM from lol.""" + return _DMQ(lol, 'DDM') + + +def DMQ_sdm(lol): + """Make a SDM from lol.""" + return _DMQ(lol, 'SDM') + + +def DMQ_dfm(lol): + """Make a DFM from lol.""" + return _DMQ(lol, 'DFM') + + +DM_all = [DM_ddm, DM_sdm, DM_dfm] +DMZ_all = [DMZ_ddm, DMZ_sdm, DMZ_dfm] +DMQ_all = [DMQ_ddm, DMQ_sdm, DMQ_dfm] + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XDM_getitem(DM): + """Test getitem for DDM, etc.""" + + lol = [[0, 1], [2, 0]] + A = DM(lol) + m, n = A.shape + + indices = [-3, -2, -1, 0, 1, 2] + + for i in indices: + for j in indices: + if -2 <= i < m and -2 <= j < n: + assert A.getitem(i, j) == ZZ(lol[i][j]) + else: + raises(IndexError, lambda: A.getitem(i, j)) + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XDM_setitem(DM): + """Test setitem for DDM, etc.""" + + A = DM([[0, 1, 2], [3, 4, 5]]) + + A.setitem(0, 0, ZZ(6)) + assert A == DM([[6, 1, 2], [3, 4, 5]]) + + A.setitem(0, 1, ZZ(7)) + assert A == DM([[6, 7, 2], [3, 4, 5]]) + + A.setitem(0, 2, ZZ(8)) + assert A == DM([[6, 7, 8], [3, 4, 5]]) + + A.setitem(0, -1, ZZ(9)) + assert A == DM([[6, 7, 9], [3, 4, 5]]) + + A.setitem(0, -2, ZZ(10)) + assert A == DM([[6, 10, 9], [3, 4, 5]]) + + A.setitem(0, -3, ZZ(11)) + assert A == DM([[11, 10, 9], [3, 4, 5]]) + + raises(IndexError, lambda: A.setitem(0, 3, ZZ(12))) + raises(IndexError, lambda: A.setitem(0, -4, ZZ(13))) + + A.setitem(1, 0, ZZ(14)) + assert A == DM([[11, 10, 9], [14, 4, 5]]) + + A.setitem(1, 1, ZZ(15)) + assert A == DM([[11, 10, 9], [14, 15, 5]]) + + A.setitem(-1, 1, ZZ(16)) + assert A == DM([[11, 10, 9], [14, 16, 5]]) + + A.setitem(-2, 1, ZZ(17)) + assert A == DM([[11, 17, 9], [14, 16, 5]]) + + raises(IndexError, lambda: A.setitem(2, 0, ZZ(18))) + raises(IndexError, lambda: A.setitem(-3, 0, ZZ(19))) + + A.setitem(1, 2, ZZ(0)) + assert A == DM([[11, 17, 9], [14, 16, 0]]) + + A.setitem(1, -2, ZZ(0)) + assert A == DM([[11, 17, 9], [14, 0, 0]]) + + A.setitem(1, -3, ZZ(0)) + assert A == DM([[11, 17, 9], [0, 0, 0]]) + + A.setitem(0, 0, ZZ(0)) + assert A == DM([[0, 17, 9], [0, 0, 0]]) + + A.setitem(0, -1, ZZ(0)) + assert A == DM([[0, 17, 0], [0, 0, 0]]) + + A.setitem(0, 0, ZZ(0)) + assert A == DM([[0, 17, 0], [0, 0, 0]]) + + A.setitem(0, -2, ZZ(0)) + assert A == DM([[0, 0, 0], [0, 0, 0]]) + + A.setitem(0, -3, ZZ(1)) + assert A == DM([[1, 0, 0], [0, 0, 0]]) + + +class _Sliced: + def __getitem__(self, item): + return item + + +_slice = _Sliced() + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_extract_slice(DM): + A = DM([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + assert A.extract_slice(*_slice[:,:]) == A + assert A.extract_slice(*_slice[1:,:]) == DM([[4, 5, 6], [7, 8, 9]]) + assert A.extract_slice(*_slice[1:,1:]) == DM([[5, 6], [8, 9]]) + assert A.extract_slice(*_slice[1:,:-1]) == DM([[4, 5], [7, 8]]) + assert A.extract_slice(*_slice[1:,:-1:2]) == DM([[4], [7]]) + assert A.extract_slice(*_slice[:,::2]) == DM([[1, 3], [4, 6], [7, 9]]) + assert A.extract_slice(*_slice[::2,:]) == DM([[1, 2, 3], [7, 8, 9]]) + assert A.extract_slice(*_slice[::2,::2]) == DM([[1, 3], [7, 9]]) + assert A.extract_slice(*_slice[::2,::-2]) == DM([[3, 1], [9, 7]]) + assert A.extract_slice(*_slice[::-2,::2]) == DM([[7, 9], [1, 3]]) + assert A.extract_slice(*_slice[::-2,::-2]) == DM([[9, 7], [3, 1]]) + assert A.extract_slice(*_slice[:,::-1]) == DM([[3, 2, 1], [6, 5, 4], [9, 8, 7]]) + assert A.extract_slice(*_slice[::-1,:]) == DM([[7, 8, 9], [4, 5, 6], [1, 2, 3]]) + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_extract(DM): + + A = DM([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + + assert A.extract([0, 1, 2], [0, 1, 2]) == A + assert A.extract([1, 2], [1, 2]) == DM([[5, 6], [8, 9]]) + assert A.extract([1, 2], [0, 1]) == DM([[4, 5], [7, 8]]) + assert A.extract([1, 2], [0, 2]) == DM([[4, 6], [7, 9]]) + assert A.extract([1, 2], [0]) == DM([[4], [7]]) + assert A.extract([1, 2], []) == DM([[1]]).zeros((2, 0), ZZ) + assert A.extract([], [0, 1, 2]) == DM([[1]]).zeros((0, 3), ZZ) + + raises(IndexError, lambda: A.extract([1, 2], [0, 3])) + raises(IndexError, lambda: A.extract([1, 2], [0, -4])) + raises(IndexError, lambda: A.extract([3, 1], [0, 1])) + raises(IndexError, lambda: A.extract([-4, 2], [3, 1])) + + B = DM([[0, 0, 0], [0, 0, 0], [0, 0, 0]]) + assert B.extract([1, 2], [1, 2]) == DM([[0, 0], [0, 0]]) + + +def test_XXM_str(): + + A = DomainMatrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]], (3, 3), ZZ) + + assert str(A) == \ + 'DomainMatrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]], (3, 3), ZZ)' + assert str(A.to_ddm()) == \ + '[[1, 2, 3], [4, 5, 6], [7, 8, 9]]' + assert str(A.to_sdm()) == \ + '{0: {0: 1, 1: 2, 2: 3}, 1: {0: 4, 1: 5, 2: 6}, 2: {0: 7, 1: 8, 2: 9}}' + + assert repr(A) == \ + 'DomainMatrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]], (3, 3), ZZ)' + assert repr(A.to_ddm()) == \ + 'DDM([[1, 2, 3], [4, 5, 6], [7, 8, 9]], (3, 3), ZZ)' + assert repr(A.to_sdm()) == \ + 'SDM({0: {0: 1, 1: 2, 2: 3}, 1: {0: 4, 1: 5, 2: 6}, 2: {0: 7, 1: 8, 2: 9}}, (3, 3), ZZ)' + + B = DomainMatrix({0: {0: ZZ(1), 1: ZZ(2)}, 1: {0: ZZ(3)}}, (2, 2), ZZ) + + assert str(B) == \ + 'DomainMatrix({0: {0: 1, 1: 2}, 1: {0: 3}}, (2, 2), ZZ)' + assert str(B.to_ddm()) == \ + '[[1, 2], [3, 0]]' + assert str(B.to_sdm()) == \ + '{0: {0: 1, 1: 2}, 1: {0: 3}}' + + assert repr(B) == \ + 'DomainMatrix({0: {0: 1, 1: 2}, 1: {0: 3}}, (2, 2), ZZ)' + + if GROUND_TYPES != 'gmpy': + assert repr(B.to_ddm()) == \ + 'DDM([[1, 2], [3, 0]], (2, 2), ZZ)' + assert repr(B.to_sdm()) == \ + 'SDM({0: {0: 1, 1: 2}, 1: {0: 3}}, (2, 2), ZZ)' + else: + assert repr(B.to_ddm()) == \ + 'DDM([[mpz(1), mpz(2)], [mpz(3), mpz(0)]], (2, 2), ZZ)' + assert repr(B.to_sdm()) == \ + 'SDM({0: {0: mpz(1), 1: mpz(2)}, 1: {0: mpz(3)}}, (2, 2), ZZ)' + + if GROUND_TYPES == 'flint': + + assert str(A.to_dfm()) == \ + '[[1, 2, 3], [4, 5, 6], [7, 8, 9]]' + assert str(B.to_dfm()) == \ + '[[1, 2], [3, 0]]' + + assert repr(A.to_dfm()) == \ + 'DFM([[1, 2, 3], [4, 5, 6], [7, 8, 9]], (3, 3), ZZ)' + assert repr(B.to_dfm()) == \ + 'DFM([[1, 2], [3, 0]], (2, 2), ZZ)' + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_from_list(DM): + T = type(DM([[0]])) + + lol = [[1, 2, 4], [4, 5, 6]] + lol_ZZ = [[ZZ(1), ZZ(2), ZZ(4)], [ZZ(4), ZZ(5), ZZ(6)]] + lol_ZZ_bad = [[ZZ(1), ZZ(2), ZZ(4)], [ZZ(4), ZZ(5), ZZ(6), ZZ(7)]] + + assert T.from_list(lol_ZZ, (2, 3), ZZ) == DM(lol) + raises(DMBadInputError, lambda: T.from_list(lol_ZZ_bad, (3, 2), ZZ)) + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_to_list(DM): + lol = [[1, 2, 4], [4, 5, 6]] + assert DM(lol).to_list() == [[ZZ(1), ZZ(2), ZZ(4)], [ZZ(4), ZZ(5), ZZ(6)]] + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_to_list_flat(DM): + lol = [[1, 2, 4], [4, 5, 6]] + assert DM(lol).to_list_flat() == [ZZ(1), ZZ(2), ZZ(4), ZZ(4), ZZ(5), ZZ(6)] + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_from_list_flat(DM): + T = type(DM([[0]])) + flat = [ZZ(1), ZZ(2), ZZ(4), ZZ(4), ZZ(5), ZZ(6)] + assert T.from_list_flat(flat, (2, 3), ZZ) == DM([[1, 2, 4], [4, 5, 6]]) + raises(DMBadInputError, lambda: T.from_list_flat(flat, (3, 3), ZZ)) + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_to_flat_nz(DM): + M = DM([[1, 2, 0], [0, 0, 0], [0, 0, 3]]) + elements = [ZZ(1), ZZ(2), ZZ(3)] + indices = ((0, 0), (0, 1), (2, 2)) + assert M.to_flat_nz() == (elements, (indices, M.shape)) + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_from_flat_nz(DM): + T = type(DM([[0]])) + elements = [ZZ(1), ZZ(2), ZZ(3)] + indices = ((0, 0), (0, 1), (2, 2)) + data = (indices, (3, 3)) + result = DM([[1, 2, 0], [0, 0, 0], [0, 0, 3]]) + assert T.from_flat_nz(elements, data, ZZ) == result + raises(DMBadInputError, lambda: T.from_flat_nz(elements, (indices, (2, 3)), ZZ)) + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_to_dod(DM): + dod = {0: {0: ZZ(1), 2: ZZ(4)}, 1: {0: ZZ(4), 1: ZZ(5), 2: ZZ(6)}} + assert DM([[1, 0, 4], [4, 5, 6]]).to_dod() == dod + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_from_dod(DM): + T = type(DM([[0]])) + dod = {0: {0: ZZ(1), 2: ZZ(4)}, 1: {0: ZZ(4), 1: ZZ(5), 2: ZZ(6)}} + assert T.from_dod(dod, (2, 3), ZZ) == DM([[1, 0, 4], [4, 5, 6]]) + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_to_dok(DM): + dod = {(0, 0): ZZ(1), (0, 2): ZZ(4), + (1, 0): ZZ(4), (1, 1): ZZ(5), (1, 2): ZZ(6)} + assert DM([[1, 0, 4], [4, 5, 6]]).to_dok() == dod + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_from_dok(DM): + T = type(DM([[0]])) + dod = {(0, 0): ZZ(1), (0, 2): ZZ(4), + (1, 0): ZZ(4), (1, 1): ZZ(5), (1, 2): ZZ(6)} + assert T.from_dok(dod, (2, 3), ZZ) == DM([[1, 0, 4], [4, 5, 6]]) + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_iter_values(DM): + values = [ZZ(1), ZZ(4), ZZ(4), ZZ(5), ZZ(6)] + assert sorted(DM([[1, 0, 4], [4, 5, 6]]).iter_values()) == values + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_iter_items(DM): + items = [((0, 0), ZZ(1)), ((0, 2), ZZ(4)), + ((1, 0), ZZ(4)), ((1, 1), ZZ(5)), ((1, 2), ZZ(6))] + assert sorted(DM([[1, 0, 4], [4, 5, 6]]).iter_items()) == items + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_from_ddm(DM): + T = type(DM([[0]])) + ddm = DDM([[1, 2, 4], [4, 5, 6]], (2, 3), ZZ) + assert T.from_ddm(ddm) == DM([[1, 2, 4], [4, 5, 6]]) + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_zeros(DM): + T = type(DM([[0]])) + assert T.zeros((2, 3), ZZ) == DM([[0, 0, 0], [0, 0, 0]]) + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_ones(DM): + T = type(DM([[0]])) + assert T.ones((2, 3), ZZ) == DM([[1, 1, 1], [1, 1, 1]]) + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_eye(DM): + T = type(DM([[0]])) + assert T.eye(3, ZZ) == DM([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + assert T.eye((3, 2), ZZ) == DM([[1, 0], [0, 1], [0, 0]]) + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_diag(DM): + T = type(DM([[0]])) + assert T.diag([1, 2, 3], ZZ) == DM([[1, 0, 0], [0, 2, 0], [0, 0, 3]]) + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_transpose(DM): + A = DM([[1, 2, 3], [4, 5, 6]]) + assert A.transpose() == DM([[1, 4], [2, 5], [3, 6]]) + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_add(DM): + A = DM([[1, 2, 3], [4, 5, 6]]) + B = DM([[1, 2, 3], [4, 5, 6]]) + C = DM([[2, 4, 6], [8, 10, 12]]) + assert A.add(B) == C + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_sub(DM): + A = DM([[1, 2, 3], [4, 5, 6]]) + B = DM([[1, 2, 3], [4, 5, 6]]) + C = DM([[0, 0, 0], [0, 0, 0]]) + assert A.sub(B) == C + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_mul(DM): + A = DM([[1, 2, 3], [4, 5, 6]]) + b = ZZ(2) + assert A.mul(b) == DM([[2, 4, 6], [8, 10, 12]]) + assert A.rmul(b) == DM([[2, 4, 6], [8, 10, 12]]) + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_matmul(DM): + A = DM([[1, 2, 3], [4, 5, 6]]) + B = DM([[1, 2], [3, 4], [5, 6]]) + C = DM([[22, 28], [49, 64]]) + assert A.matmul(B) == C + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_mul_elementwise(DM): + A = DM([[1, 2, 3], [4, 5, 6]]) + B = DM([[1, 2, 3], [4, 5, 6]]) + C = DM([[1, 4, 9], [16, 25, 36]]) + assert A.mul_elementwise(B) == C + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_neg(DM): + A = DM([[1, 2, 3], [4, 5, 6]]) + C = DM([[-1, -2, -3], [-4, -5, -6]]) + assert A.neg() == C + + +@pytest.mark.parametrize('DM', DM_all) +def test_XXM_convert_to(DM): + A = DM([[1, 2, 3], [4, 5, 6]], ZZ) + B = DM([[1, 2, 3], [4, 5, 6]], QQ) + assert A.convert_to(QQ) == B + assert B.convert_to(ZZ) == A + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_scc(DM): + A = DM([ + [0, 1, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 1], + [0, 0, 0, 0, 1, 0], + [0, 0, 0, 1, 0, 1]]) + assert A.scc() == [[0, 1], [2], [3, 5], [4]] + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_hstack(DM): + A = DM([[1, 2, 3], [4, 5, 6]]) + B = DM([[7, 8], [9, 10]]) + C = DM([[1, 2, 3, 7, 8], [4, 5, 6, 9, 10]]) + ABC = DM([[1, 2, 3, 7, 8, 1, 2, 3, 7, 8], + [4, 5, 6, 9, 10, 4, 5, 6, 9, 10]]) + assert A.hstack(B) == C + assert A.hstack(B, C) == ABC + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_vstack(DM): + A = DM([[1, 2, 3], [4, 5, 6]]) + B = DM([[7, 8, 9]]) + C = DM([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + ABC = DM([[1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9]]) + assert A.vstack(B) == C + assert A.vstack(B, C) == ABC + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_applyfunc(DM): + A = DM([[1, 2, 3], [4, 5, 6]]) + B = DM([[2, 4, 6], [8, 10, 12]]) + assert A.applyfunc(lambda x: 2*x, ZZ) == B + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_is_upper(DM): + assert DM([[1, 2, 3], [0, 5, 6]]).is_upper() is True + assert DM([[1, 2, 3], [4, 5, 6]]).is_upper() is False + assert DM([]).is_upper() is True + assert DM([[], []]).is_upper() is True + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_is_lower(DM): + assert DM([[1, 0, 0], [4, 5, 0]]).is_lower() is True + assert DM([[1, 2, 3], [4, 5, 6]]).is_lower() is False + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_is_diagonal(DM): + assert DM([[1, 0, 0], [0, 5, 0]]).is_diagonal() is True + assert DM([[1, 2, 3], [4, 5, 6]]).is_diagonal() is False + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_diagonal(DM): + assert DM([[1, 0, 0], [0, 5, 0]]).diagonal() == [1, 5] + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_is_zero_matrix(DM): + assert DM([[0, 0, 0], [0, 0, 0]]).is_zero_matrix() is True + assert DM([[1, 0, 0], [0, 0, 0]]).is_zero_matrix() is False + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_det_ZZ(DM): + assert DM([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).det() == 0 + assert DM([[1, 2, 3], [4, 5, 6], [7, 8, 10]]).det() == -3 + + +@pytest.mark.parametrize('DM', DMQ_all) +def test_XXM_det_QQ(DM): + dM1 = DM([[(1,2), (2,3)], [(3,4), (4,5)]]) + assert dM1.det() == QQ(-1,10) + + +@pytest.mark.parametrize('DM', DMQ_all) +def test_XXM_inv_QQ(DM): + dM1 = DM([[(1,2), (2,3)], [(3,4), (4,5)]]) + dM2 = DM([[(-8,1), (20,3)], [(15,2), (-5,1)]]) + assert dM1.inv() == dM2 + assert dM1.matmul(dM2) == DM([[1, 0], [0, 1]]) + + dM3 = DM([[(1,2), (2,3)], [(1,4), (1,3)]]) + raises(DMNonInvertibleMatrixError, lambda: dM3.inv()) + + dM4 = DM([[(1,2), (2,3), (3,4)], [(1,4), (1,3), (1,2)]]) + raises(DMNonSquareMatrixError, lambda: dM4.inv()) + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_inv_ZZ(DM): + dM1 = DM([[1, 2, 3], [4, 5, 6], [7, 8, 10]]) + # XXX: Maybe this should return a DM over QQ instead? + # XXX: Handle unimodular matrices? + raises(DMDomainError, lambda: dM1.inv()) + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_charpoly_ZZ(DM): + dM1 = DM([[1, 2, 3], [4, 5, 6], [7, 8, 10]]) + assert dM1.charpoly() == [1, -16, -12, 3] + + +@pytest.mark.parametrize('DM', DMQ_all) +def test_XXM_charpoly_QQ(DM): + dM1 = DM([[(1,2), (2,3)], [(3,4), (4,5)]]) + assert dM1.charpoly() == [QQ(1,1), QQ(-13,10), QQ(-1,10)] + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_lu_solve_ZZ(DM): + dM1 = DM([[1, 2, 3], [4, 5, 6], [7, 8, 10]]) + dM2 = DM([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + raises(DMDomainError, lambda: dM1.lu_solve(dM2)) + + +@pytest.mark.parametrize('DM', DMQ_all) +def test_XXM_lu_solve_QQ(DM): + dM1 = DM([[1, 2, 3], [4, 5, 6], [7, 8, 10]]) + dM2 = DM([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + dM3 = DM([[(-2,3),(-4,3),(1,1)],[(-2,3),(11,3),(-2,1)],[(1,1),(-2,1),(1,1)]]) + assert dM1.lu_solve(dM2) == dM3 == dM1.inv() + + dM4 = DM([[1, 2, 3], [4, 5, 6]]) + dM5 = DM([[1, 0], [0, 1], [0, 0]]) + raises(DMShapeError, lambda: dM4.lu_solve(dM5)) + + +@pytest.mark.parametrize('DM', DMQ_all) +def test_XXM_nullspace_QQ(DM): + dM1 = DM([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + # XXX: Change the signature to just return the nullspace. Possibly + # returning the rank or nullity makes sense but the list of nonpivots is + # not useful. + assert dM1.nullspace() == (DM([[1, -2, 1]]), [2]) + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_lll(DM): + M = DM([[1, 2, 3], [4, 5, 20]]) + M_lll = DM([[1, 2, 3], [-1, -5, 5]]) + T = DM([[1, 0], [-5, 1]]) + assert M.lll() == M_lll + assert M.lll_transform() == (M_lll, T) + assert T.matmul(M) == M_lll + + +@pytest.mark.parametrize('DM', DMQ_all) +def test_XXM_qr_mixed_signs(DM): + lol = [[QQ(1), QQ(-2)], [QQ(-3), QQ(4)]] + A = DM(lol) + Q, R = A.qr() + assert Q.matmul(R) == A + assert (Q.transpose().matmul(Q)).is_diagonal + assert R.is_upper + + +@pytest.mark.parametrize('DM', DMQ_all) +def test_XXM_qr_large_matrix(DM): + lol = [[QQ(i + j) for j in range(10)] for i in range(10)] + A = DM(lol) + Q, R = A.qr() + assert Q.matmul(R) == A + assert (Q.transpose().matmul(Q)).is_diagonal + assert R.is_upper + + +@pytest.mark.parametrize('DM', DMQ_all) +def test_XXM_qr_identity_matrix(DM): + T = type(DM([[0]])) + A = T.eye(3, QQ) + Q, R = A.qr() + assert Q == A + assert R == A + assert (Q.transpose().matmul(Q)).is_diagonal + assert R.is_upper + assert Q.shape == (3, 3) + assert R.shape == (3, 3) + + +@pytest.mark.parametrize('DM', DMQ_all) +def test_XXM_qr_square_matrix(DM): + lol = [[QQ(3), QQ(1)], [QQ(4), QQ(3)]] + A = DM(lol) + Q, R = A.qr() + assert Q.matmul(R) == A + assert (Q.transpose().matmul(Q)).is_diagonal + assert R.is_upper + + +@pytest.mark.parametrize('DM', DMQ_all) +def test_XXM_qr_matrix_with_zero_columns(DM): + lol = [[QQ(3), QQ(0)], [QQ(4), QQ(0)]] + A = DM(lol) + Q, R = A.qr() + assert Q.matmul(R) == A + assert (Q.transpose().matmul(Q)).is_diagonal + assert R.is_upper + + +@pytest.mark.parametrize('DM', DMQ_all) +def test_XXM_qr_linearly_dependent_columns(DM): + lol = [[QQ(1), QQ(2)], [QQ(2), QQ(4)]] + A = DM(lol) + Q, R = A.qr() + assert Q.matmul(R) == A + assert (Q.transpose().matmul(Q)).is_diagonal + assert R.is_upper + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_qr_non_field(DM): + lol = [[ZZ(3), ZZ(1)], [ZZ(4), ZZ(3)]] + A = DM(lol) + with pytest.raises(DMDomainError): + A.qr() + + +@pytest.mark.parametrize('DM', DMQ_all) +def test_XXM_qr_field(DM): + lol = [[QQ(3), QQ(1)], [QQ(4), QQ(3)]] + A = DM(lol) + Q, R = A.qr() + assert Q.matmul(R) == A + assert (Q.transpose().matmul(Q)).is_diagonal + assert R.is_upper + + +@pytest.mark.parametrize('DM', DMQ_all) +def test_XXM_qr_tall_matrix(DM): + lol = [[QQ(1), QQ(2)], [QQ(3), QQ(4)], [QQ(5), QQ(6)]] + A = DM(lol) + Q, R = A.qr() + assert Q.matmul(R) == A + assert (Q.transpose().matmul(Q)).is_diagonal + assert R.is_upper + + +@pytest.mark.parametrize('DM', DMQ_all) +def test_XXM_qr_wide_matrix(DM): + lol = [[QQ(1), QQ(2), QQ(3)], [QQ(4), QQ(5), QQ(6)]] + A = DM(lol) + Q, R = A.qr() + assert Q.matmul(R) == A + assert (Q.transpose().matmul(Q)).is_diagonal + assert R.is_upper + + +@pytest.mark.parametrize('DM', DMQ_all) +def test_XXM_qr_empty_matrix_0x0(DM): + T = type(DM([[0]])) + A = T.zeros((0, 0), QQ) + Q, R = A.qr() + assert Q.matmul(R).shape == A.shape + assert (Q.transpose().matmul(Q)).is_diagonal + assert R.is_upper + assert Q.shape == (0, 0) + assert R.shape == (0, 0) + + +@pytest.mark.parametrize('DM', DMQ_all) +def test_XXM_qr_empty_matrix_2x0(DM): + T = type(DM([[0]])) + A = T.zeros((2, 0), QQ) + Q, R = A.qr() + assert Q.matmul(R).shape == A.shape + assert (Q.transpose().matmul(Q)).is_diagonal + assert R.is_upper + assert Q.shape == (2, 0) + assert R.shape == (0, 0) + + +@pytest.mark.parametrize('DM', DMQ_all) +def test_XXM_qr_empty_matrix_0x2(DM): + T = type(DM([[0]])) + A = T.zeros((0, 2), QQ) + Q, R = A.qr() + assert Q.matmul(R).shape == A.shape + assert (Q.transpose().matmul(Q)).is_diagonal + assert R.is_upper + assert Q.shape == (0, 0) + assert R.shape == (0, 2) + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_fflu(DM): + A = DM([[1, 2], [3, 4]]) + P, L, D, U = A.fflu() + A_field = A.convert_to(QQ) + P_field = P.convert_to(QQ) + L_field = L.convert_to(QQ) + D_field = D.convert_to(QQ) + U_field = U.convert_to(QQ) + assert P.shape == A.shape + assert L.shape == A.shape + assert D.shape == A.shape + assert U.shape == A.shape + assert P == DM([[1, 0], [0, 1]]) + assert L == DM([[1, 0], [3, -2]]) + assert D == DM([[1, 0], [0, -2]]) + assert U == DM([[1, 2], [0, -2]]) + assert L_field.matmul(D_field.inv()).matmul(U_field) == P_field.matmul(A_field) diff --git a/phivenv/Lib/site-packages/sympy/polys/numberfields/__init__.py b/phivenv/Lib/site-packages/sympy/polys/numberfields/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..38403fdf80be22d47589a346d1b1878b982c3c93 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/numberfields/__init__.py @@ -0,0 +1,27 @@ +"""Computational algebraic field theory. """ + +__all__ = [ + 'minpoly', 'minimal_polynomial', + + 'field_isomorphism', 'primitive_element', 'to_number_field', + + 'isolate', + + 'round_two', + + 'prime_decomp', 'prime_valuation', + + 'galois_group', +] + +from .minpoly import minpoly, minimal_polynomial + +from .subfield import field_isomorphism, primitive_element, to_number_field + +from .utilities import isolate + +from .basis import round_two + +from .primes import prime_decomp, prime_valuation + +from .galoisgroups import galois_group diff --git a/phivenv/Lib/site-packages/sympy/polys/numberfields/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/numberfields/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf5b88638b0a92224907898fb826ba9381809356 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/numberfields/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/numberfields/__pycache__/basis.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/numberfields/__pycache__/basis.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3dfd076d211064ac5511be2160a139c55530f18b Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/numberfields/__pycache__/basis.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/numberfields/__pycache__/exceptions.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/numberfields/__pycache__/exceptions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37690e23a495ec9d409dd2bf163f0b99cd2a0c2a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/numberfields/__pycache__/exceptions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/numberfields/__pycache__/galois_resolvents.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/numberfields/__pycache__/galois_resolvents.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e631aec17cbe6566f19e09cf4ed064f3fa2c7d27 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/numberfields/__pycache__/galois_resolvents.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/numberfields/__pycache__/galoisgroups.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/numberfields/__pycache__/galoisgroups.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf6b07e379d6a24773f72cce4e3b8e540a530ebd Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/numberfields/__pycache__/galoisgroups.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/numberfields/__pycache__/minpoly.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/numberfields/__pycache__/minpoly.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6396793a3115e0261498c9542d60cf54cf7c6daa Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/numberfields/__pycache__/minpoly.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/numberfields/__pycache__/modules.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/numberfields/__pycache__/modules.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7160ff3e5ae108aaaa0670dfeae156bfd912781d Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/numberfields/__pycache__/modules.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/numberfields/__pycache__/primes.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/numberfields/__pycache__/primes.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a0cdc555bfdcf33a4cd5213dd147c2d44465677a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/numberfields/__pycache__/primes.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/numberfields/__pycache__/resolvent_lookup.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/numberfields/__pycache__/resolvent_lookup.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4da919d1a441fed6c1b6758c27d6b3f4525febc8 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/numberfields/__pycache__/resolvent_lookup.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/numberfields/__pycache__/subfield.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/numberfields/__pycache__/subfield.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60cf1863bb558a429fa7bc41457efbf9c5632895 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/numberfields/__pycache__/subfield.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/numberfields/__pycache__/utilities.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/numberfields/__pycache__/utilities.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..600bf1b72ca4c6fbf14815b374743692ef569016 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/numberfields/__pycache__/utilities.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/numberfields/basis.py b/phivenv/Lib/site-packages/sympy/polys/numberfields/basis.py new file mode 100644 index 0000000000000000000000000000000000000000..7c9cb41925973b3a10a80cc6ba1442cf44330971 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/numberfields/basis.py @@ -0,0 +1,246 @@ +"""Computing integral bases for number fields. """ + +from sympy.polys.polytools import Poly +from sympy.polys.domains.algebraicfield import AlgebraicField +from sympy.polys.domains.integerring import ZZ +from sympy.polys.domains.rationalfield import QQ +from sympy.utilities.decorator import public +from .modules import ModuleEndomorphism, ModuleHomomorphism, PowerBasis +from .utilities import extract_fundamental_discriminant + + +def _apply_Dedekind_criterion(T, p): + r""" + Apply the "Dedekind criterion" to test whether the order needs to be + enlarged relative to a given prime *p*. + """ + x = T.gen + T_bar = Poly(T, modulus=p) + lc, fl = T_bar.factor_list() + assert lc == 1 + g_bar = Poly(1, x, modulus=p) + for ti_bar, _ in fl: + g_bar *= ti_bar + h_bar = T_bar // g_bar + g = Poly(g_bar, domain=ZZ) + h = Poly(h_bar, domain=ZZ) + f = (g * h - T) // p + f_bar = Poly(f, modulus=p) + Z_bar = f_bar + for b in [g_bar, h_bar]: + Z_bar = Z_bar.gcd(b) + U_bar = T_bar // Z_bar + m = Z_bar.degree() + return U_bar, m + + +def nilradical_mod_p(H, p, q=None): + r""" + Compute the nilradical mod *p* for a given order *H*, and prime *p*. + + Explanation + =========== + + This is the ideal $I$ in $H/pH$ consisting of all elements some positive + power of which is zero in this quotient ring, i.e. is a multiple of *p*. + + Parameters + ========== + + H : :py:class:`~.Submodule` + The given order. + p : int + The rational prime. + q : int, optional + If known, the smallest power of *p* that is $>=$ the dimension of *H*. + If not provided, we compute it here. + + Returns + ======= + + :py:class:`~.Module` representing the nilradical mod *p* in *H*. + + References + ========== + + .. [1] Cohen, H. *A Course in Computational Algebraic Number Theory*. + (See Lemma 6.1.6.) + + """ + n = H.n + if q is None: + q = p + while q < n: + q *= p + phi = ModuleEndomorphism(H, lambda x: x**q) + return phi.kernel(modulus=p) + + +def _second_enlargement(H, p, q): + r""" + Perform the second enlargement in the Round Two algorithm. + """ + Ip = nilradical_mod_p(H, p, q=q) + B = H.parent.submodule_from_matrix(H.matrix * Ip.matrix, denom=H.denom) + C = B + p*H + E = C.endomorphism_ring() + phi = ModuleHomomorphism(H, E, lambda x: E.inner_endomorphism(x)) + gamma = phi.kernel(modulus=p) + G = H.parent.submodule_from_matrix(H.matrix * gamma.matrix, denom=H.denom * p) + H1 = G + H + return H1, Ip + + +@public +def round_two(T, radicals=None): + r""" + Zassenhaus's "Round 2" algorithm. + + Explanation + =========== + + Carry out Zassenhaus's "Round 2" algorithm on an irreducible polynomial + *T* over :ref:`ZZ` or :ref:`QQ`. This computes an integral basis and the + discriminant for the field $K = \mathbb{Q}[x]/(T(x))$. + + Alternatively, you may pass an :py:class:`~.AlgebraicField` instance, in + place of the polynomial *T*, in which case the algorithm is applied to the + minimal polynomial for the field's primitive element. + + Ordinarily this function need not be called directly, as one can instead + access the :py:meth:`~.AlgebraicField.maximal_order`, + :py:meth:`~.AlgebraicField.integral_basis`, and + :py:meth:`~.AlgebraicField.discriminant` methods of an + :py:class:`~.AlgebraicField`. + + Examples + ======== + + Working through an AlgebraicField: + + >>> from sympy import Poly, QQ + >>> from sympy.abc import x + >>> T = Poly(x ** 3 + x ** 2 - 2 * x + 8) + >>> K = QQ.alg_field_from_poly(T, "theta") + >>> print(K.maximal_order()) + Submodule[[2, 0, 0], [0, 2, 0], [0, 1, 1]]/2 + >>> print(K.discriminant()) + -503 + >>> print(K.integral_basis(fmt='sympy')) + [1, theta, theta/2 + theta**2/2] + + Calling directly: + + >>> from sympy import Poly + >>> from sympy.abc import x + >>> from sympy.polys.numberfields.basis import round_two + >>> T = Poly(x ** 3 + x ** 2 - 2 * x + 8) + >>> print(round_two(T)) + (Submodule[[2, 0, 0], [0, 2, 0], [0, 1, 1]]/2, -503) + + The nilradicals mod $p$ that are sometimes computed during the Round Two + algorithm may be useful in further calculations. Pass a dictionary under + `radicals` to receive these: + + >>> T = Poly(x**3 + 3*x**2 + 5) + >>> rad = {} + >>> ZK, dK = round_two(T, radicals=rad) + >>> print(rad) + {3: Submodule[[-1, 1, 0], [-1, 0, 1]]} + + Parameters + ========== + + T : :py:class:`~.Poly`, :py:class:`~.AlgebraicField` + Either (1) the irreducible polynomial over :ref:`ZZ` or :ref:`QQ` + defining the number field, or (2) an :py:class:`~.AlgebraicField` + representing the number field itself. + + radicals : dict, optional + This is a way for any $p$-radicals (if computed) to be returned by + reference. If desired, pass an empty dictionary. If the algorithm + reaches the point where it computes the nilradical mod $p$ of the ring + of integers $Z_K$, then an $\mathbb{F}_p$-basis for this ideal will be + stored in this dictionary under the key ``p``. This can be useful for + other algorithms, such as prime decomposition. + + Returns + ======= + + Pair ``(ZK, dK)``, where: + + ``ZK`` is a :py:class:`~sympy.polys.numberfields.modules.Submodule` + representing the maximal order. + + ``dK`` is the discriminant of the field $K = \mathbb{Q}[x]/(T(x))$. + + See Also + ======== + + .AlgebraicField.maximal_order + .AlgebraicField.integral_basis + .AlgebraicField.discriminant + + References + ========== + + .. [1] Cohen, H. *A Course in Computational Algebraic Number Theory.* + + """ + K = None + if isinstance(T, AlgebraicField): + K, T = T, T.ext.minpoly_of_element() + if ( not T.is_univariate + or not T.is_irreducible + or T.domain not in [ZZ, QQ]): + raise ValueError('Round 2 requires an irreducible univariate polynomial over ZZ or QQ.') + T, _ = T.make_monic_over_integers_by_scaling_roots() + n = T.degree() + D = T.discriminant() + D_modulus = ZZ.from_sympy(abs(D)) + # D must be 0 or 1 mod 4 (see Cohen Sec 4.4), which ensures we can write + # it in the form D = D_0 * F**2, where D_0 is 1 or a fundamental discriminant. + _, F = extract_fundamental_discriminant(D) + Ztheta = PowerBasis(K or T) + H = Ztheta.whole_submodule() + nilrad = None + while F: + # Next prime: + p, e = F.popitem() + U_bar, m = _apply_Dedekind_criterion(T, p) + if m == 0: + continue + # For a given prime p, the first enlargement of the order spanned by + # the current basis can be done in a simple way: + U = Ztheta.element_from_poly(Poly(U_bar, domain=ZZ)) + # TODO: + # Theory says only first m columns of the U//p*H term below are needed. + # Could be slightly more efficient to use only those. Maybe `Submodule` + # class should support a slice operator? + H = H.add(U // p * H, hnf_modulus=D_modulus) + if e <= m: + continue + # A second, and possibly more, enlargements for p will be needed. + # These enlargements require a more involved procedure. + q = p + while q < n: + q *= p + H1, nilrad = _second_enlargement(H, p, q) + while H1 != H: + H = H1 + H1, nilrad = _second_enlargement(H, p, q) + # Note: We do not store all nilradicals mod p, only the very last. This is + # because, unless computed against the entire integral basis, it might not + # be accurate. (In other words, if H was not already equal to ZK when we + # passed it to `_second_enlargement`, then we can't trust the nilradical + # so computed.) Example: if T(x) = x ** 3 + 15 * x ** 2 - 9 * x + 13, then + # F is divisible by 2, 3, and 7, and the nilradical mod 2 as computed above + # will not be accurate for the full, maximal order ZK. + if nilrad is not None and isinstance(radicals, dict): + radicals[p] = nilrad + ZK = H + # Pre-set expensive boolean properties which we already know to be true: + ZK._starts_with_unity = True + ZK._is_sq_maxrank_HNF = True + dK = (D * ZK.matrix.det() ** 2) // ZK.denom ** (2 * n) + return ZK, dK diff --git a/phivenv/Lib/site-packages/sympy/polys/numberfields/exceptions.py b/phivenv/Lib/site-packages/sympy/polys/numberfields/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..6e0d1ddc23c39295626fa036cf34974f50e4f53a --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/numberfields/exceptions.py @@ -0,0 +1,54 @@ +"""Special exception classes for numberfields. """ + + +class ClosureFailure(Exception): + r""" + Signals that a :py:class:`ModuleElement` which we tried to represent in a + certain :py:class:`Module` cannot in fact be represented there. + + Examples + ======== + + >>> from sympy.polys import Poly, cyclotomic_poly, ZZ + >>> from sympy.polys.matrices import DomainMatrix + >>> from sympy.polys.numberfields.modules import PowerBasis, to_col + >>> T = Poly(cyclotomic_poly(5)) + >>> A = PowerBasis(T) + >>> B = A.submodule_from_matrix(2 * DomainMatrix.eye(4, ZZ)) + + Because we are in a cyclotomic field, the power basis ``A`` is an integral + basis, and the submodule ``B`` is just the ideal $(2)$. Therefore ``B`` can + represent an element having all even coefficients over the power basis: + + >>> a1 = A(to_col([2, 4, 6, 8])) + >>> print(B.represent(a1)) + DomainMatrix([[1], [2], [3], [4]], (4, 1), ZZ) + + but ``B`` cannot represent an element with an odd coefficient: + + >>> a2 = A(to_col([1, 2, 2, 2])) + >>> B.represent(a2) + Traceback (most recent call last): + ... + ClosureFailure: Element in QQ-span but not ZZ-span of this basis. + + """ + pass + + +class StructureError(Exception): + r""" + Represents cases in which an algebraic structure was expected to have a + certain property, or be of a certain type, but was not. + """ + pass + + +class MissingUnityError(StructureError): + r"""Structure should contain a unity element but does not.""" + pass + + +__all__ = [ + 'ClosureFailure', 'StructureError', 'MissingUnityError', +] diff --git a/phivenv/Lib/site-packages/sympy/polys/numberfields/galois_resolvents.py b/phivenv/Lib/site-packages/sympy/polys/numberfields/galois_resolvents.py new file mode 100644 index 0000000000000000000000000000000000000000..5d73b56870a498f09102787da3517e7520edb3db --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/numberfields/galois_resolvents.py @@ -0,0 +1,676 @@ +r""" +Galois resolvents + +Each of the functions in ``sympy.polys.numberfields.galoisgroups`` that +computes Galois groups for a particular degree $n$ uses resolvents. Given the +polynomial $T$ whose Galois group is to be computed, a resolvent is a +polynomial $R$ whose roots are defined as functions of the roots of $T$. + +One way to compute the coefficients of $R$ is by approximating the roots of $T$ +to sufficient precision. This module defines a :py:class:`~.Resolvent` class +that handles this job, determining the necessary precision, and computing $R$. + +In some cases, the coefficients of $R$ are symmetric in the roots of $T$, +meaning they are equal to fixed functions of the coefficients of $T$. Therefore +another approach is to compute these functions once and for all, and record +them in a lookup table. This module defines code that can compute such tables. +The tables for polynomials $T$ of degrees 4 through 6, produced by this code, +are recorded in the resolvent_lookup.py module. + +""" + +from sympy.core.evalf import ( + evalf, fastlog, _evalf_with_bounded_error, quad_to_mpmath, +) +from sympy.core.symbol import symbols, Dummy +from sympy.polys.densetools import dup_eval +from sympy.polys.domains import ZZ +from sympy.polys.orderings import lex +from sympy.polys.polyroots import preprocess_roots +from sympy.polys.polytools import Poly +from sympy.polys.rings import xring +from sympy.polys.specialpolys import symmetric_poly +from sympy.utilities.lambdify import lambdify + +from mpmath import MPContext +from mpmath.libmp.libmpf import prec_to_dps + + +class GaloisGroupException(Exception): + ... + + +class ResolventException(GaloisGroupException): + ... + + +class Resolvent: + r""" + If $G$ is a subgroup of the symmetric group $S_n$, + $F$ a multivariate polynomial in $\mathbb{Z}[X_1, \ldots, X_n]$, + $H$ the stabilizer of $F$ in $G$ (i.e. the permutations $\sigma$ such that + $F(X_{\sigma(1)}, \ldots, X_{\sigma(n)}) = F(X_1, \ldots, X_n)$), and $s$ + a set of left coset representatives of $H$ in $G$, then the resolvent + polynomial $R(Y)$ is the product over $\sigma \in s$ of + $Y - F(X_{\sigma(1)}, \ldots, X_{\sigma(n)})$. + + For example, consider the resolvent for the form + $$F = X_0 X_2 + X_1 X_3$$ + and the group $G = S_4$. In this case, the stabilizer $H$ is the dihedral + group $D4 = < (0123), (02) >$, and a set of representatives of $G/H$ is + $\{I, (01), (03)\}$. The resolvent can be constructed as follows: + + >>> from sympy.combinatorics.permutations import Permutation + >>> from sympy.core.symbol import symbols + >>> from sympy.polys.numberfields.galoisgroups import Resolvent + >>> X = symbols('X0 X1 X2 X3') + >>> F = X[0]*X[2] + X[1]*X[3] + >>> s = [Permutation([0, 1, 2, 3]), Permutation([1, 0, 2, 3]), + ... Permutation([3, 1, 2, 0])] + >>> R = Resolvent(F, X, s) + + This resolvent has three roots, which are the conjugates of ``F`` under the + three permutations in ``s``: + + >>> R.root_lambdas[0](*X) + X0*X2 + X1*X3 + >>> R.root_lambdas[1](*X) + X0*X3 + X1*X2 + >>> R.root_lambdas[2](*X) + X0*X1 + X2*X3 + + Resolvents are useful for computing Galois groups. Given a polynomial $T$ + of degree $n$, we will use a resolvent $R$ where $Gal(T) \leq G \leq S_n$. + We will then want to substitute the roots of $T$ for the variables $X_i$ + in $R$, and study things like the discriminant of $R$, and the way $R$ + factors over $\mathbb{Q}$. + + From the symmetry in $R$'s construction, and since $Gal(T) \leq G$, we know + from Galois theory that the coefficients of $R$ must lie in $\mathbb{Z}$. + This allows us to compute the coefficients of $R$ by approximating the + roots of $T$ to sufficient precision, plugging these values in for the + variables $X_i$ in the coefficient expressions of $R$, and then simply + rounding to the nearest integer. + + In order to determine a sufficient precision for the roots of $T$, this + ``Resolvent`` class imposes certain requirements on the form ``F``. It + could be possible to design a different ``Resolvent`` class, that made + different precision estimates, and different assumptions about ``F``. + + ``F`` must be homogeneous, and all terms must have unit coefficient. + Furthermore, if $r$ is the number of terms in ``F``, and $t$ the total + degree, and if $m$ is the number of conjugates of ``F``, i.e. the number + of permutations in ``s``, then we require that $m < r 2^t$. Again, it is + not impossible to work with forms ``F`` that violate these assumptions, but + this ``Resolvent`` class requires them. + + Since determining the integer coefficients of the resolvent for a given + polynomial $T$ is one of the main problems this class solves, we take some + time to explain the precision bounds it uses. + + The general problem is: + Given a multivariate polynomial $P \in \mathbb{Z}[X_1, \ldots, X_n]$, and a + bound $M \in \mathbb{R}_+$, compute an $\varepsilon > 0$ such that for any + complex numbers $a_1, \ldots, a_n$ with $|a_i| < M$, if the $a_i$ are + approximated to within an accuracy of $\varepsilon$ by $b_i$, that is, + $|a_i - b_i| < \varepsilon$ for $i = 1, \ldots, n$, then + $|P(a_1, \ldots, a_n) - P(b_1, \ldots, b_n)| < 1/2$. In other words, if it + is known that $P(a_1, \ldots, a_n) = c$ for some $c \in \mathbb{Z}$, then + $P(b_1, \ldots, b_n)$ can be rounded to the nearest integer in order to + determine $c$. + + To derive our error bound, consider the monomial $xyz$. Defining + $d_i = b_i - a_i$, our error is + $|(a_1 + d_1)(a_2 + d_2)(a_3 + d_3) - a_1 a_2 a_3|$, which is bounded + above by $|(M + \varepsilon)^3 - M^3|$. Passing to a general monomial of + total degree $t$, this expression is bounded by + $M^{t-1}\varepsilon(t + 2^t\varepsilon/M)$ provided $\varepsilon < M$, + and by $(t+1)M^{t-1}\varepsilon$ provided $\varepsilon < M/2^t$. + But since our goal is to make the error less than $1/2$, we will choose + $\varepsilon < 1/(2(t+1)M^{t-1})$, which implies the condition that + $\varepsilon < M/2^t$, as long as $M \geq 2$. + + Passing from the general monomial to the general polynomial is easy, by + scaling and summing error bounds. + + In our specific case, we are given a homogeneous polynomial $F$ of + $r$ terms and total degree $t$, all of whose coefficients are $\pm 1$. We + are given the $m$ permutations that make the conjugates of $F$, and + we want to bound the error in the coefficients of the monic polynomial + $R(Y)$ having $F$ and its conjugates as roots (i.e. the resolvent). + + For $j$ from $1$ to $m$, the coefficient of $Y^{m-j}$ in $R(Y)$ is the + $j$th elementary symmetric polynomial in the conjugates of $F$. This sums + the products of these conjugates, taken $j$ at a time, in all possible + combinations. There are $\binom{m}{j}$ such combinations, and each product + of $j$ conjugates of $F$ expands to a sum of $r^j$ terms, each of unit + coefficient, and total degree $jt$. An error bound for the $j$th coeff of + $R$ is therefore + $$\binom{m}{j} r^j (jt + 1) M^{jt - 1} \varepsilon$$ + When our goal is to evaluate all the coefficients of $R$, we will want to + use the maximum of these error bounds. It is clear that this bound is + strictly increasing for $j$ up to the ceiling of $m/2$. After that point, + the first factor $\binom{m}{j}$ begins to decrease, while the others + continue to increase. However, the binomial coefficient never falls by more + than a factor of $1/m$ at a time, so our assumptions that $M \geq 2$ and + $m < r 2^t$ are enough to tell us that the constant coefficient of $R$, + i.e. that where $j = m$, has the largest error bound. Therefore we can use + $$r^m (mt + 1) M^{mt - 1} \varepsilon$$ + as our error bound for all the coefficients. + + Note that this bound is also (more than) adequate to determine whether any + of the roots of $R$ is an integer. Each of these roots is a single + conjugate of $F$, which contains less error than the trace, i.e. the + coefficient of $Y^{m - 1}$. By rounding the roots of $R$ to the nearest + integers, we therefore get all the candidates for integer roots of $R$. By + plugging these candidates into $R$, we can check whether any of them + actually is a root. + + Note: We take the definition of resolvent from Cohen, but the error bound + is ours. + + References + ========== + + .. [1] Cohen, H. *A Course in Computational Algebraic Number Theory*. + (Def 6.3.2) + + """ + + def __init__(self, F, X, s): + r""" + Parameters + ========== + + F : :py:class:`~.Expr` + polynomial in the symbols in *X* + X : list of :py:class:`~.Symbol` + s : list of :py:class:`~.Permutation` + representing the cosets of the stabilizer of *F* in + some subgroup $G$ of $S_n$, where $n$ is the length of *X*. + """ + self.F = F + self.X = X + self.s = s + + # Number of conjugates: + self.m = len(s) + # Total degree of F (computed below): + self.t = None + # Number of terms in F (computed below): + self.r = 0 + + for monom, coeff in Poly(F).terms(): + if abs(coeff) != 1: + raise ResolventException('Resolvent class expects forms with unit coeffs') + t = sum(monom) + if t != self.t and self.t is not None: + raise ResolventException('Resolvent class expects homogeneous forms') + self.t = t + self.r += 1 + + m, t, r = self.m, self.t, self.r + if not m < r * 2**t: + raise ResolventException('Resolvent class expects m < r*2^t') + M = symbols('M') + # Precision sufficient for computing the coeffs of the resolvent: + self.coeff_prec_func = Poly(r**m*(m*t + 1)*M**(m*t - 1)) + # Precision sufficient for checking whether any of the roots of the + # resolvent are integers: + self.root_prec_func = Poly(r*(t + 1)*M**(t - 1)) + + # The conjugates of F are the roots of the resolvent. + # For evaluating these to required numerical precisions, we need + # lambdified versions. + # Note: for a given permutation sigma, the conjugate (sigma F) is + # equivalent to lambda [sigma^(-1) X]: F. + self.root_lambdas = [ + lambdify((~s[j])(X), F) + for j in range(self.m) + ] + + # For evaluating the coeffs, we'll also need lambdified versions of + # the elementary symmetric functions for degree m. + Y = symbols('Y') + R = symbols(' '.join(f'R{i}' for i in range(m))) + f = 1 + for r in R: + f *= (Y - r) + C = Poly(f, Y).coeffs() + self.esf_lambdas = [lambdify(R, c) for c in C] + + def get_prec(self, M, target='coeffs'): + r""" + For a given upper bound *M* on the magnitude of the complex numbers to + be plugged in for this resolvent's symbols, compute a sufficient + precision for evaluating those complex numbers, such that the + coefficients, or the integer roots, of the resolvent can be determined. + + Parameters + ========== + + M : real number + Upper bound on magnitude of the complex numbers to be plugged in. + + target : str, 'coeffs' or 'roots', default='coeffs' + Name the task for which a sufficient precision is desired. + This is either determining the coefficients of the resolvent + ('coeffs') or determining its possible integer roots ('roots'). + The latter may require significantly lower precision. + + Returns + ======= + + int $m$ + such that $2^{-m}$ is a sufficient upper bound on the + error in approximating the complex numbers to be plugged in. + + """ + # As explained in the docstring for this class, our precision estimates + # require that M be at least 2. + M = max(M, 2) + f = self.coeff_prec_func if target == 'coeffs' else self.root_prec_func + r, _, _, _ = evalf(2*f(M), 1, {}) + return fastlog(r) + 1 + + def approximate_roots_of_poly(self, T, target='coeffs'): + """ + Approximate the roots of a given polynomial *T* to sufficient precision + in order to evaluate this resolvent's coefficients, or determine + whether the resolvent has an integer root. + + Parameters + ========== + + T : :py:class:`~.Poly` + + target : str, 'coeffs' or 'roots', default='coeffs' + Set the approximation precision to be sufficient for the desired + task, which is either determining the coefficients of the resolvent + ('coeffs') or determining its possible integer roots ('roots'). + The latter may require significantly lower precision. + + Returns + ======= + + list of elements of :ref:`CC` + + """ + ctx = MPContext() + # Because sympy.polys.polyroots._integer_basis() is called when a CRootOf + # is formed, we proactively extract the integer basis now. This means that + # when we call T.all_roots(), every root will be a CRootOf, not a Mul + # of Integer*CRootOf. + coeff, T = preprocess_roots(T) + coeff = ctx.mpf(str(coeff)) + + scaled_roots = T.all_roots(radicals=False) + + # Since we're going to be approximating the roots of T anyway, we can + # get a good upper bound on the magnitude of the roots by starting with + # a very low precision approx. + approx0 = [coeff * quad_to_mpmath(_evalf_with_bounded_error(r, m=0)) for r in scaled_roots] + # Here we add 1 to account for the possible error in our initial approximation. + M = max(abs(b) for b in approx0) + 1 + m = self.get_prec(M, target=target) + n = fastlog(M._mpf_) + 1 + p = m + n + 1 + ctx.prec = p + d = prec_to_dps(p) + + approx1 = [r.eval_approx(d, return_mpmath=True) for r in scaled_roots] + approx1 = [coeff*ctx.mpc(r) for r in approx1] + + return approx1 + + @staticmethod + def round_mpf(a): + if isinstance(a, int): + return a + # If we use python's built-in `round()`, we lose precision. + # If we use `ZZ` directly, we may add or subtract 1. + # + # XXX: We have to convert to int before converting to ZZ because + # flint.fmpz cannot convert a mpmath mpf. + return ZZ(int(a.context.nint(a))) + + def round_roots_to_integers_for_poly(self, T): + """ + For a given polynomial *T*, round the roots of this resolvent to the + nearest integers. + + Explanation + =========== + + None of the integers returned by this method is guaranteed to be a + root of the resolvent; however, if the resolvent has any integer roots + (for the given polynomial *T*), then they must be among these. + + If the coefficients of the resolvent are also desired, then this method + should not be used. Instead, use the ``eval_for_poly`` method. This + method may be significantly faster than ``eval_for_poly``. + + Parameters + ========== + + T : :py:class:`~.Poly` + + Returns + ======= + + dict + Keys are the indices of those permutations in ``self.s`` such that + the corresponding root did round to a rational integer. + + Values are :ref:`ZZ`. + + + """ + approx_roots_of_T = self.approximate_roots_of_poly(T, target='roots') + approx_roots_of_self = [r(*approx_roots_of_T) for r in self.root_lambdas] + return { + i: self.round_mpf(r.real) + for i, r in enumerate(approx_roots_of_self) + if self.round_mpf(r.imag) == 0 + } + + def eval_for_poly(self, T, find_integer_root=False): + r""" + Compute the integer values of the coefficients of this resolvent, when + plugging in the roots of a given polynomial. + + Parameters + ========== + + T : :py:class:`~.Poly` + + find_integer_root : ``bool``, default ``False`` + If ``True``, then also determine whether the resolvent has an + integer root, and return the first one found, along with its + index, i.e. the index of the permutation ``self.s[i]`` it + corresponds to. + + Returns + ======= + + Tuple ``(R, a, i)`` + + ``R`` is this resolvent as a dense univariate polynomial over + :ref:`ZZ`, i.e. a list of :ref:`ZZ`. + + If *find_integer_root* was ``True``, then ``a`` and ``i`` are the + first integer root found, and its index, if one exists. + Otherwise ``a`` and ``i`` are both ``None``. + + """ + approx_roots_of_T = self.approximate_roots_of_poly(T, target='coeffs') + approx_roots_of_self = [r(*approx_roots_of_T) for r in self.root_lambdas] + approx_coeffs_of_self = [c(*approx_roots_of_self) for c in self.esf_lambdas] + + R = [] + for c in approx_coeffs_of_self: + if self.round_mpf(c.imag) != 0: + # If precision was enough, this should never happen. + raise ResolventException(f"Got non-integer coeff for resolvent: {c}") + R.append(self.round_mpf(c.real)) + + a0, i0 = None, None + + if find_integer_root: + for i, r in enumerate(approx_roots_of_self): + if self.round_mpf(r.imag) != 0: + continue + if not dup_eval(R, (a := self.round_mpf(r.real)), ZZ): + a0, i0 = a, i + break + + return R, a0, i0 + + +def wrap(text, width=80): + """Line wrap a polynomial expression. """ + out = '' + col = 0 + for c in text: + if c == ' ' and col > width: + c, col = '\n', 0 + else: + col += 1 + out += c + return out + + +def s_vars(n): + """Form the symbols s1, s2, ..., sn to stand for elem. symm. polys. """ + return symbols([f's{i + 1}' for i in range(n)]) + + +def sparse_symmetrize_resolvent_coeffs(F, X, s, verbose=False): + """ + Compute the coefficients of a resolvent as functions of the coefficients of + the associated polynomial. + + F must be a sparse polynomial. + """ + import time, sys + # Roots of resolvent as multivariate forms over vars X: + root_forms = [ + F.compose(list(zip(X, sigma(X)))) + for sigma in s + ] + + # Coeffs of resolvent (besides lead coeff of 1) as symmetric forms over vars X: + Y = [Dummy(f'Y{i}') for i in range(len(s))] + coeff_forms = [] + for i in range(1, len(s) + 1): + if verbose: + print('----') + print(f'Computing symmetric poly of degree {i}...') + sys.stdout.flush() + t0 = time.time() + G = symmetric_poly(i, *Y) + t1 = time.time() + if verbose: + print(f'took {t1 - t0} seconds') + print('lambdifying...') + sys.stdout.flush() + t0 = time.time() + C = lambdify(Y, (-1)**i*G) + t1 = time.time() + if verbose: + print(f'took {t1 - t0} seconds') + sys.stdout.flush() + coeff_forms.append(C) + + coeffs = [] + for i, f in enumerate(coeff_forms): + if verbose: + print('----') + print(f'Plugging root forms into elem symm poly {i+1}...') + sys.stdout.flush() + t0 = time.time() + g = f(*root_forms) + t1 = time.time() + coeffs.append(g) + if verbose: + print(f'took {t1 - t0} seconds') + sys.stdout.flush() + + # Now symmetrize these coeffs. This means recasting them as polynomials in + # the elementary symmetric polys over X. + symmetrized = [] + symmetrization_times = [] + ss = s_vars(len(X)) + for i, A in list(enumerate(coeffs)): + if verbose: + print('-----') + print(f'Coeff {i+1}...') + sys.stdout.flush() + t0 = time.time() + B, rem, _ = A.symmetrize() + t1 = time.time() + if rem != 0: + msg = f"Got nonzero remainder {rem} for resolvent (F, X, s) = ({F}, {X}, {s})" + raise ResolventException(msg) + B_str = str(B.as_expr(*ss)) + symmetrized.append(B_str) + symmetrization_times.append(t1 - t0) + if verbose: + print(wrap(B_str)) + print(f'took {t1 - t0} seconds') + sys.stdout.flush() + + return symmetrized, symmetrization_times + + +def define_resolvents(): + """Define all the resolvents for polys T of degree 4 through 6. """ + from sympy.combinatorics.galois import PGL2F5 + from sympy.combinatorics.permutations import Permutation + + R4, X4 = xring("X0,X1,X2,X3", ZZ, lex) + X = X4 + + # The one resolvent used in `_galois_group_degree_4_lookup()`: + F40 = X[0]*X[1]**2 + X[1]*X[2]**2 + X[2]*X[3]**2 + X[3]*X[0]**2 + s40 = [ + Permutation(3), + Permutation(3)(0, 1), + Permutation(3)(0, 2), + Permutation(3)(0, 3), + Permutation(3)(1, 2), + Permutation(3)(2, 3), + ] + + # First resolvent used in `_galois_group_degree_4_root_approx()`: + F41 = X[0]*X[2] + X[1]*X[3] + s41 = [ + Permutation(3), + Permutation(3)(0, 1), + Permutation(3)(0, 3) + ] + + R5, X5 = xring("X0,X1,X2,X3,X4", ZZ, lex) + X = X5 + + # First resolvent used in `_galois_group_degree_5_hybrid()`, + # and only one used in `_galois_group_degree_5_lookup_ext_factor()`: + F51 = ( X[0]**2*(X[1]*X[4] + X[2]*X[3]) + + X[1]**2*(X[2]*X[0] + X[3]*X[4]) + + X[2]**2*(X[3]*X[1] + X[4]*X[0]) + + X[3]**2*(X[4]*X[2] + X[0]*X[1]) + + X[4]**2*(X[0]*X[3] + X[1]*X[2])) + s51 = [ + Permutation(4), + Permutation(4)(0, 1), + Permutation(4)(0, 2), + Permutation(4)(0, 3), + Permutation(4)(0, 4), + Permutation(4)(1, 4) + ] + + R6, X6 = xring("X0,X1,X2,X3,X4,X5", ZZ, lex) + X = X6 + + # First resolvent used in `_galois_group_degree_6_lookup()`: + H = PGL2F5() + term0 = X[0]**2*X[5]**2*(X[1]*X[4] + X[2]*X[3]) + terms = {term0.compose(list(zip(X, s(X)))) for s in H.elements} + F61 = sum(terms) + s61 = [Permutation(5)] + [Permutation(5)(0, n) for n in range(1, 6)] + + # Second resolvent used in `_galois_group_degree_6_lookup()`: + F62 = X[0]*X[1]*X[2] + X[3]*X[4]*X[5] + s62 = [Permutation(5)] + [ + Permutation(5)(i, j + 3) for i in range(3) for j in range(3) + ] + + return { + (4, 0): (F40, X4, s40), + (4, 1): (F41, X4, s41), + (5, 1): (F51, X5, s51), + (6, 1): (F61, X6, s61), + (6, 2): (F62, X6, s62), + } + + +def generate_lambda_lookup(verbose=False, trial_run=False): + """ + Generate the whole lookup table of coeff lambdas, for all resolvents. + """ + jobs = define_resolvents() + lambda_lists = {} + total_time = 0 + time_for_61 = 0 + time_for_61_last = 0 + for k, (F, X, s) in jobs.items(): + symmetrized, times = sparse_symmetrize_resolvent_coeffs(F, X, s, verbose=verbose) + + total_time += sum(times) + if k == (6, 1): + time_for_61 = sum(times) + time_for_61_last = times[-1] + + sv = s_vars(len(X)) + head = f'lambda {", ".join(str(v) for v in sv)}:' + lambda_lists[k] = ',\n '.join([ + f'{head} ({wrap(f)})' + for f in symmetrized + ]) + + if trial_run: + break + + table = ( + "# This table was generated by a call to\n" + "# `sympy.polys.numberfields.galois_resolvents.generate_lambda_lookup()`.\n" + f"# The entire job took {total_time:.2f}s.\n" + f"# Of this, Case (6, 1) took {time_for_61:.2f}s.\n" + f"# The final polynomial of Case (6, 1) alone took {time_for_61_last:.2f}s.\n" + "resolvent_coeff_lambdas = {\n") + + for k, L in lambda_lists.items(): + table += f" {k}: [\n" + table += " " + L + '\n' + table += " ],\n" + table += "}\n" + return table + + +def get_resolvent_by_lookup(T, number): + """ + Use the lookup table, to return a resolvent (as dup) for a given + polynomial *T*. + + Parameters + ========== + + T : Poly + The polynomial whose resolvent is needed + + number : int + For some degrees, there are multiple resolvents. + Use this to indicate which one you want. + + Returns + ======= + + dup + + """ + from sympy.polys.numberfields.resolvent_lookup import resolvent_coeff_lambdas + degree = T.degree() + L = resolvent_coeff_lambdas[(degree, number)] + T_coeffs = T.rep.to_list()[1:] + return [ZZ(1)] + [c(*T_coeffs) for c in L] + + +# Use +# (.venv) $ python -m sympy.polys.numberfields.galois_resolvents +# to reproduce the table found in resolvent_lookup.py +if __name__ == "__main__": + import sys + verbose = '-v' in sys.argv[1:] + trial_run = '-t' in sys.argv[1:] + table = generate_lambda_lookup(verbose=verbose, trial_run=trial_run) + print(table) diff --git a/phivenv/Lib/site-packages/sympy/polys/numberfields/galoisgroups.py b/phivenv/Lib/site-packages/sympy/polys/numberfields/galoisgroups.py new file mode 100644 index 0000000000000000000000000000000000000000..a0e424bf7554c0cedd926902e7322b9640735a8b --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/numberfields/galoisgroups.py @@ -0,0 +1,623 @@ +""" +Compute Galois groups of polynomials. + +We use algorithms from [1], with some modifications to use lookup tables for +resolvents. + +References +========== + +.. [1] Cohen, H. *A Course in Computational Algebraic Number Theory*. + +""" + +from collections import defaultdict +import random + +from sympy.core.symbol import Dummy, symbols +from sympy.ntheory.primetest import is_square +from sympy.polys.domains import ZZ +from sympy.polys.densebasic import dup_random +from sympy.polys.densetools import dup_eval +from sympy.polys.euclidtools import dup_discriminant +from sympy.polys.factortools import dup_factor_list, dup_irreducible_p +from sympy.polys.numberfields.galois_resolvents import ( + GaloisGroupException, get_resolvent_by_lookup, define_resolvents, + Resolvent, +) +from sympy.polys.numberfields.utilities import coeff_search +from sympy.polys.polytools import (Poly, poly_from_expr, + PolificationFailed, ComputationFailed) +from sympy.polys.sqfreetools import dup_sqf_p +from sympy.utilities import public + + +class MaxTriesException(GaloisGroupException): + ... + + +def tschirnhausen_transformation(T, max_coeff=10, max_tries=30, history=None, + fixed_order=True): + r""" + Given a univariate, monic, irreducible polynomial over the integers, find + another such polynomial defining the same number field. + + Explanation + =========== + + See Alg 6.3.4 of [1]. + + Parameters + ========== + + T : Poly + The given polynomial + max_coeff : int + When choosing a transformation as part of the process, + keep the coeffs between plus and minus this. + max_tries : int + Consider at most this many transformations. + history : set, None, optional (default=None) + Pass a set of ``Poly.rep``'s in order to prevent any of these + polynomials from being returned as the polynomial ``U`` i.e. the + transformation of the given polynomial *T*. The given poly *T* will + automatically be added to this set, before we try to find a new one. + fixed_order : bool, default True + If ``True``, work through candidate transformations A(x) in a fixed + order, from small coeffs to large, resulting in deterministic behavior. + If ``False``, the A(x) are chosen randomly, while still working our way + up from small coefficients to larger ones. + + Returns + ======= + + Pair ``(A, U)`` + + ``A`` and ``U`` are ``Poly``, ``A`` is the + transformation, and ``U`` is the transformed polynomial that defines + the same number field as *T*. The polynomial ``A`` maps the roots of + *T* to the roots of ``U``. + + Raises + ====== + + MaxTriesException + if could not find a polynomial before exceeding *max_tries*. + + """ + X = Dummy('X') + n = T.degree() + if history is None: + history = set() + history.add(T.rep) + + if fixed_order: + coeff_generators = {} + deg_coeff_sum = 3 + current_degree = 2 + + def get_coeff_generator(degree): + gen = coeff_generators.get(degree, coeff_search(degree, 1)) + coeff_generators[degree] = gen + return gen + + for i in range(max_tries): + + # We never use linear A(x), since applying a fixed linear transformation + # to all roots will only multiply the discriminant of T by a square + # integer. This will change nothing important. In particular, if disc(T) + # was zero before, it will still be zero now, and typically we apply + # the transformation in hopes of replacing T by a squarefree poly. + + if fixed_order: + # If d is degree and c max coeff, we move through the dc-space + # along lines of constant sum. First d + c = 3 with (d, c) = (2, 1). + # Then d + c = 4 with (d, c) = (3, 1), (2, 2). Then d + c = 5 with + # (d, c) = (4, 1), (3, 2), (2, 3), and so forth. For a given (d, c) + # we go though all sets of coeffs where max = c, before moving on. + gen = get_coeff_generator(current_degree) + coeffs = next(gen) + m = max(abs(c) for c in coeffs) + if current_degree + m > deg_coeff_sum: + if current_degree == 2: + deg_coeff_sum += 1 + current_degree = deg_coeff_sum - 1 + else: + current_degree -= 1 + gen = get_coeff_generator(current_degree) + coeffs = next(gen) + a = [ZZ(1)] + [ZZ(c) for c in coeffs] + + else: + # We use a progressive coeff bound, up to the max specified, since it + # is preferable to succeed with smaller coeffs. + # Give each coeff bound five tries, before incrementing. + C = min(i//5 + 1, max_coeff) + d = random.randint(2, n - 1) + a = dup_random(d, -C, C, ZZ) + + A = Poly(a, T.gen) + U = Poly(T.resultant(X - A), X) + if U.rep not in history and dup_sqf_p(U.rep.to_list(), ZZ): + return A, U + raise MaxTriesException + + +def has_square_disc(T): + """Convenience to check if a Poly or dup has square discriminant. """ + d = T.discriminant() if isinstance(T, Poly) else dup_discriminant(T, ZZ) + return is_square(d) + + +def _galois_group_degree_3(T, max_tries=30, randomize=False): + r""" + Compute the Galois group of a polynomial of degree 3. + + Explanation + =========== + + Uses Prop 6.3.5 of [1]. + + """ + from sympy.combinatorics.galois import S3TransitiveSubgroups + return ((S3TransitiveSubgroups.A3, True) if has_square_disc(T) + else (S3TransitiveSubgroups.S3, False)) + + +def _galois_group_degree_4_root_approx(T, max_tries=30, randomize=False): + r""" + Compute the Galois group of a polynomial of degree 4. + + Explanation + =========== + + Follows Alg 6.3.7 of [1], using a pure root approximation approach. + + """ + from sympy.combinatorics.permutations import Permutation + from sympy.combinatorics.galois import S4TransitiveSubgroups + + X = symbols('X0 X1 X2 X3') + # We start by considering the resolvent for the form + # F = X0*X2 + X1*X3 + # and the group G = S4. In this case, the stabilizer H is D4 = < (0123), (02) >, + # and a set of representatives of G/H is {I, (01), (03)} + F1 = X[0]*X[2] + X[1]*X[3] + s1 = [ + Permutation(3), + Permutation(3)(0, 1), + Permutation(3)(0, 3) + ] + R1 = Resolvent(F1, X, s1) + + # In the second half of the algorithm (if we reach it), we use another + # form and set of coset representatives. However, we may need to permute + # them first, so cannot form their resolvent now. + F2_pre = X[0]*X[1]**2 + X[1]*X[2]**2 + X[2]*X[3]**2 + X[3]*X[0]**2 + s2_pre = [ + Permutation(3), + Permutation(3)(0, 2) + ] + + history = set() + for i in range(max_tries): + if i > 0: + # If we're retrying, need a new polynomial T. + _, T = tschirnhausen_transformation(T, max_tries=max_tries, + history=history, + fixed_order=not randomize) + + R_dup, _, i0 = R1.eval_for_poly(T, find_integer_root=True) + # If R is not squarefree, must retry. + if not dup_sqf_p(R_dup, ZZ): + continue + + # By Prop 6.3.1 of [1], Gal(T) is contained in A4 iff disc(T) is square. + sq_disc = has_square_disc(T) + + if i0 is None: + # By Thm 6.3.3 of [1], Gal(T) is not conjugate to any subgroup of the + # stabilizer H = D4 that we chose. This means Gal(T) is either A4 or S4. + return ((S4TransitiveSubgroups.A4, True) if sq_disc + else (S4TransitiveSubgroups.S4, False)) + + # Gal(T) is conjugate to a subgroup of H = D4, so it is either V, C4 + # or D4 itself. + + if sq_disc: + # Neither C4 nor D4 is contained in A4, so Gal(T) must be V. + return (S4TransitiveSubgroups.V, True) + + # Gal(T) can only be D4 or C4. + # We will now use our second resolvent, with G being that conjugate of D4 that + # Gal(T) is contained in. To determine the right conjugate, we will need + # the permutation corresponding to the integer root we found. + sigma = s1[i0] + # Applying sigma means permuting the args of F, and + # conjugating the set of coset representatives. + F2 = F2_pre.subs(zip(X, sigma(X)), simultaneous=True) + s2 = [sigma*tau*sigma for tau in s2_pre] + R2 = Resolvent(F2, X, s2) + R_dup, _, _ = R2.eval_for_poly(T) + d = dup_discriminant(R_dup, ZZ) + # If d is zero (R has a repeated root), must retry. + if d == 0: + continue + if is_square(d): + return (S4TransitiveSubgroups.C4, False) + else: + return (S4TransitiveSubgroups.D4, False) + + raise MaxTriesException + + +def _galois_group_degree_4_lookup(T, max_tries=30, randomize=False): + r""" + Compute the Galois group of a polynomial of degree 4. + + Explanation + =========== + + Based on Alg 6.3.6 of [1], but uses resolvent coeff lookup. + + """ + from sympy.combinatorics.galois import S4TransitiveSubgroups + + history = set() + for i in range(max_tries): + R_dup = get_resolvent_by_lookup(T, 0) + if dup_sqf_p(R_dup, ZZ): + break + _, T = tschirnhausen_transformation(T, max_tries=max_tries, + history=history, + fixed_order=not randomize) + else: + raise MaxTriesException + + # Compute list L of degrees of irreducible factors of R, in increasing order: + fl = dup_factor_list(R_dup, ZZ) + L = sorted(sum([ + [len(r) - 1] * e for r, e in fl[1] + ], [])) + + if L == [6]: + return ((S4TransitiveSubgroups.A4, True) if has_square_disc(T) + else (S4TransitiveSubgroups.S4, False)) + + if L == [1, 1, 4]: + return (S4TransitiveSubgroups.C4, False) + + if L == [2, 2, 2]: + return (S4TransitiveSubgroups.V, True) + + assert L == [2, 4] + return (S4TransitiveSubgroups.D4, False) + + +def _galois_group_degree_5_hybrid(T, max_tries=30, randomize=False): + r""" + Compute the Galois group of a polynomial of degree 5. + + Explanation + =========== + + Based on Alg 6.3.9 of [1], but uses a hybrid approach, combining resolvent + coeff lookup, with root approximation. + + """ + from sympy.combinatorics.galois import S5TransitiveSubgroups + from sympy.combinatorics.permutations import Permutation + + X5 = symbols("X0,X1,X2,X3,X4") + res = define_resolvents() + F51, _, s51 = res[(5, 1)] + F51 = F51.as_expr(*X5) + R51 = Resolvent(F51, X5, s51) + + history = set() + reached_second_stage = False + for i in range(max_tries): + if i > 0: + _, T = tschirnhausen_transformation(T, max_tries=max_tries, + history=history, + fixed_order=not randomize) + R51_dup = get_resolvent_by_lookup(T, 1) + if not dup_sqf_p(R51_dup, ZZ): + continue + + # First stage + # If we have not yet reached the second stage, then the group still + # might be S5, A5, or M20, so must test for that. + if not reached_second_stage: + sq_disc = has_square_disc(T) + + if dup_irreducible_p(R51_dup, ZZ): + return ((S5TransitiveSubgroups.A5, True) if sq_disc + else (S5TransitiveSubgroups.S5, False)) + + if not sq_disc: + return (S5TransitiveSubgroups.M20, False) + + # Second stage + reached_second_stage = True + # R51 must have an integer root for T. + # To choose our second resolvent, we need to know which conjugate of + # F51 is a root. + rounded_roots = R51.round_roots_to_integers_for_poly(T) + # These are integers, and candidates to be roots of R51. + # We find the first one that actually is a root. + for permutation_index, candidate_root in rounded_roots.items(): + if not dup_eval(R51_dup, candidate_root, ZZ): + break + + X = X5 + F2_pre = X[0]*X[1]**2 + X[1]*X[2]**2 + X[2]*X[3]**2 + X[3]*X[4]**2 + X[4]*X[0]**2 + s2_pre = [ + Permutation(4), + Permutation(4)(0, 1)(2, 4) + ] + + i0 = permutation_index + sigma = s51[i0] + F2 = F2_pre.subs(zip(X, sigma(X)), simultaneous=True) + s2 = [sigma*tau*sigma for tau in s2_pre] + R2 = Resolvent(F2, X, s2) + R_dup, _, _ = R2.eval_for_poly(T) + d = dup_discriminant(R_dup, ZZ) + + if d == 0: + continue + if is_square(d): + return (S5TransitiveSubgroups.C5, True) + else: + return (S5TransitiveSubgroups.D5, True) + + raise MaxTriesException + + +def _galois_group_degree_5_lookup_ext_factor(T, max_tries=30, randomize=False): + r""" + Compute the Galois group of a polynomial of degree 5. + + Explanation + =========== + + Based on Alg 6.3.9 of [1], but uses resolvent coeff lookup, plus + factorization over an algebraic extension. + + """ + from sympy.combinatorics.galois import S5TransitiveSubgroups + + _T = T + + history = set() + for i in range(max_tries): + R_dup = get_resolvent_by_lookup(T, 1) + if dup_sqf_p(R_dup, ZZ): + break + _, T = tschirnhausen_transformation(T, max_tries=max_tries, + history=history, + fixed_order=not randomize) + else: + raise MaxTriesException + + sq_disc = has_square_disc(T) + + if dup_irreducible_p(R_dup, ZZ): + return ((S5TransitiveSubgroups.A5, True) if sq_disc + else (S5TransitiveSubgroups.S5, False)) + + if not sq_disc: + return (S5TransitiveSubgroups.M20, False) + + # If we get this far, Gal(T) can only be D5 or C5. + # But for Gal(T) to have order 5, T must already split completely in + # the extension field obtained by adjoining a single one of its roots. + fl = Poly(_T, domain=ZZ.alg_field_from_poly(_T)).factor_list()[1] + if len(fl) == 5: + return (S5TransitiveSubgroups.C5, True) + else: + return (S5TransitiveSubgroups.D5, True) + + +def _galois_group_degree_6_lookup(T, max_tries=30, randomize=False): + r""" + Compute the Galois group of a polynomial of degree 6. + + Explanation + =========== + + Based on Alg 6.3.10 of [1], but uses resolvent coeff lookup. + + """ + from sympy.combinatorics.galois import S6TransitiveSubgroups + + # First resolvent: + + history = set() + for i in range(max_tries): + R_dup = get_resolvent_by_lookup(T, 1) + if dup_sqf_p(R_dup, ZZ): + break + _, T = tschirnhausen_transformation(T, max_tries=max_tries, + history=history, + fixed_order=not randomize) + else: + raise MaxTriesException + + fl = dup_factor_list(R_dup, ZZ) + + # Group the factors by degree. + factors_by_deg = defaultdict(list) + for r, _ in fl[1]: + factors_by_deg[len(r) - 1].append(r) + + L = sorted(sum([ + [d] * len(ff) for d, ff in factors_by_deg.items() + ], [])) + + T_has_sq_disc = has_square_disc(T) + + if L == [1, 2, 3]: + f1 = factors_by_deg[3][0] + return ((S6TransitiveSubgroups.C6, False) if has_square_disc(f1) + else (S6TransitiveSubgroups.D6, False)) + + elif L == [3, 3]: + f1, f2 = factors_by_deg[3] + any_square = has_square_disc(f1) or has_square_disc(f2) + return ((S6TransitiveSubgroups.G18, False) if any_square + else (S6TransitiveSubgroups.G36m, False)) + + elif L == [2, 4]: + if T_has_sq_disc: + return (S6TransitiveSubgroups.S4p, True) + else: + f1 = factors_by_deg[4][0] + return ((S6TransitiveSubgroups.A4xC2, False) if has_square_disc(f1) + else (S6TransitiveSubgroups.S4xC2, False)) + + elif L == [1, 1, 4]: + return ((S6TransitiveSubgroups.A4, True) if T_has_sq_disc + else (S6TransitiveSubgroups.S4m, False)) + + elif L == [1, 5]: + return ((S6TransitiveSubgroups.PSL2F5, True) if T_has_sq_disc + else (S6TransitiveSubgroups.PGL2F5, False)) + + elif L == [1, 1, 1, 3]: + return (S6TransitiveSubgroups.S3, False) + + assert L == [6] + + # Second resolvent: + + history = set() + for i in range(max_tries): + R_dup = get_resolvent_by_lookup(T, 2) + if dup_sqf_p(R_dup, ZZ): + break + _, T = tschirnhausen_transformation(T, max_tries=max_tries, + history=history, + fixed_order=not randomize) + else: + raise MaxTriesException + + T_has_sq_disc = has_square_disc(T) + + if dup_irreducible_p(R_dup, ZZ): + return ((S6TransitiveSubgroups.A6, True) if T_has_sq_disc + else (S6TransitiveSubgroups.S6, False)) + else: + return ((S6TransitiveSubgroups.G36p, True) if T_has_sq_disc + else (S6TransitiveSubgroups.G72, False)) + + +@public +def galois_group(f, *gens, by_name=False, max_tries=30, randomize=False, **args): + r""" + Compute the Galois group for polynomials *f* up to degree 6. + + Examples + ======== + + >>> from sympy import galois_group + >>> from sympy.abc import x + >>> f = x**4 + 1 + >>> G, alt = galois_group(f) + >>> print(G) + PermutationGroup([ + (0 1)(2 3), + (0 2)(1 3)]) + + The group is returned along with a boolean, indicating whether it is + contained in the alternating group $A_n$, where $n$ is the degree of *T*. + Along with other group properties, this can help determine which group it + is: + + >>> alt + True + >>> G.order() + 4 + + Alternatively, the group can be returned by name: + + >>> G_name, _ = galois_group(f, by_name=True) + >>> print(G_name) + S4TransitiveSubgroups.V + + The group itself can then be obtained by calling the name's + ``get_perm_group()`` method: + + >>> G_name.get_perm_group() + PermutationGroup([ + (0 1)(2 3), + (0 2)(1 3)]) + + Group names are values of the enum classes + :py:class:`sympy.combinatorics.galois.S1TransitiveSubgroups`, + :py:class:`sympy.combinatorics.galois.S2TransitiveSubgroups`, + etc. + + Parameters + ========== + + f : Expr + Irreducible polynomial over :ref:`ZZ` or :ref:`QQ`, whose Galois group + is to be determined. + gens : optional list of symbols + For converting *f* to Poly, and will be passed on to the + :py:func:`~.poly_from_expr` function. + by_name : bool, default False + If ``True``, the Galois group will be returned by name. + Otherwise it will be returned as a :py:class:`~.PermutationGroup`. + max_tries : int, default 30 + Make at most this many attempts in those steps that involve + generating Tschirnhausen transformations. + randomize : bool, default False + If ``True``, then use random coefficients when generating Tschirnhausen + transformations. Otherwise try transformations in a fixed order. Both + approaches start with small coefficients and degrees and work upward. + args : optional + For converting *f* to Poly, and will be passed on to the + :py:func:`~.poly_from_expr` function. + + Returns + ======= + + Pair ``(G, alt)`` + The first element ``G`` indicates the Galois group. It is an instance + of one of the :py:class:`sympy.combinatorics.galois.S1TransitiveSubgroups` + :py:class:`sympy.combinatorics.galois.S2TransitiveSubgroups`, etc. enum + classes if *by_name* was ``True``, and a :py:class:`~.PermutationGroup` + if ``False``. + + The second element is a boolean, saying whether the group is contained + in the alternating group $A_n$ ($n$ the degree of *T*). + + Raises + ====== + + ValueError + if *f* is of an unsupported degree. + + MaxTriesException + if could not complete before exceeding *max_tries* in those steps + that involve generating Tschirnhausen transformations. + + See Also + ======== + + .Poly.galois_group + + """ + gens = gens or [] + args = args or {} + + try: + F, opt = poly_from_expr(f, *gens, **args) + except PolificationFailed as exc: + raise ComputationFailed('galois_group', 1, exc) + + return F.galois_group(by_name=by_name, max_tries=max_tries, + randomize=randomize) diff --git a/phivenv/Lib/site-packages/sympy/polys/numberfields/minpoly.py b/phivenv/Lib/site-packages/sympy/polys/numberfields/minpoly.py new file mode 100644 index 0000000000000000000000000000000000000000..e5f556e6f82a9790aa7c421fc14ac0fb637b7b49 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/numberfields/minpoly.py @@ -0,0 +1,882 @@ +"""Minimal polynomials for algebraic numbers.""" + +from functools import reduce + +from sympy.core.add import Add +from sympy.core.exprtools import Factors +from sympy.core.function import expand_mul, expand_multinomial, _mexpand +from sympy.core.mul import Mul +from sympy.core.numbers import (I, Rational, pi, _illegal) +from sympy.core.singleton import S +from sympy.core.symbol import Dummy +from sympy.core.sympify import sympify +from sympy.core.traversal import preorder_traversal +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt, cbrt +from sympy.functions.elementary.trigonometric import cos, sin, tan +from sympy.ntheory.factor_ import divisors +from sympy.utilities.iterables import subsets + +from sympy.polys.domains import ZZ, QQ, FractionField +from sympy.polys.orthopolys import dup_chebyshevt +from sympy.polys.polyerrors import ( + NotAlgebraic, + GeneratorsError, +) +from sympy.polys.polytools import ( + Poly, PurePoly, invert, factor_list, groebner, resultant, + degree, poly_from_expr, parallel_poly_from_expr, lcm +) +from sympy.polys.polyutils import dict_from_expr, expr_from_dict +from sympy.polys.ring_series import rs_compose_add +from sympy.polys.rings import ring +from sympy.polys.rootoftools import CRootOf +from sympy.polys.specialpolys import cyclotomic_poly +from sympy.utilities import ( + numbered_symbols, public, sift +) + + +def _choose_factor(factors, x, v, dom=QQ, prec=200, bound=5): + """ + Return a factor having root ``v`` + It is assumed that one of the factors has root ``v``. + """ + + if isinstance(factors[0], tuple): + factors = [f[0] for f in factors] + if len(factors) == 1: + return factors[0] + + prec1 = 10 + points = {} + symbols = dom.symbols if hasattr(dom, 'symbols') else [] + while prec1 <= prec: + # when dealing with non-Rational numbers we usually evaluate + # with `subs` argument but we only need a ballpark evaluation + fe = [f.as_expr().xreplace({x:v}) for f in factors] + if v.is_number: + fe = [f.n(prec) for f in fe] + + # assign integers [0, n) to symbols (if any) + for n in subsets(range(bound), k=len(symbols), repetition=True): + for s, i in zip(symbols, n): + points[s] = i + + # evaluate the expression at these points + candidates = [(abs(f.subs(points).n(prec1)), i) + for i,f in enumerate(fe)] + + # if we get invalid numbers (e.g. from division by zero) + # we try again + if any(i in _illegal for i, _ in candidates): + continue + + # find the smallest two -- if they differ significantly + # then we assume we have found the factor that becomes + # 0 when v is substituted into it + can = sorted(candidates) + (a, ix), (b, _) = can[:2] + if b > a * 10**6: # XXX what to use? + return factors[ix] + + prec1 *= 2 + + raise NotImplementedError("multiple candidates for the minimal polynomial of %s" % v) + + +def _is_sum_surds(p): + return all(f.is_Rational or f.is_Pow and + f.base.is_Rational and (2*f.exp).is_Integer and f.is_extended_real + for t in Add.make_args(p) for f in Mul.make_args(t)) + + +def _separate_sq(p): + """ + helper function for ``_minimal_polynomial_sq`` + + It selects a rational ``g`` such that the polynomial ``p`` + consists of a sum of terms whose surds squared have gcd equal to ``g`` + and a sum of terms with surds squared prime with ``g``; + then it takes the field norm to eliminate ``sqrt(g)`` + + See simplify.simplify.split_surds and polytools.sqf_norm. + + Examples + ======== + + >>> from sympy import sqrt + >>> from sympy.abc import x + >>> from sympy.polys.numberfields.minpoly import _separate_sq + >>> p= -x + sqrt(2) + sqrt(3) + sqrt(7) + >>> p = _separate_sq(p); p + -x**2 + 2*sqrt(3)*x + 2*sqrt(7)*x - 2*sqrt(21) - 8 + >>> p = _separate_sq(p); p + -x**4 + 4*sqrt(7)*x**3 - 32*x**2 + 8*sqrt(7)*x + 20 + >>> p = _separate_sq(p); p + -x**8 + 48*x**6 - 536*x**4 + 1728*x**2 - 400 + + """ + def is_sqrt(expr): + return expr.is_Pow and expr.exp is S.Half + # p = c1*sqrt(q1) + ... + cn*sqrt(qn) -> a = [(c1, q1), .., (cn, qn)] + a = [] + for y in p.args: + if not y.is_Mul: + if is_sqrt(y): + a.append((S.One, y**2)) + elif y.is_Atom: + a.append((y, S.One)) + elif y.is_Pow and y.exp.is_integer: + a.append((y, S.One)) + else: + raise NotImplementedError + else: + T, F = sift(y.args, is_sqrt, binary=True) + a.append((Mul(*F), Mul(*T)**2)) + a.sort(key=lambda z: z[1]) + if a[-1][1] is S.One: + # there are no surds + return p + surds = [z for y, z in a] + for i in range(len(surds)): + if surds[i] != 1: + break + from sympy.simplify.radsimp import _split_gcd + g, b1, b2 = _split_gcd(*surds[i:]) + a1 = [] + a2 = [] + for y, z in a: + if z in b1: + a1.append(y*z**S.Half) + else: + a2.append(y*z**S.Half) + p1 = Add(*a1) + p2 = Add(*a2) + p = _mexpand(p1**2) - _mexpand(p2**2) + return p + +def _minimal_polynomial_sq(p, n, x): + """ + Returns the minimal polynomial for the ``nth-root`` of a sum of surds + or ``None`` if it fails. + + Parameters + ========== + + p : sum of surds + n : positive integer + x : variable of the returned polynomial + + Examples + ======== + + >>> from sympy.polys.numberfields.minpoly import _minimal_polynomial_sq + >>> from sympy import sqrt + >>> from sympy.abc import x + >>> q = 1 + sqrt(2) + sqrt(3) + >>> _minimal_polynomial_sq(q, 3, x) + x**12 - 4*x**9 - 4*x**6 + 16*x**3 - 8 + + """ + p = sympify(p) + n = sympify(n) + if not n.is_Integer or not n > 0 or not _is_sum_surds(p): + return None + pn = p**Rational(1, n) + # eliminate the square roots + p -= x + while 1: + p1 = _separate_sq(p) + if p1 is p: + p = p1.subs({x:x**n}) + break + else: + p = p1 + + # _separate_sq eliminates field extensions in a minimal way, so that + # if n = 1 then `p = constant*(minimal_polynomial(p))` + # if n > 1 it contains the minimal polynomial as a factor. + if n == 1: + p1 = Poly(p) + if p.coeff(x**p1.degree(x)) < 0: + p = -p + p = p.primitive()[1] + return p + # by construction `p` has root `pn` + # the minimal polynomial is the factor vanishing in x = pn + factors = factor_list(p)[1] + + result = _choose_factor(factors, x, pn) + return result + +def _minpoly_op_algebraic_element(op, ex1, ex2, x, dom, mp1=None, mp2=None): + """ + return the minimal polynomial for ``op(ex1, ex2)`` + + Parameters + ========== + + op : operation ``Add`` or ``Mul`` + ex1, ex2 : expressions for the algebraic elements + x : indeterminate of the polynomials + dom: ground domain + mp1, mp2 : minimal polynomials for ``ex1`` and ``ex2`` or None + + Examples + ======== + + >>> from sympy import sqrt, Add, Mul, QQ + >>> from sympy.polys.numberfields.minpoly import _minpoly_op_algebraic_element + >>> from sympy.abc import x, y + >>> p1 = sqrt(sqrt(2) + 1) + >>> p2 = sqrt(sqrt(2) - 1) + >>> _minpoly_op_algebraic_element(Mul, p1, p2, x, QQ) + x - 1 + >>> q1 = sqrt(y) + >>> q2 = 1 / y + >>> _minpoly_op_algebraic_element(Add, q1, q2, x, QQ.frac_field(y)) + x**2*y**2 - 2*x*y - y**3 + 1 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Resultant + .. [2] I.M. Isaacs, Proc. Amer. Math. Soc. 25 (1970), 638 + "Degrees of sums in a separable field extension". + + """ + y = Dummy(str(x)) + if mp1 is None: + mp1 = _minpoly_compose(ex1, x, dom) + if mp2 is None: + mp2 = _minpoly_compose(ex2, y, dom) + else: + mp2 = mp2.subs({x: y}) + + if op is Add: + # mp1a = mp1.subs({x: x - y}) + if dom == QQ: + R, X = ring('X', QQ) + p1 = R(dict_from_expr(mp1)[0]) + p2 = R(dict_from_expr(mp2)[0]) + else: + (p1, p2), _ = parallel_poly_from_expr((mp1, x - y), x, y) + r = p1.compose(p2) + mp1a = r.as_expr() + + elif op is Mul: + mp1a = _muly(mp1, x, y) + else: + raise NotImplementedError('option not available') + + if op is Mul or dom != QQ: + r = resultant(mp1a, mp2, gens=[y, x]) + else: + r = rs_compose_add(p1, p2) + r = expr_from_dict(r.as_expr_dict(), x) + + deg1 = degree(mp1, x) + deg2 = degree(mp2, y) + if op is Mul and deg1 == 1 or deg2 == 1: + # if deg1 = 1, then mp1 = x - a; mp1a = x - y - a; + # r = mp2(x - a), so that `r` is irreducible + return r + + r = Poly(r, x, domain=dom) + _, factors = r.factor_list() + res = _choose_factor(factors, x, op(ex1, ex2), dom) + return res.as_expr() + + +def _invertx(p, x): + """ + Returns ``expand_mul(x**degree(p, x)*p.subs(x, 1/x))`` + """ + p1 = poly_from_expr(p, x)[0] + + n = degree(p1) + a = [c * x**(n - i) for (i,), c in p1.terms()] + return Add(*a) + + +def _muly(p, x, y): + """ + Returns ``_mexpand(y**deg*p.subs({x:x / y}))`` + """ + p1 = poly_from_expr(p, x)[0] + + n = degree(p1) + a = [c * x**i * y**(n - i) for (i,), c in p1.terms()] + return Add(*a) + + +def _minpoly_pow(ex, pw, x, dom, mp=None): + """ + Returns ``minpoly(ex**pw, x)`` + + Parameters + ========== + + ex : algebraic element + pw : rational number + x : indeterminate of the polynomial + dom: ground domain + mp : minimal polynomial of ``p`` + + Examples + ======== + + >>> from sympy import sqrt, QQ, Rational + >>> from sympy.polys.numberfields.minpoly import _minpoly_pow, minpoly + >>> from sympy.abc import x, y + >>> p = sqrt(1 + sqrt(2)) + >>> _minpoly_pow(p, 2, x, QQ) + x**2 - 2*x - 1 + >>> minpoly(p**2, x) + x**2 - 2*x - 1 + >>> _minpoly_pow(y, Rational(1, 3), x, QQ.frac_field(y)) + x**3 - y + >>> minpoly(y**Rational(1, 3), x) + x**3 - y + + """ + pw = sympify(pw) + if not mp: + mp = _minpoly_compose(ex, x, dom) + if not pw.is_rational: + raise NotAlgebraic("%s does not seem to be an algebraic element" % ex) + if pw < 0: + if mp == x: + raise ZeroDivisionError('%s is zero' % ex) + mp = _invertx(mp, x) + if pw == -1: + return mp + pw = -pw + ex = 1/ex + + y = Dummy(str(x)) + mp = mp.subs({x: y}) + n, d = pw.as_numer_denom() + res = Poly(resultant(mp, x**d - y**n, gens=[y]), x, domain=dom) + _, factors = res.factor_list() + res = _choose_factor(factors, x, ex**pw, dom) + return res.as_expr() + + +def _minpoly_add(x, dom, *a): + """ + returns ``minpoly(Add(*a), dom, x)`` + """ + mp = _minpoly_op_algebraic_element(Add, a[0], a[1], x, dom) + p = a[0] + a[1] + for px in a[2:]: + mp = _minpoly_op_algebraic_element(Add, p, px, x, dom, mp1=mp) + p = p + px + return mp + + +def _minpoly_mul(x, dom, *a): + """ + returns ``minpoly(Mul(*a), dom, x)`` + """ + mp = _minpoly_op_algebraic_element(Mul, a[0], a[1], x, dom) + p = a[0] * a[1] + for px in a[2:]: + mp = _minpoly_op_algebraic_element(Mul, p, px, x, dom, mp1=mp) + p = p * px + return mp + + +def _minpoly_sin(ex, x): + """ + Returns the minimal polynomial of ``sin(ex)`` + see https://mathworld.wolfram.com/TrigonometryAngles.html + """ + c, a = ex.args[0].as_coeff_Mul() + if a is pi: + if c.is_rational: + n = c.q + q = sympify(n) + if q.is_prime: + # for a = pi*p/q with q odd prime, using chebyshevt + # write sin(q*a) = mp(sin(a))*sin(a); + # the roots of mp(x) are sin(pi*p/q) for p = 1,..., q - 1 + a = dup_chebyshevt(n, ZZ) + return Add(*[x**(n - i - 1)*a[i] for i in range(n)]) + if c.p == 1: + if q == 9: + return 64*x**6 - 96*x**4 + 36*x**2 - 3 + + if n % 2 == 1: + # for a = pi*p/q with q odd, use + # sin(q*a) = 0 to see that the minimal polynomial must be + # a factor of dup_chebyshevt(n, ZZ) + a = dup_chebyshevt(n, ZZ) + a = [x**(n - i)*a[i] for i in range(n + 1)] + r = Add(*a) + _, factors = factor_list(r) + res = _choose_factor(factors, x, ex) + return res + + expr = ((1 - cos(2*c*pi))/2)**S.Half + res = _minpoly_compose(expr, x, QQ) + return res + + raise NotAlgebraic("%s does not seem to be an algebraic element" % ex) + + +def _minpoly_cos(ex, x): + """ + Returns the minimal polynomial of ``cos(ex)`` + see https://mathworld.wolfram.com/TrigonometryAngles.html + """ + c, a = ex.args[0].as_coeff_Mul() + if a is pi: + if c.is_rational: + if c.p == 1: + if c.q == 7: + return 8*x**3 - 4*x**2 - 4*x + 1 + if c.q == 9: + return 8*x**3 - 6*x - 1 + elif c.p == 2: + q = sympify(c.q) + if q.is_prime: + s = _minpoly_sin(ex, x) + return _mexpand(s.subs({x:sqrt((1 - x)/2)})) + + # for a = pi*p/q, cos(q*a) =T_q(cos(a)) = (-1)**p + n = int(c.q) + a = dup_chebyshevt(n, ZZ) + a = [x**(n - i)*a[i] for i in range(n + 1)] + r = Add(*a) - (-1)**c.p + _, factors = factor_list(r) + res = _choose_factor(factors, x, ex) + return res + + raise NotAlgebraic("%s does not seem to be an algebraic element" % ex) + + +def _minpoly_tan(ex, x): + """ + Returns the minimal polynomial of ``tan(ex)`` + see https://github.com/sympy/sympy/issues/21430 + """ + c, a = ex.args[0].as_coeff_Mul() + if a is pi: + if c.is_rational: + c = c * 2 + n = int(c.q) + a = n if c.p % 2 == 0 else 1 + terms = [] + for k in range((c.p+1)%2, n+1, 2): + terms.append(a*x**k) + a = -(a*(n-k-1)*(n-k)) // ((k+1)*(k+2)) + + r = Add(*terms) + _, factors = factor_list(r) + res = _choose_factor(factors, x, ex) + return res + + raise NotAlgebraic("%s does not seem to be an algebraic element" % ex) + + +def _minpoly_exp(ex, x): + """ + Returns the minimal polynomial of ``exp(ex)`` + """ + c, a = ex.args[0].as_coeff_Mul() + if a == I*pi: + if c.is_rational: + q = sympify(c.q) + if c.p == 1 or c.p == -1: + if q == 3: + return x**2 - x + 1 + if q == 4: + return x**4 + 1 + if q == 6: + return x**4 - x**2 + 1 + if q == 8: + return x**8 + 1 + if q == 9: + return x**6 - x**3 + 1 + if q == 10: + return x**8 - x**6 + x**4 - x**2 + 1 + if q.is_prime: + s = 0 + for i in range(q): + s += (-x)**i + return s + + # x**(2*q) = product(factors) + factors = [cyclotomic_poly(i, x) for i in divisors(2*q)] + mp = _choose_factor(factors, x, ex) + return mp + else: + raise NotAlgebraic("%s does not seem to be an algebraic element" % ex) + raise NotAlgebraic("%s does not seem to be an algebraic element" % ex) + + +def _minpoly_rootof(ex, x): + """ + Returns the minimal polynomial of a ``CRootOf`` object. + """ + p = ex.expr + p = p.subs({ex.poly.gens[0]:x}) + _, factors = factor_list(p, x) + result = _choose_factor(factors, x, ex) + return result + + +def _minpoly_compose(ex, x, dom): + """ + Computes the minimal polynomial of an algebraic element + using operations on minimal polynomials + + Examples + ======== + + >>> from sympy import minimal_polynomial, sqrt, Rational + >>> from sympy.abc import x, y + >>> minimal_polynomial(sqrt(2) + 3*Rational(1, 3), x, compose=True) + x**2 - 2*x - 1 + >>> minimal_polynomial(sqrt(y) + 1/y, x, compose=True) + x**2*y**2 - 2*x*y - y**3 + 1 + + """ + if ex.is_Rational: + return ex.q*x - ex.p + if ex is I: + _, factors = factor_list(x**2 + 1, x, domain=dom) + return x**2 + 1 if len(factors) == 1 else x - I + + if ex is S.GoldenRatio: + _, factors = factor_list(x**2 - x - 1, x, domain=dom) + if len(factors) == 1: + return x**2 - x - 1 + else: + return _choose_factor(factors, x, (1 + sqrt(5))/2, dom=dom) + + if ex is S.TribonacciConstant: + _, factors = factor_list(x**3 - x**2 - x - 1, x, domain=dom) + if len(factors) == 1: + return x**3 - x**2 - x - 1 + else: + fac = (1 + cbrt(19 - 3*sqrt(33)) + cbrt(19 + 3*sqrt(33))) / 3 + return _choose_factor(factors, x, fac, dom=dom) + + if hasattr(dom, 'symbols') and ex in dom.symbols: + return x - ex + + if dom.is_QQ and _is_sum_surds(ex): + # eliminate the square roots + v = ex + ex -= x + while 1: + ex1 = _separate_sq(ex) + if ex1 is ex: + return _choose_factor(factor_list(ex)[1], x, v) + else: + ex = ex1 + + if ex.is_Add: + res = _minpoly_add(x, dom, *ex.args) + elif ex.is_Mul: + f = Factors(ex).factors + r = sift(f.items(), lambda itx: itx[0].is_Rational and itx[1].is_Rational) + if r[True] and dom == QQ: + ex1 = Mul(*[bx**ex for bx, ex in r[False] + r[None]]) + r1 = dict(r[True]) + dens = [y.q for y in r1.values()] + lcmdens = reduce(lcm, dens, 1) + neg1 = S.NegativeOne + expn1 = r1.pop(neg1, S.Zero) + nums = [base**(y.p*lcmdens // y.q) for base, y in r1.items()] + ex2 = Mul(*nums) + mp1 = minimal_polynomial(ex1, x) + # use the fact that in SymPy canonicalization products of integers + # raised to rational powers are organized in relatively prime + # bases, and that in ``base**(n/d)`` a perfect power is + # simplified with the root + # Powers of -1 have to be treated separately to preserve sign. + mp2 = ex2.q*x**lcmdens - ex2.p*neg1**(expn1*lcmdens) + ex2 = neg1**expn1 * ex2**Rational(1, lcmdens) + res = _minpoly_op_algebraic_element(Mul, ex1, ex2, x, dom, mp1=mp1, mp2=mp2) + else: + res = _minpoly_mul(x, dom, *ex.args) + elif ex.is_Pow: + res = _minpoly_pow(ex.base, ex.exp, x, dom) + elif ex.__class__ is sin: + res = _minpoly_sin(ex, x) + elif ex.__class__ is cos: + res = _minpoly_cos(ex, x) + elif ex.__class__ is tan: + res = _minpoly_tan(ex, x) + elif ex.__class__ is exp: + res = _minpoly_exp(ex, x) + elif ex.__class__ is CRootOf: + res = _minpoly_rootof(ex, x) + else: + raise NotAlgebraic("%s does not seem to be an algebraic element" % ex) + return res + + +@public +def minimal_polynomial(ex, x=None, compose=True, polys=False, domain=None): + """ + Computes the minimal polynomial of an algebraic element. + + Parameters + ========== + + ex : Expr + Element or expression whose minimal polynomial is to be calculated. + + x : Symbol, optional + Independent variable of the minimal polynomial + + compose : boolean, optional (default=True) + Method to use for computing minimal polynomial. If ``compose=True`` + (default) then ``_minpoly_compose`` is used, if ``compose=False`` then + groebner bases are used. + + polys : boolean, optional (default=False) + If ``True`` returns a ``Poly`` object else an ``Expr`` object. + + domain : Domain, optional + Ground domain + + Notes + ===== + + By default ``compose=True``, the minimal polynomial of the subexpressions of ``ex`` + are computed, then the arithmetic operations on them are performed using the resultant + and factorization. + If ``compose=False``, a bottom-up algorithm is used with ``groebner``. + The default algorithm stalls less frequently. + + If no ground domain is given, it will be generated automatically from the expression. + + Examples + ======== + + >>> from sympy import minimal_polynomial, sqrt, solve, QQ + >>> from sympy.abc import x, y + + >>> minimal_polynomial(sqrt(2), x) + x**2 - 2 + >>> minimal_polynomial(sqrt(2), x, domain=QQ.algebraic_field(sqrt(2))) + x - sqrt(2) + >>> minimal_polynomial(sqrt(2) + sqrt(3), x) + x**4 - 10*x**2 + 1 + >>> minimal_polynomial(solve(x**3 + x + 3)[0], x) + x**3 + x + 3 + >>> minimal_polynomial(sqrt(y), x) + x**2 - y + + """ + + ex = sympify(ex) + if ex.is_number: + # not sure if it's always needed but try it for numbers (issue 8354) + ex = _mexpand(ex, recursive=True) + for expr in preorder_traversal(ex): + if expr.is_AlgebraicNumber: + compose = False + break + + if x is not None: + x, cls = sympify(x), Poly + else: + x, cls = Dummy('x'), PurePoly + + if not domain: + if ex.free_symbols: + domain = FractionField(QQ, list(ex.free_symbols)) + else: + domain = QQ + if hasattr(domain, 'symbols') and x in domain.symbols: + raise GeneratorsError("the variable %s is an element of the ground " + "domain %s" % (x, domain)) + + if compose: + result = _minpoly_compose(ex, x, domain) + result = result.primitive()[1] + c = result.coeff(x**degree(result, x)) + if c.is_negative: + result = expand_mul(-result) + return cls(result, x, field=True) if polys else result.collect(x) + + if not domain.is_QQ: + raise NotImplementedError("groebner method only works for QQ") + + result = _minpoly_groebner(ex, x, cls) + return cls(result, x, field=True) if polys else result.collect(x) + + +def _minpoly_groebner(ex, x, cls): + """ + Computes the minimal polynomial of an algebraic number + using Groebner bases + + Examples + ======== + + >>> from sympy import minimal_polynomial, sqrt, Rational + >>> from sympy.abc import x + >>> minimal_polynomial(sqrt(2) + 3*Rational(1, 3), x, compose=False) + x**2 - 2*x - 1 + + """ + + generator = numbered_symbols('a', cls=Dummy) + mapping, symbols = {}, {} + + def update_mapping(ex, exp, base=None): + a = next(generator) + symbols[ex] = a + + if base is not None: + mapping[ex] = a**exp + base + else: + mapping[ex] = exp.as_expr(a) + + return a + + def bottom_up_scan(ex): + """ + Transform a given algebraic expression *ex* into a multivariate + polynomial, by introducing fresh variables with defining equations. + + Explanation + =========== + + The critical elements of the algebraic expression *ex* are root + extractions, instances of :py:class:`~.AlgebraicNumber`, and negative + powers. + + When we encounter a root extraction or an :py:class:`~.AlgebraicNumber` + we replace this expression with a fresh variable ``a_i``, and record + the defining polynomial for ``a_i``. For example, if ``a_0**(1/3)`` + occurs, we will replace it with ``a_1``, and record the new defining + polynomial ``a_1**3 - a_0``. + + When we encounter a negative power we transform it into a positive + power by algebraically inverting the base. This means computing the + minimal polynomial in ``x`` for the base, inverting ``x`` modulo this + poly (which generates a new polynomial) and then substituting the + original base expression for ``x`` in this last polynomial. + + We return the transformed expression, and we record the defining + equations for new symbols using the ``update_mapping()`` function. + + """ + if ex.is_Atom: + if ex is S.ImaginaryUnit: + if ex not in mapping: + return update_mapping(ex, 2, 1) + else: + return symbols[ex] + elif ex.is_Rational: + return ex + elif ex.is_Add: + return Add(*[ bottom_up_scan(g) for g in ex.args ]) + elif ex.is_Mul: + return Mul(*[ bottom_up_scan(g) for g in ex.args ]) + elif ex.is_Pow: + if ex.exp.is_Rational: + if ex.exp < 0: + minpoly_base = _minpoly_groebner(ex.base, x, cls) + inverse = invert(x, minpoly_base).as_expr() + base_inv = inverse.subs(x, ex.base).expand() + + if ex.exp == -1: + return bottom_up_scan(base_inv) + else: + ex = base_inv**(-ex.exp) + if not ex.exp.is_Integer: + base, exp = ( + ex.base**ex.exp.p).expand(), Rational(1, ex.exp.q) + else: + base, exp = ex.base, ex.exp + base = bottom_up_scan(base) + expr = base**exp + + if expr not in mapping: + if exp.is_Integer: + return expr.expand() + else: + return update_mapping(expr, 1 / exp, -base) + else: + return symbols[expr] + elif ex.is_AlgebraicNumber: + if ex not in mapping: + return update_mapping(ex, ex.minpoly_of_element()) + else: + return symbols[ex] + + raise NotAlgebraic("%s does not seem to be an algebraic number" % ex) + + def simpler_inverse(ex): + """ + Returns True if it is more likely that the minimal polynomial + algorithm works better with the inverse + """ + if ex.is_Pow: + if (1/ex.exp).is_integer and ex.exp < 0: + if ex.base.is_Add: + return True + if ex.is_Mul: + hit = True + for p in ex.args: + if p.is_Add: + return False + if p.is_Pow: + if p.base.is_Add and p.exp > 0: + return False + + if hit: + return True + return False + + inverted = False + ex = expand_multinomial(ex) + if ex.is_AlgebraicNumber: + return ex.minpoly_of_element().as_expr(x) + elif ex.is_Rational: + result = ex.q*x - ex.p + else: + inverted = simpler_inverse(ex) + if inverted: + ex = ex**-1 + res = None + if ex.is_Pow and (1/ex.exp).is_Integer: + n = 1/ex.exp + res = _minimal_polynomial_sq(ex.base, n, x) + + elif _is_sum_surds(ex): + res = _minimal_polynomial_sq(ex, S.One, x) + + if res is not None: + result = res + + if res is None: + bus = bottom_up_scan(ex) + F = [x - bus] + list(mapping.values()) + G = groebner(F, list(symbols.values()) + [x], order='lex') + + _, factors = factor_list(G[-1]) + # by construction G[-1] has root `ex` + result = _choose_factor(factors, x, ex) + if inverted: + result = _invertx(result, x) + if result.coeff(x**degree(result, x)) < 0: + result = expand_mul(-result) + + return result + + +@public +def minpoly(ex, x=None, compose=True, polys=False, domain=None): + """This is a synonym for :py:func:`~.minimal_polynomial`.""" + return minimal_polynomial(ex, x=x, compose=compose, polys=polys, domain=domain) diff --git a/phivenv/Lib/site-packages/sympy/polys/numberfields/modules.py b/phivenv/Lib/site-packages/sympy/polys/numberfields/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..af2e29bcc9cf73d97def0701712f90db58601b86 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/numberfields/modules.py @@ -0,0 +1,2114 @@ +r"""Modules in number fields. + +The classes defined here allow us to work with finitely generated, free +modules, whose generators are algebraic numbers. + +There is an abstract base class called :py:class:`~.Module`, which has two +concrete subclasses, :py:class:`~.PowerBasis` and :py:class:`~.Submodule`. + +Every module is defined by its basis, or set of generators: + +* For a :py:class:`~.PowerBasis`, the generators are the first $n$ powers + (starting with the zeroth) of an algebraic integer $\theta$ of degree $n$. + The :py:class:`~.PowerBasis` is constructed by passing either the minimal + polynomial of $\theta$, or an :py:class:`~.AlgebraicField` having $\theta$ + as its primitive element. + +* For a :py:class:`~.Submodule`, the generators are a set of + $\mathbb{Q}$-linear combinations of the generators of another module. That + other module is then the "parent" of the :py:class:`~.Submodule`. The + coefficients of the $\mathbb{Q}$-linear combinations may be given by an + integer matrix, and a positive integer denominator. Each column of the matrix + defines a generator. + +>>> from sympy.polys import Poly, cyclotomic_poly, ZZ +>>> from sympy.abc import x +>>> from sympy.polys.matrices import DomainMatrix, DM +>>> from sympy.polys.numberfields.modules import PowerBasis +>>> T = Poly(cyclotomic_poly(5, x)) +>>> A = PowerBasis(T) +>>> print(A) +PowerBasis(x**4 + x**3 + x**2 + x + 1) +>>> B = A.submodule_from_matrix(2 * DomainMatrix.eye(4, ZZ), denom=3) +>>> print(B) +Submodule[[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 2, 0], [0, 0, 0, 2]]/3 +>>> print(B.parent) +PowerBasis(x**4 + x**3 + x**2 + x + 1) + +Thus, every module is either a :py:class:`~.PowerBasis`, +or a :py:class:`~.Submodule`, some ancestor of which is a +:py:class:`~.PowerBasis`. (If ``S`` is a :py:class:`~.Submodule`, then its +ancestors are ``S.parent``, ``S.parent.parent``, and so on). + +The :py:class:`~.ModuleElement` class represents a linear combination of the +generators of any module. Critically, the coefficients of this linear +combination are not restricted to be integers, but may be any rational +numbers. This is necessary so that any and all algebraic integers be +representable, starting from the power basis in a primitive element $\theta$ +for the number field in question. For example, in a quadratic field +$\mathbb{Q}(\sqrt{d})$ where $d \equiv 1 \mod{4}$, a denominator of $2$ is +needed. + +A :py:class:`~.ModuleElement` can be constructed from an integer column vector +and a denominator: + +>>> U = Poly(x**2 - 5) +>>> M = PowerBasis(U) +>>> e = M(DM([[1], [1]], ZZ), denom=2) +>>> print(e) +[1, 1]/2 +>>> print(e.module) +PowerBasis(x**2 - 5) + +The :py:class:`~.PowerBasisElement` class is a subclass of +:py:class:`~.ModuleElement` that represents elements of a +:py:class:`~.PowerBasis`, and adds functionality pertinent to elements +represented directly over powers of the primitive element $\theta$. + + +Arithmetic with module elements +=============================== + +While a :py:class:`~.ModuleElement` represents a linear combination over the +generators of a particular module, recall that every module is either a +:py:class:`~.PowerBasis` or a descendant (along a chain of +:py:class:`~.Submodule` objects) thereof, so that in fact every +:py:class:`~.ModuleElement` represents an algebraic number in some field +$\mathbb{Q}(\theta)$, where $\theta$ is the defining element of some +:py:class:`~.PowerBasis`. It thus makes sense to talk about the number field +to which a given :py:class:`~.ModuleElement` belongs. + +This means that any two :py:class:`~.ModuleElement` instances can be added, +subtracted, multiplied, or divided, provided they belong to the same number +field. Similarly, since $\mathbb{Q}$ is a subfield of every number field, +any :py:class:`~.ModuleElement` may be added, multiplied, etc. by any +rational number. + +>>> from sympy import QQ +>>> from sympy.polys.numberfields.modules import to_col +>>> T = Poly(cyclotomic_poly(5)) +>>> A = PowerBasis(T) +>>> C = A.submodule_from_matrix(3 * DomainMatrix.eye(4, ZZ)) +>>> e = A(to_col([0, 2, 0, 0]), denom=3) +>>> f = A(to_col([0, 0, 0, 7]), denom=5) +>>> g = C(to_col([1, 1, 1, 1])) +>>> e + f +[0, 10, 0, 21]/15 +>>> e - f +[0, 10, 0, -21]/15 +>>> e - g +[-9, -7, -9, -9]/3 +>>> e + QQ(7, 10) +[21, 20, 0, 0]/30 +>>> e * f +[-14, -14, -14, -14]/15 +>>> e ** 2 +[0, 0, 4, 0]/9 +>>> f // g +[7, 7, 7, 7]/15 +>>> f * QQ(2, 3) +[0, 0, 0, 14]/15 + +However, care must be taken with arithmetic operations on +:py:class:`~.ModuleElement`, because the module $C$ to which the result will +belong will be the nearest common ancestor (NCA) of the modules $A$, $B$ to +which the two operands belong, and $C$ may be different from either or both +of $A$ and $B$. + +>>> A = PowerBasis(T) +>>> B = A.submodule_from_matrix(2 * DomainMatrix.eye(4, ZZ)) +>>> C = A.submodule_from_matrix(3 * DomainMatrix.eye(4, ZZ)) +>>> print((B(0) * C(0)).module == A) +True + +Before the arithmetic operation is performed, copies of the two operands are +automatically converted into elements of the NCA (the operands themselves are +not modified). This upward conversion along an ancestor chain is easy: it just +requires the successive multiplication by the defining matrix of each +:py:class:`~.Submodule`. + +Conversely, downward conversion, i.e. representing a given +:py:class:`~.ModuleElement` in a submodule, is also supported -- namely by +the :py:meth:`~sympy.polys.numberfields.modules.Submodule.represent` method +-- but is not guaranteed to succeed in general, since the given element may +not belong to the submodule. The main circumstance in which this issue tends +to arise is with multiplication, since modules, while closed under addition, +need not be closed under multiplication. + + +Multiplication +-------------- + +Generally speaking, a module need not be closed under multiplication, i.e. need +not form a ring. However, many of the modules we work with in the context of +number fields are in fact rings, and our classes do support multiplication. + +Specifically, any :py:class:`~.Module` can attempt to compute its own +multiplication table, but this does not happen unless an attempt is made to +multiply two :py:class:`~.ModuleElement` instances belonging to it. + +>>> A = PowerBasis(T) +>>> print(A._mult_tab is None) +True +>>> a = A(0)*A(1) +>>> print(A._mult_tab is None) +False + +Every :py:class:`~.PowerBasis` is, by its nature, closed under multiplication, +so instances of :py:class:`~.PowerBasis` can always successfully compute their +multiplication table. + +When a :py:class:`~.Submodule` attempts to compute its multiplication table, +it converts each of its own generators into elements of its parent module, +multiplies them there, in every possible pairing, and then tries to +represent the results in itself, i.e. as $\mathbb{Z}$-linear combinations +over its own generators. This will succeed if and only if the submodule is +in fact closed under multiplication. + + +Module Homomorphisms +==================== + +Many important number theoretic algorithms require the calculation of the +kernel of one or more module homomorphisms. Accordingly we have several +lightweight classes, :py:class:`~.ModuleHomomorphism`, +:py:class:`~.ModuleEndomorphism`, :py:class:`~.InnerEndomorphism`, and +:py:class:`~.EndomorphismRing`, which provide the minimal necessary machinery +to support this. + +""" + +from sympy.core.intfunc import igcd, ilcm +from sympy.core.symbol import Dummy +from sympy.polys.polyclasses import ANP +from sympy.polys.polytools import Poly +from sympy.polys.densetools import dup_clear_denoms +from sympy.polys.domains.algebraicfield import AlgebraicField +from sympy.polys.domains.finitefield import FF +from sympy.polys.domains.rationalfield import QQ +from sympy.polys.domains.integerring import ZZ +from sympy.polys.matrices.domainmatrix import DomainMatrix +from sympy.polys.matrices.exceptions import DMBadInputError +from sympy.polys.matrices.normalforms import hermite_normal_form +from sympy.polys.polyerrors import CoercionFailed, UnificationFailed +from sympy.polys.polyutils import IntegerPowerable +from .exceptions import ClosureFailure, MissingUnityError, StructureError +from .utilities import AlgIntPowers, is_rat, get_num_denom + + +def to_col(coeffs): + r"""Transform a list of integer coefficients into a column vector.""" + return DomainMatrix([[ZZ(c) for c in coeffs]], (1, len(coeffs)), ZZ).transpose() + + +class Module: + """ + Generic finitely-generated module. + + This is an abstract base class, and should not be instantiated directly. + The two concrete subclasses are :py:class:`~.PowerBasis` and + :py:class:`~.Submodule`. + + Every :py:class:`~.Submodule` is derived from another module, referenced + by its ``parent`` attribute. If ``S`` is a submodule, then we refer to + ``S.parent``, ``S.parent.parent``, and so on, as the "ancestors" of + ``S``. Thus, every :py:class:`~.Module` is either a + :py:class:`~.PowerBasis` or a :py:class:`~.Submodule`, some ancestor of + which is a :py:class:`~.PowerBasis`. + """ + + @property + def n(self): + """The number of generators of this module.""" + raise NotImplementedError + + def mult_tab(self): + """ + Get the multiplication table for this module (if closed under mult). + + Explanation + =========== + + Computes a dictionary ``M`` of dictionaries of lists, representing the + upper triangular half of the multiplication table. + + In other words, if ``0 <= i <= j < self.n``, then ``M[i][j]`` is the + list ``c`` of coefficients such that + ``g[i] * g[j] == sum(c[k]*g[k], k in range(self.n))``, + where ``g`` is the list of generators of this module. + + If ``j < i`` then ``M[i][j]`` is undefined. + + Examples + ======== + + >>> from sympy.polys import Poly, cyclotomic_poly + >>> from sympy.polys.numberfields.modules import PowerBasis + >>> T = Poly(cyclotomic_poly(5)) + >>> A = PowerBasis(T) + >>> print(A.mult_tab()) # doctest: +SKIP + {0: {0: [1, 0, 0, 0], 1: [0, 1, 0, 0], 2: [0, 0, 1, 0], 3: [0, 0, 0, 1]}, + 1: {1: [0, 0, 1, 0], 2: [0, 0, 0, 1], 3: [-1, -1, -1, -1]}, + 2: {2: [-1, -1, -1, -1], 3: [1, 0, 0, 0]}, + 3: {3: [0, 1, 0, 0]}} + + Returns + ======= + + dict of dict of lists + + Raises + ====== + + ClosureFailure + If the module is not closed under multiplication. + + """ + raise NotImplementedError + + @property + def parent(self): + """ + The parent module, if any, for this module. + + Explanation + =========== + + For a :py:class:`~.Submodule` this is its ``parent`` attribute; for a + :py:class:`~.PowerBasis` this is ``None``. + + Returns + ======= + + :py:class:`~.Module`, ``None`` + + See Also + ======== + + Module + + """ + return None + + def represent(self, elt): + r""" + Represent a module element as an integer-linear combination over the + generators of this module. + + Explanation + =========== + + In our system, to "represent" always means to write a + :py:class:`~.ModuleElement` as a :ref:`ZZ`-linear combination over the + generators of the present :py:class:`~.Module`. Furthermore, the + incoming :py:class:`~.ModuleElement` must belong to an ancestor of + the present :py:class:`~.Module` (or to the present + :py:class:`~.Module` itself). + + The most common application is to represent a + :py:class:`~.ModuleElement` in a :py:class:`~.Submodule`. For example, + this is involved in computing multiplication tables. + + On the other hand, representing in a :py:class:`~.PowerBasis` is an + odd case, and one which tends not to arise in practice, except for + example when using a :py:class:`~.ModuleEndomorphism` on a + :py:class:`~.PowerBasis`. + + In such a case, (1) the incoming :py:class:`~.ModuleElement` must + belong to the :py:class:`~.PowerBasis` itself (since the latter has no + proper ancestors) and (2) it is "representable" iff it belongs to + $\mathbb{Z}[\theta]$ (although generally a + :py:class:`~.PowerBasisElement` may represent any element of + $\mathbb{Q}(\theta)$, i.e. any algebraic number). + + Examples + ======== + + >>> from sympy import Poly, cyclotomic_poly + >>> from sympy.polys.numberfields.modules import PowerBasis, to_col + >>> from sympy.abc import zeta + >>> T = Poly(cyclotomic_poly(5)) + >>> A = PowerBasis(T) + >>> a = A(to_col([2, 4, 6, 8])) + + The :py:class:`~.ModuleElement` ``a`` has all even coefficients. + If we represent ``a`` in the submodule ``B = 2*A``, the coefficients in + the column vector will be halved: + + >>> B = A.submodule_from_gens([2*A(i) for i in range(4)]) + >>> b = B.represent(a) + >>> print(b.transpose()) # doctest: +SKIP + DomainMatrix([[1, 2, 3, 4]], (1, 4), ZZ) + + However, the element of ``B`` so defined still represents the same + algebraic number: + + >>> print(a.poly(zeta).as_expr()) + 8*zeta**3 + 6*zeta**2 + 4*zeta + 2 + >>> print(B(b).over_power_basis().poly(zeta).as_expr()) + 8*zeta**3 + 6*zeta**2 + 4*zeta + 2 + + Parameters + ========== + + elt : :py:class:`~.ModuleElement` + The module element to be represented. Must belong to some ancestor + module of this module (including this module itself). + + Returns + ======= + + :py:class:`~.DomainMatrix` over :ref:`ZZ` + This will be a column vector, representing the coefficients of a + linear combination of this module's generators, which equals the + given element. + + Raises + ====== + + ClosureFailure + If the given element cannot be represented as a :ref:`ZZ`-linear + combination over this module. + + See Also + ======== + + .Submodule.represent + .PowerBasis.represent + + """ + raise NotImplementedError + + def ancestors(self, include_self=False): + """ + Return the list of ancestor modules of this module, from the + foundational :py:class:`~.PowerBasis` downward, optionally including + ``self``. + + See Also + ======== + + Module + + """ + c = self.parent + a = [] if c is None else c.ancestors(include_self=True) + if include_self: + a.append(self) + return a + + def power_basis_ancestor(self): + """ + Return the :py:class:`~.PowerBasis` that is an ancestor of this module. + + See Also + ======== + + Module + + """ + if isinstance(self, PowerBasis): + return self + c = self.parent + if c is not None: + return c.power_basis_ancestor() + return None + + def nearest_common_ancestor(self, other): + """ + Locate the nearest common ancestor of this module and another. + + Returns + ======= + + :py:class:`~.Module`, ``None`` + + See Also + ======== + + Module + + """ + sA = self.ancestors(include_self=True) + oA = other.ancestors(include_self=True) + nca = None + for sa, oa in zip(sA, oA): + if sa == oa: + nca = sa + else: + break + return nca + + @property + def number_field(self): + r""" + Return the associated :py:class:`~.AlgebraicField`, if any. + + Explanation + =========== + + A :py:class:`~.PowerBasis` can be constructed on a :py:class:`~.Poly` + $f$ or on an :py:class:`~.AlgebraicField` $K$. In the latter case, the + :py:class:`~.PowerBasis` and all its descendant modules will return $K$ + as their ``.number_field`` property, while in the former case they will + all return ``None``. + + Returns + ======= + + :py:class:`~.AlgebraicField`, ``None`` + + """ + return self.power_basis_ancestor().number_field + + def is_compat_col(self, col): + """Say whether *col* is a suitable column vector for this module.""" + return isinstance(col, DomainMatrix) and col.shape == (self.n, 1) and col.domain.is_ZZ + + def __call__(self, spec, denom=1): + r""" + Generate a :py:class:`~.ModuleElement` belonging to this module. + + Examples + ======== + + >>> from sympy.polys import Poly, cyclotomic_poly + >>> from sympy.polys.numberfields.modules import PowerBasis, to_col + >>> T = Poly(cyclotomic_poly(5)) + >>> A = PowerBasis(T) + >>> e = A(to_col([1, 2, 3, 4]), denom=3) + >>> print(e) # doctest: +SKIP + [1, 2, 3, 4]/3 + >>> f = A(2) + >>> print(f) # doctest: +SKIP + [0, 0, 1, 0] + + Parameters + ========== + + spec : :py:class:`~.DomainMatrix`, int + Specifies the numerators of the coefficients of the + :py:class:`~.ModuleElement`. Can be either a column vector over + :ref:`ZZ`, whose length must equal the number $n$ of generators of + this module, or else an integer ``j``, $0 \leq j < n$, which is a + shorthand for column $j$ of $I_n$, the $n \times n$ identity + matrix. + denom : int, optional (default=1) + Denominator for the coefficients of the + :py:class:`~.ModuleElement`. + + Returns + ======= + + :py:class:`~.ModuleElement` + The coefficients are the entries of the *spec* vector, divided by + *denom*. + + """ + if isinstance(spec, int) and 0 <= spec < self.n: + spec = DomainMatrix.eye(self.n, ZZ)[:, spec].to_dense() + if not self.is_compat_col(spec): + raise ValueError('Compatible column vector required.') + return make_mod_elt(self, spec, denom=denom) + + def starts_with_unity(self): + """Say whether the module's first generator equals unity.""" + raise NotImplementedError + + def basis_elements(self): + """ + Get list of :py:class:`~.ModuleElement` being the generators of this + module. + """ + return [self(j) for j in range(self.n)] + + def zero(self): + """Return a :py:class:`~.ModuleElement` representing zero.""" + return self(0) * 0 + + def one(self): + """ + Return a :py:class:`~.ModuleElement` representing unity, + and belonging to the first ancestor of this module (including + itself) that starts with unity. + """ + return self.element_from_rational(1) + + def element_from_rational(self, a): + """ + Return a :py:class:`~.ModuleElement` representing a rational number. + + Explanation + =========== + + The returned :py:class:`~.ModuleElement` will belong to the first + module on this module's ancestor chain (including this module + itself) that starts with unity. + + Examples + ======== + + >>> from sympy.polys import Poly, cyclotomic_poly, QQ + >>> from sympy.polys.numberfields.modules import PowerBasis + >>> T = Poly(cyclotomic_poly(5)) + >>> A = PowerBasis(T) + >>> a = A.element_from_rational(QQ(2, 3)) + >>> print(a) # doctest: +SKIP + [2, 0, 0, 0]/3 + + Parameters + ========== + + a : int, :ref:`ZZ`, :ref:`QQ` + + Returns + ======= + + :py:class:`~.ModuleElement` + + """ + raise NotImplementedError + + def submodule_from_gens(self, gens, hnf=True, hnf_modulus=None): + """ + Form the submodule generated by a list of :py:class:`~.ModuleElement` + belonging to this module. + + Examples + ======== + + >>> from sympy.polys import Poly, cyclotomic_poly + >>> from sympy.polys.numberfields.modules import PowerBasis + >>> T = Poly(cyclotomic_poly(5)) + >>> A = PowerBasis(T) + >>> gens = [A(0), 2*A(1), 3*A(2), 4*A(3)//5] + >>> B = A.submodule_from_gens(gens) + >>> print(B) # doctest: +SKIP + Submodule[[5, 0, 0, 0], [0, 10, 0, 0], [0, 0, 15, 0], [0, 0, 0, 4]]/5 + + Parameters + ========== + + gens : list of :py:class:`~.ModuleElement` belonging to this module. + hnf : boolean, optional (default=True) + If True, we will reduce the matrix into Hermite Normal Form before + forming the :py:class:`~.Submodule`. + hnf_modulus : int, None, optional (default=None) + Modulus for use in the HNF reduction algorithm. See + :py:func:`~sympy.polys.matrices.normalforms.hermite_normal_form`. + + Returns + ======= + + :py:class:`~.Submodule` + + See Also + ======== + + submodule_from_matrix + + """ + if not all(g.module == self for g in gens): + raise ValueError('Generators must belong to this module.') + n = len(gens) + if n == 0: + raise ValueError('Need at least one generator.') + m = gens[0].n + d = gens[0].denom if n == 1 else ilcm(*[g.denom for g in gens]) + B = DomainMatrix.zeros((m, 0), ZZ).hstack(*[(d // g.denom) * g.col for g in gens]) + if hnf: + B = hermite_normal_form(B, D=hnf_modulus) + return self.submodule_from_matrix(B, denom=d) + + def submodule_from_matrix(self, B, denom=1): + """ + Form the submodule generated by the elements of this module indicated + by the columns of a matrix, with an optional denominator. + + Examples + ======== + + >>> from sympy.polys import Poly, cyclotomic_poly, ZZ + >>> from sympy.polys.matrices import DM + >>> from sympy.polys.numberfields.modules import PowerBasis + >>> T = Poly(cyclotomic_poly(5)) + >>> A = PowerBasis(T) + >>> B = A.submodule_from_matrix(DM([ + ... [0, 10, 0, 0], + ... [0, 0, 7, 0], + ... ], ZZ).transpose(), denom=15) + >>> print(B) # doctest: +SKIP + Submodule[[0, 10, 0, 0], [0, 0, 7, 0]]/15 + + Parameters + ========== + + B : :py:class:`~.DomainMatrix` over :ref:`ZZ` + Each column gives the numerators of the coefficients of one + generator of the submodule. Thus, the number of rows of *B* must + equal the number of generators of the present module. + denom : int, optional (default=1) + Common denominator for all generators of the submodule. + + Returns + ======= + + :py:class:`~.Submodule` + + Raises + ====== + + ValueError + If the given matrix *B* is not over :ref:`ZZ` or its number of rows + does not equal the number of generators of the present module. + + See Also + ======== + + submodule_from_gens + + """ + m, n = B.shape + if not B.domain.is_ZZ: + raise ValueError('Matrix must be over ZZ.') + if not m == self.n: + raise ValueError('Matrix row count must match base module.') + return Submodule(self, B, denom=denom) + + def whole_submodule(self): + """ + Return a submodule equal to this entire module. + + Explanation + =========== + + This is useful when you have a :py:class:`~.PowerBasis` and want to + turn it into a :py:class:`~.Submodule` (in order to use methods + belonging to the latter). + + """ + B = DomainMatrix.eye(self.n, ZZ) + return self.submodule_from_matrix(B) + + def endomorphism_ring(self): + """Form the :py:class:`~.EndomorphismRing` for this module.""" + return EndomorphismRing(self) + + +class PowerBasis(Module): + """The module generated by the powers of an algebraic integer.""" + + def __init__(self, T): + """ + Parameters + ========== + + T : :py:class:`~.Poly`, :py:class:`~.AlgebraicField` + Either (1) the monic, irreducible, univariate polynomial over + :ref:`ZZ`, a root of which is the generator of the power basis, + or (2) an :py:class:`~.AlgebraicField` whose primitive element + is the generator of the power basis. + + """ + K = None + if isinstance(T, AlgebraicField): + K, T = T, T.ext.minpoly_of_element() + # Sometimes incoming Polys are formally over QQ, although all their + # coeffs are integral. We want them to be formally over ZZ. + T = T.set_domain(ZZ) + self.K = K + self.T = T + self._n = T.degree() + self._mult_tab = None + + @property + def number_field(self): + return self.K + + def __repr__(self): + return f'PowerBasis({self.T.as_expr()})' + + def __eq__(self, other): + if isinstance(other, PowerBasis): + return self.T == other.T + return NotImplemented + + @property + def n(self): + return self._n + + def mult_tab(self): + if self._mult_tab is None: + self.compute_mult_tab() + return self._mult_tab + + def compute_mult_tab(self): + theta_pow = AlgIntPowers(self.T) + M = {} + n = self.n + for u in range(n): + M[u] = {} + for v in range(u, n): + M[u][v] = theta_pow[u + v] + self._mult_tab = M + + def represent(self, elt): + r""" + Represent a module element as an integer-linear combination over the + generators of this module. + + See Also + ======== + + .Module.represent + .Submodule.represent + + """ + if elt.module == self and elt.denom == 1: + return elt.column() + else: + raise ClosureFailure('Element not representable in ZZ[theta].') + + def starts_with_unity(self): + return True + + def element_from_rational(self, a): + return self(0) * a + + def element_from_poly(self, f): + """ + Produce an element of this module, representing *f* after reduction mod + our defining minimal polynomial. + + Parameters + ========== + + f : :py:class:`~.Poly` over :ref:`ZZ` in same var as our defining poly. + + Returns + ======= + + :py:class:`~.PowerBasisElement` + + """ + n, k = self.n, f.degree() + if k >= n: + f = f % self.T + if f == 0: + return self.zero() + d, c = dup_clear_denoms(f.rep.to_list(), QQ, convert=True) + c = list(reversed(c)) + ell = len(c) + z = [ZZ(0)] * (n - ell) + col = to_col(c + z) + return self(col, denom=d) + + def _element_from_rep_and_mod(self, rep, mod): + """ + Produce a PowerBasisElement representing a given algebraic number. + + Parameters + ========== + + rep : list of coeffs + Represents the number as polynomial in the primitive element of the + field. + + mod : list of coeffs + Represents the minimal polynomial of the primitive element of the + field. + + Returns + ======= + + :py:class:`~.PowerBasisElement` + + """ + if mod != self.T.rep.to_list(): + raise UnificationFailed('Element does not appear to be in the same field.') + return self.element_from_poly(Poly(rep, self.T.gen)) + + def element_from_ANP(self, a): + """Convert an ANP into a PowerBasisElement. """ + return self._element_from_rep_and_mod(a.to_list(), a.mod_to_list()) + + def element_from_alg_num(self, a): + """Convert an AlgebraicNumber into a PowerBasisElement. """ + return self._element_from_rep_and_mod(a.rep.to_list(), a.minpoly.rep.to_list()) + + +class Submodule(Module, IntegerPowerable): + """A submodule of another module.""" + + def __init__(self, parent, matrix, denom=1, mult_tab=None): + """ + Parameters + ========== + + parent : :py:class:`~.Module` + The module from which this one is derived. + matrix : :py:class:`~.DomainMatrix` over :ref:`ZZ` + The matrix whose columns define this submodule's generators as + linear combinations over the parent's generators. + denom : int, optional (default=1) + Denominator for the coefficients given by the matrix. + mult_tab : dict, ``None``, optional + If already known, the multiplication table for this module may be + supplied. + + """ + self._parent = parent + self._matrix = matrix + self._denom = denom + self._mult_tab = mult_tab + self._n = matrix.shape[1] + self._QQ_matrix = None + self._starts_with_unity = None + self._is_sq_maxrank_HNF = None + + def __repr__(self): + r = 'Submodule' + repr(self.matrix.transpose().to_Matrix().tolist()) + if self.denom > 1: + r += f'/{self.denom}' + return r + + def reduced(self): + """ + Produce a reduced version of this submodule. + + Explanation + =========== + + In the reduced version, it is guaranteed that 1 is the only positive + integer dividing both the submodule's denominator, and every entry in + the submodule's matrix. + + Returns + ======= + + :py:class:`~.Submodule` + + """ + if self.denom == 1: + return self + g = igcd(self.denom, *self.coeffs) + if g == 1: + return self + return type(self)(self.parent, (self.matrix / g).convert_to(ZZ), denom=self.denom // g, mult_tab=self._mult_tab) + + def discard_before(self, r): + """ + Produce a new module by discarding all generators before a given + index *r*. + """ + W = self.matrix[:, r:] + s = self.n - r + M = None + mt = self._mult_tab + if mt is not None: + M = {} + for u in range(s): + M[u] = {} + for v in range(u, s): + M[u][v] = mt[r + u][r + v][r:] + return Submodule(self.parent, W, denom=self.denom, mult_tab=M) + + @property + def n(self): + return self._n + + def mult_tab(self): + if self._mult_tab is None: + self.compute_mult_tab() + return self._mult_tab + + def compute_mult_tab(self): + gens = self.basis_element_pullbacks() + M = {} + n = self.n + for u in range(n): + M[u] = {} + for v in range(u, n): + M[u][v] = self.represent(gens[u] * gens[v]).flat() + self._mult_tab = M + + @property + def parent(self): + return self._parent + + @property + def matrix(self): + return self._matrix + + @property + def coeffs(self): + return self.matrix.flat() + + @property + def denom(self): + return self._denom + + @property + def QQ_matrix(self): + """ + :py:class:`~.DomainMatrix` over :ref:`QQ`, equal to + ``self.matrix / self.denom``, and guaranteed to be dense. + + Explanation + =========== + + Depending on how it is formed, a :py:class:`~.DomainMatrix` may have + an internal representation that is sparse or dense. We guarantee a + dense representation here, so that tests for equivalence of submodules + always come out as expected. + + Examples + ======== + + >>> from sympy.polys import Poly, cyclotomic_poly, ZZ + >>> from sympy.abc import x + >>> from sympy.polys.matrices import DomainMatrix + >>> from sympy.polys.numberfields.modules import PowerBasis + >>> T = Poly(cyclotomic_poly(5, x)) + >>> A = PowerBasis(T) + >>> B = A.submodule_from_matrix(3*DomainMatrix.eye(4, ZZ), denom=6) + >>> C = A.submodule_from_matrix(DomainMatrix.eye(4, ZZ), denom=2) + >>> print(B.QQ_matrix == C.QQ_matrix) + True + + Returns + ======= + + :py:class:`~.DomainMatrix` over :ref:`QQ` + + """ + if self._QQ_matrix is None: + self._QQ_matrix = (self.matrix / self.denom).to_dense() + return self._QQ_matrix + + def starts_with_unity(self): + if self._starts_with_unity is None: + self._starts_with_unity = self(0).equiv(1) + return self._starts_with_unity + + def is_sq_maxrank_HNF(self): + if self._is_sq_maxrank_HNF is None: + self._is_sq_maxrank_HNF = is_sq_maxrank_HNF(self._matrix) + return self._is_sq_maxrank_HNF + + def is_power_basis_submodule(self): + return isinstance(self.parent, PowerBasis) + + def element_from_rational(self, a): + if self.starts_with_unity(): + return self(0) * a + else: + return self.parent.element_from_rational(a) + + def basis_element_pullbacks(self): + """ + Return list of this submodule's basis elements as elements of the + submodule's parent module. + """ + return [e.to_parent() for e in self.basis_elements()] + + def represent(self, elt): + """ + Represent a module element as an integer-linear combination over the + generators of this module. + + See Also + ======== + + .Module.represent + .PowerBasis.represent + + """ + if elt.module == self: + return elt.column() + elif elt.module == self.parent: + try: + # The given element should be a ZZ-linear combination over our + # basis vectors; however, due to the presence of denominators, + # we need to solve over QQ. + A = self.QQ_matrix + b = elt.QQ_col + x = A._solve(b)[0].transpose() + x = x.convert_to(ZZ) + except DMBadInputError: + raise ClosureFailure('Element outside QQ-span of this basis.') + except CoercionFailed: + raise ClosureFailure('Element in QQ-span but not ZZ-span of this basis.') + return x + elif isinstance(self.parent, Submodule): + coeffs_in_parent = self.parent.represent(elt) + parent_element = self.parent(coeffs_in_parent) + return self.represent(parent_element) + else: + raise ClosureFailure('Element outside ancestor chain of this module.') + + def is_compat_submodule(self, other): + return isinstance(other, Submodule) and other.parent == self.parent + + def __eq__(self, other): + if self.is_compat_submodule(other): + return other.QQ_matrix == self.QQ_matrix + return NotImplemented + + def add(self, other, hnf=True, hnf_modulus=None): + """ + Add this :py:class:`~.Submodule` to another. + + Explanation + =========== + + This represents the module generated by the union of the two modules' + sets of generators. + + Parameters + ========== + + other : :py:class:`~.Submodule` + hnf : boolean, optional (default=True) + If ``True``, reduce the matrix of the combined module to its + Hermite Normal Form. + hnf_modulus : :ref:`ZZ`, None, optional + If a positive integer is provided, use this as modulus in the + HNF reduction. See + :py:func:`~sympy.polys.matrices.normalforms.hermite_normal_form`. + + Returns + ======= + + :py:class:`~.Submodule` + + """ + d, e = self.denom, other.denom + m = ilcm(d, e) + a, b = m // d, m // e + B = (a * self.matrix).hstack(b * other.matrix) + if hnf: + B = hermite_normal_form(B, D=hnf_modulus) + return self.parent.submodule_from_matrix(B, denom=m) + + def __add__(self, other): + if self.is_compat_submodule(other): + return self.add(other) + return NotImplemented + + __radd__ = __add__ + + def mul(self, other, hnf=True, hnf_modulus=None): + """ + Multiply this :py:class:`~.Submodule` by a rational number, a + :py:class:`~.ModuleElement`, or another :py:class:`~.Submodule`. + + Explanation + =========== + + To multiply by a rational number or :py:class:`~.ModuleElement` means + to form the submodule whose generators are the products of this + quantity with all the generators of the present submodule. + + To multiply by another :py:class:`~.Submodule` means to form the + submodule whose generators are all the products of one generator from + the one submodule, and one generator from the other. + + Parameters + ========== + + other : int, :ref:`ZZ`, :ref:`QQ`, :py:class:`~.ModuleElement`, :py:class:`~.Submodule` + hnf : boolean, optional (default=True) + If ``True``, reduce the matrix of the product module to its + Hermite Normal Form. + hnf_modulus : :ref:`ZZ`, None, optional + If a positive integer is provided, use this as modulus in the + HNF reduction. See + :py:func:`~sympy.polys.matrices.normalforms.hermite_normal_form`. + + Returns + ======= + + :py:class:`~.Submodule` + + """ + if is_rat(other): + a, b = get_num_denom(other) + if a == b == 1: + return self + else: + return Submodule(self.parent, + self.matrix * a, denom=self.denom * b, + mult_tab=None).reduced() + elif isinstance(other, ModuleElement) and other.module == self.parent: + # The submodule is multiplied by an element of the parent module. + # We presume this means we want a new submodule of the parent module. + gens = [other * e for e in self.basis_element_pullbacks()] + return self.parent.submodule_from_gens(gens, hnf=hnf, hnf_modulus=hnf_modulus) + elif self.is_compat_submodule(other): + # This case usually means you're multiplying ideals, and want another + # ideal, i.e. another submodule of the same parent module. + alphas, betas = self.basis_element_pullbacks(), other.basis_element_pullbacks() + gens = [a * b for a in alphas for b in betas] + return self.parent.submodule_from_gens(gens, hnf=hnf, hnf_modulus=hnf_modulus) + return NotImplemented + + def __mul__(self, other): + return self.mul(other) + + __rmul__ = __mul__ + + def _first_power(self): + return self + + def reduce_element(self, elt): + r""" + If this submodule $B$ has defining matrix $W$ in square, maximal-rank + Hermite normal form, then, given an element $x$ of the parent module + $A$, we produce an element $y \in A$ such that $x - y \in B$, and the + $i$th coordinate of $y$ satisfies $0 \leq y_i < w_{i,i}$. This + representative $y$ is unique, in the sense that every element of + the coset $x + B$ reduces to it under this procedure. + + Explanation + =========== + + In the special case where $A$ is a power basis for a number field $K$, + and $B$ is a submodule representing an ideal $I$, this operation + represents one of a few important ways of reducing an element of $K$ + modulo $I$ to obtain a "small" representative. See [Cohen00]_ Section + 1.4.3. + + Examples + ======== + + >>> from sympy import QQ, Poly, symbols + >>> t = symbols('t') + >>> k = QQ.alg_field_from_poly(Poly(t**3 + t**2 - 2*t + 8)) + >>> Zk = k.maximal_order() + >>> A = Zk.parent + >>> B = (A(2) - 3*A(0))*Zk + >>> B.reduce_element(A(2)) + [3, 0, 0] + + Parameters + ========== + + elt : :py:class:`~.ModuleElement` + An element of this submodule's parent module. + + Returns + ======= + + elt : :py:class:`~.ModuleElement` + An element of this submodule's parent module. + + Raises + ====== + + NotImplementedError + If the given :py:class:`~.ModuleElement` does not belong to this + submodule's parent module. + StructureError + If this submodule's defining matrix is not in square, maximal-rank + Hermite normal form. + + References + ========== + + .. [Cohen00] Cohen, H. *Advanced Topics in Computational Number + Theory.* + + """ + if not elt.module == self.parent: + raise NotImplementedError + if not self.is_sq_maxrank_HNF(): + msg = "Reduction not implemented unless matrix square max-rank HNF" + raise StructureError(msg) + B = self.basis_element_pullbacks() + a = elt + for i in range(self.n - 1, -1, -1): + b = B[i] + q = a.coeffs[i]*b.denom // (b.coeffs[i]*a.denom) + a -= q*b + return a + + +def is_sq_maxrank_HNF(dm): + r""" + Say whether a :py:class:`~.DomainMatrix` is in that special case of Hermite + Normal Form, in which the matrix is also square and of maximal rank. + + Explanation + =========== + + We commonly work with :py:class:`~.Submodule` instances whose matrix is in + this form, and it can be useful to be able to check that this condition is + satisfied. + + For example this is the case with the :py:class:`~.Submodule` ``ZK`` + returned by :py:func:`~sympy.polys.numberfields.basis.round_two`, which + represents the maximal order in a number field, and with ideals formed + therefrom, such as ``2 * ZK``. + + """ + if dm.domain.is_ZZ and dm.is_square and dm.is_upper: + n = dm.shape[0] + for i in range(n): + d = dm[i, i].element + if d <= 0: + return False + for j in range(i + 1, n): + if not (0 <= dm[i, j].element < d): + return False + return True + return False + + +def make_mod_elt(module, col, denom=1): + r""" + Factory function which builds a :py:class:`~.ModuleElement`, but ensures + that it is a :py:class:`~.PowerBasisElement` if the module is a + :py:class:`~.PowerBasis`. + """ + if isinstance(module, PowerBasis): + return PowerBasisElement(module, col, denom=denom) + else: + return ModuleElement(module, col, denom=denom) + + +class ModuleElement(IntegerPowerable): + r""" + Represents an element of a :py:class:`~.Module`. + + NOTE: Should not be constructed directly. Use the + :py:meth:`~.Module.__call__` method or the :py:func:`make_mod_elt()` + factory function instead. + """ + + def __init__(self, module, col, denom=1): + """ + Parameters + ========== + + module : :py:class:`~.Module` + The module to which this element belongs. + col : :py:class:`~.DomainMatrix` over :ref:`ZZ` + Column vector giving the numerators of the coefficients of this + element. + denom : int, optional (default=1) + Denominator for the coefficients of this element. + + """ + self.module = module + self.col = col + self.denom = denom + self._QQ_col = None + + def __repr__(self): + r = str([int(c) for c in self.col.flat()]) + if self.denom > 1: + r += f'/{self.denom}' + return r + + def reduced(self): + """ + Produce a reduced version of this ModuleElement, i.e. one in which the + gcd of the denominator together with all numerator coefficients is 1. + """ + if self.denom == 1: + return self + g = igcd(self.denom, *self.coeffs) + if g == 1: + return self + return type(self)(self.module, + (self.col / g).convert_to(ZZ), + denom=self.denom // g) + + def reduced_mod_p(self, p): + """ + Produce a version of this :py:class:`~.ModuleElement` in which all + numerator coefficients have been reduced mod *p*. + """ + return make_mod_elt(self.module, + self.col.convert_to(FF(p)).convert_to(ZZ), + denom=self.denom) + + @classmethod + def from_int_list(cls, module, coeffs, denom=1): + """ + Make a :py:class:`~.ModuleElement` from a list of ints (instead of a + column vector). + """ + col = to_col(coeffs) + return cls(module, col, denom=denom) + + @property + def n(self): + """The length of this element's column.""" + return self.module.n + + def __len__(self): + return self.n + + def column(self, domain=None): + """ + Get a copy of this element's column, optionally converting to a domain. + """ + if domain is None: + return self.col.copy() + else: + return self.col.convert_to(domain) + + @property + def coeffs(self): + return self.col.flat() + + @property + def QQ_col(self): + """ + :py:class:`~.DomainMatrix` over :ref:`QQ`, equal to + ``self.col / self.denom``, and guaranteed to be dense. + + See Also + ======== + + .Submodule.QQ_matrix + + """ + if self._QQ_col is None: + self._QQ_col = (self.col / self.denom).to_dense() + return self._QQ_col + + def to_parent(self): + """ + Transform into a :py:class:`~.ModuleElement` belonging to the parent of + this element's module. + """ + if not isinstance(self.module, Submodule): + raise ValueError('Not an element of a Submodule.') + return make_mod_elt( + self.module.parent, self.module.matrix * self.col, + denom=self.module.denom * self.denom) + + def to_ancestor(self, anc): + """ + Transform into a :py:class:`~.ModuleElement` belonging to a given + ancestor of this element's module. + + Parameters + ========== + + anc : :py:class:`~.Module` + + """ + if anc == self.module: + return self + else: + return self.to_parent().to_ancestor(anc) + + def over_power_basis(self): + """ + Transform into a :py:class:`~.PowerBasisElement` over our + :py:class:`~.PowerBasis` ancestor. + """ + e = self + while not isinstance(e.module, PowerBasis): + e = e.to_parent() + return e + + def is_compat(self, other): + """ + Test whether other is another :py:class:`~.ModuleElement` with same + module. + """ + return isinstance(other, ModuleElement) and other.module == self.module + + def unify(self, other): + """ + Try to make a compatible pair of :py:class:`~.ModuleElement`, one + equivalent to this one, and one equivalent to the other. + + Explanation + =========== + + We search for the nearest common ancestor module for the pair of + elements, and represent each one there. + + Returns + ======= + + Pair ``(e1, e2)`` + Each ``ei`` is a :py:class:`~.ModuleElement`, they belong to the + same :py:class:`~.Module`, ``e1`` is equivalent to ``self``, and + ``e2`` is equivalent to ``other``. + + Raises + ====== + + UnificationFailed + If ``self`` and ``other`` have no common ancestor module. + + """ + if self.module == other.module: + return self, other + nca = self.module.nearest_common_ancestor(other.module) + if nca is not None: + return self.to_ancestor(nca), other.to_ancestor(nca) + raise UnificationFailed(f"Cannot unify {self} with {other}") + + def __eq__(self, other): + if self.is_compat(other): + return self.QQ_col == other.QQ_col + return NotImplemented + + def equiv(self, other): + """ + A :py:class:`~.ModuleElement` may test as equivalent to a rational + number or another :py:class:`~.ModuleElement`, if they represent the + same algebraic number. + + Explanation + =========== + + This method is intended to check equivalence only in those cases in + which it is easy to test; namely, when *other* is either a + :py:class:`~.ModuleElement` that can be unified with this one (i.e. one + which shares a common :py:class:`~.PowerBasis` ancestor), or else a + rational number (which is easy because every :py:class:`~.PowerBasis` + represents every rational number). + + Parameters + ========== + + other : int, :ref:`ZZ`, :ref:`QQ`, :py:class:`~.ModuleElement` + + Returns + ======= + + bool + + Raises + ====== + + UnificationFailed + If ``self`` and ``other`` do not share a common + :py:class:`~.PowerBasis` ancestor. + + """ + if self == other: + return True + elif isinstance(other, ModuleElement): + a, b = self.unify(other) + return a == b + elif is_rat(other): + if isinstance(self, PowerBasisElement): + return self == self.module(0) * other + else: + return self.over_power_basis().equiv(other) + return False + + def __add__(self, other): + """ + A :py:class:`~.ModuleElement` can be added to a rational number, or to + another :py:class:`~.ModuleElement`. + + Explanation + =========== + + When the other summand is a rational number, it will be converted into + a :py:class:`~.ModuleElement` (belonging to the first ancestor of this + module that starts with unity). + + In all cases, the sum belongs to the nearest common ancestor (NCA) of + the modules of the two summands. If the NCA does not exist, we return + ``NotImplemented``. + """ + if self.is_compat(other): + d, e = self.denom, other.denom + m = ilcm(d, e) + u, v = m // d, m // e + col = to_col([u * a + v * b for a, b in zip(self.coeffs, other.coeffs)]) + return type(self)(self.module, col, denom=m).reduced() + elif isinstance(other, ModuleElement): + try: + a, b = self.unify(other) + except UnificationFailed: + return NotImplemented + return a + b + elif is_rat(other): + return self + self.module.element_from_rational(other) + return NotImplemented + + __radd__ = __add__ + + def __neg__(self): + return self * -1 + + def __sub__(self, other): + return self + (-other) + + def __rsub__(self, other): + return -self + other + + def __mul__(self, other): + """ + A :py:class:`~.ModuleElement` can be multiplied by a rational number, + or by another :py:class:`~.ModuleElement`. + + Explanation + =========== + + When the multiplier is a rational number, the product is computed by + operating directly on the coefficients of this + :py:class:`~.ModuleElement`. + + When the multiplier is another :py:class:`~.ModuleElement`, the product + will belong to the nearest common ancestor (NCA) of the modules of the + two operands, and that NCA must have a multiplication table. If the NCA + does not exist, we return ``NotImplemented``. If the NCA does not have + a mult. table, ``ClosureFailure`` will be raised. + """ + if self.is_compat(other): + M = self.module.mult_tab() + A, B = self.col.flat(), other.col.flat() + n = self.n + C = [0] * n + for u in range(n): + for v in range(u, n): + c = A[u] * B[v] + if v > u: + c += A[v] * B[u] + if c != 0: + R = M[u][v] + for k in range(n): + C[k] += c * R[k] + d = self.denom * other.denom + return self.from_int_list(self.module, C, denom=d) + elif isinstance(other, ModuleElement): + try: + a, b = self.unify(other) + except UnificationFailed: + return NotImplemented + return a * b + elif is_rat(other): + a, b = get_num_denom(other) + if a == b == 1: + return self + else: + return make_mod_elt(self.module, + self.col * a, denom=self.denom * b).reduced() + return NotImplemented + + __rmul__ = __mul__ + + def _zeroth_power(self): + return self.module.one() + + def _first_power(self): + return self + + def __floordiv__(self, a): + if is_rat(a): + a = QQ(a) + return self * (1/a) + elif isinstance(a, ModuleElement): + return self * (1//a) + return NotImplemented + + def __rfloordiv__(self, a): + return a // self.over_power_basis() + + def __mod__(self, m): + r""" + Reduce this :py:class:`~.ModuleElement` mod a :py:class:`~.Submodule`. + + Parameters + ========== + + m : int, :ref:`ZZ`, :ref:`QQ`, :py:class:`~.Submodule` + If a :py:class:`~.Submodule`, reduce ``self`` relative to this. + If an integer or rational, reduce relative to the + :py:class:`~.Submodule` that is our own module times this constant. + + See Also + ======== + + .Submodule.reduce_element + + """ + if is_rat(m): + m = m * self.module.whole_submodule() + if isinstance(m, Submodule) and m.parent == self.module: + return m.reduce_element(self) + return NotImplemented + + +class PowerBasisElement(ModuleElement): + r""" + Subclass for :py:class:`~.ModuleElement` instances whose module is a + :py:class:`~.PowerBasis`. + """ + + @property + def T(self): + """Access the defining polynomial of the :py:class:`~.PowerBasis`.""" + return self.module.T + + def numerator(self, x=None): + """Obtain the numerator as a polynomial over :ref:`ZZ`.""" + x = x or self.T.gen + return Poly(reversed(self.coeffs), x, domain=ZZ) + + def poly(self, x=None): + """Obtain the number as a polynomial over :ref:`QQ`.""" + return self.numerator(x=x) // self.denom + + @property + def is_rational(self): + """Say whether this element represents a rational number.""" + return self.col[1:, :].is_zero_matrix + + @property + def generator(self): + """ + Return a :py:class:`~.Symbol` to be used when expressing this element + as a polynomial. + + If we have an associated :py:class:`~.AlgebraicField` whose primitive + element has an alias symbol, we use that. Otherwise we use the variable + of the minimal polynomial defining the power basis to which we belong. + """ + K = self.module.number_field + return K.ext.alias if K and K.ext.is_aliased else self.T.gen + + def as_expr(self, x=None): + """Create a Basic expression from ``self``. """ + return self.poly(x or self.generator).as_expr() + + def norm(self, T=None): + """Compute the norm of this number.""" + T = T or self.T + x = T.gen + A = self.numerator(x=x) + return T.resultant(A) // self.denom ** self.n + + def inverse(self): + f = self.poly() + f_inv = f.invert(self.T) + return self.module.element_from_poly(f_inv) + + def __rfloordiv__(self, a): + return self.inverse() * a + + def _negative_power(self, e, modulo=None): + return self.inverse() ** abs(e) + + def to_ANP(self): + """Convert to an equivalent :py:class:`~.ANP`. """ + return ANP(list(reversed(self.QQ_col.flat())), QQ.map(self.T.rep.to_list()), QQ) + + def to_alg_num(self): + """ + Try to convert to an equivalent :py:class:`~.AlgebraicNumber`. + + Explanation + =========== + + In general, the conversion from an :py:class:`~.AlgebraicNumber` to a + :py:class:`~.PowerBasisElement` throws away information, because an + :py:class:`~.AlgebraicNumber` specifies a complex embedding, while a + :py:class:`~.PowerBasisElement` does not. However, in some cases it is + possible to convert a :py:class:`~.PowerBasisElement` back into an + :py:class:`~.AlgebraicNumber`, namely when the associated + :py:class:`~.PowerBasis` has a reference to an + :py:class:`~.AlgebraicField`. + + Returns + ======= + + :py:class:`~.AlgebraicNumber` + + Raises + ====== + + StructureError + If the :py:class:`~.PowerBasis` to which this element belongs does + not have an associated :py:class:`~.AlgebraicField`. + + """ + K = self.module.number_field + if K: + return K.to_alg_num(self.to_ANP()) + raise StructureError("No associated AlgebraicField") + + +class ModuleHomomorphism: + r"""A homomorphism from one module to another.""" + + def __init__(self, domain, codomain, mapping): + r""" + Parameters + ========== + + domain : :py:class:`~.Module` + The domain of the mapping. + + codomain : :py:class:`~.Module` + The codomain of the mapping. + + mapping : callable + An arbitrary callable is accepted, but should be chosen so as + to represent an actual module homomorphism. In particular, should + accept elements of *domain* and return elements of *codomain*. + + Examples + ======== + + >>> from sympy import Poly, cyclotomic_poly + >>> from sympy.polys.numberfields.modules import PowerBasis, ModuleHomomorphism + >>> T = Poly(cyclotomic_poly(5)) + >>> A = PowerBasis(T) + >>> B = A.submodule_from_gens([2*A(j) for j in range(4)]) + >>> phi = ModuleHomomorphism(A, B, lambda x: 6*x) + >>> print(phi.matrix()) # doctest: +SKIP + DomainMatrix([[3, 0, 0, 0], [0, 3, 0, 0], [0, 0, 3, 0], [0, 0, 0, 3]], (4, 4), ZZ) + + """ + self.domain = domain + self.codomain = codomain + self.mapping = mapping + + def matrix(self, modulus=None): + r""" + Compute the matrix of this homomorphism. + + Parameters + ========== + + modulus : int, optional + A positive prime number $p$ if the matrix should be reduced mod + $p$. + + Returns + ======= + + :py:class:`~.DomainMatrix` + The matrix is over :ref:`ZZ`, or else over :ref:`GF(p)` if a + modulus was given. + + """ + basis = self.domain.basis_elements() + cols = [self.codomain.represent(self.mapping(elt)) for elt in basis] + if not cols: + return DomainMatrix.zeros((self.codomain.n, 0), ZZ).to_dense() + M = cols[0].hstack(*cols[1:]) + if modulus: + M = M.convert_to(FF(modulus)) + return M + + def kernel(self, modulus=None): + r""" + Compute a Submodule representing the kernel of this homomorphism. + + Parameters + ========== + + modulus : int, optional + A positive prime number $p$ if the kernel should be computed mod + $p$. + + Returns + ======= + + :py:class:`~.Submodule` + This submodule's generators span the kernel of this + homomorphism over :ref:`ZZ`, or else over :ref:`GF(p)` if a + modulus was given. + + """ + M = self.matrix(modulus=modulus) + if modulus is None: + M = M.convert_to(QQ) + # Note: Even when working over a finite field, what we want here is + # the pullback into the integers, so in this case the conversion to ZZ + # below is appropriate. When working over ZZ, the kernel should be a + # ZZ-submodule, so, while the conversion to QQ above was required in + # order for the nullspace calculation to work, conversion back to ZZ + # afterward should always work. + # TODO: + # Watch , which calls + # for fraction-free algorithms. If this is implemented, we can skip + # the conversion to `QQ` above. + K = M.nullspace().convert_to(ZZ).transpose() + return self.domain.submodule_from_matrix(K) + + +class ModuleEndomorphism(ModuleHomomorphism): + r"""A homomorphism from one module to itself.""" + + def __init__(self, domain, mapping): + r""" + Parameters + ========== + + domain : :py:class:`~.Module` + The common domain and codomain of the mapping. + + mapping : callable + An arbitrary callable is accepted, but should be chosen so as + to represent an actual module endomorphism. In particular, should + accept and return elements of *domain*. + + """ + super().__init__(domain, domain, mapping) + + +class InnerEndomorphism(ModuleEndomorphism): + r""" + An inner endomorphism on a module, i.e. the endomorphism corresponding to + multiplication by a fixed element. + """ + + def __init__(self, domain, multiplier): + r""" + Parameters + ========== + + domain : :py:class:`~.Module` + The domain and codomain of the endomorphism. + + multiplier : :py:class:`~.ModuleElement` + The element $a$ defining the mapping as $x \mapsto a x$. + + """ + super().__init__(domain, lambda x: multiplier * x) + self.multiplier = multiplier + + +class EndomorphismRing: + r"""The ring of endomorphisms on a module.""" + + def __init__(self, domain): + """ + Parameters + ========== + + domain : :py:class:`~.Module` + The domain and codomain of the endomorphisms. + + """ + self.domain = domain + + def inner_endomorphism(self, multiplier): + r""" + Form an inner endomorphism belonging to this endomorphism ring. + + Parameters + ========== + + multiplier : :py:class:`~.ModuleElement` + Element $a$ defining the inner endomorphism $x \mapsto a x$. + + Returns + ======= + + :py:class:`~.InnerEndomorphism` + + """ + return InnerEndomorphism(self.domain, multiplier) + + def represent(self, element): + r""" + Represent an element of this endomorphism ring, as a single column + vector. + + Explanation + =========== + + Let $M$ be a module, and $E$ its ring of endomorphisms. Let $N$ be + another module, and consider a homomorphism $\varphi: N \rightarrow E$. + In the event that $\varphi$ is to be represented by a matrix $A$, each + column of $A$ must represent an element of $E$. This is possible when + the elements of $E$ are themselves representable as matrices, by + stacking the columns of such a matrix into a single column. + + This method supports calculating such matrices $A$, by representing + an element of this endomorphism ring first as a matrix, and then + stacking that matrix's columns into a single column. + + Examples + ======== + + Note that in these examples we print matrix transposes, to make their + columns easier to inspect. + + >>> from sympy import Poly, cyclotomic_poly + >>> from sympy.polys.numberfields.modules import PowerBasis + >>> from sympy.polys.numberfields.modules import ModuleHomomorphism + >>> T = Poly(cyclotomic_poly(5)) + >>> M = PowerBasis(T) + >>> E = M.endomorphism_ring() + + Let $\zeta$ be a primitive 5th root of unity, a generator of our field, + and consider the inner endomorphism $\tau$ on the ring of integers, + induced by $\zeta$: + + >>> zeta = M(1) + >>> tau = E.inner_endomorphism(zeta) + >>> tau.matrix().transpose() # doctest: +SKIP + DomainMatrix( + [[0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1], [-1, -1, -1, -1]], + (4, 4), ZZ) + + The matrix representation of $\tau$ is as expected. The first column + shows that multiplying by $\zeta$ carries $1$ to $\zeta$, the second + column that it carries $\zeta$ to $\zeta^2$, and so forth. + + The ``represent`` method of the endomorphism ring ``E`` stacks these + into a single column: + + >>> E.represent(tau).transpose() # doctest: +SKIP + DomainMatrix( + [[0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, -1, -1, -1, -1]], + (1, 16), ZZ) + + This is useful when we want to consider a homomorphism $\varphi$ having + ``E`` as codomain: + + >>> phi = ModuleHomomorphism(M, E, lambda x: E.inner_endomorphism(x)) + + and we want to compute the matrix of such a homomorphism: + + >>> phi.matrix().transpose() # doctest: +SKIP + DomainMatrix( + [[1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1], + [0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, -1, -1, -1, -1], + [0, 0, 1, 0, 0, 0, 0, 1, -1, -1, -1, -1, 1, 0, 0, 0], + [0, 0, 0, 1, -1, -1, -1, -1, 1, 0, 0, 0, 0, 1, 0, 0]], + (4, 16), ZZ) + + Note that the stacked matrix of $\tau$ occurs as the second column in + this example. This is because $\zeta$ is the second basis element of + ``M``, and $\varphi(\zeta) = \tau$. + + Parameters + ========== + + element : :py:class:`~.ModuleEndomorphism` belonging to this ring. + + Returns + ======= + + :py:class:`~.DomainMatrix` + Column vector equalling the vertical stacking of all the columns + of the matrix that represents the given *element* as a mapping. + + """ + if isinstance(element, ModuleEndomorphism) and element.domain == self.domain: + M = element.matrix() + # Transform the matrix into a single column, which should reproduce + # the original columns, one after another. + m, n = M.shape + if n == 0: + return M + return M[:, 0].vstack(*[M[:, j] for j in range(1, n)]) + raise NotImplementedError + + +def find_min_poly(alpha, domain, x=None, powers=None): + r""" + Find a polynomial of least degree (not necessarily irreducible) satisfied + by an element of a finitely-generated ring with unity. + + Examples + ======== + + For the $n$th cyclotomic field, $n$ an odd prime, consider the quadratic + equation whose roots are the two periods of length $(n-1)/2$. Article 356 + of Gauss tells us that we should get $x^2 + x - (n-1)/4$ or + $x^2 + x + (n+1)/4$ according to whether $n$ is 1 or 3 mod 4, respectively. + + >>> from sympy import Poly, cyclotomic_poly, primitive_root, QQ + >>> from sympy.abc import x + >>> from sympy.polys.numberfields.modules import PowerBasis, find_min_poly + >>> n = 13 + >>> g = primitive_root(n) + >>> C = PowerBasis(Poly(cyclotomic_poly(n, x))) + >>> ee = [g**(2*k+1) % n for k in range((n-1)//2)] + >>> eta = sum(C(e) for e in ee) + >>> print(find_min_poly(eta, QQ, x=x).as_expr()) + x**2 + x - 3 + >>> n = 19 + >>> g = primitive_root(n) + >>> C = PowerBasis(Poly(cyclotomic_poly(n, x))) + >>> ee = [g**(2*k+2) % n for k in range((n-1)//2)] + >>> eta = sum(C(e) for e in ee) + >>> print(find_min_poly(eta, QQ, x=x).as_expr()) + x**2 + x + 5 + + Parameters + ========== + + alpha : :py:class:`~.ModuleElement` + The element whose min poly is to be found, and whose module has + multiplication and starts with unity. + + domain : :py:class:`~.Domain` + The desired domain of the polynomial. + + x : :py:class:`~.Symbol`, optional + The desired variable for the polynomial. + + powers : list, optional + If desired, pass an empty list. The powers of *alpha* (as + :py:class:`~.ModuleElement` instances) from the zeroth up to the degree + of the min poly will be recorded here, as we compute them. + + Returns + ======= + + :py:class:`~.Poly`, ``None`` + The minimal polynomial for alpha, or ``None`` if no polynomial could be + found over the desired domain. + + Raises + ====== + + MissingUnityError + If the module to which alpha belongs does not start with unity. + ClosureFailure + If the module to which alpha belongs is not closed under + multiplication. + + """ + R = alpha.module + if not R.starts_with_unity(): + raise MissingUnityError("alpha must belong to finitely generated ring with unity.") + if powers is None: + powers = [] + one = R(0) + powers.append(one) + powers_matrix = one.column(domain=domain) + ak = alpha + m = None + for k in range(1, R.n + 1): + powers.append(ak) + ak_col = ak.column(domain=domain) + try: + X = powers_matrix._solve(ak_col)[0] + except DMBadInputError: + # This means alpha^k still isn't in the domain-span of the lower powers. + powers_matrix = powers_matrix.hstack(ak_col) + ak *= alpha + else: + # alpha^k is in the domain-span of the lower powers, so we have found a + # minimal-degree poly for alpha. + coeffs = [1] + [-c for c in reversed(X.to_list_flat())] + x = x or Dummy('x') + if domain.is_FF: + m = Poly(coeffs, x, modulus=domain.mod) + else: + m = Poly(coeffs, x, domain=domain) + break + return m diff --git a/phivenv/Lib/site-packages/sympy/polys/numberfields/primes.py b/phivenv/Lib/site-packages/sympy/polys/numberfields/primes.py new file mode 100644 index 0000000000000000000000000000000000000000..8f28f13d94f33ed59cded8eabd05e9cf7d0f103f --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/numberfields/primes.py @@ -0,0 +1,784 @@ +"""Prime ideals in number fields. """ + +from sympy.polys.polytools import Poly +from sympy.polys.domains.finitefield import FF +from sympy.polys.domains.rationalfield import QQ +from sympy.polys.domains.integerring import ZZ +from sympy.polys.matrices.domainmatrix import DomainMatrix +from sympy.polys.polyerrors import CoercionFailed +from sympy.polys.polyutils import IntegerPowerable +from sympy.utilities.decorator import public +from .basis import round_two, nilradical_mod_p +from .exceptions import StructureError +from .modules import ModuleEndomorphism, find_min_poly +from .utilities import coeff_search, supplement_a_subspace + + +def _check_formal_conditions_for_maximal_order(submodule): + r""" + Several functions in this module accept an argument which is to be a + :py:class:`~.Submodule` representing the maximal order in a number field, + such as returned by the :py:func:`~sympy.polys.numberfields.basis.round_two` + algorithm. + + We do not attempt to check that the given ``Submodule`` actually represents + a maximal order, but we do check a basic set of formal conditions that the + ``Submodule`` must satisfy, at a minimum. The purpose is to catch an + obviously ill-formed argument. + """ + prefix = 'The submodule representing the maximal order should ' + cond = None + if not submodule.is_power_basis_submodule(): + cond = 'be a direct submodule of a power basis.' + elif not submodule.starts_with_unity(): + cond = 'have 1 as its first generator.' + elif not submodule.is_sq_maxrank_HNF(): + cond = 'have square matrix, of maximal rank, in Hermite Normal Form.' + if cond is not None: + raise StructureError(prefix + cond) + + +class PrimeIdeal(IntegerPowerable): + r""" + A prime ideal in a ring of algebraic integers. + """ + + def __init__(self, ZK, p, alpha, f, e=None): + """ + Parameters + ========== + + ZK : :py:class:`~.Submodule` + The maximal order where this ideal lives. + p : int + The rational prime this ideal divides. + alpha : :py:class:`~.PowerBasisElement` + Such that the ideal is equal to ``p*ZK + alpha*ZK``. + f : int + The inertia degree. + e : int, ``None``, optional + The ramification index, if already known. If ``None``, we will + compute it here. + + """ + _check_formal_conditions_for_maximal_order(ZK) + self.ZK = ZK + self.p = p + self.alpha = alpha + self.f = f + self._test_factor = None + self.e = e if e is not None else self.valuation(p * ZK) + + def __str__(self): + if self.is_inert: + return f'({self.p})' + return f'({self.p}, {self.alpha.as_expr()})' + + @property + def is_inert(self): + """ + Say whether the rational prime we divide is inert, i.e. stays prime in + our ring of integers. + """ + return self.f == self.ZK.n + + def repr(self, field_gen=None, just_gens=False): + """ + Print a representation of this prime ideal. + + Examples + ======== + + >>> from sympy import cyclotomic_poly, QQ + >>> from sympy.abc import x, zeta + >>> T = cyclotomic_poly(7, x) + >>> K = QQ.algebraic_field((T, zeta)) + >>> P = K.primes_above(11) + >>> print(P[0].repr()) + [ (11, x**3 + 5*x**2 + 4*x - 1) e=1, f=3 ] + >>> print(P[0].repr(field_gen=zeta)) + [ (11, zeta**3 + 5*zeta**2 + 4*zeta - 1) e=1, f=3 ] + >>> print(P[0].repr(field_gen=zeta, just_gens=True)) + (11, zeta**3 + 5*zeta**2 + 4*zeta - 1) + + Parameters + ========== + + field_gen : :py:class:`~.Symbol`, ``None``, optional (default=None) + The symbol to use for the generator of the field. This will appear + in our representation of ``self.alpha``. If ``None``, we use the + variable of the defining polynomial of ``self.ZK``. + just_gens : bool, optional (default=False) + If ``True``, just print the "(p, alpha)" part, showing "just the + generators" of the prime ideal. Otherwise, print a string of the + form "[ (p, alpha) e=..., f=... ]", giving the ramification index + and inertia degree, along with the generators. + + """ + field_gen = field_gen or self.ZK.parent.T.gen + p, alpha, e, f = self.p, self.alpha, self.e, self.f + alpha_rep = str(alpha.numerator(x=field_gen).as_expr()) + if alpha.denom > 1: + alpha_rep = f'({alpha_rep})/{alpha.denom}' + gens = f'({p}, {alpha_rep})' + if just_gens: + return gens + return f'[ {gens} e={e}, f={f} ]' + + def __repr__(self): + return self.repr() + + def as_submodule(self): + r""" + Represent this prime ideal as a :py:class:`~.Submodule`. + + Explanation + =========== + + The :py:class:`~.PrimeIdeal` class serves to bundle information about + a prime ideal, such as its inertia degree, ramification index, and + two-generator representation, as well as to offer helpful methods like + :py:meth:`~.PrimeIdeal.valuation` and + :py:meth:`~.PrimeIdeal.test_factor`. + + However, in order to be added and multiplied by other ideals or + rational numbers, it must first be converted into a + :py:class:`~.Submodule`, which is a class that supports these + operations. + + In many cases, the user need not perform this conversion deliberately, + since it is automatically performed by the arithmetic operator methods + :py:meth:`~.PrimeIdeal.__add__` and :py:meth:`~.PrimeIdeal.__mul__`. + + Raising a :py:class:`~.PrimeIdeal` to a non-negative integer power is + also supported. + + Examples + ======== + + >>> from sympy import Poly, cyclotomic_poly, prime_decomp + >>> T = Poly(cyclotomic_poly(7)) + >>> P0 = prime_decomp(7, T)[0] + >>> print(P0**6 == 7*P0.ZK) + True + + Note that, on both sides of the equation above, we had a + :py:class:`~.Submodule`. In the next equation we recall that adding + ideals yields their GCD. This time, we need a deliberate conversion + to :py:class:`~.Submodule` on the right: + + >>> print(P0 + 7*P0.ZK == P0.as_submodule()) + True + + Returns + ======= + + :py:class:`~.Submodule` + Will be equal to ``self.p * self.ZK + self.alpha * self.ZK``. + + See Also + ======== + + __add__ + __mul__ + + """ + M = self.p * self.ZK + self.alpha * self.ZK + # Pre-set expensive boolean properties whose value we already know: + M._starts_with_unity = False + M._is_sq_maxrank_HNF = True + return M + + def __eq__(self, other): + if isinstance(other, PrimeIdeal): + return self.as_submodule() == other.as_submodule() + return NotImplemented + + def __add__(self, other): + """ + Convert to a :py:class:`~.Submodule` and add to another + :py:class:`~.Submodule`. + + See Also + ======== + + as_submodule + + """ + return self.as_submodule() + other + + __radd__ = __add__ + + def __mul__(self, other): + """ + Convert to a :py:class:`~.Submodule` and multiply by another + :py:class:`~.Submodule` or a rational number. + + See Also + ======== + + as_submodule + + """ + return self.as_submodule() * other + + __rmul__ = __mul__ + + def _zeroth_power(self): + return self.ZK + + def _first_power(self): + return self + + def test_factor(self): + r""" + Compute a test factor for this prime ideal. + + Explanation + =========== + + Write $\mathfrak{p}$ for this prime ideal, $p$ for the rational prime + it divides. Then, for computing $\mathfrak{p}$-adic valuations it is + useful to have a number $\beta \in \mathbb{Z}_K$ such that + $p/\mathfrak{p} = p \mathbb{Z}_K + \beta \mathbb{Z}_K$. + + Essentially, this is the same as the number $\Psi$ (or the "reagent") + from Kummer's 1847 paper (*Ueber die Zerlegung...*, Crelle vol. 35) in + which ideal divisors were invented. + """ + if self._test_factor is None: + self._test_factor = _compute_test_factor(self.p, [self.alpha], self.ZK) + return self._test_factor + + def valuation(self, I): + r""" + Compute the $\mathfrak{p}$-adic valuation of integral ideal I at this + prime ideal. + + Parameters + ========== + + I : :py:class:`~.Submodule` + + See Also + ======== + + prime_valuation + + """ + return prime_valuation(I, self) + + def reduce_element(self, elt): + """ + Reduce a :py:class:`~.PowerBasisElement` to a "small representative" + modulo this prime ideal. + + Parameters + ========== + + elt : :py:class:`~.PowerBasisElement` + The element to be reduced. + + Returns + ======= + + :py:class:`~.PowerBasisElement` + The reduced element. + + See Also + ======== + + reduce_ANP + reduce_alg_num + .Submodule.reduce_element + + """ + return self.as_submodule().reduce_element(elt) + + def reduce_ANP(self, a): + """ + Reduce an :py:class:`~.ANP` to a "small representative" modulo this + prime ideal. + + Parameters + ========== + + elt : :py:class:`~.ANP` + The element to be reduced. + + Returns + ======= + + :py:class:`~.ANP` + The reduced element. + + See Also + ======== + + reduce_element + reduce_alg_num + .Submodule.reduce_element + + """ + elt = self.ZK.parent.element_from_ANP(a) + red = self.reduce_element(elt) + return red.to_ANP() + + def reduce_alg_num(self, a): + """ + Reduce an :py:class:`~.AlgebraicNumber` to a "small representative" + modulo this prime ideal. + + Parameters + ========== + + elt : :py:class:`~.AlgebraicNumber` + The element to be reduced. + + Returns + ======= + + :py:class:`~.AlgebraicNumber` + The reduced element. + + See Also + ======== + + reduce_element + reduce_ANP + .Submodule.reduce_element + + """ + elt = self.ZK.parent.element_from_alg_num(a) + red = self.reduce_element(elt) + return a.field_element(list(reversed(red.QQ_col.flat()))) + + +def _compute_test_factor(p, gens, ZK): + r""" + Compute the test factor for a :py:class:`~.PrimeIdeal` $\mathfrak{p}$. + + Parameters + ========== + + p : int + The rational prime $\mathfrak{p}$ divides + + gens : list of :py:class:`PowerBasisElement` + A complete set of generators for $\mathfrak{p}$ over *ZK*, EXCEPT that + an element equivalent to rational *p* can and should be omitted (since + it has no effect except to waste time). + + ZK : :py:class:`~.Submodule` + The maximal order where the prime ideal $\mathfrak{p}$ lives. + + Returns + ======= + + :py:class:`~.PowerBasisElement` + + References + ========== + + .. [1] Cohen, H. *A Course in Computational Algebraic Number Theory.* + (See Proposition 4.8.15.) + + """ + _check_formal_conditions_for_maximal_order(ZK) + E = ZK.endomorphism_ring() + matrices = [E.inner_endomorphism(g).matrix(modulus=p) for g in gens] + B = DomainMatrix.zeros((0, ZK.n), FF(p)).vstack(*matrices) + # A nonzero element of the nullspace of B will represent a + # lin comb over the omegas which (i) is not a multiple of p + # (since it is nonzero over FF(p)), while (ii) is such that + # its product with each g in gens _is_ a multiple of p (since + # B represents multiplication by these generators). Theory + # predicts that such an element must exist, so nullspace should + # be non-trivial. + x = B.nullspace()[0, :].transpose() + beta = ZK.parent(ZK.matrix * x.convert_to(ZZ), denom=ZK.denom) + return beta + + +@public +def prime_valuation(I, P): + r""" + Compute the *P*-adic valuation for an integral ideal *I*. + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.polys.numberfields import prime_valuation + >>> K = QQ.cyclotomic_field(5) + >>> P = K.primes_above(5) + >>> ZK = K.maximal_order() + >>> print(prime_valuation(25*ZK, P[0])) + 8 + + Parameters + ========== + + I : :py:class:`~.Submodule` + An integral ideal whose valuation is desired. + + P : :py:class:`~.PrimeIdeal` + The prime at which to compute the valuation. + + Returns + ======= + + int + + See Also + ======== + + .PrimeIdeal.valuation + + References + ========== + + .. [1] Cohen, H. *A Course in Computational Algebraic Number Theory.* + (See Algorithm 4.8.17.) + + """ + p, ZK = P.p, P.ZK + n, W, d = ZK.n, ZK.matrix, ZK.denom + + A = W.convert_to(QQ).inv() * I.matrix * d / I.denom + # Although A must have integer entries, given that I is an integral ideal, + # as a DomainMatrix it will still be over QQ, so we convert back: + A = A.convert_to(ZZ) + D = A.det() + if D % p != 0: + return 0 + + beta = P.test_factor() + + f = d ** n // W.det() + need_complete_test = (f % p == 0) + v = 0 + while True: + # Entering the loop, the cols of A represent lin combs of omegas. + # Turn them into lin combs of thetas: + A = W * A + # And then one column at a time... + for j in range(n): + c = ZK.parent(A[:, j], denom=d) + c *= beta + # ...turn back into lin combs of omegas, after multiplying by beta: + c = ZK.represent(c).flat() + for i in range(n): + A[i, j] = c[i] + if A[n - 1, n - 1].element % p != 0: + break + A = A / p + # As noted above, domain converts to QQ even when division goes evenly. + # So must convert back, even when we don't "need_complete_test". + if need_complete_test: + # In this case, having a non-integer entry is actually just our + # halting condition. + try: + A = A.convert_to(ZZ) + except CoercionFailed: + break + else: + # In this case theory says we should not have any non-integer entries. + A = A.convert_to(ZZ) + v += 1 + return v + + +def _two_elt_rep(gens, ZK, p, f=None, Np=None): + r""" + Given a set of *ZK*-generators of a prime ideal, compute a set of just two + *ZK*-generators for the same ideal, one of which is *p* itself. + + Parameters + ========== + + gens : list of :py:class:`PowerBasisElement` + Generators for the prime ideal over *ZK*, the ring of integers of the + field $K$. + + ZK : :py:class:`~.Submodule` + The maximal order in $K$. + + p : int + The rational prime divided by the prime ideal. + + f : int, optional + The inertia degree of the prime ideal, if known. + + Np : int, optional + The norm $p^f$ of the prime ideal, if known. + NOTE: There is no reason to supply both *f* and *Np*. Either one will + save us from having to compute the norm *Np* ourselves. If both are known, + *Np* is preferred since it saves one exponentiation. + + Returns + ======= + + :py:class:`~.PowerBasisElement` representing a single algebraic integer + alpha such that the prime ideal is equal to ``p*ZK + alpha*ZK``. + + References + ========== + + .. [1] Cohen, H. *A Course in Computational Algebraic Number Theory.* + (See Algorithm 4.7.10.) + + """ + _check_formal_conditions_for_maximal_order(ZK) + pb = ZK.parent + T = pb.T + # Detect the special cases in which either (a) all generators are multiples + # of p, or (b) there are no generators (so `all` is vacuously true): + if all((g % p).equiv(0) for g in gens): + return pb.zero() + + if Np is None: + if f is not None: + Np = p**f + else: + Np = abs(pb.submodule_from_gens(gens).matrix.det()) + + omega = ZK.basis_element_pullbacks() + beta = [p*om for om in omega[1:]] # note: we omit omega[0] == 1 + beta += gens + search = coeff_search(len(beta), 1) + for c in search: + alpha = sum(ci*betai for ci, betai in zip(c, beta)) + # Note: It may be tempting to reduce alpha mod p here, to try to work + # with smaller numbers, but must not do that, as it can result in an + # infinite loop! E.g. try factoring 2 in Q(sqrt(-7)). + n = alpha.norm(T) // Np + if n % p != 0: + # Now can reduce alpha mod p. + return alpha % p + + +def _prime_decomp_easy_case(p, ZK): + r""" + Compute the decomposition of rational prime *p* in the ring of integers + *ZK* (given as a :py:class:`~.Submodule`), in the "easy case", i.e. the + case where *p* does not divide the index of $\theta$ in *ZK*, where + $\theta$ is the generator of the ``PowerBasis`` of which *ZK* is a + ``Submodule``. + """ + T = ZK.parent.T + T_bar = Poly(T, modulus=p) + lc, fl = T_bar.factor_list() + if len(fl) == 1 and fl[0][1] == 1: + return [PrimeIdeal(ZK, p, ZK.parent.zero(), ZK.n, 1)] + return [PrimeIdeal(ZK, p, + ZK.parent.element_from_poly(Poly(t, domain=ZZ)), + t.degree(), e) + for t, e in fl] + + +def _prime_decomp_compute_kernel(I, p, ZK): + r""" + Parameters + ========== + + I : :py:class:`~.Module` + An ideal of ``ZK/pZK``. + p : int + The rational prime being factored. + ZK : :py:class:`~.Submodule` + The maximal order. + + Returns + ======= + + Pair ``(N, G)``, where: + + ``N`` is a :py:class:`~.Module` representing the kernel of the map + ``a |--> a**p - a`` on ``(O/pO)/I``, guaranteed to be a module with + unity. + + ``G`` is a :py:class:`~.Module` representing a basis for the separable + algebra ``A = O/I`` (see Cohen). + + """ + W = I.matrix + n, r = W.shape + # Want to take the Fp-basis given by the columns of I, adjoin (1, 0, ..., 0) + # (which we know is not already in there since I is a basis for a prime ideal) + # and then supplement this with additional columns to make an invertible n x n + # matrix. This will then represent a full basis for ZK, whose first r columns + # are pullbacks of the basis for I. + if r == 0: + B = W.eye(n, ZZ) + else: + B = W.hstack(W.eye(n, ZZ)[:, 0]) + if B.shape[1] < n: + B = supplement_a_subspace(B.convert_to(FF(p))).convert_to(ZZ) + + G = ZK.submodule_from_matrix(B) + # Must compute G's multiplication table _before_ discarding the first r + # columns. (See Step 9 in Alg 6.2.9 in Cohen, where the betas are actually + # needed in order to represent each product of gammas. However, once we've + # found the representations, then we can ignore the betas.) + G.compute_mult_tab() + G = G.discard_before(r) + + phi = ModuleEndomorphism(G, lambda x: x**p - x) + N = phi.kernel(modulus=p) + assert N.starts_with_unity() + return N, G + + +def _prime_decomp_maximal_ideal(I, p, ZK): + r""" + We have reached the case where we have a maximal (hence prime) ideal *I*, + which we know because the quotient ``O/I`` is a field. + + Parameters + ========== + + I : :py:class:`~.Module` + An ideal of ``O/pO``. + p : int + The rational prime being factored. + ZK : :py:class:`~.Submodule` + The maximal order. + + Returns + ======= + + :py:class:`~.PrimeIdeal` instance representing this prime + + """ + m, n = I.matrix.shape + f = m - n + G = ZK.matrix * I.matrix + gens = [ZK.parent(G[:, j], denom=ZK.denom) for j in range(G.shape[1])] + alpha = _two_elt_rep(gens, ZK, p, f=f) + return PrimeIdeal(ZK, p, alpha, f) + + +def _prime_decomp_split_ideal(I, p, N, G, ZK): + r""" + Perform the step in the prime decomposition algorithm where we have determined + the quotient ``ZK/I`` is _not_ a field, and we want to perform a non-trivial + factorization of *I* by locating an idempotent element of ``ZK/I``. + """ + assert I.parent == ZK and G.parent is ZK and N.parent is G + # Since ZK/I is not a field, the kernel computed in the previous step contains + # more than just the prime field Fp, and our basis N for the nullspace therefore + # contains at least a second column (which represents an element outside Fp). + # Let alpha be such an element: + alpha = N(1).to_parent() + assert alpha.module is G + + alpha_powers = [] + m = find_min_poly(alpha, FF(p), powers=alpha_powers) + # TODO (future work): + # We don't actually need full factorization, so might use a faster method + # to just break off a single non-constant factor m1? + lc, fl = m.factor_list() + m1 = fl[0][0] + m2 = m.quo(m1) + U, V, g = m1.gcdex(m2) + # Sanity check: theory says m is squarefree, so m1, m2 should be coprime: + assert g == 1 + E = list(reversed(Poly(U * m1, domain=ZZ).rep.to_list())) + eps1 = sum(E[i]*alpha_powers[i] for i in range(len(E))) + eps2 = 1 - eps1 + idemps = [eps1, eps2] + factors = [] + for eps in idemps: + e = eps.to_parent() + assert e.module is ZK + D = I.matrix.convert_to(FF(p)).hstack(*[ + (e * om).column(domain=FF(p)) for om in ZK.basis_elements() + ]) + W = D.columnspace().convert_to(ZZ) + H = ZK.submodule_from_matrix(W) + factors.append(H) + return factors + + +@public +def prime_decomp(p, T=None, ZK=None, dK=None, radical=None): + r""" + Compute the decomposition of rational prime *p* in a number field. + + Explanation + =========== + + Ordinarily this should be accessed through the + :py:meth:`~.AlgebraicField.primes_above` method of an + :py:class:`~.AlgebraicField`. + + Examples + ======== + + >>> from sympy import Poly, QQ + >>> from sympy.abc import x, theta + >>> T = Poly(x ** 3 + x ** 2 - 2 * x + 8) + >>> K = QQ.algebraic_field((T, theta)) + >>> print(K.primes_above(2)) + [[ (2, x**2 + 1) e=1, f=1 ], [ (2, (x**2 + 3*x + 2)/2) e=1, f=1 ], + [ (2, (3*x**2 + 3*x)/2) e=1, f=1 ]] + + Parameters + ========== + + p : int + The rational prime whose decomposition is desired. + + T : :py:class:`~.Poly`, optional + Monic irreducible polynomial defining the number field $K$ in which to + factor. NOTE: at least one of *T* or *ZK* must be provided. + + ZK : :py:class:`~.Submodule`, optional + The maximal order for $K$, if already known. + NOTE: at least one of *T* or *ZK* must be provided. + + dK : int, optional + The discriminant of the field $K$, if already known. + + radical : :py:class:`~.Submodule`, optional + The nilradical mod *p* in the integers of $K$, if already known. + + Returns + ======= + + List of :py:class:`~.PrimeIdeal` instances. + + References + ========== + + .. [1] Cohen, H. *A Course in Computational Algebraic Number Theory.* + (See Algorithm 6.2.9.) + + """ + if T is None and ZK is None: + raise ValueError('At least one of T or ZK must be provided.') + if ZK is not None: + _check_formal_conditions_for_maximal_order(ZK) + if T is None: + T = ZK.parent.T + radicals = {} + if dK is None or ZK is None: + ZK, dK = round_two(T, radicals=radicals) + dT = T.discriminant() + f_squared = dT // dK + if f_squared % p != 0: + return _prime_decomp_easy_case(p, ZK) + radical = radical or radicals.get(p) or nilradical_mod_p(ZK, p) + stack = [radical] + primes = [] + while stack: + I = stack.pop() + N, G = _prime_decomp_compute_kernel(I, p, ZK) + if N.n == 1: + P = _prime_decomp_maximal_ideal(I, p, ZK) + primes.append(P) + else: + I1, I2 = _prime_decomp_split_ideal(I, p, N, G, ZK) + stack.extend([I1, I2]) + return primes diff --git a/phivenv/Lib/site-packages/sympy/polys/numberfields/resolvent_lookup.py b/phivenv/Lib/site-packages/sympy/polys/numberfields/resolvent_lookup.py new file mode 100644 index 0000000000000000000000000000000000000000..71812c0d7aec6501039eefe4f3602b1916628071 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/numberfields/resolvent_lookup.py @@ -0,0 +1,456 @@ +"""Lookup table for Galois resolvents for polys of degree 4 through 6. """ +# This table was generated by a call to +# `sympy.polys.numberfields.galois_resolvents.generate_lambda_lookup()`. +# The entire job took 543.23s. +# Of this, Case (6, 1) took 539.03s. +# The final polynomial of Case (6, 1) alone took 455.09s. +resolvent_coeff_lambdas = { + (4, 0): [ + lambda s1, s2, s3, s4: (-2*s1*s2 + 6*s3), + lambda s1, s2, s3, s4: (2*s1**3*s3 + s1**2*s2**2 + s1**2*s4 - 17*s1*s2*s3 + 2*s2**3 - 8*s2*s4 + 24*s3**2), + lambda s1, s2, s3, s4: (-2*s1**5*s4 - 2*s1**4*s2*s3 + 10*s1**3*s2*s4 + 8*s1**3*s3**2 + 10*s1**2*s2**2*s3 - +12*s1**2*s3*s4 - 2*s1*s2**4 - 54*s1*s2*s3**2 + 32*s1*s4**2 + 8*s2**3*s3 - 32*s2*s3*s4 ++ 56*s3**3), + lambda s1, s2, s3, s4: (2*s1**6*s2*s4 + s1**6*s3**2 - 5*s1**5*s3*s4 - 11*s1**4*s2**2*s4 - 13*s1**4*s2*s3**2 ++ 7*s1**4*s4**2 + 3*s1**3*s2**3*s3 + 30*s1**3*s2*s3*s4 + 22*s1**3*s3**3 + 10*s1**2*s2**3*s4 ++ 33*s1**2*s2**2*s3**2 - 72*s1**2*s2*s4**2 - 36*s1**2*s3**2*s4 - 13*s1*s2**4*s3 + +48*s1*s2**2*s3*s4 - 116*s1*s2*s3**3 + 144*s1*s3*s4**2 + s2**6 - 12*s2**4*s4 + 22*s2**3*s3**2 ++ 48*s2**2*s4**2 - 120*s2*s3**2*s4 + 96*s3**4 - 64*s4**3), + lambda s1, s2, s3, s4: (-2*s1**8*s3*s4 - s1**7*s4**2 + 22*s1**6*s2*s3*s4 + 2*s1**6*s3**3 - 2*s1**5*s2**3*s4 +- s1**5*s2**2*s3**2 - 29*s1**5*s3**2*s4 - 60*s1**4*s2**2*s3*s4 - 19*s1**4*s2*s3**3 ++ 38*s1**4*s3*s4**2 + 9*s1**3*s2**4*s4 + 10*s1**3*s2**3*s3**2 + 24*s1**3*s2**2*s4**2 ++ 134*s1**3*s2*s3**2*s4 + 28*s1**3*s3**4 + 16*s1**3*s4**3 - s1**2*s2**5*s3 - 4*s1**2*s2**3*s3*s4 ++ 34*s1**2*s2**2*s3**3 - 288*s1**2*s2*s3*s4**2 - 104*s1**2*s3**3*s4 - 19*s1*s2**4*s3**2 ++ 120*s1*s2**2*s3**2*s4 - 128*s1*s2*s3**4 + 336*s1*s3**2*s4**2 + 2*s2**6*s3 - 24*s2**4*s3*s4 ++ 28*s2**3*s3**3 + 96*s2**2*s3*s4**2 - 176*s2*s3**3*s4 + 96*s3**5 - 128*s3*s4**3), + lambda s1, s2, s3, s4: (s1**10*s4**2 - 11*s1**8*s2*s4**2 - 2*s1**8*s3**2*s4 + s1**7*s2**2*s3*s4 + 15*s1**7*s3*s4**2 ++ 45*s1**6*s2**2*s4**2 + 17*s1**6*s2*s3**2*s4 + s1**6*s3**4 - 5*s1**6*s4**3 - 12*s1**5*s2**3*s3*s4 +- 133*s1**5*s2*s3*s4**2 - 22*s1**5*s3**3*s4 + s1**4*s2**5*s4 - 76*s1**4*s2**3*s4**2 +- 6*s1**4*s2**2*s3**2*s4 - 12*s1**4*s2*s3**4 + 32*s1**4*s2*s4**3 + 128*s1**4*s3**2*s4**2 ++ 29*s1**3*s2**4*s3*s4 + 2*s1**3*s2**3*s3**3 + 344*s1**3*s2**2*s3*s4**2 + 48*s1**3*s2*s3**3*s4 ++ 16*s1**3*s3**5 - 48*s1**3*s3*s4**3 - 4*s1**2*s2**6*s4 + 32*s1**2*s2**4*s4**2 - 134*s1**2*s2**3*s3**2*s4 ++ 36*s1**2*s2**2*s3**4 - 64*s1**2*s2**2*s4**3 - 648*s1**2*s2*s3**2*s4**2 - 48*s1**2*s3**4*s4 ++ 16*s1*s2**5*s3*s4 - 12*s1*s2**4*s3**3 - 128*s1*s2**3*s3*s4**2 + 296*s1*s2**2*s3**3*s4 +- 96*s1*s2*s3**5 + 256*s1*s2*s3*s4**3 + 416*s1*s3**3*s4**2 + s2**6*s3**2 - 28*s2**4*s3**2*s4 ++ 16*s2**3*s3**4 + 176*s2**2*s3**2*s4**2 - 224*s2*s3**4*s4 + 64*s3**6 - 320*s3**2*s4**3) + ], + (4, 1): [ + lambda s1, s2, s3, s4: (-s2), + lambda s1, s2, s3, s4: (s1*s3 - 4*s4), + lambda s1, s2, s3, s4: (-s1**2*s4 + 4*s2*s4 - s3**2) + ], + (5, 1): [ + lambda s1, s2, s3, s4, s5: (-2*s1*s3 + 8*s4), + lambda s1, s2, s3, s4, s5: (-8*s1**3*s5 + 2*s1**2*s2*s4 + s1**2*s3**2 + 30*s1*s2*s5 - 14*s1*s3*s4 - 6*s2**2*s4 ++ 2*s2*s3**2 - 50*s3*s5 + 40*s4**2), + lambda s1, s2, s3, s4, s5: (16*s1**4*s3*s5 - 2*s1**4*s4**2 - 2*s1**3*s2**2*s5 - 2*s1**3*s2*s3*s4 - 44*s1**3*s4*s5 +- 66*s1**2*s2*s3*s5 + 21*s1**2*s2*s4**2 + 6*s1**2*s3**2*s4 - 50*s1**2*s5**2 + 9*s1*s2**3*s5 ++ 5*s1*s2**2*s3*s4 - 2*s1*s2*s3**3 + 190*s1*s2*s4*s5 + 120*s1*s3**2*s5 - 80*s1*s3*s4**2 +- 15*s2**2*s3*s5 - 40*s2**2*s4**2 + 21*s2*s3**2*s4 + 125*s2*s5**2 - 2*s3**4 - 400*s3*s4*s5 ++ 160*s4**3), + lambda s1, s2, s3, s4, s5: (16*s1**6*s5**2 - 8*s1**5*s2*s4*s5 - 8*s1**5*s3**2*s5 + 2*s1**5*s3*s4**2 + 2*s1**4*s2**2*s3*s5 ++ s1**4*s2**2*s4**2 - 120*s1**4*s2*s5**2 + 68*s1**4*s3*s4*s5 - 8*s1**4*s4**3 + 46*s1**3*s2**2*s4*s5 ++ 28*s1**3*s2*s3**2*s5 - 19*s1**3*s2*s3*s4**2 + 250*s1**3*s3*s5**2 - 144*s1**3*s4**2*s5 +- 9*s1**2*s2**3*s3*s5 - 6*s1**2*s2**3*s4**2 + 3*s1**2*s2**2*s3**2*s4 + 225*s1**2*s2**2*s5**2 +- 354*s1**2*s2*s3*s4*s5 + 76*s1**2*s2*s4**3 - 70*s1**2*s3**3*s5 + 41*s1**2*s3**2*s4**2 +- 200*s1**2*s4*s5**2 - 54*s1*s2**3*s4*s5 + 45*s1*s2**2*s3**2*s5 + 30*s1*s2**2*s3*s4**2 +- 19*s1*s2*s3**3*s4 - 875*s1*s2*s3*s5**2 + 640*s1*s2*s4**2*s5 + 2*s1*s3**5 + 630*s1*s3**2*s4*s5 +- 264*s1*s3*s4**3 + 9*s2**4*s4**2 - 6*s2**3*s3**2*s4 + s2**2*s3**4 + 90*s2**2*s3*s4*s5 +- 136*s2**2*s4**3 - 50*s2*s3**3*s5 + 76*s2*s3**2*s4**2 + 500*s2*s4*s5**2 - 8*s3**4*s4 ++ 625*s3**2*s5**2 - 1400*s3*s4**2*s5 + 400*s4**4), + lambda s1, s2, s3, s4, s5: (-32*s1**7*s3*s5**2 + 8*s1**7*s4**2*s5 + 8*s1**6*s2**2*s5**2 + 8*s1**6*s2*s3*s4*s5 +- 2*s1**6*s2*s4**3 + 48*s1**6*s4*s5**2 - 2*s1**5*s2**3*s4*s5 + 264*s1**5*s2*s3*s5**2 +- 94*s1**5*s2*s4**2*s5 - 24*s1**5*s3**2*s4*s5 + 6*s1**5*s3*s4**3 - 56*s1**5*s5**3 +- 66*s1**4*s2**3*s5**2 - 50*s1**4*s2**2*s3*s4*s5 + 19*s1**4*s2**2*s4**3 + 8*s1**4*s2*s3**3*s5 +- 2*s1**4*s2*s3**2*s4**2 - 318*s1**4*s2*s4*s5**2 - 352*s1**4*s3**2*s5**2 + 166*s1**4*s3*s4**2*s5 ++ 3*s1**4*s4**4 + 15*s1**3*s2**4*s4*s5 - 2*s1**3*s2**3*s3**2*s5 - s1**3*s2**3*s3*s4**2 +- 574*s1**3*s2**2*s3*s5**2 + 347*s1**3*s2**2*s4**2*s5 + 194*s1**3*s2*s3**2*s4*s5 - +89*s1**3*s2*s3*s4**3 + 350*s1**3*s2*s5**3 - 8*s1**3*s3**4*s5 + 4*s1**3*s3**3*s4**2 ++ 1090*s1**3*s3*s4*s5**2 - 364*s1**3*s4**3*s5 + 162*s1**2*s2**4*s5**2 + 33*s1**2*s2**3*s3*s4*s5 +- 51*s1**2*s2**3*s4**3 - 32*s1**2*s2**2*s3**3*s5 + 28*s1**2*s2**2*s3**2*s4**2 + 305*s1**2*s2**2*s4*s5**2 +- 2*s1**2*s2*s3**4*s4 + 1340*s1**2*s2*s3**2*s5**2 - 901*s1**2*s2*s3*s4**2*s5 + 76*s1**2*s2*s4**4 +- 234*s1**2*s3**3*s4*s5 + 102*s1**2*s3**2*s4**3 - 750*s1**2*s3*s5**3 - 550*s1**2*s4**2*s5**2 +- 27*s1*s2**5*s4*s5 + 9*s1*s2**4*s3**2*s5 + 3*s1*s2**4*s3*s4**2 - s1*s2**3*s3**3*s4 ++ 180*s1*s2**3*s3*s5**2 - 366*s1*s2**3*s4**2*s5 - 231*s1*s2**2*s3**2*s4*s5 + 212*s1*s2**2*s3*s4**3 +- 375*s1*s2**2*s5**3 + 112*s1*s2*s3**4*s5 - 89*s1*s2*s3**3*s4**2 - 3075*s1*s2*s3*s4*s5**2 ++ 1640*s1*s2*s4**3*s5 + 6*s1*s3**5*s4 - 850*s1*s3**3*s5**2 + 1220*s1*s3**2*s4**2*s5 +- 384*s1*s3*s4**4 + 2500*s1*s4*s5**3 - 108*s2**5*s5**2 + 117*s2**4*s3*s4*s5 + 32*s2**4*s4**3 +- 31*s2**3*s3**3*s5 - 51*s2**3*s3**2*s4**2 + 525*s2**3*s4*s5**2 + 19*s2**2*s3**4*s4 +- 325*s2**2*s3**2*s5**2 + 260*s2**2*s3*s4**2*s5 - 256*s2**2*s4**4 - 2*s2*s3**6 + 105*s2*s3**3*s4*s5 ++ 76*s2*s3**2*s4**3 + 625*s2*s3*s5**3 - 500*s2*s4**2*s5**2 - 58*s3**5*s5 + 3*s3**4*s4**2 ++ 2750*s3**2*s4*s5**2 - 2400*s3*s4**3*s5 + 512*s4**5 - 3125*s5**4), + lambda s1, s2, s3, s4, s5: (16*s1**8*s3**2*s5**2 - 8*s1**8*s3*s4**2*s5 + s1**8*s4**4 - 8*s1**7*s2**2*s3*s5**2 ++ 2*s1**7*s2**2*s4**2*s5 - 48*s1**7*s3*s4*s5**2 + 12*s1**7*s4**3*s5 + s1**6*s2**4*s5**2 ++ 12*s1**6*s2**2*s4*s5**2 - 144*s1**6*s2*s3**2*s5**2 + 88*s1**6*s2*s3*s4**2*s5 - 13*s1**6*s2*s4**4 ++ 56*s1**6*s3*s5**3 + 86*s1**6*s4**2*s5**2 + 72*s1**5*s2**3*s3*s5**2 - 22*s1**5*s2**3*s4**2*s5 +- 4*s1**5*s2**2*s3**2*s4*s5 + s1**5*s2**2*s3*s4**3 - 14*s1**5*s2**2*s5**3 + 304*s1**5*s2*s3*s4*s5**2 +- 148*s1**5*s2*s4**3*s5 + 152*s1**5*s3**3*s5**2 - 54*s1**5*s3**2*s4**2*s5 + 5*s1**5*s3*s4**4 +- 468*s1**5*s4*s5**3 - 9*s1**4*s2**5*s5**2 + s1**4*s2**4*s3*s4*s5 - 76*s1**4*s2**3*s4*s5**2 ++ 370*s1**4*s2**2*s3**2*s5**2 - 287*s1**4*s2**2*s3*s4**2*s5 + 65*s1**4*s2**2*s4**4 +- 28*s1**4*s2*s3**3*s4*s5 + 5*s1**4*s2*s3**2*s4**3 - 200*s1**4*s2*s3*s5**3 - 294*s1**4*s2*s4**2*s5**2 ++ 8*s1**4*s3**5*s5 - 2*s1**4*s3**4*s4**2 - 676*s1**4*s3**2*s4*s5**2 + 180*s1**4*s3*s4**3*s5 ++ 17*s1**4*s4**5 + 625*s1**4*s5**4 - 210*s1**3*s2**4*s3*s5**2 + 76*s1**3*s2**4*s4**2*s5 ++ 43*s1**3*s2**3*s3**2*s4*s5 - 15*s1**3*s2**3*s3*s4**3 + 50*s1**3*s2**3*s5**3 - 6*s1**3*s2**2*s3**4*s5 ++ 2*s1**3*s2**2*s3**3*s4**2 - 397*s1**3*s2**2*s3*s4*s5**2 + 514*s1**3*s2**2*s4**3*s5 +- 700*s1**3*s2*s3**3*s5**2 + 447*s1**3*s2*s3**2*s4**2*s5 - 118*s1**3*s2*s3*s4**4 + +2300*s1**3*s2*s4*s5**3 - 12*s1**3*s3**4*s4*s5 + 6*s1**3*s3**3*s4**3 + 250*s1**3*s3**2*s5**3 ++ 1470*s1**3*s3*s4**2*s5**2 - 276*s1**3*s4**4*s5 + 27*s1**2*s2**6*s5**2 - 9*s1**2*s2**5*s3*s4*s5 ++ s1**2*s2**5*s4**3 + s1**2*s2**4*s3**3*s5 + 141*s1**2*s2**4*s4*s5**2 - 185*s1**2*s2**3*s3**2*s5**2 ++ 168*s1**2*s2**3*s3*s4**2*s5 - 128*s1**2*s2**3*s4**4 + 93*s1**2*s2**2*s3**3*s4*s5 ++ 19*s1**2*s2**2*s3**2*s4**3 - 125*s1**2*s2**2*s3*s5**3 - 610*s1**2*s2**2*s4**2*s5**2 +- 36*s1**2*s2*s3**5*s5 + 5*s1**2*s2*s3**4*s4**2 + 1995*s1**2*s2*s3**2*s4*s5**2 - 1174*s1**2*s2*s3*s4**3*s5 +- 16*s1**2*s2*s4**5 - 3125*s1**2*s2*s5**4 + 375*s1**2*s3**4*s5**2 - 172*s1**2*s3**3*s4**2*s5 ++ 82*s1**2*s3**2*s4**4 - 3500*s1**2*s3*s4*s5**3 - 1450*s1**2*s4**3*s5**2 + 198*s1*s2**5*s3*s5**2 +- 78*s1*s2**5*s4**2*s5 - 95*s1*s2**4*s3**2*s4*s5 + 44*s1*s2**4*s3*s4**3 + 25*s1*s2**3*s3**4*s5 +- 15*s1*s2**3*s3**3*s4**2 + 15*s1*s2**3*s3*s4*s5**2 - 384*s1*s2**3*s4**3*s5 + s1*s2**2*s3**5*s4 ++ 525*s1*s2**2*s3**3*s5**2 - 528*s1*s2**2*s3**2*s4**2*s5 + 384*s1*s2**2*s3*s4**4 - +1750*s1*s2**2*s4*s5**3 - 29*s1*s2*s3**4*s4*s5 - 118*s1*s2*s3**3*s4**3 + 625*s1*s2*s3**2*s5**3 +- 850*s1*s2*s3*s4**2*s5**2 + 1760*s1*s2*s4**4*s5 + 38*s1*s3**6*s5 + 5*s1*s3**5*s4**2 +- 2050*s1*s3**3*s4*s5**2 + 780*s1*s3**2*s4**3*s5 - 192*s1*s3*s4**5 + 3125*s1*s3*s5**4 ++ 7500*s1*s4**2*s5**3 - 27*s2**7*s5**2 + 18*s2**6*s3*s4*s5 - 4*s2**6*s4**3 - 4*s2**5*s3**3*s5 ++ s2**5*s3**2*s4**2 - 99*s2**5*s4*s5**2 - 150*s2**4*s3**2*s5**2 + 196*s2**4*s3*s4**2*s5 ++ 48*s2**4*s4**4 + 12*s2**3*s3**3*s4*s5 - 128*s2**3*s3**2*s4**3 + 1200*s2**3*s4**2*s5**2 +- 12*s2**2*s3**5*s5 + 65*s2**2*s3**4*s4**2 - 725*s2**2*s3**2*s4*s5**2 - 160*s2**2*s3*s4**3*s5 +- 192*s2**2*s4**5 + 3125*s2**2*s5**4 - 13*s2*s3**6*s4 - 125*s2*s3**4*s5**2 + 590*s2*s3**3*s4**2*s5 +- 16*s2*s3**2*s4**4 - 1250*s2*s3*s4*s5**3 - 2000*s2*s4**3*s5**2 + s3**8 - 124*s3**5*s4*s5 ++ 17*s3**4*s4**3 + 3250*s3**2*s4**2*s5**2 - 1600*s3*s4**4*s5 + 256*s4**6 - 9375*s4*s5**4) + ], + (6, 1): [ + lambda s1, s2, s3, s4, s5, s6: (8*s1*s5 - 2*s2*s4 - 18*s6), + lambda s1, s2, s3, s4, s5, s6: (-50*s1**2*s4*s6 + 40*s1**2*s5**2 + 30*s1*s2*s3*s6 - 14*s1*s2*s4*s5 - 6*s1*s3**2*s5 ++ 2*s1*s3*s4**2 - 30*s1*s5*s6 - 8*s2**3*s6 + 2*s2**2*s3*s5 + s2**2*s4**2 + 114*s2*s4*s6 +- 50*s2*s5**2 - 54*s3**2*s6 + 30*s3*s4*s5 - 8*s4**3 - 135*s6**2), + lambda s1, s2, s3, s4, s5, s6: (125*s1**3*s3*s6**2 - 400*s1**3*s4*s5*s6 + 160*s1**3*s5**3 - 50*s1**2*s2**2*s6**2 + +190*s1**2*s2*s3*s5*s6 + 120*s1**2*s2*s4**2*s6 - 80*s1**2*s2*s4*s5**2 - 15*s1**2*s3**2*s4*s6 +- 40*s1**2*s3**2*s5**2 + 21*s1**2*s3*s4**2*s5 - 2*s1**2*s4**4 + 900*s1**2*s4*s6**2 +- 80*s1**2*s5**2*s6 - 44*s1*s2**3*s5*s6 - 66*s1*s2**2*s3*s4*s6 + 21*s1*s2**2*s3*s5**2 ++ 6*s1*s2**2*s4**2*s5 + 9*s1*s2*s3**3*s6 + 5*s1*s2*s3**2*s4*s5 - 2*s1*s2*s3*s4**3 +- 990*s1*s2*s3*s6**2 + 920*s1*s2*s4*s5*s6 - 400*s1*s2*s5**3 - 135*s1*s3**2*s5*s6 - +126*s1*s3*s4**2*s6 + 190*s1*s3*s4*s5**2 - 44*s1*s4**3*s5 - 2070*s1*s5*s6**2 + 16*s2**4*s4*s6 +- 2*s2**4*s5**2 - 2*s2**3*s3**2*s6 - 2*s2**3*s3*s4*s5 + 304*s2**3*s6**2 - 126*s2**2*s3*s5*s6 +- 232*s2**2*s4**2*s6 + 120*s2**2*s4*s5**2 + 198*s2*s3**2*s4*s6 - 15*s2*s3**2*s5**2 +- 66*s2*s3*s4**2*s5 + 16*s2*s4**4 - 1440*s2*s4*s6**2 + 900*s2*s5**2*s6 - 27*s3**4*s6 ++ 9*s3**3*s4*s5 - 2*s3**2*s4**3 + 1350*s3**2*s6**2 - 990*s3*s4*s5*s6 + 125*s3*s5**3 ++ 304*s4**3*s6 - 50*s4**2*s5**2 + 3240*s6**3), + lambda s1, s2, s3, s4, s5, s6: (500*s1**4*s3*s5*s6**2 + 625*s1**4*s4**2*s6**2 - 1400*s1**4*s4*s5**2*s6 + 400*s1**4*s5**4 +- 200*s1**3*s2**2*s5*s6**2 - 875*s1**3*s2*s3*s4*s6**2 + 640*s1**3*s2*s3*s5**2*s6 + +630*s1**3*s2*s4**2*s5*s6 - 264*s1**3*s2*s4*s5**3 + 90*s1**3*s3**2*s4*s5*s6 - 136*s1**3*s3**2*s5**3 +- 50*s1**3*s3*s4**3*s6 + 76*s1**3*s3*s4**2*s5**2 - 1125*s1**3*s3*s6**3 - 8*s1**3*s4**4*s5 ++ 2550*s1**3*s4*s5*s6**2 - 200*s1**3*s5**3*s6 + 250*s1**2*s2**3*s4*s6**2 - 144*s1**2*s2**3*s5**2*s6 ++ 225*s1**2*s2**2*s3**2*s6**2 - 354*s1**2*s2**2*s3*s4*s5*s6 + 76*s1**2*s2**2*s3*s5**3 +- 70*s1**2*s2**2*s4**3*s6 + 41*s1**2*s2**2*s4**2*s5**2 + 450*s1**2*s2**2*s6**3 - 54*s1**2*s2*s3**3*s5*s6 ++ 45*s1**2*s2*s3**2*s4**2*s6 + 30*s1**2*s2*s3**2*s4*s5**2 - 19*s1**2*s2*s3*s4**3*s5 +- 2880*s1**2*s2*s3*s5*s6**2 + 2*s1**2*s2*s4**5 - 3480*s1**2*s2*s4**2*s6**2 + 4692*s1**2*s2*s4*s5**2*s6 +- 1400*s1**2*s2*s5**4 + 9*s1**2*s3**4*s5**2 - 6*s1**2*s3**3*s4**2*s5 + s1**2*s3**2*s4**4 ++ 1485*s1**2*s3**2*s4*s6**2 - 522*s1**2*s3**2*s5**2*s6 - 1257*s1**2*s3*s4**2*s5*s6 ++ 640*s1**2*s3*s4*s5**3 + 218*s1**2*s4**4*s6 - 144*s1**2*s4**3*s5**2 + 1350*s1**2*s4*s6**3 +- 5175*s1**2*s5**2*s6**2 - 120*s1*s2**4*s3*s6**2 + 68*s1*s2**4*s4*s5*s6 - 8*s1*s2**4*s5**3 ++ 46*s1*s2**3*s3**2*s5*s6 + 28*s1*s2**3*s3*s4**2*s6 - 19*s1*s2**3*s3*s4*s5**2 + 868*s1*s2**3*s5*s6**2 +- 9*s1*s2**2*s3**3*s4*s6 - 6*s1*s2**2*s3**3*s5**2 + 3*s1*s2**2*s3**2*s4**2*s5 + 2484*s1*s2**2*s3*s4*s6**2 +- 1257*s1*s2**2*s3*s5**2*s6 - 1356*s1*s2**2*s4**2*s5*s6 + 630*s1*s2**2*s4*s5**3 - +891*s1*s2*s3**3*s6**2 + 882*s1*s2*s3**2*s4*s5*s6 + 90*s1*s2*s3**2*s5**3 + 84*s1*s2*s3*s4**3*s6 +- 354*s1*s2*s3*s4**2*s5**2 + 3240*s1*s2*s3*s6**3 + 68*s1*s2*s4**4*s5 - 4392*s1*s2*s4*s5*s6**2 ++ 2550*s1*s2*s5**3*s6 + 54*s1*s3**4*s5*s6 - 54*s1*s3**3*s4**2*s6 - 54*s1*s3**3*s4*s5**2 ++ 46*s1*s3**2*s4**3*s5 + 2727*s1*s3**2*s5*s6**2 - 8*s1*s3*s4**5 + 756*s1*s3*s4**2*s6**2 +- 2880*s1*s3*s4*s5**2*s6 + 500*s1*s3*s5**4 + 868*s1*s4**3*s5*s6 - 200*s1*s4**2*s5**3 ++ 8100*s1*s5*s6**3 + 16*s2**6*s6**2 - 8*s2**5*s3*s5*s6 - 8*s2**5*s4**2*s6 + 2*s2**5*s4*s5**2 ++ 2*s2**4*s3**2*s4*s6 + s2**4*s3**2*s5**2 - 688*s2**4*s4*s6**2 + 218*s2**4*s5**2*s6 ++ 234*s2**3*s3**2*s6**2 + 84*s2**3*s3*s4*s5*s6 - 50*s2**3*s3*s5**3 + 168*s2**3*s4**3*s6 +- 70*s2**3*s4**2*s5**2 - 1224*s2**3*s6**3 - 54*s2**2*s3**3*s5*s6 - 144*s2**2*s3**2*s4**2*s6 ++ 45*s2**2*s3**2*s4*s5**2 + 28*s2**2*s3*s4**3*s5 + 756*s2**2*s3*s5*s6**2 - 8*s2**2*s4**5 ++ 4320*s2**2*s4**2*s6**2 - 3480*s2**2*s4*s5**2*s6 + 625*s2**2*s5**4 + 27*s2*s3**4*s4*s6 +- 9*s2*s3**3*s4**2*s5 + 2*s2*s3**2*s4**4 - 4752*s2*s3**2*s4*s6**2 + 1485*s2*s3**2*s5**2*s6 ++ 2484*s2*s3*s4**2*s5*s6 - 875*s2*s3*s4*s5**3 - 688*s2*s4**4*s6 + 250*s2*s4**3*s5**2 +- 4536*s2*s4*s6**3 + 1350*s2*s5**2*s6**2 + 972*s3**4*s6**2 - 891*s3**3*s4*s5*s6 + +234*s3**2*s4**3*s6 + 225*s3**2*s4**2*s5**2 - 1944*s3**2*s6**3 - 120*s3*s4**4*s5 + +3240*s3*s4*s5*s6**2 - 1125*s3*s5**3*s6 + 16*s4**6 - 1224*s4**3*s6**2 + 450*s4**2*s5**2*s6), + lambda s1, s2, s3, s4, s5, s6: (-3125*s1**6*s6**4 + 2500*s1**5*s2*s5*s6**3 + 625*s1**5*s3*s4*s6**3 - 500*s1**5*s3*s5**2*s6**2 ++ 2750*s1**5*s4**2*s5*s6**2 - 2400*s1**5*s4*s5**3*s6 + 512*s1**5*s5**5 - 750*s1**4*s2**2*s4*s6**3 +- 550*s1**4*s2**2*s5**2*s6**2 - 375*s1**4*s2*s3**2*s6**3 - 3075*s1**4*s2*s3*s4*s5*s6**2 ++ 1640*s1**4*s2*s3*s5**3*s6 - 850*s1**4*s2*s4**3*s6**2 + 1220*s1**4*s2*s4**2*s5**2*s6 +- 384*s1**4*s2*s4*s5**4 + 22500*s1**4*s2*s6**4 + 525*s1**4*s3**3*s5*s6**2 - 325*s1**4*s3**2*s4**2*s6**2 ++ 260*s1**4*s3**2*s4*s5**2*s6 - 256*s1**4*s3**2*s5**4 + 105*s1**4*s3*s4**3*s5*s6 + +76*s1**4*s3*s4**2*s5**3 + 375*s1**4*s3*s5*s6**3 - 58*s1**4*s4**5*s6 + 3*s1**4*s4**4*s5**2 +- 12750*s1**4*s4**2*s6**3 + 3700*s1**4*s4*s5**2*s6**2 + 640*s1**4*s5**4*s6 + 350*s1**3*s2**3*s3*s6**3 ++ 1090*s1**3*s2**3*s4*s5*s6**2 - 364*s1**3*s2**3*s5**3*s6 + 305*s1**3*s2**2*s3**2*s5*s6**2 ++ 1340*s1**3*s2**2*s3*s4**2*s6**2 - 901*s1**3*s2**2*s3*s4*s5**2*s6 + 76*s1**3*s2**2*s3*s5**4 +- 234*s1**3*s2**2*s4**3*s5*s6 + 102*s1**3*s2**2*s4**2*s5**3 - 16650*s1**3*s2**2*s5*s6**3 ++ 180*s1**3*s2*s3**3*s4*s6**2 - 366*s1**3*s2*s3**3*s5**2*s6 - 231*s1**3*s2*s3**2*s4**2*s5*s6 ++ 212*s1**3*s2*s3**2*s4*s5**3 + 112*s1**3*s2*s3*s4**4*s6 - 89*s1**3*s2*s3*s4**3*s5**2 ++ 10950*s1**3*s2*s3*s4*s6**3 + 1555*s1**3*s2*s3*s5**2*s6**2 + 6*s1**3*s2*s4**5*s5 +- 9540*s1**3*s2*s4**2*s5*s6**2 + 9016*s1**3*s2*s4*s5**3*s6 - 2400*s1**3*s2*s5**5 - +108*s1**3*s3**5*s6**2 + 117*s1**3*s3**4*s4*s5*s6 + 32*s1**3*s3**4*s5**3 - 31*s1**3*s3**3*s4**3*s6 +- 51*s1**3*s3**3*s4**2*s5**2 - 2025*s1**3*s3**3*s6**3 + 19*s1**3*s3**2*s4**4*s5 + +2955*s1**3*s3**2*s4*s5*s6**2 - 1436*s1**3*s3**2*s5**3*s6 - 2*s1**3*s3*s4**6 + 2770*s1**3*s3*s4**3*s6**2 +- 5123*s1**3*s3*s4**2*s5**2*s6 + 1640*s1**3*s3*s4*s5**4 - 40500*s1**3*s3*s6**4 + 914*s1**3*s4**4*s5*s6 +- 364*s1**3*s4**3*s5**3 + 53550*s1**3*s4*s5*s6**3 - 17930*s1**3*s5**3*s6**2 - 56*s1**2*s2**5*s6**3 +- 318*s1**2*s2**4*s3*s5*s6**2 - 352*s1**2*s2**4*s4**2*s6**2 + 166*s1**2*s2**4*s4*s5**2*s6 ++ 3*s1**2*s2**4*s5**4 - 574*s1**2*s2**3*s3**2*s4*s6**2 + 347*s1**2*s2**3*s3**2*s5**2*s6 ++ 194*s1**2*s2**3*s3*s4**2*s5*s6 - 89*s1**2*s2**3*s3*s4*s5**3 - 8*s1**2*s2**3*s4**4*s6 ++ 4*s1**2*s2**3*s4**3*s5**2 + 560*s1**2*s2**3*s4*s6**3 + 3662*s1**2*s2**3*s5**2*s6**2 ++ 162*s1**2*s2**2*s3**4*s6**2 + 33*s1**2*s2**2*s3**3*s4*s5*s6 - 51*s1**2*s2**2*s3**3*s5**3 +- 32*s1**2*s2**2*s3**2*s4**3*s6 + 28*s1**2*s2**2*s3**2*s4**2*s5**2 + 270*s1**2*s2**2*s3**2*s6**3 +- 2*s1**2*s2**2*s3*s4**4*s5 + 4872*s1**2*s2**2*s3*s4*s5*s6**2 - 5123*s1**2*s2**2*s3*s5**3*s6 ++ 2144*s1**2*s2**2*s4**3*s6**2 - 2812*s1**2*s2**2*s4**2*s5**2*s6 + 1220*s1**2*s2**2*s4*s5**4 +- 37800*s1**2*s2**2*s6**4 - 27*s1**2*s2*s3**5*s5*s6 + 9*s1**2*s2*s3**4*s4**2*s6 + +3*s1**2*s2*s3**4*s4*s5**2 - s1**2*s2*s3**3*s4**3*s5 - 3078*s1**2*s2*s3**3*s5*s6**2 +- 4014*s1**2*s2*s3**2*s4**2*s6**2 + 5412*s1**2*s2*s3**2*s4*s5**2*s6 + 260*s1**2*s2*s3**2*s5**4 +- 310*s1**2*s2*s3*s4**3*s5*s6 - 901*s1**2*s2*s3*s4**2*s5**3 - 3780*s1**2*s2*s3*s5*s6**3 ++ 166*s1**2*s2*s4**4*s5**2 + 40320*s1**2*s2*s4**2*s6**3 - 25344*s1**2*s2*s4*s5**2*s6**2 ++ 3700*s1**2*s2*s5**4*s6 + 918*s1**2*s3**4*s4*s6**2 + 27*s1**2*s3**4*s5**2*s6 - 342*s1**2*s3**3*s4**2*s5*s6 +- 366*s1**2*s3**3*s4*s5**3 + 32*s1**2*s3**2*s4**4*s6 + 347*s1**2*s3**2*s4**3*s5**2 +- 4590*s1**2*s3**2*s4*s6**3 + 594*s1**2*s3**2*s5**2*s6**2 - 94*s1**2*s3*s4**5*s5 + +3618*s1**2*s3*s4**2*s5*s6**2 + 1555*s1**2*s3*s4*s5**3*s6 - 500*s1**2*s3*s5**5 + 8*s1**2*s4**7 +- 7192*s1**2*s4**4*s6**2 + 3662*s1**2*s4**3*s5**2*s6 - 550*s1**2*s4**2*s5**4 - 48600*s1**2*s4*s6**4 ++ 1080*s1**2*s5**2*s6**3 + 48*s1*s2**6*s5*s6**2 + 264*s1*s2**5*s3*s4*s6**2 - 94*s1*s2**5*s3*s5**2*s6 +- 24*s1*s2**5*s4**2*s5*s6 + 6*s1*s2**5*s4*s5**3 - 66*s1*s2**4*s3**3*s6**2 - 50*s1*s2**4*s3**2*s4*s5*s6 ++ 19*s1*s2**4*s3**2*s5**3 + 8*s1*s2**4*s3*s4**3*s6 - 2*s1*s2**4*s3*s4**2*s5**2 - 552*s1*s2**4*s3*s6**3 +- 2560*s1*s2**4*s4*s5*s6**2 + 914*s1*s2**4*s5**3*s6 + 15*s1*s2**3*s3**4*s5*s6 - 2*s1*s2**3*s3**3*s4**2*s6 +- s1*s2**3*s3**3*s4*s5**2 + 1602*s1*s2**3*s3**2*s5*s6**2 - 608*s1*s2**3*s3*s4**2*s6**2 +- 310*s1*s2**3*s3*s4*s5**2*s6 + 105*s1*s2**3*s3*s5**4 + 600*s1*s2**3*s4**3*s5*s6 - +234*s1*s2**3*s4**2*s5**3 + 31368*s1*s2**3*s5*s6**3 + 756*s1*s2**2*s3**3*s4*s6**2 - +342*s1*s2**2*s3**3*s5**2*s6 + 216*s1*s2**2*s3**2*s4**2*s5*s6 - 231*s1*s2**2*s3**2*s4*s5**3 +- 192*s1*s2**2*s3*s4**4*s6 + 194*s1*s2**2*s3*s4**3*s5**2 - 39096*s1*s2**2*s3*s4*s6**3 ++ 3618*s1*s2**2*s3*s5**2*s6**2 - 24*s1*s2**2*s4**5*s5 + 9408*s1*s2**2*s4**2*s5*s6**2 +- 9540*s1*s2**2*s4*s5**3*s6 + 2750*s1*s2**2*s5**5 - 162*s1*s2*s3**5*s6**2 - 378*s1*s2*s3**4*s4*s5*s6 ++ 117*s1*s2*s3**4*s5**3 + 150*s1*s2*s3**3*s4**3*s6 + 33*s1*s2*s3**3*s4**2*s5**2 + +10044*s1*s2*s3**3*s6**3 - 50*s1*s2*s3**2*s4**4*s5 - 8640*s1*s2*s3**2*s4*s5*s6**2 + +2955*s1*s2*s3**2*s5**3*s6 + 8*s1*s2*s3*s4**6 + 6144*s1*s2*s3*s4**3*s6**2 + 4872*s1*s2*s3*s4**2*s5**2*s6 +- 3075*s1*s2*s3*s4*s5**4 + 174960*s1*s2*s3*s6**4 - 2560*s1*s2*s4**4*s5*s6 + 1090*s1*s2*s4**3*s5**3 +- 148824*s1*s2*s4*s5*s6**3 + 53550*s1*s2*s5**3*s6**2 + 81*s1*s3**6*s5*s6 - 27*s1*s3**5*s4**2*s6 +- 27*s1*s3**5*s4*s5**2 + 15*s1*s3**4*s4**3*s5 + 2430*s1*s3**4*s5*s6**2 - 2*s1*s3**3*s4**5 +- 2052*s1*s3**3*s4**2*s6**2 - 3078*s1*s3**3*s4*s5**2*s6 + 525*s1*s3**3*s5**4 + 1602*s1*s3**2*s4**3*s5*s6 ++ 305*s1*s3**2*s4**2*s5**3 + 18144*s1*s3**2*s5*s6**3 - 104*s1*s3*s4**5*s6 - 318*s1*s3*s4**4*s5**2 +- 33696*s1*s3*s4**2*s6**3 - 3780*s1*s3*s4*s5**2*s6**2 + 375*s1*s3*s5**4*s6 + 48*s1*s4**6*s5 ++ 31368*s1*s4**3*s5*s6**2 - 16650*s1*s4**2*s5**3*s6 + 2500*s1*s4*s5**5 + 77760*s1*s5*s6**4 +- 32*s2**7*s4*s6**2 + 8*s2**7*s5**2*s6 + 8*s2**6*s3**2*s6**2 + 8*s2**6*s3*s4*s5*s6 +- 2*s2**6*s3*s5**3 + 96*s2**6*s6**3 - 2*s2**5*s3**3*s5*s6 - 104*s2**5*s3*s5*s6**2 ++ 416*s2**5*s4**2*s6**2 - 58*s2**5*s5**4 - 312*s2**4*s3**2*s4*s6**2 + 32*s2**4*s3**2*s5**2*s6 +- 192*s2**4*s3*s4**2*s5*s6 + 112*s2**4*s3*s4*s5**3 - 8*s2**4*s4**3*s5**2 + 4224*s2**4*s4*s6**3 +- 7192*s2**4*s5**2*s6**2 + 54*s2**3*s3**4*s6**2 + 150*s2**3*s3**3*s4*s5*s6 - 31*s2**3*s3**3*s5**3 +- 32*s2**3*s3**2*s4**2*s5**2 - 864*s2**3*s3**2*s6**3 + 8*s2**3*s3*s4**4*s5 + 6144*s2**3*s3*s4*s5*s6**2 ++ 2770*s2**3*s3*s5**3*s6 - 4032*s2**3*s4**3*s6**2 + 2144*s2**3*s4**2*s5**2*s6 - 850*s2**3*s4*s5**4 +- 16416*s2**3*s6**4 - 27*s2**2*s3**5*s5*s6 + 9*s2**2*s3**4*s4*s5**2 - 2*s2**2*s3**3*s4**3*s5 +- 2052*s2**2*s3**3*s5*s6**2 + 2376*s2**2*s3**2*s4**2*s6**2 - 4014*s2**2*s3**2*s4*s5**2*s6 +- 325*s2**2*s3**2*s5**4 - 608*s2**2*s3*s4**3*s5*s6 + 1340*s2**2*s3*s4**2*s5**3 - 33696*s2**2*s3*s5*s6**3 ++ 416*s2**2*s4**5*s6 - 352*s2**2*s4**4*s5**2 - 6048*s2**2*s4**2*s6**3 + 40320*s2**2*s4*s5**2*s6**2 +- 12750*s2**2*s5**4*s6 - 324*s2*s3**4*s4*s6**2 + 918*s2*s3**4*s5**2*s6 + 756*s2*s3**3*s4**2*s5*s6 ++ 180*s2*s3**3*s4*s5**3 - 312*s2*s3**2*s4**4*s6 - 574*s2*s3**2*s4**3*s5**2 + 43416*s2*s3**2*s4*s6**3 +- 4590*s2*s3**2*s5**2*s6**2 + 264*s2*s3*s4**5*s5 - 39096*s2*s3*s4**2*s5*s6**2 + 10950*s2*s3*s4*s5**3*s6 ++ 625*s2*s3*s5**5 - 32*s2*s4**7 + 4224*s2*s4**4*s6**2 + 560*s2*s4**3*s5**2*s6 - 750*s2*s4**2*s5**4 ++ 85536*s2*s4*s6**4 - 48600*s2*s5**2*s6**3 - 162*s3**5*s4*s5*s6 - 108*s3**5*s5**3 ++ 54*s3**4*s4**3*s6 + 162*s3**4*s4**2*s5**2 - 11664*s3**4*s6**3 - 66*s3**3*s4**4*s5 ++ 10044*s3**3*s4*s5*s6**2 - 2025*s3**3*s5**3*s6 + 8*s3**2*s4**6 - 864*s3**2*s4**3*s6**2 ++ 270*s3**2*s4**2*s5**2*s6 - 375*s3**2*s4*s5**4 - 163296*s3**2*s6**4 - 552*s3*s4**4*s5*s6 ++ 350*s3*s4**3*s5**3 + 174960*s3*s4*s5*s6**3 - 40500*s3*s5**3*s6**2 + 96*s4**6*s6 +- 56*s4**5*s5**2 - 16416*s4**3*s6**3 - 37800*s4**2*s5**2*s6**2 + 22500*s4*s5**4*s6 +- 3125*s5**6 - 93312*s6**5), + lambda s1, s2, s3, s4, s5, s6: (-9375*s1**7*s5*s6**4 + 3125*s1**6*s2*s4*s6**4 + 7500*s1**6*s2*s5**2*s6**3 + 3125*s1**6*s3**2*s6**4 +- 1250*s1**6*s3*s4*s5*s6**3 - 2000*s1**6*s3*s5**3*s6**2 + 3250*s1**6*s4**2*s5**2*s6**2 +- 1600*s1**6*s4*s5**4*s6 + 256*s1**6*s5**6 + 40625*s1**6*s6**5 - 3125*s1**5*s2**2*s3*s6**4 +- 3500*s1**5*s2**2*s4*s5*s6**3 - 1450*s1**5*s2**2*s5**3*s6**2 - 1750*s1**5*s2*s3**2*s5*s6**3 ++ 625*s1**5*s2*s3*s4**2*s6**3 - 850*s1**5*s2*s3*s4*s5**2*s6**2 + 1760*s1**5*s2*s3*s5**4*s6 +- 2050*s1**5*s2*s4**3*s5*s6**2 + 780*s1**5*s2*s4**2*s5**3*s6 - 192*s1**5*s2*s4*s5**5 ++ 35000*s1**5*s2*s5*s6**4 + 1200*s1**5*s3**3*s5**2*s6**2 - 725*s1**5*s3**2*s4**2*s5*s6**2 +- 160*s1**5*s3**2*s4*s5**3*s6 - 192*s1**5*s3**2*s5**5 - 125*s1**5*s3*s4**4*s6**2 + +590*s1**5*s3*s4**3*s5**2*s6 - 16*s1**5*s3*s4**2*s5**4 - 20625*s1**5*s3*s4*s6**4 + +17250*s1**5*s3*s5**2*s6**3 - 124*s1**5*s4**5*s5*s6 + 17*s1**5*s4**4*s5**3 - 20250*s1**5*s4**2*s5*s6**3 ++ 1900*s1**5*s4*s5**3*s6**2 + 1344*s1**5*s5**5*s6 + 625*s1**4*s2**4*s6**4 + 2300*s1**4*s2**3*s3*s5*s6**3 ++ 250*s1**4*s2**3*s4**2*s6**3 + 1470*s1**4*s2**3*s4*s5**2*s6**2 - 276*s1**4*s2**3*s5**4*s6 +- 125*s1**4*s2**2*s3**2*s4*s6**3 - 610*s1**4*s2**2*s3**2*s5**2*s6**2 + 1995*s1**4*s2**2*s3*s4**2*s5*s6**2 +- 1174*s1**4*s2**2*s3*s4*s5**3*s6 - 16*s1**4*s2**2*s3*s5**5 + 375*s1**4*s2**2*s4**4*s6**2 +- 172*s1**4*s2**2*s4**3*s5**2*s6 + 82*s1**4*s2**2*s4**2*s5**4 - 7750*s1**4*s2**2*s4*s6**4 +- 46650*s1**4*s2**2*s5**2*s6**3 + 15*s1**4*s2*s3**3*s4*s5*s6**2 - 384*s1**4*s2*s3**3*s5**3*s6 ++ 525*s1**4*s2*s3**2*s4**3*s6**2 - 528*s1**4*s2*s3**2*s4**2*s5**2*s6 + 384*s1**4*s2*s3**2*s4*s5**4 +- 10125*s1**4*s2*s3**2*s6**4 - 29*s1**4*s2*s3*s4**4*s5*s6 - 118*s1**4*s2*s3*s4**3*s5**3 ++ 36700*s1**4*s2*s3*s4*s5*s6**3 + 2410*s1**4*s2*s3*s5**3*s6**2 + 38*s1**4*s2*s4**6*s6 ++ 5*s1**4*s2*s4**5*s5**2 + 5550*s1**4*s2*s4**3*s6**3 - 10040*s1**4*s2*s4**2*s5**2*s6**2 ++ 5800*s1**4*s2*s4*s5**4*s6 - 1600*s1**4*s2*s5**6 - 292500*s1**4*s2*s6**5 - 99*s1**4*s3**5*s5*s6**2 +- 150*s1**4*s3**4*s4**2*s6**2 + 196*s1**4*s3**4*s4*s5**2*s6 + 48*s1**4*s3**4*s5**4 ++ 12*s1**4*s3**3*s4**3*s5*s6 - 128*s1**4*s3**3*s4**2*s5**3 - 6525*s1**4*s3**3*s5*s6**3 +- 12*s1**4*s3**2*s4**5*s6 + 65*s1**4*s3**2*s4**4*s5**2 + 225*s1**4*s3**2*s4**2*s6**3 ++ 80*s1**4*s3**2*s4*s5**2*s6**2 - 13*s1**4*s3*s4**6*s5 + 5145*s1**4*s3*s4**3*s5*s6**2 +- 6746*s1**4*s3*s4**2*s5**3*s6 + 1760*s1**4*s3*s4*s5**5 - 103500*s1**4*s3*s5*s6**4 ++ s1**4*s4**8 + 954*s1**4*s4**5*s6**2 + 449*s1**4*s4**4*s5**2*s6 - 276*s1**4*s4**3*s5**4 ++ 70125*s1**4*s4**2*s6**4 + 58900*s1**4*s4*s5**2*s6**3 - 23310*s1**4*s5**4*s6**2 - +468*s1**3*s2**5*s5*s6**3 - 200*s1**3*s2**4*s3*s4*s6**3 - 294*s1**3*s2**4*s3*s5**2*s6**2 +- 676*s1**3*s2**4*s4**2*s5*s6**2 + 180*s1**3*s2**4*s4*s5**3*s6 + 17*s1**3*s2**4*s5**5 ++ 50*s1**3*s2**3*s3**3*s6**3 - 397*s1**3*s2**3*s3**2*s4*s5*s6**2 + 514*s1**3*s2**3*s3**2*s5**3*s6 +- 700*s1**3*s2**3*s3*s4**3*s6**2 + 447*s1**3*s2**3*s3*s4**2*s5**2*s6 - 118*s1**3*s2**3*s3*s4*s5**4 ++ 11700*s1**3*s2**3*s3*s6**4 - 12*s1**3*s2**3*s4**4*s5*s6 + 6*s1**3*s2**3*s4**3*s5**3 ++ 10360*s1**3*s2**3*s4*s5*s6**3 + 11404*s1**3*s2**3*s5**3*s6**2 + 141*s1**3*s2**2*s3**4*s5*s6**2 +- 185*s1**3*s2**2*s3**3*s4**2*s6**2 + 168*s1**3*s2**2*s3**3*s4*s5**2*s6 - 128*s1**3*s2**2*s3**3*s5**4 ++ 93*s1**3*s2**2*s3**2*s4**3*s5*s6 + 19*s1**3*s2**2*s3**2*s4**2*s5**3 + 5895*s1**3*s2**2*s3**2*s5*s6**3 +- 36*s1**3*s2**2*s3*s4**5*s6 + 5*s1**3*s2**2*s3*s4**4*s5**2 - 12020*s1**3*s2**2*s3*s4**2*s6**3 +- 5698*s1**3*s2**2*s3*s4*s5**2*s6**2 - 6746*s1**3*s2**2*s3*s5**4*s6 + 5064*s1**3*s2**2*s4**3*s5*s6**2 +- 762*s1**3*s2**2*s4**2*s5**3*s6 + 780*s1**3*s2**2*s4*s5**5 + 93900*s1**3*s2**2*s5*s6**4 ++ 198*s1**3*s2*s3**5*s4*s6**2 - 78*s1**3*s2*s3**5*s5**2*s6 - 95*s1**3*s2*s3**4*s4**2*s5*s6 ++ 44*s1**3*s2*s3**4*s4*s5**3 + 25*s1**3*s2*s3**3*s4**4*s6 - 15*s1**3*s2*s3**3*s4**3*s5**2 ++ 1935*s1**3*s2*s3**3*s4*s6**3 - 2808*s1**3*s2*s3**3*s5**2*s6**2 + s1**3*s2*s3**2*s4**5*s5 +- 4844*s1**3*s2*s3**2*s4**2*s5*s6**2 + 8996*s1**3*s2*s3**2*s4*s5**3*s6 - 160*s1**3*s2*s3**2*s5**5 +- 3616*s1**3*s2*s3*s4**4*s6**2 + 500*s1**3*s2*s3*s4**3*s5**2*s6 - 1174*s1**3*s2*s3*s4**2*s5**4 ++ 72900*s1**3*s2*s3*s4*s6**4 - 55665*s1**3*s2*s3*s5**2*s6**3 + 128*s1**3*s2*s4**5*s5*s6 ++ 180*s1**3*s2*s4**4*s5**3 + 16240*s1**3*s2*s4**2*s5*s6**3 - 9330*s1**3*s2*s4*s5**3*s6**2 ++ 1900*s1**3*s2*s5**5*s6 - 27*s1**3*s3**7*s6**2 + 18*s1**3*s3**6*s4*s5*s6 - 4*s1**3*s3**6*s5**3 +- 4*s1**3*s3**5*s4**3*s6 + s1**3*s3**5*s4**2*s5**2 + 54*s1**3*s3**5*s6**3 + 1143*s1**3*s3**4*s4*s5*s6**2 +- 820*s1**3*s3**4*s5**3*s6 + 923*s1**3*s3**3*s4**3*s6**2 + 57*s1**3*s3**3*s4**2*s5**2*s6 +- 384*s1**3*s3**3*s4*s5**4 + 29700*s1**3*s3**3*s6**4 - 547*s1**3*s3**2*s4**4*s5*s6 ++ 514*s1**3*s3**2*s4**3*s5**3 - 10305*s1**3*s3**2*s4*s5*s6**3 - 7405*s1**3*s3**2*s5**3*s6**2 ++ 108*s1**3*s3*s4**6*s6 - 148*s1**3*s3*s4**5*s5**2 - 11360*s1**3*s3*s4**3*s6**3 + +22209*s1**3*s3*s4**2*s5**2*s6**2 + 2410*s1**3*s3*s4*s5**4*s6 - 2000*s1**3*s3*s5**6 ++ 432000*s1**3*s3*s6**5 + 12*s1**3*s4**7*s5 - 22624*s1**3*s4**4*s5*s6**2 + 11404*s1**3*s4**3*s5**3*s6 +- 1450*s1**3*s4**2*s5**5 - 242100*s1**3*s4*s5*s6**4 + 58430*s1**3*s5**3*s6**3 + 56*s1**2*s2**6*s4*s6**3 ++ 86*s1**2*s2**6*s5**2*s6**2 - 14*s1**2*s2**5*s3**2*s6**3 + 304*s1**2*s2**5*s3*s4*s5*s6**2 +- 148*s1**2*s2**5*s3*s5**3*s6 + 152*s1**2*s2**5*s4**3*s6**2 - 54*s1**2*s2**5*s4**2*s5**2*s6 ++ 5*s1**2*s2**5*s4*s5**4 - 2472*s1**2*s2**5*s6**4 - 76*s1**2*s2**4*s3**3*s5*s6**2 ++ 370*s1**2*s2**4*s3**2*s4**2*s6**2 - 287*s1**2*s2**4*s3**2*s4*s5**2*s6 + 65*s1**2*s2**4*s3**2*s5**4 +- 28*s1**2*s2**4*s3*s4**3*s5*s6 + 5*s1**2*s2**4*s3*s4**2*s5**3 - 8092*s1**2*s2**4*s3*s5*s6**3 ++ 8*s1**2*s2**4*s4**5*s6 - 2*s1**2*s2**4*s4**4*s5**2 + 1096*s1**2*s2**4*s4**2*s6**3 +- 5144*s1**2*s2**4*s4*s5**2*s6**2 + 449*s1**2*s2**4*s5**4*s6 - 210*s1**2*s2**3*s3**4*s4*s6**2 ++ 76*s1**2*s2**3*s3**4*s5**2*s6 + 43*s1**2*s2**3*s3**3*s4**2*s5*s6 - 15*s1**2*s2**3*s3**3*s4*s5**3 +- 6*s1**2*s2**3*s3**2*s4**4*s6 + 2*s1**2*s2**3*s3**2*s4**3*s5**2 + 1962*s1**2*s2**3*s3**2*s4*s6**3 ++ 3181*s1**2*s2**3*s3**2*s5**2*s6**2 + 1684*s1**2*s2**3*s3*s4**2*s5*s6**2 + 500*s1**2*s2**3*s3*s4*s5**3*s6 ++ 590*s1**2*s2**3*s3*s5**5 - 168*s1**2*s2**3*s4**4*s6**2 - 494*s1**2*s2**3*s4**3*s5**2*s6 +- 172*s1**2*s2**3*s4**2*s5**4 - 22080*s1**2*s2**3*s4*s6**4 + 58894*s1**2*s2**3*s5**2*s6**3 ++ 27*s1**2*s2**2*s3**6*s6**2 - 9*s1**2*s2**2*s3**5*s4*s5*s6 + s1**2*s2**2*s3**5*s5**3 ++ s1**2*s2**2*s3**4*s4**3*s6 - 486*s1**2*s2**2*s3**4*s6**3 + 1071*s1**2*s2**2*s3**3*s4*s5*s6**2 ++ 57*s1**2*s2**2*s3**3*s5**3*s6 + 2262*s1**2*s2**2*s3**2*s4**3*s6**2 - 2742*s1**2*s2**2*s3**2*s4**2*s5**2*s6 +- 528*s1**2*s2**2*s3**2*s4*s5**4 - 29160*s1**2*s2**2*s3**2*s6**4 + 772*s1**2*s2**2*s3*s4**4*s5*s6 ++ 447*s1**2*s2**2*s3*s4**3*s5**3 - 96732*s1**2*s2**2*s3*s4*s5*s6**3 + 22209*s1**2*s2**2*s3*s5**3*s6**2 +- 160*s1**2*s2**2*s4**6*s6 - 54*s1**2*s2**2*s4**5*s5**2 - 7992*s1**2*s2**2*s4**3*s6**3 ++ 8634*s1**2*s2**2*s4**2*s5**2*s6**2 - 10040*s1**2*s2**2*s4*s5**4*s6 + 3250*s1**2*s2**2*s5**6 ++ 529200*s1**2*s2**2*s6**5 - 351*s1**2*s2*s3**5*s5*s6**2 - 1215*s1**2*s2*s3**4*s4**2*s6**2 +- 360*s1**2*s2*s3**4*s4*s5**2*s6 + 196*s1**2*s2*s3**4*s5**4 + 741*s1**2*s2*s3**3*s4**3*s5*s6 ++ 168*s1**2*s2*s3**3*s4**2*s5**3 + 11718*s1**2*s2*s3**3*s5*s6**3 - 106*s1**2*s2*s3**2*s4**5*s6 +- 287*s1**2*s2*s3**2*s4**4*s5**2 + 22572*s1**2*s2*s3**2*s4**2*s6**3 - 8892*s1**2*s2*s3**2*s4*s5**2*s6**2 ++ 80*s1**2*s2*s3**2*s5**4*s6 + 88*s1**2*s2*s3*s4**6*s5 + 22144*s1**2*s2*s3*s4**3*s5*s6**2 +- 5698*s1**2*s2*s3*s4**2*s5**3*s6 - 850*s1**2*s2*s3*s4*s5**5 + 169560*s1**2*s2*s3*s5*s6**4 +- 8*s1**2*s2*s4**8 + 3032*s1**2*s2*s4**5*s6**2 - 5144*s1**2*s2*s4**4*s5**2*s6 + 1470*s1**2*s2*s4**3*s5**4 +- 249480*s1**2*s2*s4**2*s6**4 - 105390*s1**2*s2*s4*s5**2*s6**3 + 58900*s1**2*s2*s5**4*s6**2 ++ 162*s1**2*s3**6*s4*s6**2 + 216*s1**2*s3**6*s5**2*s6 - 216*s1**2*s3**5*s4**2*s5*s6 +- 78*s1**2*s3**5*s4*s5**3 + 36*s1**2*s3**4*s4**4*s6 + 76*s1**2*s3**4*s4**3*s5**2 - +3564*s1**2*s3**4*s4*s6**3 + 8802*s1**2*s3**4*s5**2*s6**2 - 22*s1**2*s3**3*s4**5*s5 +- 11475*s1**2*s3**3*s4**2*s5*s6**2 - 2808*s1**2*s3**3*s4*s5**3*s6 + 1200*s1**2*s3**3*s5**5 ++ 2*s1**2*s3**2*s4**7 + 222*s1**2*s3**2*s4**4*s6**2 + 3181*s1**2*s3**2*s4**3*s5**2*s6 +- 610*s1**2*s3**2*s4**2*s5**4 - 165240*s1**2*s3**2*s4*s6**4 + 118260*s1**2*s3**2*s5**2*s6**3 ++ 572*s1**2*s3*s4**5*s5*s6 - 294*s1**2*s3*s4**4*s5**3 - 32616*s1**2*s3*s4**2*s5*s6**3 +- 55665*s1**2*s3*s4*s5**3*s6**2 + 17250*s1**2*s3*s5**5*s6 - 232*s1**2*s4**7*s6 + 86*s1**2*s4**6*s5**2 ++ 48408*s1**2*s4**4*s6**3 + 58894*s1**2*s4**3*s5**2*s6**2 - 46650*s1**2*s4**2*s5**4*s6 ++ 7500*s1**2*s4*s5**6 - 129600*s1**2*s4*s6**5 + 41040*s1**2*s5**2*s6**4 - 48*s1*s2**7*s4*s5*s6**2 ++ 12*s1*s2**7*s5**3*s6 + 12*s1*s2**6*s3**2*s5*s6**2 - 144*s1*s2**6*s3*s4**2*s6**2 ++ 88*s1*s2**6*s3*s4*s5**2*s6 - 13*s1*s2**6*s3*s5**4 + 1680*s1*s2**6*s5*s6**3 + 72*s1*s2**5*s3**3*s4*s6**2 +- 22*s1*s2**5*s3**3*s5**2*s6 - 4*s1*s2**5*s3**2*s4**2*s5*s6 + s1*s2**5*s3**2*s4*s5**3 +- 144*s1*s2**5*s3*s4*s6**3 + 572*s1*s2**5*s3*s5**2*s6**2 + 736*s1*s2**5*s4**2*s5*s6**2 ++ 128*s1*s2**5*s4*s5**3*s6 - 124*s1*s2**5*s5**5 - 9*s1*s2**4*s3**5*s6**2 + s1*s2**4*s3**4*s4*s5*s6 ++ 36*s1*s2**4*s3**3*s6**3 - 2028*s1*s2**4*s3**2*s4*s5*s6**2 - 547*s1*s2**4*s3**2*s5**3*s6 +- 480*s1*s2**4*s3*s4**3*s6**2 + 772*s1*s2**4*s3*s4**2*s5**2*s6 - 29*s1*s2**4*s3*s4*s5**4 ++ 6336*s1*s2**4*s3*s6**4 - 12*s1*s2**4*s4**3*s5**3 + 4368*s1*s2**4*s4*s5*s6**3 - 22624*s1*s2**4*s5**3*s6**2 ++ 441*s1*s2**3*s3**4*s5*s6**2 + 336*s1*s2**3*s3**3*s4**2*s6**2 + 741*s1*s2**3*s3**3*s4*s5**2*s6 ++ 12*s1*s2**3*s3**3*s5**4 - 868*s1*s2**3*s3**2*s4**3*s5*s6 + 93*s1*s2**3*s3**2*s4**2*s5**3 ++ 11016*s1*s2**3*s3**2*s5*s6**3 + 176*s1*s2**3*s3*s4**5*s6 - 28*s1*s2**3*s3*s4**4*s5**2 ++ 14784*s1*s2**3*s3*s4**2*s6**3 + 22144*s1*s2**3*s3*s4*s5**2*s6**2 + 5145*s1*s2**3*s3*s5**4*s6 +- 11344*s1*s2**3*s4**3*s5*s6**2 + 5064*s1*s2**3*s4**2*s5**3*s6 - 2050*s1*s2**3*s4*s5**5 +- 346896*s1*s2**3*s5*s6**4 - 54*s1*s2**2*s3**5*s4*s6**2 - 216*s1*s2**2*s3**5*s5**2*s6 ++ 324*s1*s2**2*s3**4*s4**2*s5*s6 - 95*s1*s2**2*s3**4*s4*s5**3 - 80*s1*s2**2*s3**3*s4**4*s6 ++ 43*s1*s2**2*s3**3*s4**3*s5**2 - 12204*s1*s2**2*s3**3*s4*s6**3 - 11475*s1*s2**2*s3**3*s5**2*s6**2 +- 4*s1*s2**2*s3**2*s4**5*s5 - 3888*s1*s2**2*s3**2*s4**2*s5*s6**2 - 4844*s1*s2**2*s3**2*s4*s5**3*s6 +- 725*s1*s2**2*s3**2*s5**5 - 1312*s1*s2**2*s3*s4**4*s6**2 + 1684*s1*s2**2*s3*s4**3*s5**2*s6 ++ 1995*s1*s2**2*s3*s4**2*s5**4 + 139104*s1*s2**2*s3*s4*s6**4 - 32616*s1*s2**2*s3*s5**2*s6**3 ++ 736*s1*s2**2*s4**5*s5*s6 - 676*s1*s2**2*s4**4*s5**3 + 131040*s1*s2**2*s4**2*s5*s6**3 ++ 16240*s1*s2**2*s4*s5**3*s6**2 - 20250*s1*s2**2*s5**5*s6 - 27*s1*s2*s3**6*s4*s5*s6 ++ 18*s1*s2*s3**6*s5**3 + 9*s1*s2*s3**5*s4**3*s6 - 9*s1*s2*s3**5*s4**2*s5**2 + 1944*s1*s2*s3**5*s6**3 ++ s1*s2*s3**4*s4**4*s5 + 6156*s1*s2*s3**4*s4*s5*s6**2 + 1143*s1*s2*s3**4*s5**3*s6 ++ 324*s1*s2*s3**3*s4**3*s6**2 + 1071*s1*s2*s3**3*s4**2*s5**2*s6 + 15*s1*s2*s3**3*s4*s5**4 +- 7776*s1*s2*s3**3*s6**4 - 2028*s1*s2*s3**2*s4**4*s5*s6 - 397*s1*s2*s3**2*s4**3*s5**3 ++ 112860*s1*s2*s3**2*s4*s5*s6**3 - 10305*s1*s2*s3**2*s5**3*s6**2 + 336*s1*s2*s3*s4**6*s6 ++ 304*s1*s2*s3*s4**5*s5**2 - 68976*s1*s2*s3*s4**3*s6**3 - 96732*s1*s2*s3*s4**2*s5**2*s6**2 ++ 36700*s1*s2*s3*s4*s5**4*s6 - 1250*s1*s2*s3*s5**6 - 1477440*s1*s2*s3*s6**5 - 48*s1*s2*s4**7*s5 ++ 4368*s1*s2*s4**4*s5*s6**2 + 10360*s1*s2*s4**3*s5**3*s6 - 3500*s1*s2*s4**2*s5**5 ++ 935280*s1*s2*s4*s5*s6**4 - 242100*s1*s2*s5**3*s6**3 - 972*s1*s3**6*s5*s6**2 - 351*s1*s3**5*s4*s5**2*s6 +- 99*s1*s3**5*s5**4 + 441*s1*s3**4*s4**3*s5*s6 + 141*s1*s3**4*s4**2*s5**3 - 36936*s1*s3**4*s5*s6**3 +- 84*s1*s3**3*s4**5*s6 - 76*s1*s3**3*s4**4*s5**2 + 17496*s1*s3**3*s4**2*s6**3 + 11718*s1*s3**3*s4*s5**2*s6**2 +- 6525*s1*s3**3*s5**4*s6 + 12*s1*s3**2*s4**6*s5 + 11016*s1*s3**2*s4**3*s5*s6**2 + +5895*s1*s3**2*s4**2*s5**3*s6 - 1750*s1*s3**2*s4*s5**5 - 252720*s1*s3**2*s5*s6**4 - +2544*s1*s3*s4**5*s6**2 - 8092*s1*s3*s4**4*s5**2*s6 + 2300*s1*s3*s4**3*s5**4 + 536544*s1*s3*s4**2*s6**4 ++ 169560*s1*s3*s4*s5**2*s6**3 - 103500*s1*s3*s5**4*s6**2 + 1680*s1*s4**6*s5*s6 - 468*s1*s4**5*s5**3 +- 346896*s1*s4**3*s5*s6**3 + 93900*s1*s4**2*s5**3*s6**2 + 35000*s1*s4*s5**5*s6 - 9375*s1*s5**7 ++ 108864*s1*s5*s6**5 + 16*s2**8*s4**2*s6**2 - 8*s2**8*s4*s5**2*s6 + s2**8*s5**4 - +8*s2**7*s3**2*s4*s6**2 + 2*s2**7*s3**2*s5**2*s6 - 96*s2**7*s4*s6**3 - 232*s2**7*s5**2*s6**2 ++ s2**6*s3**4*s6**2 + 24*s2**6*s3**2*s6**3 + 336*s2**6*s3*s4*s5*s6**2 + 108*s2**6*s3*s5**3*s6 +- 32*s2**6*s4**3*s6**2 - 160*s2**6*s4**2*s5**2*s6 + 38*s2**6*s4*s5**4 + 144*s2**6*s6**4 +- 84*s2**5*s3**3*s5*s6**2 + 8*s2**5*s3**2*s4**2*s6**2 - 106*s2**5*s3**2*s4*s5**2*s6 +- 12*s2**5*s3**2*s5**4 + 176*s2**5*s3*s4**3*s5*s6 - 36*s2**5*s3*s4**2*s5**3 - 2544*s2**5*s3*s5*s6**3 +- 32*s2**5*s4**5*s6 + 8*s2**5*s4**4*s5**2 - 3072*s2**5*s4**2*s6**3 + 3032*s2**5*s4*s5**2*s6**2 ++ 954*s2**5*s5**4*s6 + 36*s2**4*s3**4*s5**2*s6 - 80*s2**4*s3**3*s4**2*s5*s6 + 25*s2**4*s3**3*s4*s5**3 ++ 16*s2**4*s3**2*s4**4*s6 - 6*s2**4*s3**2*s4**3*s5**2 + 2520*s2**4*s3**2*s4*s6**3 ++ 222*s2**4*s3**2*s5**2*s6**2 - 1312*s2**4*s3*s4**2*s5*s6**2 - 3616*s2**4*s3*s4*s5**3*s6 +- 125*s2**4*s3*s5**5 + 1296*s2**4*s4**4*s6**2 - 168*s2**4*s4**3*s5**2*s6 + 375*s2**4*s4**2*s5**4 ++ 19296*s2**4*s4*s6**4 + 48408*s2**4*s5**2*s6**3 + 9*s2**3*s3**5*s4*s5*s6 - 4*s2**3*s3**5*s5**3 +- 2*s2**3*s3**4*s4**3*s6 + s2**3*s3**4*s4**2*s5**2 - 432*s2**3*s3**4*s6**3 + 324*s2**3*s3**3*s4*s5*s6**2 ++ 923*s2**3*s3**3*s5**3*s6 - 752*s2**3*s3**2*s4**3*s6**2 + 2262*s2**3*s3**2*s4**2*s5**2*s6 ++ 525*s2**3*s3**2*s4*s5**4 - 9936*s2**3*s3**2*s6**4 - 480*s2**3*s3*s4**4*s5*s6 - 700*s2**3*s3*s4**3*s5**3 +- 68976*s2**3*s3*s4*s5*s6**3 - 11360*s2**3*s3*s5**3*s6**2 - 32*s2**3*s4**6*s6 + 152*s2**3*s4**5*s5**2 ++ 6912*s2**3*s4**3*s6**3 - 7992*s2**3*s4**2*s5**2*s6**2 + 5550*s2**3*s4*s5**4*s6 - +29376*s2**3*s6**5 + 108*s2**2*s3**4*s4**2*s6**2 - 1215*s2**2*s3**4*s4*s5**2*s6 - 150*s2**2*s3**4*s5**4 ++ 336*s2**2*s3**3*s4**3*s5*s6 - 185*s2**2*s3**3*s4**2*s5**3 + 17496*s2**2*s3**3*s5*s6**3 ++ 8*s2**2*s3**2*s4**5*s6 + 370*s2**2*s3**2*s4**4*s5**2 - 864*s2**2*s3**2*s4**2*s6**3 ++ 22572*s2**2*s3**2*s4*s5**2*s6**2 + 225*s2**2*s3**2*s5**4*s6 - 144*s2**2*s3*s4**6*s5 ++ 14784*s2**2*s3*s4**3*s5*s6**2 - 12020*s2**2*s3*s4**2*s5**3*s6 + 625*s2**2*s3*s4*s5**5 ++ 536544*s2**2*s3*s5*s6**4 + 16*s2**2*s4**8 - 3072*s2**2*s4**5*s6**2 + 1096*s2**2*s4**4*s5**2*s6 ++ 250*s2**2*s4**3*s5**4 - 93744*s2**2*s4**2*s6**4 - 249480*s2**2*s4*s5**2*s6**3 + +70125*s2**2*s5**4*s6**2 + 162*s2*s3**6*s5**2*s6 - 54*s2*s3**5*s4**2*s5*s6 + 198*s2*s3**5*s4*s5**3 +- 210*s2*s3**4*s4**3*s5**2 - 3564*s2*s3**4*s5**2*s6**2 + 72*s2*s3**3*s4**5*s5 - 12204*s2*s3**3*s4**2*s5*s6**2 ++ 1935*s2*s3**3*s4*s5**3*s6 - 8*s2*s3**2*s4**7 + 2520*s2*s3**2*s4**4*s6**2 + 1962*s2*s3**2*s4**3*s5**2*s6 +- 125*s2*s3**2*s4**2*s5**4 - 178848*s2*s3**2*s4*s6**4 - 165240*s2*s3**2*s5**2*s6**3 +- 144*s2*s3*s4**5*s5*s6 - 200*s2*s3*s4**4*s5**3 + 139104*s2*s3*s4**2*s5*s6**3 + 72900*s2*s3*s4*s5**3*s6**2 +- 20625*s2*s3*s5**5*s6 - 96*s2*s4**7*s6 + 56*s2*s4**6*s5**2 + 19296*s2*s4**4*s6**3 +- 22080*s2*s4**3*s5**2*s6**2 - 7750*s2*s4**2*s5**4*s6 + 3125*s2*s4*s5**6 + 248832*s2*s4*s6**5 +- 129600*s2*s5**2*s6**4 - 27*s3**7*s5**3 + 27*s3**6*s4**2*s5**2 - 9*s3**5*s4**4*s5 ++ 1944*s3**5*s4*s5*s6**2 + 54*s3**5*s5**3*s6 + s3**4*s4**6 - 432*s3**4*s4**3*s6**2 +- 486*s3**4*s4**2*s5**2*s6 + 46656*s3**4*s6**4 + 36*s3**3*s4**4*s5*s6 + 50*s3**3*s4**3*s5**3 +- 7776*s3**3*s4*s5*s6**3 + 29700*s3**3*s5**3*s6**2 + 24*s3**2*s4**6*s6 - 14*s3**2*s4**5*s5**2 +- 9936*s3**2*s4**3*s6**3 - 29160*s3**2*s4**2*s5**2*s6**2 - 10125*s3**2*s4*s5**4*s6 ++ 3125*s3**2*s5**6 + 1026432*s3**2*s6**5 + 6336*s3*s4**4*s5*s6**2 + 11700*s3*s4**3*s5**3*s6 +- 3125*s3*s4**2*s5**5 - 1477440*s3*s4*s5*s6**4 + 432000*s3*s5**3*s6**3 + 144*s4**6*s6**2 +- 2472*s4**5*s5**2*s6 + 625*s4**4*s5**4 - 29376*s4**3*s6**4 + 529200*s4**2*s5**2*s6**3 +- 292500*s4*s5**4*s6**2 + 40625*s5**6*s6 - 186624*s6**6) + ], + (6, 2): [ + lambda s1, s2, s3, s4, s5, s6: (-s3), + lambda s1, s2, s3, s4, s5, s6: (-s1*s5 + s2*s4 - 9*s6), + lambda s1, s2, s3, s4, s5, s6: (s1*s2*s6 + 2*s1*s3*s5 - s1*s4**2 - s2**2*s5 + 6*s3*s6 + s4*s5), + lambda s1, s2, s3, s4, s5, s6: (s1**2*s4*s6 - s1**2*s5**2 - 3*s1*s2*s3*s6 + s1*s2*s4*s5 + 9*s1*s5*s6 + s2**3*s6 - +9*s2*s4*s6 + s2*s5**2 + 3*s3**2*s6 - 3*s3*s4*s5 + s4**3 + 27*s6**2), + lambda s1, s2, s3, s4, s5, s6: (-2*s1**3*s6**2 + 2*s1**2*s2*s5*s6 + 2*s1**2*s3*s4*s6 - s1**2*s3*s5**2 - s1*s2**2*s4*s6 +- 3*s1*s2*s6**2 - 16*s1*s3*s5*s6 + 4*s1*s4**2*s6 + 2*s1*s4*s5**2 + 4*s2**2*s5*s6 + +s2*s3*s4*s6 + 2*s2*s3*s5**2 - s2*s4**2*s5 - 9*s3*s6**2 - 3*s4*s5*s6 - 2*s5**3), + lambda s1, s2, s3, s4, s5, s6: (s1**3*s3*s6**2 - 3*s1**3*s4*s5*s6 + s1**3*s5**3 - s1**2*s2**2*s6**2 + s1**2*s2*s3*s5*s6 +- 2*s1**2*s4*s6**2 + 6*s1**2*s5**2*s6 + 16*s1*s2*s3*s6**2 - 3*s1*s2*s5**3 - s1*s3**2*s5*s6 +- 2*s1*s3*s4**2*s6 + s1*s3*s4*s5**2 - 30*s1*s5*s6**2 - 4*s2**3*s6**2 - 2*s2**2*s3*s5*s6 ++ s2**2*s4**2*s6 + 18*s2*s4*s6**2 - 2*s2*s5**2*s6 - 15*s3**2*s6**2 + 16*s3*s4*s5*s6 ++ s3*s5**3 - 4*s4**3*s6 - s4**2*s5**2 - 27*s6**3), + lambda s1, s2, s3, s4, s5, s6: (s1**4*s5*s6**2 + 2*s1**3*s2*s4*s6**2 - s1**3*s2*s5**2*s6 - s1**3*s3**2*s6**2 + 9*s1**3*s6**3 +- 14*s1**2*s2*s5*s6**2 - 11*s1**2*s3*s4*s6**2 + 6*s1**2*s3*s5**2*s6 + 3*s1**2*s4**2*s5*s6 +- s1**2*s4*s5**3 + 3*s1*s2**2*s5**2*s6 + 3*s1*s2*s3**2*s6**2 - s1*s2*s3*s4*s5*s6 + +39*s1*s3*s5*s6**2 - 14*s1*s4*s5**2*s6 + s1*s5**4 - 11*s2*s3*s5**2*s6 + 2*s2*s4*s5**3 +- 3*s3**3*s6**2 + 3*s3**2*s4*s5*s6 - s3**2*s5**3 + 9*s5**3*s6), + lambda s1, s2, s3, s4, s5, s6: (-s1**4*s2*s6**3 + s1**4*s3*s5*s6**2 - 4*s1**3*s3*s6**3 + 10*s1**3*s4*s5*s6**2 - 4*s1**3*s5**3*s6 ++ 8*s1**2*s2**2*s6**3 - 8*s1**2*s2*s3*s5*s6**2 - 2*s1**2*s2*s4**2*s6**2 + s1**2*s2*s4*s5**2*s6 ++ s1**2*s3**2*s4*s6**2 - 6*s1**2*s4*s6**3 - 7*s1**2*s5**2*s6**2 - 24*s1*s2*s3*s6**3 +- 4*s1*s2*s4*s5*s6**2 + 10*s1*s2*s5**3*s6 + 8*s1*s3**2*s5*s6**2 + 8*s1*s3*s4**2*s6**2 +- 8*s1*s3*s4*s5**2*s6 + s1*s3*s5**4 + 36*s1*s5*s6**3 + 8*s2**2*s3*s5*s6**2 - 2*s2**2*s4*s5**2*s6 +- 2*s2*s3**2*s4*s6**2 + s2*s3**2*s5**2*s6 - 6*s2*s5**2*s6**2 + 18*s3**2*s6**3 - 24*s3*s4*s5*s6**2 +- 4*s3*s5**3*s6 + 8*s4**2*s5**2*s6 - s4*s5**4), + lambda s1, s2, s3, s4, s5, s6: (-s1**5*s4*s6**3 - 2*s1**4*s5*s6**3 + 3*s1**3*s2*s5**2*s6**2 + 3*s1**3*s3**2*s6**3 +- s1**3*s3*s4*s5*s6**2 - 8*s1**3*s6**4 + 16*s1**2*s2*s5*s6**3 + 8*s1**2*s3*s4*s6**3 +- 6*s1**2*s3*s5**2*s6**2 - 8*s1**2*s4**2*s5*s6**2 + 3*s1**2*s4*s5**3*s6 - 8*s1*s2**2*s5**2*s6**2 +- 8*s1*s2*s3**2*s6**3 + 8*s1*s2*s3*s4*s5*s6**2 - s1*s2*s3*s5**3*s6 - s1*s3**3*s5*s6**2 +- 24*s1*s3*s5*s6**3 + 16*s1*s4*s5**2*s6**2 - 2*s1*s5**4*s6 + 8*s2*s3*s5**2*s6**2 - +s2*s5**5 + 8*s3**3*s6**3 - 8*s3**2*s4*s5*s6**2 + 3*s3**2*s5**3*s6 - 8*s5**3*s6**2), + lambda s1, s2, s3, s4, s5, s6: (s1**6*s6**4 - 4*s1**4*s2*s6**4 - 2*s1**4*s3*s5*s6**3 + s1**4*s4**2*s6**3 + 8*s1**3*s3*s6**4 +- 4*s1**3*s4*s5*s6**3 + 2*s1**3*s5**3*s6**2 + 8*s1**2*s2*s3*s5*s6**3 - 2*s1**2*s2*s4*s5**2*s6**2 +- 2*s1**2*s3**2*s4*s6**3 + s1**2*s3**2*s5**2*s6**2 - 4*s1*s2*s5**3*s6**2 - 12*s1*s3**2*s5*s6**3 ++ 8*s1*s3*s4*s5**2*s6**2 - 2*s1*s3*s5**4*s6 + s2**2*s5**4*s6 - 2*s2*s3**2*s5**2*s6**2 ++ s3**4*s6**3 + 8*s3*s5**3*s6**2 - 4*s4*s5**4*s6 + s5**6) + ], +} diff --git a/phivenv/Lib/site-packages/sympy/polys/numberfields/subfield.py b/phivenv/Lib/site-packages/sympy/polys/numberfields/subfield.py new file mode 100644 index 0000000000000000000000000000000000000000..c56d0662e4a38b4c0fcaa385c2e0166490354790 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/numberfields/subfield.py @@ -0,0 +1,516 @@ +r""" +Functions in ``polys.numberfields.subfield`` solve the "Subfield Problem" and +allied problems, for algebraic number fields. + +Following Cohen (see [Cohen93]_ Section 4.5), we can define the main problem as +follows: + +* **Subfield Problem:** + + Given two number fields $\mathbb{Q}(\alpha)$, $\mathbb{Q}(\beta)$ + via the minimal polynomials for their generators $\alpha$ and $\beta$, decide + whether one field is isomorphic to a subfield of the other. + +From a solution to this problem flow solutions to the following problems as +well: + +* **Primitive Element Problem:** + + Given several algebraic numbers + $\alpha_1, \ldots, \alpha_m$, compute a single algebraic number $\theta$ + such that $\mathbb{Q}(\alpha_1, \ldots, \alpha_m) = \mathbb{Q}(\theta)$. + +* **Field Isomorphism Problem:** + + Decide whether two number fields + $\mathbb{Q}(\alpha)$, $\mathbb{Q}(\beta)$ are isomorphic. + +* **Field Membership Problem:** + + Given two algebraic numbers $\alpha$, + $\beta$, decide whether $\alpha \in \mathbb{Q}(\beta)$, and if so write + $\alpha = f(\beta)$ for some $f(x) \in \mathbb{Q}[x]$. +""" + +from sympy.core.add import Add +from sympy.core.numbers import AlgebraicNumber +from sympy.core.singleton import S +from sympy.core.symbol import Dummy +from sympy.core.sympify import sympify, _sympify +from sympy.ntheory import sieve +from sympy.polys.densetools import dup_eval +from sympy.polys.domains import QQ +from sympy.polys.numberfields.minpoly import _choose_factor, minimal_polynomial +from sympy.polys.polyerrors import IsomorphismFailed +from sympy.polys.polytools import Poly, PurePoly, factor_list +from sympy.utilities import public + +from mpmath import MPContext + + +def is_isomorphism_possible(a, b): + """Necessary but not sufficient test for isomorphism. """ + n = a.minpoly.degree() + m = b.minpoly.degree() + + if m % n != 0: + return False + + if n == m: + return True + + da = a.minpoly.discriminant() + db = b.minpoly.discriminant() + + i, k, half = 1, m//n, db//2 + + while True: + p = sieve[i] + P = p**k + + if P > half: + break + + if ((da % p) % 2) and not (db % P): + return False + + i += 1 + + return True + + +def field_isomorphism_pslq(a, b): + """Construct field isomorphism using PSLQ algorithm. """ + if not a.root.is_real or not b.root.is_real: + raise NotImplementedError("PSLQ doesn't support complex coefficients") + + f = a.minpoly + g = b.minpoly.replace(f.gen) + + n, m, prev = 100, b.minpoly.degree(), None + ctx = MPContext() + + for i in range(1, 5): + A = a.root.evalf(n) + B = b.root.evalf(n) + + basis = [1, B] + [ B**i for i in range(2, m) ] + [-A] + + ctx.dps = n + coeffs = ctx.pslq(basis, maxcoeff=10**10, maxsteps=1000) + + if coeffs is None: + # PSLQ can't find an integer linear combination. Give up. + break + + if coeffs != prev: + prev = coeffs + else: + # Increasing precision didn't produce anything new. Give up. + break + + # We have + # c0 + c1*B + c2*B^2 + ... + cm-1*B^(m-1) - cm*A ~ 0. + # So bring cm*A to the other side, and divide through by cm, + # for an approximate representation of A as a polynomial in B. + # (We know cm != 0 since `b.minpoly` is irreducible.) + coeffs = [S(c)/coeffs[-1] for c in coeffs[:-1]] + + # Throw away leading zeros. + while not coeffs[-1]: + coeffs.pop() + + coeffs = list(reversed(coeffs)) + h = Poly(coeffs, f.gen, domain='QQ') + + # We only have A ~ h(B). We must check whether the relation is exact. + if f.compose(h).rem(g).is_zero: + # Now we know that h(b) is in fact equal to _some conjugate of_ a. + # But from the very precise approximation A ~ h(B) we can assume + # the conjugate is a itself. + return coeffs + else: + n *= 2 + + return None + + +def field_isomorphism_factor(a, b): + """Construct field isomorphism via factorization. """ + _, factors = factor_list(a.minpoly, extension=b) + for f, _ in factors: + if f.degree() == 1: + # Any linear factor f(x) represents some conjugate of a in QQ(b). + # We want to know whether this linear factor represents a itself. + # Let f = x - c + c = -f.rep.TC() + # Write c as polynomial in b + coeffs = c.to_sympy_list() + d, terms = len(coeffs) - 1, [] + for i, coeff in enumerate(coeffs): + terms.append(coeff*b.root**(d - i)) + r = Add(*terms) + # Check whether we got the number a + if a.minpoly.same_root(r, a): + return coeffs + + # If none of the linear factors represented a in QQ(b), then in fact a is + # not an element of QQ(b). + return None + + +@public +def field_isomorphism(a, b, *, fast=True): + r""" + Find an embedding of one number field into another. + + Explanation + =========== + + This function looks for an isomorphism from $\mathbb{Q}(a)$ onto some + subfield of $\mathbb{Q}(b)$. Thus, it solves the Subfield Problem. + + Examples + ======== + + >>> from sympy import sqrt, field_isomorphism, I + >>> print(field_isomorphism(3, sqrt(2))) # doctest: +SKIP + [3] + >>> print(field_isomorphism( I*sqrt(3), I*sqrt(3)/2)) # doctest: +SKIP + [2, 0] + + Parameters + ========== + + a : :py:class:`~.Expr` + Any expression representing an algebraic number. + b : :py:class:`~.Expr` + Any expression representing an algebraic number. + fast : boolean, optional (default=True) + If ``True``, we first attempt a potentially faster way of computing the + isomorphism, falling back on a slower method if this fails. If + ``False``, we go directly to the slower method, which is guaranteed to + return a result. + + Returns + ======= + + List of rational numbers, or None + If $\mathbb{Q}(a)$ is not isomorphic to some subfield of + $\mathbb{Q}(b)$, then return ``None``. Otherwise, return a list of + rational numbers representing an element of $\mathbb{Q}(b)$ to which + $a$ may be mapped, in order to define a monomorphism, i.e. an + isomorphism from $\mathbb{Q}(a)$ to some subfield of $\mathbb{Q}(b)$. + The elements of the list are the coefficients of falling powers of $b$. + + """ + a, b = sympify(a), sympify(b) + + if not a.is_AlgebraicNumber: + a = AlgebraicNumber(a) + + if not b.is_AlgebraicNumber: + b = AlgebraicNumber(b) + + a = a.to_primitive_element() + b = b.to_primitive_element() + + if a == b: + return a.coeffs() + + n = a.minpoly.degree() + m = b.minpoly.degree() + + if n == 1: + return [a.root] + + if m % n != 0: + return None + + if fast: + try: + result = field_isomorphism_pslq(a, b) + + if result is not None: + return result + except NotImplementedError: + pass + + return field_isomorphism_factor(a, b) + + +def _switch_domain(g, K): + # An algebraic relation f(a, b) = 0 over Q can also be written + # g(b) = 0 where g is in Q(a)[x] and h(a) = 0 where h is in Q(b)[x]. + # This function transforms g into h where Q(b) = K. + frep = g.rep.inject() + hrep = frep.eject(K, front=True) + + return g.new(hrep, g.gens[0]) + + +def _linsolve(p): + # Compute root of linear polynomial. + c, d = p.rep.to_list() + return -d/c + + +@public +def primitive_element(extension, x=None, *, ex=False, polys=False): + r""" + Find a single generator for a number field given by several generators. + + Explanation + =========== + + The basic problem is this: Given several algebraic numbers + $\alpha_1, \alpha_2, \ldots, \alpha_n$, find a single algebraic number + $\theta$ such that + $\mathbb{Q}(\alpha_1, \alpha_2, \ldots, \alpha_n) = \mathbb{Q}(\theta)$. + + This function actually guarantees that $\theta$ will be a linear + combination of the $\alpha_i$, with non-negative integer coefficients. + + Furthermore, if desired, this function will tell you how to express each + $\alpha_i$ as a $\mathbb{Q}$-linear combination of the powers of $\theta$. + + Examples + ======== + + >>> from sympy import primitive_element, sqrt, S, minpoly, simplify + >>> from sympy.abc import x + >>> f, lincomb, reps = primitive_element([sqrt(2), sqrt(3)], x, ex=True) + + Then ``lincomb`` tells us the primitive element as a linear combination of + the given generators ``sqrt(2)`` and ``sqrt(3)``. + + >>> print(lincomb) + [1, 1] + + This means the primtiive element is $\sqrt{2} + \sqrt{3}$. + Meanwhile ``f`` is the minimal polynomial for this primitive element. + + >>> print(f) + x**4 - 10*x**2 + 1 + >>> print(minpoly(sqrt(2) + sqrt(3), x)) + x**4 - 10*x**2 + 1 + + Finally, ``reps`` (which was returned only because we set keyword arg + ``ex=True``) tells us how to recover each of the generators $\sqrt{2}$ and + $\sqrt{3}$ as $\mathbb{Q}$-linear combinations of the powers of the + primitive element $\sqrt{2} + \sqrt{3}$. + + >>> print([S(r) for r in reps[0]]) + [1/2, 0, -9/2, 0] + >>> theta = sqrt(2) + sqrt(3) + >>> print(simplify(theta**3/2 - 9*theta/2)) + sqrt(2) + >>> print([S(r) for r in reps[1]]) + [-1/2, 0, 11/2, 0] + >>> print(simplify(-theta**3/2 + 11*theta/2)) + sqrt(3) + + Parameters + ========== + + extension : list of :py:class:`~.Expr` + Each expression must represent an algebraic number $\alpha_i$. + x : :py:class:`~.Symbol`, optional (default=None) + The desired symbol to appear in the computed minimal polynomial for the + primitive element $\theta$. If ``None``, we use a dummy symbol. + ex : boolean, optional (default=False) + If and only if ``True``, compute the representation of each $\alpha_i$ + as a $\mathbb{Q}$-linear combination over the powers of $\theta$. + polys : boolean, optional (default=False) + If ``True``, return the minimal polynomial as a :py:class:`~.Poly`. + Otherwise return it as an :py:class:`~.Expr`. + + Returns + ======= + + Pair (f, coeffs) or triple (f, coeffs, reps), where: + ``f`` is the minimal polynomial for the primitive element. + ``coeffs`` gives the primitive element as a linear combination of the + given generators. + ``reps`` is present if and only if argument ``ex=True`` was passed, + and is a list of lists of rational numbers. Each list gives the + coefficients of falling powers of the primitive element, to recover + one of the original, given generators. + + """ + if not extension: + raise ValueError("Cannot compute primitive element for empty extension") + extension = [_sympify(ext) for ext in extension] + + if x is not None: + x, cls = sympify(x), Poly + else: + x, cls = Dummy('x'), PurePoly + + def _canonicalize(f): + _, f = f.primitive() + if f.LC() < 0: + f = -f + return f + + if not ex: + gen, coeffs = extension[0], [1] + g = minimal_polynomial(gen, x, polys=True) + for ext in extension[1:]: + if ext.is_Rational: + coeffs.append(0) + continue + _, factors = factor_list(g, extension=ext) + g = _choose_factor(factors, x, gen) + [s], _, g = g.sqf_norm() + gen += s*ext + coeffs.append(s) + + g = _canonicalize(g) + if not polys: + return g.as_expr(), coeffs + else: + return cls(g), coeffs + + gen, coeffs = extension[0], [1] + f = minimal_polynomial(gen, x, polys=True) + K = QQ.algebraic_field((f, gen)) # incrementally constructed field + reps = [K.unit] # representations of extension elements in K + for ext in extension[1:]: + if ext.is_Rational: + coeffs.append(0) # rational ext is not included in the expression of a primitive element + reps.append(K.convert(ext)) # but it is included in reps + continue + p = minimal_polynomial(ext, x, polys=True) + L = QQ.algebraic_field((p, ext)) + _, factors = factor_list(f, domain=L) + f = _choose_factor(factors, x, gen) + [s], g, f = f.sqf_norm() + gen += s*ext + coeffs.append(s) + K = QQ.algebraic_field((f, gen)) + h = _switch_domain(g, K) + erep = _linsolve(h.gcd(p)) # ext as element of K + ogen = K.unit - s*erep # old gen as element of K + reps = [dup_eval(_.to_list(), ogen, K) for _ in reps] + [erep] + + if K.ext.root.is_Rational: # all extensions are rational + H = [K.convert(_).rep for _ in extension] + coeffs = [0]*len(extension) + f = cls(x, domain=QQ) + else: + H = [_.to_list() for _ in reps] + + f = _canonicalize(f) + if not polys: + return f.as_expr(), coeffs, H + else: + return f, coeffs, H + + +@public +def to_number_field(extension, theta=None, *, gen=None, alias=None): + r""" + Express one algebraic number in the field generated by another. + + Explanation + =========== + + Given two algebraic numbers $\eta, \theta$, this function either expresses + $\eta$ as an element of $\mathbb{Q}(\theta)$, or else raises an exception + if $\eta \not\in \mathbb{Q}(\theta)$. + + This function is essentially just a convenience, utilizing + :py:func:`~.field_isomorphism` (our solution of the Subfield Problem) to + solve this, the Field Membership Problem. + + As an additional convenience, this function allows you to pass a list of + algebraic numbers $\alpha_1, \alpha_2, \ldots, \alpha_n$ instead of $\eta$. + It then computes $\eta$ for you, as a solution of the Primitive Element + Problem, using :py:func:`~.primitive_element` on the list of $\alpha_i$. + + Examples + ======== + + >>> from sympy import sqrt, to_number_field + >>> eta = sqrt(2) + >>> theta = sqrt(2) + sqrt(3) + >>> a = to_number_field(eta, theta) + >>> print(type(a)) + + >>> a.root + sqrt(2) + sqrt(3) + >>> print(a) + sqrt(2) + >>> a.coeffs() + [1/2, 0, -9/2, 0] + + We get an :py:class:`~.AlgebraicNumber`, whose ``.root`` is $\theta$, whose + value is $\eta$, and whose ``.coeffs()`` show how to write $\eta$ as a + $\mathbb{Q}$-linear combination in falling powers of $\theta$. + + Parameters + ========== + + extension : :py:class:`~.Expr` or list of :py:class:`~.Expr` + Either the algebraic number that is to be expressed in the other field, + or else a list of algebraic numbers, a primitive element for which is + to be expressed in the other field. + theta : :py:class:`~.Expr`, None, optional (default=None) + If an :py:class:`~.Expr` representing an algebraic number, behavior is + as described under **Explanation**. If ``None``, then this function + reduces to a shorthand for calling :py:func:`~.primitive_element` on + ``extension`` and turning the computed primitive element into an + :py:class:`~.AlgebraicNumber`. + gen : :py:class:`~.Symbol`, None, optional (default=None) + If provided, this will be used as the generator symbol for the minimal + polynomial in the returned :py:class:`~.AlgebraicNumber`. + alias : str, :py:class:`~.Symbol`, None, optional (default=None) + If provided, this will be used as the alias symbol for the returned + :py:class:`~.AlgebraicNumber`. + + Returns + ======= + + AlgebraicNumber + Belonging to $\mathbb{Q}(\theta)$ and equaling $\eta$. + + Raises + ====== + + IsomorphismFailed + If $\eta \not\in \mathbb{Q}(\theta)$. + + See Also + ======== + + field_isomorphism + primitive_element + + """ + if hasattr(extension, '__iter__'): + extension = list(extension) + else: + extension = [extension] + + if len(extension) == 1 and isinstance(extension[0], tuple): + return AlgebraicNumber(extension[0], alias=alias) + + minpoly, coeffs = primitive_element(extension, gen, polys=True) + root = sum(coeff*ext for coeff, ext in zip(coeffs, extension)) + + if theta is None: + return AlgebraicNumber((minpoly, root), alias=alias) + else: + theta = sympify(theta) + + if not theta.is_AlgebraicNumber: + theta = AlgebraicNumber(theta, gen=gen, alias=alias) + + coeffs = field_isomorphism(root, theta) + + if coeffs is not None: + return AlgebraicNumber(theta, coeffs, alias=alias) + else: + raise IsomorphismFailed( + "%s is not in a subfield of %s" % (root, theta.root)) diff --git a/phivenv/Lib/site-packages/sympy/polys/numberfields/tests/__init__.py b/phivenv/Lib/site-packages/sympy/polys/numberfields/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/polys/numberfields/tests/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/numberfields/tests/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb61e50d4864aec26ae376c199c4448826fe94a0 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/numberfields/tests/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/numberfields/tests/__pycache__/test_basis.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/numberfields/tests/__pycache__/test_basis.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce39299f3c8649ebd84ac6952cba09ad5964b210 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/numberfields/tests/__pycache__/test_basis.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/numberfields/tests/__pycache__/test_galoisgroups.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/numberfields/tests/__pycache__/test_galoisgroups.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f3ecc9f7c94a0232db7288506837de03c64ab627 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/numberfields/tests/__pycache__/test_galoisgroups.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/numberfields/tests/__pycache__/test_minpoly.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/numberfields/tests/__pycache__/test_minpoly.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..395ecf44e21c35bf9f05b4aaa8e5891b6b8d8d31 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/numberfields/tests/__pycache__/test_minpoly.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/numberfields/tests/__pycache__/test_modules.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/numberfields/tests/__pycache__/test_modules.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0da659800aae4dfd57333c35c51942c189ab65ef Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/numberfields/tests/__pycache__/test_modules.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/numberfields/tests/__pycache__/test_numbers.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/numberfields/tests/__pycache__/test_numbers.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8fd4cfc4eedf31737b6aa3d49c58845fe800fd5 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/numberfields/tests/__pycache__/test_numbers.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/numberfields/tests/__pycache__/test_primes.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/numberfields/tests/__pycache__/test_primes.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ea0e368d3965010d9a27a7571ccd29b114df471 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/numberfields/tests/__pycache__/test_primes.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/numberfields/tests/__pycache__/test_subfield.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/numberfields/tests/__pycache__/test_subfield.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82ae1695a7e3c65d9b141c053c1c382f91c861ff Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/numberfields/tests/__pycache__/test_subfield.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/numberfields/tests/__pycache__/test_utilities.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/numberfields/tests/__pycache__/test_utilities.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..948bd0506b23cc385c875c81cf773cc20f9b54c6 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/numberfields/tests/__pycache__/test_utilities.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/numberfields/tests/test_basis.py b/phivenv/Lib/site-packages/sympy/polys/numberfields/tests/test_basis.py new file mode 100644 index 0000000000000000000000000000000000000000..c0ed017936cc5c24da63ac02ceca0480f1945feb --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/numberfields/tests/test_basis.py @@ -0,0 +1,85 @@ +from sympy.abc import x +from sympy.core import S +from sympy.core.numbers import AlgebraicNumber +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.polys import Poly, cyclotomic_poly +from sympy.polys.domains import QQ +from sympy.polys.matrices import DomainMatrix, DM +from sympy.polys.numberfields.basis import round_two +from sympy.testing.pytest import raises + + +def test_round_two(): + # Poly must be irreducible, and over ZZ or QQ: + raises(ValueError, lambda: round_two(Poly(x ** 2 - 1))) + raises(ValueError, lambda: round_two(Poly(x ** 2 + sqrt(2)))) + + # Test on many fields: + cases = ( + # A couple of cyclotomic fields: + (cyclotomic_poly(5), DomainMatrix.eye(4, QQ), 125), + (cyclotomic_poly(7), DomainMatrix.eye(6, QQ), -16807), + # A couple of quadratic fields (one 1 mod 4, one 3 mod 4): + (x ** 2 - 5, DM([[1, (1, 2)], [0, (1, 2)]], QQ), 5), + (x ** 2 - 7, DM([[1, 0], [0, 1]], QQ), 28), + # Dedekind's example of a field with 2 as essential disc divisor: + (x ** 3 + x ** 2 - 2 * x + 8, DM([[1, 0, 0], [0, 1, 0], [0, (1, 2), (1, 2)]], QQ).transpose(), -503), + # A bunch of cubics with various forms for F -- all of these require + # second or third enlargements. (Five of them require a third, while the rest require just a second.) + # F = 2^2 + (x**3 + 3 * x**2 - 4 * x + 4, DM([((1, 2), (1, 4), (1, 4)), (0, (1, 2), (1, 2)), (0, 0, 1)], QQ).transpose(), -83), + # F = 2^2 * 3 + (x**3 + 3 * x**2 + 3 * x - 3, DM([((1, 2), 0, (1, 2)), (0, 1, 0), (0, 0, 1)], QQ).transpose(), -108), + # F = 2^3 + (x**3 + 5 * x**2 - x + 3, DM([((1, 4), 0, (3, 4)), (0, (1, 2), (1, 2)), (0, 0, 1)], QQ).transpose(), -31), + # F = 2^2 * 5 + (x**3 + 5 * x**2 - 5 * x - 5, DM([((1, 2), 0, (1, 2)), (0, 1, 0), (0, 0, 1)], QQ).transpose(), 1300), + # F = 3^2 + (x**3 + 3 * x**2 + 5, DM([((1, 3), (1, 3), (1, 3)), (0, 1, 0), (0, 0, 1)], QQ).transpose(), -135), + # F = 3^3 + (x**3 + 6 * x**2 + 3 * x - 1, DM([((1, 3), (1, 3), (1, 3)), (0, 1, 0), (0, 0, 1)], QQ).transpose(), 81), + # F = 2^2 * 3^2 + (x**3 + 6 * x**2 + 4, DM([((1, 3), (2, 3), (1, 3)), (0, 1, 0), (0, 0, (1, 2))], QQ).transpose(), -108), + # F = 2^3 * 7 + (x**3 + 7 * x**2 + 7 * x - 7, DM([((1, 4), 0, (3, 4)), (0, (1, 2), (1, 2)), (0, 0, 1)], QQ).transpose(), 49), + # F = 2^2 * 13 + (x**3 + 7 * x**2 - x + 5, DM([((1, 2), 0, (1, 2)), (0, 1, 0), (0, 0, 1)], QQ).transpose(), -2028), + # F = 2^4 + (x**3 + 7 * x**2 - 5 * x + 5, DM([((1, 4), 0, (3, 4)), (0, (1, 2), (1, 2)), (0, 0, 1)], QQ).transpose(), -140), + # F = 5^2 + (x**3 + 4 * x**2 - 3 * x + 7, DM([((1, 5), (4, 5), (4, 5)), (0, 1, 0), (0, 0, 1)], QQ).transpose(), -175), + # F = 7^2 + (x**3 + 8 * x**2 + 5 * x - 1, DM([((1, 7), (6, 7), (2, 7)), (0, 1, 0), (0, 0, 1)], QQ).transpose(), 49), + # F = 2 * 5 * 7 + (x**3 + 8 * x**2 - 2 * x + 6, DM([(1, 0, 0), (0, 1, 0), (0, 0, 1)], QQ).transpose(), -14700), + # F = 2^2 * 3 * 5 + (x**3 + 6 * x**2 - 3 * x + 8, DM([(1, 0, 0), (0, (1, 4), (1, 4)), (0, 0, 1)], QQ).transpose(), -675), + # F = 2 * 3^2 * 7 + (x**3 + 9 * x**2 + 6 * x - 8, DM([(1, 0, 0), (0, (1, 2), (1, 2)), (0, 0, 1)], QQ).transpose(), 3969), + # F = 2^2 * 3^2 * 7 + (x**3 + 15 * x**2 - 9 * x + 13, DM([((1, 6), (1, 3), (1, 6)), (0, 1, 0), (0, 0, 1)], QQ).transpose(), -5292), + # Polynomial need not be monic + (5*x**3 + 5*x**2 - 10 * x + 40, DM([[1, 0, 0], [0, 1, 0], [0, (1, 2), (1, 2)]], QQ).transpose(), -503), + # Polynomial can have non-integer rational coeffs + (QQ(5, 3)*x**3 + QQ(5, 3)*x**2 - QQ(10, 3)*x + QQ(40, 3), DM([[1, 0, 0], [0, 1, 0], [0, (1, 2), (1, 2)]], QQ).transpose(), -503), + ) + for f, B_exp, d_exp in cases: + K = QQ.alg_field_from_poly(f) + B = K.maximal_order().QQ_matrix + d = K.discriminant() + assert d == d_exp + # The computed basis need not equal the expected one, but their quotient + # must be unimodular: + assert (B.inv()*B_exp).det()**2 == 1 + + +def test_AlgebraicField_integral_basis(): + alpha = AlgebraicNumber(sqrt(5), alias='alpha') + k = QQ.algebraic_field(alpha) + B0 = k.integral_basis() + B1 = k.integral_basis(fmt='sympy') + B2 = k.integral_basis(fmt='alg') + assert B0 == [k([1]), k([S.Half, S.Half])] + assert B1 == [1, S.Half + alpha/2] + assert B2 == [k.ext.field_element([1]), + k.ext.field_element([S.Half, S.Half])] diff --git a/phivenv/Lib/site-packages/sympy/polys/numberfields/tests/test_galoisgroups.py b/phivenv/Lib/site-packages/sympy/polys/numberfields/tests/test_galoisgroups.py new file mode 100644 index 0000000000000000000000000000000000000000..e4cb3d51bcdfad7764b3f6f62dbd2049e466e9e1 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/numberfields/tests/test_galoisgroups.py @@ -0,0 +1,143 @@ +"""Tests for computing Galois groups. """ + +from sympy.abc import x +from sympy.combinatorics.galois import ( + S1TransitiveSubgroups, S2TransitiveSubgroups, S3TransitiveSubgroups, + S4TransitiveSubgroups, S5TransitiveSubgroups, S6TransitiveSubgroups, +) +from sympy.polys.domains.rationalfield import QQ +from sympy.polys.numberfields.galoisgroups import ( + tschirnhausen_transformation, + galois_group, + _galois_group_degree_4_root_approx, + _galois_group_degree_5_hybrid, +) +from sympy.polys.numberfields.subfield import field_isomorphism +from sympy.polys.polytools import Poly +from sympy.testing.pytest import raises + + +def test_tschirnhausen_transformation(): + for T in [ + Poly(x**2 - 2), + Poly(x**2 + x + 1), + Poly(x**4 + 1), + Poly(x**4 - x**3 + x**2 - x + 1), + ]: + _, U = tschirnhausen_transformation(T) + assert U.degree() == T.degree() + assert U.is_monic + assert U.is_irreducible + K = QQ.alg_field_from_poly(T) + L = QQ.alg_field_from_poly(U) + assert field_isomorphism(K.ext, L.ext) is not None + + +# Test polys are from: +# Cohen, H. *A Course in Computational Algebraic Number Theory*. +test_polys_by_deg = { + # Degree 1 + 1: [ + (x, S1TransitiveSubgroups.S1, True) + ], + # Degree 2 + 2: [ + (x**2 + x + 1, S2TransitiveSubgroups.S2, False) + ], + # Degree 3 + 3: [ + (x**3 + x**2 - 2*x - 1, S3TransitiveSubgroups.A3, True), + (x**3 + 2, S3TransitiveSubgroups.S3, False), + ], + # Degree 4 + 4: [ + (x**4 + x**3 + x**2 + x + 1, S4TransitiveSubgroups.C4, False), + (x**4 + 1, S4TransitiveSubgroups.V, True), + (x**4 - 2, S4TransitiveSubgroups.D4, False), + (x**4 + 8*x + 12, S4TransitiveSubgroups.A4, True), + (x**4 + x + 1, S4TransitiveSubgroups.S4, False), + ], + # Degree 5 + 5: [ + (x**5 + x**4 - 4*x**3 - 3*x**2 + 3*x + 1, S5TransitiveSubgroups.C5, True), + (x**5 - 5*x + 12, S5TransitiveSubgroups.D5, True), + (x**5 + 2, S5TransitiveSubgroups.M20, False), + (x**5 + 20*x + 16, S5TransitiveSubgroups.A5, True), + (x**5 - x + 1, S5TransitiveSubgroups.S5, False), + ], + # Degree 6 + 6: [ + (x**6 + x**5 + x**4 + x**3 + x**2 + x + 1, S6TransitiveSubgroups.C6, False), + (x**6 + 108, S6TransitiveSubgroups.S3, False), + (x**6 + 2, S6TransitiveSubgroups.D6, False), + (x**6 - 3*x**2 - 1, S6TransitiveSubgroups.A4, True), + (x**6 + 3*x**3 + 3, S6TransitiveSubgroups.G18, False), + (x**6 - 3*x**2 + 1, S6TransitiveSubgroups.A4xC2, False), + (x**6 - 4*x**2 - 1, S6TransitiveSubgroups.S4p, True), + (x**6 - 3*x**5 + 6*x**4 - 7*x**3 + 2*x**2 + x - 4, S6TransitiveSubgroups.S4m, False), + (x**6 + 2*x**3 - 2, S6TransitiveSubgroups.G36m, False), + (x**6 + 2*x**2 + 2, S6TransitiveSubgroups.S4xC2, False), + (x**6 + 10*x**5 + 55*x**4 + 140*x**3 + 175*x**2 + 170*x + 25, S6TransitiveSubgroups.PSL2F5, True), + (x**6 + 10*x**5 + 55*x**4 + 140*x**3 + 175*x**2 - 3019*x + 25, S6TransitiveSubgroups.PGL2F5, False), + (x**6 + 6*x**4 + 2*x**3 + 9*x**2 + 6*x - 4, S6TransitiveSubgroups.G36p, True), + (x**6 + 2*x**4 + 2*x**3 + x**2 + 2*x + 2, S6TransitiveSubgroups.G72, False), + (x**6 + 24*x - 20, S6TransitiveSubgroups.A6, True), + (x**6 + x + 1, S6TransitiveSubgroups.S6, False), + ], +} + + +def test_galois_group(): + """ + Try all the test polys. + """ + for deg in range(1, 7): + polys = test_polys_by_deg[deg] + for T, G, alt in polys: + assert galois_group(T, by_name=True) == (G, alt) + + +def test_galois_group_degree_out_of_bounds(): + raises(ValueError, lambda: galois_group(Poly(0, x))) + raises(ValueError, lambda: galois_group(Poly(1, x))) + raises(ValueError, lambda: galois_group(Poly(x ** 7 + 1))) + + +def test_galois_group_not_by_name(): + """ + Check at least one polynomial of each supported degree, to see that + conversion from name to group works. + """ + for deg in range(1, 7): + T, G_name, _ = test_polys_by_deg[deg][0] + G, _ = galois_group(T) + assert G == G_name.get_perm_group() + + +def test_galois_group_not_monic_over_ZZ(): + """ + Check that we can work with polys that are not monic over ZZ. + """ + for deg in range(1, 7): + T, G, alt = test_polys_by_deg[deg][0] + assert galois_group(T/2, by_name=True) == (G, alt) + + +def test__galois_group_degree_4_root_approx(): + for T, G, alt in test_polys_by_deg[4]: + assert _galois_group_degree_4_root_approx(Poly(T)) == (G, alt) + + +def test__galois_group_degree_5_hybrid(): + for T, G, alt in test_polys_by_deg[5]: + assert _galois_group_degree_5_hybrid(Poly(T)) == (G, alt) + + +def test_AlgebraicField_galois_group(): + k = QQ.alg_field_from_poly(Poly(x**4 + 1)) + G, _ = k.galois_group(by_name=True) + assert G == S4TransitiveSubgroups.V + + k = QQ.alg_field_from_poly(Poly(x**4 - 2)) + G, _ = k.galois_group(by_name=True) + assert G == S4TransitiveSubgroups.D4 diff --git a/phivenv/Lib/site-packages/sympy/polys/numberfields/tests/test_minpoly.py b/phivenv/Lib/site-packages/sympy/polys/numberfields/tests/test_minpoly.py new file mode 100644 index 0000000000000000000000000000000000000000..792e5ad6e136bb00abda0b0739b2fff4fd41937b --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/numberfields/tests/test_minpoly.py @@ -0,0 +1,490 @@ +"""Tests for minimal polynomials. """ + +from sympy.core.function import expand +from sympy.core import (GoldenRatio, TribonacciConstant) +from sympy.core.numbers import (AlgebraicNumber, I, Rational, oo, pi) +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import (cbrt, sqrt) +from sympy.functions.elementary.trigonometric import (cos, sin, tan) +from sympy.ntheory.generate import nextprime +from sympy.polys.polytools import Poly +from sympy.polys.rootoftools import CRootOf +from sympy.solvers.solveset import nonlinsolve +from sympy.geometry import Circle, intersection +from sympy.testing.pytest import raises, slow +from sympy.sets.sets import FiniteSet +from sympy.geometry.point import Point2D +from sympy.polys.numberfields.minpoly import ( + minimal_polynomial, + _choose_factor, + _minpoly_op_algebraic_element, + _separate_sq, + _minpoly_groebner, +) +from sympy.polys.partfrac import apart +from sympy.polys.polyerrors import ( + NotAlgebraic, + GeneratorsError, +) + +from sympy.polys.domains import QQ +from sympy.polys.rootoftools import rootof +from sympy.polys.polytools import degree + +from sympy.abc import x, y, z + +Q = Rational + + +def test_minimal_polynomial(): + assert minimal_polynomial(-7, x) == x + 7 + assert minimal_polynomial(-1, x) == x + 1 + assert minimal_polynomial( 0, x) == x + assert minimal_polynomial( 1, x) == x - 1 + assert minimal_polynomial( 7, x) == x - 7 + + assert minimal_polynomial(sqrt(2), x) == x**2 - 2 + assert minimal_polynomial(sqrt(5), x) == x**2 - 5 + assert minimal_polynomial(sqrt(6), x) == x**2 - 6 + + assert minimal_polynomial(2*sqrt(2), x) == x**2 - 8 + assert minimal_polynomial(3*sqrt(5), x) == x**2 - 45 + assert minimal_polynomial(4*sqrt(6), x) == x**2 - 96 + + assert minimal_polynomial(2*sqrt(2) + 3, x) == x**2 - 6*x + 1 + assert minimal_polynomial(3*sqrt(5) + 6, x) == x**2 - 12*x - 9 + assert minimal_polynomial(4*sqrt(6) + 7, x) == x**2 - 14*x - 47 + + assert minimal_polynomial(2*sqrt(2) - 3, x) == x**2 + 6*x + 1 + assert minimal_polynomial(3*sqrt(5) - 6, x) == x**2 + 12*x - 9 + assert minimal_polynomial(4*sqrt(6) - 7, x) == x**2 + 14*x - 47 + + assert minimal_polynomial(sqrt(1 + sqrt(6)), x) == x**4 - 2*x**2 - 5 + assert minimal_polynomial(sqrt(I + sqrt(6)), x) == x**8 - 10*x**4 + 49 + + assert minimal_polynomial(2*I + sqrt(2 + I), x) == x**4 + 4*x**2 + 8*x + 37 + + assert minimal_polynomial(sqrt(2) + sqrt(3), x) == x**4 - 10*x**2 + 1 + assert minimal_polynomial( + sqrt(2) + sqrt(3) + sqrt(6), x) == x**4 - 22*x**2 - 48*x - 23 + + a = 1 - 9*sqrt(2) + 7*sqrt(3) + + assert minimal_polynomial( + 1/a, x) == 392*x**4 - 1232*x**3 + 612*x**2 + 4*x - 1 + assert minimal_polynomial( + 1/sqrt(a), x) == 392*x**8 - 1232*x**6 + 612*x**4 + 4*x**2 - 1 + + raises(NotAlgebraic, lambda: minimal_polynomial(oo, x)) + raises(NotAlgebraic, lambda: minimal_polynomial(2**y, x)) + raises(NotAlgebraic, lambda: minimal_polynomial(sin(1), x)) + + assert minimal_polynomial(sqrt(2)).dummy_eq(x**2 - 2) + assert minimal_polynomial(sqrt(2), x) == x**2 - 2 + + assert minimal_polynomial(sqrt(2), polys=True) == Poly(x**2 - 2) + assert minimal_polynomial(sqrt(2), x, polys=True) == Poly(x**2 - 2, domain='QQ') + assert minimal_polynomial(sqrt(2), x, polys=True, compose=False) == Poly(x**2 - 2, domain='QQ') + + a = AlgebraicNumber(sqrt(2)) + b = AlgebraicNumber(sqrt(3)) + + assert minimal_polynomial(a, x) == x**2 - 2 + assert minimal_polynomial(b, x) == x**2 - 3 + + assert minimal_polynomial(a, x, polys=True) == Poly(x**2 - 2, domain='QQ') + assert minimal_polynomial(b, x, polys=True) == Poly(x**2 - 3, domain='QQ') + + assert minimal_polynomial(sqrt(a/2 + 17), x) == 2*x**4 - 68*x**2 + 577 + assert minimal_polynomial(sqrt(b/2 + 17), x) == 4*x**4 - 136*x**2 + 1153 + + a, b = sqrt(2)/3 + 7, AlgebraicNumber(sqrt(2)/3 + 7) + + f = 81*x**8 - 2268*x**6 - 4536*x**5 + 22644*x**4 + 63216*x**3 - \ + 31608*x**2 - 189648*x + 141358 + + assert minimal_polynomial(sqrt(a) + sqrt(sqrt(a)), x) == f + assert minimal_polynomial(sqrt(b) + sqrt(sqrt(b)), x) == f + + assert minimal_polynomial( + a**Q(3, 2), x) == 729*x**4 - 506898*x**2 + 84604519 + + # issue 5994 + eq = S(''' + -1/(800*sqrt(-1/240 + 1/(18000*(-1/17280000 + + sqrt(15)*I/28800000)**(1/3)) + 2*(-1/17280000 + + sqrt(15)*I/28800000)**(1/3)))''') + assert minimal_polynomial(eq, x) == 8000*x**2 - 1 + + ex = (sqrt(5)*sqrt(I)/(5*sqrt(1 + 125*I)) + + 25*sqrt(5)/(I**Q(5,2)*(1 + 125*I)**Q(3,2)) + + 3125*sqrt(5)/(I**Q(11,2)*(1 + 125*I)**Q(3,2)) + + 5*I*sqrt(1 - I/125)) + mp = minimal_polynomial(ex, x) + assert mp == 25*x**4 + 5000*x**2 + 250016 + + ex = 1 + sqrt(2) + sqrt(3) + mp = minimal_polynomial(ex, x) + assert mp == x**4 - 4*x**3 - 4*x**2 + 16*x - 8 + + ex = 1/(1 + sqrt(2) + sqrt(3)) + mp = minimal_polynomial(ex, x) + assert mp == 8*x**4 - 16*x**3 + 4*x**2 + 4*x - 1 + + p = (expand((1 + sqrt(2) - 2*sqrt(3) + sqrt(7))**3))**Rational(1, 3) + mp = minimal_polynomial(p, x) + assert mp == x**8 - 8*x**7 - 56*x**6 + 448*x**5 + 480*x**4 - 5056*x**3 + 1984*x**2 + 7424*x - 3008 + p = expand((1 + sqrt(2) - 2*sqrt(3) + sqrt(7))**3) + mp = minimal_polynomial(p, x) + assert mp == x**8 - 512*x**7 - 118208*x**6 + 31131136*x**5 + 647362560*x**4 - 56026611712*x**3 + 116994310144*x**2 + 404854931456*x - 27216576512 + + assert minimal_polynomial(S("-sqrt(5)/2 - 1/2 + (-sqrt(5)/2 - 1/2)**2"), x) == x - 1 + a = 1 + sqrt(2) + assert minimal_polynomial((a*sqrt(2) + a)**3, x) == x**2 - 198*x + 1 + + p = 1/(1 + sqrt(2) + sqrt(3)) + assert minimal_polynomial(p, x, compose=False) == 8*x**4 - 16*x**3 + 4*x**2 + 4*x - 1 + + p = 2/(1 + sqrt(2) + sqrt(3)) + assert minimal_polynomial(p, x, compose=False) == x**4 - 4*x**3 + 2*x**2 + 4*x - 2 + + assert minimal_polynomial(1 + sqrt(2)*I, x, compose=False) == x**2 - 2*x + 3 + assert minimal_polynomial(1/(1 + sqrt(2)) + 1, x, compose=False) == x**2 - 2 + assert minimal_polynomial(sqrt(2)*I + I*(1 + sqrt(2)), x, + compose=False) == x**4 + 18*x**2 + 49 + + # minimal polynomial of I + assert minimal_polynomial(I, x, domain=QQ.algebraic_field(I)) == x - I + K = QQ.algebraic_field(I*(sqrt(2) + 1)) + assert minimal_polynomial(I, x, domain=K) == x - I + assert minimal_polynomial(I, x, domain=QQ) == x**2 + 1 + assert minimal_polynomial(I, x, domain='QQ(y)') == x**2 + 1 + + #issue 11553 + assert minimal_polynomial(GoldenRatio, x) == x**2 - x - 1 + assert minimal_polynomial(TribonacciConstant + 3, x) == x**3 - 10*x**2 + 32*x - 34 + assert minimal_polynomial(GoldenRatio, x, domain=QQ.algebraic_field(sqrt(5))) == \ + 2*x - sqrt(5) - 1 + assert minimal_polynomial(TribonacciConstant, x, domain=QQ.algebraic_field(cbrt(19 - 3*sqrt(33)))) == \ + 48*x - 19*(19 - 3*sqrt(33))**Rational(2, 3) - 3*sqrt(33)*(19 - 3*sqrt(33))**Rational(2, 3) \ + - 16*(19 - 3*sqrt(33))**Rational(1, 3) - 16 + + # AlgebraicNumber with an alias. + # Wester H24 + phi = AlgebraicNumber(S.GoldenRatio.expand(func=True), alias='phi') + assert minimal_polynomial(phi, x) == x**2 - x - 1 + + +def test_issue_26903(): + p1 = nextprime(10**16) # greater than 10**15 + p2 = nextprime(p1) + assert sqrt(p1**2*p2).is_Pow # square not extracted + zero = sqrt(p1**2*p2) - p1*sqrt(p2) + assert minimal_polynomial(zero, x) == x + assert minimal_polynomial(sqrt(2) - zero, x) == x**2 - 2 + + +def test_issue_8353(): + assert minimal_polynomial(exp(3*I*pi, evaluate=False), x) == x + 1 + assert minimal_polynomial(Pow(8, S(1)/3, evaluate=False), x + ) == x - 2 + + +def test_minimal_polynomial_issue_19732(): + # https://github.com/sympy/sympy/issues/19732 + expr = (-280898097948878450887044002323982963174671632174995451265117559518123750720061943079105185551006003416773064305074191140286225850817291393988597615/(-488144716373031204149459129212782509078221364279079444636386844223983756114492222145074506571622290776245390771587888364089507840000000*sqrt(238368341569)*sqrt(S(11918417078450)/63568729 + - 24411360*sqrt(238368341569)/63568729) + + 238326799225996604451373809274348704114327860564921529846705817404208077866956345381951726531296652901169111729944612727047670549086208000000*sqrt(S(11918417078450)/63568729 + - 24411360*sqrt(238368341569)/63568729)) - + 180561807339168676696180573852937120123827201075968945871075967679148461189459480842956689723484024031016208588658753107/(-59358007109636562851035004992802812513575019937126272896569856090962677491318275291141463850327474176000000*sqrt(238368341569)*sqrt(S(11918417078450)/63568729 + - 24411360*sqrt(238368341569)/63568729) + + 28980348180319251787320809875930301310576055074938369007463004788921613896002936637780993064387310446267596800000*sqrt(S(11918417078450)/63568729 + - 24411360*sqrt(238368341569)/63568729))) + poly = (2151288870990266634727173620565483054187142169311153766675688628985237817262915166497766867289157986631135400926544697981091151416655364879773546003475813114962656742744975460025956167152918469472166170500512008351638710934022160294849059721218824490226159355197136265032810944357335461128949781377875451881300105989490353140886315677977149440000000000000000000000*x**4 + - 5773274155644072033773937864114266313663195672820501581692669271302387257492905909558846459600429795784309388968498783843631580008547382703258503404023153694528041873101120067477617592651525155101107144042679962433039557235772239171616433004024998230222455940044709064078962397144550855715640331680262171410099614469231080995436488414164502751395405398078353242072696360734131090111239998110773292915337556205692674790561090109440000000000000*x**2 + + 211295968822207088328287206509522887719741955693091053353263782924470627623790749534705683380138972642560898936171035770539616881000369889020398551821767092685775598633794696371561234818461806577723412581353857653829324364446419444210520602157621008010129702779407422072249192199762604318993590841636967747488049176548615614290254356975376588506729604345612047361483789518445332415765213187893207704958013682516462853001964919444736320672860140355089) + assert minimal_polynomial(expr, x) == poly + + +def test_minimal_polynomial_hi_prec(): + p = 1/sqrt(1 - 9*sqrt(2) + 7*sqrt(3) + Rational(1, 10)**30) + mp = minimal_polynomial(p, x) + # checked with Wolfram Alpha + assert mp.coeff(x**6) == -1232000000000000000000000000001223999999999999999999999999999987999999999999999999999999999996000000000000000000000000000000 + + +def test_minimal_polynomial_sq(): + from sympy.core.add import Add + from sympy.core.function import expand_multinomial + p = expand_multinomial((1 + 5*sqrt(2) + 2*sqrt(3))**3) + mp = minimal_polynomial(p**Rational(1, 3), x) + assert mp == x**4 - 4*x**3 - 118*x**2 + 244*x + 1321 + p = expand_multinomial((1 + sqrt(2) - 2*sqrt(3) + sqrt(7))**3) + mp = minimal_polynomial(p**Rational(1, 3), x) + assert mp == x**8 - 8*x**7 - 56*x**6 + 448*x**5 + 480*x**4 - 5056*x**3 + 1984*x**2 + 7424*x - 3008 + p = Add(*[sqrt(i) for i in range(1, 12)]) + mp = minimal_polynomial(p, x) + assert mp.subs({x: 0}) == -71965773323122507776 + + +def test_minpoly_compose(): + # issue 6868 + eq = S(''' + -1/(800*sqrt(-1/240 + 1/(18000*(-1/17280000 + + sqrt(15)*I/28800000)**(1/3)) + 2*(-1/17280000 + + sqrt(15)*I/28800000)**(1/3)))''') + mp = minimal_polynomial(eq + 3, x) + assert mp == 8000*x**2 - 48000*x + 71999 + + # issue 5888 + assert minimal_polynomial(exp(I*pi/8), x) == x**8 + 1 + + mp = minimal_polynomial(sin(pi/7) + sqrt(2), x) + assert mp == 4096*x**12 - 63488*x**10 + 351488*x**8 - 826496*x**6 + \ + 770912*x**4 - 268432*x**2 + 28561 + mp = minimal_polynomial(cos(pi/7) + sqrt(2), x) + assert mp == 64*x**6 - 64*x**5 - 432*x**4 + 304*x**3 + 712*x**2 - \ + 232*x - 239 + mp = minimal_polynomial(exp(I*pi/7) + sqrt(2), x) + assert mp == x**12 - 2*x**11 - 9*x**10 + 16*x**9 + 43*x**8 - 70*x**7 - 97*x**6 + 126*x**5 + 211*x**4 - 212*x**3 - 37*x**2 + 142*x + 127 + + mp = minimal_polynomial(sin(pi/7) + sqrt(2), x) + assert mp == 4096*x**12 - 63488*x**10 + 351488*x**8 - 826496*x**6 + \ + 770912*x**4 - 268432*x**2 + 28561 + mp = minimal_polynomial(cos(pi/7) + sqrt(2), x) + assert mp == 64*x**6 - 64*x**5 - 432*x**4 + 304*x**3 + 712*x**2 - \ + 232*x - 239 + mp = minimal_polynomial(exp(I*pi/7) + sqrt(2), x) + assert mp == x**12 - 2*x**11 - 9*x**10 + 16*x**9 + 43*x**8 - 70*x**7 - 97*x**6 + 126*x**5 + 211*x**4 - 212*x**3 - 37*x**2 + 142*x + 127 + + mp = minimal_polynomial(exp(I*pi*Rational(2, 7)), x) + assert mp == x**6 + x**5 + x**4 + x**3 + x**2 + x + 1 + mp = minimal_polynomial(exp(I*pi*Rational(2, 15)), x) + assert mp == x**8 - x**7 + x**5 - x**4 + x**3 - x + 1 + mp = minimal_polynomial(cos(pi*Rational(2, 7)), x) + assert mp == 8*x**3 + 4*x**2 - 4*x - 1 + mp = minimal_polynomial(sin(pi*Rational(2, 7)), x) + ex = (5*cos(pi*Rational(2, 7)) - 7)/(9*cos(pi/7) - 5*cos(pi*Rational(3, 7))) + mp = minimal_polynomial(ex, x) + assert mp == x**3 + 2*x**2 - x - 1 + assert minimal_polynomial(-1/(2*cos(pi/7)), x) == x**3 + 2*x**2 - x - 1 + assert minimal_polynomial(sin(pi*Rational(2, 15)), x) == \ + 256*x**8 - 448*x**6 + 224*x**4 - 32*x**2 + 1 + assert minimal_polynomial(sin(pi*Rational(5, 14)), x) == 8*x**3 - 4*x**2 - 4*x + 1 + assert minimal_polynomial(cos(pi/15), x) == 16*x**4 + 8*x**3 - 16*x**2 - 8*x + 1 + + ex = rootof(x**3 +x*4 + 1, 0) + mp = minimal_polynomial(ex, x) + assert mp == x**3 + 4*x + 1 + mp = minimal_polynomial(ex + 1, x) + assert mp == x**3 - 3*x**2 + 7*x - 4 + assert minimal_polynomial(exp(I*pi/3), x) == x**2 - x + 1 + assert minimal_polynomial(exp(I*pi/4), x) == x**4 + 1 + assert minimal_polynomial(exp(I*pi/6), x) == x**4 - x**2 + 1 + assert minimal_polynomial(exp(I*pi/9), x) == x**6 - x**3 + 1 + assert minimal_polynomial(exp(I*pi/10), x) == x**8 - x**6 + x**4 - x**2 + 1 + assert minimal_polynomial(sin(pi/9), x) == 64*x**6 - 96*x**4 + 36*x**2 - 3 + assert minimal_polynomial(sin(pi/11), x) == 1024*x**10 - 2816*x**8 + \ + 2816*x**6 - 1232*x**4 + 220*x**2 - 11 + assert minimal_polynomial(sin(pi/21), x) == 4096*x**12 - 11264*x**10 + \ + 11264*x**8 - 4992*x**6 + 960*x**4 - 64*x**2 + 1 + assert minimal_polynomial(cos(pi/9), x) == 8*x**3 - 6*x - 1 + + ex = 2**Rational(1, 3)*exp(2*I*pi/3) + assert minimal_polynomial(ex, x) == x**3 - 2 + + raises(NotAlgebraic, lambda: minimal_polynomial(cos(pi*sqrt(2)), x)) + raises(NotAlgebraic, lambda: minimal_polynomial(sin(pi*sqrt(2)), x)) + raises(NotAlgebraic, lambda: minimal_polynomial(exp(1.618*I*pi), x)) + raises(NotAlgebraic, lambda: minimal_polynomial(exp(I*pi*sqrt(2)), x)) + + # issue 5934 + ex = 1/(-36000 - 7200*sqrt(5) + (12*sqrt(10)*sqrt(sqrt(5) + 5) + + 24*sqrt(10)*sqrt(-sqrt(5) + 5))**2) + 1 + raises(ZeroDivisionError, lambda: minimal_polynomial(ex, x)) + + ex = sqrt(1 + 2**Rational(1,3)) + sqrt(1 + 2**Rational(1,4)) + sqrt(2) + mp = minimal_polynomial(ex, x) + assert degree(mp) == 48 and mp.subs({x:0}) == -16630256576 + + ex = tan(pi/5, evaluate=False) + mp = minimal_polynomial(ex, x) + assert mp == x**4 - 10*x**2 + 5 + assert mp.subs(x, tan(pi/5)).is_zero + + ex = tan(pi/6, evaluate=False) + mp = minimal_polynomial(ex, x) + assert mp == 3*x**2 - 1 + assert mp.subs(x, tan(pi/6)).is_zero + + ex = tan(pi/10, evaluate=False) + mp = minimal_polynomial(ex, x) + assert mp == 5*x**4 - 10*x**2 + 1 + assert mp.subs(x, tan(pi/10)).is_zero + + raises(NotAlgebraic, lambda: minimal_polynomial(tan(pi*sqrt(2)), x)) + + +def test_minpoly_issue_7113(): + # see discussion in https://github.com/sympy/sympy/pull/2234 + from sympy.simplify.simplify import nsimplify + r = nsimplify(pi, tolerance=0.000000001) + mp = minimal_polynomial(r, x) + assert mp == 1768292677839237920489538677417507171630859375*x**109 - \ + 2734577732179183863586489182929671773182898498218854181690460140337930774573792597743853652058046464 + + +def test_minpoly_issue_23677(): + r1 = CRootOf(4000000*x**3 - 239960000*x**2 + 4782399900*x - 31663998001, 0) + r2 = CRootOf(4000000*x**3 - 239960000*x**2 + 4782399900*x - 31663998001, 1) + num = (7680000000000000000*r1**4*r2**4 - 614323200000000000000*r1**4*r2**3 + + 18458112576000000000000*r1**4*r2**2 - 246896663036160000000000*r1**4*r2 + + 1240473830323209600000000*r1**4 - 614323200000000000000*r1**3*r2**4 + - 1476464424954240000000000*r1**3*r2**2 - 99225501687553535904000000*r1**3 + + 18458112576000000000000*r1**2*r2**4 - 1476464424954240000000000*r1**2*r2**3 + - 593391458458356671712000000*r1**2*r2 + 2981354896834339226880720000*r1**2 + - 246896663036160000000000*r1*r2**4 - 593391458458356671712000000*r1*r2**2 + - 39878756418031796275267195200*r1 + 1240473830323209600000000*r2**4 + - 99225501687553535904000000*r2**3 + 2981354896834339226880720000*r2**2 - + 39878756418031796275267195200*r2 + 200361370275616536577343808012) + mp = (x**3 + 59426520028417434406408556687919*x**2 + + 1161475464966574421163316896737773190861975156439163671112508400*x + + 7467465541178623874454517208254940823818304424383315270991298807299003671748074773558707779600) + assert minimal_polynomial(num, x) == mp + + +def test_minpoly_issue_7574(): + ex = -(-1)**Rational(1, 3) + (-1)**Rational(2,3) + assert minimal_polynomial(ex, x) == x + 1 + + +def test_choose_factor(): + # Test that this does not enter an infinite loop: + bad_factors = [Poly(x-2, x), Poly(x+2, x)] + raises(NotImplementedError, lambda: _choose_factor(bad_factors, x, sqrt(3))) + + +def test_minpoly_fraction_field(): + assert minimal_polynomial(1/x, y) == -x*y + 1 + assert minimal_polynomial(1 / (x + 1), y) == (x + 1)*y - 1 + + assert minimal_polynomial(sqrt(x), y) == y**2 - x + assert minimal_polynomial(sqrt(x + 1), y) == y**2 - x - 1 + assert minimal_polynomial(sqrt(x) / x, y) == x*y**2 - 1 + assert minimal_polynomial(sqrt(2) * sqrt(x), y) == y**2 - 2 * x + assert minimal_polynomial(sqrt(2) + sqrt(x), y) == \ + y**4 + (-2*x - 4)*y**2 + x**2 - 4*x + 4 + + assert minimal_polynomial(x**Rational(1,3), y) == y**3 - x + assert minimal_polynomial(x**Rational(1,3) + sqrt(x), y) == \ + y**6 - 3*x*y**4 - 2*x*y**3 + 3*x**2*y**2 - 6*x**2*y - x**3 + x**2 + + assert minimal_polynomial(sqrt(x) / z, y) == z**2*y**2 - x + assert minimal_polynomial(sqrt(x) / (z + 1), y) == (z**2 + 2*z + 1)*y**2 - x + + assert minimal_polynomial(1/x, y, polys=True) == Poly(-x*y + 1, y, domain='ZZ(x)') + assert minimal_polynomial(1 / (x + 1), y, polys=True) == \ + Poly((x + 1)*y - 1, y, domain='ZZ(x)') + assert minimal_polynomial(sqrt(x), y, polys=True) == Poly(y**2 - x, y, domain='ZZ(x)') + assert minimal_polynomial(sqrt(x) / z, y, polys=True) == \ + Poly(z**2*y**2 - x, y, domain='ZZ(x, z)') + + # this is (sqrt(1 + x**3)/x).integrate(x).diff(x) - sqrt(1 + x**3)/x + a = sqrt(x)/sqrt(1 + x**(-3)) - sqrt(x**3 + 1)/x + 1/(x**Rational(5, 2)* \ + (1 + x**(-3))**Rational(3, 2)) + 1/(x**Rational(11, 2)*(1 + x**(-3))**Rational(3, 2)) + + assert minimal_polynomial(a, y) == y + + raises(NotAlgebraic, lambda: minimal_polynomial(exp(x), y)) + raises(GeneratorsError, lambda: minimal_polynomial(sqrt(x), x)) + raises(GeneratorsError, lambda: minimal_polynomial(sqrt(x) - y, x)) + raises(NotImplementedError, lambda: minimal_polynomial(sqrt(x), y, compose=False)) + +@slow +def test_minpoly_fraction_field_slow(): + assert minimal_polynomial(minimal_polynomial(sqrt(x**Rational(1,5) - 1), + y).subs(y, sqrt(x**Rational(1,5) - 1)), z) == z + +def test_minpoly_domain(): + assert minimal_polynomial(sqrt(2), x, domain=QQ.algebraic_field(sqrt(2))) == \ + x - sqrt(2) + assert minimal_polynomial(sqrt(8), x, domain=QQ.algebraic_field(sqrt(2))) == \ + x - 2*sqrt(2) + assert minimal_polynomial(sqrt(Rational(3,2)), x, + domain=QQ.algebraic_field(sqrt(2))) == 2*x**2 - 3 + + raises(NotAlgebraic, lambda: minimal_polynomial(y, x, domain=QQ)) + + +def test_issue_14831(): + a = -2*sqrt(2)*sqrt(12*sqrt(2) + 17) + assert minimal_polynomial(a, x) == x**2 + 16*x - 8 + e = (-3*sqrt(12*sqrt(2) + 17) + 12*sqrt(2) + + 17 - 2*sqrt(2)*sqrt(12*sqrt(2) + 17)) + assert minimal_polynomial(e, x) == x + + +def test_issue_18248(): + assert nonlinsolve([x*y**3-sqrt(2)/3, x*y**6-4/(9*(sqrt(3)))],x,y) == \ + FiniteSet((sqrt(3)/2, sqrt(6)/3), (sqrt(3)/2, -sqrt(6)/6 - sqrt(2)*I/2), + (sqrt(3)/2, -sqrt(6)/6 + sqrt(2)*I/2)) + + +def test_issue_13230(): + c1 = Circle(Point2D(3, sqrt(5)), 5) + c2 = Circle(Point2D(4, sqrt(7)), 6) + assert intersection(c1, c2) == [Point2D(-1 + (-sqrt(7) + sqrt(5))*(-2*sqrt(7)/29 + + 9*sqrt(5)/29 + sqrt(196*sqrt(35) + 1941)/29), -2*sqrt(7)/29 + 9*sqrt(5)/29 + + sqrt(196*sqrt(35) + 1941)/29), Point2D(-1 + (-sqrt(7) + sqrt(5))*(-sqrt(196*sqrt(35) + + 1941)/29 - 2*sqrt(7)/29 + 9*sqrt(5)/29), -sqrt(196*sqrt(35) + 1941)/29 - 2*sqrt(7)/29 + 9*sqrt(5)/29)] + +def test_issue_19760(): + e = 1/(sqrt(1 + sqrt(2)) - sqrt(2)*sqrt(1 + sqrt(2))) + 1 + mp_expected = x**4 - 4*x**3 + 4*x**2 - 2 + + for comp in (True, False): + mp = Poly(minimal_polynomial(e, compose=comp)) + assert mp(x) == mp_expected, "minimal_polynomial(e, compose=%s) = %s; %s expected" % (comp, mp(x), mp_expected) + + +def test_issue_20163(): + assert apart(1/(x**6+1), extension=[sqrt(3), I]) == \ + (sqrt(3) + I)/(2*x + sqrt(3) + I)/6 + \ + (sqrt(3) - I)/(2*x + sqrt(3) - I)/6 - \ + (sqrt(3) - I)/(2*x - sqrt(3) + I)/6 - \ + (sqrt(3) + I)/(2*x - sqrt(3) - I)/6 + \ + I/(x + I)/6 - I/(x - I)/6 + + +def test_issue_22559(): + alpha = AlgebraicNumber(sqrt(2)) + assert minimal_polynomial(alpha**3, x) == x**2 - 8 + + +def test_issue_22561(): + a = AlgebraicNumber(sqrt(2) + sqrt(3), [S(1) / 2, 0, S(-9) / 2, 0], gen=x) + assert a.as_expr() == sqrt(2) + assert minimal_polynomial(a, x) == x**2 - 2 + assert minimal_polynomial(a**3, x) == x**2 - 8 + + +def test_separate_sq_not_impl(): + raises(NotImplementedError, lambda: _separate_sq(x**(S(1)/3) + x)) + + +def test_minpoly_op_algebraic_element_not_impl(): + raises(NotImplementedError, + lambda: _minpoly_op_algebraic_element(Pow, sqrt(2), sqrt(3), x, QQ)) + + +def test_minpoly_groebner(): + assert _minpoly_groebner(S(2)/3, x, Poly) == 3*x - 2 + assert _minpoly_groebner( + (sqrt(2) + 3)*(sqrt(2) + 1), x, Poly) == x**2 - 10*x - 7 + assert _minpoly_groebner((sqrt(2) + 3)**(S(1)/3)*(sqrt(2) + 1)**(S(1)/3), + x, Poly) == x**6 - 10*x**3 - 7 + assert _minpoly_groebner((sqrt(2) + 3)**(-S(1)/3)*(sqrt(2) + 1)**(S(1)/3), + x, Poly) == 7*x**6 - 2*x**3 - 1 + raises(NotAlgebraic, lambda: _minpoly_groebner(pi**2, x, Poly)) diff --git a/phivenv/Lib/site-packages/sympy/polys/numberfields/tests/test_modules.py b/phivenv/Lib/site-packages/sympy/polys/numberfields/tests/test_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..f3c61c98e33d3c78e79eeed45efcfa1f74478645 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/numberfields/tests/test_modules.py @@ -0,0 +1,752 @@ +from sympy.abc import x, zeta +from sympy.polys import Poly, cyclotomic_poly +from sympy.polys.domains import FF, QQ, ZZ +from sympy.polys.matrices import DomainMatrix, DM +from sympy.polys.numberfields.exceptions import ( + ClosureFailure, MissingUnityError, StructureError +) +from sympy.polys.numberfields.modules import ( + Module, ModuleElement, ModuleEndomorphism, PowerBasis, PowerBasisElement, + find_min_poly, is_sq_maxrank_HNF, make_mod_elt, to_col, +) +from sympy.polys.numberfields.utilities import is_int +from sympy.polys.polyerrors import UnificationFailed +from sympy.testing.pytest import raises + + +def test_to_col(): + c = [1, 2, 3, 4] + m = to_col(c) + assert m.domain.is_ZZ + assert m.shape == (4, 1) + assert m.flat() == c + + +def test_Module_NotImplemented(): + M = Module() + raises(NotImplementedError, lambda: M.n) + raises(NotImplementedError, lambda: M.mult_tab()) + raises(NotImplementedError, lambda: M.represent(None)) + raises(NotImplementedError, lambda: M.starts_with_unity()) + raises(NotImplementedError, lambda: M.element_from_rational(QQ(2, 3))) + + +def test_Module_ancestors(): + T = Poly(cyclotomic_poly(5, x)) + A = PowerBasis(T) + B = A.submodule_from_matrix(2 * DomainMatrix.eye(4, ZZ)) + C = B.submodule_from_matrix(3 * DomainMatrix.eye(4, ZZ)) + D = B.submodule_from_matrix(5 * DomainMatrix.eye(4, ZZ)) + assert C.ancestors(include_self=True) == [A, B, C] + assert D.ancestors(include_self=True) == [A, B, D] + assert C.power_basis_ancestor() == A + assert C.nearest_common_ancestor(D) == B + M = Module() + assert M.power_basis_ancestor() is None + + +def test_Module_compat_col(): + T = Poly(cyclotomic_poly(5, x)) + A = PowerBasis(T) + col = to_col([1, 2, 3, 4]) + row = col.transpose() + assert A.is_compat_col(col) is True + assert A.is_compat_col(row) is False + assert A.is_compat_col(1) is False + assert A.is_compat_col(DomainMatrix.eye(3, ZZ)[:, 0]) is False + assert A.is_compat_col(DomainMatrix.eye(4, QQ)[:, 0]) is False + assert A.is_compat_col(DomainMatrix.eye(4, ZZ)[:, 0]) is True + + +def test_Module_call(): + T = Poly(cyclotomic_poly(5, x)) + B = PowerBasis(T) + assert B(0).col.flat() == [1, 0, 0, 0] + assert B(1).col.flat() == [0, 1, 0, 0] + col = DomainMatrix.eye(4, ZZ)[:, 2] + assert B(col).col == col + raises(ValueError, lambda: B(-1)) + + +def test_Module_starts_with_unity(): + T = Poly(cyclotomic_poly(5, x)) + A = PowerBasis(T) + B = A.submodule_from_matrix(2 * DomainMatrix.eye(4, ZZ)) + assert A.starts_with_unity() is True + assert B.starts_with_unity() is False + + +def test_Module_basis_elements(): + T = Poly(cyclotomic_poly(5, x)) + A = PowerBasis(T) + B = A.submodule_from_matrix(2 * DomainMatrix.eye(4, ZZ)) + basis = B.basis_elements() + bp = B.basis_element_pullbacks() + for i, (e, p) in enumerate(zip(basis, bp)): + c = [0] * 4 + assert e.module == B + assert p.module == A + c[i] = 1 + assert e == B(to_col(c)) + c[i] = 2 + assert p == A(to_col(c)) + + +def test_Module_zero(): + T = Poly(cyclotomic_poly(5, x)) + A = PowerBasis(T) + B = A.submodule_from_matrix(2 * DomainMatrix.eye(4, ZZ)) + assert A.zero().col.flat() == [0, 0, 0, 0] + assert A.zero().module == A + assert B.zero().col.flat() == [0, 0, 0, 0] + assert B.zero().module == B + + +def test_Module_one(): + T = Poly(cyclotomic_poly(5, x)) + A = PowerBasis(T) + B = A.submodule_from_matrix(2 * DomainMatrix.eye(4, ZZ)) + assert A.one().col.flat() == [1, 0, 0, 0] + assert A.one().module == A + assert B.one().col.flat() == [1, 0, 0, 0] + assert B.one().module == A + + +def test_Module_element_from_rational(): + T = Poly(cyclotomic_poly(5, x)) + A = PowerBasis(T) + B = A.submodule_from_matrix(2 * DomainMatrix.eye(4, ZZ)) + rA = A.element_from_rational(QQ(22, 7)) + rB = B.element_from_rational(QQ(22, 7)) + assert rA.coeffs == [22, 0, 0, 0] + assert rA.denom == 7 + assert rA.module == A + assert rB.coeffs == [22, 0, 0, 0] + assert rB.denom == 7 + assert rB.module == A + + +def test_Module_submodule_from_gens(): + T = Poly(cyclotomic_poly(5, x)) + A = PowerBasis(T) + gens = [2*A(0), 2*A(1), 6*A(0), 6*A(1)] + B = A.submodule_from_gens(gens) + # Because the 3rd and 4th generators do not add anything new, we expect + # the cols of the matrix of B to just reproduce the first two gens: + M = gens[0].column().hstack(gens[1].column()) + assert B.matrix == M + # At least one generator must be provided: + raises(ValueError, lambda: A.submodule_from_gens([])) + # All generators must belong to A: + raises(ValueError, lambda: A.submodule_from_gens([3*A(0), B(0)])) + + +def test_Module_submodule_from_matrix(): + T = Poly(cyclotomic_poly(5, x)) + A = PowerBasis(T) + B = A.submodule_from_matrix(2 * DomainMatrix.eye(4, ZZ)) + e = B(to_col([1, 2, 3, 4])) + f = e.to_parent() + assert f.col.flat() == [2, 4, 6, 8] + # Matrix must be over ZZ: + raises(ValueError, lambda: A.submodule_from_matrix(DomainMatrix.eye(4, QQ))) + # Number of rows of matrix must equal number of generators of module A: + raises(ValueError, lambda: A.submodule_from_matrix(2 * DomainMatrix.eye(5, ZZ))) + + +def test_Module_whole_submodule(): + T = Poly(cyclotomic_poly(5, x)) + A = PowerBasis(T) + B = A.whole_submodule() + e = B(to_col([1, 2, 3, 4])) + f = e.to_parent() + assert f.col.flat() == [1, 2, 3, 4] + e0, e1, e2, e3 = B(0), B(1), B(2), B(3) + assert e2 * e3 == e0 + assert e3 ** 2 == e1 + + +def test_PowerBasis_repr(): + T = Poly(cyclotomic_poly(5, x)) + A = PowerBasis(T) + assert repr(A) == 'PowerBasis(x**4 + x**3 + x**2 + x + 1)' + + +def test_PowerBasis_eq(): + T = Poly(cyclotomic_poly(5, x)) + A = PowerBasis(T) + B = PowerBasis(T) + assert A == B + + +def test_PowerBasis_mult_tab(): + T = Poly(cyclotomic_poly(5, x)) + A = PowerBasis(T) + M = A.mult_tab() + exp = {0: {0: [1, 0, 0, 0], 1: [0, 1, 0, 0], 2: [0, 0, 1, 0], 3: [0, 0, 0, 1]}, + 1: {1: [0, 0, 1, 0], 2: [0, 0, 0, 1], 3: [-1, -1, -1, -1]}, + 2: {2: [-1, -1, -1, -1], 3: [1, 0, 0, 0]}, + 3: {3: [0, 1, 0, 0]}} + # We get the table we expect: + assert M == exp + # And all entries are of expected type: + assert all(is_int(c) for u in M for v in M[u] for c in M[u][v]) + + +def test_PowerBasis_represent(): + T = Poly(cyclotomic_poly(5, x)) + A = PowerBasis(T) + col = to_col([1, 2, 3, 4]) + a = A(col) + assert A.represent(a) == col + b = A(col, denom=2) + raises(ClosureFailure, lambda: A.represent(b)) + + +def test_PowerBasis_element_from_poly(): + T = Poly(cyclotomic_poly(5, x)) + A = PowerBasis(T) + f = Poly(1 + 2*x) + g = Poly(x**4) + h = Poly(0, x) + assert A.element_from_poly(f).coeffs == [1, 2, 0, 0] + assert A.element_from_poly(g).coeffs == [-1, -1, -1, -1] + assert A.element_from_poly(h).coeffs == [0, 0, 0, 0] + + +def test_PowerBasis_element__conversions(): + k = QQ.cyclotomic_field(5) + L = QQ.cyclotomic_field(7) + B = PowerBasis(k) + + # ANP --> PowerBasisElement + a = k([QQ(1, 2), QQ(1, 3), 5, 7]) + e = B.element_from_ANP(a) + assert e.coeffs == [42, 30, 2, 3] + assert e.denom == 6 + + # PowerBasisElement --> ANP + assert e.to_ANP() == a + + # Cannot convert ANP from different field + d = L([QQ(1, 2), QQ(1, 3), 5, 7]) + raises(UnificationFailed, lambda: B.element_from_ANP(d)) + + # AlgebraicNumber --> PowerBasisElement + alpha = k.to_alg_num(a) + eps = B.element_from_alg_num(alpha) + assert eps.coeffs == [42, 30, 2, 3] + assert eps.denom == 6 + + # PowerBasisElement --> AlgebraicNumber + assert eps.to_alg_num() == alpha + + # Cannot convert AlgebraicNumber from different field + delta = L.to_alg_num(d) + raises(UnificationFailed, lambda: B.element_from_alg_num(delta)) + + # When we don't know the field: + C = PowerBasis(k.ext.minpoly) + # Can convert from AlgebraicNumber: + eps = C.element_from_alg_num(alpha) + assert eps.coeffs == [42, 30, 2, 3] + assert eps.denom == 6 + # But can't convert back: + raises(StructureError, lambda: eps.to_alg_num()) + + +def test_Submodule_repr(): + T = Poly(cyclotomic_poly(5, x)) + A = PowerBasis(T) + B = A.submodule_from_matrix(2 * DomainMatrix.eye(4, ZZ), denom=3) + assert repr(B) == 'Submodule[[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 2, 0], [0, 0, 0, 2]]/3' + + +def test_Submodule_reduced(): + T = Poly(cyclotomic_poly(5, x)) + A = PowerBasis(T) + B = A.submodule_from_matrix(2 * DomainMatrix.eye(4, ZZ)) + C = A.submodule_from_matrix(6 * DomainMatrix.eye(4, ZZ), denom=3) + D = C.reduced() + assert D.denom == 1 and D == C == B + + +def test_Submodule_discard_before(): + T = Poly(cyclotomic_poly(5, x)) + A = PowerBasis(T) + B = A.submodule_from_matrix(2 * DomainMatrix.eye(4, ZZ)) + B.compute_mult_tab() + C = B.discard_before(2) + assert C.parent == B.parent + assert B.is_sq_maxrank_HNF() and not C.is_sq_maxrank_HNF() + assert C.matrix == B.matrix[:, 2:] + assert C.mult_tab() == {0: {0: [-2, -2], 1: [0, 0]}, 1: {1: [0, 0]}} + + +def test_Submodule_QQ_matrix(): + T = Poly(cyclotomic_poly(5, x)) + A = PowerBasis(T) + B = A.submodule_from_matrix(2 * DomainMatrix.eye(4, ZZ)) + C = A.submodule_from_matrix(6 * DomainMatrix.eye(4, ZZ), denom=3) + assert C.QQ_matrix == B.QQ_matrix + + +def test_Submodule_represent(): + T = Poly(cyclotomic_poly(5, x)) + A = PowerBasis(T) + B = A.submodule_from_matrix(2 * DomainMatrix.eye(4, ZZ)) + C = B.submodule_from_matrix(3 * DomainMatrix.eye(4, ZZ)) + a0 = A(to_col([6, 12, 18, 24])) + a1 = A(to_col([2, 4, 6, 8])) + a2 = A(to_col([1, 3, 5, 7])) + + b1 = B.represent(a1) + assert b1.flat() == [1, 2, 3, 4] + + c0 = C.represent(a0) + assert c0.flat() == [1, 2, 3, 4] + + Y = A.submodule_from_matrix(DomainMatrix([ + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, 0], + ], (3, 4), ZZ).transpose()) + + U = Poly(cyclotomic_poly(7, x)) + Z = PowerBasis(U) + z0 = Z(to_col([1, 2, 3, 4, 5, 6])) + + raises(ClosureFailure, lambda: Y.represent(A(3))) + raises(ClosureFailure, lambda: B.represent(a2)) + raises(ClosureFailure, lambda: B.represent(z0)) + + +def test_Submodule_is_compat_submodule(): + T = Poly(cyclotomic_poly(5, x)) + A = PowerBasis(T) + B = A.submodule_from_matrix(2 * DomainMatrix.eye(4, ZZ)) + C = A.submodule_from_matrix(3 * DomainMatrix.eye(4, ZZ)) + D = C.submodule_from_matrix(5 * DomainMatrix.eye(4, ZZ)) + assert B.is_compat_submodule(C) is True + assert B.is_compat_submodule(A) is False + assert B.is_compat_submodule(D) is False + + +def test_Submodule_eq(): + T = Poly(cyclotomic_poly(5, x)) + A = PowerBasis(T) + B = A.submodule_from_matrix(2 * DomainMatrix.eye(4, ZZ)) + C = A.submodule_from_matrix(6 * DomainMatrix.eye(4, ZZ), denom=3) + assert C == B + + +def test_Submodule_add(): + T = Poly(cyclotomic_poly(5, x)) + A = PowerBasis(T) + B = A.submodule_from_matrix(DomainMatrix([ + [4, 0, 0, 0], + [0, 4, 0, 0], + ], (2, 4), ZZ).transpose(), denom=6) + C = A.submodule_from_matrix(DomainMatrix([ + [0, 10, 0, 0], + [0, 0, 7, 0], + ], (2, 4), ZZ).transpose(), denom=15) + D = A.submodule_from_matrix(DomainMatrix([ + [20, 0, 0, 0], + [ 0, 20, 0, 0], + [ 0, 0, 14, 0], + ], (3, 4), ZZ).transpose(), denom=30) + assert B + C == D + + U = Poly(cyclotomic_poly(7, x)) + Z = PowerBasis(U) + Y = Z.submodule_from_gens([Z(0), Z(1)]) + raises(TypeError, lambda: B + Y) + + +def test_Submodule_mul(): + T = Poly(cyclotomic_poly(5, x)) + A = PowerBasis(T) + C = A.submodule_from_matrix(DomainMatrix([ + [0, 10, 0, 0], + [0, 0, 7, 0], + ], (2, 4), ZZ).transpose(), denom=15) + C1 = A.submodule_from_matrix(DomainMatrix([ + [0, 20, 0, 0], + [0, 0, 14, 0], + ], (2, 4), ZZ).transpose(), denom=3) + C2 = A.submodule_from_matrix(DomainMatrix([ + [0, 0, 10, 0], + [0, 0, 0, 7], + ], (2, 4), ZZ).transpose(), denom=15) + C3_unred = A.submodule_from_matrix(DomainMatrix([ + [0, 0, 100, 0], + [0, 0, 0, 70], + [0, 0, 0, 70], + [-49, -49, -49, -49] + ], (4, 4), ZZ).transpose(), denom=225) + C3 = A.submodule_from_matrix(DomainMatrix([ + [4900, 4900, 0, 0], + [4410, 4410, 10, 0], + [2107, 2107, 7, 7] + ], (3, 4), ZZ).transpose(), denom=225) + assert C * 1 == C + assert C ** 1 == C + assert C * 10 == C1 + assert C * A(1) == C2 + assert C.mul(C, hnf=False) == C3_unred + assert C * C == C3 + assert C ** 2 == C3 + + +def test_Submodule_reduce_element(): + T = Poly(cyclotomic_poly(5, x)) + A = PowerBasis(T) + B = A.whole_submodule() + b = B(to_col([90, 84, 80, 75]), denom=120) + + C = B.submodule_from_matrix(DomainMatrix.eye(4, ZZ), denom=2) + b_bar_expected = B(to_col([30, 24, 20, 15]), denom=120) + b_bar = C.reduce_element(b) + assert b_bar == b_bar_expected + + C = B.submodule_from_matrix(DomainMatrix.eye(4, ZZ), denom=4) + b_bar_expected = B(to_col([0, 24, 20, 15]), denom=120) + b_bar = C.reduce_element(b) + assert b_bar == b_bar_expected + + C = B.submodule_from_matrix(DomainMatrix.eye(4, ZZ), denom=8) + b_bar_expected = B(to_col([0, 9, 5, 0]), denom=120) + b_bar = C.reduce_element(b) + assert b_bar == b_bar_expected + + a = A(to_col([1, 2, 3, 4])) + raises(NotImplementedError, lambda: C.reduce_element(a)) + + C = B.submodule_from_matrix(DomainMatrix([ + [5, 4, 3, 2], + [0, 8, 7, 6], + [0, 0,11,12], + [0, 0, 0, 1] + ], (4, 4), ZZ).transpose()) + raises(StructureError, lambda: C.reduce_element(b)) + + +def test_is_HNF(): + M = DM([ + [3, 2, 1], + [0, 2, 1], + [0, 0, 1] + ], ZZ) + M1 = DM([ + [3, 2, 1], + [0, -2, 1], + [0, 0, 1] + ], ZZ) + M2 = DM([ + [3, 2, 3], + [0, 2, 1], + [0, 0, 1] + ], ZZ) + assert is_sq_maxrank_HNF(M) is True + assert is_sq_maxrank_HNF(M1) is False + assert is_sq_maxrank_HNF(M2) is False + + +def test_make_mod_elt(): + T = Poly(cyclotomic_poly(5, x)) + A = PowerBasis(T) + B = A.submodule_from_matrix(2 * DomainMatrix.eye(4, ZZ)) + col = to_col([1, 2, 3, 4]) + eA = make_mod_elt(A, col) + eB = make_mod_elt(B, col) + assert isinstance(eA, PowerBasisElement) + assert not isinstance(eB, PowerBasisElement) + + +def test_ModuleElement_repr(): + T = Poly(cyclotomic_poly(5, x)) + A = PowerBasis(T) + e = A(to_col([1, 2, 3, 4]), denom=2) + assert repr(e) == '[1, 2, 3, 4]/2' + + +def test_ModuleElement_reduced(): + T = Poly(cyclotomic_poly(5, x)) + A = PowerBasis(T) + e = A(to_col([2, 4, 6, 8]), denom=2) + f = e.reduced() + assert f.denom == 1 and f == e + + +def test_ModuleElement_reduced_mod_p(): + T = Poly(cyclotomic_poly(5, x)) + A = PowerBasis(T) + e = A(to_col([20, 40, 60, 80])) + f = e.reduced_mod_p(7) + assert f.coeffs == [-1, -2, -3, 3] + + +def test_ModuleElement_from_int_list(): + T = Poly(cyclotomic_poly(5, x)) + A = PowerBasis(T) + c = [1, 2, 3, 4] + assert ModuleElement.from_int_list(A, c).coeffs == c + + +def test_ModuleElement_len(): + T = Poly(cyclotomic_poly(5, x)) + A = PowerBasis(T) + e = A(0) + assert len(e) == 4 + + +def test_ModuleElement_column(): + T = Poly(cyclotomic_poly(5, x)) + A = PowerBasis(T) + e = A(0) + col1 = e.column() + assert col1 == e.col and col1 is not e.col + col2 = e.column(domain=FF(5)) + assert col2.domain.is_FF + + +def test_ModuleElement_QQ_col(): + T = Poly(cyclotomic_poly(5, x)) + A = PowerBasis(T) + e = A(to_col([1, 2, 3, 4]), denom=1) + f = A(to_col([3, 6, 9, 12]), denom=3) + assert e.QQ_col == f.QQ_col + + +def test_ModuleElement_to_ancestors(): + T = Poly(cyclotomic_poly(5, x)) + A = PowerBasis(T) + B = A.submodule_from_matrix(2 * DomainMatrix.eye(4, ZZ)) + C = B.submodule_from_matrix(3 * DomainMatrix.eye(4, ZZ)) + D = C.submodule_from_matrix(5 * DomainMatrix.eye(4, ZZ)) + eD = D(0) + eC = eD.to_parent() + eB = eD.to_ancestor(B) + eA = eD.over_power_basis() + assert eC.module is C and eC.coeffs == [5, 0, 0, 0] + assert eB.module is B and eB.coeffs == [15, 0, 0, 0] + assert eA.module is A and eA.coeffs == [30, 0, 0, 0] + + a = A(0) + raises(ValueError, lambda: a.to_parent()) + + +def test_ModuleElement_compatibility(): + T = Poly(cyclotomic_poly(5, x)) + A = PowerBasis(T) + B = A.submodule_from_matrix(2 * DomainMatrix.eye(4, ZZ)) + C = B.submodule_from_matrix(3 * DomainMatrix.eye(4, ZZ)) + D = B.submodule_from_matrix(5 * DomainMatrix.eye(4, ZZ)) + assert C(0).is_compat(C(1)) is True + assert C(0).is_compat(D(0)) is False + u, v = C(0).unify(D(0)) + assert u.module is B and v.module is B + assert C(C.represent(u)) == C(0) and D(D.represent(v)) == D(0) + + u, v = C(0).unify(C(1)) + assert u == C(0) and v == C(1) + + U = Poly(cyclotomic_poly(7, x)) + Z = PowerBasis(U) + raises(UnificationFailed, lambda: C(0).unify(Z(1))) + + +def test_ModuleElement_eq(): + T = Poly(cyclotomic_poly(5, x)) + A = PowerBasis(T) + e = A(to_col([1, 2, 3, 4]), denom=1) + f = A(to_col([3, 6, 9, 12]), denom=3) + assert e == f + + U = Poly(cyclotomic_poly(7, x)) + Z = PowerBasis(U) + assert e != Z(0) + assert e != 3.14 + + +def test_ModuleElement_equiv(): + T = Poly(cyclotomic_poly(5, x)) + A = PowerBasis(T) + e = A(to_col([1, 2, 3, 4]), denom=1) + f = A(to_col([3, 6, 9, 12]), denom=3) + assert e.equiv(f) + + C = A.submodule_from_matrix(3 * DomainMatrix.eye(4, ZZ)) + g = C(to_col([1, 2, 3, 4]), denom=1) + h = A(to_col([3, 6, 9, 12]), denom=1) + assert g.equiv(h) + assert C(to_col([5, 0, 0, 0]), denom=7).equiv(QQ(15, 7)) + + U = Poly(cyclotomic_poly(7, x)) + Z = PowerBasis(U) + raises(UnificationFailed, lambda: e.equiv(Z(0))) + + assert e.equiv(3.14) is False + + +def test_ModuleElement_add(): + T = Poly(cyclotomic_poly(5, x)) + A = PowerBasis(T) + C = A.submodule_from_matrix(3 * DomainMatrix.eye(4, ZZ)) + e = A(to_col([1, 2, 3, 4]), denom=6) + f = A(to_col([5, 6, 7, 8]), denom=10) + g = C(to_col([1, 1, 1, 1]), denom=2) + assert e + f == A(to_col([10, 14, 18, 22]), denom=15) + assert e - f == A(to_col([-5, -4, -3, -2]), denom=15) + assert e + g == A(to_col([10, 11, 12, 13]), denom=6) + assert e + QQ(7, 10) == A(to_col([26, 10, 15, 20]), denom=30) + assert g + QQ(7, 10) == A(to_col([22, 15, 15, 15]), denom=10) + + U = Poly(cyclotomic_poly(7, x)) + Z = PowerBasis(U) + raises(TypeError, lambda: e + Z(0)) + raises(TypeError, lambda: e + 3.14) + + +def test_ModuleElement_mul(): + T = Poly(cyclotomic_poly(5, x)) + A = PowerBasis(T) + C = A.submodule_from_matrix(3 * DomainMatrix.eye(4, ZZ)) + e = A(to_col([0, 2, 0, 0]), denom=3) + f = A(to_col([0, 0, 0, 7]), denom=5) + g = C(to_col([0, 0, 0, 1]), denom=2) + h = A(to_col([0, 0, 3, 1]), denom=7) + assert e * f == A(to_col([-14, -14, -14, -14]), denom=15) + assert e * g == A(to_col([-1, -1, -1, -1])) + assert e * h == A(to_col([-2, -2, -2, 4]), denom=21) + assert e * QQ(6, 5) == A(to_col([0, 4, 0, 0]), denom=5) + assert (g * QQ(10, 21)).equiv(A(to_col([0, 0, 0, 5]), denom=7)) + assert e // QQ(6, 5) == A(to_col([0, 5, 0, 0]), denom=9) + + U = Poly(cyclotomic_poly(7, x)) + Z = PowerBasis(U) + raises(TypeError, lambda: e * Z(0)) + raises(TypeError, lambda: e * 3.14) + raises(TypeError, lambda: e // 3.14) + raises(ZeroDivisionError, lambda: e // 0) + + +def test_ModuleElement_div(): + T = Poly(cyclotomic_poly(5, x)) + A = PowerBasis(T) + C = A.submodule_from_matrix(3 * DomainMatrix.eye(4, ZZ)) + e = A(to_col([0, 2, 0, 0]), denom=3) + f = A(to_col([0, 0, 0, 7]), denom=5) + g = C(to_col([1, 1, 1, 1])) + assert e // f == 10*A(3)//21 + assert e // g == -2*A(2)//9 + assert 3 // g == -A(1) + + +def test_ModuleElement_pow(): + T = Poly(cyclotomic_poly(5, x)) + A = PowerBasis(T) + C = A.submodule_from_matrix(3 * DomainMatrix.eye(4, ZZ)) + e = A(to_col([0, 2, 0, 0]), denom=3) + g = C(to_col([0, 0, 0, 1]), denom=2) + assert e ** 3 == A(to_col([0, 0, 0, 8]), denom=27) + assert g ** 2 == C(to_col([0, 3, 0, 0]), denom=4) + assert e ** 0 == A(to_col([1, 0, 0, 0])) + assert g ** 0 == A(to_col([1, 0, 0, 0])) + assert e ** 1 == e + assert g ** 1 == g + + +def test_ModuleElement_mod(): + T = Poly(cyclotomic_poly(5, x)) + A = PowerBasis(T) + e = A(to_col([1, 15, 8, 0]), denom=2) + assert e % 7 == A(to_col([1, 1, 8, 0]), denom=2) + assert e % QQ(1, 2) == A.zero() + assert e % QQ(1, 3) == A(to_col([1, 1, 0, 0]), denom=6) + + B = A.submodule_from_gens([A(0), 5*A(1), 3*A(2), A(3)]) + assert e % B == A(to_col([1, 5, 2, 0]), denom=2) + + C = B.whole_submodule() + raises(TypeError, lambda: e % C) + + +def test_PowerBasisElement_polys(): + T = Poly(cyclotomic_poly(5, x)) + A = PowerBasis(T) + e = A(to_col([1, 15, 8, 0]), denom=2) + assert e.numerator(x=zeta) == Poly(8 * zeta ** 2 + 15 * zeta + 1, domain=ZZ) + assert e.poly(x=zeta) == Poly(4 * zeta ** 2 + QQ(15, 2) * zeta + QQ(1, 2), domain=QQ) + + +def test_PowerBasisElement_norm(): + T = Poly(cyclotomic_poly(5, x)) + A = PowerBasis(T) + lam = A(to_col([1, -1, 0, 0])) + assert lam.norm() == 5 + + +def test_PowerBasisElement_inverse(): + T = Poly(cyclotomic_poly(5, x)) + A = PowerBasis(T) + e = A(to_col([1, 1, 1, 1])) + assert 2 // e == -2*A(1) + assert e ** -3 == -A(3) + + +def test_ModuleHomomorphism_matrix(): + T = Poly(cyclotomic_poly(5, x)) + A = PowerBasis(T) + phi = ModuleEndomorphism(A, lambda a: a ** 2) + M = phi.matrix() + assert M == DomainMatrix([ + [1, 0, -1, 0], + [0, 0, -1, 1], + [0, 1, -1, 0], + [0, 0, -1, 0] + ], (4, 4), ZZ) + + +def test_ModuleHomomorphism_kernel(): + T = Poly(cyclotomic_poly(5, x)) + A = PowerBasis(T) + phi = ModuleEndomorphism(A, lambda a: a ** 5) + N = phi.kernel() + assert N.n == 3 + + +def test_EndomorphismRing_represent(): + T = Poly(cyclotomic_poly(5, x)) + A = PowerBasis(T) + R = A.endomorphism_ring() + phi = R.inner_endomorphism(A(1)) + col = R.represent(phi) + assert col.transpose() == DomainMatrix([ + [0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, -1, -1, -1, -1] + ], (1, 16), ZZ) + + B = A.submodule_from_matrix(DomainMatrix.zeros((4, 0), ZZ)) + S = B.endomorphism_ring() + psi = S.inner_endomorphism(A(1)) + col = S.represent(psi) + assert col == DomainMatrix([], (0, 0), ZZ) + + raises(NotImplementedError, lambda: R.represent(3.14)) + + +def test_find_min_poly(): + T = Poly(cyclotomic_poly(5, x)) + A = PowerBasis(T) + powers = [] + m = find_min_poly(A(1), QQ, x=x, powers=powers) + assert m == Poly(T, domain=QQ) + assert len(powers) == 5 + + # powers list need not be passed + m = find_min_poly(A(1), QQ, x=x) + assert m == Poly(T, domain=QQ) + + B = A.submodule_from_matrix(2 * DomainMatrix.eye(4, ZZ)) + raises(MissingUnityError, lambda: find_min_poly(B(1), QQ)) diff --git a/phivenv/Lib/site-packages/sympy/polys/numberfields/tests/test_numbers.py b/phivenv/Lib/site-packages/sympy/polys/numberfields/tests/test_numbers.py new file mode 100644 index 0000000000000000000000000000000000000000..f8f350719cc740901a29d03e45ae9f3978446f31 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/numberfields/tests/test_numbers.py @@ -0,0 +1,202 @@ +"""Tests on algebraic numbers. """ + +from sympy.core.containers import Tuple +from sympy.core.numbers import (AlgebraicNumber, I, Rational) +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.polys.polytools import Poly +from sympy.polys.numberfields.subfield import to_number_field +from sympy.polys.polyclasses import DMP +from sympy.polys.domains import QQ +from sympy.polys.rootoftools import CRootOf +from sympy.abc import x, y + + +def test_AlgebraicNumber(): + minpoly, root = x**2 - 2, sqrt(2) + + a = AlgebraicNumber(root, gen=x) + + assert a.rep == DMP([QQ(1), QQ(0)], QQ) + assert a.root == root + assert a.alias is None + assert a.minpoly == minpoly + assert a.is_number + + assert a.is_aliased is False + + assert a.coeffs() == [S.One, S.Zero] + assert a.native_coeffs() == [QQ(1), QQ(0)] + + a = AlgebraicNumber(root, gen=x, alias='y') + + assert a.rep == DMP([QQ(1), QQ(0)], QQ) + assert a.root == root + assert a.alias == Symbol('y') + assert a.minpoly == minpoly + assert a.is_number + + assert a.is_aliased is True + + a = AlgebraicNumber(root, gen=x, alias=Symbol('y')) + + assert a.rep == DMP([QQ(1), QQ(0)], QQ) + assert a.root == root + assert a.alias == Symbol('y') + assert a.minpoly == minpoly + assert a.is_number + + assert a.is_aliased is True + + assert AlgebraicNumber(sqrt(2), []).rep == DMP([], QQ) + assert AlgebraicNumber(sqrt(2), ()).rep == DMP([], QQ) + assert AlgebraicNumber(sqrt(2), (0, 0)).rep == DMP([], QQ) + + assert AlgebraicNumber(sqrt(2), [8]).rep == DMP([QQ(8)], QQ) + assert AlgebraicNumber(sqrt(2), [Rational(8, 3)]).rep == DMP([QQ(8, 3)], QQ) + + assert AlgebraicNumber(sqrt(2), [7, 3]).rep == DMP([QQ(7), QQ(3)], QQ) + assert AlgebraicNumber( + sqrt(2), [Rational(7, 9), Rational(3, 2)]).rep == DMP([QQ(7, 9), QQ(3, 2)], QQ) + + assert AlgebraicNumber(sqrt(2), [1, 2, 3]).rep == DMP([QQ(2), QQ(5)], QQ) + + a = AlgebraicNumber(AlgebraicNumber(root, gen=x), [1, 2]) + + assert a.rep == DMP([QQ(1), QQ(2)], QQ) + assert a.root == root + assert a.alias is None + assert a.minpoly == minpoly + assert a.is_number + + assert a.is_aliased is False + + assert a.coeffs() == [S.One, S(2)] + assert a.native_coeffs() == [QQ(1), QQ(2)] + + a = AlgebraicNumber((minpoly, root), [1, 2]) + + assert a.rep == DMP([QQ(1), QQ(2)], QQ) + assert a.root == root + assert a.alias is None + assert a.minpoly == minpoly + assert a.is_number + + assert a.is_aliased is False + + a = AlgebraicNumber((Poly(minpoly), root), [1, 2]) + + assert a.rep == DMP([QQ(1), QQ(2)], QQ) + assert a.root == root + assert a.alias is None + assert a.minpoly == minpoly + assert a.is_number + + assert a.is_aliased is False + + assert AlgebraicNumber( sqrt(3)).rep == DMP([ QQ(1), QQ(0)], QQ) + assert AlgebraicNumber(-sqrt(3)).rep == DMP([ QQ(1), QQ(0)], QQ) + + a = AlgebraicNumber(sqrt(2)) + b = AlgebraicNumber(sqrt(2)) + + assert a == b + + c = AlgebraicNumber(sqrt(2), gen=x) + + assert a == b + assert a == c + + a = AlgebraicNumber(sqrt(2), [1, 2]) + b = AlgebraicNumber(sqrt(2), [1, 3]) + + assert a != b and a != sqrt(2) + 3 + + assert (a == x) is False and (a != x) is True + + a = AlgebraicNumber(sqrt(2), [1, 0]) + b = AlgebraicNumber(sqrt(2), [1, 0], alias=y) + + assert a.as_poly(x) == Poly(x, domain='QQ') + assert b.as_poly() == Poly(y, domain='QQ') + + assert a.as_expr() == sqrt(2) + assert a.as_expr(x) == x + assert b.as_expr() == sqrt(2) + assert b.as_expr(x) == x + + a = AlgebraicNumber(sqrt(2), [2, 3]) + b = AlgebraicNumber(sqrt(2), [2, 3], alias=y) + + p = a.as_poly() + + assert p == Poly(2*p.gen + 3) + + assert a.as_poly(x) == Poly(2*x + 3, domain='QQ') + assert b.as_poly() == Poly(2*y + 3, domain='QQ') + + assert a.as_expr() == 2*sqrt(2) + 3 + assert a.as_expr(x) == 2*x + 3 + assert b.as_expr() == 2*sqrt(2) + 3 + assert b.as_expr(x) == 2*x + 3 + + a = AlgebraicNumber(sqrt(2)) + b = to_number_field(sqrt(2)) + assert a.args == b.args == (sqrt(2), Tuple(1, 0)) + b = AlgebraicNumber(sqrt(2), alias='alpha') + assert b.args == (sqrt(2), Tuple(1, 0), Symbol('alpha')) + + a = AlgebraicNumber(sqrt(2), [1, 2, 3]) + assert a.args == (sqrt(2), Tuple(1, 2, 3)) + + a = AlgebraicNumber(sqrt(2), [1, 2], "alpha") + b = AlgebraicNumber(a) + c = AlgebraicNumber(a, alias="gamma") + assert a == b + assert c.alias.name == "gamma" + + a = AlgebraicNumber(sqrt(2) + sqrt(3), [S(1)/2, 0, S(-9)/2, 0]) + b = AlgebraicNumber(a, [1, 0, 0]) + assert b.root == a.root + assert a.to_root() == sqrt(2) + assert b.to_root() == 2 + + a = AlgebraicNumber(2) + assert a.is_primitive_element is True + + +def test_to_algebraic_integer(): + a = AlgebraicNumber(sqrt(3), gen=x).to_algebraic_integer() + + assert a.minpoly == x**2 - 3 + assert a.root == sqrt(3) + assert a.rep == DMP([QQ(1), QQ(0)], QQ) + + a = AlgebraicNumber(2*sqrt(3), gen=x).to_algebraic_integer() + assert a.minpoly == x**2 - 12 + assert a.root == 2*sqrt(3) + assert a.rep == DMP([QQ(1), QQ(0)], QQ) + + a = AlgebraicNumber(sqrt(3)/2, gen=x).to_algebraic_integer() + + assert a.minpoly == x**2 - 12 + assert a.root == 2*sqrt(3) + assert a.rep == DMP([QQ(1), QQ(0)], QQ) + + a = AlgebraicNumber(sqrt(3)/2, [Rational(7, 19), 3], gen=x).to_algebraic_integer() + + assert a.minpoly == x**2 - 12 + assert a.root == 2*sqrt(3) + assert a.rep == DMP([QQ(7, 19), QQ(3)], QQ) + + +def test_AlgebraicNumber_to_root(): + assert AlgebraicNumber(sqrt(2)).to_root() == sqrt(2) + + zeta5_squared = AlgebraicNumber(CRootOf(x**5 - 1, 4), coeffs=[1, 0, 0]) + assert zeta5_squared.to_root() == CRootOf(x**4 + x**3 + x**2 + x + 1, 1) + + zeta3_squared = AlgebraicNumber(CRootOf(x**3 - 1, 2), coeffs=[1, 0, 0]) + assert zeta3_squared.to_root() == -S(1)/2 - sqrt(3)*I/2 + assert zeta3_squared.to_root(radicals=False) == CRootOf(x**2 + x + 1, 0) diff --git a/phivenv/Lib/site-packages/sympy/polys/numberfields/tests/test_primes.py b/phivenv/Lib/site-packages/sympy/polys/numberfields/tests/test_primes.py new file mode 100644 index 0000000000000000000000000000000000000000..f121d60d272fe65345de773748828a8a67eb0028 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/numberfields/tests/test_primes.py @@ -0,0 +1,296 @@ +from math import prod + +from sympy import QQ, ZZ +from sympy.abc import x, theta +from sympy.ntheory import factorint +from sympy.ntheory.residue_ntheory import n_order +from sympy.polys import Poly, cyclotomic_poly +from sympy.polys.matrices import DomainMatrix +from sympy.polys.numberfields.basis import round_two +from sympy.polys.numberfields.exceptions import StructureError +from sympy.polys.numberfields.modules import PowerBasis, to_col +from sympy.polys.numberfields.primes import ( + prime_decomp, _two_elt_rep, + _check_formal_conditions_for_maximal_order, +) +from sympy.testing.pytest import raises + + +def test_check_formal_conditions_for_maximal_order(): + T = Poly(cyclotomic_poly(5, x)) + A = PowerBasis(T) + B = A.submodule_from_matrix(2 * DomainMatrix.eye(4, ZZ)) + C = B.submodule_from_matrix(3 * DomainMatrix.eye(4, ZZ)) + D = A.submodule_from_matrix(DomainMatrix.eye(4, ZZ)[:, :-1]) + # Is a direct submodule of a power basis, but lacks 1 as first generator: + raises(StructureError, lambda: _check_formal_conditions_for_maximal_order(B)) + # Is not a direct submodule of a power basis: + raises(StructureError, lambda: _check_formal_conditions_for_maximal_order(C)) + # Is direct submod of pow basis, and starts with 1, but not sq/max rank/HNF: + raises(StructureError, lambda: _check_formal_conditions_for_maximal_order(D)) + + +def test_two_elt_rep(): + ell = 7 + T = Poly(cyclotomic_poly(ell)) + ZK, dK = round_two(T) + for p in [29, 13, 11, 5]: + P = prime_decomp(p, T) + for Pi in P: + # We have Pi in two-element representation, and, because we are + # looking at a cyclotomic field, this was computed by the "easy" + # method that just factors T mod p. We will now convert this to + # a set of Z-generators, then convert that back into a two-element + # rep. The latter need not be identical to the two-elt rep we + # already have, but it must have the same HNF. + H = p*ZK + Pi.alpha*ZK + gens = H.basis_element_pullbacks() + # Note: we could supply f = Pi.f, but prefer to test behavior without it. + b = _two_elt_rep(gens, ZK, p) + if b != Pi.alpha: + H2 = p*ZK + b*ZK + assert H2 == H + + +def test_valuation_at_prime_ideal(): + p = 7 + T = Poly(cyclotomic_poly(p)) + ZK, dK = round_two(T) + P = prime_decomp(p, T, dK=dK, ZK=ZK) + assert len(P) == 1 + P0 = P[0] + v = P0.valuation(p*ZK) + assert v == P0.e + # Test easy 0 case: + assert P0.valuation(5*ZK) == 0 + + +def test_decomp_1(): + # All prime decompositions in cyclotomic fields are in the "easy case," + # since the index is unity. + # Here we check the ramified prime. + T = Poly(cyclotomic_poly(7)) + raises(ValueError, lambda: prime_decomp(7)) + P = prime_decomp(7, T) + assert len(P) == 1 + P0 = P[0] + assert P0.e == 6 + assert P0.f == 1 + # Test powers: + assert P0**0 == P0.ZK + assert P0**1 == P0 + assert P0**6 == 7 * P0.ZK + + +def test_decomp_2(): + # More easy cyclotomic cases, but here we check unramified primes. + ell = 7 + T = Poly(cyclotomic_poly(ell)) + for p in [29, 13, 11, 5]: + f_exp = n_order(p, ell) + g_exp = (ell - 1) // f_exp + P = prime_decomp(p, T) + assert len(P) == g_exp + for Pi in P: + assert Pi.e == 1 + assert Pi.f == f_exp + + +def test_decomp_3(): + T = Poly(x ** 2 - 35) + rad = {} + ZK, dK = round_two(T, radicals=rad) + # 35 is 3 mod 4, so field disc is 4*5*7, and theory says each of the + # rational primes 2, 5, 7 should be the square of a prime ideal. + for p in [2, 5, 7]: + P = prime_decomp(p, T, dK=dK, ZK=ZK, radical=rad.get(p)) + assert len(P) == 1 + assert P[0].e == 2 + assert P[0]**2 == p*ZK + + +def test_decomp_4(): + T = Poly(x ** 2 - 21) + rad = {} + ZK, dK = round_two(T, radicals=rad) + # 21 is 1 mod 4, so field disc is 3*7, and theory says the + # rational primes 3, 7 should be the square of a prime ideal. + for p in [3, 7]: + P = prime_decomp(p, T, dK=dK, ZK=ZK, radical=rad.get(p)) + assert len(P) == 1 + assert P[0].e == 2 + assert P[0]**2 == p*ZK + + +def test_decomp_5(): + # Here is our first test of the "hard case" of prime decomposition. + # We work in a quadratic extension Q(sqrt(d)) where d is 1 mod 4, and + # we consider the factorization of the rational prime 2, which divides + # the index. + # Theory says the form of p's factorization depends on the residue of + # d mod 8, so we consider both cases, d = 1 mod 8 and d = 5 mod 8. + for d in [-7, -3]: + T = Poly(x ** 2 - d) + rad = {} + ZK, dK = round_two(T, radicals=rad) + p = 2 + P = prime_decomp(p, T, dK=dK, ZK=ZK, radical=rad.get(p)) + if d % 8 == 1: + assert len(P) == 2 + assert all(P[i].e == 1 and P[i].f == 1 for i in range(2)) + assert prod(Pi**Pi.e for Pi in P) == p * ZK + else: + assert d % 8 == 5 + assert len(P) == 1 + assert P[0].e == 1 + assert P[0].f == 2 + assert P[0].as_submodule() == p * ZK + + +def test_decomp_6(): + # Another case where 2 divides the index. This is Dedekind's example of + # an essential discriminant divisor. (See Cohen, Exercise 6.10.) + T = Poly(x ** 3 + x ** 2 - 2 * x + 8) + rad = {} + ZK, dK = round_two(T, radicals=rad) + p = 2 + P = prime_decomp(p, T, dK=dK, ZK=ZK, radical=rad.get(p)) + assert len(P) == 3 + assert all(Pi.e == Pi.f == 1 for Pi in P) + assert prod(Pi**Pi.e for Pi in P) == p*ZK + + +def test_decomp_7(): + # Try working through an AlgebraicField + T = Poly(x ** 3 + x ** 2 - 2 * x + 8) + K = QQ.alg_field_from_poly(T) + p = 2 + P = K.primes_above(p) + ZK = K.maximal_order() + assert len(P) == 3 + assert all(Pi.e == Pi.f == 1 for Pi in P) + assert prod(Pi**Pi.e for Pi in P) == p*ZK + + +def test_decomp_8(): + # This time we consider various cubics, and try factoring all primes + # dividing the index. + cases = ( + x ** 3 + 3 * x ** 2 - 4 * x + 4, + x ** 3 + 3 * x ** 2 + 3 * x - 3, + x ** 3 + 5 * x ** 2 - x + 3, + x ** 3 + 5 * x ** 2 - 5 * x - 5, + x ** 3 + 3 * x ** 2 + 5, + x ** 3 + 6 * x ** 2 + 3 * x - 1, + x ** 3 + 6 * x ** 2 + 4, + x ** 3 + 7 * x ** 2 + 7 * x - 7, + x ** 3 + 7 * x ** 2 - x + 5, + x ** 3 + 7 * x ** 2 - 5 * x + 5, + x ** 3 + 4 * x ** 2 - 3 * x + 7, + x ** 3 + 8 * x ** 2 + 5 * x - 1, + x ** 3 + 8 * x ** 2 - 2 * x + 6, + x ** 3 + 6 * x ** 2 - 3 * x + 8, + x ** 3 + 9 * x ** 2 + 6 * x - 8, + x ** 3 + 15 * x ** 2 - 9 * x + 13, + ) + def display(T, p, radical, P, I, J): + """Useful for inspection, when running test manually.""" + print('=' * 20) + print(T, p, radical) + for Pi in P: + print(f' ({Pi!r})') + print("I: ", I) + print("J: ", J) + print(f'Equal: {I == J}') + inspect = False + for g in cases: + T = Poly(g) + rad = {} + ZK, dK = round_two(T, radicals=rad) + dT = T.discriminant() + f_squared = dT // dK + F = factorint(f_squared) + for p in F: + radical = rad.get(p) + P = prime_decomp(p, T, dK=dK, ZK=ZK, radical=radical) + I = prod(Pi**Pi.e for Pi in P) + J = p * ZK + if inspect: + display(T, p, radical, P, I, J) + assert I == J + + +def test_PrimeIdeal_eq(): + # `==` should fail on objects of different types, so even a completely + # inert PrimeIdeal should test unequal to the rational prime it divides. + T = Poly(cyclotomic_poly(7)) + P0 = prime_decomp(5, T)[0] + assert P0.f == 6 + assert P0.as_submodule() == 5 * P0.ZK + assert P0 != 5 + + +def test_PrimeIdeal_add(): + T = Poly(cyclotomic_poly(7)) + P0 = prime_decomp(7, T)[0] + # Adding ideals computes their GCD, so adding the ramified prime dividing + # 7 to 7 itself should reproduce this prime (as a submodule). + assert P0 + 7 * P0.ZK == P0.as_submodule() + + +def test_str(): + # Without alias: + k = QQ.alg_field_from_poly(Poly(x**2 + 7)) + frp = k.primes_above(2)[0] + assert str(frp) == '(2, 3*_x/2 + 1/2)' + + frp = k.primes_above(3)[0] + assert str(frp) == '(3)' + + # With alias: + k = QQ.alg_field_from_poly(Poly(x ** 2 + 7), alias='alpha') + frp = k.primes_above(2)[0] + assert str(frp) == '(2, 3*alpha/2 + 1/2)' + + frp = k.primes_above(3)[0] + assert str(frp) == '(3)' + + +def test_repr(): + T = Poly(x**2 + 7) + ZK, dK = round_two(T) + P = prime_decomp(2, T, dK=dK, ZK=ZK) + assert repr(P[0]) == '[ (2, (3*x + 1)/2) e=1, f=1 ]' + assert P[0].repr(field_gen=theta) == '[ (2, (3*theta + 1)/2) e=1, f=1 ]' + assert P[0].repr(field_gen=theta, just_gens=True) == '(2, (3*theta + 1)/2)' + + +def test_PrimeIdeal_reduce(): + k = QQ.alg_field_from_poly(Poly(x ** 3 + x ** 2 - 2 * x + 8)) + Zk = k.maximal_order() + P = k.primes_above(2) + frp = P[2] + + # reduce_element + a = Zk.parent(to_col([23, 20, 11]), denom=6) + a_bar_expected = Zk.parent(to_col([11, 5, 2]), denom=6) + a_bar = frp.reduce_element(a) + assert a_bar == a_bar_expected + + # reduce_ANP + a = k([QQ(11, 6), QQ(20, 6), QQ(23, 6)]) + a_bar_expected = k([QQ(2, 6), QQ(5, 6), QQ(11, 6)]) + a_bar = frp.reduce_ANP(a) + assert a_bar == a_bar_expected + + # reduce_alg_num + a = k.to_alg_num(a) + a_bar_expected = k.to_alg_num(a_bar_expected) + a_bar = frp.reduce_alg_num(a) + assert a_bar == a_bar_expected + + +def test_issue_23402(): + k = QQ.alg_field_from_poly(Poly(x ** 3 + x ** 2 - 2 * x + 8)) + P = k.primes_above(3) + assert P[0].alpha.equiv(0) diff --git a/phivenv/Lib/site-packages/sympy/polys/numberfields/tests/test_subfield.py b/phivenv/Lib/site-packages/sympy/polys/numberfields/tests/test_subfield.py new file mode 100644 index 0000000000000000000000000000000000000000..b152dd684aa20034f9233eedb1866aac2639b5f9 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/numberfields/tests/test_subfield.py @@ -0,0 +1,317 @@ +"""Tests for the subfield problem and allied problems. """ + +from sympy.core.numbers import (AlgebraicNumber, I, pi, Rational) +from sympy.core.singleton import S +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.external.gmpy import MPQ +from sympy.polys.numberfields.subfield import ( + is_isomorphism_possible, + field_isomorphism_pslq, + field_isomorphism, + primitive_element, + to_number_field, +) +from sympy.polys.domains import QQ +from sympy.polys.polyerrors import IsomorphismFailed +from sympy.polys.polytools import Poly +from sympy.polys.rootoftools import CRootOf +from sympy.testing.pytest import raises + +from sympy.abc import x + +Q = Rational + + +def test_field_isomorphism_pslq(): + a = AlgebraicNumber(I) + b = AlgebraicNumber(I*sqrt(3)) + + raises(NotImplementedError, lambda: field_isomorphism_pslq(a, b)) + + a = AlgebraicNumber(sqrt(2)) + b = AlgebraicNumber(sqrt(3)) + c = AlgebraicNumber(sqrt(7)) + d = AlgebraicNumber(sqrt(2) + sqrt(3)) + e = AlgebraicNumber(sqrt(2) + sqrt(3) + sqrt(7)) + + assert field_isomorphism_pslq(a, a) == [1, 0] + assert field_isomorphism_pslq(a, b) is None + assert field_isomorphism_pslq(a, c) is None + assert field_isomorphism_pslq(a, d) == [Q(1, 2), 0, -Q(9, 2), 0] + assert field_isomorphism_pslq( + a, e) == [Q(1, 80), 0, -Q(1, 2), 0, Q(59, 20), 0] + + assert field_isomorphism_pslq(b, a) is None + assert field_isomorphism_pslq(b, b) == [1, 0] + assert field_isomorphism_pslq(b, c) is None + assert field_isomorphism_pslq(b, d) == [-Q(1, 2), 0, Q(11, 2), 0] + assert field_isomorphism_pslq(b, e) == [-Q( + 3, 640), 0, Q(67, 320), 0, -Q(297, 160), 0, Q(313, 80), 0] + + assert field_isomorphism_pslq(c, a) is None + assert field_isomorphism_pslq(c, b) is None + assert field_isomorphism_pslq(c, c) == [1, 0] + assert field_isomorphism_pslq(c, d) is None + assert field_isomorphism_pslq(c, e) == [Q( + 3, 640), 0, -Q(71, 320), 0, Q(377, 160), 0, -Q(469, 80), 0] + + assert field_isomorphism_pslq(d, a) is None + assert field_isomorphism_pslq(d, b) is None + assert field_isomorphism_pslq(d, c) is None + assert field_isomorphism_pslq(d, d) == [1, 0] + assert field_isomorphism_pslq(d, e) == [-Q( + 3, 640), 0, Q(71, 320), 0, -Q(377, 160), 0, Q(549, 80), 0] + + assert field_isomorphism_pslq(e, a) is None + assert field_isomorphism_pslq(e, b) is None + assert field_isomorphism_pslq(e, c) is None + assert field_isomorphism_pslq(e, d) is None + assert field_isomorphism_pslq(e, e) == [1, 0] + + f = AlgebraicNumber(3*sqrt(2) + 8*sqrt(7) - 5) + + assert field_isomorphism_pslq( + f, e) == [Q(3, 80), 0, -Q(139, 80), 0, Q(347, 20), 0, -Q(761, 20), -5] + + +def test_field_isomorphism(): + assert field_isomorphism(3, sqrt(2)) == [3] + + assert field_isomorphism( I*sqrt(3), I*sqrt(3)/2) == [ 2, 0] + assert field_isomorphism(-I*sqrt(3), I*sqrt(3)/2) == [-2, 0] + + assert field_isomorphism( I*sqrt(3), -I*sqrt(3)/2) == [-2, 0] + assert field_isomorphism(-I*sqrt(3), -I*sqrt(3)/2) == [ 2, 0] + + assert field_isomorphism( 2*I*sqrt(3)/7, 5*I*sqrt(3)/3) == [ Rational(6, 35), 0] + assert field_isomorphism(-2*I*sqrt(3)/7, 5*I*sqrt(3)/3) == [Rational(-6, 35), 0] + + assert field_isomorphism( 2*I*sqrt(3)/7, -5*I*sqrt(3)/3) == [Rational(-6, 35), 0] + assert field_isomorphism(-2*I*sqrt(3)/7, -5*I*sqrt(3)/3) == [ Rational(6, 35), 0] + + assert field_isomorphism( + 2*I*sqrt(3)/7 + 27, 5*I*sqrt(3)/3) == [ Rational(6, 35), 27] + assert field_isomorphism( + -2*I*sqrt(3)/7 + 27, 5*I*sqrt(3)/3) == [Rational(-6, 35), 27] + + assert field_isomorphism( + 2*I*sqrt(3)/7 + 27, -5*I*sqrt(3)/3) == [Rational(-6, 35), 27] + assert field_isomorphism( + -2*I*sqrt(3)/7 + 27, -5*I*sqrt(3)/3) == [ Rational(6, 35), 27] + + p = AlgebraicNumber( sqrt(2) + sqrt(3)) + q = AlgebraicNumber(-sqrt(2) + sqrt(3)) + r = AlgebraicNumber( sqrt(2) - sqrt(3)) + s = AlgebraicNumber(-sqrt(2) - sqrt(3)) + + pos_coeffs = [ S.Half, S.Zero, Rational(-9, 2), S.Zero] + neg_coeffs = [Rational(-1, 2), S.Zero, Rational(9, 2), S.Zero] + + a = AlgebraicNumber(sqrt(2)) + + assert is_isomorphism_possible(a, p) is True + assert is_isomorphism_possible(a, q) is True + assert is_isomorphism_possible(a, r) is True + assert is_isomorphism_possible(a, s) is True + + assert field_isomorphism(a, p, fast=True) == pos_coeffs + assert field_isomorphism(a, q, fast=True) == neg_coeffs + assert field_isomorphism(a, r, fast=True) == pos_coeffs + assert field_isomorphism(a, s, fast=True) == neg_coeffs + + assert field_isomorphism(a, p, fast=False) == pos_coeffs + assert field_isomorphism(a, q, fast=False) == neg_coeffs + assert field_isomorphism(a, r, fast=False) == pos_coeffs + assert field_isomorphism(a, s, fast=False) == neg_coeffs + + a = AlgebraicNumber(-sqrt(2)) + + assert is_isomorphism_possible(a, p) is True + assert is_isomorphism_possible(a, q) is True + assert is_isomorphism_possible(a, r) is True + assert is_isomorphism_possible(a, s) is True + + assert field_isomorphism(a, p, fast=True) == neg_coeffs + assert field_isomorphism(a, q, fast=True) == pos_coeffs + assert field_isomorphism(a, r, fast=True) == neg_coeffs + assert field_isomorphism(a, s, fast=True) == pos_coeffs + + assert field_isomorphism(a, p, fast=False) == neg_coeffs + assert field_isomorphism(a, q, fast=False) == pos_coeffs + assert field_isomorphism(a, r, fast=False) == neg_coeffs + assert field_isomorphism(a, s, fast=False) == pos_coeffs + + pos_coeffs = [ S.Half, S.Zero, Rational(-11, 2), S.Zero] + neg_coeffs = [Rational(-1, 2), S.Zero, Rational(11, 2), S.Zero] + + a = AlgebraicNumber(sqrt(3)) + + assert is_isomorphism_possible(a, p) is True + assert is_isomorphism_possible(a, q) is True + assert is_isomorphism_possible(a, r) is True + assert is_isomorphism_possible(a, s) is True + + assert field_isomorphism(a, p, fast=True) == neg_coeffs + assert field_isomorphism(a, q, fast=True) == neg_coeffs + assert field_isomorphism(a, r, fast=True) == pos_coeffs + assert field_isomorphism(a, s, fast=True) == pos_coeffs + + assert field_isomorphism(a, p, fast=False) == neg_coeffs + assert field_isomorphism(a, q, fast=False) == neg_coeffs + assert field_isomorphism(a, r, fast=False) == pos_coeffs + assert field_isomorphism(a, s, fast=False) == pos_coeffs + + a = AlgebraicNumber(-sqrt(3)) + + assert is_isomorphism_possible(a, p) is True + assert is_isomorphism_possible(a, q) is True + assert is_isomorphism_possible(a, r) is True + assert is_isomorphism_possible(a, s) is True + + assert field_isomorphism(a, p, fast=True) == pos_coeffs + assert field_isomorphism(a, q, fast=True) == pos_coeffs + assert field_isomorphism(a, r, fast=True) == neg_coeffs + assert field_isomorphism(a, s, fast=True) == neg_coeffs + + assert field_isomorphism(a, p, fast=False) == pos_coeffs + assert field_isomorphism(a, q, fast=False) == pos_coeffs + assert field_isomorphism(a, r, fast=False) == neg_coeffs + assert field_isomorphism(a, s, fast=False) == neg_coeffs + + pos_coeffs = [ Rational(3, 2), S.Zero, Rational(-33, 2), -S(8)] + neg_coeffs = [Rational(-3, 2), S.Zero, Rational(33, 2), -S(8)] + + a = AlgebraicNumber(3*sqrt(3) - 8) + + assert is_isomorphism_possible(a, p) is True + assert is_isomorphism_possible(a, q) is True + assert is_isomorphism_possible(a, r) is True + assert is_isomorphism_possible(a, s) is True + + assert field_isomorphism(a, p, fast=True) == neg_coeffs + assert field_isomorphism(a, q, fast=True) == neg_coeffs + assert field_isomorphism(a, r, fast=True) == pos_coeffs + assert field_isomorphism(a, s, fast=True) == pos_coeffs + + assert field_isomorphism(a, p, fast=False) == neg_coeffs + assert field_isomorphism(a, q, fast=False) == neg_coeffs + assert field_isomorphism(a, r, fast=False) == pos_coeffs + assert field_isomorphism(a, s, fast=False) == pos_coeffs + + a = AlgebraicNumber(3*sqrt(2) + 2*sqrt(3) + 1) + + pos_1_coeffs = [ S.Half, S.Zero, Rational(-5, 2), S.One] + neg_5_coeffs = [Rational(-5, 2), S.Zero, Rational(49, 2), S.One] + pos_5_coeffs = [ Rational(5, 2), S.Zero, Rational(-49, 2), S.One] + neg_1_coeffs = [Rational(-1, 2), S.Zero, Rational(5, 2), S.One] + + assert is_isomorphism_possible(a, p) is True + assert is_isomorphism_possible(a, q) is True + assert is_isomorphism_possible(a, r) is True + assert is_isomorphism_possible(a, s) is True + + assert field_isomorphism(a, p, fast=True) == pos_1_coeffs + assert field_isomorphism(a, q, fast=True) == neg_5_coeffs + assert field_isomorphism(a, r, fast=True) == pos_5_coeffs + assert field_isomorphism(a, s, fast=True) == neg_1_coeffs + + assert field_isomorphism(a, p, fast=False) == pos_1_coeffs + assert field_isomorphism(a, q, fast=False) == neg_5_coeffs + assert field_isomorphism(a, r, fast=False) == pos_5_coeffs + assert field_isomorphism(a, s, fast=False) == neg_1_coeffs + + a = AlgebraicNumber(sqrt(2)) + b = AlgebraicNumber(sqrt(3)) + c = AlgebraicNumber(sqrt(7)) + + assert is_isomorphism_possible(a, b) is True + assert is_isomorphism_possible(b, a) is True + + assert is_isomorphism_possible(c, p) is False + + assert field_isomorphism(sqrt(2), sqrt(3), fast=True) is None + assert field_isomorphism(sqrt(3), sqrt(2), fast=True) is None + + assert field_isomorphism(sqrt(2), sqrt(3), fast=False) is None + assert field_isomorphism(sqrt(3), sqrt(2), fast=False) is None + + a = AlgebraicNumber(sqrt(2)) + b = AlgebraicNumber(2 ** (S(1) / 3)) + + assert is_isomorphism_possible(a, b) is False + assert field_isomorphism(a, b) is None + + +def test_primitive_element(): + assert primitive_element([sqrt(2)], x) == (x**2 - 2, [1]) + assert primitive_element( + [sqrt(2), sqrt(3)], x) == (x**4 - 10*x**2 + 1, [1, 1]) + + assert primitive_element([sqrt(2)], x, polys=True) == (Poly(x**2 - 2, domain='QQ'), [1]) + assert primitive_element([sqrt( + 2), sqrt(3)], x, polys=True) == (Poly(x**4 - 10*x**2 + 1, domain='QQ'), [1, 1]) + + assert primitive_element( + [sqrt(2)], x, ex=True) == (x**2 - 2, [1], [[1, 0]]) + assert primitive_element([sqrt(2), sqrt(3)], x, ex=True) == \ + (x**4 - 10*x**2 + 1, [1, 1], [[Q(1, 2), 0, -Q(9, 2), 0], [- + Q(1, 2), 0, Q(11, 2), 0]]) + + assert primitive_element( + [sqrt(2)], x, ex=True, polys=True) == (Poly(x**2 - 2, domain='QQ'), [1], [[1, 0]]) + assert primitive_element([sqrt(2), sqrt(3)], x, ex=True, polys=True) == \ + (Poly(x**4 - 10*x**2 + 1, domain='QQ'), [1, 1], [[Q(1, 2), 0, -Q(9, 2), + 0], [-Q(1, 2), 0, Q(11, 2), 0]]) + + assert primitive_element([sqrt(2)], polys=True) == (Poly(x**2 - 2), [1]) + + raises(ValueError, lambda: primitive_element([], x, ex=False)) + raises(ValueError, lambda: primitive_element([], x, ex=True)) + + # Issue 14117 + a, b = I*sqrt(2*sqrt(2) + 3), I*sqrt(-2*sqrt(2) + 3) + assert primitive_element([a, b, I], x) == (x**4 + 6*x**2 + 1, [1, 0, 0]) + + assert primitive_element([sqrt(2), 0], x) == (x**2 - 2, [1, 0]) + assert primitive_element([0, sqrt(2)], x) == (x**2 - 2, [1, 1]) + assert primitive_element([sqrt(2), 0], x, ex=True) == (x**2 - 2, [1, 0], [[MPQ(1,1), MPQ(0,1)], []]) + assert primitive_element([0, sqrt(2)], x, ex=True) == (x**2 - 2, [1, 1], [[], [MPQ(1,1), MPQ(0,1)]]) + + +def test_to_number_field(): + assert to_number_field(sqrt(2)) == AlgebraicNumber(sqrt(2)) + assert to_number_field( + [sqrt(2), sqrt(3)]) == AlgebraicNumber(sqrt(2) + sqrt(3)) + + a = AlgebraicNumber(sqrt(2) + sqrt(3), [S.Half, S.Zero, Rational(-9, 2), S.Zero]) + + assert to_number_field(sqrt(2), sqrt(2) + sqrt(3)) == a + assert to_number_field(sqrt(2), AlgebraicNumber(sqrt(2) + sqrt(3))) == a + + raises(IsomorphismFailed, lambda: to_number_field(sqrt(2), sqrt(3))) + + +def test_issue_22561(): + a = to_number_field(sqrt(2), sqrt(2) + sqrt(3)) + b = to_number_field(sqrt(2), sqrt(2) + sqrt(5)) + assert field_isomorphism(a, b) == [1, 0] + + +def test_issue_22736(): + a = CRootOf(x**4 + x**3 + x**2 + x + 1, -1) + a._reset() + b = exp(2*I*pi/5) + assert field_isomorphism(a, b) == [1, 0] + + +def test_issue_27798(): + # https://github.com/sympy/sympy/issues/27798 + a, b = CRootOf(49*x**3 - 49*x**2 + 14*x - 1, 2), CRootOf(49*x**3 - 49*x**2 + 14*x - 1, 0) + assert primitive_element([a, b], polys=True)[0].primitive()[0] == 1 + assert primitive_element([a, b], polys=True, ex=True)[0].primitive()[0] == 1 + + f1, f2 = QQ.algebraic_field(a), QQ.algebraic_field(b) + f3 = f1.unify(f2) + assert f3.mod.primitive()[0] == 1 + assert Poly(x, x, domain=f1) + Poly(x, x, domain=f2) == Poly(2*x, x, domain=f3) diff --git a/phivenv/Lib/site-packages/sympy/polys/numberfields/tests/test_utilities.py b/phivenv/Lib/site-packages/sympy/polys/numberfields/tests/test_utilities.py new file mode 100644 index 0000000000000000000000000000000000000000..134853ef0c88045ef9cc7e215bb98db37041e63a --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/numberfields/tests/test_utilities.py @@ -0,0 +1,113 @@ +from sympy.abc import x +from sympy.core.numbers import (I, Rational) +from sympy.core.singleton import S +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.polys import Poly, cyclotomic_poly +from sympy.polys.domains import FF, QQ +from sympy.polys.matrices import DomainMatrix, DM +from sympy.polys.matrices.exceptions import DMRankError +from sympy.polys.numberfields.utilities import ( + AlgIntPowers, coeff_search, extract_fundamental_discriminant, + isolate, supplement_a_subspace, +) +from sympy.printing.lambdarepr import IntervalPrinter +from sympy.testing.pytest import raises + + +def test_AlgIntPowers_01(): + T = Poly(cyclotomic_poly(5)) + zeta_pow = AlgIntPowers(T) + raises(ValueError, lambda: zeta_pow[-1]) + for e in range(10): + a = e % 5 + if a < 4: + c = zeta_pow[e] + assert c[a] == 1 and all(c[i] == 0 for i in range(4) if i != a) + else: + assert zeta_pow[e] == [-1] * 4 + + +def test_AlgIntPowers_02(): + T = Poly(x**3 + 2*x**2 + 3*x + 4) + m = 7 + theta_pow = AlgIntPowers(T, m) + for e in range(10): + computed = theta_pow[e] + coeffs = (Poly(x)**e % T + Poly(x**3)).rep.to_list()[1:] + expected = [c % m for c in reversed(coeffs)] + assert computed == expected + + +def test_coeff_search(): + C = [] + search = coeff_search(2, 1) + for i, c in enumerate(search): + C.append(c) + if i == 12: + break + assert C == [[1, 1], [1, 0], [1, -1], [0, 1], [2, 2], [2, 1], [2, 0], [2, -1], [2, -2], [1, 2], [1, -2], [0, 2], [3, 3]] + + +def test_extract_fundamental_discriminant(): + # To extract, integer must be 0 or 1 mod 4. + raises(ValueError, lambda: extract_fundamental_discriminant(2)) + raises(ValueError, lambda: extract_fundamental_discriminant(3)) + # Try many cases, of different forms: + cases = ( + (0, {}, {0: 1}), + (1, {}, {}), + (8, {2: 3}, {}), + (-8, {2: 3, -1: 1}, {}), + (12, {2: 2, 3: 1}, {}), + (36, {}, {2: 1, 3: 1}), + (45, {5: 1}, {3: 1}), + (48, {2: 2, 3: 1}, {2: 1}), + (1125, {5: 1}, {3: 1, 5: 1}), + ) + for a, D_expected, F_expected in cases: + D, F = extract_fundamental_discriminant(a) + assert D == D_expected + assert F == F_expected + + +def test_supplement_a_subspace_1(): + M = DM([[1, 7, 0], [2, 3, 4]], QQ).transpose() + + # First supplement over QQ: + B = supplement_a_subspace(M) + assert B[:, :2] == M + assert B[:, 2] == DomainMatrix.eye(3, QQ).to_dense()[:, 0] + + # Now supplement over FF(7): + M = M.convert_to(FF(7)) + B = supplement_a_subspace(M) + assert B[:, :2] == M + # When we work mod 7, first col of M goes to [1, 0, 0], + # so the supplementary vector cannot equal this, as it did + # when we worked over QQ. Instead, we get the second std basis vector: + assert B[:, 2] == DomainMatrix.eye(3, FF(7)).to_dense()[:, 1] + + +def test_supplement_a_subspace_2(): + M = DM([[1, 0, 0], [2, 0, 0]], QQ).transpose() + with raises(DMRankError): + supplement_a_subspace(M) + + +def test_IntervalPrinter(): + ip = IntervalPrinter() + assert ip.doprint(x**Rational(1, 3)) == "x**(mpi('1/3'))" + assert ip.doprint(sqrt(x)) == "x**(mpi('1/2'))" + + +def test_isolate(): + assert isolate(1) == (1, 1) + assert isolate(S.Half) == (S.Half, S.Half) + + assert isolate(sqrt(2)) == (1, 2) + assert isolate(-sqrt(2)) == (-2, -1) + + assert isolate(sqrt(2), eps=Rational(1, 100)) == (Rational(24, 17), Rational(17, 12)) + assert isolate(-sqrt(2), eps=Rational(1, 100)) == (Rational(-17, 12), Rational(-24, 17)) + + raises(NotImplementedError, lambda: isolate(I)) diff --git a/phivenv/Lib/site-packages/sympy/polys/numberfields/utilities.py b/phivenv/Lib/site-packages/sympy/polys/numberfields/utilities.py new file mode 100644 index 0000000000000000000000000000000000000000..fe583efb440f02f1b16c38fb7d03621c1f97e83d --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/numberfields/utilities.py @@ -0,0 +1,474 @@ +"""Utilities for algebraic number theory. """ + +from sympy.core.sympify import sympify +from sympy.ntheory.factor_ import factorint +from sympy.polys.domains.rationalfield import QQ +from sympy.polys.domains.integerring import ZZ +from sympy.polys.matrices.exceptions import DMRankError +from sympy.polys.numberfields.minpoly import minpoly +from sympy.printing.lambdarepr import IntervalPrinter +from sympy.utilities.decorator import public +from sympy.utilities.lambdify import lambdify + +from mpmath import mp + + +def is_rat(c): + r""" + Test whether an argument is of an acceptable type to be used as a rational + number. + + Explanation + =========== + + Returns ``True`` on any argument of type ``int``, :ref:`ZZ`, or :ref:`QQ`. + + See Also + ======== + + is_int + + """ + # ``c in QQ`` is too accepting (e.g. ``3.14 in QQ`` is ``True``), + # ``QQ.of_type(c)`` is too demanding (e.g. ``QQ.of_type(3)`` is ``False``). + # + # Meanwhile, if gmpy2 is installed then ``ZZ.of_type()`` accepts only + # ``mpz``, not ``int``, so we need another clause to ensure ``int`` is + # accepted. + return isinstance(c, int) or ZZ.of_type(c) or QQ.of_type(c) + + +def is_int(c): + r""" + Test whether an argument is of an acceptable type to be used as an integer. + + Explanation + =========== + + Returns ``True`` on any argument of type ``int`` or :ref:`ZZ`. + + See Also + ======== + + is_rat + + """ + # If gmpy2 is installed then ``ZZ.of_type()`` accepts only + # ``mpz``, not ``int``, so we need another clause to ensure ``int`` is + # accepted. + return isinstance(c, int) or ZZ.of_type(c) + + +def get_num_denom(c): + r""" + Given any argument on which :py:func:`~.is_rat` is ``True``, return the + numerator and denominator of this number. + + See Also + ======== + + is_rat + + """ + r = QQ(c) + return r.numerator, r.denominator + + +@public +def extract_fundamental_discriminant(a): + r""" + Extract a fundamental discriminant from an integer *a*. + + Explanation + =========== + + Given any rational integer *a* that is 0 or 1 mod 4, write $a = d f^2$, + where $d$ is either 1 or a fundamental discriminant, and return a pair + of dictionaries ``(D, F)`` giving the prime factorizations of $d$ and $f$ + respectively, in the same format returned by :py:func:`~.factorint`. + + A fundamental discriminant $d$ is different from unity, and is either + 1 mod 4 and squarefree, or is 0 mod 4 and such that $d/4$ is squarefree + and 2 or 3 mod 4. This is the same as being the discriminant of some + quadratic field. + + Examples + ======== + + >>> from sympy.polys.numberfields.utilities import extract_fundamental_discriminant + >>> print(extract_fundamental_discriminant(-432)) + ({3: 1, -1: 1}, {2: 2, 3: 1}) + + For comparison: + + >>> from sympy import factorint + >>> print(factorint(-432)) + {2: 4, 3: 3, -1: 1} + + Parameters + ========== + + a: int, must be 0 or 1 mod 4 + + Returns + ======= + + Pair ``(D, F)`` of dictionaries. + + Raises + ====== + + ValueError + If *a* is not 0 or 1 mod 4. + + References + ========== + + .. [1] Cohen, H. *A Course in Computational Algebraic Number Theory.* + (See Prop. 5.1.3) + + """ + if a % 4 not in [0, 1]: + raise ValueError('To extract fundamental discriminant, number must be 0 or 1 mod 4.') + if a == 0: + return {}, {0: 1} + if a == 1: + return {}, {} + a_factors = factorint(a) + D = {} + F = {} + # First pass: just make d squarefree, and a/d a perfect square. + # We'll count primes (and units! i.e. -1) that are 3 mod 4 and present in d. + num_3_mod_4 = 0 + for p, e in a_factors.items(): + if e % 2 == 1: + D[p] = 1 + if p % 4 == 3: + num_3_mod_4 += 1 + if e >= 3: + F[p] = (e - 1) // 2 + else: + F[p] = e // 2 + # Second pass: if d is cong. to 2 or 3 mod 4, then we must steal away + # another factor of 4 from f**2 and give it to d. + even = 2 in D + if even or num_3_mod_4 % 2 == 1: + e2 = F[2] + assert e2 > 0 + if e2 == 1: + del F[2] + else: + F[2] = e2 - 1 + D[2] = 3 if even else 2 + return D, F + + +@public +class AlgIntPowers: + r""" + Compute the powers of an algebraic integer. + + Explanation + =========== + + Given an algebraic integer $\theta$ by its monic irreducible polynomial + ``T`` over :ref:`ZZ`, this class computes representations of arbitrarily + high powers of $\theta$, as :ref:`ZZ`-linear combinations over + $\{1, \theta, \ldots, \theta^{n-1}\}$, where $n = \deg(T)$. + + The representations are computed using the linear recurrence relations for + powers of $\theta$, derived from the polynomial ``T``. See [1], Sec. 4.2.2. + + Optionally, the representations may be reduced with respect to a modulus. + + Examples + ======== + + >>> from sympy import Poly, cyclotomic_poly + >>> from sympy.polys.numberfields.utilities import AlgIntPowers + >>> T = Poly(cyclotomic_poly(5)) + >>> zeta_pow = AlgIntPowers(T) + >>> print(zeta_pow[0]) + [1, 0, 0, 0] + >>> print(zeta_pow[1]) + [0, 1, 0, 0] + >>> print(zeta_pow[4]) # doctest: +SKIP + [-1, -1, -1, -1] + >>> print(zeta_pow[24]) # doctest: +SKIP + [-1, -1, -1, -1] + + References + ========== + + .. [1] Cohen, H. *A Course in Computational Algebraic Number Theory.* + + """ + + def __init__(self, T, modulus=None): + """ + Parameters + ========== + + T : :py:class:`~.Poly` + The monic irreducible polynomial over :ref:`ZZ` defining the + algebraic integer. + + modulus : int, None, optional + If not ``None``, all representations will be reduced w.r.t. this. + + """ + self.T = T + self.modulus = modulus + self.n = T.degree() + self.powers_n_and_up = [[-c % self for c in reversed(T.rep.to_list())][:-1]] + self.max_so_far = self.n + + def red(self, exp): + return exp if self.modulus is None else exp % self.modulus + + def __rmod__(self, other): + return self.red(other) + + def compute_up_through(self, e): + m = self.max_so_far + if e <= m: return + n = self.n + r = self.powers_n_and_up + c = r[0] + for k in range(m+1, e+1): + b = r[k-1-n][n-1] + r.append( + [c[0]*b % self] + [ + (r[k-1-n][i-1] + c[i]*b) % self for i in range(1, n) + ] + ) + self.max_so_far = e + + def get(self, e): + n = self.n + if e < 0: + raise ValueError('Exponent must be non-negative.') + elif e < n: + return [1 if i == e else 0 for i in range(n)] + else: + self.compute_up_through(e) + return self.powers_n_and_up[e - n] + + def __getitem__(self, item): + return self.get(item) + + +@public +def coeff_search(m, R): + r""" + Generate coefficients for searching through polynomials. + + Explanation + =========== + + Lead coeff is always non-negative. Explore all combinations with coeffs + bounded in absolute value before increasing the bound. Skip the all-zero + list, and skip any repeats. See examples. + + Examples + ======== + + >>> from sympy.polys.numberfields.utilities import coeff_search + >>> cs = coeff_search(2, 1) + >>> C = [next(cs) for i in range(13)] + >>> print(C) + [[1, 1], [1, 0], [1, -1], [0, 1], [2, 2], [2, 1], [2, 0], [2, -1], [2, -2], + [1, 2], [1, -2], [0, 2], [3, 3]] + + Parameters + ========== + + m : int + Length of coeff list. + R : int + Initial max abs val for coeffs (will increase as search proceeds). + + Returns + ======= + + generator + Infinite generator of lists of coefficients. + + """ + R0 = R + c = [R] * m + while True: + if R == R0 or R in c or -R in c: + yield c[:] + j = m - 1 + while c[j] == -R: + j -= 1 + c[j] -= 1 + for i in range(j + 1, m): + c[i] = R + for j in range(m): + if c[j] != 0: + break + else: + R += 1 + c = [R] * m + + +def supplement_a_subspace(M): + r""" + Extend a basis for a subspace to a basis for the whole space. + + Explanation + =========== + + Given an $n \times r$ matrix *M* of rank $r$ (so $r \leq n$), this function + computes an invertible $n \times n$ matrix $B$ such that the first $r$ + columns of $B$ equal *M*. + + This operation can be interpreted as a way of extending a basis for a + subspace, to give a basis for the whole space. + + To be precise, suppose you have an $n$-dimensional vector space $V$, with + basis $\{v_1, v_2, \ldots, v_n\}$, and an $r$-dimensional subspace $W$ of + $V$, spanned by a basis $\{w_1, w_2, \ldots, w_r\}$, where the $w_j$ are + given as linear combinations of the $v_i$. If the columns of *M* represent + the $w_j$ as such linear combinations, then the columns of the matrix $B$ + computed by this function give a new basis $\{u_1, u_2, \ldots, u_n\}$ for + $V$, again relative to the $\{v_i\}$ basis, and such that $u_j = w_j$ + for $1 \leq j \leq r$. + + Examples + ======== + + Note: The function works in terms of columns, so in these examples we + print matrix transposes in order to make the columns easier to inspect. + + >>> from sympy.polys.matrices import DM + >>> from sympy import QQ, FF + >>> from sympy.polys.numberfields.utilities import supplement_a_subspace + >>> M = DM([[1, 7, 0], [2, 3, 4]], QQ).transpose() + >>> print(supplement_a_subspace(M).to_Matrix().transpose()) + Matrix([[1, 7, 0], [2, 3, 4], [1, 0, 0]]) + + >>> M2 = M.convert_to(FF(7)) + >>> print(M2.to_Matrix().transpose()) + Matrix([[1, 0, 0], [2, 3, -3]]) + >>> print(supplement_a_subspace(M2).to_Matrix().transpose()) + Matrix([[1, 0, 0], [2, 3, -3], [0, 1, 0]]) + + Parameters + ========== + + M : :py:class:`~.DomainMatrix` + The columns give the basis for the subspace. + + Returns + ======= + + :py:class:`~.DomainMatrix` + This matrix is invertible and its first $r$ columns equal *M*. + + Raises + ====== + + DMRankError + If *M* was not of maximal rank. + + References + ========== + + .. [1] Cohen, H. *A Course in Computational Algebraic Number Theory* + (See Sec. 2.3.2.) + + """ + n, r = M.shape + # Let In be the n x n identity matrix. + # Form the augmented matrix [M | In] and compute RREF. + Maug = M.hstack(M.eye(n, M.domain)) + R, pivots = Maug.rref() + if pivots[:r] != tuple(range(r)): + raise DMRankError('M was not of maximal rank') + # Let J be the n x r matrix equal to the first r columns of In. + # Since M is of rank r, RREF reduces [M | In] to [J | A], where A is the product of + # elementary matrices Ei corresp. to the row ops performed by RREF. Since the Ei are + # invertible, so is A. Let B = A^(-1). + A = R[:, r:] + B = A.inv() + # Then B is the desired matrix. It is invertible, since B^(-1) == A. + # And A * [M | In] == [J | A] + # => A * M == J + # => M == B * J == the first r columns of B. + return B + + +@public +def isolate(alg, eps=None, fast=False): + """ + Find a rational isolating interval for a real algebraic number. + + Examples + ======== + + >>> from sympy import isolate, sqrt, Rational + >>> print(isolate(sqrt(2))) # doctest: +SKIP + (1, 2) + >>> print(isolate(sqrt(2), eps=Rational(1, 100))) + (24/17, 17/12) + + Parameters + ========== + + alg : str, int, :py:class:`~.Expr` + The algebraic number to be isolated. Must be a real number, to use this + particular function. However, see also :py:meth:`.Poly.intervals`, + which isolates complex roots when you pass ``all=True``. + eps : positive element of :ref:`QQ`, None, optional (default=None) + Precision to be passed to :py:meth:`.Poly.refine_root` + fast : boolean, optional (default=False) + Say whether fast refinement procedure should be used. + (Will be passed to :py:meth:`.Poly.refine_root`.) + + Returns + ======= + + Pair of rational numbers defining an isolating interval for the given + algebraic number. + + See Also + ======== + + .Poly.intervals + + """ + alg = sympify(alg) + + if alg.is_Rational: + return (alg, alg) + elif not alg.is_real: + raise NotImplementedError( + "complex algebraic numbers are not supported") + + func = lambdify((), alg, modules="mpmath", printer=IntervalPrinter()) + + poly = minpoly(alg, polys=True) + intervals = poly.intervals(sqf=True) + + dps, done = mp.dps, False + + try: + while not done: + alg = func() + + for a, b in intervals: + if a <= alg.a and alg.b <= b: + done = True + break + else: + mp.dps *= 2 + finally: + mp.dps = dps + + if eps is not None: + a, b = poly.refine_root(a, b, eps=eps, fast=fast) + + return (a, b) diff --git a/phivenv/Lib/site-packages/sympy/polys/puiseux.py b/phivenv/Lib/site-packages/sympy/polys/puiseux.py new file mode 100644 index 0000000000000000000000000000000000000000..446dc9c1a5e0d873cdf23da37d2c2430ba0bac6e --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/puiseux.py @@ -0,0 +1,795 @@ +""" +Puiseux rings. These are used by the ring_series module to represented +truncated Puiseux series. Elements of a Puiseux ring are like polynomials +except that the exponents can be negative or rational rather than just +non-negative integers. +""" + +# Previously the ring_series module used PolyElement to represent Puiseux +# series. This is problematic because it means that PolyElement has to support +# negative and non-integer exponents which most polynomial representations do +# not support. This module provides an implementation of a ring for Puiseux +# series that can be used by ring_series without breaking the basic invariants +# of polynomial rings. +# +# Ideally there would be more of a proper series type that can keep track of +# not just the leading terms of a truncated series but also the precision +# of the series. For now the rings here are just introduced to keep the +# interface that ring_series was using before. + +from __future__ import annotations + +from sympy.polys.domains import QQ +from sympy.polys.rings import PolyRing, PolyElement +from sympy.core.add import Add +from sympy.core.mul import Mul +from sympy.external.gmpy import gcd, lcm + + +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from typing import Any, Unpack + from sympy.core.expr import Expr + from sympy.polys.domains import Domain + from collections.abc import Iterable, Iterator + + +def puiseux_ring( + symbols: str | list[Expr], domain: Domain +) -> tuple[PuiseuxRing, Unpack[tuple[PuiseuxPoly, ...]]]: + """Construct a Puiseux ring. + + This function constructs a Puiseux ring with the given symbols and domain. + + >>> from sympy.polys.domains import QQ + >>> from sympy.polys.puiseux import puiseux_ring + >>> R, x, y = puiseux_ring('x y', QQ) + >>> R + PuiseuxRing((x, y), QQ) + >>> p = 5*x**QQ(1,2) + 7/y + >>> p + 7*y**(-1) + 5*x**(1/2) + """ + ring = PuiseuxRing(symbols, domain) + return (ring,) + ring.gens # type: ignore + + +class PuiseuxRing: + """Ring of Puiseux polynomials. + + A Puiseux polynomial is a truncated Puiseux series. The exponents of the + monomials can be negative or rational numbers. This ring is used by the + ring_series module: + + >>> from sympy.polys.domains import QQ + >>> from sympy.polys.puiseux import puiseux_ring + >>> from sympy.polys.ring_series import rs_exp, rs_nth_root + >>> ring, x, y = puiseux_ring('x y', QQ) + >>> f = x**2 + y**3 + >>> f + y**3 + x**2 + >>> f.diff(x) + 2*x + >>> rs_exp(x, x, 5) + 1 + x + 1/2*x**2 + 1/6*x**3 + 1/24*x**4 + + Importantly the Puiseux ring can represent truncated series with negative + and fractional exponents: + + >>> f = 1/x + 1/y**2 + >>> f + x**(-1) + y**(-2) + >>> f.diff(x) + -1*x**(-2) + + >>> rs_nth_root(8*x + x**2 + x**3, 3, x, 5) + 2*x**(1/3) + 1/12*x**(4/3) + 23/288*x**(7/3) + -139/20736*x**(10/3) + + See Also + ======== + + sympy.polys.ring_series.rs_series + PuiseuxPoly + """ + def __init__(self, symbols: str | list[Expr], domain: Domain): + + poly_ring = PolyRing(symbols, domain) + + domain = poly_ring.domain + ngens = poly_ring.ngens + + self.poly_ring = poly_ring + self.domain = domain + + self.symbols = poly_ring.symbols + self.gens = tuple([self.from_poly(g) for g in poly_ring.gens]) + self.ngens = ngens + + self.zero = self.from_poly(poly_ring.zero) + self.one = self.from_poly(poly_ring.one) + + self.zero_monom = poly_ring.zero_monom # type: ignore + self.monomial_mul = poly_ring.monomial_mul # type: ignore + + def __repr__(self) -> str: + return f"PuiseuxRing({self.symbols}, {self.domain})" + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, PuiseuxRing): + return NotImplemented + return self.symbols == other.symbols and self.domain == other.domain + + def from_poly(self, poly: PolyElement) -> PuiseuxPoly: + """Create a Puiseux polynomial from a polynomial. + + >>> from sympy.polys.domains import QQ + >>> from sympy.polys.rings import ring + >>> from sympy.polys.puiseux import puiseux_ring + >>> R1, x1 = ring('x', QQ) + >>> R2, x2 = puiseux_ring('x', QQ) + >>> R2.from_poly(x1**2) + x**2 + """ + return PuiseuxPoly(poly, self) + + def from_dict(self, terms: dict[tuple[int, ...], Any]) -> PuiseuxPoly: + """Create a Puiseux polynomial from a dictionary of terms. + + >>> from sympy.polys.domains import QQ + >>> from sympy.polys.puiseux import puiseux_ring + >>> R, x = puiseux_ring('x', QQ) + >>> R.from_dict({(QQ(1,2),): QQ(3)}) + 3*x**(1/2) + """ + return PuiseuxPoly.from_dict(terms, self) + + def from_int(self, n: int) -> PuiseuxPoly: + """Create a Puiseux polynomial from an integer. + + >>> from sympy.polys.domains import QQ + >>> from sympy.polys.puiseux import puiseux_ring + >>> R, x = puiseux_ring('x', QQ) + >>> R.from_int(3) + 3 + """ + return self.from_poly(self.poly_ring(n)) + + def domain_new(self, arg: Any) -> Any: + """Create a new element of the domain. + + >>> from sympy.polys.domains import QQ + >>> from sympy.polys.puiseux import puiseux_ring + >>> R, x = puiseux_ring('x', QQ) + >>> R.domain_new(3) + 3 + >>> QQ.of_type(_) + True + """ + return self.poly_ring.domain_new(arg) + + def ground_new(self, arg: Any) -> PuiseuxPoly: + """Create a new element from a ground element. + + >>> from sympy.polys.domains import QQ + >>> from sympy.polys.puiseux import puiseux_ring, PuiseuxPoly + >>> R, x = puiseux_ring('x', QQ) + >>> R.ground_new(3) + 3 + >>> isinstance(_, PuiseuxPoly) + True + """ + return self.from_poly(self.poly_ring.ground_new(arg)) + + def __call__(self, arg: Any) -> PuiseuxPoly: + """Coerce an element into the ring. + + >>> from sympy.polys.domains import QQ + >>> from sympy.polys.puiseux import puiseux_ring + >>> R, x = puiseux_ring('x', QQ) + >>> R(3) + 3 + >>> R({(QQ(1,2),): QQ(3)}) + 3*x**(1/2) + """ + if isinstance(arg, dict): + return self.from_dict(arg) + else: + return self.from_poly(self.poly_ring(arg)) + + def index(self, x: PuiseuxPoly) -> int: + """Return the index of a generator. + + >>> from sympy.polys.domains import QQ + >>> from sympy.polys.puiseux import puiseux_ring + >>> R, x, y = puiseux_ring('x y', QQ) + >>> R.index(x) + 0 + >>> R.index(y) + 1 + """ + return self.gens.index(x) + + +def _div_poly_monom(poly: PolyElement, monom: Iterable[int]) -> PolyElement: + ring = poly.ring + div = ring.monomial_div + return ring.from_dict({div(m, monom): c for m, c in poly.terms()}) + + +def _mul_poly_monom(poly: PolyElement, monom: Iterable[int]) -> PolyElement: + ring = poly.ring + mul = ring.monomial_mul + return ring.from_dict({mul(m, monom): c for m, c in poly.terms()}) + + +def _div_monom(monom: Iterable[int], div: Iterable[int]) -> tuple[int, ...]: + return tuple(mi - di for mi, di in zip(monom, div)) + + +class PuiseuxPoly: + """Puiseux polynomial. Represents a truncated Puiseux series. + + See the :class:`PuiseuxRing` class for more information. + + >>> from sympy import QQ + >>> from sympy.polys.puiseux import puiseux_ring + >>> R, x, y = puiseux_ring('x, y', QQ) + >>> p = 5*x**2 + 7*y**3 + >>> p + 7*y**3 + 5*x**2 + + The internal representation of a Puiseux polynomial wraps a normal + polynomial. To support negative powers the polynomial is considered to be + divided by a monomial. + + >>> p2 = 1/x + 1/y**2 + >>> p2.monom # x*y**2 + (1, 2) + >>> p2.poly + x + y**2 + >>> (y**2 + x) / (x*y**2) == p2 + True + + To support fractional powers the polynomial is considered to be a function + of ``x**(1/nx), y**(1/ny), ...``. The representation keeps track of a + monomial and a list of exponent denominators so that the polynomial can be + used to represent both negative and fractional powers. + + >>> p3 = x**QQ(1,2) + y**QQ(2,3) + >>> p3.ns + (2, 3) + >>> p3.poly + x + y**2 + + See Also + ======== + + sympy.polys.puiseux.PuiseuxRing + sympy.polys.rings.PolyElement + """ + + ring: PuiseuxRing + poly: PolyElement + monom: tuple[int, ...] | None + ns: tuple[int, ...] | None + + def __new__(cls, poly: PolyElement, ring: PuiseuxRing) -> PuiseuxPoly: + return cls._new(ring, poly, None, None) + + @classmethod + def _new( + cls, + ring: PuiseuxRing, + poly: PolyElement, + monom: tuple[int, ...] | None, + ns: tuple[int, ...] | None, + ) -> PuiseuxPoly: + poly, monom, ns = cls._normalize(poly, monom, ns) + return cls._new_raw(ring, poly, monom, ns) + + @classmethod + def _new_raw( + cls, + ring: PuiseuxRing, + poly: PolyElement, + monom: tuple[int, ...] | None, + ns: tuple[int, ...] | None, + ) -> PuiseuxPoly: + obj = object.__new__(cls) + obj.ring = ring + obj.poly = poly + obj.monom = monom + obj.ns = ns + return obj + + def __eq__(self, other: Any) -> bool: + if isinstance(other, PuiseuxPoly): + return ( + self.poly == other.poly + and self.monom == other.monom + and self.ns == other.ns + ) + elif self.monom is None and self.ns is None: + return self.poly.__eq__(other) + else: + return NotImplemented + + @classmethod + def _normalize( + cls, + poly: PolyElement, + monom: tuple[int, ...] | None, + ns: tuple[int, ...] | None, + ) -> tuple[PolyElement, tuple[int, ...] | None, tuple[int, ...] | None]: + if monom is None and ns is None: + return poly, None, None + + if monom is not None: + degs = [max(d, 0) for d in poly.tail_degrees()] + if all(di >= mi for di, mi in zip(degs, monom)): + poly = _div_poly_monom(poly, monom) + monom = None + elif any(degs): + poly = _div_poly_monom(poly, degs) + monom = _div_monom(monom, degs) + + if ns is not None: + factors_d, [poly_d] = poly.deflate() + degrees = poly.degrees() + monom_d = monom if monom is not None else [0] * len(degrees) + ns_new = [] + monom_new = [] + inflations = [] + for fi, ni, di, mi in zip(factors_d, ns, degrees, monom_d): + if di == 0: + g = gcd(ni, mi) + else: + g = gcd(fi, ni, mi) + ns_new.append(ni // g) + monom_new.append(mi // g) + inflations.append(fi // g) + + if any(infl > 1 for infl in inflations): + poly_d = poly_d.inflate(inflations) + + poly = poly_d + + if monom is not None: + monom = tuple(monom_new) + + if all(n == 1 for n in ns_new): + ns = None + else: + ns = tuple(ns_new) + + return poly, monom, ns + + @classmethod + def _monom_fromint( + cls, + monom: tuple[int, ...], + dmonom: tuple[int, ...] | None, + ns: tuple[int, ...] | None, + ) -> tuple[Any, ...]: + if dmonom is not None and ns is not None: + return tuple(QQ(mi - di, ni) for mi, di, ni in zip(monom, dmonom, ns)) + elif dmonom is not None: + return tuple(QQ(mi - di) for mi, di in zip(monom, dmonom)) + elif ns is not None: + return tuple(QQ(mi, ni) for mi, ni in zip(monom, ns)) + else: + return tuple(QQ(mi) for mi in monom) + + @classmethod + def _monom_toint( + cls, + monom: tuple[Any, ...], + dmonom: tuple[int, ...] | None, + ns: tuple[int, ...] | None, + ) -> tuple[int, ...]: + if dmonom is not None and ns is not None: + return tuple( + int((mi * ni).numerator + di) for mi, di, ni in zip(monom, dmonom, ns) + ) + elif dmonom is not None: + return tuple(int(mi.numerator + di) for mi, di in zip(monom, dmonom)) + elif ns is not None: + return tuple(int((mi * ni).numerator) for mi, ni in zip(monom, ns)) + else: + return tuple(int(mi.numerator) for mi in monom) + + def itermonoms(self) -> Iterator[tuple[Any, ...]]: + """Iterate over the monomials of a Puiseux polynomial. + + >>> from sympy import QQ + >>> from sympy.polys.puiseux import puiseux_ring + >>> R, x, y = puiseux_ring('x, y', QQ) + >>> p = 5*x**2 + 7*y**3 + >>> list(p.itermonoms()) + [(2, 0), (0, 3)] + >>> p[(2, 0)] + 5 + """ + monom, ns = self.monom, self.ns + for m in self.poly.itermonoms(): + yield self._monom_fromint(m, monom, ns) + + def monoms(self) -> list[tuple[Any, ...]]: + """Return a list of the monomials of a Puiseux polynomial.""" + return list(self.itermonoms()) + + def __iter__(self) -> Iterator[tuple[tuple[Any, ...], Any]]: + return self.itermonoms() + + def __getitem__(self, monom: tuple[int, ...]) -> Any: + monom = self._monom_toint(monom, self.monom, self.ns) + return self.poly[monom] + + def __len__(self) -> int: + return len(self.poly) + + def iterterms(self) -> Iterator[tuple[tuple[Any, ...], Any]]: + """Iterate over the terms of a Puiseux polynomial. + + >>> from sympy import QQ + >>> from sympy.polys.puiseux import puiseux_ring + >>> R, x, y = puiseux_ring('x, y', QQ) + >>> p = 5*x**2 + 7*y**3 + >>> list(p.iterterms()) + [((2, 0), 5), ((0, 3), 7)] + """ + monom, ns = self.monom, self.ns + for m, coeff in self.poly.iterterms(): + mq = self._monom_fromint(m, monom, ns) + yield mq, coeff + + def terms(self) -> list[tuple[tuple[Any, ...], Any]]: + """Return a list of the terms of a Puiseux polynomial.""" + return list(self.iterterms()) + + @property + def is_term(self) -> bool: + """Return True if the Puiseux polynomial is a single term.""" + return self.poly.is_term + + def to_dict(self) -> dict[tuple[int, ...], Any]: + """Return a dictionary representation of a Puiseux polynomial.""" + return dict(self.iterterms()) + + @classmethod + def from_dict( + cls, terms: dict[tuple[Any, ...], Any], ring: PuiseuxRing + ) -> PuiseuxPoly: + """Create a Puiseux polynomial from a dictionary of terms. + + >>> from sympy import QQ + >>> from sympy.polys.puiseux import puiseux_ring, PuiseuxPoly + >>> R, x = puiseux_ring('x', QQ) + >>> PuiseuxPoly.from_dict({(QQ(1,2),): QQ(3)}, R) + 3*x**(1/2) + >>> R.from_dict({(QQ(1,2),): QQ(3)}) + 3*x**(1/2) + """ + ns = [1] * ring.ngens + mon = [0] * ring.ngens + for mo in terms: + ns = [lcm(n, m.denominator) for n, m in zip(ns, mo)] + mon = [min(m, n) for m, n in zip(mo, mon)] + + if not any(mon): + monom = None + else: + monom = tuple(-int((m * n).numerator) for m, n in zip(mon, ns)) + + if all(n == 1 for n in ns): + ns_final = None + else: + ns_final = tuple(ns) + + terms_p = {cls._monom_toint(m, monom, ns_final): coeff for m, coeff in terms.items()} + + poly = ring.poly_ring.from_dict(terms_p) + + return cls._new(ring, poly, monom, ns_final) + + def as_expr(self) -> Expr: + """Convert a Puiseux polynomial to :class:`~sympy.core.expr.Expr`. + + >>> from sympy import QQ, Expr + >>> from sympy.polys.puiseux import puiseux_ring + >>> R, x = puiseux_ring('x', QQ) + >>> p = 5*x**2 + 7*x**3 + >>> p.as_expr() + 7*x**3 + 5*x**2 + >>> isinstance(_, Expr) + True + """ + ring = self.ring + dom = ring.domain + symbols = ring.symbols + terms = [] + for monom, coeff in self.iterterms(): + coeff_expr = dom.to_sympy(coeff) + monoms_expr = [] + for i, m in enumerate(monom): + monoms_expr.append(symbols[i] ** m) + terms.append(Mul(coeff_expr, *monoms_expr)) + return Add(*terms) + + def __repr__(self) -> str: + + def format_power(base: str, exp: int) -> str: + if exp == 1: + return base + elif exp >= 0 and int(exp) == exp: + return f"{base}**{exp}" + else: + return f"{base}**({exp})" + + ring = self.ring + dom = ring.domain + + syms = [str(s) for s in ring.symbols] + terms_str = [] + for monom, coeff in sorted(self.terms()): + monom_str = "*".join(format_power(s, e) for s, e in zip(syms, monom) if e) + if coeff == dom.one: + if monom_str: + terms_str.append(monom_str) + else: + terms_str.append("1") + elif not monom_str: + terms_str.append(str(coeff)) + else: + terms_str.append(f"{coeff}*{monom_str}") + + return " + ".join(terms_str) + + def _unify( + self, other: PuiseuxPoly + ) -> tuple[ + PolyElement, PolyElement, tuple[int, ...] | None, tuple[int, ...] | None + ]: + """Bring two Puiseux polynomials to a common monom and ns.""" + poly1, monom1, ns1 = self.poly, self.monom, self.ns + poly2, monom2, ns2 = other.poly, other.monom, other.ns + + if monom1 == monom2 and ns1 == ns2: + return poly1, poly2, monom1, ns1 + + if ns1 == ns2: + ns = ns1 + elif ns1 is not None and ns2 is not None: + ns = tuple(lcm(n1, n2) for n1, n2 in zip(ns1, ns2)) + f1 = [n // n1 for n, n1 in zip(ns, ns1)] + f2 = [n // n2 for n, n2 in zip(ns, ns2)] + poly1 = poly1.inflate(f1) + poly2 = poly2.inflate(f2) + if monom1 is not None: + monom1 = tuple(m * f for m, f in zip(monom1, f1)) + if monom2 is not None: + monom2 = tuple(m * f for m, f in zip(monom2, f2)) + elif ns2 is not None: + ns = ns2 + poly1 = poly1.inflate(ns) + if monom1 is not None: + monom1 = tuple(m * n for m, n in zip(monom1, ns)) + elif ns1 is not None: + ns = ns1 + poly2 = poly2.inflate(ns) + if monom2 is not None: + monom2 = tuple(m * n for m, n in zip(monom2, ns)) + else: + assert False + + if monom1 == monom2: + monom = monom1 + elif monom1 is not None and monom2 is not None: + monom = tuple(max(m1, m2) for m1, m2 in zip(monom1, monom2)) + poly1 = _mul_poly_monom(poly1, _div_monom(monom, monom1)) + poly2 = _mul_poly_monom(poly2, _div_monom(monom, monom2)) + elif monom2 is not None: + monom = monom2 + poly1 = _mul_poly_monom(poly1, monom2) + elif monom1 is not None: + monom = monom1 + poly2 = _mul_poly_monom(poly2, monom1) + else: + assert False + + return poly1, poly2, monom, ns + + def __pos__(self) -> PuiseuxPoly: + return self + + def __neg__(self) -> PuiseuxPoly: + return self._new_raw(self.ring, -self.poly, self.monom, self.ns) + + def __add__(self, other: Any) -> PuiseuxPoly: + if isinstance(other, PuiseuxPoly): + if self.ring != other.ring: + raise ValueError("Cannot add Puiseux polynomials from different rings") + return self._add(other) + domain = self.ring.domain + if isinstance(other, int): + return self._add_ground(domain.convert_from(QQ(other), QQ)) + elif domain.of_type(other): + return self._add_ground(other) + else: + return NotImplemented + + def __radd__(self, other: Any) -> PuiseuxPoly: + domain = self.ring.domain + if isinstance(other, int): + return self._add_ground(domain.convert_from(QQ(other), QQ)) + elif domain.of_type(other): + return self._add_ground(other) + else: + return NotImplemented + + def __sub__(self, other: Any) -> PuiseuxPoly: + if isinstance(other, PuiseuxPoly): + if self.ring != other.ring: + raise ValueError( + "Cannot subtract Puiseux polynomials from different rings" + ) + return self._sub(other) + domain = self.ring.domain + if isinstance(other, int): + return self._sub_ground(domain.convert_from(QQ(other), QQ)) + elif domain.of_type(other): + return self._sub_ground(other) + else: + return NotImplemented + + def __rsub__(self, other: Any) -> PuiseuxPoly: + domain = self.ring.domain + if isinstance(other, int): + return self._rsub_ground(domain.convert_from(QQ(other), QQ)) + elif domain.of_type(other): + return self._rsub_ground(other) + else: + return NotImplemented + + def __mul__(self, other: Any) -> PuiseuxPoly: + if isinstance(other, PuiseuxPoly): + if self.ring != other.ring: + raise ValueError( + "Cannot multiply Puiseux polynomials from different rings" + ) + return self._mul(other) + domain = self.ring.domain + if isinstance(other, int): + return self._mul_ground(domain.convert_from(QQ(other), QQ)) + elif domain.of_type(other): + return self._mul_ground(other) + else: + return NotImplemented + + def __rmul__(self, other: Any) -> PuiseuxPoly: + domain = self.ring.domain + if isinstance(other, int): + return self._mul_ground(domain.convert_from(QQ(other), QQ)) + elif domain.of_type(other): + return self._mul_ground(other) + else: + return NotImplemented + + def __pow__(self, other: Any) -> PuiseuxPoly: + if isinstance(other, int): + if other >= 0: + return self._pow_pint(other) + else: + return self._pow_nint(-other) + elif QQ.of_type(other): + return self._pow_rational(other) + else: + return NotImplemented + + def __truediv__(self, other: Any) -> PuiseuxPoly: + if isinstance(other, PuiseuxPoly): + if self.ring != other.ring: + raise ValueError( + "Cannot divide Puiseux polynomials from different rings" + ) + return self._mul(other._inv()) + domain = self.ring.domain + if isinstance(other, int): + return self._mul_ground(domain.convert_from(QQ(1, other), QQ)) + elif domain.of_type(other): + return self._div_ground(other) + else: + return NotImplemented + + def __rtruediv__(self, other: Any) -> PuiseuxPoly: + if isinstance(other, int): + return self._inv()._mul_ground(self.ring.domain.convert_from(QQ(other), QQ)) + elif self.ring.domain.of_type(other): + return self._inv()._mul_ground(other) + else: + return NotImplemented + + def _add(self, other: PuiseuxPoly) -> PuiseuxPoly: + poly1, poly2, monom, ns = self._unify(other) + return self._new(self.ring, poly1 + poly2, monom, ns) + + def _add_ground(self, ground: Any) -> PuiseuxPoly: + return self._add(self.ring.ground_new(ground)) + + def _sub(self, other: PuiseuxPoly) -> PuiseuxPoly: + poly1, poly2, monom, ns = self._unify(other) + return self._new(self.ring, poly1 - poly2, monom, ns) + + def _sub_ground(self, ground: Any) -> PuiseuxPoly: + return self._sub(self.ring.ground_new(ground)) + + def _rsub_ground(self, ground: Any) -> PuiseuxPoly: + return self.ring.ground_new(ground)._sub(self) + + def _mul(self, other: PuiseuxPoly) -> PuiseuxPoly: + poly1, poly2, monom, ns = self._unify(other) + if monom is not None: + monom = tuple(2 * e for e in monom) + return self._new(self.ring, poly1 * poly2, monom, ns) + + def _mul_ground(self, ground: Any) -> PuiseuxPoly: + return self._new_raw(self.ring, self.poly * ground, self.monom, self.ns) + + def _div_ground(self, ground: Any) -> PuiseuxPoly: + return self._new_raw(self.ring, self.poly / ground, self.monom, self.ns) + + def _pow_pint(self, n: int) -> PuiseuxPoly: + assert n >= 0 + monom = self.monom + if monom is not None: + monom = tuple(m * n for m in monom) + return self._new(self.ring, self.poly**n, monom, self.ns) + + def _pow_nint(self, n: int) -> PuiseuxPoly: + return self._inv()._pow_pint(n) + + def _pow_rational(self, n: Any) -> PuiseuxPoly: + if not self.is_term: + raise ValueError("Only monomials can be raised to a rational power") + [(monom, coeff)] = self.terms() + domain = self.ring.domain + if not domain.is_one(coeff): + raise ValueError("Only monomials can be raised to a rational power") + monom = tuple(m * n for m in monom) + return self.ring.from_dict({monom: domain.one}) + + def _inv(self) -> PuiseuxPoly: + if not self.is_term: + raise ValueError("Only terms can be inverted") + [(monom, coeff)] = self.terms() + domain = self.ring.domain + if not domain.is_Field and not domain.is_one(coeff): + raise ValueError("Cannot invert non-unit coefficient") + monom = tuple(-m for m in monom) + coeff = 1 / coeff + return self.ring.from_dict({monom: coeff}) + + def diff(self, x: PuiseuxPoly) -> PuiseuxPoly: + """Differentiate a Puiseux polynomial with respect to a variable. + + >>> from sympy import QQ + >>> from sympy.polys.puiseux import puiseux_ring + >>> R, x, y = puiseux_ring('x, y', QQ) + >>> p = 5*x**2 + 7*y**3 + >>> p.diff(x) + 10*x + >>> p.diff(y) + 21*y**2 + """ + ring = self.ring + i = ring.index(x) + g = {} + for expv, coeff in self.iterterms(): + n = expv[i] + if n: + e = list(expv) + e[i] -= 1 + g[tuple(e)] = coeff * n + return ring(g) diff --git a/phivenv/Lib/site-packages/sympy/polys/rationaltools.py b/phivenv/Lib/site-packages/sympy/polys/rationaltools.py new file mode 100644 index 0000000000000000000000000000000000000000..0ca513ff2d4af96baaaf1c82caf501750b1524da --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/rationaltools.py @@ -0,0 +1,85 @@ +"""Tools for manipulation of rational expressions. """ + + +from sympy.core import Basic, Add, sympify +from sympy.core.exprtools import gcd_terms +from sympy.utilities import public +from sympy.utilities.iterables import iterable + + +@public +def together(expr, deep=False, fraction=True): + """ + Denest and combine rational expressions using symbolic methods. + + This function takes an expression or a container of expressions + and puts it (them) together by denesting and combining rational + subexpressions. No heroic measures are taken to minimize degree + of the resulting numerator and denominator. To obtain completely + reduced expression use :func:`~.cancel`. However, :func:`~.together` + can preserve as much as possible of the structure of the input + expression in the output (no expansion is performed). + + A wide variety of objects can be put together including lists, + tuples, sets, relational objects, integrals and others. It is + also possible to transform interior of function applications, + by setting ``deep`` flag to ``True``. + + By definition, :func:`~.together` is a complement to :func:`~.apart`, + so ``apart(together(expr))`` should return expr unchanged. Note + however, that :func:`~.together` uses only symbolic methods, so + it might be necessary to use :func:`~.cancel` to perform algebraic + simplification and minimize degree of the numerator and denominator. + + Examples + ======== + + >>> from sympy import together, exp + >>> from sympy.abc import x, y, z + + >>> together(1/x + 1/y) + (x + y)/(x*y) + >>> together(1/x + 1/y + 1/z) + (x*y + x*z + y*z)/(x*y*z) + + >>> together(1/(x*y) + 1/y**2) + (x + y)/(x*y**2) + + >>> together(1/(1 + 1/x) + 1/(1 + 1/y)) + (x*(y + 1) + y*(x + 1))/((x + 1)*(y + 1)) + + >>> together(exp(1/x + 1/y)) + exp(1/y + 1/x) + >>> together(exp(1/x + 1/y), deep=True) + exp((x + y)/(x*y)) + + >>> together(1/exp(x) + 1/(x*exp(x))) + (x + 1)*exp(-x)/x + + >>> together(1/exp(2*x) + 1/(x*exp(3*x))) + (x*exp(x) + 1)*exp(-3*x)/x + + """ + def _together(expr): + if isinstance(expr, Basic): + if expr.is_Atom or (expr.is_Function and not deep): + return expr + elif expr.is_Add: + return gcd_terms(list(map(_together, Add.make_args(expr))), fraction=fraction) + elif expr.is_Pow: + base = _together(expr.base) + + if deep: + exp = _together(expr.exp) + else: + exp = expr.exp + + return expr.func(base, exp) + else: + return expr.func(*[ _together(arg) for arg in expr.args ]) + elif iterable(expr): + return expr.__class__([ _together(ex) for ex in expr ]) + + return expr + + return _together(sympify(expr)) diff --git a/phivenv/Lib/site-packages/sympy/polys/ring_series.py b/phivenv/Lib/site-packages/sympy/polys/ring_series.py new file mode 100644 index 0000000000000000000000000000000000000000..b4333f0add9365991794d920a2699722900e8a5e --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/ring_series.py @@ -0,0 +1,2127 @@ +"""Power series evaluation and manipulation using sparse Polynomials + +Implementing a new function +--------------------------- + +There are a few things to be kept in mind when adding a new function here:: + + - The implementation should work on all possible input domains/rings. + Special cases include the ``EX`` ring and a constant term in the series + to be expanded. There can be two types of constant terms in the series: + + + A constant value or symbol. + + A term of a multivariate series not involving the generator, with + respect to which the series is to expanded. + + Strictly speaking, a generator of a ring should not be considered a + constant. However, for series expansion both the cases need similar + treatment (as the user does not care about inner details), i.e, use an + addition formula to separate the constant part and the variable part (see + rs_sin for reference). + + - All the algorithms used here are primarily designed to work for Taylor + series (number of iterations in the algo equals the required order). + Hence, it becomes tricky to get the series of the right order if a + Puiseux series is input. Use rs_puiseux? in your function if your + algorithm is not designed to handle fractional powers. + +Extending rs_series +------------------- + +To make a function work with rs_series you need to do two things:: + + - Many sure it works with a constant term (as explained above). + - If the series contains constant terms, you might need to extend its ring. + You do so by adding the new terms to the rings as generators. + ``PolyRing.compose`` and ``PolyRing.add_gens`` are two functions that do + so and need to be called every time you expand a series containing a + constant term. + +Look at rs_sin and rs_series for further reference. + +""" + +from sympy.polys.domains import QQ, EX +from sympy.polys.rings import PolyElement, ring, sring +from sympy.polys.puiseux import PuiseuxPoly +from sympy.polys.polyerrors import DomainError +from sympy.polys.monomials import (monomial_min, monomial_mul, monomial_div, + monomial_ldiv) +from mpmath.libmp.libintmath import ifac +from sympy.core import PoleError, Function, Expr +from sympy.core.numbers import Rational +from sympy.core.intfunc import igcd +from sympy.functions import (sin, cos, tan, atan, exp, atanh, asinh, tanh, log, + ceiling, sinh, cosh) +from sympy.utilities.misc import as_int +from mpmath.libmp.libintmath import giant_steps +import math + + +def _invert_monoms(p1): + """ + Compute ``x**n * p1(1/x)`` for a univariate polynomial ``p1`` in ``x``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.rings import ring + >>> from sympy.polys.ring_series import _invert_monoms + >>> R, x = ring('x', ZZ) + >>> p = x**2 + 2*x + 3 + >>> _invert_monoms(p) + 3*x**2 + 2*x + 1 + + See Also + ======== + + sympy.polys.densebasic.dup_reverse + """ + terms = list(p1.items()) + terms.sort() + deg = p1.degree() + R = p1.ring + p = R.zero + cv = p1.listcoeffs() + mv = p1.listmonoms() + for mvi, cvi in zip(mv, cv): + p[(deg - mvi[0],)] = cvi + return p + +def _giant_steps(target): + """Return a list of precision steps for the Newton's method""" + # We use ceil here because giant_steps cannot handle flint.fmpq + res = giant_steps(2, math.ceil(target)) + if res[0] != 2: + res = [2] + res + return res + +def rs_trunc(p1, x, prec): + """ + Truncate the series in the ``x`` variable with precision ``prec``, + that is, modulo ``O(x**prec)`` + + Examples + ======== + + >>> from sympy.polys.domains import QQ + >>> from sympy.polys.rings import ring + >>> from sympy.polys.ring_series import rs_trunc + >>> R, x = ring('x', QQ) + >>> p = x**10 + x**5 + x + 1 + >>> rs_trunc(p, x, 12) + x**10 + x**5 + x + 1 + >>> rs_trunc(p, x, 10) + x**5 + x + 1 + """ + R = p1.ring + p = {} + i = R.gens.index(x) + for exp1 in p1: + if exp1[i] >= prec: + continue + p[exp1] = p1[exp1] + return R(p) + +def rs_is_puiseux(p, x): + """ + Test if ``p`` is Puiseux series in ``x``. + + Raise an exception if it has a negative power in ``x``. + + Examples + ======== + + >>> from sympy.polys.domains import QQ + >>> from sympy.polys.puiseux import puiseux_ring + >>> from sympy.polys.ring_series import rs_is_puiseux + >>> R, x = puiseux_ring('x', QQ) + >>> p = x**QQ(2,5) + x**QQ(2,3) + x + >>> rs_is_puiseux(p, x) + True + """ + index = p.ring.gens.index(x) + for k in p.itermonoms(): + if k[index] != int(k[index]): + return True + if k[index] < 0: + raise ValueError('The series is not regular in %s' % x) + return False + +def rs_puiseux(f, p, x, prec): + """ + Return the puiseux series for `f(p, x, prec)`. + + To be used when function ``f`` is implemented only for regular series. + + Examples + ======== + + >>> from sympy.polys.domains import QQ + >>> from sympy.polys.puiseux import puiseux_ring + >>> from sympy.polys.ring_series import rs_puiseux, rs_exp + >>> R, x = puiseux_ring('x', QQ) + >>> p = x**QQ(2,5) + x**QQ(2,3) + x + >>> rs_puiseux(rs_exp,p, x, 1) + 1 + x**(2/5) + x**(2/3) + 1/2*x**(4/5) + """ + index = p.ring.gens.index(x) + n = 1 + for k in p: + power = k[index] + if isinstance(power, Rational): + num, den = power.as_numer_denom() + n = int(n*den // igcd(n, den)) + elif power != int(power): + den = power.denominator + n = int(n*den // igcd(n, den)) + if n != 1: + p1 = pow_xin(p, index, n) + r = f(p1, x, prec*n) + n1 = QQ(1, n) + if isinstance(r, tuple): + r = tuple([pow_xin(rx, index, n1) for rx in r]) + else: + r = pow_xin(r, index, n1) + else: + r = f(p, x, prec) + return r + +def rs_puiseux2(f, p, q, x, prec): + """ + Return the puiseux series for `f(p, q, x, prec)`. + + To be used when function ``f`` is implemented only for regular series. + """ + index = p.ring.gens.index(x) + n = 1 + for k in p: + power = k[index] + if isinstance(power, Rational): + num, den = power.as_numer_denom() + n = n*den // igcd(n, den) + elif power != int(power): + den = power.denominator + n = n*den // igcd(n, den) + if n != 1: + p1 = pow_xin(p, index, n) + r = f(p1, q, x, prec*n) + n1 = QQ(1, n) + r = pow_xin(r, index, n1) + else: + r = f(p, q, x, prec) + return r + +def rs_mul(p1, p2, x, prec): + """ + Return the product of the given two series, modulo ``O(x**prec)``. + + ``x`` is the series variable or its position in the generators. + + Examples + ======== + + >>> from sympy.polys.domains import QQ + >>> from sympy.polys.rings import ring + >>> from sympy.polys.ring_series import rs_mul + >>> R, x = ring('x', QQ) + >>> p1 = x**2 + 2*x + 1 + >>> p2 = x + 1 + >>> rs_mul(p1, p2, x, 3) + 3*x**2 + 3*x + 1 + """ + R = p1.ring + p = {} + if R.__class__ != p2.ring.__class__ or R != p2.ring: + raise ValueError('p1 and p2 must have the same ring') + iv = R.gens.index(x) + if not isinstance(p2, (PolyElement, PuiseuxPoly)): + raise ValueError('p2 must be a polynomial') + if R == p2.ring: + get = p.get + items2 = p2.terms() + items2.sort(key=lambda e: e[0][iv]) + if R.ngens == 1: + for exp1, v1 in p1.iterterms(): + for exp2, v2 in items2: + exp = exp1[0] + exp2[0] + if exp < prec: + exp = (exp, ) + p[exp] = get(exp, 0) + v1*v2 + else: + break + else: + monomial_mul = R.monomial_mul + for exp1, v1 in p1.iterterms(): + for exp2, v2 in items2: + if exp1[iv] + exp2[iv] < prec: + exp = monomial_mul(exp1, exp2) + p[exp] = get(exp, 0) + v1*v2 + else: + break + + return R(p) + +def rs_square(p1, x, prec): + """ + Square the series modulo ``O(x**prec)`` + + Examples + ======== + + >>> from sympy.polys.domains import QQ + >>> from sympy.polys.rings import ring + >>> from sympy.polys.ring_series import rs_square + >>> R, x = ring('x', QQ) + >>> p = x**2 + 2*x + 1 + >>> rs_square(p, x, 3) + 6*x**2 + 4*x + 1 + """ + R = p1.ring + p = {} + iv = R.gens.index(x) + get = p.get + items = p1.terms() + items.sort(key=lambda e: e[0][iv]) + monomial_mul = R.monomial_mul + for i in range(len(items)): + exp1, v1 = items[i] + for j in range(i): + exp2, v2 = items[j] + if exp1[iv] + exp2[iv] < prec: + exp = monomial_mul(exp1, exp2) + p[exp] = get(exp, 0) + v1*v2 + else: + break + p = {m: 2*v for m, v in p.items()} + get = p.get + for expv, v in p1.iterterms(): + if 2*expv[iv] < prec: + e2 = monomial_mul(expv, expv) + p[e2] = get(e2, 0) + v**2 + return R(p) + +def rs_pow(p1, n, x, prec): + """ + Return ``p1**n`` modulo ``O(x**prec)`` + + Examples + ======== + + >>> from sympy.polys.domains import QQ + >>> from sympy.polys.rings import ring + >>> from sympy.polys.ring_series import rs_pow + >>> R, x = ring('x', QQ) + >>> p = x + 1 + >>> rs_pow(p, 4, x, 3) + 6*x**2 + 4*x + 1 + """ + R = p1.ring + if isinstance(n, Rational): + np = int(n.p) + nq = int(n.q) + if nq != 1: + res = rs_nth_root(p1, nq, x, prec) + if np != 1: + res = rs_pow(res, np, x, prec) + else: + res = rs_pow(p1, np, x, prec) + return res + + n = as_int(n) + if n == 0: + if p1: + return R(1) + else: + raise ValueError('0**0 is undefined') + if n < 0: + p1 = rs_pow(p1, -n, x, prec) + return rs_series_inversion(p1, x, prec) + if n == 1: + return rs_trunc(p1, x, prec) + if n == 2: + return rs_square(p1, x, prec) + if n == 3: + p2 = rs_square(p1, x, prec) + return rs_mul(p1, p2, x, prec) + p = R(1) + while 1: + if n & 1: + p = rs_mul(p1, p, x, prec) + n -= 1 + if not n: + break + p1 = rs_square(p1, x, prec) + n = n // 2 + return p + +def rs_subs(p, rules, x, prec): + """ + Substitution with truncation according to the mapping in ``rules``. + + Return a series with precision ``prec`` in the generator ``x`` + + Note that substitutions are not done one after the other + + >>> from sympy.polys.domains import QQ + >>> from sympy.polys.rings import ring + >>> from sympy.polys.ring_series import rs_subs + >>> R, x, y = ring('x, y', QQ) + >>> p = x**2 + y**2 + >>> rs_subs(p, {x: x+ y, y: x+ 2*y}, x, 3) + 2*x**2 + 6*x*y + 5*y**2 + >>> (x + y)**2 + (x + 2*y)**2 + 2*x**2 + 6*x*y + 5*y**2 + + which differs from + + >>> rs_subs(rs_subs(p, {x: x+ y}, x, 3), {y: x+ 2*y}, x, 3) + 5*x**2 + 12*x*y + 8*y**2 + + Parameters + ---------- + p : :class:`~.PolyElement` Input series. + rules : ``dict`` with substitution mappings. + x : :class:`~.PolyElement` in which the series truncation is to be done. + prec : :class:`~.Integer` order of the series after truncation. + + Examples + ======== + + >>> from sympy.polys.domains import QQ + >>> from sympy.polys.rings import ring + >>> from sympy.polys.ring_series import rs_subs + >>> R, x, y = ring('x, y', QQ) + >>> rs_subs(x**2+y**2, {y: (x+y)**2}, x, 3) + 6*x**2*y**2 + x**2 + 4*x*y**3 + y**4 + """ + R = p.ring + ngens = R.ngens + d = R(0) + for i in range(ngens): + d[(i, 1)] = R.gens[i] + for var in rules: + d[(R.index(var), 1)] = rules[var] + p1 = R(0) + p_keys = sorted(p.keys()) + for expv in p_keys: + p2 = R(1) + for i in range(ngens): + power = expv[i] + if power == 0: + continue + if (i, power) not in d: + q, r = divmod(power, 2) + if r == 0 and (i, q) in d: + d[(i, power)] = rs_square(d[(i, q)], x, prec) + elif (i, power - 1) in d: + d[(i, power)] = rs_mul(d[(i, power - 1)], d[(i, 1)], + x, prec) + else: + d[(i, power)] = rs_pow(d[(i, 1)], power, x, prec) + p2 = rs_mul(p2, d[(i, power)], x, prec) + p1 += p2*p[expv] + return p1 + +def _has_constant_term(p, x): + """ + Check if ``p`` has a constant term in ``x`` + + Examples + ======== + + >>> from sympy.polys.domains import QQ + >>> from sympy.polys.rings import ring + >>> from sympy.polys.ring_series import _has_constant_term + >>> R, x = ring('x', QQ) + >>> p = x**2 + x + 1 + >>> _has_constant_term(p, x) + True + """ + R = p.ring + iv = R.gens.index(x) + zm = R.zero_monom + a = [0]*R.ngens + a[iv] = 1 + miv = tuple(a) + return any(monomial_min(expv, miv) == zm for expv in p) + +def _get_constant_term(p, x): + """Return constant term in p with respect to x + + Note that it is not simply `p[R.zero_monom]` as there might be multiple + generators in the ring R. We want the `x`-free term which can contain other + generators. + """ + R = p.ring + i = R.gens.index(x) + zm = R.zero_monom + a = [0]*R.ngens + a[i] = 1 + miv = tuple(a) + c = 0 + for expv in p: + if monomial_min(expv, miv) == zm: + c += R({expv: p[expv]}) + return c + +def _check_series_var(p, x, name): + index = p.ring.gens.index(x) + m = min(p, key=lambda k: k[index])[index] + if m < 0: + raise PoleError("Asymptotic expansion of %s around [oo] not " + "implemented." % name) + return index, m + +def _series_inversion1(p, x, prec): + """ + Univariate series inversion ``1/p`` modulo ``O(x**prec)``. + + The Newton method is used. + + Examples + ======== + + >>> from sympy.polys.domains import QQ + >>> from sympy.polys.rings import ring + >>> from sympy.polys.ring_series import _series_inversion1 + >>> R, x = ring('x', QQ) + >>> p = x + 1 + >>> _series_inversion1(p, x, 4) + -x**3 + x**2 - x + 1 + """ + if rs_is_puiseux(p, x): + return rs_puiseux(_series_inversion1, p, x, prec) + R = p.ring + zm = R.zero_monom + c = p[zm] + + # giant_steps does not seem to work with PythonRational numbers with 1 as + # denominator. This makes sure such a number is converted to integer. + if prec == int(prec): + prec = int(prec) + + if zm not in p: + raise ValueError("No constant term in series") + if _has_constant_term(p - c, x): + raise ValueError("p cannot contain a constant term depending on " + "parameters") + if not R.domain.is_unit(c): + raise ValueError(f"Constant term {c} must be a unit in {R.domain}") + + one = R(1) + if R.domain is EX: + one = 1 + if c != one: + p1 = R(1)/c + else: + p1 = R(1) + for precx in _giant_steps(prec): + t = 1 - rs_mul(p1, p, x, precx) + p1 = p1 + rs_mul(p1, t, x, precx) + return p1 + +def rs_series_inversion(p, x, prec): + """ + Multivariate series inversion ``1/p`` modulo ``O(x**prec)``. + + Examples + ======== + + >>> from sympy.polys.domains import QQ + >>> from sympy.polys.rings import ring + >>> from sympy.polys.ring_series import rs_series_inversion + >>> R, x, y = ring('x, y', QQ) + >>> rs_series_inversion(1 + x*y**2, x, 4) + -x**3*y**6 + x**2*y**4 - x*y**2 + 1 + >>> rs_series_inversion(1 + x*y**2, y, 4) + -x*y**2 + 1 + >>> rs_series_inversion(x + x**2, x, 4) + x**3 - x**2 + x - 1 + x**(-1) + """ + R = p.ring + if p == R.zero: + raise ZeroDivisionError + zm = R.zero_monom + index = R.gens.index(x) + m = min(p, key=lambda k: k[index])[index] + if m: + p = mul_xin(p, index, -m) + prec = prec + m + if zm not in p: + raise NotImplementedError("No constant term in series") + + if _has_constant_term(p - p[zm], x): + raise NotImplementedError("p - p[0] must not have a constant term in " + "the series variables") + r = _series_inversion1(p, x, prec) + if m != 0: + r = mul_xin(r, index, -m) + return r + +def _coefficient_t(p, t): + r"""Coefficient of `x_i**j` in p, where ``t`` = (i, j)""" + i, j = t + R = p.ring + expv1 = [0]*R.ngens + expv1[i] = j + expv1 = tuple(expv1) + p1 = R(0) + for expv in p: + if expv[i] == j: + p1[monomial_div(expv, expv1)] = p[expv] + return p1 + +def rs_series_reversion(p, x, n, y): + r""" + Reversion of a series. + + ``p`` is a series with ``O(x**n)`` of the form $p = ax + f(x)$ + where $a$ is a number different from 0. + + $f(x) = \sum_{k=2}^{n-1} a_kx_k$ + + Parameters + ========== + + a_k : Can depend polynomially on other variables, not indicated. + x : Variable with name x. + y : Variable with name y. + + Returns + ======= + + Solve $p = y$, that is, given $ax + f(x) - y = 0$, + find the solution $x = r(y)$ up to $O(y^n)$. + + Algorithm + ========= + + If $r_i$ is the solution at order $i$, then: + $ar_i + f(r_i) - y = O\left(y^{i + 1}\right)$ + + and if $r_{i + 1}$ is the solution at order $i + 1$, then: + $ar_{i + 1} + f(r_{i + 1}) - y = O\left(y^{i + 2}\right)$ + + We have, $r_{i + 1} = r_i + e$, such that, + $ae + f(r_i) = O\left(y^{i + 2}\right)$ + or $e = -f(r_i)/a$ + + So we use the recursion relation: + $r_{i + 1} = r_i - f(r_i)/a$ + with the boundary condition: $r_1 = y$ + + Examples + ======== + + >>> from sympy.polys.domains import QQ + >>> from sympy.polys.rings import ring + >>> from sympy.polys.ring_series import rs_series_reversion, rs_trunc + >>> R, x, y, a, b = ring('x, y, a, b', QQ) + >>> p = x - x**2 - 2*b*x**2 + 2*a*b*x**2 + >>> p1 = rs_series_reversion(p, x, 3, y); p1 + -2*y**2*a*b + 2*y**2*b + y**2 + y + >>> rs_trunc(p.compose(x, p1), y, 3) + y + """ + if rs_is_puiseux(p, x): + raise NotImplementedError + R = p.ring + nx = R.gens.index(x) + y = R(y) + ny = R.gens.index(y) + if _has_constant_term(p, x): + raise ValueError("p must not contain a constant term in the series " + "variable") + a = _coefficient_t(p, (nx, 1)) + zm = R.zero_monom + assert zm in a and len(a) == 1 + a = a[zm] + r = y/a + for i in range(2, n): + sp = rs_subs(p, {x: r}, y, i + 1) + sp = _coefficient_t(sp, (ny, i))*y**i + r -= sp/a + return r + +def rs_series_from_list(p, c, x, prec, concur=1): + """ + Return a series `sum c[n]*p**n` modulo `O(x**prec)`. + + It reduces the number of multiplications by summing concurrently. + + `ax = [1, p, p**2, .., p**(J - 1)]` + `s = sum(c[i]*ax[i]` for i in `range(r, (r + 1)*J))*p**((K - 1)*J)` + with `K >= (n + 1)/J` + + Examples + ======== + + >>> from sympy.polys.domains import QQ + >>> from sympy.polys.rings import ring + >>> from sympy.polys.ring_series import rs_series_from_list, rs_trunc + >>> R, x = ring('x', QQ) + >>> p = x**2 + x + 1 + >>> c = [1, 2, 3] + >>> rs_series_from_list(p, c, x, 4) + 6*x**3 + 11*x**2 + 8*x + 6 + >>> rs_trunc(1 + 2*p + 3*p**2, x, 4) + 6*x**3 + 11*x**2 + 8*x + 6 + >>> pc = R.from_list(list(reversed(c))) + >>> rs_trunc(pc.compose(x, p), x, 4) + 6*x**3 + 11*x**2 + 8*x + 6 + + See Also + ======== + + sympy.polys.rings.PolyRing.compose + + """ + R = p.ring + n = len(c) + if not concur: + q = R(1) + s = c[0]*q + for i in range(1, n): + q = rs_mul(q, p, x, prec) + s += c[i]*q + return s + J = int(math.sqrt(n) + 1) + K, r = divmod(n, J) + if r: + K += 1 + ax = [R(1)] + q = R(1) + if len(p) < 20: + for i in range(1, J): + q = rs_mul(q, p, x, prec) + ax.append(q) + else: + for i in range(1, J): + if i % 2 == 0: + q = rs_square(ax[i//2], x, prec) + else: + q = rs_mul(q, p, x, prec) + ax.append(q) + # optimize using rs_square + pj = rs_mul(ax[-1], p, x, prec) + b = R(1) + s = R(0) + for k in range(K - 1): + r = J*k + s1 = c[r] + for j in range(1, J): + s1 += c[r + j]*ax[j] + s1 = rs_mul(s1, b, x, prec) + s += s1 + b = rs_mul(b, pj, x, prec) + if not b: + break + k = K - 1 + r = J*k + if r < n: + s1 = c[r]*R(1) + for j in range(1, J): + if r + j >= n: + break + s1 += c[r + j]*ax[j] + s1 = rs_mul(s1, b, x, prec) + s += s1 + return s + +def rs_diff(p, x): + """ + Return partial derivative of ``p`` with respect to ``x``. + + Parameters + ========== + + x : :class:`~.PolyElement` with respect to which ``p`` is differentiated. + + Examples + ======== + + >>> from sympy.polys.domains import QQ + >>> from sympy.polys.rings import ring + >>> from sympy.polys.ring_series import rs_diff + >>> R, x, y = ring('x, y', QQ) + >>> p = x + x**2*y**3 + >>> rs_diff(p, x) + 2*x*y**3 + 1 + """ + R = p.ring + n = R.gens.index(x) + p1 = {} + mn = [0]*R.ngens + mn[n] = 1 + mn = tuple(mn) + for expv in p: + if expv[n]: + e = monomial_ldiv(expv, mn) + p1[e] = R.domain_new(p[expv]*expv[n]) + return R(p1) + +def rs_integrate(p, x): + """ + Integrate ``p`` with respect to ``x``. + + Parameters + ========== + + x : :class:`~.PolyElement` with respect to which ``p`` is integrated. + + Examples + ======== + + >>> from sympy.polys.domains import QQ + >>> from sympy.polys.rings import ring + >>> from sympy.polys.ring_series import rs_integrate + >>> R, x, y = ring('x, y', QQ) + >>> p = x + x**2*y**3 + >>> rs_integrate(p, x) + 1/3*x**3*y**3 + 1/2*x**2 + """ + R = p.ring + p1 = {} + n = R.gens.index(x) + mn = [0]*R.ngens + mn[n] = 1 + mn = tuple(mn) + + for expv in p: + e = monomial_mul(expv, mn) + p1[e] = R.domain_new(p[expv]/(expv[n] + 1)) + return R(p1) + +def rs_fun(p, f, *args): + r""" + Function of a multivariate series computed by substitution. + + The case with f method name is used to compute `rs\_tan` and `rs\_nth\_root` + of a multivariate series: + + `rs\_fun(p, tan, iv, prec)` + + tan series is first computed for a dummy variable _x, + i.e, `rs\_tan(\_x, iv, prec)`. Then we substitute _x with p to get the + desired series + + Parameters + ========== + + p : :class:`~.PolyElement` The multivariate series to be expanded. + f : `ring\_series` function to be applied on `p`. + args[-2] : :class:`~.PolyElement` with respect to which, the series is to be expanded. + args[-1] : Required order of the expanded series. + + Examples + ======== + + >>> from sympy.polys.domains import QQ + >>> from sympy.polys.rings import ring + >>> from sympy.polys.ring_series import rs_fun, _tan1 + >>> R, x, y = ring('x, y', QQ) + >>> p = x + x*y + x**2*y + x**3*y**2 + >>> rs_fun(p, _tan1, x, 4) + 1/3*x**3*y**3 + 2*x**3*y**2 + x**3*y + 1/3*x**3 + x**2*y + x*y + x + """ + _R = p.ring + R1, _x = ring('_x', _R.domain) + h = int(args[-1]) + args1 = args[:-2] + (_x, h) + zm = _R.zero_monom + # separate the constant term of the series + # compute the univariate series f(_x, .., 'x', sum(nv)) + if zm in p: + x1 = _x + p[zm] + p1 = p - p[zm] + else: + x1 = _x + p1 = p + if isinstance(f, str): + q = getattr(x1, f)(*args1) + else: + q = f(x1, *args1) + a = sorted(q.items()) + c = [0]*h + for x in a: + c[x[0][0]] = x[1] + p1 = rs_series_from_list(p1, c, args[-2], args[-1]) + return p1 + +def mul_xin(p, i, n): + r""" + Return `p*x_i**n`. + + `x\_i` is the ith variable in ``p``. + """ + R = p.ring + q = {} + for k, v in p.terms(): + k1 = list(k) + k1[i] += n + q[tuple(k1)] = v + return R(q) + +def pow_xin(p, i, n): + """ + >>> from sympy.polys.domains import QQ + >>> from sympy.polys.puiseux import puiseux_ring + >>> from sympy.polys.ring_series import pow_xin + >>> R, x, y = puiseux_ring('x, y', QQ) + >>> p = x**QQ(2,5) + x + x**QQ(2,3) + >>> index = p.ring.gens.index(x) + >>> pow_xin(p, index, 15) + x**6 + x**10 + x**15 + """ + R = p.ring + q = {} + for k, v in p.terms(): + k1 = list(k) + k1[i] *= n + q[tuple(k1)] = v + return R(q) + +def _nth_root1(p, n, x, prec): + """ + Univariate series expansion of the nth root of ``p``. + + The Newton method is used. + """ + if rs_is_puiseux(p, x): + return rs_puiseux2(_nth_root1, p, n, x, prec) + R = p.ring + zm = R.zero_monom + if zm not in p: + raise NotImplementedError('No constant term in series') + n = as_int(n) + assert p[zm] == 1 + p1 = R(1) + if p == 1: + return p + if n == 0: + return R(1) + if n == 1: + return p + if n < 0: + n = -n + sign = 1 + else: + sign = 0 + for precx in _giant_steps(prec): + tmp = rs_pow(p1, n + 1, x, precx) + tmp = rs_mul(tmp, p, x, precx) + p1 += p1/n - tmp/n + if sign: + return p1 + else: + return _series_inversion1(p1, x, prec) + +def rs_nth_root(p, n, x, prec): + """ + Multivariate series expansion of the nth root of ``p``. + + Parameters + ========== + + p : Expr + The polynomial to computer the root of. + n : integer + The order of the root to be computed. + x : :class:`~.PolyElement` + prec : integer + Order of the expanded series. + + Notes + ===== + + The result of this function is dependent on the ring over which the + polynomial has been defined. If the answer involves a root of a constant, + make sure that the polynomial is over a real field. It cannot yet handle + roots of symbols. + + Examples + ======== + + >>> from sympy.polys.domains import QQ, RR + >>> from sympy.polys.rings import ring + >>> from sympy.polys.ring_series import rs_nth_root + >>> R, x, y = ring('x, y', QQ) + >>> rs_nth_root(1 + x + x*y, -3, x, 3) + 2/9*x**2*y**2 + 4/9*x**2*y + 2/9*x**2 - 1/3*x*y - 1/3*x + 1 + >>> R, x, y = ring('x, y', RR) + >>> rs_nth_root(3 + x + x*y, 3, x, 2) + 0.160249952256379*x*y + 0.160249952256379*x + 1.44224957030741 + """ + if n == 0: + if p == 0: + raise ValueError('0**0 expression') + else: + return p.ring(1) + if n == 1: + return rs_trunc(p, x, prec) + R = p.ring + index = R.gens.index(x) + m = min(p, key=lambda k: k[index])[index] + p = mul_xin(p, index, -m) + prec -= m + + if _has_constant_term(p - 1, x): + zm = R.zero_monom + c = p[zm] + if isinstance(c, PolyElement): + try: + c_expr = c.as_expr() + const = R(c_expr**(QQ(1, n))) + except ValueError: + raise DomainError("The given series cannot be expanded in " + "this domain.") + else: + try: # RealElement doesn't support + const = R(c**Rational(1, n)) # exponentiation with mpq object + except ValueError: # as exponent + raise DomainError("The given series cannot be expanded in " + "this domain.") + res = rs_nth_root(p/c, n, x, prec)*const + else: + res = _nth_root1(p, n, x, prec) + if m: + m = QQ(m) / n + res = mul_xin(res, index, m) + return res + +def rs_log(p, x, prec): + """ + The Logarithm of ``p`` modulo ``O(x**prec)``. + + Notes + ===== + + Truncation of ``integral dx p**-1*d p/dx`` is used. + + Examples + ======== + + >>> from sympy.polys.domains import QQ + >>> from sympy.polys.puiseux import puiseux_ring + >>> from sympy.polys.ring_series import rs_log + >>> R, x = puiseux_ring('x', QQ) + >>> rs_log(1 + x, x, 8) + x + -1/2*x**2 + 1/3*x**3 + -1/4*x**4 + 1/5*x**5 + -1/6*x**6 + 1/7*x**7 + >>> rs_log(x**QQ(3, 2) + 1, x, 5) + x**(3/2) + -1/2*x**3 + 1/3*x**(9/2) + """ + if rs_is_puiseux(p, x): + return rs_puiseux(rs_log, p, x, prec) + R = p.ring + if p == 1: + return R.zero + c = _get_constant_term(p, x) + if c: + const = 0 + if c == 1: + pass + try: + c_expr = c.as_expr() + const = R(log(c_expr)) + except ValueError: + R = R.add_gens([log(c_expr)]) + p = p.set_ring(R) + x = x.set_ring(R) + c = c.set_ring(R) + const = R(log(c_expr)) + + dlog = p.diff(x) + dlog = rs_mul(dlog, _series_inversion1(p, x, prec), x, prec - 1) + return rs_integrate(dlog, x) + const + else: + raise NotImplementedError + +def rs_LambertW(p, x, prec): + """ + Calculate the series expansion of the principal branch of the Lambert W + function. + + Examples + ======== + + >>> from sympy.polys.domains import QQ + >>> from sympy.polys.rings import ring + >>> from sympy.polys.ring_series import rs_LambertW + >>> R, x, y = ring('x, y', QQ) + >>> rs_LambertW(x + x*y, x, 3) + -x**2*y**2 - 2*x**2*y - x**2 + x*y + x + + See Also + ======== + + LambertW + """ + if rs_is_puiseux(p, x): + return rs_puiseux(rs_LambertW, p, x, prec) + R = p.ring + p1 = R(0) + if _has_constant_term(p, x): + raise NotImplementedError("Polynomial must not have constant term in " + "the series variables") + if x in R.gens: + for precx in _giant_steps(prec): + e = rs_exp(p1, x, precx) + p2 = rs_mul(e, p1, x, precx) - p + p3 = rs_mul(e, p1 + 1, x, precx) + p3 = rs_series_inversion(p3, x, precx) + tmp = rs_mul(p2, p3, x, precx) + p1 -= tmp + return p1 + else: + raise NotImplementedError + +def _exp1(p, x, prec): + r"""Helper function for `rs\_exp`. """ + R = p.ring + p1 = R(1) + for precx in _giant_steps(prec): + pt = p - rs_log(p1, x, precx) + tmp = rs_mul(pt, p1, x, precx) + p1 += tmp + return p1 + +def rs_exp(p, x, prec): + """ + Exponentiation of a series modulo ``O(x**prec)`` + + Examples + ======== + + >>> from sympy.polys.domains import QQ + >>> from sympy.polys.rings import ring + >>> from sympy.polys.ring_series import rs_exp + >>> R, x = ring('x', QQ) + >>> rs_exp(x**2, x, 7) + 1/6*x**6 + 1/2*x**4 + x**2 + 1 + """ + if rs_is_puiseux(p, x): + return rs_puiseux(rs_exp, p, x, prec) + R = p.ring + c = _get_constant_term(p, x) + if c: + try: + c_expr = c.as_expr() + const = R(exp(c_expr)) + except ValueError: + R = R.add_gens([exp(c_expr)]) + p = p.set_ring(R) + x = x.set_ring(R) + c = c.set_ring(R) + const = R(exp(c_expr)) + + p1 = p - c + + # Makes use of SymPy functions to evaluate the values of the cos/sin + # of the constant term. + return const*rs_exp(p1, x, prec) + + if len(p) > 20: + return _exp1(p, x, prec) + one = R(1) + n = 1 + c = [] + for k in range(prec): + c.append(one/n) + k += 1 + n *= k + + r = rs_series_from_list(p, c, x, prec) + return r + +def _atan(p, iv, prec): + """ + Expansion using formula. + + Faster on very small and univariate series. + """ + R = p.ring + mo = R(-1) + c = [-mo] + p2 = rs_square(p, iv, prec) + for k in range(1, prec): + c.append(mo**k/(2*k + 1)) + s = rs_series_from_list(p2, c, iv, prec) + s = rs_mul(s, p, iv, prec) + return s + +def rs_atan(p, x, prec): + """ + The arctangent of a series + + Return the series expansion of the atan of ``p``, about 0. + + Examples + ======== + + >>> from sympy.polys.domains import QQ + >>> from sympy.polys.rings import ring + >>> from sympy.polys.ring_series import rs_atan + >>> R, x, y = ring('x, y', QQ) + >>> rs_atan(x + x*y, x, 4) + -1/3*x**3*y**3 - x**3*y**2 - x**3*y - 1/3*x**3 + x*y + x + + See Also + ======== + + atan + """ + if rs_is_puiseux(p, x): + return rs_puiseux(rs_atan, p, x, prec) + R = p.ring + const = 0 + c = _get_constant_term(p, x) + if c: + try: + c_expr = c.as_expr() + const = R(atan(c_expr)) + except ValueError: + R = R.add_gens([atan(c_expr)]) + p = p.set_ring(R) + x = x.set_ring(R) + c = c.set_ring(R) + const = R(atan(c_expr)) + + # Instead of using a closed form formula, we differentiate atan(p) to get + # `1/(1+p**2) * dp`, whose series expansion is much easier to calculate. + # Finally we integrate to get back atan + dp = p.diff(x) + p1 = rs_square(p, x, prec) + R(1) + p1 = rs_series_inversion(p1, x, prec - 1) + p1 = rs_mul(dp, p1, x, prec - 1) + return rs_integrate(p1, x) + const + +def rs_asin(p, x, prec): + """ + Arcsine of a series + + Return the series expansion of the asin of ``p``, about 0. + + Examples + ======== + + >>> from sympy.polys.domains import QQ + >>> from sympy.polys.rings import ring + >>> from sympy.polys.ring_series import rs_asin + >>> R, x, y = ring('x, y', QQ) + >>> rs_asin(x, x, 8) + 5/112*x**7 + 3/40*x**5 + 1/6*x**3 + x + + See Also + ======== + + asin + """ + if rs_is_puiseux(p, x): + return rs_puiseux(rs_asin, p, x, prec) + if _has_constant_term(p, x): + raise NotImplementedError("Polynomial must not have constant term in " + "series variables") + R = p.ring + if x in R.gens: + # get a good value + if len(p) > 20: + dp = rs_diff(p, x) + p1 = 1 - rs_square(p, x, prec - 1) + p1 = rs_nth_root(p1, -2, x, prec - 1) + p1 = rs_mul(dp, p1, x, prec - 1) + return rs_integrate(p1, x) + one = R(1) + c = [0, one, 0] + for k in range(3, prec, 2): + c.append((k - 2)**2*c[-2]/(k*(k - 1))) + c.append(0) + return rs_series_from_list(p, c, x, prec) + + else: + raise NotImplementedError + +def _tan1(p, x, prec): + r""" + Helper function of :func:`rs_tan`. + + Return the series expansion of tan of a univariate series using Newton's + method. It takes advantage of the fact that series expansion of atan is + easier than that of tan. + + Consider `f(x) = y - \arctan(x)` + Let r be a root of f(x) found using Newton's method. + Then `f(r) = 0` + Or `y = \arctan(x)` where `x = \tan(y)` as required. + """ + R = p.ring + p1 = R(0) + for precx in _giant_steps(prec): + tmp = p - rs_atan(p1, x, precx) + tmp = rs_mul(tmp, 1 + rs_square(p1, x, precx), x, precx) + p1 += tmp + return p1 + +def rs_tan(p, x, prec): + """ + Tangent of a series. + + Return the series expansion of the tan of ``p``, about 0. + + Examples + ======== + + >>> from sympy.polys.domains import QQ + >>> from sympy.polys.rings import ring + >>> from sympy.polys.ring_series import rs_tan + >>> R, x, y = ring('x, y', QQ) + >>> rs_tan(x + x*y, x, 4) + 1/3*x**3*y**3 + x**3*y**2 + x**3*y + 1/3*x**3 + x*y + x + + See Also + ======== + + _tan1, tan + """ + if rs_is_puiseux(p, x): + r = rs_puiseux(rs_tan, p, x, prec) + return r + R = p.ring + const = 0 + c = _get_constant_term(p, x) + if c: + try: + c_expr = c.as_expr() + const = R(tan(c_expr)) + except ValueError: + R = R.add_gens([tan(c_expr, )]) + p = p.set_ring(R) + x = x.set_ring(R) + c = c.set_ring(R) + const = R(tan(c_expr)) + + p1 = p - c + + # Makes use of SymPy functions to evaluate the values of the cos/sin + # of the constant term. + t2 = rs_tan(p1, x, prec) + t = rs_series_inversion(1 - const*t2, x, prec) + return rs_mul(const + t2, t, x, prec) + + if R.ngens == 1: + return _tan1(p, x, prec) + else: + return rs_fun(p, rs_tan, x, prec) + +def rs_cot(p, x, prec): + """ + Cotangent of a series + + Return the series expansion of the cot of ``p``, about 0. + + Examples + ======== + + >>> from sympy.polys.domains import QQ + >>> from sympy.polys.rings import ring + >>> from sympy.polys.ring_series import rs_cot + >>> R, x, y = ring('x, y', QQ) + >>> rs_cot(x, x, 6) + -2/945*x**5 - 1/45*x**3 - 1/3*x + x**(-1) + + See Also + ======== + + cot + """ + # It can not handle series like `p = x + x*y` where the coefficient of the + # linear term in the series variable is symbolic. + if rs_is_puiseux(p, x): + r = rs_puiseux(rs_cot, p, x, prec) + return r + i, m = _check_series_var(p, x, 'cot') + prec1 = int(prec + 2*m) + c, s = rs_cos_sin(p, x, prec1) + s = mul_xin(s, i, -m) + s = rs_series_inversion(s, x, prec1) + res = rs_mul(c, s, x, prec1) + res = mul_xin(res, i, -m) + res = rs_trunc(res, x, prec) + return res + +def rs_sin(p, x, prec): + """ + Sine of a series + + Return the series expansion of the sin of ``p``, about 0. + + Examples + ======== + + >>> from sympy.polys.domains import QQ + >>> from sympy.polys.puiseux import puiseux_ring + >>> from sympy.polys.ring_series import rs_sin + >>> R, x, y = puiseux_ring('x, y', QQ) + >>> rs_sin(x + x*y, x, 4) + x + x*y + -1/6*x**3 + -1/2*x**3*y + -1/2*x**3*y**2 + -1/6*x**3*y**3 + >>> rs_sin(x**QQ(3, 2) + x*y**QQ(7, 5), x, 4) + x*y**(7/5) + x**(3/2) + -1/6*x**3*y**(21/5) + -1/2*x**(7/2)*y**(14/5) + + See Also + ======== + + sin + """ + if rs_is_puiseux(p, x): + return rs_puiseux(rs_sin, p, x, prec) + R = x.ring + if not p: + return R(0) + c = _get_constant_term(p, x) + if c: + try: + c_expr = c.as_expr() + t1, t2 = R(sin(c_expr)), R(cos(c_expr)) + except ValueError: + R = R.add_gens([sin(c_expr), cos(c_expr)]) + p = p.set_ring(R) + x = x.set_ring(R) + c = c.set_ring(R) + t1, t2 = R(sin(c_expr)), R(cos(c_expr)) + + p1 = p - c + + # Makes use of SymPy cos, sin functions to evaluate the values of the + # cos/sin of the constant term. + p_cos, p_sin = rs_cos_sin(p1, x, prec) + return p_sin*t2 + p_cos*t1 + + # Series is calculated in terms of tan as its evaluation is fast. + if len(p) > 20 and R.ngens == 1: + t = rs_tan(p/2, x, prec) + t2 = rs_square(t, x, prec) + p1 = rs_series_inversion(1 + t2, x, prec) + return rs_mul(p1, 2*t, x, prec) + one = R(1) + n = 1 + c = [0] + for k in range(2, prec + 2, 2): + c.append(one/n) + c.append(0) + n *= -k*(k + 1) + return rs_series_from_list(p, c, x, prec) + +def rs_cos(p, x, prec): + """ + Cosine of a series + + Return the series expansion of the cos of ``p``, about 0. + + Examples + ======== + + >>> from sympy.polys.domains import QQ + >>> from sympy.polys.puiseux import puiseux_ring + >>> from sympy.polys.ring_series import rs_cos + >>> R, x, y = puiseux_ring('x, y', QQ) + >>> rs_cos(x + x*y, x, 4) + 1 + -1/2*x**2 + -1*x**2*y + -1/2*x**2*y**2 + >>> rs_cos(x + x*y, x, 4)/x**QQ(7, 5) + x**(-7/5) + -1/2*x**(3/5) + -1*x**(3/5)*y + -1/2*x**(3/5)*y**2 + + See Also + ======== + + cos + """ + if rs_is_puiseux(p, x): + return rs_puiseux(rs_cos, p, x, prec) + R = p.ring + c = _get_constant_term(p, x) + if c: + try: + c_expr = c.as_expr() + t1, t2 = R(sin(c_expr)), R(cos(c_expr)) + except ValueError: + R = R.add_gens([sin(c_expr), cos(c_expr)]) + p = p.set_ring(R) + x = x.set_ring(R) + c = c.set_ring(R) + t1, t2 = R(sin(c_expr)), R(cos(c_expr)) + + p1 = p - c + # Makes use of SymPy cos, sin functions to evaluate the values of the + # cos/sin of the constant term. + p_cos, p_sin = rs_cos_sin(p1, x, prec) + return p_cos*t2 - p_sin*t1 + + # Series is calculated in terms of tan as its evaluation is fast. + if len(p) > 20 and R.ngens == 1: + t = rs_tan(p/2, x, prec) + t2 = rs_square(t, x, prec) + p1 = rs_series_inversion(1+t2, x, prec) + return rs_mul(p1, 1 - t2, x, prec) + one = R(1) + n = 1 + c = [] + for k in range(2, prec + 2, 2): + c.append(one/n) + c.append(0) + n *= -k*(k - 1) + return rs_series_from_list(p, c, x, prec) + +def rs_cos_sin(p, x, prec): + """ + Cosine and sine of a series + + Return the series expansion of the cosine and sine of ``p``, about 0. + + Examples + ======== + + >>> from sympy.polys.domains import QQ + >>> from sympy.polys.rings import ring + >>> from sympy.polys.ring_series import rs_cos_sin + >>> R, x, y = ring('x, y', QQ) + >>> c, s = rs_cos_sin(x + x*y, x, 4) + >>> c + -1/2*x**2*y**2 - x**2*y - 1/2*x**2 + 1 + >>> s + -1/6*x**3*y**3 - 1/2*x**3*y**2 - 1/2*x**3*y - 1/6*x**3 + x*y + x + + See Also + ======== + + rs_cos, rs_sin + """ + if rs_is_puiseux(p, x): + return rs_puiseux(rs_cos_sin, p, x, prec) + R = p.ring + if not p: + return R(0), R(0) + c = _get_constant_term(p, x) + if c: + try: + c_expr = c.as_expr() + t1, t2 = R(sin(c_expr)), R(cos(c_expr)) + except ValueError: + R = R.add_gens([sin(c_expr), cos(c_expr)]) + p = p.set_ring(R) + x = x.set_ring(R) + c = c.set_ring(R) + t1, t2 = R(sin(c_expr)), R(cos(c_expr)) + + p1 = p - c + p_cos, p_sin = rs_cos_sin(p1, x, prec) + return p_cos*t2 - p_sin*t1, p_cos*t1 + p_sin*t2 + + if len(p) > 20 and R.ngens == 1: + t = rs_tan(p/2, x, prec) + t2 = rs_square(t, x, prec) + p1 = rs_series_inversion(1 + t2, x, prec) + return (rs_mul(p1, 1 - t2, x, prec), rs_mul(p1, 2*t, x, prec)) + + one = R(1) + coeffs = [] + cn, sn = 1, 1 + for k in range(2, prec+2, 2): + coeffs.extend([(one/cn, 0), (0, one/sn)]) + cn, sn = -cn*k*(k - 1), -sn*k*(k + 1) + + c, s = zip(*coeffs) + return (rs_series_from_list(p, c, x, prec), rs_series_from_list(p, s, x, prec)) + +def _atanh(p, x, prec): + """ + Expansion using formula + + Faster for very small and univariate series + """ + R = p.ring + one = R(1) + c = [one] + p2 = rs_square(p, x, prec) + for k in range(1, prec): + c.append(one/(2*k + 1)) + s = rs_series_from_list(p2, c, x, prec) + s = rs_mul(s, p, x, prec) + return s + +def rs_atanh(p, x, prec): + """ + Hyperbolic arctangent of a series + + Return the series expansion of the atanh of ``p``, about 0. + + Examples + ======== + + >>> from sympy.polys.domains import QQ + >>> from sympy.polys.rings import ring + >>> from sympy.polys.ring_series import rs_atanh + >>> R, x, y = ring('x, y', QQ) + >>> rs_atanh(x + x*y, x, 4) + 1/3*x**3*y**3 + x**3*y**2 + x**3*y + 1/3*x**3 + x*y + x + + See Also + ======== + + atanh + """ + if rs_is_puiseux(p, x): + return rs_puiseux(rs_atanh, p, x, prec) + R = p.ring + const = 0 + c = _get_constant_term(p, x) + if c: + try: + c_expr = c.as_expr() + const = R(atanh(c_expr)) + except ValueError: + raise DomainError("The given series cannot be expanded in " + "this domain.") + + # Instead of using a closed form formula, we differentiate atanh(p) to get + # `1/(1-p**2) * dp`, whose series expansion is much easier to calculate. + # Finally we integrate to get back atanh + dp = rs_diff(p, x) + p1 = - rs_square(p, x, prec) + 1 + p1 = rs_series_inversion(p1, x, prec - 1) + p1 = rs_mul(dp, p1, x, prec - 1) + return rs_integrate(p1, x) + const + +def rs_asinh(p, x, prec): + """ + Hyperbolic arcsine of a series + + Return the series expansion of the arcsinh of ``p``, about 0. + + Examples + ======== + + >>> from sympy.polys.domains import QQ + >>> from sympy.polys.rings import ring + >>> from sympy.polys.ring_series import rs_asinh + >>> R, x = ring('x', QQ) + >>> rs_asinh(x, x, 9) + -5/112*x**7 + 3/40*x**5 - 1/6*x**3 + x + + See Also + ======== + + asinh + """ + if rs_is_puiseux(p, x): + return rs_puiseux(rs_asinh, p, x, prec) + R = p.ring + const = 0 + c = _get_constant_term(p, x) + if c: + try: + c_expr = c.as_expr() + const = R(asinh(c_expr)) + except ValueError: + raise DomainError("The given series cannot be expanded in " + "this domain.") + + # Instead of using a closed form formula, we differentiate asinh(p) to get + # `1/sqrt(1+p**2) * dp`, whose series expansion is much easier to calculate. + # Finally we integrate to get back asinh + dp = rs_diff(p, x) + p_squared = rs_square(p, x, prec) + denom = p_squared + R(1) + p1 = rs_nth_root(denom, -2, x, prec - 1) + p1 = rs_mul(dp, p1, x, prec - 1) + return rs_integrate(p1, x) + const + +def rs_sinh(p, x, prec): + """ + Hyperbolic sine of a series + + Return the series expansion of the sinh of ``p``, about 0. + + Examples + ======== + + >>> from sympy.polys.domains import QQ + >>> from sympy.polys.rings import ring + >>> from sympy.polys.ring_series import rs_sinh + >>> R, x, y = ring('x, y', QQ) + >>> rs_sinh(x + x*y, x, 4) + 1/6*x**3*y**3 + 1/2*x**3*y**2 + 1/2*x**3*y + 1/6*x**3 + x*y + x + + See Also + ======== + + sinh + """ + if rs_is_puiseux(p, x): + return rs_puiseux(rs_sinh, p, x, prec) + R = p.ring + if not p: + return R(0) + c = _get_constant_term(p, x) + if c: + try: + c_expr = c.as_expr() + t1, t2 = R(sinh(c_expr)), R(cosh(c_expr)) + except ValueError: + R = R.add_gens([sinh(c_expr), cosh(c_expr)]) + p = p.set_ring(R) + x = x.set_ring(R) + c = c.set_ring(R) + t1, t2 = R(sinh(c_expr)), R(cosh(c_expr)) + + p1 = p - c + p_cosh, p_sinh = rs_cosh_sinh(p1, x, prec) + return p_sinh * t2 + p_cosh * t1 + + t = rs_exp(p, x, prec) + t1 = rs_series_inversion(t, x, prec) + return (t - t1)/2 + +def rs_cosh(p, x, prec): + """ + Hyperbolic cosine of a series + + Return the series expansion of the cosh of ``p``, about 0. + + Examples + ======== + + >>> from sympy.polys.domains import QQ + >>> from sympy.polys.rings import ring + >>> from sympy.polys.ring_series import rs_cosh + >>> R, x, y = ring('x, y', QQ) + >>> rs_cosh(x + x*y, x, 4) + 1/2*x**2*y**2 + x**2*y + 1/2*x**2 + 1 + + See Also + ======== + + cosh + """ + if rs_is_puiseux(p, x): + return rs_puiseux(rs_cosh, p, x, prec) + R = p.ring + if not p: + return R(0) + c = _get_constant_term(p, x) + if c: + try: + c_expr = c.as_expr() + t1, t2 = R(sinh(c_expr)), R(cosh(c_expr)) + except ValueError: + R = R.add_gens([sinh(c_expr), cosh(c_expr)]) + p = p.set_ring(R) + x = x.set_ring(R) + c = c.set_ring(R) + t1, t2 = R(sinh(c_expr)), R(cosh(c_expr)) + + p1 = p - c + p_cosh, p_sinh = rs_cosh_sinh(p1, x, prec) + return p_cosh * t2 + p_sinh * t1 + + t = rs_exp(p, x, prec) + t1 = rs_series_inversion(t, x, prec) + return (t + t1)/2 + +def rs_cosh_sinh(p, x, prec): + """ + Hyperbolic cosine and sine of a series + + Return the series expansion of the hyperbolic cosine and sine of ``p``, about 0. + + Examples + ======== + + >>> from sympy.polys.domains import QQ + >>> from sympy.polys.rings import ring + >>> from sympy.polys.ring_series import rs_cosh_sinh + >>> R, x, y = ring('x, y', QQ) + >>> c, s = rs_cosh_sinh(x + x*y, x, 4) + >>> c + 1/2*x**2*y**2 + x**2*y + 1/2*x**2 + 1 + >>> s + 1/6*x**3*y**3 + 1/2*x**3*y**2 + 1/2*x**3*y + 1/6*x**3 + x*y + x + + See Also + ======== + + rs_cosh, rs_sinh + """ + if rs_is_puiseux(p, x): + return rs_puiseux(rs_cosh_sinh, p, x, prec) + R = p.ring + if not p: + return R(0), R(0) + c = _get_constant_term(p, x) + if c: + try: + c_expr = c.as_expr() + t1, t2 = R(sinh(c_expr)), R(cosh(c_expr)) + except ValueError: + R = R.add_gens([sinh(c_expr), cosh(c_expr)]) + p = p.set_ring(R) + x = x.set_ring(R) + c = c.set_ring(R) + t1, t2 = R(sinh(c_expr)), R(cosh(c_expr)) + + p1 = p - c + p_cosh, p_sinh = rs_cosh_sinh(p1, x, prec) + return p_cosh * t2 + p_sinh * t1, p_sinh * t2 + p_cosh * t1 + + t = rs_exp(p, x, prec) + t1 = rs_series_inversion(t, x, prec) + return (t + t1)/2, (t - t1)/2 + + +def _tanh(p, x, prec): + r""" + Helper function of :func:`rs_tanh` + + Return the series expansion of tanh of a univariate series using Newton's + method. It takes advantage of the fact that series expansion of atanh is + easier than that of tanh. + + See Also + ======== + + _tanh + """ + R = p.ring + p1 = R(0) + for precx in _giant_steps(prec): + tmp = p - rs_atanh(p1, x, precx) + tmp = rs_mul(tmp, 1 - rs_square(p1, x, prec), x, precx) + p1 += tmp + return p1 + +def rs_tanh(p, x, prec): + """ + Hyperbolic tangent of a series + + Return the series expansion of the tanh of ``p``, about 0. + + Examples + ======== + + >>> from sympy.polys.domains import QQ + >>> from sympy.polys.rings import ring + >>> from sympy.polys.ring_series import rs_tanh + >>> R, x, y = ring('x, y', QQ) + >>> rs_tanh(x + x*y, x, 4) + -1/3*x**3*y**3 - x**3*y**2 - x**3*y - 1/3*x**3 + x*y + x + + See Also + ======== + + tanh + """ + if rs_is_puiseux(p, x): + return rs_puiseux(rs_tanh, p, x, prec) + R = p.ring + const = 0 + c = _get_constant_term(p, x) + if c: + try: + c_expr = c.as_expr() + const = R(tanh(c_expr)) + except ValueError: + R = R.add_gens([tanh(c_expr)]) + p = p.set_ring(R) + x = x.set_ring(R) + c = c.set_ring(R) + const = R(tanh(c_expr)) + + p1 = p - c + t1 = rs_tanh(p1, x, prec) + t = rs_series_inversion(1 + const*t1, x, prec) + return rs_mul(const + t1, t, x, prec) + + if R.ngens == 1: + return _tanh(p, x, prec) + else: + return rs_fun(p, _tanh, x, prec) + +def rs_newton(p, x, prec): + """ + Compute the truncated Newton sum of the polynomial ``p`` + + Examples + ======== + + >>> from sympy.polys.domains import QQ + >>> from sympy.polys.rings import ring + >>> from sympy.polys.ring_series import rs_newton + >>> R, x = ring('x', QQ) + >>> p = x**2 - 2 + >>> rs_newton(p, x, 5) + 8*x**4 + 4*x**2 + 2 + """ + deg = p.degree() + p1 = _invert_monoms(p) + p2 = rs_series_inversion(p1, x, prec) + p3 = rs_mul(p1.diff(x), p2, x, prec) + res = deg - p3*x + return res + +def rs_hadamard_exp(p1, inverse=False): + """ + Return ``sum f_i/i!*x**i`` from ``sum f_i*x**i``, + where ``x`` is the first variable. + + If ``inverse=True`` return ``sum f_i*i!*x**i`` + + Examples + ======== + + >>> from sympy.polys.domains import QQ + >>> from sympy.polys.rings import ring + >>> from sympy.polys.ring_series import rs_hadamard_exp + >>> R, x = ring('x', QQ) + >>> p = 1 + x + x**2 + x**3 + >>> rs_hadamard_exp(p) + 1/6*x**3 + 1/2*x**2 + x + 1 + """ + R = p1.ring + if R.domain != QQ: + raise NotImplementedError + p = R.zero + if not inverse: + for exp1, v1 in p1.items(): + p[exp1] = v1/int(ifac(exp1[0])) + else: + for exp1, v1 in p1.items(): + p[exp1] = v1*int(ifac(exp1[0])) + return p + +def rs_compose_add(p1, p2): + """ + compute the composed sum ``prod(p2(x - beta) for beta root of p1)`` + + Examples + ======== + + >>> from sympy.polys.domains import QQ + >>> from sympy.polys.rings import ring + >>> from sympy.polys.ring_series import rs_compose_add + >>> R, x = ring('x', QQ) + >>> f = x**2 - 2 + >>> g = x**2 - 3 + >>> rs_compose_add(f, g) + x**4 - 10*x**2 + 1 + + References + ========== + + .. [1] A. Bostan, P. Flajolet, B. Salvy and E. Schost + "Fast Computation with Two Algebraic Numbers", + (2002) Research Report 4579, Institut + National de Recherche en Informatique et en Automatique + """ + R = p1.ring + x = R.gens[0] + prec = p1.degree()*p2.degree() + 1 + np1 = rs_newton(p1, x, prec) + np1e = rs_hadamard_exp(np1) + np2 = rs_newton(p2, x, prec) + np2e = rs_hadamard_exp(np2) + np3e = rs_mul(np1e, np2e, x, prec) + np3 = rs_hadamard_exp(np3e, True) + np3a = (np3[(0,)] - np3) / x + q = rs_integrate(np3a, x) + q = rs_exp(q, x, prec) + q = _invert_monoms(q) + q = q.primitive()[1] + dp = p1.degree()*p2.degree() - q.degree() + # `dp` is the multiplicity of the zeroes of the resultant; + # these zeroes are missed in this computation so they are put here. + # if p1 and p2 are monic irreducible polynomials, + # there are zeroes in the resultant + # if and only if p1 = p2 ; in fact in that case p1 and p2 have a + # root in common, so gcd(p1, p2) != 1; being p1 and p2 irreducible + # this means p1 = p2 + if dp: + q = q*x**dp + return q + + +_convert_func = { + 'sin': 'rs_sin', + 'cos': 'rs_cos', + 'exp': 'rs_exp', + 'tan': 'rs_tan', + 'log': 'rs_log', + 'atan': 'rs_atan', + 'sinh': 'rs_sinh', + 'cosh': 'rs_cosh', + 'tanh': 'rs_tanh' + } + +def rs_min_pow(expr, series_rs, a): + """Find the minimum power of `a` in the series expansion of expr""" + series = 0 + n = 2 + while series == 0: + series = _rs_series(expr, series_rs, a, n) + n *= 2 + R = series.ring + a = R(a) + i = R.gens.index(a) + return min(series, key=lambda t: t[i])[i] + + +def _rs_series(expr, series_rs, a, prec): + # TODO Use _parallel_dict_from_expr instead of sring as sring is + # inefficient. For details, read the todo in sring. + args = expr.args + R = series_rs.ring + + # expr does not contain any function to be expanded + if not any(arg.has(Function) for arg in args) and not expr.is_Function: + return series_rs + + if not expr.has(a): + return series_rs + + elif expr.is_Function: + arg = args[0] + if len(args) > 1: + raise NotImplementedError + R1, series = sring(arg, domain=QQ, expand=False, series=True) + series_inner = _rs_series(arg, series, a, prec) + + # Why do we need to compose these three rings? + # + # We want to use a simple domain (like ``QQ`` or ``RR``) but they don't + # support symbolic coefficients. We need a ring that for example lets + # us have `sin(1)` and `cos(1)` as coefficients if we are expanding + # `sin(x + 1)`. The ``EX`` domain allows all symbolic coefficients, but + # that makes it very complex and hence slow. + # + # To solve this problem, we add only those symbolic elements as + # generators to our ring, that we need. Here, series_inner might + # involve terms like `sin(4)`, `exp(a)`, etc, which are not there in + # R1 or R. Hence, we compose these three rings to create one that has + # the generators of all three. + R = R.compose(R1).compose(series_inner.ring) + series_inner = series_inner.set_ring(R) + series = eval(_convert_func[str(expr.func)])(series_inner, + R(a), prec) + return series + + elif expr.is_Mul: + n = len(args) + for arg in args: # XXX Looks redundant + if not arg.is_Number: + R1, _ = sring(arg, expand=False, series=True) + R = R.compose(R1) + min_pows = list(map(rs_min_pow, args, [R(arg) for arg in args], + [a]*len(args))) + sum_pows = sum(min_pows) + series = R(1) + + for i in range(n): + _series = _rs_series(args[i], R(args[i]), a, ceiling(prec + - sum_pows + min_pows[i])) + R = R.compose(_series.ring) + _series = _series.set_ring(R) + series = series.set_ring(R) + series *= _series + series = rs_trunc(series, R(a), prec) + return series + + elif expr.is_Add: + n = len(args) + series = R(0) + for i in range(n): + _series = _rs_series(args[i], R(args[i]), a, prec) + R = R.compose(_series.ring) + _series = _series.set_ring(R) + series = series.set_ring(R) + series += _series + return series + + elif expr.is_Pow: + R1, _ = sring(expr.base, domain=QQ, expand=False, series=True) + R = R.compose(R1) + series_inner = _rs_series(expr.base, R(expr.base), a, prec) + return rs_pow(series_inner, expr.exp, series_inner.ring(a), prec) + + # The `is_constant` method is buggy hence we check it at the end. + # See issue #9786 for details. + elif isinstance(expr, Expr) and expr.is_constant(): + return sring(expr, domain=QQ, expand=False, series=True)[1] + + else: + raise NotImplementedError + +def rs_series(expr, a, prec): + """Return the series expansion of an expression about 0. + + Parameters + ========== + + expr : :class:`~.Expr` + a : :class:`~.Symbol` with respect to which expr is to be expanded + prec : order of the series expansion + + Currently supports multivariate Taylor series expansion. This is much + faster that SymPy's series method as it uses sparse polynomial operations. + + It automatically creates the simplest ring required to represent the series + expansion through repeated calls to sring. + + Examples + ======== + + >>> from sympy.polys.ring_series import rs_series + >>> from sympy import sin, cos, exp, tan, symbols, QQ + >>> a, b, c = symbols('a, b, c') + >>> rs_series(sin(a) + exp(a), a, 5) + 1/24*a**4 + 1/2*a**2 + 2*a + 1 + >>> series = rs_series(tan(a + b)*cos(a + c), a, 2) + >>> series.as_expr() + -a*sin(c)*tan(b) + a*cos(c)*tan(b)**2 + a*cos(c) + cos(c)*tan(b) + >>> series = rs_series(exp(a**QQ(1,3) + a**QQ(2, 5)), a, 1) + >>> series.as_expr() + a**(11/15) + a**(4/5)/2 + a**(2/5) + a**(2/3)/2 + a**(1/3) + 1 + + """ + R, series = sring(expr, domain=QQ, expand=False, series=True) + if a not in R.symbols: + R = R.add_gens([a, ]) + series = series.set_ring(R) + series = _rs_series(expr, series, a, prec) + R = series.ring + gen = R(a) + prec_got = series.degree(gen) + 1 + + if prec_got >= prec: + return rs_trunc(series, gen, prec) + else: + # increase the requested number of terms to get the desired + # number keep increasing (up to 9) until the received order + # is different than the original order and then predict how + # many additional terms are needed + for more in range(1, 9): + p1 = _rs_series(expr, series, a, prec=prec + more) + gen = gen.set_ring(p1.ring) + new_prec = p1.degree(gen) + 1 + if new_prec != prec_got: + prec_do = ceiling(prec + (prec - prec_got)*more/(new_prec - + prec_got)) + p1 = _rs_series(expr, series, a, prec=prec_do) + while p1.degree(gen) + 1 < prec: + p1 = _rs_series(expr, series, a, prec=prec_do) + gen = gen.set_ring(p1.ring) + prec_do *= 2 + break + else: + break + else: + raise ValueError('Could not calculate %s terms for %s' + % (str(prec), expr)) + return rs_trunc(p1, gen, prec) diff --git a/phivenv/Lib/site-packages/sympy/polys/rings.py b/phivenv/Lib/site-packages/sympy/polys/rings.py new file mode 100644 index 0000000000000000000000000000000000000000..9df84dcf1691ac9bcd4aa01a85ca34b7ffc53e5d --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/rings.py @@ -0,0 +1,3096 @@ +"""Sparse polynomial rings. """ + +from __future__ import annotations + +from operator import add, mul, lt, le, gt, ge +from functools import reduce +from types import GeneratorType + +from sympy.core.cache import cacheit +from sympy.core.expr import Expr +from sympy.core.intfunc import igcd +from sympy.core.symbol import Symbol, symbols as _symbols +from sympy.core.sympify import CantSympify, sympify +from sympy.ntheory.multinomial import multinomial_coefficients +from sympy.polys.compatibility import IPolys +from sympy.polys.constructor import construct_domain +from sympy.polys.densebasic import ninf, dmp_to_dict, dmp_from_dict +from sympy.polys.domains.domain import Domain +from sympy.polys.domains.domainelement import DomainElement +from sympy.polys.domains.polynomialring import PolynomialRing +from sympy.polys.heuristicgcd import heugcd +from sympy.polys.monomials import MonomialOps +from sympy.polys.orderings import lex, MonomialOrder +from sympy.polys.polyerrors import ( + CoercionFailed, GeneratorsError, + ExactQuotientFailed, MultivariatePolynomialError) +from sympy.polys.polyoptions import (Domain as DomainOpt, + Order as OrderOpt, build_options) +from sympy.polys.polyutils import (expr_from_dict, _dict_reorder, + _parallel_dict_from_expr) +from sympy.printing.defaults import DefaultPrinting +from sympy.utilities import public, subsets +from sympy.utilities.iterables import is_sequence +from sympy.utilities.magic import pollute + +@public +def ring(symbols, domain, order: MonomialOrder|str = lex): + """Construct a polynomial ring returning ``(ring, x_1, ..., x_n)``. + + Parameters + ========== + + symbols : str + Symbol/Expr or sequence of str, Symbol/Expr (non-empty) + domain : :class:`~.Domain` or coercible + order : :class:`~.MonomialOrder` or coercible, optional, defaults to ``lex`` + + Examples + ======== + + >>> from sympy.polys.rings import ring + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.orderings import lex + + >>> R, x, y, z = ring("x,y,z", ZZ, lex) + >>> R + Polynomial ring in x, y, z over ZZ with lex order + >>> x + y + z + x + y + z + >>> type(_) + + + """ + _ring = PolyRing(symbols, domain, order) + return (_ring,) + _ring.gens + +@public +def xring(symbols, domain, order=lex): + """Construct a polynomial ring returning ``(ring, (x_1, ..., x_n))``. + + Parameters + ========== + + symbols : str + Symbol/Expr or sequence of str, Symbol/Expr (non-empty) + domain : :class:`~.Domain` or coercible + order : :class:`~.MonomialOrder` or coercible, optional, defaults to ``lex`` + + Examples + ======== + + >>> from sympy.polys.rings import xring + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.orderings import lex + + >>> R, (x, y, z) = xring("x,y,z", ZZ, lex) + >>> R + Polynomial ring in x, y, z over ZZ with lex order + >>> x + y + z + x + y + z + >>> type(_) + + + """ + _ring = PolyRing(symbols, domain, order) + return (_ring, _ring.gens) + +@public +def vring(symbols, domain, order=lex): + """Construct a polynomial ring and inject ``x_1, ..., x_n`` into the global namespace. + + Parameters + ========== + + symbols : str + Symbol/Expr or sequence of str, Symbol/Expr (non-empty) + domain : :class:`~.Domain` or coercible + order : :class:`~.MonomialOrder` or coercible, optional, defaults to ``lex`` + + Examples + ======== + + >>> from sympy.polys.rings import vring + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.orderings import lex + + >>> vring("x,y,z", ZZ, lex) + Polynomial ring in x, y, z over ZZ with lex order + >>> x + y + z # noqa: + x + y + z + >>> type(_) + + + """ + _ring = PolyRing(symbols, domain, order) + pollute([ sym.name for sym in _ring.symbols ], _ring.gens) + return _ring + +@public +def sring(exprs, *symbols, **options): + """Construct a ring deriving generators and domain from options and input expressions. + + Parameters + ========== + + exprs : :class:`~.Expr` or sequence of :class:`~.Expr` (sympifiable) + symbols : sequence of :class:`~.Symbol`/:class:`~.Expr` + options : keyword arguments understood by :class:`~.Options` + + Examples + ======== + + >>> from sympy import sring, symbols + + >>> x, y, z = symbols("x,y,z") + >>> R, f = sring(x + 2*y + 3*z) + >>> R + Polynomial ring in x, y, z over ZZ with lex order + >>> f + x + 2*y + 3*z + >>> type(_) + + + """ + single = False + + if not is_sequence(exprs): + exprs, single = [exprs], True + + exprs = list(map(sympify, exprs)) + opt = build_options(symbols, options) + + # TODO: rewrite this so that it doesn't use expand() (see poly()). + reps, opt = _parallel_dict_from_expr(exprs, opt) + + if opt.domain is None: + coeffs = sum([ list(rep.values()) for rep in reps ], []) + + opt.domain, coeffs_dom = construct_domain(coeffs, opt=opt) + + coeff_map = dict(zip(coeffs, coeffs_dom)) + reps = [{m: coeff_map[c] for m, c in rep.items()} for rep in reps] + + _ring = PolyRing(opt.gens, opt.domain, opt.order) + polys = list(map(_ring.from_dict, reps)) + + if single: + return (_ring, polys[0]) + else: + return (_ring, polys) + +def _parse_symbols(symbols): + if isinstance(symbols, str): + return _symbols(symbols, seq=True) if symbols else () + elif isinstance(symbols, Expr): + return (symbols,) + elif is_sequence(symbols): + if all(isinstance(s, str) for s in symbols): + return _symbols(symbols) + elif all(isinstance(s, Expr) for s in symbols): + return symbols + + raise GeneratorsError("expected a string, Symbol or expression or a non-empty sequence of strings, Symbols or expressions") + + +class PolyRing(DefaultPrinting, IPolys): + """Multivariate distributed polynomial ring. """ + + gens: tuple[PolyElement, ...] + symbols: tuple[Expr, ...] + ngens: int + domain: Domain + order: MonomialOrder + + def __new__(cls, symbols, domain, order=lex): + symbols = tuple(_parse_symbols(symbols)) + ngens = len(symbols) + domain = DomainOpt.preprocess(domain) + order = OrderOpt.preprocess(order) + + _hash_tuple = (cls.__name__, symbols, ngens, domain, order) + + if domain.is_Composite and set(symbols) & set(domain.symbols): + raise GeneratorsError("polynomial ring and it's ground domain share generators") + + obj = object.__new__(cls) + obj._hash_tuple = _hash_tuple + obj._hash = hash(_hash_tuple) + obj.symbols = symbols + obj.ngens = ngens + obj.domain = domain + obj.order = order + + obj.dtype = PolyElement(obj, ()).new + + obj.zero_monom = (0,)*ngens + obj.gens = obj._gens() + obj._gens_set = set(obj.gens) + + obj._one = [(obj.zero_monom, domain.one)] + + if ngens: + # These expect monomials in at least one variable + codegen = MonomialOps(ngens) + obj.monomial_mul = codegen.mul() + obj.monomial_pow = codegen.pow() + obj.monomial_mulpow = codegen.mulpow() + obj.monomial_ldiv = codegen.ldiv() + obj.monomial_div = codegen.div() + obj.monomial_lcm = codegen.lcm() + obj.monomial_gcd = codegen.gcd() + else: + monunit = lambda a, b: () + obj.monomial_mul = monunit + obj.monomial_pow = monunit + obj.monomial_mulpow = lambda a, b, c: () + obj.monomial_ldiv = monunit + obj.monomial_div = monunit + obj.monomial_lcm = monunit + obj.monomial_gcd = monunit + + + if order is lex: + obj.leading_expv = max + else: + obj.leading_expv = lambda f: max(f, key=order) + + for symbol, generator in zip(obj.symbols, obj.gens): + if isinstance(symbol, Symbol): + name = symbol.name + + if not hasattr(obj, name): + setattr(obj, name, generator) + + return obj + + def _gens(self): + """Return a list of polynomial generators. """ + one = self.domain.one + _gens = [] + for i in range(self.ngens): + expv = self.monomial_basis(i) + poly = self.zero + poly[expv] = one + _gens.append(poly) + return tuple(_gens) + + def __getnewargs__(self): + return (self.symbols, self.domain, self.order) + + def __getstate__(self): + state = self.__dict__.copy() + del state["leading_expv"] + + for key in state: + if key.startswith("monomial_"): + del state[key] + + return state + + def __hash__(self): + return self._hash + + def __eq__(self, other): + return isinstance(other, PolyRing) and \ + (self.symbols, self.domain, self.ngens, self.order) == \ + (other.symbols, other.domain, other.ngens, other.order) + + def __ne__(self, other): + return not self == other + + def clone(self, symbols=None, domain=None, order=None): + # Need a hashable tuple for cacheit to work + if symbols is not None and isinstance(symbols, list): + symbols = tuple(symbols) + return self._clone(symbols, domain, order) + + @cacheit + def _clone(self, symbols, domain, order): + return self.__class__(symbols or self.symbols, domain or self.domain, order or self.order) + + def monomial_basis(self, i): + """Return the ith-basis element. """ + basis = [0]*self.ngens + basis[i] = 1 + return tuple(basis) + + @property + def zero(self): + return self.dtype([]) + + @property + def one(self): + return self.dtype(self._one) + + def is_element(self, element): + """True if ``element`` is an element of this ring. False otherwise. """ + return isinstance(element, PolyElement) and element.ring == self + + def domain_new(self, element, orig_domain=None): + return self.domain.convert(element, orig_domain) + + def ground_new(self, coeff): + return self.term_new(self.zero_monom, coeff) + + def term_new(self, monom, coeff): + coeff = self.domain_new(coeff) + poly = self.zero + if coeff: + poly[monom] = coeff + return poly + + def ring_new(self, element): + if isinstance(element, PolyElement): + if self == element.ring: + return element + elif isinstance(self.domain, PolynomialRing) and self.domain.ring == element.ring: + return self.ground_new(element) + else: + raise NotImplementedError("conversion") + elif isinstance(element, str): + raise NotImplementedError("parsing") + elif isinstance(element, dict): + return self.from_dict(element) + elif isinstance(element, list): + try: + return self.from_terms(element) + except ValueError: + return self.from_list(element) + elif isinstance(element, Expr): + return self.from_expr(element) + else: + return self.ground_new(element) + + __call__ = ring_new + + def from_dict(self, element, orig_domain=None): + domain_new = self.domain_new + poly = self.zero + + for monom, coeff in element.items(): + coeff = domain_new(coeff, orig_domain) + if coeff: + poly[monom] = coeff + + return poly + + def from_terms(self, element, orig_domain=None): + return self.from_dict(dict(element), orig_domain) + + def from_list(self, element): + return self.from_dict(dmp_to_dict(element, self.ngens-1, self.domain)) + + def _rebuild_expr(self, expr, mapping): + domain = self.domain + + def _rebuild(expr): + generator = mapping.get(expr) + + if generator is not None: + return generator + elif expr.is_Add: + return reduce(add, list(map(_rebuild, expr.args))) + elif expr.is_Mul: + return reduce(mul, list(map(_rebuild, expr.args))) + else: + # XXX: Use as_base_exp() to handle Pow(x, n) and also exp(n) + # XXX: E can be a generator e.g. sring([exp(2)]) -> ZZ[E] + base, exp = expr.as_base_exp() + if exp.is_Integer and exp > 1: + return _rebuild(base)**int(exp) + else: + return self.ground_new(domain.convert(expr)) + + return _rebuild(sympify(expr)) + + def from_expr(self, expr): + mapping = dict(list(zip(self.symbols, self.gens))) + + try: + poly = self._rebuild_expr(expr, mapping) + except CoercionFailed: + raise ValueError("expected an expression convertible to a polynomial in %s, got %s" % (self, expr)) + else: + return self.ring_new(poly) + + def index(self, gen): + """Compute index of ``gen`` in ``self.gens``. """ + if gen is None: + if self.ngens: + i = 0 + else: + i = -1 # indicate impossible choice + elif isinstance(gen, int): + i = gen + + if 0 <= i and i < self.ngens: + pass + elif -self.ngens <= i and i <= -1: + i = -i - 1 + else: + raise ValueError("invalid generator index: %s" % gen) + elif self.is_element(gen): + try: + i = self.gens.index(gen) + except ValueError: + raise ValueError("invalid generator: %s" % gen) + elif isinstance(gen, str): + try: + i = self.symbols.index(gen) + except ValueError: + raise ValueError("invalid generator: %s" % gen) + else: + raise ValueError("expected a polynomial generator, an integer, a string or None, got %s" % gen) + + return i + + def drop(self, *gens): + """Remove specified generators from this ring. """ + indices = set(map(self.index, gens)) + symbols = [ s for i, s in enumerate(self.symbols) if i not in indices ] + + if not symbols: + return self.domain + else: + return self.clone(symbols=symbols) + + def __getitem__(self, key): + symbols = self.symbols[key] + + if not symbols: + return self.domain + else: + return self.clone(symbols=symbols) + + def to_ground(self): + # TODO: should AlgebraicField be a Composite domain? + if self.domain.is_Composite or hasattr(self.domain, 'domain'): + return self.clone(domain=self.domain.domain) + else: + raise ValueError("%s is not a composite domain" % self.domain) + + def to_domain(self): + return PolynomialRing(self) + + def to_field(self): + from sympy.polys.fields import FracField + return FracField(self.symbols, self.domain, self.order) + + @property + def is_univariate(self): + return len(self.gens) == 1 + + @property + def is_multivariate(self): + return len(self.gens) > 1 + + def add(self, *objs): + """ + Add a sequence of polynomials or containers of polynomials. + + Examples + ======== + + >>> from sympy.polys.rings import ring + >>> from sympy.polys.domains import ZZ + + >>> R, x = ring("x", ZZ) + >>> R.add([ x**2 + 2*i + 3 for i in range(4) ]) + 4*x**2 + 24 + >>> _.factor_list() + (4, [(x**2 + 6, 1)]) + + """ + p = self.zero + + for obj in objs: + if is_sequence(obj, include=GeneratorType): + p += self.add(*obj) + else: + p += obj + + return p + + def mul(self, *objs): + """ + Multiply a sequence of polynomials or containers of polynomials. + + Examples + ======== + + >>> from sympy.polys.rings import ring + >>> from sympy.polys.domains import ZZ + + >>> R, x = ring("x", ZZ) + >>> R.mul([ x**2 + 2*i + 3 for i in range(4) ]) + x**8 + 24*x**6 + 206*x**4 + 744*x**2 + 945 + >>> _.factor_list() + (1, [(x**2 + 3, 1), (x**2 + 5, 1), (x**2 + 7, 1), (x**2 + 9, 1)]) + + """ + p = self.one + + for obj in objs: + if is_sequence(obj, include=GeneratorType): + p *= self.mul(*obj) + else: + p *= obj + + return p + + def drop_to_ground(self, *gens): + r""" + Remove specified generators from the ring and inject them into + its domain. + """ + indices = set(map(self.index, gens)) + symbols = [s for i, s in enumerate(self.symbols) if i not in indices] + gens = [gen for i, gen in enumerate(self.gens) if i not in indices] + + if not symbols: + return self + else: + return self.clone(symbols=symbols, domain=self.drop(*gens)) + + def compose(self, other): + """Add the generators of ``other`` to ``self``""" + if self != other: + syms = set(self.symbols).union(set(other.symbols)) + return self.clone(symbols=list(syms)) + else: + return self + + def add_gens(self, symbols): + """Add the elements of ``symbols`` as generators to ``self``""" + syms = set(self.symbols).union(set(symbols)) + return self.clone(symbols=list(syms)) + + def symmetric_poly(self, n): + """ + Return the elementary symmetric polynomial of degree *n* over + this ring's generators. + """ + if n < 0 or n > self.ngens: + raise ValueError("Cannot generate symmetric polynomial of order %s for %s" % (n, self.gens)) + elif not n: + return self.one + else: + poly = self.zero + for s in subsets(range(self.ngens), int(n)): + monom = tuple(int(i in s) for i in range(self.ngens)) + poly += self.term_new(monom, self.domain.one) + return poly + + +class PolyElement(DomainElement, DefaultPrinting, CantSympify, dict): + """Element of multivariate distributed polynomial ring. """ + + def __init__(self, ring, init): + super().__init__(init) + self.ring = ring + # This check would be too slow to run every time: + # self._check() + + def _check(self): + assert isinstance(self, PolyElement) + assert isinstance(self.ring, PolyRing) + dom = self.ring.domain + assert isinstance(dom, Domain) + for monom, coeff in self.items(): + assert dom.of_type(coeff) + assert len(monom) == self.ring.ngens + assert all(isinstance(exp, int) and exp >= 0 for exp in monom) + + def new(self, init): + return self.__class__(self.ring, init) + + def parent(self): + return self.ring.to_domain() + + def __getnewargs__(self): + return (self.ring, list(self.iterterms())) + + _hash = None + + def __hash__(self): + # XXX: This computes a hash of a dictionary, but currently we don't + # protect dictionary from being changed so any use site modifications + # will make hashing go wrong. Use this feature with caution until we + # figure out how to make a safe API without compromising speed of this + # low-level class. + _hash = self._hash + if _hash is None: + self._hash = _hash = hash((self.ring, frozenset(self.items()))) + return _hash + + def copy(self): + """Return a copy of polynomial self. + + Polynomials are mutable; if one is interested in preserving + a polynomial, and one plans to use inplace operations, one + can copy the polynomial. This method makes a shallow copy. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.rings import ring + + >>> R, x, y = ring('x, y', ZZ) + >>> p = (x + y)**2 + >>> p1 = p.copy() + >>> p2 = p + >>> p[R.zero_monom] = 3 + >>> p + x**2 + 2*x*y + y**2 + 3 + >>> p1 + x**2 + 2*x*y + y**2 + >>> p2 + x**2 + 2*x*y + y**2 + 3 + + """ + return self.new(self) + + def set_ring(self, new_ring): + if self.ring == new_ring: + return self + elif self.ring.symbols != new_ring.symbols: + terms = list(zip(*_dict_reorder(self, self.ring.symbols, new_ring.symbols))) + return new_ring.from_terms(terms, self.ring.domain) + else: + return new_ring.from_dict(self, self.ring.domain) + + def as_expr(self, *symbols): + if not symbols: + symbols = self.ring.symbols + elif len(symbols) != self.ring.ngens: + raise ValueError( + "Wrong number of symbols, expected %s got %s" % + (self.ring.ngens, len(symbols)) + ) + + return expr_from_dict(self.as_expr_dict(), *symbols) + + def as_expr_dict(self): + to_sympy = self.ring.domain.to_sympy + return {monom: to_sympy(coeff) for monom, coeff in self.iterterms()} + + def clear_denoms(self): + domain = self.ring.domain + + if not domain.is_Field or not domain.has_assoc_Ring: + return domain.one, self + + ground_ring = domain.get_ring() + common = ground_ring.one + lcm = ground_ring.lcm + denom = domain.denom + + for coeff in self.values(): + common = lcm(common, denom(coeff)) + + poly = self.new([ (k, v*common) for k, v in self.items() ]) + return common, poly + + def strip_zero(self): + """Eliminate monomials with zero coefficient. """ + for k, v in list(self.items()): + if not v: + del self[k] + + def __eq__(p1, p2): + """Equality test for polynomials. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.rings import ring + + >>> _, x, y = ring('x, y', ZZ) + >>> p1 = (x + y)**2 + (x - y)**2 + >>> p1 == 4*x*y + False + >>> p1 == 2*(x**2 + y**2) + True + + """ + if not p2: + return not p1 + elif p1.ring.is_element(p2): + return dict.__eq__(p1, p2) + elif len(p1) > 1: + return False + else: + return p1.get(p1.ring.zero_monom) == p2 + + def __ne__(p1, p2): + return not p1 == p2 + + def almosteq(p1, p2, tolerance=None): + """Approximate equality test for polynomials. """ + ring = p1.ring + + if ring.is_element(p2): + if set(p1.keys()) != set(p2.keys()): + return False + + almosteq = ring.domain.almosteq + + for k in p1.keys(): + if not almosteq(p1[k], p2[k], tolerance): + return False + return True + elif len(p1) > 1: + return False + else: + try: + p2 = ring.domain.convert(p2) + except CoercionFailed: + return False + else: + return ring.domain.almosteq(p1.const(), p2, tolerance) + + def sort_key(self): + return (len(self), self.terms()) + + def _cmp(p1, p2, op): + if p1.ring.is_element(p2): + return op(p1.sort_key(), p2.sort_key()) + else: + return NotImplemented + + def __lt__(p1, p2): + return p1._cmp(p2, lt) + def __le__(p1, p2): + return p1._cmp(p2, le) + def __gt__(p1, p2): + return p1._cmp(p2, gt) + def __ge__(p1, p2): + return p1._cmp(p2, ge) + + def _drop(self, gen): + ring = self.ring + i = ring.index(gen) + + if ring.ngens == 1: + return i, ring.domain + else: + symbols = list(ring.symbols) + del symbols[i] + return i, ring.clone(symbols=symbols) + + def drop(self, gen): + i, ring = self._drop(gen) + + if self.ring.ngens == 1: + if self.is_ground: + return self.coeff(1) + else: + raise ValueError("Cannot drop %s" % gen) + else: + poly = ring.zero + + for k, v in self.items(): + if k[i] == 0: + K = list(k) + del K[i] + poly[tuple(K)] = v + else: + raise ValueError("Cannot drop %s" % gen) + + return poly + + def _drop_to_ground(self, gen): + ring = self.ring + i = ring.index(gen) + + symbols = list(ring.symbols) + del symbols[i] + return i, ring.clone(symbols=symbols, domain=ring[i]) + + def drop_to_ground(self, gen): + if self.ring.ngens == 1: + raise ValueError("Cannot drop only generator to ground") + + i, ring = self._drop_to_ground(gen) + poly = ring.zero + gen = ring.domain.gens[0] + + for monom, coeff in self.iterterms(): + mon = monom[:i] + monom[i+1:] + if mon not in poly: + poly[mon] = (gen**monom[i]).mul_ground(coeff) + else: + poly[mon] += (gen**monom[i]).mul_ground(coeff) + + return poly + + def to_dense(self): + return dmp_from_dict(self, self.ring.ngens-1, self.ring.domain) + + def to_dict(self): + return dict(self) + + def str(self, printer, precedence, exp_pattern, mul_symbol): + if not self: + return printer._print(self.ring.domain.zero) + prec_mul = precedence["Mul"] + prec_atom = precedence["Atom"] + ring = self.ring + symbols = ring.symbols + ngens = ring.ngens + zm = ring.zero_monom + sexpvs = [] + for expv, coeff in self.terms(): + negative = ring.domain.is_negative(coeff) + sign = " - " if negative else " + " + sexpvs.append(sign) + if expv == zm: + scoeff = printer._print(coeff) + if negative and scoeff.startswith("-"): + scoeff = scoeff[1:] + else: + if negative: + coeff = -coeff + if coeff != self.ring.domain.one: + scoeff = printer.parenthesize(coeff, prec_mul, strict=True) + else: + scoeff = '' + sexpv = [] + for i in range(ngens): + exp = expv[i] + if not exp: + continue + symbol = printer.parenthesize(symbols[i], prec_atom, strict=True) + if exp != 1: + if exp != int(exp) or exp < 0: + sexp = printer.parenthesize(exp, prec_atom, strict=False) + else: + sexp = exp + sexpv.append(exp_pattern % (symbol, sexp)) + else: + sexpv.append('%s' % symbol) + if scoeff: + sexpv = [scoeff] + sexpv + sexpvs.append(mul_symbol.join(sexpv)) + if sexpvs[0] in [" + ", " - "]: + head = sexpvs.pop(0) + if head == " - ": + sexpvs.insert(0, "-") + return "".join(sexpvs) + + @property + def is_generator(self): + return self in self.ring._gens_set + + @property + def is_ground(self): + return not self or (len(self) == 1 and self.ring.zero_monom in self) + + @property + def is_monomial(self): + return not self or (len(self) == 1 and self.LC == 1) + + @property + def is_term(self): + return len(self) <= 1 + + @property + def is_negative(self): + return self.ring.domain.is_negative(self.LC) + + @property + def is_positive(self): + return self.ring.domain.is_positive(self.LC) + + @property + def is_nonnegative(self): + return self.ring.domain.is_nonnegative(self.LC) + + @property + def is_nonpositive(self): + return self.ring.domain.is_nonpositive(self.LC) + + @property + def is_zero(f): + return not f + + @property + def is_one(f): + return f == f.ring.one + + @property + def is_monic(f): + return f.ring.domain.is_one(f.LC) + + @property + def is_primitive(f): + return f.ring.domain.is_one(f.content()) + + @property + def is_linear(f): + return all(sum(monom) <= 1 for monom in f.itermonoms()) + + @property + def is_quadratic(f): + return all(sum(monom) <= 2 for monom in f.itermonoms()) + + @property + def is_squarefree(f): + if not f.ring.ngens: + return True + return f.ring.dmp_sqf_p(f) + + @property + def is_irreducible(f): + if not f.ring.ngens: + return True + return f.ring.dmp_irreducible_p(f) + + @property + def is_cyclotomic(f): + if f.ring.is_univariate: + return f.ring.dup_cyclotomic_p(f) + else: + raise MultivariatePolynomialError("cyclotomic polynomial") + + def __neg__(self): + return self.new([ (monom, -coeff) for monom, coeff in self.iterterms() ]) + + def __pos__(self): + return self + + def __add__(p1, p2): + """Add two polynomials. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.rings import ring + + >>> _, x, y = ring('x, y', ZZ) + >>> (x + y)**2 + (x - y)**2 + 2*x**2 + 2*y**2 + + """ + if not p2: + return p1.copy() + ring = p1.ring + if ring.is_element(p2): + p = p1.copy() + get = p.get + zero = ring.domain.zero + for k, v in p2.items(): + v = get(k, zero) + v + if v: + p[k] = v + else: + del p[k] + return p + elif isinstance(p2, PolyElement): + if isinstance(ring.domain, PolynomialRing) and ring.domain.ring == p2.ring: + pass + elif isinstance(p2.ring.domain, PolynomialRing) and p2.ring.domain.ring == ring: + return p2.__radd__(p1) + else: + return NotImplemented + + try: + cp2 = ring.domain_new(p2) + except CoercionFailed: + return NotImplemented + else: + p = p1.copy() + if not cp2: + return p + zm = ring.zero_monom + if zm not in p1.keys(): + p[zm] = cp2 + else: + if p2 == -p[zm]: + del p[zm] + else: + p[zm] += cp2 + return p + + def __radd__(p1, n): + p = p1.copy() + if not n: + return p + ring = p1.ring + try: + n = ring.domain_new(n) + except CoercionFailed: + return NotImplemented + else: + zm = ring.zero_monom + if zm not in p1.keys(): + p[zm] = n + else: + if n == -p[zm]: + del p[zm] + else: + p[zm] += n + return p + + def __sub__(p1, p2): + """Subtract polynomial p2 from p1. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.rings import ring + + >>> _, x, y = ring('x, y', ZZ) + >>> p1 = x + y**2 + >>> p2 = x*y + y**2 + >>> p1 - p2 + -x*y + x + + """ + if not p2: + return p1.copy() + ring = p1.ring + if ring.is_element(p2): + p = p1.copy() + get = p.get + zero = ring.domain.zero + for k, v in p2.items(): + v = get(k, zero) - v + if v: + p[k] = v + else: + del p[k] + return p + elif isinstance(p2, PolyElement): + if isinstance(ring.domain, PolynomialRing) and ring.domain.ring == p2.ring: + pass + elif isinstance(p2.ring.domain, PolynomialRing) and p2.ring.domain.ring == ring: + return p2.__rsub__(p1) + else: + return NotImplemented + + try: + p2 = ring.domain_new(p2) + except CoercionFailed: + return NotImplemented + else: + p = p1.copy() + zm = ring.zero_monom + if zm not in p1.keys(): + p[zm] = -p2 + else: + if p2 == p[zm]: + del p[zm] + else: + p[zm] -= p2 + return p + + def __rsub__(p1, n): + """n - p1 with n convertible to the coefficient domain. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.rings import ring + + >>> _, x, y = ring('x, y', ZZ) + >>> p = x + y + >>> 4 - p + -x - y + 4 + + """ + ring = p1.ring + try: + n = ring.domain_new(n) + except CoercionFailed: + return NotImplemented + else: + p = ring.zero + for expv in p1: + p[expv] = -p1[expv] + p += n + # p._check() + return p + + def __mul__(p1, p2): + """Multiply two polynomials. + + Examples + ======== + + >>> from sympy.polys.domains import QQ + >>> from sympy.polys.rings import ring + + >>> _, x, y = ring('x, y', QQ) + >>> p1 = x + y + >>> p2 = x - y + >>> p1*p2 + x**2 - y**2 + + """ + ring = p1.ring + p = ring.zero + if not p1 or not p2: + return p + elif ring.is_element(p2): + get = p.get + zero = ring.domain.zero + monomial_mul = ring.monomial_mul + p2it = list(p2.items()) + for exp1, v1 in p1.items(): + for exp2, v2 in p2it: + exp = monomial_mul(exp1, exp2) + p[exp] = get(exp, zero) + v1*v2 + p.strip_zero() + # p._check() + return p + elif isinstance(p2, PolyElement): + if isinstance(ring.domain, PolynomialRing) and ring.domain.ring == p2.ring: + pass + elif isinstance(p2.ring.domain, PolynomialRing) and p2.ring.domain.ring == ring: + return p2.__rmul__(p1) + else: + return NotImplemented + + try: + p2 = ring.domain_new(p2) + except CoercionFailed: + return NotImplemented + else: + for exp1, v1 in p1.items(): + v = v1*p2 + if v: + p[exp1] = v + # p._check() + return p + + def __rmul__(p1, p2): + """p2 * p1 with p2 in the coefficient domain of p1. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.rings import ring + + >>> _, x, y = ring('x, y', ZZ) + >>> p = x + y + >>> 4 * p + 4*x + 4*y + + """ + p = p1.ring.zero + if not p2: + return p + try: + p2 = p.ring.domain_new(p2) + except CoercionFailed: + return NotImplemented + else: + for exp1, v1 in p1.items(): + v = p2*v1 + if v: + p[exp1] = v + return p + + def __pow__(self, n): + """raise polynomial to power `n` + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.rings import ring + + >>> _, x, y = ring('x, y', ZZ) + >>> p = x + y**2 + >>> p**3 + x**3 + 3*x**2*y**2 + 3*x*y**4 + y**6 + + """ + if not isinstance(n, int): + raise TypeError("exponent must be an integer, got %s" % n) + elif n < 0: + raise ValueError("exponent must be a non-negative integer, got %s" % n) + + ring = self.ring + + if not n: + if self: + return ring.one + else: + raise ValueError("0**0") + elif len(self) == 1: + monom, coeff = list(self.items())[0] + p = ring.zero + if coeff == ring.domain.one: + p[ring.monomial_pow(monom, n)] = coeff + else: + p[ring.monomial_pow(monom, n)] = coeff**n + # p._check() + return p + + # For ring series, we need negative and rational exponent support only + # with monomials. + n = int(n) + if n < 0: + raise ValueError("Negative exponent") + + elif n == 1: + return self.copy() + elif n == 2: + return self.square() + elif n == 3: + return self*self.square() + elif len(self) <= 5: # TODO: use an actual density measure + return self._pow_multinomial(n) + else: + return self._pow_generic(n) + + def _pow_generic(self, n): + p = self.ring.one + c = self + + while True: + if n & 1: + p = p*c + n -= 1 + if not n: + break + + c = c.square() + n = n // 2 + + return p + + def _pow_multinomial(self, n): + multinomials = multinomial_coefficients(len(self), n).items() + monomial_mulpow = self.ring.monomial_mulpow + zero_monom = self.ring.zero_monom + terms = self.items() + zero = self.ring.domain.zero + poly = self.ring.zero + + for multinomial, multinomial_coeff in multinomials: + product_monom = zero_monom + product_coeff = multinomial_coeff + + for exp, (monom, coeff) in zip(multinomial, terms): + if exp: + product_monom = monomial_mulpow(product_monom, monom, exp) + product_coeff *= coeff**exp + + monom = tuple(product_monom) + coeff = product_coeff + + coeff = poly.get(monom, zero) + coeff + + if coeff: + poly[monom] = coeff + elif monom in poly: + del poly[monom] + + return poly + + def square(self): + """square of a polynomial + + Examples + ======== + + >>> from sympy.polys.rings import ring + >>> from sympy.polys.domains import ZZ + + >>> _, x, y = ring('x, y', ZZ) + >>> p = x + y**2 + >>> p.square() + x**2 + 2*x*y**2 + y**4 + + """ + ring = self.ring + p = ring.zero + get = p.get + keys = list(self.keys()) + zero = ring.domain.zero + monomial_mul = ring.monomial_mul + for i in range(len(keys)): + k1 = keys[i] + pk = self[k1] + for j in range(i): + k2 = keys[j] + exp = monomial_mul(k1, k2) + p[exp] = get(exp, zero) + pk*self[k2] + p = p.imul_num(2) + get = p.get + for k, v in self.items(): + k2 = monomial_mul(k, k) + p[k2] = get(k2, zero) + v**2 + p.strip_zero() + # p._check() + return p + + def __divmod__(p1, p2): + ring = p1.ring + + if not p2: + raise ZeroDivisionError("polynomial division") + elif ring.is_element(p2): + return p1.div(p2) + elif isinstance(p2, PolyElement): + if isinstance(ring.domain, PolynomialRing) and ring.domain.ring == p2.ring: + pass + elif isinstance(p2.ring.domain, PolynomialRing) and p2.ring.domain.ring == ring: + return p2.__rdivmod__(p1) + else: + return NotImplemented + + try: + p2 = ring.domain_new(p2) + except CoercionFailed: + return NotImplemented + else: + return (p1.quo_ground(p2), p1.rem_ground(p2)) + + def __rdivmod__(p1, p2): + ring = p1.ring + try: + p2 = ring.ground_new(p2) + except CoercionFailed: + return NotImplemented + else: + return p2.div(p1) + + def __mod__(p1, p2): + ring = p1.ring + + if not p2: + raise ZeroDivisionError("polynomial division") + elif ring.is_element(p2): + return p1.rem(p2) + elif isinstance(p2, PolyElement): + if isinstance(ring.domain, PolynomialRing) and ring.domain.ring == p2.ring: + pass + elif isinstance(p2.ring.domain, PolynomialRing) and p2.ring.domain.ring == ring: + return p2.__rmod__(p1) + else: + return NotImplemented + + try: + p2 = ring.domain_new(p2) + except CoercionFailed: + return NotImplemented + else: + return p1.rem_ground(p2) + + def __rmod__(p1, p2): + ring = p1.ring + try: + p2 = ring.ground_new(p2) + except CoercionFailed: + return NotImplemented + else: + return p2.rem(p1) + + def __floordiv__(p1, p2): + ring = p1.ring + + if not p2: + raise ZeroDivisionError("polynomial division") + elif ring.is_element(p2): + return p1.quo(p2) + elif isinstance(p2, PolyElement): + if isinstance(ring.domain, PolynomialRing) and ring.domain.ring == p2.ring: + pass + elif isinstance(p2.ring.domain, PolynomialRing) and p2.ring.domain.ring == ring: + return p2.__rtruediv__(p1) + else: + return NotImplemented + + try: + p2 = ring.domain_new(p2) + except CoercionFailed: + return NotImplemented + else: + return p1.quo_ground(p2) + + def __rfloordiv__(p1, p2): + ring = p1.ring + try: + p2 = ring.ground_new(p2) + except CoercionFailed: + return NotImplemented + else: + return p2.quo(p1) + + def __truediv__(p1, p2): + ring = p1.ring + + if not p2: + raise ZeroDivisionError("polynomial division") + elif ring.is_element(p2): + return p1.exquo(p2) + elif isinstance(p2, PolyElement): + if isinstance(ring.domain, PolynomialRing) and ring.domain.ring == p2.ring: + pass + elif isinstance(p2.ring.domain, PolynomialRing) and p2.ring.domain.ring == ring: + return p2.__rtruediv__(p1) + else: + return NotImplemented + + try: + p2 = ring.domain_new(p2) + except CoercionFailed: + return NotImplemented + else: + return p1.quo_ground(p2) + + def __rtruediv__(p1, p2): + ring = p1.ring + try: + p2 = ring.ground_new(p2) + except CoercionFailed: + return NotImplemented + else: + return p2.exquo(p1) + + def _term_div(self): + zm = self.ring.zero_monom + domain = self.ring.domain + domain_quo = domain.quo + monomial_div = self.ring.monomial_div + + if domain.is_Field: + def term_div(a_lm_a_lc, b_lm_b_lc): + a_lm, a_lc = a_lm_a_lc + b_lm, b_lc = b_lm_b_lc + if b_lm == zm: # apparently this is a very common case + monom = a_lm + else: + monom = monomial_div(a_lm, b_lm) + if monom is not None: + return monom, domain_quo(a_lc, b_lc) + else: + return None + else: + def term_div(a_lm_a_lc, b_lm_b_lc): + a_lm, a_lc = a_lm_a_lc + b_lm, b_lc = b_lm_b_lc + if b_lm == zm: # apparently this is a very common case + monom = a_lm + else: + monom = monomial_div(a_lm, b_lm) + if not (monom is None or a_lc % b_lc): + return monom, domain_quo(a_lc, b_lc) + else: + return None + + return term_div + + def div(self, fv): + """Division algorithm, see [CLO] p64. + + fv array of polynomials + return qv, r such that + self = sum(fv[i]*qv[i]) + r + + All polynomials are required not to be Laurent polynomials. + + Examples + ======== + + >>> from sympy.polys.rings import ring + >>> from sympy.polys.domains import ZZ + + >>> _, x, y = ring('x, y', ZZ) + >>> f = x**3 + >>> f0 = x - y**2 + >>> f1 = x - y + >>> qv, r = f.div((f0, f1)) + >>> qv[0] + x**2 + x*y**2 + y**4 + >>> qv[1] + 0 + >>> r + y**6 + + """ + ring = self.ring + ret_single = False + if isinstance(fv, PolyElement): + ret_single = True + fv = [fv] + if not all(fv): + raise ZeroDivisionError("polynomial division") + if not self: + if ret_single: + return ring.zero, ring.zero + else: + return [], ring.zero + for f in fv: + if f.ring != ring: + raise ValueError('self and f must have the same ring') + s = len(fv) + qv = [ring.zero for i in range(s)] + p = self.copy() + r = ring.zero + term_div = self._term_div() + expvs = [fx.leading_expv() for fx in fv] + while p: + i = 0 + divoccurred = 0 + while i < s and divoccurred == 0: + expv = p.leading_expv() + term = term_div((expv, p[expv]), (expvs[i], fv[i][expvs[i]])) + if term is not None: + expv1, c = term + qv[i] = qv[i]._iadd_monom((expv1, c)) + p = p._iadd_poly_monom(fv[i], (expv1, -c)) + divoccurred = 1 + else: + i += 1 + if not divoccurred: + expv = p.leading_expv() + r = r._iadd_monom((expv, p[expv])) + del p[expv] + if expv == ring.zero_monom: + r += p + if ret_single: + if not qv: + return ring.zero, r + else: + return qv[0], r + else: + return qv, r + + def rem(self, G): + f = self + if isinstance(G, PolyElement): + G = [G] + if not all(G): + raise ZeroDivisionError("polynomial division") + ring = f.ring + domain = ring.domain + zero = domain.zero + monomial_mul = ring.monomial_mul + r = ring.zero + term_div = f._term_div() + ltf = f.LT + f = f.copy() + get = f.get + while f: + for g in G: + tq = term_div(ltf, g.LT) + if tq is not None: + m, c = tq + for mg, cg in g.iterterms(): + m1 = monomial_mul(mg, m) + c1 = get(m1, zero) - c*cg + if not c1: + del f[m1] + else: + f[m1] = c1 + ltm = f.leading_expv() + if ltm is not None: + ltf = ltm, f[ltm] + + break + else: + ltm, ltc = ltf + if ltm in r: + r[ltm] += ltc + else: + r[ltm] = ltc + del f[ltm] + ltm = f.leading_expv() + if ltm is not None: + ltf = ltm, f[ltm] + + return r + + def quo(f, G): + return f.div(G)[0] + + def exquo(f, G): + q, r = f.div(G) + + if not r: + return q + else: + raise ExactQuotientFailed(f, G) + + def _iadd_monom(self, mc): + """add to self the monomial coeff*x0**i0*x1**i1*... + unless self is a generator -- then just return the sum of the two. + + mc is a tuple, (monom, coeff), where monomial is (i0, i1, ...) + + Examples + ======== + + >>> from sympy.polys.rings import ring + >>> from sympy.polys.domains import ZZ + + >>> _, x, y = ring('x, y', ZZ) + >>> p = x**4 + 2*y + >>> m = (1, 2) + >>> p1 = p._iadd_monom((m, 5)) + >>> p1 + x**4 + 5*x*y**2 + 2*y + >>> p1 is p + True + >>> p = x + >>> p1 = p._iadd_monom((m, 5)) + >>> p1 + 5*x*y**2 + x + >>> p1 is p + False + + """ + if self in self.ring._gens_set: + cpself = self.copy() + else: + cpself = self + expv, coeff = mc + c = cpself.get(expv) + if c is None: + cpself[expv] = coeff + else: + c += coeff + if c: + cpself[expv] = c + else: + del cpself[expv] + return cpself + + def _iadd_poly_monom(self, p2, mc): + """add to self the product of (p)*(coeff*x0**i0*x1**i1*...) + unless self is a generator -- then just return the sum of the two. + + mc is a tuple, (monom, coeff), where monomial is (i0, i1, ...) + + Examples + ======== + + >>> from sympy.polys.rings import ring + >>> from sympy.polys.domains import ZZ + + >>> _, x, y, z = ring('x, y, z', ZZ) + >>> p1 = x**4 + 2*y + >>> p2 = y + z + >>> m = (1, 2, 3) + >>> p1 = p1._iadd_poly_monom(p2, (m, 3)) + >>> p1 + x**4 + 3*x*y**3*z**3 + 3*x*y**2*z**4 + 2*y + + """ + p1 = self + if p1 in p1.ring._gens_set: + p1 = p1.copy() + (m, c) = mc + get = p1.get + zero = p1.ring.domain.zero + monomial_mul = p1.ring.monomial_mul + for k, v in p2.items(): + ka = monomial_mul(k, m) + coeff = get(ka, zero) + v*c + if coeff: + p1[ka] = coeff + else: + del p1[ka] + return p1 + + def degree(f, x=None): + """ + The leading degree in ``x`` or the main variable. + + Note that the degree of 0 is negative infinity (``float('-inf')``) + + """ + i = f.ring.index(x) + + if not f: + return ninf + elif i < 0: + return 0 + else: + return max(monom[i] for monom in f.itermonoms()) + + def degrees(f): + """ + A tuple containing leading degrees in all variables. + + Note that the degree of 0 is negative infinity (``float('-inf')``) + + """ + if not f: + return (ninf,)*f.ring.ngens + else: + return tuple(map(max, list(zip(*f.itermonoms())))) + + def tail_degree(f, x=None): + """ + The tail degree in ``x`` or the main variable. + + Note that the degree of 0 is negative infinity (``float('-inf')``) + + """ + i = f.ring.index(x) + + if not f: + return ninf + elif i < 0: + return 0 + else: + return min(monom[i] for monom in f.itermonoms()) + + def tail_degrees(f): + """ + A tuple containing tail degrees in all variables. + + Note that the degree of 0 is negative infinity (``float('-inf')``) + + """ + if not f: + return (ninf,)*f.ring.ngens + else: + return tuple(map(min, list(zip(*f.itermonoms())))) + + def leading_expv(self): + """Leading monomial tuple according to the monomial ordering. + + Examples + ======== + + >>> from sympy.polys.rings import ring + >>> from sympy.polys.domains import ZZ + + >>> _, x, y, z = ring('x, y, z', ZZ) + >>> p = x**4 + x**3*y + x**2*z**2 + z**7 + >>> p.leading_expv() + (4, 0, 0) + + """ + if self: + return self.ring.leading_expv(self) + else: + return None + + def _get_coeff(self, expv): + return self.get(expv, self.ring.domain.zero) + + def coeff(self, element): + """ + Returns the coefficient that stands next to the given monomial. + + Parameters + ========== + + element : PolyElement (with ``is_monomial = True``) or 1 + + Examples + ======== + + >>> from sympy.polys.rings import ring + >>> from sympy.polys.domains import ZZ + + >>> _, x, y, z = ring("x,y,z", ZZ) + >>> f = 3*x**2*y - x*y*z + 7*z**3 + 23 + + >>> f.coeff(x**2*y) + 3 + >>> f.coeff(x*y) + 0 + >>> f.coeff(1) + 23 + + """ + if element == 1: + return self._get_coeff(self.ring.zero_monom) + elif self.ring.is_element(element): + terms = list(element.iterterms()) + if len(terms) == 1: + monom, coeff = terms[0] + if coeff == self.ring.domain.one: + return self._get_coeff(monom) + + raise ValueError("expected a monomial, got %s" % element) + + def const(self): + """Returns the constant coefficient. """ + return self._get_coeff(self.ring.zero_monom) + + @property + def LC(self): + return self._get_coeff(self.leading_expv()) + + @property + def LM(self): + expv = self.leading_expv() + if expv is None: + return self.ring.zero_monom + else: + return expv + + def leading_monom(self): + """ + Leading monomial as a polynomial element. + + Examples + ======== + + >>> from sympy.polys.rings import ring + >>> from sympy.polys.domains import ZZ + + >>> _, x, y = ring('x, y', ZZ) + >>> (3*x*y + y**2).leading_monom() + x*y + + """ + p = self.ring.zero + expv = self.leading_expv() + if expv: + p[expv] = self.ring.domain.one + return p + + @property + def LT(self): + expv = self.leading_expv() + if expv is None: + return (self.ring.zero_monom, self.ring.domain.zero) + else: + return (expv, self._get_coeff(expv)) + + def leading_term(self): + """Leading term as a polynomial element. + + Examples + ======== + + >>> from sympy.polys.rings import ring + >>> from sympy.polys.domains import ZZ + + >>> _, x, y = ring('x, y', ZZ) + >>> (3*x*y + y**2).leading_term() + 3*x*y + + """ + p = self.ring.zero + expv = self.leading_expv() + if expv is not None: + p[expv] = self[expv] + return p + + def _sorted(self, seq, order): + if order is None: + order = self.ring.order + else: + order = OrderOpt.preprocess(order) + + if order is lex: + return sorted(seq, key=lambda monom: monom[0], reverse=True) + else: + return sorted(seq, key=lambda monom: order(monom[0]), reverse=True) + + def coeffs(self, order=None): + """Ordered list of polynomial coefficients. + + Parameters + ========== + + order : :class:`~.MonomialOrder` or coercible, optional + + Examples + ======== + + >>> from sympy.polys.rings import ring + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.orderings import lex, grlex + + >>> _, x, y = ring("x, y", ZZ, lex) + >>> f = x*y**7 + 2*x**2*y**3 + + >>> f.coeffs() + [2, 1] + >>> f.coeffs(grlex) + [1, 2] + + """ + return [ coeff for _, coeff in self.terms(order) ] + + def monoms(self, order=None): + """Ordered list of polynomial monomials. + + Parameters + ========== + + order : :class:`~.MonomialOrder` or coercible, optional + + Examples + ======== + + >>> from sympy.polys.rings import ring + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.orderings import lex, grlex + + >>> _, x, y = ring("x, y", ZZ, lex) + >>> f = x*y**7 + 2*x**2*y**3 + + >>> f.monoms() + [(2, 3), (1, 7)] + >>> f.monoms(grlex) + [(1, 7), (2, 3)] + + """ + return [ monom for monom, _ in self.terms(order) ] + + def terms(self, order=None): + """Ordered list of polynomial terms. + + Parameters + ========== + + order : :class:`~.MonomialOrder` or coercible, optional + + Examples + ======== + + >>> from sympy.polys.rings import ring + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.orderings import lex, grlex + + >>> _, x, y = ring("x, y", ZZ, lex) + >>> f = x*y**7 + 2*x**2*y**3 + + >>> f.terms() + [((2, 3), 2), ((1, 7), 1)] + >>> f.terms(grlex) + [((1, 7), 1), ((2, 3), 2)] + + """ + return self._sorted(list(self.items()), order) + + def itercoeffs(self): + """Iterator over coefficients of a polynomial. """ + return iter(self.values()) + + def itermonoms(self): + """Iterator over monomials of a polynomial. """ + return iter(self.keys()) + + def iterterms(self): + """Iterator over terms of a polynomial. """ + return iter(self.items()) + + def listcoeffs(self): + """Unordered list of polynomial coefficients. """ + return list(self.values()) + + def listmonoms(self): + """Unordered list of polynomial monomials. """ + return list(self.keys()) + + def listterms(self): + """Unordered list of polynomial terms. """ + return list(self.items()) + + def imul_num(p, c): + """multiply inplace the polynomial p by an element in the + coefficient ring, provided p is not one of the generators; + else multiply not inplace + + Examples + ======== + + >>> from sympy.polys.rings import ring + >>> from sympy.polys.domains import ZZ + + >>> _, x, y = ring('x, y', ZZ) + >>> p = x + y**2 + >>> p1 = p.imul_num(3) + >>> p1 + 3*x + 3*y**2 + >>> p1 is p + True + >>> p = x + >>> p1 = p.imul_num(3) + >>> p1 + 3*x + >>> p1 is p + False + + """ + if p in p.ring._gens_set: + return p*c + if not c: + p.clear() + return + for exp in p: + p[exp] *= c + return p + + def content(f): + """Returns GCD of polynomial's coefficients. """ + domain = f.ring.domain + cont = domain.zero + gcd = domain.gcd + + for coeff in f.itercoeffs(): + cont = gcd(cont, coeff) + + return cont + + def primitive(f): + """Returns content and a primitive polynomial. """ + cont = f.content() + if cont == f.ring.domain.zero: + return (cont, f) + return cont, f.quo_ground(cont) + + def monic(f): + """Divides all coefficients by the leading coefficient. """ + if not f: + return f + else: + return f.quo_ground(f.LC) + + def mul_ground(f, x): + if not x: + return f.ring.zero + + terms = [ (monom, coeff*x) for monom, coeff in f.iterterms() ] + return f.new(terms) + + def mul_monom(f, monom): + monomial_mul = f.ring.monomial_mul + terms = [ (monomial_mul(f_monom, monom), f_coeff) for f_monom, f_coeff in f.items() ] + return f.new(terms) + + def mul_term(f, term): + monom, coeff = term + + if not f or not coeff: + return f.ring.zero + elif monom == f.ring.zero_monom: + return f.mul_ground(coeff) + + monomial_mul = f.ring.monomial_mul + terms = [ (monomial_mul(f_monom, monom), f_coeff*coeff) for f_monom, f_coeff in f.items() ] + return f.new(terms) + + def quo_ground(f, x): + domain = f.ring.domain + + if not x: + raise ZeroDivisionError('polynomial division') + if not f or x == domain.one: + return f + + if domain.is_Field: + quo = domain.quo + terms = [ (monom, quo(coeff, x)) for monom, coeff in f.iterterms() ] + else: + terms = [ (monom, coeff // x) for monom, coeff in f.iterterms() if not (coeff % x) ] + + return f.new(terms) + + def quo_term(f, term): + monom, coeff = term + + if not coeff: + raise ZeroDivisionError("polynomial division") + elif not f: + return f.ring.zero + elif monom == f.ring.zero_monom: + return f.quo_ground(coeff) + + term_div = f._term_div() + + terms = [ term_div(t, term) for t in f.iterterms() ] + return f.new([ t for t in terms if t is not None ]) + + def trunc_ground(f, p): + if f.ring.domain.is_ZZ: + terms = [] + + for monom, coeff in f.iterterms(): + coeff = coeff % p + + if coeff > p // 2: + coeff = coeff - p + + terms.append((monom, coeff)) + else: + terms = [ (monom, coeff % p) for monom, coeff in f.iterterms() ] + + poly = f.new(terms) + poly.strip_zero() + return poly + + rem_ground = trunc_ground + + def extract_ground(self, g): + f = self + fc = f.content() + gc = g.content() + + gcd = f.ring.domain.gcd(fc, gc) + + f = f.quo_ground(gcd) + g = g.quo_ground(gcd) + + return gcd, f, g + + def _norm(f, norm_func): + if not f: + return f.ring.domain.zero + else: + ground_abs = f.ring.domain.abs + return norm_func([ ground_abs(coeff) for coeff in f.itercoeffs() ]) + + def max_norm(f): + return f._norm(max) + + def l1_norm(f): + return f._norm(sum) + + def deflate(f, *G): + ring = f.ring + polys = [f] + list(G) + + J = [0]*ring.ngens + + for p in polys: + for monom in p.itermonoms(): + for i, m in enumerate(monom): + J[i] = igcd(J[i], m) + + for i, b in enumerate(J): + if not b: + J[i] = 1 + + J = tuple(J) + + if all(b == 1 for b in J): + return J, polys + + H = [] + + for p in polys: + h = ring.zero + + for I, coeff in p.iterterms(): + N = [ i // j for i, j in zip(I, J) ] + h[tuple(N)] = coeff + + H.append(h) + + return J, H + + def inflate(f, J): + poly = f.ring.zero + + for I, coeff in f.iterterms(): + N = [ i*j for i, j in zip(I, J) ] + poly[tuple(N)] = coeff + + return poly + + def lcm(self, g): + f = self + domain = f.ring.domain + + if not domain.is_Field: + fc, f = f.primitive() + gc, g = g.primitive() + c = domain.lcm(fc, gc) + + h = (f*g).quo(f.gcd(g)) + + if not domain.is_Field: + return h.mul_ground(c) + else: + return h.monic() + + def gcd(f, g): + return f.cofactors(g)[0] + + def cofactors(f, g): + if not f and not g: + zero = f.ring.zero + return zero, zero, zero + elif not f: + h, cff, cfg = f._gcd_zero(g) + return h, cff, cfg + elif not g: + h, cfg, cff = g._gcd_zero(f) + return h, cff, cfg + elif len(f) == 1: + h, cff, cfg = f._gcd_monom(g) + return h, cff, cfg + elif len(g) == 1: + h, cfg, cff = g._gcd_monom(f) + return h, cff, cfg + + J, (f, g) = f.deflate(g) + h, cff, cfg = f._gcd(g) + + return (h.inflate(J), cff.inflate(J), cfg.inflate(J)) + + def _gcd_zero(f, g): + one, zero = f.ring.one, f.ring.zero + if g.is_nonnegative: + return g, zero, one + else: + return -g, zero, -one + + def _gcd_monom(f, g): + ring = f.ring + ground_gcd = ring.domain.gcd + ground_quo = ring.domain.quo + monomial_gcd = ring.monomial_gcd + monomial_ldiv = ring.monomial_ldiv + mf, cf = list(f.iterterms())[0] + _mgcd, _cgcd = mf, cf + for mg, cg in g.iterterms(): + _mgcd = monomial_gcd(_mgcd, mg) + _cgcd = ground_gcd(_cgcd, cg) + h = f.new([(_mgcd, _cgcd)]) + cff = f.new([(monomial_ldiv(mf, _mgcd), ground_quo(cf, _cgcd))]) + cfg = f.new([(monomial_ldiv(mg, _mgcd), ground_quo(cg, _cgcd)) for mg, cg in g.iterterms()]) + return h, cff, cfg + + def _gcd(f, g): + ring = f.ring + + if ring.domain.is_QQ: + return f._gcd_QQ(g) + elif ring.domain.is_ZZ: + return f._gcd_ZZ(g) + else: # TODO: don't use dense representation (port PRS algorithms) + return ring.dmp_inner_gcd(f, g) + + def _gcd_ZZ(f, g): + return heugcd(f, g) + + def _gcd_QQ(self, g): + f = self + ring = f.ring + new_ring = ring.clone(domain=ring.domain.get_ring()) + + cf, f = f.clear_denoms() + cg, g = g.clear_denoms() + + f = f.set_ring(new_ring) + g = g.set_ring(new_ring) + + h, cff, cfg = f._gcd_ZZ(g) + + h = h.set_ring(ring) + c, h = h.LC, h.monic() + + cff = cff.set_ring(ring).mul_ground(ring.domain.quo(c, cf)) + cfg = cfg.set_ring(ring).mul_ground(ring.domain.quo(c, cg)) + + return h, cff, cfg + + def cancel(self, g): + """ + Cancel common factors in a rational function ``f/g``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + >>> (2*x**2 - 2).cancel(x**2 - 2*x + 1) + (2*x + 2, x - 1) + + """ + f = self + ring = f.ring + + if not f: + return f, ring.one + + domain = ring.domain + + if not (domain.is_Field and domain.has_assoc_Ring): + _, p, q = f.cofactors(g) + else: + new_ring = ring.clone(domain=domain.get_ring()) + + cq, f = f.clear_denoms() + cp, g = g.clear_denoms() + + f = f.set_ring(new_ring) + g = g.set_ring(new_ring) + + _, p, q = f.cofactors(g) + _, cp, cq = new_ring.domain.cofactors(cp, cq) + + p = p.set_ring(ring) + q = q.set_ring(ring) + + p = p.mul_ground(cp) + q = q.mul_ground(cq) + + # Make canonical with respect to sign or quadrant in the case of ZZ_I + # or QQ_I. This ensures that the LC of the denominator is canonical by + # multiplying top and bottom by a unit of the ring. + u = q.canonical_unit() + if u == domain.one: + pass + elif u == -domain.one: + p, q = -p, -q + else: + p = p.mul_ground(u) + q = q.mul_ground(u) + + return p, q + + def canonical_unit(f): + domain = f.ring.domain + return domain.canonical_unit(f.LC) + + def diff(f, x): + """Computes partial derivative in ``x``. + + Examples + ======== + + >>> from sympy.polys.rings import ring + >>> from sympy.polys.domains import ZZ + + >>> _, x, y = ring("x,y", ZZ) + >>> p = x + x**2*y**3 + >>> p.diff(x) + 2*x*y**3 + 1 + + """ + ring = f.ring + i = ring.index(x) + m = ring.monomial_basis(i) + g = ring.zero + for expv, coeff in f.iterterms(): + if expv[i]: + e = ring.monomial_ldiv(expv, m) + g[e] = ring.domain_new(coeff*expv[i]) + return g + + def __call__(f, *values): + if 0 < len(values) <= f.ring.ngens: + return f.evaluate(list(zip(f.ring.gens, values))) + else: + raise ValueError("expected at least 1 and at most %s values, got %s" % (f.ring.ngens, len(values))) + + def evaluate(self, x, a=None): + f = self + + if isinstance(x, list) and a is None: + (X, a), x = x[0], x[1:] + f = f.evaluate(X, a) + + if not x: + return f + else: + x = [ (Y.drop(X), a) for (Y, a) in x ] + return f.evaluate(x) + + ring = f.ring + i = ring.index(x) + a = ring.domain.convert(a) + + if ring.ngens == 1: + result = ring.domain.zero + + for (n,), coeff in f.iterterms(): + result += coeff*a**n + + return result + else: + poly = ring.drop(x).zero + + for monom, coeff in f.iterterms(): + n, monom = monom[i], monom[:i] + monom[i+1:] + coeff = coeff*a**n + + if monom in poly: + coeff = coeff + poly[monom] + + if coeff: + poly[monom] = coeff + else: + del poly[monom] + else: + if coeff: + poly[monom] = coeff + + return poly + + def subs(self, x, a=None): + f = self + + if isinstance(x, list) and a is None: + for X, a in x: + f = f.subs(X, a) + return f + + ring = f.ring + i = ring.index(x) + a = ring.domain.convert(a) + + if ring.ngens == 1: + result = ring.domain.zero + + for (n,), coeff in f.iterterms(): + result += coeff*a**n + + return ring.ground_new(result) + else: + poly = ring.zero + + for monom, coeff in f.iterterms(): + n, monom = monom[i], monom[:i] + (0,) + monom[i+1:] + coeff = coeff*a**n + + if monom in poly: + coeff = coeff + poly[monom] + + if coeff: + poly[monom] = coeff + else: + del poly[monom] + else: + if coeff: + poly[monom] = coeff + + return poly + + def symmetrize(self): + r""" + Rewrite *self* in terms of elementary symmetric polynomials. + + Explanation + =========== + + If this :py:class:`~.PolyElement` belongs to a ring of $n$ variables, + we can try to write it as a function of the elementary symmetric + polynomials on $n$ variables. We compute a symmetric part, and a + remainder for any part we were not able to symmetrize. + + Examples + ======== + + >>> from sympy.polys.rings import ring + >>> from sympy.polys.domains import ZZ + >>> R, x, y = ring("x,y", ZZ) + + >>> f = x**2 + y**2 + >>> f.symmetrize() + (x**2 - 2*y, 0, [(x, x + y), (y, x*y)]) + + >>> f = x**2 - y**2 + >>> f.symmetrize() + (x**2 - 2*y, -2*y**2, [(x, x + y), (y, x*y)]) + + Returns + ======= + + Triple ``(p, r, m)`` + ``p`` is a :py:class:`~.PolyElement` that represents our attempt + to express *self* as a function of elementary symmetric + polynomials. Each variable in ``p`` stands for one of the + elementary symmetric polynomials. The correspondence is given + by ``m``. + + ``r`` is the remainder. + + ``m`` is a list of pairs, giving the mapping from variables in + ``p`` to elementary symmetric polynomials. + + The triple satisfies the equation ``p.compose(m) + r == self``. + If the remainder ``r`` is zero, *self* is symmetric. If it is + nonzero, we were not able to represent *self* as symmetric. + + See Also + ======== + + sympy.polys.polyfuncs.symmetrize + + References + ========== + + .. [1] Lauer, E. Algorithms for symmetrical polynomials, Proc. 1976 + ACM Symp. on Symbolic and Algebraic Computing, NY 242-247. + https://dl.acm.org/doi/pdf/10.1145/800205.806342 + + """ + f = self.copy() + ring = f.ring + n = ring.ngens + + if not n: + return f, ring.zero, [] + + polys = [ring.symmetric_poly(i+1) for i in range(n)] + + poly_powers = {} + def get_poly_power(i, n): + if (i, n) not in poly_powers: + poly_powers[(i, n)] = polys[i]**n + return poly_powers[(i, n)] + + indices = list(range(n - 1)) + weights = list(range(n, 0, -1)) + + symmetric = ring.zero + + while f: + _height, _monom, _coeff = -1, None, None + + for i, (monom, coeff) in enumerate(f.terms()): + if all(monom[i] >= monom[i + 1] for i in indices): + height = max(n*m for n, m in zip(weights, monom)) + + if height > _height: + _height, _monom, _coeff = height, monom, coeff + + if _height != -1: + monom, coeff = _monom, _coeff + else: + break + + exponents = [] + for m1, m2 in zip(monom, monom[1:] + (0,)): + exponents.append(m1 - m2) + + symmetric += ring.term_new(tuple(exponents), coeff) + + product = coeff + for i, n in enumerate(exponents): + product *= get_poly_power(i, n) + f -= product + + mapping = list(zip(ring.gens, polys)) + + return symmetric, f, mapping + + def compose(f, x, a=None): + ring = f.ring + poly = ring.zero + gens_map = dict(zip(ring.gens, range(ring.ngens))) + + if a is not None: + replacements = [(x, a)] + else: + if isinstance(x, list): + replacements = list(x) + elif isinstance(x, dict): + replacements = sorted(x.items(), key=lambda k: gens_map[k[0]]) + else: + raise ValueError("expected a generator, value pair a sequence of such pairs") + + for k, (x, g) in enumerate(replacements): + replacements[k] = (gens_map[x], ring.ring_new(g)) + + for monom, coeff in f.iterterms(): + monom = list(monom) + subpoly = ring.one + + for i, g in replacements: + n, monom[i] = monom[i], 0 + if n: + subpoly *= g**n + + subpoly = subpoly.mul_term((tuple(monom), coeff)) + poly += subpoly + + return poly + + def coeff_wrt(self, x, deg): + """ + Coefficient of ``self`` with respect to ``x**deg``. + + Treating ``self`` as a univariate polynomial in ``x`` this finds the + coefficient of ``x**deg`` as a polynomial in the other generators. + + Parameters + ========== + + x : generator or generator index + The generator or generator index to compute the expression for. + deg : int + The degree of the monomial to compute the expression for. + + Returns + ======= + + :py:class:`~.PolyElement` + The coefficient of ``x**deg`` as a polynomial in the same ring. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x, y, z = ring("x, y, z", ZZ) + + >>> p = 2*x**4 + 3*y**4 + 10*z**2 + 10*x*z**2 + >>> deg = 2 + >>> p.coeff_wrt(2, deg) # Using the generator index + 10*x + 10 + >>> p.coeff_wrt(z, deg) # Using the generator + 10*x + 10 + >>> p.coeff(z**2) # shows the difference between coeff and coeff_wrt + 10 + + See Also + ======== + + coeff, coeffs + + """ + p = self + i = p.ring.index(x) + terms = [(m, c) for m, c in p.iterterms() if m[i] == deg] + + if not terms: + return p.ring.zero + + monoms, coeffs = zip(*terms) + monoms = [m[:i] + (0,) + m[i + 1:] for m in monoms] + return p.ring.from_dict(dict(zip(monoms, coeffs))) + + def prem(self, g, x=None): + """ + Pseudo-remainder of the polynomial ``self`` with respect to ``g``. + + The pseudo-quotient ``q`` and pseudo-remainder ``r`` with respect to + ``z`` when dividing ``f`` by ``g`` satisfy ``m*f = g*q + r``, + where ``deg(r,z) < deg(g,z)`` and + ``m = LC(g,z)**(deg(f,z) - deg(g,z)+1)``. + + See :meth:`pdiv` for explanation of pseudo-division. + + + Parameters + ========== + + g : :py:class:`~.PolyElement` + The polynomial to divide ``self`` by. + x : generator or generator index, optional + The main variable of the polynomials and default is first generator. + + Returns + ======= + + :py:class:`~.PolyElement` + The pseudo-remainder polynomial. + + Raises + ====== + + ZeroDivisionError : If ``g`` is the zero polynomial. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x, y = ring("x, y", ZZ) + + >>> f = x**2 + x*y + >>> g = 2*x + 2 + >>> f.prem(g) # first generator is chosen by default if it is not given + -4*y + 4 + >>> f.rem(g) # shows the difference between prem and rem + x**2 + x*y + >>> f.prem(g, y) # generator is given + 0 + >>> f.prem(g, 1) # generator index is given + 0 + + See Also + ======== + + pdiv, pquo, pexquo, sympy.polys.domains.ring.Ring.rem + + """ + f = self + x = f.ring.index(x) + df = f.degree(x) + dg = g.degree(x) + + if dg < 0: + raise ZeroDivisionError('polynomial division') + + r, dr = f, df + + if df < dg: + return r + + N = df - dg + 1 + + lc_g = g.coeff_wrt(x, dg) + + xp = f.ring.gens[x] + + while True: + + lc_r = r.coeff_wrt(x, dr) + j, N = dr - dg, N - 1 + + R = r * lc_g + G = g * lc_r * xp**j + r = R - G + + dr = r.degree(x) + + if dr < dg: + break + + c = lc_g ** N + + return r * c + + def pdiv(self, g, x=None): + """ + Computes the pseudo-division of the polynomial ``self`` with respect to ``g``. + + The pseudo-division algorithm is used to find the pseudo-quotient ``q`` + and pseudo-remainder ``r`` such that ``m*f = g*q + r``, where ``m`` + represents the multiplier and ``f`` is the dividend polynomial. + + The pseudo-quotient ``q`` and pseudo-remainder ``r`` are polynomials in + the variable ``x``, with the degree of ``r`` with respect to ``x`` + being strictly less than the degree of ``g`` with respect to ``x``. + + The multiplier ``m`` is defined as + ``LC(g, x) ^ (deg(f, x) - deg(g, x) + 1)``, + where ``LC(g, x)`` represents the leading coefficient of ``g``. + + It is important to note that in the context of the ``prem`` method, + multivariate polynomials in a ring, such as ``R[x,y,z]``, are treated + as univariate polynomials with coefficients that are polynomials, + such as ``R[x,y][z]``. When dividing ``f`` by ``g`` with respect to the + variable ``z``, the pseudo-quotient ``q`` and pseudo-remainder ``r`` + satisfy ``m*f = g*q + r``, where ``deg(r, z) < deg(g, z)`` + and ``m = LC(g, z)^(deg(f, z) - deg(g, z) + 1)``. + + In this function, the pseudo-remainder ``r`` can be obtained using the + ``prem`` method, the pseudo-quotient ``q`` can + be obtained using the ``pquo`` method, and + the function ``pdiv`` itself returns a tuple ``(q, r)``. + + + Parameters + ========== + + g : :py:class:`~.PolyElement` + The polynomial to divide ``self`` by. + x : generator or generator index, optional + The main variable of the polynomials and default is first generator. + + Returns + ======= + + :py:class:`~.PolyElement` + The pseudo-division polynomial (tuple of ``q`` and ``r``). + + Raises + ====== + + ZeroDivisionError : If ``g`` is the zero polynomial. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x, y = ring("x, y", ZZ) + + >>> f = x**2 + x*y + >>> g = 2*x + 2 + >>> f.pdiv(g) # first generator is chosen by default if it is not given + (2*x + 2*y - 2, -4*y + 4) + >>> f.div(g) # shows the difference between pdiv and div + (0, x**2 + x*y) + >>> f.pdiv(g, y) # generator is given + (2*x**3 + 2*x**2*y + 6*x**2 + 2*x*y + 8*x + 4, 0) + >>> f.pdiv(g, 1) # generator index is given + (2*x**3 + 2*x**2*y + 6*x**2 + 2*x*y + 8*x + 4, 0) + + See Also + ======== + + prem + Computes only the pseudo-remainder more efficiently than + `f.pdiv(g)[1]`. + pquo + Returns only the pseudo-quotient. + pexquo + Returns only an exact pseudo-quotient having no remainder. + div + Returns quotient and remainder of f and g polynomials. + + """ + f = self + x = f.ring.index(x) + + df = f.degree(x) + dg = g.degree(x) + + if dg < 0: + raise ZeroDivisionError("polynomial division") + + q, r, dr = x, f, df + + if df < dg: + return q, r + + N = df - dg + 1 + lc_g = g.coeff_wrt(x, dg) + + xp = f.ring.gens[x] + + while True: + + lc_r = r.coeff_wrt(x, dr) + j, N = dr - dg, N - 1 + + Q = q * lc_g + + q = Q + (lc_r)*xp**j + + R = r * lc_g + + G = g * lc_r * xp**j + + r = R - G + + dr = r.degree(x) + + if dr < dg: + break + + c = lc_g**N + + q = q * c + r = r * c + + return q, r + + def pquo(self, g, x=None): + """ + Polynomial pseudo-quotient in multivariate polynomial ring. + + Examples + ======== + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + >>> f = x**2 + x*y + >>> g = 2*x + 2*y + >>> h = 2*x + 2 + >>> f.pquo(g) + 2*x + >>> f.quo(g) # shows the difference between pquo and quo + 0 + >>> f.pquo(h) + 2*x + 2*y - 2 + >>> f.quo(h) # shows the difference between pquo and quo + 0 + + See Also + ======== + + prem, pdiv, pexquo, sympy.polys.domains.ring.Ring.quo + + """ + f = self + return f.pdiv(g, x)[0] + + def pexquo(self, g, x=None): + """ + Polynomial exact pseudo-quotient in multivariate polynomial ring. + + Examples + ======== + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + >>> f = x**2 + x*y + >>> g = 2*x + 2*y + >>> h = 2*x + 2 + >>> f.pexquo(g) + 2*x + >>> f.exquo(g) # shows the difference between pexquo and exquo + Traceback (most recent call last): + ... + ExactQuotientFailed: 2*x + 2*y does not divide x**2 + x*y + >>> f.pexquo(h) + Traceback (most recent call last): + ... + ExactQuotientFailed: 2*x + 2 does not divide x**2 + x*y + + See Also + ======== + + prem, pdiv, pquo, sympy.polys.domains.ring.Ring.exquo + + """ + f = self + q, r = f.pdiv(g, x) + + if r.is_zero: + return q + else: + raise ExactQuotientFailed(f, g) + + def subresultants(self, g, x=None): + """ + Computes the subresultant PRS of two polynomials ``self`` and ``g``. + + Parameters + ========== + + g : :py:class:`~.PolyElement` + The second polynomial. + x : generator or generator index + The variable with respect to which the subresultant sequence is computed. + + Returns + ======= + + R : list + Returns a list polynomials representing the subresultant PRS. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x, y = ring("x, y", ZZ) + + >>> f = x**2*y + x*y + >>> g = x + y + >>> f.subresultants(g) # first generator is chosen by default if not given + [x**2*y + x*y, x + y, y**3 - y**2] + >>> f.subresultants(g, 0) # generator index is given + [x**2*y + x*y, x + y, y**3 - y**2] + >>> f.subresultants(g, y) # generator is given + [x**2*y + x*y, x + y, x**3 + x**2] + + """ + f = self + x = f.ring.index(x) + n = f.degree(x) + m = g.degree(x) + + if n < m: + f, g = g, f + n, m = m, n + + if f == 0: + return [0, 0] + + if g == 0: + return [f, 1] + + R = [f, g] + + d = n - m + b = (-1) ** (d + 1) + + # Compute the pseudo-remainder for f and g + h = f.prem(g, x) + h = h * b + + # Compute the coefficient of g with respect to x**m + lc = g.coeff_wrt(x, m) + + c = lc ** d + + S = [1, c] + + c = -c + + while h: + k = h.degree(x) + + R.append(h) + f, g, m, d = g, h, k, m - k + + b = -lc * c ** d + h = f.prem(g, x) + h = h.exquo(b) + + lc = g.coeff_wrt(x, k) + + if d > 1: + p = (-lc) ** d + q = c ** (d - 1) + c = p.exquo(q) + else: + c = -lc + + S.append(-c) + + return R + + # TODO: following methods should point to polynomial + # representation independent algorithm implementations. + + def half_gcdex(f, g): + return f.ring.dmp_half_gcdex(f, g) + + def gcdex(f, g): + return f.ring.dmp_gcdex(f, g) + + def resultant(f, g): + return f.ring.dmp_resultant(f, g) + + def discriminant(f): + return f.ring.dmp_discriminant(f) + + def decompose(f): + if f.ring.is_univariate: + return f.ring.dup_decompose(f) + else: + raise MultivariatePolynomialError("polynomial decomposition") + + def shift(f, a): + if f.ring.is_univariate: + return f.ring.dup_shift(f, a) + else: + raise MultivariatePolynomialError("shift: use shift_list instead") + + def shift_list(f, a): + return f.ring.dmp_shift(f, a) + + def sturm(f): + if f.ring.is_univariate: + return f.ring.dup_sturm(f) + else: + raise MultivariatePolynomialError("sturm sequence") + + def gff_list(f): + return f.ring.dmp_gff_list(f) + + def norm(f): + return f.ring.dmp_norm(f) + + def sqf_norm(f): + return f.ring.dmp_sqf_norm(f) + + def sqf_part(f): + return f.ring.dmp_sqf_part(f) + + def sqf_list(f, all=False): + return f.ring.dmp_sqf_list(f, all=all) + + def factor_list(f): + return f.ring.dmp_factor_list(f) diff --git a/phivenv/Lib/site-packages/sympy/polys/rootisolation.py b/phivenv/Lib/site-packages/sympy/polys/rootisolation.py new file mode 100644 index 0000000000000000000000000000000000000000..b2f8fd115e49ce8dcf4db8659a60c3361818b7bb --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/rootisolation.py @@ -0,0 +1,2190 @@ +"""Real and complex root isolation and refinement algorithms. """ + + +from sympy.polys.densearith import ( + dup_neg, dup_rshift, dup_rem, + dup_l2_norm_squared) +from sympy.polys.densebasic import ( + dup_LC, dup_TC, dup_degree, + dup_strip, dup_reverse, + dup_convert, + dup_terms_gcd) +from sympy.polys.densetools import ( + dup_clear_denoms, + dup_mirror, dup_scale, dup_shift, + dup_transform, + dup_diff, + dup_eval, dmp_eval_in, + dup_sign_variations, + dup_real_imag) +from sympy.polys.euclidtools import ( + dup_discriminant) +from sympy.polys.factortools import ( + dup_factor_list) +from sympy.polys.polyerrors import ( + RefinementFailed, + DomainError, + PolynomialError) +from sympy.polys.sqfreetools import ( + dup_sqf_part, dup_sqf_list) + + +def dup_sturm(f, K): + """ + Computes the Sturm sequence of ``f`` in ``F[x]``. + + Given a univariate, square-free polynomial ``f(x)`` returns the + associated Sturm sequence ``f_0(x), ..., f_n(x)`` defined by:: + + f_0(x), f_1(x) = f(x), f'(x) + f_n = -rem(f_{n-2}(x), f_{n-1}(x)) + + Examples + ======== + + >>> from sympy.polys import ring, QQ + >>> R, x = ring("x", QQ) + + >>> R.dup_sturm(x**3 - 2*x**2 + x - 3) + [x**3 - 2*x**2 + x - 3, 3*x**2 - 4*x + 1, 2/9*x + 25/9, -2079/4] + + References + ========== + + .. [1] [Davenport88]_ + + """ + if not K.is_Field: + raise DomainError("Cannot compute Sturm sequence over %s" % K) + + f = dup_sqf_part(f, K) + + sturm = [f, dup_diff(f, 1, K)] + + while sturm[-1]: + s = dup_rem(sturm[-2], sturm[-1], K) + sturm.append(dup_neg(s, K)) + + return sturm[:-1] + +def dup_root_upper_bound(f, K): + """Compute the LMQ upper bound for the positive roots of `f`; + LMQ (Local Max Quadratic) was developed by Akritas-Strzebonski-Vigklas. + + References + ========== + .. [1] Alkiviadis G. Akritas: "Linear and Quadratic Complexity Bounds on the + Values of the Positive Roots of Polynomials" + Journal of Universal Computer Science, Vol. 15, No. 3, 523-537, 2009. + """ + n, P = len(f), [] + t = n * [K.one] + if dup_LC(f, K) < 0: + f = dup_neg(f, K) + f = list(reversed(f)) + + for i in range(0, n): + if f[i] >= 0: + continue + + a, QL = K.log(-f[i], 2), [] + + for j in range(i + 1, n): + + if f[j] <= 0: + continue + + q = t[j] + a - K.log(f[j], 2) + QL.append([q // (j - i), j]) + + if not QL: + continue + + q = min(QL) + + t[q[1]] = t[q[1]] + 1 + + P.append(q[0]) + + if not P: + return None + else: + return K.get_field()(2)**(max(P) + 1) + +def dup_root_lower_bound(f, K): + """Compute the LMQ lower bound for the positive roots of `f`; + LMQ (Local Max Quadratic) was developed by Akritas-Strzebonski-Vigklas. + + References + ========== + .. [1] Alkiviadis G. Akritas: "Linear and Quadratic Complexity Bounds on the + Values of the Positive Roots of Polynomials" + Journal of Universal Computer Science, Vol. 15, No. 3, 523-537, 2009. + """ + bound = dup_root_upper_bound(dup_reverse(f), K) + + if bound is not None: + return 1/bound + else: + return None + +def dup_cauchy_upper_bound(f, K): + """ + Compute the Cauchy upper bound on the absolute value of all roots of f, + real or complex. + + References + ========== + .. [1] https://en.wikipedia.org/wiki/Geometrical_properties_of_polynomial_roots#Lagrange's_and_Cauchy's_bounds + """ + n = dup_degree(f) + if n < 1: + raise PolynomialError('Polynomial has no roots.') + + if K.is_ZZ: + L = K.get_field() + f, K = dup_convert(f, K, L), L + elif not K.is_QQ or K.is_RR or K.is_CC: + # We need to compute absolute value, and we are not supporting cases + # where this would take us outside the domain (or its quotient field). + raise DomainError('Cauchy bound not supported over %s' % K) + else: + f = f[:] + + while K.is_zero(f[-1]): + f.pop() + if len(f) == 1: + # Monomial. All roots are zero. + return K.zero + + lc = f[0] + return K.one + max(abs(n / lc) for n in f[1:]) + +def dup_cauchy_lower_bound(f, K): + """Compute the Cauchy lower bound on the absolute value of all non-zero + roots of f, real or complex.""" + g = dup_reverse(f) + if len(g) < 2: + raise PolynomialError('Polynomial has no non-zero roots.') + if K.is_ZZ: + K = K.get_field() + b = dup_cauchy_upper_bound(g, K) + return K.one / b + +def dup_mignotte_sep_bound_squared(f, K): + """ + Return the square of the Mignotte lower bound on separation between + distinct roots of f. The square is returned so that the bound lies in + K or its quotient field. + + References + ========== + + .. [1] Mignotte, Maurice. "Some useful bounds." Computer algebra. + Springer, Vienna, 1982. 259-263. + https://people.dm.unipi.it/gianni/AC-EAG/Mignotte.pdf + """ + n = dup_degree(f) + if n < 2: + raise PolynomialError('Polynomials of degree < 2 have no distinct roots.') + + if K.is_ZZ: + L = K.get_field() + f, K = dup_convert(f, K, L), L + elif not K.is_QQ or K.is_RR or K.is_CC: + # We need to compute absolute value, and we are not supporting cases + # where this would take us outside the domain (or its quotient field). + raise DomainError('Mignotte bound not supported over %s' % K) + + D = dup_discriminant(f, K) + l2sq = dup_l2_norm_squared(f, K) + return K(3)*K.abs(D) / ( K(n)**(n+1) * l2sq**(n-1) ) + +def _mobius_from_interval(I, field): + """Convert an open interval to a Mobius transform. """ + s, t = I + + a, c = field.numer(s), field.denom(s) + b, d = field.numer(t), field.denom(t) + + return a, b, c, d + +def _mobius_to_interval(M, field): + """Convert a Mobius transform to an open interval. """ + a, b, c, d = M + + s, t = field(a, c), field(b, d) + + if s <= t: + return (s, t) + else: + return (t, s) + +def dup_step_refine_real_root(f, M, K, fast=False): + """One step of positive real root refinement algorithm. """ + a, b, c, d = M + + if a == b and c == d: + return f, (a, b, c, d) + + A = dup_root_lower_bound(f, K) + + if A is not None: + A = K(int(A)) + else: + A = K.zero + + if fast and A > 16: + f = dup_scale(f, A, K) + a, c, A = A*a, A*c, K.one + + if A >= K.one: + f = dup_shift(f, A, K) + b, d = A*a + b, A*c + d + + if not dup_eval(f, K.zero, K): + return f, (b, b, d, d) + + f, g = dup_shift(f, K.one, K), f + + a1, b1, c1, d1 = a, a + b, c, c + d + + if not dup_eval(f, K.zero, K): + return f, (b1, b1, d1, d1) + + k = dup_sign_variations(f, K) + + if k == 1: + a, b, c, d = a1, b1, c1, d1 + else: + f = dup_shift(dup_reverse(g), K.one, K) + + if not dup_eval(f, K.zero, K): + f = dup_rshift(f, 1, K) + + a, b, c, d = b, a + b, d, c + d + + return f, (a, b, c, d) + +def dup_inner_refine_real_root(f, M, K, eps=None, steps=None, disjoint=None, fast=False, mobius=False): + """Refine a positive root of `f` given a Mobius transform or an interval. """ + F = K.get_field() + + if len(M) == 2: + a, b, c, d = _mobius_from_interval(M, F) + else: + a, b, c, d = M + + while not c: + f, (a, b, c, d) = dup_step_refine_real_root(f, (a, b, c, + d), K, fast=fast) + + if eps is not None and steps is not None: + for i in range(0, steps): + if abs(F(a, c) - F(b, d)) >= eps: + f, (a, b, c, d) = dup_step_refine_real_root(f, (a, b, c, d), K, fast=fast) + else: + break + else: + if eps is not None: + while abs(F(a, c) - F(b, d)) >= eps: + f, (a, b, c, d) = dup_step_refine_real_root(f, (a, b, c, d), K, fast=fast) + + if steps is not None: + for i in range(0, steps): + f, (a, b, c, d) = dup_step_refine_real_root(f, (a, b, c, d), K, fast=fast) + + if disjoint is not None: + while True: + u, v = _mobius_to_interval((a, b, c, d), F) + + if v <= disjoint or disjoint <= u: + break + else: + f, (a, b, c, d) = dup_step_refine_real_root(f, (a, b, c, d), K, fast=fast) + + if not mobius: + return _mobius_to_interval((a, b, c, d), F) + else: + return f, (a, b, c, d) + +def dup_outer_refine_real_root(f, s, t, K, eps=None, steps=None, disjoint=None, fast=False): + """Refine a positive root of `f` given an interval `(s, t)`. """ + a, b, c, d = _mobius_from_interval((s, t), K.get_field()) + + f = dup_transform(f, dup_strip([a, b]), + dup_strip([c, d]), K) + + if dup_sign_variations(f, K) != 1: + raise RefinementFailed("there should be exactly one root in (%s, %s) interval" % (s, t)) + + return dup_inner_refine_real_root(f, (a, b, c, d), K, eps=eps, steps=steps, disjoint=disjoint, fast=fast) + +def dup_refine_real_root(f, s, t, K, eps=None, steps=None, disjoint=None, fast=False): + """Refine real root's approximating interval to the given precision. """ + if K.is_QQ: + (_, f), K = dup_clear_denoms(f, K, convert=True), K.get_ring() + elif not K.is_ZZ: + raise DomainError("real root refinement not supported over %s" % K) + + if s == t: + return (s, t) + + if s > t: + s, t = t, s + + negative = False + + if s < 0: + if t <= 0: + f, s, t, negative = dup_mirror(f, K), -t, -s, True + else: + raise ValueError("Cannot refine a real root in (%s, %s)" % (s, t)) + + if negative and disjoint is not None: + if disjoint < 0: + disjoint = -disjoint + else: + disjoint = None + + s, t = dup_outer_refine_real_root( + f, s, t, K, eps=eps, steps=steps, disjoint=disjoint, fast=fast) + + if negative: + return (-t, -s) + else: + return ( s, t) + +def dup_inner_isolate_real_roots(f, K, eps=None, fast=False): + """Internal function for isolation positive roots up to given precision. + + References + ========== + 1. Alkiviadis G. Akritas and Adam W. Strzebonski: A Comparative Study of Two Real Root + Isolation Methods . Nonlinear Analysis: Modelling and Control, Vol. 10, No. 4, 297-304, 2005. + 2. Alkiviadis G. Akritas, Adam W. Strzebonski and Panagiotis S. Vigklas: Improving the + Performance of the Continued Fractions Method Using new Bounds of Positive Roots. Nonlinear + Analysis: Modelling and Control, Vol. 13, No. 3, 265-279, 2008. + """ + a, b, c, d = K.one, K.zero, K.zero, K.one + + k = dup_sign_variations(f, K) + + if k == 0: + return [] + if k == 1: + roots = [dup_inner_refine_real_root( + f, (a, b, c, d), K, eps=eps, fast=fast, mobius=True)] + else: + roots, stack = [], [(a, b, c, d, f, k)] + + while stack: + a, b, c, d, f, k = stack.pop() + + A = dup_root_lower_bound(f, K) + + if A is not None: + A = K(int(A)) + else: + A = K.zero + + if fast and A > 16: + f = dup_scale(f, A, K) + a, c, A = A*a, A*c, K.one + + if A >= K.one: + f = dup_shift(f, A, K) + b, d = A*a + b, A*c + d + + if not dup_TC(f, K): + roots.append((f, (b, b, d, d))) + f = dup_rshift(f, 1, K) + + k = dup_sign_variations(f, K) + + if k == 0: + continue + if k == 1: + roots.append(dup_inner_refine_real_root( + f, (a, b, c, d), K, eps=eps, fast=fast, mobius=True)) + continue + + f1 = dup_shift(f, K.one, K) + + a1, b1, c1, d1, r = a, a + b, c, c + d, 0 + + if not dup_TC(f1, K): + roots.append((f1, (b1, b1, d1, d1))) + f1, r = dup_rshift(f1, 1, K), 1 + + k1 = dup_sign_variations(f1, K) + k2 = k - k1 - r + + a2, b2, c2, d2 = b, a + b, d, c + d + + if k2 > 1: + f2 = dup_shift(dup_reverse(f), K.one, K) + + if not dup_TC(f2, K): + f2 = dup_rshift(f2, 1, K) + + k2 = dup_sign_variations(f2, K) + else: + f2 = None + + if k1 < k2: + a1, a2, b1, b2 = a2, a1, b2, b1 + c1, c2, d1, d2 = c2, c1, d2, d1 + f1, f2, k1, k2 = f2, f1, k2, k1 + + if not k1: + continue + + if f1 is None: + f1 = dup_shift(dup_reverse(f), K.one, K) + + if not dup_TC(f1, K): + f1 = dup_rshift(f1, 1, K) + + if k1 == 1: + roots.append(dup_inner_refine_real_root( + f1, (a1, b1, c1, d1), K, eps=eps, fast=fast, mobius=True)) + else: + stack.append((a1, b1, c1, d1, f1, k1)) + + if not k2: + continue + + if f2 is None: + f2 = dup_shift(dup_reverse(f), K.one, K) + + if not dup_TC(f2, K): + f2 = dup_rshift(f2, 1, K) + + if k2 == 1: + roots.append(dup_inner_refine_real_root( + f2, (a2, b2, c2, d2), K, eps=eps, fast=fast, mobius=True)) + else: + stack.append((a2, b2, c2, d2, f2, k2)) + + return roots + +def _discard_if_outside_interval(f, M, inf, sup, K, negative, fast, mobius): + """Discard an isolating interval if outside ``(inf, sup)``. """ + F = K.get_field() + + while True: + u, v = _mobius_to_interval(M, F) + + if negative: + u, v = -v, -u + + if (inf is None or u >= inf) and (sup is None or v <= sup): + if not mobius: + return u, v + else: + return f, M + elif (sup is not None and u > sup) or (inf is not None and v < inf): + return None + else: + f, M = dup_step_refine_real_root(f, M, K, fast=fast) + +def dup_inner_isolate_positive_roots(f, K, eps=None, inf=None, sup=None, fast=False, mobius=False): + """Iteratively compute disjoint positive root isolation intervals. """ + if sup is not None and sup < 0: + return [] + + roots = dup_inner_isolate_real_roots(f, K, eps=eps, fast=fast) + + F, results = K.get_field(), [] + + if inf is not None or sup is not None: + for f, M in roots: + result = _discard_if_outside_interval(f, M, inf, sup, K, False, fast, mobius) + + if result is not None: + results.append(result) + elif not mobius: + results.extend(_mobius_to_interval(M, F) for _, M in roots) + else: + results = roots + + return results + +def dup_inner_isolate_negative_roots(f, K, inf=None, sup=None, eps=None, fast=False, mobius=False): + """Iteratively compute disjoint negative root isolation intervals. """ + if inf is not None and inf >= 0: + return [] + + roots = dup_inner_isolate_real_roots(dup_mirror(f, K), K, eps=eps, fast=fast) + + F, results = K.get_field(), [] + + if inf is not None or sup is not None: + for f, M in roots: + result = _discard_if_outside_interval(f, M, inf, sup, K, True, fast, mobius) + + if result is not None: + results.append(result) + elif not mobius: + for f, M in roots: + u, v = _mobius_to_interval(M, F) + results.append((-v, -u)) + else: + results = roots + + return results + +def _isolate_zero(f, K, inf, sup, basis=False, sqf=False): + """Handle special case of CF algorithm when ``f`` is homogeneous. """ + j, f = dup_terms_gcd(f, K) + + if j > 0: + F = K.get_field() + + if (inf is None or inf <= 0) and (sup is None or 0 <= sup): + if not sqf: + if not basis: + return [((F.zero, F.zero), j)], f + else: + return [((F.zero, F.zero), j, [K.one, K.zero])], f + else: + return [(F.zero, F.zero)], f + + return [], f + +def dup_isolate_real_roots_sqf(f, K, eps=None, inf=None, sup=None, fast=False, blackbox=False): + """Isolate real roots of a square-free polynomial using the Vincent-Akritas-Strzebonski (VAS) CF approach. + + References + ========== + .. [1] Alkiviadis G. Akritas and Adam W. Strzebonski: A Comparative + Study of Two Real Root Isolation Methods. Nonlinear Analysis: + Modelling and Control, Vol. 10, No. 4, 297-304, 2005. + .. [2] Alkiviadis G. Akritas, Adam W. Strzebonski and Panagiotis S. + Vigklas: Improving the Performance of the Continued Fractions + Method Using New Bounds of Positive Roots. Nonlinear Analysis: + Modelling and Control, Vol. 13, No. 3, 265-279, 2008. + + """ + if K.is_QQ: + (_, f), K = dup_clear_denoms(f, K, convert=True), K.get_ring() + elif not K.is_ZZ: + raise DomainError("isolation of real roots not supported over %s" % K) + + if dup_degree(f) <= 0: + return [] + + I_zero, f = _isolate_zero(f, K, inf, sup, basis=False, sqf=True) + + I_neg = dup_inner_isolate_negative_roots(f, K, eps=eps, inf=inf, sup=sup, fast=fast) + I_pos = dup_inner_isolate_positive_roots(f, K, eps=eps, inf=inf, sup=sup, fast=fast) + + roots = sorted(I_neg + I_zero + I_pos) + + if not blackbox: + return roots + else: + return [ RealInterval((a, b), f, K) for (a, b) in roots ] + +def dup_isolate_real_roots(f, K, eps=None, inf=None, sup=None, basis=False, fast=False): + """Isolate real roots using Vincent-Akritas-Strzebonski (VAS) continued fractions approach. + + References + ========== + + .. [1] Alkiviadis G. Akritas and Adam W. Strzebonski: A Comparative + Study of Two Real Root Isolation Methods. Nonlinear Analysis: + Modelling and Control, Vol. 10, No. 4, 297-304, 2005. + .. [2] Alkiviadis G. Akritas, Adam W. Strzebonski and Panagiotis S. + Vigklas: Improving the Performance of the Continued Fractions + Method Using New Bounds of Positive Roots. + Nonlinear Analysis: Modelling and Control, Vol. 13, No. 3, 265-279, 2008. + + """ + if K.is_QQ: + (_, f), K = dup_clear_denoms(f, K, convert=True), K.get_ring() + elif not K.is_ZZ: + raise DomainError("isolation of real roots not supported over %s" % K) + + if dup_degree(f) <= 0: + return [] + + I_zero, f = _isolate_zero(f, K, inf, sup, basis=basis, sqf=False) + + _, factors = dup_sqf_list(f, K) + + if len(factors) == 1: + ((f, k),) = factors + + I_neg = dup_inner_isolate_negative_roots(f, K, eps=eps, inf=inf, sup=sup, fast=fast) + I_pos = dup_inner_isolate_positive_roots(f, K, eps=eps, inf=inf, sup=sup, fast=fast) + + I_neg = [ ((u, v), k) for u, v in I_neg ] + I_pos = [ ((u, v), k) for u, v in I_pos ] + else: + I_neg, I_pos = _real_isolate_and_disjoin(factors, K, + eps=eps, inf=inf, sup=sup, basis=basis, fast=fast) + + return sorted(I_neg + I_zero + I_pos) + +def dup_isolate_real_roots_list(polys, K, eps=None, inf=None, sup=None, strict=False, basis=False, fast=False): + """Isolate real roots of a list of polynomial using Vincent-Akritas-Strzebonski (VAS) CF approach. + + References + ========== + + .. [1] Alkiviadis G. Akritas and Adam W. Strzebonski: A Comparative + Study of Two Real Root Isolation Methods. Nonlinear Analysis: + Modelling and Control, Vol. 10, No. 4, 297-304, 2005. + .. [2] Alkiviadis G. Akritas, Adam W. Strzebonski and Panagiotis S. + Vigklas: Improving the Performance of the Continued Fractions + Method Using New Bounds of Positive Roots. + Nonlinear Analysis: Modelling and Control, Vol. 13, No. 3, 265-279, 2008. + + """ + if K.is_QQ: + K, F, polys = K.get_ring(), K, polys[:] + + for i, p in enumerate(polys): + polys[i] = dup_clear_denoms(p, F, K, convert=True)[1] + elif not K.is_ZZ: + raise DomainError("isolation of real roots not supported over %s" % K) + + zeros, factors_dict = False, {} + + if (inf is None or inf <= 0) and (sup is None or 0 <= sup): + zeros, zero_indices = True, {} + + for i, p in enumerate(polys): + j, p = dup_terms_gcd(p, K) + + if zeros and j > 0: + zero_indices[i] = j + + for f, k in dup_factor_list(p, K)[1]: + f = tuple(f) + + if f not in factors_dict: + factors_dict[f] = {i: k} + else: + factors_dict[f][i] = k + + factors_list = [(list(f), indices) for f, indices in factors_dict.items()] + I_neg, I_pos = _real_isolate_and_disjoin(factors_list, K, eps=eps, + inf=inf, sup=sup, strict=strict, basis=basis, fast=fast) + + F = K.get_field() + + if not zeros or not zero_indices: + I_zero = [] + else: + if not basis: + I_zero = [((F.zero, F.zero), zero_indices)] + else: + I_zero = [((F.zero, F.zero), zero_indices, [K.one, K.zero])] + + return sorted(I_neg + I_zero + I_pos) + +def _disjoint_p(M, N, strict=False): + """Check if Mobius transforms define disjoint intervals. """ + a1, b1, c1, d1 = M + a2, b2, c2, d2 = N + + a1d1, b1c1 = a1*d1, b1*c1 + a2d2, b2c2 = a2*d2, b2*c2 + + if a1d1 == b1c1 and a2d2 == b2c2: + return True + + if a1d1 > b1c1: + a1, c1, b1, d1 = b1, d1, a1, c1 + + if a2d2 > b2c2: + a2, c2, b2, d2 = b2, d2, a2, c2 + + if not strict: + return a2*d1 >= c2*b1 or b2*c1 <= d2*a1 + else: + return a2*d1 > c2*b1 or b2*c1 < d2*a1 + +def _real_isolate_and_disjoin(factors, K, eps=None, inf=None, sup=None, strict=False, basis=False, fast=False): + """Isolate real roots of a list of polynomials and disjoin intervals. """ + I_pos, I_neg = [], [] + + for i, (f, k) in enumerate(factors): + for F, M in dup_inner_isolate_positive_roots(f, K, eps=eps, inf=inf, sup=sup, fast=fast, mobius=True): + I_pos.append((F, M, k, f)) + + for G, N in dup_inner_isolate_negative_roots(f, K, eps=eps, inf=inf, sup=sup, fast=fast, mobius=True): + I_neg.append((G, N, k, f)) + + for i, (f, M, k, F) in enumerate(I_pos): + for j, (g, N, m, G) in enumerate(I_pos[i + 1:]): + while not _disjoint_p(M, N, strict=strict): + f, M = dup_inner_refine_real_root(f, M, K, steps=1, fast=fast, mobius=True) + g, N = dup_inner_refine_real_root(g, N, K, steps=1, fast=fast, mobius=True) + + I_pos[i + j + 1] = (g, N, m, G) + + I_pos[i] = (f, M, k, F) + + for i, (f, M, k, F) in enumerate(I_neg): + for j, (g, N, m, G) in enumerate(I_neg[i + 1:]): + while not _disjoint_p(M, N, strict=strict): + f, M = dup_inner_refine_real_root(f, M, K, steps=1, fast=fast, mobius=True) + g, N = dup_inner_refine_real_root(g, N, K, steps=1, fast=fast, mobius=True) + + I_neg[i + j + 1] = (g, N, m, G) + + I_neg[i] = (f, M, k, F) + + if strict: + for i, (f, M, k, F) in enumerate(I_neg): + if not M[0]: + while not M[0]: + f, M = dup_inner_refine_real_root(f, M, K, steps=1, fast=fast, mobius=True) + + I_neg[i] = (f, M, k, F) + break + + for j, (g, N, m, G) in enumerate(I_pos): + if not N[0]: + while not N[0]: + g, N = dup_inner_refine_real_root(g, N, K, steps=1, fast=fast, mobius=True) + + I_pos[j] = (g, N, m, G) + break + + field = K.get_field() + + I_neg = [ (_mobius_to_interval(M, field), k, f) for (_, M, k, f) in I_neg ] + I_pos = [ (_mobius_to_interval(M, field), k, f) for (_, M, k, f) in I_pos ] + + I_neg = [((-v, -u), k, f) for ((u, v), k, f) in I_neg] + + if not basis: + I_neg = [((u, v), k) for ((u, v), k, _) in I_neg] + I_pos = [((u, v), k) for ((u, v), k, _) in I_pos] + + return I_neg, I_pos + +def dup_count_real_roots(f, K, inf=None, sup=None): + """Returns the number of distinct real roots of ``f`` in ``[inf, sup]``. """ + if dup_degree(f) <= 0: + return 0 + + if not K.is_Field: + R, K = K, K.get_field() + f = dup_convert(f, R, K) + + sturm = dup_sturm(f, K) + + if inf is None: + signs_inf = dup_sign_variations([ dup_LC(s, K)*(-1)**dup_degree(s) for s in sturm ], K) + else: + signs_inf = dup_sign_variations([ dup_eval(s, inf, K) for s in sturm ], K) + + if sup is None: + signs_sup = dup_sign_variations([ dup_LC(s, K) for s in sturm ], K) + else: + signs_sup = dup_sign_variations([ dup_eval(s, sup, K) for s in sturm ], K) + + count = abs(signs_inf - signs_sup) + + if inf is not None and not dup_eval(f, inf, K): + count += 1 + + return count + +OO = 'OO' # Origin of (re, im) coordinate system + +Q1 = 'Q1' # Quadrant #1 (++): re > 0 and im > 0 +Q2 = 'Q2' # Quadrant #2 (-+): re < 0 and im > 0 +Q3 = 'Q3' # Quadrant #3 (--): re < 0 and im < 0 +Q4 = 'Q4' # Quadrant #4 (+-): re > 0 and im < 0 + +A1 = 'A1' # Axis #1 (+0): re > 0 and im = 0 +A2 = 'A2' # Axis #2 (0+): re = 0 and im > 0 +A3 = 'A3' # Axis #3 (-0): re < 0 and im = 0 +A4 = 'A4' # Axis #4 (0-): re = 0 and im < 0 + +_rules_simple = { + # Q --> Q (same) => no change + (Q1, Q1): 0, + (Q2, Q2): 0, + (Q3, Q3): 0, + (Q4, Q4): 0, + + # A -- CCW --> Q => +1/4 (CCW) + (A1, Q1): 1, + (A2, Q2): 1, + (A3, Q3): 1, + (A4, Q4): 1, + + # A -- CW --> Q => -1/4 (CCW) + (A1, Q4): 2, + (A2, Q1): 2, + (A3, Q2): 2, + (A4, Q3): 2, + + # Q -- CCW --> A => +1/4 (CCW) + (Q1, A2): 3, + (Q2, A3): 3, + (Q3, A4): 3, + (Q4, A1): 3, + + # Q -- CW --> A => -1/4 (CCW) + (Q1, A1): 4, + (Q2, A2): 4, + (Q3, A3): 4, + (Q4, A4): 4, + + # Q -- CCW --> Q => +1/2 (CCW) + (Q1, Q2): +5, + (Q2, Q3): +5, + (Q3, Q4): +5, + (Q4, Q1): +5, + + # Q -- CW --> Q => -1/2 (CW) + (Q1, Q4): -5, + (Q2, Q1): -5, + (Q3, Q2): -5, + (Q4, Q3): -5, +} + +_rules_ambiguous = { + # A -- CCW --> Q => { +1/4 (CCW), -9/4 (CW) } + (A1, OO, Q1): -1, + (A2, OO, Q2): -1, + (A3, OO, Q3): -1, + (A4, OO, Q4): -1, + + # A -- CW --> Q => { -1/4 (CCW), +7/4 (CW) } + (A1, OO, Q4): -2, + (A2, OO, Q1): -2, + (A3, OO, Q2): -2, + (A4, OO, Q3): -2, + + # Q -- CCW --> A => { +1/4 (CCW), -9/4 (CW) } + (Q1, OO, A2): -3, + (Q2, OO, A3): -3, + (Q3, OO, A4): -3, + (Q4, OO, A1): -3, + + # Q -- CW --> A => { -1/4 (CCW), +7/4 (CW) } + (Q1, OO, A1): -4, + (Q2, OO, A2): -4, + (Q3, OO, A3): -4, + (Q4, OO, A4): -4, + + # A -- OO --> A => { +1 (CCW), -1 (CW) } + (A1, A3): 7, + (A2, A4): 7, + (A3, A1): 7, + (A4, A2): 7, + + (A1, OO, A3): 7, + (A2, OO, A4): 7, + (A3, OO, A1): 7, + (A4, OO, A2): 7, + + # Q -- DIA --> Q => { +1 (CCW), -1 (CW) } + (Q1, Q3): 8, + (Q2, Q4): 8, + (Q3, Q1): 8, + (Q4, Q2): 8, + + (Q1, OO, Q3): 8, + (Q2, OO, Q4): 8, + (Q3, OO, Q1): 8, + (Q4, OO, Q2): 8, + + # A --- R ---> A => { +1/2 (CCW), -3/2 (CW) } + (A1, A2): 9, + (A2, A3): 9, + (A3, A4): 9, + (A4, A1): 9, + + (A1, OO, A2): 9, + (A2, OO, A3): 9, + (A3, OO, A4): 9, + (A4, OO, A1): 9, + + # A --- L ---> A => { +3/2 (CCW), -1/2 (CW) } + (A1, A4): 10, + (A2, A1): 10, + (A3, A2): 10, + (A4, A3): 10, + + (A1, OO, A4): 10, + (A2, OO, A1): 10, + (A3, OO, A2): 10, + (A4, OO, A3): 10, + + # Q --- 1 ---> A => { +3/4 (CCW), -5/4 (CW) } + (Q1, A3): 11, + (Q2, A4): 11, + (Q3, A1): 11, + (Q4, A2): 11, + + (Q1, OO, A3): 11, + (Q2, OO, A4): 11, + (Q3, OO, A1): 11, + (Q4, OO, A2): 11, + + # Q --- 2 ---> A => { +5/4 (CCW), -3/4 (CW) } + (Q1, A4): 12, + (Q2, A1): 12, + (Q3, A2): 12, + (Q4, A3): 12, + + (Q1, OO, A4): 12, + (Q2, OO, A1): 12, + (Q3, OO, A2): 12, + (Q4, OO, A3): 12, + + # A --- 1 ---> Q => { +5/4 (CCW), -3/4 (CW) } + (A1, Q3): 13, + (A2, Q4): 13, + (A3, Q1): 13, + (A4, Q2): 13, + + (A1, OO, Q3): 13, + (A2, OO, Q4): 13, + (A3, OO, Q1): 13, + (A4, OO, Q2): 13, + + # A --- 2 ---> Q => { +3/4 (CCW), -5/4 (CW) } + (A1, Q2): 14, + (A2, Q3): 14, + (A3, Q4): 14, + (A4, Q1): 14, + + (A1, OO, Q2): 14, + (A2, OO, Q3): 14, + (A3, OO, Q4): 14, + (A4, OO, Q1): 14, + + # Q --> OO --> Q => { +1/2 (CCW), -3/2 (CW) } + (Q1, OO, Q2): 15, + (Q2, OO, Q3): 15, + (Q3, OO, Q4): 15, + (Q4, OO, Q1): 15, + + # Q --> OO --> Q => { +3/2 (CCW), -1/2 (CW) } + (Q1, OO, Q4): 16, + (Q2, OO, Q1): 16, + (Q3, OO, Q2): 16, + (Q4, OO, Q3): 16, + + # A --> OO --> A => { +2 (CCW), 0 (CW) } + (A1, OO, A1): 17, + (A2, OO, A2): 17, + (A3, OO, A3): 17, + (A4, OO, A4): 17, + + # Q --> OO --> Q => { +2 (CCW), 0 (CW) } + (Q1, OO, Q1): 18, + (Q2, OO, Q2): 18, + (Q3, OO, Q3): 18, + (Q4, OO, Q4): 18, +} + +_values = { + 0: [( 0, 1)], + 1: [(+1, 4)], + 2: [(-1, 4)], + 3: [(+1, 4)], + 4: [(-1, 4)], + -1: [(+9, 4), (+1, 4)], + -2: [(+7, 4), (-1, 4)], + -3: [(+9, 4), (+1, 4)], + -4: [(+7, 4), (-1, 4)], + +5: [(+1, 2)], + -5: [(-1, 2)], + 7: [(+1, 1), (-1, 1)], + 8: [(+1, 1), (-1, 1)], + 9: [(+1, 2), (-3, 2)], + 10: [(+3, 2), (-1, 2)], + 11: [(+3, 4), (-5, 4)], + 12: [(+5, 4), (-3, 4)], + 13: [(+5, 4), (-3, 4)], + 14: [(+3, 4), (-5, 4)], + 15: [(+1, 2), (-3, 2)], + 16: [(+3, 2), (-1, 2)], + 17: [(+2, 1), ( 0, 1)], + 18: [(+2, 1), ( 0, 1)], +} + +def _classify_point(re, im): + """Return the half-axis (or origin) on which (re, im) point is located. """ + if not re and not im: + return OO + + if not re: + if im > 0: + return A2 + else: + return A4 + elif not im: + if re > 0: + return A1 + else: + return A3 + +def _intervals_to_quadrants(intervals, f1, f2, s, t, F): + """Generate a sequence of extended quadrants from a list of critical points. """ + if not intervals: + return [] + + Q = [] + + if not f1: + (a, b), _, _ = intervals[0] + + if a == b == s: + if len(intervals) == 1: + if dup_eval(f2, t, F) > 0: + return [OO, A2] + else: + return [OO, A4] + else: + (a, _), _, _ = intervals[1] + + if dup_eval(f2, (s + a)/2, F) > 0: + Q.extend([OO, A2]) + f2_sgn = +1 + else: + Q.extend([OO, A4]) + f2_sgn = -1 + + intervals = intervals[1:] + else: + if dup_eval(f2, s, F) > 0: + Q.append(A2) + f2_sgn = +1 + else: + Q.append(A4) + f2_sgn = -1 + + for (a, _), indices, _ in intervals: + Q.append(OO) + + if indices[1] % 2 == 1: + f2_sgn = -f2_sgn + + if a != t: + if f2_sgn > 0: + Q.append(A2) + else: + Q.append(A4) + + return Q + + if not f2: + (a, b), _, _ = intervals[0] + + if a == b == s: + if len(intervals) == 1: + if dup_eval(f1, t, F) > 0: + return [OO, A1] + else: + return [OO, A3] + else: + (a, _), _, _ = intervals[1] + + if dup_eval(f1, (s + a)/2, F) > 0: + Q.extend([OO, A1]) + f1_sgn = +1 + else: + Q.extend([OO, A3]) + f1_sgn = -1 + + intervals = intervals[1:] + else: + if dup_eval(f1, s, F) > 0: + Q.append(A1) + f1_sgn = +1 + else: + Q.append(A3) + f1_sgn = -1 + + for (a, _), indices, _ in intervals: + Q.append(OO) + + if indices[0] % 2 == 1: + f1_sgn = -f1_sgn + + if a != t: + if f1_sgn > 0: + Q.append(A1) + else: + Q.append(A3) + + return Q + + re = dup_eval(f1, s, F) + im = dup_eval(f2, s, F) + + if not re or not im: + Q.append(_classify_point(re, im)) + + if len(intervals) == 1: + re = dup_eval(f1, t, F) + im = dup_eval(f2, t, F) + else: + (a, _), _, _ = intervals[1] + + re = dup_eval(f1, (s + a)/2, F) + im = dup_eval(f2, (s + a)/2, F) + + intervals = intervals[1:] + + if re > 0: + f1_sgn = +1 + else: + f1_sgn = -1 + + if im > 0: + f2_sgn = +1 + else: + f2_sgn = -1 + + sgn = { + (+1, +1): Q1, + (-1, +1): Q2, + (-1, -1): Q3, + (+1, -1): Q4, + } + + Q.append(sgn[(f1_sgn, f2_sgn)]) + + for (a, b), indices, _ in intervals: + if a == b: + re = dup_eval(f1, a, F) + im = dup_eval(f2, a, F) + + cls = _classify_point(re, im) + + if cls is not None: + Q.append(cls) + + if 0 in indices: + if indices[0] % 2 == 1: + f1_sgn = -f1_sgn + + if 1 in indices: + if indices[1] % 2 == 1: + f2_sgn = -f2_sgn + + if not (a == b and b == t): + Q.append(sgn[(f1_sgn, f2_sgn)]) + + return Q + +def _traverse_quadrants(Q_L1, Q_L2, Q_L3, Q_L4, exclude=None): + """Transform sequences of quadrants to a sequence of rules. """ + if exclude is True: + edges = [1, 1, 0, 0] + + corners = { + (0, 1): 1, + (1, 2): 1, + (2, 3): 0, + (3, 0): 1, + } + else: + edges = [0, 0, 0, 0] + + corners = { + (0, 1): 0, + (1, 2): 0, + (2, 3): 0, + (3, 0): 0, + } + + if exclude is not None and exclude is not True: + exclude = set(exclude) + + for i, edge in enumerate(['S', 'E', 'N', 'W']): + if edge in exclude: + edges[i] = 1 + + for i, corner in enumerate(['SW', 'SE', 'NE', 'NW']): + if corner in exclude: + corners[((i - 1) % 4, i)] = 1 + + QQ, rules = [Q_L1, Q_L2, Q_L3, Q_L4], [] + + for i, Q in enumerate(QQ): + if not Q: + continue + + if Q[-1] == OO: + Q = Q[:-1] + + if Q[0] == OO: + j, Q = (i - 1) % 4, Q[1:] + qq = (QQ[j][-2], OO, Q[0]) + + if qq in _rules_ambiguous: + rules.append((_rules_ambiguous[qq], corners[(j, i)])) + else: + raise NotImplementedError("3 element rule (corner): " + str(qq)) + + q1, k = Q[0], 1 + + while k < len(Q): + q2, k = Q[k], k + 1 + + if q2 != OO: + qq = (q1, q2) + + if qq in _rules_simple: + rules.append((_rules_simple[qq], 0)) + elif qq in _rules_ambiguous: + rules.append((_rules_ambiguous[qq], edges[i])) + else: + raise NotImplementedError("2 element rule (inside): " + str(qq)) + else: + qq, k = (q1, q2, Q[k]), k + 1 + + if qq in _rules_ambiguous: + rules.append((_rules_ambiguous[qq], edges[i])) + else: + raise NotImplementedError("3 element rule (edge): " + str(qq)) + + q1 = qq[-1] + + return rules + +def _reverse_intervals(intervals): + """Reverse intervals for traversal from right to left and from top to bottom. """ + return [ ((b, a), indices, f) for (a, b), indices, f in reversed(intervals) ] + +def _winding_number(T, field): + """Compute the winding number of the input polynomial, i.e. the number of roots. """ + return int(sum(field(*_values[t][i]) for t, i in T) / field(2)) + +def dup_count_complex_roots(f, K, inf=None, sup=None, exclude=None): + """Count all roots in [u + v*I, s + t*I] rectangle using Collins-Krandick algorithm. """ + if not K.is_ZZ and not K.is_QQ: + raise DomainError("complex root counting is not supported over %s" % K) + + if K.is_ZZ: + R, F = K, K.get_field() + else: + R, F = K.get_ring(), K + + f = dup_convert(f, K, F) + + if inf is None or sup is None: + _, lc = dup_degree(f), abs(dup_LC(f, F)) + B = 2*max(F.quo(abs(c), lc) for c in f) + + if inf is None: + (u, v) = (-B, -B) + else: + (u, v) = inf + + if sup is None: + (s, t) = (+B, +B) + else: + (s, t) = sup + + f1, f2 = dup_real_imag(f, F) + + f1L1F = dmp_eval_in(f1, v, 1, 1, F) + f2L1F = dmp_eval_in(f2, v, 1, 1, F) + + _, f1L1R = dup_clear_denoms(f1L1F, F, R, convert=True) + _, f2L1R = dup_clear_denoms(f2L1F, F, R, convert=True) + + f1L2F = dmp_eval_in(f1, s, 0, 1, F) + f2L2F = dmp_eval_in(f2, s, 0, 1, F) + + _, f1L2R = dup_clear_denoms(f1L2F, F, R, convert=True) + _, f2L2R = dup_clear_denoms(f2L2F, F, R, convert=True) + + f1L3F = dmp_eval_in(f1, t, 1, 1, F) + f2L3F = dmp_eval_in(f2, t, 1, 1, F) + + _, f1L3R = dup_clear_denoms(f1L3F, F, R, convert=True) + _, f2L3R = dup_clear_denoms(f2L3F, F, R, convert=True) + + f1L4F = dmp_eval_in(f1, u, 0, 1, F) + f2L4F = dmp_eval_in(f2, u, 0, 1, F) + + _, f1L4R = dup_clear_denoms(f1L4F, F, R, convert=True) + _, f2L4R = dup_clear_denoms(f2L4F, F, R, convert=True) + + S_L1 = [f1L1R, f2L1R] + S_L2 = [f1L2R, f2L2R] + S_L3 = [f1L3R, f2L3R] + S_L4 = [f1L4R, f2L4R] + + I_L1 = dup_isolate_real_roots_list(S_L1, R, inf=u, sup=s, fast=True, basis=True, strict=True) + I_L2 = dup_isolate_real_roots_list(S_L2, R, inf=v, sup=t, fast=True, basis=True, strict=True) + I_L3 = dup_isolate_real_roots_list(S_L3, R, inf=u, sup=s, fast=True, basis=True, strict=True) + I_L4 = dup_isolate_real_roots_list(S_L4, R, inf=v, sup=t, fast=True, basis=True, strict=True) + + I_L3 = _reverse_intervals(I_L3) + I_L4 = _reverse_intervals(I_L4) + + Q_L1 = _intervals_to_quadrants(I_L1, f1L1F, f2L1F, u, s, F) + Q_L2 = _intervals_to_quadrants(I_L2, f1L2F, f2L2F, v, t, F) + Q_L3 = _intervals_to_quadrants(I_L3, f1L3F, f2L3F, s, u, F) + Q_L4 = _intervals_to_quadrants(I_L4, f1L4F, f2L4F, t, v, F) + + T = _traverse_quadrants(Q_L1, Q_L2, Q_L3, Q_L4, exclude=exclude) + + return _winding_number(T, F) + +def _vertical_bisection(N, a, b, I, Q, F1, F2, f1, f2, F): + """Vertical bisection step in Collins-Krandick root isolation algorithm. """ + (u, v), (s, t) = a, b + + I_L1, I_L2, I_L3, I_L4 = I + Q_L1, Q_L2, Q_L3, Q_L4 = Q + + f1L1F, f1L2F, f1L3F, f1L4F = F1 + f2L1F, f2L2F, f2L3F, f2L4F = F2 + + x = (u + s) / 2 + + f1V = dmp_eval_in(f1, x, 0, 1, F) + f2V = dmp_eval_in(f2, x, 0, 1, F) + + I_V = dup_isolate_real_roots_list([f1V, f2V], F, inf=v, sup=t, fast=True, strict=True, basis=True) + + I_L1_L, I_L1_R = [], [] + I_L2_L, I_L2_R = I_V, I_L2 + I_L3_L, I_L3_R = [], [] + I_L4_L, I_L4_R = I_L4, _reverse_intervals(I_V) + + for I in I_L1: + (a, b), indices, h = I + + if a == b: + if a == x: + I_L1_L.append(I) + I_L1_R.append(I) + elif a < x: + I_L1_L.append(I) + else: + I_L1_R.append(I) + else: + if b <= x: + I_L1_L.append(I) + elif a >= x: + I_L1_R.append(I) + else: + a, b = dup_refine_real_root(h, a, b, F.get_ring(), disjoint=x, fast=True) + + if b <= x: + I_L1_L.append(((a, b), indices, h)) + if a >= x: + I_L1_R.append(((a, b), indices, h)) + + for I in I_L3: + (b, a), indices, h = I + + if a == b: + if a == x: + I_L3_L.append(I) + I_L3_R.append(I) + elif a < x: + I_L3_L.append(I) + else: + I_L3_R.append(I) + else: + if b <= x: + I_L3_L.append(I) + elif a >= x: + I_L3_R.append(I) + else: + a, b = dup_refine_real_root(h, a, b, F.get_ring(), disjoint=x, fast=True) + + if b <= x: + I_L3_L.append(((b, a), indices, h)) + if a >= x: + I_L3_R.append(((b, a), indices, h)) + + Q_L1_L = _intervals_to_quadrants(I_L1_L, f1L1F, f2L1F, u, x, F) + Q_L2_L = _intervals_to_quadrants(I_L2_L, f1V, f2V, v, t, F) + Q_L3_L = _intervals_to_quadrants(I_L3_L, f1L3F, f2L3F, x, u, F) + Q_L4_L = Q_L4 + + Q_L1_R = _intervals_to_quadrants(I_L1_R, f1L1F, f2L1F, x, s, F) + Q_L2_R = Q_L2 + Q_L3_R = _intervals_to_quadrants(I_L3_R, f1L3F, f2L3F, s, x, F) + Q_L4_R = _intervals_to_quadrants(I_L4_R, f1V, f2V, t, v, F) + + T_L = _traverse_quadrants(Q_L1_L, Q_L2_L, Q_L3_L, Q_L4_L, exclude=True) + T_R = _traverse_quadrants(Q_L1_R, Q_L2_R, Q_L3_R, Q_L4_R, exclude=True) + + N_L = _winding_number(T_L, F) + N_R = _winding_number(T_R, F) + + I_L = (I_L1_L, I_L2_L, I_L3_L, I_L4_L) + Q_L = (Q_L1_L, Q_L2_L, Q_L3_L, Q_L4_L) + + I_R = (I_L1_R, I_L2_R, I_L3_R, I_L4_R) + Q_R = (Q_L1_R, Q_L2_R, Q_L3_R, Q_L4_R) + + F1_L = (f1L1F, f1V, f1L3F, f1L4F) + F2_L = (f2L1F, f2V, f2L3F, f2L4F) + + F1_R = (f1L1F, f1L2F, f1L3F, f1V) + F2_R = (f2L1F, f2L2F, f2L3F, f2V) + + a, b = (u, v), (x, t) + c, d = (x, v), (s, t) + + D_L = (N_L, a, b, I_L, Q_L, F1_L, F2_L) + D_R = (N_R, c, d, I_R, Q_R, F1_R, F2_R) + + return D_L, D_R + +def _horizontal_bisection(N, a, b, I, Q, F1, F2, f1, f2, F): + """Horizontal bisection step in Collins-Krandick root isolation algorithm. """ + (u, v), (s, t) = a, b + + I_L1, I_L2, I_L3, I_L4 = I + Q_L1, Q_L2, Q_L3, Q_L4 = Q + + f1L1F, f1L2F, f1L3F, f1L4F = F1 + f2L1F, f2L2F, f2L3F, f2L4F = F2 + + y = (v + t) / 2 + + f1H = dmp_eval_in(f1, y, 1, 1, F) + f2H = dmp_eval_in(f2, y, 1, 1, F) + + I_H = dup_isolate_real_roots_list([f1H, f2H], F, inf=u, sup=s, fast=True, strict=True, basis=True) + + I_L1_B, I_L1_U = I_L1, I_H + I_L2_B, I_L2_U = [], [] + I_L3_B, I_L3_U = _reverse_intervals(I_H), I_L3 + I_L4_B, I_L4_U = [], [] + + for I in I_L2: + (a, b), indices, h = I + + if a == b: + if a == y: + I_L2_B.append(I) + I_L2_U.append(I) + elif a < y: + I_L2_B.append(I) + else: + I_L2_U.append(I) + else: + if b <= y: + I_L2_B.append(I) + elif a >= y: + I_L2_U.append(I) + else: + a, b = dup_refine_real_root(h, a, b, F.get_ring(), disjoint=y, fast=True) + + if b <= y: + I_L2_B.append(((a, b), indices, h)) + if a >= y: + I_L2_U.append(((a, b), indices, h)) + + for I in I_L4: + (b, a), indices, h = I + + if a == b: + if a == y: + I_L4_B.append(I) + I_L4_U.append(I) + elif a < y: + I_L4_B.append(I) + else: + I_L4_U.append(I) + else: + if b <= y: + I_L4_B.append(I) + elif a >= y: + I_L4_U.append(I) + else: + a, b = dup_refine_real_root(h, a, b, F.get_ring(), disjoint=y, fast=True) + + if b <= y: + I_L4_B.append(((b, a), indices, h)) + if a >= y: + I_L4_U.append(((b, a), indices, h)) + + Q_L1_B = Q_L1 + Q_L2_B = _intervals_to_quadrants(I_L2_B, f1L2F, f2L2F, v, y, F) + Q_L3_B = _intervals_to_quadrants(I_L3_B, f1H, f2H, s, u, F) + Q_L4_B = _intervals_to_quadrants(I_L4_B, f1L4F, f2L4F, y, v, F) + + Q_L1_U = _intervals_to_quadrants(I_L1_U, f1H, f2H, u, s, F) + Q_L2_U = _intervals_to_quadrants(I_L2_U, f1L2F, f2L2F, y, t, F) + Q_L3_U = Q_L3 + Q_L4_U = _intervals_to_quadrants(I_L4_U, f1L4F, f2L4F, t, y, F) + + T_B = _traverse_quadrants(Q_L1_B, Q_L2_B, Q_L3_B, Q_L4_B, exclude=True) + T_U = _traverse_quadrants(Q_L1_U, Q_L2_U, Q_L3_U, Q_L4_U, exclude=True) + + N_B = _winding_number(T_B, F) + N_U = _winding_number(T_U, F) + + I_B = (I_L1_B, I_L2_B, I_L3_B, I_L4_B) + Q_B = (Q_L1_B, Q_L2_B, Q_L3_B, Q_L4_B) + + I_U = (I_L1_U, I_L2_U, I_L3_U, I_L4_U) + Q_U = (Q_L1_U, Q_L2_U, Q_L3_U, Q_L4_U) + + F1_B = (f1L1F, f1L2F, f1H, f1L4F) + F2_B = (f2L1F, f2L2F, f2H, f2L4F) + + F1_U = (f1H, f1L2F, f1L3F, f1L4F) + F2_U = (f2H, f2L2F, f2L3F, f2L4F) + + a, b = (u, v), (s, y) + c, d = (u, y), (s, t) + + D_B = (N_B, a, b, I_B, Q_B, F1_B, F2_B) + D_U = (N_U, c, d, I_U, Q_U, F1_U, F2_U) + + return D_B, D_U + +def _depth_first_select(rectangles): + """Find a rectangle of minimum area for bisection. """ + min_area, j = None, None + + for i, (_, (u, v), (s, t), _, _, _, _) in enumerate(rectangles): + area = (s - u)*(t - v) + + if min_area is None or area < min_area: + min_area, j = area, i + + return rectangles.pop(j) + +def _rectangle_small_p(a, b, eps): + """Return ``True`` if the given rectangle is small enough. """ + (u, v), (s, t) = a, b + + if eps is not None: + return s - u < eps and t - v < eps + else: + return True + +def dup_isolate_complex_roots_sqf(f, K, eps=None, inf=None, sup=None, blackbox=False): + """Isolate complex roots of a square-free polynomial using Collins-Krandick algorithm. """ + if not K.is_ZZ and not K.is_QQ: + raise DomainError("isolation of complex roots is not supported over %s" % K) + + if dup_degree(f) <= 0: + return [] + + if K.is_ZZ: + F = K.get_field() + else: + F = K + + f = dup_convert(f, K, F) + + lc = abs(dup_LC(f, F)) + B = 2*max(F.quo(abs(c), lc) for c in f) + + (u, v), (s, t) = (-B, F.zero), (B, B) + + if inf is not None: + u = inf + + if sup is not None: + s = sup + + if v < 0 or t <= v or s <= u: + raise ValueError("not a valid complex isolation rectangle") + + f1, f2 = dup_real_imag(f, F) + + f1L1 = dmp_eval_in(f1, v, 1, 1, F) + f2L1 = dmp_eval_in(f2, v, 1, 1, F) + + f1L2 = dmp_eval_in(f1, s, 0, 1, F) + f2L2 = dmp_eval_in(f2, s, 0, 1, F) + + f1L3 = dmp_eval_in(f1, t, 1, 1, F) + f2L3 = dmp_eval_in(f2, t, 1, 1, F) + + f1L4 = dmp_eval_in(f1, u, 0, 1, F) + f2L4 = dmp_eval_in(f2, u, 0, 1, F) + + S_L1 = [f1L1, f2L1] + S_L2 = [f1L2, f2L2] + S_L3 = [f1L3, f2L3] + S_L4 = [f1L4, f2L4] + + I_L1 = dup_isolate_real_roots_list(S_L1, F, inf=u, sup=s, fast=True, strict=True, basis=True) + I_L2 = dup_isolate_real_roots_list(S_L2, F, inf=v, sup=t, fast=True, strict=True, basis=True) + I_L3 = dup_isolate_real_roots_list(S_L3, F, inf=u, sup=s, fast=True, strict=True, basis=True) + I_L4 = dup_isolate_real_roots_list(S_L4, F, inf=v, sup=t, fast=True, strict=True, basis=True) + + I_L3 = _reverse_intervals(I_L3) + I_L4 = _reverse_intervals(I_L4) + + Q_L1 = _intervals_to_quadrants(I_L1, f1L1, f2L1, u, s, F) + Q_L2 = _intervals_to_quadrants(I_L2, f1L2, f2L2, v, t, F) + Q_L3 = _intervals_to_quadrants(I_L3, f1L3, f2L3, s, u, F) + Q_L4 = _intervals_to_quadrants(I_L4, f1L4, f2L4, t, v, F) + + T = _traverse_quadrants(Q_L1, Q_L2, Q_L3, Q_L4) + N = _winding_number(T, F) + + if not N: + return [] + + I = (I_L1, I_L2, I_L3, I_L4) + Q = (Q_L1, Q_L2, Q_L3, Q_L4) + + F1 = (f1L1, f1L2, f1L3, f1L4) + F2 = (f2L1, f2L2, f2L3, f2L4) + + rectangles, roots = [(N, (u, v), (s, t), I, Q, F1, F2)], [] + + while rectangles: + N, (u, v), (s, t), I, Q, F1, F2 = _depth_first_select(rectangles) + + if s - u > t - v: + D_L, D_R = _vertical_bisection(N, (u, v), (s, t), I, Q, F1, F2, f1, f2, F) + + N_L, a, b, I_L, Q_L, F1_L, F2_L = D_L + N_R, c, d, I_R, Q_R, F1_R, F2_R = D_R + + if N_L >= 1: + if N_L == 1 and _rectangle_small_p(a, b, eps): + roots.append(ComplexInterval(a, b, I_L, Q_L, F1_L, F2_L, f1, f2, F)) + else: + rectangles.append(D_L) + + if N_R >= 1: + if N_R == 1 and _rectangle_small_p(c, d, eps): + roots.append(ComplexInterval(c, d, I_R, Q_R, F1_R, F2_R, f1, f2, F)) + else: + rectangles.append(D_R) + else: + D_B, D_U = _horizontal_bisection(N, (u, v), (s, t), I, Q, F1, F2, f1, f2, F) + + N_B, a, b, I_B, Q_B, F1_B, F2_B = D_B + N_U, c, d, I_U, Q_U, F1_U, F2_U = D_U + + if N_B >= 1: + if N_B == 1 and _rectangle_small_p(a, b, eps): + roots.append(ComplexInterval( + a, b, I_B, Q_B, F1_B, F2_B, f1, f2, F)) + else: + rectangles.append(D_B) + + if N_U >= 1: + if N_U == 1 and _rectangle_small_p(c, d, eps): + roots.append(ComplexInterval( + c, d, I_U, Q_U, F1_U, F2_U, f1, f2, F)) + else: + rectangles.append(D_U) + + _roots, roots = sorted(roots, key=lambda r: (r.ax, r.ay)), [] + + for root in _roots: + roots.extend([root.conjugate(), root]) + + if blackbox: + return roots + else: + return [ r.as_tuple() for r in roots ] + +def dup_isolate_all_roots_sqf(f, K, eps=None, inf=None, sup=None, fast=False, blackbox=False): + """Isolate real and complex roots of a square-free polynomial ``f``. """ + return ( + dup_isolate_real_roots_sqf( f, K, eps=eps, inf=inf, sup=sup, fast=fast, blackbox=blackbox), + dup_isolate_complex_roots_sqf(f, K, eps=eps, inf=inf, sup=sup, blackbox=blackbox)) + +def dup_isolate_all_roots(f, K, eps=None, inf=None, sup=None, fast=False): + """Isolate real and complex roots of a non-square-free polynomial ``f``. """ + if not K.is_ZZ and not K.is_QQ: + raise DomainError("isolation of real and complex roots is not supported over %s" % K) + + _, factors = dup_sqf_list(f, K) + + if len(factors) == 1: + ((f, k),) = factors + + real_part, complex_part = dup_isolate_all_roots_sqf( + f, K, eps=eps, inf=inf, sup=sup, fast=fast) + + real_part = [ ((a, b), k) for (a, b) in real_part ] + complex_part = [ ((a, b), k) for (a, b) in complex_part ] + + return real_part, complex_part + else: + raise NotImplementedError( "only trivial square-free polynomials are supported") + +class RealInterval: + """A fully qualified representation of a real isolation interval. """ + + def __init__(self, data, f, dom): + """Initialize new real interval with complete information. """ + if len(data) == 2: + s, t = data + + self.neg = False + + if s < 0: + if t <= 0: + f, s, t, self.neg = dup_mirror(f, dom), -t, -s, True + else: + raise ValueError("Cannot refine a real root in (%s, %s)" % (s, t)) + + a, b, c, d = _mobius_from_interval((s, t), dom.get_field()) + + f = dup_transform(f, dup_strip([a, b]), + dup_strip([c, d]), dom) + + self.mobius = a, b, c, d + else: + self.mobius = data[:-1] + self.neg = data[-1] + + self.f, self.dom = f, dom + + @property + def func(self): + return RealInterval + + @property + def args(self): + i = self + return (i.mobius + (i.neg,), i.f, i.dom) + + def __eq__(self, other): + if type(other) is not type(self): + return False + return self.args == other.args + + @property + def a(self): + """Return the position of the left end. """ + field = self.dom.get_field() + a, b, c, d = self.mobius + + if not self.neg: + if a*d < b*c: + return field(a, c) + return field(b, d) + else: + if a*d > b*c: + return -field(a, c) + return -field(b, d) + + @property + def b(self): + """Return the position of the right end. """ + was = self.neg + self.neg = not was + rv = -self.a + self.neg = was + return rv + + @property + def dx(self): + """Return width of the real isolating interval. """ + return self.b - self.a + + @property + def center(self): + """Return the center of the real isolating interval. """ + return (self.a + self.b)/2 + + @property + def max_denom(self): + """Return the largest denominator occurring in either endpoint. """ + return max(self.a.denominator, self.b.denominator) + + def as_tuple(self): + """Return tuple representation of real isolating interval. """ + return (self.a, self.b) + + def __repr__(self): + return "(%s, %s)" % (self.a, self.b) + + def __contains__(self, item): + """ + Say whether a complex number belongs to this real interval. + + Parameters + ========== + + item : pair (re, im) or number re + Either a pair giving the real and imaginary parts of the number, + or else a real number. + + """ + if isinstance(item, tuple): + re, im = item + else: + re, im = item, 0 + return im == 0 and self.a <= re <= self.b + + def is_disjoint(self, other): + """Return ``True`` if two isolation intervals are disjoint. """ + if isinstance(other, RealInterval): + return (self.b < other.a or other.b < self.a) + assert isinstance(other, ComplexInterval) + return (self.b < other.ax or other.bx < self.a + or other.ay*other.by > 0) + + def _inner_refine(self): + """Internal one step real root refinement procedure. """ + if self.mobius is None: + return self + + f, mobius = dup_inner_refine_real_root( + self.f, self.mobius, self.dom, steps=1, mobius=True) + + return RealInterval(mobius + (self.neg,), f, self.dom) + + def refine_disjoint(self, other): + """Refine an isolating interval until it is disjoint with another one. """ + expr = self + while not expr.is_disjoint(other): + expr, other = expr._inner_refine(), other._inner_refine() + + return expr, other + + def refine_size(self, dx): + """Refine an isolating interval until it is of sufficiently small size. """ + expr = self + while not (expr.dx < dx): + expr = expr._inner_refine() + + return expr + + def refine_step(self, steps=1): + """Perform several steps of real root refinement algorithm. """ + expr = self + for _ in range(steps): + expr = expr._inner_refine() + + return expr + + def refine(self): + """Perform one step of real root refinement algorithm. """ + return self._inner_refine() + + +class ComplexInterval: + """A fully qualified representation of a complex isolation interval. + The printed form is shown as (ax, bx) x (ay, by) where (ax, ay) + and (bx, by) are the coordinates of the southwest and northeast + corners of the interval's rectangle, respectively. + + Examples + ======== + + >>> from sympy import CRootOf, S + >>> from sympy.abc import x + >>> CRootOf.clear_cache() # for doctest reproducibility + >>> root = CRootOf(x**10 - 2*x + 3, 9) + >>> i = root._get_interval(); i + (3/64, 3/32) x (9/8, 75/64) + + The real part of the root lies within the range [0, 3/4] while + the imaginary part lies within the range [9/8, 3/2]: + + >>> root.n(3) + 0.0766 + 1.14*I + + The width of the ranges in the x and y directions on the complex + plane are: + + >>> i.dx, i.dy + (3/64, 3/64) + + The center of the range is + + >>> i.center + (9/128, 147/128) + + The northeast coordinate of the rectangle bounding the root in the + complex plane is given by attribute b and the x and y components + are accessed by bx and by: + + >>> i.b, i.bx, i.by + ((3/32, 75/64), 3/32, 75/64) + + The southwest coordinate is similarly given by i.a + + >>> i.a, i.ax, i.ay + ((3/64, 9/8), 3/64, 9/8) + + Although the interval prints to show only the real and imaginary + range of the root, all the information of the underlying root + is contained as properties of the interval. + + For example, an interval with a nonpositive imaginary range is + considered to be the conjugate. Since the y values of y are in the + range [0, 1/4] it is not the conjugate: + + >>> i.conj + False + + The conjugate's interval is + + >>> ic = i.conjugate(); ic + (3/64, 3/32) x (-75/64, -9/8) + + NOTE: the values printed still represent the x and y range + in which the root -- conjugate, in this case -- is located, + but the underlying a and b values of a root and its conjugate + are the same: + + >>> assert i.a == ic.a and i.b == ic.b + + What changes are the reported coordinates of the bounding rectangle: + + >>> (i.ax, i.ay), (i.bx, i.by) + ((3/64, 9/8), (3/32, 75/64)) + >>> (ic.ax, ic.ay), (ic.bx, ic.by) + ((3/64, -75/64), (3/32, -9/8)) + + The interval can be refined once: + + >>> i # for reference, this is the current interval + (3/64, 3/32) x (9/8, 75/64) + + >>> i.refine() + (3/64, 3/32) x (9/8, 147/128) + + Several refinement steps can be taken: + + >>> i.refine_step(2) # 2 steps + (9/128, 3/32) x (9/8, 147/128) + + It is also possible to refine to a given tolerance: + + >>> tol = min(i.dx, i.dy)/2 + >>> i.refine_size(tol) + (9/128, 21/256) x (9/8, 291/256) + + A disjoint interval is one whose bounding rectangle does not + overlap with another. An interval, necessarily, is not disjoint with + itself, but any interval is disjoint with a conjugate since the + conjugate rectangle will always be in the lower half of the complex + plane and the non-conjugate in the upper half: + + >>> i.is_disjoint(i), i.is_disjoint(i.conjugate()) + (False, True) + + The following interval j is not disjoint from i: + + >>> close = CRootOf(x**10 - 2*x + 300/S(101), 9) + >>> j = close._get_interval(); j + (75/1616, 75/808) x (225/202, 1875/1616) + >>> i.is_disjoint(j) + False + + The two can be made disjoint, however: + + >>> newi, newj = i.refine_disjoint(j) + >>> newi + (39/512, 159/2048) x (2325/2048, 4653/4096) + >>> newj + (3975/51712, 2025/25856) x (29325/25856, 117375/103424) + + Even though the real ranges overlap, the imaginary do not, so + the roots have been resolved as distinct. Intervals are disjoint + when either the real or imaginary component of the intervals is + distinct. In the case above, the real components have not been + resolved (so we do not know, yet, which root has the smaller real + part) but the imaginary part of ``close`` is larger than ``root``: + + >>> close.n(3) + 0.0771 + 1.13*I + >>> root.n(3) + 0.0766 + 1.14*I + """ + + def __init__(self, a, b, I, Q, F1, F2, f1, f2, dom, conj=False): + """Initialize new complex interval with complete information. """ + # a and b are the SW and NE corner of the bounding interval, + # (ax, ay) and (bx, by), respectively, for the NON-CONJUGATE + # root (the one with the positive imaginary part); when working + # with the conjugate, the a and b value are still non-negative + # but the ay, by are reversed and have oppositite sign + self.a, self.b = a, b + self.I, self.Q = I, Q + + self.f1, self.F1 = f1, F1 + self.f2, self.F2 = f2, F2 + + self.dom = dom + self.conj = conj + + @property + def func(self): + return ComplexInterval + + @property + def args(self): + i = self + return (i.a, i.b, i.I, i.Q, i.F1, i.F2, i.f1, i.f2, i.dom, i.conj) + + def __eq__(self, other): + if type(other) is not type(self): + return False + return self.args == other.args + + @property + def ax(self): + """Return ``x`` coordinate of south-western corner. """ + return self.a[0] + + @property + def ay(self): + """Return ``y`` coordinate of south-western corner. """ + if not self.conj: + return self.a[1] + else: + return -self.b[1] + + @property + def bx(self): + """Return ``x`` coordinate of north-eastern corner. """ + return self.b[0] + + @property + def by(self): + """Return ``y`` coordinate of north-eastern corner. """ + if not self.conj: + return self.b[1] + else: + return -self.a[1] + + @property + def dx(self): + """Return width of the complex isolating interval. """ + return self.b[0] - self.a[0] + + @property + def dy(self): + """Return height of the complex isolating interval. """ + return self.b[1] - self.a[1] + + @property + def center(self): + """Return the center of the complex isolating interval. """ + return ((self.ax + self.bx)/2, (self.ay + self.by)/2) + + @property + def max_denom(self): + """Return the largest denominator occurring in either endpoint. """ + return max(self.ax.denominator, self.bx.denominator, + self.ay.denominator, self.by.denominator) + + def as_tuple(self): + """Return tuple representation of the complex isolating + interval's SW and NE corners, respectively. """ + return ((self.ax, self.ay), (self.bx, self.by)) + + def __repr__(self): + return "(%s, %s) x (%s, %s)" % (self.ax, self.bx, self.ay, self.by) + + def conjugate(self): + """This complex interval really is located in lower half-plane. """ + return ComplexInterval(self.a, self.b, self.I, self.Q, + self.F1, self.F2, self.f1, self.f2, self.dom, conj=True) + + def __contains__(self, item): + """ + Say whether a complex number belongs to this complex rectangular + region. + + Parameters + ========== + + item : pair (re, im) or number re + Either a pair giving the real and imaginary parts of the number, + or else a real number. + + """ + if isinstance(item, tuple): + re, im = item + else: + re, im = item, 0 + return self.ax <= re <= self.bx and self.ay <= im <= self.by + + def is_disjoint(self, other): + """Return ``True`` if two isolation intervals are disjoint. """ + if isinstance(other, RealInterval): + return other.is_disjoint(self) + if self.conj != other.conj: # above and below real axis + return True + re_distinct = (self.bx < other.ax or other.bx < self.ax) + if re_distinct: + return True + im_distinct = (self.by < other.ay or other.by < self.ay) + return im_distinct + + def _inner_refine(self): + """Internal one step complex root refinement procedure. """ + (u, v), (s, t) = self.a, self.b + + I, Q = self.I, self.Q + + f1, F1 = self.f1, self.F1 + f2, F2 = self.f2, self.F2 + + dom = self.dom + + if s - u > t - v: + D_L, D_R = _vertical_bisection(1, (u, v), (s, t), I, Q, F1, F2, f1, f2, dom) + + if D_L[0] == 1: + _, a, b, I, Q, F1, F2 = D_L + else: + _, a, b, I, Q, F1, F2 = D_R + else: + D_B, D_U = _horizontal_bisection(1, (u, v), (s, t), I, Q, F1, F2, f1, f2, dom) + + if D_B[0] == 1: + _, a, b, I, Q, F1, F2 = D_B + else: + _, a, b, I, Q, F1, F2 = D_U + + return ComplexInterval(a, b, I, Q, F1, F2, f1, f2, dom, self.conj) + + def refine_disjoint(self, other): + """Refine an isolating interval until it is disjoint with another one. """ + expr = self + while not expr.is_disjoint(other): + expr, other = expr._inner_refine(), other._inner_refine() + + return expr, other + + def refine_size(self, dx, dy=None): + """Refine an isolating interval until it is of sufficiently small size. """ + if dy is None: + dy = dx + expr = self + while not (expr.dx < dx and expr.dy < dy): + expr = expr._inner_refine() + + return expr + + def refine_step(self, steps=1): + """Perform several steps of complex root refinement algorithm. """ + expr = self + for _ in range(steps): + expr = expr._inner_refine() + + return expr + + def refine(self): + """Perform one step of complex root refinement algorithm. """ + return self._inner_refine() diff --git a/phivenv/Lib/site-packages/sympy/polys/rootoftools.py b/phivenv/Lib/site-packages/sympy/polys/rootoftools.py new file mode 100644 index 0000000000000000000000000000000000000000..d68d8b008281c7e9b5aac618c6c76f74fa236d9e --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/rootoftools.py @@ -0,0 +1,1298 @@ +"""Implementation of RootOf class and related tools. """ + + + +from sympy.core.basic import Basic +from sympy.core import (S, Expr, Integer, Float, I, oo, Add, Lambda, + symbols, sympify, Rational, Dummy) +from sympy.core.cache import cacheit +from sympy.core.relational import is_le +from sympy.core.sorting import ordered +from sympy.polys.domains import QQ +from sympy.polys.polyerrors import ( + MultivariatePolynomialError, + GeneratorsNeeded, + PolynomialError, + DomainError) +from sympy.polys.polyfuncs import symmetrize, viete +from sympy.polys.polyroots import ( + roots_linear, roots_quadratic, roots_binomial, + preprocess_roots, roots) +from sympy.polys.polytools import Poly, PurePoly, factor +from sympy.polys.rationaltools import together +from sympy.polys.rootisolation import ( + dup_isolate_complex_roots_sqf, + dup_isolate_real_roots_sqf) +from sympy.utilities import lambdify, public, sift, numbered_symbols + +from mpmath import mpf, mpc, findroot, workprec +from mpmath.libmp.libmpf import dps_to_prec, prec_to_dps +from sympy.multipledispatch import dispatch +from itertools import chain + + +__all__ = ['CRootOf'] + + + +class _pure_key_dict: + """A minimal dictionary that makes sure that the key is a + univariate PurePoly instance. + + Examples + ======== + + Only the following actions are guaranteed: + + >>> from sympy.polys.rootoftools import _pure_key_dict + >>> from sympy import PurePoly + >>> from sympy.abc import x, y + + 1) creation + + >>> P = _pure_key_dict() + + 2) assignment for a PurePoly or univariate polynomial + + >>> P[x] = 1 + >>> P[PurePoly(x - y, x)] = 2 + + 3) retrieval based on PurePoly key comparison (use this + instead of the get method) + + >>> P[y] + 1 + + 4) KeyError when trying to retrieve a nonexisting key + + >>> P[y + 1] + Traceback (most recent call last): + ... + KeyError: PurePoly(y + 1, y, domain='ZZ') + + 5) ability to query with ``in`` + + >>> x + 1 in P + False + + NOTE: this is a *not* a dictionary. It is a very basic object + for internal use that makes sure to always address its cache + via PurePoly instances. It does not, for example, implement + ``get`` or ``setdefault``. + """ + def __init__(self): + self._dict = {} + + def __getitem__(self, k): + if not isinstance(k, PurePoly): + if not (isinstance(k, Expr) and len(k.free_symbols) == 1): + raise KeyError + k = PurePoly(k, expand=False) + return self._dict[k] + + def __setitem__(self, k, v): + if not isinstance(k, PurePoly): + if not (isinstance(k, Expr) and len(k.free_symbols) == 1): + raise ValueError('expecting univariate expression') + k = PurePoly(k, expand=False) + self._dict[k] = v + + def __contains__(self, k): + try: + self[k] + return True + except KeyError: + return False + +_reals_cache = _pure_key_dict() +_complexes_cache = _pure_key_dict() + + +def _pure_factors(poly): + _, factors = poly.factor_list() + return [(PurePoly(f, expand=False), m) for f, m in factors] + + +def _imag_count_of_factor(f): + """Return the number of imaginary roots for irreducible + univariate polynomial ``f``. + """ + terms = [(i, j) for (i,), j in f.terms()] + if any(i % 2 for i, j in terms): + return 0 + # update signs + even = [(i, I**i*j) for i, j in terms] + even = Poly.from_dict(dict(even), Dummy('x')) + return int(even.count_roots(-oo, oo)) + + +@public +def rootof(f, x, index=None, radicals=True, expand=True): + """An indexed root of a univariate polynomial. + + Returns either a :obj:`ComplexRootOf` object or an explicit + expression involving radicals. + + Parameters + ========== + + f : Expr + Univariate polynomial. + x : Symbol, optional + Generator for ``f``. + index : int or Integer + radicals : bool + Return a radical expression if possible. + expand : bool + Expand ``f``. + """ + return CRootOf(f, x, index=index, radicals=radicals, expand=expand) + + +@public +class RootOf(Expr): + """Represents a root of a univariate polynomial. + + Base class for roots of different kinds of polynomials. + Only complex roots are currently supported. + """ + + __slots__ = ('poly',) + + def __new__(cls, f, x, index=None, radicals=True, expand=True): + """Construct a new ``CRootOf`` object for ``k``-th root of ``f``.""" + return rootof(f, x, index=index, radicals=radicals, expand=expand) + +@public +class ComplexRootOf(RootOf): + """Represents an indexed complex root of a polynomial. + + Roots of a univariate polynomial separated into disjoint + real or complex intervals and indexed in a fixed order: + + * real roots come first and are sorted in increasing order; + * complex roots come next and are sorted primarily by increasing + real part, secondarily by increasing imaginary part. + + Currently only rational coefficients are allowed. + Can be imported as ``CRootOf``. To avoid confusion, the + generator must be a Symbol. + + + Examples + ======== + + >>> from sympy import CRootOf, rootof + >>> from sympy.abc import x + + CRootOf is a way to reference a particular root of a + polynomial. If there is a rational root, it will be returned: + + >>> CRootOf.clear_cache() # for doctest reproducibility + >>> CRootOf(x**2 - 4, 0) + -2 + + Whether roots involving radicals are returned or not + depends on whether the ``radicals`` flag is true (which is + set to True with rootof): + + >>> CRootOf(x**2 - 3, 0) + CRootOf(x**2 - 3, 0) + >>> CRootOf(x**2 - 3, 0, radicals=True) + -sqrt(3) + >>> rootof(x**2 - 3, 0) + -sqrt(3) + + The following cannot be expressed in terms of radicals: + + >>> r = rootof(4*x**5 + 16*x**3 + 12*x**2 + 7, 0); r + CRootOf(4*x**5 + 16*x**3 + 12*x**2 + 7, 0) + + The root bounds can be seen, however, and they are used by the + evaluation methods to get numerical approximations for the root. + + >>> interval = r._get_interval(); interval + (-1, 0) + >>> r.evalf(2) + -0.98 + + The evalf method refines the width of the root bounds until it + guarantees that any decimal approximation within those bounds + will satisfy the desired precision. It then stores the refined + interval so subsequent requests at or below the requested + precision will not have to recompute the root bounds and will + return very quickly. + + Before evaluation above, the interval was + + >>> interval + (-1, 0) + + After evaluation it is now + + >>> r._get_interval() # doctest: +SKIP + (-165/169, -206/211) + + To reset all intervals for a given polynomial, the :meth:`_reset` method + can be called from any CRootOf instance of the polynomial: + + >>> r._reset() + >>> r._get_interval() + (-1, 0) + + The :meth:`eval_approx` method will also find the root to a given + precision but the interval is not modified unless the search + for the root fails to converge within the root bounds. And + the secant method is used to find the root. (The ``evalf`` + method uses bisection and will always update the interval.) + + >>> r.eval_approx(2) + -0.98 + + The interval needed to be slightly updated to find that root: + + >>> r._get_interval() + (-1, -1/2) + + The ``evalf_rational`` will compute a rational approximation + of the root to the desired accuracy or precision. + + >>> r.eval_rational(n=2) + -69629/71318 + + >>> t = CRootOf(x**3 + 10*x + 1, 1) + >>> t.eval_rational(1e-1) + 15/256 - 805*I/256 + >>> t.eval_rational(1e-1, 1e-4) + 3275/65536 - 414645*I/131072 + >>> t.eval_rational(1e-4, 1e-4) + 6545/131072 - 414645*I/131072 + >>> t.eval_rational(n=2) + 104755/2097152 - 6634255*I/2097152 + + Notes + ===== + + Although a PurePoly can be constructed from a non-symbol generator + RootOf instances of non-symbols are disallowed to avoid confusion + over what root is being represented. + + >>> from sympy import exp, PurePoly + >>> PurePoly(x) == PurePoly(exp(x)) + True + >>> CRootOf(x - 1, 0) + 1 + >>> CRootOf(exp(x) - 1, 0) # would correspond to x == 0 + Traceback (most recent call last): + ... + sympy.polys.polyerrors.PolynomialError: generator must be a Symbol + + See Also + ======== + + eval_approx + eval_rational + + """ + + __slots__ = ('index',) + is_complex = True + is_number = True + is_finite = True + is_algebraic = True + + def __new__(cls, f, x, index=None, radicals=False, expand=True): + """ Construct an indexed complex root of a polynomial. + + See ``rootof`` for the parameters. + + The default value of ``radicals`` is ``False`` to satisfy + ``eval(srepr(expr) == expr``. + """ + x = sympify(x) + + if index is None and x.is_Integer: + x, index = None, x + else: + index = sympify(index) + + if index is not None and index.is_Integer: + index = int(index) + else: + raise ValueError("expected an integer root index, got %s" % index) + + poly = PurePoly(f, x, greedy=False, expand=expand) + + if not poly.is_univariate: + raise PolynomialError("only univariate polynomials are allowed") + + if not poly.gen.is_Symbol: + # PurePoly(sin(x) + 1) == PurePoly(x + 1) but the roots of + # x for each are not the same: issue 8617 + raise PolynomialError("generator must be a Symbol") + + degree = poly.degree() + + if degree <= 0: + raise PolynomialError("Cannot construct CRootOf object for %s" % f) + + if index < -degree or index >= degree: + raise IndexError("root index out of [%d, %d] range, got %d" % + (-degree, degree - 1, index)) + elif index < 0: + index += degree + + dom = poly.get_domain() + + if not dom.is_Exact: + poly = poly.to_exact() + + roots = cls._roots_trivial(poly, radicals) + + if roots is not None: + return roots[index] + + coeff, poly = preprocess_roots(poly) + dom = poly.get_domain() + + if not dom.is_ZZ: + raise NotImplementedError("CRootOf is not supported over %s" % dom) + + root = cls._indexed_root(poly, index, lazy=True) + return coeff * cls._postprocess_root(root, radicals) + + @classmethod + def _new(cls, poly, index): + """Construct new ``CRootOf`` object from raw data. """ + obj = Expr.__new__(cls) + + obj.poly = PurePoly(poly) + obj.index = index + + try: + _reals_cache[obj.poly] = _reals_cache[poly] + _complexes_cache[obj.poly] = _complexes_cache[poly] + except KeyError: + pass + + return obj + + def _hashable_content(self): + return (self.poly, self.index) + + @property + def expr(self): + return self.poly.as_expr() + + @property + def args(self): + return (self.expr, Integer(self.index)) + + @property + def free_symbols(self): + # CRootOf currently only works with univariate expressions + # whose poly attribute should be a PurePoly with no free + # symbols + return set() + + def _eval_is_real(self): + """Return ``True`` if the root is real. """ + self._ensure_reals_init() + return self.index < len(_reals_cache[self.poly]) + + def _eval_is_imaginary(self): + """Return ``True`` if the root is imaginary. """ + self._ensure_reals_init() + if self.index >= len(_reals_cache[self.poly]): + ivl = self._get_interval() + return ivl.ax*ivl.bx <= 0 # all others are on one side or the other + return False # XXX is this necessary? + + @classmethod + def real_roots(cls, poly, radicals=True): + """Get real roots of a polynomial. """ + return cls._get_roots("_real_roots", poly, radicals) + + @classmethod + def all_roots(cls, poly, radicals=True): + """Get real and complex roots of a polynomial. """ + return cls._get_roots("_all_roots", poly, radicals) + + @classmethod + def _get_reals_sqf(cls, currentfactor, use_cache=True): + """Get real root isolating intervals for a square-free factor.""" + if use_cache and currentfactor in _reals_cache: + real_part = _reals_cache[currentfactor] + else: + _reals_cache[currentfactor] = real_part = \ + dup_isolate_real_roots_sqf( + currentfactor.rep.to_list(), currentfactor.rep.dom, blackbox=True) + + return real_part + + @classmethod + def _get_complexes_sqf(cls, currentfactor, use_cache=True): + """Get complex root isolating intervals for a square-free factor.""" + if use_cache and currentfactor in _complexes_cache: + complex_part = _complexes_cache[currentfactor] + else: + _complexes_cache[currentfactor] = complex_part = \ + dup_isolate_complex_roots_sqf( + currentfactor.rep.to_list(), currentfactor.rep.dom, blackbox=True) + return complex_part + + @classmethod + def _get_reals(cls, factors, use_cache=True): + """Compute real root isolating intervals for a list of factors. """ + reals = [] + + for currentfactor, k in factors: + try: + if not use_cache: + raise KeyError + r = _reals_cache[currentfactor] + reals.extend([(i, currentfactor, k) for i in r]) + except KeyError: + real_part = cls._get_reals_sqf(currentfactor, use_cache) + new = [(root, currentfactor, k) for root in real_part] + reals.extend(new) + + reals = cls._reals_sorted(reals) + return reals + + @classmethod + def _get_complexes(cls, factors, use_cache=True): + """Compute complex root isolating intervals for a list of factors. """ + complexes = [] + + for currentfactor, k in ordered(factors): + try: + if not use_cache: + raise KeyError + c = _complexes_cache[currentfactor] + complexes.extend([(i, currentfactor, k) for i in c]) + except KeyError: + complex_part = cls._get_complexes_sqf(currentfactor, use_cache) + new = [(root, currentfactor, k) for root in complex_part] + complexes.extend(new) + + complexes = cls._complexes_sorted(complexes) + return complexes + + @classmethod + def _reals_sorted(cls, reals): + """Make real isolating intervals disjoint and sort roots. """ + cache = {} + + for i, (u, f, k) in enumerate(reals): + for j, (v, g, m) in enumerate(reals[i + 1:]): + u, v = u.refine_disjoint(v) + reals[i + j + 1] = (v, g, m) + + reals[i] = (u, f, k) + + reals = sorted(reals, key=lambda r: r[0].a) + + for root, currentfactor, _ in reals: + if currentfactor in cache: + cache[currentfactor].append(root) + else: + cache[currentfactor] = [root] + + for currentfactor, root in cache.items(): + _reals_cache[currentfactor] = root + + return reals + + @classmethod + def _refine_imaginary(cls, complexes): + sifted = sift(complexes, lambda c: c[1]) + complexes = [] + for f in ordered(sifted): + nimag = _imag_count_of_factor(f) + if nimag == 0: + # refine until xbounds are neg or pos + for u, f, k in sifted[f]: + while u.ax*u.bx <= 0: + u = u._inner_refine() + complexes.append((u, f, k)) + else: + # refine until all but nimag xbounds are neg or pos + potential_imag = list(range(len(sifted[f]))) + while True: + assert len(potential_imag) > 1 + for i in list(potential_imag): + u, f, k = sifted[f][i] + if u.ax*u.bx > 0: + potential_imag.remove(i) + elif u.ax != u.bx: + u = u._inner_refine() + sifted[f][i] = u, f, k + if len(potential_imag) == nimag: + break + complexes.extend(sifted[f]) + return complexes + + @classmethod + def _refine_complexes(cls, complexes): + """return complexes such that no bounding rectangles of non-conjugate + roots would intersect. In addition, assure that neither ay nor by is + 0 to guarantee that non-real roots are distinct from real roots in + terms of the y-bounds. + """ + # get the intervals pairwise-disjoint. + # If rectangles were drawn around the coordinates of the bounding + # rectangles, no rectangles would intersect after this procedure. + for i, (u, f, k) in enumerate(complexes): + for j, (v, g, m) in enumerate(complexes[i + 1:]): + u, v = u.refine_disjoint(v) + complexes[i + j + 1] = (v, g, m) + + complexes[i] = (u, f, k) + + # refine until the x-bounds are unambiguously positive or negative + # for non-imaginary roots + complexes = cls._refine_imaginary(complexes) + + # make sure that all y bounds are off the real axis + # and on the same side of the axis + for i, (u, f, k) in enumerate(complexes): + while u.ay*u.by <= 0: + u = u.refine() + complexes[i] = u, f, k + return complexes + + @classmethod + def _complexes_sorted(cls, complexes): + """Make complex isolating intervals disjoint and sort roots. """ + complexes = cls._refine_complexes(complexes) + # XXX don't sort until you are sure that it is compatible + # with the indexing method but assert that the desired state + # is not broken + C, F = 0, 1 # location of ComplexInterval and factor + fs = {i[F] for i in complexes} + for i in range(1, len(complexes)): + if complexes[i][F] != complexes[i - 1][F]: + # if this fails the factors of a root were not + # contiguous because a discontinuity should only + # happen once + fs.remove(complexes[i - 1][F]) + for i, cmplx in enumerate(complexes): + # negative im part (conj=True) comes before + # positive im part (conj=False) + assert cmplx[C].conj is (i % 2 == 0) + + # update cache + cache = {} + # -- collate + for root, currentfactor, _ in complexes: + cache.setdefault(currentfactor, []).append(root) + # -- store + for currentfactor, root in cache.items(): + _complexes_cache[currentfactor] = root + + return complexes + + @classmethod + def _reals_index(cls, reals, index): + """ + Map initial real root index to an index in a factor where + the root belongs. + """ + i = 0 + + for j, (_, currentfactor, k) in enumerate(reals): + if index < i + k: + poly, index = currentfactor, 0 + + for _, currentfactor, _ in reals[:j]: + if currentfactor == poly: + index += 1 + + return poly, index + else: + i += k + + @classmethod + def _complexes_index(cls, complexes, index): + """ + Map initial complex root index to an index in a factor where + the root belongs. + """ + i = 0 + for j, (_, currentfactor, k) in enumerate(complexes): + if index < i + k: + poly, index = currentfactor, 0 + + for _, currentfactor, _ in complexes[:j]: + if currentfactor == poly: + index += 1 + + index += len(_reals_cache[poly]) + + return poly, index + else: + i += k + + @classmethod + def _count_roots(cls, roots): + """Count the number of real or complex roots with multiplicities.""" + return sum(k for _, _, k in roots) + + @classmethod + def _indexed_root(cls, poly, index, lazy=False): + """Get a root of a composite polynomial by index. """ + factors = _pure_factors(poly) + + # If the given poly is already irreducible, then the index does not + # need to be adjusted, and we can postpone the heavy lifting of + # computing and refining isolating intervals until that is needed. + # Note, however, that `_pure_factors()` extracts a negative leading + # coeff if present, so `factors[0][0]` may differ from `poly`, and + # is the "normalized" version of `poly` that we must return. + if lazy and len(factors) == 1 and factors[0][1] == 1: + return factors[0][0], index + + reals = cls._get_reals(factors) + reals_count = cls._count_roots(reals) + + if index < reals_count: + return cls._reals_index(reals, index) + else: + complexes = cls._get_complexes(factors) + return cls._complexes_index(complexes, index - reals_count) + + def _ensure_reals_init(self): + """Ensure that our poly has entries in the reals cache. """ + if self.poly not in _reals_cache: + self._indexed_root(self.poly, self.index) + + def _ensure_complexes_init(self): + """Ensure that our poly has entries in the complexes cache. """ + if self.poly not in _complexes_cache: + self._indexed_root(self.poly, self.index) + + @classmethod + def _real_roots(cls, poly): + """Get real roots of a composite polynomial. """ + factors = _pure_factors(poly) + + reals = cls._get_reals(factors) + reals_count = cls._count_roots(reals) + + roots = [] + + for index in range(0, reals_count): + roots.append(cls._reals_index(reals, index)) + + return roots + + def _reset(self): + """ + Reset all intervals + """ + self._all_roots(self.poly, use_cache=False) + + @classmethod + def _all_roots(cls, poly, use_cache=True): + """Get real and complex roots of a composite polynomial. """ + factors = _pure_factors(poly) + + reals = cls._get_reals(factors, use_cache=use_cache) + reals_count = cls._count_roots(reals) + + roots = [] + + for index in range(0, reals_count): + roots.append(cls._reals_index(reals, index)) + + complexes = cls._get_complexes(factors, use_cache=use_cache) + complexes_count = cls._count_roots(complexes) + + for index in range(0, complexes_count): + roots.append(cls._complexes_index(complexes, index)) + + return roots + + @classmethod + @cacheit + def _roots_trivial(cls, poly, radicals): + """Compute roots in linear, quadratic and binomial cases. """ + if poly.degree() == 1: + return roots_linear(poly) + + if not radicals: + return None + + if poly.degree() == 2: + return roots_quadratic(poly) + elif poly.length() == 2 and poly.TC(): + return roots_binomial(poly) + else: + return None + + @classmethod + def _preprocess_roots(cls, poly): + """Take heroic measures to make ``poly`` compatible with ``CRootOf``.""" + dom = poly.get_domain() + + if not dom.is_Exact: + poly = poly.to_exact() + + coeff, poly = preprocess_roots(poly) + dom = poly.get_domain() + + if not dom.is_ZZ: + raise NotImplementedError( + "sorted roots not supported over %s" % dom) + + return coeff, poly + + @classmethod + def _postprocess_root(cls, root, radicals): + """Return the root if it is trivial or a ``CRootOf`` object. """ + poly, index = root + roots = cls._roots_trivial(poly, radicals) + + if roots is not None: + return roots[index] + else: + return cls._new(poly, index) + + @classmethod + def _get_roots(cls, method, poly, radicals): + """Return postprocessed roots of specified kind. """ + if not poly.is_univariate: + raise PolynomialError("only univariate polynomials are allowed") + + dom = poly.get_domain() + + # get rid of gen and it's free symbol + d = Dummy() + poly = poly.subs(poly.gen, d) + x = symbols('x') + # see what others are left and select x or a numbered x + # that doesn't clash + free_names = {str(i) for i in poly.free_symbols} + for x in chain((symbols('x'),), numbered_symbols('x')): + if x.name not in free_names: + poly = poly.replace(d, x) + break + + if dom.is_QQ or dom.is_ZZ: + return cls._get_roots_qq(method, poly, radicals) + elif dom.is_AlgebraicField or dom.is_ZZ_I or dom.is_QQ_I: + return cls._get_roots_alg(method, poly, radicals) + else: + # XXX: not sure how to handle ZZ[x] which appears in some tests? + # this makes the tests pass alright but has to be a better way? + return cls._get_roots_qq(method, poly, radicals) + + + @classmethod + def _get_roots_qq(cls, method, poly, radicals): + """Return postprocessed roots of specified kind + for polynomials with rational coefficients. """ + coeff, poly = cls._preprocess_roots(poly) + roots = [] + + for root in getattr(cls, method)(poly): + roots.append(coeff*cls._postprocess_root(root, radicals)) + + return roots + + @classmethod + def _get_roots_alg(cls, method, poly, radicals): + """Return postprocessed roots of specified kind + for polynomials with algebraic coefficients. It assumes + the domain is already an algebraic field. First it + finds the roots using _get_roots_qq, then uses the + square-free factors to filter roots and get the correct + multiplicity. + """ + + # Existing QQ code can find and sort the roots + roots = cls._get_roots_qq(method, poly.lift(), radicals) + + subroots = {} + for f, m in poly.sqf_list()[1]: + if method == "_real_roots": + roots_filt = f.which_real_roots(roots) + elif method == "_all_roots": + roots_filt = f.which_all_roots(roots) + for r in roots_filt: + subroots[r] = m + + roots_seen = set() + roots_flat = [] + for r in roots: + if r in subroots and r not in roots_seen: + m = subroots[r] + roots_flat.extend([r] * m) + roots_seen.add(r) + + return roots_flat + + @classmethod + def clear_cache(cls): + """Reset cache for reals and complexes. + + The intervals used to approximate a root instance are updated + as needed. When a request is made to see the intervals, the + most current values are shown. `clear_cache` will reset all + CRootOf instances back to their original state. + + See Also + ======== + + _reset + """ + global _reals_cache, _complexes_cache + _reals_cache = _pure_key_dict() + _complexes_cache = _pure_key_dict() + + def _get_interval(self): + """Internal function for retrieving isolation interval from cache. """ + self._ensure_reals_init() + if self.is_real: + return _reals_cache[self.poly][self.index] + else: + reals_count = len(_reals_cache[self.poly]) + self._ensure_complexes_init() + return _complexes_cache[self.poly][self.index - reals_count] + + def _set_interval(self, interval): + """Internal function for updating isolation interval in cache. """ + self._ensure_reals_init() + if self.is_real: + _reals_cache[self.poly][self.index] = interval + else: + reals_count = len(_reals_cache[self.poly]) + self._ensure_complexes_init() + _complexes_cache[self.poly][self.index - reals_count] = interval + + def _eval_subs(self, old, new): + # don't allow subs to change anything + return self + + def _eval_conjugate(self): + if self.is_real: + return self + expr, i = self.args + return self.func(expr, i + (1 if self._get_interval().conj else -1)) + + def eval_approx(self, n, return_mpmath=False): + """Evaluate this complex root to the given precision. + + This uses secant method and root bounds are used to both + generate an initial guess and to check that the root + returned is valid. If ever the method converges outside the + root bounds, the bounds will be made smaller and updated. + """ + prec = dps_to_prec(n) + with workprec(prec): + g = self.poly.gen + if not g.is_Symbol: + d = Dummy('x') + if self.is_imaginary: + d *= I + func = lambdify(d, self.expr.subs(g, d)) + else: + expr = self.expr + if self.is_imaginary: + expr = self.expr.subs(g, I*g) + func = lambdify(g, expr) + + interval = self._get_interval() + while True: + if self.is_real: + a = mpf(str(interval.a)) + b = mpf(str(interval.b)) + if a == b: + root = a + break + x0 = mpf(str(interval.center)) + x1 = x0 + mpf(str(interval.dx))/4 + elif self.is_imaginary: + a = mpf(str(interval.ay)) + b = mpf(str(interval.by)) + if a == b: + root = mpc(mpf('0'), a) + break + x0 = mpf(str(interval.center[1])) + x1 = x0 + mpf(str(interval.dy))/4 + else: + ax = mpf(str(interval.ax)) + bx = mpf(str(interval.bx)) + ay = mpf(str(interval.ay)) + by = mpf(str(interval.by)) + if ax == bx and ay == by: + root = mpc(ax, ay) + break + x0 = mpc(*map(str, interval.center)) + x1 = x0 + mpc(*map(str, (interval.dx, interval.dy)))/4 + try: + # without a tolerance, this will return when (to within + # the given precision) x_i == x_{i-1} + root = findroot(func, (x0, x1)) + # If the (real or complex) root is not in the 'interval', + # then keep refining the interval. This happens if findroot + # accidentally finds a different root outside of this + # interval because our initial estimate 'x0' was not close + # enough. It is also possible that the secant method will + # get trapped by a max/min in the interval; the root + # verification by findroot will raise a ValueError in this + # case and the interval will then be tightened -- and + # eventually the root will be found. + # + # It is also possible that findroot will not have any + # successful iterations to process (in which case it + # will fail to initialize a variable that is tested + # after the iterations and raise an UnboundLocalError). + if self.is_real or self.is_imaginary: + if not bool(root.imag) == self.is_real and ( + a <= root <= b): + if self.is_imaginary: + root = mpc(mpf('0'), root.real) + break + elif (ax <= root.real <= bx and ay <= root.imag <= by): + break + except (UnboundLocalError, ValueError): + pass + interval = interval.refine() + + # update the interval so we at least (for this precision or + # less) don't have much work to do to recompute the root + self._set_interval(interval) + if return_mpmath: + return root + return (Float._new(root.real._mpf_, prec) + + I*Float._new(root.imag._mpf_, prec)) + + def _eval_evalf(self, prec, **kwargs): + """Evaluate this complex root to the given precision.""" + # all kwargs are ignored + return self.eval_rational(n=prec_to_dps(prec))._evalf(prec) + + def eval_rational(self, dx=None, dy=None, n=15): + """ + Return a Rational approximation of ``self`` that has real + and imaginary component approximations that are within ``dx`` + and ``dy`` of the true values, respectively. Alternatively, + ``n`` digits of precision can be specified. + + The interval is refined with bisection and is sure to + converge. The root bounds are updated when the refinement + is complete so recalculation at the same or lesser precision + will not have to repeat the refinement and should be much + faster. + + The following example first obtains Rational approximation to + 1e-8 accuracy for all roots of the 4-th order Legendre + polynomial. Since the roots are all less than 1, this will + ensure the decimal representation of the approximation will be + correct (including rounding) to 6 digits: + + >>> from sympy import legendre_poly, Symbol + >>> x = Symbol("x") + >>> p = legendre_poly(4, x, polys=True) + >>> r = p.real_roots()[-1] + >>> r.eval_rational(10**-8).n(6) + 0.861136 + + It is not necessary to a two-step calculation, however: the + decimal representation can be computed directly: + + >>> r.evalf(17) + 0.86113631159405258 + + """ + dy = dy or dx + if dx: + rtol = None + dx = dx if isinstance(dx, Rational) else Rational(str(dx)) + dy = dy if isinstance(dy, Rational) else Rational(str(dy)) + else: + # 5 binary (or 2 decimal) digits are needed to ensure that + # a given digit is correctly rounded + # prec_to_dps(dps_to_prec(n) + 5) - n <= 2 (tested for + # n in range(1000000) + rtol = S(10)**-(n + 2) # +2 for guard digits + interval = self._get_interval() + while True: + if self.is_real: + if rtol: + dx = abs(interval.center*rtol) + interval = interval.refine_size(dx=dx) + c = interval.center + real = Rational(c) + imag = S.Zero + if not rtol or interval.dx < abs(c*rtol): + break + elif self.is_imaginary: + if rtol: + dy = abs(interval.center[1]*rtol) + dx = 1 + interval = interval.refine_size(dx=dx, dy=dy) + c = interval.center[1] + imag = Rational(c) + real = S.Zero + if not rtol or interval.dy < abs(c*rtol): + break + else: + if rtol: + dx = abs(interval.center[0]*rtol) + dy = abs(interval.center[1]*rtol) + interval = interval.refine_size(dx, dy) + c = interval.center + real, imag = map(Rational, c) + if not rtol or ( + interval.dx < abs(c[0]*rtol) and + interval.dy < abs(c[1]*rtol)): + break + + # update the interval so we at least (for this precision or + # less) don't have much work to do to recompute the root + self._set_interval(interval) + return real + I*imag + + +CRootOf = ComplexRootOf + + +@dispatch(ComplexRootOf, ComplexRootOf) +def _eval_is_eq(lhs, rhs): # noqa:F811 + # if we use is_eq to check here, we get infinite recursion + return lhs == rhs + + +@dispatch(ComplexRootOf, Basic) # type:ignore +def _eval_is_eq(lhs, rhs): # noqa:F811 + # CRootOf represents a Root, so if rhs is that root, it should set + # the expression to zero *and* it should be in the interval of the + # CRootOf instance. It must also be a number that agrees with the + # is_real value of the CRootOf instance. + if not rhs.is_number: + return None + if not rhs.is_finite: + return False + z = lhs.expr.subs(lhs.expr.free_symbols.pop(), rhs).is_zero + if z is False: # all roots will make z True but we don't know + # whether this is the right root if z is True + return False + o = rhs.is_real, rhs.is_imaginary + s = lhs.is_real, lhs.is_imaginary + assert None not in s # this is part of initial refinement + if o != s and None not in o: + return False + re, im = rhs.as_real_imag() + if lhs.is_real: + if im: + return False + i = lhs._get_interval() + a, b = [Rational(str(_)) for _ in (i.a, i.b)] + return sympify(a <= rhs and rhs <= b) + i = lhs._get_interval() + r1, r2, i1, i2 = [Rational(str(j)) for j in ( + i.ax, i.bx, i.ay, i.by)] + return is_le(r1, re) and is_le(re,r2) and is_le(i1,im) and is_le(im,i2) + + +@public +class RootSum(Expr): + """Represents a sum of all roots of a univariate polynomial. """ + + __slots__ = ('poly', 'fun', 'auto') + + def __new__(cls, expr, func=None, x=None, auto=True, quadratic=False): + """Construct a new ``RootSum`` instance of roots of a polynomial.""" + coeff, poly = cls._transform(expr, x) + + if not poly.is_univariate: + raise MultivariatePolynomialError( + "only univariate polynomials are allowed") + + if func is None: + func = Lambda(poly.gen, poly.gen) + else: + is_func = getattr(func, 'is_Function', False) + + if is_func and 1 in func.nargs: + if not isinstance(func, Lambda): + func = Lambda(poly.gen, func(poly.gen)) + else: + raise ValueError( + "expected a univariate function, got %s" % func) + + var, expr = func.variables[0], func.expr + + if coeff is not S.One: + expr = expr.subs(var, coeff*var) + + deg = poly.degree() + + if not expr.has(var): + return deg*expr + + if expr.is_Add: + add_const, expr = expr.as_independent(var) + else: + add_const = S.Zero + + if expr.is_Mul: + mul_const, expr = expr.as_independent(var) + else: + mul_const = S.One + + func = Lambda(var, expr) + + rational = cls._is_func_rational(poly, func) + factors, terms = _pure_factors(poly), [] + + for poly, k in factors: + if poly.is_linear: + term = func(roots_linear(poly)[0]) + elif quadratic and poly.is_quadratic: + term = sum(map(func, roots_quadratic(poly))) + else: + if not rational or not auto: + term = cls._new(poly, func, auto) + else: + term = cls._rational_case(poly, func) + + terms.append(k*term) + + return mul_const*Add(*terms) + deg*add_const + + @classmethod + def _new(cls, poly, func, auto=True): + """Construct new raw ``RootSum`` instance. """ + obj = Expr.__new__(cls) + + obj.poly = poly + obj.fun = func + obj.auto = auto + + return obj + + @classmethod + def new(cls, poly, func, auto=True): + """Construct new ``RootSum`` instance. """ + if not func.expr.has(*func.variables): + return func.expr + + rational = cls._is_func_rational(poly, func) + + if not rational or not auto: + return cls._new(poly, func, auto) + else: + return cls._rational_case(poly, func) + + @classmethod + def _transform(cls, expr, x): + """Transform an expression to a polynomial. """ + poly = PurePoly(expr, x, greedy=False) + return preprocess_roots(poly) + + @classmethod + def _is_func_rational(cls, poly, func): + """Check if a lambda is a rational function. """ + var, expr = func.variables[0], func.expr + return expr.is_rational_function(var) + + @classmethod + def _rational_case(cls, poly, func): + """Handle the rational function case. """ + roots = symbols('r:%d' % poly.degree()) + var, expr = func.variables[0], func.expr + + f = sum(expr.subs(var, r) for r in roots) + p, q = together(f).as_numer_denom() + + domain = QQ[roots] + + p = p.expand() + q = q.expand() + + try: + p = Poly(p, domain=domain, expand=False) + except GeneratorsNeeded: + p, p_coeff = None, (p,) + else: + p_monom, p_coeff = zip(*p.terms()) + + try: + q = Poly(q, domain=domain, expand=False) + except GeneratorsNeeded: + q, q_coeff = None, (q,) + else: + q_monom, q_coeff = zip(*q.terms()) + + coeffs, mapping = symmetrize(p_coeff + q_coeff, formal=True) + formulas, values = viete(poly, roots), [] + + for (sym, _), (_, val) in zip(mapping, formulas): + values.append((sym, val)) + + for i, (coeff, _) in enumerate(coeffs): + coeffs[i] = coeff.subs(values) + + n = len(p_coeff) + + p_coeff = coeffs[:n] + q_coeff = coeffs[n:] + + if p is not None: + p = Poly(dict(zip(p_monom, p_coeff)), *p.gens).as_expr() + else: + (p,) = p_coeff + + if q is not None: + q = Poly(dict(zip(q_monom, q_coeff)), *q.gens).as_expr() + else: + (q,) = q_coeff + + return factor(p/q) + + def _hashable_content(self): + return (self.poly, self.fun) + + @property + def expr(self): + return self.poly.as_expr() + + @property + def args(self): + return (self.expr, self.fun, self.poly.gen) + + @property + def free_symbols(self): + return self.poly.free_symbols | self.fun.free_symbols + + @property + def is_commutative(self): + return True + + def doit(self, **hints): + if not hints.get('roots', True): + return self + + _roots = roots(self.poly, multiple=True) + + if len(_roots) < self.poly.degree(): + return self + else: + return Add(*[self.fun(r) for r in _roots]) + + def _eval_evalf(self, prec): + try: + _roots = self.poly.nroots(n=prec_to_dps(prec)) + except (DomainError, PolynomialError): + return self + else: + return Add(*[self.fun(r) for r in _roots]) + + def _eval_derivative(self, x): + var, expr = self.fun.args + func = Lambda(var, expr.diff(x)) + return self.new(self.poly, func, self.auto) diff --git a/phivenv/Lib/site-packages/sympy/polys/solvers.py b/phivenv/Lib/site-packages/sympy/polys/solvers.py new file mode 100644 index 0000000000000000000000000000000000000000..b333e81d975a8cd71e7eb683c2b943d8538f6ac5 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/solvers.py @@ -0,0 +1,435 @@ +"""Low-level linear systems solver. """ + + +from sympy.utilities.exceptions import sympy_deprecation_warning +from sympy.utilities.iterables import connected_components + +from sympy.core.sympify import sympify +from sympy.core.numbers import Integer, Rational +from sympy.matrices.dense import MutableDenseMatrix +from sympy.polys.domains import ZZ, QQ + +from sympy.polys.domains import EX +from sympy.polys.rings import sring +from sympy.polys.polyerrors import NotInvertible +from sympy.polys.domainmatrix import DomainMatrix + + +class PolyNonlinearError(Exception): + """Raised by solve_lin_sys for nonlinear equations""" + pass + + +class RawMatrix(MutableDenseMatrix): + """ + .. deprecated:: 1.9 + + This class fundamentally is broken by design. Use ``DomainMatrix`` if + you want a matrix over the polys domains or ``Matrix`` for a matrix + with ``Expr`` elements. The ``RawMatrix`` class will be removed/broken + in future in order to reestablish the invariant that the elements of a + Matrix should be of type ``Expr``. + + """ + _sympify = staticmethod(lambda x, *args, **kwargs: x) + + def __init__(self, *args, **kwargs): + sympy_deprecation_warning( + """ + The RawMatrix class is deprecated. Use either DomainMatrix or + Matrix instead. + """, + deprecated_since_version="1.9", + active_deprecations_target="deprecated-rawmatrix", + ) + + domain = ZZ + for i in range(self.rows): + for j in range(self.cols): + val = self[i,j] + if getattr(val, 'is_Poly', False): + K = val.domain[val.gens] + val_sympy = val.as_expr() + elif hasattr(val, 'parent'): + K = val.parent() + val_sympy = K.to_sympy(val) + elif isinstance(val, (int, Integer)): + K = ZZ + val_sympy = sympify(val) + elif isinstance(val, Rational): + K = QQ + val_sympy = val + else: + for K in ZZ, QQ: + if K.of_type(val): + val_sympy = K.to_sympy(val) + break + else: + raise TypeError + domain = domain.unify(K) + self[i,j] = val_sympy + self.ring = domain + + +def eqs_to_matrix(eqs_coeffs, eqs_rhs, gens, domain): + """Get matrix from linear equations in dict format. + + Explanation + =========== + + Get the matrix representation of a system of linear equations represented + as dicts with low-level DomainElement coefficients. This is an + *internal* function that is used by solve_lin_sys. + + Parameters + ========== + + eqs_coeffs: list[dict[Symbol, DomainElement]] + The left hand sides of the equations as dicts mapping from symbols to + coefficients where the coefficients are instances of + DomainElement. + eqs_rhs: list[DomainElements] + The right hand sides of the equations as instances of + DomainElement. + gens: list[Symbol] + The unknowns in the system of equations. + domain: Domain + The domain for coefficients of both lhs and rhs. + + Returns + ======= + + The augmented matrix representation of the system as a DomainMatrix. + + Examples + ======== + + >>> from sympy import symbols, ZZ + >>> from sympy.polys.solvers import eqs_to_matrix + >>> x, y = symbols('x, y') + >>> eqs_coeff = [{x:ZZ(1), y:ZZ(1)}, {x:ZZ(1), y:ZZ(-1)}] + >>> eqs_rhs = [ZZ(0), ZZ(-1)] + >>> eqs_to_matrix(eqs_coeff, eqs_rhs, [x, y], ZZ) + DomainMatrix([[1, 1, 0], [1, -1, 1]], (2, 3), ZZ) + + See also + ======== + + solve_lin_sys: Uses :func:`~eqs_to_matrix` internally + """ + sym2index = {x: n for n, x in enumerate(gens)} + nrows = len(eqs_coeffs) + ncols = len(gens) + 1 + rows = [[domain.zero] * ncols for _ in range(nrows)] + for row, eq_coeff, eq_rhs in zip(rows, eqs_coeffs, eqs_rhs): + for sym, coeff in eq_coeff.items(): + row[sym2index[sym]] = domain.convert(coeff) + row[-1] = -domain.convert(eq_rhs) + + return DomainMatrix(rows, (nrows, ncols), domain) + + +def sympy_eqs_to_ring(eqs, symbols): + """Convert a system of equations from Expr to a PolyRing + + Explanation + =========== + + High-level functions like ``solve`` expect Expr as inputs but can use + ``solve_lin_sys`` internally. This function converts equations from + ``Expr`` to the low-level poly types used by the ``solve_lin_sys`` + function. + + Parameters + ========== + + eqs: List of Expr + A list of equations as Expr instances + symbols: List of Symbol + A list of the symbols that are the unknowns in the system of + equations. + + Returns + ======= + + Tuple[List[PolyElement], Ring]: The equations as PolyElement instances + and the ring of polynomials within which each equation is represented. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.polys.solvers import sympy_eqs_to_ring + >>> a, x, y = symbols('a, x, y') + >>> eqs = [x-y, x+a*y] + >>> eqs_ring, ring = sympy_eqs_to_ring(eqs, [x, y]) + >>> eqs_ring + [x - y, x + a*y] + >>> type(eqs_ring[0]) + + >>> ring + ZZ(a)[x,y] + + With the equations in this form they can be passed to ``solve_lin_sys``: + + >>> from sympy.polys.solvers import solve_lin_sys + >>> solve_lin_sys(eqs_ring, ring) + {y: 0, x: 0} + """ + try: + K, eqs_K = sring(eqs, symbols, field=True, extension=True) + except NotInvertible: + # https://github.com/sympy/sympy/issues/18874 + K, eqs_K = sring(eqs, symbols, domain=EX) + return eqs_K, K.to_domain() + + +def solve_lin_sys(eqs, ring, _raw=True): + """Solve a system of linear equations from a PolynomialRing + + Explanation + =========== + + Solves a system of linear equations given as PolyElement instances of a + PolynomialRing. The basic arithmetic is carried out using instance of + DomainElement which is more efficient than :class:`~sympy.core.expr.Expr` + for the most common inputs. + + While this is a public function it is intended primarily for internal use + so its interface is not necessarily convenient. Users are suggested to use + the :func:`sympy.solvers.solveset.linsolve` function (which uses this + function internally) instead. + + Parameters + ========== + + eqs: list[PolyElement] + The linear equations to be solved as elements of a + PolynomialRing (assumed equal to zero). + ring: PolynomialRing + The polynomial ring from which eqs are drawn. The generators of this + ring are the unknowns to be solved for and the domain of the ring is + the domain of the coefficients of the system of equations. + _raw: bool + If *_raw* is False, the keys and values in the returned dictionary + will be of type Expr (and the unit of the field will be removed from + the keys) otherwise the low-level polys types will be returned, e.g. + PolyElement: PythonRational. + + Returns + ======= + + ``None`` if the system has no solution. + + dict[Symbol, Expr] if _raw=False + + dict[Symbol, DomainElement] if _raw=True. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.polys.solvers import solve_lin_sys, sympy_eqs_to_ring + >>> x, y = symbols('x, y') + >>> eqs = [x - y, x + y - 2] + >>> eqs_ring, ring = sympy_eqs_to_ring(eqs, [x, y]) + >>> solve_lin_sys(eqs_ring, ring) + {y: 1, x: 1} + + Passing ``_raw=False`` returns the same result except that the keys are + ``Expr`` rather than low-level poly types. + + >>> solve_lin_sys(eqs_ring, ring, _raw=False) + {x: 1, y: 1} + + See also + ======== + + sympy_eqs_to_ring: prepares the inputs to ``solve_lin_sys``. + linsolve: ``linsolve`` uses ``solve_lin_sys`` internally. + sympy.solvers.solvers.solve: ``solve`` uses ``solve_lin_sys`` internally. + """ + as_expr = not _raw + + assert ring.domain.is_Field + + eqs_dict = [dict(eq) for eq in eqs] + + one_monom = ring.one.monoms()[0] + zero = ring.domain.zero + + eqs_rhs = [] + eqs_coeffs = [] + for eq_dict in eqs_dict: + eq_rhs = eq_dict.pop(one_monom, zero) + eq_coeffs = {} + for monom, coeff in eq_dict.items(): + if sum(monom) != 1: + msg = "Nonlinear term encountered in solve_lin_sys" + raise PolyNonlinearError(msg) + eq_coeffs[ring.gens[monom.index(1)]] = coeff + if not eq_coeffs: + if not eq_rhs: + continue + else: + return None + eqs_rhs.append(eq_rhs) + eqs_coeffs.append(eq_coeffs) + + result = _solve_lin_sys(eqs_coeffs, eqs_rhs, ring) + + if result is not None and as_expr: + + def to_sympy(x): + as_expr = getattr(x, 'as_expr', None) + if as_expr: + return as_expr() + else: + return ring.domain.to_sympy(x) + + tresult = {to_sympy(sym): to_sympy(val) for sym, val in result.items()} + + # Remove 1.0x + result = {} + for k, v in tresult.items(): + if k.is_Mul: + c, s = k.as_coeff_Mul() + result[s] = v/c + else: + result[k] = v + + return result + + +def _solve_lin_sys(eqs_coeffs, eqs_rhs, ring): + """Solve a linear system from dict of PolynomialRing coefficients + + Explanation + =========== + + This is an **internal** function used by :func:`solve_lin_sys` after the + equations have been preprocessed. The role of this function is to split + the system into connected components and pass those to + :func:`_solve_lin_sys_component`. + + Examples + ======== + + Setup a system for $x-y=0$ and $x+y=2$ and solve: + + >>> from sympy import symbols, sring + >>> from sympy.polys.solvers import _solve_lin_sys + >>> x, y = symbols('x, y') + >>> R, (xr, yr) = sring([x, y], [x, y]) + >>> eqs = [{xr:R.one, yr:-R.one}, {xr:R.one, yr:R.one}] + >>> eqs_rhs = [R.zero, -2*R.one] + >>> _solve_lin_sys(eqs, eqs_rhs, R) + {y: 1, x: 1} + + See also + ======== + + solve_lin_sys: This function is used internally by :func:`solve_lin_sys`. + """ + V = ring.gens + E = [] + for eq_coeffs in eqs_coeffs: + syms = list(eq_coeffs) + E.extend(zip(syms[:-1], syms[1:])) + G = V, E + + components = connected_components(G) + + sym2comp = {} + for n, component in enumerate(components): + for sym in component: + sym2comp[sym] = n + + subsystems = [([], []) for _ in range(len(components))] + for eq_coeff, eq_rhs in zip(eqs_coeffs, eqs_rhs): + sym = next(iter(eq_coeff), None) + sub_coeff, sub_rhs = subsystems[sym2comp[sym]] + sub_coeff.append(eq_coeff) + sub_rhs.append(eq_rhs) + + sol = {} + for subsystem in subsystems: + subsol = _solve_lin_sys_component(subsystem[0], subsystem[1], ring) + if subsol is None: + return None + sol.update(subsol) + + return sol + + +def _solve_lin_sys_component(eqs_coeffs, eqs_rhs, ring): + """Solve a linear system from dict of PolynomialRing coefficients + + Explanation + =========== + + This is an **internal** function used by :func:`solve_lin_sys` after the + equations have been preprocessed. After :func:`_solve_lin_sys` splits the + system into connected components this function is called for each + component. The system of equations is solved using Gauss-Jordan + elimination with division followed by back-substitution. + + Examples + ======== + + Setup a system for $x-y=0$ and $x+y=2$ and solve: + + >>> from sympy import symbols, sring + >>> from sympy.polys.solvers import _solve_lin_sys_component + >>> x, y = symbols('x, y') + >>> R, (xr, yr) = sring([x, y], [x, y]) + >>> eqs = [{xr:R.one, yr:-R.one}, {xr:R.one, yr:R.one}] + >>> eqs_rhs = [R.zero, -2*R.one] + >>> _solve_lin_sys_component(eqs, eqs_rhs, R) + {y: 1, x: 1} + + See also + ======== + + solve_lin_sys: This function is used internally by :func:`solve_lin_sys`. + """ + + # transform from equations to matrix form + matrix = eqs_to_matrix(eqs_coeffs, eqs_rhs, ring.gens, ring.domain) + + # convert to a field for rref + if not matrix.domain.is_Field: + matrix = matrix.to_field() + + # solve by row-reduction + echelon, pivots = matrix.rref() + + # construct the returnable form of the solutions + keys = ring.gens + + if pivots and pivots[-1] == len(keys): + return None + + if len(pivots) == len(keys): + sol = [] + for s in [row[-1] for row in echelon.rep.to_ddm()]: + a = s + sol.append(a) + sols = dict(zip(keys, sol)) + else: + sols = {} + g = ring.gens + # Extract ground domain coefficients and convert to the ring: + if hasattr(ring, 'ring'): + convert = ring.ring.ground_new + else: + convert = ring.ground_new + echelon = echelon.rep.to_ddm() + vals_set = {v for row in echelon for v in row} + vals_map = {v: convert(v) for v in vals_set} + echelon = [[vals_map[eij] for eij in ei] for ei in echelon] + for i, p in enumerate(pivots): + v = echelon[i][-1] - sum(echelon[i][j]*g[j] for j in range(p+1, len(g)) if echelon[i][j]) + sols[keys[p]] = v + + return sols diff --git a/phivenv/Lib/site-packages/sympy/polys/specialpolys.py b/phivenv/Lib/site-packages/sympy/polys/specialpolys.py new file mode 100644 index 0000000000000000000000000000000000000000..3e85de8679cda3084f1c263a045f4d8f817bed98 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/specialpolys.py @@ -0,0 +1,340 @@ +"""Functions for generating interesting polynomials, e.g. for benchmarking. """ + + +from sympy.core import Add, Mul, Symbol, sympify, Dummy, symbols +from sympy.core.containers import Tuple +from sympy.core.singleton import S +from sympy.ntheory import nextprime +from sympy.polys.densearith import ( + dmp_add_term, dmp_neg, dmp_mul, dmp_sqr +) +from sympy.polys.densebasic import ( + dmp_zero, dmp_one, dmp_ground, + dup_from_raw_dict, dmp_raise, dup_random +) +from sympy.polys.domains import ZZ +from sympy.polys.factortools import dup_zz_cyclotomic_poly +from sympy.polys.polyclasses import DMP +from sympy.polys.polytools import Poly, PurePoly +from sympy.polys.polyutils import _analyze_gens +from sympy.utilities import subsets, public, filldedent + + +@public +def swinnerton_dyer_poly(n, x=None, polys=False): + """Generates n-th Swinnerton-Dyer polynomial in `x`. + + Parameters + ---------- + n : int + `n` decides the order of polynomial + x : optional + polys : bool, optional + ``polys=True`` returns an expression, otherwise + (default) returns an expression. + """ + if n <= 0: + raise ValueError( + "Cannot generate Swinnerton-Dyer polynomial of order %s" % n) + + if x is not None: + sympify(x) + else: + x = Dummy('x') + + if n > 3: + from sympy.functions.elementary.miscellaneous import sqrt + from .numberfields import minimal_polynomial + p = 2 + a = [sqrt(2)] + for i in range(2, n + 1): + p = nextprime(p) + a.append(sqrt(p)) + return minimal_polynomial(Add(*a), x, polys=polys) + + if n == 1: + ex = x**2 - 2 + elif n == 2: + ex = x**4 - 10*x**2 + 1 + elif n == 3: + ex = x**8 - 40*x**6 + 352*x**4 - 960*x**2 + 576 + + return PurePoly(ex, x) if polys else ex + + +@public +def cyclotomic_poly(n, x=None, polys=False): + """Generates cyclotomic polynomial of order `n` in `x`. + + Parameters + ---------- + n : int + `n` decides the order of polynomial + x : optional + polys : bool, optional + ``polys=True`` returns an expression, otherwise + (default) returns an expression. + """ + if n <= 0: + raise ValueError( + "Cannot generate cyclotomic polynomial of order %s" % n) + + poly = DMP(dup_zz_cyclotomic_poly(int(n), ZZ), ZZ) + + if x is not None: + poly = Poly.new(poly, x) + else: + poly = PurePoly.new(poly, Dummy('x')) + + return poly if polys else poly.as_expr() + + +@public +def symmetric_poly(n, *gens, polys=False): + """ + Generates symmetric polynomial of order `n`. + + Parameters + ========== + + polys: bool, optional (default: False) + Returns a Poly object when ``polys=True``, otherwise + (default) returns an expression. + """ + gens = _analyze_gens(gens) + + if n < 0 or n > len(gens) or not gens: + raise ValueError("Cannot generate symmetric polynomial of order %s for %s" % (n, gens)) + elif not n: + poly = S.One + else: + poly = Add(*[Mul(*s) for s in subsets(gens, int(n))]) + + return Poly(poly, *gens) if polys else poly + + +@public +def random_poly(x, n, inf, sup, domain=ZZ, polys=False): + """Generates a polynomial of degree ``n`` with coefficients in + ``[inf, sup]``. + + Parameters + ---------- + x + `x` is the independent term of polynomial + n : int + `n` decides the order of polynomial + inf + Lower limit of range in which coefficients lie + sup + Upper limit of range in which coefficients lie + domain : optional + Decides what ring the coefficients are supposed + to belong. Default is set to Integers. + polys : bool, optional + ``polys=True`` returns an expression, otherwise + (default) returns an expression. + """ + poly = Poly(dup_random(n, inf, sup, domain), x, domain=domain) + + return poly if polys else poly.as_expr() + + +@public +def interpolating_poly(n, x, X='x', Y='y'): + """Construct Lagrange interpolating polynomial for ``n`` + data points. If a sequence of values are given for ``X`` and ``Y`` + then the first ``n`` values will be used. + """ + ok = getattr(x, 'free_symbols', None) + + if isinstance(X, str): + X = symbols("%s:%s" % (X, n)) + elif ok and ok & Tuple(*X).free_symbols: + ok = False + + if isinstance(Y, str): + Y = symbols("%s:%s" % (Y, n)) + elif ok and ok & Tuple(*Y).free_symbols: + ok = False + + if not ok: + raise ValueError(filldedent(''' + Expecting symbol for x that does not appear in X or Y. + Use `interpolate(list(zip(X, Y)), x)` instead.''')) + + coeffs = [] + numert = Mul(*[x - X[i] for i in range(n)]) + + for i in range(n): + numer = numert/(x - X[i]) + denom = Mul(*[(X[i] - X[j]) for j in range(n) if i != j]) + coeffs.append(numer/denom) + + return Add(*[coeff*y for coeff, y in zip(coeffs, Y)]) + + +def fateman_poly_F_1(n): + """Fateman's GCD benchmark: trivial GCD """ + Y = [Symbol('y_' + str(i)) for i in range(n + 1)] + + y_0, y_1 = Y[0], Y[1] + + u = y_0 + Add(*Y[1:]) + v = y_0**2 + Add(*[y**2 for y in Y[1:]]) + + F = ((u + 1)*(u + 2)).as_poly(*Y) + G = ((v + 1)*(-3*y_1*y_0**2 + y_1**2 - 1)).as_poly(*Y) + + H = Poly(1, *Y) + + return F, G, H + + +def dmp_fateman_poly_F_1(n, K): + """Fateman's GCD benchmark: trivial GCD """ + u = [K(1), K(0)] + + for i in range(n): + u = [dmp_one(i, K), u] + + v = [K(1), K(0), K(0)] + + for i in range(0, n): + v = [dmp_one(i, K), dmp_zero(i), v] + + m = n - 1 + + U = dmp_add_term(u, dmp_ground(K(1), m), 0, n, K) + V = dmp_add_term(u, dmp_ground(K(2), m), 0, n, K) + + f = [[-K(3), K(0)], [], [K(1), K(0), -K(1)]] + + W = dmp_add_term(v, dmp_ground(K(1), m), 0, n, K) + Y = dmp_raise(f, m, 1, K) + + F = dmp_mul(U, V, n, K) + G = dmp_mul(W, Y, n, K) + + H = dmp_one(n, K) + + return F, G, H + + +def fateman_poly_F_2(n): + """Fateman's GCD benchmark: linearly dense quartic inputs """ + Y = [Symbol('y_' + str(i)) for i in range(n + 1)] + + y_0 = Y[0] + + u = Add(*Y[1:]) + + H = Poly((y_0 + u + 1)**2, *Y) + + F = Poly((y_0 - u - 2)**2, *Y) + G = Poly((y_0 + u + 2)**2, *Y) + + return H*F, H*G, H + + +def dmp_fateman_poly_F_2(n, K): + """Fateman's GCD benchmark: linearly dense quartic inputs """ + u = [K(1), K(0)] + + for i in range(n - 1): + u = [dmp_one(i, K), u] + + m = n - 1 + + v = dmp_add_term(u, dmp_ground(K(2), m - 1), 0, n, K) + + f = dmp_sqr([dmp_one(m, K), dmp_neg(v, m, K)], n, K) + g = dmp_sqr([dmp_one(m, K), v], n, K) + + v = dmp_add_term(u, dmp_one(m - 1, K), 0, n, K) + + h = dmp_sqr([dmp_one(m, K), v], n, K) + + return dmp_mul(f, h, n, K), dmp_mul(g, h, n, K), h + + +def fateman_poly_F_3(n): + """Fateman's GCD benchmark: sparse inputs (deg f ~ vars f) """ + Y = [Symbol('y_' + str(i)) for i in range(n + 1)] + + y_0 = Y[0] + + u = Add(*[y**(n + 1) for y in Y[1:]]) + + H = Poly((y_0**(n + 1) + u + 1)**2, *Y) + + F = Poly((y_0**(n + 1) - u - 2)**2, *Y) + G = Poly((y_0**(n + 1) + u + 2)**2, *Y) + + return H*F, H*G, H + + +def dmp_fateman_poly_F_3(n, K): + """Fateman's GCD benchmark: sparse inputs (deg f ~ vars f) """ + u = dup_from_raw_dict({n + 1: K.one}, K) + + for i in range(0, n - 1): + u = dmp_add_term([u], dmp_one(i, K), n + 1, i + 1, K) + + v = dmp_add_term(u, dmp_ground(K(2), n - 2), 0, n, K) + + f = dmp_sqr( + dmp_add_term([dmp_neg(v, n - 1, K)], dmp_one(n - 1, K), n + 1, n, K), n, K) + g = dmp_sqr(dmp_add_term([v], dmp_one(n - 1, K), n + 1, n, K), n, K) + + v = dmp_add_term(u, dmp_one(n - 2, K), 0, n - 1, K) + + h = dmp_sqr(dmp_add_term([v], dmp_one(n - 1, K), n + 1, n, K), n, K) + + return dmp_mul(f, h, n, K), dmp_mul(g, h, n, K), h + +# A few useful polynomials from Wang's paper ('78). + +from sympy.polys.rings import ring + +def _f_0(): + R, x, y, z = ring("x,y,z", ZZ) + return x**2*y*z**2 + 2*x**2*y*z + 3*x**2*y + 2*x**2 + 3*x + 4*y**2*z**2 + 5*y**2*z + 6*y**2 + y*z**2 + 2*y*z + y + 1 + +def _f_1(): + R, x, y, z = ring("x,y,z", ZZ) + return x**3*y*z + x**2*y**2*z**2 + x**2*y**2 + 20*x**2*y*z + 30*x**2*y + x**2*z**2 + 10*x**2*z + x*y**3*z + 30*x*y**2*z + 20*x*y**2 + x*y*z**3 + 10*x*y*z**2 + x*y*z + 610*x*y + 20*x*z**2 + 230*x*z + 300*x + y**2*z**2 + 10*y**2*z + 30*y*z**2 + 320*y*z + 200*y + 600*z + 6000 + +def _f_2(): + R, x, y, z = ring("x,y,z", ZZ) + return x**5*y**3 + x**5*y**2*z + x**5*y*z**2 + x**5*z**3 + x**3*y**2 + x**3*y*z + 90*x**3*y + 90*x**3*z + x**2*y**2*z - 11*x**2*y**2 + x**2*z**3 - 11*x**2*z**2 + y*z - 11*y + 90*z - 990 + +def _f_3(): + R, x, y, z = ring("x,y,z", ZZ) + return x**5*y**2 + x**4*z**4 + x**4 + x**3*y**3*z + x**3*z + x**2*y**4 + x**2*y**3*z**3 + x**2*y*z**5 + x**2*y*z + x*y**2*z**4 + x*y**2 + x*y*z**7 + x*y*z**3 + x*y*z**2 + y**2*z + y*z**4 + +def _f_4(): + R, x, y, z = ring("x,y,z", ZZ) + return -x**9*y**8*z - x**8*y**5*z**3 - x**7*y**12*z**2 - 5*x**7*y**8 - x**6*y**9*z**4 + x**6*y**7*z**3 + 3*x**6*y**7*z - 5*x**6*y**5*z**2 - x**6*y**4*z**3 + x**5*y**4*z**5 + 3*x**5*y**4*z**3 - x**5*y*z**5 + x**4*y**11*z**4 + 3*x**4*y**11*z**2 - x**4*y**8*z**4 + 5*x**4*y**7*z**2 + 15*x**4*y**7 - 5*x**4*y**4*z**2 + x**3*y**8*z**6 + 3*x**3*y**8*z**4 - x**3*y**5*z**6 + 5*x**3*y**4*z**4 + 15*x**3*y**4*z**2 + x**3*y**3*z**5 + 3*x**3*y**3*z**3 - 5*x**3*y*z**4 + x**2*z**7 + 3*x**2*z**5 + x*y**7*z**6 + 3*x*y**7*z**4 + 5*x*y**3*z**4 + 15*x*y**3*z**2 + y**4*z**8 + 3*y**4*z**6 + 5*z**6 + 15*z**4 + +def _f_5(): + R, x, y, z = ring("x,y,z", ZZ) + return -x**3 - 3*x**2*y + 3*x**2*z - 3*x*y**2 + 6*x*y*z - 3*x*z**2 - y**3 + 3*y**2*z - 3*y*z**2 + z**3 + +def _f_6(): + R, x, y, z, t = ring("x,y,z,t", ZZ) + return 2115*x**4*y + 45*x**3*z**3*t**2 - 45*x**3*t**2 - 423*x*y**4 - 47*x*y**3 + 141*x*y*z**3 + 94*x*y*z*t - 9*y**3*z**3*t**2 + 9*y**3*t**2 - y**2*z**3*t**2 + y**2*t**2 + 3*z**6*t**2 + 2*z**4*t**3 - 3*z**3*t**2 - 2*z*t**3 + +def _w_1(): + R, x, y, z = ring("x,y,z", ZZ) + return 4*x**6*y**4*z**2 + 4*x**6*y**3*z**3 - 4*x**6*y**2*z**4 - 4*x**6*y*z**5 + x**5*y**4*z**3 + 12*x**5*y**3*z - x**5*y**2*z**5 + 12*x**5*y**2*z**2 - 12*x**5*y*z**3 - 12*x**5*z**4 + 8*x**4*y**4 + 6*x**4*y**3*z**2 + 8*x**4*y**3*z - 4*x**4*y**2*z**4 + 4*x**4*y**2*z**3 - 8*x**4*y**2*z**2 - 4*x**4*y*z**5 - 2*x**4*y*z**4 - 8*x**4*y*z**3 + 2*x**3*y**4*z + x**3*y**3*z**3 - x**3*y**2*z**5 - 2*x**3*y**2*z**3 + 9*x**3*y**2*z - 12*x**3*y*z**3 + 12*x**3*y*z**2 - 12*x**3*z**4 + 3*x**3*z**3 + 6*x**2*y**3 - 6*x**2*y**2*z**2 + 8*x**2*y**2*z - 2*x**2*y*z**4 - 8*x**2*y*z**3 + 2*x**2*y*z**2 + 2*x*y**3*z - 2*x*y**2*z**3 - 3*x*y*z + 3*x*z**3 - 2*y**2 + 2*y*z**2 + +def _w_2(): + R, x, y = ring("x,y", ZZ) + return 24*x**8*y**3 + 48*x**8*y**2 + 24*x**7*y**5 - 72*x**7*y**2 + 25*x**6*y**4 + 2*x**6*y**3 + 4*x**6*y + 8*x**6 + x**5*y**6 + x**5*y**3 - 12*x**5 + x**4*y**5 - x**4*y**4 - 2*x**4*y**3 + 292*x**4*y**2 - x**3*y**6 + 3*x**3*y**3 - x**2*y**5 + 12*x**2*y**3 + 48*x**2 - 12*y**3 + +def f_polys(): + return _f_0(), _f_1(), _f_2(), _f_3(), _f_4(), _f_5(), _f_6() + +def w_polys(): + return _w_1(), _w_2() diff --git a/phivenv/Lib/site-packages/sympy/polys/sqfreetools.py b/phivenv/Lib/site-packages/sympy/polys/sqfreetools.py new file mode 100644 index 0000000000000000000000000000000000000000..b2bf434cab542a42c0f7d67058e1a3c01857335d --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/sqfreetools.py @@ -0,0 +1,795 @@ +"""Square-free decomposition algorithms and related tools. """ + + +from sympy.polys.densearith import ( + dup_neg, dmp_neg, + dup_sub, dmp_sub, + dup_mul, dmp_mul, + dup_quo, dmp_quo, + dup_mul_ground, dmp_mul_ground) +from sympy.polys.densebasic import ( + dup_strip, + dup_LC, dmp_ground_LC, + dmp_zero_p, + dmp_ground, + dup_degree, dmp_degree, dmp_degree_in, dmp_degree_list, + dmp_raise, dmp_inject, + dup_convert) +from sympy.polys.densetools import ( + dup_diff, dmp_diff, dmp_diff_in, + dup_shift, dmp_shift, + dup_monic, dmp_ground_monic, + dup_primitive, dmp_ground_primitive) +from sympy.polys.euclidtools import ( + dup_inner_gcd, dmp_inner_gcd, + dup_gcd, dmp_gcd, + dmp_resultant, dmp_primitive) +from sympy.polys.galoistools import ( + gf_sqf_list, gf_sqf_part) +from sympy.polys.polyerrors import ( + MultivariatePolynomialError, + DomainError) + + +def _dup_check_degrees(f, result): + """Sanity check the degrees of a computed factorization in K[x].""" + deg = sum(k * dup_degree(fac) for (fac, k) in result) + assert deg == dup_degree(f) + + +def _dmp_check_degrees(f, u, result): + """Sanity check the degrees of a computed factorization in K[X].""" + degs = [0] * (u + 1) + for fac, k in result: + degs_fac = dmp_degree_list(fac, u) + degs = [d1 + k * d2 for d1, d2 in zip(degs, degs_fac)] + assert tuple(degs) == dmp_degree_list(f, u) + + +def dup_sqf_p(f, K): + """ + Return ``True`` if ``f`` is a square-free polynomial in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_sqf_p(x**2 - 2*x + 1) + False + >>> R.dup_sqf_p(x**2 - 1) + True + + """ + if not f: + return True + else: + return not dup_degree(dup_gcd(f, dup_diff(f, 1, K), K)) + + +def dmp_sqf_p(f, u, K): + """ + Return ``True`` if ``f`` is a square-free polynomial in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + >>> R.dmp_sqf_p(x**2 + 2*x*y + y**2) + False + >>> R.dmp_sqf_p(x**2 + y**2) + True + + """ + if dmp_zero_p(f, u): + return True + + for i in range(u+1): + + fp = dmp_diff_in(f, 1, i, u, K) + + if dmp_zero_p(fp, u): + continue + + gcd = dmp_gcd(f, fp, u, K) + + if dmp_degree_in(gcd, i, u) != 0: + return False + + return True + + +def dup_sqf_norm(f, K): + r""" + Find a shift of `f` in `K[x]` that has square-free norm. + + The domain `K` must be an algebraic number field `k(a)` (see :ref:`QQ(a)`). + + Returns `(s,g,r)`, such that `g(x)=f(x-sa)`, `r(x)=\text{Norm}(g(x))` and + `r` is a square-free polynomial over `k`. + + Examples + ======== + + We first create the algebraic number field `K=k(a)=\mathbb{Q}(\sqrt{3})` + and rings `K[x]` and `k[x]`: + + >>> from sympy.polys import ring, QQ + >>> from sympy import sqrt + + >>> K = QQ.algebraic_field(sqrt(3)) + >>> R, x = ring("x", K) + >>> _, X = ring("x", QQ) + + We can now find a square free norm for a shift of `f`: + + >>> f = x**2 - 1 + >>> s, g, r = R.dup_sqf_norm(f) + + The choice of shift `s` is arbitrary and the particular values returned for + `g` and `r` are determined by `s`. + + >>> s == 1 + True + >>> g == x**2 - 2*sqrt(3)*x + 2 + True + >>> r == X**4 - 8*X**2 + 4 + True + + The invariants are: + + >>> g == f.shift(-s*K.unit) + True + >>> g.norm() == r + True + >>> r.is_squarefree + True + + Explanation + =========== + + This is part of Trager's algorithm for factorizing polynomials over + algebraic number fields. In particular this function is algorithm + ``sqfr_norm`` from [Trager76]_. + + See Also + ======== + + dmp_sqf_norm: + Analogous function for multivariate polynomials over ``k(a)``. + dmp_norm: + Computes the norm of `f` directly without any shift. + dup_ext_factor: + Function implementing Trager's algorithm that uses this. + sympy.polys.polytools.sqf_norm: + High-level interface for using this function. + """ + if not K.is_Algebraic: + raise DomainError("ground domain must be algebraic") + + s, g = 0, dmp_raise(K.mod.to_list(), 1, 0, K.dom) + + while True: + h, _ = dmp_inject(f, 0, K, front=True) + r = dmp_resultant(g, h, 1, K.dom) + + if dup_sqf_p(r, K.dom): + break + else: + f, s = dup_shift(f, -K.unit, K), s + 1 + + return s, f, r + + +def _dmp_sqf_norm_shifts(f, u, K): + """Generate a sequence of candidate shifts for dmp_sqf_norm.""" + # + # We want to find a minimal shift if possible because shifting high degree + # variables can be expensive e.g. x**10 -> (x + 1)**10. We try a few easy + # cases first before the final infinite loop that is guaranteed to give + # only finitely many bad shifts (see Trager76 for proof of this in the + # univariate case). + # + + # First the trivial shift [0, 0, ...] + n = u + 1 + s0 = [0] * n + yield s0, f + + # Shift in multiples of the generator of the extension field K + a = K.unit + + # Variables of degree > 0 ordered by increasing degree + d = dmp_degree_list(f, u) + var_indices = [i for di, i in sorted(zip(d, range(u+1))) if di > 0] + + # Now try [1, 0, 0, ...], [0, 1, 0, ...] + for i in var_indices: + s1 = s0.copy() + s1[i] = 1 + a1 = [-a*s1i for s1i in s1] + f1 = dmp_shift(f, a1, u, K) + yield s1, f1 + + # Now try [1, 1, 1, ...], [2, 2, 2, ...] + j = 0 + while True: + j += 1 + sj = [j] * n + aj = [-a*j] * n + fj = dmp_shift(f, aj, u, K) + yield sj, fj + + +def dmp_sqf_norm(f, u, K): + r""" + Find a shift of ``f`` in ``K[X]`` that has square-free norm. + + The domain `K` must be an algebraic number field `k(a)` (see :ref:`QQ(a)`). + + Returns `(s,g,r)`, such that `g(x_1,x_2,\cdots)=f(x_1-s_1 a, x_2 - s_2 a, + \cdots)`, `r(x)=\text{Norm}(g(x))` and `r` is a square-free polynomial over + `k`. + + Examples + ======== + + We first create the algebraic number field `K=k(a)=\mathbb{Q}(i)` and rings + `K[x,y]` and `k[x,y]`: + + >>> from sympy.polys import ring, QQ + >>> from sympy import I + + >>> K = QQ.algebraic_field(I) + >>> R, x, y = ring("x,y", K) + >>> _, X, Y = ring("x,y", QQ) + + We can now find a square free norm for a shift of `f`: + + >>> f = x*y + y**2 + >>> s, g, r = R.dmp_sqf_norm(f) + + The choice of shifts ``s`` is arbitrary and the particular values returned + for ``g`` and ``r`` are determined by ``s``. + + >>> s + [0, 1] + >>> g == x*y - I*x + y**2 - 2*I*y - 1 + True + >>> r == X**2*Y**2 + X**2 + 2*X*Y**3 + 2*X*Y + Y**4 + 2*Y**2 + 1 + True + + The required invariants are: + + >>> g == f.shift_list([-si*K.unit for si in s]) + True + >>> g.norm() == r + True + >>> r.is_squarefree + True + + Explanation + =========== + + This is part of Trager's algorithm for factorizing polynomials over + algebraic number fields. In particular this function is a multivariate + generalization of algorithm ``sqfr_norm`` from [Trager76]_. + + See Also + ======== + + dup_sqf_norm: + Analogous function for univariate polynomials over ``k(a)``. + dmp_norm: + Computes the norm of `f` directly without any shift. + dmp_ext_factor: + Function implementing Trager's algorithm that uses this. + sympy.polys.polytools.sqf_norm: + High-level interface for using this function. + """ + if not u: + s, g, r = dup_sqf_norm(f, K) + return [s], g, r + + if not K.is_Algebraic: + raise DomainError("ground domain must be algebraic") + + g = dmp_raise(K.mod.to_list(), u + 1, 0, K.dom) + + for s, f in _dmp_sqf_norm_shifts(f, u, K): + + h, _ = dmp_inject(f, u, K, front=True) + r = dmp_resultant(g, h, u + 1, K.dom) + + if dmp_sqf_p(r, u, K.dom): + break + + return s, f, r + + +def dmp_norm(f, u, K): + r""" + Norm of ``f`` in ``K[X]``, often not square-free. + + The domain `K` must be an algebraic number field `k(a)` (see :ref:`QQ(a)`). + + Examples + ======== + + We first define the algebraic number field `K = k(a) = \mathbb{Q}(\sqrt{2})`: + + >>> from sympy import QQ, sqrt + >>> from sympy.polys.sqfreetools import dmp_norm + >>> k = QQ + >>> K = k.algebraic_field(sqrt(2)) + + We can now compute the norm of a polynomial `p` in `K[x,y]`: + + >>> p = [[K(1)], [K(1),K.unit]] # x + y + sqrt(2) + >>> N = [[k(1)], [k(2),k(0)], [k(1),k(0),k(-2)]] # x**2 + 2*x*y + y**2 - 2 + >>> dmp_norm(p, 1, K) == N + True + + In higher level functions that is: + + >>> from sympy import expand, roots, minpoly + >>> from sympy.abc import x, y + >>> from math import prod + >>> a = sqrt(2) + >>> e = (x + y + a) + >>> e.as_poly([x, y], extension=a).norm() + Poly(x**2 + 2*x*y + y**2 - 2, x, y, domain='QQ') + + This is equal to the product of the expressions `x + y + a_i` where the + `a_i` are the conjugates of `a`: + + >>> pa = minpoly(a) + >>> pa + _x**2 - 2 + >>> rs = roots(pa, multiple=True) + >>> rs + [sqrt(2), -sqrt(2)] + >>> n = prod(e.subs(a, r) for r in rs) + >>> n + (x + y - sqrt(2))*(x + y + sqrt(2)) + >>> expand(n) + x**2 + 2*x*y + y**2 - 2 + + Explanation + =========== + + Given an algebraic number field `K = k(a)` any element `b` of `K` can be + represented as polynomial function `b=g(a)` where `g` is in `k[x]`. If the + minimal polynomial of `a` over `k` is `p_a` then the roots `a_1`, `a_2`, + `\cdots` of `p_a(x)` are the conjugates of `a`. The norm of `b` is the + product `g(a1) \times g(a2) \times \cdots` and is an element of `k`. + + As in [Trager76]_ we extend this norm to multivariate polynomials over `K`. + If `b(x)` is a polynomial in `k(a)[X]` then we can think of `b` as being + alternately a function `g_X(a)` where `g_X` is an element of `k[X][y]` i.e. + a polynomial function with coefficients that are elements of `k[X]`. Then + the norm of `b` is the product `g_X(a1) \times g_X(a2) \times \cdots` and + will be an element of `k[X]`. + + See Also + ======== + + dmp_sqf_norm: + Compute a shift of `f` so that the `\text{Norm}(f)` is square-free. + sympy.polys.polytools.Poly.norm: + Higher-level function that calls this. + """ + if not K.is_Algebraic: + raise DomainError("ground domain must be algebraic") + + g = dmp_raise(K.mod.to_list(), u + 1, 0, K.dom) + h, _ = dmp_inject(f, u, K, front=True) + + return dmp_resultant(g, h, u + 1, K.dom) + + +def dup_gf_sqf_part(f, K): + """Compute square-free part of ``f`` in ``GF(p)[x]``. """ + f = dup_convert(f, K, K.dom) + g = gf_sqf_part(f, K.mod, K.dom) + return dup_convert(g, K.dom, K) + + +def dmp_gf_sqf_part(f, u, K): + """Compute square-free part of ``f`` in ``GF(p)[X]``. """ + raise NotImplementedError('multivariate polynomials over finite fields') + + +def dup_sqf_part(f, K): + """ + Returns square-free part of a polynomial in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_sqf_part(x**3 - 3*x - 2) + x**2 - x - 2 + + See Also + ======== + + sympy.polys.polytools.Poly.sqf_part + """ + if K.is_FiniteField: + return dup_gf_sqf_part(f, K) + + if not f: + return f + + if K.is_negative(dup_LC(f, K)): + f = dup_neg(f, K) + + gcd = dup_gcd(f, dup_diff(f, 1, K), K) + sqf = dup_quo(f, gcd, K) + + if K.is_Field: + return dup_monic(sqf, K) + else: + return dup_primitive(sqf, K)[1] + + +def dmp_sqf_part(f, u, K): + """ + Returns square-free part of a polynomial in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + >>> R.dmp_sqf_part(x**3 + 2*x**2*y + x*y**2) + x**2 + x*y + + """ + if not u: + return dup_sqf_part(f, K) + + if K.is_FiniteField: + return dmp_gf_sqf_part(f, u, K) + + if dmp_zero_p(f, u): + return f + + if K.is_negative(dmp_ground_LC(f, u, K)): + f = dmp_neg(f, u, K) + + gcd = f + for i in range(u+1): + gcd = dmp_gcd(gcd, dmp_diff_in(f, 1, i, u, K), u, K) + sqf = dmp_quo(f, gcd, u, K) + + if K.is_Field: + return dmp_ground_monic(sqf, u, K) + else: + return dmp_ground_primitive(sqf, u, K)[1] + + +def dup_gf_sqf_list(f, K, all=False): + """Compute square-free decomposition of ``f`` in ``GF(p)[x]``. """ + f_orig = f + + f = dup_convert(f, K, K.dom) + + coeff, factors = gf_sqf_list(f, K.mod, K.dom, all=all) + + for i, (f, k) in enumerate(factors): + factors[i] = (dup_convert(f, K.dom, K), k) + + _dup_check_degrees(f_orig, factors) + + return K.convert(coeff, K.dom), factors + + +def dmp_gf_sqf_list(f, u, K, all=False): + """Compute square-free decomposition of ``f`` in ``GF(p)[X]``. """ + raise NotImplementedError('multivariate polynomials over finite fields') + + +def dup_sqf_list(f, K, all=False): + """ + Return square-free decomposition of a polynomial in ``K[x]``. + + Uses Yun's algorithm from [Yun76]_. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> f = 2*x**5 + 16*x**4 + 50*x**3 + 76*x**2 + 56*x + 16 + + >>> R.dup_sqf_list(f) + (2, [(x + 1, 2), (x + 2, 3)]) + >>> R.dup_sqf_list(f, all=True) + (2, [(1, 1), (x + 1, 2), (x + 2, 3)]) + + See Also + ======== + + dmp_sqf_list: + Corresponding function for multivariate polynomials. + sympy.polys.polytools.sqf_list: + High-level function for square-free factorization of expressions. + sympy.polys.polytools.Poly.sqf_list: + Analogous method on :class:`~.Poly`. + + References + ========== + + [Yun76]_ + """ + if K.is_FiniteField: + return dup_gf_sqf_list(f, K, all=all) + + f_orig = f + + if K.is_Field: + coeff = dup_LC(f, K) + f = dup_monic(f, K) + else: + coeff, f = dup_primitive(f, K) + + if K.is_negative(dup_LC(f, K)): + f = dup_neg(f, K) + coeff = -coeff + + if dup_degree(f) <= 0: + return coeff, [] + + result, i = [], 1 + + h = dup_diff(f, 1, K) + g, p, q = dup_inner_gcd(f, h, K) + + while True: + d = dup_diff(p, 1, K) + h = dup_sub(q, d, K) + + if not h: + result.append((p, i)) + break + + g, p, q = dup_inner_gcd(p, h, K) + + if all or dup_degree(g) > 0: + result.append((g, i)) + + i += 1 + + _dup_check_degrees(f_orig, result) + + return coeff, result + + +def dup_sqf_list_include(f, K, all=False): + """ + Return square-free decomposition of a polynomial in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> f = 2*x**5 + 16*x**4 + 50*x**3 + 76*x**2 + 56*x + 16 + + >>> R.dup_sqf_list_include(f) + [(2, 1), (x + 1, 2), (x + 2, 3)] + >>> R.dup_sqf_list_include(f, all=True) + [(2, 1), (x + 1, 2), (x + 2, 3)] + + """ + coeff, factors = dup_sqf_list(f, K, all=all) + + if factors and factors[0][1] == 1: + g = dup_mul_ground(factors[0][0], coeff, K) + return [(g, 1)] + factors[1:] + else: + g = dup_strip([coeff]) + return [(g, 1)] + factors + + +def dmp_sqf_list(f, u, K, all=False): + """ + Return square-free decomposition of a polynomial in `K[X]`. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + >>> f = x**5 + 2*x**4*y + x**3*y**2 + + >>> R.dmp_sqf_list(f) + (1, [(x + y, 2), (x, 3)]) + >>> R.dmp_sqf_list(f, all=True) + (1, [(1, 1), (x + y, 2), (x, 3)]) + + Explanation + =========== + + Uses Yun's algorithm for univariate polynomials from [Yun76]_ recursively. + The multivariate polynomial is treated as a univariate polynomial in its + leading variable. Then Yun's algorithm computes the square-free + factorization of the primitive and the content is factored recursively. + + It would be better to use a dedicated algorithm for multivariate + polynomials instead. + + See Also + ======== + + dup_sqf_list: + Corresponding function for univariate polynomials. + sympy.polys.polytools.sqf_list: + High-level function for square-free factorization of expressions. + sympy.polys.polytools.Poly.sqf_list: + Analogous method on :class:`~.Poly`. + """ + if not u: + return dup_sqf_list(f, K, all=all) + + if K.is_FiniteField: + return dmp_gf_sqf_list(f, u, K, all=all) + + f_orig = f + + if K.is_Field: + coeff = dmp_ground_LC(f, u, K) + f = dmp_ground_monic(f, u, K) + else: + coeff, f = dmp_ground_primitive(f, u, K) + + if K.is_negative(dmp_ground_LC(f, u, K)): + f = dmp_neg(f, u, K) + coeff = -coeff + + deg = dmp_degree(f, u) + if deg < 0: + return coeff, [] + + # Yun's algorithm requires the polynomial to be primitive as a univariate + # polynomial in its main variable. + content, f = dmp_primitive(f, u, K) + + result = {} + + if deg != 0: + + h = dmp_diff(f, 1, u, K) + g, p, q = dmp_inner_gcd(f, h, u, K) + + i = 1 + + while True: + d = dmp_diff(p, 1, u, K) + h = dmp_sub(q, d, u, K) + + if dmp_zero_p(h, u): + result[i] = p + break + + g, p, q = dmp_inner_gcd(p, h, u, K) + + if all or dmp_degree(g, u) > 0: + result[i] = g + + i += 1 + + coeff_content, result_content = dmp_sqf_list(content, u-1, K, all=all) + + coeff *= coeff_content + + # Combine factors of the content and primitive part that have the same + # multiplicity to produce a list in ascending order of multiplicity. + for fac, i in result_content: + fac = [fac] + if i in result: + result[i] = dmp_mul(result[i], fac, u, K) + else: + result[i] = fac + + result = [(result[i], i) for i in sorted(result)] + + _dmp_check_degrees(f_orig, u, result) + + return coeff, result + + +def dmp_sqf_list_include(f, u, K, all=False): + """ + Return square-free decomposition of a polynomial in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + >>> f = x**5 + 2*x**4*y + x**3*y**2 + + >>> R.dmp_sqf_list_include(f) + [(1, 1), (x + y, 2), (x, 3)] + >>> R.dmp_sqf_list_include(f, all=True) + [(1, 1), (x + y, 2), (x, 3)] + + """ + if not u: + return dup_sqf_list_include(f, K, all=all) + + coeff, factors = dmp_sqf_list(f, u, K, all=all) + + if factors and factors[0][1] == 1: + g = dmp_mul_ground(factors[0][0], coeff, u, K) + return [(g, 1)] + factors[1:] + else: + g = dmp_ground(coeff, u) + return [(g, 1)] + factors + + +def dup_gff_list(f, K): + """ + Compute greatest factorial factorization of ``f`` in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_gff_list(x**5 + 2*x**4 - x**3 - 2*x**2) + [(x, 1), (x + 2, 4)] + + """ + if not f: + raise ValueError("greatest factorial factorization doesn't exist for a zero polynomial") + + f = dup_monic(f, K) + + if not dup_degree(f): + return [] + else: + g = dup_gcd(f, dup_shift(f, K.one, K), K) + H = dup_gff_list(g, K) + + for i, (h, k) in enumerate(H): + g = dup_mul(g, dup_shift(h, -K(k), K), K) + H[i] = (h, k + 1) + + f = dup_quo(f, g, K) + + if not dup_degree(f): + return H + else: + return [(f, 1)] + H + + +def dmp_gff_list(f, u, K): + """ + Compute greatest factorial factorization of ``f`` in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + """ + if not u: + return dup_gff_list(f, K) + else: + raise MultivariatePolynomialError(f) diff --git a/phivenv/Lib/site-packages/sympy/polys/subresultants_qq_zz.py b/phivenv/Lib/site-packages/sympy/polys/subresultants_qq_zz.py new file mode 100644 index 0000000000000000000000000000000000000000..9ce8d5c88d44022621d13d8e82956e676a7e75ae --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/subresultants_qq_zz.py @@ -0,0 +1,2558 @@ +""" +This module contains functions for the computation +of Euclidean, (generalized) Sturmian, (modified) subresultant +polynomial remainder sequences (prs's) of two polynomials; +included are also three functions for the computation of the +resultant of two polynomials. + +Except for the function res_z(), which computes the resultant +of two polynomials, the pseudo-remainder function prem() +of sympy is _not_ used by any of the functions in the module. + +Instead of prem() we use the function + +rem_z(). + +Included is also the function quo_z(). + +An explanation of why we avoid prem() can be found in the +references stated in the docstring of rem_z(). + +1. Theoretical background: +========================== +Consider the polynomials f, g in Z[x] of degrees deg(f) = n and +deg(g) = m with n >= m. + +Definition 1: +============= +The sign sequence of a polynomial remainder sequence (prs) is the +sequence of signs of the leading coefficients of its polynomials. + +Sign sequences can be computed with the function: + +sign_seq(poly_seq, x) + +Definition 2: +============= +A polynomial remainder sequence (prs) is called complete if the +degree difference between any two consecutive polynomials is 1; +otherwise, it called incomplete. + +It is understood that f, g belong to the sequences mentioned in +the two definitions above. + +1A. Euclidean and subresultant prs's: +===================================== +The subresultant prs of f, g is a sequence of polynomials in Z[x] +analogous to the Euclidean prs, the sequence obtained by applying +on f, g Euclid's algorithm for polynomial greatest common divisors +(gcd) in Q[x]. + +The subresultant prs differs from the Euclidean prs in that the +coefficients of each polynomial in the former sequence are determinants +--- also referred to as subresultants --- of appropriately selected +sub-matrices of sylvester1(f, g, x), Sylvester's matrix of 1840 of +dimensions (n + m) * (n + m). + +Recall that the determinant of sylvester1(f, g, x) itself is +called the resultant of f, g and serves as a criterion of whether +the two polynomials have common roots or not. + +In SymPy the resultant is computed with the function +resultant(f, g, x). This function does _not_ evaluate the +determinant of sylvester(f, g, x, 1); instead, it returns +the last member of the subresultant prs of f, g, multiplied +(if needed) by an appropriate power of -1; see the caveat below. + +In this module we use three functions to compute the +resultant of f, g: +a) res(f, g, x) computes the resultant by evaluating +the determinant of sylvester(f, g, x, 1); +b) res_q(f, g, x) computes the resultant recursively, by +performing polynomial divisions in Q[x] with the function rem(); +c) res_z(f, g, x) computes the resultant recursively, by +performing polynomial divisions in Z[x] with the function prem(). + +Caveat: If Df = degree(f, x) and Dg = degree(g, x), then: + +resultant(f, g, x) = (-1)**(Df*Dg) * resultant(g, f, x). + +For complete prs's the sign sequence of the Euclidean prs of f, g +is identical to the sign sequence of the subresultant prs of f, g +and the coefficients of one sequence are easily computed from the +coefficients of the other. + +For incomplete prs's the polynomials in the subresultant prs, generally +differ in sign from those of the Euclidean prs, and --- unlike the +case of complete prs's --- it is not at all obvious how to compute +the coefficients of one sequence from the coefficients of the other. + +1B. Sturmian and modified subresultant prs's: +============================================= +For the same polynomials f, g in Z[x] mentioned above, their ``modified'' +subresultant prs is a sequence of polynomials similar to the Sturmian +prs, the sequence obtained by applying in Q[x] Sturm's algorithm on f, g. + +The two sequences differ in that the coefficients of each polynomial +in the modified subresultant prs are the determinants --- also referred +to as modified subresultants --- of appropriately selected sub-matrices +of sylvester2(f, g, x), Sylvester's matrix of 1853 of dimensions 2n x 2n. + +The determinant of sylvester2 itself is called the modified resultant +of f, g and it also can serve as a criterion of whether the two +polynomials have common roots or not. + +For complete prs's the sign sequence of the Sturmian prs of f, g is +identical to the sign sequence of the modified subresultant prs of +f, g and the coefficients of one sequence are easily computed from +the coefficients of the other. + +For incomplete prs's the polynomials in the modified subresultant prs, +generally differ in sign from those of the Sturmian prs, and --- unlike +the case of complete prs's --- it is not at all obvious how to compute +the coefficients of one sequence from the coefficients of the other. + +As Sylvester pointed out, the coefficients of the polynomial remainders +obtained as (modified) subresultants are the smallest possible without +introducing rationals and without computing (integer) greatest common +divisors. + +1C. On terminology: +=================== +Whence the terminology? Well generalized Sturmian prs's are +``modifications'' of Euclidean prs's; the hint came from the title +of the Pell-Gordon paper of 1917. + +In the literature one also encounters the name ``non signed'' and +``signed'' prs for Euclidean and Sturmian prs respectively. + +Likewise ``non signed'' and ``signed'' subresultant prs for +subresultant and modified subresultant prs respectively. + +2. Functions in the module: +=========================== +No function utilizes SymPy's function prem(). + +2A. Matrices: +============= +The functions sylvester(f, g, x, method=1) and +sylvester(f, g, x, method=2) compute either Sylvester matrix. +They can be used to compute (modified) subresultant prs's by +direct determinant evaluation. + +The function bezout(f, g, x, method='prs') provides a matrix of +smaller dimensions than either Sylvester matrix. It is the function +of choice for computing (modified) subresultant prs's by direct +determinant evaluation. + +sylvester(f, g, x, method=1) +sylvester(f, g, x, method=2) +bezout(f, g, x, method='prs') + +The following identity holds: + +bezout(f, g, x, method='prs') = +backward_eye(deg(f))*bezout(f, g, x, method='bz')*backward_eye(deg(f)) + +2B. Subresultant and modified subresultant prs's by +=================================================== +determinant evaluations: +======================= +We use the Sylvester matrices of 1840 and 1853 to +compute, respectively, subresultant and modified +subresultant polynomial remainder sequences. However, +for large matrices this approach takes a lot of time. + +Instead of utilizing the Sylvester matrices, we can +employ the Bezout matrix which is of smaller dimensions. + +subresultants_sylv(f, g, x) +modified_subresultants_sylv(f, g, x) +subresultants_bezout(f, g, x) +modified_subresultants_bezout(f, g, x) + +2C. Subresultant prs's by ONE determinant evaluation: +===================================================== +All three functions in this section evaluate one determinant +per remainder polynomial; this is the determinant of an +appropriately selected sub-matrix of sylvester1(f, g, x), +Sylvester's matrix of 1840. + +To compute the remainder polynomials the function +subresultants_rem(f, g, x) employs rem(f, g, x). +By contrast, the other two functions implement Van Vleck's ideas +of 1900 and compute the remainder polynomials by trinagularizing +sylvester2(f, g, x), Sylvester's matrix of 1853. + + +subresultants_rem(f, g, x) +subresultants_vv(f, g, x) +subresultants_vv_2(f, g, x). + +2E. Euclidean, Sturmian prs's in Q[x]: +====================================== +euclid_q(f, g, x) +sturm_q(f, g, x) + +2F. Euclidean, Sturmian and (modified) subresultant prs's P-G: +============================================================== +All functions in this section are based on the Pell-Gordon (P-G) +theorem of 1917. +Computations are done in Q[x], employing the function rem(f, g, x) +for the computation of the remainder polynomials. + +euclid_pg(f, g, x) +sturm pg(f, g, x) +subresultants_pg(f, g, x) +modified_subresultants_pg(f, g, x) + +2G. Euclidean, Sturmian and (modified) subresultant prs's A-M-V: +================================================================ +All functions in this section are based on the Akritas-Malaschonok- +Vigklas (A-M-V) theorem of 2015. +Computations are done in Z[x], employing the function rem_z(f, g, x) +for the computation of the remainder polynomials. + +euclid_amv(f, g, x) +sturm_amv(f, g, x) +subresultants_amv(f, g, x) +modified_subresultants_amv(f, g, x) + +2Ga. Exception: +=============== +subresultants_amv_q(f, g, x) + +This function employs rem(f, g, x) for the computation of +the remainder polynomials, despite the fact that it implements +the A-M-V Theorem. + +It is included in our module in order to show that theorems P-G +and A-M-V can be implemented utilizing either the function +rem(f, g, x) or the function rem_z(f, g, x). + +For clearly historical reasons --- since the Collins-Brown-Traub +coefficients-reduction factor beta_i was not available in 1917 --- +we have implemented the Pell-Gordon theorem with the function +rem(f, g, x) and the A-M-V Theorem with the function rem_z(f, g, x). + +2H. Resultants: +=============== +res(f, g, x) +res_q(f, g, x) +res_z(f, g, x) +""" + + +from sympy.concrete.summations import summation +from sympy.core.function import expand +from sympy.core.numbers import nan +from sympy.core.singleton import S +from sympy.core.symbol import Dummy as var +from sympy.functions.elementary.complexes import Abs, sign +from sympy.functions.elementary.integers import floor +from sympy.matrices.dense import eye, Matrix, zeros +from sympy.printing.pretty.pretty import pretty_print as pprint +from sympy.simplify.simplify import simplify +from sympy.polys.domains import QQ +from sympy.polys.polytools import degree, LC, Poly, pquo, quo, prem, rem +from sympy.polys.polyerrors import PolynomialError + + +def sylvester(f, g, x, method = 1): + ''' + The input polynomials f, g are in Z[x] or in Q[x]. Let m = degree(f, x), + n = degree(g, x) and mx = max(m, n). + + a. If method = 1 (default), computes sylvester1, Sylvester's matrix of 1840 + of dimension (m + n) x (m + n). The determinants of properly chosen + submatrices of this matrix (a.k.a. subresultants) can be + used to compute the coefficients of the Euclidean PRS of f, g. + + b. If method = 2, computes sylvester2, Sylvester's matrix of 1853 + of dimension (2*mx) x (2*mx). The determinants of properly chosen + submatrices of this matrix (a.k.a. ``modified'' subresultants) can be + used to compute the coefficients of the Sturmian PRS of f, g. + + Applications of these Matrices can be found in the references below. + Especially, for applications of sylvester2, see the first reference!! + + References + ========== + 1. Akritas, A. G., G.I. Malaschonok and P.S. Vigklas: ``On a Theorem + by Van Vleck Regarding Sturm Sequences. Serdica Journal of Computing, + Vol. 7, No 4, 101-134, 2013. + + 2. Akritas, A. G., G.I. Malaschonok and P.S. Vigklas: ``Sturm Sequences + and Modified Subresultant Polynomial Remainder Sequences.'' + Serdica Journal of Computing, Vol. 8, No 1, 29-46, 2014. + + ''' + # obtain degrees of polys + m, n = degree( Poly(f, x), x), degree( Poly(g, x), x) + + # Special cases: + # A:: case m = n < 0 (i.e. both polys are 0) + if m == n and n < 0: + return Matrix([]) + + # B:: case m = n = 0 (i.e. both polys are constants) + if m == n and n == 0: + return Matrix([]) + + # C:: m == 0 and n < 0 or m < 0 and n == 0 + # (i.e. one poly is constant and the other is 0) + if m == 0 and n < 0: + return Matrix([]) + elif m < 0 and n == 0: + return Matrix([]) + + # D:: m >= 1 and n < 0 or m < 0 and n >=1 + # (i.e. one poly is of degree >=1 and the other is 0) + if m >= 1 and n < 0: + return Matrix([0]) + elif m < 0 and n >= 1: + return Matrix([0]) + + fp = Poly(f, x).all_coeffs() + gp = Poly(g, x).all_coeffs() + + # Sylvester's matrix of 1840 (default; a.k.a. sylvester1) + if method <= 1: + M = zeros(m + n) + k = 0 + for i in range(n): + j = k + for coeff in fp: + M[i, j] = coeff + j = j + 1 + k = k + 1 + k = 0 + for i in range(n, m + n): + j = k + for coeff in gp: + M[i, j] = coeff + j = j + 1 + k = k + 1 + return M + + # Sylvester's matrix of 1853 (a.k.a sylvester2) + else: + if len(fp) < len(gp): + h = [] + for i in range(len(gp) - len(fp)): + h.append(0) + fp[ : 0] = h + else: + h = [] + for i in range(len(fp) - len(gp)): + h.append(0) + gp[ : 0] = h + mx = max(m, n) + dim = 2*mx + M = zeros( dim ) + k = 0 + for i in range( mx ): + j = k + for coeff in fp: + M[2*i, j] = coeff + j = j + 1 + j = k + for coeff in gp: + M[2*i + 1, j] = coeff + j = j + 1 + k = k + 1 + return M + +def process_matrix_output(poly_seq, x): + """ + poly_seq is a polynomial remainder sequence computed either by + (modified_)subresultants_bezout or by (modified_)subresultants_sylv. + + This function removes from poly_seq all zero polynomials as well + as all those whose degree is equal to the degree of a preceding + polynomial in poly_seq, as we scan it from left to right. + + """ + L = poly_seq[:] # get a copy of the input sequence + d = degree(L[1], x) + i = 2 + while i < len(L): + d_i = degree(L[i], x) + if d_i < 0: # zero poly + L.remove(L[i]) + i = i - 1 + if d == d_i: # poly degree equals degree of previous poly + L.remove(L[i]) + i = i - 1 + if d_i >= 0: + d = d_i + i = i + 1 + + return L + +def subresultants_sylv(f, g, x): + """ + The input polynomials f, g are in Z[x] or in Q[x]. It is assumed + that deg(f) >= deg(g). + + Computes the subresultant polynomial remainder sequence (prs) + of f, g by evaluating determinants of appropriately selected + submatrices of sylvester(f, g, x, 1). The dimensions of the + latter are (deg(f) + deg(g)) x (deg(f) + deg(g)). + + Each coefficient is computed by evaluating the determinant of the + corresponding submatrix of sylvester(f, g, x, 1). + + If the subresultant prs is complete, then the output coincides + with the Euclidean sequence of the polynomials f, g. + + References: + =========== + 1. G.M.Diaz-Toca,L.Gonzalez-Vega: Various New Expressions for Subresultants + and Their Applications. Appl. Algebra in Engin., Communic. and Comp., + Vol. 15, 233-266, 2004. + + """ + + # make sure neither f nor g is 0 + if f == 0 or g == 0: + return [f, g] + + n = degF = degree(f, x) + m = degG = degree(g, x) + + # make sure proper degrees + if n == 0 and m == 0: + return [f, g] + if n < m: + n, m, degF, degG, f, g = m, n, degG, degF, g, f + if n > 0 and m == 0: + return [f, g] + + SR_L = [f, g] # subresultant list + + # form matrix sylvester(f, g, x, 1) + S = sylvester(f, g, x, 1) + + # pick appropriate submatrices of S + # and form subresultant polys + j = m - 1 + + while j > 0: + Sp = S[:, :] # copy of S + # delete last j rows of coeffs of g + for ind in range(m + n - j, m + n): + Sp.row_del(m + n - j) + # delete last j rows of coeffs of f + for ind in range(m - j, m): + Sp.row_del(m - j) + + # evaluate determinants and form coefficients list + coeff_L, k, l = [], Sp.rows, 0 + while l <= j: + coeff_L.append(Sp[:, 0:k].det()) + Sp.col_swap(k - 1, k + l) + l += 1 + + # form poly and append to SP_L + SR_L.append(Poly(coeff_L, x).as_expr()) + j -= 1 + + # j = 0 + SR_L.append(S.det()) + + return process_matrix_output(SR_L, x) + +def modified_subresultants_sylv(f, g, x): + """ + The input polynomials f, g are in Z[x] or in Q[x]. It is assumed + that deg(f) >= deg(g). + + Computes the modified subresultant polynomial remainder sequence (prs) + of f, g by evaluating determinants of appropriately selected + submatrices of sylvester(f, g, x, 2). The dimensions of the + latter are (2*deg(f)) x (2*deg(f)). + + Each coefficient is computed by evaluating the determinant of the + corresponding submatrix of sylvester(f, g, x, 2). + + If the modified subresultant prs is complete, then the output coincides + with the Sturmian sequence of the polynomials f, g. + + References: + =========== + 1. A. G. Akritas,G.I. Malaschonok and P.S. Vigklas: + Sturm Sequences and Modified Subresultant Polynomial Remainder + Sequences. Serdica Journal of Computing, Vol. 8, No 1, 29--46, 2014. + + """ + + # make sure neither f nor g is 0 + if f == 0 or g == 0: + return [f, g] + + n = degF = degree(f, x) + m = degG = degree(g, x) + + # make sure proper degrees + if n == 0 and m == 0: + return [f, g] + if n < m: + n, m, degF, degG, f, g = m, n, degG, degF, g, f + if n > 0 and m == 0: + return [f, g] + + SR_L = [f, g] # modified subresultant list + + # form matrix sylvester(f, g, x, 2) + S = sylvester(f, g, x, 2) + + # pick appropriate submatrices of S + # and form modified subresultant polys + j = m - 1 + + while j > 0: + # delete last 2*j rows of pairs of coeffs of f, g + Sp = S[0:2*n - 2*j, :] # copy of first 2*n - 2*j rows of S + + # evaluate determinants and form coefficients list + coeff_L, k, l = [], Sp.rows, 0 + while l <= j: + coeff_L.append(Sp[:, 0:k].det()) + Sp.col_swap(k - 1, k + l) + l += 1 + + # form poly and append to SP_L + SR_L.append(Poly(coeff_L, x).as_expr()) + j -= 1 + + # j = 0 + SR_L.append(S.det()) + + return process_matrix_output(SR_L, x) + +def res(f, g, x): + """ + The input polynomials f, g are in Z[x] or in Q[x]. + + The output is the resultant of f, g computed by evaluating + the determinant of the matrix sylvester(f, g, x, 1). + + References: + =========== + 1. J. S. Cohen: Computer Algebra and Symbolic Computation + - Mathematical Methods. A. K. Peters, 2003. + + """ + if f == 0 or g == 0: + raise PolynomialError("The resultant of %s and %s is not defined" % (f, g)) + else: + return sylvester(f, g, x, 1).det() + +def res_q(f, g, x): + """ + The input polynomials f, g are in Z[x] or in Q[x]. + + The output is the resultant of f, g computed recursively + by polynomial divisions in Q[x], using the function rem. + See Cohen's book p. 281. + + References: + =========== + 1. J. S. Cohen: Computer Algebra and Symbolic Computation + - Mathematical Methods. A. K. Peters, 2003. + """ + m = degree(f, x) + n = degree(g, x) + if m < n: + return (-1)**(m*n) * res_q(g, f, x) + elif n == 0: # g is a constant + return g**m + else: + r = rem(f, g, x) + if r == 0: + return 0 + else: + s = degree(r, x) + l = LC(g, x) + return (-1)**(m*n) * l**(m-s)*res_q(g, r, x) + +def res_z(f, g, x): + """ + The input polynomials f, g are in Z[x] or in Q[x]. + + The output is the resultant of f, g computed recursively + by polynomial divisions in Z[x], using the function prem(). + See Cohen's book p. 283. + + References: + =========== + 1. J. S. Cohen: Computer Algebra and Symbolic Computation + - Mathematical Methods. A. K. Peters, 2003. + """ + m = degree(f, x) + n = degree(g, x) + if m < n: + return (-1)**(m*n) * res_z(g, f, x) + elif n == 0: # g is a constant + return g**m + else: + r = prem(f, g, x) + if r == 0: + return 0 + else: + delta = m - n + 1 + w = (-1)**(m*n) * res_z(g, r, x) + s = degree(r, x) + l = LC(g, x) + k = delta * n - m + s + return quo(w, l**k, x) + +def sign_seq(poly_seq, x): + """ + Given a sequence of polynomials poly_seq, it returns + the sequence of signs of the leading coefficients of + the polynomials in poly_seq. + + """ + return [sign(LC(poly_seq[i], x)) for i in range(len(poly_seq))] + +def bezout(p, q, x, method='bz'): + """ + The input polynomials p, q are in Z[x] or in Q[x]. Let + mx = max(degree(p, x), degree(q, x)). + + The default option bezout(p, q, x, method='bz') returns Bezout's + symmetric matrix of p and q, of dimensions (mx) x (mx). The + determinant of this matrix is equal to the determinant of sylvester2, + Sylvester's matrix of 1853, whose dimensions are (2*mx) x (2*mx); + however the subresultants of these two matrices may differ. + + The other option, bezout(p, q, x, 'prs'), is of interest to us + in this module because it returns a matrix equivalent to sylvester2. + In this case all subresultants of the two matrices are identical. + + Both the subresultant polynomial remainder sequence (prs) and + the modified subresultant prs of p and q can be computed by + evaluating determinants of appropriately selected submatrices of + bezout(p, q, x, 'prs') --- one determinant per coefficient of the + remainder polynomials. + + The matrices bezout(p, q, x, 'bz') and bezout(p, q, x, 'prs') + are related by the formula + + bezout(p, q, x, 'prs') = + backward_eye(deg(p)) * bezout(p, q, x, 'bz') * backward_eye(deg(p)), + + where backward_eye() is the backward identity function. + + References + ========== + 1. G.M.Diaz-Toca,L.Gonzalez-Vega: Various New Expressions for Subresultants + and Their Applications. Appl. Algebra in Engin., Communic. and Comp., + Vol. 15, 233-266, 2004. + + """ + # obtain degrees of polys + m, n = degree( Poly(p, x), x), degree( Poly(q, x), x) + + # Special cases: + # A:: case m = n < 0 (i.e. both polys are 0) + if m == n and n < 0: + return Matrix([]) + + # B:: case m = n = 0 (i.e. both polys are constants) + if m == n and n == 0: + return Matrix([]) + + # C:: m == 0 and n < 0 or m < 0 and n == 0 + # (i.e. one poly is constant and the other is 0) + if m == 0 and n < 0: + return Matrix([]) + elif m < 0 and n == 0: + return Matrix([]) + + # D:: m >= 1 and n < 0 or m < 0 and n >=1 + # (i.e. one poly is of degree >=1 and the other is 0) + if m >= 1 and n < 0: + return Matrix([0]) + elif m < 0 and n >= 1: + return Matrix([0]) + + y = var('y') + + # expr is 0 when x = y + expr = p * q.subs({x:y}) - p.subs({x:y}) * q + + # hence expr is exactly divisible by x - y + poly = Poly( quo(expr, x-y), x, y) + + # form Bezout matrix and store them in B as indicated to get + # the LC coefficient of each poly either in the first position + # of each row (method='prs') or in the last (method='bz'). + mx = max(m, n) + B = zeros(mx) + for i in range(mx): + for j in range(mx): + if method == 'prs': + B[mx - 1 - i, mx - 1 - j] = poly.nth(i, j) + else: + B[i, j] = poly.nth(i, j) + return B + +def backward_eye(n): + ''' + Returns the backward identity matrix of dimensions n x n. + + Needed to "turn" the Bezout matrices + so that the leading coefficients are first. + See docstring of the function bezout(p, q, x, method='bz'). + ''' + M = eye(n) # identity matrix of order n + + for i in range(int(M.rows / 2)): + M.row_swap(0 + i, M.rows - 1 - i) + + return M + +def subresultants_bezout(p, q, x): + """ + The input polynomials p, q are in Z[x] or in Q[x]. It is assumed + that degree(p, x) >= degree(q, x). + + Computes the subresultant polynomial remainder sequence + of p, q by evaluating determinants of appropriately selected + submatrices of bezout(p, q, x, 'prs'). The dimensions of the + latter are deg(p) x deg(p). + + Each coefficient is computed by evaluating the determinant of the + corresponding submatrix of bezout(p, q, x, 'prs'). + + bezout(p, q, x, 'prs) is used instead of sylvester(p, q, x, 1), + Sylvester's matrix of 1840, because the dimensions of the latter + are (deg(p) + deg(q)) x (deg(p) + deg(q)). + + If the subresultant prs is complete, then the output coincides + with the Euclidean sequence of the polynomials p, q. + + References + ========== + 1. G.M.Diaz-Toca,L.Gonzalez-Vega: Various New Expressions for Subresultants + and Their Applications. Appl. Algebra in Engin., Communic. and Comp., + Vol. 15, 233-266, 2004. + + """ + # make sure neither p nor q is 0 + if p == 0 or q == 0: + return [p, q] + + f, g = p, q + n = degF = degree(f, x) + m = degG = degree(g, x) + + # make sure proper degrees + if n == 0 and m == 0: + return [f, g] + if n < m: + n, m, degF, degG, f, g = m, n, degG, degF, g, f + if n > 0 and m == 0: + return [f, g] + + SR_L = [f, g] # subresultant list + F = LC(f, x)**(degF - degG) + + # form the bezout matrix + B = bezout(f, g, x, 'prs') + + # pick appropriate submatrices of B + # and form subresultant polys + if degF > degG: + j = 2 + if degF == degG: + j = 1 + while j <= degF: + M = B[0:j, :] + k, coeff_L = j - 1, [] + while k <= degF - 1: + coeff_L.append(M[:, 0:j].det()) + if k < degF - 1: + M.col_swap(j - 1, k + 1) + k = k + 1 + + # apply Theorem 2.1 in the paper by Toca & Vega 2004 + # to get correct signs + SR_L.append(int((-1)**(j*(j-1)/2)) * (Poly(coeff_L, x) / F).as_expr()) + j = j + 1 + + return process_matrix_output(SR_L, x) + +def modified_subresultants_bezout(p, q, x): + """ + The input polynomials p, q are in Z[x] or in Q[x]. It is assumed + that degree(p, x) >= degree(q, x). + + Computes the modified subresultant polynomial remainder sequence + of p, q by evaluating determinants of appropriately selected + submatrices of bezout(p, q, x, 'prs'). The dimensions of the + latter are deg(p) x deg(p). + + Each coefficient is computed by evaluating the determinant of the + corresponding submatrix of bezout(p, q, x, 'prs'). + + bezout(p, q, x, 'prs') is used instead of sylvester(p, q, x, 2), + Sylvester's matrix of 1853, because the dimensions of the latter + are 2*deg(p) x 2*deg(p). + + If the modified subresultant prs is complete, and LC( p ) > 0, the output + coincides with the (generalized) Sturm's sequence of the polynomials p, q. + + References + ========== + 1. Akritas, A. G., G.I. Malaschonok and P.S. Vigklas: ``Sturm Sequences + and Modified Subresultant Polynomial Remainder Sequences.'' + Serdica Journal of Computing, Vol. 8, No 1, 29-46, 2014. + + 2. G.M.Diaz-Toca,L.Gonzalez-Vega: Various New Expressions for Subresultants + and Their Applications. Appl. Algebra in Engin., Communic. and Comp., + Vol. 15, 233-266, 2004. + + + """ + # make sure neither p nor q is 0 + if p == 0 or q == 0: + return [p, q] + + f, g = p, q + n = degF = degree(f, x) + m = degG = degree(g, x) + + # make sure proper degrees + if n == 0 and m == 0: + return [f, g] + if n < m: + n, m, degF, degG, f, g = m, n, degG, degF, g, f + if n > 0 and m == 0: + return [f, g] + + SR_L = [f, g] # subresultant list + + # form the bezout matrix + B = bezout(f, g, x, 'prs') + + # pick appropriate submatrices of B + # and form subresultant polys + if degF > degG: + j = 2 + if degF == degG: + j = 1 + while j <= degF: + M = B[0:j, :] + k, coeff_L = j - 1, [] + while k <= degF - 1: + coeff_L.append(M[:, 0:j].det()) + if k < degF - 1: + M.col_swap(j - 1, k + 1) + k = k + 1 + + ## Theorem 2.1 in the paper by Toca & Vega 2004 is _not needed_ + ## in this case since + ## the bezout matrix is equivalent to sylvester2 + SR_L.append(( Poly(coeff_L, x)).as_expr()) + j = j + 1 + + return process_matrix_output(SR_L, x) + +def sturm_pg(p, q, x, method=0): + """ + p, q are polynomials in Z[x] or Q[x]. It is assumed + that degree(p, x) >= degree(q, x). + + Computes the (generalized) Sturm sequence of p and q in Z[x] or Q[x]. + If q = diff(p, x, 1) it is the usual Sturm sequence. + + A. If method == 0, default, the remainder coefficients of the sequence + are (in absolute value) ``modified'' subresultants, which for non-monic + polynomials are greater than the coefficients of the corresponding + subresultants by the factor Abs(LC(p)**( deg(p)- deg(q))). + + B. If method == 1, the remainder coefficients of the sequence are (in + absolute value) subresultants, which for non-monic polynomials are + smaller than the coefficients of the corresponding ``modified'' + subresultants by the factor Abs(LC(p)**( deg(p)- deg(q))). + + If the Sturm sequence is complete, method=0 and LC( p ) > 0, the coefficients + of the polynomials in the sequence are ``modified'' subresultants. + That is, they are determinants of appropriately selected submatrices of + sylvester2, Sylvester's matrix of 1853. In this case the Sturm sequence + coincides with the ``modified'' subresultant prs, of the polynomials + p, q. + + If the Sturm sequence is incomplete and method=0 then the signs of the + coefficients of the polynomials in the sequence may differ from the signs + of the coefficients of the corresponding polynomials in the ``modified'' + subresultant prs; however, the absolute values are the same. + + To compute the coefficients, no determinant evaluation takes place. Instead, + polynomial divisions in Q[x] are performed, using the function rem(p, q, x); + the coefficients of the remainders computed this way become (``modified'') + subresultants with the help of the Pell-Gordon Theorem of 1917. + See also the function euclid_pg(p, q, x). + + References + ========== + 1. Pell A. J., R. L. Gordon. The Modified Remainders Obtained in Finding + the Highest Common Factor of Two Polynomials. Annals of MatheMatics, + Second Series, 18 (1917), No. 4, 188-193. + + 2. Akritas, A. G., G.I. Malaschonok and P.S. Vigklas: ``Sturm Sequences + and Modified Subresultant Polynomial Remainder Sequences.'' + Serdica Journal of Computing, Vol. 8, No 1, 29-46, 2014. + + """ + # make sure neither p nor q is 0 + if p == 0 or q == 0: + return [p, q] + + # make sure proper degrees + d0 = degree(p, x) + d1 = degree(q, x) + if d0 == 0 and d1 == 0: + return [p, q] + if d1 > d0: + d0, d1 = d1, d0 + p, q = q, p + if d0 > 0 and d1 == 0: + return [p,q] + + # make sure LC(p) > 0 + flag = 0 + if LC(p,x) < 0: + flag = 1 + p = -p + q = -q + + # initialize + lcf = LC(p, x)**(d0 - d1) # lcf * subr = modified subr + a0, a1 = p, q # the input polys + sturm_seq = [a0, a1] # the output list + del0 = d0 - d1 # degree difference + rho1 = LC(a1, x) # leading coeff of a1 + exp_deg = d1 - 1 # expected degree of a2 + a2 = - rem(a0, a1, domain=QQ) # first remainder + rho2 = LC(a2,x) # leading coeff of a2 + d2 = degree(a2, x) # actual degree of a2 + deg_diff_new = exp_deg - d2 # expected - actual degree + del1 = d1 - d2 # degree difference + + # mul_fac is the factor by which a2 is multiplied to + # get integer coefficients + mul_fac_old = rho1**(del0 + del1 - deg_diff_new) + + # append accordingly + if method == 0: + sturm_seq.append( simplify(lcf * a2 * Abs(mul_fac_old))) + else: + sturm_seq.append( simplify( a2 * Abs(mul_fac_old))) + + # main loop + deg_diff_old = deg_diff_new + while d2 > 0: + a0, a1, d0, d1 = a1, a2, d1, d2 # update polys and degrees + del0 = del1 # update degree difference + exp_deg = d1 - 1 # new expected degree + a2 = - rem(a0, a1, domain=QQ) # new remainder + rho3 = LC(a2, x) # leading coeff of a2 + d2 = degree(a2, x) # actual degree of a2 + deg_diff_new = exp_deg - d2 # expected - actual degree + del1 = d1 - d2 # degree difference + + # take into consideration the power + # rho1**deg_diff_old that was "left out" + expo_old = deg_diff_old # rho1 raised to this power + expo_new = del0 + del1 - deg_diff_new # rho2 raised to this power + + # update variables and append + mul_fac_new = rho2**(expo_new) * rho1**(expo_old) * mul_fac_old + deg_diff_old, mul_fac_old = deg_diff_new, mul_fac_new + rho1, rho2 = rho2, rho3 + if method == 0: + sturm_seq.append( simplify(lcf * a2 * Abs(mul_fac_old))) + else: + sturm_seq.append( simplify( a2 * Abs(mul_fac_old))) + + if flag: # change the sign of the sequence + sturm_seq = [-i for i in sturm_seq] + + # gcd is of degree > 0 ? + m = len(sturm_seq) + if sturm_seq[m - 1] == nan or sturm_seq[m - 1] == 0: + sturm_seq.pop(m - 1) + + return sturm_seq + +def sturm_q(p, q, x): + """ + p, q are polynomials in Z[x] or Q[x]. It is assumed + that degree(p, x) >= degree(q, x). + + Computes the (generalized) Sturm sequence of p and q in Q[x]. + Polynomial divisions in Q[x] are performed, using the function rem(p, q, x). + + The coefficients of the polynomials in the Sturm sequence can be uniquely + determined from the corresponding coefficients of the polynomials found + either in: + + (a) the ``modified'' subresultant prs, (references 1, 2) + + or in + + (b) the subresultant prs (reference 3). + + References + ========== + 1. Pell A. J., R. L. Gordon. The Modified Remainders Obtained in Finding + the Highest Common Factor of Two Polynomials. Annals of MatheMatics, + Second Series, 18 (1917), No. 4, 188-193. + + 2 Akritas, A. G., G.I. Malaschonok and P.S. Vigklas: ``Sturm Sequences + and Modified Subresultant Polynomial Remainder Sequences.'' + Serdica Journal of Computing, Vol. 8, No 1, 29-46, 2014. + + 3. Akritas, A. G., G.I. Malaschonok and P.S. Vigklas: ``A Basic Result + on the Theory of Subresultants.'' Serdica Journal of Computing 10 (2016), No.1, 31-48. + + """ + # make sure neither p nor q is 0 + if p == 0 or q == 0: + return [p, q] + + # make sure proper degrees + d0 = degree(p, x) + d1 = degree(q, x) + if d0 == 0 and d1 == 0: + return [p, q] + if d1 > d0: + d0, d1 = d1, d0 + p, q = q, p + if d0 > 0 and d1 == 0: + return [p,q] + + # make sure LC(p) > 0 + flag = 0 + if LC(p,x) < 0: + flag = 1 + p = -p + q = -q + + # initialize + a0, a1 = p, q # the input polys + sturm_seq = [a0, a1] # the output list + a2 = -rem(a0, a1, domain=QQ) # first remainder + d2 = degree(a2, x) # degree of a2 + sturm_seq.append( a2 ) + + # main loop + while d2 > 0: + a0, a1, d0, d1 = a1, a2, d1, d2 # update polys and degrees + a2 = -rem(a0, a1, domain=QQ) # new remainder + d2 = degree(a2, x) # actual degree of a2 + sturm_seq.append( a2 ) + + if flag: # change the sign of the sequence + sturm_seq = [-i for i in sturm_seq] + + # gcd is of degree > 0 ? + m = len(sturm_seq) + if sturm_seq[m - 1] == nan or sturm_seq[m - 1] == 0: + sturm_seq.pop(m - 1) + + return sturm_seq + +def sturm_amv(p, q, x, method=0): + """ + p, q are polynomials in Z[x] or Q[x]. It is assumed + that degree(p, x) >= degree(q, x). + + Computes the (generalized) Sturm sequence of p and q in Z[x] or Q[x]. + If q = diff(p, x, 1) it is the usual Sturm sequence. + + A. If method == 0, default, the remainder coefficients of the + sequence are (in absolute value) ``modified'' subresultants, which + for non-monic polynomials are greater than the coefficients of the + corresponding subresultants by the factor Abs(LC(p)**( deg(p)- deg(q))). + + B. If method == 1, the remainder coefficients of the sequence are (in + absolute value) subresultants, which for non-monic polynomials are + smaller than the coefficients of the corresponding ``modified'' + subresultants by the factor Abs( LC(p)**( deg(p)- deg(q)) ). + + If the Sturm sequence is complete, method=0 and LC( p ) > 0, then the + coefficients of the polynomials in the sequence are ``modified'' subresultants. + That is, they are determinants of appropriately selected submatrices of + sylvester2, Sylvester's matrix of 1853. In this case the Sturm sequence + coincides with the ``modified'' subresultant prs, of the polynomials + p, q. + + If the Sturm sequence is incomplete and method=0 then the signs of the + coefficients of the polynomials in the sequence may differ from the signs + of the coefficients of the corresponding polynomials in the ``modified'' + subresultant prs; however, the absolute values are the same. + + To compute the coefficients, no determinant evaluation takes place. + Instead, we first compute the euclidean sequence of p and q using + euclid_amv(p, q, x) and then: (a) change the signs of the remainders in the + Euclidean sequence according to the pattern "-, -, +, +, -, -, +, +,..." + (see Lemma 1 in the 1st reference or Theorem 3 in the 2nd reference) + and (b) if method=0, assuming deg(p) > deg(q), we multiply the remainder + coefficients of the Euclidean sequence times the factor + Abs( LC(p)**( deg(p)- deg(q)) ) to make them modified subresultants. + See also the function sturm_pg(p, q, x). + + References + ========== + 1. Akritas, A. G., G.I. Malaschonok and P.S. Vigklas: ``A Basic Result + on the Theory of Subresultants.'' Serdica Journal of Computing 10 (2016), No.1, 31-48. + + 2. Akritas, A. G., G.I. Malaschonok and P.S. Vigklas: ``On the Remainders + Obtained in Finding the Greatest Common Divisor of Two Polynomials.'' Serdica + Journal of Computing 9(2) (2015), 123-138. + + 3. Akritas, A. G., G.I. Malaschonok and P.S. Vigklas: ``Subresultant Polynomial + Remainder Sequences Obtained by Polynomial Divisions in Q[x] or in Z[x].'' + Serdica Journal of Computing 10 (2016), No.3-4, 197-217. + + """ + # compute the euclidean sequence + prs = euclid_amv(p, q, x) + + # defensive + if prs == [] or len(prs) == 2: + return prs + + # the coefficients in prs are subresultants and hence are smaller + # than the corresponding subresultants by the factor + # Abs( LC(prs[0])**( deg(prs[0]) - deg(prs[1])) ); Theorem 2, 2nd reference. + lcf = Abs( LC(prs[0])**( degree(prs[0], x) - degree(prs[1], x) ) ) + + # the signs of the first two polys in the sequence stay the same + sturm_seq = [prs[0], prs[1]] + + # change the signs according to "-, -, +, +, -, -, +, +,..." + # and multiply times lcf if needed + flag = 0 + m = len(prs) + i = 2 + while i <= m-1: + if flag == 0: + sturm_seq.append( - prs[i] ) + i = i + 1 + if i == m: + break + sturm_seq.append( - prs[i] ) + i = i + 1 + flag = 1 + elif flag == 1: + sturm_seq.append( prs[i] ) + i = i + 1 + if i == m: + break + sturm_seq.append( prs[i] ) + i = i + 1 + flag = 0 + + # subresultants or modified subresultants? + if method == 0 and lcf > 1: + aux_seq = [sturm_seq[0], sturm_seq[1]] + for i in range(2, m): + aux_seq.append(simplify(sturm_seq[i] * lcf )) + sturm_seq = aux_seq + + return sturm_seq + +def euclid_pg(p, q, x): + """ + p, q are polynomials in Z[x] or Q[x]. It is assumed + that degree(p, x) >= degree(q, x). + + Computes the Euclidean sequence of p and q in Z[x] or Q[x]. + + If the Euclidean sequence is complete the coefficients of the polynomials + in the sequence are subresultants. That is, they are determinants of + appropriately selected submatrices of sylvester1, Sylvester's matrix of 1840. + In this case the Euclidean sequence coincides with the subresultant prs + of the polynomials p, q. + + If the Euclidean sequence is incomplete the signs of the coefficients of the + polynomials in the sequence may differ from the signs of the coefficients of + the corresponding polynomials in the subresultant prs; however, the absolute + values are the same. + + To compute the Euclidean sequence, no determinant evaluation takes place. + We first compute the (generalized) Sturm sequence of p and q using + sturm_pg(p, q, x, 1), in which case the coefficients are (in absolute value) + equal to subresultants. Then we change the signs of the remainders in the + Sturm sequence according to the pattern "-, -, +, +, -, -, +, +,..." ; + see Lemma 1 in the 1st reference or Theorem 3 in the 2nd reference as well as + the function sturm_pg(p, q, x). + + References + ========== + 1. Akritas, A. G., G.I. Malaschonok and P.S. Vigklas: ``A Basic Result + on the Theory of Subresultants.'' Serdica Journal of Computing 10 (2016), No.1, 31-48. + + 2. Akritas, A. G., G.I. Malaschonok and P.S. Vigklas: ``On the Remainders + Obtained in Finding the Greatest Common Divisor of Two Polynomials.'' Serdica + Journal of Computing 9(2) (2015), 123-138. + + 3. Akritas, A. G., G.I. Malaschonok and P.S. Vigklas: ``Subresultant Polynomial + Remainder Sequences Obtained by Polynomial Divisions in Q[x] or in Z[x].'' + Serdica Journal of Computing 10 (2016), No.3-4, 197-217. + """ + # compute the sturmian sequence using the Pell-Gordon (or AMV) theorem + # with the coefficients in the prs being (in absolute value) subresultants + prs = sturm_pg(p, q, x, 1) ## any other method would do + + # defensive + if prs == [] or len(prs) == 2: + return prs + + # the signs of the first two polys in the sequence stay the same + euclid_seq = [prs[0], prs[1]] + + # change the signs according to "-, -, +, +, -, -, +, +,..." + flag = 0 + m = len(prs) + i = 2 + while i <= m-1: + if flag == 0: + euclid_seq.append(- prs[i] ) + i = i + 1 + if i == m: + break + euclid_seq.append(- prs[i] ) + i = i + 1 + flag = 1 + elif flag == 1: + euclid_seq.append(prs[i] ) + i = i + 1 + if i == m: + break + euclid_seq.append(prs[i] ) + i = i + 1 + flag = 0 + + return euclid_seq + +def euclid_q(p, q, x): + """ + p, q are polynomials in Z[x] or Q[x]. It is assumed + that degree(p, x) >= degree(q, x). + + Computes the Euclidean sequence of p and q in Q[x]. + Polynomial divisions in Q[x] are performed, using the function rem(p, q, x). + + The coefficients of the polynomials in the Euclidean sequence can be uniquely + determined from the corresponding coefficients of the polynomials found + either in: + + (a) the ``modified'' subresultant polynomial remainder sequence, + (references 1, 2) + + or in + + (b) the subresultant polynomial remainder sequence (references 3). + + References + ========== + 1. Pell A. J., R. L. Gordon. The Modified Remainders Obtained in Finding + the Highest Common Factor of Two Polynomials. Annals of MatheMatics, + Second Series, 18 (1917), No. 4, 188-193. + + 2. Akritas, A. G., G.I. Malaschonok and P.S. Vigklas: ``Sturm Sequences + and Modified Subresultant Polynomial Remainder Sequences.'' + Serdica Journal of Computing, Vol. 8, No 1, 29-46, 2014. + + 3. Akritas, A. G., G.I. Malaschonok and P.S. Vigklas: ``A Basic Result + on the Theory of Subresultants.'' Serdica Journal of Computing 10 (2016), No.1, 31-48. + + """ + # make sure neither p nor q is 0 + if p == 0 or q == 0: + return [p, q] + + # make sure proper degrees + d0 = degree(p, x) + d1 = degree(q, x) + if d0 == 0 and d1 == 0: + return [p, q] + if d1 > d0: + d0, d1 = d1, d0 + p, q = q, p + if d0 > 0 and d1 == 0: + return [p,q] + + # make sure LC(p) > 0 + flag = 0 + if LC(p,x) < 0: + flag = 1 + p = -p + q = -q + + # initialize + a0, a1 = p, q # the input polys + euclid_seq = [a0, a1] # the output list + a2 = rem(a0, a1, domain=QQ) # first remainder + d2 = degree(a2, x) # degree of a2 + euclid_seq.append( a2 ) + + # main loop + while d2 > 0: + a0, a1, d0, d1 = a1, a2, d1, d2 # update polys and degrees + a2 = rem(a0, a1, domain=QQ) # new remainder + d2 = degree(a2, x) # actual degree of a2 + euclid_seq.append( a2 ) + + if flag: # change the sign of the sequence + euclid_seq = [-i for i in euclid_seq] + + # gcd is of degree > 0 ? + m = len(euclid_seq) + if euclid_seq[m - 1] == nan or euclid_seq[m - 1] == 0: + euclid_seq.pop(m - 1) + + return euclid_seq + +def euclid_amv(f, g, x): + """ + f, g are polynomials in Z[x] or Q[x]. It is assumed + that degree(f, x) >= degree(g, x). + + Computes the Euclidean sequence of p and q in Z[x] or Q[x]. + + If the Euclidean sequence is complete the coefficients of the polynomials + in the sequence are subresultants. That is, they are determinants of + appropriately selected submatrices of sylvester1, Sylvester's matrix of 1840. + In this case the Euclidean sequence coincides with the subresultant prs, + of the polynomials p, q. + + If the Euclidean sequence is incomplete the signs of the coefficients of the + polynomials in the sequence may differ from the signs of the coefficients of + the corresponding polynomials in the subresultant prs; however, the absolute + values are the same. + + To compute the coefficients, no determinant evaluation takes place. + Instead, polynomial divisions in Z[x] or Q[x] are performed, using + the function rem_z(f, g, x); the coefficients of the remainders + computed this way become subresultants with the help of the + Collins-Brown-Traub formula for coefficient reduction. + + References + ========== + 1. Akritas, A. G., G.I. Malaschonok and P.S. Vigklas: ``A Basic Result + on the Theory of Subresultants.'' Serdica Journal of Computing 10 (2016), No.1, 31-48. + + 2. Akritas, A. G., G.I. Malaschonok and P.S. Vigklas: ``Subresultant Polynomial + remainder Sequences Obtained by Polynomial Divisions in Q[x] or in Z[x].'' + Serdica Journal of Computing 10 (2016), No.3-4, 197-217. + + """ + # make sure neither f nor g is 0 + if f == 0 or g == 0: + return [f, g] + + # make sure proper degrees + d0 = degree(f, x) + d1 = degree(g, x) + if d0 == 0 and d1 == 0: + return [f, g] + if d1 > d0: + d0, d1 = d1, d0 + f, g = g, f + if d0 > 0 and d1 == 0: + return [f, g] + + # initialize + a0 = f + a1 = g + euclid_seq = [a0, a1] + deg_dif_p1, c = degree(a0, x) - degree(a1, x) + 1, -1 + + # compute the first polynomial of the prs + i = 1 + a2 = rem_z(a0, a1, x) / Abs( (-1)**deg_dif_p1 ) # first remainder + euclid_seq.append( a2 ) + d2 = degree(a2, x) # actual degree of a2 + + # main loop + while d2 >= 1: + a0, a1, d0, d1 = a1, a2, d1, d2 # update polys and degrees + i += 1 + sigma0 = -LC(a0) + c = (sigma0**(deg_dif_p1 - 1)) / (c**(deg_dif_p1 - 2)) + deg_dif_p1 = degree(a0, x) - d2 + 1 + a2 = rem_z(a0, a1, x) / Abs( (c**(deg_dif_p1 - 1)) * sigma0 ) + euclid_seq.append( a2 ) + d2 = degree(a2, x) # actual degree of a2 + + # gcd is of degree > 0 ? + m = len(euclid_seq) + if euclid_seq[m - 1] == nan or euclid_seq[m - 1] == 0: + euclid_seq.pop(m - 1) + + return euclid_seq + +def modified_subresultants_pg(p, q, x): + """ + p, q are polynomials in Z[x] or Q[x]. It is assumed + that degree(p, x) >= degree(q, x). + + Computes the ``modified'' subresultant prs of p and q in Z[x] or Q[x]; + the coefficients of the polynomials in the sequence are + ``modified'' subresultants. That is, they are determinants of appropriately + selected submatrices of sylvester2, Sylvester's matrix of 1853. + + To compute the coefficients, no determinant evaluation takes place. Instead, + polynomial divisions in Q[x] are performed, using the function rem(p, q, x); + the coefficients of the remainders computed this way become ``modified'' + subresultants with the help of the Pell-Gordon Theorem of 1917. + + If the ``modified'' subresultant prs is complete, and LC( p ) > 0, it coincides + with the (generalized) Sturm sequence of the polynomials p, q. + + References + ========== + 1. Pell A. J., R. L. Gordon. The Modified Remainders Obtained in Finding + the Highest Common Factor of Two Polynomials. Annals of MatheMatics, + Second Series, 18 (1917), No. 4, 188-193. + + 2. Akritas, A. G., G.I. Malaschonok and P.S. Vigklas: ``Sturm Sequences + and Modified Subresultant Polynomial Remainder Sequences.'' + Serdica Journal of Computing, Vol. 8, No 1, 29-46, 2014. + + """ + # make sure neither p nor q is 0 + if p == 0 or q == 0: + return [p, q] + + # make sure proper degrees + d0 = degree(p,x) + d1 = degree(q,x) + if d0 == 0 and d1 == 0: + return [p, q] + if d1 > d0: + d0, d1 = d1, d0 + p, q = q, p + if d0 > 0 and d1 == 0: + return [p,q] + + # initialize + k = var('k') # index in summation formula + u_list = [] # of elements (-1)**u_i + subres_l = [p, q] # mod. subr. prs output list + a0, a1 = p, q # the input polys + del0 = d0 - d1 # degree difference + degdif = del0 # save it + rho_1 = LC(a0) # lead. coeff (a0) + + # Initialize Pell-Gordon variables + rho_list_minus_1 = sign( LC(a0, x)) # sign of LC(a0) + rho1 = LC(a1, x) # leading coeff of a1 + rho_list = [ sign(rho1)] # of signs + p_list = [del0] # of degree differences + u = summation(k, (k, 1, p_list[0])) # value of u + u_list.append(u) # of u values + v = sum(p_list) # v value + + # first remainder + exp_deg = d1 - 1 # expected degree of a2 + a2 = - rem(a0, a1, domain=QQ) # first remainder + rho2 = LC(a2, x) # leading coeff of a2 + d2 = degree(a2, x) # actual degree of a2 + deg_diff_new = exp_deg - d2 # expected - actual degree + del1 = d1 - d2 # degree difference + + # mul_fac is the factor by which a2 is multiplied to + # get integer coefficients + mul_fac_old = rho1**(del0 + del1 - deg_diff_new) + + # update Pell-Gordon variables + p_list.append(1 + deg_diff_new) # deg_diff_new is 0 for complete seq + + # apply Pell-Gordon formula (7) in second reference + num = 1 # numerator of fraction + for u in u_list: + num *= (-1)**u + num = num * (-1)**v + + # denominator depends on complete / incomplete seq + if deg_diff_new == 0: # complete seq + den = 1 + for k in range(len(rho_list)): + den *= rho_list[k]**(p_list[k] + p_list[k + 1]) + den = den * rho_list_minus_1 + else: # incomplete seq + den = 1 + for k in range(len(rho_list)-1): + den *= rho_list[k]**(p_list[k] + p_list[k + 1]) + den = den * rho_list_minus_1 + expo = (p_list[len(rho_list) - 1] + p_list[len(rho_list)] - deg_diff_new) + den = den * rho_list[len(rho_list) - 1]**expo + + # the sign of the determinant depends on sg(num / den) + if sign(num / den) > 0: + subres_l.append( simplify(rho_1**degdif*a2* Abs(mul_fac_old) ) ) + else: + subres_l.append(- simplify(rho_1**degdif*a2* Abs(mul_fac_old) ) ) + + # update Pell-Gordon variables + k = var('k') + rho_list.append( sign(rho2)) + u = summation(k, (k, 1, p_list[len(p_list) - 1])) + u_list.append(u) + v = sum(p_list) + deg_diff_old=deg_diff_new + + # main loop + while d2 > 0: + a0, a1, d0, d1 = a1, a2, d1, d2 # update polys and degrees + del0 = del1 # update degree difference + exp_deg = d1 - 1 # new expected degree + a2 = - rem(a0, a1, domain=QQ) # new remainder + rho3 = LC(a2, x) # leading coeff of a2 + d2 = degree(a2, x) # actual degree of a2 + deg_diff_new = exp_deg - d2 # expected - actual degree + del1 = d1 - d2 # degree difference + + # take into consideration the power + # rho1**deg_diff_old that was "left out" + expo_old = deg_diff_old # rho1 raised to this power + expo_new = del0 + del1 - deg_diff_new # rho2 raised to this power + + mul_fac_new = rho2**(expo_new) * rho1**(expo_old) * mul_fac_old + + # update variables + deg_diff_old, mul_fac_old = deg_diff_new, mul_fac_new + rho1, rho2 = rho2, rho3 + + # update Pell-Gordon variables + p_list.append(1 + deg_diff_new) # deg_diff_new is 0 for complete seq + + # apply Pell-Gordon formula (7) in second reference + num = 1 # numerator + for u in u_list: + num *= (-1)**u + num = num * (-1)**v + + # denominator depends on complete / incomplete seq + if deg_diff_new == 0: # complete seq + den = 1 + for k in range(len(rho_list)): + den *= rho_list[k]**(p_list[k] + p_list[k + 1]) + den = den * rho_list_minus_1 + else: # incomplete seq + den = 1 + for k in range(len(rho_list)-1): + den *= rho_list[k]**(p_list[k] + p_list[k + 1]) + den = den * rho_list_minus_1 + expo = (p_list[len(rho_list) - 1] + p_list[len(rho_list)] - deg_diff_new) + den = den * rho_list[len(rho_list) - 1]**expo + + # the sign of the determinant depends on sg(num / den) + if sign(num / den) > 0: + subres_l.append( simplify(rho_1**degdif*a2* Abs(mul_fac_old) ) ) + else: + subres_l.append(- simplify(rho_1**degdif*a2* Abs(mul_fac_old) ) ) + + # update Pell-Gordon variables + k = var('k') + rho_list.append( sign(rho2)) + u = summation(k, (k, 1, p_list[len(p_list) - 1])) + u_list.append(u) + v = sum(p_list) + + # gcd is of degree > 0 ? + m = len(subres_l) + if subres_l[m - 1] == nan or subres_l[m - 1] == 0: + subres_l.pop(m - 1) + + # LC( p ) < 0 + m = len(subres_l) # list may be shorter now due to deg(gcd ) > 0 + if LC( p ) < 0: + aux_seq = [subres_l[0], subres_l[1]] + for i in range(2, m): + aux_seq.append(simplify(subres_l[i] * (-1) )) + subres_l = aux_seq + + return subres_l + +def subresultants_pg(p, q, x): + """ + p, q are polynomials in Z[x] or Q[x]. It is assumed + that degree(p, x) >= degree(q, x). + + Computes the subresultant prs of p and q in Z[x] or Q[x], from + the modified subresultant prs of p and q. + + The coefficients of the polynomials in these two sequences differ only + in sign and the factor LC(p)**( deg(p)- deg(q)) as stated in + Theorem 2 of the reference. + + The coefficients of the polynomials in the output sequence are + subresultants. That is, they are determinants of appropriately + selected submatrices of sylvester1, Sylvester's matrix of 1840. + + If the subresultant prs is complete, then it coincides with the + Euclidean sequence of the polynomials p, q. + + References + ========== + 1. Akritas, A. G., G.I. Malaschonok and P.S. Vigklas: "On the Remainders + Obtained in Finding the Greatest Common Divisor of Two Polynomials." + Serdica Journal of Computing 9(2) (2015), 123-138. + + """ + # compute the modified subresultant prs + lst = modified_subresultants_pg(p,q,x) ## any other method would do + + # defensive + if lst == [] or len(lst) == 2: + return lst + + # the coefficients in lst are modified subresultants and, hence, are + # greater than those of the corresponding subresultants by the factor + # LC(lst[0])**( deg(lst[0]) - deg(lst[1])); see Theorem 2 in reference. + lcf = LC(lst[0])**( degree(lst[0], x) - degree(lst[1], x) ) + + # Initialize the subresultant prs list + subr_seq = [lst[0], lst[1]] + + # compute the degree sequences m_i and j_i of Theorem 2 in reference. + deg_seq = [degree(Poly(poly, x), x) for poly in lst] + deg = deg_seq[0] + deg_seq_s = deg_seq[1:-1] + m_seq = [m-1 for m in deg_seq_s] + j_seq = [deg - m for m in m_seq] + + # compute the AMV factors of Theorem 2 in reference. + fact = [(-1)**( j*(j-1)/S(2) ) for j in j_seq] + + # shortened list without the first two polys + lst_s = lst[2:] + + # poly lst_s[k] is multiplied times fact[k], divided by lcf + # and appended to the subresultant prs list + m = len(fact) + for k in range(m): + if sign(fact[k]) == -1: + subr_seq.append(-lst_s[k] / lcf) + else: + subr_seq.append(lst_s[k] / lcf) + + return subr_seq + +def subresultants_amv_q(p, q, x): + """ + p, q are polynomials in Z[x] or Q[x]. It is assumed + that degree(p, x) >= degree(q, x). + + Computes the subresultant prs of p and q in Q[x]; + the coefficients of the polynomials in the sequence are + subresultants. That is, they are determinants of appropriately + selected submatrices of sylvester1, Sylvester's matrix of 1840. + + To compute the coefficients, no determinant evaluation takes place. + Instead, polynomial divisions in Q[x] are performed, using the + function rem(p, q, x); the coefficients of the remainders + computed this way become subresultants with the help of the + Akritas-Malaschonok-Vigklas Theorem of 2015. + + If the subresultant prs is complete, then it coincides with the + Euclidean sequence of the polynomials p, q. + + References + ========== + 1. Akritas, A. G., G.I. Malaschonok and P.S. Vigklas: ``A Basic Result + on the Theory of Subresultants.'' Serdica Journal of Computing 10 (2016), No.1, 31-48. + + 2. Akritas, A. G., G.I. Malaschonok and P.S. Vigklas: ``Subresultant Polynomial + remainder Sequences Obtained by Polynomial Divisions in Q[x] or in Z[x].'' + Serdica Journal of Computing 10 (2016), No.3-4, 197-217. + + """ + # make sure neither p nor q is 0 + if p == 0 or q == 0: + return [p, q] + + # make sure proper degrees + d0 = degree(p, x) + d1 = degree(q, x) + if d0 == 0 and d1 == 0: + return [p, q] + if d1 > d0: + d0, d1 = d1, d0 + p, q = q, p + if d0 > 0 and d1 == 0: + return [p, q] + + # initialize + i, s = 0, 0 # counters for remainders & odd elements + p_odd_index_sum = 0 # contains the sum of p_1, p_3, etc + subres_l = [p, q] # subresultant prs output list + a0, a1 = p, q # the input polys + sigma1 = LC(a1, x) # leading coeff of a1 + p0 = d0 - d1 # degree difference + if p0 % 2 == 1: + s += 1 + phi = floor( (s + 1) / 2 ) + mul_fac = 1 + d2 = d1 + + # main loop + while d2 > 0: + i += 1 + a2 = rem(a0, a1, domain= QQ) # new remainder + if i == 1: + sigma2 = LC(a2, x) + else: + sigma3 = LC(a2, x) + sigma1, sigma2 = sigma2, sigma3 + d2 = degree(a2, x) + p1 = d1 - d2 + psi = i + phi + p_odd_index_sum + + # new mul_fac + mul_fac = sigma1**(p0 + 1) * mul_fac + + ## compute the sign of the first fraction in formula (9) of the paper + # numerator + num = (-1)**psi + # denominator + den = sign(mul_fac) + + # the sign of the determinant depends on sign( num / den ) != 0 + if sign(num / den) > 0: + subres_l.append( simplify(expand(a2* Abs(mul_fac)))) + else: + subres_l.append(- simplify(expand(a2* Abs(mul_fac)))) + + ## bring into mul_fac the missing power of sigma if there was a degree gap + if p1 - 1 > 0: + mul_fac = mul_fac * sigma1**(p1 - 1) + + # update AMV variables + a0, a1, d0, d1 = a1, a2, d1, d2 + p0 = p1 + if p0 % 2 ==1: + s += 1 + phi = floor( (s + 1) / 2 ) + if i%2 == 1: + p_odd_index_sum += p0 # p_i has odd index + + # gcd is of degree > 0 ? + m = len(subres_l) + if subres_l[m - 1] == nan or subres_l[m - 1] == 0: + subres_l.pop(m - 1) + + return subres_l + +def compute_sign(base, expo): + ''' + base != 0 and expo >= 0 are integers; + + returns the sign of base**expo without + evaluating the power itself! + ''' + sb = sign(base) + if sb == 1: + return 1 + pe = expo % 2 + if pe == 0: + return -sb + else: + return sb + +def rem_z(p, q, x): + ''' + Intended mainly for p, q polynomials in Z[x] so that, + on dividing p by q, the remainder will also be in Z[x]. (However, + it also works fine for polynomials in Q[x].) It is assumed + that degree(p, x) >= degree(q, x). + + It premultiplies p by the _absolute_ value of the leading coefficient + of q, raised to the power deg(p) - deg(q) + 1 and then performs + polynomial division in Q[x], using the function rem(p, q, x). + + By contrast the function prem(p, q, x) does _not_ use the absolute + value of the leading coefficient of q. + This results not only in ``messing up the signs'' of the Euclidean and + Sturmian prs's as mentioned in the second reference, + but also in violation of the main results of the first and third + references --- Theorem 4 and Theorem 1 respectively. Theorems 4 and 1 + establish a one-to-one correspondence between the Euclidean and the + Sturmian prs of p, q, on one hand, and the subresultant prs of p, q, + on the other. + + References + ========== + 1. Akritas, A. G., G.I. Malaschonok and P.S. Vigklas: ``On the Remainders + Obtained in Finding the Greatest Common Divisor of Two Polynomials.'' + Serdica Journal of Computing, 9(2) (2015), 123-138. + + 2. https://planetMath.org/sturmstheorem + + 3. Akritas, A. G., G.I. Malaschonok and P.S. Vigklas: ``A Basic Result on + the Theory of Subresultants.'' Serdica Journal of Computing 10 (2016), No.1, 31-48. + + ''' + if (p.as_poly().is_univariate and q.as_poly().is_univariate and + p.as_poly().gens == q.as_poly().gens): + delta = (degree(p, x) - degree(q, x) + 1) + return rem(Abs(LC(q, x))**delta * p, q, x) + else: + return prem(p, q, x) + +def quo_z(p, q, x): + """ + Intended mainly for p, q polynomials in Z[x] so that, + on dividing p by q, the quotient will also be in Z[x]. (However, + it also works fine for polynomials in Q[x].) It is assumed + that degree(p, x) >= degree(q, x). + + It premultiplies p by the _absolute_ value of the leading coefficient + of q, raised to the power deg(p) - deg(q) + 1 and then performs + polynomial division in Q[x], using the function quo(p, q, x). + + By contrast the function pquo(p, q, x) does _not_ use the absolute + value of the leading coefficient of q. + + See also function rem_z(p, q, x) for additional comments and references. + + """ + if (p.as_poly().is_univariate and q.as_poly().is_univariate and + p.as_poly().gens == q.as_poly().gens): + delta = (degree(p, x) - degree(q, x) + 1) + return quo(Abs(LC(q, x))**delta * p, q, x) + else: + return pquo(p, q, x) + +def subresultants_amv(f, g, x): + """ + p, q are polynomials in Z[x] or Q[x]. It is assumed + that degree(f, x) >= degree(g, x). + + Computes the subresultant prs of p and q in Z[x] or Q[x]; + the coefficients of the polynomials in the sequence are + subresultants. That is, they are determinants of appropriately + selected submatrices of sylvester1, Sylvester's matrix of 1840. + + To compute the coefficients, no determinant evaluation takes place. + Instead, polynomial divisions in Z[x] or Q[x] are performed, using + the function rem_z(p, q, x); the coefficients of the remainders + computed this way become subresultants with the help of the + Akritas-Malaschonok-Vigklas Theorem of 2015 and the Collins-Brown- + Traub formula for coefficient reduction. + + If the subresultant prs is complete, then it coincides with the + Euclidean sequence of the polynomials p, q. + + References + ========== + 1. Akritas, A. G., G.I. Malaschonok and P.S. Vigklas: ``A Basic Result + on the Theory of Subresultants.'' Serdica Journal of Computing 10 (2016), No.1, 31-48. + + 2. Akritas, A. G., G.I. Malaschonok and P.S. Vigklas: ``Subresultant Polynomial + remainder Sequences Obtained by Polynomial Divisions in Q[x] or in Z[x].'' + Serdica Journal of Computing 10 (2016), No.3-4, 197-217. + + """ + # make sure neither f nor g is 0 + if f == 0 or g == 0: + return [f, g] + + # make sure proper degrees + d0 = degree(f, x) + d1 = degree(g, x) + if d0 == 0 and d1 == 0: + return [f, g] + if d1 > d0: + d0, d1 = d1, d0 + f, g = g, f + if d0 > 0 and d1 == 0: + return [f, g] + + # initialize + a0 = f + a1 = g + subres_l = [a0, a1] + deg_dif_p1, c = degree(a0, x) - degree(a1, x) + 1, -1 + + # initialize AMV variables + sigma1 = LC(a1, x) # leading coeff of a1 + i, s = 0, 0 # counters for remainders & odd elements + p_odd_index_sum = 0 # contains the sum of p_1, p_3, etc + p0 = deg_dif_p1 - 1 + if p0 % 2 == 1: + s += 1 + phi = floor( (s + 1) / 2 ) + + # compute the first polynomial of the prs + i += 1 + a2 = rem_z(a0, a1, x) / Abs( (-1)**deg_dif_p1 ) # first remainder + sigma2 = LC(a2, x) # leading coeff of a2 + d2 = degree(a2, x) # actual degree of a2 + p1 = d1 - d2 # degree difference + + # sgn_den is the factor, the denominator 1st fraction of (9), + # by which a2 is multiplied to get integer coefficients + sgn_den = compute_sign( sigma1, p0 + 1 ) + + ## compute sign of the 1st fraction in formula (9) of the paper + # numerator + psi = i + phi + p_odd_index_sum + num = (-1)**psi + # denominator + den = sgn_den + + # the sign of the determinant depends on sign(num / den) != 0 + if sign(num / den) > 0: + subres_l.append( a2 ) + else: + subres_l.append( -a2 ) + + # update AMV variable + if p1 % 2 == 1: + s += 1 + + # bring in the missing power of sigma if there was gap + if p1 - 1 > 0: + sgn_den = sgn_den * compute_sign( sigma1, p1 - 1 ) + + # main loop + while d2 >= 1: + phi = floor( (s + 1) / 2 ) + if i%2 == 1: + p_odd_index_sum += p1 # p_i has odd index + a0, a1, d0, d1 = a1, a2, d1, d2 # update polys and degrees + p0 = p1 # update degree difference + i += 1 + sigma0 = -LC(a0) + c = (sigma0**(deg_dif_p1 - 1)) / (c**(deg_dif_p1 - 2)) + deg_dif_p1 = degree(a0, x) - d2 + 1 + a2 = rem_z(a0, a1, x) / Abs( (c**(deg_dif_p1 - 1)) * sigma0 ) + sigma3 = LC(a2, x) # leading coeff of a2 + d2 = degree(a2, x) # actual degree of a2 + p1 = d1 - d2 # degree difference + psi = i + phi + p_odd_index_sum + + # update variables + sigma1, sigma2 = sigma2, sigma3 + + # new sgn_den + sgn_den = compute_sign( sigma1, p0 + 1 ) * sgn_den + + # compute the sign of the first fraction in formula (9) of the paper + # numerator + num = (-1)**psi + # denominator + den = sgn_den + + # the sign of the determinant depends on sign( num / den ) != 0 + if sign(num / den) > 0: + subres_l.append( a2 ) + else: + subres_l.append( -a2 ) + + # update AMV variable + if p1 % 2 ==1: + s += 1 + + # bring in the missing power of sigma if there was gap + if p1 - 1 > 0: + sgn_den = sgn_den * compute_sign( sigma1, p1 - 1 ) + + # gcd is of degree > 0 ? + m = len(subres_l) + if subres_l[m - 1] == nan or subres_l[m - 1] == 0: + subres_l.pop(m - 1) + + return subres_l + +def modified_subresultants_amv(p, q, x): + """ + p, q are polynomials in Z[x] or Q[x]. It is assumed + that degree(p, x) >= degree(q, x). + + Computes the modified subresultant prs of p and q in Z[x] or Q[x], + from the subresultant prs of p and q. + The coefficients of the polynomials in the two sequences differ only + in sign and the factor LC(p)**( deg(p)- deg(q)) as stated in + Theorem 2 of the reference. + + The coefficients of the polynomials in the output sequence are + modified subresultants. That is, they are determinants of appropriately + selected submatrices of sylvester2, Sylvester's matrix of 1853. + + If the modified subresultant prs is complete, and LC( p ) > 0, it coincides + with the (generalized) Sturm's sequence of the polynomials p, q. + + References + ========== + 1. Akritas, A. G., G.I. Malaschonok and P.S. Vigklas: "On the Remainders + Obtained in Finding the Greatest Common Divisor of Two Polynomials." + Serdica Journal of Computing, Serdica Journal of Computing, 9(2) (2015), 123-138. + + """ + # compute the subresultant prs + lst = subresultants_amv(p,q,x) ## any other method would do + + # defensive + if lst == [] or len(lst) == 2: + return lst + + # the coefficients in lst are subresultants and, hence, smaller than those + # of the corresponding modified subresultants by the factor + # LC(lst[0])**( deg(lst[0]) - deg(lst[1])); see Theorem 2. + lcf = LC(lst[0])**( degree(lst[0], x) - degree(lst[1], x) ) + + # Initialize the modified subresultant prs list + subr_seq = [lst[0], lst[1]] + + # compute the degree sequences m_i and j_i of Theorem 2 + deg_seq = [degree(Poly(poly, x), x) for poly in lst] + deg = deg_seq[0] + deg_seq_s = deg_seq[1:-1] + m_seq = [m-1 for m in deg_seq_s] + j_seq = [deg - m for m in m_seq] + + # compute the AMV factors of Theorem 2 + fact = [(-1)**( j*(j-1)/S(2) ) for j in j_seq] + + # shortened list without the first two polys + lst_s = lst[2:] + + # poly lst_s[k] is multiplied times fact[k] and times lcf + # and appended to the subresultant prs list + m = len(fact) + for k in range(m): + if sign(fact[k]) == -1: + subr_seq.append( simplify(-lst_s[k] * lcf) ) + else: + subr_seq.append( simplify(lst_s[k] * lcf) ) + + return subr_seq + +def correct_sign(deg_f, deg_g, s1, rdel, cdel): + """ + Used in various subresultant prs algorithms. + + Evaluates the determinant, (a.k.a. subresultant) of a properly selected + submatrix of s1, Sylvester's matrix of 1840, to get the correct sign + and value of the leading coefficient of a given polynomial remainder. + + deg_f, deg_g are the degrees of the original polynomials p, q for which the + matrix s1 = sylvester(p, q, x, 1) was constructed. + + rdel denotes the expected degree of the remainder; it is the number of + rows to be deleted from each group of rows in s1 as described in the + reference below. + + cdel denotes the expected degree minus the actual degree of the remainder; + it is the number of columns to be deleted --- starting with the last column + forming the square matrix --- from the matrix resulting after the row deletions. + + References + ========== + Akritas, A. G., G.I. Malaschonok and P.S. Vigklas: ``Sturm Sequences + and Modified Subresultant Polynomial Remainder Sequences.'' + Serdica Journal of Computing, Vol. 8, No 1, 29-46, 2014. + + """ + M = s1[:, :] # copy of matrix s1 + + # eliminate rdel rows from the first deg_g rows + for i in range(M.rows - deg_f - 1, M.rows - deg_f - rdel - 1, -1): + M.row_del(i) + + # eliminate rdel rows from the last deg_f rows + for i in range(M.rows - 1, M.rows - rdel - 1, -1): + M.row_del(i) + + # eliminate cdel columns + for i in range(cdel): + M.col_del(M.rows - 1) + + # define submatrix + Md = M[:, 0: M.rows] + + return Md.det() + +def subresultants_rem(p, q, x): + """ + p, q are polynomials in Z[x] or Q[x]. It is assumed + that degree(p, x) >= degree(q, x). + + Computes the subresultant prs of p and q in Z[x] or Q[x]; + the coefficients of the polynomials in the sequence are + subresultants. That is, they are determinants of appropriately + selected submatrices of sylvester1, Sylvester's matrix of 1840. + + To compute the coefficients polynomial divisions in Q[x] are + performed, using the function rem(p, q, x). The coefficients + of the remainders computed this way become subresultants by evaluating + one subresultant per remainder --- that of the leading coefficient. + This way we obtain the correct sign and value of the leading coefficient + of the remainder and we easily ``force'' the rest of the coefficients + to become subresultants. + + If the subresultant prs is complete, then it coincides with the + Euclidean sequence of the polynomials p, q. + + References + ========== + 1. Akritas, A. G.:``Three New Methods for Computing Subresultant + Polynomial Remainder Sequences (PRS's).'' Serdica Journal of Computing 9(1) (2015), 1-26. + + """ + # make sure neither p nor q is 0 + if p == 0 or q == 0: + return [p, q] + + # make sure proper degrees + f, g = p, q + n = deg_f = degree(f, x) + m = deg_g = degree(g, x) + if n == 0 and m == 0: + return [f, g] + if n < m: + n, m, deg_f, deg_g, f, g = m, n, deg_g, deg_f, g, f + if n > 0 and m == 0: + return [f, g] + + # initialize + s1 = sylvester(f, g, x, 1) + sr_list = [f, g] # subresultant list + + # main loop + while deg_g > 0: + r = rem(p, q, x) + d = degree(r, x) + if d < 0: + return sr_list + + # make coefficients subresultants evaluating ONE determinant + exp_deg = deg_g - 1 # expected degree + sign_value = correct_sign(n, m, s1, exp_deg, exp_deg - d) + r = simplify((r / LC(r, x)) * sign_value) + + # append poly with subresultant coeffs + sr_list.append(r) + + # update degrees and polys + deg_f, deg_g = deg_g, d + p, q = q, r + + # gcd is of degree > 0 ? + m = len(sr_list) + if sr_list[m - 1] == nan or sr_list[m - 1] == 0: + sr_list.pop(m - 1) + + return sr_list + +def pivot(M, i, j): + ''' + M is a matrix, and M[i, j] specifies the pivot element. + + All elements below M[i, j], in the j-th column, will + be zeroed, if they are not already 0, according to + Dodgson-Bareiss' integer preserving transformations. + + References + ========== + 1. Akritas, A. G.: ``A new method for computing polynomial greatest + common divisors and polynomial remainder sequences.'' + Numerische MatheMatik 52, 119-127, 1988. + + 2. Akritas, A. G., G.I. Malaschonok and P.S. Vigklas: ``On a Theorem + by Van Vleck Regarding Sturm Sequences.'' + Serdica Journal of Computing, 7, No 4, 101-134, 2013. + + ''' + ma = M[:, :] # copy of matrix M + rs = ma.rows # No. of rows + cs = ma.cols # No. of cols + for r in range(i+1, rs): + if ma[r, j] != 0: + for c in range(j + 1, cs): + ma[r, c] = ma[i, j] * ma[r, c] - ma[i, c] * ma[r, j] + ma[r, j] = 0 + return ma + +def rotate_r(L, k): + ''' + Rotates right by k. L is a row of a matrix or a list. + + ''' + ll = list(L) + if ll == []: + return [] + for i in range(k): + el = ll.pop(len(ll) - 1) + ll.insert(0, el) + return ll if isinstance(L, list) else Matrix([ll]) + +def rotate_l(L, k): + ''' + Rotates left by k. L is a row of a matrix or a list. + + ''' + ll = list(L) + if ll == []: + return [] + for i in range(k): + el = ll.pop(0) + ll.insert(len(ll) - 1, el) + return ll if isinstance(L, list) else Matrix([ll]) + +def row2poly(row, deg, x): + ''' + Converts the row of a matrix to a poly of degree deg and variable x. + Some entries at the beginning and/or at the end of the row may be zero. + + ''' + k = 0 + poly = [] + leng = len(row) + + # find the beginning of the poly ; i.e. the first + # non-zero element of the row + while row[k] == 0: + k = k + 1 + + # append the next deg + 1 elements to poly + for j in range( deg + 1): + if k + j <= leng: + poly.append(row[k + j]) + + return Poly(poly, x) + +def create_ma(deg_f, deg_g, row1, row2, col_num): + ''' + Creates a ``small'' matrix M to be triangularized. + + deg_f, deg_g are the degrees of the divident and of the + divisor polynomials respectively, deg_g > deg_f. + + The coefficients of the divident poly are the elements + in row2 and those of the divisor poly are the elements + in row1. + + col_num defines the number of columns of the matrix M. + + ''' + if deg_g - deg_f >= 1: + print('Reverse degrees') + return + + m = zeros(deg_f - deg_g + 2, col_num) + + for i in range(deg_f - deg_g + 1): + m[i, :] = rotate_r(row1, i) + m[deg_f - deg_g + 1, :] = row2 + + return m + +def find_degree(M, deg_f): + ''' + Finds the degree of the poly corresponding (after triangularization) + to the _last_ row of the ``small'' matrix M, created by create_ma(). + + deg_f is the degree of the divident poly. + If _last_ row is all 0's returns None. + + ''' + j = deg_f + for i in range(0, M.cols): + if M[M.rows - 1, i] == 0: + j = j - 1 + else: + return max(j, 0) + +def final_touches(s2, r, deg_g): + """ + s2 is sylvester2, r is the row pointer in s2, + deg_g is the degree of the poly last inserted in s2. + + After a gcd of degree > 0 has been found with Van Vleck's + method, and was inserted into s2, if its last term is not + in the last column of s2, then it is inserted as many + times as needed, rotated right by one each time, until + the condition is met. + + """ + R = s2.row(r-1) + + # find the first non zero term + for i in range(s2.cols): + if R[0,i] == 0: + continue + else: + break + + # missing rows until last term is in last column + mr = s2.cols - (i + deg_g + 1) + + # insert them by replacing the existing entries in the row + i = 0 + while mr != 0 and r + i < s2.rows : + s2[r + i, : ] = rotate_r(R, i + 1) + i += 1 + mr -= 1 + + return s2 + +def subresultants_vv(p, q, x, method = 0): + """ + p, q are polynomials in Z[x] (intended) or Q[x]. It is assumed + that degree(p, x) >= degree(q, x). + + Computes the subresultant prs of p, q by triangularizing, + in Z[x] or in Q[x], all the smaller matrices encountered in the + process of triangularizing sylvester2, Sylvester's matrix of 1853; + see references 1 and 2 for Van Vleck's method. With each remainder, + sylvester2 gets updated and is prepared to be printed if requested. + + If sylvester2 has small dimensions and you want to see the final, + triangularized matrix use this version with method=1; otherwise, + use either this version with method=0 (default) or the faster version, + subresultants_vv_2(p, q, x), where sylvester2 is used implicitly. + + Sylvester's matrix sylvester1 is also used to compute one + subresultant per remainder; namely, that of the leading + coefficient, in order to obtain the correct sign and to + force the remainder coefficients to become subresultants. + + If the subresultant prs is complete, then it coincides with the + Euclidean sequence of the polynomials p, q. + + If the final, triangularized matrix s2 is printed, then: + (a) if deg(p) - deg(q) > 1 or deg( gcd(p, q) ) > 0, several + of the last rows in s2 will remain unprocessed; + (b) if deg(p) - deg(q) == 0, p will not appear in the final matrix. + + References + ========== + 1. Akritas, A. G.: ``A new method for computing polynomial greatest + common divisors and polynomial remainder sequences.'' + Numerische MatheMatik 52, 119-127, 1988. + + 2. Akritas, A. G., G.I. Malaschonok and P.S. Vigklas: ``On a Theorem + by Van Vleck Regarding Sturm Sequences.'' + Serdica Journal of Computing, 7, No 4, 101-134, 2013. + + 3. Akritas, A. G.:``Three New Methods for Computing Subresultant + Polynomial Remainder Sequences (PRS's).'' Serdica Journal of Computing 9(1) (2015), 1-26. + + """ + # make sure neither p nor q is 0 + if p == 0 or q == 0: + return [p, q] + + # make sure proper degrees + f, g = p, q + n = deg_f = degree(f, x) + m = deg_g = degree(g, x) + if n == 0 and m == 0: + return [f, g] + if n < m: + n, m, deg_f, deg_g, f, g = m, n, deg_g, deg_f, g, f + if n > 0 and m == 0: + return [f, g] + + # initialize + s1 = sylvester(f, g, x, 1) + s2 = sylvester(f, g, x, 2) + sr_list = [f, g] + col_num = 2 * n # columns in s2 + + # make two rows (row0, row1) of poly coefficients + row0 = Poly(f, x, domain = QQ).all_coeffs() + leng0 = len(row0) + for i in range(col_num - leng0): + row0.append(0) + row0 = Matrix([row0]) + row1 = Poly(g,x, domain = QQ).all_coeffs() + leng1 = len(row1) + for i in range(col_num - leng1): + row1.append(0) + row1 = Matrix([row1]) + + # row pointer for deg_f - deg_g == 1; may be reset below + r = 2 + + # modify first rows of s2 matrix depending on poly degrees + if deg_f - deg_g > 1: + r = 1 + # replacing the existing entries in the rows of s2, + # insert row0 (deg_f - deg_g - 1) times, rotated each time + for i in range(deg_f - deg_g - 1): + s2[r + i, : ] = rotate_r(row0, i + 1) + r = r + deg_f - deg_g - 1 + # insert row1 (deg_f - deg_g) times, rotated each time + for i in range(deg_f - deg_g): + s2[r + i, : ] = rotate_r(row1, r + i) + r = r + deg_f - deg_g + + if deg_f - deg_g == 0: + r = 0 + + # main loop + while deg_g > 0: + # create a small matrix M, and triangularize it; + M = create_ma(deg_f, deg_g, row1, row0, col_num) + # will need only the first and last rows of M + for i in range(deg_f - deg_g + 1): + M1 = pivot(M, i, i) + M = M1[:, :] + + # treat last row of M as poly; find its degree + d = find_degree(M, deg_f) + if d is None: + break + exp_deg = deg_g - 1 + + # evaluate one determinant & make coefficients subresultants + sign_value = correct_sign(n, m, s1, exp_deg, exp_deg - d) + poly = row2poly(M[M.rows - 1, :], d, x) + temp2 = LC(poly, x) + poly = simplify((poly / temp2) * sign_value) + + # update s2 by inserting first row of M as needed + row0 = M[0, :] + for i in range(deg_g - d): + s2[r + i, :] = rotate_r(row0, r + i) + r = r + deg_g - d + + # update s2 by inserting last row of M as needed + row1 = rotate_l(M[M.rows - 1, :], deg_f - d) + row1 = (row1 / temp2) * sign_value + for i in range(deg_g - d): + s2[r + i, :] = rotate_r(row1, r + i) + r = r + deg_g - d + + # update degrees + deg_f, deg_g = deg_g, d + + # append poly with subresultant coeffs + sr_list.append(poly) + + # final touches to print the s2 matrix + if method != 0 and s2.rows > 2: + s2 = final_touches(s2, r, deg_g) + pprint(s2) + elif method != 0 and s2.rows == 2: + s2[1, :] = rotate_r(s2.row(1), 1) + pprint(s2) + + return sr_list + +def subresultants_vv_2(p, q, x): + """ + p, q are polynomials in Z[x] (intended) or Q[x]. It is assumed + that degree(p, x) >= degree(q, x). + + Computes the subresultant prs of p, q by triangularizing, + in Z[x] or in Q[x], all the smaller matrices encountered in the + process of triangularizing sylvester2, Sylvester's matrix of 1853; + see references 1 and 2 for Van Vleck's method. + + If the sylvester2 matrix has big dimensions use this version, + where sylvester2 is used implicitly. If you want to see the final, + triangularized matrix sylvester2, then use the first version, + subresultants_vv(p, q, x, 1). + + sylvester1, Sylvester's matrix of 1840, is also used to compute + one subresultant per remainder; namely, that of the leading + coefficient, in order to obtain the correct sign and to + ``force'' the remainder coefficients to become subresultants. + + If the subresultant prs is complete, then it coincides with the + Euclidean sequence of the polynomials p, q. + + References + ========== + 1. Akritas, A. G.: ``A new method for computing polynomial greatest + common divisors and polynomial remainder sequences.'' + Numerische MatheMatik 52, 119-127, 1988. + + 2. Akritas, A. G., G.I. Malaschonok and P.S. Vigklas: ``On a Theorem + by Van Vleck Regarding Sturm Sequences.'' + Serdica Journal of Computing, 7, No 4, 101-134, 2013. + + 3. Akritas, A. G.:``Three New Methods for Computing Subresultant + Polynomial Remainder Sequences (PRS's).'' Serdica Journal of Computing 9(1) (2015), 1-26. + + """ + # make sure neither p nor q is 0 + if p == 0 or q == 0: + return [p, q] + + # make sure proper degrees + f, g = p, q + n = deg_f = degree(f, x) + m = deg_g = degree(g, x) + if n == 0 and m == 0: + return [f, g] + if n < m: + n, m, deg_f, deg_g, f, g = m, n, deg_g, deg_f, g, f + if n > 0 and m == 0: + return [f, g] + + # initialize + s1 = sylvester(f, g, x, 1) + sr_list = [f, g] # subresultant list + col_num = 2 * n # columns in sylvester2 + + # make two rows (row0, row1) of poly coefficients + row0 = Poly(f, x, domain = QQ).all_coeffs() + leng0 = len(row0) + for i in range(col_num - leng0): + row0.append(0) + row0 = Matrix([row0]) + row1 = Poly(g,x, domain = QQ).all_coeffs() + leng1 = len(row1) + for i in range(col_num - leng1): + row1.append(0) + row1 = Matrix([row1]) + + # main loop + while deg_g > 0: + # create a small matrix M, and triangularize it + M = create_ma(deg_f, deg_g, row1, row0, col_num) + for i in range(deg_f - deg_g + 1): + M1 = pivot(M, i, i) + M = M1[:, :] + + # treat last row of M as poly; find its degree + d = find_degree(M, deg_f) + if d is None: + return sr_list + exp_deg = deg_g - 1 + + # evaluate one determinant & make coefficients subresultants + sign_value = correct_sign(n, m, s1, exp_deg, exp_deg - d) + poly = row2poly(M[M.rows - 1, :], d, x) + poly = simplify((poly / LC(poly, x)) * sign_value) + + # append poly with subresultant coeffs + sr_list.append(poly) + + # update degrees and rows + deg_f, deg_g = deg_g, d + row0 = row1 + row1 = Poly(poly, x, domain = QQ).all_coeffs() + leng1 = len(row1) + for i in range(col_num - leng1): + row1.append(0) + row1 = Matrix([row1]) + + return sr_list diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/__init__.py b/phivenv/Lib/site-packages/sympy/polys/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a56d6bc47dbc36b0782382701142442518be8792 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_appellseqs.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_appellseqs.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ebc81d64730c4e0839b9b2cb10a5811896efe8e6 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_appellseqs.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_constructor.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_constructor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0d95ea312cccdfdf03f1ebe55f86b292ad7a087 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_constructor.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_densearith.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_densearith.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43c2e59b5425ec4b00d9801fb009286db866448a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_densearith.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_densebasic.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_densebasic.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..014502ee5211d67b1d0128d23d9f2085f1c927d8 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_densebasic.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_densetools.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_densetools.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6831885f8a38db75ebc11cf1ff451426b1e69ca4 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_densetools.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_dispersion.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_dispersion.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97b7b5e1f30193905ababf0c413fce3559dd94d9 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_dispersion.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_distributedmodules.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_distributedmodules.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63974d2f61504fe1a965f3c87863341427d957e9 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_distributedmodules.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_euclidtools.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_euclidtools.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8e6348fb28434b15ad884b7a7bf75d1abd204f1 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_euclidtools.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_factortools.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_factortools.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26c5e93bd8050b7ed58720df8c4fbfa9a99db6de Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_factortools.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_fields.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_fields.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db321382c6277aa9d41f2c56d9f649be24eacc0d Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_fields.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_galoistools.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_galoistools.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e5e6849b4f6aac21cdabf46355b33d155cc7c28 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_galoistools.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_groebnertools.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_groebnertools.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65cc06f7fdc9b9d72c4437b0b2bd8702fad36cd5 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_groebnertools.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_heuristicgcd.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_heuristicgcd.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..957a990481cd362cecc284b2d706424449fea122 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_heuristicgcd.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_hypothesis.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_hypothesis.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a366185aa302131ad72f20802b51cd291c1ba9d Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_hypothesis.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_injections.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_injections.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..057c45c2978a8dc955263d4e40c494642a9b8f8f Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_injections.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_modulargcd.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_modulargcd.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd9b35070897f9b80d881dacc8abf780bf772402 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_modulargcd.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_monomials.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_monomials.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f3c60757da44675068027b14aa53b56269d5851e Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_monomials.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_multivariate_resultants.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_multivariate_resultants.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a1b553df724871ab0c90f17e38201268378ea50b Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_multivariate_resultants.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_orderings.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_orderings.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf155fe6cab0de0006558b0671afa5e2dd1eaa4e Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_orderings.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_orthopolys.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_orthopolys.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2c091fe4b7e649827e96bddf1cd36a4601d1a74 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_orthopolys.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_partfrac.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_partfrac.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65bed34454b07147fee0cac7ffdf1a7c2ed17f45 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_partfrac.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_polyclasses.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_polyclasses.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed84b118537046ab068fcad918f8fff8d17aca09 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_polyclasses.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_polyfuncs.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_polyfuncs.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6489f94fe1b2d8ca866d8327d731c6eb253453fb Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_polyfuncs.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_polymatrix.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_polymatrix.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dfc794936068046c9c45d08fa3eb4b010ebb671c Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_polymatrix.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_polyoptions.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_polyoptions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5139289f70461718d9a5bf8b7bbe09579c705fde Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_polyoptions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_polyroots.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_polyroots.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb678e78d0e067bfae248661dc1642ed4a7150ed Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_polyroots.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_polyutils.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_polyutils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb58adde3c99efc4b90a780086b3827d537067a0 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_polyutils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_puiseux.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_puiseux.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7aa7d86fbb838ee2adeab253037be8c31ddc83d Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_puiseux.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_pythonrational.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_pythonrational.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae3503ce04e876cb0b2e55ee31b3e66665847507 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_pythonrational.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_rationaltools.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_rationaltools.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ebaf8c5d4ec9f0fbc0aa6c31ce0979801fe096ba Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_rationaltools.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_ring_series.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_ring_series.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b56322ba887141c7fdf92ca21b0c959785c007d Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_ring_series.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_rings.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_rings.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..201d980b4beeaca52ac9eac89984729ac191d304 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_rings.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_rootisolation.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_rootisolation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd664649e7de18eb036f44bf4b5bf38857a54e6f Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_rootisolation.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_rootoftools.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_rootoftools.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79999dd86e2c59c7d781b696a98b6a17e598170a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_rootoftools.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_solvers.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_solvers.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..61a62db50e506f0195018da56e5f17ed645d0e3c Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_solvers.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_specialpolys.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_specialpolys.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..381a937ea30ce31026726c4d78607eaf5c05ecec Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_specialpolys.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_sqfreetools.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_sqfreetools.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1923efc9feef2bf15e55dbb276df54e3bf665dc8 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_sqfreetools.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_subresultants_qq_zz.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_subresultants_qq_zz.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a6ce947fcba7a44c62a01c7792c8ce95daf01b8 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/polys/tests/__pycache__/test_subresultants_qq_zz.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/test_appellseqs.py b/phivenv/Lib/site-packages/sympy/polys/tests/test_appellseqs.py new file mode 100644 index 0000000000000000000000000000000000000000..f4718a2da272ac6f36a968572dc246ebc699e5c4 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/tests/test_appellseqs.py @@ -0,0 +1,91 @@ +"""Tests for efficient functions for generating Appell sequences.""" +from sympy.core.numbers import Rational as Q +from sympy.polys.polytools import Poly +from sympy.testing.pytest import raises +from sympy.polys.appellseqs import (bernoulli_poly, bernoulli_c_poly, + euler_poly, genocchi_poly, andre_poly) +from sympy.abc import x + +def test_bernoulli_poly(): + raises(ValueError, lambda: bernoulli_poly(-1, x)) + assert bernoulli_poly(1, x, polys=True) == Poly(x - Q(1,2)) + + assert bernoulli_poly(0, x) == 1 + assert bernoulli_poly(1, x) == x - Q(1,2) + assert bernoulli_poly(2, x) == x**2 - x + Q(1,6) + assert bernoulli_poly(3, x) == x**3 - Q(3,2)*x**2 + Q(1,2)*x + assert bernoulli_poly(4, x) == x**4 - 2*x**3 + x**2 - Q(1,30) + assert bernoulli_poly(5, x) == x**5 - Q(5,2)*x**4 + Q(5,3)*x**3 - Q(1,6)*x + assert bernoulli_poly(6, x) == x**6 - 3*x**5 + Q(5,2)*x**4 - Q(1,2)*x**2 + Q(1,42) + + assert bernoulli_poly(1).dummy_eq(x - Q(1,2)) + assert bernoulli_poly(1, polys=True) == Poly(x - Q(1,2)) + +def test_bernoulli_c_poly(): + raises(ValueError, lambda: bernoulli_c_poly(-1, x)) + assert bernoulli_c_poly(1, x, polys=True) == Poly(x, domain='QQ') + + assert bernoulli_c_poly(0, x) == 1 + assert bernoulli_c_poly(1, x) == x + assert bernoulli_c_poly(2, x) == x**2 - Q(1,3) + assert bernoulli_c_poly(3, x) == x**3 - x + assert bernoulli_c_poly(4, x) == x**4 - 2*x**2 + Q(7,15) + assert bernoulli_c_poly(5, x) == x**5 - Q(10,3)*x**3 + Q(7,3)*x + assert bernoulli_c_poly(6, x) == x**6 - 5*x**4 + 7*x**2 - Q(31,21) + + assert bernoulli_c_poly(1).dummy_eq(x) + assert bernoulli_c_poly(1, polys=True) == Poly(x, domain='QQ') + + assert 2**8 * bernoulli_poly(8, (x+1)/2).expand() == bernoulli_c_poly(8, x) + assert 2**9 * bernoulli_poly(9, (x+1)/2).expand() == bernoulli_c_poly(9, x) + +def test_genocchi_poly(): + raises(ValueError, lambda: genocchi_poly(-1, x)) + assert genocchi_poly(2, x, polys=True) == Poly(-2*x + 1) + + assert genocchi_poly(0, x) == 0 + assert genocchi_poly(1, x) == -1 + assert genocchi_poly(2, x) == 1 - 2*x + assert genocchi_poly(3, x) == 3*x - 3*x**2 + assert genocchi_poly(4, x) == -1 + 6*x**2 - 4*x**3 + assert genocchi_poly(5, x) == -5*x + 10*x**3 - 5*x**4 + assert genocchi_poly(6, x) == 3 - 15*x**2 + 15*x**4 - 6*x**5 + + assert genocchi_poly(2).dummy_eq(-2*x + 1) + assert genocchi_poly(2, polys=True) == Poly(-2*x + 1) + + assert 2 * (bernoulli_poly(8, x) - bernoulli_c_poly(8, x)) == genocchi_poly(8, x) + assert 2 * (bernoulli_poly(9, x) - bernoulli_c_poly(9, x)) == genocchi_poly(9, x) + +def test_euler_poly(): + raises(ValueError, lambda: euler_poly(-1, x)) + assert euler_poly(1, x, polys=True) == Poly(x - Q(1,2)) + + assert euler_poly(0, x) == 1 + assert euler_poly(1, x) == x - Q(1,2) + assert euler_poly(2, x) == x**2 - x + assert euler_poly(3, x) == x**3 - Q(3,2)*x**2 + Q(1,4) + assert euler_poly(4, x) == x**4 - 2*x**3 + x + assert euler_poly(5, x) == x**5 - Q(5,2)*x**4 + Q(5,2)*x**2 - Q(1,2) + assert euler_poly(6, x) == x**6 - 3*x**5 + 5*x**3 - 3*x + + assert euler_poly(1).dummy_eq(x - Q(1,2)) + assert euler_poly(1, polys=True) == Poly(x - Q(1,2)) + + assert genocchi_poly(9, x) == euler_poly(8, x) * -9 + assert genocchi_poly(10, x) == euler_poly(9, x) * -10 + +def test_andre_poly(): + raises(ValueError, lambda: andre_poly(-1, x)) + assert andre_poly(1, x, polys=True) == Poly(x) + + assert andre_poly(0, x) == 1 + assert andre_poly(1, x) == x + assert andre_poly(2, x) == x**2 - 1 + assert andre_poly(3, x) == x**3 - 3*x + assert andre_poly(4, x) == x**4 - 6*x**2 + 5 + assert andre_poly(5, x) == x**5 - 10*x**3 + 25*x + assert andre_poly(6, x) == x**6 - 15*x**4 + 75*x**2 - 61 + + assert andre_poly(1).dummy_eq(x) + assert andre_poly(1, polys=True) == Poly(x) diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/test_constructor.py b/phivenv/Lib/site-packages/sympy/polys/tests/test_constructor.py new file mode 100644 index 0000000000000000000000000000000000000000..b02d8a4b360dd09b993bbed80cdec307d09908fc --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/tests/test_constructor.py @@ -0,0 +1,236 @@ +"""Tests for tools for constructing domains for expressions. """ + +from sympy.testing.pytest import tooslow + +from sympy.polys.constructor import construct_domain +from sympy.polys.domains import ZZ, QQ, ZZ_I, QQ_I, RR, CC, EX +from sympy.polys.domains.realfield import RealField +from sympy.polys.domains.complexfield import ComplexField + +from sympy.core import (Catalan, GoldenRatio) +from sympy.core.numbers import (E, Float, I, Rational, pi) +from sympy.core.singleton import S +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import sin +from sympy import rootof + +from sympy.abc import x, y + + +def test_construct_domain(): + + assert construct_domain([1, 2, 3]) == (ZZ, [ZZ(1), ZZ(2), ZZ(3)]) + assert construct_domain([1, 2, 3], field=True) == (QQ, [QQ(1), QQ(2), QQ(3)]) + + assert construct_domain([S.One, S(2), S(3)]) == (ZZ, [ZZ(1), ZZ(2), ZZ(3)]) + assert construct_domain([S.One, S(2), S(3)], field=True) == (QQ, [QQ(1), QQ(2), QQ(3)]) + + assert construct_domain([S.Half, S(2)]) == (QQ, [QQ(1, 2), QQ(2)]) + result = construct_domain([3.14, 1, S.Half]) + assert isinstance(result[0], RealField) + assert result[1] == [RR(3.14), RR(1.0), RR(0.5)] + + result = construct_domain([3.14, I, S.Half]) + assert isinstance(result[0], ComplexField) + assert result[1] == [CC(3.14), CC(1.0j), CC(0.5)] + + assert construct_domain([1.0+I]) == (CC, [CC(1.0, 1.0)]) + assert construct_domain([2.0+3.0*I]) == (CC, [CC(2.0, 3.0)]) + + assert construct_domain([1, I]) == (ZZ_I, [ZZ_I(1, 0), ZZ_I(0, 1)]) + assert construct_domain([1, I/2]) == (QQ_I, [QQ_I(1, 0), QQ_I(0, S.Half)]) + + assert construct_domain([3.14, sqrt(2)], extension=None) == (EX, [EX(3.14), EX(sqrt(2))]) + assert construct_domain([3.14, sqrt(2)], extension=True) == (EX, [EX(3.14), EX(sqrt(2))]) + + assert construct_domain([1, sqrt(2)], extension=None) == (EX, [EX(1), EX(sqrt(2))]) + + assert construct_domain([x, sqrt(x)]) == (EX, [EX(x), EX(sqrt(x))]) + assert construct_domain([x, sqrt(x), sqrt(y)]) == (EX, [EX(x), EX(sqrt(x)), EX(sqrt(y))]) + + alg = QQ.algebraic_field(sqrt(2)) + + assert construct_domain([7, S.Half, sqrt(2)], extension=True) == \ + (alg, [alg.convert(7), alg.convert(S.Half), alg.convert(sqrt(2))]) + + alg = QQ.algebraic_field(sqrt(2) + sqrt(3)) + + assert construct_domain([7, sqrt(2), sqrt(3)], extension=True) == \ + (alg, [alg.convert(7), alg.convert(sqrt(2)), alg.convert(sqrt(3))]) + + dom = ZZ[x] + + assert construct_domain([2*x, 3]) == \ + (dom, [dom.convert(2*x), dom.convert(3)]) + + dom = ZZ[x, y] + + assert construct_domain([2*x, 3*y]) == \ + (dom, [dom.convert(2*x), dom.convert(3*y)]) + + dom = QQ[x] + + assert construct_domain([x/2, 3]) == \ + (dom, [dom.convert(x/2), dom.convert(3)]) + + dom = QQ[x, y] + + assert construct_domain([x/2, 3*y]) == \ + (dom, [dom.convert(x/2), dom.convert(3*y)]) + + dom = ZZ_I[x] + + assert construct_domain([2*x, I]) == \ + (dom, [dom.convert(2*x), dom.convert(I)]) + + dom = ZZ_I[x, y] + + assert construct_domain([2*x, I*y]) == \ + (dom, [dom.convert(2*x), dom.convert(I*y)]) + + dom = QQ_I[x] + + assert construct_domain([x/2, I]) == \ + (dom, [dom.convert(x/2), dom.convert(I)]) + + dom = QQ_I[x, y] + + assert construct_domain([x/2, I*y]) == \ + (dom, [dom.convert(x/2), dom.convert(I*y)]) + + dom = RR[x] + + assert construct_domain([x/2, 3.5]) == \ + (dom, [dom.convert(x/2), dom.convert(3.5)]) + + dom = RR[x, y] + + assert construct_domain([x/2, 3.5*y]) == \ + (dom, [dom.convert(x/2), dom.convert(3.5*y)]) + + dom = CC[x] + + assert construct_domain([I*x/2, 3.5]) == \ + (dom, [dom.convert(I*x/2), dom.convert(3.5)]) + + dom = CC[x, y] + + assert construct_domain([I*x/2, 3.5*y]) == \ + (dom, [dom.convert(I*x/2), dom.convert(3.5*y)]) + + dom = CC[x] + + assert construct_domain([x/2, I*3.5]) == \ + (dom, [dom.convert(x/2), dom.convert(I*3.5)]) + + dom = CC[x, y] + + assert construct_domain([x/2, I*3.5*y]) == \ + (dom, [dom.convert(x/2), dom.convert(I*3.5*y)]) + + dom = ZZ.frac_field(x) + + assert construct_domain([2/x, 3]) == \ + (dom, [dom.convert(2/x), dom.convert(3)]) + + dom = ZZ.frac_field(x, y) + + assert construct_domain([2/x, 3*y]) == \ + (dom, [dom.convert(2/x), dom.convert(3*y)]) + + dom = RR.frac_field(x) + + assert construct_domain([2/x, 3.5]) == \ + (dom, [dom.convert(2/x), dom.convert(3.5)]) + + dom = RR.frac_field(x, y) + + assert construct_domain([2/x, 3.5*y]) == \ + (dom, [dom.convert(2/x), dom.convert(3.5*y)]) + + dom = RealField(prec=336)[x] + + assert construct_domain([pi.evalf(100)*x]) == \ + (dom, [dom.convert(pi.evalf(100)*x)]) + + assert construct_domain(2) == (ZZ, ZZ(2)) + assert construct_domain(S(2)/3) == (QQ, QQ(2, 3)) + assert construct_domain(Rational(2, 3)) == (QQ, QQ(2, 3)) + + assert construct_domain({}) == (ZZ, {}) + + +def test_complex_exponential(): + w = exp(-I*2*pi/3, evaluate=False) + alg = QQ.algebraic_field(w) + assert construct_domain([w**2, w, 1], extension=True) == ( + alg, + [alg.convert(w**2), + alg.convert(w), + alg.convert(1)] + ) + + +def test_rootof(): + r1 = rootof(x**3 + x + 1, 0) + r2 = rootof(x**3 + x + 1, 1) + K1 = QQ.algebraic_field(r1) + K2 = QQ.algebraic_field(r2) + assert construct_domain([r1]) == (EX, [EX(r1)]) + assert construct_domain([r2]) == (EX, [EX(r2)]) + assert construct_domain([r1, r2]) == (EX, [EX(r1), EX(r2)]) + + assert construct_domain([r1], extension=True) == ( + K1, [K1.from_sympy(r1)]) + assert construct_domain([r2], extension=True) == ( + K2, [K2.from_sympy(r2)]) + + +@tooslow +def test_rootof_primitive_element(): + r1 = rootof(x**3 + x + 1, 0) + r2 = rootof(x**3 + x + 1, 1) + K12 = QQ.algebraic_field(r1 + r2) + assert construct_domain([r1, r2], extension=True) == ( + K12, [K12.from_sympy(r1), K12.from_sympy(r2)]) + + +def test_composite_option(): + assert construct_domain({(1,): sin(y)}, composite=False) == \ + (EX, {(1,): EX(sin(y))}) + + assert construct_domain({(1,): y}, composite=False) == \ + (EX, {(1,): EX(y)}) + + assert construct_domain({(1, 1): 1}, composite=False) == \ + (ZZ, {(1, 1): 1}) + + assert construct_domain({(1, 0): y}, composite=False) == \ + (EX, {(1, 0): EX(y)}) + + +def test_precision(): + f1 = Float("1.01") + f2 = Float("1.0000000000000000000001") + for u in [1, 1e-2, 1e-6, 1e-13, 1e-14, 1e-16, 1e-20, 1e-100, 1e-300, + f1, f2]: + result = construct_domain([u]) + v = float(result[1][0]) + assert abs(u - v) / u < 1e-14 # Test relative accuracy + + result = construct_domain([f1]) + y = result[1][0] + assert y-1 > 1e-50 + + result = construct_domain([f2]) + y = result[1][0] + assert y-1 > 1e-50 + + +def test_issue_11538(): + for n in [E, pi, Catalan]: + assert construct_domain(n)[0] == ZZ[n] + assert construct_domain(x + n)[0] == ZZ[x, n] + assert construct_domain(GoldenRatio)[0] == EX + assert construct_domain(x + GoldenRatio)[0] == EX diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/test_densearith.py b/phivenv/Lib/site-packages/sympy/polys/tests/test_densearith.py new file mode 100644 index 0000000000000000000000000000000000000000..ebb29d50867ad578274ed11c766e0515d8e4da35 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/tests/test_densearith.py @@ -0,0 +1,1007 @@ +"""Tests for dense recursive polynomials' arithmetics. """ + +from sympy.external.gmpy import GROUND_TYPES + +from sympy.polys.densebasic import ( + dup_normal, dmp_normal, +) + +from sympy.polys.densearith import ( + dup_add_term, dmp_add_term, + dup_sub_term, dmp_sub_term, + dup_mul_term, dmp_mul_term, + dup_add_ground, dmp_add_ground, + dup_sub_ground, dmp_sub_ground, + dup_mul_ground, dmp_mul_ground, + dup_quo_ground, dmp_quo_ground, + dup_exquo_ground, dmp_exquo_ground, + dup_lshift, dup_rshift, + dup_abs, dmp_abs, + dup_neg, dmp_neg, + dup_add, dmp_add, + dup_sub, dmp_sub, + dup_mul, dmp_mul, + dup_sqr, dmp_sqr, + dup_pow, dmp_pow, + dup_add_mul, dmp_add_mul, + dup_sub_mul, dmp_sub_mul, + dup_pdiv, dup_prem, dup_pquo, dup_pexquo, + dmp_pdiv, dmp_prem, dmp_pquo, dmp_pexquo, + dup_rr_div, dmp_rr_div, + dup_ff_div, dmp_ff_div, + dup_div, dup_rem, dup_quo, dup_exquo, + dmp_div, dmp_rem, dmp_quo, dmp_exquo, + dup_max_norm, dmp_max_norm, + dup_l1_norm, dmp_l1_norm, + dup_l2_norm_squared, dmp_l2_norm_squared, + dup_expand, dmp_expand, +) + +from sympy.polys.polyerrors import ( + ExactQuotientFailed, +) + +from sympy.polys.specialpolys import f_polys, Symbol, Poly +from sympy.polys.domains import FF, ZZ, QQ, CC + +from sympy.testing.pytest import raises + +x = Symbol('x') + +f_0, f_1, f_2, f_3, f_4, f_5, f_6 = [ f.to_dense() for f in f_polys() ] +F_0 = dmp_mul_ground(dmp_normal(f_0, 2, QQ), QQ(1, 7), 2, QQ) + +def test_dup_add_term(): + f = dup_normal([], ZZ) + + assert dup_add_term(f, ZZ(0), 0, ZZ) == dup_normal([], ZZ) + + assert dup_add_term(f, ZZ(1), 0, ZZ) == dup_normal([1], ZZ) + assert dup_add_term(f, ZZ(1), 1, ZZ) == dup_normal([1, 0], ZZ) + assert dup_add_term(f, ZZ(1), 2, ZZ) == dup_normal([1, 0, 0], ZZ) + + f = dup_normal([1, 1, 1], ZZ) + + assert dup_add_term(f, ZZ(1), 0, ZZ) == dup_normal([1, 1, 2], ZZ) + assert dup_add_term(f, ZZ(1), 1, ZZ) == dup_normal([1, 2, 1], ZZ) + assert dup_add_term(f, ZZ(1), 2, ZZ) == dup_normal([2, 1, 1], ZZ) + + assert dup_add_term(f, ZZ(1), 3, ZZ) == dup_normal([1, 1, 1, 1], ZZ) + assert dup_add_term(f, ZZ(1), 4, ZZ) == dup_normal([1, 0, 1, 1, 1], ZZ) + assert dup_add_term(f, ZZ(1), 5, ZZ) == dup_normal([1, 0, 0, 1, 1, 1], ZZ) + assert dup_add_term( + f, ZZ(1), 6, ZZ) == dup_normal([1, 0, 0, 0, 1, 1, 1], ZZ) + + assert dup_add_term(f, ZZ(-1), 2, ZZ) == dup_normal([1, 1], ZZ) + + +def test_dmp_add_term(): + assert dmp_add_term([ZZ(1), ZZ(1), ZZ(1)], ZZ(1), 2, 0, ZZ) == \ + dup_add_term([ZZ(1), ZZ(1), ZZ(1)], ZZ(1), 2, ZZ) + assert dmp_add_term(f_0, [[]], 3, 2, ZZ) == f_0 + assert dmp_add_term(F_0, [[]], 3, 2, QQ) == F_0 + + +def test_dup_sub_term(): + f = dup_normal([], ZZ) + + assert dup_sub_term(f, ZZ(0), 0, ZZ) == dup_normal([], ZZ) + + assert dup_sub_term(f, ZZ(1), 0, ZZ) == dup_normal([-1], ZZ) + assert dup_sub_term(f, ZZ(1), 1, ZZ) == dup_normal([-1, 0], ZZ) + assert dup_sub_term(f, ZZ(1), 2, ZZ) == dup_normal([-1, 0, 0], ZZ) + + f = dup_normal([1, 1, 1], ZZ) + + assert dup_sub_term(f, ZZ(2), 0, ZZ) == dup_normal([ 1, 1, -1], ZZ) + assert dup_sub_term(f, ZZ(2), 1, ZZ) == dup_normal([ 1, -1, 1], ZZ) + assert dup_sub_term(f, ZZ(2), 2, ZZ) == dup_normal([-1, 1, 1], ZZ) + + assert dup_sub_term(f, ZZ(1), 3, ZZ) == dup_normal([-1, 1, 1, 1], ZZ) + assert dup_sub_term(f, ZZ(1), 4, ZZ) == dup_normal([-1, 0, 1, 1, 1], ZZ) + assert dup_sub_term(f, ZZ(1), 5, ZZ) == dup_normal([-1, 0, 0, 1, 1, 1], ZZ) + assert dup_sub_term( + f, ZZ(1), 6, ZZ) == dup_normal([-1, 0, 0, 0, 1, 1, 1], ZZ) + + assert dup_sub_term(f, ZZ(1), 2, ZZ) == dup_normal([1, 1], ZZ) + + +def test_dmp_sub_term(): + assert dmp_sub_term([ZZ(1), ZZ(1), ZZ(1)], ZZ(1), 2, 0, ZZ) == \ + dup_sub_term([ZZ(1), ZZ(1), ZZ(1)], ZZ(1), 2, ZZ) + assert dmp_sub_term(f_0, [[]], 3, 2, ZZ) == f_0 + assert dmp_sub_term(F_0, [[]], 3, 2, QQ) == F_0 + + +def test_dup_mul_term(): + f = dup_normal([], ZZ) + + assert dup_mul_term(f, ZZ(2), 3, ZZ) == dup_normal([], ZZ) + + f = dup_normal([1, 1], ZZ) + + assert dup_mul_term(f, ZZ(0), 3, ZZ) == dup_normal([], ZZ) + + f = dup_normal([1, 2, 3], ZZ) + + assert dup_mul_term(f, ZZ(2), 0, ZZ) == dup_normal([2, 4, 6], ZZ) + assert dup_mul_term(f, ZZ(2), 1, ZZ) == dup_normal([2, 4, 6, 0], ZZ) + assert dup_mul_term(f, ZZ(2), 2, ZZ) == dup_normal([2, 4, 6, 0, 0], ZZ) + assert dup_mul_term(f, ZZ(2), 3, ZZ) == dup_normal([2, 4, 6, 0, 0, 0], ZZ) + + +def test_dmp_mul_term(): + assert dmp_mul_term([ZZ(1), ZZ(2), ZZ(3)], ZZ(2), 1, 0, ZZ) == \ + dup_mul_term([ZZ(1), ZZ(2), ZZ(3)], ZZ(2), 1, ZZ) + + assert dmp_mul_term([[]], [ZZ(2)], 3, 1, ZZ) == [[]] + assert dmp_mul_term([[ZZ(1)]], [], 3, 1, ZZ) == [[]] + + assert dmp_mul_term([[ZZ(1), ZZ(2)], [ZZ(3)]], [ZZ(2)], 2, 1, ZZ) == \ + [[ZZ(2), ZZ(4)], [ZZ(6)], [], []] + + assert dmp_mul_term([[]], [QQ(2, 3)], 3, 1, QQ) == [[]] + assert dmp_mul_term([[QQ(1, 2)]], [], 3, 1, QQ) == [[]] + + assert dmp_mul_term([[QQ(1, 5), QQ(2, 5)], [QQ(3, 5)]], [QQ(2, 3)], 2, 1, QQ) == \ + [[QQ(2, 15), QQ(4, 15)], [QQ(6, 15)], [], []] + + +def test_dup_add_ground(): + f = ZZ.map([1, 2, 3, 4]) + g = ZZ.map([1, 2, 3, 8]) + + assert dup_add_ground(f, ZZ(4), ZZ) == g + + +def test_dmp_add_ground(): + f = ZZ.map([[1], [2], [3], [4]]) + g = ZZ.map([[1], [2], [3], [8]]) + + assert dmp_add_ground(f, ZZ(4), 1, ZZ) == g + + +def test_dup_sub_ground(): + f = ZZ.map([1, 2, 3, 4]) + g = ZZ.map([1, 2, 3, 0]) + + assert dup_sub_ground(f, ZZ(4), ZZ) == g + + +def test_dmp_sub_ground(): + f = ZZ.map([[1], [2], [3], [4]]) + g = ZZ.map([[1], [2], [3], []]) + + assert dmp_sub_ground(f, ZZ(4), 1, ZZ) == g + + +def test_dup_mul_ground(): + f = dup_normal([], ZZ) + + assert dup_mul_ground(f, ZZ(2), ZZ) == dup_normal([], ZZ) + + f = dup_normal([1, 2, 3], ZZ) + + assert dup_mul_ground(f, ZZ(0), ZZ) == dup_normal([], ZZ) + assert dup_mul_ground(f, ZZ(2), ZZ) == dup_normal([2, 4, 6], ZZ) + + +def test_dmp_mul_ground(): + assert dmp_mul_ground(f_0, ZZ(2), 2, ZZ) == [ + [[ZZ(2), ZZ(4), ZZ(6)], [ZZ(4)]], + [[ZZ(6)]], + [[ZZ(8), ZZ(10), ZZ(12)], [ZZ(2), ZZ(4), ZZ(2)], [ZZ(2)]] + ] + + assert dmp_mul_ground(F_0, QQ(1, 2), 2, QQ) == [ + [[QQ(1, 14), QQ(2, 14), QQ(3, 14)], [QQ(2, 14)]], + [[QQ(3, 14)]], + [[QQ(4, 14), QQ(5, 14), QQ(6, 14)], [QQ(1, 14), QQ(2, 14), + QQ(1, 14)], [QQ(1, 14)]] + ] + + +def test_dup_quo_ground(): + raises(ZeroDivisionError, lambda: dup_quo_ground(dup_normal([1, 2, + 3], ZZ), ZZ(0), ZZ)) + + f = dup_normal([], ZZ) + + assert dup_quo_ground(f, ZZ(3), ZZ) == dup_normal([], ZZ) + + f = dup_normal([6, 2, 8], ZZ) + + assert dup_quo_ground(f, ZZ(1), ZZ) == f + assert dup_quo_ground(f, ZZ(2), ZZ) == dup_normal([3, 1, 4], ZZ) + + assert dup_quo_ground(f, ZZ(3), ZZ) == dup_normal([2, 0, 2], ZZ) + + f = dup_normal([6, 2, 8], QQ) + + assert dup_quo_ground(f, QQ(1), QQ) == f + assert dup_quo_ground(f, QQ(2), QQ) == [QQ(3), QQ(1), QQ(4)] + assert dup_quo_ground(f, QQ(7), QQ) == [QQ(6, 7), QQ(2, 7), QQ(8, 7)] + + +def test_dup_exquo_ground(): + raises(ZeroDivisionError, lambda: dup_exquo_ground(dup_normal([1, + 2, 3], ZZ), ZZ(0), ZZ)) + raises(ExactQuotientFailed, lambda: dup_exquo_ground(dup_normal([1, + 2, 3], ZZ), ZZ(3), ZZ)) + + f = dup_normal([], ZZ) + + assert dup_exquo_ground(f, ZZ(3), ZZ) == dup_normal([], ZZ) + + f = dup_normal([6, 2, 8], ZZ) + + assert dup_exquo_ground(f, ZZ(1), ZZ) == f + assert dup_exquo_ground(f, ZZ(2), ZZ) == dup_normal([3, 1, 4], ZZ) + + f = dup_normal([6, 2, 8], QQ) + + assert dup_exquo_ground(f, QQ(1), QQ) == f + assert dup_exquo_ground(f, QQ(2), QQ) == [QQ(3), QQ(1), QQ(4)] + assert dup_exquo_ground(f, QQ(7), QQ) == [QQ(6, 7), QQ(2, 7), QQ(8, 7)] + + +def test_dmp_quo_ground(): + f = dmp_normal([[6], [2], [8]], 1, ZZ) + + assert dmp_quo_ground(f, ZZ(1), 1, ZZ) == f + assert dmp_quo_ground( + f, ZZ(2), 1, ZZ) == dmp_normal([[3], [1], [4]], 1, ZZ) + + assert dmp_normal(dmp_quo_ground( + f, ZZ(3), 1, ZZ), 1, ZZ) == dmp_normal([[2], [], [2]], 1, ZZ) + + +def test_dmp_exquo_ground(): + f = dmp_normal([[6], [2], [8]], 1, ZZ) + + assert dmp_exquo_ground(f, ZZ(1), 1, ZZ) == f + assert dmp_exquo_ground( + f, ZZ(2), 1, ZZ) == dmp_normal([[3], [1], [4]], 1, ZZ) + + +def test_dup_lshift(): + assert dup_lshift([], 3, ZZ) == [] + assert dup_lshift([1], 3, ZZ) == [1, 0, 0, 0] + + +def test_dup_rshift(): + assert dup_rshift([], 3, ZZ) == [] + assert dup_rshift([1, 0, 0, 0], 3, ZZ) == [1] + + +def test_dup_abs(): + assert dup_abs([], ZZ) == [] + assert dup_abs([ZZ( 1)], ZZ) == [ZZ(1)] + assert dup_abs([ZZ(-7)], ZZ) == [ZZ(7)] + assert dup_abs([ZZ(-1), ZZ(2), ZZ(3)], ZZ) == [ZZ(1), ZZ(2), ZZ(3)] + + assert dup_abs([], QQ) == [] + assert dup_abs([QQ( 1, 2)], QQ) == [QQ(1, 2)] + assert dup_abs([QQ(-7, 3)], QQ) == [QQ(7, 3)] + assert dup_abs( + [QQ(-1, 7), QQ(2, 7), QQ(3, 7)], QQ) == [QQ(1, 7), QQ(2, 7), QQ(3, 7)] + + +def test_dmp_abs(): + assert dmp_abs([ZZ(-1)], 0, ZZ) == [ZZ(1)] + assert dmp_abs([QQ(-1, 2)], 0, QQ) == [QQ(1, 2)] + + assert dmp_abs([[[]]], 2, ZZ) == [[[]]] + assert dmp_abs([[[ZZ(1)]]], 2, ZZ) == [[[ZZ(1)]]] + assert dmp_abs([[[ZZ(-7)]]], 2, ZZ) == [[[ZZ(7)]]] + + assert dmp_abs([[[]]], 2, QQ) == [[[]]] + assert dmp_abs([[[QQ(1, 2)]]], 2, QQ) == [[[QQ(1, 2)]]] + assert dmp_abs([[[QQ(-7, 9)]]], 2, QQ) == [[[QQ(7, 9)]]] + + +def test_dup_neg(): + assert dup_neg([], ZZ) == [] + assert dup_neg([ZZ(1)], ZZ) == [ZZ(-1)] + assert dup_neg([ZZ(-7)], ZZ) == [ZZ(7)] + assert dup_neg([ZZ(-1), ZZ(2), ZZ(3)], ZZ) == [ZZ(1), ZZ(-2), ZZ(-3)] + + assert dup_neg([], QQ) == [] + assert dup_neg([QQ(1, 2)], QQ) == [QQ(-1, 2)] + assert dup_neg([QQ(-7, 9)], QQ) == [QQ(7, 9)] + assert dup_neg([QQ( + -1, 7), QQ(2, 7), QQ(3, 7)], QQ) == [QQ(1, 7), QQ(-2, 7), QQ(-3, 7)] + + +def test_dmp_neg(): + assert dmp_neg([ZZ(-1)], 0, ZZ) == [ZZ(1)] + assert dmp_neg([QQ(-1, 2)], 0, QQ) == [QQ(1, 2)] + + assert dmp_neg([[[]]], 2, ZZ) == [[[]]] + assert dmp_neg([[[ZZ(1)]]], 2, ZZ) == [[[ZZ(-1)]]] + assert dmp_neg([[[ZZ(-7)]]], 2, ZZ) == [[[ZZ(7)]]] + + assert dmp_neg([[[]]], 2, QQ) == [[[]]] + assert dmp_neg([[[QQ(1, 9)]]], 2, QQ) == [[[QQ(-1, 9)]]] + assert dmp_neg([[[QQ(-7, 9)]]], 2, QQ) == [[[QQ(7, 9)]]] + + +def test_dup_add(): + assert dup_add([], [], ZZ) == [] + assert dup_add([ZZ(1)], [], ZZ) == [ZZ(1)] + assert dup_add([], [ZZ(1)], ZZ) == [ZZ(1)] + assert dup_add([ZZ(1)], [ZZ(1)], ZZ) == [ZZ(2)] + assert dup_add([ZZ(1)], [ZZ(2)], ZZ) == [ZZ(3)] + + assert dup_add([ZZ(1), ZZ(2)], [ZZ(1)], ZZ) == [ZZ(1), ZZ(3)] + assert dup_add([ZZ(1)], [ZZ(1), ZZ(2)], ZZ) == [ZZ(1), ZZ(3)] + + assert dup_add([ZZ(1), ZZ( + 2), ZZ(3)], [ZZ(8), ZZ(9), ZZ(10)], ZZ) == [ZZ(9), ZZ(11), ZZ(13)] + + assert dup_add([], [], QQ) == [] + assert dup_add([QQ(1, 2)], [], QQ) == [QQ(1, 2)] + assert dup_add([], [QQ(1, 2)], QQ) == [QQ(1, 2)] + assert dup_add([QQ(1, 4)], [QQ(1, 4)], QQ) == [QQ(1, 2)] + assert dup_add([QQ(1, 4)], [QQ(1, 2)], QQ) == [QQ(3, 4)] + + assert dup_add([QQ(1, 2), QQ(2, 3)], [QQ(1)], QQ) == [QQ(1, 2), QQ(5, 3)] + assert dup_add([QQ(1)], [QQ(1, 2), QQ(2, 3)], QQ) == [QQ(1, 2), QQ(5, 3)] + + assert dup_add([QQ(1, 7), QQ(2, 7), QQ(3, 7)], [QQ( + 8, 7), QQ(9, 7), QQ(10, 7)], QQ) == [QQ(9, 7), QQ(11, 7), QQ(13, 7)] + + +def test_dmp_add(): + assert dmp_add([ZZ(1), ZZ(2)], [ZZ(1)], 0, ZZ) == \ + dup_add([ZZ(1), ZZ(2)], [ZZ(1)], ZZ) + assert dmp_add([QQ(1, 2), QQ(2, 3)], [QQ(1)], 0, QQ) == \ + dup_add([QQ(1, 2), QQ(2, 3)], [QQ(1)], QQ) + + assert dmp_add([[[]]], [[[]]], 2, ZZ) == [[[]]] + assert dmp_add([[[ZZ(1)]]], [[[]]], 2, ZZ) == [[[ZZ(1)]]] + assert dmp_add([[[]]], [[[ZZ(1)]]], 2, ZZ) == [[[ZZ(1)]]] + assert dmp_add([[[ZZ(2)]]], [[[ZZ(1)]]], 2, ZZ) == [[[ZZ(3)]]] + assert dmp_add([[[ZZ(1)]]], [[[ZZ(2)]]], 2, ZZ) == [[[ZZ(3)]]] + + assert dmp_add([[[]]], [[[]]], 2, QQ) == [[[]]] + assert dmp_add([[[QQ(1, 2)]]], [[[]]], 2, QQ) == [[[QQ(1, 2)]]] + assert dmp_add([[[]]], [[[QQ(1, 2)]]], 2, QQ) == [[[QQ(1, 2)]]] + assert dmp_add([[[QQ(2, 7)]]], [[[QQ(1, 7)]]], 2, QQ) == [[[QQ(3, 7)]]] + assert dmp_add([[[QQ(1, 7)]]], [[[QQ(2, 7)]]], 2, QQ) == [[[QQ(3, 7)]]] + + +def test_dup_sub(): + assert dup_sub([], [], ZZ) == [] + assert dup_sub([ZZ(1)], [], ZZ) == [ZZ(1)] + assert dup_sub([], [ZZ(1)], ZZ) == [ZZ(-1)] + assert dup_sub([ZZ(1)], [ZZ(1)], ZZ) == [] + assert dup_sub([ZZ(1)], [ZZ(2)], ZZ) == [ZZ(-1)] + + assert dup_sub([ZZ(1), ZZ(2)], [ZZ(1)], ZZ) == [ZZ(1), ZZ(1)] + assert dup_sub([ZZ(1)], [ZZ(1), ZZ(2)], ZZ) == [ZZ(-1), ZZ(-1)] + + assert dup_sub([ZZ(3), ZZ( + 2), ZZ(1)], [ZZ(8), ZZ(9), ZZ(10)], ZZ) == [ZZ(-5), ZZ(-7), ZZ(-9)] + + assert dup_sub([], [], QQ) == [] + assert dup_sub([QQ(1, 2)], [], QQ) == [QQ(1, 2)] + assert dup_sub([], [QQ(1, 2)], QQ) == [QQ(-1, 2)] + assert dup_sub([QQ(1, 3)], [QQ(1, 3)], QQ) == [] + assert dup_sub([QQ(1, 3)], [QQ(2, 3)], QQ) == [QQ(-1, 3)] + + assert dup_sub([QQ(1, 7), QQ(2, 7)], [QQ(1)], QQ) == [QQ(1, 7), QQ(-5, 7)] + assert dup_sub([QQ(1)], [QQ(1, 7), QQ(2, 7)], QQ) == [QQ(-1, 7), QQ(5, 7)] + + assert dup_sub([QQ(3, 7), QQ(2, 7), QQ(1, 7)], [QQ( + 8, 7), QQ(9, 7), QQ(10, 7)], QQ) == [QQ(-5, 7), QQ(-7, 7), QQ(-9, 7)] + + +def test_dmp_sub(): + assert dmp_sub([ZZ(1), ZZ(2)], [ZZ(1)], 0, ZZ) == \ + dup_sub([ZZ(1), ZZ(2)], [ZZ(1)], ZZ) + assert dmp_sub([QQ(1, 2), QQ(2, 3)], [QQ(1)], 0, QQ) == \ + dup_sub([QQ(1, 2), QQ(2, 3)], [QQ(1)], QQ) + + assert dmp_sub([[[]]], [[[]]], 2, ZZ) == [[[]]] + assert dmp_sub([[[ZZ(1)]]], [[[]]], 2, ZZ) == [[[ZZ(1)]]] + assert dmp_sub([[[]]], [[[ZZ(1)]]], 2, ZZ) == [[[ZZ(-1)]]] + assert dmp_sub([[[ZZ(2)]]], [[[ZZ(1)]]], 2, ZZ) == [[[ZZ(1)]]] + assert dmp_sub([[[ZZ(1)]]], [[[ZZ(2)]]], 2, ZZ) == [[[ZZ(-1)]]] + + assert dmp_sub([[[]]], [[[]]], 2, QQ) == [[[]]] + assert dmp_sub([[[QQ(1, 2)]]], [[[]]], 2, QQ) == [[[QQ(1, 2)]]] + assert dmp_sub([[[]]], [[[QQ(1, 2)]]], 2, QQ) == [[[QQ(-1, 2)]]] + assert dmp_sub([[[QQ(2, 7)]]], [[[QQ(1, 7)]]], 2, QQ) == [[[QQ(1, 7)]]] + assert dmp_sub([[[QQ(1, 7)]]], [[[QQ(2, 7)]]], 2, QQ) == [[[QQ(-1, 7)]]] + + +def test_dup_add_mul(): + assert dup_add_mul([ZZ(1), ZZ(2), ZZ(3)], [ZZ(3), ZZ(2), ZZ(1)], + [ZZ(1), ZZ(2)], ZZ) == [ZZ(3), ZZ(9), ZZ(7), ZZ(5)] + assert dmp_add_mul([[ZZ(1), ZZ(2)], [ZZ(3)]], [[ZZ(3)], [ZZ(2), ZZ(1)]], + [[ZZ(1)], [ZZ(2)]], 1, ZZ) == [[ZZ(3)], [ZZ(3), ZZ(9)], [ZZ(4), ZZ(5)]] + + +def test_dup_sub_mul(): + assert dup_sub_mul([ZZ(1), ZZ(2), ZZ(3)], [ZZ(3), ZZ(2), ZZ(1)], + [ZZ(1), ZZ(2)], ZZ) == [ZZ(-3), ZZ(-7), ZZ(-3), ZZ(1)] + assert dmp_sub_mul([[ZZ(1), ZZ(2)], [ZZ(3)]], [[ZZ(3)], [ZZ(2), ZZ(1)]], + [[ZZ(1)], [ZZ(2)]], 1, ZZ) == [[ZZ(-3)], [ZZ(-1), ZZ(-5)], [ZZ(-4), ZZ(1)]] + + +def test_dup_mul(): + assert dup_mul([], [], ZZ) == [] + assert dup_mul([], [ZZ(1)], ZZ) == [] + assert dup_mul([ZZ(1)], [], ZZ) == [] + assert dup_mul([ZZ(1)], [ZZ(1)], ZZ) == [ZZ(1)] + assert dup_mul([ZZ(5)], [ZZ(7)], ZZ) == [ZZ(35)] + + assert dup_mul([], [], QQ) == [] + assert dup_mul([], [QQ(1, 2)], QQ) == [] + assert dup_mul([QQ(1, 2)], [], QQ) == [] + assert dup_mul([QQ(1, 2)], [QQ(4, 7)], QQ) == [QQ(2, 7)] + assert dup_mul([QQ(5, 7)], [QQ(3, 7)], QQ) == [QQ(15, 49)] + + f = dup_normal([3, 0, 0, 6, 1, 2], ZZ) + g = dup_normal([4, 0, 1, 0], ZZ) + h = dup_normal([12, 0, 3, 24, 4, 14, 1, 2, 0], ZZ) + + assert dup_mul(f, g, ZZ) == h + assert dup_mul(g, f, ZZ) == h + + f = dup_normal([2, 0, 0, 1, 7], ZZ) + h = dup_normal([4, 0, 0, 4, 28, 0, 1, 14, 49], ZZ) + + assert dup_mul(f, f, ZZ) == h + + K = FF(6) + + assert dup_mul([K(2), K(1)], [K(3), K(4)], K) == [K(5), K(4)] + + p1 = dup_normal([79, -1, 78, -94, -10, 11, 32, -19, 78, 2, -89, 30, 73, 42, + 85, 77, 83, -30, -34, -2, 95, -81, 37, -49, -46, -58, -16, 37, 35, -11, + -57, -15, -31, 67, -20, 27, 76, 2, 70, 67, -65, 65, -26, -93, -44, -12, + -92, 57, -90, -57, -11, -67, -98, -69, 97, -41, 89, 33, 89, -50, 81, + -31, 60, -27, 43, 29, -77, 44, 21, -91, 32, -57, 33, 3, 53, -51, -38, + -99, -84, 23, -50, 66, -100, 1, -75, -25, 27, -60, 98, -51, -87, 6, 8, + 78, -28, -95, -88, 12, -35, 26, -9, 16, -92, 55, -7, -86, 68, -39, -46, + 84, 94, 45, 60, 92, 68, -75, -74, -19, 8, 75, 78, 91, 57, 34, 14, -3, + -49, 65, 78, -18, 6, -29, -80, -98, 17, 13, 58, 21, 20, 9, 37, 7, -30, + -53, -20, 34, 67, -42, 89, -22, 73, 43, -6, 5, 51, -8, -15, -52, -22, + -58, -72, -3, 43, -92, 82, 83, -2, -13, -23, -60, 16, -94, -8, -28, + -95, -72, 63, -90, 76, 6, -43, -100, -59, 76, 3, 3, 46, -85, 75, 62, + -71, -76, 88, 97, -72, -1, 30, -64, 72, -48, 14, -78, 58, 63, -91, 24, + -87, -27, -80, -100, -44, 98, 70, 100, -29, -38, 11, 77, 100, 52, 86, + 65, -5, -42, -81, -38, -42, 43, -2, -70, -63, -52], ZZ) + p2 = dup_normal([65, -19, -47, 1, 90, 81, -15, -34, 25, -75, 9, -83, 50, -5, + -44, 31, 1, 70, -7, 78, 74, 80, 85, 65, 21, 41, 66, 19, -40, 63, -21, + -27, 32, 69, 83, 34, -35, 14, 81, 57, -75, 32, -67, -89, -100, -61, 46, + 84, -78, -29, -50, -94, -24, -32, -68, -16, 100, -7, -72, -89, 35, 82, + 58, 81, -92, 62, 5, -47, -39, -58, -72, -13, 84, 44, 55, -25, 48, -54, + -31, -56, -11, -50, -84, 10, 67, 17, 13, -14, 61, 76, -64, -44, -40, + -96, 11, -11, -94, 2, 6, 27, -6, 68, -54, 66, -74, -14, -1, -24, -73, + 96, 89, -11, -89, 56, -53, 72, -43, 96, 25, 63, -31, 29, 68, 83, 91, + -93, -19, -38, -40, 40, -12, -19, -79, 44, 100, -66, -29, -77, 62, 39, + -8, 11, -97, 14, 87, 64, 21, -18, 13, 15, -59, -75, -99, -88, 57, 54, + 56, -67, 6, -63, -59, -14, 28, 87, -20, -39, 84, -91, -2, 49, -75, 11, + -24, -95, 36, 66, 5, 25, -72, -40, 86, 90, 37, -33, 57, -35, 29, -18, + 4, -79, 64, -17, -27, 21, 29, -5, -44, -87, -24, 52, 78, 11, -23, -53, + 36, 42, 21, -68, 94, -91, -51, -21, 51, -76, 72, 31, 24, -48, -80, -9, + 37, -47, -6, -8, -63, -91, 79, -79, -100, 38, -20, 38, 100, 83, -90, + 87, 63, -36, 82, -19, 18, -98, -38, 26, 98, -70, 79, 92, 12, 12, 70, + 74, 36, 48, -13, 31, 31, -47, -71, -12, -64, 36, -42, 32, -86, 60, 83, + 70, 55, 0, 1, 29, -35, 8, -82, 8, -73, -46, -50, 43, 48, -5, -86, -72, + 44, -90, 19, 19, 5, -20, 97, -13, -66, -5, 5, -69, 64, -30, 41, 51, 36, + 13, -99, -61, 94, -12, 74, 98, 68, 24, 46, -97, -87, -6, -27, 82, 62, + -11, -77, 86, 66, -47, -49, -50, 13, 18, 89, -89, 46, -80, 13, 98, -35, + -36, -25, 12, 20, 26, -52, 79, 27, 79, 100, 8, 62, -58, -28, 37], ZZ) + res = dup_normal([5135, -1566, 1376, -7466, 4579, 11710, 8001, -7183, + -3737, -7439, 345, -10084, 24522, -1201, 1070, -10245, 9582, 9264, + 1903, 23312, 18953, 10037, -15268, -5450, 6442, -6243, -3777, 5110, + 10936, -16649, -6022, 16255, 31300, 24818, 31922, 32760, 7854, 27080, + 15766, 29596, 7139, 31945, -19810, 465, -38026, -3971, 9641, 465, + -19375, 5524, -30112, -11960, -12813, 13535, 30670, 5925, -43725, + -14089, 11503, -22782, 6371, 43881, 37465, -33529, -33590, -39798, + -37854, -18466, -7908, -35825, -26020, -36923, -11332, -5699, 25166, + -3147, 19885, 12962, -20659, -1642, 27723, -56331, -24580, -11010, + -20206, 20087, -23772, -16038, 38580, 20901, -50731, 32037, -4299, + 26508, 18038, -28357, 31846, -7405, -20172, -15894, 2096, 25110, + -45786, 45918, -55333, -31928, -49428, -29824, -58796, -24609, -15408, + 69, -35415, -18439, 10123, -20360, -65949, 33356, -20333, 26476, + -32073, 33621, 930, 28803, -42791, 44716, 38164, 12302, -1739, 11421, + 73385, -7613, 14297, 38155, -414, 77587, 24338, -21415, 29367, 42639, + 13901, -288, 51027, -11827, 91260, 43407, 88521, -15186, 70572, -12049, + 5090, -12208, -56374, 15520, -623, -7742, 50825, 11199, -14894, 40892, + 59591, -31356, -28696, -57842, -87751, -33744, -28436, -28945, -40287, + 37957, -35638, 33401, -61534, 14870, 40292, 70366, -10803, 102290, + -71719, -85251, 7902, -22409, 75009, 99927, 35298, -1175, -762, -34744, + -10587, -47574, -62629, -19581, -43659, -54369, -32250, -39545, 15225, + -24454, 11241, -67308, -30148, 39929, 37639, 14383, -73475, -77636, + -81048, -35992, 41601, -90143, 76937, -8112, 56588, 9124, -40094, + -32340, 13253, 10898, -51639, 36390, 12086, -1885, 100714, -28561, + -23784, -18735, 18916, 16286, 10742, -87360, -13697, 10689, -19477, + -29770, 5060, 20189, -8297, 112407, 47071, 47743, 45519, -4109, 17468, + -68831, 78325, -6481, -21641, -19459, 30919, 96115, 8607, 53341, 32105, + -16211, 23538, 57259, -76272, -40583, 62093, 38511, -34255, -40665, + -40604, -37606, -15274, 33156, -13885, 103636, 118678, -14101, -92682, + -100791, 2634, 63791, 98266, 19286, -34590, -21067, -71130, 25380, + -40839, -27614, -26060, 52358, -15537, 27138, -6749, 36269, -33306, + 13207, -91084, -5540, -57116, 69548, 44169, -57742, -41234, -103327, + -62904, -8566, 41149, -12866, 71188, 23980, 1838, 58230, 73950, 5594, + 43113, -8159, -15925, 6911, 85598, -75016, -16214, -62726, -39016, + 8618, -63882, -4299, 23182, 49959, 49342, -3238, -24913, -37138, 78361, + 32451, 6337, -11438, -36241, -37737, 8169, -3077, -24829, 57953, 53016, + -31511, -91168, 12599, -41849, 41576, 55275, -62539, 47814, -62319, + 12300, -32076, -55137, -84881, -27546, 4312, -3433, -54382, 113288, + -30157, 74469, 18219, 79880, -2124, 98911, 17655, -33499, -32861, + 47242, -37393, 99765, 14831, -44483, 10800, -31617, -52710, 37406, + 22105, 29704, -20050, 13778, 43683, 36628, 8494, 60964, -22644, 31550, + -17693, 33805, -124879, -12302, 19343, 20400, -30937, -21574, -34037, + -33380, 56539, -24993, -75513, -1527, 53563, 65407, -101, 53577, 37991, + 18717, -23795, -8090, -47987, -94717, 41967, 5170, -14815, -94311, + 17896, -17734, -57718, -774, -38410, 24830, 29682, 76480, 58802, + -46416, -20348, -61353, -68225, -68306, 23822, -31598, 42972, 36327, + 28968, -65638, -21638, 24354, -8356, 26777, 52982, -11783, -44051, + -26467, -44721, -28435, -53265, -25574, -2669, 44155, 22946, -18454, + -30718, -11252, 58420, 8711, 67447, 4425, 41749, 67543, 43162, 11793, + -41907, 20477, -13080, 6559, -6104, -13244, 42853, 42935, 29793, 36730, + -28087, 28657, 17946, 7503, 7204, 21491, -27450, -24241, -98156, + -18082, -42613, -24928, 10775, -14842, -44127, 55910, 14777, 31151, -2194, + 39206, -2100, -4211, 11827, -8918, -19471, 72567, 36447, -65590, -34861, + -17147, -45303, 9025, -7333, -35473, 11101, 11638, 3441, 6626, -41800, + 9416, 13679, 33508, 40502, -60542, 16358, 8392, -43242, -35864, -34127, + -48721, 35878, 30598, 28630, 20279, -19983, -14638, -24455, -1851, -11344, + 45150, 42051, 26034, -28889, -32382, -3527, -14532, 22564, -22346, 477, + 11706, 28338, -25972, -9185, -22867, -12522, 32120, -4424, 11339, -33913, + -7184, 5101, -23552, -17115, -31401, -6104, 21906, 25708, 8406, 6317, + -7525, 5014, 20750, 20179, 22724, 11692, 13297, 2493, -253, -16841, -17339, + -6753, -4808, 2976, -10881, -10228, -13816, -12686, 1385, 2316, 2190, -875, + -1924], ZZ) + + assert dup_mul(p1, p2, ZZ) == res + + p1 = dup_normal([83, -61, -86, -24, 12, 43, -88, -9, 42, 55, -66, 74, 95, + -25, -12, 68, -99, 4, 45, 6, -15, -19, 78, 65, -55, 47, -13, 17, 86, + 81, -58, -27, 50, -40, -24, 39, -41, -92, 75, 90, -1, 40, -15, -27, + -35, 68, 70, -64, -40, 78, -88, -58, -39, 69, 46, 12, 28, -94, -37, + -50, -80, -96, -61, 25, 1, 71, 4, 12, 48, 4, 34, -47, -75, 5, 48, 82, + 88, 23, 98, 35, 17, -10, 48, -61, -95, 47, 65, -19, -66, -57, -6, -51, + -42, -89, 66, -13, 18, 37, 90, -23, 72, 96, -53, 0, 40, -73, -52, -68, + 32, -25, -53, 79, -52, 18, 44, 73, -81, 31, -90, 70, 3, 36, 48, 76, + -24, -44, 23, 98, -4, 73, 69, 88, -70, 14, -68, 94, -78, -15, -64, -97, + -70, -35, 65, 88, 49, -53, -7, 12, -45, -7, 59, -94, 99, -2, 67, -60, + -71, 29, -62, -77, 1, 51, 17, 80, -20, -47, -19, 24, -9, 39, -23, 21, + -84, 10, 84, 56, -17, -21, -66, 85, 70, 46, -51, -22, -95, 78, -60, + -96, -97, -45, 72, 35, 30, -61, -92, -93, -60, -61, 4, -4, -81, -73, + 46, 53, -11, 26, 94, 45, 14, -78, 55, 84, -68, 98, 60, 23, 100, -63, + 68, 96, -16, 3, 56, 21, -58, 62, -67, 66, 85, 41, -79, -22, 97, -67, + 82, 82, -96, -20, -7, 48, -67, 48, -9, -39, 78], ZZ) + p2 = dup_normal([52, 88, 76, 66, 9, -64, 46, -20, -28, 69, 60, 96, -36, + -92, -30, -11, -35, 35, 55, 63, -92, -7, 25, -58, 74, 55, -6, 4, 47, + -92, -65, 67, -45, 74, -76, 59, -6, 69, 39, 24, -71, -7, 39, -45, 60, + -68, 98, 97, -79, 17, 4, 94, -64, 68, -100, -96, -2, 3, 22, 96, 54, + -77, -86, 67, 6, 57, 37, 40, 89, -78, 64, -94, -45, -92, 57, 87, -26, + 36, 19, 97, 25, 77, -87, 24, 43, -5, 35, 57, 83, 71, 35, 63, 61, 96, + -22, 8, -1, 96, 43, 45, 94, -93, 36, 71, -41, -99, 85, -48, 59, 52, + -17, 5, 87, -16, -68, -54, 76, -18, 100, 91, -42, -70, -66, -88, -12, + 1, 95, -82, 52, 43, -29, 3, 12, 72, -99, -43, -32, -93, -51, 16, -20, + -12, -11, 5, 33, -38, 93, -5, -74, 25, 74, -58, 93, 59, -63, -86, 63, + -20, -4, -74, -73, -95, 29, -28, 93, -91, -2, -38, -62, 77, -58, -85, + -28, 95, 38, 19, -69, 86, 94, 25, -2, -4, 47, 34, -59, 35, -48, 29, + -63, -53, 34, 29, 66, 73, 6, 92, -84, 89, 15, 81, 93, 97, 51, -72, -78, + 25, 60, 90, -45, 39, 67, -84, -62, 57, 26, -32, -56, -14, -83, 76, 5, + -2, 99, -100, 28, 46, 94, -7, 53, -25, 16, -23, -36, 89, -78, -63, 31, + 1, 84, -99, -52, 76, 48, 90, -76, 44, -19, 54, -36, -9, -73, -100, -69, + 31, 42, 25, -39, 76, -26, -8, -14, 51, 3, 37, 45, 2, -54, 13, -34, -92, + 17, -25, -65, 53, -63, 30, 4, -70, -67, 90, 52, 51, 18, -3, 31, -45, + -9, 59, 63, -87, 22, -32, 29, -38, 21, 36, -82, 27, -11], ZZ) + res = dup_normal([4316, 4132, -3532, -7974, -11303, -10069, 5484, -3330, + -5874, 7734, 4673, 11327, -9884, -8031, 17343, 21035, -10570, -9285, + 15893, 3780, -14083, 8819, 17592, 10159, 7174, -11587, 8598, -16479, + 3602, 25596, 9781, 12163, 150, 18749, -21782, -12307, 27578, -2757, + -12573, 12565, 6345, -18956, 19503, -15617, 1443, -16778, 36851, 23588, + -28474, 5749, 40695, -7521, -53669, -2497, -18530, 6770, 57038, 3926, + -6927, -15399, 1848, -64649, -27728, 3644, 49608, 15187, -8902, -9480, + -7398, -40425, 4824, 23767, -7594, -6905, 33089, 18786, 12192, 24670, + 31114, 35334, -4501, -14676, 7107, -59018, -21352, 20777, 19661, 20653, + 33754, -885, -43758, 6269, 51897, -28719, -97488, -9527, 13746, 11644, + 17644, -21720, 23782, -10481, 47867, 20752, 33810, -1875, 39918, -7710, + -40840, 19808, -47075, 23066, 46616, 25201, 9287, 35436, -1602, 9645, + -11978, 13273, 15544, 33465, 20063, 44539, 11687, 27314, -6538, -37467, + 14031, 32970, -27086, 41323, 29551, 65910, -39027, -37800, -22232, + 8212, 46316, -28981, -55282, 50417, -44929, -44062, 73879, 37573, + -2596, -10877, -21893, -133218, -33707, -25753, -9531, 17530, 61126, + 2748, -56235, 43874, -10872, -90459, -30387, 115267, -7264, -44452, + 122626, 14839, -599, 10337, 57166, -67467, -54957, 63669, 1202, 18488, + 52594, 7205, -97822, 612, 78069, -5403, -63562, 47236, 36873, -154827, + -26188, 82427, -39521, 5628, 7416, 5276, -53095, 47050, 26121, -42207, + 79021, -13035, 2499, -66943, 29040, -72355, -23480, 23416, -12885, + -44225, -42688, -4224, 19858, 55299, 15735, 11465, 101876, -39169, + 51786, 14723, 43280, -68697, 16410, 92295, 56767, 7183, 111850, 4550, + 115451, -38443, -19642, -35058, 10230, 93829, 8925, 63047, 3146, 29250, + 8530, 5255, -98117, -115517, -76817, -8724, 41044, 1312, -35974, 79333, + -28567, 7547, -10580, -24559, -16238, 10794, -3867, 24848, 57770, + -51536, -35040, 71033, 29853, 62029, -7125, -125585, -32169, -47907, + 156811, -65176, -58006, -15757, -57861, 11963, 30225, -41901, -41681, + 31310, 27982, 18613, 61760, 60746, -59096, 33499, 30097, -17997, 24032, + 56442, -83042, 23747, -20931, -21978, -158752, -9883, -73598, -7987, + -7333, -125403, -116329, 30585, 53281, 51018, -29193, 88575, 8264, + -40147, -16289, 113088, 12810, -6508, 101552, -13037, 34440, -41840, + 101643, 24263, 80532, 61748, 65574, 6423, -20672, 6591, -10834, -71716, + 86919, -92626, 39161, 28490, 81319, 46676, 106720, 43530, 26998, 57456, + -8862, 60989, 13982, 3119, -2224, 14743, 55415, -49093, -29303, 28999, + 1789, 55953, -84043, -7780, -65013, 57129, -47251, 61484, 61994, + -78361, -82778, 22487, -26894, 9756, -74637, -15519, -4360, 30115, + 42433, 35475, 15286, 69768, 21509, -20214, 78675, -21163, 13596, 11443, + -10698, -53621, -53867, -24155, 64500, -42784, -33077, -16500, 873, + -52788, 14546, -38011, 36974, -39849, -34029, -94311, 83068, -50437, + -26169, -46746, 59185, 42259, -101379, -12943, 30089, -59086, 36271, + 22723, -30253, -52472, -70826, -23289, 3331, -31687, 14183, -857, + -28627, 35246, -51284, 5636, -6933, 66539, 36654, 50927, 24783, 3457, + 33276, 45281, 45650, -4938, -9968, -22590, 47995, 69229, 5214, -58365, + -17907, -14651, 18668, 18009, 12649, -11851, -13387, 20339, 52472, + -1087, -21458, -68647, 52295, 15849, 40608, 15323, 25164, -29368, + 10352, -7055, 7159, 21695, -5373, -54849, 101103, -24963, -10511, + 33227, 7659, 41042, -69588, 26718, -20515, 6441, 38135, -63, 24088, + -35364, -12785, -18709, 47843, 48533, -48575, 17251, -19394, 32878, + -9010, -9050, 504, -12407, 28076, -3429, 25324, -4210, -26119, 752, + -29203, 28251, -11324, -32140, -3366, -25135, 18702, -31588, -7047, + -24267, 49987, -14975, -33169, 37744, -7720, -9035, 16964, -2807, -421, + 14114, -17097, -13662, 40628, -12139, -9427, 5369, 17551, -13232, -16211, + 9804, -7422, 2677, 28635, -8280, -4906, 2908, -22558, 5604, 12459, 8756, + -3980, -4745, -18525, 7913, 5970, -16457, 20230, -6247, -13812, 2505, + 11899, 1409, -15094, 22540, -18863, 137, 11123, -4516, 2290, -8594, 12150, + -10380, 3005, 5235, -7350, 2535, -858], ZZ) + + assert dup_mul(p1, p2, ZZ) == res + + +def test_dmp_mul(): + assert dmp_mul([ZZ(5)], [ZZ(7)], 0, ZZ) == \ + dup_mul([ZZ(5)], [ZZ(7)], ZZ) + assert dmp_mul([QQ(5, 7)], [QQ(3, 7)], 0, QQ) == \ + dup_mul([QQ(5, 7)], [QQ(3, 7)], QQ) + + assert dmp_mul([[[]]], [[[]]], 2, ZZ) == [[[]]] + assert dmp_mul([[[ZZ(1)]]], [[[]]], 2, ZZ) == [[[]]] + assert dmp_mul([[[]]], [[[ZZ(1)]]], 2, ZZ) == [[[]]] + assert dmp_mul([[[ZZ(2)]]], [[[ZZ(1)]]], 2, ZZ) == [[[ZZ(2)]]] + assert dmp_mul([[[ZZ(1)]]], [[[ZZ(2)]]], 2, ZZ) == [[[ZZ(2)]]] + + assert dmp_mul([[[]]], [[[]]], 2, QQ) == [[[]]] + assert dmp_mul([[[QQ(1, 2)]]], [[[]]], 2, QQ) == [[[]]] + assert dmp_mul([[[]]], [[[QQ(1, 2)]]], 2, QQ) == [[[]]] + assert dmp_mul([[[QQ(2, 7)]]], [[[QQ(1, 3)]]], 2, QQ) == [[[QQ(2, 21)]]] + assert dmp_mul([[[QQ(1, 7)]]], [[[QQ(2, 3)]]], 2, QQ) == [[[QQ(2, 21)]]] + + K = FF(6) + + assert dmp_mul( + [[K(2)], [K(1)]], [[K(3)], [K(4)]], 1, K) == [[K(5)], [K(4)]] + + +def test_dup_sqr(): + assert dup_sqr([], ZZ) == [] + assert dup_sqr([ZZ(2)], ZZ) == [ZZ(4)] + assert dup_sqr([ZZ(1), ZZ(2)], ZZ) == [ZZ(1), ZZ(4), ZZ(4)] + + assert dup_sqr([], QQ) == [] + assert dup_sqr([QQ(2, 3)], QQ) == [QQ(4, 9)] + assert dup_sqr([QQ(1, 3), QQ(2, 3)], QQ) == [QQ(1, 9), QQ(4, 9), QQ(4, 9)] + + f = dup_normal([2, 0, 0, 1, 7], ZZ) + + assert dup_sqr(f, ZZ) == dup_normal([4, 0, 0, 4, 28, 0, 1, 14, 49], ZZ) + + K = FF(9) + + assert dup_sqr([K(3), K(4)], K) == [K(6), K(7)] + + +def test_dmp_sqr(): + assert dmp_sqr([ZZ(1), ZZ(2)], 0, ZZ) == \ + dup_sqr([ZZ(1), ZZ(2)], ZZ) + + assert dmp_sqr([[[]]], 2, ZZ) == [[[]]] + assert dmp_sqr([[[ZZ(2)]]], 2, ZZ) == [[[ZZ(4)]]] + + assert dmp_sqr([[[]]], 2, QQ) == [[[]]] + assert dmp_sqr([[[QQ(2, 3)]]], 2, QQ) == [[[QQ(4, 9)]]] + + K = FF(9) + + assert dmp_sqr([[K(3)], [K(4)]], 1, K) == [[K(6)], [K(7)]] + + +def test_dup_pow(): + assert dup_pow([], 0, ZZ) == [ZZ(1)] + assert dup_pow([], 0, QQ) == [QQ(1)] + + assert dup_pow([], 1, ZZ) == [] + assert dup_pow([], 7, ZZ) == [] + + assert dup_pow([ZZ(1)], 0, ZZ) == [ZZ(1)] + assert dup_pow([ZZ(1)], 1, ZZ) == [ZZ(1)] + assert dup_pow([ZZ(1)], 7, ZZ) == [ZZ(1)] + + assert dup_pow([ZZ(3)], 0, ZZ) == [ZZ(1)] + assert dup_pow([ZZ(3)], 1, ZZ) == [ZZ(3)] + assert dup_pow([ZZ(3)], 7, ZZ) == [ZZ(2187)] + + assert dup_pow([QQ(1, 1)], 0, QQ) == [QQ(1, 1)] + assert dup_pow([QQ(1, 1)], 1, QQ) == [QQ(1, 1)] + assert dup_pow([QQ(1, 1)], 7, QQ) == [QQ(1, 1)] + + assert dup_pow([QQ(3, 7)], 0, QQ) == [QQ(1, 1)] + assert dup_pow([QQ(3, 7)], 1, QQ) == [QQ(3, 7)] + assert dup_pow([QQ(3, 7)], 7, QQ) == [QQ(2187, 823543)] + + f = dup_normal([2, 0, 0, 1, 7], ZZ) + + assert dup_pow(f, 0, ZZ) == dup_normal([1], ZZ) + assert dup_pow(f, 1, ZZ) == dup_normal([2, 0, 0, 1, 7], ZZ) + assert dup_pow(f, 2, ZZ) == dup_normal([4, 0, 0, 4, 28, 0, 1, 14, 49], ZZ) + assert dup_pow(f, 3, ZZ) == dup_normal( + [8, 0, 0, 12, 84, 0, 6, 84, 294, 1, 21, 147, 343], ZZ) + + +def test_dmp_pow(): + assert dmp_pow([[]], 0, 1, ZZ) == [[ZZ(1)]] + assert dmp_pow([[]], 0, 1, QQ) == [[QQ(1)]] + + assert dmp_pow([[]], 1, 1, ZZ) == [[]] + assert dmp_pow([[]], 7, 1, ZZ) == [[]] + + assert dmp_pow([[ZZ(1)]], 0, 1, ZZ) == [[ZZ(1)]] + assert dmp_pow([[ZZ(1)]], 1, 1, ZZ) == [[ZZ(1)]] + assert dmp_pow([[ZZ(1)]], 7, 1, ZZ) == [[ZZ(1)]] + + assert dmp_pow([[QQ(3, 7)]], 0, 1, QQ) == [[QQ(1, 1)]] + assert dmp_pow([[QQ(3, 7)]], 1, 1, QQ) == [[QQ(3, 7)]] + assert dmp_pow([[QQ(3, 7)]], 7, 1, QQ) == [[QQ(2187, 823543)]] + + f = dup_normal([2, 0, 0, 1, 7], ZZ) + + assert dmp_pow(f, 2, 0, ZZ) == dup_pow(f, 2, ZZ) + + +def test_dup_pdiv(): + f = dup_normal([3, 1, 1, 5], ZZ) + g = dup_normal([5, -3, 1], ZZ) + + q = dup_normal([15, 14], ZZ) + r = dup_normal([52, 111], ZZ) + + assert dup_pdiv(f, g, ZZ) == (q, r) + assert dup_pquo(f, g, ZZ) == q + assert dup_prem(f, g, ZZ) == r + + raises(ExactQuotientFailed, lambda: dup_pexquo(f, g, ZZ)) + + f = dup_normal([3, 1, 1, 5], QQ) + g = dup_normal([5, -3, 1], QQ) + + q = dup_normal([15, 14], QQ) + r = dup_normal([52, 111], QQ) + + assert dup_pdiv(f, g, QQ) == (q, r) + assert dup_pquo(f, g, QQ) == q + assert dup_prem(f, g, QQ) == r + + raises(ExactQuotientFailed, lambda: dup_pexquo(f, g, QQ)) + + +def test_dmp_pdiv(): + f = dmp_normal([[1], [], [1, 0, 0]], 1, ZZ) + g = dmp_normal([[1], [-1, 0]], 1, ZZ) + + q = dmp_normal([[1], [1, 0]], 1, ZZ) + r = dmp_normal([[2, 0, 0]], 1, ZZ) + + assert dmp_pdiv(f, g, 1, ZZ) == (q, r) + assert dmp_pquo(f, g, 1, ZZ) == q + assert dmp_prem(f, g, 1, ZZ) == r + + raises(ExactQuotientFailed, lambda: dmp_pexquo(f, g, 1, ZZ)) + + f = dmp_normal([[1], [], [1, 0, 0]], 1, ZZ) + g = dmp_normal([[2], [-2, 0]], 1, ZZ) + + q = dmp_normal([[2], [2, 0]], 1, ZZ) + r = dmp_normal([[8, 0, 0]], 1, ZZ) + + assert dmp_pdiv(f, g, 1, ZZ) == (q, r) + assert dmp_pquo(f, g, 1, ZZ) == q + assert dmp_prem(f, g, 1, ZZ) == r + + raises(ExactQuotientFailed, lambda: dmp_pexquo(f, g, 1, ZZ)) + + +def test_dup_rr_div(): + raises(ZeroDivisionError, lambda: dup_rr_div([1, 2, 3], [], ZZ)) + + f = dup_normal([3, 1, 1, 5], ZZ) + g = dup_normal([5, -3, 1], ZZ) + + q, r = [], f + + assert dup_rr_div(f, g, ZZ) == (q, r) + + +def test_dmp_rr_div(): + raises(ZeroDivisionError, lambda: dmp_rr_div([[1, 2], [3]], [[]], 1, ZZ)) + + f = dmp_normal([[1], [], [1, 0, 0]], 1, ZZ) + g = dmp_normal([[1], [-1, 0]], 1, ZZ) + + q = dmp_normal([[1], [1, 0]], 1, ZZ) + r = dmp_normal([[2, 0, 0]], 1, ZZ) + + assert dmp_rr_div(f, g, 1, ZZ) == (q, r) + + f = dmp_normal([[1], [], [1, 0, 0]], 1, ZZ) + g = dmp_normal([[-1], [1, 0]], 1, ZZ) + + q = dmp_normal([[-1], [-1, 0]], 1, ZZ) + r = dmp_normal([[2, 0, 0]], 1, ZZ) + + assert dmp_rr_div(f, g, 1, ZZ) == (q, r) + + f = dmp_normal([[1], [], [1, 0, 0]], 1, ZZ) + g = dmp_normal([[2], [-2, 0]], 1, ZZ) + + q, r = [[]], f + + assert dmp_rr_div(f, g, 1, ZZ) == (q, r) + + +def test_dup_ff_div(): + raises(ZeroDivisionError, lambda: dup_ff_div([1, 2, 3], [], QQ)) + + f = dup_normal([3, 1, 1, 5], QQ) + g = dup_normal([5, -3, 1], QQ) + + q = [QQ(3, 5), QQ(14, 25)] + r = [QQ(52, 25), QQ(111, 25)] + + assert dup_ff_div(f, g, QQ) == (q, r) + +def test_dup_ff_div_gmpy2(): + if GROUND_TYPES != 'gmpy2': + return + + from gmpy2 import mpq + from sympy.polys.domains import GMPYRationalField + K = GMPYRationalField() + + f = [mpq(1,3), mpq(3,2)] + g = [mpq(2,1)] + assert dmp_ff_div(f, g, 0, K) == ([mpq(1,6), mpq(3,4)], []) + + f = [mpq(1,2), mpq(1,3), mpq(1,4), mpq(1,5)] + g = [mpq(-1,1), mpq(1,1), mpq(-1,1)] + assert dmp_ff_div(f, g, 0, K) == ([mpq(-1,2), mpq(-5,6)], [mpq(7,12), mpq(-19,30)]) + +def test_dmp_ff_div(): + raises(ZeroDivisionError, lambda: dmp_ff_div([[1, 2], [3]], [[]], 1, QQ)) + + f = dmp_normal([[1], [], [1, 0, 0]], 1, QQ) + g = dmp_normal([[1], [-1, 0]], 1, QQ) + + q = [[QQ(1, 1)], [QQ(1, 1), QQ(0, 1)]] + r = [[QQ(2, 1), QQ(0, 1), QQ(0, 1)]] + + assert dmp_ff_div(f, g, 1, QQ) == (q, r) + + f = dmp_normal([[1], [], [1, 0, 0]], 1, QQ) + g = dmp_normal([[-1], [1, 0]], 1, QQ) + + q = [[QQ(-1, 1)], [QQ(-1, 1), QQ(0, 1)]] + r = [[QQ(2, 1), QQ(0, 1), QQ(0, 1)]] + + assert dmp_ff_div(f, g, 1, QQ) == (q, r) + + f = dmp_normal([[1], [], [1, 0, 0]], 1, QQ) + g = dmp_normal([[2], [-2, 0]], 1, QQ) + + q = [[QQ(1, 2)], [QQ(1, 2), QQ(0, 1)]] + r = [[QQ(2, 1), QQ(0, 1), QQ(0, 1)]] + + assert dmp_ff_div(f, g, 1, QQ) == (q, r) + + +def test_dup_div(): + f, g, q, r = [5, 4, 3, 2, 1], [1, 2, 3], [5, -6, 0], [20, 1] + + assert dup_div(f, g, ZZ) == (q, r) + assert dup_quo(f, g, ZZ) == q + assert dup_rem(f, g, ZZ) == r + + raises(ExactQuotientFailed, lambda: dup_exquo(f, g, ZZ)) + + f, g, q, r = [5, 4, 3, 2, 1, 0], [1, 2, 0, 0, 9], [5, -6], [15, 2, -44, 54] + + assert dup_div(f, g, ZZ) == (q, r) + assert dup_quo(f, g, ZZ) == q + assert dup_rem(f, g, ZZ) == r + + raises(ExactQuotientFailed, lambda: dup_exquo(f, g, ZZ)) + + +def test_dmp_div(): + f, g, q, r = [5, 4, 3, 2, 1], [1, 2, 3], [5, -6, 0], [20, 1] + + assert dmp_div(f, g, 0, ZZ) == (q, r) + assert dmp_quo(f, g, 0, ZZ) == q + assert dmp_rem(f, g, 0, ZZ) == r + + raises(ExactQuotientFailed, lambda: dmp_exquo(f, g, 0, ZZ)) + + f, g, q, r = [[[1]]], [[[2]], [1]], [[[]]], [[[1]]] + + assert dmp_div(f, g, 2, ZZ) == (q, r) + assert dmp_quo(f, g, 2, ZZ) == q + assert dmp_rem(f, g, 2, ZZ) == r + + raises(ExactQuotientFailed, lambda: dmp_exquo(f, g, 2, ZZ)) + + +def test_dup_max_norm(): + assert dup_max_norm([], ZZ) == 0 + assert dup_max_norm([1], ZZ) == 1 + + assert dup_max_norm([1, 4, 2, 3], ZZ) == 4 + + +def test_dmp_max_norm(): + assert dmp_max_norm([[[]]], 2, ZZ) == 0 + assert dmp_max_norm([[[1]]], 2, ZZ) == 1 + + assert dmp_max_norm(f_0, 2, ZZ) == 6 + + +def test_dup_l1_norm(): + assert dup_l1_norm([], ZZ) == 0 + assert dup_l1_norm([1], ZZ) == 1 + assert dup_l1_norm([1, 4, 2, 3], ZZ) == 10 + + +def test_dmp_l1_norm(): + assert dmp_l1_norm([[[]]], 2, ZZ) == 0 + assert dmp_l1_norm([[[1]]], 2, ZZ) == 1 + + assert dmp_l1_norm(f_0, 2, ZZ) == 31 + + +def test_dup_l2_norm_squared(): + assert dup_l2_norm_squared([], ZZ) == 0 + assert dup_l2_norm_squared([1], ZZ) == 1 + assert dup_l2_norm_squared([1, 4, 2, 3], ZZ) == 30 + + +def test_dmp_l2_norm_squared(): + assert dmp_l2_norm_squared([[[]]], 2, ZZ) == 0 + assert dmp_l2_norm_squared([[[1]]], 2, ZZ) == 1 + assert dmp_l2_norm_squared(f_0, 2, ZZ) == 111 + + +def test_dup_expand(): + assert dup_expand((), ZZ) == [1] + assert dup_expand(([1, 2, 3], [1, 2], [7, 5, 4, 3]), ZZ) == \ + dup_mul([1, 2, 3], dup_mul([1, 2], [7, 5, 4, 3], ZZ), ZZ) + + +def test_dmp_expand(): + assert dmp_expand((), 1, ZZ) == [[1]] + assert dmp_expand(([[1], [2], [3]], [[1], [2]], [[7], [5], [4], [3]]), 1, ZZ) == \ + dmp_mul([[1], [2], [3]], dmp_mul([[1], [2]], [[7], [5], [ + 4], [3]], 1, ZZ), 1, ZZ) + +def test_dup_mul_poly(): + p = Poly(18786186952704.0*x**165 + 9.31746684052255e+31*x**82, x, domain='RR') + px = Poly(18786186952704.0*x**166 + 9.31746684052255e+31*x**83, x, domain='RR') + + assert p * x == px + assert p.set_domain(QQ) * x == px.set_domain(QQ) + assert p.set_domain(CC) * x == px.set_domain(CC) diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/test_densebasic.py b/phivenv/Lib/site-packages/sympy/polys/tests/test_densebasic.py new file mode 100644 index 0000000000000000000000000000000000000000..43386d86d0e6ec7b20d3962d8063aa6402165f9a --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/tests/test_densebasic.py @@ -0,0 +1,730 @@ +"""Tests for dense recursive polynomials' basic tools. """ + +from sympy.polys.densebasic import ( + ninf, + dup_LC, dmp_LC, + dup_TC, dmp_TC, + dmp_ground_LC, dmp_ground_TC, + dmp_true_LT, + dup_degree, dmp_degree, + dmp_degree_in, dmp_degree_list, + dup_strip, dmp_strip, + dmp_validate, + dup_reverse, + dup_copy, dmp_copy, + dup_normal, dmp_normal, + dup_convert, dmp_convert, + dup_from_sympy, dmp_from_sympy, + dup_nth, dmp_nth, dmp_ground_nth, + dmp_zero_p, dmp_zero, + dmp_one_p, dmp_one, + dmp_ground_p, dmp_ground, + dmp_negative_p, dmp_positive_p, + dmp_zeros, dmp_grounds, + dup_from_dict, dup_from_raw_dict, + dup_to_dict, dup_to_raw_dict, + dmp_from_dict, dmp_to_dict, + dmp_swap, dmp_permute, + dmp_nest, dmp_raise, + dup_deflate, dmp_deflate, + dup_multi_deflate, dmp_multi_deflate, + dup_inflate, dmp_inflate, + dmp_exclude, dmp_include, + dmp_inject, dmp_eject, + dup_terms_gcd, dmp_terms_gcd, + dmp_list_terms, dmp_apply_pairs, + dup_slice, + dup_random, +) + +from sympy.polys.specialpolys import f_polys +from sympy.polys.domains import ZZ, QQ +from sympy.polys.rings import ring + +from sympy.core.singleton import S +from sympy.testing.pytest import raises + +from sympy.core.numbers import oo + +f_0, f_1, f_2, f_3, f_4, f_5, f_6 = [ f.to_dense() for f in f_polys() ] + +def test_dup_LC(): + assert dup_LC([], ZZ) == 0 + assert dup_LC([2, 3, 4, 5], ZZ) == 2 + + +def test_dup_TC(): + assert dup_TC([], ZZ) == 0 + assert dup_TC([2, 3, 4, 5], ZZ) == 5 + + +def test_dmp_LC(): + assert dmp_LC([[]], ZZ) == [] + assert dmp_LC([[2, 3, 4], [5]], ZZ) == [2, 3, 4] + assert dmp_LC([[[]]], ZZ) == [[]] + assert dmp_LC([[[2], [3, 4]], [[5]]], ZZ) == [[2], [3, 4]] + + +def test_dmp_TC(): + assert dmp_TC([[]], ZZ) == [] + assert dmp_TC([[2, 3, 4], [5]], ZZ) == [5] + assert dmp_TC([[[]]], ZZ) == [[]] + assert dmp_TC([[[2], [3, 4]], [[5]]], ZZ) == [[5]] + + +def test_dmp_ground_LC(): + assert dmp_ground_LC([[]], 1, ZZ) == 0 + assert dmp_ground_LC([[2, 3, 4], [5]], 1, ZZ) == 2 + assert dmp_ground_LC([[[]]], 2, ZZ) == 0 + assert dmp_ground_LC([[[2], [3, 4]], [[5]]], 2, ZZ) == 2 + + +def test_dmp_ground_TC(): + assert dmp_ground_TC([[]], 1, ZZ) == 0 + assert dmp_ground_TC([[2, 3, 4], [5]], 1, ZZ) == 5 + assert dmp_ground_TC([[[]]], 2, ZZ) == 0 + assert dmp_ground_TC([[[2], [3, 4]], [[5]]], 2, ZZ) == 5 + + +def test_dmp_true_LT(): + assert dmp_true_LT([[]], 1, ZZ) == ((0, 0), 0) + assert dmp_true_LT([[7]], 1, ZZ) == ((0, 0), 7) + + assert dmp_true_LT([[1, 0]], 1, ZZ) == ((0, 1), 1) + assert dmp_true_LT([[1], []], 1, ZZ) == ((1, 0), 1) + assert dmp_true_LT([[1, 0], []], 1, ZZ) == ((1, 1), 1) + + +def test_dup_degree(): + assert ninf == float('-inf') + assert dup_degree([]) is ninf + assert dup_degree([1]) == 0 + assert dup_degree([1, 0]) == 1 + assert dup_degree([1, 0, 0, 0, 1]) == 4 + + +def test_dmp_degree(): + assert dmp_degree([[]], 1) is ninf + assert dmp_degree([[[]]], 2) is ninf + + assert dmp_degree([[1]], 1) == 0 + assert dmp_degree([[2], [1]], 1) == 1 + + +def test_dmp_degree_in(): + assert dmp_degree_in([[[]]], 0, 2) is ninf + assert dmp_degree_in([[[]]], 1, 2) is ninf + assert dmp_degree_in([[[]]], 2, 2) is ninf + + assert dmp_degree_in([[[1]]], 0, 2) == 0 + assert dmp_degree_in([[[1]]], 1, 2) == 0 + assert dmp_degree_in([[[1]]], 2, 2) == 0 + + assert dmp_degree_in(f_4, 0, 2) == 9 + assert dmp_degree_in(f_4, 1, 2) == 12 + assert dmp_degree_in(f_4, 2, 2) == 8 + + assert dmp_degree_in(f_6, 0, 2) == 4 + assert dmp_degree_in(f_6, 1, 2) == 4 + assert dmp_degree_in(f_6, 2, 2) == 6 + assert dmp_degree_in(f_6, 3, 3) == 3 + + raises(IndexError, lambda: dmp_degree_in([[1]], -5, 1)) + + +def test_dmp_degree_list(): + assert dmp_degree_list([[[[ ]]]], 3) == (-oo, -oo, -oo, -oo) + assert dmp_degree_list([[[[1]]]], 3) == ( 0, 0, 0, 0) + + assert dmp_degree_list(f_0, 2) == (2, 2, 2) + assert dmp_degree_list(f_1, 2) == (3, 3, 3) + assert dmp_degree_list(f_2, 2) == (5, 3, 3) + assert dmp_degree_list(f_3, 2) == (5, 4, 7) + assert dmp_degree_list(f_4, 2) == (9, 12, 8) + assert dmp_degree_list(f_5, 2) == (3, 3, 3) + assert dmp_degree_list(f_6, 3) == (4, 4, 6, 3) + + +def test_dup_strip(): + assert dup_strip([]) == [] + assert dup_strip([0]) == [] + assert dup_strip([0, 0, 0]) == [] + + assert dup_strip([1]) == [1] + assert dup_strip([0, 1]) == [1] + assert dup_strip([0, 0, 0, 1]) == [1] + + assert dup_strip([1, 2, 0]) == [1, 2, 0] + assert dup_strip([0, 1, 2, 0]) == [1, 2, 0] + assert dup_strip([0, 0, 0, 1, 2, 0]) == [1, 2, 0] + + +def test_dmp_strip(): + assert dmp_strip([0, 1, 0], 0) == [1, 0] + + assert dmp_strip([[]], 1) == [[]] + assert dmp_strip([[], []], 1) == [[]] + assert dmp_strip([[], [], []], 1) == [[]] + + assert dmp_strip([[[]]], 2) == [[[]]] + assert dmp_strip([[[]], [[]]], 2) == [[[]]] + assert dmp_strip([[[]], [[]], [[]]], 2) == [[[]]] + + assert dmp_strip([[[1]]], 2) == [[[1]]] + assert dmp_strip([[[]], [[1]]], 2) == [[[1]]] + assert dmp_strip([[[]], [[1]], [[]]], 2) == [[[1]], [[]]] + + +def test_dmp_validate(): + assert dmp_validate([]) == ([], 0) + assert dmp_validate([0, 0, 0, 1, 0]) == ([1, 0], 0) + + assert dmp_validate([[[]]]) == ([[[]]], 2) + assert dmp_validate([[0], [], [0], [1], [0]]) == ([[1], []], 1) + + raises(ValueError, lambda: dmp_validate([[0], 0, [0], [1], [0]])) + + +def test_dup_reverse(): + assert dup_reverse([1, 2, 0, 3]) == [3, 0, 2, 1] + assert dup_reverse([1, 2, 3, 0]) == [3, 2, 1] + + +def test_dup_copy(): + f = [ZZ(1), ZZ(0), ZZ(2)] + g = dup_copy(f) + + g[0], g[2] = ZZ(7), ZZ(0) + + assert f != g + + +def test_dmp_copy(): + f = [[ZZ(1)], [ZZ(2), ZZ(0)]] + g = dmp_copy(f, 1) + + g[0][0], g[1][1] = ZZ(7), ZZ(1) + + assert f != g + + +def test_dup_normal(): + assert dup_normal([0, 0, 2, 1, 0, 11, 0], ZZ) == \ + [ZZ(2), ZZ(1), ZZ(0), ZZ(11), ZZ(0)] + + +def test_dmp_normal(): + assert dmp_normal([[0], [], [0, 2, 1], [0], [11], []], 1, ZZ) == \ + [[ZZ(2), ZZ(1)], [], [ZZ(11)], []] + + +def test_dup_convert(): + K0, K1 = ZZ['x'], ZZ + + f = [K0(1), K0(2), K0(0), K0(3)] + + assert dup_convert(f, K0, K1) == \ + [ZZ(1), ZZ(2), ZZ(0), ZZ(3)] + + +def test_dmp_convert(): + K0, K1 = ZZ['x'], ZZ + + f = [[K0(1)], [K0(2)], [], [K0(3)]] + + assert dmp_convert(f, 1, K0, K1) == \ + [[ZZ(1)], [ZZ(2)], [], [ZZ(3)]] + + +def test_dup_from_sympy(): + assert dup_from_sympy([S.One, S(2)], ZZ) == \ + [ZZ(1), ZZ(2)] + assert dup_from_sympy([S.Half, S(3)], QQ) == \ + [QQ(1, 2), QQ(3, 1)] + + +def test_dmp_from_sympy(): + assert dmp_from_sympy([[S.One, S(2)], [S.Zero]], 1, ZZ) == \ + [[ZZ(1), ZZ(2)], []] + assert dmp_from_sympy([[S.Half, S(2)]], 1, QQ) == \ + [[QQ(1, 2), QQ(2, 1)]] + + +def test_dup_nth(): + assert dup_nth([1, 2, 3], 0, ZZ) == 3 + assert dup_nth([1, 2, 3], 1, ZZ) == 2 + assert dup_nth([1, 2, 3], 2, ZZ) == 1 + + assert dup_nth([1, 2, 3], 9, ZZ) == 0 + + raises(IndexError, lambda: dup_nth([3, 4, 5], -1, ZZ)) + + +def test_dmp_nth(): + assert dmp_nth([[1], [2], [3]], 0, 1, ZZ) == [3] + assert dmp_nth([[1], [2], [3]], 1, 1, ZZ) == [2] + assert dmp_nth([[1], [2], [3]], 2, 1, ZZ) == [1] + + assert dmp_nth([[1], [2], [3]], 9, 1, ZZ) == [] + + raises(IndexError, lambda: dmp_nth([[3], [4], [5]], -1, 1, ZZ)) + + +def test_dmp_ground_nth(): + assert dmp_ground_nth([[]], (0, 0), 1, ZZ) == 0 + assert dmp_ground_nth([[1], [2], [3]], (0, 0), 1, ZZ) == 3 + assert dmp_ground_nth([[1], [2], [3]], (1, 0), 1, ZZ) == 2 + assert dmp_ground_nth([[1], [2], [3]], (2, 0), 1, ZZ) == 1 + + assert dmp_ground_nth([[1], [2], [3]], (2, 1), 1, ZZ) == 0 + assert dmp_ground_nth([[1], [2], [3]], (3, 0), 1, ZZ) == 0 + + raises(IndexError, lambda: dmp_ground_nth([[3], [4], [5]], (2, -1), 1, ZZ)) + + +def test_dmp_zero_p(): + assert dmp_zero_p([], 0) is True + assert dmp_zero_p([[]], 1) is True + + assert dmp_zero_p([[[]]], 2) is True + assert dmp_zero_p([[[1]]], 2) is False + + +def test_dmp_zero(): + assert dmp_zero(0) == [] + assert dmp_zero(2) == [[[]]] + + +def test_dmp_one_p(): + assert dmp_one_p([1], 0, ZZ) is True + assert dmp_one_p([[1]], 1, ZZ) is True + assert dmp_one_p([[[1]]], 2, ZZ) is True + assert dmp_one_p([[[12]]], 2, ZZ) is False + + +def test_dmp_one(): + assert dmp_one(0, ZZ) == [ZZ(1)] + assert dmp_one(2, ZZ) == [[[ZZ(1)]]] + + +def test_dmp_ground_p(): + assert dmp_ground_p([], 0, 0) is True + assert dmp_ground_p([[]], 0, 1) is True + assert dmp_ground_p([[]], 1, 1) is False + + assert dmp_ground_p([[ZZ(1)]], 1, 1) is True + assert dmp_ground_p([[[ZZ(2)]]], 2, 2) is True + + assert dmp_ground_p([[[ZZ(2)]]], 3, 2) is False + assert dmp_ground_p([[[ZZ(3)], []]], 3, 2) is False + + assert dmp_ground_p([], None, 0) is True + assert dmp_ground_p([[]], None, 1) is True + + assert dmp_ground_p([ZZ(1)], None, 0) is True + assert dmp_ground_p([[[ZZ(1)]]], None, 2) is True + + assert dmp_ground_p([[[ZZ(3)], []]], None, 2) is False + + +def test_dmp_ground(): + assert dmp_ground(ZZ(0), 2) == [[[]]] + + assert dmp_ground(ZZ(7), -1) == ZZ(7) + assert dmp_ground(ZZ(7), 0) == [ZZ(7)] + assert dmp_ground(ZZ(7), 2) == [[[ZZ(7)]]] + + +def test_dmp_zeros(): + assert dmp_zeros(4, 0, ZZ) == [[], [], [], []] + + assert dmp_zeros(0, 2, ZZ) == [] + assert dmp_zeros(1, 2, ZZ) == [[[[]]]] + assert dmp_zeros(2, 2, ZZ) == [[[[]]], [[[]]]] + assert dmp_zeros(3, 2, ZZ) == [[[[]]], [[[]]], [[[]]]] + + assert dmp_zeros(3, -1, ZZ) == [0, 0, 0] + + +def test_dmp_grounds(): + assert dmp_grounds(ZZ(7), 0, 2) == [] + + assert dmp_grounds(ZZ(7), 1, 2) == [[[[7]]]] + assert dmp_grounds(ZZ(7), 2, 2) == [[[[7]]], [[[7]]]] + assert dmp_grounds(ZZ(7), 3, 2) == [[[[7]]], [[[7]]], [[[7]]]] + + assert dmp_grounds(ZZ(7), 3, -1) == [7, 7, 7] + + +def test_dmp_negative_p(): + assert dmp_negative_p([[[]]], 2, ZZ) is False + assert dmp_negative_p([[[1], [2]]], 2, ZZ) is False + assert dmp_negative_p([[[-1], [2]]], 2, ZZ) is True + + +def test_dmp_positive_p(): + assert dmp_positive_p([[[]]], 2, ZZ) is False + assert dmp_positive_p([[[1], [2]]], 2, ZZ) is True + assert dmp_positive_p([[[-1], [2]]], 2, ZZ) is False + + +def test_dup_from_to_dict(): + assert dup_from_raw_dict({}, ZZ) == [] + assert dup_from_dict({}, ZZ) == [] + + assert dup_to_raw_dict([]) == {} + assert dup_to_dict([]) == {} + + assert dup_to_raw_dict([], ZZ, zero=True) == {0: ZZ(0)} + assert dup_to_dict([], ZZ, zero=True) == {(0,): ZZ(0)} + + f = [3, 0, 0, 2, 0, 0, 0, 0, 8] + g = {8: 3, 5: 2, 0: 8} + h = {(8,): 3, (5,): 2, (0,): 8} + + assert dup_from_raw_dict(g, ZZ) == f + assert dup_from_dict(h, ZZ) == f + + assert dup_to_raw_dict(f) == g + assert dup_to_dict(f) == h + + R, x,y = ring("x,y", ZZ) + K = R.to_domain() + + f = [R(3), R(0), R(2), R(0), R(0), R(8)] + g = {5: R(3), 3: R(2), 0: R(8)} + h = {(5,): R(3), (3,): R(2), (0,): R(8)} + + assert dup_from_raw_dict(g, K) == f + assert dup_from_dict(h, K) == f + + assert dup_to_raw_dict(f) == g + assert dup_to_dict(f) == h + + +def test_dmp_from_to_dict(): + assert dmp_from_dict({}, 1, ZZ) == [[]] + assert dmp_to_dict([[]], 1) == {} + + assert dmp_to_dict([], 0, ZZ, zero=True) == {(0,): ZZ(0)} + assert dmp_to_dict([[]], 1, ZZ, zero=True) == {(0, 0): ZZ(0)} + + f = [[3], [], [], [2], [], [], [], [], [8]] + g = {(8, 0): 3, (5, 0): 2, (0, 0): 8} + + assert dmp_from_dict(g, 1, ZZ) == f + assert dmp_to_dict(f, 1) == g + + +def test_dmp_swap(): + f = dmp_normal([[1, 0, 0], [], [1, 0], [], [1]], 1, ZZ) + g = dmp_normal([[1, 0, 0, 0, 0], [1, 0, 0], [1]], 1, ZZ) + + assert dmp_swap(f, 1, 1, 1, ZZ) == f + + assert dmp_swap(f, 0, 1, 1, ZZ) == g + assert dmp_swap(g, 0, 1, 1, ZZ) == f + + raises(IndexError, lambda: dmp_swap(f, -1, -7, 1, ZZ)) + + +def test_dmp_permute(): + f = dmp_normal([[1, 0, 0], [], [1, 0], [], [1]], 1, ZZ) + g = dmp_normal([[1, 0, 0, 0, 0], [1, 0, 0], [1]], 1, ZZ) + + assert dmp_permute(f, [0, 1], 1, ZZ) == f + assert dmp_permute(g, [0, 1], 1, ZZ) == g + + assert dmp_permute(f, [1, 0], 1, ZZ) == g + assert dmp_permute(g, [1, 0], 1, ZZ) == f + + +def test_dmp_nest(): + assert dmp_nest(ZZ(1), 2, ZZ) == [[[1]]] + + assert dmp_nest([[1]], 0, ZZ) == [[1]] + assert dmp_nest([[1]], 1, ZZ) == [[[1]]] + assert dmp_nest([[1]], 2, ZZ) == [[[[1]]]] + + +def test_dmp_raise(): + assert dmp_raise([], 2, 0, ZZ) == [[[]]] + assert dmp_raise([[1]], 0, 1, ZZ) == [[1]] + + assert dmp_raise([[1, 2, 3], [], [2, 3]], 2, 1, ZZ) == \ + [[[[1]], [[2]], [[3]]], [[[]]], [[[2]], [[3]]]] + + +def test_dup_deflate(): + assert dup_deflate([], ZZ) == (1, []) + assert dup_deflate([2], ZZ) == (1, [2]) + assert dup_deflate([1, 2, 3], ZZ) == (1, [1, 2, 3]) + assert dup_deflate([1, 0, 2, 0, 3], ZZ) == (2, [1, 2, 3]) + + assert dup_deflate(dup_from_raw_dict({7: 1, 1: 1}, ZZ), ZZ) == \ + (1, [1, 0, 0, 0, 0, 0, 1, 0]) + assert dup_deflate(dup_from_raw_dict({7: 1, 0: 1}, ZZ), ZZ) == \ + (7, [1, 1]) + assert dup_deflate(dup_from_raw_dict({7: 1, 3: 1}, ZZ), ZZ) == \ + (1, [1, 0, 0, 0, 1, 0, 0, 0]) + + assert dup_deflate(dup_from_raw_dict({7: 1, 4: 1}, ZZ), ZZ) == \ + (1, [1, 0, 0, 1, 0, 0, 0, 0]) + assert dup_deflate(dup_from_raw_dict({8: 1, 4: 1}, ZZ), ZZ) == \ + (4, [1, 1, 0]) + + assert dup_deflate(dup_from_raw_dict({8: 1}, ZZ), ZZ) == \ + (8, [1, 0]) + assert dup_deflate(dup_from_raw_dict({7: 1}, ZZ), ZZ) == \ + (7, [1, 0]) + assert dup_deflate(dup_from_raw_dict({1: 1}, ZZ), ZZ) == \ + (1, [1, 0]) + + +def test_dmp_deflate(): + assert dmp_deflate([[]], 1, ZZ) == ((1, 1), [[]]) + assert dmp_deflate([[2]], 1, ZZ) == ((1, 1), [[2]]) + + f = [[1, 0, 0], [], [1, 0], [], [1]] + + assert dmp_deflate(f, 1, ZZ) == ((2, 1), [[1, 0, 0], [1, 0], [1]]) + + +def test_dup_multi_deflate(): + assert dup_multi_deflate(([2],), ZZ) == (1, ([2],)) + assert dup_multi_deflate(([], []), ZZ) == (1, ([], [])) + + assert dup_multi_deflate(([1, 2, 3],), ZZ) == (1, ([1, 2, 3],)) + assert dup_multi_deflate(([1, 0, 2, 0, 3],), ZZ) == (2, ([1, 2, 3],)) + + assert dup_multi_deflate(([1, 0, 2, 0, 3], [2, 0, 0]), ZZ) == \ + (2, ([1, 2, 3], [2, 0])) + assert dup_multi_deflate(([1, 0, 2, 0, 3], [2, 1, 0]), ZZ) == \ + (1, ([1, 0, 2, 0, 3], [2, 1, 0])) + + +def test_dmp_multi_deflate(): + assert dmp_multi_deflate(([[]],), 1, ZZ) == \ + ((1, 1), ([[]],)) + assert dmp_multi_deflate(([[]], [[]]), 1, ZZ) == \ + ((1, 1), ([[]], [[]])) + + assert dmp_multi_deflate(([[1]], [[]]), 1, ZZ) == \ + ((1, 1), ([[1]], [[]])) + assert dmp_multi_deflate(([[1]], [[2]]), 1, ZZ) == \ + ((1, 1), ([[1]], [[2]])) + assert dmp_multi_deflate(([[1]], [[2, 0]]), 1, ZZ) == \ + ((1, 1), ([[1]], [[2, 0]])) + + assert dmp_multi_deflate(([[2, 0]], [[2, 0]]), 1, ZZ) == \ + ((1, 1), ([[2, 0]], [[2, 0]])) + + assert dmp_multi_deflate( + ([[2]], [[2, 0, 0]]), 1, ZZ) == ((1, 2), ([[2]], [[2, 0]])) + assert dmp_multi_deflate( + ([[2, 0, 0]], [[2, 0, 0]]), 1, ZZ) == ((1, 2), ([[2, 0]], [[2, 0]])) + + assert dmp_multi_deflate(([2, 0, 0], [1, 0, 4, 0, 1]), 0, ZZ) == \ + ((2,), ([2, 0], [1, 4, 1])) + + f = [[1, 0, 0], [], [1, 0], [], [1]] + g = [[1, 0, 1, 0], [], [1]] + + assert dmp_multi_deflate((f,), 1, ZZ) == \ + ((2, 1), ([[1, 0, 0], [1, 0], [1]],)) + + assert dmp_multi_deflate((f, g), 1, ZZ) == \ + ((2, 1), ([[1, 0, 0], [1, 0], [1]], + [[1, 0, 1, 0], [1]])) + + +def test_dup_inflate(): + assert dup_inflate([], 17, ZZ) == [] + + assert dup_inflate([1, 2, 3], 1, ZZ) == [1, 2, 3] + assert dup_inflate([1, 2, 3], 2, ZZ) == [1, 0, 2, 0, 3] + assert dup_inflate([1, 2, 3], 3, ZZ) == [1, 0, 0, 2, 0, 0, 3] + assert dup_inflate([1, 2, 3], 4, ZZ) == [1, 0, 0, 0, 2, 0, 0, 0, 3] + + raises(IndexError, lambda: dup_inflate([1, 2, 3], 0, ZZ)) + + +def test_dmp_inflate(): + assert dmp_inflate([1], (3,), 0, ZZ) == [1] + + assert dmp_inflate([[]], (3, 7), 1, ZZ) == [[]] + assert dmp_inflate([[2]], (1, 2), 1, ZZ) == [[2]] + + assert dmp_inflate([[2, 0]], (1, 1), 1, ZZ) == [[2, 0]] + assert dmp_inflate([[2, 0]], (1, 2), 1, ZZ) == [[2, 0, 0]] + assert dmp_inflate([[2, 0]], (1, 3), 1, ZZ) == [[2, 0, 0, 0]] + + assert dmp_inflate([[1, 0, 0], [1], [1, 0]], (2, 1), 1, ZZ) == \ + [[1, 0, 0], [], [1], [], [1, 0]] + + raises(IndexError, lambda: dmp_inflate([[]], (-3, 7), 1, ZZ)) + + +def test_dmp_exclude(): + assert dmp_exclude([[[]]], 2, ZZ) == ([], [[[]]], 2) + assert dmp_exclude([[[7]]], 2, ZZ) == ([], [[[7]]], 2) + + assert dmp_exclude([1, 2, 3], 0, ZZ) == ([], [1, 2, 3], 0) + assert dmp_exclude([[1], [2, 3]], 1, ZZ) == ([], [[1], [2, 3]], 1) + + assert dmp_exclude([[1, 2, 3]], 1, ZZ) == ([0], [1, 2, 3], 0) + assert dmp_exclude([[1], [2], [3]], 1, ZZ) == ([1], [1, 2, 3], 0) + + assert dmp_exclude([[[1, 2, 3]]], 2, ZZ) == ([0, 1], [1, 2, 3], 0) + assert dmp_exclude([[[1]], [[2]], [[3]]], 2, ZZ) == ([1, 2], [1, 2, 3], 0) + + +def test_dmp_include(): + assert dmp_include([1, 2, 3], [], 0, ZZ) == [1, 2, 3] + + assert dmp_include([1, 2, 3], [0], 0, ZZ) == [[1, 2, 3]] + assert dmp_include([1, 2, 3], [1], 0, ZZ) == [[1], [2], [3]] + + assert dmp_include([1, 2, 3], [0, 1], 0, ZZ) == [[[1, 2, 3]]] + assert dmp_include([1, 2, 3], [1, 2], 0, ZZ) == [[[1]], [[2]], [[3]]] + + +def test_dmp_inject(): + R, x,y = ring("x,y", ZZ) + K = R.to_domain() + + assert dmp_inject([], 0, K) == ([[[]]], 2) + assert dmp_inject([[]], 1, K) == ([[[[]]]], 3) + + assert dmp_inject([R(1)], 0, K) == ([[[1]]], 2) + assert dmp_inject([[R(1)]], 1, K) == ([[[[1]]]], 3) + + assert dmp_inject([R(1), 2*x + 3*y + 4], 0, K) == ([[[1]], [[2], [3, 4]]], 2) + + f = [3*x**2 + 7*x*y + 5*y**2, 2*x, R(0), x*y**2 + 11] + g = [[[3], [7, 0], [5, 0, 0]], [[2], []], [[]], [[1, 0, 0], [11]]] + + assert dmp_inject(f, 0, K) == (g, 2) + + +def test_dmp_eject(): + R, x,y = ring("x,y", ZZ) + K = R.to_domain() + + assert dmp_eject([[[]]], 2, K) == [] + assert dmp_eject([[[[]]]], 3, K) == [[]] + + assert dmp_eject([[[1]]], 2, K) == [R(1)] + assert dmp_eject([[[[1]]]], 3, K) == [[R(1)]] + + assert dmp_eject([[[1]], [[2], [3, 4]]], 2, K) == [R(1), 2*x + 3*y + 4] + + f = [3*x**2 + 7*x*y + 5*y**2, 2*x, R(0), x*y**2 + 11] + g = [[[3], [7, 0], [5, 0, 0]], [[2], []], [[]], [[1, 0, 0], [11]]] + + assert dmp_eject(g, 2, K) == f + + +def test_dup_terms_gcd(): + assert dup_terms_gcd([], ZZ) == (0, []) + assert dup_terms_gcd([1, 0, 1], ZZ) == (0, [1, 0, 1]) + assert dup_terms_gcd([1, 0, 1, 0], ZZ) == (1, [1, 0, 1]) + + +def test_dmp_terms_gcd(): + assert dmp_terms_gcd([[]], 1, ZZ) == ((0, 0), [[]]) + + assert dmp_terms_gcd([1, 0, 1, 0], 0, ZZ) == ((1,), [1, 0, 1]) + assert dmp_terms_gcd([[1], [], [1], []], 1, ZZ) == ((1, 0), [[1], [], [1]]) + + assert dmp_terms_gcd( + [[1, 0], [], [1]], 1, ZZ) == ((0, 0), [[1, 0], [], [1]]) + assert dmp_terms_gcd( + [[1, 0], [1, 0, 0], [], []], 1, ZZ) == ((2, 1), [[1], [1, 0]]) + + +def test_dmp_list_terms(): + assert dmp_list_terms([[[]]], 2, ZZ) == [((0, 0, 0), 0)] + assert dmp_list_terms([[[1]]], 2, ZZ) == [((0, 0, 0), 1)] + + assert dmp_list_terms([1, 2, 4, 3, 5], 0, ZZ) == \ + [((4,), 1), ((3,), 2), ((2,), 4), ((1,), 3), ((0,), 5)] + + assert dmp_list_terms([[1], [2, 4], [3, 5, 0]], 1, ZZ) == \ + [((2, 0), 1), ((1, 1), 2), ((1, 0), 4), ((0, 2), 3), ((0, 1), 5)] + + f = [[2, 0, 0, 0], [1, 0, 0], []] + + assert dmp_list_terms(f, 1, ZZ, order='lex') == [((2, 3), 2), ((1, 2), 1)] + assert dmp_list_terms( + f, 1, ZZ, order='grlex') == [((2, 3), 2), ((1, 2), 1)] + + f = [[2, 0, 0, 0], [1, 0, 0, 0, 0, 0], []] + + assert dmp_list_terms(f, 1, ZZ, order='lex') == [((2, 3), 2), ((1, 5), 1)] + assert dmp_list_terms( + f, 1, ZZ, order='grlex') == [((1, 5), 1), ((2, 3), 2)] + + +def test_dmp_apply_pairs(): + h = lambda a, b: a*b + + assert dmp_apply_pairs([1, 2, 3], [4, 5, 6], h, [], 0, ZZ) == [4, 10, 18] + + assert dmp_apply_pairs([2, 3], [4, 5, 6], h, [], 0, ZZ) == [10, 18] + assert dmp_apply_pairs([1, 2, 3], [5, 6], h, [], 0, ZZ) == [10, 18] + + assert dmp_apply_pairs( + [[1, 2], [3]], [[4, 5], [6]], h, [], 1, ZZ) == [[4, 10], [18]] + + assert dmp_apply_pairs( + [[1, 2], [3]], [[4], [5, 6]], h, [], 1, ZZ) == [[8], [18]] + assert dmp_apply_pairs( + [[1], [2, 3]], [[4, 5], [6]], h, [], 1, ZZ) == [[5], [18]] + + +def test_dup_slice(): + f = [1, 2, 3, 4] + + assert dup_slice(f, 0, 0, ZZ) == [] + assert dup_slice(f, 0, 1, ZZ) == [4] + assert dup_slice(f, 0, 2, ZZ) == [3, 4] + assert dup_slice(f, 0, 3, ZZ) == [2, 3, 4] + assert dup_slice(f, 0, 4, ZZ) == [1, 2, 3, 4] + + assert dup_slice(f, 0, 4, ZZ) == f + assert dup_slice(f, 0, 9, ZZ) == f + + assert dup_slice(f, 1, 0, ZZ) == [] + assert dup_slice(f, 1, 1, ZZ) == [] + assert dup_slice(f, 1, 2, ZZ) == [3, 0] + assert dup_slice(f, 1, 3, ZZ) == [2, 3, 0] + assert dup_slice(f, 1, 4, ZZ) == [1, 2, 3, 0] + + assert dup_slice([1, 2], 0, 3, ZZ) == [1, 2] + + g = [1, 0, 0, 2] + + assert dup_slice(g, 0, 3, ZZ) == [2] + + +def test_dup_random(): + f = dup_random(0, -10, 10, ZZ) + + assert dup_degree(f) == 0 + assert all(-10 <= c <= 10 for c in f) + + f = dup_random(1, -20, 20, ZZ) + + assert dup_degree(f) == 1 + assert all(-20 <= c <= 20 for c in f) + + f = dup_random(2, -30, 30, ZZ) + + assert dup_degree(f) == 2 + assert all(-30 <= c <= 30 for c in f) + + f = dup_random(3, -40, 40, ZZ) + + assert dup_degree(f) == 3 + assert all(-40 <= c <= 40 for c in f) diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/test_densetools.py b/phivenv/Lib/site-packages/sympy/polys/tests/test_densetools.py new file mode 100644 index 0000000000000000000000000000000000000000..b4bebd2a6f061a13a7d34b7689c696456310f62e --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/tests/test_densetools.py @@ -0,0 +1,714 @@ +"""Tests for dense recursive polynomials' tools. """ + +from sympy.polys.densebasic import ( + dup_normal, dmp_normal, + dup_from_raw_dict, + dmp_convert, dmp_swap, +) + +from sympy.polys.densearith import dmp_mul_ground + +from sympy.polys.densetools import ( + dup_clear_denoms, dmp_clear_denoms, + dup_integrate, dmp_integrate, dmp_integrate_in, + dup_diff, dmp_diff, dmp_diff_in, + dup_eval, dmp_eval, dmp_eval_in, + dmp_eval_tail, dmp_diff_eval_in, + dup_trunc, dmp_trunc, dmp_ground_trunc, + dup_monic, dmp_ground_monic, + dup_content, dmp_ground_content, + dup_primitive, dmp_ground_primitive, + dup_extract, dmp_ground_extract, + dup_real_imag, + dup_mirror, dup_scale, dup_shift, dmp_shift, + dup_transform, + dup_compose, dmp_compose, + dup_decompose, + dmp_lift, + dup_sign_variations, + dup_revert, dmp_revert, +) +from sympy.polys.polyclasses import ANP + +from sympy.polys.polyerrors import ( + MultivariatePolynomialError, + ExactQuotientFailed, + NotReversible, + DomainError, +) + +from sympy.polys.specialpolys import f_polys + +from sympy.polys.domains import FF, ZZ, QQ, ZZ_I, QQ_I, EX, RR +from sympy.polys.rings import ring + +from sympy.core.numbers import I +from sympy.core.singleton import S +from sympy.functions.elementary.trigonometric import sin + +from sympy.abc import x +from sympy.testing.pytest import raises + +f_0, f_1, f_2, f_3, f_4, f_5, f_6 = [ f.to_dense() for f in f_polys() ] + +def test_dup_integrate(): + assert dup_integrate([], 1, QQ) == [] + assert dup_integrate([], 2, QQ) == [] + + assert dup_integrate([QQ(1)], 1, QQ) == [QQ(1), QQ(0)] + assert dup_integrate([QQ(1)], 2, QQ) == [QQ(1, 2), QQ(0), QQ(0)] + + assert dup_integrate([QQ(1), QQ(2), QQ(3)], 0, QQ) == \ + [QQ(1), QQ(2), QQ(3)] + assert dup_integrate([QQ(1), QQ(2), QQ(3)], 1, QQ) == \ + [QQ(1, 3), QQ(1), QQ(3), QQ(0)] + assert dup_integrate([QQ(1), QQ(2), QQ(3)], 2, QQ) == \ + [QQ(1, 12), QQ(1, 3), QQ(3, 2), QQ(0), QQ(0)] + assert dup_integrate([QQ(1), QQ(2), QQ(3)], 3, QQ) == \ + [QQ(1, 60), QQ(1, 12), QQ(1, 2), QQ(0), QQ(0), QQ(0)] + + assert dup_integrate(dup_from_raw_dict({29: QQ(17)}, QQ), 3, QQ) == \ + dup_from_raw_dict({32: QQ(17, 29760)}, QQ) + + assert dup_integrate(dup_from_raw_dict({29: QQ(17), 5: QQ(1, 2)}, QQ), 3, QQ) == \ + dup_from_raw_dict({32: QQ(17, 29760), 8: QQ(1, 672)}, QQ) + + +def test_dmp_integrate(): + assert dmp_integrate([QQ(1)], 2, 0, QQ) == [QQ(1, 2), QQ(0), QQ(0)] + + assert dmp_integrate([[[]]], 1, 2, QQ) == [[[]]] + assert dmp_integrate([[[]]], 2, 2, QQ) == [[[]]] + + assert dmp_integrate([[[QQ(1)]]], 1, 2, QQ) == [[[QQ(1)]], [[]]] + assert dmp_integrate([[[QQ(1)]]], 2, 2, QQ) == [[[QQ(1, 2)]], [[]], [[]]] + + assert dmp_integrate([[QQ(1)], [QQ(2)], [QQ(3)]], 0, 1, QQ) == \ + [[QQ(1)], [QQ(2)], [QQ(3)]] + assert dmp_integrate([[QQ(1)], [QQ(2)], [QQ(3)]], 1, 1, QQ) == \ + [[QQ(1, 3)], [QQ(1)], [QQ(3)], []] + assert dmp_integrate([[QQ(1)], [QQ(2)], [QQ(3)]], 2, 1, QQ) == \ + [[QQ(1, 12)], [QQ(1, 3)], [QQ(3, 2)], [], []] + assert dmp_integrate([[QQ(1)], [QQ(2)], [QQ(3)]], 3, 1, QQ) == \ + [[QQ(1, 60)], [QQ(1, 12)], [QQ(1, 2)], [], [], []] + + +def test_dmp_integrate_in(): + f = dmp_convert(f_6, 3, ZZ, QQ) + + assert dmp_integrate_in(f, 2, 1, 3, QQ) == \ + dmp_swap( + dmp_integrate(dmp_swap(f, 0, 1, 3, QQ), 2, 3, QQ), 0, 1, 3, QQ) + assert dmp_integrate_in(f, 3, 1, 3, QQ) == \ + dmp_swap( + dmp_integrate(dmp_swap(f, 0, 1, 3, QQ), 3, 3, QQ), 0, 1, 3, QQ) + assert dmp_integrate_in(f, 2, 2, 3, QQ) == \ + dmp_swap( + dmp_integrate(dmp_swap(f, 0, 2, 3, QQ), 2, 3, QQ), 0, 2, 3, QQ) + assert dmp_integrate_in(f, 3, 2, 3, QQ) == \ + dmp_swap( + dmp_integrate(dmp_swap(f, 0, 2, 3, QQ), 3, 3, QQ), 0, 2, 3, QQ) + + raises(IndexError, lambda: dmp_integrate_in(f, 1, -1, 3, QQ)) + raises(IndexError, lambda: dmp_integrate_in(f, 1, 4, 3, QQ)) + + +def test_dup_diff(): + assert dup_diff([], 1, ZZ) == [] + assert dup_diff([7], 1, ZZ) == [] + assert dup_diff([2, 7], 1, ZZ) == [2] + assert dup_diff([1, 2, 1], 1, ZZ) == [2, 2] + assert dup_diff([1, 2, 3, 4], 1, ZZ) == [3, 4, 3] + assert dup_diff([1, -1, 0, 0, 2], 1, ZZ) == [4, -3, 0, 0] + + f = dup_normal([17, 34, 56, -345, 23, 76, 0, 0, 12, 3, 7], ZZ) + + assert dup_diff(f, 0, ZZ) == f + assert dup_diff(f, 1, ZZ) == [170, 306, 448, -2415, 138, 380, 0, 0, 24, 3] + assert dup_diff(f, 2, ZZ) == dup_diff(dup_diff(f, 1, ZZ), 1, ZZ) + assert dup_diff( + f, 3, ZZ) == dup_diff(dup_diff(dup_diff(f, 1, ZZ), 1, ZZ), 1, ZZ) + + K = FF(3) + f = dup_normal([17, 34, 56, -345, 23, 76, 0, 0, 12, 3, 7], K) + + assert dup_diff(f, 1, K) == dup_normal([2, 0, 1, 0, 0, 2, 0, 0, 0, 0], K) + assert dup_diff(f, 2, K) == dup_normal([1, 0, 0, 2, 0, 0, 0], K) + assert dup_diff(f, 3, K) == dup_normal([], K) + + assert dup_diff(f, 0, K) == f + assert dup_diff(f, 2, K) == dup_diff(dup_diff(f, 1, K), 1, K) + assert dup_diff( + f, 3, K) == dup_diff(dup_diff(dup_diff(f, 1, K), 1, K), 1, K) + + +def test_dmp_diff(): + assert dmp_diff([], 1, 0, ZZ) == [] + assert dmp_diff([[]], 1, 1, ZZ) == [[]] + assert dmp_diff([[[]]], 1, 2, ZZ) == [[[]]] + + assert dmp_diff([[[1], [2]]], 1, 2, ZZ) == [[[]]] + + assert dmp_diff([[[1]], [[]]], 1, 2, ZZ) == [[[1]]] + assert dmp_diff([[[3]], [[1]], [[]]], 1, 2, ZZ) == [[[6]], [[1]]] + + assert dmp_diff([1, -1, 0, 0, 2], 1, 0, ZZ) == \ + dup_diff([1, -1, 0, 0, 2], 1, ZZ) + + assert dmp_diff(f_6, 0, 3, ZZ) == f_6 + assert dmp_diff(f_6, 1, 3, ZZ) == [[[[8460]], [[]]], + [[[135, 0, 0], [], [], [-135, 0, 0]]], + [[[]]], + [[[-423]], [[-47]], [[]], [[141], [], [94, 0], []], [[]]]] + assert dmp_diff( + f_6, 2, 3, ZZ) == dmp_diff(dmp_diff(f_6, 1, 3, ZZ), 1, 3, ZZ) + assert dmp_diff(f_6, 3, 3, ZZ) == dmp_diff( + dmp_diff(dmp_diff(f_6, 1, 3, ZZ), 1, 3, ZZ), 1, 3, ZZ) + + K = FF(23) + F_6 = dmp_normal(f_6, 3, K) + + assert dmp_diff(F_6, 0, 3, K) == F_6 + assert dmp_diff(F_6, 1, 3, K) == dmp_diff(F_6, 1, 3, K) + assert dmp_diff(F_6, 2, 3, K) == dmp_diff(dmp_diff(F_6, 1, 3, K), 1, 3, K) + assert dmp_diff(F_6, 3, 3, K) == dmp_diff( + dmp_diff(dmp_diff(F_6, 1, 3, K), 1, 3, K), 1, 3, K) + + +def test_dmp_diff_in(): + assert dmp_diff_in(f_6, 2, 1, 3, ZZ) == \ + dmp_swap(dmp_diff(dmp_swap(f_6, 0, 1, 3, ZZ), 2, 3, ZZ), 0, 1, 3, ZZ) + assert dmp_diff_in(f_6, 3, 1, 3, ZZ) == \ + dmp_swap(dmp_diff(dmp_swap(f_6, 0, 1, 3, ZZ), 3, 3, ZZ), 0, 1, 3, ZZ) + assert dmp_diff_in(f_6, 2, 2, 3, ZZ) == \ + dmp_swap(dmp_diff(dmp_swap(f_6, 0, 2, 3, ZZ), 2, 3, ZZ), 0, 2, 3, ZZ) + assert dmp_diff_in(f_6, 3, 2, 3, ZZ) == \ + dmp_swap(dmp_diff(dmp_swap(f_6, 0, 2, 3, ZZ), 3, 3, ZZ), 0, 2, 3, ZZ) + + raises(IndexError, lambda: dmp_diff_in(f_6, 1, -1, 3, ZZ)) + raises(IndexError, lambda: dmp_diff_in(f_6, 1, 4, 3, ZZ)) + +def test_dup_eval(): + assert dup_eval([], 7, ZZ) == 0 + assert dup_eval([1, 2], 0, ZZ) == 2 + assert dup_eval([1, 2, 3], 7, ZZ) == 66 + + +def test_dmp_eval(): + assert dmp_eval([], 3, 0, ZZ) == 0 + + assert dmp_eval([[]], 3, 1, ZZ) == [] + assert dmp_eval([[[]]], 3, 2, ZZ) == [[]] + + assert dmp_eval([[1, 2]], 0, 1, ZZ) == [1, 2] + + assert dmp_eval([[[1]]], 3, 2, ZZ) == [[1]] + assert dmp_eval([[[1, 2]]], 3, 2, ZZ) == [[1, 2]] + + assert dmp_eval([[3, 2], [1, 2]], 3, 1, ZZ) == [10, 8] + assert dmp_eval([[[3, 2]], [[1, 2]]], 3, 2, ZZ) == [[10, 8]] + + +def test_dmp_eval_in(): + assert dmp_eval_in( + f_6, -2, 1, 3, ZZ) == dmp_eval(dmp_swap(f_6, 0, 1, 3, ZZ), -2, 3, ZZ) + assert dmp_eval_in( + f_6, 7, 1, 3, ZZ) == dmp_eval(dmp_swap(f_6, 0, 1, 3, ZZ), 7, 3, ZZ) + assert dmp_eval_in(f_6, -2, 2, 3, ZZ) == dmp_swap( + dmp_eval(dmp_swap(f_6, 0, 2, 3, ZZ), -2, 3, ZZ), 0, 1, 2, ZZ) + assert dmp_eval_in(f_6, 7, 2, 3, ZZ) == dmp_swap( + dmp_eval(dmp_swap(f_6, 0, 2, 3, ZZ), 7, 3, ZZ), 0, 1, 2, ZZ) + + f = [[[int(45)]], [[]], [[]], [[int(-9)], [-1], [], [int(3), int(0), int(10), int(0)]]] + + assert dmp_eval_in(f, -2, 2, 2, ZZ) == \ + [[45], [], [], [-9, -1, 0, -44]] + + raises(IndexError, lambda: dmp_eval_in(f_6, ZZ(1), -1, 3, ZZ)) + raises(IndexError, lambda: dmp_eval_in(f_6, ZZ(1), 4, 3, ZZ)) + + +def test_dmp_eval_tail(): + assert dmp_eval_tail([[]], [1], 1, ZZ) == [] + assert dmp_eval_tail([[[]]], [1], 2, ZZ) == [[]] + assert dmp_eval_tail([[[]]], [1, 2], 2, ZZ) == [] + + assert dmp_eval_tail(f_0, [], 2, ZZ) == f_0 + + assert dmp_eval_tail(f_0, [1, -17, 8], 2, ZZ) == 84496 + assert dmp_eval_tail(f_0, [-17, 8], 2, ZZ) == [-1409, 3, 85902] + assert dmp_eval_tail(f_0, [8], 2, ZZ) == [[83, 2], [3], [302, 81, 1]] + + assert dmp_eval_tail(f_1, [-17, 8], 2, ZZ) == [-136, 15699, 9166, -27144] + + assert dmp_eval_tail( + f_2, [-12, 3], 2, ZZ) == [-1377, 0, -702, -1224, 0, -624] + assert dmp_eval_tail( + f_3, [-12, 3], 2, ZZ) == [144, 82, -5181, -28872, -14868, -540] + + assert dmp_eval_tail( + f_4, [25, -1], 2, ZZ) == [152587890625, 9765625, -59605407714843750, + -3839159765625, -1562475, 9536712644531250, 610349546750, -4, 24414375000, 1562520] + assert dmp_eval_tail(f_5, [25, -1], 2, ZZ) == [-1, -78, -2028, -17576] + + assert dmp_eval_tail(f_6, [0, 2, 4], 3, ZZ) == [5040, 0, 0, 4480] + + +def test_dmp_diff_eval_in(): + assert dmp_diff_eval_in(f_6, 2, 7, 1, 3, ZZ) == \ + dmp_eval(dmp_diff(dmp_swap(f_6, 0, 1, 3, ZZ), 2, 3, ZZ), 7, 3, ZZ) + + assert dmp_diff_eval_in(f_6, 2, 7, 0, 3, ZZ) == \ + dmp_eval(dmp_diff(f_6, 2, 3, ZZ), 7, 3, ZZ) + + raises(IndexError, lambda: dmp_diff_eval_in(f_6, 1, ZZ(1), 4, 3, ZZ)) + + +def test_dup_revert(): + f = [-QQ(1, 720), QQ(0), QQ(1, 24), QQ(0), -QQ(1, 2), QQ(0), QQ(1)] + g = [QQ(61, 720), QQ(0), QQ(5, 24), QQ(0), QQ(1, 2), QQ(0), QQ(1)] + + assert dup_revert(f, 8, QQ) == g + + raises(NotReversible, lambda: dup_revert([QQ(1), QQ(0)], 3, QQ)) + + +def test_dmp_revert(): + f = [-QQ(1, 720), QQ(0), QQ(1, 24), QQ(0), -QQ(1, 2), QQ(0), QQ(1)] + g = [QQ(61, 720), QQ(0), QQ(5, 24), QQ(0), QQ(1, 2), QQ(0), QQ(1)] + + assert dmp_revert(f, 8, 0, QQ) == g + + raises(MultivariatePolynomialError, lambda: dmp_revert([[1]], 2, 1, QQ)) + + +def test_dup_trunc(): + assert dup_trunc([1, 2, 3, 4, 5, 6], ZZ(3), ZZ) == [1, -1, 0, 1, -1, 0] + assert dup_trunc([6, 5, 4, 3, 2, 1], ZZ(3), ZZ) == [-1, 1, 0, -1, 1] + + R = ZZ_I + assert dup_trunc([R(3), R(4), R(5)], R(3), R) == [R(1), R(-1)] + + K = FF(5) + assert dup_trunc([K(3), K(4), K(5)], K(3), K) == [K(1), K(0)] + + +def test_dmp_trunc(): + assert dmp_trunc([[]], [1, 2], 2, ZZ) == [[]] + assert dmp_trunc([[1, 2], [1, 4, 1], [1]], [1, 2], 1, ZZ) == [[-3], [1]] + + +def test_dmp_ground_trunc(): + assert dmp_ground_trunc(f_0, ZZ(3), 2, ZZ) == \ + dmp_normal( + [[[1, -1, 0], [-1]], [[]], [[1, -1, 0], [1, -1, 1], [1]]], 2, ZZ) + + +def test_dup_monic(): + assert dup_monic([3, 6, 9], ZZ) == [1, 2, 3] + + raises(ExactQuotientFailed, lambda: dup_monic([3, 4, 5], ZZ)) + + assert dup_monic([], QQ) == [] + assert dup_monic([QQ(1)], QQ) == [QQ(1)] + assert dup_monic([QQ(7), QQ(1), QQ(21)], QQ) == [QQ(1), QQ(1, 7), QQ(3)] + + +def test_dmp_ground_monic(): + assert dmp_ground_monic([3, 6, 9], 0, ZZ) == [1, 2, 3] + + assert dmp_ground_monic([[3], [6], [9]], 1, ZZ) == [[1], [2], [3]] + + raises( + ExactQuotientFailed, lambda: dmp_ground_monic([[3], [4], [5]], 1, ZZ)) + + assert dmp_ground_monic([[]], 1, QQ) == [[]] + assert dmp_ground_monic([[QQ(1)]], 1, QQ) == [[QQ(1)]] + assert dmp_ground_monic( + [[QQ(7)], [QQ(1)], [QQ(21)]], 1, QQ) == [[QQ(1)], [QQ(1, 7)], [QQ(3)]] + + +def test_dup_content(): + assert dup_content([], ZZ) == ZZ(0) + assert dup_content([1], ZZ) == ZZ(1) + assert dup_content([-1], ZZ) == ZZ(1) + assert dup_content([1, 1], ZZ) == ZZ(1) + assert dup_content([2, 2], ZZ) == ZZ(2) + assert dup_content([1, 2, 1], ZZ) == ZZ(1) + assert dup_content([2, 4, 2], ZZ) == ZZ(2) + + assert dup_content([QQ(2, 3), QQ(4, 9)], QQ) == QQ(2, 9) + assert dup_content([QQ(2, 3), QQ(4, 5)], QQ) == QQ(2, 15) + + +def test_dmp_ground_content(): + assert dmp_ground_content([[]], 1, ZZ) == ZZ(0) + assert dmp_ground_content([[]], 1, QQ) == QQ(0) + assert dmp_ground_content([[1]], 1, ZZ) == ZZ(1) + assert dmp_ground_content([[-1]], 1, ZZ) == ZZ(1) + assert dmp_ground_content([[1], [1]], 1, ZZ) == ZZ(1) + assert dmp_ground_content([[2], [2]], 1, ZZ) == ZZ(2) + assert dmp_ground_content([[1], [2], [1]], 1, ZZ) == ZZ(1) + assert dmp_ground_content([[2], [4], [2]], 1, ZZ) == ZZ(2) + + assert dmp_ground_content([[QQ(2, 3)], [QQ(4, 9)]], 1, QQ) == QQ(2, 9) + assert dmp_ground_content([[QQ(2, 3)], [QQ(4, 5)]], 1, QQ) == QQ(2, 15) + + assert dmp_ground_content(f_0, 2, ZZ) == ZZ(1) + assert dmp_ground_content( + dmp_mul_ground(f_0, ZZ(2), 2, ZZ), 2, ZZ) == ZZ(2) + + assert dmp_ground_content(f_1, 2, ZZ) == ZZ(1) + assert dmp_ground_content( + dmp_mul_ground(f_1, ZZ(3), 2, ZZ), 2, ZZ) == ZZ(3) + + assert dmp_ground_content(f_2, 2, ZZ) == ZZ(1) + assert dmp_ground_content( + dmp_mul_ground(f_2, ZZ(4), 2, ZZ), 2, ZZ) == ZZ(4) + + assert dmp_ground_content(f_3, 2, ZZ) == ZZ(1) + assert dmp_ground_content( + dmp_mul_ground(f_3, ZZ(5), 2, ZZ), 2, ZZ) == ZZ(5) + + assert dmp_ground_content(f_4, 2, ZZ) == ZZ(1) + assert dmp_ground_content( + dmp_mul_ground(f_4, ZZ(6), 2, ZZ), 2, ZZ) == ZZ(6) + + assert dmp_ground_content(f_5, 2, ZZ) == ZZ(1) + assert dmp_ground_content( + dmp_mul_ground(f_5, ZZ(7), 2, ZZ), 2, ZZ) == ZZ(7) + + assert dmp_ground_content(f_6, 3, ZZ) == ZZ(1) + assert dmp_ground_content( + dmp_mul_ground(f_6, ZZ(8), 3, ZZ), 3, ZZ) == ZZ(8) + + +def test_dup_primitive(): + assert dup_primitive([], ZZ) == (ZZ(0), []) + assert dup_primitive([ZZ(1)], ZZ) == (ZZ(1), [ZZ(1)]) + assert dup_primitive([ZZ(1), ZZ(1)], ZZ) == (ZZ(1), [ZZ(1), ZZ(1)]) + assert dup_primitive([ZZ(2), ZZ(2)], ZZ) == (ZZ(2), [ZZ(1), ZZ(1)]) + assert dup_primitive( + [ZZ(1), ZZ(2), ZZ(1)], ZZ) == (ZZ(1), [ZZ(1), ZZ(2), ZZ(1)]) + assert dup_primitive( + [ZZ(2), ZZ(4), ZZ(2)], ZZ) == (ZZ(2), [ZZ(1), ZZ(2), ZZ(1)]) + + assert dup_primitive([], QQ) == (QQ(0), []) + assert dup_primitive([QQ(1)], QQ) == (QQ(1), [QQ(1)]) + assert dup_primitive([QQ(1), QQ(1)], QQ) == (QQ(1), [QQ(1), QQ(1)]) + assert dup_primitive([QQ(2), QQ(2)], QQ) == (QQ(2), [QQ(1), QQ(1)]) + assert dup_primitive( + [QQ(1), QQ(2), QQ(1)], QQ) == (QQ(1), [QQ(1), QQ(2), QQ(1)]) + assert dup_primitive( + [QQ(2), QQ(4), QQ(2)], QQ) == (QQ(2), [QQ(1), QQ(2), QQ(1)]) + + assert dup_primitive( + [QQ(2, 3), QQ(4, 9)], QQ) == (QQ(2, 9), [QQ(3), QQ(2)]) + assert dup_primitive( + [QQ(2, 3), QQ(4, 5)], QQ) == (QQ(2, 15), [QQ(5), QQ(6)]) + + +def test_dmp_ground_primitive(): + assert dmp_ground_primitive([ZZ(1)], 0, ZZ) == (ZZ(1), [ZZ(1)]) + + assert dmp_ground_primitive([[]], 1, ZZ) == (ZZ(0), [[]]) + + assert dmp_ground_primitive(f_0, 2, ZZ) == (ZZ(1), f_0) + assert dmp_ground_primitive( + dmp_mul_ground(f_0, ZZ(2), 2, ZZ), 2, ZZ) == (ZZ(2), f_0) + + assert dmp_ground_primitive(f_1, 2, ZZ) == (ZZ(1), f_1) + assert dmp_ground_primitive( + dmp_mul_ground(f_1, ZZ(3), 2, ZZ), 2, ZZ) == (ZZ(3), f_1) + + assert dmp_ground_primitive(f_2, 2, ZZ) == (ZZ(1), f_2) + assert dmp_ground_primitive( + dmp_mul_ground(f_2, ZZ(4), 2, ZZ), 2, ZZ) == (ZZ(4), f_2) + + assert dmp_ground_primitive(f_3, 2, ZZ) == (ZZ(1), f_3) + assert dmp_ground_primitive( + dmp_mul_ground(f_3, ZZ(5), 2, ZZ), 2, ZZ) == (ZZ(5), f_3) + + assert dmp_ground_primitive(f_4, 2, ZZ) == (ZZ(1), f_4) + assert dmp_ground_primitive( + dmp_mul_ground(f_4, ZZ(6), 2, ZZ), 2, ZZ) == (ZZ(6), f_4) + + assert dmp_ground_primitive(f_5, 2, ZZ) == (ZZ(1), f_5) + assert dmp_ground_primitive( + dmp_mul_ground(f_5, ZZ(7), 2, ZZ), 2, ZZ) == (ZZ(7), f_5) + + assert dmp_ground_primitive(f_6, 3, ZZ) == (ZZ(1), f_6) + assert dmp_ground_primitive( + dmp_mul_ground(f_6, ZZ(8), 3, ZZ), 3, ZZ) == (ZZ(8), f_6) + + assert dmp_ground_primitive([[ZZ(2)]], 1, ZZ) == (ZZ(2), [[ZZ(1)]]) + assert dmp_ground_primitive([[QQ(2)]], 1, QQ) == (QQ(2), [[QQ(1)]]) + + assert dmp_ground_primitive( + [[QQ(2, 3)], [QQ(4, 9)]], 1, QQ) == (QQ(2, 9), [[QQ(3)], [QQ(2)]]) + assert dmp_ground_primitive( + [[QQ(2, 3)], [QQ(4, 5)]], 1, QQ) == (QQ(2, 15), [[QQ(5)], [QQ(6)]]) + + +def test_dup_extract(): + f = dup_normal([2930944, 0, 2198208, 0, 549552, 0, 45796], ZZ) + g = dup_normal([17585664, 0, 8792832, 0, 1099104, 0], ZZ) + + F = dup_normal([64, 0, 48, 0, 12, 0, 1], ZZ) + G = dup_normal([384, 0, 192, 0, 24, 0], ZZ) + + assert dup_extract(f, g, ZZ) == (45796, F, G) + + +def test_dmp_ground_extract(): + f = dmp_normal( + [[2930944], [], [2198208], [], [549552], [], [45796]], 1, ZZ) + g = dmp_normal([[17585664], [], [8792832], [], [1099104], []], 1, ZZ) + + F = dmp_normal([[64], [], [48], [], [12], [], [1]], 1, ZZ) + G = dmp_normal([[384], [], [192], [], [24], []], 1, ZZ) + + assert dmp_ground_extract(f, g, 1, ZZ) == (45796, F, G) + + +def test_dup_real_imag(): + assert dup_real_imag([], ZZ) == ([[]], [[]]) + assert dup_real_imag([1], ZZ) == ([[1]], [[]]) + + assert dup_real_imag([1, 1], ZZ) == ([[1], [1]], [[1, 0]]) + assert dup_real_imag([1, 2], ZZ) == ([[1], [2]], [[1, 0]]) + + assert dup_real_imag( + [1, 2, 3], ZZ) == ([[1], [2], [-1, 0, 3]], [[2, 0], [2, 0]]) + + assert dup_real_imag([ZZ(1), ZZ(0), ZZ(1), ZZ(3)], ZZ) == ( + [[ZZ(1)], [], [ZZ(-3), ZZ(0), ZZ(1)], [ZZ(3)]], + [[ZZ(3), ZZ(0)], [], [ZZ(-1), ZZ(0), ZZ(1), ZZ(0)]] + ) + + raises(DomainError, lambda: dup_real_imag([EX(1), EX(2)], EX)) + + + +def test_dup_mirror(): + assert dup_mirror([], ZZ) == [] + assert dup_mirror([1], ZZ) == [1] + + assert dup_mirror([1, 2, 3, 4, 5], ZZ) == [1, -2, 3, -4, 5] + assert dup_mirror([1, 2, 3, 4, 5, 6], ZZ) == [-1, 2, -3, 4, -5, 6] + + +def test_dup_scale(): + assert dup_scale([], -1, ZZ) == [] + assert dup_scale([1], -1, ZZ) == [1] + + assert dup_scale([1, 2, 3, 4, 5], -1, ZZ) == [1, -2, 3, -4, 5] + assert dup_scale([1, 2, 3, 4, 5], -7, ZZ) == [2401, -686, 147, -28, 5] + + +def test_dup_shift(): + assert dup_shift([], 1, ZZ) == [] + assert dup_shift([1], 1, ZZ) == [1] + + assert dup_shift([1, 2, 3, 4, 5], 1, ZZ) == [1, 6, 15, 20, 15] + assert dup_shift([1, 2, 3, 4, 5], 7, ZZ) == [1, 30, 339, 1712, 3267] + + +def test_dmp_shift(): + assert dmp_shift([ZZ(1), ZZ(2)], [ZZ(1)], 0, ZZ) == [ZZ(1), ZZ(3)] + + assert dmp_shift([[]], [ZZ(1), ZZ(2)], 1, ZZ) == [[]] + + xy = [[ZZ(1), ZZ(0)], []] # x*y + x1y2 = [[ZZ(1), ZZ(2)], [ZZ(1), ZZ(2)]] # (x+1)*(y+2) + assert dmp_shift(xy, [ZZ(1), ZZ(2)], 1, ZZ) == x1y2 + + +def test_dup_transform(): + assert dup_transform([], [], [1, 1], ZZ) == [] + assert dup_transform([], [1], [1, 1], ZZ) == [] + assert dup_transform([], [1, 2], [1, 1], ZZ) == [] + + assert dup_transform([6, -5, 4, -3, 17], [1, -3, 4], [2, -3], ZZ) == \ + [6, -82, 541, -2205, 6277, -12723, 17191, -13603, 4773] + + +def test_dup_compose(): + assert dup_compose([], [], ZZ) == [] + assert dup_compose([], [1], ZZ) == [] + assert dup_compose([], [1, 2], ZZ) == [] + + assert dup_compose([1], [], ZZ) == [1] + + assert dup_compose([1, 2, 0], [], ZZ) == [] + assert dup_compose([1, 2, 1], [], ZZ) == [1] + + assert dup_compose([1, 2, 1], [1], ZZ) == [4] + assert dup_compose([1, 2, 1], [7], ZZ) == [64] + + assert dup_compose([1, 2, 1], [1, -1], ZZ) == [1, 0, 0] + assert dup_compose([1, 2, 1], [1, 1], ZZ) == [1, 4, 4] + assert dup_compose([1, 2, 1], [1, 2, 1], ZZ) == [1, 4, 8, 8, 4] + + +def test_dmp_compose(): + assert dmp_compose([1, 2, 1], [1, 2, 1], 0, ZZ) == [1, 4, 8, 8, 4] + + assert dmp_compose([[[]]], [[[]]], 2, ZZ) == [[[]]] + assert dmp_compose([[[]]], [[[1]]], 2, ZZ) == [[[]]] + assert dmp_compose([[[]]], [[[1]], [[2]]], 2, ZZ) == [[[]]] + + assert dmp_compose([[[1]]], [], 2, ZZ) == [[[1]]] + + assert dmp_compose([[1], [2], [ ]], [[]], 1, ZZ) == [[]] + assert dmp_compose([[1], [2], [1]], [[]], 1, ZZ) == [[1]] + + assert dmp_compose([[1], [2], [1]], [[1]], 1, ZZ) == [[4]] + assert dmp_compose([[1], [2], [1]], [[7]], 1, ZZ) == [[64]] + + assert dmp_compose([[1], [2], [1]], [[1], [-1]], 1, ZZ) == [[1], [ ], [ ]] + assert dmp_compose([[1], [2], [1]], [[1], [ 1]], 1, ZZ) == [[1], [4], [4]] + + assert dmp_compose( + [[1], [2], [1]], [[1], [2], [1]], 1, ZZ) == [[1], [4], [8], [8], [4]] + + +def test_dup_decompose(): + assert dup_decompose([1], ZZ) == [[1]] + + assert dup_decompose([1, 0], ZZ) == [[1, 0]] + assert dup_decompose([1, 0, 0, 0], ZZ) == [[1, 0, 0, 0]] + + assert dup_decompose([1, 0, 0, 0, 0], ZZ) == [[1, 0, 0], [1, 0, 0]] + assert dup_decompose( + [1, 0, 0, 0, 0, 0, 0], ZZ) == [[1, 0, 0, 0], [1, 0, 0]] + + assert dup_decompose([7, 0, 0, 0, 1], ZZ) == [[7, 0, 1], [1, 0, 0]] + assert dup_decompose([4, 0, 3, 0, 2], ZZ) == [[4, 3, 2], [1, 0, 0]] + + f = [1, 0, 20, 0, 150, 0, 500, 0, 625, -2, 0, -10, 9] + + assert dup_decompose(f, ZZ) == [[1, 0, 0, -2, 9], [1, 0, 5, 0]] + + f = [2, 0, 40, 0, 300, 0, 1000, 0, 1250, -4, 0, -20, 18] + + assert dup_decompose(f, ZZ) == [[2, 0, 0, -4, 18], [1, 0, 5, 0]] + + f = [1, 0, 20, -8, 150, -120, 524, -600, 865, -1034, 600, -170, 29] + + assert dup_decompose(f, ZZ) == [[1, -8, 24, -34, 29], [1, 0, 5, 0]] + + R, t = ring("t", ZZ) + f = [6*t**2 - 42, + 48*t**2 + 96, + 144*t**2 + 648*t + 288, + 624*t**2 + 864*t + 384, + 108*t**3 + 312*t**2 + 432*t + 192] + + assert dup_decompose(f, R.to_domain()) == [f] + + +def test_dmp_lift(): + q = [QQ(1, 1), QQ(0, 1), QQ(1, 1)] + + f_a = [ANP([QQ(1, 1)], q, QQ), ANP([], q, QQ), ANP([], q, QQ), + ANP([QQ(1, 1), QQ(0, 1)], q, QQ), ANP([QQ(17, 1), QQ(0, 1)], q, QQ)] + + f_lift = QQ.map([1, 0, 0, 0, 0, 0, 1, 34, 289]) + + assert dmp_lift(f_a, 0, QQ.algebraic_field(I)) == f_lift + + f_g = [QQ_I(1), QQ_I(0), QQ_I(0), QQ_I(0, 1), QQ_I(0, 17)] + + assert dmp_lift(f_g, 0, QQ_I) == f_lift + + raises(DomainError, lambda: dmp_lift([EX(1), EX(2)], 0, EX)) + + +def test_dup_sign_variations(): + assert dup_sign_variations([], ZZ) == 0 + assert dup_sign_variations([1, 0], ZZ) == 0 + assert dup_sign_variations([1, 0, 2], ZZ) == 0 + assert dup_sign_variations([1, 0, 3, 0], ZZ) == 0 + assert dup_sign_variations([1, 0, 4, 0, 5], ZZ) == 0 + + assert dup_sign_variations([-1, 0, 2], ZZ) == 1 + assert dup_sign_variations([-1, 0, 3, 0], ZZ) == 1 + assert dup_sign_variations([-1, 0, 4, 0, 5], ZZ) == 1 + + assert dup_sign_variations([-1, -4, -5], ZZ) == 0 + assert dup_sign_variations([ 1, -4, -5], ZZ) == 1 + assert dup_sign_variations([ 1, 4, -5], ZZ) == 1 + assert dup_sign_variations([ 1, -4, 5], ZZ) == 2 + assert dup_sign_variations([-1, 4, -5], ZZ) == 2 + assert dup_sign_variations([-1, 4, 5], ZZ) == 1 + assert dup_sign_variations([-1, -4, 5], ZZ) == 1 + assert dup_sign_variations([ 1, 4, 5], ZZ) == 0 + + assert dup_sign_variations([-1, 0, -4, 0, -5], ZZ) == 0 + assert dup_sign_variations([ 1, 0, -4, 0, -5], ZZ) == 1 + assert dup_sign_variations([ 1, 0, 4, 0, -5], ZZ) == 1 + assert dup_sign_variations([ 1, 0, -4, 0, 5], ZZ) == 2 + assert dup_sign_variations([-1, 0, 4, 0, -5], ZZ) == 2 + assert dup_sign_variations([-1, 0, 4, 0, 5], ZZ) == 1 + assert dup_sign_variations([-1, 0, -4, 0, 5], ZZ) == 1 + assert dup_sign_variations([ 1, 0, 4, 0, 5], ZZ) == 0 + + +def test_dup_clear_denoms(): + assert dup_clear_denoms([], QQ, ZZ) == (ZZ(1), []) + + assert dup_clear_denoms([QQ(1)], QQ, ZZ) == (ZZ(1), [QQ(1)]) + assert dup_clear_denoms([QQ(7)], QQ, ZZ) == (ZZ(1), [QQ(7)]) + + assert dup_clear_denoms([QQ(7, 3)], QQ) == (ZZ(3), [QQ(7)]) + assert dup_clear_denoms([QQ(7, 3)], QQ, ZZ) == (ZZ(3), [QQ(7)]) + + assert dup_clear_denoms( + [QQ(3), QQ(1), QQ(0)], QQ, ZZ) == (ZZ(1), [QQ(3), QQ(1), QQ(0)]) + assert dup_clear_denoms( + [QQ(1), QQ(1, 2), QQ(0)], QQ, ZZ) == (ZZ(2), [QQ(2), QQ(1), QQ(0)]) + + assert dup_clear_denoms([QQ(3), QQ( + 1), QQ(0)], QQ, ZZ, convert=True) == (ZZ(1), [ZZ(3), ZZ(1), ZZ(0)]) + assert dup_clear_denoms([QQ(1), QQ( + 1, 2), QQ(0)], QQ, ZZ, convert=True) == (ZZ(2), [ZZ(2), ZZ(1), ZZ(0)]) + + assert dup_clear_denoms( + [EX(S(3)/2), EX(S(9)/4)], EX) == (EX(4), [EX(6), EX(9)]) + + assert dup_clear_denoms([EX(7)], EX) == (EX(1), [EX(7)]) + assert dup_clear_denoms([EX(sin(x)/x), EX(0)], EX) == (EX(x), [EX(sin(x)), EX(0)]) + + F = RR.frac_field(x) + result = dup_clear_denoms([F(8.48717/(8.0089*x + 2.83)), F(0.0)], F) + assert str(result) == "(x + 0.353356890459364, [1.05971731448763, 0.0])" + +def test_dmp_clear_denoms(): + assert dmp_clear_denoms([[]], 1, QQ, ZZ) == (ZZ(1), [[]]) + + assert dmp_clear_denoms([[QQ(1)]], 1, QQ, ZZ) == (ZZ(1), [[QQ(1)]]) + assert dmp_clear_denoms([[QQ(7)]], 1, QQ, ZZ) == (ZZ(1), [[QQ(7)]]) + + assert dmp_clear_denoms([[QQ(7, 3)]], 1, QQ) == (ZZ(3), [[QQ(7)]]) + assert dmp_clear_denoms([[QQ(7, 3)]], 1, QQ, ZZ) == (ZZ(3), [[QQ(7)]]) + + assert dmp_clear_denoms( + [[QQ(3)], [QQ(1)], []], 1, QQ, ZZ) == (ZZ(1), [[QQ(3)], [QQ(1)], []]) + assert dmp_clear_denoms([[QQ( + 1)], [QQ(1, 2)], []], 1, QQ, ZZ) == (ZZ(2), [[QQ(2)], [QQ(1)], []]) + + assert dmp_clear_denoms([QQ(3), QQ( + 1), QQ(0)], 0, QQ, ZZ, convert=True) == (ZZ(1), [ZZ(3), ZZ(1), ZZ(0)]) + assert dmp_clear_denoms([QQ(1), QQ(1, 2), QQ( + 0)], 0, QQ, ZZ, convert=True) == (ZZ(2), [ZZ(2), ZZ(1), ZZ(0)]) + + assert dmp_clear_denoms([[QQ(3)], [QQ( + 1)], []], 1, QQ, ZZ, convert=True) == (ZZ(1), [[QQ(3)], [QQ(1)], []]) + assert dmp_clear_denoms([[QQ(1)], [QQ(1, 2)], []], 1, QQ, ZZ, + convert=True) == (ZZ(2), [[QQ(2)], [QQ(1)], []]) + + assert dmp_clear_denoms( + [[EX(S(3)/2)], [EX(S(9)/4)]], 1, EX) == (EX(4), [[EX(6)], [EX(9)]]) + assert dmp_clear_denoms([[EX(7)]], 1, EX) == (EX(1), [[EX(7)]]) + assert dmp_clear_denoms([[EX(sin(x)/x), EX(0)]], 1, EX) == (EX(x), [[EX(sin(x)), EX(0)]]) diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/test_dispersion.py b/phivenv/Lib/site-packages/sympy/polys/tests/test_dispersion.py new file mode 100644 index 0000000000000000000000000000000000000000..ad56b7bebd73c38e037085d36625a41729c0369a --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/tests/test_dispersion.py @@ -0,0 +1,95 @@ +from sympy.core import Symbol, S, oo +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.polys import poly +from sympy.polys.dispersion import dispersion, dispersionset + + +def test_dispersion(): + x = Symbol("x") + a = Symbol("a") + + fp = poly(S.Zero, x) + assert sorted(dispersionset(fp)) == [0] + + fp = poly(S(2), x) + assert sorted(dispersionset(fp)) == [0] + + fp = poly(x + 1, x) + assert sorted(dispersionset(fp)) == [0] + assert dispersion(fp) == 0 + + fp = poly((x + 1)*(x + 2), x) + assert sorted(dispersionset(fp)) == [0, 1] + assert dispersion(fp) == 1 + + fp = poly(x*(x + 3), x) + assert sorted(dispersionset(fp)) == [0, 3] + assert dispersion(fp) == 3 + + fp = poly((x - 3)*(x + 3), x) + assert sorted(dispersionset(fp)) == [0, 6] + assert dispersion(fp) == 6 + + fp = poly(x**4 - 3*x**2 + 1, x) + gp = fp.shift(-3) + assert sorted(dispersionset(fp, gp)) == [2, 3, 4] + assert dispersion(fp, gp) == 4 + assert sorted(dispersionset(gp, fp)) == [] + assert dispersion(gp, fp) is -oo + + fp = poly(x*(3*x**2+a)*(x-2536)*(x**3+a), x) + gp = fp.as_expr().subs(x, x-345).as_poly(x) + assert sorted(dispersionset(fp, gp)) == [345, 2881] + assert sorted(dispersionset(gp, fp)) == [2191] + + gp = poly((x-2)**2*(x-3)**3*(x-5)**3, x) + assert sorted(dispersionset(gp)) == [0, 1, 2, 3] + assert sorted(dispersionset(gp, (gp+4)**2)) == [1, 2] + + fp = poly(x*(x+2)*(x-1), x) + assert sorted(dispersionset(fp)) == [0, 1, 2, 3] + + fp = poly(x**2 + sqrt(5)*x - 1, x, domain='QQ') + gp = poly(x**2 + (2 + sqrt(5))*x + sqrt(5), x, domain='QQ') + assert sorted(dispersionset(fp, gp)) == [2] + assert sorted(dispersionset(gp, fp)) == [1, 4] + + # There are some difficulties if we compute over Z[a] + # and alpha happens to lie in Z[a] instead of simply Z. + # Hence we can not decide if alpha is indeed integral + # in general. + + fp = poly(4*x**4 + (4*a + 8)*x**3 + (a**2 + 6*a + 4)*x**2 + (a**2 + 2*a)*x, x) + assert sorted(dispersionset(fp)) == [0, 1] + + # For any specific value of a, the dispersion is 3*a + # but the algorithm can not find this in general. + # This is the point where the resultant based Ansatz + # is superior to the current one. + fp = poly(a**2*x**3 + (a**3 + a**2 + a + 1)*x, x) + gp = fp.as_expr().subs(x, x - 3*a).as_poly(x) + assert sorted(dispersionset(fp, gp)) == [] + + fpa = fp.as_expr().subs(a, 2).as_poly(x) + gpa = gp.as_expr().subs(a, 2).as_poly(x) + assert sorted(dispersionset(fpa, gpa)) == [6] + + # Work with Expr instead of Poly + f = (x + 1)*(x + 2) + assert sorted(dispersionset(f)) == [0, 1] + assert dispersion(f) == 1 + + f = x**4 - 3*x**2 + 1 + g = x**4 - 12*x**3 + 51*x**2 - 90*x + 55 + assert sorted(dispersionset(f, g)) == [2, 3, 4] + assert dispersion(f, g) == 4 + + # Work with Expr and specify a generator + f = (x + 1)*(x + 2) + assert sorted(dispersionset(f, None, x)) == [0, 1] + assert dispersion(f, None, x) == 1 + + f = x**4 - 3*x**2 + 1 + g = x**4 - 12*x**3 + 51*x**2 - 90*x + 55 + assert sorted(dispersionset(f, g, x)) == [2, 3, 4] + assert dispersion(f, g, x) == 4 diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/test_distributedmodules.py b/phivenv/Lib/site-packages/sympy/polys/tests/test_distributedmodules.py new file mode 100644 index 0000000000000000000000000000000000000000..c95672f99f878f3def660aadec901afbde9adf8b --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/tests/test_distributedmodules.py @@ -0,0 +1,208 @@ +"""Tests for sparse distributed modules. """ + +from sympy.polys.distributedmodules import ( + sdm_monomial_mul, sdm_monomial_deg, sdm_monomial_divides, + sdm_add, sdm_LM, sdm_LT, sdm_mul_term, sdm_zero, sdm_deg, + sdm_LC, sdm_from_dict, + sdm_spoly, sdm_ecart, sdm_nf_mora, sdm_groebner, + sdm_from_vector, sdm_to_vector, sdm_monomial_lcm +) + +from sympy.polys.orderings import lex, grlex, InverseOrder +from sympy.polys.domains import QQ + +from sympy.abc import x, y, z + + +def test_sdm_monomial_mul(): + assert sdm_monomial_mul((1, 1, 0), (1, 3)) == (1, 2, 3) + + +def test_sdm_monomial_deg(): + assert sdm_monomial_deg((5, 2, 1)) == 3 + + +def test_sdm_monomial_lcm(): + assert sdm_monomial_lcm((1, 2, 3), (1, 5, 0)) == (1, 5, 3) + + +def test_sdm_monomial_divides(): + assert sdm_monomial_divides((1, 0, 0), (1, 0, 0)) is True + assert sdm_monomial_divides((1, 0, 0), (1, 2, 1)) is True + assert sdm_monomial_divides((5, 1, 1), (5, 2, 1)) is True + + assert sdm_monomial_divides((1, 0, 0), (2, 0, 0)) is False + assert sdm_monomial_divides((1, 1, 0), (1, 0, 0)) is False + assert sdm_monomial_divides((5, 1, 2), (5, 0, 1)) is False + + +def test_sdm_LC(): + assert sdm_LC([((1, 2, 3), QQ(5))], QQ) == QQ(5) + + +def test_sdm_from_dict(): + dic = {(1, 2, 1, 1): QQ(1), (1, 1, 2, 1): QQ(1), (1, 0, 2, 1): QQ(1), + (1, 0, 0, 3): QQ(1), (1, 1, 1, 0): QQ(1)} + assert sdm_from_dict(dic, grlex) == \ + [((1, 2, 1, 1), QQ(1)), ((1, 1, 2, 1), QQ(1)), + ((1, 0, 2, 1), QQ(1)), ((1, 0, 0, 3), QQ(1)), ((1, 1, 1, 0), QQ(1))] + +# TODO test to_dict? + + +def test_sdm_add(): + assert sdm_add([((1, 1, 1), QQ(1))], [((2, 0, 0), QQ(1))], lex, QQ) == \ + [((2, 0, 0), QQ(1)), ((1, 1, 1), QQ(1))] + assert sdm_add([((1, 1, 1), QQ(1))], [((1, 1, 1), QQ(-1))], lex, QQ) == [] + assert sdm_add([((1, 0, 0), QQ(1))], [((1, 0, 0), QQ(2))], lex, QQ) == \ + [((1, 0, 0), QQ(3))] + assert sdm_add([((1, 0, 1), QQ(1))], [((1, 1, 0), QQ(1))], lex, QQ) == \ + [((1, 1, 0), QQ(1)), ((1, 0, 1), QQ(1))] + + +def test_sdm_LM(): + dic = {(1, 2, 3): QQ(1), (4, 0, 0): QQ(1), (4, 0, 1): QQ(1)} + assert sdm_LM(sdm_from_dict(dic, lex)) == (4, 0, 1) + + +def test_sdm_LT(): + dic = {(1, 2, 3): QQ(1), (4, 0, 0): QQ(2), (4, 0, 1): QQ(3)} + assert sdm_LT(sdm_from_dict(dic, lex)) == ((4, 0, 1), QQ(3)) + + +def test_sdm_mul_term(): + assert sdm_mul_term([((1, 0, 0), QQ(1))], ((0, 0), QQ(0)), lex, QQ) == [] + assert sdm_mul_term([], ((1, 0), QQ(1)), lex, QQ) == [] + assert sdm_mul_term([((1, 0, 0), QQ(1))], ((1, 0), QQ(1)), lex, QQ) == \ + [((1, 1, 0), QQ(1))] + f = [((2, 0, 1), QQ(4)), ((1, 1, 0), QQ(3))] + assert sdm_mul_term(f, ((1, 1), QQ(2)), lex, QQ) == \ + [((2, 1, 2), QQ(8)), ((1, 2, 1), QQ(6))] + + +def test_sdm_zero(): + assert sdm_zero() == [] + + +def test_sdm_deg(): + assert sdm_deg([((1, 2, 3), 1), ((10, 0, 1), 1), ((2, 3, 4), 4)]) == 7 + + +def test_sdm_spoly(): + f = [((2, 1, 1), QQ(1)), ((1, 0, 1), QQ(1))] + g = [((2, 3, 0), QQ(1))] + h = [((1, 2, 3), QQ(1))] + assert sdm_spoly(f, h, lex, QQ) == [] + assert sdm_spoly(f, g, lex, QQ) == [((1, 2, 1), QQ(1))] + + +def test_sdm_ecart(): + assert sdm_ecart([((1, 2, 3), 1), ((1, 0, 1), 1)]) == 0 + assert sdm_ecart([((2, 2, 1), 1), ((1, 5, 1), 1)]) == 3 + + +def test_sdm_nf_mora(): + f = sdm_from_dict({(1, 2, 1, 1): QQ(1), (1, 1, 2, 1): QQ(1), + (1, 0, 2, 1): QQ(1), (1, 0, 0, 3): QQ(1), (1, 1, 1, 0): QQ(1)}, + grlex) + f1 = sdm_from_dict({(1, 1, 1, 0): QQ(1), (1, 0, 2, 0): QQ(1), + (1, 0, 0, 0): QQ(-1)}, grlex) + f2 = sdm_from_dict({(1, 1, 1, 0): QQ(1)}, grlex) + (id0, id1, id2) = [sdm_from_dict({(i, 0, 0, 0): QQ(1)}, grlex) + for i in range(3)] + + assert sdm_nf_mora(f, [f1, f2], grlex, QQ, phantom=(id0, [id1, id2])) == \ + ([((1, 0, 2, 1), QQ(1)), ((1, 0, 0, 3), QQ(1)), ((1, 1, 1, 0), QQ(1)), + ((1, 1, 0, 1), QQ(1))], + [((1, 1, 0, 1), QQ(-1)), ((0, 0, 0, 0), QQ(1))]) + assert sdm_nf_mora(f, [f2, f1], grlex, QQ, phantom=(id0, [id2, id1])) == \ + ([((1, 0, 2, 1), QQ(1)), ((1, 0, 0, 3), QQ(1)), ((1, 1, 1, 0), QQ(1))], + [((2, 1, 0, 1), QQ(-1)), ((2, 0, 1, 1), QQ(-1)), ((0, 0, 0, 0), QQ(1))]) + + f = sdm_from_vector([x*z, y**2 + y*z - z, y], lex, QQ, gens=[x, y, z]) + f1 = sdm_from_vector([x, y, 1], lex, QQ, gens=[x, y, z]) + f2 = sdm_from_vector([x*y, z, z**2], lex, QQ, gens=[x, y, z]) + assert sdm_nf_mora(f, [f1, f2], lex, QQ) == \ + sdm_nf_mora(f, [f2, f1], lex, QQ) == \ + [((1, 0, 1, 1), QQ(1)), ((1, 0, 0, 1), QQ(-1)), ((0, 1, 1, 0), QQ(-1)), + ((0, 1, 0, 1), QQ(1))] + + +def test_conversion(): + f = [x**2 + y**2, 2*z] + g = [((1, 0, 0, 1), QQ(2)), ((0, 2, 0, 0), QQ(1)), ((0, 0, 2, 0), QQ(1))] + assert sdm_to_vector(g, [x, y, z], QQ) == f + assert sdm_from_vector(f, lex, QQ) == g + assert sdm_from_vector( + [x, 1], lex, QQ) == [((1, 0), QQ(1)), ((0, 1), QQ(1))] + assert sdm_to_vector([((1, 1, 0, 0), 1)], [x, y, z], QQ, n=3) == [0, x, 0] + assert sdm_from_vector([0, 0], lex, QQ, gens=[x, y]) == sdm_zero() + + +def test_nontrivial(): + gens = [x, y, z] + + def contains(I, f): + S = [sdm_from_vector([g], lex, QQ, gens=gens) for g in I] + G = sdm_groebner(S, sdm_nf_mora, lex, QQ) + return sdm_nf_mora(sdm_from_vector([f], lex, QQ, gens=gens), + G, lex, QQ) == sdm_zero() + + assert contains([x, y], x) + assert contains([x, y], x + y) + assert not contains([x, y], 1) + assert not contains([x, y], z) + assert contains([x**2 + y, x**2 + x], x - y) + assert not contains([x + y + z, x*y + x*z + y*z, x*y*z], x**2) + assert contains([x + y + z, x*y + x*z + y*z, x*y*z], x**3) + assert contains([x + y + z, x*y + x*z + y*z, x*y*z], x**4) + assert not contains([x + y + z, x*y + x*z + y*z, x*y*z], x*y**2) + assert contains([x + y + z, x*y + x*z + y*z, x*y*z], x**4 + y**3 + 2*z*y*x) + assert contains([x + y + z, x*y + x*z + y*z, x*y*z], x*y*z) + assert contains([x, 1 + x + y, 5 - 7*y], 1) + assert contains( + [x**3 + y**3, y**3 + z**3, z**3 + x**3, x**2*y + x**2*z + y**2*z], + x**3) + assert not contains( + [x**3 + y**3, y**3 + z**3, z**3 + x**3, x**2*y + x**2*z + y**2*z], + x**2 + y**2) + + # compare local order + assert not contains([x*(1 + x + y), y*(1 + z)], x) + assert not contains([x*(1 + x + y), y*(1 + z)], x + y) + + +def test_local(): + igrlex = InverseOrder(grlex) + gens = [x, y, z] + + def contains(I, f): + S = [sdm_from_vector([g], igrlex, QQ, gens=gens) for g in I] + G = sdm_groebner(S, sdm_nf_mora, igrlex, QQ) + return sdm_nf_mora(sdm_from_vector([f], lex, QQ, gens=gens), + G, lex, QQ) == sdm_zero() + assert contains([x, y], x) + assert contains([x, y], x + y) + assert not contains([x, y], 1) + assert not contains([x, y], z) + assert contains([x**2 + y, x**2 + x], x - y) + assert not contains([x + y + z, x*y + x*z + y*z, x*y*z], x**2) + assert contains([x*(1 + x + y), y*(1 + z)], x) + assert contains([x*(1 + x + y), y*(1 + z)], x + y) + + +def test_uncovered_line(): + gens = [x, y] + f1 = sdm_zero() + f2 = sdm_from_vector([x, 0], lex, QQ, gens=gens) + f3 = sdm_from_vector([0, y], lex, QQ, gens=gens) + + assert sdm_spoly(f1, f2, lex, QQ) == sdm_zero() + assert sdm_spoly(f3, f2, lex, QQ) == sdm_zero() + + +def test_chain_criterion(): + gens = [x] + f1 = sdm_from_vector([1, x], grlex, QQ, gens=gens) + f2 = sdm_from_vector([0, x - 2], grlex, QQ, gens=gens) + assert len(sdm_groebner([f1, f2], sdm_nf_mora, grlex, QQ)) == 2 diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/test_euclidtools.py b/phivenv/Lib/site-packages/sympy/polys/tests/test_euclidtools.py new file mode 100644 index 0000000000000000000000000000000000000000..3061be73f987163951a5836ff50125d29abc60c7 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/tests/test_euclidtools.py @@ -0,0 +1,712 @@ +"""Tests for Euclidean algorithms, GCDs, LCMs and polynomial remainder sequences. """ + +from sympy.polys.rings import ring +from sympy.polys.domains import ZZ, QQ, RR + +from sympy.polys.specialpolys import ( + f_polys, + dmp_fateman_poly_F_1, + dmp_fateman_poly_F_2, + dmp_fateman_poly_F_3) + +f_0, f_1, f_2, f_3, f_4, f_5, f_6 = f_polys() + +def test_dup_gcdex(): + R, x = ring("x", QQ) + + f = x**4 - 2*x**3 - 6*x**2 + 12*x + 15 + g = x**3 + x**2 - 4*x - 4 + + s = -QQ(1,5)*x + QQ(3,5) + t = QQ(1,5)*x**2 - QQ(6,5)*x + 2 + h = x + 1 + + assert R.dup_half_gcdex(f, g) == (s, h) + assert R.dup_gcdex(f, g) == (s, t, h) + + f = x**4 + 4*x**3 - x + 1 + g = x**3 - x + 1 + + s, t, h = R.dup_gcdex(f, g) + S, T, H = R.dup_gcdex(g, f) + + assert R.dup_add(R.dup_mul(s, f), + R.dup_mul(t, g)) == h + assert R.dup_add(R.dup_mul(S, g), + R.dup_mul(T, f)) == H + + f = 2*x + g = x**2 - 16 + + s = QQ(1,32)*x + t = -QQ(1,16) + h = 1 + + assert R.dup_half_gcdex(f, g) == (s, h) + assert R.dup_gcdex(f, g) == (s, t, h) + + +def test_dup_invert(): + R, x = ring("x", QQ) + assert R.dup_invert(2*x, x**2 - 16) == QQ(1,32)*x + + +def test_dup_euclidean_prs(): + R, x = ring("x", QQ) + + f = x**8 + x**6 - 3*x**4 - 3*x**3 + 8*x**2 + 2*x - 5 + g = 3*x**6 + 5*x**4 - 4*x**2 - 9*x + 21 + + assert R.dup_euclidean_prs(f, g) == [ + f, + g, + -QQ(5,9)*x**4 + QQ(1,9)*x**2 - QQ(1,3), + -QQ(117,25)*x**2 - 9*x + QQ(441,25), + QQ(233150,19773)*x - QQ(102500,6591), + -QQ(1288744821,543589225)] + + +def test_dup_primitive_prs(): + R, x = ring("x", ZZ) + + f = x**8 + x**6 - 3*x**4 - 3*x**3 + 8*x**2 + 2*x - 5 + g = 3*x**6 + 5*x**4 - 4*x**2 - 9*x + 21 + + assert R.dup_primitive_prs(f, g) == [ + f, + g, + -5*x**4 + x**2 - 3, + 13*x**2 + 25*x - 49, + 4663*x - 6150, + 1] + + +def test_dup_subresultants(): + R, x = ring("x", ZZ) + + assert R.dup_resultant(0, 0) == 0 + + assert R.dup_resultant(1, 0) == 0 + assert R.dup_resultant(0, 1) == 0 + + f = x**8 + x**6 - 3*x**4 - 3*x**3 + 8*x**2 + 2*x - 5 + g = 3*x**6 + 5*x**4 - 4*x**2 - 9*x + 21 + + a = 15*x**4 - 3*x**2 + 9 + b = 65*x**2 + 125*x - 245 + c = 9326*x - 12300 + d = 260708 + + assert R.dup_subresultants(f, g) == [f, g, a, b, c, d] + assert R.dup_resultant(f, g) == R.dup_LC(d) + + f = x**2 - 2*x + 1 + g = x**2 - 1 + + a = 2*x - 2 + + assert R.dup_subresultants(f, g) == [f, g, a] + assert R.dup_resultant(f, g) == 0 + + f = x**2 + 1 + g = x**2 - 1 + + a = -2 + + assert R.dup_subresultants(f, g) == [f, g, a] + assert R.dup_resultant(f, g) == 4 + + f = x**2 - 1 + g = x**3 - x**2 + 2 + + assert R.dup_resultant(f, g) == 0 + + f = 3*x**3 - x + g = 5*x**2 + 1 + + assert R.dup_resultant(f, g) == 64 + + f = x**2 - 2*x + 7 + g = x**3 - x + 5 + + assert R.dup_resultant(f, g) == 265 + + f = x**3 - 6*x**2 + 11*x - 6 + g = x**3 - 15*x**2 + 74*x - 120 + + assert R.dup_resultant(f, g) == -8640 + + f = x**3 - 6*x**2 + 11*x - 6 + g = x**3 - 10*x**2 + 29*x - 20 + + assert R.dup_resultant(f, g) == 0 + + f = x**3 - 1 + g = x**3 + 2*x**2 + 2*x - 1 + + assert R.dup_resultant(f, g) == 16 + + f = x**8 - 2 + g = x - 1 + + assert R.dup_resultant(f, g) == -1 + + +def test_dmp_subresultants(): + R, x, y = ring("x,y", ZZ) + + assert R.dmp_resultant(0, 0) == 0 + assert R.dmp_prs_resultant(0, 0)[0] == 0 + assert R.dmp_zz_collins_resultant(0, 0) == 0 + assert R.dmp_qq_collins_resultant(0, 0) == 0 + + assert R.dmp_resultant(1, 0) == 0 + assert R.dmp_resultant(1, 0) == 0 + assert R.dmp_resultant(1, 0) == 0 + + assert R.dmp_resultant(0, 1) == 0 + assert R.dmp_prs_resultant(0, 1)[0] == 0 + assert R.dmp_zz_collins_resultant(0, 1) == 0 + assert R.dmp_qq_collins_resultant(0, 1) == 0 + + f = 3*x**2*y - y**3 - 4 + g = x**2 + x*y**3 - 9 + + a = 3*x*y**4 + y**3 - 27*y + 4 + b = -3*y**10 - 12*y**7 + y**6 - 54*y**4 + 8*y**3 + 729*y**2 - 216*y + 16 + + r = R.dmp_LC(b) + + assert R.dmp_subresultants(f, g) == [f, g, a, b] + + assert R.dmp_resultant(f, g) == r + assert R.dmp_prs_resultant(f, g)[0] == r + assert R.dmp_zz_collins_resultant(f, g) == r + assert R.dmp_qq_collins_resultant(f, g) == r + + f = -x**3 + 5 + g = 3*x**2*y + x**2 + + a = 45*y**2 + 30*y + 5 + b = 675*y**3 + 675*y**2 + 225*y + 25 + + r = R.dmp_LC(b) + + assert R.dmp_subresultants(f, g) == [f, g, a] + assert R.dmp_resultant(f, g) == r + assert R.dmp_prs_resultant(f, g)[0] == r + assert R.dmp_zz_collins_resultant(f, g) == r + assert R.dmp_qq_collins_resultant(f, g) == r + + R, x, y, z, u, v = ring("x,y,z,u,v", ZZ) + + f = 6*x**2 - 3*x*y - 2*x*z + y*z + g = x**2 - x*u - x*v + u*v + + r = y**2*z**2 - 3*y**2*z*u - 3*y**2*z*v + 9*y**2*u*v - 2*y*z**2*u \ + - 2*y*z**2*v + 6*y*z*u**2 + 12*y*z*u*v + 6*y*z*v**2 - 18*y*u**2*v \ + - 18*y*u*v**2 + 4*z**2*u*v - 12*z*u**2*v - 12*z*u*v**2 + 36*u**2*v**2 + + assert R.dmp_zz_collins_resultant(f, g) == r.drop(x) + + R, x, y, z, u, v = ring("x,y,z,u,v", QQ) + + f = x**2 - QQ(1,2)*x*y - QQ(1,3)*x*z + QQ(1,6)*y*z + g = x**2 - x*u - x*v + u*v + + r = QQ(1,36)*y**2*z**2 - QQ(1,12)*y**2*z*u - QQ(1,12)*y**2*z*v + QQ(1,4)*y**2*u*v \ + - QQ(1,18)*y*z**2*u - QQ(1,18)*y*z**2*v + QQ(1,6)*y*z*u**2 + QQ(1,3)*y*z*u*v \ + + QQ(1,6)*y*z*v**2 - QQ(1,2)*y*u**2*v - QQ(1,2)*y*u*v**2 + QQ(1,9)*z**2*u*v \ + - QQ(1,3)*z*u**2*v - QQ(1,3)*z*u*v**2 + u**2*v**2 + + assert R.dmp_qq_collins_resultant(f, g) == r.drop(x) + + Rt, t = ring("t", ZZ) + Rx, x = ring("x", Rt) + + f = x**6 - 5*x**4 + 5*x**2 + 4 + g = -6*t*x**5 + x**4 + 20*t*x**3 - 3*x**2 - 10*t*x + 6 + + assert Rx.dup_resultant(f, g) == 2930944*t**6 + 2198208*t**4 + 549552*t**2 + 45796 + + +def test_dup_discriminant(): + R, x = ring("x", ZZ) + + assert R.dup_discriminant(0) == 0 + assert R.dup_discriminant(x) == 1 + + assert R.dup_discriminant(x**3 + 3*x**2 + 9*x - 13) == -11664 + assert R.dup_discriminant(5*x**5 + x**3 + 2) == 31252160 + assert R.dup_discriminant(x**4 + 2*x**3 + 6*x**2 - 22*x + 13) == 0 + assert R.dup_discriminant(12*x**7 + 15*x**4 + 30*x**3 + x**2 + 1) == -220289699947514112 + + +def test_dmp_discriminant(): + R, x = ring("x", ZZ) + + assert R.dmp_discriminant(0) == 0 + + R, x, y = ring("x,y", ZZ) + + assert R.dmp_discriminant(0) == 0 + assert R.dmp_discriminant(y) == 0 + + assert R.dmp_discriminant(x**3 + 3*x**2 + 9*x - 13) == -11664 + assert R.dmp_discriminant(5*x**5 + x**3 + 2) == 31252160 + assert R.dmp_discriminant(x**4 + 2*x**3 + 6*x**2 - 22*x + 13) == 0 + assert R.dmp_discriminant(12*x**7 + 15*x**4 + 30*x**3 + x**2 + 1) == -220289699947514112 + + assert R.dmp_discriminant(x**2*y + 2*y) == (-8*y**2).drop(x) + assert R.dmp_discriminant(x*y**2 + 2*x) == 1 + + R, x, y, z = ring("x,y,z", ZZ) + assert R.dmp_discriminant(x*y + z) == 1 + + R, x, y, z, u = ring("x,y,z,u", ZZ) + assert R.dmp_discriminant(x**2*y + x*z + u) == (-4*y*u + z**2).drop(x) + + R, x, y, z, u, v = ring("x,y,z,u,v", ZZ) + assert R.dmp_discriminant(x**3*y + x**2*z + x*u + v) == \ + (-27*y**2*v**2 + 18*y*z*u*v - 4*y*u**3 - 4*z**3*v + z**2*u**2).drop(x) + + +def test_dup_gcd(): + R, x = ring("x", ZZ) + + f, g = 0, 0 + assert R.dup_zz_heu_gcd(f, g) == R.dup_rr_prs_gcd(f, g) == (0, 0, 0) + + f, g = 2, 0 + assert R.dup_zz_heu_gcd(f, g) == R.dup_rr_prs_gcd(f, g) == (2, 1, 0) + + f, g = -2, 0 + assert R.dup_zz_heu_gcd(f, g) == R.dup_rr_prs_gcd(f, g) == (2, -1, 0) + + f, g = 0, -2 + assert R.dup_zz_heu_gcd(f, g) == R.dup_rr_prs_gcd(f, g) == (2, 0, -1) + + f, g = 0, 2*x + 4 + assert R.dup_zz_heu_gcd(f, g) == R.dup_rr_prs_gcd(f, g) == (2*x + 4, 0, 1) + + f, g = 2*x + 4, 0 + assert R.dup_zz_heu_gcd(f, g) == R.dup_rr_prs_gcd(f, g) == (2*x + 4, 1, 0) + + f, g = 2, 2 + assert R.dup_zz_heu_gcd(f, g) == R.dup_rr_prs_gcd(f, g) == (2, 1, 1) + + f, g = -2, 2 + assert R.dup_zz_heu_gcd(f, g) == R.dup_rr_prs_gcd(f, g) == (2, -1, 1) + + f, g = 2, -2 + assert R.dup_zz_heu_gcd(f, g) == R.dup_rr_prs_gcd(f, g) == (2, 1, -1) + + f, g = -2, -2 + assert R.dup_zz_heu_gcd(f, g) == R.dup_rr_prs_gcd(f, g) == (2, -1, -1) + + f, g = x**2 + 2*x + 1, 1 + assert R.dup_zz_heu_gcd(f, g) == R.dup_rr_prs_gcd(f, g) == (1, x**2 + 2*x + 1, 1) + + f, g = x**2 + 2*x + 1, 2 + assert R.dup_zz_heu_gcd(f, g) == R.dup_rr_prs_gcd(f, g) == (1, x**2 + 2*x + 1, 2) + + f, g = 2*x**2 + 4*x + 2, 2 + assert R.dup_zz_heu_gcd(f, g) == R.dup_rr_prs_gcd(f, g) == (2, x**2 + 2*x + 1, 1) + + f, g = 2, 2*x**2 + 4*x + 2 + assert R.dup_zz_heu_gcd(f, g) == R.dup_rr_prs_gcd(f, g) == (2, 1, x**2 + 2*x + 1) + + f, g = 2*x**2 + 4*x + 2, x + 1 + assert R.dup_zz_heu_gcd(f, g) == R.dup_rr_prs_gcd(f, g) == (x + 1, 2*x + 2, 1) + + f, g = x + 1, 2*x**2 + 4*x + 2 + assert R.dup_zz_heu_gcd(f, g) == R.dup_rr_prs_gcd(f, g) == (x + 1, 1, 2*x + 2) + + f, g = x - 31, x + assert R.dup_zz_heu_gcd(f, g) == R.dup_rr_prs_gcd(f, g) == (1, f, g) + + f = x**4 + 8*x**3 + 21*x**2 + 22*x + 8 + g = x**3 + 6*x**2 + 11*x + 6 + + h = x**2 + 3*x + 2 + + cff = x**2 + 5*x + 4 + cfg = x + 3 + + assert R.dup_zz_heu_gcd(f, g) == (h, cff, cfg) + assert R.dup_rr_prs_gcd(f, g) == (h, cff, cfg) + + f = x**4 - 4 + g = x**4 + 4*x**2 + 4 + + h = x**2 + 2 + + cff = x**2 - 2 + cfg = x**2 + 2 + + assert R.dup_zz_heu_gcd(f, g) == (h, cff, cfg) + assert R.dup_rr_prs_gcd(f, g) == (h, cff, cfg) + + f = x**8 + x**6 - 3*x**4 - 3*x**3 + 8*x**2 + 2*x - 5 + g = 3*x**6 + 5*x**4 - 4*x**2 - 9*x + 21 + + h = 1 + + cff = f + cfg = g + + assert R.dup_zz_heu_gcd(f, g) == (h, cff, cfg) + assert R.dup_rr_prs_gcd(f, g) == (h, cff, cfg) + + R, x = ring("x", QQ) + + f = x**8 + x**6 - 3*x**4 - 3*x**3 + 8*x**2 + 2*x - 5 + g = 3*x**6 + 5*x**4 - 4*x**2 - 9*x + 21 + + h = 1 + + cff = f + cfg = g + + assert R.dup_qq_heu_gcd(f, g) == (h, cff, cfg) + assert R.dup_ff_prs_gcd(f, g) == (h, cff, cfg) + + R, x = ring("x", ZZ) + + f = - 352518131239247345597970242177235495263669787845475025293906825864749649589178600387510272*x**49 \ + + 46818041807522713962450042363465092040687472354933295397472942006618953623327997952*x**42 \ + + 378182690892293941192071663536490788434899030680411695933646320291525827756032*x**35 \ + + 112806468807371824947796775491032386836656074179286744191026149539708928*x**28 \ + - 12278371209708240950316872681744825481125965781519138077173235712*x**21 \ + + 289127344604779611146960547954288113529690984687482920704*x**14 \ + + 19007977035740498977629742919480623972236450681*x**7 \ + + 311973482284542371301330321821976049 + + g = 365431878023781158602430064717380211405897160759702125019136*x**21 \ + + 197599133478719444145775798221171663643171734081650688*x**14 \ + - 9504116979659010018253915765478924103928886144*x**7 \ + - 311973482284542371301330321821976049 + + assert R.dup_zz_heu_gcd(f, R.dup_diff(f, 1))[0] == g + assert R.dup_rr_prs_gcd(f, R.dup_diff(f, 1))[0] == g + + R, x = ring("x", QQ) + + f = QQ(1,2)*x**2 + x + QQ(1,2) + g = QQ(1,2)*x + QQ(1,2) + + h = x + 1 + + assert R.dup_qq_heu_gcd(f, g) == (h, g, QQ(1,2)) + assert R.dup_ff_prs_gcd(f, g) == (h, g, QQ(1,2)) + + R, x = ring("x", ZZ) + + f = 1317378933230047068160*x + 2945748836994210856960 + g = 120352542776360960*x + 269116466014453760 + + h = 120352542776360960*x + 269116466014453760 + cff = 10946 + cfg = 1 + + assert R.dup_zz_heu_gcd(f, g) == (h, cff, cfg) + + +def test_dmp_gcd(): + R, x, y = ring("x,y", ZZ) + + f, g = 0, 0 + assert R.dmp_zz_heu_gcd(f, g) == R.dmp_rr_prs_gcd(f, g) == (0, 0, 0) + + f, g = 2, 0 + assert R.dmp_zz_heu_gcd(f, g) == R.dmp_rr_prs_gcd(f, g) == (2, 1, 0) + + f, g = -2, 0 + assert R.dmp_zz_heu_gcd(f, g) == R.dmp_rr_prs_gcd(f, g) == (2, -1, 0) + + f, g = 0, -2 + assert R.dmp_zz_heu_gcd(f, g) == R.dmp_rr_prs_gcd(f, g) == (2, 0, -1) + + f, g = 0, 2*x + 4 + assert R.dmp_zz_heu_gcd(f, g) == R.dmp_rr_prs_gcd(f, g) == (2*x + 4, 0, 1) + + f, g = 2*x + 4, 0 + assert R.dmp_zz_heu_gcd(f, g) == R.dmp_rr_prs_gcd(f, g) == (2*x + 4, 1, 0) + + f, g = 2, 2 + assert R.dmp_zz_heu_gcd(f, g) == R.dmp_rr_prs_gcd(f, g) == (2, 1, 1) + + f, g = -2, 2 + assert R.dmp_zz_heu_gcd(f, g) == R.dmp_rr_prs_gcd(f, g) == (2, -1, 1) + + f, g = 2, -2 + assert R.dmp_zz_heu_gcd(f, g) == R.dmp_rr_prs_gcd(f, g) == (2, 1, -1) + + f, g = -2, -2 + assert R.dmp_zz_heu_gcd(f, g) == R.dmp_rr_prs_gcd(f, g) == (2, -1, -1) + + f, g = x**2 + 2*x + 1, 1 + assert R.dmp_zz_heu_gcd(f, g) == R.dmp_rr_prs_gcd(f, g) == (1, x**2 + 2*x + 1, 1) + + f, g = x**2 + 2*x + 1, 2 + assert R.dmp_zz_heu_gcd(f, g) == R.dmp_rr_prs_gcd(f, g) == (1, x**2 + 2*x + 1, 2) + + f, g = 2*x**2 + 4*x + 2, 2 + assert R.dmp_zz_heu_gcd(f, g) == R.dmp_rr_prs_gcd(f, g) == (2, x**2 + 2*x + 1, 1) + + f, g = 2, 2*x**2 + 4*x + 2 + assert R.dmp_zz_heu_gcd(f, g) == R.dmp_rr_prs_gcd(f, g) == (2, 1, x**2 + 2*x + 1) + + f, g = 2*x**2 + 4*x + 2, x + 1 + assert R.dmp_zz_heu_gcd(f, g) == R.dmp_rr_prs_gcd(f, g) == (x + 1, 2*x + 2, 1) + + f, g = x + 1, 2*x**2 + 4*x + 2 + assert R.dmp_zz_heu_gcd(f, g) == R.dmp_rr_prs_gcd(f, g) == (x + 1, 1, 2*x + 2) + + R, x, y, z, u = ring("x,y,z,u", ZZ) + + f, g = u**2 + 2*u + 1, 2*u + 2 + assert R.dmp_zz_heu_gcd(f, g) == R.dmp_rr_prs_gcd(f, g) == (u + 1, u + 1, 2) + + f, g = z**2*u**2 + 2*z**2*u + z**2 + z*u + z, u**2 + 2*u + 1 + h, cff, cfg = u + 1, z**2*u + z**2 + z, u + 1 + + assert R.dmp_zz_heu_gcd(f, g) == (h, cff, cfg) + assert R.dmp_rr_prs_gcd(f, g) == (h, cff, cfg) + + assert R.dmp_zz_heu_gcd(g, f) == (h, cfg, cff) + assert R.dmp_rr_prs_gcd(g, f) == (h, cfg, cff) + + R, x, y, z = ring("x,y,z", ZZ) + + f, g, h = map(R.from_dense, dmp_fateman_poly_F_1(2, ZZ)) + H, cff, cfg = R.dmp_zz_heu_gcd(f, g) + + assert H == h and R.dmp_mul(H, cff) == f \ + and R.dmp_mul(H, cfg) == g + + H, cff, cfg = R.dmp_rr_prs_gcd(f, g) + + assert H == h and R.dmp_mul(H, cff) == f \ + and R.dmp_mul(H, cfg) == g + + R, x, y, z, u, v = ring("x,y,z,u,v", ZZ) + + f, g, h = map(R.from_dense, dmp_fateman_poly_F_1(4, ZZ)) + H, cff, cfg = R.dmp_zz_heu_gcd(f, g) + + assert H == h and R.dmp_mul(H, cff) == f \ + and R.dmp_mul(H, cfg) == g + + R, x, y, z, u, v, a, b = ring("x,y,z,u,v,a,b", ZZ) + + f, g, h = map(R.from_dense, dmp_fateman_poly_F_1(6, ZZ)) + H, cff, cfg = R.dmp_zz_heu_gcd(f, g) + + assert H == h and R.dmp_mul(H, cff) == f \ + and R.dmp_mul(H, cfg) == g + + R, x, y, z, u, v, a, b, c, d = ring("x,y,z,u,v,a,b,c,d", ZZ) + + f, g, h = map(R.from_dense, dmp_fateman_poly_F_1(8, ZZ)) + H, cff, cfg = R.dmp_zz_heu_gcd(f, g) + + assert H == h and R.dmp_mul(H, cff) == f \ + and R.dmp_mul(H, cfg) == g + + R, x, y, z = ring("x,y,z", ZZ) + + f, g, h = map(R.from_dense, dmp_fateman_poly_F_2(2, ZZ)) + H, cff, cfg = R.dmp_zz_heu_gcd(f, g) + + assert H == h and R.dmp_mul(H, cff) == f \ + and R.dmp_mul(H, cfg) == g + + H, cff, cfg = R.dmp_rr_prs_gcd(f, g) + + assert H == h and R.dmp_mul(H, cff) == f \ + and R.dmp_mul(H, cfg) == g + + f, g, h = map(R.from_dense, dmp_fateman_poly_F_3(2, ZZ)) + H, cff, cfg = R.dmp_zz_heu_gcd(f, g) + + assert H == h and R.dmp_mul(H, cff) == f \ + and R.dmp_mul(H, cfg) == g + + H, cff, cfg = R.dmp_rr_prs_gcd(f, g) + + assert H == h and R.dmp_mul(H, cff) == f \ + and R.dmp_mul(H, cfg) == g + + R, x, y, z, u, v = ring("x,y,z,u,v", ZZ) + + f, g, h = map(R.from_dense, dmp_fateman_poly_F_3(4, ZZ)) + H, cff, cfg = R.dmp_inner_gcd(f, g) + + assert H == h and R.dmp_mul(H, cff) == f \ + and R.dmp_mul(H, cfg) == g + + R, x, y = ring("x,y", QQ) + + f = QQ(1,2)*x**2 + x + QQ(1,2) + g = QQ(1,2)*x + QQ(1,2) + + h = x + 1 + + assert R.dmp_qq_heu_gcd(f, g) == (h, g, QQ(1,2)) + assert R.dmp_ff_prs_gcd(f, g) == (h, g, QQ(1,2)) + + R, x, y = ring("x,y", RR) + + f = 2.1*x*y**2 - 2.2*x*y + 2.1*x + g = 1.0*x**3 + + assert R.dmp_ff_prs_gcd(f, g) == \ + (1.0*x, 2.1*y**2 - 2.2*y + 2.1, 1.0*x**2) + + +def test_dup_lcm(): + R, x = ring("x", ZZ) + + assert R.dup_lcm(2, 6) == 6 + + assert R.dup_lcm(2*x**3, 6*x) == 6*x**3 + assert R.dup_lcm(2*x**3, 3*x) == 6*x**3 + + assert R.dup_lcm(x**2 + x, x) == x**2 + x + assert R.dup_lcm(x**2 + x, 2*x) == 2*x**2 + 2*x + assert R.dup_lcm(x**2 + 2*x, x) == x**2 + 2*x + assert R.dup_lcm(2*x**2 + x, x) == 2*x**2 + x + assert R.dup_lcm(2*x**2 + x, 2*x) == 4*x**2 + 2*x + + +def test_dmp_lcm(): + R, x, y = ring("x,y", ZZ) + + assert R.dmp_lcm(2, 6) == 6 + assert R.dmp_lcm(x, y) == x*y + + assert R.dmp_lcm(2*x**3, 6*x*y**2) == 6*x**3*y**2 + assert R.dmp_lcm(2*x**3, 3*x*y**2) == 6*x**3*y**2 + + assert R.dmp_lcm(x**2*y, x*y**2) == x**2*y**2 + + f = 2*x*y**5 - 3*x*y**4 - 2*x*y**3 + 3*x*y**2 + g = y**5 - 2*y**3 + y + h = 2*x*y**7 - 3*x*y**6 - 4*x*y**5 + 6*x*y**4 + 2*x*y**3 - 3*x*y**2 + + assert R.dmp_lcm(f, g) == h + + f = x**3 - 3*x**2*y - 9*x*y**2 - 5*y**3 + g = x**4 + 6*x**3*y + 12*x**2*y**2 + 10*x*y**3 + 3*y**4 + h = x**5 + x**4*y - 18*x**3*y**2 - 50*x**2*y**3 - 47*x*y**4 - 15*y**5 + + assert R.dmp_lcm(f, g) == h + + +def test_dmp_content(): + R, x,y = ring("x,y", ZZ) + + assert R.dmp_content(-2) == 2 + + f, g, F = 3*y**2 + 2*y + 1, 1, 0 + + for i in range(0, 5): + g *= f + F += x**i*g + + assert R.dmp_content(F) == f.drop(x) + + R, x,y,z = ring("x,y,z", ZZ) + + assert R.dmp_content(f_4) == 1 + assert R.dmp_content(f_5) == 1 + + R, x,y,z,t = ring("x,y,z,t", ZZ) + assert R.dmp_content(f_6) == 1 + + +def test_dmp_primitive(): + R, x,y = ring("x,y", ZZ) + + assert R.dmp_primitive(0) == (0, 0) + assert R.dmp_primitive(1) == (1, 1) + + f, g, F = 3*y**2 + 2*y + 1, 1, 0 + + for i in range(0, 5): + g *= f + F += x**i*g + + assert R.dmp_primitive(F) == (f.drop(x), F / f) + + R, x,y,z = ring("x,y,z", ZZ) + + cont, f = R.dmp_primitive(f_4) + assert cont == 1 and f == f_4 + cont, f = R.dmp_primitive(f_5) + assert cont == 1 and f == f_5 + + R, x,y,z,t = ring("x,y,z,t", ZZ) + + cont, f = R.dmp_primitive(f_6) + assert cont == 1 and f == f_6 + + +def test_dup_cancel(): + R, x = ring("x", ZZ) + + f = 2*x**2 - 2 + g = x**2 - 2*x + 1 + + p = 2*x + 2 + q = x - 1 + + assert R.dup_cancel(f, g) == (p, q) + assert R.dup_cancel(f, g, include=False) == (1, 1, p, q) + + f = -x - 2 + g = 3*x - 4 + + F = x + 2 + G = -3*x + 4 + + assert R.dup_cancel(f, g) == (f, g) + assert R.dup_cancel(F, G) == (f, g) + + assert R.dup_cancel(0, 0) == (0, 0) + assert R.dup_cancel(0, 0, include=False) == (1, 1, 0, 0) + + assert R.dup_cancel(x, 0) == (1, 0) + assert R.dup_cancel(x, 0, include=False) == (1, 1, 1, 0) + + assert R.dup_cancel(0, x) == (0, 1) + assert R.dup_cancel(0, x, include=False) == (1, 1, 0, 1) + + f = 0 + g = x + one = 1 + + assert R.dup_cancel(f, g, include=True) == (f, one) + + +def test_dmp_cancel(): + R, x, y = ring("x,y", ZZ) + + f = 2*x**2 - 2 + g = x**2 - 2*x + 1 + + p = 2*x + 2 + q = x - 1 + + assert R.dmp_cancel(f, g) == (p, q) + assert R.dmp_cancel(f, g, include=False) == (1, 1, p, q) + + assert R.dmp_cancel(0, 0) == (0, 0) + assert R.dmp_cancel(0, 0, include=False) == (1, 1, 0, 0) + + assert R.dmp_cancel(y, 0) == (1, 0) + assert R.dmp_cancel(y, 0, include=False) == (1, 1, 1, 0) + + assert R.dmp_cancel(0, y) == (0, 1) + assert R.dmp_cancel(0, y, include=False) == (1, 1, 0, 1) diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/test_factortools.py b/phivenv/Lib/site-packages/sympy/polys/tests/test_factortools.py new file mode 100644 index 0000000000000000000000000000000000000000..7f99097c71e9cde761a800b01b149ec5c9896266 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/tests/test_factortools.py @@ -0,0 +1,784 @@ +"""Tools for polynomial factorization routines in characteristic zero. """ + +from sympy.polys.rings import ring, xring +from sympy.polys.domains import FF, ZZ, QQ, ZZ_I, QQ_I, RR, EX + +from sympy.polys import polyconfig as config +from sympy.polys.polyerrors import DomainError +from sympy.polys.polyclasses import ANP +from sympy.polys.specialpolys import f_polys, w_polys + +from sympy.core.numbers import I +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import sin +from sympy.ntheory.generate import nextprime +from sympy.testing.pytest import raises, XFAIL + + +f_0, f_1, f_2, f_3, f_4, f_5, f_6 = f_polys() +w_1, w_2 = w_polys() + +def test_dup_trial_division(): + R, x = ring("x", ZZ) + assert R.dup_trial_division(x**5 + 8*x**4 + 25*x**3 + 38*x**2 + 28*x + 8, (x + 1, x + 2)) == [(x + 1, 2), (x + 2, 3)] + + +def test_dmp_trial_division(): + R, x, y = ring("x,y", ZZ) + assert R.dmp_trial_division(x**5 + 8*x**4 + 25*x**3 + 38*x**2 + 28*x + 8, (x + 1, x + 2)) == [(x + 1, 2), (x + 2, 3)] + + +def test_dup_zz_mignotte_bound(): + R, x = ring("x", ZZ) + assert R.dup_zz_mignotte_bound(2*x**2 + 3*x + 4) == 6 + assert R.dup_zz_mignotte_bound(x**3 + 14*x**2 + 56*x + 64) == 152 + + +def test_dmp_zz_mignotte_bound(): + R, x, y = ring("x,y", ZZ) + assert R.dmp_zz_mignotte_bound(2*x**2 + 3*x + 4) == 32 + + +def test_dup_zz_hensel_step(): + R, x = ring("x", ZZ) + + f = x**4 - 1 + g = x**3 + 2*x**2 - x - 2 + h = x - 2 + s = -2 + t = 2*x**2 - 2*x - 1 + + G, H, S, T = R.dup_zz_hensel_step(5, f, g, h, s, t) + + assert G == x**3 + 7*x**2 - x - 7 + assert H == x - 7 + assert S == 8 + assert T == -8*x**2 - 12*x - 1 + + +def test_dup_zz_hensel_lift(): + R, x = ring("x", ZZ) + + f = x**4 - 1 + F = [x - 1, x - 2, x + 2, x + 1] + + assert R.dup_zz_hensel_lift(ZZ(5), f, F, 4) == \ + [x - 1, x - 182, x + 182, x + 1] + + +def test_dup_zz_irreducible_p(): + R, x = ring("x", ZZ) + + assert R.dup_zz_irreducible_p(3*x**4 + 2*x**3 + 6*x**2 + 8*x + 7) is None + assert R.dup_zz_irreducible_p(3*x**4 + 2*x**3 + 6*x**2 + 8*x + 4) is None + + assert R.dup_zz_irreducible_p(3*x**4 + 2*x**3 + 6*x**2 + 8*x + 10) is True + assert R.dup_zz_irreducible_p(3*x**4 + 2*x**3 + 6*x**2 + 8*x + 14) is True + + +def test_dup_cyclotomic_p(): + R, x = ring("x", ZZ) + + assert R.dup_cyclotomic_p(x - 1) is True + assert R.dup_cyclotomic_p(x + 1) is True + assert R.dup_cyclotomic_p(x**2 + x + 1) is True + assert R.dup_cyclotomic_p(x**2 + 1) is True + assert R.dup_cyclotomic_p(x**4 + x**3 + x**2 + x + 1) is True + assert R.dup_cyclotomic_p(x**2 - x + 1) is True + assert R.dup_cyclotomic_p(x**6 + x**5 + x**4 + x**3 + x**2 + x + 1) is True + assert R.dup_cyclotomic_p(x**4 + 1) is True + assert R.dup_cyclotomic_p(x**6 + x**3 + 1) is True + + assert R.dup_cyclotomic_p(0) is False + assert R.dup_cyclotomic_p(1) is False + assert R.dup_cyclotomic_p(x) is False + assert R.dup_cyclotomic_p(x + 2) is False + assert R.dup_cyclotomic_p(3*x + 1) is False + assert R.dup_cyclotomic_p(x**2 - 1) is False + + f = x**16 + x**14 - x**10 + x**8 - x**6 + x**2 + 1 + assert R.dup_cyclotomic_p(f) is False + + g = x**16 + x**14 - x**10 - x**8 - x**6 + x**2 + 1 + assert R.dup_cyclotomic_p(g) is True + + R, x = ring("x", QQ) + assert R.dup_cyclotomic_p(x**2 + x + 1) is True + assert R.dup_cyclotomic_p(QQ(1,2)*x**2 + x + 1) is False + + R, x = ring("x", ZZ["y"]) + assert R.dup_cyclotomic_p(x**2 + x + 1) is False + + +def test_dup_zz_cyclotomic_poly(): + R, x = ring("x", ZZ) + + assert R.dup_zz_cyclotomic_poly(1) == x - 1 + assert R.dup_zz_cyclotomic_poly(2) == x + 1 + assert R.dup_zz_cyclotomic_poly(3) == x**2 + x + 1 + assert R.dup_zz_cyclotomic_poly(4) == x**2 + 1 + assert R.dup_zz_cyclotomic_poly(5) == x**4 + x**3 + x**2 + x + 1 + assert R.dup_zz_cyclotomic_poly(6) == x**2 - x + 1 + assert R.dup_zz_cyclotomic_poly(7) == x**6 + x**5 + x**4 + x**3 + x**2 + x + 1 + assert R.dup_zz_cyclotomic_poly(8) == x**4 + 1 + assert R.dup_zz_cyclotomic_poly(9) == x**6 + x**3 + 1 + + +def test_dup_zz_cyclotomic_factor(): + R, x = ring("x", ZZ) + + assert R.dup_zz_cyclotomic_factor(0) is None + assert R.dup_zz_cyclotomic_factor(1) is None + + assert R.dup_zz_cyclotomic_factor(2*x**10 - 1) is None + assert R.dup_zz_cyclotomic_factor(x**10 - 3) is None + assert R.dup_zz_cyclotomic_factor(x**10 + x**5 - 1) is None + + assert R.dup_zz_cyclotomic_factor(x + 1) == [x + 1] + assert R.dup_zz_cyclotomic_factor(x - 1) == [x - 1] + + assert R.dup_zz_cyclotomic_factor(x**2 + 1) == [x**2 + 1] + assert R.dup_zz_cyclotomic_factor(x**2 - 1) == [x - 1, x + 1] + + assert R.dup_zz_cyclotomic_factor(x**27 + 1) == \ + [x + 1, x**2 - x + 1, x**6 - x**3 + 1, x**18 - x**9 + 1] + assert R.dup_zz_cyclotomic_factor(x**27 - 1) == \ + [x - 1, x**2 + x + 1, x**6 + x**3 + 1, x**18 + x**9 + 1] + + +def test_dup_zz_factor(): + R, x = ring("x", ZZ) + + assert R.dup_zz_factor(0) == (0, []) + assert R.dup_zz_factor(7) == (7, []) + assert R.dup_zz_factor(-7) == (-7, []) + + assert R.dup_zz_factor_sqf(0) == (0, []) + assert R.dup_zz_factor_sqf(7) == (7, []) + assert R.dup_zz_factor_sqf(-7) == (-7, []) + + assert R.dup_zz_factor(2*x + 4) == (2, [(x + 2, 1)]) + assert R.dup_zz_factor_sqf(2*x + 4) == (2, [x + 2]) + + f = x**4 + x + 1 + + for i in range(0, 20): + assert R.dup_zz_factor(f) == (1, [(f, 1)]) + + assert R.dup_zz_factor(x**2 + 2*x + 2) == \ + (1, [(x**2 + 2*x + 2, 1)]) + + assert R.dup_zz_factor(18*x**2 + 12*x + 2) == \ + (2, [(3*x + 1, 2)]) + + assert R.dup_zz_factor(-9*x**2 + 1) == \ + (-1, [(3*x - 1, 1), + (3*x + 1, 1)]) + + assert R.dup_zz_factor_sqf(-9*x**2 + 1) == \ + (-1, [3*x - 1, + 3*x + 1]) + + # The order of the factors will be different when the ground types are + # flint. At the higher level dup_factor_list will sort the factors. + c, factors = R.dup_zz_factor(x**3 - 6*x**2 + 11*x - 6) + assert c == 1 + assert set(factors) == {(x - 3, 1), (x - 2, 1), (x - 1, 1)} + + assert R.dup_zz_factor_sqf(x**3 - 6*x**2 + 11*x - 6) == \ + (1, [x - 3, + x - 2, + x - 1]) + + assert R.dup_zz_factor(3*x**3 + 10*x**2 + 13*x + 10) == \ + (1, [(x + 2, 1), + (3*x**2 + 4*x + 5, 1)]) + + assert R.dup_zz_factor_sqf(3*x**3 + 10*x**2 + 13*x + 10) == \ + (1, [x + 2, + 3*x**2 + 4*x + 5]) + + c, factors = R.dup_zz_factor(-x**6 + x**2) + assert c == -1 + assert set(factors) == {(x, 2), (x - 1, 1), (x + 1, 1), (x**2 + 1, 1)} + + f = 1080*x**8 + 5184*x**7 + 2099*x**6 + 744*x**5 + 2736*x**4 - 648*x**3 + 129*x**2 - 324 + + assert R.dup_zz_factor(f) == \ + (1, [(5*x**4 + 24*x**3 + 9*x**2 + 12, 1), + (216*x**4 + 31*x**2 - 27, 1)]) + + f = -29802322387695312500000000000000000000*x**25 \ + + 2980232238769531250000000000000000*x**20 \ + + 1743435859680175781250000000000*x**15 \ + + 114142894744873046875000000*x**10 \ + - 210106372833251953125*x**5 \ + + 95367431640625 + + c, factors = R.dup_zz_factor(f) + assert c == -95367431640625 + assert set(factors) == { + (5*x - 1, 1), + (100*x**2 + 10*x - 1, 2), + (625*x**4 + 125*x**3 + 25*x**2 + 5*x + 1, 1), + (10000*x**4 - 3000*x**3 + 400*x**2 - 20*x + 1, 2), + (10000*x**4 + 2000*x**3 + 400*x**2 + 30*x + 1, 2), + } + + f = x**10 - 1 + + config.setup('USE_CYCLOTOMIC_FACTOR', True) + c0, F_0 = R.dup_zz_factor(f) + + config.setup('USE_CYCLOTOMIC_FACTOR', False) + c1, F_1 = R.dup_zz_factor(f) + + assert c0 == c1 == 1 + assert set(F_0) == set(F_1) == { + (x - 1, 1), + (x + 1, 1), + (x**4 - x**3 + x**2 - x + 1, 1), + (x**4 + x**3 + x**2 + x + 1, 1), + } + + config.setup('USE_CYCLOTOMIC_FACTOR') + + f = x**10 + 1 + + config.setup('USE_CYCLOTOMIC_FACTOR', True) + F_0 = R.dup_zz_factor(f) + + config.setup('USE_CYCLOTOMIC_FACTOR', False) + F_1 = R.dup_zz_factor(f) + + assert F_0 == F_1 == \ + (1, [(x**2 + 1, 1), + (x**8 - x**6 + x**4 - x**2 + 1, 1)]) + + config.setup('USE_CYCLOTOMIC_FACTOR') + +def test_dmp_zz_wang(): + R, x,y,z = ring("x,y,z", ZZ) + UV, _x = ring("x", ZZ) + + p = ZZ(nextprime(R.dmp_zz_mignotte_bound(w_1))) + assert p == 6291469 + + t_1, k_1, e_1 = y, 1, ZZ(-14) + t_2, k_2, e_2 = z, 2, ZZ(3) + t_3, k_3, e_3 = y + z, 2, ZZ(-11) + t_4, k_4, e_4 = y - z, 1, ZZ(-17) + + T = [t_1, t_2, t_3, t_4] + K = [k_1, k_2, k_3, k_4] + E = [e_1, e_2, e_3, e_4] + + T = zip([ t.drop(x) for t in T ], K) + + A = [ZZ(-14), ZZ(3)] + + S = R.dmp_eval_tail(w_1, A) + cs, s = UV.dup_primitive(S) + + assert cs == 1 and s == S == \ + 1036728*_x**6 + 915552*_x**5 + 55748*_x**4 + 105621*_x**3 - 17304*_x**2 - 26841*_x - 644 + + assert R.dmp_zz_wang_non_divisors(E, cs, ZZ(4)) == [7, 3, 11, 17] + assert UV.dup_sqf_p(s) and UV.dup_degree(s) == R.dmp_degree(w_1) + + _, H = UV.dup_zz_factor_sqf(s) + + h_1 = 44*_x**2 + 42*_x + 1 + h_2 = 126*_x**2 - 9*_x + 28 + h_3 = 187*_x**2 - 23 + + assert H == [h_1, h_2, h_3] + + LC = [ lc.drop(x) for lc in [-4*y - 4*z, -y*z**2, y**2 - z**2] ] + + assert R.dmp_zz_wang_lead_coeffs(w_1, T, cs, E, H, A) == (w_1, H, LC) + + factors = R.dmp_zz_wang_hensel_lifting(w_1, H, LC, A, p) + assert R.dmp_expand(factors) == w_1 + + +@XFAIL +def test_dmp_zz_wang_fail(): + R, x,y,z = ring("x,y,z", ZZ) + UV, _x = ring("x", ZZ) + + p = ZZ(nextprime(R.dmp_zz_mignotte_bound(w_1))) + assert p == 6291469 + + H_1 = [44*x**2 + 42*x + 1, 126*x**2 - 9*x + 28, 187*x**2 - 23] + H_2 = [-4*x**2*y - 12*x**2 - 3*x*y + 1, -9*x**2*y - 9*x - 2*y, x**2*y**2 - 9*x**2 + y - 9] + H_3 = [-4*x**2*y - 12*x**2 - 3*x*y + 1, -9*x**2*y - 9*x - 2*y, x**2*y**2 - 9*x**2 + y - 9] + + c_1 = -70686*x**5 - 5863*x**4 - 17826*x**3 + 2009*x**2 + 5031*x + 74 + c_2 = 9*x**5*y**4 + 12*x**5*y**3 - 45*x**5*y**2 - 108*x**5*y - 324*x**5 + 18*x**4*y**3 - 216*x**4*y**2 - 810*x**4*y + 2*x**3*y**4 + 9*x**3*y**3 - 252*x**3*y**2 - 288*x**3*y - 945*x**3 - 30*x**2*y**2 - 414*x**2*y + 2*x*y**3 - 54*x*y**2 - 3*x*y + 81*x + 12*y + c_3 = -36*x**4*y**2 - 108*x**4*y - 27*x**3*y**2 - 36*x**3*y - 108*x**3 - 8*x**2*y**2 - 42*x**2*y - 6*x*y**2 + 9*x + 2*y + + assert R.dmp_zz_diophantine(H_1, c_1, [], 5, p) == [-3*x, -2, 1] + assert R.dmp_zz_diophantine(H_2, c_2, [ZZ(-14)], 5, p) == [-x*y, -3*x, -6] + assert R.dmp_zz_diophantine(H_3, c_3, [ZZ(-14)], 5, p) == [0, 0, -1] + + +def test_issue_6355(): + # This tests a bug in the Wang algorithm that occurred only with a very + # specific set of random numbers. + random_sequence = [-1, -1, 0, 0, 0, 0, -1, -1, 0, -1, 3, -1, 3, 3, 3, 3, -1, 3] + + R, x, y, z = ring("x,y,z", ZZ) + f = 2*x**2 + y*z - y - z**2 + z + + assert R.dmp_zz_wang(f, seed=random_sequence) == [f] + + +def test_dmp_zz_factor(): + R, x = ring("x", ZZ) + assert R.dmp_zz_factor(0) == (0, []) + assert R.dmp_zz_factor(7) == (7, []) + assert R.dmp_zz_factor(-7) == (-7, []) + + assert R.dmp_zz_factor(x**2 - 9) == (1, [(x - 3, 1), (x + 3, 1)]) + + R, x, y = ring("x,y", ZZ) + assert R.dmp_zz_factor(0) == (0, []) + assert R.dmp_zz_factor(7) == (7, []) + assert R.dmp_zz_factor(-7) == (-7, []) + + assert R.dmp_zz_factor(x) == (1, [(x, 1)]) + assert R.dmp_zz_factor(4*x) == (4, [(x, 1)]) + assert R.dmp_zz_factor(4*x + 2) == (2, [(2*x + 1, 1)]) + assert R.dmp_zz_factor(x*y + 1) == (1, [(x*y + 1, 1)]) + assert R.dmp_zz_factor(y**2 + 1) == (1, [(y**2 + 1, 1)]) + assert R.dmp_zz_factor(y**2 - 1) == (1, [(y - 1, 1), (y + 1, 1)]) + + assert R.dmp_zz_factor(x**2*y**2 + 6*x**2*y + 9*x**2 - 1) == (1, [(x*y + 3*x - 1, 1), (x*y + 3*x + 1, 1)]) + assert R.dmp_zz_factor(x**2*y**2 - 9) == (1, [(x*y - 3, 1), (x*y + 3, 1)]) + + R, x, y, z = ring("x,y,z", ZZ) + assert R.dmp_zz_factor(x**2*y**2*z**2 - 9) == \ + (1, [(x*y*z - 3, 1), + (x*y*z + 3, 1)]) + + R, x, y, z, u = ring("x,y,z,u", ZZ) + assert R.dmp_zz_factor(x**2*y**2*z**2*u**2 - 9) == \ + (1, [(x*y*z*u - 3, 1), + (x*y*z*u + 3, 1)]) + + R, x, y, z = ring("x,y,z", ZZ) + assert R.dmp_zz_factor(f_1) == \ + (1, [(x + y*z + 20, 1), + (x*y + z + 10, 1), + (x*z + y + 30, 1)]) + + assert R.dmp_zz_factor(f_2) == \ + (1, [(x**2*y**2 + x**2*z**2 + y + 90, 1), + (x**3*y + x**3*z + z - 11, 1)]) + + assert R.dmp_zz_factor(f_3) == \ + (1, [(x**2*y**2 + x*z**4 + x + z, 1), + (x**3 + x*y*z + y**2 + y*z**3, 1)]) + + assert R.dmp_zz_factor(f_4) == \ + (-1, [(x*y**3 + z**2, 1), + (x**2*z + y**4*z**2 + 5, 1), + (x**3*y - z**2 - 3, 1), + (x**3*y**4 + z**2, 1)]) + + assert R.dmp_zz_factor(f_5) == \ + (-1, [(x + y - z, 3)]) + + R, x, y, z, t = ring("x,y,z,t", ZZ) + assert R.dmp_zz_factor(f_6) == \ + (1, [(47*x*y + z**3*t**2 - t**2, 1), + (45*x**3 - 9*y**3 - y**2 + 3*z**3 + 2*z*t, 1)]) + + R, x, y, z = ring("x,y,z", ZZ) + assert R.dmp_zz_factor(w_1) == \ + (1, [(x**2*y**2 - x**2*z**2 + y - z**2, 1), + (x**2*y*z**2 + 3*x*z + 2*y, 1), + (4*x**2*y + 4*x**2*z + x*y*z - 1, 1)]) + + R, x, y = ring("x,y", ZZ) + f = -12*x**16*y + 240*x**12*y**3 - 768*x**10*y**4 + 1080*x**8*y**5 - 768*x**6*y**6 + 240*x**4*y**7 - 12*y**9 + + assert R.dmp_zz_factor(f) == \ + (-12, [(y, 1), + (x**2 - y, 6), + (x**4 + 6*x**2*y + y**2, 1)]) + + +def test_dup_qq_i_factor(): + R, x = ring("x", QQ_I) + i = QQ_I(0, 1) + + assert R.dup_qq_i_factor(x**2 - 2) == (QQ_I(1, 0), [(x**2 - 2, 1)]) + + assert R.dup_qq_i_factor(x**2 - 1) == (QQ_I(1, 0), [(x - 1, 1), (x + 1, 1)]) + + assert R.dup_qq_i_factor(x**2 + 1) == (QQ_I(1, 0), [(x - i, 1), (x + i, 1)]) + + assert R.dup_qq_i_factor(x**2/4 + 1) == \ + (QQ_I(QQ(1, 4), 0), [(x - 2*i, 1), (x + 2*i, 1)]) + + assert R.dup_qq_i_factor(x**2 + 4) == \ + (QQ_I(1, 0), [(x - 2*i, 1), (x + 2*i, 1)]) + + assert R.dup_qq_i_factor(x**2 + 2*x + 1) == \ + (QQ_I(1, 0), [(x + 1, 2)]) + + assert R.dup_qq_i_factor(x**2 + 2*i*x - 1) == \ + (QQ_I(1, 0), [(x + i, 2)]) + + f = 8192*x**2 + x*(22656 + 175232*i) - 921416 + 242313*i + + assert R.dup_qq_i_factor(f) == \ + (QQ_I(8192, 0), [(x + QQ_I(QQ(177, 128), QQ(1369, 128)), 2)]) + + +def test_dmp_qq_i_factor(): + R, x, y = ring("x, y", QQ_I) + i = QQ_I(0, 1) + + assert R.dmp_qq_i_factor(x**2 + 2*y**2) == \ + (QQ_I(1, 0), [(x**2 + 2*y**2, 1)]) + + assert R.dmp_qq_i_factor(x**2 + y**2) == \ + (QQ_I(1, 0), [(x - i*y, 1), (x + i*y, 1)]) + + assert R.dmp_qq_i_factor(x**2 + y**2/4) == \ + (QQ_I(1, 0), [(x - i*y/2, 1), (x + i*y/2, 1)]) + + assert R.dmp_qq_i_factor(4*x**2 + y**2) == \ + (QQ_I(4, 0), [(x - i*y/2, 1), (x + i*y/2, 1)]) + + +def test_dup_zz_i_factor(): + R, x = ring("x", ZZ_I) + i = ZZ_I(0, 1) + + assert R.dup_zz_i_factor(x**2 - 2) == (ZZ_I(1, 0), [(x**2 - 2, 1)]) + + assert R.dup_zz_i_factor(x**2 - 1) == (ZZ_I(1, 0), [(x - 1, 1), (x + 1, 1)]) + + assert R.dup_zz_i_factor(x**2 + 1) == (ZZ_I(1, 0), [(x - i, 1), (x + i, 1)]) + + assert R.dup_zz_i_factor(x**2 + 4) == \ + (ZZ_I(1, 0), [(x - 2*i, 1), (x + 2*i, 1)]) + + assert R.dup_zz_i_factor(x**2 + 2*x + 1) == \ + (ZZ_I(1, 0), [(x + 1, 2)]) + + assert R.dup_zz_i_factor(x**2 + 2*i*x - 1) == \ + (ZZ_I(1, 0), [(x + i, 2)]) + + f = 8192*x**2 + x*(22656 + 175232*i) - 921416 + 242313*i + + assert R.dup_zz_i_factor(f) == \ + (ZZ_I(0, 1), [((64 - 64*i)*x + (773 + 596*i), 2)]) + + +def test_dmp_zz_i_factor(): + R, x, y = ring("x, y", ZZ_I) + i = ZZ_I(0, 1) + + assert R.dmp_zz_i_factor(x**2 + 2*y**2) == \ + (ZZ_I(1, 0), [(x**2 + 2*y**2, 1)]) + + assert R.dmp_zz_i_factor(x**2 + y**2) == \ + (ZZ_I(1, 0), [(x - i*y, 1), (x + i*y, 1)]) + + assert R.dmp_zz_i_factor(4*x**2 + y**2) == \ + (ZZ_I(1, 0), [(2*x - i*y, 1), (2*x + i*y, 1)]) + + +def test_dup_ext_factor(): + R, x = ring("x", QQ.algebraic_field(I)) + def anp(element): + return ANP(element, [QQ(1), QQ(0), QQ(1)], QQ) + + assert R.dup_ext_factor(0) == (anp([]), []) + + f = anp([QQ(1)])*x + anp([QQ(1)]) + + assert R.dup_ext_factor(f) == (anp([QQ(1)]), [(f, 1)]) + + g = anp([QQ(2)])*x + anp([QQ(2)]) + + assert R.dup_ext_factor(g) == (anp([QQ(2)]), [(f, 1)]) + + f = anp([QQ(7)])*x**4 + anp([QQ(1, 1)]) + g = anp([QQ(1)])*x**4 + anp([QQ(1, 7)]) + + assert R.dup_ext_factor(f) == (anp([QQ(7)]), [(g, 1)]) + + f = anp([QQ(1)])*x**4 + anp([QQ(1)]) + + assert R.dup_ext_factor(f) == \ + (anp([QQ(1, 1)]), [(anp([QQ(1)])*x**2 + anp([QQ(-1), QQ(0)]), 1), + (anp([QQ(1)])*x**2 + anp([QQ( 1), QQ(0)]), 1)]) + + f = anp([QQ(4, 1)])*x**2 + anp([QQ(9, 1)]) + + assert R.dup_ext_factor(f) == \ + (anp([QQ(4, 1)]), [(anp([QQ(1, 1)])*x + anp([-QQ(3, 2), QQ(0, 1)]), 1), + (anp([QQ(1, 1)])*x + anp([ QQ(3, 2), QQ(0, 1)]), 1)]) + + f = anp([QQ(4, 1)])*x**4 + anp([QQ(8, 1)])*x**3 + anp([QQ(77, 1)])*x**2 + anp([QQ(18, 1)])*x + anp([QQ(153, 1)]) + + assert R.dup_ext_factor(f) == \ + (anp([QQ(4, 1)]), [(anp([QQ(1, 1)])*x + anp([-QQ(4, 1), QQ(1, 1)]), 1), + (anp([QQ(1, 1)])*x + anp([-QQ(3, 2), QQ(0, 1)]), 1), + (anp([QQ(1, 1)])*x + anp([ QQ(3, 2), QQ(0, 1)]), 1), + (anp([QQ(1, 1)])*x + anp([ QQ(4, 1), QQ(1, 1)]), 1)]) + + R, x = ring("x", QQ.algebraic_field(sqrt(2))) + def anp(element): + return ANP(element, [QQ(1), QQ(0), QQ(-2)], QQ) + + f = anp([QQ(1)])*x**4 + anp([QQ(1, 1)]) + + assert R.dup_ext_factor(f) == \ + (anp([QQ(1)]), [(anp([QQ(1)])*x**2 + anp([QQ(-1), QQ(0)])*x + anp([QQ(1)]), 1), + (anp([QQ(1)])*x**2 + anp([QQ( 1), QQ(0)])*x + anp([QQ(1)]), 1)]) + + f = anp([QQ(1, 1)])*x**2 + anp([QQ(2), QQ(0)])*x + anp([QQ(2, 1)]) + + assert R.dup_ext_factor(f) == \ + (anp([QQ(1, 1)]), [(anp([1])*x + anp([1, 0]), 2)]) + + assert R.dup_ext_factor(f**3) == \ + (anp([QQ(1, 1)]), [(anp([1])*x + anp([1, 0]), 6)]) + + f *= anp([QQ(2, 1)]) + + assert R.dup_ext_factor(f) == \ + (anp([QQ(2, 1)]), [(anp([1])*x + anp([1, 0]), 2)]) + + assert R.dup_ext_factor(f**3) == \ + (anp([QQ(8, 1)]), [(anp([1])*x + anp([1, 0]), 6)]) + + +def test_dmp_ext_factor(): + K = QQ.algebraic_field(sqrt(2)) + R, x,y = ring("x,y", K) + sqrt2 = K.unit + + def anp(x): + return ANP(x, [QQ(1), QQ(0), QQ(-2)], QQ) + + assert R.dmp_ext_factor(0) == (anp([]), []) + + f = anp([QQ(1)])*x + anp([QQ(1)]) + + assert R.dmp_ext_factor(f) == (anp([QQ(1)]), [(f, 1)]) + + g = anp([QQ(2)])*x + anp([QQ(2)]) + + assert R.dmp_ext_factor(g) == (anp([QQ(2)]), [(f, 1)]) + + f = anp([QQ(1)])*x**2 + anp([QQ(-2)])*y**2 + + assert R.dmp_ext_factor(f) == \ + (anp([QQ(1)]), [(anp([QQ(1)])*x + anp([QQ(-1), QQ(0)])*y, 1), + (anp([QQ(1)])*x + anp([QQ( 1), QQ(0)])*y, 1)]) + + f = anp([QQ(2)])*x**2 + anp([QQ(-4)])*y**2 + + assert R.dmp_ext_factor(f) == \ + (anp([QQ(2)]), [(anp([QQ(1)])*x + anp([QQ(-1), QQ(0)])*y, 1), + (anp([QQ(1)])*x + anp([QQ( 1), QQ(0)])*y, 1)]) + + f1 = y + 1 + f2 = y + sqrt2 + f3 = x**2 + x + 2 + 3*sqrt2 + f = f1**2 * f2**2 * f3**2 + assert R.dmp_ext_factor(f) == (K.one, [(f1, 2), (f2, 2), (f3, 2)]) + + +def test_dup_factor_list(): + R, x = ring("x", ZZ) + assert R.dup_factor_list(0) == (0, []) + assert R.dup_factor_list(7) == (7, []) + + R, x = ring("x", QQ) + assert R.dup_factor_list(0) == (0, []) + assert R.dup_factor_list(QQ(1, 7)) == (QQ(1, 7), []) + + R, x = ring("x", ZZ['t']) + assert R.dup_factor_list(0) == (0, []) + assert R.dup_factor_list(7) == (7, []) + + R, x = ring("x", QQ['t']) + assert R.dup_factor_list(0) == (0, []) + assert R.dup_factor_list(QQ(1, 7)) == (QQ(1, 7), []) + + R, x = ring("x", ZZ) + assert R.dup_factor_list_include(0) == [(0, 1)] + assert R.dup_factor_list_include(7) == [(7, 1)] + + assert R.dup_factor_list(x**2 + 2*x + 1) == (1, [(x + 1, 2)]) + assert R.dup_factor_list_include(x**2 + 2*x + 1) == [(x + 1, 2)] + # issue 8037 + assert R.dup_factor_list(6*x**2 - 5*x - 6) == (1, [(2*x - 3, 1), (3*x + 2, 1)]) + + R, x = ring("x", QQ) + assert R.dup_factor_list(QQ(1,2)*x**2 + x + QQ(1,2)) == (QQ(1, 2), [(x + 1, 2)]) + + R, x = ring("x", FF(2)) + assert R.dup_factor_list(x**2 + 1) == (1, [(x + 1, 2)]) + + R, x = ring("x", RR) + assert R.dup_factor_list(1.0*x**2 + 2.0*x + 1.0) == (1.0, [(1.0*x + 1.0, 2)]) + assert R.dup_factor_list(2.0*x**2 + 4.0*x + 2.0) == (2.0, [(1.0*x + 1.0, 2)]) + + f = 6.7225336055071*x**2 - 10.6463972754741*x - 0.33469524022264 + coeff, factors = R.dup_factor_list(f) + assert coeff == RR(10.6463972754741) + assert len(factors) == 1 + assert factors[0][0].max_norm() == RR(1.0) + assert factors[0][1] == 1 + + Rt, t = ring("t", ZZ) + R, x = ring("x", Rt) + + f = 4*t*x**2 + 4*t**2*x + + assert R.dup_factor_list(f) == \ + (4*t, [(x, 1), + (x + t, 1)]) + + Rt, t = ring("t", QQ) + R, x = ring("x", Rt) + + f = QQ(1, 2)*t*x**2 + QQ(1, 2)*t**2*x + + assert R.dup_factor_list(f) == \ + (QQ(1, 2)*t, [(x, 1), + (x + t, 1)]) + + R, x = ring("x", QQ.algebraic_field(I)) + def anp(element): + return ANP(element, [QQ(1), QQ(0), QQ(1)], QQ) + + f = anp([QQ(1, 1)])*x**4 + anp([QQ(2, 1)])*x**2 + + assert R.dup_factor_list(f) == \ + (anp([QQ(1, 1)]), [(anp([QQ(1, 1)])*x, 2), + (anp([QQ(1, 1)])*x**2 + anp([])*x + anp([QQ(2, 1)]), 1)]) + + R, x = ring("x", EX) + raises(DomainError, lambda: R.dup_factor_list(EX(sin(1)))) + + +def test_dmp_factor_list(): + R, x, y = ring("x,y", ZZ) + assert R.dmp_factor_list(0) == (ZZ(0), []) + assert R.dmp_factor_list(7) == (7, []) + + R, x, y = ring("x,y", QQ) + assert R.dmp_factor_list(0) == (QQ(0), []) + assert R.dmp_factor_list(QQ(1, 7)) == (QQ(1, 7), []) + + Rt, t = ring("t", ZZ) + R, x, y = ring("x,y", Rt) + assert R.dmp_factor_list(0) == (0, []) + assert R.dmp_factor_list(7) == (ZZ(7), []) + + Rt, t = ring("t", QQ) + R, x, y = ring("x,y", Rt) + assert R.dmp_factor_list(0) == (0, []) + assert R.dmp_factor_list(QQ(1, 7)) == (QQ(1, 7), []) + + R, x, y = ring("x,y", ZZ) + assert R.dmp_factor_list_include(0) == [(0, 1)] + assert R.dmp_factor_list_include(7) == [(7, 1)] + + R, X = xring("x:200", ZZ) + + f, g = X[0]**2 + 2*X[0] + 1, X[0] + 1 + assert R.dmp_factor_list(f) == (1, [(g, 2)]) + + f, g = X[-1]**2 + 2*X[-1] + 1, X[-1] + 1 + assert R.dmp_factor_list(f) == (1, [(g, 2)]) + + R, x = ring("x", ZZ) + assert R.dmp_factor_list(x**2 + 2*x + 1) == (1, [(x + 1, 2)]) + R, x = ring("x", QQ) + assert R.dmp_factor_list(QQ(1,2)*x**2 + x + QQ(1,2)) == (QQ(1,2), [(x + 1, 2)]) + + R, x, y = ring("x,y", ZZ) + assert R.dmp_factor_list(x**2 + 2*x + 1) == (1, [(x + 1, 2)]) + R, x, y = ring("x,y", QQ) + assert R.dmp_factor_list(QQ(1,2)*x**2 + x + QQ(1,2)) == (QQ(1,2), [(x + 1, 2)]) + + R, x, y = ring("x,y", ZZ) + f = 4*x**2*y + 4*x*y**2 + + assert R.dmp_factor_list(f) == \ + (4, [(y, 1), + (x, 1), + (x + y, 1)]) + + assert R.dmp_factor_list_include(f) == \ + [(4*y, 1), + (x, 1), + (x + y, 1)] + + R, x, y = ring("x,y", QQ) + f = QQ(1,2)*x**2*y + QQ(1,2)*x*y**2 + + assert R.dmp_factor_list(f) == \ + (QQ(1,2), [(y, 1), + (x, 1), + (x + y, 1)]) + + R, x, y = ring("x,y", RR) + f = 2.0*x**2 - 8.0*y**2 + + assert R.dmp_factor_list(f) == \ + (RR(8.0), [(0.5*x - y, 1), + (0.5*x + y, 1)]) + + f = 6.7225336055071*x**2*y**2 - 10.6463972754741*x*y - 0.33469524022264 + coeff, factors = R.dmp_factor_list(f) + assert coeff == RR(10.6463972754741) + assert len(factors) == 1 + assert factors[0][0].max_norm() == RR(1.0) + assert factors[0][1] == 1 + + Rt, t = ring("t", ZZ) + R, x, y = ring("x,y", Rt) + f = 4*t*x**2 + 4*t**2*x + + assert R.dmp_factor_list(f) == \ + (4*t, [(x, 1), + (x + t, 1)]) + + Rt, t = ring("t", QQ) + R, x, y = ring("x,y", Rt) + f = QQ(1, 2)*t*x**2 + QQ(1, 2)*t**2*x + + assert R.dmp_factor_list(f) == \ + (QQ(1, 2)*t, [(x, 1), + (x + t, 1)]) + + R, x, y = ring("x,y", FF(2)) + raises(NotImplementedError, lambda: R.dmp_factor_list(x**2 + y**2)) + + R, x, y = ring("x,y", EX) + raises(DomainError, lambda: R.dmp_factor_list(EX(sin(1)))) + + +def test_dup_irreducible_p(): + R, x = ring("x", ZZ) + assert R.dup_irreducible_p(x**2 + x + 1) is True + assert R.dup_irreducible_p(x**2 + 2*x + 1) is False + + +def test_dmp_irreducible_p(): + R, x, y = ring("x,y", ZZ) + assert R.dmp_irreducible_p(x**2 + x + 1) is True + assert R.dmp_irreducible_p(x**2 + 2*x + 1) is False diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/test_fields.py b/phivenv/Lib/site-packages/sympy/polys/tests/test_fields.py new file mode 100644 index 0000000000000000000000000000000000000000..4f85a00d75dc02ab794ff94c83ba18ddc2023313 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/tests/test_fields.py @@ -0,0 +1,353 @@ +"""Test sparse rational functions. """ + +from sympy.polys.fields import field, sfield, FracField, FracElement +from sympy.polys.rings import ring +from sympy.polys.domains import ZZ, QQ +from sympy.polys.orderings import lex + +from sympy.testing.pytest import raises, XFAIL +from sympy.core import symbols, E +from sympy.core.numbers import Rational +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import sqrt + +def test_FracField___init__(): + F1 = FracField("x,y", ZZ, lex) + F2 = FracField("x,y", ZZ, lex) + F3 = FracField("x,y,z", ZZ, lex) + + assert F1.x == F1.gens[0] + assert F1.y == F1.gens[1] + assert F1.x == F2.x + assert F1.y == F2.y + assert F1.x != F3.x + assert F1.y != F3.y + +def test_FracField___hash__(): + F, x, y, z = field("x,y,z", QQ) + assert hash(F) + +def test_FracField___eq__(): + assert field("x,y,z", QQ)[0] == field("x,y,z", QQ)[0] + assert field("x,y,z", QQ)[0] != field("x,y,z", ZZ)[0] + assert field("x,y,z", ZZ)[0] != field("x,y,z", QQ)[0] + assert field("x,y,z", QQ)[0] != field("x,y", QQ)[0] + assert field("x,y", QQ)[0] != field("x,y,z", QQ)[0] + +def test_sfield(): + x = symbols("x") + + F = FracField((E, exp(exp(x)), exp(x)), ZZ, lex) + e, exex, ex = F.gens + assert sfield(exp(x)*exp(exp(x) + 1 + log(exp(x) + 3)/2)**2/(exp(x) + 3)) \ + == (F, e**2*exex**2*ex) + + F = FracField((x, exp(1/x), log(x), x**QQ(1, 3)), ZZ, lex) + _, ex, lg, x3 = F.gens + assert sfield(((x-3)*log(x)+4*x**2)*exp(1/x+log(x)/3)/x**2) == \ + (F, (4*F.x**2*ex + F.x*ex*lg - 3*ex*lg)/x3**5) + + F = FracField((x, log(x), sqrt(x + log(x))), ZZ, lex) + _, lg, srt = F.gens + assert sfield((x + 1) / (x * (x + log(x))**QQ(3, 2)) - 1/(x * log(x)**2)) \ + == (F, (F.x*lg**2 - F.x*srt + lg**2 - lg*srt)/ + (F.x**2*lg**2*srt + F.x*lg**3*srt)) + +def test_FracElement___hash__(): + F, x, y, z = field("x,y,z", QQ) + assert hash(x*y/z) + +def test_FracElement_copy(): + F, x, y, z = field("x,y,z", ZZ) + + f = x*y/3*z + g = f.copy() + + assert f == g + g.numer[(1, 1, 1)] = 7 + assert f != g + +def test_FracElement_as_expr(): + F, x, y, z = field("x,y,z", ZZ) + f = (3*x**2*y - x*y*z)/(7*z**3 + 1) + + X, Y, Z = F.symbols + g = (3*X**2*Y - X*Y*Z)/(7*Z**3 + 1) + + assert f != g + assert f.as_expr() == g + + X, Y, Z = symbols("x,y,z") + g = (3*X**2*Y - X*Y*Z)/(7*Z**3 + 1) + + assert f != g + assert f.as_expr(X, Y, Z) == g + + raises(ValueError, lambda: f.as_expr(X)) + +def test_FracElement_from_expr(): + x, y, z = symbols("x,y,z") + F, X, Y, Z = field((x, y, z), ZZ) + + f = F.from_expr(1) + assert f == 1 and F.is_element(f) + + f = F.from_expr(Rational(3, 7)) + assert f == F(3)/7 and F.is_element(f) + + f = F.from_expr(x) + assert f == X and F.is_element(f) + + f = F.from_expr(Rational(3,7)*x) + assert f == X*Rational(3, 7) and F.is_element(f) + + f = F.from_expr(1/x) + assert f == 1/X and F.is_element(f) + + f = F.from_expr(x*y*z) + assert f == X*Y*Z and F.is_element(f) + + f = F.from_expr(x*y/z) + assert f == X*Y/Z and F.is_element(f) + + f = F.from_expr(x*y*z + x*y + x) + assert f == X*Y*Z + X*Y + X and F.is_element(f) + + f = F.from_expr((x*y*z + x*y + x)/(x*y + 7)) + assert f == (X*Y*Z + X*Y + X)/(X*Y + 7) and F.is_element(f) + + f = F.from_expr(x**3*y*z + x**2*y**7 + 1) + assert f == X**3*Y*Z + X**2*Y**7 + 1 and F.is_element(f) + + raises(ValueError, lambda: F.from_expr(2**x)) + raises(ValueError, lambda: F.from_expr(7*x + sqrt(2))) + + assert isinstance(ZZ[2**x].get_field().convert(2**(-x)), + FracElement) + assert isinstance(ZZ[x**2].get_field().convert(x**(-6)), + FracElement) + assert isinstance(ZZ[exp(Rational(1, 3))].get_field().convert(E), + FracElement) + + +def test_FracField_nested(): + a, b, x = symbols('a b x') + F1 = ZZ.frac_field(a, b) + F2 = F1.frac_field(x) + frac = F2(a + b) + assert frac.numer == F1.poly_ring(x)(a + b) + assert frac.numer.coeffs() == [F1(a + b)] + assert frac.denom == F1.poly_ring(x)(1) + + F3 = ZZ.poly_ring(a, b) + F4 = F3.frac_field(x) + frac = F4(a + b) + assert frac.numer == F3.poly_ring(x)(a + b) + assert frac.numer.coeffs() == [F3(a + b)] + assert frac.denom == F3.poly_ring(x)(1) + + frac = F2(F3(a + b)) + assert frac.numer == F1.poly_ring(x)(a + b) + assert frac.numer.coeffs() == [F1(a + b)] + assert frac.denom == F1.poly_ring(x)(1) + + frac = F4(F1(a + b)) + assert frac.numer == F3.poly_ring(x)(a + b) + assert frac.numer.coeffs() == [F3(a + b)] + assert frac.denom == F3.poly_ring(x)(1) + + +def test_FracElement__lt_le_gt_ge__(): + F, x, y = field("x,y", ZZ) + + assert F(1) < 1/x < 1/x**2 < 1/x**3 + assert F(1) <= 1/x <= 1/x**2 <= 1/x**3 + + assert -7/x < 1/x < 3/x < y/x < 1/x**2 + assert -7/x <= 1/x <= 3/x <= y/x <= 1/x**2 + + assert 1/x**3 > 1/x**2 > 1/x > F(1) + assert 1/x**3 >= 1/x**2 >= 1/x >= F(1) + + assert 1/x**2 > y/x > 3/x > 1/x > -7/x + assert 1/x**2 >= y/x >= 3/x >= 1/x >= -7/x + +def test_FracElement___neg__(): + F, x,y = field("x,y", QQ) + + f = (7*x - 9)/y + g = (-7*x + 9)/y + + assert -f == g + assert -g == f + +def test_FracElement___add__(): + F, x,y = field("x,y", QQ) + + f, g = 1/x, 1/y + assert f + g == g + f == (x + y)/(x*y) + + assert x + F.ring.gens[0] == F.ring.gens[0] + x == 2*x + + F, x,y = field("x,y", ZZ) + assert x + 3 == 3 + x + assert x + QQ(3,7) == QQ(3,7) + x == (7*x + 3)/7 + + Fuv, u,v = field("u,v", ZZ) + Fxyzt, x,y,z,t = field("x,y,z,t", Fuv) + + f = (u*v + x)/(y + u*v) + assert dict(f.numer) == {(1, 0, 0, 0): 1, (0, 0, 0, 0): u*v} + assert dict(f.denom) == {(0, 1, 0, 0): 1, (0, 0, 0, 0): u*v} + + Ruv, u,v = ring("u,v", ZZ) + Fxyzt, x,y,z,t = field("x,y,z,t", Ruv) + + f = (u*v + x)/(y + u*v) + assert dict(f.numer) == {(1, 0, 0, 0): 1, (0, 0, 0, 0): u*v} + assert dict(f.denom) == {(0, 1, 0, 0): 1, (0, 0, 0, 0): u*v} + +def test_FracElement___sub__(): + F, x,y = field("x,y", QQ) + + f, g = 1/x, 1/y + assert f - g == (-x + y)/(x*y) + + assert x - F.ring.gens[0] == F.ring.gens[0] - x == 0 + + F, x,y = field("x,y", ZZ) + assert x - 3 == -(3 - x) + assert x - QQ(3,7) == -(QQ(3,7) - x) == (7*x - 3)/7 + + Fuv, u,v = field("u,v", ZZ) + Fxyzt, x,y,z,t = field("x,y,z,t", Fuv) + + f = (u*v - x)/(y - u*v) + assert dict(f.numer) == {(1, 0, 0, 0):-1, (0, 0, 0, 0): u*v} + assert dict(f.denom) == {(0, 1, 0, 0): 1, (0, 0, 0, 0):-u*v} + + Ruv, u,v = ring("u,v", ZZ) + Fxyzt, x,y,z,t = field("x,y,z,t", Ruv) + + f = (u*v - x)/(y - u*v) + assert dict(f.numer) == {(1, 0, 0, 0):-1, (0, 0, 0, 0): u*v} + assert dict(f.denom) == {(0, 1, 0, 0): 1, (0, 0, 0, 0):-u*v} + +def test_FracElement___mul__(): + F, x,y = field("x,y", QQ) + + f, g = 1/x, 1/y + assert f*g == g*f == 1/(x*y) + + assert x*F.ring.gens[0] == F.ring.gens[0]*x == x**2 + + F, x,y = field("x,y", ZZ) + assert x*3 == 3*x + assert x*QQ(3,7) == QQ(3,7)*x == x*Rational(3, 7) + + Fuv, u,v = field("u,v", ZZ) + Fxyzt, x,y,z,t = field("x,y,z,t", Fuv) + + f = ((u + 1)*x*y + 1)/((v - 1)*z - t*u*v - 1) + assert dict(f.numer) == {(1, 1, 0, 0): u + 1, (0, 0, 0, 0): 1} + assert dict(f.denom) == {(0, 0, 1, 0): v - 1, (0, 0, 0, 1): -u*v, (0, 0, 0, 0): -1} + + Ruv, u,v = ring("u,v", ZZ) + Fxyzt, x,y,z,t = field("x,y,z,t", Ruv) + + f = ((u + 1)*x*y + 1)/((v - 1)*z - t*u*v - 1) + assert dict(f.numer) == {(1, 1, 0, 0): u + 1, (0, 0, 0, 0): 1} + assert dict(f.denom) == {(0, 0, 1, 0): v - 1, (0, 0, 0, 1): -u*v, (0, 0, 0, 0): -1} + +def test_FracElement___truediv__(): + F, x,y = field("x,y", QQ) + + f, g = 1/x, 1/y + assert f/g == y/x + + assert x/F.ring.gens[0] == F.ring.gens[0]/x == 1 + + F, x,y = field("x,y", ZZ) + assert x*3 == 3*x + assert x/QQ(3,7) == (QQ(3,7)/x)**-1 == x*Rational(7, 3) + + raises(ZeroDivisionError, lambda: x/0) + raises(ZeroDivisionError, lambda: 1/(x - x)) + raises(ZeroDivisionError, lambda: x/(x - x)) + + Fuv, u,v = field("u,v", ZZ) + Fxyzt, x,y,z,t = field("x,y,z,t", Fuv) + + f = (u*v)/(x*y) + assert dict(f.numer) == {(0, 0, 0, 0): u*v} + assert dict(f.denom) == {(1, 1, 0, 0): 1} + + g = (x*y)/(u*v) + assert dict(g.numer) == {(1, 1, 0, 0): 1} + assert dict(g.denom) == {(0, 0, 0, 0): u*v} + + Ruv, u,v = ring("u,v", ZZ) + Fxyzt, x,y,z,t = field("x,y,z,t", Ruv) + + f = (u*v)/(x*y) + assert dict(f.numer) == {(0, 0, 0, 0): u*v} + assert dict(f.denom) == {(1, 1, 0, 0): 1} + + g = (x*y)/(u*v) + assert dict(g.numer) == {(1, 1, 0, 0): 1} + assert dict(g.denom) == {(0, 0, 0, 0): u*v} + +def test_FracElement___pow__(): + F, x,y = field("x,y", QQ) + + f, g = 1/x, 1/y + + assert f**3 == 1/x**3 + assert g**3 == 1/y**3 + + assert (f*g)**3 == 1/(x**3*y**3) + assert (f*g)**-3 == (x*y)**3 + + raises(ZeroDivisionError, lambda: (x - x)**-3) + +def test_FracElement_diff(): + F, x,y,z = field("x,y,z", ZZ) + + assert ((x**2 + y)/(z + 1)).diff(x) == 2*x/(z + 1) + +@XFAIL +def test_FracElement___call__(): + F, x,y,z = field("x,y,z", ZZ) + f = (x**2 + 3*y)/z + + r = f(1, 1, 1) + assert r == 4 and not isinstance(r, FracElement) + raises(ZeroDivisionError, lambda: f(1, 1, 0)) + +def test_FracElement_evaluate(): + F, x,y,z = field("x,y,z", ZZ) + Fyz = field("y,z", ZZ)[0] + f = (x**2 + 3*y)/z + + assert f.evaluate(x, 0) == 3*Fyz.y/Fyz.z + raises(ZeroDivisionError, lambda: f.evaluate(z, 0)) + +def test_FracElement_subs(): + F, x,y,z = field("x,y,z", ZZ) + f = (x**2 + 3*y)/z + + assert f.subs(x, 0) == 3*y/z + raises(ZeroDivisionError, lambda: f.subs(z, 0)) + +def test_FracElement_compose(): + pass + +def test_FracField_index(): + a = symbols("a") + F, x, y, z = field('x y z', QQ) + assert F.index(x) == 0 + assert F.index(y) == 1 + + raises(ValueError, lambda: F.index(1)) + raises(ValueError, lambda: F.index(a)) + pass diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/test_galoistools.py b/phivenv/Lib/site-packages/sympy/polys/tests/test_galoistools.py new file mode 100644 index 0000000000000000000000000000000000000000..e512bdd865c300bb138cb40b4ff78f393b323c22 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/tests/test_galoistools.py @@ -0,0 +1,875 @@ +from sympy.polys.galoistools import ( + gf_crt, gf_crt1, gf_crt2, gf_int, + gf_degree, gf_strip, gf_trunc, gf_normal, + gf_from_dict, gf_to_dict, + gf_from_int_poly, gf_to_int_poly, + gf_neg, gf_add_ground, gf_sub_ground, gf_mul_ground, + gf_add, gf_sub, gf_add_mul, gf_sub_mul, gf_mul, gf_sqr, + gf_div, gf_rem, gf_quo, gf_exquo, + gf_lshift, gf_rshift, gf_expand, + gf_pow, gf_pow_mod, + gf_gcdex, gf_gcd, gf_lcm, gf_cofactors, + gf_LC, gf_TC, gf_monic, + gf_eval, gf_multi_eval, + gf_compose, gf_compose_mod, + gf_trace_map, + gf_diff, + gf_irreducible, gf_irreducible_p, + gf_irred_p_ben_or, gf_irred_p_rabin, + gf_sqf_list, gf_sqf_part, gf_sqf_p, + gf_Qmatrix, gf_Qbasis, + gf_ddf_zassenhaus, gf_ddf_shoup, + gf_edf_zassenhaus, gf_edf_shoup, + gf_berlekamp, + gf_factor_sqf, gf_factor, + gf_value, linear_congruence, _csolve_prime_las_vegas, + csolve_prime, gf_csolve, gf_frobenius_map, gf_frobenius_monomial_base +) + +from sympy.polys.polyerrors import ( + ExactQuotientFailed, +) + +from sympy.polys import polyconfig as config + +from sympy.polys.domains import ZZ +from sympy.core.numbers import pi +from sympy.ntheory.generate import nextprime +from sympy.testing.pytest import raises + + +def test_gf_crt(): + U = [49, 76, 65] + M = [99, 97, 95] + + p = 912285 + u = 639985 + + assert gf_crt(U, M, ZZ) == u + + E = [9215, 9405, 9603] + S = [62, 24, 12] + + assert gf_crt1(M, ZZ) == (p, E, S) + assert gf_crt2(U, M, p, E, S, ZZ) == u + + +def test_gf_int(): + assert gf_int(0, 5) == 0 + assert gf_int(1, 5) == 1 + assert gf_int(2, 5) == 2 + assert gf_int(3, 5) == -2 + assert gf_int(4, 5) == -1 + assert gf_int(5, 5) == 0 + + +def test_gf_degree(): + assert gf_degree([]) == -1 + assert gf_degree([1]) == 0 + assert gf_degree([1, 0]) == 1 + assert gf_degree([1, 0, 0, 0, 1]) == 4 + + +def test_gf_strip(): + assert gf_strip([]) == [] + assert gf_strip([0]) == [] + assert gf_strip([0, 0, 0]) == [] + + assert gf_strip([1]) == [1] + assert gf_strip([0, 1]) == [1] + assert gf_strip([0, 0, 0, 1]) == [1] + + assert gf_strip([1, 2, 0]) == [1, 2, 0] + assert gf_strip([0, 1, 2, 0]) == [1, 2, 0] + assert gf_strip([0, 0, 0, 1, 2, 0]) == [1, 2, 0] + + +def test_gf_trunc(): + assert gf_trunc([], 11) == [] + assert gf_trunc([1], 11) == [1] + assert gf_trunc([22], 11) == [] + assert gf_trunc([12], 11) == [1] + + assert gf_trunc([11, 22, 17, 1, 0], 11) == [6, 1, 0] + assert gf_trunc([12, 23, 17, 1, 0], 11) == [1, 1, 6, 1, 0] + + +def test_gf_normal(): + assert gf_normal([11, 22, 17, 1, 0], 11, ZZ) == [6, 1, 0] + + +def test_gf_from_to_dict(): + f = {11: 12, 6: 2, 0: 25} + F = {11: 1, 6: 2, 0: 3} + g = [1, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 3] + + assert gf_from_dict(f, 11, ZZ) == g + assert gf_to_dict(g, 11) == F + + f = {11: -5, 4: 0, 3: 1, 0: 12} + F = {11: -5, 3: 1, 0: 1} + g = [6, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1] + + assert gf_from_dict(f, 11, ZZ) == g + assert gf_to_dict(g, 11) == F + + assert gf_to_dict([10], 11, symmetric=True) == {0: -1} + assert gf_to_dict([10], 11, symmetric=False) == {0: 10} + + +def test_gf_from_to_int_poly(): + assert gf_from_int_poly([1, 0, 7, 2, 20], 5) == [1, 0, 2, 2, 0] + assert gf_to_int_poly([1, 0, 4, 2, 3], 5) == [1, 0, -1, 2, -2] + + assert gf_to_int_poly([10], 11, symmetric=True) == [-1] + assert gf_to_int_poly([10], 11, symmetric=False) == [10] + + +def test_gf_LC(): + assert gf_LC([], ZZ) == 0 + assert gf_LC([1], ZZ) == 1 + assert gf_LC([1, 2], ZZ) == 1 + + +def test_gf_TC(): + assert gf_TC([], ZZ) == 0 + assert gf_TC([1], ZZ) == 1 + assert gf_TC([1, 2], ZZ) == 2 + + +def test_gf_monic(): + assert gf_monic(ZZ.map([]), 11, ZZ) == (0, []) + + assert gf_monic(ZZ.map([1]), 11, ZZ) == (1, [1]) + assert gf_monic(ZZ.map([2]), 11, ZZ) == (2, [1]) + + assert gf_monic(ZZ.map([1, 2, 3, 4]), 11, ZZ) == (1, [1, 2, 3, 4]) + assert gf_monic(ZZ.map([2, 3, 4, 5]), 11, ZZ) == (2, [1, 7, 2, 8]) + + +def test_gf_arith(): + assert gf_neg([], 11, ZZ) == [] + assert gf_neg([1], 11, ZZ) == [10] + assert gf_neg([1, 2, 3], 11, ZZ) == [10, 9, 8] + + assert gf_add_ground([], 0, 11, ZZ) == [] + assert gf_sub_ground([], 0, 11, ZZ) == [] + + assert gf_add_ground([], 3, 11, ZZ) == [3] + assert gf_sub_ground([], 3, 11, ZZ) == [8] + + assert gf_add_ground([1], 3, 11, ZZ) == [4] + assert gf_sub_ground([1], 3, 11, ZZ) == [9] + + assert gf_add_ground([8], 3, 11, ZZ) == [] + assert gf_sub_ground([3], 3, 11, ZZ) == [] + + assert gf_add_ground([1, 2, 3], 3, 11, ZZ) == [1, 2, 6] + assert gf_sub_ground([1, 2, 3], 3, 11, ZZ) == [1, 2, 0] + + assert gf_mul_ground([], 0, 11, ZZ) == [] + assert gf_mul_ground([], 1, 11, ZZ) == [] + + assert gf_mul_ground([1], 0, 11, ZZ) == [] + assert gf_mul_ground([1], 1, 11, ZZ) == [1] + + assert gf_mul_ground([1, 2, 3], 0, 11, ZZ) == [] + assert gf_mul_ground([1, 2, 3], 1, 11, ZZ) == [1, 2, 3] + assert gf_mul_ground([1, 2, 3], 7, 11, ZZ) == [7, 3, 10] + + assert gf_add([], [], 11, ZZ) == [] + assert gf_add([1], [], 11, ZZ) == [1] + assert gf_add([], [1], 11, ZZ) == [1] + assert gf_add([1], [1], 11, ZZ) == [2] + assert gf_add([1], [2], 11, ZZ) == [3] + + assert gf_add([1, 2], [1], 11, ZZ) == [1, 3] + assert gf_add([1], [1, 2], 11, ZZ) == [1, 3] + + assert gf_add([1, 2, 3], [8, 9, 10], 11, ZZ) == [9, 0, 2] + + assert gf_sub([], [], 11, ZZ) == [] + assert gf_sub([1], [], 11, ZZ) == [1] + assert gf_sub([], [1], 11, ZZ) == [10] + assert gf_sub([1], [1], 11, ZZ) == [] + assert gf_sub([1], [2], 11, ZZ) == [10] + + assert gf_sub([1, 2], [1], 11, ZZ) == [1, 1] + assert gf_sub([1], [1, 2], 11, ZZ) == [10, 10] + + assert gf_sub([3, 2, 1], [8, 9, 10], 11, ZZ) == [6, 4, 2] + + assert gf_add_mul( + [1, 5, 6], [7, 3], [8, 0, 6, 1], 11, ZZ) == [1, 2, 10, 8, 9] + assert gf_sub_mul( + [1, 5, 6], [7, 3], [8, 0, 6, 1], 11, ZZ) == [10, 9, 3, 2, 3] + + assert gf_mul([], [], 11, ZZ) == [] + assert gf_mul([], [1], 11, ZZ) == [] + assert gf_mul([1], [], 11, ZZ) == [] + assert gf_mul([1], [1], 11, ZZ) == [1] + assert gf_mul([5], [7], 11, ZZ) == [2] + + assert gf_mul([3, 0, 0, 6, 1, 2], [4, 0, 1, 0], 11, ZZ) == [1, 0, + 3, 2, 4, 3, 1, 2, 0] + assert gf_mul([4, 0, 1, 0], [3, 0, 0, 6, 1, 2], 11, ZZ) == [1, 0, + 3, 2, 4, 3, 1, 2, 0] + + assert gf_mul([2, 0, 0, 1, 7], [2, 0, 0, 1, 7], 11, ZZ) == [4, 0, + 0, 4, 6, 0, 1, 3, 5] + + assert gf_sqr([], 11, ZZ) == [] + assert gf_sqr([2], 11, ZZ) == [4] + assert gf_sqr([1, 2], 11, ZZ) == [1, 4, 4] + + assert gf_sqr([2, 0, 0, 1, 7], 11, ZZ) == [4, 0, 0, 4, 6, 0, 1, 3, 5] + + +def test_gf_division(): + raises(ZeroDivisionError, lambda: gf_div([1, 2, 3], [], 11, ZZ)) + raises(ZeroDivisionError, lambda: gf_rem([1, 2, 3], [], 11, ZZ)) + raises(ZeroDivisionError, lambda: gf_quo([1, 2, 3], [], 11, ZZ)) + raises(ZeroDivisionError, lambda: gf_quo([1, 2, 3], [], 11, ZZ)) + + assert gf_div([1], [1, 2, 3], 7, ZZ) == ([], [1]) + assert gf_rem([1], [1, 2, 3], 7, ZZ) == [1] + assert gf_quo([1], [1, 2, 3], 7, ZZ) == [] + + f = ZZ.map([5, 4, 3, 2, 1, 0]) + g = ZZ.map([1, 2, 3]) + q = [5, 1, 0, 6] + r = [3, 3] + + assert gf_div(f, g, 7, ZZ) == (q, r) + assert gf_rem(f, g, 7, ZZ) == r + assert gf_quo(f, g, 7, ZZ) == q + + raises(ExactQuotientFailed, lambda: gf_exquo(f, g, 7, ZZ)) + + f = ZZ.map([5, 4, 3, 2, 1, 0]) + g = ZZ.map([1, 2, 3, 0]) + q = [5, 1, 0] + r = [6, 1, 0] + + assert gf_div(f, g, 7, ZZ) == (q, r) + assert gf_rem(f, g, 7, ZZ) == r + assert gf_quo(f, g, 7, ZZ) == q + + raises(ExactQuotientFailed, lambda: gf_exquo(f, g, 7, ZZ)) + + assert gf_quo(ZZ.map([1, 2, 1]), ZZ.map([1, 1]), 11, ZZ) == [1, 1] + + +def test_gf_shift(): + f = [1, 2, 3, 4, 5] + + assert gf_lshift([], 5, ZZ) == [] + assert gf_rshift([], 5, ZZ) == ([], []) + + assert gf_lshift(f, 1, ZZ) == [1, 2, 3, 4, 5, 0] + assert gf_lshift(f, 2, ZZ) == [1, 2, 3, 4, 5, 0, 0] + + assert gf_rshift(f, 0, ZZ) == (f, []) + assert gf_rshift(f, 1, ZZ) == ([1, 2, 3, 4], [5]) + assert gf_rshift(f, 3, ZZ) == ([1, 2], [3, 4, 5]) + assert gf_rshift(f, 5, ZZ) == ([], f) + + +def test_gf_expand(): + F = [([1, 1], 2), ([1, 2], 3)] + + assert gf_expand(F, 11, ZZ) == [1, 8, 3, 5, 6, 8] + assert gf_expand((4, F), 11, ZZ) == [4, 10, 1, 9, 2, 10] + + +def test_gf_powering(): + assert gf_pow([1, 0, 0, 1, 8], 0, 11, ZZ) == [1] + assert gf_pow([1, 0, 0, 1, 8], 1, 11, ZZ) == [1, 0, 0, 1, 8] + assert gf_pow([1, 0, 0, 1, 8], 2, 11, ZZ) == [1, 0, 0, 2, 5, 0, 1, 5, 9] + + assert gf_pow([1, 0, 0, 1, 8], 5, 11, ZZ) == \ + [1, 0, 0, 5, 7, 0, 10, 6, 2, 10, 9, 6, 10, 6, 6, 0, 5, 2, 5, 9, 10] + + assert gf_pow([1, 0, 0, 1, 8], 8, 11, ZZ) == \ + [1, 0, 0, 8, 9, 0, 6, 8, 10, 1, 2, 5, 10, 7, 7, 9, 1, 2, 0, 0, 6, 2, + 5, 2, 5, 7, 7, 9, 10, 10, 7, 5, 5] + + assert gf_pow([1, 0, 0, 1, 8], 45, 11, ZZ) == \ + [ 1, 0, 0, 1, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 4, 10, 0, 0, 0, 0, 0, 0, + 10, 0, 0, 10, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 6, 0, 0, 6, 4, 0, 0, 0, 0, 0, 0, 8, 0, 0, 8, 9, 0, 0, 0, 0, 0, 0, + 10, 0, 0, 10, 3, 0, 0, 0, 0, 0, 0, 4, 0, 0, 4, 10, 0, 0, 0, 0, 0, 0, + 8, 0, 0, 8, 9, 0, 0, 0, 0, 0, 0, 9, 0, 0, 9, 6, 0, 0, 0, 0, 0, 0, + 3, 0, 0, 3, 2, 0, 0, 0, 0, 0, 0, 10, 0, 0, 10, 3, 0, 0, 0, 0, 0, 0, + 10, 0, 0, 10, 3, 0, 0, 0, 0, 0, 0, 2, 0, 0, 2, 5, 0, 0, 0, 0, 0, 0, + 4, 0, 0, 4, 10] + + assert gf_pow_mod(ZZ.map([1, 0, 0, 1, 8]), 0, ZZ.map([2, 0, 7]), 11, ZZ) == [1] + assert gf_pow_mod(ZZ.map([1, 0, 0, 1, 8]), 1, ZZ.map([2, 0, 7]), 11, ZZ) == [1, 1] + assert gf_pow_mod(ZZ.map([1, 0, 0, 1, 8]), 2, ZZ.map([2, 0, 7]), 11, ZZ) == [2, 3] + assert gf_pow_mod(ZZ.map([1, 0, 0, 1, 8]), 5, ZZ.map([2, 0, 7]), 11, ZZ) == [7, 8] + assert gf_pow_mod(ZZ.map([1, 0, 0, 1, 8]), 8, ZZ.map([2, 0, 7]), 11, ZZ) == [1, 5] + assert gf_pow_mod(ZZ.map([1, 0, 0, 1, 8]), 45, ZZ.map([2, 0, 7]), 11, ZZ) == [5, 4] + + +def test_gf_gcdex(): + assert gf_gcdex(ZZ.map([]), ZZ.map([]), 11, ZZ) == ([1], [], []) + assert gf_gcdex(ZZ.map([2]), ZZ.map([]), 11, ZZ) == ([6], [], [1]) + assert gf_gcdex(ZZ.map([]), ZZ.map([2]), 11, ZZ) == ([], [6], [1]) + assert gf_gcdex(ZZ.map([2]), ZZ.map([2]), 11, ZZ) == ([], [6], [1]) + + assert gf_gcdex(ZZ.map([]), ZZ.map([3, 0]), 11, ZZ) == ([], [4], [1, 0]) + assert gf_gcdex(ZZ.map([3, 0]), ZZ.map([]), 11, ZZ) == ([4], [], [1, 0]) + + assert gf_gcdex(ZZ.map([3, 0]), ZZ.map([3, 0]), 11, ZZ) == ([], [4], [1, 0]) + + assert gf_gcdex(ZZ.map([1, 8, 7]), ZZ.map([1, 7, 1, 7]), 11, ZZ) == ([5, 6], [6], [1, 7]) + + +def test_gf_gcd(): + assert gf_gcd(ZZ.map([]), ZZ.map([]), 11, ZZ) == [] + assert gf_gcd(ZZ.map([2]), ZZ.map([]), 11, ZZ) == [1] + assert gf_gcd(ZZ.map([]), ZZ.map([2]), 11, ZZ) == [1] + assert gf_gcd(ZZ.map([2]), ZZ.map([2]), 11, ZZ) == [1] + + assert gf_gcd(ZZ.map([]), ZZ.map([1, 0]), 11, ZZ) == [1, 0] + assert gf_gcd(ZZ.map([1, 0]), ZZ.map([]), 11, ZZ) == [1, 0] + + assert gf_gcd(ZZ.map([3, 0]), ZZ.map([3, 0]), 11, ZZ) == [1, 0] + assert gf_gcd(ZZ.map([1, 8, 7]), ZZ.map([1, 7, 1, 7]), 11, ZZ) == [1, 7] + + +def test_gf_lcm(): + assert gf_lcm(ZZ.map([]), ZZ.map([]), 11, ZZ) == [] + assert gf_lcm(ZZ.map([2]), ZZ.map([]), 11, ZZ) == [] + assert gf_lcm(ZZ.map([]), ZZ.map([2]), 11, ZZ) == [] + assert gf_lcm(ZZ.map([2]), ZZ.map([2]), 11, ZZ) == [1] + + assert gf_lcm(ZZ.map([]), ZZ.map([1, 0]), 11, ZZ) == [] + assert gf_lcm(ZZ.map([1, 0]), ZZ.map([]), 11, ZZ) == [] + + assert gf_lcm(ZZ.map([3, 0]), ZZ.map([3, 0]), 11, ZZ) == [1, 0] + assert gf_lcm(ZZ.map([1, 8, 7]), ZZ.map([1, 7, 1, 7]), 11, ZZ) == [1, 8, 8, 8, 7] + + +def test_gf_cofactors(): + assert gf_cofactors(ZZ.map([]), ZZ.map([]), 11, ZZ) == ([], [], []) + assert gf_cofactors(ZZ.map([2]), ZZ.map([]), 11, ZZ) == ([1], [2], []) + assert gf_cofactors(ZZ.map([]), ZZ.map([2]), 11, ZZ) == ([1], [], [2]) + assert gf_cofactors(ZZ.map([2]), ZZ.map([2]), 11, ZZ) == ([1], [2], [2]) + + assert gf_cofactors(ZZ.map([]), ZZ.map([1, 0]), 11, ZZ) == ([1, 0], [], [1]) + assert gf_cofactors(ZZ.map([1, 0]), ZZ.map([]), 11, ZZ) == ([1, 0], [1], []) + + assert gf_cofactors(ZZ.map([3, 0]), ZZ.map([3, 0]), 11, ZZ) == ( + [1, 0], [3], [3]) + assert gf_cofactors(ZZ.map([1, 8, 7]), ZZ.map([1, 7, 1, 7]), 11, ZZ) == ( + ([1, 7], [1, 1], [1, 0, 1])) + + +def test_gf_diff(): + assert gf_diff([], 11, ZZ) == [] + assert gf_diff([7], 11, ZZ) == [] + + assert gf_diff([7, 3], 11, ZZ) == [7] + assert gf_diff([7, 3, 1], 11, ZZ) == [3, 3] + + assert gf_diff([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], 11, ZZ) == [] + + +def test_gf_eval(): + assert gf_eval([], 4, 11, ZZ) == 0 + assert gf_eval([], 27, 11, ZZ) == 0 + assert gf_eval([7], 4, 11, ZZ) == 7 + assert gf_eval([7], 27, 11, ZZ) == 7 + + assert gf_eval([1, 0, 3, 2, 4, 3, 1, 2, 0], 0, 11, ZZ) == 0 + assert gf_eval([1, 0, 3, 2, 4, 3, 1, 2, 0], 4, 11, ZZ) == 9 + assert gf_eval([1, 0, 3, 2, 4, 3, 1, 2, 0], 27, 11, ZZ) == 5 + + assert gf_eval([4, 0, 0, 4, 6, 0, 1, 3, 5], 0, 11, ZZ) == 5 + assert gf_eval([4, 0, 0, 4, 6, 0, 1, 3, 5], 4, 11, ZZ) == 3 + assert gf_eval([4, 0, 0, 4, 6, 0, 1, 3, 5], 27, 11, ZZ) == 9 + + assert gf_multi_eval([3, 2, 1], [0, 1, 2, 3], 11, ZZ) == [1, 6, 6, 1] + + +def test_gf_compose(): + assert gf_compose([], [1, 0], 11, ZZ) == [] + assert gf_compose_mod([], [1, 0], [1, 0], 11, ZZ) == [] + + assert gf_compose([1], [], 11, ZZ) == [1] + assert gf_compose([1, 0], [], 11, ZZ) == [] + assert gf_compose([1, 0], [1, 0], 11, ZZ) == [1, 0] + + f = ZZ.map([1, 1, 4, 9, 1]) + g = ZZ.map([1, 1, 1]) + h = ZZ.map([1, 0, 0, 2]) + + assert gf_compose(g, h, 11, ZZ) == [1, 0, 0, 5, 0, 0, 7] + assert gf_compose_mod(g, h, f, 11, ZZ) == [3, 9, 6, 10] + + +def test_gf_trace_map(): + f = ZZ.map([1, 1, 4, 9, 1]) + a = [1, 1, 1] + c = ZZ.map([1, 0]) + b = gf_pow_mod(c, 11, f, 11, ZZ) + + assert gf_trace_map(a, b, c, 0, f, 11, ZZ) == \ + ([1, 1, 1], [1, 1, 1]) + assert gf_trace_map(a, b, c, 1, f, 11, ZZ) == \ + ([5, 2, 10, 3], [5, 3, 0, 4]) + assert gf_trace_map(a, b, c, 2, f, 11, ZZ) == \ + ([5, 9, 5, 3], [10, 1, 5, 7]) + assert gf_trace_map(a, b, c, 3, f, 11, ZZ) == \ + ([1, 10, 6, 0], [7]) + assert gf_trace_map(a, b, c, 4, f, 11, ZZ) == \ + ([1, 1, 1], [1, 1, 8]) + assert gf_trace_map(a, b, c, 5, f, 11, ZZ) == \ + ([5, 2, 10, 3], [5, 3, 0, 0]) + assert gf_trace_map(a, b, c, 11, f, 11, ZZ) == \ + ([1, 10, 6, 0], [10]) + + +def test_gf_irreducible(): + assert gf_irreducible_p(gf_irreducible(1, 11, ZZ), 11, ZZ) is True + assert gf_irreducible_p(gf_irreducible(2, 11, ZZ), 11, ZZ) is True + assert gf_irreducible_p(gf_irreducible(3, 11, ZZ), 11, ZZ) is True + assert gf_irreducible_p(gf_irreducible(4, 11, ZZ), 11, ZZ) is True + assert gf_irreducible_p(gf_irreducible(5, 11, ZZ), 11, ZZ) is True + assert gf_irreducible_p(gf_irreducible(6, 11, ZZ), 11, ZZ) is True + assert gf_irreducible_p(gf_irreducible(7, 11, ZZ), 11, ZZ) is True + + +def test_gf_irreducible_p(): + assert gf_irred_p_ben_or(ZZ.map([7]), 11, ZZ) is True + assert gf_irred_p_ben_or(ZZ.map([7, 3]), 11, ZZ) is True + assert gf_irred_p_ben_or(ZZ.map([7, 3, 1]), 11, ZZ) is False + + assert gf_irred_p_rabin(ZZ.map([7]), 11, ZZ) is True + assert gf_irred_p_rabin(ZZ.map([7, 3]), 11, ZZ) is True + assert gf_irred_p_rabin(ZZ.map([7, 3, 1]), 11, ZZ) is False + + config.setup('GF_IRRED_METHOD', 'ben-or') + + assert gf_irreducible_p(ZZ.map([7]), 11, ZZ) is True + assert gf_irreducible_p(ZZ.map([7, 3]), 11, ZZ) is True + assert gf_irreducible_p(ZZ.map([7, 3, 1]), 11, ZZ) is False + + config.setup('GF_IRRED_METHOD', 'rabin') + + assert gf_irreducible_p(ZZ.map([7]), 11, ZZ) is True + assert gf_irreducible_p(ZZ.map([7, 3]), 11, ZZ) is True + assert gf_irreducible_p(ZZ.map([7, 3, 1]), 11, ZZ) is False + + config.setup('GF_IRRED_METHOD', 'other') + raises(KeyError, lambda: gf_irreducible_p([7], 11, ZZ)) + config.setup('GF_IRRED_METHOD') + + f = ZZ.map([1, 9, 9, 13, 16, 15, 6, 7, 7, 7, 10]) + g = ZZ.map([1, 7, 16, 7, 15, 13, 13, 11, 16, 10, 9]) + + h = gf_mul(f, g, 17, ZZ) + + assert gf_irred_p_ben_or(f, 17, ZZ) is True + assert gf_irred_p_ben_or(g, 17, ZZ) is True + + assert gf_irred_p_ben_or(h, 17, ZZ) is False + + assert gf_irred_p_rabin(f, 17, ZZ) is True + assert gf_irred_p_rabin(g, 17, ZZ) is True + + assert gf_irred_p_rabin(h, 17, ZZ) is False + + +def test_gf_squarefree(): + assert gf_sqf_list([], 11, ZZ) == (0, []) + assert gf_sqf_list([1], 11, ZZ) == (1, []) + assert gf_sqf_list([1, 1], 11, ZZ) == (1, [([1, 1], 1)]) + + assert gf_sqf_p([], 11, ZZ) is True + assert gf_sqf_p([1], 11, ZZ) is True + assert gf_sqf_p([1, 1], 11, ZZ) is True + + f = gf_from_dict({11: 1, 0: 1}, 11, ZZ) + + assert gf_sqf_p(f, 11, ZZ) is False + + assert gf_sqf_list(f, 11, ZZ) == \ + (1, [([1, 1], 11)]) + + f = [1, 5, 8, 4] + + assert gf_sqf_p(f, 11, ZZ) is False + + assert gf_sqf_list(f, 11, ZZ) == \ + (1, [([1, 1], 1), + ([1, 2], 2)]) + + assert gf_sqf_part(f, 11, ZZ) == [1, 3, 2] + + f = [1, 0, 0, 2, 0, 0, 2, 0, 0, 1, 0] + + assert gf_sqf_list(f, 3, ZZ) == \ + (1, [([1, 0], 1), + ([1, 1], 3), + ([1, 2], 6)]) + +def test_gf_frobenius_map(): + f = ZZ.map([2, 0, 1, 0, 2, 2, 0, 2, 2, 2]) + g = ZZ.map([1,1,0,2,0,1,0,2,0,1]) + p = 3 + b = gf_frobenius_monomial_base(g, p, ZZ) + h = gf_frobenius_map(f, g, b, p, ZZ) + h1 = gf_pow_mod(f, p, g, p, ZZ) + assert h == h1 + + +def test_gf_berlekamp(): + f = gf_from_int_poly([1, -3, 1, -3, -1, -3, 1], 11) + + Q = [[1, 0, 0, 0, 0, 0], + [3, 5, 8, 8, 6, 5], + [3, 6, 6, 1, 10, 0], + [9, 4, 10, 3, 7, 9], + [7, 8, 10, 0, 0, 8], + [8, 10, 7, 8, 10, 8]] + + V = [[1, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 1, 0], + [0, 0, 7, 9, 0, 1]] + + assert gf_Qmatrix(f, 11, ZZ) == Q + assert gf_Qbasis(Q, 11, ZZ) == V + + assert gf_berlekamp(f, 11, ZZ) == \ + [[1, 1], [1, 5, 3], [1, 2, 3, 4]] + + f = ZZ.map([1, 0, 1, 0, 10, 10, 8, 2, 8]) + + Q = ZZ.map([[1, 0, 0, 0, 0, 0, 0, 0], + [2, 1, 7, 11, 10, 12, 5, 11], + [3, 6, 4, 3, 0, 4, 7, 2], + [4, 3, 6, 5, 1, 6, 2, 3], + [2, 11, 8, 8, 3, 1, 3, 11], + [6, 11, 8, 6, 2, 7, 10, 9], + [5, 11, 7, 10, 0, 11, 7, 12], + [3, 3, 12, 5, 0, 11, 9, 12]]) + + V = [[1, 0, 0, 0, 0, 0, 0, 0], + [0, 5, 5, 0, 9, 5, 1, 0], + [0, 9, 11, 9, 10, 12, 0, 1]] + + assert gf_Qmatrix(f, 13, ZZ) == Q + assert gf_Qbasis(Q, 13, ZZ) == V + + assert gf_berlekamp(f, 13, ZZ) == \ + [[1, 3], [1, 8, 4, 12], [1, 2, 3, 4, 6]] + + +def test_gf_ddf(): + f = gf_from_dict({15: ZZ(1), 0: ZZ(-1)}, 11, ZZ) + g = [([1, 0, 0, 0, 0, 10], 1), + ([1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1], 2)] + + assert gf_ddf_zassenhaus(f, 11, ZZ) == g + assert gf_ddf_shoup(f, 11, ZZ) == g + + f = gf_from_dict({63: ZZ(1), 0: ZZ(1)}, 2, ZZ) + g = [([1, 1], 1), + ([1, 1, 1], 2), + ([1, 1, 1, 1, 1, 1, 1], 3), + ([1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, + 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1], 6)] + + assert gf_ddf_zassenhaus(f, 2, ZZ) == g + assert gf_ddf_shoup(f, 2, ZZ) == g + + f = gf_from_dict({6: ZZ(1), 5: ZZ(-1), 4: ZZ(1), 3: ZZ(1), 1: ZZ(-1)}, 3, ZZ) + g = [([1, 1, 0], 1), + ([1, 1, 0, 1, 2], 2)] + + assert gf_ddf_zassenhaus(f, 3, ZZ) == g + assert gf_ddf_shoup(f, 3, ZZ) == g + + f = ZZ.map([1, 2, 5, 26, 677, 436, 791, 325, 456, 24, 577]) + g = [([1, 701], 1), + ([1, 110, 559, 532, 694, 151, 110, 70, 735, 122], 9)] + + assert gf_ddf_zassenhaus(f, 809, ZZ) == g + assert gf_ddf_shoup(f, 809, ZZ) == g + + p = ZZ(nextprime(int((2**15 * pi).evalf()))) + f = gf_from_dict({15: 1, 1: 1, 0: 1}, p, ZZ) + g = [([1, 22730, 68144], 2), + ([1, 64876, 83977, 10787, 12561, 68608, 52650, 88001, 84356], 4), + ([1, 15347, 95022, 84569, 94508, 92335], 5)] + + assert gf_ddf_zassenhaus(f, p, ZZ) == g + assert gf_ddf_shoup(f, p, ZZ) == g + + +def test_gf_edf(): + f = ZZ.map([1, 1, 0, 1, 2]) + g = ZZ.map([[1, 0, 1], [1, 1, 2]]) + + assert gf_edf_zassenhaus(f, 2, 3, ZZ) == g + assert gf_edf_shoup(f, 2, 3, ZZ) == g + + +def test_issue_23174(): + f = ZZ.map([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) + g = ZZ.map([[1, 0, 0, 1, 1, 1, 0, 0, 1], [1, 1, 1, 0, 1, 0, 1, 1, 1]]) + + assert gf_edf_zassenhaus(f, 8, 2, ZZ) == g + + +def test_gf_factor(): + assert gf_factor([], 11, ZZ) == (0, []) + assert gf_factor([1], 11, ZZ) == (1, []) + assert gf_factor([1, 1], 11, ZZ) == (1, [([1, 1], 1)]) + + assert gf_factor_sqf([], 11, ZZ) == (0, []) + assert gf_factor_sqf([1], 11, ZZ) == (1, []) + assert gf_factor_sqf([1, 1], 11, ZZ) == (1, [[1, 1]]) + + config.setup('GF_FACTOR_METHOD', 'berlekamp') + + assert gf_factor_sqf([], 11, ZZ) == (0, []) + assert gf_factor_sqf([1], 11, ZZ) == (1, []) + assert gf_factor_sqf([1, 1], 11, ZZ) == (1, [[1, 1]]) + + config.setup('GF_FACTOR_METHOD', 'zassenhaus') + + assert gf_factor_sqf([], 11, ZZ) == (0, []) + assert gf_factor_sqf([1], 11, ZZ) == (1, []) + assert gf_factor_sqf([1, 1], 11, ZZ) == (1, [[1, 1]]) + + config.setup('GF_FACTOR_METHOD', 'shoup') + + assert gf_factor_sqf(ZZ.map([]), 11, ZZ) == (0, []) + assert gf_factor_sqf(ZZ.map([1]), 11, ZZ) == (1, []) + assert gf_factor_sqf(ZZ.map([1, 1]), 11, ZZ) == (1, [[1, 1]]) + + f, p = ZZ.map([1, 0, 0, 1, 0]), 2 + + g = (1, [([1, 0], 1), + ([1, 1], 1), + ([1, 1, 1], 1)]) + + config.setup('GF_FACTOR_METHOD', 'berlekamp') + assert gf_factor(f, p, ZZ) == g + + config.setup('GF_FACTOR_METHOD', 'zassenhaus') + assert gf_factor(f, p, ZZ) == g + + config.setup('GF_FACTOR_METHOD', 'shoup') + assert gf_factor(f, p, ZZ) == g + + g = (1, [[1, 0], + [1, 1], + [1, 1, 1]]) + + config.setup('GF_FACTOR_METHOD', 'berlekamp') + assert gf_factor_sqf(f, p, ZZ) == g + + config.setup('GF_FACTOR_METHOD', 'zassenhaus') + assert gf_factor_sqf(f, p, ZZ) == g + + config.setup('GF_FACTOR_METHOD', 'shoup') + assert gf_factor_sqf(f, p, ZZ) == g + + f, p = gf_from_int_poly([1, -3, 1, -3, -1, -3, 1], 11), 11 + + g = (1, [([1, 1], 1), + ([1, 5, 3], 1), + ([1, 2, 3, 4], 1)]) + + config.setup('GF_FACTOR_METHOD', 'berlekamp') + assert gf_factor(f, p, ZZ) == g + + config.setup('GF_FACTOR_METHOD', 'zassenhaus') + assert gf_factor(f, p, ZZ) == g + + config.setup('GF_FACTOR_METHOD', 'shoup') + assert gf_factor(f, p, ZZ) == g + + f, p = [1, 5, 8, 4], 11 + + g = (1, [([1, 1], 1), ([1, 2], 2)]) + + config.setup('GF_FACTOR_METHOD', 'berlekamp') + assert gf_factor(f, p, ZZ) == g + + config.setup('GF_FACTOR_METHOD', 'zassenhaus') + assert gf_factor(f, p, ZZ) == g + + config.setup('GF_FACTOR_METHOD', 'shoup') + assert gf_factor(f, p, ZZ) == g + + f, p = [1, 1, 10, 1, 0, 10, 10, 10, 0, 0], 11 + + g = (1, [([1, 0], 2), ([1, 9, 5], 1), ([1, 3, 0, 8, 5, 2], 1)]) + + config.setup('GF_FACTOR_METHOD', 'berlekamp') + assert gf_factor(f, p, ZZ) == g + + config.setup('GF_FACTOR_METHOD', 'zassenhaus') + assert gf_factor(f, p, ZZ) == g + + config.setup('GF_FACTOR_METHOD', 'shoup') + assert gf_factor(f, p, ZZ) == g + + f, p = gf_from_dict({32: 1, 0: 1}, 11, ZZ), 11 + + g = (1, [([1, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 10], 1), + ([1, 0, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 10], 1)]) + + config.setup('GF_FACTOR_METHOD', 'berlekamp') + assert gf_factor(f, p, ZZ) == g + + config.setup('GF_FACTOR_METHOD', 'zassenhaus') + assert gf_factor(f, p, ZZ) == g + + config.setup('GF_FACTOR_METHOD', 'shoup') + assert gf_factor(f, p, ZZ) == g + + f, p = gf_from_dict({32: ZZ(8), 0: ZZ(5)}, 11, ZZ), 11 + + g = (8, [([1, 3], 1), + ([1, 8], 1), + ([1, 0, 9], 1), + ([1, 2, 2], 1), + ([1, 9, 2], 1), + ([1, 0, 5, 0, 7], 1), + ([1, 0, 6, 0, 7], 1), + ([1, 0, 0, 0, 1, 0, 0, 0, 6], 1), + ([1, 0, 0, 0, 10, 0, 0, 0, 6], 1)]) + + config.setup('GF_FACTOR_METHOD', 'berlekamp') + assert gf_factor(f, p, ZZ) == g + + config.setup('GF_FACTOR_METHOD', 'zassenhaus') + assert gf_factor(f, p, ZZ) == g + + config.setup('GF_FACTOR_METHOD', 'shoup') + assert gf_factor(f, p, ZZ) == g + + f, p = gf_from_dict({63: ZZ(8), 0: ZZ(5)}, 11, ZZ), 11 + + g = (8, [([1, 7], 1), + ([1, 4, 5], 1), + ([1, 6, 8, 2], 1), + ([1, 9, 9, 2], 1), + ([1, 0, 0, 9, 0, 0, 4], 1), + ([1, 2, 0, 8, 4, 6, 4], 1), + ([1, 2, 3, 8, 0, 6, 4], 1), + ([1, 2, 6, 0, 8, 4, 4], 1), + ([1, 3, 3, 1, 6, 8, 4], 1), + ([1, 5, 6, 0, 8, 6, 4], 1), + ([1, 6, 2, 7, 9, 8, 4], 1), + ([1, 10, 4, 7, 10, 7, 4], 1), + ([1, 10, 10, 1, 4, 9, 4], 1)]) + + config.setup('GF_FACTOR_METHOD', 'berlekamp') + assert gf_factor(f, p, ZZ) == g + + config.setup('GF_FACTOR_METHOD', 'zassenhaus') + assert gf_factor(f, p, ZZ) == g + + config.setup('GF_FACTOR_METHOD', 'shoup') + assert gf_factor(f, p, ZZ) == g + + # Gathen polynomials: x**n + x + 1 (mod p > 2**n * pi) + + p = ZZ(nextprime(int((2**15 * pi).evalf()))) + f = gf_from_dict({15: 1, 1: 1, 0: 1}, p, ZZ) + + assert gf_sqf_p(f, p, ZZ) is True + + g = (1, [([1, 22730, 68144], 1), + ([1, 81553, 77449, 86810, 4724], 1), + ([1, 86276, 56779, 14859, 31575], 1), + ([1, 15347, 95022, 84569, 94508, 92335], 1)]) + + config.setup('GF_FACTOR_METHOD', 'zassenhaus') + assert gf_factor(f, p, ZZ) == g + + config.setup('GF_FACTOR_METHOD', 'shoup') + assert gf_factor(f, p, ZZ) == g + + g = (1, [[1, 22730, 68144], + [1, 81553, 77449, 86810, 4724], + [1, 86276, 56779, 14859, 31575], + [1, 15347, 95022, 84569, 94508, 92335]]) + + config.setup('GF_FACTOR_METHOD', 'zassenhaus') + assert gf_factor_sqf(f, p, ZZ) == g + + config.setup('GF_FACTOR_METHOD', 'shoup') + assert gf_factor_sqf(f, p, ZZ) == g + + # Shoup polynomials: f = a_0 x**n + a_1 x**(n-1) + ... + a_n + # (mod p > 2**(n-2) * pi), where a_n = a_{n-1}**2 + 1, a_0 = 1 + + p = ZZ(nextprime(int((2**4 * pi).evalf()))) + f = ZZ.map([1, 2, 5, 26, 41, 39, 38]) + + assert gf_sqf_p(f, p, ZZ) is True + + g = (1, [([1, 44, 26], 1), + ([1, 11, 25, 18, 30], 1)]) + + config.setup('GF_FACTOR_METHOD', 'zassenhaus') + assert gf_factor(f, p, ZZ) == g + + config.setup('GF_FACTOR_METHOD', 'shoup') + assert gf_factor(f, p, ZZ) == g + + g = (1, [[1, 44, 26], + [1, 11, 25, 18, 30]]) + + config.setup('GF_FACTOR_METHOD', 'zassenhaus') + assert gf_factor_sqf(f, p, ZZ) == g + + config.setup('GF_FACTOR_METHOD', 'shoup') + assert gf_factor_sqf(f, p, ZZ) == g + + config.setup('GF_FACTOR_METHOD', 'other') + raises(KeyError, lambda: gf_factor([1, 1], 11, ZZ)) + config.setup('GF_FACTOR_METHOD') + + +def test_gf_csolve(): + assert gf_value([1, 7, 2, 4], 11) == 2204 + + assert linear_congruence(4, 3, 5) == [2] + assert linear_congruence(0, 3, 5) == [] + assert linear_congruence(6, 1, 4) == [] + assert linear_congruence(0, 5, 5) == [0, 1, 2, 3, 4] + assert linear_congruence(3, 12, 15) == [4, 9, 14] + assert linear_congruence(6, 0, 18) == [0, 3, 6, 9, 12, 15] + # _csolve_prime_las_vegas + assert _csolve_prime_las_vegas([2, 3, 1], 5) == [2, 4] + assert _csolve_prime_las_vegas([2, 0, 1], 5) == [] + from sympy.ntheory import primerange + for p in primerange(2, 100): + # f = x**(p-1) - 1 + f = gf_sub_ground(gf_pow([1, 0], p - 1, p, ZZ), 1, p, ZZ) + assert _csolve_prime_las_vegas(f, p) == list(range(1, p)) + # with power = 1 + assert csolve_prime([1, 3, 2, 17], 7) == [3] + assert csolve_prime([1, 3, 1, 5], 5) == [0, 1] + assert csolve_prime([3, 6, 9, 3], 3) == [0, 1, 2] + # with power > 1 + assert csolve_prime( + [1, 1, 223], 3, 4) == [4, 13, 22, 31, 40, 49, 58, 67, 76] + assert csolve_prime([3, 5, 2, 25], 5, 3) == [16, 50, 99] + assert csolve_prime([3, 2, 2, 49], 7, 3) == [147, 190, 234] + + assert gf_csolve([1, 1, 7], 189) == [13, 49, 76, 112, 139, 175] + assert gf_csolve([1, 3, 4, 1, 30], 60) == [10, 30] + assert gf_csolve([1, 1, 7], 15) == [] diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/test_groebnertools.py b/phivenv/Lib/site-packages/sympy/polys/tests/test_groebnertools.py new file mode 100644 index 0000000000000000000000000000000000000000..b7d0fc112047ac26f67d096db02eb8a1c91cab89 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/tests/test_groebnertools.py @@ -0,0 +1,533 @@ +"""Tests for Groebner bases. """ + +from sympy.polys.groebnertools import ( + groebner, sig, sig_key, + lbp, lbp_key, critical_pair, + cp_key, is_rewritable_or_comparable, + Sign, Polyn, Num, s_poly, f5_reduce, + groebner_lcm, groebner_gcd, is_groebner, + is_reduced +) + +from sympy.polys.fglmtools import _representing_matrices +from sympy.polys.orderings import lex, grlex + +from sympy.polys.rings import ring, xring +from sympy.polys.domains import ZZ, QQ + +from sympy.testing.pytest import slow +from sympy.polys import polyconfig as config + +def _do_test_groebner(): + R, x,y = ring("x,y", QQ, lex) + f = x**2 + 2*x*y**2 + g = x*y + 2*y**3 - 1 + + assert groebner([f, g], R) == [x, y**3 - QQ(1,2)] + + R, y,x = ring("y,x", QQ, lex) + f = 2*x**2*y + y**2 + g = 2*x**3 + x*y - 1 + + assert groebner([f, g], R) == [y, x**3 - QQ(1,2)] + + R, x,y,z = ring("x,y,z", QQ, lex) + f = x - z**2 + g = y - z**3 + + assert groebner([f, g], R) == [f, g] + + R, x,y = ring("x,y", QQ, grlex) + f = x**3 - 2*x*y + g = x**2*y + x - 2*y**2 + + assert groebner([f, g], R) == [x**2, x*y, -QQ(1,2)*x + y**2] + + R, x,y,z = ring("x,y,z", QQ, lex) + f = -x**2 + y + g = -x**3 + z + + assert groebner([f, g], R) == [x**2 - y, x*y - z, x*z - y**2, y**3 - z**2] + + R, x,y,z = ring("x,y,z", QQ, grlex) + f = -x**2 + y + g = -x**3 + z + + assert groebner([f, g], R) == [y**3 - z**2, x**2 - y, x*y - z, x*z - y**2] + + R, x,y,z = ring("x,y,z", QQ, lex) + f = -x**2 + z + g = -x**3 + y + + assert groebner([f, g], R) == [x**2 - z, x*y - z**2, x*z - y, y**2 - z**3] + + R, x,y,z = ring("x,y,z", QQ, grlex) + f = -x**2 + z + g = -x**3 + y + + assert groebner([f, g], R) == [-y**2 + z**3, x**2 - z, x*y - z**2, x*z - y] + + R, x,y,z = ring("x,y,z", QQ, lex) + f = x - y**2 + g = -y**3 + z + + assert groebner([f, g], R) == [x - y**2, y**3 - z] + + R, x,y,z = ring("x,y,z", QQ, grlex) + f = x - y**2 + g = -y**3 + z + + assert groebner([f, g], R) == [x**2 - y*z, x*y - z, -x + y**2] + + R, x,y,z = ring("x,y,z", QQ, lex) + f = x - z**2 + g = y - z**3 + + assert groebner([f, g], R) == [x - z**2, y - z**3] + + R, x,y,z = ring("x,y,z", QQ, grlex) + f = x - z**2 + g = y - z**3 + + assert groebner([f, g], R) == [x**2 - y*z, x*z - y, -x + z**2] + + R, x,y,z = ring("x,y,z", QQ, lex) + f = -y**2 + z + g = x - y**3 + + assert groebner([f, g], R) == [x - y*z, y**2 - z] + + R, x,y,z = ring("x,y,z", QQ, grlex) + f = -y**2 + z + g = x - y**3 + + assert groebner([f, g], R) == [-x**2 + z**3, x*y - z**2, y**2 - z, -x + y*z] + + R, x,y,z = ring("x,y,z", QQ, lex) + f = y - z**2 + g = x - z**3 + + assert groebner([f, g], R) == [x - z**3, y - z**2] + + R, x,y,z = ring("x,y,z", QQ, grlex) + f = y - z**2 + g = x - z**3 + + assert groebner([f, g], R) == [-x**2 + y**3, x*z - y**2, -x + y*z, -y + z**2] + + R, x,y,z = ring("x,y,z", QQ, lex) + f = 4*x**2*y**2 + 4*x*y + 1 + g = x**2 + y**2 - 1 + + assert groebner([f, g], R) == [ + x - 4*y**7 + 8*y**5 - 7*y**3 + 3*y, + y**8 - 2*y**6 + QQ(3,2)*y**4 - QQ(1,2)*y**2 + QQ(1,16), + ] + +def test_groebner_buchberger(): + with config.using(groebner='buchberger'): + _do_test_groebner() + +def test_groebner_f5b(): + with config.using(groebner='f5b'): + _do_test_groebner() + +def _do_test_benchmark_minpoly(): + R, x,y,z = ring("x,y,z", QQ, lex) + + F = [x**3 + x + 1, y**2 + y + 1, (x + y) * z - (x**2 + y)] + G = [x + QQ(155,2067)*z**5 - QQ(355,689)*z**4 + QQ(6062,2067)*z**3 - QQ(3687,689)*z**2 + QQ(6878,2067)*z - QQ(25,53), + y + QQ(4,53)*z**5 - QQ(91,159)*z**4 + QQ(523,159)*z**3 - QQ(387,53)*z**2 + QQ(1043,159)*z - QQ(308,159), + z**6 - 7*z**5 + 41*z**4 - 82*z**3 + 89*z**2 - 46*z + 13] + + assert groebner(F, R) == G + +def test_benchmark_minpoly_buchberger(): + with config.using(groebner='buchberger'): + _do_test_benchmark_minpoly() + +def test_benchmark_minpoly_f5b(): + with config.using(groebner='f5b'): + _do_test_benchmark_minpoly() + + +def test_benchmark_coloring(): + V = range(1, 12 + 1) + E = [(1, 2), (2, 3), (1, 4), (1, 6), (1, 12), (2, 5), (2, 7), (3, 8), (3, 10), + (4, 11), (4, 9), (5, 6), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), + (11, 12), (5, 12), (5, 9), (6, 10), (7, 11), (8, 12), (3, 4)] + + R, V = xring([ "x%d" % v for v in V ], QQ, lex) + E = [(V[i - 1], V[j - 1]) for i, j in E] + + x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12 = V + + I3 = [x**3 - 1 for x in V] + Ig = [x**2 + x*y + y**2 for x, y in E] + + I = I3 + Ig + + assert groebner(I[:-1], R) == [ + x1 + x11 + x12, + x2 - x11, + x3 - x12, + x4 - x12, + x5 + x11 + x12, + x6 - x11, + x7 - x12, + x8 + x11 + x12, + x9 - x11, + x10 + x11 + x12, + x11**2 + x11*x12 + x12**2, + x12**3 - 1, + ] + + assert groebner(I, R) == [1] + + +def _do_test_benchmark_katsura_3(): + R, x0,x1,x2 = ring("x:3", ZZ, lex) + I = [x0 + 2*x1 + 2*x2 - 1, + x0**2 + 2*x1**2 + 2*x2**2 - x0, + 2*x0*x1 + 2*x1*x2 - x1] + + assert groebner(I, R) == [ + -7 + 7*x0 + 8*x2 + 158*x2**2 - 420*x2**3, + 7*x1 + 3*x2 - 79*x2**2 + 210*x2**3, + x2 + x2**2 - 40*x2**3 + 84*x2**4, + ] + + R, x0,x1,x2 = ring("x:3", ZZ, grlex) + I = [ i.set_ring(R) for i in I ] + + assert groebner(I, R) == [ + 7*x1 + 3*x2 - 79*x2**2 + 210*x2**3, + -x1 + x2 - 3*x2**2 + 5*x1**2, + -x1 - 4*x2 + 10*x1*x2 + 12*x2**2, + -1 + x0 + 2*x1 + 2*x2, + ] + +def test_benchmark_katsura3_buchberger(): + with config.using(groebner='buchberger'): + _do_test_benchmark_katsura_3() + +def test_benchmark_katsura3_f5b(): + with config.using(groebner='f5b'): + _do_test_benchmark_katsura_3() + +def _do_test_benchmark_katsura_4(): + R, x0,x1,x2,x3 = ring("x:4", ZZ, lex) + I = [x0 + 2*x1 + 2*x2 + 2*x3 - 1, + x0**2 + 2*x1**2 + 2*x2**2 + 2*x3**2 - x0, + 2*x0*x1 + 2*x1*x2 + 2*x2*x3 - x1, + x1**2 + 2*x0*x2 + 2*x1*x3 - x2] + + assert groebner(I, R) == [ + 5913075*x0 - 159690237696*x3**7 + 31246269696*x3**6 + 27439610544*x3**5 - 6475723368*x3**4 - 838935856*x3**3 + 275119624*x3**2 + 4884038*x3 - 5913075, + 1971025*x1 - 97197721632*x3**7 + 73975630752*x3**6 - 12121915032*x3**5 - 2760941496*x3**4 + 814792828*x3**3 - 1678512*x3**2 - 9158924*x3, + 5913075*x2 + 371438283744*x3**7 - 237550027104*x3**6 + 22645939824*x3**5 + 11520686172*x3**4 - 2024910556*x3**3 - 132524276*x3**2 + 30947828*x3, + 128304*x3**8 - 93312*x3**7 + 15552*x3**6 + 3144*x3**5 - + 1120*x3**4 + 36*x3**3 + 15*x3**2 - x3, + ] + + R, x0,x1,x2,x3 = ring("x:4", ZZ, grlex) + I = [ i.set_ring(R) for i in I ] + + assert groebner(I, R) == [ + 393*x1 - 4662*x2**2 + 4462*x2*x3 - 59*x2 + 224532*x3**4 - 91224*x3**3 - 678*x3**2 + 2046*x3, + -x1 + 196*x2**3 - 21*x2**2 + 60*x2*x3 - 18*x2 - 168*x3**3 + 83*x3**2 - 9*x3, + -6*x1 + 1134*x2**2*x3 - 189*x2**2 - 466*x2*x3 + 32*x2 - 630*x3**3 + 57*x3**2 + 51*x3, + 33*x1 + 63*x2**2 + 2268*x2*x3**2 - 188*x2*x3 + 34*x2 + 2520*x3**3 - 849*x3**2 + 3*x3, + 7*x1**2 - x1 - 7*x2**2 - 24*x2*x3 + 3*x2 - 15*x3**2 + 5*x3, + 14*x1*x2 - x1 + 14*x2**2 + 18*x2*x3 - 4*x2 + 6*x3**2 - 2*x3, + 14*x1*x3 - x1 + 7*x2**2 + 32*x2*x3 - 4*x2 + 27*x3**2 - 9*x3, + x0 + 2*x1 + 2*x2 + 2*x3 - 1, + ] + +def test_benchmark_kastura_4_buchberger(): + with config.using(groebner='buchberger'): + _do_test_benchmark_katsura_4() + +def test_benchmark_kastura_4_f5b(): + with config.using(groebner='f5b'): + _do_test_benchmark_katsura_4() + +def _do_test_benchmark_czichowski(): + R, x,t = ring("x,t", ZZ, lex) + I = [9*x**8 + 36*x**7 - 32*x**6 - 252*x**5 - 78*x**4 + 468*x**3 + 288*x**2 - 108*x + 9, + (-72 - 72*t)*x**7 + (-256 - 252*t)*x**6 + (192 + 192*t)*x**5 + (1280 + 1260*t)*x**4 + (312 + 312*t)*x**3 + (-404*t)*x**2 + (-576 - 576*t)*x + 96 + 108*t] + + assert groebner(I, R) == [ + 3725588592068034903797967297424801242396746870413359539263038139343329273586196480000*x - + 160420835591776763325581422211936558925462474417709511019228211783493866564923546661604487873*t**7 - + 1406108495478033395547109582678806497509499966197028487131115097902188374051595011248311352864*t**6 - + 5241326875850889518164640374668786338033653548841427557880599579174438246266263602956254030352*t**5 - + 10758917262823299139373269714910672770004760114329943852726887632013485035262879510837043892416*t**4 - + 13119383576444715672578819534846747735372132018341964647712009275306635391456880068261130581248*t**3 - + 9491412317016197146080450036267011389660653495578680036574753839055748080962214787557853941760*t**2 - + 3767520915562795326943800040277726397326609797172964377014046018280260848046603967211258368000*t - + 632314652371226552085897259159210286886724229880266931574701654721512325555116066073245696000, + 610733380717522355121*t**8 + + 6243748742141230639968*t**7 + + 27761407182086143225024*t**6 + + 70066148869420956398592*t**5 + + 109701225644313784229376*t**4 + + 109009005495588442152960*t**3 + + 67072101084384786432000*t**2 + + 23339979742629593088000*t + + 3513592776846090240000, + ] + + R, x,t = ring("x,t", ZZ, grlex) + I = [ i.set_ring(R) for i in I ] + + assert groebner(I, R) == [ + 16996618586000601590732959134095643086442*t**3*x - + 32936701459297092865176560282688198064839*t**3 + + 78592411049800639484139414821529525782364*t**2*x - + 120753953358671750165454009478961405619916*t**2 + + 120988399875140799712152158915653654637280*t*x - + 144576390266626470824138354942076045758736*t + + 60017634054270480831259316163620768960*x**2 + + 61976058033571109604821862786675242894400*x - + 56266268491293858791834120380427754600960, + 576689018321912327136790519059646508441672750656050290242749*t**4 + + 2326673103677477425562248201573604572527893938459296513327336*t**3 + + 110743790416688497407826310048520299245819959064297990236000*t**2*x + + 3308669114229100853338245486174247752683277925010505284338016*t**2 + + 323150205645687941261103426627818874426097912639158572428800*t*x + + 1914335199925152083917206349978534224695445819017286960055680*t + + 861662882561803377986838989464278045397192862768588480000*x**2 + + 235296483281783440197069672204341465480107019878814196672000*x + + 361850798943225141738895123621685122544503614946436727532800, + -117584925286448670474763406733005510014188341867*t**3 + + 68566565876066068463853874568722190223721653044*t**2*x - + 435970731348366266878180788833437896139920683940*t**2 + + 196297602447033751918195568051376792491869233408*t*x - + 525011527660010557871349062870980202067479780112*t + + 517905853447200553360289634770487684447317120*x**3 + + 569119014870778921949288951688799397569321920*x**2 + + 138877356748142786670127389526667463202210102080*x - + 205109210539096046121625447192779783475018619520, + -3725142681462373002731339445216700112264527*t**3 + + 583711207282060457652784180668273817487940*t**2*x - + 12381382393074485225164741437227437062814908*t**2 + + 151081054097783125250959636747516827435040*t*x**2 + + 1814103857455163948531448580501928933873280*t*x - + 13353115629395094645843682074271212731433648*t + + 236415091385250007660606958022544983766080*x**2 + + 1390443278862804663728298060085399578417600*x - + 4716885828494075789338754454248931750698880, + ] + +# NOTE: This is very slow (> 2 minutes on 3.4 GHz) without GMPY +@slow +def test_benchmark_czichowski_buchberger(): + with config.using(groebner='buchberger'): + _do_test_benchmark_czichowski() + +def test_benchmark_czichowski_f5b(): + with config.using(groebner='f5b'): + _do_test_benchmark_czichowski() + +def _do_test_benchmark_cyclic_4(): + R, a,b,c,d = ring("a,b,c,d", ZZ, lex) + + I = [a + b + c + d, + a*b + a*d + b*c + b*d, + a*b*c + a*b*d + a*c*d + b*c*d, + a*b*c*d - 1] + + assert groebner(I, R) == [ + 4*a + 3*d**9 - 4*d**5 - 3*d, + 4*b + 4*c - 3*d**9 + 4*d**5 + 7*d, + 4*c**2 + 3*d**10 - 4*d**6 - 3*d**2, + 4*c*d**4 + 4*c - d**9 + 4*d**5 + 5*d, d**12 - d**8 - d**4 + 1 + ] + + R, a,b,c,d = ring("a,b,c,d", ZZ, grlex) + I = [ i.set_ring(R) for i in I ] + + assert groebner(I, R) == [ + 3*b*c - c**2 + d**6 - 3*d**2, + -b + 3*c**2*d**3 - c - d**5 - 4*d, + -b + 3*c*d**4 + 2*c + 2*d**5 + 2*d, + c**4 + 2*c**2*d**2 - d**4 - 2, + c**3*d + c*d**3 + d**4 + 1, + b*c**2 - c**3 - c**2*d - 2*c*d**2 - d**3, + b**2 - c**2, b*d + c**2 + c*d + d**2, + a + b + c + d + ] + +def test_benchmark_cyclic_4_buchberger(): + with config.using(groebner='buchberger'): + _do_test_benchmark_cyclic_4() + +def test_benchmark_cyclic_4_f5b(): + with config.using(groebner='f5b'): + _do_test_benchmark_cyclic_4() + +def test_sig_key(): + s1 = sig((0,) * 3, 2) + s2 = sig((1,) * 3, 4) + s3 = sig((2,) * 3, 2) + + assert sig_key(s1, lex) > sig_key(s2, lex) + assert sig_key(s2, lex) < sig_key(s3, lex) + + +def test_lbp_key(): + R, x,y,z,t = ring("x,y,z,t", ZZ, lex) + + p1 = lbp(sig((0,) * 4, 3), R.zero, 12) + p2 = lbp(sig((0,) * 4, 4), R.zero, 13) + p3 = lbp(sig((0,) * 4, 4), R.zero, 12) + + assert lbp_key(p1) > lbp_key(p2) + assert lbp_key(p2) < lbp_key(p3) + + +def test_critical_pair(): + # from cyclic4 with grlex + R, x,y,z,t = ring("x,y,z,t", QQ, grlex) + + p1 = (((0, 0, 0, 0), 4), y*z*t**2 + z**2*t**2 - t**4 - 1, 4) + q1 = (((0, 0, 0, 0), 2), -y**2 - y*t - z*t - t**2, 2) + + p2 = (((0, 0, 0, 2), 3), z**3*t**2 + z**2*t**3 - z - t, 5) + q2 = (((0, 0, 2, 2), 2), y*z + z*t**5 + z*t + t**6, 13) + + assert critical_pair(p1, q1, R) == ( + ((0, 0, 1, 2), 2), ((0, 0, 1, 2), QQ(-1, 1)), (((0, 0, 0, 0), 2), -y**2 - y*t - z*t - t**2, 2), + ((0, 1, 0, 0), 4), ((0, 1, 0, 0), QQ(1, 1)), (((0, 0, 0, 0), 4), y*z*t**2 + z**2*t**2 - t**4 - 1, 4) + ) + assert critical_pair(p2, q2, R) == ( + ((0, 0, 4, 2), 2), ((0, 0, 2, 0), QQ(1, 1)), (((0, 0, 2, 2), 2), y*z + z*t**5 + z*t + t**6, 13), + ((0, 0, 0, 5), 3), ((0, 0, 0, 3), QQ(1, 1)), (((0, 0, 0, 2), 3), z**3*t**2 + z**2*t**3 - z - t, 5) + ) + +def test_cp_key(): + # from cyclic4 with grlex + R, x,y,z,t = ring("x,y,z,t", QQ, grlex) + + p1 = (((0, 0, 0, 0), 4), y*z*t**2 + z**2*t**2 - t**4 - 1, 4) + q1 = (((0, 0, 0, 0), 2), -y**2 - y*t - z*t - t**2, 2) + + p2 = (((0, 0, 0, 2), 3), z**3*t**2 + z**2*t**3 - z - t, 5) + q2 = (((0, 0, 2, 2), 2), y*z + z*t**5 + z*t + t**6, 13) + + cp1 = critical_pair(p1, q1, R) + cp2 = critical_pair(p2, q2, R) + + assert cp_key(cp1, R) < cp_key(cp2, R) + + cp1 = critical_pair(p1, p2, R) + cp2 = critical_pair(q1, q2, R) + + assert cp_key(cp1, R) < cp_key(cp2, R) + + +def test_is_rewritable_or_comparable(): + # from katsura4 with grlex + R, x,y,z,t = ring("x,y,z,t", QQ, grlex) + + p = lbp(sig((0, 0, 2, 1), 2), R.zero, 2) + B = [lbp(sig((0, 0, 0, 1), 2), QQ(2,45)*y**2 + QQ(1,5)*y*z + QQ(5,63)*y*t + z**2*t + QQ(4,45)*z**2 + QQ(76,35)*z*t**2 - QQ(32,105)*z*t + QQ(13,7)*t**3 - QQ(13,21)*t**2, 6)] + + # rewritable: + assert is_rewritable_or_comparable(Sign(p), Num(p), B) is True + + p = lbp(sig((0, 1, 1, 0), 2), R.zero, 7) + B = [lbp(sig((0, 0, 0, 0), 3), QQ(10,3)*y*z + QQ(4,3)*y*t - QQ(1,3)*y + 4*z**2 + QQ(22,3)*z*t - QQ(4,3)*z + 4*t**2 - QQ(4,3)*t, 3)] + + # comparable: + assert is_rewritable_or_comparable(Sign(p), Num(p), B) is True + + +def test_f5_reduce(): + # katsura3 with lex + R, x,y,z = ring("x,y,z", QQ, lex) + + F = [(((0, 0, 0), 1), x + 2*y + 2*z - 1, 1), + (((0, 0, 0), 2), 6*y**2 + 8*y*z - 2*y + 6*z**2 - 2*z, 2), + (((0, 0, 0), 3), QQ(10,3)*y*z - QQ(1,3)*y + 4*z**2 - QQ(4,3)*z, 3), + (((0, 0, 1), 2), y + 30*z**3 - QQ(79,7)*z**2 + QQ(3,7)*z, 4), + (((0, 0, 2), 2), z**4 - QQ(10,21)*z**3 + QQ(1,84)*z**2 + QQ(1,84)*z, 5)] + + cp = critical_pair(F[0], F[1], R) + s = s_poly(cp) + + assert f5_reduce(s, F) == (((0, 2, 0), 1), R.zero, 1) + + s = lbp(sig(Sign(s)[0], 100), Polyn(s), Num(s)) + assert f5_reduce(s, F) == s + + +def test_representing_matrices(): + R, x,y = ring("x,y", QQ, grlex) + + basis = [(0, 0), (0, 1), (1, 0), (1, 1)] + F = [x**2 - x - 3*y + 1, -2*x + y**2 + y - 1] + + assert _representing_matrices(basis, F, R) == [ + [[QQ(0, 1), QQ(0, 1),-QQ(1, 1), QQ(3, 1)], + [QQ(0, 1), QQ(0, 1), QQ(3, 1),-QQ(4, 1)], + [QQ(1, 1), QQ(0, 1), QQ(1, 1), QQ(6, 1)], + [QQ(0, 1), QQ(1, 1), QQ(0, 1), QQ(1, 1)]], + [[QQ(0, 1), QQ(1, 1), QQ(0, 1),-QQ(2, 1)], + [QQ(1, 1),-QQ(1, 1), QQ(0, 1), QQ(6, 1)], + [QQ(0, 1), QQ(2, 1), QQ(0, 1), QQ(3, 1)], + [QQ(0, 1), QQ(0, 1), QQ(1, 1),-QQ(1, 1)]]] + +def test_groebner_lcm(): + R, x,y,z = ring("x,y,z", ZZ) + + assert groebner_lcm(x**2 - y**2, x - y) == x**2 - y**2 + assert groebner_lcm(2*x**2 - 2*y**2, 2*x - 2*y) == 2*x**2 - 2*y**2 + + R, x,y,z = ring("x,y,z", QQ) + + assert groebner_lcm(x**2 - y**2, x - y) == x**2 - y**2 + assert groebner_lcm(2*x**2 - 2*y**2, 2*x - 2*y) == 2*x**2 - 2*y**2 + + R, x,y = ring("x,y", ZZ) + + assert groebner_lcm(x**2*y, x*y**2) == x**2*y**2 + + f = 2*x*y**5 - 3*x*y**4 - 2*x*y**3 + 3*x*y**2 + g = y**5 - 2*y**3 + y + h = 2*x*y**7 - 3*x*y**6 - 4*x*y**5 + 6*x*y**4 + 2*x*y**3 - 3*x*y**2 + + assert groebner_lcm(f, g) == h + + f = x**3 - 3*x**2*y - 9*x*y**2 - 5*y**3 + g = x**4 + 6*x**3*y + 12*x**2*y**2 + 10*x*y**3 + 3*y**4 + h = x**5 + x**4*y - 18*x**3*y**2 - 50*x**2*y**3 - 47*x*y**4 - 15*y**5 + + assert groebner_lcm(f, g) == h + +def test_groebner_gcd(): + R, x,y,z = ring("x,y,z", ZZ) + + assert groebner_gcd(x**2 - y**2, x - y) == x - y + assert groebner_gcd(2*x**2 - 2*y**2, 2*x - 2*y) == 2*x - 2*y + + R, x,y,z = ring("x,y,z", QQ) + + assert groebner_gcd(x**2 - y**2, x - y) == x - y + assert groebner_gcd(2*x**2 - 2*y**2, 2*x - 2*y) == x - y + +def test_is_groebner(): + R, x,y = ring("x,y", QQ, grlex) + valid_groebner = [x**2, x*y, -QQ(1,2)*x + y**2] + invalid_groebner = [x**3, x*y, -QQ(1,2)*x + y**2] + assert is_groebner(valid_groebner, R) is True + assert is_groebner(invalid_groebner, R) is False + +def test_is_reduced(): + R, x, y = ring("x,y", QQ, lex) + f = x**2 + 2*x*y**2 + g = x*y + 2*y**3 - 1 + assert is_reduced([f, g], R) == False + G = groebner([f, g], R) + assert is_reduced(G, R) == True diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/test_heuristicgcd.py b/phivenv/Lib/site-packages/sympy/polys/tests/test_heuristicgcd.py new file mode 100644 index 0000000000000000000000000000000000000000..7ff6bd6ea4effbd49c5e942ea8925cfcca4ba162 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/tests/test_heuristicgcd.py @@ -0,0 +1,152 @@ +from sympy.polys.rings import ring +from sympy.polys.domains import ZZ +from sympy.polys.heuristicgcd import heugcd + + +def test_heugcd_univariate_integers(): + R, x = ring("x", ZZ) + + f = x**4 + 8*x**3 + 21*x**2 + 22*x + 8 + g = x**3 + 6*x**2 + 11*x + 6 + + h = x**2 + 3*x + 2 + + cff = x**2 + 5*x + 4 + cfg = x + 3 + + assert heugcd(f, g) == (h, cff, cfg) + + f = x**4 - 4 + g = x**4 + 4*x**2 + 4 + + h = x**2 + 2 + + cff = x**2 - 2 + cfg = x**2 + 2 + + assert heugcd(f, g) == (h, cff, cfg) + + f = x**8 + x**6 - 3*x**4 - 3*x**3 + 8*x**2 + 2*x - 5 + g = 3*x**6 + 5*x**4 - 4*x**2 - 9*x + 21 + + h = 1 + + cff = f + cfg = g + + assert heugcd(f, g) == (h, cff, cfg) + + f = - 352518131239247345597970242177235495263669787845475025293906825864749649589178600387510272*x**49 \ + + 46818041807522713962450042363465092040687472354933295397472942006618953623327997952*x**42 \ + + 378182690892293941192071663536490788434899030680411695933646320291525827756032*x**35 \ + + 112806468807371824947796775491032386836656074179286744191026149539708928*x**28 \ + - 12278371209708240950316872681744825481125965781519138077173235712*x**21 \ + + 289127344604779611146960547954288113529690984687482920704*x**14 \ + + 19007977035740498977629742919480623972236450681*x**7 \ + + 311973482284542371301330321821976049 + + g = 365431878023781158602430064717380211405897160759702125019136*x**21 \ + + 197599133478719444145775798221171663643171734081650688*x**14 \ + - 9504116979659010018253915765478924103928886144*x**7 \ + - 311973482284542371301330321821976049 + + # TODO: assert heugcd(f, f.diff(x))[0] == g + + f = 1317378933230047068160*x + 2945748836994210856960 + g = 120352542776360960*x + 269116466014453760 + + h = 120352542776360960*x + 269116466014453760 + cff = 10946 + cfg = 1 + + assert heugcd(f, g) == (h, cff, cfg) + +def test_heugcd_multivariate_integers(): + R, x, y = ring("x,y", ZZ) + + f, g = 2*x**2 + 4*x + 2, x + 1 + assert heugcd(f, g) == (x + 1, 2*x + 2, 1) + + f, g = x + 1, 2*x**2 + 4*x + 2 + assert heugcd(f, g) == (x + 1, 1, 2*x + 2) + + R, x, y, z, u = ring("x,y,z,u", ZZ) + + f, g = u**2 + 2*u + 1, 2*u + 2 + assert heugcd(f, g) == (u + 1, u + 1, 2) + + f, g = z**2*u**2 + 2*z**2*u + z**2 + z*u + z, u**2 + 2*u + 1 + h, cff, cfg = u + 1, z**2*u + z**2 + z, u + 1 + + assert heugcd(f, g) == (h, cff, cfg) + assert heugcd(g, f) == (h, cfg, cff) + + R, x, y, z = ring("x,y,z", ZZ) + + f, g, h = R.fateman_poly_F_1() + H, cff, cfg = heugcd(f, g) + + assert H == h and H*cff == f and H*cfg == g + + R, x, y, z, u, v = ring("x,y,z,u,v", ZZ) + + f, g, h = R.fateman_poly_F_1() + H, cff, cfg = heugcd(f, g) + + assert H == h and H*cff == f and H*cfg == g + + R, x, y, z, u, v, a, b = ring("x,y,z,u,v,a,b", ZZ) + + f, g, h = R.fateman_poly_F_1() + H, cff, cfg = heugcd(f, g) + + assert H == h and H*cff == f and H*cfg == g + + R, x, y, z, u, v, a, b, c, d = ring("x,y,z,u,v,a,b,c,d", ZZ) + + f, g, h = R.fateman_poly_F_1() + H, cff, cfg = heugcd(f, g) + + assert H == h and H*cff == f and H*cfg == g + + R, x, y, z = ring("x,y,z", ZZ) + + f, g, h = R.fateman_poly_F_2() + H, cff, cfg = heugcd(f, g) + + assert H == h and H*cff == f and H*cfg == g + + f, g, h = R.fateman_poly_F_3() + H, cff, cfg = heugcd(f, g) + + assert H == h and H*cff == f and H*cfg == g + + R, x, y, z, t = ring("x,y,z,t", ZZ) + + f, g, h = R.fateman_poly_F_3() + H, cff, cfg = heugcd(f, g) + + assert H == h and H*cff == f and H*cfg == g + + +def test_issue_10996(): + R, x, y, z = ring("x,y,z", ZZ) + + f = 12*x**6*y**7*z**3 - 3*x**4*y**9*z**3 + 12*x**3*y**5*z**4 + g = -48*x**7*y**8*z**3 + 12*x**5*y**10*z**3 - 48*x**5*y**7*z**2 + \ + 36*x**4*y**7*z - 48*x**4*y**6*z**4 + 12*x**3*y**9*z**2 - 48*x**3*y**4 \ + - 9*x**2*y**9*z - 48*x**2*y**5*z**3 + 12*x*y**6 + 36*x*y**5*z**2 - 48*y**2*z + + H, cff, cfg = heugcd(f, g) + + assert H == 12*x**3*y**4 - 3*x*y**6 + 12*y**2*z + assert H*cff == f and H*cfg == g + + +def test_issue_25793(): + R, x = ring("x", ZZ) + f = x - 4851 # failure starts for values more than 4850 + g = f*(2*x + 1) + H, cff, cfg = R.dup_zz_heu_gcd(f, g) + assert H == f + # needs a test for dmp, too, that fails in master before this change diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/test_hypothesis.py b/phivenv/Lib/site-packages/sympy/polys/tests/test_hypothesis.py new file mode 100644 index 0000000000000000000000000000000000000000..78c2369179c3f0ea4d34b8a7868417506177e3c5 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/tests/test_hypothesis.py @@ -0,0 +1,36 @@ +from hypothesis import given +from hypothesis import strategies as st +from sympy.abc import x +from sympy.polys.polytools import Poly + + +def polys(*, nonzero=False, domain="ZZ"): + # This is a simple strategy, but sufficient the tests below + elems = {"ZZ": st.integers(), "QQ": st.fractions()} + coeff_st = st.lists(elems[domain]) + if nonzero: + coeff_st = coeff_st.filter(any) + return st.builds(Poly, coeff_st, st.just(x), domain=st.just(domain)) + + +@given(f=polys(), g=polys(), r=polys()) +def test_gcd_hypothesis(f, g, r): + gcd_1 = f.gcd(g) + gcd_2 = g.gcd(f) + assert gcd_1 == gcd_2 + + # multiply by r + gcd_3 = g.gcd(f + r * g) + assert gcd_1 == gcd_3 + + +@given(f_z=polys(), g_z=polys(nonzero=True)) +def test_poly_hypothesis_integers(f_z, g_z): + remainder_z = f_z.rem(g_z) + assert g_z.degree() >= remainder_z.degree() or remainder_z.degree() == 0 + + +@given(f_q=polys(domain="QQ"), g_q=polys(nonzero=True, domain="QQ")) +def test_poly_hypothesis_rationals(f_q, g_q): + remainder_q = f_q.rem(g_q) + assert g_q.degree() >= remainder_q.degree() or remainder_q.degree() == 0 diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/test_injections.py b/phivenv/Lib/site-packages/sympy/polys/tests/test_injections.py new file mode 100644 index 0000000000000000000000000000000000000000..63a5537c94f00e52a3899c97f0d78bfadab78a67 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/tests/test_injections.py @@ -0,0 +1,39 @@ +"""Tests for functions that inject symbols into the global namespace. """ + +from sympy.polys.rings import vring +from sympy.polys.fields import vfield +from sympy.polys.domains import QQ + +def test_vring(): + ns = {'vring':vring, 'QQ':QQ} + exec('R = vring("r", QQ)', ns) + exec('assert r == R.gens[0]', ns) + + exec('R = vring("rb rbb rcc rzz _rx", QQ)', ns) + exec('assert rb == R.gens[0]', ns) + exec('assert rbb == R.gens[1]', ns) + exec('assert rcc == R.gens[2]', ns) + exec('assert rzz == R.gens[3]', ns) + exec('assert _rx == R.gens[4]', ns) + + exec('R = vring(["rd", "re", "rfg"], QQ)', ns) + exec('assert rd == R.gens[0]', ns) + exec('assert re == R.gens[1]', ns) + exec('assert rfg == R.gens[2]', ns) + +def test_vfield(): + ns = {'vfield':vfield, 'QQ':QQ} + exec('F = vfield("f", QQ)', ns) + exec('assert f == F.gens[0]', ns) + + exec('F = vfield("fb fbb fcc fzz _fx", QQ)', ns) + exec('assert fb == F.gens[0]', ns) + exec('assert fbb == F.gens[1]', ns) + exec('assert fcc == F.gens[2]', ns) + exec('assert fzz == F.gens[3]', ns) + exec('assert _fx == F.gens[4]', ns) + + exec('F = vfield(["fd", "fe", "ffg"], QQ)', ns) + exec('assert fd == F.gens[0]', ns) + exec('assert fe == F.gens[1]', ns) + exec('assert ffg == F.gens[2]', ns) diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/test_modulargcd.py b/phivenv/Lib/site-packages/sympy/polys/tests/test_modulargcd.py new file mode 100644 index 0000000000000000000000000000000000000000..20510f59186524ed4008ade943fab526a9ae7194 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/tests/test_modulargcd.py @@ -0,0 +1,325 @@ +from sympy.polys.rings import ring +from sympy.polys.domains import ZZ, QQ, AlgebraicField +from sympy.polys.modulargcd import ( + modgcd_univariate, + modgcd_bivariate, + _chinese_remainder_reconstruction_multivariate, + modgcd_multivariate, + _to_ZZ_poly, + _to_ANP_poly, + func_field_modgcd, + _func_field_modgcd_m) +from sympy.functions.elementary.miscellaneous import sqrt + + +def test_modgcd_univariate_integers(): + R, x = ring("x", ZZ) + + f, g = R.zero, R.zero + assert modgcd_univariate(f, g) == (0, 0, 0) + + f, g = R.zero, x + assert modgcd_univariate(f, g) == (x, 0, 1) + assert modgcd_univariate(g, f) == (x, 1, 0) + + f, g = R.zero, -x + assert modgcd_univariate(f, g) == (x, 0, -1) + assert modgcd_univariate(g, f) == (x, -1, 0) + + f, g = 2*x, R(2) + assert modgcd_univariate(f, g) == (2, x, 1) + + f, g = 2*x + 2, 6*x**2 - 6 + assert modgcd_univariate(f, g) == (2*x + 2, 1, 3*x - 3) + + f = x**4 + 8*x**3 + 21*x**2 + 22*x + 8 + g = x**3 + 6*x**2 + 11*x + 6 + + h = x**2 + 3*x + 2 + + cff = x**2 + 5*x + 4 + cfg = x + 3 + + assert modgcd_univariate(f, g) == (h, cff, cfg) + + f = x**4 - 4 + g = x**4 + 4*x**2 + 4 + + h = x**2 + 2 + + cff = x**2 - 2 + cfg = x**2 + 2 + + assert modgcd_univariate(f, g) == (h, cff, cfg) + + f = x**8 + x**6 - 3*x**4 - 3*x**3 + 8*x**2 + 2*x - 5 + g = 3*x**6 + 5*x**4 - 4*x**2 - 9*x + 21 + + h = 1 + + cff = f + cfg = g + + assert modgcd_univariate(f, g) == (h, cff, cfg) + + f = - 352518131239247345597970242177235495263669787845475025293906825864749649589178600387510272*x**49 \ + + 46818041807522713962450042363465092040687472354933295397472942006618953623327997952*x**42 \ + + 378182690892293941192071663536490788434899030680411695933646320291525827756032*x**35 \ + + 112806468807371824947796775491032386836656074179286744191026149539708928*x**28 \ + - 12278371209708240950316872681744825481125965781519138077173235712*x**21 \ + + 289127344604779611146960547954288113529690984687482920704*x**14 \ + + 19007977035740498977629742919480623972236450681*x**7 \ + + 311973482284542371301330321821976049 + + g = 365431878023781158602430064717380211405897160759702125019136*x**21 \ + + 197599133478719444145775798221171663643171734081650688*x**14 \ + - 9504116979659010018253915765478924103928886144*x**7 \ + - 311973482284542371301330321821976049 + + assert modgcd_univariate(f, f.diff(x))[0] == g + + f = 1317378933230047068160*x + 2945748836994210856960 + g = 120352542776360960*x + 269116466014453760 + + h = 120352542776360960*x + 269116466014453760 + cff = 10946 + cfg = 1 + + assert modgcd_univariate(f, g) == (h, cff, cfg) + + +def test_modgcd_bivariate_integers(): + R, x, y = ring("x,y", ZZ) + + f, g = R.zero, R.zero + assert modgcd_bivariate(f, g) == (0, 0, 0) + + f, g = 2*x, R(2) + assert modgcd_bivariate(f, g) == (2, x, 1) + + f, g = x + 2*y, x + y + assert modgcd_bivariate(f, g) == (1, f, g) + + f, g = x**2 + 2*x*y + y**2, x**3 + y**3 + assert modgcd_bivariate(f, g) == (x + y, x + y, x**2 - x*y + y**2) + + f, g = x*y**2 + 2*x*y + x, x*y**3 + x + assert modgcd_bivariate(f, g) == (x*y + x, y + 1, y**2 - y + 1) + + f, g = x**2*y**2 + x**2*y + 1, x*y**2 + x*y + 1 + assert modgcd_bivariate(f, g) == (1, f, g) + + f = 2*x*y**2 + 4*x*y + 2*x + y**2 + 2*y + 1 + g = 2*x*y**3 + 2*x + y**3 + 1 + assert modgcd_bivariate(f, g) == (2*x*y + 2*x + y + 1, y + 1, y**2 - y + 1) + + f, g = 2*x**2 + 4*x + 2, x + 1 + assert modgcd_bivariate(f, g) == (x + 1, 2*x + 2, 1) + + f, g = x + 1, 2*x**2 + 4*x + 2 + assert modgcd_bivariate(f, g) == (x + 1, 1, 2*x + 2) + + f = 2*x**2 + 4*x*y - 2*x - 4*y + g = x**2 + x - 2 + assert modgcd_bivariate(f, g) == (x - 1, 2*x + 4*y, x + 2) + + f = 2*x**2 + 2*x*y - 3*x - 3*y + g = 4*x*y - 2*x + 4*y**2 - 2*y + assert modgcd_bivariate(f, g) == (x + y, 2*x - 3, 4*y - 2) + + +def test_chinese_remainder(): + R, x, y = ring("x, y", ZZ) + p, q = 3, 5 + + hp = x**3*y - x**2 - 1 + hq = -x**3*y - 2*x*y**2 + 2 + + hpq = _chinese_remainder_reconstruction_multivariate(hp, hq, p, q) + + assert hpq.trunc_ground(p) == hp + assert hpq.trunc_ground(q) == hq + + T, z = ring("z", R) + p, q = 3, 7 + + hp = (x*y + 1)*z**2 + x + hq = (x**2 - 3*y)*z + 2 + + hpq = _chinese_remainder_reconstruction_multivariate(hp, hq, p, q) + + assert hpq.trunc_ground(p) == hp + assert hpq.trunc_ground(q) == hq + + +def test_modgcd_multivariate_integers(): + R, x, y = ring("x,y", ZZ) + + f, g = R.zero, R.zero + assert modgcd_multivariate(f, g) == (0, 0, 0) + + f, g = 2*x**2 + 4*x + 2, x + 1 + assert modgcd_multivariate(f, g) == (x + 1, 2*x + 2, 1) + + f, g = x + 1, 2*x**2 + 4*x + 2 + assert modgcd_multivariate(f, g) == (x + 1, 1, 2*x + 2) + + f = 2*x**2 + 2*x*y - 3*x - 3*y + g = 4*x*y - 2*x + 4*y**2 - 2*y + assert modgcd_multivariate(f, g) == (x + y, 2*x - 3, 4*y - 2) + + f, g = x*y**2 + 2*x*y + x, x*y**3 + x + assert modgcd_multivariate(f, g) == (x*y + x, y + 1, y**2 - y + 1) + + f, g = x**2*y**2 + x**2*y + 1, x*y**2 + x*y + 1 + assert modgcd_multivariate(f, g) == (1, f, g) + + f = x**4 + 8*x**3 + 21*x**2 + 22*x + 8 + g = x**3 + 6*x**2 + 11*x + 6 + + h = x**2 + 3*x + 2 + + cff = x**2 + 5*x + 4 + cfg = x + 3 + + assert modgcd_multivariate(f, g) == (h, cff, cfg) + + R, x, y, z, u = ring("x,y,z,u", ZZ) + + f, g = x + y + z, -x - y - z - u + assert modgcd_multivariate(f, g) == (1, f, g) + + f, g = u**2 + 2*u + 1, 2*u + 2 + assert modgcd_multivariate(f, g) == (u + 1, u + 1, 2) + + f, g = z**2*u**2 + 2*z**2*u + z**2 + z*u + z, u**2 + 2*u + 1 + h, cff, cfg = u + 1, z**2*u + z**2 + z, u + 1 + + assert modgcd_multivariate(f, g) == (h, cff, cfg) + assert modgcd_multivariate(g, f) == (h, cfg, cff) + + R, x, y, z = ring("x,y,z", ZZ) + + f, g = x - y*z, x - y*z + assert modgcd_multivariate(f, g) == (x - y*z, 1, 1) + + f, g, h = R.fateman_poly_F_1() + H, cff, cfg = modgcd_multivariate(f, g) + + assert H == h and H*cff == f and H*cfg == g + + R, x, y, z, u, v = ring("x,y,z,u,v", ZZ) + + f, g, h = R.fateman_poly_F_1() + H, cff, cfg = modgcd_multivariate(f, g) + + assert H == h and H*cff == f and H*cfg == g + + R, x, y, z, u, v, a, b = ring("x,y,z,u,v,a,b", ZZ) + + f, g, h = R.fateman_poly_F_1() + H, cff, cfg = modgcd_multivariate(f, g) + + assert H == h and H*cff == f and H*cfg == g + + R, x, y, z, u, v, a, b, c, d = ring("x,y,z,u,v,a,b,c,d", ZZ) + + f, g, h = R.fateman_poly_F_1() + H, cff, cfg = modgcd_multivariate(f, g) + + assert H == h and H*cff == f and H*cfg == g + + R, x, y, z = ring("x,y,z", ZZ) + + f, g, h = R.fateman_poly_F_2() + H, cff, cfg = modgcd_multivariate(f, g) + + assert H == h and H*cff == f and H*cfg == g + + f, g, h = R.fateman_poly_F_3() + H, cff, cfg = modgcd_multivariate(f, g) + + assert H == h and H*cff == f and H*cfg == g + + R, x, y, z, t = ring("x,y,z,t", ZZ) + + f, g, h = R.fateman_poly_F_3() + H, cff, cfg = modgcd_multivariate(f, g) + + assert H == h and H*cff == f and H*cfg == g + + +def test_to_ZZ_ANP_poly(): + A = AlgebraicField(QQ, sqrt(2)) + R, x = ring("x", A) + f = x*(sqrt(2) + 1) + + T, x_, z_ = ring("x_, z_", ZZ) + f_ = x_*z_ + x_ + + assert _to_ZZ_poly(f, T) == f_ + assert _to_ANP_poly(f_, R) == f + + R, x, t, s = ring("x, t, s", A) + f = x*t**2 + x*s + sqrt(2) + + D, t_, s_ = ring("t_, s_", ZZ) + T, x_, z_ = ring("x_, z_", D) + f_ = (t_**2 + s_)*x_ + z_ + + assert _to_ZZ_poly(f, T) == f_ + assert _to_ANP_poly(f_, R) == f + + +def test_modgcd_algebraic_field(): + A = AlgebraicField(QQ, sqrt(2)) + R, x = ring("x", A) + one = A.one + + f, g = 2*x, R(2) + assert func_field_modgcd(f, g) == (one, f, g) + + f, g = 2*x, R(sqrt(2)) + assert func_field_modgcd(f, g) == (one, f, g) + + f, g = 2*x + 2, 6*x**2 - 6 + assert func_field_modgcd(f, g) == (x + 1, R(2), 6*x - 6) + + R, x, y = ring("x, y", A) + + f, g = x + sqrt(2)*y, x + y + assert func_field_modgcd(f, g) == (one, f, g) + + f, g = x*y + sqrt(2)*y**2, R(sqrt(2))*y + assert func_field_modgcd(f, g) == (y, x + sqrt(2)*y, R(sqrt(2))) + + f, g = x**2 + 2*sqrt(2)*x*y + 2*y**2, x + sqrt(2)*y + assert func_field_modgcd(f, g) == (g, g, one) + + A = AlgebraicField(QQ, sqrt(2), sqrt(3)) + R, x, y, z = ring("x, y, z", A) + + h = x**2*y**7 + sqrt(6)/21*z + f, g = h*(27*y**3 + 1), h*(y + x) + assert func_field_modgcd(f, g) == (h, 27*y**3+1, y+x) + + h = x**13*y**3 + 1/2*x**10 + 1/sqrt(2) + f, g = h*(x + 1), h*sqrt(2)/sqrt(3) + assert func_field_modgcd(f, g) == (h, x + 1, R(sqrt(2)/sqrt(3))) + + A = AlgebraicField(QQ, sqrt(2)**(-1)*sqrt(3)) + R, x = ring("x", A) + + f, g = x + 1, x - 1 + assert func_field_modgcd(f, g) == (A.one, f, g) + + +# when func_field_modgcd supports function fields, this test can be changed +def test_modgcd_func_field(): + D, t = ring("t", ZZ) + R, x, z = ring("x, z", D) + + minpoly = (z**2*t**2 + z**2*t - 1).drop(0) + f, g = x + 1, x - 1 + + assert _func_field_modgcd_m(f, g, minpoly) == R.one diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/test_monomials.py b/phivenv/Lib/site-packages/sympy/polys/tests/test_monomials.py new file mode 100644 index 0000000000000000000000000000000000000000..c5ed28ba0e8e3f8e9f85c543a4fffcaef855fff8 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/tests/test_monomials.py @@ -0,0 +1,269 @@ +"""Tests for tools and arithmetics for monomials of distributed polynomials. """ + +from sympy.polys.monomials import ( + itermonomials, monomial_count, + monomial_mul, monomial_div, + monomial_gcd, monomial_lcm, + monomial_max, monomial_min, + monomial_divides, monomial_pow, + Monomial, +) + +from sympy.polys.polyerrors import ExactQuotientFailed + +from sympy.abc import a, b, c, x, y, z +from sympy.core import S, symbols +from sympy.testing.pytest import raises + +def test_monomials(): + + # total_degree tests + assert set(itermonomials([], 0)) == {S.One} + assert set(itermonomials([], 1)) == {S.One} + assert set(itermonomials([], 2)) == {S.One} + + assert set(itermonomials([], 0, 0)) == {S.One} + assert set(itermonomials([], 1, 0)) == {S.One} + assert set(itermonomials([], 2, 0)) == {S.One} + + raises(StopIteration, lambda: next(itermonomials([], 0, 1))) + raises(StopIteration, lambda: next(itermonomials([], 0, 2))) + raises(StopIteration, lambda: next(itermonomials([], 0, 3))) + + assert set(itermonomials([], 0, 1)) == set() + assert set(itermonomials([], 0, 2)) == set() + assert set(itermonomials([], 0, 3)) == set() + + raises(ValueError, lambda: set(itermonomials([], -1))) + raises(ValueError, lambda: set(itermonomials([x], -1))) + raises(ValueError, lambda: set(itermonomials([x, y], -1))) + + assert set(itermonomials([x], 0)) == {S.One} + assert set(itermonomials([x], 1)) == {S.One, x} + assert set(itermonomials([x], 2)) == {S.One, x, x**2} + assert set(itermonomials([x], 3)) == {S.One, x, x**2, x**3} + + assert set(itermonomials([x, y], 0)) == {S.One} + assert set(itermonomials([x, y], 1)) == {S.One, x, y} + assert set(itermonomials([x, y], 2)) == {S.One, x, y, x**2, y**2, x*y} + assert set(itermonomials([x, y], 3)) == \ + {S.One, x, y, x**2, x**3, y**2, y**3, x*y, x*y**2, y*x**2} + + i, j, k = symbols('i j k', commutative=False) + assert set(itermonomials([i, j, k], 0)) == {S.One} + assert set(itermonomials([i, j, k], 1)) == {S.One, i, j, k} + assert set(itermonomials([i, j, k], 2)) == \ + {S.One, i, j, k, i**2, j**2, k**2, i*j, i*k, j*i, j*k, k*i, k*j} + + assert set(itermonomials([i, j, k], 3)) == \ + {S.One, i, j, k, i**2, j**2, k**2, i*j, i*k, j*i, j*k, k*i, k*j, + i**3, j**3, k**3, + i**2 * j, i**2 * k, j * i**2, k * i**2, + j**2 * i, j**2 * k, i * j**2, k * j**2, + k**2 * i, k**2 * j, i * k**2, j * k**2, + i*j*i, i*k*i, j*i*j, j*k*j, k*i*k, k*j*k, + i*j*k, i*k*j, j*i*k, j*k*i, k*i*j, k*j*i, + } + + assert set(itermonomials([x, i, j], 0)) == {S.One} + assert set(itermonomials([x, i, j], 1)) == {S.One, x, i, j} + assert set(itermonomials([x, i, j], 2)) == {S.One, x, i, j, x*i, x*j, i*j, j*i, x**2, i**2, j**2} + assert set(itermonomials([x, i, j], 3)) == \ + {S.One, x, i, j, x*i, x*j, i*j, j*i, x**2, i**2, j**2, + x**3, i**3, j**3, + x**2 * i, x**2 * j, + x * i**2, j * i**2, i**2 * j, i*j*i, + x * j**2, i * j**2, j**2 * i, j*i*j, + x * i * j, x * j * i + } + + # degree_list tests + assert set(itermonomials([], [])) == {S.One} + + raises(ValueError, lambda: set(itermonomials([], [0]))) + raises(ValueError, lambda: set(itermonomials([], [1]))) + raises(ValueError, lambda: set(itermonomials([], [2]))) + + raises(ValueError, lambda: set(itermonomials([x], [1], []))) + raises(ValueError, lambda: set(itermonomials([x], [1, 2], []))) + raises(ValueError, lambda: set(itermonomials([x], [1, 2, 3], []))) + + raises(ValueError, lambda: set(itermonomials([x], [], [1]))) + raises(ValueError, lambda: set(itermonomials([x], [], [1, 2]))) + raises(ValueError, lambda: set(itermonomials([x], [], [1, 2, 3]))) + + raises(ValueError, lambda: set(itermonomials([x, y], [1, 2], [1, 2, 3]))) + raises(ValueError, lambda: set(itermonomials([x, y, z], [1, 2, 3], [0, 1]))) + + raises(ValueError, lambda: set(itermonomials([x], [1], [-1]))) + raises(ValueError, lambda: set(itermonomials([x, y], [1, 2], [1, -1]))) + + raises(ValueError, lambda: set(itermonomials([], [], 1))) + raises(ValueError, lambda: set(itermonomials([], [], 2))) + raises(ValueError, lambda: set(itermonomials([], [], 3))) + + raises(ValueError, lambda: set(itermonomials([x, y], [0, 1], [1, 2]))) + raises(ValueError, lambda: set(itermonomials([x, y, z], [0, 0, 3], [0, 1, 2]))) + + assert set(itermonomials([x], [0])) == {S.One} + assert set(itermonomials([x], [1])) == {S.One, x} + assert set(itermonomials([x], [2])) == {S.One, x, x**2} + assert set(itermonomials([x], [3])) == {S.One, x, x**2, x**3} + + assert set(itermonomials([x], [3], [1])) == {x, x**3, x**2} + assert set(itermonomials([x], [3], [2])) == {x**3, x**2} + + assert set(itermonomials([x, y], 3, 3)) == {x**3, x**2*y, x*y**2, y**3} + assert set(itermonomials([x, y], 3, 2)) == {x**2, x*y, y**2, x**3, x**2*y, x*y**2, y**3} + + assert set(itermonomials([x, y], [0, 0])) == {S.One} + assert set(itermonomials([x, y], [0, 1])) == {S.One, y} + assert set(itermonomials([x, y], [0, 2])) == {S.One, y, y**2} + assert set(itermonomials([x, y], [0, 2], [0, 1])) == {y, y**2} + assert set(itermonomials([x, y], [0, 2], [0, 2])) == {y**2} + + assert set(itermonomials([x, y], [1, 0])) == {S.One, x} + assert set(itermonomials([x, y], [1, 1])) == {S.One, x, y, x*y} + assert set(itermonomials([x, y], [1, 2])) == {S.One, x, y, x*y, y**2, x*y**2} + assert set(itermonomials([x, y], [1, 2], [1, 1])) == {x*y, x*y**2} + assert set(itermonomials([x, y], [1, 2], [1, 2])) == {x*y**2} + + assert set(itermonomials([x, y], [2, 0])) == {S.One, x, x**2} + assert set(itermonomials([x, y], [2, 1])) == {S.One, x, y, x*y, x**2, x**2*y} + assert set(itermonomials([x, y], [2, 2])) == \ + {S.One, y**2, x*y**2, x, x*y, x**2, x**2*y**2, y, x**2*y} + + i, j, k = symbols('i j k', commutative=False) + assert set(itermonomials([i, j, k], 2, 2)) == \ + {k*i, i**2, i*j, j*k, j*i, k**2, j**2, k*j, i*k} + assert set(itermonomials([i, j, k], 3, 2)) == \ + {j*k**2, i*k**2, k*i*j, k*i**2, k**2, j*k*j, k*j**2, i*k*i, i*j, + j**2*k, i**2*j, j*i*k, j**3, i**3, k*j*i, j*k*i, j*i, + k**2*j, j*i**2, k*j, k*j*k, i*j*i, j*i*j, i*j**2, j**2, + k*i*k, i**2, j*k, i*k, i*k*j, k**3, i**2*k, j**2*i, k**2*i, + i*j*k, k*i + } + assert set(itermonomials([i, j, k], [0, 0, 0])) == {S.One} + assert set(itermonomials([i, j, k], [0, 0, 1])) == {1, k} + assert set(itermonomials([i, j, k], [0, 1, 0])) == {1, j} + assert set(itermonomials([i, j, k], [1, 0, 0])) == {i, 1} + assert set(itermonomials([i, j, k], [0, 0, 2])) == {k**2, 1, k} + assert set(itermonomials([i, j, k], [0, 2, 0])) == {1, j, j**2} + assert set(itermonomials([i, j, k], [2, 0, 0])) == {i, 1, i**2} + assert set(itermonomials([i, j, k], [1, 1, 1])) == {1, k, j, j*k, i*k, i, i*j, i*j*k} + assert set(itermonomials([i, j, k], [2, 2, 2])) == \ + {1, k, i**2*k**2, j*k, j**2, i, i*k, j*k**2, i*j**2*k**2, + i**2*j, i**2*j**2, k**2, j**2*k, i*j**2*k, + j**2*k**2, i*j, i**2*k, i**2*j**2*k, j, i**2*j*k, + i*j**2, i*k**2, i*j*k, i**2*j**2*k**2, i*j*k**2, i**2, i**2*j*k**2 + } + + assert set(itermonomials([x, j, k], [0, 0, 0])) == {S.One} + assert set(itermonomials([x, j, k], [0, 0, 1])) == {1, k} + assert set(itermonomials([x, j, k], [0, 1, 0])) == {1, j} + assert set(itermonomials([x, j, k], [1, 0, 0])) == {x, 1} + assert set(itermonomials([x, j, k], [0, 0, 2])) == {k**2, 1, k} + assert set(itermonomials([x, j, k], [0, 2, 0])) == {1, j, j**2} + assert set(itermonomials([x, j, k], [2, 0, 0])) == {x, 1, x**2} + assert set(itermonomials([x, j, k], [1, 1, 1])) == {1, k, j, j*k, x*k, x, x*j, x*j*k} + assert set(itermonomials([x, j, k], [2, 2, 2])) == \ + {1, k, x**2*k**2, j*k, j**2, x, x*k, j*k**2, x*j**2*k**2, + x**2*j, x**2*j**2, k**2, j**2*k, x*j**2*k, + j**2*k**2, x*j, x**2*k, x**2*j**2*k, j, x**2*j*k, + x*j**2, x*k**2, x*j*k, x**2*j**2*k**2, x*j*k**2, x**2, x**2*j*k**2 + } + +def test_monomial_count(): + assert monomial_count(2, 2) == 6 + assert monomial_count(2, 3) == 10 + +def test_monomial_mul(): + assert monomial_mul((3, 4, 1), (1, 2, 0)) == (4, 6, 1) + +def test_monomial_div(): + assert monomial_div((3, 4, 1), (1, 2, 0)) == (2, 2, 1) + +def test_monomial_gcd(): + assert monomial_gcd((3, 4, 1), (1, 2, 0)) == (1, 2, 0) + +def test_monomial_lcm(): + assert monomial_lcm((3, 4, 1), (1, 2, 0)) == (3, 4, 1) + +def test_monomial_max(): + assert monomial_max((3, 4, 5), (0, 5, 1), (6, 3, 9)) == (6, 5, 9) + +def test_monomial_pow(): + assert monomial_pow((1, 2, 3), 3) == (3, 6, 9) + +def test_monomial_min(): + assert monomial_min((3, 4, 5), (0, 5, 1), (6, 3, 9)) == (0, 3, 1) + +def test_monomial_divides(): + assert monomial_divides((1, 2, 3), (4, 5, 6)) is True + assert monomial_divides((1, 2, 3), (0, 5, 6)) is False + +def test_Monomial(): + m = Monomial((3, 4, 1), (x, y, z)) + n = Monomial((1, 2, 0), (x, y, z)) + + assert m.as_expr() == x**3*y**4*z + assert n.as_expr() == x**1*y**2 + + assert m.as_expr(a, b, c) == a**3*b**4*c + assert n.as_expr(a, b, c) == a**1*b**2 + + assert m.exponents == (3, 4, 1) + assert m.gens == (x, y, z) + + assert n.exponents == (1, 2, 0) + assert n.gens == (x, y, z) + + assert m == (3, 4, 1) + assert n != (3, 4, 1) + assert m != (1, 2, 0) + assert n == (1, 2, 0) + assert (m == 1) is False + + assert m[0] == m[-3] == 3 + assert m[1] == m[-2] == 4 + assert m[2] == m[-1] == 1 + + assert n[0] == n[-3] == 1 + assert n[1] == n[-2] == 2 + assert n[2] == n[-1] == 0 + + assert m[:2] == (3, 4) + assert n[:2] == (1, 2) + + assert m*n == Monomial((4, 6, 1)) + assert m/n == Monomial((2, 2, 1)) + + assert m*(1, 2, 0) == Monomial((4, 6, 1)) + assert m/(1, 2, 0) == Monomial((2, 2, 1)) + + assert m.gcd(n) == Monomial((1, 2, 0)) + assert m.lcm(n) == Monomial((3, 4, 1)) + + assert m.gcd((1, 2, 0)) == Monomial((1, 2, 0)) + assert m.lcm((1, 2, 0)) == Monomial((3, 4, 1)) + + assert m**0 == Monomial((0, 0, 0)) + assert m**1 == m + assert m**2 == Monomial((6, 8, 2)) + assert m**3 == Monomial((9, 12, 3)) + _a = Monomial((0, 0, 0)) + for n in range(10): + assert _a == m**n + _a *= m + + raises(ExactQuotientFailed, lambda: m/Monomial((5, 2, 0))) + + mm = Monomial((1, 2, 3)) + raises(ValueError, lambda: mm.as_expr()) + assert str(mm) == 'Monomial((1, 2, 3))' + assert str(m) == 'x**3*y**4*z**1' + raises(NotImplementedError, lambda: m*1) + raises(NotImplementedError, lambda: m/1) + raises(ValueError, lambda: m**-1) + raises(TypeError, lambda: m.gcd(3)) + raises(TypeError, lambda: m.lcm(3)) diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/test_multivariate_resultants.py b/phivenv/Lib/site-packages/sympy/polys/tests/test_multivariate_resultants.py new file mode 100644 index 0000000000000000000000000000000000000000..0799feb41fc875cf038723916a3efd62ff31b1b4 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/tests/test_multivariate_resultants.py @@ -0,0 +1,294 @@ +"""Tests for Dixon's and Macaulay's classes. """ + +from sympy.matrices.dense import Matrix +from sympy.polys.polytools import factor +from sympy.core import symbols +from sympy.tensor.indexed import IndexedBase + +from sympy.polys.multivariate_resultants import (DixonResultant, + MacaulayResultant) + +c, d = symbols("a, b") +x, y = symbols("x, y") + +p = c * x + y +q = x + d * y + +dixon = DixonResultant(polynomials=[p, q], variables=[x, y]) +macaulay = MacaulayResultant(polynomials=[p, q], variables=[x, y]) + +def test_dixon_resultant_init(): + """Test init method of DixonResultant.""" + a = IndexedBase("alpha") + + assert dixon.polynomials == [p, q] + assert dixon.variables == [x, y] + assert dixon.n == 2 + assert dixon.m == 2 + assert dixon.dummy_variables == [a[0], a[1]] + +def test_get_dixon_polynomial_numerical(): + """Test Dixon's polynomial for a numerical example.""" + a = IndexedBase("alpha") + + p = x + y + q = x ** 2 + y **3 + h = x ** 2 + y + + dixon = DixonResultant([p, q, h], [x, y]) + polynomial = -x * y ** 2 * a[0] - x * y ** 2 * a[1] - x * y * a[0] \ + * a[1] - x * y * a[1] ** 2 - x * a[0] * a[1] ** 2 + x * a[0] - \ + y ** 2 * a[0] * a[1] + y ** 2 * a[1] - y * a[0] * a[1] ** 2 + y * \ + a[1] ** 2 + + assert dixon.get_dixon_polynomial().as_expr().expand() == polynomial + +def test_get_max_degrees(): + """Tests max degrees function.""" + + p = x + y + q = x ** 2 + y **3 + h = x ** 2 + y + + dixon = DixonResultant(polynomials=[p, q, h], variables=[x, y]) + dixon_polynomial = dixon.get_dixon_polynomial() + + assert dixon.get_max_degrees(dixon_polynomial) == [1, 2] + +def test_get_dixon_matrix(): + """Test Dixon's resultant for a numerical example.""" + + x, y = symbols('x, y') + + p = x + y + q = x ** 2 + y ** 3 + h = x ** 2 + y + + dixon = DixonResultant([p, q, h], [x, y]) + polynomial = dixon.get_dixon_polynomial() + + assert dixon.get_dixon_matrix(polynomial).det() == 0 + +def test_get_dixon_matrix_example_two(): + """Test Dixon's matrix for example from [Palancz08]_.""" + x, y, z = symbols('x, y, z') + + f = x ** 2 + y ** 2 - 1 + z * 0 + g = x ** 2 + z ** 2 - 1 + y * 0 + h = y ** 2 + z ** 2 - 1 + + example_two = DixonResultant([f, g, h], [y, z]) + poly = example_two.get_dixon_polynomial() + matrix = example_two.get_dixon_matrix(poly) + + expr = 1 - 8 * x ** 2 + 24 * x ** 4 - 32 * x ** 6 + 16 * x ** 8 + assert (matrix.det() - expr).expand() == 0 + +def test_KSY_precondition(): + """Tests precondition for KSY Resultant.""" + A, B, C = symbols('A, B, C') + + m1 = Matrix([[1, 2, 3], + [4, 5, 12], + [6, 7, 18]]) + + m2 = Matrix([[0, C**2], + [-2 * C, -C ** 2]]) + + m3 = Matrix([[1, 0], + [0, 1]]) + + m4 = Matrix([[A**2, 0, 1], + [A, 1, 1 / A]]) + + m5 = Matrix([[5, 1], + [2, B], + [0, 1], + [0, 0]]) + + assert dixon.KSY_precondition(m1) == False + assert dixon.KSY_precondition(m2) == True + assert dixon.KSY_precondition(m3) == True + assert dixon.KSY_precondition(m4) == False + assert dixon.KSY_precondition(m5) == True + +def test_delete_zero_rows_and_columns(): + """Tests method for deleting rows and columns containing only zeros.""" + A, B, C = symbols('A, B, C') + + m1 = Matrix([[0, 0], + [0, 0], + [1, 2]]) + + m2 = Matrix([[0, 1, 2], + [0, 3, 4], + [0, 5, 6]]) + + m3 = Matrix([[0, 0, 0, 0], + [0, 1, 2, 0], + [0, 3, 4, 0], + [0, 0, 0, 0]]) + + m4 = Matrix([[1, 0, 2], + [0, 0, 0], + [3, 0, 4]]) + + m5 = Matrix([[0, 0, 0, 1], + [0, 0, 0, 2], + [0, 0, 0, 3], + [0, 0, 0, 4]]) + + m6 = Matrix([[0, 0, A], + [B, 0, 0], + [0, 0, C]]) + + assert dixon.delete_zero_rows_and_columns(m1) == Matrix([[1, 2]]) + + assert dixon.delete_zero_rows_and_columns(m2) == Matrix([[1, 2], + [3, 4], + [5, 6]]) + + assert dixon.delete_zero_rows_and_columns(m3) == Matrix([[1, 2], + [3, 4]]) + + assert dixon.delete_zero_rows_and_columns(m4) == Matrix([[1, 2], + [3, 4]]) + + assert dixon.delete_zero_rows_and_columns(m5) == Matrix([[1], + [2], + [3], + [4]]) + + assert dixon.delete_zero_rows_and_columns(m6) == Matrix([[0, A], + [B, 0], + [0, C]]) + +def test_product_leading_entries(): + """Tests product of leading entries method.""" + A, B = symbols('A, B') + + m1 = Matrix([[1, 2, 3], + [0, 4, 5], + [0, 0, 6]]) + + m2 = Matrix([[0, 0, 1], + [2, 0, 3]]) + + m3 = Matrix([[0, 0, 0], + [1, 2, 3], + [0, 0, 0]]) + + m4 = Matrix([[0, 0, A], + [1, 2, 3], + [B, 0, 0]]) + + assert dixon.product_leading_entries(m1) == 24 + assert dixon.product_leading_entries(m2) == 2 + assert dixon.product_leading_entries(m3) == 1 + assert dixon.product_leading_entries(m4) == A * B + +def test_get_KSY_Dixon_resultant_example_one(): + """Tests the KSY Dixon resultant for example one""" + x, y, z = symbols('x, y, z') + + p = x * y * z + q = x**2 - z**2 + h = x + y + z + dixon = DixonResultant([p, q, h], [x, y]) + dixon_poly = dixon.get_dixon_polynomial() + dixon_matrix = dixon.get_dixon_matrix(dixon_poly) + D = dixon.get_KSY_Dixon_resultant(dixon_matrix) + + assert D == -z**3 + +def test_get_KSY_Dixon_resultant_example_two(): + """Tests the KSY Dixon resultant for example two""" + x, y, A = symbols('x, y, A') + + p = x * y + x * A + x - A**2 - A + y**2 + y + q = x**2 + x * A - x + x * y + y * A - y + h = x**2 + x * y + 2 * x - x * A - y * A - 2 * A + + dixon = DixonResultant([p, q, h], [x, y]) + dixon_poly = dixon.get_dixon_polynomial() + dixon_matrix = dixon.get_dixon_matrix(dixon_poly) + D = factor(dixon.get_KSY_Dixon_resultant(dixon_matrix)) + + assert D == -8*A*(A - 1)*(A + 2)*(2*A - 1)**2 + +def test_macaulay_resultant_init(): + """Test init method of MacaulayResultant.""" + + assert macaulay.polynomials == [p, q] + assert macaulay.variables == [x, y] + assert macaulay.n == 2 + assert macaulay.degrees == [1, 1] + assert macaulay.degree_m == 1 + assert macaulay.monomials_size == 2 + +def test_get_degree_m(): + assert macaulay._get_degree_m() == 1 + +def test_get_size(): + assert macaulay.get_size() == 2 + +def test_macaulay_example_one(): + """Tests the Macaulay for example from [Bruce97]_""" + + x, y, z = symbols('x, y, z') + a_1_1, a_1_2, a_1_3 = symbols('a_1_1, a_1_2, a_1_3') + a_2_2, a_2_3, a_3_3 = symbols('a_2_2, a_2_3, a_3_3') + b_1_1, b_1_2, b_1_3 = symbols('b_1_1, b_1_2, b_1_3') + b_2_2, b_2_3, b_3_3 = symbols('b_2_2, b_2_3, b_3_3') + c_1, c_2, c_3 = symbols('c_1, c_2, c_3') + + f_1 = a_1_1 * x ** 2 + a_1_2 * x * y + a_1_3 * x * z + \ + a_2_2 * y ** 2 + a_2_3 * y * z + a_3_3 * z ** 2 + f_2 = b_1_1 * x ** 2 + b_1_2 * x * y + b_1_3 * x * z + \ + b_2_2 * y ** 2 + b_2_3 * y * z + b_3_3 * z ** 2 + f_3 = c_1 * x + c_2 * y + c_3 * z + + mac = MacaulayResultant([f_1, f_2, f_3], [x, y, z]) + + assert mac.degrees == [2, 2, 1] + assert mac.degree_m == 3 + + assert mac.monomial_set == [x ** 3, x ** 2 * y, x ** 2 * z, + x * y ** 2, + x * y * z, x * z ** 2, y ** 3, + y ** 2 *z, y * z ** 2, z ** 3] + assert mac.monomials_size == 10 + assert mac.get_row_coefficients() == [[x, y, z], [x, y, z], + [x * y, x * z, y * z, z ** 2]] + + matrix = mac.get_matrix() + assert matrix.shape == (mac.monomials_size, mac.monomials_size) + assert mac.get_submatrix(matrix) == Matrix([[a_1_1, a_2_2], + [b_1_1, b_2_2]]) + +def test_macaulay_example_two(): + """Tests the Macaulay formulation for example from [Stiller96]_.""" + + x, y, z = symbols('x, y, z') + a_0, a_1, a_2 = symbols('a_0, a_1, a_2') + b_0, b_1, b_2 = symbols('b_0, b_1, b_2') + c_0, c_1, c_2, c_3, c_4 = symbols('c_0, c_1, c_2, c_3, c_4') + + f = a_0 * y - a_1 * x + a_2 * z + g = b_1 * x ** 2 + b_0 * y ** 2 - b_2 * z ** 2 + h = c_0 * y - c_1 * x ** 3 + c_2 * x ** 2 * z - c_3 * x * z ** 2 + \ + c_4 * z ** 3 + + mac = MacaulayResultant([f, g, h], [x, y, z]) + + assert mac.degrees == [1, 2, 3] + assert mac.degree_m == 4 + assert mac.monomials_size == 15 + assert len(mac.get_row_coefficients()) == mac.n + + matrix = mac.get_matrix() + assert matrix.shape == (mac.monomials_size, mac.monomials_size) + assert mac.get_submatrix(matrix) == Matrix([[-a_1, a_0, a_2, 0], + [0, -a_1, 0, 0], + [0, 0, -a_1, 0], + [0, 0, 0, -a_1]]) diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/test_orderings.py b/phivenv/Lib/site-packages/sympy/polys/tests/test_orderings.py new file mode 100644 index 0000000000000000000000000000000000000000..d61d4887754c9d9f49905c2e131d253a45cf2ffd --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/tests/test_orderings.py @@ -0,0 +1,124 @@ +"""Tests of monomial orderings. """ + +from sympy.polys.orderings import ( + monomial_key, lex, grlex, grevlex, ilex, igrlex, + LexOrder, InverseOrder, ProductOrder, build_product_order, +) + +from sympy.abc import x, y, z, t +from sympy.core import S +from sympy.testing.pytest import raises + +def test_lex_order(): + assert lex((1, 2, 3)) == (1, 2, 3) + assert str(lex) == 'lex' + + assert lex((1, 2, 3)) == lex((1, 2, 3)) + + assert lex((2, 2, 3)) > lex((1, 2, 3)) + assert lex((1, 3, 3)) > lex((1, 2, 3)) + assert lex((1, 2, 4)) > lex((1, 2, 3)) + + assert lex((0, 2, 3)) < lex((1, 2, 3)) + assert lex((1, 1, 3)) < lex((1, 2, 3)) + assert lex((1, 2, 2)) < lex((1, 2, 3)) + + assert lex.is_global is True + assert lex == LexOrder() + assert lex != grlex + +def test_grlex_order(): + assert grlex((1, 2, 3)) == (6, (1, 2, 3)) + assert str(grlex) == 'grlex' + + assert grlex((1, 2, 3)) == grlex((1, 2, 3)) + + assert grlex((2, 2, 3)) > grlex((1, 2, 3)) + assert grlex((1, 3, 3)) > grlex((1, 2, 3)) + assert grlex((1, 2, 4)) > grlex((1, 2, 3)) + + assert grlex((0, 2, 3)) < grlex((1, 2, 3)) + assert grlex((1, 1, 3)) < grlex((1, 2, 3)) + assert grlex((1, 2, 2)) < grlex((1, 2, 3)) + + assert grlex((2, 2, 3)) > grlex((1, 2, 4)) + assert grlex((1, 3, 3)) > grlex((1, 2, 4)) + + assert grlex((0, 2, 3)) < grlex((1, 2, 2)) + assert grlex((1, 1, 3)) < grlex((1, 2, 2)) + + assert grlex((0, 1, 1)) > grlex((0, 0, 2)) + assert grlex((0, 3, 1)) < grlex((2, 2, 1)) + + assert grlex.is_global is True + +def test_grevlex_order(): + assert grevlex((1, 2, 3)) == (6, (-3, -2, -1)) + assert str(grevlex) == 'grevlex' + + assert grevlex((1, 2, 3)) == grevlex((1, 2, 3)) + + assert grevlex((2, 2, 3)) > grevlex((1, 2, 3)) + assert grevlex((1, 3, 3)) > grevlex((1, 2, 3)) + assert grevlex((1, 2, 4)) > grevlex((1, 2, 3)) + + assert grevlex((0, 2, 3)) < grevlex((1, 2, 3)) + assert grevlex((1, 1, 3)) < grevlex((1, 2, 3)) + assert grevlex((1, 2, 2)) < grevlex((1, 2, 3)) + + assert grevlex((2, 2, 3)) > grevlex((1, 2, 4)) + assert grevlex((1, 3, 3)) > grevlex((1, 2, 4)) + + assert grevlex((0, 2, 3)) < grevlex((1, 2, 2)) + assert grevlex((1, 1, 3)) < grevlex((1, 2, 2)) + + assert grevlex((0, 1, 1)) > grevlex((0, 0, 2)) + assert grevlex((0, 3, 1)) < grevlex((2, 2, 1)) + + assert grevlex.is_global is True + +def test_InverseOrder(): + ilex = InverseOrder(lex) + igrlex = InverseOrder(grlex) + + assert ilex((1, 2, 3)) > ilex((2, 0, 3)) + assert igrlex((1, 2, 3)) < igrlex((0, 2, 3)) + assert str(ilex) == "ilex" + assert str(igrlex) == "igrlex" + assert ilex.is_global is False + assert igrlex.is_global is False + assert ilex != igrlex + assert ilex == InverseOrder(LexOrder()) + +def test_ProductOrder(): + P = ProductOrder((grlex, lambda m: m[:2]), (grlex, lambda m: m[2:])) + assert P((1, 3, 3, 4, 5)) > P((2, 1, 5, 5, 5)) + assert str(P) == "ProductOrder(grlex, grlex)" + assert P.is_global is True + assert ProductOrder((grlex, None), (ilex, None)).is_global is None + assert ProductOrder((igrlex, None), (ilex, None)).is_global is False + +def test_monomial_key(): + assert monomial_key() == lex + + assert monomial_key('lex') == lex + assert monomial_key('grlex') == grlex + assert monomial_key('grevlex') == grevlex + + raises(ValueError, lambda: monomial_key('foo')) + raises(ValueError, lambda: monomial_key(1)) + + M = [x, x**2*z**2, x*y, x**2, S.One, y**2, x**3, y, z, x*y**2*z, x**2*y**2] + assert sorted(M, key=monomial_key('lex', [z, y, x])) == \ + [S.One, x, x**2, x**3, y, x*y, y**2, x**2*y**2, z, x*y**2*z, x**2*z**2] + assert sorted(M, key=monomial_key('grlex', [z, y, x])) == \ + [S.One, x, y, z, x**2, x*y, y**2, x**3, x**2*y**2, x*y**2*z, x**2*z**2] + assert sorted(M, key=monomial_key('grevlex', [z, y, x])) == \ + [S.One, x, y, z, x**2, x*y, y**2, x**3, x**2*y**2, x**2*z**2, x*y**2*z] + +def test_build_product_order(): + assert build_product_order((("grlex", x, y), ("grlex", z, t)), [x, y, z, t])((4, 5, 6, 7)) == \ + ((9, (4, 5)), (13, (6, 7))) + + assert build_product_order((("grlex", x, y), ("grlex", z, t)), [x, y, z, t]) == \ + build_product_order((("grlex", x, y), ("grlex", z, t)), [x, y, z, t]) diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/test_orthopolys.py b/phivenv/Lib/site-packages/sympy/polys/tests/test_orthopolys.py new file mode 100644 index 0000000000000000000000000000000000000000..e81fbe75aa6285d229ba817026f44b23b76abd6a --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/tests/test_orthopolys.py @@ -0,0 +1,175 @@ +"""Tests for efficient functions for generating orthogonal polynomials. """ + +from sympy.core.numbers import Rational as Q +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.polys.polytools import Poly +from sympy.testing.pytest import raises + +from sympy.polys.orthopolys import ( + jacobi_poly, + gegenbauer_poly, + chebyshevt_poly, + chebyshevu_poly, + hermite_poly, + hermite_prob_poly, + legendre_poly, + laguerre_poly, + spherical_bessel_fn, +) + +from sympy.abc import x, a, b + + +def test_jacobi_poly(): + raises(ValueError, lambda: jacobi_poly(-1, a, b, x)) + + assert jacobi_poly(1, a, b, x, polys=True) == Poly( + (a/2 + b/2 + 1)*x + a/2 - b/2, x, domain='ZZ(a,b)') + + assert jacobi_poly(0, a, b, x) == 1 + assert jacobi_poly(1, a, b, x) == a/2 - b/2 + x*(a/2 + b/2 + 1) + assert jacobi_poly(2, a, b, x) == (a**2/8 - a*b/4 - a/8 + b**2/8 - b/8 + + x**2*(a**2/8 + a*b/4 + a*Q(7, 8) + b**2/8 + + b*Q(7, 8) + Q(3, 2)) + x*(a**2/4 + + a*Q(3, 4) - b**2/4 - b*Q(3, 4)) - S.Half) + + assert jacobi_poly(1, a, b, polys=True) == Poly( + (a/2 + b/2 + 1)*x + a/2 - b/2, x, domain='ZZ(a,b)') + + +def test_gegenbauer_poly(): + raises(ValueError, lambda: gegenbauer_poly(-1, a, x)) + + assert gegenbauer_poly( + 1, a, x, polys=True) == Poly(2*a*x, x, domain='ZZ(a)') + + assert gegenbauer_poly(0, a, x) == 1 + assert gegenbauer_poly(1, a, x) == 2*a*x + assert gegenbauer_poly(2, a, x) == -a + x**2*(2*a**2 + 2*a) + assert gegenbauer_poly( + 3, a, x) == x**3*(4*a**3/3 + 4*a**2 + a*Q(8, 3)) + x*(-2*a**2 - 2*a) + + assert gegenbauer_poly(1, S.Half).dummy_eq(x) + assert gegenbauer_poly(1, a, polys=True) == Poly(2*a*x, x, domain='ZZ(a)') + + +def test_chebyshevt_poly(): + raises(ValueError, lambda: chebyshevt_poly(-1, x)) + + assert chebyshevt_poly(1, x, polys=True) == Poly(x) + + assert chebyshevt_poly(0, x) == 1 + assert chebyshevt_poly(1, x) == x + assert chebyshevt_poly(2, x) == 2*x**2 - 1 + assert chebyshevt_poly(3, x) == 4*x**3 - 3*x + assert chebyshevt_poly(4, x) == 8*x**4 - 8*x**2 + 1 + assert chebyshevt_poly(5, x) == 16*x**5 - 20*x**3 + 5*x + assert chebyshevt_poly(6, x) == 32*x**6 - 48*x**4 + 18*x**2 - 1 + assert chebyshevt_poly(75, x) == (2*chebyshevt_poly(37, x)*chebyshevt_poly(38, x) - x).expand() + assert chebyshevt_poly(100, x) == (2*chebyshevt_poly(50, x)**2 - 1).expand() + + assert chebyshevt_poly(1).dummy_eq(x) + assert chebyshevt_poly(1, polys=True) == Poly(x) + + +def test_chebyshevu_poly(): + raises(ValueError, lambda: chebyshevu_poly(-1, x)) + + assert chebyshevu_poly(1, x, polys=True) == Poly(2*x) + + assert chebyshevu_poly(0, x) == 1 + assert chebyshevu_poly(1, x) == 2*x + assert chebyshevu_poly(2, x) == 4*x**2 - 1 + assert chebyshevu_poly(3, x) == 8*x**3 - 4*x + assert chebyshevu_poly(4, x) == 16*x**4 - 12*x**2 + 1 + assert chebyshevu_poly(5, x) == 32*x**5 - 32*x**3 + 6*x + assert chebyshevu_poly(6, x) == 64*x**6 - 80*x**4 + 24*x**2 - 1 + + assert chebyshevu_poly(1).dummy_eq(2*x) + assert chebyshevu_poly(1, polys=True) == Poly(2*x) + + +def test_hermite_poly(): + raises(ValueError, lambda: hermite_poly(-1, x)) + + assert hermite_poly(1, x, polys=True) == Poly(2*x) + + assert hermite_poly(0, x) == 1 + assert hermite_poly(1, x) == 2*x + assert hermite_poly(2, x) == 4*x**2 - 2 + assert hermite_poly(3, x) == 8*x**3 - 12*x + assert hermite_poly(4, x) == 16*x**4 - 48*x**2 + 12 + assert hermite_poly(5, x) == 32*x**5 - 160*x**3 + 120*x + assert hermite_poly(6, x) == 64*x**6 - 480*x**4 + 720*x**2 - 120 + + assert hermite_poly(1).dummy_eq(2*x) + assert hermite_poly(1, polys=True) == Poly(2*x) + + +def test_hermite_prob_poly(): + raises(ValueError, lambda: hermite_prob_poly(-1, x)) + + assert hermite_prob_poly(1, x, polys=True) == Poly(x) + + assert hermite_prob_poly(0, x) == 1 + assert hermite_prob_poly(1, x) == x + assert hermite_prob_poly(2, x) == x**2 - 1 + assert hermite_prob_poly(3, x) == x**3 - 3*x + assert hermite_prob_poly(4, x) == x**4 - 6*x**2 + 3 + assert hermite_prob_poly(5, x) == x**5 - 10*x**3 + 15*x + assert hermite_prob_poly(6, x) == x**6 - 15*x**4 + 45*x**2 - 15 + + assert hermite_prob_poly(1).dummy_eq(x) + assert hermite_prob_poly(1, polys=True) == Poly(x) + + +def test_legendre_poly(): + raises(ValueError, lambda: legendre_poly(-1, x)) + + assert legendre_poly(1, x, polys=True) == Poly(x, domain='QQ') + + assert legendre_poly(0, x) == 1 + assert legendre_poly(1, x) == x + assert legendre_poly(2, x) == Q(3, 2)*x**2 - Q(1, 2) + assert legendre_poly(3, x) == Q(5, 2)*x**3 - Q(3, 2)*x + assert legendre_poly(4, x) == Q(35, 8)*x**4 - Q(30, 8)*x**2 + Q(3, 8) + assert legendre_poly(5, x) == Q(63, 8)*x**5 - Q(70, 8)*x**3 + Q(15, 8)*x + assert legendre_poly(6, x) == Q( + 231, 16)*x**6 - Q(315, 16)*x**4 + Q(105, 16)*x**2 - Q(5, 16) + + assert legendre_poly(1).dummy_eq(x) + assert legendre_poly(1, polys=True) == Poly(x) + + +def test_laguerre_poly(): + raises(ValueError, lambda: laguerre_poly(-1, x)) + + assert laguerre_poly(1, x, polys=True) == Poly(-x + 1, domain='QQ') + + assert laguerre_poly(0, x) == 1 + assert laguerre_poly(1, x) == -x + 1 + assert laguerre_poly(2, x) == Q(1, 2)*x**2 - Q(4, 2)*x + 1 + assert laguerre_poly(3, x) == -Q(1, 6)*x**3 + Q(9, 6)*x**2 - Q(18, 6)*x + 1 + assert laguerre_poly(4, x) == Q( + 1, 24)*x**4 - Q(16, 24)*x**3 + Q(72, 24)*x**2 - Q(96, 24)*x + 1 + assert laguerre_poly(5, x) == -Q(1, 120)*x**5 + Q(25, 120)*x**4 - Q( + 200, 120)*x**3 + Q(600, 120)*x**2 - Q(600, 120)*x + 1 + assert laguerre_poly(6, x) == Q(1, 720)*x**6 - Q(36, 720)*x**5 + Q(450, 720)*x**4 - Q(2400, 720)*x**3 + Q(5400, 720)*x**2 - Q(4320, 720)*x + 1 + + assert laguerre_poly(0, x, a) == 1 + assert laguerre_poly(1, x, a) == -x + a + 1 + assert laguerre_poly(2, x, a) == x**2/2 + (-a - 2)*x + a**2/2 + a*Q(3, 2) + 1 + assert laguerre_poly(3, x, a) == -x**3/6 + (a/2 + Q( + 3)/2)*x**2 + (-a**2/2 - a*Q(5, 2) - 3)*x + a**3/6 + a**2 + a*Q(11, 6) + 1 + + assert laguerre_poly(1).dummy_eq(-x + 1) + assert laguerre_poly(1, polys=True) == Poly(-x + 1) + + +def test_spherical_bessel_fn(): + x, z = symbols("x z") + assert spherical_bessel_fn(1, z) == 1/z**2 + assert spherical_bessel_fn(2, z) == -1/z + 3/z**3 + assert spherical_bessel_fn(3, z) == -6/z**2 + 15/z**4 + assert spherical_bessel_fn(4, z) == 1/z - 45/z**3 + 105/z**5 diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/test_partfrac.py b/phivenv/Lib/site-packages/sympy/polys/tests/test_partfrac.py new file mode 100644 index 0000000000000000000000000000000000000000..83c5d48383d20e67dbb53c081093ad35e654c9a0 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/tests/test_partfrac.py @@ -0,0 +1,249 @@ +"""Tests for algorithms for partial fraction decomposition of rational +functions. """ + +from sympy.polys.partfrac import ( + apart_undetermined_coeffs, + apart, + apart_list, assemble_partfrac_list +) + +from sympy.core.expr import Expr +from sympy.core.function import Lambda +from sympy.core.numbers import (E, I, Rational, pi, all_close) +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.matrices.dense import Matrix +from sympy.polys.polytools import (Poly, factor) +from sympy.polys.rationaltools import together +from sympy.polys.rootoftools import RootSum +from sympy.testing.pytest import raises, XFAIL +from sympy.abc import x, y, a, b, c + + +def test_apart(): + assert apart(1) == 1 + assert apart(1, x) == 1 + + f, g = (x**2 + 1)/(x + 1), 2/(x + 1) + x - 1 + + assert apart(f, full=False) == g + assert apart(f, full=True) == g + + f, g = 1/(x + 2)/(x + 1), 1/(1 + x) - 1/(2 + x) + + assert apart(f, full=False) == g + assert apart(f, full=True) == g + + f, g = 1/(x + 1)/(x + 5), -1/(5 + x)/4 + 1/(1 + x)/4 + + assert apart(f, full=False) == g + assert apart(f, full=True) == g + + assert apart((E*x + 2)/(x - pi)*(x - 1), x) == \ + 2 - E + E*pi + E*x + (E*pi + 2)*(pi - 1)/(x - pi) + + assert apart(Eq((x**2 + 1)/(x + 1), x), x) == Eq(x - 1 + 2/(x + 1), x) + + assert apart(x/2, y) == x/2 + + f, g = (x+y)/(2*x - y), Rational(3, 2)*y/(2*x - y) + S.Half + + assert apart(f, x, full=False) == g + assert apart(f, x, full=True) == g + + f, g = (x+y)/(2*x - y), 3*x/(2*x - y) - 1 + + assert apart(f, y, full=False) == g + assert apart(f, y, full=True) == g + + raises(NotImplementedError, lambda: apart(1/(x + 1)/(y + 2))) + + +def test_apart_matrix(): + M = Matrix(2, 2, lambda i, j: 1/(x + i + 1)/(x + j)) + + assert apart(M) == Matrix([ + [1/x - 1/(x + 1), (x + 1)**(-2)], + [1/(2*x) - (S.Half)/(x + 2), 1/(x + 1) - 1/(x + 2)], + ]) + + +def test_apart_symbolic(): + f = a*x**4 + (2*b + 2*a*c)*x**3 + (4*b*c - a**2 + a*c**2)*x**2 + \ + (-2*a*b + 2*b*c**2)*x - b**2 + g = a**2*x**4 + (2*a*b + 2*c*a**2)*x**3 + (4*a*b*c + b**2 + + a**2*c**2)*x**2 + (2*c*b**2 + 2*a*b*c**2)*x + b**2*c**2 + + assert apart(f/g, x) == 1/a - 1/(x + c)**2 - b**2/(a*(a*x + b)**2) + + assert apart(1/((x + a)*(x + b)*(x + c)), x) == \ + 1/((a - c)*(b - c)*(c + x)) - 1/((a - b)*(b - c)*(b + x)) + \ + 1/((a - b)*(a - c)*(a + x)) + + +def _make_extension_example(): + # https://github.com/sympy/sympy/issues/18531 + from sympy.core import Mul + def mul2(expr): + # 2-arg mul hack... + return Mul(2, expr, evaluate=False) + + f = ((x**2 + 1)**3/((x - 1)**2*(x + 1)**2*(-x**2 + 2*x + 1)*(x**2 + 2*x - 1))) + g = (1/mul2(x - sqrt(2) + 1) + - 1/mul2(x - sqrt(2) - 1) + + 1/mul2(x + 1 + sqrt(2)) + - 1/mul2(x - 1 + sqrt(2)) + + 1/mul2((x + 1)**2) + + 1/mul2((x - 1)**2)) + return f, g + + +def test_apart_extension(): + f = 2/(x**2 + 1) + g = I/(x + I) - I/(x - I) + + assert apart(f, extension=I) == g + assert apart(f, gaussian=True) == g + + f = x/((x - 2)*(x + I)) + + assert factor(together(apart(f)).expand()) == f + + f, g = _make_extension_example() + + # XXX: Only works with dotprodsimp. See test_apart_extension_xfail below + from sympy.matrices import dotprodsimp + with dotprodsimp(True): + assert apart(f, x, extension={sqrt(2)}) == g + + +def test_apart_extension_xfail(): + f, g = _make_extension_example() + assert apart(f, x, extension={sqrt(2)}) == g + + +def test_apart_full(): + f = 1/(x**2 + 1) + + assert apart(f, full=False) == f + assert apart(f, full=True).dummy_eq( + -RootSum(x**2 + 1, Lambda(a, a/(x - a)), auto=False)/2) + + f = 1/(x**3 + x + 1) + + assert apart(f, full=False) == f + assert apart(f, full=True).dummy_eq( + RootSum(x**3 + x + 1, + Lambda(a, (a**2*Rational(6, 31) - a*Rational(9, 31) + Rational(4, 31))/(x - a)), auto=False)) + + f = 1/(x**5 + 1) + + assert apart(f, full=False) == \ + (Rational(-1, 5))*((x**3 - 2*x**2 + 3*x - 4)/(x**4 - x**3 + x**2 - + x + 1)) + (Rational(1, 5))/(x + 1) + assert apart(f, full=True).dummy_eq( + -RootSum(x**4 - x**3 + x**2 - x + 1, + Lambda(a, a/(x - a)), auto=False)/5 + (Rational(1, 5))/(x + 1)) + + +def test_apart_full_floats(): + # https://github.com/sympy/sympy/issues/26648 + f = ( + 6.43369157032015e-9*x**3 + 1.35203404799555e-5*x**2 + + 0.00357538393743079*x + 0.085 + )/( + 4.74334912634438e-11*x**4 + 4.09576274286244e-6*x**3 + + 0.00334241812250921*x**2 + 0.15406018058983*x + 1.0 + ) + + expected = ( + 133.599202650992/(x + 85524.0054884464) + + 1.07757928431867/(x + 774.88576677949) + + 0.395006955518971/(x + 40.7977016133126) + + 0.564264854137341/(x + 7.79746609204661) + ) + + f_apart = apart(f, full=True).evalf() + + # There is a significant floating point error in this operation. + assert all_close(f_apart, expected, rtol=1e-3, atol=1e-5) + + +def test_apart_undetermined_coeffs(): + p = Poly(2*x - 3) + q = Poly(x**9 - x**8 - x**6 + x**5 - 2*x**2 + 3*x - 1) + r = (-x**7 - x**6 - x**5 + 4)/(x**8 - x**5 - 2*x + 1) + 1/(x - 1) + + assert apart_undetermined_coeffs(p, q) == r + + p = Poly(1, x, domain='ZZ[a,b]') + q = Poly((x + a)*(x + b), x, domain='ZZ[a,b]') + r = 1/((a - b)*(b + x)) - 1/((a - b)*(a + x)) + + assert apart_undetermined_coeffs(p, q) == r + + +def test_apart_list(): + from sympy.utilities.iterables import numbered_symbols + def dummy_eq(i, j): + if type(i) in (list, tuple): + return all(dummy_eq(i, j) for i, j in zip(i, j)) + return i == j or i.dummy_eq(j) + + w0, w1, w2 = Symbol("w0"), Symbol("w1"), Symbol("w2") + _a = Dummy("a") + + f = (-2*x - 2*x**2) / (3*x**2 - 6*x) + got = apart_list(f, x, dummies=numbered_symbols("w")) + ans = (-1, Poly(Rational(2, 3), x, domain='QQ'), + [(Poly(w0 - 2, w0, domain='ZZ'), Lambda(_a, 2), Lambda(_a, -_a + x), 1)]) + assert dummy_eq(got, ans) + + got = apart_list(2/(x**2-2), x, dummies=numbered_symbols("w")) + ans = (1, Poly(0, x, domain='ZZ'), [(Poly(w0**2 - 2, w0, domain='ZZ'), + Lambda(_a, _a/2), + Lambda(_a, -_a + x), 1)]) + assert dummy_eq(got, ans) + + f = 36 / (x**5 - 2*x**4 - 2*x**3 + 4*x**2 + x - 2) + got = apart_list(f, x, dummies=numbered_symbols("w")) + ans = (1, Poly(0, x, domain='ZZ'), + [(Poly(w0 - 2, w0, domain='ZZ'), Lambda(_a, 4), Lambda(_a, -_a + x), 1), + (Poly(w1**2 - 1, w1, domain='ZZ'), Lambda(_a, -3*_a - 6), Lambda(_a, -_a + x), 2), + (Poly(w2 + 1, w2, domain='ZZ'), Lambda(_a, -4), Lambda(_a, -_a + x), 1)]) + assert dummy_eq(got, ans) + + +def test_assemble_partfrac_list(): + f = 36 / (x**5 - 2*x**4 - 2*x**3 + 4*x**2 + x - 2) + pfd = apart_list(f) + assert assemble_partfrac_list(pfd) == -4/(x + 1) - 3/(x + 1)**2 - 9/(x - 1)**2 + 4/(x - 2) + + a = Dummy("a") + pfd = (1, Poly(0, x, domain='ZZ'), [([sqrt(2),-sqrt(2)], Lambda(a, a/2), Lambda(a, -a + x), 1)]) + assert assemble_partfrac_list(pfd) == -1/(sqrt(2)*(x + sqrt(2))) + 1/(sqrt(2)*(x - sqrt(2))) + + +@XFAIL +def test_noncommutative_pseudomultivariate(): + # apart doesn't go inside noncommutative expressions + class foo(Expr): + is_commutative=False + e = x/(x + x*y) + c = 1/(1 + y) + assert apart(e + foo(e)) == c + foo(c) + assert apart(e*foo(e)) == c*foo(c) + +def test_noncommutative(): + class foo(Expr): + is_commutative=False + e = x/(x + x*y) + c = 1/(1 + y) + assert apart(e + foo()) == c + foo() + +def test_issue_5798(): + assert apart( + 2*x/(x**2 + 1) - (x - 1)/(2*(x**2 + 1)) + 1/(2*(x + 1)) - 2/x) == \ + (3*x + 1)/(x**2 + 1)/2 + 1/(x + 1)/2 - 2/x diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/test_polyclasses.py b/phivenv/Lib/site-packages/sympy/polys/tests/test_polyclasses.py new file mode 100644 index 0000000000000000000000000000000000000000..5e2c8f2c3ca94c42fc524c3ec1c0300d881cf3a5 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/tests/test_polyclasses.py @@ -0,0 +1,601 @@ +"""Tests for OO layer of several polynomial representations. """ + +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.polys.domains import ZZ, QQ +from sympy.polys.polyclasses import DMP, DMF, ANP +from sympy.polys.polyerrors import (CoercionFailed, ExactQuotientFailed, + NotInvertible) +from sympy.polys.specialpolys import f_polys +from sympy.testing.pytest import raises, warns_deprecated_sympy + +f_0, f_1, f_2, f_3, f_4, f_5, f_6 = [ f.to_dense() for f in f_polys() ] + +def test_DMP___init__(): + f = DMP([[ZZ(0)], [], [ZZ(0), ZZ(1), ZZ(2)], [ZZ(3)]], ZZ) + + assert f._rep == [[1, 2], [3]] + assert f.dom == ZZ + assert f.lev == 1 + + f = DMP([[ZZ(1), ZZ(2)], [ZZ(3)]], ZZ, 1) + + assert f._rep == [[1, 2], [3]] + assert f.dom == ZZ + assert f.lev == 1 + + f = DMP.from_dict({(1, 1): ZZ(1), (0, 0): ZZ(2)}, 1, ZZ) + + assert f._rep == [[1, 0], [2]] + assert f.dom == ZZ + assert f.lev == 1 + + +def test_DMP_rep_deprecation(): + f = DMP([1, 2, 3], ZZ) + + with warns_deprecated_sympy(): + assert f.rep == [1, 2, 3] + + +def test_DMP___eq__(): + assert DMP([[ZZ(1), ZZ(2)], [ZZ(3)]], ZZ) == \ + DMP([[ZZ(1), ZZ(2)], [ZZ(3)]], ZZ) + + assert DMP([[ZZ(1), ZZ(2)], [ZZ(3)]], ZZ) == \ + DMP([[QQ(1), QQ(2)], [QQ(3)]], QQ) + assert DMP([[QQ(1), QQ(2)], [QQ(3)]], QQ) == \ + DMP([[ZZ(1), ZZ(2)], [ZZ(3)]], ZZ) + + assert DMP([[[ZZ(1)]]], ZZ) != DMP([[ZZ(1)]], ZZ) + assert DMP([[ZZ(1)]], ZZ) != DMP([[[ZZ(1)]]], ZZ) + + +def test_DMP___bool__(): + assert bool(DMP([[]], ZZ)) is False + assert bool(DMP([[ZZ(1)]], ZZ)) is True + + +def test_DMP_to_dict(): + f = DMP([[ZZ(3)], [], [ZZ(2)], [], [ZZ(8)]], ZZ) + + assert f.to_dict() == \ + {(4, 0): 3, (2, 0): 2, (0, 0): 8} + assert f.to_sympy_dict() == \ + {(4, 0): ZZ.to_sympy(3), (2, 0): ZZ.to_sympy(2), (0, 0): + ZZ.to_sympy(8)} + + +def test_DMP_properties(): + assert DMP([[]], ZZ).is_zero is True + assert DMP([[ZZ(1)]], ZZ).is_zero is False + + assert DMP([[ZZ(1)]], ZZ).is_one is True + assert DMP([[ZZ(2)]], ZZ).is_one is False + + assert DMP([[ZZ(1)]], ZZ).is_ground is True + assert DMP([[ZZ(1)], [ZZ(2)], [ZZ(1)]], ZZ).is_ground is False + + assert DMP([[ZZ(1)], [ZZ(2), ZZ(0)], [ZZ(1), ZZ(0)]], ZZ).is_sqf is True + assert DMP([[ZZ(1)], [ZZ(2), ZZ(0)], [ZZ(1), ZZ(0), ZZ(0)]], ZZ).is_sqf is False + + assert DMP([[ZZ(1), ZZ(2)], [ZZ(3)]], ZZ).is_monic is True + assert DMP([[ZZ(2), ZZ(2)], [ZZ(3)]], ZZ).is_monic is False + + assert DMP([[ZZ(1), ZZ(2)], [ZZ(3)]], ZZ).is_primitive is True + assert DMP([[ZZ(2), ZZ(4)], [ZZ(6)]], ZZ).is_primitive is False + + +def test_DMP_arithmetics(): + f = DMP([[ZZ(2)], [ZZ(2), ZZ(0)]], ZZ) + + assert f.mul_ground(2) == DMP([[ZZ(4)], [ZZ(4), ZZ(0)]], ZZ) + assert f.quo_ground(2) == DMP([[ZZ(1)], [ZZ(1), ZZ(0)]], ZZ) + + raises(ExactQuotientFailed, lambda: f.exquo_ground(3)) + + f = DMP([[ZZ(-5)]], ZZ) + g = DMP([[ZZ(5)]], ZZ) + + assert f.abs() == g + assert abs(f) == g + + assert g.neg() == f + assert -g == f + + h = DMP([[]], ZZ) + + assert f.add(g) == h + assert f + g == h + assert g + f == h + assert f + 5 == h + assert 5 + f == h + + h = DMP([[ZZ(-10)]], ZZ) + + assert f.sub(g) == h + assert f - g == h + assert g - f == -h + assert f - 5 == h + assert 5 - f == -h + + h = DMP([[ZZ(-25)]], ZZ) + + assert f.mul(g) == h + assert f * g == h + assert g * f == h + assert f * 5 == h + assert 5 * f == h + + h = DMP([[ZZ(25)]], ZZ) + + assert f.sqr() == h + assert f.pow(2) == h + assert f**2 == h + + raises(TypeError, lambda: f.pow('x')) + + f = DMP([[ZZ(1)], [], [ZZ(1), ZZ(0), ZZ(0)]], ZZ) + g = DMP([[ZZ(2)], [ZZ(-2), ZZ(0)]], ZZ) + + q = DMP([[ZZ(2)], [ZZ(2), ZZ(0)]], ZZ) + r = DMP([[ZZ(8), ZZ(0), ZZ(0)]], ZZ) + + assert f.pdiv(g) == (q, r) + assert f.pquo(g) == q + assert f.prem(g) == r + + raises(ExactQuotientFailed, lambda: f.pexquo(g)) + + f = DMP([[ZZ(1)], [], [ZZ(1), ZZ(0), ZZ(0)]], ZZ) + g = DMP([[ZZ(1)], [ZZ(-1), ZZ(0)]], ZZ) + + q = DMP([[ZZ(1)], [ZZ(1), ZZ(0)]], ZZ) + r = DMP([[ZZ(2), ZZ(0), ZZ(0)]], ZZ) + + assert f.div(g) == (q, r) + assert f.quo(g) == q + assert f.rem(g) == r + + assert divmod(f, g) == (q, r) + assert f // g == q + assert f % g == r + + raises(ExactQuotientFailed, lambda: f.exquo(g)) + + f = DMP([ZZ(1), ZZ(0), ZZ(-1)], ZZ) + g = DMP([ZZ(2), ZZ(-2)], ZZ) + + q = DMP([], ZZ) + r = f + + pq = DMP([ZZ(2), ZZ(2)], ZZ) + pr = DMP([], ZZ) + + assert f.div(g) == (q, r) + assert f.quo(g) == q + assert f.rem(g) == r + + assert divmod(f, g) == (q, r) + assert f // g == q + assert f % g == r + + raises(ExactQuotientFailed, lambda: f.exquo(g)) + + assert f.pdiv(g) == (pq, pr) + assert f.pquo(g) == pq + assert f.prem(g) == pr + assert f.pexquo(g) == pq + + +def test_DMP_functionality(): + f = DMP([[ZZ(1)], [ZZ(2), ZZ(0)], [ZZ(1), ZZ(0), ZZ(0)]], ZZ) + g = DMP([[ZZ(1)], [ZZ(1), ZZ(0)]], ZZ) + h = DMP([[ZZ(1)]], ZZ) + + assert f.degree() == 2 + assert f.degree_list() == (2, 2) + assert f.total_degree() == 2 + + assert f.LC() == ZZ(1) + assert f.TC() == ZZ(0) + assert f.nth(1, 1) == ZZ(2) + + raises(TypeError, lambda: f.nth(0, 'x')) + + assert f.max_norm() == 2 + assert f.l1_norm() == 4 + + u = DMP([[ZZ(2)], [ZZ(2), ZZ(0)]], ZZ) + + assert f.diff(m=1, j=0) == u + assert f.diff(m=1, j=1) == u + + raises(TypeError, lambda: f.diff(m='x', j=0)) + + u = DMP([ZZ(1), ZZ(2), ZZ(1)], ZZ) + v = DMP([ZZ(1), ZZ(2), ZZ(1)], ZZ) + + assert f.eval(a=1, j=0) == u + assert f.eval(a=1, j=1) == v + + assert f.eval(1).eval(1) == ZZ(4) + + assert f.cofactors(g) == (g, g, h) + assert f.gcd(g) == g + assert f.lcm(g) == f + + u = DMP([[QQ(45), QQ(30), QQ(5)]], QQ) + v = DMP([[QQ(1), QQ(2, 3), QQ(1, 9)]], QQ) + + assert u.monic() == v + + assert (4*f).content() == ZZ(4) + assert (4*f).primitive() == (ZZ(4), f) + + f = DMP([QQ(1,3), QQ(1)], QQ) + g = DMP([QQ(1,7), QQ(1)], QQ) + + assert f.cancel(g) == f.cancel(g, include=True) == ( + DMP([QQ(7), QQ(21)], QQ), + DMP([QQ(3), QQ(21)], QQ) + ) + assert f.cancel(g, include=False) == ( + QQ(7), + QQ(3), + DMP([QQ(1), QQ(3)], QQ), + DMP([QQ(1), QQ(7)], QQ) + ) + + f = DMP([[ZZ(1)], [ZZ(2)], [ZZ(3)], [ZZ(4)], [ZZ(5)], [ZZ(6)]], ZZ) + + assert f.trunc(3) == DMP([[ZZ(1)], [ZZ(-1)], [], [ZZ(1)], [ZZ(-1)], []], ZZ) + + f = DMP(f_4, ZZ) + + assert f.sqf_part() == -f + assert f.sqf_list() == (ZZ(-1), [(-f, 1)]) + + f = DMP([[ZZ(-1)], [], [], [ZZ(5)]], ZZ) + g = DMP([[ZZ(3), ZZ(1)], [], []], ZZ) + h = DMP([[ZZ(45), ZZ(30), ZZ(5)]], ZZ) + + r = DMP([ZZ(675), ZZ(675), ZZ(225), ZZ(25)], ZZ) + + assert f.subresultants(g) == [f, g, h] + assert f.resultant(g) == r + + f = DMP([ZZ(1), ZZ(3), ZZ(9), ZZ(-13)], ZZ) + + assert f.discriminant() == -11664 + + f = DMP([QQ(2), QQ(0)], QQ) + g = DMP([QQ(1), QQ(0), QQ(-16)], QQ) + + s = DMP([QQ(1, 32), QQ(0)], QQ) + t = DMP([QQ(-1, 16)], QQ) + h = DMP([QQ(1)], QQ) + + assert f.half_gcdex(g) == (s, h) + assert f.gcdex(g) == (s, t, h) + + assert f.invert(g) == s + + f = DMP([[QQ(1)], [QQ(2)], [QQ(3)]], QQ) + + raises(ValueError, lambda: f.half_gcdex(f)) + raises(ValueError, lambda: f.gcdex(f)) + + raises(ValueError, lambda: f.invert(f)) + + f = DMP(ZZ.map([1, 0, 20, 0, 150, 0, 500, 0, 625, -2, 0, -10, 9]), ZZ) + g = DMP([ZZ(1), ZZ(0), ZZ(0), ZZ(-2), ZZ(9)], ZZ) + h = DMP([ZZ(1), ZZ(0), ZZ(5), ZZ(0)], ZZ) + + assert g.compose(h) == f + assert f.decompose() == [g, h] + + f = DMP([[QQ(1)], [QQ(2)], [QQ(3)]], QQ) + + raises(ValueError, lambda: f.decompose()) + raises(ValueError, lambda: f.sturm()) + + +def test_DMP_exclude(): + f = [[[[[[[[[[[[[[[[[[[[[[[[[[ZZ(1)]], [[]]]]]]]]]]]]]]]]]]]]]]]]]] + J = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 24, 25] + + assert DMP(f, ZZ).exclude() == (J, DMP([ZZ(1), ZZ(0)], ZZ)) + assert DMP([[ZZ(1)], [ZZ(1), ZZ(0)]], ZZ).exclude() ==\ + ([], DMP([[ZZ(1)], [ZZ(1), ZZ(0)]], ZZ)) + + +def test_DMF__init__(): + f = DMF(([[0], [], [0, 1, 2], [3]], [[1, 2, 3]]), ZZ) + + assert f.num == [[1, 2], [3]] + assert f.den == [[1, 2, 3]] + assert f.lev == 1 + assert f.dom == ZZ + + f = DMF(([[1, 2], [3]], [[1, 2, 3]]), ZZ, 1) + + assert f.num == [[1, 2], [3]] + assert f.den == [[1, 2, 3]] + assert f.lev == 1 + assert f.dom == ZZ + + f = DMF(([[-1], [-2]], [[3], [-4]]), ZZ) + + assert f.num == [[-1], [-2]] + assert f.den == [[3], [-4]] + assert f.lev == 1 + assert f.dom == ZZ + + f = DMF(([[1], [2]], [[-3], [4]]), ZZ) + + assert f.num == [[-1], [-2]] + assert f.den == [[3], [-4]] + assert f.lev == 1 + assert f.dom == ZZ + + f = DMF(([[1], [2]], [[-3], [4]]), ZZ) + + assert f.num == [[-1], [-2]] + assert f.den == [[3], [-4]] + assert f.lev == 1 + assert f.dom == ZZ + + f = DMF(([[]], [[-3], [4]]), ZZ) + + assert f.num == [[]] + assert f.den == [[1]] + assert f.lev == 1 + assert f.dom == ZZ + + f = DMF(17, ZZ, 1) + + assert f.num == [[17]] + assert f.den == [[1]] + assert f.lev == 1 + assert f.dom == ZZ + + f = DMF(([[1], [2]]), ZZ) + + assert f.num == [[1], [2]] + assert f.den == [[1]] + assert f.lev == 1 + assert f.dom == ZZ + + f = DMF([[0], [], [0, 1, 2], [3]], ZZ) + + assert f.num == [[1, 2], [3]] + assert f.den == [[1]] + assert f.lev == 1 + assert f.dom == ZZ + + f = DMF({(1, 1): 1, (0, 0): 2}, ZZ, 1) + + assert f.num == [[1, 0], [2]] + assert f.den == [[1]] + assert f.lev == 1 + assert f.dom == ZZ + + f = DMF(([[QQ(1)], [QQ(2)]], [[-QQ(3)], [QQ(4)]]), QQ) + + assert f.num == [[-QQ(1)], [-QQ(2)]] + assert f.den == [[QQ(3)], [-QQ(4)]] + assert f.lev == 1 + assert f.dom == QQ + + f = DMF(([[QQ(1, 5)], [QQ(2, 5)]], [[-QQ(3, 7)], [QQ(4, 7)]]), QQ) + + assert f.num == [[-QQ(7)], [-QQ(14)]] + assert f.den == [[QQ(15)], [-QQ(20)]] + assert f.lev == 1 + assert f.dom == QQ + + raises(ValueError, lambda: DMF(([1], [[1]]), ZZ)) + raises(ZeroDivisionError, lambda: DMF(([1], []), ZZ)) + + +def test_DMF__bool__(): + assert bool(DMF([[]], ZZ)) is False + assert bool(DMF([[1]], ZZ)) is True + + +def test_DMF_properties(): + assert DMF([[]], ZZ).is_zero is True + assert DMF([[]], ZZ).is_one is False + + assert DMF([[1]], ZZ).is_zero is False + assert DMF([[1]], ZZ).is_one is True + + assert DMF(([[1]], [[2]]), ZZ).is_one is False + + +def test_DMF_arithmetics(): + f = DMF([[7], [-9]], ZZ) + g = DMF([[-7], [9]], ZZ) + + assert f.neg() == -f == g + + f = DMF(([[1]], [[1], []]), ZZ) + g = DMF(([[1]], [[1, 0]]), ZZ) + + h = DMF(([[1], [1, 0]], [[1, 0], []]), ZZ) + + assert f.add(g) == f + g == h + assert g.add(f) == g + f == h + + h = DMF(([[-1], [1, 0]], [[1, 0], []]), ZZ) + + assert f.sub(g) == f - g == h + + h = DMF(([[1]], [[1, 0], []]), ZZ) + + assert f.mul(g) == f*g == h + assert g.mul(f) == g*f == h + + h = DMF(([[1, 0]], [[1], []]), ZZ) + + assert f.quo(g) == f/g == h + + h = DMF(([[1]], [[1], [], [], []]), ZZ) + + assert f.pow(3) == f**3 == h + + h = DMF(([[1]], [[1, 0, 0, 0]]), ZZ) + + assert g.pow(3) == g**3 == h + + h = DMF(([[1, 0]], [[1]]), ZZ) + + assert g.pow(-1) == g**-1 == h + + +def test_ANP___init__(): + rep = [QQ(1), QQ(1)] + mod = [QQ(1), QQ(0), QQ(1)] + + f = ANP(rep, mod, QQ) + + assert f.to_list() == [QQ(1), QQ(1)] + assert f.mod_to_list() == [QQ(1), QQ(0), QQ(1)] + assert f.dom == QQ + + rep = {1: QQ(1), 0: QQ(1)} + mod = {2: QQ(1), 0: QQ(1)} + + f = ANP(rep, mod, QQ) + + assert f.to_list() == [QQ(1), QQ(1)] + assert f.mod_to_list() == [QQ(1), QQ(0), QQ(1)] + assert f.dom == QQ + + f = ANP(1, mod, QQ) + + assert f.to_list() == [QQ(1)] + assert f.mod_to_list() == [QQ(1), QQ(0), QQ(1)] + assert f.dom == QQ + + f = ANP([1, 0.5], mod, QQ) + + assert all(QQ.of_type(a) for a in f.to_list()) + + raises(CoercionFailed, lambda: ANP([sqrt(2)], mod, QQ)) + + +def test_ANP___eq__(): + a = ANP([QQ(1), QQ(1)], [QQ(1), QQ(0), QQ(1)], QQ) + b = ANP([QQ(1), QQ(1)], [QQ(1), QQ(0), QQ(2)], QQ) + + assert (a == a) is True + assert (a != a) is False + + assert (a == b) is False + assert (a != b) is True + + b = ANP([QQ(1), QQ(2)], [QQ(1), QQ(0), QQ(1)], QQ) + + assert (a == b) is False + assert (a != b) is True + + +def test_ANP___bool__(): + assert bool(ANP([], [QQ(1), QQ(0), QQ(1)], QQ)) is False + assert bool(ANP([QQ(1)], [QQ(1), QQ(0), QQ(1)], QQ)) is True + + +def test_ANP_properties(): + mod = [QQ(1), QQ(0), QQ(1)] + + assert ANP([QQ(0)], mod, QQ).is_zero is True + assert ANP([QQ(1)], mod, QQ).is_zero is False + + assert ANP([QQ(1)], mod, QQ).is_one is True + assert ANP([QQ(2)], mod, QQ).is_one is False + + +def test_ANP_arithmetics(): + mod = [QQ(1), QQ(0), QQ(0), QQ(-2)] + + a = ANP([QQ(2), QQ(-1), QQ(1)], mod, QQ) + b = ANP([QQ(1), QQ(2)], mod, QQ) + + c = ANP([QQ(-2), QQ(1), QQ(-1)], mod, QQ) + + assert a.neg() == -a == c + + c = ANP([QQ(2), QQ(0), QQ(3)], mod, QQ) + + assert a.add(b) == a + b == c + assert b.add(a) == b + a == c + + c = ANP([QQ(2), QQ(-2), QQ(-1)], mod, QQ) + + assert a.sub(b) == a - b == c + + c = ANP([QQ(-2), QQ(2), QQ(1)], mod, QQ) + + assert b.sub(a) == b - a == c + + c = ANP([QQ(3), QQ(-1), QQ(6)], mod, QQ) + + assert a.mul(b) == a*b == c + assert b.mul(a) == b*a == c + + c = ANP([QQ(-1, 43), QQ(9, 43), QQ(5, 43)], mod, QQ) + + assert a.pow(0) == a**(0) == ANP(1, mod, QQ) + assert a.pow(1) == a**(1) == a + + assert a.pow(-1) == a**(-1) == c + + assert a.quo(a) == a.mul(a.pow(-1)) == a*a**(-1) == ANP(1, mod, QQ) + + c = ANP([], [1, 0, 0, -2], QQ) + r1 = a.rem(b) + + (q, r2) = a.div(b) + + assert r1 == r2 == c == a % b + + raises(NotInvertible, lambda: a.div(c)) + raises(NotInvertible, lambda: a.rem(c)) + + # Comparison with "hard-coded" value fails despite looking identical + # from sympy import Rational + # c = ANP([Rational(11, 10), Rational(-1, 5), Rational(-3, 5)], [1, 0, 0, -2], QQ) + + assert q == a/b # == c + +def test_ANP_unify(): + mod_z = [ZZ(1), ZZ(0), ZZ(-2)] + mod_q = [QQ(1), QQ(0), QQ(-2)] + + a = ANP([QQ(1)], mod_q, QQ) + b = ANP([ZZ(1)], mod_z, ZZ) + + assert a.unify(b)[0] == QQ + assert b.unify(a)[0] == QQ + assert a.unify(a)[0] == QQ + assert b.unify(b)[0] == ZZ + + assert a.unify_ANP(b)[-1] == QQ + assert b.unify_ANP(a)[-1] == QQ + assert a.unify_ANP(a)[-1] == QQ + assert b.unify_ANP(b)[-1] == ZZ + + +def test_zero_poly(): + from sympy import Symbol + x = Symbol('x') + + R_old = ZZ.old_poly_ring(x) + zero_poly_old = R_old(0) + cont_old, prim_old = zero_poly_old.primitive() + + assert cont_old == 0 + assert prim_old == zero_poly_old + assert prim_old.is_primitive is False diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/test_polyfuncs.py b/phivenv/Lib/site-packages/sympy/polys/tests/test_polyfuncs.py new file mode 100644 index 0000000000000000000000000000000000000000..496f63bf14e4dd9f68cf653004eb35a3ed7615ca --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/tests/test_polyfuncs.py @@ -0,0 +1,126 @@ +"""Tests for high-level polynomials manipulation functions. """ + +from sympy.polys.polyfuncs import ( + symmetrize, horner, interpolate, rational_interpolate, viete, +) + +from sympy.polys.polyerrors import ( + MultivariatePolynomialError, +) + +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.testing.pytest import raises + +from sympy.abc import a, b, c, d, e, x, y, z + + +def test_symmetrize(): + assert symmetrize(0, x, y, z) == (0, 0) + assert symmetrize(1, x, y, z) == (1, 0) + + s1 = x + y + z + s2 = x*y + x*z + y*z + + assert symmetrize(1) == (1, 0) + assert symmetrize(1, formal=True) == (1, 0, []) + + assert symmetrize(x) == (x, 0) + assert symmetrize(x + 1) == (x + 1, 0) + + assert symmetrize(x, x, y) == (x + y, -y) + assert symmetrize(x + 1, x, y) == (x + y + 1, -y) + + assert symmetrize(x, x, y, z) == (s1, -y - z) + assert symmetrize(x + 1, x, y, z) == (s1 + 1, -y - z) + + assert symmetrize(x**2, x, y, z) == (s1**2 - 2*s2, -y**2 - z**2) + + assert symmetrize(x**2 + y**2) == (-2*x*y + (x + y)**2, 0) + assert symmetrize(x**2 - y**2) == (-2*x*y + (x + y)**2, -2*y**2) + + assert symmetrize(x**3 + y**2 + a*x**2 + b*y**3, x, y) == \ + (-3*x*y*(x + y) - 2*a*x*y + a*(x + y)**2 + (x + y)**3, + y**2*(1 - a) + y**3*(b - 1)) + + U = [u0, u1, u2] = symbols('u:3') + + assert symmetrize(x + 1, x, y, z, formal=True, symbols=U) == \ + (u0 + 1, -y - z, [(u0, x + y + z), (u1, x*y + x*z + y*z), (u2, x*y*z)]) + + assert symmetrize([1, 2, 3]) == [(1, 0), (2, 0), (3, 0)] + assert symmetrize([1, 2, 3], formal=True) == ([(1, 0), (2, 0), (3, 0)], []) + + assert symmetrize([x + y, x - y]) == [(x + y, 0), (x + y, -2*y)] + + +def test_horner(): + assert horner(0) == 0 + assert horner(1) == 1 + assert horner(x) == x + + assert horner(x + 1) == x + 1 + assert horner(x**2 + 1) == x**2 + 1 + assert horner(x**2 + x) == (x + 1)*x + assert horner(x**2 + x + 1) == (x + 1)*x + 1 + + assert horner( + 9*x**4 + 8*x**3 + 7*x**2 + 6*x + 5) == (((9*x + 8)*x + 7)*x + 6)*x + 5 + assert horner( + a*x**4 + b*x**3 + c*x**2 + d*x + e) == (((a*x + b)*x + c)*x + d)*x + e + + assert horner(4*x**2*y**2 + 2*x**2*y + 2*x*y**2 + x*y, wrt=x) == (( + 4*y + 2)*x*y + (2*y + 1)*y)*x + assert horner(4*x**2*y**2 + 2*x**2*y + 2*x*y**2 + x*y, wrt=y) == (( + 4*x + 2)*y*x + (2*x + 1)*x)*y + + +def test_interpolate(): + assert interpolate([1, 4, 9, 16], x) == x**2 + assert interpolate([1, 4, 9, 25], x) == S(3)*x**3/2 - S(8)*x**2 + S(33)*x/2 - 9 + assert interpolate([(1, 1), (2, 4), (3, 9)], x) == x**2 + assert interpolate([(1, 2), (2, 5), (3, 10)], x) == 1 + x**2 + assert interpolate({1: 2, 2: 5, 3: 10}, x) == 1 + x**2 + assert interpolate({5: 2, 7: 5, 8: 10, 9: 13}, x) == \ + -S(13)*x**3/24 + S(12)*x**2 - S(2003)*x/24 + 187 + assert interpolate([(1, 3), (0, 6), (2, 5), (5, 7), (-2, 4)], x) == \ + S(-61)*x**4/280 + S(247)*x**3/210 + S(139)*x**2/280 - S(1871)*x/420 + 6 + assert interpolate((9, 4, 9), 3) == 9 + assert interpolate((1, 9, 16), 1) is S.One + assert interpolate(((x, 1), (2, 3)), x) is S.One + assert interpolate({x: 1, 2: 3}, x) is S.One + assert interpolate(((2, x), (1, 3)), x) == x**2 - 4*x + 6 + + +def test_rational_interpolate(): + x, y = symbols('x,y') + xdata = [1, 2, 3, 4, 5, 6] + ydata1 = [120, 150, 200, 255, 312, 370] + ydata2 = [-210, -35, 105, 231, 350, 465] + assert rational_interpolate(list(zip(xdata, ydata1)), 2) == ( + (60*x**2 + 60)/x ) + assert rational_interpolate(list(zip(xdata, ydata1)), 3) == ( + (60*x**2 + 60)/x ) + assert rational_interpolate(list(zip(xdata, ydata2)), 2, X=y) == ( + (105*y**2 - 525)/(y + 1) ) + xdata = list(range(1,11)) + ydata = [-1923885361858460, -5212158811973685, -9838050145867125, + -15662936261217245, -22469424125057910, -30073793365223685, + -38332297297028735, -47132954289530109, -56387719094026320, + -66026548943876885] + assert rational_interpolate(list(zip(xdata, ydata)), 5) == ( + (-12986226192544605*x**4 + + 8657484128363070*x**3 - 30301194449270745*x**2 + 4328742064181535*x + - 4328742064181535)/(x**3 + 9*x**2 - 3*x + 11)) + + +def test_viete(): + r1, r2 = symbols('r1, r2') + + assert viete( + a*x**2 + b*x + c, [r1, r2], x) == [(r1 + r2, -b/a), (r1*r2, c/a)] + + raises(ValueError, lambda: viete(1, [], x)) + raises(ValueError, lambda: viete(x**2 + 1, [r1])) + + raises(MultivariatePolynomialError, lambda: viete(x + y, [r1])) diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/test_polymatrix.py b/phivenv/Lib/site-packages/sympy/polys/tests/test_polymatrix.py new file mode 100644 index 0000000000000000000000000000000000000000..287f23d537392510acda094e764a8c3dbbd1ef73 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/tests/test_polymatrix.py @@ -0,0 +1,185 @@ +from sympy.testing.pytest import raises + +from sympy.polys.polymatrix import PolyMatrix +from sympy.polys import Poly + +from sympy.core.singleton import S +from sympy.matrices.dense import Matrix +from sympy.polys.domains.integerring import ZZ +from sympy.polys.domains.rationalfield import QQ + +from sympy.abc import x, y + + +def _test_polymatrix(): + pm1 = PolyMatrix([[Poly(x**2, x), Poly(-x, x)], [Poly(x**3, x), Poly(-1 + x, x)]]) + v1 = PolyMatrix([[1, 0], [-1, 0]], ring='ZZ[x]') + m1 = PolyMatrix([[1, 0], [-1, 0]], ring='ZZ[x]') + A = PolyMatrix([[Poly(x**2 + x, x), Poly(0, x)], \ + [Poly(x**3 - x + 1, x), Poly(0, x)]]) + B = PolyMatrix([[Poly(x**2, x), Poly(-x, x)], [Poly(-x**2, x), Poly(x, x)]]) + assert A.ring == ZZ[x] + assert isinstance(pm1*v1, PolyMatrix) + assert pm1*v1 == A + assert pm1*m1 == A + assert v1*pm1 == B + + pm2 = PolyMatrix([[Poly(x**2, x, domain='QQ'), Poly(0, x, domain='QQ'), Poly(-x**2, x, domain='QQ'), \ + Poly(x**3, x, domain='QQ'), Poly(0, x, domain='QQ'), Poly(-x**3, x, domain='QQ')]]) + assert pm2.ring == QQ[x] + v2 = PolyMatrix([1, 0, 0, 0, 0, 0], ring='ZZ[x]') + m2 = PolyMatrix([1, 0, 0, 0, 0, 0], ring='ZZ[x]') + C = PolyMatrix([[Poly(x**2, x, domain='QQ')]]) + assert pm2*v2 == C + assert pm2*m2 == C + + pm3 = PolyMatrix([[Poly(x**2, x), S.One]], ring='ZZ[x]') + v3 = S.Half*pm3 + assert v3 == PolyMatrix([[Poly(S.Half*x**2, x, domain='QQ'), S.Half]], ring='QQ[x]') + assert pm3*S.Half == v3 + assert v3.ring == QQ[x] + + pm4 = PolyMatrix([[Poly(x**2, x, domain='ZZ'), Poly(-x**2, x, domain='ZZ')]]) + v4 = PolyMatrix([1, -1], ring='ZZ[x]') + assert pm4*v4 == PolyMatrix([[Poly(2*x**2, x, domain='ZZ')]]) + + assert len(PolyMatrix(ring=ZZ[x])) == 0 + assert PolyMatrix([1, 0, 0, 1], x)/(-1) == PolyMatrix([-1, 0, 0, -1], x) + + +def test_polymatrix_constructor(): + M1 = PolyMatrix([[x, y]], ring=QQ[x,y]) + assert M1.ring == QQ[x,y] + assert M1.domain == QQ + assert M1.gens == (x, y) + assert M1.shape == (1, 2) + assert M1.rows == 1 + assert M1.cols == 2 + assert len(M1) == 2 + assert list(M1) == [Poly(x, (x, y), domain=QQ), Poly(y, (x, y), domain=QQ)] + + M2 = PolyMatrix([[x, y]], ring=QQ[x][y]) + assert M2.ring == QQ[x][y] + assert M2.domain == QQ[x] + assert M2.gens == (y,) + assert M2.shape == (1, 2) + assert M2.rows == 1 + assert M2.cols == 2 + assert len(M2) == 2 + assert list(M2) == [Poly(x, (y,), domain=QQ[x]), Poly(y, (y,), domain=QQ[x])] + + assert PolyMatrix([[x, y]], y) == PolyMatrix([[x, y]], ring=ZZ.frac_field(x)[y]) + assert PolyMatrix([[x, y]], ring='ZZ[x,y]') == PolyMatrix([[x, y]], ring=ZZ[x,y]) + + assert PolyMatrix([[x, y]], (x, y)) == PolyMatrix([[x, y]], ring=QQ[x,y]) + assert PolyMatrix([[x, y]], x, y) == PolyMatrix([[x, y]], ring=QQ[x,y]) + assert PolyMatrix([x, y]) == PolyMatrix([[x], [y]], ring=QQ[x,y]) + assert PolyMatrix(1, 2, [x, y]) == PolyMatrix([[x, y]], ring=QQ[x,y]) + assert PolyMatrix(1, 2, lambda i,j: [x,y][j]) == PolyMatrix([[x, y]], ring=QQ[x,y]) + assert PolyMatrix(0, 2, [], x, y).shape == (0, 2) + assert PolyMatrix(2, 0, [], x, y).shape == (2, 0) + assert PolyMatrix([[], []], x, y).shape == (2, 0) + assert PolyMatrix(ring=QQ[x,y]) == PolyMatrix(0, 0, [], ring=QQ[x,y]) == PolyMatrix([], ring=QQ[x,y]) + raises(TypeError, lambda: PolyMatrix()) + raises(TypeError, lambda: PolyMatrix(1)) + + assert PolyMatrix([Poly(x), Poly(y)]) == PolyMatrix([[x], [y]], ring=ZZ[x,y]) + + # XXX: Maybe a bug in parallel_poly_from_expr (x lost from gens and domain): + assert PolyMatrix([Poly(y, x), 1]) == PolyMatrix([[y], [1]], ring=QQ[y]) + + +def test_polymatrix_eq(): + assert (PolyMatrix([x]) == PolyMatrix([x])) is True + assert (PolyMatrix([y]) == PolyMatrix([x])) is False + assert (PolyMatrix([x]) != PolyMatrix([x])) is False + assert (PolyMatrix([y]) != PolyMatrix([x])) is True + + assert PolyMatrix([[x, y]]) != PolyMatrix([x, y]) == PolyMatrix([[x], [y]]) + + assert PolyMatrix([x], ring=QQ[x]) != PolyMatrix([x], ring=ZZ[x]) + + assert PolyMatrix([x]) != Matrix([x]) + assert PolyMatrix([x]).to_Matrix() == Matrix([x]) + + assert PolyMatrix([1], x) == PolyMatrix([1], x) + assert PolyMatrix([1], x) != PolyMatrix([1], y) + + +def test_polymatrix_from_Matrix(): + assert PolyMatrix.from_Matrix(Matrix([1, 2]), x) == PolyMatrix([1, 2], x, ring=QQ[x]) + assert PolyMatrix.from_Matrix(Matrix([1]), ring=QQ[x]) == PolyMatrix([1], x) + pmx = PolyMatrix([1, 2], x) + pmy = PolyMatrix([1, 2], y) + assert pmx != pmy + assert pmx.set_gens(y) == pmy + + +def test_polymatrix_repr(): + assert repr(PolyMatrix([[1, 2]], x)) == 'PolyMatrix([[1, 2]], ring=QQ[x])' + assert repr(PolyMatrix(0, 2, [], x)) == 'PolyMatrix(0, 2, [], ring=QQ[x])' + + +def test_polymatrix_getitem(): + M = PolyMatrix([[1, 2], [3, 4]], x) + assert M[:, :] == M + assert M[0, :] == PolyMatrix([[1, 2]], x) + assert M[:, 0] == PolyMatrix([1, 3], x) + assert M[0, 0] == Poly(1, x, domain=QQ) + assert M[0] == Poly(1, x, domain=QQ) + assert M[:2] == [Poly(1, x, domain=QQ), Poly(2, x, domain=QQ)] + + +def test_polymatrix_arithmetic(): + M = PolyMatrix([[1, 2], [3, 4]], x) + assert M + M == PolyMatrix([[2, 4], [6, 8]], x) + assert M - M == PolyMatrix([[0, 0], [0, 0]], x) + assert -M == PolyMatrix([[-1, -2], [-3, -4]], x) + raises(TypeError, lambda: M + 1) + raises(TypeError, lambda: M - 1) + raises(TypeError, lambda: 1 + M) + raises(TypeError, lambda: 1 - M) + + assert M * M == PolyMatrix([[7, 10], [15, 22]], x) + assert 2 * M == PolyMatrix([[2, 4], [6, 8]], x) + assert M * 2 == PolyMatrix([[2, 4], [6, 8]], x) + assert S(2) * M == PolyMatrix([[2, 4], [6, 8]], x) + assert M * S(2) == PolyMatrix([[2, 4], [6, 8]], x) + raises(TypeError, lambda: [] * M) + raises(TypeError, lambda: M * []) + M2 = PolyMatrix([[1, 2]], ring=ZZ[x]) + assert S.Half * M2 == PolyMatrix([[S.Half, 1]], ring=QQ[x]) + assert M2 * S.Half == PolyMatrix([[S.Half, 1]], ring=QQ[x]) + + assert M / 2 == PolyMatrix([[S(1)/2, 1], [S(3)/2, 2]], x) + assert M / Poly(2, x) == PolyMatrix([[S(1)/2, 1], [S(3)/2, 2]], x) + raises(TypeError, lambda: M / []) + + +def test_polymatrix_manipulations(): + M1 = PolyMatrix([[1, 2], [3, 4]], x) + assert M1.transpose() == PolyMatrix([[1, 3], [2, 4]], x) + M2 = PolyMatrix([[5, 6], [7, 8]], x) + assert M1.row_join(M2) == PolyMatrix([[1, 2, 5, 6], [3, 4, 7, 8]], x) + assert M1.col_join(M2) == PolyMatrix([[1, 2], [3, 4], [5, 6], [7, 8]], x) + assert M1.applyfunc(lambda e: 2*e) == PolyMatrix([[2, 4], [6, 8]], x) + + +def test_polymatrix_ones_zeros(): + assert PolyMatrix.zeros(1, 2, x) == PolyMatrix([[0, 0]], x) + assert PolyMatrix.eye(2, x) == PolyMatrix([[1, 0], [0, 1]], x) + + +def test_polymatrix_rref(): + M = PolyMatrix([[1, 2], [3, 4]], x) + assert M.rref() == (PolyMatrix.eye(2, x), (0, 1)) + raises(ValueError, lambda: PolyMatrix([1, 2], ring=ZZ[x]).rref()) + raises(ValueError, lambda: PolyMatrix([1, x], ring=QQ[x]).rref()) + + +def test_polymatrix_nullspace(): + M = PolyMatrix([[1, 2], [3, 6]], x) + assert M.nullspace() == [PolyMatrix([-2, 1], x)] + raises(ValueError, lambda: PolyMatrix([1, 2], ring=ZZ[x]).nullspace()) + raises(ValueError, lambda: PolyMatrix([1, x], ring=QQ[x]).nullspace()) + assert M.rank() == 1 diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/test_polyoptions.py b/phivenv/Lib/site-packages/sympy/polys/tests/test_polyoptions.py new file mode 100644 index 0000000000000000000000000000000000000000..fa2e6054bad43aef5470949180ea5c2ffdc11f30 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/tests/test_polyoptions.py @@ -0,0 +1,485 @@ +"""Tests for options manager for :class:`Poly` and public API functions. """ + +from sympy.polys.polyoptions import ( + Options, Expand, Gens, Wrt, Sort, Order, Field, Greedy, Domain, + Split, Gaussian, Extension, Modulus, Symmetric, Strict, Auto, + Frac, Formal, Polys, Include, All, Gen, Symbols, Method) + +from sympy.polys.orderings import lex +from sympy.polys.domains import FF, GF, ZZ, QQ, QQ_I, RR, CC, EX + +from sympy.polys.polyerrors import OptionError, GeneratorsError + +from sympy.core.numbers import (I, Integer) +from sympy.core.symbol import Symbol +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.testing.pytest import raises +from sympy.abc import x, y, z + + +def test_Options_clone(): + opt = Options((x, y, z), {'domain': 'ZZ'}) + + assert opt.gens == (x, y, z) + assert opt.domain == ZZ + assert ('order' in opt) is False + + new_opt = opt.clone({'gens': (x, y), 'order': 'lex'}) + + assert opt.gens == (x, y, z) + assert opt.domain == ZZ + assert ('order' in opt) is False + + assert new_opt.gens == (x, y) + assert new_opt.domain == ZZ + assert ('order' in new_opt) is True + + +def test_Expand_preprocess(): + assert Expand.preprocess(False) is False + assert Expand.preprocess(True) is True + + assert Expand.preprocess(0) is False + assert Expand.preprocess(1) is True + + raises(OptionError, lambda: Expand.preprocess(x)) + + +def test_Expand_postprocess(): + opt = {'expand': True} + Expand.postprocess(opt) + + assert opt == {'expand': True} + + +def test_Gens_preprocess(): + assert Gens.preprocess((None,)) == () + assert Gens.preprocess((x, y, z)) == (x, y, z) + assert Gens.preprocess(((x, y, z),)) == (x, y, z) + + a = Symbol('a', commutative=False) + + raises(GeneratorsError, lambda: Gens.preprocess((x, x, y))) + raises(GeneratorsError, lambda: Gens.preprocess((x, y, a))) + + +def test_Gens_postprocess(): + opt = {'gens': (x, y)} + Gens.postprocess(opt) + + assert opt == {'gens': (x, y)} + + +def test_Wrt_preprocess(): + assert Wrt.preprocess(x) == ['x'] + assert Wrt.preprocess('') == [] + assert Wrt.preprocess(' ') == [] + assert Wrt.preprocess('x,y') == ['x', 'y'] + assert Wrt.preprocess('x y') == ['x', 'y'] + assert Wrt.preprocess('x, y') == ['x', 'y'] + assert Wrt.preprocess('x , y') == ['x', 'y'] + assert Wrt.preprocess(' x, y') == ['x', 'y'] + assert Wrt.preprocess(' x, y') == ['x', 'y'] + assert Wrt.preprocess([x, y]) == ['x', 'y'] + + raises(OptionError, lambda: Wrt.preprocess(',')) + raises(OptionError, lambda: Wrt.preprocess(0)) + + +def test_Wrt_postprocess(): + opt = {'wrt': ['x']} + Wrt.postprocess(opt) + + assert opt == {'wrt': ['x']} + + +def test_Sort_preprocess(): + assert Sort.preprocess([x, y, z]) == ['x', 'y', 'z'] + assert Sort.preprocess((x, y, z)) == ['x', 'y', 'z'] + + assert Sort.preprocess('x > y > z') == ['x', 'y', 'z'] + assert Sort.preprocess('x>y>z') == ['x', 'y', 'z'] + + raises(OptionError, lambda: Sort.preprocess(0)) + raises(OptionError, lambda: Sort.preprocess({x, y, z})) + + +def test_Sort_postprocess(): + opt = {'sort': 'x > y'} + Sort.postprocess(opt) + + assert opt == {'sort': 'x > y'} + + +def test_Order_preprocess(): + assert Order.preprocess('lex') == lex + + +def test_Order_postprocess(): + opt = {'order': True} + Order.postprocess(opt) + + assert opt == {'order': True} + + +def test_Field_preprocess(): + assert Field.preprocess(False) is False + assert Field.preprocess(True) is True + + assert Field.preprocess(0) is False + assert Field.preprocess(1) is True + + raises(OptionError, lambda: Field.preprocess(x)) + + +def test_Field_postprocess(): + opt = {'field': True} + Field.postprocess(opt) + + assert opt == {'field': True} + + +def test_Greedy_preprocess(): + assert Greedy.preprocess(False) is False + assert Greedy.preprocess(True) is True + + assert Greedy.preprocess(0) is False + assert Greedy.preprocess(1) is True + + raises(OptionError, lambda: Greedy.preprocess(x)) + + +def test_Greedy_postprocess(): + opt = {'greedy': True} + Greedy.postprocess(opt) + + assert opt == {'greedy': True} + + +def test_Domain_preprocess(): + assert Domain.preprocess(ZZ) == ZZ + assert Domain.preprocess(QQ) == QQ + assert Domain.preprocess(EX) == EX + assert Domain.preprocess(FF(2)) == FF(2) + assert Domain.preprocess(ZZ[x, y]) == ZZ[x, y] + + assert Domain.preprocess('Z') == ZZ + assert Domain.preprocess('Q') == QQ + + assert Domain.preprocess('ZZ') == ZZ + assert Domain.preprocess('QQ') == QQ + + assert Domain.preprocess('EX') == EX + + assert Domain.preprocess('FF(23)') == FF(23) + assert Domain.preprocess('GF(23)') == GF(23) + + raises(OptionError, lambda: Domain.preprocess('Z[]')) + + assert Domain.preprocess('Z[x]') == ZZ[x] + assert Domain.preprocess('Q[x]') == QQ[x] + assert Domain.preprocess('R[x]') == RR[x] + assert Domain.preprocess('C[x]') == CC[x] + + assert Domain.preprocess('ZZ[x]') == ZZ[x] + assert Domain.preprocess('QQ[x]') == QQ[x] + assert Domain.preprocess('RR[x]') == RR[x] + assert Domain.preprocess('CC[x]') == CC[x] + + assert Domain.preprocess('Z[x,y]') == ZZ[x, y] + assert Domain.preprocess('Q[x,y]') == QQ[x, y] + assert Domain.preprocess('R[x,y]') == RR[x, y] + assert Domain.preprocess('C[x,y]') == CC[x, y] + + assert Domain.preprocess('ZZ[x,y]') == ZZ[x, y] + assert Domain.preprocess('QQ[x,y]') == QQ[x, y] + assert Domain.preprocess('RR[x,y]') == RR[x, y] + assert Domain.preprocess('CC[x,y]') == CC[x, y] + + raises(OptionError, lambda: Domain.preprocess('Z()')) + + assert Domain.preprocess('Z(x)') == ZZ.frac_field(x) + assert Domain.preprocess('Q(x)') == QQ.frac_field(x) + + assert Domain.preprocess('ZZ(x)') == ZZ.frac_field(x) + assert Domain.preprocess('QQ(x)') == QQ.frac_field(x) + + assert Domain.preprocess('Z(x,y)') == ZZ.frac_field(x, y) + assert Domain.preprocess('Q(x,y)') == QQ.frac_field(x, y) + + assert Domain.preprocess('ZZ(x,y)') == ZZ.frac_field(x, y) + assert Domain.preprocess('QQ(x,y)') == QQ.frac_field(x, y) + + assert Domain.preprocess('Q') == QQ.algebraic_field(I) + assert Domain.preprocess('QQ') == QQ.algebraic_field(I) + + assert Domain.preprocess('Q') == QQ.algebraic_field(sqrt(2), I) + assert Domain.preprocess( + 'QQ') == QQ.algebraic_field(sqrt(2), I) + + raises(OptionError, lambda: Domain.preprocess('abc')) + + +def test_Domain_postprocess(): + raises(GeneratorsError, lambda: Domain.postprocess({'gens': (x, y), + 'domain': ZZ[y, z]})) + + raises(GeneratorsError, lambda: Domain.postprocess({'gens': (), + 'domain': EX})) + raises(GeneratorsError, lambda: Domain.postprocess({'domain': EX})) + + +def test_Split_preprocess(): + assert Split.preprocess(False) is False + assert Split.preprocess(True) is True + + assert Split.preprocess(0) is False + assert Split.preprocess(1) is True + + raises(OptionError, lambda: Split.preprocess(x)) + + +def test_Split_postprocess(): + raises(NotImplementedError, lambda: Split.postprocess({'split': True})) + + +def test_Gaussian_preprocess(): + assert Gaussian.preprocess(False) is False + assert Gaussian.preprocess(True) is True + + assert Gaussian.preprocess(0) is False + assert Gaussian.preprocess(1) is True + + raises(OptionError, lambda: Gaussian.preprocess(x)) + + +def test_Gaussian_postprocess(): + opt = {'gaussian': True} + Gaussian.postprocess(opt) + + assert opt == { + 'gaussian': True, + 'domain': QQ_I, + } + + +def test_Extension_preprocess(): + assert Extension.preprocess(True) is True + assert Extension.preprocess(1) is True + + assert Extension.preprocess([]) is None + + assert Extension.preprocess(sqrt(2)) == {sqrt(2)} + assert Extension.preprocess([sqrt(2)]) == {sqrt(2)} + + assert Extension.preprocess([sqrt(2), I]) == {sqrt(2), I} + + raises(OptionError, lambda: Extension.preprocess(False)) + raises(OptionError, lambda: Extension.preprocess(0)) + + +def test_Extension_postprocess(): + opt = {'extension': {sqrt(2)}} + Extension.postprocess(opt) + + assert opt == { + 'extension': {sqrt(2)}, + 'domain': QQ.algebraic_field(sqrt(2)), + } + + opt = {'extension': True} + Extension.postprocess(opt) + + assert opt == {'extension': True} + + +def test_Modulus_preprocess(): + assert Modulus.preprocess(23) == 23 + assert Modulus.preprocess(Integer(23)) == 23 + + raises(OptionError, lambda: Modulus.preprocess(0)) + raises(OptionError, lambda: Modulus.preprocess(x)) + + +def test_Modulus_postprocess(): + opt = {'modulus': 5} + Modulus.postprocess(opt) + + assert opt == { + 'modulus': 5, + 'domain': FF(5), + } + + opt = {'modulus': 5, 'symmetric': False} + Modulus.postprocess(opt) + + assert opt == { + 'modulus': 5, + 'domain': FF(5, False), + 'symmetric': False, + } + + +def test_Symmetric_preprocess(): + assert Symmetric.preprocess(False) is False + assert Symmetric.preprocess(True) is True + + assert Symmetric.preprocess(0) is False + assert Symmetric.preprocess(1) is True + + raises(OptionError, lambda: Symmetric.preprocess(x)) + + +def test_Symmetric_postprocess(): + opt = {'symmetric': True} + Symmetric.postprocess(opt) + + assert opt == {'symmetric': True} + + +def test_Strict_preprocess(): + assert Strict.preprocess(False) is False + assert Strict.preprocess(True) is True + + assert Strict.preprocess(0) is False + assert Strict.preprocess(1) is True + + raises(OptionError, lambda: Strict.preprocess(x)) + + +def test_Strict_postprocess(): + opt = {'strict': True} + Strict.postprocess(opt) + + assert opt == {'strict': True} + + +def test_Auto_preprocess(): + assert Auto.preprocess(False) is False + assert Auto.preprocess(True) is True + + assert Auto.preprocess(0) is False + assert Auto.preprocess(1) is True + + raises(OptionError, lambda: Auto.preprocess(x)) + + +def test_Auto_postprocess(): + opt = {'auto': True} + Auto.postprocess(opt) + + assert opt == {'auto': True} + + +def test_Frac_preprocess(): + assert Frac.preprocess(False) is False + assert Frac.preprocess(True) is True + + assert Frac.preprocess(0) is False + assert Frac.preprocess(1) is True + + raises(OptionError, lambda: Frac.preprocess(x)) + + +def test_Frac_postprocess(): + opt = {'frac': True} + Frac.postprocess(opt) + + assert opt == {'frac': True} + + +def test_Formal_preprocess(): + assert Formal.preprocess(False) is False + assert Formal.preprocess(True) is True + + assert Formal.preprocess(0) is False + assert Formal.preprocess(1) is True + + raises(OptionError, lambda: Formal.preprocess(x)) + + +def test_Formal_postprocess(): + opt = {'formal': True} + Formal.postprocess(opt) + + assert opt == {'formal': True} + + +def test_Polys_preprocess(): + assert Polys.preprocess(False) is False + assert Polys.preprocess(True) is True + + assert Polys.preprocess(0) is False + assert Polys.preprocess(1) is True + + raises(OptionError, lambda: Polys.preprocess(x)) + + +def test_Polys_postprocess(): + opt = {'polys': True} + Polys.postprocess(opt) + + assert opt == {'polys': True} + + +def test_Include_preprocess(): + assert Include.preprocess(False) is False + assert Include.preprocess(True) is True + + assert Include.preprocess(0) is False + assert Include.preprocess(1) is True + + raises(OptionError, lambda: Include.preprocess(x)) + + +def test_Include_postprocess(): + opt = {'include': True} + Include.postprocess(opt) + + assert opt == {'include': True} + + +def test_All_preprocess(): + assert All.preprocess(False) is False + assert All.preprocess(True) is True + + assert All.preprocess(0) is False + assert All.preprocess(1) is True + + raises(OptionError, lambda: All.preprocess(x)) + + +def test_All_postprocess(): + opt = {'all': True} + All.postprocess(opt) + + assert opt == {'all': True} + + +def test_Gen_postprocess(): + opt = {'gen': x} + Gen.postprocess(opt) + + assert opt == {'gen': x} + + +def test_Symbols_preprocess(): + raises(OptionError, lambda: Symbols.preprocess(x)) + + +def test_Symbols_postprocess(): + opt = {'symbols': [x, y, z]} + Symbols.postprocess(opt) + + assert opt == {'symbols': [x, y, z]} + + +def test_Method_preprocess(): + raises(OptionError, lambda: Method.preprocess(10)) + + +def test_Method_postprocess(): + opt = {'method': 'f5b'} + Method.postprocess(opt) + + assert opt == {'method': 'f5b'} diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/test_polyroots.py b/phivenv/Lib/site-packages/sympy/polys/tests/test_polyroots.py new file mode 100644 index 0000000000000000000000000000000000000000..7f96b1930f6789ce3150ae2c920ba7d9faa68791 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/tests/test_polyroots.py @@ -0,0 +1,758 @@ +"""Tests for algorithms for computing symbolic roots of polynomials. """ + +from sympy.core.numbers import (I, Rational, pi) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, Wild, symbols) +from sympy.functions.elementary.complexes import (conjugate, im, re) +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import (root, sqrt) +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import (acos, cos, sin) +from sympy.polys.domains.integerring import ZZ +from sympy.sets.sets import Interval +from sympy.simplify.powsimp import powsimp + +from sympy.polys import Poly, cyclotomic_poly, intervals, nroots, rootof + +from sympy.polys.polyroots import (root_factors, roots_linear, + roots_quadratic, roots_cubic, roots_quartic, roots_quintic, + roots_cyclotomic, roots_binomial, preprocess_roots, roots) + +from sympy.polys.orthopolys import legendre_poly +from sympy.polys.polyerrors import PolynomialError, \ + UnsolvableFactorError +from sympy.polys.polyutils import _nsort + +from sympy.testing.pytest import raises, slow +from sympy.core.random import verify_numerically +import mpmath +from itertools import product + + + +a, b, c, d, e, q, t, x, y, z = symbols('a,b,c,d,e,q,t,x,y,z') + + +def _check(roots): + # this is the desired invariant for roots returned + # by all_roots. It is trivially true for linear + # polynomials. + nreal = sum(1 if i.is_real else 0 for i in roots) + assert sorted(roots[:nreal]) == list(roots[:nreal]) + for ix in range(nreal, len(roots), 2): + if not ( + roots[ix + 1] == roots[ix] or + roots[ix + 1] == conjugate(roots[ix])): + return False + return True + + +def test_roots_linear(): + assert roots_linear(Poly(2*x + 1, x)) == [Rational(-1, 2)] + + +def test_roots_quadratic(): + assert roots_quadratic(Poly(2*x**2, x)) == [0, 0] + assert roots_quadratic(Poly(2*x**2 + 3*x, x)) == [Rational(-3, 2), 0] + assert roots_quadratic(Poly(2*x**2 + 3, x)) == [-I*sqrt(6)/2, I*sqrt(6)/2] + assert roots_quadratic(Poly(2*x**2 + 4*x + 3, x)) == [-1 - I*sqrt(2)/2, -1 + I*sqrt(2)/2] + _check(Poly(2*x**2 + 4*x + 3, x).all_roots()) + + f = x**2 + (2*a*e + 2*c*e)/(a - c)*x + (d - b + a*e**2 - c*e**2)/(a - c) + assert roots_quadratic(Poly(f, x)) == \ + [-e*(a + c)/(a - c) - sqrt(a*b + c*d - a*d - b*c + 4*a*c*e**2)/(a - c), + -e*(a + c)/(a - c) + sqrt(a*b + c*d - a*d - b*c + 4*a*c*e**2)/(a - c)] + + # check for simplification + f = Poly(y*x**2 - 2*x - 2*y, x) + assert roots_quadratic(f) == \ + [-sqrt(2*y**2 + 1)/y + 1/y, sqrt(2*y**2 + 1)/y + 1/y] + f = Poly(x**2 + (-y**2 - 2)*x + y**2 + 1, x) + assert roots_quadratic(f) == \ + [1,y**2 + 1] + + f = Poly(sqrt(2)*x**2 - 1, x) + r = roots_quadratic(f) + assert r == _nsort(r) + + # issue 8255 + f = Poly(-24*x**2 - 180*x + 264) + assert [w.n(2) for w in f.all_roots(radicals=True)] == \ + [w.n(2) for w in f.all_roots(radicals=False)] + for _a, _b, _c in product((-2, 2), (-2, 2), (0, -1)): + f = Poly(_a*x**2 + _b*x + _c) + roots = roots_quadratic(f) + assert roots == _nsort(roots) + + +def test_issue_7724(): + eq = Poly(x**4*I + x**2 + I, x) + assert roots(eq) == { + sqrt(I/2 + sqrt(5)*I/2): 1, + sqrt(-sqrt(5)*I/2 + I/2): 1, + -sqrt(I/2 + sqrt(5)*I/2): 1, + -sqrt(-sqrt(5)*I/2 + I/2): 1} + + +def test_issue_8438(): + p = Poly([1, y, -2, -3], x).as_expr() + roots = roots_cubic(Poly(p, x), x) + z = Rational(-3, 2) - I*7/2 # this will fail in code given in commit msg + post = [r.subs(y, z) for r in roots] + assert set(post) == \ + set(roots_cubic(Poly(p.subs(y, z), x))) + # /!\ if p is not made an expression, this is *very* slow + assert all(p.subs({y: z, x: i}).n(2, chop=True) == 0 for i in post) + + +def test_issue_8285(): + roots = (Poly(4*x**8 - 1, x)*Poly(x**2 + 1)).all_roots() + assert _check(roots) + f = Poly(x**4 + 5*x**2 + 6, x) + ro = [rootof(f, i) for i in range(4)] + roots = Poly(x**4 + 5*x**2 + 6, x).all_roots() + assert roots == ro + assert _check(roots) + # more than 2 complex roots from which to identify the + # imaginary ones + roots = Poly(2*x**8 - 1).all_roots() + assert _check(roots) + assert len(Poly(2*x**10 - 1).all_roots()) == 10 # doesn't fail + + +def test_issue_8289(): + roots = (Poly(x**2 + 2)*Poly(x**4 + 2)).all_roots() + assert _check(roots) + roots = Poly(x**6 + 3*x**3 + 2, x).all_roots() + assert _check(roots) + roots = Poly(x**6 - x + 1).all_roots() + assert _check(roots) + # all imaginary roots with multiplicity of 2 + roots = Poly(x**4 + 4*x**2 + 4, x).all_roots() + assert _check(roots) + + +def test_issue_14291(): + assert Poly(((x - 1)**2 + 1)*((x - 1)**2 + 2)*(x - 1) + ).all_roots() == [1, 1 - I, 1 + I, 1 - sqrt(2)*I, 1 + sqrt(2)*I] + p = x**4 + 10*x**2 + 1 + ans = [rootof(p, i) for i in range(4)] + assert Poly(p).all_roots() == ans + _check(ans) + + +def test_issue_13340(): + eq = Poly(y**3 + exp(x)*y + x, y, domain='EX') + roots_d = roots(eq) + assert len(roots_d) == 3 + + +def test_issue_14522(): + eq = Poly(x**4 + x**3*(16 + 32*I) + x**2*(-285 + 386*I) + x*(-2824 - 448*I) - 2058 - 6053*I, x) + roots_eq = roots(eq) + assert all(eq(r) == 0 for r in roots_eq) + + +def test_issue_15076(): + sol = roots_quartic(Poly(t**4 - 6*t**2 + t/x - 3, t)) + assert sol[0].has(x) + + +def test_issue_16589(): + eq = Poly(x**4 - 8*sqrt(2)*x**3 + 4*x**3 - 64*sqrt(2)*x**2 + 1024*x, x) + roots_eq = roots(eq) + assert 0 in roots_eq + + +def test_roots_cubic(): + assert roots_cubic(Poly(2*x**3, x)) == [0, 0, 0] + assert roots_cubic(Poly(x**3 - 3*x**2 + 3*x - 1, x)) == [1, 1, 1] + + # valid for arbitrary y (issue 21263) + r = root(y, 3) + assert roots_cubic(Poly(x**3 - y, x)) == [r, + r*(-S.Half + sqrt(3)*I/2), + r*(-S.Half - sqrt(3)*I/2)] + # simpler form when y is negative + assert roots_cubic(Poly(x**3 - -1, x)) == \ + [-1, S.Half - I*sqrt(3)/2, S.Half + I*sqrt(3)/2] + assert roots_cubic(Poly(2*x**3 - 3*x**2 - 3*x - 1, x))[0] == \ + S.Half + 3**Rational(1, 3)/2 + 3**Rational(2, 3)/2 + eq = -x**3 + 2*x**2 + 3*x - 2 + assert roots(eq, trig=True, multiple=True) == \ + roots_cubic(Poly(eq, x), trig=True) == [ + Rational(2, 3) + 2*sqrt(13)*cos(acos(8*sqrt(13)/169)/3)/3, + -2*sqrt(13)*sin(-acos(8*sqrt(13)/169)/3 + pi/6)/3 + Rational(2, 3), + -2*sqrt(13)*cos(-acos(8*sqrt(13)/169)/3 + pi/3)/3 + Rational(2, 3), + ] + + +def test_roots_quartic(): + assert roots_quartic(Poly(x**4, x)) == [0, 0, 0, 0] + assert roots_quartic(Poly(x**4 + x**3, x)) in [ + [-1, 0, 0, 0], + [0, -1, 0, 0], + [0, 0, -1, 0], + [0, 0, 0, -1] + ] + assert roots_quartic(Poly(x**4 - x**3, x)) in [ + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1] + ] + + lhs = roots_quartic(Poly(x**4 + x, x)) + rhs = [S.Half + I*sqrt(3)/2, S.Half - I*sqrt(3)/2, S.Zero, -S.One] + + assert sorted(lhs, key=hash) == sorted(rhs, key=hash) + + # test of all branches of roots quartic + for i, (a, b, c, d) in enumerate([(1, 2, 3, 0), + (3, -7, -9, 9), + (1, 2, 3, 4), + (1, 2, 3, 4), + (-7, -3, 3, -6), + (-3, 5, -6, -4), + (6, -5, -10, -3)]): + if i == 2: + c = -a*(a**2/S(8) - b/S(2)) + elif i == 3: + d = a*(a*(a**2*Rational(3, 256) - b/S(16)) + c/S(4)) + eq = x**4 + a*x**3 + b*x**2 + c*x + d + ans = roots_quartic(Poly(eq, x)) + assert all(eq.subs(x, ai).n(chop=True) == 0 for ai in ans) + + # not all symbolic quartics are unresolvable + eq = Poly(q*x + q/4 + x**4 + x**3 + 2*x**2 - Rational(1, 3), x) + sol = roots_quartic(eq) + assert all(verify_numerically(eq.subs(x, i), 0) for i in sol) + z = symbols('z', negative=True) + eq = x**4 + 2*x**3 + 3*x**2 + x*(z + 11) + 5 + zans = roots_quartic(Poly(eq, x)) + assert all(verify_numerically(eq.subs(((x, i), (z, -1))), 0) for i in zans) + # but some are (see also issue 4989) + # it's ok if the solution is not Piecewise, but the tests below should pass + eq = Poly(y*x**4 + x**3 - x + z, x) + ans = roots_quartic(eq) + assert all(type(i) == Piecewise for i in ans) + reps = ( + {"y": Rational(-1, 3), "z": Rational(-1, 4)}, # 4 real + {"y": Rational(-1, 3), "z": Rational(-1, 2)}, # 2 real + {"y": Rational(-1, 3), "z": -2}) # 0 real + for rep in reps: + sol = roots_quartic(Poly(eq.subs(rep), x)) + assert all(verify_numerically(w.subs(rep) - s, 0) for w, s in zip(ans, sol)) + + +def test_issue_21287(): + assert not any(isinstance(i, Piecewise) for i in roots_quartic( + Poly(x**4 - x**2*(3 + 5*I) + 2*x*(-1 + I) - 1 + 3*I, x))) + + +def test_roots_quintic(): + eqs = (x**5 - 2, + (x/2 + 1)**5 - 5*(x/2 + 1) + 12, + x**5 - 110*x**3 - 55*x**2 + 2310*x + 979) + for eq in eqs: + roots = roots_quintic(Poly(eq)) + assert len(roots) == 5 + assert all(eq.subs(x, r.n(10)).n(chop = 1e-5) == 0 for r in roots) + + +def test_roots_cyclotomic(): + assert roots_cyclotomic(cyclotomic_poly(1, x, polys=True)) == [1] + assert roots_cyclotomic(cyclotomic_poly(2, x, polys=True)) == [-1] + assert roots_cyclotomic(cyclotomic_poly( + 3, x, polys=True)) == [Rational(-1, 2) - I*sqrt(3)/2, Rational(-1, 2) + I*sqrt(3)/2] + assert roots_cyclotomic(cyclotomic_poly(4, x, polys=True)) == [-I, I] + assert roots_cyclotomic(cyclotomic_poly( + 6, x, polys=True)) == [S.Half - I*sqrt(3)/2, S.Half + I*sqrt(3)/2] + + assert roots_cyclotomic(cyclotomic_poly(7, x, polys=True)) == [ + -cos(pi/7) - I*sin(pi/7), + -cos(pi/7) + I*sin(pi/7), + -cos(pi*Rational(3, 7)) - I*sin(pi*Rational(3, 7)), + -cos(pi*Rational(3, 7)) + I*sin(pi*Rational(3, 7)), + cos(pi*Rational(2, 7)) - I*sin(pi*Rational(2, 7)), + cos(pi*Rational(2, 7)) + I*sin(pi*Rational(2, 7)), + ] + + assert roots_cyclotomic(cyclotomic_poly(8, x, polys=True)) == [ + -sqrt(2)/2 - I*sqrt(2)/2, + -sqrt(2)/2 + I*sqrt(2)/2, + sqrt(2)/2 - I*sqrt(2)/2, + sqrt(2)/2 + I*sqrt(2)/2, + ] + + assert roots_cyclotomic(cyclotomic_poly(12, x, polys=True)) == [ + -sqrt(3)/2 - I/2, + -sqrt(3)/2 + I/2, + sqrt(3)/2 - I/2, + sqrt(3)/2 + I/2, + ] + + assert roots_cyclotomic( + cyclotomic_poly(1, x, polys=True), factor=True) == [1] + assert roots_cyclotomic( + cyclotomic_poly(2, x, polys=True), factor=True) == [-1] + + assert roots_cyclotomic(cyclotomic_poly(3, x, polys=True), factor=True) == \ + [-root(-1, 3), -1 + root(-1, 3)] + assert roots_cyclotomic(cyclotomic_poly(4, x, polys=True), factor=True) == \ + [-I, I] + assert roots_cyclotomic(cyclotomic_poly(5, x, polys=True), factor=True) == \ + [-root(-1, 5), -root(-1, 5)**3, root(-1, 5)**2, -1 - root(-1, 5)**2 + root(-1, 5) + root(-1, 5)**3] + + assert roots_cyclotomic(cyclotomic_poly(6, x, polys=True), factor=True) == \ + [1 - root(-1, 3), root(-1, 3)] + + +def test_roots_binomial(): + assert roots_binomial(Poly(5*x, x)) == [0] + assert roots_binomial(Poly(5*x**4, x)) == [0, 0, 0, 0] + assert roots_binomial(Poly(5*x + 2, x)) == [Rational(-2, 5)] + + A = 10**Rational(3, 4)/10 + + assert roots_binomial(Poly(5*x**4 + 2, x)) == \ + [-A - A*I, -A + A*I, A - A*I, A + A*I] + _check(roots_binomial(Poly(x**8 - 2))) + + a1 = Symbol('a1', nonnegative=True) + b1 = Symbol('b1', nonnegative=True) + + r0 = roots_quadratic(Poly(a1*x**2 + b1, x)) + r1 = roots_binomial(Poly(a1*x**2 + b1, x)) + + assert powsimp(r0[0]) == powsimp(r1[0]) + assert powsimp(r0[1]) == powsimp(r1[1]) + for a, b, s, n in product((1, 2), (1, 2), (-1, 1), (2, 3, 4, 5)): + if a == b and a != 1: # a == b == 1 is sufficient + continue + p = Poly(a*x**n + s*b) + ans = roots_binomial(p) + assert ans == _nsort(ans) + + # issue 8813 + assert roots(Poly(2*x**3 - 16*y**3, x)) == { + 2*y*(Rational(-1, 2) - sqrt(3)*I/2): 1, + 2*y: 1, + 2*y*(Rational(-1, 2) + sqrt(3)*I/2): 1} + + +def test_roots_preprocessing(): + f = a*y*x**2 + y - b + + coeff, poly = preprocess_roots(Poly(f, x)) + + assert coeff == 1 + assert poly == Poly(a*y*x**2 + y - b, x) + + f = c**3*x**3 + c**2*x**2 + c*x + a + + coeff, poly = preprocess_roots(Poly(f, x)) + + assert coeff == 1/c + assert poly == Poly(x**3 + x**2 + x + a, x) + + f = c**3*x**3 + c**2*x**2 + a + + coeff, poly = preprocess_roots(Poly(f, x)) + + assert coeff == 1/c + assert poly == Poly(x**3 + x**2 + a, x) + + f = c**3*x**3 + c*x + a + + coeff, poly = preprocess_roots(Poly(f, x)) + + assert coeff == 1/c + assert poly == Poly(x**3 + x + a, x) + + f = c**3*x**3 + a + + coeff, poly = preprocess_roots(Poly(f, x)) + + assert coeff == 1/c + assert poly == Poly(x**3 + a, x) + + E, F, J, L = symbols("E,F,J,L") + + f = -21601054687500000000*E**8*J**8/L**16 + \ + 508232812500000000*F*x*E**7*J**7/L**14 - \ + 4269543750000000*E**6*F**2*J**6*x**2/L**12 + \ + 16194716250000*E**5*F**3*J**5*x**3/L**10 - \ + 27633173750*E**4*F**4*J**4*x**4/L**8 + \ + 14840215*E**3*F**5*J**3*x**5/L**6 + \ + 54794*E**2*F**6*J**2*x**6/(5*L**4) - \ + 1153*E*J*F**7*x**7/(80*L**2) + \ + 633*F**8*x**8/160000 + + coeff, poly = preprocess_roots(Poly(f, x)) + + assert coeff == 20*E*J/(F*L**2) + assert poly == 633*x**8 - 115300*x**7 + 4383520*x**6 + 296804300*x**5 - 27633173750*x**4 + \ + 809735812500*x**3 - 10673859375000*x**2 + 63529101562500*x - 135006591796875 + + f = Poly(-y**2 + x**2*exp(x), y, domain=ZZ[x, exp(x)]) + g = Poly(-y**2 + exp(x), y, domain=ZZ[exp(x)]) + + assert preprocess_roots(f) == (x, g) + + +def test_roots0(): + assert roots(1, x) == {} + assert roots(x, x) == {S.Zero: 1} + assert roots(x**9, x) == {S.Zero: 9} + assert roots(((x - 2)*(x + 3)*(x - 4)).expand(), x) == {-S(3): 1, S(2): 1, S(4): 1} + + assert roots(2*x + 1, x) == {Rational(-1, 2): 1} + assert roots((2*x + 1)**2, x) == {Rational(-1, 2): 2} + assert roots((2*x + 1)**5, x) == {Rational(-1, 2): 5} + assert roots((2*x + 1)**10, x) == {Rational(-1, 2): 10} + + assert roots(x**4 - 1, x) == {I: 1, S.One: 1, -S.One: 1, -I: 1} + assert roots((x**4 - 1)**2, x) == {I: 2, S.One: 2, -S.One: 2, -I: 2} + + assert roots(((2*x - 3)**2).expand(), x) == {Rational( 3, 2): 2} + assert roots(((2*x + 3)**2).expand(), x) == {Rational(-3, 2): 2} + + assert roots(((2*x - 3)**3).expand(), x) == {Rational( 3, 2): 3} + assert roots(((2*x + 3)**3).expand(), x) == {Rational(-3, 2): 3} + + assert roots(((2*x - 3)**5).expand(), x) == {Rational( 3, 2): 5} + assert roots(((2*x + 3)**5).expand(), x) == {Rational(-3, 2): 5} + + assert roots(((a*x - b)**5).expand(), x) == { b/a: 5} + assert roots(((a*x + b)**5).expand(), x) == {-b/a: 5} + + assert roots(x**2 + (-a - 1)*x + a, x) == {a: 1, S.One: 1} + + assert roots(x**4 - 2*x**2 + 1, x) == {S.One: 2, S.NegativeOne: 2} + + assert roots(x**6 - 4*x**4 + 4*x**3 - x**2, x) == \ + {S.One: 2, -1 - sqrt(2): 1, S.Zero: 2, -1 + sqrt(2): 1} + + assert roots(x**8 - 1, x) == { + sqrt(2)/2 + I*sqrt(2)/2: 1, + sqrt(2)/2 - I*sqrt(2)/2: 1, + -sqrt(2)/2 + I*sqrt(2)/2: 1, + -sqrt(2)/2 - I*sqrt(2)/2: 1, + S.One: 1, -S.One: 1, I: 1, -I: 1 + } + + f = -2016*x**2 - 5616*x**3 - 2056*x**4 + 3324*x**5 + 2176*x**6 - \ + 224*x**7 - 384*x**8 - 64*x**9 + + assert roots(f) == {S.Zero: 2, -S(2): 2, S(2): 1, Rational(-7, 2): 1, + Rational(-3, 2): 1, Rational(-1, 2): 1, Rational(3, 2): 1} + + assert roots((a + b + c)*x - (a + b + c + d), x) == {(a + b + c + d)/(a + b + c): 1} + + assert roots(x**3 + x**2 - x + 1, x, cubics=False) == {} + assert roots(((x - 2)*( + x + 3)*(x - 4)).expand(), x, cubics=False) == {-S(3): 1, S(2): 1, S(4): 1} + assert roots(((x - 2)*(x + 3)*(x - 4)*(x - 5)).expand(), x, cubics=False) == \ + {-S(3): 1, S(2): 1, S(4): 1, S(5): 1} + assert roots(x**3 + 2*x**2 + 4*x + 8, x) == {-S(2): 1, -2*I: 1, 2*I: 1} + assert roots(x**3 + 2*x**2 + 4*x + 8, x, cubics=True) == \ + {-2*I: 1, 2*I: 1, -S(2): 1} + assert roots((x**2 - x)*(x**3 + 2*x**2 + 4*x + 8), x ) == \ + {S.One: 1, S.Zero: 1, -S(2): 1, -2*I: 1, 2*I: 1} + + r1_2, r1_3 = S.Half, Rational(1, 3) + + x0 = (3*sqrt(33) + 19)**r1_3 + x1 = 4/x0/3 + x2 = x0/3 + x3 = sqrt(3)*I/2 + x4 = x3 - r1_2 + x5 = -x3 - r1_2 + assert roots(x**3 + x**2 - x + 1, x, cubics=True) == { + -x1 - x2 - r1_3: 1, + -x1/x4 - x2*x4 - r1_3: 1, + -x1/x5 - x2*x5 - r1_3: 1, + } + + f = (x**2 + 2*x + 3).subs(x, 2*x**2 + 3*x).subs(x, 5*x - 4) + + r13_20, r1_20 = [ Rational(*r) + for r in ((13, 20), (1, 20)) ] + + s2 = sqrt(2) + assert roots(f, x) == { + r13_20 + r1_20*sqrt(1 - 8*I*s2): 1, + r13_20 - r1_20*sqrt(1 - 8*I*s2): 1, + r13_20 + r1_20*sqrt(1 + 8*I*s2): 1, + r13_20 - r1_20*sqrt(1 + 8*I*s2): 1, + } + + f = x**4 + x**3 + x**2 + x + 1 + + r1_4, r1_8, r5_8 = [ Rational(*r) for r in ((1, 4), (1, 8), (5, 8)) ] + + assert roots(f, x) == { + -r1_4 + r1_4*5**r1_2 + I*(r5_8 + r1_8*5**r1_2)**r1_2: 1, + -r1_4 + r1_4*5**r1_2 - I*(r5_8 + r1_8*5**r1_2)**r1_2: 1, + -r1_4 - r1_4*5**r1_2 + I*(r5_8 - r1_8*5**r1_2)**r1_2: 1, + -r1_4 - r1_4*5**r1_2 - I*(r5_8 - r1_8*5**r1_2)**r1_2: 1, + } + + f = z**3 + (-2 - y)*z**2 + (1 + 2*y - 2*x**2)*z - y + 2*x**2 + + assert roots(f, z) == { + S.One: 1, + S.Half + S.Half*y + S.Half*sqrt(1 - 2*y + y**2 + 8*x**2): 1, + S.Half + S.Half*y - S.Half*sqrt(1 - 2*y + y**2 + 8*x**2): 1, + } + + assert roots(a*b*c*x**3 + 2*x**2 + 4*x + 8, x, cubics=False) == {} + assert roots(a*b*c*x**3 + 2*x**2 + 4*x + 8, x, cubics=True) != {} + + assert roots(x**4 - 1, x, filter='Z') == {S.One: 1, -S.One: 1} + assert roots(x**4 - 1, x, filter='I') == {I: 1, -I: 1} + + assert roots((x - 1)*(x + 1), x) == {S.One: 1, -S.One: 1} + assert roots( + (x - 1)*(x + 1), x, predicate=lambda r: r.is_positive) == {S.One: 1} + + assert roots(x**4 - 1, x, filter='Z', multiple=True) == [-S.One, S.One] + assert roots(x**4 - 1, x, filter='I', multiple=True) == [I, -I] + + ar, br = symbols('a, b', real=True) + p = x**2*(ar-br)**2 + 2*x*(br-ar) + 1 + assert roots(p, x, filter='R') == {1/(ar - br): 2} + + assert roots(x**3, x, multiple=True) == [S.Zero, S.Zero, S.Zero] + assert roots(1234, x, multiple=True) == [] + + f = x**6 - x**5 + x**4 - x**3 + x**2 - x + 1 + + assert roots(f) == { + -I*sin(pi/7) + cos(pi/7): 1, + -I*sin(pi*Rational(2, 7)) - cos(pi*Rational(2, 7)): 1, + -I*sin(pi*Rational(3, 7)) + cos(pi*Rational(3, 7)): 1, + I*sin(pi/7) + cos(pi/7): 1, + I*sin(pi*Rational(2, 7)) - cos(pi*Rational(2, 7)): 1, + I*sin(pi*Rational(3, 7)) + cos(pi*Rational(3, 7)): 1, + } + + g = ((x**2 + 1)*f**2).expand() + + assert roots(g) == { + -I*sin(pi/7) + cos(pi/7): 2, + -I*sin(pi*Rational(2, 7)) - cos(pi*Rational(2, 7)): 2, + -I*sin(pi*Rational(3, 7)) + cos(pi*Rational(3, 7)): 2, + I*sin(pi/7) + cos(pi/7): 2, + I*sin(pi*Rational(2, 7)) - cos(pi*Rational(2, 7)): 2, + I*sin(pi*Rational(3, 7)) + cos(pi*Rational(3, 7)): 2, + -I: 1, I: 1, + } + + r = roots(x**3 + 40*x + 64) + real_root = [rx for rx in r if rx.is_real][0] + cr = 108 + 6*sqrt(1074) + assert real_root == -2*root(cr, 3)/3 + 20/root(cr, 3) + + eq = Poly((7 + 5*sqrt(2))*x**3 + (-6 - 4*sqrt(2))*x**2 + (-sqrt(2) - 1)*x + 2, x, domain='EX') + assert roots(eq) == {-1 + sqrt(2): 1, -2 + 2*sqrt(2): 1, -sqrt(2) + 1: 1} + + eq = Poly(41*x**5 + 29*sqrt(2)*x**5 - 153*x**4 - 108*sqrt(2)*x**4 + + 175*x**3 + 125*sqrt(2)*x**3 - 45*x**2 - 30*sqrt(2)*x**2 - 26*sqrt(2)*x - + 26*x + 24, x, domain='EX') + assert roots(eq) == {-sqrt(2) + 1: 1, -2 + 2*sqrt(2): 1, -1 + sqrt(2): 1, + -4 + 4*sqrt(2): 1, -3 + 3*sqrt(2): 1} + + eq = Poly(x**3 - 2*x**2 + 6*sqrt(2)*x**2 - 8*sqrt(2)*x + 23*x - 14 + + 14*sqrt(2), x, domain='EX') + assert roots(eq) == {-2*sqrt(2) + 2: 1, -2*sqrt(2) + 1: 1, -2*sqrt(2) - 1: 1} + + assert roots(Poly((x + sqrt(2))**3 - 7, x, domain='EX')) == \ + {-sqrt(2) + root(7, 3)*(-S.Half - sqrt(3)*I/2): 1, + -sqrt(2) + root(7, 3)*(-S.Half + sqrt(3)*I/2): 1, + -sqrt(2) + root(7, 3): 1} + +def test_roots_slow(): + """Just test that calculating these roots does not hang. """ + a, b, c, d, x = symbols("a,b,c,d,x") + + f1 = x**2*c + (a/b) + x*c*d - a + f2 = x**2*(a + b*(c - d)*a) + x*a*b*c/(b*d - d) + (a*d - c/d) + + assert list(roots(f1, x).values()) == [1, 1] + assert list(roots(f2, x).values()) == [1, 1] + + (zz, yy, xx, zy, zx, yx, k) = symbols("zz,yy,xx,zy,zx,yx,k") + + e1 = (zz - k)*(yy - k)*(xx - k) + zy*yx*zx + zx - zy - yx + e2 = (zz - k)*yx*yx + zx*(yy - k)*zx + zy*zy*(xx - k) + + assert list(roots(e1 - e2, k).values()) == [1, 1, 1] + + f = x**3 + 2*x**2 + 8 + R = list(roots(f).keys()) + + assert not any(i for i in [f.subs(x, ri).n(chop=True) for ri in R]) + + +def test_roots_inexact(): + R1 = roots(x**2 + x + 1, x, multiple=True) + R2 = roots(x**2 + x + 1.0, x, multiple=True) + + for r1, r2 in zip(R1, R2): + assert abs(r1 - r2) < 1e-12 + + f = x**4 + 3.0*sqrt(2.0)*x**3 - (78.0 + 24.0*sqrt(3.0))*x**2 \ + + 144.0*(2*sqrt(3.0) + 9.0) + + R1 = roots(f, multiple=True) + R2 = (-12.7530479110482, -3.85012393732929, + 4.89897948556636, 7.46155167569183) + + for r1, r2 in zip(R1, R2): + assert abs(r1 - r2) < 1e-10 + + +def test_roots_preprocessed(): + E, F, J, L = symbols("E,F,J,L") + + f = -21601054687500000000*E**8*J**8/L**16 + \ + 508232812500000000*F*x*E**7*J**7/L**14 - \ + 4269543750000000*E**6*F**2*J**6*x**2/L**12 + \ + 16194716250000*E**5*F**3*J**5*x**3/L**10 - \ + 27633173750*E**4*F**4*J**4*x**4/L**8 + \ + 14840215*E**3*F**5*J**3*x**5/L**6 + \ + 54794*E**2*F**6*J**2*x**6/(5*L**4) - \ + 1153*E*J*F**7*x**7/(80*L**2) + \ + 633*F**8*x**8/160000 + + assert roots(f, x) == {} + + R1 = roots(f.evalf(), x, multiple=True) + R2 = [-1304.88375606366, 97.1168816800648, 186.946430171876, 245.526792947065, + 503.441004174773, 791.549343830097, 1273.16678129348, 1850.10650616851] + + w = Wild('w') + p = w*E*J/(F*L**2) + + assert len(R1) == len(R2) + + for r1, r2 in zip(R1, R2): + match = r1.match(p) + assert match is not None and abs(match[w] - r2) < 1e-10 + + +def test_roots_strict(): + assert roots(x**2 - 2*x + 1, strict=False) == {1: 2} + assert roots(x**2 - 2*x + 1, strict=True) == {1: 2} + + assert roots(x**6 - 2*x**5 - x**2 + 3*x - 2, strict=False) == {2: 1} + raises(UnsolvableFactorError, lambda: roots(x**6 - 2*x**5 - x**2 + 3*x - 2, strict=True)) + + +def test_roots_mixed(): + f = -1936 - 5056*x - 7592*x**2 + 2704*x**3 - 49*x**4 + + _re, _im = intervals(f, all=True) + _nroots = nroots(f) + _sroots = roots(f, multiple=True) + + _re = [ Interval(a, b) for (a, b), _ in _re ] + _im = [ Interval(re(a), re(b))*Interval(im(a), im(b)) for (a, b), + _ in _im ] + + _intervals = _re + _im + _sroots = [ r.evalf() for r in _sroots ] + + _nroots = sorted(_nroots, key=lambda x: x.sort_key()) + _sroots = sorted(_sroots, key=lambda x: x.sort_key()) + + for _roots in (_nroots, _sroots): + for i, r in zip(_intervals, _roots): + if r.is_real: + assert r in i + else: + assert (re(r), im(r)) in i + + +def test_root_factors(): + assert root_factors(Poly(1, x)) == [Poly(1, x)] + assert root_factors(Poly(x, x)) == [Poly(x, x)] + + assert root_factors(x**2 - 1, x) == [x + 1, x - 1] + assert root_factors(x**2 - y, x) == [x - sqrt(y), x + sqrt(y)] + + assert root_factors((x**4 - 1)**2) == \ + [x + 1, x + 1, x - 1, x - 1, x - I, x - I, x + I, x + I] + + assert root_factors(Poly(x**4 - 1, x), filter='Z') == \ + [Poly(x + 1, x), Poly(x - 1, x), Poly(x**2 + 1, x)] + assert root_factors(8*x**2 + 12*x**4 + 6*x**6 + x**8, x, filter='Q') == \ + [x, x, x**6 + 6*x**4 + 12*x**2 + 8] + + +@slow +def test_nroots1(): + n = 64 + p = legendre_poly(n, x, polys=True) + + raises(mpmath.mp.NoConvergence, lambda: p.nroots(n=3, maxsteps=5)) + + roots = p.nroots(n=3) + # The order of roots matters. They are ordered from smallest to the + # largest. + assert [str(r) for r in roots] == \ + ['-0.999', '-0.996', '-0.991', '-0.983', '-0.973', '-0.961', + '-0.946', '-0.930', '-0.911', '-0.889', '-0.866', '-0.841', + '-0.813', '-0.784', '-0.753', '-0.720', '-0.685', '-0.649', + '-0.611', '-0.572', '-0.531', '-0.489', '-0.446', '-0.402', + '-0.357', '-0.311', '-0.265', '-0.217', '-0.170', '-0.121', + '-0.0730', '-0.0243', '0.0243', '0.0730', '0.121', '0.170', + '0.217', '0.265', '0.311', '0.357', '0.402', '0.446', '0.489', + '0.531', '0.572', '0.611', '0.649', '0.685', '0.720', '0.753', + '0.784', '0.813', '0.841', '0.866', '0.889', '0.911', '0.930', + '0.946', '0.961', '0.973', '0.983', '0.991', '0.996', '0.999'] + +def test_nroots2(): + p = Poly(x**5 + 3*x + 1, x) + + roots = p.nroots(n=3) + # The order of roots matters. The roots are ordered by their real + # components (if they agree, then by their imaginary components), + # with real roots appearing first. + assert [str(r) for r in roots] == \ + ['-0.332', '-0.839 - 0.944*I', '-0.839 + 0.944*I', + '1.01 - 0.937*I', '1.01 + 0.937*I'] + + roots = p.nroots(n=5) + assert [str(r) for r in roots] == \ + ['-0.33199', '-0.83907 - 0.94385*I', '-0.83907 + 0.94385*I', + '1.0051 - 0.93726*I', '1.0051 + 0.93726*I'] + + +def test_roots_composite(): + assert len(roots(Poly(y**3 + y**2*sqrt(x) + y + x, y, composite=True))) == 3 + + +def test_issue_19113(): + eq = cos(x)**3 - cos(x) + 1 + raises(PolynomialError, lambda: roots(eq)) + + +def test_issue_17454(): + assert roots([1, -3*(-4 - 4*I)**2/8 + 12*I, 0], multiple=True) == [0, 0] + + +def test_issue_20913(): + assert Poly(x + 9671406556917067856609794, x).real_roots() == [-9671406556917067856609794] + assert Poly(x**3 + 4, x).real_roots() == [-2**(S(2)/3)] + + +def test_issue_22768(): + e = Rational(1, 3) + r = (-1/a)**e*(a + 1)**(5*e) + assert roots(Poly(a*x**3 + (a + 1)**5, x)) == { + r: 1, + -r*(1 + sqrt(3)*I)/2: 1, + r*(-1 + sqrt(3)*I)/2: 1} diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/test_polytools.py b/phivenv/Lib/site-packages/sympy/polys/tests/test_polytools.py new file mode 100644 index 0000000000000000000000000000000000000000..a4096447cecea9db6e7559c305af6312b2a72725 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/tests/test_polytools.py @@ -0,0 +1,3976 @@ +"""Tests for user-friendly public interface to polynomial functions. """ + +import pickle + +from sympy.polys.polytools import ( + Poly, PurePoly, poly, + parallel_poly_from_expr, + degree, degree_list, + total_degree, + LC, LM, LT, + pdiv, prem, pquo, pexquo, + div, rem, quo, exquo, + half_gcdex, gcdex, invert, + subresultants, + resultant, discriminant, + terms_gcd, cofactors, + gcd, gcd_list, + lcm, lcm_list, + trunc, + monic, content, primitive, + compose, decompose, + sturm, + gff_list, gff, + sqf_norm, sqf_part, sqf_list, sqf, + factor_list, factor, + intervals, refine_root, count_roots, + all_roots, real_roots, nroots, ground_roots, + nth_power_roots_poly, + cancel, reduced, groebner, + GroebnerBasis, is_zero_dimensional, + _torational_factor_list, + to_rational_coeffs) + +from sympy.polys.polyerrors import ( + MultivariatePolynomialError, + ExactQuotientFailed, + PolificationFailed, + ComputationFailed, + UnificationFailed, + RefinementFailed, + GeneratorsNeeded, + GeneratorsError, + PolynomialError, + CoercionFailed, + DomainError, + OptionError, + FlagError) + +from sympy.polys.polyclasses import DMP + +from sympy.polys.fields import field +from sympy.polys.domains import FF, ZZ, QQ, ZZ_I, QQ_I, RR, EX +from sympy.polys.domains.realfield import RealField +from sympy.polys.domains.complexfield import ComplexField +from sympy.polys.orderings import lex, grlex, grevlex + +from sympy.combinatorics.galois import S4TransitiveSubgroups +from sympy.core.add import Add +from sympy.core.basic import _aresame +from sympy.core.containers import Tuple +from sympy.core.expr import Expr +from sympy.core.function import (Derivative, diff, expand) +from sympy.core.mul import _keep_coeff, Mul +from sympy.core.numbers import (Float, I, Integer, Rational, oo, pi) +from sympy.core.power import Pow +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.functions.elementary.complexes import (im, re) +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.hyperbolic import tanh +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import sin +from sympy.matrices.dense import Matrix +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.polys.rootoftools import rootof +from sympy.simplify.simplify import signsimp +from sympy.utilities.iterables import iterable +from sympy.utilities.exceptions import SymPyDeprecationWarning + +from sympy.testing.pytest import ( + raises, warns_deprecated_sympy, warns, tooslow, XFAIL +) + +from sympy.abc import a, b, c, d, p, q, t, w, x, y, z + + +def _epsilon_eq(a, b): + for u, v in zip(a, b): + if abs(u - v) > 1e-10: + return False + return True + + +def _strict_eq(a, b): + if type(a) == type(b): + if iterable(a): + if len(a) == len(b): + return all(_strict_eq(c, d) for c, d in zip(a, b)) + else: + return False + else: + return isinstance(a, Poly) and a.eq(b, strict=True) + else: + return False + + +def test_Poly_mixed_operations(): + p = Poly(x, x) + with warns_deprecated_sympy(): + p * exp(x) + with warns_deprecated_sympy(): + p + exp(x) + with warns_deprecated_sympy(): + p - exp(x) + + +def test_Poly_from_dict(): + K = FF(3) + + assert Poly.from_dict( + {0: 1, 1: 2}, gens=x, domain=K).rep == DMP([K(2), K(1)], K) + assert Poly.from_dict( + {0: 1, 1: 5}, gens=x, domain=K).rep == DMP([K(2), K(1)], K) + + assert Poly.from_dict( + {(0,): 1, (1,): 2}, gens=x, domain=K).rep == DMP([K(2), K(1)], K) + assert Poly.from_dict( + {(0,): 1, (1,): 5}, gens=x, domain=K).rep == DMP([K(2), K(1)], K) + + assert Poly.from_dict({(0, 0): 1, (1, 1): 2}, gens=( + x, y), domain=K).rep == DMP([[K(2), K(0)], [K(1)]], K) + + assert Poly.from_dict({0: 1, 1: 2}, gens=x).rep == DMP([ZZ(2), ZZ(1)], ZZ) + assert Poly.from_dict( + {0: 1, 1: 2}, gens=x, field=True).rep == DMP([QQ(2), QQ(1)], QQ) + + assert Poly.from_dict( + {0: 1, 1: 2}, gens=x, domain=ZZ).rep == DMP([ZZ(2), ZZ(1)], ZZ) + assert Poly.from_dict( + {0: 1, 1: 2}, gens=x, domain=QQ).rep == DMP([QQ(2), QQ(1)], QQ) + + assert Poly.from_dict( + {(0,): 1, (1,): 2}, gens=x).rep == DMP([ZZ(2), ZZ(1)], ZZ) + assert Poly.from_dict( + {(0,): 1, (1,): 2}, gens=x, field=True).rep == DMP([QQ(2), QQ(1)], QQ) + + assert Poly.from_dict( + {(0,): 1, (1,): 2}, gens=x, domain=ZZ).rep == DMP([ZZ(2), ZZ(1)], ZZ) + assert Poly.from_dict( + {(0,): 1, (1,): 2}, gens=x, domain=QQ).rep == DMP([QQ(2), QQ(1)], QQ) + + assert Poly.from_dict({(1,): sin(y)}, gens=x, composite=False) == \ + Poly(sin(y)*x, x, domain='EX') + assert Poly.from_dict({(1,): y}, gens=x, composite=False) == \ + Poly(y*x, x, domain='EX') + assert Poly.from_dict({(1, 1): 1}, gens=(x, y), composite=False) == \ + Poly(x*y, x, y, domain='ZZ') + assert Poly.from_dict({(1, 0): y}, gens=(x, z), composite=False) == \ + Poly(y*x, x, z, domain='EX') + + +def test_Poly_from_list(): + K = FF(3) + + assert Poly.from_list([2, 1], gens=x, domain=K).rep == DMP([K(2), K(1)], K) + assert Poly.from_list([5, 1], gens=x, domain=K).rep == DMP([K(2), K(1)], K) + + assert Poly.from_list([2, 1], gens=x).rep == DMP([ZZ(2), ZZ(1)], ZZ) + assert Poly.from_list([2, 1], gens=x, field=True).rep == DMP([QQ(2), QQ(1)], QQ) + + assert Poly.from_list([2, 1], gens=x, domain=ZZ).rep == DMP([ZZ(2), ZZ(1)], ZZ) + assert Poly.from_list([2, 1], gens=x, domain=QQ).rep == DMP([QQ(2), QQ(1)], QQ) + + assert Poly.from_list([0, 1.0], gens=x).rep == DMP([RR(1.0)], RR) + assert Poly.from_list([1.0, 0], gens=x).rep == DMP([RR(1.0), RR(0.0)], RR) + + raises(MultivariatePolynomialError, lambda: Poly.from_list([[]], gens=(x, y))) + + +def test_Poly_from_poly(): + f = Poly(x + 7, x, domain=ZZ) + g = Poly(x + 2, x, modulus=3) + h = Poly(x + y, x, y, domain=ZZ) + + K = FF(3) + + assert Poly.from_poly(f) == f + assert Poly.from_poly(f, domain=K).rep == DMP([K(1), K(1)], K) + assert Poly.from_poly(f, domain=ZZ).rep == DMP([ZZ(1), ZZ(7)], ZZ) + assert Poly.from_poly(f, domain=QQ).rep == DMP([QQ(1), QQ(7)], QQ) + + assert Poly.from_poly(f, gens=x) == f + assert Poly.from_poly(f, gens=x, domain=K).rep == DMP([K(1), K(1)], K) + assert Poly.from_poly(f, gens=x, domain=ZZ).rep == DMP([ZZ(1), ZZ(7)], ZZ) + assert Poly.from_poly(f, gens=x, domain=QQ).rep == DMP([QQ(1), QQ(7)], QQ) + + assert Poly.from_poly(f, gens=y) == Poly(x + 7, y, domain='ZZ[x]') + raises(CoercionFailed, lambda: Poly.from_poly(f, gens=y, domain=K)) + raises(CoercionFailed, lambda: Poly.from_poly(f, gens=y, domain=ZZ)) + raises(CoercionFailed, lambda: Poly.from_poly(f, gens=y, domain=QQ)) + + assert Poly.from_poly(f, gens=(x, y)) == Poly(x + 7, x, y, domain='ZZ') + assert Poly.from_poly( + f, gens=(x, y), domain=ZZ) == Poly(x + 7, x, y, domain='ZZ') + assert Poly.from_poly( + f, gens=(x, y), domain=QQ) == Poly(x + 7, x, y, domain='QQ') + assert Poly.from_poly( + f, gens=(x, y), modulus=3) == Poly(x + 7, x, y, domain='FF(3)') + + K = FF(2) + + assert Poly.from_poly(g) == g + assert Poly.from_poly(g, domain=ZZ).rep == DMP([ZZ(1), ZZ(-1)], ZZ) + raises(CoercionFailed, lambda: Poly.from_poly(g, domain=QQ)) + assert Poly.from_poly(g, domain=K).rep == DMP([K(1), K(0)], K) + + assert Poly.from_poly(g, gens=x) == g + assert Poly.from_poly(g, gens=x, domain=ZZ).rep == DMP([ZZ(1), ZZ(-1)], ZZ) + raises(CoercionFailed, lambda: Poly.from_poly(g, gens=x, domain=QQ)) + assert Poly.from_poly(g, gens=x, domain=K).rep == DMP([K(1), K(0)], K) + + K = FF(3) + + assert Poly.from_poly(h) == h + assert Poly.from_poly( + h, domain=ZZ).rep == DMP([[ZZ(1)], [ZZ(1), ZZ(0)]], ZZ) + assert Poly.from_poly( + h, domain=QQ).rep == DMP([[QQ(1)], [QQ(1), QQ(0)]], QQ) + assert Poly.from_poly(h, domain=K).rep == DMP([[K(1)], [K(1), K(0)]], K) + + assert Poly.from_poly(h, gens=x) == Poly(x + y, x, domain=ZZ[y]) + raises(CoercionFailed, lambda: Poly.from_poly(h, gens=x, domain=ZZ)) + assert Poly.from_poly( + h, gens=x, domain=ZZ[y]) == Poly(x + y, x, domain=ZZ[y]) + raises(CoercionFailed, lambda: Poly.from_poly(h, gens=x, domain=QQ)) + assert Poly.from_poly( + h, gens=x, domain=QQ[y]) == Poly(x + y, x, domain=QQ[y]) + raises(CoercionFailed, lambda: Poly.from_poly(h, gens=x, modulus=3)) + + assert Poly.from_poly(h, gens=y) == Poly(x + y, y, domain=ZZ[x]) + raises(CoercionFailed, lambda: Poly.from_poly(h, gens=y, domain=ZZ)) + assert Poly.from_poly( + h, gens=y, domain=ZZ[x]) == Poly(x + y, y, domain=ZZ[x]) + raises(CoercionFailed, lambda: Poly.from_poly(h, gens=y, domain=QQ)) + assert Poly.from_poly( + h, gens=y, domain=QQ[x]) == Poly(x + y, y, domain=QQ[x]) + raises(CoercionFailed, lambda: Poly.from_poly(h, gens=y, modulus=3)) + + assert Poly.from_poly(h, gens=(x, y)) == h + assert Poly.from_poly( + h, gens=(x, y), domain=ZZ).rep == DMP([[ZZ(1)], [ZZ(1), ZZ(0)]], ZZ) + assert Poly.from_poly( + h, gens=(x, y), domain=QQ).rep == DMP([[QQ(1)], [QQ(1), QQ(0)]], QQ) + assert Poly.from_poly( + h, gens=(x, y), domain=K).rep == DMP([[K(1)], [K(1), K(0)]], K) + + assert Poly.from_poly( + h, gens=(y, x)).rep == DMP([[ZZ(1)], [ZZ(1), ZZ(0)]], ZZ) + assert Poly.from_poly( + h, gens=(y, x), domain=ZZ).rep == DMP([[ZZ(1)], [ZZ(1), ZZ(0)]], ZZ) + assert Poly.from_poly( + h, gens=(y, x), domain=QQ).rep == DMP([[QQ(1)], [QQ(1), QQ(0)]], QQ) + assert Poly.from_poly( + h, gens=(y, x), domain=K).rep == DMP([[K(1)], [K(1), K(0)]], K) + + assert Poly.from_poly( + h, gens=(x, y), field=True).rep == DMP([[QQ(1)], [QQ(1), QQ(0)]], QQ) + assert Poly.from_poly( + h, gens=(x, y), field=True).rep == DMP([[QQ(1)], [QQ(1), QQ(0)]], QQ) + + +def test_Poly_from_expr(): + raises(GeneratorsNeeded, lambda: Poly.from_expr(S.Zero)) + raises(GeneratorsNeeded, lambda: Poly.from_expr(S(7))) + + F3 = FF(3) + + assert Poly.from_expr(x + 5, domain=F3).rep == DMP([F3(1), F3(2)], F3) + assert Poly.from_expr(y + 5, domain=F3).rep == DMP([F3(1), F3(2)], F3) + + assert Poly.from_expr(x + 5, x, domain=F3).rep == DMP([F3(1), F3(2)], F3) + assert Poly.from_expr(y + 5, y, domain=F3).rep == DMP([F3(1), F3(2)], F3) + + assert Poly.from_expr(x + y, domain=F3).rep == DMP([[F3(1)], [F3(1), F3(0)]], F3) + assert Poly.from_expr(x + y, x, y, domain=F3).rep == DMP([[F3(1)], [F3(1), F3(0)]], F3) + + assert Poly.from_expr(x + 5).rep == DMP([ZZ(1), ZZ(5)], ZZ) + assert Poly.from_expr(y + 5).rep == DMP([ZZ(1), ZZ(5)], ZZ) + + assert Poly.from_expr(x + 5, x).rep == DMP([ZZ(1), ZZ(5)], ZZ) + assert Poly.from_expr(y + 5, y).rep == DMP([ZZ(1), ZZ(5)], ZZ) + + assert Poly.from_expr(x + 5, domain=ZZ).rep == DMP([ZZ(1), ZZ(5)], ZZ) + assert Poly.from_expr(y + 5, domain=ZZ).rep == DMP([ZZ(1), ZZ(5)], ZZ) + + assert Poly.from_expr(x + 5, x, domain=ZZ).rep == DMP([ZZ(1), ZZ(5)], ZZ) + assert Poly.from_expr(y + 5, y, domain=ZZ).rep == DMP([ZZ(1), ZZ(5)], ZZ) + + assert Poly.from_expr(x + 5, x, y, domain=ZZ).rep == DMP([[ZZ(1)], [ZZ(5)]], ZZ) + assert Poly.from_expr(y + 5, x, y, domain=ZZ).rep == DMP([[ZZ(1), ZZ(5)]], ZZ) + + +def test_Poly_rootof_extension(): + r1 = rootof(x**3 + x + 3, 0) + r2 = rootof(x**3 + x + 3, 1) + K1 = QQ.algebraic_field(r1) + K2 = QQ.algebraic_field(r2) + assert Poly(r1, y) == Poly(r1, y, domain=EX) + assert Poly(r2, y) == Poly(r2, y, domain=EX) + assert Poly(r1, y, extension=True) == Poly(r1, y, domain=K1) + assert Poly(r2, y, extension=True) == Poly(r2, y, domain=K2) + + +@tooslow +def test_Poly_rootof_extension_primitive_element(): + r1 = rootof(x**3 + x + 3, 0) + r2 = rootof(x**3 + x + 3, 1) + K12 = QQ.algebraic_field(r1 + r2) + assert Poly(r1*y + r2, y, extension=True) == Poly(r1*y + r2, y, domain=K12) + + +@XFAIL +def test_Poly_rootof_same_symbol_issue_26808(): + # XXX: This fails because r1 contains x. + r1 = rootof(x**3 + x + 3, 0) + K1 = QQ.algebraic_field(r1) + assert Poly(r1, x) == Poly(r1, x, domain=EX) + assert Poly(r1, x, extension=True) == Poly(r1, x, domain=K1) + + +def test_Poly_rootof_extension_to_sympy(): + # Verify that when primitive elements and RootOf are used, the expression + # is not exploded on the way back to sympy. + r1 = rootof(y**3 + y**2 - 1, 0) + r2 = rootof(z**5 + z**2 - 1, 0) + p = -x**5 + x**2 + x*r1 - r2 + 3*r1**2 + assert p.as_poly(x, extension=True).as_expr() == p + + +def test_poly_from_domain_element(): + dom = ZZ[x] + assert Poly(dom(x+1), y, domain=dom).rep == DMP([dom(x+1)], dom) + dom = dom.get_field() + assert Poly(dom(x+1), y, domain=dom).rep == DMP([dom(x+1)], dom) + + dom = QQ[x] + assert Poly(dom(x+1), y, domain=dom).rep == DMP([dom(x+1)], dom) + dom = dom.get_field() + assert Poly(dom(x+1), y, domain=dom).rep == DMP([dom(x+1)], dom) + + dom = ZZ.old_poly_ring(x) + assert Poly(dom([ZZ(1), ZZ(1)]), y, domain=dom).rep == DMP([dom([ZZ(1), ZZ(1)])], dom) + dom = dom.get_field() + assert Poly(dom([ZZ(1), ZZ(1)]), y, domain=dom).rep == DMP([dom([ZZ(1), ZZ(1)])], dom) + + dom = QQ.old_poly_ring(x) + assert Poly(dom([QQ(1), QQ(1)]), y, domain=dom).rep == DMP([dom([QQ(1), QQ(1)])], dom) + dom = dom.get_field() + assert Poly(dom([QQ(1), QQ(1)]), y, domain=dom).rep == DMP([dom([QQ(1), QQ(1)])], dom) + + dom = QQ.algebraic_field(I) + assert Poly(dom([1, 1]), x, domain=dom).rep == DMP([dom([1, 1])], dom) + + +def test_Poly__new__(): + raises(GeneratorsError, lambda: Poly(x + 1, x, x)) + + raises(GeneratorsError, lambda: Poly(x + y, x, y, domain=ZZ[x])) + raises(GeneratorsError, lambda: Poly(x + y, x, y, domain=ZZ[y])) + + raises(OptionError, lambda: Poly(x, x, symmetric=True)) + raises(OptionError, lambda: Poly(x + 2, x, modulus=3, domain=QQ)) + + raises(OptionError, lambda: Poly(x + 2, x, domain=ZZ, gaussian=True)) + raises(OptionError, lambda: Poly(x + 2, x, modulus=3, gaussian=True)) + + raises(OptionError, lambda: Poly(x + 2, x, domain=ZZ, extension=[sqrt(3)])) + raises(OptionError, lambda: Poly(x + 2, x, modulus=3, extension=[sqrt(3)])) + + raises(OptionError, lambda: Poly(x + 2, x, domain=ZZ, extension=True)) + raises(OptionError, lambda: Poly(x + 2, x, modulus=3, extension=True)) + + raises(OptionError, lambda: Poly(x + 2, x, domain=ZZ, greedy=True)) + raises(OptionError, lambda: Poly(x + 2, x, domain=QQ, field=True)) + + raises(OptionError, lambda: Poly(x + 2, x, domain=ZZ, greedy=False)) + raises(OptionError, lambda: Poly(x + 2, x, domain=QQ, field=False)) + + raises(NotImplementedError, lambda: Poly(x + 1, x, modulus=3, order='grlex')) + raises(NotImplementedError, lambda: Poly(x + 1, x, order='grlex')) + + raises(GeneratorsNeeded, lambda: Poly({1: 2, 0: 1})) + raises(GeneratorsNeeded, lambda: Poly([2, 1])) + raises(GeneratorsNeeded, lambda: Poly((2, 1))) + + raises(GeneratorsNeeded, lambda: Poly(1)) + + assert Poly('x-x') == Poly(0, x) + + f = a*x**2 + b*x + c + + assert Poly({2: a, 1: b, 0: c}, x) == f + assert Poly(iter([a, b, c]), x) == f + assert Poly([a, b, c], x) == f + assert Poly((a, b, c), x) == f + + f = Poly({}, x, y, z) + + assert f.gens == (x, y, z) and f.as_expr() == 0 + + assert Poly(Poly(a*x + b*y, x, y), x) == Poly(a*x + b*y, x) + + assert Poly(3*x**2 + 2*x + 1, domain='ZZ').all_coeffs() == [3, 2, 1] + assert Poly(3*x**2 + 2*x + 1, domain='QQ').all_coeffs() == [3, 2, 1] + assert Poly(3*x**2 + 2*x + 1, domain='RR').all_coeffs() == [3.0, 2.0, 1.0] + + raises(CoercionFailed, lambda: Poly(3*x**2/5 + x*Rational(2, 5) + 1, domain='ZZ')) + assert Poly( + 3*x**2/5 + x*Rational(2, 5) + 1, domain='QQ').all_coeffs() == [Rational(3, 5), Rational(2, 5), 1] + assert _epsilon_eq( + Poly(3*x**2/5 + x*Rational(2, 5) + 1, domain='RR').all_coeffs(), [0.6, 0.4, 1.0]) + + assert Poly(3.0*x**2 + 2.0*x + 1, domain='ZZ').all_coeffs() == [3, 2, 1] + assert Poly(3.0*x**2 + 2.0*x + 1, domain='QQ').all_coeffs() == [3, 2, 1] + assert Poly( + 3.0*x**2 + 2.0*x + 1, domain='RR').all_coeffs() == [3.0, 2.0, 1.0] + + raises(CoercionFailed, lambda: Poly(3.1*x**2 + 2.1*x + 1, domain='ZZ')) + assert Poly(3.1*x**2 + 2.1*x + 1, domain='QQ').all_coeffs() == [Rational(31, 10), Rational(21, 10), 1] + assert Poly(3.1*x**2 + 2.1*x + 1, domain='RR').all_coeffs() == [3.1, 2.1, 1.0] + + assert Poly({(2, 1): 1, (1, 2): 2, (1, 1): 3}, x, y) == \ + Poly(x**2*y + 2*x*y**2 + 3*x*y, x, y) + + assert Poly(x**2 + 1, extension=I).get_domain() == QQ.algebraic_field(I) + + f = 3*x**5 - x**4 + x**3 - x** 2 + 65538 + + assert Poly(f, x, modulus=65537, symmetric=True) == \ + Poly(3*x**5 - x**4 + x**3 - x** 2 + 1, x, modulus=65537, + symmetric=True) + assert Poly(f, x, modulus=65537, symmetric=False) == \ + Poly(3*x**5 + 65536*x**4 + x**3 + 65536*x** 2 + 1, x, + modulus=65537, symmetric=False) + + N = 10**100 + assert Poly(-1, x, modulus=N, symmetric=False).as_expr() == N - 1 + + assert isinstance(Poly(x**2 + x + 1.0).get_domain(), RealField) + assert isinstance(Poly(x**2 + x + I + 1.0).get_domain(), ComplexField) + + +def test_Poly__args(): + assert Poly(x**2 + 1).args == (x**2 + 1, x) + + +def test_Poly__gens(): + assert Poly((x - p)*(x - q), x).gens == (x,) + assert Poly((x - p)*(x - q), p).gens == (p,) + assert Poly((x - p)*(x - q), q).gens == (q,) + + assert Poly((x - p)*(x - q), x, p).gens == (x, p) + assert Poly((x - p)*(x - q), x, q).gens == (x, q) + + assert Poly((x - p)*(x - q), x, p, q).gens == (x, p, q) + assert Poly((x - p)*(x - q), p, x, q).gens == (p, x, q) + assert Poly((x - p)*(x - q), p, q, x).gens == (p, q, x) + + assert Poly((x - p)*(x - q)).gens == (x, p, q) + + assert Poly((x - p)*(x - q), sort='x > p > q').gens == (x, p, q) + assert Poly((x - p)*(x - q), sort='p > x > q').gens == (p, x, q) + assert Poly((x - p)*(x - q), sort='p > q > x').gens == (p, q, x) + + assert Poly((x - p)*(x - q), x, p, q, sort='p > q > x').gens == (x, p, q) + + assert Poly((x - p)*(x - q), wrt='x').gens == (x, p, q) + assert Poly((x - p)*(x - q), wrt='p').gens == (p, x, q) + assert Poly((x - p)*(x - q), wrt='q').gens == (q, x, p) + + assert Poly((x - p)*(x - q), wrt=x).gens == (x, p, q) + assert Poly((x - p)*(x - q), wrt=p).gens == (p, x, q) + assert Poly((x - p)*(x - q), wrt=q).gens == (q, x, p) + + assert Poly((x - p)*(x - q), x, p, q, wrt='p').gens == (x, p, q) + + assert Poly((x - p)*(x - q), wrt='p', sort='q > x').gens == (p, q, x) + assert Poly((x - p)*(x - q), wrt='q', sort='p > x').gens == (q, p, x) + + +def test_Poly_zero(): + assert Poly(x).zero == Poly(0, x, domain=ZZ) + assert Poly(x/2).zero == Poly(0, x, domain=QQ) + + +def test_Poly_one(): + assert Poly(x).one == Poly(1, x, domain=ZZ) + assert Poly(x/2).one == Poly(1, x, domain=QQ) + + +def test_Poly__unify(): + raises(UnificationFailed, lambda: Poly(x)._unify(y)) + + F3 = FF(3) + + assert Poly(x, x, modulus=3)._unify(Poly(y, y, modulus=3))[2:] == ( + DMP([[F3(1)], []], F3), DMP([[F3(1), F3(0)]], F3)) + raises(UnificationFailed, lambda: Poly(x, x, modulus=3)._unify(Poly(y, y, modulus=5))) + + raises(UnificationFailed, lambda: Poly(y, x, y)._unify(Poly(x, x, modulus=3))) + raises(UnificationFailed, lambda: Poly(x, x, modulus=3)._unify(Poly(y, x, y))) + + assert Poly(x + 1, x)._unify(Poly(x + 2, x))[2:] ==\ + (DMP([ZZ(1), ZZ(1)], ZZ), DMP([ZZ(1), ZZ(2)], ZZ)) + assert Poly(x + 1, x, domain='QQ')._unify(Poly(x + 2, x))[2:] ==\ + (DMP([QQ(1), QQ(1)], QQ), DMP([QQ(1), QQ(2)], QQ)) + assert Poly(x + 1, x)._unify(Poly(x + 2, x, domain='QQ'))[2:] ==\ + (DMP([QQ(1), QQ(1)], QQ), DMP([QQ(1), QQ(2)], QQ)) + + assert Poly(x + 1, x)._unify(Poly(x + 2, x, y))[2:] ==\ + (DMP([[ZZ(1)], [ZZ(1)]], ZZ), DMP([[ZZ(1)], [ZZ(2)]], ZZ)) + assert Poly(x + 1, x, domain='QQ')._unify(Poly(x + 2, x, y))[2:] ==\ + (DMP([[QQ(1)], [QQ(1)]], QQ), DMP([[QQ(1)], [QQ(2)]], QQ)) + assert Poly(x + 1, x)._unify(Poly(x + 2, x, y, domain='QQ'))[2:] ==\ + (DMP([[QQ(1)], [QQ(1)]], QQ), DMP([[QQ(1)], [QQ(2)]], QQ)) + + assert Poly(x + 1, x, y)._unify(Poly(x + 2, x))[2:] ==\ + (DMP([[ZZ(1)], [ZZ(1)]], ZZ), DMP([[ZZ(1)], [ZZ(2)]], ZZ)) + assert Poly(x + 1, x, y, domain='QQ')._unify(Poly(x + 2, x))[2:] ==\ + (DMP([[QQ(1)], [QQ(1)]], QQ), DMP([[QQ(1)], [QQ(2)]], QQ)) + assert Poly(x + 1, x, y)._unify(Poly(x + 2, x, domain='QQ'))[2:] ==\ + (DMP([[QQ(1)], [QQ(1)]], QQ), DMP([[QQ(1)], [QQ(2)]], QQ)) + + assert Poly(x + 1, x, y)._unify(Poly(x + 2, x, y))[2:] ==\ + (DMP([[ZZ(1)], [ZZ(1)]], ZZ), DMP([[ZZ(1)], [ZZ(2)]], ZZ)) + assert Poly(x + 1, x, y, domain='QQ')._unify(Poly(x + 2, x, y))[2:] ==\ + (DMP([[QQ(1)], [QQ(1)]], QQ), DMP([[QQ(1)], [QQ(2)]], QQ)) + assert Poly(x + 1, x, y)._unify(Poly(x + 2, x, y, domain='QQ'))[2:] ==\ + (DMP([[QQ(1)], [QQ(1)]], QQ), DMP([[QQ(1)], [QQ(2)]], QQ)) + + assert Poly(x + 1, x)._unify(Poly(x + 2, y, x))[2:] ==\ + (DMP([[ZZ(1), ZZ(1)]], ZZ), DMP([[ZZ(1), ZZ(2)]], ZZ)) + assert Poly(x + 1, x, domain='QQ')._unify(Poly(x + 2, y, x))[2:] ==\ + (DMP([[QQ(1), QQ(1)]], QQ), DMP([[QQ(1), QQ(2)]], QQ)) + assert Poly(x + 1, x)._unify(Poly(x + 2, y, x, domain='QQ'))[2:] ==\ + (DMP([[QQ(1), QQ(1)]], QQ), DMP([[QQ(1), QQ(2)]], QQ)) + + assert Poly(x + 1, y, x)._unify(Poly(x + 2, x))[2:] ==\ + (DMP([[ZZ(1), ZZ(1)]], ZZ), DMP([[ZZ(1), ZZ(2)]], ZZ)) + assert Poly(x + 1, y, x, domain='QQ')._unify(Poly(x + 2, x))[2:] ==\ + (DMP([[QQ(1), QQ(1)]], QQ), DMP([[QQ(1), QQ(2)]], QQ)) + assert Poly(x + 1, y, x)._unify(Poly(x + 2, x, domain='QQ'))[2:] ==\ + (DMP([[QQ(1), QQ(1)]], QQ), DMP([[QQ(1), QQ(2)]], QQ)) + + assert Poly(x + 1, x, y)._unify(Poly(x + 2, y, x))[2:] ==\ + (DMP([[ZZ(1)], [ZZ(1)]], ZZ), DMP([[ZZ(1)], [ZZ(2)]], ZZ)) + assert Poly(x + 1, x, y, domain='QQ')._unify(Poly(x + 2, y, x))[2:] ==\ + (DMP([[QQ(1)], [QQ(1)]], QQ), DMP([[QQ(1)], [QQ(2)]], QQ)) + assert Poly(x + 1, x, y)._unify(Poly(x + 2, y, x, domain='QQ'))[2:] ==\ + (DMP([[QQ(1)], [QQ(1)]], QQ), DMP([[QQ(1)], [QQ(2)]], QQ)) + + assert Poly(x + 1, y, x)._unify(Poly(x + 2, x, y))[2:] ==\ + (DMP([[ZZ(1), ZZ(1)]], ZZ), DMP([[ZZ(1), ZZ(2)]], ZZ)) + assert Poly(x + 1, y, x, domain='QQ')._unify(Poly(x + 2, x, y))[2:] ==\ + (DMP([[QQ(1), QQ(1)]], QQ), DMP([[QQ(1), QQ(2)]], QQ)) + assert Poly(x + 1, y, x)._unify(Poly(x + 2, x, y, domain='QQ'))[2:] ==\ + (DMP([[QQ(1), QQ(1)]], QQ), DMP([[QQ(1), QQ(2)]], QQ)) + + assert Poly(x**2 + I, x, domain=ZZ_I).unify(Poly(x**2 + sqrt(2), x, extension=True)) == \ + (Poly(x**2 + I, x, domain='QQ'), Poly(x**2 + sqrt(2), x, domain='QQ')) + + F, A, B = field("a,b", ZZ) + + assert Poly(a*x, x, domain='ZZ[a]')._unify(Poly(a*b*x, x, domain='ZZ(a,b)'))[2:] == \ + (DMP([A, F(0)], F.to_domain()), DMP([A*B, F(0)], F.to_domain())) + + assert Poly(a*x, x, domain='ZZ(a)')._unify(Poly(a*b*x, x, domain='ZZ(a,b)'))[2:] == \ + (DMP([A, F(0)], F.to_domain()), DMP([A*B, F(0)], F.to_domain())) + + raises(CoercionFailed, lambda: Poly(Poly(x**2 + x**2*z, y, field=True), domain='ZZ(x)')) + + f = Poly(t**2 + t/3 + x, t, domain='QQ(x)') + g = Poly(t**2 + t/3 + x, t, domain='QQ[x]') + + assert f._unify(g)[2:] == (f.rep, f.rep) + + +def test_Poly_free_symbols(): + assert Poly(x**2 + 1).free_symbols == {x} + assert Poly(x**2 + y*z).free_symbols == {x, y, z} + assert Poly(x**2 + y*z, x).free_symbols == {x, y, z} + assert Poly(x**2 + sin(y*z)).free_symbols == {x, y, z} + assert Poly(x**2 + sin(y*z), x).free_symbols == {x, y, z} + assert Poly(x**2 + sin(y*z), x, domain=EX).free_symbols == {x, y, z} + assert Poly(1 + x + x**2, x, y, z).free_symbols == {x} + assert Poly(x + sin(y), z).free_symbols == {x, y} + + +def test_PurePoly_free_symbols(): + assert PurePoly(x**2 + 1).free_symbols == set() + assert PurePoly(x**2 + y*z).free_symbols == set() + assert PurePoly(x**2 + y*z, x).free_symbols == {y, z} + assert PurePoly(x**2 + sin(y*z)).free_symbols == set() + assert PurePoly(x**2 + sin(y*z), x).free_symbols == {y, z} + assert PurePoly(x**2 + sin(y*z), x, domain=EX).free_symbols == {y, z} + + +def test_Poly__eq__(): + assert (Poly(x, x) == Poly(x, x)) is True + assert (Poly(x, x, domain=QQ) == Poly(x, x)) is False + assert (Poly(x, x) == Poly(x, x, domain=QQ)) is False + + assert (Poly(x, x, domain=ZZ[a]) == Poly(x, x)) is False + assert (Poly(x, x) == Poly(x, x, domain=ZZ[a])) is False + + assert (Poly(x*y, x, y) == Poly(x, x)) is False + + assert (Poly(x, x, y) == Poly(x, x)) is False + assert (Poly(x, x) == Poly(x, x, y)) is False + + assert (Poly(x**2 + 1, x) == Poly(y**2 + 1, y)) is False + assert (Poly(y**2 + 1, y) == Poly(x**2 + 1, x)) is False + + f = Poly(x, x, domain=ZZ) + g = Poly(x, x, domain=QQ) + + assert f.eq(g) is False + assert f.ne(g) is True + + assert f.eq(g, strict=True) is False + assert f.ne(g, strict=True) is True + + t0 = Symbol('t0') + + f = Poly((t0/2 + x**2)*t**2 - x**2*t, t, domain='QQ[x,t0]') + g = Poly((t0/2 + x**2)*t**2 - x**2*t, t, domain='ZZ(x,t0)') + + assert (f == g) is False + + +def test_PurePoly__eq__(): + assert (PurePoly(x, x) == PurePoly(x, x)) is True + assert (PurePoly(x, x, domain=QQ) == PurePoly(x, x)) is True + assert (PurePoly(x, x) == PurePoly(x, x, domain=QQ)) is True + + assert (PurePoly(x, x, domain=ZZ[a]) == PurePoly(x, x)) is True + assert (PurePoly(x, x) == PurePoly(x, x, domain=ZZ[a])) is True + + assert (PurePoly(x*y, x, y) == PurePoly(x, x)) is False + + assert (PurePoly(x, x, y) == PurePoly(x, x)) is False + assert (PurePoly(x, x) == PurePoly(x, x, y)) is False + + assert (PurePoly(x**2 + 1, x) == PurePoly(y**2 + 1, y)) is True + assert (PurePoly(y**2 + 1, y) == PurePoly(x**2 + 1, x)) is True + + f = PurePoly(x, x, domain=ZZ) + g = PurePoly(x, x, domain=QQ) + + assert f.eq(g) is True + assert f.ne(g) is False + + assert f.eq(g, strict=True) is False + assert f.ne(g, strict=True) is True + + f = PurePoly(x, x, domain=ZZ) + g = PurePoly(y, y, domain=QQ) + + assert f.eq(g) is True + assert f.ne(g) is False + + assert f.eq(g, strict=True) is False + assert f.ne(g, strict=True) is True + + +def test_PurePoly_Poly(): + assert isinstance(PurePoly(Poly(x**2 + 1)), PurePoly) is True + assert isinstance(Poly(PurePoly(x**2 + 1)), Poly) is True + + +def test_Poly_get_domain(): + assert Poly(2*x).get_domain() == ZZ + + assert Poly(2*x, domain='ZZ').get_domain() == ZZ + assert Poly(2*x, domain='QQ').get_domain() == QQ + + assert Poly(x/2).get_domain() == QQ + + raises(CoercionFailed, lambda: Poly(x/2, domain='ZZ')) + assert Poly(x/2, domain='QQ').get_domain() == QQ + + assert isinstance(Poly(0.2*x).get_domain(), RealField) + + +def test_Poly_set_domain(): + assert Poly(2*x + 1).set_domain(ZZ) == Poly(2*x + 1) + assert Poly(2*x + 1).set_domain('ZZ') == Poly(2*x + 1) + + assert Poly(2*x + 1).set_domain(QQ) == Poly(2*x + 1, domain='QQ') + assert Poly(2*x + 1).set_domain('QQ') == Poly(2*x + 1, domain='QQ') + + assert Poly(Rational(2, 10)*x + Rational(1, 10)).set_domain('RR') == Poly(0.2*x + 0.1) + assert Poly(0.2*x + 0.1).set_domain('QQ') == Poly(Rational(2, 10)*x + Rational(1, 10)) + + raises(CoercionFailed, lambda: Poly(x/2 + 1).set_domain(ZZ)) + raises(CoercionFailed, lambda: Poly(x + 1, modulus=2).set_domain(QQ)) + + raises(GeneratorsError, lambda: Poly(x*y, x, y).set_domain(ZZ[y])) + + +def test_Poly_get_modulus(): + assert Poly(x**2 + 1, modulus=2).get_modulus() == 2 + raises(PolynomialError, lambda: Poly(x**2 + 1).get_modulus()) + + +def test_Poly_set_modulus(): + assert Poly( + x**2 + 1, modulus=2).set_modulus(7) == Poly(x**2 + 1, modulus=7) + assert Poly( + x**2 + 5, modulus=7).set_modulus(2) == Poly(x**2 + 1, modulus=2) + + assert Poly(x**2 + 1).set_modulus(2) == Poly(x**2 + 1, modulus=2) + + raises(CoercionFailed, lambda: Poly(x/2 + 1).set_modulus(2)) + + +def test_Poly_add_ground(): + assert Poly(x + 1).add_ground(2) == Poly(x + 3) + + +def test_Poly_sub_ground(): + assert Poly(x + 1).sub_ground(2) == Poly(x - 1) + + +def test_Poly_mul_ground(): + assert Poly(x + 1).mul_ground(2) == Poly(2*x + 2) + + +def test_Poly_quo_ground(): + assert Poly(2*x + 4).quo_ground(2) == Poly(x + 2) + assert Poly(2*x + 3).quo_ground(2) == Poly(x + 1) + + +def test_Poly_exquo_ground(): + assert Poly(2*x + 4).exquo_ground(2) == Poly(x + 2) + raises(ExactQuotientFailed, lambda: Poly(2*x + 3).exquo_ground(2)) + + +def test_Poly_abs(): + assert Poly(-x + 1, x).abs() == abs(Poly(-x + 1, x)) == Poly(x + 1, x) + + +def test_Poly_neg(): + assert Poly(-x + 1, x).neg() == -Poly(-x + 1, x) == Poly(x - 1, x) + + +def test_Poly_add(): + assert Poly(0, x).add(Poly(0, x)) == Poly(0, x) + assert Poly(0, x) + Poly(0, x) == Poly(0, x) + + assert Poly(1, x).add(Poly(0, x)) == Poly(1, x) + assert Poly(1, x, y) + Poly(0, x) == Poly(1, x, y) + assert Poly(0, x).add(Poly(1, x, y)) == Poly(1, x, y) + assert Poly(0, x, y) + Poly(1, x, y) == Poly(1, x, y) + + assert Poly(1, x) + x == Poly(x + 1, x) + with warns_deprecated_sympy(): + Poly(1, x) + sin(x) + + assert Poly(x, x) + 1 == Poly(x + 1, x) + assert 1 + Poly(x, x) == Poly(x + 1, x) + + +def test_Poly_sub(): + assert Poly(0, x).sub(Poly(0, x)) == Poly(0, x) + assert Poly(0, x) - Poly(0, x) == Poly(0, x) + + assert Poly(1, x).sub(Poly(0, x)) == Poly(1, x) + assert Poly(1, x, y) - Poly(0, x) == Poly(1, x, y) + assert Poly(0, x).sub(Poly(1, x, y)) == Poly(-1, x, y) + assert Poly(0, x, y) - Poly(1, x, y) == Poly(-1, x, y) + + assert Poly(1, x) - x == Poly(1 - x, x) + with warns_deprecated_sympy(): + Poly(1, x) - sin(x) + + assert Poly(x, x) - 1 == Poly(x - 1, x) + assert 1 - Poly(x, x) == Poly(1 - x, x) + + +def test_Poly_mul(): + assert Poly(0, x).mul(Poly(0, x)) == Poly(0, x) + assert Poly(0, x) * Poly(0, x) == Poly(0, x) + + assert Poly(2, x).mul(Poly(4, x)) == Poly(8, x) + assert Poly(2, x, y) * Poly(4, x) == Poly(8, x, y) + assert Poly(4, x).mul(Poly(2, x, y)) == Poly(8, x, y) + assert Poly(4, x, y) * Poly(2, x, y) == Poly(8, x, y) + + assert Poly(1, x) * x == Poly(x, x) + with warns_deprecated_sympy(): + Poly(1, x) * sin(x) + + assert Poly(x, x) * 2 == Poly(2*x, x) + assert 2 * Poly(x, x) == Poly(2*x, x) + +def test_issue_13079(): + assert Poly(x)*x == Poly(x**2, x, domain='ZZ') + assert x*Poly(x) == Poly(x**2, x, domain='ZZ') + assert -2*Poly(x) == Poly(-2*x, x, domain='ZZ') + assert S(-2)*Poly(x) == Poly(-2*x, x, domain='ZZ') + assert Poly(x)*S(-2) == Poly(-2*x, x, domain='ZZ') + +def test_Poly_sqr(): + assert Poly(x*y, x, y).sqr() == Poly(x**2*y**2, x, y) + + +def test_Poly_pow(): + assert Poly(x, x).pow(10) == Poly(x**10, x) + assert Poly(x, x).pow(Integer(10)) == Poly(x**10, x) + + assert Poly(2*y, x, y).pow(4) == Poly(16*y**4, x, y) + assert Poly(2*y, x, y).pow(Integer(4)) == Poly(16*y**4, x, y) + + assert Poly(7*x*y, x, y)**3 == Poly(343*x**3*y**3, x, y) + + raises(TypeError, lambda: Poly(x*y + 1, x, y)**(-1)) + raises(TypeError, lambda: Poly(x*y + 1, x, y)**x) + + +def test_Poly_divmod(): + f, g = Poly(x**2), Poly(x) + q, r = g, Poly(0, x) + + assert divmod(f, g) == (q, r) + assert f // g == q + assert f % g == r + + assert divmod(f, x) == (q, r) + assert f // x == q + assert f % x == r + + q, r = Poly(0, x), Poly(2, x) + + assert divmod(2, g) == (q, r) + assert 2 // g == q + assert 2 % g == r + + assert Poly(x)/Poly(x) == 1 + assert Poly(x**2)/Poly(x) == x + assert Poly(x)/Poly(x**2) == 1/x + + +def test_Poly_eq_ne(): + assert (Poly(x + y, x, y) == Poly(x + y, x, y)) is True + assert (Poly(x + y, x) == Poly(x + y, x, y)) is False + assert (Poly(x + y, x, y) == Poly(x + y, x)) is False + assert (Poly(x + y, x) == Poly(x + y, x)) is True + assert (Poly(x + y, y) == Poly(x + y, y)) is True + + assert (Poly(x + y, x, y) == x + y) is True + assert (Poly(x + y, x) == x + y) is True + assert (Poly(x + y, x, y) == x + y) is True + assert (Poly(x + y, x) == x + y) is True + assert (Poly(x + y, y) == x + y) is True + + assert (Poly(x + y, x, y) != Poly(x + y, x, y)) is False + assert (Poly(x + y, x) != Poly(x + y, x, y)) is True + assert (Poly(x + y, x, y) != Poly(x + y, x)) is True + assert (Poly(x + y, x) != Poly(x + y, x)) is False + assert (Poly(x + y, y) != Poly(x + y, y)) is False + + assert (Poly(x + y, x, y) != x + y) is False + assert (Poly(x + y, x) != x + y) is False + assert (Poly(x + y, x, y) != x + y) is False + assert (Poly(x + y, x) != x + y) is False + assert (Poly(x + y, y) != x + y) is False + + assert (Poly(x, x) == sin(x)) is False + assert (Poly(x, x) != sin(x)) is True + + +def test_Poly_nonzero(): + assert not bool(Poly(0, x)) is True + assert not bool(Poly(1, x)) is False + + +def test_Poly_properties(): + assert Poly(0, x).is_zero is True + assert Poly(1, x).is_zero is False + + assert Poly(1, x).is_one is True + assert Poly(2, x).is_one is False + + assert Poly(x - 1, x).is_sqf is True + assert Poly((x - 1)**2, x).is_sqf is False + + assert Poly(x - 1, x).is_monic is True + assert Poly(2*x - 1, x).is_monic is False + + assert Poly(3*x + 2, x).is_primitive is True + assert Poly(4*x + 2, x).is_primitive is False + + assert Poly(1, x).is_ground is True + assert Poly(x, x).is_ground is False + + assert Poly(x + y + z + 1).is_linear is True + assert Poly(x*y*z + 1).is_linear is False + + assert Poly(x*y + z + 1).is_quadratic is True + assert Poly(x*y*z + 1).is_quadratic is False + + assert Poly(x*y).is_monomial is True + assert Poly(x*y + 1).is_monomial is False + + assert Poly(x**2 + x*y).is_homogeneous is True + assert Poly(x**3 + x*y).is_homogeneous is False + + assert Poly(x).is_univariate is True + assert Poly(x*y).is_univariate is False + + assert Poly(x*y).is_multivariate is True + assert Poly(x).is_multivariate is False + + assert Poly( + x**16 + x**14 - x**10 + x**8 - x**6 + x**2 + 1).is_cyclotomic is False + assert Poly( + x**16 + x**14 - x**10 - x**8 - x**6 + x**2 + 1).is_cyclotomic is True + + +def test_Poly_is_irreducible(): + assert Poly(x**2 + x + 1).is_irreducible is True + assert Poly(x**2 + 2*x + 1).is_irreducible is False + + assert Poly(7*x + 3, modulus=11).is_irreducible is True + assert Poly(7*x**2 + 3*x + 1, modulus=11).is_irreducible is False + + +def test_Poly_subs(): + assert Poly(x + 1).subs(x, 0) == 1 + + assert Poly(x + 1).subs(x, x) == Poly(x + 1) + assert Poly(x + 1).subs(x, y) == Poly(y + 1) + + assert Poly(x*y, x).subs(y, x) == x**2 + assert Poly(x*y, x).subs(x, y) == y**2 + + +def test_Poly_replace(): + assert Poly(x + 1).replace(x) == Poly(x + 1) + assert Poly(x + 1).replace(y) == Poly(y + 1) + + raises(PolynomialError, lambda: Poly(x + y).replace(z)) + + assert Poly(x + 1).replace(x, x) == Poly(x + 1) + assert Poly(x + 1).replace(x, y) == Poly(y + 1) + + assert Poly(x + y).replace(x, x) == Poly(x + y) + assert Poly(x + y).replace(x, z) == Poly(z + y, z, y) + + assert Poly(x + y).replace(y, y) == Poly(x + y) + assert Poly(x + y).replace(y, z) == Poly(x + z, x, z) + assert Poly(x + y).replace(z, t) == Poly(x + y) + + raises(PolynomialError, lambda: Poly(x + y).replace(x, y)) + + assert Poly(x + y, x).replace(x, z) == Poly(z + y, z) + assert Poly(x + y, y).replace(y, z) == Poly(x + z, z) + + raises(PolynomialError, lambda: Poly(x + y, x).replace(x, y)) + raises(PolynomialError, lambda: Poly(x + y, y).replace(y, x)) + + +def test_Poly_reorder(): + raises(PolynomialError, lambda: Poly(x + y).reorder(x, z)) + + assert Poly(x + y, x, y).reorder(x, y) == Poly(x + y, x, y) + assert Poly(x + y, x, y).reorder(y, x) == Poly(x + y, y, x) + + assert Poly(x + y, y, x).reorder(x, y) == Poly(x + y, x, y) + assert Poly(x + y, y, x).reorder(y, x) == Poly(x + y, y, x) + + assert Poly(x + y, x, y).reorder(wrt=x) == Poly(x + y, x, y) + assert Poly(x + y, x, y).reorder(wrt=y) == Poly(x + y, y, x) + + +def test_Poly_ltrim(): + f = Poly(y**2 + y*z**2, x, y, z).ltrim(y) + assert f.as_expr() == y**2 + y*z**2 and f.gens == (y, z) + assert Poly(x*y - x, z, x, y).ltrim(1) == Poly(x*y - x, x, y) + + raises(PolynomialError, lambda: Poly(x*y**2 + y**2, x, y).ltrim(y)) + raises(PolynomialError, lambda: Poly(x*y - x, x, y).ltrim(-1)) + +def test_Poly_has_only_gens(): + assert Poly(x*y + 1, x, y, z).has_only_gens(x, y) is True + assert Poly(x*y + z, x, y, z).has_only_gens(x, y) is False + + raises(GeneratorsError, lambda: Poly(x*y**2 + y**2, x, y).has_only_gens(t)) + + +def test_Poly_to_ring(): + assert Poly(2*x + 1, domain='ZZ').to_ring() == Poly(2*x + 1, domain='ZZ') + assert Poly(2*x + 1, domain='QQ').to_ring() == Poly(2*x + 1, domain='ZZ') + + raises(CoercionFailed, lambda: Poly(x/2 + 1).to_ring()) + raises(DomainError, lambda: Poly(2*x + 1, modulus=3).to_ring()) + + +def test_Poly_to_field(): + assert Poly(2*x + 1, domain='ZZ').to_field() == Poly(2*x + 1, domain='QQ') + assert Poly(2*x + 1, domain='QQ').to_field() == Poly(2*x + 1, domain='QQ') + + assert Poly(x/2 + 1, domain='QQ').to_field() == Poly(x/2 + 1, domain='QQ') + assert Poly(2*x + 1, modulus=3).to_field() == Poly(2*x + 1, modulus=3) + + assert Poly(2.0*x + 1.0).to_field() == Poly(2.0*x + 1.0) + + +def test_Poly_to_exact(): + assert Poly(2*x).to_exact() == Poly(2*x) + assert Poly(x/2).to_exact() == Poly(x/2) + + assert Poly(0.1*x).to_exact() == Poly(x/10) + + +def test_Poly_retract(): + f = Poly(x**2 + 1, x, domain=QQ[y]) + + assert f.retract() == Poly(x**2 + 1, x, domain='ZZ') + assert f.retract(field=True) == Poly(x**2 + 1, x, domain='QQ') + + assert Poly(0, x, y).retract() == Poly(0, x, y) + + +def test_Poly_slice(): + f = Poly(x**3 + 2*x**2 + 3*x + 4) + + assert f.slice(0, 0) == Poly(0, x) + assert f.slice(0, 1) == Poly(4, x) + assert f.slice(0, 2) == Poly(3*x + 4, x) + assert f.slice(0, 3) == Poly(2*x**2 + 3*x + 4, x) + assert f.slice(0, 4) == Poly(x**3 + 2*x**2 + 3*x + 4, x) + + assert f.slice(x, 0, 0) == Poly(0, x) + assert f.slice(x, 0, 1) == Poly(4, x) + assert f.slice(x, 0, 2) == Poly(3*x + 4, x) + assert f.slice(x, 0, 3) == Poly(2*x**2 + 3*x + 4, x) + assert f.slice(x, 0, 4) == Poly(x**3 + 2*x**2 + 3*x + 4, x) + + g = Poly(x**3 + 1) + + assert g.slice(0, 3) == Poly(1, x) + + +def test_Poly_coeffs(): + assert Poly(0, x).coeffs() == [0] + assert Poly(1, x).coeffs() == [1] + + assert Poly(2*x + 1, x).coeffs() == [2, 1] + + assert Poly(7*x**2 + 2*x + 1, x).coeffs() == [7, 2, 1] + assert Poly(7*x**4 + 2*x + 1, x).coeffs() == [7, 2, 1] + + assert Poly(x*y**7 + 2*x**2*y**3).coeffs('lex') == [2, 1] + assert Poly(x*y**7 + 2*x**2*y**3).coeffs('grlex') == [1, 2] + + +def test_Poly_monoms(): + assert Poly(0, x).monoms() == [(0,)] + assert Poly(1, x).monoms() == [(0,)] + + assert Poly(2*x + 1, x).monoms() == [(1,), (0,)] + + assert Poly(7*x**2 + 2*x + 1, x).monoms() == [(2,), (1,), (0,)] + assert Poly(7*x**4 + 2*x + 1, x).monoms() == [(4,), (1,), (0,)] + + assert Poly(x*y**7 + 2*x**2*y**3).monoms('lex') == [(2, 3), (1, 7)] + assert Poly(x*y**7 + 2*x**2*y**3).monoms('grlex') == [(1, 7), (2, 3)] + + +def test_Poly_terms(): + assert Poly(0, x).terms() == [((0,), 0)] + assert Poly(1, x).terms() == [((0,), 1)] + + assert Poly(2*x + 1, x).terms() == [((1,), 2), ((0,), 1)] + + assert Poly(7*x**2 + 2*x + 1, x).terms() == [((2,), 7), ((1,), 2), ((0,), 1)] + assert Poly(7*x**4 + 2*x + 1, x).terms() == [((4,), 7), ((1,), 2), ((0,), 1)] + + assert Poly( + x*y**7 + 2*x**2*y**3).terms('lex') == [((2, 3), 2), ((1, 7), 1)] + assert Poly( + x*y**7 + 2*x**2*y**3).terms('grlex') == [((1, 7), 1), ((2, 3), 2)] + + +def test_Poly_all_coeffs(): + assert Poly(0, x).all_coeffs() == [0] + assert Poly(1, x).all_coeffs() == [1] + + assert Poly(2*x + 1, x).all_coeffs() == [2, 1] + + assert Poly(7*x**2 + 2*x + 1, x).all_coeffs() == [7, 2, 1] + assert Poly(7*x**4 + 2*x + 1, x).all_coeffs() == [7, 0, 0, 2, 1] + + +def test_Poly_all_monoms(): + assert Poly(0, x).all_monoms() == [(0,)] + assert Poly(1, x).all_monoms() == [(0,)] + + assert Poly(2*x + 1, x).all_monoms() == [(1,), (0,)] + + assert Poly(7*x**2 + 2*x + 1, x).all_monoms() == [(2,), (1,), (0,)] + assert Poly(7*x**4 + 2*x + 1, x).all_monoms() == [(4,), (3,), (2,), (1,), (0,)] + + +def test_Poly_all_terms(): + assert Poly(0, x).all_terms() == [((0,), 0)] + assert Poly(1, x).all_terms() == [((0,), 1)] + + assert Poly(2*x + 1, x).all_terms() == [((1,), 2), ((0,), 1)] + + assert Poly(7*x**2 + 2*x + 1, x).all_terms() == \ + [((2,), 7), ((1,), 2), ((0,), 1)] + assert Poly(7*x**4 + 2*x + 1, x).all_terms() == \ + [((4,), 7), ((3,), 0), ((2,), 0), ((1,), 2), ((0,), 1)] + + +def test_Poly_termwise(): + f = Poly(x**2 + 20*x + 400) + g = Poly(x**2 + 2*x + 4) + + def func(monom, coeff): + (k,) = monom + return coeff//10**(2 - k) + + assert f.termwise(func) == g + + def func(monom, coeff): + (k,) = monom + return (k,), coeff//10**(2 - k) + + assert f.termwise(func) == g + + +def test_Poly_length(): + assert Poly(0, x).length() == 0 + assert Poly(1, x).length() == 1 + assert Poly(x, x).length() == 1 + + assert Poly(x + 1, x).length() == 2 + assert Poly(x**2 + 1, x).length() == 2 + assert Poly(x**2 + x + 1, x).length() == 3 + + +def test_Poly_as_dict(): + assert Poly(0, x).as_dict() == {} + assert Poly(0, x, y, z).as_dict() == {} + + assert Poly(1, x).as_dict() == {(0,): 1} + assert Poly(1, x, y, z).as_dict() == {(0, 0, 0): 1} + + assert Poly(x**2 + 3, x).as_dict() == {(2,): 1, (0,): 3} + assert Poly(x**2 + 3, x, y, z).as_dict() == {(2, 0, 0): 1, (0, 0, 0): 3} + + assert Poly(3*x**2*y*z**3 + 4*x*y + 5*x*z).as_dict() == {(2, 1, 3): 3, + (1, 1, 0): 4, (1, 0, 1): 5} + + +def test_Poly_as_expr(): + assert Poly(0, x).as_expr() == 0 + assert Poly(0, x, y, z).as_expr() == 0 + + assert Poly(1, x).as_expr() == 1 + assert Poly(1, x, y, z).as_expr() == 1 + + assert Poly(x**2 + 3, x).as_expr() == x**2 + 3 + assert Poly(x**2 + 3, x, y, z).as_expr() == x**2 + 3 + + assert Poly( + 3*x**2*y*z**3 + 4*x*y + 5*x*z).as_expr() == 3*x**2*y*z**3 + 4*x*y + 5*x*z + + f = Poly(x**2 + 2*x*y**2 - y, x, y) + + assert f.as_expr() == -y + x**2 + 2*x*y**2 + + assert f.as_expr({x: 5}) == 25 - y + 10*y**2 + assert f.as_expr({y: 6}) == -6 + 72*x + x**2 + + assert f.as_expr({x: 5, y: 6}) == 379 + assert f.as_expr(5, 6) == 379 + + raises(GeneratorsError, lambda: f.as_expr({z: 7})) + + +def test_Poly_lift(): + p = Poly(x**4 - I*x + 17*I, x, gaussian=True) + assert p.lift() == Poly(x**8 + x**2 - 34*x + 289, x, domain='QQ') + + +def test_Poly_lift_multiple(): + + r1 = rootof(y**3 + y**2 - 1, 0) + r2 = rootof(z**5 + z**2 - 1, 0) + p = Poly(r1*x + 3*r1**2 - r2 + x**2 - x**5, x, extension=True) + + assert p.lift() == Poly( + -x**75 + 15*x**72 - 5*x**71 + 15*x**70 - 105*x**69 + 70*x**68 - + 220*x**67 + 560*x**66 - 635*x**65 + 1495*x**64 - 2735*x**63 + + 4415*x**62 - 7410*x**61 + 12741*x**60 - 22090*x**59 + 32125*x**58 - + 56281*x**57 + 88157*x**56 - 126842*x**55 + 214223*x**54 - 311802*x**53 + + 462667*x**52 - 700883*x**51 + 1006278*x**50 - 1480950*x**49 + + 2078055*x**48 - 3004675*x**47 + 4140410*x**46 - 5664222*x**45 + + 8029445*x**44 - 10528785*x**43 + 14309614*x**42 - 19032988*x**41 + + 24570573*x**40 - 32530459*x**39 + 41239581*x**38 - 52968051*x**37 + + 65891606*x**36 - 81997276*x**35 + 102530732*x**34 - 122009994*x**33 + + 150227996*x**32 - 176452478*x**31 + 206393768*x**30 - 245291426*x**29 + + 276598718*x**28 - 320005297*x**27 + 353649032*x**26 + - 393246309*x**25 + 434566186*x**24 - 460608964*x**23 + 508052079*x**22 + - 513976618*x**21 + 539374498*x**20 - 557851717*x**19 + 540788016*x**18 + - 564949060*x**17 + 520866566*x**16 + - 507861375*x**15 + 474999819*x**14 - 423619160*x**13 + 414540540*x**12 + - 322522367*x**11 + 311586511*x**10 - 238812299*x**9 + 184482053*x**8 + - 189265274*x**7 + 93619528*x**6 - 106852385*x**5 + 57294385*x**4 - + 26486666*x**3 + 42614683*x**2 - 1511583*x + 15975845, x, domain='QQ' + ) + + +def test_Poly_deflate(): + assert Poly(0, x).deflate() == ((1,), Poly(0, x)) + assert Poly(1, x).deflate() == ((1,), Poly(1, x)) + assert Poly(x, x).deflate() == ((1,), Poly(x, x)) + + assert Poly(x**2, x).deflate() == ((2,), Poly(x, x)) + assert Poly(x**17, x).deflate() == ((17,), Poly(x, x)) + + assert Poly( + x**2*y*z**11 + x**4*z**11).deflate() == ((2, 1, 11), Poly(x*y*z + x**2*z)) + + +def test_Poly_inject(): + f = Poly(x**2*y + x*y**3 + x*y + 1, x) + + assert f.inject() == Poly(x**2*y + x*y**3 + x*y + 1, x, y) + assert f.inject(front=True) == Poly(y**3*x + y*x**2 + y*x + 1, y, x) + + +def test_Poly_eject(): + f = Poly(x**2*y + x*y**3 + x*y + 1, x, y) + + assert f.eject(x) == Poly(x*y**3 + (x**2 + x)*y + 1, y, domain='ZZ[x]') + assert f.eject(y) == Poly(y*x**2 + (y**3 + y)*x + 1, x, domain='ZZ[y]') + + ex = x + y + z + t + w + g = Poly(ex, x, y, z, t, w) + + assert g.eject(x) == Poly(ex, y, z, t, w, domain='ZZ[x]') + assert g.eject(x, y) == Poly(ex, z, t, w, domain='ZZ[x, y]') + assert g.eject(x, y, z) == Poly(ex, t, w, domain='ZZ[x, y, z]') + assert g.eject(w) == Poly(ex, x, y, z, t, domain='ZZ[w]') + assert g.eject(t, w) == Poly(ex, x, y, z, domain='ZZ[t, w]') + assert g.eject(z, t, w) == Poly(ex, x, y, domain='ZZ[z, t, w]') + + raises(DomainError, lambda: Poly(x*y, x, y, domain=ZZ[z]).eject(y)) + raises(NotImplementedError, lambda: Poly(x*y, x, y, z).eject(y)) + + +def test_Poly_exclude(): + assert Poly(x, x, y).exclude() == Poly(x, x) + assert Poly(x*y, x, y).exclude() == Poly(x*y, x, y) + assert Poly(1, x, y).exclude() == Poly(1, x, y) + + +def test_Poly__gen_to_level(): + assert Poly(1, x, y)._gen_to_level(-2) == 0 + assert Poly(1, x, y)._gen_to_level(-1) == 1 + assert Poly(1, x, y)._gen_to_level( 0) == 0 + assert Poly(1, x, y)._gen_to_level( 1) == 1 + + raises(PolynomialError, lambda: Poly(1, x, y)._gen_to_level(-3)) + raises(PolynomialError, lambda: Poly(1, x, y)._gen_to_level( 2)) + + assert Poly(1, x, y)._gen_to_level(x) == 0 + assert Poly(1, x, y)._gen_to_level(y) == 1 + + assert Poly(1, x, y)._gen_to_level('x') == 0 + assert Poly(1, x, y)._gen_to_level('y') == 1 + + raises(PolynomialError, lambda: Poly(1, x, y)._gen_to_level(z)) + raises(PolynomialError, lambda: Poly(1, x, y)._gen_to_level('z')) + + +def test_Poly_degree(): + assert Poly(0, x).degree() is -oo + assert Poly(1, x).degree() == 0 + assert Poly(x, x).degree() == 1 + + assert Poly(0, x).degree(gen=0) is -oo + assert Poly(1, x).degree(gen=0) == 0 + assert Poly(x, x).degree(gen=0) == 1 + + assert Poly(0, x).degree(gen=x) is -oo + assert Poly(1, x).degree(gen=x) == 0 + assert Poly(x, x).degree(gen=x) == 1 + + assert Poly(0, x).degree(gen='x') is -oo + assert Poly(1, x).degree(gen='x') == 0 + assert Poly(x, x).degree(gen='x') == 1 + + raises(PolynomialError, lambda: Poly(1, x).degree(gen=1)) + raises(PolynomialError, lambda: Poly(1, x).degree(gen=y)) + raises(PolynomialError, lambda: Poly(1, x).degree(gen='y')) + + assert Poly(1, x, y).degree() == 0 + assert Poly(2*y, x, y).degree() == 0 + assert Poly(x*y, x, y).degree() == 1 + + assert Poly(1, x, y).degree(gen=x) == 0 + assert Poly(2*y, x, y).degree(gen=x) == 0 + assert Poly(x*y, x, y).degree(gen=x) == 1 + + assert Poly(1, x, y).degree(gen=y) == 0 + assert Poly(2*y, x, y).degree(gen=y) == 1 + assert Poly(x*y, x, y).degree(gen=y) == 1 + + assert degree(0, x) is -oo + assert degree(1, x) == 0 + assert degree(x, x) == 1 + + assert degree(x*y**2, x) == 1 + assert degree(x*y**2, y) == 2 + assert degree(x*y**2, z) == 0 + + assert degree(pi) == 1 + + raises(TypeError, lambda: degree(y**2 + x**3)) + raises(TypeError, lambda: degree(y**2 + x**3, 1)) + raises(PolynomialError, lambda: degree(x, 1.1)) + raises(PolynomialError, lambda: degree(x**2/(x**3 + 1), x)) + + assert degree(Poly(0,x),z) is -oo + assert degree(Poly(1,x),z) == 0 + assert degree(Poly(x**2+y**3,y)) == 3 + assert degree(Poly(y**2 + x**3, y, x), 1) == 3 + assert degree(Poly(y**2 + x**3, x), z) == 0 + assert degree(Poly(y**2 + x**3 + z**4, x), z) == 4 + +def test_Poly_degree_list(): + assert Poly(0, x).degree_list() == (-oo,) + assert Poly(0, x, y).degree_list() == (-oo, -oo) + assert Poly(0, x, y, z).degree_list() == (-oo, -oo, -oo) + + assert Poly(1, x).degree_list() == (0,) + assert Poly(1, x, y).degree_list() == (0, 0) + assert Poly(1, x, y, z).degree_list() == (0, 0, 0) + + assert Poly(x**2*y + x**3*z**2 + 1).degree_list() == (3, 1, 2) + + assert degree_list(1, x) == (0,) + assert degree_list(x, x) == (1,) + + assert degree_list(x*y**2) == (1, 2) + + raises(ComputationFailed, lambda: degree_list(1)) + + +def test_Poly_total_degree(): + assert Poly(x**2*y + x**3*z**2 + 1).total_degree() == 5 + assert Poly(x**2 + z**3).total_degree() == 3 + assert Poly(x*y*z + z**4).total_degree() == 4 + assert Poly(x**3 + x + 1).total_degree() == 3 + + assert total_degree(x*y + z**3) == 3 + assert total_degree(x*y + z**3, x, y) == 2 + assert total_degree(1) == 0 + assert total_degree(Poly(y**2 + x**3 + z**4)) == 4 + assert total_degree(Poly(y**2 + x**3 + z**4, x)) == 3 + assert total_degree(Poly(y**2 + x**3 + z**4, x), z) == 4 + assert total_degree(Poly(x**9 + x*z*y + x**3*z**2 + z**7,x), z) == 7 + +def test_Poly_homogenize(): + assert Poly(x**2+y).homogenize(z) == Poly(x**2+y*z) + assert Poly(x+y).homogenize(z) == Poly(x+y, x, y, z) + assert Poly(x+y**2).homogenize(y) == Poly(x*y+y**2) + + +def test_Poly_homogeneous_order(): + assert Poly(0, x, y).homogeneous_order() is -oo + assert Poly(1, x, y).homogeneous_order() == 0 + assert Poly(x, x, y).homogeneous_order() == 1 + assert Poly(x*y, x, y).homogeneous_order() == 2 + + assert Poly(x + 1, x, y).homogeneous_order() is None + assert Poly(x*y + x, x, y).homogeneous_order() is None + + assert Poly(x**5 + 2*x**3*y**2 + 9*x*y**4).homogeneous_order() == 5 + assert Poly(x**5 + 2*x**3*y**3 + 9*x*y**4).homogeneous_order() is None + + +def test_Poly_LC(): + assert Poly(0, x).LC() == 0 + assert Poly(1, x).LC() == 1 + assert Poly(2*x**2 + x, x).LC() == 2 + + assert Poly(x*y**7 + 2*x**2*y**3).LC('lex') == 2 + assert Poly(x*y**7 + 2*x**2*y**3).LC('grlex') == 1 + + assert LC(x*y**7 + 2*x**2*y**3, order='lex') == 2 + assert LC(x*y**7 + 2*x**2*y**3, order='grlex') == 1 + + +def test_Poly_TC(): + assert Poly(0, x).TC() == 0 + assert Poly(1, x).TC() == 1 + assert Poly(2*x**2 + x, x).TC() == 0 + + +def test_Poly_EC(): + assert Poly(0, x).EC() == 0 + assert Poly(1, x).EC() == 1 + assert Poly(2*x**2 + x, x).EC() == 1 + + assert Poly(x*y**7 + 2*x**2*y**3).EC('lex') == 1 + assert Poly(x*y**7 + 2*x**2*y**3).EC('grlex') == 2 + + +def test_Poly_coeff(): + assert Poly(0, x).coeff_monomial(1) == 0 + assert Poly(0, x).coeff_monomial(x) == 0 + + assert Poly(1, x).coeff_monomial(1) == 1 + assert Poly(1, x).coeff_monomial(x) == 0 + + assert Poly(x**8, x).coeff_monomial(1) == 0 + assert Poly(x**8, x).coeff_monomial(x**7) == 0 + assert Poly(x**8, x).coeff_monomial(x**8) == 1 + assert Poly(x**8, x).coeff_monomial(x**9) == 0 + + assert Poly(3*x*y**2 + 1, x, y).coeff_monomial(1) == 1 + assert Poly(3*x*y**2 + 1, x, y).coeff_monomial(x*y**2) == 3 + + p = Poly(24*x*y*exp(8) + 23*x, x, y) + + assert p.coeff_monomial(x) == 23 + assert p.coeff_monomial(y) == 0 + assert p.coeff_monomial(x*y) == 24*exp(8) + + assert p.as_expr().coeff(x) == 24*y*exp(8) + 23 + raises(NotImplementedError, lambda: p.coeff(x)) + + raises(ValueError, lambda: Poly(x + 1).coeff_monomial(0)) + raises(ValueError, lambda: Poly(x + 1).coeff_monomial(3*x)) + raises(ValueError, lambda: Poly(x + 1).coeff_monomial(3*x*y)) + + +def test_Poly_nth(): + assert Poly(0, x).nth(0) == 0 + assert Poly(0, x).nth(1) == 0 + + assert Poly(1, x).nth(0) == 1 + assert Poly(1, x).nth(1) == 0 + + assert Poly(x**8, x).nth(0) == 0 + assert Poly(x**8, x).nth(7) == 0 + assert Poly(x**8, x).nth(8) == 1 + assert Poly(x**8, x).nth(9) == 0 + + assert Poly(3*x*y**2 + 1, x, y).nth(0, 0) == 1 + assert Poly(3*x*y**2 + 1, x, y).nth(1, 2) == 3 + + raises(ValueError, lambda: Poly(x*y + 1, x, y).nth(1)) + + +def test_Poly_LM(): + assert Poly(0, x).LM() == (0,) + assert Poly(1, x).LM() == (0,) + assert Poly(2*x**2 + x, x).LM() == (2,) + + assert Poly(x*y**7 + 2*x**2*y**3).LM('lex') == (2, 3) + assert Poly(x*y**7 + 2*x**2*y**3).LM('grlex') == (1, 7) + + assert LM(x*y**7 + 2*x**2*y**3, order='lex') == x**2*y**3 + assert LM(x*y**7 + 2*x**2*y**3, order='grlex') == x*y**7 + + +def test_Poly_LM_custom_order(): + f = Poly(x**2*y**3*z + x**2*y*z**3 + x*y*z + 1) + rev_lex = lambda monom: tuple(reversed(monom)) + + assert f.LM(order='lex') == (2, 3, 1) + assert f.LM(order=rev_lex) == (2, 1, 3) + + +def test_Poly_EM(): + assert Poly(0, x).EM() == (0,) + assert Poly(1, x).EM() == (0,) + assert Poly(2*x**2 + x, x).EM() == (1,) + + assert Poly(x*y**7 + 2*x**2*y**3).EM('lex') == (1, 7) + assert Poly(x*y**7 + 2*x**2*y**3).EM('grlex') == (2, 3) + + +def test_Poly_LT(): + assert Poly(0, x).LT() == ((0,), 0) + assert Poly(1, x).LT() == ((0,), 1) + assert Poly(2*x**2 + x, x).LT() == ((2,), 2) + + assert Poly(x*y**7 + 2*x**2*y**3).LT('lex') == ((2, 3), 2) + assert Poly(x*y**7 + 2*x**2*y**3).LT('grlex') == ((1, 7), 1) + + assert LT(x*y**7 + 2*x**2*y**3, order='lex') == 2*x**2*y**3 + assert LT(x*y**7 + 2*x**2*y**3, order='grlex') == x*y**7 + + +def test_Poly_ET(): + assert Poly(0, x).ET() == ((0,), 0) + assert Poly(1, x).ET() == ((0,), 1) + assert Poly(2*x**2 + x, x).ET() == ((1,), 1) + + assert Poly(x*y**7 + 2*x**2*y**3).ET('lex') == ((1, 7), 1) + assert Poly(x*y**7 + 2*x**2*y**3).ET('grlex') == ((2, 3), 2) + + +def test_Poly_max_norm(): + assert Poly(-1, x).max_norm() == 1 + assert Poly( 0, x).max_norm() == 0 + assert Poly( 1, x).max_norm() == 1 + + +def test_Poly_l1_norm(): + assert Poly(-1, x).l1_norm() == 1 + assert Poly( 0, x).l1_norm() == 0 + assert Poly( 1, x).l1_norm() == 1 + + +def test_Poly_clear_denoms(): + coeff, poly = Poly(x + 2, x).clear_denoms() + assert coeff == 1 and poly == Poly( + x + 2, x, domain='ZZ') and poly.get_domain() == ZZ + + coeff, poly = Poly(x/2 + 1, x).clear_denoms() + assert coeff == 2 and poly == Poly( + x + 2, x, domain='QQ') and poly.get_domain() == QQ + + coeff, poly = Poly(2*x**2 + 3, modulus=5).clear_denoms() + assert coeff == 1 and poly == Poly( + 2*x**2 + 3, x, modulus=5) and poly.get_domain() == FF(5) + + coeff, poly = Poly(x/2 + 1, x).clear_denoms(convert=True) + assert coeff == 2 and poly == Poly( + x + 2, x, domain='ZZ') and poly.get_domain() == ZZ + + coeff, poly = Poly(x/y + 1, x).clear_denoms(convert=True) + assert coeff == y and poly == Poly( + x + y, x, domain='ZZ[y]') and poly.get_domain() == ZZ[y] + + coeff, poly = Poly(x/3 + sqrt(2), x, domain='EX').clear_denoms() + assert coeff == 3 and poly == Poly( + x + 3*sqrt(2), x, domain='EX') and poly.get_domain() == EX + + coeff, poly = Poly( + x/3 + sqrt(2), x, domain='EX').clear_denoms(convert=True) + assert coeff == 3 and poly == Poly( + x + 3*sqrt(2), x, domain='EX') and poly.get_domain() == EX + + +def test_Poly_rat_clear_denoms(): + f = Poly(x**2/y + 1, x) + g = Poly(x**3 + y, x) + + assert f.rat_clear_denoms(g) == \ + (Poly(x**2 + y, x), Poly(y*x**3 + y**2, x)) + + f = f.set_domain(EX) + g = g.set_domain(EX) + + assert f.rat_clear_denoms(g) == (f, g) + + +def test_issue_20427(): + f = Poly(-117968192370600*18**(S(1)/3)/(217603955769048*(24201 + + 253*sqrt(9165))**(S(1)/3) + 2273005839412*sqrt(9165)*(24201 + + 253*sqrt(9165))**(S(1)/3)) - 15720318185*2**(S(2)/3)*3**(S(1)/3)*(24201 + + 253*sqrt(9165))**(S(2)/3)/(217603955769048*(24201 + 253*sqrt(9165))** + (S(1)/3) + 2273005839412*sqrt(9165)*(24201 + 253*sqrt(9165))**(S(1)/3)) + + 15720318185*12**(S(1)/3)*(24201 + 253*sqrt(9165))**(S(2)/3)/( + 217603955769048*(24201 + 253*sqrt(9165))**(S(1)/3) + 2273005839412* + sqrt(9165)*(24201 + 253*sqrt(9165))**(S(1)/3)) + 117968192370600*2**( + S(1)/3)*3**(S(2)/3)/(217603955769048*(24201 + 253*sqrt(9165))**(S(1)/3) + + 2273005839412*sqrt(9165)*(24201 + 253*sqrt(9165))**(S(1)/3)), x) + assert f == Poly(0, x, domain='EX') + + +def test_Poly_integrate(): + assert Poly(x + 1).integrate() == Poly(x**2/2 + x) + assert Poly(x + 1).integrate(x) == Poly(x**2/2 + x) + assert Poly(x + 1).integrate((x, 1)) == Poly(x**2/2 + x) + + assert Poly(x*y + 1).integrate(x) == Poly(x**2*y/2 + x) + assert Poly(x*y + 1).integrate(y) == Poly(x*y**2/2 + y) + + assert Poly(x*y + 1).integrate(x, x) == Poly(x**3*y/6 + x**2/2) + assert Poly(x*y + 1).integrate(y, y) == Poly(x*y**3/6 + y**2/2) + + assert Poly(x*y + 1).integrate((x, 2)) == Poly(x**3*y/6 + x**2/2) + assert Poly(x*y + 1).integrate((y, 2)) == Poly(x*y**3/6 + y**2/2) + + assert Poly(x*y + 1).integrate(x, y) == Poly(x**2*y**2/4 + x*y) + assert Poly(x*y + 1).integrate(y, x) == Poly(x**2*y**2/4 + x*y) + + +def test_Poly_diff(): + assert Poly(x**2 + x).diff() == Poly(2*x + 1) + assert Poly(x**2 + x).diff(x) == Poly(2*x + 1) + assert Poly(x**2 + x).diff((x, 1)) == Poly(2*x + 1) + + assert Poly(x**2*y**2 + x*y).diff(x) == Poly(2*x*y**2 + y) + assert Poly(x**2*y**2 + x*y).diff(y) == Poly(2*x**2*y + x) + + assert Poly(x**2*y**2 + x*y).diff(x, x) == Poly(2*y**2, x, y) + assert Poly(x**2*y**2 + x*y).diff(y, y) == Poly(2*x**2, x, y) + + assert Poly(x**2*y**2 + x*y).diff((x, 2)) == Poly(2*y**2, x, y) + assert Poly(x**2*y**2 + x*y).diff((y, 2)) == Poly(2*x**2, x, y) + + assert Poly(x**2*y**2 + x*y).diff(x, y) == Poly(4*x*y + 1) + assert Poly(x**2*y**2 + x*y).diff(y, x) == Poly(4*x*y + 1) + + +def test_issue_9585(): + assert diff(Poly(x**2 + x)) == Poly(2*x + 1) + assert diff(Poly(x**2 + x), x, evaluate=False) == \ + Derivative(Poly(x**2 + x), x) + assert Derivative(Poly(x**2 + x), x).doit() == Poly(2*x + 1) + + +def test_Poly_eval(): + assert Poly(0, x).eval(7) == 0 + assert Poly(1, x).eval(7) == 1 + assert Poly(x, x).eval(7) == 7 + + assert Poly(0, x).eval(0, 7) == 0 + assert Poly(1, x).eval(0, 7) == 1 + assert Poly(x, x).eval(0, 7) == 7 + + assert Poly(0, x).eval(x, 7) == 0 + assert Poly(1, x).eval(x, 7) == 1 + assert Poly(x, x).eval(x, 7) == 7 + + assert Poly(0, x).eval('x', 7) == 0 + assert Poly(1, x).eval('x', 7) == 1 + assert Poly(x, x).eval('x', 7) == 7 + + raises(PolynomialError, lambda: Poly(1, x).eval(1, 7)) + raises(PolynomialError, lambda: Poly(1, x).eval(y, 7)) + raises(PolynomialError, lambda: Poly(1, x).eval('y', 7)) + + assert Poly(123, x, y).eval(7) == Poly(123, y) + assert Poly(2*y, x, y).eval(7) == Poly(2*y, y) + assert Poly(x*y, x, y).eval(7) == Poly(7*y, y) + + assert Poly(123, x, y).eval(x, 7) == Poly(123, y) + assert Poly(2*y, x, y).eval(x, 7) == Poly(2*y, y) + assert Poly(x*y, x, y).eval(x, 7) == Poly(7*y, y) + + assert Poly(123, x, y).eval(y, 7) == Poly(123, x) + assert Poly(2*y, x, y).eval(y, 7) == Poly(14, x) + assert Poly(x*y, x, y).eval(y, 7) == Poly(7*x, x) + + assert Poly(x*y + y, x, y).eval({x: 7}) == Poly(8*y, y) + assert Poly(x*y + y, x, y).eval({y: 7}) == Poly(7*x + 7, x) + + assert Poly(x*y + y, x, y).eval({x: 6, y: 7}) == 49 + assert Poly(x*y + y, x, y).eval({x: 7, y: 6}) == 48 + + assert Poly(x*y + y, x, y).eval((6, 7)) == 49 + assert Poly(x*y + y, x, y).eval([6, 7]) == 49 + + assert Poly(x + 1, domain='ZZ').eval(S.Half) == Rational(3, 2) + assert Poly(x + 1, domain='ZZ').eval(sqrt(2)) == sqrt(2) + 1 + + raises(ValueError, lambda: Poly(x*y + y, x, y).eval((6, 7, 8))) + raises(DomainError, lambda: Poly(x + 1, domain='ZZ').eval(S.Half, auto=False)) + + # issue 6344 + alpha = Symbol('alpha') + result = (2*alpha*z - 2*alpha + z**2 + 3)/(z**2 - 2*z + 1) + + f = Poly(x**2 + (alpha - 1)*x - alpha + 1, x, domain='ZZ[alpha]') + assert f.eval((z + 1)/(z - 1)) == result + + g = Poly(x**2 + (alpha - 1)*x - alpha + 1, x, y, domain='ZZ[alpha]') + assert g.eval((z + 1)/(z - 1)) == Poly(result, y, domain='ZZ(alpha,z)') + +def test_Poly___call__(): + f = Poly(2*x*y + 3*x + y + 2*z) + + assert f(2) == Poly(5*y + 2*z + 6) + assert f(2, 5) == Poly(2*z + 31) + assert f(2, 5, 7) == 45 + + +def test_parallel_poly_from_expr(): + assert parallel_poly_from_expr( + [x - 1, x**2 - 1], x)[0] == [Poly(x - 1, x), Poly(x**2 - 1, x)] + assert parallel_poly_from_expr( + [Poly(x - 1, x), x**2 - 1], x)[0] == [Poly(x - 1, x), Poly(x**2 - 1, x)] + assert parallel_poly_from_expr( + [x - 1, Poly(x**2 - 1, x)], x)[0] == [Poly(x - 1, x), Poly(x**2 - 1, x)] + assert parallel_poly_from_expr([Poly( + x - 1, x), Poly(x**2 - 1, x)], x)[0] == [Poly(x - 1, x), Poly(x**2 - 1, x)] + + assert parallel_poly_from_expr( + [x - 1, x**2 - 1], x, y)[0] == [Poly(x - 1, x, y), Poly(x**2 - 1, x, y)] + assert parallel_poly_from_expr([Poly( + x - 1, x), x**2 - 1], x, y)[0] == [Poly(x - 1, x, y), Poly(x**2 - 1, x, y)] + assert parallel_poly_from_expr([x - 1, Poly( + x**2 - 1, x)], x, y)[0] == [Poly(x - 1, x, y), Poly(x**2 - 1, x, y)] + assert parallel_poly_from_expr([Poly(x - 1, x), Poly( + x**2 - 1, x)], x, y)[0] == [Poly(x - 1, x, y), Poly(x**2 - 1, x, y)] + + assert parallel_poly_from_expr( + [x - 1, x**2 - 1])[0] == [Poly(x - 1, x), Poly(x**2 - 1, x)] + assert parallel_poly_from_expr( + [Poly(x - 1, x), x**2 - 1])[0] == [Poly(x - 1, x), Poly(x**2 - 1, x)] + assert parallel_poly_from_expr( + [x - 1, Poly(x**2 - 1, x)])[0] == [Poly(x - 1, x), Poly(x**2 - 1, x)] + assert parallel_poly_from_expr( + [Poly(x - 1, x), Poly(x**2 - 1, x)])[0] == [Poly(x - 1, x), Poly(x**2 - 1, x)] + + assert parallel_poly_from_expr( + [1, x**2 - 1])[0] == [Poly(1, x), Poly(x**2 - 1, x)] + assert parallel_poly_from_expr( + [1, x**2 - 1])[0] == [Poly(1, x), Poly(x**2 - 1, x)] + assert parallel_poly_from_expr( + [1, Poly(x**2 - 1, x)])[0] == [Poly(1, x), Poly(x**2 - 1, x)] + assert parallel_poly_from_expr( + [1, Poly(x**2 - 1, x)])[0] == [Poly(1, x), Poly(x**2 - 1, x)] + + assert parallel_poly_from_expr( + [x**2 - 1, 1])[0] == [Poly(x**2 - 1, x), Poly(1, x)] + assert parallel_poly_from_expr( + [x**2 - 1, 1])[0] == [Poly(x**2 - 1, x), Poly(1, x)] + assert parallel_poly_from_expr( + [Poly(x**2 - 1, x), 1])[0] == [Poly(x**2 - 1, x), Poly(1, x)] + assert parallel_poly_from_expr( + [Poly(x**2 - 1, x), 1])[0] == [Poly(x**2 - 1, x), Poly(1, x)] + + assert parallel_poly_from_expr([Poly(x, x, y), Poly(y, x, y)], x, y, order='lex')[0] == \ + [Poly(x, x, y, domain='ZZ'), Poly(y, x, y, domain='ZZ')] + + raises(PolificationFailed, lambda: parallel_poly_from_expr([0, 1])) + + +def test_pdiv(): + f, g = x**2 - y**2, x - y + q, r = x + y, 0 + + F, G, Q, R = [ Poly(h, x, y) for h in (f, g, q, r) ] + + assert F.pdiv(G) == (Q, R) + assert F.prem(G) == R + assert F.pquo(G) == Q + assert F.pexquo(G) == Q + + assert pdiv(f, g) == (q, r) + assert prem(f, g) == r + assert pquo(f, g) == q + assert pexquo(f, g) == q + + assert pdiv(f, g, x, y) == (q, r) + assert prem(f, g, x, y) == r + assert pquo(f, g, x, y) == q + assert pexquo(f, g, x, y) == q + + assert pdiv(f, g, (x, y)) == (q, r) + assert prem(f, g, (x, y)) == r + assert pquo(f, g, (x, y)) == q + assert pexquo(f, g, (x, y)) == q + + assert pdiv(F, G) == (Q, R) + assert prem(F, G) == R + assert pquo(F, G) == Q + assert pexquo(F, G) == Q + + assert pdiv(f, g, polys=True) == (Q, R) + assert prem(f, g, polys=True) == R + assert pquo(f, g, polys=True) == Q + assert pexquo(f, g, polys=True) == Q + + assert pdiv(F, G, polys=False) == (q, r) + assert prem(F, G, polys=False) == r + assert pquo(F, G, polys=False) == q + assert pexquo(F, G, polys=False) == q + + raises(ComputationFailed, lambda: pdiv(4, 2)) + raises(ComputationFailed, lambda: prem(4, 2)) + raises(ComputationFailed, lambda: pquo(4, 2)) + raises(ComputationFailed, lambda: pexquo(4, 2)) + + +def test_div(): + f, g = x**2 - y**2, x - y + q, r = x + y, 0 + + F, G, Q, R = [ Poly(h, x, y) for h in (f, g, q, r) ] + + assert F.div(G) == (Q, R) + assert F.rem(G) == R + assert F.quo(G) == Q + assert F.exquo(G) == Q + + assert div(f, g) == (q, r) + assert rem(f, g) == r + assert quo(f, g) == q + assert exquo(f, g) == q + + assert div(f, g, x, y) == (q, r) + assert rem(f, g, x, y) == r + assert quo(f, g, x, y) == q + assert exquo(f, g, x, y) == q + + assert div(f, g, (x, y)) == (q, r) + assert rem(f, g, (x, y)) == r + assert quo(f, g, (x, y)) == q + assert exquo(f, g, (x, y)) == q + + assert div(F, G) == (Q, R) + assert rem(F, G) == R + assert quo(F, G) == Q + assert exquo(F, G) == Q + + assert div(f, g, polys=True) == (Q, R) + assert rem(f, g, polys=True) == R + assert quo(f, g, polys=True) == Q + assert exquo(f, g, polys=True) == Q + + assert div(F, G, polys=False) == (q, r) + assert rem(F, G, polys=False) == r + assert quo(F, G, polys=False) == q + assert exquo(F, G, polys=False) == q + + raises(ComputationFailed, lambda: div(4, 2)) + raises(ComputationFailed, lambda: rem(4, 2)) + raises(ComputationFailed, lambda: quo(4, 2)) + raises(ComputationFailed, lambda: exquo(4, 2)) + + f, g = x**2 + 1, 2*x - 4 + + qz, rz = 0, x**2 + 1 + qq, rq = x/2 + 1, 5 + + assert div(f, g) == (qq, rq) + assert div(f, g, auto=True) == (qq, rq) + assert div(f, g, auto=False) == (qz, rz) + assert div(f, g, domain=ZZ) == (qz, rz) + assert div(f, g, domain=QQ) == (qq, rq) + assert div(f, g, domain=ZZ, auto=True) == (qq, rq) + assert div(f, g, domain=ZZ, auto=False) == (qz, rz) + assert div(f, g, domain=QQ, auto=True) == (qq, rq) + assert div(f, g, domain=QQ, auto=False) == (qq, rq) + + assert rem(f, g) == rq + assert rem(f, g, auto=True) == rq + assert rem(f, g, auto=False) == rz + assert rem(f, g, domain=ZZ) == rz + assert rem(f, g, domain=QQ) == rq + assert rem(f, g, domain=ZZ, auto=True) == rq + assert rem(f, g, domain=ZZ, auto=False) == rz + assert rem(f, g, domain=QQ, auto=True) == rq + assert rem(f, g, domain=QQ, auto=False) == rq + + assert quo(f, g) == qq + assert quo(f, g, auto=True) == qq + assert quo(f, g, auto=False) == qz + assert quo(f, g, domain=ZZ) == qz + assert quo(f, g, domain=QQ) == qq + assert quo(f, g, domain=ZZ, auto=True) == qq + assert quo(f, g, domain=ZZ, auto=False) == qz + assert quo(f, g, domain=QQ, auto=True) == qq + assert quo(f, g, domain=QQ, auto=False) == qq + + f, g, q = x**2, 2*x, x/2 + + assert exquo(f, g) == q + assert exquo(f, g, auto=True) == q + raises(ExactQuotientFailed, lambda: exquo(f, g, auto=False)) + raises(ExactQuotientFailed, lambda: exquo(f, g, domain=ZZ)) + assert exquo(f, g, domain=QQ) == q + assert exquo(f, g, domain=ZZ, auto=True) == q + raises(ExactQuotientFailed, lambda: exquo(f, g, domain=ZZ, auto=False)) + assert exquo(f, g, domain=QQ, auto=True) == q + assert exquo(f, g, domain=QQ, auto=False) == q + + f, g = Poly(x**2), Poly(x) + + q, r = f.div(g) + assert q.get_domain().is_ZZ and r.get_domain().is_ZZ + r = f.rem(g) + assert r.get_domain().is_ZZ + q = f.quo(g) + assert q.get_domain().is_ZZ + q = f.exquo(g) + assert q.get_domain().is_ZZ + + f, g = Poly(x+y, x), Poly(2*x+y, x) + q, r = f.div(g) + assert q.get_domain().is_Frac and r.get_domain().is_Frac + + # https://github.com/sympy/sympy/issues/19579 + p = Poly(2+3*I, x, domain=ZZ_I) + q = Poly(1-I, x, domain=ZZ_I) + assert p.div(q, auto=False) == \ + (Poly(0, x, domain='ZZ_I'), Poly(2 + 3*I, x, domain='ZZ_I')) + assert p.div(q, auto=True) == \ + (Poly(-S(1)/2 + 5*I/2, x, domain='QQ_I'), Poly(0, x, domain='QQ_I')) + + f = 5*x**2 + 10*x + 3 + g = 2*x + 2 + assert div(f, g, domain=ZZ) == (0, f) + + +def test_issue_7864(): + q, r = div(a, .408248290463863*a) + assert abs(q - 2.44948974278318) < 1e-14 + assert r == 0 + + +def test_gcdex(): + f, g = 2*x, x**2 - 16 + s, t, h = x/32, Rational(-1, 16), 1 + + F, G, S, T, H = [ Poly(u, x, domain='QQ') for u in (f, g, s, t, h) ] + + assert F.half_gcdex(G) == (S, H) + assert F.gcdex(G) == (S, T, H) + assert F.invert(G) == S + + assert half_gcdex(f, g) == (s, h) + assert gcdex(f, g) == (s, t, h) + assert invert(f, g) == s + + assert half_gcdex(f, g, x) == (s, h) + assert gcdex(f, g, x) == (s, t, h) + assert invert(f, g, x) == s + + assert half_gcdex(f, g, (x,)) == (s, h) + assert gcdex(f, g, (x,)) == (s, t, h) + assert invert(f, g, (x,)) == s + + assert half_gcdex(F, G) == (S, H) + assert gcdex(F, G) == (S, T, H) + assert invert(F, G) == S + + assert half_gcdex(f, g, polys=True) == (S, H) + assert gcdex(f, g, polys=True) == (S, T, H) + assert invert(f, g, polys=True) == S + + assert half_gcdex(F, G, polys=False) == (s, h) + assert gcdex(F, G, polys=False) == (s, t, h) + assert invert(F, G, polys=False) == s + + assert half_gcdex(100, 2004) == (-20, 4) + assert gcdex(100, 2004) == (-20, 1, 4) + assert invert(3, 7) == 5 + + raises(DomainError, lambda: half_gcdex(x + 1, 2*x + 1, auto=False)) + raises(DomainError, lambda: gcdex(x + 1, 2*x + 1, auto=False)) + raises(DomainError, lambda: invert(x + 1, 2*x + 1, auto=False)) + + +def test_revert(): + f = Poly(1 - x**2/2 + x**4/24 - x**6/720) + g = Poly(61*x**6/720 + 5*x**4/24 + x**2/2 + 1) + + assert f.revert(8) == g + + +def test_subresultants(): + f, g, h = x**2 - 2*x + 1, x**2 - 1, 2*x - 2 + F, G, H = Poly(f), Poly(g), Poly(h) + + assert F.subresultants(G) == [F, G, H] + assert subresultants(f, g) == [f, g, h] + assert subresultants(f, g, x) == [f, g, h] + assert subresultants(f, g, (x,)) == [f, g, h] + assert subresultants(F, G) == [F, G, H] + assert subresultants(f, g, polys=True) == [F, G, H] + assert subresultants(F, G, polys=False) == [f, g, h] + + raises(ComputationFailed, lambda: subresultants(4, 2)) + + +def test_resultant(): + f, g, h = x**2 - 2*x + 1, x**2 - 1, 0 + F, G = Poly(f), Poly(g) + + assert F.resultant(G) == h + assert resultant(f, g) == h + assert resultant(f, g, x) == h + assert resultant(f, g, (x,)) == h + assert resultant(F, G) == h + assert resultant(f, g, polys=True) == h + assert resultant(F, G, polys=False) == h + assert resultant(f, g, includePRS=True) == (h, [f, g, 2*x - 2]) + + f, g, h = x - a, x - b, a - b + F, G, H = Poly(f), Poly(g), Poly(h) + + assert F.resultant(G) == H + assert resultant(f, g) == h + assert resultant(f, g, x) == h + assert resultant(f, g, (x,)) == h + assert resultant(F, G) == H + assert resultant(f, g, polys=True) == H + assert resultant(F, G, polys=False) == h + + raises(ComputationFailed, lambda: resultant(4, 2)) + + +def test_discriminant(): + f, g = x**3 + 3*x**2 + 9*x - 13, -11664 + F = Poly(f) + + assert F.discriminant() == g + assert discriminant(f) == g + assert discriminant(f, x) == g + assert discriminant(f, (x,)) == g + assert discriminant(F) == g + assert discriminant(f, polys=True) == g + assert discriminant(F, polys=False) == g + + f, g = a*x**2 + b*x + c, b**2 - 4*a*c + F, G = Poly(f), Poly(g) + + assert F.discriminant() == G + assert discriminant(f) == g + assert discriminant(f, x, a, b, c) == g + assert discriminant(f, (x, a, b, c)) == g + assert discriminant(F) == G + assert discriminant(f, polys=True) == G + assert discriminant(F, polys=False) == g + + raises(ComputationFailed, lambda: discriminant(4)) + + +def test_dispersion(): + # We test only the API here. For more mathematical + # tests see the dedicated test file. + fp = poly((x + 1)*(x + 2), x) + assert sorted(fp.dispersionset()) == [0, 1] + assert fp.dispersion() == 1 + + fp = poly(x**4 - 3*x**2 + 1, x) + gp = fp.shift(-3) + assert sorted(fp.dispersionset(gp)) == [2, 3, 4] + assert fp.dispersion(gp) == 4 + + +def test_gcd_list(): + F = [x**3 - 1, x**2 - 1, x**2 - 3*x + 2] + + assert gcd_list(F) == x - 1 + assert gcd_list(F, polys=True) == Poly(x - 1) + + assert gcd_list([]) == 0 + assert gcd_list([1, 2]) == 1 + assert gcd_list([4, 6, 8]) == 2 + + assert gcd_list([x*(y + 42) - x*y - x*42]) == 0 + + gcd = gcd_list([], x) + assert gcd.is_Number and gcd is S.Zero + + gcd = gcd_list([], x, polys=True) + assert gcd.is_Poly and gcd.is_zero + + a = sqrt(2) + assert gcd_list([a, -a]) == gcd_list([-a, a]) == a + + raises(ComputationFailed, lambda: gcd_list([], polys=True)) + + +def test_lcm_list(): + F = [x**3 - 1, x**2 - 1, x**2 - 3*x + 2] + + assert lcm_list(F) == x**5 - x**4 - 2*x**3 - x**2 + x + 2 + assert lcm_list(F, polys=True) == Poly(x**5 - x**4 - 2*x**3 - x**2 + x + 2) + + assert lcm_list([]) == 1 + assert lcm_list([1, 2]) == 2 + assert lcm_list([4, 6, 8]) == 24 + + assert lcm_list([x*(y + 42) - x*y - x*42]) == 0 + + lcm = lcm_list([], x) + assert lcm.is_Number and lcm is S.One + + lcm = lcm_list([], x, polys=True) + assert lcm.is_Poly and lcm.is_one + + raises(ComputationFailed, lambda: lcm_list([], polys=True)) + + +def test_gcd(): + f, g = x**3 - 1, x**2 - 1 + s, t = x**2 + x + 1, x + 1 + h, r = x - 1, x**4 + x**3 - x - 1 + + F, G, S, T, H, R = [ Poly(u) for u in (f, g, s, t, h, r) ] + + assert F.cofactors(G) == (H, S, T) + assert F.gcd(G) == H + assert F.lcm(G) == R + + assert cofactors(f, g) == (h, s, t) + assert gcd(f, g) == h + assert lcm(f, g) == r + + assert cofactors(f, g, x) == (h, s, t) + assert gcd(f, g, x) == h + assert lcm(f, g, x) == r + + assert cofactors(f, g, (x,)) == (h, s, t) + assert gcd(f, g, (x,)) == h + assert lcm(f, g, (x,)) == r + + assert cofactors(F, G) == (H, S, T) + assert gcd(F, G) == H + assert lcm(F, G) == R + + assert cofactors(f, g, polys=True) == (H, S, T) + assert gcd(f, g, polys=True) == H + assert lcm(f, g, polys=True) == R + + assert cofactors(F, G, polys=False) == (h, s, t) + assert gcd(F, G, polys=False) == h + assert lcm(F, G, polys=False) == r + + f, g = 1.0*x**2 - 1.0, 1.0*x - 1.0 + h, s, t = g, 1.0*x + 1.0, 1.0 + + assert cofactors(f, g) == (h, s, t) + assert gcd(f, g) == h + assert lcm(f, g) == f + + f, g = 1.0*x**2 - 1.0, 1.0*x - 1.0 + h, s, t = g, 1.0*x + 1.0, 1.0 + + assert cofactors(f, g) == (h, s, t) + assert gcd(f, g) == h + assert lcm(f, g) == f + + assert cofactors(8, 6) == (2, 4, 3) + assert gcd(8, 6) == 2 + assert lcm(8, 6) == 24 + + f, g = x**2 - 3*x - 4, x**3 - 4*x**2 + x - 4 + l = x**4 - 3*x**3 - 3*x**2 - 3*x - 4 + h, s, t = x - 4, x + 1, x**2 + 1 + + assert cofactors(f, g, modulus=11) == (h, s, t) + assert gcd(f, g, modulus=11) == h + assert lcm(f, g, modulus=11) == l + + f, g = x**2 + 8*x + 7, x**3 + 7*x**2 + x + 7 + l = x**4 + 8*x**3 + 8*x**2 + 8*x + 7 + h, s, t = x + 7, x + 1, x**2 + 1 + + assert cofactors(f, g, modulus=11, symmetric=False) == (h, s, t) + assert gcd(f, g, modulus=11, symmetric=False) == h + assert lcm(f, g, modulus=11, symmetric=False) == l + + a, b = sqrt(2), -sqrt(2) + assert gcd(a, b) == gcd(b, a) == sqrt(2) + + a, b = sqrt(-2), -sqrt(-2) + assert gcd(a, b) == gcd(b, a) == sqrt(2) + + assert gcd(Poly(x - 2, x), Poly(I*x, x)) == Poly(1, x, domain=ZZ_I) + + raises(TypeError, lambda: gcd(x)) + raises(TypeError, lambda: lcm(x)) + + f = Poly(-1, x) + g = Poly(1, x) + assert lcm(f, g) == Poly(1, x) + + f = Poly(0, x) + g = Poly([1, 1], x) + for i in (f, g): + assert lcm(i, 0) == 0 + assert lcm(0, i) == 0 + assert lcm(i, f) == 0 + assert lcm(f, i) == 0 + + f = 4*x**2 + x + 2 + pfz = Poly(f, domain=ZZ) + pfq = Poly(f, domain=QQ) + + assert pfz.gcd(pfz) == pfz + assert pfz.lcm(pfz) == pfz + assert pfq.gcd(pfq) == pfq.monic() + assert pfq.lcm(pfq) == pfq.monic() + assert gcd(f, f) == f + assert lcm(f, f) == f + assert gcd(f, f, domain=QQ) == monic(f) + assert lcm(f, f, domain=QQ) == monic(f) + + +def test_gcd_numbers_vs_polys(): + assert isinstance(gcd(3, 9), Integer) + assert isinstance(gcd(3*x, 9), Integer) + + assert gcd(3, 9) == 3 + assert gcd(3*x, 9) == 3 + + assert isinstance(gcd(Rational(3, 2), Rational(9, 4)), Rational) + assert isinstance(gcd(Rational(3, 2)*x, Rational(9, 4)), Rational) + + assert gcd(Rational(3, 2), Rational(9, 4)) == Rational(3, 4) + assert gcd(Rational(3, 2)*x, Rational(9, 4)) == 1 + + assert isinstance(gcd(3.0, 9.0), Float) + assert isinstance(gcd(3.0*x, 9.0), Float) + + assert gcd(3.0, 9.0) == 1.0 + assert gcd(3.0*x, 9.0) == 1.0 + + # partial fix of 20597 + assert gcd(Mul(2, 3, evaluate=False), 2) == 2 + + +def test_terms_gcd(): + assert terms_gcd(1) == 1 + assert terms_gcd(1, x) == 1 + + assert terms_gcd(x - 1) == x - 1 + assert terms_gcd(-x - 1) == -x - 1 + + assert terms_gcd(2*x + 3) == 2*x + 3 + assert terms_gcd(6*x + 4) == Mul(2, 3*x + 2, evaluate=False) + + assert terms_gcd(x**3*y + x*y**3) == x*y*(x**2 + y**2) + assert terms_gcd(2*x**3*y + 2*x*y**3) == 2*x*y*(x**2 + y**2) + assert terms_gcd(x**3*y/2 + x*y**3/2) == x*y/2*(x**2 + y**2) + + assert terms_gcd(x**3*y + 2*x*y**3) == x*y*(x**2 + 2*y**2) + assert terms_gcd(2*x**3*y + 4*x*y**3) == 2*x*y*(x**2 + 2*y**2) + assert terms_gcd(2*x**3*y/3 + 4*x*y**3/5) == x*y*Rational(2, 15)*(5*x**2 + 6*y**2) + + assert terms_gcd(2.0*x**3*y + 4.1*x*y**3) == x*y*(2.0*x**2 + 4.1*y**2) + assert _aresame(terms_gcd(2.0*x + 3), 2.0*x + 3) + + assert terms_gcd((3 + 3*x)*(x + x*y), expand=False) == \ + (3*x + 3)*(x*y + x) + assert terms_gcd((3 + 3*x)*(x + x*sin(3 + 3*y)), expand=False, deep=True) == \ + 3*x*(x + 1)*(sin(Mul(3, y + 1, evaluate=False)) + 1) + assert terms_gcd(sin(x + x*y), deep=True) == \ + sin(x*(y + 1)) + + eq = Eq(2*x, 2*y + 2*z*y) + assert terms_gcd(eq) == Eq(2*x, 2*y*(z + 1)) + assert terms_gcd(eq, deep=True) == Eq(2*x, 2*y*(z + 1)) + + raises(TypeError, lambda: terms_gcd(x < 2)) + + +def test_trunc(): + f, g = x**5 + 2*x**4 + 3*x**3 + 4*x**2 + 5*x + 6, x**5 - x**4 + x**2 - x + F, G = Poly(f), Poly(g) + + assert F.trunc(3) == G + assert trunc(f, 3) == g + assert trunc(f, 3, x) == g + assert trunc(f, 3, (x,)) == g + assert trunc(F, 3) == G + assert trunc(f, 3, polys=True) == G + assert trunc(F, 3, polys=False) == g + + f, g = 6*x**5 + 5*x**4 + 4*x**3 + 3*x**2 + 2*x + 1, -x**4 + x**3 - x + 1 + F, G = Poly(f), Poly(g) + + assert F.trunc(3) == G + assert trunc(f, 3) == g + assert trunc(f, 3, x) == g + assert trunc(f, 3, (x,)) == g + assert trunc(F, 3) == G + assert trunc(f, 3, polys=True) == G + assert trunc(F, 3, polys=False) == g + + f = Poly(x**2 + 2*x + 3, modulus=5) + + assert f.trunc(2) == Poly(x**2 + 1, modulus=5) + + +def test_monic(): + f, g = 2*x - 1, x - S.Half + F, G = Poly(f, domain='QQ'), Poly(g) + + assert F.monic() == G + assert monic(f) == g + assert monic(f, x) == g + assert monic(f, (x,)) == g + assert monic(F) == G + assert monic(f, polys=True) == G + assert monic(F, polys=False) == g + + raises(ComputationFailed, lambda: monic(4)) + + assert monic(2*x**2 + 6*x + 4, auto=False) == x**2 + 3*x + 2 + raises(ExactQuotientFailed, lambda: monic(2*x + 6*x + 1, auto=False)) + + assert monic(2.0*x**2 + 6.0*x + 4.0) == 1.0*x**2 + 3.0*x + 2.0 + assert monic(2*x**2 + 3*x + 4, modulus=5) == x**2 - x + 2 + + +def test_content(): + f, F = 4*x + 2, Poly(4*x + 2) + + assert F.content() == 2 + assert content(f) == 2 + + raises(ComputationFailed, lambda: content(4)) + + f = Poly(2*x, modulus=3) + + assert f.content() == 1 + + +def test_primitive(): + f, g = 4*x + 2, 2*x + 1 + F, G = Poly(f), Poly(g) + + assert F.primitive() == (2, G) + assert primitive(f) == (2, g) + assert primitive(f, x) == (2, g) + assert primitive(f, (x,)) == (2, g) + assert primitive(F) == (2, G) + assert primitive(f, polys=True) == (2, G) + assert primitive(F, polys=False) == (2, g) + + raises(ComputationFailed, lambda: primitive(4)) + + f = Poly(2*x, modulus=3) + g = Poly(2.0*x, domain=RR) + + assert f.primitive() == (1, f) + assert g.primitive() == (1.0, g) + + assert primitive(S('-3*x/4 + y + 11/8')) == \ + S('(1/8, -6*x + 8*y + 11)') + + +def test_compose(): + f = x**12 + 20*x**10 + 150*x**8 + 500*x**6 + 625*x**4 - 2*x**3 - 10*x + 9 + g = x**4 - 2*x + 9 + h = x**3 + 5*x + + F, G, H = map(Poly, (f, g, h)) + + assert G.compose(H) == F + assert compose(g, h) == f + assert compose(g, h, x) == f + assert compose(g, h, (x,)) == f + assert compose(G, H) == F + assert compose(g, h, polys=True) == F + assert compose(G, H, polys=False) == f + + assert F.decompose() == [G, H] + assert decompose(f) == [g, h] + assert decompose(f, x) == [g, h] + assert decompose(f, (x,)) == [g, h] + assert decompose(F) == [G, H] + assert decompose(f, polys=True) == [G, H] + assert decompose(F, polys=False) == [g, h] + + raises(ComputationFailed, lambda: compose(4, 2)) + raises(ComputationFailed, lambda: decompose(4)) + + assert compose(x**2 - y**2, x - y, x, y) == x**2 - 2*x*y + assert compose(x**2 - y**2, x - y, y, x) == -y**2 + 2*x*y + + +def test_shift(): + assert Poly(x**2 - 2*x + 1, x).shift(2) == Poly(x**2 + 2*x + 1, x) + + +def test_shift_list(): + assert Poly(x*y, [x,y]).shift_list([1,2]) == Poly((x+1)*(y+2), [x,y]) + + +def test_transform(): + # Also test that 3-way unification is done correctly + assert Poly(x**2 - 2*x + 1, x).transform(Poly(x + 1), Poly(x - 1)) == \ + Poly(4, x) == \ + cancel((x - 1)**2*(x**2 - 2*x + 1).subs(x, (x + 1)/(x - 1))) + + assert Poly(x**2 - x/2 + 1, x).transform(Poly(x + 1), Poly(x - 1)) == \ + Poly(3*x**2/2 + Rational(5, 2), x) == \ + cancel((x - 1)**2*(x**2 - x/2 + 1).subs(x, (x + 1)/(x - 1))) + + assert Poly(x**2 - 2*x + 1, x).transform(Poly(x + S.Half), Poly(x - 1)) == \ + Poly(Rational(9, 4), x) == \ + cancel((x - 1)**2*(x**2 - 2*x + 1).subs(x, (x + S.Half)/(x - 1))) + + assert Poly(x**2 - 2*x + 1, x).transform(Poly(x + 1), Poly(x - S.Half)) == \ + Poly(Rational(9, 4), x) == \ + cancel((x - S.Half)**2*(x**2 - 2*x + 1).subs(x, (x + 1)/(x - S.Half))) + + # Unify ZZ, QQ, and RR + assert Poly(x**2 - 2*x + 1, x).transform(Poly(x + 1.0), Poly(x - S.Half)) == \ + Poly(Rational(9, 4), x, domain='RR') == \ + cancel((x - S.Half)**2*(x**2 - 2*x + 1).subs(x, (x + 1.0)/(x - S.Half))) + + raises(ValueError, lambda: Poly(x*y).transform(Poly(x + 1), Poly(x - 1))) + raises(ValueError, lambda: Poly(x).transform(Poly(y + 1), Poly(x - 1))) + raises(ValueError, lambda: Poly(x).transform(Poly(x + 1), Poly(y - 1))) + raises(ValueError, lambda: Poly(x).transform(Poly(x*y + 1), Poly(x - 1))) + raises(ValueError, lambda: Poly(x).transform(Poly(x + 1), Poly(x*y - 1))) + + +def test_sturm(): + f, F = x, Poly(x, domain='QQ') + g, G = 1, Poly(1, x, domain='QQ') + + assert F.sturm() == [F, G] + assert sturm(f) == [f, g] + assert sturm(f, x) == [f, g] + assert sturm(f, (x,)) == [f, g] + assert sturm(F) == [F, G] + assert sturm(f, polys=True) == [F, G] + assert sturm(F, polys=False) == [f, g] + + raises(ComputationFailed, lambda: sturm(4)) + raises(DomainError, lambda: sturm(f, auto=False)) + + f = Poly(S(1024)/(15625*pi**8)*x**5 + - S(4096)/(625*pi**8)*x**4 + + S(32)/(15625*pi**4)*x**3 + - S(128)/(625*pi**4)*x**2 + + Rational(1, 62500)*x + - Rational(1, 625), x, domain='ZZ(pi)') + + assert sturm(f) == \ + [Poly(x**3 - 100*x**2 + pi**4/64*x - 25*pi**4/16, x, domain='ZZ(pi)'), + Poly(3*x**2 - 200*x + pi**4/64, x, domain='ZZ(pi)'), + Poly((Rational(20000, 9) - pi**4/96)*x + 25*pi**4/18, x, domain='ZZ(pi)'), + Poly((-3686400000000*pi**4 - 11520000*pi**8 - 9*pi**12)/(26214400000000 - 245760000*pi**4 + 576*pi**8), x, domain='ZZ(pi)')] + + +def test_gff(): + f = x**5 + 2*x**4 - x**3 - 2*x**2 + + assert Poly(f).gff_list() == [(Poly(x), 1), (Poly(x + 2), 4)] + assert gff_list(f) == [(x, 1), (x + 2, 4)] + + raises(NotImplementedError, lambda: gff(f)) + + f = x*(x - 1)**3*(x - 2)**2*(x - 4)**2*(x - 5) + + assert Poly(f).gff_list() == [( + Poly(x**2 - 5*x + 4), 1), (Poly(x**2 - 5*x + 4), 2), (Poly(x), 3)] + assert gff_list(f) == [(x**2 - 5*x + 4, 1), (x**2 - 5*x + 4, 2), (x, 3)] + + raises(NotImplementedError, lambda: gff(f)) + + +def test_norm(): + a, b = sqrt(2), sqrt(3) + f = Poly(a*x + b*y, x, y, extension=(a, b)) + assert f.norm() == Poly(4*x**4 - 12*x**2*y**2 + 9*y**4, x, y, domain='QQ') + + +def test_sqf_norm(): + assert sqf_norm(x**2 - 2, extension=sqrt(3)) == \ + ([1], x**2 - 2*sqrt(3)*x + 1, x**4 - 10*x**2 + 1) + assert sqf_norm(x**2 - 3, extension=sqrt(2)) == \ + ([1], x**2 - 2*sqrt(2)*x - 1, x**4 - 10*x**2 + 1) + + assert Poly(x**2 - 2, extension=sqrt(3)).sqf_norm() == \ + ([1], Poly(x**2 - 2*sqrt(3)*x + 1, x, extension=sqrt(3)), + Poly(x**4 - 10*x**2 + 1, x, domain='QQ')) + + assert Poly(x**2 - 3, extension=sqrt(2)).sqf_norm() == \ + ([1], Poly(x**2 - 2*sqrt(2)*x - 1, x, extension=sqrt(2)), + Poly(x**4 - 10*x**2 + 1, x, domain='QQ')) + + +def test_sqf(): + f = x**5 - x**3 - x**2 + 1 + g = x**3 + 2*x**2 + 2*x + 1 + h = x - 1 + + p = x**4 + x**3 - x - 1 + + F, G, H, P = map(Poly, (f, g, h, p)) + + assert F.sqf_part() == P + assert sqf_part(f) == p + assert sqf_part(f, x) == p + assert sqf_part(f, (x,)) == p + assert sqf_part(F) == P + assert sqf_part(f, polys=True) == P + assert sqf_part(F, polys=False) == p + + assert F.sqf_list() == (1, [(G, 1), (H, 2)]) + assert sqf_list(f) == (1, [(g, 1), (h, 2)]) + assert sqf_list(f, x) == (1, [(g, 1), (h, 2)]) + assert sqf_list(f, (x,)) == (1, [(g, 1), (h, 2)]) + assert sqf_list(F) == (1, [(G, 1), (H, 2)]) + assert sqf_list(f, polys=True) == (1, [(G, 1), (H, 2)]) + assert sqf_list(F, polys=False) == (1, [(g, 1), (h, 2)]) + + assert F.sqf_list_include() == [(G, 1), (H, 2)] + + raises(ComputationFailed, lambda: sqf_part(4)) + + assert sqf(1) == 1 + assert sqf_list(1) == (1, []) + + assert sqf((2*x**2 + 2)**7) == 128*(x**2 + 1)**7 + + assert sqf(f) == g*h**2 + assert sqf(f, x) == g*h**2 + assert sqf(f, (x,)) == g*h**2 + + d = x**2 + y**2 + + assert sqf(f/d) == (g*h**2)/d + assert sqf(f/d, x) == (g*h**2)/d + assert sqf(f/d, (x,)) == (g*h**2)/d + + assert sqf(x - 1) == x - 1 + assert sqf(-x - 1) == -x - 1 + + assert sqf(x - 1) == x - 1 + assert sqf(6*x - 10) == Mul(2, 3*x - 5, evaluate=False) + + assert sqf((6*x - 10)/(3*x - 6)) == Rational(2, 3)*((3*x - 5)/(x - 2)) + assert sqf(Poly(x**2 - 2*x + 1)) == (x - 1)**2 + + f = 3 + x - x*(1 + x) + x**2 + + assert sqf(f) == 3 + + f = (x**2 + 2*x + 1)**20000000000 + + assert sqf(f) == (x + 1)**40000000000 + assert sqf_list(f) == (1, [(x + 1, 40000000000)]) + + # https://github.com/sympy/sympy/issues/26497 + assert sqf(expand(((y - 2)**2 * (y + 2) * (x + 1)))) == \ + (y - 2)**2 * expand((y + 2) * (x + 1)) + assert sqf(expand(((y - 2)**2 * (y + 2) * (z + 1)))) == \ + (y - 2)**2 * expand((y + 2) * (z + 1)) + assert sqf(expand(((y - I)**2 * (y + I) * (x + 1)))) == \ + (y - I)**2 * expand((y + I) * (x + 1)) + assert sqf(expand(((y - I)**2 * (y + I) * (z + 1)))) == \ + (y - I)**2 * expand((y + I) * (z + 1)) + + # Check that factors are combined and sorted. + p = (x - 2)**2*(x - 1)*(x + y)**2*(y - 2)**2*(y - 1) + assert Poly(p).sqf_list() == (1, [ + (Poly(x*y - x - y + 1), 1), + (Poly(x**2*y - 2*x**2 + x*y**2 - 4*x*y + 4*x - 2*y**2 + 4*y), 2) + ]) + + +def test_factor(): + f = x**5 - x**3 - x**2 + 1 + + u = x + 1 + v = x - 1 + w = x**2 + x + 1 + + F, U, V, W = map(Poly, (f, u, v, w)) + + assert F.factor_list() == (1, [(U, 1), (V, 2), (W, 1)]) + assert factor_list(f) == (1, [(u, 1), (v, 2), (w, 1)]) + assert factor_list(f, x) == (1, [(u, 1), (v, 2), (w, 1)]) + assert factor_list(f, (x,)) == (1, [(u, 1), (v, 2), (w, 1)]) + assert factor_list(F) == (1, [(U, 1), (V, 2), (W, 1)]) + assert factor_list(f, polys=True) == (1, [(U, 1), (V, 2), (W, 1)]) + assert factor_list(F, polys=False) == (1, [(u, 1), (v, 2), (w, 1)]) + + assert F.factor_list_include() == [(U, 1), (V, 2), (W, 1)] + + assert factor_list(1) == (1, []) + assert factor_list(6) == (6, []) + assert factor_list(sqrt(3), x) == (sqrt(3), []) + assert factor_list((-1)**x, x) == (1, [(-1, x)]) + assert factor_list((2*x)**y, x) == (1, [(2, y), (x, y)]) + assert factor_list(sqrt(x*y), x) == (1, [(x*y, S.Half)]) + + assert factor(6) == 6 and factor(6).is_Integer + + assert factor_list(3*x) == (3, [(x, 1)]) + assert factor_list(3*x**2) == (3, [(x, 2)]) + + assert factor(3*x) == 3*x + assert factor(3*x**2) == 3*x**2 + + assert factor((2*x**2 + 2)**7) == 128*(x**2 + 1)**7 + + assert factor(f) == u*v**2*w + assert factor(f, x) == u*v**2*w + assert factor(f, (x,)) == u*v**2*w + + g, p, q, r = x**2 - y**2, x - y, x + y, x**2 + 1 + + assert factor(f/g) == (u*v**2*w)/(p*q) + assert factor(f/g, x) == (u*v**2*w)/(p*q) + assert factor(f/g, (x,)) == (u*v**2*w)/(p*q) + + p = Symbol('p', positive=True) + i = Symbol('i', integer=True) + r = Symbol('r', real=True) + + assert factor(sqrt(x*y)).is_Pow is True + + assert factor(sqrt(3*x**2 - 3)) == sqrt(3)*sqrt((x - 1)*(x + 1)) + assert factor(sqrt(3*x**2 + 3)) == sqrt(3)*sqrt(x**2 + 1) + + assert factor((y*x**2 - y)**i) == y**i*(x - 1)**i*(x + 1)**i + assert factor((y*x**2 + y)**i) == y**i*(x**2 + 1)**i + + assert factor((y*x**2 - y)**t) == (y*(x - 1)*(x + 1))**t + assert factor((y*x**2 + y)**t) == (y*(x**2 + 1))**t + + f = sqrt(expand((r**2 + 1)*(p + 1)*(p - 1)*(p - 2)**3)) + g = sqrt((p - 2)**3*(p - 1))*sqrt(p + 1)*sqrt(r**2 + 1) + + assert factor(f) == g + assert factor(g) == g + + g = (x - 1)**5*(r**2 + 1) + f = sqrt(expand(g)) + + assert factor(f) == sqrt(g) + + f = Poly(sin(1)*x + 1, x, domain=EX) + + assert f.factor_list() == (1, [(f, 1)]) + + f = x**4 + 1 + + assert factor(f) == f + assert factor(f, extension=I) == (x**2 - I)*(x**2 + I) + assert factor(f, gaussian=True) == (x**2 - I)*(x**2 + I) + assert factor( + f, extension=sqrt(2)) == (x**2 + sqrt(2)*x + 1)*(x**2 - sqrt(2)*x + 1) + + assert factor(x**2 + 4*I*x - 4) == (x + 2*I)**2 + + f = x**2 + 2*I*x - 4 + + assert factor(f) == f + + f = 8192*x**2 + x*(22656 + 175232*I) - 921416 + 242313*I + f_zzi = I*(x*(64 - 64*I) + 773 + 596*I)**2 + f_qqi = 8192*(x + S(177)/128 + 1369*I/128)**2 + + assert factor(f) == f_zzi + assert factor(f, domain=ZZ_I) == f_zzi + assert factor(f, domain=QQ_I) == f_qqi + + f = x**2 + 2*sqrt(2)*x + 2 + + assert factor(f, extension=sqrt(2)) == (x + sqrt(2))**2 + assert factor(f**3, extension=sqrt(2)) == (x + sqrt(2))**6 + + assert factor(x**2 - 2*y**2, extension=sqrt(2)) == \ + (x + sqrt(2)*y)*(x - sqrt(2)*y) + assert factor(2*x**2 - 4*y**2, extension=sqrt(2)) == \ + 2*((x + sqrt(2)*y)*(x - sqrt(2)*y)) + + assert factor(x - 1) == x - 1 + assert factor(-x - 1) == -x - 1 + + assert factor(x - 1) == x - 1 + + assert factor(6*x - 10) == Mul(2, 3*x - 5, evaluate=False) + + assert factor(x**11 + x + 1, modulus=65537, symmetric=True) == \ + (x**2 + x + 1)*(x**9 - x**8 + x**6 - x**5 + x**3 - x** 2 + 1) + assert factor(x**11 + x + 1, modulus=65537, symmetric=False) == \ + (x**2 + x + 1)*(x**9 + 65536*x**8 + x**6 + 65536*x**5 + + x**3 + 65536*x** 2 + 1) + + f = x/pi + x*sin(x)/pi + g = y/(pi**2 + 2*pi + 1) + y*sin(x)/(pi**2 + 2*pi + 1) + + assert factor(f) == x*(sin(x) + 1)/pi + assert factor(g) == y*(sin(x) + 1)/(pi + 1)**2 + + assert factor(Eq( + x**2 + 2*x + 1, x**3 + 1)) == Eq((x + 1)**2, (x + 1)*(x**2 - x + 1)) + + f = (x**2 - 1)/(x**2 + 4*x + 4) + + assert factor(f) == (x + 1)*(x - 1)/(x + 2)**2 + assert factor(f, x) == (x + 1)*(x - 1)/(x + 2)**2 + + f = 3 + x - x*(1 + x) + x**2 + + assert factor(f) == 3 + assert factor(f, x) == 3 + + assert factor(1/(x**2 + 2*x + 1/x) - 1) == -((1 - x + 2*x**2 + + x**3)/(1 + 2*x**2 + x**3)) + + assert factor(f, expand=False) == f + raises(PolynomialError, lambda: factor(f, x, expand=False)) + + raises(FlagError, lambda: factor(x**2 - 1, polys=True)) + + assert factor([x, Eq(x**2 - y**2, Tuple(x**2 - z**2, 1/x + 1/y))]) == \ + [x, Eq((x - y)*(x + y), Tuple((x - z)*(x + z), (x + y)/x/y))] + + assert not isinstance( + Poly(x**3 + x + 1).factor_list()[1][0][0], PurePoly) is True + assert isinstance( + PurePoly(x**3 + x + 1).factor_list()[1][0][0], PurePoly) is True + + assert factor(sqrt(-x)) == sqrt(-x) + + # issue 5917 + e = (-2*x*(-x + 1)*(x - 1)*(-x*(-x + 1)*(x - 1) - x*(x - 1)**2)*(x**2*(x - + 1) - x*(x - 1) - x) - (-2*x**2*(x - 1)**2 - x*(-x + 1)*(-x*(-x + 1) + + x*(x - 1)))*(x**2*(x - 1)**4 - x*(-x*(-x + 1)*(x - 1) - x*(x - 1)**2))) + assert factor(e) == 0 + + # deep option + assert factor(sin(x**2 + x) + x, deep=True) == sin(x*(x + 1)) + x + assert factor(sin(x**2 + x)*x, deep=True) == sin(x*(x + 1))*x + + assert factor(sqrt(x**2)) == sqrt(x**2) + + # issue 13149 + assert factor(expand((0.5*x+1)*(0.5*y+1))) == Mul(1.0, 0.5*x + 1.0, + 0.5*y + 1.0, evaluate = False) + assert factor(expand((0.5*x+0.5)**2)) == 0.25*(1.0*x + 1.0)**2 + + eq = x**2*y**2 + 11*x**2*y + 30*x**2 + 7*x*y**2 + 77*x*y + 210*x + 12*y**2 + 132*y + 360 + assert factor(eq, x) == (x + 3)*(x + 4)*(y**2 + 11*y + 30) + assert factor(eq, x, deep=True) == (x + 3)*(x + 4)*(y**2 + 11*y + 30) + assert factor(eq, y, deep=True) == (y + 5)*(y + 6)*(x**2 + 7*x + 12) + + # fraction option + f = 5*x + 3*exp(2 - 7*x) + assert factor(f, deep=True) == factor(f, deep=True, fraction=True) + assert factor(f, deep=True, fraction=False) == 5*x + 3*exp(2)*exp(-7*x) + + assert factor_list(x**3 - x*y**2, t, w, x) == ( + 1, [(x, 1), (x - y, 1), (x + y, 1)]) + assert factor_list((x+1)*(x**6-1)) == ( + 1, [(x - 1, 1), (x + 1, 2), (x**2 - x + 1, 1), (x**2 + x + 1, 1)]) + + # https://github.com/sympy/sympy/issues/24952 + s2, s2p, s2n = sqrt(2), 1 + sqrt(2), 1 - sqrt(2) + pip, pin = 1 + pi, 1 - pi + assert factor_list(s2p*s2n) == (-1, [(-s2n, 1), (s2p, 1)]) + assert factor_list(pip*pin) == (-1, [(-pin, 1), (pip, 1)]) + # Not sure about this one. Maybe coeff should be 1 or -1? + assert factor_list(s2*s2n) == (-s2, [(-s2n, 1)]) + assert factor_list(pi*pin) == (-1, [(-pin, 1), (pi, 1)]) + assert factor_list(s2p*s2n, x) == (s2p*s2n, []) + assert factor_list(pip*pin, x) == (pip*pin, []) + assert factor_list(s2*s2n, x) == (s2*s2n, []) + assert factor_list(pi*pin, x) == (pi*pin, []) + assert factor_list((x - sqrt(2)*pi)*(x + sqrt(2)*pi), x) == ( + 1, [(x - sqrt(2)*pi, 1), (x + sqrt(2)*pi, 1)]) + + # https://github.com/sympy/sympy/issues/26497 + p = ((y - I)**2 * (y + I) * (x + 1)) + assert factor(expand(p)) == p + + p = ((x - I)**2 * (x + I) * (y + 1)) + assert factor(expand(p)) == p + + p = (y + 1)**2*(y + sqrt(2))**2*(x**2 + x + 2 + 3*sqrt(2))**2 + assert factor(expand(p), extension=True) == p + + e = ( + -x**2*y**4/(y**2 + 1) + 2*I*x**2*y**3/(y**2 + 1) + 2*I*x**2*y/(y**2 + 1) + + x**2/(y**2 + 1) - 2*x*y**4/(y**2 + 1) + 4*I*x*y**3/(y**2 + 1) + + 4*I*x*y/(y**2 + 1) + 2*x/(y**2 + 1) - y**4 - y**4/(y**2 + 1) + 2*I*y**3 + + 2*I*y**3/(y**2 + 1) + 2*I*y + 2*I*y/(y**2 + 1) + 1 + 1/(y**2 + 1) + ) + assert factor(e) == -(y - I)**3*(y + I)*(x**2 + 2*x + y**2 + 2)/(y**2 + 1) + + # issue 27506 + e = (I*t*x*y - 3*I*t - I*x*y*z - 6*x*y + 3*I*z + 18) + assert factor(e) == -I*(x*y - 3)*(-t + z - 6*I) + + e = (8*x**2*z**2 - 32*x**2*z*t + 24*x**2*t**2 - 4*I*x*y*z**2 + 16*I*x*y*z*t - + 12*I*x*y*t**2 + z**4 - 8*z**3*t + 22*z**2*t**2 - 24*z*t**3 + 9*t**4) + assert factor(e) == (-3*t + z)*(-t + z)*(3*t**2 - 4*t*z + 8*x**2 - 4*I*x*y + z**2) + + +def test_factor_large(): + f = (x**2 + 4*x + 4)**10000000*(x**2 + 1)*(x**2 + 2*x + 1)**1234567 + g = ((x**2 + 2*x + 1)**3000*y**2 + (x**2 + 2*x + 1)**3000*2*y + ( + x**2 + 2*x + 1)**3000) + + assert factor(f) == (x + 2)**20000000*(x**2 + 1)*(x + 1)**2469134 + assert factor(g) == (x + 1)**6000*(y + 1)**2 + + assert factor_list( + f) == (1, [(x + 1, 2469134), (x + 2, 20000000), (x**2 + 1, 1)]) + assert factor_list(g) == (1, [(y + 1, 2), (x + 1, 6000)]) + + f = (x**2 - y**2)**200000*(x**7 + 1) + g = (x**2 + y**2)**200000*(x**7 + 1) + + assert factor(f) == \ + (x + 1)*(x - y)**200000*(x + y)**200000*(x**6 - x**5 + + x**4 - x**3 + x**2 - x + 1) + assert factor(g, gaussian=True) == \ + (x + 1)*(x - I*y)**200000*(x + I*y)**200000*(x**6 - x**5 + + x**4 - x**3 + x**2 - x + 1) + + assert factor_list(f) == \ + (1, [(x + 1, 1), (x - y, 200000), (x + y, 200000), (x**6 - + x**5 + x**4 - x**3 + x**2 - x + 1, 1)]) + assert factor_list(g, gaussian=True) == \ + (1, [(x + 1, 1), (x - I*y, 200000), (x + I*y, 200000), ( + x**6 - x**5 + x**4 - x**3 + x**2 - x + 1, 1)]) + + +def test_factor_noeval(): + assert factor(6*x - 10) == Mul(2, 3*x - 5, evaluate=False) + assert factor((6*x - 10)/(3*x - 6)) == Mul(Rational(2, 3), 3*x - 5, 1/(x - 2)) + + +def test_intervals(): + assert intervals(0) == [] + assert intervals(1) == [] + + assert intervals(x, sqf=True) == [(0, 0)] + assert intervals(x) == [((0, 0), 1)] + + assert intervals(x**128) == [((0, 0), 128)] + assert intervals([x**2, x**4]) == [((0, 0), {0: 2, 1: 4})] + + f = Poly((x*Rational(2, 5) - Rational(17, 3))*(4*x + Rational(1, 257))) + + assert f.intervals(sqf=True) == [(-1, 0), (14, 15)] + assert f.intervals() == [((-1, 0), 1), ((14, 15), 1)] + + assert f.intervals(fast=True, sqf=True) == [(-1, 0), (14, 15)] + assert f.intervals(fast=True) == [((-1, 0), 1), ((14, 15), 1)] + + assert f.intervals(eps=Rational(1, 10)) == f.intervals(eps=0.1) == \ + [((Rational(-1, 258), 0), 1), ((Rational(85, 6), Rational(85, 6)), 1)] + assert f.intervals(eps=Rational(1, 100)) == f.intervals(eps=0.01) == \ + [((Rational(-1, 258), 0), 1), ((Rational(85, 6), Rational(85, 6)), 1)] + assert f.intervals(eps=Rational(1, 1000)) == f.intervals(eps=0.001) == \ + [((Rational(-1, 1002), 0), 1), ((Rational(85, 6), Rational(85, 6)), 1)] + assert f.intervals(eps=Rational(1, 10000)) == f.intervals(eps=0.0001) == \ + [((Rational(-1, 1028), Rational(-1, 1028)), 1), ((Rational(85, 6), Rational(85, 6)), 1)] + + f = (x*Rational(2, 5) - Rational(17, 3))*(4*x + Rational(1, 257)) + + assert intervals(f, sqf=True) == [(-1, 0), (14, 15)] + assert intervals(f) == [((-1, 0), 1), ((14, 15), 1)] + + assert intervals(f, eps=Rational(1, 10)) == intervals(f, eps=0.1) == \ + [((Rational(-1, 258), 0), 1), ((Rational(85, 6), Rational(85, 6)), 1)] + assert intervals(f, eps=Rational(1, 100)) == intervals(f, eps=0.01) == \ + [((Rational(-1, 258), 0), 1), ((Rational(85, 6), Rational(85, 6)), 1)] + assert intervals(f, eps=Rational(1, 1000)) == intervals(f, eps=0.001) == \ + [((Rational(-1, 1002), 0), 1), ((Rational(85, 6), Rational(85, 6)), 1)] + assert intervals(f, eps=Rational(1, 10000)) == intervals(f, eps=0.0001) == \ + [((Rational(-1, 1028), Rational(-1, 1028)), 1), ((Rational(85, 6), Rational(85, 6)), 1)] + + f = Poly((x**2 - 2)*(x**2 - 3)**7*(x + 1)*(7*x + 3)**3) + + assert f.intervals() == \ + [((-2, Rational(-3, 2)), 7), ((Rational(-3, 2), -1), 1), + ((-1, -1), 1), ((-1, 0), 3), + ((1, Rational(3, 2)), 1), ((Rational(3, 2), 2), 7)] + + assert intervals([x**5 - 200, x**5 - 201]) == \ + [((Rational(75, 26), Rational(101, 35)), {0: 1}), ((Rational(309, 107), Rational(26, 9)), {1: 1})] + + assert intervals([x**5 - 200, x**5 - 201], fast=True) == \ + [((Rational(75, 26), Rational(101, 35)), {0: 1}), ((Rational(309, 107), Rational(26, 9)), {1: 1})] + + assert intervals([x**2 - 200, x**2 - 201]) == \ + [((Rational(-71, 5), Rational(-85, 6)), {1: 1}), ((Rational(-85, 6), -14), {0: 1}), + ((14, Rational(85, 6)), {0: 1}), ((Rational(85, 6), Rational(71, 5)), {1: 1})] + + assert intervals([x + 1, x + 2, x - 1, x + 1, 1, x - 1, x - 1, (x - 2)**2]) == \ + [((-2, -2), {1: 1}), ((-1, -1), {0: 1, 3: 1}), ((1, 1), {2: + 1, 5: 1, 6: 1}), ((2, 2), {7: 2})] + + f, g, h = x**2 - 2, x**4 - 4*x**2 + 4, x - 1 + + assert intervals(f, inf=Rational(7, 4), sqf=True) == [] + assert intervals(f, inf=Rational(7, 5), sqf=True) == [(Rational(7, 5), Rational(3, 2))] + assert intervals(f, sup=Rational(7, 4), sqf=True) == [(-2, -1), (1, Rational(3, 2))] + assert intervals(f, sup=Rational(7, 5), sqf=True) == [(-2, -1)] + + assert intervals(g, inf=Rational(7, 4)) == [] + assert intervals(g, inf=Rational(7, 5)) == [((Rational(7, 5), Rational(3, 2)), 2)] + assert intervals(g, sup=Rational(7, 4)) == [((-2, -1), 2), ((1, Rational(3, 2)), 2)] + assert intervals(g, sup=Rational(7, 5)) == [((-2, -1), 2)] + + assert intervals([g, h], inf=Rational(7, 4)) == [] + assert intervals([g, h], inf=Rational(7, 5)) == [((Rational(7, 5), Rational(3, 2)), {0: 2})] + assert intervals([g, h], sup=S( + 7)/4) == [((-2, -1), {0: 2}), ((1, 1), {1: 1}), ((1, Rational(3, 2)), {0: 2})] + assert intervals( + [g, h], sup=Rational(7, 5)) == [((-2, -1), {0: 2}), ((1, 1), {1: 1})] + + assert intervals([x + 2, x**2 - 2]) == \ + [((-2, -2), {0: 1}), ((-2, -1), {1: 1}), ((1, 2), {1: 1})] + assert intervals([x + 2, x**2 - 2], strict=True) == \ + [((-2, -2), {0: 1}), ((Rational(-3, 2), -1), {1: 1}), ((1, 2), {1: 1})] + + f = 7*z**4 - 19*z**3 + 20*z**2 + 17*z + 20 + + assert intervals(f) == [] + + real_part, complex_part = intervals(f, all=True, sqf=True) + + assert real_part == [] + assert all(re(a) < re(r) < re(b) and im( + a) < im(r) < im(b) for (a, b), r in zip(complex_part, nroots(f))) + + assert complex_part == [(Rational(-40, 7) - I*40/7, 0), + (Rational(-40, 7), I*40/7), + (I*Rational(-40, 7), Rational(40, 7)), + (0, Rational(40, 7) + I*40/7)] + + real_part, complex_part = intervals(f, all=True, sqf=True, eps=Rational(1, 10)) + + assert real_part == [] + assert all(re(a) < re(r) < re(b) and im( + a) < im(r) < im(b) for (a, b), r in zip(complex_part, nroots(f))) + + raises(ValueError, lambda: intervals(x**2 - 2, eps=10**-100000)) + raises(ValueError, lambda: Poly(x**2 - 2).intervals(eps=10**-100000)) + raises( + ValueError, lambda: intervals([x**2 - 2, x**2 - 3], eps=10**-100000)) + + +def test_refine_root(): + f = Poly(x**2 - 2) + + assert f.refine_root(1, 2, steps=0) == (1, 2) + assert f.refine_root(-2, -1, steps=0) == (-2, -1) + + assert f.refine_root(1, 2, steps=None) == (1, Rational(3, 2)) + assert f.refine_root(-2, -1, steps=None) == (Rational(-3, 2), -1) + + assert f.refine_root(1, 2, steps=1) == (1, Rational(3, 2)) + assert f.refine_root(-2, -1, steps=1) == (Rational(-3, 2), -1) + + assert f.refine_root(1, 2, steps=1, fast=True) == (1, Rational(3, 2)) + assert f.refine_root(-2, -1, steps=1, fast=True) == (Rational(-3, 2), -1) + + assert f.refine_root(1, 2, eps=Rational(1, 100)) == (Rational(24, 17), Rational(17, 12)) + assert f.refine_root(1, 2, eps=1e-2) == (Rational(24, 17), Rational(17, 12)) + + raises(PolynomialError, lambda: (f**2).refine_root(1, 2, check_sqf=True)) + + raises(RefinementFailed, lambda: (f**2).refine_root(1, 2)) + raises(RefinementFailed, lambda: (f**2).refine_root(2, 3)) + + f = x**2 - 2 + + assert refine_root(f, 1, 2, steps=1) == (1, Rational(3, 2)) + assert refine_root(f, -2, -1, steps=1) == (Rational(-3, 2), -1) + + assert refine_root(f, 1, 2, steps=1, fast=True) == (1, Rational(3, 2)) + assert refine_root(f, -2, -1, steps=1, fast=True) == (Rational(-3, 2), -1) + + assert refine_root(f, 1, 2, eps=Rational(1, 100)) == (Rational(24, 17), Rational(17, 12)) + assert refine_root(f, 1, 2, eps=1e-2) == (Rational(24, 17), Rational(17, 12)) + + raises(PolynomialError, lambda: refine_root(1, 7, 8, eps=Rational(1, 100))) + + raises(ValueError, lambda: Poly(f).refine_root(1, 2, eps=10**-100000)) + raises(ValueError, lambda: refine_root(f, 1, 2, eps=10**-100000)) + + +def test_count_roots(): + assert count_roots(x**2 - 2) == 2 + + assert count_roots(x**2 - 2, inf=-oo) == 2 + assert count_roots(x**2 - 2, sup=+oo) == 2 + assert count_roots(x**2 - 2, inf=-oo, sup=+oo) == 2 + + assert count_roots(x**2 - 2, inf=-2) == 2 + assert count_roots(x**2 - 2, inf=-1) == 1 + + assert count_roots(x**2 - 2, sup=1) == 1 + assert count_roots(x**2 - 2, sup=2) == 2 + + assert count_roots(x**2 - 2, inf=-1, sup=1) == 0 + assert count_roots(x**2 - 2, inf=-2, sup=2) == 2 + + assert count_roots(x**2 - 2, inf=-1, sup=1) == 0 + assert count_roots(x**2 - 2, inf=-2, sup=2) == 2 + + assert count_roots(x**2 + 2) == 0 + assert count_roots(x**2 + 2, inf=-2*I) == 2 + assert count_roots(x**2 + 2, sup=+2*I) == 2 + assert count_roots(x**2 + 2, inf=-2*I, sup=+2*I) == 2 + + assert count_roots(x**2 + 2, inf=0) == 0 + assert count_roots(x**2 + 2, sup=0) == 0 + + assert count_roots(x**2 + 2, inf=-I) == 1 + assert count_roots(x**2 + 2, sup=+I) == 1 + + assert count_roots(x**2 + 2, inf=+I/2, sup=+I) == 0 + assert count_roots(x**2 + 2, inf=-I, sup=-I/2) == 0 + + raises(PolynomialError, lambda: count_roots(1)) + + +def test_count_roots_extension(): + + p1 = Poly(sqrt(2)*x**2 - 2, x, extension=True) + assert p1.count_roots() == 2 + assert p1.count_roots(inf=0) == 1 + assert p1.count_roots(sup=0) == 1 + + p2 = Poly(x**2 + sqrt(2), x, extension=True) + assert p2.count_roots() == 0 + + p3 = Poly(x**2 + 2*sqrt(2)*x + 1, x, extension=True) + assert p3.count_roots() == 2 + assert p3.count_roots(inf=-10, sup=10) == 2 + assert p3.count_roots(inf=-10, sup=0) == 2 + assert p3.count_roots(inf=-10, sup=-3) == 0 + assert p3.count_roots(inf=-3, sup=-2) == 1 + assert p3.count_roots(inf=-1, sup=0) == 1 + + +def test_Poly_root(): + f = Poly(2*x**3 - 7*x**2 + 4*x + 4) + + assert f.root(0) == Rational(-1, 2) + assert f.root(1) == 2 + assert f.root(2) == 2 + raises(IndexError, lambda: f.root(3)) + + assert Poly(x**5 + x + 1).root(0) == rootof(x**3 - x**2 + 1, 0) + + +def test_real_roots(): + + assert real_roots(x) == [0] + assert real_roots(x, multiple=False) == [(0, 1)] + + assert real_roots(x**3) == [0, 0, 0] + assert real_roots(x**3, multiple=False) == [(0, 3)] + + assert real_roots(x*(x**3 + x + 3)) == [rootof(x**3 + x + 3, 0), 0] + assert real_roots(x*(x**3 + x + 3), multiple=False) == [(rootof( + x**3 + x + 3, 0), 1), (0, 1)] + + assert real_roots( + x**3*(x**3 + x + 3)) == [rootof(x**3 + x + 3, 0), 0, 0, 0] + assert real_roots(x**3*(x**3 + x + 3), multiple=False) == [(rootof( + x**3 + x + 3, 0), 1), (0, 3)] + + assert real_roots(x**2 - 2, radicals=False) == [ + rootof(x**2 - 2, 0, radicals=False), + rootof(x**2 - 2, 1, radicals=False), + ] + + f = 2*x**3 - 7*x**2 + 4*x + 4 + g = x**3 + x + 1 + + assert Poly(f).real_roots() == [Rational(-1, 2), 2, 2] + assert Poly(g).real_roots() == [rootof(g, 0)] + + # testing extension + f = x**2 - sqrt(2) + roots = [-2**(S(1)/4), 2**(S(1)/4)] + raises(NotImplementedError, lambda: real_roots(f)) + raises(NotImplementedError, lambda: real_roots(Poly(f, x))) + assert real_roots(f, extension=True) == roots + assert real_roots(Poly(f, extension=True)) == roots + assert real_roots(Poly(f), extension=True) == roots + + +def test_all_roots(): + + f = 2*x**3 - 7*x**2 + 4*x + 4 + froots = [Rational(-1, 2), 2, 2] + assert all_roots(f) == Poly(f).all_roots() == froots + + g = x**3 + x + 1 + groots = [rootof(g, 0), rootof(g, 1), rootof(g, 2)] + assert all_roots(g) == Poly(g).all_roots() == groots + + assert all_roots(x**2 - 2) == [-sqrt(2), sqrt(2)] + assert all_roots(x**2 - 2, multiple=False) == [(-sqrt(2), 1), (sqrt(2), 1)] + assert all_roots(x**2 - 2, radicals=False) == [ + rootof(x**2 - 2, 0, radicals=False), + rootof(x**2 - 2, 1, radicals=False), + ] + + p = x**5 - x - 1 + assert all_roots(p) == [ + rootof(p, 0), rootof(p, 1), rootof(p, 2), rootof(p, 3), rootof(p, 4) + ] + + # testing extension + f = x**2 + sqrt(2) + roots = [-2**(S(1)/4)*I, 2**(S(1)/4)*I] + raises(NotImplementedError, lambda: all_roots(f)) + raises(NotImplementedError, lambda : all_roots(Poly(f, x))) + assert all_roots(f, extension=True) == roots + assert all_roots(Poly(f, extension=True)) == roots + assert all_roots(Poly(f), extension=True) == roots + + +def test_nroots(): + assert Poly(0, x).nroots() == [] + assert Poly(1, x).nroots() == [] + + assert Poly(x**2 - 1, x).nroots() == [-1.0, 1.0] + assert Poly(x**2 + 1, x).nroots() == [-1.0*I, 1.0*I] + + roots = Poly(x**2 - 1, x).nroots() + assert roots == [-1.0, 1.0] + + roots = Poly(x**2 + 1, x).nroots() + assert roots == [-1.0*I, 1.0*I] + + roots = Poly(x**2/3 - Rational(1, 3), x).nroots() + assert roots == [-1.0, 1.0] + + roots = Poly(x**2/3 + Rational(1, 3), x).nroots() + assert roots == [-1.0*I, 1.0*I] + + assert Poly(x**2 + 2*I, x).nroots() == [-1.0 + 1.0*I, 1.0 - 1.0*I] + assert Poly( + x**2 + 2*I, x, extension=I).nroots() == [-1.0 + 1.0*I, 1.0 - 1.0*I] + + assert Poly(0.2*x + 0.1).nroots() == [-0.5] + + roots = nroots(x**5 + x + 1, n=5) + eps = Float("1e-5") + + assert re(roots[0]).epsilon_eq(-0.75487, eps) is S.true + assert im(roots[0]) == 0 + assert re(roots[1]) == Float(-0.5, 5) + assert im(roots[1]).epsilon_eq(-0.86602, eps) is S.true + assert re(roots[2]) == Float(-0.5, 5) + assert im(roots[2]).epsilon_eq(+0.86602, eps) is S.true + assert re(roots[3]).epsilon_eq(+0.87743, eps) is S.true + assert im(roots[3]).epsilon_eq(-0.74486, eps) is S.true + assert re(roots[4]).epsilon_eq(+0.87743, eps) is S.true + assert im(roots[4]).epsilon_eq(+0.74486, eps) is S.true + + eps = Float("1e-6") + + assert re(roots[0]).epsilon_eq(-0.75487, eps) is S.false + assert im(roots[0]) == 0 + assert re(roots[1]) == Float(-0.5, 5) + assert im(roots[1]).epsilon_eq(-0.86602, eps) is S.false + assert re(roots[2]) == Float(-0.5, 5) + assert im(roots[2]).epsilon_eq(+0.86602, eps) is S.false + assert re(roots[3]).epsilon_eq(+0.87743, eps) is S.false + assert im(roots[3]).epsilon_eq(-0.74486, eps) is S.false + assert re(roots[4]).epsilon_eq(+0.87743, eps) is S.false + assert im(roots[4]).epsilon_eq(+0.74486, eps) is S.false + + raises(DomainError, lambda: Poly(x + y, x).nroots()) + raises(MultivariatePolynomialError, lambda: Poly(x + y).nroots()) + + assert nroots(x**2 - 1) == [-1.0, 1.0] + + roots = nroots(x**2 - 1) + assert roots == [-1.0, 1.0] + + assert nroots(x + I) == [-1.0*I] + assert nroots(x + 2*I) == [-2.0*I] + + raises(PolynomialError, lambda: nroots(0)) + + # issue 8296 + f = Poly(x**4 - 1) + assert f.nroots(2) == [w.n(2) for w in f.all_roots()] + + assert str(Poly(x**16 + 32*x**14 + 508*x**12 + 5440*x**10 + + 39510*x**8 + 204320*x**6 + 755548*x**4 + 1434496*x**2 + + 877969).nroots(2)) == ('[-1.7 - 1.9*I, -1.7 + 1.9*I, -1.7 ' + '- 2.5*I, -1.7 + 2.5*I, -1.0*I, 1.0*I, -1.7*I, 1.7*I, -2.8*I, ' + '2.8*I, -3.4*I, 3.4*I, 1.7 - 1.9*I, 1.7 + 1.9*I, 1.7 - 2.5*I, ' + '1.7 + 2.5*I]') + assert str(Poly(1e-15*x**2 -1).nroots()) == ('[-31622776.6016838, 31622776.6016838]') + + # https://github.com/sympy/sympy/issues/23861 + + i = Float('3.000000000000000000000000000000000000000000000000001') + [r] = nroots(x + I*i, n=300) + assert abs(r + I*i) < 1e-300 + + +def test_ground_roots(): + f = x**6 - 4*x**4 + 4*x**3 - x**2 + + assert Poly(f).ground_roots() == {S.One: 2, S.Zero: 2} + assert ground_roots(f) == {S.One: 2, S.Zero: 2} + + +def test_nth_power_roots_poly(): + f = x**4 - x**2 + 1 + + f_2 = (x**2 - x + 1)**2 + f_3 = (x**2 + 1)**2 + f_4 = (x**2 + x + 1)**2 + f_12 = (x - 1)**4 + + assert nth_power_roots_poly(f, 1) == f + + raises(ValueError, lambda: nth_power_roots_poly(f, 0)) + raises(ValueError, lambda: nth_power_roots_poly(f, x)) + + assert factor(nth_power_roots_poly(f, 2)) == f_2 + assert factor(nth_power_roots_poly(f, 3)) == f_3 + assert factor(nth_power_roots_poly(f, 4)) == f_4 + assert factor(nth_power_roots_poly(f, 12)) == f_12 + + raises(MultivariatePolynomialError, lambda: nth_power_roots_poly( + x + y, 2, x, y)) + +def test_which_real_roots(): + f = Poly(x**4 - 1) + + assert f.which_real_roots([1, -1]) == [1, -1] + assert f.which_real_roots([1, -1, 2, 4]) == [1, -1] + assert f.which_real_roots([1, -1, -1, 1, 2, 5]) == [1, -1] + assert f.which_real_roots([10, 8, 7, -1, 1]) == [-1, 1] + + # no real roots + # (technically its still a superset) + f = Poly(x**2 + 1) + assert f.which_real_roots([5, 10]) == [] + + # not square free + f = Poly((x-1)**2) + assert f.which_real_roots([1, 1, -1, 2]) == [1] + + # candidates not superset + f = Poly(x**2 - 1) + assert f.which_real_roots([0, 2]) == [0, 2] + +def test_which_all_roots(): + f = Poly(x**4 - 1) + + assert f.which_all_roots([1, -1, I, -I]) == [1, -1, I, -I] + assert f.which_all_roots([I, I, -I, 1, -1]) == [I, -I, 1, -1] + + f = Poly(x**2 + 1) + assert f.which_all_roots([I, -I, I/2]) == [I, -I] + + # not square free + f = Poly((x-I)**2) + assert f.which_all_roots([I, I, 1, -1, 0]) == [I] + + # candidates not superset + f = Poly(x**2 + 1) + assert f.which_all_roots([I/2, -I/2]) == [I/2, -I/2] + +def test_same_root(): + f = Poly(x**4 + x**3 + x**2 + x + 1) + eq = f.same_root + r0 = exp(2 * I * pi / 5) + assert [i for i, r in enumerate(f.all_roots()) if eq(r, r0)] == [3] + + raises(PolynomialError, + lambda: Poly(x + 1, domain=QQ).same_root(0, 0)) + raises(DomainError, + lambda: Poly(x**2 + 1, domain=FF(7)).same_root(0, 0)) + raises(DomainError, + lambda: Poly(x ** 2 + 1, domain=ZZ_I).same_root(0, 0)) + raises(DomainError, + lambda: Poly(y * x**2 + 1, domain=ZZ[y]).same_root(0, 0)) + raises(MultivariatePolynomialError, + lambda: Poly(x * y + 1, domain=ZZ).same_root(0, 0)) + + +def test_torational_factor_list(): + p = expand(((x**2-1)*(x-2)).subs({x:x*(1 + sqrt(2))})) + assert _torational_factor_list(p, x) == (-2, [ + (-x*(1 + sqrt(2))/2 + 1, 1), + (-x*(1 + sqrt(2)) - 1, 1), + (-x*(1 + sqrt(2)) + 1, 1)]) + + + p = expand(((x**2-1)*(x-2)).subs({x:x*(1 + 2**Rational(1, 4))})) + assert _torational_factor_list(p, x) is None + + +def test_cancel(): + assert cancel(0) == 0 + assert cancel(7) == 7 + assert cancel(x) == x + + assert cancel(oo) is oo + + raises(ValueError, lambda: cancel((1, 2, 3))) + + # test first tuple returnr + assert (t:=cancel((2, 3))) == (1, 2, 3) + assert isinstance(t, tuple) + + # tests 2nd tuple return + assert (t:=cancel((1, 0), x)) == (1, 1, 0) + assert isinstance(t, tuple) + assert cancel((0, 1), x) == (1, 0, 1) + + f, g, p, q = 4*x**2 - 4, 2*x - 2, 2*x + 2, 1 + F, G, P, Q = [ Poly(u, x) for u in (f, g, p, q) ] + + assert F.cancel(G) == (1, P, Q) + assert cancel((f, g)) == (1, p, q) + assert cancel((f, g), x) == (1, p, q) + assert cancel((f, g), (x,)) == (1, p, q) + # tests 3rd tuple return + assert (t:=cancel((F, G))) == (1, P, Q) + assert isinstance(t, tuple) + assert cancel((f, g), polys=True) == (1, P, Q) + assert cancel((F, G), polys=False) == (1, p, q) + + f = (x**2 - 2)/(x + sqrt(2)) + + assert cancel(f) == f + assert cancel(f, greedy=False) == x - sqrt(2) + + f = (x**2 - 2)/(x - sqrt(2)) + + assert cancel(f) == f + assert cancel(f, greedy=False) == x + sqrt(2) + + assert cancel((x**2/4 - 1, x/2 - 1)) == (1, x + 2, 2) + # assert cancel((x**2/4 - 1, x/2 - 1)) == (S.Half, x + 2, 1) + + assert cancel((x**2 - y)/(x - y)) == 1/(x - y)*(x**2 - y) + + assert cancel((x**2 - y**2)/(x - y), x) == x + y + assert cancel((x**2 - y**2)/(x - y), y) == x + y + assert cancel((x**2 - y**2)/(x - y)) == x + y + + assert cancel((x**3 - 1)/(x**2 - 1)) == (x**2 + x + 1)/(x + 1) + assert cancel((x**3/2 - S.Half)/(x**2 - 1)) == (x**2 + x + 1)/(2*x + 2) + + assert cancel((exp(2*x) + 2*exp(x) + 1)/(exp(x) + 1)) == exp(x) + 1 + + f = Poly(x**2 - a**2, x) + g = Poly(x - a, x) + + F = Poly(x + a, x, domain='ZZ[a]') + G = Poly(1, x, domain='ZZ[a]') + + assert cancel((f, g)) == (1, F, G) + + f = x**3 + (sqrt(2) - 2)*x**2 - (2*sqrt(2) + 3)*x - 3*sqrt(2) + g = x**2 - 2 + + assert cancel((f, g), extension=True) == (1, x**2 - 2*x - 3, x - sqrt(2)) + + f = Poly(-2*x + 3, x) + g = Poly(-x**9 + x**8 + x**6 - x**5 + 2*x**2 - 3*x + 1, x) + + assert cancel((f, g)) == (1, -f, -g) + + f = Poly(x/3 + 1, x) + g = Poly(x/7 + 1, x) + + assert f.cancel(g) == (S(7)/3, + Poly(x + 3, x, domain=QQ), + Poly(x + 7, x, domain=QQ)) + assert f.cancel(g, include=True) == ( + Poly(7*x + 21, x, domain=QQ), + Poly(3*x + 21, x, domain=QQ)) + + pairs = [ + (1 + x, 1 + x, 1, 1, 1), + (1 + x, 1 - x, -1, -1-x, -1+x), + (1 - x, 1 + x, -1, 1-x, 1+x), + (1 - x, 1 - x, 1, 1, 1), + ] + for f, g, coeff, p, q in pairs: + assert cancel((f, g)) == (1, p, q) + pf = Poly(f, x) + pg = Poly(g, x) + pp = Poly(p, x) + pq = Poly(q, x) + assert pf.cancel(pg) == (coeff, coeff*pp, pq) + assert pf.rep.cancel(pg.rep) == (pp.rep, pq.rep) + assert pf.rep.cancel(pg.rep, include=True) == (pp.rep, pq.rep) + + f = Poly(y, y, domain='ZZ(x)') + g = Poly(1, y, domain='ZZ[x]') + + assert f.cancel( + g) == (1, Poly(y, y, domain='ZZ(x)'), Poly(1, y, domain='ZZ(x)')) + assert f.cancel(g, include=True) == ( + Poly(y, y, domain='ZZ(x)'), Poly(1, y, domain='ZZ(x)')) + + f = Poly(5*x*y + x, y, domain='ZZ(x)') + g = Poly(2*x**2*y, y, domain='ZZ(x)') + + assert f.cancel(g, include=True) == ( + Poly(5*y + 1, y, domain='ZZ(x)'), Poly(2*x*y, y, domain='ZZ(x)')) + + f = -(-2*x - 4*y + 0.005*(z - y)**2)/((z - y)*(-z + y + 2)) + assert cancel(f).is_Mul == True + + P = tanh(x - 3.0) + Q = tanh(x + 3.0) + f = ((-2*P**2 + 2)*(-P**2 + 1)*Q**2/2 + (-2*P**2 + 2)*(-2*Q**2 + 2)*P*Q - (-2*P**2 + 2)*P**2*Q**2 + (-2*Q**2 + 2)*(-Q**2 + 1)*P**2/2 - (-2*Q**2 + 2)*P**2*Q**2)/(2*sqrt(P**2*Q**2 + 0.0001)) \ + + (-(-2*P**2 + 2)*P*Q**2/2 - (-2*Q**2 + 2)*P**2*Q/2)*((-2*P**2 + 2)*P*Q**2/2 + (-2*Q**2 + 2)*P**2*Q/2)/(2*(P**2*Q**2 + 0.0001)**Rational(3, 2)) + assert cancel(f).is_Mul == True + + # issue 7022 + A = Symbol('A', commutative=False) + p1 = Piecewise((A*(x**2 - 1)/(x + 1), x > 1), ((x + 2)/(x**2 + 2*x), True)) + p2 = Piecewise((A*(x - 1), x > 1), (1/x, True)) + assert cancel(p1) == p2 + assert cancel(2*p1) == 2*p2 + assert cancel(1 + p1) == 1 + p2 + assert cancel((x**2 - 1)/(x + 1)*p1) == (x - 1)*p2 + assert cancel((x**2 - 1)/(x + 1) + p1) == (x - 1) + p2 + p3 = Piecewise(((x**2 - 1)/(x + 1), x > 1), ((x + 2)/(x**2 + 2*x), True)) + p4 = Piecewise(((x - 1), x > 1), (1/x, True)) + assert cancel(p3) == p4 + assert cancel(2*p3) == 2*p4 + assert cancel(1 + p3) == 1 + p4 + assert cancel((x**2 - 1)/(x + 1)*p3) == (x - 1)*p4 + assert cancel((x**2 - 1)/(x + 1) + p3) == (x - 1) + p4 + + # issue 4077 + q = S('''(2*1*(x - 1/x)/(x*(2*x - (-x + 1/x)/(x**2*(x - 1/x)**2) - 1/(x**2*(x - + 1/x)) - 2/x)) - 2*1*((x - 1/x)/((x*(x - 1/x)**2)) - 1/(x*(x - + 1/x)))*((-x + 1/x)*((x - 1/x)/((x*(x - 1/x)**2)) - 1/(x*(x - + 1/x)))/(2*x - (-x + 1/x)/(x**2*(x - 1/x)**2) - 1/(x**2*(x - 1/x)) - + 2/x) + 1)*((x - 1/x)/((x - 1/x)**2) - ((x - 1/x)/((x*(x - 1/x)**2)) - + 1/(x*(x - 1/x)))**2/(2*x - (-x + 1/x)/(x**2*(x - 1/x)**2) - 1/(x**2*(x + - 1/x)) - 2/x) - 1/(x - 1/x))*(2*x - (-x + 1/x)/(x**2*(x - 1/x)**2) - + 1/(x**2*(x - 1/x)) - 2/x)/x - 1/x)*(((-x + 1/x)/((x*(x - 1/x)**2)) + + 1/(x*(x - 1/x)))*((-(x - 1/x)/(x*(x - 1/x)) - 1/x)*((x - 1/x)/((x*(x - + 1/x)**2)) - 1/(x*(x - 1/x)))/(2*x - (-x + 1/x)/(x**2*(x - 1/x)**2) - + 1/(x**2*(x - 1/x)) - 2/x) - 1 + (x - 1/x)/(x - 1/x))/((x*((x - + 1/x)/((x - 1/x)**2) - ((x - 1/x)/((x*(x - 1/x)**2)) - 1/(x*(x - + 1/x)))**2/(2*x - (-x + 1/x)/(x**2*(x - 1/x)**2) - 1/(x**2*(x - 1/x)) - + 2/x) - 1/(x - 1/x))*(2*x - (-x + 1/x)/(x**2*(x - 1/x)**2) - 1/(x**2*(x + - 1/x)) - 2/x))) + ((x - 1/x)/((x*(x - 1/x))) + 1/x)/((x*(2*x - (-x + + 1/x)/(x**2*(x - 1/x)**2) - 1/(x**2*(x - 1/x)) - 2/x))) + 1/x)/(2*x + + 2*((x - 1/x)/((x*(x - 1/x)**2)) - 1/(x*(x - 1/x)))*((-(x - 1/x)/(x*(x + - 1/x)) - 1/x)*((x - 1/x)/((x*(x - 1/x)**2)) - 1/(x*(x - 1/x)))/(2*x - + (-x + 1/x)/(x**2*(x - 1/x)**2) - 1/(x**2*(x - 1/x)) - 2/x) - 1 + (x - + 1/x)/(x - 1/x))/((x*((x - 1/x)/((x - 1/x)**2) - ((x - 1/x)/((x*(x - + 1/x)**2)) - 1/(x*(x - 1/x)))**2/(2*x - (-x + 1/x)/(x**2*(x - 1/x)**2) + - 1/(x**2*(x - 1/x)) - 2/x) - 1/(x - 1/x))*(2*x - (-x + 1/x)/(x**2*(x + - 1/x)**2) - 1/(x**2*(x - 1/x)) - 2/x))) - 2*((x - 1/x)/((x*(x - + 1/x))) + 1/x)/(x*(2*x - (-x + 1/x)/(x**2*(x - 1/x)**2) - 1/(x**2*(x - + 1/x)) - 2/x)) - 2/x) - ((x - 1/x)/((x*(x - 1/x)**2)) - 1/(x*(x - + 1/x)))*((-x + 1/x)*((x - 1/x)/((x*(x - 1/x)**2)) - 1/(x*(x - + 1/x)))/(2*x - (-x + 1/x)/(x**2*(x - 1/x)**2) - 1/(x**2*(x - 1/x)) - + 2/x) + 1)/(x*((x - 1/x)/((x - 1/x)**2) - ((x - 1/x)/((x*(x - 1/x)**2)) + - 1/(x*(x - 1/x)))**2/(2*x - (-x + 1/x)/(x**2*(x - 1/x)**2) - + 1/(x**2*(x - 1/x)) - 2/x) - 1/(x - 1/x))*(2*x - (-x + 1/x)/(x**2*(x - + 1/x)**2) - 1/(x**2*(x - 1/x)) - 2/x)) + (x - 1/x)/((x*(2*x - (-x + + 1/x)/(x**2*(x - 1/x)**2) - 1/(x**2*(x - 1/x)) - 2/x))) - 1/x''', + evaluate=False) + assert cancel(q, _signsimp=False) is S.NaN + assert q.subs(x, 2) is S.NaN + assert signsimp(q) is S.NaN + + # issue 9363 + M = MatrixSymbol('M', 5, 5) + assert cancel(M[0,0] + 7) == M[0,0] + 7 + expr = sin(M[1, 4] + M[2, 1] * 5 * M[4, 0]) - 5 * M[1, 2] / z + assert cancel(expr) == (z*sin(M[1, 4] + M[2, 1] * 5 * M[4, 0]) - 5 * M[1, 2]) / z + + assert cancel((x**2 + 1)/(x - I)) == x + I + + +def test_cancel_modulus(): + assert cancel((x**2 - 1)/(x + 1), modulus=2) == x + 1 + assert Poly(x**2 - 1, modulus=2).cancel(Poly(x + 1, modulus=2)) ==\ + (1, Poly(x + 1, modulus=2), Poly(1, x, modulus=2)) + + +def test_make_monic_over_integers_by_scaling_roots(): + f = Poly(x**2 + 3*x + 4, x, domain='ZZ') + g, c = f.make_monic_over_integers_by_scaling_roots() + assert g == f + assert c == ZZ.one + + f = Poly(x**2 + 3*x + 4, x, domain='QQ') + g, c = f.make_monic_over_integers_by_scaling_roots() + assert g == f.to_ring() + assert c == ZZ.one + + f = Poly(x**2/2 + S(1)/4 * x + S(1)/8, x, domain='QQ') + g, c = f.make_monic_over_integers_by_scaling_roots() + assert g == Poly(x**2 + 2*x + 4, x, domain='ZZ') + assert c == 4 + + f = Poly(x**3/2 + S(1)/4 * x + S(1)/8, x, domain='QQ') + g, c = f.make_monic_over_integers_by_scaling_roots() + assert g == Poly(x**3 + 8*x + 16, x, domain='ZZ') + assert c == 4 + + f = Poly(x*y, x, y) + raises(ValueError, lambda: f.make_monic_over_integers_by_scaling_roots()) + + f = Poly(x, domain='RR') + raises(ValueError, lambda: f.make_monic_over_integers_by_scaling_roots()) + + +def test_galois_group(): + f = Poly(x ** 4 - 2) + G, alt = f.galois_group(by_name=True) + assert G == S4TransitiveSubgroups.D4 + assert alt is False + + +def test_reduced(): + f = 2*x**4 + y**2 - x**2 + y**3 + G = [x**3 - x, y**3 - y] + + Q = [2*x, 1] + r = x**2 + y**2 + y + + assert reduced(f, G) == (Q, r) + assert reduced(f, G, x, y) == (Q, r) + + H = groebner(G) + + assert H.reduce(f) == (Q, r) + + Q = [Poly(2*x, x, y), Poly(1, x, y)] + r = Poly(x**2 + y**2 + y, x, y) + + assert _strict_eq(reduced(f, G, polys=True), (Q, r)) + assert _strict_eq(reduced(f, G, x, y, polys=True), (Q, r)) + + H = groebner(G, polys=True) + + assert _strict_eq(H.reduce(f), (Q, r)) + + f = 2*x**3 + y**3 + 3*y + G = groebner([x**2 + y**2 - 1, x*y - 2]) + + Q = [x**2 - x*y**3/2 + x*y/2 + y**6/4 - y**4/2 + y**2/4, -y**5/4 + y**3/2 + y*Rational(3, 4)] + r = 0 + + assert reduced(f, G) == (Q, r) + assert G.reduce(f) == (Q, r) + + assert reduced(f, G, auto=False)[1] != 0 + assert G.reduce(f, auto=False)[1] != 0 + + assert G.contains(f) is True + assert G.contains(f + 1) is False + + assert reduced(1, [1], x) == ([1], 0) + raises(ComputationFailed, lambda: reduced(1, [1])) + + f_poly = Poly(2*x**3 + y**3 + 3*y) + G_poly = groebner([Poly(x**2 + y**2 - 1), Poly(x*y - 2)]) + + Q_poly = [Poly(x**2 - 1/2*x*y**3 + 1/2*x*y + 1/4*y**6 - 1/2*y**4 + 1/4*y**2, x, y, domain='QQ'), + Poly(-1/4*y**5 + 1/2*y**3 + 3/4*y, x, y, domain='QQ')] + r_poly = Poly(0, x, y, domain='QQ') + + assert G_poly.reduce(f_poly) == (Q_poly, r_poly) + + Q, r = G_poly.reduce(f) + assert all(isinstance(q, Poly) for q in Q) + assert isinstance(r, Poly) + + f_wrong_gens = Poly(2*x**3 + y**3 + 3*y, x, y, z) + raises(ValueError, lambda: G_poly.reduce(f_wrong_gens)) + + zero_poly = Poly(0, x, y) + Q, r = G_poly.reduce(zero_poly) + assert all(q.is_zero for q in Q) + assert r.is_zero + + const_poly = Poly(1, x, y) + Q, r = G_poly.reduce(const_poly) + assert isinstance(r, Poly) + assert r.as_expr() == 1 + assert all(q.is_zero for q in Q) + + +def test_groebner(): + assert groebner([], x, y, z) == [] + + assert groebner([x**2 + 1, y**4*x + x**3], x, y, order='lex') == [1 + x**2, -1 + y**4] + assert groebner([x**2 + 1, y**4*x + x**3, x*y*z**3], x, y, z, order='grevlex') == [-1 + y**4, z**3, 1 + x**2] + + assert groebner([x**2 + 1, y**4*x + x**3], x, y, order='lex', polys=True) == \ + [Poly(1 + x**2, x, y), Poly(-1 + y**4, x, y)] + assert groebner([x**2 + 1, y**4*x + x**3, x*y*z**3], x, y, z, order='grevlex', polys=True) == \ + [Poly(-1 + y**4, x, y, z), Poly(z**3, x, y, z), Poly(1 + x**2, x, y, z)] + + assert groebner([x**3 - 1, x**2 - 1]) == [x - 1] + assert groebner([Eq(x**3, 1), Eq(x**2, 1)]) == [x - 1] + + F = [3*x**2 + y*z - 5*x - 1, 2*x + 3*x*y + y**2, x - 3*y + x*z - 2*z**2] + f = z**9 - x**2*y**3 - 3*x*y**2*z + 11*y*z**2 + x**2*z**2 - 5 + + G = groebner(F, x, y, z, modulus=7, symmetric=False) + + assert G == [1 + x + y + 3*z + 2*z**2 + 2*z**3 + 6*z**4 + z**5, + 1 + 3*y + y**2 + 6*z**2 + 3*z**3 + 3*z**4 + 3*z**5 + 4*z**6, + 1 + 4*y + 4*z + y*z + 4*z**3 + z**4 + z**6, + 6 + 6*z + z**2 + 4*z**3 + 3*z**4 + 6*z**5 + 3*z**6 + z**7] + + Q, r = reduced(f, G, x, y, z, modulus=7, symmetric=False, polys=True) + + assert sum([ q*g for q, g in zip(Q, G.polys)], r) == Poly(f, modulus=7) + + F = [x*y - 2*y, 2*y**2 - x**2] + + assert groebner(F, x, y, order='grevlex') == \ + [y**3 - 2*y, x**2 - 2*y**2, x*y - 2*y] + assert groebner(F, y, x, order='grevlex') == \ + [x**3 - 2*x**2, -x**2 + 2*y**2, x*y - 2*y] + assert groebner(F, order='grevlex', field=True) == \ + [y**3 - 2*y, x**2 - 2*y**2, x*y - 2*y] + + assert groebner([1], x) == [1] + + assert groebner([x**2 + 2.0*y], x, y) == [1.0*x**2 + 2.0*y] + raises(ComputationFailed, lambda: groebner([1])) + + assert groebner([x**2 - 1, x**3 + 1], method='buchberger') == [x + 1] + assert groebner([x**2 - 1, x**3 + 1], method='f5b') == [x + 1] + + raises(ValueError, lambda: groebner([x, y], method='unknown')) + + +def test_fglm(): + F = [a + b + c + d, a*b + a*d + b*c + b*d, a*b*c + a*b*d + a*c*d + b*c*d, a*b*c*d - 1] + G = groebner(F, a, b, c, d, order=grlex) + + B = [ + 4*a + 3*d**9 - 4*d**5 - 3*d, + 4*b + 4*c - 3*d**9 + 4*d**5 + 7*d, + 4*c**2 + 3*d**10 - 4*d**6 - 3*d**2, + 4*c*d**4 + 4*c - d**9 + 4*d**5 + 5*d, + d**12 - d**8 - d**4 + 1, + ] + + assert groebner(F, a, b, c, d, order=lex) == B + assert G.fglm(lex) == B + + F = [9*x**8 + 36*x**7 - 32*x**6 - 252*x**5 - 78*x**4 + 468*x**3 + 288*x**2 - 108*x + 9, + -72*t*x**7 - 252*t*x**6 + 192*t*x**5 + 1260*t*x**4 + 312*t*x**3 - 404*t*x**2 - 576*t*x + \ + 108*t - 72*x**7 - 256*x**6 + 192*x**5 + 1280*x**4 + 312*x**3 - 576*x + 96] + G = groebner(F, t, x, order=grlex) + + B = [ + 203577793572507451707*t + 627982239411707112*x**7 - 666924143779443762*x**6 - \ + 10874593056632447619*x**5 + 5119998792707079562*x**4 + 72917161949456066376*x**3 + \ + 20362663855832380362*x**2 - 142079311455258371571*x + 183756699868981873194, + 9*x**8 + 36*x**7 - 32*x**6 - 252*x**5 - 78*x**4 + 468*x**3 + 288*x**2 - 108*x + 9, + ] + + assert groebner(F, t, x, order=lex) == B + assert G.fglm(lex) == B + + F = [x**2 - x - 3*y + 1, -2*x + y**2 + y - 1] + G = groebner(F, x, y, order=lex) + + B = [ + x**2 - x - 3*y + 1, + y**2 - 2*x + y - 1, + ] + + assert groebner(F, x, y, order=grlex) == B + assert G.fglm(grlex) == B + + +def test_is_zero_dimensional(): + assert is_zero_dimensional([x, y], x, y) is True + assert is_zero_dimensional([x**3 + y**2], x, y) is False + + assert is_zero_dimensional([x, y, z], x, y, z) is True + assert is_zero_dimensional([x, y, z], x, y, z, t) is False + + F = [x*y - z, y*z - x, x*y - y] + assert is_zero_dimensional(F, x, y, z) is True + + F = [x**2 - 2*x*z + 5, x*y**2 + y*z**3, 3*y**2 - 8*z**2] + assert is_zero_dimensional(F, x, y, z) is True + + +def test_GroebnerBasis(): + F = [x*y - 2*y, 2*y**2 - x**2] + + G = groebner(F, x, y, order='grevlex') + H = [y**3 - 2*y, x**2 - 2*y**2, x*y - 2*y] + P = [ Poly(h, x, y) for h in H ] + + assert groebner(F + [0], x, y, order='grevlex') == G + assert isinstance(G, GroebnerBasis) is True + + assert len(G) == 3 + + assert G[0] == H[0] and not G[0].is_Poly + assert G[1] == H[1] and not G[1].is_Poly + assert G[2] == H[2] and not G[2].is_Poly + + assert G[1:] == H[1:] and not any(g.is_Poly for g in G[1:]) + assert G[:2] == H[:2] and not any(g.is_Poly for g in G[1:]) + + assert G.exprs == H + assert G.polys == P + assert G.gens == (x, y) + assert G.domain == ZZ + assert G.order == grevlex + + assert G == H + assert G == tuple(H) + assert G == P + assert G == tuple(P) + + assert G != [] + + G = groebner(F, x, y, order='grevlex', polys=True) + + assert G[0] == P[0] and G[0].is_Poly + assert G[1] == P[1] and G[1].is_Poly + assert G[2] == P[2] and G[2].is_Poly + + assert G[1:] == P[1:] and all(g.is_Poly for g in G[1:]) + assert G[:2] == P[:2] and all(g.is_Poly for g in G[1:]) + + +def test_poly(): + assert poly(x) == Poly(x, x) + assert poly(y) == Poly(y, y) + + assert poly(x + y) == Poly(x + y, x, y) + assert poly(x + sin(x)) == Poly(x + sin(x), x, sin(x)) + + assert poly(x + y, wrt=y) == Poly(x + y, y, x) + assert poly(x + sin(x), wrt=sin(x)) == Poly(x + sin(x), sin(x), x) + + assert poly(x*y + 2*x*z**2 + 17) == Poly(x*y + 2*x*z**2 + 17, x, y, z) + + assert poly(2*(y + z)**2 - 1) == Poly(2*y**2 + 4*y*z + 2*z**2 - 1, y, z) + assert poly( + x*(y + z)**2 - 1) == Poly(x*y**2 + 2*x*y*z + x*z**2 - 1, x, y, z) + assert poly(2*x*( + y + z)**2 - 1) == Poly(2*x*y**2 + 4*x*y*z + 2*x*z**2 - 1, x, y, z) + + assert poly(2*( + y + z)**2 - x - 1) == Poly(2*y**2 + 4*y*z + 2*z**2 - x - 1, x, y, z) + assert poly(x*( + y + z)**2 - x - 1) == Poly(x*y**2 + 2*x*y*z + x*z**2 - x - 1, x, y, z) + assert poly(2*x*(y + z)**2 - x - 1) == Poly(2*x*y**2 + 4*x*y*z + 2* + x*z**2 - x - 1, x, y, z) + + assert poly(x*y + (x + y)**2 + (x + z)**2) == \ + Poly(2*x*z + 3*x*y + y**2 + z**2 + 2*x**2, x, y, z) + assert poly(x*y*(x + y)*(x + z)**2) == \ + Poly(x**3*y**2 + x*y**2*z**2 + y*x**2*z**2 + 2*z*x**2* + y**2 + 2*y*z*x**3 + y*x**4, x, y, z) + + assert poly(Poly(x + y + z, y, x, z)) == Poly(x + y + z, y, x, z) + + assert poly((x + y)**2, x) == Poly(x**2 + 2*x*y + y**2, x, domain=ZZ[y]) + assert poly((x + y)**2, y) == Poly(x**2 + 2*x*y + y**2, y, domain=ZZ[x]) + + assert poly(1, x) == Poly(1, x) + raises(GeneratorsNeeded, lambda: poly(1)) + + # issue 6184 + assert poly(x + y, x, y) == Poly(x + y, x, y) + assert poly(x + y, y, x) == Poly(x + y, y, x) + + # https://github.com/sympy/sympy/issues/19755 + expr1 = x + (2*x + 3)**2/5 + S(6)/5 + assert poly(expr1).as_expr() == expr1.expand() + expr2 = y*(y+1) + S(1)/3 + assert poly(expr2).as_expr() == expr2.expand() + + +def test_keep_coeff(): + u = Mul(2, x + 1, evaluate=False) + assert _keep_coeff(S.One, x) == x + assert _keep_coeff(S.NegativeOne, x) == -x + assert _keep_coeff(S(1.0), x) == 1.0*x + assert _keep_coeff(S(-1.0), x) == -1.0*x + assert _keep_coeff(S.One, 2*x) == 2*x + assert _keep_coeff(S(2), x/2) == x + assert _keep_coeff(S(2), sin(x)) == 2*sin(x) + assert _keep_coeff(S(2), x + 1) == u + assert _keep_coeff(x, 1/x) == 1 + assert _keep_coeff(x + 1, S(2)) == u + assert _keep_coeff(S.Half, S.One) == S.Half + p = Pow(2, 3, evaluate=False) + assert _keep_coeff(S(-1), p) == Mul(-1, p, evaluate=False) + a = Add(2, p, evaluate=False) + assert _keep_coeff(S.Half, a, clear=True + ) == Mul(S.Half, a, evaluate=False) + assert _keep_coeff(S.Half, a, clear=False + ) == Add(1, Mul(S.Half, p, evaluate=False), evaluate=False) + + +def test_poly_matching_consistency(): + # Test for this issue: + # https://github.com/sympy/sympy/issues/5514 + assert I * Poly(x, x) == Poly(I*x, x) + assert Poly(x, x) * I == Poly(I*x, x) + + +def test_issue_5786(): + assert expand(factor(expand( + (x - I*y)*(z - I*t)), extension=[I])) == -I*t*x - t*y + x*z - I*y*z + + +def test_noncommutative(): + class foo(Expr): + is_commutative=False + e = x/(x + x*y) + c = 1/( 1 + y) + assert cancel(foo(e)) == foo(c) + assert cancel(e + foo(e)) == c + foo(c) + assert cancel(e*foo(c)) == c*foo(c) + + +def test_to_rational_coeffs(): + assert to_rational_coeffs( + Poly(x**3 + y*x**2 + sqrt(y), x, domain='EX')) is None + # issue 21268 + assert to_rational_coeffs( + Poly(y**3 + sqrt(2)*y**2*sin(x) + 1, y)) is None + + assert to_rational_coeffs(Poly(x, y)) is None + assert to_rational_coeffs(Poly(sqrt(2)*y)) is None + + +def test_factor_terms(): + # issue 7067 + assert factor_list(x*(x + y)) == (1, [(x, 1), (x + y, 1)]) + assert sqf_list(x*(x + y)) == (1, [(x**2 + x*y, 1)]) + + +def test_as_list(): + # issue 14496 + assert Poly(x**3 + 2, x, domain='ZZ').as_list() == [1, 0, 0, 2] + assert Poly(x**2 + y + 1, x, y, domain='ZZ').as_list() == [[1], [], [1, 1]] + assert Poly(x**2 + y + 1, x, y, z, domain='ZZ').as_list() == \ + [[[1]], [[]], [[1], [1]]] + + +def test_issue_11198(): + assert factor_list(sqrt(2)*x) == (sqrt(2), [(x, 1)]) + assert factor_list(sqrt(2)*sin(x), sin(x)) == (sqrt(2), [(sin(x), 1)]) + + +def test_Poly_precision(): + # Make sure Poly doesn't lose precision + p = Poly(pi.evalf(100)*x) + assert p.as_expr() == pi.evalf(100)*x + + +def test_issue_12400(): + # Correction of check for negative exponents + assert poly(1/(1+sqrt(2)), x) == \ + Poly(1/(1+sqrt(2)), x, domain='EX') + +def test_issue_14364(): + assert gcd(S(6)*(1 + sqrt(3))/5, S(3)*(1 + sqrt(3))/10) == Rational(3, 10) * (1 + sqrt(3)) + assert gcd(sqrt(5)*Rational(4, 7), sqrt(5)*Rational(2, 3)) == sqrt(5)*Rational(2, 21) + + assert lcm(Rational(2, 3)*sqrt(3), Rational(5, 6)*sqrt(3)) == S(10)*sqrt(3)/3 + assert lcm(3*sqrt(3), 4/sqrt(3)) == 12*sqrt(3) + assert lcm(S(5)*(1 + 2**Rational(1, 3))/6, S(3)*(1 + 2**Rational(1, 3))/8) == Rational(15, 2) * (1 + 2**Rational(1, 3)) + + assert gcd(Rational(2, 3)*sqrt(3), Rational(5, 6)/sqrt(3)) == sqrt(3)/18 + assert gcd(S(4)*sqrt(13)/7, S(3)*sqrt(13)/14) == sqrt(13)/14 + + # gcd_list and lcm_list + assert gcd([S(2)*sqrt(47)/7, S(6)*sqrt(47)/5, S(8)*sqrt(47)/5]) == sqrt(47)*Rational(2, 35) + assert gcd([S(6)*(1 + sqrt(7))/5, S(2)*(1 + sqrt(7))/7, S(4)*(1 + sqrt(7))/13]) == (1 + sqrt(7))*Rational(2, 455) + assert lcm((Rational(7, 2)/sqrt(15), Rational(5, 6)/sqrt(15), Rational(5, 8)/sqrt(15))) == Rational(35, 2)/sqrt(15) + assert lcm([S(5)*(2 + 2**Rational(5, 7))/6, S(7)*(2 + 2**Rational(5, 7))/2, S(13)*(2 + 2**Rational(5, 7))/4]) == Rational(455, 2) * (2 + 2**Rational(5, 7)) + + +def test_issue_15669(): + x = Symbol("x", positive=True) + expr = (16*x**3/(-x**2 + sqrt(8*x**2 + (x**2 - 2)**2) + 2)**2 - + 2*2**Rational(4, 5)*x*(-x**2 + sqrt(8*x**2 + (x**2 - 2)**2) + 2)**Rational(3, 5) + 10*x) + assert factor(expr, deep=True) == x*(x**2 + 2) + + +def test_issue_17988(): + x = Symbol('x') + p = poly(x - 1) + with warns_deprecated_sympy(): + M = Matrix([[poly(x + 1), poly(x + 1)]]) + with warns(SymPyDeprecationWarning, test_stacklevel=False): + assert p * M == M * p == Matrix([[poly(x**2 - 1), poly(x**2 - 1)]]) + + +def test_issue_18205(): + assert cancel((2 + I)*(3 - I)) == 7 + I + assert cancel((2 + I)*(2 - I)) == 5 + + +def test_issue_8695(): + p = (x**2 + 1) * (x - 1)**2 * (x - 2)**3 * (x - 3)**3 + result = (1, [(x**2 + 1, 1), (x - 1, 2), (x**2 - 5*x + 6, 3)]) + assert sqf_list(p) == result + + +def test_issue_19113(): + eq = sin(x)**3 - sin(x) + 1 + raises(PolynomialError, lambda: refine_root(eq, 1, 2, 1e-2)) + raises(PolynomialError, lambda: count_roots(eq, -1, 1)) + raises(PolynomialError, lambda: real_roots(eq)) + raises(PolynomialError, lambda: nroots(eq)) + raises(PolynomialError, lambda: ground_roots(eq)) + raises(PolynomialError, lambda: nth_power_roots_poly(eq, 2)) + + +def test_issue_19360(): + f = 2*x**2 - 2*sqrt(2)*x*y + y**2 + assert factor(f, extension=sqrt(2)) == 2*(x - (sqrt(2)*y/2))**2 + + f = -I*t*x - t*y + x*z - I*y*z + assert factor(f, extension=I) == (x - I*y)*(-I*t + z) + + +def test_poly_copy_equals_original(): + poly = Poly(x + y, x, y, z) + copy = poly.copy() + assert poly == copy, ( + "Copied polynomial not equal to original.") + assert poly.gens == copy.gens, ( + "Copied polynomial has different generators than original.") + + +def test_deserialized_poly_equals_original(): + poly = Poly(x + y, x, y, z) + deserialized = pickle.loads(pickle.dumps(poly)) + assert poly == deserialized, ( + "Deserialized polynomial not equal to original.") + assert poly.gens == deserialized.gens, ( + "Deserialized polynomial has different generators than original.") + + +def test_issue_20389(): + result = degree(x * (x + 1) - x ** 2 - x, x) + assert result == -oo + + +def test_issue_20985(): + from sympy.core.symbol import symbols + w, R = symbols('w R') + poly = Poly(1.0 + I*w/R, w, 1/R) + assert poly.degree() == S(1) diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/test_polyutils.py b/phivenv/Lib/site-packages/sympy/polys/tests/test_polyutils.py new file mode 100644 index 0000000000000000000000000000000000000000..f39561a1c5035fed52add5e49476d0eea91bdae0 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/tests/test_polyutils.py @@ -0,0 +1,300 @@ +"""Tests for useful utilities for higher level polynomial classes. """ + +from sympy.core.mul import Mul +from sympy.core.numbers import (Integer, pi) +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.integrals.integrals import Integral +from sympy.testing.pytest import raises + +from sympy.polys.polyutils import ( + _nsort, + _sort_gens, + _unify_gens, + _analyze_gens, + _sort_factors, + parallel_dict_from_expr, + dict_from_expr, +) + +from sympy.polys.polyerrors import PolynomialError + +from sympy.polys.domains import ZZ + +x, y, z, p, q, r, s, t, u, v, w = symbols('x,y,z,p,q,r,s,t,u,v,w') +A, B = symbols('A,B', commutative=False) + + +def test__nsort(): + # issue 6137 + r = S('''[3/2 + sqrt(-14/3 - 2*(-415/216 + 13*I/12)**(1/3) - 4/sqrt(-7/3 + + 61/(18*(-415/216 + 13*I/12)**(1/3)) + 2*(-415/216 + 13*I/12)**(1/3)) - + 61/(18*(-415/216 + 13*I/12)**(1/3)))/2 - sqrt(-7/3 + 61/(18*(-415/216 + + 13*I/12)**(1/3)) + 2*(-415/216 + 13*I/12)**(1/3))/2, 3/2 - sqrt(-7/3 + + 61/(18*(-415/216 + 13*I/12)**(1/3)) + 2*(-415/216 + + 13*I/12)**(1/3))/2 - sqrt(-14/3 - 2*(-415/216 + 13*I/12)**(1/3) - + 4/sqrt(-7/3 + 61/(18*(-415/216 + 13*I/12)**(1/3)) + 2*(-415/216 + + 13*I/12)**(1/3)) - 61/(18*(-415/216 + 13*I/12)**(1/3)))/2, 3/2 + + sqrt(-14/3 - 2*(-415/216 + 13*I/12)**(1/3) + 4/sqrt(-7/3 + + 61/(18*(-415/216 + 13*I/12)**(1/3)) + 2*(-415/216 + 13*I/12)**(1/3)) - + 61/(18*(-415/216 + 13*I/12)**(1/3)))/2 + sqrt(-7/3 + 61/(18*(-415/216 + + 13*I/12)**(1/3)) + 2*(-415/216 + 13*I/12)**(1/3))/2, 3/2 + sqrt(-7/3 + + 61/(18*(-415/216 + 13*I/12)**(1/3)) + 2*(-415/216 + + 13*I/12)**(1/3))/2 - sqrt(-14/3 - 2*(-415/216 + 13*I/12)**(1/3) + + 4/sqrt(-7/3 + 61/(18*(-415/216 + 13*I/12)**(1/3)) + 2*(-415/216 + + 13*I/12)**(1/3)) - 61/(18*(-415/216 + 13*I/12)**(1/3)))/2]''') + ans = [r[1], r[0], r[-1], r[-2]] + assert _nsort(r) == ans + assert len(_nsort(r, separated=True)[0]) == 0 + b, c, a = exp(-1000), exp(-999), exp(-1001) + assert _nsort((b, c, a)) == [a, b, c] + # issue 12560 + a = cos(1)**2 + sin(1)**2 - 1 + assert _nsort([a]) == [a] + + +def test__sort_gens(): + assert _sort_gens([]) == () + + assert _sort_gens([x]) == (x,) + assert _sort_gens([p]) == (p,) + assert _sort_gens([q]) == (q,) + + assert _sort_gens([x, p]) == (x, p) + assert _sort_gens([p, x]) == (x, p) + assert _sort_gens([q, p]) == (p, q) + + assert _sort_gens([q, p, x]) == (x, p, q) + + assert _sort_gens([x, p, q], wrt=x) == (x, p, q) + assert _sort_gens([x, p, q], wrt=p) == (p, x, q) + assert _sort_gens([x, p, q], wrt=q) == (q, x, p) + + assert _sort_gens([x, p, q], wrt='x') == (x, p, q) + assert _sort_gens([x, p, q], wrt='p') == (p, x, q) + assert _sort_gens([x, p, q], wrt='q') == (q, x, p) + + assert _sort_gens([x, p, q], wrt='x,q') == (x, q, p) + assert _sort_gens([x, p, q], wrt='q,x') == (q, x, p) + assert _sort_gens([x, p, q], wrt='p,q') == (p, q, x) + assert _sort_gens([x, p, q], wrt='q,p') == (q, p, x) + + assert _sort_gens([x, p, q], wrt='x, q') == (x, q, p) + assert _sort_gens([x, p, q], wrt='q, x') == (q, x, p) + assert _sort_gens([x, p, q], wrt='p, q') == (p, q, x) + assert _sort_gens([x, p, q], wrt='q, p') == (q, p, x) + + assert _sort_gens([x, p, q], wrt=[x, 'q']) == (x, q, p) + assert _sort_gens([x, p, q], wrt=[q, 'x']) == (q, x, p) + assert _sort_gens([x, p, q], wrt=[p, 'q']) == (p, q, x) + assert _sort_gens([x, p, q], wrt=[q, 'p']) == (q, p, x) + + assert _sort_gens([x, p, q], wrt=['x', 'q']) == (x, q, p) + assert _sort_gens([x, p, q], wrt=['q', 'x']) == (q, x, p) + assert _sort_gens([x, p, q], wrt=['p', 'q']) == (p, q, x) + assert _sort_gens([x, p, q], wrt=['q', 'p']) == (q, p, x) + + assert _sort_gens([x, p, q], sort='x > p > q') == (x, p, q) + assert _sort_gens([x, p, q], sort='p > x > q') == (p, x, q) + assert _sort_gens([x, p, q], sort='p > q > x') == (p, q, x) + + assert _sort_gens([x, p, q], wrt='x', sort='q > p') == (x, q, p) + assert _sort_gens([x, p, q], wrt='p', sort='q > x') == (p, q, x) + assert _sort_gens([x, p, q], wrt='q', sort='p > x') == (q, p, x) + + # https://github.com/sympy/sympy/issues/19353 + n1 = Symbol('\n1') + assert _sort_gens([n1]) == (n1,) + assert _sort_gens([x, n1]) == (x, n1) + + X = symbols('x0,x1,x2,x10,x11,x12,x20,x21,x22') + + assert _sort_gens(X) == X + + +def test__unify_gens(): + assert _unify_gens([], []) == () + + assert _unify_gens([x], [x]) == (x,) + assert _unify_gens([y], [y]) == (y,) + + assert _unify_gens([x, y], [x]) == (x, y) + assert _unify_gens([x], [x, y]) == (x, y) + + assert _unify_gens([x, y], [x, y]) == (x, y) + assert _unify_gens([y, x], [y, x]) == (y, x) + + assert _unify_gens([x], [y]) == (x, y) + assert _unify_gens([y], [x]) == (y, x) + + assert _unify_gens([x], [y, x]) == (y, x) + assert _unify_gens([y, x], [x]) == (y, x) + + assert _unify_gens([x, y, z], [x, y, z]) == (x, y, z) + assert _unify_gens([z, y, x], [x, y, z]) == (z, y, x) + assert _unify_gens([x, y, z], [z, y, x]) == (x, y, z) + assert _unify_gens([z, y, x], [z, y, x]) == (z, y, x) + + assert _unify_gens([x, y, z], [t, x, p, q, z]) == (t, x, y, p, q, z) + + +def test__analyze_gens(): + assert _analyze_gens((x, y, z)) == (x, y, z) + assert _analyze_gens([x, y, z]) == (x, y, z) + + assert _analyze_gens(([x, y, z],)) == (x, y, z) + assert _analyze_gens(((x, y, z),)) == (x, y, z) + + +def test__sort_factors(): + assert _sort_factors([], multiple=True) == [] + assert _sort_factors([], multiple=False) == [] + + F = [[1, 2, 3], [1, 2], [1]] + G = [[1], [1, 2], [1, 2, 3]] + + assert _sort_factors(F, multiple=False) == G + + F = [[1, 2], [1, 2, 3], [1, 2], [1]] + G = [[1], [1, 2], [1, 2], [1, 2, 3]] + + assert _sort_factors(F, multiple=False) == G + + F = [[2, 2], [1, 2, 3], [1, 2], [1]] + G = [[1], [1, 2], [2, 2], [1, 2, 3]] + + assert _sort_factors(F, multiple=False) == G + + F = [([1, 2, 3], 1), ([1, 2], 1), ([1], 1)] + G = [([1], 1), ([1, 2], 1), ([1, 2, 3], 1)] + + assert _sort_factors(F, multiple=True) == G + + F = [([1, 2], 1), ([1, 2, 3], 1), ([1, 2], 1), ([1], 1)] + G = [([1], 1), ([1, 2], 1), ([1, 2], 1), ([1, 2, 3], 1)] + + assert _sort_factors(F, multiple=True) == G + + F = [([2, 2], 1), ([1, 2, 3], 1), ([1, 2], 1), ([1], 1)] + G = [([1], 1), ([1, 2], 1), ([2, 2], 1), ([1, 2, 3], 1)] + + assert _sort_factors(F, multiple=True) == G + + F = [([2, 2], 1), ([1, 2, 3], 1), ([1, 2], 2), ([1], 1)] + G = [([1], 1), ([2, 2], 1), ([1, 2], 2), ([1, 2, 3], 1)] + + assert _sort_factors(F, multiple=True) == G + + +def test__dict_from_expr_if_gens(): + assert dict_from_expr( + Integer(17), gens=(x,)) == ({(0,): Integer(17)}, (x,)) + assert dict_from_expr( + Integer(17), gens=(x, y)) == ({(0, 0): Integer(17)}, (x, y)) + assert dict_from_expr( + Integer(17), gens=(x, y, z)) == ({(0, 0, 0): Integer(17)}, (x, y, z)) + + assert dict_from_expr( + Integer(-17), gens=(x,)) == ({(0,): Integer(-17)}, (x,)) + assert dict_from_expr( + Integer(-17), gens=(x, y)) == ({(0, 0): Integer(-17)}, (x, y)) + assert dict_from_expr(Integer( + -17), gens=(x, y, z)) == ({(0, 0, 0): Integer(-17)}, (x, y, z)) + + assert dict_from_expr( + Integer(17)*x, gens=(x,)) == ({(1,): Integer(17)}, (x,)) + assert dict_from_expr( + Integer(17)*x, gens=(x, y)) == ({(1, 0): Integer(17)}, (x, y)) + assert dict_from_expr(Integer( + 17)*x, gens=(x, y, z)) == ({(1, 0, 0): Integer(17)}, (x, y, z)) + + assert dict_from_expr( + Integer(17)*x**7, gens=(x,)) == ({(7,): Integer(17)}, (x,)) + assert dict_from_expr( + Integer(17)*x**7*y, gens=(x, y)) == ({(7, 1): Integer(17)}, (x, y)) + assert dict_from_expr(Integer(17)*x**7*y*z**12, gens=( + x, y, z)) == ({(7, 1, 12): Integer(17)}, (x, y, z)) + + assert dict_from_expr(x + 2*y + 3*z, gens=(x,)) == \ + ({(1,): Integer(1), (0,): 2*y + 3*z}, (x,)) + assert dict_from_expr(x + 2*y + 3*z, gens=(x, y)) == \ + ({(1, 0): Integer(1), (0, 1): Integer(2), (0, 0): 3*z}, (x, y)) + assert dict_from_expr(x + 2*y + 3*z, gens=(x, y, z)) == \ + ({(1, 0, 0): Integer( + 1), (0, 1, 0): Integer(2), (0, 0, 1): Integer(3)}, (x, y, z)) + + assert dict_from_expr(x*y + 2*x*z + 3*y*z, gens=(x,)) == \ + ({(1,): y + 2*z, (0,): 3*y*z}, (x,)) + assert dict_from_expr(x*y + 2*x*z + 3*y*z, gens=(x, y)) == \ + ({(1, 1): Integer(1), (1, 0): 2*z, (0, 1): 3*z}, (x, y)) + assert dict_from_expr(x*y + 2*x*z + 3*y*z, gens=(x, y, z)) == \ + ({(1, 1, 0): Integer( + 1), (1, 0, 1): Integer(2), (0, 1, 1): Integer(3)}, (x, y, z)) + + assert dict_from_expr(2**y*x, gens=(x,)) == ({(1,): 2**y}, (x,)) + assert dict_from_expr(Integral(x, (x, 1, 2)) + x) == ( + {(0, 1): 1, (1, 0): 1}, (x, Integral(x, (x, 1, 2)))) + raises(PolynomialError, lambda: dict_from_expr(2**y*x, gens=(x, y))) + + +def test__dict_from_expr_no_gens(): + assert dict_from_expr(Integer(17)) == ({(): Integer(17)}, ()) + + assert dict_from_expr(x) == ({(1,): Integer(1)}, (x,)) + assert dict_from_expr(y) == ({(1,): Integer(1)}, (y,)) + + assert dict_from_expr(x*y) == ({(1, 1): Integer(1)}, (x, y)) + assert dict_from_expr( + x + y) == ({(1, 0): Integer(1), (0, 1): Integer(1)}, (x, y)) + + assert dict_from_expr(sqrt(2)) == ({(1,): Integer(1)}, (sqrt(2),)) + assert dict_from_expr(sqrt(2), greedy=False) == ({(): sqrt(2)}, ()) + + assert dict_from_expr(x*y, domain=ZZ[x]) == ({(1,): x}, (y,)) + assert dict_from_expr(x*y, domain=ZZ[y]) == ({(1,): y}, (x,)) + + assert dict_from_expr(3*sqrt( + 2)*pi*x*y, extension=None) == ({(1, 1, 1, 1): 3}, (x, y, pi, sqrt(2))) + assert dict_from_expr(3*sqrt( + 2)*pi*x*y, extension=True) == ({(1, 1, 1): 3*sqrt(2)}, (x, y, pi)) + + assert dict_from_expr(3*sqrt( + 2)*pi*x*y, extension=True) == ({(1, 1, 1): 3*sqrt(2)}, (x, y, pi)) + + f = cos(x)*sin(x) + cos(x)*sin(y) + cos(y)*sin(x) + cos(y)*sin(y) + + assert dict_from_expr(f) == ({(0, 1, 0, 1): 1, (0, 1, 1, 0): 1, + (1, 0, 0, 1): 1, (1, 0, 1, 0): 1}, (cos(x), cos(y), sin(x), sin(y))) + + +def test__parallel_dict_from_expr_if_gens(): + assert parallel_dict_from_expr([x + 2*y + 3*z, Integer(7)], gens=(x,)) == \ + ([{(1,): Integer(1), (0,): 2*y + 3*z}, {(0,): Integer(7)}], (x,)) + + +def test__parallel_dict_from_expr_no_gens(): + assert parallel_dict_from_expr([x*y, Integer(3)]) == \ + ([{(1, 1): Integer(1)}, {(0, 0): Integer(3)}], (x, y)) + assert parallel_dict_from_expr([x*y, 2*z, Integer(3)]) == \ + ([{(1, 1, 0): Integer( + 1)}, {(0, 0, 1): Integer(2)}, {(0, 0, 0): Integer(3)}], (x, y, z)) + assert parallel_dict_from_expr((Mul(x, x**2, evaluate=False),)) == \ + ([{(3,): 1}], (x,)) + + +def test_parallel_dict_from_expr(): + assert parallel_dict_from_expr([Eq(x, 1), Eq( + x**2, 2)]) == ([{(0,): -Integer(1), (1,): Integer(1)}, + {(0,): -Integer(2), (2,): Integer(1)}], (x,)) + raises(PolynomialError, lambda: parallel_dict_from_expr([A*B - B*A])) + + +def test_dict_from_expr(): + assert dict_from_expr(Eq(x, 1)) == \ + ({(0,): -Integer(1), (1,): Integer(1)}, (x,)) + raises(PolynomialError, lambda: dict_from_expr(A*B - B*A)) + raises(PolynomialError, lambda: dict_from_expr(S.true)) diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/test_puiseux.py b/phivenv/Lib/site-packages/sympy/polys/tests/test_puiseux.py new file mode 100644 index 0000000000000000000000000000000000000000..031881e9d12c53053d8ec7136374bd8b3a385df0 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/tests/test_puiseux.py @@ -0,0 +1,204 @@ +# +# Tests for PuiseuxRing and PuiseuxPoly +# + +from sympy.testing.pytest import raises + +from sympy import ZZ, QQ, ring +from sympy.polys.puiseux import PuiseuxRing, PuiseuxPoly, puiseux_ring + +from sympy.abc import x, y + + +def test_puiseux_ring(): + R, px = puiseux_ring('x', QQ) + R2, px2 = puiseux_ring([x], QQ) + assert isinstance(R, PuiseuxRing) + assert isinstance(px, PuiseuxPoly) + assert R == R2 + assert px == px2 + assert R == PuiseuxRing('x', QQ) + assert R == PuiseuxRing([x], QQ) + assert R != PuiseuxRing('y', QQ) + assert R != PuiseuxRing('x', ZZ) + assert R != PuiseuxRing('x, y', QQ) + assert R != QQ + assert str(R) == 'PuiseuxRing((x,), QQ)' + + +def test_puiseux_ring_attributes(): + R1, px1, py1 = ring('x, y', QQ) + R2, px2, py2 = puiseux_ring('x, y', QQ) + assert R2.domain == QQ + assert R2.symbols == (x, y) + assert R2.gens == (px2, py2) + assert R2.ngens == 2 + assert R2.poly_ring == R1 + assert R2.zero == PuiseuxPoly(R1.zero, R2) + assert R2.one == PuiseuxPoly(R1.one, R2) + assert R2.zero_monom == R1.zero_monom == (0, 0) # type: ignore + assert R2.monomial_mul((1, 2), (3, 4)) == (4, 6) + + +def test_puiseux_ring_methods(): + R1, px1, py1 = ring('x, y', QQ) + R2, px2, py2 = puiseux_ring('x, y', QQ) + assert R2({(1, 2): 3}) == 3*px2*py2**2 + assert R2(px1) == px2 + assert R2(1) == R2.one + assert R2(QQ(1,2)) == QQ(1,2)*R2.one + assert R2.from_poly(px1) == px2 + assert R2.from_poly(px1) != py2 + assert R2.from_dict({(1, 2): QQ(3)}) == 3*px2*py2**2 + assert R2.from_dict({(QQ(1,2), 2): QQ(3)}) == 3*px2**QQ(1,2)*py2**2 + assert R2.from_int(3) == 3*R2.one + assert R2.domain_new(3) == QQ(3) + assert QQ.of_type(R2.domain_new(3)) + assert R2.ground_new(3) == 3*R2.one + assert isinstance(R2.ground_new(3), PuiseuxPoly) + assert R2.index(px2) == 0 + assert R2.index(py2) == 1 + + +def test_puiseux_poly(): + R1, px1 = ring('x', QQ) + R2, px2 = puiseux_ring('x', QQ) + assert PuiseuxPoly(px1, R2) == px2 + assert px2.ring == R2 + assert px2.as_expr() == px1.as_expr() == x + assert px1 != px2 + assert R2.one == px2**0 == 1 + assert px2 == px1 + assert px2 != 2.0 + assert px2**QQ(1,2) != px1 + + +def test_puiseux_poly_normalization(): + R, x = puiseux_ring('x', QQ) + assert (x**2 + 1) / x == x + 1/x == R({(1,): 1, (-1,): 1}) + assert (x**QQ(1,6))**2 == x**QQ(1,3) == R({(QQ(1,3),): 1}) + assert (x**QQ(1,6))**(-2) == x**(-QQ(1,3)) == R({(-QQ(1,3),): 1}) + assert (x**QQ(1,6))**QQ(1,2) == x**QQ(1,12) == R({(QQ(1,12),): 1}) + assert (x**QQ(1,6))**6 == x == R({(1,): 1}) + assert x**QQ(1,6) * x**QQ(1,3) == x**QQ(1,2) == R({(QQ(1,2),): 1}) + assert 1/x * x**2 == x == R({(1,): 1}) + assert 1/x**QQ(1,3) * x**QQ(1,3) == 1 == R({(0,): 1}) + + +def test_puiseux_poly_monoms(): + R, x = puiseux_ring('x', QQ) + assert x.monoms() == [(1,)] + assert list(x) == [(1,)] + assert (x**2 + 1).monoms() == [(2,), (0,)] + assert R({(1,): 1, (-1,): 1}).monoms() == [(1,), (-1,)] + assert R({(QQ(1,3),): 1}).monoms() == [(QQ(1,3),)] + assert R({(-QQ(1,3),): 1}).monoms() == [(-QQ(1,3),)] + p = x**QQ(1,6) + assert p[(QQ(1,6),)] == 1 + raises(KeyError, lambda: p[(1,)]) + assert p.to_dict() == {(QQ(1,6),): 1} + assert R(p.to_dict()) == p + assert PuiseuxPoly.from_dict({(QQ(1,6),): 1}, R) == p + + +def test_puiseux_poly_repr(): + R, x = puiseux_ring('x', QQ) + assert repr(x) == 'x' + assert repr(x**QQ(1,2)) == 'x**(1/2)' + assert repr(1/x) == 'x**(-1)' + assert repr(2*x**2 + 1) == '1 + 2*x**2' + assert repr(R.one) == '1' + assert repr(2*R.one) == '2' + + +def test_puiseux_poly_unify(): + R, x = puiseux_ring('x', QQ) + assert 1/x + x == x + 1/x == R({(1,): 1, (-1,): 1}) + assert repr(1/x + x) == 'x**(-1) + x' + assert 1/x + 1/x == 2/x == R({(-1,): 2}) + assert repr(1/x + 1/x) == '2*x**(-1)' + assert x**QQ(1,2) + x**QQ(1,2) == 2*x**QQ(1,2) == R({(QQ(1,2),): 2}) + assert repr(x**QQ(1,2) + x**QQ(1,2)) == '2*x**(1/2)' + assert x**QQ(1,2) + x**QQ(1,3) == R({(QQ(1,2),): 1, (QQ(1,3),): 1}) + assert repr(x**QQ(1,2) + x**QQ(1,3)) == 'x**(1/3) + x**(1/2)' + assert x + x**QQ(1,2) == R({(1,): 1, (QQ(1,2),): 1}) + assert repr(x + x**QQ(1,2)) == 'x**(1/2) + x' + assert 1/x**QQ(1,2) + 1/x**QQ(1,3) == R({(-QQ(1,2),): 1, (-QQ(1,3),): 1}) + assert repr(1/x**QQ(1,2) + 1/x**QQ(1,3)) == 'x**(-1/2) + x**(-1/3)' + assert 1/x + x**QQ(1,2) == x**QQ(1,2) + 1/x == R({(-1,): 1, (QQ(1,2),): 1}) + assert repr(1/x + x**QQ(1,2)) == 'x**(-1) + x**(1/2)' + + +def test_puiseux_poly_arit(): + R, x = puiseux_ring('x', QQ) + R2, y = puiseux_ring('y', QQ) + p = x**2 + 1 + assert +p == p + assert -p == -1 - x**2 + assert p + p == 2*p == 2*x**2 + 2 + assert p + 1 == 1 + p == x**2 + 2 + assert p + QQ(1,2) == QQ(1,2) + p == x**2 + QQ(3,2) + assert p - p == 0 + assert p - 1 == -1 + p == x**2 + assert p - QQ(1,2) == -QQ(1,2) + p == x**2 + QQ(1,2) + assert 1 - p == -p + 1 == -x**2 + assert QQ(1,2) - p == -p + QQ(1,2) == -x**2 - QQ(1,2) + assert p * p == x**4 + 2*x**2 + 1 + assert p * 1 == 1 * p == p + assert 2 * p == p * 2 == 2*x**2 + 2 + assert p * QQ(1,2) == QQ(1,2) * p == QQ(1,2)*x**2 + QQ(1,2) + assert x**QQ(1,2) * x**QQ(1,2) == x + raises(ValueError, lambda: x + y) + raises(ValueError, lambda: x - y) + raises(ValueError, lambda: x * y) + raises(TypeError, lambda: x + None) + raises(TypeError, lambda: x - None) + raises(TypeError, lambda: x * None) + raises(TypeError, lambda: None + x) + raises(TypeError, lambda: None - x) + raises(TypeError, lambda: None * x) + + +def test_puiseux_poly_div(): + R, x = puiseux_ring('x', QQ) + R2, y = puiseux_ring('y', QQ) + p = x**2 - 1 + assert p / 1 == p + assert p / QQ(1,2) == 2*p == 2*x**2 - 2 + assert p / x == x - 1/x == R({(1,): 1, (-1,): -1}) + assert 2 / x == 2*x**-1 == R({(-1,): 2}) + assert QQ(1,2) / x == QQ(1,2)*x**-1 == 1/(2*x) == 1/x/2 == R({(-1,): QQ(1,2)}) + raises(ZeroDivisionError, lambda: p / 0) + raises(ValueError, lambda: (x + 1) / (x + 2)) + raises(ValueError, lambda: (x + 1) / (x + 1)) + raises(ValueError, lambda: x / y) + raises(TypeError, lambda: x / None) + raises(TypeError, lambda: None / x) + + +def test_puiseux_poly_pow(): + R, x = puiseux_ring('x', QQ) + Rz, xz = puiseux_ring('x', ZZ) + assert x**0 == 1 == R({(0,): 1}) + assert x**1 == x == R({(1,): 1}) + assert x**2 == x*x == R({(2,): 1}) + assert x**QQ(1,2) == R({(QQ(1,2),): 1}) + assert x**-1 == 1/x == R({(-1,): 1}) + assert x**-QQ(1,2) == 1/x**QQ(1,2) == R({(-QQ(1,2),): 1}) + assert (2*x)**-1 == 1/(2*x) == QQ(1,2)/x == QQ(1,2)*x**-1 == R({(-1,): QQ(1,2)}) + assert 2/x**2 == 2*x**-2 == R({(-2,): 2}) + assert 2/xz**2 == 2*xz**-2 == Rz({(-2,): 2}) + raises(TypeError, lambda: x**None) + raises(ValueError, lambda: (x + 1)**-1) + raises(ValueError, lambda: (x + 1)**QQ(1,2)) + raises(ValueError, lambda: (2*x)**QQ(1,2)) + raises(ValueError, lambda: (2*xz)**-1) + + +def test_puiseux_poly_diff(): + R, x, y = puiseux_ring('x, y', QQ) + assert (x**2 + 1).diff(x) == 2*x + assert (x**2 + 1).diff(y) == 0 + assert (x**2 + y**2).diff(x) == 2*x + assert (x**QQ(1,2) + y**QQ(1,2)).diff(x) == QQ(1,2)*x**-QQ(1,2) + assert ((x*y)**QQ(1,2)).diff(x) == QQ(1,2)*y**QQ(1,2)*x**-QQ(1,2) diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/test_pythonrational.py b/phivenv/Lib/site-packages/sympy/polys/tests/test_pythonrational.py new file mode 100644 index 0000000000000000000000000000000000000000..547a5679626fd3a6165b151364bb506a574bb1db --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/tests/test_pythonrational.py @@ -0,0 +1,139 @@ +"""Tests for PythonRational type. """ + +from sympy.polys.domains import PythonRational as QQ +from sympy.testing.pytest import raises + +def test_PythonRational__init__(): + assert QQ(0).numerator == 0 + assert QQ(0).denominator == 1 + assert QQ(0, 1).numerator == 0 + assert QQ(0, 1).denominator == 1 + assert QQ(0, -1).numerator == 0 + assert QQ(0, -1).denominator == 1 + + assert QQ(1).numerator == 1 + assert QQ(1).denominator == 1 + assert QQ(1, 1).numerator == 1 + assert QQ(1, 1).denominator == 1 + assert QQ(-1, -1).numerator == 1 + assert QQ(-1, -1).denominator == 1 + + assert QQ(-1).numerator == -1 + assert QQ(-1).denominator == 1 + assert QQ(-1, 1).numerator == -1 + assert QQ(-1, 1).denominator == 1 + assert QQ( 1, -1).numerator == -1 + assert QQ( 1, -1).denominator == 1 + + assert QQ(1, 2).numerator == 1 + assert QQ(1, 2).denominator == 2 + assert QQ(3, 4).numerator == 3 + assert QQ(3, 4).denominator == 4 + + assert QQ(2, 2).numerator == 1 + assert QQ(2, 2).denominator == 1 + assert QQ(2, 4).numerator == 1 + assert QQ(2, 4).denominator == 2 + +def test_PythonRational__hash__(): + assert hash(QQ(0)) == hash(0) + assert hash(QQ(1)) == hash(1) + assert hash(QQ(117)) == hash(117) + +def test_PythonRational__int__(): + assert int(QQ(-1, 4)) == 0 + assert int(QQ( 1, 4)) == 0 + assert int(QQ(-5, 4)) == -1 + assert int(QQ( 5, 4)) == 1 + +def test_PythonRational__float__(): + assert float(QQ(-1, 2)) == -0.5 + assert float(QQ( 1, 2)) == 0.5 + +def test_PythonRational__abs__(): + assert abs(QQ(-1, 2)) == QQ(1, 2) + assert abs(QQ( 1, 2)) == QQ(1, 2) + +def test_PythonRational__pos__(): + assert +QQ(-1, 2) == QQ(-1, 2) + assert +QQ( 1, 2) == QQ( 1, 2) + +def test_PythonRational__neg__(): + assert -QQ(-1, 2) == QQ( 1, 2) + assert -QQ( 1, 2) == QQ(-1, 2) + +def test_PythonRational__add__(): + assert QQ(-1, 2) + QQ( 1, 2) == QQ(0) + assert QQ( 1, 2) + QQ(-1, 2) == QQ(0) + + assert QQ(1, 2) + QQ(1, 2) == QQ(1) + assert QQ(1, 2) + QQ(3, 2) == QQ(2) + assert QQ(3, 2) + QQ(1, 2) == QQ(2) + assert QQ(3, 2) + QQ(3, 2) == QQ(3) + + assert 1 + QQ(1, 2) == QQ(3, 2) + assert QQ(1, 2) + 1 == QQ(3, 2) + +def test_PythonRational__sub__(): + assert QQ(-1, 2) - QQ( 1, 2) == QQ(-1) + assert QQ( 1, 2) - QQ(-1, 2) == QQ( 1) + + assert QQ(1, 2) - QQ(1, 2) == QQ( 0) + assert QQ(1, 2) - QQ(3, 2) == QQ(-1) + assert QQ(3, 2) - QQ(1, 2) == QQ( 1) + assert QQ(3, 2) - QQ(3, 2) == QQ( 0) + + assert 1 - QQ(1, 2) == QQ( 1, 2) + assert QQ(1, 2) - 1 == QQ(-1, 2) + +def test_PythonRational__mul__(): + assert QQ(-1, 2) * QQ( 1, 2) == QQ(-1, 4) + assert QQ( 1, 2) * QQ(-1, 2) == QQ(-1, 4) + + assert QQ(1, 2) * QQ(1, 2) == QQ(1, 4) + assert QQ(1, 2) * QQ(3, 2) == QQ(3, 4) + assert QQ(3, 2) * QQ(1, 2) == QQ(3, 4) + assert QQ(3, 2) * QQ(3, 2) == QQ(9, 4) + + assert 2 * QQ(1, 2) == QQ(1) + assert QQ(1, 2) * 2 == QQ(1) + +def test_PythonRational__truediv__(): + assert QQ(-1, 2) / QQ( 1, 2) == QQ(-1) + assert QQ( 1, 2) / QQ(-1, 2) == QQ(-1) + + assert QQ(1, 2) / QQ(1, 2) == QQ(1) + assert QQ(1, 2) / QQ(3, 2) == QQ(1, 3) + assert QQ(3, 2) / QQ(1, 2) == QQ(3) + assert QQ(3, 2) / QQ(3, 2) == QQ(1) + + assert 2 / QQ(1, 2) == QQ(4) + assert QQ(1, 2) / 2 == QQ(1, 4) + + raises(ZeroDivisionError, lambda: QQ(1, 2) / QQ(0)) + raises(ZeroDivisionError, lambda: QQ(1, 2) / 0) + +def test_PythonRational__pow__(): + assert QQ(1)**10 == QQ(1) + assert QQ(2)**10 == QQ(1024) + + assert QQ(1)**(-10) == QQ(1) + assert QQ(2)**(-10) == QQ(1, 1024) + +def test_PythonRational__eq__(): + assert (QQ(1, 2) == QQ(1, 2)) is True + assert (QQ(1, 2) != QQ(1, 2)) is False + + assert (QQ(1, 2) == QQ(1, 3)) is False + assert (QQ(1, 2) != QQ(1, 3)) is True + +def test_PythonRational__lt_le_gt_ge__(): + assert (QQ(1, 2) < QQ(1, 4)) is False + assert (QQ(1, 2) <= QQ(1, 4)) is False + assert (QQ(1, 2) > QQ(1, 4)) is True + assert (QQ(1, 2) >= QQ(1, 4)) is True + + assert (QQ(1, 4) < QQ(1, 2)) is True + assert (QQ(1, 4) <= QQ(1, 2)) is True + assert (QQ(1, 4) > QQ(1, 2)) is False + assert (QQ(1, 4) >= QQ(1, 2)) is False diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/test_rationaltools.py b/phivenv/Lib/site-packages/sympy/polys/tests/test_rationaltools.py new file mode 100644 index 0000000000000000000000000000000000000000..3ee0192a3fbc8997347df081663015afd91dd8ad --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/tests/test_rationaltools.py @@ -0,0 +1,63 @@ +"""Tests for tools for manipulation of rational expressions. """ + +from sympy.polys.rationaltools import together + +from sympy.core.mul import Mul +from sympy.core.numbers import Rational +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.trigonometric import sin +from sympy.integrals.integrals import Integral +from sympy.abc import x, y, z + +A, B = symbols('A,B', commutative=False) + + +def test_together(): + assert together(0) == 0 + assert together(1) == 1 + + assert together(x*y*z) == x*y*z + assert together(x + y) == x + y + + assert together(1/x) == 1/x + + assert together(1/x + 1) == (x + 1)/x + assert together(1/x + 3) == (3*x + 1)/x + assert together(1/x + x) == (x**2 + 1)/x + + assert together(1/x + S.Half) == (x + 2)/(2*x) + assert together(S.Half + x/2) == Mul(S.Half, x + 1, evaluate=False) + + assert together(1/x + 2/y) == (2*x + y)/(y*x) + assert together(1/(1 + 1/x)) == x/(1 + x) + assert together(x/(1 + 1/x)) == x**2/(1 + x) + + assert together(1/x + 1/y + 1/z) == (x*y + x*z + y*z)/(x*y*z) + assert together(1/(1 + x + 1/y + 1/z)) == y*z/(y + z + y*z + x*y*z) + + assert together(1/(x*y) + 1/(x*y)**2) == y**(-2)*x**(-2)*(1 + x*y) + assert together(1/(x*y) + 1/(x*y)**4) == y**(-4)*x**(-4)*(1 + x**3*y**3) + assert together(1/(x**7*y) + 1/(x*y)**4) == y**(-4)*x**(-7)*(x**3 + y**3) + + assert together(5/(2 + 6/(3 + 7/(4 + 8/(5 + 9/x))))) == \ + Rational(5, 2)*((171 + 119*x)/(279 + 203*x)) + + assert together(1 + 1/(x + 1)**2) == (1 + (x + 1)**2)/(x + 1)**2 + assert together(1 + 1/(x*(1 + x))) == (1 + x*(1 + x))/(x*(1 + x)) + assert together( + 1/(x*(x + 1)) + 1/(x*(x + 2))) == (3 + 2*x)/(x*(1 + x)*(2 + x)) + assert together(1 + 1/(2*x + 2)**2) == (4*(x + 1)**2 + 1)/(4*(x + 1)**2) + + assert together(sin(1/x + 1/y)) == sin(1/x + 1/y) + assert together(sin(1/x + 1/y), deep=True) == sin((x + y)/(x*y)) + + assert together(1/exp(x) + 1/(x*exp(x))) == (1 + x)/(x*exp(x)) + assert together(1/exp(2*x) + 1/(x*exp(3*x))) == (1 + exp(x)*x)/(x*exp(3*x)) + + assert together(Integral(1/x + 1/y, x)) == Integral((x + y)/(x*y), x) + assert together(Eq(1/x + 1/y, 1 + 1/z)) == Eq((x + y)/(x*y), (z + 1)/z) + + assert together((A*B)**-1 + (B*A)**-1) == (A*B)**-1 + (B*A)**-1 diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/test_ring_series.py b/phivenv/Lib/site-packages/sympy/polys/tests/test_ring_series.py new file mode 100644 index 0000000000000000000000000000000000000000..d983fc99f8ffcf9361d8d069f1d381928ac0aada --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/tests/test_ring_series.py @@ -0,0 +1,831 @@ +from sympy.polys.domains import ZZ, QQ, EX, RR +from sympy.polys.rings import ring +from sympy.polys.puiseux import puiseux_ring +from sympy.polys.ring_series import (_invert_monoms, rs_integrate, + rs_trunc, rs_mul, rs_square, rs_pow, _has_constant_term, rs_hadamard_exp, + rs_series_from_list, rs_exp, rs_log, rs_newton, rs_series_inversion, + rs_compose_add, rs_asin, _atan, rs_atan, _atanh, rs_atanh, rs_asinh, rs_tan, + rs_cot, rs_sin, rs_cos, rs_cos_sin, rs_sinh, rs_cosh, rs_cosh_sinh, rs_tanh, + _tan1, rs_fun, rs_nth_root, rs_LambertW, rs_series_reversion, rs_is_puiseux, + rs_series) +from sympy.testing.pytest import raises, slow +from sympy.core.symbol import symbols +from sympy.functions import (sin, cos, exp, tan, cot, sinh, cosh, atan, atanh, + asinh, tanh, log, sqrt) +from sympy.core.numbers import Rational, pi +from sympy.core import expand, S + +def is_close(a, b): + tol = 10**(-10) + assert abs(a - b) < tol + + +def test_ring_series1(): + R, x = ring('x', QQ) + p = x**4 + 2*x**3 + 3*x + 4 + assert _invert_monoms(p) == 4*x**4 + 3*x**3 + 2*x + 1 + assert rs_hadamard_exp(p) == x**4/24 + x**3/3 + 3*x + 4 + R, x = ring('x', QQ) + p = x**4 + 2*x**3 + 3*x + 4 + assert rs_integrate(p, x) == x**5/5 + x**4/2 + 3*x**2/2 + 4*x + R, x, y = ring('x, y', QQ) + p = x**2*y**2 + x + 1 + assert rs_integrate(p, x) == x**3*y**2/3 + x**2/2 + x + assert rs_integrate(p, y) == x**2*y**3/3 + x*y + y + + +def test_trunc(): + R, x, y, t = ring('x, y, t', QQ) + p = (y + t*x)**4 + p1 = rs_trunc(p, x, 3) + assert p1 == y**4 + 4*y**3*t*x + 6*y**2*t**2*x**2 + + +def test_mul_trunc(): + R, x, y, t = ring('x, y, t', QQ) + p = 1 + t*x + t*y + for i in range(2): + p = rs_mul(p, p, t, 3) + + assert p == 6*x**2*t**2 + 12*x*y*t**2 + 6*y**2*t**2 + 4*x*t + 4*y*t + 1 + p = 1 + t*x + t*y + t**2*x*y + p1 = rs_mul(p, p, t, 2) + assert p1 == 1 + 2*t*x + 2*t*y + R1, z = ring('z', QQ) + raises(ValueError, lambda: rs_mul(p, z, x, 2)) + + p1 = 2 + 2*x + 3*x**2 + p2 = 3 + x**2 + assert rs_mul(p1, p2, x, 4) == 2*x**3 + 11*x**2 + 6*x + 6 + + +def test_square_trunc(): + R, x, y, t = ring('x, y, t', QQ) + p = (1 + t*x + t*y)*2 + p1 = rs_mul(p, p, x, 3) + p2 = rs_square(p, x, 3) + assert p1 == p2 + p = 1 + x + x**2 + x**3 + assert rs_square(p, x, 4) == 4*x**3 + 3*x**2 + 2*x + 1 + + +def test_pow_trunc(): + R, x, y, z = ring('x, y, z', QQ) + p0 = y + x*z + p = p0**16 + for xx in (x, y, z): + p1 = rs_trunc(p, xx, 8) + p2 = rs_pow(p0, 16, xx, 8) + assert p1 == p2 + + p = 1 + x + p1 = rs_pow(p, 3, x, 2) + assert p1 == 1 + 3*x + assert rs_pow(p, 0, x, 2) == 1 + assert rs_pow(p, -2, x, 2) == 1 - 2*x + p = x + y + assert rs_pow(p, 3, y, 3) == x**3 + 3*x**2*y + 3*x*y**2 + assert rs_pow(1 + x, Rational(2, 3), x, 4) == 4*x**3/81 - x**2/9 + x*Rational(2, 3) + 1 + + +def test_has_constant_term(): + R, x, y, z = ring('x, y, z', QQ) + p = y + x*z + assert _has_constant_term(p, x) + p = x + x**4 + assert not _has_constant_term(p, x) + p = 1 + x + x**4 + assert _has_constant_term(p, x) + p = x + y + x*z + + +def test_inversion(): + R, x = ring('x', QQ) + p = 2 + x + 2*x**2 + n = 5 + p1 = rs_series_inversion(p, x, n) + assert rs_trunc(p*p1, x, n) == 1 + R, x, y = ring('x, y', QQ) + p = 2 + x + 2*x**2 + y*x + x**2*y + p1 = rs_series_inversion(p, x, n) + assert rs_trunc(p*p1, x, n) == 1 + + R, x, y = ring('x, y', QQ) + p = 1 + x + y + raises(NotImplementedError, lambda: rs_series_inversion(p, x, 4)) + p = R.zero + raises(ZeroDivisionError, lambda: rs_series_inversion(p, x, 3)) + + R, x = ring('x', ZZ) + p = 2 + x + raises(ValueError, lambda: rs_series_inversion(p, x, 3)) + + +def test_series_reversion(): + R, x, y = ring('x, y', QQ) + + p = rs_tan(x, x, 10) + assert rs_series_reversion(p, x, 8, y) == rs_atan(y, y, 8) + + p = rs_sin(x, x, 10) + assert rs_series_reversion(p, x, 8, y) == 5*y**7/112 + 3*y**5/40 + \ + y**3/6 + y + + +def test_series_from_list(): + R, x = ring('x', QQ) + p = 1 + 2*x + x**2 + 3*x**3 + c = [1, 2, 0, 4, 4] + r = rs_series_from_list(p, c, x, 5) + pc = R.from_list(list(reversed(c))) + r1 = rs_trunc(pc.compose(x, p), x, 5) + assert r == r1 + R, x, y = ring('x, y', QQ) + c = [1, 3, 5, 7] + p1 = rs_series_from_list(x + y, c, x, 3, concur=0) + p2 = rs_trunc((1 + 3*(x+y) + 5*(x+y)**2 + 7*(x+y)**3), x, 3) + assert p1 == p2 + + R, x = ring('x', QQ) + h = 25 + p = rs_exp(x, x, h) - 1 + p1 = rs_series_from_list(p, c, x, h) + p2 = 0 + for i, cx in enumerate(c): + p2 += cx*rs_pow(p, i, x, h) + assert p1 == p2 + + +def test_log(): + R, x = ring('x', QQ) + p = 1 + x + assert rs_log(p, x, 4) == x - x**2/2 + x**3/3 + p = 1 + x +2*x**2/3 + p1 = rs_log(p, x, 9) + assert p1 == -17*x**8/648 + 13*x**7/189 - 11*x**6/162 - x**5/45 + \ + 7*x**4/36 - x**3/3 + x**2/6 + x + p2 = rs_series_inversion(p, x, 9) + p3 = rs_log(p2, x, 9) + assert p3 == -p1 + + R, x, y = ring('x, y', QQ) + p = 1 + x + 2*y*x**2 + p1 = rs_log(p, x, 6) + assert p1 == (4*x**5*y**2 - 2*x**5*y - 2*x**4*y**2 + x**5/5 + 2*x**4*y - + x**4/4 - 2*x**3*y + x**3/3 + 2*x**2*y - x**2/2 + x) + + # Constant term in series + a = symbols('a') + R, x, y = ring('x, y', EX) + assert rs_log(x + a, x, 5) == -EX(1/(4*a**4))*x**4 + EX(1/(3*a**3))*x**3 \ + - EX(1/(2*a**2))*x**2 + EX(1/a)*x + EX(log(a)) + assert rs_log(x + x**2*y + a, x, 4) == -EX(a**(-2))*x**3*y + \ + EX(1/(3*a**3))*x**3 + EX(1/a)*x**2*y - EX(1/(2*a**2))*x**2 + \ + EX(1/a)*x + EX(log(a)) + + p = x + x**2 + 3 + assert rs_log(p, x, 10).compose(x, 5) == EX(log(3) + Rational(19281291595, 9920232)) + + +def test_exp(): + R, x = ring('x', QQ) + p = x + x**4 + for h in [10, 30]: + q = rs_series_inversion(1 + p, x, h) - 1 + p1 = rs_exp(q, x, h) + q1 = rs_log(p1, x, h) + assert q1 == q + p1 = rs_exp(p, x, 30) + assert p1.coeff(x**29) == QQ(74274246775059676726972369, 353670479749588078181744640000) + prec = 21 + p = rs_log(1 + x, x, prec) + p1 = rs_exp(p, x, prec) + assert p1 == x + 1 + + # Constant term in series + a = symbols('a') + R, x, y = ring('x, y', QQ[exp(a), a]) + assert rs_exp(x + a, x, 5) == exp(a)*x**4/24 + exp(a)*x**3/6 + \ + exp(a)*x**2/2 + exp(a)*x + exp(a) + assert rs_exp(x + x**2*y + a, x, 5) == exp(a)*x**4*y**2/2 + \ + exp(a)*x**4*y/2 + exp(a)*x**4/24 + exp(a)*x**3*y + \ + exp(a)*x**3/6 + exp(a)*x**2*y + exp(a)*x**2/2 + exp(a)*x + exp(a) + + R, x, y = ring('x, y', EX) + assert rs_exp(x + a, x, 5) == EX(exp(a)/24)*x**4 + EX(exp(a)/6)*x**3 + \ + EX(exp(a)/2)*x**2 + EX(exp(a))*x + EX(exp(a)) + assert rs_exp(x + x**2*y + a, x, 5) == EX(exp(a)/2)*x**4*y**2 + \ + EX(exp(a)/2)*x**4*y + EX(exp(a)/24)*x**4 + EX(exp(a))*x**3*y + \ + EX(exp(a)/6)*x**3 + EX(exp(a))*x**2*y + EX(exp(a)/2)*x**2 + \ + EX(exp(a))*x + EX(exp(a)) + + +def test_newton(): + R, x = ring('x', QQ) + p = x**2 - 2 + r = rs_newton(p, x, 4) + assert r == 8*x**4 + 4*x**2 + 2 + + +def test_compose_add(): + R, x = ring('x', QQ) + p1 = x**3 - 1 + p2 = x**2 - 2 + assert rs_compose_add(p1, p2) == x**6 - 6*x**4 - 2*x**3 + 12*x**2 - 12*x - 7 + + +def test_fun(): + R, x, y = ring('x, y', QQ) + p = x*y + x**2*y**3 + x**5*y + assert rs_fun(p, rs_tan, x, 10) == rs_tan(p, x, 10) + assert rs_fun(p, _tan1, x, 10) == _tan1(p, x, 10) + + +def test_nth_root(): + R, x, y = puiseux_ring('x, y', QQ) + assert rs_nth_root(1 + x**2*y, 4, x, 10) == -77*x**8*y**4/2048 + \ + 7*x**6*y**3/128 - 3*x**4*y**2/32 + x**2*y/4 + 1 + assert rs_nth_root(1 + x*y + x**2*y**3, 3, x, 5) == -x**4*y**6/9 + \ + 5*x**4*y**5/27 - 10*x**4*y**4/243 - 2*x**3*y**4/9 + 5*x**3*y**3/81 + \ + x**2*y**3/3 - x**2*y**2/9 + x*y/3 + 1 + assert rs_nth_root(8*x, 3, x, 3) == 2*x**QQ(1, 3) + assert rs_nth_root(8*x + x**2 + x**3, 3, x, 3) == x**QQ(4,3)/12 + 2*x**QQ(1,3) + r = rs_nth_root(8*x + x**2*y + x**3, 3, x, 4) + assert r == -x**QQ(7,3)*y**2/288 + x**QQ(7,3)/12 + x**QQ(4,3)*y/12 + 2*x**QQ(1,3) + + # Constant term in series + a = symbols('a') + R, x, y = puiseux_ring('x, y', EX) + assert rs_nth_root(x + EX(a), 3, x, 4) == EX(5/(81*a**QQ(8, 3)))*x**3 - \ + EX(1/(9*a**QQ(5, 3)))*x**2 + EX(1/(3*a**QQ(2, 3)))*x + EX(a**QQ(1, 3)) + assert rs_nth_root(x**QQ(2, 3) + x**2*y + 5, 2, x, 3) == -EX(sqrt(5)/100)*\ + x**QQ(8, 3)*y - EX(sqrt(5)/16000)*x**QQ(8, 3) + EX(sqrt(5)/10)*x**2*y + \ + EX(sqrt(5)/2000)*x**2 - EX(sqrt(5)/200)*x**QQ(4, 3) + \ + EX(sqrt(5)/10)*x**QQ(2, 3) + EX(sqrt(5)) + + +def test_atan(): + R, x, y = ring('x, y', QQ) + assert rs_atan(x, x, 9) == -x**7/7 + x**5/5 - x**3/3 + x + assert rs_atan(x*y + x**2*y**3, x, 9) == 2*x**8*y**11 - x**8*y**9 + \ + 2*x**7*y**9 - x**7*y**7/7 - x**6*y**9/3 + x**6*y**7 - x**5*y**7 + \ + x**5*y**5/5 - x**4*y**5 - x**3*y**3/3 + x**2*y**3 + x*y + + # Constant term in series + a = symbols('a') + R, x, y = ring('x, y', EX) + assert rs_atan(x + a, x, 5) == -EX((a**3 - a)/(a**8 + 4*a**6 + 6*a**4 + \ + 4*a**2 + 1))*x**4 + EX((3*a**2 - 1)/(3*a**6 + 9*a**4 + \ + 9*a**2 + 3))*x**3 - EX(a/(a**4 + 2*a**2 + 1))*x**2 + \ + EX(1/(a**2 + 1))*x + EX(atan(a)) + assert rs_atan(x + x**2*y + a, x, 4) == -EX(2*a/(a**4 + 2*a**2 + 1)) \ + *x**3*y + EX((3*a**2 - 1)/(3*a**6 + 9*a**4 + 9*a**2 + 3))*x**3 + \ + EX(1/(a**2 + 1))*x**2*y - EX(a/(a**4 + 2*a**2 + 1))*x**2 + EX(1/(a**2 \ + + 1))*x + EX(atan(a)) + + # Test for _atan faster for small and univariate series + R, x = ring('x', QQ) + p = x**2 + 2*x + assert _atan(p, x, 5) == rs_atan(p, x, 5) + + R, x = ring('x', EX) + p = x**2 + 2*x + assert _atan(p, x, 9) == rs_atan(p, x, 9) + + +def test_asin(): + R, x, y = ring('x, y', QQ) + assert rs_asin(x + x*y, x, 5) == x**3*y**3/6 + x**3*y**2/2 + x**3*y/2 + \ + x**3/6 + x*y + x + assert rs_asin(x*y + x**2*y**3, x, 6) == x**5*y**7/2 + 3*x**5*y**5/40 + \ + x**4*y**5/2 + x**3*y**3/6 + x**2*y**3 + x*y + + +def test_tan(): + R, x, y = ring('x, y', QQ) + assert rs_tan(x, x, 9) == x + x**3/3 + QQ(2,15)*x**5 + QQ(17,315)*x**7 + assert rs_tan(x*y + x**2*y**3, x, 9) == 4*x**8*y**11/3 + 17*x**8*y**9/45 + \ + 4*x**7*y**9/3 + 17*x**7*y**7/315 + x**6*y**9/3 + 2*x**6*y**7/3 + \ + x**5*y**7 + 2*x**5*y**5/15 + x**4*y**5 + x**3*y**3/3 + x**2*y**3 + x*y + + # Constant term in series + a = symbols('a') + R, x, y = ring('x, y', QQ[tan(a), a]) + assert rs_tan(x + a, x, 5) == (tan(a)**5 + 5*tan(a)**3/3 + + 2*tan(a)/3)*x**4 + (tan(a)**4 + 4*tan(a)**2/3 + Rational(1, 3))*x**3 + \ + (tan(a)**3 + tan(a))*x**2 + (tan(a)**2 + 1)*x + tan(a) + assert rs_tan(x + x**2*y + a, x, 4) == (2*tan(a)**3 + 2*tan(a))*x**3*y + \ + (tan(a)**4 + Rational(4, 3)*tan(a)**2 + Rational(1, 3))*x**3 + (tan(a)**2 + 1)*x**2*y + \ + (tan(a)**3 + tan(a))*x**2 + (tan(a)**2 + 1)*x + tan(a) + + R, x, y = ring('x, y', EX) + assert rs_tan(x + a, x, 5) == EX(tan(a)**5 + 5*tan(a)**3/3 + + 2*tan(a)/3)*x**4 + EX(tan(a)**4 + 4*tan(a)**2/3 + EX(1)/3)*x**3 + \ + EX(tan(a)**3 + tan(a))*x**2 + EX(tan(a)**2 + 1)*x + EX(tan(a)) + assert rs_tan(x + x**2*y + a, x, 4) == EX(2*tan(a)**3 + + 2*tan(a))*x**3*y + EX(tan(a)**4 + 4*tan(a)**2/3 + EX(1)/3)*x**3 + \ + EX(tan(a)**2 + 1)*x**2*y + EX(tan(a)**3 + tan(a))*x**2 + \ + EX(tan(a)**2 + 1)*x + EX(tan(a)) + + p = x + x**2 + 5 + assert rs_atan(p, x, 10).compose(x, 10) == EX(atan(5) + S(67701870330562640) / \ + 668083460499) + + +def test_cot(): + R, x, y = puiseux_ring('x, y', QQ) + assert rs_cot(x**6 + x**7, x, 8) == x**(-6) - x**(-5) + x**(-4) - \ + x**(-3) + x**(-2) - x**(-1) + 1 - x + x**2 - x**3 + x**4 - x**5 + \ + 2*x**6/3 - 4*x**7/3 + assert rs_cot(x + x**2*y, x, 5) == -x**4*y**5 - x**4*y/15 + x**3*y**4 - \ + x**3/45 - x**2*y**3 - x**2*y/3 + x*y**2 - x/3 - y + x**(-1) + + +def test_sin(): + R, x, y = ring('x, y', QQ) + assert rs_sin(x, x, 9) == x - x**3/6 + x**5/120 - x**7/5040 + assert rs_sin(x*y + x**2*y**3, x, 9) == x**8*y**11/12 - \ + x**8*y**9/720 + x**7*y**9/12 - x**7*y**7/5040 - x**6*y**9/6 + \ + x**6*y**7/24 - x**5*y**7/2 + x**5*y**5/120 - x**4*y**5/2 - \ + x**3*y**3/6 + x**2*y**3 + x*y + + # Constant term in series + a = symbols('a') + R, x, y = ring('x, y', QQ[sin(a), cos(a), a]) + assert rs_sin(x + a, x, 5) == sin(a)*x**4/24 - cos(a)*x**3/6 - \ + sin(a)*x**2/2 + cos(a)*x + sin(a) + assert rs_sin(x + x**2*y + a, x, 5) == -sin(a)*x**4*y**2/2 - \ + cos(a)*x**4*y/2 + sin(a)*x**4/24 - sin(a)*x**3*y - cos(a)*x**3/6 + \ + cos(a)*x**2*y - sin(a)*x**2/2 + cos(a)*x + sin(a) + + R, x, y = ring('x, y', EX) + assert rs_sin(x + a, x, 5) == EX(sin(a)/24)*x**4 - EX(cos(a)/6)*x**3 - \ + EX(sin(a)/2)*x**2 + EX(cos(a))*x + EX(sin(a)) + assert rs_sin(x + x**2*y + a, x, 5) == -EX(sin(a)/2)*x**4*y**2 - \ + EX(cos(a)/2)*x**4*y + EX(sin(a)/24)*x**4 - EX(sin(a))*x**3*y - \ + EX(cos(a)/6)*x**3 + EX(cos(a))*x**2*y - EX(sin(a)/2)*x**2 + \ + EX(cos(a))*x + EX(sin(a)) + + +def test_cos(): + R, x, y = ring('x, y', QQ) + assert rs_cos(x, x, 9) == 1 - x**2/2 + x**4/24 - x**6/720 + x**8/40320 + assert rs_cos(x*y + x**2*y**3, x, 9) == x**8*y**12/24 - \ + x**8*y**10/48 + x**8*y**8/40320 + x**7*y**10/6 - \ + x**7*y**8/120 + x**6*y**8/4 - x**6*y**6/720 + x**5*y**6/6 - \ + x**4*y**6/2 + x**4*y**4/24 - x**3*y**4 - x**2*y**2/2 + 1 + + # Constant term in series + a = symbols('a') + R, x, y = ring('x, y', QQ[sin(a), cos(a), a]) + assert rs_cos(x + a, x, 5) == cos(a)*x**4/24 + sin(a)*x**3/6 - \ + cos(a)*x**2/2 - sin(a)*x + cos(a) + assert rs_cos(x + x**2*y + a, x, 5) == -cos(a)*x**4*y**2/2 + \ + sin(a)*x**4*y/2 + cos(a)*x**4/24 - cos(a)*x**3*y + sin(a)*x**3/6 - \ + sin(a)*x**2*y - cos(a)*x**2/2 - sin(a)*x + cos(a) + + R, x, y = ring('x, y', EX) + assert rs_cos(x + a, x, 5) == EX(cos(a)/24)*x**4 + EX(sin(a)/6)*x**3 - \ + EX(cos(a)/2)*x**2 - EX(sin(a))*x + EX(cos(a)) + assert rs_cos(x + x**2*y + a, x, 5) == -EX(cos(a)/2)*x**4*y**2 + \ + EX(sin(a)/2)*x**4*y + EX(cos(a)/24)*x**4 - EX(cos(a))*x**3*y + \ + EX(sin(a)/6)*x**3 - EX(sin(a))*x**2*y - EX(cos(a)/2)*x**2 - \ + EX(sin(a))*x + EX(cos(a)) + + +def test_cos_sin(): + R, x, y = ring('x, y', QQ) + c, s = rs_cos_sin(x, x, 9) + assert c == rs_cos(x, x, 9) + assert s == rs_sin(x, x, 9) + c, s = rs_cos_sin(x + x*y, x, 5) + assert c == rs_cos(x + x*y, x, 5) + assert s == rs_sin(x + x*y, x, 5) + + # constant term in series + c, s = rs_cos_sin(1 + x + x**2, x, 5) + assert c == rs_cos(1 + x + x**2, x, 5) + assert s == rs_sin(1 + x + x**2, x, 5) + + a = symbols('a') + R, x, y = ring('x, y', QQ[sin(a), cos(a), a]) + c, s = rs_cos_sin(x + a, x, 5) + assert c == rs_cos(x + a, x, 5) + assert s == rs_sin(x + a, x, 5) + + R, x, y = ring('x, y', EX) + c, s = rs_cos_sin(x + a, x, 5) + assert c == rs_cos(x + a, x, 5) + assert s == rs_sin(x + a, x, 5) + + +def test_atanh(): + R, x, y = ring('x, y', QQ) + assert rs_atanh(x, x, 9) == x + x**3/3 + x**5/5 + x**7/7 + assert rs_atanh(x*y + x**2*y**3, x, 9) == 2*x**8*y**11 + x**8*y**9 + \ + 2*x**7*y**9 + x**7*y**7/7 + x**6*y**9/3 + x**6*y**7 + x**5*y**7 + \ + x**5*y**5/5 + x**4*y**5 + x**3*y**3/3 + x**2*y**3 + x*y + + # Constant term in series + a = symbols('a') + R, x, y = ring('x, y', EX) + assert rs_atanh(x + a, x, 5) == EX((a**3 + a)/(a**8 - 4*a**6 + 6*a**4 - \ + 4*a**2 + 1))*x**4 - EX((3*a**2 + 1)/(3*a**6 - 9*a**4 + \ + 9*a**2 - 3))*x**3 + EX(a/(a**4 - 2*a**2 + 1))*x**2 - EX(1/(a**2 - \ + 1))*x + EX(atanh(a)) + assert rs_atanh(x + x**2*y + a, x, 4) == EX(2*a/(a**4 - 2*a**2 + \ + 1))*x**3*y - EX((3*a**2 + 1)/(3*a**6 - 9*a**4 + 9*a**2 - 3))*x**3 - \ + EX(1/(a**2 - 1))*x**2*y + EX(a/(a**4 - 2*a**2 + 1))*x**2 - \ + EX(1/(a**2 - 1))*x + EX(atanh(a)) + + p = x + x**2 + 5 + assert rs_atanh(p, x, 10).compose(x, 10) == EX(Rational(-733442653682135, 5079158784) \ + + atanh(5)) + + # Test for _atanh faster for small and univariate series + R,x = ring('x', QQ) + p = x**2 + 2*x + assert _atanh(p, x, 5) == rs_atanh(p, x, 5) + + R,x = ring('x', EX) + p = x**2 + 2*x + assert _atanh(p, x, 9) == rs_atanh(p, x, 9) + + +def test_asinh(): + R, x, y = ring('x, y', QQ) + assert rs_asinh(x, x, 9) == -5/112*x**7 + 3/40*x**5 - 1/6*x**3 + x + assert rs_asinh(x*y + x**2*y**3, x, 9) == 3/4*x**8*y**11 - 5/16*x**8*y**9 + \ + 3/4*x**7*y**9 - 5/112*x**7*y**7 - 1/6*x**6*y**9 + 3/8*x**6*y**7 - 1/2*x \ + **5*y**7 + 3/40*x**5*y**5 - 1/2*x**4*y**5 - 1/6*x**3*y**3 + x**2*y**3 + x*y + + # Constant term in series + a = symbols('a') + R, x, y = ring('x, y', EX) + assert rs_asinh(x + a, x, 3) == -EX(a/(2*a**2*sqrt(a**2 + 1) + 2*sqrt(a**2 + 1))) \ + *x**2 + EX(1/sqrt(a**2 + 1))*x + EX(asinh(a)) + assert rs_asinh(x + x**2*y + a, x, 3) == EX(1/sqrt(a**2 + 1))*x**2*y - EX(a/(2*a**2 \ + *sqrt(a**2 + 1) + 2*sqrt(a**2 + 1)))*x**2 + EX(1/sqrt(a**2 + 1))*x + EX(asinh(a)) + + p = x + x ** 2 + 5 + assert rs_asinh(p, x, 10).compose(x, 10) == EX(asinh(5) + 4643789843094995*sqrt(26)/\ + 205564141692) + + +def test_sinh(): + R, x, y = ring('x, y', QQ) + assert rs_sinh(x, x, 9) == x + x**3/6 + x**5/120 + x**7/5040 + assert rs_sinh(x*y + x**2*y**3, x, 9) == x**8*y**11/12 + \ + x**8*y**9/720 + x**7*y**9/12 + x**7*y**7/5040 + x**6*y**9/6 + \ + x**6*y**7/24 + x**5*y**7/2 + x**5*y**5/120 + x**4*y**5/2 + \ + x**3*y**3/6 + x**2*y**3 + x*y + + # constant term in series + a = symbols('a') + R, x, y = ring('x, y', QQ[sinh(a), cosh(a), a]) + assert rs_sinh(x + a, x, 5) == 1/24*x**4*(sinh(a)) + 1/6*x**3*(cosh(a)) + 1/\ + 2*x**2*(sinh(a)) + x*(cosh(a)) + (sinh(a)) + assert rs_sinh(x + x**2*y + a, x, 5) == 1/2*(sinh(a))*x**4*y**2 + 1/2*(cosh(a))\ + *x**4*y + 1/24*(sinh(a))*x**4 + (sinh(a))*x**3*y + 1/6*(cosh(a))*x**3 + \ + (cosh(a))*x**2*y + 1/2*(sinh(a))*x**2 + (cosh(a))*x + (sinh(a)) + + R, x, y = ring('x, y', EX) + assert rs_sinh(x + a, x, 5) == EX(sinh(a)/24)*x**4 + EX(cosh(a)/6)*x**3 + \ + EX(sinh(a)/2)*x**2 + EX(cosh(a))*x + EX(sinh(a)) + assert rs_sinh(x + x**2*y + a, x, 5) == EX(sinh(a)/2)*x**4*y**2 + EX(cosh(a)/\ + 2)*x**4*y + EX(sinh(a)/24)*x**4 + EX(sinh(a))*x**3*y + EX(cosh(a)/6)*x**3 \ + + EX(cosh(a))*x**2*y + EX(sinh(a)/2)*x**2 + EX(cosh(a))*x + EX(sinh(a)) + + +def test_cosh(): + R, x, y = ring('x, y', QQ) + assert rs_cosh(x, x, 9) == 1 + x**2/2 + x**4/24 + x**6/720 + x**8/40320 + assert rs_cosh(x*y + x**2*y**3, x, 9) == x**8*y**12/24 + \ + x**8*y**10/48 + x**8*y**8/40320 + x**7*y**10/6 + \ + x**7*y**8/120 + x**6*y**8/4 + x**6*y**6/720 + x**5*y**6/6 + \ + x**4*y**6/2 + x**4*y**4/24 + x**3*y**4 + x**2*y**2/2 + 1 + + # constant term in series + a = symbols('a') + R, x, y = ring('x, y', QQ[sinh(a), cosh(a), a]) + assert rs_cosh(x + a, x, 5) == 1/24*(cosh(a))*x**4 + 1/6*(sinh(a))*x**3 + \ + 1/2*(cosh(a))*x**2 + (sinh(a))*x + (cosh(a)) + assert rs_cosh(x + x**2*y + a, x, 5) == 1/2*(cosh(a))*x**4*y**2 + 1/2*(sinh(a))\ + *x**4*y + 1/24*(cosh(a))*x**4 + (cosh(a))*x**3*y + 1/6*(sinh(a))*x**3 + \ + (sinh(a))*x**2*y + 1/2*(cosh(a))*x**2 + (sinh(a))*x + (cosh(a)) + R, x, y = ring('x, y', EX) + assert rs_cosh(x + a, x, 5) == EX(cosh(a)/24)*x**4 + EX(sinh(a)/6)*x**3 + \ + EX(cosh(a)/2)*x**2 + EX(sinh(a))*x + EX(cosh(a)) + assert rs_cosh(x + x**2*y + a, x, 5) == EX(cosh(a)/2)*x**4*y**2 + EX(sinh(a)/\ + 2)*x**4*y + EX(cosh(a)/24)*x**4 + EX(cosh(a))*x**3*y + EX(sinh(a)/6)*x**3 \ + + EX(sinh(a))*x**2*y + EX(cosh(a)/2)*x**2 + EX(sinh(a))*x + EX(cosh(a)) + + +def test_cosh_sinh(): + R, x, y = ring('x, y', QQ) + ch, sh = rs_cosh_sinh(x, x, 9) + assert ch == rs_cosh(x, x, 9) + assert sh == rs_sinh(x, x, 9) + ch, sh = rs_cosh_sinh(x + x*y, x, 5) + assert ch == rs_cosh(x + x*y, x, 5) + assert sh == rs_sinh(x + x*y, x, 5) + + # constant term in series + c, s = rs_cosh_sinh(1 + x + x**2, x, 5) + assert c == rs_cosh(1 + x + x**2, x, 5) + assert s == rs_sinh(1 + x + x**2, x, 5) + + a = symbols('a') + R, x, y = ring('x, y', QQ[sinh(a), cosh(a), a]) + ch, sh = rs_cosh_sinh(x + a, x, 5) + assert ch == rs_cosh(x + a, x, 5) + assert sh == rs_sinh(x + a, x, 5) + R, x, y = ring('x, y', EX) + ch, sh = rs_cosh_sinh(x + a, x, 5) + assert ch == rs_cosh(x + a, x, 5) + assert sh == rs_sinh(x + a, x, 5) + + +def test_tanh(): + R, x, y = ring('x, y', QQ) + assert rs_tanh(x, x, 9) == x - QQ(1,3)*x**3 + QQ(2,15)*x**5 - QQ(17,315)*x**7 + assert rs_tanh(x*y + x**2*y**3, x, 9) == 4*x**8*y**11/3 - \ + 17*x**8*y**9/45 + 4*x**7*y**9/3 - 17*x**7*y**7/315 - x**6*y**9/3 + \ + 2*x**6*y**7/3 - x**5*y**7 + 2*x**5*y**5/15 - x**4*y**5 - \ + x**3*y**3/3 + x**2*y**3 + x*y + + # Constant term in series + a = symbols('a') + R, x, y = ring('x, y', EX) + assert rs_tanh(x + a, x, 5) == EX(tanh(a)**5 - 5*tanh(a)**3/3 + + 2*tanh(a)/3)*x**4 + EX(-tanh(a)**4 + 4*tanh(a)**2/3 - QQ(1, 3))*x**3 + \ + EX(tanh(a)**3 - tanh(a))*x**2 + EX(-tanh(a)**2 + 1)*x + EX(tanh(a)) + + p = rs_tanh(x + x**2*y + a, x, 4) + assert (p.compose(x, 10)).compose(y, 5) == EX(-1000*tanh(a)**4 + \ + 10100*tanh(a)**3 + 2470*tanh(a)**2/3 - 10099*tanh(a) + QQ(530, 3)) + + +def test_RR(): + rs_funcs = [rs_sin, rs_cos, rs_tan, rs_cot, rs_atan, rs_tanh] + sympy_funcs = [sin, cos, tan, cot, atan, tanh] + R, x, y = ring('x, y', RR) + a = symbols('a') + for rs_func, sympy_func in zip(rs_funcs, sympy_funcs): + p = rs_func(2 + x, x, 5).compose(x, 5) + q = sympy_func(2 + a).series(a, 0, 5).removeO() + is_close(p.as_expr(), q.subs(a, 5).n()) + + p = rs_nth_root(2 + x, 5, x, 5).compose(x, 5) + q = ((2 + a)**QQ(1, 5)).series(a, 0, 5).removeO() + is_close(p.as_expr(), q.subs(a, 5).n()) + + +def test_is_regular(): + R, x, y = puiseux_ring('x, y', QQ) + p = 1 + 2*x + x**2 + 3*x**3 + assert not rs_is_puiseux(p, x) + + p = x + x**QQ(1,5)*y + assert rs_is_puiseux(p, x) + assert not rs_is_puiseux(p, y) + + p = x + x**2*y**QQ(1,5)*y + assert not rs_is_puiseux(p, x) + + +def test_puiseux(): + R, x, y = puiseux_ring('x, y', QQ) + p = x**QQ(2,5) + x**QQ(2,3) + x + + r = rs_series_inversion(p, x, 1) + r1 = -x**QQ(14,15) + x**QQ(4,5) - 3*x**QQ(11,15) + x**QQ(2,3) + \ + 2*x**QQ(7,15) - x**QQ(2,5) - x**QQ(1,5) + x**QQ(2,15) - x**QQ(-2,15) \ + + x**QQ(-2,5) + assert r == r1 + + r = rs_nth_root(1 + p, 3, x, 1) + assert r == -x**QQ(4,5)/9 + x**QQ(2,3)/3 + x**QQ(2,5)/3 + 1 + + r = rs_log(1 + p, x, 1) + assert r == -x**QQ(4,5)/2 + x**QQ(2,3) + x**QQ(2,5) + + r = rs_LambertW(p, x, 1) + assert r == -x**QQ(4,5) + x**QQ(2,3) + x**QQ(2,5) + + p1 = x + x**QQ(1,5)*y + r = rs_exp(p1, x, 1) + assert r == x**QQ(4,5)*y**4/24 + x**QQ(3,5)*y**3/6 + x**QQ(2,5)*y**2/2 + \ + x**QQ(1,5)*y + 1 + + r = rs_atan(p, x, 2) + assert r == -x**QQ(9,5) - x**QQ(26,15) - x**QQ(22,15) - x**QQ(6,5)/3 + \ + x + x**QQ(2,3) + x**QQ(2,5) + + r = rs_atan(p1, x, 2) + assert r == x**QQ(9,5)*y**9/9 + x**QQ(9,5)*y**4 - x**QQ(7,5)*y**7/7 - \ + x**QQ(7,5)*y**2 + x*y**5/5 + x - x**QQ(3,5)*y**3/3 + x**QQ(1,5)*y + + r = rs_tan(p, x, 2) + assert r == x**QQ(2,5) + x**QQ(2,3) + x + QQ(1,3)*x**QQ(6,5) + x**QQ(22,15)\ + + x**QQ(26,15) + x**QQ(9,5) + + r = rs_sin(p, x, 2) + assert r == x**QQ(2,5) + x**QQ(2,3) + x - QQ(1,6)*x**QQ(6,5) - QQ(1,2)*x**\ + QQ(22,15) - QQ(1,2)*x**QQ(26,15) - QQ(1,2)*x**QQ(9,5) + + r = rs_cos(p, x, 2) + assert r == 1 - QQ(1,2)*x**QQ(4,5) - x**QQ(16,15) - QQ(1,2)*x**QQ(4,3) - \ + x**QQ(7,5) + QQ(1,24)*x**QQ(8,5) - x**QQ(5,3) + QQ(1,6)*x**QQ(28,15) + + r = rs_asin(p, x, 2) + assert r == x**QQ(9,5)/2 + x**QQ(26,15)/2 + x**QQ(22,15)/2 + \ + x**QQ(6,5)/6 + x + x**QQ(2,3) + x**QQ(2,5) + + r = rs_cot(p, x, 1) + assert r == -x**QQ(14,15) + x**QQ(4,5) - 3*x**QQ(11,15) + \ + 2*x**QQ(2,3)/3 + 2*x**QQ(7,15) - 4*x**QQ(2,5)/3 - x**QQ(1,5) + \ + x**QQ(2,15) - x**QQ(-2,15) + x**QQ(-2,5) + + r = rs_cos_sin(p, x, 2) + assert r[0] == x**QQ(28,15)/6 - x**QQ(5,3) + x**QQ(8,5)/24 - x**QQ(7,5) - \ + x**QQ(4,3)/2 - x**QQ(16,15) - x**QQ(4,5)/2 + 1 + assert r[1] == -x**QQ(9,5)/2 - x**QQ(26,15)/2 - x**QQ(22,15)/2 - \ + x**QQ(6,5)/6 + x + x**QQ(2,3) + x**QQ(2,5) + + r = rs_atanh(p, x, 2) + assert r == x**QQ(9,5) + x**QQ(26,15) + x**QQ(22,15) + x**QQ(6,5)/3 + x + \ + x**QQ(2,3) + x**QQ(2,5) + + r = rs_asinh(p, x, 2) + assert r == x**QQ(2,5) + x**QQ(2,3) + x - QQ(1,6)*x**QQ(6,5) - QQ(1,2)*x**\ + QQ(22,15) - QQ(1,2)*x**QQ(26,15) - QQ(1,2)*x**QQ(9,5) + + r = rs_cosh(p, x, 2) + assert r == x**QQ(28,15)/6 + x**QQ(5,3) + x**QQ(8,5)/24 + x**QQ(7,5) + \ + x**QQ(4,3)/2 + x**QQ(16,15) + x**QQ(4,5)/2 + 1 + + r = rs_sinh(p, x, 2) + assert r == x**QQ(9,5)/2 + x**QQ(26,15)/2 + x**QQ(22,15)/2 + \ + x**QQ(6,5)/6 + x + x**QQ(2,3) + x**QQ(2,5) + + r = rs_cosh_sinh(p, x, 2) + assert r[0] == x**QQ(28,15)/6 + x**QQ(5,3) + x**QQ(8,5)/24 + x**QQ(7,5) + \ + x**QQ(4,3)/2 + x**QQ(16,15) + x**QQ(4,5)/2 + 1 + assert r[1] == x**QQ(9,5)/2 + x**QQ(26,15)/2 + x**QQ(22,15)/2 + \ + x**QQ(6,5)/6 + x + x**QQ(2,3) + x**QQ(2,5) + + r = rs_tanh(p, x, 2) + assert r == -x**QQ(9,5) - x**QQ(26,15) - x**QQ(22,15) - x**QQ(6,5)/3 + \ + x + x**QQ(2,3) + x**QQ(2,5) + + +def test_puiseux_algebraic(): # https://github.com/sympy/sympy/issues/24395 + + K = QQ.algebraic_field(sqrt(2)) + sqrt2 = K.from_sympy(sqrt(2)) + x, y = symbols('x, y') + R, xr, yr = puiseux_ring([x, y], K) + p = (1+sqrt2)*xr**QQ(1,2) + (1-sqrt2)*yr**QQ(2,3) + + assert p.to_dict() == {(QQ(1,2),QQ(0)):1+sqrt2, (QQ(0),QQ(2,3)):1-sqrt2} + assert p.as_expr() == (1 + sqrt(2))*x**(S(1)/2) + (1 - sqrt(2))*y**(S(2)/3) + + +def test1(): + R, x = puiseux_ring('x', QQ) + r = rs_sin(x, x, 15)*x**(-5) + assert r == x**8/6227020800 - x**6/39916800 + x**4/362880 - x**2/5040 + \ + QQ(1,120) - x**-2/6 + x**-4 + + p = rs_sin(x, x, 10) + r = rs_nth_root(p, 2, x, 10) + assert r == -67*x**QQ(17,2)/29030400 - x**QQ(13,2)/24192 + \ + x**QQ(9,2)/1440 - x**QQ(5,2)/12 + x**QQ(1,2) + + p = rs_sin(x, x, 10) + r = rs_nth_root(p, 7, x, 10) + r = rs_pow(r, 5, x, 10) + assert r == -97*x**QQ(61,7)/124467840 - x**QQ(47,7)/16464 + \ + 11*x**QQ(33,7)/3528 - 5*x**QQ(19,7)/42 + x**QQ(5,7) + + r = rs_exp(x**QQ(1,2), x, 10) + assert r == x**QQ(19,2)/121645100408832000 + x**9/6402373705728000 + \ + x**QQ(17,2)/355687428096000 + x**8/20922789888000 + \ + x**QQ(15,2)/1307674368000 + x**7/87178291200 + \ + x**QQ(13,2)/6227020800 + x**6/479001600 + x**QQ(11,2)/39916800 + \ + x**5/3628800 + x**QQ(9,2)/362880 + x**4/40320 + x**QQ(7,2)/5040 + \ + x**3/720 + x**QQ(5,2)/120 + x**2/24 + x**QQ(3,2)/6 + x/2 + \ + x**QQ(1,2) + 1 + + +def test_puiseux2(): + R, y = ring('y', QQ) + S, x = puiseux_ring('x', R.to_domain()) + + p = x + x**QQ(1,5)*y + r = rs_atan(p, x, 3) + assert r == (y**13/13 + y**8 + 2*y**3)*x**QQ(13,5) - (y**11/11 + y**6 + + y)*x**QQ(11,5) + (y**9/9 + y**4)*x**QQ(9,5) - (y**7/7 + + y**2)*x**QQ(7,5) + (y**5/5 + 1)*x - y**3*x**QQ(3,5)/3 + y*x**QQ(1,5) + + +@slow +def test_rs_series(): + x, a, b, c = symbols('x, a, b, c') + + assert rs_series(a, a, 5).as_expr() == a + assert rs_series(sin(a), a, 5).as_expr() == (sin(a).series(a, 0, + 5)).removeO() + assert rs_series(sin(a) + cos(a), a, 5).as_expr() == ((sin(a) + + cos(a)).series(a, 0, 5)).removeO() + assert rs_series(sin(a)*cos(a), a, 5).as_expr() == ((sin(a)* + cos(a)).series(a, 0, 5)).removeO() + + p = (sin(a) - a)*(cos(a**2) + a**4/2) + assert expand(rs_series(p, a, 10).as_expr()) == expand(p.series(a, 0, + 10).removeO()) + + p = sin(a**2/2 + a/3) + cos(a/5)*sin(a/2)**3 + assert expand(rs_series(p, a, 5).as_expr()) == expand(p.series(a, 0, + 5).removeO()) + + p = sin(x**2 + a)*(cos(x**3 - 1) - a - a**2) + assert expand(rs_series(p, a, 5).as_expr()) == expand(p.series(a, 0, + 5).removeO()) + + p = sin(a**2 - a/3 + 2)**5*exp(a**3 - a/2) + assert expand(rs_series(p, a, 10).as_expr()) == expand(p.series(a, 0, + 10).removeO()) + + p = sin(a + b + c) + assert expand(rs_series(p, a, 5).as_expr()) == expand(p.series(a, 0, + 5).removeO()) + + p = tan(sin(a**2 + 4) + b + c) + assert expand(rs_series(p, a, 6).as_expr()) == expand(p.series(a, 0, + 6).removeO()) + + p = a**QQ(2,5) + a**QQ(2,3) + a + + r = rs_series(tan(p), a, 2) + assert r.as_expr() == a**QQ(9,5) + a**QQ(26,15) + a**QQ(22,15) + a**QQ(6,5)/3 + \ + a + a**QQ(2,3) + a**QQ(2,5) + + r = rs_series(exp(p), a, 1) + assert r.as_expr() == a**QQ(4,5)/2 + a**QQ(2,3) + a**QQ(2,5) + 1 + + r = rs_series(sin(p), a, 2) + assert r.as_expr() == -a**QQ(9,5)/2 - a**QQ(26,15)/2 - a**QQ(22,15)/2 - \ + a**QQ(6,5)/6 + a + a**QQ(2,3) + a**QQ(2,5) + + r = rs_series(cos(p), a, 2) + assert r.as_expr() == a**QQ(28,15)/6 - a**QQ(5,3) + a**QQ(8,5)/24 - a**QQ(7,5) - \ + a**QQ(4,3)/2 - a**QQ(16,15) - a**QQ(4,5)/2 + 1 + + assert rs_series(sin(a)/7, a, 5).as_expr() == (sin(a)/7).series(a, 0, + 5).removeO() + + +def test_rs_series_ConstantInExpr(): + x, a = symbols('x a') + assert rs_series(log(1 + x), x, 5).as_expr() == -x**4/4 + x**3/3 - \ + x**2/2 + x + assert rs_series(log(1 + 4*x), x, 5).as_expr() == -64*x**4 + 64*x**3/3 - \ + 8*x**2 + 4*x + assert rs_series(log(1 + x + x**2), x, 10).as_expr() == -2*x**9/9 + \ + x**8/8 + x**7/7 - x**6/3 + x**5/5 + x**4/4 - 2*x**3/3 + x**2/2 + x + assert rs_series(log(1 + x*a**2), x, 7).as_expr() == -x**6*a**12/6 + \ + x**5*a**10/5 - x**4*a**8/4 + x**3*a**6/3 - x**2*a**4/2 + x*a**2 + + assert rs_series(atan(1 + x), x, 9).as_expr() == -x**7/112 + x**6/48 - x**5/40 \ + + x**3/12 - x**2/4 + x/2 + pi/4 + assert rs_series(atan(1 + x + x**2),x, 9).as_expr() == -15*x**7/112 - x**6/48 + \ + 9*x**5/40 - 5*x**3/12 + x**2/4 + x/2 + pi/4 + assert rs_series(atan(1 + x * a), x, 9).as_expr() == -a**7*x**7/112 + a**6*x**6/48 \ + - a**5*x**5/40 + a**3*x**3/12 - a**2*x**2/4 + a*x/2 + pi/4 + + assert rs_series(tanh(1 + x), x, 5).as_expr() == -5*x**4*tanh(1)**3/3 + x**4* \ + tanh(1)**5 + 2*x**4*tanh(1)/3 - x**3*tanh(1)**4 - x**3/3 + 4*x**3*tanh(1) \ + **2/3 - x**2*tanh(1) + x**2*tanh(1)**3 - x*tanh(1)**2 + x + tanh(1) + assert rs_series(tanh(1 + x * a), x, 3).as_expr() == -a**2*x**2*tanh(1) + a**2*x** \ + 2*tanh(1)**3 - a*x*tanh(1)**2 + a*x + tanh(1) + + assert rs_series(sinh(1 + x), x, 5).as_expr() == x**4*sinh(1)/24 + x**3*cosh(1)/6 + \ + x**2*sinh(1)/2 + x*cosh(1) + sinh(1) + assert rs_series(sinh(1 + x * a), x, 5).as_expr() == a**4*x**4*sinh(1)/24 + \ + a**3*x**3*cosh(1)/6 + a**2*x**2*sinh(1)/2 + a*x*cosh(1) + sinh(1) + + assert rs_series(cosh(1 + x), x, 5).as_expr() == x**4*cosh(1)/24 + x**3*sinh(1)/6 + \ + x**2*cosh(1)/2 + x*sinh(1) + cosh(1) + assert rs_series(cosh(1 + x * a), x, 5).as_expr() == a**4*x**4*cosh(1)/24 + \ + a**3*x**3*sinh(1)/6 + a**2*x**2*cosh(1)/2 + a*x*sinh(1) + cosh(1) + + +def test_issue(): + # https://github.com/sympy/sympy/issues/10191 + # https://github.com/sympy/sympy/issues/19543 + + a, b = symbols('a b') + assert rs_series(sin(a**QQ(3,7))*exp(a + b**QQ(6,7)), a,2).as_expr() == \ + a**QQ(10,7)*exp(b**QQ(6,7)) - a**QQ(9,7)*exp(b**QQ(6,7))/6 + a**QQ(3,7)*exp(b**QQ(6,7)) diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/test_rings.py b/phivenv/Lib/site-packages/sympy/polys/tests/test_rings.py new file mode 100644 index 0000000000000000000000000000000000000000..455cc319908d0173737531b339e22def8e4a26fc --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/tests/test_rings.py @@ -0,0 +1,1591 @@ +"""Test sparse polynomials. """ + +from functools import reduce +from operator import add, mul + +from sympy.polys.rings import ring, xring, sring, PolyRing, PolyElement +from sympy.polys.fields import field, FracField +from sympy.polys.densebasic import ninf +from sympy.polys.domains import ZZ, QQ, RR, FF, EX +from sympy.polys.orderings import lex, grlex +from sympy.polys.polyerrors import GeneratorsError, \ + ExactQuotientFailed, MultivariatePolynomialError, CoercionFailed + +from sympy.testing.pytest import raises +from sympy.core import Symbol, symbols +from sympy.core.singleton import S +from sympy.core.numbers import pi +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt + +def test_PolyRing___init__(): + x, y, z, t = map(Symbol, "xyzt") + + assert len(PolyRing("x,y,z", ZZ, lex).gens) == 3 + assert len(PolyRing(x, ZZ, lex).gens) == 1 + assert len(PolyRing(("x", "y", "z"), ZZ, lex).gens) == 3 + assert len(PolyRing((x, y, z), ZZ, lex).gens) == 3 + assert len(PolyRing("", ZZ, lex).gens) == 0 + assert len(PolyRing([], ZZ, lex).gens) == 0 + + raises(GeneratorsError, lambda: PolyRing(0, ZZ, lex)) + + assert PolyRing("x", ZZ[t], lex).domain == ZZ[t] + assert PolyRing("x", 'ZZ[t]', lex).domain == ZZ[t] + assert PolyRing("x", PolyRing("t", ZZ, lex), lex).domain == ZZ[t] + + raises(GeneratorsError, lambda: PolyRing("x", PolyRing("x", ZZ, lex), lex)) + + _lex = Symbol("lex") + assert PolyRing("x", ZZ, lex).order == lex + assert PolyRing("x", ZZ, _lex).order == lex + assert PolyRing("x", ZZ, 'lex').order == lex + + R1 = PolyRing("x,y", ZZ, lex) + R2 = PolyRing("x,y", ZZ, lex) + R3 = PolyRing("x,y,z", ZZ, lex) + + assert R1.x == R1.gens[0] + assert R1.y == R1.gens[1] + assert R1.x == R2.x + assert R1.y == R2.y + assert R1.x != R3.x + assert R1.y != R3.y + +def test_PolyRing___hash__(): + R, x, y, z = ring("x,y,z", QQ) + assert hash(R) + +def test_PolyRing___eq__(): + assert ring("x,y,z", QQ)[0] == ring("x,y,z", QQ)[0] + assert ring("x,y,z", QQ)[0] != ring("x,y,z", ZZ)[0] + assert ring("x,y,z", ZZ)[0] != ring("x,y,z", QQ)[0] + assert ring("x,y,z", QQ)[0] != ring("x,y", QQ)[0] + assert ring("x,y", QQ)[0] != ring("x,y,z", QQ)[0] + +def test_PolyRing_ring_new(): + R, x, y, z = ring("x,y,z", QQ) + + assert R.ring_new(7) == R(7) + assert R.ring_new(7*x*y*z) == 7*x*y*z + + f = x**2 + 2*x*y + 3*x + 4*z**2 + 5*z + 6 + + assert R.ring_new([[[1]], [[2], [3]], [[4, 5, 6]]]) == f + assert R.ring_new({(2, 0, 0): 1, (1, 1, 0): 2, (1, 0, 0): 3, (0, 0, 2): 4, (0, 0, 1): 5, (0, 0, 0): 6}) == f + assert R.ring_new([((2, 0, 0), 1), ((1, 1, 0), 2), ((1, 0, 0), 3), ((0, 0, 2), 4), ((0, 0, 1), 5), ((0, 0, 0), 6)]) == f + + R, = ring("", QQ) + assert R.ring_new([((), 7)]) == R(7) + +def test_PolyRing_drop(): + R, x,y,z = ring("x,y,z", ZZ) + + assert R.drop(x) == PolyRing("y,z", ZZ, lex) + assert R.drop(y) == PolyRing("x,z", ZZ, lex) + assert R.drop(z) == PolyRing("x,y", ZZ, lex) + + assert R.drop(0) == PolyRing("y,z", ZZ, lex) + assert R.drop(0).drop(0) == PolyRing("z", ZZ, lex) + assert R.drop(0).drop(0).drop(0) == ZZ + + assert R.drop(1) == PolyRing("x,z", ZZ, lex) + + assert R.drop(2) == PolyRing("x,y", ZZ, lex) + assert R.drop(2).drop(1) == PolyRing("x", ZZ, lex) + assert R.drop(2).drop(1).drop(0) == ZZ + + raises(ValueError, lambda: R.drop(3)) + raises(ValueError, lambda: R.drop(x).drop(y)) + +def test_PolyRing___getitem__(): + R, x,y,z = ring("x,y,z", ZZ) + + assert R[0:] == PolyRing("x,y,z", ZZ, lex) + assert R[1:] == PolyRing("y,z", ZZ, lex) + assert R[2:] == PolyRing("z", ZZ, lex) + assert R[3:] == ZZ + +def test_PolyRing_is_(): + R = PolyRing("x", QQ, lex) + + assert R.is_univariate is True + assert R.is_multivariate is False + + R = PolyRing("x,y,z", QQ, lex) + + assert R.is_univariate is False + assert R.is_multivariate is True + + R = PolyRing("", QQ, lex) + + assert R.is_univariate is False + assert R.is_multivariate is False + +def test_PolyRing_add(): + R, x = ring("x", ZZ) + F = [ x**2 + 2*i + 3 for i in range(4) ] + + assert R.add(F) == reduce(add, F) == 4*x**2 + 24 + + R, = ring("", ZZ) + + assert R.add([2, 5, 7]) == 14 + +def test_PolyRing_mul(): + R, x = ring("x", ZZ) + F = [ x**2 + 2*i + 3 for i in range(4) ] + + assert R.mul(F) == reduce(mul, F) == x**8 + 24*x**6 + 206*x**4 + 744*x**2 + 945 + + R, = ring("", ZZ) + + assert R.mul([2, 3, 5]) == 30 + +def test_PolyRing_symmetric_poly(): + R, x, y, z, t = ring("x,y,z,t", ZZ) + + raises(ValueError, lambda: R.symmetric_poly(-1)) + raises(ValueError, lambda: R.symmetric_poly(5)) + + assert R.symmetric_poly(0) == R.one + assert R.symmetric_poly(1) == x + y + z + t + assert R.symmetric_poly(2) == x*y + x*z + x*t + y*z + y*t + z*t + assert R.symmetric_poly(3) == x*y*z + x*y*t + x*z*t + y*z*t + assert R.symmetric_poly(4) == x*y*z*t + +def test_sring(): + x, y, z, t = symbols("x,y,z,t") + + R = PolyRing("x,y,z", ZZ, lex) + assert sring(x + 2*y + 3*z) == (R, R.x + 2*R.y + 3*R.z) + + R = PolyRing("x,y,z", QQ, lex) + assert sring(x + 2*y + z/3) == (R, R.x + 2*R.y + R.z/3) + assert sring([x, 2*y, z/3]) == (R, [R.x, 2*R.y, R.z/3]) + + Rt = PolyRing("t", ZZ, lex) + R = PolyRing("x,y,z", Rt, lex) + assert sring(x + 2*t*y + 3*t**2*z, x, y, z) == (R, R.x + 2*Rt.t*R.y + 3*Rt.t**2*R.z) + + Rt = PolyRing("t", QQ, lex) + R = PolyRing("x,y,z", Rt, lex) + assert sring(x + t*y/2 + t**2*z/3, x, y, z) == (R, R.x + Rt.t*R.y/2 + Rt.t**2*R.z/3) + + Rt = FracField("t", ZZ, lex) + R = PolyRing("x,y,z", Rt, lex) + assert sring(x + 2*y/t + t**2*z/3, x, y, z) == (R, R.x + 2*R.y/Rt.t + Rt.t**2*R.z/3) + + r = sqrt(2) - sqrt(3) + R, a = sring(r, extension=True) + assert R.domain == QQ.algebraic_field(sqrt(2) + sqrt(3)) + assert R.gens == () + assert a == R.domain.from_sympy(r) + +def test_PolyElement___hash__(): + R, x, y, z = ring("x,y,z", QQ) + assert hash(x*y*z) + +def test_PolyElement___eq__(): + R, x, y = ring("x,y", ZZ, lex) + + assert ((x*y + 5*x*y) == 6) == False + assert ((x*y + 5*x*y) == 6*x*y) == True + assert (6 == (x*y + 5*x*y)) == False + assert (6*x*y == (x*y + 5*x*y)) == True + + assert ((x*y - x*y) == 0) == True + assert (0 == (x*y - x*y)) == True + + assert ((x*y - x*y) == 1) == False + assert (1 == (x*y - x*y)) == False + + assert ((x*y - x*y) == 1) == False + assert (1 == (x*y - x*y)) == False + + assert ((x*y + 5*x*y) != 6) == True + assert ((x*y + 5*x*y) != 6*x*y) == False + assert (6 != (x*y + 5*x*y)) == True + assert (6*x*y != (x*y + 5*x*y)) == False + + assert ((x*y - x*y) != 0) == False + assert (0 != (x*y - x*y)) == False + + assert ((x*y - x*y) != 1) == True + assert (1 != (x*y - x*y)) == True + + assert R.one == QQ(1, 1) == R.one + assert R.one == 1 == R.one + + Rt, t = ring("t", ZZ) + R, x, y = ring("x,y", Rt) + + assert (t**3*x/x == t**3) == True + assert (t**3*x/x == t**4) == False + +def test_PolyElement__lt_le_gt_ge__(): + R, x, y = ring("x,y", ZZ) + + assert R(1) < x < x**2 < x**3 + assert R(1) <= x <= x**2 <= x**3 + + assert x**3 > x**2 > x > R(1) + assert x**3 >= x**2 >= x >= R(1) + +def test_PolyElement__str__(): + x, y = symbols('x, y') + + for dom in [ZZ, QQ, ZZ[x], ZZ[x,y], ZZ[x][y]]: + R, t = ring('t', dom) + assert str(2*t**2 + 1) == '2*t**2 + 1' + + for dom in [EX, EX[x]]: + R, t = ring('t', dom) + assert str(2*t**2 + 1) == 'EX(2)*t**2 + EX(1)' + +def test_PolyElement_copy(): + R, x, y, z = ring("x,y,z", ZZ) + + f = x*y + 3*z + g = f.copy() + + assert f == g + g[(1, 1, 1)] = 7 + assert f != g + +def test_PolyElement_as_expr(): + R, x, y, z = ring("x,y,z", ZZ) + f = 3*x**2*y - x*y*z + 7*z**3 + 1 + + X, Y, Z = R.symbols + g = 3*X**2*Y - X*Y*Z + 7*Z**3 + 1 + + assert f != g + assert f.as_expr() == g + + U, V, W = symbols("u,v,w") + g = 3*U**2*V - U*V*W + 7*W**3 + 1 + + assert f != g + assert f.as_expr(U, V, W) == g + + raises(ValueError, lambda: f.as_expr(X)) + + R, = ring("", ZZ) + assert R(3).as_expr() == 3 + +def test_PolyElement_from_expr(): + x, y, z = symbols("x,y,z") + R, X, Y, Z = ring((x, y, z), ZZ) + + f = R.from_expr(1) + assert f == 1 and R.is_element(f) + + f = R.from_expr(x) + assert f == X and R.is_element(f) + + f = R.from_expr(x*y*z) + assert f == X*Y*Z and R.is_element(f) + + f = R.from_expr(x*y*z + x*y + x) + assert f == X*Y*Z + X*Y + X and R.is_element(f) + + f = R.from_expr(x**3*y*z + x**2*y**7 + 1) + assert f == X**3*Y*Z + X**2*Y**7 + 1 and R.is_element(f) + + r, F = sring([exp(2)]) + f = r.from_expr(exp(2)) + assert f == F[0] and r.is_element(f) + + raises(ValueError, lambda: R.from_expr(1/x)) + raises(ValueError, lambda: R.from_expr(2**x)) + raises(ValueError, lambda: R.from_expr(7*x + sqrt(2))) + + R, = ring("", ZZ) + f = R.from_expr(1) + assert f == 1 and R.is_element(f) + +def test_PolyElement_degree(): + R, x,y,z = ring("x,y,z", ZZ) + + assert ninf == float('-inf') + + assert R(0).degree() is ninf + assert R(1).degree() == 0 + assert (x + 1).degree() == 1 + assert (2*y**3 + z).degree() == 0 + assert (x*y**3 + z).degree() == 1 + assert (x**5*y**3 + z).degree() == 5 + + assert R(0).degree(x) is ninf + assert R(1).degree(x) == 0 + assert (x + 1).degree(x) == 1 + assert (2*y**3 + z).degree(x) == 0 + assert (x*y**3 + z).degree(x) == 1 + assert (7*x**5*y**3 + z).degree(x) == 5 + + assert R(0).degree(y) is ninf + assert R(1).degree(y) == 0 + assert (x + 1).degree(y) == 0 + assert (2*y**3 + z).degree(y) == 3 + assert (x*y**3 + z).degree(y) == 3 + assert (7*x**5*y**3 + z).degree(y) == 3 + + assert R(0).degree(z) is ninf + assert R(1).degree(z) == 0 + assert (x + 1).degree(z) == 0 + assert (2*y**3 + z).degree(z) == 1 + assert (x*y**3 + z).degree(z) == 1 + assert (7*x**5*y**3 + z).degree(z) == 1 + + R, = ring("", ZZ) + assert R(0).degree() is ninf + assert R(1).degree() == 0 + +def test_PolyElement_tail_degree(): + R, x,y,z = ring("x,y,z", ZZ) + + assert R(0).tail_degree() is ninf + assert R(1).tail_degree() == 0 + assert (x + 1).tail_degree() == 0 + assert (2*y**3 + x**3*z).tail_degree() == 0 + assert (x*y**3 + x**3*z).tail_degree() == 1 + assert (x**5*y**3 + x**3*z).tail_degree() == 3 + + assert R(0).tail_degree(x) is ninf + assert R(1).tail_degree(x) == 0 + assert (x + 1).tail_degree(x) == 0 + assert (2*y**3 + x**3*z).tail_degree(x) == 0 + assert (x*y**3 + x**3*z).tail_degree(x) == 1 + assert (7*x**5*y**3 + x**3*z).tail_degree(x) == 3 + + assert R(0).tail_degree(y) is ninf + assert R(1).tail_degree(y) == 0 + assert (x + 1).tail_degree(y) == 0 + assert (2*y**3 + x**3*z).tail_degree(y) == 0 + assert (x*y**3 + x**3*z).tail_degree(y) == 0 + assert (7*x**5*y**3 + x**3*z).tail_degree(y) == 0 + + assert R(0).tail_degree(z) is ninf + assert R(1).tail_degree(z) == 0 + assert (x + 1).tail_degree(z) == 0 + assert (2*y**3 + x**3*z).tail_degree(z) == 0 + assert (x*y**3 + x**3*z).tail_degree(z) == 0 + assert (7*x**5*y**3 + x**3*z).tail_degree(z) == 0 + + R, = ring("", ZZ) + assert R(0).tail_degree() is ninf + assert R(1).tail_degree() == 0 + +def test_PolyElement_degrees(): + R, x,y,z = ring("x,y,z", ZZ) + + assert R(0).degrees() == (ninf, ninf, ninf) + assert R(1).degrees() == (0, 0, 0) + assert (x**2*y + x**3*z**2).degrees() == (3, 1, 2) + +def test_PolyElement_tail_degrees(): + R, x,y,z = ring("x,y,z", ZZ) + + assert R(0).tail_degrees() == (ninf, ninf, ninf) + assert R(1).tail_degrees() == (0, 0, 0) + assert (x**2*y + x**3*z**2).tail_degrees() == (2, 0, 0) + +def test_PolyElement_coeff(): + R, x, y, z = ring("x,y,z", ZZ, lex) + f = 3*x**2*y - x*y*z + 7*z**3 + 23 + + assert f.coeff(1) == 23 + raises(ValueError, lambda: f.coeff(3)) + + assert f.coeff(x) == 0 + assert f.coeff(y) == 0 + assert f.coeff(z) == 0 + + assert f.coeff(x**2*y) == 3 + assert f.coeff(x*y*z) == -1 + assert f.coeff(z**3) == 7 + + raises(ValueError, lambda: f.coeff(3*x**2*y)) + raises(ValueError, lambda: f.coeff(-x*y*z)) + raises(ValueError, lambda: f.coeff(7*z**3)) + + R, = ring("", ZZ) + assert R(3).coeff(1) == 3 + +def test_PolyElement_LC(): + R, x, y = ring("x,y", QQ, lex) + assert R(0).LC == QQ(0) + assert (QQ(1,2)*x).LC == QQ(1, 2) + assert (QQ(1,4)*x*y + QQ(1,2)*x).LC == QQ(1, 4) + +def test_PolyElement_LM(): + R, x, y = ring("x,y", QQ, lex) + assert R(0).LM == (0, 0) + assert (QQ(1,2)*x).LM == (1, 0) + assert (QQ(1,4)*x*y + QQ(1,2)*x).LM == (1, 1) + +def test_PolyElement_LT(): + R, x, y = ring("x,y", QQ, lex) + assert R(0).LT == ((0, 0), QQ(0)) + assert (QQ(1,2)*x).LT == ((1, 0), QQ(1, 2)) + assert (QQ(1,4)*x*y + QQ(1,2)*x).LT == ((1, 1), QQ(1, 4)) + + R, = ring("", ZZ) + assert R(0).LT == ((), 0) + assert R(1).LT == ((), 1) + +def test_PolyElement_leading_monom(): + R, x, y = ring("x,y", QQ, lex) + assert R(0).leading_monom() == 0 + assert (QQ(1,2)*x).leading_monom() == x + assert (QQ(1,4)*x*y + QQ(1,2)*x).leading_monom() == x*y + +def test_PolyElement_leading_term(): + R, x, y = ring("x,y", QQ, lex) + assert R(0).leading_term() == 0 + assert (QQ(1,2)*x).leading_term() == QQ(1,2)*x + assert (QQ(1,4)*x*y + QQ(1,2)*x).leading_term() == QQ(1,4)*x*y + +def test_PolyElement_terms(): + R, x,y,z = ring("x,y,z", QQ) + terms = (x**2/3 + y**3/4 + z**4/5).terms() + assert terms == [((2,0,0), QQ(1,3)), ((0,3,0), QQ(1,4)), ((0,0,4), QQ(1,5))] + + R, x,y = ring("x,y", ZZ, lex) + f = x*y**7 + 2*x**2*y**3 + + assert f.terms() == f.terms(lex) == f.terms('lex') == [((2, 3), 2), ((1, 7), 1)] + assert f.terms(grlex) == f.terms('grlex') == [((1, 7), 1), ((2, 3), 2)] + + R, x,y = ring("x,y", ZZ, grlex) + f = x*y**7 + 2*x**2*y**3 + + assert f.terms() == f.terms(grlex) == f.terms('grlex') == [((1, 7), 1), ((2, 3), 2)] + assert f.terms(lex) == f.terms('lex') == [((2, 3), 2), ((1, 7), 1)] + + R, = ring("", ZZ) + assert R(3).terms() == [((), 3)] + +def test_PolyElement_monoms(): + R, x,y,z = ring("x,y,z", QQ) + monoms = (x**2/3 + y**3/4 + z**4/5).monoms() + assert monoms == [(2,0,0), (0,3,0), (0,0,4)] + + R, x,y = ring("x,y", ZZ, lex) + f = x*y**7 + 2*x**2*y**3 + + assert f.monoms() == f.monoms(lex) == f.monoms('lex') == [(2, 3), (1, 7)] + assert f.monoms(grlex) == f.monoms('grlex') == [(1, 7), (2, 3)] + + R, x,y = ring("x,y", ZZ, grlex) + f = x*y**7 + 2*x**2*y**3 + + assert f.monoms() == f.monoms(grlex) == f.monoms('grlex') == [(1, 7), (2, 3)] + assert f.monoms(lex) == f.monoms('lex') == [(2, 3), (1, 7)] + +def test_PolyElement_coeffs(): + R, x,y,z = ring("x,y,z", QQ) + coeffs = (x**2/3 + y**3/4 + z**4/5).coeffs() + assert coeffs == [QQ(1,3), QQ(1,4), QQ(1,5)] + + R, x,y = ring("x,y", ZZ, lex) + f = x*y**7 + 2*x**2*y**3 + + assert f.coeffs() == f.coeffs(lex) == f.coeffs('lex') == [2, 1] + assert f.coeffs(grlex) == f.coeffs('grlex') == [1, 2] + + R, x,y = ring("x,y", ZZ, grlex) + f = x*y**7 + 2*x**2*y**3 + + assert f.coeffs() == f.coeffs(grlex) == f.coeffs('grlex') == [1, 2] + assert f.coeffs(lex) == f.coeffs('lex') == [2, 1] + +def test_PolyElement___add__(): + Rt, t = ring("t", ZZ) + Ruv, u,v = ring("u,v", ZZ) + Rxyz, x,y,z = ring("x,y,z", Ruv) + + assert dict(x + 3*y) == {(1, 0, 0): 1, (0, 1, 0): 3} + + assert dict(u + x) == dict(x + u) == {(1, 0, 0): 1, (0, 0, 0): u} + assert dict(u + x*y) == dict(x*y + u) == {(1, 1, 0): 1, (0, 0, 0): u} + assert dict(u + x*y + z) == dict(x*y + z + u) == {(1, 1, 0): 1, (0, 0, 1): 1, (0, 0, 0): u} + + assert dict(u*x + x) == dict(x + u*x) == {(1, 0, 0): u + 1} + assert dict(u*x + x*y) == dict(x*y + u*x) == {(1, 1, 0): 1, (1, 0, 0): u} + assert dict(u*x + x*y + z) == dict(x*y + z + u*x) == {(1, 1, 0): 1, (0, 0, 1): 1, (1, 0, 0): u} + + raises(TypeError, lambda: t + x) + raises(TypeError, lambda: x + t) + raises(TypeError, lambda: t + u) + raises(TypeError, lambda: u + t) + + Fuv, u,v = field("u,v", ZZ) + Rxyz, x,y,z = ring("x,y,z", Fuv) + + assert dict(u + x) == dict(x + u) == {(1, 0, 0): 1, (0, 0, 0): u} + + Rxyz, x,y,z = ring("x,y,z", EX) + + assert dict(EX(pi) + x*y*z) == dict(x*y*z + EX(pi)) == {(1, 1, 1): EX(1), (0, 0, 0): EX(pi)} + +def test_PolyElement___sub__(): + Rt, t = ring("t", ZZ) + Ruv, u,v = ring("u,v", ZZ) + Rxyz, x,y,z = ring("x,y,z", Ruv) + + assert dict(x - 3*y) == {(1, 0, 0): 1, (0, 1, 0): -3} + + assert dict(-u + x) == dict(x - u) == {(1, 0, 0): 1, (0, 0, 0): -u} + assert dict(-u + x*y) == dict(x*y - u) == {(1, 1, 0): 1, (0, 0, 0): -u} + assert dict(-u + x*y + z) == dict(x*y + z - u) == {(1, 1, 0): 1, (0, 0, 1): 1, (0, 0, 0): -u} + + assert dict(-u*x + x) == dict(x - u*x) == {(1, 0, 0): -u + 1} + assert dict(-u*x + x*y) == dict(x*y - u*x) == {(1, 1, 0): 1, (1, 0, 0): -u} + assert dict(-u*x + x*y + z) == dict(x*y + z - u*x) == {(1, 1, 0): 1, (0, 0, 1): 1, (1, 0, 0): -u} + + raises(TypeError, lambda: t - x) + raises(TypeError, lambda: x - t) + raises(TypeError, lambda: t - u) + raises(TypeError, lambda: u - t) + + Fuv, u,v = field("u,v", ZZ) + Rxyz, x,y,z = ring("x,y,z", Fuv) + + assert dict(-u + x) == dict(x - u) == {(1, 0, 0): 1, (0, 0, 0): -u} + + Rxyz, x,y,z = ring("x,y,z", EX) + + assert dict(-EX(pi) + x*y*z) == dict(x*y*z - EX(pi)) == {(1, 1, 1): EX(1), (0, 0, 0): -EX(pi)} + +def test_PolyElement___mul__(): + Rt, t = ring("t", ZZ) + Ruv, u,v = ring("u,v", ZZ) + Rxyz, x,y,z = ring("x,y,z", Ruv) + + assert dict(u*x) == dict(x*u) == {(1, 0, 0): u} + + assert dict(2*u*x + z) == dict(x*2*u + z) == {(1, 0, 0): 2*u, (0, 0, 1): 1} + assert dict(u*2*x + z) == dict(2*x*u + z) == {(1, 0, 0): 2*u, (0, 0, 1): 1} + assert dict(2*u*x + z) == dict(x*2*u + z) == {(1, 0, 0): 2*u, (0, 0, 1): 1} + assert dict(u*x*2 + z) == dict(x*u*2 + z) == {(1, 0, 0): 2*u, (0, 0, 1): 1} + + assert dict(2*u*x*y + z) == dict(x*y*2*u + z) == {(1, 1, 0): 2*u, (0, 0, 1): 1} + assert dict(u*2*x*y + z) == dict(2*x*y*u + z) == {(1, 1, 0): 2*u, (0, 0, 1): 1} + assert dict(2*u*x*y + z) == dict(x*y*2*u + z) == {(1, 1, 0): 2*u, (0, 0, 1): 1} + assert dict(u*x*y*2 + z) == dict(x*y*u*2 + z) == {(1, 1, 0): 2*u, (0, 0, 1): 1} + + assert dict(2*u*y*x + z) == dict(y*x*2*u + z) == {(1, 1, 0): 2*u, (0, 0, 1): 1} + assert dict(u*2*y*x + z) == dict(2*y*x*u + z) == {(1, 1, 0): 2*u, (0, 0, 1): 1} + assert dict(2*u*y*x + z) == dict(y*x*2*u + z) == {(1, 1, 0): 2*u, (0, 0, 1): 1} + assert dict(u*y*x*2 + z) == dict(y*x*u*2 + z) == {(1, 1, 0): 2*u, (0, 0, 1): 1} + + assert dict(3*u*(x + y) + z) == dict((x + y)*3*u + z) == {(1, 0, 0): 3*u, (0, 1, 0): 3*u, (0, 0, 1): 1} + + raises(TypeError, lambda: t*x + z) + raises(TypeError, lambda: x*t + z) + raises(TypeError, lambda: t*u + z) + raises(TypeError, lambda: u*t + z) + + Fuv, u,v = field("u,v", ZZ) + Rxyz, x,y,z = ring("x,y,z", Fuv) + + assert dict(u*x) == dict(x*u) == {(1, 0, 0): u} + + Rxyz, x,y,z = ring("x,y,z", EX) + + assert dict(EX(pi)*x*y*z) == dict(x*y*z*EX(pi)) == {(1, 1, 1): EX(pi)} + +def test_PolyElement___truediv__(): + R, x,y,z = ring("x,y,z", ZZ) + + assert (2*x**2 - 4)/2 == x**2 - 2 + assert (2*x**2 - 3)/2 == x**2 + + assert (x**2 - 1).quo(x) == x + assert (x**2 - x).quo(x) == x - 1 + + raises(ExactQuotientFailed, lambda: (x**2 - 1)/x) + assert (x**2 - x)/x == x - 1 + raises(ExactQuotientFailed, lambda: (x**2 - 1)/(2*x)) + + assert (x**2 - 1).quo(2*x) == 0 + assert (x**2 - x)/(x - 1) == (x**2 - x).quo(x - 1) == x + + + R, x,y,z = ring("x,y,z", ZZ) + assert len((x**2/3 + y**3/4 + z**4/5).terms()) == 0 + + R, x,y,z = ring("x,y,z", QQ) + assert len((x**2/3 + y**3/4 + z**4/5).terms()) == 3 + + Rt, t = ring("t", ZZ) + Ruv, u,v = ring("u,v", ZZ) + Rxyz, x,y,z = ring("x,y,z", Ruv) + + assert dict((u**2*x + u)/u) == {(1, 0, 0): u, (0, 0, 0): 1} + raises(ExactQuotientFailed, lambda: u/(u**2*x + u)) + + raises(TypeError, lambda: t/x) + raises(TypeError, lambda: x/t) + raises(TypeError, lambda: t/u) + raises(TypeError, lambda: u/t) + + R, x = ring("x", ZZ) + f, g = x**2 + 2*x + 3, R(0) + + raises(ZeroDivisionError, lambda: f.div(g)) + raises(ZeroDivisionError, lambda: divmod(f, g)) + raises(ZeroDivisionError, lambda: f.rem(g)) + raises(ZeroDivisionError, lambda: f % g) + raises(ZeroDivisionError, lambda: f.quo(g)) + raises(ZeroDivisionError, lambda: f / g) + raises(ZeroDivisionError, lambda: f.exquo(g)) + + R, x, y = ring("x,y", ZZ) + f, g = x*y + 2*x + 3, R(0) + + raises(ZeroDivisionError, lambda: f.div(g)) + raises(ZeroDivisionError, lambda: divmod(f, g)) + raises(ZeroDivisionError, lambda: f.rem(g)) + raises(ZeroDivisionError, lambda: f % g) + raises(ZeroDivisionError, lambda: f.quo(g)) + raises(ZeroDivisionError, lambda: f / g) + raises(ZeroDivisionError, lambda: f.exquo(g)) + + R, x = ring("x", ZZ) + + f, g = x**2 + 1, 2*x - 4 + q, r = R(0), x**2 + 1 + + assert f.div(g) == divmod(f, g) == (q, r) + assert f.rem(g) == f % g == r + assert f.quo(g) == q + raises(ExactQuotientFailed, lambda: f / g) + raises(ExactQuotientFailed, lambda: f.exquo(g)) + + f, g = 3*x**3 + x**2 + x + 5, 5*x**2 - 3*x + 1 + q, r = R(0), f + + assert f.div(g) == divmod(f, g) == (q, r) + assert f.rem(g) == f % g == r + assert f.quo(g) == q + raises(ExactQuotientFailed, lambda: f / g) + raises(ExactQuotientFailed, lambda: f.exquo(g)) + + f, g = 5*x**4 + 4*x**3 + 3*x**2 + 2*x + 1, x**2 + 2*x + 3 + q, r = 5*x**2 - 6*x, 20*x + 1 + + assert f.div(g) == divmod(f, g) == (q, r) + assert f.rem(g) == f % g == r + assert f.quo(g) == q + raises(ExactQuotientFailed, lambda: f / g) + raises(ExactQuotientFailed, lambda: f.exquo(g)) + + f, g = 5*x**5 + 4*x**4 + 3*x**3 + 2*x**2 + x, x**4 + 2*x**3 + 9 + q, r = 5*x - 6, 15*x**3 + 2*x**2 - 44*x + 54 + + assert f.div(g) == divmod(f, g) == (q, r) + assert f.rem(g) == f % g == r + assert f.quo(g) == q + raises(ExactQuotientFailed, lambda: f / g) + raises(ExactQuotientFailed, lambda: f.exquo(g)) + + R, x = ring("x", QQ) + + f, g = x**2 + 1, 2*x - 4 + q, r = x/2 + 1, R(5) + + assert f.div(g) == divmod(f, g) == (q, r) + assert f.rem(g) == f % g == r + assert f.quo(g) == q + raises(ExactQuotientFailed, lambda: f / g) + raises(ExactQuotientFailed, lambda: f.exquo(g)) + + f, g = 3*x**3 + x**2 + x + 5, 5*x**2 - 3*x + 1 + q, r = QQ(3, 5)*x + QQ(14, 25), QQ(52, 25)*x + QQ(111, 25) + + assert f.div(g) == divmod(f, g) == (q, r) + assert f.rem(g) == f % g == r + assert f.quo(g) == q + raises(ExactQuotientFailed, lambda: f / g) + raises(ExactQuotientFailed, lambda: f.exquo(g)) + + R, x,y = ring("x,y", ZZ) + + f, g = x**2 - y**2, x - y + q, r = x + y, R(0) + + assert f.div(g) == divmod(f, g) == (q, r) + assert f.rem(g) == f % g == r + assert f.quo(g) == q + assert f.exquo(g) == f / g == q + + f, g = x**2 + y**2, x - y + q, r = x + y, 2*y**2 + + assert f.div(g) == divmod(f, g) == (q, r) + assert f.rem(g) == f % g == r + assert f.quo(g) == q + raises(ExactQuotientFailed, lambda: f / g) + raises(ExactQuotientFailed, lambda: f.exquo(g)) + + f, g = x**2 + y**2, -x + y + q, r = -x - y, 2*y**2 + + assert f.div(g) == divmod(f, g) == (q, r) + assert f.rem(g) == f % g == r + assert f.quo(g) == q + raises(ExactQuotientFailed, lambda: f / g) + raises(ExactQuotientFailed, lambda: f.exquo(g)) + + f, g = x**2 + y**2, 2*x - 2*y + q, r = R(0), f + + assert f.div(g) == divmod(f, g) == (q, r) + assert f.rem(g) == f % g == r + assert f.quo(g) == q + raises(ExactQuotientFailed, lambda: f / g) + raises(ExactQuotientFailed, lambda: f.exquo(g)) + + R, x,y = ring("x,y", QQ) + + f, g = x**2 - y**2, x - y + q, r = x + y, R(0) + + assert f.div(g) == divmod(f, g) == (q, r) + assert f.rem(g) == f % g == r + assert f.quo(g) == q + assert f.exquo(g) == f / g == q + + f, g = x**2 + y**2, x - y + q, r = x + y, 2*y**2 + + assert f.div(g) == divmod(f, g) == (q, r) + assert f.rem(g) == f % g == r + assert f.quo(g) == q + raises(ExactQuotientFailed, lambda: f / g) + raises(ExactQuotientFailed, lambda: f.exquo(g)) + + f, g = x**2 + y**2, -x + y + q, r = -x - y, 2*y**2 + + assert f.div(g) == divmod(f, g) == (q, r) + assert f.rem(g) == f % g == r + assert f.quo(g) == q + raises(ExactQuotientFailed, lambda: f / g) + raises(ExactQuotientFailed, lambda: f.exquo(g)) + + f, g = x**2 + y**2, 2*x - 2*y + q, r = x/2 + y/2, 2*y**2 + + assert f.div(g) == divmod(f, g) == (q, r) + assert f.rem(g) == f % g == r + assert f.quo(g) == q + raises(ExactQuotientFailed, lambda: f / g) + raises(ExactQuotientFailed, lambda: f.exquo(g)) + +def test_PolyElement___pow__(): + R, x = ring("x", ZZ, grlex) + f = 2*x + 3 + + assert f**0 == 1 + assert f**1 == f + raises(ValueError, lambda: f**(-1)) + + assert f**2 == f._pow_generic(2) == f._pow_multinomial(2) == 4*x**2 + 12*x + 9 + assert f**3 == f._pow_generic(3) == f._pow_multinomial(3) == 8*x**3 + 36*x**2 + 54*x + 27 + assert f**4 == f._pow_generic(4) == f._pow_multinomial(4) == 16*x**4 + 96*x**3 + 216*x**2 + 216*x + 81 + assert f**5 == f._pow_generic(5) == f._pow_multinomial(5) == 32*x**5 + 240*x**4 + 720*x**3 + 1080*x**2 + 810*x + 243 + + R, x,y,z = ring("x,y,z", ZZ, grlex) + f = x**3*y - 2*x*y**2 - 3*z + 1 + g = x**6*y**2 - 4*x**4*y**3 - 6*x**3*y*z + 2*x**3*y + 4*x**2*y**4 + 12*x*y**2*z - 4*x*y**2 + 9*z**2 - 6*z + 1 + + assert f**2 == f._pow_generic(2) == f._pow_multinomial(2) == g + + R, t = ring("t", ZZ) + f = -11200*t**4 - 2604*t**2 + 49 + g = 15735193600000000*t**16 + 14633730048000000*t**14 + 4828147466240000*t**12 \ + + 598976863027200*t**10 + 3130812416256*t**8 - 2620523775744*t**6 \ + + 92413760096*t**4 - 1225431984*t**2 + 5764801 + + assert f**4 == f._pow_generic(4) == f._pow_multinomial(4) == g + +def test_PolyElement_div(): + R, x = ring("x", ZZ, grlex) + + f = x**3 - 12*x**2 - 42 + g = x - 3 + + q = x**2 - 9*x - 27 + r = -123 + + assert f.div([g]) == ([q], r) + + R, x = ring("x", ZZ, grlex) + f = x**2 + 2*x + 2 + assert f.div([R(1)]) == ([f], 0) + + R, x = ring("x", QQ, grlex) + f = x**2 + 2*x + 2 + assert f.div([R(2)]) == ([QQ(1,2)*x**2 + x + 1], 0) + + R, x,y = ring("x,y", ZZ, grlex) + f = 4*x**2*y - 2*x*y + 4*x - 2*y + 8 + + assert f.div([R(2)]) == ([2*x**2*y - x*y + 2*x - y + 4], 0) + assert f.div([2*y]) == ([2*x**2 - x - 1], 4*x + 8) + + f = x - 1 + g = y - 1 + + assert f.div([g]) == ([0], f) + + f = x*y**2 + 1 + G = [x*y + 1, y + 1] + + Q = [y, -1] + r = 2 + + assert f.div(G) == (Q, r) + + f = x**2*y + x*y**2 + y**2 + G = [x*y - 1, y**2 - 1] + + Q = [x + y, 1] + r = x + y + 1 + + assert f.div(G) == (Q, r) + + G = [y**2 - 1, x*y - 1] + + Q = [x + 1, x] + r = 2*x + 1 + + assert f.div(G) == (Q, r) + + R, = ring("", ZZ) + assert R(3).div(R(2)) == (0, 3) + + R, = ring("", QQ) + assert R(3).div(R(2)) == (QQ(3, 2), 0) + +def test_PolyElement_rem(): + R, x = ring("x", ZZ, grlex) + + f = x**3 - 12*x**2 - 42 + g = x - 3 + r = -123 + + assert f.rem([g]) == f.div([g])[1] == r + + R, x,y = ring("x,y", ZZ, grlex) + + f = 4*x**2*y - 2*x*y + 4*x - 2*y + 8 + + assert f.rem([R(2)]) == f.div([R(2)])[1] == 0 + assert f.rem([2*y]) == f.div([2*y])[1] == 4*x + 8 + + f = x - 1 + g = y - 1 + + assert f.rem([g]) == f.div([g])[1] == f + + f = x*y**2 + 1 + G = [x*y + 1, y + 1] + r = 2 + + assert f.rem(G) == f.div(G)[1] == r + + f = x**2*y + x*y**2 + y**2 + G = [x*y - 1, y**2 - 1] + r = x + y + 1 + + assert f.rem(G) == f.div(G)[1] == r + + G = [y**2 - 1, x*y - 1] + r = 2*x + 1 + + assert f.rem(G) == f.div(G)[1] == r + +def test_PolyElement_deflate(): + R, x = ring("x", ZZ) + + assert (2*x**2).deflate(x**4 + 4*x**2 + 1) == ((2,), [2*x, x**2 + 4*x + 1]) + + R, x,y = ring("x,y", ZZ) + + assert R(0).deflate(R(0)) == ((1, 1), [0, 0]) + assert R(1).deflate(R(0)) == ((1, 1), [1, 0]) + assert R(1).deflate(R(2)) == ((1, 1), [1, 2]) + assert R(1).deflate(2*y) == ((1, 1), [1, 2*y]) + assert (2*y).deflate(2*y) == ((1, 1), [2*y, 2*y]) + assert R(2).deflate(2*y**2) == ((1, 2), [2, 2*y]) + assert (2*y**2).deflate(2*y**2) == ((1, 2), [2*y, 2*y]) + + f = x**4*y**2 + x**2*y + 1 + g = x**2*y**3 + x**2*y + 1 + + assert f.deflate(g) == ((2, 1), [x**2*y**2 + x*y + 1, x*y**3 + x*y + 1]) + +def test_PolyElement_clear_denoms(): + R, x,y = ring("x,y", QQ) + + assert R(1).clear_denoms() == (ZZ(1), 1) + assert R(7).clear_denoms() == (ZZ(1), 7) + + assert R(QQ(7,3)).clear_denoms() == (3, 7) + assert R(QQ(7,3)).clear_denoms() == (3, 7) + + assert (3*x**2 + x).clear_denoms() == (1, 3*x**2 + x) + assert (x**2 + QQ(1,2)*x).clear_denoms() == (2, 2*x**2 + x) + + rQQ, x,t = ring("x,t", QQ, lex) + rZZ, X,T = ring("x,t", ZZ, lex) + + F = [x - QQ(17824537287975195925064602467992950991718052713078834557692023531499318507213727406844943097,413954288007559433755329699713866804710749652268151059918115348815925474842910720000)*t**7 + - QQ(4882321164854282623427463828745855894130208215961904469205260756604820743234704900167747753,12936071500236232304854053116058337647210926633379720622441104650497671088840960000)*t**6 + - QQ(36398103304520066098365558157422127347455927422509913596393052633155821154626830576085097433,25872143000472464609708106232116675294421853266759441244882209300995342177681920000)*t**5 + - QQ(168108082231614049052707339295479262031324376786405372698857619250210703675982492356828810819,58212321751063045371843239022262519412449169850208742800984970927239519899784320000)*t**4 + - QQ(5694176899498574510667890423110567593477487855183144378347226247962949388653159751849449037,1617008937529529038106756639507292205901365829172465077805138081312208886105120000)*t**3 + - QQ(154482622347268833757819824809033388503591365487934245386958884099214649755244381307907779,60637835157357338929003373981523457721301218593967440417692678049207833228942000)*t**2 + - QQ(2452813096069528207645703151222478123259511586701148682951852876484544822947007791153163,2425513406294293557160134959260938308852048743758697616707707121968313329157680)*t + - QQ(34305265428126440542854669008203683099323146152358231964773310260498715579162112959703,202126117191191129763344579938411525737670728646558134725642260164026110763140), + t**8 + QQ(693749860237914515552,67859264524169150569)*t**7 + + QQ(27761407182086143225024,610733380717522355121)*t**6 + + QQ(7785127652157884044288,67859264524169150569)*t**5 + + QQ(36567075214771261409792,203577793572507451707)*t**4 + + QQ(36336335165196147384320,203577793572507451707)*t**3 + + QQ(7452455676042754048000,67859264524169150569)*t**2 + + QQ(2593331082514399232000,67859264524169150569)*t + + QQ(390399197427343360000,67859264524169150569)] + + G = [3725588592068034903797967297424801242396746870413359539263038139343329273586196480000*X - + 160420835591776763325581422211936558925462474417709511019228211783493866564923546661604487873*T**7 - + 1406108495478033395547109582678806497509499966197028487131115097902188374051595011248311352864*T**6 - + 5241326875850889518164640374668786338033653548841427557880599579174438246266263602956254030352*T**5 - + 10758917262823299139373269714910672770004760114329943852726887632013485035262879510837043892416*T**4 - + 13119383576444715672578819534846747735372132018341964647712009275306635391456880068261130581248*T**3 - + 9491412317016197146080450036267011389660653495578680036574753839055748080962214787557853941760*T**2 - + 3767520915562795326943800040277726397326609797172964377014046018280260848046603967211258368000*T - + 632314652371226552085897259159210286886724229880266931574701654721512325555116066073245696000, + 610733380717522355121*T**8 + + 6243748742141230639968*T**7 + + 27761407182086143225024*T**6 + + 70066148869420956398592*T**5 + + 109701225644313784229376*T**4 + + 109009005495588442152960*T**3 + + 67072101084384786432000*T**2 + + 23339979742629593088000*T + + 3513592776846090240000] + + assert [ f.clear_denoms()[1].set_ring(rZZ) for f in F ] == G + +def test_PolyElement_cofactors(): + R, x, y = ring("x,y", ZZ) + + f, g = R(0), R(0) + assert f.cofactors(g) == (0, 0, 0) + + f, g = R(2), R(0) + assert f.cofactors(g) == (2, 1, 0) + + f, g = R(-2), R(0) + assert f.cofactors(g) == (2, -1, 0) + + f, g = R(0), R(-2) + assert f.cofactors(g) == (2, 0, -1) + + f, g = R(0), 2*x + 4 + assert f.cofactors(g) == (2*x + 4, 0, 1) + + f, g = 2*x + 4, R(0) + assert f.cofactors(g) == (2*x + 4, 1, 0) + + f, g = R(2), R(2) + assert f.cofactors(g) == (2, 1, 1) + + f, g = R(-2), R(2) + assert f.cofactors(g) == (2, -1, 1) + + f, g = R(2), R(-2) + assert f.cofactors(g) == (2, 1, -1) + + f, g = R(-2), R(-2) + assert f.cofactors(g) == (2, -1, -1) + + f, g = x**2 + 2*x + 1, R(1) + assert f.cofactors(g) == (1, x**2 + 2*x + 1, 1) + + f, g = x**2 + 2*x + 1, R(2) + assert f.cofactors(g) == (1, x**2 + 2*x + 1, 2) + + f, g = 2*x**2 + 4*x + 2, R(2) + assert f.cofactors(g) == (2, x**2 + 2*x + 1, 1) + + f, g = R(2), 2*x**2 + 4*x + 2 + assert f.cofactors(g) == (2, 1, x**2 + 2*x + 1) + + f, g = 2*x**2 + 4*x + 2, x + 1 + assert f.cofactors(g) == (x + 1, 2*x + 2, 1) + + f, g = x + 1, 2*x**2 + 4*x + 2 + assert f.cofactors(g) == (x + 1, 1, 2*x + 2) + + R, x, y, z, t = ring("x,y,z,t", ZZ) + + f, g = t**2 + 2*t + 1, 2*t + 2 + assert f.cofactors(g) == (t + 1, t + 1, 2) + + f, g = z**2*t**2 + 2*z**2*t + z**2 + z*t + z, t**2 + 2*t + 1 + h, cff, cfg = t + 1, z**2*t + z**2 + z, t + 1 + + assert f.cofactors(g) == (h, cff, cfg) + assert g.cofactors(f) == (h, cfg, cff) + + R, x, y = ring("x,y", QQ) + + f = QQ(1,2)*x**2 + x + QQ(1,2) + g = QQ(1,2)*x + QQ(1,2) + + h = x + 1 + + assert f.cofactors(g) == (h, g, QQ(1,2)) + assert g.cofactors(f) == (h, QQ(1,2), g) + + R, x, y = ring("x,y", RR) + + f = 2.1*x*y**2 - 2.1*x*y + 2.1*x + g = 2.1*x**3 + h = 1.0*x + + assert f.cofactors(g) == (h, f/h, g/h) + assert g.cofactors(f) == (h, g/h, f/h) + +def test_PolyElement_gcd(): + R, x, y = ring("x,y", QQ) + + f = QQ(1,2)*x**2 + x + QQ(1,2) + g = QQ(1,2)*x + QQ(1,2) + + assert f.gcd(g) == x + 1 + +def test_PolyElement_cancel(): + R, x, y = ring("x,y", ZZ) + + f = 2*x**3 + 4*x**2 + 2*x + g = 3*x**2 + 3*x + F = 2*x + 2 + G = 3 + + assert f.cancel(g) == (F, G) + + assert (-f).cancel(g) == (-F, G) + assert f.cancel(-g) == (-F, G) + + R, x, y = ring("x,y", QQ) + + f = QQ(1,2)*x**3 + x**2 + QQ(1,2)*x + g = QQ(1,3)*x**2 + QQ(1,3)*x + F = 3*x + 3 + G = 2 + + assert f.cancel(g) == (F, G) + + assert (-f).cancel(g) == (-F, G) + assert f.cancel(-g) == (-F, G) + + Fx, x = field("x", ZZ) + Rt, t = ring("t", Fx) + + f = (-x**2 - 4)/4*t + g = t**2 + (x**2 + 2)/2 + + assert f.cancel(g) == ((-x**2 - 4)*t, 4*t**2 + 2*x**2 + 4) + +def test_PolyElement_max_norm(): + R, x, y = ring("x,y", ZZ) + + assert R(0).max_norm() == 0 + assert R(1).max_norm() == 1 + + assert (x**3 + 4*x**2 + 2*x + 3).max_norm() == 4 + +def test_PolyElement_l1_norm(): + R, x, y = ring("x,y", ZZ) + + assert R(0).l1_norm() == 0 + assert R(1).l1_norm() == 1 + + assert (x**3 + 4*x**2 + 2*x + 3).l1_norm() == 10 + +def test_PolyElement_diff(): + R, X = xring("x:11", QQ) + + f = QQ(288,5)*X[0]**8*X[1]**6*X[4]**3*X[10]**2 + 8*X[0]**2*X[2]**3*X[4]**3 +2*X[0]**2 - 2*X[1]**2 + + assert f.diff(X[0]) == QQ(2304,5)*X[0]**7*X[1]**6*X[4]**3*X[10]**2 + 16*X[0]*X[2]**3*X[4]**3 + 4*X[0] + assert f.diff(X[4]) == QQ(864,5)*X[0]**8*X[1]**6*X[4]**2*X[10]**2 + 24*X[0]**2*X[2]**3*X[4]**2 + assert f.diff(X[10]) == QQ(576,5)*X[0]**8*X[1]**6*X[4]**3*X[10] + +def test_PolyElement___call__(): + R, x = ring("x", ZZ) + f = 3*x + 1 + + assert f(0) == 1 + assert f(1) == 4 + + raises(ValueError, lambda: f()) + raises(ValueError, lambda: f(0, 1)) + + raises(CoercionFailed, lambda: f(QQ(1,7))) + + R, x,y = ring("x,y", ZZ) + f = 3*x + y**2 + 1 + + assert f(0, 0) == 1 + assert f(1, 7) == 53 + + Ry = R.drop(x) + + assert f(0) == Ry.y**2 + 1 + assert f(1) == Ry.y**2 + 4 + + raises(ValueError, lambda: f()) + raises(ValueError, lambda: f(0, 1, 2)) + + raises(CoercionFailed, lambda: f(1, QQ(1,7))) + raises(CoercionFailed, lambda: f(QQ(1,7), 1)) + raises(CoercionFailed, lambda: f(QQ(1,7), QQ(1,7))) + +def test_PolyElement_evaluate(): + R, x = ring("x", ZZ) + f = x**3 + 4*x**2 + 2*x + 3 + + r = f.evaluate(x, 0) + assert r == 3 and not isinstance(r, PolyElement) + + raises(CoercionFailed, lambda: f.evaluate(x, QQ(1,7))) + + R, x, y, z = ring("x,y,z", ZZ) + f = (x*y)**3 + 4*(x*y)**2 + 2*x*y + 3 + + r = f.evaluate(x, 0) + assert r == 3 and R.drop(x).is_element(r) + r = f.evaluate([(x, 0), (y, 0)]) + assert r == 3 and R.drop(x, y).is_element(r) + r = f.evaluate(y, 0) + assert r == 3 and R.drop(y).is_element(r) + r = f.evaluate([(y, 0), (x, 0)]) + assert r == 3 and R.drop(y, x).is_element(r) + + r = f.evaluate([(x, 0), (y, 0), (z, 0)]) + assert r == 3 and not isinstance(r, PolyElement) + + raises(CoercionFailed, lambda: f.evaluate([(x, 1), (y, QQ(1,7))])) + raises(CoercionFailed, lambda: f.evaluate([(x, QQ(1,7)), (y, 1)])) + raises(CoercionFailed, lambda: f.evaluate([(x, QQ(1,7)), (y, QQ(1,7))])) + +def test_PolyElement_subs(): + R, x = ring("x", ZZ) + f = x**3 + 4*x**2 + 2*x + 3 + + r = f.subs(x, 0) + assert r == 3 and R.is_element(r) + + raises(CoercionFailed, lambda: f.subs(x, QQ(1,7))) + + R, x, y, z = ring("x,y,z", ZZ) + f = x**3 + 4*x**2 + 2*x + 3 + + r = f.subs(x, 0) + assert r == 3 and R.is_element(r) + r = f.subs([(x, 0), (y, 0)]) + assert r == 3 and R.is_element(r) + + raises(CoercionFailed, lambda: f.subs([(x, 1), (y, QQ(1,7))])) + raises(CoercionFailed, lambda: f.subs([(x, QQ(1,7)), (y, 1)])) + raises(CoercionFailed, lambda: f.subs([(x, QQ(1,7)), (y, QQ(1,7))])) + +def test_PolyElement_symmetrize(): + R, x, y = ring("x,y", ZZ) + + # Homogeneous, symmetric + f = x**2 + y**2 + sym, rem, m = f.symmetrize() + assert rem == 0 + assert sym.compose(m) + rem == f + + # Homogeneous, asymmetric + f = x**2 - y**2 + sym, rem, m = f.symmetrize() + assert rem != 0 + assert sym.compose(m) + rem == f + + # Inhomogeneous, symmetric + f = x*y + 7 + sym, rem, m = f.symmetrize() + assert rem == 0 + assert sym.compose(m) + rem == f + + # Inhomogeneous, asymmetric + f = y + 7 + sym, rem, m = f.symmetrize() + assert rem != 0 + assert sym.compose(m) + rem == f + + # Constant + f = R.from_expr(3) + sym, rem, m = f.symmetrize() + assert rem == 0 + assert sym.compose(m) + rem == f + + # Constant constructed from sring + R, f = sring(3) + sym, rem, m = f.symmetrize() + assert rem == 0 + assert sym.compose(m) + rem == f + +def test_PolyElement_compose(): + R, x = ring("x", ZZ) + f = x**3 + 4*x**2 + 2*x + 3 + + r = f.compose(x, 0) + assert r == 3 and R.is_element(r) + + assert f.compose(x, x) == f + assert f.compose(x, x**2) == x**6 + 4*x**4 + 2*x**2 + 3 + + raises(CoercionFailed, lambda: f.compose(x, QQ(1,7))) + + R, x, y, z = ring("x,y,z", ZZ) + f = x**3 + 4*x**2 + 2*x + 3 + + r = f.compose(x, 0) + assert r == 3 and R.is_element(r) + r = f.compose([(x, 0), (y, 0)]) + assert r == 3 and R.is_element(r) + + r = (x**3 + 4*x**2 + 2*x*y*z + 3).compose(x, y*z**2 - 1) + q = (y*z**2 - 1)**3 + 4*(y*z**2 - 1)**2 + 2*(y*z**2 - 1)*y*z + 3 + assert r == q and R.is_element(r) + +def test_PolyElement_is_(): + R, x,y,z = ring("x,y,z", QQ) + + assert (x - x).is_generator == False + assert (x - x).is_ground == True + assert (x - x).is_monomial == True + assert (x - x).is_term == True + + assert (x - x + 1).is_generator == False + assert (x - x + 1).is_ground == True + assert (x - x + 1).is_monomial == True + assert (x - x + 1).is_term == True + + assert x.is_generator == True + assert x.is_ground == False + assert x.is_monomial == True + assert x.is_term == True + + assert (x*y).is_generator == False + assert (x*y).is_ground == False + assert (x*y).is_monomial == True + assert (x*y).is_term == True + + assert (3*x).is_generator == False + assert (3*x).is_ground == False + assert (3*x).is_monomial == False + assert (3*x).is_term == True + + assert (3*x + 1).is_generator == False + assert (3*x + 1).is_ground == False + assert (3*x + 1).is_monomial == False + assert (3*x + 1).is_term == False + + assert R(0).is_zero is True + assert R(1).is_zero is False + + assert R(0).is_one is False + assert R(1).is_one is True + + assert (x - 1).is_monic is True + assert (2*x - 1).is_monic is False + + assert (3*x + 2).is_primitive is True + assert (4*x + 2).is_primitive is False + + assert (x + y + z + 1).is_linear is True + assert (x*y*z + 1).is_linear is False + + assert (x*y + z + 1).is_quadratic is True + assert (x*y*z + 1).is_quadratic is False + + assert (x - 1).is_squarefree is True + assert ((x - 1)**2).is_squarefree is False + + assert (x**2 + x + 1).is_irreducible is True + assert (x**2 + 2*x + 1).is_irreducible is False + + _, t = ring("t", FF(11)) + + assert (7*t + 3).is_irreducible is True + assert (7*t**2 + 3*t + 1).is_irreducible is False + + _, u = ring("u", ZZ) + f = u**16 + u**14 - u**10 - u**8 - u**6 + u**2 + + assert f.is_cyclotomic is False + assert (f + 1).is_cyclotomic is True + + raises(MultivariatePolynomialError, lambda: x.is_cyclotomic) + + R, = ring("", ZZ) + assert R(4).is_squarefree is True + assert R(6).is_irreducible is True + +def test_PolyElement_drop(): + R, x,y,z = ring("x,y,z", ZZ) + + assert R(1).drop(0).ring == PolyRing("y,z", ZZ, lex) + assert R(1).drop(0).drop(0).ring == PolyRing("z", ZZ, lex) + assert R.is_element(R(1).drop(0).drop(0).drop(0)) is False + + raises(ValueError, lambda: z.drop(0).drop(0).drop(0)) + raises(ValueError, lambda: x.drop(0)) + +def test_PolyElement_coeff_wrt(): + R, x, y, z = ring("x, y, z", ZZ) + + p = 4*x**3 + 5*y**2 + 6*y**2*z + 7 + assert p.coeff_wrt(1, 2) == 6*z + 5 # using generator index + assert p.coeff_wrt(x, 3) == 4 # using generator + + p = 2*x**4 + 3*x*y**2*z + 10*y**2 + 10*x*z**2 + assert p.coeff_wrt(x, 1) == 3*y**2*z + 10*z**2 + assert p.coeff_wrt(y, 2) == 3*x*z + 10 + + p = 4*x**2 + 2*x*y + 5 + assert p.coeff_wrt(z, 1) == R(0) + assert p.coeff_wrt(y, 2) == R(0) + +def test_PolyElement_prem(): + R, x, y = ring("x, y", ZZ) + + f, g = x**2 + x*y, 2*x + 2 + assert f.prem(g) == -4*y + 4 # first generator is chosen by default if it is not given + + f, g = x**2 + 1, 2*x - 4 + assert f.prem(g) == f.prem(g, x) == 20 + assert f.prem(g, 1) == R(0) + + f, g = x*y + 2*x + 1, x + y + assert f.prem(g) == -y**2 - 2*y + 1 + assert f.prem(g, 1) == f.prem(g, y) == -x**2 + 2*x + 1 + + raises(ZeroDivisionError, lambda: f.prem(R(0))) + +def test_PolyElement_pdiv(): + R, x, y = ring("x,y", ZZ) + + f, g = x**4 + 5*x**3 + 7*x**2, 2*x**2 + 3 + assert f.pdiv(g) == f.pdiv(g, x) == (4*x**2 + 20*x + 22, -60*x - 66) + + f, g = x**2 - y**2, x - y + assert f.pdiv(g) == f.pdiv(g, 0) == (x + y, 0) + + f, g = x*y + 2*x + 1, x + y + assert f.pdiv(g) == (y + 2, -y**2 - 2*y + 1) + assert f.pdiv(g, y) == f.pdiv(g, 1) == (x + 1, -x**2 + 2*x + 1) + + assert R(0).pdiv(g) == (0, 0) + raises(ZeroDivisionError, lambda: f.prem(R(0))) + +def test_PolyElement_pquo(): + R, x, y = ring("x, y", ZZ) + + f, g = x**4 - 4*x**2*y + 4*y**2, x**2 - 2*y + assert f.pquo(g) == f.pquo(g, x) == x**2 - 2*y + assert f.pquo(g, y) == 4*x**2 - 8*y + 4 + + f, g = x**4 - y**4, x**2 - y**2 + assert f.pquo(g) == f.pquo(g, 0) == x**2 + y**2 + +def test_PolyElement_pexquo(): + R, x, y = ring("x, y", ZZ) + + f, g = x**2 - y**2, x - y + assert f.pexquo(g) == f.pexquo(g, x) == x + y + assert f.pexquo(g, y) == f.pexquo(g, 1) == x + y + 1 + + f, g = x**2 + 3*x + 6, x + 2 + raises(ExactQuotientFailed, lambda: f.pexquo(g)) + +def test_PolyElement_gcdex(): + _, x = ring("x", QQ) + + f, g = 2*x, x**2 - 16 + s, t, h = x/32, -QQ(1, 16), 1 + + assert f.half_gcdex(g) == (s, h) + assert f.gcdex(g) == (s, t, h) + +def test_PolyElement_subresultants(): + R, x, y = ring("x, y", ZZ) + + f, g = x**2*y + x*y, x + y # degree(f, x) > degree(g, x) + h = y**3 - y**2 + assert f.subresultants(g) == [f, g, h] # first generator is chosen default + + # generator index or generator is given + assert f.subresultants(g, 0) == f.subresultants(g, x) == [f, g, h] + + assert f.subresultants(g, y) == [x**2*y + x*y, x + y, x**3 + x**2] + + f, g = 2*x - y, x**2 + 2*y + x # degree(f, x) < degree(g, x) + assert f.subresultants(g) == [x**2 + x + 2*y, 2*x - y, y**2 + 10*y] + + f, g = R(0), y**3 - y**2 # f = 0 + assert f.subresultants(g) == [y**3 - y**2, 1] + + f, g = x**2*y + x*y, R(0) # g = 0 + assert f.subresultants(g) == [x**2*y + x*y, 1] + + f, g = R(0), R(0) # f = 0 and g = 0 + assert f.subresultants(g) == [0, 0] + + f, g = x**2 + x, x**2 + x # f and g are same polynomial + assert f.subresultants(g) == [x**2 + x, x**2 + x] + +def test_PolyElement_resultant(): + _, x = ring("x", ZZ) + f, g, h = x**2 - 2*x + 1, x**2 - 1, 0 + + assert f.resultant(g) == h + +def test_PolyElement_discriminant(): + _, x = ring("x", ZZ) + f, g = x**3 + 3*x**2 + 9*x - 13, -11664 + + assert f.discriminant() == g + + F, a, b, c = ring("a,b,c", ZZ) + _, x = ring("x", F) + + f, g = a*x**2 + b*x + c, b**2 - 4*a*c + + assert f.discriminant() == g + +def test_PolyElement_decompose(): + _, x = ring("x", ZZ) + + f = x**12 + 20*x**10 + 150*x**8 + 500*x**6 + 625*x**4 - 2*x**3 - 10*x + 9 + g = x**4 - 2*x + 9 + h = x**3 + 5*x + + assert g.compose(x, h) == f + assert f.decompose() == [g, h] + +def test_PolyElement_shift(): + _, x = ring("x", ZZ) + assert (x**2 - 2*x + 1).shift(2) == x**2 + 2*x + 1 + assert (x**2 - 2*x + 1).shift_list([2]) == x**2 + 2*x + 1 + + R, x, y = ring("x, y", ZZ) + assert (x*y).shift_list([1, 2]) == (x+1)*(y+2) + + raises(MultivariatePolynomialError, lambda: (x*y).shift(1)) + +def test_PolyElement_sturm(): + F, t = field("t", ZZ) + _, x = ring("x", F) + + f = 1024/(15625*t**8)*x**5 - 4096/(625*t**8)*x**4 + 32/(15625*t**4)*x**3 - 128/(625*t**4)*x**2 + F(1)/62500*x - F(1)/625 + + assert f.sturm() == [ + x**3 - 100*x**2 + t**4/64*x - 25*t**4/16, + 3*x**2 - 200*x + t**4/64, + (-t**4/96 + F(20000)/9)*x + 25*t**4/18, + (-9*t**12 - 11520000*t**8 - 3686400000000*t**4)/(576*t**8 - 245760000*t**4 + 26214400000000), + ] + +def test_PolyElement_gff_list(): + _, x = ring("x", ZZ) + + f = x**5 + 2*x**4 - x**3 - 2*x**2 + assert f.gff_list() == [(x, 1), (x + 2, 4)] + + f = x*(x - 1)**3*(x - 2)**2*(x - 4)**2*(x - 5) + assert f.gff_list() == [(x**2 - 5*x + 4, 1), (x**2 - 5*x + 4, 2), (x, 3)] + +def test_PolyElement_norm(): + k = QQ + K = QQ.algebraic_field(sqrt(2)) + sqrt2 = K.unit + _, X, Y = ring("x,y", k) + _, x, y = ring("x,y", K) + + assert (x*y + sqrt2).norm() == X**2*Y**2 - 2 + +def test_PolyElement_sqf_norm(): + R, x = ring("x", QQ.algebraic_field(sqrt(3))) + X = R.to_ground().x + + assert (x**2 - 2).sqf_norm() == ([1], x**2 - 2*sqrt(3)*x + 1, X**4 - 10*X**2 + 1) + + R, x = ring("x", QQ.algebraic_field(sqrt(2))) + X = R.to_ground().x + + assert (x**2 - 3).sqf_norm() == ([1], x**2 - 2*sqrt(2)*x - 1, X**4 - 10*X**2 + 1) + +def test_PolyElement_sqf_list(): + _, x = ring("x", ZZ) + + f = x**5 - x**3 - x**2 + 1 + g = x**3 + 2*x**2 + 2*x + 1 + h = x - 1 + p = x**4 + x**3 - x - 1 + + assert f.sqf_part() == p + assert f.sqf_list() == (1, [(g, 1), (h, 2)]) + +def test_issue_18894(): + items = [S(3)/16 + sqrt(3*sqrt(3) + 10)/8, S(1)/8 + 3*sqrt(3)/16, S(1)/8 + 3*sqrt(3)/16, -S(3)/16 + sqrt(3*sqrt(3) + 10)/8] + R, a = sring(items, extension=True) + assert R.domain == QQ.algebraic_field(sqrt(3)+sqrt(3*sqrt(3)+10)) + assert R.gens == () + result = [] + for item in items: + result.append(R.domain.from_sympy(item)) + assert a == result + +def test_PolyElement_factor_list(): + _, x = ring("x", ZZ) + + f = x**5 - x**3 - x**2 + 1 + + u = x + 1 + v = x - 1 + w = x**2 + x + 1 + + assert f.factor_list() == (1, [(u, 1), (v, 2), (w, 1)]) + + +def test_issue_21410(): + R, x = ring('x', FF(2)) + p = x**6 + x**5 + x**4 + x**3 + 1 + assert p._pow_multinomial(4) == x**24 + x**20 + x**16 + x**12 + 1 + + +def test_zero_polynomial_primitive(): + + x = symbols('x') + + R = ZZ[x] + zero_poly = R(0) + cont, prim = zero_poly.primitive() + assert cont == 0 + assert prim == zero_poly + assert prim.is_primitive is False diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/test_rootisolation.py b/phivenv/Lib/site-packages/sympy/polys/tests/test_rootisolation.py new file mode 100644 index 0000000000000000000000000000000000000000..9661c1d6b63bfb941157c7e904ba4e048afbc538 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/tests/test_rootisolation.py @@ -0,0 +1,823 @@ +"""Tests for real and complex root isolation and refinement algorithms. """ + +from sympy.polys.rings import ring +from sympy.polys.domains import ZZ, QQ, ZZ_I, EX +from sympy.polys.polyerrors import DomainError, RefinementFailed, PolynomialError +from sympy.polys.rootisolation import ( + dup_cauchy_upper_bound, dup_cauchy_lower_bound, + dup_mignotte_sep_bound_squared, +) +from sympy.testing.pytest import raises + +def test_dup_sturm(): + R, x = ring("x", QQ) + + assert R.dup_sturm(5) == [1] + assert R.dup_sturm(x) == [x, 1] + + f = x**3 - 2*x**2 + 3*x - 5 + assert R.dup_sturm(f) == [f, 3*x**2 - 4*x + 3, -QQ(10,9)*x + QQ(13,3), -QQ(3303,100)] + + +def test_dup_cauchy_upper_bound(): + raises(PolynomialError, lambda: dup_cauchy_upper_bound([], QQ)) + raises(PolynomialError, lambda: dup_cauchy_upper_bound([QQ(1)], QQ)) + raises(DomainError, lambda: dup_cauchy_upper_bound([ZZ_I(1), ZZ_I(1)], ZZ_I)) + + assert dup_cauchy_upper_bound([QQ(1), QQ(0), QQ(0)], QQ) == QQ.zero + assert dup_cauchy_upper_bound([QQ(1), QQ(0), QQ(-2)], QQ) == QQ(3) + + +def test_dup_cauchy_lower_bound(): + raises(PolynomialError, lambda: dup_cauchy_lower_bound([], QQ)) + raises(PolynomialError, lambda: dup_cauchy_lower_bound([QQ(1)], QQ)) + raises(PolynomialError, lambda: dup_cauchy_lower_bound([QQ(1), QQ(0), QQ(0)], QQ)) + raises(DomainError, lambda: dup_cauchy_lower_bound([ZZ_I(1), ZZ_I(1)], ZZ_I)) + + assert dup_cauchy_lower_bound([QQ(1), QQ(0), QQ(-2)], QQ) == QQ(2, 3) + + +def test_dup_mignotte_sep_bound_squared(): + raises(PolynomialError, lambda: dup_mignotte_sep_bound_squared([], QQ)) + raises(PolynomialError, lambda: dup_mignotte_sep_bound_squared([QQ(1)], QQ)) + + assert dup_mignotte_sep_bound_squared([QQ(1), QQ(0), QQ(-2)], QQ) == QQ(3, 5) + + +def test_dup_refine_real_root(): + R, x = ring("x", ZZ) + f = x**2 - 2 + + assert R.dup_refine_real_root(f, QQ(1), QQ(1), steps=1) == (QQ(1), QQ(1)) + assert R.dup_refine_real_root(f, QQ(1), QQ(1), steps=9) == (QQ(1), QQ(1)) + + raises(ValueError, lambda: R.dup_refine_real_root(f, QQ(-2), QQ(2))) + + s, t = QQ(1, 1), QQ(2, 1) + + assert R.dup_refine_real_root(f, s, t, steps=0) == (QQ(1, 1), QQ(2, 1)) + assert R.dup_refine_real_root(f, s, t, steps=1) == (QQ(1, 1), QQ(3, 2)) + assert R.dup_refine_real_root(f, s, t, steps=2) == (QQ(4, 3), QQ(3, 2)) + assert R.dup_refine_real_root(f, s, t, steps=3) == (QQ(7, 5), QQ(3, 2)) + assert R.dup_refine_real_root(f, s, t, steps=4) == (QQ(7, 5), QQ(10, 7)) + + s, t = QQ(1, 1), QQ(3, 2) + + assert R.dup_refine_real_root(f, s, t, steps=0) == (QQ(1, 1), QQ(3, 2)) + assert R.dup_refine_real_root(f, s, t, steps=1) == (QQ(4, 3), QQ(3, 2)) + assert R.dup_refine_real_root(f, s, t, steps=2) == (QQ(7, 5), QQ(3, 2)) + assert R.dup_refine_real_root(f, s, t, steps=3) == (QQ(7, 5), QQ(10, 7)) + assert R.dup_refine_real_root(f, s, t, steps=4) == (QQ(7, 5), QQ(17, 12)) + + s, t = QQ(1, 1), QQ(5, 3) + + assert R.dup_refine_real_root(f, s, t, steps=0) == (QQ(1, 1), QQ(5, 3)) + assert R.dup_refine_real_root(f, s, t, steps=1) == (QQ(1, 1), QQ(3, 2)) + assert R.dup_refine_real_root(f, s, t, steps=2) == (QQ(7, 5), QQ(3, 2)) + assert R.dup_refine_real_root(f, s, t, steps=3) == (QQ(7, 5), QQ(13, 9)) + assert R.dup_refine_real_root(f, s, t, steps=4) == (QQ(7, 5), QQ(27, 19)) + + s, t = QQ(-1, 1), QQ(-2, 1) + + assert R.dup_refine_real_root(f, s, t, steps=0) == (-QQ(2, 1), -QQ(1, 1)) + assert R.dup_refine_real_root(f, s, t, steps=1) == (-QQ(3, 2), -QQ(1, 1)) + assert R.dup_refine_real_root(f, s, t, steps=2) == (-QQ(3, 2), -QQ(4, 3)) + assert R.dup_refine_real_root(f, s, t, steps=3) == (-QQ(3, 2), -QQ(7, 5)) + assert R.dup_refine_real_root(f, s, t, steps=4) == (-QQ(10, 7), -QQ(7, 5)) + + raises(RefinementFailed, lambda: R.dup_refine_real_root(f, QQ(0), QQ(1))) + + s, t, u, v, w = QQ(1), QQ(2), QQ(24, 17), QQ(17, 12), QQ(7, 5) + + assert R.dup_refine_real_root(f, s, t, eps=QQ(1, 100)) == (u, v) + assert R.dup_refine_real_root(f, s, t, steps=6) == (u, v) + + assert R.dup_refine_real_root(f, s, t, eps=QQ(1, 100), steps=5) == (w, v) + assert R.dup_refine_real_root(f, s, t, eps=QQ(1, 100), steps=6) == (u, v) + assert R.dup_refine_real_root(f, s, t, eps=QQ(1, 100), steps=7) == (u, v) + + s, t, u, v = QQ(-2), QQ(-1), QQ(-3, 2), QQ(-4, 3) + + assert R.dup_refine_real_root(f, s, t, disjoint=QQ(-5)) == (s, t) + assert R.dup_refine_real_root(f, s, t, disjoint=-v) == (s, t) + assert R.dup_refine_real_root(f, s, t, disjoint=v) == (u, v) + + s, t, u, v = QQ(1), QQ(2), QQ(4, 3), QQ(3, 2) + + assert R.dup_refine_real_root(f, s, t, disjoint=QQ(5)) == (s, t) + assert R.dup_refine_real_root(f, s, t, disjoint=-u) == (s, t) + assert R.dup_refine_real_root(f, s, t, disjoint=u) == (u, v) + + +def test_dup_isolate_real_roots_sqf(): + R, x = ring("x", ZZ) + + assert R.dup_isolate_real_roots_sqf(0) == [] + assert R.dup_isolate_real_roots_sqf(5) == [] + + assert R.dup_isolate_real_roots_sqf(x**2 + x) == [(-1, -1), (0, 0)] + assert R.dup_isolate_real_roots_sqf(x**2 - x) == [( 0, 0), (1, 1)] + + assert R.dup_isolate_real_roots_sqf(x**4 + x + 1) == [] + + I = [(-2, -1), (1, 2)] + + assert R.dup_isolate_real_roots_sqf(x**2 - 2) == I + assert R.dup_isolate_real_roots_sqf(-x**2 + 2) == I + + assert R.dup_isolate_real_roots_sqf(x - 1) == \ + [(1, 1)] + assert R.dup_isolate_real_roots_sqf(x**2 - 3*x + 2) == \ + [(1, 1), (2, 2)] + assert R.dup_isolate_real_roots_sqf(x**3 - 6*x**2 + 11*x - 6) == \ + [(1, 1), (2, 2), (3, 3)] + assert R.dup_isolate_real_roots_sqf(x**4 - 10*x**3 + 35*x**2 - 50*x + 24) == \ + [(1, 1), (2, 2), (3, 3), (4, 4)] + assert R.dup_isolate_real_roots_sqf(x**5 - 15*x**4 + 85*x**3 - 225*x**2 + 274*x - 120) == \ + [(1, 1), (2, 2), (3, 3), (4, 4), (5, 5)] + + assert R.dup_isolate_real_roots_sqf(x - 10) == \ + [(10, 10)] + assert R.dup_isolate_real_roots_sqf(x**2 - 30*x + 200) == \ + [(10, 10), (20, 20)] + assert R.dup_isolate_real_roots_sqf(x**3 - 60*x**2 + 1100*x - 6000) == \ + [(10, 10), (20, 20), (30, 30)] + assert R.dup_isolate_real_roots_sqf(x**4 - 100*x**3 + 3500*x**2 - 50000*x + 240000) == \ + [(10, 10), (20, 20), (30, 30), (40, 40)] + assert R.dup_isolate_real_roots_sqf(x**5 - 150*x**4 + 8500*x**3 - 225000*x**2 + 2740000*x - 12000000) == \ + [(10, 10), (20, 20), (30, 30), (40, 40), (50, 50)] + + assert R.dup_isolate_real_roots_sqf(x + 1) == \ + [(-1, -1)] + assert R.dup_isolate_real_roots_sqf(x**2 + 3*x + 2) == \ + [(-2, -2), (-1, -1)] + assert R.dup_isolate_real_roots_sqf(x**3 + 6*x**2 + 11*x + 6) == \ + [(-3, -3), (-2, -2), (-1, -1)] + assert R.dup_isolate_real_roots_sqf(x**4 + 10*x**3 + 35*x**2 + 50*x + 24) == \ + [(-4, -4), (-3, -3), (-2, -2), (-1, -1)] + assert R.dup_isolate_real_roots_sqf(x**5 + 15*x**4 + 85*x**3 + 225*x**2 + 274*x + 120) == \ + [(-5, -5), (-4, -4), (-3, -3), (-2, -2), (-1, -1)] + + assert R.dup_isolate_real_roots_sqf(x + 10) == \ + [(-10, -10)] + assert R.dup_isolate_real_roots_sqf(x**2 + 30*x + 200) == \ + [(-20, -20), (-10, -10)] + assert R.dup_isolate_real_roots_sqf(x**3 + 60*x**2 + 1100*x + 6000) == \ + [(-30, -30), (-20, -20), (-10, -10)] + assert R.dup_isolate_real_roots_sqf(x**4 + 100*x**3 + 3500*x**2 + 50000*x + 240000) == \ + [(-40, -40), (-30, -30), (-20, -20), (-10, -10)] + assert R.dup_isolate_real_roots_sqf(x**5 + 150*x**4 + 8500*x**3 + 225000*x**2 + 2740000*x + 12000000) == \ + [(-50, -50), (-40, -40), (-30, -30), (-20, -20), (-10, -10)] + + assert R.dup_isolate_real_roots_sqf(x**2 - 5) == [(-3, -2), (2, 3)] + assert R.dup_isolate_real_roots_sqf(x**3 - 5) == [(1, 2)] + assert R.dup_isolate_real_roots_sqf(x**4 - 5) == [(-2, -1), (1, 2)] + assert R.dup_isolate_real_roots_sqf(x**5 - 5) == [(1, 2)] + assert R.dup_isolate_real_roots_sqf(x**6 - 5) == [(-2, -1), (1, 2)] + assert R.dup_isolate_real_roots_sqf(x**7 - 5) == [(1, 2)] + assert R.dup_isolate_real_roots_sqf(x**8 - 5) == [(-2, -1), (1, 2)] + assert R.dup_isolate_real_roots_sqf(x**9 - 5) == [(1, 2)] + + assert R.dup_isolate_real_roots_sqf(x**2 - 1) == \ + [(-1, -1), (1, 1)] + assert R.dup_isolate_real_roots_sqf(x**3 + 2*x**2 - x - 2) == \ + [(-2, -2), (-1, -1), (1, 1)] + assert R.dup_isolate_real_roots_sqf(x**4 - 5*x**2 + 4) == \ + [(-2, -2), (-1, -1), (1, 1), (2, 2)] + assert R.dup_isolate_real_roots_sqf(x**5 + 3*x**4 - 5*x**3 - 15*x**2 + 4*x + 12) == \ + [(-3, -3), (-2, -2), (-1, -1), (1, 1), (2, 2)] + assert R.dup_isolate_real_roots_sqf(x**6 - 14*x**4 + 49*x**2 - 36) == \ + [(-3, -3), (-2, -2), (-1, -1), (1, 1), (2, 2), (3, 3)] + assert R.dup_isolate_real_roots_sqf(2*x**7 + x**6 - 28*x**5 - 14*x**4 + 98*x**3 + 49*x**2 - 72*x - 36) == \ + [(-3, -3), (-2, -2), (-1, -1), (-1, 0), (1, 1), (2, 2), (3, 3)] + assert R.dup_isolate_real_roots_sqf(4*x**8 - 57*x**6 + 210*x**4 - 193*x**2 + 36) == \ + [(-3, -3), (-2, -2), (-1, -1), (-1, 0), (0, 1), (1, 1), (2, 2), (3, 3)] + + f = 9*x**2 - 2 + + assert R.dup_isolate_real_roots_sqf(f) == \ + [(-1, 0), (0, 1)] + + assert R.dup_isolate_real_roots_sqf(f, eps=QQ(1, 10)) == \ + [(QQ(-1, 2), QQ(-3, 7)), (QQ(3, 7), QQ(1, 2))] + assert R.dup_isolate_real_roots_sqf(f, eps=QQ(1, 100)) == \ + [(QQ(-9, 19), QQ(-8, 17)), (QQ(8, 17), QQ(9, 19))] + assert R.dup_isolate_real_roots_sqf(f, eps=QQ(1, 1000)) == \ + [(QQ(-33, 70), QQ(-8, 17)), (QQ(8, 17), QQ(33, 70))] + assert R.dup_isolate_real_roots_sqf(f, eps=QQ(1, 10000)) == \ + [(QQ(-33, 70), QQ(-107, 227)), (QQ(107, 227), QQ(33, 70))] + assert R.dup_isolate_real_roots_sqf(f, eps=QQ(1, 100000)) == \ + [(QQ(-305, 647), QQ(-272, 577)), (QQ(272, 577), QQ(305, 647))] + assert R.dup_isolate_real_roots_sqf(f, eps=QQ(1, 1000000)) == \ + [(QQ(-1121, 2378), QQ(-272, 577)), (QQ(272, 577), QQ(1121, 2378))] + + f = 200100012*x**5 - 700390052*x**4 + 700490079*x**3 - 200240054*x**2 + 40017*x - 2 + + assert R.dup_isolate_real_roots_sqf(f) == \ + [(QQ(0), QQ(1, 10002)), (QQ(1, 10002), QQ(1, 10002)), + (QQ(1, 2), QQ(1, 2)), (QQ(1), QQ(1)), (QQ(2), QQ(2))] + + assert R.dup_isolate_real_roots_sqf(f, eps=QQ(1, 100000)) == \ + [(QQ(1, 10003), QQ(1, 10003)), (QQ(1, 10002), QQ(1, 10002)), + (QQ(1, 2), QQ(1, 2)), (QQ(1), QQ(1)), (QQ(2), QQ(2))] + + a, b, c, d = 10000090000001, 2000100003, 10000300007, 10000005000008 + + f = 20001600074001600021*x**4 \ + + 1700135866278935491773999857*x**3 \ + - 2000179008931031182161141026995283662899200197*x**2 \ + - 800027600594323913802305066986600025*x \ + + 100000950000540000725000008 + + assert R.dup_isolate_real_roots_sqf(f) == \ + [(-a, -a), (-1, 0), (0, 1), (d, d)] + + assert R.dup_isolate_real_roots_sqf(f, eps=QQ(1, 100000000000)) == \ + [(-QQ(a), -QQ(a)), (-QQ(1, b), -QQ(1, b)), (QQ(1, c), QQ(1, c)), (QQ(d), QQ(d))] + + (u, v), B, C, (s, t) = R.dup_isolate_real_roots_sqf(f, fast=True) + + assert u < -a < v and B == (-QQ(1), QQ(0)) and C == (QQ(0), QQ(1)) and s < d < t + + assert R.dup_isolate_real_roots_sqf(f, fast=True, eps=QQ(1, 100000000000000000000000000000)) == \ + [(-QQ(a), -QQ(a)), (-QQ(1, b), -QQ(1, b)), (QQ(1, c), QQ(1, c)), (QQ(d), QQ(d))] + + f = -10*x**4 + 8*x**3 + 80*x**2 - 32*x - 160 + + assert R.dup_isolate_real_roots_sqf(f) == \ + [(-2, -2), (-2, -1), (2, 2), (2, 3)] + + assert R.dup_isolate_real_roots_sqf(f, eps=QQ(1, 100)) == \ + [(-QQ(2), -QQ(2)), (-QQ(23, 14), -QQ(18, 11)), (QQ(2), QQ(2)), (QQ(39, 16), QQ(22, 9))] + + f = x - 1 + + assert R.dup_isolate_real_roots_sqf(f, inf=2) == [] + assert R.dup_isolate_real_roots_sqf(f, sup=0) == [] + + assert R.dup_isolate_real_roots_sqf(f) == [(1, 1)] + assert R.dup_isolate_real_roots_sqf(f, inf=1) == [(1, 1)] + assert R.dup_isolate_real_roots_sqf(f, sup=1) == [(1, 1)] + assert R.dup_isolate_real_roots_sqf(f, inf=1, sup=1) == [(1, 1)] + + f = x**2 - 2 + + assert R.dup_isolate_real_roots_sqf(f, inf=QQ(7, 4)) == [] + assert R.dup_isolate_real_roots_sqf(f, inf=QQ(7, 5)) == [(QQ(7, 5), QQ(3, 2))] + assert R.dup_isolate_real_roots_sqf(f, sup=QQ(7, 5)) == [(-2, -1)] + assert R.dup_isolate_real_roots_sqf(f, sup=QQ(7, 4)) == [(-2, -1), (1, QQ(3, 2))] + assert R.dup_isolate_real_roots_sqf(f, sup=-QQ(7, 4)) == [] + assert R.dup_isolate_real_roots_sqf(f, sup=-QQ(7, 5)) == [(-QQ(3, 2), -QQ(7, 5))] + assert R.dup_isolate_real_roots_sqf(f, inf=-QQ(7, 5)) == [(1, 2)] + assert R.dup_isolate_real_roots_sqf(f, inf=-QQ(7, 4)) == [(-QQ(3, 2), -1), (1, 2)] + + I = [(-2, -1), (1, 2)] + + assert R.dup_isolate_real_roots_sqf(f, inf=-2) == I + assert R.dup_isolate_real_roots_sqf(f, sup=+2) == I + + assert R.dup_isolate_real_roots_sqf(f, inf=-2, sup=2) == I + + R, x = ring("x", QQ) + f = QQ(8, 5)*x**2 - QQ(87374, 3855)*x - QQ(17, 771) + + assert R.dup_isolate_real_roots_sqf(f) == [(-1, 0), (14, 15)] + + R, x = ring("x", EX) + raises(DomainError, lambda: R.dup_isolate_real_roots_sqf(x + 3)) + +def test_dup_isolate_real_roots(): + R, x = ring("x", ZZ) + + assert R.dup_isolate_real_roots(0) == [] + assert R.dup_isolate_real_roots(3) == [] + + assert R.dup_isolate_real_roots(5*x) == [((0, 0), 1)] + assert R.dup_isolate_real_roots(7*x**4) == [((0, 0), 4)] + + assert R.dup_isolate_real_roots(x**2 + x) == [((-1, -1), 1), ((0, 0), 1)] + assert R.dup_isolate_real_roots(x**2 - x) == [((0, 0), 1), ((1, 1), 1)] + + assert R.dup_isolate_real_roots(x**4 + x + 1) == [] + + I = [((-2, -1), 1), ((1, 2), 1)] + + assert R.dup_isolate_real_roots(x**2 - 2) == I + assert R.dup_isolate_real_roots(-x**2 + 2) == I + + f = 16*x**14 - 96*x**13 + 24*x**12 + 936*x**11 - 1599*x**10 - 2880*x**9 + 9196*x**8 \ + + 552*x**7 - 21831*x**6 + 13968*x**5 + 21690*x**4 - 26784*x**3 - 2916*x**2 + 15552*x - 5832 + g = R.dup_sqf_part(f) + + assert R.dup_isolate_real_roots(f) == \ + [((-QQ(2), -QQ(3, 2)), 2), ((-QQ(3, 2), -QQ(1, 1)), 3), ((QQ(1), QQ(3, 2)), 3), + ((QQ(3, 2), QQ(3, 2)), 4), ((QQ(5, 3), QQ(2)), 2)] + + assert R.dup_isolate_real_roots_sqf(g) == \ + [(-QQ(2), -QQ(3, 2)), (-QQ(3, 2), -QQ(1, 1)), (QQ(1), QQ(3, 2)), + (QQ(3, 2), QQ(3, 2)), (QQ(3, 2), QQ(2))] + assert R.dup_isolate_real_roots(g) == \ + [((-QQ(2), -QQ(3, 2)), 1), ((-QQ(3, 2), -QQ(1, 1)), 1), ((QQ(1), QQ(3, 2)), 1), + ((QQ(3, 2), QQ(3, 2)), 1), ((QQ(3, 2), QQ(2)), 1)] + + f = x - 1 + + assert R.dup_isolate_real_roots(f, inf=2) == [] + assert R.dup_isolate_real_roots(f, sup=0) == [] + + assert R.dup_isolate_real_roots(f) == [((1, 1), 1)] + assert R.dup_isolate_real_roots(f, inf=1) == [((1, 1), 1)] + assert R.dup_isolate_real_roots(f, sup=1) == [((1, 1), 1)] + assert R.dup_isolate_real_roots(f, inf=1, sup=1) == [((1, 1), 1)] + + f = x**4 - 4*x**2 + 4 + + assert R.dup_isolate_real_roots(f, inf=QQ(7, 4)) == [] + assert R.dup_isolate_real_roots(f, inf=QQ(7, 5)) == [((QQ(7, 5), QQ(3, 2)), 2)] + assert R.dup_isolate_real_roots(f, sup=QQ(7, 5)) == [((-2, -1), 2)] + assert R.dup_isolate_real_roots(f, sup=QQ(7, 4)) == [((-2, -1), 2), ((1, QQ(3, 2)), 2)] + assert R.dup_isolate_real_roots(f, sup=-QQ(7, 4)) == [] + assert R.dup_isolate_real_roots(f, sup=-QQ(7, 5)) == [((-QQ(3, 2), -QQ(7, 5)), 2)] + assert R.dup_isolate_real_roots(f, inf=-QQ(7, 5)) == [((1, 2), 2)] + assert R.dup_isolate_real_roots(f, inf=-QQ(7, 4)) == [((-QQ(3, 2), -1), 2), ((1, 2), 2)] + + I = [((-2, -1), 2), ((1, 2), 2)] + + assert R.dup_isolate_real_roots(f, inf=-2) == I + assert R.dup_isolate_real_roots(f, sup=+2) == I + + assert R.dup_isolate_real_roots(f, inf=-2, sup=2) == I + + f = x**11 - 3*x**10 - x**9 + 11*x**8 - 8*x**7 - 8*x**6 + 12*x**5 - 4*x**4 + + assert R.dup_isolate_real_roots(f, basis=False) == \ + [((-2, -1), 2), ((0, 0), 4), ((1, 1), 3), ((1, 2), 2)] + assert R.dup_isolate_real_roots(f, basis=True) == \ + [((-2, -1), 2, [1, 0, -2]), ((0, 0), 4, [1, 0]), ((1, 1), 3, [1, -1]), ((1, 2), 2, [1, 0, -2])] + + f = (x**45 - 45*x**44 + 990*x**43 - 1) + g = (x**46 - 15180*x**43 + 9366819*x**40 - 53524680*x**39 + 260932815*x**38 - 1101716330*x**37 + 4076350421*x**36 - 13340783196*x**35 + 38910617655*x**34 - 101766230790*x**33 + 239877544005*x**32 - 511738760544*x**31 + 991493848554*x**30 - 1749695026860*x**29 + 2818953098830*x**28 - 4154246671960*x**27 + 5608233007146*x**26 - 6943526580276*x**25 + 7890371113950*x**24 - 8233430727600*x**23 + 7890371113950*x**22 - 6943526580276*x**21 + 5608233007146*x**20 - 4154246671960*x**19 + 2818953098830*x**18 - 1749695026860*x**17 + 991493848554*x**16 - 511738760544*x**15 + 239877544005*x**14 - 101766230790*x**13 + 38910617655*x**12 - 13340783196*x**11 + 4076350421*x**10 - 1101716330*x**9 + 260932815*x**8 - 53524680*x**7 + 9366819*x**6 - 1370754*x**5 + 163185*x**4 - 15180*x**3 + 1035*x**2 - 47*x + 1) + + assert R.dup_isolate_real_roots(f*g) == \ + [((0, QQ(1, 2)), 1), ((QQ(2, 3), QQ(3, 4)), 1), ((QQ(3, 4), 1), 1), ((6, 7), 1), ((24, 25), 1)] + + R, x = ring("x", EX) + raises(DomainError, lambda: R.dup_isolate_real_roots(x + 3)) + + +def test_dup_isolate_real_roots_list(): + R, x = ring("x", ZZ) + + assert R.dup_isolate_real_roots_list([x**2 + x, x]) == \ + [((-1, -1), {0: 1}), ((0, 0), {0: 1, 1: 1})] + assert R.dup_isolate_real_roots_list([x**2 - x, x]) == \ + [((0, 0), {0: 1, 1: 1}), ((1, 1), {0: 1})] + + assert R.dup_isolate_real_roots_list([x + 1, x + 2, x - 1, x + 1, x - 1, x - 1]) == \ + [((-QQ(2), -QQ(2)), {1: 1}), ((-QQ(1), -QQ(1)), {0: 1, 3: 1}), ((QQ(1), QQ(1)), {2: 1, 4: 1, 5: 1})] + + assert R.dup_isolate_real_roots_list([x + 1, x + 2, x - 1, x + 1, x - 1, x + 2]) == \ + [((-QQ(2), -QQ(2)), {1: 1, 5: 1}), ((-QQ(1), -QQ(1)), {0: 1, 3: 1}), ((QQ(1), QQ(1)), {2: 1, 4: 1})] + + f, g = x**4 - 4*x**2 + 4, x - 1 + + assert R.dup_isolate_real_roots_list([f, g], inf=QQ(7, 4)) == [] + assert R.dup_isolate_real_roots_list([f, g], inf=QQ(7, 5)) == \ + [((QQ(7, 5), QQ(3, 2)), {0: 2})] + assert R.dup_isolate_real_roots_list([f, g], sup=QQ(7, 5)) == \ + [((-2, -1), {0: 2}), ((1, 1), {1: 1})] + assert R.dup_isolate_real_roots_list([f, g], sup=QQ(7, 4)) == \ + [((-2, -1), {0: 2}), ((1, 1), {1: 1}), ((1, QQ(3, 2)), {0: 2})] + assert R.dup_isolate_real_roots_list([f, g], sup=-QQ(7, 4)) == [] + assert R.dup_isolate_real_roots_list([f, g], sup=-QQ(7, 5)) == \ + [((-QQ(3, 2), -QQ(7, 5)), {0: 2})] + assert R.dup_isolate_real_roots_list([f, g], inf=-QQ(7, 5)) == \ + [((1, 1), {1: 1}), ((1, 2), {0: 2})] + assert R.dup_isolate_real_roots_list([f, g], inf=-QQ(7, 4)) == \ + [((-QQ(3, 2), -1), {0: 2}), ((1, 1), {1: 1}), ((1, 2), {0: 2})] + + f, g = 2*x**2 - 1, x**2 - 2 + + assert R.dup_isolate_real_roots_list([f, g]) == \ + [((-QQ(2), -QQ(1)), {1: 1}), ((-QQ(1), QQ(0)), {0: 1}), + ((QQ(0), QQ(1)), {0: 1}), ((QQ(1), QQ(2)), {1: 1})] + assert R.dup_isolate_real_roots_list([f, g], strict=True) == \ + [((-QQ(3, 2), -QQ(4, 3)), {1: 1}), ((-QQ(1), -QQ(2, 3)), {0: 1}), + ((QQ(2, 3), QQ(1)), {0: 1}), ((QQ(4, 3), QQ(3, 2)), {1: 1})] + + f, g = x**2 - 2, x**3 - x**2 - 2*x + 2 + + assert R.dup_isolate_real_roots_list([f, g]) == \ + [((-QQ(2), -QQ(1)), {1: 1, 0: 1}), ((QQ(1), QQ(1)), {1: 1}), ((QQ(1), QQ(2)), {1: 1, 0: 1})] + + f, g = x**3 - 2*x, x**5 - x**4 - 2*x**3 + 2*x**2 + + assert R.dup_isolate_real_roots_list([f, g]) == \ + [((-QQ(2), -QQ(1)), {1: 1, 0: 1}), ((QQ(0), QQ(0)), {0: 1, 1: 2}), + ((QQ(1), QQ(1)), {1: 1}), ((QQ(1), QQ(2)), {1: 1, 0: 1})] + + f, g = x**9 - 3*x**8 - x**7 + 11*x**6 - 8*x**5 - 8*x**4 + 12*x**3 - 4*x**2, x**5 - 2*x**4 + 3*x**3 - 4*x**2 + 2*x + + assert R.dup_isolate_real_roots_list([f, g], basis=False) == \ + [((-2, -1), {0: 2}), ((0, 0), {0: 2, 1: 1}), ((1, 1), {0: 3, 1: 2}), ((1, 2), {0: 2})] + assert R.dup_isolate_real_roots_list([f, g], basis=True) == \ + [((-2, -1), {0: 2}, [1, 0, -2]), ((0, 0), {0: 2, 1: 1}, [1, 0]), + ((1, 1), {0: 3, 1: 2}, [1, -1]), ((1, 2), {0: 2}, [1, 0, -2])] + + R, x = ring("x", EX) + raises(DomainError, lambda: R.dup_isolate_real_roots_list([x + 3])) + + +def test_dup_isolate_real_roots_list_QQ(): + R, x = ring("x", ZZ) + + f = x**5 - 200 + g = x**5 - 201 + + assert R.dup_isolate_real_roots_list([f, g]) == \ + [((QQ(75, 26), QQ(101, 35)), {0: 1}), ((QQ(309, 107), QQ(26, 9)), {1: 1})] + + R, x = ring("x", QQ) + + f = -QQ(1, 200)*x**5 + 1 + g = -QQ(1, 201)*x**5 + 1 + + assert R.dup_isolate_real_roots_list([f, g]) == \ + [((QQ(75, 26), QQ(101, 35)), {0: 1}), ((QQ(309, 107), QQ(26, 9)), {1: 1})] + + +def test_dup_count_real_roots(): + R, x = ring("x", ZZ) + + assert R.dup_count_real_roots(0) == 0 + assert R.dup_count_real_roots(7) == 0 + + f = x - 1 + assert R.dup_count_real_roots(f) == 1 + assert R.dup_count_real_roots(f, inf=1) == 1 + assert R.dup_count_real_roots(f, sup=0) == 0 + assert R.dup_count_real_roots(f, sup=1) == 1 + assert R.dup_count_real_roots(f, inf=0, sup=1) == 1 + assert R.dup_count_real_roots(f, inf=0, sup=2) == 1 + assert R.dup_count_real_roots(f, inf=1, sup=2) == 1 + + f = x**2 - 2 + assert R.dup_count_real_roots(f) == 2 + assert R.dup_count_real_roots(f, sup=0) == 1 + assert R.dup_count_real_roots(f, inf=-1, sup=1) == 0 + + +# parameters for test_dup_count_complex_roots_n(): n = 1..8 +a, b = (-QQ(1), -QQ(1)), (QQ(1), QQ(1)) +c, d = ( QQ(0), QQ(0)), (QQ(1), QQ(1)) + +def test_dup_count_complex_roots_1(): + R, x = ring("x", ZZ) + + # z-1 + f = x - 1 + assert R.dup_count_complex_roots(f, a, b) == 1 + assert R.dup_count_complex_roots(f, c, d) == 1 + + # z+1 + f = x + 1 + assert R.dup_count_complex_roots(f, a, b) == 1 + assert R.dup_count_complex_roots(f, c, d) == 0 + + +def test_dup_count_complex_roots_2(): + R, x = ring("x", ZZ) + + # (z-1)*(z) + f = x**2 - x + assert R.dup_count_complex_roots(f, a, b) == 2 + assert R.dup_count_complex_roots(f, c, d) == 2 + + # (z-1)*(-z) + f = -x**2 + x + assert R.dup_count_complex_roots(f, a, b) == 2 + assert R.dup_count_complex_roots(f, c, d) == 2 + + # (z+1)*(z) + f = x**2 + x + assert R.dup_count_complex_roots(f, a, b) == 2 + assert R.dup_count_complex_roots(f, c, d) == 1 + + # (z+1)*(-z) + f = -x**2 - x + assert R.dup_count_complex_roots(f, a, b) == 2 + assert R.dup_count_complex_roots(f, c, d) == 1 + + +def test_dup_count_complex_roots_3(): + R, x = ring("x", ZZ) + + # (z-1)*(z+1) + f = x**2 - 1 + assert R.dup_count_complex_roots(f, a, b) == 2 + assert R.dup_count_complex_roots(f, c, d) == 1 + + # (z-1)*(z+1)*(z) + f = x**3 - x + assert R.dup_count_complex_roots(f, a, b) == 3 + assert R.dup_count_complex_roots(f, c, d) == 2 + + # (z-1)*(z+1)*(-z) + f = -x**3 + x + assert R.dup_count_complex_roots(f, a, b) == 3 + assert R.dup_count_complex_roots(f, c, d) == 2 + + +def test_dup_count_complex_roots_4(): + R, x = ring("x", ZZ) + + # (z-I)*(z+I) + f = x**2 + 1 + assert R.dup_count_complex_roots(f, a, b) == 2 + assert R.dup_count_complex_roots(f, c, d) == 1 + + # (z-I)*(z+I)*(z) + f = x**3 + x + assert R.dup_count_complex_roots(f, a, b) == 3 + assert R.dup_count_complex_roots(f, c, d) == 2 + + # (z-I)*(z+I)*(-z) + f = -x**3 - x + assert R.dup_count_complex_roots(f, a, b) == 3 + assert R.dup_count_complex_roots(f, c, d) == 2 + + # (z-I)*(z+I)*(z-1) + f = x**3 - x**2 + x - 1 + assert R.dup_count_complex_roots(f, a, b) == 3 + assert R.dup_count_complex_roots(f, c, d) == 2 + + # (z-I)*(z+I)*(z-1)*(z) + f = x**4 - x**3 + x**2 - x + assert R.dup_count_complex_roots(f, a, b) == 4 + assert R.dup_count_complex_roots(f, c, d) == 3 + + # (z-I)*(z+I)*(z-1)*(-z) + f = -x**4 + x**3 - x**2 + x + assert R.dup_count_complex_roots(f, a, b) == 4 + assert R.dup_count_complex_roots(f, c, d) == 3 + + # (z-I)*(z+I)*(z-1)*(z+1) + f = x**4 - 1 + assert R.dup_count_complex_roots(f, a, b) == 4 + assert R.dup_count_complex_roots(f, c, d) == 2 + + # (z-I)*(z+I)*(z-1)*(z+1)*(z) + f = x**5 - x + assert R.dup_count_complex_roots(f, a, b) == 5 + assert R.dup_count_complex_roots(f, c, d) == 3 + + # (z-I)*(z+I)*(z-1)*(z+1)*(-z) + f = -x**5 + x + assert R.dup_count_complex_roots(f, a, b) == 5 + assert R.dup_count_complex_roots(f, c, d) == 3 + + +def test_dup_count_complex_roots_5(): + R, x = ring("x", ZZ) + + # (z-I+1)*(z+I+1) + f = x**2 + 2*x + 2 + assert R.dup_count_complex_roots(f, a, b) == 2 + assert R.dup_count_complex_roots(f, c, d) == 0 + + # (z-I+1)*(z+I+1)*(z-1) + f = x**3 + x**2 - 2 + assert R.dup_count_complex_roots(f, a, b) == 3 + assert R.dup_count_complex_roots(f, c, d) == 1 + + # (z-I+1)*(z+I+1)*(z-1)*z + f = x**4 + x**3 - 2*x + assert R.dup_count_complex_roots(f, a, b) == 4 + assert R.dup_count_complex_roots(f, c, d) == 2 + + # (z-I+1)*(z+I+1)*(z+1) + f = x**3 + 3*x**2 + 4*x + 2 + assert R.dup_count_complex_roots(f, a, b) == 3 + assert R.dup_count_complex_roots(f, c, d) == 0 + + # (z-I+1)*(z+I+1)*(z+1)*z + f = x**4 + 3*x**3 + 4*x**2 + 2*x + assert R.dup_count_complex_roots(f, a, b) == 4 + assert R.dup_count_complex_roots(f, c, d) == 1 + + # (z-I+1)*(z+I+1)*(z-1)*(z+1) + f = x**4 + 2*x**3 + x**2 - 2*x - 2 + assert R.dup_count_complex_roots(f, a, b) == 4 + assert R.dup_count_complex_roots(f, c, d) == 1 + + # (z-I+1)*(z+I+1)*(z-1)*(z+1)*z + f = x**5 + 2*x**4 + x**3 - 2*x**2 - 2*x + assert R.dup_count_complex_roots(f, a, b) == 5 + assert R.dup_count_complex_roots(f, c, d) == 2 + + +def test_dup_count_complex_roots_6(): + R, x = ring("x", ZZ) + + # (z-I-1)*(z+I-1) + f = x**2 - 2*x + 2 + assert R.dup_count_complex_roots(f, a, b) == 2 + assert R.dup_count_complex_roots(f, c, d) == 1 + + # (z-I-1)*(z+I-1)*(z-1) + f = x**3 - 3*x**2 + 4*x - 2 + assert R.dup_count_complex_roots(f, a, b) == 3 + assert R.dup_count_complex_roots(f, c, d) == 2 + + # (z-I-1)*(z+I-1)*(z-1)*z + f = x**4 - 3*x**3 + 4*x**2 - 2*x + assert R.dup_count_complex_roots(f, a, b) == 4 + assert R.dup_count_complex_roots(f, c, d) == 3 + + # (z-I-1)*(z+I-1)*(z+1) + f = x**3 - x**2 + 2 + assert R.dup_count_complex_roots(f, a, b) == 3 + assert R.dup_count_complex_roots(f, c, d) == 1 + + # (z-I-1)*(z+I-1)*(z+1)*z + f = x**4 - x**3 + 2*x + assert R.dup_count_complex_roots(f, a, b) == 4 + assert R.dup_count_complex_roots(f, c, d) == 2 + + # (z-I-1)*(z+I-1)*(z-1)*(z+1) + f = x**4 - 2*x**3 + x**2 + 2*x - 2 + assert R.dup_count_complex_roots(f, a, b) == 4 + assert R.dup_count_complex_roots(f, c, d) == 2 + + # (z-I-1)*(z+I-1)*(z-1)*(z+1)*z + f = x**5 - 2*x**4 + x**3 + 2*x**2 - 2*x + assert R.dup_count_complex_roots(f, a, b) == 5 + assert R.dup_count_complex_roots(f, c, d) == 3 + + +def test_dup_count_complex_roots_7(): + R, x = ring("x", ZZ) + + # (z-I-1)*(z+I-1)*(z-I+1)*(z+I+1) + f = x**4 + 4 + assert R.dup_count_complex_roots(f, a, b) == 4 + assert R.dup_count_complex_roots(f, c, d) == 1 + + # (z-I-1)*(z+I-1)*(z-I+1)*(z+I+1)*(z-2) + f = x**5 - 2*x**4 + 4*x - 8 + assert R.dup_count_complex_roots(f, a, b) == 4 + assert R.dup_count_complex_roots(f, c, d) == 1 + + # (z-I-1)*(z+I-1)*(z-I+1)*(z+I+1)*(z**2-2) + f = x**6 - 2*x**4 + 4*x**2 - 8 + assert R.dup_count_complex_roots(f, a, b) == 4 + assert R.dup_count_complex_roots(f, c, d) == 1 + + # (z-I-1)*(z+I-1)*(z-I+1)*(z+I+1)*(z-1) + f = x**5 - x**4 + 4*x - 4 + assert R.dup_count_complex_roots(f, a, b) == 5 + assert R.dup_count_complex_roots(f, c, d) == 2 + + # (z-I-1)*(z+I-1)*(z-I+1)*(z+I+1)*(z-1)*z + f = x**6 - x**5 + 4*x**2 - 4*x + assert R.dup_count_complex_roots(f, a, b) == 6 + assert R.dup_count_complex_roots(f, c, d) == 3 + + # (z-I-1)*(z+I-1)*(z-I+1)*(z+I+1)*(z+1) + f = x**5 + x**4 + 4*x + 4 + assert R.dup_count_complex_roots(f, a, b) == 5 + assert R.dup_count_complex_roots(f, c, d) == 1 + + # (z-I-1)*(z+I-1)*(z-I+1)*(z+I+1)*(z+1)*z + f = x**6 + x**5 + 4*x**2 + 4*x + assert R.dup_count_complex_roots(f, a, b) == 6 + assert R.dup_count_complex_roots(f, c, d) == 2 + + # (z-I-1)*(z+I-1)*(z-I+1)*(z+I+1)*(z-1)*(z+1) + f = x**6 - x**4 + 4*x**2 - 4 + assert R.dup_count_complex_roots(f, a, b) == 6 + assert R.dup_count_complex_roots(f, c, d) == 2 + + # (z-I-1)*(z+I-1)*(z-I+1)*(z+I+1)*(z-1)*(z+1)*z + f = x**7 - x**5 + 4*x**3 - 4*x + assert R.dup_count_complex_roots(f, a, b) == 7 + assert R.dup_count_complex_roots(f, c, d) == 3 + + # (z-I-1)*(z+I-1)*(z-I+1)*(z+I+1)*(z-1)*(z+1)*(z-I)*(z+I) + f = x**8 + 3*x**4 - 4 + assert R.dup_count_complex_roots(f, a, b) == 8 + assert R.dup_count_complex_roots(f, c, d) == 3 + + +def test_dup_count_complex_roots_8(): + R, x = ring("x", ZZ) + + # (z-I-1)*(z+I-1)*(z-I+1)*(z+I+1)*(z-1)*(z+1)*(z-I)*(z+I)*z + f = x**9 + 3*x**5 - 4*x + assert R.dup_count_complex_roots(f, a, b) == 9 + assert R.dup_count_complex_roots(f, c, d) == 4 + + # (z-I-1)*(z+I-1)*(z-I+1)*(z+I+1)*(z-1)*(z+1)*(z-I)*(z+I)*(z**2-2)*z + f = x**11 - 2*x**9 + 3*x**7 - 6*x**5 - 4*x**3 + 8*x + assert R.dup_count_complex_roots(f, a, b) == 9 + assert R.dup_count_complex_roots(f, c, d) == 4 + + +def test_dup_count_complex_roots_implicit(): + R, x = ring("x", ZZ) + + # z*(z-1)*(z+1)*(z-I)*(z+I) + f = x**5 - x + + assert R.dup_count_complex_roots(f) == 5 + + assert R.dup_count_complex_roots(f, sup=(0, 0)) == 3 + assert R.dup_count_complex_roots(f, inf=(0, 0)) == 3 + + +def test_dup_count_complex_roots_exclude(): + R, x = ring("x", ZZ) + + # z*(z-1)*(z+1)*(z-I)*(z+I) + f = x**5 - x + + a, b = (-QQ(1), QQ(0)), (QQ(1), QQ(1)) + + assert R.dup_count_complex_roots(f, a, b) == 4 + + assert R.dup_count_complex_roots(f, a, b, exclude=['S']) == 3 + assert R.dup_count_complex_roots(f, a, b, exclude=['N']) == 3 + + assert R.dup_count_complex_roots(f, a, b, exclude=['S', 'N']) == 2 + + assert R.dup_count_complex_roots(f, a, b, exclude=['E']) == 4 + assert R.dup_count_complex_roots(f, a, b, exclude=['W']) == 4 + + assert R.dup_count_complex_roots(f, a, b, exclude=['E', 'W']) == 4 + + assert R.dup_count_complex_roots(f, a, b, exclude=['N', 'S', 'E', 'W']) == 2 + + assert R.dup_count_complex_roots(f, a, b, exclude=['SW']) == 3 + assert R.dup_count_complex_roots(f, a, b, exclude=['SE']) == 3 + + assert R.dup_count_complex_roots(f, a, b, exclude=['SW', 'SE']) == 2 + assert R.dup_count_complex_roots(f, a, b, exclude=['SW', 'SE', 'S']) == 1 + assert R.dup_count_complex_roots(f, a, b, exclude=['SW', 'SE', 'S', 'N']) == 0 + + a, b = (QQ(0), QQ(0)), (QQ(1), QQ(1)) + + assert R.dup_count_complex_roots(f, a, b, exclude=True) == 1 + + +def test_dup_isolate_complex_roots_sqf(): + R, x = ring("x", ZZ) + f = x**2 - 2*x + 3 + + assert R.dup_isolate_complex_roots_sqf(f) == \ + [((0, -6), (6, 0)), ((0, 0), (6, 6))] + assert [ r.as_tuple() for r in R.dup_isolate_complex_roots_sqf(f, blackbox=True) ] == \ + [((0, -6), (6, 0)), ((0, 0), (6, 6))] + + assert R.dup_isolate_complex_roots_sqf(f, eps=QQ(1, 10)) == \ + [((QQ(15, 16), -QQ(3, 2)), (QQ(33, 32), -QQ(45, 32))), + ((QQ(15, 16), QQ(45, 32)), (QQ(33, 32), QQ(3, 2)))] + assert R.dup_isolate_complex_roots_sqf(f, eps=QQ(1, 100)) == \ + [((QQ(255, 256), -QQ(363, 256)), (QQ(513, 512), -QQ(723, 512))), + ((QQ(255, 256), QQ(723, 512)), (QQ(513, 512), QQ(363, 256)))] + + f = 7*x**4 - 19*x**3 + 20*x**2 + 17*x + 20 + + assert R.dup_isolate_complex_roots_sqf(f) == \ + [((-QQ(40, 7), -QQ(40, 7)), (0, 0)), ((-QQ(40, 7), 0), (0, QQ(40, 7))), + ((0, -QQ(40, 7)), (QQ(40, 7), 0)), ((0, 0), (QQ(40, 7), QQ(40, 7)))] + + +def test_dup_isolate_all_roots_sqf(): + R, x = ring("x", ZZ) + f = 4*x**4 - x**3 + 2*x**2 + 5*x + + assert R.dup_isolate_all_roots_sqf(f) == \ + ([(-1, 0), (0, 0)], + [((0, -QQ(5, 2)), (QQ(5, 2), 0)), ((0, 0), (QQ(5, 2), QQ(5, 2)))]) + + assert R.dup_isolate_all_roots_sqf(f, eps=QQ(1, 10)) == \ + ([(QQ(-7, 8), QQ(-6, 7)), (0, 0)], + [((QQ(35, 64), -QQ(35, 32)), (QQ(5, 8), -QQ(65, 64))), ((QQ(35, 64), QQ(65, 64)), (QQ(5, 8), QQ(35, 32)))]) + + +def test_dup_isolate_all_roots(): + R, x = ring("x", ZZ) + f = 4*x**4 - x**3 + 2*x**2 + 5*x + + assert R.dup_isolate_all_roots(f) == \ + ([((-1, 0), 1), ((0, 0), 1)], + [(((0, -QQ(5, 2)), (QQ(5, 2), 0)), 1), + (((0, 0), (QQ(5, 2), QQ(5, 2))), 1)]) + + assert R.dup_isolate_all_roots(f, eps=QQ(1, 10)) == \ + ([((QQ(-7, 8), QQ(-6, 7)), 1), ((0, 0), 1)], + [(((QQ(35, 64), -QQ(35, 32)), (QQ(5, 8), -QQ(65, 64))), 1), + (((QQ(35, 64), QQ(65, 64)), (QQ(5, 8), QQ(35, 32))), 1)]) + + f = x**5 + x**4 - 2*x**3 - 2*x**2 + x + 1 + raises(NotImplementedError, lambda: R.dup_isolate_all_roots(f)) diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/test_rootoftools.py b/phivenv/Lib/site-packages/sympy/polys/tests/test_rootoftools.py new file mode 100644 index 0000000000000000000000000000000000000000..de9dbcabd0a7e2bed0c5adb7127041b4be058379 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/tests/test_rootoftools.py @@ -0,0 +1,697 @@ +"""Tests for the implementation of RootOf class and related tools. """ + +from sympy.polys.polytools import Poly +import sympy.polys.rootoftools as rootoftools +from sympy.polys.rootoftools import (rootof, RootOf, CRootOf, RootSum, + _pure_key_dict as D) + +from sympy.polys.polyerrors import ( + MultivariatePolynomialError, + GeneratorsNeeded, + PolynomialError, +) + +from sympy.core.function import (Function, Lambda) +from sympy.core.numbers import (Float, I, Rational) +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import tan +from sympy.integrals.integrals import Integral +from sympy.polys.orthopolys import legendre_poly +from sympy.solvers.solvers import solve + + +from sympy.testing.pytest import raises, slow +from sympy.core.expr import unchanged + +from sympy.abc import a, b, x, y, z, r + + +def test_CRootOf___new__(): + assert rootof(x, 0) == 0 + assert rootof(x, -1) == 0 + + assert rootof(x, S.Zero) == 0 + + assert rootof(x - 1, 0) == 1 + assert rootof(x - 1, -1) == 1 + + assert rootof(x + 1, 0) == -1 + assert rootof(x + 1, -1) == -1 + + assert rootof(x**2 + 2*x + 3, 0) == -1 - I*sqrt(2) + assert rootof(x**2 + 2*x + 3, 1) == -1 + I*sqrt(2) + assert rootof(x**2 + 2*x + 3, -1) == -1 + I*sqrt(2) + assert rootof(x**2 + 2*x + 3, -2) == -1 - I*sqrt(2) + + r = rootof(x**2 + 2*x + 3, 0, radicals=False) + assert isinstance(r, RootOf) is True + + r = rootof(x**2 + 2*x + 3, 1, radicals=False) + assert isinstance(r, RootOf) is True + + r = rootof(x**2 + 2*x + 3, -1, radicals=False) + assert isinstance(r, RootOf) is True + + r = rootof(x**2 + 2*x + 3, -2, radicals=False) + assert isinstance(r, RootOf) is True + + assert rootof((x - 1)*(x + 1), 0, radicals=False) == -1 + assert rootof((x - 1)*(x + 1), 1, radicals=False) == 1 + assert rootof((x - 1)*(x + 1), -1, radicals=False) == 1 + assert rootof((x - 1)*(x + 1), -2, radicals=False) == -1 + + assert rootof((x - 1)*(x + 1), 0, radicals=True) == -1 + assert rootof((x - 1)*(x + 1), 1, radicals=True) == 1 + assert rootof((x - 1)*(x + 1), -1, radicals=True) == 1 + assert rootof((x - 1)*(x + 1), -2, radicals=True) == -1 + + assert rootof((x - 1)*(x**3 + x + 3), 0) == rootof(x**3 + x + 3, 0) + assert rootof((x - 1)*(x**3 + x + 3), 1) == 1 + assert rootof((x - 1)*(x**3 + x + 3), 2) == rootof(x**3 + x + 3, 1) + assert rootof((x - 1)*(x**3 + x + 3), 3) == rootof(x**3 + x + 3, 2) + assert rootof((x - 1)*(x**3 + x + 3), -1) == rootof(x**3 + x + 3, 2) + assert rootof((x - 1)*(x**3 + x + 3), -2) == rootof(x**3 + x + 3, 1) + assert rootof((x - 1)*(x**3 + x + 3), -3) == 1 + assert rootof((x - 1)*(x**3 + x + 3), -4) == rootof(x**3 + x + 3, 0) + + assert rootof(x**4 + 3*x**3, 0) == -3 + assert rootof(x**4 + 3*x**3, 1) == 0 + assert rootof(x**4 + 3*x**3, 2) == 0 + assert rootof(x**4 + 3*x**3, 3) == 0 + + raises(GeneratorsNeeded, lambda: rootof(0, 0)) + raises(GeneratorsNeeded, lambda: rootof(1, 0)) + + raises(PolynomialError, lambda: rootof(Poly(0, x), 0)) + raises(PolynomialError, lambda: rootof(Poly(1, x), 0)) + raises(PolynomialError, lambda: rootof(x - y, 0)) + # issue 8617 + raises(PolynomialError, lambda: rootof(exp(x), 0)) + + raises(NotImplementedError, lambda: rootof(x**3 - x + sqrt(2), 0)) + raises(NotImplementedError, lambda: rootof(x**3 - x + I, 0)) + + raises(IndexError, lambda: rootof(x**2 - 1, -4)) + raises(IndexError, lambda: rootof(x**2 - 1, -3)) + raises(IndexError, lambda: rootof(x**2 - 1, 2)) + raises(IndexError, lambda: rootof(x**2 - 1, 3)) + raises(ValueError, lambda: rootof(x**2 - 1, x)) + + assert rootof(Poly(x - y, x), 0) == y + + assert rootof(Poly(x**2 - y, x), 0) == -sqrt(y) + assert rootof(Poly(x**2 - y, x), 1) == sqrt(y) + + assert rootof(Poly(x**3 - y, x), 0) == y**Rational(1, 3) + + assert rootof(y*x**3 + y*x + 2*y, x, 0) == -1 + raises(NotImplementedError, lambda: rootof(x**3 + x + 2*y, x, 0)) + + assert rootof(x**3 + x + 1, 0).is_commutative is True + + +def test_CRootOf_attributes(): + r = rootof(x**3 + x + 3, 0) + assert r.is_number + assert r.free_symbols == set() + # if the following assertion fails then multivariate polynomials + # are apparently supported and the RootOf.free_symbols routine + # should be changed to return whatever symbols would not be + # the PurePoly dummy symbol + raises(NotImplementedError, lambda: rootof(Poly(x**3 + y*x + 1, x), 0)) + + +def test_CRootOf___eq__(): + assert (rootof(x**3 + x + 3, 0) == rootof(x**3 + x + 3, 0)) is True + assert (rootof(x**3 + x + 3, 0) == rootof(x**3 + x + 3, 1)) is False + assert (rootof(x**3 + x + 3, 1) == rootof(x**3 + x + 3, 1)) is True + assert (rootof(x**3 + x + 3, 1) == rootof(x**3 + x + 3, 2)) is False + assert (rootof(x**3 + x + 3, 2) == rootof(x**3 + x + 3, 2)) is True + + assert (rootof(x**3 + x + 3, 0) == rootof(y**3 + y + 3, 0)) is True + assert (rootof(x**3 + x + 3, 0) == rootof(y**3 + y + 3, 1)) is False + assert (rootof(x**3 + x + 3, 1) == rootof(y**3 + y + 3, 1)) is True + assert (rootof(x**3 + x + 3, 1) == rootof(y**3 + y + 3, 2)) is False + assert (rootof(x**3 + x + 3, 2) == rootof(y**3 + y + 3, 2)) is True + + +def test_CRootOf___eval_Eq__(): + f = Function('f') + eq = x**3 + x + 3 + r = rootof(eq, 2) + r1 = rootof(eq, 1) + assert Eq(r, r1) is S.false + assert Eq(r, r) is S.true + assert unchanged(Eq, r, x) + assert Eq(r, 0) is S.false + assert Eq(r, S.Infinity) is S.false + assert Eq(r, I) is S.false + assert unchanged(Eq, r, f(0)) + sol = solve(eq) + for s in sol: + if s.is_real: + assert Eq(r, s) is S.false + r = rootof(eq, 0) + for s in sol: + if s.is_real: + assert Eq(r, s) is S.true + eq = x**3 + x + 1 + sol = solve(eq) + assert [Eq(rootof(eq, i), j) for i in range(3) for j in sol + ].count(True) == 3 + assert Eq(rootof(eq, 0), 1 + S.ImaginaryUnit) == False + + +def test_CRootOf_is_real(): + assert rootof(x**3 + x + 3, 0).is_real is True + assert rootof(x**3 + x + 3, 1).is_real is False + assert rootof(x**3 + x + 3, 2).is_real is False + + +def test_CRootOf_is_complex(): + assert rootof(x**3 + x + 3, 0).is_complex is True + + +def test_CRootOf_is_algebraic(): + assert rootof(x**3 + x + 3, 0).is_algebraic is True + assert rootof(x**3 + x + 3, 1).is_algebraic is True + assert rootof(x**3 + x + 3, 2).is_algebraic is True + + +def test_CRootOf_subs(): + assert rootof(x**3 + x + 1, 0).subs(x, y) == rootof(y**3 + y + 1, 0) + + +def test_CRootOf_diff(): + assert rootof(x**3 + x + 1, 0).diff(x) == 0 + assert rootof(x**3 + x + 1, 0).diff(y) == 0 + +@slow +def test_CRootOf_evalf(): + real = rootof(x**3 + x + 3, 0).evalf(n=20) + + assert real.epsilon_eq(Float("-1.2134116627622296341")) + + re, im = rootof(x**3 + x + 3, 1).evalf(n=20).as_real_imag() + + assert re.epsilon_eq( Float("0.60670583138111481707")) + assert im.epsilon_eq(-Float("1.45061224918844152650")) + + re, im = rootof(x**3 + x + 3, 2).evalf(n=20).as_real_imag() + + assert re.epsilon_eq(Float("0.60670583138111481707")) + assert im.epsilon_eq(Float("1.45061224918844152650")) + + p = legendre_poly(4, x, polys=True) + roots = [str(r.n(17)) for r in p.real_roots()] + # magnitudes are given by + # sqrt(3/S(7) - 2*sqrt(6/S(5))/7) + # and + # sqrt(3/S(7) + 2*sqrt(6/S(5))/7) + assert roots == [ + "-0.86113631159405258", + "-0.33998104358485626", + "0.33998104358485626", + "0.86113631159405258", + ] + + re = rootof(x**5 - 5*x + 12, 0).evalf(n=20) + assert re.epsilon_eq(Float("-1.84208596619025438271")) + + re, im = rootof(x**5 - 5*x + 12, 1).evalf(n=20).as_real_imag() + assert re.epsilon_eq(Float("-0.351854240827371999559")) + assert im.epsilon_eq(Float("-1.709561043370328882010")) + + re, im = rootof(x**5 - 5*x + 12, 2).evalf(n=20).as_real_imag() + assert re.epsilon_eq(Float("-0.351854240827371999559")) + assert im.epsilon_eq(Float("+1.709561043370328882010")) + + re, im = rootof(x**5 - 5*x + 12, 3).evalf(n=20).as_real_imag() + assert re.epsilon_eq(Float("+1.272897223922499190910")) + assert im.epsilon_eq(Float("-0.719798681483861386681")) + + re, im = rootof(x**5 - 5*x + 12, 4).evalf(n=20).as_real_imag() + assert re.epsilon_eq(Float("+1.272897223922499190910")) + assert im.epsilon_eq(Float("+0.719798681483861386681")) + + # issue 6393 + assert str(rootof(x**5 + 2*x**4 + x**3 - 68719476736, 0).n(3)) == '147.' + eq = (531441*x**11 + 3857868*x**10 + 13730229*x**9 + 32597882*x**8 + + 55077472*x**7 + 60452000*x**6 + 32172064*x**5 - 4383808*x**4 - + 11942912*x**3 - 1506304*x**2 + 1453312*x + 512) + a, b = rootof(eq, 1).n(2).as_real_imag() + c, d = rootof(eq, 2).n(2).as_real_imag() + assert a == c + assert b < d + assert b == -d + # issue 6451 + r = rootof(legendre_poly(64, x), 7) + assert r.n(2) == r.n(100).n(2) + # issue 9019 + r0 = rootof(x**2 + 1, 0, radicals=False) + r1 = rootof(x**2 + 1, 1, radicals=False) + assert r0.n(4) == Float(-1.0, 4) * I + assert r1.n(4) == Float(1.0, 4) * I + + # make sure verification is used in case a max/min traps the "root" + assert str(rootof(4*x**5 + 16*x**3 + 12*x**2 + 7, 0).n(3)) == '-0.976' + + # watch out for UnboundLocalError + c = CRootOf(90720*x**6 - 4032*x**4 + 84*x**2 - 1, 0) + assert c._eval_evalf(2) # doesn't fail + + # watch out for imaginary parts that don't want to evaluate + assert str(RootOf(x**16 + 32*x**14 + 508*x**12 + 5440*x**10 + + 39510*x**8 + 204320*x**6 + 755548*x**4 + 1434496*x**2 + + 877969, 10).n(2)) == '-3.4*I' + assert abs(RootOf(x**4 + 10*x**2 + 1, 0).n(2)) < 0.4 + + # check reset and args + r = [RootOf(x**3 + x + 3, i) for i in range(3)] + r[0]._reset() + for ri in r: + i = ri._get_interval() + ri.n(2) + assert i != ri._get_interval() + ri._reset() + assert i == ri._get_interval() + assert i == i.func(*i.args) + + +def test_issue_24978(): + # Irreducible poly with negative leading coeff is normalized + # (factor of -1 is extracted), before being stored as CRootOf.poly. + f = -x**2 + 2 + r = CRootOf(f, 0) + assert r.poly.as_expr() == x**2 - 2 + # An action that prompts calculation of an interval puts r.poly in + # the cache. + r.n() + assert r.poly in rootoftools._reals_cache + + +def test_CRootOf_evalf_caching_bug(): + r = rootof(x**5 - 5*x + 12, 1) + r.n() + a = r._get_interval() + r = rootof(x**5 - 5*x + 12, 1) + r.n() + b = r._get_interval() + assert a == b + + +def test_CRootOf_real_roots(): + assert Poly(x**5 + x + 1).real_roots() == [rootof(x**3 - x**2 + 1, 0)] + assert Poly(x**5 + x + 1).real_roots(radicals=False) == [rootof( + x**3 - x**2 + 1, 0)] + + # https://github.com/sympy/sympy/issues/20902 + p = Poly(-3*x**4 - 10*x**3 - 12*x**2 - 6*x - 1, x, domain='ZZ') + assert CRootOf.real_roots(p) == [S(-1), S(-1), S(-1), S(-1)/3] + + # with real algebraic coefficients + assert Poly(x**3 + sqrt(2)*x**2 - 1, x, extension=True).real_roots() == [ + rootof(x**6 - 2*x**4 - 2*x**3 + 1, 0) + ] + assert Poly(x**5 + sqrt(2) * x**3 - 1, x, extension=True).real_roots() == [ + rootof(x**10 - 2*x**6 - 2*x**5 + 1, 0) + ] + r = rootof(y**5 + y**3 - 1, 0) + assert Poly(x**5 + r*x - 1, x, extension=True).real_roots() ==\ + [ + rootof(x**25 - 5*x**20 + x**17 + 10*x**15 - 3*x**12 - + 10*x**10 + 3*x**7 + 6*x**5 - x**2 - 1, 0) + ] + # roots with multiplicity + assert Poly((x-1) * (x-sqrt(2))**2, x, extension=True).real_roots() ==\ + [ + S(1), sqrt(2), sqrt(2) + ] + + +def test_CRootOf_all_roots(): + assert Poly(x**5 + x + 1).all_roots() == [ + rootof(x**3 - x**2 + 1, 0), + Rational(-1, 2) - sqrt(3)*I/2, + Rational(-1, 2) + sqrt(3)*I/2, + rootof(x**3 - x**2 + 1, 1), + rootof(x**3 - x**2 + 1, 2), + ] + + assert Poly(x**5 + x + 1).all_roots(radicals=False) == [ + rootof(x**3 - x**2 + 1, 0), + rootof(x**2 + x + 1, 0, radicals=False), + rootof(x**2 + x + 1, 1, radicals=False), + rootof(x**3 - x**2 + 1, 1), + rootof(x**3 - x**2 + 1, 2), + ] + + # with real algebraic coefficients + assert Poly(x**3 + sqrt(2)*x**2 - 1, x, extension=True).all_roots() ==\ + [ + rootof(x**6 - 2*x**4 - 2*x**3 + 1, 0), + rootof(x**6 - 2*x**4 - 2*x**3 + 1, 2), + rootof(x**6 - 2*x**4 - 2*x**3 + 1, 3) + ] + # roots with multiplicity + assert Poly((x-1) * (x-sqrt(2))**2 * (x-I) * (x+I), x, extension=True).all_roots() ==\ + [ + S(1), sqrt(2), sqrt(2), -I, I + ] + + # imaginary algebraic coeffs (gaussian domain) + assert Poly(x**2 - I/2, x, extension=True).all_roots() ==\ + [ + S(1)/2 + I/2, + -S(1)/2 - I/2 + ] + + +def test_CRootOf_eval_rational(): + p = legendre_poly(4, x, polys=True) + roots = [r.eval_rational(n=18) for r in p.real_roots()] + for root in roots: + assert isinstance(root, Rational) + roots = [str(root.n(17)) for root in roots] + assert roots == [ + "-0.86113631159405258", + "-0.33998104358485626", + "0.33998104358485626", + "0.86113631159405258", + ] + + +def test_CRootOf_lazy(): + # irreducible poly with both real and complex roots: + f = Poly(x**3 + 2*x + 2) + + # real root: + CRootOf.clear_cache() + r = CRootOf(f, 0) + # Not yet in cache, after construction: + assert r.poly not in rootoftools._reals_cache + assert r.poly not in rootoftools._complexes_cache + r.evalf() + # In cache after evaluation: + assert r.poly in rootoftools._reals_cache + assert r.poly not in rootoftools._complexes_cache + + # complex root: + CRootOf.clear_cache() + r = CRootOf(f, 1) + # Not yet in cache, after construction: + assert r.poly not in rootoftools._reals_cache + assert r.poly not in rootoftools._complexes_cache + r.evalf() + # In cache after evaluation: + assert r.poly in rootoftools._reals_cache + assert r.poly in rootoftools._complexes_cache + + # composite poly with both real and complex roots: + f = Poly((x**2 - 2)*(x**2 + 1)) + + # real root: + CRootOf.clear_cache() + r = CRootOf(f, 0) + # In cache immediately after construction: + assert r.poly in rootoftools._reals_cache + assert r.poly not in rootoftools._complexes_cache + + # complex root: + CRootOf.clear_cache() + r = CRootOf(f, 2) + # In cache immediately after construction: + assert r.poly in rootoftools._reals_cache + assert r.poly in rootoftools._complexes_cache + + +def test_RootSum___new__(): + f = x**3 + x + 3 + + g = Lambda(r, log(r*x)) + s = RootSum(f, g) + + assert isinstance(s, RootSum) is True + + assert RootSum(f**2, g) == 2*RootSum(f, g) + assert RootSum((x - 7)*f**3, g) == log(7*x) + 3*RootSum(f, g) + + # issue 5571 + assert hash(RootSum((x - 7)*f**3, g)) == hash(log(7*x) + 3*RootSum(f, g)) + + raises(MultivariatePolynomialError, lambda: RootSum(x**3 + x + y)) + raises(ValueError, lambda: RootSum(x**2 + 3, lambda x: x)) + + assert RootSum(f, exp) == RootSum(f, Lambda(x, exp(x))) + assert RootSum(f, log) == RootSum(f, Lambda(x, log(x))) + + assert isinstance(RootSum(f, auto=False), RootSum) is True + + assert RootSum(f) == 0 + assert RootSum(f, Lambda(x, x)) == 0 + assert RootSum(f, Lambda(x, x**2)) == -2 + + assert RootSum(f, Lambda(x, 1)) == 3 + assert RootSum(f, Lambda(x, 2)) == 6 + + assert RootSum(f, auto=False).is_commutative is True + + assert RootSum(f, Lambda(x, 1/(x + x**2))) == Rational(11, 3) + assert RootSum(f, Lambda(x, y/(x + x**2))) == Rational(11, 3)*y + + assert RootSum(x**2 - 1, Lambda(x, 3*x**2), x) == 6 + assert RootSum(x**2 - y, Lambda(x, 3*x**2), x) == 6*y + + assert RootSum(x**2 - 1, Lambda(x, z*x**2), x) == 2*z + assert RootSum(x**2 - y, Lambda(x, z*x**2), x) == 2*z*y + + assert RootSum( + x**2 - 1, Lambda(x, exp(x)), quadratic=True) == exp(-1) + exp(1) + + assert RootSum(x**3 + a*x + a**3, tan, x) == \ + RootSum(x**3 + x + 1, Lambda(x, tan(a*x))) + assert RootSum(a**3*x**3 + a*x + 1, tan, x) == \ + RootSum(x**3 + x + 1, Lambda(x, tan(x/a))) + + +def test_RootSum_free_symbols(): + assert RootSum(x**3 + x + 3, Lambda(r, exp(r))).free_symbols == set() + assert RootSum(x**3 + x + 3, Lambda(r, exp(a*r))).free_symbols == {a} + assert RootSum( + x**3 + x + y, Lambda(r, exp(a*r)), x).free_symbols == {a, y} + + +def test_RootSum___eq__(): + f = Lambda(x, exp(x)) + + assert (RootSum(x**3 + x + 1, f) == RootSum(x**3 + x + 1, f)) is True + assert (RootSum(x**3 + x + 1, f) == RootSum(y**3 + y + 1, f)) is True + + assert (RootSum(x**3 + x + 1, f) == RootSum(x**3 + x + 2, f)) is False + assert (RootSum(x**3 + x + 1, f) == RootSum(y**3 + y + 2, f)) is False + + +def test_RootSum_doit(): + rs = RootSum(x**2 + 1, exp) + + assert isinstance(rs, RootSum) is True + assert rs.doit() == exp(-I) + exp(I) + + rs = RootSum(x**2 + a, exp, x) + + assert isinstance(rs, RootSum) is True + assert rs.doit() == exp(-sqrt(-a)) + exp(sqrt(-a)) + + +def test_RootSum_evalf(): + rs = RootSum(x**2 + 1, exp) + + assert rs.evalf(n=20, chop=True).epsilon_eq(Float("1.0806046117362794348")) + assert rs.evalf(n=15, chop=True).epsilon_eq(Float("1.08060461173628")) + + rs = RootSum(x**2 + a, exp, x) + + assert rs.evalf() == rs + + +def test_RootSum_diff(): + f = x**3 + x + 3 + + g = Lambda(r, exp(r*x)) + h = Lambda(r, r*exp(r*x)) + + assert RootSum(f, g).diff(x) == RootSum(f, h) + + +def test_RootSum_subs(): + f = x**3 + x + 3 + g = Lambda(r, exp(r*x)) + + F = y**3 + y + 3 + G = Lambda(r, exp(r*y)) + + assert RootSum(f, g).subs(y, 1) == RootSum(f, g) + assert RootSum(f, g).subs(x, y) == RootSum(F, G) + + +def test_RootSum_rational(): + assert RootSum( + z**5 - z + 1, Lambda(z, z/(x - z))) == (4*x - 5)/(x**5 - x + 1) + + f = 161*z**3 + 115*z**2 + 19*z + 1 + g = Lambda(z, z*log( + -3381*z**4/4 - 3381*z**3/4 - 625*z**2/2 - z*Rational(125, 2) - 5 + exp(x))) + + assert RootSum(f, g).diff(x) == -( + (5*exp(2*x) - 6*exp(x) + 4)*exp(x)/(exp(3*x) - exp(2*x) + 1))/7 + + +def test_RootSum_independent(): + f = (x**3 - a)**2*(x**4 - b)**3 + + g = Lambda(x, 5*tan(x) + 7) + h = Lambda(x, tan(x)) + + r0 = RootSum(x**3 - a, h, x) + r1 = RootSum(x**4 - b, h, x) + + assert RootSum(f, g, x).as_ordered_terms() == [10*r0, 15*r1, 126] + + +def test_issue_7876(): + l1 = Poly(x**6 - x + 1, x).all_roots() + l2 = [rootof(x**6 - x + 1, i) for i in range(6)] + assert frozenset(l1) == frozenset(l2) + + +def test_issue_8316(): + f = Poly(7*x**8 - 9) + assert len(f.all_roots()) == 8 + f = Poly(7*x**8 - 10) + assert len(f.all_roots()) == 8 + + +def test__imag_count(): + from sympy.polys.rootoftools import _imag_count_of_factor + def imag_count(p): + return sum(_imag_count_of_factor(f)*m for f, m in + p.factor_list()[1]) + assert imag_count(Poly(x**6 + 10*x**2 + 1)) == 2 + assert imag_count(Poly(x**2)) == 0 + assert imag_count(Poly([1]*3 + [-1], x)) == 0 + assert imag_count(Poly(x**3 + 1)) == 0 + assert imag_count(Poly(x**2 + 1)) == 2 + assert imag_count(Poly(x**2 - 1)) == 0 + assert imag_count(Poly(x**4 - 1)) == 2 + assert imag_count(Poly(x**4 + 1)) == 0 + assert imag_count(Poly([1, 2, 3], x)) == 0 + assert imag_count(Poly(x**3 + x + 1)) == 0 + assert imag_count(Poly(x**4 + x + 1)) == 0 + def q(r1, r2, p): + return Poly(((x - r1)*(x - r2)).subs(x, x**p), x) + assert imag_count(q(-1, -2, 2)) == 4 + assert imag_count(q(-1, 2, 2)) == 2 + assert imag_count(q(1, 2, 2)) == 0 + assert imag_count(q(1, 2, 4)) == 4 + assert imag_count(q(-1, 2, 4)) == 2 + assert imag_count(q(-1, -2, 4)) == 0 + + +def test_RootOf_is_imaginary(): + r = RootOf(x**4 + 4*x**2 + 1, 1) + i = r._get_interval() + assert r.is_imaginary and i.ax*i.bx <= 0 + + +def test_is_disjoint(): + eq = x**3 + 5*x + 1 + ir = rootof(eq, 0)._get_interval() + ii = rootof(eq, 1)._get_interval() + assert ir.is_disjoint(ii) + assert ii.is_disjoint(ir) + + +def test_pure_key_dict(): + p = D() + assert (x in p) is False + assert (1 in p) is False + p[x] = 1 + assert x in p + assert y in p + assert p[y] == 1 + raises(KeyError, lambda: p[1]) + def dont(k): + p[k] = 2 + raises(ValueError, lambda: dont(1)) + + +@slow +def test_eval_approx_relative(): + CRootOf.clear_cache() + t = [CRootOf(x**3 + 10*x + 1, i) for i in range(3)] + assert [i.eval_rational(1e-1) for i in t] == [ + Rational(-21, 220), Rational(15, 256) - I*805/256, + Rational(15, 256) + I*805/256] + t[0]._reset() + assert [i.eval_rational(1e-1, 1e-4) for i in t] == [ + Rational(-21, 220), Rational(3275, 65536) - I*414645/131072, + Rational(3275, 65536) + I*414645/131072] + assert S(t[0]._get_interval().dx) < 1e-1 + assert S(t[1]._get_interval().dx) < 1e-1 + assert S(t[1]._get_interval().dy) < 1e-4 + assert S(t[2]._get_interval().dx) < 1e-1 + assert S(t[2]._get_interval().dy) < 1e-4 + t[0]._reset() + assert [i.eval_rational(1e-4, 1e-4) for i in t] == [ + Rational(-2001, 20020), Rational(6545, 131072) - I*414645/131072, + Rational(6545, 131072) + I*414645/131072] + assert S(t[0]._get_interval().dx) < 1e-4 + assert S(t[1]._get_interval().dx) < 1e-4 + assert S(t[1]._get_interval().dy) < 1e-4 + assert S(t[2]._get_interval().dx) < 1e-4 + assert S(t[2]._get_interval().dy) < 1e-4 + # in the following, the actual relative precision is + # less than tested, but it should never be greater + t[0]._reset() + assert [i.eval_rational(n=2) for i in t] == [ + Rational(-202201, 2024022), Rational(104755, 2097152) - I*6634255/2097152, + Rational(104755, 2097152) + I*6634255/2097152] + assert abs(S(t[0]._get_interval().dx)/t[0]) < 1e-2 + assert abs(S(t[1]._get_interval().dx)/t[1]).n() < 1e-2 + assert abs(S(t[1]._get_interval().dy)/t[1]).n() < 1e-2 + assert abs(S(t[2]._get_interval().dx)/t[2]).n() < 1e-2 + assert abs(S(t[2]._get_interval().dy)/t[2]).n() < 1e-2 + t[0]._reset() + assert [i.eval_rational(n=3) for i in t] == [ + Rational(-202201, 2024022), Rational(1676045, 33554432) - I*106148135/33554432, + Rational(1676045, 33554432) + I*106148135/33554432] + assert abs(S(t[0]._get_interval().dx)/t[0]) < 1e-3 + assert abs(S(t[1]._get_interval().dx)/t[1]).n() < 1e-3 + assert abs(S(t[1]._get_interval().dy)/t[1]).n() < 1e-3 + assert abs(S(t[2]._get_interval().dx)/t[2]).n() < 1e-3 + assert abs(S(t[2]._get_interval().dy)/t[2]).n() < 1e-3 + + t[0]._reset() + a = [i.eval_approx(2) for i in t] + assert [str(i) for i in a] == [ + '-0.10', '0.05 - 3.2*I', '0.05 + 3.2*I'] + assert all(abs(((a[i] - t[i])/t[i]).n()) < 1e-2 for i in range(len(a))) + + +def test_issue_15920(): + r = rootof(x**5 - x + 1, 0) + p = Integral(x, (x, 1, y)) + assert unchanged(Eq, r, p) + + +def test_issue_19113(): + eq = y**3 - y + 1 + # generator is a canonical x in RootOf + assert str(Poly(eq).real_roots()) == '[CRootOf(x**3 - x + 1, 0)]' + assert str(Poly(eq.subs(y, tan(y))).real_roots() + ) == '[CRootOf(x**3 - x + 1, 0)]' + assert str(Poly(eq.subs(y, tan(x))).real_roots() + ) == '[CRootOf(x**3 - x + 1, 0)]' diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/test_solvers.py b/phivenv/Lib/site-packages/sympy/polys/tests/test_solvers.py new file mode 100644 index 0000000000000000000000000000000000000000..bf8708314466b6a8676ba1a4438eb84924d0030c --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/tests/test_solvers.py @@ -0,0 +1,112 @@ +"""Tests for low-level linear systems solver. """ + +from sympy.matrices import Matrix +from sympy.polys.domains import ZZ, QQ +from sympy.polys.fields import field +from sympy.polys.rings import ring +from sympy.polys.solvers import solve_lin_sys, eqs_to_matrix + + +def test_solve_lin_sys_2x2_one(): + domain, x1,x2 = ring("x1,x2", QQ) + eqs = [x1 + x2 - 5, + 2*x1 - x2] + sol = {x1: QQ(5, 3), x2: QQ(10, 3)} + _sol = solve_lin_sys(eqs, domain) + assert _sol == sol and all(s.ring == domain for s in _sol) + +def test_solve_lin_sys_2x4_none(): + domain, x1,x2 = ring("x1,x2", QQ) + eqs = [x1 - 1, + x1 - x2, + x1 - 2*x2, + x2 - 1] + assert solve_lin_sys(eqs, domain) is None + + +def test_solve_lin_sys_3x4_one(): + domain, x1,x2,x3 = ring("x1,x2,x3", QQ) + eqs = [x1 + 2*x2 + 3*x3, + 2*x1 - x2 + x3, + 3*x1 + x2 + x3, + 5*x2 + 2*x3] + sol = {x1: 0, x2: 0, x3: 0} + assert solve_lin_sys(eqs, domain) == sol + +def test_solve_lin_sys_3x3_inf(): + domain, x1,x2,x3 = ring("x1,x2,x3", QQ) + eqs = [x1 - x2 + 2*x3 - 1, + 2*x1 + x2 + x3 - 8, + x1 + x2 - 5] + sol = {x1: -x3 + 3, x2: x3 + 2} + assert solve_lin_sys(eqs, domain) == sol + +def test_solve_lin_sys_3x4_none(): + domain, x1,x2,x3,x4 = ring("x1,x2,x3,x4", QQ) + eqs = [2*x1 + x2 + 7*x3 - 7*x4 - 2, + -3*x1 + 4*x2 - 5*x3 - 6*x4 - 3, + x1 + x2 + 4*x3 - 5*x4 - 2] + assert solve_lin_sys(eqs, domain) is None + + +def test_solve_lin_sys_4x7_inf(): + domain, x1,x2,x3,x4,x5,x6,x7 = ring("x1,x2,x3,x4,x5,x6,x7", QQ) + eqs = [x1 + 4*x2 - x4 + 7*x6 - 9*x7 - 3, + 2*x1 + 8*x2 - x3 + 3*x4 + 9*x5 - 13*x6 + 7*x7 - 9, + 2*x3 - 3*x4 - 4*x5 + 12*x6 - 8*x7 - 1, + -x1 - 4*x2 + 2*x3 + 4*x4 + 8*x5 - 31*x6 + 37*x7 - 4] + sol = {x1: 4 - 4*x2 - 2*x5 - x6 + 3*x7, + x3: 2 - x5 + 3*x6 - 5*x7, + x4: 1 - 2*x5 + 6*x6 - 6*x7} + assert solve_lin_sys(eqs, domain) == sol + +def test_solve_lin_sys_5x5_inf(): + domain, x1,x2,x3,x4,x5 = ring("x1,x2,x3,x4,x5", QQ) + eqs = [x1 - x2 - 2*x3 + x4 + 11*x5 - 13, + x1 - x2 + x3 + x4 + 5*x5 - 16, + 2*x1 - 2*x2 + x4 + 10*x5 - 21, + 2*x1 - 2*x2 - x3 + 3*x4 + 20*x5 - 38, + 2*x1 - 2*x2 + x3 + x4 + 8*x5 - 22] + sol = {x1: 6 + x2 - 3*x5, + x3: 1 + 2*x5, + x4: 9 - 4*x5} + assert solve_lin_sys(eqs, domain) == sol + +def test_solve_lin_sys_6x6_1(): + ground, d,r,e,g,i,j,l,o,m,p,q = field("d,r,e,g,i,j,l,o,m,p,q", ZZ) + domain, c,f,h,k,n,b = ring("c,f,h,k,n,b", ground) + + eqs = [b + q/d - c/d, c*(1/d + 1/e + 1/g) - f/g - q/d, f*(1/g + 1/i + 1/j) - c/g - h/i, h*(1/i + 1/l + 1/m) - f/i - k/m, k*(1/m + 1/o + 1/p) - h/m - n/p, n/p - k/p] + sol = { + b: (e*i*l*q + e*i*m*q + e*i*o*q + e*j*l*q + e*j*m*q + e*j*o*q + e*l*m*q + e*l*o*q + g*i*l*q + g*i*m*q + g*i*o*q + g*j*l*q + g*j*m*q + g*j*o*q + g*l*m*q + g*l*o*q + i*j*l*q + i*j*m*q + i*j*o*q + j*l*m*q + j*l*o*q)/(-d*e*i*l - d*e*i*m - d*e*i*o - d*e*j*l - d*e*j*m - d*e*j*o - d*e*l*m - d*e*l*o - d*g*i*l - d*g*i*m - d*g*i*o - d*g*j*l - d*g*j*m - d*g*j*o - d*g*l*m - d*g*l*o - d*i*j*l - d*i*j*m - d*i*j*o - d*j*l*m - d*j*l*o - e*g*i*l - e*g*i*m - e*g*i*o - e*g*j*l - e*g*j*m - e*g*j*o - e*g*l*m - e*g*l*o - e*i*j*l - e*i*j*m - e*i*j*o - e*j*l*m - e*j*l*o), + c: (-e*g*i*l*q - e*g*i*m*q - e*g*i*o*q - e*g*j*l*q - e*g*j*m*q - e*g*j*o*q - e*g*l*m*q - e*g*l*o*q - e*i*j*l*q - e*i*j*m*q - e*i*j*o*q - e*j*l*m*q - e*j*l*o*q)/(-d*e*i*l - d*e*i*m - d*e*i*o - d*e*j*l - d*e*j*m - d*e*j*o - d*e*l*m - d*e*l*o - d*g*i*l - d*g*i*m - d*g*i*o - d*g*j*l - d*g*j*m - d*g*j*o - d*g*l*m - d*g*l*o - d*i*j*l - d*i*j*m - d*i*j*o - d*j*l*m - d*j*l*o - e*g*i*l - e*g*i*m - e*g*i*o - e*g*j*l - e*g*j*m - e*g*j*o - e*g*l*m - e*g*l*o - e*i*j*l - e*i*j*m - e*i*j*o - e*j*l*m - e*j*l*o), + f: (-e*i*j*l*q - e*i*j*m*q - e*i*j*o*q - e*j*l*m*q - e*j*l*o*q)/(-d*e*i*l - d*e*i*m - d*e*i*o - d*e*j*l - d*e*j*m - d*e*j*o - d*e*l*m - d*e*l*o - d*g*i*l - d*g*i*m - d*g*i*o - d*g*j*l - d*g*j*m - d*g*j*o - d*g*l*m - d*g*l*o - d*i*j*l - d*i*j*m - d*i*j*o - d*j*l*m - d*j*l*o - e*g*i*l - e*g*i*m - e*g*i*o - e*g*j*l - e*g*j*m - e*g*j*o - e*g*l*m - e*g*l*o - e*i*j*l - e*i*j*m - e*i*j*o - e*j*l*m - e*j*l*o), + h: (-e*j*l*m*q - e*j*l*o*q)/(-d*e*i*l - d*e*i*m - d*e*i*o - d*e*j*l - d*e*j*m - d*e*j*o - d*e*l*m - d*e*l*o - d*g*i*l - d*g*i*m - d*g*i*o - d*g*j*l - d*g*j*m - d*g*j*o - d*g*l*m - d*g*l*o - d*i*j*l - d*i*j*m - d*i*j*o - d*j*l*m - d*j*l*o - e*g*i*l - e*g*i*m - e*g*i*o - e*g*j*l - e*g*j*m - e*g*j*o - e*g*l*m - e*g*l*o - e*i*j*l - e*i*j*m - e*i*j*o - e*j*l*m - e*j*l*o), + k: e*j*l*o*q/(d*e*i*l + d*e*i*m + d*e*i*o + d*e*j*l + d*e*j*m + d*e*j*o + d*e*l*m + d*e*l*o + d*g*i*l + d*g*i*m + d*g*i*o + d*g*j*l + d*g*j*m + d*g*j*o + d*g*l*m + d*g*l*o + d*i*j*l + d*i*j*m + d*i*j*o + d*j*l*m + d*j*l*o + e*g*i*l + e*g*i*m + e*g*i*o + e*g*j*l + e*g*j*m + e*g*j*o + e*g*l*m + e*g*l*o + e*i*j*l + e*i*j*m + e*i*j*o + e*j*l*m + e*j*l*o), + n: e*j*l*o*q/(d*e*i*l + d*e*i*m + d*e*i*o + d*e*j*l + d*e*j*m + d*e*j*o + d*e*l*m + d*e*l*o + d*g*i*l + d*g*i*m + d*g*i*o + d*g*j*l + d*g*j*m + d*g*j*o + d*g*l*m + d*g*l*o + d*i*j*l + d*i*j*m + d*i*j*o + d*j*l*m + d*j*l*o + e*g*i*l + e*g*i*m + e*g*i*o + e*g*j*l + e*g*j*m + e*g*j*o + e*g*l*m + e*g*l*o + e*i*j*l + e*i*j*m + e*i*j*o + e*j*l*m + e*j*l*o), + } + + assert solve_lin_sys(eqs, domain) == sol + +def test_solve_lin_sys_6x6_2(): + ground, d,r,e,g,i,j,l,o,m,p,q = field("d,r,e,g,i,j,l,o,m,p,q", ZZ) + domain, c,f,h,k,n,b = ring("c,f,h,k,n,b", ground) + + eqs = [b + r/d - c/d, c*(1/d + 1/e + 1/g) - f/g - r/d, f*(1/g + 1/i + 1/j) - c/g - h/i, h*(1/i + 1/l + 1/m) - f/i - k/m, k*(1/m + 1/o + 1/p) - h/m - n/p, n*(1/p + 1/q) - k/p] + sol = { + b: -((l*q*e*o + l*q*g*o + i*m*q*e + i*l*q*e + i*l*p*e + i*j*o*q + j*e*o*q + g*j*o*q + i*e*o*q + g*i*o*q + e*l*o*p + e*l*m*p + e*l*m*o + e*i*o*p + e*i*m*p + e*i*m*o + e*i*l*o + j*e*o*p + j*e*m*q + j*e*m*p + j*e*m*o + j*l*m*q + j*l*m*p + j*l*m*o + i*j*m*p + i*j*m*o + i*j*l*q + i*j*l*o + i*j*m*q + j*l*o*p + j*e*l*o + g*j*o*p + g*j*m*q + g*j*m*p + i*j*l*p + i*j*o*p + j*e*l*q + j*e*l*p + j*l*o*q + g*j*m*o + g*j*l*q + g*j*l*p + g*j*l*o + g*l*o*p + g*l*m*p + g*l*m*o + g*i*m*o + g*i*o*p + g*i*m*q + g*i*m*p + g*i*l*q + g*i*l*p + g*i*l*o + l*m*q*e + l*m*q*g)*r)/(l*q*d*e*o + l*q*d*g*o + l*q*e*g*o + i*j*d*o*q + i*j*e*o*q + j*d*e*o*q + g*j*d*o*q + g*j*e*o*q + g*i*e*o*q + i*d*e*o*q + g*i*d*o*q + g*i*d*o*p + g*i*d*m*q + g*i*d*m*p + g*i*d*m*o + g*i*d*l*q + g*i*d*l*p + g*i*d*l*o + g*e*l*m*p + g*e*l*o*p + g*j*e*l*q + g*e*l*m*o + g*j*e*m*p + g*j*e*m*o + d*e*l*m*p + d*e*l*m*o + i*d*e*m*p + g*j*e*l*p + g*j*e*l*o + d*e*l*o*p + i*j*d*l*o + i*j*e*o*p + i*j*e*m*q + i*j*d*m*q + i*j*d*m*p + i*j*d*m*o + i*j*d*l*q + i*j*d*l*p + i*j*e*m*p + i*j*e*m*o + i*j*e*l*q + i*j*e*l*p + i*j*e*l*o + i*d*e*m*q + i*d*e*m*o + i*d*e*l*q + i*d*e*l*p + j*d*l*o*p + j*d*e*l*o + g*j*d*o*p + g*j*d*m*q + g*j*d*m*p + g*j*d*m*o + g*j*d*l*q + g*j*d*l*p + g*j*d*l*o + g*j*e*o*p + g*j*e*m*q + g*d*l*o*p + g*d*l*m*p + g*d*l*m*o + j*d*e*m*p + i*d*e*o*p + j*e*o*q*l + j*e*o*p*l + j*e*m*q*l + j*d*e*o*p + j*d*e*m*q + i*j*d*o*p + g*i*e*o*p + j*d*e*m*o + j*d*e*l*q + j*d*e*l*p + j*e*m*p*l + j*e*m*o*l + g*i*e*m*q + g*i*e*m*p + g*i*e*m*o + g*i*e*l*q + g*i*e*l*p + g*i*e*l*o + j*d*l*o*q + j*d*l*m*q + j*d*l*m*p + j*d*l*m*o + i*d*e*l*o + l*m*q*d*e + l*m*q*d*g + l*m*q*e*g), + c: (r*e*(l*q*g*o + i*j*o*q + g*j*o*q + g*i*o*q + j*l*m*q + j*l*m*p + j*l*m*o + i*j*m*p + i*j*m*o + i*j*l*q + i*j*l*o + i*j*m*q + j*l*o*p + g*j*o*p + g*j*m*q + g*j*m*p + i*j*l*p + i*j*o*p + j*l*o*q + g*j*m*o + g*j*l*q + g*j*l*p + g*j*l*o + g*l*o*p + g*l*m*p + g*l*m*o + g*i*m*o + g*i*o*p + g*i*m*q + g*i*m*p + g*i*l*q + g*i*l*p + g*i*l*o + l*m*q*g))/(l*q*d*e*o + l*q*d*g*o + l*q*e*g*o + i*j*d*o*q + i*j*e*o*q + j*d*e*o*q + g*j*d*o*q + g*j*e*o*q + g*i*e*o*q + i*d*e*o*q + g*i*d*o*q + g*i*d*o*p + g*i*d*m*q + g*i*d*m*p + g*i*d*m*o + g*i*d*l*q + g*i*d*l*p + g*i*d*l*o + g*e*l*m*p + g*e*l*o*p + g*j*e*l*q + g*e*l*m*o + g*j*e*m*p + g*j*e*m*o + d*e*l*m*p + d*e*l*m*o + i*d*e*m*p + g*j*e*l*p + g*j*e*l*o + d*e*l*o*p + i*j*d*l*o + i*j*e*o*p + i*j*e*m*q + i*j*d*m*q + i*j*d*m*p + i*j*d*m*o + i*j*d*l*q + i*j*d*l*p + i*j*e*m*p + i*j*e*m*o + i*j*e*l*q + i*j*e*l*p + i*j*e*l*o + i*d*e*m*q + i*d*e*m*o + i*d*e*l*q + i*d*e*l*p + j*d*l*o*p + j*d*e*l*o + g*j*d*o*p + g*j*d*m*q + g*j*d*m*p + g*j*d*m*o + g*j*d*l*q + g*j*d*l*p + g*j*d*l*o + g*j*e*o*p + g*j*e*m*q + g*d*l*o*p + g*d*l*m*p + g*d*l*m*o + j*d*e*m*p + i*d*e*o*p + j*e*o*q*l + j*e*o*p*l + j*e*m*q*l + j*d*e*o*p + j*d*e*m*q + i*j*d*o*p + g*i*e*o*p + j*d*e*m*o + j*d*e*l*q + j*d*e*l*p + j*e*m*p*l + j*e*m*o*l + g*i*e*m*q + g*i*e*m*p + g*i*e*m*o + g*i*e*l*q + g*i*e*l*p + g*i*e*l*o + j*d*l*o*q + j*d*l*m*q + j*d*l*m*p + j*d*l*m*o + i*d*e*l*o + l*m*q*d*e + l*m*q*d*g + l*m*q*e*g), + f: (r*e*j*(l*q*o + l*o*p + l*m*q + l*m*p + l*m*o + i*o*q + i*o*p + i*m*q + i*m*p + i*m*o + i*l*q + i*l*p + i*l*o))/(l*q*d*e*o + l*q*d*g*o + l*q*e*g*o + i*j*d*o*q + i*j*e*o*q + j*d*e*o*q + g*j*d*o*q + g*j*e*o*q + g*i*e*o*q + i*d*e*o*q + g*i*d*o*q + g*i*d*o*p + g*i*d*m*q + g*i*d*m*p + g*i*d*m*o + g*i*d*l*q + g*i*d*l*p + g*i*d*l*o + g*e*l*m*p + g*e*l*o*p + g*j*e*l*q + g*e*l*m*o + g*j*e*m*p + g*j*e*m*o + d*e*l*m*p + d*e*l*m*o + i*d*e*m*p + g*j*e*l*p + g*j*e*l*o + d*e*l*o*p + i*j*d*l*o + i*j*e*o*p + i*j*e*m*q + i*j*d*m*q + i*j*d*m*p + i*j*d*m*o + i*j*d*l*q + i*j*d*l*p + i*j*e*m*p + i*j*e*m*o + i*j*e*l*q + i*j*e*l*p + i*j*e*l*o + i*d*e*m*q + i*d*e*m*o + i*d*e*l*q + i*d*e*l*p + j*d*l*o*p + j*d*e*l*o + g*j*d*o*p + g*j*d*m*q + g*j*d*m*p + g*j*d*m*o + g*j*d*l*q + g*j*d*l*p + g*j*d*l*o + g*j*e*o*p + g*j*e*m*q + g*d*l*o*p + g*d*l*m*p + g*d*l*m*o + j*d*e*m*p + i*d*e*o*p + j*e*o*q*l + j*e*o*p*l + j*e*m*q*l + j*d*e*o*p + j*d*e*m*q + i*j*d*o*p + g*i*e*o*p + j*d*e*m*o + j*d*e*l*q + j*d*e*l*p + j*e*m*p*l + j*e*m*o*l + g*i*e*m*q + g*i*e*m*p + g*i*e*m*o + g*i*e*l*q + g*i*e*l*p + g*i*e*l*o + j*d*l*o*q + j*d*l*m*q + j*d*l*m*p + j*d*l*m*o + i*d*e*l*o + l*m*q*d*e + l*m*q*d*g + l*m*q*e*g), + h: (j*e*r*l*(o*q + o*p + m*q + m*p + m*o))/(l*q*d*e*o + l*q*d*g*o + l*q*e*g*o + i*j*d*o*q + i*j*e*o*q + j*d*e*o*q + g*j*d*o*q + g*j*e*o*q + g*i*e*o*q + i*d*e*o*q + g*i*d*o*q + g*i*d*o*p + g*i*d*m*q + g*i*d*m*p + g*i*d*m*o + g*i*d*l*q + g*i*d*l*p + g*i*d*l*o + g*e*l*m*p + g*e*l*o*p + g*j*e*l*q + g*e*l*m*o + g*j*e*m*p + g*j*e*m*o + d*e*l*m*p + d*e*l*m*o + i*d*e*m*p + g*j*e*l*p + g*j*e*l*o + d*e*l*o*p + i*j*d*l*o + i*j*e*o*p + i*j*e*m*q + i*j*d*m*q + i*j*d*m*p + i*j*d*m*o + i*j*d*l*q + i*j*d*l*p + i*j*e*m*p + i*j*e*m*o + i*j*e*l*q + i*j*e*l*p + i*j*e*l*o + i*d*e*m*q + i*d*e*m*o + i*d*e*l*q + i*d*e*l*p + j*d*l*o*p + j*d*e*l*o + g*j*d*o*p + g*j*d*m*q + g*j*d*m*p + g*j*d*m*o + g*j*d*l*q + g*j*d*l*p + g*j*d*l*o + g*j*e*o*p + g*j*e*m*q + g*d*l*o*p + g*d*l*m*p + g*d*l*m*o + j*d*e*m*p + i*d*e*o*p + j*e*o*q*l + j*e*o*p*l + j*e*m*q*l + j*d*e*o*p + j*d*e*m*q + i*j*d*o*p + g*i*e*o*p + j*d*e*m*o + j*d*e*l*q + j*d*e*l*p + j*e*m*p*l + j*e*m*o*l + g*i*e*m*q + g*i*e*m*p + g*i*e*m*o + g*i*e*l*q + g*i*e*l*p + g*i*e*l*o + j*d*l*o*q + j*d*l*m*q + j*d*l*m*p + j*d*l*m*o + i*d*e*l*o + l*m*q*d*e + l*m*q*d*g + l*m*q*e*g), + k: (j*e*r*o*l*(q + p))/(l*q*d*e*o + l*q*d*g*o + l*q*e*g*o + i*j*d*o*q + i*j*e*o*q + j*d*e*o*q + g*j*d*o*q + g*j*e*o*q + g*i*e*o*q + i*d*e*o*q + g*i*d*o*q + g*i*d*o*p + g*i*d*m*q + g*i*d*m*p + g*i*d*m*o + g*i*d*l*q + g*i*d*l*p + g*i*d*l*o + g*e*l*m*p + g*e*l*o*p + g*j*e*l*q + g*e*l*m*o + g*j*e*m*p + g*j*e*m*o + d*e*l*m*p + d*e*l*m*o + i*d*e*m*p + g*j*e*l*p + g*j*e*l*o + d*e*l*o*p + i*j*d*l*o + i*j*e*o*p + i*j*e*m*q + i*j*d*m*q + i*j*d*m*p + i*j*d*m*o + i*j*d*l*q + i*j*d*l*p + i*j*e*m*p + i*j*e*m*o + i*j*e*l*q + i*j*e*l*p + i*j*e*l*o + i*d*e*m*q + i*d*e*m*o + i*d*e*l*q + i*d*e*l*p + j*d*l*o*p + j*d*e*l*o + g*j*d*o*p + g*j*d*m*q + g*j*d*m*p + g*j*d*m*o + g*j*d*l*q + g*j*d*l*p + g*j*d*l*o + g*j*e*o*p + g*j*e*m*q + g*d*l*o*p + g*d*l*m*p + g*d*l*m*o + j*d*e*m*p + i*d*e*o*p + j*e*o*q*l + j*e*o*p*l + j*e*m*q*l + j*d*e*o*p + j*d*e*m*q + i*j*d*o*p + g*i*e*o*p + j*d*e*m*o + j*d*e*l*q + j*d*e*l*p + j*e*m*p*l + j*e*m*o*l + g*i*e*m*q + g*i*e*m*p + g*i*e*m*o + g*i*e*l*q + g*i*e*l*p + g*i*e*l*o + j*d*l*o*q + j*d*l*m*q + j*d*l*m*p + j*d*l*m*o + i*d*e*l*o + l*m*q*d*e + l*m*q*d*g + l*m*q*e*g), + n: (j*e*r*o*q*l)/(l*q*d*e*o + l*q*d*g*o + l*q*e*g*o + i*j*d*o*q + i*j*e*o*q + j*d*e*o*q + g*j*d*o*q + g*j*e*o*q + g*i*e*o*q + i*d*e*o*q + g*i*d*o*q + g*i*d*o*p + g*i*d*m*q + g*i*d*m*p + g*i*d*m*o + g*i*d*l*q + g*i*d*l*p + g*i*d*l*o + g*e*l*m*p + g*e*l*o*p + g*j*e*l*q + g*e*l*m*o + g*j*e*m*p + g*j*e*m*o + d*e*l*m*p + d*e*l*m*o + i*d*e*m*p + g*j*e*l*p + g*j*e*l*o + d*e*l*o*p + i*j*d*l*o + i*j*e*o*p + i*j*e*m*q + i*j*d*m*q + i*j*d*m*p + i*j*d*m*o + i*j*d*l*q + i*j*d*l*p + i*j*e*m*p + i*j*e*m*o + i*j*e*l*q + i*j*e*l*p + i*j*e*l*o + i*d*e*m*q + i*d*e*m*o + i*d*e*l*q + i*d*e*l*p + j*d*l*o*p + j*d*e*l*o + g*j*d*o*p + g*j*d*m*q + g*j*d*m*p + g*j*d*m*o + g*j*d*l*q + g*j*d*l*p + g*j*d*l*o + g*j*e*o*p + g*j*e*m*q + g*d*l*o*p + g*d*l*m*p + g*d*l*m*o + j*d*e*m*p + i*d*e*o*p + j*e*o*q*l + j*e*o*p*l + j*e*m*q*l + j*d*e*o*p + j*d*e*m*q + i*j*d*o*p + g*i*e*o*p + j*d*e*m*o + j*d*e*l*q + j*d*e*l*p + j*e*m*p*l + j*e*m*o*l + g*i*e*m*q + g*i*e*m*p + g*i*e*m*o + g*i*e*l*q + g*i*e*l*p + g*i*e*l*o + j*d*l*o*q + j*d*l*m*q + j*d*l*m*p + j*d*l*m*o + i*d*e*l*o + l*m*q*d*e + l*m*q*d*g + l*m*q*e*g), + } + + assert solve_lin_sys(eqs, domain) == sol + +def test_eqs_to_matrix(): + domain, x1,x2 = ring("x1,x2", QQ) + eqs_coeff = [{x1: QQ(1), x2: QQ(1)}, {x1: QQ(2), x2: QQ(-1)}] + eqs_rhs = [QQ(-5), QQ(0)] + M = eqs_to_matrix(eqs_coeff, eqs_rhs, [x1, x2], QQ) + assert M.to_Matrix() == Matrix([[1, 1, 5], [2, -1, 0]]) diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/test_specialpolys.py b/phivenv/Lib/site-packages/sympy/polys/tests/test_specialpolys.py new file mode 100644 index 0000000000000000000000000000000000000000..39f551c9e70b5c2bae748ea681b9c8a8cb349fe1 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/tests/test_specialpolys.py @@ -0,0 +1,152 @@ +"""Tests for functions for generating interesting polynomials. """ + +from sympy.core.add import Add +from sympy.core.symbol import symbols +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.ntheory.generate import prime +from sympy.polys.domains.integerring import ZZ +from sympy.polys.polytools import Poly +from sympy.utilities.iterables import permute_signs +from sympy.testing.pytest import raises + +from sympy.polys.specialpolys import ( + swinnerton_dyer_poly, + cyclotomic_poly, + symmetric_poly, + random_poly, + interpolating_poly, + fateman_poly_F_1, + dmp_fateman_poly_F_1, + fateman_poly_F_2, + dmp_fateman_poly_F_2, + fateman_poly_F_3, + dmp_fateman_poly_F_3, +) + +from sympy.abc import x, y, z + + +def test_swinnerton_dyer_poly(): + raises(ValueError, lambda: swinnerton_dyer_poly(0, x)) + + assert swinnerton_dyer_poly(1, x, polys=True) == Poly(x**2 - 2) + + assert swinnerton_dyer_poly(1, x) == x**2 - 2 + assert swinnerton_dyer_poly(2, x) == x**4 - 10*x**2 + 1 + assert swinnerton_dyer_poly( + 3, x) == x**8 - 40*x**6 + 352*x**4 - 960*x**2 + 576 + # we only need to check that the polys arg works but + # we may as well test that the roots are correct + p = [sqrt(prime(i)) for i in range(1, 5)] + assert str([i.n(3) for i in + swinnerton_dyer_poly(4, polys=True).all_roots()] + ) == str(sorted([Add(*i).n(3) for i in permute_signs(p)])) + + +def test_cyclotomic_poly(): + raises(ValueError, lambda: cyclotomic_poly(0, x)) + + assert cyclotomic_poly(1, x, polys=True) == Poly(x - 1) + + assert cyclotomic_poly(1, x) == x - 1 + assert cyclotomic_poly(2, x) == x + 1 + assert cyclotomic_poly(3, x) == x**2 + x + 1 + assert cyclotomic_poly(4, x) == x**2 + 1 + assert cyclotomic_poly(5, x) == x**4 + x**3 + x**2 + x + 1 + assert cyclotomic_poly(6, x) == x**2 - x + 1 + + +def test_symmetric_poly(): + raises(ValueError, lambda: symmetric_poly(-1, x, y, z)) + raises(ValueError, lambda: symmetric_poly(5, x, y, z)) + + assert symmetric_poly(1, x, y, z, polys=True) == Poly(x + y + z) + assert symmetric_poly(1, (x, y, z), polys=True) == Poly(x + y + z) + + assert symmetric_poly(0, x, y, z) == 1 + assert symmetric_poly(1, x, y, z) == x + y + z + assert symmetric_poly(2, x, y, z) == x*y + x*z + y*z + assert symmetric_poly(3, x, y, z) == x*y*z + + +def test_random_poly(): + poly = random_poly(x, 10, -100, 100, polys=False) + + assert Poly(poly).degree() == 10 + assert all(-100 <= coeff <= 100 for coeff in Poly(poly).coeffs()) is True + + poly = random_poly(x, 10, -100, 100, polys=True) + + assert poly.degree() == 10 + assert all(-100 <= coeff <= 100 for coeff in poly.coeffs()) is True + + +def test_interpolating_poly(): + x0, x1, x2, x3, y0, y1, y2, y3 = symbols('x:4, y:4') + + assert interpolating_poly(0, x) == 0 + assert interpolating_poly(1, x) == y0 + + assert interpolating_poly(2, x) == \ + y0*(x - x1)/(x0 - x1) + y1*(x - x0)/(x1 - x0) + + assert interpolating_poly(3, x) == \ + y0*(x - x1)*(x - x2)/((x0 - x1)*(x0 - x2)) + \ + y1*(x - x0)*(x - x2)/((x1 - x0)*(x1 - x2)) + \ + y2*(x - x0)*(x - x1)/((x2 - x0)*(x2 - x1)) + + assert interpolating_poly(4, x) == \ + y0*(x - x1)*(x - x2)*(x - x3)/((x0 - x1)*(x0 - x2)*(x0 - x3)) + \ + y1*(x - x0)*(x - x2)*(x - x3)/((x1 - x0)*(x1 - x2)*(x1 - x3)) + \ + y2*(x - x0)*(x - x1)*(x - x3)/((x2 - x0)*(x2 - x1)*(x2 - x3)) + \ + y3*(x - x0)*(x - x1)*(x - x2)/((x3 - x0)*(x3 - x1)*(x3 - x2)) + + raises(ValueError, lambda: + interpolating_poly(2, x, (x, 2), (1, 3))) + raises(ValueError, lambda: + interpolating_poly(2, x, (x + y, 2), (1, 3))) + raises(ValueError, lambda: + interpolating_poly(2, x + y, (x, 2), (1, 3))) + raises(ValueError, lambda: + interpolating_poly(2, 3, (4, 5), (6, 7))) + raises(ValueError, lambda: + interpolating_poly(2, 3, (4, 5), (6, 7, 8))) + assert interpolating_poly(0, x, (1, 2), (3, 4)) == 0 + assert interpolating_poly(1, x, (1, 2), (3, 4)) == 3 + assert interpolating_poly(2, x, (1, 2), (3, 4)) == x + 2 + + +def test_fateman_poly_F_1(): + f, g, h = fateman_poly_F_1(1) + F, G, H = dmp_fateman_poly_F_1(1, ZZ) + + assert [ t.rep.to_list() for t in [f, g, h] ] == [F, G, H] + + f, g, h = fateman_poly_F_1(3) + F, G, H = dmp_fateman_poly_F_1(3, ZZ) + + assert [ t.rep.to_list() for t in [f, g, h] ] == [F, G, H] + + +def test_fateman_poly_F_2(): + f, g, h = fateman_poly_F_2(1) + F, G, H = dmp_fateman_poly_F_2(1, ZZ) + + assert [ t.rep.to_list() for t in [f, g, h] ] == [F, G, H] + + f, g, h = fateman_poly_F_2(3) + F, G, H = dmp_fateman_poly_F_2(3, ZZ) + + assert [ t.rep.to_list() for t in [f, g, h] ] == [F, G, H] + + +def test_fateman_poly_F_3(): + f, g, h = fateman_poly_F_3(1) + F, G, H = dmp_fateman_poly_F_3(1, ZZ) + + assert [ t.rep.to_list() for t in [f, g, h] ] == [F, G, H] + + f, g, h = fateman_poly_F_3(3) + F, G, H = dmp_fateman_poly_F_3(3, ZZ) + + assert [ t.rep.to_list() for t in [f, g, h] ] == [F, G, H] diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/test_sqfreetools.py b/phivenv/Lib/site-packages/sympy/polys/tests/test_sqfreetools.py new file mode 100644 index 0000000000000000000000000000000000000000..b772a05a50e2eacd5a7c80352b1eadd52c69c3fa --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/tests/test_sqfreetools.py @@ -0,0 +1,160 @@ +"""Tests for square-free decomposition algorithms and related tools. """ + +from sympy.polys.rings import ring +from sympy.polys.domains import FF, ZZ, QQ +from sympy.polys.specialpolys import f_polys + +from sympy.testing.pytest import raises +from sympy.external.gmpy import MPQ + +f_0, f_1, f_2, f_3, f_4, f_5, f_6 = f_polys() + +def test_dup_sqf(): + R, x = ring("x", ZZ) + + assert R.dup_sqf_part(0) == 0 + assert R.dup_sqf_p(0) is True + + assert R.dup_sqf_part(7) == 1 + assert R.dup_sqf_p(7) is True + + assert R.dup_sqf_part(2*x + 2) == x + 1 + assert R.dup_sqf_p(2*x + 2) is True + + assert R.dup_sqf_part(x**3 + x + 1) == x**3 + x + 1 + assert R.dup_sqf_p(x**3 + x + 1) is True + + assert R.dup_sqf_part(-x**3 + x + 1) == x**3 - x - 1 + assert R.dup_sqf_p(-x**3 + x + 1) is True + + assert R.dup_sqf_part(2*x**3 + 3*x**2) == 2*x**2 + 3*x + assert R.dup_sqf_p(2*x**3 + 3*x**2) is False + + assert R.dup_sqf_part(-2*x**3 + 3*x**2) == 2*x**2 - 3*x + assert R.dup_sqf_p(-2*x**3 + 3*x**2) is False + + assert R.dup_sqf_list(0) == (0, []) + assert R.dup_sqf_list(1) == (1, []) + + assert R.dup_sqf_list(x) == (1, [(x, 1)]) + assert R.dup_sqf_list(2*x**2) == (2, [(x, 2)]) + assert R.dup_sqf_list(3*x**3) == (3, [(x, 3)]) + + assert R.dup_sqf_list(-x**5 + x**4 + x - 1) == \ + (-1, [(x**3 + x**2 + x + 1, 1), (x - 1, 2)]) + assert R.dup_sqf_list(x**8 + 6*x**6 + 12*x**4 + 8*x**2) == \ + ( 1, [(x, 2), (x**2 + 2, 3)]) + + assert R.dup_sqf_list(2*x**2 + 4*x + 2) == (2, [(x + 1, 2)]) + + R, x = ring("x", QQ) + assert R.dup_sqf_list(2*x**2 + 4*x + 2) == (2, [(x + 1, 2)]) + + R, x = ring("x", FF(2)) + assert R.dup_sqf_list(x**2 + 1) == (1, [(x + 1, 2)]) + + R, x = ring("x", FF(3)) + assert R.dup_sqf_list(x**10 + 2*x**7 + 2*x**4 + x) == \ + (1, [(x, 1), + (x + 1, 3), + (x + 2, 6)]) + + R1, x = ring("x", ZZ) + R2, y = ring("y", FF(3)) + + f = x**3 + 1 + g = y**3 + 1 + + assert R1.dup_sqf_part(f) == f + assert R2.dup_sqf_part(g) == y + 1 + + assert R1.dup_sqf_p(f) is True + assert R2.dup_sqf_p(g) is False + + R, x, y = ring("x,y", ZZ) + + A = x**4 - 3*x**2 + 6 + D = x**6 - 5*x**4 + 5*x**2 + 4 + + f, g = D, R.dmp_sub(A, R.dmp_mul(R.dmp_diff(D, 1), y)) + res = R.dmp_resultant(f, g) + h = (4*y**2 + 1).drop(x) + + assert R.drop(x).dup_sqf_list(res) == (45796, [(h, 3)]) + + Rt, t = ring("t", ZZ) + R, x = ring("x", Rt) + assert R.dup_sqf_list_include(t**3*x**2) == [(t**3, 1), (x, 2)] + + +def test_dmp_sqf(): + R, x, y = ring("x,y", ZZ) + assert R.dmp_sqf_part(0) == 0 + assert R.dmp_sqf_p(0) is True + + assert R.dmp_sqf_part(7) == 1 + assert R.dmp_sqf_p(7) is True + + assert R.dmp_sqf_list(3) == (3, []) + assert R.dmp_sqf_list_include(3) == [(3, 1)] + + R, x, y, z = ring("x,y,z", ZZ) + assert R.dmp_sqf_p(f_0) is True + assert R.dmp_sqf_p(f_0**2) is False + assert R.dmp_sqf_p(f_1) is True + assert R.dmp_sqf_p(f_1**2) is False + assert R.dmp_sqf_p(f_2) is True + assert R.dmp_sqf_p(f_2**2) is False + assert R.dmp_sqf_p(f_3) is True + assert R.dmp_sqf_p(f_3**2) is False + assert R.dmp_sqf_p(f_5) is False + assert R.dmp_sqf_p(f_5**2) is False + + assert R.dmp_sqf_p(f_4) is True + assert R.dmp_sqf_part(f_4) == -f_4 + + assert R.dmp_sqf_part(f_5) == x + y - z + + R, x, y, z, t = ring("x,y,z,t", ZZ) + assert R.dmp_sqf_p(f_6) is True + assert R.dmp_sqf_part(f_6) == f_6 + + R, x = ring("x", ZZ) + f = -x**5 + x**4 + x - 1 + + assert R.dmp_sqf_list(f) == (-1, [(x**3 + x**2 + x + 1, 1), (x - 1, 2)]) + assert R.dmp_sqf_list_include(f) == [(-x**3 - x**2 - x - 1, 1), (x - 1, 2)] + + R, x, y = ring("x,y", ZZ) + f = -x**5 + x**4 + x - 1 + + assert R.dmp_sqf_list(f) == (-1, [(x**3 + x**2 + x + 1, 1), (x - 1, 2)]) + assert R.dmp_sqf_list_include(f) == [(-x**3 - x**2 - x - 1, 1), (x - 1, 2)] + + f = -x**2 + 2*x - 1 + assert R.dmp_sqf_list_include(f) == [(-1, 1), (x - 1, 2)] + + f = (y**2 + 1)**2*(x**2 + 2*x + 2) + assert R.dmp_sqf_p(f) is False + assert R.dmp_sqf_list(f) == (1, [(x**2 + 2*x + 2, 1), (y**2 + 1, 2)]) + + R, x, y = ring("x,y", FF(2)) + raises(NotImplementedError, lambda: R.dmp_sqf_list(y**2 + 1)) + + +def test_dup_gff_list(): + R, x = ring("x", ZZ) + + f = x**5 + 2*x**4 - x**3 - 2*x**2 + assert R.dup_gff_list(f) == [(x, 1), (x + 2, 4)] + + g = x**9 - 20*x**8 + 166*x**7 - 744*x**6 + 1965*x**5 - 3132*x**4 + 2948*x**3 - 1504*x**2 + 320*x + assert R.dup_gff_list(g) == [(x**2 - 5*x + 4, 1), (x**2 - 5*x + 4, 2), (x, 3)] + + raises(ValueError, lambda: R.dup_gff_list(0)) + +def test_issue_26178(): + R, x, y, z = ring(['x', 'y', 'z'], QQ) + assert (x**2 - 2*y**2 + 1).sqf_list() == (MPQ(1,1), [(x**2 - 2*y**2 + 1, 1)]) + assert (x**2 - 2*z**2 + 1).sqf_list() == (MPQ(1,1), [(x**2 - 2*z**2 + 1, 1)]) + assert (y**2 - 2*z**2 + 1).sqf_list() == (MPQ(1,1), [(y**2 - 2*z**2 + 1, 1)]) diff --git a/phivenv/Lib/site-packages/sympy/polys/tests/test_subresultants_qq_zz.py b/phivenv/Lib/site-packages/sympy/polys/tests/test_subresultants_qq_zz.py new file mode 100644 index 0000000000000000000000000000000000000000..7f7560dfeaf93b20f7cf68cdc597c024cb519cca --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/tests/test_subresultants_qq_zz.py @@ -0,0 +1,347 @@ +from sympy.core.symbol import Symbol +from sympy.polys.polytools import (pquo, prem, sturm, subresultants) +from sympy.matrices import Matrix +from sympy.polys.subresultants_qq_zz import (sylvester, res, res_q, res_z, bezout, + subresultants_sylv, modified_subresultants_sylv, + subresultants_bezout, modified_subresultants_bezout, + backward_eye, + sturm_pg, sturm_q, sturm_amv, euclid_pg, euclid_q, + euclid_amv, modified_subresultants_pg, subresultants_pg, + subresultants_amv_q, quo_z, rem_z, subresultants_amv, + modified_subresultants_amv, subresultants_rem, + subresultants_vv, subresultants_vv_2) + + +def test_sylvester(): + x = Symbol('x') + + assert sylvester(x**3 -7, 0, x) == sylvester(x**3 -7, 0, x, 1) == Matrix([[0]]) + assert sylvester(0, x**3 -7, x) == sylvester(0, x**3 -7, x, 1) == Matrix([[0]]) + assert sylvester(x**3 -7, 0, x, 2) == Matrix([[0]]) + assert sylvester(0, x**3 -7, x, 2) == Matrix([[0]]) + + assert sylvester(x**3 -7, 7, x).det() == sylvester(x**3 -7, 7, x, 1).det() == 343 + assert sylvester(7, x**3 -7, x).det() == sylvester(7, x**3 -7, x, 1).det() == 343 + assert sylvester(x**3 -7, 7, x, 2).det() == -343 + assert sylvester(7, x**3 -7, x, 2).det() == 343 + + assert sylvester(3, 7, x).det() == sylvester(3, 7, x, 1).det() == sylvester(3, 7, x, 2).det() == 1 + + assert sylvester(3, 0, x).det() == sylvester(3, 0, x, 1).det() == sylvester(3, 0, x, 2).det() == 1 + + assert sylvester(x - 3, x - 8, x) == sylvester(x - 3, x - 8, x, 1) == sylvester(x - 3, x - 8, x, 2) == Matrix([[1, -3], [1, -8]]) + + assert sylvester(x**3 - 7*x + 7, 3*x**2 - 7, x) == sylvester(x**3 - 7*x + 7, 3*x**2 - 7, x, 1) == Matrix([[1, 0, -7, 7, 0], [0, 1, 0, -7, 7], [3, 0, -7, 0, 0], [0, 3, 0, -7, 0], [0, 0, 3, 0, -7]]) + + assert sylvester(x**3 - 7*x + 7, 3*x**2 - 7, x, 2) == Matrix([ +[1, 0, -7, 7, 0, 0], [0, 3, 0, -7, 0, 0], [0, 1, 0, -7, 7, 0], [0, 0, 3, 0, -7, 0], [0, 0, 1, 0, -7, 7], [0, 0, 0, 3, 0, -7]]) + +def test_subresultants_sylv(): + x = Symbol('x') + + p = x**8 + x**6 - 3*x**4 - 3*x**3 + 8*x**2 + 2*x - 5 + q = 3*x**6 + 5*x**4 - 4*x**2 - 9*x + 21 + assert subresultants_sylv(p, q, x) == subresultants(p, q, x) + assert subresultants_sylv(p, q, x)[-1] == res(p, q, x) + assert subresultants_sylv(p, q, x) != euclid_amv(p, q, x) + amv_factors = [1, 1, -1, 1, -1, 1] + assert subresultants_sylv(p, q, x) == [i*j for i, j in zip(amv_factors, modified_subresultants_amv(p, q, x))] + + p = x**3 - 7*x + 7 + q = 3*x**2 - 7 + assert subresultants_sylv(p, q, x) == euclid_amv(p, q, x) + +def test_modified_subresultants_sylv(): + x = Symbol('x') + + p = x**8 + x**6 - 3*x**4 - 3*x**3 + 8*x**2 + 2*x - 5 + q = 3*x**6 + 5*x**4 - 4*x**2 - 9*x + 21 + amv_factors = [1, 1, -1, 1, -1, 1] + assert modified_subresultants_sylv(p, q, x) == [i*j for i, j in zip(amv_factors, subresultants_amv(p, q, x))] + assert modified_subresultants_sylv(p, q, x)[-1] != res_q(p + x**8, q, x) + assert modified_subresultants_sylv(p, q, x) != sturm_amv(p, q, x) + + p = x**3 - 7*x + 7 + q = 3*x**2 - 7 + assert modified_subresultants_sylv(p, q, x) == sturm_amv(p, q, x) + assert modified_subresultants_sylv(-p, q, x) != sturm_amv(-p, q, x) + +def test_res(): + x = Symbol('x') + + assert res(3, 5, x) == 1 + +def test_res_q(): + x = Symbol('x') + + assert res_q(3, 5, x) == 1 + +def test_res_z(): + x = Symbol('x') + + assert res_z(3, 5, x) == 1 + assert res(3, 5, x) == res_q(3, 5, x) == res_z(3, 5, x) + +def test_bezout(): + x = Symbol('x') + + p = -2*x**5+7*x**3+9*x**2-3*x+1 + q = -10*x**4+21*x**2+18*x-3 + assert bezout(p, q, x, 'bz').det() == sylvester(p, q, x, 2).det() + assert bezout(p, q, x, 'bz').det() != sylvester(p, q, x, 1).det() + assert bezout(p, q, x, 'prs') == backward_eye(5) * bezout(p, q, x, 'bz') * backward_eye(5) + +def test_subresultants_bezout(): + x = Symbol('x') + + p = x**8 + x**6 - 3*x**4 - 3*x**3 + 8*x**2 + 2*x - 5 + q = 3*x**6 + 5*x**4 - 4*x**2 - 9*x + 21 + assert subresultants_bezout(p, q, x) == subresultants(p, q, x) + assert subresultants_bezout(p, q, x)[-1] == sylvester(p, q, x).det() + assert subresultants_bezout(p, q, x) != euclid_amv(p, q, x) + amv_factors = [1, 1, -1, 1, -1, 1] + assert subresultants_bezout(p, q, x) == [i*j for i, j in zip(amv_factors, modified_subresultants_amv(p, q, x))] + + p = x**3 - 7*x + 7 + q = 3*x**2 - 7 + assert subresultants_bezout(p, q, x) == euclid_amv(p, q, x) + +def test_modified_subresultants_bezout(): + x = Symbol('x') + + p = x**8 + x**6 - 3*x**4 - 3*x**3 + 8*x**2 + 2*x - 5 + q = 3*x**6 + 5*x**4 - 4*x**2 - 9*x + 21 + amv_factors = [1, 1, -1, 1, -1, 1] + assert modified_subresultants_bezout(p, q, x) == [i*j for i, j in zip(amv_factors, subresultants_amv(p, q, x))] + assert modified_subresultants_bezout(p, q, x)[-1] != sylvester(p + x**8, q, x).det() + assert modified_subresultants_bezout(p, q, x) != sturm_amv(p, q, x) + + p = x**3 - 7*x + 7 + q = 3*x**2 - 7 + assert modified_subresultants_bezout(p, q, x) == sturm_amv(p, q, x) + assert modified_subresultants_bezout(-p, q, x) != sturm_amv(-p, q, x) + +def test_sturm_pg(): + x = Symbol('x') + + p = x**8 + x**6 - 3*x**4 - 3*x**3 + 8*x**2 + 2*x - 5 + q = 3*x**6 + 5*x**4 - 4*x**2 - 9*x + 21 + assert sturm_pg(p, q, x)[-1] != sylvester(p, q, x, 2).det() + sam_factors = [1, 1, -1, -1, 1, 1] + assert sturm_pg(p, q, x) == [i*j for i,j in zip(sam_factors, euclid_pg(p, q, x))] + + p = -9*x**5 - 5*x**3 - 9 + q = -45*x**4 - 15*x**2 + assert sturm_pg(p, q, x, 1)[-1] == sylvester(p, q, x, 1).det() + assert sturm_pg(p, q, x)[-1] != sylvester(p, q, x, 2).det() + assert sturm_pg(-p, q, x)[-1] == sylvester(-p, q, x, 2).det() + assert sturm_pg(-p, q, x) == modified_subresultants_pg(-p, q, x) + +def test_sturm_q(): + x = Symbol('x') + + p = x**3 - 7*x + 7 + q = 3*x**2 - 7 + assert sturm_q(p, q, x) == sturm(p) + assert sturm_q(-p, -q, x) != sturm(-p) + + +def test_sturm_amv(): + x = Symbol('x') + + p = x**8 + x**6 - 3*x**4 - 3*x**3 + 8*x**2 + 2*x - 5 + q = 3*x**6 + 5*x**4 - 4*x**2 - 9*x + 21 + assert sturm_amv(p, q, x)[-1] != sylvester(p, q, x, 2).det() + sam_factors = [1, 1, -1, -1, 1, 1] + assert sturm_amv(p, q, x) == [i*j for i,j in zip(sam_factors, euclid_amv(p, q, x))] + + p = -9*x**5 - 5*x**3 - 9 + q = -45*x**4 - 15*x**2 + assert sturm_amv(p, q, x, 1)[-1] == sylvester(p, q, x, 1).det() + assert sturm_amv(p, q, x)[-1] != sylvester(p, q, x, 2).det() + assert sturm_amv(-p, q, x)[-1] == sylvester(-p, q, x, 2).det() + assert sturm_pg(-p, q, x) == modified_subresultants_pg(-p, q, x) + + +def test_euclid_pg(): + x = Symbol('x') + + p = x**6+x**5-x**4-x**3+x**2-x+1 + q = 6*x**5+5*x**4-4*x**3-3*x**2+2*x-1 + assert euclid_pg(p, q, x)[-1] == sylvester(p, q, x).det() + assert euclid_pg(p, q, x) == subresultants_pg(p, q, x) + + p = x**8 + x**6 - 3*x**4 - 3*x**3 + 8*x**2 + 2*x - 5 + q = 3*x**6 + 5*x**4 - 4*x**2 - 9*x + 21 + assert euclid_pg(p, q, x)[-1] != sylvester(p, q, x, 2).det() + sam_factors = [1, 1, -1, -1, 1, 1] + assert euclid_pg(p, q, x) == [i*j for i,j in zip(sam_factors, sturm_pg(p, q, x))] + + +def test_euclid_q(): + x = Symbol('x') + + p = x**3 - 7*x + 7 + q = 3*x**2 - 7 + assert euclid_q(p, q, x)[-1] == -sturm(p)[-1] + + +def test_euclid_amv(): + x = Symbol('x') + + p = x**3 - 7*x + 7 + q = 3*x**2 - 7 + assert euclid_amv(p, q, x)[-1] == sylvester(p, q, x).det() + assert euclid_amv(p, q, x) == subresultants_amv(p, q, x) + + p = x**8 + x**6 - 3*x**4 - 3*x**3 + 8*x**2 + 2*x - 5 + q = 3*x**6 + 5*x**4 - 4*x**2 - 9*x + 21 + assert euclid_amv(p, q, x)[-1] != sylvester(p, q, x, 2).det() + sam_factors = [1, 1, -1, -1, 1, 1] + assert euclid_amv(p, q, x) == [i*j for i,j in zip(sam_factors, sturm_amv(p, q, x))] + + +def test_modified_subresultants_pg(): + x = Symbol('x') + + p = x**8 + x**6 - 3*x**4 - 3*x**3 + 8*x**2 + 2*x - 5 + q = 3*x**6 + 5*x**4 - 4*x**2 - 9*x + 21 + amv_factors = [1, 1, -1, 1, -1, 1] + assert modified_subresultants_pg(p, q, x) == [i*j for i, j in zip(amv_factors, subresultants_pg(p, q, x))] + assert modified_subresultants_pg(p, q, x)[-1] != sylvester(p + x**8, q, x).det() + assert modified_subresultants_pg(p, q, x) != sturm_pg(p, q, x) + + p = x**3 - 7*x + 7 + q = 3*x**2 - 7 + assert modified_subresultants_pg(p, q, x) == sturm_pg(p, q, x) + assert modified_subresultants_pg(-p, q, x) != sturm_pg(-p, q, x) + + +def test_subresultants_pg(): + x = Symbol('x') + + p = x**8 + x**6 - 3*x**4 - 3*x**3 + 8*x**2 + 2*x - 5 + q = 3*x**6 + 5*x**4 - 4*x**2 - 9*x + 21 + assert subresultants_pg(p, q, x) == subresultants(p, q, x) + assert subresultants_pg(p, q, x)[-1] == sylvester(p, q, x).det() + assert subresultants_pg(p, q, x) != euclid_pg(p, q, x) + amv_factors = [1, 1, -1, 1, -1, 1] + assert subresultants_pg(p, q, x) == [i*j for i, j in zip(amv_factors, modified_subresultants_amv(p, q, x))] + + p = x**3 - 7*x + 7 + q = 3*x**2 - 7 + assert subresultants_pg(p, q, x) == euclid_pg(p, q, x) + + +def test_subresultants_amv_q(): + x = Symbol('x') + + p = x**8 + x**6 - 3*x**4 - 3*x**3 + 8*x**2 + 2*x - 5 + q = 3*x**6 + 5*x**4 - 4*x**2 - 9*x + 21 + assert subresultants_amv_q(p, q, x) == subresultants(p, q, x) + assert subresultants_amv_q(p, q, x)[-1] == sylvester(p, q, x).det() + assert subresultants_amv_q(p, q, x) != euclid_amv(p, q, x) + amv_factors = [1, 1, -1, 1, -1, 1] + assert subresultants_amv_q(p, q, x) == [i*j for i, j in zip(amv_factors, modified_subresultants_amv(p, q, x))] + + p = x**3 - 7*x + 7 + q = 3*x**2 - 7 + assert subresultants_amv(p, q, x) == euclid_amv(p, q, x) + + +def test_rem_z(): + x = Symbol('x') + + p = x**8 + x**6 - 3*x**4 - 3*x**3 + 8*x**2 + 2*x - 5 + q = 3*x**6 + 5*x**4 - 4*x**2 - 9*x + 21 + assert rem_z(p, -q, x) != prem(p, -q, x) + +def test_quo_z(): + x = Symbol('x') + + p = x**8 + x**6 - 3*x**4 - 3*x**3 + 8*x**2 + 2*x - 5 + q = 3*x**6 + 5*x**4 - 4*x**2 - 9*x + 21 + assert quo_z(p, -q, x) != pquo(p, -q, x) + + y = Symbol('y') + q = 3*x**6 + 5*y**4 - 4*x**2 - 9*x + 21 + assert quo_z(p, -q, x) == pquo(p, -q, x) + +def test_subresultants_amv(): + x = Symbol('x') + + p = x**8 + x**6 - 3*x**4 - 3*x**3 + 8*x**2 + 2*x - 5 + q = 3*x**6 + 5*x**4 - 4*x**2 - 9*x + 21 + assert subresultants_amv(p, q, x) == subresultants(p, q, x) + assert subresultants_amv(p, q, x)[-1] == sylvester(p, q, x).det() + assert subresultants_amv(p, q, x) != euclid_amv(p, q, x) + amv_factors = [1, 1, -1, 1, -1, 1] + assert subresultants_amv(p, q, x) == [i*j for i, j in zip(amv_factors, modified_subresultants_amv(p, q, x))] + + p = x**3 - 7*x + 7 + q = 3*x**2 - 7 + assert subresultants_amv(p, q, x) == euclid_amv(p, q, x) + + +def test_modified_subresultants_amv(): + x = Symbol('x') + + p = x**8 + x**6 - 3*x**4 - 3*x**3 + 8*x**2 + 2*x - 5 + q = 3*x**6 + 5*x**4 - 4*x**2 - 9*x + 21 + amv_factors = [1, 1, -1, 1, -1, 1] + assert modified_subresultants_amv(p, q, x) == [i*j for i, j in zip(amv_factors, subresultants_amv(p, q, x))] + assert modified_subresultants_amv(p, q, x)[-1] != sylvester(p + x**8, q, x).det() + assert modified_subresultants_amv(p, q, x) != sturm_amv(p, q, x) + + p = x**3 - 7*x + 7 + q = 3*x**2 - 7 + assert modified_subresultants_amv(p, q, x) == sturm_amv(p, q, x) + assert modified_subresultants_amv(-p, q, x) != sturm_amv(-p, q, x) + + +def test_subresultants_rem(): + x = Symbol('x') + + p = x**8 + x**6 - 3*x**4 - 3*x**3 + 8*x**2 + 2*x - 5 + q = 3*x**6 + 5*x**4 - 4*x**2 - 9*x + 21 + assert subresultants_rem(p, q, x) == subresultants(p, q, x) + assert subresultants_rem(p, q, x)[-1] == sylvester(p, q, x).det() + assert subresultants_rem(p, q, x) != euclid_amv(p, q, x) + amv_factors = [1, 1, -1, 1, -1, 1] + assert subresultants_rem(p, q, x) == [i*j for i, j in zip(amv_factors, modified_subresultants_amv(p, q, x))] + + p = x**3 - 7*x + 7 + q = 3*x**2 - 7 + assert subresultants_rem(p, q, x) == euclid_amv(p, q, x) + + +def test_subresultants_vv(): + x = Symbol('x') + + p = x**8 + x**6 - 3*x**4 - 3*x**3 + 8*x**2 + 2*x - 5 + q = 3*x**6 + 5*x**4 - 4*x**2 - 9*x + 21 + assert subresultants_vv(p, q, x) == subresultants(p, q, x) + assert subresultants_vv(p, q, x)[-1] == sylvester(p, q, x).det() + assert subresultants_vv(p, q, x) != euclid_amv(p, q, x) + amv_factors = [1, 1, -1, 1, -1, 1] + assert subresultants_vv(p, q, x) == [i*j for i, j in zip(amv_factors, modified_subresultants_amv(p, q, x))] + + p = x**3 - 7*x + 7 + q = 3*x**2 - 7 + assert subresultants_vv(p, q, x) == euclid_amv(p, q, x) + + +def test_subresultants_vv_2(): + x = Symbol('x') + + p = x**8 + x**6 - 3*x**4 - 3*x**3 + 8*x**2 + 2*x - 5 + q = 3*x**6 + 5*x**4 - 4*x**2 - 9*x + 21 + assert subresultants_vv_2(p, q, x) == subresultants(p, q, x) + assert subresultants_vv_2(p, q, x)[-1] == sylvester(p, q, x).det() + assert subresultants_vv_2(p, q, x) != euclid_amv(p, q, x) + amv_factors = [1, 1, -1, 1, -1, 1] + assert subresultants_vv_2(p, q, x) == [i*j for i, j in zip(amv_factors, modified_subresultants_amv(p, q, x))] + + p = x**3 - 7*x + 7 + q = 3*x**2 - 7 + assert subresultants_vv_2(p, q, x) == euclid_amv(p, q, x) diff --git a/phivenv/Lib/site-packages/sympy/printing/__init__.py b/phivenv/Lib/site-packages/sympy/printing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..15dfaf70eb3777195b7c9a0930894bb2187bbb50 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/__init__.py @@ -0,0 +1,111 @@ +"""Printing subsystem""" + +from .pretty import pager_print, pretty, pretty_print, pprint, pprint_use_unicode, pprint_try_use_unicode + +from .latex import latex, print_latex, multiline_latex + +from .mathml import mathml, print_mathml + +from .python import python, print_python + +from .pycode import pycode + +from .codeprinter import print_ccode, print_fcode + +from .codeprinter import ccode, fcode, cxxcode, rust_code # noqa:F811 + +from .smtlib import smtlib_code + +from .glsl import glsl_code, print_glsl + +from .rcode import rcode, print_rcode + +from .jscode import jscode, print_jscode + +from .julia import julia_code + +from .mathematica import mathematica_code + +from .octave import octave_code + +from .gtk import print_gtk + +from .preview import preview + +from .repr import srepr + +from .tree import print_tree + +from .str import StrPrinter, sstr, sstrrepr + +from .tableform import TableForm + +from .dot import dotprint + +from .maple import maple_code, print_maple_code + +__all__ = [ + # sympy.printing.pretty + 'pager_print', 'pretty', 'pretty_print', 'pprint', 'pprint_use_unicode', + 'pprint_try_use_unicode', + + # sympy.printing.latex + 'latex', 'print_latex', 'multiline_latex', + + # sympy.printing.mathml + 'mathml', 'print_mathml', + + # sympy.printing.python + 'python', 'print_python', + + # sympy.printing.pycode + 'pycode', + + # sympy.printing.codeprinter + 'ccode', 'print_ccode', 'cxxcode', 'fcode', 'print_fcode', 'rust_code', + + # sympy.printing.smtlib + 'smtlib_code', + + # sympy.printing.glsl + 'glsl_code', 'print_glsl', + + # sympy.printing.rcode + 'rcode', 'print_rcode', + + # sympy.printing.jscode + 'jscode', 'print_jscode', + + # sympy.printing.julia + 'julia_code', + + # sympy.printing.mathematica + 'mathematica_code', + + # sympy.printing.octave + 'octave_code', + + # sympy.printing.gtk + 'print_gtk', + + # sympy.printing.preview + 'preview', + + # sympy.printing.repr + 'srepr', + + # sympy.printing.tree + 'print_tree', + + # sympy.printing.str + 'StrPrinter', 'sstr', 'sstrrepr', + + # sympy.printing.tableform + 'TableForm', + + # sympy.printing.dot + 'dotprint', + + # sympy.printing.maple + 'maple_code', 'print_maple_code', +] diff --git a/phivenv/Lib/site-packages/sympy/printing/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31223f79fd2204966bb2c5170fa90f95471bb96e Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/__pycache__/aesaracode.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/__pycache__/aesaracode.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3bb4b3c8f652aed1466cbafb07917f7b5a5ef4d8 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/__pycache__/aesaracode.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/__pycache__/c.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/__pycache__/c.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8aabcebe312a5996f597713fb320ea45db828e0 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/__pycache__/c.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/__pycache__/codeprinter.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/__pycache__/codeprinter.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a830620d45a2f815106520b13dd04e6d8e063b31 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/__pycache__/codeprinter.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/__pycache__/conventions.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/__pycache__/conventions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..42463ae1b3d01e9d2866bfa7e6764bb67fbe47e2 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/__pycache__/conventions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/__pycache__/cxx.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/__pycache__/cxx.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..abb1bd2cc77064840b56edd82807eb7cfc36ec6b Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/__pycache__/cxx.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/__pycache__/defaults.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/__pycache__/defaults.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53ed8a7a815ddb9511ca401badd116c41a9e86c3 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/__pycache__/defaults.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/__pycache__/dot.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/__pycache__/dot.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9a81667b55b5cafeb9f74ecff659827599fc43e Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/__pycache__/dot.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/__pycache__/fortran.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/__pycache__/fortran.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b8e8624003599f9f209d8d50bbc614e05890a654 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/__pycache__/fortran.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/__pycache__/glsl.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/__pycache__/glsl.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e44940d25488827ce23504789de477a970e4b49c Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/__pycache__/glsl.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/__pycache__/gtk.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/__pycache__/gtk.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..549735bd96d17caa0453b124cd46003cdefd7409 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/__pycache__/gtk.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/__pycache__/jscode.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/__pycache__/jscode.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d910ee850875974f87e2a78e8a91d8805161642 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/__pycache__/jscode.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/__pycache__/julia.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/__pycache__/julia.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8dbbe3bef70dab9a148cc4d8ccdfcf9c0f9cf47f Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/__pycache__/julia.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/__pycache__/lambdarepr.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/__pycache__/lambdarepr.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..906826efc7cd5b7768964b236b10163db08721a6 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/__pycache__/lambdarepr.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/__pycache__/llvmjitcode.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/__pycache__/llvmjitcode.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e350aad70b599eb1d1f0d62fe69515e50c984a02 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/__pycache__/llvmjitcode.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/__pycache__/maple.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/__pycache__/maple.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9bb05f107fb2e9400a0c33aa66a249412de30c35 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/__pycache__/maple.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/__pycache__/mathematica.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/__pycache__/mathematica.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..33ccaa79a3fd3dd23e410ac425e84722baed5d9d Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/__pycache__/mathematica.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/__pycache__/mathml.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/__pycache__/mathml.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3aa75b3dab994a75f4c7acaacda0e68345408b30 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/__pycache__/mathml.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/__pycache__/numpy.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/__pycache__/numpy.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74986986bf5b9505ba3abdefb4ff1b9f912cb8df Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/__pycache__/numpy.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/__pycache__/octave.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/__pycache__/octave.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac90d5da3c5169e1ba21a1c47df3f0edbab4032a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/__pycache__/octave.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/__pycache__/precedence.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/__pycache__/precedence.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7991ce4cf6409f19998d32ccbbee22a536f46885 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/__pycache__/precedence.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/__pycache__/preview.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/__pycache__/preview.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9e9928107ed62f5a884422b9c39097e305a028c Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/__pycache__/preview.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/__pycache__/printer.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/__pycache__/printer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38704c994d9cec222921ebbe92867ba96aa1e7b3 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/__pycache__/printer.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/__pycache__/pycode.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/__pycache__/pycode.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b6b3446afe404bcf961ea9b9746e78d053f51cd Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/__pycache__/pycode.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/__pycache__/python.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/__pycache__/python.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a80c7a1ef5c0fee3d5309c16bba9660408cb7c7 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/__pycache__/python.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/__pycache__/pytorch.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/__pycache__/pytorch.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..800da60d297bcd5a008b41fd9f0eb131f23cc64a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/__pycache__/pytorch.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/__pycache__/rcode.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/__pycache__/rcode.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d46e07608a3edf714a42ed79e15d752ae9a1b7d Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/__pycache__/rcode.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/__pycache__/repr.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/__pycache__/repr.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0b1816658db75d7c341a32724bfe88b7e7aa1d8 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/__pycache__/repr.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/__pycache__/rust.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/__pycache__/rust.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6cceb19e286bc470f59638939b237c4fb57cced2 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/__pycache__/rust.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/__pycache__/smtlib.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/__pycache__/smtlib.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ce15d4ef989d8be2b08beae421d63802edb4b65 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/__pycache__/smtlib.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/__pycache__/str.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/__pycache__/str.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc4a5992311028f407737aaf31a6a818d496f0a1 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/__pycache__/str.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/__pycache__/tableform.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/__pycache__/tableform.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d68babd39b0f2985a03af472e6ed6a0a5f8cf445 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/__pycache__/tableform.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/__pycache__/tensorflow.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/__pycache__/tensorflow.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..048a9f495d509730d04a01174470667bd18956e5 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/__pycache__/tensorflow.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/__pycache__/theanocode.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/__pycache__/theanocode.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1fb6de4a38ab8a34bcd3445af083eabc34f3314 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/__pycache__/theanocode.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/__pycache__/tree.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/__pycache__/tree.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c6ecb3868f08027f8a0b45c8b9b58d87373641c Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/__pycache__/tree.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/aesaracode.py b/phivenv/Lib/site-packages/sympy/printing/aesaracode.py new file mode 100644 index 0000000000000000000000000000000000000000..1e31c6940a86bd25ef8805420b84c22c5e08bca9 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/aesaracode.py @@ -0,0 +1,563 @@ +from __future__ import annotations +import math +from typing import Any + +from sympy.external import import_module +from sympy.printing.printer import Printer +from sympy.utilities.exceptions import sympy_deprecation_warning +from sympy.utilities.iterables import is_sequence +import sympy +from functools import partial + + +aesara = import_module('aesara') + +if aesara: + aes = aesara.scalar + aet = aesara.tensor + from aesara.tensor import nlinalg + from aesara.tensor.elemwise import Elemwise + from aesara.tensor.elemwise import DimShuffle + + # `true_divide` replaced `true_div` in Aesara 2.8.11 (released 2023) to + # match NumPy + # XXX: Remove this when not needed to support older versions. + true_divide = getattr(aet, 'true_divide', None) + if true_divide is None: + true_divide = aet.true_div + + mapping = { + sympy.Add: aet.add, + sympy.Mul: aet.mul, + sympy.Abs: aet.abs, + sympy.sign: aet.sgn, + sympy.ceiling: aet.ceil, + sympy.floor: aet.floor, + sympy.log: aet.log, + sympy.exp: aet.exp, + sympy.sqrt: aet.sqrt, + sympy.cos: aet.cos, + sympy.acos: aet.arccos, + sympy.sin: aet.sin, + sympy.asin: aet.arcsin, + sympy.tan: aet.tan, + sympy.atan: aet.arctan, + sympy.atan2: aet.arctan2, + sympy.cosh: aet.cosh, + sympy.acosh: aet.arccosh, + sympy.sinh: aet.sinh, + sympy.asinh: aet.arcsinh, + sympy.tanh: aet.tanh, + sympy.atanh: aet.arctanh, + sympy.re: aet.real, + sympy.im: aet.imag, + sympy.arg: aet.angle, + sympy.erf: aet.erf, + sympy.gamma: aet.gamma, + sympy.loggamma: aet.gammaln, + sympy.Pow: aet.pow, + sympy.Eq: aet.eq, + sympy.StrictGreaterThan: aet.gt, + sympy.StrictLessThan: aet.lt, + sympy.LessThan: aet.le, + sympy.GreaterThan: aet.ge, + sympy.And: aet.bitwise_and, # bitwise + sympy.Or: aet.bitwise_or, # bitwise + sympy.Not: aet.invert, # bitwise + sympy.Xor: aet.bitwise_xor, # bitwise + sympy.Max: aet.maximum, # Sympy accept >2 inputs, Aesara only 2 + sympy.Min: aet.minimum, # Sympy accept >2 inputs, Aesara only 2 + sympy.conjugate: aet.conj, + sympy.core.numbers.ImaginaryUnit: lambda:aet.complex(0,1), + # Matrices + sympy.MatAdd: Elemwise(aes.add), + sympy.HadamardProduct: Elemwise(aes.mul), + sympy.Trace: nlinalg.trace, + sympy.Determinant : nlinalg.det, + sympy.Inverse: nlinalg.matrix_inverse, + sympy.Transpose: DimShuffle((False, False), [1, 0]), + } + + +class AesaraPrinter(Printer): + """ + .. deprecated:: 1.14. + The ``Aesara Code printing`` is deprecated.See its documentation for + more information. See :ref:`deprecated-aesaraprinter` for details. + + Code printer which creates Aesara symbolic expression graphs. + + Parameters + ========== + + cache : dict + Cache dictionary to use. If None (default) will use + the global cache. To create a printer which does not depend on or alter + global state pass an empty dictionary. Note: the dictionary is not + copied on initialization of the printer and will be updated in-place, + so using the same dict object when creating multiple printers or making + multiple calls to :func:`.aesara_code` or :func:`.aesara_function` means + the cache is shared between all these applications. + + Attributes + ========== + + cache : dict + A cache of Aesara variables which have been created for SymPy + symbol-like objects (e.g. :class:`sympy.core.symbol.Symbol` or + :class:`sympy.matrices.expressions.MatrixSymbol`). This is used to + ensure that all references to a given symbol in an expression (or + multiple expressions) are printed as the same Aesara variable, which is + created only once. Symbols are differentiated only by name and type. The + format of the cache's contents should be considered opaque to the user. + """ + printmethod = "_aesara" + + def __init__(self, *args, **kwargs): + self.cache = kwargs.pop('cache', {}) + super().__init__(*args, **kwargs) + + def _get_key(self, s, name=None, dtype=None, broadcastable=None): + """ Get the cache key for a SymPy object. + + Parameters + ========== + + s : sympy.core.basic.Basic + SymPy object to get key for. + + name : str + Name of object, if it does not have a ``name`` attribute. + """ + + if name is None: + name = s.name + + return (name, type(s), s.args, dtype, broadcastable) + + def _get_or_create(self, s, name=None, dtype=None, broadcastable=None): + """ + Get the Aesara variable for a SymPy symbol from the cache, or create it + if it does not exist. + """ + + # Defaults + if name is None: + name = s.name + if dtype is None: + dtype = 'floatX' + if broadcastable is None: + broadcastable = () + + key = self._get_key(s, name, dtype=dtype, broadcastable=broadcastable) + + if key in self.cache: + return self.cache[key] + + value = aet.tensor(name=name, dtype=dtype, shape=broadcastable) + self.cache[key] = value + return value + + def _print_Symbol(self, s, **kwargs): + dtype = kwargs.get('dtypes', {}).get(s) + bc = kwargs.get('broadcastables', {}).get(s) + return self._get_or_create(s, dtype=dtype, broadcastable=bc) + + def _print_AppliedUndef(self, s, **kwargs): + name = str(type(s)) + '_' + str(s.args[0]) + dtype = kwargs.get('dtypes', {}).get(s) + bc = kwargs.get('broadcastables', {}).get(s) + return self._get_or_create(s, name=name, dtype=dtype, broadcastable=bc) + + def _print_Basic(self, expr, **kwargs): + op = mapping[type(expr)] + children = [self._print(arg, **kwargs) for arg in expr.args] + return op(*children) + + def _print_Number(self, n, **kwargs): + # Integers already taken care of below, interpret as float + return float(n.evalf()) + + def _print_MatrixSymbol(self, X, **kwargs): + dtype = kwargs.get('dtypes', {}).get(X) + return self._get_or_create(X, dtype=dtype, broadcastable=(None, None)) + + def _print_DenseMatrix(self, X, **kwargs): + if not hasattr(aet, 'stacklists'): + raise NotImplementedError( + "Matrix translation not yet supported in this version of Aesara") + + return aet.stacklists([ + [self._print(arg, **kwargs) for arg in L] + for L in X.tolist() + ]) + + _print_ImmutableMatrix = _print_ImmutableDenseMatrix = _print_DenseMatrix + + def _print_MatMul(self, expr, **kwargs): + children = [self._print(arg, **kwargs) for arg in expr.args] + result = children[0] + for child in children[1:]: + result = aet.dot(result, child) + return result + + def _print_MatPow(self, expr, **kwargs): + children = [self._print(arg, **kwargs) for arg in expr.args] + result = 1 + if isinstance(children[1], int) and children[1] > 0: + for i in range(children[1]): + result = aet.dot(result, children[0]) + else: + raise NotImplementedError('''Only non-negative integer + powers of matrices can be handled by Aesara at the moment''') + return result + + def _print_MatrixSlice(self, expr, **kwargs): + parent = self._print(expr.parent, **kwargs) + rowslice = self._print(slice(*expr.rowslice), **kwargs) + colslice = self._print(slice(*expr.colslice), **kwargs) + return parent[rowslice, colslice] + + def _print_BlockMatrix(self, expr, **kwargs): + nrows, ncols = expr.blocks.shape + blocks = [[self._print(expr.blocks[r, c], **kwargs) + for c in range(ncols)] + for r in range(nrows)] + return aet.join(0, *[aet.join(1, *row) for row in blocks]) + + + def _print_slice(self, expr, **kwargs): + return slice(*[self._print(i, **kwargs) + if isinstance(i, sympy.Basic) else i + for i in (expr.start, expr.stop, expr.step)]) + + def _print_Pi(self, expr, **kwargs): + return math.pi + + def _print_Piecewise(self, expr, **kwargs): + import numpy as np + e, cond = expr.args[0].args # First condition and corresponding value + + # Print conditional expression and value for first condition + p_cond = self._print(cond, **kwargs) + p_e = self._print(e, **kwargs) + + # One condition only + if len(expr.args) == 1: + # Return value if condition else NaN + return aet.switch(p_cond, p_e, np.nan) + + # Return value_1 if condition_1 else evaluate remaining conditions + p_remaining = self._print(sympy.Piecewise(*expr.args[1:]), **kwargs) + return aet.switch(p_cond, p_e, p_remaining) + + def _print_Rational(self, expr, **kwargs): + return true_divide(self._print(expr.p, **kwargs), + self._print(expr.q, **kwargs)) + + def _print_Integer(self, expr, **kwargs): + return expr.p + + def _print_factorial(self, expr, **kwargs): + return self._print(sympy.gamma(expr.args[0] + 1), **kwargs) + + def _print_Derivative(self, deriv, **kwargs): + from aesara.gradient import Rop + + rv = self._print(deriv.expr, **kwargs) + for var in deriv.variables: + var = self._print(var, **kwargs) + rv = Rop(rv, var, aet.ones_like(var)) + return rv + + def emptyPrinter(self, expr): + return expr + + def doprint(self, expr, dtypes=None, broadcastables=None): + """ Convert a SymPy expression to a Aesara graph variable. + + The ``dtypes`` and ``broadcastables`` arguments are used to specify the + data type, dimension, and broadcasting behavior of the Aesara variables + corresponding to the free symbols in ``expr``. Each is a mapping from + SymPy symbols to the value of the corresponding argument to + ``aesara.tensor.var.TensorVariable``. + + See the corresponding `documentation page`__ for more information on + broadcasting in Aesara. + + + .. __: https://aesara.readthedocs.io/en/latest/reference/tensor/shapes.html#broadcasting + + Parameters + ========== + + expr : sympy.core.expr.Expr + SymPy expression to print. + + dtypes : dict + Mapping from SymPy symbols to Aesara datatypes to use when creating + new Aesara variables for those symbols. Corresponds to the ``dtype`` + argument to ``aesara.tensor.var.TensorVariable``. Defaults to ``'floatX'`` + for symbols not included in the mapping. + + broadcastables : dict + Mapping from SymPy symbols to the value of the ``broadcastable`` + argument to ``aesara.tensor.var.TensorVariable`` to use when creating Aesara + variables for those symbols. Defaults to the empty tuple for symbols + not included in the mapping (resulting in a scalar). + + Returns + ======= + + aesara.graph.basic.Variable + A variable corresponding to the expression's value in a Aesara + symbolic expression graph. + + """ + if dtypes is None: + dtypes = {} + if broadcastables is None: + broadcastables = {} + + return self._print(expr, dtypes=dtypes, broadcastables=broadcastables) + + +global_cache: dict[Any, Any] = {} + + +def aesara_code(expr, cache=None, **kwargs): + """ + Convert a SymPy expression into a Aesara graph variable. + + Parameters + ========== + + expr : sympy.core.expr.Expr + SymPy expression object to convert. + + cache : dict + Cached Aesara variables (see :class:`AesaraPrinter.cache + `). Defaults to the module-level global cache. + + dtypes : dict + Passed to :meth:`.AesaraPrinter.doprint`. + + broadcastables : dict + Passed to :meth:`.AesaraPrinter.doprint`. + + Returns + ======= + + aesara.graph.basic.Variable + A variable corresponding to the expression's value in a Aesara symbolic + expression graph. + + """ + sympy_deprecation_warning( + """ + The aesara_code function is deprecated. + """, + deprecated_since_version="1.14", + active_deprecations_target='deprecated-aesaraprinter', + ) + + if not aesara: + raise ImportError("aesara is required for aesara_code") + + if cache is None: + cache = global_cache + + return AesaraPrinter(cache=cache, settings={}).doprint(expr, **kwargs) + + +def dim_handling(inputs, dim=None, dims=None, broadcastables=None): + r""" + Get value of ``broadcastables`` argument to :func:`.aesara_code` from + keyword arguments to :func:`.aesara_function`. + + Included for backwards compatibility. + + Parameters + ========== + + inputs + Sequence of input symbols. + + dim : int + Common number of dimensions for all inputs. Overrides other arguments + if given. + + dims : dict + Mapping from input symbols to number of dimensions. Overrides + ``broadcastables`` argument if given. + + broadcastables : dict + Explicit value of ``broadcastables`` argument to + :meth:`.AesaraPrinter.doprint`. If not None function will return this value unchanged. + + Returns + ======= + dict + Dictionary mapping elements of ``inputs`` to their "broadcastable" + values (tuple of ``bool``\ s). + """ + if dim is not None: + return dict.fromkeys(inputs, (False,) * dim) + + if dims is not None: + maxdim = max(dims.values()) + return { + s: (False,) * d + (True,) * (maxdim - d) + for s, d in dims.items() + } + + if broadcastables is not None: + return broadcastables + + return {} + + +def aesara_function(inputs, outputs, scalar=False, *, + dim=None, dims=None, broadcastables=None, **kwargs): + """ + Create a Aesara function from SymPy expressions. + + The inputs and outputs are converted to Aesara variables using + :func:`.aesara_code` and then passed to ``aesara.function``. + + Parameters + ========== + + inputs + Sequence of symbols which constitute the inputs of the function. + + outputs + Sequence of expressions which constitute the outputs(s) of the + function. The free symbols of each expression must be a subset of + ``inputs``. + + scalar : bool + Convert 0-dimensional arrays in output to scalars. This will return a + Python wrapper function around the Aesara function object. + + cache : dict + Cached Aesara variables (see :class:`AesaraPrinter.cache + `). Defaults to the module-level global cache. + + dtypes : dict + Passed to :meth:`.AesaraPrinter.doprint`. + + broadcastables : dict + Passed to :meth:`.AesaraPrinter.doprint`. + + dims : dict + Alternative to ``broadcastables`` argument. Mapping from elements of + ``inputs`` to integers indicating the dimension of their associated + arrays/tensors. Overrides ``broadcastables`` argument if given. + + dim : int + Another alternative to the ``broadcastables`` argument. Common number of + dimensions to use for all arrays/tensors. + ``aesara_function([x, y], [...], dim=2)`` is equivalent to using + ``broadcastables={x: (False, False), y: (False, False)}``. + + Returns + ======= + callable + A callable object which takes values of ``inputs`` as positional + arguments and returns an output array for each of the expressions + in ``outputs``. If ``outputs`` is a single expression the function will + return a Numpy array, if it is a list of multiple expressions the + function will return a list of arrays. See description of the ``squeeze`` + argument above for the behavior when a single output is passed in a list. + The returned object will either be an instance of + ``aesara.compile.function.types.Function`` or a Python wrapper + function around one. In both cases, the returned value will have a + ``aesara_function`` attribute which points to the return value of + ``aesara.function``. + + Examples + ======== + + >>> from sympy.abc import x, y, z + >>> from sympy.printing.aesaracode import aesara_function + + A simple function with one input and one output: + + >>> f1 = aesara_function([x], [x**2 - 1], scalar=True) + >>> f1(3) + 8.0 + + A function with multiple inputs and one output: + + >>> f2 = aesara_function([x, y, z], [(x**z + y**z)**(1/z)], scalar=True) + >>> f2(3, 4, 2) + 5.0 + + A function with multiple inputs and multiple outputs: + + >>> f3 = aesara_function([x, y], [x**2 + y**2, x**2 - y**2], scalar=True) + >>> f3(2, 3) + [13.0, -5.0] + + See also + ======== + + dim_handling + + """ + sympy_deprecation_warning( + """ + The aesara_function function is deprecated. + """, + deprecated_since_version="1.14", + active_deprecations_target='deprecated-aesaraprinter', + ) + + if not aesara: + raise ImportError("Aesara is required for aesara_function") + + # Pop off non-aesara keyword args + cache = kwargs.pop('cache', {}) + dtypes = kwargs.pop('dtypes', {}) + + broadcastables = dim_handling( + inputs, dim=dim, dims=dims, broadcastables=broadcastables, + ) + + # Print inputs/outputs + code = partial(aesara_code, cache=cache, dtypes=dtypes, + broadcastables=broadcastables) + tinputs = list(map(code, inputs)) + toutputs = list(map(code, outputs)) + + #fix constant expressions as variables + toutputs = [output if isinstance(output, aesara.graph.basic.Variable) else aet.as_tensor_variable(output) for output in toutputs] + + if len(toutputs) == 1: + toutputs = toutputs[0] + + # Compile aesara func + func = aesara.function(tinputs, toutputs, **kwargs) + + is_0d = [len(o.variable.broadcastable) == 0 for o in func.outputs] + + # No wrapper required + if not scalar or not any(is_0d): + func.aesara_function = func + return func + + # Create wrapper to convert 0-dimensional outputs to scalars + def wrapper(*args): + out = func(*args) + # out can be array(1.0) or [array(1.0), array(2.0)] + + if is_sequence(out): + return [o[()] if is_0d[i] else o for i, o in enumerate(out)] + else: + return out[()] + + wrapper.__wrapped__ = func + wrapper.__doc__ = func.__doc__ + wrapper.aesara_function = func + return wrapper diff --git a/phivenv/Lib/site-packages/sympy/printing/c.py b/phivenv/Lib/site-packages/sympy/printing/c.py new file mode 100644 index 0000000000000000000000000000000000000000..34c4b8f021073aeee7672248838557b5fa85fbae --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/c.py @@ -0,0 +1,747 @@ +""" +C code printer + +The C89CodePrinter & C99CodePrinter converts single SymPy expressions into +single C expressions, using the functions defined in math.h where possible. + +A complete code generator, which uses ccode extensively, can be found in +sympy.utilities.codegen. The codegen module can be used to generate complete +source code files that are compilable without further modifications. + + +""" + +from __future__ import annotations +from typing import Any + +from functools import wraps +from itertools import chain + +from sympy.core import S +from sympy.core.numbers import equal_valued, Float +from sympy.codegen.ast import ( + Assignment, Pointer, Variable, Declaration, Type, + real, complex_, integer, bool_, float32, float64, float80, + complex64, complex128, intc, value_const, pointer_const, + int8, int16, int32, int64, uint8, uint16, uint32, uint64, untyped, + none +) +from sympy.printing.codeprinter import CodePrinter, requires +from sympy.printing.precedence import precedence, PRECEDENCE +from sympy.sets.fancysets import Range + +# These are defined in the other file so we can avoid importing sympy.codegen +# from the top-level 'import sympy'. Export them here as well. +from sympy.printing.codeprinter import ccode, print_ccode # noqa:F401 + +# dictionary mapping SymPy function to (argument_conditions, C_function). +# Used in C89CodePrinter._print_Function(self) +known_functions_C89 = { + "Abs": [(lambda x: not x.is_integer, "fabs"), (lambda x: x.is_integer, "abs")], + "sin": "sin", + "cos": "cos", + "tan": "tan", + "asin": "asin", + "acos": "acos", + "atan": "atan", + "atan2": "atan2", + "exp": "exp", + "log": "log", + "log10": "log10", + "sinh": "sinh", + "cosh": "cosh", + "tanh": "tanh", + "floor": "floor", + "ceiling": "ceil", + "sqrt": "sqrt", # To enable automatic rewrites +} + +known_functions_C99 = dict(known_functions_C89, **{ + 'exp2': 'exp2', + 'expm1': 'expm1', + 'log2': 'log2', + 'log1p': 'log1p', + 'Cbrt': 'cbrt', + 'hypot': 'hypot', + 'fma': 'fma', + 'loggamma': 'lgamma', + 'erfc': 'erfc', + 'Max': 'fmax', + 'Min': 'fmin', + "asinh": "asinh", + "acosh": "acosh", + "atanh": "atanh", + "erf": "erf", + "gamma": "tgamma", +}) + +# These are the core reserved words in the C language. Taken from: +# https://en.cppreference.com/w/c/keyword + +reserved_words = [ + 'auto', 'break', 'case', 'char', 'const', 'continue', 'default', 'do', + 'double', 'else', 'enum', 'extern', 'float', 'for', 'goto', 'if', 'int', + 'long', 'register', 'return', 'short', 'signed', 'sizeof', 'static', + 'struct', 'entry', # never standardized, we'll leave it here anyway + 'switch', 'typedef', 'union', 'unsigned', 'void', 'volatile', 'while' +] + +reserved_words_c99 = ['inline', 'restrict'] + +def get_math_macros(): + """ Returns a dictionary with math-related macros from math.h/cmath + + Note that these macros are not strictly required by the C/C++-standard. + For MSVC they are enabled by defining "_USE_MATH_DEFINES" (preferably + via a compilation flag). + + Returns + ======= + + Dictionary mapping SymPy expressions to strings (macro names) + + """ + from sympy.codegen.cfunctions import log2, Sqrt + from sympy.functions.elementary.exponential import log + from sympy.functions.elementary.miscellaneous import sqrt + + return { + S.Exp1: 'M_E', + log2(S.Exp1): 'M_LOG2E', + 1/log(2): 'M_LOG2E', + log(2): 'M_LN2', + log(10): 'M_LN10', + S.Pi: 'M_PI', + S.Pi/2: 'M_PI_2', + S.Pi/4: 'M_PI_4', + 1/S.Pi: 'M_1_PI', + 2/S.Pi: 'M_2_PI', + 2/sqrt(S.Pi): 'M_2_SQRTPI', + 2/Sqrt(S.Pi): 'M_2_SQRTPI', + sqrt(2): 'M_SQRT2', + Sqrt(2): 'M_SQRT2', + 1/sqrt(2): 'M_SQRT1_2', + 1/Sqrt(2): 'M_SQRT1_2' + } + + +def _as_macro_if_defined(meth): + """ Decorator for printer methods + + When a Printer's method is decorated using this decorator the expressions printed + will first be looked for in the attribute ``math_macros``, and if present it will + print the macro name in ``math_macros`` followed by a type suffix for the type + ``real``. e.g. printing ``sympy.pi`` would print ``M_PIl`` if real is mapped to float80. + + """ + @wraps(meth) + def _meth_wrapper(self, expr, **kwargs): + if expr in self.math_macros: + return '%s%s' % (self.math_macros[expr], self._get_math_macro_suffix(real)) + else: + return meth(self, expr, **kwargs) + + return _meth_wrapper + + +class C89CodePrinter(CodePrinter): + """A printer to convert Python expressions to strings of C code""" + printmethod = "_ccode" + language = "C" + standard = "C89" + reserved_words = set(reserved_words) + + _default_settings: dict[str, Any] = dict(CodePrinter._default_settings, **{ + 'precision': 17, + 'user_functions': {}, + 'contract': True, + 'dereference': set(), + 'error_on_reserved': False, + }) + + type_aliases = { + real: float64, + complex_: complex128, + integer: intc + } + + type_mappings: dict[Type, Any] = { + real: 'double', + intc: 'int', + float32: 'float', + float64: 'double', + integer: 'int', + bool_: 'bool', + int8: 'int8_t', + int16: 'int16_t', + int32: 'int32_t', + int64: 'int64_t', + uint8: 'int8_t', + uint16: 'int16_t', + uint32: 'int32_t', + uint64: 'int64_t', + } + + type_headers = { + bool_: {'stdbool.h'}, + int8: {'stdint.h'}, + int16: {'stdint.h'}, + int32: {'stdint.h'}, + int64: {'stdint.h'}, + uint8: {'stdint.h'}, + uint16: {'stdint.h'}, + uint32: {'stdint.h'}, + uint64: {'stdint.h'}, + } + + # Macros needed to be defined when using a Type + type_macros: dict[Type, tuple[str, ...]] = {} + + type_func_suffixes = { + float32: 'f', + float64: '', + float80: 'l' + } + + type_literal_suffixes = { + float32: 'F', + float64: '', + float80: 'L' + } + + type_math_macro_suffixes = { + float80: 'l' + } + + math_macros = None + + _ns = '' # namespace, C++ uses 'std::' + # known_functions-dict to copy + _kf: dict[str, Any] = known_functions_C89 + + def __init__(self, settings=None): + settings = settings or {} + if self.math_macros is None: + self.math_macros = settings.pop('math_macros', get_math_macros()) + self.type_aliases = dict(chain(self.type_aliases.items(), + settings.pop('type_aliases', {}).items())) + self.type_mappings = dict(chain(self.type_mappings.items(), + settings.pop('type_mappings', {}).items())) + self.type_headers = dict(chain(self.type_headers.items(), + settings.pop('type_headers', {}).items())) + self.type_macros = dict(chain(self.type_macros.items(), + settings.pop('type_macros', {}).items())) + self.type_func_suffixes = dict(chain(self.type_func_suffixes.items(), + settings.pop('type_func_suffixes', {}).items())) + self.type_literal_suffixes = dict(chain(self.type_literal_suffixes.items(), + settings.pop('type_literal_suffixes', {}).items())) + self.type_math_macro_suffixes = dict(chain(self.type_math_macro_suffixes.items(), + settings.pop('type_math_macro_suffixes', {}).items())) + super().__init__(settings) + self.known_functions = dict(self._kf, **settings.get('user_functions', {})) + self._dereference = set(settings.get('dereference', [])) + self.headers = set() + self.libraries = set() + self.macros = set() + + def _rate_index_position(self, p): + return p*5 + + def _get_statement(self, codestring): + """ Get code string as a statement - i.e. ending with a semicolon. """ + return codestring if codestring.endswith(';') else codestring + ';' + + def _get_comment(self, text): + return "/* {} */".format(text) + + def _declare_number_const(self, name, value): + type_ = self.type_aliases[real] + var = Variable(name, type=type_, value=value.evalf(type_.decimal_dig), attrs={value_const}) + decl = Declaration(var) + return self._get_statement(self._print(decl)) + + def _format_code(self, lines): + return self.indent_code(lines) + + def _traverse_matrix_indices(self, mat): + rows, cols = mat.shape + return ((i, j) for i in range(rows) for j in range(cols)) + + @_as_macro_if_defined + def _print_Mul(self, expr, **kwargs): + return super()._print_Mul(expr, **kwargs) + + @_as_macro_if_defined + def _print_Pow(self, expr): + if "Pow" in self.known_functions: + return self._print_Function(expr) + PREC = precedence(expr) + suffix = self._get_func_suffix(real) + if equal_valued(expr.exp, -1): + return '%s/%s' % (self._print_Float(Float(1.0)), self.parenthesize(expr.base, PREC)) + elif equal_valued(expr.exp, 0.5): + return '%ssqrt%s(%s)' % (self._ns, suffix, self._print(expr.base)) + elif expr.exp == S.One/3 and self.standard != 'C89': + return '%scbrt%s(%s)' % (self._ns, suffix, self._print(expr.base)) + else: + return '%spow%s(%s, %s)' % (self._ns, suffix, self._print(expr.base), + self._print(expr.exp)) + + def _print_Mod(self, expr): + num, den = expr.args + if num.is_integer and den.is_integer: + PREC = precedence(expr) + snum, sden = [self.parenthesize(arg, PREC) for arg in expr.args] + # % is remainder (same sign as numerator), not modulo (same sign as + # denominator), in C. Hence, % only works as modulo if both numbers + # have the same sign + if (num.is_nonnegative and den.is_nonnegative or + num.is_nonpositive and den.is_nonpositive): + return f"{snum} % {sden}" + return f"(({snum} % {sden}) + {sden}) % {sden}" + # Not guaranteed integer + return self._print_math_func(expr, known='fmod') + + def _print_Rational(self, expr): + p, q = int(expr.p), int(expr.q) + suffix = self._get_literal_suffix(real) + return '%d.0%s/%d.0%s' % (p, suffix, q, suffix) + + def _print_Indexed(self, expr): + # calculate index for 1d array + offset = getattr(expr.base, 'offset', S.Zero) + strides = getattr(expr.base, 'strides', None) + indices = expr.indices + + if strides is None or isinstance(strides, str): + dims = expr.shape + shift = S.One + temp = () + if strides == 'C' or strides is None: + traversal = reversed(range(expr.rank)) + indices = indices[::-1] + elif strides == 'F': + traversal = range(expr.rank) + + for i in traversal: + temp += (shift,) + shift *= dims[i] + strides = temp + flat_index = sum(x[0]*x[1] for x in zip(indices, strides)) + offset + return "%s[%s]" % (self._print(expr.base.label), + self._print(flat_index)) + + @_as_macro_if_defined + def _print_NumberSymbol(self, expr): + return super()._print_NumberSymbol(expr) + + def _print_Infinity(self, expr): + return 'HUGE_VAL' + + def _print_NegativeInfinity(self, expr): + return '-HUGE_VAL' + + def _print_Piecewise(self, expr): + if expr.args[-1].cond != True: + # We need the last conditional to be a True, otherwise the resulting + # function may not return a result. + raise ValueError("All Piecewise expressions must contain an " + "(expr, True) statement to be used as a default " + "condition. Without one, the generated " + "expression may not evaluate to anything under " + "some condition.") + lines = [] + if expr.has(Assignment): + for i, (e, c) in enumerate(expr.args): + if i == 0: + lines.append("if (%s) {" % self._print(c)) + elif i == len(expr.args) - 1 and c == True: + lines.append("else {") + else: + lines.append("else if (%s) {" % self._print(c)) + code0 = self._print(e) + lines.append(code0) + lines.append("}") + return "\n".join(lines) + else: + # The piecewise was used in an expression, need to do inline + # operators. This has the downside that inline operators will + # not work for statements that span multiple lines (Matrix or + # Indexed expressions). + ecpairs = ["((%s) ? (\n%s\n)\n" % (self._print(c), + self._print(e)) + for e, c in expr.args[:-1]] + last_line = ": (\n%s\n)" % self._print(expr.args[-1].expr) + return ": ".join(ecpairs) + last_line + " ".join([")"*len(ecpairs)]) + + def _print_ITE(self, expr): + from sympy.functions import Piecewise + return self._print(expr.rewrite(Piecewise, deep=False)) + + def _print_MatrixElement(self, expr): + return "{}[{}]".format(self.parenthesize(expr.parent, PRECEDENCE["Atom"], + strict=True), expr.j + expr.i*expr.parent.shape[1]) + + def _print_Symbol(self, expr): + name = super()._print_Symbol(expr) + if expr in self._settings['dereference']: + return '(*{})'.format(name) + else: + return name + + def _print_Relational(self, expr): + lhs_code = self._print(expr.lhs) + rhs_code = self._print(expr.rhs) + op = expr.rel_op + return "{} {} {}".format(lhs_code, op, rhs_code) + + def _print_For(self, expr): + target = self._print(expr.target) + if isinstance(expr.iterable, Range): + start, stop, step = expr.iterable.args + else: + raise NotImplementedError("Only iterable currently supported is Range") + body = self._print(expr.body) + return ('for ({target} = {start}; {target} < {stop}; {target} += ' + '{step}) {{\n{body}\n}}').format(target=target, start=start, + stop=stop, step=step, body=body) + + def _print_sign(self, func): + return '((({0}) > 0) - (({0}) < 0))'.format(self._print(func.args[0])) + + def _print_Max(self, expr): + if "Max" in self.known_functions: + return self._print_Function(expr) + def inner_print_max(args): # The more natural abstraction of creating + if len(args) == 1: # and printing smaller Max objects is slow + return self._print(args[0]) # when there are many arguments. + half = len(args) // 2 + return "((%(a)s > %(b)s) ? %(a)s : %(b)s)" % { + 'a': inner_print_max(args[:half]), + 'b': inner_print_max(args[half:]) + } + return inner_print_max(expr.args) + + def _print_Min(self, expr): + if "Min" in self.known_functions: + return self._print_Function(expr) + def inner_print_min(args): # The more natural abstraction of creating + if len(args) == 1: # and printing smaller Min objects is slow + return self._print(args[0]) # when there are many arguments. + half = len(args) // 2 + return "((%(a)s < %(b)s) ? %(a)s : %(b)s)" % { + 'a': inner_print_min(args[:half]), + 'b': inner_print_min(args[half:]) + } + return inner_print_min(expr.args) + + def indent_code(self, code): + """Accepts a string of code or a list of code lines""" + + if isinstance(code, str): + code_lines = self.indent_code(code.splitlines(True)) + return ''.join(code_lines) + + tab = " " + inc_token = ('{', '(', '{\n', '(\n') + dec_token = ('}', ')') + + code = [line.lstrip(' \t') for line in code] + + increase = [int(any(map(line.endswith, inc_token))) for line in code] + decrease = [int(any(map(line.startswith, dec_token))) for line in code] + + pretty = [] + level = 0 + for n, line in enumerate(code): + if line in ('', '\n'): + pretty.append(line) + continue + level -= decrease[n] + pretty.append("%s%s" % (tab*level, line)) + level += increase[n] + return pretty + + def _get_func_suffix(self, type_): + return self.type_func_suffixes[self.type_aliases.get(type_, type_)] + + def _get_literal_suffix(self, type_): + return self.type_literal_suffixes[self.type_aliases.get(type_, type_)] + + def _get_math_macro_suffix(self, type_): + alias = self.type_aliases.get(type_, type_) + dflt = self.type_math_macro_suffixes.get(alias, '') + return self.type_math_macro_suffixes.get(type_, dflt) + + def _print_Tuple(self, expr): + return '{'+', '.join(self._print(e) for e in expr)+'}' + + _print_List = _print_Tuple + + def _print_Type(self, type_): + self.headers.update(self.type_headers.get(type_, set())) + self.macros.update(self.type_macros.get(type_, set())) + return self._print(self.type_mappings.get(type_, type_.name)) + + def _print_Declaration(self, decl): + from sympy.codegen.cnodes import restrict + var = decl.variable + val = var.value + if var.type == untyped: + raise ValueError("C does not support untyped variables") + + if isinstance(var, Pointer): + result = '{vc}{t} *{pc} {r}{s}'.format( + vc='const ' if value_const in var.attrs else '', + t=self._print(var.type), + pc=' const' if pointer_const in var.attrs else '', + r='restrict ' if restrict in var.attrs else '', + s=self._print(var.symbol) + ) + elif isinstance(var, Variable): + result = '{vc}{t} {s}'.format( + vc='const ' if value_const in var.attrs else '', + t=self._print(var.type), + s=self._print(var.symbol) + ) + else: + raise NotImplementedError("Unknown type of var: %s" % type(var)) + if val != None: # Must be "!= None", cannot be "is not None" + result += ' = %s' % self._print(val) + return result + + def _print_Float(self, flt): + type_ = self.type_aliases.get(real, real) + self.macros.update(self.type_macros.get(type_, set())) + suffix = self._get_literal_suffix(type_) + num = str(flt.evalf(type_.decimal_dig)) + if 'e' not in num and '.' not in num: + num += '.0' + num_parts = num.split('e') + num_parts[0] = num_parts[0].rstrip('0') + if num_parts[0].endswith('.'): + num_parts[0] += '0' + return 'e'.join(num_parts) + suffix + + @requires(headers={'stdbool.h'}) + def _print_BooleanTrue(self, expr): + return 'true' + + @requires(headers={'stdbool.h'}) + def _print_BooleanFalse(self, expr): + return 'false' + + def _print_Element(self, elem): + if elem.strides == None: # Must be "== None", cannot be "is None" + if elem.offset != None: # Must be "!= None", cannot be "is not None" + raise ValueError("Expected strides when offset is given") + idxs = ']['.join((self._print(arg) for arg in elem.indices)) + else: + global_idx = sum(i*s for i, s in zip(elem.indices, elem.strides)) + if elem.offset != None: # Must be "!= None", cannot be "is not None" + global_idx += elem.offset + idxs = self._print(global_idx) + + return "{symb}[{idxs}]".format( + symb=self._print(elem.symbol), + idxs=idxs + ) + + def _print_CodeBlock(self, expr): + """ Elements of code blocks printed as statements. """ + return '\n'.join([self._get_statement(self._print(i)) for i in expr.args]) + + def _print_While(self, expr): + return 'while ({condition}) {{\n{body}\n}}'.format(**expr.kwargs( + apply=lambda arg: self._print(arg))) + + def _print_Scope(self, expr): + return '{\n%s\n}' % self._print_CodeBlock(expr.body) + + @requires(headers={'stdio.h'}) + def _print_Print(self, expr): + if expr.file == none: + template = 'printf({fmt}, {pargs})' + else: + template = 'fprintf(%(out)s, {fmt}, {pargs})' % { + 'out': self._print(expr.file) + } + return template.format( + fmt="%s\n" if expr.format_string == none else self._print(expr.format_string), + pargs=', '.join((self._print(arg) for arg in expr.print_args)) + ) + + def _print_Stream(self, strm): + return strm.name + + def _print_FunctionPrototype(self, expr): + pars = ', '.join((self._print(Declaration(arg)) for arg in expr.parameters)) + return "%s %s(%s)" % ( + tuple((self._print(arg) for arg in (expr.return_type, expr.name))) + (pars,) + ) + + def _print_FunctionDefinition(self, expr): + return "%s%s" % (self._print_FunctionPrototype(expr), + self._print_Scope(expr)) + + def _print_Return(self, expr): + arg, = expr.args + return 'return %s' % self._print(arg) + + def _print_CommaOperator(self, expr): + return '(%s)' % ', '.join((self._print(arg) for arg in expr.args)) + + def _print_Label(self, expr): + if expr.body == none: + return '%s:' % str(expr.name) + if len(expr.body.args) == 1: + return '%s:\n%s' % (str(expr.name), self._print_CodeBlock(expr.body)) + return '%s:\n{\n%s\n}' % (str(expr.name), self._print_CodeBlock(expr.body)) + + def _print_goto(self, expr): + return 'goto %s' % expr.label.name + + def _print_PreIncrement(self, expr): + arg, = expr.args + return '++(%s)' % self._print(arg) + + def _print_PostIncrement(self, expr): + arg, = expr.args + return '(%s)++' % self._print(arg) + + def _print_PreDecrement(self, expr): + arg, = expr.args + return '--(%s)' % self._print(arg) + + def _print_PostDecrement(self, expr): + arg, = expr.args + return '(%s)--' % self._print(arg) + + def _print_struct(self, expr): + return "%(keyword)s %(name)s {\n%(lines)s}" % { + "keyword": expr.__class__.__name__, "name": expr.name, "lines": ';\n'.join( + [self._print(decl) for decl in expr.declarations] + ['']) + } + + def _print_BreakToken(self, _): + return 'break' + + def _print_ContinueToken(self, _): + return 'continue' + + _print_union = _print_struct + +class C99CodePrinter(C89CodePrinter): + standard = 'C99' + reserved_words = set(reserved_words + reserved_words_c99) + type_mappings=dict(chain(C89CodePrinter.type_mappings.items(), { + complex64: 'float complex', + complex128: 'double complex', + }.items())) + type_headers = dict(chain(C89CodePrinter.type_headers.items(), { + complex64: {'complex.h'}, + complex128: {'complex.h'} + }.items())) + + # known_functions-dict to copy + _kf: dict[str, Any] = known_functions_C99 + + # functions with versions with 'f' and 'l' suffixes: + _prec_funcs = ('fabs fmod remainder remquo fma fmax fmin fdim nan exp exp2' + ' expm1 log log10 log2 log1p pow sqrt cbrt hypot sin cos tan' + ' asin acos atan atan2 sinh cosh tanh asinh acosh atanh erf' + ' erfc tgamma lgamma ceil floor trunc round nearbyint rint' + ' frexp ldexp modf scalbn ilogb logb nextafter copysign').split() + + def _print_Infinity(self, expr): + return 'INFINITY' + + def _print_NegativeInfinity(self, expr): + return '-INFINITY' + + def _print_NaN(self, expr): + return 'NAN' + + # tgamma was already covered by 'known_functions' dict + + @requires(headers={'math.h'}, libraries={'m'}) + @_as_macro_if_defined + def _print_math_func(self, expr, nest=False, known=None): + if known is None: + known = self.known_functions[expr.__class__.__name__] + if not isinstance(known, str): + for cb, name in known: + if cb(*expr.args): + known = name + break + else: + raise ValueError("No matching printer") + try: + return known(self, *expr.args) + except TypeError: + suffix = self._get_func_suffix(real) if self._ns + known in self._prec_funcs else '' + + if nest: + args = self._print(expr.args[0]) + if len(expr.args) > 1: + paren_pile = '' + for curr_arg in expr.args[1:-1]: + paren_pile += ')' + args += ', {ns}{name}{suffix}({next}'.format( + ns=self._ns, + name=known, + suffix=suffix, + next = self._print(curr_arg) + ) + args += ', %s%s' % ( + self._print(expr.func(expr.args[-1])), + paren_pile + ) + else: + args = ', '.join((self._print(arg) for arg in expr.args)) + return '{ns}{name}{suffix}({args})'.format( + ns=self._ns, + name=known, + suffix=suffix, + args=args + ) + + def _print_Max(self, expr): + return self._print_math_func(expr, nest=True) + + def _print_Min(self, expr): + return self._print_math_func(expr, nest=True) + + def _get_loop_opening_ending(self, indices): + open_lines = [] + close_lines = [] + loopstart = "for (int %(var)s=%(start)s; %(var)s<%(end)s; %(var)s++){" # C99 + for i in indices: + # C arrays start at 0 and end at dimension-1 + open_lines.append(loopstart % { + 'var': self._print(i.label), + 'start': self._print(i.lower), + 'end': self._print(i.upper + 1)}) + close_lines.append("}") + return open_lines, close_lines + + +for k in ('Abs Sqrt exp exp2 expm1 log log10 log2 log1p Cbrt hypot fma' + ' loggamma sin cos tan asin acos atan atan2 sinh cosh tanh asinh acosh ' + 'atanh erf erfc loggamma gamma ceiling floor').split(): + setattr(C99CodePrinter, '_print_%s' % k, C99CodePrinter._print_math_func) + + +class C11CodePrinter(C99CodePrinter): + + @requires(headers={'stdalign.h'}) + def _print_alignof(self, expr): + arg, = expr.args + return 'alignof(%s)' % self._print(arg) + + +c_code_printers = { + 'c89': C89CodePrinter, + 'c99': C99CodePrinter, + 'c11': C11CodePrinter +} diff --git a/phivenv/Lib/site-packages/sympy/printing/codeprinter.py b/phivenv/Lib/site-packages/sympy/printing/codeprinter.py new file mode 100644 index 0000000000000000000000000000000000000000..1faaa0f054cbd8ff438b90e914808f720d2da90a --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/codeprinter.py @@ -0,0 +1,1039 @@ +from __future__ import annotations +from typing import Any + +from functools import wraps + +from sympy.core import Add, Mul, Pow, S, sympify, Float +from sympy.core.basic import Basic +from sympy.core.expr import Expr, UnevaluatedExpr +from sympy.core.function import Lambda +from sympy.core.mul import _keep_coeff +from sympy.core.sorting import default_sort_key +from sympy.core.symbol import Symbol +from sympy.functions.elementary.complexes import re +from sympy.printing.str import StrPrinter +from sympy.printing.precedence import precedence, PRECEDENCE + + +class requires: + """ Decorator for registering requirements on print methods. """ + def __init__(self, **kwargs): + self._req = kwargs + + def __call__(self, method): + def _method_wrapper(self_, *args, **kwargs): + for k, v in self._req.items(): + getattr(self_, k).update(v) + return method(self_, *args, **kwargs) + return wraps(method)(_method_wrapper) + + +class AssignmentError(Exception): + """ + Raised if an assignment variable for a loop is missing. + """ + pass + +class PrintMethodNotImplementedError(NotImplementedError): + """ + Raised if a _print_* method is missing in the Printer. + """ + pass + +def _convert_python_lists(arg): + if isinstance(arg, list): + from sympy.codegen.abstract_nodes import List + return List(*(_convert_python_lists(e) for e in arg)) + elif isinstance(arg, tuple): + return tuple(_convert_python_lists(e) for e in arg) + else: + return arg + + +class CodePrinter(StrPrinter): + """ + The base class for code-printing subclasses. + """ + + _operators = { + 'and': '&&', + 'or': '||', + 'not': '!', + } + + _default_settings: dict[str, Any] = { + 'order': None, + 'full_prec': 'auto', + 'error_on_reserved': False, + 'reserved_word_suffix': '_', + 'human': True, + 'inline': False, + 'allow_unknown_functions': False, + 'strict': None # True or False; None => True if human == True + } + + # Functions which are "simple" to rewrite to other functions that + # may be supported + # function_to_rewrite : (function_to_rewrite_to, iterable_with_other_functions_required) + _rewriteable_functions = { + 'cot': ('tan', []), + 'csc': ('sin', []), + 'sec': ('cos', []), + 'acot': ('atan', []), + 'acsc': ('asin', []), + 'asec': ('acos', []), + 'coth': ('exp', []), + 'csch': ('exp', []), + 'sech': ('exp', []), + 'acoth': ('log', []), + 'acsch': ('log', []), + 'asech': ('log', []), + 'catalan': ('gamma', []), + 'fibonacci': ('sqrt', []), + 'lucas': ('sqrt', []), + 'beta': ('gamma', []), + 'sinc': ('sin', ['Piecewise']), + 'Mod': ('floor', []), + 'factorial': ('gamma', []), + 'factorial2': ('gamma', ['Piecewise']), + 'subfactorial': ('uppergamma', []), + 'RisingFactorial': ('gamma', ['Piecewise']), + 'FallingFactorial': ('gamma', ['Piecewise']), + 'binomial': ('gamma', []), + 'frac': ('floor', []), + 'Max': ('Piecewise', []), + 'Min': ('Piecewise', []), + 'Heaviside': ('Piecewise', []), + 'erf2': ('erf', []), + 'erfc': ('erf', []), + 'Li': ('li', []), + 'Ei': ('li', []), + 'dirichlet_eta': ('zeta', []), + 'riemann_xi': ('zeta', ['gamma']), + 'SingularityFunction': ('Piecewise', []), + } + + def __init__(self, settings=None): + super().__init__(settings=settings) + if self._settings.get('strict', True) == None: + # for backwards compatibility, human=False need not to throw: + self._settings['strict'] = self._settings.get('human', True) == True + if not hasattr(self, 'reserved_words'): + self.reserved_words = set() + + def _handle_UnevaluatedExpr(self, expr): + return expr.replace(re, lambda arg: arg if isinstance( + arg, UnevaluatedExpr) and arg.args[0].is_real else re(arg)) + + def doprint(self, expr, assign_to=None): + """ + Print the expression as code. + + Parameters + ---------- + expr : Expression + The expression to be printed. + + assign_to : Symbol, string, MatrixSymbol, list of strings or Symbols (optional) + If provided, the printed code will set the expression to a variable or multiple variables + with the name or names given in ``assign_to``. + """ + from sympy.matrices.expressions.matexpr import MatrixSymbol + from sympy.codegen.ast import CodeBlock, Assignment + + def _handle_assign_to(expr, assign_to): + if assign_to is None: + return sympify(expr) + if isinstance(assign_to, (list, tuple)): + if len(expr) != len(assign_to): + raise ValueError('Failed to assign an expression of length {} to {} variables'.format(len(expr), len(assign_to))) + return CodeBlock(*[_handle_assign_to(lhs, rhs) for lhs, rhs in zip(expr, assign_to)]) + if isinstance(assign_to, str): + if expr.is_Matrix: + assign_to = MatrixSymbol(assign_to, *expr.shape) + else: + assign_to = Symbol(assign_to) + elif not isinstance(assign_to, Basic): + raise TypeError("{} cannot assign to object of type {}".format( + type(self).__name__, type(assign_to))) + return Assignment(assign_to, expr) + + expr = _convert_python_lists(expr) + expr = _handle_assign_to(expr, assign_to) + + # Remove re(...) nodes due to UnevaluatedExpr.is_real always is None: + expr = self._handle_UnevaluatedExpr(expr) + + # keep a set of expressions that are not strictly translatable to Code + # and number constants that must be declared and initialized + self._not_supported = set() + self._number_symbols = set() + + lines = self._print(expr).splitlines() + + # format the output + if self._settings["human"]: + frontlines = [] + if self._not_supported: + frontlines.append(self._get_comment( + "Not supported in {}:".format(self.language))) + for expr in sorted(self._not_supported, key=str): + frontlines.append(self._get_comment(type(expr).__name__)) + for name, value in sorted(self._number_symbols, key=str): + frontlines.append(self._declare_number_const(name, value)) + lines = frontlines + lines + lines = self._format_code(lines) + result = "\n".join(lines) + else: + lines = self._format_code(lines) + num_syms = {(k, self._print(v)) for k, v in self._number_symbols} + result = (num_syms, self._not_supported, "\n".join(lines)) + self._not_supported = set() + self._number_symbols = set() + return result + + def _doprint_loops(self, expr, assign_to=None): + # Here we print an expression that contains Indexed objects, they + # correspond to arrays in the generated code. The low-level implementation + # involves looping over array elements and possibly storing results in temporary + # variables or accumulate it in the assign_to object. + + if self._settings.get('contract', True): + from sympy.tensor import get_contraction_structure + # Setup loops over non-dummy indices -- all terms need these + indices = self._get_expression_indices(expr, assign_to) + # Setup loops over dummy indices -- each term needs separate treatment + dummies = get_contraction_structure(expr) + else: + indices = [] + dummies = {None: (expr,)} + openloop, closeloop = self._get_loop_opening_ending(indices) + + # terms with no summations first + if None in dummies: + text = StrPrinter.doprint(self, Add(*dummies[None])) + else: + # If all terms have summations we must initialize array to Zero + text = StrPrinter.doprint(self, 0) + + # skip redundant assignments (where lhs == rhs) + lhs_printed = self._print(assign_to) + lines = [] + if text != lhs_printed: + lines.extend(openloop) + if assign_to is not None: + text = self._get_statement("%s = %s" % (lhs_printed, text)) + lines.append(text) + lines.extend(closeloop) + + # then terms with summations + for d in dummies: + if isinstance(d, tuple): + indices = self._sort_optimized(d, expr) + openloop_d, closeloop_d = self._get_loop_opening_ending( + indices) + + for term in dummies[d]: + if term in dummies and not ([list(f.keys()) for f in dummies[term]] + == [[None] for f in dummies[term]]): + # If one factor in the term has it's own internal + # contractions, those must be computed first. + # (temporary variables?) + raise NotImplementedError( + "FIXME: no support for contractions in factor yet") + else: + + # We need the lhs expression as an accumulator for + # the loops, i.e + # + # for (int d=0; d < dim; d++){ + # lhs[] = lhs[] + term[][d] + # } ^.................. the accumulator + # + # We check if the expression already contains the + # lhs, and raise an exception if it does, as that + # syntax is currently undefined. FIXME: What would be + # a good interpretation? + if assign_to is None: + raise AssignmentError( + "need assignment variable for loops") + if term.has(assign_to): + raise ValueError("FIXME: lhs present in rhs,\ + this is undefined in CodePrinter") + + lines.extend(openloop) + lines.extend(openloop_d) + text = "%s = %s" % (lhs_printed, StrPrinter.doprint( + self, assign_to + term)) + lines.append(self._get_statement(text)) + lines.extend(closeloop_d) + lines.extend(closeloop) + + return "\n".join(lines) + + def _get_expression_indices(self, expr, assign_to): + from sympy.tensor import get_indices + rinds, junk = get_indices(expr) + linds, junk = get_indices(assign_to) + + # support broadcast of scalar + if linds and not rinds: + rinds = linds + if rinds != linds: + raise ValueError("lhs indices must match non-dummy" + " rhs indices in %s" % expr) + + return self._sort_optimized(rinds, assign_to) + + def _sort_optimized(self, indices, expr): + + from sympy.tensor.indexed import Indexed + + if not indices: + return [] + + # determine optimized loop order by giving a score to each index + # the index with the highest score are put in the innermost loop. + score_table = {} + for i in indices: + score_table[i] = 0 + + arrays = expr.atoms(Indexed) + for arr in arrays: + for p, ind in enumerate(arr.indices): + try: + score_table[ind] += self._rate_index_position(p) + except KeyError: + pass + + return sorted(indices, key=lambda x: score_table[x]) + + def _rate_index_position(self, p): + """function to calculate score based on position among indices + + This method is used to sort loops in an optimized order, see + CodePrinter._sort_optimized() + """ + raise NotImplementedError("This function must be implemented by " + "subclass of CodePrinter.") + + def _get_statement(self, codestring): + """Formats a codestring with the proper line ending.""" + raise NotImplementedError("This function must be implemented by " + "subclass of CodePrinter.") + + def _get_comment(self, text): + """Formats a text string as a comment.""" + raise NotImplementedError("This function must be implemented by " + "subclass of CodePrinter.") + + def _declare_number_const(self, name, value): + """Declare a numeric constant at the top of a function""" + raise NotImplementedError("This function must be implemented by " + "subclass of CodePrinter.") + + def _format_code(self, lines): + """Take in a list of lines of code, and format them accordingly. + + This may include indenting, wrapping long lines, etc...""" + raise NotImplementedError("This function must be implemented by " + "subclass of CodePrinter.") + + def _get_loop_opening_ending(self, indices): + """Returns a tuple (open_lines, close_lines) containing lists + of codelines""" + raise NotImplementedError("This function must be implemented by " + "subclass of CodePrinter.") + + def _print_Dummy(self, expr): + if expr.name.startswith('Dummy_'): + return '_' + expr.name + else: + return '%s_%d' % (expr.name, expr.dummy_index) + + def _print_Idx(self, expr): + return self._print(expr.label) + + def _print_CodeBlock(self, expr): + return '\n'.join([self._print(i) for i in expr.args]) + + def _print_String(self, string): + return str(string) + + def _print_QuotedString(self, arg): + return '"%s"' % arg.text + + def _print_Comment(self, string): + return self._get_comment(str(string)) + + def _print_Assignment(self, expr): + from sympy.codegen.ast import Assignment + from sympy.functions.elementary.piecewise import Piecewise + from sympy.matrices.expressions.matexpr import MatrixSymbol + from sympy.tensor.indexed import IndexedBase + lhs = expr.lhs + rhs = expr.rhs + # We special case assignments that take multiple lines + if isinstance(expr.rhs, Piecewise): + # Here we modify Piecewise so each expression is now + # an Assignment, and then continue on the print. + expressions = [] + conditions = [] + for (e, c) in rhs.args: + expressions.append(Assignment(lhs, e)) + conditions.append(c) + temp = Piecewise(*zip(expressions, conditions)) + return self._print(temp) + elif isinstance(lhs, MatrixSymbol): + # Here we form an Assignment for each element in the array, + # printing each one. + lines = [] + for (i, j) in self._traverse_matrix_indices(lhs): + temp = Assignment(lhs[i, j], rhs[i, j]) + code0 = self._print(temp) + lines.append(code0) + return "\n".join(lines) + elif self._settings.get("contract", False) and (lhs.has(IndexedBase) or + rhs.has(IndexedBase)): + # Here we check if there is looping to be done, and if so + # print the required loops. + return self._doprint_loops(rhs, lhs) + else: + lhs_code = self._print(lhs) + rhs_code = self._print(rhs) + return self._get_statement("%s = %s" % (lhs_code, rhs_code)) + + def _print_AugmentedAssignment(self, expr): + lhs_code = self._print(expr.lhs) + rhs_code = self._print(expr.rhs) + return self._get_statement("{} {} {}".format( + *(self._print(arg) for arg in [lhs_code, expr.op, rhs_code]))) + + def _print_FunctionCall(self, expr): + return '%s(%s)' % ( + expr.name, + ', '.join((self._print(arg) for arg in expr.function_args))) + + def _print_Variable(self, expr): + return self._print(expr.symbol) + + def _print_Symbol(self, expr): + name = super()._print_Symbol(expr) + + if name in self.reserved_words: + if self._settings['error_on_reserved']: + msg = ('This expression includes the symbol "{}" which is a ' + 'reserved keyword in this language.') + raise ValueError(msg.format(name)) + return name + self._settings['reserved_word_suffix'] + else: + return name + + def _can_print(self, name): + """ Check if function ``name`` is either a known function or has its own + printing method. Used to check if rewriting is possible.""" + return name in self.known_functions or getattr(self, '_print_{}'.format(name), False) + + def _print_Function(self, expr): + if expr.func.__name__ in self.known_functions: + cond_func = self.known_functions[expr.func.__name__] + if isinstance(cond_func, str): + return "%s(%s)" % (cond_func, self.stringify(expr.args, ", ")) + else: + for cond, func in cond_func: + if cond(*expr.args): + break + if func is not None: + try: + return func(*[self.parenthesize(item, 0) for item in expr.args]) + except TypeError: + return "%s(%s)" % (func, self.stringify(expr.args, ", ")) + elif hasattr(expr, '_imp_') and isinstance(expr._imp_, Lambda): + # inlined function + return self._print(expr._imp_(*expr.args)) + elif expr.func.__name__ in self._rewriteable_functions: + # Simple rewrite to supported function possible + target_f, required_fs = self._rewriteable_functions[expr.func.__name__] + if self._can_print(target_f) and all(self._can_print(f) for f in required_fs): + return '(' + self._print(expr.rewrite(target_f)) + ')' + + if expr.is_Function and self._settings.get('allow_unknown_functions', False): + return '%s(%s)' % (self._print(expr.func), ', '.join(map(self._print, expr.args))) + else: + return self._print_not_supported(expr) + + _print_Expr = _print_Function + + def _print_Derivative(self, expr): + obj, *wrt_order_pairs = expr.args + for func_arg in obj.args: + if not func_arg.is_Symbol: + raise ValueError("%s._print_Derivative(...) only supports functions with symbols as arguments." % + self.__class__.__name__) + meth_name = '_print_Derivative_%s' % obj.func.__name__ + pmeth = getattr(self, meth_name, None) + if pmeth is None: + if self._settings.get('strict', False): + raise PrintMethodNotImplementedError( + f"Unsupported by {type(self)}: {type(expr)}" + + f"\nPrinter has no method: {meth_name}" + + "\nSet the printer option 'strict' to False in order to generate partially printed code." + ) + return self._print_not_supported(expr) + orders = dict(wrt_order_pairs) + seq_orders = [orders[arg] for arg in obj.args] + return pmeth(obj.args, seq_orders) + + # Don't inherit the str-printer method for Heaviside to the code printers + _print_Heaviside = None + + def _print_NumberSymbol(self, expr): + if self._settings.get("inline", False): + return self._print(Float(expr.evalf(self._settings["precision"]))) + else: + # A Number symbol that is not implemented here or with _printmethod + # is registered and evaluated + self._number_symbols.add((expr, + Float(expr.evalf(self._settings["precision"])))) + return str(expr) + + def _print_Catalan(self, expr): + return self._print_NumberSymbol(expr) + def _print_EulerGamma(self, expr): + return self._print_NumberSymbol(expr) + def _print_GoldenRatio(self, expr): + return self._print_NumberSymbol(expr) + def _print_TribonacciConstant(self, expr): + return self._print_NumberSymbol(expr) + def _print_Exp1(self, expr): + return self._print_NumberSymbol(expr) + def _print_Pi(self, expr): + return self._print_NumberSymbol(expr) + + def _print_And(self, expr): + PREC = precedence(expr) + return (" %s " % self._operators['and']).join(self.parenthesize(a, PREC) + for a in sorted(expr.args, key=default_sort_key)) + + def _print_Or(self, expr): + PREC = precedence(expr) + return (" %s " % self._operators['or']).join(self.parenthesize(a, PREC) + for a in sorted(expr.args, key=default_sort_key)) + + def _print_Xor(self, expr): + if self._operators.get('xor') is None: + return self._print(expr.to_nnf()) + PREC = precedence(expr) + return (" %s " % self._operators['xor']).join(self.parenthesize(a, PREC) + for a in expr.args) + + def _print_Equivalent(self, expr): + if self._operators.get('equivalent') is None: + return self._print(expr.to_nnf()) + PREC = precedence(expr) + return (" %s " % self._operators['equivalent']).join(self.parenthesize(a, PREC) + for a in expr.args) + + def _print_Not(self, expr): + PREC = precedence(expr) + return self._operators['not'] + self.parenthesize(expr.args[0], PREC) + + def _print_BooleanFunction(self, expr): + return self._print(expr.to_nnf()) + + def _print_isnan(self, arg): + return 'isnan(%s)' % self._print(*arg.args) + + def _print_isinf(self, arg): + return 'isinf(%s)' % self._print(*arg.args) + + def _print_Mul(self, expr): + + prec = precedence(expr) + + c, e = expr.as_coeff_Mul() + if c < 0: + expr = _keep_coeff(-c, e) + sign = "-" + else: + sign = "" + + a = [] # items in the numerator + b = [] # items that are in the denominator (if any) + + pow_paren = [] # Will collect all pow with more than one base element and exp = -1 + + if self.order not in ('old', 'none'): + args = expr.as_ordered_factors() + else: + # use make_args in case expr was something like -x -> x + args = Mul.make_args(expr) + + # Gather args for numerator/denominator + for item in args: + if item.is_commutative and item.is_Pow and item.exp.is_Rational and item.exp.is_negative: + if item.exp != -1: + b.append(Pow(item.base, -item.exp, evaluate=False)) + else: + if len(item.args[0].args) != 1 and isinstance(item.base, Mul): # To avoid situations like #14160 + pow_paren.append(item) + b.append(Pow(item.base, -item.exp)) + else: + a.append(item) + + a = a or [S.One] + + if len(a) == 1 and sign == "-": + # Unary minus does not have a SymPy class, and hence there's no + # precedence weight associated with it, Python's unary minus has + # an operator precedence between multiplication and exponentiation, + # so we use this to compute a weight. + a_str = [self.parenthesize(a[0], 0.5*(PRECEDENCE["Pow"]+PRECEDENCE["Mul"]))] + else: + a_str = [self.parenthesize(x, prec) for x in a] + b_str = [self.parenthesize(x, prec) for x in b] + + # To parenthesize Pow with exp = -1 and having more than one Symbol + for item in pow_paren: + if item.base in b: + b_str[b.index(item.base)] = "(%s)" % b_str[b.index(item.base)] + + if not b: + return sign + '*'.join(a_str) + elif len(b) == 1: + return sign + '*'.join(a_str) + "/" + b_str[0] + else: + return sign + '*'.join(a_str) + "/(%s)" % '*'.join(b_str) + + def _print_not_supported(self, expr): + if self._settings.get('strict', False): + raise PrintMethodNotImplementedError( + f"Unsupported by {type(self)}: {type(expr)}" + + "\nSet the printer option 'strict' to False in order to generate partially printed code." + ) + try: + self._not_supported.add(expr) + except TypeError: + # not hashable + pass + return self.emptyPrinter(expr) + + # The following can not be simply translated into C or Fortran + _print_Basic = _print_not_supported + _print_ComplexInfinity = _print_not_supported + _print_ExprCondPair = _print_not_supported + _print_GeometryEntity = _print_not_supported + _print_Infinity = _print_not_supported + _print_Integral = _print_not_supported + _print_Interval = _print_not_supported + _print_AccumulationBounds = _print_not_supported + _print_Limit = _print_not_supported + _print_MatrixBase = _print_not_supported + _print_DeferredVector = _print_not_supported + _print_NaN = _print_not_supported + _print_NegativeInfinity = _print_not_supported + _print_Order = _print_not_supported + _print_RootOf = _print_not_supported + _print_RootsOf = _print_not_supported + _print_RootSum = _print_not_supported + _print_Uniform = _print_not_supported + _print_Unit = _print_not_supported + _print_Wild = _print_not_supported + _print_WildFunction = _print_not_supported + _print_Relational = _print_not_supported + + +# Code printer functions. These are included in this file so that they can be +# imported in the top-level __init__.py without importing the sympy.codegen +# module. + +def ccode(expr, assign_to=None, standard='c99', **settings): + """Converts an expr to a string of c code + + Parameters + ========== + + expr : Expr + A SymPy expression to be converted. + assign_to : optional + When given, the argument is used as the name of the variable to which + the expression is assigned. Can be a string, ``Symbol``, + ``MatrixSymbol``, or ``Indexed`` type. This is helpful in case of + line-wrapping, or for expressions that generate multi-line statements. + standard : str, optional + String specifying the standard. If your compiler supports a more modern + standard you may set this to 'c99' to allow the printer to use more math + functions. [default='c89']. + precision : integer, optional + The precision for numbers such as pi [default=17]. + user_functions : dict, optional + A dictionary where the keys are string representations of either + ``FunctionClass`` or ``UndefinedFunction`` instances and the values + are their desired C string representations. Alternatively, the + dictionary value can be a list of tuples i.e. [(argument_test, + cfunction_string)] or [(argument_test, cfunction_formater)]. See below + for examples. + dereference : iterable, optional + An iterable of symbols that should be dereferenced in the printed code + expression. These would be values passed by address to the function. + For example, if ``dereference=[a]``, the resulting code would print + ``(*a)`` instead of ``a``. + human : bool, optional + If True, the result is a single string that may contain some constant + declarations for the number symbols. If False, the same information is + returned in a tuple of (symbols_to_declare, not_supported_functions, + code_text). [default=True]. + contract: bool, optional + If True, ``Indexed`` instances are assumed to obey tensor contraction + rules and the corresponding nested loops over indices are generated. + Setting contract=False will not generate loops, instead the user is + responsible to provide values for the indices in the code. + [default=True]. + + Examples + ======== + + >>> from sympy import ccode, symbols, Rational, sin, ceiling, Abs, Function + >>> x, tau = symbols("x, tau") + >>> expr = (2*tau)**Rational(7, 2) + >>> ccode(expr) + '8*M_SQRT2*pow(tau, 7.0/2.0)' + >>> ccode(expr, math_macros={}) + '8*sqrt(2)*pow(tau, 7.0/2.0)' + >>> ccode(sin(x), assign_to="s") + 's = sin(x);' + >>> from sympy.codegen.ast import real, float80 + >>> ccode(expr, type_aliases={real: float80}) + '8*M_SQRT2l*powl(tau, 7.0L/2.0L)' + + Simple custom printing can be defined for certain types by passing a + dictionary of {"type" : "function"} to the ``user_functions`` kwarg. + Alternatively, the dictionary value can be a list of tuples i.e. + [(argument_test, cfunction_string)]. + + >>> custom_functions = { + ... "ceiling": "CEIL", + ... "Abs": [(lambda x: not x.is_integer, "fabs"), + ... (lambda x: x.is_integer, "ABS")], + ... "func": "f" + ... } + >>> func = Function('func') + >>> ccode(func(Abs(x) + ceiling(x)), standard='C89', user_functions=custom_functions) + 'f(fabs(x) + CEIL(x))' + + or if the C-function takes a subset of the original arguments: + + >>> ccode(2**x + 3**x, standard='C99', user_functions={'Pow': [ + ... (lambda b, e: b == 2, lambda b, e: 'exp2(%s)' % e), + ... (lambda b, e: b != 2, 'pow')]}) + 'exp2(x) + pow(3, x)' + + ``Piecewise`` expressions are converted into conditionals. If an + ``assign_to`` variable is provided an if statement is created, otherwise + the ternary operator is used. Note that if the ``Piecewise`` lacks a + default term, represented by ``(expr, True)`` then an error will be thrown. + This is to prevent generating an expression that may not evaluate to + anything. + + >>> from sympy import Piecewise + >>> expr = Piecewise((x + 1, x > 0), (x, True)) + >>> print(ccode(expr, tau, standard='C89')) + if (x > 0) { + tau = x + 1; + } + else { + tau = x; + } + + Support for loops is provided through ``Indexed`` types. With + ``contract=True`` these expressions will be turned into loops, whereas + ``contract=False`` will just print the assignment expression that should be + looped over: + + >>> from sympy import Eq, IndexedBase, Idx + >>> len_y = 5 + >>> y = IndexedBase('y', shape=(len_y,)) + >>> t = IndexedBase('t', shape=(len_y,)) + >>> Dy = IndexedBase('Dy', shape=(len_y-1,)) + >>> i = Idx('i', len_y-1) + >>> e=Eq(Dy[i], (y[i+1]-y[i])/(t[i+1]-t[i])) + >>> ccode(e.rhs, assign_to=e.lhs, contract=False, standard='C89') + 'Dy[i] = (y[i + 1] - y[i])/(t[i + 1] - t[i]);' + + Matrices are also supported, but a ``MatrixSymbol`` of the same dimensions + must be provided to ``assign_to``. Note that any expression that can be + generated normally can also exist inside a Matrix: + + >>> from sympy import Matrix, MatrixSymbol + >>> mat = Matrix([x**2, Piecewise((x + 1, x > 0), (x, True)), sin(x)]) + >>> A = MatrixSymbol('A', 3, 1) + >>> print(ccode(mat, A, standard='C89')) + A[0] = pow(x, 2); + if (x > 0) { + A[1] = x + 1; + } + else { + A[1] = x; + } + A[2] = sin(x); + """ + from sympy.printing.c import c_code_printers + return c_code_printers[standard.lower()](settings).doprint(expr, assign_to) + +def print_ccode(expr, **settings): + """Prints C representation of the given expression.""" + print(ccode(expr, **settings)) + +def fcode(expr, assign_to=None, **settings): + """Converts an expr to a string of fortran code + + Parameters + ========== + + expr : Expr + A SymPy expression to be converted. + assign_to : optional + When given, the argument is used as the name of the variable to which + the expression is assigned. Can be a string, ``Symbol``, + ``MatrixSymbol``, or ``Indexed`` type. This is helpful in case of + line-wrapping, or for expressions that generate multi-line statements. + precision : integer, optional + DEPRECATED. Use type_mappings instead. The precision for numbers such + as pi [default=17]. + user_functions : dict, optional + A dictionary where keys are ``FunctionClass`` instances and values are + their string representations. Alternatively, the dictionary value can + be a list of tuples i.e. [(argument_test, cfunction_string)]. See below + for examples. + human : bool, optional + If True, the result is a single string that may contain some constant + declarations for the number symbols. If False, the same information is + returned in a tuple of (symbols_to_declare, not_supported_functions, + code_text). [default=True]. + contract: bool, optional + If True, ``Indexed`` instances are assumed to obey tensor contraction + rules and the corresponding nested loops over indices are generated. + Setting contract=False will not generate loops, instead the user is + responsible to provide values for the indices in the code. + [default=True]. + source_format : optional + The source format can be either 'fixed' or 'free'. [default='fixed'] + standard : integer, optional + The Fortran standard to be followed. This is specified as an integer. + Acceptable standards are 66, 77, 90, 95, 2003, and 2008. Default is 77. + Note that currently the only distinction internally is between + standards before 95, and those 95 and after. This may change later as + more features are added. + name_mangling : bool, optional + If True, then the variables that would become identical in + case-insensitive Fortran are mangled by appending different number + of ``_`` at the end. If False, SymPy Will not interfere with naming of + variables. [default=True] + + Examples + ======== + + >>> from sympy import fcode, symbols, Rational, sin, ceiling, floor + >>> x, tau = symbols("x, tau") + >>> fcode((2*tau)**Rational(7, 2)) + ' 8*sqrt(2.0d0)*tau**(7.0d0/2.0d0)' + >>> fcode(sin(x), assign_to="s") + ' s = sin(x)' + + Custom printing can be defined for certain types by passing a dictionary of + "type" : "function" to the ``user_functions`` kwarg. Alternatively, the + dictionary value can be a list of tuples i.e. [(argument_test, + cfunction_string)]. + + >>> custom_functions = { + ... "ceiling": "CEIL", + ... "floor": [(lambda x: not x.is_integer, "FLOOR1"), + ... (lambda x: x.is_integer, "FLOOR2")] + ... } + >>> fcode(floor(x) + ceiling(x), user_functions=custom_functions) + ' CEIL(x) + FLOOR1(x)' + + ``Piecewise`` expressions are converted into conditionals. If an + ``assign_to`` variable is provided an if statement is created, otherwise + the ternary operator is used. Note that if the ``Piecewise`` lacks a + default term, represented by ``(expr, True)`` then an error will be thrown. + This is to prevent generating an expression that may not evaluate to + anything. + + >>> from sympy import Piecewise + >>> expr = Piecewise((x + 1, x > 0), (x, True)) + >>> print(fcode(expr, tau)) + if (x > 0) then + tau = x + 1 + else + tau = x + end if + + Support for loops is provided through ``Indexed`` types. With + ``contract=True`` these expressions will be turned into loops, whereas + ``contract=False`` will just print the assignment expression that should be + looped over: + + >>> from sympy import Eq, IndexedBase, Idx + >>> len_y = 5 + >>> y = IndexedBase('y', shape=(len_y,)) + >>> t = IndexedBase('t', shape=(len_y,)) + >>> Dy = IndexedBase('Dy', shape=(len_y-1,)) + >>> i = Idx('i', len_y-1) + >>> e=Eq(Dy[i], (y[i+1]-y[i])/(t[i+1]-t[i])) + >>> fcode(e.rhs, assign_to=e.lhs, contract=False) + ' Dy(i) = (y(i + 1) - y(i))/(t(i + 1) - t(i))' + + Matrices are also supported, but a ``MatrixSymbol`` of the same dimensions + must be provided to ``assign_to``. Note that any expression that can be + generated normally can also exist inside a Matrix: + + >>> from sympy import Matrix, MatrixSymbol + >>> mat = Matrix([x**2, Piecewise((x + 1, x > 0), (x, True)), sin(x)]) + >>> A = MatrixSymbol('A', 3, 1) + >>> print(fcode(mat, A)) + A(1, 1) = x**2 + if (x > 0) then + A(2, 1) = x + 1 + else + A(2, 1) = x + end if + A(3, 1) = sin(x) + """ + from sympy.printing.fortran import FCodePrinter + return FCodePrinter(settings).doprint(expr, assign_to) + + +def print_fcode(expr, **settings): + """Prints the Fortran representation of the given expression. + + See fcode for the meaning of the optional arguments. + """ + print(fcode(expr, **settings)) + +def cxxcode(expr, assign_to=None, standard='c++11', **settings): + """ C++ equivalent of :func:`~.ccode`. """ + from sympy.printing.cxx import cxx_code_printers + return cxx_code_printers[standard.lower()](settings).doprint(expr, assign_to) + + +def rust_code(expr, assign_to=None, **settings): + """Converts an expr to a string of Rust code + + Parameters + ========== + + expr : Expr + A SymPy expression to be converted. + assign_to : optional + When given, the argument is used as the name of the variable to which + the expression is assigned. Can be a string, ``Symbol``, + ``MatrixSymbol``, or ``Indexed`` type. This is helpful in case of + line-wrapping, or for expressions that generate multi-line statements. + precision : integer, optional + The precision for numbers such as pi [default=15]. + user_functions : dict, optional + A dictionary where the keys are string representations of either + ``FunctionClass`` or ``UndefinedFunction`` instances and the values + are their desired C string representations. Alternatively, the + dictionary value can be a list of tuples i.e. [(argument_test, + cfunction_string)]. See below for examples. + dereference : iterable, optional + An iterable of symbols that should be dereferenced in the printed code + expression. These would be values passed by address to the function. + For example, if ``dereference=[a]``, the resulting code would print + ``(*a)`` instead of ``a``. + human : bool, optional + If True, the result is a single string that may contain some constant + declarations for the number symbols. If False, the same information is + returned in a tuple of (symbols_to_declare, not_supported_functions, + code_text). [default=True]. + contract: bool, optional + If True, ``Indexed`` instances are assumed to obey tensor contraction + rules and the corresponding nested loops over indices are generated. + Setting contract=False will not generate loops, instead the user is + responsible to provide values for the indices in the code. + [default=True]. + + Examples + ======== + + >>> from sympy import rust_code, symbols, Rational, sin, ceiling, Abs, Function + >>> x, tau = symbols("x, tau") + >>> rust_code((2*tau)**Rational(7, 2)) + '8.0*1.4142135623731*tau.powf(7_f64/2.0)' + >>> rust_code(sin(x), assign_to="s") + 's = x.sin();' + + Simple custom printing can be defined for certain types by passing a + dictionary of {"type" : "function"} to the ``user_functions`` kwarg. + Alternatively, the dictionary value can be a list of tuples i.e. + [(argument_test, cfunction_string)]. + + >>> custom_functions = { + ... "ceiling": "CEIL", + ... "Abs": [(lambda x: not x.is_integer, "fabs", 4), + ... (lambda x: x.is_integer, "ABS", 4)], + ... "func": "f" + ... } + >>> func = Function('func') + >>> rust_code(func(Abs(x) + ceiling(x)), user_functions=custom_functions) + '(fabs(x) + x.ceil()).f()' + + ``Piecewise`` expressions are converted into conditionals. If an + ``assign_to`` variable is provided an if statement is created, otherwise + the ternary operator is used. Note that if the ``Piecewise`` lacks a + default term, represented by ``(expr, True)`` then an error will be thrown. + This is to prevent generating an expression that may not evaluate to + anything. + + >>> from sympy import Piecewise + >>> expr = Piecewise((x + 1, x > 0), (x, True)) + >>> print(rust_code(expr, tau)) + tau = if (x > 0.0) { + x + 1 + } else { + x + }; + + Support for loops is provided through ``Indexed`` types. With + ``contract=True`` these expressions will be turned into loops, whereas + ``contract=False`` will just print the assignment expression that should be + looped over: + + >>> from sympy import Eq, IndexedBase, Idx + >>> len_y = 5 + >>> y = IndexedBase('y', shape=(len_y,)) + >>> t = IndexedBase('t', shape=(len_y,)) + >>> Dy = IndexedBase('Dy', shape=(len_y-1,)) + >>> i = Idx('i', len_y-1) + >>> e=Eq(Dy[i], (y[i+1]-y[i])/(t[i+1]-t[i])) + >>> rust_code(e.rhs, assign_to=e.lhs, contract=False) + 'Dy[i] = (y[i + 1] - y[i])/(t[i + 1] - t[i]);' + + Matrices are also supported, but a ``MatrixSymbol`` of the same dimensions + must be provided to ``assign_to``. Note that any expression that can be + generated normally can also exist inside a Matrix: + + >>> from sympy import Matrix, MatrixSymbol + >>> mat = Matrix([x**2, Piecewise((x + 1, x > 0), (x, True)), sin(x)]) + >>> A = MatrixSymbol('A', 3, 1) + >>> print(rust_code(mat, A)) + A = [x.powi(2), if (x > 0.0) { + x + 1 + } else { + x + }, x.sin()]; + """ + from sympy.printing.rust import RustCodePrinter + printer = RustCodePrinter(settings) + expr = printer._rewrite_known_functions(expr) + if isinstance(expr, Expr): + for src_func, dst_func in printer.function_overrides.values(): + expr = expr.replace(src_func, dst_func) + return printer.doprint(expr, assign_to) + + +def print_rust_code(expr, **settings): + """Prints Rust representation of the given expression.""" + print(rust_code(expr, **settings)) diff --git a/phivenv/Lib/site-packages/sympy/printing/conventions.py b/phivenv/Lib/site-packages/sympy/printing/conventions.py new file mode 100644 index 0000000000000000000000000000000000000000..4f5545ae38511e0bb0366da5c9fb4e59156095c8 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/conventions.py @@ -0,0 +1,88 @@ +""" +A few practical conventions common to all printers. +""" + +import re + +from collections.abc import Iterable +from sympy.core.function import Derivative + +_name_with_digits_p = re.compile(r'^([^\W\d_]+)(\d+)$', re.UNICODE) + + +def split_super_sub(text): + """Split a symbol name into a name, superscripts and subscripts + + The first part of the symbol name is considered to be its actual + 'name', followed by super- and subscripts. Each superscript is + preceded with a "^" character or by "__". Each subscript is preceded + by a "_" character. The three return values are the actual name, a + list with superscripts and a list with subscripts. + + Examples + ======== + + >>> from sympy.printing.conventions import split_super_sub + >>> split_super_sub('a_x^1') + ('a', ['1'], ['x']) + >>> split_super_sub('var_sub1__sup_sub2') + ('var', ['sup'], ['sub1', 'sub2']) + + """ + if not text: + return text, [], [] + + pos = 0 + name = None + supers = [] + subs = [] + while pos < len(text): + start = pos + 1 + if text[pos:pos + 2] == "__": + start += 1 + pos_hat = text.find("^", start) + if pos_hat < 0: + pos_hat = len(text) + pos_usc = text.find("_", start) + if pos_usc < 0: + pos_usc = len(text) + pos_next = min(pos_hat, pos_usc) + part = text[pos:pos_next] + pos = pos_next + if name is None: + name = part + elif part.startswith("^"): + supers.append(part[1:]) + elif part.startswith("__"): + supers.append(part[2:]) + elif part.startswith("_"): + subs.append(part[1:]) + else: + raise RuntimeError("This should never happen.") + + # Make a little exception when a name ends with digits, i.e. treat them + # as a subscript too. + m = _name_with_digits_p.match(name) + if m: + name, sub = m.groups() + subs.insert(0, sub) + + return name, supers, subs + + +def requires_partial(expr): + """Return whether a partial derivative symbol is required for printing + + This requires checking how many free variables there are, + filtering out the ones that are integers. Some expressions do not have + free variables. In that case, check its variable list explicitly to + get the context of the expression. + """ + + if isinstance(expr, Derivative): + return requires_partial(expr.expr) + + if not isinstance(expr.free_symbols, Iterable): + return len(set(expr.variables)) > 1 + + return sum(not s.is_integer for s in expr.free_symbols) > 1 diff --git a/phivenv/Lib/site-packages/sympy/printing/cxx.py b/phivenv/Lib/site-packages/sympy/printing/cxx.py new file mode 100644 index 0000000000000000000000000000000000000000..0ed4f468b866e1b44aba6ae94a85c740dd324689 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/cxx.py @@ -0,0 +1,181 @@ +""" +C++ code printer +""" + +from itertools import chain +from sympy.codegen.ast import Type, none +from .codeprinter import requires +from .c import C89CodePrinter, C99CodePrinter + +# These are defined in the other file so we can avoid importing sympy.codegen +# from the top-level 'import sympy'. Export them here as well. +from sympy.printing.codeprinter import cxxcode # noqa:F401 + +# from https://en.cppreference.com/w/cpp/keyword +reserved = { + 'C++98': [ + 'and', 'and_eq', 'asm', 'auto', 'bitand', 'bitor', 'bool', 'break', + 'case', 'catch,', 'char', 'class', 'compl', 'const', 'const_cast', + 'continue', 'default', 'delete', 'do', 'double', 'dynamic_cast', + 'else', 'enum', 'explicit', 'export', 'extern', 'false', 'float', + 'for', 'friend', 'goto', 'if', 'inline', 'int', 'long', 'mutable', + 'namespace', 'new', 'not', 'not_eq', 'operator', 'or', 'or_eq', + 'private', 'protected', 'public', 'register', 'reinterpret_cast', + 'return', 'short', 'signed', 'sizeof', 'static', 'static_cast', + 'struct', 'switch', 'template', 'this', 'throw', 'true', 'try', + 'typedef', 'typeid', 'typename', 'union', 'unsigned', 'using', + 'virtual', 'void', 'volatile', 'wchar_t', 'while', 'xor', 'xor_eq' + ] +} + +reserved['C++11'] = reserved['C++98'][:] + [ + 'alignas', 'alignof', 'char16_t', 'char32_t', 'constexpr', 'decltype', + 'noexcept', 'nullptr', 'static_assert', 'thread_local' +] +reserved['C++17'] = reserved['C++11'][:] +reserved['C++17'].remove('register') +# TM TS: atomic_cancel, atomic_commit, atomic_noexcept, synchronized +# concepts TS: concept, requires +# module TS: import, module + + +_math_functions = { + 'C++98': { + 'Mod': 'fmod', + 'ceiling': 'ceil', + }, + 'C++11': { + 'gamma': 'tgamma', + }, + 'C++17': { + 'beta': 'beta', + 'Ei': 'expint', + 'zeta': 'riemann_zeta', + } +} + +# from https://en.cppreference.com/w/cpp/header/cmath +for k in ('Abs', 'exp', 'log', 'log10', 'sqrt', 'sin', 'cos', 'tan', # 'Pow' + 'asin', 'acos', 'atan', 'atan2', 'sinh', 'cosh', 'tanh', 'floor'): + _math_functions['C++98'][k] = k.lower() + + +for k in ('asinh', 'acosh', 'atanh', 'erf', 'erfc'): + _math_functions['C++11'][k] = k.lower() + + +def _attach_print_method(cls, sympy_name, func_name): + meth_name = '_print_%s' % sympy_name + if hasattr(cls, meth_name): + raise ValueError("Edit method (or subclass) instead of overwriting.") + def _print_method(self, expr): + return '{}{}({})'.format(self._ns, func_name, ', '.join(map(self._print, expr.args))) + _print_method.__doc__ = "Prints code for %s" % k + setattr(cls, meth_name, _print_method) + + +def _attach_print_methods(cls, cont): + for sympy_name, cxx_name in cont[cls.standard].items(): + _attach_print_method(cls, sympy_name, cxx_name) + + +class _CXXCodePrinterBase: + printmethod = "_cxxcode" + language = 'C++' + _ns = 'std::' # namespace + + def __init__(self, settings=None): + super().__init__(settings or {}) + + @requires(headers={'algorithm'}) + def _print_Max(self, expr): + from sympy.functions.elementary.miscellaneous import Max + if len(expr.args) == 1: + return self._print(expr.args[0]) + return "%smax(%s, %s)" % (self._ns, self._print(expr.args[0]), + self._print(Max(*expr.args[1:]))) + + @requires(headers={'algorithm'}) + def _print_Min(self, expr): + from sympy.functions.elementary.miscellaneous import Min + if len(expr.args) == 1: + return self._print(expr.args[0]) + return "%smin(%s, %s)" % (self._ns, self._print(expr.args[0]), + self._print(Min(*expr.args[1:]))) + + def _print_using(self, expr): + if expr.alias == none: + return 'using %s' % expr.type + else: + raise ValueError("C++98 does not support type aliases") + + def _print_Raise(self, rs): + arg, = rs.args + return 'throw %s' % self._print(arg) + + @requires(headers={'stdexcept'}) + def _print_RuntimeError_(self, re): + message, = re.args + return "%sruntime_error(%s)" % (self._ns, self._print(message)) + + +class CXX98CodePrinter(_CXXCodePrinterBase, C89CodePrinter): + standard = 'C++98' + reserved_words = set(reserved['C++98']) + + +# _attach_print_methods(CXX98CodePrinter, _math_functions) + + +class CXX11CodePrinter(_CXXCodePrinterBase, C99CodePrinter): + standard = 'C++11' + reserved_words = set(reserved['C++11']) + type_mappings = dict(chain( + CXX98CodePrinter.type_mappings.items(), + { + Type('int8'): ('int8_t', {'cstdint'}), + Type('int16'): ('int16_t', {'cstdint'}), + Type('int32'): ('int32_t', {'cstdint'}), + Type('int64'): ('int64_t', {'cstdint'}), + Type('uint8'): ('uint8_t', {'cstdint'}), + Type('uint16'): ('uint16_t', {'cstdint'}), + Type('uint32'): ('uint32_t', {'cstdint'}), + Type('uint64'): ('uint64_t', {'cstdint'}), + Type('complex64'): ('std::complex', {'complex'}), + Type('complex128'): ('std::complex', {'complex'}), + Type('bool'): ('bool', None), + }.items() + )) + + def _print_using(self, expr): + if expr.alias == none: + return super()._print_using(expr) + else: + return 'using %(alias)s = %(type)s' % expr.kwargs(apply=self._print) + +# _attach_print_methods(CXX11CodePrinter, _math_functions) + + +class CXX17CodePrinter(_CXXCodePrinterBase, C99CodePrinter): + standard = 'C++17' + reserved_words = set(reserved['C++17']) + + _kf = dict(C99CodePrinter._kf, **_math_functions['C++17']) + + def _print_beta(self, expr): + return self._print_math_func(expr) + + def _print_Ei(self, expr): + return self._print_math_func(expr) + + def _print_zeta(self, expr): + return self._print_math_func(expr) + + +# _attach_print_methods(CXX17CodePrinter, _math_functions) + +cxx_code_printers = { + 'c++98': CXX98CodePrinter, + 'c++11': CXX11CodePrinter, + 'c++17': CXX17CodePrinter +} diff --git a/phivenv/Lib/site-packages/sympy/printing/defaults.py b/phivenv/Lib/site-packages/sympy/printing/defaults.py new file mode 100644 index 0000000000000000000000000000000000000000..77a88d353fed4bd70496456ddd03cc429a4ba5e7 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/defaults.py @@ -0,0 +1,5 @@ +from sympy.core._print_helpers import Printable + +# alias for compatibility +Printable.__module__ = __name__ +DefaultPrinting = Printable diff --git a/phivenv/Lib/site-packages/sympy/printing/dot.py b/phivenv/Lib/site-packages/sympy/printing/dot.py new file mode 100644 index 0000000000000000000000000000000000000000..c968fee389c16108b757b8fcad531ac6fa4ddb2f --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/dot.py @@ -0,0 +1,294 @@ +from sympy.core.basic import Basic +from sympy.core.expr import Expr +from sympy.core.symbol import Symbol +from sympy.core.numbers import Integer, Rational, Float +from sympy.printing.repr import srepr + +__all__ = ['dotprint'] + +default_styles = ( + (Basic, {'color': 'blue', 'shape': 'ellipse'}), + (Expr, {'color': 'black'}) +) + +slotClasses = (Symbol, Integer, Rational, Float) +def purestr(x, with_args=False): + """A string that follows ```obj = type(obj)(*obj.args)``` exactly. + + Parameters + ========== + + with_args : boolean, optional + If ``True``, there will be a second argument for the return + value, which is a tuple containing ``purestr`` applied to each + of the subnodes. + + If ``False``, there will not be a second argument for the + return. + + Default is ``False`` + + Examples + ======== + + >>> from sympy import Float, Symbol, MatrixSymbol + >>> from sympy import Integer # noqa: F401 + >>> from sympy.core.symbol import Str # noqa: F401 + >>> from sympy.printing.dot import purestr + + Applying ``purestr`` for basic symbolic object: + >>> code = purestr(Symbol('x')) + >>> code + "Symbol('x')" + >>> eval(code) == Symbol('x') + True + + For basic numeric object: + >>> purestr(Float(2)) + "Float('2.0', precision=53)" + + For matrix symbol: + >>> code = purestr(MatrixSymbol('x', 2, 2)) + >>> code + "MatrixSymbol(Str('x'), Integer(2), Integer(2))" + >>> eval(code) == MatrixSymbol('x', 2, 2) + True + + With ``with_args=True``: + >>> purestr(Float(2), with_args=True) + ("Float('2.0', precision=53)", ()) + >>> purestr(MatrixSymbol('x', 2, 2), with_args=True) + ("MatrixSymbol(Str('x'), Integer(2), Integer(2))", + ("Str('x')", 'Integer(2)', 'Integer(2)')) + """ + sargs = () + if not isinstance(x, Basic): + rv = str(x) + elif not x.args: + rv = srepr(x) + else: + args = x.args + sargs = tuple(map(purestr, args)) + rv = "%s(%s)"%(type(x).__name__, ', '.join(sargs)) + if with_args: + rv = rv, sargs + return rv + + +def styleof(expr, styles=default_styles): + """ Merge style dictionaries in order + + Examples + ======== + + >>> from sympy import Symbol, Basic, Expr, S + >>> from sympy.printing.dot import styleof + >>> styles = [(Basic, {'color': 'blue', 'shape': 'ellipse'}), + ... (Expr, {'color': 'black'})] + + >>> styleof(Basic(S(1)), styles) + {'color': 'blue', 'shape': 'ellipse'} + + >>> x = Symbol('x') + >>> styleof(x + 1, styles) # this is an Expr + {'color': 'black', 'shape': 'ellipse'} + """ + style = {} + for typ, sty in styles: + if isinstance(expr, typ): + style.update(sty) + return style + + +def attrprint(d, delimiter=', '): + """ Print a dictionary of attributes + + Examples + ======== + + >>> from sympy.printing.dot import attrprint + >>> print(attrprint({'color': 'blue', 'shape': 'ellipse'})) + "color"="blue", "shape"="ellipse" + """ + return delimiter.join('"%s"="%s"'%item for item in sorted(d.items())) + + +def dotnode(expr, styles=default_styles, labelfunc=str, pos=(), repeat=True): + """ String defining a node + + Examples + ======== + + >>> from sympy.printing.dot import dotnode + >>> from sympy.abc import x + >>> print(dotnode(x)) + "Symbol('x')_()" ["color"="black", "label"="x", "shape"="ellipse"]; + """ + style = styleof(expr, styles) + + if isinstance(expr, Basic) and not expr.is_Atom: + label = str(expr.__class__.__name__) + else: + label = labelfunc(expr) + style['label'] = label + expr_str = purestr(expr) + if repeat: + expr_str += '_%s' % str(pos) + return '"%s" [%s];' % (expr_str, attrprint(style)) + + +def dotedges(expr, atom=lambda x: not isinstance(x, Basic), pos=(), repeat=True): + """ List of strings for all expr->expr.arg pairs + + See the docstring of dotprint for explanations of the options. + + Examples + ======== + + >>> from sympy.printing.dot import dotedges + >>> from sympy.abc import x + >>> for e in dotedges(x+2): + ... print(e) + "Add(Integer(2), Symbol('x'))_()" -> "Integer(2)_(0,)"; + "Add(Integer(2), Symbol('x'))_()" -> "Symbol('x')_(1,)"; + """ + if atom(expr): + return [] + else: + expr_str, arg_strs = purestr(expr, with_args=True) + if repeat: + expr_str += '_%s' % str(pos) + arg_strs = ['%s_%s' % (a, str(pos + (i,))) + for i, a in enumerate(arg_strs)] + return ['"%s" -> "%s";' % (expr_str, a) for a in arg_strs] + +template = \ +"""digraph{ + +# Graph style +%(graphstyle)s + +######### +# Nodes # +######### + +%(nodes)s + +######### +# Edges # +######### + +%(edges)s +}""" + +_graphstyle = {'rankdir': 'TD', 'ordering': 'out'} + +def dotprint(expr, + styles=default_styles, atom=lambda x: not isinstance(x, Basic), + maxdepth=None, repeat=True, labelfunc=str, **kwargs): + """DOT description of a SymPy expression tree + + Parameters + ========== + + styles : list of lists composed of (Class, mapping), optional + Styles for different classes. + + The default is + + .. code-block:: python + + ( + (Basic, {'color': 'blue', 'shape': 'ellipse'}), + (Expr, {'color': 'black'}) + ) + + atom : function, optional + Function used to determine if an arg is an atom. + + A good choice is ``lambda x: not x.args``. + + The default is ``lambda x: not isinstance(x, Basic)``. + + maxdepth : integer, optional + The maximum depth. + + The default is ``None``, meaning no limit. + + repeat : boolean, optional + Whether to use different nodes for common subexpressions. + + The default is ``True``. + + For example, for ``x + x*y`` with ``repeat=True``, it will have + two nodes for ``x``; with ``repeat=False``, it will have one + node. + + .. warning:: + Even if a node appears twice in the same object like ``x`` in + ``Pow(x, x)``, it will still only appear once. + Hence, with ``repeat=False``, the number of arrows out of an + object might not equal the number of args it has. + + labelfunc : function, optional + A function to create a label for a given leaf node. + + The default is ``str``. + + Another good option is ``srepr``. + + For example with ``str``, the leaf nodes of ``x + 1`` are labeled, + ``x`` and ``1``. With ``srepr``, they are labeled ``Symbol('x')`` + and ``Integer(1)``. + + **kwargs : optional + Additional keyword arguments are included as styles for the graph. + + Examples + ======== + + >>> from sympy import dotprint + >>> from sympy.abc import x + >>> print(dotprint(x+2)) # doctest: +NORMALIZE_WHITESPACE + digraph{ + + # Graph style + "ordering"="out" + "rankdir"="TD" + + ######### + # Nodes # + ######### + + "Add(Integer(2), Symbol('x'))_()" ["color"="black", "label"="Add", "shape"="ellipse"]; + "Integer(2)_(0,)" ["color"="black", "label"="2", "shape"="ellipse"]; + "Symbol('x')_(1,)" ["color"="black", "label"="x", "shape"="ellipse"]; + + ######### + # Edges # + ######### + + "Add(Integer(2), Symbol('x'))_()" -> "Integer(2)_(0,)"; + "Add(Integer(2), Symbol('x'))_()" -> "Symbol('x')_(1,)"; + } + + """ + # repeat works by adding a signature tuple to the end of each node for its + # position in the graph. For example, for expr = Add(x, Pow(x, 2)), the x in the + # Pow will have the tuple (1, 0), meaning it is expr.args[1].args[0]. + graphstyle = _graphstyle.copy() + graphstyle.update(kwargs) + + nodes = [] + edges = [] + def traverse(e, depth, pos=()): + nodes.append(dotnode(e, styles, labelfunc=labelfunc, pos=pos, repeat=repeat)) + if maxdepth and depth >= maxdepth: + return + edges.extend(dotedges(e, atom=atom, pos=pos, repeat=repeat)) + [traverse(arg, depth+1, pos + (i,)) for i, arg in enumerate(e.args) if not atom(arg)] + traverse(expr, 0) + + return template%{'graphstyle': attrprint(graphstyle, delimiter='\n'), + 'nodes': '\n'.join(nodes), + 'edges': '\n'.join(edges)} diff --git a/phivenv/Lib/site-packages/sympy/printing/fortran.py b/phivenv/Lib/site-packages/sympy/printing/fortran.py new file mode 100644 index 0000000000000000000000000000000000000000..7cea812d72ddcd3ccb56c7258f74e6fe3b8d5211 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/fortran.py @@ -0,0 +1,779 @@ +""" +Fortran code printer + +The FCodePrinter converts single SymPy expressions into single Fortran +expressions, using the functions defined in the Fortran 77 standard where +possible. Some useful pointers to Fortran can be found on wikipedia: + +https://en.wikipedia.org/wiki/Fortran + +Most of the code below is based on the "Professional Programmer\'s Guide to +Fortran77" by Clive G. Page: + +https://www.star.le.ac.uk/~cgp/prof77.html + +Fortran is a case-insensitive language. This might cause trouble because +SymPy is case sensitive. So, fcode adds underscores to variable names when +it is necessary to make them different for Fortran. +""" + +from __future__ import annotations +from typing import Any + +from collections import defaultdict +from itertools import chain +import string + +from sympy.codegen.ast import ( + Assignment, Declaration, Pointer, value_const, + float32, float64, float80, complex64, complex128, int8, int16, int32, + int64, intc, real, integer, bool_, complex_, none, stderr, stdout +) +from sympy.codegen.fnodes import ( + allocatable, isign, dsign, cmplx, merge, literal_dp, elemental, pure, + intent_in, intent_out, intent_inout +) +from sympy.core import S, Add, N, Float, Symbol +from sympy.core.function import Function +from sympy.core.numbers import equal_valued +from sympy.core.relational import Eq +from sympy.sets import Range +from sympy.printing.codeprinter import CodePrinter +from sympy.printing.precedence import precedence, PRECEDENCE +from sympy.printing.printer import printer_context + +# These are defined in the other file so we can avoid importing sympy.codegen +# from the top-level 'import sympy'. Export them here as well. +from sympy.printing.codeprinter import fcode, print_fcode # noqa:F401 + +known_functions = { + "sin": "sin", + "cos": "cos", + "tan": "tan", + "asin": "asin", + "acos": "acos", + "atan": "atan", + "atan2": "atan2", + "sinh": "sinh", + "cosh": "cosh", + "tanh": "tanh", + "log": "log", + "exp": "exp", + "erf": "erf", + "Abs": "abs", + "conjugate": "conjg", + "Max": "max", + "Min": "min", +} + + +class FCodePrinter(CodePrinter): + """A printer to convert SymPy expressions to strings of Fortran code""" + printmethod = "_fcode" + language = "Fortran" + + type_aliases = { + integer: int32, + real: float64, + complex_: complex128, + } + + type_mappings = { + intc: 'integer(c_int)', + float32: 'real*4', # real(kind(0.e0)) + float64: 'real*8', # real(kind(0.d0)) + float80: 'real*10', # real(kind(????)) + complex64: 'complex*8', + complex128: 'complex*16', + int8: 'integer*1', + int16: 'integer*2', + int32: 'integer*4', + int64: 'integer*8', + bool_: 'logical' + } + + type_modules = { + intc: {'iso_c_binding': 'c_int'} + } + + _default_settings: dict[str, Any] = dict(CodePrinter._default_settings, **{ + 'precision': 17, + 'user_functions': {}, + 'source_format': 'fixed', + 'contract': True, + 'standard': 77, + 'name_mangling': True, + }) + + _operators = { + 'and': '.and.', + 'or': '.or.', + 'xor': '.neqv.', + 'equivalent': '.eqv.', + 'not': '.not. ', + } + + _relationals = { + '!=': '/=', + } + + def __init__(self, settings=None): + if not settings: + settings = {} + self.mangled_symbols = {} # Dict showing mapping of all words + self.used_name = [] + self.type_aliases = dict(chain(self.type_aliases.items(), + settings.pop('type_aliases', {}).items())) + self.type_mappings = dict(chain(self.type_mappings.items(), + settings.pop('type_mappings', {}).items())) + super().__init__(settings) + self.known_functions = dict(known_functions) + userfuncs = settings.get('user_functions', {}) + self.known_functions.update(userfuncs) + # leading columns depend on fixed or free format + standards = {66, 77, 90, 95, 2003, 2008} + if self._settings['standard'] not in standards: + raise ValueError("Unknown Fortran standard: %s" % self._settings[ + 'standard']) + self.module_uses = defaultdict(set) # e.g.: use iso_c_binding, only: c_int + + @property + def _lead(self): + if self._settings['source_format'] == 'fixed': + return {'code': " ", 'cont': " @ ", 'comment': "C "} + elif self._settings['source_format'] == 'free': + return {'code': "", 'cont': " ", 'comment': "! "} + else: + raise ValueError("Unknown source format: %s" % self._settings['source_format']) + + def _print_Symbol(self, expr): + if self._settings['name_mangling'] == True: + if expr not in self.mangled_symbols: + name = expr.name + while name.lower() in self.used_name: + name += '_' + self.used_name.append(name.lower()) + if name == expr.name: + self.mangled_symbols[expr] = expr + else: + self.mangled_symbols[expr] = Symbol(name) + + expr = expr.xreplace(self.mangled_symbols) + + name = super()._print_Symbol(expr) + return name + + def _rate_index_position(self, p): + return -p*5 + + def _get_statement(self, codestring): + return codestring + + def _get_comment(self, text): + return "! {}".format(text) + + def _declare_number_const(self, name, value): + return "parameter ({} = {})".format(name, self._print(value)) + + def _print_NumberSymbol(self, expr): + # A Number symbol that is not implemented here or with _printmethod + # is registered and evaluated + self._number_symbols.add((expr, Float(expr.evalf(self._settings['precision'])))) + return str(expr) + + def _format_code(self, lines): + return self._wrap_fortran(self.indent_code(lines)) + + def _traverse_matrix_indices(self, mat): + rows, cols = mat.shape + return ((i, j) for j in range(cols) for i in range(rows)) + + def _get_loop_opening_ending(self, indices): + open_lines = [] + close_lines = [] + for i in indices: + # fortran arrays start at 1 and end at dimension + var, start, stop = map(self._print, + [i.label, i.lower + 1, i.upper + 1]) + open_lines.append("do %s = %s, %s" % (var, start, stop)) + close_lines.append("end do") + return open_lines, close_lines + + def _print_sign(self, expr): + from sympy.functions.elementary.complexes import Abs + arg, = expr.args + if arg.is_integer: + new_expr = merge(0, isign(1, arg), Eq(arg, 0)) + elif (arg.is_complex or arg.is_infinite): + new_expr = merge(cmplx(literal_dp(0), literal_dp(0)), arg/Abs(arg), Eq(Abs(arg), literal_dp(0))) + else: + new_expr = merge(literal_dp(0), dsign(literal_dp(1), arg), Eq(arg, literal_dp(0))) + return self._print(new_expr) + + + def _print_Piecewise(self, expr): + if expr.args[-1].cond != True: + # We need the last conditional to be a True, otherwise the resulting + # function may not return a result. + raise ValueError("All Piecewise expressions must contain an " + "(expr, True) statement to be used as a default " + "condition. Without one, the generated " + "expression may not evaluate to anything under " + "some condition.") + lines = [] + if expr.has(Assignment): + for i, (e, c) in enumerate(expr.args): + if i == 0: + lines.append("if (%s) then" % self._print(c)) + elif i == len(expr.args) - 1 and c == True: + lines.append("else") + else: + lines.append("else if (%s) then" % self._print(c)) + lines.append(self._print(e)) + lines.append("end if") + return "\n".join(lines) + elif self._settings["standard"] >= 95: + # Only supported in F95 and newer: + # The piecewise was used in an expression, need to do inline + # operators. This has the downside that inline operators will + # not work for statements that span multiple lines (Matrix or + # Indexed expressions). + pattern = "merge({T}, {F}, {COND})" + code = self._print(expr.args[-1].expr) + terms = list(expr.args[:-1]) + while terms: + e, c = terms.pop() + expr = self._print(e) + cond = self._print(c) + code = pattern.format(T=expr, F=code, COND=cond) + return code + else: + # `merge` is not supported prior to F95 + raise NotImplementedError("Using Piecewise as an expression using " + "inline operators is not supported in " + "standards earlier than Fortran95.") + + def _print_MatrixElement(self, expr): + return "{}({}, {})".format(self.parenthesize(expr.parent, + PRECEDENCE["Atom"], strict=True), expr.i + 1, expr.j + 1) + + def _print_Add(self, expr): + # purpose: print complex numbers nicely in Fortran. + # collect the purely real and purely imaginary parts: + pure_real = [] + pure_imaginary = [] + mixed = [] + for arg in expr.args: + if arg.is_number and arg.is_real: + pure_real.append(arg) + elif arg.is_number and arg.is_imaginary: + pure_imaginary.append(arg) + else: + mixed.append(arg) + if pure_imaginary: + if mixed: + PREC = precedence(expr) + term = Add(*mixed) + t = self._print(term) + if t.startswith('-'): + sign = "-" + t = t[1:] + else: + sign = "+" + if precedence(term) < PREC: + t = "(%s)" % t + + return "cmplx(%s,%s) %s %s" % ( + self._print(Add(*pure_real)), + self._print(-S.ImaginaryUnit*Add(*pure_imaginary)), + sign, t, + ) + else: + return "cmplx(%s,%s)" % ( + self._print(Add(*pure_real)), + self._print(-S.ImaginaryUnit*Add(*pure_imaginary)), + ) + else: + return CodePrinter._print_Add(self, expr) + + def _print_Function(self, expr): + # All constant function args are evaluated as floats + prec = self._settings['precision'] + args = [N(a, prec) for a in expr.args] + eval_expr = expr.func(*args) + if not isinstance(eval_expr, Function): + return self._print(eval_expr) + else: + return CodePrinter._print_Function(self, expr.func(*args)) + + def _print_Mod(self, expr): + # NOTE : Fortran has the functions mod() and modulo(). modulo() behaves + # the same wrt to the sign of the arguments as Python and SymPy's + # modulus computations (% and Mod()) but is not available in Fortran 66 + # or Fortran 77, thus we raise an error. + if self._settings['standard'] in [66, 77]: + msg = ("Python % operator and SymPy's Mod() function are not " + "supported by Fortran 66 or 77 standards.") + raise NotImplementedError(msg) + else: + x, y = expr.args + return " modulo({}, {})".format(self._print(x), self._print(y)) + + def _print_ImaginaryUnit(self, expr): + # purpose: print complex numbers nicely in Fortran. + return "cmplx(0,1)" + + def _print_int(self, expr): + return str(expr) + + def _print_Mul(self, expr): + # purpose: print complex numbers nicely in Fortran. + if expr.is_number and expr.is_imaginary: + return "cmplx(0,%s)" % ( + self._print(-S.ImaginaryUnit*expr) + ) + else: + return CodePrinter._print_Mul(self, expr) + + def _print_Pow(self, expr): + PREC = precedence(expr) + if equal_valued(expr.exp, -1): + return '%s/%s' % ( + self._print(literal_dp(1)), + self.parenthesize(expr.base, PREC) + ) + elif equal_valued(expr.exp, 0.5): + if expr.base.is_integer: + # Fortran intrinsic sqrt() does not accept integer argument + if expr.base.is_Number: + return 'sqrt(%s.0d0)' % self._print(expr.base) + else: + return 'sqrt(dble(%s))' % self._print(expr.base) + else: + return 'sqrt(%s)' % self._print(expr.base) + else: + return CodePrinter._print_Pow(self, expr) + + def _print_Rational(self, expr): + p, q = int(expr.p), int(expr.q) + return "%d.0d0/%d.0d0" % (p, q) + + def _print_Float(self, expr): + printed = CodePrinter._print_Float(self, expr) + e = printed.find('e') + if e > -1: + return "%sd%s" % (printed[:e], printed[e + 1:]) + return "%sd0" % printed + + def _print_Relational(self, expr): + lhs_code = self._print(expr.lhs) + rhs_code = self._print(expr.rhs) + op = expr.rel_op + op = op if op not in self._relationals else self._relationals[op] + return "{} {} {}".format(lhs_code, op, rhs_code) + + def _print_Indexed(self, expr): + inds = [ self._print(i) for i in expr.indices ] + return "%s(%s)" % (self._print(expr.base.label), ", ".join(inds)) + + def _print_AugmentedAssignment(self, expr): + lhs_code = self._print(expr.lhs) + rhs_code = self._print(expr.rhs) + return self._get_statement("{0} = {0} {1} {2}".format( + self._print(lhs_code), self._print(expr.binop), self._print(rhs_code))) + + def _print_sum_(self, sm): + params = self._print(sm.array) + if sm.dim != None: # Must use '!= None', cannot use 'is not None' + params += ', ' + self._print(sm.dim) + if sm.mask != None: # Must use '!= None', cannot use 'is not None' + params += ', mask=' + self._print(sm.mask) + return '%s(%s)' % (sm.__class__.__name__.rstrip('_'), params) + + def _print_product_(self, prod): + return self._print_sum_(prod) + + def _print_Do(self, do): + excl = ['concurrent'] + if do.step == 1: + excl.append('step') + step = '' + else: + step = ', {step}' + + return ( + 'do {concurrent}{counter} = {first}, {last}'+step+'\n' + '{body}\n' + 'end do\n' + ).format( + concurrent='concurrent ' if do.concurrent else '', + **do.kwargs(apply=lambda arg: self._print(arg), exclude=excl) + ) + + def _print_ImpliedDoLoop(self, idl): + step = '' if idl.step == 1 else ', {step}' + return ('({expr}, {counter} = {first}, {last}'+step+')').format( + **idl.kwargs(apply=lambda arg: self._print(arg)) + ) + + def _print_For(self, expr): + target = self._print(expr.target) + if isinstance(expr.iterable, Range): + start, stop, step = expr.iterable.args + else: + raise NotImplementedError("Only iterable currently supported is Range") + body = self._print(expr.body) + return ('do {target} = {start}, {stop}, {step}\n' + '{body}\n' + 'end do').format(target=target, start=start, stop=stop - 1, + step=step, body=body) + + def _print_Type(self, type_): + type_ = self.type_aliases.get(type_, type_) + type_str = self.type_mappings.get(type_, type_.name) + module_uses = self.type_modules.get(type_) + if module_uses: + for k, v in module_uses: + self.module_uses[k].add(v) + return type_str + + def _print_Element(self, elem): + return '{symbol}({idxs})'.format( + symbol=self._print(elem.symbol), + idxs=', '.join((self._print(arg) for arg in elem.indices)) + ) + + def _print_Extent(self, ext): + return str(ext) + + def _print_Declaration(self, expr): + var = expr.variable + val = var.value + dim = var.attr_params('dimension') + intents = [intent in var.attrs for intent in (intent_in, intent_out, intent_inout)] + if intents.count(True) == 0: + intent = '' + elif intents.count(True) == 1: + intent = ', intent(%s)' % ['in', 'out', 'inout'][intents.index(True)] + else: + raise ValueError("Multiple intents specified for %s" % self) + + if isinstance(var, Pointer): + raise NotImplementedError("Pointers are not available by default in Fortran.") + if self._settings["standard"] >= 90: + result = '{t}{vc}{dim}{intent}{alloc} :: {s}'.format( + t=self._print(var.type), + vc=', parameter' if value_const in var.attrs else '', + dim=', dimension(%s)' % ', '.join((self._print(arg) for arg in dim)) if dim else '', + intent=intent, + alloc=', allocatable' if allocatable in var.attrs else '', + s=self._print(var.symbol) + ) + if val != None: # Must be "!= None", cannot be "is not None" + result += ' = %s' % self._print(val) + else: + if value_const in var.attrs or val: + raise NotImplementedError("F77 init./parameter statem. req. multiple lines.") + result = ' '.join((self._print(arg) for arg in [var.type, var.symbol])) + + return result + + + def _print_Infinity(self, expr): + return '(huge(%s) + 1)' % self._print(literal_dp(0)) + + def _print_While(self, expr): + return 'do while ({condition})\n{body}\nend do'.format(**expr.kwargs( + apply=lambda arg: self._print(arg))) + + def _print_BooleanTrue(self, expr): + return '.true.' + + def _print_BooleanFalse(self, expr): + return '.false.' + + def _pad_leading_columns(self, lines): + result = [] + for line in lines: + if line.startswith('!'): + result.append(self._lead['comment'] + line[1:].lstrip()) + else: + result.append(self._lead['code'] + line) + return result + + def _wrap_fortran(self, lines): + """Wrap long Fortran lines + + Argument: + lines -- a list of lines (without \\n character) + + A comment line is split at white space. Code lines are split with a more + complex rule to give nice results. + """ + # routine to find split point in a code line + my_alnum = set("_+-." + string.digits + string.ascii_letters) + my_white = set(" \t()") + + def split_pos_code(line, endpos): + if len(line) <= endpos: + return len(line) + pos = endpos + split = lambda pos: \ + (line[pos] in my_alnum and line[pos - 1] not in my_alnum) or \ + (line[pos] not in my_alnum and line[pos - 1] in my_alnum) or \ + (line[pos] in my_white and line[pos - 1] not in my_white) or \ + (line[pos] not in my_white and line[pos - 1] in my_white) + while not split(pos): + pos -= 1 + if pos == 0: + return endpos + return pos + # split line by line and add the split lines to result + result = [] + if self._settings['source_format'] == 'free': + trailing = ' &' + else: + trailing = '' + for line in lines: + if line.startswith(self._lead['comment']): + # comment line + if len(line) > 72: + pos = line.rfind(" ", 6, 72) + if pos == -1: + pos = 72 + hunk = line[:pos] + line = line[pos:].lstrip() + result.append(hunk) + while line: + pos = line.rfind(" ", 0, 66) + if pos == -1 or len(line) < 66: + pos = 66 + hunk = line[:pos] + line = line[pos:].lstrip() + result.append("%s%s" % (self._lead['comment'], hunk)) + else: + result.append(line) + elif line.startswith(self._lead['code']): + # code line + pos = split_pos_code(line, 72) + hunk = line[:pos].rstrip() + line = line[pos:].lstrip() + if line: + hunk += trailing + result.append(hunk) + while line: + pos = split_pos_code(line, 65) + hunk = line[:pos].rstrip() + line = line[pos:].lstrip() + if line: + hunk += trailing + result.append("%s%s" % (self._lead['cont'], hunk)) + else: + result.append(line) + return result + + def indent_code(self, code): + """Accepts a string of code or a list of code lines""" + if isinstance(code, str): + code_lines = self.indent_code(code.splitlines(True)) + return ''.join(code_lines) + + free = self._settings['source_format'] == 'free' + code = [ line.lstrip(' \t') for line in code ] + + inc_keyword = ('do ', 'if(', 'if ', 'do\n', 'else', 'program', 'interface') + dec_keyword = ('end do', 'enddo', 'end if', 'endif', 'else', 'end program', 'end interface') + + increase = [ int(any(map(line.startswith, inc_keyword))) + for line in code ] + decrease = [ int(any(map(line.startswith, dec_keyword))) + for line in code ] + continuation = [ int(any(map(line.endswith, ['&', '&\n']))) + for line in code ] + + level = 0 + cont_padding = 0 + tabwidth = 3 + new_code = [] + for i, line in enumerate(code): + if line in ('', '\n'): + new_code.append(line) + continue + level -= decrease[i] + + if free: + padding = " "*(level*tabwidth + cont_padding) + else: + padding = " "*level*tabwidth + + line = "%s%s" % (padding, line) + if not free: + line = self._pad_leading_columns([line])[0] + + new_code.append(line) + + if continuation[i]: + cont_padding = 2*tabwidth + else: + cont_padding = 0 + level += increase[i] + + if not free: + return self._wrap_fortran(new_code) + return new_code + + def _print_GoTo(self, goto): + if goto.expr: # computed goto + return "go to ({labels}), {expr}".format( + labels=', '.join((self._print(arg) for arg in goto.labels)), + expr=self._print(goto.expr) + ) + else: + lbl, = goto.labels + return "go to %s" % self._print(lbl) + + def _print_Program(self, prog): + return ( + "program {name}\n" + "{body}\n" + "end program\n" + ).format(**prog.kwargs(apply=lambda arg: self._print(arg))) + + def _print_Module(self, mod): + return ( + "module {name}\n" + "{declarations}\n" + "\ncontains\n\n" + "{definitions}\n" + "end module\n" + ).format(**mod.kwargs(apply=lambda arg: self._print(arg))) + + def _print_Stream(self, strm): + if strm.name == 'stdout' and self._settings["standard"] >= 2003: + self.module_uses['iso_c_binding'].add('stdint=>input_unit') + return 'input_unit' + elif strm.name == 'stderr' and self._settings["standard"] >= 2003: + self.module_uses['iso_c_binding'].add('stdint=>error_unit') + return 'error_unit' + else: + if strm.name == 'stdout': + return '*' + else: + return strm.name + + def _print_Print(self, ps): + if ps.format_string == none: # Must be '!= None', cannot be 'is not None' + template = "print {fmt}, {iolist}" + fmt = '*' + else: + template = 'write(%(out)s, fmt="{fmt}", advance="no"), {iolist}' % { + 'out': {stderr: '0', stdout: '6'}.get(ps.file, '*') + } + fmt = self._print(ps.format_string) + return template.format(fmt=fmt, iolist=', '.join( + (self._print(arg) for arg in ps.print_args))) + + def _print_Return(self, rs): + arg, = rs.args + return "{result_name} = {arg}".format( + result_name=self._context.get('result_name', 'sympy_result'), + arg=self._print(arg) + ) + + def _print_FortranReturn(self, frs): + arg, = frs.args + if arg: + return 'return %s' % self._print(arg) + else: + return 'return' + + def _head(self, entity, fp, **kwargs): + bind_C_params = fp.attr_params('bind_C') + if bind_C_params is None: + bind = '' + else: + bind = ' bind(C, name="%s")' % bind_C_params[0] if bind_C_params else ' bind(C)' + result_name = self._settings.get('result_name', None) + return ( + "{entity}{name}({arg_names}){result}{bind}\n" + "{arg_declarations}" + ).format( + entity=entity, + name=self._print(fp.name), + arg_names=', '.join([self._print(arg.symbol) for arg in fp.parameters]), + result=(' result(%s)' % result_name) if result_name else '', + bind=bind, + arg_declarations='\n'.join((self._print(Declaration(arg)) for arg in fp.parameters)) + ) + + def _print_FunctionPrototype(self, fp): + entity = "{} function ".format(self._print(fp.return_type)) + return ( + "interface\n" + "{function_head}\n" + "end function\n" + "end interface" + ).format(function_head=self._head(entity, fp)) + + def _print_FunctionDefinition(self, fd): + if elemental in fd.attrs: + prefix = 'elemental ' + elif pure in fd.attrs: + prefix = 'pure ' + else: + prefix = '' + + entity = "{} function ".format(self._print(fd.return_type)) + with printer_context(self, result_name=fd.name): + return ( + "{prefix}{function_head}\n" + "{body}\n" + "end function\n" + ).format( + prefix=prefix, + function_head=self._head(entity, fd), + body=self._print(fd.body) + ) + + def _print_Subroutine(self, sub): + return ( + '{subroutine_head}\n' + '{body}\n' + 'end subroutine\n' + ).format( + subroutine_head=self._head('subroutine ', sub), + body=self._print(sub.body) + ) + + def _print_SubroutineCall(self, scall): + return 'call {name}({args})'.format( + name=self._print(scall.name), + args=', '.join((self._print(arg) for arg in scall.subroutine_args)) + ) + + def _print_use_rename(self, rnm): + return "%s => %s" % tuple((self._print(arg) for arg in rnm.args)) + + def _print_use(self, use): + result = 'use %s' % self._print(use.namespace) + if use.rename != None: # Must be '!= None', cannot be 'is not None' + result += ', ' + ', '.join([self._print(rnm) for rnm in use.rename]) + if use.only != None: # Must be '!= None', cannot be 'is not None' + result += ', only: ' + ', '.join([self._print(nly) for nly in use.only]) + return result + + def _print_BreakToken(self, _): + return 'exit' + + def _print_ContinueToken(self, _): + return 'cycle' + + def _print_ArrayConstructor(self, ac): + fmtstr = "[%s]" if self._settings["standard"] >= 2003 else '(/%s/)' + return fmtstr % ', '.join((self._print(arg) for arg in ac.elements)) + + def _print_ArrayElement(self, elem): + return '{symbol}({idxs})'.format( + symbol=self._print(elem.name), + idxs=', '.join((self._print(arg) for arg in elem.indices)) + ) diff --git a/phivenv/Lib/site-packages/sympy/printing/glsl.py b/phivenv/Lib/site-packages/sympy/printing/glsl.py new file mode 100644 index 0000000000000000000000000000000000000000..f98df8d46abec5f891b6bc9836a13ca69934275c --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/glsl.py @@ -0,0 +1,548 @@ +from __future__ import annotations + +from sympy.core import Basic, S +from sympy.core.function import Lambda +from sympy.core.numbers import equal_valued +from sympy.printing.codeprinter import CodePrinter +from sympy.printing.precedence import precedence +from functools import reduce + +known_functions = { + 'Abs': 'abs', + 'sin': 'sin', + 'cos': 'cos', + 'tan': 'tan', + 'acos': 'acos', + 'asin': 'asin', + 'atan': 'atan', + 'atan2': 'atan', + 'ceiling': 'ceil', + 'floor': 'floor', + 'sign': 'sign', + 'exp': 'exp', + 'log': 'log', + 'add': 'add', + 'sub': 'sub', + 'mul': 'mul', + 'pow': 'pow' +} + +class GLSLPrinter(CodePrinter): + """ + Rudimentary, generic GLSL printing tools. + + Additional settings: + 'use_operators': Boolean (should the printer use operators for +,-,*, or functions?) + """ + _not_supported: set[Basic] = set() + printmethod = "_glsl" + language = "GLSL" + + _default_settings = dict(CodePrinter._default_settings, **{ + 'use_operators': True, + 'zero': 0, + 'mat_nested': False, + 'mat_separator': ',\n', + 'mat_transpose': False, + 'array_type': 'float', + 'glsl_types': True, + + 'precision': 9, + 'user_functions': {}, + 'contract': True, + }) + + def __init__(self, settings={}): + CodePrinter.__init__(self, settings) + self.known_functions = dict(known_functions) + userfuncs = settings.get('user_functions', {}) + self.known_functions.update(userfuncs) + + def _rate_index_position(self, p): + return p*5 + + def _get_statement(self, codestring): + return "%s;" % codestring + + def _get_comment(self, text): + return "// {}".format(text) + + def _declare_number_const(self, name, value): + return "float {} = {};".format(name, value) + + def _format_code(self, lines): + return self.indent_code(lines) + + def indent_code(self, code): + """Accepts a string of code or a list of code lines""" + + if isinstance(code, str): + code_lines = self.indent_code(code.splitlines(True)) + return ''.join(code_lines) + + tab = " " + inc_token = ('{', '(', '{\n', '(\n') + dec_token = ('}', ')') + + code = [line.lstrip(' \t') for line in code] + + increase = [int(any(map(line.endswith, inc_token))) for line in code] + decrease = [int(any(map(line.startswith, dec_token))) for line in code] + + pretty = [] + level = 0 + for n, line in enumerate(code): + if line in ('', '\n'): + pretty.append(line) + continue + level -= decrease[n] + pretty.append("%s%s" % (tab*level, line)) + level += increase[n] + return pretty + + def _print_MatrixBase(self, mat): + mat_separator = self._settings['mat_separator'] + mat_transpose = self._settings['mat_transpose'] + column_vector = (mat.rows == 1) if mat_transpose else (mat.cols == 1) + A = mat.transpose() if mat_transpose != column_vector else mat + + glsl_types = self._settings['glsl_types'] + array_type = self._settings['array_type'] + array_size = A.cols*A.rows + array_constructor = "{}[{}]".format(array_type, array_size) + + if A.cols == 1: + return self._print(A[0]) + if A.rows <= 4 and A.cols <= 4 and glsl_types: + if A.rows == 1: + return "vec{}{}".format( + A.cols, A.table(self,rowstart='(',rowend=')') + ) + elif A.rows == A.cols: + return "mat{}({})".format( + A.rows, A.table(self,rowsep=', ', + rowstart='',rowend='') + ) + else: + return "mat{}x{}({})".format( + A.cols, A.rows, + A.table(self,rowsep=', ', + rowstart='',rowend='') + ) + elif S.One in A.shape: + return "{}({})".format( + array_constructor, + A.table(self,rowsep=mat_separator,rowstart='',rowend='') + ) + elif not self._settings['mat_nested']: + return "{}(\n{}\n) /* a {}x{} matrix */".format( + array_constructor, + A.table(self,rowsep=mat_separator,rowstart='',rowend=''), + A.rows, A.cols + ) + elif self._settings['mat_nested']: + return "{}[{}][{}](\n{}\n)".format( + array_type, A.rows, A.cols, + A.table(self,rowsep=mat_separator,rowstart='float[](',rowend=')') + ) + + def _print_SparseRepMatrix(self, mat): + # do not allow sparse matrices to be made dense + return self._print_not_supported(mat) + + def _traverse_matrix_indices(self, mat): + mat_transpose = self._settings['mat_transpose'] + if mat_transpose: + rows,cols = mat.shape + else: + cols,rows = mat.shape + return ((i, j) for i in range(cols) for j in range(rows)) + + def _print_MatrixElement(self, expr): + # print('begin _print_MatrixElement') + nest = self._settings['mat_nested'] + glsl_types = self._settings['glsl_types'] + mat_transpose = self._settings['mat_transpose'] + if mat_transpose: + cols,rows = expr.parent.shape + i,j = expr.j,expr.i + else: + rows,cols = expr.parent.shape + i,j = expr.i,expr.j + pnt = self._print(expr.parent) + if glsl_types and ((rows <= 4 and cols <=4) or nest): + return "{}[{}][{}]".format(pnt, i, j) + else: + return "{}[{}]".format(pnt, i + j*rows) + + def _print_list(self, expr): + l = ', '.join(self._print(item) for item in expr) + glsl_types = self._settings['glsl_types'] + array_type = self._settings['array_type'] + array_size = len(expr) + array_constructor = '{}[{}]'.format(array_type, array_size) + + if array_size <= 4 and glsl_types: + return 'vec{}({})'.format(array_size, l) + else: + return '{}({})'.format(array_constructor, l) + + _print_tuple = _print_list + _print_Tuple = _print_list + + def _get_loop_opening_ending(self, indices): + open_lines = [] + close_lines = [] + loopstart = "for (int %(varble)s=%(start)s; %(varble)s<%(end)s; %(varble)s++){" + for i in indices: + # GLSL arrays start at 0 and end at dimension-1 + open_lines.append(loopstart % { + 'varble': self._print(i.label), + 'start': self._print(i.lower), + 'end': self._print(i.upper + 1)}) + close_lines.append("}") + return open_lines, close_lines + + def _print_Function_with_args(self, func, func_args): + if func in self.known_functions: + cond_func = self.known_functions[func] + func = None + if isinstance(cond_func, str): + func = cond_func + else: + for cond, func in cond_func: + if cond(func_args): + break + if func is not None: + try: + return func(*[self.parenthesize(item, 0) for item in func_args]) + except TypeError: + return '{}({})'.format(func, self.stringify(func_args, ", ")) + elif isinstance(func, Lambda): + # inlined function + return self._print(func(*func_args)) + else: + return self._print_not_supported(func) + + def _print_Piecewise(self, expr): + from sympy.codegen.ast import Assignment + if expr.args[-1].cond != True: + # We need the last conditional to be a True, otherwise the resulting + # function may not return a result. + raise ValueError("All Piecewise expressions must contain an " + "(expr, True) statement to be used as a default " + "condition. Without one, the generated " + "expression may not evaluate to anything under " + "some condition.") + lines = [] + if expr.has(Assignment): + for i, (e, c) in enumerate(expr.args): + if i == 0: + lines.append("if (%s) {" % self._print(c)) + elif i == len(expr.args) - 1 and c == True: + lines.append("else {") + else: + lines.append("else if (%s) {" % self._print(c)) + code0 = self._print(e) + lines.append(code0) + lines.append("}") + return "\n".join(lines) + else: + # The piecewise was used in an expression, need to do inline + # operators. This has the downside that inline operators will + # not work for statements that span multiple lines (Matrix or + # Indexed expressions). + ecpairs = ["((%s) ? (\n%s\n)\n" % (self._print(c), + self._print(e)) + for e, c in expr.args[:-1]] + last_line = ": (\n%s\n)" % self._print(expr.args[-1].expr) + return ": ".join(ecpairs) + last_line + " ".join([")"*len(ecpairs)]) + + def _print_Indexed(self, expr): + # calculate index for 1d array + dims = expr.shape + elem = S.Zero + offset = S.One + for i in reversed(range(expr.rank)): + elem += expr.indices[i]*offset + offset *= dims[i] + return "{}[{}]".format( + self._print(expr.base.label), + self._print(elem) + ) + + def _print_Pow(self, expr): + PREC = precedence(expr) + if equal_valued(expr.exp, -1): + return '1.0/%s' % (self.parenthesize(expr.base, PREC)) + elif equal_valued(expr.exp, 0.5): + return 'sqrt(%s)' % self._print(expr.base) + else: + try: + e = self._print(float(expr.exp)) + except TypeError: + e = self._print(expr.exp) + return self._print_Function_with_args('pow', ( + self._print(expr.base), + e + )) + + def _print_int(self, expr): + return str(float(expr)) + + def _print_Rational(self, expr): + return "{}.0/{}.0".format(expr.p, expr.q) + + def _print_Relational(self, expr): + lhs_code = self._print(expr.lhs) + rhs_code = self._print(expr.rhs) + op = expr.rel_op + return "{} {} {}".format(lhs_code, op, rhs_code) + + def _print_Add(self, expr, order=None): + if self._settings['use_operators']: + return CodePrinter._print_Add(self, expr, order=order) + + terms = expr.as_ordered_terms() + + def partition(p,l): + return reduce(lambda x, y: (x[0]+[y], x[1]) if p(y) else (x[0], x[1]+[y]), l, ([], [])) + def add(a,b): + return self._print_Function_with_args('add', (a, b)) + # return self.known_functions['add']+'(%s, %s)' % (a,b) + neg, pos = partition(lambda arg: arg.could_extract_minus_sign(), terms) + if pos: + s = pos = reduce(lambda a,b: add(a,b), (self._print(t) for t in pos)) + else: + s = pos = self._print(self._settings['zero']) + + if neg: + # sum the absolute values of the negative terms + neg = reduce(lambda a,b: add(a,b), (self._print(-n) for n in neg)) + # then subtract them from the positive terms + s = self._print_Function_with_args('sub', (pos,neg)) + # s = self.known_functions['sub']+'(%s, %s)' % (pos,neg) + return s + + def _print_Mul(self, expr, **kwargs): + if self._settings['use_operators']: + return CodePrinter._print_Mul(self, expr, **kwargs) + terms = expr.as_ordered_factors() + def mul(a,b): + # return self.known_functions['mul']+'(%s, %s)' % (a,b) + return self._print_Function_with_args('mul', (a,b)) + + s = reduce(lambda a,b: mul(a,b), (self._print(t) for t in terms)) + return s + +def glsl_code(expr,assign_to=None,**settings): + """Converts an expr to a string of GLSL code + + Parameters + ========== + + expr : Expr + A SymPy expression to be converted. + assign_to : optional + When given, the argument is used for naming the variable or variables + to which the expression is assigned. Can be a string, ``Symbol``, + ``MatrixSymbol`` or ``Indexed`` type object. In cases where ``expr`` + would be printed as an array, a list of string or ``Symbol`` objects + can also be passed. + + This is helpful in case of line-wrapping, or for expressions that + generate multi-line statements. It can also be used to spread an array-like + expression into multiple assignments. + use_operators: bool, optional + If set to False, then *,/,+,- operators will be replaced with functions + mul, add, and sub, which must be implemented by the user, e.g. for + implementing non-standard rings or emulated quad/octal precision. + [default=True] + glsl_types: bool, optional + Set this argument to ``False`` in order to avoid using the ``vec`` and ``mat`` + types. The printer will instead use arrays (or nested arrays). + [default=True] + mat_nested: bool, optional + GLSL version 4.3 and above support nested arrays (arrays of arrays). Set this to ``True`` + to render matrices as nested arrays. + [default=False] + mat_separator: str, optional + By default, matrices are rendered with newlines using this separator, + making them easier to read, but less compact. By removing the newline + this option can be used to make them more vertically compact. + [default=',\n'] + mat_transpose: bool, optional + GLSL's matrix multiplication implementation assumes column-major indexing. + By default, this printer ignores that convention. Setting this option to + ``True`` transposes all matrix output. + [default=False] + array_type: str, optional + The GLSL array constructor type. + [default='float'] + precision : integer, optional + The precision for numbers such as pi [default=15]. + user_functions : dict, optional + A dictionary where keys are ``FunctionClass`` instances and values are + their string representations. Alternatively, the dictionary value can + be a list of tuples i.e. [(argument_test, js_function_string)]. See + below for examples. + human : bool, optional + If True, the result is a single string that may contain some constant + declarations for the number symbols. If False, the same information is + returned in a tuple of (symbols_to_declare, not_supported_functions, + code_text). [default=True]. + contract: bool, optional + If True, ``Indexed`` instances are assumed to obey tensor contraction + rules and the corresponding nested loops over indices are generated. + Setting contract=False will not generate loops, instead the user is + responsible to provide values for the indices in the code. + [default=True]. + + Examples + ======== + + >>> from sympy import glsl_code, symbols, Rational, sin, ceiling, Abs + >>> x, tau = symbols("x, tau") + >>> glsl_code((2*tau)**Rational(7, 2)) + '8*sqrt(2)*pow(tau, 3.5)' + >>> glsl_code(sin(x), assign_to="float y") + 'float y = sin(x);' + + Various GLSL types are supported: + >>> from sympy import Matrix, glsl_code + >>> glsl_code(Matrix([1,2,3])) + 'vec3(1, 2, 3)' + + >>> glsl_code(Matrix([[1, 2],[3, 4]])) + 'mat2(1, 2, 3, 4)' + + Pass ``mat_transpose = True`` to switch to column-major indexing: + >>> glsl_code(Matrix([[1, 2],[3, 4]]), mat_transpose = True) + 'mat2(1, 3, 2, 4)' + + By default, larger matrices get collapsed into float arrays: + >>> print(glsl_code( Matrix([[1,2,3,4,5],[6,7,8,9,10]]) )) + float[10]( + 1, 2, 3, 4, 5, + 6, 7, 8, 9, 10 + ) /* a 2x5 matrix */ + + The type of array constructor used to print GLSL arrays can be controlled + via the ``array_type`` parameter: + >>> glsl_code(Matrix([1,2,3,4,5]), array_type='int') + 'int[5](1, 2, 3, 4, 5)' + + Passing a list of strings or ``symbols`` to the ``assign_to`` parameter will yield + a multi-line assignment for each item in an array-like expression: + >>> x_struct_members = symbols('x.a x.b x.c x.d') + >>> print(glsl_code(Matrix([1,2,3,4]), assign_to=x_struct_members)) + x.a = 1; + x.b = 2; + x.c = 3; + x.d = 4; + + This could be useful in cases where it's desirable to modify members of a + GLSL ``Struct``. It could also be used to spread items from an array-like + expression into various miscellaneous assignments: + >>> misc_assignments = ('x[0]', 'x[1]', 'float y', 'float z') + >>> print(glsl_code(Matrix([1,2,3,4]), assign_to=misc_assignments)) + x[0] = 1; + x[1] = 2; + float y = 3; + float z = 4; + + Passing ``mat_nested = True`` instead prints out nested float arrays, which are + supported in GLSL 4.3 and above. + >>> mat = Matrix([ + ... [ 0, 1, 2], + ... [ 3, 4, 5], + ... [ 6, 7, 8], + ... [ 9, 10, 11], + ... [12, 13, 14]]) + >>> print(glsl_code( mat, mat_nested = True )) + float[5][3]( + float[]( 0, 1, 2), + float[]( 3, 4, 5), + float[]( 6, 7, 8), + float[]( 9, 10, 11), + float[](12, 13, 14) + ) + + + + Custom printing can be defined for certain types by passing a dictionary of + "type" : "function" to the ``user_functions`` kwarg. Alternatively, the + dictionary value can be a list of tuples i.e. [(argument_test, + js_function_string)]. + + >>> custom_functions = { + ... "ceiling": "CEIL", + ... "Abs": [(lambda x: not x.is_integer, "fabs"), + ... (lambda x: x.is_integer, "ABS")] + ... } + >>> glsl_code(Abs(x) + ceiling(x), user_functions=custom_functions) + 'fabs(x) + CEIL(x)' + + If further control is needed, addition, subtraction, multiplication and + division operators can be replaced with ``add``, ``sub``, and ``mul`` + functions. This is done by passing ``use_operators = False``: + + >>> x,y,z = symbols('x,y,z') + >>> glsl_code(x*(y+z), use_operators = False) + 'mul(x, add(y, z))' + >>> glsl_code(x*(y+z*(x-y)**z), use_operators = False) + 'mul(x, add(y, mul(z, pow(sub(x, y), z))))' + + ``Piecewise`` expressions are converted into conditionals. If an + ``assign_to`` variable is provided an if statement is created, otherwise + the ternary operator is used. Note that if the ``Piecewise`` lacks a + default term, represented by ``(expr, True)`` then an error will be thrown. + This is to prevent generating an expression that may not evaluate to + anything. + + >>> from sympy import Piecewise + >>> expr = Piecewise((x + 1, x > 0), (x, True)) + >>> print(glsl_code(expr, tau)) + if (x > 0) { + tau = x + 1; + } + else { + tau = x; + } + + Support for loops is provided through ``Indexed`` types. With + ``contract=True`` these expressions will be turned into loops, whereas + ``contract=False`` will just print the assignment expression that should be + looped over: + + >>> from sympy import Eq, IndexedBase, Idx + >>> len_y = 5 + >>> y = IndexedBase('y', shape=(len_y,)) + >>> t = IndexedBase('t', shape=(len_y,)) + >>> Dy = IndexedBase('Dy', shape=(len_y-1,)) + >>> i = Idx('i', len_y-1) + >>> e=Eq(Dy[i], (y[i+1]-y[i])/(t[i+1]-t[i])) + >>> glsl_code(e.rhs, assign_to=e.lhs, contract=False) + 'Dy[i] = (y[i + 1] - y[i])/(t[i + 1] - t[i]);' + + >>> from sympy import Matrix, MatrixSymbol + >>> mat = Matrix([x**2, Piecewise((x + 1, x > 0), (x, True)), sin(x)]) + >>> A = MatrixSymbol('A', 3, 1) + >>> print(glsl_code(mat, A)) + A[0][0] = pow(x, 2.0); + if (x > 0) { + A[1][0] = x + 1; + } + else { + A[1][0] = x; + } + A[2][0] = sin(x); + """ + return GLSLPrinter(settings).doprint(expr,assign_to) + +def print_glsl(expr, **settings): + """Prints the GLSL representation of the given expression. + + See GLSLPrinter init function for settings. + """ + print(glsl_code(expr, **settings)) diff --git a/phivenv/Lib/site-packages/sympy/printing/gtk.py b/phivenv/Lib/site-packages/sympy/printing/gtk.py new file mode 100644 index 0000000000000000000000000000000000000000..4123d7231c730bbde28e33f441470c28b21c78d0 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/gtk.py @@ -0,0 +1,16 @@ +from sympy.printing.mathml import mathml +from sympy.utilities.mathml import c2p +import tempfile +import subprocess + + +def print_gtk(x, start_viewer=True): + """Print to Gtkmathview, a gtk widget capable of rendering MathML. + + Needs libgtkmathview-bin""" + with tempfile.NamedTemporaryFile('w') as file: + file.write(c2p(mathml(x), simple=True)) + file.flush() + + if start_viewer: + subprocess.check_call(('mathmlviewer', file.name)) diff --git a/phivenv/Lib/site-packages/sympy/printing/jscode.py b/phivenv/Lib/site-packages/sympy/printing/jscode.py new file mode 100644 index 0000000000000000000000000000000000000000..753eb3291dd719ff53b06584de8ebe76c4471a3f --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/jscode.py @@ -0,0 +1,332 @@ +""" +Javascript code printer + +The JavascriptCodePrinter converts single SymPy expressions into single +Javascript expressions, using the functions defined in the Javascript +Math object where possible. + +""" + +from __future__ import annotations +from typing import Any + +from sympy.core import S +from sympy.core.numbers import equal_valued +from sympy.printing.codeprinter import CodePrinter +from sympy.printing.precedence import precedence, PRECEDENCE + + +# dictionary mapping SymPy function to (argument_conditions, Javascript_function). +# Used in JavascriptCodePrinter._print_Function(self) +known_functions = { + 'Abs': 'Math.abs', + 'acos': 'Math.acos', + 'acosh': 'Math.acosh', + 'asin': 'Math.asin', + 'asinh': 'Math.asinh', + 'atan': 'Math.atan', + 'atan2': 'Math.atan2', + 'atanh': 'Math.atanh', + 'ceiling': 'Math.ceil', + 'cos': 'Math.cos', + 'cosh': 'Math.cosh', + 'exp': 'Math.exp', + 'floor': 'Math.floor', + 'log': 'Math.log', + 'Max': 'Math.max', + 'Min': 'Math.min', + 'sign': 'Math.sign', + 'sin': 'Math.sin', + 'sinh': 'Math.sinh', + 'tan': 'Math.tan', + 'tanh': 'Math.tanh', +} + + +class JavascriptCodePrinter(CodePrinter): + """"A Printer to convert Python expressions to strings of JavaScript code + """ + printmethod = '_javascript' + language = 'JavaScript' + + _default_settings: dict[str, Any] = dict(CodePrinter._default_settings, **{ + 'precision': 17, + 'user_functions': {}, + 'contract': True, + }) + + def __init__(self, settings={}): + CodePrinter.__init__(self, settings) + self.known_functions = dict(known_functions) + userfuncs = settings.get('user_functions', {}) + self.known_functions.update(userfuncs) + + def _rate_index_position(self, p): + return p*5 + + def _get_statement(self, codestring): + return "%s;" % codestring + + def _get_comment(self, text): + return "// {}".format(text) + + def _declare_number_const(self, name, value): + return "var {} = {};".format(name, value.evalf(self._settings['precision'])) + + def _format_code(self, lines): + return self.indent_code(lines) + + def _traverse_matrix_indices(self, mat): + rows, cols = mat.shape + return ((i, j) for i in range(rows) for j in range(cols)) + + def _get_loop_opening_ending(self, indices): + open_lines = [] + close_lines = [] + loopstart = "for (var %(varble)s=%(start)s; %(varble)s<%(end)s; %(varble)s++){" + for i in indices: + # Javascript arrays start at 0 and end at dimension-1 + open_lines.append(loopstart % { + 'varble': self._print(i.label), + 'start': self._print(i.lower), + 'end': self._print(i.upper + 1)}) + close_lines.append("}") + return open_lines, close_lines + + def _print_Pow(self, expr): + PREC = precedence(expr) + if equal_valued(expr.exp, -1): + return '1/%s' % (self.parenthesize(expr.base, PREC)) + elif equal_valued(expr.exp, 0.5): + return 'Math.sqrt(%s)' % self._print(expr.base) + elif expr.exp == S.One/3: + return 'Math.cbrt(%s)' % self._print(expr.base) + else: + return 'Math.pow(%s, %s)' % (self._print(expr.base), + self._print(expr.exp)) + + def _print_Rational(self, expr): + p, q = int(expr.p), int(expr.q) + return '%d/%d' % (p, q) + + def _print_Mod(self, expr): + num, den = expr.args + PREC = precedence(expr) + snum, sden = [self.parenthesize(arg, PREC) for arg in expr.args] + # % is remainder (same sign as numerator), not modulo (same sign as + # denominator), in js. Hence, % only works as modulo if both numbers + # have the same sign + if (num.is_nonnegative and den.is_nonnegative or + num.is_nonpositive and den.is_nonpositive): + return f"{snum} % {sden}" + return f"(({snum} % {sden}) + {sden}) % {sden}" + + def _print_Relational(self, expr): + lhs_code = self._print(expr.lhs) + rhs_code = self._print(expr.rhs) + op = expr.rel_op + return "{} {} {}".format(lhs_code, op, rhs_code) + + def _print_Indexed(self, expr): + # calculate index for 1d array + dims = expr.shape + elem = S.Zero + offset = S.One + for i in reversed(range(expr.rank)): + elem += expr.indices[i]*offset + offset *= dims[i] + return "%s[%s]" % (self._print(expr.base.label), self._print(elem)) + + def _print_Exp1(self, expr): + return "Math.E" + + def _print_Pi(self, expr): + return 'Math.PI' + + def _print_Infinity(self, expr): + return 'Number.POSITIVE_INFINITY' + + def _print_NegativeInfinity(self, expr): + return 'Number.NEGATIVE_INFINITY' + + def _print_Piecewise(self, expr): + from sympy.codegen.ast import Assignment + if expr.args[-1].cond != True: + # We need the last conditional to be a True, otherwise the resulting + # function may not return a result. + raise ValueError("All Piecewise expressions must contain an " + "(expr, True) statement to be used as a default " + "condition. Without one, the generated " + "expression may not evaluate to anything under " + "some condition.") + lines = [] + if expr.has(Assignment): + for i, (e, c) in enumerate(expr.args): + if i == 0: + lines.append("if (%s) {" % self._print(c)) + elif i == len(expr.args) - 1 and c == True: + lines.append("else {") + else: + lines.append("else if (%s) {" % self._print(c)) + code0 = self._print(e) + lines.append(code0) + lines.append("}") + return "\n".join(lines) + else: + # The piecewise was used in an expression, need to do inline + # operators. This has the downside that inline operators will + # not work for statements that span multiple lines (Matrix or + # Indexed expressions). + ecpairs = ["((%s) ? (\n%s\n)\n" % (self._print(c), self._print(e)) + for e, c in expr.args[:-1]] + last_line = ": (\n%s\n)" % self._print(expr.args[-1].expr) + return ": ".join(ecpairs) + last_line + " ".join([")"*len(ecpairs)]) + + def _print_MatrixElement(self, expr): + return "{}[{}]".format(self.parenthesize(expr.parent, + PRECEDENCE["Atom"], strict=True), + expr.j + expr.i*expr.parent.shape[1]) + + def indent_code(self, code): + """Accepts a string of code or a list of code lines""" + + if isinstance(code, str): + code_lines = self.indent_code(code.splitlines(True)) + return ''.join(code_lines) + + tab = " " + inc_token = ('{', '(', '{\n', '(\n') + dec_token = ('}', ')') + + code = [ line.lstrip(' \t') for line in code ] + + increase = [ int(any(map(line.endswith, inc_token))) for line in code ] + decrease = [ int(any(map(line.startswith, dec_token))) + for line in code ] + + pretty = [] + level = 0 + for n, line in enumerate(code): + if line in ('', '\n'): + pretty.append(line) + continue + level -= decrease[n] + pretty.append("%s%s" % (tab*level, line)) + level += increase[n] + return pretty + + +def jscode(expr, assign_to=None, **settings): + """Converts an expr to a string of javascript code + + Parameters + ========== + + expr : Expr + A SymPy expression to be converted. + assign_to : optional + When given, the argument is used as the name of the variable to which + the expression is assigned. Can be a string, ``Symbol``, + ``MatrixSymbol``, or ``Indexed`` type. This is helpful in case of + line-wrapping, or for expressions that generate multi-line statements. + precision : integer, optional + The precision for numbers such as pi [default=15]. + user_functions : dict, optional + A dictionary where keys are ``FunctionClass`` instances and values are + their string representations. Alternatively, the dictionary value can + be a list of tuples i.e. [(argument_test, js_function_string)]. See + below for examples. + human : bool, optional + If True, the result is a single string that may contain some constant + declarations for the number symbols. If False, the same information is + returned in a tuple of (symbols_to_declare, not_supported_functions, + code_text). [default=True]. + contract: bool, optional + If True, ``Indexed`` instances are assumed to obey tensor contraction + rules and the corresponding nested loops over indices are generated. + Setting contract=False will not generate loops, instead the user is + responsible to provide values for the indices in the code. + [default=True]. + + Examples + ======== + + >>> from sympy import jscode, symbols, Rational, sin, ceiling, Abs + >>> x, tau = symbols("x, tau") + >>> jscode((2*tau)**Rational(7, 2)) + '8*Math.sqrt(2)*Math.pow(tau, 7/2)' + >>> jscode(sin(x), assign_to="s") + 's = Math.sin(x);' + + Custom printing can be defined for certain types by passing a dictionary of + "type" : "function" to the ``user_functions`` kwarg. Alternatively, the + dictionary value can be a list of tuples i.e. [(argument_test, + js_function_string)]. + + >>> custom_functions = { + ... "ceiling": "CEIL", + ... "Abs": [(lambda x: not x.is_integer, "fabs"), + ... (lambda x: x.is_integer, "ABS")] + ... } + >>> jscode(Abs(x) + ceiling(x), user_functions=custom_functions) + 'fabs(x) + CEIL(x)' + + ``Piecewise`` expressions are converted into conditionals. If an + ``assign_to`` variable is provided an if statement is created, otherwise + the ternary operator is used. Note that if the ``Piecewise`` lacks a + default term, represented by ``(expr, True)`` then an error will be thrown. + This is to prevent generating an expression that may not evaluate to + anything. + + >>> from sympy import Piecewise + >>> expr = Piecewise((x + 1, x > 0), (x, True)) + >>> print(jscode(expr, tau)) + if (x > 0) { + tau = x + 1; + } + else { + tau = x; + } + + Support for loops is provided through ``Indexed`` types. With + ``contract=True`` these expressions will be turned into loops, whereas + ``contract=False`` will just print the assignment expression that should be + looped over: + + >>> from sympy import Eq, IndexedBase, Idx + >>> len_y = 5 + >>> y = IndexedBase('y', shape=(len_y,)) + >>> t = IndexedBase('t', shape=(len_y,)) + >>> Dy = IndexedBase('Dy', shape=(len_y-1,)) + >>> i = Idx('i', len_y-1) + >>> e=Eq(Dy[i], (y[i+1]-y[i])/(t[i+1]-t[i])) + >>> jscode(e.rhs, assign_to=e.lhs, contract=False) + 'Dy[i] = (y[i + 1] - y[i])/(t[i + 1] - t[i]);' + + Matrices are also supported, but a ``MatrixSymbol`` of the same dimensions + must be provided to ``assign_to``. Note that any expression that can be + generated normally can also exist inside a Matrix: + + >>> from sympy import Matrix, MatrixSymbol + >>> mat = Matrix([x**2, Piecewise((x + 1, x > 0), (x, True)), sin(x)]) + >>> A = MatrixSymbol('A', 3, 1) + >>> print(jscode(mat, A)) + A[0] = Math.pow(x, 2); + if (x > 0) { + A[1] = x + 1; + } + else { + A[1] = x; + } + A[2] = Math.sin(x); + """ + + return JavascriptCodePrinter(settings).doprint(expr, assign_to) + + +def print_jscode(expr, **settings): + """Prints the Javascript representation of the given expression. + + See jscode for the meaning of the optional arguments. + """ + print(jscode(expr, **settings)) diff --git a/phivenv/Lib/site-packages/sympy/printing/julia.py b/phivenv/Lib/site-packages/sympy/printing/julia.py new file mode 100644 index 0000000000000000000000000000000000000000..3ab815add4d87fe953f646409e3a7bb383b1bbc6 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/julia.py @@ -0,0 +1,652 @@ +""" +Julia code printer + +The `JuliaCodePrinter` converts SymPy expressions into Julia expressions. + +A complete code generator, which uses `julia_code` extensively, can be found +in `sympy.utilities.codegen`. The `codegen` module can be used to generate +complete source code files. + +""" + +from __future__ import annotations +from typing import Any + +from sympy.core import Mul, Pow, S, Rational +from sympy.core.mul import _keep_coeff +from sympy.core.numbers import equal_valued +from sympy.printing.codeprinter import CodePrinter +from sympy.printing.precedence import precedence, PRECEDENCE +from re import search + +# List of known functions. First, those that have the same name in +# SymPy and Julia. This is almost certainly incomplete! +known_fcns_src1 = ["sin", "cos", "tan", "cot", "sec", "csc", + "asin", "acos", "atan", "acot", "asec", "acsc", + "sinh", "cosh", "tanh", "coth", "sech", "csch", + "asinh", "acosh", "atanh", "acoth", "asech", "acsch", + "atan2", "sign", "floor", "log", "exp", + "cbrt", "sqrt", "erf", "erfc", "erfi", + "factorial", "gamma", "digamma", "trigamma", + "polygamma", "beta", + "airyai", "airyaiprime", "airybi", "airybiprime", + "besselj", "bessely", "besseli", "besselk", + "erfinv", "erfcinv"] +# These functions have different names ("SymPy": "Julia"), more +# generally a mapping to (argument_conditions, julia_function). +known_fcns_src2 = { + "Abs": "abs", + "ceiling": "ceil", + "conjugate": "conj", + "hankel1": "hankelh1", + "hankel2": "hankelh2", + "im": "imag", + "re": "real" +} + + +class JuliaCodePrinter(CodePrinter): + """ + A printer to convert expressions to strings of Julia code. + """ + printmethod = "_julia" + language = "Julia" + + _operators = { + 'and': '&&', + 'or': '||', + 'not': '!', + } + + _default_settings: dict[str, Any] = dict(CodePrinter._default_settings, **{ + 'precision': 17, + 'user_functions': {}, + 'contract': True, + 'inline': True, + }) + # Note: contract is for expressing tensors as loops (if True), or just + # assignment (if False). FIXME: this should be looked a more carefully + # for Julia. + + def __init__(self, settings={}): + super().__init__(settings) + self.known_functions = dict(zip(known_fcns_src1, known_fcns_src1)) + self.known_functions.update(dict(known_fcns_src2)) + userfuncs = settings.get('user_functions', {}) + self.known_functions.update(userfuncs) + + + def _rate_index_position(self, p): + return p*5 + + + def _get_statement(self, codestring): + return "%s" % codestring + + + def _get_comment(self, text): + return "# {}".format(text) + + + def _declare_number_const(self, name, value): + return "const {} = {}".format(name, value) + + + def _format_code(self, lines): + return self.indent_code(lines) + + + def _traverse_matrix_indices(self, mat): + # Julia uses Fortran order (column-major) + rows, cols = mat.shape + return ((i, j) for j in range(cols) for i in range(rows)) + + + def _get_loop_opening_ending(self, indices): + open_lines = [] + close_lines = [] + for i in indices: + # Julia arrays start at 1 and end at dimension + var, start, stop = map(self._print, + [i.label, i.lower + 1, i.upper + 1]) + open_lines.append("for %s = %s:%s" % (var, start, stop)) + close_lines.append("end") + return open_lines, close_lines + + + def _print_Mul(self, expr): + # print complex numbers nicely in Julia + if (expr.is_number and expr.is_imaginary and + expr.as_coeff_Mul()[0].is_integer): + return "%sim" % self._print(-S.ImaginaryUnit*expr) + + # cribbed from str.py + prec = precedence(expr) + + c, e = expr.as_coeff_Mul() + if c < 0: + expr = _keep_coeff(-c, e) + sign = "-" + else: + sign = "" + + a = [] # items in the numerator + b = [] # items that are in the denominator (if any) + + pow_paren = [] # Will collect all pow with more than one base element and exp = -1 + + if self.order not in ('old', 'none'): + args = expr.as_ordered_factors() + else: + # use make_args in case expr was something like -x -> x + args = Mul.make_args(expr) + + # Gather args for numerator/denominator + for item in args: + if (item.is_commutative and item.is_Pow and item.exp.is_Rational + and item.exp.is_negative): + if item.exp != -1: + b.append(Pow(item.base, -item.exp, evaluate=False)) + else: + if len(item.args[0].args) != 1 and isinstance(item.base, Mul): # To avoid situations like #14160 + pow_paren.append(item) + b.append(Pow(item.base, -item.exp)) + elif item.is_Rational and item is not S.Infinity and item.p == 1: + # Save the Rational type in julia Unless the numerator is 1. + # For example: + # julia_code(Rational(3, 7)*x) --> (3 // 7) * x + # julia_code(x/3) --> x / 3 but not x * (1 // 3) + b.append(Rational(item.q)) + else: + a.append(item) + + a = a or [S.One] + + a_str = [self.parenthesize(x, prec) for x in a] + b_str = [self.parenthesize(x, prec) for x in b] + + # To parenthesize Pow with exp = -1 and having more than one Symbol + for item in pow_paren: + if item.base in b: + b_str[b.index(item.base)] = "(%s)" % b_str[b.index(item.base)] + + # from here it differs from str.py to deal with "*" and ".*" + def multjoin(a, a_str): + # here we probably are assuming the constants will come first + r = a_str[0] + for i in range(1, len(a)): + mulsym = '*' if a[i-1].is_number else '.*' + r = "%s %s %s" % (r, mulsym, a_str[i]) + return r + + if not b: + return sign + multjoin(a, a_str) + elif len(b) == 1: + divsym = '/' if b[0].is_number else './' + return "%s %s %s" % (sign+multjoin(a, a_str), divsym, b_str[0]) + else: + divsym = '/' if all(bi.is_number for bi in b) else './' + return "%s %s (%s)" % (sign + multjoin(a, a_str), divsym, multjoin(b, b_str)) + + def _print_Relational(self, expr): + lhs_code = self._print(expr.lhs) + rhs_code = self._print(expr.rhs) + op = expr.rel_op + return "{} {} {}".format(lhs_code, op, rhs_code) + + def _print_Pow(self, expr): + powsymbol = '^' if all(x.is_number for x in expr.args) else '.^' + + PREC = precedence(expr) + + if equal_valued(expr.exp, 0.5): + return "sqrt(%s)" % self._print(expr.base) + + if expr.is_commutative: + if equal_valued(expr.exp, -0.5): + sym = '/' if expr.base.is_number else './' + return "1 %s sqrt(%s)" % (sym, self._print(expr.base)) + if equal_valued(expr.exp, -1): + sym = '/' if expr.base.is_number else './' + return "1 %s %s" % (sym, self.parenthesize(expr.base, PREC)) + + return '%s %s %s' % (self.parenthesize(expr.base, PREC), powsymbol, + self.parenthesize(expr.exp, PREC)) + + + def _print_MatPow(self, expr): + PREC = precedence(expr) + return '%s ^ %s' % (self.parenthesize(expr.base, PREC), + self.parenthesize(expr.exp, PREC)) + + + def _print_Pi(self, expr): + if self._settings["inline"]: + return "pi" + else: + return super()._print_NumberSymbol(expr) + + + def _print_ImaginaryUnit(self, expr): + return "im" + + + def _print_Exp1(self, expr): + if self._settings["inline"]: + return "e" + else: + return super()._print_NumberSymbol(expr) + + + def _print_EulerGamma(self, expr): + if self._settings["inline"]: + return "eulergamma" + else: + return super()._print_NumberSymbol(expr) + + + def _print_Catalan(self, expr): + if self._settings["inline"]: + return "catalan" + else: + return super()._print_NumberSymbol(expr) + + + def _print_GoldenRatio(self, expr): + if self._settings["inline"]: + return "golden" + else: + return super()._print_NumberSymbol(expr) + + + def _print_Assignment(self, expr): + from sympy.codegen.ast import Assignment + from sympy.functions.elementary.piecewise import Piecewise + from sympy.tensor.indexed import IndexedBase + # Copied from codeprinter, but remove special MatrixSymbol treatment + lhs = expr.lhs + rhs = expr.rhs + # We special case assignments that take multiple lines + if not self._settings["inline"] and isinstance(expr.rhs, Piecewise): + # Here we modify Piecewise so each expression is now + # an Assignment, and then continue on the print. + expressions = [] + conditions = [] + for (e, c) in rhs.args: + expressions.append(Assignment(lhs, e)) + conditions.append(c) + temp = Piecewise(*zip(expressions, conditions)) + return self._print(temp) + if self._settings["contract"] and (lhs.has(IndexedBase) or + rhs.has(IndexedBase)): + # Here we check if there is looping to be done, and if so + # print the required loops. + return self._doprint_loops(rhs, lhs) + else: + lhs_code = self._print(lhs) + rhs_code = self._print(rhs) + return self._get_statement("%s = %s" % (lhs_code, rhs_code)) + + + def _print_Infinity(self, expr): + return 'Inf' + + + def _print_NegativeInfinity(self, expr): + return '-Inf' + + + def _print_NaN(self, expr): + return 'NaN' + + + def _print_list(self, expr): + return 'Any[' + ', '.join(self._print(a) for a in expr) + ']' + + + def _print_tuple(self, expr): + if len(expr) == 1: + return "(%s,)" % self._print(expr[0]) + else: + return "(%s)" % self.stringify(expr, ", ") + _print_Tuple = _print_tuple + + + def _print_BooleanTrue(self, expr): + return "true" + + + def _print_BooleanFalse(self, expr): + return "false" + + + def _print_bool(self, expr): + return str(expr).lower() + + + # Could generate quadrature code for definite Integrals? + #_print_Integral = _print_not_supported + + + def _print_MatrixBase(self, A): + # Handle zero dimensions: + if S.Zero in A.shape: + return 'zeros(%s, %s)' % (A.rows, A.cols) + elif (A.rows, A.cols) == (1, 1): + return "[%s]" % A[0, 0] + elif A.rows == 1: + return "[%s]" % A.table(self, rowstart='', rowend='', colsep=' ') + elif A.cols == 1: + # note .table would unnecessarily equispace the rows + return "[%s]" % ", ".join([self._print(a) for a in A]) + return "[%s]" % A.table(self, rowstart='', rowend='', + rowsep=';\n', colsep=' ') + + + def _print_SparseRepMatrix(self, A): + from sympy.matrices import Matrix + L = A.col_list() + # make row vectors of the indices and entries + I = Matrix([k[0] + 1 for k in L]) + J = Matrix([k[1] + 1 for k in L]) + AIJ = Matrix([k[2] for k in L]) + return "sparse(%s, %s, %s, %s, %s)" % (self._print(I), self._print(J), + self._print(AIJ), A.rows, A.cols) + + + def _print_MatrixElement(self, expr): + return self.parenthesize(expr.parent, PRECEDENCE["Atom"], strict=True) \ + + '[%s,%s]' % (expr.i + 1, expr.j + 1) + + + def _print_MatrixSlice(self, expr): + def strslice(x, lim): + l = x[0] + 1 + h = x[1] + step = x[2] + lstr = self._print(l) + hstr = 'end' if h == lim else self._print(h) + if step == 1: + if l == 1 and h == lim: + return ':' + if l == h: + return lstr + else: + return lstr + ':' + hstr + else: + return ':'.join((lstr, self._print(step), hstr)) + return (self._print(expr.parent) + '[' + + strslice(expr.rowslice, expr.parent.shape[0]) + ',' + + strslice(expr.colslice, expr.parent.shape[1]) + ']') + + + def _print_Indexed(self, expr): + inds = [ self._print(i) for i in expr.indices ] + return "%s[%s]" % (self._print(expr.base.label), ",".join(inds)) + + def _print_Identity(self, expr): + return "eye(%s)" % self._print(expr.shape[0]) + + def _print_HadamardProduct(self, expr): + return ' .* '.join([self.parenthesize(arg, precedence(expr)) + for arg in expr.args]) + + def _print_HadamardPower(self, expr): + PREC = precedence(expr) + return '.**'.join([ + self.parenthesize(expr.base, PREC), + self.parenthesize(expr.exp, PREC) + ]) + + def _print_Rational(self, expr): + if expr.q == 1: + return str(expr.p) + return "%s // %s" % (expr.p, expr.q) + + # Note: as of 2022, Julia doesn't have spherical Bessel functions + def _print_jn(self, expr): + from sympy.functions import sqrt, besselj + x = expr.argument + expr2 = sqrt(S.Pi/(2*x))*besselj(expr.order + S.Half, x) + return self._print(expr2) + + + def _print_yn(self, expr): + from sympy.functions import sqrt, bessely + x = expr.argument + expr2 = sqrt(S.Pi/(2*x))*bessely(expr.order + S.Half, x) + return self._print(expr2) + + def _print_sinc(self, expr): + # Julia has the normalized sinc function + return "sinc({})".format(self._print(expr.args[0] / S.Pi)) + + def _print_Piecewise(self, expr): + if expr.args[-1].cond != True: + # We need the last conditional to be a True, otherwise the resulting + # function may not return a result. + raise ValueError("All Piecewise expressions must contain an " + "(expr, True) statement to be used as a default " + "condition. Without one, the generated " + "expression may not evaluate to anything under " + "some condition.") + lines = [] + if self._settings["inline"]: + # Express each (cond, expr) pair in a nested Horner form: + # (condition) .* (expr) + (not cond) .* () + # Expressions that result in multiple statements won't work here. + ecpairs = ["({}) ? ({}) :".format + (self._print(c), self._print(e)) + for e, c in expr.args[:-1]] + elast = " (%s)" % self._print(expr.args[-1].expr) + pw = "\n".join(ecpairs) + elast + # Note: current need these outer brackets for 2*pw. Would be + # nicer to teach parenthesize() to do this for us when needed! + return "(" + pw + ")" + else: + for i, (e, c) in enumerate(expr.args): + if i == 0: + lines.append("if (%s)" % self._print(c)) + elif i == len(expr.args) - 1 and c == True: + lines.append("else") + else: + lines.append("elseif (%s)" % self._print(c)) + code0 = self._print(e) + lines.append(code0) + if i == len(expr.args) - 1: + lines.append("end") + return "\n".join(lines) + + def _print_MatMul(self, expr): + c, m = expr.as_coeff_mmul() + + sign = "" + if c.is_number: + re, im = c.as_real_imag() + if im.is_zero and re.is_negative: + expr = _keep_coeff(-c, m) + sign = "-" + elif re.is_zero and im.is_negative: + expr = _keep_coeff(-c, m) + sign = "-" + + return sign + ' * '.join( + (self.parenthesize(arg, precedence(expr)) for arg in expr.args) + ) + + + def indent_code(self, code): + """Accepts a string of code or a list of code lines""" + + # code mostly copied from ccode + if isinstance(code, str): + code_lines = self.indent_code(code.splitlines(True)) + return ''.join(code_lines) + + tab = " " + inc_regex = ('^function ', '^if ', '^elseif ', '^else$', '^for ') + dec_regex = ('^end$', '^elseif ', '^else$') + + # pre-strip left-space from the code + code = [ line.lstrip(' \t') for line in code ] + + increase = [ int(any(search(re, line) for re in inc_regex)) + for line in code ] + decrease = [ int(any(search(re, line) for re in dec_regex)) + for line in code ] + + pretty = [] + level = 0 + for n, line in enumerate(code): + if line in ('', '\n'): + pretty.append(line) + continue + level -= decrease[n] + pretty.append("%s%s" % (tab*level, line)) + level += increase[n] + return pretty + + +def julia_code(expr, assign_to=None, **settings): + r"""Converts `expr` to a string of Julia code. + + Parameters + ========== + + expr : Expr + A SymPy expression to be converted. + assign_to : optional + When given, the argument is used as the name of the variable to which + the expression is assigned. Can be a string, ``Symbol``, + ``MatrixSymbol``, or ``Indexed`` type. This can be helpful for + expressions that generate multi-line statements. + precision : integer, optional + The precision for numbers such as pi [default=16]. + user_functions : dict, optional + A dictionary where keys are ``FunctionClass`` instances and values are + their string representations. Alternatively, the dictionary value can + be a list of tuples i.e. [(argument_test, cfunction_string)]. See + below for examples. + human : bool, optional + If True, the result is a single string that may contain some constant + declarations for the number symbols. If False, the same information is + returned in a tuple of (symbols_to_declare, not_supported_functions, + code_text). [default=True]. + contract: bool, optional + If True, ``Indexed`` instances are assumed to obey tensor contraction + rules and the corresponding nested loops over indices are generated. + Setting contract=False will not generate loops, instead the user is + responsible to provide values for the indices in the code. + [default=True]. + inline: bool, optional + If True, we try to create single-statement code instead of multiple + statements. [default=True]. + + Examples + ======== + + >>> from sympy import julia_code, symbols, sin, pi + >>> x = symbols('x') + >>> julia_code(sin(x).series(x).removeO()) + 'x .^ 5 / 120 - x .^ 3 / 6 + x' + + >>> from sympy import Rational, ceiling + >>> x, y, tau = symbols("x, y, tau") + >>> julia_code((2*tau)**Rational(7, 2)) + '8 * sqrt(2) * tau .^ (7 // 2)' + + Note that element-wise (Hadamard) operations are used by default between + symbols. This is because its possible in Julia to write "vectorized" + code. It is harmless if the values are scalars. + + >>> julia_code(sin(pi*x*y), assign_to="s") + 's = sin(pi * x .* y)' + + If you need a matrix product "*" or matrix power "^", you can specify the + symbol as a ``MatrixSymbol``. + + >>> from sympy import Symbol, MatrixSymbol + >>> n = Symbol('n', integer=True, positive=True) + >>> A = MatrixSymbol('A', n, n) + >>> julia_code(3*pi*A**3) + '(3 * pi) * A ^ 3' + + This class uses several rules to decide which symbol to use a product. + Pure numbers use "*", Symbols use ".*" and MatrixSymbols use "*". + A HadamardProduct can be used to specify componentwise multiplication ".*" + of two MatrixSymbols. There is currently there is no easy way to specify + scalar symbols, so sometimes the code might have some minor cosmetic + issues. For example, suppose x and y are scalars and A is a Matrix, then + while a human programmer might write "(x^2*y)*A^3", we generate: + + >>> julia_code(x**2*y*A**3) + '(x .^ 2 .* y) * A ^ 3' + + Matrices are supported using Julia inline notation. When using + ``assign_to`` with matrices, the name can be specified either as a string + or as a ``MatrixSymbol``. The dimensions must align in the latter case. + + >>> from sympy import Matrix, MatrixSymbol + >>> mat = Matrix([[x**2, sin(x), ceiling(x)]]) + >>> julia_code(mat, assign_to='A') + 'A = [x .^ 2 sin(x) ceil(x)]' + + ``Piecewise`` expressions are implemented with logical masking by default. + Alternatively, you can pass "inline=False" to use if-else conditionals. + Note that if the ``Piecewise`` lacks a default term, represented by + ``(expr, True)`` then an error will be thrown. This is to prevent + generating an expression that may not evaluate to anything. + + >>> from sympy import Piecewise + >>> pw = Piecewise((x + 1, x > 0), (x, True)) + >>> julia_code(pw, assign_to=tau) + 'tau = ((x > 0) ? (x + 1) : (x))' + + Note that any expression that can be generated normally can also exist + inside a Matrix: + + >>> mat = Matrix([[x**2, pw, sin(x)]]) + >>> julia_code(mat, assign_to='A') + 'A = [x .^ 2 ((x > 0) ? (x + 1) : (x)) sin(x)]' + + Custom printing can be defined for certain types by passing a dictionary of + "type" : "function" to the ``user_functions`` kwarg. Alternatively, the + dictionary value can be a list of tuples i.e., [(argument_test, + cfunction_string)]. This can be used to call a custom Julia function. + + >>> from sympy import Function + >>> f = Function('f') + >>> g = Function('g') + >>> custom_functions = { + ... "f": "existing_julia_fcn", + ... "g": [(lambda x: x.is_Matrix, "my_mat_fcn"), + ... (lambda x: not x.is_Matrix, "my_fcn")] + ... } + >>> mat = Matrix([[1, x]]) + >>> julia_code(f(x) + g(x) + g(mat), user_functions=custom_functions) + 'existing_julia_fcn(x) + my_fcn(x) + my_mat_fcn([1 x])' + + Support for loops is provided through ``Indexed`` types. With + ``contract=True`` these expressions will be turned into loops, whereas + ``contract=False`` will just print the assignment expression that should be + looped over: + + >>> from sympy import Eq, IndexedBase, Idx + >>> len_y = 5 + >>> y = IndexedBase('y', shape=(len_y,)) + >>> t = IndexedBase('t', shape=(len_y,)) + >>> Dy = IndexedBase('Dy', shape=(len_y-1,)) + >>> i = Idx('i', len_y-1) + >>> e = Eq(Dy[i], (y[i+1]-y[i])/(t[i+1]-t[i])) + >>> julia_code(e.rhs, assign_to=e.lhs, contract=False) + 'Dy[i] = (y[i + 1] - y[i]) ./ (t[i + 1] - t[i])' + """ + return JuliaCodePrinter(settings).doprint(expr, assign_to) + + +def print_julia_code(expr, **settings): + """Prints the Julia representation of the given expression. + + See `julia_code` for the meaning of the optional arguments. + """ + print(julia_code(expr, **settings)) diff --git a/phivenv/Lib/site-packages/sympy/printing/lambdarepr.py b/phivenv/Lib/site-packages/sympy/printing/lambdarepr.py new file mode 100644 index 0000000000000000000000000000000000000000..87fa0988d138d54d68ab8aef1bbc0f27b243b472 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/lambdarepr.py @@ -0,0 +1,251 @@ +from .pycode import ( + PythonCodePrinter, + MpmathPrinter, +) +from .numpy import NumPyPrinter # NumPyPrinter is imported for backward compatibility +from sympy.core.sorting import default_sort_key + + +__all__ = [ + 'PythonCodePrinter', + 'MpmathPrinter', # MpmathPrinter is published for backward compatibility + 'NumPyPrinter', + 'LambdaPrinter', + 'NumPyPrinter', + 'IntervalPrinter', + 'lambdarepr', +] + + +class LambdaPrinter(PythonCodePrinter): + """ + This printer converts expressions into strings that can be used by + lambdify. + """ + printmethod = "_lambdacode" + + + def _print_And(self, expr): + result = ['('] + for arg in sorted(expr.args, key=default_sort_key): + result.extend(['(', self._print(arg), ')']) + result.append(' and ') + result = result[:-1] + result.append(')') + return ''.join(result) + + def _print_Or(self, expr): + result = ['('] + for arg in sorted(expr.args, key=default_sort_key): + result.extend(['(', self._print(arg), ')']) + result.append(' or ') + result = result[:-1] + result.append(')') + return ''.join(result) + + def _print_Not(self, expr): + result = ['(', 'not (', self._print(expr.args[0]), '))'] + return ''.join(result) + + def _print_BooleanTrue(self, expr): + return "True" + + def _print_BooleanFalse(self, expr): + return "False" + + def _print_ITE(self, expr): + result = [ + '((', self._print(expr.args[1]), + ') if (', self._print(expr.args[0]), + ') else (', self._print(expr.args[2]), '))' + ] + return ''.join(result) + + def _print_NumberSymbol(self, expr): + return str(expr) + + def _print_Pow(self, expr, **kwargs): + # XXX Temporary workaround. Should Python math printer be + # isolated from PythonCodePrinter? + return super(PythonCodePrinter, self)._print_Pow(expr, **kwargs) + + +# numexpr works by altering the string passed to numexpr.evaluate +# rather than by populating a namespace. Thus a special printer... +class NumExprPrinter(LambdaPrinter): + # key, value pairs correspond to SymPy name and numexpr name + # functions not appearing in this dict will raise a TypeError + printmethod = "_numexprcode" + + _numexpr_functions = { + 'sin' : 'sin', + 'cos' : 'cos', + 'tan' : 'tan', + 'asin': 'arcsin', + 'acos': 'arccos', + 'atan': 'arctan', + 'atan2' : 'arctan2', + 'sinh' : 'sinh', + 'cosh' : 'cosh', + 'tanh' : 'tanh', + 'asinh': 'arcsinh', + 'acosh': 'arccosh', + 'atanh': 'arctanh', + 'ln' : 'log', + 'log': 'log', + 'exp': 'exp', + 'sqrt' : 'sqrt', + 'Abs' : 'abs', + 'conjugate' : 'conj', + 'im' : 'imag', + 're' : 'real', + 'where' : 'where', + 'complex' : 'complex', + 'contains' : 'contains', + } + + module = 'numexpr' + + def _print_ImaginaryUnit(self, expr): + return '1j' + + def _print_seq(self, seq, delimiter=', '): + # simplified _print_seq taken from pretty.py + s = [self._print(item) for item in seq] + if s: + return delimiter.join(s) + else: + return "" + + def _print_Function(self, e): + func_name = e.func.__name__ + + nstr = self._numexpr_functions.get(func_name, None) + if nstr is None: + # check for implemented_function + if hasattr(e, '_imp_'): + return "(%s)" % self._print(e._imp_(*e.args)) + else: + raise TypeError("numexpr does not support function '%s'" % + func_name) + return "%s(%s)" % (nstr, self._print_seq(e.args)) + + def _print_Piecewise(self, expr): + "Piecewise function printer" + exprs = [self._print(arg.expr) for arg in expr.args] + conds = [self._print(arg.cond) for arg in expr.args] + # If [default_value, True] is a (expr, cond) sequence in a Piecewise object + # it will behave the same as passing the 'default' kwarg to select() + # *as long as* it is the last element in expr.args. + # If this is not the case, it may be triggered prematurely. + ans = [] + parenthesis_count = 0 + is_last_cond_True = False + for cond, expr in zip(conds, exprs): + if cond == 'True': + ans.append(expr) + is_last_cond_True = True + break + else: + ans.append('where(%s, %s, ' % (cond, expr)) + parenthesis_count += 1 + if not is_last_cond_True: + # See https://github.com/pydata/numexpr/issues/298 + # + # simplest way to put a nan but raises + # 'RuntimeWarning: invalid value encountered in log' + # + # There are other ways to do this such as + # + # >>> import numexpr as ne + # >>> nan = float('nan') + # >>> ne.evaluate('where(x < 0, -1, nan)', {'x': [-1, 2, 3], 'nan':nan}) + # array([-1., nan, nan]) + # + # That needs to be handled in the lambdified function though rather + # than here in the printer. + ans.append('log(-1)') + return ''.join(ans) + ')' * parenthesis_count + + def _print_ITE(self, expr): + from sympy.functions.elementary.piecewise import Piecewise + return self._print(expr.rewrite(Piecewise)) + + def blacklisted(self, expr): + raise TypeError("numexpr cannot be used with %s" % + expr.__class__.__name__) + + # blacklist all Matrix printing + _print_SparseRepMatrix = \ + _print_MutableSparseMatrix = \ + _print_ImmutableSparseMatrix = \ + _print_Matrix = \ + _print_DenseMatrix = \ + _print_MutableDenseMatrix = \ + _print_ImmutableMatrix = \ + _print_ImmutableDenseMatrix = \ + blacklisted + # blacklist some Python expressions + _print_list = \ + _print_tuple = \ + _print_Tuple = \ + _print_dict = \ + _print_Dict = \ + blacklisted + + def _print_NumExprEvaluate(self, expr): + evaluate = self._module_format(self.module +".evaluate") + return "%s('%s', truediv=True)" % (evaluate, self._print(expr.expr)) + + def doprint(self, expr): + from sympy.codegen.ast import CodegenAST + from sympy.codegen.pynodes import NumExprEvaluate + if not isinstance(expr, CodegenAST): + expr = NumExprEvaluate(expr) + return super().doprint(expr) + + def _print_Return(self, expr): + from sympy.codegen.pynodes import NumExprEvaluate + r, = expr.args + if not isinstance(r, NumExprEvaluate): + expr = expr.func(NumExprEvaluate(r)) + return super()._print_Return(expr) + + def _print_Assignment(self, expr): + from sympy.codegen.pynodes import NumExprEvaluate + lhs, rhs, *args = expr.args + if not isinstance(rhs, NumExprEvaluate): + expr = expr.func(lhs, NumExprEvaluate(rhs), *args) + return super()._print_Assignment(expr) + + def _print_CodeBlock(self, expr): + from sympy.codegen.ast import CodegenAST + from sympy.codegen.pynodes import NumExprEvaluate + args = [ arg if isinstance(arg, CodegenAST) else NumExprEvaluate(arg) for arg in expr.args ] + return super()._print_CodeBlock(self, expr.func(*args)) + + +class IntervalPrinter(MpmathPrinter, LambdaPrinter): + """Use ``lambda`` printer but print numbers as ``mpi`` intervals. """ + + def _print_Integer(self, expr): + return "mpi('%s')" % super(PythonCodePrinter, self)._print_Integer(expr) + + def _print_Rational(self, expr): + return "mpi('%s')" % super(PythonCodePrinter, self)._print_Rational(expr) + + def _print_Half(self, expr): + return "mpi('%s')" % super(PythonCodePrinter, self)._print_Rational(expr) + + def _print_Pow(self, expr): + return super(MpmathPrinter, self)._print_Pow(expr, rational=True) + + +for k in NumExprPrinter._numexpr_functions: + setattr(NumExprPrinter, '_print_%s' % k, NumExprPrinter._print_Function) + +def lambdarepr(expr, **settings): + """ + Returns a string usable for lambdifying. + """ + return LambdaPrinter(settings).doprint(expr) diff --git a/phivenv/Lib/site-packages/sympy/printing/latex.py b/phivenv/Lib/site-packages/sympy/printing/latex.py new file mode 100644 index 0000000000000000000000000000000000000000..724df719d560e001deb175649a3769703bdf5ca5 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/latex.py @@ -0,0 +1,3318 @@ +""" +A Printer which converts an expression into its LaTeX equivalent. +""" +from __future__ import annotations +from typing import Any, Callable, TYPE_CHECKING + +import itertools + +from sympy.core import Add, Float, Mod, Mul, Number, S, Symbol, Expr +from sympy.core.alphabets import greeks +from sympy.core.containers import Tuple +from sympy.core.function import Function, AppliedUndef, Derivative +from sympy.core.operations import AssocOp +from sympy.core.power import Pow +from sympy.core.sorting import default_sort_key +from sympy.core.sympify import SympifyError +from sympy.logic.boolalg import true, BooleanTrue, BooleanFalse + + +# sympy.printing imports +from sympy.printing.precedence import precedence_traditional +from sympy.printing.printer import Printer, print_function +from sympy.printing.conventions import split_super_sub, requires_partial +from sympy.printing.precedence import precedence, PRECEDENCE + +from mpmath.libmp.libmpf import prec_to_dps, to_str as mlib_to_str + +from sympy.utilities.iterables import has_variety, sift + +import re + +if TYPE_CHECKING: + from sympy.tensor.array import NDimArray + from sympy.vector.basisdependent import BasisDependent + +# Hand-picked functions which can be used directly in both LaTeX and MathJax +# Complete list at +# https://docs.mathjax.org/en/latest/tex.html#supported-latex-commands +# This variable only contains those functions which SymPy uses. +accepted_latex_functions = ['arcsin', 'arccos', 'arctan', 'sin', 'cos', 'tan', + 'sinh', 'cosh', 'tanh', 'sqrt', 'ln', 'log', 'sec', + 'csc', 'cot', 'coth', 're', 'im', 'frac', 'root', + 'arg', + ] + +tex_greek_dictionary = { + 'Alpha': r'\mathrm{A}', + 'Beta': r'\mathrm{B}', + 'Gamma': r'\Gamma', + 'Delta': r'\Delta', + 'Epsilon': r'\mathrm{E}', + 'Zeta': r'\mathrm{Z}', + 'Eta': r'\mathrm{H}', + 'Theta': r'\Theta', + 'Iota': r'\mathrm{I}', + 'Kappa': r'\mathrm{K}', + 'Lambda': r'\Lambda', + 'Mu': r'\mathrm{M}', + 'Nu': r'\mathrm{N}', + 'Xi': r'\Xi', + 'omicron': 'o', + 'Omicron': r'\mathrm{O}', + 'Pi': r'\Pi', + 'Rho': r'\mathrm{P}', + 'Sigma': r'\Sigma', + 'Tau': r'\mathrm{T}', + 'Upsilon': r'\Upsilon', + 'Phi': r'\Phi', + 'Chi': r'\mathrm{X}', + 'Psi': r'\Psi', + 'Omega': r'\Omega', + 'lamda': r'\lambda', + 'Lamda': r'\Lambda', + 'khi': r'\chi', + 'Khi': r'\mathrm{X}', + 'varepsilon': r'\varepsilon', + 'varkappa': r'\varkappa', + 'varphi': r'\varphi', + 'varpi': r'\varpi', + 'varrho': r'\varrho', + 'varsigma': r'\varsigma', + 'vartheta': r'\vartheta', +} + +other_symbols = {'aleph', 'beth', 'daleth', 'gimel', 'ell', 'eth', 'hbar', + 'hslash', 'mho', 'wp'} + +# Variable name modifiers +modifier_dict: dict[str, Callable[[str], str]] = { + # Accents + 'mathring': lambda s: r'\mathring{'+s+r'}', + 'ddddot': lambda s: r'\ddddot{'+s+r'}', + 'dddot': lambda s: r'\dddot{'+s+r'}', + 'ddot': lambda s: r'\ddot{'+s+r'}', + 'dot': lambda s: r'\dot{'+s+r'}', + 'check': lambda s: r'\check{'+s+r'}', + 'breve': lambda s: r'\breve{'+s+r'}', + 'acute': lambda s: r'\acute{'+s+r'}', + 'grave': lambda s: r'\grave{'+s+r'}', + 'tilde': lambda s: r'\tilde{'+s+r'}', + 'hat': lambda s: r'\hat{'+s+r'}', + 'bar': lambda s: r'\bar{'+s+r'}', + 'vec': lambda s: r'\vec{'+s+r'}', + 'prime': lambda s: "{"+s+"}'", + 'prm': lambda s: "{"+s+"}'", + # Faces + 'bold': lambda s: r'\boldsymbol{'+s+r'}', + 'bm': lambda s: r'\boldsymbol{'+s+r'}', + 'cal': lambda s: r'\mathcal{'+s+r'}', + 'scr': lambda s: r'\mathscr{'+s+r'}', + 'frak': lambda s: r'\mathfrak{'+s+r'}', + # Brackets + 'norm': lambda s: r'\left\|{'+s+r'}\right\|', + 'avg': lambda s: r'\left\langle{'+s+r'}\right\rangle', + 'abs': lambda s: r'\left|{'+s+r'}\right|', + 'mag': lambda s: r'\left|{'+s+r'}\right|', +} + +greek_letters_set = frozenset(greeks) + +_between_two_numbers_p = ( + re.compile(r'[0-9][} ]*$'), # search + re.compile(r'(\d|\\frac{\d+}{\d+})'), # match +) + + +def latex_escape(s: str) -> str: + """ + Escape a string such that latex interprets it as plaintext. + + We cannot use verbatim easily with mathjax, so escaping is easier. + Rules from https://tex.stackexchange.com/a/34586/41112. + """ + s = s.replace('\\', r'\textbackslash') + for c in '&%$#_{}': + s = s.replace(c, '\\' + c) + s = s.replace('~', r'\textasciitilde') + s = s.replace('^', r'\textasciicircum') + return s + + +class LatexPrinter(Printer): + printmethod = "_latex" + + _default_settings: dict[str, Any] = { + "full_prec": False, + "fold_frac_powers": False, + "fold_func_brackets": False, + "fold_short_frac": None, + "inv_trig_style": "abbreviated", + "itex": False, + "ln_notation": False, + "long_frac_ratio": None, + "mat_delim": "[", + "mat_str": None, + "mode": "plain", + "mul_symbol": None, + "order": None, + "symbol_names": {}, + "root_notation": True, + "mat_symbol_style": "plain", + "imaginary_unit": "i", + "gothic_re_im": False, + "decimal_separator": "period", + "perm_cyclic": True, + "parenthesize_super": True, + "min": None, + "max": None, + "diff_operator": "d", + "adjoint_style": "dagger", + "disable_split_super_sub": False, + } + + def __init__(self, settings=None): + Printer.__init__(self, settings) + + if 'mode' in self._settings: + valid_modes = ['inline', 'plain', 'equation', + 'equation*'] + if self._settings['mode'] not in valid_modes: + raise ValueError("'mode' must be one of 'inline', 'plain', " + "'equation' or 'equation*'") + + if self._settings['fold_short_frac'] is None and \ + self._settings['mode'] == 'inline': + self._settings['fold_short_frac'] = True + + mul_symbol_table = { + None: r" ", + "ldot": r" \,.\, ", + "dot": r" \cdot ", + "times": r" \times " + } + try: + self._settings['mul_symbol_latex'] = \ + mul_symbol_table[self._settings['mul_symbol']] + except KeyError: + self._settings['mul_symbol_latex'] = \ + self._settings['mul_symbol'] + try: + self._settings['mul_symbol_latex_numbers'] = \ + mul_symbol_table[self._settings['mul_symbol'] or 'dot'] + except KeyError: + if (self._settings['mul_symbol'].strip() in + ['', ' ', '\\', '\\,', '\\:', '\\;', '\\quad']): + self._settings['mul_symbol_latex_numbers'] = \ + mul_symbol_table['dot'] + else: + self._settings['mul_symbol_latex_numbers'] = \ + self._settings['mul_symbol'] + + self._delim_dict = {'(': ')', '[': ']'} + + imaginary_unit_table = { + None: r"i", + "i": r"i", + "ri": r"\mathrm{i}", + "ti": r"\text{i}", + "j": r"j", + "rj": r"\mathrm{j}", + "tj": r"\text{j}", + } + imag_unit = self._settings['imaginary_unit'] + self._settings['imaginary_unit_latex'] = imaginary_unit_table.get(imag_unit, imag_unit) + + diff_operator_table = { + None: r"d", + "d": r"d", + "rd": r"\mathrm{d}", + "td": r"\text{d}", + } + diff_operator = self._settings['diff_operator'] + self._settings["diff_operator_latex"] = diff_operator_table.get(diff_operator, diff_operator) + + def _add_parens(self, s) -> str: + return r"\left({}\right)".format(s) + + # TODO: merge this with the above, which requires a lot of test changes + def _add_parens_lspace(self, s) -> str: + return r"\left( {}\right)".format(s) + + def parenthesize(self, item, level, is_neg=False, strict=False) -> str: + prec_val = precedence_traditional(item) + if is_neg and strict: + return self._add_parens(self._print(item)) + + if (prec_val < level) or ((not strict) and prec_val <= level): + return self._add_parens(self._print(item)) + else: + return self._print(item) + + def parenthesize_super(self, s): + """ + Protect superscripts in s + + If the parenthesize_super option is set, protect with parentheses, else + wrap in braces. + """ + if "^" in s: + if self._settings['parenthesize_super']: + return self._add_parens(s) + else: + return "{{{}}}".format(s) + return s + + def doprint(self, expr) -> str: + tex = Printer.doprint(self, expr) + + if self._settings['mode'] == 'plain': + return tex + elif self._settings['mode'] == 'inline': + return r"$%s$" % tex + elif self._settings['itex']: + return r"$$%s$$" % tex + else: + env_str = self._settings['mode'] + return r"\begin{%s}%s\end{%s}" % (env_str, tex, env_str) + + def _needs_brackets(self, expr) -> bool: + """ + Returns True if the expression needs to be wrapped in brackets when + printed, False otherwise. For example: a + b => True; a => False; + 10 => False; -10 => True. + """ + return not ((expr.is_Integer and expr.is_nonnegative) + or (expr.is_Atom and (expr is not S.NegativeOne + and expr.is_Rational is False))) + + def _needs_function_brackets(self, expr) -> bool: + """ + Returns True if the expression needs to be wrapped in brackets when + passed as an argument to a function, False otherwise. This is a more + liberal version of _needs_brackets, in that many expressions which need + to be wrapped in brackets when added/subtracted/raised to a power do + not need them when passed to a function. Such an example is a*b. + """ + if not self._needs_brackets(expr): + return False + else: + # Muls of the form a*b*c... can be folded + if expr.is_Mul and not self._mul_is_clean(expr): + return True + # Pows which don't need brackets can be folded + elif expr.is_Pow and not self._pow_is_clean(expr): + return True + # Add and Function always need brackets + elif expr.is_Add or expr.is_Function: + return True + else: + return False + + def _needs_mul_brackets(self, expr, first=False, last=False) -> bool: + """ + Returns True if the expression needs to be wrapped in brackets when + printed as part of a Mul, False otherwise. This is True for Add, + but also for some container objects that would not need brackets + when appearing last in a Mul, e.g. an Integral. ``last=True`` + specifies that this expr is the last to appear in a Mul. + ``first=True`` specifies that this expr is the first to appear in + a Mul. + """ + from sympy.concrete.products import Product + from sympy.concrete.summations import Sum + from sympy.integrals.integrals import Integral + + if expr.is_Mul: + if not first and expr.could_extract_minus_sign(): + return True + elif precedence_traditional(expr) < PRECEDENCE["Mul"]: + return True + elif expr.is_Relational: + return True + if expr.is_Piecewise: + return True + if any(expr.has(x) for x in (Mod,)): + return True + if (not last and + any(expr.has(x) for x in (Integral, Product, Sum))): + return True + + return False + + def _needs_add_brackets(self, expr) -> bool: + """ + Returns True if the expression needs to be wrapped in brackets when + printed as part of an Add, False otherwise. This is False for most + things. + """ + if expr.is_Relational: + return True + if any(expr.has(x) for x in (Mod,)): + return True + if expr.is_Add: + return True + return False + + def _mul_is_clean(self, expr) -> bool: + for arg in expr.args: + if arg.is_Function: + return False + return True + + def _pow_is_clean(self, expr) -> bool: + return not self._needs_brackets(expr.base) + + def _do_exponent(self, expr: str, exp): + if exp is not None: + return r"\left(%s\right)^{%s}" % (expr, exp) + else: + return expr + + def _print_Basic(self, expr): + name = self._deal_with_super_sub(expr.__class__.__name__) + if expr.args: + ls = [self._print(o) for o in expr.args] + s = r"\operatorname{{{}}}\left({}\right)" + return s.format(name, ", ".join(ls)) + else: + return r"\text{{{}}}".format(name) + + def _print_bool(self, e: bool | BooleanTrue | BooleanFalse): + return r"\text{%s}" % e + + _print_BooleanTrue = _print_bool + _print_BooleanFalse = _print_bool + + def _print_NoneType(self, e): + return r"\text{%s}" % e + + def _print_Add(self, expr, order=None): + terms = self._as_ordered_terms(expr, order=order) + + tex = "" + for i, term in enumerate(terms): + if i == 0: + pass + elif term.could_extract_minus_sign(): + tex += " - " + term = -term + else: + tex += " + " + term_tex = self._print(term) + if self._needs_add_brackets(term): + term_tex = r"\left(%s\right)" % term_tex + tex += term_tex + + return tex + + def _print_Cycle(self, expr): + from sympy.combinatorics.permutations import Permutation + if expr.size == 0: + return r"\left( \right)" + expr = Permutation(expr) + expr_perm = expr.cyclic_form + siz = expr.size + if expr.array_form[-1] == siz - 1: + expr_perm = expr_perm + [[siz - 1]] + term_tex = '' + for i in expr_perm: + term_tex += str(i).replace(',', r"\;") + term_tex = term_tex.replace('[', r"\left( ") + term_tex = term_tex.replace(']', r"\right)") + return term_tex + + def _print_Permutation(self, expr): + from sympy.combinatorics.permutations import Permutation + from sympy.utilities.exceptions import sympy_deprecation_warning + + perm_cyclic = Permutation.print_cyclic + if perm_cyclic is not None: + sympy_deprecation_warning( + f""" + Setting Permutation.print_cyclic is deprecated. Instead use + init_printing(perm_cyclic={perm_cyclic}). + """, + deprecated_since_version="1.6", + active_deprecations_target="deprecated-permutation-print_cyclic", + stacklevel=8, + ) + else: + perm_cyclic = self._settings.get("perm_cyclic", True) + + if perm_cyclic: + return self._print_Cycle(expr) + + if expr.size == 0: + return r"\left( \right)" + + lower = [self._print(arg) for arg in expr.array_form] + upper = [self._print(arg) for arg in range(len(lower))] + + row1 = " & ".join(upper) + row2 = " & ".join(lower) + mat = r" \\ ".join((row1, row2)) + return r"\begin{pmatrix} %s \end{pmatrix}" % mat + + + def _print_AppliedPermutation(self, expr): + perm, var = expr.args + return r"\sigma_{%s}(%s)" % (self._print(perm), self._print(var)) + + def _print_Float(self, expr): + # Based off of that in StrPrinter + dps = prec_to_dps(expr._prec) + strip = False if self._settings['full_prec'] else True + low = self._settings["min"] if "min" in self._settings else None + high = self._settings["max"] if "max" in self._settings else None + str_real = mlib_to_str(expr._mpf_, dps, strip_zeros=strip, min_fixed=low, max_fixed=high) + + # Must always have a mul symbol (as 2.5 10^{20} just looks odd) + # thus we use the number separator + separator = self._settings['mul_symbol_latex_numbers'] + + if 'e' in str_real: + (mant, exp) = str_real.split('e') + + if exp[0] == '+': + exp = exp[1:] + if self._settings['decimal_separator'] == 'comma': + mant = mant.replace('.','{,}') + + return r"%s%s10^{%s}" % (mant, separator, exp) + elif str_real == "+inf": + return r"\infty" + elif str_real == "-inf": + return r"- \infty" + else: + if self._settings['decimal_separator'] == 'comma': + str_real = str_real.replace('.','{,}') + return str_real + + def _print_Cross(self, expr): + vec1 = expr._expr1 + vec2 = expr._expr2 + return r"%s \times %s" % (self.parenthesize(vec1, PRECEDENCE['Mul']), + self.parenthesize(vec2, PRECEDENCE['Mul'])) + + def _print_Curl(self, expr): + vec = expr._expr + return r"\nabla\times %s" % self.parenthesize(vec, PRECEDENCE['Mul']) + + def _print_Divergence(self, expr): + vec = expr._expr + return r"\nabla\cdot %s" % self.parenthesize(vec, PRECEDENCE['Mul']) + + def _print_Dot(self, expr): + vec1 = expr._expr1 + vec2 = expr._expr2 + return r"%s \cdot %s" % (self.parenthesize(vec1, PRECEDENCE['Mul']), + self.parenthesize(vec2, PRECEDENCE['Mul'])) + + def _print_Gradient(self, expr): + func = expr._expr + return r"\nabla %s" % self.parenthesize(func, PRECEDENCE['Mul']) + + def _print_Laplacian(self, expr): + func = expr._expr + return r"\Delta %s" % self.parenthesize(func, PRECEDENCE['Mul']) + + def _print_Mul(self, expr: Expr): + from sympy.simplify import fraction + separator: str = self._settings['mul_symbol_latex'] + numbersep: str = self._settings['mul_symbol_latex_numbers'] + + def convert(expr) -> str: + if not expr.is_Mul: + return str(self._print(expr)) + else: + if self.order not in ('old', 'none'): + args = expr.as_ordered_factors() + else: + args = list(expr.args) + + # If there are quantities or prefixes, append them at the back. + units, nonunits = sift(args, lambda x: (hasattr(x, "_scale_factor") or hasattr(x, "is_physical_constant")) or + (isinstance(x, Pow) and + hasattr(x.base, "is_physical_constant")), binary=True) + prefixes, units = sift(units, lambda x: hasattr(x, "_scale_factor"), binary=True) + return convert_args(nonunits + prefixes + units) + + def convert_args(args) -> str: + _tex = last_term_tex = "" + + for i, term in enumerate(args): + term_tex = self._print(term) + if not (hasattr(term, "_scale_factor") or hasattr(term, "is_physical_constant")): + if self._needs_mul_brackets(term, first=(i == 0), + last=(i == len(args) - 1)): + term_tex = r"\left(%s\right)" % term_tex + + if _between_two_numbers_p[0].search(last_term_tex) and \ + _between_two_numbers_p[1].match(term_tex): + # between two numbers + _tex += numbersep + elif _tex: + _tex += separator + elif _tex: + _tex += separator + + _tex += term_tex + last_term_tex = term_tex + return _tex + + # Check for unevaluated Mul. In this case we need to make sure the + # identities are visible, multiple Rational factors are not combined + # etc so we display in a straight-forward form that fully preserves all + # args and their order. + # XXX: _print_Pow calls this routine with instances of Pow... + if isinstance(expr, Mul): + args = expr.args + if args[0] is S.One or any(isinstance(arg, Number) for arg in args[1:]): + return convert_args(args) + + include_parens = False + if expr.could_extract_minus_sign(): + expr = -expr + tex = "- " + if expr.is_Add: + tex += "(" + include_parens = True + else: + tex = "" + + numer, denom = fraction(expr, exact=True) + + if denom is S.One and Pow(1, -1, evaluate=False) not in expr.args: + # use the original expression here, since fraction() may have + # altered it when producing numer and denom + tex += convert(expr) + + else: + snumer = convert(numer) + sdenom = convert(denom) + ldenom = len(sdenom.split()) + ratio = self._settings['long_frac_ratio'] + if self._settings['fold_short_frac'] and ldenom <= 2 and \ + "^" not in sdenom: + # handle short fractions + if self._needs_mul_brackets(numer, last=False): + tex += r"\left(%s\right) / %s" % (snumer, sdenom) + else: + tex += r"%s / %s" % (snumer, sdenom) + elif ratio is not None and \ + len(snumer.split()) > ratio*ldenom: + # handle long fractions + if self._needs_mul_brackets(numer, last=True): + tex += r"\frac{1}{%s}%s\left(%s\right)" \ + % (sdenom, separator, snumer) + elif numer.is_Mul: + # split a long numerator + a = S.One + b = S.One + for x in numer.args: + if self._needs_mul_brackets(x, last=False) or \ + len(convert(a*x).split()) > ratio*ldenom or \ + (b.is_commutative is x.is_commutative is False): + b *= x + else: + a *= x + if self._needs_mul_brackets(b, last=True): + tex += r"\frac{%s}{%s}%s\left(%s\right)" \ + % (convert(a), sdenom, separator, convert(b)) + else: + tex += r"\frac{%s}{%s}%s%s" \ + % (convert(a), sdenom, separator, convert(b)) + else: + tex += r"\frac{1}{%s}%s%s" % (sdenom, separator, snumer) + else: + tex += r"\frac{%s}{%s}" % (snumer, sdenom) + + if include_parens: + tex += ")" + return tex + + def _print_AlgebraicNumber(self, expr): + if expr.is_aliased: + return self._print(expr.as_poly().as_expr()) + else: + return self._print(expr.as_expr()) + + def _print_PrimeIdeal(self, expr): + p = self._print(expr.p) + if expr.is_inert: + return rf'\left({p}\right)' + alpha = self._print(expr.alpha.as_expr()) + return rf'\left({p}, {alpha}\right)' + + def _print_Pow(self, expr: Pow): + # Treat x**Rational(1,n) as special case + if expr.exp.is_Rational: + p: int = expr.exp.p # type: ignore + q: int = expr.exp.q # type: ignore + if abs(p) == 1 and q != 1 and self._settings['root_notation']: + base = self._print(expr.base) + if q == 2: + tex = r"\sqrt{%s}" % base + elif self._settings['itex']: + tex = r"\root{%d}{%s}" % (q, base) + else: + tex = r"\sqrt[%d]{%s}" % (q, base) + if expr.exp.is_negative: + return r"\frac{1}{%s}" % tex + else: + return tex + elif self._settings['fold_frac_powers'] and q != 1: + base = self.parenthesize(expr.base, PRECEDENCE['Pow']) + # issue #12886: add parentheses for superscripts raised to powers + if expr.base.is_Symbol: + base = self.parenthesize_super(base) + if expr.base.is_Function: + return self._print(expr.base, exp="%s/%s" % (p, q)) + return r"%s^{%s/%s}" % (base, p, q) + elif expr.exp.is_negative and expr.base.is_commutative: + # special case for 1^(-x), issue 9216 + if expr.base == 1: + return r"%s^{%s}" % (expr.base, expr.exp) + # special case for (1/x)^(-y) and (-1/-x)^(-y), issue 20252 + if expr.base.is_Rational: + base_p: int = expr.base.p # type: ignore + base_q: int = expr.base.q # type: ignore + if base_p * base_q == abs(base_q): + if expr.exp == -1: + return r"\frac{1}{\frac{%s}{%s}}" % (base_p, base_q) + else: + return r"\frac{1}{(\frac{%s}{%s})^{%s}}" % (base_p, base_q, abs(expr.exp)) + # things like 1/x + return self._print_Mul(expr) + if expr.base.is_Function: + return self._print(expr.base, exp=self._print(expr.exp)) + tex = r"%s^{%s}" + return self._helper_print_standard_power(expr, tex) + + def _helper_print_standard_power(self, expr, template: str) -> str: + exp = self._print(expr.exp) + # issue #12886: add parentheses around superscripts raised + # to powers + base = self.parenthesize(expr.base, PRECEDENCE['Pow']) + if expr.base.is_Symbol: + base = self.parenthesize_super(base) + elif expr.base.is_Float: + base = r"{%s}" % base + elif (isinstance(expr.base, Derivative) + and base.startswith(r'\left(') + and re.match(r'\\left\(\\d?d?dot', base) + and base.endswith(r'\right)')): + # don't use parentheses around dotted derivative + base = base[6: -7] # remove outermost added parens + return template % (base, exp) + + def _print_UnevaluatedExpr(self, expr): + return self._print(expr.args[0]) + + def _print_Sum(self, expr): + if len(expr.limits) == 1: + tex = r"\sum_{%s=%s}^{%s} " % \ + tuple([self._print(i) for i in expr.limits[0]]) + else: + def _format_ineq(l): + return r"%s \leq %s \leq %s" % \ + tuple([self._print(s) for s in (l[1], l[0], l[2])]) + + tex = r"\sum_{\substack{%s}} " % \ + str.join('\\\\', [_format_ineq(l) for l in expr.limits]) + + if isinstance(expr.function, Add): + tex += r"\left(%s\right)" % self._print(expr.function) + else: + tex += self._print(expr.function) + + return tex + + def _print_Product(self, expr): + if len(expr.limits) == 1: + tex = r"\prod_{%s=%s}^{%s} " % \ + tuple([self._print(i) for i in expr.limits[0]]) + else: + def _format_ineq(l): + return r"%s \leq %s \leq %s" % \ + tuple([self._print(s) for s in (l[1], l[0], l[2])]) + + tex = r"\prod_{\substack{%s}} " % \ + str.join('\\\\', [_format_ineq(l) for l in expr.limits]) + + if isinstance(expr.function, Add): + tex += r"\left(%s\right)" % self._print(expr.function) + else: + tex += self._print(expr.function) + + return tex + + def _print_BasisDependent(self, expr: 'BasisDependent'): + from sympy.vector import Vector + + o1: list[str] = [] + if expr == expr.zero: + return expr.zero._latex_form + if isinstance(expr, Vector): + items = expr.separate().items() + else: + items = [(0, expr)] + + for system, vect in items: + inneritems = list(vect.components.items()) + inneritems.sort(key=lambda x: x[0].__str__()) + for k, v in inneritems: + if v == 1: + o1.append(' + ' + k._latex_form) + elif v == -1: + o1.append(' - ' + k._latex_form) + else: + arg_str = r'\left(' + self._print(v) + r'\right)' + o1.append(' + ' + arg_str + k._latex_form) + + outstr = (''.join(o1)) + if outstr[1] != '-': + outstr = outstr[3:] + else: + outstr = outstr[1:] + return outstr + + def _print_Indexed(self, expr): + tex_base = self._print(expr.base) + tex = '{'+tex_base+'}'+'_{%s}' % ','.join( + map(self._print, expr.indices)) + return tex + + def _print_IndexedBase(self, expr): + return self._print(expr.label) + + def _print_Idx(self, expr): + label = self._print(expr.label) + if expr.upper is not None: + upper = self._print(expr.upper) + if expr.lower is not None: + lower = self._print(expr.lower) + else: + lower = self._print(S.Zero) + interval = '{lower}\\mathrel{{..}}\\nobreak {upper}'.format( + lower = lower, upper = upper) + return '{{{label}}}_{{{interval}}}'.format( + label = label, interval = interval) + #if no bounds are defined this just prints the label + return label + + def _print_Derivative(self, expr): + if requires_partial(expr.expr): + diff_symbol = r'\partial' + else: + diff_symbol = self._settings["diff_operator_latex"] + + tex = "" + dim = 0 + for x, num in reversed(expr.variable_count): + dim += num + if num == 1: + tex += r"%s %s" % (diff_symbol, self._print(x)) + else: + tex += r"%s %s^{%s}" % (diff_symbol, + self.parenthesize_super(self._print(x)), + self._print(num)) + + if dim == 1: + tex = r"\frac{%s}{%s}" % (diff_symbol, tex) + else: + tex = r"\frac{%s^{%s}}{%s}" % (diff_symbol, self._print(dim), tex) + + if any(i.could_extract_minus_sign() for i in expr.args): + return r"%s %s" % (tex, self.parenthesize(expr.expr, + PRECEDENCE["Mul"], + is_neg=True, + strict=True)) + + return r"%s %s" % (tex, self.parenthesize(expr.expr, + PRECEDENCE["Mul"], + is_neg=False, + strict=True)) + + def _print_Subs(self, subs): + expr, old, new = subs.args + latex_expr = self._print(expr) + latex_old = (self._print(e) for e in old) + latex_new = (self._print(e) for e in new) + latex_subs = r'\\ '.join( + e[0] + '=' + e[1] for e in zip(latex_old, latex_new)) + return r'\left. %s \right|_{\substack{ %s }}' % (latex_expr, + latex_subs) + + def _print_Integral(self, expr): + tex, symbols = "", [] + diff_symbol = self._settings["diff_operator_latex"] + + # Only up to \iiiint exists + if len(expr.limits) <= 4 and all(len(lim) == 1 for lim in expr.limits): + # Use len(expr.limits)-1 so that syntax highlighters don't think + # \" is an escaped quote + tex = r"\i" + "i"*(len(expr.limits) - 1) + "nt" + symbols = [r"\, %s%s" % (diff_symbol, self._print(symbol[0])) + for symbol in expr.limits] + + else: + for lim in reversed(expr.limits): + symbol = lim[0] + tex += r"\int" + + if len(lim) > 1: + if self._settings['mode'] != 'inline' \ + and not self._settings['itex']: + tex += r"\limits" + + if len(lim) == 3: + tex += "_{%s}^{%s}" % (self._print(lim[1]), + self._print(lim[2])) + if len(lim) == 2: + tex += "^{%s}" % (self._print(lim[1])) + + symbols.insert(0, r"\, %s%s" % (diff_symbol, self._print(symbol))) + + return r"%s %s%s" % (tex, self.parenthesize(expr.function, + PRECEDENCE["Mul"], + is_neg=any(i.could_extract_minus_sign() for i in expr.args), + strict=True), + "".join(symbols)) + + def _print_Limit(self, expr): + e, z, z0, dir = expr.args + + tex = r"\lim_{%s \to " % self._print(z) + if str(dir) == '+-' or z0 in (S.Infinity, S.NegativeInfinity): + tex += r"%s}" % self._print(z0) + else: + tex += r"%s^%s}" % (self._print(z0), self._print(dir)) + + if isinstance(e, AssocOp): + return r"%s\left(%s\right)" % (tex, self._print(e)) + else: + return r"%s %s" % (tex, self._print(e)) + + def _hprint_Function(self, func: str) -> str: + r''' + Logic to decide how to render a function to latex + - if it is a recognized latex name, use the appropriate latex command + - if it is a single letter, excluding sub- and superscripts, just use that letter + - if it is a longer name, then put \operatorname{} around it and be + mindful of undercores in the name + ''' + func = self._deal_with_super_sub(func) + superscriptidx = func.find("^") + subscriptidx = func.find("_") + if func in accepted_latex_functions: + name = r"\%s" % func + elif len(func) == 1 or func.startswith('\\') or subscriptidx == 1 or superscriptidx == 1: + name = func + else: + if superscriptidx > 0 and subscriptidx > 0: + name = r"\operatorname{%s}%s" %( + func[:min(subscriptidx,superscriptidx)], + func[min(subscriptidx,superscriptidx):]) + elif superscriptidx > 0: + name = r"\operatorname{%s}%s" %( + func[:superscriptidx], + func[superscriptidx:]) + elif subscriptidx > 0: + name = r"\operatorname{%s}%s" %( + func[:subscriptidx], + func[subscriptidx:]) + else: + name = r"\operatorname{%s}" % func + return name + + def _print_Function(self, expr: Function, exp=None) -> str: + r''' + Render functions to LaTeX, handling functions that LaTeX knows about + e.g., sin, cos, ... by using the proper LaTeX command (\sin, \cos, ...). + For single-letter function names, render them as regular LaTeX math + symbols. For multi-letter function names that LaTeX does not know + about, (e.g., Li, sech) use \operatorname{} so that the function name + is rendered in Roman font and LaTeX handles spacing properly. + + expr is the expression involving the function + exp is an exponent + ''' + func = expr.func.__name__ + if hasattr(self, '_print_' + func) and \ + not isinstance(expr, AppliedUndef): + return getattr(self, '_print_' + func)(expr, exp) + else: + args = [str(self._print(arg)) for arg in expr.args] + # How inverse trig functions should be displayed, formats are: + # abbreviated: asin, full: arcsin, power: sin^-1 + inv_trig_style = self._settings['inv_trig_style'] + # If we are dealing with a power-style inverse trig function + inv_trig_power_case = False + # If it is applicable to fold the argument brackets + can_fold_brackets = self._settings['fold_func_brackets'] and \ + len(args) == 1 and \ + not self._needs_function_brackets(expr.args[0]) + + inv_trig_table = [ + "asin", "acos", "atan", + "acsc", "asec", "acot", + "asinh", "acosh", "atanh", + "acsch", "asech", "acoth", + ] + + # If the function is an inverse trig function, handle the style + if func in inv_trig_table: + if inv_trig_style == "abbreviated": + pass + elif inv_trig_style == "full": + func = ("ar" if func[-1] == "h" else "arc") + func[1:] + elif inv_trig_style == "power": + func = func[1:] + inv_trig_power_case = True + + # Can never fold brackets if we're raised to a power + if exp is not None: + can_fold_brackets = False + + if inv_trig_power_case: + if func in accepted_latex_functions: + name = r"\%s^{-1}" % func + else: + name = r"\operatorname{%s}^{-1}" % func + elif exp is not None: + func_tex = self._hprint_Function(func) + func_tex = self.parenthesize_super(func_tex) + name = r'%s^{%s}' % (func_tex, exp) + else: + name = self._hprint_Function(func) + + if can_fold_brackets: + if func in accepted_latex_functions: + # Wrap argument safely to avoid parse-time conflicts + # with the function name itself + name += r" {%s}" + else: + name += r"%s" + else: + name += r"{\left(%s \right)}" + + if inv_trig_power_case and exp is not None: + name += r"^{%s}" % exp + + return name % ",".join(args) + + def _print_UndefinedFunction(self, expr): + return self._hprint_Function(str(expr)) + + def _print_ElementwiseApplyFunction(self, expr): + return r"{%s}_{\circ}\left({%s}\right)" % ( + self._print(expr.function), + self._print(expr.expr), + ) + + @property + def _special_function_classes(self): + from sympy.functions.special.tensor_functions import KroneckerDelta + from sympy.functions.special.gamma_functions import gamma, lowergamma + from sympy.functions.special.beta_functions import beta + from sympy.functions.special.delta_functions import DiracDelta + from sympy.functions.special.error_functions import Chi + return {KroneckerDelta: r'\delta', + gamma: r'\Gamma', + lowergamma: r'\gamma', + beta: r'\operatorname{B}', + DiracDelta: r'\delta', + Chi: r'\operatorname{Chi}'} + + def _print_FunctionClass(self, expr): + for cls in self._special_function_classes: + if issubclass(expr, cls) and expr.__name__ == cls.__name__: + return self._special_function_classes[cls] + return self._hprint_Function(str(expr)) + + def _print_Lambda(self, expr): + symbols, expr = expr.args + + if len(symbols) == 1: + symbols = self._print(symbols[0]) + else: + symbols = self._print(tuple(symbols)) + + tex = r"\left( %s \mapsto %s \right)" % (symbols, self._print(expr)) + + return tex + + def _print_IdentityFunction(self, expr): + return r"\left( x \mapsto x \right)" + + def _hprint_variadic_function(self, expr, exp=None) -> str: + args = sorted(expr.args, key=default_sort_key) + texargs = [r"%s" % self._print(symbol) for symbol in args] + tex = r"\%s\left(%s\right)" % (str(expr.func).lower(), + ", ".join(texargs)) + if exp is not None: + return r"%s^{%s}" % (tex, exp) + else: + return tex + + _print_Min = _print_Max = _hprint_variadic_function + + def _print_floor(self, expr, exp=None): + tex = r"\left\lfloor{%s}\right\rfloor" % self._print(expr.args[0]) + + if exp is not None: + return r"%s^{%s}" % (tex, exp) + else: + return tex + + def _print_ceiling(self, expr, exp=None): + tex = r"\left\lceil{%s}\right\rceil" % self._print(expr.args[0]) + + if exp is not None: + return r"%s^{%s}" % (tex, exp) + else: + return tex + + def _print_log(self, expr, exp=None): + if not self._settings["ln_notation"]: + tex = r"\log{\left(%s \right)}" % self._print(expr.args[0]) + else: + tex = r"\ln{\left(%s \right)}" % self._print(expr.args[0]) + + if exp is not None: + return r"%s^{%s}" % (tex, exp) + else: + return tex + + def _print_Abs(self, expr, exp=None): + tex = r"\left|{%s}\right|" % self._print(expr.args[0]) + + if exp is not None: + return r"%s^{%s}" % (tex, exp) + else: + return tex + + def _print_re(self, expr, exp=None): + if self._settings['gothic_re_im']: + tex = r"\Re{%s}" % self.parenthesize(expr.args[0], PRECEDENCE['Atom']) + else: + tex = r"\operatorname{{re}}{{{}}}".format(self.parenthesize(expr.args[0], PRECEDENCE['Atom'])) + + return self._do_exponent(tex, exp) + + def _print_im(self, expr, exp=None): + if self._settings['gothic_re_im']: + tex = r"\Im{%s}" % self.parenthesize(expr.args[0], PRECEDENCE['Atom']) + else: + tex = r"\operatorname{{im}}{{{}}}".format(self.parenthesize(expr.args[0], PRECEDENCE['Atom'])) + + return self._do_exponent(tex, exp) + + def _print_Not(self, e): + from sympy.logic.boolalg import (Equivalent, Implies) + if isinstance(e.args[0], Equivalent): + return self._print_Equivalent(e.args[0], r"\not\Leftrightarrow") + if isinstance(e.args[0], Implies): + return self._print_Implies(e.args[0], r"\not\Rightarrow") + if (e.args[0].is_Boolean): + return r"\neg \left(%s\right)" % self._print(e.args[0]) + else: + return r"\neg %s" % self._print(e.args[0]) + + def _print_LogOp(self, args, char): + arg = args[0] + if arg.is_Boolean and not arg.is_Not: + tex = r"\left(%s\right)" % self._print(arg) + else: + tex = r"%s" % self._print(arg) + + for arg in args[1:]: + if arg.is_Boolean and not arg.is_Not: + tex += r" %s \left(%s\right)" % (char, self._print(arg)) + else: + tex += r" %s %s" % (char, self._print(arg)) + + return tex + + def _print_And(self, e): + args = sorted(e.args, key=default_sort_key) + return self._print_LogOp(args, r"\wedge") + + def _print_Or(self, e): + args = sorted(e.args, key=default_sort_key) + return self._print_LogOp(args, r"\vee") + + def _print_Xor(self, e): + args = sorted(e.args, key=default_sort_key) + return self._print_LogOp(args, r"\veebar") + + def _print_Implies(self, e, altchar=None): + return self._print_LogOp(e.args, altchar or r"\Rightarrow") + + def _print_Equivalent(self, e, altchar=None): + args = sorted(e.args, key=default_sort_key) + return self._print_LogOp(args, altchar or r"\Leftrightarrow") + + def _print_conjugate(self, expr, exp=None): + tex = r"\overline{%s}" % self._print(expr.args[0]) + + if exp is not None: + return r"%s^{%s}" % (tex, exp) + else: + return tex + + def _print_polar_lift(self, expr, exp=None): + func = r"\operatorname{polar\_lift}" + arg = r"{\left(%s \right)}" % self._print(expr.args[0]) + + if exp is not None: + return r"%s^{%s}%s" % (func, exp, arg) + else: + return r"%s%s" % (func, arg) + + def _print_ExpBase(self, expr, exp=None): + # TODO should exp_polar be printed differently? + # what about exp_polar(0), exp_polar(1)? + tex = r"e^{%s}" % self._print(expr.args[0]) + return self._do_exponent(tex, exp) + + def _print_Exp1(self, expr, exp=None): + return "e" + + def _print_elliptic_k(self, expr, exp=None): + tex = r"\left(%s\right)" % self._print(expr.args[0]) + if exp is not None: + return r"K^{%s}%s" % (exp, tex) + else: + return r"K%s" % tex + + def _print_elliptic_f(self, expr, exp=None): + tex = r"\left(%s\middle| %s\right)" % \ + (self._print(expr.args[0]), self._print(expr.args[1])) + if exp is not None: + return r"F^{%s}%s" % (exp, tex) + else: + return r"F%s" % tex + + def _print_elliptic_e(self, expr, exp=None): + if len(expr.args) == 2: + tex = r"\left(%s\middle| %s\right)" % \ + (self._print(expr.args[0]), self._print(expr.args[1])) + else: + tex = r"\left(%s\right)" % self._print(expr.args[0]) + if exp is not None: + return r"E^{%s}%s" % (exp, tex) + else: + return r"E%s" % tex + + def _print_elliptic_pi(self, expr, exp=None): + if len(expr.args) == 3: + tex = r"\left(%s; %s\middle| %s\right)" % \ + (self._print(expr.args[0]), self._print(expr.args[1]), + self._print(expr.args[2])) + else: + tex = r"\left(%s\middle| %s\right)" % \ + (self._print(expr.args[0]), self._print(expr.args[1])) + if exp is not None: + return r"\Pi^{%s}%s" % (exp, tex) + else: + return r"\Pi%s" % tex + + def _print_beta(self, expr, exp=None): + x = expr.args[0] + # Deal with unevaluated single argument beta + y = expr.args[0] if len(expr.args) == 1 else expr.args[1] + tex = rf"\left({x}, {y}\right)" + + if exp is not None: + return r"\operatorname{B}^{%s}%s" % (exp, tex) + else: + return r"\operatorname{B}%s" % tex + + def _print_betainc(self, expr, exp=None, operator='B'): + largs = [self._print(arg) for arg in expr.args] + tex = r"\left(%s, %s\right)" % (largs[0], largs[1]) + + if exp is not None: + return r"\operatorname{%s}_{(%s, %s)}^{%s}%s" % (operator, largs[2], largs[3], exp, tex) + else: + return r"\operatorname{%s}_{(%s, %s)}%s" % (operator, largs[2], largs[3], tex) + + def _print_betainc_regularized(self, expr, exp=None): + return self._print_betainc(expr, exp, operator='I') + + def _print_uppergamma(self, expr, exp=None): + tex = r"\left(%s, %s\right)" % (self._print(expr.args[0]), + self._print(expr.args[1])) + + if exp is not None: + return r"\Gamma^{%s}%s" % (exp, tex) + else: + return r"\Gamma%s" % tex + + def _print_lowergamma(self, expr, exp=None): + tex = r"\left(%s, %s\right)" % (self._print(expr.args[0]), + self._print(expr.args[1])) + + if exp is not None: + return r"\gamma^{%s}%s" % (exp, tex) + else: + return r"\gamma%s" % tex + + def _hprint_one_arg_func(self, expr, exp=None) -> str: + tex = r"\left(%s\right)" % self._print(expr.args[0]) + + if exp is not None: + return r"%s^{%s}%s" % (self._print(expr.func), exp, tex) + else: + return r"%s%s" % (self._print(expr.func), tex) + + _print_gamma = _hprint_one_arg_func + + def _print_Chi(self, expr, exp=None): + tex = r"\left(%s\right)" % self._print(expr.args[0]) + + if exp is not None: + return r"\operatorname{Chi}^{%s}%s" % (exp, tex) + else: + return r"\operatorname{Chi}%s" % tex + + def _print_expint(self, expr, exp=None): + tex = r"\left(%s\right)" % self._print(expr.args[1]) + nu = self._print(expr.args[0]) + + if exp is not None: + return r"\operatorname{E}_{%s}^{%s}%s" % (nu, exp, tex) + else: + return r"\operatorname{E}_{%s}%s" % (nu, tex) + + def _print_fresnels(self, expr, exp=None): + tex = r"\left(%s\right)" % self._print(expr.args[0]) + + if exp is not None: + return r"S^{%s}%s" % (exp, tex) + else: + return r"S%s" % tex + + def _print_fresnelc(self, expr, exp=None): + tex = r"\left(%s\right)" % self._print(expr.args[0]) + + if exp is not None: + return r"C^{%s}%s" % (exp, tex) + else: + return r"C%s" % tex + + def _print_subfactorial(self, expr, exp=None): + tex = r"!%s" % self.parenthesize(expr.args[0], PRECEDENCE["Func"]) + + if exp is not None: + return r"\left(%s\right)^{%s}" % (tex, exp) + else: + return tex + + def _print_factorial(self, expr, exp=None): + tex = r"%s!" % self.parenthesize(expr.args[0], PRECEDENCE["Func"]) + + if exp is not None: + return r"%s^{%s}" % (tex, exp) + else: + return tex + + def _print_factorial2(self, expr, exp=None): + tex = r"%s!!" % self.parenthesize(expr.args[0], PRECEDENCE["Func"]) + + if exp is not None: + return r"%s^{%s}" % (tex, exp) + else: + return tex + + def _print_binomial(self, expr, exp=None): + tex = r"{\binom{%s}{%s}}" % (self._print(expr.args[0]), + self._print(expr.args[1])) + + if exp is not None: + return r"%s^{%s}" % (tex, exp) + else: + return tex + + def _print_RisingFactorial(self, expr, exp=None): + n, k = expr.args + base = r"%s" % self.parenthesize(n, PRECEDENCE['Func']) + + tex = r"{%s}^{\left(%s\right)}" % (base, self._print(k)) + + return self._do_exponent(tex, exp) + + def _print_FallingFactorial(self, expr, exp=None): + n, k = expr.args + sub = r"%s" % self.parenthesize(k, PRECEDENCE['Func']) + + tex = r"{\left(%s\right)}_{%s}" % (self._print(n), sub) + + return self._do_exponent(tex, exp) + + def _hprint_BesselBase(self, expr, exp, sym: str) -> str: + tex = r"%s" % (sym) + + need_exp = False + if exp is not None: + if tex.find('^') == -1: + tex = r"%s^{%s}" % (tex, exp) + else: + need_exp = True + + tex = r"%s_{%s}\left(%s\right)" % (tex, self._print(expr.order), + self._print(expr.argument)) + + if need_exp: + tex = self._do_exponent(tex, exp) + return tex + + def _hprint_vec(self, vec) -> str: + if not vec: + return "" + s = "" + for i in vec[:-1]: + s += "%s, " % self._print(i) + s += self._print(vec[-1]) + return s + + def _print_besselj(self, expr, exp=None): + return self._hprint_BesselBase(expr, exp, 'J') + + def _print_besseli(self, expr, exp=None): + return self._hprint_BesselBase(expr, exp, 'I') + + def _print_besselk(self, expr, exp=None): + return self._hprint_BesselBase(expr, exp, 'K') + + def _print_bessely(self, expr, exp=None): + return self._hprint_BesselBase(expr, exp, 'Y') + + def _print_yn(self, expr, exp=None): + return self._hprint_BesselBase(expr, exp, 'y') + + def _print_jn(self, expr, exp=None): + return self._hprint_BesselBase(expr, exp, 'j') + + def _print_hankel1(self, expr, exp=None): + return self._hprint_BesselBase(expr, exp, 'H^{(1)}') + + def _print_hankel2(self, expr, exp=None): + return self._hprint_BesselBase(expr, exp, 'H^{(2)}') + + def _print_hn1(self, expr, exp=None): + return self._hprint_BesselBase(expr, exp, 'h^{(1)}') + + def _print_hn2(self, expr, exp=None): + return self._hprint_BesselBase(expr, exp, 'h^{(2)}') + + def _hprint_airy(self, expr, exp=None, notation="") -> str: + tex = r"\left(%s\right)" % self._print(expr.args[0]) + + if exp is not None: + return r"%s^{%s}%s" % (notation, exp, tex) + else: + return r"%s%s" % (notation, tex) + + def _hprint_airy_prime(self, expr, exp=None, notation="") -> str: + tex = r"\left(%s\right)" % self._print(expr.args[0]) + + if exp is not None: + return r"{%s^\prime}^{%s}%s" % (notation, exp, tex) + else: + return r"%s^\prime%s" % (notation, tex) + + def _print_airyai(self, expr, exp=None): + return self._hprint_airy(expr, exp, 'Ai') + + def _print_airybi(self, expr, exp=None): + return self._hprint_airy(expr, exp, 'Bi') + + def _print_airyaiprime(self, expr, exp=None): + return self._hprint_airy_prime(expr, exp, 'Ai') + + def _print_airybiprime(self, expr, exp=None): + return self._hprint_airy_prime(expr, exp, 'Bi') + + def _print_hyper(self, expr, exp=None): + tex = r"{{}_{%s}F_{%s}\left(\begin{matrix} %s \\ %s \end{matrix}" \ + r"\middle| {%s} \right)}" % \ + (self._print(len(expr.ap)), self._print(len(expr.bq)), + self._hprint_vec(expr.ap), self._hprint_vec(expr.bq), + self._print(expr.argument)) + + if exp is not None: + tex = r"{%s}^{%s}" % (tex, exp) + return tex + + def _print_meijerg(self, expr, exp=None): + tex = r"{G_{%s, %s}^{%s, %s}\left(\begin{matrix} %s & %s \\" \ + r"%s & %s \end{matrix} \middle| {%s} \right)}" % \ + (self._print(len(expr.ap)), self._print(len(expr.bq)), + self._print(len(expr.bm)), self._print(len(expr.an)), + self._hprint_vec(expr.an), self._hprint_vec(expr.aother), + self._hprint_vec(expr.bm), self._hprint_vec(expr.bother), + self._print(expr.argument)) + + if exp is not None: + tex = r"{%s}^{%s}" % (tex, exp) + return tex + + def _print_dirichlet_eta(self, expr, exp=None): + tex = r"\left(%s\right)" % self._print(expr.args[0]) + if exp is not None: + return r"\eta^{%s}%s" % (exp, tex) + return r"\eta%s" % tex + + def _print_zeta(self, expr, exp=None): + if len(expr.args) == 2: + tex = r"\left(%s, %s\right)" % tuple(map(self._print, expr.args)) + else: + tex = r"\left(%s\right)" % self._print(expr.args[0]) + if exp is not None: + return r"\zeta^{%s}%s" % (exp, tex) + return r"\zeta%s" % tex + + def _print_stieltjes(self, expr, exp=None): + if len(expr.args) == 2: + tex = r"_{%s}\left(%s\right)" % tuple(map(self._print, expr.args)) + else: + tex = r"_{%s}" % self._print(expr.args[0]) + if exp is not None: + return r"\gamma%s^{%s}" % (tex, exp) + return r"\gamma%s" % tex + + def _print_lerchphi(self, expr, exp=None): + tex = r"\left(%s, %s, %s\right)" % tuple(map(self._print, expr.args)) + if exp is None: + return r"\Phi%s" % tex + return r"\Phi^{%s}%s" % (exp, tex) + + def _print_polylog(self, expr, exp=None): + s, z = map(self._print, expr.args) + tex = r"\left(%s\right)" % z + if exp is None: + return r"\operatorname{Li}_{%s}%s" % (s, tex) + return r"\operatorname{Li}_{%s}^{%s}%s" % (s, exp, tex) + + def _print_jacobi(self, expr, exp=None): + n, a, b, x = map(self._print, expr.args) + tex = r"P_{%s}^{\left(%s,%s\right)}\left(%s\right)" % (n, a, b, x) + if exp is not None: + tex = r"\left(" + tex + r"\right)^{%s}" % (exp) + return tex + + def _print_gegenbauer(self, expr, exp=None): + n, a, x = map(self._print, expr.args) + tex = r"C_{%s}^{\left(%s\right)}\left(%s\right)" % (n, a, x) + if exp is not None: + tex = r"\left(" + tex + r"\right)^{%s}" % (exp) + return tex + + def _print_chebyshevt(self, expr, exp=None): + n, x = map(self._print, expr.args) + tex = r"T_{%s}\left(%s\right)" % (n, x) + if exp is not None: + tex = r"\left(" + tex + r"\right)^{%s}" % (exp) + return tex + + def _print_chebyshevu(self, expr, exp=None): + n, x = map(self._print, expr.args) + tex = r"U_{%s}\left(%s\right)" % (n, x) + if exp is not None: + tex = r"\left(" + tex + r"\right)^{%s}" % (exp) + return tex + + def _print_legendre(self, expr, exp=None): + n, x = map(self._print, expr.args) + tex = r"P_{%s}\left(%s\right)" % (n, x) + if exp is not None: + tex = r"\left(" + tex + r"\right)^{%s}" % (exp) + return tex + + def _print_assoc_legendre(self, expr, exp=None): + n, a, x = map(self._print, expr.args) + tex = r"P_{%s}^{\left(%s\right)}\left(%s\right)" % (n, a, x) + if exp is not None: + tex = r"\left(" + tex + r"\right)^{%s}" % (exp) + return tex + + def _print_hermite(self, expr, exp=None): + n, x = map(self._print, expr.args) + tex = r"H_{%s}\left(%s\right)" % (n, x) + if exp is not None: + tex = r"\left(" + tex + r"\right)^{%s}" % (exp) + return tex + + def _print_laguerre(self, expr, exp=None): + n, x = map(self._print, expr.args) + tex = r"L_{%s}\left(%s\right)" % (n, x) + if exp is not None: + tex = r"\left(" + tex + r"\right)^{%s}" % (exp) + return tex + + def _print_assoc_laguerre(self, expr, exp=None): + n, a, x = map(self._print, expr.args) + tex = r"L_{%s}^{\left(%s\right)}\left(%s\right)" % (n, a, x) + if exp is not None: + tex = r"\left(" + tex + r"\right)^{%s}" % (exp) + return tex + + def _print_Ynm(self, expr, exp=None): + n, m, theta, phi = map(self._print, expr.args) + tex = r"Y_{%s}^{%s}\left(%s,%s\right)" % (n, m, theta, phi) + if exp is not None: + tex = r"\left(" + tex + r"\right)^{%s}" % (exp) + return tex + + def _print_Znm(self, expr, exp=None): + n, m, theta, phi = map(self._print, expr.args) + tex = r"Z_{%s}^{%s}\left(%s,%s\right)" % (n, m, theta, phi) + if exp is not None: + tex = r"\left(" + tex + r"\right)^{%s}" % (exp) + return tex + + def __print_mathieu_functions(self, character, args, prime=False, exp=None): + a, q, z = map(self._print, args) + sup = r"^{\prime}" if prime else "" + exp = "" if not exp else "^{%s}" % exp + return r"%s%s\left(%s, %s, %s\right)%s" % (character, sup, a, q, z, exp) + + def _print_mathieuc(self, expr, exp=None): + return self.__print_mathieu_functions("C", expr.args, exp=exp) + + def _print_mathieus(self, expr, exp=None): + return self.__print_mathieu_functions("S", expr.args, exp=exp) + + def _print_mathieucprime(self, expr, exp=None): + return self.__print_mathieu_functions("C", expr.args, prime=True, exp=exp) + + def _print_mathieusprime(self, expr, exp=None): + return self.__print_mathieu_functions("S", expr.args, prime=True, exp=exp) + + def _print_Rational(self, expr): + if expr.q != 1: + sign = "" + p = expr.p + if expr.p < 0: + sign = "- " + p = -p + if self._settings['fold_short_frac']: + return r"%s%d / %d" % (sign, p, expr.q) + return r"%s\frac{%d}{%d}" % (sign, p, expr.q) + else: + return self._print(expr.p) + + def _print_Order(self, expr): + s = self._print(expr.expr) + if expr.point and any(p != S.Zero for p in expr.point) or \ + len(expr.variables) > 1: + s += '; ' + if len(expr.variables) > 1: + s += self._print(expr.variables) + elif expr.variables: + s += self._print(expr.variables[0]) + s += r'\rightarrow ' + if len(expr.point) > 1: + s += self._print(expr.point) + else: + s += self._print(expr.point[0]) + return r"O\left(%s\right)" % s + + def _print_Symbol(self, expr: Symbol, style='plain'): + name: str = self._settings['symbol_names'].get(expr) + if name is not None: + return name + + return self._deal_with_super_sub(expr.name, style=style) + + _print_RandomSymbol = _print_Symbol + + def _split_super_sub(self, name: str) -> tuple[str, list[str], list[str]]: + if name is None or '{' in name: + return (name, [], []) + elif self._settings["disable_split_super_sub"]: + name, supers, subs = (name.replace('_', '\\_').replace('^', '\\^'), [], []) + else: + name, supers, subs = split_super_sub(name) + name = translate(name) + supers = [translate(sup) for sup in supers] + subs = [translate(sub) for sub in subs] + return (name, supers, subs) + + def _deal_with_super_sub(self, string: str, style='plain') -> str: + name, supers, subs = self._split_super_sub(string) + + # apply the style only to the name + if style == 'bold': + name = "\\mathbf{{{}}}".format(name) + + # glue all items together: + if supers: + name += "^{%s}" % " ".join(supers) + if subs: + name += "_{%s}" % " ".join(subs) + + return name + + def _print_Relational(self, expr): + if self._settings['itex']: + gt = r"\gt" + lt = r"\lt" + else: + gt = ">" + lt = "<" + + charmap = { + "==": "=", + ">": gt, + "<": lt, + ">=": r"\geq", + "<=": r"\leq", + "!=": r"\neq", + } + + return "%s %s %s" % (self._print(expr.lhs), + charmap[expr.rel_op], self._print(expr.rhs)) + + def _print_Piecewise(self, expr): + ecpairs = [r"%s & \text{for}\: %s" % (self._print(e), self._print(c)) + for e, c in expr.args[:-1]] + if expr.args[-1].cond == true: + ecpairs.append(r"%s & \text{otherwise}" % + self._print(expr.args[-1].expr)) + else: + ecpairs.append(r"%s & \text{for}\: %s" % + (self._print(expr.args[-1].expr), + self._print(expr.args[-1].cond))) + tex = r"\begin{cases} %s \end{cases}" + return tex % r" \\".join(ecpairs) + + def _print_matrix_contents(self, expr): + lines = [] + + for line in range(expr.rows): # horrible, should be 'rows' + lines.append(" & ".join([self._print(i) for i in expr[line, :]])) + + mat_str = self._settings['mat_str'] + if mat_str is None: + if self._settings['mode'] == 'inline': + mat_str = 'smallmatrix' + else: + if (expr.cols <= 10) is True: + mat_str = 'matrix' + else: + mat_str = 'array' + + out_str = r'\begin{%MATSTR%}%s\end{%MATSTR%}' + out_str = out_str.replace('%MATSTR%', mat_str) + if mat_str == 'array': + out_str = out_str.replace('%s', '{' + 'c'*expr.cols + '}%s') + return out_str % r"\\".join(lines) + + def _print_MatrixBase(self, expr): + out_str = self._print_matrix_contents(expr) + if self._settings['mat_delim']: + left_delim = self._settings['mat_delim'] + right_delim = self._delim_dict[left_delim] + out_str = r'\left' + left_delim + out_str + \ + r'\right' + right_delim + return out_str + + def _print_MatrixElement(self, expr): + matrix_part = self.parenthesize(expr.parent, PRECEDENCE['Atom'], strict=True) + index_part = f"{self._print(expr.i)},{self._print(expr.j)}" + return f"{{{matrix_part}}}_{{{index_part}}}" + + def _print_MatrixSlice(self, expr): + def latexslice(x, dim): + x = list(x) + if x[2] == 1: + del x[2] + if x[0] == 0: + x[0] = None + if x[1] == dim: + x[1] = None + return ':'.join(self._print(xi) if xi is not None else '' for xi in x) + return (self.parenthesize(expr.parent, PRECEDENCE["Atom"], strict=True) + r'\left[' + + latexslice(expr.rowslice, expr.parent.rows) + ', ' + + latexslice(expr.colslice, expr.parent.cols) + r'\right]') + + def _print_BlockMatrix(self, expr): + return self._print(expr.blocks) + + def _print_Transpose(self, expr): + mat = expr.arg + from sympy.matrices import MatrixSymbol, BlockMatrix + if (not isinstance(mat, MatrixSymbol) and + not isinstance(mat, BlockMatrix) and mat.is_MatrixExpr): + return r"\left(%s\right)^{T}" % self._print(mat) + else: + s = self.parenthesize(mat, precedence_traditional(expr), True) + if '^' in s: + return r"\left(%s\right)^{T}" % s + else: + return "%s^{T}" % s + + def _print_Trace(self, expr): + mat = expr.arg + return r"\operatorname{tr}\left(%s \right)" % self._print(mat) + + def _print_Adjoint(self, expr): + style_to_latex = { + "dagger" : r"\dagger", + "star" : r"\ast", + "hermitian": r"\mathsf{H}" + } + adjoint_style = style_to_latex.get(self._settings["adjoint_style"], r"\dagger") + mat = expr.arg + from sympy.matrices import MatrixSymbol, BlockMatrix + if (not isinstance(mat, MatrixSymbol) and + not isinstance(mat, BlockMatrix) and mat.is_MatrixExpr): + return r"\left(%s\right)^{%s}" % (self._print(mat), adjoint_style) + else: + s = self.parenthesize(mat, precedence_traditional(expr), True) + if '^' in s: + return r"\left(%s\right)^{%s}" % (s, adjoint_style) + else: + return r"%s^{%s}" % (s, adjoint_style) + + def _print_MatMul(self, expr): + from sympy import MatMul + + # Parenthesize nested MatMul but not other types of Mul objects: + parens = lambda x: self._print(x) if isinstance(x, Mul) and not isinstance(x, MatMul) else \ + self.parenthesize(x, precedence_traditional(expr), False) + + args = list(expr.args) + if expr.could_extract_minus_sign(): + if args[0] == -1: + args = args[1:] + else: + args[0] = -args[0] + return '- ' + ' '.join(map(parens, args)) + else: + return ' '.join(map(parens, args)) + + def _print_DotProduct(self, expr): + level = precedence_traditional(expr) + left, right = expr.args + return rf"{self.parenthesize(left, level)} \cdot {self.parenthesize(right, level)}" + + def _print_Determinant(self, expr): + mat = expr.arg + if mat.is_MatrixExpr: + from sympy.matrices.expressions.blockmatrix import BlockMatrix + if isinstance(mat, BlockMatrix): + return r"\left|{%s}\right|" % self._print_matrix_contents(mat.blocks) + return r"\left|{%s}\right|" % self._print(mat) + return r"\left|{%s}\right|" % self._print_matrix_contents(mat) + + + def _print_Mod(self, expr, exp=None): + if exp is not None: + return r'\left(%s \bmod %s\right)^{%s}' % \ + (self.parenthesize(expr.args[0], PRECEDENCE['Mul'], + strict=True), + self.parenthesize(expr.args[1], PRECEDENCE['Mul'], + strict=True), + exp) + return r'%s \bmod %s' % (self.parenthesize(expr.args[0], + PRECEDENCE['Mul'], + strict=True), + self.parenthesize(expr.args[1], + PRECEDENCE['Mul'], + strict=True)) + + def _print_HadamardProduct(self, expr): + args = expr.args + prec = PRECEDENCE['Pow'] + parens = self.parenthesize + + return r' \circ '.join( + (parens(arg, prec, strict=True) for arg in args)) + + def _print_HadamardPower(self, expr): + if precedence_traditional(expr.exp) < PRECEDENCE["Mul"]: + template = r"%s^{\circ \left({%s}\right)}" + else: + template = r"%s^{\circ {%s}}" + return self._helper_print_standard_power(expr, template) + + def _print_KroneckerProduct(self, expr): + args = expr.args + prec = PRECEDENCE['Pow'] + parens = self.parenthesize + + return r' \otimes '.join( + (parens(arg, prec, strict=True) for arg in args)) + + def _print_MatPow(self, expr): + base, exp = expr.base, expr.exp + from sympy.matrices import MatrixSymbol + if not isinstance(base, MatrixSymbol) and base.is_MatrixExpr: + return "\\left(%s\\right)^{%s}" % (self._print(base), + self._print(exp)) + else: + base_str = self._print(base) + if '^' in base_str: + return r"\left(%s\right)^{%s}" % (base_str, self._print(exp)) + else: + return "%s^{%s}" % (base_str, self._print(exp)) + + def _print_MatrixSymbol(self, expr): + return self._print_Symbol(expr, style=self._settings[ + 'mat_symbol_style']) + + def _print_ZeroMatrix(self, Z): + return "0" if self._settings[ + 'mat_symbol_style'] == 'plain' else r"\mathbf{0}" + + def _print_OneMatrix(self, O): + return "1" if self._settings[ + 'mat_symbol_style'] == 'plain' else r"\mathbf{1}" + + def _print_Identity(self, I): + return r"\mathbb{I}" if self._settings[ + 'mat_symbol_style'] == 'plain' else r"\mathbf{I}" + + def _print_PermutationMatrix(self, P): + perm_str = self._print(P.args[0]) + return "P_{%s}" % perm_str + + def _print_NDimArray(self, expr: NDimArray): + + if expr.rank() == 0: + return self._print(expr[()]) + + mat_str = self._settings['mat_str'] + if mat_str is None: + if self._settings['mode'] == 'inline': + mat_str = 'smallmatrix' + else: + if (expr.rank() == 0) or (expr.shape[-1] <= 10): + mat_str = 'matrix' + else: + mat_str = 'array' + block_str = r'\begin{%MATSTR%}%s\end{%MATSTR%}' + block_str = block_str.replace('%MATSTR%', mat_str) + if mat_str == 'array': + block_str = block_str.replace('%s', '{' + 'c'*expr.shape[0] + '}%s') + + if self._settings['mat_delim']: + left_delim: str = self._settings['mat_delim'] + right_delim = self._delim_dict[left_delim] + block_str = r'\left' + left_delim + block_str + \ + r'\right' + right_delim + + if expr.rank() == 0: + return block_str % "" + + level_str: list[list[str]] = [[] for i in range(expr.rank() + 1)] + shape_ranges = [list(range(i)) for i in expr.shape] + for outer_i in itertools.product(*shape_ranges): + level_str[-1].append(self._print(expr[outer_i])) + even = True + for back_outer_i in range(expr.rank()-1, -1, -1): + if len(level_str[back_outer_i+1]) < expr.shape[back_outer_i]: + break + if even: + level_str[back_outer_i].append( + r" & ".join(level_str[back_outer_i+1])) + else: + level_str[back_outer_i].append( + block_str % (r"\\".join(level_str[back_outer_i+1]))) + if len(level_str[back_outer_i+1]) == 1: + level_str[back_outer_i][-1] = r"\left[" + \ + level_str[back_outer_i][-1] + r"\right]" + even = not even + level_str[back_outer_i+1] = [] + + out_str = level_str[0][0] + + if expr.rank() % 2 == 1: + out_str = block_str % out_str + + return out_str + + def _printer_tensor_indices(self, name, indices, index_map: dict): + out_str = self._print(name) + last_valence = None + prev_map = None + for index in indices: + new_valence = index.is_up + if ((index in index_map) or prev_map) and \ + last_valence == new_valence: + out_str += "," + if last_valence != new_valence: + if last_valence is not None: + out_str += "}" + if index.is_up: + out_str += "{}^{" + else: + out_str += "{}_{" + out_str += self._print(index.args[0]) + if index in index_map: + out_str += "=" + out_str += self._print(index_map[index]) + prev_map = True + else: + prev_map = False + last_valence = new_valence + if last_valence is not None: + out_str += "}" + return out_str + + def _print_Tensor(self, expr): + name = expr.args[0].args[0] + indices = expr.get_indices() + return self._printer_tensor_indices(name, indices, {}) + + def _print_TensorElement(self, expr): + name = expr.expr.args[0].args[0] + indices = expr.expr.get_indices() + index_map = expr.index_map + return self._printer_tensor_indices(name, indices, index_map) + + def _print_TensMul(self, expr): + # prints expressions like "A(a)", "3*A(a)", "(1+x)*A(a)" + sign, args = expr._get_args_for_traditional_printer() + return sign + "".join( + [self.parenthesize(arg, precedence(expr)) for arg in args] + ) + + def _print_TensAdd(self, expr): + a = [] + args = expr.args + for x in args: + a.append(self.parenthesize(x, precedence(expr))) + a.sort() + s = ' + '.join(a) + s = s.replace('+ -', '- ') + return s + + def _print_TensorIndex(self, expr): + return "{}%s{%s}" % ( + "^" if expr.is_up else "_", + self._print(expr.args[0]) + ) + + def _print_PartialDerivative(self, expr): + if len(expr.variables) == 1: + return r"\frac{\partial}{\partial {%s}}{%s}" % ( + self._print(expr.variables[0]), + self.parenthesize(expr.expr, PRECEDENCE["Mul"], False) + ) + else: + return r"\frac{\partial^{%s}}{%s}{%s}" % ( + len(expr.variables), + " ".join([r"\partial {%s}" % self._print(i) for i in expr.variables]), + self.parenthesize(expr.expr, PRECEDENCE["Mul"], False) + ) + + def _print_ArraySymbol(self, expr): + return self._print(expr.name) + + def _print_ArrayElement(self, expr): + return "{{%s}_{%s}}" % ( + self.parenthesize(expr.name, PRECEDENCE["Func"], True), + ", ".join([f"{self._print(i)}" for i in expr.indices])) + + def _print_UniversalSet(self, expr): + return r"\mathbb{U}" + + def _print_frac(self, expr, exp=None): + if exp is None: + return r"\operatorname{frac}{\left(%s\right)}" % self._print(expr.args[0]) + else: + return r"\operatorname{frac}{\left(%s\right)}^{%s}" % ( + self._print(expr.args[0]), exp) + + def _print_tuple(self, expr): + if self._settings['decimal_separator'] == 'comma': + sep = ";" + elif self._settings['decimal_separator'] == 'period': + sep = "," + else: + raise ValueError('Unknown Decimal Separator') + + if len(expr) == 1: + # 1-tuple needs a trailing separator + return self._add_parens_lspace(self._print(expr[0]) + sep) + else: + return self._add_parens_lspace( + (sep + r" \ ").join([self._print(i) for i in expr])) + + def _print_TensorProduct(self, expr): + elements = [self._print(a) for a in expr.args] + return r' \otimes '.join(elements) + + def _print_WedgeProduct(self, expr): + elements = [self._print(a) for a in expr.args] + return r' \wedge '.join(elements) + + def _print_Tuple(self, expr): + return self._print_tuple(expr) + + def _print_list(self, expr): + if self._settings['decimal_separator'] == 'comma': + return r"\left[ %s\right]" % \ + r"; \ ".join([self._print(i) for i in expr]) + elif self._settings['decimal_separator'] == 'period': + return r"\left[ %s\right]" % \ + r", \ ".join([self._print(i) for i in expr]) + else: + raise ValueError('Unknown Decimal Separator') + + + def _print_dict(self, d): + keys = sorted(d.keys(), key=default_sort_key) + items = [] + + for key in keys: + val = d[key] + items.append("%s : %s" % (self._print(key), self._print(val))) + + return r"\left\{ %s\right\}" % r", \ ".join(items) + + def _print_Dict(self, expr): + return self._print_dict(expr) + + def _print_DiracDelta(self, expr, exp=None): + if len(expr.args) == 1 or expr.args[1] == 0: + tex = r"\delta\left(%s\right)" % self._print(expr.args[0]) + else: + tex = r"\delta^{\left( %s \right)}\left( %s \right)" % ( + self._print(expr.args[1]), self._print(expr.args[0])) + if exp: + tex = r"\left(%s\right)^{%s}" % (tex, exp) + return tex + + def _print_SingularityFunction(self, expr, exp=None): + shift = self._print(expr.args[0] - expr.args[1]) + power = self._print(expr.args[2]) + tex = r"{\left\langle %s \right\rangle}^{%s}" % (shift, power) + if exp is not None: + tex = r"{\left({\langle %s \rangle}^{%s}\right)}^{%s}" % (shift, power, exp) + return tex + + def _print_Heaviside(self, expr, exp=None): + pargs = ', '.join(self._print(arg) for arg in expr.pargs) + tex = r"\theta\left(%s\right)" % pargs + if exp: + tex = r"\left(%s\right)^{%s}" % (tex, exp) + return tex + + def _print_KroneckerDelta(self, expr, exp=None): + i = self._print(expr.args[0]) + j = self._print(expr.args[1]) + if expr.args[0].is_Atom and expr.args[1].is_Atom: + tex = r'\delta_{%s %s}' % (i, j) + else: + tex = r'\delta_{%s, %s}' % (i, j) + if exp is not None: + tex = r'\left(%s\right)^{%s}' % (tex, exp) + return tex + + def _print_LeviCivita(self, expr, exp=None): + indices = map(self._print, expr.args) + if all(x.is_Atom for x in expr.args): + tex = r'\varepsilon_{%s}' % " ".join(indices) + else: + tex = r'\varepsilon_{%s}' % ", ".join(indices) + if exp: + tex = r'\left(%s\right)^{%s}' % (tex, exp) + return tex + + def _print_RandomDomain(self, d): + if hasattr(d, 'as_boolean'): + return '\\text{Domain: }' + self._print(d.as_boolean()) + elif hasattr(d, 'set'): + return ('\\text{Domain: }' + self._print(d.symbols) + ' \\in ' + + self._print(d.set)) + elif hasattr(d, 'symbols'): + return '\\text{Domain on }' + self._print(d.symbols) + else: + return self._print(None) + + def _print_FiniteSet(self, s): + items = sorted(s.args, key=default_sort_key) + return self._print_set(items) + + def _print_set(self, s): + items = sorted(s, key=default_sort_key) + if self._settings['decimal_separator'] == 'comma': + items = "; ".join(map(self._print, items)) + elif self._settings['decimal_separator'] == 'period': + items = ", ".join(map(self._print, items)) + else: + raise ValueError('Unknown Decimal Separator') + return r"\left\{%s\right\}" % items + + + _print_frozenset = _print_set + + def _print_Range(self, s): + def _print_symbolic_range(): + # Symbolic Range that cannot be resolved + if s.args[0] == 0: + if s.args[2] == 1: + cont = self._print(s.args[1]) + else: + cont = ", ".join(self._print(arg) for arg in s.args) + else: + if s.args[2] == 1: + cont = ", ".join(self._print(arg) for arg in s.args[:2]) + else: + cont = ", ".join(self._print(arg) for arg in s.args) + + return(f"\\text{{Range}}\\left({cont}\\right)") + + dots = object() + + if s.start.is_infinite and s.stop.is_infinite: + if s.step.is_positive: + printset = dots, -1, 0, 1, dots + else: + printset = dots, 1, 0, -1, dots + elif s.start.is_infinite: + printset = dots, s[-1] - s.step, s[-1] + elif s.stop.is_infinite: + it = iter(s) + printset = next(it), next(it), dots + elif s.is_empty is not None: + if (s.size < 4) == True: + printset = tuple(s) + elif s.is_iterable: + it = iter(s) + printset = next(it), next(it), dots, s[-1] + else: + return _print_symbolic_range() + else: + return _print_symbolic_range() + return (r"\left\{" + + r", ".join(self._print(el) if el is not dots else r'\ldots' for el in printset) + + r"\right\}") + + def __print_number_polynomial(self, expr, letter, exp=None): + if len(expr.args) == 2: + if exp is not None: + return r"%s_{%s}^{%s}\left(%s\right)" % (letter, + self._print(expr.args[0]), exp, + self._print(expr.args[1])) + return r"%s_{%s}\left(%s\right)" % (letter, + self._print(expr.args[0]), self._print(expr.args[1])) + + tex = r"%s_{%s}" % (letter, self._print(expr.args[0])) + if exp is not None: + tex = r"%s^{%s}" % (tex, exp) + return tex + + def _print_bernoulli(self, expr, exp=None): + return self.__print_number_polynomial(expr, "B", exp) + + def _print_genocchi(self, expr, exp=None): + return self.__print_number_polynomial(expr, "G", exp) + + def _print_bell(self, expr, exp=None): + if len(expr.args) == 3: + tex1 = r"B_{%s, %s}" % (self._print(expr.args[0]), + self._print(expr.args[1])) + tex2 = r"\left(%s\right)" % r", ".join(self._print(el) for + el in expr.args[2]) + if exp is not None: + tex = r"%s^{%s}%s" % (tex1, exp, tex2) + else: + tex = tex1 + tex2 + return tex + return self.__print_number_polynomial(expr, "B", exp) + + def _print_fibonacci(self, expr, exp=None): + return self.__print_number_polynomial(expr, "F", exp) + + def _print_lucas(self, expr, exp=None): + tex = r"L_{%s}" % self._print(expr.args[0]) + if exp is not None: + tex = r"%s^{%s}" % (tex, exp) + return tex + + def _print_tribonacci(self, expr, exp=None): + return self.__print_number_polynomial(expr, "T", exp) + + def _print_mobius(self, expr, exp=None): + if exp is None: + return r'\mu\left(%s\right)' % self._print(expr.args[0]) + return r'\mu^{%s}\left(%s\right)' % (exp, self._print(expr.args[0])) + + def _print_SeqFormula(self, s): + dots = object() + if len(s.start.free_symbols) > 0 or len(s.stop.free_symbols) > 0: + return r"\left\{%s\right\}_{%s=%s}^{%s}" % ( + self._print(s.formula), + self._print(s.variables[0]), + self._print(s.start), + self._print(s.stop) + ) + if s.start is S.NegativeInfinity: + stop = s.stop + printset = (dots, s.coeff(stop - 3), s.coeff(stop - 2), + s.coeff(stop - 1), s.coeff(stop)) + elif s.stop is S.Infinity or s.length > 4: + printset = s[:4] + printset.append(dots) + else: + printset = tuple(s) + + return (r"\left[" + + r", ".join(self._print(el) if el is not dots else r'\ldots' for el in printset) + + r"\right]") + + _print_SeqPer = _print_SeqFormula + _print_SeqAdd = _print_SeqFormula + _print_SeqMul = _print_SeqFormula + + def _print_Interval(self, i): + if i.start == i.end: + return r"\left\{%s\right\}" % self._print(i.start) + + else: + if i.left_open: + left = '(' + else: + left = '[' + + if i.right_open: + right = ')' + else: + right = ']' + + return r"\left%s%s, %s\right%s" % \ + (left, self._print(i.start), self._print(i.end), right) + + def _print_AccumulationBounds(self, i): + return r"\left\langle %s, %s\right\rangle" % \ + (self._print(i.min), self._print(i.max)) + + def _print_Union(self, u): + prec = precedence_traditional(u) + args_str = [self.parenthesize(i, prec) for i in u.args] + return r" \cup ".join(args_str) + + def _print_Complement(self, u): + prec = precedence_traditional(u) + args_str = [self.parenthesize(i, prec) for i in u.args] + return r" \setminus ".join(args_str) + + def _print_Intersection(self, u): + prec = precedence_traditional(u) + args_str = [self.parenthesize(i, prec) for i in u.args] + return r" \cap ".join(args_str) + + def _print_SymmetricDifference(self, u): + prec = precedence_traditional(u) + args_str = [self.parenthesize(i, prec) for i in u.args] + return r" \triangle ".join(args_str) + + def _print_ProductSet(self, p): + prec = precedence_traditional(p) + if len(p.sets) >= 1 and not has_variety(p.sets): + return self.parenthesize(p.sets[0], prec) + "^{%d}" % len(p.sets) + return r" \times ".join( + self.parenthesize(set, prec) for set in p.sets) + + def _print_EmptySet(self, e): + return r"\emptyset" + + def _print_Naturals(self, n): + return r"\mathbb{N}" + + def _print_Naturals0(self, n): + return r"\mathbb{N}_0" + + def _print_Integers(self, i): + return r"\mathbb{Z}" + + def _print_Rationals(self, i): + return r"\mathbb{Q}" + + def _print_Reals(self, i): + return r"\mathbb{R}" + + def _print_Complexes(self, i): + return r"\mathbb{C}" + + def _print_ImageSet(self, s): + expr = s.lamda.expr + sig = s.lamda.signature + xys = ((self._print(x), self._print(y)) for x, y in zip(sig, s.base_sets)) + xinys = r", ".join(r"%s \in %s" % xy for xy in xys) + return r"\left\{%s\; \middle|\; %s\right\}" % (self._print(expr), xinys) + + def _print_ConditionSet(self, s): + vars_print = ', '.join([self._print(var) for var in Tuple(s.sym)]) + if s.base_set is S.UniversalSet: + return r"\left\{%s\; \middle|\; %s \right\}" % \ + (vars_print, self._print(s.condition)) + + return r"\left\{%s\; \middle|\; %s \in %s \wedge %s \right\}" % ( + vars_print, + vars_print, + self._print(s.base_set), + self._print(s.condition)) + + def _print_PowerSet(self, expr): + arg_print = self._print(expr.args[0]) + return r"\mathcal{{P}}\left({}\right)".format(arg_print) + + def _print_ComplexRegion(self, s): + vars_print = ', '.join([self._print(var) for var in s.variables]) + return r"\left\{%s\; \middle|\; %s \in %s \right\}" % ( + self._print(s.expr), + vars_print, + self._print(s.sets)) + + def _print_Contains(self, e): + return r"%s \in %s" % tuple(self._print(a) for a in e.args) + + def _print_FourierSeries(self, s): + if s.an.formula is S.Zero and s.bn.formula is S.Zero: + return self._print(s.a0) + return self._print_Add(s.truncate()) + r' + \ldots' + + def _print_FormalPowerSeries(self, s): + return self._print_Add(s.infinite) + + def _print_FiniteField(self, expr): + return r"\mathbb{F}_{%s}" % expr.mod + + def _print_IntegerRing(self, expr): + return r"\mathbb{Z}" + + def _print_RationalField(self, expr): + return r"\mathbb{Q}" + + def _print_RealField(self, expr): + return r"\mathbb{R}" + + def _print_ComplexField(self, expr): + return r"\mathbb{C}" + + def _print_PolynomialRing(self, expr): + domain = self._print(expr.domain) + symbols = ", ".join(map(self._print, expr.symbols)) + return r"%s\left[%s\right]" % (domain, symbols) + + def _print_FractionField(self, expr): + domain = self._print(expr.domain) + symbols = ", ".join(map(self._print, expr.symbols)) + return r"%s\left(%s\right)" % (domain, symbols) + + def _print_PolynomialRingBase(self, expr): + domain = self._print(expr.domain) + symbols = ", ".join(map(self._print, expr.symbols)) + inv = "" + if not expr.is_Poly: + inv = r"S_<^{-1}" + return r"%s%s\left[%s\right]" % (inv, domain, symbols) + + def _print_Poly(self, poly): + cls = poly.__class__.__name__ + terms = [] + for monom, coeff in poly.terms(): + s_monom = '' + for i, exp in enumerate(monom): + if exp > 0: + if exp == 1: + s_monom += self._print(poly.gens[i]) + else: + s_monom += self._print(pow(poly.gens[i], exp)) + + if coeff.is_Add: + if s_monom: + s_coeff = r"\left(%s\right)" % self._print(coeff) + else: + s_coeff = self._print(coeff) + else: + if s_monom: + if coeff is S.One: + terms.extend(['+', s_monom]) + continue + + if coeff is S.NegativeOne: + terms.extend(['-', s_monom]) + continue + + s_coeff = self._print(coeff) + + if not s_monom: + s_term = s_coeff + else: + s_term = s_coeff + " " + s_monom + + if s_term.startswith('-'): + terms.extend(['-', s_term[1:]]) + else: + terms.extend(['+', s_term]) + + if terms[0] in ('-', '+'): + modifier = terms.pop(0) + + if modifier == '-': + terms[0] = '-' + terms[0] + + expr = ' '.join(terms) + gens = list(map(self._print, poly.gens)) + domain = "domain=%s" % self._print(poly.get_domain()) + + args = ", ".join([expr] + gens + [domain]) + if cls in accepted_latex_functions: + tex = r"\%s {\left(%s \right)}" % (cls, args) + else: + tex = r"\operatorname{%s}{\left( %s \right)}" % (cls, args) + + return tex + + def _print_ComplexRootOf(self, root): + cls = root.__class__.__name__ + if cls == "ComplexRootOf": + cls = "CRootOf" + expr = self._print(root.expr) + index = root.index + if cls in accepted_latex_functions: + return r"\%s {\left(%s, %d\right)}" % (cls, expr, index) + else: + return r"\operatorname{%s} {\left(%s, %d\right)}" % (cls, expr, + index) + + def _print_RootSum(self, expr): + cls = expr.__class__.__name__ + args = [self._print(expr.expr)] + + if expr.fun is not S.IdentityFunction: + args.append(self._print(expr.fun)) + + if cls in accepted_latex_functions: + return r"\%s {\left(%s\right)}" % (cls, ", ".join(args)) + else: + return r"\operatorname{%s} {\left(%s\right)}" % (cls, + ", ".join(args)) + + def _print_OrdinalOmega(self, expr): + return r"\omega" + + def _print_OmegaPower(self, expr): + exp, mul = expr.args + if mul != 1: + if exp != 1: + return r"{} \omega^{{{}}}".format(mul, exp) + else: + return r"{} \omega".format(mul) + else: + if exp != 1: + return r"\omega^{{{}}}".format(exp) + else: + return r"\omega" + + def _print_Ordinal(self, expr): + return " + ".join([self._print(arg) for arg in expr.args]) + + def _print_PolyElement(self, poly): + mul_symbol = self._settings['mul_symbol_latex'] + return poly.str(self, PRECEDENCE, "{%s}^{%d}", mul_symbol) + + def _print_FracElement(self, frac): + if frac.denom == 1: + return self._print(frac.numer) + else: + numer = self._print(frac.numer) + denom = self._print(frac.denom) + return r"\frac{%s}{%s}" % (numer, denom) + + def _print_euler(self, expr, exp=None): + m, x = (expr.args[0], None) if len(expr.args) == 1 else expr.args + tex = r"E_{%s}" % self._print(m) + if exp is not None: + tex = r"%s^{%s}" % (tex, exp) + if x is not None: + tex = r"%s\left(%s\right)" % (tex, self._print(x)) + return tex + + def _print_catalan(self, expr, exp=None): + tex = r"C_{%s}" % self._print(expr.args[0]) + if exp is not None: + tex = r"%s^{%s}" % (tex, exp) + return tex + + def _print_UnifiedTransform(self, expr, s, inverse=False): + return r"\mathcal{{{}}}{}_{{{}}}\left[{}\right]\left({}\right)".format(s, '^{-1}' if inverse else '', self._print(expr.args[1]), self._print(expr.args[0]), self._print(expr.args[2])) + + def _print_MellinTransform(self, expr): + return self._print_UnifiedTransform(expr, 'M') + + def _print_InverseMellinTransform(self, expr): + return self._print_UnifiedTransform(expr, 'M', True) + + def _print_LaplaceTransform(self, expr): + return self._print_UnifiedTransform(expr, 'L') + + def _print_InverseLaplaceTransform(self, expr): + return self._print_UnifiedTransform(expr, 'L', True) + + def _print_FourierTransform(self, expr): + return self._print_UnifiedTransform(expr, 'F') + + def _print_InverseFourierTransform(self, expr): + return self._print_UnifiedTransform(expr, 'F', True) + + def _print_SineTransform(self, expr): + return self._print_UnifiedTransform(expr, 'SIN') + + def _print_InverseSineTransform(self, expr): + return self._print_UnifiedTransform(expr, 'SIN', True) + + def _print_CosineTransform(self, expr): + return self._print_UnifiedTransform(expr, 'COS') + + def _print_InverseCosineTransform(self, expr): + return self._print_UnifiedTransform(expr, 'COS', True) + + def _print_DMP(self, p): + try: + if p.ring is not None: + # TODO incorporate order + return self._print(p.ring.to_sympy(p)) + except SympifyError: + pass + return self._print(repr(p)) + + def _print_DMF(self, p): + return self._print_DMP(p) + + def _print_Object(self, object): + return self._print(Symbol(object.name)) + + def _print_LambertW(self, expr, exp=None): + arg0 = self._print(expr.args[0]) + exp = r"^{%s}" % (exp,) if exp is not None else "" + if len(expr.args) == 1: + result = r"W%s\left(%s\right)" % (exp, arg0) + else: + arg1 = self._print(expr.args[1]) + result = "W{0}_{{{1}}}\\left({2}\\right)".format(exp, arg1, arg0) + return result + + def _print_Expectation(self, expr): + return r"\operatorname{{E}}\left[{}\right]".format(self._print(expr.args[0])) + + def _print_Variance(self, expr): + return r"\operatorname{{Var}}\left({}\right)".format(self._print(expr.args[0])) + + def _print_Covariance(self, expr): + return r"\operatorname{{Cov}}\left({}\right)".format(", ".join(self._print(arg) for arg in expr.args)) + + def _print_Probability(self, expr): + return r"\operatorname{{P}}\left({}\right)".format(self._print(expr.args[0])) + + def _print_Morphism(self, morphism): + domain = self._print(morphism.domain) + codomain = self._print(morphism.codomain) + return "%s\\rightarrow %s" % (domain, codomain) + + def _print_TransferFunction(self, expr): + num, den = self._print(expr.num), self._print(expr.den) + return r"\frac{%s}{%s}" % (num, den) + + def _print_Series(self, expr): + args = list(expr.args) + parens = lambda x: self.parenthesize(x, precedence_traditional(expr), + False) + return ' '.join(map(parens, args)) + + def _print_MIMOSeries(self, expr): + from sympy.physics.control.lti import MIMOParallel + args = list(expr.args)[::-1] + parens = lambda x: self.parenthesize(x, precedence_traditional(expr), + False) if isinstance(x, MIMOParallel) else self._print(x) + return r"\cdot".join(map(parens, args)) + + def _print_Parallel(self, expr): + return ' + '.join(map(self._print, expr.args)) + + def _print_MIMOParallel(self, expr): + return ' + '.join(map(self._print, expr.args)) + + def _print_Feedback(self, expr): + from sympy.physics.control import TransferFunction, Series + + num, tf = expr.sys1, TransferFunction(1, 1, expr.var) + num_arg_list = list(num.args) if isinstance(num, Series) else [num] + den_arg_list = list(expr.sys2.args) if \ + isinstance(expr.sys2, Series) else [expr.sys2] + den_term_1 = tf + + if isinstance(num, Series) and isinstance(expr.sys2, Series): + den_term_2 = Series(*num_arg_list, *den_arg_list) + elif isinstance(num, Series) and isinstance(expr.sys2, TransferFunction): + if expr.sys2 == tf: + den_term_2 = Series(*num_arg_list) + else: + den_term_2 = tf, Series(*num_arg_list, expr.sys2) + elif isinstance(num, TransferFunction) and isinstance(expr.sys2, Series): + if num == tf: + den_term_2 = Series(*den_arg_list) + else: + den_term_2 = Series(num, *den_arg_list) + else: + if num == tf: + den_term_2 = Series(*den_arg_list) + elif expr.sys2 == tf: + den_term_2 = Series(*num_arg_list) + else: + den_term_2 = Series(*num_arg_list, *den_arg_list) + + numer = self._print(num) + denom_1 = self._print(den_term_1) + denom_2 = self._print(den_term_2) + _sign = "+" if expr.sign == -1 else "-" + + return r"\frac{%s}{%s %s %s}" % (numer, denom_1, _sign, denom_2) + + def _print_MIMOFeedback(self, expr): + from sympy.physics.control import MIMOSeries + inv_mat = self._print(MIMOSeries(expr.sys2, expr.sys1)) + sys1 = self._print(expr.sys1) + _sign = "+" if expr.sign == -1 else "-" + return r"\left(I_{\tau} %s %s\right)^{-1} \cdot %s" % (_sign, inv_mat, sys1) + + def _print_TransferFunctionMatrix(self, expr): + mat = self._print(expr._expr_mat) + return r"%s_\tau" % mat + + def _print_DFT(self, expr): + return r"\text{{{}}}_{{{}}}".format(expr.__class__.__name__, expr.n) + _print_IDFT = _print_DFT + + def _print_NamedMorphism(self, morphism): + pretty_name = self._print(Symbol(morphism.name)) + pretty_morphism = self._print_Morphism(morphism) + return "%s:%s" % (pretty_name, pretty_morphism) + + def _print_IdentityMorphism(self, morphism): + from sympy.categories import NamedMorphism + return self._print_NamedMorphism(NamedMorphism( + morphism.domain, morphism.codomain, "id")) + + def _print_CompositeMorphism(self, morphism): + # All components of the morphism have names and it is thus + # possible to build the name of the composite. + component_names_list = [self._print(Symbol(component.name)) for + component in morphism.components] + component_names_list.reverse() + component_names = "\\circ ".join(component_names_list) + ":" + + pretty_morphism = self._print_Morphism(morphism) + return component_names + pretty_morphism + + def _print_Category(self, morphism): + return r"\mathbf{{{}}}".format(self._print(Symbol(morphism.name))) + + def _print_Diagram(self, diagram): + if not diagram.premises: + # This is an empty diagram. + return self._print(S.EmptySet) + + latex_result = self._print(diagram.premises) + if diagram.conclusions: + latex_result += "\\Longrightarrow %s" % \ + self._print(diagram.conclusions) + + return latex_result + + def _print_DiagramGrid(self, grid): + latex_result = "\\begin{array}{%s}\n" % ("c" * grid.width) + + for i in range(grid.height): + for j in range(grid.width): + if grid[i, j]: + latex_result += latex(grid[i, j]) + latex_result += " " + if j != grid.width - 1: + latex_result += "& " + + if i != grid.height - 1: + latex_result += "\\\\" + latex_result += "\n" + + latex_result += "\\end{array}\n" + return latex_result + + def _print_FreeModule(self, M): + return '{{{}}}^{{{}}}'.format(self._print(M.ring), self._print(M.rank)) + + def _print_FreeModuleElement(self, m): + # Print as row vector for convenience, for now. + return r"\left[ {} \right]".format(",".join( + '{' + self._print(x) + '}' for x in m)) + + def _print_SubModule(self, m): + gens = [[self._print(m.ring.to_sympy(x)) for x in g] for g in m.gens] + curly = lambda o: r"{" + o + r"}" + square = lambda o: r"\left[ " + o + r" \right]" + gens_latex = ",".join(curly(square(",".join(curly(x) for x in g))) for g in gens) + return r"\left\langle {} \right\rangle".format(gens_latex) + + def _print_SubQuotientModule(self, m): + gens_latex = ",".join(["{" + self._print(g) + "}" for g in m.gens]) + return r"\left\langle {} \right\rangle".format(gens_latex) + + def _print_ModuleImplementedIdeal(self, m): + gens = [m.ring.to_sympy(x) for [x] in m._module.gens] + gens_latex = ",".join('{' + self._print(x) + '}' for x in gens) + return r"\left\langle {} \right\rangle".format(gens_latex) + + def _print_Quaternion(self, expr): + # TODO: This expression is potentially confusing, + # shall we print it as `Quaternion( ... )`? + s = [self.parenthesize(i, PRECEDENCE["Mul"], strict=True) + for i in expr.args] + a = [s[0]] + [i+" "+j for i, j in zip(s[1:], "ijk")] + return " + ".join(a) + + def _print_QuotientRing(self, R): + # TODO nicer fractions for few generators... + return r"\frac{{{}}}{{{}}}".format(self._print(R.ring), + self._print(R.base_ideal)) + + def _print_QuotientRingElement(self, x): + x_latex = self._print(x.ring.to_sympy(x)) + return r"{{{}}} + {{{}}}".format(x_latex, + self._print(x.ring.base_ideal)) + + def _print_QuotientModuleElement(self, m): + data = [m.module.ring.to_sympy(x) for x in m.data] + data_latex = r"\left[ {} \right]".format(",".join( + '{' + self._print(x) + '}' for x in data)) + return r"{{{}}} + {{{}}}".format(data_latex, + self._print(m.module.killed_module)) + + def _print_QuotientModule(self, M): + # TODO nicer fractions for few generators... + return r"\frac{{{}}}{{{}}}".format(self._print(M.base), + self._print(M.killed_module)) + + def _print_MatrixHomomorphism(self, h): + return r"{{{}}} : {{{}}} \to {{{}}}".format(self._print(h._sympy_matrix()), + self._print(h.domain), self._print(h.codomain)) + + def _print_Manifold(self, manifold): + name, supers, subs = self._split_super_sub(manifold.name.name) + + name = r'\text{%s}' % name + if supers: + name += "^{%s}" % " ".join(supers) + if subs: + name += "_{%s}" % " ".join(subs) + + return name + + def _print_Patch(self, patch): + return r'\text{%s}_{%s}' % (self._print(patch.name), self._print(patch.manifold)) + + def _print_CoordSystem(self, coordsys): + return r'\text{%s}^{\text{%s}}_{%s}' % ( + self._print(coordsys.name), self._print(coordsys.patch.name), self._print(coordsys.manifold) + ) + + def _print_CovarDerivativeOp(self, cvd): + return r'\mathbb{\nabla}_{%s}' % self._print(cvd._wrt) + + def _print_BaseScalarField(self, field): + string = field._coord_sys.symbols[field._index].name + return r'\mathbf{{{}}}'.format(self._print(Symbol(string))) + + def _print_BaseVectorField(self, field): + string = field._coord_sys.symbols[field._index].name + return r'\partial_{{{}}}'.format(self._print(Symbol(string))) + + def _print_Differential(self, diff): + field = diff._form_field + if hasattr(field, '_coord_sys'): + string = field._coord_sys.symbols[field._index].name + return r'\operatorname{{d}}{}'.format(self._print(Symbol(string))) + else: + string = self._print(field) + return r'\operatorname{{d}}\left({}\right)'.format(string) + + def _print_Tr(self, p): + # TODO: Handle indices + contents = self._print(p.args[0]) + return r'\operatorname{{tr}}\left({}\right)'.format(contents) + + def _print_totient(self, expr, exp=None): + if exp is not None: + return r'\left(\phi\left(%s\right)\right)^{%s}' % \ + (self._print(expr.args[0]), exp) + return r'\phi\left(%s\right)' % self._print(expr.args[0]) + + def _print_reduced_totient(self, expr, exp=None): + if exp is not None: + return r'\left(\lambda\left(%s\right)\right)^{%s}' % \ + (self._print(expr.args[0]), exp) + return r'\lambda\left(%s\right)' % self._print(expr.args[0]) + + def _print_divisor_sigma(self, expr, exp=None): + if len(expr.args) == 2: + tex = r"_%s\left(%s\right)" % tuple(map(self._print, + (expr.args[1], expr.args[0]))) + else: + tex = r"\left(%s\right)" % self._print(expr.args[0]) + if exp is not None: + return r"\sigma^{%s}%s" % (exp, tex) + return r"\sigma%s" % tex + + def _print_udivisor_sigma(self, expr, exp=None): + if len(expr.args) == 2: + tex = r"_%s\left(%s\right)" % tuple(map(self._print, + (expr.args[1], expr.args[0]))) + else: + tex = r"\left(%s\right)" % self._print(expr.args[0]) + if exp is not None: + return r"\sigma^*^{%s}%s" % (exp, tex) + return r"\sigma^*%s" % tex + + def _print_primenu(self, expr, exp=None): + if exp is not None: + return r'\left(\nu\left(%s\right)\right)^{%s}' % \ + (self._print(expr.args[0]), exp) + return r'\nu\left(%s\right)' % self._print(expr.args[0]) + + def _print_primeomega(self, expr, exp=None): + if exp is not None: + return r'\left(\Omega\left(%s\right)\right)^{%s}' % \ + (self._print(expr.args[0]), exp) + return r'\Omega\left(%s\right)' % self._print(expr.args[0]) + + def _print_Str(self, s): + return str(s.name) + + def _print_float(self, expr): + return self._print(Float(expr)) + + def _print_int(self, expr): + return str(expr) + + def _print_mpz(self, expr): + return str(expr) + + def _print_mpq(self, expr): + return str(expr) + + def _print_fmpz(self, expr): + return str(expr) + + def _print_fmpq(self, expr): + return str(expr) + + def _print_Predicate(self, expr): + return r"\operatorname{{Q}}_{{\text{{{}}}}}".format(latex_escape(str(expr.name))) + + def _print_AppliedPredicate(self, expr): + pred = expr.function + args = expr.arguments + pred_latex = self._print(pred) + args_latex = ', '.join([self._print(a) for a in args]) + return '%s(%s)' % (pred_latex, args_latex) + + def emptyPrinter(self, expr): + # default to just printing as monospace, like would normally be shown + s = super().emptyPrinter(expr) + + return r"\mathtt{\text{%s}}" % latex_escape(s) + + +def translate(s: str) -> str: + r''' + Check for a modifier ending the string. If present, convert the + modifier to latex and translate the rest recursively. + + Given a description of a Greek letter or other special character, + return the appropriate latex. + + Let everything else pass as given. + + >>> from sympy.printing.latex import translate + >>> translate('alphahatdotprime') + "{\\dot{\\hat{\\alpha}}}'" + ''' + # Process the rest + tex = tex_greek_dictionary.get(s) + if tex: + return tex + elif s.lower() in greek_letters_set: + return "\\" + s.lower() + elif s in other_symbols: + return "\\" + s + else: + # Process modifiers, if any, and recurse + for key in sorted(modifier_dict.keys(), key=len, reverse=True): + if s.lower().endswith(key) and len(s) > len(key): + return modifier_dict[key](translate(s[:-len(key)])) + return s + + + +@print_function(LatexPrinter) +def latex(expr, **settings): + r"""Convert the given expression to LaTeX string representation. + + Parameters + ========== + full_prec: boolean, optional + If set to True, a floating point number is printed with full precision. + fold_frac_powers : boolean, optional + Emit ``^{p/q}`` instead of ``^{\frac{p}{q}}`` for fractional powers. + fold_func_brackets : boolean, optional + Fold function brackets where applicable. + fold_short_frac : boolean, optional + Emit ``p / q`` instead of ``\frac{p}{q}`` when the denominator is + simple enough (at most two terms and no powers). The default value is + ``True`` for inline mode, ``False`` otherwise. + inv_trig_style : string, optional + How inverse trig functions should be displayed. Can be one of + ``'abbreviated'``, ``'full'``, or ``'power'``. Defaults to + ``'abbreviated'``. + itex : boolean, optional + Specifies if itex-specific syntax is used, including emitting + ``$$...$$``. + ln_notation : boolean, optional + If set to ``True``, ``\ln`` is used instead of default ``\log``. + long_frac_ratio : float or None, optional + The allowed ratio of the width of the numerator to the width of the + denominator before the printer breaks off long fractions. If ``None`` + (the default value), long fractions are not broken up. + mat_delim : string, optional + The delimiter to wrap around matrices. Can be one of ``'['``, ``'('``, + or the empty string ``''``. Defaults to ``'['``. + mat_str : string, optional + Which matrix environment string to emit. ``'smallmatrix'``, + ``'matrix'``, ``'array'``, etc. Defaults to ``'smallmatrix'`` for + inline mode, ``'matrix'`` for matrices of no more than 10 columns, and + ``'array'`` otherwise. + mode: string, optional + Specifies how the generated code will be delimited. ``mode`` can be one + of ``'plain'``, ``'inline'``, ``'equation'`` or ``'equation*'``. If + ``mode`` is set to ``'plain'``, then the resulting code will not be + delimited at all (this is the default). If ``mode`` is set to + ``'inline'`` then inline LaTeX ``$...$`` will be used. If ``mode`` is + set to ``'equation'`` or ``'equation*'``, the resulting code will be + enclosed in the ``equation`` or ``equation*`` environment (remember to + import ``amsmath`` for ``equation*``), unless the ``itex`` option is + set. In the latter case, the ``$$...$$`` syntax is used. + mul_symbol : string or None, optional + The symbol to use for multiplication. Can be one of ``None``, + ``'ldot'``, ``'dot'``, or ``'times'``. + order: string, optional + Any of the supported monomial orderings (currently ``'lex'``, + ``'grlex'``, or ``'grevlex'``), ``'old'``, and ``'none'``. This + parameter does nothing for `~.Mul` objects. Setting order to ``'old'`` + uses the compatibility ordering for ``~.Add`` defined in Printer. For + very large expressions, set the ``order`` keyword to ``'none'`` if + speed is a concern. + symbol_names : dictionary of strings mapped to symbols, optional + Dictionary of symbols and the custom strings they should be emitted as. + root_notation : boolean, optional + If set to ``False``, exponents of the form 1/n are printed in fractonal + form. Default is ``True``, to print exponent in root form. + mat_symbol_style : string, optional + Can be either ``'plain'`` (default) or ``'bold'``. If set to + ``'bold'``, a `~.MatrixSymbol` A will be printed as ``\mathbf{A}``, + otherwise as ``A``. + imaginary_unit : string, optional + String to use for the imaginary unit. Defined options are ``'i'`` + (default) and ``'j'``. Adding ``r`` or ``t`` in front gives ``\mathrm`` + or ``\text``, so ``'ri'`` leads to ``\mathrm{i}`` which gives + `\mathrm{i}`. + gothic_re_im : boolean, optional + If set to ``True``, `\Re` and `\Im` is used for ``re`` and ``im``, respectively. + The default is ``False`` leading to `\operatorname{re}` and `\operatorname{im}`. + decimal_separator : string, optional + Specifies what separator to use to separate the whole and fractional parts of a + floating point number as in `2.5` for the default, ``period`` or `2{,}5` + when ``comma`` is specified. Lists, sets, and tuple are printed with semicolon + separating the elements when ``comma`` is chosen. For example, [1; 2; 3] when + ``comma`` is chosen and [1,2,3] for when ``period`` is chosen. + parenthesize_super : boolean, optional + If set to ``False``, superscripted expressions will not be parenthesized when + powered. Default is ``True``, which parenthesizes the expression when powered. + min: Integer or None, optional + Sets the lower bound for the exponent to print floating point numbers in + fixed-point format. + max: Integer or None, optional + Sets the upper bound for the exponent to print floating point numbers in + fixed-point format. + diff_operator: string, optional + String to use for differential operator. Default is ``'d'``, to print in italic + form. ``'rd'``, ``'td'`` are shortcuts for ``\mathrm{d}`` and ``\text{d}``. + adjoint_style: string, optional + String to use for the adjoint symbol. Defined options are ``'dagger'`` + (default),``'star'``, and ``'hermitian'``. + + Notes + ===== + + Not using a print statement for printing, results in double backslashes for + latex commands since that's the way Python escapes backslashes in strings. + + >>> from sympy import latex, Rational + >>> from sympy.abc import tau + >>> latex((2*tau)**Rational(7,2)) + '8 \\sqrt{2} \\tau^{\\frac{7}{2}}' + >>> print(latex((2*tau)**Rational(7,2))) + 8 \sqrt{2} \tau^{\frac{7}{2}} + + Examples + ======== + + >>> from sympy import latex, pi, sin, asin, Integral, Matrix, Rational, log + >>> from sympy.abc import x, y, mu, r, tau + + Basic usage: + + >>> print(latex((2*tau)**Rational(7,2))) + 8 \sqrt{2} \tau^{\frac{7}{2}} + + ``mode`` and ``itex`` options: + + >>> print(latex((2*mu)**Rational(7,2), mode='plain')) + 8 \sqrt{2} \mu^{\frac{7}{2}} + >>> print(latex((2*tau)**Rational(7,2), mode='inline')) + $8 \sqrt{2} \tau^{7 / 2}$ + >>> print(latex((2*mu)**Rational(7,2), mode='equation*')) + \begin{equation*}8 \sqrt{2} \mu^{\frac{7}{2}}\end{equation*} + >>> print(latex((2*mu)**Rational(7,2), mode='equation')) + \begin{equation}8 \sqrt{2} \mu^{\frac{7}{2}}\end{equation} + >>> print(latex((2*mu)**Rational(7,2), mode='equation', itex=True)) + $$8 \sqrt{2} \mu^{\frac{7}{2}}$$ + >>> print(latex((2*mu)**Rational(7,2), mode='plain')) + 8 \sqrt{2} \mu^{\frac{7}{2}} + >>> print(latex((2*tau)**Rational(7,2), mode='inline')) + $8 \sqrt{2} \tau^{7 / 2}$ + >>> print(latex((2*mu)**Rational(7,2), mode='equation*')) + \begin{equation*}8 \sqrt{2} \mu^{\frac{7}{2}}\end{equation*} + >>> print(latex((2*mu)**Rational(7,2), mode='equation')) + \begin{equation}8 \sqrt{2} \mu^{\frac{7}{2}}\end{equation} + >>> print(latex((2*mu)**Rational(7,2), mode='equation', itex=True)) + $$8 \sqrt{2} \mu^{\frac{7}{2}}$$ + + Fraction options: + + >>> print(latex((2*tau)**Rational(7,2), fold_frac_powers=True)) + 8 \sqrt{2} \tau^{7/2} + >>> print(latex((2*tau)**sin(Rational(7,2)))) + \left(2 \tau\right)^{\sin{\left(\frac{7}{2} \right)}} + >>> print(latex((2*tau)**sin(Rational(7,2)), fold_func_brackets=True)) + \left(2 \tau\right)^{\sin {\frac{7}{2}}} + >>> print(latex(3*x**2/y)) + \frac{3 x^{2}}{y} + >>> print(latex(3*x**2/y, fold_short_frac=True)) + 3 x^{2} / y + >>> print(latex(Integral(r, r)/2/pi, long_frac_ratio=2)) + \frac{\int r\, dr}{2 \pi} + >>> print(latex(Integral(r, r)/2/pi, long_frac_ratio=0)) + \frac{1}{2 \pi} \int r\, dr + + Multiplication options: + + >>> print(latex((2*tau)**sin(Rational(7,2)), mul_symbol="times")) + \left(2 \times \tau\right)^{\sin{\left(\frac{7}{2} \right)}} + + Trig options: + + >>> print(latex(asin(Rational(7,2)))) + \operatorname{asin}{\left(\frac{7}{2} \right)} + >>> print(latex(asin(Rational(7,2)), inv_trig_style="full")) + \arcsin{\left(\frac{7}{2} \right)} + >>> print(latex(asin(Rational(7,2)), inv_trig_style="power")) + \sin^{-1}{\left(\frac{7}{2} \right)} + + Matrix options: + + >>> print(latex(Matrix(2, 1, [x, y]))) + \left[\begin{matrix}x\\y\end{matrix}\right] + >>> print(latex(Matrix(2, 1, [x, y]), mat_str = "array")) + \left[\begin{array}{c}x\\y\end{array}\right] + >>> print(latex(Matrix(2, 1, [x, y]), mat_delim="(")) + \left(\begin{matrix}x\\y\end{matrix}\right) + + Custom printing of symbols: + + >>> print(latex(x**2, symbol_names={x: 'x_i'})) + x_i^{2} + + Logarithms: + + >>> print(latex(log(10))) + \log{\left(10 \right)} + >>> print(latex(log(10), ln_notation=True)) + \ln{\left(10 \right)} + + ``latex()`` also supports the builtin container types :class:`list`, + :class:`tuple`, and :class:`dict`: + + >>> print(latex([2/x, y], mode='inline')) + $\left[ 2 / x, \ y\right]$ + + Unsupported types are rendered as monospaced plaintext: + + >>> print(latex(int)) + \mathtt{\text{}} + >>> print(latex("plain % text")) + \mathtt{\text{plain \% text}} + + See :ref:`printer_method_example` for an example of how to override + this behavior for your own types by implementing ``_latex``. + + .. versionchanged:: 1.7.0 + Unsupported types no longer have their ``str`` representation treated as valid latex. + + """ + return LatexPrinter(settings).doprint(expr) + + +def print_latex(expr, **settings): + """Prints LaTeX representation of the given expression. Takes the same + settings as ``latex()``.""" + + print(latex(expr, **settings)) + + +def multiline_latex(lhs, rhs, terms_per_line=1, environment="align*", use_dots=False, **settings): + r""" + This function generates a LaTeX equation with a multiline right-hand side + in an ``align*``, ``eqnarray`` or ``IEEEeqnarray`` environment. + + Parameters + ========== + + lhs : Expr + Left-hand side of equation + + rhs : Expr + Right-hand side of equation + + terms_per_line : integer, optional + Number of terms per line to print. Default is 1. + + environment : "string", optional + Which LaTeX wnvironment to use for the output. Options are "align*" + (default), "eqnarray", and "IEEEeqnarray". + + use_dots : boolean, optional + If ``True``, ``\\dots`` is added to the end of each line. Default is ``False``. + + Examples + ======== + + >>> from sympy import multiline_latex, symbols, sin, cos, exp, log, I + >>> x, y, alpha = symbols('x y alpha') + >>> expr = sin(alpha*y) + exp(I*alpha) - cos(log(y)) + >>> print(multiline_latex(x, expr)) + \begin{align*} + x = & e^{i \alpha} \\ + & + \sin{\left(\alpha y \right)} \\ + & - \cos{\left(\log{\left(y \right)} \right)} + \end{align*} + + Using at most two terms per line: + >>> print(multiline_latex(x, expr, 2)) + \begin{align*} + x = & e^{i \alpha} + \sin{\left(\alpha y \right)} \\ + & - \cos{\left(\log{\left(y \right)} \right)} + \end{align*} + + Using ``eqnarray`` and dots: + >>> print(multiline_latex(x, expr, terms_per_line=2, environment="eqnarray", use_dots=True)) + \begin{eqnarray} + x & = & e^{i \alpha} + \sin{\left(\alpha y \right)} \dots\nonumber\\ + & & - \cos{\left(\log{\left(y \right)} \right)} + \end{eqnarray} + + Using ``IEEEeqnarray``: + >>> print(multiline_latex(x, expr, environment="IEEEeqnarray")) + \begin{IEEEeqnarray}{rCl} + x & = & e^{i \alpha} \nonumber\\ + & & + \sin{\left(\alpha y \right)} \nonumber\\ + & & - \cos{\left(\log{\left(y \right)} \right)} + \end{IEEEeqnarray} + + Notes + ===== + + All optional parameters from ``latex`` can also be used. + + """ + + # Based on code from https://github.com/sympy/sympy/issues/3001 + l = LatexPrinter(**settings) + if environment == "eqnarray": + result = r'\begin{eqnarray}' + '\n' + first_term = '& = &' + nonumber = r'\nonumber' + end_term = '\n\\end{eqnarray}' + doubleet = True + elif environment == "IEEEeqnarray": + result = r'\begin{IEEEeqnarray}{rCl}' + '\n' + first_term = '& = &' + nonumber = r'\nonumber' + end_term = '\n\\end{IEEEeqnarray}' + doubleet = True + elif environment == "align*": + result = r'\begin{align*}' + '\n' + first_term = '= &' + nonumber = '' + end_term = '\n\\end{align*}' + doubleet = False + else: + raise ValueError("Unknown environment: {}".format(environment)) + dots = '' + if use_dots: + dots=r'\dots' + terms = rhs.as_ordered_terms() + n_terms = len(terms) + term_count = 1 + for i in range(n_terms): + term = terms[i] + term_start = '' + term_end = '' + sign = '+' + if term_count > terms_per_line: + if doubleet: + term_start = '& & ' + else: + term_start = '& ' + term_count = 1 + if term_count == terms_per_line: + # End of line + if i < n_terms-1: + # There are terms remaining + term_end = dots + nonumber + r'\\' + '\n' + else: + term_end = '' + + if term.as_ordered_factors()[0] == -1: + term = -1*term + sign = r'-' + if i == 0: # beginning + if sign == '+': + sign = '' + result += r'{:s} {:s}{:s} {:s} {:s}'.format(l.doprint(lhs), + first_term, sign, l.doprint(term), term_end) + else: + result += r'{:s}{:s} {:s} {:s}'.format(term_start, sign, + l.doprint(term), term_end) + term_count += 1 + result += end_term + return result diff --git a/phivenv/Lib/site-packages/sympy/printing/llvmjitcode.py b/phivenv/Lib/site-packages/sympy/printing/llvmjitcode.py new file mode 100644 index 0000000000000000000000000000000000000000..0e657f3b854be62fb4b6b4be96f82579b718f6ca --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/llvmjitcode.py @@ -0,0 +1,490 @@ +''' +Use llvmlite to create executable functions from SymPy expressions + +This module requires llvmlite (https://github.com/numba/llvmlite). +''' + +import ctypes + +from sympy.external import import_module +from sympy.printing.printer import Printer +from sympy.core.singleton import S +from sympy.tensor.indexed import IndexedBase +from sympy.utilities.decorator import doctest_depends_on + +llvmlite = import_module('llvmlite') +if llvmlite: + ll = import_module('llvmlite.ir').ir + llvm = import_module('llvmlite.binding').binding + llvm.initialize() + llvm.initialize_native_target() + llvm.initialize_native_asmprinter() + + +__doctest_requires__ = {('llvm_callable'): ['llvmlite']} + + +class LLVMJitPrinter(Printer): + '''Convert expressions to LLVM IR''' + def __init__(self, module, builder, fn, *args, **kwargs): + self.func_arg_map = kwargs.pop("func_arg_map", {}) + if not llvmlite: + raise ImportError("llvmlite is required for LLVMJITPrinter") + super().__init__(*args, **kwargs) + self.fp_type = ll.DoubleType() + self.module = module + self.builder = builder + self.fn = fn + self.ext_fn = {} # keep track of wrappers to external functions + self.tmp_var = {} + + def _add_tmp_var(self, name, value): + self.tmp_var[name] = value + + def _print_Number(self, n): + return ll.Constant(self.fp_type, float(n)) + + def _print_Integer(self, expr): + return ll.Constant(self.fp_type, float(expr.p)) + + def _print_Symbol(self, s): + val = self.tmp_var.get(s) + if not val: + # look up parameter with name s + val = self.func_arg_map.get(s) + if not val: + raise LookupError("Symbol not found: %s" % s) + return val + + def _print_Pow(self, expr): + base0 = self._print(expr.base) + if expr.exp == S.NegativeOne: + return self.builder.fdiv(ll.Constant(self.fp_type, 1.0), base0) + if expr.exp == S.Half: + fn = self.ext_fn.get("sqrt") + if not fn: + fn_type = ll.FunctionType(self.fp_type, [self.fp_type]) + fn = ll.Function(self.module, fn_type, "sqrt") + self.ext_fn["sqrt"] = fn + return self.builder.call(fn, [base0], "sqrt") + if expr.exp == 2: + return self.builder.fmul(base0, base0) + + exp0 = self._print(expr.exp) + fn = self.ext_fn.get("pow") + if not fn: + fn_type = ll.FunctionType(self.fp_type, [self.fp_type, self.fp_type]) + fn = ll.Function(self.module, fn_type, "pow") + self.ext_fn["pow"] = fn + return self.builder.call(fn, [base0, exp0], "pow") + + def _print_Mul(self, expr): + nodes = [self._print(a) for a in expr.args] + e = nodes[0] + for node in nodes[1:]: + e = self.builder.fmul(e, node) + return e + + def _print_Add(self, expr): + nodes = [self._print(a) for a in expr.args] + e = nodes[0] + for node in nodes[1:]: + e = self.builder.fadd(e, node) + return e + + # TODO - assumes all called functions take one double precision argument. + # Should have a list of math library functions to validate this. + def _print_Function(self, expr): + name = expr.func.__name__ + e0 = self._print(expr.args[0]) + fn = self.ext_fn.get(name) + if not fn: + fn_type = ll.FunctionType(self.fp_type, [self.fp_type]) + fn = ll.Function(self.module, fn_type, name) + self.ext_fn[name] = fn + return self.builder.call(fn, [e0], name) + + def emptyPrinter(self, expr): + raise TypeError("Unsupported type for LLVM JIT conversion: %s" + % type(expr)) + + +# Used when parameters are passed by array. Often used in callbacks to +# handle a variable number of parameters. +class LLVMJitCallbackPrinter(LLVMJitPrinter): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _print_Indexed(self, expr): + array, idx = self.func_arg_map[expr.base] + offset = int(expr.indices[0].evalf()) + array_ptr = self.builder.gep(array, [ll.Constant(ll.IntType(32), offset)]) + fp_array_ptr = self.builder.bitcast(array_ptr, ll.PointerType(self.fp_type)) + value = self.builder.load(fp_array_ptr) + return value + + def _print_Symbol(self, s): + val = self.tmp_var.get(s) + if val: + return val + + array, idx = self.func_arg_map.get(s, [None, 0]) + if not array: + raise LookupError("Symbol not found: %s" % s) + array_ptr = self.builder.gep(array, [ll.Constant(ll.IntType(32), idx)]) + fp_array_ptr = self.builder.bitcast(array_ptr, + ll.PointerType(self.fp_type)) + value = self.builder.load(fp_array_ptr) + return value + + +# ensure lifetime of the execution engine persists (else call to compiled +# function will seg fault) +exe_engines = [] + +# ensure names for generated functions are unique +link_names = set() +current_link_suffix = 0 + + +class LLVMJitCode: + def __init__(self, signature): + self.signature = signature + self.fp_type = ll.DoubleType() + self.module = ll.Module('mod1') + self.fn = None + self.llvm_arg_types = [] + self.llvm_ret_type = self.fp_type + self.param_dict = {} # map symbol name to LLVM function argument + self.link_name = '' + + def _from_ctype(self, ctype): + if ctype == ctypes.c_int: + return ll.IntType(32) + if ctype == ctypes.c_double: + return self.fp_type + if ctype == ctypes.POINTER(ctypes.c_double): + return ll.PointerType(self.fp_type) + if ctype == ctypes.c_void_p: + return ll.PointerType(ll.IntType(32)) + if ctype == ctypes.py_object: + return ll.PointerType(ll.IntType(32)) + + print("Unhandled ctype = %s" % str(ctype)) + + def _create_args(self, func_args): + """Create types for function arguments""" + self.llvm_ret_type = self._from_ctype(self.signature.ret_type) + self.llvm_arg_types = \ + [self._from_ctype(a) for a in self.signature.arg_ctypes] + + def _create_function_base(self): + """Create function with name and type signature""" + global current_link_suffix + default_link_name = 'jit_func' + current_link_suffix += 1 + self.link_name = default_link_name + str(current_link_suffix) + link_names.add(self.link_name) + + fn_type = ll.FunctionType(self.llvm_ret_type, self.llvm_arg_types) + self.fn = ll.Function(self.module, fn_type, name=self.link_name) + + def _create_param_dict(self, func_args): + """Mapping of symbolic values to function arguments""" + for i, a in enumerate(func_args): + self.fn.args[i].name = str(a) + self.param_dict[a] = self.fn.args[i] + + def _create_function(self, expr): + """Create function body and return LLVM IR""" + bb_entry = self.fn.append_basic_block('entry') + builder = ll.IRBuilder(bb_entry) + + lj = LLVMJitPrinter(self.module, builder, self.fn, + func_arg_map=self.param_dict) + + ret = self._convert_expr(lj, expr) + lj.builder.ret(self._wrap_return(lj, ret)) + + strmod = str(self.module) + return strmod + + def _wrap_return(self, lj, vals): + # Return a single double if there is one return value, + # else return a tuple of doubles. + + # Don't wrap return value in this case + if self.signature.ret_type == ctypes.c_double: + return vals[0] + + # Use this instead of a real PyObject* + void_ptr = ll.PointerType(ll.IntType(32)) + + # Create a wrapped double: PyObject* PyFloat_FromDouble(double v) + wrap_type = ll.FunctionType(void_ptr, [self.fp_type]) + wrap_fn = ll.Function(lj.module, wrap_type, "PyFloat_FromDouble") + + wrapped_vals = [lj.builder.call(wrap_fn, [v]) for v in vals] + if len(vals) == 1: + final_val = wrapped_vals[0] + else: + # Create a tuple: PyObject* PyTuple_Pack(Py_ssize_t n, ...) + + # This should be Py_ssize_t + tuple_arg_types = [ll.IntType(32)] + + tuple_arg_types.extend([void_ptr]*len(vals)) + tuple_type = ll.FunctionType(void_ptr, tuple_arg_types) + tuple_fn = ll.Function(lj.module, tuple_type, "PyTuple_Pack") + + tuple_args = [ll.Constant(ll.IntType(32), len(wrapped_vals))] + tuple_args.extend(wrapped_vals) + + final_val = lj.builder.call(tuple_fn, tuple_args) + + return final_val + + def _convert_expr(self, lj, expr): + try: + # Match CSE return data structure. + if len(expr) == 2: + tmp_exprs = expr[0] + final_exprs = expr[1] + if len(final_exprs) != 1 and self.signature.ret_type == ctypes.c_double: + raise NotImplementedError("Return of multiple expressions not supported for this callback") + for name, e in tmp_exprs: + val = lj._print(e) + lj._add_tmp_var(name, val) + except TypeError: + final_exprs = [expr] + + vals = [lj._print(e) for e in final_exprs] + + return vals + + def _compile_function(self, strmod): + llmod = llvm.parse_assembly(strmod) + + pmb = llvm.create_pass_manager_builder() + pmb.opt_level = 2 + pass_manager = llvm.create_module_pass_manager() + pmb.populate(pass_manager) + + pass_manager.run(llmod) + + target_machine = \ + llvm.Target.from_default_triple().create_target_machine() + exe_eng = llvm.create_mcjit_compiler(llmod, target_machine) + exe_eng.finalize_object() + exe_engines.append(exe_eng) + + if False: + print("Assembly") + print(target_machine.emit_assembly(llmod)) + + fptr = exe_eng.get_function_address(self.link_name) + + return fptr + + +class LLVMJitCodeCallback(LLVMJitCode): + def __init__(self, signature): + super().__init__(signature) + + def _create_param_dict(self, func_args): + for i, a in enumerate(func_args): + if isinstance(a, IndexedBase): + self.param_dict[a] = (self.fn.args[i], i) + self.fn.args[i].name = str(a) + else: + self.param_dict[a] = (self.fn.args[self.signature.input_arg], + i) + + def _create_function(self, expr): + """Create function body and return LLVM IR""" + bb_entry = self.fn.append_basic_block('entry') + builder = ll.IRBuilder(bb_entry) + + lj = LLVMJitCallbackPrinter(self.module, builder, self.fn, + func_arg_map=self.param_dict) + + ret = self._convert_expr(lj, expr) + + if self.signature.ret_arg: + output_fp_ptr = builder.bitcast(self.fn.args[self.signature.ret_arg], + ll.PointerType(self.fp_type)) + for i, val in enumerate(ret): + index = ll.Constant(ll.IntType(32), i) + output_array_ptr = builder.gep(output_fp_ptr, [index]) + builder.store(val, output_array_ptr) + builder.ret(ll.Constant(ll.IntType(32), 0)) # return success + else: + lj.builder.ret(self._wrap_return(lj, ret)) + + strmod = str(self.module) + return strmod + + +class CodeSignature: + def __init__(self, ret_type): + self.ret_type = ret_type + self.arg_ctypes = [] + + # Input argument array element index + self.input_arg = 0 + + # For the case output value is referenced through a parameter rather + # than the return value + self.ret_arg = None + + +def _llvm_jit_code(args, expr, signature, callback_type): + """Create a native code function from a SymPy expression""" + if callback_type is None: + jit = LLVMJitCode(signature) + else: + jit = LLVMJitCodeCallback(signature) + + jit._create_args(args) + jit._create_function_base() + jit._create_param_dict(args) + strmod = jit._create_function(expr) + if False: + print("LLVM IR") + print(strmod) + fptr = jit._compile_function(strmod) + return fptr + + +@doctest_depends_on(modules=('llvmlite', 'scipy')) +def llvm_callable(args, expr, callback_type=None): + '''Compile function from a SymPy expression + + Expressions are evaluated using double precision arithmetic. + Some single argument math functions (exp, sin, cos, etc.) are supported + in expressions. + + Parameters + ========== + + args : List of Symbol + Arguments to the generated function. Usually the free symbols in + the expression. Currently each one is assumed to convert to + a double precision scalar. + expr : Expr, or (Replacements, Expr) as returned from 'cse' + Expression to compile. + callback_type : string + Create function with signature appropriate to use as a callback. + Currently supported: + 'scipy.integrate' + 'scipy.integrate.test' + 'cubature' + + Returns + ======= + + Compiled function that can evaluate the expression. + + Examples + ======== + + >>> import sympy.printing.llvmjitcode as jit + >>> from sympy.abc import a + >>> e = a*a + a + 1 + >>> e1 = jit.llvm_callable([a], e) + >>> e.subs(a, 1.1) # Evaluate via substitution + 3.31000000000000 + >>> e1(1.1) # Evaluate using JIT-compiled code + 3.3100000000000005 + + + Callbacks for integration functions can be JIT compiled. + + >>> import sympy.printing.llvmjitcode as jit + >>> from sympy.abc import a + >>> from sympy import integrate + >>> from scipy.integrate import quad + >>> e = a*a + >>> e1 = jit.llvm_callable([a], e, callback_type='scipy.integrate') + >>> integrate(e, (a, 0.0, 2.0)) + 2.66666666666667 + >>> quad(e1, 0.0, 2.0)[0] + 2.66666666666667 + + The 'cubature' callback is for the Python wrapper around the + cubature package ( https://github.com/saullocastro/cubature ) + and ( http://ab-initio.mit.edu/wiki/index.php/Cubature ) + + There are two signatures for the SciPy integration callbacks. + The first ('scipy.integrate') is the function to be passed to the + integration routine, and will pass the signature checks. + The second ('scipy.integrate.test') is only useful for directly calling + the function using ctypes variables. It will not pass the signature checks + for scipy.integrate. + + The return value from the cse module can also be compiled. This + can improve the performance of the compiled function. If multiple + expressions are given to cse, the compiled function returns a tuple. + The 'cubature' callback handles multiple expressions (set `fdim` + to match in the integration call.) + + >>> import sympy.printing.llvmjitcode as jit + >>> from sympy import cse + >>> from sympy.abc import x,y + >>> e1 = x*x + y*y + >>> e2 = 4*(x*x + y*y) + 8.0 + >>> after_cse = cse([e1,e2]) + >>> after_cse + ([(x0, x**2), (x1, y**2)], [x0 + x1, 4*x0 + 4*x1 + 8.0]) + >>> j1 = jit.llvm_callable([x,y], after_cse) + >>> j1(1.0, 2.0) + (5.0, 28.0) + ''' + + if not llvmlite: + raise ImportError("llvmlite is required for llvmjitcode") + + signature = CodeSignature(ctypes.py_object) + + arg_ctypes = [] + if callback_type is None: + for _ in args: + arg_ctype = ctypes.c_double + arg_ctypes.append(arg_ctype) + elif callback_type in ('scipy.integrate', 'scipy.integrate.test'): + signature.ret_type = ctypes.c_double + arg_ctypes = [ctypes.c_int, ctypes.POINTER(ctypes.c_double)] + arg_ctypes_formal = [ctypes.c_int, ctypes.c_double] + signature.input_arg = 1 + elif callback_type == 'cubature': + arg_ctypes = [ctypes.c_int, + ctypes.POINTER(ctypes.c_double), + ctypes.c_void_p, + ctypes.c_int, + ctypes.POINTER(ctypes.c_double) + ] + signature.ret_type = ctypes.c_int + signature.input_arg = 1 + signature.ret_arg = 4 + else: + raise ValueError("Unknown callback type: %s" % callback_type) + + signature.arg_ctypes = arg_ctypes + + fptr = _llvm_jit_code(args, expr, signature, callback_type) + + if callback_type and callback_type == 'scipy.integrate': + arg_ctypes = arg_ctypes_formal + + # PYFUNCTYPE holds the GIL which is needed to prevent a segfault when + # calling PyFloat_FromDouble on Python 3.10. Probably it is better to use + # ctypes.c_double when returning a float rather than using ctypes.py_object + # and returning a PyFloat from inside the jitted function (i.e. let ctypes + # handle the conversion from double to PyFloat). + if signature.ret_type == ctypes.py_object: + FUNCTYPE = ctypes.PYFUNCTYPE + else: + FUNCTYPE = ctypes.CFUNCTYPE + + cfunc = FUNCTYPE(signature.ret_type, *arg_ctypes)(fptr) + return cfunc diff --git a/phivenv/Lib/site-packages/sympy/printing/maple.py b/phivenv/Lib/site-packages/sympy/printing/maple.py new file mode 100644 index 0000000000000000000000000000000000000000..2c937cd262ab7f3ee5f32b3f4b5eb5633bc6bb3c --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/maple.py @@ -0,0 +1,311 @@ +""" +Maple code printer + +The MapleCodePrinter converts single SymPy expressions into single +Maple expressions, using the functions defined in the Maple objects where possible. + + +FIXME: This module is still under actively developed. Some functions may be not completed. +""" + +from sympy.core import S +from sympy.core.numbers import Integer, IntegerConstant, equal_valued +from sympy.printing.codeprinter import CodePrinter +from sympy.printing.precedence import precedence, PRECEDENCE + +import sympy + +_known_func_same_name = ( + 'sin', 'cos', 'tan', 'sec', 'csc', 'cot', 'sinh', 'cosh', 'tanh', 'sech', + 'csch', 'coth', 'exp', 'floor', 'factorial', 'bernoulli', 'euler', + 'fibonacci', 'gcd', 'lcm', 'conjugate', 'Ci', 'Chi', 'Ei', 'Li', 'Si', 'Shi', + 'erf', 'erfc', 'harmonic', 'LambertW', + 'sqrt', # For automatic rewrites +) + +known_functions = { + # SymPy -> Maple + 'Abs': 'abs', + 'log': 'ln', + 'asin': 'arcsin', + 'acos': 'arccos', + 'atan': 'arctan', + 'asec': 'arcsec', + 'acsc': 'arccsc', + 'acot': 'arccot', + 'asinh': 'arcsinh', + 'acosh': 'arccosh', + 'atanh': 'arctanh', + 'asech': 'arcsech', + 'acsch': 'arccsch', + 'acoth': 'arccoth', + 'ceiling': 'ceil', + 'Max' : 'max', + 'Min' : 'min', + + 'factorial2': 'doublefactorial', + 'RisingFactorial': 'pochhammer', + 'besseli': 'BesselI', + 'besselj': 'BesselJ', + 'besselk': 'BesselK', + 'bessely': 'BesselY', + 'hankelh1': 'HankelH1', + 'hankelh2': 'HankelH2', + 'airyai': 'AiryAi', + 'airybi': 'AiryBi', + 'appellf1': 'AppellF1', + 'fresnelc': 'FresnelC', + 'fresnels': 'FresnelS', + 'lerchphi' : 'LerchPhi', +} + +for _func in _known_func_same_name: + known_functions[_func] = _func + +number_symbols = { + # SymPy -> Maple + S.Pi: 'Pi', + S.Exp1: 'exp(1)', + S.Catalan: 'Catalan', + S.EulerGamma: 'gamma', + S.GoldenRatio: '(1/2 + (1/2)*sqrt(5))' +} + +spec_relational_ops = { + # SymPy -> Maple + '==': '=', + '!=': '<>' +} + +not_supported_symbol = [ + S.ComplexInfinity +] + +class MapleCodePrinter(CodePrinter): + """ + Printer which converts a SymPy expression into a maple code. + """ + printmethod = "_maple" + language = "maple" + + _operators = { + 'and': 'and', + 'or': 'or', + 'not': 'not ', + } + + _default_settings = dict(CodePrinter._default_settings, **{ + 'inline': True, + 'allow_unknown_functions': True, + }) + + def __init__(self, settings=None): + if settings is None: + settings = {} + super().__init__(settings) + self.known_functions = dict(known_functions) + userfuncs = settings.get('user_functions', {}) + self.known_functions.update(userfuncs) + + def _get_statement(self, codestring): + return "%s;" % codestring + + def _get_comment(self, text): + return "# {}".format(text) + + def _declare_number_const(self, name, value): + return "{} := {};".format(name, + value.evalf(self._settings['precision'])) + + def _format_code(self, lines): + return lines + + def _print_tuple(self, expr): + return self._print(list(expr)) + + def _print_Tuple(self, expr): + return self._print(list(expr)) + + def _print_Assignment(self, expr): + lhs = self._print(expr.lhs) + rhs = self._print(expr.rhs) + return "{lhs} := {rhs}".format(lhs=lhs, rhs=rhs) + + def _print_Pow(self, expr, **kwargs): + PREC = precedence(expr) + if equal_valued(expr.exp, -1): + return '1/%s' % (self.parenthesize(expr.base, PREC)) + elif equal_valued(expr.exp, 0.5): + return 'sqrt(%s)' % self._print(expr.base) + elif equal_valued(expr.exp, -0.5): + return '1/sqrt(%s)' % self._print(expr.base) + else: + return '{base}^{exp}'.format( + base=self.parenthesize(expr.base, PREC), + exp=self.parenthesize(expr.exp, PREC)) + + def _print_Piecewise(self, expr): + if (expr.args[-1].cond is not True) and (expr.args[-1].cond != S.BooleanTrue): + # We need the last conditional to be a True, otherwise the resulting + # function may not return a result. + raise ValueError("All Piecewise expressions must contain an " + "(expr, True) statement to be used as a default " + "condition. Without one, the generated " + "expression may not evaluate to anything under " + "some condition.") + _coup_list = [ + ("{c}, {e}".format(c=self._print(c), + e=self._print(e)) if c is not True and c is not S.BooleanTrue else "{e}".format( + e=self._print(e))) + for e, c in expr.args] + _inbrace = ', '.join(_coup_list) + return 'piecewise({_inbrace})'.format(_inbrace=_inbrace) + + def _print_Rational(self, expr): + p, q = int(expr.p), int(expr.q) + return "{p}/{q}".format(p=str(p), q=str(q)) + + def _print_Relational(self, expr): + PREC=precedence(expr) + lhs_code = self.parenthesize(expr.lhs, PREC) + rhs_code = self.parenthesize(expr.rhs, PREC) + op = expr.rel_op + if op in spec_relational_ops: + op = spec_relational_ops[op] + return "{lhs} {rel_op} {rhs}".format(lhs=lhs_code, rel_op=op, rhs=rhs_code) + + def _print_NumberSymbol(self, expr): + return number_symbols[expr] + + def _print_NegativeInfinity(self, expr): + return '-infinity' + + def _print_Infinity(self, expr): + return 'infinity' + + def _print_BooleanTrue(self, expr): + return "true" + + def _print_BooleanFalse(self, expr): + return "false" + + def _print_bool(self, expr): + return 'true' if expr else 'false' + + def _print_NaN(self, expr): + return 'undefined' + + def _get_matrix(self, expr, sparse=False): + if S.Zero in expr.shape: + _strM = 'Matrix([], storage = {storage})'.format( + storage='sparse' if sparse else 'rectangular') + else: + _strM = 'Matrix({list}, storage = {storage})'.format( + list=self._print(expr.tolist()), + storage='sparse' if sparse else 'rectangular') + return _strM + + def _print_MatrixElement(self, expr): + return "{parent}[{i_maple}, {j_maple}]".format( + parent=self.parenthesize(expr.parent, PRECEDENCE["Atom"], strict=True), + i_maple=self._print(expr.i + 1), + j_maple=self._print(expr.j + 1)) + + def _print_MatrixBase(self, expr): + return self._get_matrix(expr, sparse=False) + + def _print_SparseRepMatrix(self, expr): + return self._get_matrix(expr, sparse=True) + + def _print_Identity(self, expr): + if isinstance(expr.rows, (Integer, IntegerConstant)): + return self._print(sympy.SparseMatrix(expr)) + else: + return "Matrix({var_size}, shape = identity)".format(var_size=self._print(expr.rows)) + + def _print_MatMul(self, expr): + PREC=precedence(expr) + _fact_list = list(expr.args) + _const = None + if not isinstance(_fact_list[0], (sympy.MatrixBase, sympy.MatrixExpr, + sympy.MatrixSlice, sympy.MatrixSymbol)): + _const, _fact_list = _fact_list[0], _fact_list[1:] + + if _const is None or _const == 1: + return '.'.join(self.parenthesize(_m, PREC) for _m in _fact_list) + else: + return '{c}*{m}'.format(c=_const, m='.'.join(self.parenthesize(_m, PREC) for _m in _fact_list)) + + def _print_MatPow(self, expr): + # This function requires LinearAlgebra Function in Maple + return 'MatrixPower({A}, {n})'.format(A=self._print(expr.base), n=self._print(expr.exp)) + + def _print_HadamardProduct(self, expr): + PREC = precedence(expr) + _fact_list = list(expr.args) + return '*'.join(self.parenthesize(_m, PREC) for _m in _fact_list) + + def _print_Derivative(self, expr): + _f, (_var, _order) = expr.args + + if _order != 1: + _second_arg = '{var}${order}'.format(var=self._print(_var), + order=self._print(_order)) + else: + _second_arg = '{var}'.format(var=self._print(_var)) + return 'diff({func_expr}, {sec_arg})'.format(func_expr=self._print(_f), sec_arg=_second_arg) + + +def maple_code(expr, assign_to=None, **settings): + r"""Converts ``expr`` to a string of Maple code. + + Parameters + ========== + + expr : Expr + A SymPy expression to be converted. + assign_to : optional + When given, the argument is used as the name of the variable to which + the expression is assigned. Can be a string, ``Symbol``, + ``MatrixSymbol``, or ``Indexed`` type. This can be helpful for + expressions that generate multi-line statements. + precision : integer, optional + The precision for numbers such as pi [default=16]. + user_functions : dict, optional + A dictionary where keys are ``FunctionClass`` instances and values are + their string representations. Alternatively, the dictionary value can + be a list of tuples i.e. [(argument_test, cfunction_string)]. See + below for examples. + human : bool, optional + If True, the result is a single string that may contain some constant + declarations for the number symbols. If False, the same information is + returned in a tuple of (symbols_to_declare, not_supported_functions, + code_text). [default=True]. + contract: bool, optional + If True, ``Indexed`` instances are assumed to obey tensor contraction + rules and the corresponding nested loops over indices are generated. + Setting contract=False will not generate loops, instead the user is + responsible to provide values for the indices in the code. + [default=True]. + inline: bool, optional + If True, we try to create single-statement code instead of multiple + statements. [default=True]. + + """ + return MapleCodePrinter(settings).doprint(expr, assign_to) + + +def print_maple_code(expr, **settings): + """Prints the Maple representation of the given expression. + + See :func:`maple_code` for the meaning of the optional arguments. + + Examples + ======== + + >>> from sympy import print_maple_code, symbols + >>> x, y = symbols('x y') + >>> print_maple_code(x, assign_to=y) + y := x + """ + print(maple_code(expr, **settings)) diff --git a/phivenv/Lib/site-packages/sympy/printing/mathematica.py b/phivenv/Lib/site-packages/sympy/printing/mathematica.py new file mode 100644 index 0000000000000000000000000000000000000000..064925ec1a9a477d7509110c57311239bac9fcaf --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/mathematica.py @@ -0,0 +1,353 @@ +""" +Mathematica code printer +""" + +from __future__ import annotations +from typing import Any + +from sympy.core import Basic, Expr, Float +from sympy.core.sorting import default_sort_key + +from sympy.printing.codeprinter import CodePrinter +from sympy.printing.precedence import precedence + +# Used in MCodePrinter._print_Function(self) +known_functions = { + "exp": [(lambda x: True, "Exp")], + "log": [(lambda x: True, "Log")], + "sin": [(lambda x: True, "Sin")], + "cos": [(lambda x: True, "Cos")], + "tan": [(lambda x: True, "Tan")], + "cot": [(lambda x: True, "Cot")], + "sec": [(lambda x: True, "Sec")], + "csc": [(lambda x: True, "Csc")], + "asin": [(lambda x: True, "ArcSin")], + "acos": [(lambda x: True, "ArcCos")], + "atan": [(lambda x: True, "ArcTan")], + "acot": [(lambda x: True, "ArcCot")], + "asec": [(lambda x: True, "ArcSec")], + "acsc": [(lambda x: True, "ArcCsc")], + "sinh": [(lambda x: True, "Sinh")], + "cosh": [(lambda x: True, "Cosh")], + "tanh": [(lambda x: True, "Tanh")], + "coth": [(lambda x: True, "Coth")], + "sech": [(lambda x: True, "Sech")], + "csch": [(lambda x: True, "Csch")], + "asinh": [(lambda x: True, "ArcSinh")], + "acosh": [(lambda x: True, "ArcCosh")], + "atanh": [(lambda x: True, "ArcTanh")], + "acoth": [(lambda x: True, "ArcCoth")], + "asech": [(lambda x: True, "ArcSech")], + "acsch": [(lambda x: True, "ArcCsch")], + "sinc": [(lambda x: True, "Sinc")], + "conjugate": [(lambda x: True, "Conjugate")], + "Max": [(lambda *x: True, "Max")], + "Min": [(lambda *x: True, "Min")], + "erf": [(lambda x: True, "Erf")], + "erf2": [(lambda *x: True, "Erf")], + "erfc": [(lambda x: True, "Erfc")], + "erfi": [(lambda x: True, "Erfi")], + "erfinv": [(lambda x: True, "InverseErf")], + "erfcinv": [(lambda x: True, "InverseErfc")], + "erf2inv": [(lambda *x: True, "InverseErf")], + "expint": [(lambda *x: True, "ExpIntegralE")], + "Ei": [(lambda x: True, "ExpIntegralEi")], + "fresnelc": [(lambda x: True, "FresnelC")], + "fresnels": [(lambda x: True, "FresnelS")], + "gamma": [(lambda x: True, "Gamma")], + "uppergamma": [(lambda *x: True, "Gamma")], + "polygamma": [(lambda *x: True, "PolyGamma")], + "loggamma": [(lambda x: True, "LogGamma")], + "beta": [(lambda *x: True, "Beta")], + "Ci": [(lambda x: True, "CosIntegral")], + "Si": [(lambda x: True, "SinIntegral")], + "Chi": [(lambda x: True, "CoshIntegral")], + "Shi": [(lambda x: True, "SinhIntegral")], + "li": [(lambda x: True, "LogIntegral")], + "factorial": [(lambda x: True, "Factorial")], + "factorial2": [(lambda x: True, "Factorial2")], + "subfactorial": [(lambda x: True, "Subfactorial")], + "catalan": [(lambda x: True, "CatalanNumber")], + "harmonic": [(lambda *x: True, "HarmonicNumber")], + "lucas": [(lambda x: True, "LucasL")], + "RisingFactorial": [(lambda *x: True, "Pochhammer")], + "FallingFactorial": [(lambda *x: True, "FactorialPower")], + "laguerre": [(lambda *x: True, "LaguerreL")], + "assoc_laguerre": [(lambda *x: True, "LaguerreL")], + "hermite": [(lambda *x: True, "HermiteH")], + "jacobi": [(lambda *x: True, "JacobiP")], + "gegenbauer": [(lambda *x: True, "GegenbauerC")], + "chebyshevt": [(lambda *x: True, "ChebyshevT")], + "chebyshevu": [(lambda *x: True, "ChebyshevU")], + "legendre": [(lambda *x: True, "LegendreP")], + "assoc_legendre": [(lambda *x: True, "LegendreP")], + "mathieuc": [(lambda *x: True, "MathieuC")], + "mathieus": [(lambda *x: True, "MathieuS")], + "mathieucprime": [(lambda *x: True, "MathieuCPrime")], + "mathieusprime": [(lambda *x: True, "MathieuSPrime")], + "stieltjes": [(lambda x: True, "StieltjesGamma")], + "elliptic_e": [(lambda *x: True, "EllipticE")], + "elliptic_f": [(lambda *x: True, "EllipticE")], + "elliptic_k": [(lambda x: True, "EllipticK")], + "elliptic_pi": [(lambda *x: True, "EllipticPi")], + "zeta": [(lambda *x: True, "Zeta")], + "dirichlet_eta": [(lambda x: True, "DirichletEta")], + "riemann_xi": [(lambda x: True, "RiemannXi")], + "besseli": [(lambda *x: True, "BesselI")], + "besselj": [(lambda *x: True, "BesselJ")], + "besselk": [(lambda *x: True, "BesselK")], + "bessely": [(lambda *x: True, "BesselY")], + "hankel1": [(lambda *x: True, "HankelH1")], + "hankel2": [(lambda *x: True, "HankelH2")], + "airyai": [(lambda x: True, "AiryAi")], + "airybi": [(lambda x: True, "AiryBi")], + "airyaiprime": [(lambda x: True, "AiryAiPrime")], + "airybiprime": [(lambda x: True, "AiryBiPrime")], + "polylog": [(lambda *x: True, "PolyLog")], + "lerchphi": [(lambda *x: True, "LerchPhi")], + "gcd": [(lambda *x: True, "GCD")], + "lcm": [(lambda *x: True, "LCM")], + "jn": [(lambda *x: True, "SphericalBesselJ")], + "yn": [(lambda *x: True, "SphericalBesselY")], + "hyper": [(lambda *x: True, "HypergeometricPFQ")], + "meijerg": [(lambda *x: True, "MeijerG")], + "appellf1": [(lambda *x: True, "AppellF1")], + "DiracDelta": [(lambda x: True, "DiracDelta")], + "Heaviside": [(lambda x: True, "HeavisideTheta")], + "KroneckerDelta": [(lambda *x: True, "KroneckerDelta")], + "sqrt": [(lambda x: True, "Sqrt")], # For automatic rewrites +} + + +class MCodePrinter(CodePrinter): + """A printer to convert Python expressions to + strings of the Wolfram's Mathematica code + """ + printmethod = "_mcode" + language = "Wolfram Language" + + _default_settings: dict[str, Any] = dict(CodePrinter._default_settings, **{ + 'precision': 15, + 'user_functions': {}, + }) + + _number_symbols: set[tuple[Expr, Float]] = set() + _not_supported: set[Basic] = set() + + def __init__(self, settings={}): + """Register function mappings supplied by user""" + CodePrinter.__init__(self, settings) + self.known_functions = dict(known_functions) + userfuncs = settings.get('user_functions', {}).copy() + for k, v in userfuncs.items(): + if not isinstance(v, list): + userfuncs[k] = [(lambda *x: True, v)] + self.known_functions.update(userfuncs) + + def _format_code(self, lines): + return lines + + def _print_Pow(self, expr): + PREC = precedence(expr) + return '%s^%s' % (self.parenthesize(expr.base, PREC), + self.parenthesize(expr.exp, PREC)) + + def _print_Mul(self, expr): + PREC = precedence(expr) + c, nc = expr.args_cnc() + res = super()._print_Mul(expr.func(*c)) + if nc: + res += '*' + res += '**'.join(self.parenthesize(a, PREC) for a in nc) + return res + + def _print_Relational(self, expr): + lhs_code = self._print(expr.lhs) + rhs_code = self._print(expr.rhs) + op = expr.rel_op + return "{} {} {}".format(lhs_code, op, rhs_code) + + # Primitive numbers + def _print_Zero(self, expr): + return '0' + + def _print_One(self, expr): + return '1' + + def _print_NegativeOne(self, expr): + return '-1' + + def _print_Half(self, expr): + return '1/2' + + def _print_ImaginaryUnit(self, expr): + return 'I' + + + # Infinity and invalid numbers + def _print_Infinity(self, expr): + return 'Infinity' + + def _print_NegativeInfinity(self, expr): + return '-Infinity' + + def _print_ComplexInfinity(self, expr): + return 'ComplexInfinity' + + def _print_NaN(self, expr): + return 'Indeterminate' + + + # Mathematical constants + def _print_Exp1(self, expr): + return 'E' + + def _print_Pi(self, expr): + return 'Pi' + + def _print_GoldenRatio(self, expr): + return 'GoldenRatio' + + def _print_TribonacciConstant(self, expr): + expanded = expr.expand(func=True) + PREC = precedence(expr) + return self.parenthesize(expanded, PREC) + + def _print_EulerGamma(self, expr): + return 'EulerGamma' + + def _print_Catalan(self, expr): + return 'Catalan' + + + def _print_list(self, expr): + return '{' + ', '.join(self.doprint(a) for a in expr) + '}' + _print_tuple = _print_list + _print_Tuple = _print_list + + def _print_ImmutableDenseMatrix(self, expr): + return self.doprint(expr.tolist()) + + def _print_ImmutableSparseMatrix(self, expr): + + def print_rule(pos, val): + return '{} -> {}'.format( + self.doprint((pos[0]+1, pos[1]+1)), self.doprint(val)) + + def print_data(): + items = sorted(expr.todok().items(), key=default_sort_key) + return '{' + \ + ', '.join(print_rule(k, v) for k, v in items) + \ + '}' + + def print_dims(): + return self.doprint(expr.shape) + + return 'SparseArray[{}, {}]'.format(print_data(), print_dims()) + + def _print_ImmutableDenseNDimArray(self, expr): + return self.doprint(expr.tolist()) + + def _print_ImmutableSparseNDimArray(self, expr): + def print_string_list(string_list): + return '{' + ', '.join(a for a in string_list) + '}' + + def to_mathematica_index(*args): + """Helper function to change Python style indexing to + Pathematica indexing. + + Python indexing (0, 1 ... n-1) + -> Mathematica indexing (1, 2 ... n) + """ + return tuple(i + 1 for i in args) + + def print_rule(pos, val): + """Helper function to print a rule of Mathematica""" + return '{} -> {}'.format(self.doprint(pos), self.doprint(val)) + + def print_data(): + """Helper function to print data part of Mathematica + sparse array. + + It uses the fourth notation ``SparseArray[data,{d1,d2,...}]`` + from + https://reference.wolfram.com/language/ref/SparseArray.html + + ``data`` must be formatted with rule. + """ + return print_string_list( + [print_rule( + to_mathematica_index(*(expr._get_tuple_index(key))), + value) + for key, value in sorted(expr._sparse_array.items())] + ) + + def print_dims(): + """Helper function to print dimensions part of Mathematica + sparse array. + + It uses the fourth notation ``SparseArray[data,{d1,d2,...}]`` + from + https://reference.wolfram.com/language/ref/SparseArray.html + """ + return self.doprint(expr.shape) + + return 'SparseArray[{}, {}]'.format(print_data(), print_dims()) + + def _print_Function(self, expr): + if expr.func.__name__ in self.known_functions: + cond_mfunc = self.known_functions[expr.func.__name__] + for cond, mfunc in cond_mfunc: + if cond(*expr.args): + return "%s[%s]" % (mfunc, self.stringify(expr.args, ", ")) + elif expr.func.__name__ in self._rewriteable_functions: + # Simple rewrite to supported function possible + target_f, required_fs = self._rewriteable_functions[expr.func.__name__] + if self._can_print(target_f) and all(self._can_print(f) for f in required_fs): + return self._print(expr.rewrite(target_f)) + return expr.func.__name__ + "[%s]" % self.stringify(expr.args, ", ") + + _print_MinMaxBase = _print_Function + + def _print_LambertW(self, expr): + if len(expr.args) == 1: + return "ProductLog[{}]".format(self._print(expr.args[0])) + return "ProductLog[{}, {}]".format( + self._print(expr.args[1]), self._print(expr.args[0])) + + def _print_atan2(self, expr): + return "ArcTan[{}, {}]".format( + self._print(expr.args[1]), self._print(expr.args[0])) + + def _print_Integral(self, expr): + if len(expr.variables) == 1 and not expr.limits[0][1:]: + args = [expr.args[0], expr.variables[0]] + else: + args = expr.args + return "Hold[Integrate[" + ', '.join(self.doprint(a) for a in args) + "]]" + + def _print_Sum(self, expr): + return "Hold[Sum[" + ', '.join(self.doprint(a) for a in expr.args) + "]]" + + def _print_Derivative(self, expr): + dexpr = expr.expr + dvars = [i[0] if i[1] == 1 else i for i in expr.variable_count] + return "Hold[D[" + ', '.join(self.doprint(a) for a in [dexpr] + dvars) + "]]" + + + def _get_comment(self, text): + return "(* {} *)".format(text) + + +def mathematica_code(expr, **settings): + r"""Converts an expr to a string of the Wolfram Mathematica code + + Examples + ======== + + >>> from sympy import mathematica_code as mcode, symbols, sin + >>> x = symbols('x') + >>> mcode(sin(x).series(x).removeO()) + '(1/120)*x^5 - 1/6*x^3 + x' + """ + return MCodePrinter(settings).doprint(expr) diff --git a/phivenv/Lib/site-packages/sympy/printing/mathml.py b/phivenv/Lib/site-packages/sympy/printing/mathml.py new file mode 100644 index 0000000000000000000000000000000000000000..4dff74cd64b17036d3ff2a766253c9af850f088d --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/mathml.py @@ -0,0 +1,2157 @@ +""" +A MathML printer. +""" + +from __future__ import annotations +from typing import Any + +from sympy.core.mul import Mul +from sympy.core.singleton import S +from sympy.core.sorting import default_sort_key +from sympy.core.sympify import sympify +from sympy.printing.conventions import split_super_sub, requires_partial +from sympy.printing.precedence import \ + precedence_traditional, PRECEDENCE, PRECEDENCE_TRADITIONAL +from sympy.printing.pretty.pretty_symbology import greek_unicode +from sympy.printing.printer import Printer, print_function + +from mpmath.libmp import prec_to_dps, repr_dps, to_str as mlib_to_str + + +class MathMLPrinterBase(Printer): + """Contains common code required for MathMLContentPrinter and + MathMLPresentationPrinter. + """ + + _default_settings: dict[str, Any] = { + "order": None, + "encoding": "utf-8", + "fold_frac_powers": False, + "fold_func_brackets": False, + "fold_short_frac": None, + "inv_trig_style": "abbreviated", + "ln_notation": False, + "long_frac_ratio": None, + "mat_delim": "[", + "mat_symbol_style": "plain", + "mul_symbol": None, + "root_notation": True, + "symbol_names": {}, + "mul_symbol_mathml_numbers": '·', + "disable_split_super_sub": False, + } + + def __init__(self, settings=None): + Printer.__init__(self, settings) + from xml.dom.minidom import Document, Text + + self.dom = Document() + + # Workaround to allow strings to remain unescaped + # Based on + # https://stackoverflow.com/questions/38015864/python-xml-dom-minidom-\ + # please-dont-escape-my-strings/38041194 + class RawText(Text): + def writexml(self, writer, indent='', addindent='', newl=''): + if self.data: + writer.write('{}{}{}'.format(indent, self.data, newl)) + + def createRawTextNode(data): + r = RawText() + r.data = data + r.ownerDocument = self.dom + return r + + self.dom.createTextNode = createRawTextNode + + def doprint(self, expr): + """ + Prints the expression as MathML. + """ + mathML = Printer._print(self, expr) + unistr = mathML.toxml() + xmlbstr = unistr.encode('ascii', 'xmlcharrefreplace') + res = xmlbstr.decode() + return res + + def _split_super_sub(self, name): + if self._settings["disable_split_super_sub"]: + return (name, [], []) + else: + return split_super_sub(name) + + +class MathMLContentPrinter(MathMLPrinterBase): + """Prints an expression to the Content MathML markup language. + + References: https://www.w3.org/TR/MathML2/chapter4.html + """ + printmethod = "_mathml_content" + + def mathml_tag(self, e): + """Returns the MathML tag for an expression.""" + translate = { + 'Add': 'plus', + 'Mul': 'times', + 'Derivative': 'diff', + 'Number': 'cn', + 'int': 'cn', + 'Pow': 'power', + 'Max': 'max', + 'Min': 'min', + 'Abs': 'abs', + 'And': 'and', + 'Or': 'or', + 'Xor': 'xor', + 'Not': 'not', + 'Implies': 'implies', + 'Symbol': 'ci', + 'MatrixSymbol': 'ci', + 'RandomSymbol': 'ci', + 'Integral': 'int', + 'Sum': 'sum', + 'sin': 'sin', + 'cos': 'cos', + 'tan': 'tan', + 'cot': 'cot', + 'csc': 'csc', + 'sec': 'sec', + 'sinh': 'sinh', + 'cosh': 'cosh', + 'tanh': 'tanh', + 'coth': 'coth', + 'csch': 'csch', + 'sech': 'sech', + 'asin': 'arcsin', + 'asinh': 'arcsinh', + 'acos': 'arccos', + 'acosh': 'arccosh', + 'atan': 'arctan', + 'atanh': 'arctanh', + 'atan2': 'arctan', + 'acot': 'arccot', + 'acoth': 'arccoth', + 'asec': 'arcsec', + 'asech': 'arcsech', + 'acsc': 'arccsc', + 'acsch': 'arccsch', + 'log': 'ln', + 'Equality': 'eq', + 'Unequality': 'neq', + 'GreaterThan': 'geq', + 'LessThan': 'leq', + 'StrictGreaterThan': 'gt', + 'StrictLessThan': 'lt', + 'Union': 'union', + 'Intersection': 'intersect', + } + + for cls in e.__class__.__mro__: + n = cls.__name__ + if n in translate: + return translate[n] + # Not found in the MRO set + n = e.__class__.__name__ + return n.lower() + + def _print_Mul(self, expr): + + if expr.could_extract_minus_sign(): + x = self.dom.createElement('apply') + x.appendChild(self.dom.createElement('minus')) + x.appendChild(self._print_Mul(-expr)) + return x + + from sympy.simplify import fraction + numer, denom = fraction(expr) + + if denom is not S.One: + x = self.dom.createElement('apply') + x.appendChild(self.dom.createElement('divide')) + x.appendChild(self._print(numer)) + x.appendChild(self._print(denom)) + return x + + coeff, terms = expr.as_coeff_mul() + if coeff is S.One and len(terms) == 1: + # XXX since the negative coefficient has been handled, I don't + # think a coeff of 1 can remain + return self._print(terms[0]) + + if self.order != 'old': + terms = Mul._from_args(terms).as_ordered_factors() + + x = self.dom.createElement('apply') + x.appendChild(self.dom.createElement('times')) + if coeff != 1: + x.appendChild(self._print(coeff)) + for term in terms: + x.appendChild(self._print(term)) + return x + + def _print_Add(self, expr, order=None): + args = self._as_ordered_terms(expr, order=order) + lastProcessed = self._print(args[0]) + plusNodes = [] + for arg in args[1:]: + if arg.could_extract_minus_sign(): + # use minus + x = self.dom.createElement('apply') + x.appendChild(self.dom.createElement('minus')) + x.appendChild(lastProcessed) + x.appendChild(self._print(-arg)) + # invert expression since this is now minused + lastProcessed = x + if arg == args[-1]: + plusNodes.append(lastProcessed) + else: + plusNodes.append(lastProcessed) + lastProcessed = self._print(arg) + if arg == args[-1]: + plusNodes.append(self._print(arg)) + if len(plusNodes) == 1: + return lastProcessed + x = self.dom.createElement('apply') + x.appendChild(self.dom.createElement('plus')) + while plusNodes: + x.appendChild(plusNodes.pop(0)) + return x + + def _print_Piecewise(self, expr): + if expr.args[-1].cond != True: + # We need the last conditional to be a True, otherwise the resulting + # function may not return a result. + raise ValueError("All Piecewise expressions must contain an " + "(expr, True) statement to be used as a default " + "condition. Without one, the generated " + "expression may not evaluate to anything under " + "some condition.") + root = self.dom.createElement('piecewise') + for i, (e, c) in enumerate(expr.args): + if i == len(expr.args) - 1 and c == True: + piece = self.dom.createElement('otherwise') + piece.appendChild(self._print(e)) + else: + piece = self.dom.createElement('piece') + piece.appendChild(self._print(e)) + piece.appendChild(self._print(c)) + root.appendChild(piece) + return root + + def _print_MatrixBase(self, m): + x = self.dom.createElement('matrix') + for i in range(m.rows): + x_r = self.dom.createElement('matrixrow') + for j in range(m.cols): + x_r.appendChild(self._print(m[i, j])) + x.appendChild(x_r) + return x + + def _print_Rational(self, e): + if e.q == 1: + # don't divide + x = self.dom.createElement('cn') + x.appendChild(self.dom.createTextNode(str(e.p))) + return x + x = self.dom.createElement('apply') + x.appendChild(self.dom.createElement('divide')) + # numerator + xnum = self.dom.createElement('cn') + xnum.appendChild(self.dom.createTextNode(str(e.p))) + # denominator + xdenom = self.dom.createElement('cn') + xdenom.appendChild(self.dom.createTextNode(str(e.q))) + x.appendChild(xnum) + x.appendChild(xdenom) + return x + + def _print_Limit(self, e): + x = self.dom.createElement('apply') + x.appendChild(self.dom.createElement(self.mathml_tag(e))) + + x_1 = self.dom.createElement('bvar') + x_2 = self.dom.createElement('lowlimit') + x_1.appendChild(self._print(e.args[1])) + x_2.appendChild(self._print(e.args[2])) + + x.appendChild(x_1) + x.appendChild(x_2) + x.appendChild(self._print(e.args[0])) + return x + + def _print_ImaginaryUnit(self, e): + return self.dom.createElement('imaginaryi') + + def _print_EulerGamma(self, e): + return self.dom.createElement('eulergamma') + + def _print_GoldenRatio(self, e): + """We use unicode #x3c6 for Greek letter phi as defined here + https://www.w3.org/2003/entities/2007doc/isogrk1.html""" + x = self.dom.createElement('cn') + x.appendChild(self.dom.createTextNode("\N{GREEK SMALL LETTER PHI}")) + return x + + def _print_Exp1(self, e): + return self.dom.createElement('exponentiale') + + def _print_Pi(self, e): + return self.dom.createElement('pi') + + def _print_Infinity(self, e): + return self.dom.createElement('infinity') + + def _print_NaN(self, e): + return self.dom.createElement('notanumber') + + def _print_EmptySet(self, e): + return self.dom.createElement('emptyset') + + def _print_BooleanTrue(self, e): + return self.dom.createElement('true') + + def _print_BooleanFalse(self, e): + return self.dom.createElement('false') + + def _print_NegativeInfinity(self, e): + x = self.dom.createElement('apply') + x.appendChild(self.dom.createElement('minus')) + x.appendChild(self.dom.createElement('infinity')) + return x + + def _print_Integral(self, e): + def lime_recur(limits): + x = self.dom.createElement('apply') + x.appendChild(self.dom.createElement(self.mathml_tag(e))) + bvar_elem = self.dom.createElement('bvar') + bvar_elem.appendChild(self._print(limits[0][0])) + x.appendChild(bvar_elem) + + if len(limits[0]) == 3: + low_elem = self.dom.createElement('lowlimit') + low_elem.appendChild(self._print(limits[0][1])) + x.appendChild(low_elem) + up_elem = self.dom.createElement('uplimit') + up_elem.appendChild(self._print(limits[0][2])) + x.appendChild(up_elem) + if len(limits[0]) == 2: + up_elem = self.dom.createElement('uplimit') + up_elem.appendChild(self._print(limits[0][1])) + x.appendChild(up_elem) + if len(limits) == 1: + x.appendChild(self._print(e.function)) + else: + x.appendChild(lime_recur(limits[1:])) + return x + + limits = list(e.limits) + limits.reverse() + return lime_recur(limits) + + def _print_Sum(self, e): + # Printer can be shared because Sum and Integral have the + # same internal representation. + return self._print_Integral(e) + + def _print_Symbol(self, sym): + ci = self.dom.createElement(self.mathml_tag(sym)) + + def join(items): + if len(items) > 1: + mrow = self.dom.createElement('mml:mrow') + for i, item in enumerate(items): + if i > 0: + mo = self.dom.createElement('mml:mo') + mo.appendChild(self.dom.createTextNode(" ")) + mrow.appendChild(mo) + mi = self.dom.createElement('mml:mi') + mi.appendChild(self.dom.createTextNode(item)) + mrow.appendChild(mi) + return mrow + else: + mi = self.dom.createElement('mml:mi') + mi.appendChild(self.dom.createTextNode(items[0])) + return mi + + # translate name, supers and subs to unicode characters + def translate(s): + if s in greek_unicode: + return greek_unicode.get(s) + else: + return s + + name, supers, subs = self._split_super_sub(sym.name) + name = translate(name) + supers = [translate(sup) for sup in supers] + subs = [translate(sub) for sub in subs] + + mname = self.dom.createElement('mml:mi') + mname.appendChild(self.dom.createTextNode(name)) + if not supers: + if not subs: + ci.appendChild(self.dom.createTextNode(name)) + else: + msub = self.dom.createElement('mml:msub') + msub.appendChild(mname) + msub.appendChild(join(subs)) + ci.appendChild(msub) + else: + if not subs: + msup = self.dom.createElement('mml:msup') + msup.appendChild(mname) + msup.appendChild(join(supers)) + ci.appendChild(msup) + else: + msubsup = self.dom.createElement('mml:msubsup') + msubsup.appendChild(mname) + msubsup.appendChild(join(subs)) + msubsup.appendChild(join(supers)) + ci.appendChild(msubsup) + return ci + + _print_MatrixSymbol = _print_Symbol + _print_RandomSymbol = _print_Symbol + + def _print_Pow(self, e): + # Here we use root instead of power if the exponent is the reciprocal + # of an integer + if (self._settings['root_notation'] and e.exp.is_Rational + and e.exp.p == 1): + x = self.dom.createElement('apply') + x.appendChild(self.dom.createElement('root')) + if e.exp.q != 2: + xmldeg = self.dom.createElement('degree') + xmlcn = self.dom.createElement('cn') + xmlcn.appendChild(self.dom.createTextNode(str(e.exp.q))) + xmldeg.appendChild(xmlcn) + x.appendChild(xmldeg) + x.appendChild(self._print(e.base)) + return x + + x = self.dom.createElement('apply') + x_1 = self.dom.createElement(self.mathml_tag(e)) + x.appendChild(x_1) + x.appendChild(self._print(e.base)) + x.appendChild(self._print(e.exp)) + return x + + def _print_Number(self, e): + x = self.dom.createElement(self.mathml_tag(e)) + x.appendChild(self.dom.createTextNode(str(e))) + return x + + def _print_Float(self, e): + x = self.dom.createElement(self.mathml_tag(e)) + repr_e = mlib_to_str(e._mpf_, repr_dps(e._prec)) + x.appendChild(self.dom.createTextNode(repr_e)) + return x + + def _print_Derivative(self, e): + x = self.dom.createElement('apply') + diff_symbol = self.mathml_tag(e) + if requires_partial(e.expr): + diff_symbol = 'partialdiff' + x.appendChild(self.dom.createElement(diff_symbol)) + x_1 = self.dom.createElement('bvar') + + for sym, times in reversed(e.variable_count): + x_1.appendChild(self._print(sym)) + if times > 1: + degree = self.dom.createElement('degree') + degree.appendChild(self._print(sympify(times))) + x_1.appendChild(degree) + + x.appendChild(x_1) + x.appendChild(self._print(e.expr)) + return x + + def _print_Function(self, e): + x = self.dom.createElement("apply") + x.appendChild(self.dom.createElement(self.mathml_tag(e))) + for arg in e.args: + x.appendChild(self._print(arg)) + return x + + def _print_Basic(self, e): + x = self.dom.createElement(self.mathml_tag(e)) + for arg in e.args: + x.appendChild(self._print(arg)) + return x + + def _print_AssocOp(self, e): + x = self.dom.createElement('apply') + x_1 = self.dom.createElement(self.mathml_tag(e)) + x.appendChild(x_1) + for arg in e.args: + x.appendChild(self._print(arg)) + return x + + def _print_Relational(self, e): + x = self.dom.createElement('apply') + x.appendChild(self.dom.createElement(self.mathml_tag(e))) + x.appendChild(self._print(e.lhs)) + x.appendChild(self._print(e.rhs)) + return x + + def _print_list(self, seq): + """MathML reference for the element: + https://www.w3.org/TR/MathML2/chapter4.html#contm.list""" + dom_element = self.dom.createElement('list') + for item in seq: + dom_element.appendChild(self._print(item)) + return dom_element + + def _print_int(self, p): + dom_element = self.dom.createElement(self.mathml_tag(p)) + dom_element.appendChild(self.dom.createTextNode(str(p))) + return dom_element + + _print_Implies = _print_AssocOp + _print_Not = _print_AssocOp + _print_Xor = _print_AssocOp + + def _print_FiniteSet(self, e): + x = self.dom.createElement('set') + for arg in e.args: + x.appendChild(self._print(arg)) + return x + + def _print_Complement(self, e): + x = self.dom.createElement('apply') + x.appendChild(self.dom.createElement('setdiff')) + for arg in e.args: + x.appendChild(self._print(arg)) + return x + + def _print_ProductSet(self, e): + x = self.dom.createElement('apply') + x.appendChild(self.dom.createElement('cartesianproduct')) + for arg in e.args: + x.appendChild(self._print(arg)) + return x + + def _print_Lambda(self, e): + # MathML reference for the lambda element: + # https://www.w3.org/TR/MathML2/chapter4.html#id.4.2.1.7 + x = self.dom.createElement(self.mathml_tag(e)) + for arg in e.signature: + x_1 = self.dom.createElement('bvar') + x_1.appendChild(self._print(arg)) + x.appendChild(x_1) + x.appendChild(self._print(e.expr)) + return x + + # XXX Symmetric difference is not supported for MathML content printers. + + +class MathMLPresentationPrinter(MathMLPrinterBase): + """Prints an expression to the Presentation MathML markup language. + + References: https://www.w3.org/TR/MathML2/chapter3.html + """ + printmethod = "_mathml_presentation" + + def mathml_tag(self, e): + """Returns the MathML tag for an expression.""" + translate = { + 'Number': 'mn', + 'Limit': '→', + 'Derivative': 'ⅆ', + 'int': 'mn', + 'Symbol': 'mi', + 'Integral': '∫', + 'Sum': '∑', + 'sin': 'sin', + 'cos': 'cos', + 'tan': 'tan', + 'cot': 'cot', + 'asin': 'arcsin', + 'asinh': 'arcsinh', + 'acos': 'arccos', + 'acosh': 'arccosh', + 'atan': 'arctan', + 'atanh': 'arctanh', + 'acot': 'arccot', + 'atan2': 'arctan', + 'Equality': '=', + 'Unequality': '≠', + 'GreaterThan': '≥', + 'LessThan': '≤', + 'StrictGreaterThan': '>', + 'StrictLessThan': '<', + 'lerchphi': 'Φ', + 'zeta': 'ζ', + 'dirichlet_eta': 'η', + 'elliptic_k': 'Κ', + 'lowergamma': 'γ', + 'uppergamma': 'Γ', + 'gamma': 'Γ', + 'totient': 'ϕ', + 'reduced_totient': 'λ', + 'primenu': 'ν', + 'primeomega': 'Ω', + 'fresnels': 'S', + 'fresnelc': 'C', + 'LambertW': 'W', + 'Heaviside': 'Θ', + 'BooleanTrue': 'True', + 'BooleanFalse': 'False', + 'NoneType': 'None', + 'mathieus': 'S', + 'mathieuc': 'C', + 'mathieusprime': 'S′', + 'mathieucprime': 'C′', + 'Lambda': 'lambda', + } + + def mul_symbol_selection(): + if (self._settings["mul_symbol"] is None or + self._settings["mul_symbol"] == 'None'): + return '⁢' + elif self._settings["mul_symbol"] == 'times': + return '×' + elif self._settings["mul_symbol"] == 'dot': + return '·' + elif self._settings["mul_symbol"] == 'ldot': + return '․' + elif not isinstance(self._settings["mul_symbol"], str): + raise TypeError + else: + return self._settings["mul_symbol"] + for cls in e.__class__.__mro__: + n = cls.__name__ + if n in translate: + return translate[n] + # Not found in the MRO set + if e.__class__.__name__ == "Mul": + return mul_symbol_selection() + n = e.__class__.__name__ + return n.lower() + + def _l_paren(self): + mo = self.dom.createElement('mo') + mo.appendChild(self.dom.createTextNode('(')) + return mo + + def _r_paren(self): + mo = self.dom.createElement('mo') + mo.appendChild(self.dom.createTextNode(')')) + return mo + + def _l_brace(self): + mo = self.dom.createElement('mo') + mo.appendChild(self.dom.createTextNode('{')) + return mo + + def _r_brace(self): + mo = self.dom.createElement('mo') + mo.appendChild(self.dom.createTextNode('}')) + return mo + + def _comma(self): + mo = self.dom.createElement('mo') + mo.appendChild(self.dom.createTextNode(',')) + return mo + + def _bar(self): + mo = self.dom.createElement('mo') + mo.appendChild(self.dom.createTextNode('|')) + return mo + + def _semicolon(self): + mo = self.dom.createElement('mo') + mo.appendChild(self.dom.createTextNode(';')) + return mo + + def _paren_comma_separated(self, *args): + mrow = self.dom.createElement('mrow') + mrow.appendChild(self._l_paren()) + for i, arg in enumerate(args): + if i: + mrow.appendChild(self._comma()) + mrow.appendChild(self._print(arg)) + mrow.appendChild(self._r_paren()) + return mrow + + def _paren_bar_separated(self, *args): + mrow = self.dom.createElement('mrow') + mrow.appendChild(self._l_paren()) + for i, arg in enumerate(args): + if i: + mrow.appendChild(self._bar()) + mrow.appendChild(self._print(arg)) + mrow.appendChild(self._r_paren()) + return mrow + + def parenthesize(self, item, level, strict=False): + prec_val = precedence_traditional(item) + if (prec_val < level) or ((not strict) and prec_val <= level): + mrow = self.dom.createElement('mrow') + mrow.appendChild(self._l_paren()) + mrow.appendChild(self._print(item)) + mrow.appendChild(self._r_paren()) + return mrow + return self._print(item) + + def _print_Mul(self, expr): + + def multiply(expr, mrow): + from sympy.simplify import fraction + numer, denom = fraction(expr) + if denom is not S.One: + frac = self.dom.createElement('mfrac') + if self._settings["fold_short_frac"] and len(str(expr)) < 7: + frac.setAttribute('bevelled', 'true') + xnum = self._print(numer) + xden = self._print(denom) + frac.appendChild(xnum) + frac.appendChild(xden) + mrow.appendChild(frac) + return mrow + + coeff, terms = expr.as_coeff_mul() + if coeff is S.One and len(terms) == 1: + mrow.appendChild(self._print(terms[0])) + return mrow + if self.order != 'old': + terms = Mul._from_args(terms).as_ordered_factors() + + if coeff != 1: + x = self._print(coeff) + y = self.dom.createElement('mo') + y.appendChild(self.dom.createTextNode(self.mathml_tag(expr))) + mrow.appendChild(x) + mrow.appendChild(y) + for term in terms: + mrow.appendChild(self.parenthesize(term, PRECEDENCE['Mul'])) + if not term == terms[-1]: + y = self.dom.createElement('mo') + y.appendChild(self.dom.createTextNode(self.mathml_tag(expr))) + mrow.appendChild(y) + return mrow + mrow = self.dom.createElement('mrow') + if expr.could_extract_minus_sign(): + x = self.dom.createElement('mo') + x.appendChild(self.dom.createTextNode('-')) + mrow.appendChild(x) + mrow = multiply(-expr, mrow) + else: + mrow = multiply(expr, mrow) + + return mrow + + def _print_Add(self, expr, order=None): + mrow = self.dom.createElement('mrow') + args = self._as_ordered_terms(expr, order=order) + mrow.appendChild(self._print(args[0])) + for arg in args[1:]: + if arg.could_extract_minus_sign(): + # use minus + x = self.dom.createElement('mo') + x.appendChild(self.dom.createTextNode('-')) + y = self._print(-arg) + # invert expression since this is now minused + else: + x = self.dom.createElement('mo') + x.appendChild(self.dom.createTextNode('+')) + y = self._print(arg) + mrow.appendChild(x) + mrow.appendChild(y) + + return mrow + + def _print_MatrixBase(self, m): + table = self.dom.createElement('mtable') + for i in range(m.rows): + x = self.dom.createElement('mtr') + for j in range(m.cols): + y = self.dom.createElement('mtd') + y.appendChild(self._print(m[i, j])) + x.appendChild(y) + table.appendChild(x) + mat_delim = self._settings["mat_delim"] + if mat_delim == '': + return table + left = self.dom.createElement('mo') + right = self.dom.createElement('mo') + if mat_delim == "[": + left.appendChild(self.dom.createTextNode("[")) + right.appendChild(self.dom.createTextNode("]")) + else: + left.appendChild(self.dom.createTextNode("(")) + right.appendChild(self.dom.createTextNode(")")) + mrow = self.dom.createElement('mrow') + mrow.appendChild(left) + mrow.appendChild(table) + mrow.appendChild(right) + return mrow + + def _get_printed_Rational(self, e, folded=None): + if e.p < 0: + p = -e.p + else: + p = e.p + x = self.dom.createElement('mfrac') + if folded or self._settings["fold_short_frac"]: + x.setAttribute('bevelled', 'true') + x.appendChild(self._print(p)) + x.appendChild(self._print(e.q)) + if e.p < 0: + mrow = self.dom.createElement('mrow') + mo = self.dom.createElement('mo') + mo.appendChild(self.dom.createTextNode('-')) + mrow.appendChild(mo) + mrow.appendChild(x) + return mrow + else: + return x + + def _print_Rational(self, e): + if e.q == 1: + # don't divide + return self._print(e.p) + + return self._get_printed_Rational(e, self._settings["fold_short_frac"]) + + def _print_Limit(self, e): + mrow = self.dom.createElement('mrow') + munder = self.dom.createElement('munder') + mi = self.dom.createElement('mi') + mi.appendChild(self.dom.createTextNode('lim')) + + x = self.dom.createElement('mrow') + x_1 = self._print(e.args[1]) + arrow = self.dom.createElement('mo') + arrow.appendChild(self.dom.createTextNode(self.mathml_tag(e))) + x_2 = self._print(e.args[2]) + x.appendChild(x_1) + x.appendChild(arrow) + x.appendChild(x_2) + + munder.appendChild(mi) + munder.appendChild(x) + mrow.appendChild(munder) + mrow.appendChild(self._print(e.args[0])) + + return mrow + + def _print_ImaginaryUnit(self, e): + x = self.dom.createElement('mi') + x.appendChild(self.dom.createTextNode('ⅈ')) + return x + + def _print_GoldenRatio(self, e): + x = self.dom.createElement('mi') + x.appendChild(self.dom.createTextNode('Φ')) + return x + + def _print_Exp1(self, e): + x = self.dom.createElement('mi') + x.appendChild(self.dom.createTextNode('ⅇ')) + return x + + def _print_Pi(self, e): + x = self.dom.createElement('mi') + x.appendChild(self.dom.createTextNode('π')) + return x + + def _print_Infinity(self, e): + x = self.dom.createElement('mi') + x.appendChild(self.dom.createTextNode('∞')) + return x + + def _print_NegativeInfinity(self, e): + mrow = self.dom.createElement('mrow') + y = self.dom.createElement('mo') + y.appendChild(self.dom.createTextNode('-')) + x = self._print_Infinity(e) + mrow.appendChild(y) + mrow.appendChild(x) + return mrow + + def _print_HBar(self, e): + x = self.dom.createElement('mi') + x.appendChild(self.dom.createTextNode('ℏ')) + return x + + def _print_EulerGamma(self, e): + x = self.dom.createElement('mi') + x.appendChild(self.dom.createTextNode('γ')) + return x + + def _print_TribonacciConstant(self, e): + x = self.dom.createElement('mi') + x.appendChild(self.dom.createTextNode('TribonacciConstant')) + return x + + def _print_Dagger(self, e): + msup = self.dom.createElement('msup') + msup.appendChild(self._print(e.args[0])) + msup.appendChild(self.dom.createTextNode('†')) + return msup + + def _print_Contains(self, e): + mrow = self.dom.createElement('mrow') + mrow.appendChild(self._print(e.args[0])) + mo = self.dom.createElement('mo') + mo.appendChild(self.dom.createTextNode('∈')) + mrow.appendChild(mo) + mrow.appendChild(self._print(e.args[1])) + return mrow + + def _print_HilbertSpace(self, e): + x = self.dom.createElement('mi') + x.appendChild(self.dom.createTextNode('ℋ')) + return x + + def _print_ComplexSpace(self, e): + msup = self.dom.createElement('msup') + msup.appendChild(self.dom.createTextNode('𝒞')) + msup.appendChild(self._print(e.args[0])) + return msup + + def _print_FockSpace(self, e): + x = self.dom.createElement('mi') + x.appendChild(self.dom.createTextNode('ℱ')) + return x + + + def _print_Integral(self, expr): + intsymbols = {1: "∫", 2: "∬", 3: "∭"} + + mrow = self.dom.createElement('mrow') + if len(expr.limits) <= 3 and all(len(lim) == 1 for lim in expr.limits): + # Only up to three-integral signs exists + mo = self.dom.createElement('mo') + mo.appendChild(self.dom.createTextNode(intsymbols[len(expr.limits)])) + mrow.appendChild(mo) + else: + # Either more than three or limits provided + for lim in reversed(expr.limits): + mo = self.dom.createElement('mo') + mo.appendChild(self.dom.createTextNode(intsymbols[1])) + if len(lim) == 1: + mrow.appendChild(mo) + if len(lim) == 2: + msup = self.dom.createElement('msup') + msup.appendChild(mo) + msup.appendChild(self._print(lim[1])) + mrow.appendChild(msup) + if len(lim) == 3: + msubsup = self.dom.createElement('msubsup') + msubsup.appendChild(mo) + msubsup.appendChild(self._print(lim[1])) + msubsup.appendChild(self._print(lim[2])) + mrow.appendChild(msubsup) + # print function + mrow.appendChild(self.parenthesize(expr.function, PRECEDENCE["Mul"], + strict=True)) + # print integration variables + for lim in reversed(expr.limits): + d = self.dom.createElement('mo') + d.appendChild(self.dom.createTextNode('ⅆ')) + mrow.appendChild(d) + mrow.appendChild(self._print(lim[0])) + return mrow + + def _print_Sum(self, e): + limits = list(e.limits) + subsup = self.dom.createElement('munderover') + low_elem = self._print(limits[0][1]) + up_elem = self._print(limits[0][2]) + summand = self.dom.createElement('mo') + summand.appendChild(self.dom.createTextNode(self.mathml_tag(e))) + + low = self.dom.createElement('mrow') + var = self._print(limits[0][0]) + equal = self.dom.createElement('mo') + equal.appendChild(self.dom.createTextNode('=')) + low.appendChild(var) + low.appendChild(equal) + low.appendChild(low_elem) + + subsup.appendChild(summand) + subsup.appendChild(low) + subsup.appendChild(up_elem) + + mrow = self.dom.createElement('mrow') + mrow.appendChild(subsup) + mrow.appendChild(self.parenthesize(e.function, precedence_traditional(e))) + return mrow + + def _print_Symbol(self, sym, style='plain'): + def join(items): + if len(items) > 1: + mrow = self.dom.createElement('mrow') + for i, item in enumerate(items): + if i > 0: + mo = self.dom.createElement('mo') + mo.appendChild(self.dom.createTextNode(" ")) + mrow.appendChild(mo) + mi = self.dom.createElement('mi') + mi.appendChild(self.dom.createTextNode(item)) + mrow.appendChild(mi) + return mrow + else: + mi = self.dom.createElement('mi') + mi.appendChild(self.dom.createTextNode(items[0])) + return mi + + # translate name, supers and subs to unicode characters + def translate(s): + if s in greek_unicode: + return greek_unicode.get(s) + else: + return s + + name, supers, subs = self._split_super_sub(sym.name) + name = translate(name) + supers = [translate(sup) for sup in supers] + subs = [translate(sub) for sub in subs] + + mname = self.dom.createElement('mi') + mname.appendChild(self.dom.createTextNode(name)) + if len(supers) == 0: + if len(subs) == 0: + x = mname + else: + x = self.dom.createElement('msub') + x.appendChild(mname) + x.appendChild(join(subs)) + else: + if len(subs) == 0: + x = self.dom.createElement('msup') + x.appendChild(mname) + x.appendChild(join(supers)) + else: + x = self.dom.createElement('msubsup') + x.appendChild(mname) + x.appendChild(join(subs)) + x.appendChild(join(supers)) + # Set bold font? + if style == 'bold': + x.setAttribute('mathvariant', 'bold') + return x + + def _print_MatrixSymbol(self, sym): + return self._print_Symbol(sym, + style=self._settings['mat_symbol_style']) + + _print_RandomSymbol = _print_Symbol + + def _print_conjugate(self, expr): + enc = self.dom.createElement('menclose') + enc.setAttribute('notation', 'top') + enc.appendChild(self._print(expr.args[0])) + return enc + + def _print_operator_after(self, op, expr): + row = self.dom.createElement('mrow') + row.appendChild(self.parenthesize(expr, PRECEDENCE["Func"])) + mo = self.dom.createElement('mo') + mo.appendChild(self.dom.createTextNode(op)) + row.appendChild(mo) + return row + + def _print_factorial(self, expr): + return self._print_operator_after('!', expr.args[0]) + + def _print_factorial2(self, expr): + return self._print_operator_after('!!', expr.args[0]) + + def _print_binomial(self, expr): + frac = self.dom.createElement('mfrac') + frac.setAttribute('linethickness', '0') + frac.appendChild(self._print(expr.args[0])) + frac.appendChild(self._print(expr.args[1])) + brac = self.dom.createElement('mrow') + brac.appendChild(self._l_paren()) + brac.appendChild(frac) + brac.appendChild(self._r_paren()) + return brac + + def _print_Pow(self, e): + # Here we use root instead of power if the exponent is the + # reciprocal of an integer + if (e.exp.is_Rational and abs(e.exp.p) == 1 and e.exp.q != 1 and + self._settings['root_notation']): + if e.exp.q == 2: + x = self.dom.createElement('msqrt') + x.appendChild(self._print(e.base)) + if e.exp.q != 2: + x = self.dom.createElement('mroot') + x.appendChild(self._print(e.base)) + x.appendChild(self._print(e.exp.q)) + if e.exp.p == -1: + frac = self.dom.createElement('mfrac') + frac.appendChild(self._print(1)) + frac.appendChild(x) + return frac + else: + return x + + if e.exp.is_Rational and e.exp.q != 1: + if e.exp.is_negative: + top = self.dom.createElement('mfrac') + top.appendChild(self._print(1)) + x = self.dom.createElement('msup') + x.appendChild(self.parenthesize(e.base, PRECEDENCE['Pow'])) + x.appendChild(self._get_printed_Rational(-e.exp, + self._settings['fold_frac_powers'])) + top.appendChild(x) + return top + else: + x = self.dom.createElement('msup') + x.appendChild(self.parenthesize(e.base, PRECEDENCE['Pow'])) + x.appendChild(self._get_printed_Rational(e.exp, + self._settings['fold_frac_powers'])) + return x + + if e.exp.is_negative: + top = self.dom.createElement('mfrac') + top.appendChild(self._print(1)) + if e.exp == -1: + top.appendChild(self._print(e.base)) + else: + x = self.dom.createElement('msup') + x.appendChild(self.parenthesize(e.base, PRECEDENCE['Pow'])) + x.appendChild(self._print(-e.exp)) + top.appendChild(x) + return top + + x = self.dom.createElement('msup') + x.appendChild(self.parenthesize(e.base, PRECEDENCE['Pow'])) + x.appendChild(self._print(e.exp)) + return x + + def _print_Number(self, e): + x = self.dom.createElement(self.mathml_tag(e)) + x.appendChild(self.dom.createTextNode(str(e))) + return x + + def _print_AccumulationBounds(self, i): + left = self.dom.createElement('mo') + left.appendChild(self.dom.createTextNode('\u27e8')) + right = self.dom.createElement('mo') + right.appendChild(self.dom.createTextNode('\u27e9')) + brac = self.dom.createElement('mrow') + brac.appendChild(left) + brac.appendChild(self._print(i.min)) + brac.appendChild(self._comma()) + brac.appendChild(self._print(i.max)) + brac.appendChild(right) + return brac + + def _print_Derivative(self, e): + + if requires_partial(e.expr): + d = '∂' + else: + d = self.mathml_tag(e) + + # Determine denominator + m = self.dom.createElement('mrow') + dim = 0 # Total diff dimension, for numerator + for sym, num in reversed(e.variable_count): + dim += num + if num >= 2: + x = self.dom.createElement('msup') + xx = self.dom.createElement('mo') + xx.appendChild(self.dom.createTextNode(d)) + x.appendChild(xx) + x.appendChild(self._print(num)) + else: + x = self.dom.createElement('mo') + x.appendChild(self.dom.createTextNode(d)) + m.appendChild(x) + y = self._print(sym) + m.appendChild(y) + + mnum = self.dom.createElement('mrow') + if dim >= 2: + x = self.dom.createElement('msup') + xx = self.dom.createElement('mo') + xx.appendChild(self.dom.createTextNode(d)) + x.appendChild(xx) + x.appendChild(self._print(dim)) + else: + x = self.dom.createElement('mo') + x.appendChild(self.dom.createTextNode(d)) + + mnum.appendChild(x) + mrow = self.dom.createElement('mrow') + frac = self.dom.createElement('mfrac') + frac.appendChild(mnum) + frac.appendChild(m) + mrow.appendChild(frac) + + # Print function + mrow.appendChild(self._print(e.expr)) + + return mrow + + def _print_Function(self, e): + x = self.dom.createElement('mi') + if self.mathml_tag(e) == 'log' and self._settings["ln_notation"]: + x.appendChild(self.dom.createTextNode('ln')) + else: + x.appendChild(self.dom.createTextNode(self.mathml_tag(e))) + mrow = self.dom.createElement('mrow') + mrow.appendChild(x) + mrow.appendChild(self._paren_comma_separated(*e.args)) + return mrow + + def _print_Float(self, expr): + # Based off of that in StrPrinter + dps = prec_to_dps(expr._prec) + str_real = mlib_to_str(expr._mpf_, dps, strip_zeros=True) + + # Must always have a mul symbol (as 2.5 10^{20} just looks odd) + # thus we use the number separator + separator = self._settings['mul_symbol_mathml_numbers'] + mrow = self.dom.createElement('mrow') + if 'e' in str_real: + (mant, exp) = str_real.split('e') + + if exp[0] == '+': + exp = exp[1:] + + mn = self.dom.createElement('mn') + mn.appendChild(self.dom.createTextNode(mant)) + mrow.appendChild(mn) + mo = self.dom.createElement('mo') + mo.appendChild(self.dom.createTextNode(separator)) + mrow.appendChild(mo) + msup = self.dom.createElement('msup') + mn = self.dom.createElement('mn') + mn.appendChild(self.dom.createTextNode("10")) + msup.appendChild(mn) + mn = self.dom.createElement('mn') + mn.appendChild(self.dom.createTextNode(exp)) + msup.appendChild(mn) + mrow.appendChild(msup) + return mrow + elif str_real == "+inf": + return self._print_Infinity(None) + elif str_real == "-inf": + return self._print_NegativeInfinity(None) + else: + mn = self.dom.createElement('mn') + mn.appendChild(self.dom.createTextNode(str_real)) + return mn + + def _print_polylog(self, expr): + mrow = self.dom.createElement('mrow') + m = self.dom.createElement('msub') + + mi = self.dom.createElement('mi') + mi.appendChild(self.dom.createTextNode('Li')) + m.appendChild(mi) + m.appendChild(self._print(expr.args[0])) + mrow.appendChild(m) + brac = self.dom.createElement('mrow') + brac.appendChild(self._l_paren()) + brac.appendChild(self._print(expr.args[1])) + brac.appendChild(self._r_paren()) + mrow.appendChild(brac) + return mrow + + def _print_Basic(self, e): + mrow = self.dom.createElement('mrow') + mi = self.dom.createElement('mi') + mi.appendChild(self.dom.createTextNode(self.mathml_tag(e))) + mrow.appendChild(mi) + mrow.appendChild(self._paren_comma_separated(*e.args)) + return mrow + + def _print_Tuple(self, e): + return self._paren_comma_separated(*e.args) + + def _print_Interval(self, i): + right = self.dom.createElement('mo') + if i.right_open: + right.appendChild(self.dom.createTextNode(')')) + else: + right.appendChild(self.dom.createTextNode(']')) + left = self.dom.createElement('mo') + if i.left_open: + left.appendChild(self.dom.createTextNode('(')) + else: + left.appendChild(self.dom.createTextNode('[')) + mrow = self.dom.createElement('mrow') + mrow.appendChild(left) + mrow.appendChild(self._print(i.start)) + mrow.appendChild(self._comma()) + mrow.appendChild(self._print(i.end)) + mrow.appendChild(right) + return mrow + + def _print_Abs(self, expr, exp=None): + mrow = self.dom.createElement('mrow') + mrow.appendChild(self._bar()) + mrow.appendChild(self._print(expr.args[0])) + mrow.appendChild(self._bar()) + return mrow + + _print_Determinant = _print_Abs + + def _print_re_im(self, c, expr): + brac = self.dom.createElement('mrow') + brac.appendChild(self._l_paren()) + brac.appendChild(self._print(expr)) + brac.appendChild(self._r_paren()) + mi = self.dom.createElement('mi') + mi.appendChild(self.dom.createTextNode(c)) + mrow = self.dom.createElement('mrow') + mrow.appendChild(mi) + mrow.appendChild(brac) + return mrow + + def _print_re(self, expr, exp=None): + return self._print_re_im('\u211C', expr.args[0]) + + def _print_im(self, expr, exp=None): + return self._print_re_im('\u2111', expr.args[0]) + + def _print_AssocOp(self, e): + mrow = self.dom.createElement('mrow') + mi = self.dom.createElement('mi') + mi.appendChild(self.dom.createTextNode(self.mathml_tag(e))) + mrow.appendChild(mi) + for arg in e.args: + mrow.appendChild(self._print(arg)) + return mrow + + def _print_SetOp(self, expr, symbol, prec): + mrow = self.dom.createElement('mrow') + mrow.appendChild(self.parenthesize(expr.args[0], prec)) + for arg in expr.args[1:]: + x = self.dom.createElement('mo') + x.appendChild(self.dom.createTextNode(symbol)) + y = self.parenthesize(arg, prec) + mrow.appendChild(x) + mrow.appendChild(y) + return mrow + + def _print_Union(self, expr): + prec = PRECEDENCE_TRADITIONAL['Union'] + return self._print_SetOp(expr, '∪', prec) + + def _print_Intersection(self, expr): + prec = PRECEDENCE_TRADITIONAL['Intersection'] + return self._print_SetOp(expr, '∩', prec) + + def _print_Complement(self, expr): + prec = PRECEDENCE_TRADITIONAL['Complement'] + return self._print_SetOp(expr, '∖', prec) + + def _print_SymmetricDifference(self, expr): + prec = PRECEDENCE_TRADITIONAL['SymmetricDifference'] + return self._print_SetOp(expr, '∆', prec) + + def _print_ProductSet(self, expr): + prec = PRECEDENCE_TRADITIONAL['ProductSet'] + return self._print_SetOp(expr, '×', prec) + + def _print_FiniteSet(self, s): + return self._print_set(s.args) + + def _print_set(self, s): + items = sorted(s, key=default_sort_key) + brac = self.dom.createElement('mrow') + brac.appendChild(self._l_brace()) + for i, item in enumerate(items): + if i: + brac.appendChild(self._comma()) + brac.appendChild(self._print(item)) + brac.appendChild(self._r_brace()) + return brac + + _print_frozenset = _print_set + + def _print_LogOp(self, args, symbol): + mrow = self.dom.createElement('mrow') + if args[0].is_Boolean and not args[0].is_Not: + brac = self.dom.createElement('mrow') + brac.appendChild(self._l_paren()) + brac.appendChild(self._print(args[0])) + brac.appendChild(self._r_paren()) + mrow.appendChild(brac) + else: + mrow.appendChild(self._print(args[0])) + for arg in args[1:]: + x = self.dom.createElement('mo') + x.appendChild(self.dom.createTextNode(symbol)) + if arg.is_Boolean and not arg.is_Not: + y = self.dom.createElement('mrow') + y.appendChild(self._l_paren()) + y.appendChild(self._print(arg)) + y.appendChild(self._r_paren()) + else: + y = self._print(arg) + mrow.appendChild(x) + mrow.appendChild(y) + return mrow + + def _print_BasisDependent(self, expr): + from sympy.vector import Vector + + if expr == expr.zero: + # Not clear if this is ever called + return self._print(expr.zero) + if isinstance(expr, Vector): + items = expr.separate().items() + else: + items = [(0, expr)] + + mrow = self.dom.createElement('mrow') + for system, vect in items: + inneritems = list(vect.components.items()) + inneritems.sort(key = lambda x:x[0].__str__()) + for i, (k, v) in enumerate(inneritems): + if v == 1: + if i: # No + for first item + mo = self.dom.createElement('mo') + mo.appendChild(self.dom.createTextNode('+')) + mrow.appendChild(mo) + mrow.appendChild(self._print(k)) + elif v == -1: + mo = self.dom.createElement('mo') + mo.appendChild(self.dom.createTextNode('-')) + mrow.appendChild(mo) + mrow.appendChild(self._print(k)) + else: + if i: # No + for first item + mo = self.dom.createElement('mo') + mo.appendChild(self.dom.createTextNode('+')) + mrow.appendChild(mo) + mbrac = self.dom.createElement('mrow') + mbrac.appendChild(self._l_paren()) + mbrac.appendChild(self._print(v)) + mbrac.appendChild(self._r_paren()) + mrow.appendChild(mbrac) + mo = self.dom.createElement('mo') + mo.appendChild(self.dom.createTextNode('⁢')) + mrow.appendChild(mo) + mrow.appendChild(self._print(k)) + return mrow + + + def _print_And(self, expr): + args = sorted(expr.args, key=default_sort_key) + return self._print_LogOp(args, '∧') + + def _print_Or(self, expr): + args = sorted(expr.args, key=default_sort_key) + return self._print_LogOp(args, '∨') + + def _print_Xor(self, expr): + args = sorted(expr.args, key=default_sort_key) + return self._print_LogOp(args, '⊻') + + def _print_Implies(self, expr): + return self._print_LogOp(expr.args, '⇒') + + def _print_Equivalent(self, expr): + args = sorted(expr.args, key=default_sort_key) + return self._print_LogOp(args, '⇔') + + def _print_Not(self, e): + mrow = self.dom.createElement('mrow') + mo = self.dom.createElement('mo') + mo.appendChild(self.dom.createTextNode('¬')) + mrow.appendChild(mo) + if (e.args[0].is_Boolean): + x = self.dom.createElement('mrow') + x.appendChild(self._l_paren()) + x.appendChild(self._print(e.args[0])) + x.appendChild(self._r_paren()) + else: + x = self._print(e.args[0]) + mrow.appendChild(x) + return mrow + + def _print_bool(self, e): + mi = self.dom.createElement('mi') + mi.appendChild(self.dom.createTextNode(self.mathml_tag(e))) + return mi + + _print_BooleanTrue = _print_bool + _print_BooleanFalse = _print_bool + + def _print_NoneType(self, e): + mi = self.dom.createElement('mi') + mi.appendChild(self.dom.createTextNode(self.mathml_tag(e))) + return mi + + def _print_Range(self, s): + dots = "\u2026" + if s.start.is_infinite and s.stop.is_infinite: + if s.step.is_positive: + printset = dots, -1, 0, 1, dots + else: + printset = dots, 1, 0, -1, dots + elif s.start.is_infinite: + printset = dots, s[-1] - s.step, s[-1] + elif s.stop.is_infinite: + it = iter(s) + printset = next(it), next(it), dots + elif len(s) > 4: + it = iter(s) + printset = next(it), next(it), dots, s[-1] + else: + printset = tuple(s) + brac = self.dom.createElement('mrow') + brac.appendChild(self._l_brace()) + for i, el in enumerate(printset): + if i: + brac.appendChild(self._comma()) + if el == dots: + mi = self.dom.createElement('mi') + mi.appendChild(self.dom.createTextNode(dots)) + brac.appendChild(mi) + else: + brac.appendChild(self._print(el)) + brac.appendChild(self._r_brace()) + return brac + + def _hprint_variadic_function(self, expr): + args = sorted(expr.args, key=default_sort_key) + mrow = self.dom.createElement('mrow') + mo = self.dom.createElement('mo') + mo.appendChild(self.dom.createTextNode((str(expr.func)).lower())) + mrow.appendChild(mo) + mrow.appendChild(self._paren_comma_separated(*args)) + return mrow + + _print_Min = _print_Max = _hprint_variadic_function + + def _print_exp(self, expr): + msup = self.dom.createElement('msup') + msup.appendChild(self._print_Exp1(None)) + msup.appendChild(self._print(expr.args[0])) + return msup + + def _print_Relational(self, e): + mrow = self.dom.createElement('mrow') + mrow.appendChild(self._print(e.lhs)) + x = self.dom.createElement('mo') + x.appendChild(self.dom.createTextNode(self.mathml_tag(e))) + mrow.appendChild(x) + mrow.appendChild(self._print(e.rhs)) + return mrow + + def _print_int(self, p): + dom_element = self.dom.createElement(self.mathml_tag(p)) + dom_element.appendChild(self.dom.createTextNode(str(p))) + return dom_element + + def _print_BaseScalar(self, e): + msub = self.dom.createElement('msub') + index, system = e._id + mi = self.dom.createElement('mi') + mi.setAttribute('mathvariant', 'bold') + mi.appendChild(self.dom.createTextNode(system._variable_names[index])) + msub.appendChild(mi) + mi = self.dom.createElement('mi') + mi.setAttribute('mathvariant', 'bold') + mi.appendChild(self.dom.createTextNode(system._name)) + msub.appendChild(mi) + return msub + + def _print_BaseVector(self, e): + msub = self.dom.createElement('msub') + index, system = e._id + mover = self.dom.createElement('mover') + mi = self.dom.createElement('mi') + mi.setAttribute('mathvariant', 'bold') + mi.appendChild(self.dom.createTextNode(system._vector_names[index])) + mover.appendChild(mi) + mo = self.dom.createElement('mo') + mo.appendChild(self.dom.createTextNode('^')) + mover.appendChild(mo) + msub.appendChild(mover) + mi = self.dom.createElement('mi') + mi.setAttribute('mathvariant', 'bold') + mi.appendChild(self.dom.createTextNode(system._name)) + msub.appendChild(mi) + return msub + + def _print_VectorZero(self, e): + mover = self.dom.createElement('mover') + mi = self.dom.createElement('mi') + mi.setAttribute('mathvariant', 'bold') + mi.appendChild(self.dom.createTextNode("0")) + mover.appendChild(mi) + mo = self.dom.createElement('mo') + mo.appendChild(self.dom.createTextNode('^')) + mover.appendChild(mo) + return mover + + def _print_Cross(self, expr): + mrow = self.dom.createElement('mrow') + vec1 = expr._expr1 + vec2 = expr._expr2 + mrow.appendChild(self.parenthesize(vec1, PRECEDENCE['Mul'])) + mo = self.dom.createElement('mo') + mo.appendChild(self.dom.createTextNode('×')) + mrow.appendChild(mo) + mrow.appendChild(self.parenthesize(vec2, PRECEDENCE['Mul'])) + return mrow + + def _print_Curl(self, expr): + mrow = self.dom.createElement('mrow') + mo = self.dom.createElement('mo') + mo.appendChild(self.dom.createTextNode('∇')) + mrow.appendChild(mo) + mo = self.dom.createElement('mo') + mo.appendChild(self.dom.createTextNode('×')) + mrow.appendChild(mo) + mrow.appendChild(self.parenthesize(expr._expr, PRECEDENCE['Mul'])) + return mrow + + def _print_Divergence(self, expr): + mrow = self.dom.createElement('mrow') + mo = self.dom.createElement('mo') + mo.appendChild(self.dom.createTextNode('∇')) + mrow.appendChild(mo) + mo = self.dom.createElement('mo') + mo.appendChild(self.dom.createTextNode('·')) + mrow.appendChild(mo) + mrow.appendChild(self.parenthesize(expr._expr, PRECEDENCE['Mul'])) + return mrow + + def _print_Dot(self, expr): + mrow = self.dom.createElement('mrow') + vec1 = expr._expr1 + vec2 = expr._expr2 + mrow.appendChild(self.parenthesize(vec1, PRECEDENCE['Mul'])) + mo = self.dom.createElement('mo') + mo.appendChild(self.dom.createTextNode('·')) + mrow.appendChild(mo) + mrow.appendChild(self.parenthesize(vec2, PRECEDENCE['Mul'])) + return mrow + + def _print_Gradient(self, expr): + mrow = self.dom.createElement('mrow') + mo = self.dom.createElement('mo') + mo.appendChild(self.dom.createTextNode('∇')) + mrow.appendChild(mo) + mrow.appendChild(self.parenthesize(expr._expr, PRECEDENCE['Mul'])) + return mrow + + def _print_Laplacian(self, expr): + mrow = self.dom.createElement('mrow') + mo = self.dom.createElement('mo') + mo.appendChild(self.dom.createTextNode('∆')) + mrow.appendChild(mo) + mrow.appendChild(self.parenthesize(expr._expr, PRECEDENCE['Mul'])) + return mrow + + def _print_Integers(self, e): + x = self.dom.createElement('mi') + x.setAttribute('mathvariant', 'normal') + x.appendChild(self.dom.createTextNode('ℤ')) + return x + + def _print_Complexes(self, e): + x = self.dom.createElement('mi') + x.setAttribute('mathvariant', 'normal') + x.appendChild(self.dom.createTextNode('ℂ')) + return x + + def _print_Reals(self, e): + x = self.dom.createElement('mi') + x.setAttribute('mathvariant', 'normal') + x.appendChild(self.dom.createTextNode('ℝ')) + return x + + def _print_Naturals(self, e): + x = self.dom.createElement('mi') + x.setAttribute('mathvariant', 'normal') + x.appendChild(self.dom.createTextNode('ℕ')) + return x + + def _print_Naturals0(self, e): + sub = self.dom.createElement('msub') + x = self.dom.createElement('mi') + x.setAttribute('mathvariant', 'normal') + x.appendChild(self.dom.createTextNode('ℕ')) + sub.appendChild(x) + sub.appendChild(self._print(S.Zero)) + return sub + + def _print_SingularityFunction(self, expr): + shift = expr.args[0] - expr.args[1] + power = expr.args[2] + left = self.dom.createElement('mo') + left.appendChild(self.dom.createTextNode('\u27e8')) + right = self.dom.createElement('mo') + right.appendChild(self.dom.createTextNode('\u27e9')) + brac = self.dom.createElement('mrow') + brac.appendChild(left) + brac.appendChild(self._print(shift)) + brac.appendChild(right) + sup = self.dom.createElement('msup') + sup.appendChild(brac) + sup.appendChild(self._print(power)) + return sup + + def _print_NaN(self, e): + x = self.dom.createElement('mi') + x.appendChild(self.dom.createTextNode('NaN')) + return x + + def _print_number_function(self, e, name): + # Print name_arg[0] for one argument or name_arg[0](arg[1]) + # for more than one argument + sub = self.dom.createElement('msub') + mi = self.dom.createElement('mi') + mi.appendChild(self.dom.createTextNode(name)) + sub.appendChild(mi) + sub.appendChild(self._print(e.args[0])) + if len(e.args) == 1: + return sub + mrow = self.dom.createElement('mrow') + mrow.appendChild(sub) + mrow.appendChild(self._paren_comma_separated(*e.args[1:])) + return mrow + + def _print_bernoulli(self, e): + return self._print_number_function(e, 'B') + + _print_bell = _print_bernoulli + + def _print_catalan(self, e): + return self._print_number_function(e, 'C') + + def _print_euler(self, e): + return self._print_number_function(e, 'E') + + def _print_fibonacci(self, e): + return self._print_number_function(e, 'F') + + def _print_lucas(self, e): + return self._print_number_function(e, 'L') + + def _print_stieltjes(self, e): + return self._print_number_function(e, 'γ') + + def _print_tribonacci(self, e): + return self._print_number_function(e, 'T') + + def _print_ComplexInfinity(self, e): + x = self.dom.createElement('mover') + mo = self.dom.createElement('mo') + mo.appendChild(self.dom.createTextNode('∞')) + x.appendChild(mo) + mo = self.dom.createElement('mo') + mo.appendChild(self.dom.createTextNode('~')) + x.appendChild(mo) + return x + + def _print_EmptySet(self, e): + x = self.dom.createElement('mo') + x.appendChild(self.dom.createTextNode('∅')) + return x + + def _print_UniversalSet(self, e): + x = self.dom.createElement('mo') + x.appendChild(self.dom.createTextNode('𝕌')) + return x + + def _print_Adjoint(self, expr): + from sympy.matrices import MatrixSymbol + mat = expr.arg + sup = self.dom.createElement('msup') + if not isinstance(mat, MatrixSymbol): + brac = self.dom.createElement('mrow') + brac.appendChild(self._l_paren()) + brac.appendChild(self._print(mat)) + brac.appendChild(self._r_paren()) + sup.appendChild(brac) + else: + sup.appendChild(self._print(mat)) + mo = self.dom.createElement('mo') + mo.appendChild(self.dom.createTextNode('†')) + sup.appendChild(mo) + return sup + + def _print_Transpose(self, expr): + from sympy.matrices import MatrixSymbol + mat = expr.arg + sup = self.dom.createElement('msup') + if not isinstance(mat, MatrixSymbol): + brac = self.dom.createElement('mrow') + brac.appendChild(self._l_paren()) + brac.appendChild(self._print(mat)) + brac.appendChild(self._r_paren()) + sup.appendChild(brac) + else: + sup.appendChild(self._print(mat)) + mo = self.dom.createElement('mo') + mo.appendChild(self.dom.createTextNode('T')) + sup.appendChild(mo) + return sup + + def _print_Inverse(self, expr): + from sympy.matrices import MatrixSymbol + mat = expr.arg + sup = self.dom.createElement('msup') + if not isinstance(mat, MatrixSymbol): + brac = self.dom.createElement('mrow') + brac.appendChild(self._l_paren()) + brac.appendChild(self._print(mat)) + brac.appendChild(self._r_paren()) + sup.appendChild(brac) + else: + sup.appendChild(self._print(mat)) + sup.appendChild(self._print(-1)) + return sup + + def _print_MatMul(self, expr): + from sympy.matrices.expressions.matmul import MatMul + + x = self.dom.createElement('mrow') + args = expr.args + if isinstance(args[0], Mul): + args = args[0].as_ordered_factors() + list(args[1:]) + else: + args = list(args) + + if isinstance(expr, MatMul) and expr.could_extract_minus_sign(): + if args[0] == -1: + args = args[1:] + else: + args[0] = -args[0] + mo = self.dom.createElement('mo') + mo.appendChild(self.dom.createTextNode('-')) + x.appendChild(mo) + + for arg in args[:-1]: + x.appendChild(self.parenthesize(arg, precedence_traditional(expr), + False)) + mo = self.dom.createElement('mo') + mo.appendChild(self.dom.createTextNode('⁢')) + x.appendChild(mo) + x.appendChild(self.parenthesize(args[-1], precedence_traditional(expr), + False)) + return x + + def _print_MatPow(self, expr): + from sympy.matrices import MatrixSymbol + base, exp = expr.base, expr.exp + sup = self.dom.createElement('msup') + if not isinstance(base, MatrixSymbol): + brac = self.dom.createElement('mrow') + brac.appendChild(self._l_paren()) + brac.appendChild(self._print(base)) + brac.appendChild(self._r_paren()) + sup.appendChild(brac) + else: + sup.appendChild(self._print(base)) + sup.appendChild(self._print(exp)) + return sup + + def _print_HadamardProduct(self, expr): + x = self.dom.createElement('mrow') + args = expr.args + for arg in args[:-1]: + x.appendChild( + self.parenthesize(arg, precedence_traditional(expr), False)) + mo = self.dom.createElement('mo') + mo.appendChild(self.dom.createTextNode('∘')) + x.appendChild(mo) + x.appendChild( + self.parenthesize(args[-1], precedence_traditional(expr), False)) + return x + + def _print_ZeroMatrix(self, Z): + x = self.dom.createElement('mn') + x.appendChild(self.dom.createTextNode('𝟘')) + return x + + def _print_OneMatrix(self, Z): + x = self.dom.createElement('mn') + x.appendChild(self.dom.createTextNode('𝟙')) + return x + + def _print_Identity(self, I): + x = self.dom.createElement('mi') + x.appendChild(self.dom.createTextNode('𝕀')) + return x + + def _print_floor(self, e): + left = self.dom.createElement('mo') + left.appendChild(self.dom.createTextNode('\u230A')) + right = self.dom.createElement('mo') + right.appendChild(self.dom.createTextNode('\u230B')) + mrow = self.dom.createElement('mrow') + mrow.appendChild(left) + mrow.appendChild(self._print(e.args[0])) + mrow.appendChild(right) + return mrow + + def _print_ceiling(self, e): + left = self.dom.createElement('mo') + left.appendChild(self.dom.createTextNode('\u2308')) + right = self.dom.createElement('mo') + right.appendChild(self.dom.createTextNode('\u2309')) + mrow = self.dom.createElement('mrow') + mrow.appendChild(left) + mrow.appendChild(self._print(e.args[0])) + mrow.appendChild(right) + return mrow + + def _print_Lambda(self, e): + mrow = self.dom.createElement('mrow') + symbols = e.args[0] + if len(symbols) == 1: + symbols = self._print(symbols[0]) + else: + symbols = self._print(symbols) + mrow.appendChild(self._l_paren()) + mrow.appendChild(symbols) + mo = self.dom.createElement('mo') + mo.appendChild(self.dom.createTextNode('↦')) + mrow.appendChild(mo) + mrow.appendChild(self._print(e.args[1])) + mrow.appendChild(self._r_paren()) + return mrow + + def _print_tuple(self, e): + return self._paren_comma_separated(*e) + + def _print_IndexedBase(self, e): + return self._print(e.label) + + def _print_Indexed(self, e): + x = self.dom.createElement('msub') + x.appendChild(self._print(e.base)) + if len(e.indices) == 1: + x.appendChild(self._print(e.indices[0])) + return x + x.appendChild(self._print(e.indices)) + return x + + def _print_MatrixElement(self, e): + x = self.dom.createElement('msub') + x.appendChild(self.parenthesize(e.parent, PRECEDENCE["Atom"], strict = True)) + brac = self.dom.createElement('mrow') + for i, arg in enumerate(e.indices): + if i: + brac.appendChild(self._comma()) + brac.appendChild(self._print(arg)) + x.appendChild(brac) + return x + + def _print_elliptic_f(self, e): + x = self.dom.createElement('mrow') + mi = self.dom.createElement('mi') + mi.appendChild(self.dom.createTextNode('𝖥')) + x.appendChild(mi) + x.appendChild(self._paren_bar_separated(*e.args)) + return x + + def _print_elliptic_e(self, e): + x = self.dom.createElement('mrow') + mi = self.dom.createElement('mi') + mi.appendChild(self.dom.createTextNode('𝖤')) + x.appendChild(mi) + x.appendChild(self._paren_bar_separated(*e.args)) + return x + + def _print_elliptic_pi(self, e): + x = self.dom.createElement('mrow') + mi = self.dom.createElement('mi') + mi.appendChild(self.dom.createTextNode('𝛱')) + x.appendChild(mi) + y = self.dom.createElement('mrow') + y.appendChild(self._l_paren()) + if len(e.args) == 2: + n, m = e.args + y.appendChild(self._print(n)) + y.appendChild(self._bar()) + y.appendChild(self._print(m)) + else: + n, m, z = e.args + y.appendChild(self._print(n)) + y.appendChild(self._semicolon()) + y.appendChild(self._print(m)) + y.appendChild(self._bar()) + y.appendChild(self._print(z)) + y.appendChild(self._r_paren()) + x.appendChild(y) + return x + + def _print_Ei(self, e): + x = self.dom.createElement('mrow') + mi = self.dom.createElement('mi') + mi.appendChild(self.dom.createTextNode('Ei')) + x.appendChild(mi) + x.appendChild(self._print(e.args)) + return x + + def _print_expint(self, e): + x = self.dom.createElement('mrow') + y = self.dom.createElement('msub') + mo = self.dom.createElement('mo') + mo.appendChild(self.dom.createTextNode('E')) + y.appendChild(mo) + y.appendChild(self._print(e.args[0])) + x.appendChild(y) + x.appendChild(self._print(e.args[1:])) + return x + + def _print_jacobi(self, e): + x = self.dom.createElement('mrow') + y = self.dom.createElement('msubsup') + mo = self.dom.createElement('mo') + mo.appendChild(self.dom.createTextNode('P')) + y.appendChild(mo) + y.appendChild(self._print(e.args[0])) + y.appendChild(self._print(e.args[1:3])) + x.appendChild(y) + x.appendChild(self._print(e.args[3:])) + return x + + def _print_gegenbauer(self, e): + x = self.dom.createElement('mrow') + y = self.dom.createElement('msubsup') + mo = self.dom.createElement('mo') + mo.appendChild(self.dom.createTextNode('C')) + y.appendChild(mo) + y.appendChild(self._print(e.args[0])) + y.appendChild(self._print(e.args[1:2])) + x.appendChild(y) + x.appendChild(self._print(e.args[2:])) + return x + + def _print_chebyshevt(self, e): + x = self.dom.createElement('mrow') + y = self.dom.createElement('msub') + mo = self.dom.createElement('mo') + mo.appendChild(self.dom.createTextNode('T')) + y.appendChild(mo) + y.appendChild(self._print(e.args[0])) + x.appendChild(y) + x.appendChild(self._print(e.args[1:])) + return x + + def _print_chebyshevu(self, e): + x = self.dom.createElement('mrow') + y = self.dom.createElement('msub') + mo = self.dom.createElement('mo') + mo.appendChild(self.dom.createTextNode('U')) + y.appendChild(mo) + y.appendChild(self._print(e.args[0])) + x.appendChild(y) + x.appendChild(self._print(e.args[1:])) + return x + + def _print_legendre(self, e): + x = self.dom.createElement('mrow') + y = self.dom.createElement('msub') + mo = self.dom.createElement('mo') + mo.appendChild(self.dom.createTextNode('P')) + y.appendChild(mo) + y.appendChild(self._print(e.args[0])) + x.appendChild(y) + x.appendChild(self._print(e.args[1:])) + return x + + def _print_assoc_legendre(self, e): + x = self.dom.createElement('mrow') + y = self.dom.createElement('msubsup') + mo = self.dom.createElement('mo') + mo.appendChild(self.dom.createTextNode('P')) + y.appendChild(mo) + y.appendChild(self._print(e.args[0])) + y.appendChild(self._print(e.args[1:2])) + x.appendChild(y) + x.appendChild(self._print(e.args[2:])) + return x + + def _print_laguerre(self, e): + x = self.dom.createElement('mrow') + y = self.dom.createElement('msub') + mo = self.dom.createElement('mo') + mo.appendChild(self.dom.createTextNode('L')) + y.appendChild(mo) + y.appendChild(self._print(e.args[0])) + x.appendChild(y) + x.appendChild(self._print(e.args[1:])) + return x + + def _print_assoc_laguerre(self, e): + x = self.dom.createElement('mrow') + y = self.dom.createElement('msubsup') + mo = self.dom.createElement('mo') + mo.appendChild(self.dom.createTextNode('L')) + y.appendChild(mo) + y.appendChild(self._print(e.args[0])) + y.appendChild(self._print(e.args[1:2])) + x.appendChild(y) + x.appendChild(self._print(e.args[2:])) + return x + + def _print_hermite(self, e): + x = self.dom.createElement('mrow') + y = self.dom.createElement('msub') + mo = self.dom.createElement('mo') + mo.appendChild(self.dom.createTextNode('H')) + y.appendChild(mo) + y.appendChild(self._print(e.args[0])) + x.appendChild(y) + x.appendChild(self._print(e.args[1:])) + return x + + +@print_function(MathMLPrinterBase) +def mathml(expr, printer='content', **settings): + """Returns the MathML representation of expr. If printer is presentation + then prints Presentation MathML else prints content MathML. + """ + if printer == 'presentation': + return MathMLPresentationPrinter(settings).doprint(expr) + else: + return MathMLContentPrinter(settings).doprint(expr) + + +def print_mathml(expr, printer='content', **settings): + """ + Prints a pretty representation of the MathML code for expr. If printer is + presentation then prints Presentation MathML else prints content MathML. + + Examples + ======== + + >>> ## + >>> from sympy import print_mathml + >>> from sympy.abc import x + >>> print_mathml(x+1) #doctest: +NORMALIZE_WHITESPACE + + + x + 1 + + >>> print_mathml(x+1, printer='presentation') + + x + + + 1 + + + """ + if printer == 'presentation': + s = MathMLPresentationPrinter(settings) + else: + s = MathMLContentPrinter(settings) + xml = s._print(sympify(expr)) + pretty_xml = xml.toprettyxml() + + print(pretty_xml) + + +# For backward compatibility +MathMLPrinter = MathMLContentPrinter diff --git a/phivenv/Lib/site-packages/sympy/printing/numpy.py b/phivenv/Lib/site-packages/sympy/printing/numpy.py new file mode 100644 index 0000000000000000000000000000000000000000..1ff68454bb287bc0a1d2dfc1fe68fb05b3c22a74 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/numpy.py @@ -0,0 +1,541 @@ +from sympy.core import S +from sympy.core.function import Lambda +from sympy.core.power import Pow +from .pycode import PythonCodePrinter, _known_functions_math, _print_known_const, _print_known_func, _unpack_integral_limits, ArrayPrinter +from .codeprinter import CodePrinter + + +_not_in_numpy = 'erf erfc factorial gamma loggamma'.split() +_in_numpy = [(k, v) for k, v in _known_functions_math.items() if k not in _not_in_numpy] +_known_functions_numpy = dict(_in_numpy, **{ + 'acos': 'arccos', + 'acosh': 'arccosh', + 'asin': 'arcsin', + 'asinh': 'arcsinh', + 'atan': 'arctan', + 'atan2': 'arctan2', + 'atanh': 'arctanh', + 'exp2': 'exp2', + 'sign': 'sign', + 'logaddexp': 'logaddexp', + 'logaddexp2': 'logaddexp2', + 'isinf': 'isinf', + 'isnan': 'isnan', + +}) +_known_constants_numpy = { + 'Exp1': 'e', + 'Pi': 'pi', + 'EulerGamma': 'euler_gamma', + 'NaN': 'nan', + 'Infinity': 'inf', +} + +_numpy_known_functions = {k: 'numpy.' + v for k, v in _known_functions_numpy.items()} +_numpy_known_constants = {k: 'numpy.' + v for k, v in _known_constants_numpy.items()} + +class NumPyPrinter(ArrayPrinter, PythonCodePrinter): + """ + Numpy printer which handles vectorized piecewise functions, + logical operators, etc. + """ + + _module = 'numpy' + _kf = _numpy_known_functions + _kc = _numpy_known_constants + + def __init__(self, settings=None): + """ + `settings` is passed to CodePrinter.__init__() + `module` specifies the array module to use, currently 'NumPy', 'CuPy' + or 'JAX'. + """ + self.language = "Python with {}".format(self._module) + self.printmethod = "_{}code".format(self._module) + + self._kf = {**PythonCodePrinter._kf, **self._kf} + + super().__init__(settings=settings) + + + def _print_seq(self, seq): + "General sequence printer: converts to tuple" + # Print tuples here instead of lists because numba supports + # tuples in nopython mode. + delimiter=', ' + return '({},)'.format(delimiter.join(self._print(item) for item in seq)) + + def _print_NegativeInfinity(self, expr): + return '-' + self._print(S.Infinity) + + def _print_MatMul(self, expr): + "Matrix multiplication printer" + if expr.as_coeff_matrices()[0] is not S.One: + expr_list = expr.as_coeff_matrices()[1]+[(expr.as_coeff_matrices()[0])] + return '({})'.format(').dot('.join(self._print(i) for i in expr_list)) + return '({})'.format(').dot('.join(self._print(i) for i in expr.args)) + + def _print_MatPow(self, expr): + "Matrix power printer" + return '{}({}, {})'.format(self._module_format(self._module + '.linalg.matrix_power'), + self._print(expr.args[0]), self._print(expr.args[1])) + + def _print_Inverse(self, expr): + "Matrix inverse printer" + return '{}({})'.format(self._module_format(self._module + '.linalg.inv'), + self._print(expr.args[0])) + + def _print_DotProduct(self, expr): + # DotProduct allows any shape order, but numpy.dot does matrix + # multiplication, so we have to make sure it gets 1 x n by n x 1. + arg1, arg2 = expr.args + if arg1.shape[0] != 1: + arg1 = arg1.T + if arg2.shape[1] != 1: + arg2 = arg2.T + + return "%s(%s, %s)" % (self._module_format(self._module + '.dot'), + self._print(arg1), + self._print(arg2)) + + def _print_MatrixSolve(self, expr): + return "%s(%s, %s)" % (self._module_format(self._module + '.linalg.solve'), + self._print(expr.matrix), + self._print(expr.vector)) + + def _print_ZeroMatrix(self, expr): + return '{}({})'.format(self._module_format(self._module + '.zeros'), + self._print(expr.shape)) + + def _print_OneMatrix(self, expr): + return '{}({})'.format(self._module_format(self._module + '.ones'), + self._print(expr.shape)) + + def _print_FunctionMatrix(self, expr): + from sympy.abc import i, j + lamda = expr.lamda + if not isinstance(lamda, Lambda): + lamda = Lambda((i, j), lamda(i, j)) + return '{}(lambda {}: {}, {})'.format(self._module_format(self._module + '.fromfunction'), + ', '.join(self._print(arg) for arg in lamda.args[0]), + self._print(lamda.args[1]), self._print(expr.shape)) + + def _print_HadamardProduct(self, expr): + func = self._module_format(self._module + '.multiply') + return ''.join('{}({}, '.format(func, self._print(arg)) \ + for arg in expr.args[:-1]) + "{}{}".format(self._print(expr.args[-1]), + ')' * (len(expr.args) - 1)) + + def _print_KroneckerProduct(self, expr): + func = self._module_format(self._module + '.kron') + return ''.join('{}({}, '.format(func, self._print(arg)) \ + for arg in expr.args[:-1]) + "{}{}".format(self._print(expr.args[-1]), + ')' * (len(expr.args) - 1)) + + def _print_Adjoint(self, expr): + return '{}({}({}))'.format( + self._module_format(self._module + '.conjugate'), + self._module_format(self._module + '.transpose'), + self._print(expr.args[0])) + + def _print_DiagonalOf(self, expr): + vect = '{}({})'.format( + self._module_format(self._module + '.diag'), + self._print(expr.arg)) + return '{}({}, (-1, 1))'.format( + self._module_format(self._module + '.reshape'), vect) + + def _print_DiagMatrix(self, expr): + return '{}({})'.format(self._module_format(self._module + '.diagflat'), + self._print(expr.args[0])) + + def _print_DiagonalMatrix(self, expr): + return '{}({}, {}({}, {}))'.format(self._module_format(self._module + '.multiply'), + self._print(expr.arg), self._module_format(self._module + '.eye'), + self._print(expr.shape[0]), self._print(expr.shape[1])) + + def _print_Piecewise(self, expr): + "Piecewise function printer" + from sympy.logic.boolalg import ITE, simplify_logic + def print_cond(cond): + """ Problem having an ITE in the cond. """ + if cond.has(ITE): + return self._print(simplify_logic(cond)) + else: + return self._print(cond) + exprs = '[{}]'.format(','.join(self._print(arg.expr) for arg in expr.args)) + conds = '[{}]'.format(','.join(print_cond(arg.cond) for arg in expr.args)) + # If [default_value, True] is a (expr, cond) sequence in a Piecewise object + # it will behave the same as passing the 'default' kwarg to select() + # *as long as* it is the last element in expr.args. + # If this is not the case, it may be triggered prematurely. + return '{}({}, {}, default={})'.format( + self._module_format(self._module + '.select'), conds, exprs, + self._print(S.NaN)) + + def _print_Relational(self, expr): + "Relational printer for Equality and Unequality" + op = { + '==' :'equal', + '!=' :'not_equal', + '<' :'less', + '<=' :'less_equal', + '>' :'greater', + '>=' :'greater_equal', + } + if expr.rel_op in op: + lhs = self._print(expr.lhs) + rhs = self._print(expr.rhs) + return '{op}({lhs}, {rhs})'.format(op=self._module_format(self._module + '.'+op[expr.rel_op]), + lhs=lhs, rhs=rhs) + return super()._print_Relational(expr) + + def _print_And(self, expr): + "Logical And printer" + # We have to override LambdaPrinter because it uses Python 'and' keyword. + # If LambdaPrinter didn't define it, we could use StrPrinter's + # version of the function and add 'logical_and' to NUMPY_TRANSLATIONS. + return '{}.reduce(({}))'.format(self._module_format(self._module + '.logical_and'), ','.join(self._print(i) for i in expr.args)) + + def _print_Or(self, expr): + "Logical Or printer" + # We have to override LambdaPrinter because it uses Python 'or' keyword. + # If LambdaPrinter didn't define it, we could use StrPrinter's + # version of the function and add 'logical_or' to NUMPY_TRANSLATIONS. + return '{}.reduce(({}))'.format(self._module_format(self._module + '.logical_or'), ','.join(self._print(i) for i in expr.args)) + + def _print_Not(self, expr): + "Logical Not printer" + # We have to override LambdaPrinter because it uses Python 'not' keyword. + # If LambdaPrinter didn't define it, we would still have to define our + # own because StrPrinter doesn't define it. + return '{}({})'.format(self._module_format(self._module + '.logical_not'), ','.join(self._print(i) for i in expr.args)) + + def _print_Pow(self, expr, rational=False): + # XXX Workaround for negative integer power error + if expr.exp.is_integer and expr.exp.is_negative: + expr = Pow(expr.base, expr.exp.evalf(), evaluate=False) + return self._hprint_Pow(expr, rational=rational, sqrt=self._module + '.sqrt') + + def _helper_minimum_maximum(self, op: str, *args): + if len(args) == 0: + raise NotImplementedError(f"Need at least one argument for {op}") + elif len(args) == 1: + return self._print(args[0]) + _reduce = self._module_format('functools.reduce') + s_args = [self._print(arg) for arg in args] + return f"{_reduce}({op}, [{', '.join(s_args)}])" + + def _print_Min(self, expr): + return self._print_minimum(expr) + + def _print_amin(self, expr): + return '{}({}, axis={})'.format(self._module_format(self._module + '.amin'), self._print(expr.array), self._print(expr.axis)) + + def _print_minimum(self, expr): + op = self._module_format(self._module + '.minimum') + return self._helper_minimum_maximum(op, *expr.args) + + def _print_Max(self, expr): + return self._print_maximum(expr) + + def _print_amax(self, expr): + return '{}({}, axis={})'.format(self._module_format(self._module + '.amax'), self._print(expr.array), self._print(expr.axis)) + + def _print_maximum(self, expr): + op = self._module_format(self._module + '.maximum') + return self._helper_minimum_maximum(op, *expr.args) + + def _print_arg(self, expr): + return "%s(%s)" % (self._module_format(self._module + '.angle'), self._print(expr.args[0])) + + def _print_im(self, expr): + return "%s(%s)" % (self._module_format(self._module + '.imag'), self._print(expr.args[0])) + + def _print_Mod(self, expr): + return "%s(%s)" % (self._module_format(self._module + '.mod'), ', '.join( + (self._print(arg) for arg in expr.args))) + + def _print_re(self, expr): + return "%s(%s)" % (self._module_format(self._module + '.real'), self._print(expr.args[0])) + + def _print_sinc(self, expr): + return "%s(%s)" % (self._module_format(self._module + '.sinc'), self._print(expr.args[0]/S.Pi)) + + def _print_MatrixBase(self, expr): + if 0 in expr.shape: + func = self._module_format(f'{self._module}.{self._zeros}') + return f"{func}({self._print(expr.shape)})" + func = self.known_functions.get(expr.__class__.__name__, None) + if func is None: + func = self._module_format(f'{self._module}.array') + return "%s(%s)" % (func, self._print(expr.tolist())) + + def _print_Identity(self, expr): + shape = expr.shape + if all(dim.is_Integer for dim in shape): + return "%s(%s)" % (self._module_format(self._module + '.eye'), self._print(expr.shape[0])) + else: + raise NotImplementedError("Symbolic matrix dimensions are not yet supported for identity matrices") + + def _print_BlockMatrix(self, expr): + return '{}({})'.format(self._module_format(self._module + '.block'), + self._print(expr.args[0].tolist())) + + def _print_NDimArray(self, expr): + if expr.rank() == 0: + func = self._module_format(f'{self._module}.array') + return f"{func}({self._print(expr[()])})" + if 0 in expr.shape: + func = self._module_format(f'{self._module}.{self._zeros}') + return f"{func}({self._print(expr.shape)})" + func = self._module_format(f'{self._module}.array') + return f"{func}({self._print(expr.tolist())})" + + _add = "add" + _einsum = "einsum" + _transpose = "transpose" + _ones = "ones" + _zeros = "zeros" + + _print_lowergamma = CodePrinter._print_not_supported + _print_uppergamma = CodePrinter._print_not_supported + _print_fresnelc = CodePrinter._print_not_supported + _print_fresnels = CodePrinter._print_not_supported + +for func in _numpy_known_functions: + setattr(NumPyPrinter, f'_print_{func}', _print_known_func) + +for const in _numpy_known_constants: + setattr(NumPyPrinter, f'_print_{const}', _print_known_const) + + +_known_functions_scipy_special = { + 'Ei': 'expi', + 'erf': 'erf', + 'erfc': 'erfc', + 'besselj': 'jv', + 'bessely': 'yv', + 'besseli': 'iv', + 'besselk': 'kv', + 'cosm1': 'cosm1', + 'powm1': 'powm1', + 'factorial': 'factorial', + 'gamma': 'gamma', + 'loggamma': 'gammaln', + 'digamma': 'psi', + 'polygamma': 'polygamma', + 'RisingFactorial': 'poch', + 'jacobi': 'eval_jacobi', + 'gegenbauer': 'eval_gegenbauer', + 'chebyshevt': 'eval_chebyt', + 'chebyshevu': 'eval_chebyu', + 'legendre': 'eval_legendre', + 'hermite': 'eval_hermite', + 'laguerre': 'eval_laguerre', + 'assoc_laguerre': 'eval_genlaguerre', + 'beta': 'beta', + 'LambertW' : 'lambertw', +} + +_known_constants_scipy_constants = { + 'GoldenRatio': 'golden_ratio', + 'Pi': 'pi', +} +_scipy_known_functions = {k : "scipy.special." + v for k, v in _known_functions_scipy_special.items()} +_scipy_known_constants = {k : "scipy.constants." + v for k, v in _known_constants_scipy_constants.items()} + +class SciPyPrinter(NumPyPrinter): + + _kf = {**NumPyPrinter._kf, **_scipy_known_functions} + _kc = {**NumPyPrinter._kc, **_scipy_known_constants} + + def __init__(self, settings=None): + super().__init__(settings=settings) + self.language = "Python with SciPy and NumPy" + + def _print_SparseRepMatrix(self, expr): + i, j, data = [], [], [] + for (r, c), v in expr.todok().items(): + i.append(r) + j.append(c) + data.append(v) + + return "{name}(({data}, ({i}, {j})), shape={shape})".format( + name=self._module_format('scipy.sparse.coo_matrix'), + data=data, i=i, j=j, shape=expr.shape + ) + + _print_ImmutableSparseMatrix = _print_SparseRepMatrix + + # SciPy's lpmv has a different order of arguments from assoc_legendre + def _print_assoc_legendre(self, expr): + return "{0}({2}, {1}, {3})".format( + self._module_format('scipy.special.lpmv'), + self._print(expr.args[0]), + self._print(expr.args[1]), + self._print(expr.args[2])) + + def _print_lowergamma(self, expr): + return "{0}({2})*{1}({2}, {3})".format( + self._module_format('scipy.special.gamma'), + self._module_format('scipy.special.gammainc'), + self._print(expr.args[0]), + self._print(expr.args[1])) + + def _print_uppergamma(self, expr): + return "{0}({2})*{1}({2}, {3})".format( + self._module_format('scipy.special.gamma'), + self._module_format('scipy.special.gammaincc'), + self._print(expr.args[0]), + self._print(expr.args[1])) + + def _print_betainc(self, expr): + betainc = self._module_format('scipy.special.betainc') + beta = self._module_format('scipy.special.beta') + args = [self._print(arg) for arg in expr.args] + return f"({betainc}({args[0]}, {args[1]}, {args[3]}) - {betainc}({args[0]}, {args[1]}, {args[2]})) \ + * {beta}({args[0]}, {args[1]})" + + def _print_betainc_regularized(self, expr): + return "{0}({1}, {2}, {4}) - {0}({1}, {2}, {3})".format( + self._module_format('scipy.special.betainc'), + self._print(expr.args[0]), + self._print(expr.args[1]), + self._print(expr.args[2]), + self._print(expr.args[3])) + + def _print_fresnels(self, expr): + return "{}({})[0]".format( + self._module_format("scipy.special.fresnel"), + self._print(expr.args[0])) + + def _print_fresnelc(self, expr): + return "{}({})[1]".format( + self._module_format("scipy.special.fresnel"), + self._print(expr.args[0])) + + def _print_airyai(self, expr): + return "{}({})[0]".format( + self._module_format("scipy.special.airy"), + self._print(expr.args[0])) + + def _print_airyaiprime(self, expr): + return "{}({})[1]".format( + self._module_format("scipy.special.airy"), + self._print(expr.args[0])) + + def _print_airybi(self, expr): + return "{}({})[2]".format( + self._module_format("scipy.special.airy"), + self._print(expr.args[0])) + + def _print_airybiprime(self, expr): + return "{}({})[3]".format( + self._module_format("scipy.special.airy"), + self._print(expr.args[0])) + + def _print_bernoulli(self, expr): + # scipy's bernoulli is inconsistent with SymPy's so rewrite + return self._print(expr._eval_rewrite_as_zeta(*expr.args)) + + def _print_harmonic(self, expr): + return self._print(expr._eval_rewrite_as_zeta(*expr.args)) + + def _print_Integral(self, e): + integration_vars, limits = _unpack_integral_limits(e) + + if len(limits) == 1: + # nicer (but not necessary) to prefer quad over nquad for 1D case + module_str = self._module_format("scipy.integrate.quad") + limit_str = "%s, %s" % tuple(map(self._print, limits[0])) + else: + module_str = self._module_format("scipy.integrate.nquad") + limit_str = "({})".format(", ".join( + "(%s, %s)" % tuple(map(self._print, l)) for l in limits)) + + return "{}(lambda {}: {}, {})[0]".format( + module_str, + ", ".join(map(self._print, integration_vars)), + self._print(e.args[0]), + limit_str) + + def _print_Si(self, expr): + return "{}({})[0]".format( + self._module_format("scipy.special.sici"), + self._print(expr.args[0])) + + def _print_Ci(self, expr): + return "{}({})[1]".format( + self._module_format("scipy.special.sici"), + self._print(expr.args[0])) + +for func in _scipy_known_functions: + setattr(SciPyPrinter, f'_print_{func}', _print_known_func) + +for const in _scipy_known_constants: + setattr(SciPyPrinter, f'_print_{const}', _print_known_const) + + +_cupy_known_functions = {k : "cupy." + v for k, v in _known_functions_numpy.items()} +_cupy_known_constants = {k : "cupy." + v for k, v in _known_constants_numpy.items()} + +class CuPyPrinter(NumPyPrinter): + """ + CuPy printer which handles vectorized piecewise functions, + logical operators, etc. + """ + + _module = 'cupy' + _kf = _cupy_known_functions + _kc = _cupy_known_constants + + def __init__(self, settings=None): + super().__init__(settings=settings) + +for func in _cupy_known_functions: + setattr(CuPyPrinter, f'_print_{func}', _print_known_func) + +for const in _cupy_known_constants: + setattr(CuPyPrinter, f'_print_{const}', _print_known_const) + + +_jax_known_functions = {k: 'jax.numpy.' + v for k, v in _known_functions_numpy.items()} +_jax_known_constants = {k: 'jax.numpy.' + v for k, v in _known_constants_numpy.items()} + +class JaxPrinter(NumPyPrinter): + """ + JAX printer which handles vectorized piecewise functions, + logical operators, etc. + """ + _module = "jax.numpy" + + _kf = _jax_known_functions + _kc = _jax_known_constants + + def __init__(self, settings=None): + super().__init__(settings=settings) + self.printmethod = '_jaxcode' + + # These need specific override to allow for the lack of "jax.numpy.reduce" + def _print_And(self, expr): + "Logical And printer" + return "{}({}.asarray([{}]), axis=0)".format( + self._module_format(self._module + ".all"), + self._module_format(self._module), + ",".join(self._print(i) for i in expr.args), + ) + + def _print_Or(self, expr): + "Logical Or printer" + return "{}({}.asarray([{}]), axis=0)".format( + self._module_format(self._module + ".any"), + self._module_format(self._module), + ",".join(self._print(i) for i in expr.args), + ) + +for func in _jax_known_functions: + setattr(JaxPrinter, f'_print_{func}', _print_known_func) + +for const in _jax_known_constants: + setattr(JaxPrinter, f'_print_{const}', _print_known_const) diff --git a/phivenv/Lib/site-packages/sympy/printing/octave.py b/phivenv/Lib/site-packages/sympy/printing/octave.py new file mode 100644 index 0000000000000000000000000000000000000000..2cf2d6a5754668d7a95ef5dc7b27b4864756a9e5 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/octave.py @@ -0,0 +1,711 @@ +""" +Octave (and Matlab) code printer + +The `OctaveCodePrinter` converts SymPy expressions into Octave expressions. +It uses a subset of the Octave language for Matlab compatibility. + +A complete code generator, which uses `octave_code` extensively, can be found +in `sympy.utilities.codegen`. The `codegen` module can be used to generate +complete source code files. + +""" + +from __future__ import annotations +from typing import Any + +from sympy.core import Mul, Pow, S, Rational +from sympy.core.mul import _keep_coeff +from sympy.core.numbers import equal_valued +from sympy.printing.codeprinter import CodePrinter +from sympy.printing.precedence import precedence, PRECEDENCE +from re import search + +# List of known functions. First, those that have the same name in +# SymPy and Octave. This is almost certainly incomplete! +known_fcns_src1 = ["sin", "cos", "tan", "cot", "sec", "csc", + "asin", "acos", "acot", "atan", "atan2", "asec", "acsc", + "sinh", "cosh", "tanh", "coth", "csch", "sech", + "asinh", "acosh", "atanh", "acoth", "asech", "acsch", + "erfc", "erfi", "erf", "erfinv", "erfcinv", + "besseli", "besselj", "besselk", "bessely", + "bernoulli", "beta", "euler", "exp", "factorial", "floor", + "fresnelc", "fresnels", "gamma", "harmonic", "log", + "polylog", "sign", "zeta", "legendre"] + +# These functions have different names ("SymPy": "Octave"), more +# generally a mapping to (argument_conditions, octave_function). +known_fcns_src2 = { + "Abs": "abs", + "arg": "angle", # arg/angle ok in Octave but only angle in Matlab + "binomial": "bincoeff", + "ceiling": "ceil", + "chebyshevu": "chebyshevU", + "chebyshevt": "chebyshevT", + "Chi": "coshint", + "Ci": "cosint", + "conjugate": "conj", + "DiracDelta": "dirac", + "Heaviside": "heaviside", + "im": "imag", + "laguerre": "laguerreL", + "LambertW": "lambertw", + "li": "logint", + "loggamma": "gammaln", + "Max": "max", + "Min": "min", + "Mod": "mod", + "polygamma": "psi", + "re": "real", + "RisingFactorial": "pochhammer", + "Shi": "sinhint", + "Si": "sinint", +} + + +class OctaveCodePrinter(CodePrinter): + """ + A printer to convert expressions to strings of Octave/Matlab code. + """ + printmethod = "_octave" + language = "Octave" + + _operators = { + 'and': '&', + 'or': '|', + 'not': '~', + } + + _default_settings: dict[str, Any] = dict(CodePrinter._default_settings, **{ + 'precision': 17, + 'user_functions': {}, + 'contract': True, + 'inline': True, + }) + # Note: contract is for expressing tensors as loops (if True), or just + # assignment (if False). FIXME: this should be looked a more carefully + # for Octave. + + + def __init__(self, settings={}): + super().__init__(settings) + self.known_functions = dict(zip(known_fcns_src1, known_fcns_src1)) + self.known_functions.update(dict(known_fcns_src2)) + userfuncs = settings.get('user_functions', {}) + self.known_functions.update(userfuncs) + + + def _rate_index_position(self, p): + return p*5 + + + def _get_statement(self, codestring): + return "%s;" % codestring + + + def _get_comment(self, text): + return "% {}".format(text) + + + def _declare_number_const(self, name, value): + return "{} = {};".format(name, value) + + + def _format_code(self, lines): + return self.indent_code(lines) + + + def _traverse_matrix_indices(self, mat): + # Octave uses Fortran order (column-major) + rows, cols = mat.shape + return ((i, j) for j in range(cols) for i in range(rows)) + + + def _get_loop_opening_ending(self, indices): + open_lines = [] + close_lines = [] + for i in indices: + # Octave arrays start at 1 and end at dimension + var, start, stop = map(self._print, + [i.label, i.lower + 1, i.upper + 1]) + open_lines.append("for %s = %s:%s" % (var, start, stop)) + close_lines.append("end") + return open_lines, close_lines + + + def _print_Mul(self, expr): + # print complex numbers nicely in Octave + if (expr.is_number and expr.is_imaginary and + (S.ImaginaryUnit*expr).is_Integer): + return "%si" % self._print(-S.ImaginaryUnit*expr) + + # cribbed from str.py + prec = precedence(expr) + + c, e = expr.as_coeff_Mul() + if c < 0: + expr = _keep_coeff(-c, e) + sign = "-" + else: + sign = "" + + a = [] # items in the numerator + b = [] # items that are in the denominator (if any) + + pow_paren = [] # Will collect all pow with more than one base element and exp = -1 + + if self.order not in ('old', 'none'): + args = expr.as_ordered_factors() + else: + # use make_args in case expr was something like -x -> x + args = Mul.make_args(expr) + + # Gather args for numerator/denominator + for item in args: + if (item.is_commutative and item.is_Pow and item.exp.is_Rational + and item.exp.is_negative): + if item.exp != -1: + b.append(Pow(item.base, -item.exp, evaluate=False)) + else: + if len(item.args[0].args) != 1 and isinstance(item.base, Mul): # To avoid situations like #14160 + pow_paren.append(item) + b.append(Pow(item.base, -item.exp)) + elif item.is_Rational and item is not S.Infinity: + if item.p != 1: + a.append(Rational(item.p)) + if item.q != 1: + b.append(Rational(item.q)) + else: + a.append(item) + + a = a or [S.One] + + a_str = [self.parenthesize(x, prec) for x in a] + b_str = [self.parenthesize(x, prec) for x in b] + + # To parenthesize Pow with exp = -1 and having more than one Symbol + for item in pow_paren: + if item.base in b: + b_str[b.index(item.base)] = "(%s)" % b_str[b.index(item.base)] + + # from here it differs from str.py to deal with "*" and ".*" + def multjoin(a, a_str): + # here we probably are assuming the constants will come first + r = a_str[0] + for i in range(1, len(a)): + mulsym = '*' if a[i-1].is_number else '.*' + r = r + mulsym + a_str[i] + return r + + if not b: + return sign + multjoin(a, a_str) + elif len(b) == 1: + divsym = '/' if b[0].is_number else './' + return sign + multjoin(a, a_str) + divsym + b_str[0] + else: + divsym = '/' if all(bi.is_number for bi in b) else './' + return (sign + multjoin(a, a_str) + + divsym + "(%s)" % multjoin(b, b_str)) + + def _print_Relational(self, expr): + lhs_code = self._print(expr.lhs) + rhs_code = self._print(expr.rhs) + op = expr.rel_op + return "{} {} {}".format(lhs_code, op, rhs_code) + + def _print_Pow(self, expr): + powsymbol = '^' if all(x.is_number for x in expr.args) else '.^' + + PREC = precedence(expr) + + if equal_valued(expr.exp, 0.5): + return "sqrt(%s)" % self._print(expr.base) + + if expr.is_commutative: + if equal_valued(expr.exp, -0.5): + sym = '/' if expr.base.is_number else './' + return "1" + sym + "sqrt(%s)" % self._print(expr.base) + if equal_valued(expr.exp, -1): + sym = '/' if expr.base.is_number else './' + return "1" + sym + "%s" % self.parenthesize(expr.base, PREC) + + return '%s%s%s' % (self.parenthesize(expr.base, PREC), powsymbol, + self.parenthesize(expr.exp, PREC)) + + + def _print_MatPow(self, expr): + PREC = precedence(expr) + return '%s^%s' % (self.parenthesize(expr.base, PREC), + self.parenthesize(expr.exp, PREC)) + + def _print_MatrixSolve(self, expr): + PREC = precedence(expr) + return "%s \\ %s" % (self.parenthesize(expr.matrix, PREC), + self.parenthesize(expr.vector, PREC)) + + def _print_Pi(self, expr): + return 'pi' + + + def _print_ImaginaryUnit(self, expr): + return "1i" + + + def _print_Exp1(self, expr): + return "exp(1)" + + + def _print_GoldenRatio(self, expr): + # FIXME: how to do better, e.g., for octave_code(2*GoldenRatio)? + #return self._print((1+sqrt(S(5)))/2) + return "(1+sqrt(5))/2" + + + def _print_Assignment(self, expr): + from sympy.codegen.ast import Assignment + from sympy.functions.elementary.piecewise import Piecewise + from sympy.tensor.indexed import IndexedBase + # Copied from codeprinter, but remove special MatrixSymbol treatment + lhs = expr.lhs + rhs = expr.rhs + # We special case assignments that take multiple lines + if not self._settings["inline"] and isinstance(expr.rhs, Piecewise): + # Here we modify Piecewise so each expression is now + # an Assignment, and then continue on the print. + expressions = [] + conditions = [] + for (e, c) in rhs.args: + expressions.append(Assignment(lhs, e)) + conditions.append(c) + temp = Piecewise(*zip(expressions, conditions)) + return self._print(temp) + if self._settings["contract"] and (lhs.has(IndexedBase) or + rhs.has(IndexedBase)): + # Here we check if there is looping to be done, and if so + # print the required loops. + return self._doprint_loops(rhs, lhs) + else: + lhs_code = self._print(lhs) + rhs_code = self._print(rhs) + return self._get_statement("%s = %s" % (lhs_code, rhs_code)) + + + def _print_Infinity(self, expr): + return 'inf' + + + def _print_NegativeInfinity(self, expr): + return '-inf' + + + def _print_NaN(self, expr): + return 'NaN' + + + def _print_list(self, expr): + return '{' + ', '.join(self._print(a) for a in expr) + '}' + _print_tuple = _print_list + _print_Tuple = _print_list + _print_List = _print_list + + + def _print_BooleanTrue(self, expr): + return "true" + + + def _print_BooleanFalse(self, expr): + return "false" + + + def _print_bool(self, expr): + return str(expr).lower() + + + # Could generate quadrature code for definite Integrals? + #_print_Integral = _print_not_supported + + + def _print_MatrixBase(self, A): + # Handle zero dimensions: + if (A.rows, A.cols) == (0, 0): + return '[]' + elif S.Zero in A.shape: + return 'zeros(%s, %s)' % (A.rows, A.cols) + elif (A.rows, A.cols) == (1, 1): + # Octave does not distinguish between scalars and 1x1 matrices + return self._print(A[0, 0]) + return "[%s]" % "; ".join(" ".join([self._print(a) for a in A[r, :]]) + for r in range(A.rows)) + + + def _print_SparseRepMatrix(self, A): + from sympy.matrices import Matrix + L = A.col_list() + # make row vectors of the indices and entries + I = Matrix([[k[0] + 1 for k in L]]) + J = Matrix([[k[1] + 1 for k in L]]) + AIJ = Matrix([[k[2] for k in L]]) + return "sparse(%s, %s, %s, %s, %s)" % (self._print(I), self._print(J), + self._print(AIJ), A.rows, A.cols) + + + def _print_MatrixElement(self, expr): + return self.parenthesize(expr.parent, PRECEDENCE["Atom"], strict=True) \ + + '(%s, %s)' % (expr.i + 1, expr.j + 1) + + + def _print_MatrixSlice(self, expr): + def strslice(x, lim): + l = x[0] + 1 + h = x[1] + step = x[2] + lstr = self._print(l) + hstr = 'end' if h == lim else self._print(h) + if step == 1: + if l == 1 and h == lim: + return ':' + if l == h: + return lstr + else: + return lstr + ':' + hstr + else: + return ':'.join((lstr, self._print(step), hstr)) + return (self._print(expr.parent) + '(' + + strslice(expr.rowslice, expr.parent.shape[0]) + ', ' + + strslice(expr.colslice, expr.parent.shape[1]) + ')') + + + def _print_Indexed(self, expr): + inds = [ self._print(i) for i in expr.indices ] + return "%s(%s)" % (self._print(expr.base.label), ", ".join(inds)) + + + def _print_KroneckerDelta(self, expr): + prec = PRECEDENCE["Pow"] + return "double(%s == %s)" % tuple(self.parenthesize(x, prec) + for x in expr.args) + + def _print_HadamardProduct(self, expr): + return '.*'.join([self.parenthesize(arg, precedence(expr)) + for arg in expr.args]) + + def _print_HadamardPower(self, expr): + PREC = precedence(expr) + return '.**'.join([ + self.parenthesize(expr.base, PREC), + self.parenthesize(expr.exp, PREC) + ]) + + def _print_Identity(self, expr): + shape = expr.shape + if len(shape) == 2 and shape[0] == shape[1]: + shape = [shape[0]] + s = ", ".join(self._print(n) for n in shape) + return "eye(" + s + ")" + + def _print_lowergamma(self, expr): + # Octave implements regularized incomplete gamma function + return "(gammainc({1}, {0}).*gamma({0}))".format( + self._print(expr.args[0]), self._print(expr.args[1])) + + + def _print_uppergamma(self, expr): + return "(gammainc({1}, {0}, 'upper').*gamma({0}))".format( + self._print(expr.args[0]), self._print(expr.args[1])) + + + def _print_sinc(self, expr): + #Note: Divide by pi because Octave implements normalized sinc function. + return "sinc(%s)" % self._print(expr.args[0]/S.Pi) + + + def _print_hankel1(self, expr): + return "besselh(%s, 1, %s)" % (self._print(expr.order), + self._print(expr.argument)) + + + def _print_hankel2(self, expr): + return "besselh(%s, 2, %s)" % (self._print(expr.order), + self._print(expr.argument)) + + + # Note: as of 2015, Octave doesn't have spherical Bessel functions + def _print_jn(self, expr): + from sympy.functions import sqrt, besselj + x = expr.argument + expr2 = sqrt(S.Pi/(2*x))*besselj(expr.order + S.Half, x) + return self._print(expr2) + + + def _print_yn(self, expr): + from sympy.functions import sqrt, bessely + x = expr.argument + expr2 = sqrt(S.Pi/(2*x))*bessely(expr.order + S.Half, x) + return self._print(expr2) + + + def _print_airyai(self, expr): + return "airy(0, %s)" % self._print(expr.args[0]) + + + def _print_airyaiprime(self, expr): + return "airy(1, %s)" % self._print(expr.args[0]) + + + def _print_airybi(self, expr): + return "airy(2, %s)" % self._print(expr.args[0]) + + + def _print_airybiprime(self, expr): + return "airy(3, %s)" % self._print(expr.args[0]) + + + def _print_expint(self, expr): + mu, x = expr.args + if mu != 1: + return self._print_not_supported(expr) + return "expint(%s)" % self._print(x) + + + def _one_or_two_reversed_args(self, expr): + assert len(expr.args) <= 2 + return '{name}({args})'.format( + name=self.known_functions[expr.__class__.__name__], + args=", ".join([self._print(x) for x in reversed(expr.args)]) + ) + + + _print_DiracDelta = _print_LambertW = _one_or_two_reversed_args + + + def _nested_binary_math_func(self, expr): + return '{name}({arg1}, {arg2})'.format( + name=self.known_functions[expr.__class__.__name__], + arg1=self._print(expr.args[0]), + arg2=self._print(expr.func(*expr.args[1:])) + ) + + _print_Max = _print_Min = _nested_binary_math_func + + + def _print_Piecewise(self, expr): + if expr.args[-1].cond != True: + # We need the last conditional to be a True, otherwise the resulting + # function may not return a result. + raise ValueError("All Piecewise expressions must contain an " + "(expr, True) statement to be used as a default " + "condition. Without one, the generated " + "expression may not evaluate to anything under " + "some condition.") + lines = [] + if self._settings["inline"]: + # Express each (cond, expr) pair in a nested Horner form: + # (condition) .* (expr) + (not cond) .* () + # Expressions that result in multiple statements won't work here. + ecpairs = ["({0}).*({1}) + (~({0})).*(".format + (self._print(c), self._print(e)) + for e, c in expr.args[:-1]] + elast = "%s" % self._print(expr.args[-1].expr) + pw = " ...\n".join(ecpairs) + elast + ")"*len(ecpairs) + # Note: current need these outer brackets for 2*pw. Would be + # nicer to teach parenthesize() to do this for us when needed! + return "(" + pw + ")" + else: + for i, (e, c) in enumerate(expr.args): + if i == 0: + lines.append("if (%s)" % self._print(c)) + elif i == len(expr.args) - 1 and c == True: + lines.append("else") + else: + lines.append("elseif (%s)" % self._print(c)) + code0 = self._print(e) + lines.append(code0) + if i == len(expr.args) - 1: + lines.append("end") + return "\n".join(lines) + + + def _print_zeta(self, expr): + if len(expr.args) == 1: + return "zeta(%s)" % self._print(expr.args[0]) + else: + # Matlab two argument zeta is not equivalent to SymPy's + return self._print_not_supported(expr) + + + def indent_code(self, code): + """Accepts a string of code or a list of code lines""" + + # code mostly copied from ccode + if isinstance(code, str): + code_lines = self.indent_code(code.splitlines(True)) + return ''.join(code_lines) + + tab = " " + inc_regex = ('^function ', '^if ', '^elseif ', '^else$', '^for ') + dec_regex = ('^end$', '^elseif ', '^else$') + + # pre-strip left-space from the code + code = [ line.lstrip(' \t') for line in code ] + + increase = [ int(any(search(re, line) for re in inc_regex)) + for line in code ] + decrease = [ int(any(search(re, line) for re in dec_regex)) + for line in code ] + + pretty = [] + level = 0 + for n, line in enumerate(code): + if line in ('', '\n'): + pretty.append(line) + continue + level -= decrease[n] + pretty.append("%s%s" % (tab*level, line)) + level += increase[n] + return pretty + + +def octave_code(expr, assign_to=None, **settings): + r"""Converts `expr` to a string of Octave (or Matlab) code. + + The string uses a subset of the Octave language for Matlab compatibility. + + Parameters + ========== + + expr : Expr + A SymPy expression to be converted. + assign_to : optional + When given, the argument is used as the name of the variable to which + the expression is assigned. Can be a string, ``Symbol``, + ``MatrixSymbol``, or ``Indexed`` type. This can be helpful for + expressions that generate multi-line statements. + precision : integer, optional + The precision for numbers such as pi [default=16]. + user_functions : dict, optional + A dictionary where keys are ``FunctionClass`` instances and values are + their string representations. Alternatively, the dictionary value can + be a list of tuples i.e. [(argument_test, cfunction_string)]. See + below for examples. + human : bool, optional + If True, the result is a single string that may contain some constant + declarations for the number symbols. If False, the same information is + returned in a tuple of (symbols_to_declare, not_supported_functions, + code_text). [default=True]. + contract: bool, optional + If True, ``Indexed`` instances are assumed to obey tensor contraction + rules and the corresponding nested loops over indices are generated. + Setting contract=False will not generate loops, instead the user is + responsible to provide values for the indices in the code. + [default=True]. + inline: bool, optional + If True, we try to create single-statement code instead of multiple + statements. [default=True]. + + Examples + ======== + + >>> from sympy import octave_code, symbols, sin, pi + >>> x = symbols('x') + >>> octave_code(sin(x).series(x).removeO()) + 'x.^5/120 - x.^3/6 + x' + + >>> from sympy import Rational, ceiling + >>> x, y, tau = symbols("x, y, tau") + >>> octave_code((2*tau)**Rational(7, 2)) + '8*sqrt(2)*tau.^(7/2)' + + Note that element-wise (Hadamard) operations are used by default between + symbols. This is because its very common in Octave to write "vectorized" + code. It is harmless if the values are scalars. + + >>> octave_code(sin(pi*x*y), assign_to="s") + 's = sin(pi*x.*y);' + + If you need a matrix product "*" or matrix power "^", you can specify the + symbol as a ``MatrixSymbol``. + + >>> from sympy import Symbol, MatrixSymbol + >>> n = Symbol('n', integer=True, positive=True) + >>> A = MatrixSymbol('A', n, n) + >>> octave_code(3*pi*A**3) + '(3*pi)*A^3' + + This class uses several rules to decide which symbol to use a product. + Pure numbers use "*", Symbols use ".*" and MatrixSymbols use "*". + A HadamardProduct can be used to specify componentwise multiplication ".*" + of two MatrixSymbols. There is currently there is no easy way to specify + scalar symbols, so sometimes the code might have some minor cosmetic + issues. For example, suppose x and y are scalars and A is a Matrix, then + while a human programmer might write "(x^2*y)*A^3", we generate: + + >>> octave_code(x**2*y*A**3) + '(x.^2.*y)*A^3' + + Matrices are supported using Octave inline notation. When using + ``assign_to`` with matrices, the name can be specified either as a string + or as a ``MatrixSymbol``. The dimensions must align in the latter case. + + >>> from sympy import Matrix, MatrixSymbol + >>> mat = Matrix([[x**2, sin(x), ceiling(x)]]) + >>> octave_code(mat, assign_to='A') + 'A = [x.^2 sin(x) ceil(x)];' + + ``Piecewise`` expressions are implemented with logical masking by default. + Alternatively, you can pass "inline=False" to use if-else conditionals. + Note that if the ``Piecewise`` lacks a default term, represented by + ``(expr, True)`` then an error will be thrown. This is to prevent + generating an expression that may not evaluate to anything. + + >>> from sympy import Piecewise + >>> pw = Piecewise((x + 1, x > 0), (x, True)) + >>> octave_code(pw, assign_to=tau) + 'tau = ((x > 0).*(x + 1) + (~(x > 0)).*(x));' + + Note that any expression that can be generated normally can also exist + inside a Matrix: + + >>> mat = Matrix([[x**2, pw, sin(x)]]) + >>> octave_code(mat, assign_to='A') + 'A = [x.^2 ((x > 0).*(x + 1) + (~(x > 0)).*(x)) sin(x)];' + + Custom printing can be defined for certain types by passing a dictionary of + "type" : "function" to the ``user_functions`` kwarg. Alternatively, the + dictionary value can be a list of tuples i.e., [(argument_test, + cfunction_string)]. This can be used to call a custom Octave function. + + >>> from sympy import Function + >>> f = Function('f') + >>> g = Function('g') + >>> custom_functions = { + ... "f": "existing_octave_fcn", + ... "g": [(lambda x: x.is_Matrix, "my_mat_fcn"), + ... (lambda x: not x.is_Matrix, "my_fcn")] + ... } + >>> mat = Matrix([[1, x]]) + >>> octave_code(f(x) + g(x) + g(mat), user_functions=custom_functions) + 'existing_octave_fcn(x) + my_fcn(x) + my_mat_fcn([1 x])' + + Support for loops is provided through ``Indexed`` types. With + ``contract=True`` these expressions will be turned into loops, whereas + ``contract=False`` will just print the assignment expression that should be + looped over: + + >>> from sympy import Eq, IndexedBase, Idx + >>> len_y = 5 + >>> y = IndexedBase('y', shape=(len_y,)) + >>> t = IndexedBase('t', shape=(len_y,)) + >>> Dy = IndexedBase('Dy', shape=(len_y-1,)) + >>> i = Idx('i', len_y-1) + >>> e = Eq(Dy[i], (y[i+1]-y[i])/(t[i+1]-t[i])) + >>> octave_code(e.rhs, assign_to=e.lhs, contract=False) + 'Dy(i) = (y(i + 1) - y(i))./(t(i + 1) - t(i));' + """ + return OctaveCodePrinter(settings).doprint(expr, assign_to) + + +def print_octave_code(expr, **settings): + """Prints the Octave (or Matlab) representation of the given expression. + + See `octave_code` for the meaning of the optional arguments. + """ + print(octave_code(expr, **settings)) diff --git a/phivenv/Lib/site-packages/sympy/printing/precedence.py b/phivenv/Lib/site-packages/sympy/printing/precedence.py new file mode 100644 index 0000000000000000000000000000000000000000..d22d5746aeee51bddcf273d4575c30c3c27db71a --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/precedence.py @@ -0,0 +1,180 @@ +"""A module providing information about the necessity of brackets""" + + +# Default precedence values for some basic types +PRECEDENCE = { + "Lambda": 1, + "Xor": 10, + "Or": 20, + "And": 30, + "Relational": 35, + "Add": 40, + "Mul": 50, + "Pow": 60, + "Func": 70, + "Not": 100, + "Atom": 1000, + "BitwiseOr": 36, + "BitwiseXor": 37, + "BitwiseAnd": 38 +} + +# A dictionary assigning precedence values to certain classes. These values are +# treated like they were inherited, so not every single class has to be named +# here. +# Do not use this with printers other than StrPrinter +PRECEDENCE_VALUES = { + "Equivalent": PRECEDENCE["Xor"], + "Xor": PRECEDENCE["Xor"], + "Implies": PRECEDENCE["Xor"], + "Or": PRECEDENCE["Or"], + "And": PRECEDENCE["And"], + "Add": PRECEDENCE["Add"], + "Pow": PRECEDENCE["Pow"], + "Relational": PRECEDENCE["Relational"], + "Sub": PRECEDENCE["Add"], + "Not": PRECEDENCE["Not"], + "Function" : PRECEDENCE["Func"], + "NegativeInfinity": PRECEDENCE["Add"], + "MatAdd": PRECEDENCE["Add"], + "MatPow": PRECEDENCE["Pow"], + "MatrixSolve": PRECEDENCE["Mul"], + "Mod": PRECEDENCE["Mul"], + "TensAdd": PRECEDENCE["Add"], + # As soon as `TensMul` is a subclass of `Mul`, remove this: + "TensMul": PRECEDENCE["Mul"], + "HadamardProduct": PRECEDENCE["Mul"], + "HadamardPower": PRECEDENCE["Pow"], + "KroneckerProduct": PRECEDENCE["Mul"], + "Equality": PRECEDENCE["Mul"], + "Unequality": PRECEDENCE["Mul"], +} + +# Sometimes it's not enough to assign a fixed precedence value to a +# class. Then a function can be inserted in this dictionary that takes +# an instance of this class as argument and returns the appropriate +# precedence value. + +# Precedence functions + + +def precedence_Mul(item): + from sympy.core.function import Function + if any(hasattr(arg, 'precedence') and isinstance(arg, Function) and + arg.precedence < PRECEDENCE["Mul"] for arg in item.args): + return PRECEDENCE["Mul"] + + if item.could_extract_minus_sign(): + return PRECEDENCE["Add"] + return PRECEDENCE["Mul"] + + +def precedence_Rational(item): + if item.p < 0: + return PRECEDENCE["Add"] + return PRECEDENCE["Mul"] + + +def precedence_Integer(item): + if item.p < 0: + return PRECEDENCE["Add"] + return PRECEDENCE["Atom"] + + +def precedence_Float(item): + if item < 0: + return PRECEDENCE["Add"] + return PRECEDENCE["Atom"] + + +def precedence_PolyElement(item): + if item.is_generator: + return PRECEDENCE["Atom"] + elif item.is_ground: + return precedence(item.coeff(1)) + elif item.is_term: + return PRECEDENCE["Mul"] + else: + return PRECEDENCE["Add"] + + +def precedence_FracElement(item): + if item.denom == 1: + return precedence_PolyElement(item.numer) + else: + return PRECEDENCE["Mul"] + + +def precedence_UnevaluatedExpr(item): + return precedence(item.args[0]) - 0.5 + + +PRECEDENCE_FUNCTIONS = { + "Integer": precedence_Integer, + "Mul": precedence_Mul, + "Rational": precedence_Rational, + "Float": precedence_Float, + "PolyElement": precedence_PolyElement, + "FracElement": precedence_FracElement, + "UnevaluatedExpr": precedence_UnevaluatedExpr, +} + + +def precedence(item): + """Returns the precedence of a given object. + + This is the precedence for StrPrinter. + """ + if hasattr(item, "precedence"): + return item.precedence + if not isinstance(item, type): + for i in type(item).mro(): + n = i.__name__ + if n in PRECEDENCE_FUNCTIONS: + return PRECEDENCE_FUNCTIONS[n](item) + elif n in PRECEDENCE_VALUES: + return PRECEDENCE_VALUES[n] + return PRECEDENCE["Atom"] + + +PRECEDENCE_TRADITIONAL = PRECEDENCE.copy() +PRECEDENCE_TRADITIONAL['Integral'] = PRECEDENCE["Mul"] +PRECEDENCE_TRADITIONAL['Sum'] = PRECEDENCE["Mul"] +PRECEDENCE_TRADITIONAL['Product'] = PRECEDENCE["Mul"] +PRECEDENCE_TRADITIONAL['Limit'] = PRECEDENCE["Mul"] +PRECEDENCE_TRADITIONAL['Derivative'] = PRECEDENCE["Mul"] +PRECEDENCE_TRADITIONAL['TensorProduct'] = PRECEDENCE["Mul"] +PRECEDENCE_TRADITIONAL['Transpose'] = PRECEDENCE["Pow"] +PRECEDENCE_TRADITIONAL['Adjoint'] = PRECEDENCE["Pow"] +PRECEDENCE_TRADITIONAL['Dot'] = PRECEDENCE["Mul"] - 1 +PRECEDENCE_TRADITIONAL['Cross'] = PRECEDENCE["Mul"] - 1 +PRECEDENCE_TRADITIONAL['Gradient'] = PRECEDENCE["Mul"] - 1 +PRECEDENCE_TRADITIONAL['Divergence'] = PRECEDENCE["Mul"] - 1 +PRECEDENCE_TRADITIONAL['Curl'] = PRECEDENCE["Mul"] - 1 +PRECEDENCE_TRADITIONAL['Laplacian'] = PRECEDENCE["Mul"] - 1 +PRECEDENCE_TRADITIONAL['Union'] = PRECEDENCE['Xor'] +PRECEDENCE_TRADITIONAL['Intersection'] = PRECEDENCE['Xor'] +PRECEDENCE_TRADITIONAL['Complement'] = PRECEDENCE['Xor'] +PRECEDENCE_TRADITIONAL['SymmetricDifference'] = PRECEDENCE['Xor'] +PRECEDENCE_TRADITIONAL['ProductSet'] = PRECEDENCE['Xor'] +PRECEDENCE_TRADITIONAL['DotProduct'] = PRECEDENCE_TRADITIONAL['Dot'] + + +def precedence_traditional(item): + """Returns the precedence of a given object according to the + traditional rules of mathematics. + + This is the precedence for the LaTeX and pretty printer. + """ + # Integral, Sum, Product, Limit have the precedence of Mul in LaTeX, + # the precedence of Atom for other printers: + from sympy.core.expr import UnevaluatedExpr + + if isinstance(item, UnevaluatedExpr): + return precedence_traditional(item.args[0]) + + n = item.__class__.__name__ + if n in PRECEDENCE_TRADITIONAL: + return PRECEDENCE_TRADITIONAL[n] + + return precedence(item) diff --git a/phivenv/Lib/site-packages/sympy/printing/pretty/__init__.py b/phivenv/Lib/site-packages/sympy/printing/pretty/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cbabc649152a3c353a37225d342064634fbb5805 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/pretty/__init__.py @@ -0,0 +1,12 @@ +"""ASCII-ART 2D pretty-printer""" + +from .pretty import (pretty, pretty_print, pprint, pprint_use_unicode, + pprint_try_use_unicode, pager_print) + +# if unicode output is available -- let's use it +pprint_try_use_unicode() + +__all__ = [ + 'pretty', 'pretty_print', 'pprint', 'pprint_use_unicode', + 'pprint_try_use_unicode', 'pager_print', +] diff --git a/phivenv/Lib/site-packages/sympy/printing/pretty/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/pretty/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac6afe1d1a4bb998784f8790c99b687924062b50 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/pretty/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/pretty/__pycache__/pretty.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/pretty/__pycache__/pretty.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a79536943dca8e9d62ef1c7214a7a9b44455ef8c Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/pretty/__pycache__/pretty.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/pretty/__pycache__/pretty_symbology.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/pretty/__pycache__/pretty_symbology.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..741fdf8a07f76ea07e8664226db059259e525ba6 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/pretty/__pycache__/pretty_symbology.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/pretty/__pycache__/stringpict.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/pretty/__pycache__/stringpict.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb7b12959c5457f4552f72a8465314006740d48d Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/pretty/__pycache__/stringpict.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/pretty/pretty.py b/phivenv/Lib/site-packages/sympy/printing/pretty/pretty.py new file mode 100644 index 0000000000000000000000000000000000000000..b945f009119b24fc95e8452d91359957baba26a8 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/pretty/pretty.py @@ -0,0 +1,2937 @@ +import itertools + +from sympy.core import S +from sympy.core.add import Add +from sympy.core.containers import Tuple +from sympy.core.function import Function +from sympy.core.mul import Mul +from sympy.core.numbers import Number, Rational +from sympy.core.power import Pow +from sympy.core.sorting import default_sort_key +from sympy.core.symbol import Symbol +from sympy.core.sympify import SympifyError +from sympy.printing.conventions import requires_partial +from sympy.printing.precedence import PRECEDENCE, precedence, precedence_traditional +from sympy.printing.printer import Printer, print_function +from sympy.printing.str import sstr +from sympy.utilities.iterables import has_variety +from sympy.utilities.exceptions import sympy_deprecation_warning + +from sympy.printing.pretty.stringpict import prettyForm, stringPict +from sympy.printing.pretty.pretty_symbology import hobj, vobj, xobj, \ + xsym, pretty_symbol, pretty_atom, pretty_use_unicode, greek_unicode, U, \ + pretty_try_use_unicode, annotated, is_subscriptable_in_unicode, center_pad, root as nth_root + +# rename for usage from outside +pprint_use_unicode = pretty_use_unicode +pprint_try_use_unicode = pretty_try_use_unicode + + +class PrettyPrinter(Printer): + """Printer, which converts an expression into 2D ASCII-art figure.""" + printmethod = "_pretty" + + _default_settings = { + "order": None, + "full_prec": "auto", + "use_unicode": None, + "wrap_line": True, + "num_columns": None, + "use_unicode_sqrt_char": True, + "root_notation": True, + "mat_symbol_style": "plain", + "imaginary_unit": "i", + "perm_cyclic": True + } + + def __init__(self, settings=None): + Printer.__init__(self, settings) + + if not isinstance(self._settings['imaginary_unit'], str): + raise TypeError("'imaginary_unit' must a string, not {}".format(self._settings['imaginary_unit'])) + elif self._settings['imaginary_unit'] not in ("i", "j"): + raise ValueError("'imaginary_unit' must be either 'i' or 'j', not '{}'".format(self._settings['imaginary_unit'])) + + def emptyPrinter(self, expr): + return prettyForm(str(expr)) + + @property + def _use_unicode(self): + if self._settings['use_unicode']: + return True + else: + return pretty_use_unicode() + + def doprint(self, expr): + return self._print(expr).render(**self._settings) + + # empty op so _print(stringPict) returns the same + def _print_stringPict(self, e): + return e + + def _print_basestring(self, e): + return prettyForm(e) + + def _print_atan2(self, e): + pform = prettyForm(*self._print_seq(e.args).parens()) + pform = prettyForm(*pform.left('atan2')) + return pform + + def _print_Symbol(self, e, bold_name=False): + symb = pretty_symbol(e.name, bold_name) + return prettyForm(symb) + _print_RandomSymbol = _print_Symbol + def _print_MatrixSymbol(self, e): + return self._print_Symbol(e, self._settings['mat_symbol_style'] == "bold") + + def _print_Float(self, e): + # we will use StrPrinter's Float printer, but we need to handle the + # full_prec ourselves, according to the self._print_level + full_prec = self._settings["full_prec"] + if full_prec == "auto": + full_prec = self._print_level == 1 + return prettyForm(sstr(e, full_prec=full_prec)) + + def _print_Cross(self, e): + vec1 = e._expr1 + vec2 = e._expr2 + pform = self._print(vec2) + pform = prettyForm(*pform.left('(')) + pform = prettyForm(*pform.right(')')) + pform = prettyForm(*pform.left(self._print(U('MULTIPLICATION SIGN')))) + pform = prettyForm(*pform.left(')')) + pform = prettyForm(*pform.left(self._print(vec1))) + pform = prettyForm(*pform.left('(')) + return pform + + def _print_Curl(self, e): + vec = e._expr + pform = self._print(vec) + pform = prettyForm(*pform.left('(')) + pform = prettyForm(*pform.right(')')) + pform = prettyForm(*pform.left(self._print(U('MULTIPLICATION SIGN')))) + pform = prettyForm(*pform.left(self._print(U('NABLA')))) + return pform + + def _print_Divergence(self, e): + vec = e._expr + pform = self._print(vec) + pform = prettyForm(*pform.left('(')) + pform = prettyForm(*pform.right(')')) + pform = prettyForm(*pform.left(self._print(U('DOT OPERATOR')))) + pform = prettyForm(*pform.left(self._print(U('NABLA')))) + return pform + + def _print_Dot(self, e): + vec1 = e._expr1 + vec2 = e._expr2 + pform = self._print(vec2) + pform = prettyForm(*pform.left('(')) + pform = prettyForm(*pform.right(')')) + pform = prettyForm(*pform.left(self._print(U('DOT OPERATOR')))) + pform = prettyForm(*pform.left(')')) + pform = prettyForm(*pform.left(self._print(vec1))) + pform = prettyForm(*pform.left('(')) + return pform + + def _print_Gradient(self, e): + func = e._expr + pform = self._print(func) + pform = prettyForm(*pform.left('(')) + pform = prettyForm(*pform.right(')')) + pform = prettyForm(*pform.left(self._print(U('NABLA')))) + return pform + + def _print_Laplacian(self, e): + func = e._expr + pform = self._print(func) + pform = prettyForm(*pform.left('(')) + pform = prettyForm(*pform.right(')')) + pform = prettyForm(*pform.left(self._print(U('INCREMENT')))) + return pform + + def _print_Atom(self, e): + try: + # print atoms like Exp1 or Pi + return prettyForm(pretty_atom(e.__class__.__name__, printer=self)) + except KeyError: + return self.emptyPrinter(e) + + # Infinity inherits from Number, so we have to override _print_XXX order + _print_Infinity = _print_Atom + _print_NegativeInfinity = _print_Atom + _print_EmptySet = _print_Atom + _print_Naturals = _print_Atom + _print_Naturals0 = _print_Atom + _print_Integers = _print_Atom + _print_Rationals = _print_Atom + _print_Complexes = _print_Atom + + _print_EmptySequence = _print_Atom + + def _print_Reals(self, e): + if self._use_unicode: + return self._print_Atom(e) + else: + inf_list = ['-oo', 'oo'] + return self._print_seq(inf_list, '(', ')') + + def _print_subfactorial(self, e): + x = e.args[0] + pform = self._print(x) + # Add parentheses if needed + if not ((x.is_Integer and x.is_nonnegative) or x.is_Symbol): + pform = prettyForm(*pform.parens()) + pform = prettyForm(*pform.left('!')) + return pform + + def _print_factorial(self, e): + x = e.args[0] + pform = self._print(x) + # Add parentheses if needed + if not ((x.is_Integer and x.is_nonnegative) or x.is_Symbol): + pform = prettyForm(*pform.parens()) + pform = prettyForm(*pform.right('!')) + return pform + + def _print_factorial2(self, e): + x = e.args[0] + pform = self._print(x) + # Add parentheses if needed + if not ((x.is_Integer and x.is_nonnegative) or x.is_Symbol): + pform = prettyForm(*pform.parens()) + pform = prettyForm(*pform.right('!!')) + return pform + + def _print_binomial(self, e): + n, k = e.args + + n_pform = self._print(n) + k_pform = self._print(k) + + bar = ' '*max(n_pform.width(), k_pform.width()) + + pform = prettyForm(*k_pform.above(bar)) + pform = prettyForm(*pform.above(n_pform)) + pform = prettyForm(*pform.parens('(', ')')) + + pform.baseline = (pform.baseline + 1)//2 + + return pform + + def _print_Relational(self, e): + op = prettyForm(' ' + xsym(e.rel_op) + ' ') + + l = self._print(e.lhs) + r = self._print(e.rhs) + pform = prettyForm(*stringPict.next(l, op, r), binding=prettyForm.OPEN) + return pform + + def _print_Not(self, e): + from sympy.logic.boolalg import (Equivalent, Implies) + if self._use_unicode: + arg = e.args[0] + pform = self._print(arg) + if isinstance(arg, Equivalent): + return self._print_Equivalent(arg, altchar=pretty_atom('NotEquiv')) + if isinstance(arg, Implies): + return self._print_Implies(arg, altchar=pretty_atom('NotArrow')) + + if arg.is_Boolean and not arg.is_Not: + pform = prettyForm(*pform.parens()) + + return prettyForm(*pform.left(pretty_atom('Not'))) + else: + return self._print_Function(e) + + def __print_Boolean(self, e, char, sort=True): + args = e.args + if sort: + args = sorted(e.args, key=default_sort_key) + arg = args[0] + pform = self._print(arg) + + if arg.is_Boolean and not arg.is_Not: + pform = prettyForm(*pform.parens()) + + for arg in args[1:]: + pform_arg = self._print(arg) + + if arg.is_Boolean and not arg.is_Not: + pform_arg = prettyForm(*pform_arg.parens()) + + pform = prettyForm(*pform.right(' %s ' % char)) + pform = prettyForm(*pform.right(pform_arg)) + + return pform + + def _print_And(self, e): + if self._use_unicode: + return self.__print_Boolean(e, pretty_atom('And')) + else: + return self._print_Function(e, sort=True) + + def _print_Or(self, e): + if self._use_unicode: + return self.__print_Boolean(e, pretty_atom('Or')) + else: + return self._print_Function(e, sort=True) + + def _print_Xor(self, e): + if self._use_unicode: + return self.__print_Boolean(e, pretty_atom("Xor")) + else: + return self._print_Function(e, sort=True) + + def _print_Nand(self, e): + if self._use_unicode: + return self.__print_Boolean(e, pretty_atom('Nand')) + else: + return self._print_Function(e, sort=True) + + def _print_Nor(self, e): + if self._use_unicode: + return self.__print_Boolean(e, pretty_atom('Nor')) + else: + return self._print_Function(e, sort=True) + + def _print_Implies(self, e, altchar=None): + if self._use_unicode: + return self.__print_Boolean(e, altchar or pretty_atom('Arrow'), sort=False) + else: + return self._print_Function(e) + + def _print_Equivalent(self, e, altchar=None): + if self._use_unicode: + return self.__print_Boolean(e, altchar or pretty_atom('Equiv')) + else: + return self._print_Function(e, sort=True) + + def _print_conjugate(self, e): + pform = self._print(e.args[0]) + return prettyForm( *pform.above( hobj('_', pform.width())) ) + + def _print_Abs(self, e): + pform = self._print(e.args[0]) + pform = prettyForm(*pform.parens('|', '|')) + return pform + + def _print_floor(self, e): + if self._use_unicode: + pform = self._print(e.args[0]) + pform = prettyForm(*pform.parens('lfloor', 'rfloor')) + return pform + else: + return self._print_Function(e) + + def _print_ceiling(self, e): + if self._use_unicode: + pform = self._print(e.args[0]) + pform = prettyForm(*pform.parens('lceil', 'rceil')) + return pform + else: + return self._print_Function(e) + + def _print_Derivative(self, deriv): + if requires_partial(deriv.expr) and self._use_unicode: + deriv_symbol = U('PARTIAL DIFFERENTIAL') + else: + deriv_symbol = r'd' + x = None + count_total_deriv = 0 + + for sym, num in reversed(deriv.variable_count): + s = self._print(sym) + ds = prettyForm(*s.left(deriv_symbol)) + count_total_deriv += num + + if (not num.is_Integer) or (num > 1): + ds = ds**prettyForm(str(num)) + + if x is None: + x = ds + else: + x = prettyForm(*x.right(' ')) + x = prettyForm(*x.right(ds)) + + f = prettyForm( + binding=prettyForm.FUNC, *self._print(deriv.expr).parens()) + + pform = prettyForm(deriv_symbol) + + if (count_total_deriv > 1) != False: + pform = pform**prettyForm(str(count_total_deriv)) + + pform = prettyForm(*pform.below(stringPict.LINE, x)) + pform.baseline = pform.baseline + 1 + pform = prettyForm(*stringPict.next(pform, f)) + pform.binding = prettyForm.MUL + + return pform + + def _print_Cycle(self, dc): + from sympy.combinatorics.permutations import Permutation, Cycle + # for Empty Cycle + if dc == Cycle(): + cyc = stringPict('') + return prettyForm(*cyc.parens()) + + dc_list = Permutation(dc.list()).cyclic_form + # for Identity Cycle + if dc_list == []: + cyc = self._print(dc.size - 1) + return prettyForm(*cyc.parens()) + + cyc = stringPict('') + for i in dc_list: + l = self._print(str(tuple(i)).replace(',', '')) + cyc = prettyForm(*cyc.right(l)) + return cyc + + def _print_Permutation(self, expr): + from sympy.combinatorics.permutations import Permutation, Cycle + + perm_cyclic = Permutation.print_cyclic + if perm_cyclic is not None: + sympy_deprecation_warning( + f""" + Setting Permutation.print_cyclic is deprecated. Instead use + init_printing(perm_cyclic={perm_cyclic}). + """, + deprecated_since_version="1.6", + active_deprecations_target="deprecated-permutation-print_cyclic", + stacklevel=7, + ) + else: + perm_cyclic = self._settings.get("perm_cyclic", True) + + if perm_cyclic: + return self._print_Cycle(Cycle(expr)) + + lower = expr.array_form + upper = list(range(len(lower))) + + result = stringPict('') + first = True + for u, l in zip(upper, lower): + s1 = self._print(u) + s2 = self._print(l) + col = prettyForm(*s1.below(s2)) + if first: + first = False + else: + col = prettyForm(*col.left(" ")) + result = prettyForm(*result.right(col)) + return prettyForm(*result.parens()) + + + def _print_Integral(self, integral): + f = integral.function + + # Add parentheses if arg involves addition of terms and + # create a pretty form for the argument + prettyF = self._print(f) + # XXX generalize parens + if f.is_Add: + prettyF = prettyForm(*prettyF.parens()) + + # dx dy dz ... + arg = prettyF + for x in integral.limits: + prettyArg = self._print(x[0]) + # XXX qparens (parens if needs-parens) + if prettyArg.width() > 1: + prettyArg = prettyForm(*prettyArg.parens()) + + arg = prettyForm(*arg.right(' d', prettyArg)) + + # \int \int \int ... + firstterm = True + s = None + for lim in integral.limits: + # Create bar based on the height of the argument + h = arg.height() + H = h + 2 + + # XXX hack! + ascii_mode = not self._use_unicode + if ascii_mode: + H += 2 + + vint = vobj('int', H) + + # Construct the pretty form with the integral sign and the argument + pform = prettyForm(vint) + pform.baseline = arg.baseline + ( + H - h)//2 # covering the whole argument + + if len(lim) > 1: + # Create pretty forms for endpoints, if definite integral. + # Do not print empty endpoints. + if len(lim) == 2: + prettyA = prettyForm("") + prettyB = self._print(lim[1]) + if len(lim) == 3: + prettyA = self._print(lim[1]) + prettyB = self._print(lim[2]) + + if ascii_mode: # XXX hack + # Add spacing so that endpoint can more easily be + # identified with the correct integral sign + spc = max(1, 3 - prettyB.width()) + prettyB = prettyForm(*prettyB.left(' ' * spc)) + + spc = max(1, 4 - prettyA.width()) + prettyA = prettyForm(*prettyA.right(' ' * spc)) + + pform = prettyForm(*pform.above(prettyB)) + pform = prettyForm(*pform.below(prettyA)) + + if not ascii_mode: # XXX hack + pform = prettyForm(*pform.right(' ')) + + if firstterm: + s = pform # first term + firstterm = False + else: + s = prettyForm(*s.left(pform)) + + pform = prettyForm(*arg.left(s)) + pform.binding = prettyForm.MUL + return pform + + def _print_Product(self, expr): + func = expr.term + pretty_func = self._print(func) + + horizontal_chr = xobj('_', 1) + corner_chr = xobj('_', 1) + vertical_chr = xobj('|', 1) + + if self._use_unicode: + # use unicode corners + horizontal_chr = xobj('-', 1) + corner_chr = xobj('UpTack', 1) + + func_height = pretty_func.height() + + first = True + max_upper = 0 + sign_height = 0 + + for lim in expr.limits: + pretty_lower, pretty_upper = self.__print_SumProduct_Limits(lim) + + width = (func_height + 2) * 5 // 3 - 2 + sign_lines = [horizontal_chr + corner_chr + (horizontal_chr * (width-2)) + corner_chr + horizontal_chr] + for _ in range(func_height + 1): + sign_lines.append(' ' + vertical_chr + (' ' * (width-2)) + vertical_chr + ' ') + + pretty_sign = stringPict('') + pretty_sign = prettyForm(*pretty_sign.stack(*sign_lines)) + + + max_upper = max(max_upper, pretty_upper.height()) + + if first: + sign_height = pretty_sign.height() + + pretty_sign = prettyForm(*pretty_sign.above(pretty_upper)) + pretty_sign = prettyForm(*pretty_sign.below(pretty_lower)) + + if first: + pretty_func.baseline = 0 + first = False + + height = pretty_sign.height() + padding = stringPict('') + padding = prettyForm(*padding.stack(*[' ']*(height - 1))) + pretty_sign = prettyForm(*pretty_sign.right(padding)) + + pretty_func = prettyForm(*pretty_sign.right(pretty_func)) + + pretty_func.baseline = max_upper + sign_height//2 + pretty_func.binding = prettyForm.MUL + return pretty_func + + def __print_SumProduct_Limits(self, lim): + def print_start(lhs, rhs): + op = prettyForm(' ' + xsym("==") + ' ') + l = self._print(lhs) + r = self._print(rhs) + pform = prettyForm(*stringPict.next(l, op, r)) + return pform + + prettyUpper = self._print(lim[2]) + prettyLower = print_start(lim[0], lim[1]) + return prettyLower, prettyUpper + + def _print_Sum(self, expr): + ascii_mode = not self._use_unicode + + def asum(hrequired, lower, upper, use_ascii): + def adjust(s, wid=None, how='<^>'): + if not wid or len(s) > wid: + return s + need = wid - len(s) + if how in ('<^>', "<") or how not in list('<^>'): + return s + ' '*need + half = need//2 + lead = ' '*half + if how == ">": + return " "*need + s + return lead + s + ' '*(need - len(lead)) + + h = max(hrequired, 2) + d = h//2 + w = d + 1 + more = hrequired % 2 + + lines = [] + if use_ascii: + lines.append("_"*(w) + ' ') + lines.append(r"\%s`" % (' '*(w - 1))) + for i in range(1, d): + lines.append('%s\\%s' % (' '*i, ' '*(w - i))) + if more: + lines.append('%s)%s' % (' '*(d), ' '*(w - d))) + for i in reversed(range(1, d)): + lines.append('%s/%s' % (' '*i, ' '*(w - i))) + lines.append("/" + "_"*(w - 1) + ',') + return d, h + more, lines, more + else: + w = w + more + d = d + more + vsum = vobj('sum', 4) + lines.append("_"*(w)) + for i in range(0, d): + lines.append('%s%s%s' % (' '*i, vsum[2], ' '*(w - i - 1))) + for i in reversed(range(0, d)): + lines.append('%s%s%s' % (' '*i, vsum[4], ' '*(w - i - 1))) + lines.append(vsum[8]*(w)) + return d, h + 2*more, lines, more + + f = expr.function + + prettyF = self._print(f) + + if f.is_Add: # add parens + prettyF = prettyForm(*prettyF.parens()) + + H = prettyF.height() + 2 + + # \sum \sum \sum ... + first = True + max_upper = 0 + sign_height = 0 + + for lim in expr.limits: + prettyLower, prettyUpper = self.__print_SumProduct_Limits(lim) + + max_upper = max(max_upper, prettyUpper.height()) + + # Create sum sign based on the height of the argument + d, h, slines, adjustment = asum( + H, prettyLower.width(), prettyUpper.width(), ascii_mode) + prettySign = stringPict('') + prettySign = prettyForm(*prettySign.stack(*slines)) + + if first: + sign_height = prettySign.height() + + prettySign = prettyForm(*prettySign.above(prettyUpper)) + prettySign = prettyForm(*prettySign.below(prettyLower)) + + if first: + # change F baseline so it centers on the sign + prettyF.baseline -= d - (prettyF.height()//2 - + prettyF.baseline) + first = False + + # put padding to the right + pad = stringPict('') + pad = prettyForm(*pad.stack(*[' ']*h)) + prettySign = prettyForm(*prettySign.right(pad)) + # put the present prettyF to the right + prettyF = prettyForm(*prettySign.right(prettyF)) + + # adjust baseline of ascii mode sigma with an odd height so that it is + # exactly through the center + ascii_adjustment = ascii_mode if not adjustment else 0 + prettyF.baseline = max_upper + sign_height//2 + ascii_adjustment + + prettyF.binding = prettyForm.MUL + return prettyF + + def _print_Limit(self, l): + e, z, z0, dir = l.args + + E = self._print(e) + if precedence(e) <= PRECEDENCE["Mul"]: + E = prettyForm(*E.parens('(', ')')) + Lim = prettyForm('lim') + + LimArg = self._print(z) + if self._use_unicode: + LimArg = prettyForm(*LimArg.right(f"{xobj('-', 1)}{pretty_atom('Arrow')}")) + else: + LimArg = prettyForm(*LimArg.right('->')) + LimArg = prettyForm(*LimArg.right(self._print(z0))) + + if str(dir) == '+-' or z0 in (S.Infinity, S.NegativeInfinity): + dir = "" + else: + if self._use_unicode: + dir = pretty_atom('SuperscriptPlus') if str(dir) == "+" else pretty_atom('SuperscriptMinus') + + LimArg = prettyForm(*LimArg.right(self._print(dir))) + + Lim = prettyForm(*Lim.below(LimArg)) + Lim = prettyForm(*Lim.right(E), binding=prettyForm.MUL) + + return Lim + + def _print_matrix_contents(self, e): + """ + This method factors out what is essentially grid printing. + """ + M = e # matrix + Ms = {} # i,j -> pretty(M[i,j]) + for i in range(M.rows): + for j in range(M.cols): + Ms[i, j] = self._print(M[i, j]) + + # h- and v- spacers + hsep = 2 + vsep = 1 + + # max width for columns + maxw = [-1] * M.cols + + for j in range(M.cols): + maxw[j] = max([Ms[i, j].width() for i in range(M.rows)] or [0]) + + # drawing result + D = None + + for i in range(M.rows): + + D_row = None + for j in range(M.cols): + s = Ms[i, j] + + # reshape s to maxw + # XXX this should be generalized, and go to stringPict.reshape ? + assert s.width() <= maxw[j] + + # hcenter it, +0.5 to the right 2 + # ( it's better to align formula starts for say 0 and r ) + # XXX this is not good in all cases -- maybe introduce vbaseline? + left, right = center_pad(s.width(), maxw[j]) + + s = prettyForm(*s.right(right)) + s = prettyForm(*s.left(left)) + + # we don't need vcenter cells -- this is automatically done in + # a pretty way because when their baselines are taking into + # account in .right() + + if D_row is None: + D_row = s # first box in a row + continue + + D_row = prettyForm(*D_row.right(' '*hsep)) # h-spacer + D_row = prettyForm(*D_row.right(s)) + + if D is None: + D = D_row # first row in a picture + continue + + # v-spacer + for _ in range(vsep): + D = prettyForm(*D.below(' ')) + + D = prettyForm(*D.below(D_row)) + + if D is None: + D = prettyForm('') # Empty Matrix + + return D + + def _print_MatrixBase(self, e, lparens='[', rparens=']'): + D = self._print_matrix_contents(e) + D.baseline = D.height()//2 + D = prettyForm(*D.parens(lparens, rparens)) + return D + + def _print_Determinant(self, e): + mat = e.arg + if mat.is_MatrixExpr: + from sympy.matrices.expressions.blockmatrix import BlockMatrix + if isinstance(mat, BlockMatrix): + return self._print_MatrixBase(mat.blocks, lparens='|', rparens='|') + D = self._print(mat) + D.baseline = D.height()//2 + return prettyForm(*D.parens('|', '|')) + else: + return self._print_MatrixBase(mat, lparens='|', rparens='|') + + def _print_TensorProduct(self, expr): + # This should somehow share the code with _print_WedgeProduct: + if self._use_unicode: + circled_times = "\u2297" + else: + circled_times = ".*" + return self._print_seq(expr.args, None, None, circled_times, + parenthesize=lambda x: precedence_traditional(x) <= PRECEDENCE["Mul"]) + + def _print_WedgeProduct(self, expr): + # This should somehow share the code with _print_TensorProduct: + if self._use_unicode: + wedge_symbol = "\u2227" + else: + wedge_symbol = '/\\' + return self._print_seq(expr.args, None, None, wedge_symbol, + parenthesize=lambda x: precedence_traditional(x) <= PRECEDENCE["Mul"]) + + def _print_Trace(self, e): + D = self._print(e.arg) + D = prettyForm(*D.parens('(',')')) + D.baseline = D.height()//2 + D = prettyForm(*D.left('\n'*(0) + 'tr')) + return D + + + def _print_MatrixElement(self, expr): + from sympy.matrices import MatrixSymbol + if (isinstance(expr.parent, MatrixSymbol) + and expr.i.is_number and expr.j.is_number): + return self._print( + Symbol(expr.parent.name + '_%d%d' % (expr.i, expr.j))) + else: + prettyFunc = self._print(expr.parent) + prettyFunc = prettyForm(*prettyFunc.parens()) + prettyIndices = self._print_seq((expr.i, expr.j), delimiter=', ' + ).parens(left='[', right=']')[0] + pform = prettyForm(binding=prettyForm.FUNC, + *stringPict.next(prettyFunc, prettyIndices)) + + # store pform parts so it can be reassembled e.g. when powered + pform.prettyFunc = prettyFunc + pform.prettyArgs = prettyIndices + + return pform + + + def _print_MatrixSlice(self, m): + # XXX works only for applied functions + from sympy.matrices import MatrixSymbol + prettyFunc = self._print(m.parent) + if not isinstance(m.parent, MatrixSymbol): + prettyFunc = prettyForm(*prettyFunc.parens()) + def ppslice(x, dim): + x = list(x) + if x[2] == 1: + del x[2] + if x[0] == 0: + x[0] = '' + if x[1] == dim: + x[1] = '' + return prettyForm(*self._print_seq(x, delimiter=':')) + prettyArgs = self._print_seq((ppslice(m.rowslice, m.parent.rows), + ppslice(m.colslice, m.parent.cols)), delimiter=', ').parens(left='[', right=']')[0] + + pform = prettyForm( + binding=prettyForm.FUNC, *stringPict.next(prettyFunc, prettyArgs)) + + # store pform parts so it can be reassembled e.g. when powered + pform.prettyFunc = prettyFunc + pform.prettyArgs = prettyArgs + + return pform + + def _print_Transpose(self, expr): + mat = expr.arg + pform = self._print(mat) + from sympy.matrices import MatrixSymbol, BlockMatrix + if (not isinstance(mat, MatrixSymbol) and + not isinstance(mat, BlockMatrix) and mat.is_MatrixExpr): + pform = prettyForm(*pform.parens()) + pform = pform**(prettyForm('T')) + return pform + + def _print_Adjoint(self, expr): + mat = expr.arg + pform = self._print(mat) + if self._use_unicode: + dag = prettyForm(pretty_atom('Dagger')) + else: + dag = prettyForm('+') + from sympy.matrices import MatrixSymbol, BlockMatrix + if (not isinstance(mat, MatrixSymbol) and + not isinstance(mat, BlockMatrix) and mat.is_MatrixExpr): + pform = prettyForm(*pform.parens()) + pform = pform**dag + return pform + + def _print_BlockMatrix(self, B): + if B.blocks.shape == (1, 1): + return self._print(B.blocks[0, 0]) + return self._print(B.blocks) + + def _print_MatAdd(self, expr): + s = None + for item in expr.args: + pform = self._print(item) + if s is None: + s = pform # First element + else: + coeff = item.as_coeff_mmul()[0] + if S(coeff).could_extract_minus_sign(): + s = prettyForm(*stringPict.next(s, ' ')) + pform = self._print(item) + else: + s = prettyForm(*stringPict.next(s, ' + ')) + s = prettyForm(*stringPict.next(s, pform)) + + return s + + def _print_MatMul(self, expr): + args = list(expr.args) + from sympy.matrices.expressions.hadamard import HadamardProduct + from sympy.matrices.expressions.kronecker import KroneckerProduct + from sympy.matrices.expressions.matadd import MatAdd + for i, a in enumerate(args): + if (isinstance(a, (Add, MatAdd, HadamardProduct, KroneckerProduct)) + and len(expr.args) > 1): + args[i] = prettyForm(*self._print(a).parens()) + else: + args[i] = self._print(a) + + return prettyForm.__mul__(*args) + + def _print_Identity(self, expr): + if self._use_unicode: + return prettyForm(pretty_atom('IdentityMatrix')) + else: + return prettyForm('I') + + def _print_ZeroMatrix(self, expr): + if self._use_unicode: + return prettyForm(pretty_atom('ZeroMatrix')) + else: + return prettyForm('0') + + def _print_OneMatrix(self, expr): + if self._use_unicode: + return prettyForm(pretty_atom("OneMatrix")) + else: + return prettyForm('1') + + def _print_DotProduct(self, expr): + args = list(expr.args) + + for i, a in enumerate(args): + args[i] = self._print(a) + return prettyForm.__mul__(*args) + + def _print_MatPow(self, expr): + pform = self._print(expr.base) + from sympy.matrices import MatrixSymbol + if not isinstance(expr.base, MatrixSymbol) and expr.base.is_MatrixExpr: + pform = prettyForm(*pform.parens()) + pform = pform**(self._print(expr.exp)) + return pform + + def _print_HadamardProduct(self, expr): + from sympy.matrices.expressions.hadamard import HadamardProduct + from sympy.matrices.expressions.matadd import MatAdd + from sympy.matrices.expressions.matmul import MatMul + if self._use_unicode: + delim = pretty_atom('Ring') + else: + delim = '.*' + return self._print_seq(expr.args, None, None, delim, + parenthesize=lambda x: isinstance(x, (MatAdd, MatMul, HadamardProduct))) + + def _print_HadamardPower(self, expr): + # from sympy import MatAdd, MatMul + if self._use_unicode: + circ = pretty_atom('Ring') + else: + circ = self._print('.') + pretty_base = self._print(expr.base) + pretty_exp = self._print(expr.exp) + if precedence(expr.exp) < PRECEDENCE["Mul"]: + pretty_exp = prettyForm(*pretty_exp.parens()) + pretty_circ_exp = prettyForm( + binding=prettyForm.LINE, + *stringPict.next(circ, pretty_exp) + ) + return pretty_base**pretty_circ_exp + + def _print_KroneckerProduct(self, expr): + from sympy.matrices.expressions.matadd import MatAdd + from sympy.matrices.expressions.matmul import MatMul + if self._use_unicode: + delim = f" {pretty_atom('TensorProduct')} " + else: + delim = ' x ' + return self._print_seq(expr.args, None, None, delim, + parenthesize=lambda x: isinstance(x, (MatAdd, MatMul))) + + def _print_FunctionMatrix(self, X): + D = self._print(X.lamda.expr) + D = prettyForm(*D.parens('[', ']')) + return D + + def _print_TransferFunction(self, expr): + if not expr.num == 1: + num, den = expr.num, expr.den + res = Mul(num, Pow(den, -1, evaluate=False), evaluate=False) + return self._print_Mul(res) + else: + return self._print(1)/self._print(expr.den) + + def _print_Series(self, expr): + args = list(expr.args) + for i, a in enumerate(expr.args): + args[i] = prettyForm(*self._print(a).parens()) + return prettyForm.__mul__(*args) + + def _print_MIMOSeries(self, expr): + from sympy.physics.control.lti import MIMOParallel + args = list(expr.args) + pretty_args = [] + for a in reversed(args): + if (isinstance(a, MIMOParallel) and len(expr.args) > 1): + expression = self._print(a) + expression.baseline = expression.height()//2 + pretty_args.append(prettyForm(*expression.parens())) + else: + expression = self._print(a) + expression.baseline = expression.height()//2 + pretty_args.append(expression) + return prettyForm.__mul__(*pretty_args) + + def _print_Parallel(self, expr): + s = None + for item in expr.args: + pform = self._print(item) + if s is None: + s = pform # First element + else: + s = prettyForm(*stringPict.next(s)) + s.baseline = s.height()//2 + s = prettyForm(*stringPict.next(s, ' + ')) + s = prettyForm(*stringPict.next(s, pform)) + return s + + def _print_MIMOParallel(self, expr): + from sympy.physics.control.lti import TransferFunctionMatrix + s = None + for item in expr.args: + pform = self._print(item) + if s is None: + s = pform # First element + else: + s = prettyForm(*stringPict.next(s)) + s.baseline = s.height()//2 + s = prettyForm(*stringPict.next(s, ' + ')) + if isinstance(item, TransferFunctionMatrix): + s.baseline = s.height() - 1 + s = prettyForm(*stringPict.next(s, pform)) + # s.baseline = s.height()//2 + return s + + def _print_Feedback(self, expr): + from sympy.physics.control import TransferFunction, Series + + num, tf = expr.sys1, TransferFunction(1, 1, expr.var) + num_arg_list = list(num.args) if isinstance(num, Series) else [num] + den_arg_list = list(expr.sys2.args) if \ + isinstance(expr.sys2, Series) else [expr.sys2] + + if isinstance(num, Series) and isinstance(expr.sys2, Series): + den = Series(*num_arg_list, *den_arg_list) + elif isinstance(num, Series) and isinstance(expr.sys2, TransferFunction): + if expr.sys2 == tf: + den = Series(*num_arg_list) + else: + den = Series(*num_arg_list, expr.sys2) + elif isinstance(num, TransferFunction) and isinstance(expr.sys2, Series): + if num == tf: + den = Series(*den_arg_list) + else: + den = Series(num, *den_arg_list) + else: + if num == tf: + den = Series(*den_arg_list) + elif expr.sys2 == tf: + den = Series(*num_arg_list) + else: + den = Series(*num_arg_list, *den_arg_list) + + denom = prettyForm(*stringPict.next(self._print(tf))) + denom.baseline = denom.height()//2 + denom = prettyForm(*stringPict.next(denom, ' + ')) if expr.sign == -1 \ + else prettyForm(*stringPict.next(denom, ' - ')) + denom = prettyForm(*stringPict.next(denom, self._print(den))) + + return self._print(num)/denom + + def _print_MIMOFeedback(self, expr): + from sympy.physics.control import MIMOSeries, TransferFunctionMatrix + + inv_mat = self._print(MIMOSeries(expr.sys2, expr.sys1)) + plant = self._print(expr.sys1) + _feedback = prettyForm(*stringPict.next(inv_mat)) + _feedback = prettyForm(*stringPict.right("I + ", _feedback)) if expr.sign == -1 \ + else prettyForm(*stringPict.right("I - ", _feedback)) + _feedback = prettyForm(*stringPict.parens(_feedback)) + _feedback.baseline = 0 + _feedback = prettyForm(*stringPict.right(_feedback, '-1 ')) + _feedback.baseline = _feedback.height()//2 + _feedback = prettyForm.__mul__(_feedback, prettyForm(" ")) + if isinstance(expr.sys1, TransferFunctionMatrix): + _feedback.baseline = _feedback.height() - 1 + _feedback = prettyForm(*stringPict.next(_feedback, plant)) + return _feedback + + def _print_TransferFunctionMatrix(self, expr): + mat = self._print(expr._expr_mat) + mat.baseline = mat.height() - 1 + subscript = greek_unicode['tau'] if self._use_unicode else r'{t}' + mat = prettyForm(*mat.right(subscript)) + return mat + + def _print_StateSpace(self, expr): + from sympy.matrices.expressions.blockmatrix import BlockMatrix + A = expr._A + B = expr._B + C = expr._C + D = expr._D + mat = BlockMatrix([[A, B], [C, D]]) + return self._print(mat.blocks) + + def _print_BasisDependent(self, expr): + from sympy.vector import Vector + + if not self._use_unicode: + raise NotImplementedError("ASCII pretty printing of BasisDependent is not implemented") + + if expr == expr.zero: + return prettyForm(expr.zero._pretty_form) + o1 = [] + vectstrs = [] + if isinstance(expr, Vector): + items = expr.separate().items() + else: + items = [(0, expr)] + for system, vect in items: + inneritems = list(vect.components.items()) + inneritems.sort(key = lambda x: x[0].__str__()) + for k, v in inneritems: + #if the coef of the basis vector is 1 + #we skip the 1 + if v == 1: + o1.append("" + + k._pretty_form) + #Same for -1 + elif v == -1: + o1.append("(-1) " + + k._pretty_form) + #For a general expr + else: + #We always wrap the measure numbers in + #parentheses + arg_str = self._print( + v).parens()[0] + + o1.append(arg_str + ' ' + k._pretty_form) + vectstrs.append(k._pretty_form) + + #outstr = u("").join(o1) + if o1[0].startswith(" + "): + o1[0] = o1[0][3:] + elif o1[0].startswith(" "): + o1[0] = o1[0][1:] + #Fixing the newlines + lengths = [] + strs = [''] + flag = [] + for i, partstr in enumerate(o1): + flag.append(0) + # XXX: What is this hack? + if '\n' in partstr: + tempstr = partstr + tempstr = tempstr.replace(vectstrs[i], '') + if xobj(')_ext', 1) in tempstr: # If scalar is a fraction + for paren in range(len(tempstr)): + flag[i] = 1 + if tempstr[paren] == xobj(')_ext', 1) and tempstr[paren + 1] == '\n': + # We want to place the vector string after all the right parentheses, because + # otherwise, the vector will be in the middle of the string + tempstr = tempstr[:paren] + xobj(')_ext', 1)\ + + ' ' + vectstrs[i] + tempstr[paren + 1:] + break + elif xobj(')_lower_hook', 1) in tempstr: + # We want to place the vector string after all the right parentheses, because + # otherwise, the vector will be in the middle of the string. For this reason, + # we insert the vector string at the rightmost index. + index = tempstr.rfind(xobj(')_lower_hook', 1)) + if index != -1: # then this character was found in this string + flag[i] = 1 + tempstr = tempstr[:index] + xobj(')_lower_hook', 1)\ + + ' ' + vectstrs[i] + tempstr[index + 1:] + o1[i] = tempstr + + o1 = [x.split('\n') for x in o1] + n_newlines = max(len(x) for x in o1) # Width of part in its pretty form + + if 1 in flag: # If there was a fractional scalar + for i, parts in enumerate(o1): + if len(parts) == 1: # If part has no newline + parts.insert(0, ' ' * (len(parts[0]))) + flag[i] = 1 + + for i, parts in enumerate(o1): + lengths.append(len(parts[flag[i]])) + for j in range(n_newlines): + if j+1 <= len(parts): + if j >= len(strs): + strs.append(' ' * (sum(lengths[:-1]) + + 3*(len(lengths)-1))) + if j == flag[i]: + strs[flag[i]] += parts[flag[i]] + ' + ' + else: + strs[j] += parts[j] + ' '*(lengths[-1] - + len(parts[j])+ + 3) + else: + if j >= len(strs): + strs.append(' ' * (sum(lengths[:-1]) + + 3*(len(lengths)-1))) + strs[j] += ' '*(lengths[-1]+3) + + return prettyForm('\n'.join([s[:-3] for s in strs])) + + def _print_NDimArray(self, expr): + from sympy.matrices.immutable import ImmutableMatrix + + if expr.rank() == 0: + return self._print(expr[()]) + + level_str = [[]] + [[] for i in range(expr.rank())] + shape_ranges = [list(range(i)) for i in expr.shape] + # leave eventual matrix elements unflattened + mat = lambda x: ImmutableMatrix(x, evaluate=False) + for outer_i in itertools.product(*shape_ranges): + level_str[-1].append(expr[outer_i]) + even = True + for back_outer_i in range(expr.rank()-1, -1, -1): + if len(level_str[back_outer_i+1]) < expr.shape[back_outer_i]: + break + if even: + level_str[back_outer_i].append(level_str[back_outer_i+1]) + else: + level_str[back_outer_i].append(mat( + level_str[back_outer_i+1])) + if len(level_str[back_outer_i + 1]) == 1: + level_str[back_outer_i][-1] = mat( + [[level_str[back_outer_i][-1]]]) + even = not even + level_str[back_outer_i+1] = [] + + out_expr = level_str[0][0] + if expr.rank() % 2 == 1: + out_expr = mat([out_expr]) + + return self._print(out_expr) + + def _printer_tensor_indices(self, name, indices, index_map={}): + center = stringPict(name) + top = stringPict(" "*center.width()) + bot = stringPict(" "*center.width()) + + last_valence = None + prev_map = None + + for index in indices: + indpic = self._print(index.args[0]) + if ((index in index_map) or prev_map) and last_valence == index.is_up: + if index.is_up: + top = prettyForm(*stringPict.next(top, ",")) + else: + bot = prettyForm(*stringPict.next(bot, ",")) + if index in index_map: + indpic = prettyForm(*stringPict.next(indpic, "=")) + indpic = prettyForm(*stringPict.next(indpic, self._print(index_map[index]))) + prev_map = True + else: + prev_map = False + if index.is_up: + top = stringPict(*top.right(indpic)) + center = stringPict(*center.right(" "*indpic.width())) + bot = stringPict(*bot.right(" "*indpic.width())) + else: + bot = stringPict(*bot.right(indpic)) + center = stringPict(*center.right(" "*indpic.width())) + top = stringPict(*top.right(" "*indpic.width())) + last_valence = index.is_up + + pict = prettyForm(*center.above(top)) + pict = prettyForm(*pict.below(bot)) + return pict + + def _print_Tensor(self, expr): + name = expr.args[0].name + indices = expr.get_indices() + return self._printer_tensor_indices(name, indices) + + def _print_TensorElement(self, expr): + name = expr.expr.args[0].name + indices = expr.expr.get_indices() + index_map = expr.index_map + return self._printer_tensor_indices(name, indices, index_map) + + def _print_TensMul(self, expr): + sign, args = expr._get_args_for_traditional_printer() + args = [ + prettyForm(*self._print(i).parens()) if + precedence_traditional(i) < PRECEDENCE["Mul"] else self._print(i) + for i in args + ] + pform = prettyForm.__mul__(*args) + if sign: + return prettyForm(*pform.left(sign)) + else: + return pform + + def _print_TensAdd(self, expr): + args = [ + prettyForm(*self._print(i).parens()) if + precedence_traditional(i) < PRECEDENCE["Mul"] else self._print(i) + for i in expr.args + ] + return prettyForm.__add__(*args) + + def _print_TensorIndex(self, expr): + sym = expr.args[0] + if not expr.is_up: + sym = -sym + return self._print(sym) + + def _print_PartialDerivative(self, deriv): + if self._use_unicode: + deriv_symbol = U('PARTIAL DIFFERENTIAL') + else: + deriv_symbol = r'd' + x = None + + for variable in reversed(deriv.variables): + s = self._print(variable) + ds = prettyForm(*s.left(deriv_symbol)) + + if x is None: + x = ds + else: + x = prettyForm(*x.right(' ')) + x = prettyForm(*x.right(ds)) + + f = prettyForm( + binding=prettyForm.FUNC, *self._print(deriv.expr).parens()) + + pform = prettyForm(deriv_symbol) + + if len(deriv.variables) > 1: + pform = pform**self._print(len(deriv.variables)) + + pform = prettyForm(*pform.below(stringPict.LINE, x)) + pform.baseline = pform.baseline + 1 + pform = prettyForm(*stringPict.next(pform, f)) + pform.binding = prettyForm.MUL + + return pform + + def _print_Piecewise(self, pexpr): + + P = {} + for n, ec in enumerate(pexpr.args): + P[n, 0] = self._print(ec.expr) + if ec.cond == True: + P[n, 1] = prettyForm('otherwise') + else: + P[n, 1] = prettyForm( + *prettyForm('for ').right(self._print(ec.cond))) + hsep = 2 + vsep = 1 + len_args = len(pexpr.args) + + # max widths + maxw = [max(P[i, j].width() for i in range(len_args)) + for j in range(2)] + + # FIXME: Refactor this code and matrix into some tabular environment. + # drawing result + D = None + + for i in range(len_args): + D_row = None + for j in range(2): + p = P[i, j] + assert p.width() <= maxw[j] + + wdelta = maxw[j] - p.width() + wleft = wdelta // 2 + wright = wdelta - wleft + + p = prettyForm(*p.right(' '*wright)) + p = prettyForm(*p.left(' '*wleft)) + + if D_row is None: + D_row = p + continue + + D_row = prettyForm(*D_row.right(' '*hsep)) # h-spacer + D_row = prettyForm(*D_row.right(p)) + if D is None: + D = D_row # first row in a picture + continue + + # v-spacer + for _ in range(vsep): + D = prettyForm(*D.below(' ')) + + D = prettyForm(*D.below(D_row)) + + D = prettyForm(*D.parens('{', '')) + D.baseline = D.height()//2 + D.binding = prettyForm.OPEN + return D + + def _print_ITE(self, ite): + from sympy.functions.elementary.piecewise import Piecewise + return self._print(ite.rewrite(Piecewise)) + + def _hprint_vec(self, v): + D = None + + for a in v: + p = a + if D is None: + D = p + else: + D = prettyForm(*D.right(', ')) + D = prettyForm(*D.right(p)) + if D is None: + D = stringPict(' ') + + return D + + def _hprint_vseparator(self, p1, p2, left=None, right=None, delimiter='', ifascii_nougly=False): + if ifascii_nougly and not self._use_unicode: + return self._print_seq((p1, '|', p2), left=left, right=right, + delimiter=delimiter, ifascii_nougly=True) + tmp = self._print_seq((p1, p2,), left=left, right=right, delimiter=delimiter) + sep = stringPict(vobj('|', tmp.height()), baseline=tmp.baseline) + return self._print_seq((p1, sep, p2), left=left, right=right, + delimiter=delimiter) + + def _print_hyper(self, e): + # FIXME refactor Matrix, Piecewise, and this into a tabular environment + ap = [self._print(a) for a in e.ap] + bq = [self._print(b) for b in e.bq] + + P = self._print(e.argument) + P.baseline = P.height()//2 + + # Drawing result - first create the ap, bq vectors + D = None + for v in [ap, bq]: + D_row = self._hprint_vec(v) + if D is None: + D = D_row # first row in a picture + else: + D = prettyForm(*D.below(' ')) + D = prettyForm(*D.below(D_row)) + + # make sure that the argument `z' is centred vertically + D.baseline = D.height()//2 + + # insert horizontal separator + P = prettyForm(*P.left(' ')) + D = prettyForm(*D.right(' ')) + + # insert separating `|` + D = self._hprint_vseparator(D, P) + + # add parens + D = prettyForm(*D.parens('(', ')')) + + # create the F symbol + above = D.height()//2 - 1 + below = D.height() - above - 1 + + sz, t, b, add, img = annotated('F') + F = prettyForm('\n' * (above - t) + img + '\n' * (below - b), + baseline=above + sz) + add = (sz + 1)//2 + + F = prettyForm(*F.left(self._print(len(e.ap)))) + F = prettyForm(*F.right(self._print(len(e.bq)))) + F.baseline = above + add + + D = prettyForm(*F.right(' ', D)) + + return D + + def _print_meijerg(self, e): + # FIXME refactor Matrix, Piecewise, and this into a tabular environment + + v = {} + v[(0, 0)] = [self._print(a) for a in e.an] + v[(0, 1)] = [self._print(a) for a in e.aother] + v[(1, 0)] = [self._print(b) for b in e.bm] + v[(1, 1)] = [self._print(b) for b in e.bother] + + P = self._print(e.argument) + P.baseline = P.height()//2 + + vp = {} + for idx in v: + vp[idx] = self._hprint_vec(v[idx]) + + for i in range(2): + maxw = max(vp[(0, i)].width(), vp[(1, i)].width()) + for j in range(2): + s = vp[(j, i)] + left = (maxw - s.width()) // 2 + right = maxw - left - s.width() + s = prettyForm(*s.left(' ' * left)) + s = prettyForm(*s.right(' ' * right)) + vp[(j, i)] = s + + D1 = prettyForm(*vp[(0, 0)].right(' ', vp[(0, 1)])) + D1 = prettyForm(*D1.below(' ')) + D2 = prettyForm(*vp[(1, 0)].right(' ', vp[(1, 1)])) + D = prettyForm(*D1.below(D2)) + + # make sure that the argument `z' is centred vertically + D.baseline = D.height()//2 + + # insert horizontal separator + P = prettyForm(*P.left(' ')) + D = prettyForm(*D.right(' ')) + + # insert separating `|` + D = self._hprint_vseparator(D, P) + + # add parens + D = prettyForm(*D.parens('(', ')')) + + # create the G symbol + above = D.height()//2 - 1 + below = D.height() - above - 1 + + sz, t, b, add, img = annotated('G') + F = prettyForm('\n' * (above - t) + img + '\n' * (below - b), + baseline=above + sz) + + pp = self._print(len(e.ap)) + pq = self._print(len(e.bq)) + pm = self._print(len(e.bm)) + pn = self._print(len(e.an)) + + def adjust(p1, p2): + diff = p1.width() - p2.width() + if diff == 0: + return p1, p2 + elif diff > 0: + return p1, prettyForm(*p2.left(' '*diff)) + else: + return prettyForm(*p1.left(' '*-diff)), p2 + pp, pm = adjust(pp, pm) + pq, pn = adjust(pq, pn) + pu = prettyForm(*pm.right(', ', pn)) + pl = prettyForm(*pp.right(', ', pq)) + + ht = F.baseline - above - 2 + if ht > 0: + pu = prettyForm(*pu.below('\n'*ht)) + p = prettyForm(*pu.below(pl)) + + F.baseline = above + F = prettyForm(*F.right(p)) + + F.baseline = above + add + + D = prettyForm(*F.right(' ', D)) + + return D + + def _print_ExpBase(self, e): + # TODO should exp_polar be printed differently? + # what about exp_polar(0), exp_polar(1)? + base = prettyForm(pretty_atom('Exp1', 'e')) + return base ** self._print(e.args[0]) + + def _print_Exp1(self, e): + return prettyForm(pretty_atom('Exp1', 'e')) + + def _print_Function(self, e, sort=False, func_name=None, left='(', + right=')'): + # optional argument func_name for supplying custom names + # XXX works only for applied functions + return self._helper_print_function(e.func, e.args, sort=sort, func_name=func_name, left=left, right=right) + + def _print_mathieuc(self, e): + return self._print_Function(e, func_name='C') + + def _print_mathieus(self, e): + return self._print_Function(e, func_name='S') + + def _print_mathieucprime(self, e): + return self._print_Function(e, func_name="C'") + + def _print_mathieusprime(self, e): + return self._print_Function(e, func_name="S'") + + def _helper_print_function(self, func, args, sort=False, func_name=None, + delimiter=', ', elementwise=False, left='(', + right=')'): + if sort: + args = sorted(args, key=default_sort_key) + + if not func_name and hasattr(func, "__name__"): + func_name = func.__name__ + + if func_name: + prettyFunc = self._print(Symbol(func_name)) + else: + prettyFunc = prettyForm(*self._print(func).parens()) + + if elementwise: + if self._use_unicode: + circ = pretty_atom('Modifier Letter Low Ring') + else: + circ = '.' + circ = self._print(circ) + prettyFunc = prettyForm( + binding=prettyForm.LINE, + *stringPict.next(prettyFunc, circ) + ) + + prettyArgs = prettyForm(*self._print_seq(args, delimiter=delimiter).parens( + left=left, right=right)) + + pform = prettyForm( + binding=prettyForm.FUNC, *stringPict.next(prettyFunc, prettyArgs)) + + # store pform parts so it can be reassembled e.g. when powered + pform.prettyFunc = prettyFunc + pform.prettyArgs = prettyArgs + + return pform + + def _print_ElementwiseApplyFunction(self, e): + func = e.function + arg = e.expr + args = [arg] + return self._helper_print_function(func, args, delimiter="", elementwise=True) + + @property + def _special_function_classes(self): + from sympy.functions.special.tensor_functions import KroneckerDelta + from sympy.functions.special.gamma_functions import gamma, lowergamma + from sympy.functions.special.zeta_functions import lerchphi + from sympy.functions.special.beta_functions import beta + from sympy.functions.special.delta_functions import DiracDelta + from sympy.functions.special.error_functions import Chi + return {KroneckerDelta: [greek_unicode['delta'], 'delta'], + gamma: [greek_unicode['Gamma'], 'Gamma'], + lerchphi: [greek_unicode['Phi'], 'lerchphi'], + lowergamma: [greek_unicode['gamma'], 'gamma'], + beta: [greek_unicode['Beta'], 'B'], + DiracDelta: [greek_unicode['delta'], 'delta'], + Chi: ['Chi', 'Chi']} + + def _print_FunctionClass(self, expr): + for cls in self._special_function_classes: + if issubclass(expr, cls) and expr.__name__ == cls.__name__: + if self._use_unicode: + return prettyForm(self._special_function_classes[cls][0]) + else: + return prettyForm(self._special_function_classes[cls][1]) + func_name = expr.__name__ + return prettyForm(pretty_symbol(func_name)) + + def _print_GeometryEntity(self, expr): + # GeometryEntity is based on Tuple but should not print like a Tuple + return self.emptyPrinter(expr) + + def _print_polylog(self, e): + subscript = self._print(e.args[0]) + if self._use_unicode and is_subscriptable_in_unicode(subscript): + return self._print_Function(Function('Li_%s' % subscript)(e.args[1])) + return self._print_Function(e) + + def _print_lerchphi(self, e): + func_name = greek_unicode['Phi'] if self._use_unicode else 'lerchphi' + return self._print_Function(e, func_name=func_name) + + def _print_dirichlet_eta(self, e): + func_name = greek_unicode['eta'] if self._use_unicode else 'dirichlet_eta' + return self._print_Function(e, func_name=func_name) + + def _print_Heaviside(self, e): + func_name = greek_unicode['theta'] if self._use_unicode else 'Heaviside' + if e.args[1] is S.Half: + pform = prettyForm(*self._print(e.args[0]).parens()) + pform = prettyForm(*pform.left(func_name)) + return pform + else: + return self._print_Function(e, func_name=func_name) + + def _print_fresnels(self, e): + return self._print_Function(e, func_name="S") + + def _print_fresnelc(self, e): + return self._print_Function(e, func_name="C") + + def _print_airyai(self, e): + return self._print_Function(e, func_name="Ai") + + def _print_airybi(self, e): + return self._print_Function(e, func_name="Bi") + + def _print_airyaiprime(self, e): + return self._print_Function(e, func_name="Ai'") + + def _print_airybiprime(self, e): + return self._print_Function(e, func_name="Bi'") + + def _print_LambertW(self, e): + return self._print_Function(e, func_name="W") + + def _print_Covariance(self, e): + return self._print_Function(e, func_name="Cov") + + def _print_Variance(self, e): + return self._print_Function(e, func_name="Var") + + def _print_Probability(self, e): + return self._print_Function(e, func_name="P") + + def _print_Expectation(self, e): + return self._print_Function(e, func_name="E", left='[', right=']') + + def _print_Lambda(self, e): + expr = e.expr + sig = e.signature + if self._use_unicode: + arrow = f" {pretty_atom('ArrowFromBar')} " + else: + arrow = " -> " + if len(sig) == 1 and sig[0].is_symbol: + sig = sig[0] + var_form = self._print(sig) + + return prettyForm(*stringPict.next(var_form, arrow, self._print(expr)), binding=8) + + def _print_Order(self, expr): + pform = self._print(expr.expr) + if (expr.point and any(p != S.Zero for p in expr.point)) or \ + len(expr.variables) > 1: + pform = prettyForm(*pform.right("; ")) + if len(expr.variables) > 1: + pform = prettyForm(*pform.right(self._print(expr.variables))) + elif len(expr.variables): + pform = prettyForm(*pform.right(self._print(expr.variables[0]))) + if self._use_unicode: + pform = prettyForm(*pform.right(f" {pretty_atom('Arrow')} ")) + else: + pform = prettyForm(*pform.right(" -> ")) + if len(expr.point) > 1: + pform = prettyForm(*pform.right(self._print(expr.point))) + else: + pform = prettyForm(*pform.right(self._print(expr.point[0]))) + pform = prettyForm(*pform.parens()) + pform = prettyForm(*pform.left("O")) + return pform + + def _print_SingularityFunction(self, e): + if self._use_unicode: + shift = self._print(e.args[0]-e.args[1]) + n = self._print(e.args[2]) + base = prettyForm("<") + base = prettyForm(*base.right(shift)) + base = prettyForm(*base.right(">")) + pform = base**n + return pform + else: + n = self._print(e.args[2]) + shift = self._print(e.args[0]-e.args[1]) + base = self._print_seq(shift, "<", ">", ' ') + return base**n + + def _print_beta(self, e): + func_name = greek_unicode['Beta'] if self._use_unicode else 'B' + return self._print_Function(e, func_name=func_name) + + def _print_betainc(self, e): + func_name = "B'" + return self._print_Function(e, func_name=func_name) + + def _print_betainc_regularized(self, e): + func_name = 'I' + return self._print_Function(e, func_name=func_name) + + def _print_gamma(self, e): + func_name = greek_unicode['Gamma'] if self._use_unicode else 'Gamma' + return self._print_Function(e, func_name=func_name) + + def _print_uppergamma(self, e): + func_name = greek_unicode['Gamma'] if self._use_unicode else 'Gamma' + return self._print_Function(e, func_name=func_name) + + def _print_lowergamma(self, e): + func_name = greek_unicode['gamma'] if self._use_unicode else 'lowergamma' + return self._print_Function(e, func_name=func_name) + + def _print_DiracDelta(self, e): + if self._use_unicode: + if len(e.args) == 2: + a = prettyForm(greek_unicode['delta']) + b = self._print(e.args[1]) + b = prettyForm(*b.parens()) + c = self._print(e.args[0]) + c = prettyForm(*c.parens()) + pform = a**b + pform = prettyForm(*pform.right(' ')) + pform = prettyForm(*pform.right(c)) + return pform + pform = self._print(e.args[0]) + pform = prettyForm(*pform.parens()) + pform = prettyForm(*pform.left(greek_unicode['delta'])) + return pform + else: + return self._print_Function(e) + + def _print_expint(self, e): + subscript = self._print(e.args[0]) + if self._use_unicode and is_subscriptable_in_unicode(subscript): + return self._print_Function(Function('E_%s' % subscript)(e.args[1])) + return self._print_Function(e) + + def _print_Chi(self, e): + # This needs a special case since otherwise it comes out as greek + # letter chi... + prettyFunc = prettyForm("Chi") + prettyArgs = prettyForm(*self._print_seq(e.args).parens()) + + pform = prettyForm( + binding=prettyForm.FUNC, *stringPict.next(prettyFunc, prettyArgs)) + + # store pform parts so it can be reassembled e.g. when powered + pform.prettyFunc = prettyFunc + pform.prettyArgs = prettyArgs + + return pform + + def _print_elliptic_e(self, e): + pforma0 = self._print(e.args[0]) + if len(e.args) == 1: + pform = pforma0 + else: + pforma1 = self._print(e.args[1]) + pform = self._hprint_vseparator(pforma0, pforma1) + pform = prettyForm(*pform.parens()) + pform = prettyForm(*pform.left('E')) + return pform + + def _print_elliptic_k(self, e): + pform = self._print(e.args[0]) + pform = prettyForm(*pform.parens()) + pform = prettyForm(*pform.left('K')) + return pform + + def _print_elliptic_f(self, e): + pforma0 = self._print(e.args[0]) + pforma1 = self._print(e.args[1]) + pform = self._hprint_vseparator(pforma0, pforma1) + pform = prettyForm(*pform.parens()) + pform = prettyForm(*pform.left('F')) + return pform + + def _print_elliptic_pi(self, e): + name = greek_unicode['Pi'] if self._use_unicode else 'Pi' + pforma0 = self._print(e.args[0]) + pforma1 = self._print(e.args[1]) + if len(e.args) == 2: + pform = self._hprint_vseparator(pforma0, pforma1) + else: + pforma2 = self._print(e.args[2]) + pforma = self._hprint_vseparator(pforma1, pforma2, ifascii_nougly=False) + pforma = prettyForm(*pforma.left('; ')) + pform = prettyForm(*pforma.left(pforma0)) + pform = prettyForm(*pform.parens()) + pform = prettyForm(*pform.left(name)) + return pform + + def _print_GoldenRatio(self, expr): + if self._use_unicode: + return prettyForm(pretty_symbol('phi')) + return self._print(Symbol("GoldenRatio")) + + def _print_EulerGamma(self, expr): + if self._use_unicode: + return prettyForm(pretty_symbol('gamma')) + return self._print(Symbol("EulerGamma")) + + def _print_Catalan(self, expr): + return self._print(Symbol("G")) + + def _print_Mod(self, expr): + pform = self._print(expr.args[0]) + if pform.binding > prettyForm.MUL: + pform = prettyForm(*pform.parens()) + pform = prettyForm(*pform.right(' mod ')) + pform = prettyForm(*pform.right(self._print(expr.args[1]))) + pform.binding = prettyForm.OPEN + return pform + + def _print_Add(self, expr, order=None): + terms = self._as_ordered_terms(expr, order=order) + pforms, indices = [], [] + + def pretty_negative(pform, index): + """Prepend a minus sign to a pretty form. """ + #TODO: Move this code to prettyForm + if index == 0: + if pform.height() > 1: + pform_neg = '- ' + else: + pform_neg = '-' + else: + pform_neg = ' - ' + + if (pform.binding > prettyForm.NEG + or pform.binding == prettyForm.ADD): + p = stringPict(*pform.parens()) + else: + p = pform + p = stringPict.next(pform_neg, p) + # Lower the binding to NEG, even if it was higher. Otherwise, it + # will print as a + ( - (b)), instead of a - (b). + return prettyForm(binding=prettyForm.NEG, *p) + + for i, term in enumerate(terms): + if term.is_Mul and term.could_extract_minus_sign(): + coeff, other = term.as_coeff_mul(rational=False) + if coeff == -1: + negterm = Mul(*other, evaluate=False) + else: + negterm = Mul(-coeff, *other, evaluate=False) + pform = self._print(negterm) + pforms.append(pretty_negative(pform, i)) + elif term.is_Rational and term.q > 1: + pforms.append(None) + indices.append(i) + elif term.is_Number and term < 0: + pform = self._print(-term) + pforms.append(pretty_negative(pform, i)) + elif term.is_Relational: + pforms.append(prettyForm(*self._print(term).parens())) + else: + pforms.append(self._print(term)) + + if indices: + large = True + + for pform in pforms: + if pform is not None and pform.height() > 1: + break + else: + large = False + + for i in indices: + term, negative = terms[i], False + + if term < 0: + term, negative = -term, True + + if large: + pform = prettyForm(str(term.p))/prettyForm(str(term.q)) + else: + pform = self._print(term) + + if negative: + pform = pretty_negative(pform, i) + + pforms[i] = pform + + return prettyForm.__add__(*pforms) + + def _print_Mul(self, product): + from sympy.physics.units import Quantity + + # Check for unevaluated Mul. In this case we need to make sure the + # identities are visible, multiple Rational factors are not combined + # etc so we display in a straight-forward form that fully preserves all + # args and their order. + args = product.args + if args[0] is S.One or any(isinstance(arg, Number) for arg in args[1:]): + strargs = list(map(self._print, args)) + # XXX: This is a hack to work around the fact that + # prettyForm.__mul__ absorbs a leading -1 in the args. Probably it + # would be better to fix this in prettyForm.__mul__ instead. + negone = strargs[0] == '-1' + if negone: + strargs[0] = prettyForm('1', 0, 0) + obj = prettyForm.__mul__(*strargs) + if negone: + obj = prettyForm('-' + obj.s, obj.baseline, obj.binding) + return obj + + a = [] # items in the numerator + b = [] # items that are in the denominator (if any) + + if self.order not in ('old', 'none'): + args = product.as_ordered_factors() + else: + args = list(product.args) + + # If quantities are present append them at the back + args = sorted(args, key=lambda x: isinstance(x, Quantity) or + (isinstance(x, Pow) and isinstance(x.base, Quantity))) + + # Gather terms for numerator/denominator + for item in args: + if item.is_commutative and item.is_Pow and item.exp.is_Rational and item.exp.is_negative: + if item.exp != -1: + b.append(Pow(item.base, -item.exp, evaluate=False)) + else: + b.append(Pow(item.base, -item.exp)) + elif item.is_Rational and item is not S.Infinity: + if item.p != 1: + a.append( Rational(item.p) ) + if item.q != 1: + b.append( Rational(item.q) ) + else: + a.append(item) + + # Convert to pretty forms. Parentheses are added by `__mul__`. + a = [self._print(ai) for ai in a] + b = [self._print(bi) for bi in b] + + # Construct a pretty form + if len(b) == 0: + return prettyForm.__mul__(*a) + else: + if len(a) == 0: + a.append( self._print(S.One) ) + return prettyForm.__mul__(*a)/prettyForm.__mul__(*b) + + # A helper function for _print_Pow to print x**(1/n) + def _print_nth_root(self, base, root): + bpretty = self._print(base) + + # In very simple cases, use a single-char root sign + if (self._settings['use_unicode_sqrt_char'] and self._use_unicode + and root == 2 and bpretty.height() == 1 + and (bpretty.width() == 1 + or (base.is_Integer and base.is_nonnegative))): + return prettyForm(*bpretty.left(nth_root[2])) + + # Construct root sign, start with the \/ shape + _zZ = xobj('/', 1) + rootsign = xobj('\\', 1) + _zZ + # Constructing the number to put on root + rpretty = self._print(root) + # roots look bad if they are not a single line + if rpretty.height() != 1: + return self._print(base)**self._print(1/root) + # If power is half, no number should appear on top of root sign + exp = '' if root == 2 else str(rpretty).ljust(2) + if len(exp) > 2: + rootsign = ' '*(len(exp) - 2) + rootsign + # Stack the exponent + rootsign = stringPict(exp + '\n' + rootsign) + rootsign.baseline = 0 + # Diagonal: length is one less than height of base + linelength = bpretty.height() - 1 + diagonal = stringPict('\n'.join( + ' '*(linelength - i - 1) + _zZ + ' '*i + for i in range(linelength) + )) + # Put baseline just below lowest line: next to exp + diagonal.baseline = linelength - 1 + # Make the root symbol + rootsign = prettyForm(*rootsign.right(diagonal)) + # Det the baseline to match contents to fix the height + # but if the height of bpretty is one, the rootsign must be one higher + rootsign.baseline = max(1, bpretty.baseline) + #build result + s = prettyForm(hobj('_', 2 + bpretty.width())) + s = prettyForm(*bpretty.above(s)) + s = prettyForm(*s.left(rootsign)) + return s + + def _print_Pow(self, power): + from sympy.simplify.simplify import fraction + b, e = power.as_base_exp() + if power.is_commutative: + if e is S.NegativeOne: + return prettyForm("1")/self._print(b) + n, d = fraction(e) + if n is S.One and d.is_Atom and not e.is_Integer and (e.is_Rational or d.is_Symbol) \ + and self._settings['root_notation']: + return self._print_nth_root(b, d) + if e.is_Rational and e < 0: + return prettyForm("1")/self._print(Pow(b, -e, evaluate=False)) + + if b.is_Relational: + return prettyForm(*self._print(b).parens()).__pow__(self._print(e)) + + return self._print(b)**self._print(e) + + def _print_UnevaluatedExpr(self, expr): + return self._print(expr.args[0]) + + def __print_numer_denom(self, p, q): + if q == 1: + if p < 0: + return prettyForm(str(p), binding=prettyForm.NEG) + else: + return prettyForm(str(p)) + elif abs(p) >= 10 and abs(q) >= 10: + # If more than one digit in numer and denom, print larger fraction + if p < 0: + return prettyForm(str(p), binding=prettyForm.NEG)/prettyForm(str(q)) + # Old printing method: + #pform = prettyForm(str(-p))/prettyForm(str(q)) + #return prettyForm(binding=prettyForm.NEG, *pform.left('- ')) + else: + return prettyForm(str(p))/prettyForm(str(q)) + else: + return None + + def _print_Rational(self, expr): + result = self.__print_numer_denom(expr.p, expr.q) + + if result is not None: + return result + else: + return self.emptyPrinter(expr) + + def _print_Fraction(self, expr): + result = self.__print_numer_denom(expr.numerator, expr.denominator) + + if result is not None: + return result + else: + return self.emptyPrinter(expr) + + def _print_ProductSet(self, p): + if len(p.sets) >= 1 and not has_variety(p.sets): + return self._print(p.sets[0]) ** self._print(len(p.sets)) + else: + prod_char = pretty_atom('Multiplication') if self._use_unicode else 'x' + return self._print_seq(p.sets, None, None, ' %s ' % prod_char, + parenthesize=lambda set: set.is_Union or + set.is_Intersection or set.is_ProductSet) + + def _print_FiniteSet(self, s): + items = sorted(s.args, key=default_sort_key) + return self._print_seq(items, '{', '}', ', ' ) + + def _print_Range(self, s): + + if self._use_unicode: + dots = pretty_atom('Dots') + else: + dots = '...' + + if s.start.is_infinite and s.stop.is_infinite: + if s.step.is_positive: + printset = dots, -1, 0, 1, dots + else: + printset = dots, 1, 0, -1, dots + elif s.start.is_infinite: + printset = dots, s[-1] - s.step, s[-1] + elif s.stop.is_infinite: + it = iter(s) + printset = next(it), next(it), dots + elif len(s) > 4: + it = iter(s) + printset = next(it), next(it), dots, s[-1] + else: + printset = tuple(s) + + return self._print_seq(printset, '{', '}', ', ' ) + + def _print_Interval(self, i): + if i.start == i.end: + return self._print_seq(i.args[:1], '{', '}') + + else: + if i.left_open: + left = '(' + else: + left = '[' + + if i.right_open: + right = ')' + else: + right = ']' + + return self._print_seq(i.args[:2], left, right) + + def _print_AccumulationBounds(self, i): + left = '<' + right = '>' + + return self._print_seq(i.args[:2], left, right) + + def _print_Intersection(self, u): + + delimiter = ' %s ' % pretty_atom('Intersection', 'n') + + return self._print_seq(u.args, None, None, delimiter, + parenthesize=lambda set: set.is_ProductSet or + set.is_Union or set.is_Complement) + + def _print_Union(self, u): + + union_delimiter = ' %s ' % pretty_atom('Union', 'U') + + return self._print_seq(u.args, None, None, union_delimiter, + parenthesize=lambda set: set.is_ProductSet or + set.is_Intersection or set.is_Complement) + + def _print_SymmetricDifference(self, u): + if not self._use_unicode: + raise NotImplementedError("ASCII pretty printing of SymmetricDifference is not implemented") + + sym_delimeter = ' %s ' % pretty_atom('SymmetricDifference') + + return self._print_seq(u.args, None, None, sym_delimeter) + + def _print_Complement(self, u): + + delimiter = r' \ ' + + return self._print_seq(u.args, None, None, delimiter, + parenthesize=lambda set: set.is_ProductSet or set.is_Intersection + or set.is_Union) + + def _print_ImageSet(self, ts): + if self._use_unicode: + inn = pretty_atom("SmallElementOf") + else: + inn = 'in' + fun = ts.lamda + sets = ts.base_sets + signature = fun.signature + expr = self._print(fun.expr) + + # TODO: the stuff to the left of the | and the stuff to the right of + # the | should have independent baselines, that way something like + # ImageSet(Lambda(x, 1/x**2), S.Naturals) prints the "x in N" part + # centered on the right instead of aligned with the fraction bar on + # the left. The same also applies to ConditionSet and ComplexRegion + if len(signature) == 1: + S = self._print_seq((signature[0], inn, sets[0]), + delimiter=' ') + return self._hprint_vseparator(expr, S, + left='{', right='}', + ifascii_nougly=True, delimiter=' ') + else: + pargs = tuple(j for var, setv in zip(signature, sets) for j in + (var, ' ', inn, ' ', setv, ", ")) + S = self._print_seq(pargs[:-1], delimiter='') + return self._hprint_vseparator(expr, S, + left='{', right='}', + ifascii_nougly=True, delimiter=' ') + + def _print_ConditionSet(self, ts): + if self._use_unicode: + inn = pretty_atom('SmallElementOf') + # using _and because and is a keyword and it is bad practice to + # overwrite them + _and = pretty_atom('And') + else: + inn = 'in' + _and = 'and' + + variables = self._print_seq(Tuple(ts.sym)) + as_expr = getattr(ts.condition, 'as_expr', None) + if as_expr is not None: + cond = self._print(ts.condition.as_expr()) + else: + cond = self._print(ts.condition) + if self._use_unicode: + cond = self._print(cond) + cond = prettyForm(*cond.parens()) + + if ts.base_set is S.UniversalSet: + return self._hprint_vseparator(variables, cond, left="{", + right="}", ifascii_nougly=True, + delimiter=' ') + + base = self._print(ts.base_set) + C = self._print_seq((variables, inn, base, _and, cond), + delimiter=' ') + return self._hprint_vseparator(variables, C, left="{", right="}", + ifascii_nougly=True, delimiter=' ') + + def _print_ComplexRegion(self, ts): + if self._use_unicode: + inn = pretty_atom('SmallElementOf') + else: + inn = 'in' + variables = self._print_seq(ts.variables) + expr = self._print(ts.expr) + prodsets = self._print(ts.sets) + + C = self._print_seq((variables, inn, prodsets), + delimiter=' ') + return self._hprint_vseparator(expr, C, left="{", right="}", + ifascii_nougly=True, delimiter=' ') + + def _print_Contains(self, e): + var, set = e.args + if self._use_unicode: + el = f" {pretty_atom('ElementOf')} " + return prettyForm(*stringPict.next(self._print(var), + el, self._print(set)), binding=8) + else: + return prettyForm(sstr(e)) + + def _print_FourierSeries(self, s): + if s.an.formula is S.Zero and s.bn.formula is S.Zero: + return self._print(s.a0) + if self._use_unicode: + dots = pretty_atom('Dots') + else: + dots = '...' + return self._print_Add(s.truncate()) + self._print(dots) + + def _print_FormalPowerSeries(self, s): + return self._print_Add(s.infinite) + + def _print_SetExpr(self, se): + pretty_set = prettyForm(*self._print(se.set).parens()) + pretty_name = self._print(Symbol("SetExpr")) + return prettyForm(*pretty_name.right(pretty_set)) + + def _print_SeqFormula(self, s): + if self._use_unicode: + dots = pretty_atom('Dots') + else: + dots = '...' + + if len(s.start.free_symbols) > 0 or len(s.stop.free_symbols) > 0: + raise NotImplementedError("Pretty printing of sequences with symbolic bound not implemented") + + if s.start is S.NegativeInfinity: + stop = s.stop + printset = (dots, s.coeff(stop - 3), s.coeff(stop - 2), + s.coeff(stop - 1), s.coeff(stop)) + elif s.stop is S.Infinity or s.length > 4: + printset = s[:4] + printset.append(dots) + printset = tuple(printset) + else: + printset = tuple(s) + return self._print_list(printset) + + _print_SeqPer = _print_SeqFormula + _print_SeqAdd = _print_SeqFormula + _print_SeqMul = _print_SeqFormula + + def _print_seq(self, seq, left=None, right=None, delimiter=', ', + parenthesize=lambda x: False, ifascii_nougly=True): + + pforms = [] + for item in seq: + pform = self._print(item) + if parenthesize(item): + pform = prettyForm(*pform.parens()) + if pforms: + pforms.append(delimiter) + pforms.append(pform) + + if not pforms: + s = stringPict('') + else: + s = prettyForm(*stringPict.next(*pforms)) + + s = prettyForm(*s.parens(left, right, ifascii_nougly=ifascii_nougly)) + return s + + def join(self, delimiter, args): + pform = None + + for arg in args: + if pform is None: + pform = arg + else: + pform = prettyForm(*pform.right(delimiter)) + pform = prettyForm(*pform.right(arg)) + + if pform is None: + return prettyForm("") + else: + return pform + + def _print_list(self, l): + return self._print_seq(l, '[', ']') + + def _print_tuple(self, t): + if len(t) == 1: + ptuple = prettyForm(*stringPict.next(self._print(t[0]), ',')) + return prettyForm(*ptuple.parens('(', ')', ifascii_nougly=True)) + else: + return self._print_seq(t, '(', ')') + + def _print_Tuple(self, expr): + return self._print_tuple(expr) + + def _print_dict(self, d): + keys = sorted(d.keys(), key=default_sort_key) + items = [] + + for k in keys: + K = self._print(k) + V = self._print(d[k]) + s = prettyForm(*stringPict.next(K, ': ', V)) + + items.append(s) + + return self._print_seq(items, '{', '}') + + def _print_Dict(self, d): + return self._print_dict(d) + + def _print_set(self, s): + if not s: + return prettyForm('set()') + items = sorted(s, key=default_sort_key) + pretty = self._print_seq(items) + pretty = prettyForm(*pretty.parens('{', '}', ifascii_nougly=True)) + return pretty + + def _print_frozenset(self, s): + if not s: + return prettyForm('frozenset()') + items = sorted(s, key=default_sort_key) + pretty = self._print_seq(items) + pretty = prettyForm(*pretty.parens('{', '}', ifascii_nougly=True)) + pretty = prettyForm(*pretty.parens('(', ')', ifascii_nougly=True)) + pretty = prettyForm(*stringPict.next(type(s).__name__, pretty)) + return pretty + + def _print_UniversalSet(self, s): + if self._use_unicode: + return prettyForm(pretty_atom('Universe')) + else: + return prettyForm('UniversalSet') + + def _print_PolyRing(self, ring): + return prettyForm(sstr(ring)) + + def _print_FracField(self, field): + return prettyForm(sstr(field)) + + def _print_FreeGroupElement(self, elm): + return prettyForm(str(elm)) + + def _print_PolyElement(self, poly): + return prettyForm(sstr(poly)) + + def _print_FracElement(self, frac): + return prettyForm(sstr(frac)) + + def _print_AlgebraicNumber(self, expr): + if expr.is_aliased: + return self._print(expr.as_poly().as_expr()) + else: + return self._print(expr.as_expr()) + + def _print_ComplexRootOf(self, expr): + args = [self._print_Add(expr.expr, order='lex'), expr.index] + pform = prettyForm(*self._print_seq(args).parens()) + pform = prettyForm(*pform.left('CRootOf')) + return pform + + def _print_RootSum(self, expr): + args = [self._print_Add(expr.expr, order='lex')] + + if expr.fun is not S.IdentityFunction: + args.append(self._print(expr.fun)) + + pform = prettyForm(*self._print_seq(args).parens()) + pform = prettyForm(*pform.left('RootSum')) + + return pform + + def _print_FiniteField(self, expr): + if self._use_unicode: + form = f"{pretty_atom('Integers')}_%d" + else: + form = 'GF(%d)' + + return prettyForm(pretty_symbol(form % expr.mod)) + + def _print_IntegerRing(self, expr): + if self._use_unicode: + return prettyForm(pretty_atom('Integers')) + else: + return prettyForm('ZZ') + + def _print_RationalField(self, expr): + if self._use_unicode: + return prettyForm(pretty_atom('Rationals')) + else: + return prettyForm('QQ') + + def _print_RealField(self, domain): + if self._use_unicode: + prefix = pretty_atom("Reals") + else: + prefix = 'RR' + + if domain.has_default_precision: + return prettyForm(prefix) + else: + return self._print(pretty_symbol(prefix + "_" + str(domain.precision))) + + def _print_ComplexField(self, domain): + if self._use_unicode: + prefix = pretty_atom('Complexes') + else: + prefix = 'CC' + + if domain.has_default_precision: + return prettyForm(prefix) + else: + return self._print(pretty_symbol(prefix + "_" + str(domain.precision))) + + def _print_PolynomialRing(self, expr): + args = list(expr.symbols) + + if not expr.order.is_default: + order = prettyForm(*prettyForm("order=").right(self._print(expr.order))) + args.append(order) + + pform = self._print_seq(args, '[', ']') + pform = prettyForm(*pform.left(self._print(expr.domain))) + + return pform + + def _print_FractionField(self, expr): + args = list(expr.symbols) + + if not expr.order.is_default: + order = prettyForm(*prettyForm("order=").right(self._print(expr.order))) + args.append(order) + + pform = self._print_seq(args, '(', ')') + pform = prettyForm(*pform.left(self._print(expr.domain))) + + return pform + + def _print_PolynomialRingBase(self, expr): + g = expr.symbols + if str(expr.order) != str(expr.default_order): + g = g + ("order=" + str(expr.order),) + pform = self._print_seq(g, '[', ']') + pform = prettyForm(*pform.left(self._print(expr.domain))) + + return pform + + def _print_GroebnerBasis(self, basis): + exprs = [ self._print_Add(arg, order=basis.order) + for arg in basis.exprs ] + exprs = prettyForm(*self.join(", ", exprs).parens(left="[", right="]")) + + gens = [ self._print(gen) for gen in basis.gens ] + + domain = prettyForm( + *prettyForm("domain=").right(self._print(basis.domain))) + order = prettyForm( + *prettyForm("order=").right(self._print(basis.order))) + + pform = self.join(", ", [exprs] + gens + [domain, order]) + + pform = prettyForm(*pform.parens()) + pform = prettyForm(*pform.left(basis.__class__.__name__)) + + return pform + + def _print_Subs(self, e): + pform = self._print(e.expr) + pform = prettyForm(*pform.parens()) + + h = pform.height() if pform.height() > 1 else 2 + rvert = stringPict(vobj('|', h), baseline=pform.baseline) + pform = prettyForm(*pform.right(rvert)) + + b = pform.baseline + pform.baseline = pform.height() - 1 + pform = prettyForm(*pform.right(self._print_seq([ + self._print_seq((self._print(v[0]), xsym('=='), self._print(v[1])), + delimiter='') for v in zip(e.variables, e.point) ]))) + + pform.baseline = b + return pform + + def _print_number_function(self, e, name): + # Print name_arg[0] for one argument or name_arg[0](arg[1]) + # for more than one argument + pform = prettyForm(name) + arg = self._print(e.args[0]) + pform_arg = prettyForm(" "*arg.width()) + pform_arg = prettyForm(*pform_arg.below(arg)) + pform = prettyForm(*pform.right(pform_arg)) + if len(e.args) == 1: + return pform + m, x = e.args + # TODO: copy-pasted from _print_Function: can we do better? + prettyFunc = pform + prettyArgs = prettyForm(*self._print_seq([x]).parens()) + pform = prettyForm( + binding=prettyForm.FUNC, *stringPict.next(prettyFunc, prettyArgs)) + pform.prettyFunc = prettyFunc + pform.prettyArgs = prettyArgs + return pform + + def _print_euler(self, e): + return self._print_number_function(e, "E") + + def _print_catalan(self, e): + return self._print_number_function(e, "C") + + def _print_bernoulli(self, e): + return self._print_number_function(e, "B") + + _print_bell = _print_bernoulli + + def _print_lucas(self, e): + return self._print_number_function(e, "L") + + def _print_fibonacci(self, e): + return self._print_number_function(e, "F") + + def _print_tribonacci(self, e): + return self._print_number_function(e, "T") + + def _print_stieltjes(self, e): + if self._use_unicode: + return self._print_number_function(e, greek_unicode['gamma']) + else: + return self._print_number_function(e, "stieltjes") + + def _print_KroneckerDelta(self, e): + pform = self._print(e.args[0]) + pform = prettyForm(*pform.right(prettyForm(','))) + pform = prettyForm(*pform.right(self._print(e.args[1]))) + if self._use_unicode: + a = stringPict(pretty_symbol('delta')) + else: + a = stringPict('d') + b = pform + top = stringPict(*b.left(' '*a.width())) + bot = stringPict(*a.right(' '*b.width())) + return prettyForm(binding=prettyForm.POW, *bot.below(top)) + + def _print_RandomDomain(self, d): + if hasattr(d, 'as_boolean'): + pform = self._print('Domain: ') + pform = prettyForm(*pform.right(self._print(d.as_boolean()))) + return pform + elif hasattr(d, 'set'): + pform = self._print('Domain: ') + pform = prettyForm(*pform.right(self._print(d.symbols))) + pform = prettyForm(*pform.right(self._print(' in '))) + pform = prettyForm(*pform.right(self._print(d.set))) + return pform + elif hasattr(d, 'symbols'): + pform = self._print('Domain on ') + pform = prettyForm(*pform.right(self._print(d.symbols))) + return pform + else: + return self._print(None) + + def _print_DMP(self, p): + try: + if p.ring is not None: + # TODO incorporate order + return self._print(p.ring.to_sympy(p)) + except SympifyError: + pass + return self._print(repr(p)) + + def _print_DMF(self, p): + return self._print_DMP(p) + + def _print_Object(self, object): + return self._print(pretty_symbol(object.name)) + + def _print_Morphism(self, morphism): + arrow = xsym("-->") + + domain = self._print(morphism.domain) + codomain = self._print(morphism.codomain) + tail = domain.right(arrow, codomain)[0] + + return prettyForm(tail) + + def _print_NamedMorphism(self, morphism): + pretty_name = self._print(pretty_symbol(morphism.name)) + pretty_morphism = self._print_Morphism(morphism) + return prettyForm(pretty_name.right(":", pretty_morphism)[0]) + + def _print_IdentityMorphism(self, morphism): + from sympy.categories import NamedMorphism + return self._print_NamedMorphism( + NamedMorphism(morphism.domain, morphism.codomain, "id")) + + def _print_CompositeMorphism(self, morphism): + + circle = xsym(".") + + # All components of the morphism have names and it is thus + # possible to build the name of the composite. + component_names_list = [pretty_symbol(component.name) for + component in morphism.components] + component_names_list.reverse() + component_names = circle.join(component_names_list) + ":" + + pretty_name = self._print(component_names) + pretty_morphism = self._print_Morphism(morphism) + return prettyForm(pretty_name.right(pretty_morphism)[0]) + + def _print_Category(self, category): + return self._print(pretty_symbol(category.name)) + + def _print_Diagram(self, diagram): + if not diagram.premises: + # This is an empty diagram. + return self._print(S.EmptySet) + + pretty_result = self._print(diagram.premises) + if diagram.conclusions: + results_arrow = " %s " % xsym("==>") + + pretty_conclusions = self._print(diagram.conclusions)[0] + pretty_result = pretty_result.right( + results_arrow, pretty_conclusions) + + return prettyForm(pretty_result[0]) + + def _print_DiagramGrid(self, grid): + from sympy.matrices import Matrix + matrix = Matrix([[grid[i, j] if grid[i, j] else Symbol(" ") + for j in range(grid.width)] + for i in range(grid.height)]) + return self._print_matrix_contents(matrix) + + def _print_FreeModuleElement(self, m): + # Print as row vector for convenience, for now. + return self._print_seq(m, '[', ']') + + def _print_SubModule(self, M): + gens = [[M.ring.to_sympy(g) for g in gen] for gen in M.gens] + return self._print_seq(gens, '<', '>') + + def _print_FreeModule(self, M): + return self._print(M.ring)**self._print(M.rank) + + def _print_ModuleImplementedIdeal(self, M): + sym = M.ring.to_sympy + return self._print_seq([sym(x) for [x] in M._module.gens], '<', '>') + + def _print_QuotientRing(self, R): + return self._print(R.ring) / self._print(R.base_ideal) + + def _print_QuotientRingElement(self, R): + return self._print(R.ring.to_sympy(R)) + self._print(R.ring.base_ideal) + + def _print_QuotientModuleElement(self, m): + return self._print(m.data) + self._print(m.module.killed_module) + + def _print_QuotientModule(self, M): + return self._print(M.base) / self._print(M.killed_module) + + def _print_MatrixHomomorphism(self, h): + matrix = self._print(h._sympy_matrix()) + matrix.baseline = matrix.height() // 2 + pform = prettyForm(*matrix.right(' : ', self._print(h.domain), + ' %s> ' % hobj('-', 2), self._print(h.codomain))) + return pform + + def _print_Manifold(self, manifold): + return self._print(manifold.name) + + def _print_Patch(self, patch): + return self._print(patch.name) + + def _print_CoordSystem(self, coords): + return self._print(coords.name) + + def _print_BaseScalarField(self, field): + string = field._coord_sys.symbols[field._index].name + return self._print(pretty_symbol(string)) + + def _print_BaseVectorField(self, field): + s = U('PARTIAL DIFFERENTIAL') + '_' + field._coord_sys.symbols[field._index].name + return self._print(pretty_symbol(s)) + + def _print_Differential(self, diff): + if self._use_unicode: + d = pretty_atom('Differential') + else: + d = 'd' + field = diff._form_field + if hasattr(field, '_coord_sys'): + string = field._coord_sys.symbols[field._index].name + return self._print(d + ' ' + pretty_symbol(string)) + else: + pform = self._print(field) + pform = prettyForm(*pform.parens()) + return prettyForm(*pform.left(d)) + + def _print_Tr(self, p): + #TODO: Handle indices + pform = self._print(p.args[0]) + pform = prettyForm(*pform.left('%s(' % (p.__class__.__name__))) + pform = prettyForm(*pform.right(')')) + return pform + + def _print_primenu(self, e): + pform = self._print(e.args[0]) + pform = prettyForm(*pform.parens()) + if self._use_unicode: + pform = prettyForm(*pform.left(greek_unicode['nu'])) + else: + pform = prettyForm(*pform.left('nu')) + return pform + + def _print_primeomega(self, e): + pform = self._print(e.args[0]) + pform = prettyForm(*pform.parens()) + if self._use_unicode: + pform = prettyForm(*pform.left(greek_unicode['Omega'])) + else: + pform = prettyForm(*pform.left('Omega')) + return pform + + def _print_Quantity(self, e): + if e.name.name == 'degree': + if self._use_unicode: + pform = self._print(pretty_atom('Degree')) + else: + pform = self._print(chr(176)) + return pform + else: + return self.emptyPrinter(e) + + def _print_AssignmentBase(self, e): + + op = prettyForm(' ' + xsym(e.op) + ' ') + + l = self._print(e.lhs) + r = self._print(e.rhs) + pform = prettyForm(*stringPict.next(l, op, r)) + return pform + + def _print_Str(self, s): + return self._print(s.name) + + +@print_function(PrettyPrinter) +def pretty(expr, **settings): + """Returns a string containing the prettified form of expr. + + For information on keyword arguments see pretty_print function. + + """ + pp = PrettyPrinter(settings) + + # XXX: this is an ugly hack, but at least it works + use_unicode = pp._settings['use_unicode'] + uflag = pretty_use_unicode(use_unicode) + + try: + return pp.doprint(expr) + finally: + pretty_use_unicode(uflag) + + +def pretty_print(expr, **kwargs): + """Prints expr in pretty form. + + pprint is just a shortcut for this function. + + Parameters + ========== + + expr : expression + The expression to print. + + wrap_line : bool, optional (default=True) + Line wrapping enabled/disabled. + + num_columns : int or None, optional (default=None) + Number of columns before line breaking (default to None which reads + the terminal width), useful when using SymPy without terminal. + + use_unicode : bool or None, optional (default=None) + Use unicode characters, such as the Greek letter pi instead of + the string pi. + + full_prec : bool or string, optional (default="auto") + Use full precision. + + order : bool or string, optional (default=None) + Set to 'none' for long expressions if slow; default is None. + + use_unicode_sqrt_char : bool, optional (default=True) + Use compact single-character square root symbol (when unambiguous). + + root_notation : bool, optional (default=True) + Set to 'False' for printing exponents of the form 1/n in fractional form. + By default exponent is printed in root form. + + mat_symbol_style : string, optional (default="plain") + Set to "bold" for printing MatrixSymbols using a bold mathematical symbol face. + By default the standard face is used. + + imaginary_unit : string, optional (default="i") + Letter to use for imaginary unit when use_unicode is True. + Can be "i" (default) or "j". + """ + print(pretty(expr, **kwargs)) + +pprint = pretty_print + + +def pager_print(expr, **settings): + """Prints expr using the pager, in pretty form. + + This invokes a pager command using pydoc. Lines are not wrapped + automatically. This routine is meant to be used with a pager that allows + sideways scrolling, like ``less -S``. + + Parameters are the same as for ``pretty_print``. If you wish to wrap lines, + pass ``num_columns=None`` to auto-detect the width of the terminal. + + """ + from pydoc import pager + from locale import getpreferredencoding + if 'num_columns' not in settings: + settings['num_columns'] = 500000 # disable line wrap + pager(pretty(expr, **settings).encode(getpreferredencoding())) diff --git a/phivenv/Lib/site-packages/sympy/printing/pretty/pretty_symbology.py b/phivenv/Lib/site-packages/sympy/printing/pretty/pretty_symbology.py new file mode 100644 index 0000000000000000000000000000000000000000..bdb6ec556c6ed7b15dfcddcfc3da189102d5395b --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/pretty/pretty_symbology.py @@ -0,0 +1,731 @@ +"""Symbolic primitives + unicode/ASCII abstraction for pretty.py""" + +import sys +import warnings +from string import ascii_lowercase, ascii_uppercase +import unicodedata + +unicode_warnings = '' + +def U(name): + """ + Get a unicode character by name or, None if not found. + + This exists because older versions of Python use older unicode databases. + """ + try: + return unicodedata.lookup(name) + except KeyError: + global unicode_warnings + unicode_warnings += 'No \'%s\' in unicodedata\n' % name + return None + +from sympy.printing.conventions import split_super_sub +from sympy.core.alphabets import greeks +from sympy.utilities.exceptions import sympy_deprecation_warning + +# prefix conventions when constructing tables +# L - LATIN i +# G - GREEK beta +# D - DIGIT 0 +# S - SYMBOL + + + +__all__ = ['greek_unicode', 'sub', 'sup', 'xsym', 'vobj', 'hobj', 'pretty_symbol', + 'annotated', 'center_pad', 'center'] + + +_use_unicode = False + + +def pretty_use_unicode(flag=None): + """Set whether pretty-printer should use unicode by default""" + global _use_unicode, unicode_warnings + if flag is None: + return _use_unicode + + if flag and unicode_warnings: + # print warnings (if any) on first unicode usage + warnings.warn(unicode_warnings) + unicode_warnings = '' + + use_unicode_prev = _use_unicode + _use_unicode = flag + return use_unicode_prev + + +def pretty_try_use_unicode(): + """See if unicode output is available and leverage it if possible""" + + encoding = getattr(sys.stdout, 'encoding', None) + + # this happens when e.g. stdout is redirected through a pipe, or is + # e.g. a cStringIO.StringO + if encoding is None: + return # sys.stdout has no encoding + + symbols = [] + + # see if we can represent greek alphabet + symbols += greek_unicode.values() + + # and atoms + symbols += atoms_table.values() + + for s in symbols: + if s is None: + return # common symbols not present! + + try: + s.encode(encoding) + except UnicodeEncodeError: + return + + # all the characters were present and encodable + pretty_use_unicode(True) + + +def xstr(*args): + sympy_deprecation_warning( + """ + The sympy.printing.pretty.pretty_symbology.xstr() function is + deprecated. Use str() instead. + """, + deprecated_since_version="1.7", + active_deprecations_target="deprecated-pretty-printing-functions" + ) + return str(*args) + +# GREEK +g = lambda l: U('GREEK SMALL LETTER %s' % l.upper()) +G = lambda l: U('GREEK CAPITAL LETTER %s' % l.upper()) + +greek_letters = list(greeks) # make a copy +# deal with Unicode's funny spelling of lambda +greek_letters[greek_letters.index('lambda')] = 'lamda' + +# {} greek letter -> (g,G) +greek_unicode = {L: g(L) for L in greek_letters} +greek_unicode.update((L[0].upper() + L[1:], G(L)) for L in greek_letters) + +# aliases +greek_unicode['lambda'] = greek_unicode['lamda'] +greek_unicode['Lambda'] = greek_unicode['Lamda'] +greek_unicode['varsigma'] = '\N{GREEK SMALL LETTER FINAL SIGMA}' + +# BOLD +b = lambda l: U('MATHEMATICAL BOLD SMALL %s' % l.upper()) +B = lambda l: U('MATHEMATICAL BOLD CAPITAL %s' % l.upper()) + +bold_unicode = {l: b(l) for l in ascii_lowercase} +bold_unicode.update((L, B(L)) for L in ascii_uppercase) + +# GREEK BOLD +gb = lambda l: U('MATHEMATICAL BOLD SMALL %s' % l.upper()) +GB = lambda l: U('MATHEMATICAL BOLD CAPITAL %s' % l.upper()) + +greek_bold_letters = list(greeks) # make a copy, not strictly required here +# deal with Unicode's funny spelling of lambda +greek_bold_letters[greek_bold_letters.index('lambda')] = 'lamda' + +# {} greek letter -> (g,G) +greek_bold_unicode = {L: g(L) for L in greek_bold_letters} +greek_bold_unicode.update((L[0].upper() + L[1:], G(L)) for L in greek_bold_letters) +greek_bold_unicode['lambda'] = greek_unicode['lamda'] +greek_bold_unicode['Lambda'] = greek_unicode['Lamda'] +greek_bold_unicode['varsigma'] = '\N{MATHEMATICAL BOLD SMALL FINAL SIGMA}' + +digit_2txt = { + '0': 'ZERO', + '1': 'ONE', + '2': 'TWO', + '3': 'THREE', + '4': 'FOUR', + '5': 'FIVE', + '6': 'SIX', + '7': 'SEVEN', + '8': 'EIGHT', + '9': 'NINE', +} + +symb_2txt = { + '+': 'PLUS SIGN', + '-': 'MINUS', + '=': 'EQUALS SIGN', + '(': 'LEFT PARENTHESIS', + ')': 'RIGHT PARENTHESIS', + '[': 'LEFT SQUARE BRACKET', + ']': 'RIGHT SQUARE BRACKET', + '{': 'LEFT CURLY BRACKET', + '}': 'RIGHT CURLY BRACKET', + + # non-std + '{}': 'CURLY BRACKET', + 'sum': 'SUMMATION', + 'int': 'INTEGRAL', +} + +# SUBSCRIPT & SUPERSCRIPT +LSUB = lambda letter: U('LATIN SUBSCRIPT SMALL LETTER %s' % letter.upper()) +GSUB = lambda letter: U('GREEK SUBSCRIPT SMALL LETTER %s' % letter.upper()) +DSUB = lambda digit: U('SUBSCRIPT %s' % digit_2txt[digit]) +SSUB = lambda symb: U('SUBSCRIPT %s' % symb_2txt[symb]) + +LSUP = lambda letter: U('SUPERSCRIPT LATIN SMALL LETTER %s' % letter.upper()) +DSUP = lambda digit: U('SUPERSCRIPT %s' % digit_2txt[digit]) +SSUP = lambda symb: U('SUPERSCRIPT %s' % symb_2txt[symb]) + +sub = {} # symb -> subscript symbol +sup = {} # symb -> superscript symbol + +# latin subscripts +for l in 'aeioruvxhklmnpst': + sub[l] = LSUB(l) + +for l in 'in': + sup[l] = LSUP(l) + +for gl in ['beta', 'gamma', 'rho', 'phi', 'chi']: + sub[gl] = GSUB(gl) + +for d in [str(i) for i in range(10)]: + sub[d] = DSUB(d) + sup[d] = DSUP(d) + +for s in '+-=()': + sub[s] = SSUB(s) + sup[s] = SSUP(s) + +# Variable modifiers +# TODO: Make brackets adjust to height of contents +modifier_dict = { + # Accents + 'mathring': lambda s: center_accent(s, '\N{COMBINING RING ABOVE}'), + 'ddddot': lambda s: center_accent(s, '\N{COMBINING FOUR DOTS ABOVE}'), + 'dddot': lambda s: center_accent(s, '\N{COMBINING THREE DOTS ABOVE}'), + 'ddot': lambda s: center_accent(s, '\N{COMBINING DIAERESIS}'), + 'dot': lambda s: center_accent(s, '\N{COMBINING DOT ABOVE}'), + 'check': lambda s: center_accent(s, '\N{COMBINING CARON}'), + 'breve': lambda s: center_accent(s, '\N{COMBINING BREVE}'), + 'acute': lambda s: center_accent(s, '\N{COMBINING ACUTE ACCENT}'), + 'grave': lambda s: center_accent(s, '\N{COMBINING GRAVE ACCENT}'), + 'tilde': lambda s: center_accent(s, '\N{COMBINING TILDE}'), + 'hat': lambda s: center_accent(s, '\N{COMBINING CIRCUMFLEX ACCENT}'), + 'bar': lambda s: center_accent(s, '\N{COMBINING OVERLINE}'), + 'vec': lambda s: center_accent(s, '\N{COMBINING RIGHT ARROW ABOVE}'), + 'prime': lambda s: s+'\N{PRIME}', + 'prm': lambda s: s+'\N{PRIME}', + # # Faces -- these are here for some compatibility with latex printing + # 'bold': lambda s: s, + # 'bm': lambda s: s, + # 'cal': lambda s: s, + # 'scr': lambda s: s, + # 'frak': lambda s: s, + # Brackets + 'norm': lambda s: '\N{DOUBLE VERTICAL LINE}'+s+'\N{DOUBLE VERTICAL LINE}', + 'avg': lambda s: '\N{MATHEMATICAL LEFT ANGLE BRACKET}'+s+'\N{MATHEMATICAL RIGHT ANGLE BRACKET}', + 'abs': lambda s: '\N{VERTICAL LINE}'+s+'\N{VERTICAL LINE}', + 'mag': lambda s: '\N{VERTICAL LINE}'+s+'\N{VERTICAL LINE}', +} + +# VERTICAL OBJECTS +HUP = lambda symb: U('%s UPPER HOOK' % symb_2txt[symb]) +CUP = lambda symb: U('%s UPPER CORNER' % symb_2txt[symb]) +MID = lambda symb: U('%s MIDDLE PIECE' % symb_2txt[symb]) +EXT = lambda symb: U('%s EXTENSION' % symb_2txt[symb]) +HLO = lambda symb: U('%s LOWER HOOK' % symb_2txt[symb]) +CLO = lambda symb: U('%s LOWER CORNER' % symb_2txt[symb]) +TOP = lambda symb: U('%s TOP' % symb_2txt[symb]) +BOT = lambda symb: U('%s BOTTOM' % symb_2txt[symb]) + +# {} '(' -> (extension, start, end, middle) 1-character +_xobj_unicode = { + + # vertical symbols + # (( ext, top, bot, mid ), c1) + '(': (( EXT('('), HUP('('), HLO('(') ), '('), + ')': (( EXT(')'), HUP(')'), HLO(')') ), ')'), + '[': (( EXT('['), CUP('['), CLO('[') ), '['), + ']': (( EXT(']'), CUP(']'), CLO(']') ), ']'), + '{': (( EXT('{}'), HUP('{'), HLO('{'), MID('{') ), '{'), + '}': (( EXT('{}'), HUP('}'), HLO('}'), MID('}') ), '}'), + '|': U('BOX DRAWINGS LIGHT VERTICAL'), + 'Tee': U('BOX DRAWINGS LIGHT UP AND HORIZONTAL'), + 'UpTack': U('BOX DRAWINGS LIGHT DOWN AND HORIZONTAL'), + 'corner_up_centre' + '(_ext': U('LEFT PARENTHESIS EXTENSION'), + ')_ext': U('RIGHT PARENTHESIS EXTENSION'), + '(_lower_hook': U('LEFT PARENTHESIS LOWER HOOK'), + ')_lower_hook': U('RIGHT PARENTHESIS LOWER HOOK'), + '(_upper_hook': U('LEFT PARENTHESIS UPPER HOOK'), + ')_upper_hook': U('RIGHT PARENTHESIS UPPER HOOK'), + '<': ((U('BOX DRAWINGS LIGHT VERTICAL'), + U('BOX DRAWINGS LIGHT DIAGONAL UPPER RIGHT TO LOWER LEFT'), + U('BOX DRAWINGS LIGHT DIAGONAL UPPER LEFT TO LOWER RIGHT')), '<'), + + '>': ((U('BOX DRAWINGS LIGHT VERTICAL'), + U('BOX DRAWINGS LIGHT DIAGONAL UPPER LEFT TO LOWER RIGHT'), + U('BOX DRAWINGS LIGHT DIAGONAL UPPER RIGHT TO LOWER LEFT')), '>'), + + 'lfloor': (( EXT('['), EXT('['), CLO('[') ), U('LEFT FLOOR')), + 'rfloor': (( EXT(']'), EXT(']'), CLO(']') ), U('RIGHT FLOOR')), + 'lceil': (( EXT('['), CUP('['), EXT('[') ), U('LEFT CEILING')), + 'rceil': (( EXT(']'), CUP(']'), EXT(']') ), U('RIGHT CEILING')), + + 'int': (( EXT('int'), U('TOP HALF INTEGRAL'), U('BOTTOM HALF INTEGRAL') ), U('INTEGRAL')), + 'sum': (( U('BOX DRAWINGS LIGHT DIAGONAL UPPER LEFT TO LOWER RIGHT'), '_', U('OVERLINE'), U('BOX DRAWINGS LIGHT DIAGONAL UPPER RIGHT TO LOWER LEFT')), U('N-ARY SUMMATION')), + + # horizontal objects + #'-': '-', + '-': U('BOX DRAWINGS LIGHT HORIZONTAL'), + '_': U('LOW LINE'), + # We used to use this, but LOW LINE looks better for roots, as it's a + # little lower (i.e., it lines up with the / perfectly. But perhaps this + # one would still be wanted for some cases? + # '_': U('HORIZONTAL SCAN LINE-9'), + + # diagonal objects '\' & '/' ? + '/': U('BOX DRAWINGS LIGHT DIAGONAL UPPER RIGHT TO LOWER LEFT'), + '\\': U('BOX DRAWINGS LIGHT DIAGONAL UPPER LEFT TO LOWER RIGHT'), +} + +_xobj_ascii = { + # vertical symbols + # (( ext, top, bot, mid ), c1) + '(': (( '|', '/', '\\' ), '('), + ')': (( '|', '\\', '/' ), ')'), + +# XXX this looks ugly +# '[': (( '|', '-', '-' ), '['), +# ']': (( '|', '-', '-' ), ']'), +# XXX not so ugly :( + '[': (( '[', '[', '[' ), '['), + ']': (( ']', ']', ']' ), ']'), + + '{': (( '|', '/', '\\', '<' ), '{'), + '}': (( '|', '\\', '/', '>' ), '}'), + '|': '|', + + '<': (( '|', '/', '\\' ), '<'), + '>': (( '|', '\\', '/' ), '>'), + + 'int': ( ' | ', ' /', '/ ' ), + + # horizontal objects + '-': '-', + '_': '_', + + # diagonal objects '\' & '/' ? + '/': '/', + '\\': '\\', +} + + +def xobj(symb, length): + """Construct spatial object of given length. + + return: [] of equal-length strings + """ + + if length <= 0: + raise ValueError("Length should be greater than 0") + + # TODO robustify when no unicodedat available + if _use_unicode: + _xobj = _xobj_unicode + else: + _xobj = _xobj_ascii + + vinfo = _xobj[symb] + + c1 = top = bot = mid = None + + if not isinstance(vinfo, tuple): # 1 entry + ext = vinfo + else: + if isinstance(vinfo[0], tuple): # (vlong), c1 + vlong = vinfo[0] + c1 = vinfo[1] + else: # (vlong), c1 + vlong = vinfo + + ext = vlong[0] + + try: + top = vlong[1] + bot = vlong[2] + mid = vlong[3] + except IndexError: + pass + + if c1 is None: + c1 = ext + if top is None: + top = ext + if bot is None: + bot = ext + if mid is not None: + if (length % 2) == 0: + # even height, but we have to print it somehow anyway... + # XXX is it ok? + length += 1 + + else: + mid = ext + + if length == 1: + return c1 + + res = [] + next = (length - 2)//2 + nmid = (length - 2) - next*2 + + res += [top] + res += [ext]*next + res += [mid]*nmid + res += [ext]*next + res += [bot] + + return res + + +def vobj(symb, height): + """Construct vertical object of a given height + + see: xobj + """ + return '\n'.join( xobj(symb, height) ) + + +def hobj(symb, width): + """Construct horizontal object of a given width + + see: xobj + """ + return ''.join( xobj(symb, width) ) + +# RADICAL +# n -> symbol +root = { + 2: U('SQUARE ROOT'), # U('RADICAL SYMBOL BOTTOM') + 3: U('CUBE ROOT'), + 4: U('FOURTH ROOT'), +} + + +# RATIONAL +VF = lambda txt: U('VULGAR FRACTION %s' % txt) + +# (p,q) -> symbol +frac = { + (1, 2): VF('ONE HALF'), + (1, 3): VF('ONE THIRD'), + (2, 3): VF('TWO THIRDS'), + (1, 4): VF('ONE QUARTER'), + (3, 4): VF('THREE QUARTERS'), + (1, 5): VF('ONE FIFTH'), + (2, 5): VF('TWO FIFTHS'), + (3, 5): VF('THREE FIFTHS'), + (4, 5): VF('FOUR FIFTHS'), + (1, 6): VF('ONE SIXTH'), + (5, 6): VF('FIVE SIXTHS'), + (1, 8): VF('ONE EIGHTH'), + (3, 8): VF('THREE EIGHTHS'), + (5, 8): VF('FIVE EIGHTHS'), + (7, 8): VF('SEVEN EIGHTHS'), +} + + +# atom symbols +_xsym = { + '==': ('=', '='), + '<': ('<', '<'), + '>': ('>', '>'), + '<=': ('<=', U('LESS-THAN OR EQUAL TO')), + '>=': ('>=', U('GREATER-THAN OR EQUAL TO')), + '!=': ('!=', U('NOT EQUAL TO')), + ':=': (':=', ':='), + '+=': ('+=', '+='), + '-=': ('-=', '-='), + '*=': ('*=', '*='), + '/=': ('/=', '/='), + '%=': ('%=', '%='), + '*': ('*', U('DOT OPERATOR')), + '-->': ('-->', U('EM DASH') + U('EM DASH') + + U('BLACK RIGHT-POINTING TRIANGLE') if U('EM DASH') + and U('BLACK RIGHT-POINTING TRIANGLE') else None), + '==>': ('==>', U('BOX DRAWINGS DOUBLE HORIZONTAL') + + U('BOX DRAWINGS DOUBLE HORIZONTAL') + + U('BLACK RIGHT-POINTING TRIANGLE') if + U('BOX DRAWINGS DOUBLE HORIZONTAL') and + U('BOX DRAWINGS DOUBLE HORIZONTAL') and + U('BLACK RIGHT-POINTING TRIANGLE') else None), + '.': ('*', U('RING OPERATOR')), +} + + +def xsym(sym): + """get symbology for a 'character'""" + op = _xsym[sym] + + if _use_unicode: + return op[1] + else: + return op[0] + + +# SYMBOLS + +atoms_table = { + # class how-to-display + 'Exp1': U('SCRIPT SMALL E'), + 'Pi': U('GREEK SMALL LETTER PI'), + 'Infinity': U('INFINITY'), + 'NegativeInfinity': U('INFINITY') and ('-' + U('INFINITY')), # XXX what to do here + #'ImaginaryUnit': U('GREEK SMALL LETTER IOTA'), + #'ImaginaryUnit': U('MATHEMATICAL ITALIC SMALL I'), + 'ImaginaryUnit': U('DOUBLE-STRUCK ITALIC SMALL I'), + 'EmptySet': U('EMPTY SET'), + 'Naturals': U('DOUBLE-STRUCK CAPITAL N'), + 'Naturals0': (U('DOUBLE-STRUCK CAPITAL N') and + (U('DOUBLE-STRUCK CAPITAL N') + + U('SUBSCRIPT ZERO'))), + 'Integers': U('DOUBLE-STRUCK CAPITAL Z'), + 'Rationals': U('DOUBLE-STRUCK CAPITAL Q'), + 'Reals': U('DOUBLE-STRUCK CAPITAL R'), + 'Complexes': U('DOUBLE-STRUCK CAPITAL C'), + 'Universe': U('MATHEMATICAL DOUBLE-STRUCK CAPITAL U'), + 'IdentityMatrix': U('MATHEMATICAL DOUBLE-STRUCK CAPITAL I'), + 'ZeroMatrix': U('MATHEMATICAL DOUBLE-STRUCK DIGIT ZERO'), + 'OneMatrix': U('MATHEMATICAL DOUBLE-STRUCK DIGIT ONE'), + 'Differential': U('DOUBLE-STRUCK ITALIC SMALL D'), + 'Union': U('UNION'), + 'ElementOf': U('ELEMENT OF'), + 'SmallElementOf': U('SMALL ELEMENT OF'), + 'SymmetricDifference': U('INCREMENT'), + 'Intersection': U('INTERSECTION'), + 'Ring': U('RING OPERATOR'), + 'Multiplication': U('MULTIPLICATION SIGN'), + 'TensorProduct': U('N-ARY CIRCLED TIMES OPERATOR'), + 'Dots': U('HORIZONTAL ELLIPSIS'), + 'Modifier Letter Low Ring':U('Modifier Letter Low Ring'), + 'EmptySequence': 'EmptySequence', + 'SuperscriptPlus': U('SUPERSCRIPT PLUS SIGN'), + 'SuperscriptMinus': U('SUPERSCRIPT MINUS'), + 'Dagger': U('DAGGER'), + 'Degree': U('DEGREE SIGN'), + #Logic Symbols + 'And': U('LOGICAL AND'), + 'Or': U('LOGICAL OR'), + 'Not': U('NOT SIGN'), + 'Nor': U('NOR'), + 'Nand': U('NAND'), + 'Xor': U('XOR'), + 'Equiv': U('LEFT RIGHT DOUBLE ARROW'), + 'NotEquiv': U('LEFT RIGHT DOUBLE ARROW WITH STROKE'), + 'Implies': U('LEFT RIGHT DOUBLE ARROW'), + 'NotImplies': U('LEFT RIGHT DOUBLE ARROW WITH STROKE'), + 'Arrow': U('RIGHTWARDS ARROW'), + 'ArrowFromBar': U('RIGHTWARDS ARROW FROM BAR'), + 'NotArrow': U('RIGHTWARDS ARROW WITH STROKE'), + 'Tautology': U('BOX DRAWINGS LIGHT UP AND HORIZONTAL'), + 'Contradiction': U('BOX DRAWINGS LIGHT DOWN AND HORIZONTAL') +} + + +def pretty_atom(atom_name, default=None, printer=None): + """return pretty representation of an atom""" + if _use_unicode: + if printer is not None and atom_name == 'ImaginaryUnit' and printer._settings['imaginary_unit'] == 'j': + return U('DOUBLE-STRUCK ITALIC SMALL J') + else: + return atoms_table[atom_name] + else: + if default is not None: + return default + + raise KeyError('only unicode') # send it default printer + + +def pretty_symbol(symb_name, bold_name=False): + """return pretty representation of a symbol""" + # let's split symb_name into symbol + index + # UC: beta1 + # UC: f_beta + + if not _use_unicode: + return symb_name + + name, sups, subs = split_super_sub(symb_name) + + def translate(s, bold_name) : + if bold_name: + gG = greek_bold_unicode.get(s) + else: + gG = greek_unicode.get(s) + if gG is not None: + return gG + for key in sorted(modifier_dict.keys(), key=lambda k:len(k), reverse=True) : + if s.lower().endswith(key) and len(s)>len(key): + return modifier_dict[key](translate(s[:-len(key)], bold_name)) + if bold_name: + return ''.join([bold_unicode[c] for c in s]) + return s + + name = translate(name, bold_name) + + # Let's prettify sups/subs. If it fails at one of them, pretty sups/subs are + # not used at all. + def pretty_list(l, mapping): + result = [] + for s in l: + pretty = mapping.get(s) + if pretty is None: + try: # match by separate characters + pretty = ''.join([mapping[c] for c in s]) + except (TypeError, KeyError): + return None + result.append(pretty) + return result + + pretty_sups = pretty_list(sups, sup) + if pretty_sups is not None: + pretty_subs = pretty_list(subs, sub) + else: + pretty_subs = None + + # glue the results into one string + if pretty_subs is None: # nice formatting of sups/subs did not work + if subs: + name += '_'+'_'.join([translate(s, bold_name) for s in subs]) + if sups: + name += '__'+'__'.join([translate(s, bold_name) for s in sups]) + return name + else: + sups_result = ' '.join(pretty_sups) + subs_result = ' '.join(pretty_subs) + + return ''.join([name, sups_result, subs_result]) + + +def annotated(letter): + """ + Return a stylised drawing of the letter ``letter``, together with + information on how to put annotations (super- and subscripts to the + left and to the right) on it. + + See pretty.py functions _print_meijerg, _print_hyper on how to use this + information. + """ + ucode_pics = { + 'F': (2, 0, 2, 0, '\N{BOX DRAWINGS LIGHT DOWN AND RIGHT}\N{BOX DRAWINGS LIGHT HORIZONTAL}\n' + '\N{BOX DRAWINGS LIGHT VERTICAL AND RIGHT}\N{BOX DRAWINGS LIGHT HORIZONTAL}\n' + '\N{BOX DRAWINGS LIGHT UP}'), + 'G': (3, 0, 3, 1, '\N{BOX DRAWINGS LIGHT ARC DOWN AND RIGHT}\N{BOX DRAWINGS LIGHT HORIZONTAL}\N{BOX DRAWINGS LIGHT ARC DOWN AND LEFT}\n' + '\N{BOX DRAWINGS LIGHT VERTICAL}\N{BOX DRAWINGS LIGHT RIGHT}\N{BOX DRAWINGS LIGHT DOWN AND LEFT}\n' + '\N{BOX DRAWINGS LIGHT ARC UP AND RIGHT}\N{BOX DRAWINGS LIGHT HORIZONTAL}\N{BOX DRAWINGS LIGHT ARC UP AND LEFT}') + } + ascii_pics = { + 'F': (3, 0, 3, 0, ' _\n|_\n|\n'), + 'G': (3, 0, 3, 1, ' __\n/__\n\\_|') + } + + if _use_unicode: + return ucode_pics[letter] + else: + return ascii_pics[letter] + +_remove_combining = dict.fromkeys(list(range(ord('\N{COMBINING GRAVE ACCENT}'), ord('\N{COMBINING LATIN SMALL LETTER X}'))) + + list(range(ord('\N{COMBINING LEFT HARPOON ABOVE}'), ord('\N{COMBINING ASTERISK ABOVE}')))) + +def is_combining(sym): + """Check whether symbol is a unicode modifier. """ + + return ord(sym) in _remove_combining + + +def center_accent(string, accent): + """ + Returns a string with accent inserted on the middle character. Useful to + put combining accents on symbol names, including multi-character names. + + Parameters + ========== + + string : string + The string to place the accent in. + accent : string + The combining accent to insert + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Combining_character + .. [2] https://en.wikipedia.org/wiki/Combining_Diacritical_Marks + + """ + + # Accent is placed on the previous character, although it may not always look + # like that depending on console + midpoint = len(string) // 2 + 1 + firstpart = string[:midpoint] + secondpart = string[midpoint:] + return firstpart + accent + secondpart + + +def line_width(line): + """Unicode combining symbols (modifiers) are not ever displayed as + separate symbols and thus should not be counted + """ + return len(line.translate(_remove_combining)) + + +def is_subscriptable_in_unicode(subscript): + """ + Checks whether a string is subscriptable in unicode or not. + + Parameters + ========== + + subscript: the string which needs to be checked + + Examples + ======== + + >>> from sympy.printing.pretty.pretty_symbology import is_subscriptable_in_unicode + >>> is_subscriptable_in_unicode('abc') + False + >>> is_subscriptable_in_unicode('123') + True + + """ + return all(character in sub for character in subscript) + + +def center_pad(wstring, wtarget, fillchar=' '): + """ + Return the padding strings necessary to center a string of + wstring characters wide in a wtarget wide space. + + The line_width wstring should always be less or equal to wtarget + or else a ValueError will be raised. + """ + if wstring > wtarget: + raise ValueError('not enough space for string') + wdelta = wtarget - wstring + + wleft = wdelta // 2 # favor left '1 ' + wright = wdelta - wleft + + left = fillchar * wleft + right = fillchar * wright + + return left, right + + +def center(string, width, fillchar=' '): + """Return a centered string of length determined by `line_width` + that uses `fillchar` for padding. + """ + left, right = center_pad(line_width(string), width, fillchar) + return ''.join([left, string, right]) diff --git a/phivenv/Lib/site-packages/sympy/printing/pretty/stringpict.py b/phivenv/Lib/site-packages/sympy/printing/pretty/stringpict.py new file mode 100644 index 0000000000000000000000000000000000000000..b6055f09c83b2abbe0c492991aaee4dff5b34f49 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/pretty/stringpict.py @@ -0,0 +1,537 @@ +"""Prettyprinter by Jurjen Bos. +(I hate spammers: mail me at pietjepuk314 at the reverse of ku.oc.oohay). +All objects have a method that create a "stringPict", +that can be used in the str method for pretty printing. + +Updates by Jason Gedge (email at cs mun ca) + - terminal_string() method + - minor fixes and changes (mostly to prettyForm) + +TODO: + - Allow left/center/right alignment options for above/below and + top/center/bottom alignment options for left/right +""" + +import shutil + +from .pretty_symbology import hobj, vobj, xsym, xobj, pretty_use_unicode, line_width, center +from sympy.utilities.exceptions import sympy_deprecation_warning + +_GLOBAL_WRAP_LINE = None + +class stringPict: + """An ASCII picture. + The pictures are represented as a list of equal length strings. + """ + #special value for stringPict.below + LINE = 'line' + + def __init__(self, s, baseline=0): + """Initialize from string. + Multiline strings are centered. + """ + self.s = s + #picture is a string that just can be printed + self.picture = stringPict.equalLengths(s.splitlines()) + #baseline is the line number of the "base line" + self.baseline = baseline + self.binding = None + + @staticmethod + def equalLengths(lines): + # empty lines + if not lines: + return [''] + + width = max(line_width(line) for line in lines) + return [center(line, width) for line in lines] + + def height(self): + """The height of the picture in characters.""" + return len(self.picture) + + def width(self): + """The width of the picture in characters.""" + return line_width(self.picture[0]) + + @staticmethod + def next(*args): + """Put a string of stringPicts next to each other. + Returns string, baseline arguments for stringPict. + """ + #convert everything to stringPicts + objects = [] + for arg in args: + if isinstance(arg, str): + arg = stringPict(arg) + objects.append(arg) + + #make a list of pictures, with equal height and baseline + newBaseline = max(obj.baseline for obj in objects) + newHeightBelowBaseline = max( + obj.height() - obj.baseline + for obj in objects) + newHeight = newBaseline + newHeightBelowBaseline + + pictures = [] + for obj in objects: + oneEmptyLine = [' '*obj.width()] + basePadding = newBaseline - obj.baseline + totalPadding = newHeight - obj.height() + pictures.append( + oneEmptyLine * basePadding + + obj.picture + + oneEmptyLine * (totalPadding - basePadding)) + + result = [''.join(lines) for lines in zip(*pictures)] + return '\n'.join(result), newBaseline + + def right(self, *args): + r"""Put pictures next to this one. + Returns string, baseline arguments for stringPict. + (Multiline) strings are allowed, and are given a baseline of 0. + + Examples + ======== + + >>> from sympy.printing.pretty.stringpict import stringPict + >>> print(stringPict("10").right(" + ",stringPict("1\r-\r2",1))[0]) + 1 + 10 + - + 2 + + """ + return stringPict.next(self, *args) + + def left(self, *args): + """Put pictures (left to right) at left. + Returns string, baseline arguments for stringPict. + """ + return stringPict.next(*(args + (self,))) + + @staticmethod + def stack(*args): + """Put pictures on top of each other, + from top to bottom. + Returns string, baseline arguments for stringPict. + The baseline is the baseline of the second picture. + Everything is centered. + Baseline is the baseline of the second picture. + Strings are allowed. + The special value stringPict.LINE is a row of '-' extended to the width. + """ + #convert everything to stringPicts; keep LINE + objects = [] + for arg in args: + if arg is not stringPict.LINE and isinstance(arg, str): + arg = stringPict(arg) + objects.append(arg) + + #compute new width + newWidth = max( + obj.width() + for obj in objects + if obj is not stringPict.LINE) + + lineObj = stringPict(hobj('-', newWidth)) + + #replace LINE with proper lines + for i, obj in enumerate(objects): + if obj is stringPict.LINE: + objects[i] = lineObj + + #stack the pictures, and center the result + newPicture = [center(line, newWidth) for obj in objects for line in obj.picture] + newBaseline = objects[0].height() + objects[1].baseline + return '\n'.join(newPicture), newBaseline + + def below(self, *args): + """Put pictures under this picture. + Returns string, baseline arguments for stringPict. + Baseline is baseline of top picture + + Examples + ======== + + >>> from sympy.printing.pretty.stringpict import stringPict + >>> print(stringPict("x+3").below( + ... stringPict.LINE, '3')[0]) #doctest: +NORMALIZE_WHITESPACE + x+3 + --- + 3 + + """ + s, baseline = stringPict.stack(self, *args) + return s, self.baseline + + def above(self, *args): + """Put pictures above this picture. + Returns string, baseline arguments for stringPict. + Baseline is baseline of bottom picture. + """ + string, baseline = stringPict.stack(*(args + (self,))) + baseline = len(string.splitlines()) - self.height() + self.baseline + return string, baseline + + def parens(self, left='(', right=')', ifascii_nougly=False): + """Put parentheses around self. + Returns string, baseline arguments for stringPict. + + left or right can be None or empty string which means 'no paren from + that side' + """ + h = self.height() + b = self.baseline + + # XXX this is a hack -- ascii parens are ugly! + if ifascii_nougly and not pretty_use_unicode(): + h = 1 + b = 0 + + res = self + + if left: + lparen = stringPict(vobj(left, h), baseline=b) + res = stringPict(*lparen.right(self)) + if right: + rparen = stringPict(vobj(right, h), baseline=b) + res = stringPict(*res.right(rparen)) + + return ('\n'.join(res.picture), res.baseline) + + def leftslash(self): + """Precede object by a slash of the proper size. + """ + # XXX not used anywhere ? + height = max( + self.baseline, + self.height() - 1 - self.baseline)*2 + 1 + slash = '\n'.join( + ' '*(height - i - 1) + xobj('/', 1) + ' '*i + for i in range(height) + ) + return self.left(stringPict(slash, height//2)) + + def root(self, n=None): + """Produce a nice root symbol. + Produces ugly results for big n inserts. + """ + # XXX not used anywhere + # XXX duplicate of root drawing in pretty.py + #put line over expression + result = self.above('_'*self.width()) + #construct right half of root symbol + height = self.height() + slash = '\n'.join( + ' ' * (height - i - 1) + '/' + ' ' * i + for i in range(height) + ) + slash = stringPict(slash, height - 1) + #left half of root symbol + if height > 2: + downline = stringPict('\\ \n \\', 1) + else: + downline = stringPict('\\') + #put n on top, as low as possible + if n is not None and n.width() > downline.width(): + downline = downline.left(' '*(n.width() - downline.width())) + downline = downline.above(n) + #build root symbol + root = downline.right(slash) + #glue it on at the proper height + #normally, the root symbel is as high as self + #which is one less than result + #this moves the root symbol one down + #if the root became higher, the baseline has to grow too + root.baseline = result.baseline - result.height() + root.height() + return result.left(root) + + def render(self, * args, **kwargs): + """Return the string form of self. + + Unless the argument line_break is set to False, it will + break the expression in a form that can be printed + on the terminal without being broken up. + """ + if _GLOBAL_WRAP_LINE is not None: + kwargs["wrap_line"] = _GLOBAL_WRAP_LINE + + if kwargs["wrap_line"] is False: + return "\n".join(self.picture) + + if kwargs["num_columns"] is not None: + # Read the argument num_columns if it is not None + ncols = kwargs["num_columns"] + else: + # Attempt to get a terminal width + ncols = self.terminal_width() + + if ncols <= 0: + ncols = 80 + + # If smaller than the terminal width, no need to correct + if self.width() <= ncols: + return type(self.picture[0])(self) + + """ + Break long-lines in a visually pleasing format. + without overflow indicators | with overflow indicators + | 2 2 3 | | 2 2 3 ↪| + |6*x *y + 4*x*y + | |6*x *y + 4*x*y + ↪| + | | | | + | 3 4 4 | |↪ 3 4 4 | + |4*y*x + x + y | |↪ 4*y*x + x + y | + |a*c*e + a*c*f + a*d | |a*c*e + a*c*f + a*d ↪| + |*e + a*d*f + b*c*e | | | + |+ b*c*f + b*d*e + b | |↪ *e + a*d*f + b*c* ↪| + |*d*f | | | + | | |↪ e + b*c*f + b*d*e ↪| + | | | | + | | |↪ + b*d*f | + """ + + overflow_first = "" + if kwargs["use_unicode"] or pretty_use_unicode(): + overflow_start = "\N{RIGHTWARDS ARROW WITH HOOK} " + overflow_end = " \N{RIGHTWARDS ARROW WITH HOOK}" + else: + overflow_start = "> " + overflow_end = " >" + + def chunks(line): + """Yields consecutive chunks of line_width ncols""" + prefix = overflow_first + width, start = line_width(prefix + overflow_end), 0 + for i, x in enumerate(line): + wx = line_width(x) + # Only flush the screen when the current character overflows. + # This way, combining marks can be appended even when width == ncols. + if width + wx > ncols: + yield prefix + line[start:i] + overflow_end + prefix = overflow_start + width, start = line_width(prefix + overflow_end), i + width += wx + yield prefix + line[start:] + + # Concurrently assemble chunks of all lines into individual screens + pictures = zip(*map(chunks, self.picture)) + + # Join lines of each screen into sub-pictures + pictures = ["\n".join(picture) for picture in pictures] + + # Add spacers between sub-pictures + return "\n\n".join(pictures) + + def terminal_width(self): + """Return the terminal width if possible, otherwise return 0. + """ + size = shutil.get_terminal_size(fallback=(0, 0)) + return size.columns + + def __eq__(self, o): + if isinstance(o, str): + return '\n'.join(self.picture) == o + elif isinstance(o, stringPict): + return o.picture == self.picture + return False + + def __hash__(self): + return super().__hash__() + + def __str__(self): + return '\n'.join(self.picture) + + def __repr__(self): + return "stringPict(%r,%d)" % ('\n'.join(self.picture), self.baseline) + + def __getitem__(self, index): + return self.picture[index] + + def __len__(self): + return len(self.s) + + +class prettyForm(stringPict): + """ + Extension of the stringPict class that knows about basic math applications, + optimizing double minus signs. + + "Binding" is interpreted as follows:: + + ATOM this is an atom: never needs to be parenthesized + FUNC this is a function application: parenthesize if added (?) + DIV this is a division: make wider division if divided + POW this is a power: only parenthesize if exponent + MUL this is a multiplication: parenthesize if powered + ADD this is an addition: parenthesize if multiplied or powered + NEG this is a negative number: optimize if added, parenthesize if + multiplied or powered + OPEN this is an open object: parenthesize if added, multiplied, or + powered (example: Piecewise) + """ + ATOM, FUNC, DIV, POW, MUL, ADD, NEG, OPEN = range(8) + + def __init__(self, s, baseline=0, binding=0, unicode=None): + """Initialize from stringPict and binding power.""" + stringPict.__init__(self, s, baseline) + self.binding = binding + if unicode is not None: + sympy_deprecation_warning( + """ + The unicode argument to prettyForm is deprecated. Only the s + argument (the first positional argument) should be passed. + """, + deprecated_since_version="1.7", + active_deprecations_target="deprecated-pretty-printing-functions") + self._unicode = unicode or s + + @property + def unicode(self): + sympy_deprecation_warning( + """ + The prettyForm.unicode attribute is deprecated. Use the + prettyForm.s attribute instead. + """, + deprecated_since_version="1.7", + active_deprecations_target="deprecated-pretty-printing-functions") + return self._unicode + + # Note: code to handle subtraction is in _print_Add + + def __add__(self, *others): + """Make a pretty addition. + Addition of negative numbers is simplified. + """ + arg = self + if arg.binding > prettyForm.NEG: + arg = stringPict(*arg.parens()) + result = [arg] + for arg in others: + #add parentheses for weak binders + if arg.binding > prettyForm.NEG: + arg = stringPict(*arg.parens()) + #use existing minus sign if available + if arg.binding != prettyForm.NEG: + result.append(' + ') + result.append(arg) + return prettyForm(binding=prettyForm.ADD, *stringPict.next(*result)) + + def __truediv__(self, den, slashed=False): + """Make a pretty division; stacked or slashed. + """ + if slashed: + raise NotImplementedError("Can't do slashed fraction yet") + num = self + if num.binding == prettyForm.DIV: + num = stringPict(*num.parens()) + if den.binding == prettyForm.DIV: + den = stringPict(*den.parens()) + + if num.binding==prettyForm.NEG: + num = num.right(" ")[0] + + return prettyForm(binding=prettyForm.DIV, *stringPict.stack( + num, + stringPict.LINE, + den)) + + def __mul__(self, *others): + """Make a pretty multiplication. + Parentheses are needed around +, - and neg. + """ + quantity = { + 'degree': "\N{DEGREE SIGN}" + } + + if len(others) == 0: + return self # We aren't actually multiplying... So nothing to do here. + + # add parens on args that need them + arg = self + if arg.binding > prettyForm.MUL and arg.binding != prettyForm.NEG: + arg = stringPict(*arg.parens()) + result = [arg] + for arg in others: + if arg.picture[0] not in quantity.values(): + result.append(xsym('*')) + #add parentheses for weak binders + if arg.binding > prettyForm.MUL and arg.binding != prettyForm.NEG: + arg = stringPict(*arg.parens()) + result.append(arg) + + len_res = len(result) + for i in range(len_res): + if i < len_res - 1 and result[i] == '-1' and result[i + 1] == xsym('*'): + # substitute -1 by -, like in -1*x -> -x + result.pop(i) + result.pop(i) + result.insert(i, '-') + if result[0][0] == '-': + # if there is a - sign in front of all + # This test was failing to catch a prettyForm.__mul__(prettyForm("-1", 0, 6)) being negative + bin = prettyForm.NEG + if result[0] == '-': + right = result[1] + if right.picture[right.baseline][0] == '-': + result[0] = '- ' + else: + bin = prettyForm.MUL + return prettyForm(binding=bin, *stringPict.next(*result)) + + def __repr__(self): + return "prettyForm(%r,%d,%d)" % ( + '\n'.join(self.picture), + self.baseline, + self.binding) + + def __pow__(self, b): + """Make a pretty power. + """ + a = self + use_inline_func_form = False + if b.binding == prettyForm.POW: + b = stringPict(*b.parens()) + if a.binding > prettyForm.FUNC: + a = stringPict(*a.parens()) + elif a.binding == prettyForm.FUNC: + # heuristic for when to use inline power + if b.height() > 1: + a = stringPict(*a.parens()) + else: + use_inline_func_form = True + + if use_inline_func_form: + # 2 + # sin + + (x) + b.baseline = a.prettyFunc.baseline + b.height() + func = stringPict(*a.prettyFunc.right(b)) + return prettyForm(*func.right(a.prettyArgs)) + else: + # 2 <-- top + # (x+y) <-- bot + top = stringPict(*b.left(' '*a.width())) + bot = stringPict(*a.right(' '*b.width())) + + return prettyForm(binding=prettyForm.POW, *bot.above(top)) + + simpleFunctions = ["sin", "cos", "tan"] + + @staticmethod + def apply(function, *args): + """Functions of one or more variables. + """ + if function in prettyForm.simpleFunctions: + #simple function: use only space if possible + assert len( + args) == 1, "Simple function %s must have 1 argument" % function + arg = args[0].__pretty__() + if arg.binding <= prettyForm.DIV: + #optimization: no parentheses necessary + return prettyForm(binding=prettyForm.FUNC, *arg.left(function + ' ')) + argumentList = [] + for arg in args: + argumentList.append(',') + argumentList.append(arg.__pretty__()) + argumentList = stringPict(*stringPict.next(*argumentList[1:])) + argumentList = stringPict(*argumentList.parens()) + return prettyForm(binding=prettyForm.ATOM, *argumentList.left(function)) diff --git a/phivenv/Lib/site-packages/sympy/printing/pretty/tests/__init__.py b/phivenv/Lib/site-packages/sympy/printing/pretty/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/printing/pretty/tests/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/pretty/tests/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..494dcc8bd8e98df32e61f479a2cf48d1e75a6b1f Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/pretty/tests/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/pretty/tests/test_pretty.py b/phivenv/Lib/site-packages/sympy/printing/pretty/tests/test_pretty.py new file mode 100644 index 0000000000000000000000000000000000000000..1cca79bd1dc5c3ba81483c8fe2e87c35926d1b94 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/pretty/tests/test_pretty.py @@ -0,0 +1,7972 @@ +# -*- coding: utf-8 -*- +from sympy.concrete.products import Product +from sympy.concrete.summations import Sum +from sympy.core.add import Add +from sympy.core.basic import Basic +from sympy.core.containers import (Dict, Tuple) +from sympy.core.function import (Derivative, Function, Lambda, Subs) +from sympy.core.mul import Mul +from sympy.core import (EulerGamma, GoldenRatio, Catalan) +from sympy.core.numbers import (I, Rational, oo, pi) +from sympy.core.power import Pow +from sympy.core.relational import (Eq, Ge, Gt, Le, Lt, Ne) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.complexes import conjugate +from sympy.functions.elementary.exponential import LambertW +from sympy.functions.special.bessel import (airyai, airyaiprime, airybi, airybiprime) +from sympy.functions.special.delta_functions import Heaviside +from sympy.functions.special.error_functions import (fresnelc, fresnels) +from sympy.functions.special.singularity_functions import SingularityFunction +from sympy.functions.special.zeta_functions import dirichlet_eta +from sympy.geometry.line import (Ray, Segment) +from sympy.integrals.integrals import Integral +from sympy.logic.boolalg import (And, Equivalent, ITE, Implies, Nand, Nor, Not, Or, Xor) +from sympy.matrices.dense import (Matrix, diag) +from sympy.matrices.expressions.slice import MatrixSlice +from sympy.matrices.expressions.trace import Trace +from sympy.polys.domains.finitefield import FF +from sympy.polys.domains.integerring import ZZ +from sympy.polys.domains.rationalfield import QQ +from sympy.polys.domains.realfield import RR +from sympy.polys.orderings import (grlex, ilex) +from sympy.polys.polytools import groebner +from sympy.polys.rootoftools import (RootSum, rootof) +from sympy.series.formal import fps +from sympy.series.fourier import fourier_series +from sympy.series.limits import Limit +from sympy.series.order import O +from sympy.series.sequences import (SeqAdd, SeqFormula, SeqMul, SeqPer) +from sympy.sets.contains import Contains +from sympy.sets.fancysets import Range +from sympy.sets.sets import (Complement, FiniteSet, Intersection, Interval, Union) +from sympy.codegen.ast import (Assignment, AddAugmentedAssignment, + SubAugmentedAssignment, MulAugmentedAssignment, DivAugmentedAssignment, ModAugmentedAssignment) +from sympy.core.expr import UnevaluatedExpr +from sympy.physics.quantum.trace import Tr + +from sympy.functions import (Abs, Chi, Ci, Ei, KroneckerDelta, + Piecewise, Shi, Si, atan2, beta, binomial, catalan, ceiling, cos, + euler, exp, expint, factorial, factorial2, floor, gamma, hyper, log, + meijerg, sin, sqrt, subfactorial, tan, uppergamma, lerchphi, polylog, + elliptic_k, elliptic_f, elliptic_e, elliptic_pi, DiracDelta, bell, + bernoulli, fibonacci, tribonacci, lucas, stieltjes, mathieuc, mathieus, + mathieusprime, mathieucprime) + +from sympy.matrices import (Adjoint, Inverse, MatrixSymbol, Transpose, + KroneckerProduct, BlockMatrix, OneMatrix, ZeroMatrix) +from sympy.matrices.expressions import hadamard_power + +from sympy.physics import mechanics +from sympy.physics.control.lti import (TransferFunction, Feedback, TransferFunctionMatrix, + Series, Parallel, MIMOSeries, MIMOParallel, MIMOFeedback, StateSpace) +from sympy.physics.units import joule, degree +from sympy.printing.pretty import pprint, pretty as xpretty +from sympy.printing.pretty.pretty_symbology import center_accent, is_combining, center +from sympy.sets.conditionset import ConditionSet + +from sympy.sets import ImageSet, ProductSet +from sympy.sets.setexpr import SetExpr +from sympy.stats.crv_types import Normal +from sympy.stats.symbolic_probability import (Covariance, Expectation, + Probability, Variance) +from sympy.tensor.array import (ImmutableDenseNDimArray, ImmutableSparseNDimArray, + MutableDenseNDimArray, MutableSparseNDimArray, tensorproduct) +from sympy.tensor.functions import TensorProduct +from sympy.tensor.tensor import (TensorIndexType, tensor_indices, TensorHead, + TensorElement, tensor_heads) + +from sympy.testing.pytest import raises, _both_exp_pow, warns_deprecated_sympy + +from sympy.vector import CoordSys3D, Gradient, Curl, Divergence, Dot, Cross, Laplacian + + + +import sympy as sym +class lowergamma(sym.lowergamma): + pass # testing notation inheritance by a subclass with same name + +a, b, c, d, x, y, z, k, n, s, p = symbols('a,b,c,d,x,y,z,k,n,s,p') +f = Function("f") +th = Symbol('theta') +ph = Symbol('phi') + +""" +Expressions whose pretty-printing is tested here: +(A '#' to the right of an expression indicates that its various acceptable +orderings are accounted for by the tests.) + + +BASIC EXPRESSIONS: + +oo +(x**2) +1/x +y*x**-2 +x**Rational(-5,2) +(-2)**x +Pow(3, 1, evaluate=False) +(x**2 + x + 1) # +1-x # +1-2*x # +x/y +-x/y +(x+2)/y # +(1+x)*y #3 +-5*x/(x+10) # correct placement of negative sign +1 - Rational(3,2)*(x+1) +-(-x + 5)*(-x - 2*sqrt(2) + 5) - (-y + 5)*(-y + 5) # issue 5524 + + +ORDERING: + +x**2 + x + 1 +1 - x +1 - 2*x +2*x**4 + y**2 - x**2 + y**3 + + +RELATIONAL: + +Eq(x, y) +Lt(x, y) +Gt(x, y) +Le(x, y) +Ge(x, y) +Ne(x/(y+1), y**2) # + + +RATIONAL NUMBERS: + +y*x**-2 +y**Rational(3,2) * x**Rational(-5,2) +sin(x)**3/tan(x)**2 + + +FUNCTIONS (ABS, CONJ, EXP, FUNCTION BRACES, FACTORIAL, FLOOR, CEILING): + +(2*x + exp(x)) # +Abs(x) +Abs(x/(x**2+1)) # +Abs(1 / (y - Abs(x))) +factorial(n) +factorial(2*n) +subfactorial(n) +subfactorial(2*n) +factorial(factorial(factorial(n))) +factorial(n+1) # +conjugate(x) +conjugate(f(x+1)) # +f(x) +f(x, y) +f(x/(y+1), y) # +f(x**x**x**x**x**x) +sin(x)**2 +conjugate(a+b*I) +conjugate(exp(a+b*I)) +conjugate( f(1 + conjugate(f(x))) ) # +f(x/(y+1), y) # denom of first arg +floor(1 / (y - floor(x))) +ceiling(1 / (y - ceiling(x))) + + +SQRT: + +sqrt(2) +2**Rational(1,3) +2**Rational(1,1000) +sqrt(x**2 + 1) +(1 + sqrt(5))**Rational(1,3) +2**(1/x) +sqrt(2+pi) +(2+(1+x**2)/(2+x))**Rational(1,4)+(1+x**Rational(1,1000))/sqrt(3+x**2) + + +DERIVATIVES: + +Derivative(log(x), x, evaluate=False) +Derivative(log(x), x, evaluate=False) + x # +Derivative(log(x) + x**2, x, y, evaluate=False) +Derivative(2*x*y, y, x, evaluate=False) + x**2 # +beta(alpha).diff(alpha) + + +INTEGRALS: + +Integral(log(x), x) +Integral(x**2, x) +Integral((sin(x))**2 / (tan(x))**2) +Integral(x**(2**x), x) +Integral(x**2, (x,1,2)) +Integral(x**2, (x,Rational(1,2),10)) +Integral(x**2*y**2, x,y) +Integral(x**2, (x, None, 1)) +Integral(x**2, (x, 1, None)) +Integral(sin(th)/cos(ph), (th,0,pi), (ph, 0, 2*pi)) + + +MATRICES: + +Matrix([[x**2+1, 1], [y, x+y]]) # +Matrix([[x/y, y, th], [0, exp(I*k*ph), 1]]) + + +PIECEWISE: + +Piecewise((x,x<1),(x**2,True)) + +ITE: + +ITE(x, y, z) + +SEQUENCES (TUPLES, LISTS, DICTIONARIES): + +() +[] +{} +(1/x,) +[x**2, 1/x, x, y, sin(th)**2/cos(ph)**2] +(x**2, 1/x, x, y, sin(th)**2/cos(ph)**2) +{x: sin(x)} +{1/x: 1/y, x: sin(x)**2} # +[x**2] +(x**2,) +{x**2: 1} + + +LIMITS: + +Limit(x, x, oo) +Limit(x**2, x, 0) +Limit(1/x, x, 0) +Limit(sin(x)/x, x, 0) + + +UNITS: + +joule => kg*m**2/s + + +SUBS: + +Subs(f(x), x, ph**2) +Subs(f(x).diff(x), x, 0) +Subs(f(x).diff(x)/y, (x, y), (0, Rational(1, 2))) + + +ORDER: + +O(1) +O(1/x) +O(x**2 + y**2) + +""" + + +def pretty(expr, order=None): + """ASCII pretty-printing""" + return xpretty(expr, order=order, use_unicode=False, wrap_line=False) + + +def upretty(expr, order=None): + """Unicode pretty-printing""" + return xpretty(expr, order=order, use_unicode=True, wrap_line=False) + + +def test_pretty_ascii_str(): + assert pretty( 'xxx' ) == 'xxx' + assert pretty( "xxx" ) == 'xxx' + assert pretty( 'xxx\'xxx' ) == 'xxx\'xxx' + assert pretty( 'xxx"xxx' ) == 'xxx\"xxx' + assert pretty( 'xxx\"xxx' ) == 'xxx\"xxx' + assert pretty( "xxx'xxx" ) == 'xxx\'xxx' + assert pretty( "xxx\'xxx" ) == 'xxx\'xxx' + assert pretty( "xxx\"xxx" ) == 'xxx\"xxx' + assert pretty( "xxx\"xxx\'xxx" ) == 'xxx"xxx\'xxx' + assert pretty( "xxx\nxxx" ) == 'xxx\nxxx' + + +def test_pretty_unicode_str(): + assert pretty( 'xxx' ) == 'xxx' + assert pretty( 'xxx' ) == 'xxx' + assert pretty( 'xxx\'xxx' ) == 'xxx\'xxx' + assert pretty( 'xxx"xxx' ) == 'xxx\"xxx' + assert pretty( 'xxx\"xxx' ) == 'xxx\"xxx' + assert pretty( "xxx'xxx" ) == 'xxx\'xxx' + assert pretty( "xxx\'xxx" ) == 'xxx\'xxx' + assert pretty( "xxx\"xxx" ) == 'xxx\"xxx' + assert pretty( "xxx\"xxx\'xxx" ) == 'xxx"xxx\'xxx' + assert pretty( "xxx\nxxx" ) == 'xxx\nxxx' + + +def test_upretty_greek(): + assert upretty( oo ) == '∞' + assert upretty( Symbol('alpha^+_1') ) == 'α⁺₁' + assert upretty( Symbol('beta') ) == 'β' + assert upretty(Symbol('lambda')) == 'λ' + + +def test_upretty_multiindex(): + assert upretty( Symbol('beta12') ) == 'β₁₂' + assert upretty( Symbol('Y00') ) == 'Y₀₀' + assert upretty( Symbol('Y_00') ) == 'Y₀₀' + assert upretty( Symbol('F^+-') ) == 'F⁺⁻' + + +def test_upretty_sub_super(): + assert upretty( Symbol('beta_1_2') ) == 'β₁ ₂' + assert upretty( Symbol('beta^1^2') ) == 'β¹ ²' + assert upretty( Symbol('beta_1^2') ) == 'β²₁' + assert upretty( Symbol('beta_10_20') ) == 'β₁₀ ₂₀' + assert upretty( Symbol('beta_ax_gamma^i') ) == 'βⁱₐₓ ᵧ' + assert upretty( Symbol("F^1^2_3_4") ) == 'F¹ ²₃ ₄' + assert upretty( Symbol("F_1_2^3^4") ) == 'F³ ⁴₁ ₂' + assert upretty( Symbol("F_1_2_3_4") ) == 'F₁ ₂ ₃ ₄' + assert upretty( Symbol("F^1^2^3^4") ) == 'F¹ ² ³ ⁴' + + +def test_upretty_subs_missing_in_24(): + assert upretty( Symbol('F_beta') ) == 'Fᵦ' + assert upretty( Symbol('F_gamma') ) == 'Fᵧ' + assert upretty( Symbol('F_rho') ) == 'Fᵨ' + assert upretty( Symbol('F_phi') ) == 'Fᵩ' + assert upretty( Symbol('F_chi') ) == 'Fᵪ' + + assert upretty( Symbol('F_a') ) == 'Fₐ' + assert upretty( Symbol('F_e') ) == 'Fₑ' + assert upretty( Symbol('F_i') ) == 'Fᵢ' + assert upretty( Symbol('F_o') ) == 'Fₒ' + assert upretty( Symbol('F_u') ) == 'Fᵤ' + assert upretty( Symbol('F_r') ) == 'Fᵣ' + assert upretty( Symbol('F_v') ) == 'Fᵥ' + assert upretty( Symbol('F_x') ) == 'Fₓ' + + +def test_missing_in_2X_issue_9047(): + assert upretty( Symbol('F_h') ) == 'Fₕ' + assert upretty( Symbol('F_k') ) == 'Fₖ' + assert upretty( Symbol('F_l') ) == 'Fₗ' + assert upretty( Symbol('F_m') ) == 'Fₘ' + assert upretty( Symbol('F_n') ) == 'Fₙ' + assert upretty( Symbol('F_p') ) == 'Fₚ' + assert upretty( Symbol('F_s') ) == 'Fₛ' + assert upretty( Symbol('F_t') ) == 'Fₜ' + + +def test_upretty_modifiers(): + # Accents + assert upretty( Symbol('Fmathring') ) == 'F̊' + assert upretty( Symbol('Fddddot') ) == 'F⃜' + assert upretty( Symbol('Fdddot') ) == 'F⃛' + assert upretty( Symbol('Fddot') ) == 'F̈' + assert upretty( Symbol('Fdot') ) == 'Ḟ' + assert upretty( Symbol('Fcheck') ) == 'F̌' + assert upretty( Symbol('Fbreve') ) == 'F̆' + assert upretty( Symbol('Facute') ) == 'F́' + assert upretty( Symbol('Fgrave') ) == 'F̀' + assert upretty( Symbol('Ftilde') ) == 'F̃' + assert upretty( Symbol('Fhat') ) == 'F̂' + assert upretty( Symbol('Fbar') ) == 'F̅' + assert upretty( Symbol('Fvec') ) == 'F⃗' + assert upretty( Symbol('Fprime') ) == 'F′' + assert upretty( Symbol('Fprm') ) == 'F′' + # No faces are actually implemented, but test to make sure the modifiers are stripped + assert upretty( Symbol('Fbold') ) == 'Fbold' + assert upretty( Symbol('Fbm') ) == 'Fbm' + assert upretty( Symbol('Fcal') ) == 'Fcal' + assert upretty( Symbol('Fscr') ) == 'Fscr' + assert upretty( Symbol('Ffrak') ) == 'Ffrak' + # Brackets + assert upretty( Symbol('Fnorm') ) == '‖F‖' + assert upretty( Symbol('Favg') ) == '⟨F⟩' + assert upretty( Symbol('Fabs') ) == '|F|' + assert upretty( Symbol('Fmag') ) == '|F|' + # Combinations + assert upretty( Symbol('xvecdot') ) == 'x⃗̇' + assert upretty( Symbol('xDotVec') ) == 'ẋ⃗' + assert upretty( Symbol('xHATNorm') ) == '‖x̂‖' + assert upretty( Symbol('xMathring_yCheckPRM__zbreveAbs') ) == 'x̊_y̌′__|z̆|' + assert upretty( Symbol('alphadothat_nVECDOT__tTildePrime') ) == 'α̇̂_n⃗̇__t̃′' + assert upretty( Symbol('x_dot') ) == 'x_dot' + assert upretty( Symbol('x__dot') ) == 'x__dot' + + +def test_pretty_Cycle(): + from sympy.combinatorics.permutations import Cycle + assert pretty(Cycle(1, 2)) == '(1 2)' + assert pretty(Cycle(2)) == '(2)' + assert pretty(Cycle(1, 3)(4, 5)) == '(1 3)(4 5)' + assert pretty(Cycle()) == '()' + + +def test_pretty_Permutation(): + from sympy.combinatorics.permutations import Permutation + p1 = Permutation(1, 2)(3, 4) + assert xpretty(p1, perm_cyclic=True, use_unicode=True) == "(1 2)(3 4)" + assert xpretty(p1, perm_cyclic=True, use_unicode=False) == "(1 2)(3 4)" + assert xpretty(p1, perm_cyclic=False, use_unicode=True) == \ + '⎛0 1 2 3 4⎞\n'\ + '⎝0 2 1 4 3⎠' + assert xpretty(p1, perm_cyclic=False, use_unicode=False) == \ + "/0 1 2 3 4\\\n"\ + "\\0 2 1 4 3/" + + with warns_deprecated_sympy(): + old_print_cyclic = Permutation.print_cyclic + Permutation.print_cyclic = False + assert xpretty(p1, use_unicode=True) == \ + '⎛0 1 2 3 4⎞\n'\ + '⎝0 2 1 4 3⎠' + assert xpretty(p1, use_unicode=False) == \ + "/0 1 2 3 4\\\n"\ + "\\0 2 1 4 3/" + Permutation.print_cyclic = old_print_cyclic + + +def test_pretty_basic(): + assert pretty( -Rational(1)/2 ) == '-1/2' + assert pretty( -Rational(13)/22 ) == \ +"""\ +-13 \n\ +----\n\ + 22 \ +""" + expr = oo + ascii_str = \ +"""\ +oo\ +""" + ucode_str = \ +"""\ +∞\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = (x**2) + ascii_str = \ +"""\ + 2\n\ +x \ +""" + ucode_str = \ +"""\ + 2\n\ +x \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = 1/x + ascii_str = \ +"""\ +1\n\ +-\n\ +x\ +""" + ucode_str = \ +"""\ +1\n\ +─\n\ +x\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + # not the same as 1/x + expr = x**-1.0 + ascii_str = \ +"""\ + -1.0\n\ +x \ +""" + ucode_str = \ +"""\ + -1.0\n\ +x \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + # see issue #2860 + expr = Pow(S(2), -1.0, evaluate=False) + ascii_str = \ +"""\ + -1.0\n\ +2 \ +""" + ucode_str = \ +"""\ + -1.0\n\ +2 \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = y*x**-2 + ascii_str = \ +"""\ +y \n\ +--\n\ + 2\n\ +x \ +""" + ucode_str = \ +"""\ +y \n\ +──\n\ + 2\n\ +x \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + #see issue #14033 + expr = x**Rational(1, 3) + ascii_str = \ +"""\ + 1/3\n\ +x \ +""" + ucode_str = \ +"""\ + 1/3\n\ +x \ +""" + assert xpretty(expr, use_unicode=False, wrap_line=False,\ + root_notation = False) == ascii_str + assert xpretty(expr, use_unicode=True, wrap_line=False,\ + root_notation = False) == ucode_str + + expr = x**Rational(-5, 2) + ascii_str = \ +"""\ + 1 \n\ +----\n\ + 5/2\n\ +x \ +""" + ucode_str = \ +"""\ + 1 \n\ +────\n\ + 5/2\n\ +x \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = (-2)**x + ascii_str = \ +"""\ + x\n\ +(-2) \ +""" + ucode_str = \ +"""\ + x\n\ +(-2) \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + # See issue 4923 + expr = Pow(3, 1, evaluate=False) + ascii_str = \ +"""\ + 1\n\ +3 \ +""" + ucode_str = \ +"""\ + 1\n\ +3 \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = (x**2 + x + 1) + ascii_str_1 = \ +"""\ + 2\n\ +1 + x + x \ +""" + ascii_str_2 = \ +"""\ + 2 \n\ +x + x + 1\ +""" + ascii_str_3 = \ +"""\ + 2 \n\ +x + 1 + x\ +""" + ucode_str_1 = \ +"""\ + 2\n\ +1 + x + x \ +""" + ucode_str_2 = \ +"""\ + 2 \n\ +x + x + 1\ +""" + ucode_str_3 = \ +"""\ + 2 \n\ +x + 1 + x\ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2, ascii_str_3] + assert upretty(expr) in [ucode_str_1, ucode_str_2, ucode_str_3] + + expr = 1 - x + ascii_str_1 = \ +"""\ +1 - x\ +""" + ascii_str_2 = \ +"""\ +-x + 1\ +""" + ucode_str_1 = \ +"""\ +1 - x\ +""" + ucode_str_2 = \ +"""\ +-x + 1\ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2] + assert upretty(expr) in [ucode_str_1, ucode_str_2] + + expr = 1 - 2*x + ascii_str_1 = \ +"""\ +1 - 2*x\ +""" + ascii_str_2 = \ +"""\ +-2*x + 1\ +""" + ucode_str_1 = \ +"""\ +1 - 2⋅x\ +""" + ucode_str_2 = \ +"""\ +-2⋅x + 1\ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2] + assert upretty(expr) in [ucode_str_1, ucode_str_2] + + expr = x/y + ascii_str = \ +"""\ +x\n\ +-\n\ +y\ +""" + ucode_str = \ +"""\ +x\n\ +─\n\ +y\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = -x/y + ascii_str = \ +"""\ +-x \n\ +---\n\ + y \ +""" + ucode_str = \ +"""\ +-x \n\ +───\n\ + y \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = (x + 2)/y + ascii_str_1 = \ +"""\ +2 + x\n\ +-----\n\ + y \ +""" + ascii_str_2 = \ +"""\ +x + 2\n\ +-----\n\ + y \ +""" + ucode_str_1 = \ +"""\ +2 + x\n\ +─────\n\ + y \ +""" + ucode_str_2 = \ +"""\ +x + 2\n\ +─────\n\ + y \ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2] + assert upretty(expr) in [ucode_str_1, ucode_str_2] + + expr = (1 + x)*y + ascii_str_1 = \ +"""\ +y*(1 + x)\ +""" + ascii_str_2 = \ +"""\ +(1 + x)*y\ +""" + ascii_str_3 = \ +"""\ +y*(x + 1)\ +""" + ucode_str_1 = \ +"""\ +y⋅(1 + x)\ +""" + ucode_str_2 = \ +"""\ +(1 + x)⋅y\ +""" + ucode_str_3 = \ +"""\ +y⋅(x + 1)\ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2, ascii_str_3] + assert upretty(expr) in [ucode_str_1, ucode_str_2, ucode_str_3] + + # Test for correct placement of the negative sign + expr = -5*x/(x + 10) + ascii_str_1 = \ +"""\ +-5*x \n\ +------\n\ +10 + x\ +""" + ascii_str_2 = \ +"""\ +-5*x \n\ +------\n\ +x + 10\ +""" + ucode_str_1 = \ +"""\ +-5⋅x \n\ +──────\n\ +10 + x\ +""" + ucode_str_2 = \ +"""\ +-5⋅x \n\ +──────\n\ +x + 10\ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2] + assert upretty(expr) in [ucode_str_1, ucode_str_2] + + expr = -S.Half - 3*x + ascii_str = \ +"""\ +-3*x - 1/2\ +""" + ucode_str = \ +"""\ +-3⋅x - 1/2\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = S.Half - 3*x + ascii_str = \ +"""\ +1/2 - 3*x\ +""" + ucode_str = \ +"""\ +1/2 - 3⋅x\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = -S.Half - 3*x/2 + ascii_str = \ +"""\ + 3*x 1\n\ +- --- - -\n\ + 2 2\ +""" + ucode_str = \ +"""\ + 3⋅x 1\n\ +- ─── - ─\n\ + 2 2\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = S.Half - 3*x/2 + ascii_str = \ +"""\ +1 3*x\n\ +- - ---\n\ +2 2 \ +""" + ucode_str = \ +"""\ +1 3⋅x\n\ +─ - ───\n\ +2 2 \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_negative_fractions(): + expr = -x/y + ascii_str =\ +"""\ +-x \n\ +---\n\ + y \ +""" + ucode_str =\ +"""\ +-x \n\ +───\n\ + y \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + expr = -x*z/y + ascii_str =\ +"""\ +-x*z \n\ +-----\n\ + y \ +""" + ucode_str =\ +"""\ +-x⋅z \n\ +─────\n\ + y \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + expr = x**2/y + ascii_str =\ +"""\ + 2\n\ +x \n\ +--\n\ +y \ +""" + ucode_str =\ +"""\ + 2\n\ +x \n\ +──\n\ +y \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + expr = -x**2/y + ascii_str =\ +"""\ + 2 \n\ +-x \n\ +----\n\ + y \ +""" + ucode_str =\ +"""\ + 2 \n\ +-x \n\ +────\n\ + y \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + expr = -x/(y*z) + ascii_str =\ +"""\ +-x \n\ +---\n\ +y*z\ +""" + ucode_str =\ +"""\ +-x \n\ +───\n\ +y⋅z\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + expr = -a/y**2 + ascii_str =\ +"""\ +-a \n\ +---\n\ + 2 \n\ +y \ +""" + ucode_str =\ +"""\ +-a \n\ +───\n\ + 2 \n\ +y \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + expr = y**(-a/b) + ascii_str =\ +"""\ + -a \n\ + ---\n\ + b \n\ +y \ +""" + ucode_str =\ +"""\ + -a \n\ + ───\n\ + b \n\ +y \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + expr = -1/y**2 + ascii_str =\ +"""\ +-1 \n\ +---\n\ + 2 \n\ +y \ +""" + ucode_str =\ +"""\ +-1 \n\ +───\n\ + 2 \n\ +y \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + expr = -10/b**2 + ascii_str =\ +"""\ +-10 \n\ +----\n\ + 2 \n\ + b \ +""" + ucode_str =\ +"""\ +-10 \n\ +────\n\ + 2 \n\ + b \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + expr = Rational(-200, 37) + ascii_str =\ +"""\ +-200 \n\ +-----\n\ + 37 \ +""" + ucode_str =\ +"""\ +-200 \n\ +─────\n\ + 37 \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_Mul(): + expr = Mul(0, 1, evaluate=False) + assert pretty(expr) == "0*1" + assert upretty(expr) == "0⋅1" + expr = Mul(1, 0, evaluate=False) + assert pretty(expr) == "1*0" + assert upretty(expr) == "1⋅0" + expr = Mul(1, 1, evaluate=False) + assert pretty(expr) == "1*1" + assert upretty(expr) == "1⋅1" + expr = Mul(1, 1, 1, evaluate=False) + assert pretty(expr) == "1*1*1" + assert upretty(expr) == "1⋅1⋅1" + expr = Mul(1, 2, evaluate=False) + assert pretty(expr) == "1*2" + assert upretty(expr) == "1⋅2" + expr = Add(0, 1, evaluate=False) + assert pretty(expr) == "0 + 1" + assert upretty(expr) == "0 + 1" + expr = Mul(1, 1, 2, evaluate=False) + assert pretty(expr) == "1*1*2" + assert upretty(expr) == "1⋅1⋅2" + expr = Add(0, 0, 1, evaluate=False) + assert pretty(expr) == "0 + 0 + 1" + assert upretty(expr) == "0 + 0 + 1" + expr = Mul(1, -1, evaluate=False) + assert pretty(expr) == "1*-1" + assert upretty(expr) == "1⋅-1" + expr = Mul(1.0, x, evaluate=False) + assert pretty(expr) == "1.0*x" + assert upretty(expr) == "1.0⋅x" + expr = Mul(1, 1, 2, 3, x, evaluate=False) + assert pretty(expr) == "1*1*2*3*x" + assert upretty(expr) == "1⋅1⋅2⋅3⋅x" + expr = Mul(-1, 1, evaluate=False) + assert pretty(expr) == "-1*1" + assert upretty(expr) == "-1⋅1" + expr = Mul(4, 3, 2, 1, 0, y, x, evaluate=False) + assert pretty(expr) == "4*3*2*1*0*y*x" + assert upretty(expr) == "4⋅3⋅2⋅1⋅0⋅y⋅x" + expr = Mul(4, 3, 2, 1+z, 0, y, x, evaluate=False) + assert pretty(expr) == "4*3*2*(z + 1)*0*y*x" + assert upretty(expr) == "4⋅3⋅2⋅(z + 1)⋅0⋅y⋅x" + expr = Mul(Rational(2, 3), Rational(5, 7), evaluate=False) + assert pretty(expr) == "2/3*5/7" + assert upretty(expr) == "2/3⋅5/7" + expr = Mul(x + y, Rational(1, 2), evaluate=False) + assert pretty(expr) == "(x + y)*1/2" + assert upretty(expr) == "(x + y)⋅1/2" + expr = Mul(Rational(1, 2), x + y, evaluate=False) + assert pretty(expr) == "x + y\n-----\n 2 " + assert upretty(expr) == "x + y\n─────\n 2 " + expr = Mul(S.One, x + y, evaluate=False) + assert pretty(expr) == "1*(x + y)" + assert upretty(expr) == "1⋅(x + y)" + expr = Mul(x - y, S.One, evaluate=False) + assert pretty(expr) == "(x - y)*1" + assert upretty(expr) == "(x - y)⋅1" + expr = Mul(Rational(1, 2), x - y, S.One, x + y, evaluate=False) + assert pretty(expr) == "1/2*(x - y)*1*(x + y)" + assert upretty(expr) == "1/2⋅(x - y)⋅1⋅(x + y)" + expr = Mul(x + y, Rational(3, 4), S.One, y - z, evaluate=False) + assert pretty(expr) == "(x + y)*3/4*1*(y - z)" + assert upretty(expr) == "(x + y)⋅3/4⋅1⋅(y - z)" + expr = Mul(x + y, Rational(1, 1), Rational(3, 4), Rational(5, 6),evaluate=False) + assert pretty(expr) == "(x + y)*1*3/4*5/6" + assert upretty(expr) == "(x + y)⋅1⋅3/4⋅5/6" + expr = Mul(Rational(3, 4), x + y, S.One, y - z, evaluate=False) + assert pretty(expr) == "3/4*(x + y)*1*(y - z)" + assert upretty(expr) == "3/4⋅(x + y)⋅1⋅(y - z)" + + +def test_issue_5524(): + assert pretty(-(-x + 5)*(-x - 2*sqrt(2) + 5) - (-y + 5)*(-y + 5)) == \ +"""\ + 2 / ___ \\\n\ +- (5 - y) + (x - 5)*\\-x - 2*\\/ 2 + 5/\ +""" + + assert upretty(-(-x + 5)*(-x - 2*sqrt(2) + 5) - (-y + 5)*(-y + 5)) == \ +"""\ + 2 \n\ +- (5 - y) + (x - 5)⋅(-x - 2⋅√2 + 5)\ +""" + + +def test_pretty_ordering(): + assert pretty(x**2 + x + 1, order='lex') == \ +"""\ + 2 \n\ +x + x + 1\ +""" + assert pretty(x**2 + x + 1, order='rev-lex') == \ +"""\ + 2\n\ +1 + x + x \ +""" + assert pretty(1 - x, order='lex') == '-x + 1' + assert pretty(1 - x, order='rev-lex') == '1 - x' + + assert pretty(1 - 2*x, order='lex') == '-2*x + 1' + assert pretty(1 - 2*x, order='rev-lex') == '1 - 2*x' + + f = 2*x**4 + y**2 - x**2 + y**3 + assert pretty(f, order=None) == \ +"""\ + 4 2 3 2\n\ +2*x - x + y + y \ +""" + assert pretty(f, order='lex') == \ +"""\ + 4 2 3 2\n\ +2*x - x + y + y \ +""" + assert pretty(f, order='rev-lex') == \ +"""\ + 2 3 2 4\n\ +y + y - x + 2*x \ +""" + + expr = x - x**3/6 + x**5/120 + O(x**6) + ascii_str = \ +"""\ + 3 5 \n\ + x x / 6\\\n\ +x - -- + --- + O\\x /\n\ + 6 120 \ +""" + ucode_str = \ +"""\ + 3 5 \n\ + x x ⎛ 6⎞\n\ +x - ── + ─── + O⎝x ⎠\n\ + 6 120 \ +""" + assert pretty(expr, order=None) == ascii_str + assert upretty(expr, order=None) == ucode_str + + assert pretty(expr, order='lex') == ascii_str + assert upretty(expr, order='lex') == ucode_str + + assert pretty(expr, order='rev-lex') == ascii_str + assert upretty(expr, order='rev-lex') == ucode_str + + +def test_EulerGamma(): + assert pretty(EulerGamma) == str(EulerGamma) == "EulerGamma" + assert upretty(EulerGamma) == "γ" + + +def test_GoldenRatio(): + assert pretty(GoldenRatio) == str(GoldenRatio) == "GoldenRatio" + assert upretty(GoldenRatio) == "φ" + + +def test_Catalan(): + assert pretty(Catalan) == upretty(Catalan) == "G" + + +def test_pretty_relational(): + expr = Eq(x, y) + ascii_str = \ +"""\ +x = y\ +""" + ucode_str = \ +"""\ +x = y\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Lt(x, y) + ascii_str = \ +"""\ +x < y\ +""" + ucode_str = \ +"""\ +x < y\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Gt(x, y) + ascii_str = \ +"""\ +x > y\ +""" + ucode_str = \ +"""\ +x > y\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Le(x, y) + ascii_str = \ +"""\ +x <= y\ +""" + ucode_str = \ +"""\ +x ≤ y\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Ge(x, y) + ascii_str = \ +"""\ +x >= y\ +""" + ucode_str = \ +"""\ +x ≥ y\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Ne(x/(y + 1), y**2) + ascii_str_1 = \ +"""\ + x 2\n\ +----- != y \n\ +1 + y \ +""" + ascii_str_2 = \ +"""\ + x 2\n\ +----- != y \n\ +y + 1 \ +""" + ucode_str_1 = \ +"""\ + x 2\n\ +───── ≠ y \n\ +1 + y \ +""" + ucode_str_2 = \ +"""\ + x 2\n\ +───── ≠ y \n\ +y + 1 \ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2] + assert upretty(expr) in [ucode_str_1, ucode_str_2] + + +def test_Assignment(): + expr = Assignment(x, y) + ascii_str = \ +"""\ +x := y\ +""" + ucode_str = \ +"""\ +x := y\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_AugmentedAssignment(): + expr = AddAugmentedAssignment(x, y) + ascii_str = \ +"""\ +x += y\ +""" + ucode_str = \ +"""\ +x += y\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = SubAugmentedAssignment(x, y) + ascii_str = \ +"""\ +x -= y\ +""" + ucode_str = \ +"""\ +x -= y\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = MulAugmentedAssignment(x, y) + ascii_str = \ +"""\ +x *= y\ +""" + ucode_str = \ +"""\ +x *= y\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = DivAugmentedAssignment(x, y) + ascii_str = \ +"""\ +x /= y\ +""" + ucode_str = \ +"""\ +x /= y\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = ModAugmentedAssignment(x, y) + ascii_str = \ +"""\ +x %= y\ +""" + ucode_str = \ +"""\ +x %= y\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_pretty_rational(): + expr = y*x**-2 + ascii_str = \ +"""\ +y \n\ +--\n\ + 2\n\ +x \ +""" + ucode_str = \ +"""\ +y \n\ +──\n\ + 2\n\ +x \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = y**Rational(3, 2) * x**Rational(-5, 2) + ascii_str = \ +"""\ + 3/2\n\ +y \n\ +----\n\ + 5/2\n\ +x \ +""" + ucode_str = \ +"""\ + 3/2\n\ +y \n\ +────\n\ + 5/2\n\ +x \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = sin(x)**3/tan(x)**2 + ascii_str = \ +"""\ + 3 \n\ +sin (x)\n\ +-------\n\ + 2 \n\ +tan (x)\ +""" + ucode_str = \ +"""\ + 3 \n\ +sin (x)\n\ +───────\n\ + 2 \n\ +tan (x)\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +@_both_exp_pow +def test_pretty_functions(): + """Tests for Abs, conjugate, exp, function braces, and factorial.""" + expr = (2*x + exp(x)) + ascii_str_1 = \ +"""\ + x\n\ +2*x + e \ +""" + ascii_str_2 = \ +"""\ + x \n\ +e + 2*x\ +""" + ucode_str_1 = \ +"""\ + x\n\ +2⋅x + ℯ \ +""" + ucode_str_2 = \ +"""\ + x \n\ +ℯ + 2⋅x\ +""" + ucode_str_3 = \ +"""\ + x \n\ +ℯ + 2⋅x\ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2] + assert upretty(expr) in [ucode_str_1, ucode_str_2, ucode_str_3] + + expr = Abs(x) + ascii_str = \ +"""\ +|x|\ +""" + ucode_str = \ +"""\ +│x│\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Abs(x/(x**2 + 1)) + ascii_str_1 = \ +"""\ +| x |\n\ +|------|\n\ +| 2|\n\ +|1 + x |\ +""" + ascii_str_2 = \ +"""\ +| x |\n\ +|------|\n\ +| 2 |\n\ +|x + 1|\ +""" + ucode_str_1 = \ +"""\ +│ x │\n\ +│──────│\n\ +│ 2│\n\ +│1 + x │\ +""" + ucode_str_2 = \ +"""\ +│ x │\n\ +│──────│\n\ +│ 2 │\n\ +│x + 1│\ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2] + assert upretty(expr) in [ucode_str_1, ucode_str_2] + + expr = Abs(1 / (y - Abs(x))) + ascii_str = \ +"""\ + 1 \n\ +---------\n\ +|y - |x||\ +""" + ucode_str = \ +"""\ + 1 \n\ +─────────\n\ +│y - │x││\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + n = Symbol('n', integer=True) + expr = factorial(n) + ascii_str = \ +"""\ +n!\ +""" + ucode_str = \ +"""\ +n!\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = factorial(2*n) + ascii_str = \ +"""\ +(2*n)!\ +""" + ucode_str = \ +"""\ +(2⋅n)!\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = factorial(factorial(factorial(n))) + ascii_str = \ +"""\ +((n!)!)!\ +""" + ucode_str = \ +"""\ +((n!)!)!\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = factorial(n + 1) + ascii_str_1 = \ +"""\ +(1 + n)!\ +""" + ascii_str_2 = \ +"""\ +(n + 1)!\ +""" + ucode_str_1 = \ +"""\ +(1 + n)!\ +""" + ucode_str_2 = \ +"""\ +(n + 1)!\ +""" + + assert pretty(expr) in [ascii_str_1, ascii_str_2] + assert upretty(expr) in [ucode_str_1, ucode_str_2] + + expr = subfactorial(n) + ascii_str = \ +"""\ +!n\ +""" + ucode_str = \ +"""\ +!n\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = subfactorial(2*n) + ascii_str = \ +"""\ +!(2*n)\ +""" + ucode_str = \ +"""\ +!(2⋅n)\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + n = Symbol('n', integer=True) + expr = factorial2(n) + ascii_str = \ +"""\ +n!!\ +""" + ucode_str = \ +"""\ +n!!\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = factorial2(2*n) + ascii_str = \ +"""\ +(2*n)!!\ +""" + ucode_str = \ +"""\ +(2⋅n)!!\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = factorial2(factorial2(factorial2(n))) + ascii_str = \ +"""\ +((n!!)!!)!!\ +""" + ucode_str = \ +"""\ +((n!!)!!)!!\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = factorial2(n + 1) + ascii_str_1 = \ +"""\ +(1 + n)!!\ +""" + ascii_str_2 = \ +"""\ +(n + 1)!!\ +""" + ucode_str_1 = \ +"""\ +(1 + n)!!\ +""" + ucode_str_2 = \ +"""\ +(n + 1)!!\ +""" + + assert pretty(expr) in [ascii_str_1, ascii_str_2] + assert upretty(expr) in [ucode_str_1, ucode_str_2] + + expr = 2*binomial(n, k) + ascii_str = \ +"""\ + /n\\\n\ +2*| |\n\ + \\k/\ +""" + ucode_str = \ +"""\ + ⎛n⎞\n\ +2⋅⎜ ⎟\n\ + ⎝k⎠\ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = 2*binomial(2*n, k) + ascii_str = \ +"""\ + /2*n\\\n\ +2*| |\n\ + \\ k /\ +""" + ucode_str = \ +"""\ + ⎛2⋅n⎞\n\ +2⋅⎜ ⎟\n\ + ⎝ k ⎠\ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = 2*binomial(n**2, k) + ascii_str = \ +"""\ + / 2\\\n\ + |n |\n\ +2*| |\n\ + \\k /\ +""" + ucode_str = \ +"""\ + ⎛ 2⎞\n\ + ⎜n ⎟\n\ +2⋅⎜ ⎟\n\ + ⎝k ⎠\ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = catalan(n) + ascii_str = \ +"""\ +C \n\ + n\ +""" + ucode_str = \ +"""\ +C \n\ + n\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = catalan(n) + ascii_str = \ +"""\ +C \n\ + n\ +""" + ucode_str = \ +"""\ +C \n\ + n\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = bell(n) + ascii_str = \ +"""\ +B \n\ + n\ +""" + ucode_str = \ +"""\ +B \n\ + n\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = bernoulli(n) + ascii_str = \ +"""\ +B \n\ + n\ +""" + ucode_str = \ +"""\ +B \n\ + n\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = bernoulli(n, x) + ascii_str = \ +"""\ +B (x)\n\ + n \ +""" + ucode_str = \ +"""\ +B (x)\n\ + n \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = fibonacci(n) + ascii_str = \ +"""\ +F \n\ + n\ +""" + ucode_str = \ +"""\ +F \n\ + n\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = lucas(n) + ascii_str = \ +"""\ +L \n\ + n\ +""" + ucode_str = \ +"""\ +L \n\ + n\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = tribonacci(n) + ascii_str = \ +"""\ +T \n\ + n\ +""" + ucode_str = \ +"""\ +T \n\ + n\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = stieltjes(n) + ascii_str = \ +"""\ +stieltjes \n\ + n\ +""" + ucode_str = \ +"""\ +γ \n\ + n\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = stieltjes(n, x) + ascii_str = \ +"""\ +stieltjes (x)\n\ + n \ +""" + ucode_str = \ +"""\ +γ (x)\n\ + n \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = mathieuc(x, y, z) + ascii_str = 'C(x, y, z)' + ucode_str = 'C(x, y, z)' + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = mathieus(x, y, z) + ascii_str = 'S(x, y, z)' + ucode_str = 'S(x, y, z)' + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = mathieucprime(x, y, z) + ascii_str = "C'(x, y, z)" + ucode_str = "C'(x, y, z)" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = mathieusprime(x, y, z) + ascii_str = "S'(x, y, z)" + ucode_str = "S'(x, y, z)" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = conjugate(x) + ascii_str = \ +"""\ +_\n\ +x\ +""" + ucode_str = \ +"""\ +_\n\ +x\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + f = Function('f') + expr = conjugate(f(x + 1)) + ascii_str_1 = \ +"""\ +________\n\ +f(1 + x)\ +""" + ascii_str_2 = \ +"""\ +________\n\ +f(x + 1)\ +""" + ucode_str_1 = \ +"""\ +________\n\ +f(1 + x)\ +""" + ucode_str_2 = \ +"""\ +________\n\ +f(x + 1)\ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2] + assert upretty(expr) in [ucode_str_1, ucode_str_2] + + expr = f(x) + ascii_str = \ +"""\ +f(x)\ +""" + ucode_str = \ +"""\ +f(x)\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = f(x, y) + ascii_str = \ +"""\ +f(x, y)\ +""" + ucode_str = \ +"""\ +f(x, y)\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = f(x/(y + 1), y) + ascii_str_1 = \ +"""\ + / x \\\n\ +f|-----, y|\n\ + \\1 + y /\ +""" + ascii_str_2 = \ +"""\ + / x \\\n\ +f|-----, y|\n\ + \\y + 1 /\ +""" + ucode_str_1 = \ +"""\ + ⎛ x ⎞\n\ +f⎜─────, y⎟\n\ + ⎝1 + y ⎠\ +""" + ucode_str_2 = \ +"""\ + ⎛ x ⎞\n\ +f⎜─────, y⎟\n\ + ⎝y + 1 ⎠\ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2] + assert upretty(expr) in [ucode_str_1, ucode_str_2] + + expr = f(x**x**x**x**x**x) + ascii_str = \ +"""\ + / / / / / x\\\\\\\\\\ + | | | | \\x /|||| + | | | \\x /||| + | | \\x /|| + | \\x /| +f\\x /\ +""" + ucode_str = \ +"""\ + ⎛ ⎛ ⎛ ⎛ ⎛ x⎞⎞⎞⎞⎞ + ⎜ ⎜ ⎜ ⎜ ⎝x ⎠⎟⎟⎟⎟ + ⎜ ⎜ ⎜ ⎝x ⎠⎟⎟⎟ + ⎜ ⎜ ⎝x ⎠⎟⎟ + ⎜ ⎝x ⎠⎟ +f⎝x ⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = sin(x)**2 + ascii_str = \ +"""\ + 2 \n\ +sin (x)\ +""" + ucode_str = \ +"""\ + 2 \n\ +sin (x)\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = conjugate(a + b*I) + ascii_str = \ +"""\ +_ _\n\ +a - I*b\ +""" + ucode_str = \ +"""\ +_ _\n\ +a - ⅈ⋅b\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = conjugate(exp(a + b*I)) + ascii_str = \ +"""\ + _ _\n\ + a - I*b\n\ +e \ +""" + ucode_str = \ +"""\ + _ _\n\ + a - ⅈ⋅b\n\ +ℯ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = conjugate( f(1 + conjugate(f(x))) ) + ascii_str_1 = \ +"""\ +___________\n\ + / ____\\\n\ +f\\1 + f(x)/\ +""" + ascii_str_2 = \ +"""\ +___________\n\ + /____ \\\n\ +f\\f(x) + 1/\ +""" + ucode_str_1 = \ +"""\ +___________\n\ + ⎛ ____⎞\n\ +f⎝1 + f(x)⎠\ +""" + ucode_str_2 = \ +"""\ +___________\n\ + ⎛____ ⎞\n\ +f⎝f(x) + 1⎠\ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2] + assert upretty(expr) in [ucode_str_1, ucode_str_2] + + expr = f(x/(y + 1), y) + ascii_str_1 = \ +"""\ + / x \\\n\ +f|-----, y|\n\ + \\1 + y /\ +""" + ascii_str_2 = \ +"""\ + / x \\\n\ +f|-----, y|\n\ + \\y + 1 /\ +""" + ucode_str_1 = \ +"""\ + ⎛ x ⎞\n\ +f⎜─────, y⎟\n\ + ⎝1 + y ⎠\ +""" + ucode_str_2 = \ +"""\ + ⎛ x ⎞\n\ +f⎜─────, y⎟\n\ + ⎝y + 1 ⎠\ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2] + assert upretty(expr) in [ucode_str_1, ucode_str_2] + + expr = floor(1 / (y - floor(x))) + ascii_str = \ +"""\ + / 1 \\\n\ +floor|------------|\n\ + \\y - floor(x)/\ +""" + ucode_str = \ +"""\ +⎢ 1 ⎥\n\ +⎢───────⎥\n\ +⎣y - ⌊x⌋⎦\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = ceiling(1 / (y - ceiling(x))) + ascii_str = \ +"""\ + / 1 \\\n\ +ceiling|--------------|\n\ + \\y - ceiling(x)/\ +""" + ucode_str = \ +"""\ +⎡ 1 ⎤\n\ +⎢───────⎥\n\ +⎢y - ⌈x⌉⎥\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = euler(n) + ascii_str = \ +"""\ +E \n\ + n\ +""" + ucode_str = \ +"""\ +E \n\ + n\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = euler(1/(1 + 1/(1 + 1/n))) + ascii_str = \ +"""\ +E \n\ + 1 \n\ + ---------\n\ + 1 \n\ + 1 + -----\n\ + 1\n\ + 1 + -\n\ + n\ +""" + + ucode_str = \ +"""\ +E \n\ + 1 \n\ + ─────────\n\ + 1 \n\ + 1 + ─────\n\ + 1\n\ + 1 + ─\n\ + n\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = euler(n, x) + ascii_str = \ +"""\ +E (x)\n\ + n \ +""" + ucode_str = \ +"""\ +E (x)\n\ + n \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = euler(n, x/2) + ascii_str = \ +"""\ + /x\\\n\ +E |-|\n\ + n\\2/\ +""" + ucode_str = \ +"""\ + ⎛x⎞\n\ +E ⎜─⎟\n\ + n⎝2⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_pretty_sqrt(): + expr = sqrt(2) + ascii_str = \ +"""\ + ___\n\ +\\/ 2 \ +""" + ucode_str = \ +"√2" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = 2**Rational(1, 3) + ascii_str = \ +"""\ +3 ___\n\ +\\/ 2 \ +""" + ucode_str = \ +"""\ +3 ___\n\ +╲╱ 2 \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = 2**Rational(1, 1000) + ascii_str = \ +"""\ +1000___\n\ + \\/ 2 \ +""" + ucode_str = \ +"""\ +1000___\n\ + ╲╱ 2 \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = sqrt(x**2 + 1) + ascii_str = \ +"""\ + ________\n\ + / 2 \n\ +\\/ x + 1 \ +""" + ucode_str = \ +"""\ + ________\n\ + ╱ 2 \n\ +╲╱ x + 1 \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = (1 + sqrt(5))**Rational(1, 3) + ascii_str = \ +"""\ + ___________\n\ +3 / ___ \n\ +\\/ 1 + \\/ 5 \ +""" + ucode_str = \ +"""\ +3 ________\n\ +╲╱ 1 + √5 \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = 2**(1/x) + ascii_str = \ +"""\ +x ___\n\ +\\/ 2 \ +""" + ucode_str = \ +"""\ +x ___\n\ +╲╱ 2 \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = sqrt(2 + pi) + ascii_str = \ +"""\ + ________\n\ +\\/ 2 + pi \ +""" + ucode_str = \ +"""\ + _______\n\ +╲╱ 2 + π \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = (2 + ( + 1 + x**2)/(2 + x))**Rational(1, 4) + (1 + x**Rational(1, 1000))/sqrt(3 + x**2) + ascii_str = \ +"""\ + ____________ \n\ + / 2 1000___ \n\ + / x + 1 \\/ x + 1\n\ +4 / 2 + ------ + -----------\n\ +\\/ x + 2 ________\n\ + / 2 \n\ + \\/ x + 3 \ +""" + ucode_str = \ +"""\ + ____________ \n\ + ╱ 2 1000___ \n\ + ╱ x + 1 ╲╱ x + 1\n\ +4 ╱ 2 + ────── + ───────────\n\ +╲╱ x + 2 ________\n\ + ╱ 2 \n\ + ╲╱ x + 3 \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_pretty_sqrt_char_knob(): + # See PR #9234. + expr = sqrt(2) + ucode_str1 = \ +"""\ + ___\n\ +╲╱ 2 \ +""" + ucode_str2 = \ +"√2" + assert xpretty(expr, use_unicode=True, + use_unicode_sqrt_char=False) == ucode_str1 + assert xpretty(expr, use_unicode=True, + use_unicode_sqrt_char=True) == ucode_str2 + + +def test_pretty_sqrt_longsymbol_no_sqrt_char(): + # Do not use unicode sqrt char for long symbols (see PR #9234). + expr = sqrt(Symbol('C1')) + ucode_str = \ +"""\ + ____\n\ +╲╱ C₁ \ +""" + assert upretty(expr) == ucode_str + + +def test_pretty_KroneckerDelta(): + x, y = symbols("x, y") + expr = KroneckerDelta(x, y) + ascii_str = \ +"""\ +d \n\ + x,y\ +""" + ucode_str = \ +"""\ +δ \n\ + x,y\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_pretty_product(): + n, m, k, l = symbols('n m k l') + f = symbols('f', cls=Function) + expr = Product(f((n/3)**2), (n, k**2, l)) + + unicode_str = \ +"""\ + l \n\ +─┬──────┬─ \n\ + │ │ ⎛ 2⎞\n\ + │ │ ⎜n ⎟\n\ + │ │ f⎜──⎟\n\ + │ │ ⎝9 ⎠\n\ + │ │ \n\ + 2 \n\ + n = k """ + ascii_str = \ +"""\ + l \n\ +__________ \n\ + | | / 2\\\n\ + | | |n |\n\ + | | f|--|\n\ + | | \\9 /\n\ + | | \n\ + 2 \n\ + n = k """ + + expr = Product(f((n/3)**2), (n, k**2, l), (l, 1, m)) + + unicode_str = \ +"""\ + m l \n\ +─┬──────┬─ ─┬──────┬─ \n\ + │ │ │ │ ⎛ 2⎞\n\ + │ │ │ │ ⎜n ⎟\n\ + │ │ │ │ f⎜──⎟\n\ + │ │ │ │ ⎝9 ⎠\n\ + │ │ │ │ \n\ + l = 1 2 \n\ + n = k """ + ascii_str = \ +"""\ + m l \n\ +__________ __________ \n\ + | | | | / 2\\\n\ + | | | | |n |\n\ + | | | | f|--|\n\ + | | | | \\9 /\n\ + | | | | \n\ + l = 1 2 \n\ + n = k """ + + assert pretty(expr) == ascii_str + assert upretty(expr) == unicode_str + + +def test_pretty_Lambda(): + # S.IdentityFunction is a special case + expr = Lambda(y, y) + assert pretty(expr) == "x -> x" + assert upretty(expr) == "x ↦ x" + + expr = Lambda(x, x+1) + assert pretty(expr) == "x -> x + 1" + assert upretty(expr) == "x ↦ x + 1" + + expr = Lambda(x, x**2) + ascii_str = \ +"""\ + 2\n\ +x -> x \ +""" + ucode_str = \ +"""\ + 2\n\ +x ↦ x \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Lambda(x, x**2)**2 + ascii_str = \ +"""\ + 2 +/ 2\\ \n\ +\\x -> x / \ +""" + ucode_str = \ +"""\ + 2 +⎛ 2⎞ \n\ +⎝x ↦ x ⎠ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Lambda((x, y), x) + ascii_str = "(x, y) -> x" + ucode_str = "(x, y) ↦ x" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Lambda((x, y), x**2) + ascii_str = \ +"""\ + 2\n\ +(x, y) -> x \ +""" + ucode_str = \ +"""\ + 2\n\ +(x, y) ↦ x \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Lambda(((x, y),), x**2) + ascii_str = \ +"""\ + 2\n\ +((x, y),) -> x \ +""" + ucode_str = \ +"""\ + 2\n\ +((x, y),) ↦ x \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_pretty_TransferFunction(): + tf1 = TransferFunction(s - 1, s + 1, s) + assert upretty(tf1) == "s - 1\n─────\ns + 1" + tf2 = TransferFunction(2*s + 1, 3 - p, s) + assert upretty(tf2) == "2⋅s + 1\n───────\n 3 - p " + tf3 = TransferFunction(p, p + 1, p) + assert upretty(tf3) == " p \n─────\np + 1" + + +def test_pretty_Series(): + tf1 = TransferFunction(x + y, x - 2*y, y) + tf2 = TransferFunction(x - y, x + y, y) + tf3 = TransferFunction(x**2 + y, y - x, y) + tf4 = TransferFunction(2, 3, y) + + tfm1 = TransferFunctionMatrix([[tf1, tf2], [tf3, tf4]]) + tfm2 = TransferFunctionMatrix([[tf3], [-tf4]]) + tfm3 = TransferFunctionMatrix([[tf1, -tf2, -tf3], [tf3, -tf4, tf2]]) + tfm4 = TransferFunctionMatrix([[tf1, tf2], [tf3, -tf4], [-tf2, -tf1]]) + tfm5 = TransferFunctionMatrix([[-tf2, -tf1], [tf4, -tf3], [tf1, tf2]]) + + expected1 = \ +"""\ + ⎛ 2 ⎞\n\ +⎛ x + y ⎞ ⎜x + y⎟\n\ +⎜───────⎟⋅⎜──────⎟\n\ +⎝x - 2⋅y⎠ ⎝-x + y⎠\ +""" + expected2 = \ +"""\ +⎛-x + y⎞ ⎛-x - y ⎞\n\ +⎜──────⎟⋅⎜───────⎟\n\ +⎝x + y ⎠ ⎝x - 2⋅y⎠\ +""" + expected3 = \ +"""\ +⎛ 2 ⎞ \n\ +⎜x + y⎟ ⎛ x + y ⎞ ⎛-x - y x - y⎞\n\ +⎜──────⎟⋅⎜───────⎟⋅⎜─────── + ─────⎟\n\ +⎝-x + y⎠ ⎝x - 2⋅y⎠ ⎝x - 2⋅y x + y⎠\ +""" + expected4 = \ +"""\ + ⎛ 2 ⎞\n\ +⎛ x + y x - y⎞ ⎜x - y x + y⎟\n\ +⎜─────── + ─────⎟⋅⎜───── + ──────⎟\n\ +⎝x - 2⋅y x + y⎠ ⎝x + y -x + y⎠\ +""" + expected5 = \ +"""\ +⎡ x + y x - y⎤ ⎡ 2 ⎤ \n\ +⎢─────── ─────⎥ ⎢x + y⎥ \n\ +⎢x - 2⋅y x + y⎥ ⎢──────⎥ \n\ +⎢ ⎥ ⎢-x + y⎥ \n\ +⎢ 2 ⎥ ⋅⎢ ⎥ \n\ +⎢x + y 2 ⎥ ⎢ -2 ⎥ \n\ +⎢────── ─ ⎥ ⎢ ─── ⎥ \n\ +⎣-x + y 3 ⎦τ ⎣ 3 ⎦τ\ +""" + expected6 = \ +"""\ + ⎛⎡ x + y x - y ⎤ ⎡ x - y x + y ⎤ ⎞\n\ + ⎜⎢─────── ───── ⎥ ⎢ ───── ───────⎥ ⎟\n\ +⎡ x + y x - y⎤ ⎡ 2 ⎤ ⎜⎢x - 2⋅y x + y ⎥ ⎢ x + y x - 2⋅y⎥ ⎟\n\ +⎢─────── ─────⎥ ⎢ x + y -x + y - x - y⎥ ⎜⎢ ⎥ ⎢ ⎥ ⎟\n\ +⎢x - 2⋅y x + y⎥ ⎢─────── ────── ────────⎥ ⎜⎢ 2 ⎥ ⎢ 2 ⎥ ⎟\n\ +⎢ ⎥ ⎢x - 2⋅y x + y -x + y ⎥ ⎜⎢x + y -2 ⎥ ⎢ -2 x + y ⎥ ⎟\n\ +⎢ 2 ⎥ ⋅⎢ ⎥ ⋅⎜⎢────── ─── ⎥ + ⎢ ─── ────── ⎥ ⎟\n\ +⎢x + y 2 ⎥ ⎢ 2 ⎥ ⎜⎢-x + y 3 ⎥ ⎢ 3 -x + y ⎥ ⎟\n\ +⎢────── ─ ⎥ ⎢x + y -2 x - y ⎥ ⎜⎢ ⎥ ⎢ ⎥ ⎟\n\ +⎣-x + y 3 ⎦τ ⎢────── ─── ───── ⎥ ⎜⎢-x + y -x - y ⎥ ⎢-x - y -x + y ⎥ ⎟\n\ + ⎣-x + y 3 x + y ⎦τ ⎜⎢────── ───────⎥ ⎢─────── ────── ⎥ ⎟\n\ + ⎝⎣x + y x - 2⋅y⎦τ ⎣x - 2⋅y x + y ⎦τ⎠\ +""" + + assert upretty(Series(tf1, tf3)) == expected1 + assert upretty(Series(-tf2, -tf1)) == expected2 + assert upretty(Series(tf3, tf1, Parallel(-tf1, tf2))) == expected3 + assert upretty(Series(Parallel(tf1, tf2), Parallel(tf2, tf3))) == expected4 + assert upretty(MIMOSeries(tfm2, tfm1)) == expected5 + assert upretty(MIMOSeries(MIMOParallel(tfm4, -tfm5), tfm3, tfm1)) == expected6 + + +def test_pretty_Parallel(): + tf1 = TransferFunction(x + y, x - 2*y, y) + tf2 = TransferFunction(x - y, x + y, y) + tf3 = TransferFunction(x**2 + y, y - x, y) + tf4 = TransferFunction(y**2 - x, x**3 + x, y) + + tfm1 = TransferFunctionMatrix([[tf1, tf2], [tf3, -tf4], [-tf2, -tf1]]) + tfm2 = TransferFunctionMatrix([[-tf2, -tf1], [tf4, -tf3], [tf1, tf2]]) + tfm3 = TransferFunctionMatrix([[-tf1, tf2], [-tf3, tf4], [tf2, tf1]]) + tfm4 = TransferFunctionMatrix([[-tf1, -tf2], [-tf3, -tf4]]) + + expected1 = \ +"""\ + x + y x - y\n\ +─────── + ─────\n\ +x - 2⋅y x + y\ +""" + expected2 = \ +"""\ +-x + y -x - y \n\ +────── + ─────── +x + y x - 2⋅y\ +""" + expected3 = \ +"""\ + 2 \n\ +x + y x + y ⎛-x - y ⎞ ⎛x - y⎞ +────── + ─────── + ⎜───────⎟⋅⎜─────⎟ +-x + y x - 2⋅y ⎝x - 2⋅y⎠ ⎝x + y⎠\ +""" + + expected4 = \ +"""\ + ⎛ 2 ⎞\n\ +⎛ x + y ⎞ ⎛x - y⎞ ⎛x - y⎞ ⎜x + y⎟\n\ +⎜───────⎟⋅⎜─────⎟ + ⎜─────⎟⋅⎜──────⎟\n\ +⎝x - 2⋅y⎠ ⎝x + y⎠ ⎝x + y⎠ ⎝-x + y⎠\ +""" + expected5 = \ +"""\ +⎡ x + y -x + y ⎤ ⎡ x - y x + y ⎤ ⎡ x + y x - y ⎤ \n\ +⎢─────── ────── ⎥ ⎢ ───── ───────⎥ ⎢─────── ───── ⎥ \n\ +⎢x - 2⋅y x + y ⎥ ⎢ x + y x - 2⋅y⎥ ⎢x - 2⋅y x + y ⎥ \n\ +⎢ ⎥ ⎢ ⎥ ⎢ ⎥ \n\ +⎢ 2 2 ⎥ ⎢ 2 2 ⎥ ⎢ 2 2 ⎥ \n\ +⎢x + y x - y ⎥ ⎢x - y x + y ⎥ ⎢x + y x - y ⎥ \n\ +⎢────── ────── ⎥ + ⎢────── ────── ⎥ + ⎢────── ────── ⎥ \n\ +⎢-x + y 3 ⎥ ⎢ 3 -x + y ⎥ ⎢-x + y 3 ⎥ \n\ +⎢ x + x ⎥ ⎢x + x ⎥ ⎢ x + x ⎥ \n\ +⎢ ⎥ ⎢ ⎥ ⎢ ⎥ \n\ +⎢-x + y -x - y ⎥ ⎢-x - y -x + y ⎥ ⎢-x + y -x - y ⎥ \n\ +⎢────── ───────⎥ ⎢─────── ────── ⎥ ⎢────── ───────⎥ \n\ +⎣x + y x - 2⋅y⎦τ ⎣x - 2⋅y x + y ⎦τ ⎣x + y x - 2⋅y⎦τ\ +""" + expected6 = \ +"""\ +⎡ x - y x + y ⎤ ⎡-x + y -x - y ⎤ \n\ +⎢ ───── ───────⎥ ⎢────── ─────── ⎥ \n\ +⎢ x + y x - 2⋅y⎥ ⎡-x - y -x + y⎤ ⎢x + y x - 2⋅y ⎥ \n\ +⎢ ⎥ ⎢─────── ──────⎥ ⎢ ⎥ \n\ +⎢ 2 2 ⎥ ⎢x - 2⋅y x + y ⎥ ⎢ 2 2 ⎥ \n\ +⎢x - y x + y ⎥ ⎢ ⎥ ⎢-x + y - x - y⎥ \n\ +⎢────── ────── ⎥ ⋅⎢ 2 2⎥ + ⎢─────── ────────⎥ \n\ +⎢ 3 -x + y ⎥ ⎢- x - y x - y ⎥ ⎢ 3 -x + y ⎥ \n\ +⎢x + x ⎥ ⎢──────── ──────⎥ ⎢x + x ⎥ \n\ +⎢ ⎥ ⎢ -x + y 3 ⎥ ⎢ ⎥ \n\ +⎢-x - y -x + y ⎥ ⎣ x + x⎦τ ⎢ x + y x - y ⎥ \n\ +⎢─────── ────── ⎥ ⎢─────── ───── ⎥ \n\ +⎣x - 2⋅y x + y ⎦τ ⎣x - 2⋅y x + y ⎦τ\ +""" + assert upretty(Parallel(tf1, tf2)) == expected1 + assert upretty(Parallel(-tf2, -tf1)) == expected2 + assert upretty(Parallel(tf3, tf1, Series(-tf1, tf2))) == expected3 + assert upretty(Parallel(Series(tf1, tf2), Series(tf2, tf3))) == expected4 + assert upretty(MIMOParallel(-tfm3, -tfm2, tfm1)) == expected5 + assert upretty(MIMOParallel(MIMOSeries(tfm4, -tfm2), tfm2)) == expected6 + + +def test_pretty_Feedback(): + tf = TransferFunction(1, 1, y) + tf1 = TransferFunction(x + y, x - 2*y, y) + tf2 = TransferFunction(x - y, x + y, y) + tf3 = TransferFunction(y**2 - 2*y + 1, y + 5, y) + tf4 = TransferFunction(x - 2*y**3, x + y, x) + tf5 = TransferFunction(1 - x, x - y, y) + tf6 = TransferFunction(2, 2, x) + expected1 = \ +"""\ + ⎛1⎞ \n\ + ⎜─⎟ \n\ + ⎝1⎠ \n\ +─────────────\n\ +1 ⎛ x + y ⎞\n\ +─ + ⎜───────⎟\n\ +1 ⎝x - 2⋅y⎠\ +""" + expected2 = \ +"""\ + ⎛1⎞ \n\ + ⎜─⎟ \n\ + ⎝1⎠ \n\ +────────────────────────────────────\n\ + ⎛ 2 ⎞\n\ +1 ⎛x - y⎞ ⎛ x + y ⎞ ⎜y - 2⋅y + 1⎟\n\ +─ + ⎜─────⎟⋅⎜───────⎟⋅⎜────────────⎟\n\ +1 ⎝x + y⎠ ⎝x - 2⋅y⎠ ⎝ y + 5 ⎠\ +""" + expected3 = \ +"""\ + ⎛ x + y ⎞ \n\ + ⎜───────⎟ \n\ + ⎝x - 2⋅y⎠ \n\ +────────────────────────────────────────────\n\ + ⎛ 2 ⎞ \n\ +1 ⎛ x + y ⎞ ⎛x - y⎞ ⎜y - 2⋅y + 1⎟ ⎛1 - x⎞\n\ +─ + ⎜───────⎟⋅⎜─────⎟⋅⎜────────────⎟⋅⎜─────⎟\n\ +1 ⎝x - 2⋅y⎠ ⎝x + y⎠ ⎝ y + 5 ⎠ ⎝x - y⎠\ +""" + expected4 = \ +"""\ + ⎛ x + y ⎞ ⎛x - y⎞ \n\ + ⎜───────⎟⋅⎜─────⎟ \n\ + ⎝x - 2⋅y⎠ ⎝x + y⎠ \n\ +─────────────────────\n\ +1 ⎛ x + y ⎞ ⎛x - y⎞\n\ +─ + ⎜───────⎟⋅⎜─────⎟\n\ +1 ⎝x - 2⋅y⎠ ⎝x + y⎠\ +""" + expected5 = \ +"""\ + ⎛ x + y ⎞ ⎛x - y⎞ \n\ + ⎜───────⎟⋅⎜─────⎟ \n\ + ⎝x - 2⋅y⎠ ⎝x + y⎠ \n\ +─────────────────────────────\n\ +1 ⎛ x + y ⎞ ⎛x - y⎞ ⎛1 - x⎞\n\ +─ + ⎜───────⎟⋅⎜─────⎟⋅⎜─────⎟\n\ +1 ⎝x - 2⋅y⎠ ⎝x + y⎠ ⎝x - y⎠\ +""" + expected6 = \ +"""\ + ⎛ 2 ⎞ \n\ + ⎜y - 2⋅y + 1⎟ ⎛1 - x⎞ \n\ + ⎜────────────⎟⋅⎜─────⎟ \n\ + ⎝ y + 5 ⎠ ⎝x - y⎠ \n\ +────────────────────────────────────────────\n\ + ⎛ 2 ⎞ \n\ +1 ⎜y - 2⋅y + 1⎟ ⎛1 - x⎞ ⎛x - y⎞ ⎛ x + y ⎞\n\ +─ + ⎜────────────⎟⋅⎜─────⎟⋅⎜─────⎟⋅⎜───────⎟\n\ +1 ⎝ y + 5 ⎠ ⎝x - y⎠ ⎝x + y⎠ ⎝x - 2⋅y⎠\ +""" + expected7 = \ +"""\ + ⎛ 3⎞ \n\ + ⎜x - 2⋅y ⎟ \n\ + ⎜────────⎟ \n\ + ⎝ x + y ⎠ \n\ +──────────────────\n\ + ⎛ 3⎞ \n\ +1 ⎜x - 2⋅y ⎟ ⎛2⎞\n\ +─ + ⎜────────⎟⋅⎜─⎟\n\ +1 ⎝ x + y ⎠ ⎝2⎠\ +""" + expected8 = \ +"""\ + ⎛1 - x⎞ \n\ + ⎜─────⎟ \n\ + ⎝x - y⎠ \n\ +───────────\n\ +1 ⎛1 - x⎞\n\ +─ + ⎜─────⎟\n\ +1 ⎝x - y⎠\ +""" + expected9 = \ +"""\ + ⎛ x + y ⎞ ⎛x - y⎞ \n\ + ⎜───────⎟⋅⎜─────⎟ \n\ + ⎝x - 2⋅y⎠ ⎝x + y⎠ \n\ +─────────────────────────────\n\ +1 ⎛ x + y ⎞ ⎛x - y⎞ ⎛1 - x⎞\n\ +─ - ⎜───────⎟⋅⎜─────⎟⋅⎜─────⎟\n\ +1 ⎝x - 2⋅y⎠ ⎝x + y⎠ ⎝x - y⎠\ +""" + expected10 = \ +"""\ + ⎛1 - x⎞ \n\ + ⎜─────⎟ \n\ + ⎝x - y⎠ \n\ +───────────\n\ +1 ⎛1 - x⎞\n\ +─ - ⎜─────⎟\n\ +1 ⎝x - y⎠\ +""" + assert upretty(Feedback(tf, tf1)) == expected1 + assert upretty(Feedback(tf, tf2*tf1*tf3)) == expected2 + assert upretty(Feedback(tf1, tf2*tf3*tf5)) == expected3 + assert upretty(Feedback(tf1*tf2, tf)) == expected4 + assert upretty(Feedback(tf1*tf2, tf5)) == expected5 + assert upretty(Feedback(tf3*tf5, tf2*tf1)) == expected6 + assert upretty(Feedback(tf4, tf6)) == expected7 + assert upretty(Feedback(tf5, tf)) == expected8 + + assert upretty(Feedback(tf1*tf2, tf5, 1)) == expected9 + assert upretty(Feedback(tf5, tf, 1)) == expected10 + + +def test_pretty_MIMOFeedback(): + tf1 = TransferFunction(x + y, x - 2*y, y) + tf2 = TransferFunction(x - y, x + y, y) + tfm_1 = TransferFunctionMatrix([[tf1, tf2], [tf2, tf1]]) + tfm_2 = TransferFunctionMatrix([[tf2, tf1], [tf1, tf2]]) + tfm_3 = TransferFunctionMatrix([[tf1, tf1], [tf2, tf2]]) + + expected1 = \ +"""\ +⎛ ⎡ x + y x - y ⎤ ⎡ x - y x + y ⎤ ⎞-1 ⎡ x + y x - y ⎤ \n\ +⎜ ⎢─────── ───── ⎥ ⎢ ───── ───────⎥ ⎟ ⎢─────── ───── ⎥ \n\ +⎜ ⎢x - 2⋅y x + y ⎥ ⎢ x + y x - 2⋅y⎥ ⎟ ⎢x - 2⋅y x + y ⎥ \n\ +⎜I - ⎢ ⎥ ⋅⎢ ⎥ ⎟ ⋅ ⎢ ⎥ \n\ +⎜ ⎢ x - y x + y ⎥ ⎢ x + y x - y ⎥ ⎟ ⎢ x - y x + y ⎥ \n\ +⎜ ⎢ ───── ───────⎥ ⎢─────── ───── ⎥ ⎟ ⎢ ───── ───────⎥ \n\ +⎝ ⎣ x + y x - 2⋅y⎦τ ⎣x - 2⋅y x + y ⎦τ⎠ ⎣ x + y x - 2⋅y⎦τ\ +""" + expected2 = \ +"""\ +⎛ ⎡ x + y x - y ⎤ ⎡ x - y x + y ⎤ ⎡ x + y x + y ⎤ ⎞-1 ⎡ x + y x - y ⎤ ⎡ x - y x + y ⎤ \n\ +⎜ ⎢─────── ───── ⎥ ⎢ ───── ───────⎥ ⎢─────── ───────⎥ ⎟ ⎢─────── ───── ⎥ ⎢ ───── ───────⎥ \n\ +⎜ ⎢x - 2⋅y x + y ⎥ ⎢ x + y x - 2⋅y⎥ ⎢x - 2⋅y x - 2⋅y⎥ ⎟ ⎢x - 2⋅y x + y ⎥ ⎢ x + y x - 2⋅y⎥ \n\ +⎜I + ⎢ ⎥ ⋅⎢ ⎥ ⋅⎢ ⎥ ⎟ ⋅ ⎢ ⎥ ⋅⎢ ⎥ \n\ +⎜ ⎢ x - y x + y ⎥ ⎢ x + y x - y ⎥ ⎢ x - y x - y ⎥ ⎟ ⎢ x - y x + y ⎥ ⎢ x + y x - y ⎥ \n\ +⎜ ⎢ ───── ───────⎥ ⎢─────── ───── ⎥ ⎢ ───── ───── ⎥ ⎟ ⎢ ───── ───────⎥ ⎢─────── ───── ⎥ \n\ +⎝ ⎣ x + y x - 2⋅y⎦τ ⎣x - 2⋅y x + y ⎦τ ⎣ x + y x + y ⎦τ⎠ ⎣ x + y x - 2⋅y⎦τ ⎣x - 2⋅y x + y ⎦τ\ +""" + + assert upretty(MIMOFeedback(tfm_1, tfm_2, 1)) == \ + expected1 # Positive MIMOFeedback + assert upretty(MIMOFeedback(tfm_1*tfm_2, tfm_3)) == \ + expected2 # Negative MIMOFeedback (Default) + + +def test_pretty_TransferFunctionMatrix(): + tf1 = TransferFunction(x + y, x - 2*y, y) + tf2 = TransferFunction(x - y, x + y, y) + tf3 = TransferFunction(y**2 - 2*y + 1, y + 5, y) + tf4 = TransferFunction(y, x**2 + x + 1, y) + tf5 = TransferFunction(1 - x, x - y, y) + tf6 = TransferFunction(2, 2, y) + expected1 = \ +"""\ +⎡ x + y ⎤ \n\ +⎢───────⎥ \n\ +⎢x - 2⋅y⎥ \n\ +⎢ ⎥ \n\ +⎢ x - y ⎥ \n\ +⎢ ───── ⎥ \n\ +⎣ x + y ⎦τ\ +""" + expected2 = \ +"""\ +⎡ x + y ⎤ \n\ +⎢ ─────── ⎥ \n\ +⎢ x - 2⋅y ⎥ \n\ +⎢ ⎥ \n\ +⎢ x - y ⎥ \n\ +⎢ ───── ⎥ \n\ +⎢ x + y ⎥ \n\ +⎢ ⎥ \n\ +⎢ 2 ⎥ \n\ +⎢- y + 2⋅y - 1⎥ \n\ +⎢──────────────⎥ \n\ +⎣ y + 5 ⎦τ\ +""" + expected3 = \ +"""\ +⎡ x + y x - y ⎤ \n\ +⎢ ─────── ───── ⎥ \n\ +⎢ x - 2⋅y x + y ⎥ \n\ +⎢ ⎥ \n\ +⎢ 2 ⎥ \n\ +⎢y - 2⋅y + 1 y ⎥ \n\ +⎢──────────── ──────────⎥ \n\ +⎢ y + 5 2 ⎥ \n\ +⎢ x + x + 1⎥ \n\ +⎢ ⎥ \n\ +⎢ 1 - x 2 ⎥ \n\ +⎢ ───── ─ ⎥ \n\ +⎣ x - y 2 ⎦τ\ +""" + expected4 = \ +"""\ +⎡ x - y x + y y ⎤ \n\ +⎢ ───── ─────── ──────────⎥ \n\ +⎢ x + y x - 2⋅y 2 ⎥ \n\ +⎢ x + x + 1⎥ \n\ +⎢ ⎥ \n\ +⎢ 2 ⎥ \n\ +⎢- y + 2⋅y - 1 x - 1 -2 ⎥ \n\ +⎢────────────── ───── ─── ⎥ \n\ +⎣ y + 5 x - y 2 ⎦τ\ +""" + expected5 = \ +"""\ +⎡ x + y x - y x + y y ⎤ \n\ +⎢───────⋅───── ─────── ──────────⎥ \n\ +⎢x - 2⋅y x + y x - 2⋅y 2 ⎥ \n\ +⎢ x + x + 1⎥ \n\ +⎢ ⎥ \n\ +⎢ 1 - x 2 x + y -2 ⎥ \n\ +⎢ ───── + ─ ─────── ─── ⎥ \n\ +⎣ x - y 2 x - 2⋅y 2 ⎦τ\ +""" + + assert upretty(TransferFunctionMatrix([[tf1], [tf2]])) == expected1 + assert upretty(TransferFunctionMatrix([[tf1], [tf2], [-tf3]])) == expected2 + assert upretty(TransferFunctionMatrix([[tf1, tf2], [tf3, tf4], [tf5, tf6]])) == expected3 + assert upretty(TransferFunctionMatrix([[tf2, tf1, tf4], [-tf3, -tf5, -tf6]])) == expected4 + assert upretty(TransferFunctionMatrix([[Series(tf2, tf1), tf1, tf4], [Parallel(tf6, tf5), tf1, -tf6]])) == \ + expected5 + + +def test_pretty_StateSpace(): + ss1 = StateSpace(Matrix([a]), Matrix([b]), Matrix([c]), Matrix([d])) + A = Matrix([[0, 1], [1, 0]]) + B = Matrix([1, 0]) + C = Matrix([[0, 1]]) + D = Matrix([0]) + ss2 = StateSpace(A, B, C, D) + ss3 = StateSpace(Matrix([[-1.5, -2], [1, 0]]), + Matrix([[0.5, 0], [0, 1]]), + Matrix([[0, 1], [0, 2]]), + Matrix([[2, 2], [1, 1]])) + + expected1 = \ +"""\ +⎡[a] [b]⎤\n\ +⎢ ⎥\n\ +⎣[c] [d]⎦\ +""" + expected2 = \ +"""\ +⎡⎡0 1⎤ ⎡1⎤⎤\n\ +⎢⎢ ⎥ ⎢ ⎥⎥\n\ +⎢⎣1 0⎦ ⎣0⎦⎥\n\ +⎢ ⎥\n\ +⎣[0 1] [0]⎦\ +""" + expected3 = \ +"""\ +⎡⎡-1.5 -2⎤ ⎡0.5 0⎤⎤\n\ +⎢⎢ ⎥ ⎢ ⎥⎥\n\ +⎢⎣ 1 0 ⎦ ⎣ 0 1⎦⎥\n\ +⎢ ⎥\n\ +⎢ ⎡0 1⎤ ⎡2 2⎤ ⎥\n\ +⎢ ⎢ ⎥ ⎢ ⎥ ⎥\n\ +⎣ ⎣0 2⎦ ⎣1 1⎦ ⎦\ +""" + + assert upretty(ss1) == expected1 + assert upretty(ss2) == expected2 + assert upretty(ss3) == expected3 + +def test_pretty_order(): + expr = O(1) + ascii_str = \ +"""\ +O(1)\ +""" + ucode_str = \ +"""\ +O(1)\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = O(1/x) + ascii_str = \ +"""\ + /1\\\n\ +O|-|\n\ + \\x/\ +""" + ucode_str = \ +"""\ + ⎛1⎞\n\ +O⎜─⎟\n\ + ⎝x⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = O(x**2 + y**2) + ascii_str = \ +"""\ + / 2 2 \\\n\ +O\\x + y ; (x, y) -> (0, 0)/\ +""" + ucode_str = \ +"""\ + ⎛ 2 2 ⎞\n\ +O⎝x + y ; (x, y) → (0, 0)⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = O(1, (x, oo)) + ascii_str = \ +"""\ +O(1; x -> oo)\ +""" + ucode_str = \ +"""\ +O(1; x → ∞)\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = O(1/x, (x, oo)) + ascii_str = \ +"""\ + /1 \\\n\ +O|-; x -> oo|\n\ + \\x /\ +""" + ucode_str = \ +"""\ + ⎛1 ⎞\n\ +O⎜─; x → ∞⎟\n\ + ⎝x ⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = O(x**2 + y**2, (x, oo), (y, oo)) + ascii_str = \ +"""\ + / 2 2 \\\n\ +O\\x + y ; (x, y) -> (oo, oo)/\ +""" + ucode_str = \ +"""\ + ⎛ 2 2 ⎞\n\ +O⎝x + y ; (x, y) → (∞, ∞)⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_pretty_derivatives(): + # Simple + expr = Derivative(log(x), x, evaluate=False) + ascii_str = \ +"""\ +d \n\ +--(log(x))\n\ +dx \ +""" + ucode_str = \ +"""\ +d \n\ +──(log(x))\n\ +dx \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Derivative(log(x), x, evaluate=False) + x + ascii_str_1 = \ +"""\ + d \n\ +x + --(log(x))\n\ + dx \ +""" + ascii_str_2 = \ +"""\ +d \n\ +--(log(x)) + x\n\ +dx \ +""" + ucode_str_1 = \ +"""\ + d \n\ +x + ──(log(x))\n\ + dx \ +""" + ucode_str_2 = \ +"""\ +d \n\ +──(log(x)) + x\n\ +dx \ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2] + assert upretty(expr) in [ucode_str_1, ucode_str_2] + + # basic partial derivatives + expr = Derivative(log(x + y) + x, x) + ascii_str_1 = \ +"""\ +d \n\ +--(log(x + y) + x)\n\ +dx \ +""" + ascii_str_2 = \ +"""\ +d \n\ +--(x + log(x + y))\n\ +dx \ +""" + ucode_str_1 = \ +"""\ +∂ \n\ +──(log(x + y) + x)\n\ +∂x \ +""" + ucode_str_2 = \ +"""\ +∂ \n\ +──(x + log(x + y))\n\ +∂x \ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2] + assert upretty(expr) in [ucode_str_1, ucode_str_2], upretty(expr) + + # Multiple symbols + expr = Derivative(log(x) + x**2, x, y) + ascii_str_1 = \ +"""\ + 2 \n\ + d / 2\\\n\ +-----\\log(x) + x /\n\ +dy dx \ +""" + ascii_str_2 = \ +"""\ + 2 \n\ + d / 2 \\\n\ +-----\\x + log(x)/\n\ +dy dx \ +""" + ascii_str_3 = \ +"""\ + 2 \n\ + d / 2 \\\n\ +-----\\x + log(x)/\n\ +dy dx \ +""" + ucode_str_1 = \ +"""\ + 2 \n\ + d ⎛ 2⎞\n\ +─────⎝log(x) + x ⎠\n\ +dy dx \ +""" + ucode_str_2 = \ +"""\ + 2 \n\ + d ⎛ 2 ⎞\n\ +─────⎝x + log(x)⎠\n\ +dy dx \ +""" + ucode_str_3 = \ +"""\ + 2 \n\ + d ⎛ 2 ⎞\n\ +─────⎝x + log(x)⎠\n\ +dy dx \ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2, ascii_str_3] + assert upretty(expr) in [ucode_str_1, ucode_str_2, ucode_str_3] + + expr = Derivative(2*x*y, y, x) + x**2 + ascii_str_1 = \ +"""\ + 2 \n\ + d 2\n\ +-----(2*x*y) + x \n\ +dx dy \ +""" + ascii_str_2 = \ +"""\ + 2 \n\ + 2 d \n\ +x + -----(2*x*y)\n\ + dx dy \ +""" + ascii_str_3 = \ +"""\ + 2 \n\ + 2 d \n\ +x + -----(2*x*y)\n\ + dx dy \ +""" + ucode_str_1 = \ +"""\ + 2 \n\ + ∂ 2\n\ +─────(2⋅x⋅y) + x \n\ +∂x ∂y \ +""" + ucode_str_2 = \ +"""\ + 2 \n\ + 2 ∂ \n\ +x + ─────(2⋅x⋅y)\n\ + ∂x ∂y \ +""" + ucode_str_3 = \ +"""\ + 2 \n\ + 2 ∂ \n\ +x + ─────(2⋅x⋅y)\n\ + ∂x ∂y \ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2, ascii_str_3] + assert upretty(expr) in [ucode_str_1, ucode_str_2, ucode_str_3] + + expr = Derivative(2*x*y, x, x) + ascii_str = \ +"""\ + 2 \n\ +d \n\ +---(2*x*y)\n\ + 2 \n\ +dx \ +""" + ucode_str = \ +"""\ + 2 \n\ +∂ \n\ +───(2⋅x⋅y)\n\ + 2 \n\ +∂x \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Derivative(2*x*y, x, 17) + ascii_str = \ +"""\ + 17 \n\ +d \n\ +----(2*x*y)\n\ + 17 \n\ +dx \ +""" + ucode_str = \ +"""\ + 17 \n\ +∂ \n\ +────(2⋅x⋅y)\n\ + 17 \n\ +∂x \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Derivative(2*x*y, x, x, y) + ascii_str = \ +"""\ + 3 \n\ + d \n\ +------(2*x*y)\n\ + 2 \n\ +dy dx \ +""" + ucode_str = \ +"""\ + 3 \n\ + ∂ \n\ +──────(2⋅x⋅y)\n\ + 2 \n\ +∂y ∂x \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + # Greek letters + alpha = Symbol('alpha') + beta = Function('beta') + expr = beta(alpha).diff(alpha) + ascii_str = \ +"""\ + d \n\ +------(beta(alpha))\n\ +dalpha \ +""" + ucode_str = \ +"""\ +d \n\ +──(β(α))\n\ +dα \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Derivative(f(x), (x, n)) + + ascii_str = \ +"""\ + n \n\ +d \n\ +---(f(x))\n\ + n \n\ +dx \ +""" + ucode_str = \ +"""\ + n \n\ +d \n\ +───(f(x))\n\ + n \n\ +dx \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_pretty_integrals(): + expr = Integral(log(x), x) + ascii_str = \ +"""\ + / \n\ + | \n\ + | log(x) dx\n\ + | \n\ +/ \ +""" + ucode_str = \ +"""\ +⌠ \n\ +⎮ log(x) dx\n\ +⌡ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Integral(x**2, x) + ascii_str = \ +"""\ + / \n\ + | \n\ + | 2 \n\ + | x dx\n\ + | \n\ +/ \ +""" + ucode_str = \ +"""\ +⌠ \n\ +⎮ 2 \n\ +⎮ x dx\n\ +⌡ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Integral((sin(x))**2 / (tan(x))**2) + ascii_str = \ +"""\ + / \n\ + | \n\ + | 2 \n\ + | sin (x) \n\ + | ------- dx\n\ + | 2 \n\ + | tan (x) \n\ + | \n\ +/ \ +""" + ucode_str = \ +"""\ +⌠ \n\ +⎮ 2 \n\ +⎮ sin (x) \n\ +⎮ ─────── dx\n\ +⎮ 2 \n\ +⎮ tan (x) \n\ +⌡ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Integral(x**(2**x), x) + ascii_str = \ +"""\ + / \n\ + | \n\ + | / x\\ \n\ + | \\2 / \n\ + | x dx\n\ + | \n\ +/ \ +""" + ucode_str = \ +"""\ +⌠ \n\ +⎮ ⎛ x⎞ \n\ +⎮ ⎝2 ⎠ \n\ +⎮ x dx\n\ +⌡ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Integral(x**2, (x, 1, 2)) + ascii_str = \ +"""\ + 2 \n\ + / \n\ + | \n\ + | 2 \n\ + | x dx\n\ + | \n\ +/ \n\ +1 \ +""" + ucode_str = \ +"""\ +2 \n\ +⌠ \n\ +⎮ 2 \n\ +⎮ x dx\n\ +⌡ \n\ +1 \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Integral(x**2, (x, Rational(1, 2), 10)) + ascii_str = \ +"""\ + 10 \n\ + / \n\ + | \n\ + | 2 \n\ + | x dx\n\ + | \n\ +/ \n\ +1/2 \ +""" + ucode_str = \ +"""\ +10 \n\ +⌠ \n\ +⎮ 2 \n\ +⎮ x dx\n\ +⌡ \n\ +1/2 \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Integral(x**2*y**2, x, y) + ascii_str = \ +"""\ + / / \n\ + | | \n\ + | | 2 2 \n\ + | | x *y dx dy\n\ + | | \n\ +/ / \ +""" + ucode_str = \ +"""\ +⌠ ⌠ \n\ +⎮ ⎮ 2 2 \n\ +⎮ ⎮ x ⋅y dx dy\n\ +⌡ ⌡ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Integral(sin(th)/cos(ph), (th, 0, pi), (ph, 0, 2*pi)) + ascii_str = \ +"""\ + 2*pi pi \n\ + / / \n\ + | | \n\ + | | sin(theta) \n\ + | | ---------- d(theta) d(phi)\n\ + | | cos(phi) \n\ + | | \n\ + / / \n\ +0 0 \ +""" + ucode_str = \ +"""\ +2⋅π π \n\ + ⌠ ⌠ \n\ + ⎮ ⎮ sin(θ) \n\ + ⎮ ⎮ ────── dθ dφ\n\ + ⎮ ⎮ cos(φ) \n\ + ⌡ ⌡ \n\ + 0 0 \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_pretty_matrix(): + # Empty Matrix + expr = Matrix() + ascii_str = "[]" + unicode_str = "[]" + assert pretty(expr) == ascii_str + assert upretty(expr) == unicode_str + expr = Matrix(2, 0, lambda i, j: 0) + ascii_str = "[]" + unicode_str = "[]" + assert pretty(expr) == ascii_str + assert upretty(expr) == unicode_str + expr = Matrix(0, 2, lambda i, j: 0) + ascii_str = "[]" + unicode_str = "[]" + assert pretty(expr) == ascii_str + assert upretty(expr) == unicode_str + expr = Matrix([[x**2 + 1, 1], [y, x + y]]) + ascii_str_1 = \ +"""\ +[ 2 ] +[1 + x 1 ] +[ ] +[ y x + y]\ +""" + ascii_str_2 = \ +"""\ +[ 2 ] +[x + 1 1 ] +[ ] +[ y x + y]\ +""" + ucode_str_1 = \ +"""\ +⎡ 2 ⎤ +⎢1 + x 1 ⎥ +⎢ ⎥ +⎣ y x + y⎦\ +""" + ucode_str_2 = \ +"""\ +⎡ 2 ⎤ +⎢x + 1 1 ⎥ +⎢ ⎥ +⎣ y x + y⎦\ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2] + assert upretty(expr) in [ucode_str_1, ucode_str_2] + + expr = Matrix([[x/y, y, th], [0, exp(I*k*ph), 1]]) + ascii_str = \ +"""\ +[x ] +[- y theta] +[y ] +[ ] +[ I*k*phi ] +[0 e 1 ]\ +""" + ucode_str = \ +"""\ +⎡x ⎤ +⎢─ y θ⎥ +⎢y ⎥ +⎢ ⎥ +⎢ ⅈ⋅k⋅φ ⎥ +⎣0 ℯ 1⎦\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + unicode_str = \ +"""\ +⎡v̇_msc_00 0 0 ⎤ +⎢ ⎥ +⎢ 0 v̇_msc_01 0 ⎥ +⎢ ⎥ +⎣ 0 0 v̇_msc_02⎦\ +""" + + expr = diag(*MatrixSymbol('vdot_msc',1,3)) + assert upretty(expr) == unicode_str + + +def test_pretty_ndim_arrays(): + x, y, z, w = symbols("x y z w") + + for ArrayType in (ImmutableDenseNDimArray, ImmutableSparseNDimArray, MutableDenseNDimArray, MutableSparseNDimArray): + # Basic: scalar array + M = ArrayType(x) + + assert pretty(M) == "x" + assert upretty(M) == "x" + + M = ArrayType([[1/x, y], [z, w]]) + M1 = ArrayType([1/x, y, z]) + + M2 = tensorproduct(M1, M) + M3 = tensorproduct(M, M) + + ascii_str = \ +"""\ +[1 ]\n\ +[- y]\n\ +[x ]\n\ +[ ]\n\ +[z w]\ +""" + ucode_str = \ +"""\ +⎡1 ⎤\n\ +⎢─ y⎥\n\ +⎢x ⎥\n\ +⎢ ⎥\n\ +⎣z w⎦\ +""" + assert pretty(M) == ascii_str + assert upretty(M) == ucode_str + + ascii_str = \ +"""\ +[1 ]\n\ +[- y z]\n\ +[x ]\ +""" + ucode_str = \ +"""\ +⎡1 ⎤\n\ +⎢─ y z⎥\n\ +⎣x ⎦\ +""" + assert pretty(M1) == ascii_str + assert upretty(M1) == ucode_str + + ascii_str = \ +"""\ +[[1 y] ]\n\ +[[-- -] [z ]]\n\ +[[ 2 x] [ y 2 ] [- y*z]]\n\ +[[x ] [ - y ] [x ]]\n\ +[[ ] [ x ] [ ]]\n\ +[[z w] [ ] [ 2 ]]\n\ +[[- -] [y*z w*y] [z w*z]]\n\ +[[x x] ]\ +""" + ucode_str = \ +"""\ +⎡⎡1 y⎤ ⎤\n\ +⎢⎢── ─⎥ ⎡z ⎤⎥\n\ +⎢⎢ 2 x⎥ ⎡ y 2 ⎤ ⎢─ y⋅z⎥⎥\n\ +⎢⎢x ⎥ ⎢ ─ y ⎥ ⎢x ⎥⎥\n\ +⎢⎢ ⎥ ⎢ x ⎥ ⎢ ⎥⎥\n\ +⎢⎢z w⎥ ⎢ ⎥ ⎢ 2 ⎥⎥\n\ +⎢⎢─ ─⎥ ⎣y⋅z w⋅y⎦ ⎣z w⋅z⎦⎥\n\ +⎣⎣x x⎦ ⎦\ +""" + assert pretty(M2) == ascii_str + assert upretty(M2) == ucode_str + + ascii_str = \ +"""\ +[ [1 y] ]\n\ +[ [-- -] ]\n\ +[ [ 2 x] [ y 2 ]]\n\ +[ [x ] [ - y ]]\n\ +[ [ ] [ x ]]\n\ +[ [z w] [ ]]\n\ +[ [- -] [y*z w*y]]\n\ +[ [x x] ]\n\ +[ ]\n\ +[[z ] [ w ]]\n\ +[[- y*z] [ - w*y]]\n\ +[[x ] [ x ]]\n\ +[[ ] [ ]]\n\ +[[ 2 ] [ 2 ]]\n\ +[[z w*z] [w*z w ]]\ +""" + ucode_str = \ +"""\ +⎡ ⎡1 y⎤ ⎤\n\ +⎢ ⎢── ─⎥ ⎥\n\ +⎢ ⎢ 2 x⎥ ⎡ y 2 ⎤⎥\n\ +⎢ ⎢x ⎥ ⎢ ─ y ⎥⎥\n\ +⎢ ⎢ ⎥ ⎢ x ⎥⎥\n\ +⎢ ⎢z w⎥ ⎢ ⎥⎥\n\ +⎢ ⎢─ ─⎥ ⎣y⋅z w⋅y⎦⎥\n\ +⎢ ⎣x x⎦ ⎥\n\ +⎢ ⎥\n\ +⎢⎡z ⎤ ⎡ w ⎤⎥\n\ +⎢⎢─ y⋅z⎥ ⎢ ─ w⋅y⎥⎥\n\ +⎢⎢x ⎥ ⎢ x ⎥⎥\n\ +⎢⎢ ⎥ ⎢ ⎥⎥\n\ +⎢⎢ 2 ⎥ ⎢ 2 ⎥⎥\n\ +⎣⎣z w⋅z⎦ ⎣w⋅z w ⎦⎦\ +""" + assert pretty(M3) == ascii_str + assert upretty(M3) == ucode_str + + Mrow = ArrayType([[x, y, 1 / z]]) + Mcolumn = ArrayType([[x], [y], [1 / z]]) + Mcol2 = ArrayType([Mcolumn.tolist()]) + + ascii_str = \ +"""\ +[[ 1]]\n\ +[[x y -]]\n\ +[[ z]]\ +""" + ucode_str = \ +"""\ +⎡⎡ 1⎤⎤\n\ +⎢⎢x y ─⎥⎥\n\ +⎣⎣ z⎦⎦\ +""" + assert pretty(Mrow) == ascii_str + assert upretty(Mrow) == ucode_str + + ascii_str = \ +"""\ +[x]\n\ +[ ]\n\ +[y]\n\ +[ ]\n\ +[1]\n\ +[-]\n\ +[z]\ +""" + ucode_str = \ +"""\ +⎡x⎤\n\ +⎢ ⎥\n\ +⎢y⎥\n\ +⎢ ⎥\n\ +⎢1⎥\n\ +⎢─⎥\n\ +⎣z⎦\ +""" + assert pretty(Mcolumn) == ascii_str + assert upretty(Mcolumn) == ucode_str + + ascii_str = \ +"""\ +[[x]]\n\ +[[ ]]\n\ +[[y]]\n\ +[[ ]]\n\ +[[1]]\n\ +[[-]]\n\ +[[z]]\ +""" + ucode_str = \ +"""\ +⎡⎡x⎤⎤\n\ +⎢⎢ ⎥⎥\n\ +⎢⎢y⎥⎥\n\ +⎢⎢ ⎥⎥\n\ +⎢⎢1⎥⎥\n\ +⎢⎢─⎥⎥\n\ +⎣⎣z⎦⎦\ +""" + assert pretty(Mcol2) == ascii_str + assert upretty(Mcol2) == ucode_str + + +def test_tensor_TensorProduct(): + A = MatrixSymbol("A", 3, 3) + B = MatrixSymbol("B", 3, 3) + assert upretty(TensorProduct(A, B)) == "A\u2297B" + assert upretty(TensorProduct(A, B, A)) == "A\u2297B\u2297A" + + +def test_diffgeom_print_WedgeProduct(): + from sympy.diffgeom.rn import R2 + from sympy.diffgeom import WedgeProduct + wp = WedgeProduct(R2.dx, R2.dy) + assert upretty(wp) == "ⅆ x∧ⅆ y" + assert pretty(wp) == r"d x/\d y" + + +def test_Adjoint(): + X = MatrixSymbol('X', 2, 2) + Y = MatrixSymbol('Y', 2, 2) + assert pretty(Adjoint(X)) == " +\nX " + assert pretty(Adjoint(X + Y)) == " +\n(X + Y) " + assert pretty(Adjoint(X) + Adjoint(Y)) == " + +\nX + Y " + assert pretty(Adjoint(X*Y)) == " +\n(X*Y) " + assert pretty(Adjoint(Y)*Adjoint(X)) == " + +\nY *X " + assert pretty(Adjoint(X**2)) == " +\n/ 2\\ \n\\X / " + assert pretty(Adjoint(X)**2) == " 2\n/ +\\ \n\\X / " + assert pretty(Adjoint(Inverse(X))) == " +\n/ -1\\ \n\\X / " + assert pretty(Inverse(Adjoint(X))) == " -1\n/ +\\ \n\\X / " + assert pretty(Adjoint(Transpose(X))) == " +\n/ T\\ \n\\X / " + assert pretty(Transpose(Adjoint(X))) == " T\n/ +\\ \n\\X / " + assert upretty(Adjoint(X)) == " †\nX " + assert upretty(Adjoint(X + Y)) == " †\n(X + Y) " + assert upretty(Adjoint(X) + Adjoint(Y)) == " † †\nX + Y " + assert upretty(Adjoint(X*Y)) == " †\n(X⋅Y) " + assert upretty(Adjoint(Y)*Adjoint(X)) == " † †\nY ⋅X " + assert upretty(Adjoint(X**2)) == \ + " †\n⎛ 2⎞ \n⎝X ⎠ " + assert upretty(Adjoint(X)**2) == \ + " 2\n⎛ †⎞ \n⎝X ⎠ " + assert upretty(Adjoint(Inverse(X))) == \ + " †\n⎛ -1⎞ \n⎝X ⎠ " + assert upretty(Inverse(Adjoint(X))) == \ + " -1\n⎛ †⎞ \n⎝X ⎠ " + assert upretty(Adjoint(Transpose(X))) == \ + " †\n⎛ T⎞ \n⎝X ⎠ " + assert upretty(Transpose(Adjoint(X))) == \ + " T\n⎛ †⎞ \n⎝X ⎠ " + m = Matrix(((1, 2), (3, 4))) + assert upretty(Adjoint(m)) == \ + ' †\n'\ + '⎡1 2⎤ \n'\ + '⎢ ⎥ \n'\ + '⎣3 4⎦ ' + assert upretty(Adjoint(m+X)) == \ + ' †\n'\ + '⎛⎡1 2⎤ ⎞ \n'\ + '⎜⎢ ⎥ + X⎟ \n'\ + '⎝⎣3 4⎦ ⎠ ' + assert upretty(Adjoint(BlockMatrix(((OneMatrix(2, 2), X), + (m, ZeroMatrix(2, 2)))))) == \ + ' †\n'\ + '⎡ 𝟙 X⎤ \n'\ + '⎢ ⎥ \n'\ + '⎢⎡1 2⎤ ⎥ \n'\ + '⎢⎢ ⎥ 𝟘⎥ \n'\ + '⎣⎣3 4⎦ ⎦ ' + + +def test_Transpose(): + X = MatrixSymbol('X', 2, 2) + Y = MatrixSymbol('Y', 2, 2) + assert pretty(Transpose(X)) == " T\nX " + assert pretty(Transpose(X + Y)) == " T\n(X + Y) " + assert pretty(Transpose(X) + Transpose(Y)) == " T T\nX + Y " + assert pretty(Transpose(X*Y)) == " T\n(X*Y) " + assert pretty(Transpose(Y)*Transpose(X)) == " T T\nY *X " + assert pretty(Transpose(X**2)) == " T\n/ 2\\ \n\\X / " + assert pretty(Transpose(X)**2) == " 2\n/ T\\ \n\\X / " + assert pretty(Transpose(Inverse(X))) == " T\n/ -1\\ \n\\X / " + assert pretty(Inverse(Transpose(X))) == " -1\n/ T\\ \n\\X / " + assert upretty(Transpose(X)) == " T\nX " + assert upretty(Transpose(X + Y)) == " T\n(X + Y) " + assert upretty(Transpose(X) + Transpose(Y)) == " T T\nX + Y " + assert upretty(Transpose(X*Y)) == " T\n(X⋅Y) " + assert upretty(Transpose(Y)*Transpose(X)) == " T T\nY ⋅X " + assert upretty(Transpose(X**2)) == \ + " T\n⎛ 2⎞ \n⎝X ⎠ " + assert upretty(Transpose(X)**2) == \ + " 2\n⎛ T⎞ \n⎝X ⎠ " + assert upretty(Transpose(Inverse(X))) == \ + " T\n⎛ -1⎞ \n⎝X ⎠ " + assert upretty(Inverse(Transpose(X))) == \ + " -1\n⎛ T⎞ \n⎝X ⎠ " + m = Matrix(((1, 2), (3, 4))) + assert upretty(Transpose(m)) == \ + ' T\n'\ + '⎡1 2⎤ \n'\ + '⎢ ⎥ \n'\ + '⎣3 4⎦ ' + assert upretty(Transpose(m+X)) == \ + ' T\n'\ + '⎛⎡1 2⎤ ⎞ \n'\ + '⎜⎢ ⎥ + X⎟ \n'\ + '⎝⎣3 4⎦ ⎠ ' + assert upretty(Transpose(BlockMatrix(((OneMatrix(2, 2), X), + (m, ZeroMatrix(2, 2)))))) == \ + ' T\n'\ + '⎡ 𝟙 X⎤ \n'\ + '⎢ ⎥ \n'\ + '⎢⎡1 2⎤ ⎥ \n'\ + '⎢⎢ ⎥ 𝟘⎥ \n'\ + '⎣⎣3 4⎦ ⎦ ' + + +def test_pretty_Trace_issue_9044(): + X = Matrix([[1, 2], [3, 4]]) + Y = Matrix([[2, 4], [6, 8]]) + ascii_str_1 = \ +"""\ + /[1 2]\\ +tr|[ ]| + \\[3 4]/\ +""" + ucode_str_1 = \ +"""\ + ⎛⎡1 2⎤⎞ +tr⎜⎢ ⎥⎟ + ⎝⎣3 4⎦⎠\ +""" + ascii_str_2 = \ +"""\ + /[1 2]\\ /[2 4]\\ +tr|[ ]| + tr|[ ]| + \\[3 4]/ \\[6 8]/\ +""" + ucode_str_2 = \ +"""\ + ⎛⎡1 2⎤⎞ ⎛⎡2 4⎤⎞ +tr⎜⎢ ⎥⎟ + tr⎜⎢ ⎥⎟ + ⎝⎣3 4⎦⎠ ⎝⎣6 8⎦⎠\ +""" + assert pretty(Trace(X)) == ascii_str_1 + assert upretty(Trace(X)) == ucode_str_1 + + assert pretty(Trace(X) + Trace(Y)) == ascii_str_2 + assert upretty(Trace(X) + Trace(Y)) == ucode_str_2 + + +def test_MatrixSlice(): + n = Symbol('n', integer=True) + x, y, z, w, t, = symbols('x y z w t') + X = MatrixSymbol('X', n, n) + Y = MatrixSymbol('Y', 10, 10) + Z = MatrixSymbol('Z', 10, 10) + + expr = MatrixSlice(X, (None, None, None), (None, None, None)) + assert pretty(expr) == upretty(expr) == 'X[:, :]' + expr = X[x:x + 1, y:y + 1] + assert pretty(expr) == upretty(expr) == 'X[x:x + 1, y:y + 1]' + expr = X[x:x + 1:2, y:y + 1:2] + assert pretty(expr) == upretty(expr) == 'X[x:x + 1:2, y:y + 1:2]' + expr = X[:x, y:] + assert pretty(expr) == upretty(expr) == 'X[:x, y:]' + expr = X[:x, y:] + assert pretty(expr) == upretty(expr) == 'X[:x, y:]' + expr = X[x:, :y] + assert pretty(expr) == upretty(expr) == 'X[x:, :y]' + expr = X[x:y, z:w] + assert pretty(expr) == upretty(expr) == 'X[x:y, z:w]' + expr = X[x:y:t, w:t:x] + assert pretty(expr) == upretty(expr) == 'X[x:y:t, w:t:x]' + expr = X[x::y, t::w] + assert pretty(expr) == upretty(expr) == 'X[x::y, t::w]' + expr = X[:x:y, :t:w] + assert pretty(expr) == upretty(expr) == 'X[:x:y, :t:w]' + expr = X[::x, ::y] + assert pretty(expr) == upretty(expr) == 'X[::x, ::y]' + expr = MatrixSlice(X, (0, None, None), (0, None, None)) + assert pretty(expr) == upretty(expr) == 'X[:, :]' + expr = MatrixSlice(X, (None, n, None), (None, n, None)) + assert pretty(expr) == upretty(expr) == 'X[:, :]' + expr = MatrixSlice(X, (0, n, None), (0, n, None)) + assert pretty(expr) == upretty(expr) == 'X[:, :]' + expr = MatrixSlice(X, (0, n, 2), (0, n, 2)) + assert pretty(expr) == upretty(expr) == 'X[::2, ::2]' + expr = X[1:2:3, 4:5:6] + assert pretty(expr) == upretty(expr) == 'X[1:2:3, 4:5:6]' + expr = X[1:3:5, 4:6:8] + assert pretty(expr) == upretty(expr) == 'X[1:3:5, 4:6:8]' + expr = X[1:10:2] + assert pretty(expr) == upretty(expr) == 'X[1:10:2, :]' + expr = Y[:5, 1:9:2] + assert pretty(expr) == upretty(expr) == 'Y[:5, 1:9:2]' + expr = Y[:5, 1:10:2] + assert pretty(expr) == upretty(expr) == 'Y[:5, 1::2]' + expr = Y[5, :5:2] + assert pretty(expr) == upretty(expr) == 'Y[5:6, :5:2]' + expr = X[0:1, 0:1] + assert pretty(expr) == upretty(expr) == 'X[:1, :1]' + expr = X[0:1:2, 0:1:2] + assert pretty(expr) == upretty(expr) == 'X[:1:2, :1:2]' + expr = (Y + Z)[2:, 2:] + assert pretty(expr) == upretty(expr) == '(Y + Z)[2:, 2:]' + + +def test_MatrixExpressions(): + n = Symbol('n', integer=True) + X = MatrixSymbol('X', n, n) + + assert pretty(X) == upretty(X) == "X" + + # Apply function elementwise (`ElementwiseApplyFunc`): + + expr = (X.T*X).applyfunc(sin) + + ascii_str = """\ + / T \\\n\ +(d -> sin(d)).\\X *X/\ +""" + ucode_str = """\ + ⎛ T ⎞\n\ +(d ↦ sin(d))˳⎝X ⋅X⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + lamda = Lambda(x, 1/x) + expr = (n*X).applyfunc(lamda) + ascii_str = """\ +/ 1\\ \n\ +|x -> -|.(n*X)\n\ +\\ x/ \ +""" + ucode_str = """\ +⎛ 1⎞ \n\ +⎜x ↦ ─⎟˳(n⋅X)\n\ +⎝ x⎠ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_pretty_dotproduct(): + from sympy.matrices.expressions.dotproduct import DotProduct + n = symbols("n", integer=True) + A = MatrixSymbol('A', n, 1) + B = MatrixSymbol('B', n, 1) + C = Matrix(1, 3, [1, 2, 3]) + D = Matrix(1, 3, [1, 3, 4]) + + assert pretty(DotProduct(A, B)) == "A*B" + assert pretty(DotProduct(C, D)) == "[1 2 3]*[1 3 4]" + assert upretty(DotProduct(A, B)) == "A⋅B" + assert upretty(DotProduct(C, D)) == "[1 2 3]⋅[1 3 4]" + + +def test_pretty_Determinant(): + from sympy.matrices import Determinant, Inverse, BlockMatrix, OneMatrix, ZeroMatrix + m = Matrix(((1, 2), (3, 4))) + assert upretty(Determinant(m)) == '│1 2│\n│ │\n│3 4│' + assert upretty(Determinant(Inverse(m))) == \ + '│ -1│\n'\ + '│⎡1 2⎤ │\n'\ + '│⎢ ⎥ │\n'\ + '│⎣3 4⎦ │' + X = MatrixSymbol('X', 2, 2) + assert upretty(Determinant(X)) == '│X│' + assert upretty(Determinant(X + m)) == \ + '│⎡1 2⎤ │\n'\ + '│⎢ ⎥ + X│\n'\ + '│⎣3 4⎦ │' + assert upretty(Determinant(BlockMatrix(((OneMatrix(2, 2), X), + (m, ZeroMatrix(2, 2)))))) == \ + '│ 𝟙 X│\n'\ + '│ │\n'\ + '│⎡1 2⎤ │\n'\ + '│⎢ ⎥ 𝟘│\n'\ + '│⎣3 4⎦ │' + + +def test_pretty_piecewise(): + expr = Piecewise((x, x < 1), (x**2, True)) + ascii_str = \ +"""\ +/x for x < 1\n\ +| \n\ +< 2 \n\ +|x otherwise\n\ +\\ \ +""" + ucode_str = \ +"""\ +⎧x for x < 1\n\ +⎪ \n\ +⎨ 2 \n\ +⎪x otherwise\n\ +⎩ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = -Piecewise((x, x < 1), (x**2, True)) + ascii_str = \ +"""\ + //x for x < 1\\\n\ + || |\n\ +-|< 2 |\n\ + ||x otherwise|\n\ + \\\\ /\ +""" + ucode_str = \ +"""\ + ⎛⎧x for x < 1⎞\n\ + ⎜⎪ ⎟\n\ +-⎜⎨ 2 ⎟\n\ + ⎜⎪x otherwise⎟\n\ + ⎝⎩ ⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = x + Piecewise((x, x > 0), (y, True)) + Piecewise((x/y, x < 2), + (y**2, x > 2), (1, True)) + 1 + ascii_str = \ +"""\ + //x \\ \n\ + ||- for x < 2| \n\ + ||y | \n\ + //x for x > 0\\ || | \n\ +x + |< | + |< 2 | + 1\n\ + \\\\y otherwise/ ||y for x > 2| \n\ + || | \n\ + ||1 otherwise| \n\ + \\\\ / \ +""" + ucode_str = \ +"""\ + ⎛⎧x ⎞ \n\ + ⎜⎪─ for x < 2⎟ \n\ + ⎜⎪y ⎟ \n\ + ⎛⎧x for x > 0⎞ ⎜⎪ ⎟ \n\ +x + ⎜⎨ ⎟ + ⎜⎨ 2 ⎟ + 1\n\ + ⎝⎩y otherwise⎠ ⎜⎪y for x > 2⎟ \n\ + ⎜⎪ ⎟ \n\ + ⎜⎪1 otherwise⎟ \n\ + ⎝⎩ ⎠ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = x - Piecewise((x, x > 0), (y, True)) + Piecewise((x/y, x < 2), + (y**2, x > 2), (1, True)) + 1 + ascii_str = \ +"""\ + //x \\ \n\ + ||- for x < 2| \n\ + ||y | \n\ + //x for x > 0\\ || | \n\ +x - |< | + |< 2 | + 1\n\ + \\\\y otherwise/ ||y for x > 2| \n\ + || | \n\ + ||1 otherwise| \n\ + \\\\ / \ +""" + ucode_str = \ +"""\ + ⎛⎧x ⎞ \n\ + ⎜⎪─ for x < 2⎟ \n\ + ⎜⎪y ⎟ \n\ + ⎛⎧x for x > 0⎞ ⎜⎪ ⎟ \n\ +x - ⎜⎨ ⎟ + ⎜⎨ 2 ⎟ + 1\n\ + ⎝⎩y otherwise⎠ ⎜⎪y for x > 2⎟ \n\ + ⎜⎪ ⎟ \n\ + ⎜⎪1 otherwise⎟ \n\ + ⎝⎩ ⎠ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = x*Piecewise((x, x > 0), (y, True)) + ascii_str = \ +"""\ + //x for x > 0\\\n\ +x*|< |\n\ + \\\\y otherwise/\ +""" + ucode_str = \ +"""\ + ⎛⎧x for x > 0⎞\n\ +x⋅⎜⎨ ⎟\n\ + ⎝⎩y otherwise⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Piecewise((x, x > 0), (y, True))*Piecewise((x/y, x < 2), (y**2, x > + 2), (1, True)) + ascii_str = \ +"""\ + //x \\\n\ + ||- for x < 2|\n\ + ||y |\n\ +//x for x > 0\\ || |\n\ +|< |*|< 2 |\n\ +\\\\y otherwise/ ||y for x > 2|\n\ + || |\n\ + ||1 otherwise|\n\ + \\\\ /\ +""" + ucode_str = \ +"""\ + ⎛⎧x ⎞\n\ + ⎜⎪─ for x < 2⎟\n\ + ⎜⎪y ⎟\n\ +⎛⎧x for x > 0⎞ ⎜⎪ ⎟\n\ +⎜⎨ ⎟⋅⎜⎨ 2 ⎟\n\ +⎝⎩y otherwise⎠ ⎜⎪y for x > 2⎟\n\ + ⎜⎪ ⎟\n\ + ⎜⎪1 otherwise⎟\n\ + ⎝⎩ ⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = -Piecewise((x, x > 0), (y, True))*Piecewise((x/y, x < 2), (y**2, x + > 2), (1, True)) + ascii_str = \ +"""\ + //x \\\n\ + ||- for x < 2|\n\ + ||y |\n\ + //x for x > 0\\ || |\n\ +-|< |*|< 2 |\n\ + \\\\y otherwise/ ||y for x > 2|\n\ + || |\n\ + ||1 otherwise|\n\ + \\\\ /\ +""" + ucode_str = \ +"""\ + ⎛⎧x ⎞\n\ + ⎜⎪─ for x < 2⎟\n\ + ⎜⎪y ⎟\n\ + ⎛⎧x for x > 0⎞ ⎜⎪ ⎟\n\ +-⎜⎨ ⎟⋅⎜⎨ 2 ⎟\n\ + ⎝⎩y otherwise⎠ ⎜⎪y for x > 2⎟\n\ + ⎜⎪ ⎟\n\ + ⎜⎪1 otherwise⎟\n\ + ⎝⎩ ⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Piecewise((0, Abs(1/y) < 1), (1, Abs(y) < 1), (y*meijerg(((2, 1), + ()), ((), (1, 0)), 1/y), True)) + ascii_str = \ +"""\ +/ 1 \n\ +| 0 for --- < 1\n\ +| |y| \n\ +| \n\ +< 1 for |y| < 1\n\ +| \n\ +| __0, 2 /1, 2 | 1\\ \n\ +|y*/__ | | -| otherwise \n\ +\\ \\_|2, 2 \\ 0, 1 | y/ \ +""" + ucode_str = \ +"""\ +⎧ 1 \n\ +⎪ 0 for ─── < 1\n\ +⎪ │y│ \n\ +⎪ \n\ +⎨ 1 for │y│ < 1\n\ +⎪ \n\ +⎪ ╭─╮0, 2 ⎛1, 2 │ 1⎞ \n\ +⎪y⋅│╶┐ ⎜ │ ─⎟ otherwise \n\ +⎩ ╰─╯2, 2 ⎝ 0, 1 │ y⎠ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + # XXX: We have to use evaluate=False here because Piecewise._eval_power + # denests the power. + expr = Pow(Piecewise((x, x > 0), (y, True)), 2, evaluate=False) + ascii_str = \ +"""\ + 2\n\ +//x for x > 0\\ \n\ +|< | \n\ +\\\\y otherwise/ \ +""" + ucode_str = \ +"""\ + 2\n\ +⎛⎧x for x > 0⎞ \n\ +⎜⎨ ⎟ \n\ +⎝⎩y otherwise⎠ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_pretty_ITE(): + expr = ITE(x, y, z) + assert pretty(expr) == ( + '/y for x \n' + '< \n' + '\\z otherwise' + ) + assert upretty(expr) == """\ +⎧y for x \n\ +⎨ \n\ +⎩z otherwise\ +""" + + +def test_pretty_seq(): + expr = () + ascii_str = \ +"""\ +()\ +""" + ucode_str = \ +"""\ +()\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = [] + ascii_str = \ +"""\ +[]\ +""" + ucode_str = \ +"""\ +[]\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = {} + expr_2 = {} + ascii_str = \ +"""\ +{}\ +""" + ucode_str = \ +"""\ +{}\ +""" + assert pretty(expr) == ascii_str + assert pretty(expr_2) == ascii_str + assert upretty(expr) == ucode_str + assert upretty(expr_2) == ucode_str + + expr = (1/x,) + ascii_str = \ +"""\ + 1 \n\ +(-,)\n\ + x \ +""" + ucode_str = \ +"""\ +⎛1 ⎞\n\ +⎜─,⎟\n\ +⎝x ⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = [x**2, 1/x, x, y, sin(th)**2/cos(ph)**2] + ascii_str = \ +"""\ + 2 \n\ + 2 1 sin (theta) \n\ +[x , -, x, y, -----------]\n\ + x 2 \n\ + cos (phi) \ +""" + ucode_str = \ +"""\ +⎡ 2 ⎤\n\ +⎢ 2 1 sin (θ)⎥\n\ +⎢x , ─, x, y, ───────⎥\n\ +⎢ x 2 ⎥\n\ +⎣ cos (φ)⎦\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = (x**2, 1/x, x, y, sin(th)**2/cos(ph)**2) + ascii_str = \ +"""\ + 2 \n\ + 2 1 sin (theta) \n\ +(x , -, x, y, -----------)\n\ + x 2 \n\ + cos (phi) \ +""" + ucode_str = \ +"""\ +⎛ 2 ⎞\n\ +⎜ 2 1 sin (θ)⎟\n\ +⎜x , ─, x, y, ───────⎟\n\ +⎜ x 2 ⎟\n\ +⎝ cos (φ)⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Tuple(x**2, 1/x, x, y, sin(th)**2/cos(ph)**2) + ascii_str = \ +"""\ + 2 \n\ + 2 1 sin (theta) \n\ +(x , -, x, y, -----------)\n\ + x 2 \n\ + cos (phi) \ +""" + ucode_str = \ +"""\ +⎛ 2 ⎞\n\ +⎜ 2 1 sin (θ)⎟\n\ +⎜x , ─, x, y, ───────⎟\n\ +⎜ x 2 ⎟\n\ +⎝ cos (φ)⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = {x: sin(x)} + expr_2 = Dict({x: sin(x)}) + ascii_str = \ +"""\ +{x: sin(x)}\ +""" + ucode_str = \ +"""\ +{x: sin(x)}\ +""" + assert pretty(expr) == ascii_str + assert pretty(expr_2) == ascii_str + assert upretty(expr) == ucode_str + assert upretty(expr_2) == ucode_str + + expr = {1/x: 1/y, x: sin(x)**2} + expr_2 = Dict({1/x: 1/y, x: sin(x)**2}) + ascii_str = \ +"""\ + 1 1 2 \n\ +{-: -, x: sin (x)}\n\ + x y \ +""" + ucode_str = \ +"""\ +⎧1 1 2 ⎫\n\ +⎨─: ─, x: sin (x)⎬\n\ +⎩x y ⎭\ +""" + assert pretty(expr) == ascii_str + assert pretty(expr_2) == ascii_str + assert upretty(expr) == ucode_str + assert upretty(expr_2) == ucode_str + + # There used to be a bug with pretty-printing sequences of even height. + expr = [x**2] + ascii_str = \ +"""\ + 2 \n\ +[x ]\ +""" + ucode_str = \ +"""\ +⎡ 2⎤\n\ +⎣x ⎦\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = (x**2,) + ascii_str = \ +"""\ + 2 \n\ +(x ,)\ +""" + ucode_str = \ +"""\ +⎛ 2 ⎞\n\ +⎝x ,⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Tuple(x**2) + ascii_str = \ +"""\ + 2 \n\ +(x ,)\ +""" + ucode_str = \ +"""\ +⎛ 2 ⎞\n\ +⎝x ,⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = {x**2: 1} + expr_2 = Dict({x**2: 1}) + ascii_str = \ +"""\ + 2 \n\ +{x : 1}\ +""" + ucode_str = \ +"""\ +⎧ 2 ⎫\n\ +⎨x : 1⎬\n\ +⎩ ⎭\ +""" + assert pretty(expr) == ascii_str + assert pretty(expr_2) == ascii_str + assert upretty(expr) == ucode_str + assert upretty(expr_2) == ucode_str + + +def test_any_object_in_sequence(): + # Cf. issue 5306 + b1 = Basic() + b2 = Basic(Basic()) + + expr = [b2, b1] + assert pretty(expr) == "[Basic(Basic()), Basic()]" + assert upretty(expr) == "[Basic(Basic()), Basic()]" + + expr = {b2, b1} + assert pretty(expr) == "{Basic(), Basic(Basic())}" + assert upretty(expr) == "{Basic(), Basic(Basic())}" + + expr = {b2: b1, b1: b2} + expr2 = Dict({b2: b1, b1: b2}) + assert pretty(expr) == "{Basic(): Basic(Basic()), Basic(Basic()): Basic()}" + assert pretty( + expr2) == "{Basic(): Basic(Basic()), Basic(Basic()): Basic()}" + assert upretty( + expr) == "{Basic(): Basic(Basic()), Basic(Basic()): Basic()}" + assert upretty( + expr2) == "{Basic(): Basic(Basic()), Basic(Basic()): Basic()}" + + +def test_print_builtin_set(): + assert pretty(set()) == 'set()' + assert upretty(set()) == 'set()' + + assert pretty(frozenset()) == 'frozenset()' + assert upretty(frozenset()) == 'frozenset()' + + s1 = {1/x, x} + s2 = frozenset(s1) + + assert pretty(s1) == \ +"""\ + 1 \n\ +{-, x} + x \ +""" + assert upretty(s1) == \ +"""\ +⎧1 ⎫ +⎨─, x⎬ +⎩x ⎭\ +""" + + assert pretty(s2) == \ +"""\ + 1 \n\ +frozenset({-, x}) + x \ +""" + assert upretty(s2) == \ +"""\ + ⎛⎧1 ⎫⎞ +frozenset⎜⎨─, x⎬⎟ + ⎝⎩x ⎭⎠\ +""" + + +def test_pretty_sets(): + s = FiniteSet + assert pretty(s(*[x*y, x**2])) == \ +"""\ + 2 \n\ +{x , x*y}\ +""" + assert pretty(s(*range(1, 6))) == "{1, 2, 3, 4, 5}" + assert pretty(s(*range(1, 13))) == "{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}" + + assert pretty({x*y, x**2}) == \ +"""\ + 2 \n\ +{x , x*y}\ +""" + assert pretty(set(range(1, 6))) == "{1, 2, 3, 4, 5}" + assert pretty(set(range(1, 13))) == \ + "{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}" + + assert pretty(frozenset([x*y, x**2])) == \ +"""\ + 2 \n\ +frozenset({x , x*y})\ +""" + assert pretty(frozenset(range(1, 6))) == "frozenset({1, 2, 3, 4, 5})" + assert pretty(frozenset(range(1, 13))) == \ + "frozenset({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12})" + + assert pretty(Range(0, 3, 1)) == '{0, 1, 2}' + + ascii_str = '{0, 1, ..., 29}' + ucode_str = '{0, 1, …, 29}' + assert pretty(Range(0, 30, 1)) == ascii_str + assert upretty(Range(0, 30, 1)) == ucode_str + + ascii_str = '{30, 29, ..., 2}' + ucode_str = '{30, 29, …, 2}' + assert pretty(Range(30, 1, -1)) == ascii_str + assert upretty(Range(30, 1, -1)) == ucode_str + + ascii_str = '{0, 2, ...}' + ucode_str = '{0, 2, …}' + assert pretty(Range(0, oo, 2)) == ascii_str + assert upretty(Range(0, oo, 2)) == ucode_str + + ascii_str = '{..., 2, 0}' + ucode_str = '{…, 2, 0}' + assert pretty(Range(oo, -2, -2)) == ascii_str + assert upretty(Range(oo, -2, -2)) == ucode_str + + ascii_str = '{-2, -3, ...}' + ucode_str = '{-2, -3, …}' + assert pretty(Range(-2, -oo, -1)) == ascii_str + assert upretty(Range(-2, -oo, -1)) == ucode_str + + +def test_pretty_SetExpr(): + iv = Interval(1, 3) + se = SetExpr(iv) + ascii_str = "SetExpr([1, 3])" + ucode_str = "SetExpr([1, 3])" + assert pretty(se) == ascii_str + assert upretty(se) == ucode_str + + +def test_pretty_ImageSet(): + imgset = ImageSet(Lambda((x, y), x + y), {1, 2, 3}, {3, 4}) + ascii_str = '{x + y | x in {1, 2, 3}, y in {3, 4}}' + ucode_str = '{x + y │ x ∊ {1, 2, 3}, y ∊ {3, 4}}' + assert pretty(imgset) == ascii_str + assert upretty(imgset) == ucode_str + + imgset = ImageSet(Lambda(((x, y),), x + y), ProductSet({1, 2, 3}, {3, 4})) + ascii_str = '{x + y | (x, y) in {1, 2, 3} x {3, 4}}' + ucode_str = '{x + y │ (x, y) ∊ {1, 2, 3} × {3, 4}}' + assert pretty(imgset) == ascii_str + assert upretty(imgset) == ucode_str + + imgset = ImageSet(Lambda(x, x**2), S.Naturals) + ascii_str = '''\ + 2 \n\ +{x | x in Naturals}''' + ucode_str = '''\ +⎧ 2 │ ⎫\n\ +⎨x │ x ∊ ℕ⎬\n\ +⎩ │ ⎭''' + assert pretty(imgset) == ascii_str + assert upretty(imgset) == ucode_str + + # TODO: The "x in N" parts below should be centered independently of the + # 1/x**2 fraction + imgset = ImageSet(Lambda(x, 1/x**2), S.Naturals) + ascii_str = '''\ + 1 \n\ +{-- | x in Naturals} + 2 \n\ + x ''' + ucode_str = '''\ +⎧1 │ ⎫\n\ +⎪── │ x ∊ ℕ⎪\n\ +⎨ 2 │ ⎬\n\ +⎪x │ ⎪\n\ +⎩ │ ⎭''' + assert pretty(imgset) == ascii_str + assert upretty(imgset) == ucode_str + + imgset = ImageSet(Lambda((x, y), 1/(x + y)**2), S.Naturals, S.Naturals) + ascii_str = '''\ + 1 \n\ +{-------- | x in Naturals, y in Naturals} + 2 \n\ + (x + y) ''' + ucode_str = '''\ +⎧ 1 │ ⎫ +⎪──────── │ x ∊ ℕ, y ∊ ℕ⎪ +⎨ 2 │ ⎬ +⎪(x + y) │ ⎪ +⎩ │ ⎭''' + assert pretty(imgset) == ascii_str + assert upretty(imgset) == ucode_str + + # issue 23449 centering issue + assert upretty([Symbol("ihat") / (Symbol("i") + 1)]) == '''\ +⎡ î ⎤ +⎢─────⎥ +⎣i + 1⎦\ +''' + assert upretty(Matrix([Symbol("ihat"), Symbol("i") + 1])) == '''\ +⎡ î ⎤ +⎢ ⎥ +⎣i + 1⎦\ +''' + + +def test_pretty_ConditionSet(): + ascii_str = '{x | x in (-oo, oo) and sin(x) = 0}' + ucode_str = '{x │ x ∊ ℝ ∧ (sin(x) = 0)}' + assert pretty(ConditionSet(x, Eq(sin(x), 0), S.Reals)) == ascii_str + assert upretty(ConditionSet(x, Eq(sin(x), 0), S.Reals)) == ucode_str + + assert pretty(ConditionSet(x, Contains(x, S.Reals, evaluate=False), FiniteSet(1))) == '{1}' + assert upretty(ConditionSet(x, Contains(x, S.Reals, evaluate=False), FiniteSet(1))) == '{1}' + + assert pretty(ConditionSet(x, And(x > 1, x < -1), FiniteSet(1, 2, 3))) == "EmptySet" + assert upretty(ConditionSet(x, And(x > 1, x < -1), FiniteSet(1, 2, 3))) == "∅" + + assert pretty(ConditionSet(x, Or(x > 1, x < -1), FiniteSet(1, 2))) == '{2}' + assert upretty(ConditionSet(x, Or(x > 1, x < -1), FiniteSet(1, 2))) == '{2}' + + condset = ConditionSet(x, 1/x**2 > 0) + ascii_str = '''\ + 1 \n\ +{x | -- > 0} + 2 \n\ + x ''' + ucode_str = '''\ +⎧ │ ⎛1 ⎞⎫ +⎪x │ ⎜── > 0⎟⎪ +⎨ │ ⎜ 2 ⎟⎬ +⎪ │ ⎝x ⎠⎪ +⎩ │ ⎭''' + assert pretty(condset) == ascii_str + assert upretty(condset) == ucode_str + + condset = ConditionSet(x, 1/x**2 > 0, S.Reals) + ascii_str = '''\ + 1 \n\ +{x | x in (-oo, oo) and -- > 0} + 2 \n\ + x ''' + ucode_str = '''\ +⎧ │ ⎛1 ⎞⎫ +⎪x │ x ∊ ℝ ∧ ⎜── > 0⎟⎪ +⎨ │ ⎜ 2 ⎟⎬ +⎪ │ ⎝x ⎠⎪ +⎩ │ ⎭''' + assert pretty(condset) == ascii_str + assert upretty(condset) == ucode_str + + +def test_pretty_ComplexRegion(): + from sympy.sets.fancysets import ComplexRegion + cregion = ComplexRegion(Interval(3, 5)*Interval(4, 6)) + ascii_str = '{x + y*I | x, y in [3, 5] x [4, 6]}' + ucode_str = '{x + y⋅ⅈ │ x, y ∊ [3, 5] × [4, 6]}' + assert pretty(cregion) == ascii_str + assert upretty(cregion) == ucode_str + + cregion = ComplexRegion(Interval(0, 1)*Interval(0, 2*pi), polar=True) + ascii_str = '{r*(I*sin(theta) + cos(theta)) | r, theta in [0, 1] x [0, 2*pi)}' + ucode_str = '{r⋅(ⅈ⋅sin(θ) + cos(θ)) │ r, θ ∊ [0, 1] × [0, 2⋅π)}' + assert pretty(cregion) == ascii_str + assert upretty(cregion) == ucode_str + + cregion = ComplexRegion(Interval(3, 1/a**2)*Interval(4, 6)) + ascii_str = '''\ + 1 \n\ +{x + y*I | x, y in [3, --] x [4, 6]} + 2 \n\ + a ''' + ucode_str = '''\ +⎧ │ ⎡ 1 ⎤ ⎫ +⎪x + y⋅ⅈ │ x, y ∊ ⎢3, ──⎥ × [4, 6]⎪ +⎨ │ ⎢ 2⎥ ⎬ +⎪ │ ⎣ a ⎦ ⎪ +⎩ │ ⎭''' + assert pretty(cregion) == ascii_str + assert upretty(cregion) == ucode_str + + cregion = ComplexRegion(Interval(0, 1/a**2)*Interval(0, 2*pi), polar=True) + ascii_str = '''\ + 1 \n\ +{r*(I*sin(theta) + cos(theta)) | r, theta in [0, --] x [0, 2*pi)} + 2 \n\ + a ''' + ucode_str = '''\ +⎧ │ ⎡ 1 ⎤ ⎫ +⎪r⋅(ⅈ⋅sin(θ) + cos(θ)) │ r, θ ∊ ⎢0, ──⎥ × [0, 2⋅π)⎪ +⎨ │ ⎢ 2⎥ ⎬ +⎪ │ ⎣ a ⎦ ⎪ +⎩ │ ⎭''' + assert pretty(cregion) == ascii_str + assert upretty(cregion) == ucode_str + + +def test_pretty_Union_issue_10414(): + a, b = Interval(2, 3), Interval(4, 7) + ucode_str = '[2, 3] ∪ [4, 7]' + ascii_str = '[2, 3] U [4, 7]' + assert upretty(Union(a, b)) == ucode_str + assert pretty(Union(a, b)) == ascii_str + + +def test_pretty_Intersection_issue_10414(): + x, y, z, w = symbols('x, y, z, w') + a, b = Interval(x, y), Interval(z, w) + ucode_str = '[x, y] ∩ [z, w]' + ascii_str = '[x, y] n [z, w]' + assert upretty(Intersection(a, b)) == ucode_str + assert pretty(Intersection(a, b)) == ascii_str + + +def test_ProductSet_exponent(): + ucode_str = ' 1\n[0, 1] ' + assert upretty(Interval(0, 1)**1) == ucode_str + ucode_str = ' 2\n[0, 1] ' + assert upretty(Interval(0, 1)**2) == ucode_str + + +def test_ProductSet_parenthesis(): + ucode_str = '([4, 7] × {1, 2}) ∪ ([2, 3] × [4, 7])' + + a, b = Interval(2, 3), Interval(4, 7) + assert upretty(Union(a*b, b*FiniteSet(1, 2))) == ucode_str + + +def test_ProductSet_prod_char_issue_10413(): + ascii_str = '[2, 3] x [4, 7]' + ucode_str = '[2, 3] × [4, 7]' + + a, b = Interval(2, 3), Interval(4, 7) + assert pretty(a*b) == ascii_str + assert upretty(a*b) == ucode_str + + +def test_pretty_sequences(): + s1 = SeqFormula(a**2, (0, oo)) + s2 = SeqPer((1, 2)) + + ascii_str = '[0, 1, 4, 9, ...]' + ucode_str = '[0, 1, 4, 9, …]' + + assert pretty(s1) == ascii_str + assert upretty(s1) == ucode_str + + ascii_str = '[1, 2, 1, 2, ...]' + ucode_str = '[1, 2, 1, 2, …]' + assert pretty(s2) == ascii_str + assert upretty(s2) == ucode_str + + s3 = SeqFormula(a**2, (0, 2)) + s4 = SeqPer((1, 2), (0, 2)) + + ascii_str = '[0, 1, 4]' + ucode_str = '[0, 1, 4]' + + assert pretty(s3) == ascii_str + assert upretty(s3) == ucode_str + + ascii_str = '[1, 2, 1]' + ucode_str = '[1, 2, 1]' + assert pretty(s4) == ascii_str + assert upretty(s4) == ucode_str + + s5 = SeqFormula(a**2, (-oo, 0)) + s6 = SeqPer((1, 2), (-oo, 0)) + + ascii_str = '[..., 9, 4, 1, 0]' + ucode_str = '[…, 9, 4, 1, 0]' + + assert pretty(s5) == ascii_str + assert upretty(s5) == ucode_str + + ascii_str = '[..., 2, 1, 2, 1]' + ucode_str = '[…, 2, 1, 2, 1]' + assert pretty(s6) == ascii_str + assert upretty(s6) == ucode_str + + ascii_str = '[1, 3, 5, 11, ...]' + ucode_str = '[1, 3, 5, 11, …]' + + assert pretty(SeqAdd(s1, s2)) == ascii_str + assert upretty(SeqAdd(s1, s2)) == ucode_str + + ascii_str = '[1, 3, 5]' + ucode_str = '[1, 3, 5]' + + assert pretty(SeqAdd(s3, s4)) == ascii_str + assert upretty(SeqAdd(s3, s4)) == ucode_str + + ascii_str = '[..., 11, 5, 3, 1]' + ucode_str = '[…, 11, 5, 3, 1]' + + assert pretty(SeqAdd(s5, s6)) == ascii_str + assert upretty(SeqAdd(s5, s6)) == ucode_str + + ascii_str = '[0, 2, 4, 18, ...]' + ucode_str = '[0, 2, 4, 18, …]' + + assert pretty(SeqMul(s1, s2)) == ascii_str + assert upretty(SeqMul(s1, s2)) == ucode_str + + ascii_str = '[0, 2, 4]' + ucode_str = '[0, 2, 4]' + + assert pretty(SeqMul(s3, s4)) == ascii_str + assert upretty(SeqMul(s3, s4)) == ucode_str + + ascii_str = '[..., 18, 4, 2, 0]' + ucode_str = '[…, 18, 4, 2, 0]' + + assert pretty(SeqMul(s5, s6)) == ascii_str + assert upretty(SeqMul(s5, s6)) == ucode_str + + # Sequences with symbolic limits, issue 12629 + s7 = SeqFormula(a**2, (a, 0, x)) + raises(NotImplementedError, lambda: pretty(s7)) + raises(NotImplementedError, lambda: upretty(s7)) + + b = Symbol('b') + s8 = SeqFormula(b*a**2, (a, 0, 2)) + ascii_str = '[0, b, 4*b]' + ucode_str = '[0, b, 4⋅b]' + assert pretty(s8) == ascii_str + assert upretty(s8) == ucode_str + + +def test_pretty_FourierSeries(): + f = fourier_series(x, (x, -pi, pi)) + + ascii_str = \ +"""\ + 2*sin(3*x) \n\ +2*sin(x) - sin(2*x) + ---------- + ...\n\ + 3 \ +""" + + ucode_str = \ +"""\ + 2⋅sin(3⋅x) \n\ +2⋅sin(x) - sin(2⋅x) + ────────── + …\n\ + 3 \ +""" + + assert pretty(f) == ascii_str + assert upretty(f) == ucode_str + + +def test_pretty_FormalPowerSeries(): + f = fps(log(1 + x)) + + + ascii_str = \ +"""\ + oo \n\ +____ \n\ +\\ ` \n\ + \\ -k k \n\ + \\ -(-1) *x \n\ + / -----------\n\ + / k \n\ +/___, \n\ +k = 1 \ +""" + + ucode_str = \ +"""\ + ∞ \n\ +____ \n\ +╲ \n\ + ╲ -k k \n\ + ╲ -(-1) ⋅x \n\ + ╱ ───────────\n\ + ╱ k \n\ +╱ \n\ +‾‾‾‾ \n\ +k = 1 \ +""" + + assert pretty(f) == ascii_str + assert upretty(f) == ucode_str + + +def test_pretty_limits(): + expr = Limit(x, x, oo) + ascii_str = \ +"""\ + lim x\n\ +x->oo \ +""" + ucode_str = \ +"""\ +lim x\n\ +x─→∞ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Limit(x**2, x, 0) + ascii_str = \ +"""\ + 2\n\ + lim x \n\ +x->0+ \ +""" + ucode_str = \ +"""\ + 2\n\ + lim x \n\ +x─→0⁺ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Limit(1/x, x, 0) + ascii_str = \ +"""\ + 1\n\ + lim -\n\ +x->0+x\ +""" + ucode_str = \ +"""\ + 1\n\ + lim ─\n\ +x─→0⁺x\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Limit(sin(x)/x, x, 0) + ascii_str = \ +"""\ + /sin(x)\\\n\ + lim |------|\n\ +x->0+\\ x /\ +""" + ucode_str = \ +"""\ + ⎛sin(x)⎞\n\ + lim ⎜──────⎟\n\ +x─→0⁺⎝ x ⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Limit(sin(x)/x, x, 0, "-") + ascii_str = \ +"""\ + /sin(x)\\\n\ + lim |------|\n\ +x->0-\\ x /\ +""" + ucode_str = \ +"""\ + ⎛sin(x)⎞\n\ + lim ⎜──────⎟\n\ +x─→0⁻⎝ x ⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Limit(x + sin(x), x, 0) + ascii_str = \ +"""\ + lim (x + sin(x))\n\ +x->0+ \ +""" + ucode_str = \ +"""\ + lim (x + sin(x))\n\ +x─→0⁺ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Limit(x, x, 0)**2 + ascii_str = \ +"""\ + 2\n\ +/ lim x\\ \n\ +\\x->0+ / \ +""" + ucode_str = \ +"""\ + 2\n\ +⎛ lim x⎞ \n\ +⎝x─→0⁺ ⎠ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Limit(x*Limit(y/2,y,0), x, 0) + ascii_str = \ +"""\ + / /y\\\\\n\ + lim |x* lim |-||\n\ +x->0+\\ y->0+\\2//\ +""" + ucode_str = \ +"""\ + ⎛ ⎛y⎞⎞\n\ + lim ⎜x⋅ lim ⎜─⎟⎟\n\ +x─→0⁺⎝ y─→0⁺⎝2⎠⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = 2*Limit(x*Limit(y/2,y,0), x, 0) + ascii_str = \ +"""\ + / /y\\\\\n\ +2* lim |x* lim |-||\n\ + x->0+\\ y->0+\\2//\ +""" + ucode_str = \ +"""\ + ⎛ ⎛y⎞⎞\n\ +2⋅ lim ⎜x⋅ lim ⎜─⎟⎟\n\ + x─→0⁺⎝ y─→0⁺⎝2⎠⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Limit(sin(x), x, 0, dir='+-') + ascii_str = \ +"""\ +lim sin(x)\n\ +x->0 \ +""" + ucode_str = \ +"""\ +lim sin(x)\n\ +x─→0 \ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_pretty_ComplexRootOf(): + expr = rootof(x**5 + 11*x - 2, 0) + ascii_str = \ +"""\ + / 5 \\\n\ +CRootOf\\x + 11*x - 2, 0/\ +""" + ucode_str = \ +"""\ + ⎛ 5 ⎞\n\ +CRootOf⎝x + 11⋅x - 2, 0⎠\ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_pretty_RootSum(): + expr = RootSum(x**5 + 11*x - 2, auto=False) + ascii_str = \ +"""\ + / 5 \\\n\ +RootSum\\x + 11*x - 2/\ +""" + ucode_str = \ +"""\ + ⎛ 5 ⎞\n\ +RootSum⎝x + 11⋅x - 2⎠\ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = RootSum(x**5 + 11*x - 2, Lambda(z, exp(z))) + ascii_str = \ +"""\ + / 5 z\\\n\ +RootSum\\x + 11*x - 2, z -> e /\ +""" + ucode_str = \ +"""\ + ⎛ 5 z⎞\n\ +RootSum⎝x + 11⋅x - 2, z ↦ ℯ ⎠\ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_GroebnerBasis(): + expr = groebner([], x, y) + + ascii_str = \ +"""\ +GroebnerBasis([], x, y, domain=ZZ, order=lex)\ +""" + ucode_str = \ +"""\ +GroebnerBasis([], x, y, domain=ℤ, order=lex)\ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + F = [x**2 - 3*y - x + 1, y**2 - 2*x + y - 1] + expr = groebner(F, x, y, order='grlex') + + ascii_str = \ +"""\ + /[ 2 2 ] \\\n\ +GroebnerBasis\\[x - x - 3*y + 1, y - 2*x + y - 1], x, y, domain=ZZ, order=grlex/\ +""" + ucode_str = \ +"""\ + ⎛⎡ 2 2 ⎤ ⎞\n\ +GroebnerBasis⎝⎣x - x - 3⋅y + 1, y - 2⋅x + y - 1⎦, x, y, domain=ℤ, order=grlex⎠\ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = expr.fglm('lex') + + ascii_str = \ +"""\ + /[ 2 4 3 2 ] \\\n\ +GroebnerBasis\\[2*x - y - y + 1, y + 2*y - 3*y - 16*y + 7], x, y, domain=ZZ, order=lex/\ +""" + ucode_str = \ +"""\ + ⎛⎡ 2 4 3 2 ⎤ ⎞\n\ +GroebnerBasis⎝⎣2⋅x - y - y + 1, y + 2⋅y - 3⋅y - 16⋅y + 7⎦, x, y, domain=ℤ, order=lex⎠\ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_pretty_UniversalSet(): + assert pretty(S.UniversalSet) == "UniversalSet" + assert upretty(S.UniversalSet) == '𝕌' + + +def test_pretty_Boolean(): + expr = Not(x, evaluate=False) + + assert pretty(expr) == "Not(x)" + assert upretty(expr) == "¬x" + + expr = And(x, y) + + assert pretty(expr) == "And(x, y)" + assert upretty(expr) == "x ∧ y" + + expr = Or(x, y) + + assert pretty(expr) == "Or(x, y)" + assert upretty(expr) == "x ∨ y" + + syms = symbols('a:f') + expr = And(*syms) + + assert pretty(expr) == "And(a, b, c, d, e, f)" + assert upretty(expr) == "a ∧ b ∧ c ∧ d ∧ e ∧ f" + + expr = Or(*syms) + + assert pretty(expr) == "Or(a, b, c, d, e, f)" + assert upretty(expr) == "a ∨ b ∨ c ∨ d ∨ e ∨ f" + + expr = Xor(x, y, evaluate=False) + + assert pretty(expr) == "Xor(x, y)" + assert upretty(expr) == "x ⊻ y" + + expr = Nand(x, y, evaluate=False) + + assert pretty(expr) == "Nand(x, y)" + assert upretty(expr) == "x ⊼ y" + + expr = Nor(x, y, evaluate=False) + + assert pretty(expr) == "Nor(x, y)" + assert upretty(expr) == "x ⊽ y" + + expr = Implies(x, y, evaluate=False) + + assert pretty(expr) == "Implies(x, y)" + assert upretty(expr) == "x → y" + + # don't sort args + expr = Implies(y, x, evaluate=False) + + assert pretty(expr) == "Implies(y, x)" + assert upretty(expr) == "y → x" + + expr = Equivalent(x, y, evaluate=False) + + assert pretty(expr) == "Equivalent(x, y)" + assert upretty(expr) == "x ⇔ y" + + expr = Equivalent(y, x, evaluate=False) + + assert pretty(expr) == "Equivalent(x, y)" + assert upretty(expr) == "x ⇔ y" + + +def test_pretty_Domain(): + expr = FF(23) + + assert pretty(expr) == "GF(23)" + assert upretty(expr) == "ℤ₂₃" + + expr = ZZ + + assert pretty(expr) == "ZZ" + assert upretty(expr) == "ℤ" + + expr = QQ + + assert pretty(expr) == "QQ" + assert upretty(expr) == "ℚ" + + expr = RR + + assert pretty(expr) == "RR" + assert upretty(expr) == "ℝ" + + expr = QQ[x] + + assert pretty(expr) == "QQ[x]" + assert upretty(expr) == "ℚ[x]" + + expr = QQ[x, y] + + assert pretty(expr) == "QQ[x, y]" + assert upretty(expr) == "ℚ[x, y]" + + expr = ZZ.frac_field(x) + + assert pretty(expr) == "ZZ(x)" + assert upretty(expr) == "ℤ(x)" + + expr = ZZ.frac_field(x, y) + + assert pretty(expr) == "ZZ(x, y)" + assert upretty(expr) == "ℤ(x, y)" + + expr = QQ.poly_ring(x, y, order=grlex) + + assert pretty(expr) == "QQ[x, y, order=grlex]" + assert upretty(expr) == "ℚ[x, y, order=grlex]" + + expr = QQ.poly_ring(x, y, order=ilex) + + assert pretty(expr) == "QQ[x, y, order=ilex]" + assert upretty(expr) == "ℚ[x, y, order=ilex]" + + +def test_pretty_prec(): + assert xpretty(S("0.3"), full_prec=True, wrap_line=False) == "0.300000000000000" + assert xpretty(S("0.3"), full_prec="auto", wrap_line=False) == "0.300000000000000" + assert xpretty(S("0.3"), full_prec=False, wrap_line=False) == "0.3" + assert xpretty(S("0.3")*x, full_prec=True, use_unicode=False, wrap_line=False) in [ + "0.300000000000000*x", + "x*0.300000000000000" + ] + assert xpretty(S("0.3")*x, full_prec="auto", use_unicode=False, wrap_line=False) in [ + "0.3*x", + "x*0.3" + ] + assert xpretty(S("0.3")*x, full_prec=False, use_unicode=False, wrap_line=False) in [ + "0.3*x", + "x*0.3" + ] + + +def test_pprint(): + import sys + from io import StringIO + fd = StringIO() + sso = sys.stdout + sys.stdout = fd + try: + pprint(pi, use_unicode=False, wrap_line=False) + finally: + sys.stdout = sso + assert fd.getvalue() == 'pi\n' + + +def test_pretty_class(): + """Test that the printer dispatcher correctly handles classes.""" + class C: + pass # C has no .__class__ and this was causing problems + + class D: + pass + + assert pretty( C ) == str( C ) + assert pretty( D ) == str( D ) + + +def test_pretty_no_wrap_line(): + huge_expr = 0 + for i in range(20): + huge_expr += i*sin(i + x) + assert xpretty(huge_expr ).find('\n') != -1 + assert xpretty(huge_expr, wrap_line=False).find('\n') == -1 + + +def test_settings(): + raises(TypeError, lambda: pretty(S(4), method="garbage")) + + +def test_pretty_sum(): + from sympy.abc import x, a, b, k, m, n + + expr = Sum(k**k, (k, 0, n)) + ascii_str = \ +"""\ + n \n\ +___ \n\ +\\ ` \n\ + \\ k\n\ + / k \n\ +/__, \n\ +k = 0 \ +""" + ucode_str = \ +"""\ + n \n\ + ___ \n\ + ╲ \n\ + ╲ k\n\ + ╱ k \n\ + ╱ \n\ + ‾‾‾ \n\ +k = 0 \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Sum(k**k, (k, oo, n)) + ascii_str = \ +"""\ + n \n\ + ___ \n\ + \\ ` \n\ + \\ k\n\ + / k \n\ + /__, \n\ +k = oo \ +""" + ucode_str = \ +"""\ + n \n\ + ___ \n\ + ╲ \n\ + ╲ k\n\ + ╱ k \n\ + ╱ \n\ + ‾‾‾ \n\ +k = ∞ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Sum(k**(Integral(x**n, (x, -oo, oo))), (k, 0, n**n)) + ascii_str = \ +"""\ + n \n\ + n \n\ +______ \n\ +\\ ` \n\ + \\ oo \n\ + \\ / \n\ + \\ | \n\ + \\ | n \n\ + ) | x dx\n\ + / | \n\ + / / \n\ + / -oo \n\ + / k \n\ +/_____, \n\ + k = 0 \ +""" + ucode_str = \ +"""\ + n \n\ + n \n\ +______ \n\ +╲ \n\ + ╲ \n\ + ╲ ∞ \n\ + ╲ ⌠ \n\ + ╲ ⎮ n \n\ + ╱ ⎮ x dx\n\ + ╱ ⌡ \n\ + ╱ -∞ \n\ + ╱ k \n\ +╱ \n\ +‾‾‾‾‾‾ \n\ +k = 0 \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Sum(k**( + Integral(x**n, (x, -oo, oo))), (k, 0, Integral(x**x, (x, -oo, oo)))) + ascii_str = \ +"""\ + oo \n\ + / \n\ + | \n\ + | x \n\ + | x dx \n\ + | \n\ +/ \n\ +-oo \n\ + ______ \n\ + \\ ` \n\ + \\ oo \n\ + \\ / \n\ + \\ | \n\ + \\ | n \n\ + ) | x dx\n\ + / | \n\ + / / \n\ + / -oo \n\ + / k \n\ + /_____, \n\ + k = 0 \ +""" + ucode_str = \ +"""\ +∞ \n\ +⌠ \n\ +⎮ x \n\ +⎮ x dx \n\ +⌡ \n\ +-∞ \n\ + ______ \n\ + ╲ \n\ + ╲ \n\ + ╲ ∞ \n\ + ╲ ⌠ \n\ + ╲ ⎮ n \n\ + ╱ ⎮ x dx\n\ + ╱ ⌡ \n\ + ╱ -∞ \n\ + ╱ k \n\ + ╱ \n\ + ‾‾‾‾‾‾ \n\ + k = 0 \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Sum(k**(Integral(x**n, (x, -oo, oo))), ( + k, x + n + x**2 + n**2 + (x/n) + (1/x), Integral(x**x, (x, -oo, oo)))) + ascii_str = \ +"""\ + oo \n\ + / \n\ + | \n\ + | x \n\ + | x dx \n\ + | \n\ + / \n\ + -oo \n\ + ______ \n\ + \\ ` \n\ + \\ oo \n\ + \\ / \n\ + \\ | \n\ + \\ | n \n\ + ) | x dx\n\ + / | \n\ + / / \n\ + / -oo \n\ + / k \n\ + /_____, \n\ + 2 2 1 x \n\ +k = n + n + x + x + - + - \n\ + x n \ +""" + ucode_str = \ +"""\ + ∞ \n\ + ⌠ \n\ + ⎮ x \n\ + ⎮ x dx \n\ + ⌡ \n\ + -∞ \n\ + ______ \n\ + ╲ \n\ + ╲ \n\ + ╲ ∞ \n\ + ╲ ⌠ \n\ + ╲ ⎮ n \n\ + ╱ ⎮ x dx\n\ + ╱ ⌡ \n\ + ╱ -∞ \n\ + ╱ k \n\ + ╱ \n\ + ‾‾‾‾‾‾ \n\ + 2 2 1 x \n\ +k = n + n + x + x + ─ + ─ \n\ + x n \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Sum(k**( + Integral(x**n, (x, -oo, oo))), (k, 0, x + n + x**2 + n**2 + (x/n) + (1/x))) + ascii_str = \ +"""\ + 2 2 1 x \n\ +n + n + x + x + - + - \n\ + x n \n\ + ______ \n\ + \\ ` \n\ + \\ oo \n\ + \\ / \n\ + \\ | \n\ + \\ | n \n\ + ) | x dx\n\ + / | \n\ + / / \n\ + / -oo \n\ + / k \n\ + /_____, \n\ + k = 0 \ +""" + ucode_str = \ +"""\ + 2 2 1 x \n\ +n + n + x + x + ─ + ─ \n\ + x n \n\ + ______ \n\ + ╲ \n\ + ╲ \n\ + ╲ ∞ \n\ + ╲ ⌠ \n\ + ╲ ⎮ n \n\ + ╱ ⎮ x dx\n\ + ╱ ⌡ \n\ + ╱ -∞ \n\ + ╱ k \n\ + ╱ \n\ + ‾‾‾‾‾‾ \n\ + k = 0 \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Sum(x, (x, 0, oo)) + ascii_str = \ +"""\ + oo \n\ + __ \n\ + \\ ` \n\ + ) x\n\ + /_, \n\ +x = 0 \ +""" + ucode_str = \ +"""\ + ∞ \n\ + ___ \n\ + ╲ \n\ + ╲ \n\ + ╱ x\n\ + ╱ \n\ + ‾‾‾ \n\ +x = 0 \ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Sum(x**2, (x, 0, oo)) + ascii_str = \ +"""\ + oo \n\ +___ \n\ +\\ ` \n\ + \\ 2\n\ + / x \n\ +/__, \n\ +x = 0 \ +""" + ucode_str = \ +"""\ + ∞ \n\ + ___ \n\ + ╲ \n\ + ╲ 2\n\ + ╱ x \n\ + ╱ \n\ + ‾‾‾ \n\ +x = 0 \ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Sum(x/2, (x, 0, oo)) + ascii_str = \ +"""\ + oo \n\ +___ \n\ +\\ ` \n\ + \\ x\n\ + ) -\n\ + / 2\n\ +/__, \n\ +x = 0 \ +""" + ucode_str = \ +"""\ + ∞ \n\ +____ \n\ +╲ \n\ + ╲ \n\ + ╲ x\n\ + ╱ ─\n\ + ╱ 2\n\ +╱ \n\ +‾‾‾‾ \n\ +x = 0 \ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Sum(x**3/2, (x, 0, oo)) + ascii_str = \ +"""\ + oo \n\ +____ \n\ +\\ ` \n\ + \\ 3\n\ + \\ x \n\ + / --\n\ + / 2 \n\ +/___, \n\ +x = 0 \ +""" + ucode_str = \ +"""\ + ∞ \n\ +____ \n\ +╲ \n\ + ╲ 3\n\ + ╲ x \n\ + ╱ ──\n\ + ╱ 2 \n\ +╱ \n\ +‾‾‾‾ \n\ +x = 0 \ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Sum((x**3*y**(x/2))**n, (x, 0, oo)) + ascii_str = \ +"""\ + oo \n\ +____ \n\ +\\ ` \n\ + \\ n\n\ + \\ / x\\ \n\ + ) | -| \n\ + / | 3 2| \n\ + / \\x *y / \n\ +/___, \n\ +x = 0 \ +""" + ucode_str = \ +"""\ + ∞ \n\ +_____ \n\ +╲ \n\ + ╲ \n\ + ╲ n\n\ + ╲ ⎛ x⎞ \n\ + ╱ ⎜ ─⎟ \n\ + ╱ ⎜ 3 2⎟ \n\ + ╱ ⎝x ⋅y ⎠ \n\ +╱ \n\ +‾‾‾‾‾ \n\ +x = 0 \ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Sum(1/x**2, (x, 0, oo)) + ascii_str = \ +"""\ + oo \n\ +____ \n\ +\\ ` \n\ + \\ 1 \n\ + \\ --\n\ + / 2\n\ + / x \n\ +/___, \n\ +x = 0 \ +""" + ucode_str = \ +"""\ + ∞ \n\ +____ \n\ +╲ \n\ + ╲ 1 \n\ + ╲ ──\n\ + ╱ 2\n\ + ╱ x \n\ +╱ \n\ +‾‾‾‾ \n\ +x = 0 \ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Sum(1/y**(a/b), (x, 0, oo)) + ascii_str = \ +"""\ + oo \n\ +____ \n\ +\\ ` \n\ + \\ -a \n\ + \\ ---\n\ + / b \n\ + / y \n\ +/___, \n\ +x = 0 \ +""" + ucode_str = \ +"""\ + ∞ \n\ +____ \n\ +╲ \n\ + ╲ -a \n\ + ╲ ───\n\ + ╱ b \n\ + ╱ y \n\ +╱ \n\ +‾‾‾‾ \n\ +x = 0 \ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Sum(1/y**(a/b), (x, 0, oo), (y, 1, 2)) + ascii_str = \ +"""\ + 2 oo \n\ +____ ____ \n\ +\\ ` \\ ` \n\ + \\ \\ -a\n\ + \\ \\ --\n\ + / / b \n\ + / / y \n\ +/___, /___, \n\ +y = 1 x = 0 \ +""" + ucode_str = \ +"""\ + 2 ∞ \n\ +____ ____ \n\ +╲ ╲ \n\ + ╲ ╲ -a\n\ + ╲ ╲ ──\n\ + ╱ ╱ b \n\ + ╱ ╱ y \n\ +╱ ╱ \n\ +‾‾‾‾ ‾‾‾‾ \n\ +y = 1 x = 0 \ +""" + expr = Sum(1/(1 + 1/( + 1 + 1/k)) + 1, (k, 111, 1 + 1/n), (k, 1/(1 + m), oo)) + 1/(1 + 1/k) + ascii_str = \ +"""\ + 1 \n\ + 1 + - \n\ + oo n \n\ + _____ _____ \n\ + \\ ` \\ ` \n\ + \\ \\ / 1 \\ \n\ + \\ \\ |1 + ---------| \n\ + \\ \\ | 1 | 1 \n\ + ) ) | 1 + -----| + -----\n\ + / / | 1| 1\n\ + / / | 1 + -| 1 + -\n\ + / / \\ k/ k\n\ + /____, /____, \n\ + 1 k = 111 \n\ +k = ----- \n\ + m + 1 \ +""" + ucode_str = \ +"""\ + 1 \n\ + 1 + ─ \n\ + ∞ n \n\ + ______ ______ \n\ + ╲ ╲ \n\ + ╲ ╲ \n\ + ╲ ╲ ⎛ 1 ⎞ \n\ + ╲ ╲ ⎜1 + ─────────⎟ \n\ + ╲ ╲ ⎜ 1 ⎟ 1 \n\ + ╱ ╱ ⎜ 1 + ─────⎟ + ─────\n\ + ╱ ╱ ⎜ 1⎟ 1\n\ + ╱ ╱ ⎜ 1 + ─⎟ 1 + ─\n\ + ╱ ╱ ⎝ k⎠ k\n\ + ╱ ╱ \n\ + ‾‾‾‾‾‾ ‾‾‾‾‾‾ \n\ + 1 k = 111 \n\ +k = ───── \n\ + m + 1 \ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_units(): + expr = joule + ascii_str1 = \ +"""\ + 2\n\ +kilogram*meter \n\ +---------------\n\ + 2 \n\ + second \ +""" + unicode_str1 = \ +"""\ + 2\n\ +kilogram⋅meter \n\ +───────────────\n\ + 2 \n\ + second \ +""" + + ascii_str2 = \ +"""\ + 2\n\ +3*x*y*kilogram*meter \n\ +---------------------\n\ + 2 \n\ + second \ +""" + unicode_str2 = \ +"""\ + 2\n\ +3⋅x⋅y⋅kilogram⋅meter \n\ +─────────────────────\n\ + 2 \n\ + second \ +""" + + from sympy.physics.units import kg, m, s + assert upretty(expr) == "joule" + assert pretty(expr) == "joule" + assert upretty(expr.convert_to(kg*m**2/s**2)) == unicode_str1 + assert pretty(expr.convert_to(kg*m**2/s**2)) == ascii_str1 + assert upretty(3*kg*x*m**2*y/s**2) == unicode_str2 + assert pretty(3*kg*x*m**2*y/s**2) == ascii_str2 + + +def test_pretty_Subs(): + f = Function('f') + expr = Subs(f(x), x, ph**2) + ascii_str = \ +"""\ +(f(x))| 2\n\ + |x=phi \ +""" + unicode_str = \ +"""\ +(f(x))│ 2\n\ + │x=φ \ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == unicode_str + + expr = Subs(f(x).diff(x), x, 0) + ascii_str = \ +"""\ +/d \\| \n\ +|--(f(x))|| \n\ +\\dx /|x=0\ +""" + unicode_str = \ +"""\ +⎛d ⎞│ \n\ +⎜──(f(x))⎟│ \n\ +⎝dx ⎠│x=0\ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == unicode_str + + expr = Subs(f(x).diff(x)/y, (x, y), (0, Rational(1, 2))) + ascii_str = \ +"""\ +/d \\| \n\ +|--(f(x))|| \n\ +|dx || \n\ +|--------|| \n\ +\\ y /|x=0, y=1/2\ +""" + unicode_str = \ +"""\ +⎛d ⎞│ \n\ +⎜──(f(x))⎟│ \n\ +⎜dx ⎟│ \n\ +⎜────────⎟│ \n\ +⎝ y ⎠│x=0, y=1/2\ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == unicode_str + + +def test_gammas(): + assert upretty(lowergamma(x, y)) == "γ(x, y)" + assert upretty(uppergamma(x, y)) == "Γ(x, y)" + assert xpretty(gamma(x), use_unicode=True) == 'Γ(x)' + assert xpretty(gamma, use_unicode=True) == 'Γ' + assert xpretty(symbols('gamma', cls=Function)(x), use_unicode=True) == 'γ(x)' + assert xpretty(symbols('gamma', cls=Function), use_unicode=True) == 'γ' + + +def test_beta(): + assert xpretty(beta(x,y), use_unicode=True) == 'Β(x, y)' + assert xpretty(beta(x,y), use_unicode=False) == 'B(x, y)' + assert xpretty(beta, use_unicode=True) == 'Β' + assert xpretty(beta, use_unicode=False) == 'B' + mybeta = Function('beta') + assert xpretty(mybeta(x), use_unicode=True) == 'β(x)' + assert xpretty(mybeta(x, y, z), use_unicode=False) == 'beta(x, y, z)' + assert xpretty(mybeta, use_unicode=True) == 'β' + + +# test that notation passes to subclasses of the same name only +def test_function_subclass_different_name(): + class mygamma(gamma): + pass + assert xpretty(mygamma, use_unicode=True) == r"mygamma" + assert xpretty(mygamma(x), use_unicode=True) == r"mygamma(x)" + + +def test_SingularityFunction(): + assert xpretty(SingularityFunction(x, 0, n), use_unicode=True) == ( +"""\ + n\n\ + \ +""") + assert xpretty(SingularityFunction(x, 1, n), use_unicode=True) == ( +"""\ + n\n\ + \ +""") + assert xpretty(SingularityFunction(x, -1, n), use_unicode=True) == ( +"""\ + n\n\ + \ +""") + assert xpretty(SingularityFunction(x, a, n), use_unicode=True) == ( +"""\ + n\n\ +<-a + x> \ +""") + assert xpretty(SingularityFunction(x, y, n), use_unicode=True) == ( +"""\ + n\n\ + \ +""") + assert xpretty(SingularityFunction(x, 0, n), use_unicode=False) == ( +"""\ + n\n\ + \ +""") + assert xpretty(SingularityFunction(x, 1, n), use_unicode=False) == ( +"""\ + n\n\ + \ +""") + assert xpretty(SingularityFunction(x, -1, n), use_unicode=False) == ( +"""\ + n\n\ + \ +""") + assert xpretty(SingularityFunction(x, a, n), use_unicode=False) == ( +"""\ + n\n\ +<-a + x> \ +""") + assert xpretty(SingularityFunction(x, y, n), use_unicode=False) == ( +"""\ + n\n\ + \ +""") + + +def test_deltas(): + assert xpretty(DiracDelta(x), use_unicode=True) == 'δ(x)' + assert xpretty(DiracDelta(x, 1), use_unicode=True) == \ +"""\ + (1) \n\ +δ (x)\ +""" + assert xpretty(x*DiracDelta(x, 1), use_unicode=True) == \ +"""\ + (1) \n\ +x⋅δ (x)\ +""" + + +def test_hyper(): + expr = hyper((), (), z) + ucode_str = \ +"""\ + ┌─ ⎛ │ ⎞\n\ + ├─ ⎜ │ z⎟\n\ +0╵ 0 ⎝ │ ⎠\ +""" + ascii_str = \ +"""\ + _ \n\ + |_ / | \\\n\ + | | | z|\n\ +0 0 \\ | /\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = hyper((), (1,), x) + ucode_str = \ +"""\ + ┌─ ⎛ │ ⎞\n\ + ├─ ⎜ │ x⎟\n\ +0╵ 1 ⎝1 │ ⎠\ +""" + ascii_str = \ +"""\ + _ \n\ + |_ / | \\\n\ + | | | x|\n\ +0 1 \\1 | /\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = hyper([2], [1], x) + ucode_str = \ +"""\ + ┌─ ⎛2 │ ⎞\n\ + ├─ ⎜ │ x⎟\n\ +1╵ 1 ⎝1 │ ⎠\ +""" + ascii_str = \ +"""\ + _ \n\ + |_ /2 | \\\n\ + | | | x|\n\ +1 1 \\1 | /\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = hyper((pi/3, -2*k), (3, 4, 5, -3), x) + ucode_str = \ +"""\ + ⎛ π │ ⎞\n\ + ┌─ ⎜ ─, -2⋅k │ ⎟\n\ + ├─ ⎜ 3 │ x⎟\n\ +2╵ 4 ⎜ │ ⎟\n\ + ⎝-3, 3, 4, 5 │ ⎠\ +""" + ascii_str = \ +"""\ + \n\ + _ / pi | \\\n\ + |_ | --, -2*k | |\n\ + | | 3 | x|\n\ +2 4 | | |\n\ + \\-3, 3, 4, 5 | /\ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = hyper((pi, S('2/3'), -2*k), (3, 4, 5, -3), x**2) + ucode_str = \ +"""\ + ┌─ ⎛2/3, π, -2⋅k │ 2⎞\n\ + ├─ ⎜ │ x ⎟\n\ +3╵ 4 ⎝-3, 3, 4, 5 │ ⎠\ +""" + ascii_str = \ +"""\ + _ \n\ + |_ /2/3, pi, -2*k | 2\\ + | | | x | +3 4 \\ -3, 3, 4, 5 | /""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = hyper([1, 2], [3, 4], 1/(1/(1/(1/x + 1) + 1) + 1)) + ucode_str = \ +"""\ + ⎛ │ 1 ⎞\n\ + ⎜ │ ─────────────⎟\n\ + ⎜ │ 1 ⎟\n\ + ┌─ ⎜1, 2 │ 1 + ─────────⎟\n\ + ├─ ⎜ │ 1 ⎟\n\ +2╵ 2 ⎜3, 4 │ 1 + ─────⎟\n\ + ⎜ │ 1⎟\n\ + ⎜ │ 1 + ─⎟\n\ + ⎝ │ x⎠\ +""" + + ascii_str = \ +"""\ + \n\ + / | 1 \\\n\ + | | -------------|\n\ + _ | | 1 |\n\ + |_ |1, 2 | 1 + ---------|\n\ + | | | 1 |\n\ +2 2 |3, 4 | 1 + -----|\n\ + | | 1|\n\ + | | 1 + -|\n\ + \\ | x/\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_meijerg(): + expr = meijerg([pi, pi, x], [1], [0, 1], [1, 2, 3], z) + ucode_str = \ +"""\ +╭─╮2, 3 ⎛π, π, x 1 │ ⎞\n\ +│╶┐ ⎜ │ z⎟\n\ +╰─╯4, 5 ⎝ 0, 1 1, 2, 3 │ ⎠\ +""" + ascii_str = \ +"""\ + __2, 3 /pi, pi, x 1 | \\\n\ +/__ | | z|\n\ +\\_|4, 5 \\ 0, 1 1, 2, 3 | /\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = meijerg([1, pi/7], [2, pi, 5], [], [], z**2) + ucode_str = \ +"""\ + ⎛ π │ ⎞\n\ +╭─╮0, 2 ⎜1, ─ 2, 5, π │ 2⎟\n\ +│╶┐ ⎜ 7 │ z ⎟\n\ +╰─╯5, 0 ⎜ │ ⎟\n\ + ⎝ │ ⎠\ +""" + ascii_str = \ +"""\ + / pi | \\\n\ + __0, 2 |1, -- 2, 5, pi | 2|\n\ +/__ | 7 | z |\n\ +\\_|5, 0 | | |\n\ + \\ | /\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + ucode_str = \ +"""\ +╭─╮ 1, 10 ⎛1, 1, 1, 1, 1, 1, 1, 1, 1, 1 1 │ ⎞\n\ +│╶┐ ⎜ │ z⎟\n\ +╰─╯11, 2 ⎝ 1 1 │ ⎠\ +""" + ascii_str = \ +"""\ + __ 1, 10 /1, 1, 1, 1, 1, 1, 1, 1, 1, 1 1 | \\\n\ +/__ | | z|\n\ +\\_|11, 2 \\ 1 1 | /\ +""" + + expr = meijerg([1]*10, [1], [1], [1], z) + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = meijerg([1, 2, ], [4, 3], [3], [4, 5], 1/(1/(1/(1/x + 1) + 1) + 1)) + + ucode_str = \ +"""\ + ⎛ │ 1 ⎞\n\ + ⎜ │ ─────────────⎟\n\ + ⎜ │ 1 ⎟\n\ +╭─╮1, 2 ⎜1, 2 3, 4 │ 1 + ─────────⎟\n\ +│╶┐ ⎜ │ 1 ⎟\n\ +╰─╯4, 3 ⎜ 3 4, 5 │ 1 + ─────⎟\n\ + ⎜ │ 1⎟\n\ + ⎜ │ 1 + ─⎟\n\ + ⎝ │ x⎠\ +""" + + ascii_str = \ +"""\ + / | 1 \\\n\ + | | -------------|\n\ + | | 1 |\n\ + __1, 2 |1, 2 3, 4 | 1 + ---------|\n\ +/__ | | 1 |\n\ +\\_|4, 3 | 3 4, 5 | 1 + -----|\n\ + | | 1|\n\ + | | 1 + -|\n\ + \\ | x/\ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Integral(expr, x) + + ucode_str = \ +"""\ +⌠ \n\ +⎮ ⎛ │ 1 ⎞ \n\ +⎮ ⎜ │ ─────────────⎟ \n\ +⎮ ⎜ │ 1 ⎟ \n\ +⎮ ╭─╮1, 2 ⎜1, 2 3, 4 │ 1 + ─────────⎟ \n\ +⎮ │╶┐ ⎜ │ 1 ⎟ dx\n\ +⎮ ╰─╯4, 3 ⎜ 3 4, 5 │ 1 + ─────⎟ \n\ +⎮ ⎜ │ 1⎟ \n\ +⎮ ⎜ │ 1 + ─⎟ \n\ +⎮ ⎝ │ x⎠ \n\ +⌡ \ +""" + + ascii_str = \ +"""\ + / \n\ + | \n\ + | / | 1 \\ \n\ + | | | -------------| \n\ + | | | 1 | \n\ + | __1, 2 |1, 2 3, 4 | 1 + ---------| \n\ + | /__ | | 1 | dx\n\ + | \\_|4, 3 | 3 4, 5 | 1 + -----| \n\ + | | | 1| \n\ + | | | 1 + -| \n\ + | \\ | x/ \n\ + | \n\ +/ \ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_noncommutative(): + A, B, C = symbols('A,B,C', commutative=False) + + expr = A*B*C**-1 + ascii_str = \ +"""\ + -1\n\ +A*B*C \ +""" + ucode_str = \ +"""\ + -1\n\ +A⋅B⋅C \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = C**-1*A*B + ascii_str = \ +"""\ + -1 \n\ +C *A*B\ +""" + ucode_str = \ +"""\ + -1 \n\ +C ⋅A⋅B\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = A*C**-1*B + ascii_str = \ +"""\ + -1 \n\ +A*C *B\ +""" + ucode_str = \ +"""\ + -1 \n\ +A⋅C ⋅B\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = A*C**-1*B/x + ascii_str = \ +"""\ + -1 \n\ +A*C *B\n\ +-------\n\ + x \ +""" + ucode_str = \ +"""\ + -1 \n\ +A⋅C ⋅B\n\ +───────\n\ + x \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_pretty_special_functions(): + x, y = symbols("x y") + + # atan2 + expr = atan2(y/sqrt(200), sqrt(x)) + ascii_str = \ +"""\ + / ___ \\\n\ + |\\/ 2 *y ___|\n\ +atan2|-------, \\/ x |\n\ + \\ 20 /\ +""" + ucode_str = \ +"""\ + ⎛√2⋅y ⎞\n\ +atan2⎜────, √x⎟\n\ + ⎝ 20 ⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_pretty_geometry(): + e = Segment((0, 1), (0, 2)) + assert pretty(e) == 'Segment2D(Point2D(0, 1), Point2D(0, 2))' + e = Ray((1, 1), angle=4.02*pi) + assert pretty(e) == 'Ray2D(Point2D(1, 1), Point2D(2, tan(pi/50) + 1))' + + +def test_expint(): + expr = Ei(x) + string = 'Ei(x)' + assert pretty(expr) == string + assert upretty(expr) == string + + expr = expint(1, z) + ucode_str = "E₁(z)" + ascii_str = "expint(1, z)" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + assert pretty(Shi(x)) == 'Shi(x)' + assert pretty(Si(x)) == 'Si(x)' + assert pretty(Ci(x)) == 'Ci(x)' + assert pretty(Chi(x)) == 'Chi(x)' + assert upretty(Shi(x)) == 'Shi(x)' + assert upretty(Si(x)) == 'Si(x)' + assert upretty(Ci(x)) == 'Ci(x)' + assert upretty(Chi(x)) == 'Chi(x)' + + +def test_elliptic_functions(): + ascii_str = \ +"""\ + / 1 \\\n\ +K|-----|\n\ + \\z + 1/\ +""" + ucode_str = \ +"""\ + ⎛ 1 ⎞\n\ +K⎜─────⎟\n\ + ⎝z + 1⎠\ +""" + expr = elliptic_k(1/(z + 1)) + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + ascii_str = \ +"""\ + / | 1 \\\n\ +F|1|-----|\n\ + \\ |z + 1/\ +""" + ucode_str = \ +"""\ + ⎛ │ 1 ⎞\n\ +F⎜1│─────⎟\n\ + ⎝ │z + 1⎠\ +""" + expr = elliptic_f(1, 1/(1 + z)) + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + ascii_str = \ +"""\ + / 1 \\\n\ +E|-----|\n\ + \\z + 1/\ +""" + ucode_str = \ +"""\ + ⎛ 1 ⎞\n\ +E⎜─────⎟\n\ + ⎝z + 1⎠\ +""" + expr = elliptic_e(1/(z + 1)) + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + ascii_str = \ +"""\ + / | 1 \\\n\ +E|1|-----|\n\ + \\ |z + 1/\ +""" + ucode_str = \ +"""\ + ⎛ │ 1 ⎞\n\ +E⎜1│─────⎟\n\ + ⎝ │z + 1⎠\ +""" + expr = elliptic_e(1, 1/(1 + z)) + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + ascii_str = \ +"""\ + / |4\\\n\ +Pi|3|-|\n\ + \\ |x/\ +""" + ucode_str = \ +"""\ + ⎛ │4⎞\n\ +Π⎜3│─⎟\n\ + ⎝ │x⎠\ +""" + expr = elliptic_pi(3, 4/x) + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + ascii_str = \ +"""\ + / 4| \\\n\ +Pi|3; -|6|\n\ + \\ x| /\ +""" + ucode_str = \ +"""\ + ⎛ 4│ ⎞\n\ +Π⎜3; ─│6⎟\n\ + ⎝ x│ ⎠\ +""" + expr = elliptic_pi(3, 4/x, 6) + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_RandomDomain(): + from sympy.stats import Normal, Die, Exponential, pspace, where + X = Normal('x1', 0, 1) + assert upretty(where(X > 0)) == "Domain: 0 < x₁ ∧ x₁ < ∞" + + D = Die('d1', 6) + assert upretty(where(D > 4)) == 'Domain: d₁ = 5 ∨ d₁ = 6' + + A = Exponential('a', 1) + B = Exponential('b', 1) + assert upretty(pspace(Tuple(A, B)).domain) == \ + 'Domain: 0 ≤ a ∧ 0 ≤ b ∧ a < ∞ ∧ b < ∞' + + +def test_PrettyPoly(): + F = QQ.frac_field(x, y) + R = QQ.poly_ring(x, y) + + expr = F.convert(x/(x + y)) + assert pretty(expr) == "x/(x + y)" + assert upretty(expr) == "x/(x + y)" + + expr = R.convert(x + y) + assert pretty(expr) == "x + y" + assert upretty(expr) == "x + y" + + +def test_issue_6285(): + assert pretty(Pow(2, -5, evaluate=False)) == '1 \n--\n 5\n2 ' + assert pretty(Pow(x, (1/pi))) == \ + ' 1 \n'\ + ' --\n'\ + ' pi\n'\ + 'x ' + + +def test_issue_6359(): + assert pretty(Integral(x**2, x)**2) == \ +"""\ + 2 +/ / \\ \n\ +| | | \n\ +| | 2 | \n\ +| | x dx| \n\ +| | | \n\ +\\/ / \ +""" + assert upretty(Integral(x**2, x)**2) == \ +"""\ + 2 +⎛⌠ ⎞ \n\ +⎜⎮ 2 ⎟ \n\ +⎜⎮ x dx⎟ \n\ +⎝⌡ ⎠ \ +""" + + assert pretty(Sum(x**2, (x, 0, 1))**2) == \ +"""\ + 2\n\ +/ 1 \\ \n\ +|___ | \n\ +|\\ ` | \n\ +| \\ 2| \n\ +| / x | \n\ +|/__, | \n\ +\\x = 0 / \ +""" + assert upretty(Sum(x**2, (x, 0, 1))**2) == \ +"""\ + 2 +⎛ 1 ⎞ \n\ +⎜ ___ ⎟ \n\ +⎜ ╲ ⎟ \n\ +⎜ ╲ 2⎟ \n\ +⎜ ╱ x ⎟ \n\ +⎜ ╱ ⎟ \n\ +⎜ ‾‾‾ ⎟ \n\ +⎝x = 0 ⎠ \ +""" + + assert pretty(Product(x**2, (x, 1, 2))**2) == \ +"""\ + 2 +/ 2 \\ \n\ +|______ | \n\ +| | | 2| \n\ +| | | x | \n\ +| | | | \n\ +\\x = 1 / \ +""" + assert upretty(Product(x**2, (x, 1, 2))**2) == \ +"""\ + 2 +⎛ 2 ⎞ \n\ +⎜─┬──┬─ ⎟ \n\ +⎜ │ │ 2⎟ \n\ +⎜ │ │ x ⎟ \n\ +⎜ │ │ ⎟ \n\ +⎝x = 1 ⎠ \ +""" + + f = Function('f') + assert pretty(Derivative(f(x), x)**2) == \ +"""\ + 2 +/d \\ \n\ +|--(f(x))| \n\ +\\dx / \ +""" + assert upretty(Derivative(f(x), x)**2) == \ +"""\ + 2 +⎛d ⎞ \n\ +⎜──(f(x))⎟ \n\ +⎝dx ⎠ \ +""" + + +def test_issue_6739(): + ascii_str = \ +"""\ + 1 \n\ +-----\n\ + ___\n\ +\\/ x \ +""" + ucode_str = \ +"""\ +1 \n\ +──\n\ +√x\ +""" + assert pretty(1/sqrt(x)) == ascii_str + assert upretty(1/sqrt(x)) == ucode_str + + +def test_complicated_symbol_unchanged(): + for symb_name in ["dexpr2_d1tau", "dexpr2^d1tau"]: + assert pretty(Symbol(symb_name)) == symb_name + + +def test_categories(): + from sympy.categories import (Object, IdentityMorphism, + NamedMorphism, Category, Diagram, DiagramGrid) + + A1 = Object("A1") + A2 = Object("A2") + A3 = Object("A3") + + f1 = NamedMorphism(A1, A2, "f1") + f2 = NamedMorphism(A2, A3, "f2") + id_A1 = IdentityMorphism(A1) + + K1 = Category("K1") + + assert pretty(A1) == "A1" + assert upretty(A1) == "A₁" + + assert pretty(f1) == "f1:A1-->A2" + assert upretty(f1) == "f₁:A₁——▶A₂" + assert pretty(id_A1) == "id:A1-->A1" + assert upretty(id_A1) == "id:A₁——▶A₁" + + assert pretty(f2*f1) == "f2*f1:A1-->A3" + assert upretty(f2*f1) == "f₂∘f₁:A₁——▶A₃" + + assert pretty(K1) == "K1" + assert upretty(K1) == "K₁" + + # Test how diagrams are printed. + d = Diagram() + assert pretty(d) == "EmptySet" + assert upretty(d) == "∅" + + d = Diagram({f1: "unique", f2: S.EmptySet}) + assert pretty(d) == "{f2*f1:A1-->A3: EmptySet, id:A1-->A1: " \ + "EmptySet, id:A2-->A2: EmptySet, id:A3-->A3: " \ + "EmptySet, f1:A1-->A2: {unique}, f2:A2-->A3: EmptySet}" + + assert upretty(d) == "{f₂∘f₁:A₁——▶A₃: ∅, id:A₁——▶A₁: ∅, " \ + "id:A₂——▶A₂: ∅, id:A₃——▶A₃: ∅, f₁:A₁——▶A₂: {unique}, f₂:A₂——▶A₃: ∅}" + + d = Diagram({f1: "unique", f2: S.EmptySet}, {f2 * f1: "unique"}) + assert pretty(d) == "{f2*f1:A1-->A3: EmptySet, id:A1-->A1: " \ + "EmptySet, id:A2-->A2: EmptySet, id:A3-->A3: " \ + "EmptySet, f1:A1-->A2: {unique}, f2:A2-->A3: EmptySet}" \ + " ==> {f2*f1:A1-->A3: {unique}}" + assert upretty(d) == "{f₂∘f₁:A₁——▶A₃: ∅, id:A₁——▶A₁: ∅, id:A₂——▶A₂: " \ + "∅, id:A₃——▶A₃: ∅, f₁:A₁——▶A₂: {unique}, f₂:A₂——▶A₃: ∅}" \ + " ══▶ {f₂∘f₁:A₁——▶A₃: {unique}}" + + grid = DiagramGrid(d) + assert pretty(grid) == "A1 A2\n \nA3 " + assert upretty(grid) == "A₁ A₂\n \nA₃ " + + +def test_PrettyModules(): + R = QQ.old_poly_ring(x, y) + F = R.free_module(2) + M = F.submodule([x, y], [1, x**2]) + + ucode_str = \ +"""\ + 2\n\ +ℚ[x, y] \ +""" + ascii_str = \ +"""\ + 2\n\ +QQ[x, y] \ +""" + + assert upretty(F) == ucode_str + assert pretty(F) == ascii_str + + ucode_str = \ +"""\ +╱ ⎡ 2⎤╲\n\ +╲[x, y], ⎣1, x ⎦╱\ +""" + ascii_str = \ +"""\ + 2 \n\ +<[x, y], [1, x ]>\ +""" + + assert upretty(M) == ucode_str + assert pretty(M) == ascii_str + + I = R.ideal(x**2, y) + + ucode_str = \ +"""\ +╱ 2 ╲\n\ +╲x , y╱\ +""" + + ascii_str = \ +"""\ + 2 \n\ +\ +""" + + assert upretty(I) == ucode_str + assert pretty(I) == ascii_str + + Q = F / M + + ucode_str = \ +"""\ + 2 \n\ + ℚ[x, y] \n\ +─────────────────\n\ +╱ ⎡ 2⎤╲\n\ +╲[x, y], ⎣1, x ⎦╱\ +""" + + ascii_str = \ +"""\ + 2 \n\ + QQ[x, y] \n\ +-----------------\n\ + 2 \n\ +<[x, y], [1, x ]>\ +""" + + assert upretty(Q) == ucode_str + assert pretty(Q) == ascii_str + + ucode_str = \ +"""\ +╱⎡ 3⎤ ╲\n\ +│⎢ x ⎥ ╱ ⎡ 2⎤╲ ╱ ⎡ 2⎤╲│\n\ +│⎢1, ──⎥ + ╲[x, y], ⎣1, x ⎦╱, [2, y] + ╲[x, y], ⎣1, x ⎦╱│\n\ +╲⎣ 2 ⎦ ╱\ +""" + + ascii_str = \ +"""\ + 3 \n\ + x 2 2 \n\ +<[1, --] + <[x, y], [1, x ]>, [2, y] + <[x, y], [1, x ]>>\n\ + 2 \ +""" + + +def test_QuotientRing(): + R = QQ.old_poly_ring(x)/[x**2 + 1] + + ucode_str = \ +"""\ + ℚ[x] \n\ +────────\n\ +╱ 2 ╲\n\ +╲x + 1╱\ +""" + + ascii_str = \ +"""\ + QQ[x] \n\ +--------\n\ + 2 \n\ +\ +""" + + assert upretty(R) == ucode_str + assert pretty(R) == ascii_str + + ucode_str = \ +"""\ + ╱ 2 ╲\n\ +1 + ╲x + 1╱\ +""" + + ascii_str = \ +"""\ + 2 \n\ +1 + \ +""" + + assert upretty(R.one) == ucode_str + assert pretty(R.one) == ascii_str + + +def test_Homomorphism(): + from sympy.polys.agca import homomorphism + + R = QQ.old_poly_ring(x) + + expr = homomorphism(R.free_module(1), R.free_module(1), [0]) + + ucode_str = \ +"""\ + 1 1\n\ +[0] : ℚ[x] ──> ℚ[x] \ +""" + + ascii_str = \ +"""\ + 1 1\n\ +[0] : QQ[x] --> QQ[x] \ +""" + + assert upretty(expr) == ucode_str + assert pretty(expr) == ascii_str + + expr = homomorphism(R.free_module(2), R.free_module(2), [0, 0]) + + ucode_str = \ +"""\ +⎡0 0⎤ 2 2\n\ +⎢ ⎥ : ℚ[x] ──> ℚ[x] \n\ +⎣0 0⎦ \ +""" + + ascii_str = \ +"""\ +[0 0] 2 2\n\ +[ ] : QQ[x] --> QQ[x] \n\ +[0 0] \ +""" + + assert upretty(expr) == ucode_str + assert pretty(expr) == ascii_str + + expr = homomorphism(R.free_module(1), R.free_module(1) / [[x]], [0]) + + ucode_str = \ +"""\ + 1\n\ + 1 ℚ[x] \n\ +[0] : ℚ[x] ──> ─────\n\ + <[x]>\ +""" + + ascii_str = \ +"""\ + 1\n\ + 1 QQ[x] \n\ +[0] : QQ[x] --> ------\n\ + <[x]> \ +""" + + assert upretty(expr) == ucode_str + assert pretty(expr) == ascii_str + + +def test_Tr(): + A, B = symbols('A B', commutative=False) + t = Tr(A*B) + assert pretty(t) == r'Tr(A*B)' + assert upretty(t) == 'Tr(A⋅B)' + + +def test_pretty_Add(): + eq = Mul(-2, x - 2, evaluate=False) + 5 + assert pretty(eq) == '5 - 2*(x - 2)' + + +def test_issue_7179(): + assert upretty(Not(Equivalent(x, y))) == 'x ⇎ y' + assert upretty(Not(Implies(x, y))) == 'x ↛ y' + + +def test_issue_7180(): + assert upretty(Equivalent(x, y)) == 'x ⇔ y' + + +def test_pretty_Complement(): + assert pretty(S.Reals - S.Naturals) == '(-oo, oo) \\ Naturals' + assert upretty(S.Reals - S.Naturals) == 'ℝ \\ ℕ' + assert pretty(S.Reals - S.Naturals0) == '(-oo, oo) \\ Naturals0' + assert upretty(S.Reals - S.Naturals0) == 'ℝ \\ ℕ₀' + + +def test_pretty_SymmetricDifference(): + from sympy.sets.sets import SymmetricDifference + assert upretty(SymmetricDifference(Interval(2,3), Interval(3,5), \ + evaluate = False)) == '[2, 3] ∆ [3, 5]' + with raises(NotImplementedError): + pretty(SymmetricDifference(Interval(2,3), Interval(3,5), evaluate = False)) + + +def test_pretty_Contains(): + assert pretty(Contains(x, S.Integers)) == 'Contains(x, Integers)' + assert upretty(Contains(x, S.Integers)) == 'x ∈ ℤ' + + +def test_issue_8292(): + from sympy.core import sympify + e = sympify('((x+x**4)/(x-1))-(2*(x-1)**4/(x-1)**4)', evaluate=False) + ucode_str = \ +"""\ + 4 4 \n\ + 2⋅(x - 1) x + x\n\ +- ────────── + ──────\n\ + 4 x - 1 \n\ + (x - 1) \ +""" + ascii_str = \ +"""\ + 4 4 \n\ + 2*(x - 1) x + x\n\ +- ---------- + ------\n\ + 4 x - 1 \n\ + (x - 1) \ +""" + assert pretty(e) == ascii_str + assert upretty(e) == ucode_str + + +def test_issue_4335(): + y = Function('y') + expr = -y(x).diff(x) + ucode_str = \ +"""\ + d \n\ +-──(y(x))\n\ + dx \ +""" + ascii_str = \ +"""\ + d \n\ +- --(y(x))\n\ + dx \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_issue_8344(): + from sympy.core import sympify + e = sympify('2*x*y**2/1**2 + 1', evaluate=False) + ucode_str = \ +"""\ + 2 \n\ +2⋅x⋅y \n\ +────── + 1\n\ + 2 \n\ + 1 \ +""" + assert upretty(e) == ucode_str + + +def test_issue_6324(): + x = Pow(2, 3, evaluate=False) + y = Pow(10, -2, evaluate=False) + e = Mul(x, y, evaluate=False) + ucode_str = \ +"""\ + 3 \n\ +2 \n\ +───\n\ + 2\n\ +10 \ +""" + assert upretty(e) == ucode_str + + +def test_issue_7927(): + e = sin(x/2)**cos(x/2) + ucode_str = \ +"""\ + ⎛x⎞\n\ + cos⎜─⎟\n\ + ⎝2⎠\n\ +⎛ ⎛x⎞⎞ \n\ +⎜sin⎜─⎟⎟ \n\ +⎝ ⎝2⎠⎠ \ +""" + assert upretty(e) == ucode_str + e = sin(x)**(S(11)/13) + ucode_str = \ +"""\ + 11\n\ + ──\n\ + 13\n\ +(sin(x)) \ +""" + assert upretty(e) == ucode_str + + +def test_issue_6134(): + from sympy.abc import lamda, t + phi = Function('phi') + + e = lamda*x*Integral(phi(t)*pi*sin(pi*t), (t, 0, 1)) + lamda*x**2*Integral(phi(t)*2*pi*sin(2*pi*t), (t, 0, 1)) + ucode_str = \ +"""\ + 1 1 \n\ + 2 ⌠ ⌠ \n\ +λ⋅x ⋅⎮ 2⋅π⋅φ(t)⋅sin(2⋅π⋅t) dt + λ⋅x⋅⎮ π⋅φ(t)⋅sin(π⋅t) dt\n\ + ⌡ ⌡ \n\ + 0 0 \ +""" + assert upretty(e) == ucode_str + + +def test_issue_9877(): + ucode_str1 = '(2, 3) ∪ ([1, 2] \\ {x})' + a, b, c = Interval(2, 3, True, True), Interval(1, 2), FiniteSet(x) + assert upretty(Union(a, Complement(b, c))) == ucode_str1 + + ucode_str2 = '{x} ∩ {y} ∩ ({z} \\ [1, 2])' + d, e, f, g = FiniteSet(x), FiniteSet(y), FiniteSet(z), Interval(1, 2) + assert upretty(Intersection(d, e, Complement(f, g))) == ucode_str2 + + +def test_issue_13651(): + expr1 = c + Mul(-1, a + b, evaluate=False) + assert pretty(expr1) == 'c - (a + b)' + expr2 = c + Mul(-1, a - b + d, evaluate=False) + assert pretty(expr2) == 'c - (a - b + d)' + + +def test_pretty_primenu(): + from sympy.functions.combinatorial.numbers import primenu + + ascii_str1 = "nu(n)" + ucode_str1 = "ν(n)" + + n = symbols('n', integer=True) + assert pretty(primenu(n)) == ascii_str1 + assert upretty(primenu(n)) == ucode_str1 + + +def test_pretty_primeomega(): + from sympy.functions.combinatorial.numbers import primeomega + + ascii_str1 = "Omega(n)" + ucode_str1 = "Ω(n)" + + n = symbols('n', integer=True) + assert pretty(primeomega(n)) == ascii_str1 + assert upretty(primeomega(n)) == ucode_str1 + + +def test_pretty_Mod(): + from sympy.core import Mod + + ascii_str1 = "x mod 7" + ucode_str1 = "x mod 7" + + ascii_str2 = "(x + 1) mod 7" + ucode_str2 = "(x + 1) mod 7" + + ascii_str3 = "2*x mod 7" + ucode_str3 = "2⋅x mod 7" + + ascii_str4 = "(x mod 7) + 1" + ucode_str4 = "(x mod 7) + 1" + + ascii_str5 = "2*(x mod 7)" + ucode_str5 = "2⋅(x mod 7)" + + x = symbols('x', integer=True) + assert pretty(Mod(x, 7)) == ascii_str1 + assert upretty(Mod(x, 7)) == ucode_str1 + assert pretty(Mod(x + 1, 7)) == ascii_str2 + assert upretty(Mod(x + 1, 7)) == ucode_str2 + assert pretty(Mod(2 * x, 7)) == ascii_str3 + assert upretty(Mod(2 * x, 7)) == ucode_str3 + assert pretty(Mod(x, 7) + 1) == ascii_str4 + assert upretty(Mod(x, 7) + 1) == ucode_str4 + assert pretty(2 * Mod(x, 7)) == ascii_str5 + assert upretty(2 * Mod(x, 7)) == ucode_str5 + + +def test_issue_11801(): + assert pretty(Symbol("")) == "" + assert upretty(Symbol("")) == "" + + +def test_pretty_UnevaluatedExpr(): + x = symbols('x') + he = UnevaluatedExpr(1/x) + + ucode_str = \ +"""\ +1\n\ +─\n\ +x\ +""" + + assert upretty(he) == ucode_str + + ucode_str = \ +"""\ + 2\n\ +⎛1⎞ \n\ +⎜─⎟ \n\ +⎝x⎠ \ +""" + + assert upretty(he**2) == ucode_str + + ucode_str = \ +"""\ + 1\n\ +1 + ─\n\ + x\ +""" + + assert upretty(he + 1) == ucode_str + + ucode_str = \ +('''\ + 1\n\ +x⋅─\n\ + x\ +''') + assert upretty(x*he) == ucode_str + + +def test_issue_10472(): + M = (Matrix([[0, 0], [0, 0]]), Matrix([0, 0])) + + ucode_str = \ +"""\ +⎛⎡0 0⎤ ⎡0⎤⎞ +⎜⎢ ⎥, ⎢ ⎥⎟ +⎝⎣0 0⎦ ⎣0⎦⎠\ +""" + assert upretty(M) == ucode_str + + +def test_MatrixElement_printing(): + # test cases for issue #11821 + A = MatrixSymbol("A", 1, 3) + B = MatrixSymbol("B", 1, 3) + C = MatrixSymbol("C", 1, 3) + + ascii_str1 = "A_00" + ucode_str1 = "A₀₀" + assert pretty(A[0, 0]) == ascii_str1 + assert upretty(A[0, 0]) == ucode_str1 + + ascii_str1 = "3*A_00" + ucode_str1 = "3⋅A₀₀" + assert pretty(3*A[0, 0]) == ascii_str1 + assert upretty(3*A[0, 0]) == ucode_str1 + + ascii_str1 = "(-B + A)[0, 0]" + ucode_str1 = "(-B + A)[0, 0]" + F = C[0, 0].subs(C, A - B) + assert pretty(F) == ascii_str1 + assert upretty(F) == ucode_str1 + + +def test_issue_12675(): + x, y, t, j = symbols('x y t j') + e = CoordSys3D('e') + + ucode_str = \ +"""\ +⎛ t⎞ \n\ +⎜⎛x⎞ ⎟ j_e\n\ +⎜⎜─⎟ ⎟ \n\ +⎝⎝y⎠ ⎠ \ +""" + assert upretty((x/y)**t*e.j) == ucode_str + ucode_str = \ +"""\ +⎛1⎞ \n\ +⎜─⎟ j_e\n\ +⎝y⎠ \ +""" + assert upretty((1/y)*e.j) == ucode_str + + +def test_MatrixSymbol_printing(): + # test cases for issue #14237 + A = MatrixSymbol("A", 3, 3) + B = MatrixSymbol("B", 3, 3) + C = MatrixSymbol("C", 3, 3) + assert pretty(-A*B*C) == "-A*B*C" + assert pretty(A - B) == "-B + A" + assert pretty(A*B*C - A*B - B*C) == "-A*B -B*C + A*B*C" + + # issue #14814 + x = MatrixSymbol('x', n, n) + y = MatrixSymbol('y*', n, n) + assert pretty(x + y) == "x + y*" + ascii_str = \ +"""\ + 2 \n\ +-2*y* -a*x\ +""" + assert pretty(-a*x + -2*y*y) == ascii_str + + +def test_degree_printing(): + expr1 = 90*degree + assert pretty(expr1) == '90°' + expr2 = x*degree + assert pretty(expr2) == 'x°' + expr3 = cos(x*degree + 90*degree) + assert pretty(expr3) == 'cos(x° + 90°)' + + +def test_vector_expr_pretty_printing(): + A = CoordSys3D('A') + + assert upretty(Cross(A.i, A.x*A.i+3*A.y*A.j)) == "(i_A)×((x_A) i_A + (3⋅y_A) j_A)" + assert upretty(x*Cross(A.i, A.j)) == 'x⋅(i_A)×(j_A)' + + assert upretty(Curl(A.x*A.i + 3*A.y*A.j)) == "∇×((x_A) i_A + (3⋅y_A) j_A)" + + assert upretty(Divergence(A.x*A.i + 3*A.y*A.j)) == "∇⋅((x_A) i_A + (3⋅y_A) j_A)" + + assert upretty(Dot(A.i, A.x*A.i+3*A.y*A.j)) == "(i_A)⋅((x_A) i_A + (3⋅y_A) j_A)" + + assert upretty(Gradient(A.x+3*A.y)) == "∇(x_A + 3⋅y_A)" + assert upretty(Laplacian(A.x+3*A.y)) == "∆(x_A + 3⋅y_A)" + # TODO: add support for ASCII pretty. + + +def test_pretty_print_tensor_expr(): + L = TensorIndexType("L") + i, j, k = tensor_indices("i j k", L) + i0 = tensor_indices("i_0", L) + A, B, C, D = tensor_heads("A B C D", [L]) + H = TensorHead("H", [L, L]) + + expr = -i + ascii_str = \ +"""\ +-i\ +""" + ucode_str = \ +"""\ +-i\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = A(i) + ascii_str = \ +"""\ + i\n\ +A \n\ + \ +""" + ucode_str = \ +"""\ + i\n\ +A \n\ + \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = A(i0) + ascii_str = \ +"""\ + i_0\n\ +A \n\ + \ +""" + ucode_str = \ +"""\ + i₀\n\ +A \n\ + \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = A(-i) + ascii_str = \ +"""\ + \n\ +A \n\ + i\ +""" + ucode_str = \ +"""\ + \n\ +A \n\ + i\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = -3*A(-i) + ascii_str = \ +"""\ + \n\ +-3*A \n\ + i\ +""" + ucode_str = \ +"""\ + \n\ +-3⋅A \n\ + i\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = H(i, -j) + ascii_str = \ +"""\ + i \n\ +H \n\ + j\ +""" + ucode_str = \ +"""\ + i \n\ +H \n\ + j\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = H(i, -i) + ascii_str = \ +"""\ + L_0 \n\ +H \n\ + L_0\ +""" + ucode_str = \ +"""\ + L₀ \n\ +H \n\ + L₀\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = H(i, -j)*A(j)*B(k) + ascii_str = \ +"""\ + i L_0 k\n\ +H *A *B \n\ + L_0 \ +""" + ucode_str = \ +"""\ + i L₀ k\n\ +H ⋅A ⋅B \n\ + L₀ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = (1+x)*A(i) + ascii_str = \ +"""\ + i\n\ +(x + 1)*A \n\ + \ +""" + ucode_str = \ +"""\ + i\n\ +(x + 1)⋅A \n\ + \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = A(i) + 3*B(i) + ascii_str = \ +"""\ + i i\n\ +3*B + A \n\ + \ +""" + ucode_str = \ +"""\ + i i\n\ +3⋅B + A \n\ + \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_pretty_print_tensor_partial_deriv(): + from sympy.tensor.toperators import PartialDerivative + + L = TensorIndexType("L") + i, j, k = tensor_indices("i j k", L) + + A, B, C, D = tensor_heads("A B C D", [L]) + + H = TensorHead("H", [L, L]) + + expr = PartialDerivative(A(i), A(j)) + ascii_str = \ +"""\ + d / i\\\n\ +---|A |\n\ + j\\ /\n\ +dA \n\ + \ +""" + ucode_str = \ +"""\ + ∂ ⎛ i⎞\n\ +───⎜A ⎟\n\ + j⎝ ⎠\n\ +∂A \n\ + \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = A(i)*PartialDerivative(H(k, -i), A(j)) + ascii_str = \ +"""\ + L_0 d / k \\\n\ +A *---|H |\n\ + j\\ L_0/\n\ + dA \n\ + \ +""" + ucode_str = \ +"""\ + L₀ ∂ ⎛ k ⎞\n\ +A ⋅───⎜H ⎟\n\ + j⎝ L₀⎠\n\ + ∂A \n\ + \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = A(i)*PartialDerivative(B(k)*C(-i) + 3*H(k, -i), A(j)) + ascii_str = \ +"""\ + L_0 d / k k \\\n\ +A *---|3*H + B *C |\n\ + j\\ L_0 L_0/\n\ + dA \n\ + \ +""" + ucode_str = \ +"""\ + L₀ ∂ ⎛ k k ⎞\n\ +A ⋅───⎜3⋅H + B ⋅C ⎟\n\ + j⎝ L₀ L₀⎠\n\ + ∂A \n\ + \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = (A(i) + B(i))*PartialDerivative(C(j), D(j)) + ascii_str = \ +"""\ +/ i i\\ d / L_0\\\n\ +|A + B |*-----|C |\n\ +\\ / L_0\\ /\n\ + dD \n\ + \ +""" + ucode_str = \ +"""\ +⎛ i i⎞ ∂ ⎛ L₀⎞\n\ +⎜A + B ⎟⋅────⎜C ⎟\n\ +⎝ ⎠ L₀⎝ ⎠\n\ + ∂D \n\ + \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = (A(i) + B(i))*PartialDerivative(C(-i), D(j)) + ascii_str = \ +"""\ +/ L_0 L_0\\ d / \\\n\ +|A + B |*---|C |\n\ +\\ / j\\ L_0/\n\ + dD \n\ + \ +""" + ucode_str = \ +"""\ +⎛ L₀ L₀⎞ ∂ ⎛ ⎞\n\ +⎜A + B ⎟⋅───⎜C ⎟\n\ +⎝ ⎠ j⎝ L₀⎠\n\ + ∂D \n\ + \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = PartialDerivative(B(-i) + A(-i), A(-j), A(-n)) + ucode_str = """\ + 2 \n\ + ∂ ⎛ ⎞\n\ +───────⎜A + B ⎟\n\ + ⎝ i i⎠\n\ +∂A ∂A \n\ + n j \ +""" + assert upretty(expr) == ucode_str + + expr = PartialDerivative(3*A(-i), A(-j), A(-n)) + ucode_str = """\ + 2 \n\ + ∂ ⎛ ⎞\n\ +───────⎜3⋅A ⎟\n\ + ⎝ i⎠\n\ +∂A ∂A \n\ + n j \ +""" + assert upretty(expr) == ucode_str + + expr = TensorElement(H(i, j), {i:1}) + ascii_str = \ +"""\ + i=1,j\n\ +H \n\ + \ +""" + ucode_str = ascii_str + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = TensorElement(H(i, j), {i: 1, j: 1}) + ascii_str = \ +"""\ + i=1,j=1\n\ +H \n\ + \ +""" + ucode_str = ascii_str + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = TensorElement(H(i, j), {j: 1}) + ascii_str = \ +"""\ + i,j=1\n\ +H \n\ + \ +""" + ucode_str = ascii_str + + expr = TensorElement(H(-i, j), {-i: 1}) + ascii_str = \ +"""\ + j\n\ +H \n\ + i=1 \ +""" + ucode_str = ascii_str + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_issue_15560(): + a = MatrixSymbol('a', 1, 1) + e = pretty(a*(KroneckerProduct(a, a))) + result = 'a*(a x a)' + assert e == result + + +def test_print_polylog(): + # Part of issue 6013 + uresult = 'Li₂(3)' + aresult = 'polylog(2, 3)' + assert pretty(polylog(2, 3)) == aresult + assert upretty(polylog(2, 3)) == uresult + + +# Issue #25312 +def test_print_expint_polylog_symbolic_order(): + s, z = symbols("s, z") + uresult = 'Liₛ(z)' + aresult = 'polylog(s, z)' + assert pretty(polylog(s, z)) == aresult + assert upretty(polylog(s, z)) == uresult + # TODO: TBD polylog(s - 1, z) + uresult = 'Eₛ(z)' + aresult = 'expint(s, z)' + assert pretty(expint(s, z)) == aresult + assert upretty(expint(s, z)) == uresult + + + +def test_print_polylog_long_order_issue_25309(): + s, z = symbols("s, z") + ucode_str = \ +"""\ + ⎛ 2 ⎞\n\ +polylog⎝s , z⎠\ +""" + assert upretty(polylog(s**2, z)) == ucode_str + + +def test_print_lerchphi(): + # Part of issue 6013 + a = Symbol('a') + pretty(lerchphi(a, 1, 2)) + uresult = 'Φ(a, 1, 2)' + aresult = 'lerchphi(a, 1, 2)' + assert pretty(lerchphi(a, 1, 2)) == aresult + assert upretty(lerchphi(a, 1, 2)) == uresult + + +def test_issue_15583(): + + N = mechanics.ReferenceFrame('N') + result = '(n_x, n_y, n_z)' + e = pretty((N.x, N.y, N.z)) + assert e == result + + +def test_matrixSymbolBold(): + # Issue 15871 + def boldpretty(expr): + return xpretty(expr, use_unicode=True, wrap_line=False, mat_symbol_style="bold") + + from sympy.matrices.expressions.trace import trace + A = MatrixSymbol("A", 2, 2) + assert boldpretty(trace(A)) == 'tr(𝐀)' + + A = MatrixSymbol("A", 3, 3) + B = MatrixSymbol("B", 3, 3) + C = MatrixSymbol("C", 3, 3) + + assert boldpretty(-A) == '-𝐀' + assert boldpretty(A - A*B - B) == '-𝐁 -𝐀⋅𝐁 + 𝐀' + assert boldpretty(-A*B - A*B*C - B) == '-𝐁 -𝐀⋅𝐁 -𝐀⋅𝐁⋅𝐂' + + A = MatrixSymbol("Addot", 3, 3) + assert boldpretty(A) == '𝐀̈' + omega = MatrixSymbol("omega", 3, 3) + assert boldpretty(omega) == 'ω' + omega = MatrixSymbol("omeganorm", 3, 3) + assert boldpretty(omega) == '‖ω‖' + + a = Symbol('alpha') + b = Symbol('b') + c = MatrixSymbol("c", 3, 1) + d = MatrixSymbol("d", 3, 1) + + assert boldpretty(a*B*c+b*d) == 'b⋅𝐝 + α⋅𝐁⋅𝐜' + + d = MatrixSymbol("delta", 3, 1) + B = MatrixSymbol("Beta", 3, 3) + + assert boldpretty(a*B*c+b*d) == 'b⋅δ + α⋅Β⋅𝐜' + + A = MatrixSymbol("A_2", 3, 3) + assert boldpretty(A) == '𝐀₂' + + +def test_center_accent(): + assert center_accent('a', '\N{COMBINING TILDE}') == 'ã' + assert center_accent('aa', '\N{COMBINING TILDE}') == 'aã' + assert center_accent('aaa', '\N{COMBINING TILDE}') == 'aãa' + assert center_accent('aaaa', '\N{COMBINING TILDE}') == 'aaãa' + assert center_accent('aaaaa', '\N{COMBINING TILDE}') == 'aaãaa' + assert center_accent('abcdefg', '\N{COMBINING FOUR DOTS ABOVE}') == 'abcd⃜efg' + + +def test_imaginary_unit(): + from sympy.printing.pretty import pretty # b/c it was redefined above + assert pretty(1 + I, use_unicode=False) == '1 + I' + assert pretty(1 + I, use_unicode=True) == '1 + ⅈ' + assert pretty(1 + I, use_unicode=False, imaginary_unit='j') == '1 + I' + assert pretty(1 + I, use_unicode=True, imaginary_unit='j') == '1 + ⅉ' + + raises(TypeError, lambda: pretty(I, imaginary_unit=I)) + raises(ValueError, lambda: pretty(I, imaginary_unit="kkk")) + + +def test_str_special_matrices(): + from sympy.matrices import Identity, ZeroMatrix, OneMatrix + assert pretty(Identity(4)) == 'I' + assert upretty(Identity(4)) == '𝕀' + assert pretty(ZeroMatrix(2, 2)) == '0' + assert upretty(ZeroMatrix(2, 2)) == '𝟘' + assert pretty(OneMatrix(2, 2)) == '1' + assert upretty(OneMatrix(2, 2)) == '𝟙' + + +def test_pretty_misc_functions(): + assert pretty(LambertW(x)) == 'W(x)' + assert upretty(LambertW(x)) == 'W(x)' + assert pretty(LambertW(x, y)) == 'W(x, y)' + assert upretty(LambertW(x, y)) == 'W(x, y)' + assert pretty(airyai(x)) == 'Ai(x)' + assert upretty(airyai(x)) == 'Ai(x)' + assert pretty(airybi(x)) == 'Bi(x)' + assert upretty(airybi(x)) == 'Bi(x)' + assert pretty(airyaiprime(x)) == "Ai'(x)" + assert upretty(airyaiprime(x)) == "Ai'(x)" + assert pretty(airybiprime(x)) == "Bi'(x)" + assert upretty(airybiprime(x)) == "Bi'(x)" + assert pretty(fresnelc(x)) == 'C(x)' + assert upretty(fresnelc(x)) == 'C(x)' + assert pretty(fresnels(x)) == 'S(x)' + assert upretty(fresnels(x)) == 'S(x)' + assert pretty(Heaviside(x)) == 'Heaviside(x)' + assert upretty(Heaviside(x)) == 'θ(x)' + assert pretty(Heaviside(x, y)) == 'Heaviside(x, y)' + assert upretty(Heaviside(x, y)) == 'θ(x, y)' + assert pretty(dirichlet_eta(x)) == 'dirichlet_eta(x)' + assert upretty(dirichlet_eta(x)) == 'η(x)' + + +def test_hadamard_power(): + m, n, p = symbols('m, n, p', integer=True) + A = MatrixSymbol('A', m, n) + B = MatrixSymbol('B', m, n) + + # Testing printer: + expr = hadamard_power(A, n) + ascii_str = \ +"""\ + .n\n\ +A \ +""" + ucode_str = \ +"""\ + ∘n\n\ +A \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = hadamard_power(A, 1+n) + ascii_str = \ +"""\ + .(n + 1)\n\ +A \ +""" + ucode_str = \ +"""\ + ∘(n + 1)\n\ +A \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = hadamard_power(A*B.T, 1+n) + ascii_str = \ +"""\ + .(n + 1)\n\ +/ T\\ \n\ +\\A*B / \ +""" + ucode_str = \ +"""\ + ∘(n + 1)\n\ +⎛ T⎞ \n\ +⎝A⋅B ⎠ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_issue_17258(): + n = Symbol('n', integer=True) + assert pretty(Sum(n, (n, -oo, 1))) == \ + ' 1 \n'\ + ' __ \n'\ + ' \\ ` \n'\ + ' ) n\n'\ + ' /_, \n'\ + 'n = -oo ' + + assert upretty(Sum(n, (n, -oo, 1))) == \ +"""\ + 1 \n\ + ___ \n\ + ╲ \n\ + ╲ \n\ + ╱ n\n\ + ╱ \n\ + ‾‾‾ \n\ +n = -∞ \ +""" + + +def test_is_combining(): + line = "v̇_m" + assert [is_combining(sym) for sym in line] == \ + [False, True, False, False] + + +def test_issue_17616(): + assert pretty(pi**(1/exp(1))) == \ + ' / -1\\\n'\ + ' \\e /\n'\ + 'pi ' + + assert upretty(pi**(1/exp(1))) == \ + ' ⎛ -1⎞\n'\ + ' ⎝ℯ ⎠\n'\ + 'π ' + + assert pretty(pi**(1/pi)) == \ + ' 1 \n'\ + ' --\n'\ + ' pi\n'\ + 'pi ' + + assert upretty(pi**(1/pi)) == \ + ' 1\n'\ + ' ─\n'\ + ' π\n'\ + 'π ' + + assert pretty(pi**(1/EulerGamma)) == \ + ' 1 \n'\ + ' ----------\n'\ + ' EulerGamma\n'\ + 'pi ' + + assert upretty(pi**(1/EulerGamma)) == \ + ' 1\n'\ + ' ─\n'\ + ' γ\n'\ + 'π ' + + z = Symbol("x_17") + assert upretty(7**(1/z)) == \ + 'x₁₇___\n'\ + ' ╲╱ 7 ' + + assert pretty(7**(1/z)) == \ + 'x_17___\n'\ + ' \\/ 7 ' + + +def test_issue_17857(): + assert pretty(Range(-oo, oo)) == '{..., -1, 0, 1, ...}' + assert pretty(Range(oo, -oo, -1)) == '{..., 1, 0, -1, ...}' + + +def test_issue_18272(): + x = Symbol('x') + n = Symbol('n') + + assert upretty(ConditionSet(x, Eq(-x + exp(x), 0), S.Complexes)) == \ + '⎧ │ ⎛ x ⎞⎫\n'\ + '⎨x │ x ∊ ℂ ∧ ⎝-x + ℯ = 0⎠⎬\n'\ + '⎩ │ ⎭' + assert upretty(ConditionSet(x, Contains(n/2, Interval(0, oo)), FiniteSet(-n/2, n/2))) == \ + '⎧ │ ⎧-n n⎫ ⎛n ⎞⎫\n'\ + '⎨x │ x ∊ ⎨───, ─⎬ ∧ ⎜─ ∈ [0, ∞)⎟⎬\n'\ + '⎩ │ ⎩ 2 2⎭ ⎝2 ⎠⎭' + assert upretty(ConditionSet(x, Eq(Piecewise((1, x >= 3), (x/2 - 1/2, x >= 2), (1/2, x >= 1), + (x/2, True)) - 1/2, 0), Interval(0, 3))) == \ + '⎧ │ ⎛⎛⎧ 1 for x ≥ 3⎞ ⎞⎫\n'\ + '⎪ │ ⎜⎜⎪ ⎟ ⎟⎪\n'\ + '⎪ │ ⎜⎜⎪x ⎟ ⎟⎪\n'\ + '⎪ │ ⎜⎜⎪─ - 0.5 for x ≥ 2⎟ ⎟⎪\n'\ + '⎪ │ ⎜⎜⎪2 ⎟ ⎟⎪\n'\ + '⎨x │ x ∊ [0, 3] ∧ ⎜⎜⎨ ⎟ - 0.5 = 0⎟⎬\n'\ + '⎪ │ ⎜⎜⎪ 0.5 for x ≥ 1⎟ ⎟⎪\n'\ + '⎪ │ ⎜⎜⎪ ⎟ ⎟⎪\n'\ + '⎪ │ ⎜⎜⎪ x ⎟ ⎟⎪\n'\ + '⎪ │ ⎜⎜⎪ ─ otherwise⎟ ⎟⎪\n'\ + '⎩ │ ⎝⎝⎩ 2 ⎠ ⎠⎭' + + +def test_Str(): + from sympy.core.symbol import Str + assert pretty(Str('x')) == 'x' + + +def test_symbolic_probability(): + mu = symbols("mu") + sigma = symbols("sigma", positive=True) + X = Normal("X", mu, sigma) + assert pretty(Expectation(X)) == r'E[X]' + assert pretty(Variance(X)) == r'Var(X)' + assert pretty(Probability(X > 0)) == r'P(X > 0)' + Y = Normal("Y", mu, sigma) + assert pretty(Covariance(X, Y)) == 'Cov(X, Y)' + + +def test_issue_21758(): + from sympy.functions.elementary.piecewise import piecewise_fold + from sympy.series.fourier import FourierSeries + x = Symbol('x') + k, n = symbols('k n') + fo = FourierSeries(x, (x, -pi, pi), (0, SeqFormula(0, (k, 1, oo)), SeqFormula( + Piecewise((-2*pi*cos(n*pi)/n + 2*sin(n*pi)/n**2, (n > -oo) & (n < oo) & Ne(n, 0)), + (0, True))*sin(n*x)/pi, (n, 1, oo)))) + assert upretty(piecewise_fold(fo)) == \ + '⎧ 2⋅sin(3⋅x) \n'\ + '⎪2⋅sin(x) - sin(2⋅x) + ────────── + … for n > -∞ ∧ n < ∞ ∧ n ≠ 0\n'\ + '⎨ 3 \n'\ + '⎪ \n'\ + '⎩ 0 otherwise ' + assert pretty(FourierSeries(x, (x, -pi, pi), (0, SeqFormula(0, (k, 1, oo)), + SeqFormula(0, (n, 1, oo))))) == '0' + + +def test_diffgeom(): + from sympy.diffgeom import Manifold, Patch, CoordSystem, BaseScalarField + x,y = symbols('x y', real=True) + m = Manifold('M', 2) + assert pretty(m) == 'M' + p = Patch('P', m) + assert pretty(p) == "P" + rect = CoordSystem('rect', p, [x, y]) + assert pretty(rect) == "rect" + b = BaseScalarField(rect, 0) + assert pretty(b) == "x" + + +def test_deprecated_prettyForm(): + with warns_deprecated_sympy(): + from sympy.printing.pretty.pretty_symbology import xstr + assert xstr(1) == '1' + + with warns_deprecated_sympy(): + from sympy.printing.pretty.stringpict import prettyForm + p = prettyForm('s', unicode='s') + + with warns_deprecated_sympy(): + assert p.unicode == p.s == 's' + + +def test_center(): + assert center('1', 2) == '1 ' + assert center('1', 3) == ' 1 ' + assert center('1', 3, '-') == '-1-' + assert center('1', 5, '-') == '--1--' diff --git a/phivenv/Lib/site-packages/sympy/printing/preview.py b/phivenv/Lib/site-packages/sympy/printing/preview.py new file mode 100644 index 0000000000000000000000000000000000000000..b04a344b5b4acc086eb84ff068bc1c6a8b55d811 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/preview.py @@ -0,0 +1,390 @@ +import os +from os.path import join +import shutil +import tempfile +from pathlib import Path + +try: + from subprocess import STDOUT, CalledProcessError, check_output +except ImportError: + pass + +from sympy.utilities.decorator import doctest_depends_on +from sympy.utilities.misc import debug +from .latex import latex + +__doctest_requires__ = {('preview',): ['pyglet']} + + +def _check_output_no_window(*args, **kwargs): + # Avoid showing a cmd.exe window when running this + # on Windows + if os.name == 'nt': + creation_flag = 0x08000000 # CREATE_NO_WINDOW + else: + creation_flag = 0 # Default value + return check_output(*args, creationflags=creation_flag, **kwargs) + + +def system_default_viewer(fname, fmt): + """ Open fname with the default system viewer. + + In practice, it is impossible for python to know when the system viewer is + done. For this reason, we ensure the passed file will not be deleted under + it, and this function does not attempt to block. + """ + # copy to a new temporary file that will not be deleted + with tempfile.NamedTemporaryFile(prefix='sympy-preview-', + suffix=os.path.splitext(fname)[1], + delete=False) as temp_f: + with open(fname, 'rb') as f: + shutil.copyfileobj(f, temp_f) + + import platform + if platform.system() == 'Darwin': + import subprocess + subprocess.call(('open', temp_f.name)) + elif platform.system() == 'Windows': + os.startfile(temp_f.name) + else: + import subprocess + subprocess.call(('xdg-open', temp_f.name)) + + +def pyglet_viewer(fname, fmt): + try: + from pyglet import window, image, gl + from pyglet.window import key + from pyglet.image.codecs import ImageDecodeException + except ImportError: + raise ImportError("pyglet is required for preview.\n visit https://pyglet.org/") + + try: + img = image.load(fname) + except ImageDecodeException: + raise ValueError("pyglet preview does not work for '{}' files.".format(fmt)) + + offset = 25 + + config = gl.Config(double_buffer=False) + win = window.Window( + width=img.width + 2*offset, + height=img.height + 2*offset, + caption="SymPy", + resizable=False, + config=config + ) + + win.set_vsync(False) + + try: + def on_close(): + win.has_exit = True + + win.on_close = on_close + + def on_key_press(symbol, modifiers): + if symbol in [key.Q, key.ESCAPE]: + on_close() + + win.on_key_press = on_key_press + + def on_expose(): + gl.glClearColor(1.0, 1.0, 1.0, 1.0) + gl.glClear(gl.GL_COLOR_BUFFER_BIT) + + img.blit( + (win.width - img.width) / 2, + (win.height - img.height) / 2 + ) + + win.on_expose = on_expose + + while not win.has_exit: + win.dispatch_events() + win.flip() + except KeyboardInterrupt: + pass + + win.close() + + +def _get_latex_main(expr, *, preamble=None, packages=(), extra_preamble=None, + euler=True, fontsize=None, **latex_settings): + """ + Generate string of a LaTeX document rendering ``expr``. + """ + if preamble is None: + actual_packages = packages + ("amsmath", "amsfonts") + if euler: + actual_packages += ("euler",) + package_includes = "\n" + "\n".join(["\\usepackage{%s}" % p + for p in actual_packages]) + if extra_preamble: + package_includes += extra_preamble + + if not fontsize: + fontsize = "12pt" + elif isinstance(fontsize, int): + fontsize = "{}pt".format(fontsize) + preamble = r"""\documentclass[varwidth,%s]{standalone} +%s + +\begin{document} +""" % (fontsize, package_includes) + else: + if packages or extra_preamble: + raise ValueError("The \"packages\" or \"extra_preamble\" keywords" + "must not be set if a " + "custom LaTeX preamble was specified") + + if isinstance(expr, str): + latex_string = expr + else: + latex_string = ('$\\displaystyle ' + + latex(expr, mode='plain', **latex_settings) + + '$') + + return preamble + '\n' + latex_string + '\n\n' + r"\end{document}" + + +@doctest_depends_on(exe=('latex', 'dvipng'), modules=('pyglet',), + disable_viewers=('evince', 'gimp', 'superior-dvi-viewer')) +def preview(expr, output='png', viewer=None, euler=True, packages=(), + filename=None, outputbuffer=None, preamble=None, dvioptions=None, + outputTexFile=None, extra_preamble=None, fontsize=None, + **latex_settings): + r""" + View expression or LaTeX markup in PNG, DVI, PostScript or PDF form. + + If the expr argument is an expression, it will be exported to LaTeX and + then compiled using the available TeX distribution. The first argument, + 'expr', may also be a LaTeX string. The function will then run the + appropriate viewer for the given output format or use the user defined + one. By default png output is generated. + + By default pretty Euler fonts are used for typesetting (they were used to + typeset the well known "Concrete Mathematics" book). For that to work, you + need the 'eulervm.sty' LaTeX style (in Debian/Ubuntu, install the + texlive-fonts-extra package). If you prefer default AMS fonts or your + system lacks 'eulervm' LaTeX package then unset the 'euler' keyword + argument. + + To use viewer auto-detection, lets say for 'png' output, issue + + >>> from sympy import symbols, preview, Symbol + >>> x, y = symbols("x,y") + + >>> preview(x + y, output='png') + + This will choose 'pyglet' by default. To select a different one, do + + >>> preview(x + y, output='png', viewer='gimp') + + The 'png' format is considered special. For all other formats the rules + are slightly different. As an example we will take 'dvi' output format. If + you would run + + >>> preview(x + y, output='dvi') + + then 'view' will look for available 'dvi' viewers on your system + (predefined in the function, so it will try evince, first, then kdvi and + xdvi). If nothing is found, it will fall back to using a system file + association (via ``open`` and ``xdg-open``). To always use your system file + association without searching for the above readers, use + + >>> from sympy.printing.preview import system_default_viewer + >>> preview(x + y, output='dvi', viewer=system_default_viewer) + + If this still does not find the viewer you want, it can be set explicitly. + + >>> preview(x + y, output='dvi', viewer='superior-dvi-viewer') + + This will skip auto-detection and will run user specified + 'superior-dvi-viewer'. If ``view`` fails to find it on your system it will + gracefully raise an exception. + + You may also enter ``'file'`` for the viewer argument. Doing so will cause + this function to return a file object in read-only mode, if ``filename`` + is unset. However, if it was set, then 'preview' writes the generated + file to this filename instead. + + There is also support for writing to a ``io.BytesIO`` like object, which + needs to be passed to the ``outputbuffer`` argument. + + >>> from io import BytesIO + >>> obj = BytesIO() + >>> preview(x + y, output='png', viewer='BytesIO', + ... outputbuffer=obj) + + The LaTeX preamble can be customized by setting the 'preamble' keyword + argument. This can be used, e.g., to set a different font size, use a + custom documentclass or import certain set of LaTeX packages. + + >>> preamble = "\\documentclass[10pt]{article}\n" \ + ... "\\usepackage{amsmath,amsfonts}\\begin{document}" + >>> preview(x + y, output='png', preamble=preamble) + + It is also possible to use the standard preamble and provide additional + information to the preamble using the ``extra_preamble`` keyword argument. + + >>> from sympy import sin + >>> extra_preamble = "\\renewcommand{\\sin}{\\cos}" + >>> preview(sin(x), output='png', extra_preamble=extra_preamble) + + If the value of 'output' is different from 'dvi' then command line + options can be set ('dvioptions' argument) for the execution of the + 'dvi'+output conversion tool. These options have to be in the form of a + list of strings (see ``subprocess.Popen``). + + Additional keyword args will be passed to the :func:`~sympy.printing.latex.latex` call, + e.g., the ``symbol_names`` flag. + + >>> phidd = Symbol('phidd') + >>> preview(phidd, symbol_names={phidd: r'\ddot{\varphi}'}) + + For post-processing the generated TeX File can be written to a file by + passing the desired filename to the 'outputTexFile' keyword + argument. To write the TeX code to a file named + ``"sample.tex"`` and run the default png viewer to display the resulting + bitmap, do + + >>> preview(x + y, outputTexFile="sample.tex") + + + """ + # pyglet is the default for png + if viewer is None and output == "png": + try: + import pyglet # noqa: F401 + except ImportError: + pass + else: + viewer = pyglet_viewer + + # look up a known application + if viewer is None: + # sorted in order from most pretty to most ugly + # very discussable, but indeed 'gv' looks awful :) + candidates = { + "dvi": [ "evince", "okular", "kdvi", "xdvi" ], + "ps": [ "evince", "okular", "gsview", "gv" ], + "pdf": [ "evince", "okular", "kpdf", "acroread", "xpdf", "gv" ], + } + + for candidate in candidates.get(output, []): + path = shutil.which(candidate) + if path is not None: + viewer = path + break + + # otherwise, use the system default for file association + if viewer is None: + viewer = system_default_viewer + + if viewer == "file": + if filename is None: + raise ValueError("filename has to be specified if viewer=\"file\"") + elif viewer == "BytesIO": + if outputbuffer is None: + raise ValueError("outputbuffer has to be a BytesIO " + "compatible object if viewer=\"BytesIO\"") + elif not callable(viewer) and not shutil.which(viewer): + raise OSError("Unrecognized viewer: %s" % viewer) + + latex_main = _get_latex_main(expr, preamble=preamble, packages=packages, + euler=euler, extra_preamble=extra_preamble, + fontsize=fontsize, **latex_settings) + + debug("Latex code:") + debug(latex_main) + with tempfile.TemporaryDirectory() as workdir: + Path(join(workdir, 'texput.tex')).write_text(latex_main, encoding='utf-8') + + if outputTexFile is not None: + shutil.copyfile(join(workdir, 'texput.tex'), outputTexFile) + + if not shutil.which('latex'): + raise RuntimeError("latex program is not installed") + + try: + _check_output_no_window( + ['latex', '-halt-on-error', '-interaction=nonstopmode', + 'texput.tex'], + cwd=workdir, + stderr=STDOUT) + except CalledProcessError as e: + raise RuntimeError( + "'latex' exited abnormally with the following output:\n%s" % + e.output) + + src = "texput.%s" % (output) + + if output != "dvi": + # in order of preference + commandnames = { + "ps": ["dvips"], + "pdf": ["dvipdfmx", "dvipdfm", "dvipdf"], + "png": ["dvipng"], + "svg": ["dvisvgm"], + } + try: + cmd_variants = commandnames[output] + except KeyError: + raise ValueError("Invalid output format: %s" % output) from None + + # find an appropriate command + for cmd_variant in cmd_variants: + cmd_path = shutil.which(cmd_variant) + if cmd_path: + cmd = [cmd_path] + break + else: + if len(cmd_variants) > 1: + raise RuntimeError("None of %s are installed" % ", ".join(cmd_variants)) + else: + raise RuntimeError("%s is not installed" % cmd_variants[0]) + + defaultoptions = { + "dvipng": ["-T", "tight", "-z", "9", "--truecolor"], + "dvisvgm": ["--no-fonts"], + } + + commandend = { + "dvips": ["-o", src, "texput.dvi"], + "dvipdf": ["texput.dvi", src], + "dvipdfm": ["-o", src, "texput.dvi"], + "dvipdfmx": ["-o", src, "texput.dvi"], + "dvipng": ["-o", src, "texput.dvi"], + "dvisvgm": ["-o", src, "texput.dvi"], + } + + if dvioptions is not None: + cmd.extend(dvioptions) + else: + cmd.extend(defaultoptions.get(cmd_variant, [])) + cmd.extend(commandend[cmd_variant]) + + try: + _check_output_no_window(cmd, cwd=workdir, stderr=STDOUT) + except CalledProcessError as e: + raise RuntimeError( + "'%s' exited abnormally with the following output:\n%s" % + (' '.join(cmd), e.output)) + + + if viewer == "file": + shutil.move(join(workdir, src), filename) + elif viewer == "BytesIO": + s = Path(join(workdir, src)).read_bytes() + outputbuffer.write(s) + elif callable(viewer): + viewer(join(workdir, src), fmt=output) + else: + try: + _check_output_no_window( + [viewer, src], cwd=workdir, stderr=STDOUT) + except CalledProcessError as e: + raise RuntimeError( + "'%s %s' exited abnormally with the following output:\n%s" % + (viewer, src, e.output)) diff --git a/phivenv/Lib/site-packages/sympy/printing/printer.py b/phivenv/Lib/site-packages/sympy/printing/printer.py new file mode 100644 index 0000000000000000000000000000000000000000..0c0a6970920cf0928ad330ed9a3ea4291107a29d --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/printer.py @@ -0,0 +1,432 @@ +"""Printing subsystem driver + +SymPy's printing system works the following way: Any expression can be +passed to a designated Printer who then is responsible to return an +adequate representation of that expression. + +**The basic concept is the following:** + +1. Let the object print itself if it knows how. +2. Take the best fitting method defined in the printer. +3. As fall-back use the emptyPrinter method for the printer. + +Which Method is Responsible for Printing? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The whole printing process is started by calling ``.doprint(expr)`` on the printer +which you want to use. This method looks for an appropriate method which can +print the given expression in the given style that the printer defines. +While looking for the method, it follows these steps: + +1. **Let the object print itself if it knows how.** + + The printer looks for a specific method in every object. The name of that method + depends on the specific printer and is defined under ``Printer.printmethod``. + For example, StrPrinter calls ``_sympystr`` and LatexPrinter calls ``_latex``. + Look at the documentation of the printer that you want to use. + The name of the method is specified there. + + This was the original way of doing printing in sympy. Every class had + its own latex, mathml, str and repr methods, but it turned out that it + is hard to produce a high quality printer, if all the methods are spread + out that far. Therefore all printing code was combined into the different + printers, which works great for built-in SymPy objects, but not that + good for user defined classes where it is inconvenient to patch the + printers. + +2. **Take the best fitting method defined in the printer.** + + The printer loops through expr classes (class + its bases), and tries + to dispatch the work to ``_print_`` + + e.g., suppose we have the following class hierarchy:: + + Basic + | + Atom + | + Number + | + Rational + + then, for ``expr=Rational(...)``, the Printer will try + to call printer methods in the order as shown in the figure below:: + + p._print(expr) + | + |-- p._print_Rational(expr) + | + |-- p._print_Number(expr) + | + |-- p._print_Atom(expr) + | + `-- p._print_Basic(expr) + + if ``._print_Rational`` method exists in the printer, then it is called, + and the result is returned back. Otherwise, the printer tries to call + ``._print_Number`` and so on. + +3. **As a fall-back use the emptyPrinter method for the printer.** + + As fall-back ``self.emptyPrinter`` will be called with the expression. If + not defined in the Printer subclass this will be the same as ``str(expr)``. + +.. _printer_example: + +Example of Custom Printer +^^^^^^^^^^^^^^^^^^^^^^^^^ + +In the example below, we have a printer which prints the derivative of a function +in a shorter form. + +.. code-block:: python + + from sympy.core.symbol import Symbol + from sympy.printing.latex import LatexPrinter, print_latex + from sympy.core.function import UndefinedFunction, Function + + + class MyLatexPrinter(LatexPrinter): + \"\"\"Print derivative of a function of symbols in a shorter form. + \"\"\" + def _print_Derivative(self, expr): + function, *vars = expr.args + if not isinstance(type(function), UndefinedFunction) or \\ + not all(isinstance(i, Symbol) for i in vars): + return super()._print_Derivative(expr) + + # If you want the printer to work correctly for nested + # expressions then use self._print() instead of str() or latex(). + # See the example of nested modulo below in the custom printing + # method section. + return "{}_{{{}}}".format( + self._print(Symbol(function.func.__name__)), + ''.join(self._print(i) for i in vars)) + + + def print_my_latex(expr): + \"\"\" Most of the printers define their own wrappers for print(). + These wrappers usually take printer settings. Our printer does not have + any settings. + \"\"\" + print(MyLatexPrinter().doprint(expr)) + + + y = Symbol("y") + x = Symbol("x") + f = Function("f") + expr = f(x, y).diff(x, y) + + # Print the expression using the normal latex printer and our custom + # printer. + print_latex(expr) + print_my_latex(expr) + +The output of the code above is:: + + \\frac{\\partial^{2}}{\\partial x\\partial y} f{\\left(x,y \\right)} + f_{xy} + +.. _printer_method_example: + +Example of Custom Printing Method +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In the example below, the latex printing of the modulo operator is modified. +This is done by overriding the method ``_latex`` of ``Mod``. + +>>> from sympy import Symbol, Mod, Integer, print_latex + +>>> # Always use printer._print() +>>> class ModOp(Mod): +... def _latex(self, printer): +... a, b = [printer._print(i) for i in self.args] +... return r"\\operatorname{Mod}{\\left(%s, %s\\right)}" % (a, b) + +Comparing the output of our custom operator to the builtin one: + +>>> x = Symbol('x') +>>> m = Symbol('m') +>>> print_latex(Mod(x, m)) +x \\bmod m +>>> print_latex(ModOp(x, m)) +\\operatorname{Mod}{\\left(x, m\\right)} + +Common mistakes +~~~~~~~~~~~~~~~ +It's important to always use ``self._print(obj)`` to print subcomponents of +an expression when customizing a printer. Mistakes include: + +1. Using ``self.doprint(obj)`` instead: + + >>> # This example does not work properly, as only the outermost call may use + >>> # doprint. + >>> class ModOpModeWrong(Mod): + ... def _latex(self, printer): + ... a, b = [printer.doprint(i) for i in self.args] + ... return r"\\operatorname{Mod}{\\left(%s, %s\\right)}" % (a, b) + + This fails when the ``mode`` argument is passed to the printer: + + >>> print_latex(ModOp(x, m), mode='inline') # ok + $\\operatorname{Mod}{\\left(x, m\\right)}$ + >>> print_latex(ModOpModeWrong(x, m), mode='inline') # bad + $\\operatorname{Mod}{\\left($x$, $m$\\right)}$ + +2. Using ``str(obj)`` instead: + + >>> class ModOpNestedWrong(Mod): + ... def _latex(self, printer): + ... a, b = [str(i) for i in self.args] + ... return r"\\operatorname{Mod}{\\left(%s, %s\\right)}" % (a, b) + + This fails on nested objects: + + >>> # Nested modulo. + >>> print_latex(ModOp(ModOp(x, m), Integer(7))) # ok + \\operatorname{Mod}{\\left(\\operatorname{Mod}{\\left(x, m\\right)}, 7\\right)} + >>> print_latex(ModOpNestedWrong(ModOpNestedWrong(x, m), Integer(7))) # bad + \\operatorname{Mod}{\\left(ModOpNestedWrong(x, m), 7\\right)} + +3. Using ``LatexPrinter()._print(obj)`` instead. + + >>> from sympy.printing.latex import LatexPrinter + >>> class ModOpSettingsWrong(Mod): + ... def _latex(self, printer): + ... a, b = [LatexPrinter()._print(i) for i in self.args] + ... return r"\\operatorname{Mod}{\\left(%s, %s\\right)}" % (a, b) + + This causes all the settings to be discarded in the subobjects. As an + example, the ``full_prec`` setting which shows floats to full precision is + ignored: + + >>> from sympy import Float + >>> print_latex(ModOp(Float(1) * x, m), full_prec=True) # ok + \\operatorname{Mod}{\\left(1.00000000000000 x, m\\right)} + >>> print_latex(ModOpSettingsWrong(Float(1) * x, m), full_prec=True) # bad + \\operatorname{Mod}{\\left(1.0 x, m\\right)} + +""" + +from __future__ import annotations +import sys +from typing import Any, Type +import inspect +from contextlib import contextmanager +from functools import cmp_to_key, update_wrapper + +from sympy.core.add import Add +from sympy.core.basic import Basic + +from sympy.core.function import AppliedUndef, UndefinedFunction, Function + + + +@contextmanager +def printer_context(printer, **kwargs): + original = printer._context.copy() + try: + printer._context.update(kwargs) + yield + finally: + printer._context = original + + +class Printer: + """ Generic printer + + Its job is to provide infrastructure for implementing new printers easily. + + If you want to define your custom Printer or your custom printing method + for your custom class then see the example above: printer_example_ . + """ + + _global_settings: dict[str, Any] = {} + + _default_settings: dict[str, Any] = {} + + # must be initialized to pass tests and cannot be set to '| None' to pass mypy + printmethod = None # type: str + + @classmethod + def _get_initial_settings(cls): + settings = cls._default_settings.copy() + for key, val in cls._global_settings.items(): + if key in cls._default_settings: + settings[key] = val + return settings + + def __init__(self, settings=None): + self._str = str + + self._settings = self._get_initial_settings() + self._context = {} # mutable during printing + + if settings is not None: + self._settings.update(settings) + + if len(self._settings) > len(self._default_settings): + for key in self._settings: + if key not in self._default_settings: + raise TypeError("Unknown setting '%s'." % key) + + # _print_level is the number of times self._print() was recursively + # called. See StrPrinter._print_Float() for an example of usage + self._print_level = 0 + + @classmethod + def set_global_settings(cls, **settings): + """Set system-wide printing settings. """ + for key, val in settings.items(): + if val is not None: + cls._global_settings[key] = val + + @property + def order(self): + if 'order' in self._settings: + return self._settings['order'] + else: + raise AttributeError("No order defined.") + + def doprint(self, expr): + """Returns printer's representation for expr (as a string)""" + return self._str(self._print(expr)) + + def _print(self, expr, **kwargs) -> str: + """Internal dispatcher + + Tries the following concepts to print an expression: + 1. Let the object print itself if it knows how. + 2. Take the best fitting method defined in the printer. + 3. As fall-back use the emptyPrinter method for the printer. + """ + self._print_level += 1 + try: + # If the printer defines a name for a printing method + # (Printer.printmethod) and the object knows for itself how it + # should be printed, use that method. + if self.printmethod and hasattr(expr, self.printmethod): + if not (isinstance(expr, type) and issubclass(expr, Basic)): + return getattr(expr, self.printmethod)(self, **kwargs) + + # See if the class of expr is known, or if one of its super + # classes is known, and use that print function + # Exception: ignore the subclasses of Undefined, so that, e.g., + # Function('gamma') does not get dispatched to _print_gamma + classes = type(expr).__mro__ + if AppliedUndef in classes: + classes = classes[classes.index(AppliedUndef):] + if UndefinedFunction in classes: + classes = classes[classes.index(UndefinedFunction):] + # Another exception: if someone subclasses a known function, e.g., + # gamma, and changes the name, then ignore _print_gamma + if Function in classes: + i = classes.index(Function) + classes = tuple(c for c in classes[:i] if \ + c.__name__ == classes[0].__name__ or \ + c.__name__.endswith("Base")) + classes[i:] + for cls in classes: + printmethodname = '_print_' + cls.__name__ + printmethod = getattr(self, printmethodname, None) + if printmethod is not None: + return printmethod(expr, **kwargs) + # Unknown object, fall back to the emptyPrinter. + return self.emptyPrinter(expr) + finally: + self._print_level -= 1 + + def emptyPrinter(self, expr): + return str(expr) + + def _as_ordered_terms(self, expr, order=None): + """A compatibility function for ordering terms in Add. """ + order = order or self.order + + if order == 'old': + return sorted(Add.make_args(expr), key=cmp_to_key(self._compare_pretty)) + elif order == 'none': + return list(expr.args) + else: + return expr.as_ordered_terms(order=order) + + def _compare_pretty(self, a, b): + """return -1, 0, 1 if a is canonically less, equal or + greater than b. This is used when 'order=old' is selected + for printing. This puts Order last, orders Rationals + according to value, puts terms in order wrt the power of + the last power appearing in a term. Ties are broken using + Basic.compare. + """ + from sympy.core.numbers import Rational + from sympy.core.symbol import Wild + from sympy.series.order import Order + if isinstance(a, Order) and not isinstance(b, Order): + return 1 + if not isinstance(a, Order) and isinstance(b, Order): + return -1 + + if isinstance(a, Rational) and isinstance(b, Rational): + l = a.p * b.q + r = b.p * a.q + return (l > r) - (l < r) + else: + p1, p2, p3 = Wild("p1"), Wild("p2"), Wild("p3") + r_a = a.match(p1 * p2**p3) + if r_a and p3 in r_a: + a3 = r_a[p3] + r_b = b.match(p1 * p2**p3) + if r_b and p3 in r_b: + b3 = r_b[p3] + c = Basic.compare(a3, b3) + if c != 0: + return c + + # break ties + return Basic.compare(a, b) + + +class _PrintFunction: + """ + Function wrapper to replace ``**settings`` in the signature with printer defaults + """ + def __init__(self, f, print_cls: Type[Printer]): + # find all the non-setting arguments + params = list(inspect.signature(f).parameters.values()) + assert params.pop(-1).kind == inspect.Parameter.VAR_KEYWORD + self.__other_params = params + + self.__print_cls = print_cls + update_wrapper(self, f) + + def __reduce__(self): + # Since this is used as a decorator, it replaces the original function. + # The default pickling will try to pickle self.__wrapped__ and fail + # because the wrapped function can't be retrieved by name. + return self.__wrapped__.__qualname__ + + def __call__(self, *args, **kwargs): + return self.__wrapped__(*args, **kwargs) + + @property + def __signature__(self) -> inspect.Signature: + settings = self.__print_cls._get_initial_settings() + return inspect.Signature( + parameters=self.__other_params + [ + inspect.Parameter(k, inspect.Parameter.KEYWORD_ONLY, default=v) + for k, v in settings.items() + ], + return_annotation=self.__wrapped__.__annotations__.get('return', inspect.Signature.empty) # type:ignore + ) + + +def print_function(print_cls): + """ A decorator to replace kwargs with the printer settings in __signature__ """ + def decorator(f): + if sys.version_info < (3, 9): + # We have to create a subclass so that `help` actually shows the docstring in older Python versions. + # IPython and Sphinx do not need this, only a raw Python console. + cls = type(f'{f.__qualname__}_PrintFunction', (_PrintFunction,), {"__doc__": f.__doc__}) + else: + cls = _PrintFunction + return cls(f, print_cls) + return decorator diff --git a/phivenv/Lib/site-packages/sympy/printing/pycode.py b/phivenv/Lib/site-packages/sympy/printing/pycode.py new file mode 100644 index 0000000000000000000000000000000000000000..09bdc6788775d409c06bdaae0a43c54544894602 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/pycode.py @@ -0,0 +1,852 @@ +""" +Python code printers + +This module contains Python code printers for plain Python as well as NumPy & SciPy enabled code. +""" +from collections import defaultdict +from itertools import chain +from sympy.core import S +from sympy.core.mod import Mod +from .precedence import precedence +from .codeprinter import CodePrinter + +_kw = { + 'and', 'as', 'assert', 'break', 'class', 'continue', 'def', 'del', 'elif', + 'else', 'except', 'finally', 'for', 'from', 'global', 'if', 'import', 'in', + 'is', 'lambda', 'not', 'or', 'pass', 'raise', 'return', 'try', 'while', + 'with', 'yield', 'None', 'False', 'nonlocal', 'True' +} + +_known_functions = { + 'Abs': 'abs', + 'Min': 'min', + 'Max': 'max', +} +_known_functions_math = { + 'acos': 'acos', + 'acosh': 'acosh', + 'asin': 'asin', + 'asinh': 'asinh', + 'atan': 'atan', + 'atan2': 'atan2', + 'atanh': 'atanh', + 'ceiling': 'ceil', + 'cos': 'cos', + 'cosh': 'cosh', + 'erf': 'erf', + 'erfc': 'erfc', + 'exp': 'exp', + 'expm1': 'expm1', + 'factorial': 'factorial', + 'floor': 'floor', + 'gamma': 'gamma', + 'hypot': 'hypot', + 'isinf': 'isinf', + 'isnan': 'isnan', + 'loggamma': 'lgamma', + 'log': 'log', + 'ln': 'log', + 'log10': 'log10', + 'log1p': 'log1p', + 'log2': 'log2', + 'sin': 'sin', + 'sinh': 'sinh', + 'Sqrt': 'sqrt', + 'tan': 'tan', + 'tanh': 'tanh' +} # Not used from ``math``: [copysign isclose isfinite isinf ldexp frexp pow modf +# radians trunc fmod fsum gcd degrees fabs] +_known_constants_math = { + 'Exp1': 'e', + 'Pi': 'pi', + 'E': 'e', + 'Infinity': 'inf', + 'NaN': 'nan', + 'ComplexInfinity': 'nan' +} + +def _print_known_func(self, expr): + known = self.known_functions[expr.__class__.__name__] + return '{name}({args})'.format(name=self._module_format(known), + args=', '.join((self._print(arg) for arg in expr.args))) + + +def _print_known_const(self, expr): + known = self.known_constants[expr.__class__.__name__] + return self._module_format(known) + + +class AbstractPythonCodePrinter(CodePrinter): + printmethod = "_pythoncode" + language = "Python" + reserved_words = _kw + modules = None # initialized to a set in __init__ + tab = ' ' + _kf = dict(chain( + _known_functions.items(), + [(k, 'math.' + v) for k, v in _known_functions_math.items()] + )) + _kc = {k: 'math.'+v for k, v in _known_constants_math.items()} + _operators = {'and': 'and', 'or': 'or', 'not': 'not'} + _default_settings = dict( + CodePrinter._default_settings, + user_functions={}, + precision=17, + inline=True, + fully_qualified_modules=True, + contract=False, + standard='python3', + ) + + def __init__(self, settings=None): + super().__init__(settings) + + # Python standard handler + std = self._settings['standard'] + if std is None: + import sys + std = 'python{}'.format(sys.version_info.major) + if std != 'python3': + raise ValueError('Only Python 3 is supported.') + self.standard = std + + self.module_imports = defaultdict(set) + + # Known functions and constants handler + self.known_functions = dict(self._kf, **(settings or {}).get( + 'user_functions', {})) + self.known_constants = dict(self._kc, **(settings or {}).get( + 'user_constants', {})) + + def _declare_number_const(self, name, value): + return "%s = %s" % (name, value) + + def _module_format(self, fqn, register=True): + parts = fqn.split('.') + if register and len(parts) > 1: + self.module_imports['.'.join(parts[:-1])].add(parts[-1]) + + if self._settings['fully_qualified_modules']: + return fqn + else: + return fqn.split('(')[0].split('[')[0].split('.')[-1] + + def _format_code(self, lines): + return lines + + def _get_statement(self, codestring): + return "{}".format(codestring) + + def _get_comment(self, text): + return " # {}".format(text) + + def _expand_fold_binary_op(self, op, args): + """ + This method expands a fold on binary operations. + + ``functools.reduce`` is an example of a folded operation. + + For example, the expression + + `A + B + C + D` + + is folded into + + `((A + B) + C) + D` + """ + if len(args) == 1: + return self._print(args[0]) + else: + return "%s(%s, %s)" % ( + self._module_format(op), + self._expand_fold_binary_op(op, args[:-1]), + self._print(args[-1]), + ) + + def _expand_reduce_binary_op(self, op, args): + """ + This method expands a reduction on binary operations. + + Notice: this is NOT the same as ``functools.reduce``. + + For example, the expression + + `A + B + C + D` + + is reduced into: + + `(A + B) + (C + D)` + """ + if len(args) == 1: + return self._print(args[0]) + else: + N = len(args) + Nhalf = N // 2 + return "%s(%s, %s)" % ( + self._module_format(op), + self._expand_reduce_binary_op(args[:Nhalf]), + self._expand_reduce_binary_op(args[Nhalf:]), + ) + + def _print_NaN(self, expr): + return "float('nan')" + + def _print_Infinity(self, expr): + return "float('inf')" + + def _print_NegativeInfinity(self, expr): + return "float('-inf')" + + def _print_ComplexInfinity(self, expr): + return self._print_NaN(expr) + + def _print_Mod(self, expr): + PREC = precedence(expr) + return ('{} % {}'.format(*(self.parenthesize(x, PREC) for x in expr.args))) + + def _print_Piecewise(self, expr): + result = [] + i = 0 + for arg in expr.args: + e = arg.expr + c = arg.cond + if i == 0: + result.append('(') + result.append('(') + result.append(self._print(e)) + result.append(')') + result.append(' if ') + result.append(self._print(c)) + result.append(' else ') + i += 1 + result = result[:-1] + if result[-1] == 'True': + result = result[:-2] + result.append(')') + else: + result.append(' else None)') + return ''.join(result) + + def _print_Relational(self, expr): + "Relational printer for Equality and Unequality" + op = { + '==' :'equal', + '!=' :'not_equal', + '<' :'less', + '<=' :'less_equal', + '>' :'greater', + '>=' :'greater_equal', + } + if expr.rel_op in op: + lhs = self._print(expr.lhs) + rhs = self._print(expr.rhs) + return '({lhs} {op} {rhs})'.format(op=expr.rel_op, lhs=lhs, rhs=rhs) + return super()._print_Relational(expr) + + def _print_ITE(self, expr): + from sympy.functions.elementary.piecewise import Piecewise + return self._print(expr.rewrite(Piecewise)) + + def _print_Sum(self, expr): + loops = ( + 'for {i} in range({a}, {b}+1)'.format( + i=self._print(i), + a=self._print(a), + b=self._print(b)) + for i, a, b in expr.limits[::-1]) + return '(builtins.sum({function} {loops}))'.format( + function=self._print(expr.function), + loops=' '.join(loops)) + + def _print_ImaginaryUnit(self, expr): + return '1j' + + def _print_KroneckerDelta(self, expr): + a, b = expr.args + + return '(1 if {a} == {b} else 0)'.format( + a = self._print(a), + b = self._print(b) + ) + + def _print_MatrixBase(self, expr): + name = expr.__class__.__name__ + func = self.known_functions.get(name, name) + return "%s(%s)" % (func, self._print(expr.tolist())) + + _print_SparseRepMatrix = \ + _print_MutableSparseMatrix = \ + _print_ImmutableSparseMatrix = \ + _print_Matrix = \ + _print_DenseMatrix = \ + _print_MutableDenseMatrix = \ + _print_ImmutableMatrix = \ + _print_ImmutableDenseMatrix = \ + lambda self, expr: self._print_MatrixBase(expr) + + def _indent_codestring(self, codestring): + return '\n'.join([self.tab + line for line in codestring.split('\n')]) + + def _print_FunctionDefinition(self, fd): + body = '\n'.join((self._print(arg) for arg in fd.body)) + return "def {name}({parameters}):\n{body}".format( + name=self._print(fd.name), + parameters=', '.join([self._print(var.symbol) for var in fd.parameters]), + body=self._indent_codestring(body) + ) + + def _print_While(self, whl): + body = '\n'.join((self._print(arg) for arg in whl.body)) + return "while {cond}:\n{body}".format( + cond=self._print(whl.condition), + body=self._indent_codestring(body) + ) + + def _print_Declaration(self, decl): + return '%s = %s' % ( + self._print(decl.variable.symbol), + self._print(decl.variable.value) + ) + + def _print_BreakToken(self, bt): + return 'break' + + def _print_Return(self, ret): + arg, = ret.args + return 'return %s' % self._print(arg) + + def _print_Raise(self, rs): + arg, = rs.args + return 'raise %s' % self._print(arg) + + def _print_RuntimeError_(self, re): + message, = re.args + return "RuntimeError(%s)" % self._print(message) + + def _print_Print(self, prnt): + print_args = ', '.join((self._print(arg) for arg in prnt.print_args)) + from sympy.codegen.ast import none + if prnt.format_string != none: + print_args = '{} % ({}), end=""'.format( + self._print(prnt.format_string), + print_args + ) + if prnt.file != None: # Must be '!= None', cannot be 'is not None' + print_args += ', file=%s' % self._print(prnt.file) + return 'print(%s)' % print_args + + def _print_Stream(self, strm): + if str(strm.name) == 'stdout': + return self._module_format('sys.stdout') + elif str(strm.name) == 'stderr': + return self._module_format('sys.stderr') + else: + return self._print(strm.name) + + def _print_NoneToken(self, arg): + return 'None' + + def _hprint_Pow(self, expr, rational=False, sqrt='math.sqrt'): + """Printing helper function for ``Pow`` + + Notes + ===== + + This preprocesses the ``sqrt`` as math formatter and prints division + + Examples + ======== + + >>> from sympy import sqrt + >>> from sympy.printing.pycode import PythonCodePrinter + >>> from sympy.abc import x + + Python code printer automatically looks up ``math.sqrt``. + + >>> printer = PythonCodePrinter() + >>> printer._hprint_Pow(sqrt(x), rational=True) + 'x**(1/2)' + >>> printer._hprint_Pow(sqrt(x), rational=False) + 'math.sqrt(x)' + >>> printer._hprint_Pow(1/sqrt(x), rational=True) + 'x**(-1/2)' + >>> printer._hprint_Pow(1/sqrt(x), rational=False) + '1/math.sqrt(x)' + >>> printer._hprint_Pow(1/x, rational=False) + '1/x' + >>> printer._hprint_Pow(1/x, rational=True) + 'x**(-1)' + + Using sqrt from numpy or mpmath + + >>> printer._hprint_Pow(sqrt(x), sqrt='numpy.sqrt') + 'numpy.sqrt(x)' + >>> printer._hprint_Pow(sqrt(x), sqrt='mpmath.sqrt') + 'mpmath.sqrt(x)' + + See Also + ======== + + sympy.printing.str.StrPrinter._print_Pow + """ + PREC = precedence(expr) + + if expr.exp == S.Half and not rational: + func = self._module_format(sqrt) + arg = self._print(expr.base) + return '{func}({arg})'.format(func=func, arg=arg) + + if expr.is_commutative and not rational: + if -expr.exp is S.Half: + func = self._module_format(sqrt) + num = self._print(S.One) + arg = self._print(expr.base) + return f"{num}/{func}({arg})" + if expr.exp is S.NegativeOne: + num = self._print(S.One) + arg = self.parenthesize(expr.base, PREC, strict=False) + return f"{num}/{arg}" + + + base_str = self.parenthesize(expr.base, PREC, strict=False) + exp_str = self.parenthesize(expr.exp, PREC, strict=False) + return "{}**{}".format(base_str, exp_str) + + +class ArrayPrinter: + + def _arrayify(self, indexed): + from sympy.tensor.array.expressions.from_indexed_to_array import convert_indexed_to_array + try: + return convert_indexed_to_array(indexed) + except Exception: + return indexed + + def _get_einsum_string(self, subranks, contraction_indices): + letters = self._get_letter_generator_for_einsum() + contraction_string = "" + counter = 0 + d = {j: min(i) for i in contraction_indices for j in i} + indices = [] + for rank_arg in subranks: + lindices = [] + for i in range(rank_arg): + if counter in d: + lindices.append(d[counter]) + else: + lindices.append(counter) + counter += 1 + indices.append(lindices) + mapping = {} + letters_free = [] + letters_dum = [] + for i in indices: + for j in i: + if j not in mapping: + l = next(letters) + mapping[j] = l + else: + l = mapping[j] + contraction_string += l + if j in d: + if l not in letters_dum: + letters_dum.append(l) + else: + letters_free.append(l) + contraction_string += "," + contraction_string = contraction_string[:-1] + return contraction_string, letters_free, letters_dum + + def _get_letter_generator_for_einsum(self): + for i in range(97, 123): + yield chr(i) + for i in range(65, 91): + yield chr(i) + raise ValueError("out of letters") + + def _print_ArrayTensorProduct(self, expr): + letters = self._get_letter_generator_for_einsum() + contraction_string = ",".join(["".join([next(letters) for j in range(i)]) for i in expr.subranks]) + return '%s("%s", %s)' % ( + self._module_format(self._module + "." + self._einsum), + contraction_string, + ", ".join([self._print(arg) for arg in expr.args]) + ) + + def _print_ArrayContraction(self, expr): + from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct + base = expr.expr + contraction_indices = expr.contraction_indices + + if isinstance(base, ArrayTensorProduct): + elems = ",".join(["%s" % (self._print(arg)) for arg in base.args]) + ranks = base.subranks + else: + elems = self._print(base) + ranks = [len(base.shape)] + + contraction_string, letters_free, letters_dum = self._get_einsum_string(ranks, contraction_indices) + + if not contraction_indices: + return self._print(base) + if isinstance(base, ArrayTensorProduct): + elems = ",".join(["%s" % (self._print(arg)) for arg in base.args]) + else: + elems = self._print(base) + return "%s(\"%s\", %s)" % ( + self._module_format(self._module + "." + self._einsum), + "{}->{}".format(contraction_string, "".join(sorted(letters_free))), + elems, + ) + + def _print_ArrayDiagonal(self, expr): + from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct + diagonal_indices = list(expr.diagonal_indices) + if isinstance(expr.expr, ArrayTensorProduct): + subranks = expr.expr.subranks + elems = expr.expr.args + else: + subranks = expr.subranks + elems = [expr.expr] + diagonal_string, letters_free, letters_dum = self._get_einsum_string(subranks, diagonal_indices) + elems = [self._print(i) for i in elems] + return '%s("%s", %s)' % ( + self._module_format(self._module + "." + self._einsum), + "{}->{}".format(diagonal_string, "".join(letters_free+letters_dum)), + ", ".join(elems) + ) + + def _print_PermuteDims(self, expr): + return "%s(%s, %s)" % ( + self._module_format(self._module + "." + self._transpose), + self._print(expr.expr), + self._print(expr.permutation.array_form), + ) + + def _print_ArrayAdd(self, expr): + return self._expand_fold_binary_op(self._module + "." + self._add, expr.args) + + def _print_OneArray(self, expr): + return "%s((%s,))" % ( + self._module_format(self._module+ "." + self._ones), + ','.join(map(self._print,expr.args)) + ) + + def _print_ZeroArray(self, expr): + return "%s((%s,))" % ( + self._module_format(self._module+ "." + self._zeros), + ','.join(map(self._print,expr.args)) + ) + + def _print_Assignment(self, expr): + #XXX: maybe this needs to happen at a higher level e.g. at _print or + #doprint? + lhs = self._print(self._arrayify(expr.lhs)) + rhs = self._print(self._arrayify(expr.rhs)) + return "%s = %s" % ( lhs, rhs ) + + def _print_IndexedBase(self, expr): + return self._print_ArraySymbol(expr) + + +class PythonCodePrinter(AbstractPythonCodePrinter): + + def _print_sign(self, e): + return '(0.0 if {e} == 0 else {f}(1, {e}))'.format( + f=self._module_format('math.copysign'), e=self._print(e.args[0])) + + def _print_Not(self, expr): + PREC = precedence(expr) + return self._operators['not'] + ' ' + self.parenthesize(expr.args[0], PREC) + + def _print_IndexedBase(self, expr): + return expr.name + + def _print_Indexed(self, expr): + base = expr.args[0] + index = expr.args[1:] + return "{}[{}]".format(str(base), ", ".join([self._print(ind) for ind in index])) + + def _print_Pow(self, expr, rational=False): + return self._hprint_Pow(expr, rational=rational) + + def _print_Rational(self, expr): + return '{}/{}'.format(expr.p, expr.q) + + def _print_Half(self, expr): + return self._print_Rational(expr) + + def _print_frac(self, expr): + return self._print_Mod(Mod(expr.args[0], 1)) + + def _print_Symbol(self, expr): + + name = super()._print_Symbol(expr) + + if name in self.reserved_words: + if self._settings['error_on_reserved']: + msg = ('This expression includes the symbol "{}" which is a ' + 'reserved keyword in this language.') + raise ValueError(msg.format(name)) + return name + self._settings['reserved_word_suffix'] + elif '{' in name: # Remove curly braces from subscripted variables + return name.replace('{', '').replace('}', '') + else: + return name + + _print_lowergamma = CodePrinter._print_not_supported + _print_uppergamma = CodePrinter._print_not_supported + _print_fresnelc = CodePrinter._print_not_supported + _print_fresnels = CodePrinter._print_not_supported + + +for k in PythonCodePrinter._kf: + setattr(PythonCodePrinter, '_print_%s' % k, _print_known_func) + +for k in _known_constants_math: + setattr(PythonCodePrinter, '_print_%s' % k, _print_known_const) + + +def pycode(expr, **settings): + """ Converts an expr to a string of Python code + + Parameters + ========== + + expr : Expr + A SymPy expression. + fully_qualified_modules : bool + Whether or not to write out full module names of functions + (``math.sin`` vs. ``sin``). default: ``True``. + standard : str or None, optional + Only 'python3' (default) is supported. + This parameter may be removed in the future. + + Examples + ======== + + >>> from sympy import pycode, tan, Symbol + >>> pycode(tan(Symbol('x')) + 1) + 'math.tan(x) + 1' + + """ + return PythonCodePrinter(settings).doprint(expr) + + +from itertools import chain +from sympy.printing.pycode import PythonCodePrinter + +_known_functions_cmath = { + 'exp': 'exp', + 'sqrt': 'sqrt', + 'log': 'log', + 'cos': 'cos', + 'sin': 'sin', + 'tan': 'tan', + 'acos': 'acos', + 'asin': 'asin', + 'atan': 'atan', + 'cosh': 'cosh', + 'sinh': 'sinh', + 'tanh': 'tanh', + 'acosh': 'acosh', + 'asinh': 'asinh', + 'atanh': 'atanh', +} + +_known_constants_cmath = { + 'Pi': 'pi', + 'E': 'e', + 'Infinity': 'inf', + 'NegativeInfinity': '-inf', +} + +class CmathPrinter(PythonCodePrinter): + """ Printer for Python's cmath module """ + printmethod = "_cmathcode" + language = "Python with cmath" + + _kf = dict(chain( + _known_functions_cmath.items() + )) + + _kc = {k: 'cmath.' + v for k, v in _known_constants_cmath.items()} + + def _print_Pow(self, expr, rational=False): + return self._hprint_Pow(expr, rational=rational, sqrt='cmath.sqrt') + + def _print_Float(self, e): + return '{func}({val})'.format(func=self._module_format('cmath.mpf'), val=self._print(e)) + + def _print_known_func(self, expr): + func_name = expr.func.__name__ + if func_name in self._kf: + return f"cmath.{self._kf[func_name]}({', '.join(map(self._print, expr.args))})" + return super()._print_Function(expr) + + def _print_known_const(self, expr): + return self._kc[expr.__class__.__name__] + + def _print_re(self, expr): + """Prints `re(z)` as `z.real`""" + return f"({self._print(expr.args[0])}).real" + + def _print_im(self, expr): + """Prints `im(z)` as `z.imag`""" + return f"({self._print(expr.args[0])}).imag" + + +for k in CmathPrinter._kf: + setattr(CmathPrinter, '_print_%s' % k, CmathPrinter._print_known_func) + +for k in _known_constants_cmath: + setattr(CmathPrinter, '_print_%s' % k, CmathPrinter._print_known_const) + + +_not_in_mpmath = 'log1p log2'.split() +_in_mpmath = [(k, v) for k, v in _known_functions_math.items() if k not in _not_in_mpmath] +_known_functions_mpmath = dict(_in_mpmath, **{ + 'beta': 'beta', + 'frac': 'frac', + 'fresnelc': 'fresnelc', + 'fresnels': 'fresnels', + 'sign': 'sign', + 'loggamma': 'loggamma', + 'hyper': 'hyper', + 'meijerg': 'meijerg', + 'besselj': 'besselj', + 'bessely': 'bessely', + 'besseli': 'besseli', + 'besselk': 'besselk', +}) +_known_constants_mpmath = { + 'Exp1': 'e', + 'Pi': 'pi', + 'GoldenRatio': 'phi', + 'EulerGamma': 'euler', + 'Catalan': 'catalan', + 'NaN': 'nan', + 'Infinity': 'inf', + 'NegativeInfinity': 'ninf' +} + + +def _unpack_integral_limits(integral_expr): + """ helper function for _print_Integral that + - accepts an Integral expression + - returns a tuple of + - a list variables of integration + - a list of tuples of the upper and lower limits of integration + """ + integration_vars = [] + limits = [] + for integration_range in integral_expr.limits: + if len(integration_range) == 3: + integration_var, lower_limit, upper_limit = integration_range + else: + raise NotImplementedError("Only definite integrals are supported") + integration_vars.append(integration_var) + limits.append((lower_limit, upper_limit)) + return integration_vars, limits + + +class MpmathPrinter(PythonCodePrinter): + """ + Lambda printer for mpmath which maintains precision for floats + """ + printmethod = "_mpmathcode" + + language = "Python with mpmath" + + _kf = dict(chain( + _known_functions.items(), + [(k, 'mpmath.' + v) for k, v in _known_functions_mpmath.items()] + )) + _kc = {k: 'mpmath.'+v for k, v in _known_constants_mpmath.items()} + + def _print_Float(self, e): + # XXX: This does not handle setting mpmath.mp.dps. It is assumed that + # the caller of the lambdified function will have set it to sufficient + # precision to match the Floats in the expression. + + # Remove 'mpz' if gmpy is installed. + args = str(tuple(map(int, e._mpf_))) + return '{func}({args})'.format(func=self._module_format('mpmath.mpf'), args=args) + + + def _print_Rational(self, e): + return "{func}({p})/{func}({q})".format( + func=self._module_format('mpmath.mpf'), + q=self._print(e.q), + p=self._print(e.p) + ) + + def _print_Half(self, e): + return self._print_Rational(e) + + def _print_uppergamma(self, e): + return "{}({}, {}, {})".format( + self._module_format('mpmath.gammainc'), + self._print(e.args[0]), + self._print(e.args[1]), + self._module_format('mpmath.inf')) + + def _print_lowergamma(self, e): + return "{}({}, 0, {})".format( + self._module_format('mpmath.gammainc'), + self._print(e.args[0]), + self._print(e.args[1])) + + def _print_log2(self, e): + return '{0}({1})/{0}(2)'.format( + self._module_format('mpmath.log'), self._print(e.args[0])) + + def _print_log1p(self, e): + return '{}({})'.format( + self._module_format('mpmath.log1p'), self._print(e.args[0])) + + def _print_Pow(self, expr, rational=False): + return self._hprint_Pow(expr, rational=rational, sqrt='mpmath.sqrt') + + def _print_Integral(self, e): + integration_vars, limits = _unpack_integral_limits(e) + + return "{}(lambda {}: {}, {})".format( + self._module_format("mpmath.quad"), + ", ".join(map(self._print, integration_vars)), + self._print(e.args[0]), + ", ".join("(%s, %s)" % tuple(map(self._print, l)) for l in limits)) + + + def _print_Derivative_zeta(self, args, seq_orders): + arg, = args + deriv_order, = seq_orders + return '{}({}, derivative={})'.format( + self._module_format('mpmath.zeta'), + self._print(arg), deriv_order + ) + + +for k in MpmathPrinter._kf: + setattr(MpmathPrinter, '_print_%s' % k, _print_known_func) + +for k in _known_constants_mpmath: + setattr(MpmathPrinter, '_print_%s' % k, _print_known_const) + + +class SymPyPrinter(AbstractPythonCodePrinter): + + language = "Python with SymPy" + + _default_settings = dict( + AbstractPythonCodePrinter._default_settings, + strict=False # any class name will per definition be what we target in SymPyPrinter. + ) + + def _print_Function(self, expr): + mod = expr.func.__module__ or '' + return '%s(%s)' % (self._module_format(mod + ('.' if mod else '') + expr.func.__name__), + ', '.join((self._print(arg) for arg in expr.args))) + + def _print_Pow(self, expr, rational=False): + return self._hprint_Pow(expr, rational=rational, sqrt='sympy.sqrt') diff --git a/phivenv/Lib/site-packages/sympy/printing/python.py b/phivenv/Lib/site-packages/sympy/printing/python.py new file mode 100644 index 0000000000000000000000000000000000000000..2f6862574d99db90f289de65144c7122ed2d731a --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/python.py @@ -0,0 +1,92 @@ +import keyword as kw +import sympy +from .repr import ReprPrinter +from .str import StrPrinter + +# A list of classes that should be printed using StrPrinter +STRPRINT = ("Add", "Infinity", "Integer", "Mul", "NegativeInfinity", "Pow") + + +class PythonPrinter(ReprPrinter, StrPrinter): + """A printer which converts an expression into its Python interpretation.""" + + def __init__(self, settings=None): + super().__init__(settings) + self.symbols = [] + self.functions = [] + + # Create print methods for classes that should use StrPrinter instead + # of ReprPrinter. + for name in STRPRINT: + f_name = "_print_%s" % name + f = getattr(StrPrinter, f_name) + setattr(PythonPrinter, f_name, f) + + def _print_Function(self, expr): + func = expr.func.__name__ + if not hasattr(sympy, func) and func not in self.functions: + self.functions.append(func) + return StrPrinter._print_Function(self, expr) + + # procedure (!) for defining symbols which have be defined in print_python() + def _print_Symbol(self, expr): + symbol = self._str(expr) + if symbol not in self.symbols: + self.symbols.append(symbol) + return StrPrinter._print_Symbol(self, expr) + + def _print_module(self, expr): + raise ValueError('Modules in the expression are unacceptable') + + +def python(expr, **settings): + """Return Python interpretation of passed expression + (can be passed to the exec() function without any modifications)""" + + printer = PythonPrinter(settings) + exprp = printer.doprint(expr) + + result = '' + # Returning found symbols and functions + renamings = {} + for symbolname in printer.symbols: + # Remove curly braces from subscripted variables + if '{' in symbolname: + newsymbolname = symbolname.replace('{', '').replace('}', '') + renamings[sympy.Symbol(symbolname)] = newsymbolname + else: + newsymbolname = symbolname + + # Escape symbol names that are reserved Python keywords + if kw.iskeyword(newsymbolname): + while True: + newsymbolname += "_" + if (newsymbolname not in printer.symbols and + newsymbolname not in printer.functions): + renamings[sympy.Symbol( + symbolname)] = sympy.Symbol(newsymbolname) + break + result += newsymbolname + ' = Symbol(\'' + symbolname + '\')\n' + + for functionname in printer.functions: + newfunctionname = functionname + # Escape function names that are reserved Python keywords + if kw.iskeyword(newfunctionname): + while True: + newfunctionname += "_" + if (newfunctionname not in printer.symbols and + newfunctionname not in printer.functions): + renamings[sympy.Function( + functionname)] = sympy.Function(newfunctionname) + break + result += newfunctionname + ' = Function(\'' + functionname + '\')\n' + + if renamings: + exprp = expr.subs(renamings) + result += 'e = ' + printer._str(exprp) + return result + + +def print_python(expr, **settings): + """Print output of python() function""" + print(python(expr, **settings)) diff --git a/phivenv/Lib/site-packages/sympy/printing/pytorch.py b/phivenv/Lib/site-packages/sympy/printing/pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..0e8ff01856fa1dce7f8e786065a32bb74d987254 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/pytorch.py @@ -0,0 +1,297 @@ + +from sympy.printing.pycode import AbstractPythonCodePrinter, ArrayPrinter +from sympy.matrices.expressions import MatrixExpr +from sympy.core.mul import Mul +from sympy.printing.precedence import PRECEDENCE +from sympy.external import import_module +from sympy.codegen.cfunctions import Sqrt +from sympy import S +from sympy import Integer + +import sympy + +torch = import_module('torch') + + +class TorchPrinter(ArrayPrinter, AbstractPythonCodePrinter): + + printmethod = "_torchcode" + + mapping = { + sympy.Abs: "torch.abs", + sympy.sign: "torch.sign", + + # XXX May raise error for ints. + sympy.ceiling: "torch.ceil", + sympy.floor: "torch.floor", + sympy.log: "torch.log", + sympy.exp: "torch.exp", + Sqrt: "torch.sqrt", + sympy.cos: "torch.cos", + sympy.acos: "torch.acos", + sympy.sin: "torch.sin", + sympy.asin: "torch.asin", + sympy.tan: "torch.tan", + sympy.atan: "torch.atan", + sympy.atan2: "torch.atan2", + # XXX Also may give NaN for complex results. + sympy.cosh: "torch.cosh", + sympy.acosh: "torch.acosh", + sympy.sinh: "torch.sinh", + sympy.asinh: "torch.asinh", + sympy.tanh: "torch.tanh", + sympy.atanh: "torch.atanh", + sympy.Pow: "torch.pow", + + sympy.re: "torch.real", + sympy.im: "torch.imag", + sympy.arg: "torch.angle", + + # XXX May raise error for ints and complexes + sympy.erf: "torch.erf", + sympy.loggamma: "torch.lgamma", + + sympy.Eq: "torch.eq", + sympy.Ne: "torch.ne", + sympy.StrictGreaterThan: "torch.gt", + sympy.StrictLessThan: "torch.lt", + sympy.LessThan: "torch.le", + sympy.GreaterThan: "torch.ge", + + sympy.And: "torch.logical_and", + sympy.Or: "torch.logical_or", + sympy.Not: "torch.logical_not", + sympy.Max: "torch.max", + sympy.Min: "torch.min", + + # Matrices + sympy.MatAdd: "torch.add", + sympy.HadamardProduct: "torch.mul", + sympy.Trace: "torch.trace", + + # XXX May raise error for integer matrices. + sympy.Determinant: "torch.det", + } + + _default_settings = dict( + AbstractPythonCodePrinter._default_settings, + torch_version=None, + requires_grad=False, + dtype="torch.float64", + ) + + def __init__(self, settings=None): + super().__init__(settings) + + version = self._settings['torch_version'] + self.requires_grad = self._settings['requires_grad'] + self.dtype = self._settings['dtype'] + if version is None and torch: + version = torch.__version__ + self.torch_version = version + + def _print_Function(self, expr): + + op = self.mapping.get(type(expr), None) + if op is None: + return super()._print_Basic(expr) + children = [self._print(arg) for arg in expr.args] + if len(children) == 1: + return "%s(%s)" % ( + self._module_format(op), + children[0] + ) + else: + return self._expand_fold_binary_op(op, children) + + # mirrors the tensorflow version + _print_Expr = _print_Function + _print_Application = _print_Function + _print_MatrixExpr = _print_Function + _print_Relational = _print_Function + _print_Not = _print_Function + _print_And = _print_Function + _print_Or = _print_Function + _print_HadamardProduct = _print_Function + _print_Trace = _print_Function + _print_Determinant = _print_Function + + def _print_Inverse(self, expr): + return '{}({})'.format(self._module_format("torch.linalg.inv"), + self._print(expr.args[0])) + + def _print_Transpose(self, expr): + if expr.arg.is_Matrix and expr.arg.shape[0] == expr.arg.shape[1]: + # For square matrices, we can use the .t() method + return "{}({}).t()".format("torch.transpose", self._print(expr.arg)) + else: + # For non-square matrices or more general cases + # transpose first and second dimensions (typical matrix transpose) + return "{}.permute({})".format( + self._print(expr.arg), + ", ".join([str(i) for i in range(len(expr.arg.shape))])[::-1] + ) + + def _print_PermuteDims(self, expr): + return "%s.permute(%s)" % ( + self._print(expr.expr), + ", ".join(str(i) for i in expr.permutation.array_form) + ) + + def _print_Derivative(self, expr): + # this version handles multi-variable and mixed partial derivatives. The tensorflow version does not. + variables = expr.variables + expr_arg = expr.expr + + # Handle multi-variable or repeated derivatives + if len(variables) > 1 or ( + len(variables) == 1 and not isinstance(variables[0], tuple) and variables.count(variables[0]) > 1): + result = self._print(expr_arg) + var_groups = {} + + # Group variables by base symbol + for var in variables: + if isinstance(var, tuple): + base_var, order = var + var_groups[base_var] = var_groups.get(base_var, 0) + order + else: + var_groups[var] = var_groups.get(var, 0) + 1 + + # Apply gradients in sequence + for var, order in var_groups.items(): + for _ in range(order): + result = "torch.autograd.grad({}, {}, create_graph=True)[0]".format(result, self._print(var)) + return result + + # Handle single variable case + if len(variables) == 1: + variable = variables[0] + if isinstance(variable, tuple) and len(variable) == 2: + base_var, order = variable + if not isinstance(order, Integer): raise NotImplementedError("Only integer orders are supported") + result = self._print(expr_arg) + for _ in range(order): + result = "torch.autograd.grad({}, {}, create_graph=True)[0]".format(result, self._print(base_var)) + return result + return "torch.autograd.grad({}, {})[0]".format(self._print(expr_arg), self._print(variable)) + + return self._print(expr_arg) # Empty variables case + + def _print_Piecewise(self, expr): + from sympy import Piecewise + e, cond = expr.args[0].args + if len(expr.args) == 1: + return '{}({}, {}, {})'.format( + self._module_format("torch.where"), + self._print(cond), + self._print(e), + 0) + + return '{}({}, {}, {})'.format( + self._module_format("torch.where"), + self._print(cond), + self._print(e), + self._print(Piecewise(*expr.args[1:]))) + + def _print_Pow(self, expr): + # XXX May raise error for + # int**float or int**complex or float**complex + base, exp = expr.args + if expr.exp == S.Half: + return "{}({})".format( + self._module_format("torch.sqrt"), self._print(base)) + return "{}({}, {})".format( + self._module_format("torch.pow"), + self._print(base), self._print(exp)) + + def _print_MatMul(self, expr): + # Separate matrix and scalar arguments + mat_args = [arg for arg in expr.args if isinstance(arg, MatrixExpr)] + args = [arg for arg in expr.args if arg not in mat_args] + # Handle scalar multipliers if present + if args: + return "%s*%s" % ( + self.parenthesize(Mul.fromiter(args), PRECEDENCE["Mul"]), + self._expand_fold_binary_op("torch.matmul", mat_args) + ) + else: + return self._expand_fold_binary_op("torch.matmul", mat_args) + + def _print_MatPow(self, expr): + return self._expand_fold_binary_op("torch.mm", [expr.base]*expr.exp) + + def _print_MatrixBase(self, expr): + data = "[" + ", ".join(["[" + ", ".join([self._print(j) for j in i]) + "]" for i in expr.tolist()]) + "]" + params = [str(data)] + params.append(f"dtype={self.dtype}") + if self.requires_grad: + params.append("requires_grad=True") + + return "{}({})".format( + self._module_format("torch.tensor"), + ", ".join(params) + ) + + def _print_isnan(self, expr): + return f'torch.isnan({self._print(expr.args[0])})' + + def _print_isinf(self, expr): + return f'torch.isinf({self._print(expr.args[0])})' + + def _print_Identity(self, expr): + if all(dim.is_Integer for dim in expr.shape): + return "{}({})".format( + self._module_format("torch.eye"), + self._print(expr.shape[0]) + ) + else: + # For symbolic dimensions, fall back to a more general approach + return "{}({}, {})".format( + self._module_format("torch.eye"), + self._print(expr.shape[0]), + self._print(expr.shape[1]) + ) + + def _print_ZeroMatrix(self, expr): + return "{}({})".format( + self._module_format("torch.zeros"), + self._print(expr.shape) + ) + + def _print_OneMatrix(self, expr): + return "{}({})".format( + self._module_format("torch.ones"), + self._print(expr.shape) + ) + + def _print_conjugate(self, expr): + return f"{self._module_format('torch.conj')}({self._print(expr.args[0])})" + + def _print_ImaginaryUnit(self, expr): + return "1j" # uses the Python built-in 1j notation for the imaginary unit + + def _print_Heaviside(self, expr): + args = [self._print(expr.args[0]), "0.5"] + if len(expr.args) > 1: + args[1] = self._print(expr.args[1]) + return f"{self._module_format('torch.heaviside')}({args[0]}, {args[1]})" + + def _print_gamma(self, expr): + return f"{self._module_format('torch.special.gamma')}({self._print(expr.args[0])})" + + def _print_polygamma(self, expr): + if expr.args[0] == S.Zero: + return f"{self._module_format('torch.special.digamma')}({self._print(expr.args[1])})" + else: + raise NotImplementedError("PyTorch only supports digamma (0th order polygamma)") + + _module = "torch" + _einsum = "einsum" + _add = "add" + _transpose = "t" + _ones = "ones" + _zeros = "zeros" + +def torch_code(expr, requires_grad=False, dtype="torch.float64", **settings): + printer = TorchPrinter(settings={'requires_grad': requires_grad, 'dtype': dtype}) + return printer.doprint(expr, **settings) diff --git a/phivenv/Lib/site-packages/sympy/printing/rcode.py b/phivenv/Lib/site-packages/sympy/printing/rcode.py new file mode 100644 index 0000000000000000000000000000000000000000..3107e6e94d5c5acf0b2dc063e4a83af6970f6576 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/rcode.py @@ -0,0 +1,402 @@ +""" +R code printer + +The RCodePrinter converts single SymPy expressions into single R expressions, +using the functions defined in math.h where possible. + + + +""" + +from __future__ import annotations +from typing import Any + +from sympy.core.numbers import equal_valued +from sympy.printing.codeprinter import CodePrinter +from sympy.printing.precedence import precedence, PRECEDENCE +from sympy.sets.fancysets import Range + +# dictionary mapping SymPy function to (argument_conditions, C_function). +# Used in RCodePrinter._print_Function(self) +known_functions = { + #"Abs": [(lambda x: not x.is_integer, "fabs")], + "Abs": "abs", + "sin": "sin", + "cos": "cos", + "tan": "tan", + "asin": "asin", + "acos": "acos", + "atan": "atan", + "atan2": "atan2", + "exp": "exp", + "log": "log", + "erf": "erf", + "sinh": "sinh", + "cosh": "cosh", + "tanh": "tanh", + "asinh": "asinh", + "acosh": "acosh", + "atanh": "atanh", + "floor": "floor", + "ceiling": "ceiling", + "sign": "sign", + "Max": "max", + "Min": "min", + "factorial": "factorial", + "gamma": "gamma", + "digamma": "digamma", + "trigamma": "trigamma", + "beta": "beta", + "sqrt": "sqrt", # To enable automatic rewrite +} + +# These are the core reserved words in the R language. Taken from: +# https://cran.r-project.org/doc/manuals/r-release/R-lang.html#Reserved-words + +reserved_words = ['if', + 'else', + 'repeat', + 'while', + 'function', + 'for', + 'in', + 'next', + 'break', + 'TRUE', + 'FALSE', + 'NULL', + 'Inf', + 'NaN', + 'NA', + 'NA_integer_', + 'NA_real_', + 'NA_complex_', + 'NA_character_', + 'volatile'] + + +class RCodePrinter(CodePrinter): + """A printer to convert SymPy expressions to strings of R code""" + printmethod = "_rcode" + language = "R" + + _default_settings: dict[str, Any] = dict(CodePrinter._default_settings, **{ + 'precision': 15, + 'user_functions': {}, + 'contract': True, + 'dereference': set(), + }) + _operators = { + 'and': '&', + 'or': '|', + 'not': '!', + } + + _relationals: dict[str, str] = {} + + def __init__(self, settings={}): + CodePrinter.__init__(self, settings) + self.known_functions = dict(known_functions) + userfuncs = settings.get('user_functions', {}) + self.known_functions.update(userfuncs) + self._dereference = set(settings.get('dereference', [])) + self.reserved_words = set(reserved_words) + + def _rate_index_position(self, p): + return p*5 + + def _get_statement(self, codestring): + return "%s;" % codestring + + def _get_comment(self, text): + return "// {}".format(text) + + def _declare_number_const(self, name, value): + return "{} = {};".format(name, value) + + def _format_code(self, lines): + return self.indent_code(lines) + + def _traverse_matrix_indices(self, mat): + rows, cols = mat.shape + return ((i, j) for i in range(rows) for j in range(cols)) + + def _get_loop_opening_ending(self, indices): + """Returns a tuple (open_lines, close_lines) containing lists of codelines + """ + open_lines = [] + close_lines = [] + loopstart = "for (%(var)s in %(start)s:%(end)s){" + for i in indices: + # R arrays start at 1 and end at dimension + open_lines.append(loopstart % { + 'var': self._print(i.label), + 'start': self._print(i.lower+1), + 'end': self._print(i.upper + 1)}) + close_lines.append("}") + return open_lines, close_lines + + def _print_Pow(self, expr): + if "Pow" in self.known_functions: + return self._print_Function(expr) + PREC = precedence(expr) + if equal_valued(expr.exp, -1): + return '1.0/%s' % (self.parenthesize(expr.base, PREC)) + elif equal_valued(expr.exp, 0.5): + return 'sqrt(%s)' % self._print(expr.base) + else: + return '%s^%s' % (self.parenthesize(expr.base, PREC), + self.parenthesize(expr.exp, PREC)) + + + def _print_Rational(self, expr): + p, q = int(expr.p), int(expr.q) + return '%d.0/%d.0' % (p, q) + + def _print_Indexed(self, expr): + inds = [ self._print(i) for i in expr.indices ] + return "%s[%s]" % (self._print(expr.base.label), ", ".join(inds)) + + def _print_Exp1(self, expr): + return "exp(1)" + + def _print_Pi(self, expr): + return 'pi' + + def _print_Infinity(self, expr): + return 'Inf' + + def _print_NegativeInfinity(self, expr): + return '-Inf' + + def _print_Assignment(self, expr): + from sympy.codegen.ast import Assignment + + from sympy.matrices.expressions.matexpr import MatrixSymbol + from sympy.tensor.indexed import IndexedBase + lhs = expr.lhs + rhs = expr.rhs + # We special case assignments that take multiple lines + #if isinstance(expr.rhs, Piecewise): + # from sympy.functions.elementary.piecewise import Piecewise + # # Here we modify Piecewise so each expression is now + # # an Assignment, and then continue on the print. + # expressions = [] + # conditions = [] + # for (e, c) in rhs.args: + # expressions.append(Assignment(lhs, e)) + # conditions.append(c) + # temp = Piecewise(*zip(expressions, conditions)) + # return self._print(temp) + #elif isinstance(lhs, MatrixSymbol): + if isinstance(lhs, MatrixSymbol): + # Here we form an Assignment for each element in the array, + # printing each one. + lines = [] + for (i, j) in self._traverse_matrix_indices(lhs): + temp = Assignment(lhs[i, j], rhs[i, j]) + code0 = self._print(temp) + lines.append(code0) + return "\n".join(lines) + elif self._settings["contract"] and (lhs.has(IndexedBase) or + rhs.has(IndexedBase)): + # Here we check if there is looping to be done, and if so + # print the required loops. + return self._doprint_loops(rhs, lhs) + else: + lhs_code = self._print(lhs) + rhs_code = self._print(rhs) + return self._get_statement("%s = %s" % (lhs_code, rhs_code)) + + def _print_Piecewise(self, expr): + # This method is called only for inline if constructs + # Top level piecewise is handled in doprint() + if expr.args[-1].cond == True: + last_line = "%s" % self._print(expr.args[-1].expr) + else: + last_line = "ifelse(%s,%s,NA)" % (self._print(expr.args[-1].cond), self._print(expr.args[-1].expr)) + code=last_line + for e, c in reversed(expr.args[:-1]): + code= "ifelse(%s,%s," % (self._print(c), self._print(e))+code+")" + return(code) + + def _print_ITE(self, expr): + from sympy.functions import Piecewise + return self._print(expr.rewrite(Piecewise)) + + def _print_MatrixElement(self, expr): + return "{}[{}]".format(self.parenthesize(expr.parent, PRECEDENCE["Atom"], + strict=True), expr.j + expr.i*expr.parent.shape[1]) + + def _print_Symbol(self, expr): + name = super()._print_Symbol(expr) + if expr in self._dereference: + return '(*{})'.format(name) + else: + return name + + def _print_Relational(self, expr): + lhs_code = self._print(expr.lhs) + rhs_code = self._print(expr.rhs) + op = expr.rel_op + return "{} {} {}".format(lhs_code, op, rhs_code) + + def _print_AugmentedAssignment(self, expr): + lhs_code = self._print(expr.lhs) + op = expr.op + rhs_code = self._print(expr.rhs) + return "{} {} {};".format(lhs_code, op, rhs_code) + + def _print_For(self, expr): + target = self._print(expr.target) + if isinstance(expr.iterable, Range): + start, stop, step = expr.iterable.args + else: + raise NotImplementedError("Only iterable currently supported is Range") + body = self._print(expr.body) + return 'for({target} in seq(from={start}, to={stop}, by={step}){{\n{body}\n}}'.format(target=target, start=start, + stop=stop-1, step=step, body=body) + + + def indent_code(self, code): + """Accepts a string of code or a list of code lines""" + + if isinstance(code, str): + code_lines = self.indent_code(code.splitlines(True)) + return ''.join(code_lines) + + tab = " " + inc_token = ('{', '(', '{\n', '(\n') + dec_token = ('}', ')') + + code = [ line.lstrip(' \t') for line in code ] + + increase = [ int(any(map(line.endswith, inc_token))) for line in code ] + decrease = [ int(any(map(line.startswith, dec_token))) + for line in code ] + + pretty = [] + level = 0 + for n, line in enumerate(code): + if line in ('', '\n'): + pretty.append(line) + continue + level -= decrease[n] + pretty.append("%s%s" % (tab*level, line)) + level += increase[n] + return pretty + + +def rcode(expr, assign_to=None, **settings): + """Converts an expr to a string of r code + + Parameters + ========== + + expr : Expr + A SymPy expression to be converted. + assign_to : optional + When given, the argument is used as the name of the variable to which + the expression is assigned. Can be a string, ``Symbol``, + ``MatrixSymbol``, or ``Indexed`` type. This is helpful in case of + line-wrapping, or for expressions that generate multi-line statements. + precision : integer, optional + The precision for numbers such as pi [default=15]. + user_functions : dict, optional + A dictionary where the keys are string representations of either + ``FunctionClass`` or ``UndefinedFunction`` instances and the values + are their desired R string representations. Alternatively, the + dictionary value can be a list of tuples i.e. [(argument_test, + rfunction_string)] or [(argument_test, rfunction_formater)]. See below + for examples. + human : bool, optional + If True, the result is a single string that may contain some constant + declarations for the number symbols. If False, the same information is + returned in a tuple of (symbols_to_declare, not_supported_functions, + code_text). [default=True]. + contract: bool, optional + If True, ``Indexed`` instances are assumed to obey tensor contraction + rules and the corresponding nested loops over indices are generated. + Setting contract=False will not generate loops, instead the user is + responsible to provide values for the indices in the code. + [default=True]. + + Examples + ======== + + >>> from sympy import rcode, symbols, Rational, sin, ceiling, Abs, Function + >>> x, tau = symbols("x, tau") + >>> rcode((2*tau)**Rational(7, 2)) + '8*sqrt(2)*tau^(7.0/2.0)' + >>> rcode(sin(x), assign_to="s") + 's = sin(x);' + + Simple custom printing can be defined for certain types by passing a + dictionary of {"type" : "function"} to the ``user_functions`` kwarg. + Alternatively, the dictionary value can be a list of tuples i.e. + [(argument_test, cfunction_string)]. + + >>> custom_functions = { + ... "ceiling": "CEIL", + ... "Abs": [(lambda x: not x.is_integer, "fabs"), + ... (lambda x: x.is_integer, "ABS")], + ... "func": "f" + ... } + >>> func = Function('func') + >>> rcode(func(Abs(x) + ceiling(x)), user_functions=custom_functions) + 'f(fabs(x) + CEIL(x))' + + or if the R-function takes a subset of the original arguments: + + >>> rcode(2**x + 3**x, user_functions={'Pow': [ + ... (lambda b, e: b == 2, lambda b, e: 'exp2(%s)' % e), + ... (lambda b, e: b != 2, 'pow')]}) + 'exp2(x) + pow(3, x)' + + ``Piecewise`` expressions are converted into conditionals. If an + ``assign_to`` variable is provided an if statement is created, otherwise + the ternary operator is used. Note that if the ``Piecewise`` lacks a + default term, represented by ``(expr, True)`` then an error will be thrown. + This is to prevent generating an expression that may not evaluate to + anything. + + >>> from sympy import Piecewise + >>> expr = Piecewise((x + 1, x > 0), (x, True)) + >>> print(rcode(expr, assign_to=tau)) + tau = ifelse(x > 0,x + 1,x); + + Support for loops is provided through ``Indexed`` types. With + ``contract=True`` these expressions will be turned into loops, whereas + ``contract=False`` will just print the assignment expression that should be + looped over: + + >>> from sympy import Eq, IndexedBase, Idx + >>> len_y = 5 + >>> y = IndexedBase('y', shape=(len_y,)) + >>> t = IndexedBase('t', shape=(len_y,)) + >>> Dy = IndexedBase('Dy', shape=(len_y-1,)) + >>> i = Idx('i', len_y-1) + >>> e=Eq(Dy[i], (y[i+1]-y[i])/(t[i+1]-t[i])) + >>> rcode(e.rhs, assign_to=e.lhs, contract=False) + 'Dy[i] = (y[i + 1] - y[i])/(t[i + 1] - t[i]);' + + Matrices are also supported, but a ``MatrixSymbol`` of the same dimensions + must be provided to ``assign_to``. Note that any expression that can be + generated normally can also exist inside a Matrix: + + >>> from sympy import Matrix, MatrixSymbol + >>> mat = Matrix([x**2, Piecewise((x + 1, x > 0), (x, True)), sin(x)]) + >>> A = MatrixSymbol('A', 3, 1) + >>> print(rcode(mat, A)) + A[0] = x^2; + A[1] = ifelse(x > 0,x + 1,x); + A[2] = sin(x); + + """ + + return RCodePrinter(settings).doprint(expr, assign_to) + + +def print_rcode(expr, **settings): + """Prints R representation of the given expression.""" + print(rcode(expr, **settings)) diff --git a/phivenv/Lib/site-packages/sympy/printing/repr.py b/phivenv/Lib/site-packages/sympy/printing/repr.py new file mode 100644 index 0000000000000000000000000000000000000000..0a4b756abbab77c3eb0fd77ee1f0bd97382c36fb --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/repr.py @@ -0,0 +1,339 @@ +""" +A Printer for generating executable code. + +The most important function here is srepr that returns a string so that the +relation eval(srepr(expr))=expr holds in an appropriate environment. +""" + +from __future__ import annotations +from typing import Any + +from sympy.core.function import AppliedUndef +from sympy.core.mul import Mul +from mpmath.libmp import repr_dps, to_str as mlib_to_str + +from .printer import Printer, print_function + + +class ReprPrinter(Printer): + printmethod = "_sympyrepr" + + _default_settings: dict[str, Any] = { + "order": None, + "perm_cyclic" : True, + } + + def reprify(self, args, sep): + """ + Prints each item in `args` and joins them with `sep`. + """ + return sep.join([self.doprint(item) for item in args]) + + def emptyPrinter(self, expr): + """ + The fallback printer. + """ + if isinstance(expr, str): + return expr + elif hasattr(expr, "__srepr__"): + return expr.__srepr__() + elif hasattr(expr, "args") and hasattr(expr.args, "__iter__"): + l = [] + for o in expr.args: + l.append(self._print(o)) + return expr.__class__.__name__ + '(%s)' % ', '.join(l) + elif hasattr(expr, "__module__") and hasattr(expr, "__name__"): + return "<'%s.%s'>" % (expr.__module__, expr.__name__) + else: + return str(expr) + + def _print_Add(self, expr, order=None): + args = self._as_ordered_terms(expr, order=order) + args = map(self._print, args) + clsname = type(expr).__name__ + return clsname + "(%s)" % ", ".join(args) + + def _print_Cycle(self, expr): + return expr.__repr__() + + def _print_Permutation(self, expr): + from sympy.combinatorics.permutations import Permutation, Cycle + from sympy.utilities.exceptions import sympy_deprecation_warning + + perm_cyclic = Permutation.print_cyclic + if perm_cyclic is not None: + sympy_deprecation_warning( + f""" + Setting Permutation.print_cyclic is deprecated. Instead use + init_printing(perm_cyclic={perm_cyclic}). + """, + deprecated_since_version="1.6", + active_deprecations_target="deprecated-permutation-print_cyclic", + stacklevel=7, + ) + else: + perm_cyclic = self._settings.get("perm_cyclic", True) + + if perm_cyclic: + if not expr.size: + return 'Permutation()' + # before taking Cycle notation, see if the last element is + # a singleton and move it to the head of the string + s = Cycle(expr)(expr.size - 1).__repr__()[len('Cycle'):] + last = s.rfind('(') + if not last == 0 and ',' not in s[last:]: + s = s[last:] + s[:last] + return 'Permutation%s' %s + else: + s = expr.support() + if not s: + if expr.size < 5: + return 'Permutation(%s)' % str(expr.array_form) + return 'Permutation([], size=%s)' % expr.size + trim = str(expr.array_form[:s[-1] + 1]) + ', size=%s' % expr.size + use = full = str(expr.array_form) + if len(trim) < len(full): + use = trim + return 'Permutation(%s)' % use + + def _print_Function(self, expr): + r = self._print(expr.func) + r += '(%s)' % ', '.join([self._print(a) for a in expr.args]) + return r + + def _print_Heaviside(self, expr): + # Same as _print_Function but uses pargs to suppress default value for + # 2nd arg. + r = self._print(expr.func) + r += '(%s)' % ', '.join([self._print(a) for a in expr.pargs]) + return r + + def _print_FunctionClass(self, expr): + if issubclass(expr, AppliedUndef): + return 'Function(%r)' % (expr.__name__) + else: + return expr.__name__ + + def _print_Half(self, expr): + return 'Rational(1, 2)' + + def _print_RationalConstant(self, expr): + return str(expr) + + def _print_AtomicExpr(self, expr): + return str(expr) + + def _print_NumberSymbol(self, expr): + return str(expr) + + def _print_Integer(self, expr): + return 'Integer(%i)' % expr.p + + def _print_Complexes(self, expr): + return 'Complexes' + + def _print_Integers(self, expr): + return 'Integers' + + def _print_Naturals(self, expr): + return 'Naturals' + + def _print_Naturals0(self, expr): + return 'Naturals0' + + def _print_Rationals(self, expr): + return 'Rationals' + + def _print_Reals(self, expr): + return 'Reals' + + def _print_EmptySet(self, expr): + return 'EmptySet' + + def _print_UniversalSet(self, expr): + return 'UniversalSet' + + def _print_EmptySequence(self, expr): + return 'EmptySequence' + + def _print_list(self, expr): + return "[%s]" % self.reprify(expr, ", ") + + def _print_dict(self, expr): + sep = ", " + dict_kvs = ["%s: %s" % (self.doprint(key), self.doprint(value)) for key, value in expr.items()] + return "{%s}" % sep.join(dict_kvs) + + def _print_set(self, expr): + if not expr: + return "set()" + return "{%s}" % self.reprify(expr, ", ") + + def _print_MatrixBase(self, expr): + # special case for some empty matrices + if (expr.rows == 0) ^ (expr.cols == 0): + return '%s(%s, %s, %s)' % (expr.__class__.__name__, + self._print(expr.rows), + self._print(expr.cols), + self._print([])) + l = [] + for i in range(expr.rows): + l.append([]) + for j in range(expr.cols): + l[-1].append(expr[i, j]) + return '%s(%s)' % (expr.__class__.__name__, self._print(l)) + + def _print_BooleanTrue(self, expr): + return "true" + + def _print_BooleanFalse(self, expr): + return "false" + + def _print_NaN(self, expr): + return "nan" + + def _print_Mul(self, expr, order=None): + if self.order not in ('old', 'none'): + args = expr.as_ordered_factors() + else: + # use make_args in case expr was something like -x -> x + args = Mul.make_args(expr) + + args = map(self._print, args) + clsname = type(expr).__name__ + return clsname + "(%s)" % ", ".join(args) + + def _print_Rational(self, expr): + return 'Rational(%s, %s)' % (self._print(expr.p), self._print(expr.q)) + + def _print_PythonRational(self, expr): + return "%s(%d, %d)" % (expr.__class__.__name__, expr.p, expr.q) + + def _print_Fraction(self, expr): + return 'Fraction(%s, %s)' % (self._print(expr.numerator), self._print(expr.denominator)) + + def _print_Float(self, expr): + r = mlib_to_str(expr._mpf_, repr_dps(expr._prec)) + return "%s('%s', precision=%i)" % (expr.__class__.__name__, r, expr._prec) + + def _print_Sum2(self, expr): + return "Sum2(%s, (%s, %s, %s))" % (self._print(expr.f), self._print(expr.i), + self._print(expr.a), self._print(expr.b)) + + def _print_Str(self, s): + return "%s(%s)" % (s.__class__.__name__, self._print(s.name)) + + def _print_Symbol(self, expr): + d = expr._assumptions_orig + # print the dummy_index like it was an assumption + if expr.is_Dummy: + d = d.copy() + d['dummy_index'] = expr.dummy_index + + if d == {}: + return "%s(%s)" % (expr.__class__.__name__, self._print(expr.name)) + else: + attr = ['%s=%s' % (k, v) for k, v in d.items()] + return "%s(%s, %s)" % (expr.__class__.__name__, + self._print(expr.name), ', '.join(attr)) + + def _print_CoordinateSymbol(self, expr): + d = expr._assumptions.generator + + if d == {}: + return "%s(%s, %s)" % ( + expr.__class__.__name__, + self._print(expr.coord_sys), + self._print(expr.index) + ) + else: + attr = ['%s=%s' % (k, v) for k, v in d.items()] + return "%s(%s, %s, %s)" % ( + expr.__class__.__name__, + self._print(expr.coord_sys), + self._print(expr.index), + ', '.join(attr) + ) + + def _print_Predicate(self, expr): + return "Q.%s" % expr.name + + def _print_AppliedPredicate(self, expr): + # will be changed to just expr.args when args overriding is removed + args = expr._args + return "%s(%s)" % (expr.__class__.__name__, self.reprify(args, ", ")) + + def _print_str(self, expr): + return repr(expr) + + def _print_tuple(self, expr): + if len(expr) == 1: + return "(%s,)" % self._print(expr[0]) + else: + return "(%s)" % self.reprify(expr, ", ") + + def _print_WildFunction(self, expr): + return "%s('%s')" % (expr.__class__.__name__, expr.name) + + def _print_AlgebraicNumber(self, expr): + return "%s(%s, %s)" % (expr.__class__.__name__, + self._print(expr.root), self._print(expr.coeffs())) + + def _print_PolyRing(self, ring): + return "%s(%s, %s, %s)" % (ring.__class__.__name__, + self._print(ring.symbols), self._print(ring.domain), self._print(ring.order)) + + def _print_FracField(self, field): + return "%s(%s, %s, %s)" % (field.__class__.__name__, + self._print(field.symbols), self._print(field.domain), self._print(field.order)) + + def _print_PolyElement(self, poly): + terms = list(poly.terms()) + terms.sort(key=poly.ring.order, reverse=True) + return "%s(%s, %s)" % (poly.__class__.__name__, self._print(poly.ring), self._print(terms)) + + def _print_FracElement(self, frac): + numer_terms = list(frac.numer.terms()) + numer_terms.sort(key=frac.field.order, reverse=True) + denom_terms = list(frac.denom.terms()) + denom_terms.sort(key=frac.field.order, reverse=True) + numer = self._print(numer_terms) + denom = self._print(denom_terms) + return "%s(%s, %s, %s)" % (frac.__class__.__name__, self._print(frac.field), numer, denom) + + def _print_FractionField(self, domain): + cls = domain.__class__.__name__ + field = self._print(domain.field) + return "%s(%s)" % (cls, field) + + def _print_PolynomialRingBase(self, ring): + cls = ring.__class__.__name__ + dom = self._print(ring.domain) + gens = ', '.join(map(self._print, ring.gens)) + order = str(ring.order) + if order != ring.default_order: + orderstr = ", order=" + order + else: + orderstr = "" + return "%s(%s, %s%s)" % (cls, dom, gens, orderstr) + + def _print_DMP(self, p): + cls = p.__class__.__name__ + rep = self._print(p.to_list()) + dom = self._print(p.dom) + return "%s(%s, %s)" % (cls, rep, dom) + + def _print_MonogenicFiniteExtension(self, ext): + # The expanded tree shown by srepr(ext.modulus) + # is not practical. + return "FiniteExtension(%s)" % str(ext.modulus) + + def _print_ExtensionElement(self, f): + rep = self._print(f.rep) + ext = self._print(f.ext) + return "ExtElem(%s, %s)" % (rep, ext) + +@print_function(ReprPrinter) +def srepr(expr, **settings): + """return expr in repr form""" + return ReprPrinter(settings).doprint(expr) diff --git a/phivenv/Lib/site-packages/sympy/printing/rust.py b/phivenv/Lib/site-packages/sympy/printing/rust.py new file mode 100644 index 0000000000000000000000000000000000000000..5bfd481bec6b7350df281accfb9b7a598cf05baa --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/rust.py @@ -0,0 +1,637 @@ +""" +Rust code printer + +The `RustCodePrinter` converts SymPy expressions into Rust expressions. + +A complete code generator, which uses `rust_code` extensively, can be found +in `sympy.utilities.codegen`. The `codegen` module can be used to generate +complete source code files. + +""" + +# Possible Improvement +# +# * make sure we follow Rust Style Guidelines_ +# * make use of pattern matching +# * better support for reference +# * generate generic code and use trait to make sure they have specific methods +# * use crates_ to get more math support +# - num_ +# + BigInt_, BigUint_ +# + Complex_ +# + Rational64_, Rational32_, BigRational_ +# +# .. _crates: https://crates.io/ +# .. _Guidelines: https://github.com/rust-lang/rust/tree/master/src/doc/style +# .. _num: http://rust-num.github.io/num/num/ +# .. _BigInt: http://rust-num.github.io/num/num/bigint/struct.BigInt.html +# .. _BigUint: http://rust-num.github.io/num/num/bigint/struct.BigUint.html +# .. _Complex: http://rust-num.github.io/num/num/complex/struct.Complex.html +# .. _Rational32: http://rust-num.github.io/num/num/rational/type.Rational32.html +# .. _Rational64: http://rust-num.github.io/num/num/rational/type.Rational64.html +# .. _BigRational: http://rust-num.github.io/num/num/rational/type.BigRational.html + +from __future__ import annotations +from functools import reduce +import operator +from typing import Any + +from sympy.codegen.ast import ( + float32, float64, int32, + real, integer, bool_ +) +from sympy.core import S, Rational, Float, Lambda +from sympy.core.expr import Expr +from sympy.core.numbers import equal_valued +from sympy.functions.elementary.integers import ceiling, floor +from sympy.printing.codeprinter import CodePrinter +from sympy.printing.precedence import PRECEDENCE + +# Rust's methods for integer and float can be found at here : +# +# * `Rust - Primitive Type f64 `_ +# * `Rust - Primitive Type i64 `_ +# +# Function Style : +# +# 1. args[0].func(args[1:]), method with arguments +# 2. args[0].func(), method without arguments +# 3. args[1].func(), method without arguments (e.g. (e, x) => x.exp()) +# 4. func(args), function with arguments + +# dictionary mapping SymPy function to (argument_conditions, Rust_function). +# Used in RustCodePrinter._print_Function(self) + +class float_floor(floor): + """ + Same as `sympy.floor`, but mimics the Rust behavior of returning a float rather than an integer + """ + def _eval_is_integer(self): + return False + +class float_ceiling(ceiling): + """ + Same as `sympy.ceiling`, but mimics the Rust behavior of returning a float rather than an integer + """ + def _eval_is_integer(self): + return False + + +function_overrides = { + "floor": (floor, float_floor), + "ceiling": (ceiling, float_ceiling), +} + +# f64 method in Rust +known_functions = { + # "": "is_nan", + # "": "is_infinite", + # "": "is_finite", + # "": "is_normal", + # "": "classify", + "float_floor": "floor", + "float_ceiling": "ceil", + # "": "round", + # "": "trunc", + # "": "fract", + "Abs": "abs", + # "": "signum", + # "": "is_sign_positive", + # "": "is_sign_negative", + # "": "mul_add", + "Pow": [(lambda base, exp: equal_valued(exp, -1), "recip", 2), # 1.0/x + (lambda base, exp: equal_valued(exp, 0.5), "sqrt", 2), # x ** 0.5 + (lambda base, exp: equal_valued(exp, -0.5), "sqrt().recip", 2), # 1/(x ** 0.5) + (lambda base, exp: exp == Rational(1, 3), "cbrt", 2), # x ** (1/3) + (lambda base, exp: equal_valued(base, 2), "exp2", 3), # 2 ** x + (lambda base, exp: exp.is_integer, "powi", 1), # x ** y, for i32 + (lambda base, exp: not exp.is_integer, "powf", 1)], # x ** y, for f64 + "exp": [(lambda exp: True, "exp", 2)], # e ** x + "log": "ln", + # "": "log", # number.log(base) + # "": "log2", + # "": "log10", + # "": "to_degrees", + # "": "to_radians", + "Max": "max", + "Min": "min", + # "": "hypot", # (x**2 + y**2) ** 0.5 + "sin": "sin", + "cos": "cos", + "tan": "tan", + "asin": "asin", + "acos": "acos", + "atan": "atan", + "atan2": "atan2", + # "": "sin_cos", + # "": "exp_m1", # e ** x - 1 + # "": "ln_1p", # ln(1 + x) + "sinh": "sinh", + "cosh": "cosh", + "tanh": "tanh", + "asinh": "asinh", + "acosh": "acosh", + "atanh": "atanh", + "sqrt": "sqrt", # To enable automatic rewrites +} + +# i64 method in Rust +# known_functions_i64 = { +# "": "min_value", +# "": "max_value", +# "": "from_str_radix", +# "": "count_ones", +# "": "count_zeros", +# "": "leading_zeros", +# "": "trainling_zeros", +# "": "rotate_left", +# "": "rotate_right", +# "": "swap_bytes", +# "": "from_be", +# "": "from_le", +# "": "to_be", # to big endian +# "": "to_le", # to little endian +# "": "checked_add", +# "": "checked_sub", +# "": "checked_mul", +# "": "checked_div", +# "": "checked_rem", +# "": "checked_neg", +# "": "checked_shl", +# "": "checked_shr", +# "": "checked_abs", +# "": "saturating_add", +# "": "saturating_sub", +# "": "saturating_mul", +# "": "wrapping_add", +# "": "wrapping_sub", +# "": "wrapping_mul", +# "": "wrapping_div", +# "": "wrapping_rem", +# "": "wrapping_neg", +# "": "wrapping_shl", +# "": "wrapping_shr", +# "": "wrapping_abs", +# "": "overflowing_add", +# "": "overflowing_sub", +# "": "overflowing_mul", +# "": "overflowing_div", +# "": "overflowing_rem", +# "": "overflowing_neg", +# "": "overflowing_shl", +# "": "overflowing_shr", +# "": "overflowing_abs", +# "Pow": "pow", +# "Abs": "abs", +# "sign": "signum", +# "": "is_positive", +# "": "is_negnative", +# } + +# These are the core reserved words in the Rust language. Taken from: +# https://doc.rust-lang.org/reference/keywords.html + +reserved_words = ['abstract', + 'as', + 'async', + 'await', + 'become', + 'box', + 'break', + 'const', + 'continue', + 'crate', + 'do', + 'dyn', + 'else', + 'enum', + 'extern', + 'false', + 'final', + 'fn', + 'for', + 'gen', + 'if', + 'impl', + 'in', + 'let', + 'loop', + 'macro', + 'match', + 'mod', + 'move', + 'mut', + 'override', + 'priv', + 'pub', + 'ref', + 'return', + 'Self', + 'self', + 'static', + 'struct', + 'super', + 'trait', + 'true', + 'try', + 'type', + 'typeof', + 'unsafe', + 'unsized', + 'use', + 'virtual', + 'where', + 'while', + 'yield'] + + +class TypeCast(Expr): + """ + The type casting operator of the Rust language. + """ + + def __init__(self, expr, type_) -> None: + super().__init__() + self.explicit = expr.is_integer and type_ is not integer + self._assumptions = expr._assumptions + if self.explicit: + setattr(self, 'precedence', PRECEDENCE["Func"] + 10) + + @property + def expr(self): + return self.args[0] + + @property + def type_(self): + return self.args[1] + + def sort_key(self, order=None): + return self.args[0].sort_key(order=order) + + +class RustCodePrinter(CodePrinter): + """A printer to convert SymPy expressions to strings of Rust code""" + printmethod = "_rust_code" + language = "Rust" + + type_aliases = { + integer: int32, + real: float64, + } + + type_mappings = { + int32: 'i32', + float32: 'f32', + float64: 'f64', + bool_: 'bool' + } + + _default_settings: dict[str, Any] = dict(CodePrinter._default_settings, **{ + 'precision': 17, + 'user_functions': {}, + 'contract': True, + 'dereference': set(), + }) + + def __init__(self, settings={}): + CodePrinter.__init__(self, settings) + self.known_functions = dict(known_functions) + userfuncs = settings.get('user_functions', {}) + self.known_functions.update(userfuncs) + self._dereference = set(settings.get('dereference', [])) + self.reserved_words = set(reserved_words) + self.function_overrides = function_overrides + + def _rate_index_position(self, p): + return p*5 + + def _get_statement(self, codestring): + return "%s;" % codestring + + def _get_comment(self, text): + return "// %s" % text + + def _declare_number_const(self, name, value): + type_ = self.type_mappings[self.type_aliases[real]] + return "const %s: %s = %s;" % (name, type_, value) + + def _format_code(self, lines): + return self.indent_code(lines) + + def _traverse_matrix_indices(self, mat): + rows, cols = mat.shape + return ((i, j) for i in range(rows) for j in range(cols)) + + def _get_loop_opening_ending(self, indices): + open_lines = [] + close_lines = [] + loopstart = "for %(var)s in %(start)s..%(end)s {" + for i in indices: + # Rust arrays start at 0 and end at dimension-1 + open_lines.append(loopstart % { + 'var': self._print(i), + 'start': self._print(i.lower), + 'end': self._print(i.upper + 1)}) + close_lines.append("}") + return open_lines, close_lines + + def _print_caller_var(self, expr): + if len(expr.args) > 1: + # for something like `sin(x + y + z)`, + # make sure we can get '(x + y + z).sin()' + # instead of 'x + y + z.sin()' + return '(' + self._print(expr) + ')' + elif expr.is_number: + return self._print(expr, _type=True) + else: + return self._print(expr) + + def _print_Function(self, expr): + """ + basic function for printing `Function` + + Function Style : + + 1. args[0].func(args[1:]), method with arguments + 2. args[0].func(), method without arguments + 3. args[1].func(), method without arguments (e.g. (e, x) => x.exp()) + 4. func(args), function with arguments + """ + + if expr.func.__name__ in self.known_functions: + cond_func = self.known_functions[expr.func.__name__] + func = None + style = 1 + if isinstance(cond_func, str): + func = cond_func + else: + for cond, func, style in cond_func: + if cond(*expr.args): + break + if func is not None: + if style == 1: + ret = "%(var)s.%(method)s(%(args)s)" % { + 'var': self._print_caller_var(expr.args[0]), + 'method': func, + 'args': self.stringify(expr.args[1:], ", ") if len(expr.args) > 1 else '' + } + elif style == 2: + ret = "%(var)s.%(method)s()" % { + 'var': self._print_caller_var(expr.args[0]), + 'method': func, + } + elif style == 3: + ret = "%(var)s.%(method)s()" % { + 'var': self._print_caller_var(expr.args[1]), + 'method': func, + } + else: + ret = "%(func)s(%(args)s)" % { + 'func': func, + 'args': self.stringify(expr.args, ", "), + } + return ret + elif hasattr(expr, '_imp_') and isinstance(expr._imp_, Lambda): + # inlined function + return self._print(expr._imp_(*expr.args)) + else: + return self._print_not_supported(expr) + + def _print_Mul(self, expr): + contains_floats = any(arg.is_real and not arg.is_integer for arg in expr.args) + if contains_floats: + expr = reduce(operator.mul,(self._cast_to_float(arg) if arg != -1 else arg for arg in expr.args)) + + return super()._print_Mul(expr) + + def _print_Add(self, expr, order=None): + contains_floats = any(arg.is_real and not arg.is_integer for arg in expr.args) + if contains_floats: + expr = reduce(operator.add, (self._cast_to_float(arg) for arg in expr.args)) + + return super()._print_Add(expr, order) + + def _print_Pow(self, expr): + if expr.base.is_integer and not expr.exp.is_integer: + expr = type(expr)(Float(expr.base), expr.exp) + return self._print(expr) + return self._print_Function(expr) + + def _print_TypeCast(self, expr): + if not expr.explicit: + return self._print(expr.expr) + else: + return self._print(expr.expr) + ' as %s' % self.type_mappings[self.type_aliases[expr.type_]] + + def _print_Float(self, expr, _type=False): + ret = super()._print_Float(expr) + if _type: + return ret + '_%s' % self.type_mappings[self.type_aliases[real]] + else: + return ret + + def _print_Integer(self, expr, _type=False): + ret = super()._print_Integer(expr) + if _type: + return ret + '_%s' % self.type_mappings[self.type_aliases[integer]] + else: + return ret + + def _print_Rational(self, expr): + p, q = int(expr.p), int(expr.q) + float_suffix = self.type_mappings[self.type_aliases[real]] + return '%d_%s/%d.0' % (p, float_suffix, q) + + def _print_Relational(self, expr): + if (expr.lhs.is_integer and not expr.rhs.is_integer) or (expr.rhs.is_integer and not expr.lhs.is_integer): + lhs = self._cast_to_float(expr.lhs) + rhs = self._cast_to_float(expr.rhs) + else: + lhs = expr.lhs + rhs = expr.rhs + lhs_code = self._print(lhs) + rhs_code = self._print(rhs) + op = expr.rel_op + return "{} {} {}".format(lhs_code, op, rhs_code) + + def _print_Indexed(self, expr): + # calculate index for 1d array + dims = expr.shape + elem = S.Zero + offset = S.One + for i in reversed(range(expr.rank)): + elem += expr.indices[i]*offset + offset *= dims[i] + return "%s[%s]" % (self._print(expr.base.label), self._print(elem)) + + def _print_Idx(self, expr): + return expr.label.name + + def _print_Dummy(self, expr): + return expr.name + + def _print_Exp1(self, expr, _type=False): + return "E" + + def _print_Pi(self, expr, _type=False): + return 'PI' + + def _print_Infinity(self, expr, _type=False): + return 'INFINITY' + + def _print_NegativeInfinity(self, expr, _type=False): + return 'NEG_INFINITY' + + def _print_BooleanTrue(self, expr, _type=False): + return "true" + + def _print_BooleanFalse(self, expr, _type=False): + return "false" + + def _print_bool(self, expr, _type=False): + return str(expr).lower() + + def _print_NaN(self, expr, _type=False): + return "NAN" + + def _print_Piecewise(self, expr): + if expr.args[-1].cond != True: + # We need the last conditional to be a True, otherwise the resulting + # function may not return a result. + raise ValueError("All Piecewise expressions must contain an " + "(expr, True) statement to be used as a default " + "condition. Without one, the generated " + "expression may not evaluate to anything under " + "some condition.") + lines = [] + + for i, (e, c) in enumerate(expr.args): + if i == 0: + lines.append("if (%s) {" % self._print(c)) + elif i == len(expr.args) - 1 and c == True: + lines[-1] += " else {" + else: + lines[-1] += " else if (%s) {" % self._print(c) + code0 = self._print(e) + lines.append(code0) + lines.append("}") + + if self._settings['inline']: + return " ".join(lines) + else: + return "\n".join(lines) + + def _print_ITE(self, expr): + from sympy.functions import Piecewise + return self._print(expr.rewrite(Piecewise, deep=False)) + + def _print_MatrixBase(self, A): + if A.cols == 1: + return "[%s]" % ", ".join(self._print(a) for a in A) + else: + raise ValueError("Full Matrix Support in Rust need Crates (https://crates.io/keywords/matrix).") + + def _print_SparseRepMatrix(self, mat): + # do not allow sparse matrices to be made dense + return self._print_not_supported(mat) + + def _print_MatrixElement(self, expr): + return "%s[%s]" % (expr.parent, + expr.j + expr.i*expr.parent.shape[1]) + + def _print_Symbol(self, expr): + + name = super()._print_Symbol(expr) + + if expr in self._dereference: + return '(*%s)' % name + else: + return name + + def _print_Assignment(self, expr): + from sympy.tensor.indexed import IndexedBase + lhs = expr.lhs + rhs = expr.rhs + if self._settings["contract"] and (lhs.has(IndexedBase) or + rhs.has(IndexedBase)): + # Here we check if there is looping to be done, and if so + # print the required loops. + return self._doprint_loops(rhs, lhs) + else: + lhs_code = self._print(lhs) + rhs_code = self._print(rhs) + return self._get_statement("%s = %s" % (lhs_code, rhs_code)) + + def _print_sign(self, expr): + arg = self._print(expr.args[0]) + return "(if (%s == 0.0) { 0.0 } else { (%s).signum() })" % (arg, arg) + + def _cast_to_float(self, expr): + if not expr.is_number: + return TypeCast(expr, real) + elif expr.is_integer: + return Float(expr) + return expr + + def _can_print(self, name): + """ Check if function ``name`` is either a known function or has its own + printing method. Used to check if rewriting is possible.""" + + # since the whole point of function_overrides is to enable proper printing, + # we presume they all are printable + + return name in self.known_functions or name in function_overrides or getattr(self, '_print_{}'.format(name), False) + + def _collect_functions(self, expr): + functions = set() + if isinstance(expr, Expr): + if expr.is_Function: + functions.add(expr.func) + for arg in expr.args: + functions = functions.union(self._collect_functions(arg)) + return functions + + def _rewrite_known_functions(self, expr): + if not isinstance(expr, Expr): + return expr + + expression_functions = self._collect_functions(expr) + rewriteable_functions = { + name: (target_f, required_fs) + for name, (target_f, required_fs) in self._rewriteable_functions.items() + if self._can_print(target_f) + and all(self._can_print(f) for f in required_fs) + } + for func in expression_functions: + target_f, _ = rewriteable_functions.get(func.__name__, (None, None)) + if target_f: + expr = expr.rewrite(target_f) + return expr + + def indent_code(self, code): + """Accepts a string of code or a list of code lines""" + + if isinstance(code, str): + code_lines = self.indent_code(code.splitlines(True)) + return ''.join(code_lines) + + tab = " " + inc_token = ('{', '(', '{\n', '(\n') + dec_token = ('}', ')') + + code = [ line.lstrip(' \t') for line in code ] + + increase = [ int(any(map(line.endswith, inc_token))) for line in code ] + decrease = [ int(any(map(line.startswith, dec_token))) + for line in code ] + + pretty = [] + level = 0 + for n, line in enumerate(code): + if line in ('', '\n'): + pretty.append(line) + continue + level -= decrease[n] + pretty.append("%s%s" % (tab*level, line)) + level += increase[n] + return pretty diff --git a/phivenv/Lib/site-packages/sympy/printing/smtlib.py b/phivenv/Lib/site-packages/sympy/printing/smtlib.py new file mode 100644 index 0000000000000000000000000000000000000000..8fa015c6198cb32837eb3a0d7fe9d61352da25ad --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/smtlib.py @@ -0,0 +1,583 @@ +import typing + +import sympy +from sympy.core import Add, Mul +from sympy.core import Symbol, Expr, Float, Rational, Integer, Basic +from sympy.core.function import UndefinedFunction, Function +from sympy.core.relational import Relational, Unequality, Equality, LessThan, GreaterThan, StrictLessThan, StrictGreaterThan +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.exponential import exp, log, Pow +from sympy.functions.elementary.hyperbolic import sinh, cosh, tanh +from sympy.functions.elementary.miscellaneous import Min, Max +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import sin, cos, tan, asin, acos, atan, atan2 +from sympy.logic.boolalg import And, Or, Xor, Implies, Boolean +from sympy.logic.boolalg import BooleanTrue, BooleanFalse, BooleanFunction, Not, ITE +from sympy.printing.printer import Printer +from sympy.sets import Interval +from mpmath.libmp.libmpf import prec_to_dps, to_str as mlib_to_str +from sympy.assumptions.assume import AppliedPredicate +from sympy.assumptions.relation.binrel import AppliedBinaryRelation +from sympy.assumptions.ask import Q +from sympy.assumptions.relation.equality import StrictGreaterThanPredicate, StrictLessThanPredicate, GreaterThanPredicate, LessThanPredicate, EqualityPredicate + + +class SMTLibPrinter(Printer): + printmethod = "_smtlib" + + # based on dReal, an automated reasoning tool for solving problems that can be encoded as first-order logic formulas over the real numbers. + # dReal's special strength is in handling problems that involve a wide range of nonlinear real functions. + _default_settings: dict = { + 'precision': None, + 'known_types': { + bool: 'Bool', + int: 'Int', + float: 'Real' + }, + 'known_constants': { + # pi: 'MY_VARIABLE_PI_DECLARED_ELSEWHERE', + }, + 'known_functions': { + Add: '+', + Mul: '*', + + Equality: '=', + LessThan: '<=', + GreaterThan: '>=', + StrictLessThan: '<', + StrictGreaterThan: '>', + + EqualityPredicate(): '=', + LessThanPredicate(): '<=', + GreaterThanPredicate(): '>=', + StrictLessThanPredicate(): '<', + StrictGreaterThanPredicate(): '>', + + exp: 'exp', + log: 'log', + Abs: 'abs', + sin: 'sin', + cos: 'cos', + tan: 'tan', + asin: 'arcsin', + acos: 'arccos', + atan: 'arctan', + atan2: 'arctan2', + sinh: 'sinh', + cosh: 'cosh', + tanh: 'tanh', + Min: 'min', + Max: 'max', + Pow: 'pow', + + And: 'and', + Or: 'or', + Xor: 'xor', + Not: 'not', + ITE: 'ite', + Implies: '=>', + } + } + + symbol_table: dict + + def __init__(self, settings: typing.Optional[dict] = None, + symbol_table=None): + settings = settings or {} + self.symbol_table = symbol_table or {} + Printer.__init__(self, settings) + self._precision = self._settings['precision'] + self._known_types = dict(self._settings['known_types']) + self._known_constants = dict(self._settings['known_constants']) + self._known_functions = dict(self._settings['known_functions']) + + for _ in self._known_types.values(): assert self._is_legal_name(_) + for _ in self._known_constants.values(): assert self._is_legal_name(_) + # for _ in self._known_functions.values(): assert self._is_legal_name(_) # +, *, <, >, etc. + + def _is_legal_name(self, s: str): + if not s: return False + if s[0].isnumeric(): return False + return all(_.isalnum() or _ == '_' for _ in s) + + def _s_expr(self, op: str, args: typing.Union[list, tuple]) -> str: + args_str = ' '.join( + a if isinstance(a, str) + else self._print(a) + for a in args + ) + return f'({op} {args_str})' + + def _print_Function(self, e): + if e in self._known_functions: + op = self._known_functions[e] + elif type(e) in self._known_functions: + op = self._known_functions[type(e)] + elif type(type(e)) == UndefinedFunction: + op = e.name + elif isinstance(e, AppliedBinaryRelation) and e.function in self._known_functions: + op = self._known_functions[e.function] + return self._s_expr(op, e.arguments) + else: + op = self._known_functions[e] # throw KeyError + + return self._s_expr(op, e.args) + + def _print_Relational(self, e: Relational): + return self._print_Function(e) + + def _print_BooleanFunction(self, e: BooleanFunction): + return self._print_Function(e) + + def _print_Expr(self, e: Expr): + return self._print_Function(e) + + def _print_Unequality(self, e: Unequality): + if type(e) in self._known_functions: + return self._print_Relational(e) # default + else: + eq_op = self._known_functions[Equality] + not_op = self._known_functions[Not] + return self._s_expr(not_op, [self._s_expr(eq_op, e.args)]) + + def _print_Piecewise(self, e: Piecewise): + def _print_Piecewise_recursive(args: typing.Union[list, tuple]): + e, c = args[0] + if len(args) == 1: + assert (c is True) or isinstance(c, BooleanTrue) + return self._print(e) + else: + ite = self._known_functions[ITE] + return self._s_expr(ite, [ + c, e, _print_Piecewise_recursive(args[1:]) + ]) + + return _print_Piecewise_recursive(e.args) + + def _print_Interval(self, e: Interval): + if e.start.is_infinite and e.end.is_infinite: + return '' + elif e.start.is_infinite != e.end.is_infinite: + raise ValueError(f'One-sided intervals (`{e}`) are not supported in SMT.') + else: + return f'[{e.start}, {e.end}]' + + def _print_AppliedPredicate(self, e: AppliedPredicate): + if e.function == Q.positive: + rel = Q.gt(e.arguments[0],0) + elif e.function == Q.negative: + rel = Q.lt(e.arguments[0], 0) + elif e.function == Q.zero: + rel = Q.eq(e.arguments[0], 0) + elif e.function == Q.nonpositive: + rel = Q.le(e.arguments[0], 0) + elif e.function == Q.nonnegative: + rel = Q.ge(e.arguments[0], 0) + elif e.function == Q.nonzero: + rel = Q.ne(e.arguments[0], 0) + else: + raise ValueError(f"Predicate (`{e}`) is not handled.") + + return self._print_AppliedBinaryRelation(rel) + + def _print_AppliedBinaryRelation(self, e: AppliedPredicate): + if e.function == Q.ne: + return self._print_Unequality(Unequality(*e.arguments)) + else: + return self._print_Function(e) + + # todo: Sympy does not support quantifiers yet as of 2022, but quantifiers can be handy in SMT. + # For now, users can extend this class and build in their own quantifier support. + # See `test_quantifier_extensions()` in test_smtlib.py for an example of how this might look. + + # def _print_ForAll(self, e: ForAll): + # return self._s('forall', [ + # self._s('', [ + # self._s(sym.name, [self._type_name(sym), Interval(start, end)]) + # for sym, start, end in e.limits + # ]), + # e.function + # ]) + + def _print_BooleanTrue(self, x: BooleanTrue): + return 'true' + + def _print_BooleanFalse(self, x: BooleanFalse): + return 'false' + + def _print_Float(self, x: Float): + dps = prec_to_dps(x._prec) + str_real = mlib_to_str(x._mpf_, dps, strip_zeros=True, min_fixed=None, max_fixed=None) + + if 'e' in str_real: + (mant, exp) = str_real.split('e') + + if exp[0] == '+': + exp = exp[1:] + + mul = self._known_functions[Mul] + pow = self._known_functions[Pow] + + return r"(%s %s (%s 10 %s))" % (mul, mant, pow, exp) + elif str_real in ["+inf", "-inf"]: + raise ValueError("Infinite values are not supported in SMT.") + else: + return str_real + + def _print_float(self, x: float): + return self._print(Float(x)) + + def _print_Rational(self, x: Rational): + return self._s_expr('/', [x.p, x.q]) + + def _print_Integer(self, x: Integer): + assert x.q == 1 + return str(x.p) + + def _print_int(self, x: int): + return str(x) + + def _print_Symbol(self, x: Symbol): + assert self._is_legal_name(x.name) + return x.name + + def _print_NumberSymbol(self, x): + name = self._known_constants.get(x) + if name: + return name + else: + f = x.evalf(self._precision) if self._precision else x.evalf() + return self._print_Float(f) + + def _print_UndefinedFunction(self, x): + assert self._is_legal_name(x.name) + return x.name + + def _print_Exp1(self, x): + return ( + self._print_Function(exp(1, evaluate=False)) + if exp in self._known_functions else + self._print_NumberSymbol(x) + ) + + def emptyPrinter(self, expr): + raise NotImplementedError(f'Cannot convert `{repr(expr)}` of type `{type(expr)}` to SMT.') + + +def smtlib_code( + expr, + auto_assert=True, auto_declare=True, + precision=None, + symbol_table=None, + known_types=None, known_constants=None, known_functions=None, + prefix_expressions=None, suffix_expressions=None, + log_warn=None +): + r"""Converts ``expr`` to a string of smtlib code. + + Parameters + ========== + + expr : Expr | List[Expr] + A SymPy expression or system to be converted. + auto_assert : bool, optional + If false, do not modify expr and produce only the S-Expression equivalent of expr. + If true, assume expr is a system and assert each boolean element. + auto_declare : bool, optional + If false, do not produce declarations for the symbols used in expr. + If true, prepend all necessary declarations for variables used in expr based on symbol_table. + precision : integer, optional + The ``evalf(..)`` precision for numbers such as pi. + symbol_table : dict, optional + A dictionary where keys are ``Symbol`` or ``Function`` instances and values are their Python type i.e. ``bool``, ``int``, ``float``, or ``Callable[...]``. + If incomplete, an attempt will be made to infer types from ``expr``. + known_types: dict, optional + A dictionary where keys are ``bool``, ``int``, ``float`` etc. and values are their corresponding SMT type names. + If not given, a partial listing compatible with several solvers will be used. + known_functions : dict, optional + A dictionary where keys are ``Function``, ``Relational``, ``BooleanFunction``, or ``Expr`` instances and values are their SMT string representations. + If not given, a partial listing optimized for dReal solver (but compatible with others) will be used. + known_constants: dict, optional + A dictionary where keys are ``NumberSymbol`` instances and values are their SMT variable names. + When using this feature, extra caution must be taken to avoid naming collisions between user symbols and listed constants. + If not given, constants will be expanded inline i.e. ``3.14159`` instead of ``MY_SMT_VARIABLE_FOR_PI``. + prefix_expressions: list, optional + A list of lists of ``str`` and/or expressions to convert into SMTLib and prefix to the output. + suffix_expressions: list, optional + A list of lists of ``str`` and/or expressions to convert into SMTLib and postfix to the output. + log_warn: lambda function, optional + A function to record all warnings during potentially risky operations. + Soundness is a core value in SMT solving, so it is good to log all assumptions made. + + Examples + ======== + >>> from sympy import smtlib_code, symbols, sin, Eq + >>> x = symbols('x') + >>> smtlib_code(sin(x).series(x).removeO(), log_warn=print) + Could not infer type of `x`. Defaulting to float. + Non-Boolean expression `x**5/120 - x**3/6 + x` will not be asserted. Converting to SMTLib verbatim. + '(declare-const x Real)\n(+ x (* (/ -1 6) (pow x 3)) (* (/ 1 120) (pow x 5)))' + + >>> from sympy import Rational + >>> x, y, tau = symbols("x, y, tau") + >>> smtlib_code((2*tau)**Rational(7, 2), log_warn=print) + Could not infer type of `tau`. Defaulting to float. + Non-Boolean expression `8*sqrt(2)*tau**(7/2)` will not be asserted. Converting to SMTLib verbatim. + '(declare-const tau Real)\n(* 8 (pow 2 (/ 1 2)) (pow tau (/ 7 2)))' + + ``Piecewise`` expressions are implemented with ``ite`` expressions by default. + Note that if the ``Piecewise`` lacks a default term, represented by + ``(expr, True)`` then an error will be thrown. This is to prevent + generating an expression that may not evaluate to anything. + + >>> from sympy import Piecewise + >>> pw = Piecewise((x + 1, x > 0), (x, True)) + >>> smtlib_code(Eq(pw, 3), symbol_table={x: float}, log_warn=print) + '(declare-const x Real)\n(assert (= (ite (> x 0) (+ 1 x) x) 3))' + + Custom printing can be defined for certain types by passing a dictionary of + PythonType : "SMT Name" to the ``known_types``, ``known_constants``, and ``known_functions`` kwargs. + + >>> from typing import Callable + >>> from sympy import Function, Add + >>> f = Function('f') + >>> g = Function('g') + >>> smt_builtin_funcs = { # functions our SMT solver will understand + ... f: "existing_smtlib_fcn", + ... Add: "sum", + ... } + >>> user_def_funcs = { # functions defined by the user must have their types specified explicitly + ... g: Callable[[int], float], + ... } + >>> smtlib_code(f(x) + g(x), symbol_table=user_def_funcs, known_functions=smt_builtin_funcs, log_warn=print) + Non-Boolean expression `f(x) + g(x)` will not be asserted. Converting to SMTLib verbatim. + '(declare-const x Int)\n(declare-fun g (Int) Real)\n(sum (existing_smtlib_fcn x) (g x))' + """ + log_warn = log_warn or (lambda _: None) + + if not isinstance(expr, list): expr = [expr] + expr = [ + sympy.sympify(_, strict=True, evaluate=False, convert_xor=False) + for _ in expr + ] + + if not symbol_table: symbol_table = {} + symbol_table = _auto_infer_smtlib_types( + *expr, symbol_table=symbol_table + ) + # See [FALLBACK RULES] + # Need SMTLibPrinter to populate known_functions and known_constants first. + + settings = {} + if precision: settings['precision'] = precision + del precision + + if known_types: settings['known_types'] = known_types + del known_types + + if known_functions: settings['known_functions'] = known_functions + del known_functions + + if known_constants: settings['known_constants'] = known_constants + del known_constants + + if not prefix_expressions: prefix_expressions = [] + if not suffix_expressions: suffix_expressions = [] + + p = SMTLibPrinter(settings, symbol_table) + del symbol_table + + # [FALLBACK RULES] + for e in expr: + for sym in e.atoms(Symbol, Function): + if ( + sym.is_Symbol and + sym not in p._known_constants and + sym not in p.symbol_table + ): + log_warn(f"Could not infer type of `{sym}`. Defaulting to float.") + p.symbol_table[sym] = float + if ( + sym.is_Function and + type(sym) not in p._known_functions and + type(sym) not in p.symbol_table and + not sym.is_Piecewise + ): raise TypeError( + f"Unknown type of undefined function `{sym}`. " + f"Must be mapped to ``str`` in known_functions or mapped to ``Callable[..]`` in symbol_table." + ) + + declarations = [] + if auto_declare: + constants = {sym.name: sym for e in expr for sym in e.free_symbols + if sym not in p._known_constants} + functions = {fnc.name: fnc for e in expr for fnc in e.atoms(Function) + if type(fnc) not in p._known_functions and not fnc.is_Piecewise} + declarations = \ + [ + _auto_declare_smtlib(sym, p, log_warn) + for sym in constants.values() + ] + [ + _auto_declare_smtlib(fnc, p, log_warn) + for fnc in functions.values() + ] + declarations = [decl for decl in declarations if decl] + + if auto_assert: + expr = [_auto_assert_smtlib(e, p, log_warn) for e in expr] + + # return SMTLibPrinter().doprint(expr) + return '\n'.join([ + # ';; PREFIX EXPRESSIONS', + *[ + e if isinstance(e, str) else p.doprint(e) + for e in prefix_expressions + ], + + # ';; DECLARATIONS', + *sorted(e for e in declarations), + + # ';; EXPRESSIONS', + *[ + e if isinstance(e, str) else p.doprint(e) + for e in expr + ], + + # ';; SUFFIX EXPRESSIONS', + *[ + e if isinstance(e, str) else p.doprint(e) + for e in suffix_expressions + ], + ]) + + +def _auto_declare_smtlib(sym: typing.Union[Symbol, Function], p: SMTLibPrinter, log_warn: typing.Callable[[str], None]): + if sym.is_Symbol: + type_signature = p.symbol_table[sym] + assert isinstance(type_signature, type) + type_signature = p._known_types[type_signature] + return p._s_expr('declare-const', [sym, type_signature]) + + elif sym.is_Function: + type_signature = p.symbol_table[type(sym)] + assert callable(type_signature) + type_signature = [p._known_types[_] for _ in type_signature.__args__] + assert len(type_signature) > 0 + params_signature = f"({' '.join(type_signature[:-1])})" + return_signature = type_signature[-1] + return p._s_expr('declare-fun', [type(sym), params_signature, return_signature]) + + else: + log_warn(f"Non-Symbol/Function `{sym}` will not be declared.") + return None + + +def _auto_assert_smtlib(e: Expr, p: SMTLibPrinter, log_warn: typing.Callable[[str], None]): + if isinstance(e, Boolean) or ( + e in p.symbol_table and p.symbol_table[e] == bool + ) or ( + e.is_Function and + type(e) in p.symbol_table and + p.symbol_table[type(e)].__args__[-1] == bool + ): + return p._s_expr('assert', [e]) + else: + log_warn(f"Non-Boolean expression `{e}` will not be asserted. Converting to SMTLib verbatim.") + return e + + +def _auto_infer_smtlib_types( + *exprs: Basic, + symbol_table: typing.Optional[dict] = None +) -> dict: + # [TYPE INFERENCE RULES] + # X is alone in an expr => X is bool + # X in BooleanFunction.args => X is bool + # X matches to a bool param of a symbol_table function => X is bool + # X matches to an int param of a symbol_table function => X is int + # X.is_integer => X is int + # X == Y, where X is T => Y is T + + # [FALLBACK RULES] + # see _auto_declare_smtlib(..) + # X is not bool and X is not int and X is Symbol => X is float + # else (e.g. X is Function) => error. must be specified explicitly. + + _symbols = dict(symbol_table) if symbol_table else {} + + def safe_update(syms: set, inf): + for s in syms: + assert s.is_Symbol + if (old_type := _symbols.setdefault(s, inf)) != inf: + raise TypeError(f"Could not infer type of `{s}`. Apparently both `{old_type}` and `{inf}`?") + + # EXPLICIT TYPES + safe_update({ + e + for e in exprs + if e.is_Symbol + }, bool) + + safe_update({ + symbol + for e in exprs + for boolfunc in e.atoms(BooleanFunction) + for symbol in boolfunc.args + if symbol.is_Symbol + }, bool) + + safe_update({ + symbol + for e in exprs + for boolfunc in e.atoms(Function) + if type(boolfunc) in _symbols + for symbol, param in zip(boolfunc.args, _symbols[type(boolfunc)].__args__) + if symbol.is_Symbol and param == bool + }, bool) + + safe_update({ + symbol + for e in exprs + for intfunc in e.atoms(Function) + if type(intfunc) in _symbols + for symbol, param in zip(intfunc.args, _symbols[type(intfunc)].__args__) + if symbol.is_Symbol and param == int + }, int) + + safe_update({ + symbol + for e in exprs + for symbol in e.atoms(Symbol) + if symbol.is_integer + }, int) + + safe_update({ + symbol + for e in exprs + for symbol in e.atoms(Symbol) + if symbol.is_real and not symbol.is_integer + }, float) + + # EQUALITY RELATION RULE + rels_eq = [rel for expr in exprs for rel in expr.atoms(Equality)] + rels = [ + (rel.lhs, rel.rhs) for rel in rels_eq if rel.lhs.is_Symbol + ] + [ + (rel.rhs, rel.lhs) for rel in rels_eq if rel.rhs.is_Symbol + ] + for infer, reltd in rels: + inference = ( + _symbols[infer] if infer in _symbols else + _symbols[reltd] if reltd in _symbols else + + _symbols[type(reltd)].__args__[-1] + if reltd.is_Function and type(reltd) in _symbols else + + bool if reltd.is_Boolean else + int if reltd.is_integer or reltd.is_Integer else + float if reltd.is_real else + None + ) + if inference: safe_update({infer}, inference) + + return _symbols diff --git a/phivenv/Lib/site-packages/sympy/printing/str.py b/phivenv/Lib/site-packages/sympy/printing/str.py new file mode 100644 index 0000000000000000000000000000000000000000..9975776fbb73bec9c956fe359387fa8036995795 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/str.py @@ -0,0 +1,1021 @@ +""" +A Printer for generating readable representation of most SymPy classes. +""" + +from __future__ import annotations +from typing import Any + +from sympy.core import S, Rational, Pow, Basic, Mul, Number +from sympy.core.mul import _keep_coeff +from sympy.core.numbers import Integer +from sympy.core.relational import Relational +from sympy.core.sorting import default_sort_key +from sympy.utilities.iterables import sift +from .precedence import precedence, PRECEDENCE +from .printer import Printer, print_function + +from mpmath.libmp import prec_to_dps, to_str as mlib_to_str + + +class StrPrinter(Printer): + printmethod = "_sympystr" + _default_settings: dict[str, Any] = { + "order": None, + "full_prec": "auto", + "sympy_integers": False, + "abbrev": False, + "perm_cyclic": True, + "min": None, + "max": None, + "dps" : None + } + + _relationals: dict[str, str] = {} + + def parenthesize(self, item, level, strict=False): + if (precedence(item) < level) or ((not strict) and precedence(item) <= level): + return "(%s)" % self._print(item) + else: + return self._print(item) + + def stringify(self, args, sep, level=0): + return sep.join([self.parenthesize(item, level) for item in args]) + + def emptyPrinter(self, expr): + if isinstance(expr, str): + return expr + elif isinstance(expr, Basic): + return repr(expr) + else: + return str(expr) + + def _print_Add(self, expr, order=None): + terms = self._as_ordered_terms(expr, order=order) + + prec = precedence(expr) + l = [] + for term in terms: + t = self._print(term) + if t.startswith('-') and not term.is_Add: + sign = "-" + t = t[1:] + else: + sign = "+" + if precedence(term) < prec or term.is_Add: + l.extend([sign, "(%s)" % t]) + else: + l.extend([sign, t]) + sign = l.pop(0) + if sign == '+': + sign = "" + return sign + ' '.join(l) + + def _print_BooleanTrue(self, expr): + return "True" + + def _print_BooleanFalse(self, expr): + return "False" + + def _print_Not(self, expr): + return '~%s' %(self.parenthesize(expr.args[0],PRECEDENCE["Not"])) + + def _print_And(self, expr): + args = list(expr.args) + for j, i in enumerate(args): + if isinstance(i, Relational) and ( + i.canonical.rhs is S.NegativeInfinity): + args.insert(0, args.pop(j)) + return self.stringify(args, " & ", PRECEDENCE["BitwiseAnd"]) + + def _print_Or(self, expr): + return self.stringify(expr.args, " | ", PRECEDENCE["BitwiseOr"]) + + def _print_Xor(self, expr): + return self.stringify(expr.args, " ^ ", PRECEDENCE["BitwiseXor"]) + + def _print_AppliedPredicate(self, expr): + return '%s(%s)' % ( + self._print(expr.function), self.stringify(expr.arguments, ", ")) + + def _print_Basic(self, expr): + l = [self._print(o) for o in expr.args] + return expr.__class__.__name__ + "(%s)" % ", ".join(l) + + def _print_BlockMatrix(self, B): + if B.blocks.shape == (1, 1): + self._print(B.blocks[0, 0]) + return self._print(B.blocks) + + def _print_Catalan(self, expr): + return 'Catalan' + + def _print_ComplexInfinity(self, expr): + return 'zoo' + + def _print_ConditionSet(self, s): + args = tuple([self._print(i) for i in (s.sym, s.condition)]) + if s.base_set is S.UniversalSet: + return 'ConditionSet(%s, %s)' % args + args += (self._print(s.base_set),) + return 'ConditionSet(%s, %s, %s)' % args + + def _print_Derivative(self, expr): + dexpr = expr.expr + dvars = [i[0] if i[1] == 1 else i for i in expr.variable_count] + return 'Derivative(%s)' % ", ".join((self._print(arg) for arg in [dexpr] + dvars)) + + def _print_dict(self, d): + keys = sorted(d.keys(), key=default_sort_key) + items = [] + + for key in keys: + item = "%s: %s" % (self._print(key), self._print(d[key])) + items.append(item) + + return "{%s}" % ", ".join(items) + + def _print_Dict(self, expr): + return self._print_dict(expr) + + def _print_RandomDomain(self, d): + if hasattr(d, 'as_boolean'): + return 'Domain: ' + self._print(d.as_boolean()) + elif hasattr(d, 'set'): + return ('Domain: ' + self._print(d.symbols) + ' in ' + + self._print(d.set)) + else: + return 'Domain on ' + self._print(d.symbols) + + def _print_Dummy(self, expr): + return '_' + expr.name + + def _print_EulerGamma(self, expr): + return 'EulerGamma' + + def _print_Exp1(self, expr): + return 'E' + + def _print_ExprCondPair(self, expr): + return '(%s, %s)' % (self._print(expr.expr), self._print(expr.cond)) + + def _print_Function(self, expr): + return expr.func.__name__ + "(%s)" % self.stringify(expr.args, ", ") + + def _print_GoldenRatio(self, expr): + return 'GoldenRatio' + + def _print_Heaviside(self, expr): + # Same as _print_Function but uses pargs to suppress default 1/2 for + # 2nd args + return expr.func.__name__ + "(%s)" % self.stringify(expr.pargs, ", ") + + def _print_TribonacciConstant(self, expr): + return 'TribonacciConstant' + + def _print_ImaginaryUnit(self, expr): + return 'I' + + def _print_Infinity(self, expr): + return 'oo' + + def _print_Integral(self, expr): + def _xab_tostr(xab): + if len(xab) == 1: + return self._print(xab[0]) + else: + return self._print((xab[0],) + tuple(xab[1:])) + L = ', '.join([_xab_tostr(l) for l in expr.limits]) + return 'Integral(%s, %s)' % (self._print(expr.function), L) + + def _print_Interval(self, i): + fin = 'Interval{m}({a}, {b})' + a, b, l, r = i.args + if a.is_infinite and b.is_infinite: + m = '' + elif a.is_infinite and not r: + m = '' + elif b.is_infinite and not l: + m = '' + elif not l and not r: + m = '' + elif l and r: + m = '.open' + elif l: + m = '.Lopen' + else: + m = '.Ropen' + return fin.format(**{'a': a, 'b': b, 'm': m}) + + def _print_AccumulationBounds(self, i): + return "AccumBounds(%s, %s)" % (self._print(i.min), + self._print(i.max)) + + def _print_Inverse(self, I): + return "%s**(-1)" % self.parenthesize(I.arg, PRECEDENCE["Pow"]) + + def _print_Lambda(self, obj): + expr = obj.expr + sig = obj.signature + if len(sig) == 1 and sig[0].is_symbol: + sig = sig[0] + return "Lambda(%s, %s)" % (self._print(sig), self._print(expr)) + + def _print_LatticeOp(self, expr): + args = sorted(expr.args, key=default_sort_key) + return expr.func.__name__ + "(%s)" % ", ".join(self._print(arg) for arg in args) + + def _print_Limit(self, expr): + e, z, z0, dir = expr.args + return "Limit(%s, %s, %s, dir='%s')" % tuple(map(self._print, (e, z, z0, dir))) + + + def _print_list(self, expr): + return "[%s]" % self.stringify(expr, ", ") + + def _print_List(self, expr): + return self._print_list(expr) + + def _print_MatrixBase(self, expr): + return expr._format_str(self) + + def _print_MatrixElement(self, expr): + return self.parenthesize(expr.parent, PRECEDENCE["Atom"], strict=True) \ + + '[%s, %s]' % (self._print(expr.i), self._print(expr.j)) + + def _print_MatrixSlice(self, expr): + def strslice(x, dim): + x = list(x) + if x[2] == 1: + del x[2] + if x[0] == 0: + x[0] = '' + if x[1] == dim: + x[1] = '' + return ':'.join((self._print(arg) for arg in x)) + return (self.parenthesize(expr.parent, PRECEDENCE["Atom"], strict=True) + '[' + + strslice(expr.rowslice, expr.parent.rows) + ', ' + + strslice(expr.colslice, expr.parent.cols) + ']') + + def _print_DeferredVector(self, expr): + return expr.name + + def _print_Mul(self, expr): + + prec = precedence(expr) + + # Check for unevaluated Mul. In this case we need to make sure the + # identities are visible, multiple Rational factors are not combined + # etc so we display in a straight-forward form that fully preserves all + # args and their order. + args = expr.args + if args[0] is S.One or any( + isinstance(a, Number) or + a.is_Pow and all(ai.is_Integer for ai in a.args) + for a in args[1:]): + d, n = sift(args, lambda x: + isinstance(x, Pow) and bool(x.exp.as_coeff_Mul()[0] < 0), + binary=True) + for i, di in enumerate(d): + if di.exp.is_Number: + e = -di.exp + else: + dargs = list(di.exp.args) + dargs[0] = -dargs[0] + e = Mul._from_args(dargs) + d[i] = Pow(di.base, e, evaluate=False) if e - 1 else di.base + + pre = [] + # don't parenthesize first factor if negative + if n and not n[0].is_Add and n[0].could_extract_minus_sign(): + pre = [self._print(n.pop(0))] + + nfactors = pre + [self.parenthesize(a, prec, strict=False) + for a in n] + if not nfactors: + nfactors = ['1'] + + # don't parenthesize first of denominator unless singleton + if len(d) > 1 and d[0].could_extract_minus_sign(): + pre = [self._print(d.pop(0))] + else: + pre = [] + dfactors = pre + [self.parenthesize(a, prec, strict=False) + for a in d] + + n = '*'.join(nfactors) + d = '*'.join(dfactors) + if len(dfactors) > 1: + return '%s/(%s)' % (n, d) + elif dfactors: + return '%s/%s' % (n, d) + return n + + c, e = expr.as_coeff_Mul() + if c < 0: + expr = _keep_coeff(-c, e) + sign = "-" + else: + sign = "" + + a = [] # items in the numerator + b = [] # items that are in the denominator (if any) + + pow_paren = [] # Will collect all pow with more than one base element and exp = -1 + + if self.order not in ('old', 'none'): + args = expr.as_ordered_factors() + else: + # use make_args in case expr was something like -x -> x + args = Mul.make_args(expr) + + # Gather args for numerator/denominator + def apow(i): + b, e = i.as_base_exp() + eargs = list(Mul.make_args(e)) + if eargs[0] is S.NegativeOne: + eargs = eargs[1:] + else: + eargs[0] = -eargs[0] + e = Mul._from_args(eargs) + if isinstance(i, Pow): + return i.func(b, e, evaluate=False) + return i.func(e, evaluate=False) + for item in args: + if (item.is_commutative and + isinstance(item, Pow) and + bool(item.exp.as_coeff_Mul()[0] < 0)): + if item.exp is not S.NegativeOne: + b.append(apow(item)) + else: + if (len(item.args[0].args) != 1 and + isinstance(item.base, (Mul, Pow))): + # To avoid situations like #14160 + pow_paren.append(item) + b.append(item.base) + elif item.is_Rational and item is not S.Infinity: + if item.p != 1: + a.append(Rational(item.p)) + if item.q != 1: + b.append(Rational(item.q)) + else: + a.append(item) + + a = a or [S.One] + + a_str = [self.parenthesize(x, prec, strict=False) for x in a] + b_str = [self.parenthesize(x, prec, strict=False) for x in b] + + # To parenthesize Pow with exp = -1 and having more than one Symbol + for item in pow_paren: + if item.base in b: + b_str[b.index(item.base)] = "(%s)" % b_str[b.index(item.base)] + + if not b: + return sign + '*'.join(a_str) + elif len(b) == 1: + return sign + '*'.join(a_str) + "/" + b_str[0] + else: + return sign + '*'.join(a_str) + "/(%s)" % '*'.join(b_str) + + def _print_MatMul(self, expr): + c, m = expr.as_coeff_mmul() + + sign = "" + if c.is_number: + re, im = c.as_real_imag() + if im.is_zero and re.is_negative: + expr = _keep_coeff(-c, m) + sign = "-" + elif re.is_zero and im.is_negative: + expr = _keep_coeff(-c, m) + sign = "-" + + return sign + '*'.join( + [self.parenthesize(arg, precedence(expr)) for arg in expr.args] + ) + + def _print_ElementwiseApplyFunction(self, expr): + return "{}.({})".format( + expr.function, + self._print(expr.expr), + ) + + def _print_NaN(self, expr): + return 'nan' + + def _print_NegativeInfinity(self, expr): + return '-oo' + + def _print_Order(self, expr): + if not expr.variables or all(p is S.Zero for p in expr.point): + if len(expr.variables) <= 1: + return 'O(%s)' % self._print(expr.expr) + else: + return 'O(%s)' % self.stringify((expr.expr,) + expr.variables, ', ', 0) + else: + return 'O(%s)' % self.stringify(expr.args, ', ', 0) + + def _print_Ordinal(self, expr): + return expr.__str__() + + def _print_Cycle(self, expr): + return expr.__str__() + + def _print_Permutation(self, expr): + from sympy.combinatorics.permutations import Permutation, Cycle + from sympy.utilities.exceptions import sympy_deprecation_warning + + perm_cyclic = Permutation.print_cyclic + if perm_cyclic is not None: + sympy_deprecation_warning( + f""" + Setting Permutation.print_cyclic is deprecated. Instead use + init_printing(perm_cyclic={perm_cyclic}). + """, + deprecated_since_version="1.6", + active_deprecations_target="deprecated-permutation-print_cyclic", + stacklevel=7, + ) + else: + perm_cyclic = self._settings.get("perm_cyclic", True) + + if perm_cyclic: + if not expr.size: + return '()' + # before taking Cycle notation, see if the last element is + # a singleton and move it to the head of the string + s = Cycle(expr)(expr.size - 1).__repr__()[len('Cycle'):] + last = s.rfind('(') + if not last == 0 and ',' not in s[last:]: + s = s[last:] + s[:last] + s = s.replace(',', '') + return s + else: + s = expr.support() + if not s: + if expr.size < 5: + return 'Permutation(%s)' % self._print(expr.array_form) + return 'Permutation([], size=%s)' % self._print(expr.size) + trim = self._print(expr.array_form[:s[-1] + 1]) + ', size=%s' % self._print(expr.size) + use = full = self._print(expr.array_form) + if len(trim) < len(full): + use = trim + return 'Permutation(%s)' % use + + def _print_Subs(self, obj): + expr, old, new = obj.args + if len(obj.point) == 1: + old = old[0] + new = new[0] + return "Subs(%s, %s, %s)" % ( + self._print(expr), self._print(old), self._print(new)) + + def _print_TensorIndex(self, expr): + return expr._print() + + def _print_TensorHead(self, expr): + return expr._print() + + def _print_Tensor(self, expr): + return expr._print() + + def _print_TensMul(self, expr): + # prints expressions like "A(a)", "3*A(a)", "(1+x)*A(a)" + sign, args = expr._get_args_for_traditional_printer() + return sign + "*".join( + [self.parenthesize(arg, precedence(expr)) for arg in args] + ) + + def _print_TensAdd(self, expr): + return expr._print() + + def _print_ArraySymbol(self, expr): + return self._print(expr.name) + + def _print_ArrayElement(self, expr): + return "%s[%s]" % ( + self.parenthesize(expr.name, PRECEDENCE["Func"], True), ", ".join([self._print(i) for i in expr.indices])) + + def _print_PermutationGroup(self, expr): + p = [' %s' % self._print(a) for a in expr.args] + return 'PermutationGroup([\n%s])' % ',\n'.join(p) + + def _print_Pi(self, expr): + return 'pi' + + def _print_PolyRing(self, ring): + return "Polynomial ring in %s over %s with %s order" % \ + (", ".join((self._print(rs) for rs in ring.symbols)), + self._print(ring.domain), self._print(ring.order)) + + def _print_FracField(self, field): + return "Rational function field in %s over %s with %s order" % \ + (", ".join((self._print(fs) for fs in field.symbols)), + self._print(field.domain), self._print(field.order)) + + def _print_FreeGroupElement(self, elm): + return elm.__str__() + + def _print_GaussianElement(self, poly): + return "(%s + %s*I)" % (poly.x, poly.y) + + def _print_PolyElement(self, poly): + return poly.str(self, PRECEDENCE, "%s**%s", "*") + + def _print_FracElement(self, frac): + if frac.denom == 1: + return self._print(frac.numer) + else: + numer = self.parenthesize(frac.numer, PRECEDENCE["Mul"], strict=True) + denom = self.parenthesize(frac.denom, PRECEDENCE["Atom"], strict=True) + return numer + "/" + denom + + def _print_Poly(self, expr): + ATOM_PREC = PRECEDENCE["Atom"] - 1 + terms, gens = [], [ self.parenthesize(s, ATOM_PREC) for s in expr.gens ] + + for monom, coeff in expr.terms(): + s_monom = [] + + for i, e in enumerate(monom): + if e > 0: + if e == 1: + s_monom.append(gens[i]) + else: + s_monom.append(gens[i] + "**%d" % e) + + s_monom = "*".join(s_monom) + + if coeff.is_Add: + if s_monom: + s_coeff = "(" + self._print(coeff) + ")" + else: + s_coeff = self._print(coeff) + else: + if s_monom: + if coeff is S.One: + terms.extend(['+', s_monom]) + continue + + if coeff is S.NegativeOne: + terms.extend(['-', s_monom]) + continue + + s_coeff = self._print(coeff) + + if not s_monom: + s_term = s_coeff + else: + s_term = s_coeff + "*" + s_monom + + if s_term.startswith('-'): + terms.extend(['-', s_term[1:]]) + else: + terms.extend(['+', s_term]) + + if terms[0] in ('-', '+'): + modifier = terms.pop(0) + + if modifier == '-': + terms[0] = '-' + terms[0] + + format = expr.__class__.__name__ + "(%s, %s" + + from sympy.polys.polyerrors import PolynomialError + + try: + format += ", modulus=%s" % expr.get_modulus() + except PolynomialError: + format += ", domain='%s'" % expr.get_domain() + + format += ")" + + for index, item in enumerate(gens): + if len(item) > 2 and (item[:1] == "(" and item[len(item) - 1:] == ")"): + gens[index] = item[1:len(item) - 1] + + return format % (' '.join(terms), ', '.join(gens)) + + def _print_UniversalSet(self, p): + return 'UniversalSet' + + def _print_AlgebraicNumber(self, expr): + if expr.is_aliased: + return self._print(expr.as_poly().as_expr()) + else: + return self._print(expr.as_expr()) + + def _print_Pow(self, expr, rational=False): + """Printing helper function for ``Pow`` + + Parameters + ========== + + rational : bool, optional + If ``True``, it will not attempt printing ``sqrt(x)`` or + ``x**S.Half`` as ``sqrt``, and will use ``x**(1/2)`` + instead. + + See examples for additional details + + Examples + ======== + + >>> from sympy import sqrt, StrPrinter + >>> from sympy.abc import x + + How ``rational`` keyword works with ``sqrt``: + + >>> printer = StrPrinter() + >>> printer._print_Pow(sqrt(x), rational=True) + 'x**(1/2)' + >>> printer._print_Pow(sqrt(x), rational=False) + 'sqrt(x)' + >>> printer._print_Pow(1/sqrt(x), rational=True) + 'x**(-1/2)' + >>> printer._print_Pow(1/sqrt(x), rational=False) + '1/sqrt(x)' + + Notes + ===== + + ``sqrt(x)`` is canonicalized as ``Pow(x, S.Half)`` in SymPy, + so there is no need of defining a separate printer for ``sqrt``. + Instead, it should be handled here as well. + """ + PREC = precedence(expr) + + if expr.exp is S.Half and not rational: + return "sqrt(%s)" % self._print(expr.base) + + if expr.is_commutative: + if -expr.exp is S.Half and not rational: + # Note: Don't test "expr.exp == -S.Half" here, because that will + # match -0.5, which we don't want. + return "%s/sqrt(%s)" % tuple((self._print(arg) for arg in (S.One, expr.base))) + if expr.exp is -S.One: + # Similarly to the S.Half case, don't test with "==" here. + return '%s/%s' % (self._print(S.One), + self.parenthesize(expr.base, PREC, strict=False)) + + e = self.parenthesize(expr.exp, PREC, strict=False) + if self.printmethod == '_sympyrepr' and expr.exp.is_Rational and expr.exp.q != 1: + # the parenthesized exp should be '(Rational(a, b))' so strip parens, + # but just check to be sure. + if e.startswith('(Rational'): + return '%s**%s' % (self.parenthesize(expr.base, PREC, strict=False), e[1:-1]) + return '%s**%s' % (self.parenthesize(expr.base, PREC, strict=False), e) + + def _print_UnevaluatedExpr(self, expr): + return self._print(expr.args[0]) + + def _print_MatPow(self, expr): + PREC = precedence(expr) + return '%s**%s' % (self.parenthesize(expr.base, PREC, strict=False), + self.parenthesize(expr.exp, PREC, strict=False)) + + def _print_Integer(self, expr): + if self._settings.get("sympy_integers", False): + return "S(%s)" % (expr) + return str(expr.p) + + def _print_Integers(self, expr): + return 'Integers' + + def _print_Naturals(self, expr): + return 'Naturals' + + def _print_Naturals0(self, expr): + return 'Naturals0' + + def _print_Rationals(self, expr): + return 'Rationals' + + def _print_Reals(self, expr): + return 'Reals' + + def _print_Complexes(self, expr): + return 'Complexes' + + def _print_EmptySet(self, expr): + return 'EmptySet' + + def _print_EmptySequence(self, expr): + return 'EmptySequence' + + def _print_int(self, expr): + return str(expr) + + def _print_mpz(self, expr): + return str(expr) + + def _print_Rational(self, expr): + if expr.q == 1: + return str(expr.p) + else: + if self._settings.get("sympy_integers", False): + return "S(%s)/%s" % (expr.p, expr.q) + return "%s/%s" % (expr.p, expr.q) + + def _print_PythonRational(self, expr): + if expr.q == 1: + return str(expr.p) + else: + return "%d/%d" % (expr.p, expr.q) + + def _print_Fraction(self, expr): + if expr.denominator == 1: + return str(expr.numerator) + else: + return "%s/%s" % (expr.numerator, expr.denominator) + + def _print_mpq(self, expr): + if expr.denominator == 1: + return str(expr.numerator) + else: + return "%s/%s" % (expr.numerator, expr.denominator) + + def _print_Float(self, expr): + prec = expr._prec + dps = self._settings.get('dps', None) + if dps is None: + dps = 0 if prec < 5 else prec_to_dps(expr._prec) + if self._settings["full_prec"] is True: + strip = False + elif self._settings["full_prec"] is False: + strip = True + elif self._settings["full_prec"] == "auto": + strip = self._print_level > 1 + low = self._settings["min"] if "min" in self._settings else None + high = self._settings["max"] if "max" in self._settings else None + rv = mlib_to_str(expr._mpf_, dps, strip_zeros=strip, min_fixed=low, max_fixed=high) + if rv.startswith('-.0'): + rv = '-0.' + rv[3:] + elif rv.startswith('.0'): + rv = '0.' + rv[2:] + rv = rv.removeprefix('+') # e.g., +inf -> inf + return rv + + def _print_Relational(self, expr): + + charmap = { + "==": "Eq", + "!=": "Ne", + ":=": "Assignment", + '+=': "AddAugmentedAssignment", + "-=": "SubAugmentedAssignment", + "*=": "MulAugmentedAssignment", + "/=": "DivAugmentedAssignment", + "%=": "ModAugmentedAssignment", + } + + if expr.rel_op in charmap: + return '%s(%s, %s)' % (charmap[expr.rel_op], self._print(expr.lhs), + self._print(expr.rhs)) + + return '%s %s %s' % (self.parenthesize(expr.lhs, precedence(expr)), + self._relationals.get(expr.rel_op) or expr.rel_op, + self.parenthesize(expr.rhs, precedence(expr))) + + def _print_ComplexRootOf(self, expr): + return "CRootOf(%s, %d)" % (self._print_Add(expr.expr, order='lex'), + expr.index) + + def _print_RootSum(self, expr): + args = [self._print_Add(expr.expr, order='lex')] + + if expr.fun is not S.IdentityFunction: + args.append(self._print(expr.fun)) + + return "RootSum(%s)" % ", ".join(args) + + def _print_GroebnerBasis(self, basis): + cls = basis.__class__.__name__ + + exprs = [self._print_Add(arg, order=basis.order) for arg in basis.exprs] + exprs = "[%s]" % ", ".join(exprs) + + gens = [ self._print(gen) for gen in basis.gens ] + domain = "domain='%s'" % self._print(basis.domain) + order = "order='%s'" % self._print(basis.order) + + args = [exprs] + gens + [domain, order] + + return "%s(%s)" % (cls, ", ".join(args)) + + def _print_set(self, s): + items = sorted(s, key=default_sort_key) + + args = ', '.join(self._print(item) for item in items) + if not args: + return "set()" + return '{%s}' % args + + def _print_FiniteSet(self, s): + from sympy.sets.sets import FiniteSet + items = sorted(s, key=default_sort_key) + + args = ', '.join(self._print(item) for item in items) + if any(item.has(FiniteSet) for item in items): + return 'FiniteSet({})'.format(args) + return '{{{}}}'.format(args) + + def _print_Partition(self, s): + items = sorted(s, key=default_sort_key) + + args = ', '.join(self._print(arg) for arg in items) + return 'Partition({})'.format(args) + + def _print_frozenset(self, s): + if not s: + return "frozenset()" + return "frozenset(%s)" % self._print_set(s) + + def _print_Sum(self, expr): + def _xab_tostr(xab): + if len(xab) == 1: + return self._print(xab[0]) + else: + return self._print((xab[0],) + tuple(xab[1:])) + L = ', '.join([_xab_tostr(l) for l in expr.limits]) + return 'Sum(%s, %s)' % (self._print(expr.function), L) + + def _print_Symbol(self, expr): + return expr.name + _print_MatrixSymbol = _print_Symbol + _print_RandomSymbol = _print_Symbol + + def _print_Identity(self, expr): + return "I" + + def _print_ZeroMatrix(self, expr): + return "0" + + def _print_OneMatrix(self, expr): + return "1" + + def _print_Predicate(self, expr): + return "Q.%s" % expr.name + + def _print_str(self, expr): + return str(expr) + + def _print_tuple(self, expr): + if len(expr) == 1: + return "(%s,)" % self._print(expr[0]) + else: + return "(%s)" % self.stringify(expr, ", ") + + def _print_Tuple(self, expr): + return self._print_tuple(expr) + + def _print_Transpose(self, T): + return "%s.T" % self.parenthesize(T.arg, PRECEDENCE["Pow"]) + + def _print_Uniform(self, expr): + return "Uniform(%s, %s)" % (self._print(expr.a), self._print(expr.b)) + + def _print_Quantity(self, expr): + if self._settings.get("abbrev", False): + return "%s" % expr.abbrev + return "%s" % expr.name + + def _print_Quaternion(self, expr): + s = [self.parenthesize(i, PRECEDENCE["Mul"], strict=True) for i in expr.args] + a = [s[0]] + [i+"*"+j for i, j in zip(s[1:], "ijk")] + return " + ".join(a) + + def _print_Dimension(self, expr): + return str(expr) + + def _print_Wild(self, expr): + return expr.name + '_' + + def _print_WildFunction(self, expr): + return expr.name + '_' + + def _print_WildDot(self, expr): + return expr.name + + def _print_WildPlus(self, expr): + return expr.name + + def _print_WildStar(self, expr): + return expr.name + + def _print_Zero(self, expr): + if self._settings.get("sympy_integers", False): + return "S(0)" + return self._print_Integer(Integer(0)) + + def _print_DMP(self, p): + cls = p.__class__.__name__ + rep = self._print(p.to_list()) + dom = self._print(p.dom) + + return "%s(%s, %s)" % (cls, rep, dom) + + def _print_DMF(self, expr): + cls = expr.__class__.__name__ + num = self._print(expr.num) + den = self._print(expr.den) + dom = self._print(expr.dom) + + return "%s(%s, %s, %s)" % (cls, num, den, dom) + + def _print_Object(self, obj): + return 'Object("%s")' % obj.name + + def _print_IdentityMorphism(self, morphism): + return 'IdentityMorphism(%s)' % morphism.domain + + def _print_NamedMorphism(self, morphism): + return 'NamedMorphism(%s, %s, "%s")' % \ + (morphism.domain, morphism.codomain, morphism.name) + + def _print_Category(self, category): + return 'Category("%s")' % category.name + + def _print_Manifold(self, manifold): + return manifold.name.name + + def _print_Patch(self, patch): + return patch.name.name + + def _print_CoordSystem(self, coords): + return coords.name.name + + def _print_BaseScalarField(self, field): + return field._coord_sys.symbols[field._index].name + + def _print_BaseVectorField(self, field): + return 'e_%s' % field._coord_sys.symbols[field._index].name + + def _print_Differential(self, diff): + field = diff._form_field + if hasattr(field, '_coord_sys'): + return 'd%s' % field._coord_sys.symbols[field._index].name + else: + return 'd(%s)' % self._print(field) + + def _print_Tr(self, expr): + #TODO : Handle indices + return "%s(%s)" % ("Tr", self._print(expr.args[0])) + + def _print_Str(self, s): + return self._print(s.name) + + def _print_AppliedBinaryRelation(self, expr): + rel = expr.function + return '%s(%s, %s)' % (self._print(rel), + self._print(expr.lhs), + self._print(expr.rhs)) + + +@print_function(StrPrinter) +def sstr(expr, **settings): + """Returns the expression as a string. + + For large expressions where speed is a concern, use the setting + order='none'. If abbrev=True setting is used then units are printed in + abbreviated form. + + Examples + ======== + + >>> from sympy import symbols, Eq, sstr + >>> a, b = symbols('a b') + >>> sstr(Eq(a + b, 0)) + 'Eq(a + b, 0)' + """ + + p = StrPrinter(settings) + s = p.doprint(expr) + + return s + + +class StrReprPrinter(StrPrinter): + """(internal) -- see sstrrepr""" + + def _print_str(self, s): + return repr(s) + + def _print_Str(self, s): + # Str does not to be printed same as str here + return "%s(%s)" % (s.__class__.__name__, self._print(s.name)) + +@print_function(StrReprPrinter) +def sstrrepr(expr, **settings): + """return expr in mixed str/repr form + + i.e. strings are returned in repr form with quotes, and everything else + is returned in str form. + + This function could be useful for hooking into sys.displayhook + """ + + p = StrReprPrinter(settings) + s = p.doprint(expr) + + return s diff --git a/phivenv/Lib/site-packages/sympy/printing/tableform.py b/phivenv/Lib/site-packages/sympy/printing/tableform.py new file mode 100644 index 0000000000000000000000000000000000000000..4a84ef96ae92517a6ec01ca9db1a13e9afa67093 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/tableform.py @@ -0,0 +1,366 @@ +from sympy.core.containers import Tuple +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.core.sympify import SympifyError + +from types import FunctionType + + +class TableForm: + r""" + Create a nice table representation of data. + + Examples + ======== + + >>> from sympy import TableForm + >>> t = TableForm([[5, 7], [4, 2], [10, 3]]) + >>> print(t) + 5 7 + 4 2 + 10 3 + + You can use the SymPy's printing system to produce tables in any + format (ascii, latex, html, ...). + + >>> print(t.as_latex()) + \begin{tabular}{l l} + $5$ & $7$ \\ + $4$ & $2$ \\ + $10$ & $3$ \\ + \end{tabular} + + """ + + def __init__(self, data, **kwarg): + """ + Creates a TableForm. + + Parameters: + + data ... + 2D data to be put into the table; data can be + given as a Matrix + + headings ... + gives the labels for rows and columns: + + Can be a single argument that applies to both + dimensions: + + - None ... no labels + - "automatic" ... labels are 1, 2, 3, ... + + Can be a list of labels for rows and columns: + The labels for each dimension can be given + as None, "automatic", or [l1, l2, ...] e.g. + ["automatic", None] will number the rows + + [default: None] + + alignments ... + alignment of the columns with: + + - "left" or "<" + - "center" or "^" + - "right" or ">" + + When given as a single value, the value is used for + all columns. The row headings (if given) will be + right justified unless an explicit alignment is + given for it and all other columns. + + [default: "left"] + + formats ... + a list of format strings or functions that accept + 3 arguments (entry, row number, col number) and + return a string for the table entry. (If a function + returns None then the _print method will be used.) + + wipe_zeros ... + Do not show zeros in the table. + + [default: True] + + pad ... + the string to use to indicate a missing value (e.g. + elements that are None or those that are missing + from the end of a row (i.e. any row that is shorter + than the rest is assumed to have missing values). + When None, nothing will be shown for values that + are missing from the end of a row; values that are + None, however, will be shown. + + [default: None] + + Examples + ======== + + >>> from sympy import TableForm, Symbol + >>> TableForm([[5, 7], [4, 2], [10, 3]]) + 5 7 + 4 2 + 10 3 + >>> TableForm([list('.'*i) for i in range(1, 4)], headings='automatic') + | 1 2 3 + --------- + 1 | . + 2 | . . + 3 | . . . + >>> TableForm([[Symbol('.'*(j if not i%2 else 1)) for i in range(3)] + ... for j in range(4)], alignments='rcl') + . + . . . + .. . .. + ... . ... + """ + from sympy.matrices.dense import Matrix + + # We only support 2D data. Check the consistency: + if isinstance(data, Matrix): + data = data.tolist() + _h = len(data) + + # fill out any short lines + pad = kwarg.get('pad', None) + ok_None = False + if pad is None: + pad = " " + ok_None = True + pad = Symbol(pad) + _w = max(len(line) for line in data) + for i, line in enumerate(data): + if len(line) != _w: + line.extend([pad]*(_w - len(line))) + for j, lj in enumerate(line): + if lj is None: + if not ok_None: + lj = pad + else: + try: + lj = S(lj) + except SympifyError: + lj = Symbol(str(lj)) + line[j] = lj + data[i] = line + _lines = Tuple(*[Tuple(*d) for d in data]) + + headings = kwarg.get("headings", [None, None]) + if headings == "automatic": + _headings = [range(1, _h + 1), range(1, _w + 1)] + else: + h1, h2 = headings + if h1 == "automatic": + h1 = range(1, _h + 1) + if h2 == "automatic": + h2 = range(1, _w + 1) + _headings = [h1, h2] + + allow = ('l', 'r', 'c') + alignments = kwarg.get("alignments", "l") + + def _std_align(a): + a = a.strip().lower() + if len(a) > 1: + return {'left': 'l', 'right': 'r', 'center': 'c'}.get(a, a) + else: + return {'<': 'l', '>': 'r', '^': 'c'}.get(a, a) + std_align = _std_align(alignments) + if std_align in allow: + _alignments = [std_align]*_w + else: + _alignments = [] + for a in alignments: + std_align = _std_align(a) + _alignments.append(std_align) + if std_align not in ('l', 'r', 'c'): + raise ValueError('alignment "%s" unrecognized' % + alignments) + if _headings[0] and len(_alignments) == _w + 1: + _head_align = _alignments[0] + _alignments = _alignments[1:] + else: + _head_align = 'r' + if len(_alignments) != _w: + raise ValueError( + 'wrong number of alignments: expected %s but got %s' % + (_w, len(_alignments))) + + _column_formats = kwarg.get("formats", [None]*_w) + + _wipe_zeros = kwarg.get("wipe_zeros", True) + + self._w = _w + self._h = _h + self._lines = _lines + self._headings = _headings + self._head_align = _head_align + self._alignments = _alignments + self._column_formats = _column_formats + self._wipe_zeros = _wipe_zeros + + def __repr__(self): + from .str import sstr + return sstr(self, order=None) + + def __str__(self): + from .str import sstr + return sstr(self, order=None) + + def as_matrix(self): + """Returns the data of the table in Matrix form. + + Examples + ======== + + >>> from sympy import TableForm + >>> t = TableForm([[5, 7], [4, 2], [10, 3]], headings='automatic') + >>> t + | 1 2 + -------- + 1 | 5 7 + 2 | 4 2 + 3 | 10 3 + >>> t.as_matrix() + Matrix([ + [ 5, 7], + [ 4, 2], + [10, 3]]) + """ + from sympy.matrices.dense import Matrix + return Matrix(self._lines) + + def as_str(self): + # XXX obsolete ? + return str(self) + + def as_latex(self): + from .latex import latex + return latex(self) + + def _sympystr(self, p): + """ + Returns the string representation of 'self'. + + Examples + ======== + + >>> from sympy import TableForm + >>> t = TableForm([[5, 7], [4, 2], [10, 3]]) + >>> s = t.as_str() + + """ + column_widths = [0] * self._w + lines = [] + for line in self._lines: + new_line = [] + for i in range(self._w): + # Format the item somehow if needed: + s = str(line[i]) + if self._wipe_zeros and (s == "0"): + s = " " + w = len(s) + if w > column_widths[i]: + column_widths[i] = w + new_line.append(s) + lines.append(new_line) + + # Check heading: + if self._headings[0]: + self._headings[0] = [str(x) for x in self._headings[0]] + _head_width = max(len(x) for x in self._headings[0]) + + if self._headings[1]: + new_line = [] + for i in range(self._w): + # Format the item somehow if needed: + s = str(self._headings[1][i]) + w = len(s) + if w > column_widths[i]: + column_widths[i] = w + new_line.append(s) + self._headings[1] = new_line + + format_str = [] + + def _align(align, w): + return '%%%s%ss' % ( + ("-" if align == "l" else ""), + str(w)) + format_str = [_align(align, w) for align, w in + zip(self._alignments, column_widths)] + if self._headings[0]: + format_str.insert(0, _align(self._head_align, _head_width)) + format_str.insert(1, '|') + format_str = ' '.join(format_str) + '\n' + + s = [] + if self._headings[1]: + d = self._headings[1] + if self._headings[0]: + d = [""] + d + first_line = format_str % tuple(d) + s.append(first_line) + s.append("-" * (len(first_line) - 1) + "\n") + for i, line in enumerate(lines): + d = [l if self._alignments[j] != 'c' else + l.center(column_widths[j]) for j, l in enumerate(line)] + if self._headings[0]: + l = self._headings[0][i] + l = (l if self._head_align != 'c' else + l.center(_head_width)) + d = [l] + d + s.append(format_str % tuple(d)) + return ''.join(s)[:-1] # don't include trailing newline + + def _latex(self, printer): + """ + Returns the string representation of 'self'. + """ + # Check heading: + if self._headings[1]: + new_line = [] + for i in range(self._w): + # Format the item somehow if needed: + new_line.append(str(self._headings[1][i])) + self._headings[1] = new_line + + alignments = [] + if self._headings[0]: + self._headings[0] = [str(x) for x in self._headings[0]] + alignments = [self._head_align] + alignments.extend(self._alignments) + + s = r"\begin{tabular}{" + " ".join(alignments) + "}\n" + + if self._headings[1]: + d = self._headings[1] + if self._headings[0]: + d = [""] + d + first_line = " & ".join(d) + r" \\" + "\n" + s += first_line + s += r"\hline" + "\n" + for i, line in enumerate(self._lines): + d = [] + for j, x in enumerate(line): + if self._wipe_zeros and (x in (0, "0")): + d.append(" ") + continue + f = self._column_formats[j] + if f: + if isinstance(f, FunctionType): + v = f(x, i, j) + if v is None: + v = printer._print(x) + else: + v = f % x + d.append(v) + else: + v = printer._print(x) + d.append("$%s$" % v) + if self._headings[0]: + d = [self._headings[0][i]] + d + s += " & ".join(d) + r" \\" + "\n" + s += r"\end{tabular}" + return s diff --git a/phivenv/Lib/site-packages/sympy/printing/tensorflow.py b/phivenv/Lib/site-packages/sympy/printing/tensorflow.py new file mode 100644 index 0000000000000000000000000000000000000000..78b0df62b611f336468769e4cee1695bc068eee9 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/tensorflow.py @@ -0,0 +1,224 @@ +import sympy.codegen +import sympy.codegen.cfunctions +from sympy.external.importtools import version_tuple +from collections.abc import Iterable + +from sympy.core.mul import Mul +from sympy.core.singleton import S +from sympy.codegen.cfunctions import Sqrt +from sympy.external import import_module +from sympy.printing.precedence import PRECEDENCE +from sympy.printing.pycode import AbstractPythonCodePrinter, ArrayPrinter +import sympy + +tensorflow = import_module('tensorflow') + +class TensorflowPrinter(ArrayPrinter, AbstractPythonCodePrinter): + """ + Tensorflow printer which handles vectorized piecewise functions, + logical operators, max/min, and relational operators. + """ + printmethod = "_tensorflowcode" + + mapping = { + sympy.Abs: "tensorflow.math.abs", + sympy.sign: "tensorflow.math.sign", + + # XXX May raise error for ints. + sympy.ceiling: "tensorflow.math.ceil", + sympy.floor: "tensorflow.math.floor", + sympy.log: "tensorflow.math.log", + sympy.exp: "tensorflow.math.exp", + Sqrt: "tensorflow.math.sqrt", + sympy.cos: "tensorflow.math.cos", + sympy.acos: "tensorflow.math.acos", + sympy.sin: "tensorflow.math.sin", + sympy.asin: "tensorflow.math.asin", + sympy.tan: "tensorflow.math.tan", + sympy.atan: "tensorflow.math.atan", + sympy.atan2: "tensorflow.math.atan2", + # XXX Also may give NaN for complex results. + sympy.cosh: "tensorflow.math.cosh", + sympy.acosh: "tensorflow.math.acosh", + sympy.sinh: "tensorflow.math.sinh", + sympy.asinh: "tensorflow.math.asinh", + sympy.tanh: "tensorflow.math.tanh", + sympy.atanh: "tensorflow.math.atanh", + + sympy.re: "tensorflow.math.real", + sympy.im: "tensorflow.math.imag", + sympy.arg: "tensorflow.math.angle", + + # XXX May raise error for ints and complexes + sympy.erf: "tensorflow.math.erf", + sympy.loggamma: "tensorflow.math.lgamma", + + sympy.Eq: "tensorflow.math.equal", + sympy.Ne: "tensorflow.math.not_equal", + sympy.StrictGreaterThan: "tensorflow.math.greater", + sympy.StrictLessThan: "tensorflow.math.less", + sympy.LessThan: "tensorflow.math.less_equal", + sympy.GreaterThan: "tensorflow.math.greater_equal", + + sympy.And: "tensorflow.math.logical_and", + sympy.Or: "tensorflow.math.logical_or", + sympy.Not: "tensorflow.math.logical_not", + sympy.Max: "tensorflow.math.maximum", + sympy.Min: "tensorflow.math.minimum", + + # Matrices + sympy.MatAdd: "tensorflow.math.add", + sympy.HadamardProduct: "tensorflow.math.multiply", + sympy.Trace: "tensorflow.linalg.trace", + + # XXX May raise error for integer matrices. + sympy.Determinant : "tensorflow.linalg.det", + } + + _default_settings = dict( + AbstractPythonCodePrinter._default_settings, + tensorflow_version=None + ) + + def __init__(self, settings=None): + super().__init__(settings) + + version = self._settings['tensorflow_version'] + if version is None and tensorflow: + version = tensorflow.__version__ + self.tensorflow_version = version + + def _print_Function(self, expr): + op = self.mapping.get(type(expr), None) + if op is None: + return super()._print_Basic(expr) + children = [self._print(arg) for arg in expr.args] + if len(children) == 1: + return "%s(%s)" % ( + self._module_format(op), + children[0] + ) + else: + return self._expand_fold_binary_op(op, children) + + _print_Expr = _print_Function + _print_Application = _print_Function + _print_MatrixExpr = _print_Function + # TODO: a better class structure would avoid this mess: + _print_Relational = _print_Function + _print_Not = _print_Function + _print_And = _print_Function + _print_Or = _print_Function + _print_HadamardProduct = _print_Function + _print_Trace = _print_Function + _print_Determinant = _print_Function + + def _print_Inverse(self, expr): + op = self._module_format('tensorflow.linalg.inv') + return "{}({})".format(op, self._print(expr.arg)) + + def _print_Transpose(self, expr): + version = self.tensorflow_version + if version and version_tuple(version) < version_tuple('1.14'): + op = self._module_format('tensorflow.matrix_transpose') + else: + op = self._module_format('tensorflow.linalg.matrix_transpose') + return "{}({})".format(op, self._print(expr.arg)) + + def _print_Derivative(self, expr): + variables = expr.variables + if any(isinstance(i, Iterable) for i in variables): + raise NotImplementedError("derivation by multiple variables is not supported") + def unfold(expr, args): + if not args: + return self._print(expr) + return "%s(%s, %s)[0]" % ( + self._module_format("tensorflow.gradients"), + unfold(expr, args[:-1]), + self._print(args[-1]), + ) + return unfold(expr.expr, variables) + + def _print_Piecewise(self, expr): + version = self.tensorflow_version + if version and version_tuple(version) < version_tuple('1.0'): + tensorflow_piecewise = "tensorflow.select" + else: + tensorflow_piecewise = "tensorflow.where" + + from sympy.functions.elementary.piecewise import Piecewise + e, cond = expr.args[0].args + if len(expr.args) == 1: + return '{}({}, {}, {})'.format( + self._module_format(tensorflow_piecewise), + self._print(cond), + self._print(e), + 0) + + return '{}({}, {}, {})'.format( + self._module_format(tensorflow_piecewise), + self._print(cond), + self._print(e), + self._print(Piecewise(*expr.args[1:]))) + + def _print_Pow(self, expr): + # XXX May raise error for + # int**float or int**complex or float**complex + base, exp = expr.args + if expr.exp == S.Half: + return "{}({})".format( + self._module_format("tensorflow.math.sqrt"), self._print(base)) + return "{}({}, {})".format( + self._module_format("tensorflow.math.pow"), + self._print(base), self._print(exp)) + + def _print_MatrixBase(self, expr): + tensorflow_f = "tensorflow.Variable" if expr.free_symbols else "tensorflow.constant" + data = "["+", ".join(["["+", ".join([self._print(j) for j in i])+"]" for i in expr.tolist()])+"]" + return "%s(%s)" % ( + self._module_format(tensorflow_f), + data, + ) + + def _print_MatMul(self, expr): + from sympy.matrices.expressions import MatrixExpr + mat_args = [arg for arg in expr.args if isinstance(arg, MatrixExpr)] + args = [arg for arg in expr.args if arg not in mat_args] + if args: + return "%s*%s" % ( + self.parenthesize(Mul.fromiter(args), PRECEDENCE["Mul"]), + self._expand_fold_binary_op( + "tensorflow.linalg.matmul", mat_args) + ) + else: + return self._expand_fold_binary_op( + "tensorflow.linalg.matmul", mat_args) + + def _print_MatPow(self, expr): + return self._expand_fold_binary_op( + "tensorflow.linalg.matmul", [expr.base]*expr.exp) + + def _print_CodeBlock(self, expr): + # TODO: is this necessary? + ret = [] + for subexpr in expr.args: + ret.append(self._print(subexpr)) + return "\n".join(ret) + + def _print_isnan(self, exp): + return f'tensorflow.math.is_nan({self._print(*exp.args)})' + + def _print_isinf(self, exp): + return f'tensorflow.math.is_inf({self._print(*exp.args)})' + + _module = "tensorflow" + _einsum = "linalg.einsum" + _add = "math.add" + _transpose = "transpose" + _ones = "ones" + _zeros = "zeros" + + +def tensorflow_code(expr, **settings): + printer = TensorflowPrinter(settings) + return printer.doprint(expr) diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/__init__.py b/phivenv/Lib/site-packages/sympy/printing/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f76f145b09282ae9775f1279a5684f6fbc0d4678 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_aesaracode.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_aesaracode.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f9cc17f577bec341e515cbf8cc6e26d5f487261 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_aesaracode.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_c.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_c.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..593ee425a1ba42b009614bfd1e103fea760e6827 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_c.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_codeprinter.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_codeprinter.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74de4ad8a275c3e4e93d535624e5778cf130cb53 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_codeprinter.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_conventions.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_conventions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e0f2f26fa95c8e80bef207d08ddf4f90603e7912 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_conventions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_cupy.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_cupy.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28929e459fa142395b058bfbcbae2a7385c1953a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_cupy.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_cxx.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_cxx.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d11ffebd24ad99a45174d0145fb64fb76c1580f Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_cxx.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_dot.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_dot.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f90f709b30562fa494523a6821c095b2f6e079bb Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_dot.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_fortran.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_fortran.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..189fbceaa32b7d069ee7ea039fd5eda6ee6baf62 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_fortran.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_glsl.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_glsl.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1bd4355f8fcc8fe5f5831c195c3f7461b1c661eb Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_glsl.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_gtk.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_gtk.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a22620a83b22581d00a352c013fdaf1f068fa435 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_gtk.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_jax.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_jax.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5fbfe331b99a0a2d936d42e026c65a2039dc266f Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_jax.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_jscode.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_jscode.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..674d0310e4f9f281b8e316974bb4066bef51f643 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_jscode.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_julia.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_julia.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..079a93a7cd22265a613a4931490de089bca69f36 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_julia.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_lambdarepr.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_lambdarepr.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db3f47827a0678b8f0ca825fe5e6c7fd825e776a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_lambdarepr.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_llvmjit.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_llvmjit.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..efe4737363680f48234df360d28cf8ec34daa181 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_llvmjit.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_maple.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_maple.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..408a05c7d531c66c8dad3a55e60a895c0bd26303 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_maple.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_mathematica.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_mathematica.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3575c6ac9f98597eea1ac15d026910039fd54429 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_mathematica.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_mathml.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_mathml.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49ab99a1a89dbd2fae1632bc691c407aa36ab3e7 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_mathml.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_numpy.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_numpy.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50cd4c7561e80cb977e84602f38bdc9bf0298bdc Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_numpy.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_octave.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_octave.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2621bfb1829cdea434ef2b86ada667032107fbb8 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_octave.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_precedence.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_precedence.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa3fa3db947f6f7a1cd4a25748b8c4e501c0df21 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_precedence.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_preview.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_preview.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1cbf4dc792864e94c6311263a5c4b354213bead1 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_preview.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_pycode.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_pycode.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bda30d0420df06f50262d842dc30959a4ca6bbc0 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_pycode.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_python.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_python.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f26edbcf2d866cbce7af03482530a2ddb545e77 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_python.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_rcode.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_rcode.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08f23b8797792647f30fa19279777110b9ebce7a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_rcode.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_repr.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_repr.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4dc1cac5aedb5d5853ab4f50c13b37bc878abf9c Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_repr.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_rust.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_rust.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..435eb8d5c4344d893daeab5c8194ab06d00f344f Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_rust.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_smtlib.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_smtlib.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..72b5a305f992260a37191bff02778e1e14e091a8 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_smtlib.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_str.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_str.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6175e5fd4d9566b55b903e640b5154689dd8d90 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_str.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_tableform.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_tableform.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99a2a61430d4999b0eff321f821936d7bd0a9670 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_tableform.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_tensorflow.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_tensorflow.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..854bffd6ab03bca6fbfc6fea1674646387b82501 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_tensorflow.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_theanocode.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_theanocode.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b1c287afe4b09d3e96051c3d7ffc3290b685c17 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_theanocode.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_torch.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_torch.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53dbe9e04e03bef04d46439e1064e5b3572605b4 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_torch.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_tree.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_tree.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1554b5f19c63ad07e22b9d5673e47d3087d1be05 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/printing/tests/__pycache__/test_tree.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/test_aesaracode.py b/phivenv/Lib/site-packages/sympy/printing/tests/test_aesaracode.py new file mode 100644 index 0000000000000000000000000000000000000000..13308af65b382e77de33302bcd75344d2b00adbf --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/tests/test_aesaracode.py @@ -0,0 +1,633 @@ +""" +Important note on tests in this module - the Aesara printing functions use a +global cache by default, which means that tests using it will modify global +state and thus not be independent from each other. Instead of using the "cache" +keyword argument each time, this module uses the aesara_code_ and +aesara_function_ functions defined below which default to using a new, empty +cache instead. +""" + +import logging + +from sympy.external import import_module +from sympy.testing.pytest import raises, SKIP, warns_deprecated_sympy + +from sympy.utilities.exceptions import ignore_warnings + + +aesaralogger = logging.getLogger('aesara.configdefaults') +aesaralogger.setLevel(logging.CRITICAL) +aesara = import_module('aesara') +aesaralogger.setLevel(logging.WARNING) + + +if aesara: + import numpy as np + aet = aesara.tensor + from aesara.scalar.basic import ScalarType + from aesara.graph.basic import Variable + from aesara.tensor.var import TensorVariable + from aesara.tensor.elemwise import Elemwise, DimShuffle + from aesara.tensor.math import Dot + + from sympy.printing.aesaracode import true_divide + + xt, yt, zt = [aet.scalar(name, 'floatX') for name in 'xyz'] + Xt, Yt, Zt = [aet.tensor('floatX', (False, False), name=n) for n in 'XYZ'] +else: + #bin/test will not execute any tests now + disabled = True + +import sympy as sy +from sympy.core.singleton import S +from sympy.abc import x, y, z, t +from sympy.printing.aesaracode import (aesara_code, dim_handling, + aesara_function) + + +# Default set of matrix symbols for testing - make square so we can both +# multiply and perform elementwise operations between them. +X, Y, Z = [sy.MatrixSymbol(n, 4, 4) for n in 'XYZ'] + +# For testing AppliedUndef +f_t = sy.Function('f')(t) + + +def aesara_code_(expr, **kwargs): + """ Wrapper for aesara_code that uses a new, empty cache by default. """ + kwargs.setdefault('cache', {}) + with warns_deprecated_sympy(): + return aesara_code(expr, **kwargs) + +def aesara_function_(inputs, outputs, **kwargs): + """ Wrapper for aesara_function that uses a new, empty cache by default. """ + kwargs.setdefault('cache', {}) + with warns_deprecated_sympy(): + return aesara_function(inputs, outputs, **kwargs) + + +def fgraph_of(*exprs): + """ Transform SymPy expressions into Aesara Computation. + + Parameters + ========== + exprs + SymPy expressions + + Returns + ======= + aesara.graph.fg.FunctionGraph + """ + outs = list(map(aesara_code_, exprs)) + ins = list(aesara.graph.basic.graph_inputs(outs)) + ins, outs = aesara.graph.basic.clone(ins, outs) + return aesara.graph.fg.FunctionGraph(ins, outs) + + +def aesara_simplify(fgraph): + """ Simplify a Aesara Computation. + + Parameters + ========== + fgraph : aesara.graph.fg.FunctionGraph + + Returns + ======= + aesara.graph.fg.FunctionGraph + """ + mode = aesara.compile.get_default_mode().excluding("fusion") + fgraph = fgraph.clone() + mode.optimizer.rewrite(fgraph) + return fgraph + + +def theq(a, b): + """ Test two Aesara objects for equality. + + Also accepts numeric types and lists/tuples of supported types. + + Note - debugprint() has a bug where it will accept numeric types but does + not respect the "file" argument and in this case and instead prints the number + to stdout and returns an empty string. This can lead to tests passing where + they should fail because any two numbers will always compare as equal. To + prevent this we treat numbers as a separate case. + """ + numeric_types = (int, float, np.number) + a_is_num = isinstance(a, numeric_types) + b_is_num = isinstance(b, numeric_types) + + # Compare numeric types using regular equality + if a_is_num or b_is_num: + if not (a_is_num and b_is_num): + return False + + return a == b + + # Compare sequences element-wise + a_is_seq = isinstance(a, (tuple, list)) + b_is_seq = isinstance(b, (tuple, list)) + + if a_is_seq or b_is_seq: + if not (a_is_seq and b_is_seq) or type(a) != type(b): + return False + + return list(map(theq, a)) == list(map(theq, b)) + + # Otherwise, assume debugprint() can handle it + astr = aesara.printing.debugprint(a, file='str') + bstr = aesara.printing.debugprint(b, file='str') + + # Check for bug mentioned above + for argname, argval, argstr in [('a', a, astr), ('b', b, bstr)]: + if argstr == '': + raise TypeError( + 'aesara.printing.debugprint(%s) returned empty string ' + '(%s is instance of %r)' + % (argname, argname, type(argval)) + ) + + return astr == bstr + + +def test_example_symbols(): + """ + Check that the example symbols in this module print to their Aesara + equivalents, as many of the other tests depend on this. + """ + assert theq(xt, aesara_code_(x)) + assert theq(yt, aesara_code_(y)) + assert theq(zt, aesara_code_(z)) + assert theq(Xt, aesara_code_(X)) + assert theq(Yt, aesara_code_(Y)) + assert theq(Zt, aesara_code_(Z)) + + +def test_Symbol(): + """ Test printing a Symbol to a aesara variable. """ + xx = aesara_code_(x) + assert isinstance(xx, Variable) + assert xx.broadcastable == () + assert xx.name == x.name + + xx2 = aesara_code_(x, broadcastables={x: (False,)}) + assert xx2.broadcastable == (False,) + assert xx2.name == x.name + +def test_MatrixSymbol(): + """ Test printing a MatrixSymbol to a aesara variable. """ + XX = aesara_code_(X) + assert isinstance(XX, TensorVariable) + assert XX.broadcastable == (False, False) + +@SKIP # TODO - this is currently not checked but should be implemented +def test_MatrixSymbol_wrong_dims(): + """ Test MatrixSymbol with invalid broadcastable. """ + bcs = [(), (False,), (True,), (True, False), (False, True,), (True, True)] + for bc in bcs: + with raises(ValueError): + aesara_code_(X, broadcastables={X: bc}) + +def test_AppliedUndef(): + """ Test printing AppliedUndef instance, which works similarly to Symbol. """ + ftt = aesara_code_(f_t) + assert isinstance(ftt, TensorVariable) + assert ftt.broadcastable == () + assert ftt.name == 'f_t' + + +def test_add(): + expr = x + y + comp = aesara_code_(expr) + assert comp.owner.op == aesara.tensor.add + +def test_trig(): + assert theq(aesara_code_(sy.sin(x)), aet.sin(xt)) + assert theq(aesara_code_(sy.tan(x)), aet.tan(xt)) + +def test_many(): + """ Test printing a complex expression with multiple symbols. """ + expr = sy.exp(x**2 + sy.cos(y)) * sy.log(2*z) + comp = aesara_code_(expr) + expected = aet.exp(xt**2 + aet.cos(yt)) * aet.log(2*zt) + assert theq(comp, expected) + + +def test_dtype(): + """ Test specifying specific data types through the dtype argument. """ + for dtype in ['float32', 'float64', 'int8', 'int16', 'int32', 'int64']: + assert aesara_code_(x, dtypes={x: dtype}).type.dtype == dtype + + # "floatX" type + assert aesara_code_(x, dtypes={x: 'floatX'}).type.dtype in ('float32', 'float64') + + # Type promotion + assert aesara_code_(x + 1, dtypes={x: 'float32'}).type.dtype == 'float32' + assert aesara_code_(x + y, dtypes={x: 'float64', y: 'float32'}).type.dtype == 'float64' + + +def test_broadcastables(): + """ Test the "broadcastables" argument when printing symbol-like objects. """ + + # No restrictions on shape + for s in [x, f_t]: + for bc in [(), (False,), (True,), (False, False), (True, False)]: + assert aesara_code_(s, broadcastables={s: bc}).broadcastable == bc + + # TODO - matrix broadcasting? + +def test_broadcasting(): + """ Test "broadcastable" attribute after applying element-wise binary op. """ + + expr = x + y + + cases = [ + [(), (), ()], + [(False,), (False,), (False,)], + [(True,), (False,), (False,)], + [(False, True), (False, False), (False, False)], + [(True, False), (False, False), (False, False)], + ] + + for bc1, bc2, bc3 in cases: + comp = aesara_code_(expr, broadcastables={x: bc1, y: bc2}) + assert comp.broadcastable == bc3 + + +def test_MatMul(): + expr = X*Y*Z + expr_t = aesara_code_(expr) + assert isinstance(expr_t.owner.op, Dot) + assert theq(expr_t, Xt.dot(Yt).dot(Zt)) + +def test_Transpose(): + assert isinstance(aesara_code_(X.T).owner.op, DimShuffle) + +def test_MatAdd(): + expr = X+Y+Z + assert isinstance(aesara_code_(expr).owner.op, Elemwise) + + +def test_Rationals(): + assert theq(aesara_code_(sy.Integer(2) / 3), true_divide(2, 3)) + assert theq(aesara_code_(S.Half), true_divide(1, 2)) + +def test_Integers(): + assert aesara_code_(sy.Integer(3)) == 3 + +def test_factorial(): + n = sy.Symbol('n') + assert aesara_code_(sy.factorial(n)) + +def test_Derivative(): + with ignore_warnings(UserWarning): + simp = lambda expr: aesara_simplify(fgraph_of(expr)) + assert theq(simp(aesara_code_(sy.Derivative(sy.sin(x), x, evaluate=False))), + simp(aesara.grad(aet.sin(xt), xt))) + + +def test_aesara_function_simple(): + """ Test aesara_function() with single output. """ + f = aesara_function_([x, y], [x+y]) + assert f(2, 3) == 5 + +def test_aesara_function_multi(): + """ Test aesara_function() with multiple outputs. """ + f = aesara_function_([x, y], [x+y, x-y]) + o1, o2 = f(2, 3) + assert o1 == 5 + assert o2 == -1 + +def test_aesara_function_numpy(): + """ Test aesara_function() vs Numpy implementation. """ + f = aesara_function_([x, y], [x+y], dim=1, + dtypes={x: 'float64', y: 'float64'}) + assert np.linalg.norm(f([1, 2], [3, 4]) - np.asarray([4, 6])) < 1e-9 + + f = aesara_function_([x, y], [x+y], dtypes={x: 'float64', y: 'float64'}, + dim=1) + xx = np.arange(3).astype('float64') + yy = 2*np.arange(3).astype('float64') + assert np.linalg.norm(f(xx, yy) - 3*np.arange(3)) < 1e-9 + + +def test_aesara_function_matrix(): + m = sy.Matrix([[x, y], [z, x + y + z]]) + expected = np.array([[1.0, 2.0], [3.0, 1.0 + 2.0 + 3.0]]) + f = aesara_function_([x, y, z], [m]) + np.testing.assert_allclose(f(1.0, 2.0, 3.0), expected) + f = aesara_function_([x, y, z], [m], scalar=True) + np.testing.assert_allclose(f(1.0, 2.0, 3.0), expected) + f = aesara_function_([x, y, z], [m, m]) + assert isinstance(f(1.0, 2.0, 3.0), type([])) + np.testing.assert_allclose(f(1.0, 2.0, 3.0)[0], expected) + np.testing.assert_allclose(f(1.0, 2.0, 3.0)[1], expected) + +def test_dim_handling(): + assert dim_handling([x], dim=2) == {x: (False, False)} + assert dim_handling([x, y], dims={x: 1, y: 2}) == {x: (False, True), + y: (False, False)} + assert dim_handling([x], broadcastables={x: (False,)}) == {x: (False,)} + +def test_aesara_function_kwargs(): + """ + Test passing additional kwargs from aesara_function() to aesara.function(). + """ + import numpy as np + f = aesara_function_([x, y, z], [x+y], dim=1, on_unused_input='ignore', + dtypes={x: 'float64', y: 'float64', z: 'float64'}) + assert np.linalg.norm(f([1, 2], [3, 4], [0, 0]) - np.asarray([4, 6])) < 1e-9 + + f = aesara_function_([x, y, z], [x+y], + dtypes={x: 'float64', y: 'float64', z: 'float64'}, + dim=1, on_unused_input='ignore') + xx = np.arange(3).astype('float64') + yy = 2*np.arange(3).astype('float64') + zz = 2*np.arange(3).astype('float64') + assert np.linalg.norm(f(xx, yy, zz) - 3*np.arange(3)) < 1e-9 + +def test_aesara_function_scalar(): + """ Test the "scalar" argument to aesara_function(). """ + from aesara.compile.function.types import Function + + args = [ + ([x, y], [x + y], None, [0]), # Single 0d output + ([X, Y], [X + Y], None, [2]), # Single 2d output + ([x, y], [x + y], {x: 0, y: 1}, [1]), # Single 1d output + ([x, y], [x + y, x - y], None, [0, 0]), # Two 0d outputs + ([x, y, X, Y], [x + y, X + Y], None, [0, 2]), # One 0d output, one 2d + ] + + # Create and test functions with and without the scalar setting + for inputs, outputs, in_dims, out_dims in args: + for scalar in [False, True]: + + f = aesara_function_(inputs, outputs, dims=in_dims, scalar=scalar) + + # Check the aesara_function attribute is set whether wrapped or not + assert isinstance(f.aesara_function, Function) + + # Feed in inputs of the appropriate size and get outputs + in_values = [ + np.ones([1 if bc else 5 for bc in i.type.broadcastable]) + for i in f.aesara_function.input_storage + ] + out_values = f(*in_values) + if not isinstance(out_values, list): + out_values = [out_values] + + # Check output types and shapes + assert len(out_dims) == len(out_values) + for d, value in zip(out_dims, out_values): + + if scalar and d == 0: + # Should have been converted to a scalar value + assert isinstance(value, np.number) + + else: + # Otherwise should be an array + assert isinstance(value, np.ndarray) + assert value.ndim == d + +def test_aesara_function_bad_kwarg(): + """ + Passing an unknown keyword argument to aesara_function() should raise an + exception. + """ + raises(Exception, lambda : aesara_function_([x], [x+1], foobar=3)) + + +def test_slice(): + assert aesara_code_(slice(1, 2, 3)) == slice(1, 2, 3) + + def theq_slice(s1, s2): + for attr in ['start', 'stop', 'step']: + a1 = getattr(s1, attr) + a2 = getattr(s2, attr) + if a1 is None or a2 is None: + if not (a1 is None or a2 is None): + return False + elif not theq(a1, a2): + return False + return True + + dtypes = {x: 'int32', y: 'int32'} + assert theq_slice(aesara_code_(slice(x, y), dtypes=dtypes), slice(xt, yt)) + assert theq_slice(aesara_code_(slice(1, x, 3), dtypes=dtypes), slice(1, xt, 3)) + +def test_MatrixSlice(): + cache = {} + + n = sy.Symbol('n', integer=True) + X = sy.MatrixSymbol('X', n, n) + + Y = X[1:2:3, 4:5:6] + Yt = aesara_code_(Y, cache=cache) + + s = ScalarType('int64') + assert tuple(Yt.owner.op.idx_list) == (slice(s, s, s), slice(s, s, s)) + assert Yt.owner.inputs[0] == aesara_code_(X, cache=cache) + # == doesn't work in Aesara like it does in SymPy. You have to use + # equals. + assert all(Yt.owner.inputs[i].data == i for i in range(1, 7)) + + k = sy.Symbol('k') + aesara_code_(k, dtypes={k: 'int32'}) + start, stop, step = 4, k, 2 + Y = X[start:stop:step] + Yt = aesara_code_(Y, dtypes={n: 'int32', k: 'int32'}) + # assert Yt.owner.op.idx_list[0].stop == kt + +def test_BlockMatrix(): + n = sy.Symbol('n', integer=True) + A, B, C, D = [sy.MatrixSymbol(name, n, n) for name in 'ABCD'] + At, Bt, Ct, Dt = map(aesara_code_, (A, B, C, D)) + Block = sy.BlockMatrix([[A, B], [C, D]]) + Blockt = aesara_code_(Block) + solutions = [aet.join(0, aet.join(1, At, Bt), aet.join(1, Ct, Dt)), + aet.join(1, aet.join(0, At, Ct), aet.join(0, Bt, Dt))] + assert any(theq(Blockt, solution) for solution in solutions) + +@SKIP +def test_BlockMatrix_Inverse_execution(): + k, n = 2, 4 + dtype = 'float32' + A = sy.MatrixSymbol('A', n, k) + B = sy.MatrixSymbol('B', n, n) + inputs = A, B + output = B.I*A + + cutsizes = {A: [(n//2, n//2), (k//2, k//2)], + B: [(n//2, n//2), (n//2, n//2)]} + cutinputs = [sy.blockcut(i, *cutsizes[i]) for i in inputs] + cutoutput = output.subs(dict(zip(inputs, cutinputs))) + + dtypes = dict(zip(inputs, [dtype]*len(inputs))) + f = aesara_function_(inputs, [output], dtypes=dtypes, cache={}) + fblocked = aesara_function_(inputs, [sy.block_collapse(cutoutput)], + dtypes=dtypes, cache={}) + + ninputs = [np.random.rand(*x.shape).astype(dtype) for x in inputs] + ninputs = [np.arange(n*k).reshape(A.shape).astype(dtype), + np.eye(n).astype(dtype)] + ninputs[1] += np.ones(B.shape)*1e-5 + + assert np.allclose(f(*ninputs), fblocked(*ninputs), rtol=1e-5) + +def test_DenseMatrix(): + from aesara.tensor.basic import Join + + t = sy.Symbol('theta') + for MatrixType in [sy.Matrix, sy.ImmutableMatrix]: + X = MatrixType([[sy.cos(t), -sy.sin(t)], [sy.sin(t), sy.cos(t)]]) + tX = aesara_code_(X) + assert isinstance(tX, TensorVariable) + assert isinstance(tX.owner.op, Join) + + +def test_cache_basic(): + """ Test single symbol-like objects are cached when printed by themselves. """ + + # Pairs of objects which should be considered equivalent with respect to caching + pairs = [ + (x, sy.Symbol('x')), + (X, sy.MatrixSymbol('X', *X.shape)), + (f_t, sy.Function('f')(sy.Symbol('t'))), + ] + + for s1, s2 in pairs: + cache = {} + st = aesara_code_(s1, cache=cache) + + # Test hit with same instance + assert aesara_code_(s1, cache=cache) is st + + # Test miss with same instance but new cache + assert aesara_code_(s1, cache={}) is not st + + # Test hit with different but equivalent instance + assert aesara_code_(s2, cache=cache) is st + +def test_global_cache(): + """ Test use of the global cache. """ + from sympy.printing.aesaracode import global_cache + + backup = dict(global_cache) + try: + # Temporarily empty global cache + global_cache.clear() + + for s in [x, X, f_t]: + with warns_deprecated_sympy(): + st = aesara_code(s) + assert aesara_code(s) is st + + finally: + # Restore global cache + global_cache.update(backup) + +def test_cache_types_distinct(): + """ + Test that symbol-like objects of different types (Symbol, MatrixSymbol, + AppliedUndef) are distinguished by the cache even if they have the same + name. + """ + symbols = [sy.Symbol('f_t'), sy.MatrixSymbol('f_t', 4, 4), f_t] + + cache = {} # Single shared cache + printed = {} + + for s in symbols: + st = aesara_code_(s, cache=cache) + assert st not in printed.values() + printed[s] = st + + # Check all printed objects are distinct + assert len(set(map(id, printed.values()))) == len(symbols) + + # Check retrieving + for s, st in printed.items(): + with warns_deprecated_sympy(): + assert aesara_code(s, cache=cache) is st + +def test_symbols_are_created_once(): + """ + Test that a symbol is cached and reused when it appears in an expression + more than once. + """ + expr = sy.Add(x, x, evaluate=False) + comp = aesara_code_(expr) + + assert theq(comp, xt + xt) + assert not theq(comp, xt + aesara_code_(x)) + +def test_cache_complex(): + """ + Test caching on a complicated expression with multiple symbols appearing + multiple times. + """ + expr = x ** 2 + (y - sy.exp(x)) * sy.sin(z - x * y) + symbol_names = {s.name for s in expr.free_symbols} + expr_t = aesara_code_(expr) + + # Iterate through variables in the Aesara computational graph that the + # printed expression depends on + seen = set() + for v in aesara.graph.basic.ancestors([expr_t]): + # Owner-less, non-constant variables should be our symbols + if v.owner is None and not isinstance(v, aesara.graph.basic.Constant): + # Check it corresponds to a symbol and appears only once + assert v.name in symbol_names + assert v.name not in seen + seen.add(v.name) + + # Check all were present + assert seen == symbol_names + + +def test_Piecewise(): + # A piecewise linear + expr = sy.Piecewise((0, x<0), (x, x<2), (1, True)) # ___/III + result = aesara_code_(expr) + assert result.owner.op == aet.switch + + expected = aet.switch(xt<0, 0, aet.switch(xt<2, xt, 1)) + assert theq(result, expected) + + expr = sy.Piecewise((x, x < 0)) + result = aesara_code_(expr) + expected = aet.switch(xt < 0, xt, np.nan) + assert theq(result, expected) + + expr = sy.Piecewise((0, sy.And(x>0, x<2)), \ + (x, sy.Or(x>2, x<0))) + result = aesara_code_(expr) + expected = aet.switch(aet.and_(xt>0,xt<2), 0, \ + aet.switch(aet.or_(xt>2, xt<0), xt, np.nan)) + assert theq(result, expected) + + +def test_Relationals(): + assert theq(aesara_code_(sy.Eq(x, y)), aet.eq(xt, yt)) + # assert theq(aesara_code_(sy.Ne(x, y)), aet.neq(xt, yt)) # TODO - implement + assert theq(aesara_code_(x > y), xt > yt) + assert theq(aesara_code_(x < y), xt < yt) + assert theq(aesara_code_(x >= y), xt >= yt) + assert theq(aesara_code_(x <= y), xt <= yt) + + +def test_complexfunctions(): + dtypes = {x:'complex128', y:'complex128'} + with warns_deprecated_sympy(): + xt, yt = aesara_code(x, dtypes=dtypes), aesara_code(y, dtypes=dtypes) + from sympy.functions.elementary.complexes import conjugate + from aesara.tensor import as_tensor_variable as atv + from aesara.tensor import complex as cplx + with warns_deprecated_sympy(): + assert theq(aesara_code(y*conjugate(x), dtypes=dtypes), yt*(xt.conj())) + assert theq(aesara_code((1+2j)*x), xt*(atv(1.0)+atv(2.0)*cplx(0,1))) + + +def test_constantfunctions(): + with warns_deprecated_sympy(): + tf = aesara_function([],[1+1j]) + assert(tf()==1+1j) diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/test_c.py b/phivenv/Lib/site-packages/sympy/printing/tests/test_c.py new file mode 100644 index 0000000000000000000000000000000000000000..626e7b6f244ea3227b886cd897d327f5d7bf66ec --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/tests/test_c.py @@ -0,0 +1,888 @@ +from sympy.core import ( + S, pi, oo, Symbol, symbols, Rational, Integer, Float, Function, Mod, GoldenRatio, EulerGamma, Catalan, + Lambda, Dummy, nan, Mul, Pow, UnevaluatedExpr +) +from sympy.core.relational import (Eq, Ge, Gt, Le, Lt, Ne) +from sympy.functions import ( + Abs, acos, acosh, asin, asinh, atan, atanh, atan2, ceiling, cos, cosh, erf, + erfc, exp, floor, gamma, log, loggamma, Max, Min, Piecewise, sign, sin, sinh, + sqrt, tan, tanh, fibonacci, lucas +) +from sympy.sets import Range +from sympy.logic import ITE, Implies, Equivalent +from sympy.codegen import For, aug_assign, Assignment +from sympy.testing.pytest import raises, XFAIL +from sympy.printing.codeprinter import PrintMethodNotImplementedError +from sympy.printing.c import C89CodePrinter, C99CodePrinter, get_math_macros +from sympy.codegen.ast import ( + AddAugmentedAssignment, Element, Type, FloatType, Declaration, Pointer, Variable, value_const, pointer_const, + While, Scope, Print, FunctionPrototype, FunctionDefinition, FunctionCall, Return, + real, float32, float64, float80, float128, intc, Comment, CodeBlock, stderr, QuotedString +) +from sympy.codegen.cfunctions import expm1, log1p, exp2, log2, fma, log10, Cbrt, hypot, Sqrt, isnan, isinf +from sympy.codegen.cnodes import restrict +from sympy.utilities.lambdify import implemented_function +from sympy.tensor import IndexedBase, Idx +from sympy.matrices import Matrix, MatrixSymbol, SparseMatrix + +from sympy.printing.codeprinter import ccode + +x, y, z = symbols('x,y,z') + + +def test_printmethod(): + class fabs(Abs): + def _ccode(self, printer): + return "fabs(%s)" % printer._print(self.args[0]) + + assert ccode(fabs(x)) == "fabs(x)" + + +def test_ccode_sqrt(): + assert ccode(sqrt(x)) == "sqrt(x)" + assert ccode(x**0.5) == "sqrt(x)" + assert ccode(sqrt(x)) == "sqrt(x)" + + +def test_ccode_Pow(): + assert ccode(x**3) == "pow(x, 3)" + assert ccode(x**(y**3)) == "pow(x, pow(y, 3))" + g = implemented_function('g', Lambda(x, 2*x)) + assert ccode(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \ + "pow(3.5*2*x, -x + pow(y, x))/(pow(x, 2) + y)" + assert ccode(x**-1.0) == '1.0/x' + assert ccode(x**Rational(2, 3)) == 'pow(x, 2.0/3.0)' + assert ccode(x**Rational(2, 3), type_aliases={real: float80}) == 'powl(x, 2.0L/3.0L)' + _cond_cfunc = [(lambda base, exp: exp.is_integer, "dpowi"), + (lambda base, exp: not exp.is_integer, "pow")] + assert ccode(x**3, user_functions={'Pow': _cond_cfunc}) == 'dpowi(x, 3)' + assert ccode(x**0.5, user_functions={'Pow': _cond_cfunc}) == 'pow(x, 0.5)' + assert ccode(x**Rational(16, 5), user_functions={'Pow': _cond_cfunc}) == 'pow(x, 16.0/5.0)' + _cond_cfunc2 = [(lambda base, exp: base == 2, lambda base, exp: 'exp2(%s)' % exp), + (lambda base, exp: base != 2, 'pow')] + # Related to gh-11353 + assert ccode(2**x, user_functions={'Pow': _cond_cfunc2}) == 'exp2(x)' + assert ccode(x**2, user_functions={'Pow': _cond_cfunc2}) == 'pow(x, 2)' + # For issue 14160 + assert ccode(Mul(-2, x, Pow(Mul(y,y,evaluate=False), -1, evaluate=False), + evaluate=False)) == '-2*x/(y*y)' + + +def test_ccode_Max(): + # Test for gh-11926 + assert ccode(Max(x,x*x),user_functions={"Max":"my_max", "Pow":"my_pow"}) == 'my_max(x, my_pow(x, 2))' + + +def test_ccode_Min_performance(): + #Shouldn't take more than a few seconds + big_min = Min(*symbols('a[0:50]')) + for curr_standard in ('c89', 'c99', 'c11'): + output = ccode(big_min, standard=curr_standard) + assert output.count('(') == output.count(')') + + +def test_ccode_constants_mathh(): + assert ccode(exp(1)) == "M_E" + assert ccode(pi) == "M_PI" + assert ccode(oo, standard='c89') == "HUGE_VAL" + assert ccode(-oo, standard='c89') == "-HUGE_VAL" + assert ccode(oo) == "INFINITY" + assert ccode(-oo, standard='c99') == "-INFINITY" + assert ccode(pi, type_aliases={real: float80}) == "M_PIl" + + +def test_ccode_constants_other(): + assert ccode(2*GoldenRatio) == "const double GoldenRatio = %s;\n2*GoldenRatio" % GoldenRatio.evalf(17) + assert ccode( + 2*Catalan) == "const double Catalan = %s;\n2*Catalan" % Catalan.evalf(17) + assert ccode(2*EulerGamma) == "const double EulerGamma = %s;\n2*EulerGamma" % EulerGamma.evalf(17) + + +def test_ccode_Rational(): + assert ccode(Rational(3, 7)) == "3.0/7.0" + assert ccode(Rational(3, 7), type_aliases={real: float80}) == "3.0L/7.0L" + assert ccode(Rational(18, 9)) == "2" + assert ccode(Rational(3, -7)) == "-3.0/7.0" + assert ccode(Rational(3, -7), type_aliases={real: float80}) == "-3.0L/7.0L" + assert ccode(Rational(-3, -7)) == "3.0/7.0" + assert ccode(Rational(-3, -7), type_aliases={real: float80}) == "3.0L/7.0L" + assert ccode(x + Rational(3, 7)) == "x + 3.0/7.0" + assert ccode(x + Rational(3, 7), type_aliases={real: float80}) == "x + 3.0L/7.0L" + assert ccode(Rational(3, 7)*x) == "(3.0/7.0)*x" + assert ccode(Rational(3, 7)*x, type_aliases={real: float80}) == "(3.0L/7.0L)*x" + + +def test_ccode_Integer(): + assert ccode(Integer(67)) == "67" + assert ccode(Integer(-1)) == "-1" + + +def test_ccode_functions(): + assert ccode(sin(x) ** cos(x)) == "pow(sin(x), cos(x))" + + +def test_ccode_inline_function(): + x = symbols('x') + g = implemented_function('g', Lambda(x, 2*x)) + assert ccode(g(x)) == "2*x" + g = implemented_function('g', Lambda(x, 2*x/Catalan)) + assert ccode( + g(x)) == "const double Catalan = %s;\n2*x/Catalan" % Catalan.evalf(17) + A = IndexedBase('A') + i = Idx('i', symbols('n', integer=True)) + g = implemented_function('g', Lambda(x, x*(1 + x)*(2 + x))) + assert ccode(g(A[i]), assign_to=A[i]) == ( + "for (int i=0; i y" + assert ccode(Ge(x, y)) == "x >= y" + + +def test_ccode_Piecewise(): + expr = Piecewise((x, x < 1), (x**2, True)) + assert ccode(expr) == ( + "((x < 1) ? (\n" + " x\n" + ")\n" + ": (\n" + " pow(x, 2)\n" + "))") + assert ccode(expr, assign_to="c") == ( + "if (x < 1) {\n" + " c = x;\n" + "}\n" + "else {\n" + " c = pow(x, 2);\n" + "}") + expr = Piecewise((x, x < 1), (x + 1, x < 2), (x**2, True)) + assert ccode(expr) == ( + "((x < 1) ? (\n" + " x\n" + ")\n" + ": ((x < 2) ? (\n" + " x + 1\n" + ")\n" + ": (\n" + " pow(x, 2)\n" + ")))") + assert ccode(expr, assign_to='c') == ( + "if (x < 1) {\n" + " c = x;\n" + "}\n" + "else if (x < 2) {\n" + " c = x + 1;\n" + "}\n" + "else {\n" + " c = pow(x, 2);\n" + "}") + # Check that Piecewise without a True (default) condition error + expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0)) + raises(ValueError, lambda: ccode(expr)) + + +def test_ccode_sinc(): + from sympy.functions.elementary.trigonometric import sinc + expr = sinc(x) + assert ccode(expr) == ( + "(((x != 0) ? (\n" + " sin(x)/x\n" + ")\n" + ": (\n" + " 1\n" + ")))") + + +def test_ccode_Piecewise_deep(): + p = ccode(2*Piecewise((x, x < 1), (x + 1, x < 2), (x**2, True))) + assert p == ( + "2*((x < 1) ? (\n" + " x\n" + ")\n" + ": ((x < 2) ? (\n" + " x + 1\n" + ")\n" + ": (\n" + " pow(x, 2)\n" + ")))") + expr = x*y*z + x**2 + y**2 + Piecewise((0, x < 0.5), (1, True)) + cos(z) - 1 + assert ccode(expr) == ( + "pow(x, 2) + x*y*z + pow(y, 2) + ((x < 0.5) ? (\n" + " 0\n" + ")\n" + ": (\n" + " 1\n" + ")) + cos(z) - 1") + assert ccode(expr, assign_to='c') == ( + "c = pow(x, 2) + x*y*z + pow(y, 2) + ((x < 0.5) ? (\n" + " 0\n" + ")\n" + ": (\n" + " 1\n" + ")) + cos(z) - 1;") + + +def test_ccode_ITE(): + expr = ITE(x < 1, y, z) + assert ccode(expr) == ( + "((x < 1) ? (\n" + " y\n" + ")\n" + ": (\n" + " z\n" + "))") + + +def test_ccode_settings(): + raises(TypeError, lambda: ccode(sin(x), method="garbage")) + + +def test_ccode_Indexed(): + s, n, m, o = symbols('s n m o', integer=True) + i, j, k = Idx('i', n), Idx('j', m), Idx('k', o) + + x = IndexedBase('x')[j] + A = IndexedBase('A')[i, j] + B = IndexedBase('B')[i, j, k] + + p = C99CodePrinter() + + assert p._print_Indexed(x) == 'x[j]' + assert p._print_Indexed(A) == 'A[%s]' % (m*i+j) + assert p._print_Indexed(B) == 'B[%s]' % (i*o*m+j*o+k) + + A = IndexedBase('A', shape=(5,3))[i, j] + assert p._print_Indexed(A) == 'A[%s]' % (3*i + j) + + A = IndexedBase('A', shape=(5,3), strides='F')[i, j] + assert ccode(A) == 'A[%s]' % (i + 5*j) + + A = IndexedBase('A', shape=(29,29), strides=(1, s), offset=o)[i, j] + assert ccode(A) == 'A[o + s*j + i]' + + Abase = IndexedBase('A', strides=(s, m, n), offset=o) + assert ccode(Abase[i, j, k]) == 'A[m*j + n*k + o + s*i]' + assert ccode(Abase[2, 3, k]) == 'A[3*m + n*k + o + 2*s]' + + +def test_Element(): + assert ccode(Element('x', 'ij')) == 'x[i][j]' + assert ccode(Element('x', 'ij', strides='kl', offset='o')) == 'x[i*k + j*l + o]' + assert ccode(Element('x', (3,))) == 'x[3]' + assert ccode(Element('x', (3,4,5))) == 'x[3][4][5]' + + +def test_ccode_Indexed_without_looking_for_contraction(): + len_y = 5 + y = IndexedBase('y', shape=(len_y,)) + x = IndexedBase('x', shape=(len_y,)) + Dy = IndexedBase('Dy', shape=(len_y-1,)) + i = Idx('i', len_y-1) + e = Eq(Dy[i], (y[i+1]-y[i])/(x[i+1]-x[i])) + code0 = ccode(e.rhs, assign_to=e.lhs, contract=False) + assert code0 == 'Dy[i] = (y[%s] - y[i])/(x[%s] - x[i]);' % (i + 1, i + 1) + + +def test_ccode_loops_matrix_vector(): + n, m = symbols('n m', integer=True) + A = IndexedBase('A') + x = IndexedBase('x') + y = IndexedBase('y') + i = Idx('i', m) + j = Idx('j', n) + + s = ( + 'for (int i=0; i0), (y, True)), sin(z)]) + A = MatrixSymbol('A', 3, 1) + assert ccode(mat, A) == ( + "A[0] = x*y;\n" + "if (y > 0) {\n" + " A[1] = x + 2;\n" + "}\n" + "else {\n" + " A[1] = y;\n" + "}\n" + "A[2] = sin(z);") + # Test using MatrixElements in expressions + expr = Piecewise((2*A[2, 0], x > 0), (A[2, 0], True)) + sin(A[1, 0]) + A[0, 0] + assert ccode(expr) == ( + "((x > 0) ? (\n" + " 2*A[2]\n" + ")\n" + ": (\n" + " A[2]\n" + ")) + sin(A[1]) + A[0]") + # Test using MatrixElements in a Matrix + q = MatrixSymbol('q', 5, 1) + M = MatrixSymbol('M', 3, 3) + m = Matrix([[sin(q[1,0]), 0, cos(q[2,0])], + [q[1,0] + q[2,0], q[3, 0], 5], + [2*q[4, 0]/q[1,0], sqrt(q[0,0]) + 4, 0]]) + assert ccode(m, M) == ( + "M[0] = sin(q[1]);\n" + "M[1] = 0;\n" + "M[2] = cos(q[2]);\n" + "M[3] = q[1] + q[2];\n" + "M[4] = q[3];\n" + "M[5] = 5;\n" + "M[6] = 2*q[4]/q[1];\n" + "M[7] = sqrt(q[0]) + 4;\n" + "M[8] = 0;") + + +def test_sparse_matrix(): + # gh-15791 + with raises(PrintMethodNotImplementedError): + ccode(SparseMatrix([[1, 2, 3]])) + + assert 'Not supported in C' in C89CodePrinter({'strict': False}).doprint(SparseMatrix([[1, 2, 3]])) + + + +def test_ccode_reserved_words(): + x, y = symbols('x, if') + with raises(ValueError): + ccode(y**2, error_on_reserved=True, standard='C99') + assert ccode(y**2) == 'pow(if_, 2)' + assert ccode(x * y**2, dereference=[y]) == 'pow((*if_), 2)*x' + assert ccode(y**2, reserved_word_suffix='_unreserved') == 'pow(if_unreserved, 2)' + + +def test_ccode_sign(): + expr1, ref1 = sign(x) * y, 'y*(((x) > 0) - ((x) < 0))' + expr2, ref2 = sign(cos(x)), '(((cos(x)) > 0) - ((cos(x)) < 0))' + expr3, ref3 = sign(2 * x + x**2) * x + x**2, 'pow(x, 2) + x*(((pow(x, 2) + 2*x) > 0) - ((pow(x, 2) + 2*x) < 0))' + assert ccode(expr1) == ref1 + assert ccode(expr1, 'z') == 'z = %s;' % ref1 + assert ccode(expr2) == ref2 + assert ccode(expr3) == ref3 + +def test_ccode_Assignment(): + assert ccode(Assignment(x, y + z)) == 'x = y + z;' + assert ccode(aug_assign(x, '+', y + z)) == 'x += y + z;' + + +def test_ccode_For(): + f = For(x, Range(0, 10, 2), [aug_assign(y, '*', x)]) + assert ccode(f) == ("for (x = 0; x < 10; x += 2) {\n" + " y *= x;\n" + "}") + +def test_ccode_Max_Min(): + assert ccode(Max(x, 0), standard='C89') == '((0 > x) ? 0 : x)' + assert ccode(Max(x, 0), standard='C99') == 'fmax(0, x)' + assert ccode(Min(x, 0, sqrt(x)), standard='c89') == ( + '((0 < ((x < sqrt(x)) ? x : sqrt(x))) ? 0 : ((x < sqrt(x)) ? x : sqrt(x)))' + ) + +def test_ccode_standard(): + assert ccode(expm1(x), standard='c99') == 'expm1(x)' + assert ccode(nan, standard='c99') == 'NAN' + assert ccode(float('nan'), standard='c99') == 'NAN' + + +def test_C89CodePrinter(): + c89printer = C89CodePrinter() + assert c89printer.language == 'C' + assert c89printer.standard == 'C89' + assert 'void' in c89printer.reserved_words + assert 'template' not in c89printer.reserved_words + assert c89printer.doprint(log10(x)) == 'log10(x)' + + +def test_C99CodePrinter(): + assert C99CodePrinter().doprint(expm1(x)) == 'expm1(x)' + assert C99CodePrinter().doprint(log1p(x)) == 'log1p(x)' + assert C99CodePrinter().doprint(exp2(x)) == 'exp2(x)' + assert C99CodePrinter().doprint(log2(x)) == 'log2(x)' + assert C99CodePrinter().doprint(fma(x, y, -z)) == 'fma(x, y, -z)' + assert C99CodePrinter().doprint(log10(x)) == 'log10(x)' + assert C99CodePrinter().doprint(Cbrt(x)) == 'cbrt(x)' # note Cbrt due to cbrt already taken. + assert C99CodePrinter().doprint(hypot(x, y)) == 'hypot(x, y)' + assert C99CodePrinter().doprint(loggamma(x)) == 'lgamma(x)' + assert C99CodePrinter().doprint(Max(x, 3, x**2)) == 'fmax(3, fmax(x, pow(x, 2)))' + assert C99CodePrinter().doprint(Min(x, 3)) == 'fmin(3, x)' + c99printer = C99CodePrinter() + assert c99printer.language == 'C' + assert c99printer.standard == 'C99' + assert 'restrict' in c99printer.reserved_words + assert 'using' not in c99printer.reserved_words + + +@XFAIL +def test_C99CodePrinter__precision_f80(): + f80_printer = C99CodePrinter({"type_aliases": {real: float80}}) + assert f80_printer.doprint(sin(x + Float('2.1'))) == 'sinl(x + 2.1L)' + + +def test_C99CodePrinter__precision(): + n = symbols('n', integer=True) + p = symbols('p', integer=True, positive=True) + f32_printer = C99CodePrinter({"type_aliases": {real: float32}}) + f64_printer = C99CodePrinter({"type_aliases": {real: float64}}) + f80_printer = C99CodePrinter({"type_aliases": {real: float80}}) + assert f32_printer.doprint(sin(x+2.1)) == 'sinf(x + 2.1F)' + assert f64_printer.doprint(sin(x+2.1)) == 'sin(x + 2.1000000000000001)' + assert f80_printer.doprint(sin(x+Float('2.0'))) == 'sinl(x + 2.0L)' + + for printer, suffix in zip([f32_printer, f64_printer, f80_printer], ['f', '', 'l']): + def check(expr, ref): + assert printer.doprint(expr) == ref.format(s=suffix, S=suffix.upper()) + check(Abs(n), 'abs(n)') + check(Abs(x + 2.0), 'fabs{s}(x + 2.0{S})') + check(sin(x + 4.0)**cos(x - 2.0), 'pow{s}(sin{s}(x + 4.0{S}), cos{s}(x - 2.0{S}))') + check(exp(x*8.0), 'exp{s}(8.0{S}*x)') + check(exp2(x), 'exp2{s}(x)') + check(expm1(x*4.0), 'expm1{s}(4.0{S}*x)') + check(Mod(p, 2), 'p % 2') + check(Mod(2*p + 3, 3*p + 5, evaluate=False), '(2*p + 3) % (3*p + 5)') + check(Mod(x + 2.0, 3.0), 'fmod{s}(1.0{S}*x + 2.0{S}, 3.0{S})') + check(Mod(x, 2.0*x + 3.0), 'fmod{s}(1.0{S}*x, 2.0{S}*x + 3.0{S})') + check(log(x/2), 'log{s}((1.0{S}/2.0{S})*x)') + check(log10(3*x/2), 'log10{s}((3.0{S}/2.0{S})*x)') + check(log2(x*8.0), 'log2{s}(8.0{S}*x)') + check(log1p(x), 'log1p{s}(x)') + check(2**x, 'pow{s}(2, x)') + check(2.0**x, 'pow{s}(2.0{S}, x)') + check(x**3, 'pow{s}(x, 3)') + check(x**4.0, 'pow{s}(x, 4.0{S})') + check(sqrt(3+x), 'sqrt{s}(x + 3)') + check(Cbrt(x-2.0), 'cbrt{s}(x - 2.0{S})') + check(hypot(x, y), 'hypot{s}(x, y)') + check(sin(3.*x + 2.), 'sin{s}(3.0{S}*x + 2.0{S})') + check(cos(3.*x - 1.), 'cos{s}(3.0{S}*x - 1.0{S})') + check(tan(4.*y + 2.), 'tan{s}(4.0{S}*y + 2.0{S})') + check(asin(3.*x + 2.), 'asin{s}(3.0{S}*x + 2.0{S})') + check(acos(3.*x + 2.), 'acos{s}(3.0{S}*x + 2.0{S})') + check(atan(3.*x + 2.), 'atan{s}(3.0{S}*x + 2.0{S})') + check(atan2(3.*x, 2.*y), 'atan2{s}(3.0{S}*x, 2.0{S}*y)') + + check(sinh(3.*x + 2.), 'sinh{s}(3.0{S}*x + 2.0{S})') + check(cosh(3.*x - 1.), 'cosh{s}(3.0{S}*x - 1.0{S})') + check(tanh(4.0*y + 2.), 'tanh{s}(4.0{S}*y + 2.0{S})') + check(asinh(3.*x + 2.), 'asinh{s}(3.0{S}*x + 2.0{S})') + check(acosh(3.*x + 2.), 'acosh{s}(3.0{S}*x + 2.0{S})') + check(atanh(3.*x + 2.), 'atanh{s}(3.0{S}*x + 2.0{S})') + check(erf(42.*x), 'erf{s}(42.0{S}*x)') + check(erfc(42.*x), 'erfc{s}(42.0{S}*x)') + check(gamma(x), 'tgamma{s}(x)') + check(loggamma(x), 'lgamma{s}(x)') + + check(ceiling(x + 2.), "ceil{s}(x) + 2") + check(floor(x + 2.), "floor{s}(x) + 2") + check(fma(x, y, -z), 'fma{s}(x, y, -z)') + check(Max(x, 8.0, x**4.0), 'fmax{s}(8.0{S}, fmax{s}(x, pow{s}(x, 4.0{S})))') + check(Min(x, 2.0), 'fmin{s}(2.0{S}, x)') + + +def test_get_math_macros(): + macros = get_math_macros() + assert macros[exp(1)] == 'M_E' + assert macros[1/Sqrt(2)] == 'M_SQRT1_2' + + +def test_ccode_Declaration(): + i = symbols('i', integer=True) + var1 = Variable(i, type=Type.from_expr(i)) + dcl1 = Declaration(var1) + assert ccode(dcl1) == 'int i' + + var2 = Variable(x, type=float32, attrs={value_const}) + dcl2a = Declaration(var2) + assert ccode(dcl2a) == 'const float x' + dcl2b = var2.as_Declaration(value=pi) + assert ccode(dcl2b) == 'const float x = M_PI' + + var3 = Variable(y, type=Type('bool')) + dcl3 = Declaration(var3) + printer = C89CodePrinter() + assert 'stdbool.h' not in printer.headers + assert printer.doprint(dcl3) == 'bool y' + assert 'stdbool.h' in printer.headers + + u = symbols('u', real=True) + ptr4 = Pointer.deduced(u, attrs={pointer_const, restrict}) + dcl4 = Declaration(ptr4) + assert ccode(dcl4) == 'double * const restrict u' + + var5 = Variable(x, Type('__float128'), attrs={value_const}) + dcl5a = Declaration(var5) + assert ccode(dcl5a) == 'const __float128 x' + var5b = Variable(var5.symbol, var5.type, pi, attrs=var5.attrs) + dcl5b = Declaration(var5b) + assert ccode(dcl5b) == 'const __float128 x = M_PI' + + +def test_C99CodePrinter_custom_type(): + # We will look at __float128 (new in glibc 2.26) + f128 = FloatType('_Float128', float128.nbits, float128.nmant, float128.nexp) + p128 = C99CodePrinter({ + "type_aliases": {real: f128}, + "type_literal_suffixes": {f128: 'Q'}, + "type_func_suffixes": {f128: 'f128'}, + "type_math_macro_suffixes": { + real: 'f128', + f128: 'f128' + }, + "type_macros": { + f128: ('__STDC_WANT_IEC_60559_TYPES_EXT__',) + } + }) + assert p128.doprint(x) == 'x' + assert not p128.headers + assert not p128.libraries + assert not p128.macros + assert p128.doprint(2.0) == '2.0Q' + assert not p128.headers + assert not p128.libraries + assert p128.macros == {'__STDC_WANT_IEC_60559_TYPES_EXT__'} + + assert p128.doprint(Rational(1, 2)) == '1.0Q/2.0Q' + assert p128.doprint(sin(x)) == 'sinf128(x)' + assert p128.doprint(cos(2., evaluate=False)) == 'cosf128(2.0Q)' + assert p128.doprint(x**-1.0) == '1.0Q/x' + + var5 = Variable(x, f128, attrs={value_const}) + + dcl5a = Declaration(var5) + assert ccode(dcl5a) == 'const _Float128 x' + var5b = Variable(x, f128, pi, attrs={value_const}) + dcl5b = Declaration(var5b) + assert p128.doprint(dcl5b) == 'const _Float128 x = M_PIf128' + var5b = Variable(x, f128, value=Catalan.evalf(38), attrs={value_const}) + dcl5c = Declaration(var5b) + assert p128.doprint(dcl5c) == 'const _Float128 x = %sQ' % Catalan.evalf(f128.decimal_dig) + + +def test_MatrixElement_printing(): + # test cases for issue #11821 + A = MatrixSymbol("A", 1, 3) + B = MatrixSymbol("B", 1, 3) + C = MatrixSymbol("C", 1, 3) + + assert(ccode(A[0, 0]) == "A[0]") + assert(ccode(3 * A[0, 0]) == "3*A[0]") + + F = C[0, 0].subs(C, A - B) + assert(ccode(F) == "(A - B)[0]") + +def test_ccode_math_macros(): + assert ccode(z + exp(1)) == 'z + M_E' + assert ccode(z + log2(exp(1))) == 'z + M_LOG2E' + assert ccode(z + 1/log(2)) == 'z + M_LOG2E' + assert ccode(z + log(2)) == 'z + M_LN2' + assert ccode(z + log(10)) == 'z + M_LN10' + assert ccode(z + pi) == 'z + M_PI' + assert ccode(z + pi/2) == 'z + M_PI_2' + assert ccode(z + pi/4) == 'z + M_PI_4' + assert ccode(z + 1/pi) == 'z + M_1_PI' + assert ccode(z + 2/pi) == 'z + M_2_PI' + assert ccode(z + 2/sqrt(pi)) == 'z + M_2_SQRTPI' + assert ccode(z + 2/Sqrt(pi)) == 'z + M_2_SQRTPI' + assert ccode(z + sqrt(2)) == 'z + M_SQRT2' + assert ccode(z + Sqrt(2)) == 'z + M_SQRT2' + assert ccode(z + 1/sqrt(2)) == 'z + M_SQRT1_2' + assert ccode(z + 1/Sqrt(2)) == 'z + M_SQRT1_2' + + +def test_ccode_Type(): + assert ccode(Type('float')) == 'float' + assert ccode(intc) == 'int' + + +def test_ccode_codegen_ast(): + # Note that C only allows comments of the form /* ... */, double forward + # slash is not standard C, and some C compilers will grind to a halt upon + # encountering them. + assert ccode(Comment("this is a comment")) == "/* this is a comment */" # not // + assert ccode(While(abs(x) > 1, [aug_assign(x, '-', 1)])) == ( + 'while (fabs(x) > 1) {\n' + ' x -= 1;\n' + '}' + ) + assert ccode(Scope([AddAugmentedAssignment(x, 1)])) == ( + '{\n' + ' x += 1;\n' + '}' + ) + inp_x = Declaration(Variable(x, type=real)) + assert ccode(FunctionPrototype(real, 'pwer', [inp_x])) == 'double pwer(double x)' + assert ccode(FunctionDefinition(real, 'pwer', [inp_x], [Assignment(x, x**2)])) == ( + 'double pwer(double x){\n' + ' x = pow(x, 2);\n' + '}' + ) + + # Elements of CodeBlock are formatted as statements: + block = CodeBlock( + x, + Print([x, y], "%d %d"), + Print([QuotedString('hello'), y], "%s %d", file=stderr), + FunctionCall('pwer', [x]), + Return(x), + ) + assert ccode(block) == '\n'.join([ + 'x;', + 'printf("%d %d", x, y);', + 'fprintf(stderr, "%s %d", "hello", y);', + 'pwer(x);', + 'return x;', + ]) + +def test_ccode_UnevaluatedExpr(): + assert ccode(UnevaluatedExpr(y * x) + z) == "z + x*y" + assert ccode(UnevaluatedExpr(y + x) + z) == "z + (x + y)" # gh-21955 + w = symbols('w') + assert ccode(UnevaluatedExpr(y + x) + UnevaluatedExpr(z + w)) == "(w + z) + (x + y)" + + p, q, r = symbols("p q r", real=True) + q_r = UnevaluatedExpr(q + r) + expr = abs(exp(p+q_r)) + assert ccode(expr) == "exp(p + (q + r))" + + +def test_ccode_array_like_containers(): + assert ccode([2,3,4]) == "{2, 3, 4}" + assert ccode((2,3,4)) == "{2, 3, 4}" + +def test_ccode__isinf_isnan(): + assert ccode(isinf(x)) == 'isinf(x)' + assert ccode(isnan(x)) == 'isnan(x)' diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/test_codeprinter.py b/phivenv/Lib/site-packages/sympy/printing/tests/test_codeprinter.py new file mode 100644 index 0000000000000000000000000000000000000000..4b077037eb84e218fcfd4a05fc03e40b211e45b9 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/tests/test_codeprinter.py @@ -0,0 +1,77 @@ +from sympy.printing.codeprinter import CodePrinter, PrintMethodNotImplementedError +from sympy.core import symbols +from sympy.core.symbol import Dummy +from sympy.testing.pytest import raises +from sympy import cos +from sympy.utilities.lambdify import lambdify +from math import cos as math_cos +from sympy.printing.lambdarepr import LambdaPrinter + + +def setup_test_printer(**kwargs): + p = CodePrinter(settings=kwargs) + p._not_supported = set() + p._number_symbols = set() + return p + + +def test_print_Dummy(): + d = Dummy('d') + p = setup_test_printer() + assert p._print_Dummy(d) == "d_%i" % d.dummy_index + +def test_print_Symbol(): + + x, y = symbols('x, if') + + p = setup_test_printer() + assert p._print(x) == 'x' + assert p._print(y) == 'if' + + p.reserved_words.update(['if']) + assert p._print(y) == 'if_' + + p = setup_test_printer(error_on_reserved=True) + p.reserved_words.update(['if']) + with raises(ValueError): + p._print(y) + + p = setup_test_printer(reserved_word_suffix='_He_Man') + p.reserved_words.update(['if']) + assert p._print(y) == 'if_He_Man' + + +def test_lambdify_LaTeX_symbols_issue_23374(): + # Create symbols with Latex style names + x1, x2 = symbols("x_{1} x_2") + + # Lambdify the function + f1 = lambdify([x1, x2], cos(x1 ** 2 + x2 ** 2)) + + # Test that the function works correctly (numerically) + assert f1(1, 2) == math_cos(1 ** 2 + 2 ** 2) + + # Explicitly generate a custom printer to verify the naming convention + p = LambdaPrinter() + expr_str = p.doprint(cos(x1 ** 2 + x2 ** 2)) + assert 'x_1' in expr_str + assert 'x_2' in expr_str + + +def test_issue_15791(): + class CrashingCodePrinter(CodePrinter): + def emptyPrinter(self, obj): + raise NotImplementedError + + from sympy.matrices import ( + MutableSparseMatrix, + ImmutableSparseMatrix, + ) + + c = CrashingCodePrinter() + + # these should not silently succeed + with raises(PrintMethodNotImplementedError): + c.doprint(ImmutableSparseMatrix(2, 2, {})) + with raises(PrintMethodNotImplementedError): + c.doprint(MutableSparseMatrix(2, 2, {})) diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/test_conventions.py b/phivenv/Lib/site-packages/sympy/printing/tests/test_conventions.py new file mode 100644 index 0000000000000000000000000000000000000000..e8f1fa8532f96130828b89d1ba5ba11fd5bed7a4 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/tests/test_conventions.py @@ -0,0 +1,116 @@ +# -*- coding: utf-8 -*- + +from sympy.core.function import (Derivative, Function) +from sympy.core.numbers import oo +from sympy.core.symbol import symbols +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.trigonometric import cos +from sympy.integrals.integrals import Integral +from sympy.functions.special.bessel import besselj +from sympy.functions.special.polynomials import legendre +from sympy.functions.combinatorial.numbers import bell +from sympy.printing.conventions import split_super_sub, requires_partial +from sympy.testing.pytest import XFAIL + +def test_super_sub(): + assert split_super_sub("beta_13_2") == ("beta", [], ["13", "2"]) + assert split_super_sub("beta_132_20") == ("beta", [], ["132", "20"]) + assert split_super_sub("beta_13") == ("beta", [], ["13"]) + assert split_super_sub("x_a_b") == ("x", [], ["a", "b"]) + assert split_super_sub("x_1_2_3") == ("x", [], ["1", "2", "3"]) + assert split_super_sub("x_a_b1") == ("x", [], ["a", "b1"]) + assert split_super_sub("x_a_1") == ("x", [], ["a", "1"]) + assert split_super_sub("x_1_a") == ("x", [], ["1", "a"]) + assert split_super_sub("x_1^aa") == ("x", ["aa"], ["1"]) + assert split_super_sub("x_1__aa") == ("x", ["aa"], ["1"]) + assert split_super_sub("x_11^a") == ("x", ["a"], ["11"]) + assert split_super_sub("x_11__a") == ("x", ["a"], ["11"]) + assert split_super_sub("x_a_b_c_d") == ("x", [], ["a", "b", "c", "d"]) + assert split_super_sub("x_a_b^c^d") == ("x", ["c", "d"], ["a", "b"]) + assert split_super_sub("x_a_b__c__d") == ("x", ["c", "d"], ["a", "b"]) + assert split_super_sub("x_a^b_c^d") == ("x", ["b", "d"], ["a", "c"]) + assert split_super_sub("x_a__b_c__d") == ("x", ["b", "d"], ["a", "c"]) + assert split_super_sub("x^a^b_c_d") == ("x", ["a", "b"], ["c", "d"]) + assert split_super_sub("x__a__b_c_d") == ("x", ["a", "b"], ["c", "d"]) + assert split_super_sub("x^a^b^c^d") == ("x", ["a", "b", "c", "d"], []) + assert split_super_sub("x__a__b__c__d") == ("x", ["a", "b", "c", "d"], []) + assert split_super_sub("alpha_11") == ("alpha", [], ["11"]) + assert split_super_sub("alpha_11_11") == ("alpha", [], ["11", "11"]) + assert split_super_sub("w1") == ("w", [], ["1"]) + assert split_super_sub("w𝟙") == ("w", [], ["𝟙"]) + assert split_super_sub("w11") == ("w", [], ["11"]) + assert split_super_sub("w𝟙𝟙") == ("w", [], ["𝟙𝟙"]) + assert split_super_sub("w𝟙2𝟙") == ("w", [], ["𝟙2𝟙"]) + assert split_super_sub("w1^a") == ("w", ["a"], ["1"]) + assert split_super_sub("ω1") == ("ω", [], ["1"]) + assert split_super_sub("ω11") == ("ω", [], ["11"]) + assert split_super_sub("ω1^a") == ("ω", ["a"], ["1"]) + assert split_super_sub("ω𝟙^α") == ("ω", ["α"], ["𝟙"]) + assert split_super_sub("ω𝟙2^3α") == ("ω", ["3α"], ["𝟙2"]) + assert split_super_sub("") == ("", [], []) + + +def test_requires_partial(): + x, y, z, t, nu = symbols('x y z t nu') + n = symbols('n', integer=True) + + f = x * y + assert requires_partial(Derivative(f, x)) is True + assert requires_partial(Derivative(f, y)) is True + + ## integrating out one of the variables + assert requires_partial(Derivative(Integral(exp(-x * y), (x, 0, oo)), y, evaluate=False)) is False + + ## bessel function with smooth parameter + f = besselj(nu, x) + assert requires_partial(Derivative(f, x)) is True + assert requires_partial(Derivative(f, nu)) is True + + ## bessel function with integer parameter + f = besselj(n, x) + assert requires_partial(Derivative(f, x)) is False + # this is not really valid (differentiating with respect to an integer) + # but there's no reason to use the partial derivative symbol there. make + # sure we don't throw an exception here, though + assert requires_partial(Derivative(f, n)) is False + + ## bell polynomial + f = bell(n, x) + assert requires_partial(Derivative(f, x)) is False + # again, invalid + assert requires_partial(Derivative(f, n)) is False + + ## legendre polynomial + f = legendre(0, x) + assert requires_partial(Derivative(f, x)) is False + + f = legendre(n, x) + assert requires_partial(Derivative(f, x)) is False + # again, invalid + assert requires_partial(Derivative(f, n)) is False + + f = x ** n + assert requires_partial(Derivative(f, x)) is False + + assert requires_partial(Derivative(Integral((x*y) ** n * exp(-x * y), (x, 0, oo)), y, evaluate=False)) is False + + # parametric equation + f = (exp(t), cos(t)) + g = sum(f) + assert requires_partial(Derivative(g, t)) is False + + f = symbols('f', cls=Function) + assert requires_partial(Derivative(f(x), x)) is False + assert requires_partial(Derivative(f(x), y)) is False + assert requires_partial(Derivative(f(x, y), x)) is True + assert requires_partial(Derivative(f(x, y), y)) is True + assert requires_partial(Derivative(f(x, y), z)) is True + assert requires_partial(Derivative(f(x, y), x, y)) is True + +@XFAIL +def test_requires_partial_unspecified_variables(): + x, y = symbols('x y') + # function of unspecified variables + f = symbols('f', cls=Function) + assert requires_partial(Derivative(f, x)) is False + assert requires_partial(Derivative(f, x, y)) is True diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/test_cupy.py b/phivenv/Lib/site-packages/sympy/printing/tests/test_cupy.py new file mode 100644 index 0000000000000000000000000000000000000000..cf111ec1623390a3dbbf489235d2ed387624a36c --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/tests/test_cupy.py @@ -0,0 +1,56 @@ +from sympy.concrete.summations import Sum +from sympy.functions.elementary.exponential import log +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.utilities.lambdify import lambdify +from sympy.abc import x, i, a, b +from sympy.codegen.numpy_nodes import logaddexp +from sympy.printing.numpy import CuPyPrinter, _cupy_known_constants, _cupy_known_functions + +from sympy.testing.pytest import skip, raises +from sympy.external import import_module + +cp = import_module('cupy') + +def test_cupy_print(): + prntr = CuPyPrinter() + assert prntr.doprint(logaddexp(a, b)) == 'cupy.logaddexp(a, b)' + assert prntr.doprint(sqrt(x)) == 'cupy.sqrt(x)' + assert prntr.doprint(log(x)) == 'cupy.log(x)' + assert prntr.doprint("acos(x)") == 'cupy.arccos(x)' + assert prntr.doprint("exp(x)") == 'cupy.exp(x)' + assert prntr.doprint("Abs(x)") == 'abs(x)' + +def test_not_cupy_print(): + prntr = CuPyPrinter() + with raises(NotImplementedError): + prntr.doprint("abcd(x)") + +def test_cupy_sum(): + if not cp: + skip("CuPy not installed") + + s = Sum(x ** i, (i, a, b)) + f = lambdify((a, b, x), s, 'cupy') + + a_, b_ = 0, 10 + x_ = cp.linspace(-1, +1, 10) + assert cp.allclose(f(a_, b_, x_), sum(x_ ** i_ for i_ in range(a_, b_ + 1))) + + s = Sum(i * x, (i, a, b)) + f = lambdify((a, b, x), s, 'numpy') + + a_, b_ = 0, 10 + x_ = cp.linspace(-1, +1, 10) + assert cp.allclose(f(a_, b_, x_), sum(i_ * x_ for i_ in range(a_, b_ + 1))) + +def test_cupy_known_funcs_consts(): + assert _cupy_known_constants['NaN'] == 'cupy.nan' + assert _cupy_known_constants['EulerGamma'] == 'cupy.euler_gamma' + + assert _cupy_known_functions['acos'] == 'cupy.arccos' + assert _cupy_known_functions['log'] == 'cupy.log' + +def test_cupy_print_methods(): + prntr = CuPyPrinter() + assert hasattr(prntr, '_print_acos') + assert hasattr(prntr, '_print_log') diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/test_cxx.py b/phivenv/Lib/site-packages/sympy/printing/tests/test_cxx.py new file mode 100644 index 0000000000000000000000000000000000000000..d84ec75cbf0eeb60a1176b9cb3b401a3384454e7 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/tests/test_cxx.py @@ -0,0 +1,86 @@ +from sympy.core.numbers import Float, Integer, Rational +from sympy.core.symbol import symbols +from sympy.functions import beta, Ei, zeta, Max, Min, sqrt, riemann_xi, frac +from sympy.printing.cxx import CXX98CodePrinter, CXX11CodePrinter, CXX17CodePrinter, cxxcode +from sympy.codegen.cfunctions import log1p + + +x, y, u, v = symbols('x y u v') + + +def test_CXX98CodePrinter(): + assert CXX98CodePrinter().doprint(Max(x, 3)) in ('std::max(x, 3)', 'std::max(3, x)') + assert CXX98CodePrinter().doprint(Min(x, 3, sqrt(x))) == 'std::min(3, std::min(x, std::sqrt(x)))' + cxx98printer = CXX98CodePrinter() + assert cxx98printer.language == 'C++' + assert cxx98printer.standard == 'C++98' + assert 'template' in cxx98printer.reserved_words + assert 'alignas' not in cxx98printer.reserved_words + + +def test_CXX11CodePrinter(): + assert CXX11CodePrinter().doprint(log1p(x)) == 'std::log1p(x)' + + cxx11printer = CXX11CodePrinter() + assert cxx11printer.language == 'C++' + assert cxx11printer.standard == 'C++11' + assert 'operator' in cxx11printer.reserved_words + assert 'noexcept' in cxx11printer.reserved_words + assert 'concept' not in cxx11printer.reserved_words + + +def test_subclass_print_method(): + class MyPrinter(CXX11CodePrinter): + def _print_log1p(self, expr): + return 'my_library::log1p(%s)' % ', '.join(map(self._print, expr.args)) + + assert MyPrinter().doprint(log1p(x)) == 'my_library::log1p(x)' + + +def test_subclass_print_method__ns(): + class MyPrinter(CXX11CodePrinter): + _ns = 'my_library::' + + p = CXX11CodePrinter() + myp = MyPrinter() + + assert p.doprint(log1p(x)) == 'std::log1p(x)' + assert myp.doprint(log1p(x)) == 'my_library::log1p(x)' + + +def test_CXX17CodePrinter(): + assert CXX17CodePrinter().doprint(beta(x, y)) == 'std::beta(x, y)' + assert CXX17CodePrinter().doprint(Ei(x)) == 'std::expint(x)' + assert CXX17CodePrinter().doprint(zeta(x)) == 'std::riemann_zeta(x)' + + # Automatic rewrite + assert CXX17CodePrinter().doprint(frac(x)) == '(x - std::floor(x))' + assert CXX17CodePrinter().doprint(riemann_xi(x)) == '((1.0/2.0)*std::pow(M_PI, -1.0/2.0*x)*x*(x - 1)*std::tgamma((1.0/2.0)*x)*std::riemann_zeta(x))' + + +def test_cxxcode(): + assert sorted(cxxcode(sqrt(x)*.5).split('*')) == sorted(['0.5', 'std::sqrt(x)']) + +def test_cxxcode_nested_minmax(): + assert cxxcode(Max(Min(x, y), Min(u, v))) \ + == 'std::max(std::min(u, v), std::min(x, y))' + assert cxxcode(Min(Max(x, y), Max(u, v))) \ + == 'std::min(std::max(u, v), std::max(x, y))' + +def test_subclass_Integer_Float(): + class MyPrinter(CXX17CodePrinter): + def _print_Integer(self, arg): + return 'bigInt("%s")' % super()._print_Integer(arg) + + def _print_Float(self, arg): + rat = Rational(arg) + return 'bigFloat(%s, %s)' % ( + self._print(Integer(rat.p)), + self._print(Integer(rat.q)) + ) + + p = MyPrinter() + for i in range(13): + assert p.doprint(i) == 'bigInt("%d")' % i + assert p.doprint(Float(0.5)) == 'bigFloat(bigInt("1"), bigInt("2"))' + assert p.doprint(x**-1.0) == 'bigFloat(bigInt("1"), bigInt("1"))/x' diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/test_dot.py b/phivenv/Lib/site-packages/sympy/printing/tests/test_dot.py new file mode 100644 index 0000000000000000000000000000000000000000..6213e237fb7aac6460a956b4c9fc1f7c8710fec6 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/tests/test_dot.py @@ -0,0 +1,134 @@ +from sympy.printing.dot import (purestr, styleof, attrprint, dotnode, + dotedges, dotprint) +from sympy.core.basic import Basic +from sympy.core.expr import Expr +from sympy.core.numbers import (Float, Integer) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.printing.repr import srepr +from sympy.abc import x + + +def test_purestr(): + assert purestr(Symbol('x')) == "Symbol('x')" + assert purestr(Basic(S(1), S(2))) == "Basic(Integer(1), Integer(2))" + assert purestr(Float(2)) == "Float('2.0', precision=53)" + + assert purestr(Symbol('x'), with_args=True) == ("Symbol('x')", ()) + assert purestr(Basic(S(1), S(2)), with_args=True) == \ + ('Basic(Integer(1), Integer(2))', ('Integer(1)', 'Integer(2)')) + assert purestr(Float(2), with_args=True) == \ + ("Float('2.0', precision=53)", ()) + + +def test_styleof(): + styles = [(Basic, {'color': 'blue', 'shape': 'ellipse'}), + (Expr, {'color': 'black'})] + assert styleof(Basic(S(1)), styles) == {'color': 'blue', 'shape': 'ellipse'} + + assert styleof(x + 1, styles) == {'color': 'black', 'shape': 'ellipse'} + + +def test_attrprint(): + assert attrprint({'color': 'blue', 'shape': 'ellipse'}) == \ + '"color"="blue", "shape"="ellipse"' + +def test_dotnode(): + + assert dotnode(x, repeat=False) == \ + '"Symbol(\'x\')" ["color"="black", "label"="x", "shape"="ellipse"];' + assert dotnode(x+2, repeat=False) == \ + '"Add(Integer(2), Symbol(\'x\'))" ' \ + '["color"="black", "label"="Add", "shape"="ellipse"];', \ + dotnode(x+2,repeat=0) + + assert dotnode(x + x**2, repeat=False) == \ + '"Add(Symbol(\'x\'), Pow(Symbol(\'x\'), Integer(2)))" ' \ + '["color"="black", "label"="Add", "shape"="ellipse"];' + assert dotnode(x + x**2, repeat=True) == \ + '"Add(Symbol(\'x\'), Pow(Symbol(\'x\'), Integer(2)))_()" ' \ + '["color"="black", "label"="Add", "shape"="ellipse"];' + +def test_dotedges(): + assert sorted(dotedges(x+2, repeat=False)) == [ + '"Add(Integer(2), Symbol(\'x\'))" -> "Integer(2)";', + '"Add(Integer(2), Symbol(\'x\'))" -> "Symbol(\'x\')";' + ] + assert sorted(dotedges(x + 2, repeat=True)) == [ + '"Add(Integer(2), Symbol(\'x\'))_()" -> "Integer(2)_(0,)";', + '"Add(Integer(2), Symbol(\'x\'))_()" -> "Symbol(\'x\')_(1,)";' + ] + +def test_dotprint(): + text = dotprint(x+2, repeat=False) + assert all(e in text for e in dotedges(x+2, repeat=False)) + assert all( + n in text for n in [dotnode(expr, repeat=False) + for expr in (x, Integer(2), x+2)]) + assert 'digraph' in text + + text = dotprint(x+x**2, repeat=False) + assert all(e in text for e in dotedges(x+x**2, repeat=False)) + assert all( + n in text for n in [dotnode(expr, repeat=False) + for expr in (x, Integer(2), x**2)]) + assert 'digraph' in text + + text = dotprint(x+x**2, repeat=True) + assert all(e in text for e in dotedges(x+x**2, repeat=True)) + assert all( + n in text for n in [dotnode(expr, pos=()) + for expr in [x + x**2]]) + + text = dotprint(x**x, repeat=True) + assert all(e in text for e in dotedges(x**x, repeat=True)) + assert all( + n in text for n in [dotnode(x, pos=(0,)), dotnode(x, pos=(1,))]) + assert 'digraph' in text + +def test_dotprint_depth(): + text = dotprint(3*x+2, depth=1) + assert dotnode(3*x+2) in text + assert dotnode(x) not in text + text = dotprint(3*x+2) + assert "depth" not in text + +def test_Matrix_and_non_basics(): + from sympy.matrices.expressions.matexpr import MatrixSymbol + n = Symbol('n') + assert dotprint(MatrixSymbol('X', n, n)) == \ +"""digraph{ + +# Graph style +"ordering"="out" +"rankdir"="TD" + +######### +# Nodes # +######### + +"MatrixSymbol(Str('X'), Symbol('n'), Symbol('n'))_()" ["color"="black", "label"="MatrixSymbol", "shape"="ellipse"]; +"Str('X')_(0,)" ["color"="blue", "label"="X", "shape"="ellipse"]; +"Symbol('n')_(1,)" ["color"="black", "label"="n", "shape"="ellipse"]; +"Symbol('n')_(2,)" ["color"="black", "label"="n", "shape"="ellipse"]; + +######### +# Edges # +######### + +"MatrixSymbol(Str('X'), Symbol('n'), Symbol('n'))_()" -> "Str('X')_(0,)"; +"MatrixSymbol(Str('X'), Symbol('n'), Symbol('n'))_()" -> "Symbol('n')_(1,)"; +"MatrixSymbol(Str('X'), Symbol('n'), Symbol('n'))_()" -> "Symbol('n')_(2,)"; +}""" + + +def test_labelfunc(): + text = dotprint(x + 2, labelfunc=srepr) + assert "Symbol('x')" in text + assert "Integer(2)" in text + + +def test_commutative(): + x, y = symbols('x y', commutative=False) + assert dotprint(x + y) == dotprint(y + x) + assert dotprint(x*y) != dotprint(y*x) diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/test_fortran.py b/phivenv/Lib/site-packages/sympy/printing/tests/test_fortran.py new file mode 100644 index 0000000000000000000000000000000000000000..c28a1ea16dcf2157b58d763286428dccc1944b71 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/tests/test_fortran.py @@ -0,0 +1,854 @@ +from sympy.core.add import Add +from sympy.core.expr import Expr +from sympy.core.function import (Function, Lambda, diff) +from sympy.core.mod import Mod +from sympy.core import (Catalan, EulerGamma, GoldenRatio) +from sympy.core.numbers import (E, Float, I, Integer, Rational, pi) +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, symbols) +from sympy.functions.combinatorial.factorials import factorial +from sympy.functions.elementary.complexes import (conjugate, sign) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import (atan2, cos, sin) +from sympy.functions.special.gamma_functions import gamma +from sympy.integrals.integrals import Integral +from sympy.sets.fancysets import Range + +from sympy.codegen import For, Assignment, aug_assign +from sympy.codegen.ast import Declaration, Variable, float32, float64, \ + value_const, real, bool_, While, FunctionPrototype, FunctionDefinition, \ + integer, Return, Element +from sympy.core.expr import UnevaluatedExpr +from sympy.core.relational import Relational +from sympy.logic.boolalg import And, Or, Not, Equivalent, Xor +from sympy.matrices import Matrix, MatrixSymbol +from sympy.printing.fortran import fcode, FCodePrinter +from sympy.tensor import IndexedBase, Idx +from sympy.tensor.array.expressions import ArraySymbol, ArrayElement +from sympy.utilities.lambdify import implemented_function +from sympy.testing.pytest import raises + + +def test_UnevaluatedExpr(): + p, q, r = symbols("p q r", real=True) + q_r = UnevaluatedExpr(q + r) + expr = abs(exp(p+q_r)) + assert fcode(expr, source_format="free") == "exp(p + (q + r))" + x, y, z = symbols("x y z") + y_z = UnevaluatedExpr(y + z) + expr2 = abs(exp(x+y_z)) + assert fcode(expr2, human=False)[2].lstrip() == "exp(re(x) + re(y + z))" + assert fcode(expr2, user_functions={"re": "realpart"}).lstrip() == "exp(realpart(x) + realpart(y + z))" + + +def test_printmethod(): + x = symbols('x') + + class nint(Function): + def _fcode(self, printer): + return "nint(%s)" % printer._print(self.args[0]) + assert fcode(nint(x)) == " nint(x)" + + +def test_fcode_sign(): #issue 12267 + x=symbols('x') + y=symbols('y', integer=True) + z=symbols('z', complex=True) + assert fcode(sign(x), standard=95, source_format='free') == "merge(0d0, dsign(1d0, x), x == 0d0)" + assert fcode(sign(y), standard=95, source_format='free') == "merge(0, isign(1, y), y == 0)" + assert fcode(sign(z), standard=95, source_format='free') == "merge(cmplx(0d0, 0d0), z/abs(z), abs(z) == 0d0)" + raises(NotImplementedError, lambda: fcode(sign(x))) + + +def test_fcode_Pow(): + x, y = symbols('x,y') + n = symbols('n', integer=True) + + assert fcode(x**3) == " x**3" + assert fcode(x**(y**3)) == " x**(y**3)" + assert fcode(1/(sin(x)*3.5)**(x - y**x)/(x**2 + y)) == \ + " (3.5d0*sin(x))**(-x + y**x)/(x**2 + y)" + assert fcode(sqrt(x)) == ' sqrt(x)' + assert fcode(sqrt(n)) == ' sqrt(dble(n))' + assert fcode(x**0.5) == ' sqrt(x)' + assert fcode(sqrt(x)) == ' sqrt(x)' + assert fcode(sqrt(10)) == ' sqrt(10.0d0)' + assert fcode(x**-1.0) == ' 1d0/x' + assert fcode(x**-2.0, 'y', source_format='free') == 'y = x**(-2.0d0)' # 2823 + assert fcode(x**Rational(3, 7)) == ' x**(3.0d0/7.0d0)' + + +def test_fcode_Rational(): + x = symbols('x') + assert fcode(Rational(3, 7)) == " 3.0d0/7.0d0" + assert fcode(Rational(18, 9)) == " 2" + assert fcode(Rational(3, -7)) == " -3.0d0/7.0d0" + assert fcode(Rational(-3, -7)) == " 3.0d0/7.0d0" + assert fcode(x + Rational(3, 7)) == " x + 3.0d0/7.0d0" + assert fcode(Rational(3, 7)*x) == " (3.0d0/7.0d0)*x" + + +def test_fcode_Integer(): + assert fcode(Integer(67)) == " 67" + assert fcode(Integer(-1)) == " -1" + + +def test_fcode_Float(): + assert fcode(Float(42.0)) == " 42.0000000000000d0" + assert fcode(Float(-1e20)) == " -1.00000000000000d+20" + + +def test_fcode_functions(): + x, y = symbols('x,y') + assert fcode(sin(x) ** cos(y)) == " sin(x)**cos(y)" + raises(NotImplementedError, lambda: fcode(Mod(x, y), standard=66)) + raises(NotImplementedError, lambda: fcode(x % y, standard=66)) + raises(NotImplementedError, lambda: fcode(Mod(x, y), standard=77)) + raises(NotImplementedError, lambda: fcode(x % y, standard=77)) + for standard in [90, 95, 2003, 2008]: + assert fcode(Mod(x, y), standard=standard) == " modulo(x, y)" + assert fcode(x % y, standard=standard) == " modulo(x, y)" + + +def test_case(): + ob = FCodePrinter() + x,x_,x__,y,X,X_,Y = symbols('x,x_,x__,y,X,X_,Y') + assert fcode(exp(x_) + sin(x*y) + cos(X*Y)) == \ + ' exp(x_) + sin(x*y) + cos(X__*Y_)' + assert fcode(exp(x__) + 2*x*Y*X_**Rational(7, 2)) == \ + ' 2*X_**(7.0d0/2.0d0)*Y*x + exp(x__)' + assert fcode(exp(x_) + sin(x*y) + cos(X*Y), name_mangling=False) == \ + ' exp(x_) + sin(x*y) + cos(X*Y)' + assert fcode(x - cos(X), name_mangling=False) == ' x - cos(X)' + assert ob.doprint(X*sin(x) + x_, assign_to='me') == ' me = X*sin(x_) + x__' + assert ob.doprint(X*sin(x), assign_to='mu') == ' mu = X*sin(x_)' + assert ob.doprint(x_, assign_to='ad') == ' ad = x__' + n, m = symbols('n,m', integer=True) + A = IndexedBase('A') + x = IndexedBase('x') + y = IndexedBase('y') + i = Idx('i', m) + I = Idx('I', n) + assert fcode(A[i, I]*x[I], assign_to=y[i], source_format='free') == ( + "do i = 1, m\n" + " y(i) = 0\n" + "end do\n" + "do i = 1, m\n" + " do I_ = 1, n\n" + " y(i) = A(i, I_)*x(I_) + y(i)\n" + " end do\n" + "end do" ) + + +#issue 6814 +def test_fcode_functions_with_integers(): + x= symbols('x') + log10_17 = log(10).evalf(17) + loglog10_17 = '0.8340324452479558d0' + assert fcode(x * log(10)) == " x*%sd0" % log10_17 + assert fcode(x * log(10)) == " x*%sd0" % log10_17 + assert fcode(x * log(S(10))) == " x*%sd0" % log10_17 + assert fcode(log(S(10))) == " %sd0" % log10_17 + assert fcode(exp(10)) == " %sd0" % exp(10).evalf(17) + assert fcode(x * log(log(10))) == " x*%s" % loglog10_17 + assert fcode(x * log(log(S(10)))) == " x*%s" % loglog10_17 + + +def test_fcode_NumberSymbol(): + prec = 17 + p = FCodePrinter() + assert fcode(Catalan) == ' parameter (Catalan = %sd0)\n Catalan' % Catalan.evalf(prec) + assert fcode(EulerGamma) == ' parameter (EulerGamma = %sd0)\n EulerGamma' % EulerGamma.evalf(prec) + assert fcode(E) == ' parameter (E = %sd0)\n E' % E.evalf(prec) + assert fcode(GoldenRatio) == ' parameter (GoldenRatio = %sd0)\n GoldenRatio' % GoldenRatio.evalf(prec) + assert fcode(pi) == ' parameter (pi = %sd0)\n pi' % pi.evalf(prec) + assert fcode( + pi, precision=5) == ' parameter (pi = %sd0)\n pi' % pi.evalf(5) + assert fcode(Catalan, human=False) == ({ + (Catalan, p._print(Catalan.evalf(prec)))}, set(), ' Catalan') + assert fcode(EulerGamma, human=False) == ({(EulerGamma, p._print( + EulerGamma.evalf(prec)))}, set(), ' EulerGamma') + assert fcode(E, human=False) == ( + {(E, p._print(E.evalf(prec)))}, set(), ' E') + assert fcode(GoldenRatio, human=False) == ({(GoldenRatio, p._print( + GoldenRatio.evalf(prec)))}, set(), ' GoldenRatio') + assert fcode(pi, human=False) == ( + {(pi, p._print(pi.evalf(prec)))}, set(), ' pi') + assert fcode(pi, precision=5, human=False) == ( + {(pi, p._print(pi.evalf(5)))}, set(), ' pi') + + +def test_fcode_complex(): + assert fcode(I) == " cmplx(0,1)" + x = symbols('x') + assert fcode(4*I) == " cmplx(0,4)" + assert fcode(3 + 4*I) == " cmplx(3,4)" + assert fcode(3 + 4*I + x) == " cmplx(3,4) + x" + assert fcode(I*x) == " cmplx(0,1)*x" + assert fcode(3 + 4*I - x) == " cmplx(3,4) - x" + x = symbols('x', imaginary=True) + assert fcode(5*x) == " 5*x" + assert fcode(I*x) == " cmplx(0,1)*x" + assert fcode(3 + x) == " x + 3" + + +def test_implicit(): + x, y = symbols('x,y') + assert fcode(sin(x)) == " sin(x)" + assert fcode(atan2(x, y)) == " atan2(x, y)" + assert fcode(conjugate(x)) == " conjg(x)" + + +def test_not_fortran(): + x = symbols('x') + g = Function('g') + with raises(NotImplementedError): + fcode(gamma(x)) + assert fcode(Integral(sin(x)), strict=False) == "C Not supported in Fortran:\nC Integral\n Integral(sin(x), x)" + with raises(NotImplementedError): + fcode(g(x)) + + +def test_user_functions(): + x = symbols('x') + assert fcode(sin(x), user_functions={"sin": "zsin"}) == " zsin(x)" + x = symbols('x') + assert fcode( + gamma(x), user_functions={"gamma": "mygamma"}) == " mygamma(x)" + g = Function('g') + assert fcode(g(x), user_functions={"g": "great"}) == " great(x)" + n = symbols('n', integer=True) + assert fcode( + factorial(n), user_functions={"factorial": "fct"}) == " fct(n)" + + +def test_inline_function(): + x = symbols('x') + g = implemented_function('g', Lambda(x, 2*x)) + assert fcode(g(x)) == " 2*x" + g = implemented_function('g', Lambda(x, 2*pi/x)) + assert fcode(g(x)) == ( + " parameter (pi = %sd0)\n" + " 2*pi/x" + ) % pi.evalf(17) + A = IndexedBase('A') + i = Idx('i', symbols('n', integer=True)) + g = implemented_function('g', Lambda(x, x*(1 + x)*(2 + x))) + assert fcode(g(A[i]), assign_to=A[i]) == ( + " do i = 1, n\n" + " A(i) = (A(i) + 1)*(A(i) + 2)*A(i)\n" + " end do" + ) + + +def test_assign_to(): + x = symbols('x') + assert fcode(sin(x), assign_to="s") == " s = sin(x)" + + +def test_line_wrapping(): + x, y = symbols('x,y') + assert fcode(((x + y)**10).expand(), assign_to="var") == ( + " var = x**10 + 10*x**9*y + 45*x**8*y**2 + 120*x**7*y**3 + 210*x**6*\n" + " @ y**4 + 252*x**5*y**5 + 210*x**4*y**6 + 120*x**3*y**7 + 45*x**2*y\n" + " @ **8 + 10*x*y**9 + y**10" + ) + e = [x**i for i in range(11)] + assert fcode(Add(*e)) == ( + " x**10 + x**9 + x**8 + x**7 + x**6 + x**5 + x**4 + x**3 + x**2 + x\n" + " @ + 1" + ) + + +def test_fcode_precedence(): + x, y = symbols("x y") + assert fcode(And(x < y, y < x + 1), source_format="free") == \ + "x < y .and. y < x + 1" + assert fcode(Or(x < y, y < x + 1), source_format="free") == \ + "x < y .or. y < x + 1" + assert fcode(Xor(x < y, y < x + 1, evaluate=False), + source_format="free") == "x < y .neqv. y < x + 1" + assert fcode(Equivalent(x < y, y < x + 1), source_format="free") == \ + "x < y .eqv. y < x + 1" + + +def test_fcode_Logical(): + x, y, z = symbols("x y z") + # unary Not + assert fcode(Not(x), source_format="free") == ".not. x" + # binary And + assert fcode(And(x, y), source_format="free") == "x .and. y" + assert fcode(And(x, Not(y)), source_format="free") == "x .and. .not. y" + assert fcode(And(Not(x), y), source_format="free") == "y .and. .not. x" + assert fcode(And(Not(x), Not(y)), source_format="free") == \ + ".not. x .and. .not. y" + assert fcode(Not(And(x, y), evaluate=False), source_format="free") == \ + ".not. (x .and. y)" + # binary Or + assert fcode(Or(x, y), source_format="free") == "x .or. y" + assert fcode(Or(x, Not(y)), source_format="free") == "x .or. .not. y" + assert fcode(Or(Not(x), y), source_format="free") == "y .or. .not. x" + assert fcode(Or(Not(x), Not(y)), source_format="free") == \ + ".not. x .or. .not. y" + assert fcode(Not(Or(x, y), evaluate=False), source_format="free") == \ + ".not. (x .or. y)" + # mixed And/Or + assert fcode(And(Or(y, z), x), source_format="free") == "x .and. (y .or. z)" + assert fcode(And(Or(z, x), y), source_format="free") == "y .and. (x .or. z)" + assert fcode(And(Or(x, y), z), source_format="free") == "z .and. (x .or. y)" + assert fcode(Or(And(y, z), x), source_format="free") == "x .or. y .and. z" + assert fcode(Or(And(z, x), y), source_format="free") == "y .or. x .and. z" + assert fcode(Or(And(x, y), z), source_format="free") == "z .or. x .and. y" + # trinary And + assert fcode(And(x, y, z), source_format="free") == "x .and. y .and. z" + assert fcode(And(x, y, Not(z)), source_format="free") == \ + "x .and. y .and. .not. z" + assert fcode(And(x, Not(y), z), source_format="free") == \ + "x .and. z .and. .not. y" + assert fcode(And(Not(x), y, z), source_format="free") == \ + "y .and. z .and. .not. x" + assert fcode(Not(And(x, y, z), evaluate=False), source_format="free") == \ + ".not. (x .and. y .and. z)" + # trinary Or + assert fcode(Or(x, y, z), source_format="free") == "x .or. y .or. z" + assert fcode(Or(x, y, Not(z)), source_format="free") == \ + "x .or. y .or. .not. z" + assert fcode(Or(x, Not(y), z), source_format="free") == \ + "x .or. z .or. .not. y" + assert fcode(Or(Not(x), y, z), source_format="free") == \ + "y .or. z .or. .not. x" + assert fcode(Not(Or(x, y, z), evaluate=False), source_format="free") == \ + ".not. (x .or. y .or. z)" + + +def test_fcode_Xlogical(): + x, y, z = symbols("x y z") + # binary Xor + assert fcode(Xor(x, y, evaluate=False), source_format="free") == \ + "x .neqv. y" + assert fcode(Xor(x, Not(y), evaluate=False), source_format="free") == \ + "x .neqv. .not. y" + assert fcode(Xor(Not(x), y, evaluate=False), source_format="free") == \ + "y .neqv. .not. x" + assert fcode(Xor(Not(x), Not(y), evaluate=False), + source_format="free") == ".not. x .neqv. .not. y" + assert fcode(Not(Xor(x, y, evaluate=False), evaluate=False), + source_format="free") == ".not. (x .neqv. y)" + # binary Equivalent + assert fcode(Equivalent(x, y), source_format="free") == "x .eqv. y" + assert fcode(Equivalent(x, Not(y)), source_format="free") == \ + "x .eqv. .not. y" + assert fcode(Equivalent(Not(x), y), source_format="free") == \ + "y .eqv. .not. x" + assert fcode(Equivalent(Not(x), Not(y)), source_format="free") == \ + ".not. x .eqv. .not. y" + assert fcode(Not(Equivalent(x, y), evaluate=False), + source_format="free") == ".not. (x .eqv. y)" + # mixed And/Equivalent + assert fcode(Equivalent(And(y, z), x), source_format="free") == \ + "x .eqv. y .and. z" + assert fcode(Equivalent(And(z, x), y), source_format="free") == \ + "y .eqv. x .and. z" + assert fcode(Equivalent(And(x, y), z), source_format="free") == \ + "z .eqv. x .and. y" + assert fcode(And(Equivalent(y, z), x), source_format="free") == \ + "x .and. (y .eqv. z)" + assert fcode(And(Equivalent(z, x), y), source_format="free") == \ + "y .and. (x .eqv. z)" + assert fcode(And(Equivalent(x, y), z), source_format="free") == \ + "z .and. (x .eqv. y)" + # mixed Or/Equivalent + assert fcode(Equivalent(Or(y, z), x), source_format="free") == \ + "x .eqv. y .or. z" + assert fcode(Equivalent(Or(z, x), y), source_format="free") == \ + "y .eqv. x .or. z" + assert fcode(Equivalent(Or(x, y), z), source_format="free") == \ + "z .eqv. x .or. y" + assert fcode(Or(Equivalent(y, z), x), source_format="free") == \ + "x .or. (y .eqv. z)" + assert fcode(Or(Equivalent(z, x), y), source_format="free") == \ + "y .or. (x .eqv. z)" + assert fcode(Or(Equivalent(x, y), z), source_format="free") == \ + "z .or. (x .eqv. y)" + # mixed Xor/Equivalent + assert fcode(Equivalent(Xor(y, z, evaluate=False), x), + source_format="free") == "x .eqv. (y .neqv. z)" + assert fcode(Equivalent(Xor(z, x, evaluate=False), y), + source_format="free") == "y .eqv. (x .neqv. z)" + assert fcode(Equivalent(Xor(x, y, evaluate=False), z), + source_format="free") == "z .eqv. (x .neqv. y)" + assert fcode(Xor(Equivalent(y, z), x, evaluate=False), + source_format="free") == "x .neqv. (y .eqv. z)" + assert fcode(Xor(Equivalent(z, x), y, evaluate=False), + source_format="free") == "y .neqv. (x .eqv. z)" + assert fcode(Xor(Equivalent(x, y), z, evaluate=False), + source_format="free") == "z .neqv. (x .eqv. y)" + # mixed And/Xor + assert fcode(Xor(And(y, z), x, evaluate=False), source_format="free") == \ + "x .neqv. y .and. z" + assert fcode(Xor(And(z, x), y, evaluate=False), source_format="free") == \ + "y .neqv. x .and. z" + assert fcode(Xor(And(x, y), z, evaluate=False), source_format="free") == \ + "z .neqv. x .and. y" + assert fcode(And(Xor(y, z, evaluate=False), x), source_format="free") == \ + "x .and. (y .neqv. z)" + assert fcode(And(Xor(z, x, evaluate=False), y), source_format="free") == \ + "y .and. (x .neqv. z)" + assert fcode(And(Xor(x, y, evaluate=False), z), source_format="free") == \ + "z .and. (x .neqv. y)" + # mixed Or/Xor + assert fcode(Xor(Or(y, z), x, evaluate=False), source_format="free") == \ + "x .neqv. y .or. z" + assert fcode(Xor(Or(z, x), y, evaluate=False), source_format="free") == \ + "y .neqv. x .or. z" + assert fcode(Xor(Or(x, y), z, evaluate=False), source_format="free") == \ + "z .neqv. x .or. y" + assert fcode(Or(Xor(y, z, evaluate=False), x), source_format="free") == \ + "x .or. (y .neqv. z)" + assert fcode(Or(Xor(z, x, evaluate=False), y), source_format="free") == \ + "y .or. (x .neqv. z)" + assert fcode(Or(Xor(x, y, evaluate=False), z), source_format="free") == \ + "z .or. (x .neqv. y)" + # trinary Xor + assert fcode(Xor(x, y, z, evaluate=False), source_format="free") == \ + "x .neqv. y .neqv. z" + assert fcode(Xor(x, y, Not(z), evaluate=False), source_format="free") == \ + "x .neqv. y .neqv. .not. z" + assert fcode(Xor(x, Not(y), z, evaluate=False), source_format="free") == \ + "x .neqv. z .neqv. .not. y" + assert fcode(Xor(Not(x), y, z, evaluate=False), source_format="free") == \ + "y .neqv. z .neqv. .not. x" + + +def test_fcode_Relational(): + x, y = symbols("x y") + assert fcode(Relational(x, y, "=="), source_format="free") == "x == y" + assert fcode(Relational(x, y, "!="), source_format="free") == "x /= y" + assert fcode(Relational(x, y, ">="), source_format="free") == "x >= y" + assert fcode(Relational(x, y, "<="), source_format="free") == "x <= y" + assert fcode(Relational(x, y, ">"), source_format="free") == "x > y" + assert fcode(Relational(x, y, "<"), source_format="free") == "x < y" + + +def test_fcode_Piecewise(): + x = symbols('x') + expr = Piecewise((x, x < 1), (x**2, True)) + # Check that inline conditional (merge) fails if standard isn't 95+ + raises(NotImplementedError, lambda: fcode(expr)) + code = fcode(expr, standard=95) + expected = " merge(x, x**2, x < 1)" + assert code == expected + assert fcode(Piecewise((x, x < 1), (x**2, True)), assign_to="var") == ( + " if (x < 1) then\n" + " var = x\n" + " else\n" + " var = x**2\n" + " end if" + ) + a = cos(x)/x + b = sin(x)/x + for i in range(10): + a = diff(a, x) + b = diff(b, x) + expected = ( + " if (x < 0) then\n" + " weird_name = -cos(x)/x + 10*sin(x)/x**2 + 90*cos(x)/x**3 - 720*\n" + " @ sin(x)/x**4 - 5040*cos(x)/x**5 + 30240*sin(x)/x**6 + 151200*cos(x\n" + " @ )/x**7 - 604800*sin(x)/x**8 - 1814400*cos(x)/x**9 + 3628800*sin(x\n" + " @ )/x**10 + 3628800*cos(x)/x**11\n" + " else\n" + " weird_name = -sin(x)/x - 10*cos(x)/x**2 + 90*sin(x)/x**3 + 720*\n" + " @ cos(x)/x**4 - 5040*sin(x)/x**5 - 30240*cos(x)/x**6 + 151200*sin(x\n" + " @ )/x**7 + 604800*cos(x)/x**8 - 1814400*sin(x)/x**9 - 3628800*cos(x\n" + " @ )/x**10 + 3628800*sin(x)/x**11\n" + " end if" + ) + code = fcode(Piecewise((a, x < 0), (b, True)), assign_to="weird_name") + assert code == expected + code = fcode(Piecewise((x, x < 1), (x**2, x > 1), (sin(x), True)), standard=95) + expected = " merge(x, merge(x**2, sin(x), x > 1), x < 1)" + assert code == expected + # Check that Piecewise without a True (default) condition error + expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0)) + raises(ValueError, lambda: fcode(expr)) + + +def test_wrap_fortran(): + # "########################################################################" + printer = FCodePrinter() + lines = [ + "C This is a long comment on a single line that must be wrapped properly to produce nice output", + " this = is + a + long + and + nasty + fortran + statement + that * must + be + wrapped + properly", + " this = is + a + long + and + nasty + fortran + statement + that * must + be + wrapped + properly", + " this = is + a + long + and + nasty + fortran + statement + that * must + be + wrapped + properly", + " this = is + a + long + and + nasty + fortran + statement + that*must + be + wrapped + properly", + " this = is + a + long + and + nasty + fortran + statement + that*must + be + wrapped + properly", + " this = is + a + long + and + nasty + fortran + statement + that*must + be + wrapped + properly", + " this = is + a + long + and + nasty + fortran + statement + that*must + be + wrapped + properly", + " this = is + a + long + and + nasty + fortran + statement + that**must + be + wrapped + properly", + " this = is + a + long + and + nasty + fortran + statement + that**must + be + wrapped + properly", + " this = is + a + long + and + nasty + fortran + statement + that**must + be + wrapped + properly", + " this = is + a + long + and + nasty + fortran + statement + that**must + be + wrapped + properly", + " this = is + a + long + and + nasty + fortran + statement + that**must + be + wrapped + properly", + " this = is + a + long + and + nasty + fortran + statement(that)/must + be + wrapped + properly", + " this = is + a + long + and + nasty + fortran + statement(that)/must + be + wrapped + properly", + ] + wrapped_lines = printer._wrap_fortran(lines) + expected_lines = [ + "C This is a long comment on a single line that must be wrapped", + "C properly to produce nice output", + " this = is + a + long + and + nasty + fortran + statement + that *", + " @ must + be + wrapped + properly", + " this = is + a + long + and + nasty + fortran + statement + that *", + " @ must + be + wrapped + properly", + " this = is + a + long + and + nasty + fortran + statement + that", + " @ * must + be + wrapped + properly", + " this = is + a + long + and + nasty + fortran + statement + that*", + " @ must + be + wrapped + properly", + " this = is + a + long + and + nasty + fortran + statement + that*", + " @ must + be + wrapped + properly", + " this = is + a + long + and + nasty + fortran + statement + that", + " @ *must + be + wrapped + properly", + " this = is + a + long + and + nasty + fortran + statement +", + " @ that*must + be + wrapped + properly", + " this = is + a + long + and + nasty + fortran + statement + that**", + " @ must + be + wrapped + properly", + " this = is + a + long + and + nasty + fortran + statement + that**", + " @ must + be + wrapped + properly", + " this = is + a + long + and + nasty + fortran + statement + that", + " @ **must + be + wrapped + properly", + " this = is + a + long + and + nasty + fortran + statement + that", + " @ **must + be + wrapped + properly", + " this = is + a + long + and + nasty + fortran + statement +", + " @ that**must + be + wrapped + properly", + " this = is + a + long + and + nasty + fortran + statement(that)/", + " @ must + be + wrapped + properly", + " this = is + a + long + and + nasty + fortran + statement(that)", + " @ /must + be + wrapped + properly", + ] + for line in wrapped_lines: + assert len(line) <= 72 + for w, e in zip(wrapped_lines, expected_lines): + assert w == e + assert len(wrapped_lines) == len(expected_lines) + + +def test_wrap_fortran_keep_d0(): + printer = FCodePrinter() + lines = [ + ' this_variable_is_very_long_because_we_try_to_test_line_break=1.0d0', + ' this_variable_is_very_long_because_we_try_to_test_line_break =1.0d0', + ' this_variable_is_very_long_because_we_try_to_test_line_break = 1.0d0', + ' this_variable_is_very_long_because_we_try_to_test_line_break = 1.0d0', + ' this_variable_is_very_long_because_we_try_to_test_line_break = 1.0d0', + ' this_variable_is_very_long_because_we_try_to_test_line_break = 10.0d0' + ] + expected = [ + ' this_variable_is_very_long_because_we_try_to_test_line_break=1.0d0', + ' this_variable_is_very_long_because_we_try_to_test_line_break =', + ' @ 1.0d0', + ' this_variable_is_very_long_because_we_try_to_test_line_break =', + ' @ 1.0d0', + ' this_variable_is_very_long_because_we_try_to_test_line_break =', + ' @ 1.0d0', + ' this_variable_is_very_long_because_we_try_to_test_line_break =', + ' @ 1.0d0', + ' this_variable_is_very_long_because_we_try_to_test_line_break =', + ' @ 10.0d0' + ] + assert printer._wrap_fortran(lines) == expected + + +def test_settings(): + raises(TypeError, lambda: fcode(S(4), method="garbage")) + + +def test_free_form_code_line(): + x, y = symbols('x,y') + assert fcode(cos(x) + sin(y), source_format='free') == "sin(y) + cos(x)" + + +def test_free_form_continuation_line(): + x, y = symbols('x,y') + result = fcode(((cos(x) + sin(y))**(7)).expand(), source_format='free') + expected = ( + 'sin(y)**7 + 7*sin(y)**6*cos(x) + 21*sin(y)**5*cos(x)**2 + 35*sin(y)**4* &\n' + ' cos(x)**3 + 35*sin(y)**3*cos(x)**4 + 21*sin(y)**2*cos(x)**5 + 7* &\n' + ' sin(y)*cos(x)**6 + cos(x)**7' + ) + assert result == expected + + +def test_free_form_comment_line(): + printer = FCodePrinter({'source_format': 'free'}) + lines = [ "! This is a long comment on a single line that must be wrapped properly to produce nice output"] + expected = [ + '! This is a long comment on a single line that must be wrapped properly', + '! to produce nice output'] + assert printer._wrap_fortran(lines) == expected + + +def test_loops(): + n, m = symbols('n,m', integer=True) + A = IndexedBase('A') + x = IndexedBase('x') + y = IndexedBase('y') + i = Idx('i', m) + j = Idx('j', n) + + expected = ( + 'do i = 1, m\n' + ' y(i) = 0\n' + 'end do\n' + 'do i = 1, m\n' + ' do j = 1, n\n' + ' y(i) = %(rhs)s\n' + ' end do\n' + 'end do' + ) + + code = fcode(A[i, j]*x[j], assign_to=y[i], source_format='free') + assert (code == expected % {'rhs': 'y(i) + A(i, j)*x(j)'} or + code == expected % {'rhs': 'y(i) + x(j)*A(i, j)'} or + code == expected % {'rhs': 'x(j)*A(i, j) + y(i)'} or + code == expected % {'rhs': 'A(i, j)*x(j) + y(i)'}) + + +def test_dummy_loops(): + i, m = symbols('i m', integer=True, cls=Dummy) + x = IndexedBase('x') + y = IndexedBase('y') + i = Idx(i, m) + + expected = ( + 'do i_%(icount)i = 1, m_%(mcount)i\n' + ' y(i_%(icount)i) = x(i_%(icount)i)\n' + 'end do' + ) % {'icount': i.label.dummy_index, 'mcount': m.dummy_index} + code = fcode(x[i], assign_to=y[i], source_format='free') + assert code == expected + + +def test_fcode_Indexed_without_looking_for_contraction(): + len_y = 5 + y = IndexedBase('y', shape=(len_y,)) + x = IndexedBase('x', shape=(len_y,)) + Dy = IndexedBase('Dy', shape=(len_y-1,)) + i = Idx('i', len_y-1) + e=Eq(Dy[i], (y[i+1]-y[i])/(x[i+1]-x[i])) + code0 = fcode(e.rhs, assign_to=e.lhs, contract=False) + assert code0.endswith('Dy(i) = (y(i + 1) - y(i))/(x(i + 1) - x(i))') + + +def test_element_like_objects(): + len_y = 5 + y = ArraySymbol('y', shape=(len_y,)) + x = ArraySymbol('x', shape=(len_y,)) + Dy = ArraySymbol('Dy', shape=(len_y-1,)) + i = Idx('i', len_y-1) + e=Eq(Dy[i], (y[i+1]-y[i])/(x[i+1]-x[i])) + code0 = fcode(Assignment(e.lhs, e.rhs)) + assert code0.endswith('Dy(i) = (y(i + 1) - y(i))/(x(i + 1) - x(i))') + + class ElementExpr(Element, Expr): + pass + + e = e.subs((a, ElementExpr(a.name, a.indices)) for a in e.atoms(ArrayElement) ) + e=Eq(Dy[i], (y[i+1]-y[i])/(x[i+1]-x[i])) + code0 = fcode(Assignment(e.lhs, e.rhs)) + assert code0.endswith('Dy(i) = (y(i + 1) - y(i))/(x(i + 1) - x(i))') + + +def test_derived_classes(): + class MyFancyFCodePrinter(FCodePrinter): + _default_settings = FCodePrinter._default_settings.copy() + + printer = MyFancyFCodePrinter() + x = symbols('x') + assert printer.doprint(sin(x), "bork") == " bork = sin(x)" + + +def test_indent(): + codelines = ( + 'subroutine test(a)\n' + 'integer :: a, i, j\n' + '\n' + 'do\n' + 'do \n' + 'do j = 1, 5\n' + 'if (a>b) then\n' + 'if(b>0) then\n' + 'a = 3\n' + 'donot_indent_me = 2\n' + 'do_not_indent_me_either = 2\n' + 'ifIam_indented_something_went_wrong = 2\n' + 'if_I_am_indented_something_went_wrong = 2\n' + 'end should not be unindented here\n' + 'end if\n' + 'endif\n' + 'end do\n' + 'end do\n' + 'enddo\n' + 'end subroutine\n' + '\n' + 'subroutine test2(a)\n' + 'integer :: a\n' + 'do\n' + 'a = a + 1\n' + 'end do \n' + 'end subroutine\n' + ) + expected = ( + 'subroutine test(a)\n' + 'integer :: a, i, j\n' + '\n' + 'do\n' + ' do \n' + ' do j = 1, 5\n' + ' if (a>b) then\n' + ' if(b>0) then\n' + ' a = 3\n' + ' donot_indent_me = 2\n' + ' do_not_indent_me_either = 2\n' + ' ifIam_indented_something_went_wrong = 2\n' + ' if_I_am_indented_something_went_wrong = 2\n' + ' end should not be unindented here\n' + ' end if\n' + ' endif\n' + ' end do\n' + ' end do\n' + 'enddo\n' + 'end subroutine\n' + '\n' + 'subroutine test2(a)\n' + 'integer :: a\n' + 'do\n' + ' a = a + 1\n' + 'end do \n' + 'end subroutine\n' + ) + p = FCodePrinter({'source_format': 'free'}) + result = p.indent_code(codelines) + assert result == expected + +def test_Matrix_printing(): + x, y, z = symbols('x,y,z') + # Test returning a Matrix + mat = Matrix([x*y, Piecewise((2 + x, y>0), (y, True)), sin(z)]) + A = MatrixSymbol('A', 3, 1) + assert fcode(mat, A) == ( + " A(1, 1) = x*y\n" + " if (y > 0) then\n" + " A(2, 1) = x + 2\n" + " else\n" + " A(2, 1) = y\n" + " end if\n" + " A(3, 1) = sin(z)") + # Test using MatrixElements in expressions + expr = Piecewise((2*A[2, 0], x > 0), (A[2, 0], True)) + sin(A[1, 0]) + A[0, 0] + assert fcode(expr, standard=95) == ( + " merge(2*A(3, 1), A(3, 1), x > 0) + sin(A(2, 1)) + A(1, 1)") + # Test using MatrixElements in a Matrix + q = MatrixSymbol('q', 5, 1) + M = MatrixSymbol('M', 3, 3) + m = Matrix([[sin(q[1,0]), 0, cos(q[2,0])], + [q[1,0] + q[2,0], q[3, 0], 5], + [2*q[4, 0]/q[1,0], sqrt(q[0,0]) + 4, 0]]) + assert fcode(m, M) == ( + " M(1, 1) = sin(q(2, 1))\n" + " M(2, 1) = q(2, 1) + q(3, 1)\n" + " M(3, 1) = 2*q(5, 1)/q(2, 1)\n" + " M(1, 2) = 0\n" + " M(2, 2) = q(4, 1)\n" + " M(3, 2) = sqrt(q(1, 1)) + 4\n" + " M(1, 3) = cos(q(3, 1))\n" + " M(2, 3) = 5\n" + " M(3, 3) = 0") + + +def test_fcode_For(): + x, y = symbols('x y') + + f = For(x, Range(0, 10, 2), [Assignment(y, x * y)]) + sol = fcode(f) + assert sol == (" do x = 0, 9, 2\n" + " y = x*y\n" + " end do") + + +def test_fcode_Declaration(): + def check(expr, ref, **kwargs): + assert fcode(expr, standard=95, source_format='free', **kwargs) == ref + + i = symbols('i', integer=True) + var1 = Variable.deduced(i) + dcl1 = Declaration(var1) + check(dcl1, "integer*4 :: i") + + + x, y = symbols('x y') + var2 = Variable(x, float32, value=42, attrs={value_const}) + dcl2b = Declaration(var2) + check(dcl2b, 'real*4, parameter :: x = 42') + + var3 = Variable(y, type=bool_) + dcl3 = Declaration(var3) + check(dcl3, 'logical :: y') + + check(float32, "real*4") + check(float64, "real*8") + check(real, "real*4", type_aliases={real: float32}) + check(real, "real*8", type_aliases={real: float64}) + + +def test_MatrixElement_printing(): + # test cases for issue #11821 + A = MatrixSymbol("A", 1, 3) + B = MatrixSymbol("B", 1, 3) + C = MatrixSymbol("C", 1, 3) + + assert(fcode(A[0, 0]) == " A(1, 1)") + assert(fcode(3 * A[0, 0]) == " 3*A(1, 1)") + + F = C[0, 0].subs(C, A - B) + assert(fcode(F) == " (A - B)(1, 1)") + + +def test_aug_assign(): + x = symbols('x') + assert fcode(aug_assign(x, '+', 1), source_format='free') == 'x = x + 1' + + +def test_While(): + x = symbols('x') + assert fcode(While(abs(x) > 1, [aug_assign(x, '-', 1)]), source_format='free') == ( + 'do while (abs(x) > 1)\n' + ' x = x - 1\n' + 'end do' + ) + + +def test_FunctionPrototype_print(): + x = symbols('x') + n = symbols('n', integer=True) + vx = Variable(x, type=real) + vn = Variable(n, type=integer) + fp1 = FunctionPrototype(real, 'power', [vx, vn]) + # Should be changed to proper test once multi-line generation is working + # see https://github.com/sympy/sympy/issues/15824 + raises(NotImplementedError, lambda: fcode(fp1)) + + +def test_FunctionDefinition_print(): + x = symbols('x') + n = symbols('n', integer=True) + vx = Variable(x, type=real) + vn = Variable(n, type=integer) + body = [Assignment(x, x**n), Return(x)] + fd1 = FunctionDefinition(real, 'power', [vx, vn], body) + # Should be changed to proper test once multi-line generation is working + # see https://github.com/sympy/sympy/issues/15824 + raises(NotImplementedError, lambda: fcode(fd1)) diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/test_glsl.py b/phivenv/Lib/site-packages/sympy/printing/tests/test_glsl.py new file mode 100644 index 0000000000000000000000000000000000000000..86ec1dfe4a37d141e8435c369cb692d3a9a3b7bc --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/tests/test_glsl.py @@ -0,0 +1,998 @@ +from sympy.core import (pi, symbols, Rational, Integer, GoldenRatio, EulerGamma, + Catalan, Lambda, Dummy, Eq, Ne, Le, Lt, Gt, Ge) +from sympy.functions import Piecewise, sin, cos, Abs, exp, ceiling, sqrt +from sympy.testing.pytest import raises, warns_deprecated_sympy +from sympy.printing.glsl import GLSLPrinter +from sympy.printing.str import StrPrinter +from sympy.utilities.lambdify import implemented_function +from sympy.tensor import IndexedBase, Idx +from sympy.matrices import Matrix, MatrixSymbol +from sympy.core import Tuple +from sympy.printing.glsl import glsl_code +import textwrap + +x, y, z = symbols('x,y,z') + + +def test_printmethod(): + assert glsl_code(Abs(x)) == "abs(x)" + +def test_print_without_operators(): + assert glsl_code(x*y,use_operators = False) == 'mul(x, y)' + assert glsl_code(x**y+z,use_operators = False) == 'add(pow(x, y), z)' + assert glsl_code(x*(y+z),use_operators = False) == 'mul(x, add(y, z))' + assert glsl_code(x*(y+z),use_operators = False) == 'mul(x, add(y, z))' + assert glsl_code(x*(y+z**y**0.5),use_operators = False) == 'mul(x, add(y, pow(z, sqrt(y))))' + assert glsl_code(-x-y, use_operators=False, zero='zero()') == 'sub(zero(), add(x, y))' + assert glsl_code(-x-y, use_operators=False) == 'sub(0.0, add(x, y))' + +def test_glsl_code_sqrt(): + assert glsl_code(sqrt(x)) == "sqrt(x)" + assert glsl_code(x**0.5) == "sqrt(x)" + assert glsl_code(sqrt(x)) == "sqrt(x)" + + +def test_glsl_code_Pow(): + g = implemented_function('g', Lambda(x, 2*x)) + assert glsl_code(x**3) == "pow(x, 3.0)" + assert glsl_code(x**(y**3)) == "pow(x, pow(y, 3.0))" + assert glsl_code(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \ + "pow(3.5*2*x, -x + pow(y, x))/(pow(x, 2.0) + y)" + assert glsl_code(x**-1.0) == '1.0/x' + + +def test_glsl_code_Relational(): + assert glsl_code(Eq(x, y)) == "x == y" + assert glsl_code(Ne(x, y)) == "x != y" + assert glsl_code(Le(x, y)) == "x <= y" + assert glsl_code(Lt(x, y)) == "x < y" + assert glsl_code(Gt(x, y)) == "x > y" + assert glsl_code(Ge(x, y)) == "x >= y" + + +def test_glsl_code_constants_mathh(): + assert glsl_code(exp(1)) == "float E = 2.71828183;\nE" + assert glsl_code(pi) == "float pi = 3.14159265;\npi" + # assert glsl_code(oo) == "Number.POSITIVE_INFINITY" + # assert glsl_code(-oo) == "Number.NEGATIVE_INFINITY" + + +def test_glsl_code_constants_other(): + assert glsl_code(2*GoldenRatio) == "float GoldenRatio = 1.61803399;\n2*GoldenRatio" + assert glsl_code(2*Catalan) == "float Catalan = 0.915965594;\n2*Catalan" + assert glsl_code(2*EulerGamma) == "float EulerGamma = 0.577215665;\n2*EulerGamma" + + +def test_glsl_code_Rational(): + assert glsl_code(Rational(3, 7)) == "3.0/7.0" + assert glsl_code(Rational(18, 9)) == "2" + assert glsl_code(Rational(3, -7)) == "-3.0/7.0" + assert glsl_code(Rational(-3, -7)) == "3.0/7.0" + + +def test_glsl_code_Integer(): + assert glsl_code(Integer(67)) == "67" + assert glsl_code(Integer(-1)) == "-1" + + +def test_glsl_code_functions(): + assert glsl_code(sin(x) ** cos(x)) == "pow(sin(x), cos(x))" + + +def test_glsl_code_inline_function(): + x = symbols('x') + g = implemented_function('g', Lambda(x, 2*x)) + assert glsl_code(g(x)) == "2*x" + g = implemented_function('g', Lambda(x, 2*x/Catalan)) + assert glsl_code(g(x)) == "float Catalan = 0.915965594;\n2*x/Catalan" + A = IndexedBase('A') + i = Idx('i', symbols('n', integer=True)) + g = implemented_function('g', Lambda(x, x*(1 + x)*(2 + x))) + assert glsl_code(g(A[i]), assign_to=A[i]) == ( + "for (int i=0; i 1), (sin(x), x > 0)) + raises(ValueError, lambda: glsl_code(expr)) + + +def test_glsl_code_Piecewise_deep(): + p = glsl_code(2*Piecewise((x, x < 1), (x**2, True))) + s = \ +"""\ +2*((x < 1) ? ( + x +) +: ( + pow(x, 2.0) +))\ +""" + assert p == s + + +def test_glsl_code_settings(): + raises(TypeError, lambda: glsl_code(sin(x), method="garbage")) + + +def test_glsl_code_Indexed(): + n, m, o = symbols('n m o', integer=True) + i, j, k = Idx('i', n), Idx('j', m), Idx('k', o) + p = GLSLPrinter() + p._not_c = set() + + x = IndexedBase('x')[j] + assert p._print_Indexed(x) == 'x[j]' + A = IndexedBase('A')[i, j] + assert p._print_Indexed(A) == 'A[%s]' % (m*i+j) + B = IndexedBase('B')[i, j, k] + assert p._print_Indexed(B) == 'B[%s]' % (i*o*m+j*o+k) + + assert p._not_c == set() + +def test_glsl_code_list_tuple_Tuple(): + assert glsl_code([1,2,3,4]) == 'vec4(1, 2, 3, 4)' + assert glsl_code([1,2,3],glsl_types=False) == 'float[3](1, 2, 3)' + assert glsl_code([1,2,3]) == glsl_code((1,2,3)) + assert glsl_code([1,2,3]) == glsl_code(Tuple(1,2,3)) + + m = MatrixSymbol('A',3,4) + assert glsl_code([m[0],m[1]]) + +def test_glsl_code_loops_matrix_vector(): + n, m = symbols('n m', integer=True) + A = IndexedBase('A') + x = IndexedBase('x') + y = IndexedBase('y') + i = Idx('i', m) + j = Idx('j', n) + + s = ( + 'for (int i=0; i0), (y, True)), sin(z)]) + A = MatrixSymbol('A', 3, 1) + assert glsl_code(mat, assign_to=A) == ( +'''A[0][0] = x*y; +if (y > 0) { + A[1][0] = x + 2; +} +else { + A[1][0] = y; +} +A[2][0] = sin(z);''' ) + assert glsl_code(Matrix([A[0],A[1]])) + # Test using MatrixElements in expressions + expr = Piecewise((2*A[2, 0], x > 0), (A[2, 0], True)) + sin(A[1, 0]) + A[0, 0] + assert glsl_code(expr) == ( +'''((x > 0) ? ( + 2*A[2][0] +) +: ( + A[2][0] +)) + sin(A[1][0]) + A[0][0]''' ) + + # Test using MatrixElements in a Matrix + q = MatrixSymbol('q', 5, 1) + M = MatrixSymbol('M', 3, 3) + m = Matrix([[sin(q[1,0]), 0, cos(q[2,0])], + [q[1,0] + q[2,0], q[3, 0], 5], + [2*q[4, 0]/q[1,0], sqrt(q[0,0]) + 4, 0]]) + assert glsl_code(m,M) == ( +'''M[0][0] = sin(q[1]); +M[0][1] = 0; +M[0][2] = cos(q[2]); +M[1][0] = q[1] + q[2]; +M[1][1] = q[3]; +M[1][2] = 5; +M[2][0] = 2*q[4]/q[1]; +M[2][1] = sqrt(q[0]) + 4; +M[2][2] = 0;''' + ) + +def test_Matrices_1x7(): + gl = glsl_code + A = Matrix([1,2,3,4,5,6,7]) + assert gl(A) == 'float[7](1, 2, 3, 4, 5, 6, 7)' + assert gl(A.transpose()) == 'float[7](1, 2, 3, 4, 5, 6, 7)' + +def test_Matrices_1x7_array_type_int(): + gl = glsl_code + A = Matrix([1,2,3,4,5,6,7]) + assert gl(A, array_type='int') == 'int[7](1, 2, 3, 4, 5, 6, 7)' + +def test_Tuple_array_type_custom(): + gl = glsl_code + A = symbols('a b c') + assert gl(A, array_type='AbcType', glsl_types=False) == 'AbcType[3](a, b, c)' + +def test_Matrices_1x7_spread_assign_to_symbols(): + gl = glsl_code + A = Matrix([1,2,3,4,5,6,7]) + assign_to = symbols('x.a x.b x.c x.d x.e x.f x.g') + assert gl(A, assign_to=assign_to) == textwrap.dedent('''\ + x.a = 1; + x.b = 2; + x.c = 3; + x.d = 4; + x.e = 5; + x.f = 6; + x.g = 7;''' + ) + +def test_spread_assign_to_nested_symbols(): + gl = glsl_code + expr = ((1,2,3), (1,2,3)) + assign_to = (symbols('a b c'), symbols('x y z')) + assert gl(expr, assign_to=assign_to) == textwrap.dedent('''\ + a = 1; + b = 2; + c = 3; + x = 1; + y = 2; + z = 3;''' + ) + +def test_spread_assign_to_deeply_nested_symbols(): + gl = glsl_code + a, b, c, x, y, z = symbols('a b c x y z') + expr = (((1,2),3), ((1,2),3)) + assign_to = (((a, b), c), ((x, y), z)) + assert gl(expr, assign_to=assign_to) == textwrap.dedent('''\ + a = 1; + b = 2; + c = 3; + x = 1; + y = 2; + z = 3;''' + ) + +def test_matrix_of_tuples_spread_assign_to_symbols(): + gl = glsl_code + with warns_deprecated_sympy(): + expr = Matrix([[(1,2),(3,4)],[(5,6),(7,8)]]) + assign_to = (symbols('a b'), symbols('c d'), symbols('e f'), symbols('g h')) + assert gl(expr, assign_to) == textwrap.dedent('''\ + a = 1; + b = 2; + c = 3; + d = 4; + e = 5; + f = 6; + g = 7; + h = 8;''' + ) + +def test_cannot_assign_to_cause_mismatched_length(): + expr = (1, 2) + assign_to = symbols('x y z') + raises(ValueError, lambda: glsl_code(expr, assign_to)) + +def test_matrix_4x4_assign(): + gl = glsl_code + expr = MatrixSymbol('A',4,4) * MatrixSymbol('B',4,4) + MatrixSymbol('C',4,4) + assign_to = MatrixSymbol('X',4,4) + assert gl(expr, assign_to=assign_to) == textwrap.dedent('''\ + X[0][0] = A[0][0]*B[0][0] + A[0][1]*B[1][0] + A[0][2]*B[2][0] + A[0][3]*B[3][0] + C[0][0]; + X[0][1] = A[0][0]*B[0][1] + A[0][1]*B[1][1] + A[0][2]*B[2][1] + A[0][3]*B[3][1] + C[0][1]; + X[0][2] = A[0][0]*B[0][2] + A[0][1]*B[1][2] + A[0][2]*B[2][2] + A[0][3]*B[3][2] + C[0][2]; + X[0][3] = A[0][0]*B[0][3] + A[0][1]*B[1][3] + A[0][2]*B[2][3] + A[0][3]*B[3][3] + C[0][3]; + X[1][0] = A[1][0]*B[0][0] + A[1][1]*B[1][0] + A[1][2]*B[2][0] + A[1][3]*B[3][0] + C[1][0]; + X[1][1] = A[1][0]*B[0][1] + A[1][1]*B[1][1] + A[1][2]*B[2][1] + A[1][3]*B[3][1] + C[1][1]; + X[1][2] = A[1][0]*B[0][2] + A[1][1]*B[1][2] + A[1][2]*B[2][2] + A[1][3]*B[3][2] + C[1][2]; + X[1][3] = A[1][0]*B[0][3] + A[1][1]*B[1][3] + A[1][2]*B[2][3] + A[1][3]*B[3][3] + C[1][3]; + X[2][0] = A[2][0]*B[0][0] + A[2][1]*B[1][0] + A[2][2]*B[2][0] + A[2][3]*B[3][0] + C[2][0]; + X[2][1] = A[2][0]*B[0][1] + A[2][1]*B[1][1] + A[2][2]*B[2][1] + A[2][3]*B[3][1] + C[2][1]; + X[2][2] = A[2][0]*B[0][2] + A[2][1]*B[1][2] + A[2][2]*B[2][2] + A[2][3]*B[3][2] + C[2][2]; + X[2][3] = A[2][0]*B[0][3] + A[2][1]*B[1][3] + A[2][2]*B[2][3] + A[2][3]*B[3][3] + C[2][3]; + X[3][0] = A[3][0]*B[0][0] + A[3][1]*B[1][0] + A[3][2]*B[2][0] + A[3][3]*B[3][0] + C[3][0]; + X[3][1] = A[3][0]*B[0][1] + A[3][1]*B[1][1] + A[3][2]*B[2][1] + A[3][3]*B[3][1] + C[3][1]; + X[3][2] = A[3][0]*B[0][2] + A[3][1]*B[1][2] + A[3][2]*B[2][2] + A[3][3]*B[3][2] + C[3][2]; + X[3][3] = A[3][0]*B[0][3] + A[3][1]*B[1][3] + A[3][2]*B[2][3] + A[3][3]*B[3][3] + C[3][3];''' + ) + +def test_1xN_vecs(): + gl = glsl_code + for i in range(1,10): + A = Matrix(range(i)) + assert gl(A.transpose()) == gl(A) + assert gl(A,mat_transpose=True) == gl(A) + if i > 1: + if i <= 4: + assert gl(A) == 'vec%s(%s)' % (i,', '.join(str(s) for s in range(i))) + else: + assert gl(A) == 'float[%s](%s)' % (i,', '.join(str(s) for s in range(i))) + +def test_MxN_mats(): + generatedAssertions='def test_misc_mats():\n' + for i in range(1,6): + for j in range(1,6): + A = Matrix([[x + y*j for x in range(j)] for y in range(i)]) + gl = glsl_code(A) + glTransposed = glsl_code(A,mat_transpose=True) + generatedAssertions+=' mat = '+StrPrinter()._print(A)+'\n\n' + generatedAssertions+=' gl = \'\'\''+gl+'\'\'\'\n' + generatedAssertions+=' glTransposed = \'\'\''+glTransposed+'\'\'\'\n\n' + generatedAssertions+=' assert glsl_code(mat) == gl\n' + generatedAssertions+=' assert glsl_code(mat,mat_transpose=True) == glTransposed\n' + if i == 1 and j == 1: + assert gl == '0' + elif i <= 4 and j <= 4 and i>1 and j>1: + assert gl.startswith('mat%s' % j) + assert glTransposed.startswith('mat%s' % i) + elif i == 1 and j <= 4: + assert gl.startswith('vec') + elif j == 1 and i <= 4: + assert gl.startswith('vec') + elif i == 1: + assert gl.startswith('float[%s]('% j*i) + assert glTransposed.startswith('float[%s]('% j*i) + elif j == 1: + assert gl.startswith('float[%s]('% i*j) + assert glTransposed.startswith('float[%s]('% i*j) + else: + assert gl.startswith('float[%s](' % (i*j)) + assert glTransposed.startswith('float[%s](' % (i*j)) + glNested = glsl_code(A,mat_nested=True) + glNestedTransposed = glsl_code(A,mat_transpose=True,mat_nested=True) + assert glNested.startswith('float[%s][%s]' % (i,j)) + assert glNestedTransposed.startswith('float[%s][%s]' % (j,i)) + generatedAssertions+=' glNested = \'\'\''+glNested+'\'\'\'\n' + generatedAssertions+=' glNestedTransposed = \'\'\''+glNestedTransposed+'\'\'\'\n\n' + generatedAssertions+=' assert glsl_code(mat,mat_nested=True) == glNested\n' + generatedAssertions+=' assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed\n\n' + generateAssertions = False # set this to true to write bake these generated tests to a file + if generateAssertions: + gen = open('test_glsl_generated_matrices.py','w') + gen.write(generatedAssertions) + gen.close() + + +# these assertions were generated from the previous function +# glsl has complicated rules and this makes it easier to look over all the cases +def test_misc_mats(): + + mat = Matrix([[0]]) + + gl = '''0''' + glTransposed = '''0''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([[0, 1]]) + + gl = '''vec2(0, 1)''' + glTransposed = '''vec2(0, 1)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([[0, 1, 2]]) + + gl = '''vec3(0, 1, 2)''' + glTransposed = '''vec3(0, 1, 2)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([[0, 1, 2, 3]]) + + gl = '''vec4(0, 1, 2, 3)''' + glTransposed = '''vec4(0, 1, 2, 3)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([[0, 1, 2, 3, 4]]) + + gl = '''float[5](0, 1, 2, 3, 4)''' + glTransposed = '''float[5](0, 1, 2, 3, 4)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[0], +[1]]) + + gl = '''vec2(0, 1)''' + glTransposed = '''vec2(0, 1)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[0, 1], +[2, 3]]) + + gl = '''mat2(0, 1, 2, 3)''' + glTransposed = '''mat2(0, 2, 1, 3)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[0, 1, 2], +[3, 4, 5]]) + + gl = '''mat3x2(0, 1, 2, 3, 4, 5)''' + glTransposed = '''mat2x3(0, 3, 1, 4, 2, 5)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[0, 1, 2, 3], +[4, 5, 6, 7]]) + + gl = '''mat4x2(0, 1, 2, 3, 4, 5, 6, 7)''' + glTransposed = '''mat2x4(0, 4, 1, 5, 2, 6, 3, 7)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[0, 1, 2, 3, 4], +[5, 6, 7, 8, 9]]) + + gl = '''float[10]( + 0, 1, 2, 3, 4, + 5, 6, 7, 8, 9 +) /* a 2x5 matrix */''' + glTransposed = '''float[10]( + 0, 5, + 1, 6, + 2, 7, + 3, 8, + 4, 9 +) /* a 5x2 matrix */''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + glNested = '''float[2][5]( + float[](0, 1, 2, 3, 4), + float[](5, 6, 7, 8, 9) +)''' + glNestedTransposed = '''float[5][2]( + float[](0, 5), + float[](1, 6), + float[](2, 7), + float[](3, 8), + float[](4, 9) +)''' + + assert glsl_code(mat,mat_nested=True) == glNested + assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed + + mat = Matrix([ +[0], +[1], +[2]]) + + gl = '''vec3(0, 1, 2)''' + glTransposed = '''vec3(0, 1, 2)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[0, 1], +[2, 3], +[4, 5]]) + + gl = '''mat2x3(0, 1, 2, 3, 4, 5)''' + glTransposed = '''mat3x2(0, 2, 4, 1, 3, 5)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[0, 1, 2], +[3, 4, 5], +[6, 7, 8]]) + + gl = '''mat3(0, 1, 2, 3, 4, 5, 6, 7, 8)''' + glTransposed = '''mat3(0, 3, 6, 1, 4, 7, 2, 5, 8)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[0, 1, 2, 3], +[4, 5, 6, 7], +[8, 9, 10, 11]]) + + gl = '''mat4x3(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11)''' + glTransposed = '''mat3x4(0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[ 0, 1, 2, 3, 4], +[ 5, 6, 7, 8, 9], +[10, 11, 12, 13, 14]]) + + gl = '''float[15]( + 0, 1, 2, 3, 4, + 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14 +) /* a 3x5 matrix */''' + glTransposed = '''float[15]( + 0, 5, 10, + 1, 6, 11, + 2, 7, 12, + 3, 8, 13, + 4, 9, 14 +) /* a 5x3 matrix */''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + glNested = '''float[3][5]( + float[]( 0, 1, 2, 3, 4), + float[]( 5, 6, 7, 8, 9), + float[](10, 11, 12, 13, 14) +)''' + glNestedTransposed = '''float[5][3]( + float[](0, 5, 10), + float[](1, 6, 11), + float[](2, 7, 12), + float[](3, 8, 13), + float[](4, 9, 14) +)''' + + assert glsl_code(mat,mat_nested=True) == glNested + assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed + + mat = Matrix([ +[0], +[1], +[2], +[3]]) + + gl = '''vec4(0, 1, 2, 3)''' + glTransposed = '''vec4(0, 1, 2, 3)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[0, 1], +[2, 3], +[4, 5], +[6, 7]]) + + gl = '''mat2x4(0, 1, 2, 3, 4, 5, 6, 7)''' + glTransposed = '''mat4x2(0, 2, 4, 6, 1, 3, 5, 7)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[0, 1, 2], +[3, 4, 5], +[6, 7, 8], +[9, 10, 11]]) + + gl = '''mat3x4(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11)''' + glTransposed = '''mat4x3(0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[ 0, 1, 2, 3], +[ 4, 5, 6, 7], +[ 8, 9, 10, 11], +[12, 13, 14, 15]]) + + gl = '''mat4( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)''' + glTransposed = '''mat4(0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[ 0, 1, 2, 3, 4], +[ 5, 6, 7, 8, 9], +[10, 11, 12, 13, 14], +[15, 16, 17, 18, 19]]) + + gl = '''float[20]( + 0, 1, 2, 3, 4, + 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19 +) /* a 4x5 matrix */''' + glTransposed = '''float[20]( + 0, 5, 10, 15, + 1, 6, 11, 16, + 2, 7, 12, 17, + 3, 8, 13, 18, + 4, 9, 14, 19 +) /* a 5x4 matrix */''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + glNested = '''float[4][5]( + float[]( 0, 1, 2, 3, 4), + float[]( 5, 6, 7, 8, 9), + float[](10, 11, 12, 13, 14), + float[](15, 16, 17, 18, 19) +)''' + glNestedTransposed = '''float[5][4]( + float[](0, 5, 10, 15), + float[](1, 6, 11, 16), + float[](2, 7, 12, 17), + float[](3, 8, 13, 18), + float[](4, 9, 14, 19) +)''' + + assert glsl_code(mat,mat_nested=True) == glNested + assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed + + mat = Matrix([ +[0], +[1], +[2], +[3], +[4]]) + + gl = '''float[5](0, 1, 2, 3, 4)''' + glTransposed = '''float[5](0, 1, 2, 3, 4)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[0, 1], +[2, 3], +[4, 5], +[6, 7], +[8, 9]]) + + gl = '''float[10]( + 0, 1, + 2, 3, + 4, 5, + 6, 7, + 8, 9 +) /* a 5x2 matrix */''' + glTransposed = '''float[10]( + 0, 2, 4, 6, 8, + 1, 3, 5, 7, 9 +) /* a 2x5 matrix */''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + glNested = '''float[5][2]( + float[](0, 1), + float[](2, 3), + float[](4, 5), + float[](6, 7), + float[](8, 9) +)''' + glNestedTransposed = '''float[2][5]( + float[](0, 2, 4, 6, 8), + float[](1, 3, 5, 7, 9) +)''' + + assert glsl_code(mat,mat_nested=True) == glNested + assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed + + mat = Matrix([ +[ 0, 1, 2], +[ 3, 4, 5], +[ 6, 7, 8], +[ 9, 10, 11], +[12, 13, 14]]) + + gl = '''float[15]( + 0, 1, 2, + 3, 4, 5, + 6, 7, 8, + 9, 10, 11, + 12, 13, 14 +) /* a 5x3 matrix */''' + glTransposed = '''float[15]( + 0, 3, 6, 9, 12, + 1, 4, 7, 10, 13, + 2, 5, 8, 11, 14 +) /* a 3x5 matrix */''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + glNested = '''float[5][3]( + float[]( 0, 1, 2), + float[]( 3, 4, 5), + float[]( 6, 7, 8), + float[]( 9, 10, 11), + float[](12, 13, 14) +)''' + glNestedTransposed = '''float[3][5]( + float[](0, 3, 6, 9, 12), + float[](1, 4, 7, 10, 13), + float[](2, 5, 8, 11, 14) +)''' + + assert glsl_code(mat,mat_nested=True) == glNested + assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed + + mat = Matrix([ +[ 0, 1, 2, 3], +[ 4, 5, 6, 7], +[ 8, 9, 10, 11], +[12, 13, 14, 15], +[16, 17, 18, 19]]) + + gl = '''float[20]( + 0, 1, 2, 3, + 4, 5, 6, 7, + 8, 9, 10, 11, + 12, 13, 14, 15, + 16, 17, 18, 19 +) /* a 5x4 matrix */''' + glTransposed = '''float[20]( + 0, 4, 8, 12, 16, + 1, 5, 9, 13, 17, + 2, 6, 10, 14, 18, + 3, 7, 11, 15, 19 +) /* a 4x5 matrix */''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + glNested = '''float[5][4]( + float[]( 0, 1, 2, 3), + float[]( 4, 5, 6, 7), + float[]( 8, 9, 10, 11), + float[](12, 13, 14, 15), + float[](16, 17, 18, 19) +)''' + glNestedTransposed = '''float[4][5]( + float[](0, 4, 8, 12, 16), + float[](1, 5, 9, 13, 17), + float[](2, 6, 10, 14, 18), + float[](3, 7, 11, 15, 19) +)''' + + assert glsl_code(mat,mat_nested=True) == glNested + assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed + + mat = Matrix([ +[ 0, 1, 2, 3, 4], +[ 5, 6, 7, 8, 9], +[10, 11, 12, 13, 14], +[15, 16, 17, 18, 19], +[20, 21, 22, 23, 24]]) + + gl = '''float[25]( + 0, 1, 2, 3, 4, + 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, + 20, 21, 22, 23, 24 +) /* a 5x5 matrix */''' + glTransposed = '''float[25]( + 0, 5, 10, 15, 20, + 1, 6, 11, 16, 21, + 2, 7, 12, 17, 22, + 3, 8, 13, 18, 23, + 4, 9, 14, 19, 24 +) /* a 5x5 matrix */''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + glNested = '''float[5][5]( + float[]( 0, 1, 2, 3, 4), + float[]( 5, 6, 7, 8, 9), + float[](10, 11, 12, 13, 14), + float[](15, 16, 17, 18, 19), + float[](20, 21, 22, 23, 24) +)''' + glNestedTransposed = '''float[5][5]( + float[](0, 5, 10, 15, 20), + float[](1, 6, 11, 16, 21), + float[](2, 7, 12, 17, 22), + float[](3, 8, 13, 18, 23), + float[](4, 9, 14, 19, 24) +)''' + + assert glsl_code(mat,mat_nested=True) == glNested + assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/test_gtk.py b/phivenv/Lib/site-packages/sympy/printing/tests/test_gtk.py new file mode 100644 index 0000000000000000000000000000000000000000..5a595ab04d3a29d23e06ec12207bf917392aebce --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/tests/test_gtk.py @@ -0,0 +1,18 @@ +from sympy.functions.elementary.trigonometric import sin +from sympy.printing.gtk import print_gtk +from sympy.testing.pytest import XFAIL, raises + +# this test fails if python-lxml isn't installed. We don't want to depend on +# anything with SymPy + + +@XFAIL +def test_1(): + from sympy.abc import x + print_gtk(x**2, start_viewer=False) + print_gtk(x**2 + sin(x)/4, start_viewer=False) + + +def test_settings(): + from sympy.abc import x + raises(TypeError, lambda: print_gtk(x, method="garbage")) diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/test_jax.py b/phivenv/Lib/site-packages/sympy/printing/tests/test_jax.py new file mode 100644 index 0000000000000000000000000000000000000000..365d87c5b91fdd49a8e46cfde9c2b5792c23a03c --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/tests/test_jax.py @@ -0,0 +1,370 @@ +from sympy.concrete.summations import Sum +from sympy.core.mod import Mod +from sympy.core.relational import (Equality, Unequality) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.matrices.expressions.blockmatrix import BlockMatrix +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.expressions.special import Identity +from sympy.utilities.lambdify import lambdify + +from sympy.abc import x, i, j, a, b, c, d +from sympy.core import Function, Pow, Symbol +from sympy.codegen.matrix_nodes import MatrixSolve +from sympy.codegen.numpy_nodes import logaddexp, logaddexp2 +from sympy.codegen.cfunctions import log1p, expm1, hypot, log10, exp2, log2, Sqrt +from sympy.tensor.array import Array +from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct, ArrayAdd, \ + PermuteDims, ArrayDiagonal +from sympy.printing.numpy import JaxPrinter, _jax_known_constants, _jax_known_functions +from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array + +from sympy.testing.pytest import skip, raises +from sympy.external import import_module + +# Unlike NumPy which will aggressively promote operands to double precision, +# jax always uses single precision. Double precision in jax can be +# configured before the call to `import jax`, however this must be explicitly +# configured and is not fully supported. Thus, the tests here have been modified +# from the tests in test_numpy.py, only in the fact that they assert lambdify +# function accuracy to only single precision accuracy. +# https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision + +jax = import_module('jax') + +if jax: + deafult_float_info = jax.numpy.finfo(jax.numpy.array([]).dtype) + JAX_DEFAULT_EPSILON = deafult_float_info.eps + + +def test_jax_piecewise_regression(): + """ + NumPyPrinter needs to print Piecewise()'s choicelist as a list to avoid + breaking compatibility with numpy 1.8. This is not necessary in numpy 1.9+. + See gh-9747 and gh-9749 for details. + """ + printer = JaxPrinter() + p = Piecewise((1, x < 0), (0, True)) + assert printer.doprint(p) == \ + 'jax.numpy.select([jax.numpy.less(x, 0),True], [1,0], default=jax.numpy.nan)' + assert printer.module_imports == {'jax.numpy': {'select', 'less', 'nan'}} + + +def test_jax_logaddexp(): + lae = logaddexp(a, b) + assert JaxPrinter().doprint(lae) == 'jax.numpy.logaddexp(a, b)' + lae2 = logaddexp2(a, b) + assert JaxPrinter().doprint(lae2) == 'jax.numpy.logaddexp2(a, b)' + + +def test_jax_sum(): + if not jax: + skip("JAX not installed") + + s = Sum(x ** i, (i, a, b)) + f = lambdify((a, b, x), s, 'jax') + + a_, b_ = 0, 10 + x_ = jax.numpy.linspace(-1, +1, 10) + assert jax.numpy.allclose(f(a_, b_, x_), sum(x_ ** i_ for i_ in range(a_, b_ + 1))) + + s = Sum(i * x, (i, a, b)) + f = lambdify((a, b, x), s, 'jax') + + a_, b_ = 0, 10 + x_ = jax.numpy.linspace(-1, +1, 10) + assert jax.numpy.allclose(f(a_, b_, x_), sum(i_ * x_ for i_ in range(a_, b_ + 1))) + + +def test_jax_multiple_sums(): + if not jax: + skip("JAX not installed") + + s = Sum((x + j) * i, (i, a, b), (j, c, d)) + f = lambdify((a, b, c, d, x), s, 'jax') + + a_, b_ = 0, 10 + c_, d_ = 11, 21 + x_ = jax.numpy.linspace(-1, +1, 10) + assert jax.numpy.allclose(f(a_, b_, c_, d_, x_), + sum((x_ + j_) * i_ for i_ in range(a_, b_ + 1) for j_ in range(c_, d_ + 1))) + + +def test_jax_codegen_einsum(): + if not jax: + skip("JAX not installed") + + M = MatrixSymbol("M", 2, 2) + N = MatrixSymbol("N", 2, 2) + + cg = convert_matrix_to_array(M * N) + f = lambdify((M, N), cg, 'jax') + + ma = jax.numpy.array([[1, 2], [3, 4]]) + mb = jax.numpy.array([[1,-2], [-1, 3]]) + assert (f(ma, mb) == jax.numpy.matmul(ma, mb)).all() + + +def test_jax_codegen_extra(): + if not jax: + skip("JAX not installed") + + M = MatrixSymbol("M", 2, 2) + N = MatrixSymbol("N", 2, 2) + P = MatrixSymbol("P", 2, 2) + Q = MatrixSymbol("Q", 2, 2) + ma = jax.numpy.array([[1, 2], [3, 4]]) + mb = jax.numpy.array([[1,-2], [-1, 3]]) + mc = jax.numpy.array([[2, 0], [1, 2]]) + md = jax.numpy.array([[1,-1], [4, 7]]) + + cg = ArrayTensorProduct(M, N) + f = lambdify((M, N), cg, 'jax') + assert (f(ma, mb) == jax.numpy.einsum(ma, [0, 1], mb, [2, 3])).all() + + cg = ArrayAdd(M, N) + f = lambdify((M, N), cg, 'jax') + assert (f(ma, mb) == ma+mb).all() + + cg = ArrayAdd(M, N, P) + f = lambdify((M, N, P), cg, 'jax') + assert (f(ma, mb, mc) == ma+mb+mc).all() + + cg = ArrayAdd(M, N, P, Q) + f = lambdify((M, N, P, Q), cg, 'jax') + assert (f(ma, mb, mc, md) == ma+mb+mc+md).all() + + cg = PermuteDims(M, [1, 0]) + f = lambdify((M,), cg, 'jax') + assert (f(ma) == ma.T).all() + + cg = PermuteDims(ArrayTensorProduct(M, N), [1, 2, 3, 0]) + f = lambdify((M, N), cg, 'jax') + assert (f(ma, mb) == jax.numpy.transpose(jax.numpy.einsum(ma, [0, 1], mb, [2, 3]), (1, 2, 3, 0))).all() + + cg = ArrayDiagonal(ArrayTensorProduct(M, N), (1, 2)) + f = lambdify((M, N), cg, 'jax') + assert (f(ma, mb) == jax.numpy.diagonal(jax.numpy.einsum(ma, [0, 1], mb, [2, 3]), axis1=1, axis2=2)).all() + + +def test_jax_relational(): + if not jax: + skip("JAX not installed") + + e = Equality(x, 1) + + f = lambdify((x,), e, 'jax') + x_ = jax.numpy.array([0, 1, 2]) + assert jax.numpy.array_equal(f(x_), [False, True, False]) + + e = Unequality(x, 1) + + f = lambdify((x,), e, 'jax') + x_ = jax.numpy.array([0, 1, 2]) + assert jax.numpy.array_equal(f(x_), [True, False, True]) + + e = (x < 1) + + f = lambdify((x,), e, 'jax') + x_ = jax.numpy.array([0, 1, 2]) + assert jax.numpy.array_equal(f(x_), [True, False, False]) + + e = (x <= 1) + + f = lambdify((x,), e, 'jax') + x_ = jax.numpy.array([0, 1, 2]) + assert jax.numpy.array_equal(f(x_), [True, True, False]) + + e = (x > 1) + + f = lambdify((x,), e, 'jax') + x_ = jax.numpy.array([0, 1, 2]) + assert jax.numpy.array_equal(f(x_), [False, False, True]) + + e = (x >= 1) + + f = lambdify((x,), e, 'jax') + x_ = jax.numpy.array([0, 1, 2]) + assert jax.numpy.array_equal(f(x_), [False, True, True]) + + # Multi-condition expressions + e = (x >= 1) & (x < 2) + f = lambdify((x,), e, 'jax') + x_ = jax.numpy.array([0, 1, 2]) + assert jax.numpy.array_equal(f(x_), [False, True, False]) + + e = (x >= 1) | (x < 2) + f = lambdify((x,), e, 'jax') + x_ = jax.numpy.array([0, 1, 2]) + assert jax.numpy.array_equal(f(x_), [True, True, True]) + +def test_jax_mod(): + if not jax: + skip("JAX not installed") + + e = Mod(a, b) + f = lambdify((a, b), e, 'jax') + + a_ = jax.numpy.array([0, 1, 2, 3]) + b_ = 2 + assert jax.numpy.array_equal(f(a_, b_), [0, 1, 0, 1]) + + a_ = jax.numpy.array([0, 1, 2, 3]) + b_ = jax.numpy.array([2, 2, 2, 2]) + assert jax.numpy.array_equal(f(a_, b_), [0, 1, 0, 1]) + + a_ = jax.numpy.array([2, 3, 4, 5]) + b_ = jax.numpy.array([2, 3, 4, 5]) + assert jax.numpy.array_equal(f(a_, b_), [0, 0, 0, 0]) + + +def test_jax_pow(): + if not jax: + skip('JAX not installed') + + expr = Pow(2, -1, evaluate=False) + f = lambdify([], expr, 'jax') + assert f() == 0.5 + + +def test_jax_expm1(): + if not jax: + skip("JAX not installed") + + f = lambdify((a,), expm1(a), 'jax') + assert abs(f(1e-10) - 1e-10 - 5e-21) <= 1e-10 * JAX_DEFAULT_EPSILON + + +def test_jax_log1p(): + if not jax: + skip("JAX not installed") + + f = lambdify((a,), log1p(a), 'jax') + assert abs(f(1e-99) - 1e-99) <= 1e-99 * JAX_DEFAULT_EPSILON + +def test_jax_hypot(): + if not jax: + skip("JAX not installed") + assert abs(lambdify((a, b), hypot(a, b), 'jax')(3, 4) - 5) <= JAX_DEFAULT_EPSILON + +def test_jax_log10(): + if not jax: + skip("JAX not installed") + + assert abs(lambdify((a,), log10(a), 'jax')(100) - 2) <= JAX_DEFAULT_EPSILON + + +def test_jax_exp2(): + if not jax: + skip("JAX not installed") + assert abs(lambdify((a,), exp2(a), 'jax')(5) - 32) <= JAX_DEFAULT_EPSILON + + +def test_jax_log2(): + if not jax: + skip("JAX not installed") + assert abs(lambdify((a,), log2(a), 'jax')(256) - 8) <= JAX_DEFAULT_EPSILON + + +def test_jax_Sqrt(): + if not jax: + skip("JAX not installed") + assert abs(lambdify((a,), Sqrt(a), 'jax')(4) - 2) <= JAX_DEFAULT_EPSILON + + +def test_jax_sqrt(): + if not jax: + skip("JAX not installed") + assert abs(lambdify((a,), sqrt(a), 'jax')(4) - 2) <= JAX_DEFAULT_EPSILON + + +def test_jax_matsolve(): + if not jax: + skip("JAX not installed") + + M = MatrixSymbol("M", 3, 3) + x = MatrixSymbol("x", 3, 1) + + expr = M**(-1) * x + x + matsolve_expr = MatrixSolve(M, x) + x + + f = lambdify((M, x), expr, 'jax') + f_matsolve = lambdify((M, x), matsolve_expr, 'jax') + + m0 = jax.numpy.array([[1, 2, 3], [3, 2, 5], [5, 6, 7]]) + assert jax.numpy.linalg.matrix_rank(m0) == 3 + + x0 = jax.numpy.array([3, 4, 5]) + + assert jax.numpy.allclose(f_matsolve(m0, x0), f(m0, x0)) + + +def test_16857(): + if not jax: + skip("JAX not installed") + + a_1 = MatrixSymbol('a_1', 10, 3) + a_2 = MatrixSymbol('a_2', 10, 3) + a_3 = MatrixSymbol('a_3', 10, 3) + a_4 = MatrixSymbol('a_4', 10, 3) + A = BlockMatrix([[a_1, a_2], [a_3, a_4]]) + assert A.shape == (20, 6) + + printer = JaxPrinter() + assert printer.doprint(A) == 'jax.numpy.block([[a_1, a_2], [a_3, a_4]])' + + +def test_issue_17006(): + if not jax: + skip("JAX not installed") + + M = MatrixSymbol("M", 2, 2) + + f = lambdify(M, M + Identity(2), 'jax') + ma = jax.numpy.array([[1, 2], [3, 4]]) + mr = jax.numpy.array([[2, 2], [3, 5]]) + + assert (f(ma) == mr).all() + + from sympy.core.symbol import symbols + n = symbols('n', integer=True) + N = MatrixSymbol("M", n, n) + raises(NotImplementedError, lambda: lambdify(N, N + Identity(n), 'jax')) + + +def test_jax_array(): + assert JaxPrinter().doprint(Array(((1, 2), (3, 5)))) == 'jax.numpy.array([[1, 2], [3, 5]])' + assert JaxPrinter().doprint(Array((1, 2))) == 'jax.numpy.array([1, 2])' + + +def test_jax_known_funcs_consts(): + assert _jax_known_constants['NaN'] == 'jax.numpy.nan' + assert _jax_known_constants['EulerGamma'] == 'jax.numpy.euler_gamma' + + assert _jax_known_functions['acos'] == 'jax.numpy.arccos' + assert _jax_known_functions['log'] == 'jax.numpy.log' + + +def test_jax_print_methods(): + prntr = JaxPrinter() + assert hasattr(prntr, '_print_acos') + assert hasattr(prntr, '_print_log') + + +def test_jax_printmethod(): + printer = JaxPrinter() + assert hasattr(printer, 'printmethod') + assert printer.printmethod == '_jaxcode' + + +def test_jax_custom_print_method(): + + class expm1(Function): + + def _jaxcode(self, printer): + x, = self.args + function = f'expm1({printer._print(x)})' + return printer._module_format(printer._module + '.' + function) + + printer = JaxPrinter() + assert printer.doprint(expm1(Symbol('x'))) == 'jax.numpy.expm1(x)' diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/test_jscode.py b/phivenv/Lib/site-packages/sympy/printing/tests/test_jscode.py new file mode 100644 index 0000000000000000000000000000000000000000..9199a8e0d62e87f2e964cb1712726a21c894fd20 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/tests/test_jscode.py @@ -0,0 +1,396 @@ +from sympy.core import (pi, oo, symbols, Rational, Integer, GoldenRatio, + EulerGamma, Catalan, Lambda, Dummy, S, Eq, Ne, Le, + Lt, Gt, Ge, Mod) +from sympy.functions import (Piecewise, sin, cos, Abs, exp, ceiling, sqrt, + sinh, cosh, tanh, asin, acos, acosh, Max, Min) +from sympy.testing.pytest import raises +from sympy.printing.jscode import JavascriptCodePrinter +from sympy.utilities.lambdify import implemented_function +from sympy.tensor import IndexedBase, Idx +from sympy.matrices import Matrix, MatrixSymbol + +from sympy.printing.jscode import jscode + +x, y, z = symbols('x,y,z') + + +def test_printmethod(): + assert jscode(Abs(x)) == "Math.abs(x)" + + +def test_jscode_sqrt(): + assert jscode(sqrt(x)) == "Math.sqrt(x)" + assert jscode(x**0.5) == "Math.sqrt(x)" + assert jscode(x**(S.One/3)) == "Math.cbrt(x)" + + +def test_jscode_Pow(): + g = implemented_function('g', Lambda(x, 2*x)) + assert jscode(x**3) == "Math.pow(x, 3)" + assert jscode(x**(y**3)) == "Math.pow(x, Math.pow(y, 3))" + assert jscode(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \ + "Math.pow(3.5*2*x, -x + Math.pow(y, x))/(Math.pow(x, 2) + y)" + assert jscode(x**-1.0) == '1/x' + + +def test_jscode_constants_mathh(): + assert jscode(exp(1)) == "Math.E" + assert jscode(pi) == "Math.PI" + assert jscode(oo) == "Number.POSITIVE_INFINITY" + assert jscode(-oo) == "Number.NEGATIVE_INFINITY" + + +def test_jscode_constants_other(): + assert jscode( + 2*GoldenRatio) == "var GoldenRatio = %s;\n2*GoldenRatio" % GoldenRatio.evalf(17) + assert jscode(2*Catalan) == "var Catalan = %s;\n2*Catalan" % Catalan.evalf(17) + assert jscode( + 2*EulerGamma) == "var EulerGamma = %s;\n2*EulerGamma" % EulerGamma.evalf(17) + + +def test_jscode_Rational(): + assert jscode(Rational(3, 7)) == "3/7" + assert jscode(Rational(18, 9)) == "2" + assert jscode(Rational(3, -7)) == "-3/7" + assert jscode(Rational(-3, -7)) == "3/7" + + +def test_Relational(): + assert jscode(Eq(x, y)) == "x == y" + assert jscode(Ne(x, y)) == "x != y" + assert jscode(Le(x, y)) == "x <= y" + assert jscode(Lt(x, y)) == "x < y" + assert jscode(Gt(x, y)) == "x > y" + assert jscode(Ge(x, y)) == "x >= y" + + +def test_Mod(): + assert jscode(Mod(x, y)) == '((x % y) + y) % y' + assert jscode(Mod(x, x + y)) == '((x % (x + y)) + (x + y)) % (x + y)' + p1, p2 = symbols('p1 p2', positive=True) + assert jscode(Mod(p1, p2)) == 'p1 % p2' + assert jscode(Mod(p1, p2 + 3)) == 'p1 % (p2 + 3)' + assert jscode(Mod(-3, -7, evaluate=False)) == '(-3) % (-7)' + assert jscode(-Mod(p1, p2)) == '-(p1 % p2)' + assert jscode(x*Mod(p1, p2)) == 'x*(p1 % p2)' + + +def test_jscode_Integer(): + assert jscode(Integer(67)) == "67" + assert jscode(Integer(-1)) == "-1" + + +def test_jscode_functions(): + assert jscode(sin(x) ** cos(x)) == "Math.pow(Math.sin(x), Math.cos(x))" + assert jscode(sinh(x) * cosh(x)) == "Math.sinh(x)*Math.cosh(x)" + assert jscode(Max(x, y) + Min(x, y)) == "Math.max(x, y) + Math.min(x, y)" + assert jscode(tanh(x)*acosh(y)) == "Math.tanh(x)*Math.acosh(y)" + assert jscode(asin(x)-acos(y)) == "-Math.acos(y) + Math.asin(x)" + + +def test_jscode_inline_function(): + x = symbols('x') + g = implemented_function('g', Lambda(x, 2*x)) + assert jscode(g(x)) == "2*x" + g = implemented_function('g', Lambda(x, 2*x/Catalan)) + assert jscode(g(x)) == "var Catalan = %s;\n2*x/Catalan" % Catalan.evalf(17) + A = IndexedBase('A') + i = Idx('i', symbols('n', integer=True)) + g = implemented_function('g', Lambda(x, x*(1 + x)*(2 + x))) + assert jscode(g(A[i]), assign_to=A[i]) == ( + "for (var i=0; i 1), (sin(x), x > 0)) + raises(ValueError, lambda: jscode(expr)) + + +def test_jscode_Piecewise_deep(): + p = jscode(2*Piecewise((x, x < 1), (x**2, True))) + s = \ +"""\ +2*((x < 1) ? ( + x +) +: ( + Math.pow(x, 2) +))\ +""" + assert p == s + + +def test_jscode_settings(): + raises(TypeError, lambda: jscode(sin(x), method="garbage")) + + +def test_jscode_Indexed(): + n, m, o = symbols('n m o', integer=True) + i, j, k = Idx('i', n), Idx('j', m), Idx('k', o) + p = JavascriptCodePrinter() + p._not_c = set() + + x = IndexedBase('x')[j] + assert p._print_Indexed(x) == 'x[j]' + A = IndexedBase('A')[i, j] + assert p._print_Indexed(A) == 'A[%s]' % (m*i+j) + B = IndexedBase('B')[i, j, k] + assert p._print_Indexed(B) == 'B[%s]' % (i*o*m+j*o+k) + + assert p._not_c == set() + + +def test_jscode_loops_matrix_vector(): + n, m = symbols('n m', integer=True) + A = IndexedBase('A') + x = IndexedBase('x') + y = IndexedBase('y') + i = Idx('i', m) + j = Idx('j', n) + + s = ( + 'for (var i=0; i0), (y, True)), sin(z)]) + A = MatrixSymbol('A', 3, 1) + assert jscode(mat, A) == ( + "A[0] = x*y;\n" + "if (y > 0) {\n" + " A[1] = x + 2;\n" + "}\n" + "else {\n" + " A[1] = y;\n" + "}\n" + "A[2] = Math.sin(z);") + # Test using MatrixElements in expressions + expr = Piecewise((2*A[2, 0], x > 0), (A[2, 0], True)) + sin(A[1, 0]) + A[0, 0] + assert jscode(expr) == ( + "((x > 0) ? (\n" + " 2*A[2]\n" + ")\n" + ": (\n" + " A[2]\n" + ")) + Math.sin(A[1]) + A[0]") + # Test using MatrixElements in a Matrix + q = MatrixSymbol('q', 5, 1) + M = MatrixSymbol('M', 3, 3) + m = Matrix([[sin(q[1,0]), 0, cos(q[2,0])], + [q[1,0] + q[2,0], q[3, 0], 5], + [2*q[4, 0]/q[1,0], sqrt(q[0,0]) + 4, 0]]) + assert jscode(m, M) == ( + "M[0] = Math.sin(q[1]);\n" + "M[1] = 0;\n" + "M[2] = Math.cos(q[2]);\n" + "M[3] = q[1] + q[2];\n" + "M[4] = q[3];\n" + "M[5] = 5;\n" + "M[6] = 2*q[4]/q[1];\n" + "M[7] = Math.sqrt(q[0]) + 4;\n" + "M[8] = 0;") + + +def test_MatrixElement_printing(): + # test cases for issue #11821 + A = MatrixSymbol("A", 1, 3) + B = MatrixSymbol("B", 1, 3) + C = MatrixSymbol("C", 1, 3) + + assert(jscode(A[0, 0]) == "A[0]") + assert(jscode(3 * A[0, 0]) == "3*A[0]") + + F = C[0, 0].subs(C, A - B) + assert(jscode(F) == "(A - B)[0]") diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/test_julia.py b/phivenv/Lib/site-packages/sympy/printing/tests/test_julia.py new file mode 100644 index 0000000000000000000000000000000000000000..b19c7b4fd4f21d8402ca2f577605322b3ec10f5b --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/tests/test_julia.py @@ -0,0 +1,390 @@ +from sympy.core import (S, pi, oo, symbols, Function, Rational, Integer, + Tuple, Symbol, Eq, Ne, Le, Lt, Gt, Ge) +from sympy.core import EulerGamma, GoldenRatio, Catalan, Lambda, Mul, Pow +from sympy.functions import Piecewise, sqrt, ceiling, exp, sin, cos, sinc +from sympy.testing.pytest import raises +from sympy.utilities.lambdify import implemented_function +from sympy.matrices import (eye, Matrix, MatrixSymbol, Identity, + HadamardProduct, SparseMatrix) +from sympy.functions.special.bessel import (jn, yn, besselj, bessely, besseli, + besselk, hankel1, hankel2, airyai, + airybi, airyaiprime, airybiprime) +from sympy.testing.pytest import XFAIL + +from sympy.printing.julia import julia_code + +x, y, z = symbols('x,y,z') + + +def test_Integer(): + assert julia_code(Integer(67)) == "67" + assert julia_code(Integer(-1)) == "-1" + + +def test_Rational(): + assert julia_code(Rational(3, 7)) == "3 // 7" + assert julia_code(Rational(18, 9)) == "2" + assert julia_code(Rational(3, -7)) == "-3 // 7" + assert julia_code(Rational(-3, -7)) == "3 // 7" + assert julia_code(x + Rational(3, 7)) == "x + 3 // 7" + assert julia_code(Rational(3, 7)*x) == "(3 // 7) * x" + + +def test_Relational(): + assert julia_code(Eq(x, y)) == "x == y" + assert julia_code(Ne(x, y)) == "x != y" + assert julia_code(Le(x, y)) == "x <= y" + assert julia_code(Lt(x, y)) == "x < y" + assert julia_code(Gt(x, y)) == "x > y" + assert julia_code(Ge(x, y)) == "x >= y" + + +def test_Function(): + assert julia_code(sin(x) ** cos(x)) == "sin(x) .^ cos(x)" + assert julia_code(abs(x)) == "abs(x)" + assert julia_code(ceiling(x)) == "ceil(x)" + + +def test_Pow(): + assert julia_code(x**3) == "x .^ 3" + assert julia_code(x**(y**3)) == "x .^ (y .^ 3)" + assert julia_code(x**Rational(2, 3)) == 'x .^ (2 // 3)' + g = implemented_function('g', Lambda(x, 2*x)) + assert julia_code(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \ + "(3.5 * 2 * x) .^ (-x + y .^ x) ./ (x .^ 2 + y)" + # For issue 14160 + assert julia_code(Mul(-2, x, Pow(Mul(y,y,evaluate=False), -1, evaluate=False), + evaluate=False)) == '-2 * x ./ (y .* y)' + + +def test_basic_ops(): + assert julia_code(x*y) == "x .* y" + assert julia_code(x + y) == "x + y" + assert julia_code(x - y) == "x - y" + assert julia_code(-x) == "-x" + + +def test_1_over_x_and_sqrt(): + # 1.0 and 0.5 would do something different in regular StrPrinter, + # but these are exact in IEEE floating point so no different here. + assert julia_code(1/x) == '1 ./ x' + assert julia_code(x**-1) == julia_code(x**-1.0) == '1 ./ x' + assert julia_code(1/sqrt(x)) == '1 ./ sqrt(x)' + assert julia_code(x**-S.Half) == julia_code(x**-0.5) == '1 ./ sqrt(x)' + assert julia_code(sqrt(x)) == 'sqrt(x)' + assert julia_code(x**S.Half) == julia_code(x**0.5) == 'sqrt(x)' + assert julia_code(1/pi) == '1 / pi' + assert julia_code(pi**-1) == julia_code(pi**-1.0) == '1 / pi' + assert julia_code(pi**-0.5) == '1 / sqrt(pi)' + + +def test_mix_number_mult_symbols(): + assert julia_code(3*x) == "3 * x" + assert julia_code(pi*x) == "pi * x" + assert julia_code(3/x) == "3 ./ x" + assert julia_code(pi/x) == "pi ./ x" + assert julia_code(x/3) == "x / 3" + assert julia_code(x/pi) == "x / pi" + assert julia_code(x*y) == "x .* y" + assert julia_code(3*x*y) == "3 * x .* y" + assert julia_code(3*pi*x*y) == "3 * pi * x .* y" + assert julia_code(x/y) == "x ./ y" + assert julia_code(3*x/y) == "3 * x ./ y" + assert julia_code(x*y/z) == "x .* y ./ z" + assert julia_code(x/y*z) == "x .* z ./ y" + assert julia_code(1/x/y) == "1 ./ (x .* y)" + assert julia_code(2*pi*x/y/z) == "2 * pi * x ./ (y .* z)" + assert julia_code(3*pi/x) == "3 * pi ./ x" + assert julia_code(S(3)/5) == "3 // 5" + assert julia_code(S(3)/5*x) == "(3 // 5) * x" + assert julia_code(x/y/z) == "x ./ (y .* z)" + assert julia_code((x+y)/z) == "(x + y) ./ z" + assert julia_code((x+y)/(z+x)) == "(x + y) ./ (x + z)" + assert julia_code((x+y)/EulerGamma) == "(x + y) / eulergamma" + assert julia_code(x/3/pi) == "x / (3 * pi)" + assert julia_code(S(3)/5*x*y/pi) == "(3 // 5) * x .* y / pi" + + +def test_mix_number_pow_symbols(): + assert julia_code(pi**3) == 'pi ^ 3' + assert julia_code(x**2) == 'x .^ 2' + assert julia_code(x**(pi**3)) == 'x .^ (pi ^ 3)' + assert julia_code(x**y) == 'x .^ y' + assert julia_code(x**(y**z)) == 'x .^ (y .^ z)' + assert julia_code((x**y)**z) == '(x .^ y) .^ z' + + +def test_imag(): + I = S('I') + assert julia_code(I) == "im" + assert julia_code(5*I) == "5im" + assert julia_code((S(3)/2)*I) == "(3 // 2) * im" + assert julia_code(3+4*I) == "3 + 4im" + + +def test_constants(): + assert julia_code(pi) == "pi" + assert julia_code(oo) == "Inf" + assert julia_code(-oo) == "-Inf" + assert julia_code(S.NegativeInfinity) == "-Inf" + assert julia_code(S.NaN) == "NaN" + assert julia_code(S.Exp1) == "e" + assert julia_code(exp(1)) == "e" + + +def test_constants_other(): + assert julia_code(2*GoldenRatio) == "2 * golden" + assert julia_code(2*Catalan) == "2 * catalan" + assert julia_code(2*EulerGamma) == "2 * eulergamma" + + +def test_boolean(): + assert julia_code(x & y) == "x && y" + assert julia_code(x | y) == "x || y" + assert julia_code(~x) == "!x" + assert julia_code(x & y & z) == "x && y && z" + assert julia_code(x | y | z) == "x || y || z" + assert julia_code((x & y) | z) == "z || x && y" + assert julia_code((x | y) & z) == "z && (x || y)" + +def test_sinc(): + assert julia_code(sinc(x)) == 'sinc(x / pi)' + assert julia_code(sinc(x + 3)) == 'sinc((x + 3) / pi)' + assert julia_code(sinc(pi * (x + 3))) == 'sinc(x + 3)' + +def test_Matrices(): + assert julia_code(Matrix(1, 1, [10])) == "[10]" + A = Matrix([[1, sin(x/2), abs(x)], + [0, 1, pi], + [0, exp(1), ceiling(x)]]) + expected = ("[1 sin(x / 2) abs(x);\n" + "0 1 pi;\n" + "0 e ceil(x)]") + assert julia_code(A) == expected + # row and columns + assert julia_code(A[:,0]) == "[1, 0, 0]" + assert julia_code(A[0,:]) == "[1 sin(x / 2) abs(x)]" + # empty matrices + assert julia_code(Matrix(0, 0, [])) == 'zeros(0, 0)' + assert julia_code(Matrix(0, 3, [])) == 'zeros(0, 3)' + # annoying to read but correct + assert julia_code(Matrix([[x, x - y, -y]])) == "[x x - y -y]" + + +def test_vector_entries_hadamard(): + # For a row or column, user might to use the other dimension + A = Matrix([[1, sin(2/x), 3*pi/x/5]]) + assert julia_code(A) == "[1 sin(2 ./ x) (3 // 5) * pi ./ x]" + assert julia_code(A.T) == "[1, sin(2 ./ x), (3 // 5) * pi ./ x]" + + +@XFAIL +def test_Matrices_entries_not_hadamard(): + # For Matrix with col >= 2, row >= 2, they need to be scalars + # FIXME: is it worth worrying about this? Its not wrong, just + # leave it user's responsibility to put scalar data for x. + A = Matrix([[1, sin(2/x), 3*pi/x/5], [1, 2, x*y]]) + expected = ("[1 sin(2/x) 3*pi/(5*x);\n" + "1 2 x*y]") # <- we give x.*y + assert julia_code(A) == expected + + +def test_MatrixSymbol(): + n = Symbol('n', integer=True) + A = MatrixSymbol('A', n, n) + B = MatrixSymbol('B', n, n) + assert julia_code(A*B) == "A * B" + assert julia_code(B*A) == "B * A" + assert julia_code(2*A*B) == "2 * A * B" + assert julia_code(B*2*A) == "2 * B * A" + assert julia_code(A*(B + 3*Identity(n))) == "A * (3 * eye(n) + B)" + assert julia_code(A**(x**2)) == "A ^ (x .^ 2)" + assert julia_code(A**3) == "A ^ 3" + assert julia_code(A**S.Half) == "A ^ (1 // 2)" + + +def test_special_matrices(): + assert julia_code(6*Identity(3)) == "6 * eye(3)" + + +def test_containers(): + assert julia_code([1, 2, 3, [4, 5, [6, 7]], 8, [9, 10], 11]) == \ + "Any[1, 2, 3, Any[4, 5, Any[6, 7]], 8, Any[9, 10], 11]" + assert julia_code((1, 2, (3, 4))) == "(1, 2, (3, 4))" + assert julia_code([1]) == "Any[1]" + assert julia_code((1,)) == "(1,)" + assert julia_code(Tuple(*[1, 2, 3])) == "(1, 2, 3)" + assert julia_code((1, x*y, (3, x**2))) == "(1, x .* y, (3, x .^ 2))" + # scalar, matrix, empty matrix and empty list + assert julia_code((1, eye(3), Matrix(0, 0, []), [])) == "(1, [1 0 0;\n0 1 0;\n0 0 1], zeros(0, 0), Any[])" + + +def test_julia_noninline(): + source = julia_code((x+y)/Catalan, assign_to='me', inline=False) + expected = ( + "const Catalan = %s\n" + "me = (x + y) / Catalan" + ) % Catalan.evalf(17) + assert source == expected + + +def test_julia_piecewise(): + expr = Piecewise((x, x < 1), (x**2, True)) + assert julia_code(expr) == "((x < 1) ? (x) : (x .^ 2))" + assert julia_code(expr, assign_to="r") == ( + "r = ((x < 1) ? (x) : (x .^ 2))") + assert julia_code(expr, assign_to="r", inline=False) == ( + "if (x < 1)\n" + " r = x\n" + "else\n" + " r = x .^ 2\n" + "end") + expr = Piecewise((x**2, x < 1), (x**3, x < 2), (x**4, x < 3), (x**5, True)) + expected = ("((x < 1) ? (x .^ 2) :\n" + "(x < 2) ? (x .^ 3) :\n" + "(x < 3) ? (x .^ 4) : (x .^ 5))") + assert julia_code(expr) == expected + assert julia_code(expr, assign_to="r") == "r = " + expected + assert julia_code(expr, assign_to="r", inline=False) == ( + "if (x < 1)\n" + " r = x .^ 2\n" + "elseif (x < 2)\n" + " r = x .^ 3\n" + "elseif (x < 3)\n" + " r = x .^ 4\n" + "else\n" + " r = x .^ 5\n" + "end") + # Check that Piecewise without a True (default) condition error + expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0)) + raises(ValueError, lambda: julia_code(expr)) + + +def test_julia_piecewise_times_const(): + pw = Piecewise((x, x < 1), (x**2, True)) + assert julia_code(2*pw) == "2 * ((x < 1) ? (x) : (x .^ 2))" + assert julia_code(pw/x) == "((x < 1) ? (x) : (x .^ 2)) ./ x" + assert julia_code(pw/(x*y)) == "((x < 1) ? (x) : (x .^ 2)) ./ (x .* y)" + assert julia_code(pw/3) == "((x < 1) ? (x) : (x .^ 2)) / 3" + + +def test_julia_matrix_assign_to(): + A = Matrix([[1, 2, 3]]) + assert julia_code(A, assign_to='a') == "a = [1 2 3]" + A = Matrix([[1, 2], [3, 4]]) + assert julia_code(A, assign_to='A') == "A = [1 2;\n3 4]" + + +def test_julia_matrix_assign_to_more(): + # assigning to Symbol or MatrixSymbol requires lhs/rhs match + A = Matrix([[1, 2, 3]]) + B = MatrixSymbol('B', 1, 3) + C = MatrixSymbol('C', 2, 3) + assert julia_code(A, assign_to=B) == "B = [1 2 3]" + raises(ValueError, lambda: julia_code(A, assign_to=x)) + raises(ValueError, lambda: julia_code(A, assign_to=C)) + + +def test_julia_matrix_1x1(): + A = Matrix([[3]]) + B = MatrixSymbol('B', 1, 1) + C = MatrixSymbol('C', 1, 2) + assert julia_code(A, assign_to=B) == "B = [3]" + # FIXME? + #assert julia_code(A, assign_to=x) == "x = [3]" + raises(ValueError, lambda: julia_code(A, assign_to=C)) + + +def test_julia_matrix_elements(): + A = Matrix([[x, 2, x*y]]) + assert julia_code(A[0, 0]**2 + A[0, 1] + A[0, 2]) == "x .^ 2 + x .* y + 2" + A = MatrixSymbol('AA', 1, 3) + assert julia_code(A) == "AA" + assert julia_code(A[0, 0]**2 + sin(A[0,1]) + A[0,2]) == \ + "sin(AA[1,2]) + AA[1,1] .^ 2 + AA[1,3]" + assert julia_code(sum(A)) == "AA[1,1] + AA[1,2] + AA[1,3]" + + +def test_julia_boolean(): + assert julia_code(True) == "true" + assert julia_code(S.true) == "true" + assert julia_code(False) == "false" + assert julia_code(S.false) == "false" + + +def test_julia_not_supported(): + with raises(NotImplementedError): + julia_code(S.ComplexInfinity) + + f = Function('f') + assert julia_code(f(x).diff(x), strict=False) == ( + "# Not supported in Julia:\n" + "# Derivative\n" + "Derivative(f(x), x)" + ) + + +def test_trick_indent_with_end_else_words(): + # words starting with "end" or "else" do not confuse the indenter + t1 = S('endless') + t2 = S('elsewhere') + pw = Piecewise((t1, x < 0), (t2, x <= 1), (1, True)) + assert julia_code(pw, inline=False) == ( + "if (x < 0)\n" + " endless\n" + "elseif (x <= 1)\n" + " elsewhere\n" + "else\n" + " 1\n" + "end") + + +def test_haramard(): + A = MatrixSymbol('A', 3, 3) + B = MatrixSymbol('B', 3, 3) + v = MatrixSymbol('v', 3, 1) + h = MatrixSymbol('h', 1, 3) + C = HadamardProduct(A, B) + assert julia_code(C) == "A .* B" + assert julia_code(C*v) == "(A .* B) * v" + assert julia_code(h*C*v) == "h * (A .* B) * v" + assert julia_code(C*A) == "(A .* B) * A" + # mixing Hadamard and scalar strange b/c we vectorize scalars + assert julia_code(C*x*y) == "(x .* y) * (A .* B)" + + +def test_sparse(): + M = SparseMatrix(5, 6, {}) + M[2, 2] = 10 + M[1, 2] = 20 + M[1, 3] = 22 + M[0, 3] = 30 + M[3, 0] = x*y + assert julia_code(M) == ( + "sparse([4, 2, 3, 1, 2], [1, 3, 3, 4, 4], [x .* y, 20, 10, 30, 22], 5, 6)" + ) + + +def test_specfun(): + n = Symbol('n') + for f in [besselj, bessely, besseli, besselk]: + assert julia_code(f(n, x)) == f.__name__ + '(n, x)' + for f in [airyai, airyaiprime, airybi, airybiprime]: + assert julia_code(f(x)) == f.__name__ + '(x)' + assert julia_code(hankel1(n, x)) == 'hankelh1(n, x)' + assert julia_code(hankel2(n, x)) == 'hankelh2(n, x)' + assert julia_code(jn(n, x)) == 'sqrt(2) * sqrt(pi) * sqrt(1 ./ x) .* besselj(n + 1 // 2, x) / 2' + assert julia_code(yn(n, x)) == 'sqrt(2) * sqrt(pi) * sqrt(1 ./ x) .* bessely(n + 1 // 2, x) / 2' + + +def test_MatrixElement_printing(): + # test cases for issue #11821 + A = MatrixSymbol("A", 1, 3) + B = MatrixSymbol("B", 1, 3) + C = MatrixSymbol("C", 1, 3) + + assert(julia_code(A[0, 0]) == "A[1,1]") + assert(julia_code(3 * A[0, 0]) == "3 * A[1,1]") + + F = C[0, 0].subs(C, A - B) + assert(julia_code(F) == "(A - B)[1,1]") diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/test_lambdarepr.py b/phivenv/Lib/site-packages/sympy/printing/tests/test_lambdarepr.py new file mode 100644 index 0000000000000000000000000000000000000000..94e09ada7a9ce7d01667edd8fc6ec35ebfbb9639 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/tests/test_lambdarepr.py @@ -0,0 +1,246 @@ +from sympy.concrete.summations import Sum +from sympy.core.expr import Expr +from sympy.core.symbol import symbols +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import sin +from sympy.matrices.dense import MutableDenseMatrix as Matrix +from sympy.sets.sets import Interval +from sympy.utilities.lambdify import lambdify +from sympy.testing.pytest import raises + +from sympy.printing.tensorflow import TensorflowPrinter +from sympy.printing.lambdarepr import lambdarepr, LambdaPrinter, NumExprPrinter + + +x, y, z = symbols("x,y,z") +i, a, b = symbols("i,a,b") +j, c, d = symbols("j,c,d") + + +def test_basic(): + assert lambdarepr(x*y) == "x*y" + assert lambdarepr(x + y) in ["y + x", "x + y"] + assert lambdarepr(x**y) == "x**y" + + +def test_matrix(): + # Test printing a Matrix that has an element that is printed differently + # with the LambdaPrinter than with the StrPrinter. + e = x % 2 + assert lambdarepr(e) != str(e) + assert lambdarepr(Matrix([e])) == 'ImmutableDenseMatrix([[x % 2]])' + + +def test_piecewise(): + # In each case, test eval() the lambdarepr() to make sure there are a + # correct number of parentheses. It will give a SyntaxError if there aren't. + + h = "lambda x: " + + p = Piecewise((x, x < 0)) + l = lambdarepr(p) + eval(h + l) + assert l == "((x) if (x < 0) else None)" + + p = Piecewise( + (1, x < 1), + (2, x < 2), + (0, True) + ) + l = lambdarepr(p) + eval(h + l) + assert l == "((1) if (x < 1) else (2) if (x < 2) else (0))" + + p = Piecewise( + (1, x < 1), + (2, x < 2), + ) + l = lambdarepr(p) + eval(h + l) + assert l == "((1) if (x < 1) else (2) if (x < 2) else None)" + + p = Piecewise( + (x, x < 1), + (x**2, Interval(3, 4, True, False).contains(x)), + (0, True), + ) + l = lambdarepr(p) + eval(h + l) + assert l == "((x) if (x < 1) else (x**2) if (((x <= 4)) and ((x > 3))) else (0))" + + p = Piecewise( + (x**2, x < 0), + (x, x < 1), + (2 - x, x >= 1), + (0, True), evaluate=False + ) + l = lambdarepr(p) + eval(h + l) + assert l == "((x**2) if (x < 0) else (x) if (x < 1)"\ + " else (2 - x) if (x >= 1) else (0))" + + p = Piecewise( + (x**2, x < 0), + (x, x < 1), + (2 - x, x >= 1), evaluate=False + ) + l = lambdarepr(p) + eval(h + l) + assert l == "((x**2) if (x < 0) else (x) if (x < 1)"\ + " else (2 - x) if (x >= 1) else None)" + + p = Piecewise( + (1, x >= 1), + (2, x >= 2), + (3, x >= 3), + (4, x >= 4), + (5, x >= 5), + (6, True) + ) + l = lambdarepr(p) + eval(h + l) + assert l == "((1) if (x >= 1) else (2) if (x >= 2) else (3) if (x >= 3)"\ + " else (4) if (x >= 4) else (5) if (x >= 5) else (6))" + + p = Piecewise( + (1, x <= 1), + (2, x <= 2), + (3, x <= 3), + (4, x <= 4), + (5, x <= 5), + (6, True) + ) + l = lambdarepr(p) + eval(h + l) + assert l == "((1) if (x <= 1) else (2) if (x <= 2) else (3) if (x <= 3)"\ + " else (4) if (x <= 4) else (5) if (x <= 5) else (6))" + + p = Piecewise( + (1, x > 1), + (2, x > 2), + (3, x > 3), + (4, x > 4), + (5, x > 5), + (6, True) + ) + l = lambdarepr(p) + eval(h + l) + assert l =="((1) if (x > 1) else (2) if (x > 2) else (3) if (x > 3)"\ + " else (4) if (x > 4) else (5) if (x > 5) else (6))" + + p = Piecewise( + (1, x < 1), + (2, x < 2), + (3, x < 3), + (4, x < 4), + (5, x < 5), + (6, True) + ) + l = lambdarepr(p) + eval(h + l) + assert l == "((1) if (x < 1) else (2) if (x < 2) else (3) if (x < 3)"\ + " else (4) if (x < 4) else (5) if (x < 5) else (6))" + + p = Piecewise( + (Piecewise( + (1, x > 0), + (2, True) + ), y > 0), + (3, True) + ) + l = lambdarepr(p) + eval(h + l) + assert l == "((((1) if (x > 0) else (2))) if (y > 0) else (3))" + + +def test_sum__1(): + # In each case, test eval() the lambdarepr() to make sure that + # it evaluates to the same results as the symbolic expression + s = Sum(x ** i, (i, a, b)) + l = lambdarepr(s) + assert l == "(builtins.sum(x**i for i in range(a, b+1)))" + + args = x, a, b + f = lambdify(args, s) + v = 2, 3, 8 + assert f(*v) == s.subs(zip(args, v)).doit() + +def test_sum__2(): + s = Sum(i * x, (i, a, b)) + l = lambdarepr(s) + assert l == "(builtins.sum(i*x for i in range(a, b+1)))" + + args = x, a, b + f = lambdify(args, s) + v = 2, 3, 8 + assert f(*v) == s.subs(zip(args, v)).doit() + + +def test_multiple_sums(): + s = Sum(i * x + j, (i, a, b), (j, c, d)) + + l = lambdarepr(s) + assert l == "(builtins.sum(i*x + j for j in range(c, d+1) for i in range(a, b+1)))" + + args = x, a, b, c, d + f = lambdify(args, s) + vals = 2, 3, 4, 5, 6 + f_ref = s.subs(zip(args, vals)).doit() + f_res = f(*vals) + assert f_res == f_ref + + +def test_sqrt(): + prntr = LambdaPrinter({'standard' : 'python3'}) + assert prntr._print_Pow(sqrt(x), rational=False) == 'sqrt(x)' + assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1/2)' + + +def test_settings(): + raises(TypeError, lambda: lambdarepr(sin(x), method="garbage")) + + +def test_numexpr(): + # test ITE rewrite as Piecewise + from sympy.logic.boolalg import ITE + expr = ITE(x > 0, True, False, evaluate=False) + assert NumExprPrinter().doprint(expr) == \ + "numexpr.evaluate('where((x > 0), True, False)', truediv=True)" + + from sympy.codegen.ast import Return, FunctionDefinition, Variable, Assignment + func_def = FunctionDefinition(None, 'foo', [Variable(x)], [Assignment(y,x), Return(y**2)]) + expected = "def foo(x):\n"\ + " y = numexpr.evaluate('x', truediv=True)\n"\ + " return numexpr.evaluate('y**2', truediv=True)" + assert NumExprPrinter().doprint(func_def) == expected + + +class CustomPrintedObject(Expr): + def _lambdacode(self, printer): + return 'lambda' + + def _tensorflowcode(self, printer): + return 'tensorflow' + + def _numpycode(self, printer): + return 'numpy' + + def _numexprcode(self, printer): + return 'numexpr' + + def _mpmathcode(self, printer): + return 'mpmath' + + +def test_printmethod(): + # In each case, printmethod is called to test + # its working + + obj = CustomPrintedObject() + assert LambdaPrinter().doprint(obj) == 'lambda' + assert TensorflowPrinter().doprint(obj) == 'tensorflow' + assert NumExprPrinter().doprint(obj) == "numexpr.evaluate('numexpr', truediv=True)" + + assert NumExprPrinter().doprint(Piecewise((y, x >= 0), (z, x < 0))) == \ + "numexpr.evaluate('where((x >= 0), y, z)', truediv=True)" diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/test_latex.py b/phivenv/Lib/site-packages/sympy/printing/tests/test_latex.py new file mode 100644 index 0000000000000000000000000000000000000000..063611d09a923881cd94bd693f3f3f721535fd0c --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/tests/test_latex.py @@ -0,0 +1,3164 @@ +from sympy import MatAdd, MatMul, Array +from sympy.algebras.quaternion import Quaternion +from sympy.calculus.accumulationbounds import AccumBounds +from sympy.combinatorics.permutations import Cycle, Permutation, AppliedPermutation +from sympy.concrete.products import Product +from sympy.concrete.summations import Sum +from sympy.core.containers import Tuple, Dict +from sympy.core.expr import UnevaluatedExpr +from sympy.core.function import (Derivative, Function, Lambda, Subs, diff) +from sympy.core.mod import Mod +from sympy.core.mul import Mul +from sympy.core.numbers import (AlgebraicNumber, Float, I, Integer, Rational, oo, pi) +from sympy.core.parameters import evaluate +from sympy.core.power import Pow +from sympy.core.relational import Eq, Ne +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, Wild, symbols) +from sympy.functions.combinatorial.factorials import (FallingFactorial, RisingFactorial, binomial, factorial, factorial2, subfactorial) +from sympy.functions.combinatorial.numbers import (bernoulli, bell, catalan, euler, genocchi, + lucas, fibonacci, tribonacci, divisor_sigma, udivisor_sigma, + mobius, primenu, primeomega, + totient, reduced_totient) +from sympy.functions.elementary.complexes import (Abs, arg, conjugate, im, polar_lift, re) +from sympy.functions.elementary.exponential import (LambertW, exp, log) +from sympy.functions.elementary.hyperbolic import (asinh, coth) +from sympy.functions.elementary.integers import (ceiling, floor, frac) +from sympy.functions.elementary.miscellaneous import (Max, Min, root, sqrt) +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import (acsc, asin, cos, cot, sin, tan) +from sympy.functions.special.beta_functions import beta +from sympy.functions.special.delta_functions import (DiracDelta, Heaviside) +from sympy.functions.special.elliptic_integrals import (elliptic_e, elliptic_f, elliptic_k, elliptic_pi) +from sympy.functions.special.error_functions import (Chi, Ci, Ei, Shi, Si, expint) +from sympy.functions.special.gamma_functions import (gamma, uppergamma) +from sympy.functions.special.hyper import (hyper, meijerg) +from sympy.functions.special.mathieu_functions import (mathieuc, mathieucprime, mathieus, mathieusprime) +from sympy.functions.special.polynomials import (assoc_laguerre, assoc_legendre, chebyshevt, chebyshevu, gegenbauer, hermite, jacobi, laguerre, legendre) +from sympy.functions.special.singularity_functions import SingularityFunction +from sympy.functions.special.spherical_harmonics import (Ynm, Znm) +from sympy.functions.special.tensor_functions import (KroneckerDelta, LeviCivita) +from sympy.functions.special.zeta_functions import (dirichlet_eta, lerchphi, polylog, stieltjes, zeta) +from sympy.integrals.integrals import Integral +from sympy.integrals.transforms import (CosineTransform, FourierTransform, InverseCosineTransform, InverseFourierTransform, InverseLaplaceTransform, InverseMellinTransform, InverseSineTransform, LaplaceTransform, MellinTransform, SineTransform) +from sympy.logic import Implies +from sympy.logic.boolalg import (And, Or, Xor, Equivalent, false, Not, true) +from sympy.matrices.dense import Matrix +from sympy.matrices.expressions.kronecker import KroneckerProduct +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.expressions.permutation import PermutationMatrix +from sympy.matrices.expressions.slice import MatrixSlice +from sympy.matrices.expressions.dotproduct import DotProduct +from sympy.physics.control.lti import TransferFunction, Series, Parallel, Feedback, TransferFunctionMatrix, MIMOSeries, MIMOParallel, MIMOFeedback +from sympy.physics.quantum import Commutator, Operator +from sympy.physics.quantum.trace import Tr +from sympy.physics.units import meter, gibibyte, gram, microgram, second, milli, micro +from sympy.polys.domains.integerring import ZZ +from sympy.polys.fields import field +from sympy.polys.polytools import Poly +from sympy.polys.rings import ring +from sympy.polys.rootoftools import (RootSum, rootof) +from sympy.series.formal import fps +from sympy.series.fourier import fourier_series +from sympy.series.limits import Limit +from sympy.series.order import Order +from sympy.series.sequences import (SeqAdd, SeqFormula, SeqMul, SeqPer) +from sympy.sets.conditionset import ConditionSet +from sympy.sets.contains import Contains +from sympy.sets.fancysets import (ComplexRegion, ImageSet, Range) +from sympy.sets.ordinals import Ordinal, OrdinalOmega, OmegaPower +from sympy.sets.powerset import PowerSet +from sympy.sets.sets import (FiniteSet, Interval, Union, Intersection, Complement, SymmetricDifference, ProductSet) +from sympy.sets.setexpr import SetExpr +from sympy.stats.crv_types import Normal +from sympy.stats.symbolic_probability import (Covariance, Expectation, + Probability, Variance) +from sympy.tensor.array import (ImmutableDenseNDimArray, + ImmutableSparseNDimArray, + MutableSparseNDimArray, + MutableDenseNDimArray, + tensorproduct) +from sympy.tensor.array.expressions.array_expressions import ArraySymbol, ArrayElement +from sympy.tensor.indexed import (Idx, Indexed, IndexedBase) +from sympy.tensor.toperators import PartialDerivative +from sympy.vector import CoordSys3D, Cross, Curl, Dot, Divergence, Gradient, Laplacian + + +from sympy.testing.pytest import (XFAIL, raises, _both_exp_pow, + warns_deprecated_sympy) +from sympy.printing.latex import (latex, translate, greek_letters_set, + tex_greek_dictionary, multiline_latex, + latex_escape, LatexPrinter) + +import sympy as sym + +from sympy.abc import mu, tau + + +class lowergamma(sym.lowergamma): + pass # testing notation inheritance by a subclass with same name + + +x, y, z, t, w, a, b, c, s, p = symbols('x y z t w a b c s p') +k, m, n = symbols('k m n', integer=True) + + +def test_printmethod(): + class R(Abs): + def _latex(self, printer): + return "foo(%s)" % printer._print(self.args[0]) + assert latex(R(x)) == r"foo(x)" + + class R(Abs): + def _latex(self, printer): + return "foo" + assert latex(R(x)) == r"foo" + + +def test_latex_basic(): + assert latex(1 + x) == r"x + 1" + assert latex(x**2) == r"x^{2}" + assert latex(x**(1 + x)) == r"x^{x + 1}" + assert latex(x**3 + x + 1 + x**2) == r"x^{3} + x^{2} + x + 1" + + assert latex(2*x*y) == r"2 x y" + assert latex(2*x*y, mul_symbol='dot') == r"2 \cdot x \cdot y" + assert latex(3*x**2*y, mul_symbol='\\,') == r"3\,x^{2}\,y" + assert latex(1.5*3**x, mul_symbol='\\,') == r"1.5 \cdot 3^{x}" + + assert latex(x**S.Half**5) == r"\sqrt[32]{x}" + assert latex(Mul(S.Half, x**2, -5, evaluate=False)) == r"\frac{1}{2} x^{2} \left(-5\right)" + assert latex(Mul(S.Half, x**2, 5, evaluate=False)) == r"\frac{1}{2} x^{2} \cdot 5" + assert latex(Mul(-5, -5, evaluate=False)) == r"\left(-5\right) \left(-5\right)" + assert latex(Mul(5, -5, evaluate=False)) == r"5 \left(-5\right)" + assert latex(Mul(S.Half, -5, S.Half, evaluate=False)) == r"\frac{1}{2} \left(-5\right) \frac{1}{2}" + assert latex(Mul(5, I, 5, evaluate=False)) == r"5 i 5" + assert latex(Mul(5, I, -5, evaluate=False)) == r"5 i \left(-5\right)" + assert latex(Mul(Pow(x, 2), S.Half*x + 1)) == r"x^{2} \left(\frac{x}{2} + 1\right)" + assert latex(Mul(Pow(x, 3), Rational(2, 3)*x + 1)) == r"x^{3} \left(\frac{2 x}{3} + 1\right)" + assert latex(Mul(Pow(x, 11), 2*x + 1)) == r"x^{11} \left(2 x + 1\right)" + + assert latex(Mul(0, 1, evaluate=False)) == r'0 \cdot 1' + assert latex(Mul(1, 0, evaluate=False)) == r'1 \cdot 0' + assert latex(Mul(1, 1, evaluate=False)) == r'1 \cdot 1' + assert latex(Mul(-1, 1, evaluate=False)) == r'\left(-1\right) 1' + assert latex(Mul(1, 1, 1, evaluate=False)) == r'1 \cdot 1 \cdot 1' + assert latex(Mul(1, 2, evaluate=False)) == r'1 \cdot 2' + assert latex(Mul(1, S.Half, evaluate=False)) == r'1 \cdot \frac{1}{2}' + assert latex(Mul(1, 1, S.Half, evaluate=False)) == \ + r'1 \cdot 1 \cdot \frac{1}{2}' + assert latex(Mul(1, 1, 2, 3, x, evaluate=False)) == \ + r'1 \cdot 1 \cdot 2 \cdot 3 x' + assert latex(Mul(1, -1, evaluate=False)) == r'1 \left(-1\right)' + assert latex(Mul(4, 3, 2, 1, 0, y, x, evaluate=False)) == \ + r'4 \cdot 3 \cdot 2 \cdot 1 \cdot 0 y x' + assert latex(Mul(4, 3, 2, 1+z, 0, y, x, evaluate=False)) == \ + r'4 \cdot 3 \cdot 2 \left(z + 1\right) 0 y x' + assert latex(Mul(Rational(2, 3), Rational(5, 7), evaluate=False)) == \ + r'\frac{2}{3} \cdot \frac{5}{7}' + + assert latex(1/x) == r"\frac{1}{x}" + assert latex(1/x, fold_short_frac=True) == r"1 / x" + assert latex(-S(3)/2) == r"- \frac{3}{2}" + assert latex(-S(3)/2, fold_short_frac=True) == r"- 3 / 2" + assert latex(1/x**2) == r"\frac{1}{x^{2}}" + assert latex(1/(x + y)/2) == r"\frac{1}{2 \left(x + y\right)}" + assert latex(x/2) == r"\frac{x}{2}" + assert latex(x/2, fold_short_frac=True) == r"x / 2" + assert latex((x + y)/(2*x)) == r"\frac{x + y}{2 x}" + assert latex((x + y)/(2*x), fold_short_frac=True) == \ + r"\left(x + y\right) / 2 x" + assert latex((x + y)/(2*x), long_frac_ratio=0) == \ + r"\frac{1}{2 x} \left(x + y\right)" + assert latex((x + y)/x) == r"\frac{x + y}{x}" + assert latex((x + y)/x, long_frac_ratio=3) == r"\frac{x + y}{x}" + assert latex((2*sqrt(2)*x)/3) == r"\frac{2 \sqrt{2} x}{3}" + assert latex((2*sqrt(2)*x)/3, long_frac_ratio=2) == \ + r"\frac{2 x}{3} \sqrt{2}" + assert latex(binomial(x, y)) == r"{\binom{x}{y}}" + + x_star = Symbol('x^*') + f = Function('f') + assert latex(x_star**2) == r"\left(x^{*}\right)^{2}" + assert latex(x_star**2, parenthesize_super=False) == r"{x^{*}}^{2}" + assert latex(Derivative(f(x_star), x_star,2)) == r"\frac{d^{2}}{d \left(x^{*}\right)^{2}} f{\left(x^{*} \right)}" + assert latex(Derivative(f(x_star), x_star,2), parenthesize_super=False) == r"\frac{d^{2}}{d {x^{*}}^{2}} f{\left(x^{*} \right)}" + + assert latex(2*Integral(x, x)/3) == r"\frac{2 \int x\, dx}{3}" + assert latex(2*Integral(x, x)/3, fold_short_frac=True) == \ + r"\left(2 \int x\, dx\right) / 3" + + assert latex(sqrt(x)) == r"\sqrt{x}" + assert latex(x**Rational(1, 3)) == r"\sqrt[3]{x}" + assert latex(x**Rational(1, 3), root_notation=False) == r"x^{\frac{1}{3}}" + assert latex(sqrt(x)**3) == r"x^{\frac{3}{2}}" + assert latex(sqrt(x), itex=True) == r"\sqrt{x}" + assert latex(x**Rational(1, 3), itex=True) == r"\root{3}{x}" + assert latex(sqrt(x)**3, itex=True) == r"x^{\frac{3}{2}}" + assert latex(x**Rational(3, 4)) == r"x^{\frac{3}{4}}" + assert latex(x**Rational(3, 4), fold_frac_powers=True) == r"x^{3/4}" + assert latex((x + 1)**Rational(3, 4)) == \ + r"\left(x + 1\right)^{\frac{3}{4}}" + assert latex((x + 1)**Rational(3, 4), fold_frac_powers=True) == \ + r"\left(x + 1\right)^{3/4}" + assert latex(AlgebraicNumber(sqrt(2))) == r"\sqrt{2}" + assert latex(AlgebraicNumber(sqrt(2), [3, -7])) == r"-7 + 3 \sqrt{2}" + assert latex(AlgebraicNumber(sqrt(2), alias='alpha')) == r"\alpha" + assert latex(AlgebraicNumber(sqrt(2), [3, -7], alias='alpha')) == \ + r"3 \alpha - 7" + assert latex(AlgebraicNumber(2**(S(1)/3), [1, 3, -7], alias='beta')) == \ + r"\beta^{2} + 3 \beta - 7" + + k = ZZ.cyclotomic_field(5) + assert latex(k.ext.field_element([1, 2, 3, 4])) == \ + r"\zeta^{3} + 2 \zeta^{2} + 3 \zeta + 4" + assert latex(k.ext.field_element([1, 2, 3, 4]), order='old') == \ + r"4 + 3 \zeta + 2 \zeta^{2} + \zeta^{3}" + assert latex(k.primes_above(19)[0]) == \ + r"\left(19, \zeta^{2} + 5 \zeta + 1\right)" + assert latex(k.primes_above(19)[0], order='old') == \ + r"\left(19, 1 + 5 \zeta + \zeta^{2}\right)" + assert latex(k.primes_above(7)[0]) == r"\left(7\right)" + + assert latex(1.5e20*x) == r"1.5 \cdot 10^{20} x" + assert latex(1.5e20*x, mul_symbol='dot') == r"1.5 \cdot 10^{20} \cdot x" + assert latex(1.5e20*x, mul_symbol='times') == \ + r"1.5 \times 10^{20} \times x" + + assert latex(1/sin(x)) == r"\frac{1}{\sin{\left(x \right)}}" + assert latex(sin(x)**-1) == r"\frac{1}{\sin{\left(x \right)}}" + assert latex(sin(x)**Rational(3, 2)) == \ + r"\sin^{\frac{3}{2}}{\left(x \right)}" + assert latex(sin(x)**Rational(3, 2), fold_frac_powers=True) == \ + r"\sin^{3/2}{\left(x \right)}" + + assert latex(~x) == r"\neg x" + assert latex(x & y) == r"x \wedge y" + assert latex(x & y & z) == r"x \wedge y \wedge z" + assert latex(x | y) == r"x \vee y" + assert latex(x | y | z) == r"x \vee y \vee z" + assert latex((x & y) | z) == r"z \vee \left(x \wedge y\right)" + assert latex(Implies(x, y)) == r"x \Rightarrow y" + assert latex(~(x >> ~y)) == r"x \not\Rightarrow \neg y" + assert latex(Implies(Or(x,y), z)) == r"\left(x \vee y\right) \Rightarrow z" + assert latex(Implies(z, Or(x,y))) == r"z \Rightarrow \left(x \vee y\right)" + assert latex(~(x & y)) == r"\neg \left(x \wedge y\right)" + + assert latex(~x, symbol_names={x: "x_i"}) == r"\neg x_i" + assert latex(x & y, symbol_names={x: "x_i", y: "y_i"}) == \ + r"x_i \wedge y_i" + assert latex(x & y & z, symbol_names={x: "x_i", y: "y_i", z: "z_i"}) == \ + r"x_i \wedge y_i \wedge z_i" + assert latex(x | y, symbol_names={x: "x_i", y: "y_i"}) == r"x_i \vee y_i" + assert latex(x | y | z, symbol_names={x: "x_i", y: "y_i", z: "z_i"}) == \ + r"x_i \vee y_i \vee z_i" + assert latex((x & y) | z, symbol_names={x: "x_i", y: "y_i", z: "z_i"}) == \ + r"z_i \vee \left(x_i \wedge y_i\right)" + assert latex(Implies(x, y), symbol_names={x: "x_i", y: "y_i"}) == \ + r"x_i \Rightarrow y_i" + assert latex(Pow(Rational(1, 3), -1, evaluate=False)) == r"\frac{1}{\frac{1}{3}}" + assert latex(Pow(Rational(1, 3), -2, evaluate=False)) == r"\frac{1}{(\frac{1}{3})^{2}}" + assert latex(Pow(Integer(1)/100, -1, evaluate=False)) == r"\frac{1}{\frac{1}{100}}" + + p = Symbol('p', positive=True) + assert latex(exp(-p)*log(p)) == r"e^{- p} \log{\left(p \right)}" + + assert latex(Pow(Rational(2, 3), -1, evaluate=False)) == r'\frac{1}{\frac{2}{3}}' + assert latex(Pow(Rational(4, 3), -1, evaluate=False)) == r'\frac{1}{\frac{4}{3}}' + assert latex(Pow(Rational(-3, 4), -1, evaluate=False)) == r'\frac{1}{- \frac{3}{4}}' + assert latex(Pow(Rational(-4, 4), -1, evaluate=False)) == r'\frac{1}{-1}' + assert latex(Pow(Rational(1, 3), -1, evaluate=False)) == r'\frac{1}{\frac{1}{3}}' + assert latex(Pow(Rational(-1, 3), -1, evaluate=False)) == r'\frac{1}{- \frac{1}{3}}' + + +def test_latex_builtins(): + assert latex(True) == r"\text{True}" + assert latex(False) == r"\text{False}" + assert latex(None) == r"\text{None}" + assert latex(true) == r"\text{True}" + assert latex(false) == r'\text{False}' + + +def test_latex_SingularityFunction(): + assert latex(SingularityFunction(x, 4, 5)) == \ + r"{\left\langle x - 4 \right\rangle}^{5}" + assert latex(SingularityFunction(x, -3, 4)) == \ + r"{\left\langle x + 3 \right\rangle}^{4}" + assert latex(SingularityFunction(x, 0, 4)) == \ + r"{\left\langle x \right\rangle}^{4}" + assert latex(SingularityFunction(x, a, n)) == \ + r"{\left\langle - a + x \right\rangle}^{n}" + assert latex(SingularityFunction(x, 4, -2)) == \ + r"{\left\langle x - 4 \right\rangle}^{-2}" + assert latex(SingularityFunction(x, 4, -1)) == \ + r"{\left\langle x - 4 \right\rangle}^{-1}" + + assert latex(SingularityFunction(x, 4, 5)**3) == \ + r"{\left({\langle x - 4 \rangle}^{5}\right)}^{3}" + assert latex(SingularityFunction(x, -3, 4)**3) == \ + r"{\left({\langle x + 3 \rangle}^{4}\right)}^{3}" + assert latex(SingularityFunction(x, 0, 4)**3) == \ + r"{\left({\langle x \rangle}^{4}\right)}^{3}" + assert latex(SingularityFunction(x, a, n)**3) == \ + r"{\left({\langle - a + x \rangle}^{n}\right)}^{3}" + assert latex(SingularityFunction(x, 4, -2)**3) == \ + r"{\left({\langle x - 4 \rangle}^{-2}\right)}^{3}" + assert latex((SingularityFunction(x, 4, -1)**3)**3) == \ + r"{\left({\langle x - 4 \rangle}^{-1}\right)}^{9}" + + +def test_latex_cycle(): + assert latex(Cycle(1, 2, 4)) == r"\left( 1\; 2\; 4\right)" + assert latex(Cycle(1, 2)(4, 5, 6)) == \ + r"\left( 1\; 2\right)\left( 4\; 5\; 6\right)" + assert latex(Cycle()) == r"\left( \right)" + + +def test_latex_permutation(): + assert latex(Permutation(1, 2, 4)) == r"\left( 1\; 2\; 4\right)" + assert latex(Permutation(1, 2)(4, 5, 6)) == \ + r"\left( 1\; 2\right)\left( 4\; 5\; 6\right)" + assert latex(Permutation()) == r"\left( \right)" + assert latex(Permutation(2, 4)*Permutation(5)) == \ + r"\left( 2\; 4\right)\left( 5\right)" + assert latex(Permutation(5)) == r"\left( 5\right)" + + assert latex(Permutation(0, 1), perm_cyclic=False) == \ + r"\begin{pmatrix} 0 & 1 \\ 1 & 0 \end{pmatrix}" + assert latex(Permutation(0, 1)(2, 3), perm_cyclic=False) == \ + r"\begin{pmatrix} 0 & 1 & 2 & 3 \\ 1 & 0 & 3 & 2 \end{pmatrix}" + assert latex(Permutation(), perm_cyclic=False) == \ + r"\left( \right)" + + with warns_deprecated_sympy(): + old_print_cyclic = Permutation.print_cyclic + Permutation.print_cyclic = False + assert latex(Permutation(0, 1)(2, 3)) == \ + r"\begin{pmatrix} 0 & 1 & 2 & 3 \\ 1 & 0 & 3 & 2 \end{pmatrix}" + Permutation.print_cyclic = old_print_cyclic + +def test_latex_Float(): + assert latex(Float(1.0e100)) == r"1.0 \cdot 10^{100}" + assert latex(Float(1.0e-100)) == r"1.0 \cdot 10^{-100}" + assert latex(Float(1.0e-100), mul_symbol="times") == \ + r"1.0 \times 10^{-100}" + assert latex(Float('10000.0'), full_prec=False, min=-2, max=2) == \ + r"1.0 \cdot 10^{4}" + assert latex(Float('10000.0'), full_prec=False, min=-2, max=4) == \ + r"1.0 \cdot 10^{4}" + assert latex(Float('10000.0'), full_prec=False, min=-2, max=5) == \ + r"10000.0" + assert latex(Float('0.099999'), full_prec=True, min=-2, max=5) == \ + r"9.99990000000000 \cdot 10^{-2}" + + +def test_latex_vector_expressions(): + A = CoordSys3D('A') + + assert latex(Cross(A.i, A.j*A.x*3+A.k)) == \ + r"\mathbf{\hat{i}_{A}} \times \left(\left(3 \mathbf{{x}_{A}}\right)\mathbf{\hat{j}_{A}} + \mathbf{\hat{k}_{A}}\right)" + assert latex(Cross(A.i, A.j)) == \ + r"\mathbf{\hat{i}_{A}} \times \mathbf{\hat{j}_{A}}" + assert latex(x*Cross(A.i, A.j)) == \ + r"x \left(\mathbf{\hat{i}_{A}} \times \mathbf{\hat{j}_{A}}\right)" + assert latex(Cross(x*A.i, A.j)) == \ + r'- \mathbf{\hat{j}_{A}} \times \left(\left(x\right)\mathbf{\hat{i}_{A}}\right)' + + assert latex(Curl(3*A.x*A.j)) == \ + r"\nabla\times \left(\left(3 \mathbf{{x}_{A}}\right)\mathbf{\hat{j}_{A}}\right)" + assert latex(Curl(3*A.x*A.j+A.i)) == \ + r"\nabla\times \left(\mathbf{\hat{i}_{A}} + \left(3 \mathbf{{x}_{A}}\right)\mathbf{\hat{j}_{A}}\right)" + assert latex(Curl(3*x*A.x*A.j)) == \ + r"\nabla\times \left(\left(3 \mathbf{{x}_{A}} x\right)\mathbf{\hat{j}_{A}}\right)" + assert latex(x*Curl(3*A.x*A.j)) == \ + r"x \left(\nabla\times \left(\left(3 \mathbf{{x}_{A}}\right)\mathbf{\hat{j}_{A}}\right)\right)" + + assert latex(Divergence(3*A.x*A.j+A.i)) == \ + r"\nabla\cdot \left(\mathbf{\hat{i}_{A}} + \left(3 \mathbf{{x}_{A}}\right)\mathbf{\hat{j}_{A}}\right)" + assert latex(Divergence(3*A.x*A.j)) == \ + r"\nabla\cdot \left(\left(3 \mathbf{{x}_{A}}\right)\mathbf{\hat{j}_{A}}\right)" + assert latex(x*Divergence(3*A.x*A.j)) == \ + r"x \left(\nabla\cdot \left(\left(3 \mathbf{{x}_{A}}\right)\mathbf{\hat{j}_{A}}\right)\right)" + + assert latex(Dot(A.i, A.j*A.x*3+A.k)) == \ + r"\mathbf{\hat{i}_{A}} \cdot \left(\left(3 \mathbf{{x}_{A}}\right)\mathbf{\hat{j}_{A}} + \mathbf{\hat{k}_{A}}\right)" + assert latex(Dot(A.i, A.j)) == \ + r"\mathbf{\hat{i}_{A}} \cdot \mathbf{\hat{j}_{A}}" + assert latex(Dot(x*A.i, A.j)) == \ + r"\mathbf{\hat{j}_{A}} \cdot \left(\left(x\right)\mathbf{\hat{i}_{A}}\right)" + assert latex(x*Dot(A.i, A.j)) == \ + r"x \left(\mathbf{\hat{i}_{A}} \cdot \mathbf{\hat{j}_{A}}\right)" + + assert latex(Gradient(A.x)) == r"\nabla \mathbf{{x}_{A}}" + assert latex(Gradient(A.x + 3*A.y)) == \ + r"\nabla \left(\mathbf{{x}_{A}} + 3 \mathbf{{y}_{A}}\right)" + assert latex(x*Gradient(A.x)) == r"x \left(\nabla \mathbf{{x}_{A}}\right)" + assert latex(Gradient(x*A.x)) == r"\nabla \left(\mathbf{{x}_{A}} x\right)" + + assert latex(Laplacian(A.x)) == r"\Delta \mathbf{{x}_{A}}" + assert latex(Laplacian(A.x + 3*A.y)) == \ + r"\Delta \left(\mathbf{{x}_{A}} + 3 \mathbf{{y}_{A}}\right)" + assert latex(x*Laplacian(A.x)) == r"x \left(\Delta \mathbf{{x}_{A}}\right)" + assert latex(Laplacian(x*A.x)) == r"\Delta \left(\mathbf{{x}_{A}} x\right)" + +def test_latex_symbols(): + Gamma, lmbda, rho = symbols('Gamma, lambda, rho') + tau, Tau, TAU, taU = symbols('tau, Tau, TAU, taU') + assert latex(tau) == r"\tau" + assert latex(Tau) == r"\mathrm{T}" + assert latex(TAU) == r"\tau" + assert latex(taU) == r"\tau" + # Check that all capitalized greek letters are handled explicitly + capitalized_letters = {l.capitalize() for l in greek_letters_set} + assert len(capitalized_letters - set(tex_greek_dictionary.keys())) == 0 + assert latex(Gamma + lmbda) == r"\Gamma + \lambda" + assert latex(Gamma * lmbda) == r"\Gamma \lambda" + assert latex(Symbol('q1')) == r"q_{1}" + assert latex(Symbol('q21')) == r"q_{21}" + assert latex(Symbol('epsilon0')) == r"\epsilon_{0}" + assert latex(Symbol('omega1')) == r"\omega_{1}" + assert latex(Symbol('91')) == r"91" + assert latex(Symbol('alpha_new')) == r"\alpha_{new}" + assert latex(Symbol('C^orig')) == r"C^{orig}" + assert latex(Symbol('x^alpha')) == r"x^{\alpha}" + assert latex(Symbol('beta^alpha')) == r"\beta^{\alpha}" + assert latex(Symbol('e^Alpha')) == r"e^{\mathrm{A}}" + assert latex(Symbol('omega_alpha^beta')) == r"\omega^{\beta}_{\alpha}" + assert latex(Symbol('omega') ** Symbol('beta')) == r"\omega^{\beta}" + + +@XFAIL +def test_latex_symbols_failing(): + rho, mass, volume = symbols('rho, mass, volume') + assert latex( + volume * rho == mass) == r"\rho \mathrm{volume} = \mathrm{mass}" + assert latex(volume / mass * rho == 1) == \ + r"\rho \mathrm{volume} {\mathrm{mass}}^{(-1)} = 1" + assert latex(mass**3 * volume**3) == \ + r"{\mathrm{mass}}^{3} \cdot {\mathrm{volume}}^{3}" + + +@_both_exp_pow +def test_latex_functions(): + assert latex(exp(x)) == r"e^{x}" + assert latex(exp(1) + exp(2)) == r"e + e^{2}" + + f = Function('f') + assert latex(f(x)) == r'f{\left(x \right)}' + assert latex(f) == r'f' + + g = Function('g') + assert latex(g(x, y)) == r'g{\left(x,y \right)}' + assert latex(g) == r'g' + + h = Function('h') + assert latex(h(x, y, z)) == r'h{\left(x,y,z \right)}' + assert latex(h) == r'h' + + Li = Function('Li') + assert latex(Li) == r'\operatorname{Li}' + assert latex(Li(x)) == r'\operatorname{Li}{\left(x \right)}' + + mybeta = Function('beta') + # not to be confused with the beta function + assert latex(mybeta(x, y, z)) == r"\beta{\left(x,y,z \right)}" + assert latex(beta(x, y)) == r'\operatorname{B}\left(x, y\right)' + assert latex(beta(x, evaluate=False)) == r'\operatorname{B}\left(x, x\right)' + assert latex(beta(x, y)**2) == r'\operatorname{B}^{2}\left(x, y\right)' + assert latex(mybeta(x)) == r"\beta{\left(x \right)}" + assert latex(mybeta) == r"\beta" + + g = Function('gamma') + # not to be confused with the gamma function + assert latex(g(x, y, z)) == r"\gamma{\left(x,y,z \right)}" + assert latex(g(x)) == r"\gamma{\left(x \right)}" + assert latex(g) == r"\gamma" + + a_1 = Function('a_1') + assert latex(a_1) == r"a_{1}" + assert latex(a_1(x)) == r"a_{1}{\left(x \right)}" + assert latex(Function('a_1')) == r"a_{1}" + + # Issue #16925 + # multi letter function names + # > simple + assert latex(Function('ab')) == r"\operatorname{ab}" + assert latex(Function('ab1')) == r"\operatorname{ab}_{1}" + assert latex(Function('ab12')) == r"\operatorname{ab}_{12}" + assert latex(Function('ab_1')) == r"\operatorname{ab}_{1}" + assert latex(Function('ab_12')) == r"\operatorname{ab}_{12}" + assert latex(Function('ab_c')) == r"\operatorname{ab}_{c}" + assert latex(Function('ab_cd')) == r"\operatorname{ab}_{cd}" + # > with argument + assert latex(Function('ab')(Symbol('x'))) == r"\operatorname{ab}{\left(x \right)}" + assert latex(Function('ab1')(Symbol('x'))) == r"\operatorname{ab}_{1}{\left(x \right)}" + assert latex(Function('ab12')(Symbol('x'))) == r"\operatorname{ab}_{12}{\left(x \right)}" + assert latex(Function('ab_1')(Symbol('x'))) == r"\operatorname{ab}_{1}{\left(x \right)}" + assert latex(Function('ab_c')(Symbol('x'))) == r"\operatorname{ab}_{c}{\left(x \right)}" + assert latex(Function('ab_cd')(Symbol('x'))) == r"\operatorname{ab}_{cd}{\left(x \right)}" + + # > with power + # does not work on functions without brackets + + # > with argument and power combined + assert latex(Function('ab')()**2) == r"\operatorname{ab}^{2}{\left( \right)}" + assert latex(Function('ab1')()**2) == r"\operatorname{ab}_{1}^{2}{\left( \right)}" + assert latex(Function('ab12')()**2) == r"\operatorname{ab}_{12}^{2}{\left( \right)}" + assert latex(Function('ab_1')()**2) == r"\operatorname{ab}_{1}^{2}{\left( \right)}" + assert latex(Function('ab_12')()**2) == r"\operatorname{ab}_{12}^{2}{\left( \right)}" + assert latex(Function('ab')(Symbol('x'))**2) == r"\operatorname{ab}^{2}{\left(x \right)}" + assert latex(Function('ab1')(Symbol('x'))**2) == r"\operatorname{ab}_{1}^{2}{\left(x \right)}" + assert latex(Function('ab12')(Symbol('x'))**2) == r"\operatorname{ab}_{12}^{2}{\left(x \right)}" + assert latex(Function('ab_1')(Symbol('x'))**2) == r"\operatorname{ab}_{1}^{2}{\left(x \right)}" + assert latex(Function('ab_12')(Symbol('x'))**2) == \ + r"\operatorname{ab}_{12}^{2}{\left(x \right)}" + + # single letter function names + # > simple + assert latex(Function('a')) == r"a" + assert latex(Function('a1')) == r"a_{1}" + assert latex(Function('a12')) == r"a_{12}" + assert latex(Function('a_1')) == r"a_{1}" + assert latex(Function('a_12')) == r"a_{12}" + + # > with argument + assert latex(Function('a')()) == r"a{\left( \right)}" + assert latex(Function('a1')()) == r"a_{1}{\left( \right)}" + assert latex(Function('a12')()) == r"a_{12}{\left( \right)}" + assert latex(Function('a_1')()) == r"a_{1}{\left( \right)}" + assert latex(Function('a_12')()) == r"a_{12}{\left( \right)}" + + # > with power + # does not work on functions without brackets + + # > with argument and power combined + assert latex(Function('a')()**2) == r"a^{2}{\left( \right)}" + assert latex(Function('a1')()**2) == r"a_{1}^{2}{\left( \right)}" + assert latex(Function('a12')()**2) == r"a_{12}^{2}{\left( \right)}" + assert latex(Function('a_1')()**2) == r"a_{1}^{2}{\left( \right)}" + assert latex(Function('a_12')()**2) == r"a_{12}^{2}{\left( \right)}" + assert latex(Function('a')(Symbol('x'))**2) == r"a^{2}{\left(x \right)}" + assert latex(Function('a1')(Symbol('x'))**2) == r"a_{1}^{2}{\left(x \right)}" + assert latex(Function('a12')(Symbol('x'))**2) == r"a_{12}^{2}{\left(x \right)}" + assert latex(Function('a_1')(Symbol('x'))**2) == r"a_{1}^{2}{\left(x \right)}" + assert latex(Function('a_12')(Symbol('x'))**2) == r"a_{12}^{2}{\left(x \right)}" + + assert latex(Function('a')()**32) == r"a^{32}{\left( \right)}" + assert latex(Function('a1')()**32) == r"a_{1}^{32}{\left( \right)}" + assert latex(Function('a12')()**32) == r"a_{12}^{32}{\left( \right)}" + assert latex(Function('a_1')()**32) == r"a_{1}^{32}{\left( \right)}" + assert latex(Function('a_12')()**32) == r"a_{12}^{32}{\left( \right)}" + assert latex(Function('a')(Symbol('x'))**32) == r"a^{32}{\left(x \right)}" + assert latex(Function('a1')(Symbol('x'))**32) == r"a_{1}^{32}{\left(x \right)}" + assert latex(Function('a12')(Symbol('x'))**32) == r"a_{12}^{32}{\left(x \right)}" + assert latex(Function('a_1')(Symbol('x'))**32) == r"a_{1}^{32}{\left(x \right)}" + assert latex(Function('a_12')(Symbol('x'))**32) == r"a_{12}^{32}{\left(x \right)}" + + assert latex(Function('a')()**a) == r"a^{a}{\left( \right)}" + assert latex(Function('a1')()**a) == r"a_{1}^{a}{\left( \right)}" + assert latex(Function('a12')()**a) == r"a_{12}^{a}{\left( \right)}" + assert latex(Function('a_1')()**a) == r"a_{1}^{a}{\left( \right)}" + assert latex(Function('a_12')()**a) == r"a_{12}^{a}{\left( \right)}" + assert latex(Function('a')(Symbol('x'))**a) == r"a^{a}{\left(x \right)}" + assert latex(Function('a1')(Symbol('x'))**a) == r"a_{1}^{a}{\left(x \right)}" + assert latex(Function('a12')(Symbol('x'))**a) == r"a_{12}^{a}{\left(x \right)}" + assert latex(Function('a_1')(Symbol('x'))**a) == r"a_{1}^{a}{\left(x \right)}" + assert latex(Function('a_12')(Symbol('x'))**a) == r"a_{12}^{a}{\left(x \right)}" + + ab = Symbol('ab') + assert latex(Function('a')()**ab) == r"a^{ab}{\left( \right)}" + assert latex(Function('a1')()**ab) == r"a_{1}^{ab}{\left( \right)}" + assert latex(Function('a12')()**ab) == r"a_{12}^{ab}{\left( \right)}" + assert latex(Function('a_1')()**ab) == r"a_{1}^{ab}{\left( \right)}" + assert latex(Function('a_12')()**ab) == r"a_{12}^{ab}{\left( \right)}" + assert latex(Function('a')(Symbol('x'))**ab) == r"a^{ab}{\left(x \right)}" + assert latex(Function('a1')(Symbol('x'))**ab) == r"a_{1}^{ab}{\left(x \right)}" + assert latex(Function('a12')(Symbol('x'))**ab) == r"a_{12}^{ab}{\left(x \right)}" + assert latex(Function('a_1')(Symbol('x'))**ab) == r"a_{1}^{ab}{\left(x \right)}" + assert latex(Function('a_12')(Symbol('x'))**ab) == r"a_{12}^{ab}{\left(x \right)}" + + assert latex(Function('a^12')(x)) == R"a^{12}{\left(x \right)}" + assert latex(Function('a^12')(x) ** ab) == R"\left(a^{12}\right)^{ab}{\left(x \right)}" + assert latex(Function('a__12')(x)) == R"a^{12}{\left(x \right)}" + assert latex(Function('a__12')(x) ** ab) == R"\left(a^{12}\right)^{ab}{\left(x \right)}" + assert latex(Function('a_1__1_2')(x)) == R"a^{1}_{1 2}{\left(x \right)}" + + # issue 5868 + omega1 = Function('omega1') + assert latex(omega1) == r"\omega_{1}" + assert latex(omega1(x)) == r"\omega_{1}{\left(x \right)}" + + assert latex(sin(x)) == r"\sin{\left(x \right)}" + assert latex(sin(x), fold_func_brackets=True) == r"\sin {x}" + assert latex(sin(2*x**2), fold_func_brackets=True) == \ + r"\sin {2 x^{2}}" + assert latex(sin(x**2), fold_func_brackets=True) == \ + r"\sin {x^{2}}" + + assert latex(asin(x)**2) == r"\operatorname{asin}^{2}{\left(x \right)}" + assert latex(asin(x)**2, inv_trig_style="full") == \ + r"\arcsin^{2}{\left(x \right)}" + assert latex(asin(x)**2, inv_trig_style="power") == \ + r"\sin^{-1}{\left(x \right)}^{2}" + assert latex(asin(x**2), inv_trig_style="power", + fold_func_brackets=True) == \ + r"\sin^{-1} {x^{2}}" + assert latex(acsc(x), inv_trig_style="full") == \ + r"\operatorname{arccsc}{\left(x \right)}" + assert latex(asinh(x), inv_trig_style="full") == \ + r"\operatorname{arsinh}{\left(x \right)}" + + assert latex(factorial(k)) == r"k!" + assert latex(factorial(-k)) == r"\left(- k\right)!" + assert latex(factorial(k)**2) == r"k!^{2}" + + assert latex(subfactorial(k)) == r"!k" + assert latex(subfactorial(-k)) == r"!\left(- k\right)" + assert latex(subfactorial(k)**2) == r"\left(!k\right)^{2}" + + assert latex(factorial2(k)) == r"k!!" + assert latex(factorial2(-k)) == r"\left(- k\right)!!" + assert latex(factorial2(k)**2) == r"k!!^{2}" + + assert latex(binomial(2, k)) == r"{\binom{2}{k}}" + assert latex(binomial(2, k)**2) == r"{\binom{2}{k}}^{2}" + + assert latex(FallingFactorial(3, k)) == r"{\left(3\right)}_{k}" + assert latex(RisingFactorial(3, k)) == r"{3}^{\left(k\right)}" + + assert latex(floor(x)) == r"\left\lfloor{x}\right\rfloor" + assert latex(ceiling(x)) == r"\left\lceil{x}\right\rceil" + assert latex(frac(x)) == r"\operatorname{frac}{\left(x\right)}" + assert latex(floor(x)**2) == r"\left\lfloor{x}\right\rfloor^{2}" + assert latex(ceiling(x)**2) == r"\left\lceil{x}\right\rceil^{2}" + assert latex(frac(x)**2) == r"\operatorname{frac}{\left(x\right)}^{2}" + + assert latex(Min(x, 2, x**3)) == r"\min\left(2, x, x^{3}\right)" + assert latex(Min(x, y)**2) == r"\min\left(x, y\right)^{2}" + assert latex(Max(x, 2, x**3)) == r"\max\left(2, x, x^{3}\right)" + assert latex(Max(x, y)**2) == r"\max\left(x, y\right)^{2}" + assert latex(Abs(x)) == r"\left|{x}\right|" + assert latex(Abs(x)**2) == r"\left|{x}\right|^{2}" + assert latex(re(x)) == r"\operatorname{re}{\left(x\right)}" + assert latex(re(x + y)) == \ + r"\operatorname{re}{\left(x\right)} + \operatorname{re}{\left(y\right)}" + assert latex(im(x)) == r"\operatorname{im}{\left(x\right)}" + assert latex(conjugate(x)) == r"\overline{x}" + assert latex(conjugate(x)**2) == r"\overline{x}^{2}" + assert latex(conjugate(x**2)) == r"\overline{x}^{2}" + assert latex(gamma(x)) == r"\Gamma\left(x\right)" + w = Wild('w') + assert latex(gamma(w)) == r"\Gamma\left(w\right)" + assert latex(Order(x)) == r"O\left(x\right)" + assert latex(Order(x, x)) == r"O\left(x\right)" + assert latex(Order(x, (x, 0))) == r"O\left(x\right)" + assert latex(Order(x, (x, oo))) == r"O\left(x; x\rightarrow \infty\right)" + assert latex(Order(x - y, (x, y))) == \ + r"O\left(x - y; x\rightarrow y\right)" + assert latex(Order(x, x, y)) == \ + r"O\left(x; \left( x, \ y\right)\rightarrow \left( 0, \ 0\right)\right)" + assert latex(Order(x, x, y)) == \ + r"O\left(x; \left( x, \ y\right)\rightarrow \left( 0, \ 0\right)\right)" + assert latex(Order(x, (x, oo), (y, oo))) == \ + r"O\left(x; \left( x, \ y\right)\rightarrow \left( \infty, \ \infty\right)\right)" + assert latex(lowergamma(x, y)) == r'\gamma\left(x, y\right)' + assert latex(lowergamma(x, y)**2) == r'\gamma^{2}\left(x, y\right)' + assert latex(uppergamma(x, y)) == r'\Gamma\left(x, y\right)' + assert latex(uppergamma(x, y)**2) == r'\Gamma^{2}\left(x, y\right)' + + assert latex(cot(x)) == r'\cot{\left(x \right)}' + assert latex(coth(x)) == r'\coth{\left(x \right)}' + assert latex(re(x)) == r'\operatorname{re}{\left(x\right)}' + assert latex(im(x)) == r'\operatorname{im}{\left(x\right)}' + assert latex(root(x, y)) == r'x^{\frac{1}{y}}' + assert latex(arg(x)) == r'\arg{\left(x \right)}' + + assert latex(zeta(x)) == r"\zeta\left(x\right)" + assert latex(zeta(x)**2) == r"\zeta^{2}\left(x\right)" + assert latex(zeta(x, y)) == r"\zeta\left(x, y\right)" + assert latex(zeta(x, y)**2) == r"\zeta^{2}\left(x, y\right)" + assert latex(dirichlet_eta(x)) == r"\eta\left(x\right)" + assert latex(dirichlet_eta(x)**2) == r"\eta^{2}\left(x\right)" + assert latex(polylog(x, y)) == r"\operatorname{Li}_{x}\left(y\right)" + assert latex( + polylog(x, y)**2) == r"\operatorname{Li}_{x}^{2}\left(y\right)" + assert latex(lerchphi(x, y, n)) == r"\Phi\left(x, y, n\right)" + assert latex(lerchphi(x, y, n)**2) == r"\Phi^{2}\left(x, y, n\right)" + assert latex(stieltjes(x)) == r"\gamma_{x}" + assert latex(stieltjes(x)**2) == r"\gamma_{x}^{2}" + assert latex(stieltjes(x, y)) == r"\gamma_{x}\left(y\right)" + assert latex(stieltjes(x, y)**2) == r"\gamma_{x}\left(y\right)^{2}" + + assert latex(elliptic_k(z)) == r"K\left(z\right)" + assert latex(elliptic_k(z)**2) == r"K^{2}\left(z\right)" + assert latex(elliptic_f(x, y)) == r"F\left(x\middle| y\right)" + assert latex(elliptic_f(x, y)**2) == r"F^{2}\left(x\middle| y\right)" + assert latex(elliptic_e(x, y)) == r"E\left(x\middle| y\right)" + assert latex(elliptic_e(x, y)**2) == r"E^{2}\left(x\middle| y\right)" + assert latex(elliptic_e(z)) == r"E\left(z\right)" + assert latex(elliptic_e(z)**2) == r"E^{2}\left(z\right)" + assert latex(elliptic_pi(x, y, z)) == r"\Pi\left(x; y\middle| z\right)" + assert latex(elliptic_pi(x, y, z)**2) == \ + r"\Pi^{2}\left(x; y\middle| z\right)" + assert latex(elliptic_pi(x, y)) == r"\Pi\left(x\middle| y\right)" + assert latex(elliptic_pi(x, y)**2) == r"\Pi^{2}\left(x\middle| y\right)" + + assert latex(Ei(x)) == r'\operatorname{Ei}{\left(x \right)}' + assert latex(Ei(x)**2) == r'\operatorname{Ei}^{2}{\left(x \right)}' + assert latex(expint(x, y)) == r'\operatorname{E}_{x}\left(y\right)' + assert latex(expint(x, y)**2) == r'\operatorname{E}_{x}^{2}\left(y\right)' + assert latex(Shi(x)**2) == r'\operatorname{Shi}^{2}{\left(x \right)}' + assert latex(Si(x)**2) == r'\operatorname{Si}^{2}{\left(x \right)}' + assert latex(Ci(x)**2) == r'\operatorname{Ci}^{2}{\left(x \right)}' + assert latex(Chi(x)**2) == r'\operatorname{Chi}^{2}\left(x\right)' + assert latex(Chi(x)) == r'\operatorname{Chi}\left(x\right)' + assert latex(jacobi(n, a, b, x)) == \ + r'P_{n}^{\left(a,b\right)}\left(x\right)' + assert latex(jacobi(n, a, b, x)**2) == \ + r'\left(P_{n}^{\left(a,b\right)}\left(x\right)\right)^{2}' + assert latex(gegenbauer(n, a, x)) == \ + r'C_{n}^{\left(a\right)}\left(x\right)' + assert latex(gegenbauer(n, a, x)**2) == \ + r'\left(C_{n}^{\left(a\right)}\left(x\right)\right)^{2}' + assert latex(chebyshevt(n, x)) == r'T_{n}\left(x\right)' + assert latex(chebyshevt(n, x)**2) == \ + r'\left(T_{n}\left(x\right)\right)^{2}' + assert latex(chebyshevu(n, x)) == r'U_{n}\left(x\right)' + assert latex(chebyshevu(n, x)**2) == \ + r'\left(U_{n}\left(x\right)\right)^{2}' + assert latex(legendre(n, x)) == r'P_{n}\left(x\right)' + assert latex(legendre(n, x)**2) == r'\left(P_{n}\left(x\right)\right)^{2}' + assert latex(assoc_legendre(n, a, x)) == \ + r'P_{n}^{\left(a\right)}\left(x\right)' + assert latex(assoc_legendre(n, a, x)**2) == \ + r'\left(P_{n}^{\left(a\right)}\left(x\right)\right)^{2}' + assert latex(laguerre(n, x)) == r'L_{n}\left(x\right)' + assert latex(laguerre(n, x)**2) == r'\left(L_{n}\left(x\right)\right)^{2}' + assert latex(assoc_laguerre(n, a, x)) == \ + r'L_{n}^{\left(a\right)}\left(x\right)' + assert latex(assoc_laguerre(n, a, x)**2) == \ + r'\left(L_{n}^{\left(a\right)}\left(x\right)\right)^{2}' + assert latex(hermite(n, x)) == r'H_{n}\left(x\right)' + assert latex(hermite(n, x)**2) == r'\left(H_{n}\left(x\right)\right)^{2}' + + theta = Symbol("theta", real=True) + phi = Symbol("phi", real=True) + assert latex(Ynm(n, m, theta, phi)) == r'Y_{n}^{m}\left(\theta,\phi\right)' + assert latex(Ynm(n, m, theta, phi)**3) == \ + r'\left(Y_{n}^{m}\left(\theta,\phi\right)\right)^{3}' + assert latex(Znm(n, m, theta, phi)) == r'Z_{n}^{m}\left(\theta,\phi\right)' + assert latex(Znm(n, m, theta, phi)**3) == \ + r'\left(Z_{n}^{m}\left(\theta,\phi\right)\right)^{3}' + + # Test latex printing of function names with "_" + assert latex(polar_lift(0)) == \ + r"\operatorname{polar\_lift}{\left(0 \right)}" + assert latex(polar_lift(0)**3) == \ + r"\operatorname{polar\_lift}^{3}{\left(0 \right)}" + + assert latex(totient(n)) == r'\phi\left(n\right)' + assert latex(totient(n) ** 2) == r'\left(\phi\left(n\right)\right)^{2}' + + assert latex(reduced_totient(n)) == r'\lambda\left(n\right)' + assert latex(reduced_totient(n) ** 2) == \ + r'\left(\lambda\left(n\right)\right)^{2}' + + assert latex(divisor_sigma(x)) == r"\sigma\left(x\right)" + assert latex(divisor_sigma(x)**2) == r"\sigma^{2}\left(x\right)" + assert latex(divisor_sigma(x, y)) == r"\sigma_y\left(x\right)" + assert latex(divisor_sigma(x, y)**2) == r"\sigma^{2}_y\left(x\right)" + + assert latex(udivisor_sigma(x)) == r"\sigma^*\left(x\right)" + assert latex(udivisor_sigma(x)**2) == r"\sigma^*^{2}\left(x\right)" + assert latex(udivisor_sigma(x, y)) == r"\sigma^*_y\left(x\right)" + assert latex(udivisor_sigma(x, y)**2) == r"\sigma^*^{2}_y\left(x\right)" + + assert latex(primenu(n)) == r'\nu\left(n\right)' + assert latex(primenu(n) ** 2) == r'\left(\nu\left(n\right)\right)^{2}' + + assert latex(primeomega(n)) == r'\Omega\left(n\right)' + assert latex(primeomega(n) ** 2) == \ + r'\left(\Omega\left(n\right)\right)^{2}' + + assert latex(LambertW(n)) == r'W\left(n\right)' + assert latex(LambertW(n, -1)) == r'W_{-1}\left(n\right)' + assert latex(LambertW(n, k)) == r'W_{k}\left(n\right)' + assert latex(LambertW(n) * LambertW(n)) == r"W^{2}\left(n\right)" + assert latex(Pow(LambertW(n), 2)) == r"W^{2}\left(n\right)" + assert latex(LambertW(n)**k) == r"W^{k}\left(n\right)" + assert latex(LambertW(n, k)**p) == r"W^{p}_{k}\left(n\right)" + + assert latex(Mod(x, 7)) == r'x \bmod 7' + assert latex(Mod(x + 1, 7)) == r'\left(x + 1\right) \bmod 7' + assert latex(Mod(7, x + 1)) == r'7 \bmod \left(x + 1\right)' + assert latex(Mod(2 * x, 7)) == r'2 x \bmod 7' + assert latex(Mod(7, 2 * x)) == r'7 \bmod 2 x' + assert latex(Mod(x, 7) + 1) == r'\left(x \bmod 7\right) + 1' + assert latex(2 * Mod(x, 7)) == r'2 \left(x \bmod 7\right)' + assert latex(Mod(7, 2 * x)**n) == r'\left(7 \bmod 2 x\right)^{n}' + + # some unknown function name should get rendered with \operatorname + fjlkd = Function('fjlkd') + assert latex(fjlkd(x)) == r'\operatorname{fjlkd}{\left(x \right)}' + # even when it is referred to without an argument + assert latex(fjlkd) == r'\operatorname{fjlkd}' + + +# test that notation passes to subclasses of the same name only +def test_function_subclass_different_name(): + class mygamma(gamma): + pass + assert latex(mygamma) == r"\operatorname{mygamma}" + assert latex(mygamma(x)) == r"\operatorname{mygamma}{\left(x \right)}" + + +def test_hyper_printing(): + from sympy.abc import x, z + + assert latex(meijerg(Tuple(pi, pi, x), Tuple(1), + (0, 1), Tuple(1, 2, 3/pi), z)) == \ + r'{G_{4, 5}^{2, 3}\left(\begin{matrix} \pi, \pi, x & 1 \\0, 1 & 1, 2, '\ + r'\frac{3}{\pi} \end{matrix} \middle| {z} \right)}' + assert latex(meijerg(Tuple(), Tuple(1), (0,), Tuple(), z)) == \ + r'{G_{1, 1}^{1, 0}\left(\begin{matrix} & 1 \\0 & \end{matrix} \middle| {z} \right)}' + assert latex(hyper((x, 2), (3,), z)) == \ + r'{{}_{2}F_{1}\left(\begin{matrix} 2, x ' \ + r'\\ 3 \end{matrix}\middle| {z} \right)}' + assert latex(hyper(Tuple(), Tuple(1), z)) == \ + r'{{}_{0}F_{1}\left(\begin{matrix} ' \ + r'\\ 1 \end{matrix}\middle| {z} \right)}' + + +def test_latex_bessel(): + from sympy.functions.special.bessel import (besselj, bessely, besseli, + besselk, hankel1, hankel2, + jn, yn, hn1, hn2) + from sympy.abc import z + assert latex(besselj(n, z**2)**k) == r'J^{k}_{n}\left(z^{2}\right)' + assert latex(bessely(n, z)) == r'Y_{n}\left(z\right)' + assert latex(besseli(n, z)) == r'I_{n}\left(z\right)' + assert latex(besselk(n, z)) == r'K_{n}\left(z\right)' + assert latex(hankel1(n, z**2)**2) == \ + r'\left(H^{(1)}_{n}\left(z^{2}\right)\right)^{2}' + assert latex(hankel2(n, z)) == r'H^{(2)}_{n}\left(z\right)' + assert latex(jn(n, z)) == r'j_{n}\left(z\right)' + assert latex(yn(n, z)) == r'y_{n}\left(z\right)' + assert latex(hn1(n, z)) == r'h^{(1)}_{n}\left(z\right)' + assert latex(hn2(n, z)) == r'h^{(2)}_{n}\left(z\right)' + + +def test_latex_fresnel(): + from sympy.functions.special.error_functions import (fresnels, fresnelc) + from sympy.abc import z + assert latex(fresnels(z)) == r'S\left(z\right)' + assert latex(fresnelc(z)) == r'C\left(z\right)' + assert latex(fresnels(z)**2) == r'S^{2}\left(z\right)' + assert latex(fresnelc(z)**2) == r'C^{2}\left(z\right)' + + +def test_latex_brackets(): + assert latex((-1)**x) == r"\left(-1\right)^{x}" + + +def test_latex_indexed(): + Psi_symbol = Symbol('Psi_0', complex=True, real=False) + Psi_indexed = IndexedBase(Symbol('Psi', complex=True, real=False)) + symbol_latex = latex(Psi_symbol * conjugate(Psi_symbol)) + indexed_latex = latex(Psi_indexed[0] * conjugate(Psi_indexed[0])) + # \\overline{{\\Psi}_{0}} {\\Psi}_{0} vs. \\Psi_{0} \\overline{\\Psi_{0}} + assert symbol_latex == r'\Psi_{0} \overline{\Psi_{0}}' + assert indexed_latex == r'\overline{{\Psi}_{0}} {\Psi}_{0}' + + # Symbol('gamma') gives r'\gamma' + interval = '\\mathrel{..}\\nobreak ' + assert latex(Indexed('x1', Symbol('i'))) == r'{x_{1}}_{i}' + assert latex(Indexed('x2', Idx('i'))) == r'{x_{2}}_{i}' + assert latex(Indexed('x3', Idx('i', Symbol('N')))) == r'{x_{3}}_{{i}_{0'+interval+'N - 1}}' + assert latex(Indexed('x3', Idx('i', Symbol('N')+1))) == r'{x_{3}}_{{i}_{0'+interval+'N}}' + assert latex(Indexed('x4', Idx('i', (Symbol('a'),Symbol('b'))))) == r'{x_{4}}_{{i}_{a'+interval+'b}}' + assert latex(IndexedBase('gamma')) == r'\gamma' + assert latex(IndexedBase('a b')) == r'a b' + assert latex(IndexedBase('a_b')) == r'a_{b}' + + +def test_latex_derivatives(): + # regular "d" for ordinary derivatives + assert latex(diff(x**3, x, evaluate=False)) == \ + r"\frac{d}{d x} x^{3}" + assert latex(diff(sin(x) + x**2, x, evaluate=False)) == \ + r"\frac{d}{d x} \left(x^{2} + \sin{\left(x \right)}\right)" + assert latex(diff(diff(sin(x) + x**2, x, evaluate=False), evaluate=False))\ + == \ + r"\frac{d^{2}}{d x^{2}} \left(x^{2} + \sin{\left(x \right)}\right)" + assert latex(diff(diff(diff(sin(x) + x**2, x, evaluate=False), evaluate=False), evaluate=False)) == \ + r"\frac{d^{3}}{d x^{3}} \left(x^{2} + \sin{\left(x \right)}\right)" + + # \partial for partial derivatives + assert latex(diff(sin(x * y), x, evaluate=False)) == \ + r"\frac{\partial}{\partial x} \sin{\left(x y \right)}" + assert latex(diff(sin(x * y) + x**2, x, evaluate=False)) == \ + r"\frac{\partial}{\partial x} \left(x^{2} + \sin{\left(x y \right)}\right)" + assert latex(diff(diff(sin(x*y) + x**2, x, evaluate=False), x, evaluate=False)) == \ + r"\frac{\partial^{2}}{\partial x^{2}} \left(x^{2} + \sin{\left(x y \right)}\right)" + assert latex(diff(diff(diff(sin(x*y) + x**2, x, evaluate=False), x, evaluate=False), x, evaluate=False)) == \ + r"\frac{\partial^{3}}{\partial x^{3}} \left(x^{2} + \sin{\left(x y \right)}\right)" + + # mixed partial derivatives + f = Function("f") + assert latex(diff(diff(f(x, y), x, evaluate=False), y, evaluate=False)) == \ + r"\frac{\partial^{2}}{\partial y\partial x} " + latex(f(x, y)) + + assert latex(diff(diff(diff(f(x, y), x, evaluate=False), x, evaluate=False), y, evaluate=False)) == \ + r"\frac{\partial^{3}}{\partial y\partial x^{2}} " + latex(f(x, y)) + + # for negative nested Derivative + assert latex(diff(-diff(y**2,x,evaluate=False),x,evaluate=False)) == r'\frac{d}{d x} \left(- \frac{d}{d x} y^{2}\right)' + assert latex(diff(diff(-diff(diff(y,x,evaluate=False),x,evaluate=False),x,evaluate=False),x,evaluate=False)) == \ + r'\frac{d^{2}}{d x^{2}} \left(- \frac{d^{2}}{d x^{2}} y\right)' + + # use ordinary d when one of the variables has been integrated out + assert latex(diff(Integral(exp(-x*y), (x, 0, oo)), y, evaluate=False)) == \ + r"\frac{d}{d y} \int\limits_{0}^{\infty} e^{- x y}\, dx" + + # Derivative wrapped in power: + assert latex(diff(x, x, evaluate=False)**2) == \ + r"\left(\frac{d}{d x} x\right)^{2}" + + assert latex(diff(f(x), x)**2) == \ + r"\left(\frac{d}{d x} f{\left(x \right)}\right)^{2}" + + assert latex(diff(f(x), (x, n))) == \ + r"\frac{d^{n}}{d x^{n}} f{\left(x \right)}" + + x1 = Symbol('x1') + x2 = Symbol('x2') + assert latex(diff(f(x1, x2), x1)) == r'\frac{\partial}{\partial x_{1}} f{\left(x_{1},x_{2} \right)}' + + n1 = Symbol('n1') + assert latex(diff(f(x), (x, n1))) == r'\frac{d^{n_{1}}}{d x^{n_{1}}} f{\left(x \right)}' + + n2 = Symbol('n2') + assert latex(diff(f(x), (x, Max(n1, n2)))) == \ + r'\frac{d^{\max\left(n_{1}, n_{2}\right)}}{d x^{\max\left(n_{1}, n_{2}\right)}} f{\left(x \right)}' + + # set diff operator + assert latex(diff(f(x), x), diff_operator="rd") == r'\frac{\mathrm{d}}{\mathrm{d} x} f{\left(x \right)}' + + +def test_latex_subs(): + assert latex(Subs(x*y, (x, y), (1, 2))) == r'\left. x y \right|_{\substack{ x=1\\ y=2 }}' + + +def test_latex_integrals(): + assert latex(Integral(log(x), x)) == r"\int \log{\left(x \right)}\, dx" + assert latex(Integral(x**2, (x, 0, 1))) == \ + r"\int\limits_{0}^{1} x^{2}\, dx" + assert latex(Integral(x**2, (x, 10, 20))) == \ + r"\int\limits_{10}^{20} x^{2}\, dx" + assert latex(Integral(y*x**2, (x, 0, 1), y)) == \ + r"\int\int\limits_{0}^{1} x^{2} y\, dx\, dy" + assert latex(Integral(y*x**2, (x, 0, 1), y), mode='equation*') == \ + r"\begin{equation*}\int\int\limits_{0}^{1} x^{2} y\, dx\, dy\end{equation*}" + assert latex(Integral(y*x**2, (x, 0, 1), y), mode='equation*', itex=True) \ + == r"$$\int\int_{0}^{1} x^{2} y\, dx\, dy$$" + assert latex(Integral(x, (x, 0))) == r"\int\limits^{0} x\, dx" + assert latex(Integral(x*y, x, y)) == r"\iint x y\, dx\, dy" + assert latex(Integral(x*y*z, x, y, z)) == r"\iiint x y z\, dx\, dy\, dz" + assert latex(Integral(x*y*z*t, x, y, z, t)) == \ + r"\iiiint t x y z\, dx\, dy\, dz\, dt" + assert latex(Integral(x, x, x, x, x, x, x)) == \ + r"\int\int\int\int\int\int x\, dx\, dx\, dx\, dx\, dx\, dx" + assert latex(Integral(x, x, y, (z, 0, 1))) == \ + r"\int\limits_{0}^{1}\int\int x\, dx\, dy\, dz" + + # for negative nested Integral + assert latex(Integral(-Integral(y**2,x),x)) == \ + r'\int \left(- \int y^{2}\, dx\right)\, dx' + assert latex(Integral(-Integral(-Integral(y,x),x),x)) == \ + r'\int \left(- \int \left(- \int y\, dx\right)\, dx\right)\, dx' + + # fix issue #10806 + assert latex(Integral(z, z)**2) == r"\left(\int z\, dz\right)^{2}" + assert latex(Integral(x + z, z)) == r"\int \left(x + z\right)\, dz" + assert latex(Integral(x+z/2, z)) == \ + r"\int \left(x + \frac{z}{2}\right)\, dz" + assert latex(Integral(x**y, z)) == r"\int x^{y}\, dz" + + # set diff operator + assert latex(Integral(x, x), diff_operator="rd") == r'\int x\, \mathrm{d}x' + assert latex(Integral(x, (x, 0, 1)), diff_operator="rd") == r'\int\limits_{0}^{1} x\, \mathrm{d}x' + + +def test_latex_sets(): + for s in (frozenset, set): + assert latex(s([x*y, x**2])) == r"\left\{x^{2}, x y\right\}" + assert latex(s(range(1, 6))) == r"\left\{1, 2, 3, 4, 5\right\}" + assert latex(s(range(1, 13))) == \ + r"\left\{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12\right\}" + + s = FiniteSet + assert latex(s(*[x*y, x**2])) == r"\left\{x^{2}, x y\right\}" + assert latex(s(*range(1, 6))) == r"\left\{1, 2, 3, 4, 5\right\}" + assert latex(s(*range(1, 13))) == \ + r"\left\{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12\right\}" + + +def test_latex_SetExpr(): + iv = Interval(1, 3) + se = SetExpr(iv) + assert latex(se) == r"SetExpr\left(\left[1, 3\right]\right)" + + +def test_latex_Range(): + assert latex(Range(1, 51)) == r'\left\{1, 2, \ldots, 50\right\}' + assert latex(Range(1, 4)) == r'\left\{1, 2, 3\right\}' + assert latex(Range(0, 3, 1)) == r'\left\{0, 1, 2\right\}' + assert latex(Range(0, 30, 1)) == r'\left\{0, 1, \ldots, 29\right\}' + assert latex(Range(30, 1, -1)) == r'\left\{30, 29, \ldots, 2\right\}' + assert latex(Range(0, oo, 2)) == r'\left\{0, 2, \ldots\right\}' + assert latex(Range(oo, -2, -2)) == r'\left\{\ldots, 2, 0\right\}' + assert latex(Range(-2, -oo, -1)) == r'\left\{-2, -3, \ldots\right\}' + assert latex(Range(-oo, oo)) == r'\left\{\ldots, -1, 0, 1, \ldots\right\}' + assert latex(Range(oo, -oo, -1)) == r'\left\{\ldots, 1, 0, -1, \ldots\right\}' + + a, b, c = symbols('a:c') + assert latex(Range(a, b, c)) == r'\text{Range}\left(a, b, c\right)' + assert latex(Range(a, 10, 1)) == r'\text{Range}\left(a, 10\right)' + assert latex(Range(0, b, 1)) == r'\text{Range}\left(b\right)' + assert latex(Range(0, 10, c)) == r'\text{Range}\left(0, 10, c\right)' + + i = Symbol('i', integer=True) + n = Symbol('n', negative=True, integer=True) + p = Symbol('p', positive=True, integer=True) + + assert latex(Range(i, i + 3)) == r'\left\{i, i + 1, i + 2\right\}' + assert latex(Range(-oo, n, 2)) == r'\left\{\ldots, n - 4, n - 2\right\}' + assert latex(Range(p, oo)) == r'\left\{p, p + 1, \ldots\right\}' + # The following will work if __iter__ is improved + # assert latex(Range(-3, p + 7)) == r'\left\{-3, -2, \ldots, p + 6\right\}' + # Must have integer assumptions + assert latex(Range(a, a + 3)) == r'\text{Range}\left(a, a + 3\right)' + + +def test_latex_sequences(): + s1 = SeqFormula(a**2, (0, oo)) + s2 = SeqPer((1, 2)) + + latex_str = r'\left[0, 1, 4, 9, \ldots\right]' + assert latex(s1) == latex_str + + latex_str = r'\left[1, 2, 1, 2, \ldots\right]' + assert latex(s2) == latex_str + + s3 = SeqFormula(a**2, (0, 2)) + s4 = SeqPer((1, 2), (0, 2)) + + latex_str = r'\left[0, 1, 4\right]' + assert latex(s3) == latex_str + + latex_str = r'\left[1, 2, 1\right]' + assert latex(s4) == latex_str + + s5 = SeqFormula(a**2, (-oo, 0)) + s6 = SeqPer((1, 2), (-oo, 0)) + + latex_str = r'\left[\ldots, 9, 4, 1, 0\right]' + assert latex(s5) == latex_str + + latex_str = r'\left[\ldots, 2, 1, 2, 1\right]' + assert latex(s6) == latex_str + + latex_str = r'\left[1, 3, 5, 11, \ldots\right]' + assert latex(SeqAdd(s1, s2)) == latex_str + + latex_str = r'\left[1, 3, 5\right]' + assert latex(SeqAdd(s3, s4)) == latex_str + + latex_str = r'\left[\ldots, 11, 5, 3, 1\right]' + assert latex(SeqAdd(s5, s6)) == latex_str + + latex_str = r'\left[0, 2, 4, 18, \ldots\right]' + assert latex(SeqMul(s1, s2)) == latex_str + + latex_str = r'\left[0, 2, 4\right]' + assert latex(SeqMul(s3, s4)) == latex_str + + latex_str = r'\left[\ldots, 18, 4, 2, 0\right]' + assert latex(SeqMul(s5, s6)) == latex_str + + # Sequences with symbolic limits, issue 12629 + s7 = SeqFormula(a**2, (a, 0, x)) + latex_str = r'\left\{a^{2}\right\}_{a=0}^{x}' + assert latex(s7) == latex_str + + b = Symbol('b') + s8 = SeqFormula(b*a**2, (a, 0, 2)) + latex_str = r'\left[0, b, 4 b\right]' + assert latex(s8) == latex_str + + +def test_latex_FourierSeries(): + latex_str = \ + r'2 \sin{\left(x \right)} - \sin{\left(2 x \right)} + \frac{2 \sin{\left(3 x \right)}}{3} + \ldots' + assert latex(fourier_series(x, (x, -pi, pi))) == latex_str + + +def test_latex_FormalPowerSeries(): + latex_str = r'\sum_{k=1}^{\infty} - \frac{\left(-1\right)^{- k} x^{k}}{k}' + assert latex(fps(log(1 + x))) == latex_str + + +def test_latex_intervals(): + a = Symbol('a', real=True) + assert latex(Interval(0, 0)) == r"\left\{0\right\}" + assert latex(Interval(0, a)) == r"\left[0, a\right]" + assert latex(Interval(0, a, False, False)) == r"\left[0, a\right]" + assert latex(Interval(0, a, True, False)) == r"\left(0, a\right]" + assert latex(Interval(0, a, False, True)) == r"\left[0, a\right)" + assert latex(Interval(0, a, True, True)) == r"\left(0, a\right)" + + +def test_latex_AccumuBounds(): + a = Symbol('a', real=True) + assert latex(AccumBounds(0, 1)) == r"\left\langle 0, 1\right\rangle" + assert latex(AccumBounds(0, a)) == r"\left\langle 0, a\right\rangle" + assert latex(AccumBounds(a + 1, a + 2)) == \ + r"\left\langle a + 1, a + 2\right\rangle" + + +def test_latex_emptyset(): + assert latex(S.EmptySet) == r"\emptyset" + + +def test_latex_universalset(): + assert latex(S.UniversalSet) == r"\mathbb{U}" + + +def test_latex_commutator(): + A = Operator('A') + B = Operator('B') + comm = Commutator(B, A) + assert latex(comm.doit()) == r"- (A B - B A)" + + +def test_latex_union(): + assert latex(Union(Interval(0, 1), Interval(2, 3))) == \ + r"\left[0, 1\right] \cup \left[2, 3\right]" + assert latex(Union(Interval(1, 1), Interval(2, 2), Interval(3, 4))) == \ + r"\left\{1, 2\right\} \cup \left[3, 4\right]" + + +def test_latex_intersection(): + assert latex(Intersection(Interval(0, 1), Interval(x, y))) == \ + r"\left[0, 1\right] \cap \left[x, y\right]" + + +def test_latex_symmetric_difference(): + assert latex(SymmetricDifference(Interval(2, 5), Interval(4, 7), + evaluate=False)) == \ + r'\left[2, 5\right] \triangle \left[4, 7\right]' + + +def test_latex_Complement(): + assert latex(Complement(S.Reals, S.Naturals)) == \ + r"\mathbb{R} \setminus \mathbb{N}" + + +def test_latex_productset(): + line = Interval(0, 1) + bigline = Interval(0, 10) + fset = FiniteSet(1, 2, 3) + assert latex(line**2) == r"%s^{2}" % latex(line) + assert latex(line**10) == r"%s^{10}" % latex(line) + assert latex((line * bigline * fset).flatten()) == r"%s \times %s \times %s" % ( + latex(line), latex(bigline), latex(fset)) + + +def test_latex_powerset(): + fset = FiniteSet(1, 2, 3) + assert latex(PowerSet(fset)) == r'\mathcal{P}\left(\left\{1, 2, 3\right\}\right)' + + +def test_latex_ordinals(): + w = OrdinalOmega() + assert latex(w) == r"\omega" + wp = OmegaPower(2, 3) + assert latex(wp) == r'3 \omega^{2}' + assert latex(Ordinal(wp, OmegaPower(1, 1))) == r'3 \omega^{2} + \omega' + assert latex(Ordinal(OmegaPower(2, 1), OmegaPower(1, 2))) == r'\omega^{2} + 2 \omega' + + +def test_set_operators_parenthesis(): + a, b, c, d = symbols('a:d') + A = FiniteSet(a) + B = FiniteSet(b) + C = FiniteSet(c) + D = FiniteSet(d) + + U1 = Union(A, B, evaluate=False) + U2 = Union(C, D, evaluate=False) + I1 = Intersection(A, B, evaluate=False) + I2 = Intersection(C, D, evaluate=False) + C1 = Complement(A, B, evaluate=False) + C2 = Complement(C, D, evaluate=False) + D1 = SymmetricDifference(A, B, evaluate=False) + D2 = SymmetricDifference(C, D, evaluate=False) + # XXX ProductSet does not support evaluate keyword + P1 = ProductSet(A, B) + P2 = ProductSet(C, D) + + assert latex(Intersection(A, U2, evaluate=False)) == \ + r'\left\{a\right\} \cap ' \ + r'\left(\left\{c\right\} \cup \left\{d\right\}\right)' + assert latex(Intersection(U1, U2, evaluate=False)) == \ + r'\left(\left\{a\right\} \cup \left\{b\right\}\right) ' \ + r'\cap \left(\left\{c\right\} \cup \left\{d\right\}\right)' + assert latex(Intersection(C1, C2, evaluate=False)) == \ + r'\left(\left\{a\right\} \setminus ' \ + r'\left\{b\right\}\right) \cap \left(\left\{c\right\} ' \ + r'\setminus \left\{d\right\}\right)' + assert latex(Intersection(D1, D2, evaluate=False)) == \ + r'\left(\left\{a\right\} \triangle ' \ + r'\left\{b\right\}\right) \cap \left(\left\{c\right\} ' \ + r'\triangle \left\{d\right\}\right)' + assert latex(Intersection(P1, P2, evaluate=False)) == \ + r'\left(\left\{a\right\} \times \left\{b\right\}\right) ' \ + r'\cap \left(\left\{c\right\} \times ' \ + r'\left\{d\right\}\right)' + + assert latex(Union(A, I2, evaluate=False)) == \ + r'\left\{a\right\} \cup ' \ + r'\left(\left\{c\right\} \cap \left\{d\right\}\right)' + assert latex(Union(I1, I2, evaluate=False)) == \ + r'\left(\left\{a\right\} \cap \left\{b\right\}\right) ' \ + r'\cup \left(\left\{c\right\} \cap \left\{d\right\}\right)' + assert latex(Union(C1, C2, evaluate=False)) == \ + r'\left(\left\{a\right\} \setminus ' \ + r'\left\{b\right\}\right) \cup \left(\left\{c\right\} ' \ + r'\setminus \left\{d\right\}\right)' + assert latex(Union(D1, D2, evaluate=False)) == \ + r'\left(\left\{a\right\} \triangle ' \ + r'\left\{b\right\}\right) \cup \left(\left\{c\right\} ' \ + r'\triangle \left\{d\right\}\right)' + assert latex(Union(P1, P2, evaluate=False)) == \ + r'\left(\left\{a\right\} \times \left\{b\right\}\right) ' \ + r'\cup \left(\left\{c\right\} \times ' \ + r'\left\{d\right\}\right)' + + assert latex(Complement(A, C2, evaluate=False)) == \ + r'\left\{a\right\} \setminus \left(\left\{c\right\} ' \ + r'\setminus \left\{d\right\}\right)' + assert latex(Complement(U1, U2, evaluate=False)) == \ + r'\left(\left\{a\right\} \cup \left\{b\right\}\right) ' \ + r'\setminus \left(\left\{c\right\} \cup ' \ + r'\left\{d\right\}\right)' + assert latex(Complement(I1, I2, evaluate=False)) == \ + r'\left(\left\{a\right\} \cap \left\{b\right\}\right) ' \ + r'\setminus \left(\left\{c\right\} \cap ' \ + r'\left\{d\right\}\right)' + assert latex(Complement(D1, D2, evaluate=False)) == \ + r'\left(\left\{a\right\} \triangle ' \ + r'\left\{b\right\}\right) \setminus ' \ + r'\left(\left\{c\right\} \triangle \left\{d\right\}\right)' + assert latex(Complement(P1, P2, evaluate=False)) == \ + r'\left(\left\{a\right\} \times \left\{b\right\}\right) '\ + r'\setminus \left(\left\{c\right\} \times '\ + r'\left\{d\right\}\right)' + + assert latex(SymmetricDifference(A, D2, evaluate=False)) == \ + r'\left\{a\right\} \triangle \left(\left\{c\right\} ' \ + r'\triangle \left\{d\right\}\right)' + assert latex(SymmetricDifference(U1, U2, evaluate=False)) == \ + r'\left(\left\{a\right\} \cup \left\{b\right\}\right) ' \ + r'\triangle \left(\left\{c\right\} \cup ' \ + r'\left\{d\right\}\right)' + assert latex(SymmetricDifference(I1, I2, evaluate=False)) == \ + r'\left(\left\{a\right\} \cap \left\{b\right\}\right) ' \ + r'\triangle \left(\left\{c\right\} \cap ' \ + r'\left\{d\right\}\right)' + assert latex(SymmetricDifference(C1, C2, evaluate=False)) == \ + r'\left(\left\{a\right\} \setminus ' \ + r'\left\{b\right\}\right) \triangle ' \ + r'\left(\left\{c\right\} \setminus \left\{d\right\}\right)' + assert latex(SymmetricDifference(P1, P2, evaluate=False)) == \ + r'\left(\left\{a\right\} \times \left\{b\right\}\right) ' \ + r'\triangle \left(\left\{c\right\} \times ' \ + r'\left\{d\right\}\right)' + + # XXX This can be incorrect since cartesian product is not associative + assert latex(ProductSet(A, P2).flatten()) == \ + r'\left\{a\right\} \times \left\{c\right\} \times ' \ + r'\left\{d\right\}' + assert latex(ProductSet(U1, U2)) == \ + r'\left(\left\{a\right\} \cup \left\{b\right\}\right) ' \ + r'\times \left(\left\{c\right\} \cup ' \ + r'\left\{d\right\}\right)' + assert latex(ProductSet(I1, I2)) == \ + r'\left(\left\{a\right\} \cap \left\{b\right\}\right) ' \ + r'\times \left(\left\{c\right\} \cap ' \ + r'\left\{d\right\}\right)' + assert latex(ProductSet(C1, C2)) == \ + r'\left(\left\{a\right\} \setminus ' \ + r'\left\{b\right\}\right) \times \left(\left\{c\right\} ' \ + r'\setminus \left\{d\right\}\right)' + assert latex(ProductSet(D1, D2)) == \ + r'\left(\left\{a\right\} \triangle ' \ + r'\left\{b\right\}\right) \times \left(\left\{c\right\} ' \ + r'\triangle \left\{d\right\}\right)' + + +def test_latex_Complexes(): + assert latex(S.Complexes) == r"\mathbb{C}" + + +def test_latex_Naturals(): + assert latex(S.Naturals) == r"\mathbb{N}" + + +def test_latex_Naturals0(): + assert latex(S.Naturals0) == r"\mathbb{N}_0" + + +def test_latex_Integers(): + assert latex(S.Integers) == r"\mathbb{Z}" + + +def test_latex_ImageSet(): + x = Symbol('x') + assert latex(ImageSet(Lambda(x, x**2), S.Naturals)) == \ + r"\left\{x^{2}\; \middle|\; x \in \mathbb{N}\right\}" + + y = Symbol('y') + imgset = ImageSet(Lambda((x, y), x + y), {1, 2, 3}, {3, 4}) + assert latex(imgset) == \ + r"\left\{x + y\; \middle|\; x \in \left\{1, 2, 3\right\}, y \in \left\{3, 4\right\}\right\}" + + imgset = ImageSet(Lambda(((x, y),), x + y), ProductSet({1, 2, 3}, {3, 4})) + assert latex(imgset) == \ + r"\left\{x + y\; \middle|\; \left( x, \ y\right) \in \left\{1, 2, 3\right\} \times \left\{3, 4\right\}\right\}" + + +def test_latex_ConditionSet(): + x = Symbol('x') + assert latex(ConditionSet(x, Eq(x**2, 1), S.Reals)) == \ + r"\left\{x\; \middle|\; x \in \mathbb{R} \wedge x^{2} = 1 \right\}" + assert latex(ConditionSet(x, Eq(x**2, 1), S.UniversalSet)) == \ + r"\left\{x\; \middle|\; x^{2} = 1 \right\}" + + +def test_latex_ComplexRegion(): + assert latex(ComplexRegion(Interval(3, 5)*Interval(4, 6))) == \ + r"\left\{x + y i\; \middle|\; x, y \in \left[3, 5\right] \times \left[4, 6\right] \right\}" + assert latex(ComplexRegion(Interval(0, 1)*Interval(0, 2*pi), polar=True)) == \ + r"\left\{r \left(i \sin{\left(\theta \right)} + \cos{\left(\theta "\ + r"\right)}\right)\; \middle|\; r, \theta \in \left[0, 1\right] \times \left[0, 2 \pi\right) \right\}" + + +def test_latex_Contains(): + x = Symbol('x') + assert latex(Contains(x, S.Naturals)) == r"x \in \mathbb{N}" + + +def test_latex_sum(): + assert latex(Sum(x*y**2, (x, -2, 2), (y, -5, 5))) == \ + r"\sum_{\substack{-2 \leq x \leq 2\\-5 \leq y \leq 5}} x y^{2}" + assert latex(Sum(x**2, (x, -2, 2))) == \ + r"\sum_{x=-2}^{2} x^{2}" + assert latex(Sum(x**2 + y, (x, -2, 2))) == \ + r"\sum_{x=-2}^{2} \left(x^{2} + y\right)" + assert latex(Sum(x**2 + y, (x, -2, 2))**2) == \ + r"\left(\sum_{x=-2}^{2} \left(x^{2} + y\right)\right)^{2}" + + +def test_latex_product(): + assert latex(Product(x*y**2, (x, -2, 2), (y, -5, 5))) == \ + r"\prod_{\substack{-2 \leq x \leq 2\\-5 \leq y \leq 5}} x y^{2}" + assert latex(Product(x**2, (x, -2, 2))) == \ + r"\prod_{x=-2}^{2} x^{2}" + assert latex(Product(x**2 + y, (x, -2, 2))) == \ + r"\prod_{x=-2}^{2} \left(x^{2} + y\right)" + + assert latex(Product(x, (x, -2, 2))**2) == \ + r"\left(\prod_{x=-2}^{2} x\right)^{2}" + + +def test_latex_limits(): + assert latex(Limit(x, x, oo)) == r"\lim_{x \to \infty} x" + + # issue 8175 + f = Function('f') + assert latex(Limit(f(x), x, 0)) == r"\lim_{x \to 0^+} f{\left(x \right)}" + assert latex(Limit(f(x), x, 0, "-")) == \ + r"\lim_{x \to 0^-} f{\left(x \right)}" + + # issue #10806 + assert latex(Limit(f(x), x, 0)**2) == \ + r"\left(\lim_{x \to 0^+} f{\left(x \right)}\right)^{2}" + # bi-directional limit + assert latex(Limit(f(x), x, 0, dir='+-')) == \ + r"\lim_{x \to 0} f{\left(x \right)}" + + +def test_latex_log(): + assert latex(log(x)) == r"\log{\left(x \right)}" + assert latex(log(x), ln_notation=True) == r"\ln{\left(x \right)}" + assert latex(log(x) + log(y)) == \ + r"\log{\left(x \right)} + \log{\left(y \right)}" + assert latex(log(x) + log(y), ln_notation=True) == \ + r"\ln{\left(x \right)} + \ln{\left(y \right)}" + assert latex(pow(log(x), x)) == r"\log{\left(x \right)}^{x}" + assert latex(pow(log(x), x), ln_notation=True) == \ + r"\ln{\left(x \right)}^{x}" + + +def test_issue_3568(): + beta = Symbol(r'\beta') + y = beta + x + assert latex(y) in [r'\beta + x', r'x + \beta'] + + beta = Symbol(r'beta') + y = beta + x + assert latex(y) in [r'\beta + x', r'x + \beta'] + + +def test_latex(): + assert latex((2*tau)**Rational(7, 2)) == r"8 \sqrt{2} \tau^{\frac{7}{2}}" + assert latex((2*mu)**Rational(7, 2), mode='equation*') == \ + r"\begin{equation*}8 \sqrt{2} \mu^{\frac{7}{2}}\end{equation*}" + assert latex((2*mu)**Rational(7, 2), mode='equation', itex=True) == \ + r"$$8 \sqrt{2} \mu^{\frac{7}{2}}$$" + assert latex([2/x, y]) == r"\left[ \frac{2}{x}, \ y\right]" + + +def test_latex_dict(): + d = {Rational(1): 1, x**2: 2, x: 3, x**3: 4} + assert latex(d) == \ + r'\left\{ 1 : 1, \ x : 3, \ x^{2} : 2, \ x^{3} : 4\right\}' + D = Dict(d) + assert latex(D) == \ + r'\left\{ 1 : 1, \ x : 3, \ x^{2} : 2, \ x^{3} : 4\right\}' + + +def test_latex_list(): + ll = [Symbol('omega1'), Symbol('a'), Symbol('alpha')] + assert latex(ll) == r'\left[ \omega_{1}, \ a, \ \alpha\right]' + + +def test_latex_NumberSymbols(): + assert latex(S.Catalan) == "G" + assert latex(S.EulerGamma) == r"\gamma" + assert latex(S.Exp1) == "e" + assert latex(S.GoldenRatio) == r"\phi" + assert latex(S.Pi) == r"\pi" + assert latex(S.TribonacciConstant) == r"\text{TribonacciConstant}" + + +def test_latex_rational(): + # tests issue 3973 + assert latex(-Rational(1, 2)) == r"- \frac{1}{2}" + assert latex(Rational(-1, 2)) == r"- \frac{1}{2}" + assert latex(Rational(1, -2)) == r"- \frac{1}{2}" + assert latex(-Rational(-1, 2)) == r"\frac{1}{2}" + assert latex(-Rational(1, 2)*x) == r"- \frac{x}{2}" + assert latex(-Rational(1, 2)*x + Rational(-2, 3)*y) == \ + r"- \frac{x}{2} - \frac{2 y}{3}" + + +def test_latex_inverse(): + # tests issue 4129 + assert latex(1/x) == r"\frac{1}{x}" + assert latex(1/(x + y)) == r"\frac{1}{x + y}" + + +def test_latex_DiracDelta(): + assert latex(DiracDelta(x)) == r"\delta\left(x\right)" + assert latex(DiracDelta(x)**2) == r"\left(\delta\left(x\right)\right)^{2}" + assert latex(DiracDelta(x, 0)) == r"\delta\left(x\right)" + assert latex(DiracDelta(x, 5)) == \ + r"\delta^{\left( 5 \right)}\left( x \right)" + assert latex(DiracDelta(x, 5)**2) == \ + r"\left(\delta^{\left( 5 \right)}\left( x \right)\right)^{2}" + + +def test_latex_Heaviside(): + assert latex(Heaviside(x)) == r"\theta\left(x\right)" + assert latex(Heaviside(x)**2) == r"\left(\theta\left(x\right)\right)^{2}" + + +def test_latex_KroneckerDelta(): + assert latex(KroneckerDelta(x, y)) == r"\delta_{x y}" + assert latex(KroneckerDelta(x, y + 1)) == r"\delta_{x, y + 1}" + # issue 6578 + assert latex(KroneckerDelta(x + 1, y)) == r"\delta_{y, x + 1}" + assert latex(Pow(KroneckerDelta(x, y), 2, evaluate=False)) == \ + r"\left(\delta_{x y}\right)^{2}" + + +def test_latex_LeviCivita(): + assert latex(LeviCivita(x, y, z)) == r"\varepsilon_{x y z}" + assert latex(LeviCivita(x, y, z)**2) == \ + r"\left(\varepsilon_{x y z}\right)^{2}" + assert latex(LeviCivita(x, y, z + 1)) == r"\varepsilon_{x, y, z + 1}" + assert latex(LeviCivita(x, y + 1, z)) == r"\varepsilon_{x, y + 1, z}" + assert latex(LeviCivita(x + 1, y, z)) == r"\varepsilon_{x + 1, y, z}" + + +def test_mode(): + expr = x + y + assert latex(expr) == r'x + y' + assert latex(expr, mode='plain') == r'x + y' + assert latex(expr, mode='inline') == r'$x + y$' + assert latex( + expr, mode='equation*') == r'\begin{equation*}x + y\end{equation*}' + assert latex( + expr, mode='equation') == r'\begin{equation}x + y\end{equation}' + raises(ValueError, lambda: latex(expr, mode='foo')) + + +def test_latex_mathieu(): + assert latex(mathieuc(x, y, z)) == r"C\left(x, y, z\right)" + assert latex(mathieus(x, y, z)) == r"S\left(x, y, z\right)" + assert latex(mathieuc(x, y, z)**2) == r"C\left(x, y, z\right)^{2}" + assert latex(mathieus(x, y, z)**2) == r"S\left(x, y, z\right)^{2}" + assert latex(mathieucprime(x, y, z)) == r"C^{\prime}\left(x, y, z\right)" + assert latex(mathieusprime(x, y, z)) == r"S^{\prime}\left(x, y, z\right)" + assert latex(mathieucprime(x, y, z)**2) == r"C^{\prime}\left(x, y, z\right)^{2}" + assert latex(mathieusprime(x, y, z)**2) == r"S^{\prime}\left(x, y, z\right)^{2}" + +def test_latex_Piecewise(): + p = Piecewise((x, x < 1), (x**2, True)) + assert latex(p) == r"\begin{cases} x & \text{for}\: x < 1 \\x^{2} &" \ + r" \text{otherwise} \end{cases}" + assert latex(p, itex=True) == \ + r"\begin{cases} x & \text{for}\: x \lt 1 \\x^{2} &" \ + r" \text{otherwise} \end{cases}" + p = Piecewise((x, x < 0), (0, x >= 0)) + assert latex(p) == r'\begin{cases} x & \text{for}\: x < 0 \\0 &' \ + r' \text{otherwise} \end{cases}' + A, B = symbols("A B", commutative=False) + p = Piecewise((A**2, Eq(A, B)), (A*B, True)) + s = r"\begin{cases} A^{2} & \text{for}\: A = B \\A B & \text{otherwise} \end{cases}" + assert latex(p) == s + assert latex(A*p) == r"A \left(%s\right)" % s + assert latex(p*A) == r"\left(%s\right) A" % s + assert latex(Piecewise((x, x < 1), (x**2, x < 2))) == \ + r'\begin{cases} x & ' \ + r'\text{for}\: x < 1 \\x^{2} & \text{for}\: x < 2 \end{cases}' + + +def test_latex_Matrix(): + M = Matrix([[1 + x, y], [y, x - 1]]) + assert latex(M) == \ + r'\left[\begin{matrix}x + 1 & y\\y & x - 1\end{matrix}\right]' + assert latex(M, mode='inline') == \ + r'$\left[\begin{smallmatrix}x + 1 & y\\' \ + r'y & x - 1\end{smallmatrix}\right]$' + assert latex(M, mat_str='array') == \ + r'\left[\begin{array}{cc}x + 1 & y\\y & x - 1\end{array}\right]' + assert latex(M, mat_str='bmatrix') == \ + r'\left[\begin{bmatrix}x + 1 & y\\y & x - 1\end{bmatrix}\right]' + assert latex(M, mat_delim=None, mat_str='bmatrix') == \ + r'\begin{bmatrix}x + 1 & y\\y & x - 1\end{bmatrix}' + + M2 = Matrix(1, 11, range(11)) + assert latex(M2) == \ + r'\left[\begin{array}{ccccccccccc}' \ + r'0 & 1 & 2 & 3 & 4 & 5 & 6 & 7 & 8 & 9 & 10\end{array}\right]' + + +def test_latex_matrix_with_functions(): + t = symbols('t') + theta1 = symbols('theta1', cls=Function) + + M = Matrix([[sin(theta1(t)), cos(theta1(t))], + [cos(theta1(t).diff(t)), sin(theta1(t).diff(t))]]) + + expected = (r'\left[\begin{matrix}\sin{\left(' + r'\theta_{1}{\left(t \right)} \right)} & ' + r'\cos{\left(\theta_{1}{\left(t \right)} \right)' + r'}\\\cos{\left(\frac{d}{d t} \theta_{1}{\left(t ' + r'\right)} \right)} & \sin{\left(\frac{d}{d t} ' + r'\theta_{1}{\left(t \right)} \right' + r')}\end{matrix}\right]') + + assert latex(M) == expected + + +def test_latex_NDimArray(): + x, y, z, w = symbols("x y z w") + + for ArrayType in (ImmutableDenseNDimArray, ImmutableSparseNDimArray, + MutableDenseNDimArray, MutableSparseNDimArray): + # Basic: scalar array + M = ArrayType(x) + + assert latex(M) == r"x" + + M = ArrayType([[1 / x, y], [z, w]]) + M1 = ArrayType([1 / x, y, z]) + + M2 = tensorproduct(M1, M) + M3 = tensorproduct(M, M) + + assert latex(M) == \ + r'\left[\begin{matrix}\frac{1}{x} & y\\z & w\end{matrix}\right]' + assert latex(M1) == \ + r"\left[\begin{matrix}\frac{1}{x} & y & z\end{matrix}\right]" + assert latex(M2) == \ + r"\left[\begin{matrix}" \ + r"\left[\begin{matrix}\frac{1}{x^{2}} & \frac{y}{x}\\\frac{z}{x} & \frac{w}{x}\end{matrix}\right] & " \ + r"\left[\begin{matrix}\frac{y}{x} & y^{2}\\y z & w y\end{matrix}\right] & " \ + r"\left[\begin{matrix}\frac{z}{x} & y z\\z^{2} & w z\end{matrix}\right]" \ + r"\end{matrix}\right]" + assert latex(M3) == \ + r"""\left[\begin{matrix}"""\ + r"""\left[\begin{matrix}\frac{1}{x^{2}} & \frac{y}{x}\\\frac{z}{x} & \frac{w}{x}\end{matrix}\right] & """\ + r"""\left[\begin{matrix}\frac{y}{x} & y^{2}\\y z & w y\end{matrix}\right]\\"""\ + r"""\left[\begin{matrix}\frac{z}{x} & y z\\z^{2} & w z\end{matrix}\right] & """\ + r"""\left[\begin{matrix}\frac{w}{x} & w y\\w z & w^{2}\end{matrix}\right]"""\ + r"""\end{matrix}\right]""" + + Mrow = ArrayType([[x, y, 1/z]]) + Mcolumn = ArrayType([[x], [y], [1/z]]) + Mcol2 = ArrayType([Mcolumn.tolist()]) + + assert latex(Mrow) == \ + r"\left[\left[\begin{matrix}x & y & \frac{1}{z}\end{matrix}\right]\right]" + assert latex(Mcolumn) == \ + r"\left[\begin{matrix}x\\y\\\frac{1}{z}\end{matrix}\right]" + assert latex(Mcol2) == \ + r'\left[\begin{matrix}\left[\begin{matrix}x\\y\\\frac{1}{z}\end{matrix}\right]\end{matrix}\right]' + + +def test_latex_mul_symbol(): + assert latex(4*4**x, mul_symbol='times') == r"4 \times 4^{x}" + assert latex(4*4**x, mul_symbol='dot') == r"4 \cdot 4^{x}" + assert latex(4*4**x, mul_symbol='ldot') == r"4 \,.\, 4^{x}" + + assert latex(4*x, mul_symbol='times') == r"4 \times x" + assert latex(4*x, mul_symbol='dot') == r"4 \cdot x" + assert latex(4*x, mul_symbol='ldot') == r"4 \,.\, x" + + +def test_latex_issue_4381(): + y = 4*4**log(2) + assert latex(y) == r'4 \cdot 4^{\log{\left(2 \right)}}' + assert latex(1/y) == r'\frac{1}{4 \cdot 4^{\log{\left(2 \right)}}}' + + +def test_latex_issue_4576(): + assert latex(Symbol("beta_13_2")) == r"\beta_{13 2}" + assert latex(Symbol("beta_132_20")) == r"\beta_{132 20}" + assert latex(Symbol("beta_13")) == r"\beta_{13}" + assert latex(Symbol("x_a_b")) == r"x_{a b}" + assert latex(Symbol("x_1_2_3")) == r"x_{1 2 3}" + assert latex(Symbol("x_a_b1")) == r"x_{a b1}" + assert latex(Symbol("x_a_1")) == r"x_{a 1}" + assert latex(Symbol("x_1_a")) == r"x_{1 a}" + assert latex(Symbol("x_1^aa")) == r"x^{aa}_{1}" + assert latex(Symbol("x_1__aa")) == r"x^{aa}_{1}" + assert latex(Symbol("x_11^a")) == r"x^{a}_{11}" + assert latex(Symbol("x_11__a")) == r"x^{a}_{11}" + assert latex(Symbol("x_a_a_a_a")) == r"x_{a a a a}" + assert latex(Symbol("x_a_a^a^a")) == r"x^{a a}_{a a}" + assert latex(Symbol("x_a_a__a__a")) == r"x^{a a}_{a a}" + assert latex(Symbol("alpha_11")) == r"\alpha_{11}" + assert latex(Symbol("alpha_11_11")) == r"\alpha_{11 11}" + assert latex(Symbol("alpha_alpha")) == r"\alpha_{\alpha}" + assert latex(Symbol("alpha^aleph")) == r"\alpha^{\aleph}" + assert latex(Symbol("alpha__aleph")) == r"\alpha^{\aleph}" + + +def test_latex_pow_fraction(): + x = Symbol('x') + # Testing exp + assert r'e^{-x}' in latex(exp(-x)/2).replace(' ', '') # Remove Whitespace + + # Testing e^{-x} in case future changes alter behavior of muls or fracs + # In particular current output is \frac{1}{2}e^{- x} but perhaps this will + # change to \frac{e^{-x}}{2} + + # Testing general, non-exp, power + assert r'3^{-x}' in latex(3**-x/2).replace(' ', '') + + +def test_noncommutative(): + A, B, C = symbols('A,B,C', commutative=False) + + assert latex(A*B*C**-1) == r"A B C^{-1}" + assert latex(C**-1*A*B) == r"C^{-1} A B" + assert latex(A*C**-1*B) == r"A C^{-1} B" + + +def test_latex_order(): + expr = x**3 + x**2*y + y**4 + 3*x*y**3 + + assert latex(expr, order='lex') == r"x^{3} + x^{2} y + 3 x y^{3} + y^{4}" + assert latex( + expr, order='rev-lex') == r"y^{4} + 3 x y^{3} + x^{2} y + x^{3}" + assert latex(expr, order='none') == r"x^{3} + y^{4} + y x^{2} + 3 x y^{3}" + + +def test_latex_Lambda(): + assert latex(Lambda(x, x + 1)) == r"\left( x \mapsto x + 1 \right)" + assert latex(Lambda((x, y), x + 1)) == r"\left( \left( x, \ y\right) \mapsto x + 1 \right)" + assert latex(Lambda(x, x)) == r"\left( x \mapsto x \right)" + +def test_latex_PolyElement(): + Ruv, u, v = ring("u,v", ZZ) + Rxyz, x, y, z = ring("x,y,z", Ruv) + + assert latex(x - x) == r"0" + assert latex(x - 1) == r"x - 1" + assert latex(x + 1) == r"x + 1" + + assert latex((u**2 + 3*u*v + 1)*x**2*y + u + 1) == \ + r"\left({u}^{2} + 3 u v + 1\right) {x}^{2} y + u + 1" + assert latex((u**2 + 3*u*v + 1)*x**2*y + (u + 1)*x) == \ + r"\left({u}^{2} + 3 u v + 1\right) {x}^{2} y + \left(u + 1\right) x" + assert latex((u**2 + 3*u*v + 1)*x**2*y + (u + 1)*x + 1) == \ + r"\left({u}^{2} + 3 u v + 1\right) {x}^{2} y + \left(u + 1\right) x + 1" + assert latex((-u**2 + 3*u*v - 1)*x**2*y - (u + 1)*x - 1) == \ + r"-\left({u}^{2} - 3 u v + 1\right) {x}^{2} y - \left(u + 1\right) x - 1" + + assert latex(-(v**2 + v + 1)*x + 3*u*v + 1) == \ + r"-\left({v}^{2} + v + 1\right) x + 3 u v + 1" + assert latex(-(v**2 + v + 1)*x - 3*u*v + 1) == \ + r"-\left({v}^{2} + v + 1\right) x - 3 u v + 1" + + +def test_latex_FracElement(): + Fuv, u, v = field("u,v", ZZ) + Fxyzt, x, y, z, t = field("x,y,z,t", Fuv) + + assert latex(x - x) == r"0" + assert latex(x - 1) == r"x - 1" + assert latex(x + 1) == r"x + 1" + + assert latex(x/3) == r"\frac{x}{3}" + assert latex(x/z) == r"\frac{x}{z}" + assert latex(x*y/z) == r"\frac{x y}{z}" + assert latex(x/(z*t)) == r"\frac{x}{z t}" + assert latex(x*y/(z*t)) == r"\frac{x y}{z t}" + + assert latex((x - 1)/y) == r"\frac{x - 1}{y}" + assert latex((x + 1)/y) == r"\frac{x + 1}{y}" + assert latex((-x - 1)/y) == r"\frac{-x - 1}{y}" + assert latex((x + 1)/(y*z)) == r"\frac{x + 1}{y z}" + assert latex(-y/(x + 1)) == r"\frac{-y}{x + 1}" + assert latex(y*z/(x + 1)) == r"\frac{y z}{x + 1}" + + assert latex(((u + 1)*x*y + 1)/((v - 1)*z - 1)) == \ + r"\frac{\left(u + 1\right) x y + 1}{\left(v - 1\right) z - 1}" + assert latex(((u + 1)*x*y + 1)/((v - 1)*z - t*u*v - 1)) == \ + r"\frac{\left(u + 1\right) x y + 1}{\left(v - 1\right) z - u v t - 1}" + + +def test_latex_Poly(): + assert latex(Poly(x**2 + 2 * x, x)) == \ + r"\operatorname{Poly}{\left( x^{2} + 2 x, x, domain=\mathbb{Z} \right)}" + assert latex(Poly(x/y, x)) == \ + r"\operatorname{Poly}{\left( \frac{1}{y} x, x, domain=\mathbb{Z}\left(y\right) \right)}" + assert latex(Poly(2.0*x + y)) == \ + r"\operatorname{Poly}{\left( 2.0 x + 1.0 y, x, y, domain=\mathbb{R} \right)}" + + +def test_latex_Poly_order(): + assert latex(Poly([a, 1, b, 2, c, 3], x)) == \ + r'\operatorname{Poly}{\left( a x^{5} + x^{4} + b x^{3} + 2 x^{2} + c'\ + r' x + 3, x, domain=\mathbb{Z}\left[a, b, c\right] \right)}' + assert latex(Poly([a, 1, b+c, 2, 3], x)) == \ + r'\operatorname{Poly}{\left( a x^{4} + x^{3} + \left(b + c\right) '\ + r'x^{2} + 2 x + 3, x, domain=\mathbb{Z}\left[a, b, c\right] \right)}' + assert latex(Poly(a*x**3 + x**2*y - x*y - c*y**3 - b*x*y**2 + y - a*x + b, + (x, y))) == \ + r'\operatorname{Poly}{\left( a x^{3} + x^{2}y - b xy^{2} - xy - '\ + r'a x - c y^{3} + y + b, x, y, domain=\mathbb{Z}\left[a, b, c\right] \right)}' + + +def test_latex_ComplexRootOf(): + assert latex(rootof(x**5 + x + 3, 0)) == \ + r"\operatorname{CRootOf} {\left(x^{5} + x + 3, 0\right)}" + + +def test_latex_RootSum(): + assert latex(RootSum(x**5 + x + 3, sin)) == \ + r"\operatorname{RootSum} {\left(x^{5} + x + 3, \left( x \mapsto \sin{\left(x \right)} \right)\right)}" + + +def test_settings(): + raises(TypeError, lambda: latex(x*y, method="garbage")) + + +def test_latex_numbers(): + assert latex(catalan(n)) == r"C_{n}" + assert latex(catalan(n)**2) == r"C_{n}^{2}" + assert latex(bernoulli(n)) == r"B_{n}" + assert latex(bernoulli(n, x)) == r"B_{n}\left(x\right)" + assert latex(bernoulli(n)**2) == r"B_{n}^{2}" + assert latex(bernoulli(n, x)**2) == r"B_{n}^{2}\left(x\right)" + assert latex(genocchi(n)) == r"G_{n}" + assert latex(genocchi(n, x)) == r"G_{n}\left(x\right)" + assert latex(genocchi(n)**2) == r"G_{n}^{2}" + assert latex(genocchi(n, x)**2) == r"G_{n}^{2}\left(x\right)" + assert latex(bell(n)) == r"B_{n}" + assert latex(bell(n, x)) == r"B_{n}\left(x\right)" + assert latex(bell(n, m, (x, y))) == r"B_{n, m}\left(x, y\right)" + assert latex(bell(n)**2) == r"B_{n}^{2}" + assert latex(bell(n, x)**2) == r"B_{n}^{2}\left(x\right)" + assert latex(bell(n, m, (x, y))**2) == r"B_{n, m}^{2}\left(x, y\right)" + assert latex(fibonacci(n)) == r"F_{n}" + assert latex(fibonacci(n, x)) == r"F_{n}\left(x\right)" + assert latex(fibonacci(n)**2) == r"F_{n}^{2}" + assert latex(fibonacci(n, x)**2) == r"F_{n}^{2}\left(x\right)" + assert latex(lucas(n)) == r"L_{n}" + assert latex(lucas(n)**2) == r"L_{n}^{2}" + assert latex(tribonacci(n)) == r"T_{n}" + assert latex(tribonacci(n, x)) == r"T_{n}\left(x\right)" + assert latex(tribonacci(n)**2) == r"T_{n}^{2}" + assert latex(tribonacci(n, x)**2) == r"T_{n}^{2}\left(x\right)" + assert latex(mobius(n)) == r"\mu\left(n\right)" + assert latex(mobius(n)**2) == r"\mu^{2}\left(n\right)" + + +def test_latex_euler(): + assert latex(euler(n)) == r"E_{n}" + assert latex(euler(n, x)) == r"E_{n}\left(x\right)" + assert latex(euler(n, x)**2) == r"E_{n}^{2}\left(x\right)" + + +def test_lamda(): + assert latex(Symbol('lamda')) == r"\lambda" + assert latex(Symbol('Lamda')) == r"\Lambda" + + +def test_custom_symbol_names(): + x = Symbol('x') + y = Symbol('y') + assert latex(x) == r"x" + assert latex(x, symbol_names={x: "x_i"}) == r"x_i" + assert latex(x + y, symbol_names={x: "x_i"}) == r"x_i + y" + assert latex(x**2, symbol_names={x: "x_i"}) == r"x_i^{2}" + assert latex(x + y, symbol_names={x: "x_i", y: "y_j"}) == r"x_i + y_j" + + +def test_matAdd(): + C = MatrixSymbol('C', 5, 5) + B = MatrixSymbol('B', 5, 5) + + n = symbols("n") + h = MatrixSymbol("h", 1, 1) + + assert latex(C - 2*B) in [r'- 2 B + C', r'C -2 B'] + assert latex(C + 2*B) in [r'2 B + C', r'C + 2 B'] + assert latex(B - 2*C) in [r'B - 2 C', r'- 2 C + B'] + assert latex(B + 2*C) in [r'B + 2 C', r'2 C + B'] + + assert latex(n * h - (-h + h.T) * (h + h.T)) == 'n h - \\left(- h + h^{T}\\right) \\left(h + h^{T}\\right)' + assert latex(MatAdd(MatAdd(h, h), MatAdd(h, h))) == '\\left(h + h\\right) + \\left(h + h\\right)' + assert latex(MatMul(MatMul(h, h), MatMul(h, h))) == '\\left(h h\\right) \\left(h h\\right)' + + +def test_matMul(): + A = MatrixSymbol('A', 5, 5) + B = MatrixSymbol('B', 5, 5) + x = Symbol('x') + assert latex(2*A) == r'2 A' + assert latex(2*x*A) == r'2 x A' + assert latex(-2*A) == r'- 2 A' + assert latex(1.5*A) == r'1.5 A' + assert latex(sqrt(2)*A) == r'\sqrt{2} A' + assert latex(-sqrt(2)*A) == r'- \sqrt{2} A' + assert latex(2*sqrt(2)*x*A) == r'2 \sqrt{2} x A' + assert latex(-2*A*(A + 2*B)) in [r'- 2 A \left(A + 2 B\right)', + r'- 2 A \left(2 B + A\right)'] + + +def test_latex_MatrixSlice(): + n = Symbol('n', integer=True) + x, y, z, w, t, = symbols('x y z w t') + X = MatrixSymbol('X', n, n) + Y = MatrixSymbol('Y', 10, 10) + Z = MatrixSymbol('Z', 10, 10) + + assert latex(MatrixSlice(X, (None, None, None), (None, None, None))) == r'X\left[:, :\right]' + assert latex(X[x:x + 1, y:y + 1]) == r'X\left[x:x + 1, y:y + 1\right]' + assert latex(X[x:x + 1:2, y:y + 1:2]) == r'X\left[x:x + 1:2, y:y + 1:2\right]' + assert latex(X[:x, y:]) == r'X\left[:x, y:\right]' + assert latex(X[:x, y:]) == r'X\left[:x, y:\right]' + assert latex(X[x:, :y]) == r'X\left[x:, :y\right]' + assert latex(X[x:y, z:w]) == r'X\left[x:y, z:w\right]' + assert latex(X[x:y:t, w:t:x]) == r'X\left[x:y:t, w:t:x\right]' + assert latex(X[x::y, t::w]) == r'X\left[x::y, t::w\right]' + assert latex(X[:x:y, :t:w]) == r'X\left[:x:y, :t:w\right]' + assert latex(X[::x, ::y]) == r'X\left[::x, ::y\right]' + assert latex(MatrixSlice(X, (0, None, None), (0, None, None))) == r'X\left[:, :\right]' + assert latex(MatrixSlice(X, (None, n, None), (None, n, None))) == r'X\left[:, :\right]' + assert latex(MatrixSlice(X, (0, n, None), (0, n, None))) == r'X\left[:, :\right]' + assert latex(MatrixSlice(X, (0, n, 2), (0, n, 2))) == r'X\left[::2, ::2\right]' + assert latex(X[1:2:3, 4:5:6]) == r'X\left[1:2:3, 4:5:6\right]' + assert latex(X[1:3:5, 4:6:8]) == r'X\left[1:3:5, 4:6:8\right]' + assert latex(X[1:10:2]) == r'X\left[1:10:2, :\right]' + assert latex(Y[:5, 1:9:2]) == r'Y\left[:5, 1:9:2\right]' + assert latex(Y[:5, 1:10:2]) == r'Y\left[:5, 1::2\right]' + assert latex(Y[5, :5:2]) == r'Y\left[5:6, :5:2\right]' + assert latex(X[0:1, 0:1]) == r'X\left[:1, :1\right]' + assert latex(X[0:1:2, 0:1:2]) == r'X\left[:1:2, :1:2\right]' + assert latex((Y + Z)[2:, 2:]) == r'\left(Y + Z\right)\left[2:, 2:\right]' + + +def test_latex_RandomDomain(): + from sympy.stats import Normal, Die, Exponential, pspace, where + from sympy.stats.rv import RandomDomain + + X = Normal('x1', 0, 1) + assert latex(where(X > 0)) == r"\text{Domain: }0 < x_{1} \wedge x_{1} < \infty" + + D = Die('d1', 6) + assert latex(where(D > 4)) == r"\text{Domain: }d_{1} = 5 \vee d_{1} = 6" + + A = Exponential('a', 1) + B = Exponential('b', 1) + assert latex( + pspace(Tuple(A, B)).domain) == \ + r"\text{Domain: }0 \leq a \wedge 0 \leq b \wedge a < \infty \wedge b < \infty" + + assert latex(RandomDomain(FiniteSet(x), FiniteSet(1, 2))) == \ + r'\text{Domain: }\left\{x\right\} \in \left\{1, 2\right\}' + +def test_PrettyPoly(): + from sympy.polys.domains import QQ + F = QQ.frac_field(x, y) + R = QQ[x, y] + + assert latex(F.convert(x/(x + y))) == latex(x/(x + y)) + assert latex(R.convert(x + y)) == latex(x + y) + + +def test_integral_transforms(): + x = Symbol("x") + k = Symbol("k") + f = Function("f") + a = Symbol("a") + b = Symbol("b") + + assert latex(MellinTransform(f(x), x, k)) == \ + r"\mathcal{M}_{x}\left[f{\left(x \right)}\right]\left(k\right)" + assert latex(InverseMellinTransform(f(k), k, x, a, b)) == \ + r"\mathcal{M}^{-1}_{k}\left[f{\left(k \right)}\right]\left(x\right)" + + assert latex(LaplaceTransform(f(x), x, k)) == \ + r"\mathcal{L}_{x}\left[f{\left(x \right)}\right]\left(k\right)" + assert latex(InverseLaplaceTransform(f(k), k, x, (a, b))) == \ + r"\mathcal{L}^{-1}_{k}\left[f{\left(k \right)}\right]\left(x\right)" + + assert latex(FourierTransform(f(x), x, k)) == \ + r"\mathcal{F}_{x}\left[f{\left(x \right)}\right]\left(k\right)" + assert latex(InverseFourierTransform(f(k), k, x)) == \ + r"\mathcal{F}^{-1}_{k}\left[f{\left(k \right)}\right]\left(x\right)" + + assert latex(CosineTransform(f(x), x, k)) == \ + r"\mathcal{COS}_{x}\left[f{\left(x \right)}\right]\left(k\right)" + assert latex(InverseCosineTransform(f(k), k, x)) == \ + r"\mathcal{COS}^{-1}_{k}\left[f{\left(k \right)}\right]\left(x\right)" + + assert latex(SineTransform(f(x), x, k)) == \ + r"\mathcal{SIN}_{x}\left[f{\left(x \right)}\right]\left(k\right)" + assert latex(InverseSineTransform(f(k), k, x)) == \ + r"\mathcal{SIN}^{-1}_{k}\left[f{\left(k \right)}\right]\left(x\right)" + + +def test_PolynomialRingBase(): + from sympy.polys.domains import QQ + assert latex(QQ.old_poly_ring(x, y)) == r"\mathbb{Q}\left[x, y\right]" + assert latex(QQ.old_poly_ring(x, y, order="ilex")) == \ + r"S_<^{-1}\mathbb{Q}\left[x, y\right]" + + +def test_categories(): + from sympy.categories import (Object, IdentityMorphism, + NamedMorphism, Category, Diagram, + DiagramGrid) + + A1 = Object("A1") + A2 = Object("A2") + A3 = Object("A3") + + f1 = NamedMorphism(A1, A2, "f1") + f2 = NamedMorphism(A2, A3, "f2") + id_A1 = IdentityMorphism(A1) + + K1 = Category("K1") + + assert latex(A1) == r"A_{1}" + assert latex(f1) == r"f_{1}:A_{1}\rightarrow A_{2}" + assert latex(id_A1) == r"id:A_{1}\rightarrow A_{1}" + assert latex(f2*f1) == r"f_{2}\circ f_{1}:A_{1}\rightarrow A_{3}" + + assert latex(K1) == r"\mathbf{K_{1}}" + + d = Diagram() + assert latex(d) == r"\emptyset" + + d = Diagram({f1: "unique", f2: S.EmptySet}) + assert latex(d) == r"\left\{ f_{2}\circ f_{1}:A_{1}" \ + r"\rightarrow A_{3} : \emptyset, \ id:A_{1}\rightarrow " \ + r"A_{1} : \emptyset, \ id:A_{2}\rightarrow A_{2} : " \ + r"\emptyset, \ id:A_{3}\rightarrow A_{3} : \emptyset, " \ + r"\ f_{1}:A_{1}\rightarrow A_{2} : \left\{unique\right\}, " \ + r"\ f_{2}:A_{2}\rightarrow A_{3} : \emptyset\right\}" + + d = Diagram({f1: "unique", f2: S.EmptySet}, {f2 * f1: "unique"}) + assert latex(d) == r"\left\{ f_{2}\circ f_{1}:A_{1}" \ + r"\rightarrow A_{3} : \emptyset, \ id:A_{1}\rightarrow " \ + r"A_{1} : \emptyset, \ id:A_{2}\rightarrow A_{2} : " \ + r"\emptyset, \ id:A_{3}\rightarrow A_{3} : \emptyset, " \ + r"\ f_{1}:A_{1}\rightarrow A_{2} : \left\{unique\right\}," \ + r" \ f_{2}:A_{2}\rightarrow A_{3} : \emptyset\right\}" \ + r"\Longrightarrow \left\{ f_{2}\circ f_{1}:A_{1}" \ + r"\rightarrow A_{3} : \left\{unique\right\}\right\}" + + # A linear diagram. + A = Object("A") + B = Object("B") + C = Object("C") + f = NamedMorphism(A, B, "f") + g = NamedMorphism(B, C, "g") + d = Diagram([f, g]) + grid = DiagramGrid(d) + + assert latex(grid) == r"\begin{array}{cc}" + "\n" \ + r"A & B \\" + "\n" \ + r" & C " + "\n" \ + r"\end{array}" + "\n" + + +def test_Modules(): + from sympy.polys.domains import QQ + from sympy.polys.agca import homomorphism + + R = QQ.old_poly_ring(x, y) + F = R.free_module(2) + M = F.submodule([x, y], [1, x**2]) + + assert latex(F) == r"{\mathbb{Q}\left[x, y\right]}^{2}" + assert latex(M) == \ + r"\left\langle {\left[ {x},{y} \right]},{\left[ {1},{x^{2}} \right]} \right\rangle" + + I = R.ideal(x**2, y) + assert latex(I) == r"\left\langle {x^{2}},{y} \right\rangle" + + Q = F / M + assert latex(Q) == \ + r"\frac{{\mathbb{Q}\left[x, y\right]}^{2}}{\left\langle {\left[ {x},"\ + r"{y} \right]},{\left[ {1},{x^{2}} \right]} \right\rangle}" + assert latex(Q.submodule([1, x**3/2], [2, y])) == \ + r"\left\langle {{\left[ {1},{\frac{x^{3}}{2}} \right]} + {\left"\ + r"\langle {\left[ {x},{y} \right]},{\left[ {1},{x^{2}} \right]} "\ + r"\right\rangle}},{{\left[ {2},{y} \right]} + {\left\langle {\left[ "\ + r"{x},{y} \right]},{\left[ {1},{x^{2}} \right]} \right\rangle}} \right\rangle" + + h = homomorphism(QQ.old_poly_ring(x).free_module(2), + QQ.old_poly_ring(x).free_module(2), [0, 0]) + + assert latex(h) == \ + r"{\left[\begin{matrix}0 & 0\\0 & 0\end{matrix}\right]} : "\ + r"{{\mathbb{Q}\left[x\right]}^{2}} \to {{\mathbb{Q}\left[x\right]}^{2}}" + + +def test_QuotientRing(): + from sympy.polys.domains import QQ + R = QQ.old_poly_ring(x)/[x**2 + 1] + + assert latex(R) == \ + r"\frac{\mathbb{Q}\left[x\right]}{\left\langle {x^{2} + 1} \right\rangle}" + assert latex(R.one) == r"{1} + {\left\langle {x^{2} + 1} \right\rangle}" + + +def test_Tr(): + #TODO: Handle indices + A, B = symbols('A B', commutative=False) + t = Tr(A*B) + assert latex(t) == r'\operatorname{tr}\left(A B\right)' + + +def test_Determinant(): + from sympy.matrices import Determinant, Inverse, BlockMatrix, OneMatrix, ZeroMatrix + m = Matrix(((1, 2), (3, 4))) + assert latex(Determinant(m)) == '\\left|{\\begin{matrix}1 & 2\\\\3 & 4\\end{matrix}}\\right|' + assert latex(Determinant(Inverse(m))) == \ + '\\left|{\\left[\\begin{matrix}1 & 2\\\\3 & 4\\end{matrix}\\right]^{-1}}\\right|' + X = MatrixSymbol('X', 2, 2) + assert latex(Determinant(X)) == '\\left|{X}\\right|' + assert latex(Determinant(X + m)) == \ + '\\left|{\\left[\\begin{matrix}1 & 2\\\\3 & 4\\end{matrix}\\right] + X}\\right|' + assert latex(Determinant(BlockMatrix(((OneMatrix(2, 2), X), + (m, ZeroMatrix(2, 2)))))) == \ + '\\left|{\\begin{matrix}1 & X\\\\\\left[\\begin{matrix}1 & 2\\\\3 & 4\\end{matrix}\\right] & 0\\end{matrix}}\\right|' + + +def test_Adjoint(): + from sympy.matrices import Adjoint, Inverse, Transpose + X = MatrixSymbol('X', 2, 2) + Y = MatrixSymbol('Y', 2, 2) + assert latex(Adjoint(X)) == r'X^{\dagger}' + assert latex(Adjoint(X + Y)) == r'\left(X + Y\right)^{\dagger}' + assert latex(Adjoint(X) + Adjoint(Y)) == r'X^{\dagger} + Y^{\dagger}' + assert latex(Adjoint(X*Y)) == r'\left(X Y\right)^{\dagger}' + assert latex(Adjoint(Y)*Adjoint(X)) == r'Y^{\dagger} X^{\dagger}' + assert latex(Adjoint(X**2)) == r'\left(X^{2}\right)^{\dagger}' + assert latex(Adjoint(X)**2) == r'\left(X^{\dagger}\right)^{2}' + assert latex(Adjoint(Inverse(X))) == r'\left(X^{-1}\right)^{\dagger}' + assert latex(Inverse(Adjoint(X))) == r'\left(X^{\dagger}\right)^{-1}' + assert latex(Adjoint(Transpose(X))) == r'\left(X^{T}\right)^{\dagger}' + assert latex(Transpose(Adjoint(X))) == r'\left(X^{\dagger}\right)^{T}' + assert latex(Transpose(Adjoint(X) + Y)) == r'\left(X^{\dagger} + Y\right)^{T}' + m = Matrix(((1, 2), (3, 4))) + assert latex(Adjoint(m)) == '\\left[\\begin{matrix}1 & 2\\\\3 & 4\\end{matrix}\\right]^{\\dagger}' + assert latex(Adjoint(m+X)) == \ + '\\left(\\left[\\begin{matrix}1 & 2\\\\3 & 4\\end{matrix}\\right] + X\\right)^{\\dagger}' + from sympy.matrices import BlockMatrix, OneMatrix, ZeroMatrix + assert latex(Adjoint(BlockMatrix(((OneMatrix(2, 2), X), + (m, ZeroMatrix(2, 2)))))) == \ + '\\left[\\begin{matrix}1 & X\\\\\\left[\\begin{matrix}1 & 2\\\\3 & 4\\end{matrix}\\right] & 0\\end{matrix}\\right]^{\\dagger}' + # Issue 20959 + Mx = MatrixSymbol('M^x', 2, 2) + assert latex(Adjoint(Mx)) == r'\left(M^{x}\right)^{\dagger}' + + # adjoint style + assert latex(Adjoint(X), adjoint_style="star") == r'X^{\ast}' + assert latex(Adjoint(X + Y), adjoint_style="hermitian") == r'\left(X + Y\right)^{\mathsf{H}}' + assert latex(Adjoint(X) + Adjoint(Y), adjoint_style="dagger") == r'X^{\dagger} + Y^{\dagger}' + assert latex(Adjoint(Y)*Adjoint(X)) == r'Y^{\dagger} X^{\dagger}' + assert latex(Adjoint(X**2), adjoint_style="star") == r'\left(X^{2}\right)^{\ast}' + assert latex(Adjoint(X)**2, adjoint_style="hermitian") == r'\left(X^{\mathsf{H}}\right)^{2}' + +def test_Transpose(): + from sympy.matrices import Transpose, MatPow, HadamardPower + X = MatrixSymbol('X', 2, 2) + Y = MatrixSymbol('Y', 2, 2) + assert latex(Transpose(X)) == r'X^{T}' + assert latex(Transpose(X + Y)) == r'\left(X + Y\right)^{T}' + + assert latex(Transpose(HadamardPower(X, 2))) == r'\left(X^{\circ {2}}\right)^{T}' + assert latex(HadamardPower(Transpose(X), 2)) == r'\left(X^{T}\right)^{\circ {2}}' + assert latex(Transpose(MatPow(X, 2))) == r'\left(X^{2}\right)^{T}' + assert latex(MatPow(Transpose(X), 2)) == r'\left(X^{T}\right)^{2}' + m = Matrix(((1, 2), (3, 4))) + assert latex(Transpose(m)) == '\\left[\\begin{matrix}1 & 2\\\\3 & 4\\end{matrix}\\right]^{T}' + assert latex(Transpose(m+X)) == \ + '\\left(\\left[\\begin{matrix}1 & 2\\\\3 & 4\\end{matrix}\\right] + X\\right)^{T}' + from sympy.matrices import BlockMatrix, OneMatrix, ZeroMatrix + assert latex(Transpose(BlockMatrix(((OneMatrix(2, 2), X), + (m, ZeroMatrix(2, 2)))))) == \ + '\\left[\\begin{matrix}1 & X\\\\\\left[\\begin{matrix}1 & 2\\\\3 & 4\\end{matrix}\\right] & 0\\end{matrix}\\right]^{T}' + # Issue 20959 + Mx = MatrixSymbol('M^x', 2, 2) + assert latex(Transpose(Mx)) == r'\left(M^{x}\right)^{T}' + + +def test_Hadamard(): + from sympy.matrices import HadamardProduct, HadamardPower + from sympy.matrices.expressions import MatAdd, MatMul, MatPow + X = MatrixSymbol('X', 2, 2) + Y = MatrixSymbol('Y', 2, 2) + assert latex(HadamardProduct(X, Y*Y)) == r'X \circ Y^{2}' + assert latex(HadamardProduct(X, Y)*Y) == r'\left(X \circ Y\right) Y' + + assert latex(HadamardPower(X, 2)) == r'X^{\circ {2}}' + assert latex(HadamardPower(X, -1)) == r'X^{\circ \left({-1}\right)}' + assert latex(HadamardPower(MatAdd(X, Y), 2)) == \ + r'\left(X + Y\right)^{\circ {2}}' + assert latex(HadamardPower(MatMul(X, Y), 2)) == \ + r'\left(X Y\right)^{\circ {2}}' + + assert latex(HadamardPower(MatPow(X, -1), -1)) == \ + r'\left(X^{-1}\right)^{\circ \left({-1}\right)}' + assert latex(MatPow(HadamardPower(X, -1), -1)) == \ + r'\left(X^{\circ \left({-1}\right)}\right)^{-1}' + + assert latex(HadamardPower(X, n+1)) == \ + r'X^{\circ \left({n + 1}\right)}' + + +def test_MatPow(): + from sympy.matrices.expressions import MatPow + X = MatrixSymbol('X', 2, 2) + Y = MatrixSymbol('Y', 2, 2) + assert latex(MatPow(X, 2)) == 'X^{2}' + assert latex(MatPow(X*X, 2)) == '\\left(X^{2}\\right)^{2}' + assert latex(MatPow(X*Y, 2)) == '\\left(X Y\\right)^{2}' + assert latex(MatPow(X + Y, 2)) == '\\left(X + Y\\right)^{2}' + assert latex(MatPow(X + X, 2)) == '\\left(2 X\\right)^{2}' + # Issue 20959 + Mx = MatrixSymbol('M^x', 2, 2) + assert latex(MatPow(Mx, 2)) == r'\left(M^{x}\right)^{2}' + + +def test_ElementwiseApplyFunction(): + X = MatrixSymbol('X', 2, 2) + expr = (X.T*X).applyfunc(sin) + assert latex(expr) == r"{\left( d \mapsto \sin{\left(d \right)} \right)}_{\circ}\left({X^{T} X}\right)" + expr = X.applyfunc(Lambda(x, 1/x)) + assert latex(expr) == r'{\left( x \mapsto \frac{1}{x} \right)}_{\circ}\left({X}\right)' + + +def test_ZeroMatrix(): + from sympy.matrices.expressions.special import ZeroMatrix + assert latex(ZeroMatrix(1, 1), mat_symbol_style='plain') == r"0" + assert latex(ZeroMatrix(1, 1), mat_symbol_style='bold') == r"\mathbf{0}" + + +def test_OneMatrix(): + from sympy.matrices.expressions.special import OneMatrix + assert latex(OneMatrix(3, 4), mat_symbol_style='plain') == r"1" + assert latex(OneMatrix(3, 4), mat_symbol_style='bold') == r"\mathbf{1}" + + +def test_Identity(): + from sympy.matrices.expressions.special import Identity + assert latex(Identity(1), mat_symbol_style='plain') == r"\mathbb{I}" + assert latex(Identity(1), mat_symbol_style='bold') == r"\mathbf{I}" + + +def test_latex_DFT_IDFT(): + from sympy.matrices.expressions.fourier import DFT, IDFT + assert latex(DFT(13)) == r"\text{DFT}_{13}" + assert latex(IDFT(x)) == r"\text{IDFT}_{x}" + + +def test_boolean_args_order(): + syms = symbols('a:f') + + expr = And(*syms) + assert latex(expr) == r'a \wedge b \wedge c \wedge d \wedge e \wedge f' + + expr = Or(*syms) + assert latex(expr) == r'a \vee b \vee c \vee d \vee e \vee f' + + expr = Equivalent(*syms) + assert latex(expr) == \ + r'a \Leftrightarrow b \Leftrightarrow c \Leftrightarrow d \Leftrightarrow e \Leftrightarrow f' + + expr = Xor(*syms) + assert latex(expr) == \ + r'a \veebar b \veebar c \veebar d \veebar e \veebar f' + + +def test_imaginary(): + i = sqrt(-1) + assert latex(i) == r'i' + + +def test_builtins_without_args(): + assert latex(sin) == r'\sin' + assert latex(cos) == r'\cos' + assert latex(tan) == r'\tan' + assert latex(log) == r'\log' + assert latex(Ei) == r'\operatorname{Ei}' + assert latex(zeta) == r'\zeta' + + +def test_latex_greek_functions(): + # bug because capital greeks that have roman equivalents should not use + # \Alpha, \Beta, \Eta, etc. + s = Function('Alpha') + assert latex(s) == r'\mathrm{A}' + assert latex(s(x)) == r'\mathrm{A}{\left(x \right)}' + s = Function('Beta') + assert latex(s) == r'\mathrm{B}' + s = Function('Eta') + assert latex(s) == r'\mathrm{H}' + assert latex(s(x)) == r'\mathrm{H}{\left(x \right)}' + + # bug because sympy.core.numbers.Pi is special + p = Function('Pi') + # assert latex(p(x)) == r'\Pi{\left(x \right)}' + assert latex(p) == r'\Pi' + + # bug because not all greeks are included + c = Function('chi') + assert latex(c(x)) == r'\chi{\left(x \right)}' + assert latex(c) == r'\chi' + + +def test_translate(): + s = 'Alpha' + assert translate(s) == r'\mathrm{A}' + s = 'Beta' + assert translate(s) == r'\mathrm{B}' + s = 'Eta' + assert translate(s) == r'\mathrm{H}' + s = 'omicron' + assert translate(s) == r'o' + s = 'Pi' + assert translate(s) == r'\Pi' + s = 'pi' + assert translate(s) == r'\pi' + s = 'LamdaHatDOT' + assert translate(s) == r'\dot{\hat{\Lambda}}' + + +def test_other_symbols(): + from sympy.printing.latex import other_symbols + for s in other_symbols: + assert latex(symbols(s)) == r"" "\\" + s + + +def test_modifiers(): + # Test each modifier individually in the simplest case + # (with funny capitalizations) + assert latex(symbols("xMathring")) == r"\mathring{x}" + assert latex(symbols("xCheck")) == r"\check{x}" + assert latex(symbols("xBreve")) == r"\breve{x}" + assert latex(symbols("xAcute")) == r"\acute{x}" + assert latex(symbols("xGrave")) == r"\grave{x}" + assert latex(symbols("xTilde")) == r"\tilde{x}" + assert latex(symbols("xPrime")) == r"{x}'" + assert latex(symbols("xddDDot")) == r"\ddddot{x}" + assert latex(symbols("xDdDot")) == r"\dddot{x}" + assert latex(symbols("xDDot")) == r"\ddot{x}" + assert latex(symbols("xBold")) == r"\boldsymbol{x}" + assert latex(symbols("xnOrM")) == r"\left\|{x}\right\|" + assert latex(symbols("xAVG")) == r"\left\langle{x}\right\rangle" + assert latex(symbols("xHat")) == r"\hat{x}" + assert latex(symbols("xDot")) == r"\dot{x}" + assert latex(symbols("xBar")) == r"\bar{x}" + assert latex(symbols("xVec")) == r"\vec{x}" + assert latex(symbols("xAbs")) == r"\left|{x}\right|" + assert latex(symbols("xMag")) == r"\left|{x}\right|" + assert latex(symbols("xPrM")) == r"{x}'" + assert latex(symbols("xBM")) == r"\boldsymbol{x}" + # Test strings that are *only* the names of modifiers + assert latex(symbols("Mathring")) == r"Mathring" + assert latex(symbols("Check")) == r"Check" + assert latex(symbols("Breve")) == r"Breve" + assert latex(symbols("Acute")) == r"Acute" + assert latex(symbols("Grave")) == r"Grave" + assert latex(symbols("Tilde")) == r"Tilde" + assert latex(symbols("Prime")) == r"Prime" + assert latex(symbols("DDot")) == r"\dot{D}" + assert latex(symbols("Bold")) == r"Bold" + assert latex(symbols("NORm")) == r"NORm" + assert latex(symbols("AVG")) == r"AVG" + assert latex(symbols("Hat")) == r"Hat" + assert latex(symbols("Dot")) == r"Dot" + assert latex(symbols("Bar")) == r"Bar" + assert latex(symbols("Vec")) == r"Vec" + assert latex(symbols("Abs")) == r"Abs" + assert latex(symbols("Mag")) == r"Mag" + assert latex(symbols("PrM")) == r"PrM" + assert latex(symbols("BM")) == r"BM" + assert latex(symbols("hbar")) == r"\hbar" + # Check a few combinations + assert latex(symbols("xvecdot")) == r"\dot{\vec{x}}" + assert latex(symbols("xDotVec")) == r"\vec{\dot{x}}" + assert latex(symbols("xHATNorm")) == r"\left\|{\hat{x}}\right\|" + # Check a couple big, ugly combinations + assert latex(symbols('xMathringBm_yCheckPRM__zbreveAbs')) == \ + r"\boldsymbol{\mathring{x}}^{\left|{\breve{z}}\right|}_{{\check{y}}'}" + assert latex(symbols('alphadothat_nVECDOT__tTildePrime')) == \ + r"\hat{\dot{\alpha}}^{{\tilde{t}}'}_{\dot{\vec{n}}}" + + +def test_greek_symbols(): + assert latex(Symbol('alpha')) == r'\alpha' + assert latex(Symbol('beta')) == r'\beta' + assert latex(Symbol('gamma')) == r'\gamma' + assert latex(Symbol('delta')) == r'\delta' + assert latex(Symbol('epsilon')) == r'\epsilon' + assert latex(Symbol('zeta')) == r'\zeta' + assert latex(Symbol('eta')) == r'\eta' + assert latex(Symbol('theta')) == r'\theta' + assert latex(Symbol('iota')) == r'\iota' + assert latex(Symbol('kappa')) == r'\kappa' + assert latex(Symbol('lambda')) == r'\lambda' + assert latex(Symbol('mu')) == r'\mu' + assert latex(Symbol('nu')) == r'\nu' + assert latex(Symbol('xi')) == r'\xi' + assert latex(Symbol('omicron')) == r'o' + assert latex(Symbol('pi')) == r'\pi' + assert latex(Symbol('rho')) == r'\rho' + assert latex(Symbol('sigma')) == r'\sigma' + assert latex(Symbol('tau')) == r'\tau' + assert latex(Symbol('upsilon')) == r'\upsilon' + assert latex(Symbol('phi')) == r'\phi' + assert latex(Symbol('chi')) == r'\chi' + assert latex(Symbol('psi')) == r'\psi' + assert latex(Symbol('omega')) == r'\omega' + + assert latex(Symbol('Alpha')) == r'\mathrm{A}' + assert latex(Symbol('Beta')) == r'\mathrm{B}' + assert latex(Symbol('Gamma')) == r'\Gamma' + assert latex(Symbol('Delta')) == r'\Delta' + assert latex(Symbol('Epsilon')) == r'\mathrm{E}' + assert latex(Symbol('Zeta')) == r'\mathrm{Z}' + assert latex(Symbol('Eta')) == r'\mathrm{H}' + assert latex(Symbol('Theta')) == r'\Theta' + assert latex(Symbol('Iota')) == r'\mathrm{I}' + assert latex(Symbol('Kappa')) == r'\mathrm{K}' + assert latex(Symbol('Lambda')) == r'\Lambda' + assert latex(Symbol('Mu')) == r'\mathrm{M}' + assert latex(Symbol('Nu')) == r'\mathrm{N}' + assert latex(Symbol('Xi')) == r'\Xi' + assert latex(Symbol('Omicron')) == r'\mathrm{O}' + assert latex(Symbol('Pi')) == r'\Pi' + assert latex(Symbol('Rho')) == r'\mathrm{P}' + assert latex(Symbol('Sigma')) == r'\Sigma' + assert latex(Symbol('Tau')) == r'\mathrm{T}' + assert latex(Symbol('Upsilon')) == r'\Upsilon' + assert latex(Symbol('Phi')) == r'\Phi' + assert latex(Symbol('Chi')) == r'\mathrm{X}' + assert latex(Symbol('Psi')) == r'\Psi' + assert latex(Symbol('Omega')) == r'\Omega' + + assert latex(Symbol('varepsilon')) == r'\varepsilon' + assert latex(Symbol('varkappa')) == r'\varkappa' + assert latex(Symbol('varphi')) == r'\varphi' + assert latex(Symbol('varpi')) == r'\varpi' + assert latex(Symbol('varrho')) == r'\varrho' + assert latex(Symbol('varsigma')) == r'\varsigma' + assert latex(Symbol('vartheta')) == r'\vartheta' + + +def test_fancyset_symbols(): + assert latex(S.Rationals) == r'\mathbb{Q}' + assert latex(S.Naturals) == r'\mathbb{N}' + assert latex(S.Naturals0) == r'\mathbb{N}_0' + assert latex(S.Integers) == r'\mathbb{Z}' + assert latex(S.Reals) == r'\mathbb{R}' + assert latex(S.Complexes) == r'\mathbb{C}' + + +@XFAIL +def test_builtin_without_args_mismatched_names(): + assert latex(CosineTransform) == r'\mathcal{COS}' + + +def test_builtin_no_args(): + assert latex(Chi) == r'\operatorname{Chi}' + assert latex(beta) == r'\operatorname{B}' + assert latex(gamma) == r'\Gamma' + assert latex(KroneckerDelta) == r'\delta' + assert latex(DiracDelta) == r'\delta' + assert latex(lowergamma) == r'\gamma' + + +def test_issue_6853(): + p = Function('Pi') + assert latex(p(x)) == r"\Pi{\left(x \right)}" + + +def test_Mul(): + e = Mul(-2, x + 1, evaluate=False) + assert latex(e) == r'- 2 \left(x + 1\right)' + e = Mul(2, x + 1, evaluate=False) + assert latex(e) == r'2 \left(x + 1\right)' + e = Mul(S.Half, x + 1, evaluate=False) + assert latex(e) == r'\frac{x + 1}{2}' + e = Mul(y, x + 1, evaluate=False) + assert latex(e) == r'y \left(x + 1\right)' + e = Mul(-y, x + 1, evaluate=False) + assert latex(e) == r'- y \left(x + 1\right)' + e = Mul(-2, x + 1) + assert latex(e) == r'- 2 x - 2' + e = Mul(2, x + 1) + assert latex(e) == r'2 x + 2' + + +def test_Pow(): + e = Pow(2, 2, evaluate=False) + assert latex(e) == r'2^{2}' + assert latex(x**(Rational(-1, 3))) == r'\frac{1}{\sqrt[3]{x}}' + x2 = Symbol(r'x^2') + assert latex(x2**2) == r'\left(x^{2}\right)^{2}' + # Issue 11011 + assert latex(S('1.453e4500')**x) == r'{1.453 \cdot 10^{4500}}^{x}' + + +def test_issue_7180(): + assert latex(Equivalent(x, y)) == r"x \Leftrightarrow y" + assert latex(Not(Equivalent(x, y))) == r"x \not\Leftrightarrow y" + + +def test_issue_8409(): + assert latex(S.Half**n) == r"\left(\frac{1}{2}\right)^{n}" + + +def test_issue_8470(): + from sympy.parsing.sympy_parser import parse_expr + e = parse_expr("-B*A", evaluate=False) + assert latex(e) == r"A \left(- B\right)" + + +def test_issue_15439(): + x = MatrixSymbol('x', 2, 2) + y = MatrixSymbol('y', 2, 2) + assert latex((x * y).subs(y, -y)) == r"x \left(- y\right)" + assert latex((x * y).subs(y, -2*y)) == r"x \left(- 2 y\right)" + assert latex((x * y).subs(x, -x)) == r"\left(- x\right) y" + + +def test_issue_2934(): + assert latex(Symbol(r'\frac{a_1}{b_1}')) == r'\frac{a_1}{b_1}' + + +def test_issue_10489(): + latexSymbolWithBrace = r'C_{x_{0}}' + s = Symbol(latexSymbolWithBrace) + assert latex(s) == latexSymbolWithBrace + assert latex(cos(s)) == r'\cos{\left(C_{x_{0}} \right)}' + + +def test_issue_12886(): + m__1, l__1 = symbols('m__1, l__1') + assert latex(m__1**2 + l__1**2) == \ + r'\left(l^{1}\right)^{2} + \left(m^{1}\right)^{2}' + + +def test_issue_13559(): + from sympy.parsing.sympy_parser import parse_expr + expr = parse_expr('5/1', evaluate=False) + assert latex(expr) == r"\frac{5}{1}" + + +def test_issue_13651(): + expr = c + Mul(-1, a + b, evaluate=False) + assert latex(expr) == r"c - \left(a + b\right)" + + +def test_latex_UnevaluatedExpr(): + x = symbols("x") + he = UnevaluatedExpr(1/x) + assert latex(he) == latex(1/x) == r"\frac{1}{x}" + assert latex(he**2) == r"\left(\frac{1}{x}\right)^{2}" + assert latex(he + 1) == r"1 + \frac{1}{x}" + assert latex(x*he) == r"x \frac{1}{x}" + + +def test_MatrixElement_printing(): + # test cases for issue #11821 + A = MatrixSymbol("A", 1, 3) + B = MatrixSymbol("B", 1, 3) + C = MatrixSymbol("C", 1, 3) + + assert latex(A[0, 0]) == r"{A}_{0,0}" + assert latex(3 * A[0, 0]) == r"3 {A}_{0,0}" + + F = C[0, 0].subs(C, A - B) + assert latex(F) == r"{\left(A - B\right)}_{0,0}" + + i, j, k = symbols("i j k") + M = MatrixSymbol("M", k, k) + N = MatrixSymbol("N", k, k) + assert latex((M*N)[i, j]) == \ + r'\sum_{i_{1}=0}^{k - 1} {M}_{i,i_{1}} {N}_{i_{1},j}' + + X_a = MatrixSymbol('X_a', 3, 3) + assert latex(X_a[0, 0]) == r"{X_{a}}_{0,0}" + + +def test_MatrixSymbol_printing(): + # test cases for issue #14237 + A = MatrixSymbol("A", 3, 3) + B = MatrixSymbol("B", 3, 3) + C = MatrixSymbol("C", 3, 3) + + assert latex(-A) == r"- A" + assert latex(A - A*B - B) == r"A - A B - B" + assert latex(-A*B - A*B*C - B) == r"- A B - A B C - B" + + +def test_DotProduct_printing(): + X = MatrixSymbol('X', 3, 1) + Y = MatrixSymbol('Y', 3, 1) + a = Symbol('a') + assert latex(DotProduct(X, Y)) == r"X \cdot Y" + assert latex(DotProduct(a * X, Y)) == r"a X \cdot Y" + assert latex(a * DotProduct(X, Y)) == r"a \left(X \cdot Y\right)" + + +def test_KroneckerProduct_printing(): + A = MatrixSymbol('A', 3, 3) + B = MatrixSymbol('B', 2, 2) + assert latex(KroneckerProduct(A, B)) == r'A \otimes B' + + +def test_Series_printing(): + tf1 = TransferFunction(x*y**2 - z, y**3 - t**3, y) + tf2 = TransferFunction(x - y, x + y, y) + tf3 = TransferFunction(t*x**2 - t**w*x + w, t - y, y) + assert latex(Series(tf1, tf2)) == \ + r'\left(\frac{x y^{2} - z}{- t^{3} + y^{3}}\right) \left(\frac{x - y}{x + y}\right)' + assert latex(Series(tf1, tf2, tf3)) == \ + r'\left(\frac{x y^{2} - z}{- t^{3} + y^{3}}\right) \left(\frac{x - y}{x + y}\right) \left(\frac{t x^{2} - t^{w} x + w}{t - y}\right)' + assert latex(Series(-tf2, tf1)) == \ + r'\left(\frac{- x + y}{x + y}\right) \left(\frac{x y^{2} - z}{- t^{3} + y^{3}}\right)' + + M_1 = Matrix([[5/s], [5/(2*s)]]) + T_1 = TransferFunctionMatrix.from_Matrix(M_1, s) + M_2 = Matrix([[5, 6*s**3]]) + T_2 = TransferFunctionMatrix.from_Matrix(M_2, s) + # Brackets + assert latex(T_1*(T_2 + T_2)) == \ + r'\left[\begin{matrix}\frac{5}{s}\\\frac{5}{2 s}\end{matrix}\right]_\tau\cdot\left(\left[\begin{matrix}\frac{5}{1} &' \ + r' \frac{6 s^{3}}{1}\end{matrix}\right]_\tau + \left[\begin{matrix}\frac{5}{1} & \frac{6 s^{3}}{1}\end{matrix}\right]_\tau\right)' \ + == latex(MIMOSeries(MIMOParallel(T_2, T_2), T_1)) + # No Brackets + M_3 = Matrix([[5, 6], [6, 5/s]]) + T_3 = TransferFunctionMatrix.from_Matrix(M_3, s) + assert latex(T_1*T_2 + T_3) == r'\left[\begin{matrix}\frac{5}{s}\\\frac{5}{2 s}\end{matrix}\right]_\tau\cdot\left[\begin{matrix}' \ + r'\frac{5}{1} & \frac{6 s^{3}}{1}\end{matrix}\right]_\tau + \left[\begin{matrix}\frac{5}{1} & \frac{6}{1}\\\frac{6}{1} & ' \ + r'\frac{5}{s}\end{matrix}\right]_\tau' == latex(MIMOParallel(MIMOSeries(T_2, T_1), T_3)) + + +def test_TransferFunction_printing(): + tf1 = TransferFunction(x - 1, x + 1, x) + assert latex(tf1) == r"\frac{x - 1}{x + 1}" + tf2 = TransferFunction(x + 1, 2 - y, x) + assert latex(tf2) == r"\frac{x + 1}{2 - y}" + tf3 = TransferFunction(y, y**2 + 2*y + 3, y) + assert latex(tf3) == r"\frac{y}{y^{2} + 2 y + 3}" + + +def test_Parallel_printing(): + tf1 = TransferFunction(x*y**2 - z, y**3 - t**3, y) + tf2 = TransferFunction(x - y, x + y, y) + assert latex(Parallel(tf1, tf2)) == \ + r'\frac{x y^{2} - z}{- t^{3} + y^{3}} + \frac{x - y}{x + y}' + assert latex(Parallel(-tf2, tf1)) == \ + r'\frac{- x + y}{x + y} + \frac{x y^{2} - z}{- t^{3} + y^{3}}' + + M_1 = Matrix([[5, 6], [6, 5/s]]) + T_1 = TransferFunctionMatrix.from_Matrix(M_1, s) + M_2 = Matrix([[5/s, 6], [6, 5/(s - 1)]]) + T_2 = TransferFunctionMatrix.from_Matrix(M_2, s) + M_3 = Matrix([[6, 5/(s*(s - 1))], [5, 6]]) + T_3 = TransferFunctionMatrix.from_Matrix(M_3, s) + assert latex(T_1 + T_2 + T_3) == r'\left[\begin{matrix}\frac{5}{1} & \frac{6}{1}\\\frac{6}{1} & \frac{5}{s}\end{matrix}\right]' \ + r'_\tau + \left[\begin{matrix}\frac{5}{s} & \frac{6}{1}\\\frac{6}{1} & \frac{5}{s - 1}\end{matrix}\right]_\tau + \left[\begin{matrix}' \ + r'\frac{6}{1} & \frac{5}{s \left(s - 1\right)}\\\frac{5}{1} & \frac{6}{1}\end{matrix}\right]_\tau' \ + == latex(MIMOParallel(T_1, T_2, T_3)) == latex(MIMOParallel(T_1, MIMOParallel(T_2, T_3))) == latex(MIMOParallel(MIMOParallel(T_1, T_2), T_3)) + + +def test_TransferFunctionMatrix_printing(): + tf1 = TransferFunction(p, p + x, p) + tf2 = TransferFunction(-s + p, p + s, p) + tf3 = TransferFunction(p, y**2 + 2*y + 3, p) + assert latex(TransferFunctionMatrix([[tf1], [tf2]])) == \ + r'\left[\begin{matrix}\frac{p}{p + x}\\\frac{p - s}{p + s}\end{matrix}\right]_\tau' + assert latex(TransferFunctionMatrix([[tf1, tf2], [tf3, -tf1]])) == \ + r'\left[\begin{matrix}\frac{p}{p + x} & \frac{p - s}{p + s}\\\frac{p}{y^{2} + 2 y + 3} & \frac{\left(-1\right) p}{p + x}\end{matrix}\right]_\tau' + + +def test_Feedback_printing(): + tf1 = TransferFunction(p, p + x, p) + tf2 = TransferFunction(-s + p, p + s, p) + # Negative Feedback (Default) + assert latex(Feedback(tf1, tf2)) == \ + r'\frac{\frac{p}{p + x}}{\frac{1}{1} + \left(\frac{p}{p + x}\right) \left(\frac{p - s}{p + s}\right)}' + assert latex(Feedback(tf1*tf2, TransferFunction(1, 1, p))) == \ + r'\frac{\left(\frac{p}{p + x}\right) \left(\frac{p - s}{p + s}\right)}{\frac{1}{1} + \left(\frac{p}{p + x}\right) \left(\frac{p - s}{p + s}\right)}' + # Positive Feedback + assert latex(Feedback(tf1, tf2, 1)) == \ + r'\frac{\frac{p}{p + x}}{\frac{1}{1} - \left(\frac{p}{p + x}\right) \left(\frac{p - s}{p + s}\right)}' + assert latex(Feedback(tf1*tf2, sign=1)) == \ + r'\frac{\left(\frac{p}{p + x}\right) \left(\frac{p - s}{p + s}\right)}{\frac{1}{1} - \left(\frac{p}{p + x}\right) \left(\frac{p - s}{p + s}\right)}' + + +def test_MIMOFeedback_printing(): + tf1 = TransferFunction(1, s, s) + tf2 = TransferFunction(s, s**2 - 1, s) + tf3 = TransferFunction(s, s - 1, s) + tf4 = TransferFunction(s**2, s**2 - 1, s) + + tfm_1 = TransferFunctionMatrix([[tf1, tf2], [tf3, tf4]]) + tfm_2 = TransferFunctionMatrix([[tf4, tf3], [tf2, tf1]]) + + # Negative Feedback (Default) + assert latex(MIMOFeedback(tfm_1, tfm_2)) == \ + r'\left(I_{\tau} + \left[\begin{matrix}\frac{1}{s} & \frac{s}{s^{2} - 1}\\\frac{s}{s - 1} & \frac{s^{2}}{s^{2} - 1}\end{matrix}\right]_\tau\cdot\left[' \ + r'\begin{matrix}\frac{s^{2}}{s^{2} - 1} & \frac{s}{s - 1}\\\frac{s}{s^{2} - 1} & \frac{1}{s}\end{matrix}\right]_\tau\right)^{-1} \cdot \left[\begin{matrix}' \ + r'\frac{1}{s} & \frac{s}{s^{2} - 1}\\\frac{s}{s - 1} & \frac{s^{2}}{s^{2} - 1}\end{matrix}\right]_\tau' + + # Positive Feedback + assert latex(MIMOFeedback(tfm_1*tfm_2, tfm_1, 1)) == \ + r'\left(I_{\tau} - \left[\begin{matrix}\frac{1}{s} & \frac{s}{s^{2} - 1}\\\frac{s}{s - 1} & \frac{s^{2}}{s^{2} - 1}\end{matrix}\right]_\tau\cdot\left' \ + r'[\begin{matrix}\frac{s^{2}}{s^{2} - 1} & \frac{s}{s - 1}\\\frac{s}{s^{2} - 1} & \frac{1}{s}\end{matrix}\right]_\tau\cdot\left[\begin{matrix}\frac{1}{s} & \frac{s}{s^{2} - 1}' \ + r'\\\frac{s}{s - 1} & \frac{s^{2}}{s^{2} - 1}\end{matrix}\right]_\tau\right)^{-1} \cdot \left[\begin{matrix}\frac{1}{s} & \frac{s}{s^{2} - 1}' \ + r'\\\frac{s}{s - 1} & \frac{s^{2}}{s^{2} - 1}\end{matrix}\right]_\tau\cdot\left[\begin{matrix}\frac{s^{2}}{s^{2} - 1} & \frac{s}{s - 1}\\\frac{s}{s^{2} - 1}' \ + r' & \frac{1}{s}\end{matrix}\right]_\tau' + + +def test_Quaternion_latex_printing(): + q = Quaternion(x, y, z, t) + assert latex(q) == r"x + y i + z j + t k" + q = Quaternion(x, y, z, x*t) + assert latex(q) == r"x + y i + z j + t x k" + q = Quaternion(x, y, z, x + t) + assert latex(q) == r"x + y i + z j + \left(t + x\right) k" + + +def test_TensorProduct_printing(): + from sympy.tensor.functions import TensorProduct + A = MatrixSymbol("A", 3, 3) + B = MatrixSymbol("B", 3, 3) + assert latex(TensorProduct(A, B)) == r"A \otimes B" + + +def test_WedgeProduct_printing(): + from sympy.diffgeom.rn import R2 + from sympy.diffgeom import WedgeProduct + wp = WedgeProduct(R2.dx, R2.dy) + assert latex(wp) == r"\operatorname{d}x \wedge \operatorname{d}y" + + +def test_issue_9216(): + expr_1 = Pow(1, -1, evaluate=False) + assert latex(expr_1) == r"1^{-1}" + + expr_2 = Pow(1, Pow(1, -1, evaluate=False), evaluate=False) + assert latex(expr_2) == r"1^{1^{-1}}" + + expr_3 = Pow(3, -2, evaluate=False) + assert latex(expr_3) == r"\frac{1}{9}" + + expr_4 = Pow(1, -2, evaluate=False) + assert latex(expr_4) == r"1^{-2}" + + +def test_latex_printer_tensor(): + from sympy.tensor.tensor import TensorIndexType, tensor_indices, TensorHead, tensor_heads + L = TensorIndexType("L") + i, j, k, l = tensor_indices("i j k l", L) + i0 = tensor_indices("i_0", L) + A, B, C, D = tensor_heads("A B C D", [L]) + H = TensorHead("H", [L, L]) + K = TensorHead("K", [L, L, L, L]) + + assert latex(i) == r"{}^{i}" + assert latex(-i) == r"{}_{i}" + + expr = A(i) + assert latex(expr) == r"A{}^{i}" + + expr = A(i0) + assert latex(expr) == r"A{}^{i_{0}}" + + expr = A(-i) + assert latex(expr) == r"A{}_{i}" + + expr = -3*A(i) + assert latex(expr) == r"-3A{}^{i}" + + expr = K(i, j, -k, -i0) + assert latex(expr) == r"K{}^{ij}{}_{ki_{0}}" + + expr = K(i, -j, -k, i0) + assert latex(expr) == r"K{}^{i}{}_{jk}{}^{i_{0}}" + + expr = K(i, -j, k, -i0) + assert latex(expr) == r"K{}^{i}{}_{j}{}^{k}{}_{i_{0}}" + + expr = H(i, -j) + assert latex(expr) == r"H{}^{i}{}_{j}" + + expr = H(i, j) + assert latex(expr) == r"H{}^{ij}" + + expr = H(-i, -j) + assert latex(expr) == r"H{}_{ij}" + + expr = (1+x)*A(i) + assert latex(expr) == r"\left(x + 1\right)A{}^{i}" + + expr = H(i, -i) + assert latex(expr) == r"H{}^{L_{0}}{}_{L_{0}}" + + expr = H(i, -j)*A(j)*B(k) + assert latex(expr) == r"H{}^{i}{}_{L_{0}}A{}^{L_{0}}B{}^{k}" + + expr = A(i) + 3*B(i) + assert latex(expr) == r"3B{}^{i} + A{}^{i}" + + # Test ``TensorElement``: + from sympy.tensor.tensor import TensorElement + + expr = TensorElement(K(i, j, k, l), {i: 3, k: 2}) + assert latex(expr) == r'K{}^{i=3,j,k=2,l}' + + expr = TensorElement(K(i, j, k, l), {i: 3}) + assert latex(expr) == r'K{}^{i=3,jkl}' + + expr = TensorElement(K(i, -j, k, l), {i: 3, k: 2}) + assert latex(expr) == r'K{}^{i=3}{}_{j}{}^{k=2,l}' + + expr = TensorElement(K(i, -j, k, -l), {i: 3, k: 2}) + assert latex(expr) == r'K{}^{i=3}{}_{j}{}^{k=2}{}_{l}' + + expr = TensorElement(K(i, j, -k, -l), {i: 3, -k: 2}) + assert latex(expr) == r'K{}^{i=3,j}{}_{k=2,l}' + + expr = TensorElement(K(i, j, -k, -l), {i: 3}) + assert latex(expr) == r'K{}^{i=3,j}{}_{kl}' + + expr = PartialDerivative(A(i), A(i)) + assert latex(expr) == r"\frac{\partial}{\partial {A{}^{L_{0}}}}{A{}^{L_{0}}}" + + expr = PartialDerivative(A(-i), A(-j)) + assert latex(expr) == r"\frac{\partial}{\partial {A{}_{j}}}{A{}_{i}}" + + expr = PartialDerivative(K(i, j, -k, -l), A(m), A(-n)) + assert latex(expr) == r"\frac{\partial^{2}}{\partial {A{}^{m}} \partial {A{}_{n}}}{K{}^{ij}{}_{kl}}" + + expr = PartialDerivative(B(-i) + A(-i), A(-j), A(-n)) + assert latex(expr) == r"\frac{\partial^{2}}{\partial {A{}_{j}} \partial {A{}_{n}}}{\left(A{}_{i} + B{}_{i}\right)}" + + expr = PartialDerivative(3*A(-i), A(-j), A(-n)) + assert latex(expr) == r"\frac{\partial^{2}}{\partial {A{}_{j}} \partial {A{}_{n}}}{\left(3A{}_{i}\right)}" + + +def test_multiline_latex(): + a, b, c, d, e, f = symbols('a b c d e f') + expr = -a + 2*b -3*c +4*d -5*e + expected = r"\begin{eqnarray}" + "\n"\ + r"f & = &- a \nonumber\\" + "\n"\ + r"& & + 2 b \nonumber\\" + "\n"\ + r"& & - 3 c \nonumber\\" + "\n"\ + r"& & + 4 d \nonumber\\" + "\n"\ + r"& & - 5 e " + "\n"\ + r"\end{eqnarray}" + assert multiline_latex(f, expr, environment="eqnarray") == expected + + expected2 = r'\begin{eqnarray}' + '\n'\ + r'f & = &- a + 2 b \nonumber\\' + '\n'\ + r'& & - 3 c + 4 d \nonumber\\' + '\n'\ + r'& & - 5 e ' + '\n'\ + r'\end{eqnarray}' + + assert multiline_latex(f, expr, 2, environment="eqnarray") == expected2 + + expected3 = r'\begin{eqnarray}' + '\n'\ + r'f & = &- a + 2 b - 3 c \nonumber\\'+ '\n'\ + r'& & + 4 d - 5 e ' + '\n'\ + r'\end{eqnarray}' + + assert multiline_latex(f, expr, 3, environment="eqnarray") == expected3 + + expected3dots = r'\begin{eqnarray}' + '\n'\ + r'f & = &- a + 2 b - 3 c \dots\nonumber\\'+ '\n'\ + r'& & + 4 d - 5 e ' + '\n'\ + r'\end{eqnarray}' + + assert multiline_latex(f, expr, 3, environment="eqnarray", use_dots=True) == expected3dots + + expected3align = r'\begin{align*}' + '\n'\ + r'f = &- a + 2 b - 3 c \\'+ '\n'\ + r'& + 4 d - 5 e ' + '\n'\ + r'\end{align*}' + + assert multiline_latex(f, expr, 3) == expected3align + assert multiline_latex(f, expr, 3, environment='align*') == expected3align + + expected2ieee = r'\begin{IEEEeqnarray}{rCl}' + '\n'\ + r'f & = &- a + 2 b \nonumber\\' + '\n'\ + r'& & - 3 c + 4 d \nonumber\\' + '\n'\ + r'& & - 5 e ' + '\n'\ + r'\end{IEEEeqnarray}' + + assert multiline_latex(f, expr, 2, environment="IEEEeqnarray") == expected2ieee + + raises(ValueError, lambda: multiline_latex(f, expr, environment="foo")) + +def test_issue_15353(): + a, x = symbols('a x') + # Obtained from nonlinsolve([(sin(a*x)),cos(a*x)],[x,a]) + sol = ConditionSet( + Tuple(x, a), Eq(sin(a*x), 0) & Eq(cos(a*x), 0), S.Complexes**2) + assert latex(sol) == \ + r'\left\{\left( x, \ a\right)\; \middle|\; \left( x, \ a\right) \in ' \ + r'\mathbb{C}^{2} \wedge \sin{\left(a x \right)} = 0 \wedge ' \ + r'\cos{\left(a x \right)} = 0 \right\}' + + +def test_latex_symbolic_probability(): + mu = symbols("mu") + sigma = symbols("sigma", positive=True) + X = Normal("X", mu, sigma) + assert latex(Expectation(X)) == r'\operatorname{E}\left[X\right]' + assert latex(Variance(X)) == r'\operatorname{Var}\left(X\right)' + assert latex(Probability(X > 0)) == r'\operatorname{P}\left(X > 0\right)' + Y = Normal("Y", mu, sigma) + assert latex(Covariance(X, Y)) == r'\operatorname{Cov}\left(X, Y\right)' + + +def test_trace(): + # Issue 15303 + from sympy.matrices.expressions.trace import trace + A = MatrixSymbol("A", 2, 2) + assert latex(trace(A)) == r"\operatorname{tr}\left(A \right)" + assert latex(trace(A**2)) == r"\operatorname{tr}\left(A^{2} \right)" + + +def test_print_basic(): + # Issue 15303 + from sympy.core.basic import Basic + from sympy.core.expr import Expr + + # dummy class for testing printing where the function is not + # implemented in latex.py + class UnimplementedExpr(Expr): + def __new__(cls, e): + return Basic.__new__(cls, e) + + # dummy function for testing + def unimplemented_expr(expr): + return UnimplementedExpr(expr).doit() + + # override class name to use superscript / subscript + def unimplemented_expr_sup_sub(expr): + result = UnimplementedExpr(expr) + result.__class__.__name__ = 'UnimplementedExpr_x^1' + return result + + assert latex(unimplemented_expr(x)) == r'\operatorname{UnimplementedExpr}\left(x\right)' + assert latex(unimplemented_expr(x**2)) == \ + r'\operatorname{UnimplementedExpr}\left(x^{2}\right)' + assert latex(unimplemented_expr_sup_sub(x)) == \ + r'\operatorname{UnimplementedExpr^{1}_{x}}\left(x\right)' + + +def test_MatrixSymbol_bold(): + # Issue #15871 + from sympy.matrices.expressions.trace import trace + A = MatrixSymbol("A", 2, 2) + assert latex(trace(A), mat_symbol_style='bold') == \ + r"\operatorname{tr}\left(\mathbf{A} \right)" + assert latex(trace(A), mat_symbol_style='plain') == \ + r"\operatorname{tr}\left(A \right)" + + A = MatrixSymbol("A", 3, 3) + B = MatrixSymbol("B", 3, 3) + C = MatrixSymbol("C", 3, 3) + + assert latex(-A, mat_symbol_style='bold') == r"- \mathbf{A}" + assert latex(A - A*B - B, mat_symbol_style='bold') == \ + r"\mathbf{A} - \mathbf{A} \mathbf{B} - \mathbf{B}" + assert latex(-A*B - A*B*C - B, mat_symbol_style='bold') == \ + r"- \mathbf{A} \mathbf{B} - \mathbf{A} \mathbf{B} \mathbf{C} - \mathbf{B}" + + A_k = MatrixSymbol("A_k", 3, 3) + assert latex(A_k, mat_symbol_style='bold') == r"\mathbf{A}_{k}" + + A = MatrixSymbol(r"\nabla_k", 3, 3) + assert latex(A, mat_symbol_style='bold') == r"\mathbf{\nabla}_{k}" + +def test_AppliedPermutation(): + p = Permutation(0, 1, 2) + x = Symbol('x') + assert latex(AppliedPermutation(p, x)) == \ + r'\sigma_{\left( 0\; 1\; 2\right)}(x)' + + +def test_PermutationMatrix(): + p = Permutation(0, 1, 2) + assert latex(PermutationMatrix(p)) == r'P_{\left( 0\; 1\; 2\right)}' + p = Permutation(0, 3)(1, 2) + assert latex(PermutationMatrix(p)) == \ + r'P_{\left( 0\; 3\right)\left( 1\; 2\right)}' + + +def test_issue_21758(): + from sympy.functions.elementary.piecewise import piecewise_fold + from sympy.series.fourier import FourierSeries + x = Symbol('x') + k, n = symbols('k n') + fo = FourierSeries(x, (x, -pi, pi), (0, SeqFormula(0, (k, 1, oo)), SeqFormula( + Piecewise((-2*pi*cos(n*pi)/n + 2*sin(n*pi)/n**2, (n > -oo) & (n < oo) & Ne(n, 0)), + (0, True))*sin(n*x)/pi, (n, 1, oo)))) + assert latex(piecewise_fold(fo)) == '\\begin{cases} 2 \\sin{\\left(x \\right)}' \ + ' - \\sin{\\left(2 x \\right)} + \\frac{2 \\sin{\\left(3 x \\right)}}{3} +' \ + ' \\ldots & \\text{for}\\: n > -\\infty \\wedge n < \\infty \\wedge ' \ + 'n \\neq 0 \\\\0 & \\text{otherwise} \\end{cases}' + assert latex(FourierSeries(x, (x, -pi, pi), (0, SeqFormula(0, (k, 1, oo)), + SeqFormula(0, (n, 1, oo))))) == '0' + + +def test_imaginary_unit(): + assert latex(1 + I) == r'1 + i' + assert latex(1 + I, imaginary_unit='i') == r'1 + i' + assert latex(1 + I, imaginary_unit='j') == r'1 + j' + assert latex(1 + I, imaginary_unit='foo') == r'1 + foo' + assert latex(I, imaginary_unit="ti") == r'\text{i}' + assert latex(I, imaginary_unit="tj") == r'\text{j}' + + +def test_text_re_im(): + assert latex(im(x), gothic_re_im=True) == r'\Im{\left(x\right)}' + assert latex(im(x), gothic_re_im=False) == r'\operatorname{im}{\left(x\right)}' + assert latex(re(x), gothic_re_im=True) == r'\Re{\left(x\right)}' + assert latex(re(x), gothic_re_im=False) == r'\operatorname{re}{\left(x\right)}' + + +def test_latex_diffgeom(): + from sympy.diffgeom import Manifold, Patch, CoordSystem, BaseScalarField, Differential + from sympy.diffgeom.rn import R2 + x,y = symbols('x y', real=True) + m = Manifold('M', 2) + assert latex(m) == r'\text{M}' + p = Patch('P', m) + assert latex(p) == r'\text{P}_{\text{M}}' + rect = CoordSystem('rect', p, [x, y]) + assert latex(rect) == r'\text{rect}^{\text{P}}_{\text{M}}' + b = BaseScalarField(rect, 0) + assert latex(b) == r'\mathbf{x}' + + g = Function('g') + s_field = g(R2.x, R2.y) + assert latex(Differential(s_field)) == \ + r'\operatorname{d}\left(g{\left(\mathbf{x},\mathbf{y} \right)}\right)' + + +def test_unit_printing(): + assert latex(5*meter) == r'5 \text{m}' + assert latex(3*gibibyte) == r'3 \text{gibibyte}' + assert latex(4*microgram/second) == r'\frac{4 \mu\text{g}}{\text{s}}' + assert latex(4*micro*gram/second) == r'\frac{4 \mu \text{g}}{\text{s}}' + assert latex(5*milli*meter) == r'5 \text{m} \text{m}' + assert latex(milli) == r'\text{m}' + + +def test_issue_17092(): + x_star = Symbol('x^*') + assert latex(Derivative(x_star, x_star,2)) == r'\frac{d^{2}}{d \left(x^{*}\right)^{2}} x^{*}' + + +def test_latex_decimal_separator(): + + x, y, z, t = symbols('x y z t') + k, m, n = symbols('k m n', integer=True) + f, g, h = symbols('f g h', cls=Function) + + # comma decimal_separator + assert(latex([1, 2.3, 4.5], decimal_separator='comma') == r'\left[ 1; \ 2{,}3; \ 4{,}5\right]') + assert(latex(FiniteSet(1, 2.3, 4.5), decimal_separator='comma') == r'\left\{1; 2{,}3; 4{,}5\right\}') + assert(latex((1, 2.3, 4.6), decimal_separator = 'comma') == r'\left( 1; \ 2{,}3; \ 4{,}6\right)') + assert(latex((1,), decimal_separator='comma') == r'\left( 1;\right)') + + # period decimal_separator + assert(latex([1, 2.3, 4.5], decimal_separator='period') == r'\left[ 1, \ 2.3, \ 4.5\right]' ) + assert(latex(FiniteSet(1, 2.3, 4.5), decimal_separator='period') == r'\left\{1, 2.3, 4.5\right\}') + assert(latex((1, 2.3, 4.6), decimal_separator = 'period') == r'\left( 1, \ 2.3, \ 4.6\right)') + assert(latex((1,), decimal_separator='period') == r'\left( 1,\right)') + + # default decimal_separator + assert(latex([1, 2.3, 4.5]) == r'\left[ 1, \ 2.3, \ 4.5\right]') + assert(latex(FiniteSet(1, 2.3, 4.5)) == r'\left\{1, 2.3, 4.5\right\}') + assert(latex((1, 2.3, 4.6)) == r'\left( 1, \ 2.3, \ 4.6\right)') + assert(latex((1,)) == r'\left( 1,\right)') + + assert(latex(Mul(3.4,5.3), decimal_separator = 'comma') == r'18{,}02') + assert(latex(3.4*5.3, decimal_separator = 'comma') == r'18{,}02') + x = symbols('x') + y = symbols('y') + z = symbols('z') + assert(latex(x*5.3 + 2**y**3.4 + 4.5 + z, decimal_separator = 'comma') == r'2^{y^{3{,}4}} + 5{,}3 x + z + 4{,}5') + + assert(latex(0.987, decimal_separator='comma') == r'0{,}987') + assert(latex(S(0.987), decimal_separator='comma') == r'0{,}987') + assert(latex(.3, decimal_separator='comma') == r'0{,}3') + assert(latex(S(.3), decimal_separator='comma') == r'0{,}3') + + + assert(latex(5.8*10**(-7), decimal_separator='comma') == r'5{,}8 \cdot 10^{-7}') + assert(latex(S(5.7)*10**(-7), decimal_separator='comma') == r'5{,}7 \cdot 10^{-7}') + assert(latex(S(5.7*10**(-7)), decimal_separator='comma') == r'5{,}7 \cdot 10^{-7}') + + x = symbols('x') + assert(latex(1.2*x+3.4, decimal_separator='comma') == r'1{,}2 x + 3{,}4') + assert(latex(FiniteSet(1, 2.3, 4.5), decimal_separator='period') == r'\left\{1, 2.3, 4.5\right\}') + + # Error Handling tests + raises(ValueError, lambda: latex([1,2.3,4.5], decimal_separator='non_existing_decimal_separator_in_list')) + raises(ValueError, lambda: latex(FiniteSet(1,2.3,4.5), decimal_separator='non_existing_decimal_separator_in_set')) + raises(ValueError, lambda: latex((1,2.3,4.5), decimal_separator='non_existing_decimal_separator_in_tuple')) + +def test_Str(): + from sympy.core.symbol import Str + assert str(Str('x')) == r'x' + +def test_latex_escape(): + assert latex_escape(r"~^\&%$#_{}") == "".join([ + r'\textasciitilde', + r'\textasciicircum', + r'\textbackslash', + r'\&', + r'\%', + r'\$', + r'\#', + r'\_', + r'\{', + r'\}', + ]) + +def test_emptyPrinter(): + class MyObject: + def __repr__(self): + return "" + + # unknown objects are monospaced + assert latex(MyObject()) == r"\mathtt{\text{}}" + + # even if they are nested within other objects + assert latex((MyObject(),)) == r"\left( \mathtt{\text{}},\right)" + +def test_global_settings(): + import inspect + + # settings should be visible in the signature of `latex` + assert inspect.signature(latex).parameters['imaginary_unit'].default == r'i' + assert latex(I) == r'i' + try: + # but changing the defaults... + LatexPrinter.set_global_settings(imaginary_unit='j') + # ... should change the signature + assert inspect.signature(latex).parameters['imaginary_unit'].default == r'j' + assert latex(I) == r'j' + finally: + # there's no public API to undo this, but we need to make sure we do + # so as not to impact other tests + del LatexPrinter._global_settings['imaginary_unit'] + + # check we really did undo it + assert inspect.signature(latex).parameters['imaginary_unit'].default == r'i' + assert latex(I) == r'i' + +def test_pickleable(): + # this tests that the _PrintFunction instance is pickleable + import pickle + assert pickle.loads(pickle.dumps(latex)) is latex + +def test_printing_latex_array_expressions(): + assert latex(ArraySymbol("A", (2, 3, 4))) == "A" + assert latex(ArrayElement("A", (2, 1/(1-x), 0))) == "{{A}_{2, \\frac{1}{1 - x}, 0}}" + M = MatrixSymbol("M", 3, 3) + N = MatrixSymbol("N", 3, 3) + assert latex(ArrayElement(M*N, [x, 0])) == "{{\\left(M N\\right)}_{x, 0}}" + +def test_Array(): + arr = Array(range(10)) + assert latex(arr) == r'\left[\begin{matrix}0 & 1 & 2 & 3 & 4 & 5 & 6 & 7 & 8 & 9\end{matrix}\right]' + + arr = Array(range(11)) + # fill the empty argument with a bunch of 'c' to avoid latex errors + assert latex(arr) == r'\left[\begin{array}{ccccccccccc}0 & 1 & 2 & 3 & 4 & 5 & 6 & 7 & 8 & 9 & 10\end{array}\right]' + +def test_latex_with_unevaluated(): + with evaluate(False): + assert latex(a * a) == r"a a" + + +def test_latex_disable_split_super_sub(): + assert latex(Symbol('u^a_b')) == 'u^{a}_{b}' + assert latex(Symbol('u^a_b'), disable_split_super_sub=False) == 'u^{a}_{b}' + assert latex(Symbol('u^a_b'), disable_split_super_sub=True) == 'u\\^a\\_b' diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/test_llvmjit.py b/phivenv/Lib/site-packages/sympy/printing/tests/test_llvmjit.py new file mode 100644 index 0000000000000000000000000000000000000000..709476f1d7517dc629210341594a70dc6f41808f --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/tests/test_llvmjit.py @@ -0,0 +1,224 @@ +from sympy.external import import_module +from sympy.testing.pytest import raises +import ctypes + + +if import_module('llvmlite'): + import sympy.printing.llvmjitcode as g +else: + disabled = True + +import sympy +from sympy.abc import a, b, n + + +# copied from numpy.isclose documentation +def isclose(a, b): + rtol = 1e-5 + atol = 1e-8 + return abs(a-b) <= atol + rtol*abs(b) + + +def test_simple_expr(): + e = a + 1.0 + f = g.llvm_callable([a], e) + res = float(e.subs({a: 4.0}).evalf()) + jit_res = f(4.0) + + assert isclose(jit_res, res) + + +def test_two_arg(): + e = 4.0*a + b + 3.0 + f = g.llvm_callable([a, b], e) + res = float(e.subs({a: 4.0, b: 3.0}).evalf()) + jit_res = f(4.0, 3.0) + + assert isclose(jit_res, res) + + +def test_func(): + e = 4.0*sympy.exp(-a) + f = g.llvm_callable([a], e) + res = float(e.subs({a: 1.5}).evalf()) + jit_res = f(1.5) + + assert isclose(jit_res, res) + + +def test_two_func(): + e = 4.0*sympy.exp(-a) + sympy.exp(b) + f = g.llvm_callable([a, b], e) + res = float(e.subs({a: 1.5, b: 2.0}).evalf()) + jit_res = f(1.5, 2.0) + + assert isclose(jit_res, res) + + +def test_two_sqrt(): + e = 4.0*sympy.sqrt(a) + sympy.sqrt(b) + f = g.llvm_callable([a, b], e) + res = float(e.subs({a: 1.5, b: 2.0}).evalf()) + jit_res = f(1.5, 2.0) + + assert isclose(jit_res, res) + + +def test_two_pow(): + e = a**1.5 + b**7 + f = g.llvm_callable([a, b], e) + res = float(e.subs({a: 1.5, b: 2.0}).evalf()) + jit_res = f(1.5, 2.0) + + assert isclose(jit_res, res) + + +def test_callback(): + e = a + 1.2 + f = g.llvm_callable([a], e, callback_type='scipy.integrate.test') + m = ctypes.c_int(1) + array_type = ctypes.c_double * 1 + inp = {a: 2.2} + array = array_type(inp[a]) + jit_res = f(m, array) + + res = float(e.subs(inp).evalf()) + + assert isclose(jit_res, res) + + +def test_callback_cubature(): + e = a + 1.2 + f = g.llvm_callable([a], e, callback_type='cubature') + m = ctypes.c_int(1) + array_type = ctypes.c_double * 1 + inp = {a: 2.2} + array = array_type(inp[a]) + out_array = array_type(0.0) + jit_ret = f(m, array, None, m, out_array) + + assert jit_ret == 0 + + res = float(e.subs(inp).evalf()) + + assert isclose(out_array[0], res) + + +def test_callback_two(): + e = 3*a*b + f = g.llvm_callable([a, b], e, callback_type='scipy.integrate.test') + m = ctypes.c_int(2) + array_type = ctypes.c_double * 2 + inp = {a: 0.2, b: 1.7} + array = array_type(inp[a], inp[b]) + jit_res = f(m, array) + + res = float(e.subs(inp).evalf()) + + assert isclose(jit_res, res) + + +def test_callback_alt_two(): + d = sympy.IndexedBase('d') + e = 3*d[0]*d[1] + f = g.llvm_callable([n, d], e, callback_type='scipy.integrate.test') + m = ctypes.c_int(2) + array_type = ctypes.c_double * 2 + inp = {d[0]: 0.2, d[1]: 1.7} + array = array_type(inp[d[0]], inp[d[1]]) + jit_res = f(m, array) + + res = float(e.subs(inp).evalf()) + + assert isclose(jit_res, res) + + +def test_multiple_statements(): + # Match return from CSE + e = [[(b, 4.0*a)], [b + 5]] + f = g.llvm_callable([a], e) + b_val = e[0][0][1].subs({a: 1.5}) + res = float(e[1][0].subs({b: b_val}).evalf()) + jit_res = f(1.5) + assert isclose(jit_res, res) + + f_callback = g.llvm_callable([a], e, callback_type='scipy.integrate.test') + m = ctypes.c_int(1) + array_type = ctypes.c_double * 1 + array = array_type(1.5) + jit_callback_res = f_callback(m, array) + assert isclose(jit_callback_res, res) + + +def test_cse(): + e = a*a + b*b + sympy.exp(-a*a - b*b) + e2 = sympy.cse(e) + f = g.llvm_callable([a, b], e2) + res = float(e.subs({a: 2.3, b: 0.1}).evalf()) + jit_res = f(2.3, 0.1) + + assert isclose(jit_res, res) + + +def eval_cse(e, sub_dict): + tmp_dict = {} + for tmp_name, tmp_expr in e[0]: + e2 = tmp_expr.subs(sub_dict) + e3 = e2.subs(tmp_dict) + tmp_dict[tmp_name] = e3 + return [e.subs(sub_dict).subs(tmp_dict) for e in e[1]] + + +def test_cse_multiple(): + e1 = a*a + e2 = a*a + b*b + e3 = sympy.cse([e1, e2]) + + raises(NotImplementedError, + lambda: g.llvm_callable([a, b], e3, callback_type='scipy.integrate')) + + f = g.llvm_callable([a, b], e3) + jit_res = f(0.1, 1.5) + assert len(jit_res) == 2 + res = eval_cse(e3, {a: 0.1, b: 1.5}) + assert isclose(res[0], jit_res[0]) + assert isclose(res[1], jit_res[1]) + + +def test_callback_cubature_multiple(): + e1 = a*a + e2 = a*a + b*b + e3 = sympy.cse([e1, e2, 4*e2]) + f = g.llvm_callable([a, b], e3, callback_type='cubature') + + # Number of input variables + ndim = 2 + # Number of output expression values + outdim = 3 + + m = ctypes.c_int(ndim) + fdim = ctypes.c_int(outdim) + array_type = ctypes.c_double * ndim + out_array_type = ctypes.c_double * outdim + inp = {a: 0.2, b: 1.5} + array = array_type(inp[a], inp[b]) + out_array = out_array_type() + jit_ret = f(m, array, None, fdim, out_array) + + assert jit_ret == 0 + + res = eval_cse(e3, inp) + + assert isclose(out_array[0], res[0]) + assert isclose(out_array[1], res[1]) + assert isclose(out_array[2], res[2]) + + +def test_symbol_not_found(): + e = a*a + b + raises(LookupError, lambda: g.llvm_callable([a], e)) + + +def test_bad_callback(): + e = a + raises(ValueError, lambda: g.llvm_callable([a], e, callback_type='bad_callback')) diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/test_maple.py b/phivenv/Lib/site-packages/sympy/printing/tests/test_maple.py new file mode 100644 index 0000000000000000000000000000000000000000..9bb4c512ad3203bd64ae56b350e15734b3a6afb0 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/tests/test_maple.py @@ -0,0 +1,381 @@ +from sympy.core import (S, pi, oo, symbols, Function, Rational, Integer, + Tuple, Symbol, Eq, Ne, Le, Lt, Gt, Ge) +from sympy.core import EulerGamma, GoldenRatio, Catalan, Lambda, Mul, Pow +from sympy.functions import Piecewise, sqrt, ceiling, exp, sin, cos, sinc, lucas +from sympy.testing.pytest import raises +from sympy.utilities.lambdify import implemented_function +from sympy.matrices import (eye, Matrix, MatrixSymbol, Identity, + HadamardProduct, SparseMatrix) +from sympy.functions.special.bessel import besseli + +from sympy.printing.maple import maple_code + +x, y, z = symbols('x,y,z') + + +def test_Integer(): + assert maple_code(Integer(67)) == "67" + assert maple_code(Integer(-1)) == "-1" + + +def test_Rational(): + assert maple_code(Rational(3, 7)) == "3/7" + assert maple_code(Rational(18, 9)) == "2" + assert maple_code(Rational(3, -7)) == "-3/7" + assert maple_code(Rational(-3, -7)) == "3/7" + assert maple_code(x + Rational(3, 7)) == "x + 3/7" + assert maple_code(Rational(3, 7) * x) == '(3/7)*x' + + +def test_Relational(): + assert maple_code(Eq(x, y)) == "x = y" + assert maple_code(Ne(x, y)) == "x <> y" + assert maple_code(Le(x, y)) == "x <= y" + assert maple_code(Lt(x, y)) == "x < y" + assert maple_code(Gt(x, y)) == "x > y" + assert maple_code(Ge(x, y)) == "x >= y" + + +def test_Function(): + assert maple_code(sin(x) ** cos(x)) == "sin(x)^cos(x)" + assert maple_code(abs(x)) == "abs(x)" + assert maple_code(ceiling(x)) == "ceil(x)" + + +def test_Pow(): + assert maple_code(x ** 3) == "x^3" + assert maple_code(x ** (y ** 3)) == "x^(y^3)" + + assert maple_code((x ** 3) ** y) == "(x^3)^y" + assert maple_code(x ** Rational(2, 3)) == 'x^(2/3)' + + g = implemented_function('g', Lambda(x, 2 * x)) + assert maple_code(1 / (g(x) * 3.5) ** (x - y ** x) / (x ** 2 + y)) == \ + "(3.5*2*x)^(-x + y^x)/(x^2 + y)" + # For issue 14160 + assert maple_code(Mul(-2, x, Pow(Mul(y, y, evaluate=False), -1, evaluate=False), + evaluate=False)) == '-2*x/(y*y)' + + +def test_basic_ops(): + assert maple_code(x * y) == "x*y" + assert maple_code(x + y) == "x + y" + assert maple_code(x - y) == "x - y" + assert maple_code(-x) == "-x" + + +def test_1_over_x_and_sqrt(): + # 1.0 and 0.5 would do something different in regular StrPrinter, + # but these are exact in IEEE floating point so no different here. + assert maple_code(1 / x) == '1/x' + assert maple_code(x ** -1) == maple_code(x ** -1.0) == '1/x' + assert maple_code(1 / sqrt(x)) == '1/sqrt(x)' + assert maple_code(x ** -S.Half) == maple_code(x ** -0.5) == '1/sqrt(x)' + assert maple_code(sqrt(x)) == 'sqrt(x)' + assert maple_code(x ** S.Half) == maple_code(x ** 0.5) == 'sqrt(x)' + assert maple_code(1 / pi) == '1/Pi' + assert maple_code(pi ** -1) == maple_code(pi ** -1.0) == '1/Pi' + assert maple_code(pi ** -0.5) == '1/sqrt(Pi)' + + +def test_mix_number_mult_symbols(): + assert maple_code(3 * x) == "3*x" + assert maple_code(pi * x) == "Pi*x" + assert maple_code(3 / x) == "3/x" + assert maple_code(pi / x) == "Pi/x" + assert maple_code(x / 3) == '(1/3)*x' + assert maple_code(x / pi) == "x/Pi" + assert maple_code(x * y) == "x*y" + assert maple_code(3 * x * y) == "3*x*y" + assert maple_code(3 * pi * x * y) == "3*Pi*x*y" + assert maple_code(x / y) == "x/y" + assert maple_code(3 * x / y) == "3*x/y" + assert maple_code(x * y / z) == "x*y/z" + assert maple_code(x / y * z) == "x*z/y" + assert maple_code(1 / x / y) == "1/(x*y)" + assert maple_code(2 * pi * x / y / z) == "2*Pi*x/(y*z)" + assert maple_code(3 * pi / x) == "3*Pi/x" + assert maple_code(S(3) / 5) == "3/5" + assert maple_code(S(3) / 5 * x) == '(3/5)*x' + assert maple_code(x / y / z) == "x/(y*z)" + assert maple_code((x + y) / z) == "(x + y)/z" + assert maple_code((x + y) / (z + x)) == "(x + y)/(x + z)" + assert maple_code((x + y) / EulerGamma) == '(x + y)/gamma' + assert maple_code(x / 3 / pi) == '(1/3)*x/Pi' + assert maple_code(S(3) / 5 * x * y / pi) == '(3/5)*x*y/Pi' + + +def test_mix_number_pow_symbols(): + assert maple_code(pi ** 3) == 'Pi^3' + assert maple_code(x ** 2) == 'x^2' + + assert maple_code(x ** (pi ** 3)) == 'x^(Pi^3)' + assert maple_code(x ** y) == 'x^y' + + assert maple_code(x ** (y ** z)) == 'x^(y^z)' + assert maple_code((x ** y) ** z) == '(x^y)^z' + + +def test_imag(): + I = S('I') + assert maple_code(I) == "I" + assert maple_code(5 * I) == "5*I" + + assert maple_code((S(3) / 2) * I) == "(3/2)*I" + assert maple_code(3 + 4 * I) == "3 + 4*I" + + +def test_constants(): + assert maple_code(pi) == "Pi" + assert maple_code(oo) == "infinity" + assert maple_code(-oo) == "-infinity" + assert maple_code(S.NegativeInfinity) == "-infinity" + assert maple_code(S.NaN) == "undefined" + assert maple_code(S.Exp1) == "exp(1)" + assert maple_code(exp(1)) == "exp(1)" + + +def test_constants_other(): + assert maple_code(2 * GoldenRatio) == '2*(1/2 + (1/2)*sqrt(5))' + assert maple_code(2 * Catalan) == '2*Catalan' + assert maple_code(2 * EulerGamma) == "2*gamma" + + +def test_boolean(): + assert maple_code(x & y) == "x and y" + assert maple_code(x | y) == "x or y" + assert maple_code(~x) == "not x" + assert maple_code(x & y & z) == "x and y and z" + assert maple_code(x | y | z) == "x or y or z" + assert maple_code((x & y) | z) == "z or x and y" + assert maple_code((x | y) & z) == "z and (x or y)" + + +def test_Matrices(): + assert maple_code(Matrix(1, 1, [10])) == \ + 'Matrix([[10]], storage = rectangular)' + + A = Matrix([[1, sin(x / 2), abs(x)], + [0, 1, pi], + [0, exp(1), ceiling(x)]]) + expected = \ + 'Matrix(' \ + '[[1, sin((1/2)*x), abs(x)],' \ + ' [0, 1, Pi],' \ + ' [0, exp(1), ceil(x)]], ' \ + 'storage = rectangular)' + assert maple_code(A) == expected + + # row and columns + assert maple_code(A[:, 0]) == \ + 'Matrix([[1], [0], [0]], storage = rectangular)' + assert maple_code(A[0, :]) == \ + 'Matrix([[1, sin((1/2)*x), abs(x)]], storage = rectangular)' + assert maple_code(Matrix([[x, x - y, -y]])) == \ + 'Matrix([[x, x - y, -y]], storage = rectangular)' + + # empty matrices + assert maple_code(Matrix(0, 0, [])) == \ + 'Matrix([], storage = rectangular)' + assert maple_code(Matrix(0, 3, [])) == \ + 'Matrix([], storage = rectangular)' + +def test_SparseMatrices(): + assert maple_code(SparseMatrix(Identity(2))) == 'Matrix([[1, 0], [0, 1]], storage = sparse)' + + +def test_vector_entries_hadamard(): + # For a row or column, user might to use the other dimension + A = Matrix([[1, sin(2 / x), 3 * pi / x / 5]]) + assert maple_code(A) == \ + 'Matrix([[1, sin(2/x), (3/5)*Pi/x]], storage = rectangular)' + assert maple_code(A.T) == \ + 'Matrix([[1], [sin(2/x)], [(3/5)*Pi/x]], storage = rectangular)' + + +def test_Matrices_entries_not_hadamard(): + A = Matrix([[1, sin(2 / x), 3 * pi / x / 5], [1, 2, x * y]]) + expected = \ + 'Matrix([[1, sin(2/x), (3/5)*Pi/x], [1, 2, x*y]], ' \ + 'storage = rectangular)' + assert maple_code(A) == expected + + +def test_MatrixSymbol(): + n = Symbol('n', integer=True) + A = MatrixSymbol('A', n, n) + B = MatrixSymbol('B', n, n) + assert maple_code(A * B) == "A.B" + assert maple_code(B * A) == "B.A" + assert maple_code(2 * A * B) == "2*A.B" + assert maple_code(B * 2 * A) == "2*B.A" + + assert maple_code( + A * (B + 3 * Identity(n))) == "A.(3*Matrix(n, shape = identity) + B)" + + assert maple_code(A ** (x ** 2)) == "MatrixPower(A, x^2)" + assert maple_code(A ** 3) == "MatrixPower(A, 3)" + assert maple_code(A ** (S.Half)) == "MatrixPower(A, 1/2)" + + +def test_special_matrices(): + assert maple_code(6 * Identity(3)) == "6*Matrix([[1, 0, 0], [0, 1, 0], [0, 0, 1]], storage = sparse)" + assert maple_code(Identity(x)) == 'Matrix(x, shape = identity)' + + +def test_containers(): + assert maple_code([1, 2, 3, [4, 5, [6, 7]], 8, [9, 10], 11]) == \ + "[1, 2, 3, [4, 5, [6, 7]], 8, [9, 10], 11]" + + assert maple_code((1, 2, (3, 4))) == "[1, 2, [3, 4]]" + assert maple_code([1]) == "[1]" + assert maple_code((1,)) == "[1]" + assert maple_code(Tuple(*[1, 2, 3])) == "[1, 2, 3]" + assert maple_code((1, x * y, (3, x ** 2))) == "[1, x*y, [3, x^2]]" + # scalar, matrix, empty matrix and empty list + + assert maple_code((1, eye(3), Matrix(0, 0, []), [])) == \ + "[1, Matrix([[1, 0, 0], [0, 1, 0], [0, 0, 1]], storage = rectangular), Matrix([], storage = rectangular), []]" + + +def test_maple_noninline(): + source = maple_code((x + y)/Catalan, assign_to='me', inline=False) + expected = "me := (x + y)/Catalan" + + assert source == expected + + +def test_maple_matrix_assign_to(): + A = Matrix([[1, 2, 3]]) + assert maple_code(A, assign_to='a') == "a := Matrix([[1, 2, 3]], storage = rectangular)" + A = Matrix([[1, 2], [3, 4]]) + assert maple_code(A, assign_to='A') == "A := Matrix([[1, 2], [3, 4]], storage = rectangular)" + + +def test_maple_matrix_assign_to_more(): + # assigning to Symbol or MatrixSymbol requires lhs/rhs match + A = Matrix([[1, 2, 3]]) + B = MatrixSymbol('B', 1, 3) + C = MatrixSymbol('C', 2, 3) + assert maple_code(A, assign_to=B) == "B := Matrix([[1, 2, 3]], storage = rectangular)" + raises(ValueError, lambda: maple_code(A, assign_to=x)) + raises(ValueError, lambda: maple_code(A, assign_to=C)) + + +def test_maple_matrix_1x1(): + A = Matrix([[3]]) + assert maple_code(A, assign_to='B') == "B := Matrix([[3]], storage = rectangular)" + + +def test_maple_matrix_elements(): + A = Matrix([[x, 2, x * y]]) + + assert maple_code(A[0, 0] ** 2 + A[0, 1] + A[0, 2]) == "x^2 + x*y + 2" + AA = MatrixSymbol('AA', 1, 3) + assert maple_code(AA) == "AA" + + assert maple_code(AA[0, 0] ** 2 + sin(AA[0, 1]) + AA[0, 2]) == \ + "sin(AA[1, 2]) + AA[1, 1]^2 + AA[1, 3]" + assert maple_code(sum(AA)) == "AA[1, 1] + AA[1, 2] + AA[1, 3]" + + +def test_maple_boolean(): + assert maple_code(True) == "true" + assert maple_code(S.true) == "true" + assert maple_code(False) == "false" + assert maple_code(S.false) == "false" + + +def test_sparse(): + M = SparseMatrix(5, 6, {}) + M[2, 2] = 10 + M[1, 2] = 20 + M[1, 3] = 22 + M[0, 3] = 30 + M[3, 0] = x * y + assert maple_code(M) == \ + 'Matrix([[0, 0, 0, 30, 0, 0],' \ + ' [0, 0, 20, 22, 0, 0],' \ + ' [0, 0, 10, 0, 0, 0],' \ + ' [x*y, 0, 0, 0, 0, 0],' \ + ' [0, 0, 0, 0, 0, 0]], ' \ + 'storage = sparse)' + +# Not an important point. +def test_maple_not_supported(): + with raises(NotImplementedError): + maple_code(S.ComplexInfinity) + + +def test_MatrixElement_printing(): + # test cases for issue #11821 + A = MatrixSymbol("A", 1, 3) + B = MatrixSymbol("B", 1, 3) + + assert (maple_code(A[0, 0]) == "A[1, 1]") + assert (maple_code(3 * A[0, 0]) == "3*A[1, 1]") + + F = A-B + + assert (maple_code(F[0,0]) == "A[1, 1] - B[1, 1]") + + +def test_hadamard(): + A = MatrixSymbol('A', 3, 3) + B = MatrixSymbol('B', 3, 3) + v = MatrixSymbol('v', 3, 1) + h = MatrixSymbol('h', 1, 3) + C = HadamardProduct(A, B) + assert maple_code(C) == "A*B" + + assert maple_code(C * v) == "(A*B).v" + # HadamardProduct is higher than dot product. + + assert maple_code(h * C * v) == "h.(A*B).v" + + assert maple_code(C * A) == "(A*B).A" + # mixing Hadamard and scalar strange b/c we vectorize scalars + + assert maple_code(C * x * y) == "x*y*(A*B)" + + +def test_maple_piecewise(): + expr = Piecewise((x, x < 1), (x ** 2, True)) + + assert maple_code(expr) == "piecewise(x < 1, x, x^2)" + assert maple_code(expr, assign_to="r") == ( + "r := piecewise(x < 1, x, x^2)") + + expr = Piecewise((x ** 2, x < 1), (x ** 3, x < 2), (x ** 4, x < 3), (x ** 5, True)) + expected = "piecewise(x < 1, x^2, x < 2, x^3, x < 3, x^4, x^5)" + assert maple_code(expr) == expected + assert maple_code(expr, assign_to="r") == "r := " + expected + + # Check that Piecewise without a True (default) condition error + expr = Piecewise((x, x < 1), (x ** 2, x > 1), (sin(x), x > 0)) + raises(ValueError, lambda: maple_code(expr)) + + +def test_maple_piecewise_times_const(): + pw = Piecewise((x, x < 1), (x ** 2, True)) + + assert maple_code(2 * pw) == "2*piecewise(x < 1, x, x^2)" + assert maple_code(pw / x) == "piecewise(x < 1, x, x^2)/x" + assert maple_code(pw / (x * y)) == "piecewise(x < 1, x, x^2)/(x*y)" + assert maple_code(pw / 3) == "(1/3)*piecewise(x < 1, x, x^2)" + + +def test_maple_derivatives(): + f = Function('f') + assert maple_code(f(x).diff(x)) == 'diff(f(x), x)' + assert maple_code(f(x).diff(x, 2)) == 'diff(f(x), x$2)' + + +def test_automatic_rewrites(): + assert maple_code(lucas(x)) == '(2^(-x)*((1 - sqrt(5))^x + (1 + sqrt(5))^x))' + assert maple_code(sinc(x)) == '(piecewise(x <> 0, sin(x)/x, 1))' + + +def test_specfun(): + assert maple_code('asin(x)') == 'arcsin(x)' + assert maple_code(besseli(x, y)) == 'BesselI(x, y)' diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/test_mathematica.py b/phivenv/Lib/site-packages/sympy/printing/tests/test_mathematica.py new file mode 100644 index 0000000000000000000000000000000000000000..aaf6b537677442ae59a4f1bbd2b5774d6646f4e2 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/tests/test_mathematica.py @@ -0,0 +1,287 @@ +from sympy.core import (S, pi, oo, symbols, Function, Rational, Integer, Tuple, + Derivative, Eq, Ne, Le, Lt, Gt, Ge) +from sympy.integrals import Integral +from sympy.concrete import Sum +from sympy.functions import (exp, sin, cos, fresnelc, fresnels, conjugate, Max, + Min, gamma, polygamma, loggamma, erf, erfi, erfc, + erf2, expint, erfinv, erfcinv, Ei, Si, Ci, li, + Shi, Chi, uppergamma, beta, subfactorial, erf2inv, + factorial, factorial2, catalan, RisingFactorial, + FallingFactorial, harmonic, atan2, sec, acsc, + hermite, laguerre, assoc_laguerre, jacobi, + gegenbauer, chebyshevt, chebyshevu, legendre, + assoc_legendre, Li, LambertW) + +from sympy.printing.mathematica import mathematica_code as mcode + +x, y, z, w = symbols('x,y,z,w') +f = Function('f') + + +def test_Integer(): + assert mcode(Integer(67)) == "67" + assert mcode(Integer(-1)) == "-1" + + +def test_Rational(): + assert mcode(Rational(3, 7)) == "3/7" + assert mcode(Rational(18, 9)) == "2" + assert mcode(Rational(3, -7)) == "-3/7" + assert mcode(Rational(-3, -7)) == "3/7" + assert mcode(x + Rational(3, 7)) == "x + 3/7" + assert mcode(Rational(3, 7)*x) == "(3/7)*x" + + +def test_Relational(): + assert mcode(Eq(x, y)) == "x == y" + assert mcode(Ne(x, y)) == "x != y" + assert mcode(Le(x, y)) == "x <= y" + assert mcode(Lt(x, y)) == "x < y" + assert mcode(Gt(x, y)) == "x > y" + assert mcode(Ge(x, y)) == "x >= y" + + +def test_Function(): + assert mcode(f(x, y, z)) == "f[x, y, z]" + assert mcode(sin(x) ** cos(x)) == "Sin[x]^Cos[x]" + assert mcode(sec(x) * acsc(x)) == "ArcCsc[x]*Sec[x]" + assert mcode(atan2(y, x)) == "ArcTan[x, y]" + assert mcode(conjugate(x)) == "Conjugate[x]" + assert mcode(Max(x, y, z)*Min(y, z)) == "Max[x, y, z]*Min[y, z]" + assert mcode(fresnelc(x)) == "FresnelC[x]" + assert mcode(fresnels(x)) == "FresnelS[x]" + assert mcode(gamma(x)) == "Gamma[x]" + assert mcode(uppergamma(x, y)) == "Gamma[x, y]" + assert mcode(polygamma(x, y)) == "PolyGamma[x, y]" + assert mcode(loggamma(x)) == "LogGamma[x]" + assert mcode(erf(x)) == "Erf[x]" + assert mcode(erfc(x)) == "Erfc[x]" + assert mcode(erfi(x)) == "Erfi[x]" + assert mcode(erf2(x, y)) == "Erf[x, y]" + assert mcode(expint(x, y)) == "ExpIntegralE[x, y]" + assert mcode(erfcinv(x)) == "InverseErfc[x]" + assert mcode(erfinv(x)) == "InverseErf[x]" + assert mcode(erf2inv(x, y)) == "InverseErf[x, y]" + assert mcode(Ei(x)) == "ExpIntegralEi[x]" + assert mcode(Ci(x)) == "CosIntegral[x]" + assert mcode(li(x)) == "LogIntegral[x]" + assert mcode(Si(x)) == "SinIntegral[x]" + assert mcode(Shi(x)) == "SinhIntegral[x]" + assert mcode(Chi(x)) == "CoshIntegral[x]" + assert mcode(beta(x, y)) == "Beta[x, y]" + assert mcode(factorial(x)) == "Factorial[x]" + assert mcode(factorial2(x)) == "Factorial2[x]" + assert mcode(subfactorial(x)) == "Subfactorial[x]" + assert mcode(FallingFactorial(x, y)) == "FactorialPower[x, y]" + assert mcode(RisingFactorial(x, y)) == "Pochhammer[x, y]" + assert mcode(catalan(x)) == "CatalanNumber[x]" + assert mcode(harmonic(x)) == "HarmonicNumber[x]" + assert mcode(harmonic(x, y)) == "HarmonicNumber[x, y]" + assert mcode(Li(x)) == "LogIntegral[x] - LogIntegral[2]" + assert mcode(LambertW(x)) == "ProductLog[x]" + assert mcode(LambertW(x, -1)) == "ProductLog[-1, x]" + assert mcode(LambertW(x, y)) == "ProductLog[y, x]" + + +def test_special_polynomials(): + assert mcode(hermite(x, y)) == "HermiteH[x, y]" + assert mcode(laguerre(x, y)) == "LaguerreL[x, y]" + assert mcode(assoc_laguerre(x, y, z)) == "LaguerreL[x, y, z]" + assert mcode(jacobi(x, y, z, w)) == "JacobiP[x, y, z, w]" + assert mcode(gegenbauer(x, y, z)) == "GegenbauerC[x, y, z]" + assert mcode(chebyshevt(x, y)) == "ChebyshevT[x, y]" + assert mcode(chebyshevu(x, y)) == "ChebyshevU[x, y]" + assert mcode(legendre(x, y)) == "LegendreP[x, y]" + assert mcode(assoc_legendre(x, y, z)) == "LegendreP[x, y, z]" + + +def test_Pow(): + assert mcode(x**3) == "x^3" + assert mcode(x**(y**3)) == "x^(y^3)" + assert mcode(1/(f(x)*3.5)**(x - y**x)/(x**2 + y)) == \ + "(3.5*f[x])^(-x + y^x)/(x^2 + y)" + assert mcode(x**-1.0) == 'x^(-1.0)' + assert mcode(x**Rational(2, 3)) == 'x^(2/3)' + + +def test_Mul(): + A, B, C, D = symbols('A B C D', commutative=False) + assert mcode(x*y*z) == "x*y*z" + assert mcode(x*y*A) == "x*y*A" + assert mcode(x*y*A*B) == "x*y*A**B" + assert mcode(x*y*A*B*C) == "x*y*A**B**C" + assert mcode(x*A*B*(C + D)*A*y) == "x*y*A**B**(C + D)**A" + + +def test_constants(): + assert mcode(S.Zero) == "0" + assert mcode(S.One) == "1" + assert mcode(S.NegativeOne) == "-1" + assert mcode(S.Half) == "1/2" + assert mcode(S.ImaginaryUnit) == "I" + + assert mcode(oo) == "Infinity" + assert mcode(S.NegativeInfinity) == "-Infinity" + assert mcode(S.ComplexInfinity) == "ComplexInfinity" + assert mcode(S.NaN) == "Indeterminate" + + assert mcode(S.Exp1) == "E" + assert mcode(pi) == "Pi" + assert mcode(S.GoldenRatio) == "GoldenRatio" + assert mcode(S.TribonacciConstant) == \ + "(1/3 + (1/3)*(19 - 3*33^(1/2))^(1/3) + " \ + "(1/3)*(3*33^(1/2) + 19)^(1/3))" + assert mcode(2*S.TribonacciConstant) == \ + "2*(1/3 + (1/3)*(19 - 3*33^(1/2))^(1/3) + " \ + "(1/3)*(3*33^(1/2) + 19)^(1/3))" + assert mcode(S.EulerGamma) == "EulerGamma" + assert mcode(S.Catalan) == "Catalan" + + +def test_containers(): + assert mcode([1, 2, 3, [4, 5, [6, 7]], 8, [9, 10], 11]) == \ + "{1, 2, 3, {4, 5, {6, 7}}, 8, {9, 10}, 11}" + assert mcode((1, 2, (3, 4))) == "{1, 2, {3, 4}}" + assert mcode([1]) == "{1}" + assert mcode((1,)) == "{1}" + assert mcode(Tuple(*[1, 2, 3])) == "{1, 2, 3}" + + +def test_matrices(): + from sympy.matrices import MutableDenseMatrix, MutableSparseMatrix, \ + ImmutableDenseMatrix, ImmutableSparseMatrix + A = MutableDenseMatrix( + [[1, -1, 0, 0], + [0, 1, -1, 0], + [0, 0, 1, -1], + [0, 0, 0, 1]] + ) + B = MutableSparseMatrix(A) + C = ImmutableDenseMatrix(A) + D = ImmutableSparseMatrix(A) + + assert mcode(C) == mcode(A) == \ + "{{1, -1, 0, 0}, " \ + "{0, 1, -1, 0}, " \ + "{0, 0, 1, -1}, " \ + "{0, 0, 0, 1}}" + + assert mcode(D) == mcode(B) == \ + "SparseArray[{" \ + "{1, 1} -> 1, {1, 2} -> -1, {2, 2} -> 1, {2, 3} -> -1, " \ + "{3, 3} -> 1, {3, 4} -> -1, {4, 4} -> 1" \ + "}, {4, 4}]" + + # Trivial cases of matrices + assert mcode(MutableDenseMatrix(0, 0, [])) == '{}' + assert mcode(MutableSparseMatrix(0, 0, [])) == 'SparseArray[{}, {0, 0}]' + assert mcode(MutableDenseMatrix(0, 3, [])) == '{}' + assert mcode(MutableSparseMatrix(0, 3, [])) == 'SparseArray[{}, {0, 3}]' + assert mcode(MutableDenseMatrix(3, 0, [])) == '{{}, {}, {}}' + assert mcode(MutableSparseMatrix(3, 0, [])) == 'SparseArray[{}, {3, 0}]' + +def test_NDArray(): + from sympy.tensor.array import ( + MutableDenseNDimArray, ImmutableDenseNDimArray, + MutableSparseNDimArray, ImmutableSparseNDimArray) + + example = MutableDenseNDimArray( + [[[1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12]], + [[13, 14, 15, 16], + [17, 18, 19, 20], + [21, 22, 23, 24]]] + ) + + assert mcode(example) == \ + "{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, " \ + "{{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}" + + example = ImmutableDenseNDimArray(example) + + assert mcode(example) == \ + "{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, " \ + "{{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}" + + example = MutableSparseNDimArray(example) + + assert mcode(example) == \ + "SparseArray[{" \ + "{1, 1, 1} -> 1, {1, 1, 2} -> 2, {1, 1, 3} -> 3, " \ + "{1, 1, 4} -> 4, {1, 2, 1} -> 5, {1, 2, 2} -> 6, " \ + "{1, 2, 3} -> 7, {1, 2, 4} -> 8, {1, 3, 1} -> 9, " \ + "{1, 3, 2} -> 10, {1, 3, 3} -> 11, {1, 3, 4} -> 12, " \ + "{2, 1, 1} -> 13, {2, 1, 2} -> 14, {2, 1, 3} -> 15, " \ + "{2, 1, 4} -> 16, {2, 2, 1} -> 17, {2, 2, 2} -> 18, " \ + "{2, 2, 3} -> 19, {2, 2, 4} -> 20, {2, 3, 1} -> 21, " \ + "{2, 3, 2} -> 22, {2, 3, 3} -> 23, {2, 3, 4} -> 24" \ + "}, {2, 3, 4}]" + + example = ImmutableSparseNDimArray(example) + + assert mcode(example) == \ + "SparseArray[{" \ + "{1, 1, 1} -> 1, {1, 1, 2} -> 2, {1, 1, 3} -> 3, " \ + "{1, 1, 4} -> 4, {1, 2, 1} -> 5, {1, 2, 2} -> 6, " \ + "{1, 2, 3} -> 7, {1, 2, 4} -> 8, {1, 3, 1} -> 9, " \ + "{1, 3, 2} -> 10, {1, 3, 3} -> 11, {1, 3, 4} -> 12, " \ + "{2, 1, 1} -> 13, {2, 1, 2} -> 14, {2, 1, 3} -> 15, " \ + "{2, 1, 4} -> 16, {2, 2, 1} -> 17, {2, 2, 2} -> 18, " \ + "{2, 2, 3} -> 19, {2, 2, 4} -> 20, {2, 3, 1} -> 21, " \ + "{2, 3, 2} -> 22, {2, 3, 3} -> 23, {2, 3, 4} -> 24" \ + "}, {2, 3, 4}]" + + +def test_Integral(): + assert mcode(Integral(sin(sin(x)), x)) == "Hold[Integrate[Sin[Sin[x]], x]]" + assert mcode(Integral(exp(-x**2 - y**2), + (x, -oo, oo), + (y, -oo, oo))) == \ + "Hold[Integrate[Exp[-x^2 - y^2], {x, -Infinity, Infinity}, " \ + "{y, -Infinity, Infinity}]]" + + +def test_Derivative(): + assert mcode(Derivative(sin(x), x)) == "Hold[D[Sin[x], x]]" + assert mcode(Derivative(x, x)) == "Hold[D[x, x]]" + assert mcode(Derivative(sin(x)*y**4, x, 2)) == "Hold[D[y^4*Sin[x], {x, 2}]]" + assert mcode(Derivative(sin(x)*y**4, x, y, x)) == "Hold[D[y^4*Sin[x], x, y, x]]" + assert mcode(Derivative(sin(x)*y**4, x, y, 3, x)) == "Hold[D[y^4*Sin[x], x, {y, 3}, x]]" + + +def test_Sum(): + assert mcode(Sum(sin(x), (x, 0, 10))) == "Hold[Sum[Sin[x], {x, 0, 10}]]" + assert mcode(Sum(exp(-x**2 - y**2), + (x, -oo, oo), + (y, -oo, oo))) == \ + "Hold[Sum[Exp[-x^2 - y^2], {x, -Infinity, Infinity}, " \ + "{y, -Infinity, Infinity}]]" + + +def test_comment(): + from sympy.printing.mathematica import MCodePrinter + assert MCodePrinter()._get_comment("Hello World") == \ + "(* Hello World *)" + + +def test_userfuncs(): + # Dictionary mutation test + some_function = symbols("some_function", cls=Function) + my_user_functions = {"some_function": "SomeFunction"} + assert mcode( + some_function(z), + user_functions=my_user_functions) == \ + 'SomeFunction[z]' + assert mcode( + some_function(z), + user_functions=my_user_functions) == \ + 'SomeFunction[z]' + + # List argument test + my_user_functions = \ + {"some_function": [(lambda x: True, "SomeOtherFunction")]} + assert mcode( + some_function(z), + user_functions=my_user_functions) == \ + 'SomeOtherFunction[z]' diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/test_mathml.py b/phivenv/Lib/site-packages/sympy/printing/tests/test_mathml.py new file mode 100644 index 0000000000000000000000000000000000000000..4e7c2253c98fb1a4e99375774ad158df9b80b439 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/tests/test_mathml.py @@ -0,0 +1,2048 @@ +from sympy.calculus.accumulationbounds import AccumBounds +from sympy.concrete.summations import Sum +from sympy.core.basic import Basic +from sympy.core.containers import Tuple +from sympy.core.function import Derivative, Lambda, diff, Function +from sympy.core.numbers import (zoo, Float, Integer, I, oo, pi, E, + Rational) +from sympy.core.relational import Lt, Ge, Ne, Eq +from sympy.core.singleton import S +from sympy.core.symbol import symbols, Symbol +from sympy.core.sympify import sympify +from sympy.functions.combinatorial.factorials import (factorial2, + binomial, factorial) +from sympy.functions.combinatorial.numbers import (lucas, bell, + catalan, euler, tribonacci, fibonacci, bernoulli, primenu, primeomega, + totient, reduced_totient) +from sympy.functions.elementary.complexes import re, im, conjugate, Abs +from sympy.functions.elementary.exponential import exp, LambertW, log +from sympy.functions.elementary.hyperbolic import (tanh, acoth, atanh, + coth, asinh, acsch, asech, acosh, csch, sinh, cosh, sech) +from sympy.functions.elementary.integers import ceiling, floor +from sympy.functions.elementary.miscellaneous import Max, Min +from sympy.functions.elementary.trigonometric import (csc, sec, tan, + atan, sin, asec, cot, cos, acot, acsc, asin, acos) +from sympy.functions.special.delta_functions import Heaviside +from sympy.functions.special.elliptic_integrals import (elliptic_pi, + elliptic_f, elliptic_k, elliptic_e) +from sympy.functions.special.error_functions import (fresnelc, + fresnels, Ei, expint) +from sympy.functions.special.gamma_functions import (gamma, uppergamma, + lowergamma) +from sympy.functions.special.mathieu_functions import (mathieusprime, + mathieus, mathieucprime, mathieuc) +from sympy.functions.special.polynomials import (jacobi, chebyshevu, + chebyshevt, hermite, assoc_legendre, gegenbauer, assoc_laguerre, + legendre, laguerre) +from sympy.functions.special.singularity_functions import SingularityFunction +from sympy.functions.special.zeta_functions import (polylog, stieltjes, + lerchphi, dirichlet_eta, zeta) +from sympy.integrals.integrals import Integral +from sympy.logic.boolalg import (Xor, Or, false, true, And, Equivalent, + Implies, Not) +from sympy.matrices.dense import Matrix +from sympy.matrices.expressions.determinant import Determinant +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.physics.quantum import (ComplexSpace, FockSpace, hbar, + HilbertSpace, Dagger) +from sympy.printing.mathml import (MathMLPresentationPrinter, + MathMLPrinter, MathMLContentPrinter, mathml) +from sympy.series.limits import Limit +from sympy.sets.contains import Contains +from sympy.sets.fancysets import Range +from sympy.sets.sets import (Interval, Union, SymmetricDifference, + Complement, FiniteSet, Intersection, ProductSet) +from sympy.stats.rv import RandomSymbol +from sympy.tensor.indexed import IndexedBase +from sympy.vector import (Divergence, CoordSys3D, Cross, Curl, Dot, + Laplacian, Gradient) +from sympy.testing.pytest import raises, XFAIL + +x, y, z, a, b, c, d, e, n = symbols('x:z a:e n') +mp = MathMLContentPrinter() +mpp = MathMLPresentationPrinter() + + +def test_mathml_printer(): + m = MathMLPrinter() + assert m.doprint(1+x) == mp.doprint(1+x) + + +def test_content_printmethod(): + assert mp.doprint(1 + x) == 'x1' + + +def test_content_mathml_core(): + mml_1 = mp._print(1 + x) + assert mml_1.nodeName == 'apply' + nodes = mml_1.childNodes + assert len(nodes) == 3 + assert nodes[0].nodeName == 'plus' + assert nodes[0].hasChildNodes() is False + assert nodes[0].nodeValue is None + assert nodes[1].nodeName in ['cn', 'ci'] + if nodes[1].nodeName == 'cn': + assert nodes[1].childNodes[0].nodeValue == '1' + assert nodes[2].childNodes[0].nodeValue == 'x' + else: + assert nodes[1].childNodes[0].nodeValue == 'x' + assert nodes[2].childNodes[0].nodeValue == '1' + + mml_2 = mp._print(x**2) + assert mml_2.nodeName == 'apply' + nodes = mml_2.childNodes + assert nodes[1].childNodes[0].nodeValue == 'x' + assert nodes[2].childNodes[0].nodeValue == '2' + + mml_3 = mp._print(2*x) + assert mml_3.nodeName == 'apply' + nodes = mml_3.childNodes + assert nodes[0].nodeName == 'times' + assert nodes[1].childNodes[0].nodeValue == '2' + assert nodes[2].childNodes[0].nodeValue == 'x' + + mml = mp._print(Float(1.0, 2)*x) + assert mml.nodeName == 'apply' + nodes = mml.childNodes + assert nodes[0].nodeName == 'times' + assert nodes[1].childNodes[0].nodeValue == '1.0' + assert nodes[2].childNodes[0].nodeValue == 'x' + + +def test_content_mathml_functions(): + mml_1 = mp._print(sin(x)) + assert mml_1.nodeName == 'apply' + assert mml_1.childNodes[0].nodeName == 'sin' + assert mml_1.childNodes[1].nodeName == 'ci' + + mml_2 = mp._print(diff(sin(x), x, evaluate=False)) + assert mml_2.nodeName == 'apply' + assert mml_2.childNodes[0].nodeName == 'diff' + assert mml_2.childNodes[1].nodeName == 'bvar' + assert mml_2.childNodes[1].childNodes[ + 0].nodeName == 'ci' # below bvar there's x/ci> + + mml_3 = mp._print(diff(cos(x*y), x, evaluate=False)) + assert mml_3.nodeName == 'apply' + assert mml_3.childNodes[0].nodeName == 'partialdiff' + assert mml_3.childNodes[1].nodeName == 'bvar' + assert mml_3.childNodes[1].childNodes[ + 0].nodeName == 'ci' # below bvar there's x/ci> + + mml_4 = mp._print(Lambda((x, y), x * y)) + assert mml_4.nodeName == 'lambda' + assert mml_4.childNodes[0].nodeName == 'bvar' + assert mml_4.childNodes[0].childNodes[ + 0].nodeName == 'ci' # below bvar there's x/ci> + assert mml_4.childNodes[1].nodeName == 'bvar' + assert mml_4.childNodes[1].childNodes[ + 0].nodeName == 'ci' # below bvar there's y/ci> + assert mml_4.childNodes[2].nodeName == 'apply' + + +def test_content_mathml_limits(): + # XXX No unevaluated limits + lim_fun = sin(x)/x + mml_1 = mp._print(Limit(lim_fun, x, 0)) + assert mml_1.childNodes[0].nodeName == 'limit' + assert mml_1.childNodes[1].nodeName == 'bvar' + assert mml_1.childNodes[2].nodeName == 'lowlimit' + assert mml_1.childNodes[3].toxml() == mp._print(lim_fun).toxml() + + +def test_content_mathml_integrals(): + integrand = x + mml_1 = mp._print(Integral(integrand, (x, 0, 1))) + assert mml_1.childNodes[0].nodeName == 'int' + assert mml_1.childNodes[1].nodeName == 'bvar' + assert mml_1.childNodes[2].nodeName == 'lowlimit' + assert mml_1.childNodes[3].nodeName == 'uplimit' + assert mml_1.childNodes[4].toxml() == mp._print(integrand).toxml() + + +def test_content_mathml_matrices(): + A = Matrix([1, 2, 3]) + B = Matrix([[0, 5, 4], [2, 3, 1], [9, 7, 9]]) + mll_1 = mp._print(A) + assert mll_1.childNodes[0].nodeName == 'matrixrow' + assert mll_1.childNodes[0].childNodes[0].nodeName == 'cn' + assert mll_1.childNodes[0].childNodes[0].childNodes[0].nodeValue == '1' + assert mll_1.childNodes[1].nodeName == 'matrixrow' + assert mll_1.childNodes[1].childNodes[0].nodeName == 'cn' + assert mll_1.childNodes[1].childNodes[0].childNodes[0].nodeValue == '2' + assert mll_1.childNodes[2].nodeName == 'matrixrow' + assert mll_1.childNodes[2].childNodes[0].nodeName == 'cn' + assert mll_1.childNodes[2].childNodes[0].childNodes[0].nodeValue == '3' + mll_2 = mp._print(B) + assert mll_2.childNodes[0].nodeName == 'matrixrow' + assert mll_2.childNodes[0].childNodes[0].nodeName == 'cn' + assert mll_2.childNodes[0].childNodes[0].childNodes[0].nodeValue == '0' + assert mll_2.childNodes[0].childNodes[1].nodeName == 'cn' + assert mll_2.childNodes[0].childNodes[1].childNodes[0].nodeValue == '5' + assert mll_2.childNodes[0].childNodes[2].nodeName == 'cn' + assert mll_2.childNodes[0].childNodes[2].childNodes[0].nodeValue == '4' + assert mll_2.childNodes[1].nodeName == 'matrixrow' + assert mll_2.childNodes[1].childNodes[0].nodeName == 'cn' + assert mll_2.childNodes[1].childNodes[0].childNodes[0].nodeValue == '2' + assert mll_2.childNodes[1].childNodes[1].nodeName == 'cn' + assert mll_2.childNodes[1].childNodes[1].childNodes[0].nodeValue == '3' + assert mll_2.childNodes[1].childNodes[2].nodeName == 'cn' + assert mll_2.childNodes[1].childNodes[2].childNodes[0].nodeValue == '1' + assert mll_2.childNodes[2].nodeName == 'matrixrow' + assert mll_2.childNodes[2].childNodes[0].nodeName == 'cn' + assert mll_2.childNodes[2].childNodes[0].childNodes[0].nodeValue == '9' + assert mll_2.childNodes[2].childNodes[1].nodeName == 'cn' + assert mll_2.childNodes[2].childNodes[1].childNodes[0].nodeValue == '7' + assert mll_2.childNodes[2].childNodes[2].nodeName == 'cn' + assert mll_2.childNodes[2].childNodes[2].childNodes[0].nodeValue == '9' + + +def test_content_mathml_sums(): + summand = x + mml_1 = mp._print(Sum(summand, (x, 1, 10))) + assert mml_1.childNodes[0].nodeName == 'sum' + assert mml_1.childNodes[1].nodeName == 'bvar' + assert mml_1.childNodes[2].nodeName == 'lowlimit' + assert mml_1.childNodes[3].nodeName == 'uplimit' + assert mml_1.childNodes[4].toxml() == mp._print(summand).toxml() + + +def test_content_mathml_tuples(): + mml_1 = mp._print([2]) + assert mml_1.nodeName == 'list' + assert mml_1.childNodes[0].nodeName == 'cn' + assert len(mml_1.childNodes) == 1 + + mml_2 = mp._print([2, Integer(1)]) + assert mml_2.nodeName == 'list' + assert mml_2.childNodes[0].nodeName == 'cn' + assert mml_2.childNodes[1].nodeName == 'cn' + assert len(mml_2.childNodes) == 2 + + +def test_content_mathml_add(): + mml = mp._print(x**5 - x**4 + x) + assert mml.childNodes[0].nodeName == 'plus' + assert mml.childNodes[1].childNodes[0].nodeName == 'minus' + assert mml.childNodes[1].childNodes[1].nodeName == 'apply' + + +def test_content_mathml_Rational(): + mml_1 = mp._print(Rational(1, 1)) + """should just return a number""" + assert mml_1.nodeName == 'cn' + + mml_2 = mp._print(Rational(2, 5)) + assert mml_2.childNodes[0].nodeName == 'divide' + + +def test_content_mathml_constants(): + mml = mp._print(I) + assert mml.nodeName == 'imaginaryi' + + mml = mp._print(E) + assert mml.nodeName == 'exponentiale' + + mml = mp._print(oo) + assert mml.nodeName == 'infinity' + + mml = mp._print(pi) + assert mml.nodeName == 'pi' + + assert mathml(hbar) == '' + assert mathml(S.TribonacciConstant) == '' + assert mathml(S.GoldenRatio) == 'φ' + mml = mathml(S.EulerGamma) + assert mml == '' + + mml = mathml(S.EmptySet) + assert mml == '' + + mml = mathml(S.true) + assert mml == '' + + mml = mathml(S.false) + assert mml == '' + + mml = mathml(S.NaN) + assert mml == '' + + +def test_content_mathml_trig(): + mml = mp._print(sin(x)) + assert mml.childNodes[0].nodeName == 'sin' + + mml = mp._print(cos(x)) + assert mml.childNodes[0].nodeName == 'cos' + + mml = mp._print(tan(x)) + assert mml.childNodes[0].nodeName == 'tan' + + mml = mp._print(cot(x)) + assert mml.childNodes[0].nodeName == 'cot' + + mml = mp._print(csc(x)) + assert mml.childNodes[0].nodeName == 'csc' + + mml = mp._print(sec(x)) + assert mml.childNodes[0].nodeName == 'sec' + + mml = mp._print(asin(x)) + assert mml.childNodes[0].nodeName == 'arcsin' + + mml = mp._print(acos(x)) + assert mml.childNodes[0].nodeName == 'arccos' + + mml = mp._print(atan(x)) + assert mml.childNodes[0].nodeName == 'arctan' + + mml = mp._print(acot(x)) + assert mml.childNodes[0].nodeName == 'arccot' + + mml = mp._print(acsc(x)) + assert mml.childNodes[0].nodeName == 'arccsc' + + mml = mp._print(asec(x)) + assert mml.childNodes[0].nodeName == 'arcsec' + + mml = mp._print(sinh(x)) + assert mml.childNodes[0].nodeName == 'sinh' + + mml = mp._print(cosh(x)) + assert mml.childNodes[0].nodeName == 'cosh' + + mml = mp._print(tanh(x)) + assert mml.childNodes[0].nodeName == 'tanh' + + mml = mp._print(coth(x)) + assert mml.childNodes[0].nodeName == 'coth' + + mml = mp._print(csch(x)) + assert mml.childNodes[0].nodeName == 'csch' + + mml = mp._print(sech(x)) + assert mml.childNodes[0].nodeName == 'sech' + + mml = mp._print(asinh(x)) + assert mml.childNodes[0].nodeName == 'arcsinh' + + mml = mp._print(atanh(x)) + assert mml.childNodes[0].nodeName == 'arctanh' + + mml = mp._print(acosh(x)) + assert mml.childNodes[0].nodeName == 'arccosh' + + mml = mp._print(acoth(x)) + assert mml.childNodes[0].nodeName == 'arccoth' + + mml = mp._print(acsch(x)) + assert mml.childNodes[0].nodeName == 'arccsch' + + mml = mp._print(asech(x)) + assert mml.childNodes[0].nodeName == 'arcsech' + + +def test_content_mathml_relational(): + mml_1 = mp._print(Eq(x, 1)) + assert mml_1.nodeName == 'apply' + assert mml_1.childNodes[0].nodeName == 'eq' + assert mml_1.childNodes[1].nodeName == 'ci' + assert mml_1.childNodes[1].childNodes[0].nodeValue == 'x' + assert mml_1.childNodes[2].nodeName == 'cn' + assert mml_1.childNodes[2].childNodes[0].nodeValue == '1' + + mml_2 = mp._print(Ne(1, x)) + assert mml_2.nodeName == 'apply' + assert mml_2.childNodes[0].nodeName == 'neq' + assert mml_2.childNodes[1].nodeName == 'cn' + assert mml_2.childNodes[1].childNodes[0].nodeValue == '1' + assert mml_2.childNodes[2].nodeName == 'ci' + assert mml_2.childNodes[2].childNodes[0].nodeValue == 'x' + + mml_3 = mp._print(Ge(1, x)) + assert mml_3.nodeName == 'apply' + assert mml_3.childNodes[0].nodeName == 'geq' + assert mml_3.childNodes[1].nodeName == 'cn' + assert mml_3.childNodes[1].childNodes[0].nodeValue == '1' + assert mml_3.childNodes[2].nodeName == 'ci' + assert mml_3.childNodes[2].childNodes[0].nodeValue == 'x' + + mml_4 = mp._print(Lt(1, x)) + assert mml_4.nodeName == 'apply' + assert mml_4.childNodes[0].nodeName == 'lt' + assert mml_4.childNodes[1].nodeName == 'cn' + assert mml_4.childNodes[1].childNodes[0].nodeValue == '1' + assert mml_4.childNodes[2].nodeName == 'ci' + assert mml_4.childNodes[2].childNodes[0].nodeValue == 'x' + + +def test_content_symbol(): + mml = mp._print(x) + assert mml.nodeName == 'ci' + assert mml.childNodes[0].nodeValue == 'x' + del mml + + mml = mp._print(Symbol("x^2")) + assert mml.nodeName == 'ci' + assert mml.childNodes[0].nodeName == 'mml:msup' + assert mml.childNodes[0].childNodes[0].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[0].childNodes[1].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[1].childNodes[0].nodeValue == '2' + del mml + + mml = mp._print(Symbol("x__2")) + assert mml.nodeName == 'ci' + assert mml.childNodes[0].nodeName == 'mml:msup' + assert mml.childNodes[0].childNodes[0].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[0].childNodes[1].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[1].childNodes[0].nodeValue == '2' + del mml + + mml = mp._print(Symbol("x_2")) + assert mml.nodeName == 'ci' + assert mml.childNodes[0].nodeName == 'mml:msub' + assert mml.childNodes[0].childNodes[0].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[0].childNodes[1].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[1].childNodes[0].nodeValue == '2' + del mml + + mml = mp._print(Symbol("x^3_2")) + assert mml.nodeName == 'ci' + assert mml.childNodes[0].nodeName == 'mml:msubsup' + assert mml.childNodes[0].childNodes[0].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[0].childNodes[1].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[1].childNodes[0].nodeValue == '2' + assert mml.childNodes[0].childNodes[2].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[2].childNodes[0].nodeValue == '3' + del mml + + mml = mp._print(Symbol("x__3_2")) + assert mml.nodeName == 'ci' + assert mml.childNodes[0].nodeName == 'mml:msubsup' + assert mml.childNodes[0].childNodes[0].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[0].childNodes[1].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[1].childNodes[0].nodeValue == '2' + assert mml.childNodes[0].childNodes[2].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[2].childNodes[0].nodeValue == '3' + del mml + + mml = mp._print(Symbol("x_2_a")) + assert mml.nodeName == 'ci' + assert mml.childNodes[0].nodeName == 'mml:msub' + assert mml.childNodes[0].childNodes[0].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[0].childNodes[1].nodeName == 'mml:mrow' + assert mml.childNodes[0].childNodes[1].childNodes[0].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[1].childNodes[0].childNodes[ + 0].nodeValue == '2' + assert mml.childNodes[0].childNodes[1].childNodes[1].nodeName == 'mml:mo' + assert mml.childNodes[0].childNodes[1].childNodes[1].childNodes[ + 0].nodeValue == ' ' + assert mml.childNodes[0].childNodes[1].childNodes[2].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[1].childNodes[2].childNodes[ + 0].nodeValue == 'a' + del mml + + mml = mp._print(Symbol("x^2^a")) + assert mml.nodeName == 'ci' + assert mml.childNodes[0].nodeName == 'mml:msup' + assert mml.childNodes[0].childNodes[0].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[0].childNodes[1].nodeName == 'mml:mrow' + assert mml.childNodes[0].childNodes[1].childNodes[0].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[1].childNodes[0].childNodes[ + 0].nodeValue == '2' + assert mml.childNodes[0].childNodes[1].childNodes[1].nodeName == 'mml:mo' + assert mml.childNodes[0].childNodes[1].childNodes[1].childNodes[ + 0].nodeValue == ' ' + assert mml.childNodes[0].childNodes[1].childNodes[2].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[1].childNodes[2].childNodes[ + 0].nodeValue == 'a' + del mml + + mml = mp._print(Symbol("x__2__a")) + assert mml.nodeName == 'ci' + assert mml.childNodes[0].nodeName == 'mml:msup' + assert mml.childNodes[0].childNodes[0].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[0].childNodes[1].nodeName == 'mml:mrow' + assert mml.childNodes[0].childNodes[1].childNodes[0].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[1].childNodes[0].childNodes[ + 0].nodeValue == '2' + assert mml.childNodes[0].childNodes[1].childNodes[1].nodeName == 'mml:mo' + assert mml.childNodes[0].childNodes[1].childNodes[1].childNodes[ + 0].nodeValue == ' ' + assert mml.childNodes[0].childNodes[1].childNodes[2].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[1].childNodes[2].childNodes[ + 0].nodeValue == 'a' + del mml + + +def test_content_mathml_greek(): + mml = mp._print(Symbol('alpha')) + assert mml.nodeName == 'ci' + assert mml.childNodes[0].nodeValue == '\N{GREEK SMALL LETTER ALPHA}' + + assert mp.doprint(Symbol('alpha')) == 'α' + assert mp.doprint(Symbol('beta')) == 'β' + assert mp.doprint(Symbol('gamma')) == 'γ' + assert mp.doprint(Symbol('delta')) == 'δ' + assert mp.doprint(Symbol('epsilon')) == 'ε' + assert mp.doprint(Symbol('zeta')) == 'ζ' + assert mp.doprint(Symbol('eta')) == 'η' + assert mp.doprint(Symbol('theta')) == 'θ' + assert mp.doprint(Symbol('iota')) == 'ι' + assert mp.doprint(Symbol('kappa')) == 'κ' + assert mp.doprint(Symbol('lambda')) == 'λ' + assert mp.doprint(Symbol('mu')) == 'μ' + assert mp.doprint(Symbol('nu')) == 'ν' + assert mp.doprint(Symbol('xi')) == 'ξ' + assert mp.doprint(Symbol('omicron')) == 'ο' + assert mp.doprint(Symbol('pi')) == 'π' + assert mp.doprint(Symbol('rho')) == 'ρ' + assert mp.doprint(Symbol('varsigma')) == 'ς' + assert mp.doprint(Symbol('sigma')) == 'σ' + assert mp.doprint(Symbol('tau')) == 'τ' + assert mp.doprint(Symbol('upsilon')) == 'υ' + assert mp.doprint(Symbol('phi')) == 'φ' + assert mp.doprint(Symbol('chi')) == 'χ' + assert mp.doprint(Symbol('psi')) == 'ψ' + assert mp.doprint(Symbol('omega')) == 'ω' + + assert mp.doprint(Symbol('Alpha')) == 'Α' + assert mp.doprint(Symbol('Beta')) == 'Β' + assert mp.doprint(Symbol('Gamma')) == 'Γ' + assert mp.doprint(Symbol('Delta')) == 'Δ' + assert mp.doprint(Symbol('Epsilon')) == 'Ε' + assert mp.doprint(Symbol('Zeta')) == 'Ζ' + assert mp.doprint(Symbol('Eta')) == 'Η' + assert mp.doprint(Symbol('Theta')) == 'Θ' + assert mp.doprint(Symbol('Iota')) == 'Ι' + assert mp.doprint(Symbol('Kappa')) == 'Κ' + assert mp.doprint(Symbol('Lambda')) == 'Λ' + assert mp.doprint(Symbol('Mu')) == 'Μ' + assert mp.doprint(Symbol('Nu')) == 'Ν' + assert mp.doprint(Symbol('Xi')) == 'Ξ' + assert mp.doprint(Symbol('Omicron')) == 'Ο' + assert mp.doprint(Symbol('Pi')) == 'Π' + assert mp.doprint(Symbol('Rho')) == 'Ρ' + assert mp.doprint(Symbol('Sigma')) == 'Σ' + assert mp.doprint(Symbol('Tau')) == 'Τ' + assert mp.doprint(Symbol('Upsilon')) == 'Υ' + assert mp.doprint(Symbol('Phi')) == 'Φ' + assert mp.doprint(Symbol('Chi')) == 'Χ' + assert mp.doprint(Symbol('Psi')) == 'Ψ' + assert mp.doprint(Symbol('Omega')) == 'Ω' + + +def test_content_mathml_order(): + expr = x**3 + x**2*y + 3*x*y**3 + y**4 + + mp = MathMLContentPrinter({'order': 'lex'}) + mml = mp._print(expr) + + assert mml.childNodes[1].childNodes[0].nodeName == 'power' + assert mml.childNodes[1].childNodes[1].childNodes[0].data == 'x' + assert mml.childNodes[1].childNodes[2].childNodes[0].data == '3' + + assert mml.childNodes[4].childNodes[0].nodeName == 'power' + assert mml.childNodes[4].childNodes[1].childNodes[0].data == 'y' + assert mml.childNodes[4].childNodes[2].childNodes[0].data == '4' + + mp = MathMLContentPrinter({'order': 'rev-lex'}) + mml = mp._print(expr) + + assert mml.childNodes[1].childNodes[0].nodeName == 'power' + assert mml.childNodes[1].childNodes[1].childNodes[0].data == 'y' + assert mml.childNodes[1].childNodes[2].childNodes[0].data == '4' + + assert mml.childNodes[4].childNodes[0].nodeName == 'power' + assert mml.childNodes[4].childNodes[1].childNodes[0].data == 'x' + assert mml.childNodes[4].childNodes[2].childNodes[0].data == '3' + + +def test_content_settings(): + raises(TypeError, lambda: mathml(x, method="garbage")) + + +def test_content_mathml_logic(): + assert mathml(And(x, y)) == 'xy' + assert mathml(Or(x, y)) == 'xy' + assert mathml(Xor(x, y)) == 'xy' + assert mathml(Implies(x, y)) == 'xy' + assert mathml(Not(x)) == 'x' + + +def test_content_finite_sets(): + assert mathml(FiniteSet(a)) == 'a' + assert mathml(FiniteSet(a, b)) == 'ab' + assert mathml(FiniteSet(FiniteSet(a, b), c)) == \ + 'cab' + + A = FiniteSet(a) + B = FiniteSet(b) + C = FiniteSet(c) + D = FiniteSet(d) + + U1 = Union(A, B, evaluate=False) + U2 = Union(C, D, evaluate=False) + I1 = Intersection(A, B, evaluate=False) + I2 = Intersection(C, D, evaluate=False) + C1 = Complement(A, B, evaluate=False) + C2 = Complement(C, D, evaluate=False) + # XXX ProductSet does not support evaluate keyword + P1 = ProductSet(A, B) + P2 = ProductSet(C, D) + + assert mathml(U1) == \ + 'ab' + assert mathml(I1) == \ + 'ab' \ + '' + assert mathml(C1) == \ + 'ab' + assert mathml(P1) == \ + 'ab' \ + '' + + assert mathml(Intersection(A, U2, evaluate=False)) == \ + 'a' \ + 'cd' + assert mathml(Intersection(U1, U2, evaluate=False)) == \ + 'a' \ + 'bc' \ + 'd' + + # XXX Does the parenthesis appear correctly for these examples in mathjax? + assert mathml(Intersection(C1, C2, evaluate=False)) == \ + 'a' \ + 'bc' \ + 'd' + assert mathml(Intersection(P1, P2, evaluate=False)) == \ + 'a' \ + 'b' \ + 'cd' + + assert mathml(Union(A, I2, evaluate=False)) == \ + 'a' \ + 'cd' + assert mathml(Union(I1, I2, evaluate=False)) == \ + 'a' \ + 'bc' \ + 'd' + assert mathml(Union(C1, C2, evaluate=False)) == \ + 'a' \ + 'bc' \ + 'd' + assert mathml(Union(P1, P2, evaluate=False)) == \ + 'a' \ + 'b' \ + 'cd' + + assert mathml(Complement(A, C2, evaluate=False)) == \ + 'a' \ + 'cd' + assert mathml(Complement(U1, U2, evaluate=False)) == \ + 'a' \ + 'bc' \ + 'd' + assert mathml(Complement(I1, I2, evaluate=False)) == \ + 'a' \ + 'bc' \ + 'd' + assert mathml(Complement(P1, P2, evaluate=False)) == \ + 'a' \ + 'b' \ + 'cd' + + assert mathml(ProductSet(A, P2)) == \ + 'a' \ + 'c' \ + 'd' + assert mathml(ProductSet(U1, U2)) == \ + 'a' \ + 'bc' \ + 'd' + assert mathml(ProductSet(I1, I2)) == \ + 'a' \ + 'b' \ + 'cd' + assert mathml(ProductSet(C1, C2)) == \ + 'a' \ + 'b' \ + 'cd' + + +def test_presentation_printmethod(): + assert mpp.doprint(1 + x) == 'x+1' + assert mpp.doprint(x**2) == 'x2' + assert mpp.doprint(x**-1) == '1x' + assert mpp.doprint(x**-2) == \ + '1x2' + assert mpp.doprint(2*x) == \ + '2x' + + +def test_presentation_mathml_core(): + mml_1 = mpp._print(1 + x) + assert mml_1.nodeName == 'mrow' + nodes = mml_1.childNodes + assert len(nodes) == 3 + assert nodes[0].nodeName in ['mi', 'mn'] + assert nodes[1].nodeName == 'mo' + if nodes[0].nodeName == 'mn': + assert nodes[0].childNodes[0].nodeValue == '1' + assert nodes[2].childNodes[0].nodeValue == 'x' + else: + assert nodes[0].childNodes[0].nodeValue == 'x' + assert nodes[2].childNodes[0].nodeValue == '1' + + mml_2 = mpp._print(x**2) + assert mml_2.nodeName == 'msup' + nodes = mml_2.childNodes + assert nodes[0].childNodes[0].nodeValue == 'x' + assert nodes[1].childNodes[0].nodeValue == '2' + + mml_3 = mpp._print(2*x) + assert mml_3.nodeName == 'mrow' + nodes = mml_3.childNodes + assert nodes[0].childNodes[0].nodeValue == '2' + assert nodes[1].childNodes[0].nodeValue == '⁢' + assert nodes[2].childNodes[0].nodeValue == 'x' + + mml = mpp._print(Float(1.0, 2)*x) + assert mml.nodeName == 'mrow' + nodes = mml.childNodes + assert nodes[0].childNodes[0].nodeValue == '1.0' + assert nodes[1].childNodes[0].nodeValue == '⁢' + assert nodes[2].childNodes[0].nodeValue == 'x' + + +def test_presentation_mathml_functions(): + mml_1 = mpp._print(sin(x)) + assert mml_1.childNodes[0].childNodes[0 + ].nodeValue == 'sin' + assert mml_1.childNodes[1].childNodes[1 + ].childNodes[0].nodeValue == 'x' + + mml_2 = mpp._print(diff(sin(x), x, evaluate=False)) + assert mml_2.nodeName == 'mrow' + assert mml_2.childNodes[0].childNodes[0 + ].childNodes[0].childNodes[0].nodeValue == 'ⅆ' + assert mml_2.childNodes[1].childNodes[1 + ].nodeName == 'mrow' + assert mml_2.childNodes[0].childNodes[1 + ].childNodes[0].childNodes[0].nodeValue == 'ⅆ' + + mml_3 = mpp._print(diff(cos(x*y), x, evaluate=False)) + assert mml_3.childNodes[0].nodeName == 'mfrac' + assert mml_3.childNodes[0].childNodes[0 + ].childNodes[0].childNodes[0].nodeValue == '∂' + assert mml_3.childNodes[1].childNodes[0 + ].childNodes[0].nodeValue == 'cos' + + +def test_print_derivative(): + f = Function('f') + d = Derivative(f(x, y, z), x, z, x, z, z, y) + assert mathml(d) == \ + 'yz2xzxxyz' + assert mathml(d, printer='presentation') == \ + '6y2zxzxf(x,y,z)' + + +def test_presentation_mathml_limits(): + lim_fun = sin(x)/x + mml_1 = mpp._print(Limit(lim_fun, x, 0)) + assert mml_1.childNodes[0].nodeName == 'munder' + assert mml_1.childNodes[0].childNodes[0 + ].childNodes[0].nodeValue == 'lim' + assert mml_1.childNodes[0].childNodes[1 + ].childNodes[0].childNodes[0 + ].nodeValue == 'x' + assert mml_1.childNodes[0].childNodes[1 + ].childNodes[1].childNodes[0 + ].nodeValue == '→' + assert mml_1.childNodes[0].childNodes[1 + ].childNodes[2].childNodes[0 + ].nodeValue == '0' + + +def test_presentation_mathml_integrals(): + assert mpp.doprint(Integral(x, (x, 0, 1))) == \ + '01'\ + 'xx' + assert mpp.doprint(Integral(log(x), x)) == \ + 'log(x' \ + ')x' + assert mpp.doprint(Integral(x*y, x, y)) == \ + 'x'\ + 'yyx' + z, w = symbols('z w') + assert mpp.doprint(Integral(x*y*z, x, y, z)) == \ + 'x'\ + 'yz'\ + 'zyx' + assert mpp.doprint(Integral(x*y*z*w, x, y, z, w)) == \ + ''\ + 'w'\ + 'xy'\ + 'zw'\ + 'zyx' + assert mpp.doprint(Integral(x, x, y, (z, 0, 1))) == \ + '01'\ + 'xz'\ + 'yx' + assert mpp.doprint(Integral(x, (x, 0))) == \ + '0x'\ + 'x' + + +def test_presentation_mathml_matrices(): + A = Matrix([1, 2, 3]) + B = Matrix([[0, 5, 4], [2, 3, 1], [9, 7, 9]]) + mll_1 = mpp._print(A) + assert mll_1.childNodes[1].nodeName == 'mtable' + assert mll_1.childNodes[1].childNodes[0].nodeName == 'mtr' + assert len(mll_1.childNodes[1].childNodes) == 3 + assert mll_1.childNodes[1].childNodes[0].childNodes[0].nodeName == 'mtd' + assert len(mll_1.childNodes[1].childNodes[0].childNodes) == 1 + assert mll_1.childNodes[1].childNodes[0].childNodes[0 + ].childNodes[0].childNodes[0].nodeValue == '1' + assert mll_1.childNodes[1].childNodes[1].childNodes[0 + ].childNodes[0].childNodes[0].nodeValue == '2' + assert mll_1.childNodes[1].childNodes[2].childNodes[0 + ].childNodes[0].childNodes[0].nodeValue == '3' + mll_2 = mpp._print(B) + assert mll_2.childNodes[1].nodeName == 'mtable' + assert mll_2.childNodes[1].childNodes[0].nodeName == 'mtr' + assert len(mll_2.childNodes[1].childNodes) == 3 + assert mll_2.childNodes[1].childNodes[0].childNodes[0].nodeName == 'mtd' + assert len(mll_2.childNodes[1].childNodes[0].childNodes) == 3 + assert mll_2.childNodes[1].childNodes[0].childNodes[0 + ].childNodes[0].childNodes[0].nodeValue == '0' + assert mll_2.childNodes[1].childNodes[0].childNodes[1 + ].childNodes[0].childNodes[0].nodeValue == '5' + assert mll_2.childNodes[1].childNodes[0].childNodes[2 + ].childNodes[0].childNodes[0].nodeValue == '4' + assert mll_2.childNodes[1].childNodes[1].childNodes[0 + ].childNodes[0].childNodes[0].nodeValue == '2' + assert mll_2.childNodes[1].childNodes[1].childNodes[1 + ].childNodes[0].childNodes[0].nodeValue == '3' + assert mll_2.childNodes[1].childNodes[1].childNodes[2 + ].childNodes[0].childNodes[0].nodeValue == '1' + assert mll_2.childNodes[1].childNodes[2].childNodes[0 + ].childNodes[0].childNodes[0].nodeValue == '9' + assert mll_2.childNodes[1].childNodes[2].childNodes[1 + ].childNodes[0].childNodes[0].nodeValue == '7' + assert mll_2.childNodes[1].childNodes[2].childNodes[2 + ].childNodes[0].childNodes[0].nodeValue == '9' + + +def test_presentation_mathml_sums(): + mml_1 = mpp._print(Sum(x, (x, 1, 10))) + assert mml_1.childNodes[0].nodeName == 'munderover' + assert len(mml_1.childNodes[0].childNodes) == 3 + assert mml_1.childNodes[0].childNodes[0].childNodes[0 + ].nodeValue == '∑' + assert len(mml_1.childNodes[0].childNodes[1].childNodes) == 3 + assert mml_1.childNodes[0].childNodes[2].childNodes[0 + ].nodeValue == '10' + assert mml_1.childNodes[1].childNodes[0].nodeValue == 'x' + + assert mpp.doprint(Sum(x, (x, 1, 10))) == \ + 'x=110x' + assert mpp.doprint(Sum(x + y, (x, 1, 10))) == \ + 'x=110(x+y)' + + +def test_presentation_mathml_add(): + mml = mpp._print(x**5 - x**4 + x) + assert len(mml.childNodes) == 5 + assert mml.childNodes[0].childNodes[0].childNodes[0 + ].nodeValue == 'x' + assert mml.childNodes[0].childNodes[1].childNodes[0 + ].nodeValue == '5' + assert mml.childNodes[1].childNodes[0].nodeValue == '-' + assert mml.childNodes[2].childNodes[0].childNodes[0 + ].nodeValue == 'x' + assert mml.childNodes[2].childNodes[1].childNodes[0 + ].nodeValue == '4' + assert mml.childNodes[3].childNodes[0].nodeValue == '+' + assert mml.childNodes[4].childNodes[0].nodeValue == 'x' + + +def test_presentation_mathml_Rational(): + mml_1 = mpp._print(Rational(1, 1)) + assert mml_1.nodeName == 'mn' + + mml_2 = mpp._print(Rational(2, 5)) + assert mml_2.nodeName == 'mfrac' + assert mml_2.childNodes[0].childNodes[0].nodeValue == '2' + assert mml_2.childNodes[1].childNodes[0].nodeValue == '5' + + +def test_presentation_mathml_constants(): + mml = mpp._print(I) + assert mml.childNodes[0].nodeValue == 'ⅈ' + + mml = mpp._print(E) + assert mml.childNodes[0].nodeValue == 'ⅇ' + + mml = mpp._print(oo) + assert mml.childNodes[0].nodeValue == '∞' + + mml = mpp._print(pi) + assert mml.childNodes[0].nodeValue == 'π' + + assert mathml(hbar, printer='presentation') == '' + assert mathml(S.TribonacciConstant, printer='presentation' + ) == 'TribonacciConstant' + assert mathml(S.EulerGamma, printer='presentation' + ) == 'γ' + assert mathml(S.GoldenRatio, printer='presentation' + ) == 'Φ' + + assert mathml(zoo, printer='presentation') == \ + '~' + + assert mathml(S.NaN, printer='presentation') == 'NaN' + +def test_presentation_mathml_trig(): + mml = mpp._print(sin(x)) + assert mml.childNodes[0].childNodes[0].nodeValue == 'sin' + + mml = mpp._print(cos(x)) + assert mml.childNodes[0].childNodes[0].nodeValue == 'cos' + + mml = mpp._print(tan(x)) + assert mml.childNodes[0].childNodes[0].nodeValue == 'tan' + + mml = mpp._print(asin(x)) + assert mml.childNodes[0].childNodes[0].nodeValue == 'arcsin' + + mml = mpp._print(acos(x)) + assert mml.childNodes[0].childNodes[0].nodeValue == 'arccos' + + mml = mpp._print(atan(x)) + assert mml.childNodes[0].childNodes[0].nodeValue == 'arctan' + + mml = mpp._print(sinh(x)) + assert mml.childNodes[0].childNodes[0].nodeValue == 'sinh' + + mml = mpp._print(cosh(x)) + assert mml.childNodes[0].childNodes[0].nodeValue == 'cosh' + + mml = mpp._print(tanh(x)) + assert mml.childNodes[0].childNodes[0].nodeValue == 'tanh' + + mml = mpp._print(asinh(x)) + assert mml.childNodes[0].childNodes[0].nodeValue == 'arcsinh' + + mml = mpp._print(atanh(x)) + assert mml.childNodes[0].childNodes[0].nodeValue == 'arctanh' + + mml = mpp._print(acosh(x)) + assert mml.childNodes[0].childNodes[0].nodeValue == 'arccosh' + + +def test_presentation_mathml_relational(): + mml_1 = mpp._print(Eq(x, 1)) + assert len(mml_1.childNodes) == 3 + assert mml_1.childNodes[0].nodeName == 'mi' + assert mml_1.childNodes[0].childNodes[0].nodeValue == 'x' + assert mml_1.childNodes[1].nodeName == 'mo' + assert mml_1.childNodes[1].childNodes[0].nodeValue == '=' + assert mml_1.childNodes[2].nodeName == 'mn' + assert mml_1.childNodes[2].childNodes[0].nodeValue == '1' + + mml_2 = mpp._print(Ne(1, x)) + assert len(mml_2.childNodes) == 3 + assert mml_2.childNodes[0].nodeName == 'mn' + assert mml_2.childNodes[0].childNodes[0].nodeValue == '1' + assert mml_2.childNodes[1].nodeName == 'mo' + assert mml_2.childNodes[1].childNodes[0].nodeValue == '≠' + assert mml_2.childNodes[2].nodeName == 'mi' + assert mml_2.childNodes[2].childNodes[0].nodeValue == 'x' + + mml_3 = mpp._print(Ge(1, x)) + assert len(mml_3.childNodes) == 3 + assert mml_3.childNodes[0].nodeName == 'mn' + assert mml_3.childNodes[0].childNodes[0].nodeValue == '1' + assert mml_3.childNodes[1].nodeName == 'mo' + assert mml_3.childNodes[1].childNodes[0].nodeValue == '≥' + assert mml_3.childNodes[2].nodeName == 'mi' + assert mml_3.childNodes[2].childNodes[0].nodeValue == 'x' + + mml_4 = mpp._print(Lt(1, x)) + assert len(mml_4.childNodes) == 3 + assert mml_4.childNodes[0].nodeName == 'mn' + assert mml_4.childNodes[0].childNodes[0].nodeValue == '1' + assert mml_4.childNodes[1].nodeName == 'mo' + assert mml_4.childNodes[1].childNodes[0].nodeValue == '<' + assert mml_4.childNodes[2].nodeName == 'mi' + assert mml_4.childNodes[2].childNodes[0].nodeValue == 'x' + + +def test_presentation_symbol(): + mml = mpp._print(x) + assert mml.nodeName == 'mi' + assert mml.childNodes[0].nodeValue == 'x' + del mml + + mml = mpp._print(Symbol("x^2")) + assert mml.nodeName == 'msup' + assert mml.childNodes[0].nodeName == 'mi' + assert mml.childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[1].nodeName == 'mi' + assert mml.childNodes[1].childNodes[0].nodeValue == '2' + del mml + + mml = mpp._print(Symbol("x__2")) + assert mml.nodeName == 'msup' + assert mml.childNodes[0].nodeName == 'mi' + assert mml.childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[1].nodeName == 'mi' + assert mml.childNodes[1].childNodes[0].nodeValue == '2' + del mml + + mml = mpp._print(Symbol("x_2")) + assert mml.nodeName == 'msub' + assert mml.childNodes[0].nodeName == 'mi' + assert mml.childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[1].nodeName == 'mi' + assert mml.childNodes[1].childNodes[0].nodeValue == '2' + del mml + + mml = mpp._print(Symbol("x^3_2")) + assert mml.nodeName == 'msubsup' + assert mml.childNodes[0].nodeName == 'mi' + assert mml.childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[1].nodeName == 'mi' + assert mml.childNodes[1].childNodes[0].nodeValue == '2' + assert mml.childNodes[2].nodeName == 'mi' + assert mml.childNodes[2].childNodes[0].nodeValue == '3' + del mml + + mml = mpp._print(Symbol("x__3_2")) + assert mml.nodeName == 'msubsup' + assert mml.childNodes[0].nodeName == 'mi' + assert mml.childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[1].nodeName == 'mi' + assert mml.childNodes[1].childNodes[0].nodeValue == '2' + assert mml.childNodes[2].nodeName == 'mi' + assert mml.childNodes[2].childNodes[0].nodeValue == '3' + del mml + + mml = mpp._print(Symbol("x_2_a")) + assert mml.nodeName == 'msub' + assert mml.childNodes[0].nodeName == 'mi' + assert mml.childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[1].nodeName == 'mrow' + assert mml.childNodes[1].childNodes[0].nodeName == 'mi' + assert mml.childNodes[1].childNodes[0].childNodes[0].nodeValue == '2' + assert mml.childNodes[1].childNodes[1].nodeName == 'mo' + assert mml.childNodes[1].childNodes[1].childNodes[0].nodeValue == ' ' + assert mml.childNodes[1].childNodes[2].nodeName == 'mi' + assert mml.childNodes[1].childNodes[2].childNodes[0].nodeValue == 'a' + del mml + + mml = mpp._print(Symbol("x^2^a")) + assert mml.nodeName == 'msup' + assert mml.childNodes[0].nodeName == 'mi' + assert mml.childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[1].nodeName == 'mrow' + assert mml.childNodes[1].childNodes[0].nodeName == 'mi' + assert mml.childNodes[1].childNodes[0].childNodes[0].nodeValue == '2' + assert mml.childNodes[1].childNodes[1].nodeName == 'mo' + assert mml.childNodes[1].childNodes[1].childNodes[0].nodeValue == ' ' + assert mml.childNodes[1].childNodes[2].nodeName == 'mi' + assert mml.childNodes[1].childNodes[2].childNodes[0].nodeValue == 'a' + del mml + + mml = mpp._print(Symbol("x__2__a")) + assert mml.nodeName == 'msup' + assert mml.childNodes[0].nodeName == 'mi' + assert mml.childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[1].nodeName == 'mrow' + assert mml.childNodes[1].childNodes[0].nodeName == 'mi' + assert mml.childNodes[1].childNodes[0].childNodes[0].nodeValue == '2' + assert mml.childNodes[1].childNodes[1].nodeName == 'mo' + assert mml.childNodes[1].childNodes[1].childNodes[0].nodeValue == ' ' + assert mml.childNodes[1].childNodes[2].nodeName == 'mi' + assert mml.childNodes[1].childNodes[2].childNodes[0].nodeValue == 'a' + del mml + + +def test_presentation_mathml_greek(): + mml = mpp._print(Symbol('alpha')) + assert mml.nodeName == 'mi' + assert mml.childNodes[0].nodeValue == '\N{GREEK SMALL LETTER ALPHA}' + + assert mpp.doprint(Symbol('alpha')) == 'α' + assert mpp.doprint(Symbol('beta')) == 'β' + assert mpp.doprint(Symbol('gamma')) == 'γ' + assert mpp.doprint(Symbol('delta')) == 'δ' + assert mpp.doprint(Symbol('epsilon')) == 'ε' + assert mpp.doprint(Symbol('zeta')) == 'ζ' + assert mpp.doprint(Symbol('eta')) == 'η' + assert mpp.doprint(Symbol('theta')) == 'θ' + assert mpp.doprint(Symbol('iota')) == 'ι' + assert mpp.doprint(Symbol('kappa')) == 'κ' + assert mpp.doprint(Symbol('lambda')) == 'λ' + assert mpp.doprint(Symbol('mu')) == 'μ' + assert mpp.doprint(Symbol('nu')) == 'ν' + assert mpp.doprint(Symbol('xi')) == 'ξ' + assert mpp.doprint(Symbol('omicron')) == 'ο' + assert mpp.doprint(Symbol('pi')) == 'π' + assert mpp.doprint(Symbol('rho')) == 'ρ' + assert mpp.doprint(Symbol('varsigma')) == 'ς' + assert mpp.doprint(Symbol('sigma')) == 'σ' + assert mpp.doprint(Symbol('tau')) == 'τ' + assert mpp.doprint(Symbol('upsilon')) == 'υ' + assert mpp.doprint(Symbol('phi')) == 'φ' + assert mpp.doprint(Symbol('chi')) == 'χ' + assert mpp.doprint(Symbol('psi')) == 'ψ' + assert mpp.doprint(Symbol('omega')) == 'ω' + + assert mpp.doprint(Symbol('Alpha')) == 'Α' + assert mpp.doprint(Symbol('Beta')) == 'Β' + assert mpp.doprint(Symbol('Gamma')) == 'Γ' + assert mpp.doprint(Symbol('Delta')) == 'Δ' + assert mpp.doprint(Symbol('Epsilon')) == 'Ε' + assert mpp.doprint(Symbol('Zeta')) == 'Ζ' + assert mpp.doprint(Symbol('Eta')) == 'Η' + assert mpp.doprint(Symbol('Theta')) == 'Θ' + assert mpp.doprint(Symbol('Iota')) == 'Ι' + assert mpp.doprint(Symbol('Kappa')) == 'Κ' + assert mpp.doprint(Symbol('Lambda')) == 'Λ' + assert mpp.doprint(Symbol('Mu')) == 'Μ' + assert mpp.doprint(Symbol('Nu')) == 'Ν' + assert mpp.doprint(Symbol('Xi')) == 'Ξ' + assert mpp.doprint(Symbol('Omicron')) == 'Ο' + assert mpp.doprint(Symbol('Pi')) == 'Π' + assert mpp.doprint(Symbol('Rho')) == 'Ρ' + assert mpp.doprint(Symbol('Sigma')) == 'Σ' + assert mpp.doprint(Symbol('Tau')) == 'Τ' + assert mpp.doprint(Symbol('Upsilon')) == 'Υ' + assert mpp.doprint(Symbol('Phi')) == 'Φ' + assert mpp.doprint(Symbol('Chi')) == 'Χ' + assert mpp.doprint(Symbol('Psi')) == 'Ψ' + assert mpp.doprint(Symbol('Omega')) == 'Ω' + + +def test_presentation_mathml_order(): + expr = x**3 + x**2*y + 3*x*y**3 + y**4 + + mp = MathMLPresentationPrinter({'order': 'lex'}) + mml = mp._print(expr) + assert mml.childNodes[0].nodeName == 'msup' + assert mml.childNodes[0].childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[0].childNodes[1].childNodes[0].nodeValue == '3' + + assert mml.childNodes[6].nodeName == 'msup' + assert mml.childNodes[6].childNodes[0].childNodes[0].nodeValue == 'y' + assert mml.childNodes[6].childNodes[1].childNodes[0].nodeValue == '4' + + mp = MathMLPresentationPrinter({'order': 'rev-lex'}) + mml = mp._print(expr) + + assert mml.childNodes[0].nodeName == 'msup' + assert mml.childNodes[0].childNodes[0].childNodes[0].nodeValue == 'y' + assert mml.childNodes[0].childNodes[1].childNodes[0].nodeValue == '4' + + assert mml.childNodes[6].nodeName == 'msup' + assert mml.childNodes[6].childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[6].childNodes[1].childNodes[0].nodeValue == '3' + + +def test_print_intervals(): + a = Symbol('a', real=True) + assert mpp.doprint(Interval(0, a)) == \ + '[0,a]' + assert mpp.doprint(Interval(0, a, False, False)) == \ + '[0,a]' + assert mpp.doprint(Interval(0, a, True, False)) == \ + '(0,a]' + assert mpp.doprint(Interval(0, a, False, True)) == \ + '[0,a)' + assert mpp.doprint(Interval(0, a, True, True)) == \ + '(0,a)' + + +def test_print_tuples(): + assert mpp.doprint(Tuple(0,)) == \ + '(0)' + assert mpp.doprint(Tuple(0, a)) == \ + '(0,a)' + assert mpp.doprint(Tuple(0, a, a)) == \ + '(0,a,a)' + assert mpp.doprint(Tuple(0, 1, 2, 3, 4)) == \ + '(0,1,2,3,4)' + assert mpp.doprint(Tuple(0, 1, Tuple(2, 3, 4))) == \ + '(0,1,(2,3'\ + ',4))' + + +def test_print_re_im(): + assert mpp.doprint(re(x)) == \ + '(x)' + assert mpp.doprint(im(x)) == \ + '(x)' + assert mpp.doprint(re(x + 1, evaluate=False)) == \ + '(x+1)' + assert mpp.doprint(im(x + 1, evaluate=False)) == \ + '(x+1)' + + +def test_print_Abs(): + assert mpp.doprint(Abs(x)) == \ + '|x|' + assert mpp.doprint(Abs(x + 1)) == \ + '|x+1|' + + +def test_print_Determinant(): + assert mpp.doprint(Determinant(Matrix([[1, 2], [3, 4]]))) == \ + '|[1234]|' + + +def test_presentation_settings(): + raises(TypeError, lambda: mathml(x, printer='presentation', + method="garbage")) + + +def test_print_domains(): + from sympy.sets import Integers, Naturals, Naturals0, Reals, Complexes + + assert mpp.doprint(Complexes) == '' + assert mpp.doprint(Integers) == '' + assert mpp.doprint(Naturals) == '' + assert mpp.doprint(Naturals0) == \ + '0' + assert mpp.doprint(Reals) == '' + + +def test_print_expression_with_minus(): + assert mpp.doprint(-x) == '-x' + assert mpp.doprint(-x/y) == \ + '-xy' + assert mpp.doprint(-Rational(1, 2)) == \ + '-12' + + +def test_print_AssocOp(): + from sympy.core.operations import AssocOp + + class TestAssocOp(AssocOp): + identity = 0 + + expr = TestAssocOp(1, 2) + assert mpp.doprint(expr) == \ + 'testassocop12' + + +def test_print_basic(): + expr = Basic(S(1), S(2)) + assert mpp.doprint(expr) == \ + 'basic(1,2)' + assert mp.doprint(expr) == '12' + + +def test_mat_delim_print(): + expr = Matrix([[1, 2], [3, 4]]) + assert mathml(expr, printer='presentation', mat_delim='[') == \ + '[1'\ + '234'\ + ']' + assert mathml(expr, printer='presentation', mat_delim='(') == \ + '(12'\ + '34)' + assert mathml(expr, printer='presentation', mat_delim='') == \ + '12'\ + '34' + + +def test_ln_notation_print(): + expr = log(x) + assert mathml(expr, printer='presentation') == \ + 'log(x)' + assert mathml(expr, printer='presentation', ln_notation=False) == \ + 'log(x)' + assert mathml(expr, printer='presentation', ln_notation=True) == \ + 'ln(x)' + + +def test_mul_symbol_print(): + expr = x * y + assert mathml(expr, printer='presentation') == \ + 'xy' + assert mathml(expr, printer='presentation', mul_symbol=None) == \ + 'xy' + assert mathml(expr, printer='presentation', mul_symbol='dot') == \ + 'x·y' + assert mathml(expr, printer='presentation', mul_symbol='ldot') == \ + 'xy' + assert mathml(expr, printer='presentation', mul_symbol='times') == \ + 'x×y' + + +def test_print_lerchphi(): + assert mpp.doprint(lerchphi(1, 2, 3)) == \ + 'Φ(1,2,3)' + + +def test_print_polylog(): + assert mp.doprint(polylog(x, y)) == \ + 'xy' + assert mpp.doprint(polylog(x, y)) == \ + 'Lix(y)' + + +def test_print_set_frozenset(): + f = frozenset({1, 5, 3}) + assert mpp.doprint(f) == \ + '{1,3,5}' + s = set({1, 2, 3}) + assert mpp.doprint(s) == \ + '{1,2,3}' + + +def test_print_FiniteSet(): + f1 = FiniteSet(x, 1, 3) + assert mpp.doprint(f1) == \ + '{1,3,x}' + + +def test_print_LambertW(): + assert mpp.doprint(LambertW(x)) == 'W(x)' + assert mpp.doprint(LambertW(x, y)) == 'W(x,y)' + + +def test_print_EmptySet(): + assert mpp.doprint(S.EmptySet) == '' + + +def test_print_UniversalSet(): + assert mpp.doprint(S.UniversalSet) == '𝕌' + + +def test_print_spaces(): + assert mpp.doprint(HilbertSpace()) == '' + assert mpp.doprint(ComplexSpace(2)) == '𝒞2' + assert mpp.doprint(FockSpace()) == '' + + +def test_print_constants(): + assert mpp.doprint(hbar) == '' + assert mpp.doprint(S.TribonacciConstant) == 'TribonacciConstant' + assert mpp.doprint(S.GoldenRatio) == 'Φ' + assert mpp.doprint(S.EulerGamma) == 'γ' + + +def test_print_Contains(): + assert mpp.doprint(Contains(x, S.Naturals)) == \ + 'x' + + +def test_print_Dagger(): + x = symbols('x', commutative=False) + assert mpp.doprint(Dagger(x)) == 'x' + + +def test_print_SetOp(): + f1 = FiniteSet(x, 1, 3) + f2 = FiniteSet(y, 2, 4) + + prntr = lambda x: mathml(x, printer='presentation') + + assert prntr(Union(f1, f2, evaluate=False)) == \ + '{1,3,x'\ + '}{2,'\ + '4,y}' + assert prntr(Intersection(f1, f2, evaluate=False)) == \ + '{1,3,x'\ + '}{2'\ + ',4,y}' + assert prntr(Complement(f1, f2, evaluate=False)) == \ + '{1,3,x'\ + '}{2'\ + ',4,y}' + assert prntr(SymmetricDifference(f1, f2, evaluate=False)) == \ + '{1,3,x'\ + '}{2'\ + ',4,y}' + + A = FiniteSet(a) + C = FiniteSet(c) + D = FiniteSet(d) + + U1 = Union(C, D, evaluate=False) + I1 = Intersection(C, D, evaluate=False) + C1 = Complement(C, D, evaluate=False) + D1 = SymmetricDifference(C, D, evaluate=False) + # XXX ProductSet does not support evaluate keyword + P1 = ProductSet(C, D) + + assert prntr(Union(A, I1, evaluate=False)) == \ + '{a}' \ + '({' \ + 'c}{' \ + 'd})' + assert prntr(Intersection(A, C1, evaluate=False)) == \ + '{a}' \ + '({' \ + 'c}{' \ + 'd})' + assert prntr(Complement(A, D1, evaluate=False)) == \ + '{a}' \ + '({' \ + 'c}{' \ + 'd})' + assert prntr(SymmetricDifference(A, P1, evaluate=False)) == \ + '{a}' \ + '({' \ + 'c}×{' \ + 'd})' + assert prntr(ProductSet(A, U1)) == \ + '{a}' \ + '×({' \ + 'c}{' \ + 'd})' + + +def test_print_logic(): + assert mpp.doprint(And(x, y)) == \ + 'xy' + assert mpp.doprint(Or(x, y)) == \ + 'xy' + assert mpp.doprint(Xor(x, y)) == \ + 'xy' + assert mpp.doprint(Implies(x, y)) == \ + 'xy' + assert mpp.doprint(Equivalent(x, y)) == \ + 'xy' + + assert mpp.doprint(And(Eq(x, y), x > 4)) == \ + 'x=y'\ + 'x>4' + assert mpp.doprint(And(Eq(x, 3), y < 3, x > y + 1)) == \ + 'x=3'\ + 'x>y+1'\ + 'y<3' + assert mpp.doprint(Or(Eq(x, y), x > 4)) == \ + 'x=y'\ + 'x>4' + assert mpp.doprint(And(Eq(x, 3), Or(y < 3, x > y + 1))) == \ + 'x=3'\ + '(x>'\ + 'y+1'\ + 'y<3)' + + assert mpp.doprint(Not(x)) == '¬x' + assert mpp.doprint(Not(And(x, y))) == \ + '¬(xy)' + + +def test_root_notation_print(): + assert mathml(x**(S.One/3), printer='presentation') == \ + 'x3' + assert mathml(x**(S.One/3), printer='presentation', root_notation=False) ==\ + 'x13' + assert mathml(x**(S.One/3), printer='content') == \ + '3x' + assert mathml(x**(S.One/3), printer='content', root_notation=False) == \ + 'x13' + assert mathml(x**(Rational(-1, 3)), printer='presentation') == \ + '1x3' + assert mathml(x**(Rational(-1, 3)), printer='presentation', root_notation=False) \ + == '1x13' + + +def test_fold_frac_powers_print(): + expr = x ** Rational(5, 2) + assert mathml(expr, printer='presentation') == \ + 'x52' + assert mathml(expr, printer='presentation', fold_frac_powers=True) == \ + 'x52' + assert mathml(expr, printer='presentation', fold_frac_powers=False) == \ + 'x52' + + +def test_fold_short_frac_print(): + expr = Rational(2, 5) + assert mathml(expr, printer='presentation') == \ + '25' + assert mathml(expr, printer='presentation', fold_short_frac=True) == \ + '25' + assert mathml(expr, printer='presentation', fold_short_frac=False) == \ + '25' + + +def test_print_factorials(): + assert mpp.doprint(factorial(x)) == 'x!' + assert mpp.doprint(factorial(x + 1)) == \ + '(x+1)!' + assert mpp.doprint(factorial2(x)) == 'x!!' + assert mpp.doprint(factorial2(x + 1)) == \ + '(x+1)!!' + assert mpp.doprint(binomial(x, y)) == \ + '(xy)' + assert mpp.doprint(binomial(4, x + y)) == \ + '(4x'\ + '+y)' + + +def test_print_floor(): + expr = floor(x) + assert mathml(expr, printer='presentation') == \ + 'x' + + +def test_print_ceiling(): + expr = ceiling(x) + assert mathml(expr, printer='presentation') == \ + 'x' + + +def test_print_Lambda(): + expr = Lambda(x, x+1) + assert mathml(expr, printer='presentation') == \ + '(xx+1)' + expr = Lambda((x, y), x + y) + assert mathml(expr, printer='presentation') == \ + '((x,y)x+y)' + + +def test_print_conjugate(): + assert mpp.doprint(conjugate(x)) == \ + 'x' + assert mpp.doprint(conjugate(x + 1)) == \ + 'x+1' + + +def test_print_AccumBounds(): + a = Symbol('a', real=True) + assert mpp.doprint(AccumBounds(0, 1)) == '0,1' + assert mpp.doprint(AccumBounds(0, a)) == '0,a' + assert mpp.doprint(AccumBounds(a + 1, a + 2)) == 'a+1,a+2' + + +def test_print_Float(): + assert mpp.doprint(Float(1e100)) == '1.0·10100' + assert mpp.doprint(Float(1e-100)) == '1.0·10-100' + assert mpp.doprint(Float(-1e100)) == '-1.0·10100' + assert mpp.doprint(Float(1.0*oo)) == '' + assert mpp.doprint(Float(-1.0*oo)) == '-' + + +def test_print_different_functions(): + assert mpp.doprint(gamma(x)) == 'Γ(x)' + assert mpp.doprint(lowergamma(x, y)) == 'γ(x,y)' + assert mpp.doprint(uppergamma(x, y)) == 'Γ(x,y)' + assert mpp.doprint(zeta(x)) == 'ζ(x)' + assert mpp.doprint(zeta(x, y)) == 'ζ(x,y)' + assert mpp.doprint(dirichlet_eta(x)) == 'η(x)' + assert mpp.doprint(elliptic_k(x)) == 'Κ(x)' + assert mpp.doprint(totient(x)) == 'ϕ(x)' + assert mpp.doprint(reduced_totient(x)) == 'λ(x)' + assert mpp.doprint(primenu(x)) == 'ν(x)' + assert mpp.doprint(primeomega(x)) == 'Ω(x)' + assert mpp.doprint(fresnels(x)) == 'S(x)' + assert mpp.doprint(fresnelc(x)) == 'C(x)' + assert mpp.doprint(Heaviside(x)) == 'Θ(x,12)' + + +def test_mathml_builtins(): + assert mpp.doprint(None) == 'None' + assert mpp.doprint(true) == 'True' + assert mpp.doprint(false) == 'False' + + +def test_mathml_Range(): + assert mpp.doprint(Range(1, 51)) == \ + '{1,2,,50}' + assert mpp.doprint(Range(1, 4)) == \ + '{1,2,3}' + assert mpp.doprint(Range(0, 3, 1)) == \ + '{0,1,2}' + assert mpp.doprint(Range(0, 30, 1)) == \ + '{0,1,,29}' + assert mpp.doprint(Range(30, 1, -1)) == \ + '{30,29,,2}' + assert mpp.doprint(Range(0, oo, 2)) == \ + '{0,2,}' + assert mpp.doprint(Range(oo, -2, -2)) == \ + '{,2,0}' + assert mpp.doprint(Range(-2, -oo, -1)) == \ + '{-2,-3,}' + + +def test_print_exp(): + assert mpp.doprint(exp(x)) == \ + 'x' + assert mpp.doprint(exp(1) + exp(2)) == \ + '+2' + + +def test_print_MinMax(): + assert mpp.doprint(Min(x, y)) == \ + 'min(x,y)' + assert mpp.doprint(Min(x, 2, x**3)) == \ + 'min(2,x,x3)' + assert mpp.doprint(Max(x, y)) == \ + 'max(x,y)' + assert mpp.doprint(Max(x, 2, x**3)) == \ + 'max(2,x,x3)' + + +def test_mathml_presentation_numbers(): + n = Symbol('n') + assert mathml(catalan(n), printer='presentation') == \ + 'Cn' + assert mathml(bernoulli(n), printer='presentation') == \ + 'Bn' + assert mathml(bell(n), printer='presentation') == \ + 'Bn' + assert mathml(euler(n), printer='presentation') == \ + 'En' + assert mathml(fibonacci(n), printer='presentation') == \ + 'Fn' + assert mathml(lucas(n), printer='presentation') == \ + 'Ln' + assert mathml(tribonacci(n), printer='presentation') == \ + 'Tn' + assert mathml(bernoulli(n, x), printer='presentation') == \ + mathml(bell(n, x), printer='presentation') == \ + 'Bn(x)' + assert mathml(euler(n, x), printer='presentation') == \ + 'En(x)' + assert mathml(fibonacci(n, x), printer='presentation') == \ + 'Fn(x)' + assert mathml(tribonacci(n, x), printer='presentation') == \ + 'Tn(x)' + + +def test_mathml_presentation_mathieu(): + assert mathml(mathieuc(x, y, z), printer='presentation') == \ + 'C(x,y,z)' + assert mathml(mathieus(x, y, z), printer='presentation') == \ + 'S(x,y,z)' + assert mathml(mathieucprime(x, y, z), printer='presentation') == \ + 'C′(x,y,z)' + assert mathml(mathieusprime(x, y, z), printer='presentation') == \ + 'S′(x,y,z)' + + +def test_mathml_presentation_stieltjes(): + assert mathml(stieltjes(n), printer='presentation') == \ + 'γn' + assert mathml(stieltjes(n, x), printer='presentation') == \ + 'γn(x)' + + +def test_print_matrix_symbol(): + A = MatrixSymbol('A', 1, 2) + assert mpp.doprint(A) == 'A' + assert mp.doprint(A) == 'A' + assert mathml(A, printer='presentation', mat_symbol_style="bold") == \ + 'A' + # No effect in content printer + assert mathml(A, mat_symbol_style="bold") == 'A' + + +def test_print_hadamard(): + from sympy.matrices.expressions import HadamardProduct + from sympy.matrices.expressions import Transpose + + X = MatrixSymbol('X', 2, 2) + Y = MatrixSymbol('Y', 2, 2) + + assert mathml(HadamardProduct(X, Y*Y), printer="presentation") == \ + '' \ + 'X' \ + '' \ + 'Y2' \ + '' + + assert mathml(HadamardProduct(X, Y)*Y, printer="presentation") == \ + '' \ + '(' \ + 'XY' \ + ')' \ + 'Y' \ + '' + + assert mathml(HadamardProduct(X, Y, Y), printer="presentation") == \ + '' \ + 'X' \ + 'Y' \ + 'Y' \ + '' + + assert mathml( + Transpose(HadamardProduct(X, Y)), printer="presentation") == \ + '' \ + '(' \ + 'XY' \ + ')' \ + 'T' \ + '' + + +def test_print_random_symbol(): + R = RandomSymbol(Symbol('R')) + assert mpp.doprint(R) == 'R' + assert mp.doprint(R) == 'R' + + +def test_print_IndexedBase(): + assert mathml(IndexedBase(a)[b], printer='presentation') == \ + 'ab' + assert mathml(IndexedBase(a)[b, c, d], printer='presentation') == \ + 'a(b,c,d)' + assert mathml(IndexedBase(a)[b]*IndexedBase(c)[d]*IndexedBase(e), + printer='presentation') == \ + 'ab⁢'\ + 'cde' + + +def test_print_Indexed(): + assert mathml(IndexedBase(a), printer='presentation') == 'a' + assert mathml(IndexedBase(a/b), printer='presentation') == \ + 'ab' + assert mathml(IndexedBase((a, b)), printer='presentation') == \ + '(a,b)' + +def test_print_MatrixElement(): + i, j = symbols('i j') + A = MatrixSymbol('A', i, j) + assert mathml(A[0,0],printer = 'presentation') == \ + 'A0,0' + assert mathml(A[i,j], printer = 'presentation') == \ + 'Ai,j' + assert mathml(A[i*j,0], printer = 'presentation') == \ + 'Aij,0' + + +def test_print_Vector(): + ACS = CoordSys3D('A') + assert mathml(Cross(ACS.i, ACS.j*ACS.x*3 + ACS.k), printer='presentation') == \ + 'i^'\ + 'A×('\ + '(3'\ + 'xA'\ + ')'\ + 'j^'\ + 'A+'\ + 'k^'\ + 'A)' + assert mathml(Cross(ACS.i, ACS.j), printer='presentation') == \ + 'i^'\ + 'A×'\ + 'j^'\ + 'A' + assert mathml(x*Cross(ACS.i, ACS.j), printer='presentation') == \ + 'x('\ + 'i^'\ + 'A×'\ + 'j^'\ + 'A)' + assert mathml(Cross(x*ACS.i, ACS.j), printer='presentation') == \ + '-j'\ + '^A'\ + '×((x)'\ + 'i'\ + '^A'\ + ')' + assert mathml(Curl(3*ACS.x*ACS.j), printer='presentation') == \ + '×(('\ + '3x'\ + 'A)'\ + 'j^'\ + 'A)' + assert mathml(Curl(3*x*ACS.x*ACS.j), printer='presentation') == \ + '×(('\ + '3x'\ + 'A'\ + 'x)'\ + 'j^'\ + 'A)' + assert mathml(x*Curl(3*ACS.x*ACS.j), printer='presentation') == \ + 'x('\ + '×((3'\ + 'x'\ + 'A)'\ + 'j'\ + '^A)'\ + ')' + assert mathml(Curl(3*x*ACS.x*ACS.j + ACS.i), printer='presentation') == \ + '×('\ + 'i^'\ + 'A+('\ + '3x'\ + 'A'\ + 'x)'\ + 'j^'\ + 'A)' + assert mathml(Divergence(3*ACS.x*ACS.j), printer='presentation') == \ + '·(('\ + '3x'\ + 'A)'\ + 'j'\ + '^A)' + assert mathml(x*Divergence(3*ACS.x*ACS.j), printer='presentation') == \ + 'x('\ + '·((3'\ + 'x'\ + 'A)'\ + 'j'\ + '^A'\ + '))' + assert mathml(Divergence(3*x*ACS.x*ACS.j + ACS.i), printer='presentation') == \ + '·('\ + 'i^'\ + 'A+('\ + '3'\ + 'xA'\ + 'x)'\ + 'j'\ + '^A'\ + ')' + assert mathml(Dot(ACS.i, ACS.j*ACS.x*3+ACS.k), printer='presentation') == \ + 'i^'\ + 'A·('\ + '(3'\ + 'xA'\ + ')'\ + 'j^'\ + 'A+'\ + 'k^'\ + 'A)' + assert mathml(Dot(ACS.i, ACS.j), printer='presentation') == \ + 'i^'\ + 'A·'\ + 'j^'\ + 'A' + assert mathml(Dot(x*ACS.i, ACS.j), printer='presentation') == \ + 'j^'\ + 'A·('\ + '(x)'\ + 'i^'\ + 'A)' + assert mathml(x*Dot(ACS.i, ACS.j), printer='presentation') == \ + 'x('\ + 'i^'\ + 'A·'\ + 'j^'\ + 'A)' + assert mathml(Gradient(ACS.x), printer='presentation') == \ + 'x'\ + 'A' + assert mathml(Gradient(ACS.x + 3*ACS.y), printer='presentation') == \ + '('\ + 'xA+3'\ + 'y'\ + 'A)' + assert mathml(x*Gradient(ACS.x), printer='presentation') == \ + 'x('\ + 'xA'\ + ')' + assert mathml(Gradient(x*ACS.x), printer='presentation') == \ + '('\ + 'xA'\ + 'x)' + assert mathml(Cross(ACS.z, ACS.x), printer='presentation') == \ + '-x'\ + 'A×'\ + 'zA' + assert mathml(Laplacian(ACS.x), printer='presentation') == \ + 'x'\ + 'A' + assert mathml(Laplacian(ACS.x + 3*ACS.y), printer='presentation') == \ + '('\ + 'xA+3'\ + 'y'\ + 'A)' + assert mathml(x*Laplacian(ACS.x), printer='presentation') == \ + 'x('\ + 'xA'\ + ')' + assert mathml(Laplacian(x*ACS.x), printer='presentation') == \ + '('\ + 'xA'\ + 'x)' + +@XFAIL +def test_vector_cross_xfail(): + ACS = CoordSys3D('A') + assert mathml(Cross(ACS.x, ACS.z) + Cross(ACS.z, ACS.x), printer='presentation') == \ + '0^' + +def test_print_elliptic_f(): + assert mathml(elliptic_f(x, y), printer = 'presentation') == \ + '𝖥(x|y)' + assert mathml(elliptic_f(x/y, y), printer = 'presentation') == \ + '𝖥(xy|y)' + +def test_print_elliptic_e(): + assert mathml(elliptic_e(x), printer = 'presentation') == \ + '𝖤(x)' + assert mathml(elliptic_e(x, y), printer = 'presentation') == \ + '𝖤(x|y)' + +def test_print_elliptic_pi(): + assert mathml(elliptic_pi(x, y), printer = 'presentation') == \ + '𝛱(x|y)' + assert mathml(elliptic_pi(x, y, z), printer = 'presentation') == \ + '𝛱(x;y|z)' + +def test_print_Ei(): + assert mathml(Ei(x), printer = 'presentation') == \ + 'Ei(x)' + assert mathml(Ei(x**y), printer = 'presentation') == \ + 'Ei(xy)' + +def test_print_expint(): + assert mathml(expint(x, y), printer = 'presentation') == \ + 'Ex(y)' + assert mathml(expint(IndexedBase(x)[1], IndexedBase(x)[2]), printer = 'presentation') == \ + 'Ex1(x2)' + +def test_print_jacobi(): + assert mathml(jacobi(n, a, b, x), printer = 'presentation') == \ + 'Pn(a,b)(x)' + +def test_print_gegenbauer(): + assert mathml(gegenbauer(n, a, x), printer = 'presentation') == \ + 'Cn(a)(x)' + +def test_print_chebyshevt(): + assert mathml(chebyshevt(n, x), printer = 'presentation') == \ + 'Tn(x)' + +def test_print_chebyshevu(): + assert mathml(chebyshevu(n, x), printer = 'presentation') == \ + 'Un(x)' + +def test_print_legendre(): + assert mathml(legendre(n, x), printer = 'presentation') == \ + 'Pn(x)' + +def test_print_assoc_legendre(): + assert mathml(assoc_legendre(n, a, x), printer = 'presentation') == \ + 'Pn(a)(x)' + +def test_print_laguerre(): + assert mathml(laguerre(n, x), printer = 'presentation') == \ + 'Ln(x)' + +def test_print_assoc_laguerre(): + assert mathml(assoc_laguerre(n, a, x), printer = 'presentation') == \ + 'Ln(a)(x)' + +def test_print_hermite(): + assert mathml(hermite(n, x), printer = 'presentation') == \ + 'Hn(x)' + +def test_mathml_SingularityFunction(): + assert mathml(SingularityFunction(x, 4, 5), printer='presentation') == \ + 'x-45' + assert mathml(SingularityFunction(x, -3, 4), printer='presentation') == \ + 'x+34' + assert mathml(SingularityFunction(x, 0, 4), printer='presentation') == \ + 'x4' + assert mathml(SingularityFunction(x, a, n), printer='presentation') == \ + '-a+xn' + assert mathml(SingularityFunction(x, 4, -2), printer='presentation') == \ + 'x-4-2' + assert mathml(SingularityFunction(x, 4, -1), printer='presentation') == \ + 'x-4-1' + + +def test_mathml_matrix_functions(): + from sympy.matrices import Adjoint, Inverse, Transpose + X = MatrixSymbol('X', 2, 2) + Y = MatrixSymbol('Y', 2, 2) + assert mathml(Adjoint(X), printer='presentation') == \ + 'X' + assert mathml(Adjoint(X + Y), printer='presentation') == \ + '(X+Y)' + assert mathml(Adjoint(X) + Adjoint(Y), printer='presentation') == \ + 'X+' \ + 'Y' + assert mathml(Adjoint(X*Y), printer='presentation') == \ + '(X' \ + 'Y)' + assert mathml(Adjoint(Y)*Adjoint(X), printer='presentation') == \ + 'Y⁢' \ + 'X' + assert mathml(Adjoint(X**2), printer='presentation') == \ + '(X2)' + assert mathml(Adjoint(X)**2, printer='presentation') == \ + '(X)2' + assert mathml(Adjoint(Inverse(X)), printer='presentation') == \ + '(X-1)' + assert mathml(Inverse(Adjoint(X)), printer='presentation') == \ + '(X)-1' + assert mathml(Adjoint(Transpose(X)), printer='presentation') == \ + '(XT)' + assert mathml(Transpose(Adjoint(X)), printer='presentation') == \ + '(X)T' + assert mathml(Transpose(Adjoint(X) + Y), printer='presentation') == \ + '(X' \ + '+Y)T' + assert mathml(Transpose(X), printer='presentation') == \ + 'XT' + assert mathml(Transpose(X + Y), printer='presentation') == \ + '(X+Y)T' + + +def test_mathml_special_matrices(): + from sympy.matrices import Identity, ZeroMatrix, OneMatrix + assert mathml(Identity(4), printer='presentation') == '𝕀' + assert mathml(ZeroMatrix(2, 2), printer='presentation') == '𝟘' + assert mathml(OneMatrix(2, 2), printer='presentation') == '𝟙' + +def test_mathml_piecewise(): + from sympy.functions.elementary.piecewise import Piecewise + # Content MathML + assert mathml(Piecewise((x, x <= 1), (x**2, True))) == \ + 'xx1x2' + + raises(ValueError, lambda: mathml(Piecewise((x, x <= 1)))) + + +def test_issue_17857(): + assert mathml(Range(-oo, oo), printer='presentation') == \ + '{,-1,0,1,}' + assert mathml(Range(oo, -oo, -1), printer='presentation') == \ + '{,1,0,-1,}' + + +def test_float_roundtrip(): + x = sympify(0.8975979010256552) + y = float(mp.doprint(x).strip('')) + assert x == y + + +def test_content_mathml_disable_split_super_sub(): + mp = MathMLContentPrinter() + assert mp.doprint(Symbol('u_b')) == 'ub' + mp = MathMLContentPrinter({'disable_split_super_sub': False}) + assert mp.doprint(Symbol('u_b')) == 'ub' + mp = MathMLContentPrinter({'disable_split_super_sub': True}) + assert mp.doprint(Symbol('u_b')) == 'u_b' + +def test_presentation_mathml_disable_split_super_sub(): + mpp = MathMLPresentationPrinter() + assert mpp.doprint(Symbol('u_b')) == 'ub' + mpp = MathMLPresentationPrinter({'disable_split_super_sub': False}) + assert mpp.doprint(Symbol('u_b')) == 'ub' + mpp = MathMLPresentationPrinter({'disable_split_super_sub': True}) + assert mpp.doprint(Symbol('u_b')) == 'u_b' diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/test_numpy.py b/phivenv/Lib/site-packages/sympy/printing/tests/test_numpy.py new file mode 100644 index 0000000000000000000000000000000000000000..fee1c6bd95e54790a048220f37b8e5de79017d2f --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/tests/test_numpy.py @@ -0,0 +1,381 @@ +from sympy.concrete.summations import Sum +from sympy.core.mod import Mod +from sympy.core.relational import (Equality, Unequality) +from sympy.core.symbol import Symbol +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.special.gamma_functions import polygamma +from sympy.functions.special.error_functions import (Si, Ci) +from sympy.matrices import Matrix +from sympy.matrices.expressions.blockmatrix import BlockMatrix +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.expressions.special import Identity +from sympy.utilities.lambdify import lambdify +from sympy import symbols, Min, Max + +from sympy.abc import x, i, j, a, b, c, d +from sympy.core import Pow +from sympy.codegen.matrix_nodes import MatrixSolve +from sympy.codegen.numpy_nodes import logaddexp, logaddexp2 +from sympy.codegen.cfunctions import log1p, expm1, hypot, log10, exp2, log2, Sqrt +from sympy.tensor.array import Array +from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct, ArrayAdd, \ + PermuteDims, ArrayDiagonal +from sympy.printing.numpy import NumPyPrinter, SciPyPrinter, _numpy_known_constants, \ + _numpy_known_functions, _scipy_known_constants, _scipy_known_functions +from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array + +from sympy.testing.pytest import skip, raises +from sympy.external import import_module + +np = import_module('numpy') +jax = import_module('jax') + +if np: + deafult_float_info = np.finfo(np.array([]).dtype) + NUMPY_DEFAULT_EPSILON = deafult_float_info.eps + +def test_numpy_piecewise_regression(): + """ + NumPyPrinter needs to print Piecewise()'s choicelist as a list to avoid + breaking compatibility with numpy 1.8. This is not necessary in numpy 1.9+. + See gh-9747 and gh-9749 for details. + """ + printer = NumPyPrinter() + p = Piecewise((1, x < 0), (0, True)) + assert printer.doprint(p) == \ + 'numpy.select([numpy.less(x, 0),True], [1,0], default=numpy.nan)' + assert printer.module_imports == {'numpy': {'select', 'less', 'nan'}} + +def test_numpy_logaddexp(): + lae = logaddexp(a, b) + assert NumPyPrinter().doprint(lae) == 'numpy.logaddexp(a, b)' + lae2 = logaddexp2(a, b) + assert NumPyPrinter().doprint(lae2) == 'numpy.logaddexp2(a, b)' + + +def test_sum(): + if not np: + skip("NumPy not installed") + + s = Sum(x ** i, (i, a, b)) + f = lambdify((a, b, x), s, 'numpy') + + a_, b_ = 0, 10 + x_ = np.linspace(-1, +1, 10) + assert np.allclose(f(a_, b_, x_), sum(x_ ** i_ for i_ in range(a_, b_ + 1))) + + s = Sum(i * x, (i, a, b)) + f = lambdify((a, b, x), s, 'numpy') + + a_, b_ = 0, 10 + x_ = np.linspace(-1, +1, 10) + assert np.allclose(f(a_, b_, x_), sum(i_ * x_ for i_ in range(a_, b_ + 1))) + + +def test_multiple_sums(): + if not np: + skip("NumPy not installed") + + s = Sum((x + j) * i, (i, a, b), (j, c, d)) + f = lambdify((a, b, c, d, x), s, 'numpy') + + a_, b_ = 0, 10 + c_, d_ = 11, 21 + x_ = np.linspace(-1, +1, 10) + assert np.allclose(f(a_, b_, c_, d_, x_), + sum((x_ + j_) * i_ for i_ in range(a_, b_ + 1) for j_ in range(c_, d_ + 1))) + + +def test_codegen_einsum(): + if not np: + skip("NumPy not installed") + + M = MatrixSymbol("M", 2, 2) + N = MatrixSymbol("N", 2, 2) + + cg = convert_matrix_to_array(M * N) + f = lambdify((M, N), cg, 'numpy') + + ma = np.array([[1, 2], [3, 4]]) + mb = np.array([[1,-2], [-1, 3]]) + assert (f(ma, mb) == np.matmul(ma, mb)).all() + + +def test_codegen_extra(): + if not np: + skip("NumPy not installed") + + M = MatrixSymbol("M", 2, 2) + N = MatrixSymbol("N", 2, 2) + P = MatrixSymbol("P", 2, 2) + Q = MatrixSymbol("Q", 2, 2) + ma = np.array([[1, 2], [3, 4]]) + mb = np.array([[1,-2], [-1, 3]]) + mc = np.array([[2, 0], [1, 2]]) + md = np.array([[1,-1], [4, 7]]) + + cg = ArrayTensorProduct(M, N) + f = lambdify((M, N), cg, 'numpy') + assert (f(ma, mb) == np.einsum(ma, [0, 1], mb, [2, 3])).all() + + cg = ArrayAdd(M, N) + f = lambdify((M, N), cg, 'numpy') + assert (f(ma, mb) == ma+mb).all() + + cg = ArrayAdd(M, N, P) + f = lambdify((M, N, P), cg, 'numpy') + assert (f(ma, mb, mc) == ma+mb+mc).all() + + cg = ArrayAdd(M, N, P, Q) + f = lambdify((M, N, P, Q), cg, 'numpy') + assert (f(ma, mb, mc, md) == ma+mb+mc+md).all() + + cg = PermuteDims(M, [1, 0]) + f = lambdify((M,), cg, 'numpy') + assert (f(ma) == ma.T).all() + + cg = PermuteDims(ArrayTensorProduct(M, N), [1, 2, 3, 0]) + f = lambdify((M, N), cg, 'numpy') + assert (f(ma, mb) == np.transpose(np.einsum(ma, [0, 1], mb, [2, 3]), (1, 2, 3, 0))).all() + + cg = ArrayDiagonal(ArrayTensorProduct(M, N), (1, 2)) + f = lambdify((M, N), cg, 'numpy') + assert (f(ma, mb) == np.diagonal(np.einsum(ma, [0, 1], mb, [2, 3]), axis1=1, axis2=2)).all() + + +def test_relational(): + if not np: + skip("NumPy not installed") + + e = Equality(x, 1) + + f = lambdify((x,), e) + x_ = np.array([0, 1, 2]) + assert np.array_equal(f(x_), [False, True, False]) + + e = Unequality(x, 1) + + f = lambdify((x,), e) + x_ = np.array([0, 1, 2]) + assert np.array_equal(f(x_), [True, False, True]) + + e = (x < 1) + + f = lambdify((x,), e) + x_ = np.array([0, 1, 2]) + assert np.array_equal(f(x_), [True, False, False]) + + e = (x <= 1) + + f = lambdify((x,), e) + x_ = np.array([0, 1, 2]) + assert np.array_equal(f(x_), [True, True, False]) + + e = (x > 1) + + f = lambdify((x,), e) + x_ = np.array([0, 1, 2]) + assert np.array_equal(f(x_), [False, False, True]) + + e = (x >= 1) + + f = lambdify((x,), e) + x_ = np.array([0, 1, 2]) + assert np.array_equal(f(x_), [False, True, True]) + + +def test_mod(): + if not np: + skip("NumPy not installed") + + e = Mod(a, b) + f = lambdify((a, b), e) + + a_ = np.array([0, 1, 2, 3]) + b_ = 2 + assert np.array_equal(f(a_, b_), [0, 1, 0, 1]) + + a_ = np.array([0, 1, 2, 3]) + b_ = np.array([2, 2, 2, 2]) + assert np.array_equal(f(a_, b_), [0, 1, 0, 1]) + + a_ = np.array([2, 3, 4, 5]) + b_ = np.array([2, 3, 4, 5]) + assert np.array_equal(f(a_, b_), [0, 0, 0, 0]) + + +def test_pow(): + if not np: + skip('NumPy not installed') + + expr = Pow(2, -1, evaluate=False) + f = lambdify([], expr, 'numpy') + assert f() == 0.5 + + +def test_expm1(): + if not np: + skip("NumPy not installed") + + f = lambdify((a,), expm1(a), 'numpy') + assert abs(f(1e-10) - 1e-10 - 5e-21) <= 1e-10 * NUMPY_DEFAULT_EPSILON + + +def test_log1p(): + if not np: + skip("NumPy not installed") + + f = lambdify((a,), log1p(a), 'numpy') + assert abs(f(1e-99) - 1e-99) <= 1e-99 * NUMPY_DEFAULT_EPSILON + +def test_hypot(): + if not np: + skip("NumPy not installed") + assert abs(lambdify((a, b), hypot(a, b), 'numpy')(3, 4) - 5) <= NUMPY_DEFAULT_EPSILON + +def test_log10(): + if not np: + skip("NumPy not installed") + assert abs(lambdify((a,), log10(a), 'numpy')(100) - 2) <= NUMPY_DEFAULT_EPSILON + + +def test_exp2(): + if not np: + skip("NumPy not installed") + assert abs(lambdify((a,), exp2(a), 'numpy')(5) - 32) <= NUMPY_DEFAULT_EPSILON + + +def test_log2(): + if not np: + skip("NumPy not installed") + assert abs(lambdify((a,), log2(a), 'numpy')(256) - 8) <= NUMPY_DEFAULT_EPSILON + + +def test_Sqrt(): + if not np: + skip("NumPy not installed") + assert abs(lambdify((a,), Sqrt(a), 'numpy')(4) - 2) <= NUMPY_DEFAULT_EPSILON + + +def test_sqrt(): + if not np: + skip("NumPy not installed") + assert abs(lambdify((a,), sqrt(a), 'numpy')(4) - 2) <= NUMPY_DEFAULT_EPSILON + + +def test_matsolve(): + if not np: + skip("NumPy not installed") + + M = MatrixSymbol("M", 3, 3) + x = MatrixSymbol("x", 3, 1) + + expr = M**(-1) * x + x + matsolve_expr = MatrixSolve(M, x) + x + + f = lambdify((M, x), expr) + f_matsolve = lambdify((M, x), matsolve_expr) + + m0 = np.array([[1, 2, 3], [3, 2, 5], [5, 6, 7]]) + assert np.linalg.matrix_rank(m0) == 3 + + x0 = np.array([3, 4, 5]) + + assert np.allclose(f_matsolve(m0, x0), f(m0, x0)) + + +def test_16857(): + if not np: + skip("NumPy not installed") + + a_1 = MatrixSymbol('a_1', 10, 3) + a_2 = MatrixSymbol('a_2', 10, 3) + a_3 = MatrixSymbol('a_3', 10, 3) + a_4 = MatrixSymbol('a_4', 10, 3) + A = BlockMatrix([[a_1, a_2], [a_3, a_4]]) + assert A.shape == (20, 6) + + printer = NumPyPrinter() + assert printer.doprint(A) == 'numpy.block([[a_1, a_2], [a_3, a_4]])' + + +def test_issue_17006(): + if not np: + skip("NumPy not installed") + + M = MatrixSymbol("M", 2, 2) + + f = lambdify(M, M + Identity(2)) + ma = np.array([[1, 2], [3, 4]]) + mr = np.array([[2, 2], [3, 5]]) + + assert (f(ma) == mr).all() + + from sympy.core.symbol import symbols + n = symbols('n', integer=True) + N = MatrixSymbol("M", n, n) + raises(NotImplementedError, lambda: lambdify(N, N + Identity(n))) + +def test_jax_tuple_compatibility(): + if not jax: + skip("Jax not installed") + + x, y, z = symbols('x y z') + expr = Max(x, y, z) + Min(x, y, z) + func = lambdify((x, y, z), expr, 'jax') + input_tuple1, input_tuple2 = (1, 2, 3), (4, 5, 6) + input_array1, input_array2 = jax.numpy.asarray(input_tuple1), jax.numpy.asarray(input_tuple2) + assert np.allclose(func(*input_tuple1), func(*input_array1)) + assert np.allclose(func(*input_tuple2), func(*input_array2)) + +def test_numpy_array(): + p = NumPyPrinter() + assert p.doprint(Array([[1, 2], [3, 5]])) == 'numpy.array([[1, 2], [3, 5]])' + assert p.doprint(Array([1, 2])) == 'numpy.array([1, 2])' + assert p.doprint(Array([[[1, 2, 3]]])) == 'numpy.array([[[1, 2, 3]]])' + assert p.doprint(Array([], (0,))) == 'numpy.zeros((0,))' + assert p.doprint(Array([], (0, 0))) == 'numpy.zeros((0, 0))' + assert p.doprint(Array([], (0, 1))) == 'numpy.zeros((0, 1))' + assert p.doprint(Array([], (1, 0))) == 'numpy.zeros((1, 0))' + assert p.doprint(Array([1], ())) == 'numpy.array(1)' + +def test_numpy_matrix(): + p = NumPyPrinter() + assert p.doprint(Matrix([[1, 2], [3, 5]])) == 'numpy.array([[1, 2], [3, 5]])' + assert p.doprint(Matrix([1, 2])) == 'numpy.array([[1], [2]])' + assert p.doprint(Matrix(0, 0, [])) == 'numpy.zeros((0, 0))' + assert p.doprint(Matrix(0, 1, [])) == 'numpy.zeros((0, 1))' + assert p.doprint(Matrix(1, 0, [])) == 'numpy.zeros((1, 0))' + +def test_numpy_known_funcs_consts(): + assert _numpy_known_constants['NaN'] == 'numpy.nan' + assert _numpy_known_constants['EulerGamma'] == 'numpy.euler_gamma' + + assert _numpy_known_functions['acos'] == 'numpy.arccos' + assert _numpy_known_functions['log'] == 'numpy.log' + +def test_scipy_known_funcs_consts(): + assert _scipy_known_constants['GoldenRatio'] == 'scipy.constants.golden_ratio' + assert _scipy_known_constants['Pi'] == 'scipy.constants.pi' + + assert _scipy_known_functions['erf'] == 'scipy.special.erf' + assert _scipy_known_functions['factorial'] == 'scipy.special.factorial' + +def test_numpy_print_methods(): + prntr = NumPyPrinter() + assert hasattr(prntr, '_print_acos') + assert hasattr(prntr, '_print_log') + +def test_scipy_print_methods(): + prntr = SciPyPrinter() + assert hasattr(prntr, '_print_acos') + assert hasattr(prntr, '_print_log') + assert hasattr(prntr, '_print_erf') + assert hasattr(prntr, '_print_factorial') + assert hasattr(prntr, '_print_chebyshevt') + k = Symbol('k', integer=True, nonnegative=True) + x = Symbol('x', real=True) + assert prntr.doprint(polygamma(k, x)) == "scipy.special.polygamma(k, x)" + assert prntr.doprint(Si(x)) == "scipy.special.sici(x)[0]" + assert prntr.doprint(Ci(x)) == "scipy.special.sici(x)[1]" diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/test_octave.py b/phivenv/Lib/site-packages/sympy/printing/tests/test_octave.py new file mode 100644 index 0000000000000000000000000000000000000000..1aba318f873c48ec702f1b4e3a6cc047f75d647d --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/tests/test_octave.py @@ -0,0 +1,515 @@ +from sympy.core import (S, pi, oo, symbols, Function, Rational, Integer, + Tuple, Symbol, EulerGamma, GoldenRatio, Catalan, + Lambda, Mul, Pow, Mod, Eq, Ne, Le, Lt, Gt, Ge) +from sympy.codegen.matrix_nodes import MatrixSolve +from sympy.functions import (arg, atan2, bernoulli, beta, ceiling, chebyshevu, + chebyshevt, conjugate, DiracDelta, exp, expint, + factorial, floor, harmonic, Heaviside, im, + laguerre, LambertW, log, Max, Min, Piecewise, + polylog, re, RisingFactorial, sign, sinc, sqrt, + zeta, binomial, legendre, dirichlet_eta, + riemann_xi) +from sympy.functions import (sin, cos, tan, cot, sec, csc, asin, acos, acot, + atan, asec, acsc, sinh, cosh, tanh, coth, csch, + sech, asinh, acosh, atanh, acoth, asech, acsch) +from sympy.testing.pytest import raises, XFAIL +from sympy.utilities.lambdify import implemented_function +from sympy.matrices import (eye, Matrix, MatrixSymbol, Identity, + HadamardProduct, SparseMatrix, HadamardPower) +from sympy.functions.special.bessel import (jn, yn, besselj, bessely, besseli, + besselk, hankel1, hankel2, airyai, + airybi, airyaiprime, airybiprime) +from sympy.functions.special.gamma_functions import (gamma, lowergamma, + uppergamma, loggamma, + polygamma) +from sympy.functions.special.error_functions import (Chi, Ci, erf, erfc, erfi, + erfcinv, erfinv, fresnelc, + fresnels, li, Shi, Si, Li, + erf2, Ei) +from sympy.printing.octave import octave_code, octave_code as mcode + +x, y, z = symbols('x,y,z') + + +def test_Integer(): + assert mcode(Integer(67)) == "67" + assert mcode(Integer(-1)) == "-1" + + +def test_Rational(): + assert mcode(Rational(3, 7)) == "3/7" + assert mcode(Rational(18, 9)) == "2" + assert mcode(Rational(3, -7)) == "-3/7" + assert mcode(Rational(-3, -7)) == "3/7" + assert mcode(x + Rational(3, 7)) == "x + 3/7" + assert mcode(Rational(3, 7)*x) == "3*x/7" + + +def test_Relational(): + assert mcode(Eq(x, y)) == "x == y" + assert mcode(Ne(x, y)) == "x != y" + assert mcode(Le(x, y)) == "x <= y" + assert mcode(Lt(x, y)) == "x < y" + assert mcode(Gt(x, y)) == "x > y" + assert mcode(Ge(x, y)) == "x >= y" + + +def test_Function(): + assert mcode(sin(x) ** cos(x)) == "sin(x).^cos(x)" + assert mcode(sign(x)) == "sign(x)" + assert mcode(exp(x)) == "exp(x)" + assert mcode(log(x)) == "log(x)" + assert mcode(factorial(x)) == "factorial(x)" + assert mcode(floor(x)) == "floor(x)" + assert mcode(atan2(y, x)) == "atan2(y, x)" + assert mcode(beta(x, y)) == 'beta(x, y)' + assert mcode(polylog(x, y)) == 'polylog(x, y)' + assert mcode(harmonic(x)) == 'harmonic(x)' + assert mcode(bernoulli(x)) == "bernoulli(x)" + assert mcode(bernoulli(x, y)) == "bernoulli(x, y)" + assert mcode(legendre(x, y)) == "legendre(x, y)" + + +def test_Function_change_name(): + assert mcode(abs(x)) == "abs(x)" + assert mcode(ceiling(x)) == "ceil(x)" + assert mcode(arg(x)) == "angle(x)" + assert mcode(im(x)) == "imag(x)" + assert mcode(re(x)) == "real(x)" + assert mcode(conjugate(x)) == "conj(x)" + assert mcode(chebyshevt(y, x)) == "chebyshevT(y, x)" + assert mcode(chebyshevu(y, x)) == "chebyshevU(y, x)" + assert mcode(laguerre(x, y)) == "laguerreL(x, y)" + assert mcode(Chi(x)) == "coshint(x)" + assert mcode(Shi(x)) == "sinhint(x)" + assert mcode(Ci(x)) == "cosint(x)" + assert mcode(Si(x)) == "sinint(x)" + assert mcode(li(x)) == "logint(x)" + assert mcode(loggamma(x)) == "gammaln(x)" + assert mcode(polygamma(x, y)) == "psi(x, y)" + assert mcode(RisingFactorial(x, y)) == "pochhammer(x, y)" + assert mcode(DiracDelta(x)) == "dirac(x)" + assert mcode(DiracDelta(x, 3)) == "dirac(3, x)" + assert mcode(Heaviside(x)) == "heaviside(x, 1/2)" + assert mcode(Heaviside(x, y)) == "heaviside(x, y)" + assert mcode(binomial(x, y)) == "bincoeff(x, y)" + assert mcode(Mod(x, y)) == "mod(x, y)" + + +def test_minmax(): + assert mcode(Max(x, y) + Min(x, y)) == "max(x, y) + min(x, y)" + assert mcode(Max(x, y, z)) == "max(x, max(y, z))" + assert mcode(Min(x, y, z)) == "min(x, min(y, z))" + + +def test_Pow(): + assert mcode(x**3) == "x.^3" + assert mcode(x**(y**3)) == "x.^(y.^3)" + assert mcode(x**Rational(2, 3)) == 'x.^(2/3)' + g = implemented_function('g', Lambda(x, 2*x)) + assert mcode(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \ + "(3.5*2*x).^(-x + y.^x)./(x.^2 + y)" + # For issue 14160 + assert mcode(Mul(-2, x, Pow(Mul(y,y,evaluate=False), -1, evaluate=False), + evaluate=False)) == '-2*x./(y.*y)' + + +def test_basic_ops(): + assert mcode(x*y) == "x.*y" + assert mcode(x + y) == "x + y" + assert mcode(x - y) == "x - y" + assert mcode(-x) == "-x" + + +def test_1_over_x_and_sqrt(): + # 1.0 and 0.5 would do something different in regular StrPrinter, + # but these are exact in IEEE floating point so no different here. + assert mcode(1/x) == '1./x' + assert mcode(x**-1) == mcode(x**-1.0) == '1./x' + assert mcode(1/sqrt(x)) == '1./sqrt(x)' + assert mcode(x**-S.Half) == mcode(x**-0.5) == '1./sqrt(x)' + assert mcode(sqrt(x)) == 'sqrt(x)' + assert mcode(x**S.Half) == mcode(x**0.5) == 'sqrt(x)' + assert mcode(1/pi) == '1/pi' + assert mcode(pi**-1) == mcode(pi**-1.0) == '1/pi' + assert mcode(pi**-0.5) == '1/sqrt(pi)' + + +def test_mix_number_mult_symbols(): + assert mcode(3*x) == "3*x" + assert mcode(pi*x) == "pi*x" + assert mcode(3/x) == "3./x" + assert mcode(pi/x) == "pi./x" + assert mcode(x/3) == "x/3" + assert mcode(x/pi) == "x/pi" + assert mcode(x*y) == "x.*y" + assert mcode(3*x*y) == "3*x.*y" + assert mcode(3*pi*x*y) == "3*pi*x.*y" + assert mcode(x/y) == "x./y" + assert mcode(3*x/y) == "3*x./y" + assert mcode(x*y/z) == "x.*y./z" + assert mcode(x/y*z) == "x.*z./y" + assert mcode(1/x/y) == "1./(x.*y)" + assert mcode(2*pi*x/y/z) == "2*pi*x./(y.*z)" + assert mcode(3*pi/x) == "3*pi./x" + assert mcode(S(3)/5) == "3/5" + assert mcode(S(3)/5*x) == "3*x/5" + assert mcode(x/y/z) == "x./(y.*z)" + assert mcode((x+y)/z) == "(x + y)./z" + assert mcode((x+y)/(z+x)) == "(x + y)./(x + z)" + assert mcode((x+y)/EulerGamma) == "(x + y)/%s" % EulerGamma.evalf(17) + assert mcode(x/3/pi) == "x/(3*pi)" + assert mcode(S(3)/5*x*y/pi) == "3*x.*y/(5*pi)" + + +def test_mix_number_pow_symbols(): + assert mcode(pi**3) == 'pi^3' + assert mcode(x**2) == 'x.^2' + assert mcode(x**(pi**3)) == 'x.^(pi^3)' + assert mcode(x**y) == 'x.^y' + assert mcode(x**(y**z)) == 'x.^(y.^z)' + assert mcode((x**y)**z) == '(x.^y).^z' + + +def test_imag(): + I = S('I') + assert mcode(I) == "1i" + assert mcode(5*I) == "5i" + assert mcode((S(3)/2)*I) == "3*1i/2" + assert mcode(3+4*I) == "3 + 4i" + assert mcode(sqrt(3)*I) == "sqrt(3)*1i" + + +def test_constants(): + assert mcode(pi) == "pi" + assert mcode(oo) == "inf" + assert mcode(-oo) == "-inf" + assert mcode(S.NegativeInfinity) == "-inf" + assert mcode(S.NaN) == "NaN" + assert mcode(S.Exp1) == "exp(1)" + assert mcode(exp(1)) == "exp(1)" + + +def test_constants_other(): + assert mcode(2*GoldenRatio) == "2*(1+sqrt(5))/2" + assert mcode(2*Catalan) == "2*%s" % Catalan.evalf(17) + assert mcode(2*EulerGamma) == "2*%s" % EulerGamma.evalf(17) + + +def test_boolean(): + assert mcode(x & y) == "x & y" + assert mcode(x | y) == "x | y" + assert mcode(~x) == "~x" + assert mcode(x & y & z) == "x & y & z" + assert mcode(x | y | z) == "x | y | z" + assert mcode((x & y) | z) == "z | x & y" + assert mcode((x | y) & z) == "z & (x | y)" + + +def test_KroneckerDelta(): + from sympy.functions import KroneckerDelta + assert mcode(KroneckerDelta(x, y)) == "double(x == y)" + assert mcode(KroneckerDelta(x, y + 1)) == "double(x == (y + 1))" + assert mcode(KroneckerDelta(2**x, y)) == "double((2.^x) == y)" + + +def test_Matrices(): + assert mcode(Matrix(1, 1, [10])) == "10" + A = Matrix([[1, sin(x/2), abs(x)], + [0, 1, pi], + [0, exp(1), ceiling(x)]]) + expected = "[1 sin(x/2) abs(x); 0 1 pi; 0 exp(1) ceil(x)]" + assert mcode(A) == expected + # row and columns + assert mcode(A[:,0]) == "[1; 0; 0]" + assert mcode(A[0,:]) == "[1 sin(x/2) abs(x)]" + # empty matrices + assert mcode(Matrix(0, 0, [])) == '[]' + assert mcode(Matrix(0, 3, [])) == 'zeros(0, 3)' + # annoying to read but correct + assert mcode(Matrix([[x, x - y, -y]])) == "[x x - y -y]" + + +def test_vector_entries_hadamard(): + # For a row or column, user might to use the other dimension + A = Matrix([[1, sin(2/x), 3*pi/x/5]]) + assert mcode(A) == "[1 sin(2./x) 3*pi./(5*x)]" + assert mcode(A.T) == "[1; sin(2./x); 3*pi./(5*x)]" + + +@XFAIL +def test_Matrices_entries_not_hadamard(): + # For Matrix with col >= 2, row >= 2, they need to be scalars + # FIXME: is it worth worrying about this? Its not wrong, just + # leave it user's responsibility to put scalar data for x. + A = Matrix([[1, sin(2/x), 3*pi/x/5], [1, 2, x*y]]) + expected = ("[1 sin(2/x) 3*pi/(5*x);\n" + "1 2 x*y]") # <- we give x.*y + assert mcode(A) == expected + + +def test_MatrixSymbol(): + n = Symbol('n', integer=True) + A = MatrixSymbol('A', n, n) + B = MatrixSymbol('B', n, n) + assert mcode(A*B) == "A*B" + assert mcode(B*A) == "B*A" + assert mcode(2*A*B) == "2*A*B" + assert mcode(B*2*A) == "2*B*A" + assert mcode(A*(B + 3*Identity(n))) == "A*(3*eye(n) + B)" + assert mcode(A**(x**2)) == "A^(x.^2)" + assert mcode(A**3) == "A^3" + assert mcode(A**S.Half) == "A^(1/2)" + + +def test_MatrixSolve(): + n = Symbol('n', integer=True) + A = MatrixSymbol('A', n, n) + x = MatrixSymbol('x', n, 1) + assert mcode(MatrixSolve(A, x)) == "A \\ x" + +def test_special_matrices(): + assert mcode(6*Identity(3)) == "6*eye(3)" + + +def test_containers(): + assert mcode([1, 2, 3, [4, 5, [6, 7]], 8, [9, 10], 11]) == \ + "{1, 2, 3, {4, 5, {6, 7}}, 8, {9, 10}, 11}" + assert mcode((1, 2, (3, 4))) == "{1, 2, {3, 4}}" + assert mcode([1]) == "{1}" + assert mcode((1,)) == "{1}" + assert mcode(Tuple(*[1, 2, 3])) == "{1, 2, 3}" + assert mcode((1, x*y, (3, x**2))) == "{1, x.*y, {3, x.^2}}" + # scalar, matrix, empty matrix and empty list + assert mcode((1, eye(3), Matrix(0, 0, []), [])) == "{1, [1 0 0; 0 1 0; 0 0 1], [], {}}" + + +def test_octave_noninline(): + source = mcode((x+y)/Catalan, assign_to='me', inline=False) + expected = ( + "Catalan = %s;\n" + "me = (x + y)/Catalan;" + ) % Catalan.evalf(17) + assert source == expected + + +def test_octave_piecewise(): + expr = Piecewise((x, x < 1), (x**2, True)) + assert mcode(expr) == "((x < 1).*(x) + (~(x < 1)).*(x.^2))" + assert mcode(expr, assign_to="r") == ( + "r = ((x < 1).*(x) + (~(x < 1)).*(x.^2));") + assert mcode(expr, assign_to="r", inline=False) == ( + "if (x < 1)\n" + " r = x;\n" + "else\n" + " r = x.^2;\n" + "end") + expr = Piecewise((x**2, x < 1), (x**3, x < 2), (x**4, x < 3), (x**5, True)) + expected = ("((x < 1).*(x.^2) + (~(x < 1)).*( ...\n" + "(x < 2).*(x.^3) + (~(x < 2)).*( ...\n" + "(x < 3).*(x.^4) + (~(x < 3)).*(x.^5))))") + assert mcode(expr) == expected + assert mcode(expr, assign_to="r") == "r = " + expected + ";" + assert mcode(expr, assign_to="r", inline=False) == ( + "if (x < 1)\n" + " r = x.^2;\n" + "elseif (x < 2)\n" + " r = x.^3;\n" + "elseif (x < 3)\n" + " r = x.^4;\n" + "else\n" + " r = x.^5;\n" + "end") + # Check that Piecewise without a True (default) condition error + expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0)) + raises(ValueError, lambda: mcode(expr)) + + +def test_octave_piecewise_times_const(): + pw = Piecewise((x, x < 1), (x**2, True)) + assert mcode(2*pw) == "2*((x < 1).*(x) + (~(x < 1)).*(x.^2))" + assert mcode(pw/x) == "((x < 1).*(x) + (~(x < 1)).*(x.^2))./x" + assert mcode(pw/(x*y)) == "((x < 1).*(x) + (~(x < 1)).*(x.^2))./(x.*y)" + assert mcode(pw/3) == "((x < 1).*(x) + (~(x < 1)).*(x.^2))/3" + + +def test_octave_matrix_assign_to(): + A = Matrix([[1, 2, 3]]) + assert mcode(A, assign_to='a') == "a = [1 2 3];" + A = Matrix([[1, 2], [3, 4]]) + assert mcode(A, assign_to='A') == "A = [1 2; 3 4];" + + +def test_octave_matrix_assign_to_more(): + # assigning to Symbol or MatrixSymbol requires lhs/rhs match + A = Matrix([[1, 2, 3]]) + B = MatrixSymbol('B', 1, 3) + C = MatrixSymbol('C', 2, 3) + assert mcode(A, assign_to=B) == "B = [1 2 3];" + raises(ValueError, lambda: mcode(A, assign_to=x)) + raises(ValueError, lambda: mcode(A, assign_to=C)) + + +def test_octave_matrix_1x1(): + A = Matrix([[3]]) + B = MatrixSymbol('B', 1, 1) + C = MatrixSymbol('C', 1, 2) + assert mcode(A, assign_to=B) == "B = 3;" + # FIXME? + #assert mcode(A, assign_to=x) == "x = 3;" + raises(ValueError, lambda: mcode(A, assign_to=C)) + + +def test_octave_matrix_elements(): + A = Matrix([[x, 2, x*y]]) + assert mcode(A[0, 0]**2 + A[0, 1] + A[0, 2]) == "x.^2 + x.*y + 2" + A = MatrixSymbol('AA', 1, 3) + assert mcode(A) == "AA" + assert mcode(A[0, 0]**2 + sin(A[0,1]) + A[0,2]) == \ + "sin(AA(1, 2)) + AA(1, 1).^2 + AA(1, 3)" + assert mcode(sum(A)) == "AA(1, 1) + AA(1, 2) + AA(1, 3)" + + +def test_octave_boolean(): + assert mcode(True) == "true" + assert mcode(S.true) == "true" + assert mcode(False) == "false" + assert mcode(S.false) == "false" + + +def test_octave_not_supported(): + with raises(NotImplementedError): + mcode(S.ComplexInfinity) + f = Function('f') + assert mcode(f(x).diff(x), strict=False) == ( + "% Not supported in Octave:\n" + "% Derivative\n" + "Derivative(f(x), x)" + ) + + +def test_octave_not_supported_not_on_whitelist(): + from sympy.functions.special.polynomials import assoc_laguerre + with raises(NotImplementedError): + mcode(assoc_laguerre(x, y, z)) + + +def test_octave_expint(): + assert mcode(expint(1, x)) == "expint(x)" + with raises(NotImplementedError): + mcode(expint(2, x)) + assert mcode(expint(y, x), strict=False) == ( + "% Not supported in Octave:\n" + "% expint\n" + "expint(y, x)" + ) + + +def test_trick_indent_with_end_else_words(): + # words starting with "end" or "else" do not confuse the indenter + t1 = S('endless') + t2 = S('elsewhere') + pw = Piecewise((t1, x < 0), (t2, x <= 1), (1, True)) + assert mcode(pw, inline=False) == ( + "if (x < 0)\n" + " endless\n" + "elseif (x <= 1)\n" + " elsewhere\n" + "else\n" + " 1\n" + "end") + + +def test_hadamard(): + A = MatrixSymbol('A', 3, 3) + B = MatrixSymbol('B', 3, 3) + v = MatrixSymbol('v', 3, 1) + h = MatrixSymbol('h', 1, 3) + C = HadamardProduct(A, B) + n = Symbol('n') + assert mcode(C) == "A.*B" + assert mcode(C*v) == "(A.*B)*v" + assert mcode(h*C*v) == "h*(A.*B)*v" + assert mcode(C*A) == "(A.*B)*A" + # mixing Hadamard and scalar strange b/c we vectorize scalars + assert mcode(C*x*y) == "(x.*y)*(A.*B)" + + # Testing HadamardPower: + assert mcode(HadamardPower(A, n)) == "A.**n" + assert mcode(HadamardPower(A, 1+n)) == "A.**(n + 1)" + assert mcode(HadamardPower(A*B.T, 1+n)) == "(A*B.T).**(n + 1)" + + +def test_sparse(): + M = SparseMatrix(5, 6, {}) + M[2, 2] = 10 + M[1, 2] = 20 + M[1, 3] = 22 + M[0, 3] = 30 + M[3, 0] = x*y + assert mcode(M) == ( + "sparse([4 2 3 1 2], [1 3 3 4 4], [x.*y 20 10 30 22], 5, 6)" + ) + + +def test_sinc(): + assert mcode(sinc(x)) == 'sinc(x/pi)' + assert mcode(sinc(x + 3)) == 'sinc((x + 3)/pi)' + assert mcode(sinc(pi*(x + 3))) == 'sinc(x + 3)' + + +def test_trigfun(): + for f in (sin, cos, tan, cot, sec, csc, asin, acos, acot, atan, asec, acsc, + sinh, cosh, tanh, coth, csch, sech, asinh, acosh, atanh, acoth, + asech, acsch): + assert octave_code(f(x) == f.__name__ + '(x)') + + +def test_specfun(): + n = Symbol('n') + for f in [besselj, bessely, besseli, besselk]: + assert octave_code(f(n, x)) == f.__name__ + '(n, x)' + for f in (erfc, erfi, erf, erfinv, erfcinv, fresnelc, fresnels, gamma): + assert octave_code(f(x)) == f.__name__ + '(x)' + assert octave_code(hankel1(n, x)) == 'besselh(n, 1, x)' + assert octave_code(hankel2(n, x)) == 'besselh(n, 2, x)' + assert octave_code(airyai(x)) == 'airy(0, x)' + assert octave_code(airyaiprime(x)) == 'airy(1, x)' + assert octave_code(airybi(x)) == 'airy(2, x)' + assert octave_code(airybiprime(x)) == 'airy(3, x)' + assert octave_code(uppergamma(n, x)) == '(gammainc(x, n, \'upper\').*gamma(n))' + assert octave_code(lowergamma(n, x)) == '(gammainc(x, n).*gamma(n))' + assert octave_code(z**lowergamma(n, x)) == 'z.^(gammainc(x, n).*gamma(n))' + assert octave_code(jn(n, x)) == 'sqrt(2)*sqrt(pi)*sqrt(1./x).*besselj(n + 1/2, x)/2' + assert octave_code(yn(n, x)) == 'sqrt(2)*sqrt(pi)*sqrt(1./x).*bessely(n + 1/2, x)/2' + assert octave_code(LambertW(x)) == 'lambertw(x)' + assert octave_code(LambertW(x, n)) == 'lambertw(n, x)' + + # Automatic rewrite + assert octave_code(Ei(x)) == '(logint(exp(x)))' + assert octave_code(dirichlet_eta(x)) == '(((x == 1).*(log(2)) + (~(x == 1)).*((1 - 2.^(1 - x)).*zeta(x))))' + assert octave_code(riemann_xi(x)) == '(pi.^(-x/2).*x.*(x - 1).*gamma(x/2).*zeta(x)/2)' + + +def test_MatrixElement_printing(): + # test cases for issue #11821 + A = MatrixSymbol("A", 1, 3) + B = MatrixSymbol("B", 1, 3) + C = MatrixSymbol("C", 1, 3) + + assert mcode(A[0, 0]) == "A(1, 1)" + assert mcode(3 * A[0, 0]) == "3*A(1, 1)" + + F = C[0, 0].subs(C, A - B) + assert mcode(F) == "(A - B)(1, 1)" + + +def test_zeta_printing_issue_14820(): + assert octave_code(zeta(x)) == 'zeta(x)' + with raises(NotImplementedError): + octave_code(zeta(x, y)) + + +def test_automatic_rewrite(): + assert octave_code(Li(x)) == '(logint(x) - logint(2))' + assert octave_code(erf2(x, y)) == '(-erf(x) + erf(y))' diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/test_precedence.py b/phivenv/Lib/site-packages/sympy/printing/tests/test_precedence.py new file mode 100644 index 0000000000000000000000000000000000000000..d08ea07483857e8c2ee7f930aa53d2dacdc58193 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/tests/test_precedence.py @@ -0,0 +1,128 @@ +from sympy.concrete.products import Product +from sympy.concrete.summations import Sum +from sympy.core.function import Derivative, Function +from sympy.core.numbers import Integer, Rational, Float, oo +from sympy.core.relational import Rel +from sympy.core.symbol import symbols +from sympy.functions import sin +from sympy.integrals.integrals import Integral +from sympy.series.order import Order + +from sympy.printing.precedence import precedence, PRECEDENCE + +x, y = symbols("x,y") + + +def test_Add(): + assert precedence(x + y) == PRECEDENCE["Add"] + assert precedence(x*y + 1) == PRECEDENCE["Add"] + + +def test_Function(): + assert precedence(sin(x)) == PRECEDENCE["Func"] + +def test_Derivative(): + assert precedence(Derivative(x, y)) == PRECEDENCE["Atom"] + +def test_Integral(): + assert precedence(Integral(x, y)) == PRECEDENCE["Atom"] + + +def test_Mul(): + assert precedence(x*y) == PRECEDENCE["Mul"] + assert precedence(-x*y) == PRECEDENCE["Add"] + + +def test_Number(): + assert precedence(Integer(0)) == PRECEDENCE["Atom"] + assert precedence(Integer(1)) == PRECEDENCE["Atom"] + assert precedence(Integer(-1)) == PRECEDENCE["Add"] + assert precedence(Integer(10)) == PRECEDENCE["Atom"] + assert precedence(Rational(5, 2)) == PRECEDENCE["Mul"] + assert precedence(Rational(-5, 2)) == PRECEDENCE["Add"] + assert precedence(Float(5)) == PRECEDENCE["Atom"] + assert precedence(Float(-5)) == PRECEDENCE["Add"] + assert precedence(oo) == PRECEDENCE["Atom"] + assert precedence(-oo) == PRECEDENCE["Add"] + + +def test_Order(): + assert precedence(Order(x)) == PRECEDENCE["Atom"] + + +def test_Pow(): + assert precedence(x**y) == PRECEDENCE["Pow"] + assert precedence(-x**y) == PRECEDENCE["Add"] + assert precedence(x**-y) == PRECEDENCE["Pow"] + + +def test_Product(): + assert precedence(Product(x, (x, y, y + 1))) == PRECEDENCE["Atom"] + + +def test_Relational(): + assert precedence(Rel(x + y, y, "<")) == PRECEDENCE["Relational"] + + +def test_Sum(): + assert precedence(Sum(x, (x, y, y + 1))) == PRECEDENCE["Atom"] + + +def test_Symbol(): + assert precedence(x) == PRECEDENCE["Atom"] + + +def test_And_Or(): + # precedence relations between logical operators, ... + assert precedence(x & y) > precedence(x | y) + assert precedence(~y) > precedence(x & y) + # ... and with other operators (cfr. other programming languages) + assert precedence(x + y) > precedence(x | y) + assert precedence(x + y) > precedence(x & y) + assert precedence(x*y) > precedence(x | y) + assert precedence(x*y) > precedence(x & y) + assert precedence(~y) > precedence(x*y) + assert precedence(~y) > precedence(x - y) + # double checks + assert precedence(x & y) == PRECEDENCE["And"] + assert precedence(x | y) == PRECEDENCE["Or"] + assert precedence(~y) == PRECEDENCE["Not"] + + +def test_custom_function_precedence_comparison(): + """ + Test cases for custom functions with different precedence values, + specifically handling: + 1. Functions with precedence < PRECEDENCE["Mul"] (50) + 2. Functions with precedence = Func (70) + + Key distinction: + 1. Lower precedence functions (45) need parentheses: -2*(x F y) + 2. Higher precedence functions (70) don't: -2*x F y + """ + class LowPrecedenceF(Function): + precedence = PRECEDENCE["Mul"] - 5 + def _sympystr(self, printer): + return f"{printer._print(self.args[0])} F {printer._print(self.args[1])}" + + class HighPrecedenceF(Function): + precedence = PRECEDENCE["Func"] + def _sympystr(self, printer): + return f"{printer._print(self.args[0])} F {printer._print(self.args[1])}" + + def test_low_precedence(): + expr1 = 2 * LowPrecedenceF(x, y) + assert str(expr1) == "2*(x F y)" + + expr2 = -2 * LowPrecedenceF(x, y) + assert str(expr2) == "-2*(x F y)" + + def test_high_precedence(): + expr1 = 2 * HighPrecedenceF(x, y) + assert str(expr1) == "2*x F y" + + expr2 = -2 * HighPrecedenceF(x, y) + assert str(expr2) == "-2*x F y" + + test_low_precedence() + test_high_precedence() diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/test_preview.py b/phivenv/Lib/site-packages/sympy/printing/tests/test_preview.py new file mode 100644 index 0000000000000000000000000000000000000000..91771ceb0466d6b0fee00570426713d02da14872 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/tests/test_preview.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- + +from sympy.core.relational import Eq +from sympy.core.symbol import Symbol +from sympy.functions.elementary.piecewise import Piecewise +from sympy.printing.preview import preview + +from io import BytesIO + + +def test_preview(): + x = Symbol('x') + obj = BytesIO() + try: + preview(x, output='png', viewer='BytesIO', outputbuffer=obj) + except RuntimeError: + pass # latex not installed on CI server + + +def test_preview_unicode_symbol(): + # issue 9107 + a = Symbol('α') + obj = BytesIO() + try: + preview(a, output='png', viewer='BytesIO', outputbuffer=obj) + except RuntimeError: + pass # latex not installed on CI server + + +def test_preview_latex_construct_in_expr(): + # see PR 9801 + x = Symbol('x') + pw = Piecewise((1, Eq(x, 0)), (0, True)) + obj = BytesIO() + try: + preview(pw, output='png', viewer='BytesIO', outputbuffer=obj) + except RuntimeError: + pass # latex not installed on CI server diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/test_pycode.py b/phivenv/Lib/site-packages/sympy/printing/tests/test_pycode.py new file mode 100644 index 0000000000000000000000000000000000000000..2c38fe81d830149cdce6b55f15e6e07513fdd146 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/tests/test_pycode.py @@ -0,0 +1,493 @@ +from sympy import Not +from sympy.codegen import Assignment +from sympy.codegen.ast import none +from sympy.codegen.cfunctions import expm1, log1p +from sympy.codegen.scipy_nodes import cosm1 +from sympy.codegen.matrix_nodes import MatrixSolve +from sympy.core import Expr, Mod, symbols, Eq, Le, Gt, zoo, oo, Rational, Pow +from sympy.core.function import Derivative +from sympy.core.numbers import pi +from sympy.core.singleton import S +from sympy.functions import acos, KroneckerDelta, Piecewise, sign, sqrt, Min, Max, cot, acsch, asec, coth, sec, log, sin, cos, tan, asin, atan, sinh, cosh, tanh, asinh, acosh, atanh +from sympy.functions.elementary.trigonometric import atan2 +from sympy.logic import And, Or +from sympy.matrices import SparseMatrix, MatrixSymbol, Identity +from sympy.printing.codeprinter import PrintMethodNotImplementedError +from sympy.printing.pycode import ( + MpmathPrinter, CmathPrinter, PythonCodePrinter, pycode, SymPyPrinter +) +from sympy.printing.tensorflow import TensorflowPrinter +from sympy.printing.numpy import NumPyPrinter, SciPyPrinter +from sympy.testing.pytest import raises, skip +from sympy.tensor import IndexedBase, Idx +from sympy.tensor.array.expressions.array_expressions import ArraySymbol, ArrayDiagonal, ArrayContraction, ZeroArray, OneArray +from sympy.external import import_module +from sympy.functions.special.gamma_functions import loggamma + + + +x, y, z = symbols('x y z') +p = IndexedBase("p") + + +def test_PythonCodePrinter(): + prntr = PythonCodePrinter() + + assert not prntr.module_imports + + assert prntr.doprint(x**y) == 'x**y' + assert prntr.doprint(Mod(x, 2)) == 'x % 2' + assert prntr.doprint(-Mod(x, y)) == '-(x % y)' + assert prntr.doprint(Mod(-x, y)) == '(-x) % y' + assert prntr.doprint(And(x, y)) == 'x and y' + assert prntr.doprint(Or(x, y)) == 'x or y' + assert prntr.doprint(1/(x+y)) == '1/(x + y)' + assert prntr.doprint(Not(x)) == 'not x' + assert not prntr.module_imports + + assert prntr.doprint(pi) == 'math.pi' + assert prntr.module_imports == {'math': {'pi'}} + + assert prntr.doprint(x**Rational(1, 2)) == 'math.sqrt(x)' + assert prntr.doprint(sqrt(x)) == 'math.sqrt(x)' + assert prntr.module_imports == {'math': {'pi', 'sqrt'}} + + assert prntr.doprint(acos(x)) == 'math.acos(x)' + assert prntr.doprint(cot(x)) == '(1/math.tan(x))' + assert prntr.doprint(coth(x)) == '((math.exp(x) + math.exp(-x))/(math.exp(x) - math.exp(-x)))' + assert prntr.doprint(asec(x)) == '(math.acos(1/x))' + assert prntr.doprint(acsch(x)) == '(math.log(math.sqrt(1 + x**(-2)) + 1/x))' + + assert prntr.doprint(Assignment(x, 2)) == 'x = 2' + assert prntr.doprint(Piecewise((1, Eq(x, 0)), + (2, x>6))) == '((1) if (x == 0) else (2) if (x > 6) else None)' + assert prntr.doprint(Piecewise((2, Le(x, 0)), + (3, Gt(x, 0)), evaluate=False)) == '((2) if (x <= 0) else'\ + ' (3) if (x > 0) else None)' + assert prntr.doprint(sign(x)) == '(0.0 if x == 0 else math.copysign(1, x))' + assert prntr.doprint(p[0, 1]) == 'p[0, 1]' + assert prntr.doprint(KroneckerDelta(x,y)) == '(1 if x == y else 0)' + + assert prntr.doprint((2,3)) == "(2, 3)" + assert prntr.doprint([2,3]) == "[2, 3]" + + assert prntr.doprint(Min(x, y)) == "min(x, y)" + assert prntr.doprint(Max(x, y)) == "max(x, y)" + + +def test_PythonCodePrinter_standard(): + prntr = PythonCodePrinter() + + assert prntr.standard == 'python3' + + raises(ValueError, lambda: PythonCodePrinter({'standard':'python4'})) + + +def test_CmathPrinter(): + p = CmathPrinter() + + assert p.doprint(sqrt(x)) == 'cmath.sqrt(x)' + assert p.doprint(log(x)) == 'cmath.log(x)' + + assert p.doprint(sin(x)) == 'cmath.sin(x)' + assert p.doprint(cos(x)) == 'cmath.cos(x)' + assert p.doprint(tan(x)) == 'cmath.tan(x)' + + assert p.doprint(asin(x)) == 'cmath.asin(x)' + assert p.doprint(acos(x)) == 'cmath.acos(x)' + assert p.doprint(atan(x)) == 'cmath.atan(x)' + + assert p.doprint(sinh(x)) == 'cmath.sinh(x)' + assert p.doprint(cosh(x)) == 'cmath.cosh(x)' + assert p.doprint(tanh(x)) == 'cmath.tanh(x)' + + assert p.doprint(asinh(x)) == 'cmath.asinh(x)' + assert p.doprint(acosh(x)) == 'cmath.acosh(x)' + assert p.doprint(atanh(x)) == 'cmath.atanh(x)' + + +def test_MpmathPrinter(): + p = MpmathPrinter() + assert p.doprint(sign(x)) == 'mpmath.sign(x)' + assert p.doprint(Rational(1, 2)) == 'mpmath.mpf(1)/mpmath.mpf(2)' + + assert p.doprint(S.Exp1) == 'mpmath.e' + assert p.doprint(S.Pi) == 'mpmath.pi' + assert p.doprint(S.GoldenRatio) == 'mpmath.phi' + assert p.doprint(S.EulerGamma) == 'mpmath.euler' + assert p.doprint(S.NaN) == 'mpmath.nan' + assert p.doprint(S.Infinity) == 'mpmath.inf' + assert p.doprint(S.NegativeInfinity) == 'mpmath.ninf' + assert p.doprint(loggamma(x)) == 'mpmath.loggamma(x)' + + +def test_NumPyPrinter(): + from sympy.core.function import Lambda + from sympy.matrices.expressions.adjoint import Adjoint + from sympy.matrices.expressions.diagonal import (DiagMatrix, DiagonalMatrix, DiagonalOf) + from sympy.matrices.expressions.funcmatrix import FunctionMatrix + from sympy.matrices.expressions.hadamard import HadamardProduct + from sympy.matrices.expressions.kronecker import KroneckerProduct + from sympy.matrices.expressions.special import (OneMatrix, ZeroMatrix) + from sympy.abc import a, b + p = NumPyPrinter() + assert p.doprint(sign(x)) == 'numpy.sign(x)' + A = MatrixSymbol("A", 2, 2) + B = MatrixSymbol("B", 2, 2) + C = MatrixSymbol("C", 1, 5) + D = MatrixSymbol("D", 3, 4) + assert p.doprint(A**(-1)) == "numpy.linalg.inv(A)" + assert p.doprint(A**5) == "numpy.linalg.matrix_power(A, 5)" + assert p.doprint(Identity(3)) == "numpy.eye(3)" + + u = MatrixSymbol('x', 2, 1) + v = MatrixSymbol('y', 2, 1) + assert p.doprint(MatrixSolve(A, u)) == 'numpy.linalg.solve(A, x)' + assert p.doprint(MatrixSolve(A, u) + v) == 'numpy.linalg.solve(A, x) + y' + + assert p.doprint(ZeroMatrix(2, 3)) == "numpy.zeros((2, 3))" + assert p.doprint(OneMatrix(2, 3)) == "numpy.ones((2, 3))" + assert p.doprint(FunctionMatrix(4, 5, Lambda((a, b), a + b))) == \ + "numpy.fromfunction(lambda a, b: a + b, (4, 5))" + assert p.doprint(HadamardProduct(A, B)) == "numpy.multiply(A, B)" + assert p.doprint(KroneckerProduct(A, B)) == "numpy.kron(A, B)" + assert p.doprint(Adjoint(A)) == "numpy.conjugate(numpy.transpose(A))" + assert p.doprint(DiagonalOf(A)) == "numpy.reshape(numpy.diag(A), (-1, 1))" + assert p.doprint(DiagMatrix(C)) == "numpy.diagflat(C)" + assert p.doprint(DiagonalMatrix(D)) == "numpy.multiply(D, numpy.eye(3, 4))" + + # Workaround for numpy negative integer power errors + assert p.doprint(x**-1) == 'x**(-1.0)' + assert p.doprint(x**-2) == 'x**(-2.0)' + + expr = Pow(2, -1, evaluate=False) + assert p.doprint(expr) == "2**(-1.0)" + + assert p.doprint(S.Exp1) == 'numpy.e' + assert p.doprint(S.Pi) == 'numpy.pi' + assert p.doprint(S.EulerGamma) == 'numpy.euler_gamma' + assert p.doprint(S.NaN) == 'numpy.nan' + assert p.doprint(S.Infinity) == 'numpy.inf' + assert p.doprint(S.NegativeInfinity) == '-numpy.inf' + + # Function rewriting operator precedence fix + assert p.doprint(sec(x)**2) == '(numpy.cos(x)**(-1.0))**2' + + +def test_issue_18770(): + numpy = import_module('numpy') + if not numpy: + skip("numpy not installed.") + + from sympy.functions.elementary.miscellaneous import (Max, Min) + from sympy.utilities.lambdify import lambdify + + expr1 = Min(0.1*x + 3, x + 1, 0.5*x + 1) + func = lambdify(x, expr1, "numpy") + assert (func(numpy.linspace(0, 3, 3)) == [1.0, 1.75, 2.5 ]).all() + assert func(4) == 3 + + expr1 = Max(x**2, x**3) + func = lambdify(x,expr1, "numpy") + assert (func(numpy.linspace(-1, 2, 4)) == [1, 0, 1, 8] ).all() + assert func(4) == 64 + + +def test_SciPyPrinter(): + p = SciPyPrinter() + expr = acos(x) + assert 'numpy' not in p.module_imports + assert p.doprint(expr) == 'numpy.arccos(x)' + assert 'numpy' in p.module_imports + assert not any(m.startswith('scipy') for m in p.module_imports) + smat = SparseMatrix(2, 5, {(0, 1): 3}) + assert p.doprint(smat) == \ + 'scipy.sparse.coo_matrix(([3], ([0], [1])), shape=(2, 5))' + assert 'scipy.sparse' in p.module_imports + + assert p.doprint(S.GoldenRatio) == 'scipy.constants.golden_ratio' + assert p.doprint(S.Pi) == 'scipy.constants.pi' + assert p.doprint(S.Exp1) == 'numpy.e' + + +def test_pycode_reserved_words(): + s1, s2 = symbols('if else') + raises(ValueError, lambda: pycode(s1 + s2, error_on_reserved=True)) + py_str = pycode(s1 + s2) + assert py_str in ('else_ + if_', 'if_ + else_') + + +def test_issue_20762(): + # Make sure pycode removes curly braces from subscripted variables + a_b, b, a_11 = symbols('a_{b} b a_{11}') + expr = a_b*b + assert pycode(expr) == 'a_b*b' + expr = a_11*b + assert pycode(expr) == 'a_11*b' + + +def test_sqrt(): + prntr = PythonCodePrinter() + assert prntr._print_Pow(sqrt(x), rational=False) == 'math.sqrt(x)' + assert prntr._print_Pow(1/sqrt(x), rational=False) == '1/math.sqrt(x)' + + prntr = PythonCodePrinter({'standard' : 'python3'}) + assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1/2)' + assert prntr._print_Pow(1/sqrt(x), rational=True) == 'x**(-1/2)' + + prntr = MpmathPrinter() + assert prntr._print_Pow(sqrt(x), rational=False) == 'mpmath.sqrt(x)' + assert prntr._print_Pow(sqrt(x), rational=True) == \ + "x**(mpmath.mpf(1)/mpmath.mpf(2))" + + prntr = NumPyPrinter() + assert prntr._print_Pow(sqrt(x), rational=False) == 'numpy.sqrt(x)' + assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1/2)' + + prntr = SciPyPrinter() + assert prntr._print_Pow(sqrt(x), rational=False) == 'numpy.sqrt(x)' + assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1/2)' + + prntr = SymPyPrinter() + assert prntr._print_Pow(sqrt(x), rational=False) == 'sympy.sqrt(x)' + assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1/2)' + + +def test_frac(): + from sympy.functions.elementary.integers import frac + + expr = frac(x) + prntr = NumPyPrinter() + assert prntr.doprint(expr) == 'numpy.mod(x, 1)' + + prntr = SciPyPrinter() + assert prntr.doprint(expr) == 'numpy.mod(x, 1)' + + prntr = PythonCodePrinter() + assert prntr.doprint(expr) == 'x % 1' + + prntr = MpmathPrinter() + assert prntr.doprint(expr) == 'mpmath.frac(x)' + + prntr = SymPyPrinter() + assert prntr.doprint(expr) == 'sympy.functions.elementary.integers.frac(x)' + + +class CustomPrintedObject(Expr): + def _numpycode(self, printer): + return 'numpy' + + def _mpmathcode(self, printer): + return 'mpmath' + + +def test_printmethod(): + obj = CustomPrintedObject() + assert NumPyPrinter().doprint(obj) == 'numpy' + assert MpmathPrinter().doprint(obj) == 'mpmath' + + +def test_codegen_ast_nodes(): + assert pycode(none) == 'None' + + +def test_issue_14283(): + prntr = PythonCodePrinter() + + assert prntr.doprint(zoo) == "math.nan" + assert prntr.doprint(-oo) == "float('-inf')" + + +def test_NumPyPrinter_print_seq(): + n = NumPyPrinter() + + assert n._print_seq(range(2)) == '(0, 1,)' + + +def test_issue_16535_16536(): + from sympy.functions.special.gamma_functions import (lowergamma, uppergamma) + + a = symbols('a') + expr1 = lowergamma(a, x) + expr2 = uppergamma(a, x) + + prntr = SciPyPrinter() + assert prntr.doprint(expr1) == 'scipy.special.gamma(a)*scipy.special.gammainc(a, x)' + assert prntr.doprint(expr2) == 'scipy.special.gamma(a)*scipy.special.gammaincc(a, x)' + + p_numpy = NumPyPrinter() + p_pycode = PythonCodePrinter({'strict': False}) + + for expr in [expr1, expr2]: + with raises(NotImplementedError): + p_numpy.doprint(expr1) + assert "Not supported" in p_pycode.doprint(expr) + + +def test_Integral(): + from sympy.functions.elementary.exponential import exp + from sympy.integrals.integrals import Integral + + single = Integral(exp(-x), (x, 0, oo)) + double = Integral(x**2*exp(x*y), (x, -z, z), (y, 0, z)) + indefinite = Integral(x**2, x) + evaluateat = Integral(x**2, (x, 1)) + + prntr = SciPyPrinter() + assert prntr.doprint(single) == 'scipy.integrate.quad(lambda x: numpy.exp(-x), 0, numpy.inf)[0]' + assert prntr.doprint(double) == 'scipy.integrate.nquad(lambda x, y: x**2*numpy.exp(x*y), ((-z, z), (0, z)))[0]' + raises(NotImplementedError, lambda: prntr.doprint(indefinite)) + raises(NotImplementedError, lambda: prntr.doprint(evaluateat)) + + prntr = MpmathPrinter() + assert prntr.doprint(single) == 'mpmath.quad(lambda x: mpmath.exp(-x), (0, mpmath.inf))' + assert prntr.doprint(double) == 'mpmath.quad(lambda x, y: x**2*mpmath.exp(x*y), (-z, z), (0, z))' + raises(NotImplementedError, lambda: prntr.doprint(indefinite)) + raises(NotImplementedError, lambda: prntr.doprint(evaluateat)) + + +def test_fresnel_integrals(): + from sympy.functions.special.error_functions import (fresnelc, fresnels) + + expr1 = fresnelc(x) + expr2 = fresnels(x) + + prntr = SciPyPrinter() + assert prntr.doprint(expr1) == 'scipy.special.fresnel(x)[1]' + assert prntr.doprint(expr2) == 'scipy.special.fresnel(x)[0]' + + p_numpy = NumPyPrinter() + p_pycode = PythonCodePrinter() + p_mpmath = MpmathPrinter() + for expr in [expr1, expr2]: + with raises(NotImplementedError): + p_numpy.doprint(expr) + with raises(NotImplementedError): + p_pycode.doprint(expr) + + assert p_mpmath.doprint(expr1) == 'mpmath.fresnelc(x)' + assert p_mpmath.doprint(expr2) == 'mpmath.fresnels(x)' + + +def test_beta(): + from sympy.functions.special.beta_functions import beta + + expr = beta(x, y) + + prntr = SciPyPrinter() + assert prntr.doprint(expr) == 'scipy.special.beta(x, y)' + + prntr = NumPyPrinter() + assert prntr.doprint(expr) == '(math.gamma(x)*math.gamma(y)/math.gamma(x + y))' + + prntr = PythonCodePrinter() + assert prntr.doprint(expr) == '(math.gamma(x)*math.gamma(y)/math.gamma(x + y))' + + prntr = PythonCodePrinter({'allow_unknown_functions': True}) + assert prntr.doprint(expr) == '(math.gamma(x)*math.gamma(y)/math.gamma(x + y))' + + prntr = MpmathPrinter() + assert prntr.doprint(expr) == 'mpmath.beta(x, y)' + +def test_airy(): + from sympy.functions.special.bessel import (airyai, airybi) + + expr1 = airyai(x) + expr2 = airybi(x) + + prntr = SciPyPrinter() + assert prntr.doprint(expr1) == 'scipy.special.airy(x)[0]' + assert prntr.doprint(expr2) == 'scipy.special.airy(x)[2]' + + prntr = NumPyPrinter({'strict': False}) + assert "Not supported" in prntr.doprint(expr1) + assert "Not supported" in prntr.doprint(expr2) + + prntr = PythonCodePrinter({'strict': False}) + assert "Not supported" in prntr.doprint(expr1) + assert "Not supported" in prntr.doprint(expr2) + +def test_airy_prime(): + from sympy.functions.special.bessel import (airyaiprime, airybiprime) + + expr1 = airyaiprime(x) + expr2 = airybiprime(x) + + prntr = SciPyPrinter() + assert prntr.doprint(expr1) == 'scipy.special.airy(x)[1]' + assert prntr.doprint(expr2) == 'scipy.special.airy(x)[3]' + + prntr = NumPyPrinter({'strict': False}) + assert "Not supported" in prntr.doprint(expr1) + assert "Not supported" in prntr.doprint(expr2) + + prntr = PythonCodePrinter({'strict': False}) + assert "Not supported" in prntr.doprint(expr1) + assert "Not supported" in prntr.doprint(expr2) + + +def test_numerical_accuracy_functions(): + prntr = SciPyPrinter() + assert prntr.doprint(expm1(x)) == 'numpy.expm1(x)' + assert prntr.doprint(log1p(x)) == 'numpy.log1p(x)' + assert prntr.doprint(cosm1(x)) == 'scipy.special.cosm1(x)' + +def test_array_printer(): + A = ArraySymbol('A', (4,4,6,6,6)) + I = IndexedBase('I') + i,j,k = Idx('i', (0,1)), Idx('j', (2,3)), Idx('k', (4,5)) + + prntr = NumPyPrinter() + assert prntr.doprint(ZeroArray(5)) == 'numpy.zeros((5,))' + assert prntr.doprint(OneArray(5)) == 'numpy.ones((5,))' + assert prntr.doprint(ArrayContraction(A, [2,3])) == 'numpy.einsum("abccd->abd", A)' + assert prntr.doprint(I) == 'I' + assert prntr.doprint(ArrayDiagonal(A, [2,3,4])) == 'numpy.einsum("abccc->abc", A)' + assert prntr.doprint(ArrayDiagonal(A, [0,1], [2,3])) == 'numpy.einsum("aabbc->cab", A)' + assert prntr.doprint(ArrayContraction(A, [2], [3])) == 'numpy.einsum("abcde->abe", A)' + assert prntr.doprint(Assignment(I[i,j,k], I[i,j,k])) == 'I = I' + + prntr = TensorflowPrinter() + assert prntr.doprint(ZeroArray(5)) == 'tensorflow.zeros((5,))' + assert prntr.doprint(OneArray(5)) == 'tensorflow.ones((5,))' + assert prntr.doprint(ArrayContraction(A, [2,3])) == 'tensorflow.linalg.einsum("abccd->abd", A)' + assert prntr.doprint(I) == 'I' + assert prntr.doprint(ArrayDiagonal(A, [2,3,4])) == 'tensorflow.linalg.einsum("abccc->abc", A)' + assert prntr.doprint(ArrayDiagonal(A, [0,1], [2,3])) == 'tensorflow.linalg.einsum("aabbc->cab", A)' + assert prntr.doprint(ArrayContraction(A, [2], [3])) == 'tensorflow.linalg.einsum("abcde->abe", A)' + assert prntr.doprint(Assignment(I[i,j,k], I[i,j,k])) == 'I = I' + + +def test_custom_Derivative_methods(): + class MyPrinter(SciPyPrinter): + def _print_Derivative_cosm1(self, args, seq_orders): + arg, = args + order, = seq_orders + return 'my_custom_cosm1(%s, deriv_order=%d)' % (self._print(arg), order) + + def _print_Derivative_atan2(self, args, seq_orders): + arg1, arg2 = args + ord1, ord2 = seq_orders + return 'my_custom_atan2(%s, %s, deriv1=%d, deriv2=%d)' % ( + self._print(arg1), self._print(arg2), ord1, ord2 + ) + + p = MyPrinter() + cosm1_1 = cosm1(x).diff(x, evaluate=False) + assert p.doprint(cosm1_1) == 'my_custom_cosm1(x, deriv_order=1)' + atan2_2_3 = atan2(x, y).diff(x, 2, y, 3, evaluate=False) + assert p.doprint(atan2_2_3) == 'my_custom_atan2(x, y, deriv1=2, deriv2=3)' + + try: + p.doprint(expm1(x).diff(x, evaluate=False)) + except PrintMethodNotImplementedError as e: + assert '_print_Derivative_expm1' in repr(e) + else: + assert False # should have thrown + + try: + p.doprint(Derivative(cosm1(x**2),x)) + except ValueError as e: + assert '_print_Derivative(' in repr(e) + else: + assert False # should have thrown diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/test_python.py b/phivenv/Lib/site-packages/sympy/printing/tests/test_python.py new file mode 100644 index 0000000000000000000000000000000000000000..fb94a662be90934a672d08b3de44a22e2580d8b6 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/tests/test_python.py @@ -0,0 +1,203 @@ +from sympy.core.function import (Derivative, Function) +from sympy.core.numbers import (I, Rational, oo, pi) +from sympy.core.relational import (Eq, Ge, Gt, Le, Lt, Ne) +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.complexes import (Abs, conjugate) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import sin +from sympy.integrals.integrals import Integral +from sympy.matrices.dense import Matrix +from sympy.series.limits import limit + +from sympy.printing.python import python + +from sympy.testing.pytest import raises, XFAIL + +x, y = symbols('x,y') +th = Symbol('theta') +ph = Symbol('phi') + + +def test_python_basic(): + # Simple numbers/symbols + assert python(-Rational(1)/2) == "e = Rational(-1, 2)" + assert python(-Rational(13)/22) == "e = Rational(-13, 22)" + assert python(oo) == "e = oo" + + # Powers + assert python(x**2) == "x = Symbol(\'x\')\ne = x**2" + assert python(1/x) == "x = Symbol('x')\ne = 1/x" + assert python(y*x**-2) == "y = Symbol('y')\nx = Symbol('x')\ne = y/x**2" + assert python( + x**Rational(-5, 2)) == "x = Symbol('x')\ne = x**Rational(-5, 2)" + + # Sums of terms + assert python(x**2 + x + 1) in [ + "x = Symbol('x')\ne = 1 + x + x**2", + "x = Symbol('x')\ne = x + x**2 + 1", + "x = Symbol('x')\ne = x**2 + x + 1", ] + assert python(1 - x) in [ + "x = Symbol('x')\ne = 1 - x", + "x = Symbol('x')\ne = -x + 1"] + assert python(1 - 2*x) in [ + "x = Symbol('x')\ne = 1 - 2*x", + "x = Symbol('x')\ne = -2*x + 1"] + assert python(1 - Rational(3, 2)*y/x) in [ + "y = Symbol('y')\nx = Symbol('x')\ne = 1 - 3/2*y/x", + "y = Symbol('y')\nx = Symbol('x')\ne = -3/2*y/x + 1", + "y = Symbol('y')\nx = Symbol('x')\ne = 1 - 3*y/(2*x)"] + + # Multiplication + assert python(x/y) == "x = Symbol('x')\ny = Symbol('y')\ne = x/y" + assert python(-x/y) == "x = Symbol('x')\ny = Symbol('y')\ne = -x/y" + assert python((x + 2)/y) in [ + "y = Symbol('y')\nx = Symbol('x')\ne = 1/y*(2 + x)", + "y = Symbol('y')\nx = Symbol('x')\ne = 1/y*(x + 2)", + "x = Symbol('x')\ny = Symbol('y')\ne = 1/y*(2 + x)", + "x = Symbol('x')\ny = Symbol('y')\ne = (2 + x)/y", + "x = Symbol('x')\ny = Symbol('y')\ne = (x + 2)/y"] + assert python((1 + x)*y) in [ + "y = Symbol('y')\nx = Symbol('x')\ne = y*(1 + x)", + "y = Symbol('y')\nx = Symbol('x')\ne = y*(x + 1)", ] + + # Check for proper placement of negative sign + assert python(-5*x/(x + 10)) == "x = Symbol('x')\ne = -5*x/(x + 10)" + assert python(1 - Rational(3, 2)*(x + 1)) in [ + "x = Symbol('x')\ne = Rational(-3, 2)*x + Rational(-1, 2)", + "x = Symbol('x')\ne = -3*x/2 + Rational(-1, 2)", + "x = Symbol('x')\ne = -3*x/2 + Rational(-1, 2)" + ] + + +def test_python_keyword_symbol_name_escaping(): + # Check for escaping of keywords + assert python( + 5*Symbol("lambda")) == "lambda_ = Symbol('lambda')\ne = 5*lambda_" + assert (python(5*Symbol("lambda") + 7*Symbol("lambda_")) == + "lambda__ = Symbol('lambda')\nlambda_ = Symbol('lambda_')\ne = 7*lambda_ + 5*lambda__") + assert (python(5*Symbol("for") + Function("for_")(8)) == + "for__ = Symbol('for')\nfor_ = Function('for_')\ne = 5*for__ + for_(8)") + + +def test_python_keyword_function_name_escaping(): + assert python( + 5*Function("for")(8)) == "for_ = Function('for')\ne = 5*for_(8)" + + +def test_python_relational(): + assert python(Eq(x, y)) == "x = Symbol('x')\ny = Symbol('y')\ne = Eq(x, y)" + assert python(Ge(x, y)) == "x = Symbol('x')\ny = Symbol('y')\ne = x >= y" + assert python(Le(x, y)) == "x = Symbol('x')\ny = Symbol('y')\ne = x <= y" + assert python(Gt(x, y)) == "x = Symbol('x')\ny = Symbol('y')\ne = x > y" + assert python(Lt(x, y)) == "x = Symbol('x')\ny = Symbol('y')\ne = x < y" + assert python(Ne(x/(y + 1), y**2)) in [ + "x = Symbol('x')\ny = Symbol('y')\ne = Ne(x/(1 + y), y**2)", + "x = Symbol('x')\ny = Symbol('y')\ne = Ne(x/(y + 1), y**2)"] + + +def test_python_functions(): + # Simple + assert python(2*x + exp(x)) in "x = Symbol('x')\ne = 2*x + exp(x)" + assert python(sqrt(2)) == 'e = sqrt(2)' + assert python(2**Rational(1, 3)) == 'e = 2**Rational(1, 3)' + assert python(sqrt(2 + pi)) == 'e = sqrt(2 + pi)' + assert python((2 + pi)**Rational(1, 3)) == 'e = (2 + pi)**Rational(1, 3)' + assert python(2**Rational(1, 4)) == 'e = 2**Rational(1, 4)' + assert python(Abs(x)) == "x = Symbol('x')\ne = Abs(x)" + assert python( + Abs(x/(x**2 + 1))) in ["x = Symbol('x')\ne = Abs(x/(1 + x**2))", + "x = Symbol('x')\ne = Abs(x/(x**2 + 1))"] + + # Univariate/Multivariate functions + f = Function('f') + assert python(f(x)) == "x = Symbol('x')\nf = Function('f')\ne = f(x)" + assert python(f(x, y)) == "x = Symbol('x')\ny = Symbol('y')\nf = Function('f')\ne = f(x, y)" + assert python(f(x/(y + 1), y)) in [ + "x = Symbol('x')\ny = Symbol('y')\nf = Function('f')\ne = f(x/(1 + y), y)", + "x = Symbol('x')\ny = Symbol('y')\nf = Function('f')\ne = f(x/(y + 1), y)"] + + # Nesting of square roots + assert python(sqrt((sqrt(x + 1)) + 1)) in [ + "x = Symbol('x')\ne = sqrt(1 + sqrt(1 + x))", + "x = Symbol('x')\ne = sqrt(sqrt(x + 1) + 1)"] + + # Nesting of powers + assert python((((x + 1)**Rational(1, 3)) + 1)**Rational(1, 3)) in [ + "x = Symbol('x')\ne = (1 + (1 + x)**Rational(1, 3))**Rational(1, 3)", + "x = Symbol('x')\ne = ((x + 1)**Rational(1, 3) + 1)**Rational(1, 3)"] + + # Function powers + assert python(sin(x)**2) == "x = Symbol('x')\ne = sin(x)**2" + + +@XFAIL +def test_python_functions_conjugates(): + a, b = map(Symbol, 'ab') + assert python( conjugate(a + b*I) ) == '_ _\na - I*b' + assert python( conjugate(exp(a + b*I)) ) == ' _ _\n a - I*b\ne ' + + +def test_python_derivatives(): + # Simple + f_1 = Derivative(log(x), x, evaluate=False) + assert python(f_1) == "x = Symbol('x')\ne = Derivative(log(x), x)" + + f_2 = Derivative(log(x), x, evaluate=False) + x + assert python(f_2) == "x = Symbol('x')\ne = x + Derivative(log(x), x)" + + # Multiple symbols + f_3 = Derivative(log(x) + x**2, x, y, evaluate=False) + assert python(f_3) == \ + "x = Symbol('x')\ny = Symbol('y')\ne = Derivative(x**2 + log(x), x, y)" + + f_4 = Derivative(2*x*y, y, x, evaluate=False) + x**2 + assert python(f_4) in [ + "x = Symbol('x')\ny = Symbol('y')\ne = x**2 + Derivative(2*x*y, y, x)", + "x = Symbol('x')\ny = Symbol('y')\ne = Derivative(2*x*y, y, x) + x**2"] + + +def test_python_integrals(): + # Simple + f_1 = Integral(log(x), x) + assert python(f_1) == "x = Symbol('x')\ne = Integral(log(x), x)" + + f_2 = Integral(x**2, x) + assert python(f_2) == "x = Symbol('x')\ne = Integral(x**2, x)" + + # Double nesting of pow + f_3 = Integral(x**(2**x), x) + assert python(f_3) == "x = Symbol('x')\ne = Integral(x**(2**x), x)" + + # Definite integrals + f_4 = Integral(x**2, (x, 1, 2)) + assert python(f_4) == "x = Symbol('x')\ne = Integral(x**2, (x, 1, 2))" + + f_5 = Integral(x**2, (x, Rational(1, 2), 10)) + assert python( + f_5) == "x = Symbol('x')\ne = Integral(x**2, (x, Rational(1, 2), 10))" + + # Nested integrals + f_6 = Integral(x**2*y**2, x, y) + assert python(f_6) == "x = Symbol('x')\ny = Symbol('y')\ne = Integral(x**2*y**2, x, y)" + + +def test_python_matrix(): + p = python(Matrix([[x**2+1, 1], [y, x+y]])) + s = "x = Symbol('x')\ny = Symbol('y')\ne = MutableDenseMatrix([[x**2 + 1, 1], [y, x + y]])" + assert p == s + +def test_python_limits(): + assert python(limit(x, x, oo)) == 'e = oo' + assert python(limit(x**2, x, 0)) == 'e = 0' + +def test_issue_20762(): + # Make sure Python removes curly braces from subscripted variables + a_b = Symbol('a_{b}') + b = Symbol('b') + expr = a_b*b + assert python(expr) == "a_b = Symbol('a_{b}')\nb = Symbol('b')\ne = a_b*b" + + +def test_settings(): + raises(TypeError, lambda: python(x, method="garbage")) diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/test_rcode.py b/phivenv/Lib/site-packages/sympy/printing/tests/test_rcode.py new file mode 100644 index 0000000000000000000000000000000000000000..a83235b0654c6bf24c30846dbf68678d29cd3c80 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/tests/test_rcode.py @@ -0,0 +1,476 @@ +from sympy.core import (S, pi, oo, Symbol, symbols, Rational, Integer, + GoldenRatio, EulerGamma, Catalan, Lambda, Dummy) +from sympy.functions import (Piecewise, sin, cos, Abs, exp, ceiling, sqrt, + gamma, sign, Max, Min, factorial, beta) +from sympy.core.relational import (Eq, Ge, Gt, Le, Lt, Ne) +from sympy.sets import Range +from sympy.logic import ITE +from sympy.codegen import For, aug_assign, Assignment +from sympy.testing.pytest import raises +from sympy.printing.rcode import RCodePrinter +from sympy.utilities.lambdify import implemented_function +from sympy.tensor import IndexedBase, Idx +from sympy.matrices import Matrix, MatrixSymbol + +from sympy.printing.rcode import rcode + +x, y, z = symbols('x,y,z') + + +def test_printmethod(): + class fabs(Abs): + def _rcode(self, printer): + return "abs(%s)" % printer._print(self.args[0]) + + assert rcode(fabs(x)) == "abs(x)" + + +def test_rcode_sqrt(): + assert rcode(sqrt(x)) == "sqrt(x)" + assert rcode(x**0.5) == "sqrt(x)" + assert rcode(sqrt(x)) == "sqrt(x)" + + +def test_rcode_Pow(): + assert rcode(x**3) == "x^3" + assert rcode(x**(y**3)) == "x^(y^3)" + g = implemented_function('g', Lambda(x, 2*x)) + assert rcode(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \ + "(3.5*2*x)^(-x + y^x)/(x^2 + y)" + assert rcode(x**-1.0) == '1.0/x' + assert rcode(x**Rational(2, 3)) == 'x^(2.0/3.0)' + _cond_cfunc = [(lambda base, exp: exp.is_integer, "dpowi"), + (lambda base, exp: not exp.is_integer, "pow")] + assert rcode(x**3, user_functions={'Pow': _cond_cfunc}) == 'dpowi(x, 3)' + assert rcode(x**3.2, user_functions={'Pow': _cond_cfunc}) == 'pow(x, 3.2)' + + +def test_rcode_Max(): + # Test for gh-11926 + assert rcode(Max(x,x*x),user_functions={"Max":"my_max", "Pow":"my_pow"}) == 'my_max(x, my_pow(x, 2))' + + +def test_rcode_constants_mathh(): + assert rcode(exp(1)) == "exp(1)" + assert rcode(pi) == "pi" + assert rcode(oo) == "Inf" + assert rcode(-oo) == "-Inf" + + +def test_rcode_constants_other(): + assert rcode(2*GoldenRatio) == "GoldenRatio = 1.61803398874989;\n2*GoldenRatio" + assert rcode( + 2*Catalan) == "Catalan = 0.915965594177219;\n2*Catalan" + assert rcode(2*EulerGamma) == "EulerGamma = 0.577215664901533;\n2*EulerGamma" + + +def test_rcode_Rational(): + assert rcode(Rational(3, 7)) == "3.0/7.0" + assert rcode(Rational(18, 9)) == "2" + assert rcode(Rational(3, -7)) == "-3.0/7.0" + assert rcode(Rational(-3, -7)) == "3.0/7.0" + assert rcode(x + Rational(3, 7)) == "x + 3.0/7.0" + assert rcode(Rational(3, 7)*x) == "(3.0/7.0)*x" + + +def test_rcode_Integer(): + assert rcode(Integer(67)) == "67" + assert rcode(Integer(-1)) == "-1" + + +def test_rcode_functions(): + assert rcode(sin(x) ** cos(x)) == "sin(x)^cos(x)" + assert rcode(factorial(x) + gamma(y)) == "factorial(x) + gamma(y)" + assert rcode(beta(Min(x, y), Max(x, y))) == "beta(min(x, y), max(x, y))" + + +def test_rcode_inline_function(): + x = symbols('x') + g = implemented_function('g', Lambda(x, 2*x)) + assert rcode(g(x)) == "2*x" + g = implemented_function('g', Lambda(x, 2*x/Catalan)) + assert rcode( + g(x)) == "Catalan = %s;\n2*x/Catalan" % Catalan.n() + A = IndexedBase('A') + i = Idx('i', symbols('n', integer=True)) + g = implemented_function('g', Lambda(x, x*(1 + x)*(2 + x))) + res=rcode(g(A[i]), assign_to=A[i]) + ref=( + "for (i in 1:n){\n" + " A[i] = (A[i] + 1)*(A[i] + 2)*A[i];\n" + "}" + ) + assert res == ref + + +def test_rcode_exceptions(): + assert rcode(ceiling(x)) == "ceiling(x)" + assert rcode(Abs(x)) == "abs(x)" + assert rcode(gamma(x)) == "gamma(x)" + + +def test_rcode_user_functions(): + x = symbols('x', integer=False) + n = symbols('n', integer=True) + custom_functions = { + "ceiling": "myceil", + "Abs": [(lambda x: not x.is_integer, "fabs"), (lambda x: x.is_integer, "abs")], + } + assert rcode(ceiling(x), user_functions=custom_functions) == "myceil(x)" + assert rcode(Abs(x), user_functions=custom_functions) == "fabs(x)" + assert rcode(Abs(n), user_functions=custom_functions) == "abs(n)" + + +def test_rcode_boolean(): + assert rcode(True) == "True" + assert rcode(S.true) == "True" + assert rcode(False) == "False" + assert rcode(S.false) == "False" + assert rcode(x & y) == "x & y" + assert rcode(x | y) == "x | y" + assert rcode(~x) == "!x" + assert rcode(x & y & z) == "x & y & z" + assert rcode(x | y | z) == "x | y | z" + assert rcode((x & y) | z) == "z | x & y" + assert rcode((x | y) & z) == "z & (x | y)" + +def test_rcode_Relational(): + assert rcode(Eq(x, y)) == "x == y" + assert rcode(Ne(x, y)) == "x != y" + assert rcode(Le(x, y)) == "x <= y" + assert rcode(Lt(x, y)) == "x < y" + assert rcode(Gt(x, y)) == "x > y" + assert rcode(Ge(x, y)) == "x >= y" + + +def test_rcode_Piecewise(): + expr = Piecewise((x, x < 1), (x**2, True)) + res=rcode(expr) + ref="ifelse(x < 1,x,x^2)" + assert res == ref + tau=Symbol("tau") + res=rcode(expr,tau) + ref="tau = ifelse(x < 1,x,x^2);" + assert res == ref + + expr = 2*Piecewise((x, x < 1), (x**2, x<2), (x**3,True)) + assert rcode(expr) == "2*ifelse(x < 1,x,ifelse(x < 2,x^2,x^3))" + res = rcode(expr, assign_to='c') + assert res == "c = 2*ifelse(x < 1,x,ifelse(x < 2,x^2,x^3));" + + # Check that Piecewise without a True (default) condition error + #expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0)) + #raises(ValueError, lambda: rcode(expr)) + expr = 2*Piecewise((x, x < 1), (x**2, x<2)) + assert(rcode(expr))== "2*ifelse(x < 1,x,ifelse(x < 2,x^2,NA))" + + +def test_rcode_sinc(): + from sympy.functions.elementary.trigonometric import sinc + expr = sinc(x) + res = rcode(expr) + ref = "(ifelse(x != 0,sin(x)/x,1))" + assert res == ref + + +def test_rcode_Piecewise_deep(): + p = rcode(2*Piecewise((x, x < 1), (x + 1, x < 2), (x**2, True))) + assert p == "2*ifelse(x < 1,x,ifelse(x < 2,x + 1,x^2))" + expr = x*y*z + x**2 + y**2 + Piecewise((0, x < 0.5), (1, True)) + cos(z) - 1 + p = rcode(expr) + ref="x^2 + x*y*z + y^2 + ifelse(x < 0.5,0,1) + cos(z) - 1" + assert p == ref + + ref="c = x^2 + x*y*z + y^2 + ifelse(x < 0.5,0,1) + cos(z) - 1;" + p = rcode(expr, assign_to='c') + assert p == ref + + +def test_rcode_ITE(): + expr = ITE(x < 1, y, z) + p = rcode(expr) + ref="ifelse(x < 1,y,z)" + assert p == ref + + +def test_rcode_settings(): + raises(TypeError, lambda: rcode(sin(x), method="garbage")) + + +def test_rcode_Indexed(): + n, m, o = symbols('n m o', integer=True) + i, j, k = Idx('i', n), Idx('j', m), Idx('k', o) + p = RCodePrinter() + p._not_r = set() + + x = IndexedBase('x')[j] + assert p._print_Indexed(x) == 'x[j]' + A = IndexedBase('A')[i, j] + assert p._print_Indexed(A) == 'A[i, j]' + B = IndexedBase('B')[i, j, k] + assert p._print_Indexed(B) == 'B[i, j, k]' + + assert p._not_r == set() + +def test_rcode_Indexed_without_looking_for_contraction(): + len_y = 5 + y = IndexedBase('y', shape=(len_y,)) + x = IndexedBase('x', shape=(len_y,)) + Dy = IndexedBase('Dy', shape=(len_y-1,)) + i = Idx('i', len_y-1) + e=Eq(Dy[i], (y[i+1]-y[i])/(x[i+1]-x[i])) + code0 = rcode(e.rhs, assign_to=e.lhs, contract=False) + assert code0 == 'Dy[i] = (y[%s] - y[i])/(x[%s] - x[i]);' % (i + 1, i + 1) + + +def test_rcode_loops_matrix_vector(): + n, m = symbols('n m', integer=True) + A = IndexedBase('A') + x = IndexedBase('x') + y = IndexedBase('y') + i = Idx('i', m) + j = Idx('j', n) + + s = ( + 'for (i in 1:m){\n' + ' y[i] = 0;\n' + '}\n' + 'for (i in 1:m){\n' + ' for (j in 1:n){\n' + ' y[i] = A[i, j]*x[j] + y[i];\n' + ' }\n' + '}' + ) + c = rcode(A[i, j]*x[j], assign_to=y[i]) + assert c == s + + +def test_dummy_loops(): + # the following line could also be + # [Dummy(s, integer=True) for s in 'im'] + # or [Dummy(integer=True) for s in 'im'] + i, m = symbols('i m', integer=True, cls=Dummy) + x = IndexedBase('x') + y = IndexedBase('y') + i = Idx(i, m) + + expected = ( + 'for (i_%(icount)i in 1:m_%(mcount)i){\n' + ' y[i_%(icount)i] = x[i_%(icount)i];\n' + '}' + ) % {'icount': i.label.dummy_index, 'mcount': m.dummy_index} + code = rcode(x[i], assign_to=y[i]) + assert code == expected + + +def test_rcode_loops_add(): + n, m = symbols('n m', integer=True) + A = IndexedBase('A') + x = IndexedBase('x') + y = IndexedBase('y') + z = IndexedBase('z') + i = Idx('i', m) + j = Idx('j', n) + + s = ( + 'for (i in 1:m){\n' + ' y[i] = x[i] + z[i];\n' + '}\n' + 'for (i in 1:m){\n' + ' for (j in 1:n){\n' + ' y[i] = A[i, j]*x[j] + y[i];\n' + ' }\n' + '}' + ) + c = rcode(A[i, j]*x[j] + x[i] + z[i], assign_to=y[i]) + assert c == s + + +def test_rcode_loops_multiple_contractions(): + n, m, o, p = symbols('n m o p', integer=True) + a = IndexedBase('a') + b = IndexedBase('b') + y = IndexedBase('y') + i = Idx('i', m) + j = Idx('j', n) + k = Idx('k', o) + l = Idx('l', p) + + s = ( + 'for (i in 1:m){\n' + ' y[i] = 0;\n' + '}\n' + 'for (i in 1:m){\n' + ' for (j in 1:n){\n' + ' for (k in 1:o){\n' + ' for (l in 1:p){\n' + ' y[i] = a[i, j, k, l]*b[j, k, l] + y[i];\n' + ' }\n' + ' }\n' + ' }\n' + '}' + ) + c = rcode(b[j, k, l]*a[i, j, k, l], assign_to=y[i]) + assert c == s + + +def test_rcode_loops_addfactor(): + n, m, o, p = symbols('n m o p', integer=True) + a = IndexedBase('a') + b = IndexedBase('b') + c = IndexedBase('c') + y = IndexedBase('y') + i = Idx('i', m) + j = Idx('j', n) + k = Idx('k', o) + l = Idx('l', p) + + s = ( + 'for (i in 1:m){\n' + ' y[i] = 0;\n' + '}\n' + 'for (i in 1:m){\n' + ' for (j in 1:n){\n' + ' for (k in 1:o){\n' + ' for (l in 1:p){\n' + ' y[i] = (a[i, j, k, l] + b[i, j, k, l])*c[j, k, l] + y[i];\n' + ' }\n' + ' }\n' + ' }\n' + '}' + ) + c = rcode((a[i, j, k, l] + b[i, j, k, l])*c[j, k, l], assign_to=y[i]) + assert c == s + + +def test_rcode_loops_multiple_terms(): + n, m, o, p = symbols('n m o p', integer=True) + a = IndexedBase('a') + b = IndexedBase('b') + c = IndexedBase('c') + y = IndexedBase('y') + i = Idx('i', m) + j = Idx('j', n) + k = Idx('k', o) + + s0 = ( + 'for (i in 1:m){\n' + ' y[i] = 0;\n' + '}\n' + ) + s1 = ( + 'for (i in 1:m){\n' + ' for (j in 1:n){\n' + ' for (k in 1:o){\n' + ' y[i] = b[j]*b[k]*c[i, j, k] + y[i];\n' + ' }\n' + ' }\n' + '}\n' + ) + s2 = ( + 'for (i in 1:m){\n' + ' for (k in 1:o){\n' + ' y[i] = a[i, k]*b[k] + y[i];\n' + ' }\n' + '}\n' + ) + s3 = ( + 'for (i in 1:m){\n' + ' for (j in 1:n){\n' + ' y[i] = a[i, j]*b[j] + y[i];\n' + ' }\n' + '}\n' + ) + c = rcode( + b[j]*a[i, j] + b[k]*a[i, k] + b[j]*b[k]*c[i, j, k], assign_to=y[i]) + + ref={} + ref[0] = s0 + s1 + s2 + s3[:-1] + ref[1] = s0 + s1 + s3 + s2[:-1] + ref[2] = s0 + s2 + s1 + s3[:-1] + ref[3] = s0 + s2 + s3 + s1[:-1] + ref[4] = s0 + s3 + s1 + s2[:-1] + ref[5] = s0 + s3 + s2 + s1[:-1] + + assert (c == ref[0] or + c == ref[1] or + c == ref[2] or + c == ref[3] or + c == ref[4] or + c == ref[5]) + + +def test_dereference_printing(): + expr = x + y + sin(z) + z + assert rcode(expr, dereference=[z]) == "x + y + (*z) + sin((*z))" + + +def test_Matrix_printing(): + # Test returning a Matrix + mat = Matrix([x*y, Piecewise((2 + x, y>0), (y, True)), sin(z)]) + A = MatrixSymbol('A', 3, 1) + p = rcode(mat, A) + assert p == ( + "A[0] = x*y;\n" + "A[1] = ifelse(y > 0,x + 2,y);\n" + "A[2] = sin(z);") + # Test using MatrixElements in expressions + expr = Piecewise((2*A[2, 0], x > 0), (A[2, 0], True)) + sin(A[1, 0]) + A[0, 0] + p = rcode(expr) + assert p == ("ifelse(x > 0,2*A[2],A[2]) + sin(A[1]) + A[0]") + # Test using MatrixElements in a Matrix + q = MatrixSymbol('q', 5, 1) + M = MatrixSymbol('M', 3, 3) + m = Matrix([[sin(q[1,0]), 0, cos(q[2,0])], + [q[1,0] + q[2,0], q[3, 0], 5], + [2*q[4, 0]/q[1,0], sqrt(q[0,0]) + 4, 0]]) + assert rcode(m, M) == ( + "M[0] = sin(q[1]);\n" + "M[1] = 0;\n" + "M[2] = cos(q[2]);\n" + "M[3] = q[1] + q[2];\n" + "M[4] = q[3];\n" + "M[5] = 5;\n" + "M[6] = 2*q[4]/q[1];\n" + "M[7] = sqrt(q[0]) + 4;\n" + "M[8] = 0;") + + +def test_rcode_sgn(): + + expr = sign(x) * y + assert rcode(expr) == 'y*sign(x)' + p = rcode(expr, 'z') + assert p == 'z = y*sign(x);' + + p = rcode(sign(2 * x + x**2) * x + x**2) + assert p == "x^2 + x*sign(x^2 + 2*x)" + + expr = sign(cos(x)) + p = rcode(expr) + assert p == 'sign(cos(x))' + +def test_rcode_Assignment(): + assert rcode(Assignment(x, y + z)) == 'x = y + z;' + assert rcode(aug_assign(x, '+', y + z)) == 'x += y + z;' + + +def test_rcode_For(): + f = For(x, Range(0, 10, 2), [aug_assign(y, '*', x)]) + sol = rcode(f) + assert sol == ("for(x in seq(from=0, to=9, by=2){\n" + " y *= x;\n" + "}") + + +def test_MatrixElement_printing(): + # test cases for issue #11821 + A = MatrixSymbol("A", 1, 3) + B = MatrixSymbol("B", 1, 3) + C = MatrixSymbol("C", 1, 3) + + assert(rcode(A[0, 0]) == "A[0]") + assert(rcode(3 * A[0, 0]) == "3*A[0]") + + F = C[0, 0].subs(C, A - B) + assert(rcode(F) == "(A - B)[0]") diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/test_repr.py b/phivenv/Lib/site-packages/sympy/printing/tests/test_repr.py new file mode 100644 index 0000000000000000000000000000000000000000..da58883b4fb027ed82db842a0a1ce5f76a49a8bb --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/tests/test_repr.py @@ -0,0 +1,382 @@ +from __future__ import annotations +from typing import Any + +from sympy.external.gmpy import GROUND_TYPES +from sympy.testing.pytest import raises, warns_deprecated_sympy +from sympy.assumptions.ask import Q +from sympy.core.function import (Function, WildFunction) +from sympy.core.numbers import (AlgebraicNumber, Float, Integer, Rational) +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol, Wild, symbols) +from sympy.core.sympify import sympify +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.miscellaneous import (root, sqrt) +from sympy.functions.elementary.trigonometric import sin +from sympy.functions.special.delta_functions import Heaviside +from sympy.logic.boolalg import (false, true) +from sympy.matrices.dense import (Matrix, ones) +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.immutable import ImmutableDenseMatrix +from sympy.combinatorics import Cycle, Permutation +from sympy.core.symbol import Str +from sympy.geometry import Point, Ellipse +from sympy.printing import srepr +from sympy.polys import ring, field, ZZ, QQ, lex, grlex, Poly +from sympy.polys.polyclasses import DMP +from sympy.polys.agca.extensions import FiniteExtension + +x, y = symbols('x,y') + +# eval(srepr(expr)) == expr has to succeed in the right environment. The right +# environment is the scope of "from sympy import *" for most cases. +ENV: dict[str, Any] = {"Str": Str} +exec("from sympy import *", ENV) + + +def sT(expr, string, import_stmt=None, **kwargs): + """ + sT := sreprTest + + Tests that srepr delivers the expected string and that + the condition eval(srepr(expr))==expr holds. + """ + if import_stmt is None: + ENV2 = ENV + else: + ENV2 = ENV.copy() + exec(import_stmt, ENV2) + + assert srepr(expr, **kwargs) == string + assert eval(string, ENV2) == expr + + +def test_printmethod(): + class R(Abs): + def _sympyrepr(self, printer): + return "foo(%s)" % printer._print(self.args[0]) + assert srepr(R(x)) == "foo(Symbol('x'))" + + +def test_Add(): + sT(x + y, "Add(Symbol('x'), Symbol('y'))") + assert srepr(x**2 + 1, order='lex') == "Add(Pow(Symbol('x'), Integer(2)), Integer(1))" + assert srepr(x**2 + 1, order='old') == "Add(Integer(1), Pow(Symbol('x'), Integer(2)))" + assert srepr(sympify('x + 3 - 2', evaluate=False), order='none') == "Add(Symbol('x'), Integer(3), Mul(Integer(-1), Integer(2)))" + + +def test_more_than_255_args_issue_10259(): + from sympy.core.add import Add + from sympy.core.mul import Mul + for op in (Add, Mul): + expr = op(*symbols('x:256')) + assert eval(srepr(expr)) == expr + + +def test_Function(): + sT(Function("f")(x), "Function('f')(Symbol('x'))") + # test unapplied Function + sT(Function('f'), "Function('f')") + + sT(sin(x), "sin(Symbol('x'))") + sT(sin, "sin") + + +def test_Heaviside(): + sT(Heaviside(x), "Heaviside(Symbol('x'))") + sT(Heaviside(x, 1), "Heaviside(Symbol('x'), Integer(1))") + + +def test_Geometry(): + sT(Point(0, 0), "Point2D(Integer(0), Integer(0))") + sT(Ellipse(Point(0, 0), 5, 1), + "Ellipse(Point2D(Integer(0), Integer(0)), Integer(5), Integer(1))") + # TODO more tests + + +def test_Singletons(): + sT(S.Catalan, 'Catalan') + sT(S.ComplexInfinity, 'zoo') + sT(S.EulerGamma, 'EulerGamma') + sT(S.Exp1, 'E') + sT(S.GoldenRatio, 'GoldenRatio') + sT(S.TribonacciConstant, 'TribonacciConstant') + sT(S.Half, 'Rational(1, 2)') + sT(S.ImaginaryUnit, 'I') + sT(S.Infinity, 'oo') + sT(S.NaN, 'nan') + sT(S.NegativeInfinity, '-oo') + sT(S.NegativeOne, 'Integer(-1)') + sT(S.One, 'Integer(1)') + sT(S.Pi, 'pi') + sT(S.Zero, 'Integer(0)') + sT(S.Complexes, 'Complexes') + sT(S.EmptySequence, 'EmptySequence') + sT(S.EmptySet, 'EmptySet') + # sT(S.IdentityFunction, 'Lambda(_x, _x)') + sT(S.Naturals, 'Naturals') + sT(S.Naturals0, 'Naturals0') + sT(S.Rationals, 'Rationals') + sT(S.Reals, 'Reals') + sT(S.UniversalSet, 'UniversalSet') + + +def test_Integer(): + sT(Integer(4), "Integer(4)") + + +def test_list(): + sT([x, Integer(4)], "[Symbol('x'), Integer(4)]") + + +def test_Matrix(): + for cls, name in [(Matrix, "MutableDenseMatrix"), (ImmutableDenseMatrix, "ImmutableDenseMatrix")]: + sT(cls([[x**+1, 1], [y, x + y]]), + "%s([[Symbol('x'), Integer(1)], [Symbol('y'), Add(Symbol('x'), Symbol('y'))]])" % name) + + sT(cls(), "%s([])" % name) + + sT(cls([[x**+1, 1], [y, x + y]]), "%s([[Symbol('x'), Integer(1)], [Symbol('y'), Add(Symbol('x'), Symbol('y'))]])" % name) + + +def test_empty_Matrix(): + sT(ones(0, 3), "MutableDenseMatrix(0, 3, [])") + sT(ones(4, 0), "MutableDenseMatrix(4, 0, [])") + sT(ones(0, 0), "MutableDenseMatrix([])") + + +def test_Rational(): + sT(Rational(1, 3), "Rational(1, 3)") + sT(Rational(-1, 3), "Rational(-1, 3)") + + +def test_Float(): + sT(Float('1.23', dps=3), "Float('1.22998', precision=13)") + sT(Float('1.23456789', dps=9), "Float('1.23456788994', precision=33)") + sT(Float('1.234567890123456789', dps=19), + "Float('1.234567890123456789013', precision=66)") + sT(Float('0.60038617995049726', dps=15), + "Float('0.60038617995049726', precision=53)") + + sT(Float('1.23', precision=13), "Float('1.22998', precision=13)") + sT(Float('1.23456789', precision=33), + "Float('1.23456788994', precision=33)") + sT(Float('1.234567890123456789', precision=66), + "Float('1.234567890123456789013', precision=66)") + sT(Float('0.60038617995049726', precision=53), + "Float('0.60038617995049726', precision=53)") + + sT(Float('0.60038617995049726', 15), + "Float('0.60038617995049726', precision=53)") + + +def test_Symbol(): + sT(x, "Symbol('x')") + sT(y, "Symbol('y')") + sT(Symbol('x', negative=True), "Symbol('x', negative=True)") + + +def test_Symbol_two_assumptions(): + x = Symbol('x', negative=0, integer=1) + # order could vary + s1 = "Symbol('x', integer=True, negative=False)" + s2 = "Symbol('x', negative=False, integer=True)" + assert srepr(x) in (s1, s2) + assert eval(srepr(x), ENV) == x + + +def test_Symbol_no_special_commutative_treatment(): + sT(Symbol('x'), "Symbol('x')") + sT(Symbol('x', commutative=False), "Symbol('x', commutative=False)") + sT(Symbol('x', commutative=0), "Symbol('x', commutative=False)") + sT(Symbol('x', commutative=True), "Symbol('x', commutative=True)") + sT(Symbol('x', commutative=1), "Symbol('x', commutative=True)") + + +def test_Wild(): + sT(Wild('x', even=True), "Wild('x', even=True)") + + +def test_Dummy(): + d = Dummy('d') + sT(d, "Dummy('d', dummy_index=%s)" % str(d.dummy_index)) + + +def test_Dummy_assumption(): + d = Dummy('d', nonzero=True) + assert d == eval(srepr(d)) + s1 = "Dummy('d', dummy_index=%s, nonzero=True)" % str(d.dummy_index) + s2 = "Dummy('d', nonzero=True, dummy_index=%s)" % str(d.dummy_index) + assert srepr(d) in (s1, s2) + + +def test_Dummy_from_Symbol(): + # should not get the full dictionary of assumptions + n = Symbol('n', integer=True) + d = n.as_dummy() + assert srepr(d + ) == "Dummy('n', dummy_index=%s)" % str(d.dummy_index) + + +def test_tuple(): + sT((x,), "(Symbol('x'),)") + sT((x, y), "(Symbol('x'), Symbol('y'))") + + +def test_WildFunction(): + sT(WildFunction('w'), "WildFunction('w')") + + +def test_settins(): + raises(TypeError, lambda: srepr(x, method="garbage")) + + +def test_Mul(): + sT(3*x**3*y, "Mul(Integer(3), Pow(Symbol('x'), Integer(3)), Symbol('y'))") + assert srepr(3*x**3*y, order='old') == "Mul(Integer(3), Symbol('y'), Pow(Symbol('x'), Integer(3)))" + assert srepr(sympify('(x+4)*2*x*7', evaluate=False), order='none') == "Mul(Add(Symbol('x'), Integer(4)), Integer(2), Symbol('x'), Integer(7))" + + +def test_AlgebraicNumber(): + a = AlgebraicNumber(sqrt(2)) + sT(a, "AlgebraicNumber(Pow(Integer(2), Rational(1, 2)), [Integer(1), Integer(0)])") + a = AlgebraicNumber(root(-2, 3)) + sT(a, "AlgebraicNumber(Pow(Integer(-2), Rational(1, 3)), [Integer(1), Integer(0)])") + + +def test_PolyRing(): + assert srepr(ring("x", ZZ, lex)[0]) == "PolyRing((Symbol('x'),), ZZ, lex)" + assert srepr(ring("x,y", QQ, grlex)[0]) == "PolyRing((Symbol('x'), Symbol('y')), QQ, grlex)" + assert srepr(ring("x,y,z", ZZ["t"], lex)[0]) == "PolyRing((Symbol('x'), Symbol('y'), Symbol('z')), ZZ[t], lex)" + + +def test_FracField(): + assert srepr(field("x", ZZ, lex)[0]) == "FracField((Symbol('x'),), ZZ, lex)" + assert srepr(field("x,y", QQ, grlex)[0]) == "FracField((Symbol('x'), Symbol('y')), QQ, grlex)" + assert srepr(field("x,y,z", ZZ["t"], lex)[0]) == "FracField((Symbol('x'), Symbol('y'), Symbol('z')), ZZ[t], lex)" + + +def test_PolyElement(): + R, x, y = ring("x,y", ZZ) + assert srepr(3*x**2*y + 1) == "PolyElement(PolyRing((Symbol('x'), Symbol('y')), ZZ, lex), [((2, 1), 3), ((0, 0), 1)])" + + +def test_FracElement(): + F, x, y = field("x,y", ZZ) + assert srepr((3*x**2*y + 1)/(x - y**2)) == "FracElement(FracField((Symbol('x'), Symbol('y')), ZZ, lex), [((2, 1), 3), ((0, 0), 1)], [((1, 0), 1), ((0, 2), -1)])" + + +def test_FractionField(): + assert srepr(QQ.frac_field(x)) == \ + "FractionField(FracField((Symbol('x'),), QQ, lex))" + assert srepr(QQ.frac_field(x, y, order=grlex)) == \ + "FractionField(FracField((Symbol('x'), Symbol('y')), QQ, grlex))" + + +def test_PolynomialRingBase(): + assert srepr(ZZ.old_poly_ring(x)) == \ + "GlobalPolynomialRing(ZZ, Symbol('x'))" + assert srepr(ZZ[x].old_poly_ring(y)) == \ + "GlobalPolynomialRing(ZZ[x], Symbol('y'))" + assert srepr(QQ.frac_field(x).old_poly_ring(y)) == \ + "GlobalPolynomialRing(FractionField(FracField((Symbol('x'),), QQ, lex)), Symbol('y'))" + + +def test_DMP(): + p1 = DMP([1, 2], ZZ) + p2 = ZZ.old_poly_ring(x)([1, 2]) + if GROUND_TYPES != 'flint': + assert srepr(p1) == "DMP_Python([1, 2], ZZ)" + assert srepr(p2) == "DMP_Python([1, 2], ZZ)" + else: + assert srepr(p1) == "DUP_Flint([1, 2], ZZ)" + assert srepr(p2) == "DUP_Flint([1, 2], ZZ)" + + +def test_FiniteExtension(): + assert srepr(FiniteExtension(Poly(x**2 + 1, x))) == \ + "FiniteExtension(Poly(x**2 + 1, x, domain='ZZ'))" + + +def test_ExtensionElement(): + A = FiniteExtension(Poly(x**2 + 1, x)) + if GROUND_TYPES != 'flint': + ans = "ExtElem(DMP_Python([1, 0], ZZ), FiniteExtension(Poly(x**2 + 1, x, domain='ZZ')))" + else: + ans = "ExtElem(DUP_Flint([1, 0], ZZ), FiniteExtension(Poly(x**2 + 1, x, domain='ZZ')))" + assert srepr(A.generator) == ans + +def test_BooleanAtom(): + assert srepr(true) == "true" + assert srepr(false) == "false" + + +def test_Integers(): + sT(S.Integers, "Integers") + + +def test_Naturals(): + sT(S.Naturals, "Naturals") + + +def test_Naturals0(): + sT(S.Naturals0, "Naturals0") + + +def test_Reals(): + sT(S.Reals, "Reals") + + +def test_matrix_expressions(): + n = symbols('n', integer=True) + A = MatrixSymbol("A", n, n) + B = MatrixSymbol("B", n, n) + sT(A, "MatrixSymbol(Str('A'), Symbol('n', integer=True), Symbol('n', integer=True))") + sT(A*B, "MatMul(MatrixSymbol(Str('A'), Symbol('n', integer=True), Symbol('n', integer=True)), MatrixSymbol(Str('B'), Symbol('n', integer=True), Symbol('n', integer=True)))") + sT(A + B, "MatAdd(MatrixSymbol(Str('A'), Symbol('n', integer=True), Symbol('n', integer=True)), MatrixSymbol(Str('B'), Symbol('n', integer=True), Symbol('n', integer=True)))") + + +def test_Cycle(): + # FIXME: sT fails because Cycle is not immutable and calling srepr(Cycle(1, 2)) + # adds keys to the Cycle dict (GH-17661) + #import_stmt = "from sympy.combinatorics import Cycle" + #sT(Cycle(1, 2), "Cycle(1, 2)", import_stmt) + assert srepr(Cycle(1, 2)) == "Cycle(1, 2)" + + +def test_Permutation(): + import_stmt = "from sympy.combinatorics import Permutation" + sT(Permutation(1, 2)(3, 4), "Permutation([0, 2, 1, 4, 3])", import_stmt, perm_cyclic=False) + sT(Permutation(1, 2)(3, 4), "Permutation(1, 2)(3, 4)", import_stmt, perm_cyclic=True) + + with warns_deprecated_sympy(): + old_print_cyclic = Permutation.print_cyclic + Permutation.print_cyclic = False + sT(Permutation(1, 2)(3, 4), "Permutation([0, 2, 1, 4, 3])", import_stmt) + Permutation.print_cyclic = old_print_cyclic + +def test_dict(): + from sympy.abc import x, y, z + d = {} + assert srepr(d) == "{}" + d = {x: y} + assert srepr(d) == "{Symbol('x'): Symbol('y')}" + d = {x: y, y: z} + assert srepr(d) in ( + "{Symbol('x'): Symbol('y'), Symbol('y'): Symbol('z')}", + "{Symbol('y'): Symbol('z'), Symbol('x'): Symbol('y')}", + ) + d = {x: {y: z}} + assert srepr(d) == "{Symbol('x'): {Symbol('y'): Symbol('z')}}" + +def test_set(): + from sympy.abc import x, y + s = set() + assert srepr(s) == "set()" + s = {x, y} + assert srepr(s) in ("{Symbol('x'), Symbol('y')}", "{Symbol('y'), Symbol('x')}") + +def test_Predicate(): + sT(Q.even, "Q.even") + +def test_AppliedPredicate(): + sT(Q.even(Symbol('z')), "AppliedPredicate(Q.even, Symbol('z'))") diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/test_rust.py b/phivenv/Lib/site-packages/sympy/printing/tests/test_rust.py new file mode 100644 index 0000000000000000000000000000000000000000..c81d592faca0d4a31e5a9618a48d67cb19ca94d8 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/tests/test_rust.py @@ -0,0 +1,363 @@ +from sympy.core import (S, pi, oo, symbols, Rational, Integer, + GoldenRatio, EulerGamma, Catalan, Lambda, Dummy, + Eq, Ne, Le, Lt, Gt, Ge, Mod) +from sympy.functions import (Piecewise, sin, cos, Abs, exp, ceiling, sqrt, + sign, floor) +from sympy.logic import ITE +from sympy.testing.pytest import raises +from sympy.utilities.lambdify import implemented_function +from sympy.tensor import IndexedBase, Idx +from sympy.matrices import MatrixSymbol, SparseMatrix, Matrix + +from sympy.printing.codeprinter import rust_code + +x, y, z = symbols('x,y,z', integer=False, real=True) +k, m, n = symbols('k,m,n', integer=True) + + +def test_Integer(): + assert rust_code(Integer(42)) == "42" + assert rust_code(Integer(-56)) == "-56" + + +def test_Relational(): + assert rust_code(Eq(x, y)) == "x == y" + assert rust_code(Ne(x, y)) == "x != y" + assert rust_code(Le(x, y)) == "x <= y" + assert rust_code(Lt(x, y)) == "x < y" + assert rust_code(Gt(x, y)) == "x > y" + assert rust_code(Ge(x, y)) == "x >= y" + + +def test_Rational(): + assert rust_code(Rational(3, 7)) == "3_f64/7.0" + assert rust_code(Rational(18, 9)) == "2" + assert rust_code(Rational(3, -7)) == "-3_f64/7.0" + assert rust_code(Rational(-3, -7)) == "3_f64/7.0" + assert rust_code(x + Rational(3, 7)) == "x + 3_f64/7.0" + assert rust_code(Rational(3, 7)*x) == "(3_f64/7.0)*x" + + +def test_basic_ops(): + assert rust_code(x + y) == "x + y" + assert rust_code(x - y) == "x - y" + assert rust_code(x * y) == "x*y" + assert rust_code(x / y) == "x*y.recip()" + assert rust_code(-x) == "-x" + assert rust_code(2 * x) == "2.0*x" + assert rust_code(y + 2) == "y + 2.0" + assert rust_code(x + n) == "n as f64 + x" + +def test_printmethod(): + class fabs(Abs): + def _rust_code(self, printer): + return "%s.fabs()" % printer._print(self.args[0]) + assert rust_code(fabs(x)) == "x.fabs()" + a = MatrixSymbol("a", 1, 3) + assert rust_code(a[0,0]) == 'a[0]' + + +def test_Functions(): + assert rust_code(sin(x) ** cos(x)) == "x.sin().powf(x.cos())" + assert rust_code(abs(x)) == "x.abs()" + assert rust_code(ceiling(x)) == "x.ceil()" + assert rust_code(floor(x)) == "x.floor()" + + # Automatic rewrite + assert rust_code(Mod(x, 3)) == 'x - 3.0*((1_f64/3.0)*x).floor()' + + +def test_Pow(): + assert rust_code(1/x) == "x.recip()" + assert rust_code(x**-1) == rust_code(x**-1.0) == "x.recip()" + assert rust_code(sqrt(x)) == "x.sqrt()" + assert rust_code(x**S.Half) == rust_code(x**0.5) == "x.sqrt()" + + assert rust_code(1/sqrt(x)) == "x.sqrt().recip()" + assert rust_code(x**-S.Half) == rust_code(x**-0.5) == "x.sqrt().recip()" + + assert rust_code(1/pi) == "PI.recip()" + assert rust_code(pi**-1) == rust_code(pi**-1.0) == "PI.recip()" + assert rust_code(pi**-0.5) == "PI.sqrt().recip()" + + assert rust_code(x**Rational(1, 3)) == "x.cbrt()" + assert rust_code(2**x) == "x.exp2()" + assert rust_code(exp(x)) == "x.exp()" + assert rust_code(x**3) == "x.powi(3)" + assert rust_code(x**(y**3)) == "x.powf(y.powi(3))" + assert rust_code(x**Rational(2, 3)) == "x.powf(2_f64/3.0)" + + g = implemented_function('g', Lambda(x, 2*x)) + assert rust_code(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \ + "(3.5*2.0*x).powf(-x + y.powf(x))/(x.powi(2) + y)" + _cond_cfunc = [(lambda base, exp: exp.is_integer, "dpowi", 1), + (lambda base, exp: not exp.is_integer, "pow", 1)] + assert rust_code(x**3, user_functions={'Pow': _cond_cfunc}) == 'x.dpowi(3)' + assert rust_code(x**3.2, user_functions={'Pow': _cond_cfunc}) == 'x.pow(3.2)' + + +def test_constants(): + assert rust_code(pi) == "PI" + assert rust_code(oo) == "INFINITY" + assert rust_code(S.Infinity) == "INFINITY" + assert rust_code(-oo) == "NEG_INFINITY" + assert rust_code(S.NegativeInfinity) == "NEG_INFINITY" + assert rust_code(S.NaN) == "NAN" + assert rust_code(exp(1)) == "E" + assert rust_code(S.Exp1) == "E" + + +def test_constants_other(): + assert rust_code(2*GoldenRatio) == "const GoldenRatio: f64 = %s;\n2.0*GoldenRatio" % GoldenRatio.evalf(17) + assert rust_code( + 2*Catalan) == "const Catalan: f64 = %s;\n2.0*Catalan" % Catalan.evalf(17) + assert rust_code(2*EulerGamma) == "const EulerGamma: f64 = %s;\n2.0*EulerGamma" % EulerGamma.evalf(17) + + +def test_boolean(): + assert rust_code(True) == "true" + assert rust_code(S.true) == "true" + assert rust_code(False) == "false" + assert rust_code(S.false) == "false" + assert rust_code(k & m) == "k && m" + assert rust_code(k | m) == "k || m" + assert rust_code(~k) == "!k" + assert rust_code(k & m & n) == "k && m && n" + assert rust_code(k | m | n) == "k || m || n" + assert rust_code((k & m) | n) == "n || k && m" + assert rust_code((k | m) & n) == "n && (k || m)" + + +def test_Piecewise(): + expr = Piecewise((x, x < 1), (x + 2, True)) + assert rust_code(expr) == ( + "if (x < 1.0) {\n" + " x\n" + "} else {\n" + " x + 2.0\n" + "}") + assert rust_code(expr, assign_to="r") == ( + "r = if (x < 1.0) {\n" + " x\n" + "} else {\n" + " x + 2.0\n" + "};") + assert rust_code(expr, assign_to="r", inline=True) == ( + "r = if (x < 1.0) { x } else { x + 2.0 };") + expr = Piecewise((x, x < 1), (x + 1, x < 5), (x + 2, True)) + assert rust_code(expr, inline=True) == ( + "if (x < 1.0) { x } else if (x < 5.0) { x + 1.0 } else { x + 2.0 }") + assert rust_code(expr, assign_to="r", inline=True) == ( + "r = if (x < 1.0) { x } else if (x < 5.0) { x + 1.0 } else { x + 2.0 };") + assert rust_code(expr, assign_to="r") == ( + "r = if (x < 1.0) {\n" + " x\n" + "} else if (x < 5.0) {\n" + " x + 1.0\n" + "} else {\n" + " x + 2.0\n" + "};") + expr = 2*Piecewise((x, x < 1), (x + 1, x < 5), (x + 2, True)) + assert rust_code(expr, inline=True) == ( + "2.0*if (x < 1.0) { x } else if (x < 5.0) { x + 1.0 } else { x + 2.0 }") + expr = 2*Piecewise((x, x < 1), (x + 1, x < 5), (x + 2, True)) - 42 + assert rust_code(expr, inline=True) == ( + "2.0*if (x < 1.0) { x } else if (x < 5.0) { x + 1.0 } else { x + 2.0 } - 42.0") + # Check that Piecewise without a True (default) condition error + expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0)) + raises(ValueError, lambda: rust_code(expr)) + + +def test_dereference_printing(): + expr = x + y + sin(z) + z + assert rust_code(expr, dereference=[z]) == "x + y + (*z) + (*z).sin()" + + +def test_sign(): + expr = sign(x) * y + assert rust_code(expr) == "y*(if (x == 0.0) { 0.0 } else { (x).signum() }) as f64" + assert rust_code(expr, assign_to='r') == "r = y*(if (x == 0.0) { 0.0 } else { (x).signum() }) as f64;" + + expr = sign(x + y) + 42 + assert rust_code(expr) == "(if (x + y == 0.0) { 0.0 } else { (x + y).signum() }) + 42" + assert rust_code(expr, assign_to='r') == "r = (if (x + y == 0.0) { 0.0 } else { (x + y).signum() }) + 42;" + + expr = sign(cos(x)) + assert rust_code(expr) == "(if (x.cos() == 0.0) { 0.0 } else { (x.cos()).signum() })" + + +def test_reserved_words(): + + x, y = symbols("x if") + + expr = sin(y) + assert rust_code(expr) == "if_.sin()" + assert rust_code(expr, dereference=[y]) == "(*if_).sin()" + assert rust_code(expr, reserved_word_suffix='_unreserved') == "if_unreserved.sin()" + + with raises(ValueError): + rust_code(expr, error_on_reserved=True) + + +def test_ITE(): + ekpr = ITE(k < 1, m, n) + assert rust_code(ekpr) == ( + "if (k < 1) {\n" + " m\n" + "} else {\n" + " n\n" + "}") + + +def test_Indexed(): + n, m, o = symbols('n m o', integer=True) + i, j, k = Idx('i', n), Idx('j', m), Idx('k', o) + + x = IndexedBase('x')[j] + assert rust_code(x) == "x[j]" + + A = IndexedBase('A')[i, j] + assert rust_code(A) == "A[m*i + j]" + + B = IndexedBase('B')[i, j, k] + assert rust_code(B) == "B[m*o*i + o*j + k]" + + +def test_dummy_loops(): + i, m = symbols('i m', integer=True, cls=Dummy) + x = IndexedBase('x') + y = IndexedBase('y') + i = Idx(i, m) + + assert rust_code(x[i], assign_to=y[i]) == ( + "for i in 0..m {\n" + " y[i] = x[i];\n" + "}") + + +def test_loops(): + m, n = symbols('m n', integer=True) + A = IndexedBase('A') + x = IndexedBase('x') + y = IndexedBase('y') + z = IndexedBase('z') + i = Idx('i', m) + j = Idx('j', n) + + assert rust_code(A[i, j]*x[j], assign_to=y[i]) == ( + "for i in 0..m {\n" + " y[i] = 0;\n" + "}\n" + "for i in 0..m {\n" + " for j in 0..n {\n" + " y[i] = A[n*i + j]*x[j] + y[i];\n" + " }\n" + "}") + + assert rust_code(A[i, j]*x[j] + x[i] + z[i], assign_to=y[i]) == ( + "for i in 0..m {\n" + " y[i] = x[i] + z[i];\n" + "}\n" + "for i in 0..m {\n" + " for j in 0..n {\n" + " y[i] = A[n*i + j]*x[j] + y[i];\n" + " }\n" + "}") + + +def test_loops_multiple_contractions(): + n, m, o, p = symbols('n m o p', integer=True) + a = IndexedBase('a') + b = IndexedBase('b') + y = IndexedBase('y') + i = Idx('i', m) + j = Idx('j', n) + k = Idx('k', o) + l = Idx('l', p) + + assert rust_code(b[j, k, l]*a[i, j, k, l], assign_to=y[i]) == ( + "for i in 0..m {\n" + " y[i] = 0;\n" + "}\n" + "for i in 0..m {\n" + " for j in 0..n {\n" + " for k in 0..o {\n" + " for l in 0..p {\n" + " y[i] = a[%s]*b[%s] + y[i];\n" % (i*n*o*p + j*o*p + k*p + l, j*o*p + k*p + l) +\ + " }\n" + " }\n" + " }\n" + "}") + + +def test_loops_addfactor(): + m, n, o, p = symbols('m n o p', integer=True) + a = IndexedBase('a') + b = IndexedBase('b') + c = IndexedBase('c') + y = IndexedBase('y') + i = Idx('i', m) + j = Idx('j', n) + k = Idx('k', o) + l = Idx('l', p) + + code = rust_code((a[i, j, k, l] + b[i, j, k, l])*c[j, k, l], assign_to=y[i]) + assert code == ( + "for i in 0..m {\n" + " y[i] = 0;\n" + "}\n" + "for i in 0..m {\n" + " for j in 0..n {\n" + " for k in 0..o {\n" + " for l in 0..p {\n" + " y[i] = (a[%s] + b[%s])*c[%s] + y[i];\n" % (i*n*o*p + j*o*p + k*p + l, i*n*o*p + j*o*p + k*p + l, j*o*p + k*p + l) +\ + " }\n" + " }\n" + " }\n" + "}") + + +def test_settings(): + raises(TypeError, lambda: rust_code(sin(x), method="garbage")) + + +def test_inline_function(): + x = symbols('x') + g = implemented_function('g', Lambda(x, 2*x)) + assert rust_code(g(x)) == "2*x" + + g = implemented_function('g', Lambda(x, 2*x/Catalan)) + assert rust_code(g(x)) == ( + "const Catalan: f64 = %s;\n2.0*x/Catalan" % Catalan.evalf(17)) + + A = IndexedBase('A') + i = Idx('i', symbols('n', integer=True)) + g = implemented_function('g', Lambda(x, x*(1 + x)*(2 + x))) + assert rust_code(g(A[i]), assign_to=A[i]) == ( + "for i in 0..n {\n" + " A[i] = (A[i] + 1)*(A[i] + 2)*A[i];\n" + "}") + + +def test_user_functions(): + x = symbols('x', integer=False) + n = symbols('n', integer=True) + custom_functions = { + "ceiling": "ceil", + "Abs": [(lambda x: not x.is_integer, "fabs", 4), (lambda x: x.is_integer, "abs", 4)], + } + assert rust_code(ceiling(x), user_functions=custom_functions) == "x.ceil()" + assert rust_code(Abs(x), user_functions=custom_functions) == "fabs(x)" + assert rust_code(Abs(n), user_functions=custom_functions) == "abs(n)" + + +def test_matrix(): + assert rust_code(Matrix([1, 2, 3])) == '[1, 2, 3]' + with raises(ValueError): + rust_code(Matrix([[1, 2, 3]])) + + +def test_sparse_matrix(): + # gh-15791 + with raises(NotImplementedError): + rust_code(SparseMatrix([[1, 2, 3]])) diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/test_smtlib.py b/phivenv/Lib/site-packages/sympy/printing/tests/test_smtlib.py new file mode 100644 index 0000000000000000000000000000000000000000..48ff3d432d9042bf178f4e52dc46c787059937a3 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/tests/test_smtlib.py @@ -0,0 +1,553 @@ +import contextlib +import itertools +import re +import typing +from enum import Enum +from typing import Callable + +import sympy +from sympy import Add, Implies, sqrt +from sympy.core import Mul, Pow +from sympy.core import (S, pi, symbols, Function, Rational, Integer, + Symbol, Eq, Ne, Le, Lt, Gt, Ge) +from sympy.functions import Piecewise, exp, sin, cos +from sympy.assumptions.ask import Q +from sympy.printing.smtlib import smtlib_code +from sympy.testing.pytest import raises, Failed + +x, y, z = symbols('x,y,z') + + +class _W(Enum): + DEFAULTING_TO_FLOAT = re.compile("Could not infer type of `.+`. Defaulting to float.", re.IGNORECASE) + WILL_NOT_DECLARE = re.compile("Non-Symbol/Function `.+` will not be declared.", re.IGNORECASE) + WILL_NOT_ASSERT = re.compile("Non-Boolean expression `.+` will not be asserted. Converting to SMTLib verbatim.", re.IGNORECASE) + + +@contextlib.contextmanager +def _check_warns(expected: typing.Iterable[_W]): + warns: typing.List[str] = [] + log_warn = warns.append + yield log_warn + + errors = [] + for i, (w, e) in enumerate(itertools.zip_longest(warns, expected)): + if not e: + errors += [f"[{i}] Received unexpected warning `{w}`."] + elif not w: + errors += [f"[{i}] Did not receive expected warning `{e.name}`."] + elif not e.value.match(w): + errors += [f"[{i}] Warning `{w}` does not match expected {e.name}."] + + if errors: raise Failed('\n'.join(errors)) + + +def test_Integer(): + with _check_warns([_W.WILL_NOT_ASSERT] * 2) as w: + assert smtlib_code(Integer(67), log_warn=w) == "67" + assert smtlib_code(Integer(-1), log_warn=w) == "-1" + with _check_warns([]) as w: + assert smtlib_code(Integer(67)) == "67" + assert smtlib_code(Integer(-1)) == "-1" + + +def test_Rational(): + with _check_warns([_W.WILL_NOT_ASSERT] * 4) as w: + assert smtlib_code(Rational(3, 7), log_warn=w) == "(/ 3 7)" + assert smtlib_code(Rational(18, 9), log_warn=w) == "2" + assert smtlib_code(Rational(3, -7), log_warn=w) == "(/ -3 7)" + assert smtlib_code(Rational(-3, -7), log_warn=w) == "(/ 3 7)" + + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT] * 2) as w: + assert smtlib_code(x + Rational(3, 7), auto_declare=False, log_warn=w) == "(+ (/ 3 7) x)" + assert smtlib_code(Rational(3, 7) * x, log_warn=w) == "(declare-const x Real)\n" \ + "(* (/ 3 7) x)" + + +def test_Relational(): + with _check_warns([_W.DEFAULTING_TO_FLOAT] * 12) as w: + assert smtlib_code(Eq(x, y), auto_declare=False, log_warn=w) == "(assert (= x y))" + assert smtlib_code(Ne(x, y), auto_declare=False, log_warn=w) == "(assert (not (= x y)))" + assert smtlib_code(Le(x, y), auto_declare=False, log_warn=w) == "(assert (<= x y))" + assert smtlib_code(Lt(x, y), auto_declare=False, log_warn=w) == "(assert (< x y))" + assert smtlib_code(Gt(x, y), auto_declare=False, log_warn=w) == "(assert (> x y))" + assert smtlib_code(Ge(x, y), auto_declare=False, log_warn=w) == "(assert (>= x y))" + + +def test_AppliedBinaryRelation(): + with _check_warns([_W.DEFAULTING_TO_FLOAT] * 12) as w: + assert smtlib_code(Q.eq(x, y), auto_declare=False, log_warn=w) == "(assert (= x y))" + assert smtlib_code(Q.ne(x, y), auto_declare=False, log_warn=w) == "(assert (not (= x y)))" + assert smtlib_code(Q.lt(x, y), auto_declare=False, log_warn=w) == "(assert (< x y))" + assert smtlib_code(Q.le(x, y), auto_declare=False, log_warn=w) == "(assert (<= x y))" + assert smtlib_code(Q.gt(x, y), auto_declare=False, log_warn=w) == "(assert (> x y))" + assert smtlib_code(Q.ge(x, y), auto_declare=False, log_warn=w) == "(assert (>= x y))" + + raises(ValueError, lambda: smtlib_code(Q.complex(x), log_warn=w)) + + +def test_AppliedPredicate(): + with _check_warns([_W.DEFAULTING_TO_FLOAT] * 6) as w: + assert smtlib_code(Q.positive(x), auto_declare=False, log_warn=w) == "(assert (> x 0))" + assert smtlib_code(Q.negative(x), auto_declare=False, log_warn=w) == "(assert (< x 0))" + assert smtlib_code(Q.zero(x), auto_declare=False, log_warn=w) == "(assert (= x 0))" + assert smtlib_code(Q.nonpositive(x), auto_declare=False, log_warn=w) == "(assert (<= x 0))" + assert smtlib_code(Q.nonnegative(x), auto_declare=False, log_warn=w) == "(assert (>= x 0))" + assert smtlib_code(Q.nonzero(x), auto_declare=False, log_warn=w) == "(assert (not (= x 0)))" + +def test_Function(): + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code(sin(x) ** cos(x), auto_declare=False, log_warn=w) == "(pow (sin x) (cos x))" + + with _check_warns([_W.WILL_NOT_ASSERT]) as w: + assert smtlib_code( + abs(x), + symbol_table={x: int, y: bool}, + known_types={int: "INTEGER_TYPE"}, + known_functions={sympy.Abs: "ABSOLUTE_VALUE_OF"}, + log_warn=w + ) == "(declare-const x INTEGER_TYPE)\n" \ + "(ABSOLUTE_VALUE_OF x)" + + my_fun1 = Function('f1') + with _check_warns([_W.WILL_NOT_ASSERT]) as w: + assert smtlib_code( + my_fun1(x), + symbol_table={my_fun1: Callable[[bool], float]}, + log_warn=w + ) == "(declare-const x Bool)\n" \ + "(declare-fun f1 (Bool) Real)\n" \ + "(f1 x)" + + with _check_warns([]) as w: + assert smtlib_code( + my_fun1(x), + symbol_table={my_fun1: Callable[[bool], bool]}, + log_warn=w + ) == "(declare-const x Bool)\n" \ + "(declare-fun f1 (Bool) Bool)\n" \ + "(assert (f1 x))" + + assert smtlib_code( + Eq(my_fun1(x, z), y), + symbol_table={my_fun1: Callable[[int, bool], bool]}, + log_warn=w + ) == "(declare-const x Int)\n" \ + "(declare-const y Bool)\n" \ + "(declare-const z Bool)\n" \ + "(declare-fun f1 (Int Bool) Bool)\n" \ + "(assert (= (f1 x z) y))" + + assert smtlib_code( + Eq(my_fun1(x, z), y), + symbol_table={my_fun1: Callable[[int, bool], bool]}, + known_functions={my_fun1: "MY_KNOWN_FUN", Eq: '=='}, + log_warn=w + ) == "(declare-const x Int)\n" \ + "(declare-const y Bool)\n" \ + "(declare-const z Bool)\n" \ + "(assert (== (MY_KNOWN_FUN x z) y))" + + with _check_warns([_W.DEFAULTING_TO_FLOAT] * 3) as w: + assert smtlib_code( + Eq(my_fun1(x, z), y), + known_functions={my_fun1: "MY_KNOWN_FUN", Eq: '=='}, + log_warn=w + ) == "(declare-const x Real)\n" \ + "(declare-const y Real)\n" \ + "(declare-const z Real)\n" \ + "(assert (== (MY_KNOWN_FUN x z) y))" + + +def test_Pow(): + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code(x ** 3, auto_declare=False, log_warn=w) == "(pow x 3)" + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code(x ** (y ** 3), auto_declare=False, log_warn=w) == "(pow x (pow y 3))" + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code(x ** Rational(2, 3), auto_declare=False, log_warn=w) == '(pow x (/ 2 3))' + + a = Symbol('a', integer=True) + b = Symbol('b', real=True) + c = Symbol('c') + + def g(x): return 2 * x + + # if x=1, y=2, then expr=2.333... + expr = 1 / (g(a) * 3.5) ** (a - b ** a) / (a ** 2 + b) + + with _check_warns([]) as w: + assert smtlib_code( + [ + Eq(a < 2, c), + Eq(b > a, c), + c & True, + Eq(expr, 2 + Rational(1, 3)) + ], + log_warn=w + ) == '(declare-const a Int)\n' \ + '(declare-const b Real)\n' \ + '(declare-const c Bool)\n' \ + '(assert (= (< a 2) c))\n' \ + '(assert (= (> b a) c))\n' \ + '(assert c)\n' \ + '(assert (= ' \ + '(* (pow (* 7.0 a) (+ (pow b a) (* -1 a))) (pow (+ b (pow a 2)) -1)) ' \ + '(/ 7 3)' \ + '))' + + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code( + Mul(-2, c, Pow(Mul(b, b, evaluate=False), -1, evaluate=False), evaluate=False), + log_warn=w + ) == '(declare-const b Real)\n' \ + '(declare-const c Real)\n' \ + '(* -2 c (pow (* b b) -1))' + + +def test_basic_ops(): + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code(x * y, auto_declare=False, log_warn=w) == "(* x y)" + + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code(x + y, auto_declare=False, log_warn=w) == "(+ x y)" + + # with _check_warns([_SmtlibWarnings.DEFAULTING_TO_FLOAT, _SmtlibWarnings.DEFAULTING_TO_FLOAT, _SmtlibWarnings.WILL_NOT_ASSERT]) as w: + # todo: implement re-write, currently does '(+ x (* -1 y))' instead + # assert smtlib_code(x - y, auto_declare=False, log_warn=w) == "(- x y)" + + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code(-x, auto_declare=False, log_warn=w) == "(* -1 x)" + + +def test_quantifier_extensions(): + from sympy.logic.boolalg import Boolean + from sympy import Interval, Tuple, sympify + + # start For-all quantifier class example + class ForAll(Boolean): + def _smtlib(self, printer): + bound_symbol_declarations = [ + printer._s_expr(sym.name, [ + printer._known_types[printer.symbol_table[sym]], + Interval(start, end) + ]) for sym, start, end in self.limits + ] + return printer._s_expr('forall', [ + printer._s_expr('', bound_symbol_declarations), + self.function + ]) + + @property + def bound_symbols(self): + return {s for s, _, _ in self.limits} + + @property + def free_symbols(self): + bound_symbol_names = {s.name for s in self.bound_symbols} + return { + s for s in self.function.free_symbols + if s.name not in bound_symbol_names + } + + def __new__(cls, *args): + limits = [sympify(a) for a in args if isinstance(a, (tuple, Tuple))] + function = [sympify(a) for a in args if isinstance(a, Boolean)] + assert len(limits) + len(function) == len(args) + assert len(function) == 1 + function = function[0] + + if isinstance(function, ForAll): return ForAll.__new__( + ForAll, *(limits + function.limits), function.function + ) + inst = Boolean.__new__(cls) + inst._args = tuple(limits + [function]) + inst.limits = limits + inst.function = function + return inst + + # end For-All Quantifier class example + + f = Function('f') + with _check_warns([_W.DEFAULTING_TO_FLOAT]) as w: + assert smtlib_code( + ForAll((x, -42, +21), Eq(f(x), f(x))), + symbol_table={f: Callable[[float], float]}, + log_warn=w + ) == '(assert (forall ( (x Real [-42, 21])) true))' + + with _check_warns([_W.DEFAULTING_TO_FLOAT] * 2) as w: + assert smtlib_code( + ForAll( + (x, -42, +21), (y, -100, 3), + Implies(Eq(x, y), Eq(f(x), f(y))) + ), + symbol_table={f: Callable[[float], float]}, + log_warn=w + ) == '(declare-fun f (Real) Real)\n' \ + '(assert (' \ + 'forall ( (x Real [-42, 21]) (y Real [-100, 3])) ' \ + '(=> (= x y) (= (f x) (f y)))' \ + '))' + + a = Symbol('a', integer=True) + b = Symbol('b', real=True) + c = Symbol('c') + + with _check_warns([]) as w: + assert smtlib_code( + ForAll( + (a, 2, 100), ForAll( + (b, 2, 100), + Implies(a < b, sqrt(a) < b) | c + )), + log_warn=w + ) == '(declare-const c Bool)\n' \ + '(assert (forall ( (a Int [2, 100]) (b Real [2, 100])) ' \ + '(or c (=> (< a b) (< (pow a (/ 1 2)) b)))' \ + '))' + + +def test_mix_number_mult_symbols(): + with _check_warns([_W.WILL_NOT_ASSERT]) as w: + assert smtlib_code( + 1 / pi, + known_constants={pi: "MY_PI"}, + log_warn=w + ) == '(pow MY_PI -1)' + + with _check_warns([_W.WILL_NOT_ASSERT]) as w: + assert smtlib_code( + [ + Eq(pi, 3.14, evaluate=False), + 1 / pi, + ], + known_constants={pi: "MY_PI"}, + log_warn=w + ) == '(assert (= MY_PI 3.14))\n' \ + '(pow MY_PI -1)' + + with _check_warns([_W.WILL_NOT_ASSERT]) as w: + assert smtlib_code( + Add(S.Zero, S.One, S.NegativeOne, S.Half, + S.Exp1, S.Pi, S.GoldenRatio, evaluate=False), + known_constants={ + S.Pi: 'p', S.GoldenRatio: 'g', + S.Exp1: 'e' + }, + known_functions={ + Add: 'plus', + exp: 'exp' + }, + precision=3, + log_warn=w + ) == '(plus 0 1 -1 (/ 1 2) (exp 1) p g)' + + with _check_warns([_W.WILL_NOT_ASSERT]) as w: + assert smtlib_code( + Add(S.Zero, S.One, S.NegativeOne, S.Half, + S.Exp1, S.Pi, S.GoldenRatio, evaluate=False), + known_constants={ + S.Pi: 'p' + }, + known_functions={ + Add: 'plus', + exp: 'exp' + }, + precision=3, + log_warn=w + ) == '(plus 0 1 -1 (/ 1 2) (exp 1) p 1.62)' + + with _check_warns([_W.WILL_NOT_ASSERT]) as w: + assert smtlib_code( + Add(S.Zero, S.One, S.NegativeOne, S.Half, + S.Exp1, S.Pi, S.GoldenRatio, evaluate=False), + known_functions={Add: 'plus'}, + precision=3, + log_warn=w + ) == '(plus 0 1 -1 (/ 1 2) 2.72 3.14 1.62)' + + with _check_warns([_W.WILL_NOT_ASSERT]) as w: + assert smtlib_code( + Add(S.Zero, S.One, S.NegativeOne, S.Half, + S.Exp1, S.Pi, S.GoldenRatio, evaluate=False), + known_constants={S.Exp1: 'e'}, + known_functions={Add: 'plus'}, + precision=3, + log_warn=w + ) == '(plus 0 1 -1 (/ 1 2) e 3.14 1.62)' + + +def test_boolean(): + with _check_warns([]) as w: + assert smtlib_code(x & y, log_warn=w) == '(declare-const x Bool)\n' \ + '(declare-const y Bool)\n' \ + '(assert (and x y))' + assert smtlib_code(x | y, log_warn=w) == '(declare-const x Bool)\n' \ + '(declare-const y Bool)\n' \ + '(assert (or x y))' + assert smtlib_code(~x, log_warn=w) == '(declare-const x Bool)\n' \ + '(assert (not x))' + assert smtlib_code(x & y & z, log_warn=w) == '(declare-const x Bool)\n' \ + '(declare-const y Bool)\n' \ + '(declare-const z Bool)\n' \ + '(assert (and x y z))' + + with _check_warns([_W.DEFAULTING_TO_FLOAT]) as w: + assert smtlib_code((x & ~y) | (z > 3), log_warn=w) == '(declare-const x Bool)\n' \ + '(declare-const y Bool)\n' \ + '(declare-const z Real)\n' \ + '(assert (or (> z 3) (and x (not y))))' + + f = Function('f') + g = Function('g') + h = Function('h') + with _check_warns([_W.DEFAULTING_TO_FLOAT]) as w: + assert smtlib_code( + [Gt(f(x), y), + Lt(y, g(z))], + symbol_table={ + f: Callable[[bool], int], g: Callable[[bool], int], + }, log_warn=w + ) == '(declare-const x Bool)\n' \ + '(declare-const y Real)\n' \ + '(declare-const z Bool)\n' \ + '(declare-fun f (Bool) Int)\n' \ + '(declare-fun g (Bool) Int)\n' \ + '(assert (> (f x) y))\n' \ + '(assert (< y (g z)))' + + with _check_warns([]) as w: + assert smtlib_code( + [Eq(f(x), y), + Lt(y, g(z))], + symbol_table={ + f: Callable[[bool], int], g: Callable[[bool], int], + }, log_warn=w + ) == '(declare-const x Bool)\n' \ + '(declare-const y Int)\n' \ + '(declare-const z Bool)\n' \ + '(declare-fun f (Bool) Int)\n' \ + '(declare-fun g (Bool) Int)\n' \ + '(assert (= (f x) y))\n' \ + '(assert (< y (g z)))' + + with _check_warns([]) as w: + assert smtlib_code( + [Eq(f(x), y), + Eq(g(f(x)), z), + Eq(h(g(f(x))), x)], + symbol_table={ + f: Callable[[float], int], + g: Callable[[int], bool], + h: Callable[[bool], float] + }, + log_warn=w + ) == '(declare-const x Real)\n' \ + '(declare-const y Int)\n' \ + '(declare-const z Bool)\n' \ + '(declare-fun f (Real) Int)\n' \ + '(declare-fun g (Int) Bool)\n' \ + '(declare-fun h (Bool) Real)\n' \ + '(assert (= (f x) y))\n' \ + '(assert (= (g (f x)) z))\n' \ + '(assert (= (h (g (f x))) x))' + + +# todo: make smtlib_code support arrays +# def test_containers(): +# assert julia_code([1, 2, 3, [4, 5, [6, 7]], 8, [9, 10], 11]) == \ +# "Any[1, 2, 3, Any[4, 5, Any[6, 7]], 8, Any[9, 10], 11]" +# assert julia_code((1, 2, (3, 4))) == "(1, 2, (3, 4))" +# assert julia_code([1]) == "Any[1]" +# assert julia_code((1,)) == "(1,)" +# assert julia_code(Tuple(*[1, 2, 3])) == "(1, 2, 3)" +# assert julia_code((1, x * y, (3, x ** 2))) == "(1, x .* y, (3, x .^ 2))" +# # scalar, matrix, empty matrix and empty list +# assert julia_code((1, eye(3), Matrix(0, 0, []), [])) == "(1, [1 0 0;\n0 1 0;\n0 0 1], zeros(0, 0), Any[])" + +def test_smtlib_piecewise(): + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code( + Piecewise((x, x < 1), + (x ** 2, True)), + auto_declare=False, + log_warn=w + ) == '(ite (< x 1) x (pow x 2))' + + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code( + Piecewise((x ** 2, x < 1), + (x ** 3, x < 2), + (x ** 4, x < 3), + (x ** 5, True)), + auto_declare=False, + log_warn=w + ) == '(ite (< x 1) (pow x 2) ' \ + '(ite (< x 2) (pow x 3) ' \ + '(ite (< x 3) (pow x 4) ' \ + '(pow x 5))))' + + # Check that Piecewise without a True (default) condition error + expr = Piecewise((x, x < 1), (x ** 2, x > 1), (sin(x), x > 0)) + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + raises(AssertionError, lambda: smtlib_code(expr, log_warn=w)) + + +def test_smtlib_piecewise_times_const(): + pw = Piecewise((x, x < 1), (x ** 2, True)) + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code(2 * pw, log_warn=w) == '(declare-const x Real)\n(* 2 (ite (< x 1) x (pow x 2)))' + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code(pw / x, log_warn=w) == '(declare-const x Real)\n(* (pow x -1) (ite (< x 1) x (pow x 2)))' + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code(pw / (x * y), log_warn=w) == '(declare-const x Real)\n(declare-const y Real)\n(* (pow x -1) (pow y -1) (ite (< x 1) x (pow x 2)))' + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code(pw / 3, log_warn=w) == '(declare-const x Real)\n(* (/ 1 3) (ite (< x 1) x (pow x 2)))' + + +# todo: make smtlib_code support arrays / matrices ? +# def test_smtlib_matrix_assign_to(): +# A = Matrix([[1, 2, 3]]) +# assert smtlib_code(A, assign_to='a') == "a = [1 2 3]" +# A = Matrix([[1, 2], [3, 4]]) +# assert smtlib_code(A, assign_to='A') == "A = [1 2;\n3 4]" + +# def test_julia_matrix_1x1(): +# A = Matrix([[3]]) +# B = MatrixSymbol('B', 1, 1) +# C = MatrixSymbol('C', 1, 2) +# assert julia_code(A, assign_to=B) == "B = [3]" +# raises(ValueError, lambda: julia_code(A, assign_to=C)) + +# def test_julia_matrix_elements(): +# A = Matrix([[x, 2, x * y]]) +# assert julia_code(A[0, 0] ** 2 + A[0, 1] + A[0, 2]) == "x .^ 2 + x .* y + 2" +# A = MatrixSymbol('AA', 1, 3) +# assert julia_code(A) == "AA" +# assert julia_code(A[0, 0] ** 2 + sin(A[0, 1]) + A[0, 2]) == \ +# "sin(AA[1,2]) + AA[1,1] .^ 2 + AA[1,3]" +# assert julia_code(sum(A)) == "AA[1,1] + AA[1,2] + AA[1,3]" + +def test_smtlib_boolean(): + with _check_warns([]) as w: + assert smtlib_code(True, auto_assert=False, log_warn=w) == 'true' + assert smtlib_code(True, log_warn=w) == '(assert true)' + assert smtlib_code(S.true, log_warn=w) == '(assert true)' + assert smtlib_code(S.false, log_warn=w) == '(assert false)' + assert smtlib_code(False, log_warn=w) == '(assert false)' + assert smtlib_code(False, auto_assert=False, log_warn=w) == 'false' + + +def test_not_supported(): + f = Function('f') + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + raises(KeyError, lambda: smtlib_code(f(x).diff(x), symbol_table={f: Callable[[float], float]}, log_warn=w)) + with _check_warns([_W.WILL_NOT_ASSERT]) as w: + raises(KeyError, lambda: smtlib_code(S.ComplexInfinity, log_warn=w)) + + +def test_Float(): + assert smtlib_code(0.0) == "0.0" + assert smtlib_code(0.000000000000000003) == '(* 3.0 (pow 10 -18))' + assert smtlib_code(5.3) == "5.3" diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/test_str.py b/phivenv/Lib/site-packages/sympy/printing/tests/test_str.py new file mode 100644 index 0000000000000000000000000000000000000000..675212964b03bf9a9806088225c28d7f70971ca7 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/tests/test_str.py @@ -0,0 +1,1206 @@ +from sympy import MatAdd +from sympy.algebras.quaternion import Quaternion +from sympy.assumptions.ask import Q +from sympy.calculus.accumulationbounds import AccumBounds +from sympy.combinatorics.partitions import Partition +from sympy.concrete.summations import (Sum, summation) +from sympy.core.add import Add +from sympy.core.containers import (Dict, Tuple) +from sympy.core.expr import UnevaluatedExpr, Expr +from sympy.core.function import (Derivative, Function, Lambda, Subs, WildFunction) +from sympy.core.mul import Mul +from sympy.core import (Catalan, EulerGamma, GoldenRatio, TribonacciConstant) +from sympy.core.numbers import (E, Float, I, Integer, Rational, nan, oo, pi, zoo) +from sympy.core.parameters import _exp_is_pow +from sympy.core.power import Pow +from sympy.core.relational import (Eq, Rel, Ne) +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol, Wild, symbols) +from sympy.functions.combinatorial.factorials import (factorial, factorial2, subfactorial) +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.functions.special.delta_functions import Heaviside +from sympy.functions.special.zeta_functions import zeta +from sympy.integrals.integrals import Integral +from sympy.logic.boolalg import (Equivalent, false, true, Xor) +from sympy.matrices.dense import Matrix +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.expressions import Identity +from sympy.matrices.expressions.slice import MatrixSlice +from sympy.matrices import SparseMatrix +from sympy.polys.polytools import factor +from sympy.series.limits import Limit +from sympy.series.order import O +from sympy.sets.sets import (Complement, FiniteSet, Interval, SymmetricDifference) +from sympy.stats import (Covariance, Expectation, Probability, Variance) +from sympy.stats.rv import RandomSymbol +from sympy.external import import_module +from sympy.physics.control.lti import TransferFunction, Series, Parallel, \ + Feedback, TransferFunctionMatrix, MIMOSeries, MIMOParallel, MIMOFeedback +from sympy.physics.units import second, joule +from sympy.polys import (Poly, rootof, RootSum, groebner, ring, field, ZZ, QQ, + ZZ_I, QQ_I, lex, grlex) +from sympy.geometry import Point, Circle, Polygon, Ellipse, Triangle +from sympy.tensor import NDimArray +from sympy.tensor.array.expressions.array_expressions import ArraySymbol, ArrayElement + +from sympy.testing.pytest import raises, warns_deprecated_sympy + +from sympy.printing import sstr, sstrrepr, StrPrinter +from sympy.physics.quantum.trace import Tr + +x, y, z, w, t = symbols('x,y,z,w,t') +d = Dummy('d') + + +def test_printmethod(): + class R(Abs): + def _sympystr(self, printer): + return "foo(%s)" % printer._print(self.args[0]) + assert sstr(R(x)) == "foo(x)" + + class R(Abs): + def _sympystr(self, printer): + return "foo" + assert sstr(R(x)) == "foo" + + +def test_Abs(): + assert str(Abs(x)) == "Abs(x)" + assert str(Abs(Rational(1, 6))) == "1/6" + assert str(Abs(Rational(-1, 6))) == "1/6" + + +def test_Add(): + assert str(x + y) == "x + y" + assert str(x + 1) == "x + 1" + assert str(x + x**2) == "x**2 + x" + assert str(Add(0, 1, evaluate=False)) == "0 + 1" + assert str(Add(0, 0, 1, evaluate=False)) == "0 + 0 + 1" + assert str(1.0*x) == "1.0*x" + assert str(5 + x + y + x*y + x**2 + y**2) == "x**2 + x*y + x + y**2 + y + 5" + assert str(1 + x + x**2/2 + x**3/3) == "x**3/3 + x**2/2 + x + 1" + assert str(2*x - 7*x**2 + 2 + 3*y) == "-7*x**2 + 2*x + 3*y + 2" + assert str(x - y) == "x - y" + assert str(2 - x) == "2 - x" + assert str(x - 2) == "x - 2" + assert str(x - y - z - w) == "-w + x - y - z" + assert str(x - z*y**2*z*w) == "-w*y**2*z**2 + x" + assert str(x - 1*y*x*y) == "-x*y**2 + x" + assert str(sin(x).series(x, 0, 15)) == "x - x**3/6 + x**5/120 - x**7/5040 + x**9/362880 - x**11/39916800 + x**13/6227020800 + O(x**15)" + assert str(Add(Add(-w, x, evaluate=False), Add(-y, z, evaluate=False), evaluate=False)) == "(-w + x) + (-y + z)" + assert str(Add(Add(-x, -y, evaluate=False), -z, evaluate=False)) == "-z + (-x - y)" + assert str(Add(Add(Add(-x, -y, evaluate=False), -z, evaluate=False), -t, evaluate=False)) == "-t + (-z + (-x - y))" + + +def test_Catalan(): + assert str(Catalan) == "Catalan" + + +def test_ComplexInfinity(): + assert str(zoo) == "zoo" + + +def test_Derivative(): + assert str(Derivative(x, y)) == "Derivative(x, y)" + assert str(Derivative(x**2, x, evaluate=False)) == "Derivative(x**2, x)" + assert str(Derivative( + x**2/y, x, y, evaluate=False)) == "Derivative(x**2/y, x, y)" + + +def test_dict(): + assert str({1: 1 + x}) == sstr({1: 1 + x}) == "{1: x + 1}" + assert str({1: x**2, 2: y*x}) in ("{1: x**2, 2: x*y}", "{2: x*y, 1: x**2}") + assert sstr({1: x**2, 2: y*x}) == "{1: x**2, 2: x*y}" + + +def test_Dict(): + assert str(Dict({1: 1 + x})) == sstr({1: 1 + x}) == "{1: x + 1}" + assert str(Dict({1: x**2, 2: y*x})) in ( + "{1: x**2, 2: x*y}", "{2: x*y, 1: x**2}") + assert sstr(Dict({1: x**2, 2: y*x})) == "{1: x**2, 2: x*y}" + + +def test_Dummy(): + assert str(d) == "_d" + assert str(d + x) == "_d + x" + + +def test_EulerGamma(): + assert str(EulerGamma) == "EulerGamma" + + +def test_Exp(): + assert str(E) == "E" + with _exp_is_pow(True): + assert str(exp(x)) == "E**x" + + +def test_factorial(): + n = Symbol('n', integer=True) + assert str(factorial(-2)) == "zoo" + assert str(factorial(0)) == "1" + assert str(factorial(7)) == "5040" + assert str(factorial(n)) == "factorial(n)" + assert str(factorial(2*n)) == "factorial(2*n)" + assert str(factorial(factorial(n))) == 'factorial(factorial(n))' + assert str(factorial(factorial2(n))) == 'factorial(factorial2(n))' + assert str(factorial2(factorial(n))) == 'factorial2(factorial(n))' + assert str(factorial2(factorial2(n))) == 'factorial2(factorial2(n))' + assert str(subfactorial(3)) == "2" + assert str(subfactorial(n)) == "subfactorial(n)" + assert str(subfactorial(2*n)) == "subfactorial(2*n)" + + +def test_Function(): + f = Function('f') + fx = f(x) + w = WildFunction('w') + assert str(f) == "f" + assert str(fx) == "f(x)" + assert str(w) == "w_" + + +def test_Geometry(): + assert sstr(Point(0, 0)) == 'Point2D(0, 0)' + assert sstr(Circle(Point(0, 0), 3)) == 'Circle(Point2D(0, 0), 3)' + assert sstr(Ellipse(Point(1, 2), 3, 4)) == 'Ellipse(Point2D(1, 2), 3, 4)' + assert sstr(Triangle(Point(1, 1), Point(7, 8), Point(0, -1))) == \ + 'Triangle(Point2D(1, 1), Point2D(7, 8), Point2D(0, -1))' + assert sstr(Polygon(Point(5, 6), Point(-2, -3), Point(0, 0), Point(4, 7))) == \ + 'Polygon(Point2D(5, 6), Point2D(-2, -3), Point2D(0, 0), Point2D(4, 7))' + assert sstr(Triangle(Point(0, 0), Point(1, 0), Point(0, 1)), sympy_integers=True) == \ + 'Triangle(Point2D(S(0), S(0)), Point2D(S(1), S(0)), Point2D(S(0), S(1)))' + assert sstr(Ellipse(Point(1, 2), 3, 4), sympy_integers=True) == \ + 'Ellipse(Point2D(S(1), S(2)), S(3), S(4))' + + +def test_GoldenRatio(): + assert str(GoldenRatio) == "GoldenRatio" + + +def test_Heaviside(): + assert str(Heaviside(x)) == str(Heaviside(x, S.Half)) == "Heaviside(x)" + assert str(Heaviside(x, 1)) == "Heaviside(x, 1)" + + +def test_TribonacciConstant(): + assert str(TribonacciConstant) == "TribonacciConstant" + + +def test_ImaginaryUnit(): + assert str(I) == "I" + + +def test_Infinity(): + assert str(oo) == "oo" + assert str(oo*I) == "oo*I" + + +def test_Integer(): + assert str(Integer(-1)) == "-1" + assert str(Integer(1)) == "1" + assert str(Integer(-3)) == "-3" + assert str(Integer(0)) == "0" + assert str(Integer(25)) == "25" + + +def test_Integral(): + assert str(Integral(sin(x), y)) == "Integral(sin(x), y)" + assert str(Integral(sin(x), (y, 0, 1))) == "Integral(sin(x), (y, 0, 1))" + + +def test_Interval(): + n = (S.NegativeInfinity, 1, 2, S.Infinity) + for i in range(len(n)): + for j in range(i + 1, len(n)): + for l in (True, False): + for r in (True, False): + ival = Interval(n[i], n[j], l, r) + assert S(str(ival)) == ival + + +def test_AccumBounds(): + a = Symbol('a', real=True) + assert str(AccumBounds(0, a)) == "AccumBounds(0, a)" + assert str(AccumBounds(0, 1)) == "AccumBounds(0, 1)" + + +def test_Lambda(): + assert str(Lambda(d, d**2)) == "Lambda(_d, _d**2)" + # issue 2908 + assert str(Lambda((), 1)) == "Lambda((), 1)" + assert str(Lambda((), x)) == "Lambda((), x)" + assert str(Lambda((x, y), x+y)) == "Lambda((x, y), x + y)" + assert str(Lambda(((x, y),), x+y)) == "Lambda(((x, y),), x + y)" + + +def test_Limit(): + assert str(Limit(sin(x)/x, x, y)) == "Limit(sin(x)/x, x, y, dir='+')" + assert str(Limit(1/x, x, 0)) == "Limit(1/x, x, 0, dir='+')" + assert str( + Limit(sin(x)/x, x, y, dir="-")) == "Limit(sin(x)/x, x, y, dir='-')" + + +def test_list(): + assert str([x]) == sstr([x]) == "[x]" + assert str([x**2, x*y + 1]) == sstr([x**2, x*y + 1]) == "[x**2, x*y + 1]" + assert str([x**2, [y + x]]) == sstr([x**2, [y + x]]) == "[x**2, [x + y]]" + + +def test_Matrix_str(): + M = Matrix([[x**+1, 1], [y, x + y]]) + assert str(M) == "Matrix([[x, 1], [y, x + y]])" + assert sstr(M) == "Matrix([\n[x, 1],\n[y, x + y]])" + M = Matrix([[1]]) + assert str(M) == sstr(M) == "Matrix([[1]])" + M = Matrix([[1, 2]]) + assert str(M) == sstr(M) == "Matrix([[1, 2]])" + M = Matrix() + assert str(M) == sstr(M) == "Matrix(0, 0, [])" + M = Matrix(0, 1, lambda i, j: 0) + assert str(M) == sstr(M) == "Matrix(0, 1, [])" + + +def test_Mul(): + assert str(x/y) == "x/y" + assert str(y/x) == "y/x" + assert str(x/y/z) == "x/(y*z)" + assert str((x + 1)/(y + 2)) == "(x + 1)/(y + 2)" + assert str(2*x/3) == '2*x/3' + assert str(-2*x/3) == '-2*x/3' + assert str(-1.0*x) == '-1.0*x' + assert str(1.0*x) == '1.0*x' + assert str(Mul(0, 1, evaluate=False)) == '0*1' + assert str(Mul(1, 0, evaluate=False)) == '1*0' + assert str(Mul(1, 1, evaluate=False)) == '1*1' + assert str(Mul(1, 1, 1, evaluate=False)) == '1*1*1' + assert str(Mul(1, 2, evaluate=False)) == '1*2' + assert str(Mul(1, S.Half, evaluate=False)) == '1*(1/2)' + assert str(Mul(1, 1, S.Half, evaluate=False)) == '1*1*(1/2)' + assert str(Mul(1, 1, 2, 3, x, evaluate=False)) == '1*1*2*3*x' + assert str(Mul(1, -1, evaluate=False)) == '1*(-1)' + assert str(Mul(-1, 1, evaluate=False)) == '-1*1' + assert str(Mul(4, 3, 2, 1, 0, y, x, evaluate=False)) == '4*3*2*1*0*y*x' + assert str(Mul(4, 3, 2, 1+z, 0, y, x, evaluate=False)) == '4*3*2*(z + 1)*0*y*x' + assert str(Mul(Rational(2, 3), Rational(5, 7), evaluate=False)) == '(2/3)*(5/7)' + # For issue 14160 + assert str(Mul(-2, x, Pow(Mul(y,y,evaluate=False), -1, evaluate=False), + evaluate=False)) == '-2*x/(y*y)' + # issue 21537 + assert str(Mul(x, Pow(1/y, -1, evaluate=False), evaluate=False)) == 'x/(1/y)' + + # Issue 24108 + from sympy.core.parameters import evaluate + with evaluate(False): + assert str(Mul(Pow(Integer(2), Integer(-1)), Add(Integer(-1), Mul(Integer(-1), Integer(1))))) == "(-1 - 1*1)/2" + + class CustomClass1(Expr): + is_commutative = True + + class CustomClass2(Expr): + is_commutative = True + cc1 = CustomClass1() + cc2 = CustomClass2() + assert str(Rational(2)*cc1) == '2*CustomClass1()' + assert str(cc1*Rational(2)) == '2*CustomClass1()' + assert str(cc1*Float("1.5")) == '1.5*CustomClass1()' + assert str(cc2*Rational(2)) == '2*CustomClass2()' + assert str(cc2*Rational(2)*cc1) == '2*CustomClass1()*CustomClass2()' + assert str(cc1*Rational(2)*cc2) == '2*CustomClass1()*CustomClass2()' + + +def test_NaN(): + assert str(nan) == "nan" + + +def test_NegativeInfinity(): + assert str(-oo) == "-oo" + +def test_Order(): + assert str(O(x)) == "O(x)" + assert str(O(x**2)) == "O(x**2)" + assert str(O(x*y)) == "O(x*y, x, y)" + assert str(O(x, x)) == "O(x)" + assert str(O(x, (x, 0))) == "O(x)" + assert str(O(x, (x, oo))) == "O(x, (x, oo))" + assert str(O(x, x, y)) == "O(x, x, y)" + assert str(O(x, x, y)) == "O(x, x, y)" + assert str(O(x, (x, oo), (y, oo))) == "O(x, (x, oo), (y, oo))" + + +def test_Permutation_Cycle(): + from sympy.combinatorics import Permutation, Cycle + + # general principle: economically, canonically show all moved elements + # and the size of the permutation. + + for p, s in [ + (Cycle(), + '()'), + (Cycle(2), + '(2)'), + (Cycle(2, 1), + '(1 2)'), + (Cycle(1, 2)(5)(6, 7)(10), + '(1 2)(6 7)(10)'), + (Cycle(3, 4)(1, 2)(3, 4), + '(1 2)(4)'), + ]: + assert sstr(p) == s + + for p, s in [ + (Permutation([]), + 'Permutation([])'), + (Permutation([], size=1), + 'Permutation([0])'), + (Permutation([], size=2), + 'Permutation([0, 1])'), + (Permutation([], size=10), + 'Permutation([], size=10)'), + (Permutation([1, 0, 2]), + 'Permutation([1, 0, 2])'), + (Permutation([1, 0, 2, 3, 4, 5]), + 'Permutation([1, 0], size=6)'), + (Permutation([1, 0, 2, 3, 4, 5], size=10), + 'Permutation([1, 0], size=10)'), + ]: + assert sstr(p, perm_cyclic=False) == s + + for p, s in [ + (Permutation([]), + '()'), + (Permutation([], size=1), + '(0)'), + (Permutation([], size=2), + '(1)'), + (Permutation([], size=10), + '(9)'), + (Permutation([1, 0, 2]), + '(2)(0 1)'), + (Permutation([1, 0, 2, 3, 4, 5]), + '(5)(0 1)'), + (Permutation([1, 0, 2, 3, 4, 5], size=10), + '(9)(0 1)'), + (Permutation([0, 1, 3, 2, 4, 5], size=10), + '(9)(2 3)'), + ]: + assert sstr(p) == s + + + with warns_deprecated_sympy(): + old_print_cyclic = Permutation.print_cyclic + Permutation.print_cyclic = False + assert sstr(Permutation([1, 0, 2])) == 'Permutation([1, 0, 2])' + Permutation.print_cyclic = old_print_cyclic + +def test_Pi(): + assert str(pi) == "pi" + + +def test_Poly(): + assert str(Poly(0, x)) == "Poly(0, x, domain='ZZ')" + assert str(Poly(1, x)) == "Poly(1, x, domain='ZZ')" + assert str(Poly(x, x)) == "Poly(x, x, domain='ZZ')" + + assert str(Poly(2*x + 1, x)) == "Poly(2*x + 1, x, domain='ZZ')" + assert str(Poly(2*x - 1, x)) == "Poly(2*x - 1, x, domain='ZZ')" + + assert str(Poly(-1, x)) == "Poly(-1, x, domain='ZZ')" + assert str(Poly(-x, x)) == "Poly(-x, x, domain='ZZ')" + + assert str(Poly(-2*x + 1, x)) == "Poly(-2*x + 1, x, domain='ZZ')" + assert str(Poly(-2*x - 1, x)) == "Poly(-2*x - 1, x, domain='ZZ')" + + assert str(Poly(x - 1, x)) == "Poly(x - 1, x, domain='ZZ')" + assert str(Poly(2*x + x**5, x)) == "Poly(x**5 + 2*x, x, domain='ZZ')" + + assert str(Poly(3**(2*x), 3**x)) == "Poly((3**x)**2, 3**x, domain='ZZ')" + assert str(Poly((x**2)**x)) == "Poly(((x**2)**x), (x**2)**x, domain='ZZ')" + + assert str(Poly((x + y)**3, (x + y), expand=False) + ) == "Poly((x + y)**3, x + y, domain='ZZ')" + assert str(Poly((x - 1)**2, (x - 1), expand=False) + ) == "Poly((x - 1)**2, x - 1, domain='ZZ')" + + assert str( + Poly(x**2 + 1 + y, x)) == "Poly(x**2 + y + 1, x, domain='ZZ[y]')" + assert str( + Poly(x**2 - 1 + y, x)) == "Poly(x**2 + y - 1, x, domain='ZZ[y]')" + + assert str(Poly(x**2 + I*x, x)) == "Poly(x**2 + I*x, x, domain='ZZ_I')" + assert str(Poly(x**2 - I*x, x)) == "Poly(x**2 - I*x, x, domain='ZZ_I')" + + assert str(Poly(-x*y*z + x*y - 1, x, y, z) + ) == "Poly(-x*y*z + x*y - 1, x, y, z, domain='ZZ')" + assert str(Poly(-w*x**21*y**7*z + (1 + w)*z**3 - 2*x*z + 1, x, y, z)) == \ + "Poly(-w*x**21*y**7*z - 2*x*z + (w + 1)*z**3 + 1, x, y, z, domain='ZZ[w]')" + + assert str(Poly(x**2 + 1, x, modulus=2)) == "Poly(x**2 + 1, x, modulus=2)" + assert str(Poly(2*x**2 + 3*x + 4, x, modulus=17)) == "Poly(2*x**2 + 3*x + 4, x, modulus=17)" + + +def test_PolyRing(): + assert str(ring("x", ZZ, lex)[0]) == "Polynomial ring in x over ZZ with lex order" + assert str(ring("x,y", QQ, grlex)[0]) == "Polynomial ring in x, y over QQ with grlex order" + assert str(ring("x,y,z", ZZ["t"], lex)[0]) == "Polynomial ring in x, y, z over ZZ[t] with lex order" + + +def test_FracField(): + assert str(field("x", ZZ, lex)[0]) == "Rational function field in x over ZZ with lex order" + assert str(field("x,y", QQ, grlex)[0]) == "Rational function field in x, y over QQ with grlex order" + assert str(field("x,y,z", ZZ["t"], lex)[0]) == "Rational function field in x, y, z over ZZ[t] with lex order" + + +def test_PolyElement(): + Ruv, u,v = ring("u,v", ZZ) + Rxyz, x,y,z = ring("x,y,z", Ruv) + Rx_zzi, xz = ring("x", ZZ_I) + + assert str(x - x) == "0" + assert str(x - 1) == "x - 1" + assert str(x + 1) == "x + 1" + assert str(x**2) == "x**2" + + assert str((u**2 + 3*u*v + 1)*x**2*y + u + 1) == "(u**2 + 3*u*v + 1)*x**2*y + u + 1" + assert str((u**2 + 3*u*v + 1)*x**2*y + (u + 1)*x) == "(u**2 + 3*u*v + 1)*x**2*y + (u + 1)*x" + assert str((u**2 + 3*u*v + 1)*x**2*y + (u + 1)*x + 1) == "(u**2 + 3*u*v + 1)*x**2*y + (u + 1)*x + 1" + assert str((-u**2 + 3*u*v - 1)*x**2*y - (u + 1)*x - 1) == "-(u**2 - 3*u*v + 1)*x**2*y - (u + 1)*x - 1" + + assert str(-(v**2 + v + 1)*x + 3*u*v + 1) == "-(v**2 + v + 1)*x + 3*u*v + 1" + assert str(-(v**2 + v + 1)*x - 3*u*v + 1) == "-(v**2 + v + 1)*x - 3*u*v + 1" + + assert str((1+I)*xz + 2) == "(1 + 1*I)*x + (2 + 0*I)" + + +def test_FracElement(): + Fuv, u,v = field("u,v", ZZ) + Fxyzt, x,y,z,t = field("x,y,z,t", Fuv) + Rx_zzi, xz = field("x", QQ_I) + i = QQ_I(0, 1) + + assert str(x - x) == "0" + assert str(x - 1) == "x - 1" + assert str(x + 1) == "x + 1" + + assert str(x/3) == "x/3" + assert str(x/z) == "x/z" + assert str(x*y/z) == "x*y/z" + assert str(x/(z*t)) == "x/(z*t)" + assert str(x*y/(z*t)) == "x*y/(z*t)" + + assert str((x - 1)/y) == "(x - 1)/y" + assert str((x + 1)/y) == "(x + 1)/y" + assert str((-x - 1)/y) == "(-x - 1)/y" + assert str((x + 1)/(y*z)) == "(x + 1)/(y*z)" + assert str(-y/(x + 1)) == "-y/(x + 1)" + assert str(y*z/(x + 1)) == "y*z/(x + 1)" + + assert str(((u + 1)*x*y + 1)/((v - 1)*z - 1)) == "((u + 1)*x*y + 1)/((v - 1)*z - 1)" + assert str(((u + 1)*x*y + 1)/((v - 1)*z - t*u*v - 1)) == "((u + 1)*x*y + 1)/((v - 1)*z - u*v*t - 1)" + + assert str((1+i)/xz) == "(1 + 1*I)/x" + assert str(((1+i)*xz - i)/xz) == "((1 + 1*I)*x + (0 + -1*I))/x" + + +def test_GaussianInteger(): + assert str(ZZ_I(1, 0)) == "1" + assert str(ZZ_I(-1, 0)) == "-1" + assert str(ZZ_I(0, 1)) == "I" + assert str(ZZ_I(0, -1)) == "-I" + assert str(ZZ_I(0, 2)) == "2*I" + assert str(ZZ_I(0, -2)) == "-2*I" + assert str(ZZ_I(1, 1)) == "1 + I" + assert str(ZZ_I(-1, -1)) == "-1 - I" + assert str(ZZ_I(-1, -2)) == "-1 - 2*I" + + +def test_GaussianRational(): + assert str(QQ_I(1, 0)) == "1" + assert str(QQ_I(QQ(2, 3), 0)) == "2/3" + assert str(QQ_I(0, QQ(2, 3))) == "2*I/3" + assert str(QQ_I(QQ(1, 2), QQ(-2, 3))) == "1/2 - 2*I/3" + + +def test_Pow(): + assert str(x**-1) == "1/x" + assert str(x**-2) == "x**(-2)" + assert str(x**2) == "x**2" + assert str((x + y)**-1) == "1/(x + y)" + assert str((x + y)**-2) == "(x + y)**(-2)" + assert str((x + y)**2) == "(x + y)**2" + assert str((x + y)**(1 + x)) == "(x + y)**(x + 1)" + assert str(x**Rational(1, 3)) == "x**(1/3)" + assert str(1/x**Rational(1, 3)) == "x**(-1/3)" + assert str(sqrt(sqrt(x))) == "x**(1/4)" + # not the same as x**-1 + assert str(x**-1.0) == 'x**(-1.0)' + # see issue #2860 + assert str(Pow(S(2), -1.0, evaluate=False)) == '2**(-1.0)' + + +def test_sqrt(): + assert str(sqrt(x)) == "sqrt(x)" + assert str(sqrt(x**2)) == "sqrt(x**2)" + assert str(1/sqrt(x)) == "1/sqrt(x)" + assert str(1/sqrt(x**2)) == "1/sqrt(x**2)" + assert str(y/sqrt(x)) == "y/sqrt(x)" + assert str(x**0.5) == "x**0.5" + assert str(1/x**0.5) == "x**(-0.5)" + + +def test_Rational(): + n1 = Rational(1, 4) + n2 = Rational(1, 3) + n3 = Rational(2, 4) + n4 = Rational(2, -4) + n5 = Rational(0) + n7 = Rational(3) + n8 = Rational(-3) + assert str(n1*n2) == "1/12" + assert str(n1*n2) == "1/12" + assert str(n3) == "1/2" + assert str(n1*n3) == "1/8" + assert str(n1 + n3) == "3/4" + assert str(n1 + n2) == "7/12" + assert str(n1 + n4) == "-1/4" + assert str(n4*n4) == "1/4" + assert str(n4 + n2) == "-1/6" + assert str(n4 + n5) == "-1/2" + assert str(n4*n5) == "0" + assert str(n3 + n4) == "0" + assert str(n1**n7) == "1/64" + assert str(n2**n7) == "1/27" + assert str(n2**n8) == "27" + assert str(n7**n8) == "1/27" + assert str(Rational("-25")) == "-25" + assert str(Rational("1.25")) == "5/4" + assert str(Rational("-2.6e-2")) == "-13/500" + assert str(S("25/7")) == "25/7" + assert str(S("-123/569")) == "-123/569" + assert str(S("0.1[23]", rational=1)) == "61/495" + assert str(S("5.1[666]", rational=1)) == "31/6" + assert str(S("-5.1[666]", rational=1)) == "-31/6" + assert str(S("0.[9]", rational=1)) == "1" + assert str(S("-0.[9]", rational=1)) == "-1" + + assert str(sqrt(Rational(1, 4))) == "1/2" + assert str(sqrt(Rational(1, 36))) == "1/6" + + assert str((123**25) ** Rational(1, 25)) == "123" + assert str((123**25 + 1)**Rational(1, 25)) != "123" + assert str((123**25 - 1)**Rational(1, 25)) != "123" + assert str((123**25 - 1)**Rational(1, 25)) != "122" + + assert str(sqrt(Rational(81, 36))**3) == "27/8" + assert str(1/sqrt(Rational(81, 36))**3) == "8/27" + + assert str(sqrt(-4)) == str(2*I) + assert str(2**Rational(1, 10**10)) == "2**(1/10000000000)" + + assert sstr(Rational(2, 3), sympy_integers=True) == "S(2)/3" + x = Symbol("x") + assert sstr(x**Rational(2, 3), sympy_integers=True) == "x**(S(2)/3)" + assert sstr(Eq(x, Rational(2, 3)), sympy_integers=True) == "Eq(x, S(2)/3)" + assert sstr(Limit(x, x, Rational(7, 2)), sympy_integers=True) == \ + "Limit(x, x, S(7)/2, dir='+')" + + +def test_Float(): + # NOTE dps is the whole number of decimal digits + assert str(Float('1.23', dps=1 + 2)) == '1.23' + assert str(Float('1.23456789', dps=1 + 8)) == '1.23456789' + assert str( + Float('1.234567890123456789', dps=1 + 18)) == '1.234567890123456789' + assert str(pi.evalf(1 + 2)) == '3.14' + assert str(pi.evalf(1 + 14)) == '3.14159265358979' + assert str(pi.evalf(1 + 64)) == ('3.141592653589793238462643383279' + '5028841971693993751058209749445923') + assert str(pi.round(-1)) == '0.0' + assert str((pi**400 - (pi**400).round(1)).n(2)) == '-0.e+88' + assert sstr(Float("100"), full_prec=False, min=-2, max=2) == '1.0e+2' + assert sstr(Float("100"), full_prec=False, min=-2, max=3) == '100.0' + assert sstr(Float("0.1"), full_prec=False, min=-2, max=3) == '0.1' + assert sstr(Float("0.099"), min=-2, max=3) == '9.90000000000000e-2' + + +def test_Relational(): + assert str(Rel(x, y, "<")) == "x < y" + assert str(Rel(x + y, y, "==")) == "Eq(x + y, y)" + assert str(Rel(x, y, "!=")) == "Ne(x, y)" + assert str(Eq(x, 1) | Eq(x, 2)) == "Eq(x, 1) | Eq(x, 2)" + assert str(Ne(x, 1) & Ne(x, 2)) == "Ne(x, 1) & Ne(x, 2)" + + +def test_AppliedBinaryRelation(): + assert str(Q.eq(x, y)) == "Q.eq(x, y)" + assert str(Q.ne(x, y)) == "Q.ne(x, y)" + + +def test_CRootOf(): + assert str(rootof(x**5 + 2*x - 1, 0)) == "CRootOf(x**5 + 2*x - 1, 0)" + + +def test_RootSum(): + f = x**5 + 2*x - 1 + + assert str( + RootSum(f, Lambda(z, z), auto=False)) == "RootSum(x**5 + 2*x - 1)" + assert str(RootSum(f, Lambda( + z, z**2), auto=False)) == "RootSum(x**5 + 2*x - 1, Lambda(z, z**2))" + + +def test_GroebnerBasis(): + assert str(groebner( + [], x, y)) == "GroebnerBasis([], x, y, domain='ZZ', order='lex')" + + F = [x**2 - 3*y - x + 1, y**2 - 2*x + y - 1] + + assert str(groebner(F, order='grlex')) == \ + "GroebnerBasis([x**2 - x - 3*y + 1, y**2 - 2*x + y - 1], x, y, domain='ZZ', order='grlex')" + assert str(groebner(F, order='lex')) == \ + "GroebnerBasis([2*x - y**2 - y + 1, y**4 + 2*y**3 - 3*y**2 - 16*y + 7], x, y, domain='ZZ', order='lex')" + +def test_set(): + assert sstr(set()) == 'set()' + assert sstr(frozenset()) == 'frozenset()' + + assert sstr({1}) == '{1}' + assert sstr(frozenset([1])) == 'frozenset({1})' + assert sstr({1, 2, 3}) == '{1, 2, 3}' + assert sstr(frozenset([1, 2, 3])) == 'frozenset({1, 2, 3})' + + assert sstr( + {1, x, x**2, x**3, x**4}) == '{1, x, x**2, x**3, x**4}' + assert sstr( + frozenset([1, x, x**2, x**3, x**4])) == 'frozenset({1, x, x**2, x**3, x**4})' + + +def test_SparseMatrix(): + M = SparseMatrix([[x**+1, 1], [y, x + y]]) + assert str(M) == "Matrix([[x, 1], [y, x + y]])" + assert sstr(M) == "Matrix([\n[x, 1],\n[y, x + y]])" + + +def test_Sum(): + assert str(summation(cos(3*z), (z, x, y))) == "Sum(cos(3*z), (z, x, y))" + assert str(Sum(x*y**2, (x, -2, 2), (y, -5, 5))) == \ + "Sum(x*y**2, (x, -2, 2), (y, -5, 5))" + + +def test_Symbol(): + assert str(y) == "y" + assert str(x) == "x" + e = x + assert str(e) == "x" + + +def test_tuple(): + assert str((x,)) == sstr((x,)) == "(x,)" + assert str((x + y, 1 + x)) == sstr((x + y, 1 + x)) == "(x + y, x + 1)" + assert str((x + y, ( + 1 + x, x**2))) == sstr((x + y, (1 + x, x**2))) == "(x + y, (x + 1, x**2))" + + +def test_Series_str(): + tf1 = TransferFunction(x*y**2 - z, y**3 - t**3, y) + tf2 = TransferFunction(x - y, x + y, y) + tf3 = TransferFunction(t*x**2 - t**w*x + w, t - y, y) + assert str(Series(tf1, tf2)) == \ + "Series(TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(x - y, x + y, y))" + assert str(Series(tf1, tf2, tf3)) == \ + "Series(TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(x - y, x + y, y), TransferFunction(t*x**2 - t**w*x + w, t - y, y))" + assert str(Series(-tf2, tf1)) == \ + "Series(TransferFunction(-x + y, x + y, y), TransferFunction(x*y**2 - z, -t**3 + y**3, y))" + + +def test_MIMOSeries_str(): + tf1 = TransferFunction(x*y**2 - z, y**3 - t**3, y) + tf2 = TransferFunction(x - y, x + y, y) + tfm_1 = TransferFunctionMatrix([[tf1, tf2], [tf2, tf1]]) + tfm_2 = TransferFunctionMatrix([[tf2, tf1], [tf1, tf2]]) + assert str(MIMOSeries(tfm_1, tfm_2)) == \ + "MIMOSeries(TransferFunctionMatrix(((TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(x - y, x + y, y)), "\ + "(TransferFunction(x - y, x + y, y), TransferFunction(x*y**2 - z, -t**3 + y**3, y)))), "\ + "TransferFunctionMatrix(((TransferFunction(x - y, x + y, y), TransferFunction(x*y**2 - z, -t**3 + y**3, y)), "\ + "(TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(x - y, x + y, y)))))" + + +def test_TransferFunction_str(): + tf1 = TransferFunction(x - 1, x + 1, x) + assert str(tf1) == "TransferFunction(x - 1, x + 1, x)" + tf2 = TransferFunction(x + 1, 2 - y, x) + assert str(tf2) == "TransferFunction(x + 1, 2 - y, x)" + tf3 = TransferFunction(y, y**2 + 2*y + 3, y) + assert str(tf3) == "TransferFunction(y, y**2 + 2*y + 3, y)" + + +def test_Parallel_str(): + tf1 = TransferFunction(x*y**2 - z, y**3 - t**3, y) + tf2 = TransferFunction(x - y, x + y, y) + tf3 = TransferFunction(t*x**2 - t**w*x + w, t - y, y) + assert str(Parallel(tf1, tf2)) == \ + "Parallel(TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(x - y, x + y, y))" + assert str(Parallel(tf1, tf2, tf3)) == \ + "Parallel(TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(x - y, x + y, y), TransferFunction(t*x**2 - t**w*x + w, t - y, y))" + assert str(Parallel(-tf2, tf1)) == \ + "Parallel(TransferFunction(-x + y, x + y, y), TransferFunction(x*y**2 - z, -t**3 + y**3, y))" + + +def test_MIMOParallel_str(): + tf1 = TransferFunction(x*y**2 - z, y**3 - t**3, y) + tf2 = TransferFunction(x - y, x + y, y) + tfm_1 = TransferFunctionMatrix([[tf1, tf2], [tf2, tf1]]) + tfm_2 = TransferFunctionMatrix([[tf2, tf1], [tf1, tf2]]) + assert str(MIMOParallel(tfm_1, tfm_2)) == \ + "MIMOParallel(TransferFunctionMatrix(((TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(x - y, x + y, y)), "\ + "(TransferFunction(x - y, x + y, y), TransferFunction(x*y**2 - z, -t**3 + y**3, y)))), "\ + "TransferFunctionMatrix(((TransferFunction(x - y, x + y, y), TransferFunction(x*y**2 - z, -t**3 + y**3, y)), "\ + "(TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(x - y, x + y, y)))))" + + +def test_Feedback_str(): + tf1 = TransferFunction(x*y**2 - z, y**3 - t**3, y) + tf2 = TransferFunction(x - y, x + y, y) + tf3 = TransferFunction(t*x**2 - t**w*x + w, t - y, y) + assert str(Feedback(tf1*tf2, tf3)) == \ + "Feedback(Series(TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(x - y, x + y, y)), " \ + "TransferFunction(t*x**2 - t**w*x + w, t - y, y), -1)" + assert str(Feedback(tf1, TransferFunction(1, 1, y), 1)) == \ + "Feedback(TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(1, 1, y), 1)" + + +def test_MIMOFeedback_str(): + tf1 = TransferFunction(x**2 - y**3, y - z, x) + tf2 = TransferFunction(y - x, z + y, x) + tfm_1 = TransferFunctionMatrix([[tf2, tf1], [tf1, tf2]]) + tfm_2 = TransferFunctionMatrix([[tf1, tf2], [tf2, tf1]]) + assert (str(MIMOFeedback(tfm_1, tfm_2)) \ + == "MIMOFeedback(TransferFunctionMatrix(((TransferFunction(-x + y, y + z, x), TransferFunction(x**2 - y**3, y - z, x))," \ + " (TransferFunction(x**2 - y**3, y - z, x), TransferFunction(-x + y, y + z, x)))), " \ + "TransferFunctionMatrix(((TransferFunction(x**2 - y**3, y - z, x), " \ + "TransferFunction(-x + y, y + z, x)), (TransferFunction(-x + y, y + z, x), TransferFunction(x**2 - y**3, y - z, x)))), -1)") + assert (str(MIMOFeedback(tfm_1, tfm_2, 1)) \ + == "MIMOFeedback(TransferFunctionMatrix(((TransferFunction(-x + y, y + z, x), TransferFunction(x**2 - y**3, y - z, x)), " \ + "(TransferFunction(x**2 - y**3, y - z, x), TransferFunction(-x + y, y + z, x)))), " \ + "TransferFunctionMatrix(((TransferFunction(x**2 - y**3, y - z, x), TransferFunction(-x + y, y + z, x)), "\ + "(TransferFunction(-x + y, y + z, x), TransferFunction(x**2 - y**3, y - z, x)))), 1)") + + +def test_TransferFunctionMatrix_str(): + tf1 = TransferFunction(x*y**2 - z, y**3 - t**3, y) + tf2 = TransferFunction(x - y, x + y, y) + tf3 = TransferFunction(t*x**2 - t**w*x + w, t - y, y) + assert str(TransferFunctionMatrix([[tf1], [tf2]])) == \ + "TransferFunctionMatrix(((TransferFunction(x*y**2 - z, -t**3 + y**3, y),), (TransferFunction(x - y, x + y, y),)))" + assert str(TransferFunctionMatrix([[tf1, tf2], [tf3, tf2]])) == \ + "TransferFunctionMatrix(((TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(x - y, x + y, y)), (TransferFunction(t*x**2 - t**w*x + w, t - y, y), TransferFunction(x - y, x + y, y))))" + + +def test_Quaternion_str_printer(): + q = Quaternion(x, y, z, t) + assert str(q) == "x + y*i + z*j + t*k" + q = Quaternion(x,y,z,x*t) + assert str(q) == "x + y*i + z*j + t*x*k" + q = Quaternion(x,y,z,x+t) + assert str(q) == "x + y*i + z*j + (t + x)*k" + + +def test_Quantity_str(): + assert sstr(second, abbrev=True) == "s" + assert sstr(joule, abbrev=True) == "J" + assert str(second) == "second" + assert str(joule) == "joule" + + +def test_wild_str(): + # Check expressions containing Wild not causing infinite recursion + w = Wild('x') + assert str(w + 1) == 'x_ + 1' + assert str(exp(2**w) + 5) == 'exp(2**x_) + 5' + assert str(3*w + 1) == '3*x_ + 1' + assert str(1/w + 1) == '1 + 1/x_' + assert str(w**2 + 1) == 'x_**2 + 1' + assert str(1/(1 - w)) == '1/(1 - x_)' + + +def test_wild_matchpy(): + from sympy.utilities.matchpy_connector import WildDot, WildPlus, WildStar + + matchpy = import_module("matchpy") + + if matchpy is None: + return + + wd = WildDot('w_') + wp = WildPlus('w__') + ws = WildStar('w___') + + assert str(wd) == 'w_' + assert str(wp) == 'w__' + assert str(ws) == 'w___' + + assert str(wp/ws + 2**wd) == '2**w_ + w__/w___' + assert str(sin(wd)*cos(wp)*sqrt(ws)) == 'sqrt(w___)*sin(w_)*cos(w__)' + + +def test_zeta(): + assert str(zeta(3)) == "zeta(3)" + + +def test_issue_3101(): + e = x - y + a = str(e) + b = str(e) + assert a == b + + +def test_issue_3103(): + e = -2*sqrt(x) - y/sqrt(x)/2 + assert str(e) not in ["(-2)*x**1/2(-1/2)*x**(-1/2)*y", + "-2*x**1/2(-1/2)*x**(-1/2)*y", "-2*x**1/2-1/2*x**-1/2*w"] + assert str(e) == "-2*sqrt(x) - y/(2*sqrt(x))" + + +def test_issue_4021(): + e = Integral(x, x) + 1 + assert str(e) == 'Integral(x, x) + 1' + + +def test_sstrrepr(): + assert sstr('abc') == 'abc' + assert sstrrepr('abc') == "'abc'" + + e = ['a', 'b', 'c', x] + assert sstr(e) == "[a, b, c, x]" + assert sstrrepr(e) == "['a', 'b', 'c', x]" + + +def test_infinity(): + assert sstr(oo*I) == "oo*I" + + +def test_full_prec(): + assert sstr(S("0.3"), full_prec=True) == "0.300000000000000" + assert sstr(S("0.3"), full_prec="auto") == "0.300000000000000" + assert sstr(S("0.3"), full_prec=False) == "0.3" + assert sstr(S("0.3")*x, full_prec=True) in [ + "0.300000000000000*x", + "x*0.300000000000000" + ] + assert sstr(S("0.3")*x, full_prec="auto") in [ + "0.3*x", + "x*0.3" + ] + assert sstr(S("0.3")*x, full_prec=False) in [ + "0.3*x", + "x*0.3" + ] + + +def test_noncommutative(): + A, B, C = symbols('A,B,C', commutative=False) + + assert sstr(A*B*C**-1) == "A*B*C**(-1)" + assert sstr(C**-1*A*B) == "C**(-1)*A*B" + assert sstr(A*C**-1*B) == "A*C**(-1)*B" + assert sstr(sqrt(A)) == "sqrt(A)" + assert sstr(1/sqrt(A)) == "A**(-1/2)" + + +def test_empty_printer(): + str_printer = StrPrinter() + assert str_printer.emptyPrinter("foo") == "foo" + assert str_printer.emptyPrinter(x*y) == "x*y" + assert str_printer.emptyPrinter(32) == "32" + +def test_decimal_printer(): + dec_printer = StrPrinter(settings={"dps":3}) + f = Function('f') + assert dec_printer.doprint(f(1.329294)) == "f(1.33)" + + +def test_settings(): + raises(TypeError, lambda: sstr(S(4), method="garbage")) + + +def test_RandomDomain(): + from sympy.stats import Normal, Die, Exponential, pspace, where + X = Normal('x1', 0, 1) + assert str(where(X > 0)) == "Domain: (0 < x1) & (x1 < oo)" + + D = Die('d1', 6) + assert str(where(D > 4)) == "Domain: Eq(d1, 5) | Eq(d1, 6)" + + A = Exponential('a', 1) + B = Exponential('b', 1) + assert str(pspace(Tuple(A, B)).domain) == "Domain: (0 <= a) & (0 <= b) & (a < oo) & (b < oo)" + + +def test_FiniteSet(): + assert str(FiniteSet(*range(1, 51))) == ( + '{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,' + ' 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34,' + ' 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50}' + ) + assert str(FiniteSet(*range(1, 6))) == '{1, 2, 3, 4, 5}' + assert str(FiniteSet(*[x*y, x**2])) == '{x**2, x*y}' + assert str(FiniteSet(FiniteSet(FiniteSet(x, y), 5), FiniteSet(x,y), 5) + ) == 'FiniteSet(5, FiniteSet(5, {x, y}), {x, y})' + + +def test_Partition(): + assert str(Partition(FiniteSet(x, y), {z})) == 'Partition({z}, {x, y})' + +def test_UniversalSet(): + assert str(S.UniversalSet) == 'UniversalSet' + + +def test_PrettyPoly(): + F = QQ.frac_field(x, y) + R = QQ[x, y] + assert sstr(F.convert(x/(x + y))) == sstr(x/(x + y)) + assert sstr(R.convert(x + y)) == sstr(x + y) + + +def test_categories(): + from sympy.categories import (Object, NamedMorphism, + IdentityMorphism, Category) + + A = Object("A") + B = Object("B") + + f = NamedMorphism(A, B, "f") + id_A = IdentityMorphism(A) + + K = Category("K") + + assert str(A) == 'Object("A")' + assert str(f) == 'NamedMorphism(Object("A"), Object("B"), "f")' + assert str(id_A) == 'IdentityMorphism(Object("A"))' + + assert str(K) == 'Category("K")' + + +def test_Tr(): + A, B = symbols('A B', commutative=False) + t = Tr(A*B) + assert str(t) == 'Tr(A*B)' + + +def test_issue_6387(): + assert str(factor(-3.0*z + 3)) == '-3.0*(1.0*z - 1.0)' + + +def test_MatMul_MatAdd(): + X, Y = MatrixSymbol("X", 2, 2), MatrixSymbol("Y", 2, 2) + assert str(2*(X + Y)) == "2*X + 2*Y" + + assert str(I*X) == "I*X" + assert str(-I*X) == "-I*X" + assert str((1 + I)*X) == '(1 + I)*X' + assert str(-(1 + I)*X) == '(-1 - I)*X' + assert str(MatAdd(MatAdd(X, Y), MatAdd(X, Y))) == '(X + Y) + (X + Y)' + + +def test_MatrixSlice(): + n = Symbol('n', integer=True) + X = MatrixSymbol('X', n, n) + Y = MatrixSymbol('Y', 10, 10) + Z = MatrixSymbol('Z', 10, 10) + + assert str(MatrixSlice(X, (None, None, None), (None, None, None))) == 'X[:, :]' + assert str(X[x:x + 1, y:y + 1]) == 'X[x:x + 1, y:y + 1]' + assert str(X[x:x + 1:2, y:y + 1:2]) == 'X[x:x + 1:2, y:y + 1:2]' + assert str(X[:x, y:]) == 'X[:x, y:]' + assert str(X[:x, y:]) == 'X[:x, y:]' + assert str(X[x:, :y]) == 'X[x:, :y]' + assert str(X[x:y, z:w]) == 'X[x:y, z:w]' + assert str(X[x:y:t, w:t:x]) == 'X[x:y:t, w:t:x]' + assert str(X[x::y, t::w]) == 'X[x::y, t::w]' + assert str(X[:x:y, :t:w]) == 'X[:x:y, :t:w]' + assert str(X[::x, ::y]) == 'X[::x, ::y]' + assert str(MatrixSlice(X, (0, None, None), (0, None, None))) == 'X[:, :]' + assert str(MatrixSlice(X, (None, n, None), (None, n, None))) == 'X[:, :]' + assert str(MatrixSlice(X, (0, n, None), (0, n, None))) == 'X[:, :]' + assert str(MatrixSlice(X, (0, n, 2), (0, n, 2))) == 'X[::2, ::2]' + assert str(X[1:2:3, 4:5:6]) == 'X[1:2:3, 4:5:6]' + assert str(X[1:3:5, 4:6:8]) == 'X[1:3:5, 4:6:8]' + assert str(X[1:10:2]) == 'X[1:10:2, :]' + assert str(Y[:5, 1:9:2]) == 'Y[:5, 1:9:2]' + assert str(Y[:5, 1:10:2]) == 'Y[:5, 1::2]' + assert str(Y[5, :5:2]) == 'Y[5:6, :5:2]' + assert str(X[0:1, 0:1]) == 'X[:1, :1]' + assert str(X[0:1:2, 0:1:2]) == 'X[:1:2, :1:2]' + assert str((Y + Z)[2:, 2:]) == '(Y + Z)[2:, 2:]' + +def test_true_false(): + assert str(true) == repr(true) == sstr(true) == "True" + assert str(false) == repr(false) == sstr(false) == "False" + +def test_Equivalent(): + assert str(Equivalent(y, x)) == "Equivalent(x, y)" + +def test_Xor(): + assert str(Xor(y, x, evaluate=False)) == "x ^ y" + +def test_Complement(): + assert str(Complement(S.Reals, S.Naturals)) == 'Complement(Reals, Naturals)' + +def test_SymmetricDifference(): + assert str(SymmetricDifference(Interval(2, 3), Interval(3, 4),evaluate=False)) == \ + 'SymmetricDifference(Interval(2, 3), Interval(3, 4))' + + +def test_UnevaluatedExpr(): + a, b = symbols("a b") + expr1 = 2*UnevaluatedExpr(a+b) + assert str(expr1) == "2*(a + b)" + + +def test_MatrixElement_printing(): + # test cases for issue #11821 + A = MatrixSymbol("A", 1, 3) + B = MatrixSymbol("B", 1, 3) + C = MatrixSymbol("C", 1, 3) + + assert(str(A[0, 0]) == "A[0, 0]") + assert(str(3 * A[0, 0]) == "3*A[0, 0]") + + F = C[0, 0].subs(C, A - B) + assert str(F) == "(A - B)[0, 0]" + + +def test_MatrixSymbol_printing(): + A = MatrixSymbol("A", 3, 3) + B = MatrixSymbol("B", 3, 3) + + assert str(A - A*B - B) == "A - A*B - B" + assert str(A*B - (A+B)) == "-A + A*B - B" + assert str(A**(-1)) == "A**(-1)" + assert str(A**3) == "A**3" + + +def test_MatrixExpressions(): + n = Symbol('n', integer=True) + X = MatrixSymbol('X', n, n) + + assert str(X) == "X" + + # Apply function elementwise (`ElementwiseApplyFunc`): + + expr = (X.T*X).applyfunc(sin) + assert str(expr) == 'Lambda(_d, sin(_d)).(X.T*X)' + + lamda = Lambda(x, 1/x) + expr = (n*X).applyfunc(lamda) + assert str(expr) == 'Lambda(x, 1/x).(n*X)' + + +def test_Subs_printing(): + assert str(Subs(x, (x,), (1,))) == 'Subs(x, x, 1)' + assert str(Subs(x + y, (x, y), (1, 2))) == 'Subs(x + y, (x, y), (1, 2))' + + +def test_issue_15716(): + e = Integral(factorial(x), (x, -oo, oo)) + assert e.as_terms() == ([(e, ((1.0, 0.0), (1,), ()))], [e]) + + +def test_str_special_matrices(): + from sympy.matrices import Identity, ZeroMatrix, OneMatrix + assert str(Identity(4)) == 'I' + assert str(ZeroMatrix(2, 2)) == '0' + assert str(OneMatrix(2, 2)) == '1' + + +def test_issue_14567(): + assert factorial(Sum(-1, (x, 0, 0))) + y # doesn't raise an error + + +def test_issue_21823(): + assert str(Partition([1, 2])) == 'Partition({1, 2})' + assert str(Partition({1, 2})) == 'Partition({1, 2})' + + +def test_issue_22689(): + assert str(Mul(Pow(x,-2, evaluate=False), Pow(3,-1,evaluate=False), evaluate=False)) == "1/(x**2*3)" + + +def test_issue_21119_21460(): + ss = lambda x: str(S(x, evaluate=False)) + assert ss('4/2') == '4/2' + assert ss('4/-2') == '4/(-2)' + assert ss('-4/2') == '-4/2' + assert ss('-4/-2') == '-4/(-2)' + assert ss('-2*3/-1') == '-2*3/(-1)' + assert ss('-2*3/-1/2') == '-2*3/(-1*2)' + assert ss('4/2/1') == '4/(2*1)' + assert ss('-2/-1/2') == '-2/(-1*2)' + assert ss('2*3*4**(-2*3)') == '2*3/4**(2*3)' + assert ss('2*3*1*4**(-2*3)') == '2*3*1/4**(2*3)' + + +def test_Str(): + from sympy.core.symbol import Str + assert str(Str('x')) == 'x' + assert sstrrepr(Str('x')) == "Str('x')" + + +def test_diffgeom(): + from sympy.diffgeom import Manifold, Patch, CoordSystem, BaseScalarField + x,y = symbols('x y', real=True) + m = Manifold('M', 2) + assert str(m) == "M" + p = Patch('P', m) + assert str(p) == "P" + rect = CoordSystem('rect', p, [x, y]) + assert str(rect) == "rect" + b = BaseScalarField(rect, 0) + assert str(b) == "x" + +def test_NDimArray(): + assert sstr(NDimArray(1.0), full_prec=True) == '1.00000000000000' + assert sstr(NDimArray(1.0), full_prec=False) == '1.0' + assert sstr(NDimArray([1.0, 2.0]), full_prec=True) == '[1.00000000000000, 2.00000000000000]' + assert sstr(NDimArray([1.0, 2.0]), full_prec=False) == '[1.0, 2.0]' + assert sstr(NDimArray([], (0,))) == 'ImmutableDenseNDimArray([], (0,))' + assert sstr(NDimArray([], (0, 0))) == 'ImmutableDenseNDimArray([], (0, 0))' + assert sstr(NDimArray([], (0, 1))) == 'ImmutableDenseNDimArray([], (0, 1))' + assert sstr(NDimArray([], (1, 0))) == 'ImmutableDenseNDimArray([], (1, 0))' + +def test_Predicate(): + assert sstr(Q.even) == 'Q.even' + +def test_AppliedPredicate(): + assert sstr(Q.even(x)) == 'Q.even(x)' + +def test_printing_str_array_expressions(): + assert sstr(ArraySymbol("A", (2, 3, 4))) == "A" + assert sstr(ArrayElement("A", (2, 1/(1-x), 0))) == "A[2, 1/(1 - x), 0]" + M = MatrixSymbol("M", 3, 3) + N = MatrixSymbol("N", 3, 3) + assert sstr(ArrayElement(M*N, [x, 0])) == "(M*N)[x, 0]" + +def test_printing_stats(): + # issue 24132 + x = RandomSymbol("x") + y = RandomSymbol("y") + z1 = Probability(x > 0)*Identity(2) + z2 = Expectation(x)*Identity(2) + z3 = Variance(x)*Identity(2) + z4 = Covariance(x, y) * Identity(2) + + assert str(z1) == "Probability(x > 0)*I" + assert str(z2) == "Expectation(x)*I" + assert str(z3) == "Variance(x)*I" + assert str(z4) == "Covariance(x, y)*I" + assert z1.is_commutative == False + assert z2.is_commutative == False + assert z3.is_commutative == False + assert z4.is_commutative == False + assert z2._eval_is_commutative() == False + assert z3._eval_is_commutative() == False + assert z4._eval_is_commutative() == False diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/test_tableform.py b/phivenv/Lib/site-packages/sympy/printing/tests/test_tableform.py new file mode 100644 index 0000000000000000000000000000000000000000..05802dd104a12f2f53d137167ecf31d201ff8dfc --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/tests/test_tableform.py @@ -0,0 +1,182 @@ +from sympy.core.singleton import S +from sympy.printing.tableform import TableForm +from sympy.printing.latex import latex +from sympy.abc import x +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import sin +from sympy.testing.pytest import raises + +from textwrap import dedent + + +def test_TableForm(): + s = str(TableForm([["a", "b"], ["c", "d"], ["e", 0]], + headings="automatic")) + assert s == ( + ' | 1 2\n' + '-------\n' + '1 | a b\n' + '2 | c d\n' + '3 | e ' + ) + s = str(TableForm([["a", "b"], ["c", "d"], ["e", 0]], + headings="automatic", wipe_zeros=False)) + assert s == dedent('''\ + | 1 2 + ------- + 1 | a b + 2 | c d + 3 | e 0''') + s = str(TableForm([[x**2, "b"], ["c", x**2], ["e", "f"]], + headings=("automatic", None))) + assert s == ( + '1 | x**2 b \n' + '2 | c x**2\n' + '3 | e f ' + ) + s = str(TableForm([["a", "b"], ["c", "d"], ["e", "f"]], + headings=(None, "automatic"))) + assert s == dedent('''\ + 1 2 + --- + a b + c d + e f''') + s = str(TableForm([[5, 7], [4, 2], [10, 3]], + headings=[["Group A", "Group B", "Group C"], ["y1", "y2"]])) + assert s == ( + ' | y1 y2\n' + '---------------\n' + 'Group A | 5 7 \n' + 'Group B | 4 2 \n' + 'Group C | 10 3 ' + ) + raises( + ValueError, + lambda: + TableForm( + [[5, 7], [4, 2], [10, 3]], + headings=[["Group A", "Group B", "Group C"], ["y1", "y2"]], + alignments="middle") + ) + s = str(TableForm([[5, 7], [4, 2], [10, 3]], + headings=[["Group A", "Group B", "Group C"], ["y1", "y2"]], + alignments="right")) + assert s == dedent('''\ + | y1 y2 + --------------- + Group A | 5 7 + Group B | 4 2 + Group C | 10 3''') + + # other alignment permutations + d = [[1, 100], [100, 1]] + s = TableForm(d, headings=(('xxx', 'x'), None), alignments='l') + assert str(s) == ( + 'xxx | 1 100\n' + ' x | 100 1 ' + ) + s = TableForm(d, headings=(('xxx', 'x'), None), alignments='lr') + assert str(s) == dedent('''\ + xxx | 1 100 + x | 100 1''') + s = TableForm(d, headings=(('xxx', 'x'), None), alignments='clr') + assert str(s) == dedent('''\ + xxx | 1 100 + x | 100 1''') + + s = TableForm(d, headings=(('xxx', 'x'), None)) + assert str(s) == ( + 'xxx | 1 100\n' + ' x | 100 1 ' + ) + + raises(ValueError, lambda: TableForm(d, alignments='clr')) + + #pad + s = str(TableForm([[None, "-", 2], [1]], pad='?')) + assert s == dedent('''\ + ? - 2 + 1 ? ?''') + + +def test_TableForm_latex(): + s = latex(TableForm([[0, x**3], ["c", S.One/4], [sqrt(x), sin(x**2)]], + wipe_zeros=True, headings=("automatic", "automatic"))) + assert s == ( + '\\begin{tabular}{r l l}\n' + ' & 1 & 2 \\\\\n' + '\\hline\n' + '1 & & $x^{3}$ \\\\\n' + '2 & $c$ & $\\frac{1}{4}$ \\\\\n' + '3 & $\\sqrt{x}$ & $\\sin{\\left(x^{2} \\right)}$ \\\\\n' + '\\end{tabular}' + ) + s = latex(TableForm([[0, x**3], ["c", S.One/4], [sqrt(x), sin(x**2)]], + wipe_zeros=True, headings=("automatic", "automatic"), alignments='l')) + assert s == ( + '\\begin{tabular}{r l l}\n' + ' & 1 & 2 \\\\\n' + '\\hline\n' + '1 & & $x^{3}$ \\\\\n' + '2 & $c$ & $\\frac{1}{4}$ \\\\\n' + '3 & $\\sqrt{x}$ & $\\sin{\\left(x^{2} \\right)}$ \\\\\n' + '\\end{tabular}' + ) + s = latex(TableForm([[0, x**3], ["c", S.One/4], [sqrt(x), sin(x**2)]], + wipe_zeros=True, headings=("automatic", "automatic"), alignments='l'*3)) + assert s == ( + '\\begin{tabular}{l l l}\n' + ' & 1 & 2 \\\\\n' + '\\hline\n' + '1 & & $x^{3}$ \\\\\n' + '2 & $c$ & $\\frac{1}{4}$ \\\\\n' + '3 & $\\sqrt{x}$ & $\\sin{\\left(x^{2} \\right)}$ \\\\\n' + '\\end{tabular}' + ) + s = latex(TableForm([["a", x**3], ["c", S.One/4], [sqrt(x), sin(x**2)]], + headings=("automatic", "automatic"))) + assert s == ( + '\\begin{tabular}{r l l}\n' + ' & 1 & 2 \\\\\n' + '\\hline\n' + '1 & $a$ & $x^{3}$ \\\\\n' + '2 & $c$ & $\\frac{1}{4}$ \\\\\n' + '3 & $\\sqrt{x}$ & $\\sin{\\left(x^{2} \\right)}$ \\\\\n' + '\\end{tabular}' + ) + s = latex(TableForm([["a", x**3], ["c", S.One/4], [sqrt(x), sin(x**2)]], + formats=['(%s)', None], headings=("automatic", "automatic"))) + assert s == ( + '\\begin{tabular}{r l l}\n' + ' & 1 & 2 \\\\\n' + '\\hline\n' + '1 & (a) & $x^{3}$ \\\\\n' + '2 & (c) & $\\frac{1}{4}$ \\\\\n' + '3 & (sqrt(x)) & $\\sin{\\left(x^{2} \\right)}$ \\\\\n' + '\\end{tabular}' + ) + + def neg_in_paren(x, i, j): + if i % 2: + return ('(%s)' if x < 0 else '%s') % x + else: + pass # use default print + s = latex(TableForm([[-1, 2], [-3, 4]], + formats=[neg_in_paren]*2, headings=("automatic", "automatic"))) + assert s == ( + '\\begin{tabular}{r l l}\n' + ' & 1 & 2 \\\\\n' + '\\hline\n' + '1 & -1 & 2 \\\\\n' + '2 & (-3) & 4 \\\\\n' + '\\end{tabular}' + ) + s = latex(TableForm([["a", x**3], ["c", S.One/4], [sqrt(x), sin(x**2)]])) + assert s == ( + '\\begin{tabular}{l l}\n' + '$a$ & $x^{3}$ \\\\\n' + '$c$ & $\\frac{1}{4}$ \\\\\n' + '$\\sqrt{x}$ & $\\sin{\\left(x^{2} \\right)}$ \\\\\n' + '\\end{tabular}' + ) diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/test_tensorflow.py b/phivenv/Lib/site-packages/sympy/printing/tests/test_tensorflow.py new file mode 100644 index 0000000000000000000000000000000000000000..e9c92cd17b13e1148ebf83f13f66854b983491fe --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/tests/test_tensorflow.py @@ -0,0 +1,493 @@ +import random +from sympy.core.function import Derivative +from sympy.core.symbol import symbols +from sympy import Piecewise +from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct, ArrayAdd, \ + PermuteDims, ArrayDiagonal +from sympy.core.relational import Eq, Ne, Ge, Gt, Le, Lt +from sympy.external import import_module +from sympy.functions import \ + Abs, ceiling, exp, floor, sign, sin, asin, sqrt, cos, \ + acos, tan, atan, atan2, cosh, acosh, sinh, asinh, tanh, atanh, \ + re, im, arg, erf, loggamma, log +from sympy.codegen.cfunctions import isnan, isinf +from sympy.matrices import Matrix, MatrixBase, eye, randMatrix +from sympy.matrices.expressions import \ + Determinant, HadamardProduct, Inverse, MatrixSymbol, Trace +from sympy.printing.tensorflow import tensorflow_code +from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array +from sympy.utilities.lambdify import lambdify +from sympy.testing.pytest import skip +from sympy.testing.pytest import XFAIL + + +tf = tensorflow = import_module("tensorflow") + +if tensorflow: + # Hide Tensorflow warnings + import os + os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' + + +M = MatrixSymbol("M", 3, 3) +N = MatrixSymbol("N", 3, 3) +P = MatrixSymbol("P", 3, 3) +Q = MatrixSymbol("Q", 3, 3) + +x, y, z, t = symbols("x y z t") + +if tf is not None: + llo = [list(range(i, i+3)) for i in range(0, 9, 3)] + m3x3 = tf.constant(llo) + m3x3sympy = Matrix(llo) + + +def _compare_tensorflow_matrix(variables, expr, use_float=False): + f = lambdify(variables, expr, 'tensorflow') + if not use_float: + random_matrices = [randMatrix(v.rows, v.cols) for v in variables] + else: + random_matrices = [randMatrix(v.rows, v.cols)/100. for v in variables] + + graph = tf.Graph() + r = None + with graph.as_default(): + random_variables = [eval(tensorflow_code(i)) for i in random_matrices] + session = tf.compat.v1.Session(graph=graph) + r = session.run(f(*random_variables)) + + e = expr.subs(dict(zip(variables, random_matrices))) + e = e.doit() + if e.is_Matrix: + if not isinstance(e, MatrixBase): + e = e.as_explicit() + e = e.tolist() + + if not use_float: + assert (r == e).all() + else: + r = [i for row in r for i in row] + e = [i for row in e for i in row] + assert all( + abs(a-b) < 10**-(4-int(log(abs(a), 10))) for a, b in zip(r, e)) + + +# Creating a custom inverse test. +# See https://github.com/sympy/sympy/issues/18469 +def _compare_tensorflow_matrix_inverse(variables, expr, use_float=False): + f = lambdify(variables, expr, 'tensorflow') + if not use_float: + random_matrices = [eye(v.rows, v.cols)*4 for v in variables] + else: + random_matrices = [eye(v.rows, v.cols)*3.14 for v in variables] + + graph = tf.Graph() + r = None + with graph.as_default(): + random_variables = [eval(tensorflow_code(i)) for i in random_matrices] + session = tf.compat.v1.Session(graph=graph) + r = session.run(f(*random_variables)) + + e = expr.subs(dict(zip(variables, random_matrices))) + e = e.doit() + if e.is_Matrix: + if not isinstance(e, MatrixBase): + e = e.as_explicit() + e = e.tolist() + + if not use_float: + assert (r == e).all() + else: + r = [i for row in r for i in row] + e = [i for row in e for i in row] + assert all( + abs(a-b) < 10**-(4-int(log(abs(a), 10))) for a, b in zip(r, e)) + + +def _compare_tensorflow_matrix_scalar(variables, expr): + f = lambdify(variables, expr, 'tensorflow') + random_matrices = [ + randMatrix(v.rows, v.cols).evalf() / 100 for v in variables] + + graph = tf.Graph() + r = None + with graph.as_default(): + random_variables = [eval(tensorflow_code(i)) for i in random_matrices] + session = tf.compat.v1.Session(graph=graph) + r = session.run(f(*random_variables)) + + e = expr.subs(dict(zip(variables, random_matrices))) + e = e.doit() + assert abs(r-e) < 10**-6 + + +def _compare_tensorflow_scalar( + variables, expr, rng=lambda: random.randint(0, 10)): + f = lambdify(variables, expr, 'tensorflow') + rvs = [rng() for v in variables] + + graph = tf.Graph() + r = None + with graph.as_default(): + tf_rvs = [eval(tensorflow_code(i)) for i in rvs] + session = tf.compat.v1.Session(graph=graph) + r = session.run(f(*tf_rvs)) + + e = expr.subs(dict(zip(variables, rvs))).evalf().doit() + assert abs(r-e) < 10**-6 + + +def _compare_tensorflow_relational( + variables, expr, rng=lambda: random.randint(0, 10)): + f = lambdify(variables, expr, 'tensorflow') + rvs = [rng() for v in variables] + + graph = tf.Graph() + r = None + with graph.as_default(): + tf_rvs = [eval(tensorflow_code(i)) for i in rvs] + session = tf.compat.v1.Session(graph=graph) + r = session.run(f(*tf_rvs)) + + e = expr.subs(dict(zip(variables, rvs))).doit() + assert r == e + + +def test_tensorflow_printing(): + assert tensorflow_code(eye(3)) == \ + "tensorflow.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]])" + + expr = Matrix([[x, sin(y)], [exp(z), -t]]) + assert tensorflow_code(expr) == \ + "tensorflow.Variable(" \ + "[[x, tensorflow.math.sin(y)]," \ + " [tensorflow.math.exp(z), -t]])" + + +# This (random) test is XFAIL because it fails occasionally +# See https://github.com/sympy/sympy/issues/18469 +@XFAIL +def test_tensorflow_math(): + if not tf: + skip("TensorFlow not installed") + + expr = Abs(x) + assert tensorflow_code(expr) == "tensorflow.math.abs(x)" + _compare_tensorflow_scalar((x,), expr) + + expr = sign(x) + assert tensorflow_code(expr) == "tensorflow.math.sign(x)" + _compare_tensorflow_scalar((x,), expr) + + expr = ceiling(x) + assert tensorflow_code(expr) == "tensorflow.math.ceil(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random()) + + expr = floor(x) + assert tensorflow_code(expr) == "tensorflow.math.floor(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random()) + + expr = exp(x) + assert tensorflow_code(expr) == "tensorflow.math.exp(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random()) + + expr = sqrt(x) + assert tensorflow_code(expr) == "tensorflow.math.sqrt(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random()) + + expr = x ** 4 + assert tensorflow_code(expr) == "tensorflow.math.pow(x, 4)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random()) + + expr = cos(x) + assert tensorflow_code(expr) == "tensorflow.math.cos(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random()) + + expr = acos(x) + assert tensorflow_code(expr) == "tensorflow.math.acos(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.uniform(0, 0.95)) + + expr = sin(x) + assert tensorflow_code(expr) == "tensorflow.math.sin(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random()) + + expr = asin(x) + assert tensorflow_code(expr) == "tensorflow.math.asin(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random()) + + expr = tan(x) + assert tensorflow_code(expr) == "tensorflow.math.tan(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random()) + + expr = atan(x) + assert tensorflow_code(expr) == "tensorflow.math.atan(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random()) + + expr = atan2(y, x) + assert tensorflow_code(expr) == "tensorflow.math.atan2(y, x)" + _compare_tensorflow_scalar((y, x), expr, rng=lambda: random.random()) + + expr = cosh(x) + assert tensorflow_code(expr) == "tensorflow.math.cosh(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random()) + + expr = acosh(x) + assert tensorflow_code(expr) == "tensorflow.math.acosh(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.uniform(1, 2)) + + expr = sinh(x) + assert tensorflow_code(expr) == "tensorflow.math.sinh(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.uniform(1, 2)) + + expr = asinh(x) + assert tensorflow_code(expr) == "tensorflow.math.asinh(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.uniform(1, 2)) + + expr = tanh(x) + assert tensorflow_code(expr) == "tensorflow.math.tanh(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.uniform(1, 2)) + + expr = atanh(x) + assert tensorflow_code(expr) == "tensorflow.math.atanh(x)" + _compare_tensorflow_scalar( + (x,), expr, rng=lambda: random.uniform(-.5, .5)) + + expr = erf(x) + assert tensorflow_code(expr) == "tensorflow.math.erf(x)" + _compare_tensorflow_scalar( + (x,), expr, rng=lambda: random.random()) + + expr = loggamma(x) + assert tensorflow_code(expr) == "tensorflow.math.lgamma(x)" + _compare_tensorflow_scalar( + (x,), expr, rng=lambda: random.random()) + + +def test_tensorflow_complexes(): + assert tensorflow_code(re(x)) == "tensorflow.math.real(x)" + assert tensorflow_code(im(x)) == "tensorflow.math.imag(x)" + assert tensorflow_code(arg(x)) == "tensorflow.math.angle(x)" + + +def test_tensorflow_relational(): + if not tf: + skip("TensorFlow not installed") + + expr = Eq(x, y) + assert tensorflow_code(expr) == "tensorflow.math.equal(x, y)" + _compare_tensorflow_relational((x, y), expr) + + expr = Ne(x, y) + assert tensorflow_code(expr) == "tensorflow.math.not_equal(x, y)" + _compare_tensorflow_relational((x, y), expr) + + expr = Ge(x, y) + assert tensorflow_code(expr) == "tensorflow.math.greater_equal(x, y)" + _compare_tensorflow_relational((x, y), expr) + + expr = Gt(x, y) + assert tensorflow_code(expr) == "tensorflow.math.greater(x, y)" + _compare_tensorflow_relational((x, y), expr) + + expr = Le(x, y) + assert tensorflow_code(expr) == "tensorflow.math.less_equal(x, y)" + _compare_tensorflow_relational((x, y), expr) + + expr = Lt(x, y) + assert tensorflow_code(expr) == "tensorflow.math.less(x, y)" + _compare_tensorflow_relational((x, y), expr) + + +# This (random) test is XFAIL because it fails occasionally +# See https://github.com/sympy/sympy/issues/18469 +@XFAIL +def test_tensorflow_matrices(): + if not tf: + skip("TensorFlow not installed") + + expr = M + assert tensorflow_code(expr) == "M" + _compare_tensorflow_matrix((M,), expr) + + expr = M + N + assert tensorflow_code(expr) == "tensorflow.math.add(M, N)" + _compare_tensorflow_matrix((M, N), expr) + + expr = M * N + assert tensorflow_code(expr) == "tensorflow.linalg.matmul(M, N)" + _compare_tensorflow_matrix((M, N), expr) + + expr = HadamardProduct(M, N) + assert tensorflow_code(expr) == "tensorflow.math.multiply(M, N)" + _compare_tensorflow_matrix((M, N), expr) + + expr = M*N*P*Q + assert tensorflow_code(expr) == \ + "tensorflow.linalg.matmul(" \ + "tensorflow.linalg.matmul(" \ + "tensorflow.linalg.matmul(M, N), P), Q)" + _compare_tensorflow_matrix((M, N, P, Q), expr) + + expr = M**3 + assert tensorflow_code(expr) == \ + "tensorflow.linalg.matmul(tensorflow.linalg.matmul(M, M), M)" + _compare_tensorflow_matrix((M,), expr) + + expr = Trace(M) + assert tensorflow_code(expr) == "tensorflow.linalg.trace(M)" + _compare_tensorflow_matrix((M,), expr) + + expr = Determinant(M) + assert tensorflow_code(expr) == "tensorflow.linalg.det(M)" + _compare_tensorflow_matrix_scalar((M,), expr) + + expr = Inverse(M) + assert tensorflow_code(expr) == "tensorflow.linalg.inv(M)" + _compare_tensorflow_matrix_inverse((M,), expr, use_float=True) + + expr = M.T + assert tensorflow_code(expr, tensorflow_version='1.14') == \ + "tensorflow.linalg.matrix_transpose(M)" + assert tensorflow_code(expr, tensorflow_version='1.13') == \ + "tensorflow.matrix_transpose(M)" + + _compare_tensorflow_matrix((M,), expr) + + +def test_codegen_einsum(): + if not tf: + skip("TensorFlow not installed") + + graph = tf.Graph() + with graph.as_default(): + session = tf.compat.v1.Session(graph=graph) + + M = MatrixSymbol("M", 2, 2) + N = MatrixSymbol("N", 2, 2) + + cg = convert_matrix_to_array(M * N) + f = lambdify((M, N), cg, 'tensorflow') + + ma = tf.constant([[1, 2], [3, 4]]) + mb = tf.constant([[1,-2], [-1, 3]]) + y = session.run(f(ma, mb)) + c = session.run(tf.matmul(ma, mb)) + assert (y == c).all() + + +def test_codegen_extra(): + if not tf: + skip("TensorFlow not installed") + + graph = tf.Graph() + with graph.as_default(): + session = tf.compat.v1.Session() + + M = MatrixSymbol("M", 2, 2) + N = MatrixSymbol("N", 2, 2) + P = MatrixSymbol("P", 2, 2) + Q = MatrixSymbol("Q", 2, 2) + ma = tf.constant([[1, 2], [3, 4]]) + mb = tf.constant([[1,-2], [-1, 3]]) + mc = tf.constant([[2, 0], [1, 2]]) + md = tf.constant([[1,-1], [4, 7]]) + + cg = ArrayTensorProduct(M, N) + assert tensorflow_code(cg) == \ + 'tensorflow.linalg.einsum("ab,cd", M, N)' + f = lambdify((M, N), cg, 'tensorflow') + y = session.run(f(ma, mb)) + c = session.run(tf.einsum("ij,kl", ma, mb)) + assert (y == c).all() + + cg = ArrayAdd(M, N) + assert tensorflow_code(cg) == 'tensorflow.math.add(M, N)' + f = lambdify((M, N), cg, 'tensorflow') + y = session.run(f(ma, mb)) + c = session.run(ma + mb) + assert (y == c).all() + + cg = ArrayAdd(M, N, P) + assert tensorflow_code(cg) == \ + 'tensorflow.math.add(tensorflow.math.add(M, N), P)' + f = lambdify((M, N, P), cg, 'tensorflow') + y = session.run(f(ma, mb, mc)) + c = session.run(ma + mb + mc) + assert (y == c).all() + + cg = ArrayAdd(M, N, P, Q) + assert tensorflow_code(cg) == \ + 'tensorflow.math.add(' \ + 'tensorflow.math.add(tensorflow.math.add(M, N), P), Q)' + f = lambdify((M, N, P, Q), cg, 'tensorflow') + y = session.run(f(ma, mb, mc, md)) + c = session.run(ma + mb + mc + md) + assert (y == c).all() + + cg = PermuteDims(M, [1, 0]) + assert tensorflow_code(cg) == 'tensorflow.transpose(M, [1, 0])' + f = lambdify((M,), cg, 'tensorflow') + y = session.run(f(ma)) + c = session.run(tf.transpose(ma)) + assert (y == c).all() + + cg = PermuteDims(ArrayTensorProduct(M, N), [1, 2, 3, 0]) + assert tensorflow_code(cg) == \ + 'tensorflow.transpose(' \ + 'tensorflow.linalg.einsum("ab,cd", M, N), [1, 2, 3, 0])' + f = lambdify((M, N), cg, 'tensorflow') + y = session.run(f(ma, mb)) + c = session.run(tf.transpose(tf.einsum("ab,cd", ma, mb), [1, 2, 3, 0])) + assert (y == c).all() + + cg = ArrayDiagonal(ArrayTensorProduct(M, N), (1, 2)) + assert tensorflow_code(cg) == \ + 'tensorflow.linalg.einsum("ab,bc->acb", M, N)' + f = lambdify((M, N), cg, 'tensorflow') + y = session.run(f(ma, mb)) + c = session.run(tf.einsum("ab,bc->acb", ma, mb)) + assert (y == c).all() + + +def test_MatrixElement_printing(): + A = MatrixSymbol("A", 1, 3) + B = MatrixSymbol("B", 1, 3) + C = MatrixSymbol("C", 1, 3) + + assert tensorflow_code(A[0, 0]) == "A[0, 0]" + assert tensorflow_code(3 * A[0, 0]) == "3*A[0, 0]" + + F = C[0, 0].subs(C, A - B) + assert tensorflow_code(F) == "(tensorflow.math.add((-1)*B, A))[0, 0]" + + +def test_tensorflow_Derivative(): + expr = Derivative(sin(x), x) + assert tensorflow_code(expr) == \ + "tensorflow.gradients(tensorflow.math.sin(x), x)[0]" + +def test_tensorflow_isnan_isinf(): + if not tf: + skip("TensorFlow not installed") + + # Test for isnan + x = symbols("x") + # Return 0 if x is of nan value, and 1 otherwise + expression = Piecewise((0.0, isnan(x)), (1.0, True)) + printed_code = tensorflow_code(expression) + expected_printed_code = "tensorflow.where(tensorflow.math.is_nan(x), 0.0, 1.0)" + assert tensorflow_code(expression) == expected_printed_code, f"Incorrect printed result {printed_code}, expected {expected_printed_code}" + for _input, _expected in [(float('nan'), 0.0), (float('inf'), 1.0), (float('-inf'), 1.0), (1.0, 1.0)]: + _output = lambdify((x), expression, modules="tensorflow")(x=tf.constant([_input])) + assert (_output == _expected).numpy().all() + + # Test for isinf + x = symbols("x") + # Return 0 if x is of nan value, and 1 otherwise + expression = Piecewise((0.0, isinf(x)), (1.0, True)) + printed_code = tensorflow_code(expression) + expected_printed_code = "tensorflow.where(tensorflow.math.is_inf(x), 0.0, 1.0)" + assert tensorflow_code(expression) == expected_printed_code, f"Incorrect printed result {printed_code}, expected {expected_printed_code}" + for _input, _expected in [(float('inf'), 0.0), (float('-inf'), 0.0), (float('nan'), 1.0), (1.0, 1.0)]: + _output = lambdify((x), expression, modules="tensorflow")(x=tf.constant([_input])) + assert (_output == _expected).numpy().all() diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/test_theanocode.py b/phivenv/Lib/site-packages/sympy/printing/tests/test_theanocode.py new file mode 100644 index 0000000000000000000000000000000000000000..6ff40f78cb4de16149cb5e780756b7e32b574b71 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/tests/test_theanocode.py @@ -0,0 +1,639 @@ +""" +Important note on tests in this module - the Theano printing functions use a +global cache by default, which means that tests using it will modify global +state and thus not be independent from each other. Instead of using the "cache" +keyword argument each time, this module uses the theano_code_ and +theano_function_ functions defined below which default to using a new, empty +cache instead. +""" + +import logging + +from sympy.external import import_module +from sympy.testing.pytest import raises, SKIP, warns_deprecated_sympy + +theanologger = logging.getLogger('theano.configdefaults') +theanologger.setLevel(logging.CRITICAL) +theano = import_module('theano') +theanologger.setLevel(logging.WARNING) + + +if theano: + import numpy as np + ts = theano.scalar + tt = theano.tensor + xt, yt, zt = [tt.scalar(name, 'floatX') for name in 'xyz'] + Xt, Yt, Zt = [tt.tensor('floatX', (False, False), name=n) for n in 'XYZ'] +else: + #bin/test will not execute any tests now + disabled = True + +import sympy as sy +from sympy.core.singleton import S +from sympy.abc import x, y, z, t +from sympy.printing.theanocode import (theano_code, dim_handling, + theano_function) + + +# Default set of matrix symbols for testing - make square so we can both +# multiply and perform elementwise operations between them. +X, Y, Z = [sy.MatrixSymbol(n, 4, 4) for n in 'XYZ'] + +# For testing AppliedUndef +f_t = sy.Function('f')(t) + + +def theano_code_(expr, **kwargs): + """ Wrapper for theano_code that uses a new, empty cache by default. """ + kwargs.setdefault('cache', {}) + with warns_deprecated_sympy(): + return theano_code(expr, **kwargs) + +def theano_function_(inputs, outputs, **kwargs): + """ Wrapper for theano_function that uses a new, empty cache by default. """ + kwargs.setdefault('cache', {}) + with warns_deprecated_sympy(): + return theano_function(inputs, outputs, **kwargs) + + +def fgraph_of(*exprs): + """ Transform SymPy expressions into Theano Computation. + + Parameters + ========== + exprs + SymPy expressions + + Returns + ======= + theano.gof.FunctionGraph + """ + outs = list(map(theano_code_, exprs)) + ins = theano.gof.graph.inputs(outs) + ins, outs = theano.gof.graph.clone(ins, outs) + return theano.gof.FunctionGraph(ins, outs) + + +def theano_simplify(fgraph): + """ Simplify a Theano Computation. + + Parameters + ========== + fgraph : theano.gof.FunctionGraph + + Returns + ======= + theano.gof.FunctionGraph + """ + mode = theano.compile.get_default_mode().excluding("fusion") + fgraph = fgraph.clone() + mode.optimizer.optimize(fgraph) + return fgraph + + +def theq(a, b): + """ Test two Theano objects for equality. + + Also accepts numeric types and lists/tuples of supported types. + + Note - debugprint() has a bug where it will accept numeric types but does + not respect the "file" argument and in this case and instead prints the number + to stdout and returns an empty string. This can lead to tests passing where + they should fail because any two numbers will always compare as equal. To + prevent this we treat numbers as a separate case. + """ + numeric_types = (int, float, np.number) + a_is_num = isinstance(a, numeric_types) + b_is_num = isinstance(b, numeric_types) + + # Compare numeric types using regular equality + if a_is_num or b_is_num: + if not (a_is_num and b_is_num): + return False + + return a == b + + # Compare sequences element-wise + a_is_seq = isinstance(a, (tuple, list)) + b_is_seq = isinstance(b, (tuple, list)) + + if a_is_seq or b_is_seq: + if not (a_is_seq and b_is_seq) or type(a) != type(b): + return False + + return list(map(theq, a)) == list(map(theq, b)) + + # Otherwise, assume debugprint() can handle it + astr = theano.printing.debugprint(a, file='str') + bstr = theano.printing.debugprint(b, file='str') + + # Check for bug mentioned above + for argname, argval, argstr in [('a', a, astr), ('b', b, bstr)]: + if argstr == '': + raise TypeError( + 'theano.printing.debugprint(%s) returned empty string ' + '(%s is instance of %r)' + % (argname, argname, type(argval)) + ) + + return astr == bstr + + +def test_example_symbols(): + """ + Check that the example symbols in this module print to their Theano + equivalents, as many of the other tests depend on this. + """ + assert theq(xt, theano_code_(x)) + assert theq(yt, theano_code_(y)) + assert theq(zt, theano_code_(z)) + assert theq(Xt, theano_code_(X)) + assert theq(Yt, theano_code_(Y)) + assert theq(Zt, theano_code_(Z)) + + +def test_Symbol(): + """ Test printing a Symbol to a theano variable. """ + xx = theano_code_(x) + assert isinstance(xx, (tt.TensorVariable, ts.ScalarVariable)) + assert xx.broadcastable == () + assert xx.name == x.name + + xx2 = theano_code_(x, broadcastables={x: (False,)}) + assert xx2.broadcastable == (False,) + assert xx2.name == x.name + +def test_MatrixSymbol(): + """ Test printing a MatrixSymbol to a theano variable. """ + XX = theano_code_(X) + assert isinstance(XX, tt.TensorVariable) + assert XX.broadcastable == (False, False) + +@SKIP # TODO - this is currently not checked but should be implemented +def test_MatrixSymbol_wrong_dims(): + """ Test MatrixSymbol with invalid broadcastable. """ + bcs = [(), (False,), (True,), (True, False), (False, True,), (True, True)] + for bc in bcs: + with raises(ValueError): + theano_code_(X, broadcastables={X: bc}) + +def test_AppliedUndef(): + """ Test printing AppliedUndef instance, which works similarly to Symbol. """ + ftt = theano_code_(f_t) + assert isinstance(ftt, tt.TensorVariable) + assert ftt.broadcastable == () + assert ftt.name == 'f_t' + + +def test_add(): + expr = x + y + comp = theano_code_(expr) + assert comp.owner.op == theano.tensor.add + +def test_trig(): + assert theq(theano_code_(sy.sin(x)), tt.sin(xt)) + assert theq(theano_code_(sy.tan(x)), tt.tan(xt)) + +def test_many(): + """ Test printing a complex expression with multiple symbols. """ + expr = sy.exp(x**2 + sy.cos(y)) * sy.log(2*z) + comp = theano_code_(expr) + expected = tt.exp(xt**2 + tt.cos(yt)) * tt.log(2*zt) + assert theq(comp, expected) + + +def test_dtype(): + """ Test specifying specific data types through the dtype argument. """ + for dtype in ['float32', 'float64', 'int8', 'int16', 'int32', 'int64']: + assert theano_code_(x, dtypes={x: dtype}).type.dtype == dtype + + # "floatX" type + assert theano_code_(x, dtypes={x: 'floatX'}).type.dtype in ('float32', 'float64') + + # Type promotion + assert theano_code_(x + 1, dtypes={x: 'float32'}).type.dtype == 'float32' + assert theano_code_(x + y, dtypes={x: 'float64', y: 'float32'}).type.dtype == 'float64' + + +def test_broadcastables(): + """ Test the "broadcastables" argument when printing symbol-like objects. """ + + # No restrictions on shape + for s in [x, f_t]: + for bc in [(), (False,), (True,), (False, False), (True, False)]: + assert theano_code_(s, broadcastables={s: bc}).broadcastable == bc + + # TODO - matrix broadcasting? + +def test_broadcasting(): + """ Test "broadcastable" attribute after applying element-wise binary op. """ + + expr = x + y + + cases = [ + [(), (), ()], + [(False,), (False,), (False,)], + [(True,), (False,), (False,)], + [(False, True), (False, False), (False, False)], + [(True, False), (False, False), (False, False)], + ] + + for bc1, bc2, bc3 in cases: + comp = theano_code_(expr, broadcastables={x: bc1, y: bc2}) + assert comp.broadcastable == bc3 + + +def test_MatMul(): + expr = X*Y*Z + expr_t = theano_code_(expr) + assert isinstance(expr_t.owner.op, tt.Dot) + assert theq(expr_t, Xt.dot(Yt).dot(Zt)) + +def test_Transpose(): + assert isinstance(theano_code_(X.T).owner.op, tt.DimShuffle) + +def test_MatAdd(): + expr = X+Y+Z + assert isinstance(theano_code_(expr).owner.op, tt.Elemwise) + + +def test_Rationals(): + assert theq(theano_code_(sy.Integer(2) / 3), tt.true_div(2, 3)) + assert theq(theano_code_(S.Half), tt.true_div(1, 2)) + +def test_Integers(): + assert theano_code_(sy.Integer(3)) == 3 + +def test_factorial(): + n = sy.Symbol('n') + assert theano_code_(sy.factorial(n)) + +def test_Derivative(): + simp = lambda expr: theano_simplify(fgraph_of(expr)) + assert theq(simp(theano_code_(sy.Derivative(sy.sin(x), x, evaluate=False))), + simp(theano.grad(tt.sin(xt), xt))) + + +def test_theano_function_simple(): + """ Test theano_function() with single output. """ + f = theano_function_([x, y], [x+y]) + assert f(2, 3) == 5 + +def test_theano_function_multi(): + """ Test theano_function() with multiple outputs. """ + f = theano_function_([x, y], [x+y, x-y]) + o1, o2 = f(2, 3) + assert o1 == 5 + assert o2 == -1 + +def test_theano_function_numpy(): + """ Test theano_function() vs Numpy implementation. """ + f = theano_function_([x, y], [x+y], dim=1, + dtypes={x: 'float64', y: 'float64'}) + assert np.linalg.norm(f([1, 2], [3, 4]) - np.asarray([4, 6])) < 1e-9 + + f = theano_function_([x, y], [x+y], dtypes={x: 'float64', y: 'float64'}, + dim=1) + xx = np.arange(3).astype('float64') + yy = 2*np.arange(3).astype('float64') + assert np.linalg.norm(f(xx, yy) - 3*np.arange(3)) < 1e-9 + + +def test_theano_function_matrix(): + m = sy.Matrix([[x, y], [z, x + y + z]]) + expected = np.array([[1.0, 2.0], [3.0, 1.0 + 2.0 + 3.0]]) + f = theano_function_([x, y, z], [m]) + np.testing.assert_allclose(f(1.0, 2.0, 3.0), expected) + f = theano_function_([x, y, z], [m], scalar=True) + np.testing.assert_allclose(f(1.0, 2.0, 3.0), expected) + f = theano_function_([x, y, z], [m, m]) + assert isinstance(f(1.0, 2.0, 3.0), type([])) + np.testing.assert_allclose(f(1.0, 2.0, 3.0)[0], expected) + np.testing.assert_allclose(f(1.0, 2.0, 3.0)[1], expected) + +def test_dim_handling(): + assert dim_handling([x], dim=2) == {x: (False, False)} + assert dim_handling([x, y], dims={x: 1, y: 2}) == {x: (False, True), + y: (False, False)} + assert dim_handling([x], broadcastables={x: (False,)}) == {x: (False,)} + +def test_theano_function_kwargs(): + """ + Test passing additional kwargs from theano_function() to theano.function(). + """ + import numpy as np + f = theano_function_([x, y, z], [x+y], dim=1, on_unused_input='ignore', + dtypes={x: 'float64', y: 'float64', z: 'float64'}) + assert np.linalg.norm(f([1, 2], [3, 4], [0, 0]) - np.asarray([4, 6])) < 1e-9 + + f = theano_function_([x, y, z], [x+y], + dtypes={x: 'float64', y: 'float64', z: 'float64'}, + dim=1, on_unused_input='ignore') + xx = np.arange(3).astype('float64') + yy = 2*np.arange(3).astype('float64') + zz = 2*np.arange(3).astype('float64') + assert np.linalg.norm(f(xx, yy, zz) - 3*np.arange(3)) < 1e-9 + +def test_theano_function_scalar(): + """ Test the "scalar" argument to theano_function(). """ + + args = [ + ([x, y], [x + y], None, [0]), # Single 0d output + ([X, Y], [X + Y], None, [2]), # Single 2d output + ([x, y], [x + y], {x: 0, y: 1}, [1]), # Single 1d output + ([x, y], [x + y, x - y], None, [0, 0]), # Two 0d outputs + ([x, y, X, Y], [x + y, X + Y], None, [0, 2]), # One 0d output, one 2d + ] + + # Create and test functions with and without the scalar setting + for inputs, outputs, in_dims, out_dims in args: + for scalar in [False, True]: + + f = theano_function_(inputs, outputs, dims=in_dims, scalar=scalar) + + # Check the theano_function attribute is set whether wrapped or not + assert isinstance(f.theano_function, theano.compile.function_module.Function) + + # Feed in inputs of the appropriate size and get outputs + in_values = [ + np.ones([1 if bc else 5 for bc in i.type.broadcastable]) + for i in f.theano_function.input_storage + ] + out_values = f(*in_values) + if not isinstance(out_values, list): + out_values = [out_values] + + # Check output types and shapes + assert len(out_dims) == len(out_values) + for d, value in zip(out_dims, out_values): + + if scalar and d == 0: + # Should have been converted to a scalar value + assert isinstance(value, np.number) + + else: + # Otherwise should be an array + assert isinstance(value, np.ndarray) + assert value.ndim == d + +def test_theano_function_bad_kwarg(): + """ + Passing an unknown keyword argument to theano_function() should raise an + exception. + """ + raises(Exception, lambda : theano_function_([x], [x+1], foobar=3)) + + +def test_slice(): + assert theano_code_(slice(1, 2, 3)) == slice(1, 2, 3) + + def theq_slice(s1, s2): + for attr in ['start', 'stop', 'step']: + a1 = getattr(s1, attr) + a2 = getattr(s2, attr) + if a1 is None or a2 is None: + if not (a1 is None or a2 is None): + return False + elif not theq(a1, a2): + return False + return True + + dtypes = {x: 'int32', y: 'int32'} + assert theq_slice(theano_code_(slice(x, y), dtypes=dtypes), slice(xt, yt)) + assert theq_slice(theano_code_(slice(1, x, 3), dtypes=dtypes), slice(1, xt, 3)) + +def test_MatrixSlice(): + from theano import Constant + + cache = {} + + n = sy.Symbol('n', integer=True) + X = sy.MatrixSymbol('X', n, n) + + Y = X[1:2:3, 4:5:6] + Yt = theano_code_(Y, cache=cache) + + s = ts.Scalar('int64') + assert tuple(Yt.owner.op.idx_list) == (slice(s, s, s), slice(s, s, s)) + assert Yt.owner.inputs[0] == theano_code_(X, cache=cache) + # == doesn't work in theano like it does in SymPy. You have to use + # equals. + assert all(Yt.owner.inputs[i].equals(Constant(s, i)) for i in range(1, 7)) + + k = sy.Symbol('k') + theano_code_(k, dtypes={k: 'int32'}) + start, stop, step = 4, k, 2 + Y = X[start:stop:step] + Yt = theano_code_(Y, dtypes={n: 'int32', k: 'int32'}) + # assert Yt.owner.op.idx_list[0].stop == kt + +def test_BlockMatrix(): + n = sy.Symbol('n', integer=True) + A, B, C, D = [sy.MatrixSymbol(name, n, n) for name in 'ABCD'] + At, Bt, Ct, Dt = map(theano_code_, (A, B, C, D)) + Block = sy.BlockMatrix([[A, B], [C, D]]) + Blockt = theano_code_(Block) + solutions = [tt.join(0, tt.join(1, At, Bt), tt.join(1, Ct, Dt)), + tt.join(1, tt.join(0, At, Ct), tt.join(0, Bt, Dt))] + assert any(theq(Blockt, solution) for solution in solutions) + +@SKIP +def test_BlockMatrix_Inverse_execution(): + k, n = 2, 4 + dtype = 'float32' + A = sy.MatrixSymbol('A', n, k) + B = sy.MatrixSymbol('B', n, n) + inputs = A, B + output = B.I*A + + cutsizes = {A: [(n//2, n//2), (k//2, k//2)], + B: [(n//2, n//2), (n//2, n//2)]} + cutinputs = [sy.blockcut(i, *cutsizes[i]) for i in inputs] + cutoutput = output.subs(dict(zip(inputs, cutinputs))) + + dtypes = dict(zip(inputs, [dtype]*len(inputs))) + f = theano_function_(inputs, [output], dtypes=dtypes, cache={}) + fblocked = theano_function_(inputs, [sy.block_collapse(cutoutput)], + dtypes=dtypes, cache={}) + + ninputs = [np.random.rand(*x.shape).astype(dtype) for x in inputs] + ninputs = [np.arange(n*k).reshape(A.shape).astype(dtype), + np.eye(n).astype(dtype)] + ninputs[1] += np.ones(B.shape)*1e-5 + + assert np.allclose(f(*ninputs), fblocked(*ninputs), rtol=1e-5) + +def test_DenseMatrix(): + t = sy.Symbol('theta') + for MatrixType in [sy.Matrix, sy.ImmutableMatrix]: + X = MatrixType([[sy.cos(t), -sy.sin(t)], [sy.sin(t), sy.cos(t)]]) + tX = theano_code_(X) + assert isinstance(tX, tt.TensorVariable) + assert tX.owner.op == tt.join_ + + +def test_cache_basic(): + """ Test single symbol-like objects are cached when printed by themselves. """ + + # Pairs of objects which should be considered equivalent with respect to caching + pairs = [ + (x, sy.Symbol('x')), + (X, sy.MatrixSymbol('X', *X.shape)), + (f_t, sy.Function('f')(sy.Symbol('t'))), + ] + + for s1, s2 in pairs: + cache = {} + st = theano_code_(s1, cache=cache) + + # Test hit with same instance + assert theano_code_(s1, cache=cache) is st + + # Test miss with same instance but new cache + assert theano_code_(s1, cache={}) is not st + + # Test hit with different but equivalent instance + assert theano_code_(s2, cache=cache) is st + +def test_global_cache(): + """ Test use of the global cache. """ + from sympy.printing.theanocode import global_cache + + backup = dict(global_cache) + try: + # Temporarily empty global cache + global_cache.clear() + + for s in [x, X, f_t]: + with warns_deprecated_sympy(): + st = theano_code(s) + assert theano_code(s) is st + + finally: + # Restore global cache + global_cache.update(backup) + +def test_cache_types_distinct(): + """ + Test that symbol-like objects of different types (Symbol, MatrixSymbol, + AppliedUndef) are distinguished by the cache even if they have the same + name. + """ + symbols = [sy.Symbol('f_t'), sy.MatrixSymbol('f_t', 4, 4), f_t] + + cache = {} # Single shared cache + printed = {} + + for s in symbols: + st = theano_code_(s, cache=cache) + assert st not in printed.values() + printed[s] = st + + # Check all printed objects are distinct + assert len(set(map(id, printed.values()))) == len(symbols) + + # Check retrieving + for s, st in printed.items(): + with warns_deprecated_sympy(): + assert theano_code(s, cache=cache) is st + +def test_symbols_are_created_once(): + """ + Test that a symbol is cached and reused when it appears in an expression + more than once. + """ + expr = sy.Add(x, x, evaluate=False) + comp = theano_code_(expr) + + assert theq(comp, xt + xt) + assert not theq(comp, xt + theano_code_(x)) + +def test_cache_complex(): + """ + Test caching on a complicated expression with multiple symbols appearing + multiple times. + """ + expr = x ** 2 + (y - sy.exp(x)) * sy.sin(z - x * y) + symbol_names = {s.name for s in expr.free_symbols} + expr_t = theano_code_(expr) + + # Iterate through variables in the Theano computational graph that the + # printed expression depends on + seen = set() + for v in theano.gof.graph.ancestors([expr_t]): + # Owner-less, non-constant variables should be our symbols + if v.owner is None and not isinstance(v, theano.gof.graph.Constant): + # Check it corresponds to a symbol and appears only once + assert v.name in symbol_names + assert v.name not in seen + seen.add(v.name) + + # Check all were present + assert seen == symbol_names + + +def test_Piecewise(): + # A piecewise linear + expr = sy.Piecewise((0, x<0), (x, x<2), (1, True)) # ___/III + result = theano_code_(expr) + assert result.owner.op == tt.switch + + expected = tt.switch(xt<0, 0, tt.switch(xt<2, xt, 1)) + assert theq(result, expected) + + expr = sy.Piecewise((x, x < 0)) + result = theano_code_(expr) + expected = tt.switch(xt < 0, xt, np.nan) + assert theq(result, expected) + + expr = sy.Piecewise((0, sy.And(x>0, x<2)), \ + (x, sy.Or(x>2, x<0))) + result = theano_code_(expr) + expected = tt.switch(tt.and_(xt>0,xt<2), 0, \ + tt.switch(tt.or_(xt>2, xt<0), xt, np.nan)) + assert theq(result, expected) + + +def test_Relationals(): + assert theq(theano_code_(sy.Eq(x, y)), tt.eq(xt, yt)) + # assert theq(theano_code_(sy.Ne(x, y)), tt.neq(xt, yt)) # TODO - implement + assert theq(theano_code_(x > y), xt > yt) + assert theq(theano_code_(x < y), xt < yt) + assert theq(theano_code_(x >= y), xt >= yt) + assert theq(theano_code_(x <= y), xt <= yt) + + +def test_complexfunctions(): + with warns_deprecated_sympy(): + xt, yt = theano_code_(x, dtypes={x:'complex128'}), theano_code_(y, dtypes={y: 'complex128'}) + from sympy.functions.elementary.complexes import conjugate + from theano.tensor import as_tensor_variable as atv + from theano.tensor import complex as cplx + with warns_deprecated_sympy(): + assert theq(theano_code_(y*conjugate(x)), yt*(xt.conj())) + assert theq(theano_code_((1+2j)*x), xt*(atv(1.0)+atv(2.0)*cplx(0,1))) + + +def test_constantfunctions(): + with warns_deprecated_sympy(): + tf = theano_function_([],[1+1j]) + assert(tf()==1+1j) + + +def test_Exp1(): + """ + Test that exp(1) prints without error and evaluates close to SymPy's E + """ + # sy.exp(1) should yield same instance of E as sy.E (singleton), but extra + # check added for sanity + e_a = sy.exp(1) + e_b = sy.E + + np.testing.assert_allclose(float(e_a), np.e) + np.testing.assert_allclose(float(e_b), np.e) + + e = theano_code_(e_a) + np.testing.assert_allclose(float(e_a), e.eval()) + + e = theano_code_(e_b) + np.testing.assert_allclose(float(e_b), e.eval()) diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/test_torch.py b/phivenv/Lib/site-packages/sympy/printing/tests/test_torch.py new file mode 100644 index 0000000000000000000000000000000000000000..8ce2c6cec75e03264f93b472a79eb073742e3486 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/tests/test_torch.py @@ -0,0 +1,531 @@ +import random +import math + +from sympy import symbols, Derivative +from sympy.printing.pytorch import torch_code +from sympy import (eye, MatrixSymbol, Matrix) +from sympy.tensor.array import NDimArray +from sympy.tensor.array.expressions.array_expressions import ( + ArrayTensorProduct, ArrayAdd, + PermuteDims, ArrayDiagonal, _CodegenArrayAbstract) +from sympy.utilities.lambdify import lambdify +from sympy.core.relational import Eq, Ne, Ge, Gt, Le, Lt +from sympy.functions import \ + Abs, ceiling, exp, floor, sign, sin, asin, cos, \ + acos, tan, atan, atan2, cosh, acosh, sinh, asinh, tanh, atanh, \ + re, im, arg, erf, loggamma, sqrt +from sympy.testing.pytest import skip +from sympy.external import import_module +from sympy.matrices.expressions import \ + Determinant, HadamardProduct, Inverse, Trace +from sympy.matrices import randMatrix +from sympy.matrices import Identity, ZeroMatrix, OneMatrix +from sympy import conjugate, I +from sympy import Heaviside, gamma, polygamma + + + +torch = import_module("torch") + +M = MatrixSymbol("M", 3, 3) +N = MatrixSymbol("N", 3, 3) +P = MatrixSymbol("P", 3, 3) +Q = MatrixSymbol("Q", 3, 3) + +x, y, z, t = symbols("x y z t") + +if torch is not None: + llo = [list(range(i, i + 3)) for i in range(0, 9, 3)] + m3x3 = torch.tensor(llo, dtype=torch.float64) + m3x3sympy = Matrix(llo) + + +def _compare_torch_matrix(variables, expr): + f = lambdify(variables, expr, 'torch') + + random_matrices = [randMatrix(i.shape[0], i.shape[1]) for i in variables] + random_variables = [torch.tensor(i.tolist(), dtype=torch.float64) for i in random_matrices] + r = f(*random_variables) + e = expr.subs(dict(zip(variables, random_matrices))).doit() + + if isinstance(e, _CodegenArrayAbstract): + e = e.doit() + + if hasattr(e, 'is_number') and e.is_number: + if isinstance(r, torch.Tensor) and r.dim() == 0: + r = r.item() + e = float(e) + assert abs(r - e) < 1e-6 + return + + if e.is_Matrix or isinstance(e, NDimArray): + e = torch.tensor(e.tolist(), dtype=torch.float64) + assert torch.allclose(r, e, atol=1e-6) + else: + raise TypeError(f"Cannot compare {type(r)} with {type(e)}") + + +def _compare_torch_scalar(variables, expr, rng=lambda: random.uniform(-5, 5)): + f = lambdify(variables, expr, 'torch') + rvs = [rng() for v in variables] + t_rvs = [torch.tensor(i, dtype=torch.float64) for i in rvs] + r = f(*t_rvs) + if isinstance(r, torch.Tensor): + r = r.item() + e = expr.subs(dict(zip(variables, rvs))).doit() + assert abs(r - e) < 1e-6 + + +def _compare_torch_relational(variables, expr, rng=lambda: random.randint(0, 10)): + f = lambdify(variables, expr, 'torch') + rvs = [rng() for v in variables] + t_rvs = [torch.tensor(i, dtype=torch.float64) for i in rvs] + r = f(*t_rvs) + e = bool(expr.subs(dict(zip(variables, rvs))).doit()) + assert r.item() == e + + +def test_torch_math(): + if not torch: + skip("PyTorch not installed") + + expr = Abs(x) + assert torch_code(expr) == "torch.abs(x)" + f = lambdify(x, expr, 'torch') + ma = torch.tensor([[-1, 2, -3, -4]], dtype=torch.float64) + y_abs = f(ma) + c = torch.abs(ma) + assert torch.all(y_abs == c) + + expr = sign(x) + assert torch_code(expr) == "torch.sign(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-10, 10)) + + expr = ceiling(x) + assert torch_code(expr) == "torch.ceil(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.random()) + + expr = floor(x) + assert torch_code(expr) == "torch.floor(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.random()) + + expr = exp(x) + assert torch_code(expr) == "torch.exp(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-2, 2)) + + expr = sqrt(x) + assert torch_code(expr) == "torch.sqrt(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.random()) + + expr = x ** 4 + assert torch_code(expr) == "torch.pow(x, 4)" + _compare_torch_scalar((x,), expr, rng=lambda: random.random()) + + expr = cos(x) + assert torch_code(expr) == "torch.cos(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.random()) + + expr = acos(x) + assert torch_code(expr) == "torch.acos(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-0.99, 0.99)) + + expr = sin(x) + assert torch_code(expr) == "torch.sin(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.random()) + + expr = asin(x) + assert torch_code(expr) == "torch.asin(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-0.99, 0.99)) + + expr = tan(x) + assert torch_code(expr) == "torch.tan(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-1.5, 1.5)) + + expr = atan(x) + assert torch_code(expr) == "torch.atan(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-5, 5)) + + expr = atan2(y, x) + assert torch_code(expr) == "torch.atan2(y, x)" + _compare_torch_scalar((y, x), expr, rng=lambda: random.uniform(-5, 5)) + + expr = cosh(x) + assert torch_code(expr) == "torch.cosh(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-2, 2)) + + expr = acosh(x) + assert torch_code(expr) == "torch.acosh(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(1.1, 5)) + + expr = sinh(x) + assert torch_code(expr) == "torch.sinh(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-2, 2)) + + expr = asinh(x) + assert torch_code(expr) == "torch.asinh(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-5, 5)) + + expr = tanh(x) + assert torch_code(expr) == "torch.tanh(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-2, 2)) + + expr = atanh(x) + assert torch_code(expr) == "torch.atanh(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-0.9, 0.9)) + + expr = erf(x) + assert torch_code(expr) == "torch.erf(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-2, 2)) + + expr = loggamma(x) + assert torch_code(expr) == "torch.lgamma(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(0.5, 5)) + + +def test_torch_complexes(): + assert torch_code(re(x)) == "torch.real(x)" + assert torch_code(im(x)) == "torch.imag(x)" + assert torch_code(arg(x)) == "torch.angle(x)" + + +def test_torch_relational(): + if not torch: + skip("PyTorch not installed") + + expr = Eq(x, y) + assert torch_code(expr) == "torch.eq(x, y)" + _compare_torch_relational((x, y), expr) + + expr = Ne(x, y) + assert torch_code(expr) == "torch.ne(x, y)" + _compare_torch_relational((x, y), expr) + + expr = Ge(x, y) + assert torch_code(expr) == "torch.ge(x, y)" + _compare_torch_relational((x, y), expr) + + expr = Gt(x, y) + assert torch_code(expr) == "torch.gt(x, y)" + _compare_torch_relational((x, y), expr) + + expr = Le(x, y) + assert torch_code(expr) == "torch.le(x, y)" + _compare_torch_relational((x, y), expr) + + expr = Lt(x, y) + assert torch_code(expr) == "torch.lt(x, y)" + _compare_torch_relational((x, y), expr) + + +def test_torch_matrix(): + if torch is None: + skip("PyTorch not installed") + + expr = M + assert torch_code(expr) == "M" + f = lambdify((M,), expr, "torch") + eye_mat = eye(3) + eye_tensor = torch.tensor(eye_mat.tolist(), dtype=torch.float64) + assert torch.allclose(f(eye_tensor), eye_tensor) + + expr = M * N + assert torch_code(expr) == "torch.matmul(M, N)" + _compare_torch_matrix((M, N), expr) + + expr = M ** 3 + assert torch_code(expr) == "torch.mm(torch.mm(M, M), M)" + _compare_torch_matrix((M,), expr) + + expr = M * N * P * Q + assert torch_code(expr) == "torch.matmul(torch.matmul(torch.matmul(M, N), P), Q)" + _compare_torch_matrix((M, N, P, Q), expr) + + expr = Trace(M) + assert torch_code(expr) == "torch.trace(M)" + _compare_torch_matrix((M,), expr) + + expr = Determinant(M) + assert torch_code(expr) == "torch.det(M)" + _compare_torch_matrix((M,), expr) + + expr = HadamardProduct(M, N) + assert torch_code(expr) == "torch.mul(M, N)" + _compare_torch_matrix((M, N), expr) + + expr = Inverse(M) + assert torch_code(expr) == "torch.linalg.inv(M)" + + # For inverse, use a matrix that's guaranteed to be invertible + eye_mat = eye(3) + eye_tensor = torch.tensor(eye_mat.tolist(), dtype=torch.float64) + f = lambdify((M,), expr, "torch") + result = f(eye_tensor) + expected = torch.linalg.inv(eye_tensor) + assert torch.allclose(result, expected) + + +def test_torch_array_operations(): + if not torch: + skip("PyTorch not installed") + + M = MatrixSymbol("M", 2, 2) + N = MatrixSymbol("N", 2, 2) + P = MatrixSymbol("P", 2, 2) + Q = MatrixSymbol("Q", 2, 2) + + ma = torch.tensor([[1., 2.], [3., 4.]], dtype=torch.float64) + mb = torch.tensor([[1., -2.], [-1., 3.]], dtype=torch.float64) + mc = torch.tensor([[2., 0.], [1., 2.]], dtype=torch.float64) + md = torch.tensor([[1., -1.], [4., 7.]], dtype=torch.float64) + + cg = ArrayTensorProduct(M, N) + assert torch_code(cg) == 'torch.einsum("ab,cd", M, N)' + f = lambdify((M, N), cg, 'torch') + y = f(ma, mb) + c = torch.einsum("ij,kl", ma, mb) + assert torch.allclose(y, c) + + cg = ArrayAdd(M, N) + assert torch_code(cg) == 'torch.add(M, N)' + f = lambdify((M, N), cg, 'torch') + y = f(ma, mb) + c = ma + mb + assert torch.allclose(y, c) + + cg = ArrayAdd(M, N, P) + assert torch_code(cg) == 'torch.add(torch.add(M, N), P)' + f = lambdify((M, N, P), cg, 'torch') + y = f(ma, mb, mc) + c = ma + mb + mc + assert torch.allclose(y, c) + + cg = ArrayAdd(M, N, P, Q) + assert torch_code(cg) == 'torch.add(torch.add(torch.add(M, N), P), Q)' + f = lambdify((M, N, P, Q), cg, 'torch') + y = f(ma, mb, mc, md) + c = ma + mb + mc + md + assert torch.allclose(y, c) + + cg = PermuteDims(M, [1, 0]) + assert torch_code(cg) == 'M.permute(1, 0)' + f = lambdify((M,), cg, 'torch') + y = f(ma) + c = ma.T + assert torch.allclose(y, c) + + cg = PermuteDims(ArrayTensorProduct(M, N), [1, 2, 3, 0]) + assert torch_code(cg) == 'torch.einsum("ab,cd", M, N).permute(1, 2, 3, 0)' + f = lambdify((M, N), cg, 'torch') + y = f(ma, mb) + c = torch.einsum("ab,cd", ma, mb).permute(1, 2, 3, 0) + assert torch.allclose(y, c) + + cg = ArrayDiagonal(ArrayTensorProduct(M, N), (1, 2)) + assert torch_code(cg) == 'torch.einsum("ab,bc->acb", M, N)' + f = lambdify((M, N), cg, 'torch') + y = f(ma, mb) + c = torch.einsum("ab,bc->acb", ma, mb) + assert torch.allclose(y, c) + + +def test_torch_derivative(): + """Test derivative handling.""" + expr = Derivative(sin(x), x) + assert torch_code(expr) == 'torch.autograd.grad(torch.sin(x), x)[0]' + + +def test_torch_printing_dtype(): + if not torch: + skip("PyTorch not installed") + + # matrix printing with default dtype + expr = Matrix([[x, sin(y)], [exp(z), -t]]) + assert "dtype=torch.float64" in torch_code(expr) + + # explicit dtype + assert "dtype=torch.float32" in torch_code(expr, dtype="torch.float32") + + # with requires_grad + result = torch_code(expr, requires_grad=True) + assert "requires_grad=True" in result + assert "dtype=torch.float64" in result + + # both + result = torch_code(expr, requires_grad=True, dtype="torch.float32") + assert "requires_grad=True" in result + assert "dtype=torch.float32" in result + + +def test_requires_grad(): + if not torch: + skip("PyTorch not installed") + + expr = sin(x) + cos(y) + f = lambdify([x, y], expr, 'torch') + + # make sure the gradients flow + x_val = torch.tensor(1.0, requires_grad=True) + y_val = torch.tensor(2.0, requires_grad=True) + result = f(x_val, y_val) + assert result.requires_grad + result.backward() + + # x_val.grad should be cos(x_val) which is close to cos(1.0) + assert abs(x_val.grad.item() - float(cos(1.0).evalf())) < 1e-6 + + # y_val.grad should be -sin(y_val) which is close to -sin(2.0) + assert abs(y_val.grad.item() - float(-sin(2.0).evalf())) < 1e-6 + + +def test_torch_multi_variable_derivatives(): + if not torch: + skip("PyTorch not installed") + + x, y, z = symbols("x y z") + + expr = Derivative(sin(x), x) + assert torch_code(expr) == "torch.autograd.grad(torch.sin(x), x)[0]" + + expr = Derivative(sin(x), (x, 2)) + assert torch_code( + expr) == "torch.autograd.grad(torch.autograd.grad(torch.sin(x), x, create_graph=True)[0], x, create_graph=True)[0]" + + expr = Derivative(sin(x * y), x, y) + result = torch_code(expr) + expected = "torch.autograd.grad(torch.autograd.grad(torch.sin(x*y), x, create_graph=True)[0], y, create_graph=True)[0]" + normalized_result = result.replace(" ", "") + normalized_expected = expected.replace(" ", "") + assert normalized_result == normalized_expected + + expr = Derivative(sin(x), x, x) + result = torch_code(expr) + expected = "torch.autograd.grad(torch.autograd.grad(torch.sin(x), x, create_graph=True)[0], x, create_graph=True)[0]" + assert result == expected + + expr = Derivative(sin(x * y * z), x, (y, 2), z) + result = torch_code(expr) + expected = "torch.autograd.grad(torch.autograd.grad(torch.autograd.grad(torch.autograd.grad(torch.sin(x*y*z), x, create_graph=True)[0], y, create_graph=True)[0], y, create_graph=True)[0], z, create_graph=True)[0]" + normalized_result = result.replace(" ", "") + normalized_expected = expected.replace(" ", "") + assert normalized_result == normalized_expected + + +def test_torch_derivative_lambdify(): + if not torch: + skip("PyTorch not installed") + + x = symbols("x") + y = symbols("y") + + expr = Derivative(x ** 2, x) + f = lambdify(x, expr, 'torch') + x_val = torch.tensor(2.0, requires_grad=True) + result = f(x_val) + assert torch.isclose(result, torch.tensor(4.0)) + + expr = Derivative(sin(x), (x, 2)) + f = lambdify(x, expr, 'torch') + # Second derivative of sin(x) at x=0 is 0, not -1 + x_val = torch.tensor(0.0, requires_grad=True) + result = f(x_val) + assert torch.isclose(result, torch.tensor(0.0), atol=1e-5) + + x_val = torch.tensor(math.pi / 2, requires_grad=True) + result = f(x_val) + assert torch.isclose(result, torch.tensor(-1.0), atol=1e-5) + + expr = Derivative(x * y ** 2, x, y) + f = lambdify((x, y), expr, 'torch') + x_val = torch.tensor(2.0, requires_grad=True) + y_val = torch.tensor(3.0, requires_grad=True) + result = f(x_val, y_val) + assert torch.isclose(result, torch.tensor(6.0)) + + +def test_torch_special_matrices(): + if not torch: + skip("PyTorch not installed") + + expr = Identity(3) + assert torch_code(expr) == "torch.eye(3)" + + n = symbols("n") + expr = Identity(n) + assert torch_code(expr) == "torch.eye(n, n)" + + expr = ZeroMatrix(2, 3) + assert torch_code(expr) == "torch.zeros((2, 3))" + + m, n = symbols("m n") + expr = ZeroMatrix(m, n) + assert torch_code(expr) == "torch.zeros((m, n))" + + expr = OneMatrix(2, 3) + assert torch_code(expr) == "torch.ones((2, 3))" + + expr = OneMatrix(m, n) + assert torch_code(expr) == "torch.ones((m, n))" + + +def test_torch_special_matrices_lambdify(): + if not torch: + skip("PyTorch not installed") + + expr = Identity(3) + f = lambdify([], expr, 'torch') + result = f() + expected = torch.eye(3) + assert torch.allclose(result, expected) + + expr = ZeroMatrix(2, 3) + f = lambdify([], expr, 'torch') + result = f() + expected = torch.zeros((2, 3)) + assert torch.allclose(result, expected) + + expr = OneMatrix(2, 3) + f = lambdify([], expr, 'torch') + result = f() + expected = torch.ones((2, 3)) + assert torch.allclose(result, expected) + + +def test_torch_complex_operations(): + if not torch: + skip("PyTorch not installed") + + expr = conjugate(x) + assert torch_code(expr) == "torch.conj(x)" + + # SymPy distributes conjugate over addition and applies specific rules for each term + expr = conjugate(sin(x) + I * cos(y)) + assert torch_code(expr) == "torch.sin(torch.conj(x)) - 1j*torch.cos(torch.conj(y))" + + expr = I + assert torch_code(expr) == "1j" + + expr = 2 * I + x + assert torch_code(expr) == "x + 2*1j" + + expr = exp(I * x) + assert torch_code(expr) == "torch.exp(1j*x)" + + +def test_torch_special_functions(): + if not torch: + skip("PyTorch not installed") + + expr = Heaviside(x) + assert torch_code(expr) == "torch.heaviside(x, 1/2)" + + expr = Heaviside(x, 0) + assert torch_code(expr) == "torch.heaviside(x, 0)" + + expr = gamma(x) + assert torch_code(expr) == "torch.special.gamma(x)" + + expr = polygamma(0, x) # Use polygamma instead of digamma because sympy will default to that anyway + assert torch_code(expr) == "torch.special.digamma(x)" + + expr = gamma(sin(x)) + assert torch_code(expr) == "torch.special.gamma(torch.sin(x))" diff --git a/phivenv/Lib/site-packages/sympy/printing/tests/test_tree.py b/phivenv/Lib/site-packages/sympy/printing/tests/test_tree.py new file mode 100644 index 0000000000000000000000000000000000000000..cf116d0cac5d38f225815fcd2d4ac90cd0dd96d7 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/tests/test_tree.py @@ -0,0 +1,196 @@ +from sympy.printing.tree import tree +from sympy.testing.pytest import XFAIL + + +# Remove this flag after making _assumptions cache deterministic. +@XFAIL +def test_print_tree_MatAdd(): + from sympy.matrices.expressions import MatrixSymbol + A = MatrixSymbol('A', 3, 3) + B = MatrixSymbol('B', 3, 3) + + test_str = [ + 'MatAdd: A + B\n', + 'algebraic: False\n', + 'commutative: False\n', + 'complex: False\n', + 'composite: False\n', + 'even: False\n', + 'extended_negative: False\n', + 'extended_nonnegative: False\n', + 'extended_nonpositive: False\n', + 'extended_nonzero: False\n', + 'extended_positive: False\n', + 'extended_real: False\n', + 'imaginary: False\n', + 'integer: False\n', + 'irrational: False\n', + 'negative: False\n', + 'noninteger: False\n', + 'nonnegative: False\n', + 'nonpositive: False\n', + 'nonzero: False\n', + 'odd: False\n', + 'positive: False\n', + 'prime: False\n', + 'rational: False\n', + 'real: False\n', + 'transcendental: False\n', + 'zero: False\n', + '+-MatrixSymbol: A\n', + '| algebraic: False\n', + '| commutative: False\n', + '| complex: False\n', + '| composite: False\n', + '| even: False\n', + '| extended_negative: False\n', + '| extended_nonnegative: False\n', + '| extended_nonpositive: False\n', + '| extended_nonzero: False\n', + '| extended_positive: False\n', + '| extended_real: False\n', + '| imaginary: False\n', + '| integer: False\n', + '| irrational: False\n', + '| negative: False\n', + '| noninteger: False\n', + '| nonnegative: False\n', + '| nonpositive: False\n', + '| nonzero: False\n', + '| odd: False\n', + '| positive: False\n', + '| prime: False\n', + '| rational: False\n', + '| real: False\n', + '| transcendental: False\n', + '| zero: False\n', + '| +-Symbol: A\n', + '| | commutative: True\n', + '| +-Integer: 3\n', + '| | algebraic: True\n', + '| | commutative: True\n', + '| | complex: True\n', + '| | extended_negative: False\n', + '| | extended_nonnegative: True\n', + '| | extended_real: True\n', + '| | finite: True\n', + '| | hermitian: True\n', + '| | imaginary: False\n', + '| | infinite: False\n', + '| | integer: True\n', + '| | irrational: False\n', + '| | negative: False\n', + '| | noninteger: False\n', + '| | nonnegative: True\n', + '| | rational: True\n', + '| | real: True\n', + '| | transcendental: False\n', + '| +-Integer: 3\n', + '| algebraic: True\n', + '| commutative: True\n', + '| complex: True\n', + '| extended_negative: False\n', + '| extended_nonnegative: True\n', + '| extended_real: True\n', + '| finite: True\n', + '| hermitian: True\n', + '| imaginary: False\n', + '| infinite: False\n', + '| integer: True\n', + '| irrational: False\n', + '| negative: False\n', + '| noninteger: False\n', + '| nonnegative: True\n', + '| rational: True\n', + '| real: True\n', + '| transcendental: False\n', + '+-MatrixSymbol: B\n', + ' algebraic: False\n', + ' commutative: False\n', + ' complex: False\n', + ' composite: False\n', + ' even: False\n', + ' extended_negative: False\n', + ' extended_nonnegative: False\n', + ' extended_nonpositive: False\n', + ' extended_nonzero: False\n', + ' extended_positive: False\n', + ' extended_real: False\n', + ' imaginary: False\n', + ' integer: False\n', + ' irrational: False\n', + ' negative: False\n', + ' noninteger: False\n', + ' nonnegative: False\n', + ' nonpositive: False\n', + ' nonzero: False\n', + ' odd: False\n', + ' positive: False\n', + ' prime: False\n', + ' rational: False\n', + ' real: False\n', + ' transcendental: False\n', + ' zero: False\n', + ' +-Symbol: B\n', + ' | commutative: True\n', + ' +-Integer: 3\n', + ' | algebraic: True\n', + ' | commutative: True\n', + ' | complex: True\n', + ' | extended_negative: False\n', + ' | extended_nonnegative: True\n', + ' | extended_real: True\n', + ' | finite: True\n', + ' | hermitian: True\n', + ' | imaginary: False\n', + ' | infinite: False\n', + ' | integer: True\n', + ' | irrational: False\n', + ' | negative: False\n', + ' | noninteger: False\n', + ' | nonnegative: True\n', + ' | rational: True\n', + ' | real: True\n', + ' | transcendental: False\n', + ' +-Integer: 3\n', + ' algebraic: True\n', + ' commutative: True\n', + ' complex: True\n', + ' extended_negative: False\n', + ' extended_nonnegative: True\n', + ' extended_real: True\n', + ' finite: True\n', + ' hermitian: True\n', + ' imaginary: False\n', + ' infinite: False\n', + ' integer: True\n', + ' irrational: False\n', + ' negative: False\n', + ' noninteger: False\n', + ' nonnegative: True\n', + ' rational: True\n', + ' real: True\n', + ' transcendental: False\n' + ] + + assert tree(A + B) == "".join(test_str) + + +def test_print_tree_MatAdd_noassumptions(): + from sympy.matrices.expressions import MatrixSymbol + A = MatrixSymbol('A', 3, 3) + B = MatrixSymbol('B', 3, 3) + + test_str = \ +"""MatAdd: A + B ++-MatrixSymbol: A +| +-Str: A +| +-Integer: 3 +| +-Integer: 3 ++-MatrixSymbol: B + +-Str: B + +-Integer: 3 + +-Integer: 3 +""" + + assert tree(A + B, assumptions=False) == test_str diff --git a/phivenv/Lib/site-packages/sympy/printing/theanocode.py b/phivenv/Lib/site-packages/sympy/printing/theanocode.py new file mode 100644 index 0000000000000000000000000000000000000000..dce908865d426dabede2b6749ad944e5a420e4cf --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/theanocode.py @@ -0,0 +1,571 @@ +""" +.. deprecated:: 1.8 + + ``sympy.printing.theanocode`` is deprecated. Theano has been renamed to + Aesara. Use ``sympy.printing.aesaracode`` instead. See + :ref:`theanocode-deprecated` for more information. + +""" +from __future__ import annotations +import math +from typing import Any + +from sympy.external import import_module +from sympy.printing.printer import Printer +from sympy.utilities.iterables import is_sequence +import sympy +from functools import partial + +from sympy.utilities.decorator import doctest_depends_on +from sympy.utilities.exceptions import sympy_deprecation_warning + + +__doctest_requires__ = {('theano_function',): ['theano']} + + +theano = import_module('theano') + + +if theano: + ts = theano.scalar + tt = theano.tensor + from theano.sandbox import linalg as tlinalg + + mapping = { + sympy.Add: tt.add, + sympy.Mul: tt.mul, + sympy.Abs: tt.abs_, + sympy.sign: tt.sgn, + sympy.ceiling: tt.ceil, + sympy.floor: tt.floor, + sympy.log: tt.log, + sympy.exp: tt.exp, + sympy.sqrt: tt.sqrt, + sympy.cos: tt.cos, + sympy.acos: tt.arccos, + sympy.sin: tt.sin, + sympy.asin: tt.arcsin, + sympy.tan: tt.tan, + sympy.atan: tt.arctan, + sympy.atan2: tt.arctan2, + sympy.cosh: tt.cosh, + sympy.acosh: tt.arccosh, + sympy.sinh: tt.sinh, + sympy.asinh: tt.arcsinh, + sympy.tanh: tt.tanh, + sympy.atanh: tt.arctanh, + sympy.re: tt.real, + sympy.im: tt.imag, + sympy.arg: tt.angle, + sympy.erf: tt.erf, + sympy.gamma: tt.gamma, + sympy.loggamma: tt.gammaln, + sympy.Pow: tt.pow, + sympy.Eq: tt.eq, + sympy.StrictGreaterThan: tt.gt, + sympy.StrictLessThan: tt.lt, + sympy.LessThan: tt.le, + sympy.GreaterThan: tt.ge, + sympy.And: tt.and_, + sympy.Or: tt.or_, + sympy.Max: tt.maximum, # SymPy accept >2 inputs, Theano only 2 + sympy.Min: tt.minimum, # SymPy accept >2 inputs, Theano only 2 + sympy.conjugate: tt.conj, + sympy.core.numbers.ImaginaryUnit: lambda:tt.complex(0,1), + # Matrices + sympy.MatAdd: tt.Elemwise(ts.add), + sympy.HadamardProduct: tt.Elemwise(ts.mul), + sympy.Trace: tlinalg.trace, + sympy.Determinant : tlinalg.det, + sympy.Inverse: tlinalg.matrix_inverse, + sympy.Transpose: tt.DimShuffle((False, False), [1, 0]), + } + + +class TheanoPrinter(Printer): + """ Code printer which creates Theano symbolic expression graphs. + + Parameters + ========== + + cache : dict + Cache dictionary to use. If None (default) will use + the global cache. To create a printer which does not depend on or alter + global state pass an empty dictionary. Note: the dictionary is not + copied on initialization of the printer and will be updated in-place, + so using the same dict object when creating multiple printers or making + multiple calls to :func:`.theano_code` or :func:`.theano_function` means + the cache is shared between all these applications. + + Attributes + ========== + + cache : dict + A cache of Theano variables which have been created for SymPy + symbol-like objects (e.g. :class:`sympy.core.symbol.Symbol` or + :class:`sympy.matrices.expressions.MatrixSymbol`). This is used to + ensure that all references to a given symbol in an expression (or + multiple expressions) are printed as the same Theano variable, which is + created only once. Symbols are differentiated only by name and type. The + format of the cache's contents should be considered opaque to the user. + """ + printmethod = "_theano" + + def __init__(self, *args, **kwargs): + self.cache = kwargs.pop('cache', {}) + super().__init__(*args, **kwargs) + + def _get_key(self, s, name=None, dtype=None, broadcastable=None): + """ Get the cache key for a SymPy object. + + Parameters + ========== + + s : sympy.core.basic.Basic + SymPy object to get key for. + + name : str + Name of object, if it does not have a ``name`` attribute. + """ + + if name is None: + name = s.name + + return (name, type(s), s.args, dtype, broadcastable) + + def _get_or_create(self, s, name=None, dtype=None, broadcastable=None): + """ + Get the Theano variable for a SymPy symbol from the cache, or create it + if it does not exist. + """ + + # Defaults + if name is None: + name = s.name + if dtype is None: + dtype = 'floatX' + if broadcastable is None: + broadcastable = () + + key = self._get_key(s, name, dtype=dtype, broadcastable=broadcastable) + + if key in self.cache: + return self.cache[key] + + value = tt.tensor(name=name, dtype=dtype, broadcastable=broadcastable) + self.cache[key] = value + return value + + def _print_Symbol(self, s, **kwargs): + dtype = kwargs.get('dtypes', {}).get(s) + bc = kwargs.get('broadcastables', {}).get(s) + return self._get_or_create(s, dtype=dtype, broadcastable=bc) + + def _print_AppliedUndef(self, s, **kwargs): + name = str(type(s)) + '_' + str(s.args[0]) + dtype = kwargs.get('dtypes', {}).get(s) + bc = kwargs.get('broadcastables', {}).get(s) + return self._get_or_create(s, name=name, dtype=dtype, broadcastable=bc) + + def _print_Basic(self, expr, **kwargs): + op = mapping[type(expr)] + children = [self._print(arg, **kwargs) for arg in expr.args] + return op(*children) + + def _print_Number(self, n, **kwargs): + # Integers already taken care of below, interpret as float + return float(n.evalf()) + + def _print_MatrixSymbol(self, X, **kwargs): + dtype = kwargs.get('dtypes', {}).get(X) + return self._get_or_create(X, dtype=dtype, broadcastable=(None, None)) + + def _print_DenseMatrix(self, X, **kwargs): + if not hasattr(tt, 'stacklists'): + raise NotImplementedError( + "Matrix translation not yet supported in this version of Theano") + + return tt.stacklists([ + [self._print(arg, **kwargs) for arg in L] + for L in X.tolist() + ]) + + _print_ImmutableMatrix = _print_ImmutableDenseMatrix = _print_DenseMatrix + + def _print_MatMul(self, expr, **kwargs): + children = [self._print(arg, **kwargs) for arg in expr.args] + result = children[0] + for child in children[1:]: + result = tt.dot(result, child) + return result + + def _print_MatPow(self, expr, **kwargs): + children = [self._print(arg, **kwargs) for arg in expr.args] + result = 1 + if isinstance(children[1], int) and children[1] > 0: + for i in range(children[1]): + result = tt.dot(result, children[0]) + else: + raise NotImplementedError('''Only non-negative integer + powers of matrices can be handled by Theano at the moment''') + return result + + def _print_MatrixSlice(self, expr, **kwargs): + parent = self._print(expr.parent, **kwargs) + rowslice = self._print(slice(*expr.rowslice), **kwargs) + colslice = self._print(slice(*expr.colslice), **kwargs) + return parent[rowslice, colslice] + + def _print_BlockMatrix(self, expr, **kwargs): + nrows, ncols = expr.blocks.shape + blocks = [[self._print(expr.blocks[r, c], **kwargs) + for c in range(ncols)] + for r in range(nrows)] + return tt.join(0, *[tt.join(1, *row) for row in blocks]) + + + def _print_slice(self, expr, **kwargs): + return slice(*[self._print(i, **kwargs) + if isinstance(i, sympy.Basic) else i + for i in (expr.start, expr.stop, expr.step)]) + + def _print_Pi(self, expr, **kwargs): + return math.pi + + def _print_Exp1(self, expr, **kwargs): + return ts.exp(1) + + def _print_Piecewise(self, expr, **kwargs): + import numpy as np + e, cond = expr.args[0].args # First condition and corresponding value + + # Print conditional expression and value for first condition + p_cond = self._print(cond, **kwargs) + p_e = self._print(e, **kwargs) + + # One condition only + if len(expr.args) == 1: + # Return value if condition else NaN + return tt.switch(p_cond, p_e, np.nan) + + # Return value_1 if condition_1 else evaluate remaining conditions + p_remaining = self._print(sympy.Piecewise(*expr.args[1:]), **kwargs) + return tt.switch(p_cond, p_e, p_remaining) + + def _print_Rational(self, expr, **kwargs): + return tt.true_div(self._print(expr.p, **kwargs), + self._print(expr.q, **kwargs)) + + def _print_Integer(self, expr, **kwargs): + return expr.p + + def _print_factorial(self, expr, **kwargs): + return self._print(sympy.gamma(expr.args[0] + 1), **kwargs) + + def _print_Derivative(self, deriv, **kwargs): + rv = self._print(deriv.expr, **kwargs) + for var in deriv.variables: + var = self._print(var, **kwargs) + rv = tt.Rop(rv, var, tt.ones_like(var)) + return rv + + def emptyPrinter(self, expr): + return expr + + def doprint(self, expr, dtypes=None, broadcastables=None): + """ Convert a SymPy expression to a Theano graph variable. + + The ``dtypes`` and ``broadcastables`` arguments are used to specify the + data type, dimension, and broadcasting behavior of the Theano variables + corresponding to the free symbols in ``expr``. Each is a mapping from + SymPy symbols to the value of the corresponding argument to + ``theano.tensor.Tensor``. + + See the corresponding `documentation page`__ for more information on + broadcasting in Theano. + + .. __: http://deeplearning.net/software/theano/tutorial/broadcasting.html + + Parameters + ========== + + expr : sympy.core.expr.Expr + SymPy expression to print. + + dtypes : dict + Mapping from SymPy symbols to Theano datatypes to use when creating + new Theano variables for those symbols. Corresponds to the ``dtype`` + argument to ``theano.tensor.Tensor``. Defaults to ``'floatX'`` + for symbols not included in the mapping. + + broadcastables : dict + Mapping from SymPy symbols to the value of the ``broadcastable`` + argument to ``theano.tensor.Tensor`` to use when creating Theano + variables for those symbols. Defaults to the empty tuple for symbols + not included in the mapping (resulting in a scalar). + + Returns + ======= + + theano.gof.graph.Variable + A variable corresponding to the expression's value in a Theano + symbolic expression graph. + + """ + if dtypes is None: + dtypes = {} + if broadcastables is None: + broadcastables = {} + + return self._print(expr, dtypes=dtypes, broadcastables=broadcastables) + + +global_cache: dict[Any, Any] = {} + + +def theano_code(expr, cache=None, **kwargs): + """ + Convert a SymPy expression into a Theano graph variable. + + .. deprecated:: 1.8 + + ``sympy.printing.theanocode`` is deprecated. Theano has been renamed to + Aesara. Use ``sympy.printing.aesaracode`` instead. See + :ref:`theanocode-deprecated` for more information. + + Parameters + ========== + + expr : sympy.core.expr.Expr + SymPy expression object to convert. + + cache : dict + Cached Theano variables (see :class:`TheanoPrinter.cache + `). Defaults to the module-level global cache. + + dtypes : dict + Passed to :meth:`.TheanoPrinter.doprint`. + + broadcastables : dict + Passed to :meth:`.TheanoPrinter.doprint`. + + Returns + ======= + + theano.gof.graph.Variable + A variable corresponding to the expression's value in a Theano symbolic + expression graph. + + """ + sympy_deprecation_warning( + """ + sympy.printing.theanocode is deprecated. Theano has been renamed to + Aesara. Use sympy.printing.aesaracode instead.""", + deprecated_since_version="1.8", + active_deprecations_target='theanocode-deprecated') + + if not theano: + raise ImportError("theano is required for theano_code") + + if cache is None: + cache = global_cache + + return TheanoPrinter(cache=cache, settings={}).doprint(expr, **kwargs) + + +def dim_handling(inputs, dim=None, dims=None, broadcastables=None): + r""" + Get value of ``broadcastables`` argument to :func:`.theano_code` from + keyword arguments to :func:`.theano_function`. + + Included for backwards compatibility. + + Parameters + ========== + + inputs + Sequence of input symbols. + + dim : int + Common number of dimensions for all inputs. Overrides other arguments + if given. + + dims : dict + Mapping from input symbols to number of dimensions. Overrides + ``broadcastables`` argument if given. + + broadcastables : dict + Explicit value of ``broadcastables`` argument to + :meth:`.TheanoPrinter.doprint`. If not None function will return this value unchanged. + + Returns + ======= + dict + Dictionary mapping elements of ``inputs`` to their "broadcastable" + values (tuple of ``bool``\ s). + """ + if dim is not None: + return dict.fromkeys(inputs, (False,) * dim) + + if dims is not None: + maxdim = max(dims.values()) + return { + s: (False,) * d + (True,) * (maxdim - d) + for s, d in dims.items() + } + + if broadcastables is not None: + return broadcastables + + return {} + + +@doctest_depends_on(modules=('theano',)) +def theano_function(inputs, outputs, scalar=False, *, + dim=None, dims=None, broadcastables=None, **kwargs): + """ + Create a Theano function from SymPy expressions. + + .. deprecated:: 1.8 + + ``sympy.printing.theanocode`` is deprecated. Theano has been renamed to + Aesara. Use ``sympy.printing.aesaracode`` instead. See + :ref:`theanocode-deprecated` for more information. + + The inputs and outputs are converted to Theano variables using + :func:`.theano_code` and then passed to ``theano.function``. + + Parameters + ========== + + inputs + Sequence of symbols which constitute the inputs of the function. + + outputs + Sequence of expressions which constitute the outputs(s) of the + function. The free symbols of each expression must be a subset of + ``inputs``. + + scalar : bool + Convert 0-dimensional arrays in output to scalars. This will return a + Python wrapper function around the Theano function object. + + cache : dict + Cached Theano variables (see :class:`TheanoPrinter.cache + `). Defaults to the module-level global cache. + + dtypes : dict + Passed to :meth:`.TheanoPrinter.doprint`. + + broadcastables : dict + Passed to :meth:`.TheanoPrinter.doprint`. + + dims : dict + Alternative to ``broadcastables`` argument. Mapping from elements of + ``inputs`` to integers indicating the dimension of their associated + arrays/tensors. Overrides ``broadcastables`` argument if given. + + dim : int + Another alternative to the ``broadcastables`` argument. Common number of + dimensions to use for all arrays/tensors. + ``theano_function([x, y], [...], dim=2)`` is equivalent to using + ``broadcastables={x: (False, False), y: (False, False)}``. + + Returns + ======= + callable + A callable object which takes values of ``inputs`` as positional + arguments and returns an output array for each of the expressions + in ``outputs``. If ``outputs`` is a single expression the function will + return a Numpy array, if it is a list of multiple expressions the + function will return a list of arrays. See description of the ``squeeze`` + argument above for the behavior when a single output is passed in a list. + The returned object will either be an instance of + ``theano.compile.function_module.Function`` or a Python wrapper + function around one. In both cases, the returned value will have a + ``theano_function`` attribute which points to the return value of + ``theano.function``. + + Examples + ======== + + >>> from sympy.abc import x, y, z + >>> from sympy.printing.theanocode import theano_function + + A simple function with one input and one output: + + >>> f1 = theano_function([x], [x**2 - 1], scalar=True) + >>> f1(3) + 8.0 + + A function with multiple inputs and one output: + + >>> f2 = theano_function([x, y, z], [(x**z + y**z)**(1/z)], scalar=True) + >>> f2(3, 4, 2) + 5.0 + + A function with multiple inputs and multiple outputs: + + >>> f3 = theano_function([x, y], [x**2 + y**2, x**2 - y**2], scalar=True) + >>> f3(2, 3) + [13.0, -5.0] + + See also + ======== + + dim_handling + + """ + sympy_deprecation_warning( + """ + sympy.printing.theanocode is deprecated. Theano has been renamed to Aesara. Use sympy.printing.aesaracode instead""", + deprecated_since_version="1.8", + active_deprecations_target='theanocode-deprecated') + + if not theano: + raise ImportError("theano is required for theano_function") + + # Pop off non-theano keyword args + cache = kwargs.pop('cache', {}) + dtypes = kwargs.pop('dtypes', {}) + + broadcastables = dim_handling( + inputs, dim=dim, dims=dims, broadcastables=broadcastables, + ) + + # Print inputs/outputs + code = partial(theano_code, cache=cache, dtypes=dtypes, + broadcastables=broadcastables) + tinputs = list(map(code, inputs)) + toutputs = list(map(code, outputs)) + + #fix constant expressions as variables + toutputs = [output if isinstance(output, theano.Variable) else tt.as_tensor_variable(output) for output in toutputs] + + if len(toutputs) == 1: + toutputs = toutputs[0] + + # Compile theano func + func = theano.function(tinputs, toutputs, **kwargs) + + is_0d = [len(o.variable.broadcastable) == 0 for o in func.outputs] + + # No wrapper required + if not scalar or not any(is_0d): + func.theano_function = func + return func + + # Create wrapper to convert 0-dimensional outputs to scalars + def wrapper(*args): + out = func(*args) + # out can be array(1.0) or [array(1.0), array(2.0)] + + if is_sequence(out): + return [o[()] if is_0d[i] else o for i, o in enumerate(out)] + else: + return out[()] + + wrapper.__wrapped__ = func + wrapper.__doc__ = func.__doc__ + wrapper.theano_function = func + return wrapper diff --git a/phivenv/Lib/site-packages/sympy/printing/tree.py b/phivenv/Lib/site-packages/sympy/printing/tree.py new file mode 100644 index 0000000000000000000000000000000000000000..82dac013419fbe93f63dcf5b90b3a529d72a32bc --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/printing/tree.py @@ -0,0 +1,175 @@ +def pprint_nodes(subtrees): + """ + Prettyprints systems of nodes. + + Examples + ======== + + >>> from sympy.printing.tree import pprint_nodes + >>> print(pprint_nodes(["a", "b1\\nb2", "c"])) + +-a + +-b1 + | b2 + +-c + + """ + def indent(s, type=1): + x = s.split("\n") + r = "+-%s\n" % x[0] + for a in x[1:]: + if a == "": + continue + if type == 1: + r += "| %s\n" % a + else: + r += " %s\n" % a + return r + if not subtrees: + return "" + f = "" + for a in subtrees[:-1]: + f += indent(a) + f += indent(subtrees[-1], 2) + return f + + +def print_node(node, assumptions=True): + """ + Returns information about the "node". + + This includes class name, string representation and assumptions. + + Parameters + ========== + + assumptions : bool, optional + See the ``assumptions`` keyword in ``tree`` + """ + s = "%s: %s\n" % (node.__class__.__name__, str(node)) + + if assumptions: + d = node._assumptions + else: + d = None + + if d: + for a in sorted(d): + v = d[a] + if v is None: + continue + s += "%s: %s\n" % (a, v) + + return s + + +def tree(node, assumptions=True): + """ + Returns a tree representation of "node" as a string. + + It uses print_node() together with pprint_nodes() on node.args recursively. + + Parameters + ========== + + assumptions : bool, optional + The flag to decide whether to print out all the assumption data + (such as ``is_integer`, ``is_real``) associated with the + expression or not. + + Enabling the flag makes the result verbose, and the printed + result may not be deterministic because of the randomness used + in backtracing the assumptions. + + See Also + ======== + + print_tree + + """ + subtrees = [] + for arg in node.args: + subtrees.append(tree(arg, assumptions=assumptions)) + s = print_node(node, assumptions=assumptions) + pprint_nodes(subtrees) + return s + + +def print_tree(node, assumptions=True): + """ + Prints a tree representation of "node". + + Parameters + ========== + + assumptions : bool, optional + The flag to decide whether to print out all the assumption data + (such as ``is_integer`, ``is_real``) associated with the + expression or not. + + Enabling the flag makes the result verbose, and the printed + result may not be deterministic because of the randomness used + in backtracing the assumptions. + + Examples + ======== + + >>> from sympy.printing import print_tree + >>> from sympy import Symbol + >>> x = Symbol('x', odd=True) + >>> y = Symbol('y', even=True) + + Printing with full assumptions information: + + >>> print_tree(y**x) + Pow: y**x + +-Symbol: y + | algebraic: True + | commutative: True + | complex: True + | even: True + | extended_real: True + | finite: True + | hermitian: True + | imaginary: False + | infinite: False + | integer: True + | irrational: False + | noninteger: False + | odd: False + | rational: True + | real: True + | transcendental: False + +-Symbol: x + algebraic: True + commutative: True + complex: True + even: False + extended_nonzero: True + extended_real: True + finite: True + hermitian: True + imaginary: False + infinite: False + integer: True + irrational: False + noninteger: False + nonzero: True + odd: True + rational: True + real: True + transcendental: False + zero: False + + Hiding the assumptions: + + >>> print_tree(y**x, assumptions=False) + Pow: y**x + +-Symbol: y + +-Symbol: x + + See Also + ======== + + tree + + """ + print(tree(node, assumptions=assumptions)) diff --git a/phivenv/Lib/site-packages/sympy/sandbox/__init__.py b/phivenv/Lib/site-packages/sympy/sandbox/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3a84b7517819bb2fc9886274e09d955a74cabca1 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/sandbox/__init__.py @@ -0,0 +1,8 @@ +""" +Sandbox module of SymPy. + +This module contains experimental code, use at your own risk! + +There is no warranty that this code will still be located here in future +versions of SymPy. +""" diff --git a/phivenv/Lib/site-packages/sympy/sandbox/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/sandbox/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f013ad718f640f69d5f7a4d7f033490c0e49b4c8 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/sandbox/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/sandbox/__pycache__/indexed_integrals.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/sandbox/__pycache__/indexed_integrals.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f3ceb9139da29d88f4b97a180af24c60ebe7693 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/sandbox/__pycache__/indexed_integrals.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/sandbox/indexed_integrals.py b/phivenv/Lib/site-packages/sympy/sandbox/indexed_integrals.py new file mode 100644 index 0000000000000000000000000000000000000000..c0c17d141448b5a71cb814bff76a710a5bd43f88 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/sandbox/indexed_integrals.py @@ -0,0 +1,72 @@ +from sympy.tensor import Indexed +from sympy.core.containers import Tuple +from sympy.core.symbol import Dummy +from sympy.core.sympify import sympify +from sympy.integrals.integrals import Integral + + +class IndexedIntegral(Integral): + """ + Experimental class to test integration by indexed variables. + + Usage is analogue to ``Integral``, it simply adds awareness of + integration over indices. + + Contraction of non-identical index symbols referring to the same + ``IndexedBase`` is not yet supported. + + Examples + ======== + + >>> from sympy.sandbox.indexed_integrals import IndexedIntegral + >>> from sympy import IndexedBase, symbols + >>> A = IndexedBase('A') + >>> i, j = symbols('i j', integer=True) + >>> ii = IndexedIntegral(A[i], A[i]) + >>> ii + Integral(_A[i], _A[i]) + >>> ii.doit() + A[i]**2/2 + + If the indices are different, indexed objects are considered to be + different variables: + + >>> i2 = IndexedIntegral(A[j], A[i]) + >>> i2 + Integral(A[j], _A[i]) + >>> i2.doit() + A[i]*A[j] + """ + + def __new__(cls, function, *limits, **assumptions): + repl, limits = IndexedIntegral._indexed_process_limits(limits) + function = sympify(function) + function = function.xreplace(repl) + obj = Integral.__new__(cls, function, *limits, **assumptions) + obj._indexed_repl = repl + obj._indexed_reverse_repl = {val: key for key, val in repl.items()} + return obj + + def doit(self): + res = super().doit() + return res.xreplace(self._indexed_reverse_repl) + + @staticmethod + def _indexed_process_limits(limits): + repl = {} + newlimits = [] + for i in limits: + if isinstance(i, (tuple, list, Tuple)): + v = i[0] + vrest = i[1:] + else: + v = i + vrest = () + if isinstance(v, Indexed): + if v not in repl: + r = Dummy(str(v)) + repl[v] = r + newlimits.append((r,)+vrest) + else: + newlimits.append(i) + return repl, newlimits diff --git a/phivenv/Lib/site-packages/sympy/sandbox/tests/__init__.py b/phivenv/Lib/site-packages/sympy/sandbox/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/sandbox/tests/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/sandbox/tests/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97fadc96e331095012b4aa9f858e754c75e71012 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/sandbox/tests/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/sandbox/tests/__pycache__/test_indexed_integrals.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/sandbox/tests/__pycache__/test_indexed_integrals.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c81f54546e24973c62fb7ca8dc053ffac1ae5959 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/sandbox/tests/__pycache__/test_indexed_integrals.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/sandbox/tests/test_indexed_integrals.py b/phivenv/Lib/site-packages/sympy/sandbox/tests/test_indexed_integrals.py new file mode 100644 index 0000000000000000000000000000000000000000..61b98f0ffec29e026f6dfe8e16fde8b5818b0b09 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/sandbox/tests/test_indexed_integrals.py @@ -0,0 +1,25 @@ +from sympy.sandbox.indexed_integrals import IndexedIntegral +from sympy.core.symbol import symbols +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.tensor.indexed import (Idx, IndexedBase) + + +def test_indexed_integrals(): + A = IndexedBase('A') + i, j = symbols('i j', integer=True) + a1, a2 = symbols('a1:3', cls=Idx) + assert isinstance(a1, Idx) + + assert IndexedIntegral(1, A[i]).doit() == A[i] + assert IndexedIntegral(A[i], A[i]).doit() == A[i] ** 2 / 2 + assert IndexedIntegral(A[j], A[i]).doit() == A[i] * A[j] + assert IndexedIntegral(A[i] * A[j], A[i]).doit() == A[i] ** 2 * A[j] / 2 + assert IndexedIntegral(sin(A[i]), A[i]).doit() == -cos(A[i]) + assert IndexedIntegral(sin(A[j]), A[i]).doit() == sin(A[j]) * A[i] + + assert IndexedIntegral(1, A[a1]).doit() == A[a1] + assert IndexedIntegral(A[a1], A[a1]).doit() == A[a1] ** 2 / 2 + assert IndexedIntegral(A[a2], A[a1]).doit() == A[a1] * A[a2] + assert IndexedIntegral(A[a1] * A[a2], A[a1]).doit() == A[a1] ** 2 * A[a2] / 2 + assert IndexedIntegral(sin(A[a1]), A[a1]).doit() == -cos(A[a1]) + assert IndexedIntegral(sin(A[a2]), A[a1]).doit() == sin(A[a2]) * A[a1] diff --git a/phivenv/Lib/site-packages/sympy/series/__init__.py b/phivenv/Lib/site-packages/sympy/series/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..953653e21856b82bc0b708ccd922efb728a084ed --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/series/__init__.py @@ -0,0 +1,23 @@ +"""A module that handles series: find a limit, order the series etc. +""" +from .order import Order +from .limits import limit, Limit +from .gruntz import gruntz +from .series import series +from .approximants import approximants +from .residues import residue +from .sequences import SeqPer, SeqFormula, sequence, SeqAdd, SeqMul +from .fourier import fourier_series +from .formal import fps +from .limitseq import difference_delta, limit_seq + +from sympy.core.singleton import S +EmptySequence = S.EmptySequence + +O = Order + +__all__ = ['Order', 'O', 'limit', 'Limit', 'gruntz', 'series', 'approximants', + 'residue', 'EmptySequence', 'SeqPer', 'SeqFormula', 'sequence', + 'SeqAdd', 'SeqMul', 'fourier_series', 'fps', 'difference_delta', + 'limit_seq' + ] diff --git a/phivenv/Lib/site-packages/sympy/series/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/series/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..77b9ba91db94180e1f310196a4ebc096eec0fa42 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/series/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/series/__pycache__/acceleration.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/series/__pycache__/acceleration.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5bcca29537bf824d51ccefa8f1220b4412b0ddfe Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/series/__pycache__/acceleration.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/series/__pycache__/approximants.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/series/__pycache__/approximants.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ff5d8c12acaa2bd49383ceb64c532982460f621 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/series/__pycache__/approximants.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/series/__pycache__/aseries.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/series/__pycache__/aseries.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..debf2df70ae7c7a303228112adbeb81e63754792 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/series/__pycache__/aseries.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/series/__pycache__/formal.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/series/__pycache__/formal.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af6c1ec144291538b1b090723b1ad189b726debf Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/series/__pycache__/formal.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/series/__pycache__/fourier.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/series/__pycache__/fourier.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..34d2fc10b16d522316ce6597210ec606c82a7231 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/series/__pycache__/fourier.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/series/__pycache__/gruntz.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/series/__pycache__/gruntz.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..904586c13c64348bd2208726e28e0a7195f8264f Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/series/__pycache__/gruntz.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/series/__pycache__/kauers.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/series/__pycache__/kauers.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b99c457c18d346a879238621b57fb4bbdfb6525d Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/series/__pycache__/kauers.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/series/__pycache__/limits.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/series/__pycache__/limits.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f66aeb4b8fd4aa472de304f1b13bb975eacff44 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/series/__pycache__/limits.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/series/__pycache__/limitseq.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/series/__pycache__/limitseq.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9623b5ae200622da8a49ec53a6861eec344796f Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/series/__pycache__/limitseq.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/series/__pycache__/order.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/series/__pycache__/order.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd3a36222b68f31021bf940dd7d83f7ec90219f2 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/series/__pycache__/order.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/series/__pycache__/residues.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/series/__pycache__/residues.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b07d908344e08621e80c8a3879c7166cec0bd1a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/series/__pycache__/residues.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/series/__pycache__/sequences.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/series/__pycache__/sequences.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..502273edb3f55627b7c2555c321fb37ae012561d Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/series/__pycache__/sequences.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/series/__pycache__/series.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/series/__pycache__/series.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6761a086aeef663f0c8a6cbe86ea3e5b83ee5ce7 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/series/__pycache__/series.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/series/__pycache__/series_class.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/series/__pycache__/series_class.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a539aa9b648f366c95f932f536b7547014bd45d Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/series/__pycache__/series_class.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/series/acceleration.py b/phivenv/Lib/site-packages/sympy/series/acceleration.py new file mode 100644 index 0000000000000000000000000000000000000000..e2c7c1629a4b0d52e2aa33bd415886dfed515693 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/series/acceleration.py @@ -0,0 +1,101 @@ +""" +Convergence acceleration / extrapolation methods for series and +sequences. + +References: +Carl M. Bender & Steven A. Orszag, "Advanced Mathematical Methods for +Scientists and Engineers: Asymptotic Methods and Perturbation Theory", +Springer 1999. (Shanks transformation: pp. 368-375, Richardson +extrapolation: pp. 375-377.) +""" + +from sympy.core.numbers import Integer +from sympy.core.singleton import S +from sympy.functions.combinatorial.factorials import factorial + + +def richardson(A, k, n, N): + """ + Calculate an approximation for lim k->oo A(k) using Richardson + extrapolation with the terms A(n), A(n+1), ..., A(n+N+1). + Choosing N ~= 2*n often gives good results. + + Examples + ======== + + A simple example is to calculate exp(1) using the limit definition. + This limit converges slowly; n = 100 only produces two accurate + digits: + + >>> from sympy.abc import n + >>> e = (1 + 1/n)**n + >>> print(round(e.subs(n, 100).evalf(), 10)) + 2.7048138294 + + Richardson extrapolation with 11 appropriately chosen terms gives + a value that is accurate to the indicated precision: + + >>> from sympy import E + >>> from sympy.series.acceleration import richardson + >>> print(round(richardson(e, n, 10, 20).evalf(), 10)) + 2.7182818285 + >>> print(round(E.evalf(), 10)) + 2.7182818285 + + Another useful application is to speed up convergence of series. + Computing 100 terms of the zeta(2) series 1/k**2 yields only + two accurate digits: + + >>> from sympy.abc import k, n + >>> from sympy import Sum + >>> A = Sum(k**-2, (k, 1, n)) + >>> print(round(A.subs(n, 100).evalf(), 10)) + 1.6349839002 + + Richardson extrapolation performs much better: + + >>> from sympy import pi + >>> print(round(richardson(A, n, 10, 20).evalf(), 10)) + 1.6449340668 + >>> print(round(((pi**2)/6).evalf(), 10)) # Exact value + 1.6449340668 + + """ + s = S.Zero + for j in range(0, N + 1): + s += (A.subs(k, Integer(n + j)).doit() * (n + j)**N * + S.NegativeOne**(j + N) / (factorial(j) * factorial(N - j))) + return s + + +def shanks(A, k, n, m=1): + """ + Calculate an approximation for lim k->oo A(k) using the n-term Shanks + transformation S(A)(n). With m > 1, calculate the m-fold recursive + Shanks transformation S(S(...S(A)...))(n). + + The Shanks transformation is useful for summing Taylor series that + converge slowly near a pole or singularity, e.g. for log(2): + + >>> from sympy.abc import k, n + >>> from sympy import Sum, Integer + >>> from sympy.series.acceleration import shanks + >>> A = Sum(Integer(-1)**(k+1) / k, (k, 1, n)) + >>> print(round(A.subs(n, 100).doit().evalf(), 10)) + 0.6881721793 + >>> print(round(shanks(A, n, 25).evalf(), 10)) + 0.6931396564 + >>> print(round(shanks(A, n, 25, 5).evalf(), 10)) + 0.6931471806 + + The correct value is 0.6931471805599453094172321215. + """ + table = [A.subs(k, Integer(j)).doit() for j in range(n + m + 2)] + table2 = table.copy() + + for i in range(1, m + 1): + for j in range(i, n + m + 1): + x, y, z = table[j - 1], table[j], table[j + 1] + table2[j] = (z*x - y**2) / (z + x - 2*y) + table = table2.copy() + return table[n] diff --git a/phivenv/Lib/site-packages/sympy/series/approximants.py b/phivenv/Lib/site-packages/sympy/series/approximants.py new file mode 100644 index 0000000000000000000000000000000000000000..3d54ce41bc7367606ae6260f8e9ac00149cedc0f --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/series/approximants.py @@ -0,0 +1,103 @@ +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.polys.polytools import lcm +from sympy.utilities import public + +@public +def approximants(l, X=Symbol('x'), simplify=False): + """ + Return a generator for consecutive Pade approximants for a series. + It can also be used for computing the rational generating function of a + series when possible, since the last approximant returned by the generator + will be the generating function (if any). + + Explanation + =========== + + The input list can contain more complex expressions than integer or rational + numbers; symbols may also be involved in the computation. An example below + show how to compute the generating function of the whole Pascal triangle. + + The generator can be asked to apply the sympy.simplify function on each + generated term, which will make the computation slower; however it may be + useful when symbols are involved in the expressions. + + Examples + ======== + + >>> from sympy.series import approximants + >>> from sympy import lucas, fibonacci, symbols, binomial + >>> g = [lucas(k) for k in range(16)] + >>> [e for e in approximants(g)] + [2, -4/(x - 2), (5*x - 2)/(3*x - 1), (x - 2)/(x**2 + x - 1)] + + >>> h = [fibonacci(k) for k in range(16)] + >>> [e for e in approximants(h)] + [x, -x/(x - 1), (x**2 - x)/(2*x - 1), -x/(x**2 + x - 1)] + + >>> x, t = symbols("x,t") + >>> p=[sum(binomial(k,i)*x**i for i in range(k+1)) for k in range(16)] + >>> y = approximants(p, t) + >>> for k in range(3): print(next(y)) + 1 + (x + 1)/((-x - 1)*(t*(x + 1) + (x + 1)/(-x - 1))) + nan + + >>> y = approximants(p, t, simplify=True) + >>> for k in range(3): print(next(y)) + 1 + -1/(t*(x + 1) - 1) + nan + + See Also + ======== + + sympy.concrete.guess.guess_generating_function_rational + mpmath.pade + """ + from sympy.simplify import simplify as simp + from sympy.simplify.radsimp import denom + p1, q1 = [S.One], [S.Zero] + p2, q2 = [S.Zero], [S.One] + while len(l): + b = 0 + while l[b]==0: + b += 1 + if b == len(l): + return + m = [S.One/l[b]] + for k in range(b+1, len(l)): + s = 0 + for j in range(b, k): + s -= l[j+1] * m[b-j-1] + m.append(s/l[b]) + l = m + a, l[0] = l[0], 0 + p = [0] * max(len(p2), b+len(p1)) + q = [0] * max(len(q2), b+len(q1)) + for k in range(len(p2)): + p[k] = a*p2[k] + for k in range(b, b+len(p1)): + p[k] += p1[k-b] + for k in range(len(q2)): + q[k] = a*q2[k] + for k in range(b, b+len(q1)): + q[k] += q1[k-b] + while p[-1]==0: p.pop() + while q[-1]==0: q.pop() + p1, p2 = p2, p + q1, q2 = q2, q + + # yield result + c = 1 + for x in p: + c = lcm(c, denom(x)) + for x in q: + c = lcm(c, denom(x)) + out = ( sum(c*e*X**k for k, e in enumerate(p)) + / sum(c*e*X**k for k, e in enumerate(q)) ) + if simplify: + yield(simp(out)) + else: + yield out + return diff --git a/phivenv/Lib/site-packages/sympy/series/aseries.py b/phivenv/Lib/site-packages/sympy/series/aseries.py new file mode 100644 index 0000000000000000000000000000000000000000..dbbe0664e6d43a9329f37789c16c48143eda5413 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/series/aseries.py @@ -0,0 +1,10 @@ +from sympy.core.sympify import sympify + + +def aseries(expr, x=None, n=6, bound=0, hir=False): + """ + See the docstring of Expr.aseries() for complete details of this wrapper. + + """ + expr = sympify(expr) + return expr.aseries(x, n, bound, hir) diff --git a/phivenv/Lib/site-packages/sympy/series/benchmarks/__init__.py b/phivenv/Lib/site-packages/sympy/series/benchmarks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/series/benchmarks/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/series/benchmarks/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f55651c578cdfde70ab26a5b12b2a8198dbe39e Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/series/benchmarks/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/series/benchmarks/__pycache__/bench_limit.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/series/benchmarks/__pycache__/bench_limit.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef48969e4107bf4343377f97742451a20f24a175 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/series/benchmarks/__pycache__/bench_limit.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/series/benchmarks/__pycache__/bench_order.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/series/benchmarks/__pycache__/bench_order.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..51106667a3d62189b793a669b9d384802591b127 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/series/benchmarks/__pycache__/bench_order.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/series/benchmarks/bench_limit.py b/phivenv/Lib/site-packages/sympy/series/benchmarks/bench_limit.py new file mode 100644 index 0000000000000000000000000000000000000000..eafc28328848dad4b3ea433537971f5785253afe --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/series/benchmarks/bench_limit.py @@ -0,0 +1,9 @@ +from sympy.core.numbers import oo +from sympy.core.symbol import Symbol +from sympy.series.limits import limit + +x = Symbol('x') + + +def timeit_limit_1x(): + limit(1/x, x, oo) diff --git a/phivenv/Lib/site-packages/sympy/series/benchmarks/bench_order.py b/phivenv/Lib/site-packages/sympy/series/benchmarks/bench_order.py new file mode 100644 index 0000000000000000000000000000000000000000..1c85fa173dfc2a478792de8ab816c23ba9d408ef --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/series/benchmarks/bench_order.py @@ -0,0 +1,10 @@ +from sympy.core.add import Add +from sympy.core.symbol import Symbol +from sympy.series.order import O + +x = Symbol('x') +l = [x**i for i in range(1000)] +l.append(O(x**1001)) + +def timeit_order_1x(): + Add(*l) diff --git a/phivenv/Lib/site-packages/sympy/series/formal.py b/phivenv/Lib/site-packages/sympy/series/formal.py new file mode 100644 index 0000000000000000000000000000000000000000..ada591e03dd607daf8f8a5da1cf38cf283a3ed86 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/series/formal.py @@ -0,0 +1,1863 @@ +"""Formal Power Series""" + +from collections import defaultdict + +from sympy.core.numbers import (nan, oo, zoo) +from sympy.core.add import Add +from sympy.core.expr import Expr +from sympy.core.function import Derivative, Function, expand +from sympy.core.mul import Mul +from sympy.core.numbers import Rational +from sympy.core.relational import Eq +from sympy.sets.sets import Interval +from sympy.core.singleton import S +from sympy.core.symbol import Wild, Dummy, symbols, Symbol +from sympy.core.sympify import sympify +from sympy.discrete.convolutions import convolution +from sympy.functions.combinatorial.factorials import binomial, factorial, rf +from sympy.functions.combinatorial.numbers import bell +from sympy.functions.elementary.integers import floor, frac, ceiling +from sympy.functions.elementary.miscellaneous import Min, Max +from sympy.functions.elementary.piecewise import Piecewise +from sympy.series.limits import Limit +from sympy.series.order import Order +from sympy.series.sequences import sequence +from sympy.series.series_class import SeriesBase +from sympy.utilities.iterables import iterable + + + +def rational_algorithm(f, x, k, order=4, full=False): + """ + Rational algorithm for computing + formula of coefficients of Formal Power Series + of a function. + + Explanation + =========== + + Applicable when f(x) or some derivative of f(x) + is a rational function in x. + + :func:`rational_algorithm` uses :func:`~.apart` function for partial fraction + decomposition. :func:`~.apart` by default uses 'undetermined coefficients + method'. By setting ``full=True``, 'Bronstein's algorithm' can be used + instead. + + Looks for derivative of a function up to 4'th order (by default). + This can be overridden using order option. + + Parameters + ========== + + x : Symbol + order : int, optional + Order of the derivative of ``f``, Default is 4. + full : bool + + Returns + ======= + + formula : Expr + ind : Expr + Independent terms. + order : int + full : bool + + Examples + ======== + + >>> from sympy import log, atan + >>> from sympy.series.formal import rational_algorithm as ra + >>> from sympy.abc import x, k + + >>> ra(1 / (1 - x), x, k) + (1, 0, 0) + >>> ra(log(1 + x), x, k) + (-1/((-1)**k*k), 0, 1) + + >>> ra(atan(x), x, k, full=True) + ((-I/(2*(-I)**k) + I/(2*I**k))/k, 0, 1) + + Notes + ===== + + By setting ``full=True``, range of admissible functions to be solved using + ``rational_algorithm`` can be increased. This option should be used + carefully as it can significantly slow down the computation as ``doit`` is + performed on the :class:`~.RootSum` object returned by the :func:`~.apart` + function. Use ``full=False`` whenever possible. + + See Also + ======== + + sympy.polys.partfrac.apart + + References + ========== + + .. [1] Formal Power Series - Dominik Gruntz, Wolfram Koepf + .. [2] Power Series in Computer Algebra - Wolfram Koepf + + """ + from sympy.polys import RootSum, apart + from sympy.integrals import integrate + + diff = f + ds = [] # list of diff + + for i in range(order + 1): + if i: + diff = diff.diff(x) + + if diff.is_rational_function(x): + coeff, sep = S.Zero, S.Zero + + terms = apart(diff, x, full=full) + if terms.has(RootSum): + terms = terms.doit() + + for t in Add.make_args(terms): + num, den = t.as_numer_denom() + if not den.has(x): + sep += t + else: + if isinstance(den, Mul): + # m*(n*x - a)**j -> (n*x - a)**j + ind = den.as_independent(x) + den = ind[1] + num /= ind[0] + + # (n*x - a)**j -> (x - b) + den, j = den.as_base_exp() + a, xterm = den.as_coeff_add(x) + + # term -> m/x**n + if not a: + sep += t + continue + + xc = xterm[0].coeff(x) + a /= -xc + num /= xc**j + + ak = ((-1)**j * num * + binomial(j + k - 1, k).rewrite(factorial) / + a**(j + k)) + coeff += ak + + # Hacky, better way? + if coeff.is_zero: + return None + if (coeff.has(x) or coeff.has(zoo) or coeff.has(oo) or + coeff.has(nan)): + return None + + for j in range(i): + coeff = (coeff / (k + j + 1)) + sep = integrate(sep, x) + sep += (ds.pop() - sep).limit(x, 0) # constant of integration + return (coeff.subs(k, k - i), sep, i) + + else: + ds.append(diff) + + return None + + +def rational_independent(terms, x): + """ + Returns a list of all the rationally independent terms. + + Examples + ======== + + >>> from sympy import sin, cos + >>> from sympy.series.formal import rational_independent + >>> from sympy.abc import x + + >>> rational_independent([cos(x), sin(x)], x) + [cos(x), sin(x)] + >>> rational_independent([x**2, sin(x), x*sin(x), x**3], x) + [x**3 + x**2, x*sin(x) + sin(x)] + """ + if not terms: + return [] + + ind = terms[0:1] + + for t in terms[1:]: + n = t.as_independent(x)[1] + for i, term in enumerate(ind): + d = term.as_independent(x)[1] + q = (n / d).cancel() + if q.is_rational_function(x): + ind[i] += t + break + else: + ind.append(t) + return ind + + +def simpleDE(f, x, g, order=4): + r""" + Generates simple DE. + + Explanation + =========== + + DE is of the form + + .. math:: + f^k(x) + \sum\limits_{j=0}^{k-1} A_j f^j(x) = 0 + + where :math:`A_j` should be rational function in x. + + Generates DE's upto order 4 (default). DE's can also have free parameters. + + By increasing order, higher order DE's can be found. + + Yields a tuple of (DE, order). + """ + from sympy.solvers.solveset import linsolve + + a = symbols('a:%d' % (order)) + + def _makeDE(k): + eq = f.diff(x, k) + Add(*[a[i]*f.diff(x, i) for i in range(0, k)]) + DE = g(x).diff(x, k) + Add(*[a[i]*g(x).diff(x, i) for i in range(0, k)]) + return eq, DE + + found = False + for k in range(1, order + 1): + eq, DE = _makeDE(k) + eq = eq.expand() + terms = eq.as_ordered_terms() + ind = rational_independent(terms, x) + if found or len(ind) == k: + sol = dict(zip(a, (i for s in linsolve(ind, a[:k]) for i in s))) + if sol: + found = True + DE = DE.subs(sol) + DE = DE.as_numer_denom()[0] + DE = DE.factor().as_coeff_mul(Derivative)[1][0] + yield DE.collect(Derivative(g(x))), k + + +def exp_re(DE, r, k): + """Converts a DE with constant coefficients (explike) into a RE. + + Explanation + =========== + + Performs the substitution: + + .. math:: + f^j(x) \\to r(k + j) + + Normalises the terms so that lowest order of a term is always r(k). + + Examples + ======== + + >>> from sympy import Function, Derivative + >>> from sympy.series.formal import exp_re + >>> from sympy.abc import x, k + >>> f, r = Function('f'), Function('r') + + >>> exp_re(-f(x) + Derivative(f(x)), r, k) + -r(k) + r(k + 1) + >>> exp_re(Derivative(f(x), x) + Derivative(f(x), (x, 2)), r, k) + r(k) + r(k + 1) + + See Also + ======== + + sympy.series.formal.hyper_re + """ + RE = S.Zero + + g = DE.atoms(Function).pop() + + mini = None + for t in Add.make_args(DE): + coeff, d = t.as_independent(g) + if isinstance(d, Derivative): + j = d.derivative_count + else: + j = 0 + if mini is None or j < mini: + mini = j + RE += coeff * r(k + j) + if mini: + RE = RE.subs(k, k - mini) + return RE + + +def hyper_re(DE, r, k): + """ + Converts a DE into a RE. + + Explanation + =========== + + Performs the substitution: + + .. math:: + x^l f^j(x) \\to (k + 1 - l)_j . a_{k + j - l} + + Normalises the terms so that lowest order of a term is always r(k). + + Examples + ======== + + >>> from sympy import Function, Derivative + >>> from sympy.series.formal import hyper_re + >>> from sympy.abc import x, k + >>> f, r = Function('f'), Function('r') + + >>> hyper_re(-f(x) + Derivative(f(x)), r, k) + (k + 1)*r(k + 1) - r(k) + >>> hyper_re(-x*f(x) + Derivative(f(x), (x, 2)), r, k) + (k + 2)*(k + 3)*r(k + 3) - r(k) + + See Also + ======== + + sympy.series.formal.exp_re + """ + RE = S.Zero + + g = DE.atoms(Function).pop() + x = g.atoms(Symbol).pop() + + mini = None + for t in Add.make_args(DE.expand()): + coeff, d = t.as_independent(g) + c, v = coeff.as_independent(x) + l = v.as_coeff_exponent(x)[1] + if isinstance(d, Derivative): + j = d.derivative_count + else: + j = 0 + RE += c * rf(k + 1 - l, j) * r(k + j - l) + if mini is None or j - l < mini: + mini = j - l + + RE = RE.subs(k, k - mini) + + m = Wild('m') + return RE.collect(r(k + m)) + + +def _transformation_a(f, x, P, Q, k, m, shift): + f *= x**(-shift) + P = P.subs(k, k + shift) + Q = Q.subs(k, k + shift) + return f, P, Q, m + + +def _transformation_c(f, x, P, Q, k, m, scale): + f = f.subs(x, x**scale) + P = P.subs(k, k / scale) + Q = Q.subs(k, k / scale) + m *= scale + return f, P, Q, m + + +def _transformation_e(f, x, P, Q, k, m): + f = f.diff(x) + P = P.subs(k, k + 1) * (k + m + 1) + Q = Q.subs(k, k + 1) * (k + 1) + return f, P, Q, m + + +def _apply_shift(sol, shift): + return [(res, cond + shift) for res, cond in sol] + + +def _apply_scale(sol, scale): + return [(res, cond / scale) for res, cond in sol] + + +def _apply_integrate(sol, x, k): + return [(res / ((cond + 1)*(cond.as_coeff_Add()[1].coeff(k))), cond + 1) + for res, cond in sol] + + +def _compute_formula(f, x, P, Q, k, m, k_max): + """Computes the formula for f.""" + from sympy.polys import roots + + sol = [] + for i in range(k_max + 1, k_max + m + 1): + if (i < 0) == True: + continue + r = f.diff(x, i).limit(x, 0) / factorial(i) + if r.is_zero: + continue + + kterm = m*k + i + res = r + + p = P.subs(k, kterm) + q = Q.subs(k, kterm) + c1 = p.subs(k, 1/k).leadterm(k)[0] + c2 = q.subs(k, 1/k).leadterm(k)[0] + res *= (-c1 / c2)**k + + res *= Mul(*[rf(-r, k)**mul for r, mul in roots(p, k).items()]) + res /= Mul(*[rf(-r, k)**mul for r, mul in roots(q, k).items()]) + + sol.append((res, kterm)) + + return sol + + +def _rsolve_hypergeometric(f, x, P, Q, k, m): + """ + Recursive wrapper to rsolve_hypergeometric. + + Explanation + =========== + + Returns a Tuple of (formula, series independent terms, + maximum power of x in independent terms) if successful + otherwise ``None``. + + See :func:`rsolve_hypergeometric` for details. + """ + from sympy.polys import lcm, roots + from sympy.integrals import integrate + + # transformation - c + proots, qroots = roots(P, k), roots(Q, k) + all_roots = dict(proots) + all_roots.update(qroots) + scale = lcm([r.as_numer_denom()[1] for r, t in all_roots.items() + if r.is_rational]) + f, P, Q, m = _transformation_c(f, x, P, Q, k, m, scale) + + # transformation - a + qroots = roots(Q, k) + if qroots: + k_min = Min(*qroots.keys()) + else: + k_min = S.Zero + shift = k_min + m + f, P, Q, m = _transformation_a(f, x, P, Q, k, m, shift) + + l = (x*f).limit(x, 0) + if not isinstance(l, Limit) and l != 0: # Ideally should only be l != 0 + return None + + qroots = roots(Q, k) + if qroots: + k_max = Max(*qroots.keys()) + else: + k_max = S.Zero + + ind, mp = S.Zero, -oo + for i in range(k_max + m + 1): + r = f.diff(x, i).limit(x, 0) / factorial(i) + if r.is_finite is False: + old_f = f + f, P, Q, m = _transformation_a(f, x, P, Q, k, m, i) + f, P, Q, m = _transformation_e(f, x, P, Q, k, m) + sol, ind, mp = _rsolve_hypergeometric(f, x, P, Q, k, m) + sol = _apply_integrate(sol, x, k) + sol = _apply_shift(sol, i) + ind = integrate(ind, x) + ind += (old_f - ind).limit(x, 0) # constant of integration + mp += 1 + return sol, ind, mp + elif r: + ind += r*x**(i + shift) + pow_x = Rational((i + shift), scale) + if pow_x > mp: + mp = pow_x # maximum power of x + ind = ind.subs(x, x**(1/scale)) + + sol = _compute_formula(f, x, P, Q, k, m, k_max) + sol = _apply_shift(sol, shift) + sol = _apply_scale(sol, scale) + + return sol, ind, mp + + +def rsolve_hypergeometric(f, x, P, Q, k, m): + """ + Solves RE of hypergeometric type. + + Explanation + =========== + + Attempts to solve RE of the form + + Q(k)*a(k + m) - P(k)*a(k) + + Transformations that preserve Hypergeometric type: + + a. x**n*f(x): b(k + m) = R(k - n)*b(k) + b. f(A*x): b(k + m) = A**m*R(k)*b(k) + c. f(x**n): b(k + n*m) = R(k/n)*b(k) + d. f(x**(1/m)): b(k + 1) = R(k*m)*b(k) + e. f'(x): b(k + m) = ((k + m + 1)/(k + 1))*R(k + 1)*b(k) + + Some of these transformations have been used to solve the RE. + + Returns + ======= + + formula : Expr + ind : Expr + Independent terms. + order : int + + Examples + ======== + + >>> from sympy import exp, ln, S + >>> from sympy.series.formal import rsolve_hypergeometric as rh + >>> from sympy.abc import x, k + + >>> rh(exp(x), x, -S.One, (k + 1), k, 1) + (Piecewise((1/factorial(k), Eq(Mod(k, 1), 0)), (0, True)), 1, 1) + + >>> rh(ln(1 + x), x, k**2, k*(k + 1), k, 1) + (Piecewise(((-1)**(k - 1)*factorial(k - 1)/RisingFactorial(2, k - 1), + Eq(Mod(k, 1), 0)), (0, True)), x, 2) + + References + ========== + + .. [1] Formal Power Series - Dominik Gruntz, Wolfram Koepf + .. [2] Power Series in Computer Algebra - Wolfram Koepf + """ + result = _rsolve_hypergeometric(f, x, P, Q, k, m) + + if result is None: + return None + + sol_list, ind, mp = result + + sol_dict = defaultdict(lambda: S.Zero) + for res, cond in sol_list: + j, mk = cond.as_coeff_Add() + c = mk.coeff(k) + + if j.is_integer is False: + res *= x**frac(j) + j = floor(j) + + res = res.subs(k, (k - j) / c) + cond = Eq(k % c, j % c) + sol_dict[cond] += res # Group together formula for same conditions + + sol = [(res, cond) for cond, res in sol_dict.items()] + sol.append((S.Zero, True)) + sol = Piecewise(*sol) + + if mp is -oo: + s = S.Zero + elif mp.is_integer is False: + s = ceiling(mp) + else: + s = mp + 1 + + # save all the terms of + # form 1/x**k in ind + if s < 0: + ind += sum(sequence(sol * x**k, (k, s, -1))) + s = S.Zero + + return (sol, ind, s) + + +def _solve_hyper_RE(f, x, RE, g, k): + """See docstring of :func:`rsolve_hypergeometric` for details.""" + terms = Add.make_args(RE) + + if len(terms) == 2: + gs = list(RE.atoms(Function)) + P, Q = map(RE.coeff, gs) + m = gs[1].args[0] - gs[0].args[0] + if m < 0: + P, Q = Q, P + m = abs(m) + return rsolve_hypergeometric(f, x, P, Q, k, m) + + +def _solve_explike_DE(f, x, DE, g, k): + """Solves DE with constant coefficients.""" + from sympy.solvers import rsolve + + for t in Add.make_args(DE): + coeff, d = t.as_independent(g) + if coeff.free_symbols: + return + + RE = exp_re(DE, g, k) + + init = {} + for i in range(len(Add.make_args(RE))): + if i: + f = f.diff(x) + init[g(k).subs(k, i)] = f.limit(x, 0) + + sol = rsolve(RE, g(k), init) + + if sol: + return (sol / factorial(k), S.Zero, S.Zero) + + +def _solve_simple(f, x, DE, g, k): + """Converts DE into RE and solves using :func:`rsolve`.""" + from sympy.solvers import rsolve + + RE = hyper_re(DE, g, k) + + init = {} + for i in range(len(Add.make_args(RE))): + if i: + f = f.diff(x) + init[g(k).subs(k, i)] = f.limit(x, 0) / factorial(i) + + sol = rsolve(RE, g(k), init) + + if sol: + return (sol, S.Zero, S.Zero) + + +def _transform_explike_DE(DE, g, x, order, syms): + """Converts DE with free parameters into DE with constant coefficients.""" + from sympy.solvers.solveset import linsolve + + eq = [] + highest_coeff = DE.coeff(Derivative(g(x), x, order)) + for i in range(order): + coeff = DE.coeff(Derivative(g(x), x, i)) + coeff = (coeff / highest_coeff).expand().collect(x) + eq.extend(Add.make_args(coeff)) + temp = [] + for e in eq: + if e.has(x): + break + elif e.has(Symbol): + temp.append(e) + else: + eq = temp + if eq: + sol = dict(zip(syms, (i for s in linsolve(eq, list(syms)) for i in s))) + if sol: + DE = DE.subs(sol) + DE = DE.factor().as_coeff_mul(Derivative)[1][0] + DE = DE.collect(Derivative(g(x))) + return DE + + +def _transform_DE_RE(DE, g, k, order, syms): + """Converts DE with free parameters into RE of hypergeometric type.""" + from sympy.solvers.solveset import linsolve + + RE = hyper_re(DE, g, k) + + eq = [RE.coeff(g(k + i)) for i in range(1, order)] + sol = dict(zip(syms, (i for s in linsolve(eq, list(syms)) for i in s))) + if sol: + m = Wild('m') + RE = RE.subs(sol) + RE = RE.factor().as_numer_denom()[0].collect(g(k + m)) + RE = RE.as_coeff_mul(g)[1][0] + for i in range(order): # smallest order should be g(k) + if RE.coeff(g(k + i)) and i: + RE = RE.subs(k, k - i) + break + return RE + + +def solve_de(f, x, DE, order, g, k): + """ + Solves the DE. + + Explanation + =========== + + Tries to solve DE by either converting into a RE containing two terms or + converting into a DE having constant coefficients. + + Returns + ======= + + formula : Expr + ind : Expr + Independent terms. + order : int + + Examples + ======== + + >>> from sympy import Derivative as D, Function + >>> from sympy import exp, ln + >>> from sympy.series.formal import solve_de + >>> from sympy.abc import x, k + >>> f = Function('f') + + >>> solve_de(exp(x), x, D(f(x), x) - f(x), 1, f, k) + (Piecewise((1/factorial(k), Eq(Mod(k, 1), 0)), (0, True)), 1, 1) + + >>> solve_de(ln(1 + x), x, (x + 1)*D(f(x), x, 2) + D(f(x)), 2, f, k) + (Piecewise(((-1)**(k - 1)*factorial(k - 1)/RisingFactorial(2, k - 1), + Eq(Mod(k, 1), 0)), (0, True)), x, 2) + """ + sol = None + syms = DE.free_symbols.difference({g, x}) + + if syms: + RE = _transform_DE_RE(DE, g, k, order, syms) + else: + RE = hyper_re(DE, g, k) + if not RE.free_symbols.difference({k}): + sol = _solve_hyper_RE(f, x, RE, g, k) + + if sol: + return sol + + if syms: + DE = _transform_explike_DE(DE, g, x, order, syms) + if not DE.free_symbols.difference({x}): + sol = _solve_explike_DE(f, x, DE, g, k) + + if sol: + return sol + + +def hyper_algorithm(f, x, k, order=4): + """ + Hypergeometric algorithm for computing Formal Power Series. + + Explanation + =========== + + Steps: + * Generates DE + * Convert the DE into RE + * Solves the RE + + Examples + ======== + + >>> from sympy import exp, ln + >>> from sympy.series.formal import hyper_algorithm + + >>> from sympy.abc import x, k + + >>> hyper_algorithm(exp(x), x, k) + (Piecewise((1/factorial(k), Eq(Mod(k, 1), 0)), (0, True)), 1, 1) + + >>> hyper_algorithm(ln(1 + x), x, k) + (Piecewise(((-1)**(k - 1)*factorial(k - 1)/RisingFactorial(2, k - 1), + Eq(Mod(k, 1), 0)), (0, True)), x, 2) + + See Also + ======== + + sympy.series.formal.simpleDE + sympy.series.formal.solve_de + """ + g = Function('g') + + des = [] # list of DE's + sol = None + for DE, i in simpleDE(f, x, g, order): + if DE is not None: + sol = solve_de(f, x, DE, i, g, k) + if sol: + return sol + if not DE.free_symbols.difference({x}): + des.append(DE) + + # If nothing works + # Try plain rsolve + for DE in des: + sol = _solve_simple(f, x, DE, g, k) + if sol: + return sol + + +def _compute_fps(f, x, x0, dir, hyper, order, rational, full): + """Recursive wrapper to compute fps. + + See :func:`compute_fps` for details. + """ + if x0 in [S.Infinity, S.NegativeInfinity]: + dir = S.One if x0 is S.Infinity else -S.One + temp = f.subs(x, 1/x) + result = _compute_fps(temp, x, 0, dir, hyper, order, rational, full) + if result is None: + return None + return (result[0], result[1].subs(x, 1/x), result[2].subs(x, 1/x)) + elif x0 or dir == -S.One: + if dir == -S.One: + rep = -x + x0 + rep2 = -x + rep2b = x0 + else: + rep = x + x0 + rep2 = x + rep2b = -x0 + temp = f.subs(x, rep) + result = _compute_fps(temp, x, 0, S.One, hyper, order, rational, full) + if result is None: + return None + return (result[0], result[1].subs(x, rep2 + rep2b), + result[2].subs(x, rep2 + rep2b)) + + if f.is_polynomial(x): + k = Dummy('k') + ak = sequence(Coeff(f, x, k), (k, 1, oo)) + xk = sequence(x**k, (k, 0, oo)) + ind = f.coeff(x, 0) + return ak, xk, ind + + # Break instances of Add + # this allows application of different + # algorithms on different terms increasing the + # range of admissible functions. + if isinstance(f, Add): + result = False + ak = sequence(S.Zero, (0, oo)) + ind, xk = S.Zero, None + for t in Add.make_args(f): + res = _compute_fps(t, x, 0, S.One, hyper, order, rational, full) + if res: + if not result: + result = True + xk = res[1] + if res[0].start > ak.start: + seq = ak + s, f = ak.start, res[0].start + else: + seq = res[0] + s, f = res[0].start, ak.start + save = Add(*[z[0]*z[1] for z in zip(seq[0:(f - s)], xk[s:f])]) + ak += res[0] + ind += res[2] + save + else: + ind += t + if result: + return ak, xk, ind + return None + + # The symbolic term - symb, if present, is being separated from the function + # Otherwise symb is being set to S.One + syms = f.free_symbols.difference({x}) + (f, symb) = expand(f).as_independent(*syms) + + result = None + + # from here on it's x0=0 and dir=1 handling + k = Dummy('k') + if rational: + result = rational_algorithm(f, x, k, order, full) + + if result is None and hyper: + result = hyper_algorithm(f, x, k, order) + + if result is None: + return None + + from sympy.simplify.powsimp import powsimp + if symb.is_zero: + symb = S.One + else: + symb = powsimp(symb) + ak = sequence(result[0], (k, result[2], oo)) + xk_formula = powsimp(x**k * symb) + xk = sequence(xk_formula, (k, 0, oo)) + ind = powsimp(result[1] * symb) + + return ak, xk, ind + + +def compute_fps(f, x, x0=0, dir=1, hyper=True, order=4, rational=True, + full=False): + """ + Computes the formula for Formal Power Series of a function. + + Explanation + =========== + + Tries to compute the formula by applying the following techniques + (in order): + + * rational_algorithm + * Hypergeometric algorithm + + Parameters + ========== + + x : Symbol + x0 : number, optional + Point to perform series expansion about. Default is 0. + dir : {1, -1, '+', '-'}, optional + If dir is 1 or '+' the series is calculated from the right and + for -1 or '-' the series is calculated from the left. For smooth + functions this flag will not alter the results. Default is 1. + hyper : {True, False}, optional + Set hyper to False to skip the hypergeometric algorithm. + By default it is set to False. + order : int, optional + Order of the derivative of ``f``, Default is 4. + rational : {True, False}, optional + Set rational to False to skip rational algorithm. By default it is set + to True. + full : {True, False}, optional + Set full to True to increase the range of rational algorithm. + See :func:`rational_algorithm` for details. By default it is set to + False. + + Returns + ======= + + ak : sequence + Sequence of coefficients. + xk : sequence + Sequence of powers of x. + ind : Expr + Independent terms. + mul : Pow + Common terms. + + See Also + ======== + + sympy.series.formal.rational_algorithm + sympy.series.formal.hyper_algorithm + """ + f = sympify(f) + x = sympify(x) + + if not f.has(x): + return None + + x0 = sympify(x0) + + if dir == '+': + dir = S.One + elif dir == '-': + dir = -S.One + elif dir not in [S.One, -S.One]: + raise ValueError("Dir must be '+' or '-'") + else: + dir = sympify(dir) + + return _compute_fps(f, x, x0, dir, hyper, order, rational, full) + + +class Coeff(Function): + """ + Coeff(p, x, n) represents the nth coefficient of the polynomial p in x + """ + @classmethod + def eval(cls, p, x, n): + if p.is_polynomial(x) and n.is_integer: + return p.coeff(x, n) + + +class FormalPowerSeries(SeriesBase): + """ + Represents Formal Power Series of a function. + + Explanation + =========== + + No computation is performed. This class should only to be used to represent + a series. No checks are performed. + + For computing a series use :func:`fps`. + + See Also + ======== + + sympy.series.formal.fps + """ + def __new__(cls, *args): + args = map(sympify, args) + return Expr.__new__(cls, *args) + + def __init__(self, *args): + ak = args[4][0] + k = ak.variables[0] + self.ak_seq = sequence(ak.formula, (k, 1, oo)) + self.fact_seq = sequence(factorial(k), (k, 1, oo)) + self.bell_coeff_seq = self.ak_seq * self.fact_seq + self.sign_seq = sequence((-1, 1), (k, 1, oo)) + + @property + def function(self): + return self.args[0] + + @property + def x(self): + return self.args[1] + + @property + def x0(self): + return self.args[2] + + @property + def dir(self): + return self.args[3] + + @property + def ak(self): + return self.args[4][0] + + @property + def xk(self): + return self.args[4][1] + + @property + def ind(self): + return self.args[4][2] + + @property + def interval(self): + return Interval(0, oo) + + @property + def start(self): + return self.interval.inf + + @property + def stop(self): + return self.interval.sup + + @property + def length(self): + return oo + + @property + def infinite(self): + """Returns an infinite representation of the series""" + from sympy.concrete import Sum + ak, xk = self.ak, self.xk + k = ak.variables[0] + inf_sum = Sum(ak.formula * xk.formula, (k, ak.start, ak.stop)) + + return self.ind + inf_sum + + def _get_pow_x(self, term): + """Returns the power of x in a term.""" + xterm, pow_x = term.as_independent(self.x)[1].as_base_exp() + if not xterm.has(self.x): + return S.Zero + return pow_x + + def polynomial(self, n=6): + """ + Truncated series as polynomial. + + Explanation + =========== + + Returns series expansion of ``f`` upto order ``O(x**n)`` + as a polynomial(without ``O`` term). + """ + terms = [] + sym = self.free_symbols + for i, t in enumerate(self): + xp = self._get_pow_x(t) + if xp.has(*sym): + xp = xp.as_coeff_add(*sym)[0] + if xp >= n: + break + elif xp.is_integer is True and i == n + 1: + break + elif t is not S.Zero: + terms.append(t) + + return Add(*terms) + + def truncate(self, n=6): + """ + Truncated series. + + Explanation + =========== + + Returns truncated series expansion of f upto + order ``O(x**n)``. + + If n is ``None``, returns an infinite iterator. + """ + if n is None: + return iter(self) + + x, x0 = self.x, self.x0 + pt_xk = self.xk.coeff(n) + if x0 is S.NegativeInfinity: + x0 = S.Infinity + + return self.polynomial(n) + Order(pt_xk, (x, x0)) + + def zero_coeff(self): + return self._eval_term(0) + + def _eval_term(self, pt): + try: + pt_xk = self.xk.coeff(pt) + pt_ak = self.ak.coeff(pt).simplify() # Simplify the coefficients + except IndexError: + term = S.Zero + else: + term = (pt_ak * pt_xk) + + if self.ind: + ind = S.Zero + sym = self.free_symbols + for t in Add.make_args(self.ind): + pow_x = self._get_pow_x(t) + if pow_x.has(*sym): + pow_x = pow_x.as_coeff_add(*sym)[0] + if pt == 0 and pow_x < 1: + ind += t + elif pow_x >= pt and pow_x < pt + 1: + ind += t + term += ind + + return term.collect(self.x) + + def _eval_subs(self, old, new): + x = self.x + if old.has(x): + return self + + def _eval_as_leading_term(self, x, logx, cdir): + for t in self: + if t is not S.Zero: + return t + + def _eval_derivative(self, x): + f = self.function.diff(x) + ind = self.ind.diff(x) + + pow_xk = self._get_pow_x(self.xk.formula) + ak = self.ak + k = ak.variables[0] + if ak.formula.has(x): + form = [] + for e, c in ak.formula.args: + temp = S.Zero + for t in Add.make_args(e): + pow_x = self._get_pow_x(t) + temp += t * (pow_xk + pow_x) + form.append((temp, c)) + form = Piecewise(*form) + ak = sequence(form.subs(k, k + 1), (k, ak.start - 1, ak.stop)) + else: + ak = sequence((ak.formula * pow_xk).subs(k, k + 1), + (k, ak.start - 1, ak.stop)) + + return self.func(f, self.x, self.x0, self.dir, (ak, self.xk, ind)) + + def integrate(self, x=None, **kwargs): + """ + Integrate Formal Power Series. + + Examples + ======== + + >>> from sympy import fps, sin, integrate + >>> from sympy.abc import x + >>> f = fps(sin(x)) + >>> f.integrate(x).truncate() + -1 + x**2/2 - x**4/24 + O(x**6) + >>> integrate(f, (x, 0, 1)) + 1 - cos(1) + """ + from sympy.integrals import integrate + + if x is None: + x = self.x + elif iterable(x): + return integrate(self.function, x) + + f = integrate(self.function, x) + ind = integrate(self.ind, x) + ind += (f - ind).limit(x, 0) # constant of integration + + pow_xk = self._get_pow_x(self.xk.formula) + ak = self.ak + k = ak.variables[0] + if ak.formula.has(x): + form = [] + for e, c in ak.formula.args: + temp = S.Zero + for t in Add.make_args(e): + pow_x = self._get_pow_x(t) + temp += t / (pow_xk + pow_x + 1) + form.append((temp, c)) + form = Piecewise(*form) + ak = sequence(form.subs(k, k - 1), (k, ak.start + 1, ak.stop)) + else: + ak = sequence((ak.formula / (pow_xk + 1)).subs(k, k - 1), + (k, ak.start + 1, ak.stop)) + + return self.func(f, self.x, self.x0, self.dir, (ak, self.xk, ind)) + + def product(self, other, x=None, n=6): + """ + Multiplies two Formal Power Series, using discrete convolution and + return the truncated terms upto specified order. + + Parameters + ========== + + n : Number, optional + Specifies the order of the term up to which the polynomial should + be truncated. + + Examples + ======== + + >>> from sympy import fps, sin, exp + >>> from sympy.abc import x + >>> f1 = fps(sin(x)) + >>> f2 = fps(exp(x)) + + >>> f1.product(f2, x).truncate(4) + x + x**2 + x**3/3 + O(x**4) + + See Also + ======== + + sympy.discrete.convolutions + sympy.series.formal.FormalPowerSeriesProduct + + """ + + if n is None: + return iter(self) + + other = sympify(other) + + if not isinstance(other, FormalPowerSeries): + raise ValueError("Both series should be an instance of FormalPowerSeries" + " class.") + + if self.dir != other.dir: + raise ValueError("Both series should be calculated from the" + " same direction.") + elif self.x0 != other.x0: + raise ValueError("Both series should be calculated about the" + " same point.") + + elif self.x != other.x: + raise ValueError("Both series should have the same symbol.") + + return FormalPowerSeriesProduct(self, other) + + def coeff_bell(self, n): + r""" + self.coeff_bell(n) returns a sequence of Bell polynomials of the second kind. + Note that ``n`` should be a integer. + + The second kind of Bell polynomials (are sometimes called "partial" Bell + polynomials or incomplete Bell polynomials) are defined as + + .. math:: + B_{n,k}(x_1, x_2,\dotsc x_{n-k+1}) = + \sum_{j_1+j_2+j_2+\dotsb=k \atop j_1+2j_2+3j_2+\dotsb=n} + \frac{n!}{j_1!j_2!\dotsb j_{n-k+1}!} + \left(\frac{x_1}{1!} \right)^{j_1} + \left(\frac{x_2}{2!} \right)^{j_2} \dotsb + \left(\frac{x_{n-k+1}}{(n-k+1)!} \right) ^{j_{n-k+1}}. + + * ``bell(n, k, (x1, x2, ...))`` gives Bell polynomials of the second kind, + `B_{n,k}(x_1, x_2, \dotsc, x_{n-k+1})`. + + See Also + ======== + + sympy.functions.combinatorial.numbers.bell + + """ + + inner_coeffs = [bell(n, j, tuple(self.bell_coeff_seq[:n-j+1])) for j in range(1, n+1)] + + k = Dummy('k') + return sequence(tuple(inner_coeffs), (k, 1, oo)) + + def compose(self, other, x=None, n=6): + r""" + Returns the truncated terms of the formal power series of the composed function, + up to specified ``n``. + + Explanation + =========== + + If ``f`` and ``g`` are two formal power series of two different functions, + then the coefficient sequence ``ak`` of the composed formal power series `fp` + will be as follows. + + .. math:: + \sum\limits_{k=0}^{n} b_k B_{n,k}(x_1, x_2, \dotsc, x_{n-k+1}) + + Parameters + ========== + + n : Number, optional + Specifies the order of the term up to which the polynomial should + be truncated. + + Examples + ======== + + >>> from sympy import fps, sin, exp + >>> from sympy.abc import x + >>> f1 = fps(exp(x)) + >>> f2 = fps(sin(x)) + + >>> f1.compose(f2, x).truncate() + 1 + x + x**2/2 - x**4/8 - x**5/15 + O(x**6) + + >>> f1.compose(f2, x).truncate(8) + 1 + x + x**2/2 - x**4/8 - x**5/15 - x**6/240 + x**7/90 + O(x**8) + + See Also + ======== + + sympy.functions.combinatorial.numbers.bell + sympy.series.formal.FormalPowerSeriesCompose + + References + ========== + + .. [1] Comtet, Louis: Advanced combinatorics; the art of finite and infinite expansions. Reidel, 1974. + + """ + + if n is None: + return iter(self) + + other = sympify(other) + + if not isinstance(other, FormalPowerSeries): + raise ValueError("Both series should be an instance of FormalPowerSeries" + " class.") + + if self.dir != other.dir: + raise ValueError("Both series should be calculated from the" + " same direction.") + elif self.x0 != other.x0: + raise ValueError("Both series should be calculated about the" + " same point.") + + elif self.x != other.x: + raise ValueError("Both series should have the same symbol.") + + if other._eval_term(0).as_coeff_mul(other.x)[0] is not S.Zero: + raise ValueError("The formal power series of the inner function should not have any " + "constant coefficient term.") + + return FormalPowerSeriesCompose(self, other) + + def inverse(self, x=None, n=6): + r""" + Returns the truncated terms of the inverse of the formal power series, + up to specified ``n``. + + Explanation + =========== + + If ``f`` and ``g`` are two formal power series of two different functions, + then the coefficient sequence ``ak`` of the composed formal power series ``fp`` + will be as follows. + + .. math:: + \sum\limits_{k=0}^{n} (-1)^{k} x_0^{-k-1} B_{n,k}(x_1, x_2, \dotsc, x_{n-k+1}) + + Parameters + ========== + + n : Number, optional + Specifies the order of the term up to which the polynomial should + be truncated. + + Examples + ======== + + >>> from sympy import fps, exp, cos + >>> from sympy.abc import x + >>> f1 = fps(exp(x)) + >>> f2 = fps(cos(x)) + + >>> f1.inverse(x).truncate() + 1 - x + x**2/2 - x**3/6 + x**4/24 - x**5/120 + O(x**6) + + >>> f2.inverse(x).truncate(8) + 1 + x**2/2 + 5*x**4/24 + 61*x**6/720 + O(x**8) + + See Also + ======== + + sympy.functions.combinatorial.numbers.bell + sympy.series.formal.FormalPowerSeriesInverse + + References + ========== + + .. [1] Comtet, Louis: Advanced combinatorics; the art of finite and infinite expansions. Reidel, 1974. + + """ + + if n is None: + return iter(self) + + if self._eval_term(0).is_zero: + raise ValueError("Constant coefficient should exist for an inverse of a formal" + " power series to exist.") + + return FormalPowerSeriesInverse(self) + + def __add__(self, other): + other = sympify(other) + + if isinstance(other, FormalPowerSeries): + if self.dir != other.dir: + raise ValueError("Both series should be calculated from the" + " same direction.") + elif self.x0 != other.x0: + raise ValueError("Both series should be calculated about the" + " same point.") + + x, y = self.x, other.x + f = self.function + other.function.subs(y, x) + + if self.x not in f.free_symbols: + return f + + ak = self.ak + other.ak + if self.ak.start > other.ak.start: + seq = other.ak + s, e = other.ak.start, self.ak.start + else: + seq = self.ak + s, e = self.ak.start, other.ak.start + save = Add(*[z[0]*z[1] for z in zip(seq[0:(e - s)], self.xk[s:e])]) + ind = self.ind + other.ind + save + + return self.func(f, x, self.x0, self.dir, (ak, self.xk, ind)) + + elif not other.has(self.x): + f = self.function + other + ind = self.ind + other + + return self.func(f, self.x, self.x0, self.dir, + (self.ak, self.xk, ind)) + + return Add(self, other) + + def __radd__(self, other): + return self.__add__(other) + + def __neg__(self): + return self.func(-self.function, self.x, self.x0, self.dir, + (-self.ak, self.xk, -self.ind)) + + def __sub__(self, other): + return self.__add__(-other) + + def __rsub__(self, other): + return (-self).__add__(other) + + def __mul__(self, other): + other = sympify(other) + + if other.has(self.x): + return Mul(self, other) + + f = self.function * other + ak = self.ak.coeff_mul(other) + ind = self.ind * other + + return self.func(f, self.x, self.x0, self.dir, (ak, self.xk, ind)) + + def __rmul__(self, other): + return self.__mul__(other) + + +class FiniteFormalPowerSeries(FormalPowerSeries): + """Base Class for Product, Compose and Inverse classes""" + + def __init__(self, *args): + pass + + @property + def ffps(self): + return self.args[0] + + @property + def gfps(self): + return self.args[1] + + @property + def f(self): + return self.ffps.function + + @property + def g(self): + return self.gfps.function + + @property + def infinite(self): + raise NotImplementedError("No infinite version for an object of" + " FiniteFormalPowerSeries class.") + + def _eval_terms(self, n): + raise NotImplementedError("(%s)._eval_terms()" % self) + + def _eval_term(self, pt): + raise NotImplementedError("By the current logic, one can get terms" + "upto a certain order, instead of getting term by term.") + + def polynomial(self, n): + return self._eval_terms(n) + + def truncate(self, n=6): + ffps = self.ffps + pt_xk = ffps.xk.coeff(n) + x, x0 = ffps.x, ffps.x0 + + return self.polynomial(n) + Order(pt_xk, (x, x0)) + + def _eval_derivative(self, x): + raise NotImplementedError + + def integrate(self, x): + raise NotImplementedError + + +class FormalPowerSeriesProduct(FiniteFormalPowerSeries): + """Represents the product of two formal power series of two functions. + + Explanation + =========== + + No computation is performed. Terms are calculated using a term by term logic, + instead of a point by point logic. + + There are two differences between a :obj:`FormalPowerSeries` object and a + :obj:`FormalPowerSeriesProduct` object. The first argument contains the two + functions involved in the product. Also, the coefficient sequence contains + both the coefficient sequence of the formal power series of the involved functions. + + See Also + ======== + + sympy.series.formal.FormalPowerSeries + sympy.series.formal.FiniteFormalPowerSeries + + """ + + def __init__(self, *args): + ffps, gfps = self.ffps, self.gfps + + k = ffps.ak.variables[0] + self.coeff1 = sequence(ffps.ak.formula, (k, 0, oo)) + + k = gfps.ak.variables[0] + self.coeff2 = sequence(gfps.ak.formula, (k, 0, oo)) + + @property + def function(self): + """Function of the product of two formal power series.""" + return self.f * self.g + + def _eval_terms(self, n): + """ + Returns the first ``n`` terms of the product formal power series. + Term by term logic is implemented here. + + Examples + ======== + + >>> from sympy import fps, sin, exp + >>> from sympy.abc import x + >>> f1 = fps(sin(x)) + >>> f2 = fps(exp(x)) + >>> fprod = f1.product(f2, x) + + >>> fprod._eval_terms(4) + x**3/3 + x**2 + x + + See Also + ======== + + sympy.series.formal.FormalPowerSeries.product + + """ + coeff1, coeff2 = self.coeff1, self.coeff2 + + aks = convolution(coeff1[:n], coeff2[:n]) + + terms = [] + for i in range(0, n): + terms.append(aks[i] * self.ffps.xk.coeff(i)) + + return Add(*terms) + + +class FormalPowerSeriesCompose(FiniteFormalPowerSeries): + """ + Represents the composed formal power series of two functions. + + Explanation + =========== + + No computation is performed. Terms are calculated using a term by term logic, + instead of a point by point logic. + + There are two differences between a :obj:`FormalPowerSeries` object and a + :obj:`FormalPowerSeriesCompose` object. The first argument contains the outer + function and the inner function involved in the omposition. Also, the + coefficient sequence contains the generic sequence which is to be multiplied + by a custom ``bell_seq`` finite sequence. The finite terms will then be added up to + get the final terms. + + See Also + ======== + + sympy.series.formal.FormalPowerSeries + sympy.series.formal.FiniteFormalPowerSeries + + """ + + @property + def function(self): + """Function for the composed formal power series.""" + f, g, x = self.f, self.g, self.ffps.x + return f.subs(x, g) + + def _eval_terms(self, n): + """ + Returns the first `n` terms of the composed formal power series. + Term by term logic is implemented here. + + Explanation + =========== + + The coefficient sequence of the :obj:`FormalPowerSeriesCompose` object is the generic sequence. + It is multiplied by ``bell_seq`` to get a sequence, whose terms are added up to get + the final terms for the polynomial. + + Examples + ======== + + >>> from sympy import fps, sin, exp + >>> from sympy.abc import x + >>> f1 = fps(exp(x)) + >>> f2 = fps(sin(x)) + >>> fcomp = f1.compose(f2, x) + + >>> fcomp._eval_terms(6) + -x**5/15 - x**4/8 + x**2/2 + x + 1 + + >>> fcomp._eval_terms(8) + x**7/90 - x**6/240 - x**5/15 - x**4/8 + x**2/2 + x + 1 + + See Also + ======== + + sympy.series.formal.FormalPowerSeries.compose + sympy.series.formal.FormalPowerSeries.coeff_bell + + """ + + ffps, gfps = self.ffps, self.gfps + terms = [ffps.zero_coeff()] + + for i in range(1, n): + bell_seq = gfps.coeff_bell(i) + seq = (ffps.bell_coeff_seq * bell_seq) + terms.append(Add(*(seq[:i])) / ffps.fact_seq[i-1] * ffps.xk.coeff(i)) + + return Add(*terms) + + +class FormalPowerSeriesInverse(FiniteFormalPowerSeries): + """ + Represents the Inverse of a formal power series. + + Explanation + =========== + + No computation is performed. Terms are calculated using a term by term logic, + instead of a point by point logic. + + There is a single difference between a :obj:`FormalPowerSeries` object and a + :obj:`FormalPowerSeriesInverse` object. The coefficient sequence contains the + generic sequence which is to be multiplied by a custom ``bell_seq`` finite sequence. + The finite terms will then be added up to get the final terms. + + See Also + ======== + + sympy.series.formal.FormalPowerSeries + sympy.series.formal.FiniteFormalPowerSeries + + """ + def __init__(self, *args): + ffps = self.ffps + k = ffps.xk.variables[0] + + inv = ffps.zero_coeff() + inv_seq = sequence(inv ** (-(k + 1)), (k, 1, oo)) + self.aux_seq = ffps.sign_seq * ffps.fact_seq * inv_seq + + @property + def function(self): + """Function for the inverse of a formal power series.""" + f = self.f + return 1 / f + + @property + def g(self): + raise ValueError("Only one function is considered while performing" + "inverse of a formal power series.") + + @property + def gfps(self): + raise ValueError("Only one function is considered while performing" + "inverse of a formal power series.") + + def _eval_terms(self, n): + """ + Returns the first ``n`` terms of the composed formal power series. + Term by term logic is implemented here. + + Explanation + =========== + + The coefficient sequence of the `FormalPowerSeriesInverse` object is the generic sequence. + It is multiplied by ``bell_seq`` to get a sequence, whose terms are added up to get + the final terms for the polynomial. + + Examples + ======== + + >>> from sympy import fps, exp, cos + >>> from sympy.abc import x + >>> f1 = fps(exp(x)) + >>> f2 = fps(cos(x)) + >>> finv1, finv2 = f1.inverse(), f2.inverse() + + >>> finv1._eval_terms(6) + -x**5/120 + x**4/24 - x**3/6 + x**2/2 - x + 1 + + >>> finv2._eval_terms(8) + 61*x**6/720 + 5*x**4/24 + x**2/2 + 1 + + See Also + ======== + + sympy.series.formal.FormalPowerSeries.inverse + sympy.series.formal.FormalPowerSeries.coeff_bell + + """ + ffps = self.ffps + terms = [ffps.zero_coeff()] + + for i in range(1, n): + bell_seq = ffps.coeff_bell(i) + seq = (self.aux_seq * bell_seq) + terms.append(Add(*(seq[:i])) / ffps.fact_seq[i-1] * ffps.xk.coeff(i)) + + return Add(*terms) + + +def fps(f, x=None, x0=0, dir=1, hyper=True, order=4, rational=True, full=False): + """ + Generates Formal Power Series of ``f``. + + Explanation + =========== + + Returns the formal series expansion of ``f`` around ``x = x0`` + with respect to ``x`` in the form of a ``FormalPowerSeries`` object. + + Formal Power Series is represented using an explicit formula + computed using different algorithms. + + See :func:`compute_fps` for the more details regarding the computation + of formula. + + Parameters + ========== + + x : Symbol, optional + If x is None and ``f`` is univariate, the univariate symbols will be + supplied, otherwise an error will be raised. + x0 : number, optional + Point to perform series expansion about. Default is 0. + dir : {1, -1, '+', '-'}, optional + If dir is 1 or '+' the series is calculated from the right and + for -1 or '-' the series is calculated from the left. For smooth + functions this flag will not alter the results. Default is 1. + hyper : {True, False}, optional + Set hyper to False to skip the hypergeometric algorithm. + By default it is set to False. + order : int, optional + Order of the derivative of ``f``, Default is 4. + rational : {True, False}, optional + Set rational to False to skip rational algorithm. By default it is set + to True. + full : {True, False}, optional + Set full to True to increase the range of rational algorithm. + See :func:`rational_algorithm` for details. By default it is set to + False. + + Examples + ======== + + >>> from sympy import fps, ln, atan, sin + >>> from sympy.abc import x, n + + Rational Functions + + >>> fps(ln(1 + x)).truncate() + x - x**2/2 + x**3/3 - x**4/4 + x**5/5 + O(x**6) + + >>> fps(atan(x), full=True).truncate() + x - x**3/3 + x**5/5 + O(x**6) + + Symbolic Functions + + >>> fps(x**n*sin(x**2), x).truncate(8) + -x**(n + 6)/6 + x**(n + 2) + O(x**(n + 8)) + + See Also + ======== + + sympy.series.formal.FormalPowerSeries + sympy.series.formal.compute_fps + """ + f = sympify(f) + + if x is None: + free = f.free_symbols + if len(free) == 1: + x = free.pop() + elif not free: + return f + else: + raise NotImplementedError("multivariate formal power series") + + result = compute_fps(f, x, x0, dir, hyper, order, rational, full) + + if result is None: + return f + + return FormalPowerSeries(f, x, x0, dir, result) diff --git a/phivenv/Lib/site-packages/sympy/series/fourier.py b/phivenv/Lib/site-packages/sympy/series/fourier.py new file mode 100644 index 0000000000000000000000000000000000000000..c4ad43a75d0e1d24bfc591383f68965fd8b504c7 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/series/fourier.py @@ -0,0 +1,811 @@ +"""Fourier Series""" + +from sympy.core.numbers import (oo, pi) +from sympy.core.symbol import Wild +from sympy.core.expr import Expr +from sympy.core.add import Add +from sympy.core.containers import Tuple +from sympy.core.singleton import S +from sympy.core.symbol import Dummy, Symbol +from sympy.core.sympify import sympify +from sympy.functions.elementary.trigonometric import sin, cos, sinc +from sympy.series.series_class import SeriesBase +from sympy.series.sequences import SeqFormula +from sympy.sets.sets import Interval +from sympy.utilities.iterables import is_sequence + + +__doctest_requires__ = {('fourier_series',): ['matplotlib']} + + +def fourier_cos_seq(func, limits, n): + """Returns the cos sequence in a Fourier series""" + from sympy.integrals import integrate + x, L = limits[0], limits[2] - limits[1] + cos_term = cos(2*n*pi*x / L) + formula = 2 * cos_term * integrate(func * cos_term, limits) / L + a0 = formula.subs(n, S.Zero) / 2 + return a0, SeqFormula(2 * cos_term * integrate(func * cos_term, limits) + / L, (n, 1, oo)) + + +def fourier_sin_seq(func, limits, n): + """Returns the sin sequence in a Fourier series""" + from sympy.integrals import integrate + x, L = limits[0], limits[2] - limits[1] + sin_term = sin(2*n*pi*x / L) + return SeqFormula(2 * sin_term * integrate(func * sin_term, limits) + / L, (n, 1, oo)) + + +def _process_limits(func, limits): + """ + Limits should be of the form (x, start, stop). + x should be a symbol. Both start and stop should be bounded. + + Explanation + =========== + + * If x is not given, x is determined from func. + * If limits is None. Limit of the form (x, -pi, pi) is returned. + + Examples + ======== + + >>> from sympy.series.fourier import _process_limits as pari + >>> from sympy.abc import x + >>> pari(x**2, (x, -2, 2)) + (x, -2, 2) + >>> pari(x**2, (-2, 2)) + (x, -2, 2) + >>> pari(x**2, None) + (x, -pi, pi) + """ + def _find_x(func): + free = func.free_symbols + if len(free) == 1: + return free.pop() + elif not free: + return Dummy('k') + else: + raise ValueError( + " specify dummy variables for %s. If the function contains" + " more than one free symbol, a dummy variable should be" + " supplied explicitly e.g. FourierSeries(m*n**2, (n, -pi, pi))" + % func) + + x, start, stop = None, None, None + if limits is None: + x, start, stop = _find_x(func), -pi, pi + if is_sequence(limits, Tuple): + if len(limits) == 3: + x, start, stop = limits + elif len(limits) == 2: + x = _find_x(func) + start, stop = limits + + if not isinstance(x, Symbol) or start is None or stop is None: + raise ValueError('Invalid limits given: %s' % str(limits)) + + unbounded = [S.NegativeInfinity, S.Infinity] + if start in unbounded or stop in unbounded: + raise ValueError("Both the start and end value should be bounded") + + return sympify((x, start, stop)) + + +def finite_check(f, x, L): + + def check_fx(exprs, x): + return x not in exprs.free_symbols + + def check_sincos(_expr, x, L): + if isinstance(_expr, (sin, cos)): + sincos_args = _expr.args[0] + + if sincos_args.match(a*(pi/L)*x + b) is not None: + return True + else: + return False + + from sympy.simplify.fu import TR2, TR1, sincos_to_sum + _expr = sincos_to_sum(TR2(TR1(f))) + add_coeff = _expr.as_coeff_add() + + a = Wild('a', properties=[lambda k: k.is_Integer, lambda k: k != S.Zero, ]) + b = Wild('b', properties=[lambda k: x not in k.free_symbols, ]) + + for s in add_coeff[1]: + mul_coeffs = s.as_coeff_mul()[1] + for t in mul_coeffs: + if not (check_fx(t, x) or check_sincos(t, x, L)): + return False, f + + return True, _expr + + +class FourierSeries(SeriesBase): + r"""Represents Fourier sine/cosine series. + + Explanation + =========== + + This class only represents a fourier series. + No computation is performed. + + For how to compute Fourier series, see the :func:`fourier_series` + docstring. + + See Also + ======== + + sympy.series.fourier.fourier_series + """ + def __new__(cls, *args): + args = map(sympify, args) + return Expr.__new__(cls, *args) + + @property + def function(self): + return self.args[0] + + @property + def x(self): + return self.args[1][0] + + @property + def period(self): + return (self.args[1][1], self.args[1][2]) + + @property + def a0(self): + return self.args[2][0] + + @property + def an(self): + return self.args[2][1] + + @property + def bn(self): + return self.args[2][2] + + @property + def interval(self): + return Interval(0, oo) + + @property + def start(self): + return self.interval.inf + + @property + def stop(self): + return self.interval.sup + + @property + def length(self): + return oo + + @property + def L(self): + return abs(self.period[1] - self.period[0]) / 2 + + def _eval_subs(self, old, new): + x = self.x + if old.has(x): + return self + + def truncate(self, n=3): + """ + Return the first n nonzero terms of the series. + + If ``n`` is None return an iterator. + + Parameters + ========== + + n : int or None + Amount of non-zero terms in approximation or None. + + Returns + ======= + + Expr or iterator : + Approximation of function expanded into Fourier series. + + Examples + ======== + + >>> from sympy import fourier_series, pi + >>> from sympy.abc import x + >>> s = fourier_series(x, (x, -pi, pi)) + >>> s.truncate(4) + 2*sin(x) - sin(2*x) + 2*sin(3*x)/3 - sin(4*x)/2 + + See Also + ======== + + sympy.series.fourier.FourierSeries.sigma_approximation + """ + if n is None: + return iter(self) + + terms = [] + for t in self: + if len(terms) == n: + break + if t is not S.Zero: + terms.append(t) + + return Add(*terms) + + def sigma_approximation(self, n=3): + r""" + Return :math:`\sigma`-approximation of Fourier series with respect + to order n. + + Explanation + =========== + + Sigma approximation adjusts a Fourier summation to eliminate the Gibbs + phenomenon which would otherwise occur at discontinuities. + A sigma-approximated summation for a Fourier series of a T-periodical + function can be written as + + .. math:: + s(\theta) = \frac{1}{2} a_0 + \sum _{k=1}^{m-1} + \operatorname{sinc} \Bigl( \frac{k}{m} \Bigr) \cdot + \left[ a_k \cos \Bigl( \frac{2\pi k}{T} \theta \Bigr) + + b_k \sin \Bigl( \frac{2\pi k}{T} \theta \Bigr) \right], + + where :math:`a_0, a_k, b_k, k=1,\ldots,{m-1}` are standard Fourier + series coefficients and + :math:`\operatorname{sinc} \Bigl( \frac{k}{m} \Bigr)` is a Lanczos + :math:`\sigma` factor (expressed in terms of normalized + :math:`\operatorname{sinc}` function). + + Parameters + ========== + + n : int + Highest order of the terms taken into account in approximation. + + Returns + ======= + + Expr : + Sigma approximation of function expanded into Fourier series. + + Examples + ======== + + >>> from sympy import fourier_series, pi + >>> from sympy.abc import x + >>> s = fourier_series(x, (x, -pi, pi)) + >>> s.sigma_approximation(4) + 2*sin(x)*sinc(pi/4) - 2*sin(2*x)/pi + 2*sin(3*x)*sinc(3*pi/4)/3 + + See Also + ======== + + sympy.series.fourier.FourierSeries.truncate + + Notes + ===== + + The behaviour of + :meth:`~sympy.series.fourier.FourierSeries.sigma_approximation` + is different from :meth:`~sympy.series.fourier.FourierSeries.truncate` + - it takes all nonzero terms of degree smaller than n, rather than + first n nonzero ones. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Gibbs_phenomenon + .. [2] https://en.wikipedia.org/wiki/Sigma_approximation + """ + terms = [sinc(pi * i / n) * t for i, t in enumerate(self[:n]) + if t is not S.Zero] + return Add(*terms) + + def shift(self, s): + """ + Shift the function by a term independent of x. + + Explanation + =========== + + f(x) -> f(x) + s + + This is fast, if Fourier series of f(x) is already + computed. + + Examples + ======== + + >>> from sympy import fourier_series, pi + >>> from sympy.abc import x + >>> s = fourier_series(x**2, (x, -pi, pi)) + >>> s.shift(1).truncate() + -4*cos(x) + cos(2*x) + 1 + pi**2/3 + """ + s, x = sympify(s), self.x + + if x in s.free_symbols: + raise ValueError("'%s' should be independent of %s" % (s, x)) + + a0 = self.a0 + s + sfunc = self.function + s + + return self.func(sfunc, self.args[1], (a0, self.an, self.bn)) + + def shiftx(self, s): + """ + Shift x by a term independent of x. + + Explanation + =========== + + f(x) -> f(x + s) + + This is fast, if Fourier series of f(x) is already + computed. + + Examples + ======== + + >>> from sympy import fourier_series, pi + >>> from sympy.abc import x + >>> s = fourier_series(x**2, (x, -pi, pi)) + >>> s.shiftx(1).truncate() + -4*cos(x + 1) + cos(2*x + 2) + pi**2/3 + """ + s, x = sympify(s), self.x + + if x in s.free_symbols: + raise ValueError("'%s' should be independent of %s" % (s, x)) + + an = self.an.subs(x, x + s) + bn = self.bn.subs(x, x + s) + sfunc = self.function.subs(x, x + s) + + return self.func(sfunc, self.args[1], (self.a0, an, bn)) + + def scale(self, s): + """ + Scale the function by a term independent of x. + + Explanation + =========== + + f(x) -> s * f(x) + + This is fast, if Fourier series of f(x) is already + computed. + + Examples + ======== + + >>> from sympy import fourier_series, pi + >>> from sympy.abc import x + >>> s = fourier_series(x**2, (x, -pi, pi)) + >>> s.scale(2).truncate() + -8*cos(x) + 2*cos(2*x) + 2*pi**2/3 + """ + s, x = sympify(s), self.x + + if x in s.free_symbols: + raise ValueError("'%s' should be independent of %s" % (s, x)) + + an = self.an.coeff_mul(s) + bn = self.bn.coeff_mul(s) + a0 = self.a0 * s + sfunc = self.args[0] * s + + return self.func(sfunc, self.args[1], (a0, an, bn)) + + def scalex(self, s): + """ + Scale x by a term independent of x. + + Explanation + =========== + + f(x) -> f(s*x) + + This is fast, if Fourier series of f(x) is already + computed. + + Examples + ======== + + >>> from sympy import fourier_series, pi + >>> from sympy.abc import x + >>> s = fourier_series(x**2, (x, -pi, pi)) + >>> s.scalex(2).truncate() + -4*cos(2*x) + cos(4*x) + pi**2/3 + """ + s, x = sympify(s), self.x + + if x in s.free_symbols: + raise ValueError("'%s' should be independent of %s" % (s, x)) + + an = self.an.subs(x, x * s) + bn = self.bn.subs(x, x * s) + sfunc = self.function.subs(x, x * s) + + return self.func(sfunc, self.args[1], (self.a0, an, bn)) + + def _eval_as_leading_term(self, x, logx, cdir): + for t in self: + if t is not S.Zero: + return t + + def _eval_term(self, pt): + if pt == 0: + return self.a0 + return self.an.coeff(pt) + self.bn.coeff(pt) + + def __neg__(self): + return self.scale(-1) + + def __add__(self, other): + if isinstance(other, FourierSeries): + if self.period != other.period: + raise ValueError("Both the series should have same periods") + + x, y = self.x, other.x + function = self.function + other.function.subs(y, x) + + if self.x not in function.free_symbols: + return function + + an = self.an + other.an + bn = self.bn + other.bn + a0 = self.a0 + other.a0 + + return self.func(function, self.args[1], (a0, an, bn)) + + return Add(self, other) + + def __sub__(self, other): + return self.__add__(-other) + + +class FiniteFourierSeries(FourierSeries): + r"""Represents Finite Fourier sine/cosine series. + + For how to compute Fourier series, see the :func:`fourier_series` + docstring. + + Parameters + ========== + + f : Expr + Expression for finding fourier_series + + limits : ( x, start, stop) + x is the independent variable for the expression f + (start, stop) is the period of the fourier series + + exprs: (a0, an, bn) or Expr + a0 is the constant term a0 of the fourier series + an is a dictionary of coefficients of cos terms + an[k] = coefficient of cos(pi*(k/L)*x) + bn is a dictionary of coefficients of sin terms + bn[k] = coefficient of sin(pi*(k/L)*x) + + or exprs can be an expression to be converted to fourier form + + Methods + ======= + + This class is an extension of FourierSeries class. + Please refer to sympy.series.fourier.FourierSeries for + further information. + + See Also + ======== + + sympy.series.fourier.FourierSeries + sympy.series.fourier.fourier_series + """ + + def __new__(cls, f, limits, exprs): + f = sympify(f) + limits = sympify(limits) + exprs = sympify(exprs) + + if not (isinstance(exprs, Tuple) and len(exprs) == 3): # exprs is not of form (a0, an, bn) + # Converts the expression to fourier form + c, e = exprs.as_coeff_add() + from sympy.simplify.fu import TR10 + rexpr = c + Add(*[TR10(i) for i in e]) + a0, exp_ls = rexpr.expand(trig=False, power_base=False, power_exp=False, log=False).as_coeff_add() + + x = limits[0] + L = abs(limits[2] - limits[1]) / 2 + + a = Wild('a', properties=[lambda k: k.is_Integer, lambda k: k is not S.Zero, ]) + b = Wild('b', properties=[lambda k: x not in k.free_symbols, ]) + + an = {} + bn = {} + + # separates the coefficients of sin and cos terms in dictionaries an, and bn + for p in exp_ls: + t = p.match(b * cos(a * (pi / L) * x)) + q = p.match(b * sin(a * (pi / L) * x)) + if t: + an[t[a]] = t[b] + an.get(t[a], S.Zero) + elif q: + bn[q[a]] = q[b] + bn.get(q[a], S.Zero) + else: + a0 += p + + exprs = Tuple(a0, an, bn) + + return Expr.__new__(cls, f, limits, exprs) + + @property + def interval(self): + _length = 1 if self.a0 else 0 + _length += max(set(self.an.keys()).union(set(self.bn.keys()))) + 1 + return Interval(0, _length) + + @property + def length(self): + return self.stop - self.start + + def shiftx(self, s): + s, x = sympify(s), self.x + + if x in s.free_symbols: + raise ValueError("'%s' should be independent of %s" % (s, x)) + + _expr = self.truncate().subs(x, x + s) + sfunc = self.function.subs(x, x + s) + + return self.func(sfunc, self.args[1], _expr) + + def scale(self, s): + s, x = sympify(s), self.x + + if x in s.free_symbols: + raise ValueError("'%s' should be independent of %s" % (s, x)) + + _expr = self.truncate() * s + sfunc = self.function * s + + return self.func(sfunc, self.args[1], _expr) + + def scalex(self, s): + s, x = sympify(s), self.x + + if x in s.free_symbols: + raise ValueError("'%s' should be independent of %s" % (s, x)) + + _expr = self.truncate().subs(x, x * s) + sfunc = self.function.subs(x, x * s) + + return self.func(sfunc, self.args[1], _expr) + + def _eval_term(self, pt): + if pt == 0: + return self.a0 + + _term = self.an.get(pt, S.Zero) * cos(pt * (pi / self.L) * self.x) \ + + self.bn.get(pt, S.Zero) * sin(pt * (pi / self.L) * self.x) + return _term + + def __add__(self, other): + if isinstance(other, FourierSeries): + return other.__add__(fourier_series(self.function, self.args[1],\ + finite=False)) + elif isinstance(other, FiniteFourierSeries): + if self.period != other.period: + raise ValueError("Both the series should have same periods") + + x, y = self.x, other.x + function = self.function + other.function.subs(y, x) + + if self.x not in function.free_symbols: + return function + + return fourier_series(function, limits=self.args[1]) + + +def fourier_series(f, limits=None, finite=True): + r"""Computes the Fourier trigonometric series expansion. + + Explanation + =========== + + Fourier trigonometric series of $f(x)$ over the interval $(a, b)$ + is defined as: + + .. math:: + \frac{a_0}{2} + \sum_{n=1}^{\infty} + (a_n \cos(\frac{2n \pi x}{L}) + b_n \sin(\frac{2n \pi x}{L})) + + where the coefficients are: + + .. math:: + L = b - a + + .. math:: + a_0 = \frac{2}{L} \int_{a}^{b}{f(x) dx} + + .. math:: + a_n = \frac{2}{L} \int_{a}^{b}{f(x) \cos(\frac{2n \pi x}{L}) dx} + + .. math:: + b_n = \frac{2}{L} \int_{a}^{b}{f(x) \sin(\frac{2n \pi x}{L}) dx} + + The condition whether the function $f(x)$ given should be periodic + or not is more than necessary, because it is sufficient to consider + the series to be converging to $f(x)$ only in the given interval, + not throughout the whole real line. + + This also brings a lot of ease for the computation because + you do not have to make $f(x)$ artificially periodic by + wrapping it with piecewise, modulo operations, + but you can shape the function to look like the desired periodic + function only in the interval $(a, b)$, and the computed series will + automatically become the series of the periodic version of $f(x)$. + + This property is illustrated in the examples section below. + + Parameters + ========== + + limits : (sym, start, end), optional + *sym* denotes the symbol the series is computed with respect to. + + *start* and *end* denotes the start and the end of the interval + where the fourier series converges to the given function. + + Default range is specified as $-\pi$ and $\pi$. + + Returns + ======= + + FourierSeries + A symbolic object representing the Fourier trigonometric series. + + Examples + ======== + + Computing the Fourier series of $f(x) = x^2$: + + >>> from sympy import fourier_series, pi + >>> from sympy.abc import x + >>> f = x**2 + >>> s = fourier_series(f, (x, -pi, pi)) + >>> s1 = s.truncate(n=3) + >>> s1 + -4*cos(x) + cos(2*x) + pi**2/3 + + Shifting of the Fourier series: + + >>> s.shift(1).truncate() + -4*cos(x) + cos(2*x) + 1 + pi**2/3 + >>> s.shiftx(1).truncate() + -4*cos(x + 1) + cos(2*x + 2) + pi**2/3 + + Scaling of the Fourier series: + + >>> s.scale(2).truncate() + -8*cos(x) + 2*cos(2*x) + 2*pi**2/3 + >>> s.scalex(2).truncate() + -4*cos(2*x) + cos(4*x) + pi**2/3 + + Computing the Fourier series of $f(x) = x$: + + This illustrates how truncating to the higher order gives better + convergence. + + .. plot:: + :context: reset + :format: doctest + :include-source: True + + >>> from sympy import fourier_series, pi, plot + >>> from sympy.abc import x + >>> f = x + >>> s = fourier_series(f, (x, -pi, pi)) + >>> s1 = s.truncate(n = 3) + >>> s2 = s.truncate(n = 5) + >>> s3 = s.truncate(n = 7) + >>> p = plot(f, s1, s2, s3, (x, -pi, pi), show=False, legend=True) + + >>> p[0].line_color = (0, 0, 0) + >>> p[0].label = 'x' + >>> p[1].line_color = (0.7, 0.7, 0.7) + >>> p[1].label = 'n=3' + >>> p[2].line_color = (0.5, 0.5, 0.5) + >>> p[2].label = 'n=5' + >>> p[3].line_color = (0.3, 0.3, 0.3) + >>> p[3].label = 'n=7' + + >>> p.show() + + This illustrates how the series converges to different sawtooth + waves if the different ranges are specified. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> s1 = fourier_series(x, (x, -1, 1)).truncate(10) + >>> s2 = fourier_series(x, (x, -pi, pi)).truncate(10) + >>> s3 = fourier_series(x, (x, 0, 1)).truncate(10) + >>> p = plot(x, s1, s2, s3, (x, -5, 5), show=False, legend=True) + + >>> p[0].line_color = (0, 0, 0) + >>> p[0].label = 'x' + >>> p[1].line_color = (0.7, 0.7, 0.7) + >>> p[1].label = '[-1, 1]' + >>> p[2].line_color = (0.5, 0.5, 0.5) + >>> p[2].label = '[-pi, pi]' + >>> p[3].line_color = (0.3, 0.3, 0.3) + >>> p[3].label = '[0, 1]' + + >>> p.show() + + Notes + ===== + + Computing Fourier series can be slow + due to the integration required in computing + an, bn. + + It is faster to compute Fourier series of a function + by using shifting and scaling on an already + computed Fourier series rather than computing + again. + + e.g. If the Fourier series of ``x**2`` is known + the Fourier series of ``x**2 - 1`` can be found by shifting by ``-1``. + + See Also + ======== + + sympy.series.fourier.FourierSeries + + References + ========== + + .. [1] https://mathworld.wolfram.com/FourierSeries.html + """ + f = sympify(f) + + limits = _process_limits(f, limits) + x = limits[0] + + if x not in f.free_symbols: + return f + + if finite: + L = abs(limits[2] - limits[1]) / 2 + is_finite, res_f = finite_check(f, x, L) + if is_finite: + return FiniteFourierSeries(f, limits, res_f) + + n = Dummy('n') + center = (limits[1] + limits[2]) / 2 + if center.is_zero: + neg_f = f.subs(x, -x) + if f == neg_f: + a0, an = fourier_cos_seq(f, limits, n) + bn = SeqFormula(0, (1, oo)) + return FourierSeries(f, limits, (a0, an, bn)) + elif f == -neg_f: + a0 = S.Zero + an = SeqFormula(0, (1, oo)) + bn = fourier_sin_seq(f, limits, n) + return FourierSeries(f, limits, (a0, an, bn)) + a0, an = fourier_cos_seq(f, limits, n) + bn = fourier_sin_seq(f, limits, n) + return FourierSeries(f, limits, (a0, an, bn)) diff --git a/phivenv/Lib/site-packages/sympy/series/gruntz.py b/phivenv/Lib/site-packages/sympy/series/gruntz.py new file mode 100644 index 0000000000000000000000000000000000000000..20ba3150e9918384141e755a39f5bacb0e76db55 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/series/gruntz.py @@ -0,0 +1,701 @@ +""" +Limits +====== + +Implemented according to the PhD thesis +https://www.cybertester.com/data/gruntz.pdf, which contains very thorough +descriptions of the algorithm including many examples. We summarize here +the gist of it. + +All functions are sorted according to how rapidly varying they are at +infinity using the following rules. Any two functions f and g can be +compared using the properties of L: + +L=lim log|f(x)| / log|g(x)| (for x -> oo) + +We define >, < ~ according to:: + + 1. f > g .... L=+-oo + + we say that: + - f is greater than any power of g + - f is more rapidly varying than g + - f goes to infinity/zero faster than g + + 2. f < g .... L=0 + + we say that: + - f is lower than any power of g + + 3. f ~ g .... L!=0, +-oo + + we say that: + - both f and g are bounded from above and below by suitable integral + powers of the other + +Examples +======== +:: + 2 < x < exp(x) < exp(x**2) < exp(exp(x)) + 2 ~ 3 ~ -5 + x ~ x**2 ~ x**3 ~ 1/x ~ x**m ~ -x + exp(x) ~ exp(-x) ~ exp(2x) ~ exp(x)**2 ~ exp(x+exp(-x)) + f ~ 1/f + +So we can divide all the functions into comparability classes (x and x^2 +belong to one class, exp(x) and exp(-x) belong to some other class). In +principle, we could compare any two functions, but in our algorithm, we +do not compare anything below the class 2~3~-5 (for example log(x) is +below this), so we set 2~3~-5 as the lowest comparability class. + +Given the function f, we find the list of most rapidly varying (mrv set) +subexpressions of it. This list belongs to the same comparability class. +Let's say it is {exp(x), exp(2x)}. Using the rule f ~ 1/f we find an +element "w" (either from the list or a new one) from the same +comparability class which goes to zero at infinity. In our example we +set w=exp(-x) (but we could also set w=exp(-2x) or w=exp(-3x) ...). We +rewrite the mrv set using w, in our case {1/w, 1/w^2}, and substitute it +into f. Then we expand f into a series in w:: + + f = c0*w^e0 + c1*w^e1 + ... + O(w^en), where e0oo, lim f = lim c0*w^e0, because all the other terms go to zero, +because w goes to zero faster than the ci and ei. So:: + + for e0>0, lim f = 0 + for e0<0, lim f = +-oo (the sign depends on the sign of c0) + for e0=0, lim f = lim c0 + +We need to recursively compute limits at several places of the algorithm, but +as is shown in the PhD thesis, it always finishes. + +Important functions from the implementation: + +compare(a, b, x) compares "a" and "b" by computing the limit L. +mrv(e, x) returns list of most rapidly varying (mrv) subexpressions of "e" +rewrite(e, Omega, x, wsym) rewrites "e" in terms of w +leadterm(f, x) returns the lowest power term in the series of f +mrv_leadterm(e, x) returns the lead term (c0, e0) for e +limitinf(e, x) computes lim e (for x->oo) +limit(e, z, z0) computes any limit by converting it to the case x->oo + +All the functions are really simple and straightforward except +rewrite(), which is the most difficult/complex part of the algorithm. +When the algorithm fails, the bugs are usually in the series expansion +(i.e. in SymPy) or in rewrite. + +This code is almost exact rewrite of the Maple code inside the Gruntz +thesis. + +Debugging +--------- + +Because the gruntz algorithm is highly recursive, it's difficult to +figure out what went wrong inside a debugger. Instead, turn on nice +debug prints by defining the environment variable SYMPY_DEBUG. For +example: + +[user@localhost]: SYMPY_DEBUG=True ./bin/isympy + +In [1]: limit(sin(x)/x, x, 0) +limitinf(_x*sin(1/_x), _x) = 1 ++-mrv_leadterm(_x*sin(1/_x), _x) = (1, 0) +| +-mrv(_x*sin(1/_x), _x) = set([_x]) +| | +-mrv(_x, _x) = set([_x]) +| | +-mrv(sin(1/_x), _x) = set([_x]) +| | +-mrv(1/_x, _x) = set([_x]) +| | +-mrv(_x, _x) = set([_x]) +| +-mrv_leadterm(exp(_x)*sin(exp(-_x)), _x, set([exp(_x)])) = (1, 0) +| +-rewrite(exp(_x)*sin(exp(-_x)), set([exp(_x)]), _x, _w) = (1/_w*sin(_w), -_x) +| +-sign(_x, _x) = 1 +| +-mrv_leadterm(1, _x) = (1, 0) ++-sign(0, _x) = 0 ++-limitinf(1, _x) = 1 + +And check manually which line is wrong. Then go to the source code and +debug this function to figure out the exact problem. + +""" +from functools import reduce + +from sympy.core import Basic, S, Mul, PoleError +from sympy.core.cache import cacheit +from sympy.core.function import AppliedUndef +from sympy.core.intfunc import ilcm +from sympy.core.numbers import I, oo +from sympy.core.symbol import Dummy, Wild +from sympy.core.traversal import bottom_up + +from sympy.functions import log, exp, sign as _sign +from sympy.series.order import Order +from sympy.utilities.misc import debug_decorator as debug +from sympy.utilities.timeutils import timethis + +timeit = timethis('gruntz') + + +def compare(a, b, x): + """Returns "<" if a" for a>b""" + # log(exp(...)) must always be simplified here for termination + la, lb = log(a), log(b) + if isinstance(a, Basic) and (isinstance(a, exp) or (a.is_Pow and a.base == S.Exp1)): + la = a.exp + if isinstance(b, Basic) and (isinstance(b, exp) or (b.is_Pow and b.base == S.Exp1)): + lb = b.exp + + c = limitinf(la/lb, x) + if c == 0: + return "<" + elif c.is_infinite: + return ">" + else: + return "=" + + +class SubsSet(dict): + """ + Stores (expr, dummy) pairs, and how to rewrite expr-s. + + Explanation + =========== + + The gruntz algorithm needs to rewrite certain expressions in term of a new + variable w. We cannot use subs, because it is just too smart for us. For + example:: + + > Omega=[exp(exp(_p - exp(-_p))/(1 - 1/_p)), exp(exp(_p))] + > O2=[exp(-exp(_p) + exp(-exp(-_p))*exp(_p)/(1 - 1/_p))/_w, 1/_w] + > e = exp(exp(_p - exp(-_p))/(1 - 1/_p)) - exp(exp(_p)) + > e.subs(Omega[0],O2[0]).subs(Omega[1],O2[1]) + -1/w + exp(exp(p)*exp(-exp(-p))/(1 - 1/p)) + + is really not what we want! + + So we do it the hard way and keep track of all the things we potentially + want to substitute by dummy variables. Consider the expression:: + + exp(x - exp(-x)) + exp(x) + x. + + The mrv set is {exp(x), exp(-x), exp(x - exp(-x))}. + We introduce corresponding dummy variables d1, d2, d3 and rewrite:: + + d3 + d1 + x. + + This class first of all keeps track of the mapping expr->variable, i.e. + will at this stage be a dictionary:: + + {exp(x): d1, exp(-x): d2, exp(x - exp(-x)): d3}. + + [It turns out to be more convenient this way round.] + But sometimes expressions in the mrv set have other expressions from the + mrv set as subexpressions, and we need to keep track of that as well. In + this case, d3 is really exp(x - d2), so rewrites at this stage is:: + + {d3: exp(x-d2)}. + + The function rewrite uses all this information to correctly rewrite our + expression in terms of w. In this case w can be chosen to be exp(-x), + i.e. d2. The correct rewriting then is:: + + exp(-w)/w + 1/w + x. + """ + def __init__(self): + self.rewrites = {} + + def __repr__(self): + return super().__repr__() + ', ' + self.rewrites.__repr__() + + def __getitem__(self, key): + if key not in self: + self[key] = Dummy() + return dict.__getitem__(self, key) + + def do_subs(self, e): + """Substitute the variables with expressions""" + for expr, var in self.items(): + e = e.xreplace({var: expr}) + return e + + def meets(self, s2): + """Tell whether or not self and s2 have non-empty intersection""" + return set(self.keys()).intersection(list(s2.keys())) != set() + + def union(self, s2, exps=None): + """Compute the union of self and s2, adjusting exps""" + res = self.copy() + tr = {} + for expr, var in s2.items(): + if expr in self: + if exps: + exps = exps.xreplace({var: res[expr]}) + tr[var] = res[expr] + else: + res[expr] = var + for var, rewr in s2.rewrites.items(): + res.rewrites[var] = rewr.xreplace(tr) + return res, exps + + def copy(self): + """Create a shallow copy of SubsSet""" + r = SubsSet() + r.rewrites = self.rewrites.copy() + for expr, var in self.items(): + r[expr] = var + return r + + +@debug +def mrv(e, x): + """Returns a SubsSet of most rapidly varying (mrv) subexpressions of 'e', + and e rewritten in terms of these""" + from sympy.simplify.powsimp import powsimp + e = powsimp(e, deep=True, combine='exp') + if not isinstance(e, Basic): + raise TypeError("e should be an instance of Basic") + if not e.has(x): + return SubsSet(), e + elif e == x: + s = SubsSet() + return s, s[x] + elif e.is_Mul or e.is_Add: + i, d = e.as_independent(x) # throw away x-independent terms + if d.func != e.func: + s, expr = mrv(d, x) + return s, e.func(i, expr) + a, b = d.as_two_terms() + s1, e1 = mrv(a, x) + s2, e2 = mrv(b, x) + return mrv_max1(s1, s2, e.func(i, e1, e2), x) + elif e.is_Pow and e.base != S.Exp1: + e1 = S.One + while e.is_Pow: + b1 = e.base + e1 *= e.exp + e = b1 + if b1 == 1: + return SubsSet(), b1 + if e1.has(x): + return mrv(exp(e1*log(b1)), x) + else: + s, expr = mrv(b1, x) + return s, expr**e1 + elif isinstance(e, log): + s, expr = mrv(e.args[0], x) + return s, log(expr) + elif isinstance(e, exp) or (e.is_Pow and e.base == S.Exp1): + # We know from the theory of this algorithm that exp(log(...)) may always + # be simplified here, and doing so is vital for termination. + if isinstance(e.exp, log): + return mrv(e.exp.args[0], x) + # if a product has an infinite factor the result will be + # infinite if there is no zero, otherwise NaN; here, we + # consider the result infinite if any factor is infinite + li = limitinf(e.exp, x) + if any(_.is_infinite for _ in Mul.make_args(li)): + s1 = SubsSet() + e1 = s1[e] + s2, e2 = mrv(e.exp, x) + su = s1.union(s2)[0] + su.rewrites[e1] = exp(e2) + return mrv_max3(s1, e1, s2, exp(e2), su, e1, x) + else: + s, expr = mrv(e.exp, x) + return s, exp(expr) + elif isinstance(e, AppliedUndef): + raise ValueError("MRV set computation for UndefinedFunction is not allowed") + elif e.is_Function: + l = [mrv(a, x) for a in e.args] + l2 = [s for (s, _) in l if s != SubsSet()] + if len(l2) != 1: + # e.g. something like BesselJ(x, x) + raise NotImplementedError("MRV set computation for functions in" + " several variables not implemented.") + s, ss = l2[0], SubsSet() + args = [ss.do_subs(x[1]) for x in l] + return s, e.func(*args) + elif e.is_Derivative: + raise NotImplementedError("MRV set computation for derivatives" + " not implemented yet.") + raise NotImplementedError( + "Don't know how to calculate the mrv of '%s'" % e) + + +def mrv_max3(f, expsf, g, expsg, union, expsboth, x): + """ + Computes the maximum of two sets of expressions f and g, which + are in the same comparability class, i.e. max() compares (two elements of) + f and g and returns either (f, expsf) [if f is larger], (g, expsg) + [if g is larger] or (union, expsboth) [if f, g are of the same class]. + """ + if not isinstance(f, SubsSet): + raise TypeError("f should be an instance of SubsSet") + if not isinstance(g, SubsSet): + raise TypeError("g should be an instance of SubsSet") + if f == SubsSet(): + return g, expsg + elif g == SubsSet(): + return f, expsf + elif f.meets(g): + return union, expsboth + + c = compare(list(f.keys())[0], list(g.keys())[0], x) + if c == ">": + return f, expsf + elif c == "<": + return g, expsg + else: + if c != "=": + raise ValueError("c should be =") + return union, expsboth + + +def mrv_max1(f, g, exps, x): + """Computes the maximum of two sets of expressions f and g, which + are in the same comparability class, i.e. mrv_max1() compares (two elements of) + f and g and returns the set, which is in the higher comparability class + of the union of both, if they have the same order of variation. + Also returns exps, with the appropriate substitutions made. + """ + u, b = f.union(g, exps) + return mrv_max3(f, g.do_subs(exps), g, f.do_subs(exps), + u, b, x) + + +@debug +@cacheit +@timeit +def sign(e, x): + """ + Returns a sign of an expression e(x) for x->oo. + + :: + + e > 0 for x sufficiently large ... 1 + e == 0 for x sufficiently large ... 0 + e < 0 for x sufficiently large ... -1 + + The result of this function is currently undefined if e changes sign + arbitrarily often for arbitrarily large x (e.g. sin(x)). + + Note that this returns zero only if e is *constantly* zero + for x sufficiently large. [If e is constant, of course, this is just + the same thing as the sign of e.] + """ + if not isinstance(e, Basic): + raise TypeError("e should be an instance of Basic") + + if e.is_positive: + return 1 + elif e.is_negative: + return -1 + elif e.is_zero: + return 0 + + elif not e.has(x): + from sympy.simplify import logcombine + e = logcombine(e) + return _sign(e) + elif e == x: + return 1 + elif e.is_Mul: + a, b = e.as_two_terms() + sa = sign(a, x) + if not sa: + return 0 + return sa * sign(b, x) + elif isinstance(e, exp): + return 1 + elif e.is_Pow: + if e.base == S.Exp1: + return 1 + s = sign(e.base, x) + if s == 1: + return 1 + if e.exp.is_Integer: + return s**e.exp + elif isinstance(e, log) and e.args[0].is_positive: + return sign(e.args[0] - 1, x) + + # if all else fails, do it the hard way + c0, e0 = mrv_leadterm(e, x) + return sign(c0, x) + + +@debug +@timeit +@cacheit +def limitinf(e, x): + """Limit e(x) for x-> oo.""" + # rewrite e in terms of tractable functions only + + old = e + if not e.has(x): + return e # e is a constant + from sympy.simplify.powsimp import powdenest + from sympy.calculus.util import AccumBounds + if e.has(Order): + e = e.expand().removeO() + if not x.is_positive or x.is_integer: + # We make sure that x.is_positive is True and x.is_integer is None + # so we get all the correct mathematical behavior from the expression. + # We need a fresh variable. + p = Dummy('p', positive=True) + e = e.subs(x, p) + x = p + e = e.rewrite('tractable', deep=True, limitvar=x) + e = powdenest(e) + if isinstance(e, AccumBounds): + if mrv_leadterm(e.min, x) != mrv_leadterm(e.max, x): + raise NotImplementedError + c0, e0 = mrv_leadterm(e.min, x) + else: + c0, e0 = mrv_leadterm(e, x) + sig = sign(e0, x) + if sig == 1: + return S.Zero # e0>0: lim f = 0 + elif sig == -1: # e0<0: lim f = +-oo (the sign depends on the sign of c0) + if c0.match(I*Wild("a", exclude=[I])): + return c0*oo + s = sign(c0, x) + # the leading term shouldn't be 0: + if s == 0: + raise ValueError("Leading term should not be 0") + return s*oo + elif sig == 0: + if c0 == old: + c0 = c0.cancel() + return limitinf(c0, x) # e0=0: lim f = lim c0 + else: + raise ValueError("{} could not be evaluated".format(sig)) + + +def moveup2(s, x): + r = SubsSet() + for expr, var in s.items(): + r[expr.xreplace({x: exp(x)})] = var + for var, expr in s.rewrites.items(): + r.rewrites[var] = s.rewrites[var].xreplace({x: exp(x)}) + return r + + +def moveup(l, x): + return [e.xreplace({x: exp(x)}) for e in l] + + +@debug +@timeit +@cacheit +def mrv_leadterm(e, x): + """Returns (c0, e0) for e.""" + Omega = SubsSet() + if not e.has(x): + return (e, S.Zero) + if Omega == SubsSet(): + Omega, exps = mrv(e, x) + if not Omega: + # e really does not depend on x after simplification + return exps, S.Zero + if x in Omega: + # move the whole omega up (exponentiate each term): + Omega_up = moveup2(Omega, x) + exps_up = moveup([exps], x)[0] + # NOTE: there is no need to move this down! + Omega = Omega_up + exps = exps_up + # + # The positive dummy, w, is used here so log(w*2) etc. will expand; + # a unique dummy is needed in this algorithm + # + # For limits of complex functions, the algorithm would have to be + # improved, or just find limits of Re and Im components separately. + # + w = Dummy("w", positive=True) + f, logw = rewrite(exps, Omega, x, w) + + # Ensure expressions of the form exp(log(...)) don't get simplified automatically in the previous steps. + # see: https://github.com/sympy/sympy/issues/15323#issuecomment-478639399 + f = f.replace(lambda f: f.is_Pow and f.has(x), lambda f: exp(log(f.base)*f.exp)) + + try: + lt = f.leadterm(w, logx=logw) + except (NotImplementedError, PoleError, ValueError): + n0 = 1 + _series = Order(1) + incr = S.One + while _series.is_Order: + _series = f._eval_nseries(w, n=n0+incr, logx=logw) + incr *= 2 + series = _series.expand().removeO() + try: + lt = series.leadterm(w, logx=logw) + except (NotImplementedError, PoleError, ValueError): + lt = f.as_coeff_exponent(w) + if lt[0].has(w): + base = f.as_base_exp()[0].as_coeff_exponent(w) + ex = f.as_base_exp()[1] + lt = (base[0]**ex, base[1]*ex) + return (lt[0].subs(log(w), logw), lt[1]) + + +def build_expression_tree(Omega, rewrites): + r""" Helper function for rewrite. + + We need to sort Omega (mrv set) so that we replace an expression before + we replace any expression in terms of which it has to be rewritten:: + + e1 ---> e2 ---> e3 + \ + -> e4 + + Here we can do e1, e2, e3, e4 or e1, e2, e4, e3. + To do this we assemble the nodes into a tree, and sort them by height. + + This function builds the tree, rewrites then sorts the nodes. + """ + class Node: + def __init__(self): + self.before = [] + self.expr = None + self.var = None + def ht(self): + return reduce(lambda x, y: x + y, + [x.ht() for x in self.before], 1) + nodes = {} + for expr, v in Omega: + n = Node() + n.var = v + n.expr = expr + nodes[v] = n + for _, v in Omega: + if v in rewrites: + n = nodes[v] + r = rewrites[v] + for _, v2 in Omega: + if r.has(v2): + n.before.append(nodes[v2]) + + return nodes + + +@debug +@timeit +def rewrite(e, Omega, x, wsym): + """e(x) ... the function + Omega ... the mrv set + wsym ... the symbol which is going to be used for w + + Returns the rewritten e in terms of w and log(w). See test_rewrite1() + for examples and correct results. + """ + + from sympy import AccumBounds + if not isinstance(Omega, SubsSet): + raise TypeError("Omega should be an instance of SubsSet") + if len(Omega) == 0: + raise ValueError("Length cannot be 0") + # all items in Omega must be exponentials + for t in Omega.keys(): + if not isinstance(t, exp): + raise ValueError("Value should be exp") + rewrites = Omega.rewrites + Omega = list(Omega.items()) + + nodes = build_expression_tree(Omega, rewrites) + Omega.sort(key=lambda x: nodes[x[1]].ht(), reverse=True) + + # make sure we know the sign of each exp() term; after the loop, + # g is going to be the "w" - the simplest one in the mrv set + for g, _ in Omega: + sig = sign(g.exp, x) + if sig != 1 and sig != -1 and not sig.has(AccumBounds): + raise NotImplementedError('Result depends on the sign of %s' % sig) + if sig == 1: + wsym = 1/wsym # if g goes to oo, substitute 1/w + # O2 is a list, which results by rewriting each item in Omega using "w" + O2 = [] + denominators = [] + for f, var in Omega: + c = limitinf(f.exp/g.exp, x) + if c.is_Rational: + denominators.append(c.q) + arg = f.exp + if var in rewrites: + if not isinstance(rewrites[var], exp): + raise ValueError("Value should be exp") + arg = rewrites[var].args[0] + O2.append((var, exp((arg - c*g.exp))*wsym**c)) + + # Remember that Omega contains subexpressions of "e". So now we find + # them in "e" and substitute them for our rewriting, stored in O2 + + # the following powsimp is necessary to automatically combine exponentials, + # so that the .xreplace() below succeeds: + # TODO this should not be necessary + from sympy.simplify.powsimp import powsimp + f = powsimp(e, deep=True, combine='exp') + for a, b in O2: + f = f.xreplace({a: b}) + + for _, var in Omega: + assert not f.has(var) + + # finally compute the logarithm of w (logw). + logw = g.exp + if sig == 1: + logw = -logw # log(w)->log(1/w)=-log(w) + + # Some parts of SymPy have difficulty computing series expansions with + # non-integral exponents. The following heuristic improves the situation: + exponent = reduce(ilcm, denominators, 1) + f = f.subs({wsym: wsym**exponent}) + logw /= exponent + + # bottom_up function is required for a specific case - when f is + # -exp(p/(p + 1)) + exp(-p**2/(p + 1) + p). No current simplification + # methods reduce this to 0 while not expanding polynomials. + f = bottom_up(f, lambda w: getattr(w, 'normal', lambda: w)()) + + return f, logw + + +def gruntz(e, z, z0, dir="+"): + """ + Compute the limit of e(z) at the point z0 using the Gruntz algorithm. + + Explanation + =========== + + ``z0`` can be any expression, including oo and -oo. + + For ``dir="+"`` (default) it calculates the limit from the right + (z->z0+) and for ``dir="-"`` the limit from the left (z->z0-). For infinite z0 + (oo or -oo), the dir argument does not matter. + + This algorithm is fully described in the module docstring in the gruntz.py + file. It relies heavily on the series expansion. Most frequently, gruntz() + is only used if the faster limit() function (which uses heuristics) fails. + """ + if not z.is_symbol: + raise NotImplementedError("Second argument must be a Symbol") + + # convert all limits to the limit z->oo; sign of z is handled in limitinf + r = None + if z0 in (oo, I*oo): + e0 = e + elif z0 in (-oo, -I*oo): + e0 = e.subs(z, -z) + else: + if str(dir) == "-": + e0 = e.subs(z, z0 - 1/z) + elif str(dir) == "+": + e0 = e.subs(z, z0 + 1/z) + else: + raise NotImplementedError("dir must be '+' or '-'") + + r = limitinf(e0, z) + + # This is a bit of a heuristic for nice results... we always rewrite + # tractable functions in terms of familiar intractable ones. + # It might be nicer to rewrite the exactly to what they were initially, + # but that would take some work to implement. + return r.rewrite('intractable', deep=True) diff --git a/phivenv/Lib/site-packages/sympy/series/kauers.py b/phivenv/Lib/site-packages/sympy/series/kauers.py new file mode 100644 index 0000000000000000000000000000000000000000..9e9645ff15ee5ae3c1d1c8709f76aed1b366f50a --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/series/kauers.py @@ -0,0 +1,51 @@ +def finite_diff(expression, variable, increment=1): + """ + Takes as input a polynomial expression and the variable used to construct + it and returns the difference between function's value when the input is + incremented to 1 and the original function value. If you want an increment + other than one supply it as a third argument. + + Examples + ======== + + >>> from sympy.abc import x, y, z + >>> from sympy.series.kauers import finite_diff + >>> finite_diff(x**2, x) + 2*x + 1 + >>> finite_diff(y**3 + 2*y**2 + 3*y + 4, y) + 3*y**2 + 7*y + 6 + >>> finite_diff(x**2 + 3*x + 8, x, 2) + 4*x + 10 + >>> finite_diff(z**3 + 8*z, z, 3) + 9*z**2 + 27*z + 51 + """ + expression = expression.expand() + expression2 = expression.subs(variable, variable + increment) + expression2 = expression2.expand() + return expression2 - expression + +def finite_diff_kauers(sum): + """ + Takes as input a Sum instance and returns the difference between the sum + with the upper index incremented by 1 and the original sum. For example, + if S(n) is a sum, then finite_diff_kauers will return S(n + 1) - S(n). + + Examples + ======== + + >>> from sympy.series.kauers import finite_diff_kauers + >>> from sympy import Sum + >>> from sympy.abc import x, y, m, n, k + >>> finite_diff_kauers(Sum(k, (k, 1, n))) + n + 1 + >>> finite_diff_kauers(Sum(1/k, (k, 1, n))) + 1/(n + 1) + >>> finite_diff_kauers(Sum((x*y**2), (x, 1, n), (y, 1, m))) + (m + 1)**2*(n + 1) + >>> finite_diff_kauers(Sum((x*y), (x, 1, m), (y, 1, n))) + (m + 1)*(n + 1) + """ + function = sum.function + for l in sum.limits: + function = function.subs(l[0], l[- 1] + 1) + return function diff --git a/phivenv/Lib/site-packages/sympy/series/limits.py b/phivenv/Lib/site-packages/sympy/series/limits.py new file mode 100644 index 0000000000000000000000000000000000000000..e15f7a1243452075a76553903cabf60e43942d8c --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/series/limits.py @@ -0,0 +1,394 @@ +from sympy.calculus.accumulationbounds import AccumBounds +from sympy.core import S, Symbol, Add, sympify, Expr, PoleError, Mul +from sympy.core.exprtools import factor_terms +from sympy.core.numbers import Float, _illegal +from sympy.core.function import AppliedUndef +from sympy.core.symbol import Dummy +from sympy.functions.combinatorial.factorials import factorial +from sympy.functions.elementary.complexes import (Abs, sign, arg, re) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.special.gamma_functions import gamma +from sympy.polys import PolynomialError, factor +from sympy.series.order import Order +from .gruntz import gruntz + +def limit(e, z, z0, dir="+"): + """Computes the limit of ``e(z)`` at the point ``z0``. + + Parameters + ========== + + e : expression, the limit of which is to be taken + + z : symbol representing the variable in the limit. + Other symbols are treated as constants. Multivariate limits + are not supported. + + z0 : the value toward which ``z`` tends. Can be any expression, + including ``oo`` and ``-oo``. + + dir : string, optional (default: "+") + The limit is bi-directional if ``dir="+-"``, from the right + (z->z0+) if ``dir="+"``, and from the left (z->z0-) if + ``dir="-"``. For infinite ``z0`` (``oo`` or ``-oo``), the ``dir`` + argument is determined from the direction of the infinity + (i.e., ``dir="-"`` for ``oo``). + + Examples + ======== + + >>> from sympy import limit, sin, oo + >>> from sympy.abc import x + >>> limit(sin(x)/x, x, 0) + 1 + >>> limit(1/x, x, 0) # default dir='+' + oo + >>> limit(1/x, x, 0, dir="-") + -oo + >>> limit(1/x, x, 0, dir='+-') + zoo + >>> limit(1/x, x, oo) + 0 + + Notes + ===== + + First we try some heuristics for easy and frequent cases like "x", "1/x", + "x**2" and similar, so that it's fast. For all other cases, we use the + Gruntz algorithm (see the gruntz() function). + + See Also + ======== + + limit_seq : returns the limit of a sequence. + """ + + return Limit(e, z, z0, dir).doit(deep=False) + + +def heuristics(e, z, z0, dir): + """Computes the limit of an expression term-wise. + Parameters are the same as for the ``limit`` function. + Works with the arguments of expression ``e`` one by one, computing + the limit of each and then combining the results. This approach + works only for simple limits, but it is fast. + """ + + rv = None + if z0 is S.Infinity: + rv = limit(e.subs(z, 1/z), z, S.Zero, "+") + if isinstance(rv, Limit): + return + elif (e.is_Mul or e.is_Add or e.is_Pow or (e.is_Function and not isinstance(e, AppliedUndef))): + r = [] + from sympy.simplify.simplify import together + for a in e.args: + l = limit(a, z, z0, dir) + if l.has(S.Infinity) and l.is_finite is None: + if isinstance(e, Add): + m = factor_terms(e) + if not isinstance(m, Mul): # try together + m = together(m) + if not isinstance(m, Mul): # try factor if the previous methods failed + m = factor(e) + if isinstance(m, Mul): + return heuristics(m, z, z0, dir) + return + return + elif isinstance(l, Limit): + return + elif l is S.NaN: + return + else: + r.append(l) + if r: + rv = e.func(*r) + if rv is S.NaN and e.is_Mul and any(isinstance(rr, AccumBounds) for rr in r): + r2 = [] + e2 = [] + for ii, rval in enumerate(r): + if isinstance(rval, AccumBounds): + r2.append(rval) + else: + e2.append(e.args[ii]) + + if len(e2) > 0: + e3 = Mul(*e2).simplify() + l = limit(e3, z, z0, dir) + rv = l * Mul(*r2) + + if rv is S.NaN: + try: + from sympy.simplify.ratsimp import ratsimp + rat_e = ratsimp(e) + except PolynomialError: + return + if rat_e is S.NaN or rat_e == e: + return + return limit(rat_e, z, z0, dir) + return rv + + +class Limit(Expr): + """Represents an unevaluated limit. + + Examples + ======== + + >>> from sympy import Limit, sin + >>> from sympy.abc import x + >>> Limit(sin(x)/x, x, 0) + Limit(sin(x)/x, x, 0, dir='+') + >>> Limit(1/x, x, 0, dir="-") + Limit(1/x, x, 0, dir='-') + + """ + + def __new__(cls, e, z, z0, dir="+"): + e = sympify(e) + z = sympify(z) + z0 = sympify(z0) + + if z0 in (S.Infinity, S.ImaginaryUnit*S.Infinity): + dir = "-" + elif z0 in (S.NegativeInfinity, S.ImaginaryUnit*S.NegativeInfinity): + dir = "+" + + if(z0.has(z)): + raise NotImplementedError("Limits approaching a variable point are" + " not supported (%s -> %s)" % (z, z0)) + if isinstance(dir, str): + dir = Symbol(dir) + elif not isinstance(dir, Symbol): + raise TypeError("direction must be of type basestring or " + "Symbol, not %s" % type(dir)) + if str(dir) not in ('+', '-', '+-'): + raise ValueError("direction must be one of '+', '-' " + "or '+-', not %s" % dir) + + obj = Expr.__new__(cls) + obj._args = (e, z, z0, dir) + return obj + + + @property + def free_symbols(self): + e = self.args[0] + isyms = e.free_symbols + isyms.difference_update(self.args[1].free_symbols) + isyms.update(self.args[2].free_symbols) + return isyms + + + def pow_heuristics(self, e): + _, z, z0, _ = self.args + b1, e1 = e.base, e.exp + if not b1.has(z): + res = limit(e1*log(b1), z, z0) + return exp(res) + + ex_lim = limit(e1, z, z0) + base_lim = limit(b1, z, z0) + + if base_lim is S.One: + if ex_lim in (S.Infinity, S.NegativeInfinity): + res = limit(e1*(b1 - 1), z, z0) + return exp(res) + if base_lim is S.NegativeInfinity and ex_lim is S.Infinity: + return S.ComplexInfinity + + + def doit(self, **hints): + """Evaluates the limit. + + Parameters + ========== + + deep : bool, optional (default: True) + Invoke the ``doit`` method of the expressions involved before + taking the limit. + + hints : optional keyword arguments + To be passed to ``doit`` methods; only used if deep is True. + """ + + e, z, z0, dir = self.args + + if str(dir) == '+-': + r = limit(e, z, z0, dir='+') + l = limit(e, z, z0, dir='-') + if isinstance(r, Limit) and isinstance(l, Limit): + if r.args[0] == l.args[0]: + return self + if r == l: + return l + if r.is_infinite and l.is_infinite: + return S.ComplexInfinity + raise ValueError("The limit does not exist since " + "left hand limit = %s and right hand limit = %s" + % (l, r)) + + if z0 is S.ComplexInfinity: + raise NotImplementedError("Limits at complex " + "infinity are not implemented") + + if z0.is_infinite: + cdir = sign(z0) + cdir = cdir/abs(cdir) + e = e.subs(z, cdir*z) + dir = "-" + z0 = S.Infinity + + if hints.get('deep', True): + e = e.doit(**hints) + z = z.doit(**hints) + z0 = z0.doit(**hints) + + if e == z: + return z0 + + if not e.has(z): + return e + + if z0 is S.NaN: + return S.NaN + + if e.has(*_illegal): + return self + + if e.is_Order: + return Order(limit(e.expr, z, z0), *e.args[1:]) + + cdir = S.Zero + if str(dir) == "+": + cdir = S.One + elif str(dir) == "-": + cdir = S.NegativeOne + + def set_signs(expr): + if not expr.args: + return expr + newargs = tuple(set_signs(arg) for arg in expr.args) + if newargs != expr.args: + expr = expr.func(*newargs) + abs_flag = isinstance(expr, Abs) + arg_flag = isinstance(expr, arg) + sign_flag = isinstance(expr, sign) + if abs_flag or sign_flag or arg_flag: + try: + sig = limit(expr.args[0], z, z0, dir) + if sig.is_zero: + sig = limit(1/expr.args[0], z, z0, dir) + except NotImplementedError: + return expr + else: + if sig.is_extended_real: + if (sig < 0) == True: + return (-expr.args[0] if abs_flag else + S.NegativeOne if sign_flag else S.Pi) + elif (sig > 0) == True: + return (expr.args[0] if abs_flag else + S.One if sign_flag else S.Zero) + return expr + + if e.has(Float): + # Convert floats like 0.5 to exact SymPy numbers like S.Half, to + # prevent rounding errors which can lead to unexpected execution + # of conditional blocks that work on comparisons + # Also see comments in https://github.com/sympy/sympy/issues/19453 + from sympy.simplify.simplify import nsimplify + e = nsimplify(e) + e = set_signs(e) + + + if e.is_meromorphic(z, z0): + if z0 is S.Infinity: + newe = e.subs(z, 1/z) + # cdir changes sign as oo- should become 0+ + cdir = -cdir + else: + newe = e.subs(z, z + z0) + try: + coeff, ex = newe.leadterm(z, cdir=cdir) + except ValueError: + pass + else: + if ex > 0: + return S.Zero + elif ex == 0: + return coeff + if cdir == 1 or not(int(ex) & 1): + return S.Infinity*sign(coeff) + elif cdir == -1: + return S.NegativeInfinity*sign(coeff) + else: + return S.ComplexInfinity + + if z0 is S.Infinity: + if e.is_Mul: + e = factor_terms(e) + dummy = Dummy('z', positive=z.is_positive, negative=z.is_negative, real=z.is_real) + newe = e.subs(z, 1/dummy) + # cdir changes sign as oo- should become 0+ + cdir = -cdir + newz = dummy + else: + newe = e.subs(z, z + z0) + newz = z + try: + coeff, ex = newe.leadterm(newz, cdir=cdir) + except (ValueError, NotImplementedError, PoleError): + # The NotImplementedError catching is for custom functions + from sympy.simplify.powsimp import powsimp + e = powsimp(e) + if e.is_Pow: + r = self.pow_heuristics(e) + if r is not None: + return r + try: + coeff = newe.as_leading_term(newz, cdir=cdir) + if coeff != newe and (coeff.has(exp) or coeff.has(S.Exp1)): + return gruntz(coeff, newz, 0, "-" if re(cdir).is_negative else "+") + except (ValueError, NotImplementedError, PoleError): + pass + else: + if isinstance(coeff, AccumBounds) and ex == S.Zero: + return coeff + if coeff.has(S.Infinity, S.NegativeInfinity, S.ComplexInfinity, S.NaN): + return self + if not coeff.has(newz): + if ex.is_positive: + return S.Zero + elif ex == 0: + return coeff + elif ex.is_negative: + if cdir == 1: + return S.Infinity*sign(coeff) + elif cdir == -1: + return S.NegativeInfinity*sign(coeff)*S.NegativeOne**(S.One + ex) + else: + return S.ComplexInfinity + else: + raise NotImplementedError("Not sure of sign of %s" % ex) + + # gruntz fails on factorials but works with the gamma function + # If no factorial term is present, e should remain unchanged. + # factorial is defined to be zero for negative inputs (which + # differs from gamma) so only rewrite for non-negative z0. + if z0.is_extended_nonnegative: + e = e.rewrite(factorial, gamma) + + l = None + + try: + r = gruntz(e, z, z0, dir) + if r is S.NaN or l is S.NaN: + raise PoleError() + except (PoleError, ValueError): + if l is not None: + raise + r = heuristics(e, z, z0, dir) + if r is None: + return self + + return r diff --git a/phivenv/Lib/site-packages/sympy/series/limitseq.py b/phivenv/Lib/site-packages/sympy/series/limitseq.py new file mode 100644 index 0000000000000000000000000000000000000000..ceac4e7b63bfc09d9dfc26c12c7d2acc8b8d44da --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/series/limitseq.py @@ -0,0 +1,257 @@ +"""Limits of sequences""" + +from sympy.calculus.accumulationbounds import AccumulationBounds +from sympy.core.add import Add +from sympy.core.function import PoleError +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.core.symbol import Dummy +from sympy.core.sympify import sympify +from sympy.functions.combinatorial.numbers import fibonacci +from sympy.functions.combinatorial.factorials import factorial, subfactorial +from sympy.functions.special.gamma_functions import gamma +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.miscellaneous import Max, Min +from sympy.functions.elementary.trigonometric import cos, sin +from sympy.series.limits import Limit + + +def difference_delta(expr, n=None, step=1): + """Difference Operator. + + Explanation + =========== + + Discrete analog of differential operator. Given a sequence x[n], + returns the sequence x[n + step] - x[n]. + + Examples + ======== + + >>> from sympy import difference_delta as dd + >>> from sympy.abc import n + >>> dd(n*(n + 1), n) + 2*n + 2 + >>> dd(n*(n + 1), n, 2) + 4*n + 6 + + References + ========== + + .. [1] https://reference.wolfram.com/language/ref/DifferenceDelta.html + """ + expr = sympify(expr) + + if n is None: + f = expr.free_symbols + if len(f) == 1: + n = f.pop() + elif len(f) == 0: + return S.Zero + else: + raise ValueError("Since there is more than one variable in the" + " expression, a variable must be supplied to" + " take the difference of %s" % expr) + step = sympify(step) + if step.is_number is False or step.is_finite is False: + raise ValueError("Step should be a finite number.") + + if hasattr(expr, '_eval_difference_delta'): + result = expr._eval_difference_delta(n, step) + if result: + return result + + return expr.subs(n, n + step) - expr + + +def dominant(expr, n): + """Finds the dominant term in a sum, that is a term that dominates + every other term. + + Explanation + =========== + + If limit(a/b, n, oo) is oo then a dominates b. + If limit(a/b, n, oo) is 0 then b dominates a. + Otherwise, a and b are comparable. + + If there is no unique dominant term, then returns ``None``. + + Examples + ======== + + >>> from sympy import Sum + >>> from sympy.series.limitseq import dominant + >>> from sympy.abc import n, k + >>> dominant(5*n**3 + 4*n**2 + n + 1, n) + 5*n**3 + >>> dominant(2**n + Sum(k, (k, 0, n)), n) + 2**n + + See Also + ======== + + sympy.series.limitseq.dominant + """ + terms = Add.make_args(expr.expand(func=True)) + term0 = terms[-1] + comp = [term0] # comparable terms + for t in terms[:-1]: + r = term0/t + e = r.gammasimp() + if e == r: + e = r.factor() + l = limit_seq(e, n) + if l is None: + return None + elif l.is_zero: + term0 = t + comp = [term0] + elif l not in [S.Infinity, S.NegativeInfinity]: + comp.append(t) + if len(comp) > 1: + return None + return term0 + + +def _limit_inf(expr, n): + try: + return Limit(expr, n, S.Infinity).doit(deep=False) + except (NotImplementedError, PoleError): + return None + + +def _limit_seq(expr, n, trials): + from sympy.concrete.summations import Sum + + for i in range(trials): + if not expr.has(Sum): + result = _limit_inf(expr, n) + if result is not None: + return result + + num, den = expr.as_numer_denom() + if not den.has(n) or not num.has(n): + result = _limit_inf(expr.doit(), n) + if result is not None: + return result + return None + + num, den = (difference_delta(t.expand(), n) for t in [num, den]) + expr = (num / den).gammasimp() + + if not expr.has(Sum): + result = _limit_inf(expr, n) + if result is not None: + return result + + num, den = expr.as_numer_denom() + + num = dominant(num, n) + if num is None: + return None + + den = dominant(den, n) + if den is None: + return None + + expr = (num / den).gammasimp() + + +def limit_seq(expr, n=None, trials=5): + """Finds the limit of a sequence as index ``n`` tends to infinity. + + Parameters + ========== + + expr : Expr + SymPy expression for the ``n-th`` term of the sequence + n : Symbol, optional + The index of the sequence, an integer that tends to positive + infinity. If None, inferred from the expression unless it has + multiple symbols. + trials: int, optional + The algorithm is highly recursive. ``trials`` is a safeguard from + infinite recursion in case the limit is not easily computed by the + algorithm. Try increasing ``trials`` if the algorithm returns ``None``. + + Admissible Terms + ================ + + The algorithm is designed for sequences built from rational functions, + indefinite sums, and indefinite products over an indeterminate n. Terms of + alternating sign are also allowed, but more complex oscillatory behavior is + not supported. + + Examples + ======== + + >>> from sympy import limit_seq, Sum, binomial + >>> from sympy.abc import n, k, m + >>> limit_seq((5*n**3 + 3*n**2 + 4) / (3*n**3 + 4*n - 5), n) + 5/3 + >>> limit_seq(binomial(2*n, n) / Sum(binomial(2*k, k), (k, 1, n)), n) + 3/4 + >>> limit_seq(Sum(k**2 * Sum(2**m/m, (m, 1, k)), (k, 1, n)) / (2**n*n), n) + 4 + + See Also + ======== + + sympy.series.limitseq.dominant + + References + ========== + + .. [1] Computing Limits of Sequences - Manuel Kauers + """ + + from sympy.concrete.summations import Sum + if n is None: + free = expr.free_symbols + if len(free) == 1: + n = free.pop() + elif not free: + return expr + else: + raise ValueError("Expression has more than one variable. " + "Please specify a variable.") + elif n not in expr.free_symbols: + return expr + + expr = expr.rewrite(fibonacci, S.GoldenRatio) + expr = expr.rewrite(factorial, subfactorial, gamma) + n_ = Dummy("n", integer=True, positive=True) + n1 = Dummy("n", odd=True, positive=True) + n2 = Dummy("n", even=True, positive=True) + + # If there is a negative term raised to a power involving n, or a + # trigonometric function, then consider even and odd n separately. + powers = (p.as_base_exp() for p in expr.atoms(Pow)) + if (any(b.is_negative and e.has(n) for b, e in powers) or + expr.has(cos, sin)): + L1 = _limit_seq(expr.xreplace({n: n1}), n1, trials) + if L1 is not None: + L2 = _limit_seq(expr.xreplace({n: n2}), n2, trials) + if L1 != L2: + if L1.is_comparable and L2.is_comparable: + return AccumulationBounds(Min(L1, L2), Max(L1, L2)) + else: + return None + else: + L1 = _limit_seq(expr.xreplace({n: n_}), n_, trials) + if L1 is not None: + return L1 + else: + if expr.is_Add: + limits = [limit_seq(term, n, trials) for term in expr.args] + if any(result is None for result in limits): + return None + else: + return Add(*limits) + # Maybe the absolute value is easier to deal with (though not if + # it has a Sum). If it tends to 0, the limit is 0. + elif not expr.has(Sum): + lim = _limit_seq(Abs(expr.xreplace({n: n_})), n_, trials) + if lim is not None and lim.is_zero: + return S.Zero diff --git a/phivenv/Lib/site-packages/sympy/series/order.py b/phivenv/Lib/site-packages/sympy/series/order.py new file mode 100644 index 0000000000000000000000000000000000000000..9cfd4309c2b7094ce02feab129e5f051c442d8cd --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/series/order.py @@ -0,0 +1,522 @@ +from sympy.core import S, sympify, Expr, Dummy, Add, Mul +from sympy.core.cache import cacheit +from sympy.core.containers import Tuple +from sympy.core.function import Function, PoleError, expand_power_base, expand_log +from sympy.core.sorting import default_sort_key +from sympy.functions.elementary.exponential import exp, log +from sympy.sets.sets import Complement +from sympy.utilities.iterables import uniq, is_sequence + + +class Order(Expr): + r""" Represents the limiting behavior of some function. + + Explanation + =========== + + The order of a function characterizes the function based on the limiting + behavior of the function as it goes to some limit. Only taking the limit + point to be a number is currently supported. This is expressed in + big O notation [1]_. + + The formal definition for the order of a function `g(x)` about a point `a` + is such that `g(x) = O(f(x))` as `x \rightarrow a` if and only if there + exists a `\delta > 0` and an `M > 0` such that `|g(x)| \leq M|f(x)|` for + `|x-a| < \delta`. This is equivalent to `\limsup_{x \rightarrow a} + |g(x)/f(x)| < \infty`. + + Let's illustrate it on the following example by taking the expansion of + `\sin(x)` about 0: + + .. math :: + \sin(x) = x - x^3/3! + O(x^5) + + where in this case `O(x^5) = x^5/5! - x^7/7! + \cdots`. By the definition + of `O`, there is a `\delta > 0` and an `M` such that: + + .. math :: + |x^5/5! - x^7/7! + ....| <= M|x^5| \text{ for } |x| < \delta + + or by the alternate definition: + + .. math :: + \lim_{x \rightarrow 0} | (x^5/5! - x^7/7! + ....) / x^5| < \infty + + which surely is true, because + + .. math :: + \lim_{x \rightarrow 0} | (x^5/5! - x^7/7! + ....) / x^5| = 1/5! + + + As it is usually used, the order of a function can be intuitively thought + of representing all terms of powers greater than the one specified. For + example, `O(x^3)` corresponds to any terms proportional to `x^3, + x^4,\ldots` and any higher power. For a polynomial, this leaves terms + proportional to `x^2`, `x` and constants. + + Examples + ======== + + >>> from sympy import O, oo, cos, pi + >>> from sympy.abc import x, y + + >>> O(x + x**2) + O(x) + >>> O(x + x**2, (x, 0)) + O(x) + >>> O(x + x**2, (x, oo)) + O(x**2, (x, oo)) + + >>> O(1 + x*y) + O(1, x, y) + >>> O(1 + x*y, (x, 0), (y, 0)) + O(1, x, y) + >>> O(1 + x*y, (x, oo), (y, oo)) + O(x*y, (x, oo), (y, oo)) + + >>> O(1) in O(1, x) + True + >>> O(1, x) in O(1) + False + >>> O(x) in O(1, x) + True + >>> O(x**2) in O(x) + True + + >>> O(x)*x + O(x**2) + >>> O(x) - O(x) + O(x) + >>> O(cos(x)) + O(1) + >>> O(cos(x), (x, pi/2)) + O(x - pi/2, (x, pi/2)) + + References + ========== + + .. [1] `Big O notation `_ + + Notes + ===== + + In ``O(f(x), x)`` the expression ``f(x)`` is assumed to have a leading + term. ``O(f(x), x)`` is automatically transformed to + ``O(f(x).as_leading_term(x),x)``. + + ``O(expr*f(x), x)`` is ``O(f(x), x)`` + + ``O(expr, x)`` is ``O(1)`` + + ``O(0, x)`` is 0. + + Multivariate O is also supported: + + ``O(f(x, y), x, y)`` is transformed to + ``O(f(x, y).as_leading_term(x,y).as_leading_term(y), x, y)`` + + In the multivariate case, it is assumed the limits w.r.t. the various + symbols commute. + + If no symbols are passed then all symbols in the expression are used + and the limit point is assumed to be zero. + + """ + + is_Order = True + + __slots__ = () + + @cacheit + def __new__(cls, expr, *args, **kwargs): + expr = sympify(expr) + + if not args: + if expr.is_Order: + variables = expr.variables + point = expr.point + else: + variables = list(expr.free_symbols) + point = [S.Zero]*len(variables) + else: + args = list(args if is_sequence(args) else [args]) + variables, point = [], [] + if is_sequence(args[0]): + for a in args: + v, p = list(map(sympify, a)) + variables.append(v) + point.append(p) + else: + variables = list(map(sympify, args)) + point = [S.Zero]*len(variables) + + if not all(v.is_symbol for v in variables): + raise TypeError('Variables are not symbols, got %s' % variables) + + if len(list(uniq(variables))) != len(variables): + raise ValueError('Variables are supposed to be unique symbols, got %s' % variables) + + if expr.is_Order: + expr_vp = dict(expr.args[1:]) + new_vp = dict(expr_vp) + vp = dict(zip(variables, point)) + for v, p in vp.items(): + if v in new_vp.keys(): + if p != new_vp[v]: + raise NotImplementedError( + "Mixing Order at different points is not supported.") + else: + new_vp[v] = p + if set(expr_vp.keys()) == set(new_vp.keys()): + return expr + else: + variables = list(new_vp.keys()) + point = [new_vp[v] for v in variables] + + if expr is S.NaN: + return S.NaN + + if any(x in p.free_symbols for x in variables for p in point): + raise ValueError('Got %s as a point.' % point) + + if variables: + if any(p != point[0] for p in point): + raise NotImplementedError( + "Multivariable orders at different points are not supported.") + if point[0] in (S.Infinity, S.Infinity*S.ImaginaryUnit): + s = {k: 1/Dummy() for k in variables} + rs = {1/v: 1/k for k, v in s.items()} + ps = [S.Zero for p in point] + elif point[0] in (S.NegativeInfinity, S.NegativeInfinity*S.ImaginaryUnit): + s = {k: -1/Dummy() for k in variables} + rs = {-1/v: -1/k for k, v in s.items()} + ps = [S.Zero for p in point] + elif point[0] is not S.Zero: + s = {k: Dummy() + point[0] for k in variables} + rs = {(v - point[0]).together(): k - point[0] for k, v in s.items()} + ps = [S.Zero for p in point] + else: + s = () + rs = () + ps = list(point) + + expr = expr.subs(s) + + if expr.is_Add: + expr = expr.factor() + + if s: + args = tuple([r[0] for r in rs.items()]) + else: + args = tuple(variables) + + if len(variables) > 1: + # XXX: better way? We need this expand() to + # workaround e.g: expr = x*(x + y). + # (x*(x + y)).as_leading_term(x, y) currently returns + # x*y (wrong order term!). That's why we want to deal with + # expand()'ed expr (handled in "if expr.is_Add" branch below). + expr = expr.expand() + + old_expr = None + while old_expr != expr: + old_expr = expr + if expr.is_Add: + lst = expr.extract_leading_order(args) + expr = Add(*[f.expr for (e, f) in lst]) + + elif expr: + try: + expr = expr.as_leading_term(*args) + except PoleError: + if isinstance(expr, Function) or\ + all(isinstance(arg, Function) for arg in expr.args): + # It is not possible to simplify an expression + # containing only functions (which raise error on + # call to leading term) further + pass + else: + orders = [] + pts = tuple(zip(args, ps)) + for arg in expr.args: + try: + lt = arg.as_leading_term(*args) + except PoleError: + lt = arg + if lt not in args: + order = Order(lt) + else: + order = Order(lt, *pts) + orders.append(order) + if expr.is_Add: + new_expr = Order(Add(*orders), *pts) + if new_expr.is_Add: + new_expr = Order(Add(*[a.expr for a in new_expr.args]), *pts) + expr = new_expr.expr + elif expr.is_Mul: + expr = Mul(*[a.expr for a in orders]) + elif expr.is_Pow: + e = expr.exp + b = expr.base + expr = exp(e * log(b)) + + # It would probably be better to handle this somewhere + # else. This is needed for a testcase in which there is a + # symbol with the assumptions zero=True. + if expr.is_zero: + expr = S.Zero + else: + expr = expr.as_independent(*args, as_Add=False)[1] + + expr = expand_power_base(expr) + expr = expand_log(expr) + + if len(args) == 1: + # The definition of O(f(x)) symbol explicitly stated that + # the argument of f(x) is irrelevant. That's why we can + # combine some power exponents (only "on top" of the + # expression tree for f(x)), e.g.: + # x**p * (-x)**q -> x**(p+q) for real p, q. + x = args[0] + margs = list(Mul.make_args( + expr.as_independent(x, as_Add=False)[1])) + + for i, t in enumerate(margs): + if t.is_Pow: + b, q = t.args + if b in (x, -x) and q.is_real and not q.has(x): + margs[i] = x**q + elif b.is_Pow and not b.exp.has(x): + b, r = b.args + if b in (x, -x) and r.is_real: + margs[i] = x**(r*q) + elif b.is_Mul and b.args[0] is S.NegativeOne: + b = -b + if b.is_Pow and not b.exp.has(x): + b, r = b.args + if b in (x, -x) and r.is_real: + margs[i] = x**(r*q) + + expr = Mul(*margs) + + expr = expr.subs(rs) + + if expr.is_Order: + expr = expr.expr + + if not expr.has(*variables) and not expr.is_zero: + expr = S.One + + # create Order instance: + vp = dict(zip(variables, point)) + variables.sort(key=default_sort_key) + point = [vp[v] for v in variables] + args = (expr,) + Tuple(*zip(variables, point)) + obj = Expr.__new__(cls, *args) + return obj + + def _eval_nseries(self, x, n, logx, cdir=0): + return self + + @property + def expr(self): + return self.args[0] + + @property + def variables(self): + if self.args[1:]: + return tuple(x[0] for x in self.args[1:]) + else: + return () + + @property + def point(self): + if self.args[1:]: + return tuple(x[1] for x in self.args[1:]) + else: + return () + + @property + def free_symbols(self): + return self.expr.free_symbols | set(self.variables) + + def _eval_power(b, e): + if e.is_Number and e.is_nonnegative: + return b.func(b.expr ** e, *b.args[1:]) + if e == O(1): + return b + return + + def as_expr_variables(self, order_symbols): + if order_symbols is None: + order_symbols = self.args[1:] + else: + if (not all(o[1] == order_symbols[0][1] for o in order_symbols) and + not all(p == self.point[0] for p in self.point)): # pragma: no cover + raise NotImplementedError('Order at points other than 0 ' + 'or oo not supported, got %s as a point.' % self.point) + if order_symbols and order_symbols[0][1] != self.point[0]: + raise NotImplementedError( + "Multiplying Order at different points is not supported.") + order_symbols = dict(order_symbols) + for s, p in dict(self.args[1:]).items(): + if s not in order_symbols.keys(): + order_symbols[s] = p + order_symbols = sorted(order_symbols.items(), key=lambda x: default_sort_key(x[0])) + return self.expr, tuple(order_symbols) + + def removeO(self): + return S.Zero + + def getO(self): + return self + + @cacheit + def contains(self, expr): + r""" + Return True if expr belongs to Order(self.expr, \*self.variables). + Return False if self belongs to expr. + Return None if the inclusion relation cannot be determined + (e.g. when self and expr have different symbols). + """ + expr = sympify(expr) + if expr.is_zero: + return True + if expr is S.NaN: + return False + point = self.point[0] if self.point else S.Zero + if expr.is_Order: + if (any(p != point for p in expr.point) or + any(p != point for p in self.point)): + return None + if expr.expr == self.expr: + # O(1) + O(1), O(1) + O(1, x), etc. + return all(x in self.args[1:] for x in expr.args[1:]) + if expr.expr.is_Add: + return all(self.contains(x) for x in expr.expr.args) + if self.expr.is_Add and point.is_zero: + return any(self.func(x, *self.args[1:]).contains(expr) + for x in self.expr.args) + if self.variables and expr.variables: + common_symbols = tuple( + [s for s in self.variables if s in expr.variables]) + elif self.variables: + common_symbols = self.variables + else: + common_symbols = expr.variables + if not common_symbols: + return None + if (self.expr.is_Pow and len(self.variables) == 1 + and self.variables == expr.variables): + symbol = self.variables[0] + other = expr.expr.as_independent(symbol, as_Add=False)[1] + if (other.is_Pow and other.base == symbol and + self.expr.base == symbol): + if point.is_zero: + rv = (self.expr.exp - other.exp).is_nonpositive + if point.is_infinite: + rv = (self.expr.exp - other.exp).is_nonnegative + if rv is not None: + return rv + + from sympy.simplify.powsimp import powsimp + r = None + ratio = self.expr/expr.expr + ratio = powsimp(ratio, deep=True, combine='exp') + for s in common_symbols: + from sympy.series.limits import Limit + l = Limit(ratio, s, point).doit(heuristics=False) + if not isinstance(l, Limit): + l = l != 0 + else: + l = None + if r is None: + r = l + else: + if r != l: + return + return r + + if self.expr.is_Pow and len(self.variables) == 1: + symbol = self.variables[0] + other = expr.as_independent(symbol, as_Add=False)[1] + if (other.is_Pow and other.base == symbol and + self.expr.base == symbol): + if point.is_zero: + rv = (self.expr.exp - other.exp).is_nonpositive + if point.is_infinite: + rv = (self.expr.exp - other.exp).is_nonnegative + if rv is not None: + return rv + + obj = self.func(expr, *self.args[1:]) + return self.contains(obj) + + def __contains__(self, other): + result = self.contains(other) + if result is None: + raise TypeError('contains did not evaluate to a bool') + return result + + def _eval_subs(self, old, new): + if old in self.variables: + newexpr = self.expr.subs(old, new) + i = self.variables.index(old) + newvars = list(self.variables) + newpt = list(self.point) + if new.is_symbol: + newvars[i] = new + else: + syms = new.free_symbols + if len(syms) == 1 or old in syms: + if old in syms: + var = self.variables[i] + else: + var = syms.pop() + # First, try to substitute self.point in the "new" + # expr to see if this is a fixed point. + # E.g. O(y).subs(y, sin(x)) + from sympy import limit + if new.has(Order) and limit(new.getO().expr, var, new.getO().point[0]) == self.point[i]: + point = new.getO().point[0] + return Order(newexpr, *zip([var], [point])) + else: + point = new.subs(var, self.point[i]) + if point != self.point[i]: + from sympy.solvers.solveset import solveset + d = Dummy() + sol = solveset(old - new.subs(var, d), d) + if isinstance(sol, Complement): + e1 = sol.args[0] + e2 = sol.args[1] + sol = set(e1) - set(e2) + res = [dict(zip((d, ), sol))] + point = d.subs(res[0]).limit(old, self.point[i]) + newvars[i] = var + newpt[i] = point + elif old not in syms: + del newvars[i], newpt[i] + if not syms and new == self.point[i]: + newvars.extend(syms) + newpt.extend([S.Zero]*len(syms)) + else: + return + return Order(newexpr, *zip(newvars, newpt)) + + def _eval_conjugate(self): + expr = self.expr._eval_conjugate() + if expr is not None: + return self.func(expr, *self.args[1:]) + + def _eval_derivative(self, x): + return self.func(self.expr.diff(x), *self.args[1:]) or self + + def _eval_transpose(self): + expr = self.expr._eval_transpose() + if expr is not None: + return self.func(expr, *self.args[1:]) + + def __neg__(self): + return self + +O = Order diff --git a/phivenv/Lib/site-packages/sympy/series/residues.py b/phivenv/Lib/site-packages/sympy/series/residues.py new file mode 100644 index 0000000000000000000000000000000000000000..a426f9e799bd040eea5124f718c2fa43e5de026b --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/series/residues.py @@ -0,0 +1,73 @@ +""" +This module implements the Residue function and related tools for working +with residues. +""" + +from sympy.core.mul import Mul +from sympy.core.singleton import S +from sympy.core.sympify import sympify +from sympy.utilities.timeutils import timethis + + +@timethis('residue') +def residue(expr, x, x0): + """ + Finds the residue of ``expr`` at the point x=x0. + + The residue is defined as the coefficient of ``1/(x-x0)`` in the power series + expansion about ``x=x0``. + + Examples + ======== + + >>> from sympy import Symbol, residue, sin + >>> x = Symbol("x") + >>> residue(1/x, x, 0) + 1 + >>> residue(1/x**2, x, 0) + 0 + >>> residue(2/sin(x), x, 0) + 2 + + This function is essential for the Residue Theorem [1]. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Residue_theorem + """ + # The current implementation uses series expansion to + # calculate it. A more general implementation is explained in + # the section 5.6 of the Bronstein's book {M. Bronstein: + # Symbolic Integration I, Springer Verlag (2005)}. For purely + # rational functions, the algorithm is much easier. See + # sections 2.4, 2.5, and 2.7 (this section actually gives an + # algorithm for computing any Laurent series coefficient for + # a rational function). The theory in section 2.4 will help to + # understand why the resultant works in the general algorithm. + # For the definition of a resultant, see section 1.4 (and any + # previous sections for more review). + + from sympy.series.order import Order + from sympy.simplify.radsimp import collect + expr = sympify(expr) + if x0 != 0: + expr = expr.subs(x, x + x0) + for n in (0, 1, 2, 4, 8, 16, 32): + s = expr.nseries(x, n=n) + if not s.has(Order) or s.getn() >= 0: + break + s = collect(s.removeO(), x) + if s.is_Add: + args = s.args + else: + args = [s] + res = S.Zero + for arg in args: + c, m = arg.as_coeff_mul(x) + m = Mul(*m) + if not (m in (S.One, x) or (m.is_Pow and m.exp.is_Integer)): + raise NotImplementedError('term of unexpected form: %s' % m) + if m == 1/x: + res += c + return res diff --git a/phivenv/Lib/site-packages/sympy/series/sequences.py b/phivenv/Lib/site-packages/sympy/series/sequences.py new file mode 100644 index 0000000000000000000000000000000000000000..7787515ddb05afaf34751bf451544935723d0921 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/series/sequences.py @@ -0,0 +1,1239 @@ +from sympy.core.basic import Basic +from sympy.core.cache import cacheit +from sympy.core.containers import Tuple +from sympy.core.decorators import call_highest_priority +from sympy.core.parameters import global_parameters +from sympy.core.function import AppliedUndef, expand +from sympy.core.mul import Mul +from sympy.core.numbers import Integer +from sympy.core.relational import Eq +from sympy.core.singleton import S, Singleton +from sympy.core.sorting import ordered +from sympy.core.symbol import Dummy, Symbol, Wild +from sympy.core.sympify import sympify +from sympy.matrices import Matrix +from sympy.polys import lcm, factor +from sympy.sets.sets import Interval, Intersection +from sympy.tensor.indexed import Idx +from sympy.utilities.iterables import flatten, is_sequence, iterable + + +############################################################################### +# SEQUENCES # +############################################################################### + + +class SeqBase(Basic): + """Base class for sequences""" + + is_commutative = True + _op_priority = 15 + + @staticmethod + def _start_key(expr): + """Return start (if possible) else S.Infinity. + + adapted from Set._infimum_key + """ + try: + start = expr.start + except NotImplementedError: + start = S.Infinity + return start + + def _intersect_interval(self, other): + """Returns start and stop. + + Takes intersection over the two intervals. + """ + interval = Intersection(self.interval, other.interval) + return interval.inf, interval.sup + + @property + def gen(self): + """Returns the generator for the sequence""" + raise NotImplementedError("(%s).gen" % self) + + @property + def interval(self): + """The interval on which the sequence is defined""" + raise NotImplementedError("(%s).interval" % self) + + @property + def start(self): + """The starting point of the sequence. This point is included""" + raise NotImplementedError("(%s).start" % self) + + @property + def stop(self): + """The ending point of the sequence. This point is included""" + raise NotImplementedError("(%s).stop" % self) + + @property + def length(self): + """Length of the sequence""" + raise NotImplementedError("(%s).length" % self) + + @property + def variables(self): + """Returns a tuple of variables that are bounded""" + return () + + @property + def free_symbols(self): + """ + This method returns the symbols in the object, excluding those + that take on a specific value (i.e. the dummy symbols). + + Examples + ======== + + >>> from sympy import SeqFormula + >>> from sympy.abc import n, m + >>> SeqFormula(m*n**2, (n, 0, 5)).free_symbols + {m} + """ + return ({j for i in self.args for j in i.free_symbols + .difference(self.variables)}) + + @cacheit + def coeff(self, pt): + """Returns the coefficient at point pt""" + if pt < self.start or pt > self.stop: + raise IndexError("Index %s out of bounds %s" % (pt, self.interval)) + return self._eval_coeff(pt) + + def _eval_coeff(self, pt): + raise NotImplementedError("The _eval_coeff method should be added to" + "%s to return coefficient so it is available" + "when coeff calls it." + % self.func) + + def _ith_point(self, i): + """Returns the i'th point of a sequence. + + Explanation + =========== + + If start point is negative infinity, point is returned from the end. + Assumes the first point to be indexed zero. + + Examples + ========= + + >>> from sympy import oo + >>> from sympy.series.sequences import SeqPer + + bounded + + >>> SeqPer((1, 2, 3), (-10, 10))._ith_point(0) + -10 + >>> SeqPer((1, 2, 3), (-10, 10))._ith_point(5) + -5 + + End is at infinity + + >>> SeqPer((1, 2, 3), (0, oo))._ith_point(5) + 5 + + Starts at negative infinity + + >>> SeqPer((1, 2, 3), (-oo, 0))._ith_point(5) + -5 + """ + if self.start is S.NegativeInfinity: + initial = self.stop + else: + initial = self.start + + if self.start is S.NegativeInfinity: + step = -1 + else: + step = 1 + + return initial + i*step + + def _add(self, other): + """ + Should only be used internally. + + Explanation + =========== + + self._add(other) returns a new, term-wise added sequence if self + knows how to add with other, otherwise it returns ``None``. + + ``other`` should only be a sequence object. + + Used within :class:`SeqAdd` class. + """ + return None + + def _mul(self, other): + """ + Should only be used internally. + + Explanation + =========== + + self._mul(other) returns a new, term-wise multiplied sequence if self + knows how to multiply with other, otherwise it returns ``None``. + + ``other`` should only be a sequence object. + + Used within :class:`SeqMul` class. + """ + return None + + def coeff_mul(self, other): + """ + Should be used when ``other`` is not a sequence. Should be + defined to define custom behaviour. + + Examples + ======== + + >>> from sympy import SeqFormula + >>> from sympy.abc import n + >>> SeqFormula(n**2).coeff_mul(2) + SeqFormula(2*n**2, (n, 0, oo)) + + Notes + ===== + + '*' defines multiplication of sequences with sequences only. + """ + return Mul(self, other) + + def __add__(self, other): + """Returns the term-wise addition of 'self' and 'other'. + + ``other`` should be a sequence. + + Examples + ======== + + >>> from sympy import SeqFormula + >>> from sympy.abc import n + >>> SeqFormula(n**2) + SeqFormula(n**3) + SeqFormula(n**3 + n**2, (n, 0, oo)) + """ + if not isinstance(other, SeqBase): + raise TypeError('cannot add sequence and %s' % type(other)) + return SeqAdd(self, other) + + @call_highest_priority('__add__') + def __radd__(self, other): + return self + other + + def __sub__(self, other): + """Returns the term-wise subtraction of ``self`` and ``other``. + + ``other`` should be a sequence. + + Examples + ======== + + >>> from sympy import SeqFormula + >>> from sympy.abc import n + >>> SeqFormula(n**2) - (SeqFormula(n)) + SeqFormula(n**2 - n, (n, 0, oo)) + """ + if not isinstance(other, SeqBase): + raise TypeError('cannot subtract sequence and %s' % type(other)) + return SeqAdd(self, -other) + + @call_highest_priority('__sub__') + def __rsub__(self, other): + return (-self) + other + + def __neg__(self): + """Negates the sequence. + + Examples + ======== + + >>> from sympy import SeqFormula + >>> from sympy.abc import n + >>> -SeqFormula(n**2) + SeqFormula(-n**2, (n, 0, oo)) + """ + return self.coeff_mul(-1) + + def __mul__(self, other): + """Returns the term-wise multiplication of 'self' and 'other'. + + ``other`` should be a sequence. For ``other`` not being a + sequence see :func:`coeff_mul` method. + + Examples + ======== + + >>> from sympy import SeqFormula + >>> from sympy.abc import n + >>> SeqFormula(n**2) * (SeqFormula(n)) + SeqFormula(n**3, (n, 0, oo)) + """ + if not isinstance(other, SeqBase): + raise TypeError('cannot multiply sequence and %s' % type(other)) + return SeqMul(self, other) + + @call_highest_priority('__mul__') + def __rmul__(self, other): + return self * other + + def __iter__(self): + for i in range(self.length): + pt = self._ith_point(i) + yield self.coeff(pt) + + def __getitem__(self, index): + if isinstance(index, int): + index = self._ith_point(index) + return self.coeff(index) + elif isinstance(index, slice): + start, stop = index.start, index.stop + if start is None: + start = 0 + if stop is None: + stop = self.length + return [self.coeff(self._ith_point(i)) for i in + range(start, stop, index.step or 1)] + + def find_linear_recurrence(self,n,d=None,gfvar=None): + r""" + Finds the shortest linear recurrence that satisfies the first n + terms of sequence of order `\leq` ``n/2`` if possible. + If ``d`` is specified, find shortest linear recurrence of order + `\leq` min(d, n/2) if possible. + Returns list of coefficients ``[b(1), b(2), ...]`` corresponding to the + recurrence relation ``x(n) = b(1)*x(n-1) + b(2)*x(n-2) + ...`` + Returns ``[]`` if no recurrence is found. + If gfvar is specified, also returns ordinary generating function as a + function of gfvar. + + Examples + ======== + + >>> from sympy import sequence, sqrt, oo, lucas + >>> from sympy.abc import n, x, y + >>> sequence(n**2).find_linear_recurrence(10, 2) + [] + >>> sequence(n**2).find_linear_recurrence(10) + [3, -3, 1] + >>> sequence(2**n).find_linear_recurrence(10) + [2] + >>> sequence(23*n**4+91*n**2).find_linear_recurrence(10) + [5, -10, 10, -5, 1] + >>> sequence(sqrt(5)*(((1 + sqrt(5))/2)**n - (-(1 + sqrt(5))/2)**(-n))/5).find_linear_recurrence(10) + [1, 1] + >>> sequence(x+y*(-2)**(-n), (n, 0, oo)).find_linear_recurrence(30) + [1/2, 1/2] + >>> sequence(3*5**n + 12).find_linear_recurrence(20,gfvar=x) + ([6, -5], 3*(5 - 21*x)/((x - 1)*(5*x - 1))) + >>> sequence(lucas(n)).find_linear_recurrence(15,gfvar=x) + ([1, 1], (x - 2)/(x**2 + x - 1)) + """ + from sympy.simplify import simplify + x = [simplify(expand(t)) for t in self[:n]] + lx = len(x) + if d is None: + r = lx//2 + else: + r = min(d,lx//2) + coeffs = [] + for l in range(1, r+1): + l2 = 2*l + mlist = [] + for k in range(l): + mlist.append(x[k:k+l]) + m = Matrix(mlist) + if m.det() != 0: + y = simplify(m.LUsolve(Matrix(x[l:l2]))) + if lx == l2: + coeffs = flatten(y[::-1]) + break + mlist = [] + for k in range(l,lx-l): + mlist.append(x[k:k+l]) + m = Matrix(mlist) + if m*y == Matrix(x[l2:]): + coeffs = flatten(y[::-1]) + break + if gfvar is None: + return coeffs + else: + l = len(coeffs) + if l == 0: + return [], None + else: + n, d = x[l-1]*gfvar**(l-1), 1 - coeffs[l-1]*gfvar**l + for i in range(l-1): + n += x[i]*gfvar**i + for j in range(l-i-1): + n -= coeffs[i]*x[j]*gfvar**(i+j+1) + d -= coeffs[i]*gfvar**(i+1) + return coeffs, simplify(factor(n)/factor(d)) + +class EmptySequence(SeqBase, metaclass=Singleton): + """Represents an empty sequence. + + The empty sequence is also available as a singleton as + ``S.EmptySequence``. + + Examples + ======== + + >>> from sympy import EmptySequence, SeqPer + >>> from sympy.abc import x + >>> EmptySequence + EmptySequence + >>> SeqPer((1, 2), (x, 0, 10)) + EmptySequence + SeqPer((1, 2), (x, 0, 10)) + >>> SeqPer((1, 2)) * EmptySequence + EmptySequence + >>> EmptySequence.coeff_mul(-1) + EmptySequence + """ + + @property + def interval(self): + return S.EmptySet + + @property + def length(self): + return S.Zero + + def coeff_mul(self, coeff): + """See docstring of SeqBase.coeff_mul""" + return self + + def __iter__(self): + return iter([]) + + +class SeqExpr(SeqBase): + """Sequence expression class. + + Various sequences should inherit from this class. + + Examples + ======== + + >>> from sympy.series.sequences import SeqExpr + >>> from sympy.abc import x + >>> from sympy import Tuple + >>> s = SeqExpr(Tuple(1, 2, 3), Tuple(x, 0, 10)) + >>> s.gen + (1, 2, 3) + >>> s.interval + Interval(0, 10) + >>> s.length + 11 + + See Also + ======== + + sympy.series.sequences.SeqPer + sympy.series.sequences.SeqFormula + """ + + @property + def gen(self): + return self.args[0] + + @property + def interval(self): + return Interval(self.args[1][1], self.args[1][2]) + + @property + def start(self): + return self.interval.inf + + @property + def stop(self): + return self.interval.sup + + @property + def length(self): + return self.stop - self.start + 1 + + @property + def variables(self): + return (self.args[1][0],) + + +class SeqPer(SeqExpr): + """ + Represents a periodic sequence. + + The elements are repeated after a given period. + + Examples + ======== + + >>> from sympy import SeqPer, oo + >>> from sympy.abc import k + + >>> s = SeqPer((1, 2, 3), (0, 5)) + >>> s.periodical + (1, 2, 3) + >>> s.period + 3 + + For value at a particular point + + >>> s.coeff(3) + 1 + + supports slicing + + >>> s[:] + [1, 2, 3, 1, 2, 3] + + iterable + + >>> list(s) + [1, 2, 3, 1, 2, 3] + + sequence starts from negative infinity + + >>> SeqPer((1, 2, 3), (-oo, 0))[0:6] + [1, 2, 3, 1, 2, 3] + + Periodic formulas + + >>> SeqPer((k, k**2, k**3), (k, 0, oo))[0:6] + [0, 1, 8, 3, 16, 125] + + See Also + ======== + + sympy.series.sequences.SeqFormula + """ + + def __new__(cls, periodical, limits=None): + periodical = sympify(periodical) + + def _find_x(periodical): + free = periodical.free_symbols + if len(periodical.free_symbols) == 1: + return free.pop() + else: + return Dummy('k') + + x, start, stop = None, None, None + if limits is None: + x, start, stop = _find_x(periodical), 0, S.Infinity + if is_sequence(limits, Tuple): + if len(limits) == 3: + x, start, stop = limits + elif len(limits) == 2: + x = _find_x(periodical) + start, stop = limits + + if not isinstance(x, (Symbol, Idx)) or start is None or stop is None: + raise ValueError('Invalid limits given: %s' % str(limits)) + + if start is S.NegativeInfinity and stop is S.Infinity: + raise ValueError("Both the start and end value" + "cannot be unbounded") + + limits = sympify((x, start, stop)) + + if is_sequence(periodical, Tuple): + periodical = sympify(tuple(flatten(periodical))) + else: + raise ValueError("invalid period %s should be something " + "like e.g (1, 2) " % periodical) + + if Interval(limits[1], limits[2]) is S.EmptySet: + return S.EmptySequence + + return Basic.__new__(cls, periodical, limits) + + @property + def period(self): + return len(self.gen) + + @property + def periodical(self): + return self.gen + + def _eval_coeff(self, pt): + if self.start is S.NegativeInfinity: + idx = (self.stop - pt) % self.period + else: + idx = (pt - self.start) % self.period + return self.periodical[idx].subs(self.variables[0], pt) + + def _add(self, other): + """See docstring of SeqBase._add""" + if isinstance(other, SeqPer): + per1, lper1 = self.periodical, self.period + per2, lper2 = other.periodical, other.period + + per_length = lcm(lper1, lper2) + + new_per = [] + for x in range(per_length): + ele1 = per1[x % lper1] + ele2 = per2[x % lper2] + new_per.append(ele1 + ele2) + + start, stop = self._intersect_interval(other) + return SeqPer(new_per, (self.variables[0], start, stop)) + + def _mul(self, other): + """See docstring of SeqBase._mul""" + if isinstance(other, SeqPer): + per1, lper1 = self.periodical, self.period + per2, lper2 = other.periodical, other.period + + per_length = lcm(lper1, lper2) + + new_per = [] + for x in range(per_length): + ele1 = per1[x % lper1] + ele2 = per2[x % lper2] + new_per.append(ele1 * ele2) + + start, stop = self._intersect_interval(other) + return SeqPer(new_per, (self.variables[0], start, stop)) + + def coeff_mul(self, coeff): + """See docstring of SeqBase.coeff_mul""" + coeff = sympify(coeff) + per = [x * coeff for x in self.periodical] + return SeqPer(per, self.args[1]) + + +class SeqFormula(SeqExpr): + """ + Represents sequence based on a formula. + + Elements are generated using a formula. + + Examples + ======== + + >>> from sympy import SeqFormula, oo, Symbol + >>> n = Symbol('n') + >>> s = SeqFormula(n**2, (n, 0, 5)) + >>> s.formula + n**2 + + For value at a particular point + + >>> s.coeff(3) + 9 + + supports slicing + + >>> s[:] + [0, 1, 4, 9, 16, 25] + + iterable + + >>> list(s) + [0, 1, 4, 9, 16, 25] + + sequence starts from negative infinity + + >>> SeqFormula(n**2, (-oo, 0))[0:6] + [0, 1, 4, 9, 16, 25] + + See Also + ======== + + sympy.series.sequences.SeqPer + """ + + def __new__(cls, formula, limits=None): + formula = sympify(formula) + + def _find_x(formula): + free = formula.free_symbols + if len(free) == 1: + return free.pop() + elif not free: + return Dummy('k') + else: + raise ValueError( + " specify dummy variables for %s. If the formula contains" + " more than one free symbol, a dummy variable should be" + " supplied explicitly e.g., SeqFormula(m*n**2, (n, 0, 5))" + % formula) + + x, start, stop = None, None, None + if limits is None: + x, start, stop = _find_x(formula), 0, S.Infinity + if is_sequence(limits, Tuple): + if len(limits) == 3: + x, start, stop = limits + elif len(limits) == 2: + x = _find_x(formula) + start, stop = limits + + if not isinstance(x, (Symbol, Idx)) or start is None or stop is None: + raise ValueError('Invalid limits given: %s' % str(limits)) + + if start is S.NegativeInfinity and stop is S.Infinity: + raise ValueError("Both the start and end value " + "cannot be unbounded") + limits = sympify((x, start, stop)) + + if Interval(limits[1], limits[2]) is S.EmptySet: + return S.EmptySequence + + return Basic.__new__(cls, formula, limits) + + @property + def formula(self): + return self.gen + + def _eval_coeff(self, pt): + d = self.variables[0] + return self.formula.subs(d, pt) + + def _add(self, other): + """See docstring of SeqBase._add""" + if isinstance(other, SeqFormula): + form1, v1 = self.formula, self.variables[0] + form2, v2 = other.formula, other.variables[0] + formula = form1 + form2.subs(v2, v1) + start, stop = self._intersect_interval(other) + return SeqFormula(formula, (v1, start, stop)) + + def _mul(self, other): + """See docstring of SeqBase._mul""" + if isinstance(other, SeqFormula): + form1, v1 = self.formula, self.variables[0] + form2, v2 = other.formula, other.variables[0] + formula = form1 * form2.subs(v2, v1) + start, stop = self._intersect_interval(other) + return SeqFormula(formula, (v1, start, stop)) + + def coeff_mul(self, coeff): + """See docstring of SeqBase.coeff_mul""" + coeff = sympify(coeff) + formula = self.formula * coeff + return SeqFormula(formula, self.args[1]) + + def expand(self, *args, **kwargs): + return SeqFormula(expand(self.formula, *args, **kwargs), self.args[1]) + +class RecursiveSeq(SeqBase): + """ + A finite degree recursive sequence. + + Explanation + =========== + + That is, a sequence a(n) that depends on a fixed, finite number of its + previous values. The general form is + + a(n) = f(a(n - 1), a(n - 2), ..., a(n - d)) + + for some fixed, positive integer d, where f is some function defined by a + SymPy expression. + + Parameters + ========== + + recurrence : SymPy expression defining recurrence + This is *not* an equality, only the expression that the nth term is + equal to. For example, if :code:`a(n) = f(a(n - 1), ..., a(n - d))`, + then the expression should be :code:`f(a(n - 1), ..., a(n - d))`. + + yn : applied undefined function + Represents the nth term of the sequence as e.g. :code:`y(n)` where + :code:`y` is an undefined function and `n` is the sequence index. + + n : symbolic argument + The name of the variable that the recurrence is in, e.g., :code:`n` if + the recurrence function is :code:`y(n)`. + + initial : iterable with length equal to the degree of the recurrence + The initial values of the recurrence. + + start : start value of sequence (inclusive) + + Examples + ======== + + >>> from sympy import Function, symbols + >>> from sympy.series.sequences import RecursiveSeq + >>> y = Function("y") + >>> n = symbols("n") + >>> fib = RecursiveSeq(y(n - 1) + y(n - 2), y(n), n, [0, 1]) + + >>> fib.coeff(3) # Value at a particular point + 2 + + >>> fib[:6] # supports slicing + [0, 1, 1, 2, 3, 5] + + >>> fib.recurrence # inspect recurrence + Eq(y(n), y(n - 2) + y(n - 1)) + + >>> fib.degree # automatically determine degree + 2 + + >>> for x in zip(range(10), fib): # supports iteration + ... print(x) + (0, 0) + (1, 1) + (2, 1) + (3, 2) + (4, 3) + (5, 5) + (6, 8) + (7, 13) + (8, 21) + (9, 34) + + See Also + ======== + + sympy.series.sequences.SeqFormula + + """ + + def __new__(cls, recurrence, yn, n, initial=None, start=0): + if not isinstance(yn, AppliedUndef): + raise TypeError("recurrence sequence must be an applied undefined function" + ", found `{}`".format(yn)) + + if not isinstance(n, Basic) or not n.is_symbol: + raise TypeError("recurrence variable must be a symbol" + ", found `{}`".format(n)) + + if yn.args != (n,): + raise TypeError("recurrence sequence does not match symbol") + + y = yn.func + + k = Wild("k", exclude=(n,)) + degree = 0 + + # Find all applications of y in the recurrence and check that: + # 1. The function y is only being used with a single argument; and + # 2. All arguments are n + k for constant negative integers k. + + prev_ys = recurrence.find(y) + for prev_y in prev_ys: + if len(prev_y.args) != 1: + raise TypeError("Recurrence should be in a single variable") + + shift = prev_y.args[0].match(n + k)[k] + if not (shift.is_constant() and shift.is_integer and shift < 0): + raise TypeError("Recurrence should have constant," + " negative, integer shifts" + " (found {})".format(prev_y)) + + if -shift > degree: + degree = -shift + + if not initial: + initial = [Dummy("c_{}".format(k)) for k in range(degree)] + + if len(initial) != degree: + raise ValueError("Number of initial terms must equal degree") + + degree = Integer(degree) + start = sympify(start) + + initial = Tuple(*(sympify(x) for x in initial)) + + seq = Basic.__new__(cls, recurrence, yn, n, initial, start) + + seq.cache = {y(start + k): init for k, init in enumerate(initial)} + seq.degree = degree + + return seq + + @property + def _recurrence(self): + """Equation defining recurrence.""" + return self.args[0] + + @property + def recurrence(self): + """Equation defining recurrence.""" + return Eq(self.yn, self.args[0]) + + @property + def yn(self): + """Applied function representing the nth term""" + return self.args[1] + + @property + def y(self): + """Undefined function for the nth term of the sequence""" + return self.yn.func + + @property + def n(self): + """Sequence index symbol""" + return self.args[2] + + @property + def initial(self): + """The initial values of the sequence""" + return self.args[3] + + @property + def start(self): + """The starting point of the sequence. This point is included""" + return self.args[4] + + @property + def stop(self): + """The ending point of the sequence. (oo)""" + return S.Infinity + + @property + def interval(self): + """Interval on which sequence is defined.""" + return (self.start, S.Infinity) + + def _eval_coeff(self, index): + if index - self.start < len(self.cache): + return self.cache[self.y(index)] + + for current in range(len(self.cache), index + 1): + # Use xreplace over subs for performance. + # See issue #10697. + seq_index = self.start + current + current_recurrence = self._recurrence.xreplace({self.n: seq_index}) + new_term = current_recurrence.xreplace(self.cache) + + self.cache[self.y(seq_index)] = new_term + + return self.cache[self.y(self.start + current)] + + def __iter__(self): + index = self.start + while True: + yield self._eval_coeff(index) + index += 1 + + +def sequence(seq, limits=None): + """ + Returns appropriate sequence object. + + Explanation + =========== + + If ``seq`` is a SymPy sequence, returns :class:`SeqPer` object + otherwise returns :class:`SeqFormula` object. + + Examples + ======== + + >>> from sympy import sequence + >>> from sympy.abc import n + >>> sequence(n**2, (n, 0, 5)) + SeqFormula(n**2, (n, 0, 5)) + >>> sequence((1, 2, 3), (n, 0, 5)) + SeqPer((1, 2, 3), (n, 0, 5)) + + See Also + ======== + + sympy.series.sequences.SeqPer + sympy.series.sequences.SeqFormula + """ + seq = sympify(seq) + + if is_sequence(seq, Tuple): + return SeqPer(seq, limits) + else: + return SeqFormula(seq, limits) + + +############################################################################### +# OPERATIONS # +############################################################################### + + +class SeqExprOp(SeqBase): + """ + Base class for operations on sequences. + + Examples + ======== + + >>> from sympy.series.sequences import SeqExprOp, sequence + >>> from sympy.abc import n + >>> s1 = sequence(n**2, (n, 0, 10)) + >>> s2 = sequence((1, 2, 3), (n, 5, 10)) + >>> s = SeqExprOp(s1, s2) + >>> s.gen + (n**2, (1, 2, 3)) + >>> s.interval + Interval(5, 10) + >>> s.length + 6 + + See Also + ======== + + sympy.series.sequences.SeqAdd + sympy.series.sequences.SeqMul + """ + @property + def gen(self): + """Generator for the sequence. + + returns a tuple of generators of all the argument sequences. + """ + return tuple(a.gen for a in self.args) + + @property + def interval(self): + """Sequence is defined on the intersection + of all the intervals of respective sequences + """ + return Intersection(*(a.interval for a in self.args)) + + @property + def start(self): + return self.interval.inf + + @property + def stop(self): + return self.interval.sup + + @property + def variables(self): + """Cumulative of all the bound variables""" + return tuple(flatten([a.variables for a in self.args])) + + @property + def length(self): + return self.stop - self.start + 1 + + +class SeqAdd(SeqExprOp): + """Represents term-wise addition of sequences. + + Rules: + * The interval on which sequence is defined is the intersection + of respective intervals of sequences. + * Anything + :class:`EmptySequence` remains unchanged. + * Other rules are defined in ``_add`` methods of sequence classes. + + Examples + ======== + + >>> from sympy import EmptySequence, oo, SeqAdd, SeqPer, SeqFormula + >>> from sympy.abc import n + >>> SeqAdd(SeqPer((1, 2), (n, 0, oo)), EmptySequence) + SeqPer((1, 2), (n, 0, oo)) + >>> SeqAdd(SeqPer((1, 2), (n, 0, 5)), SeqPer((1, 2), (n, 6, 10))) + EmptySequence + >>> SeqAdd(SeqPer((1, 2), (n, 0, oo)), SeqFormula(n**2, (n, 0, oo))) + SeqAdd(SeqFormula(n**2, (n, 0, oo)), SeqPer((1, 2), (n, 0, oo))) + >>> SeqAdd(SeqFormula(n**3), SeqFormula(n**2)) + SeqFormula(n**3 + n**2, (n, 0, oo)) + + See Also + ======== + + sympy.series.sequences.SeqMul + """ + + def __new__(cls, *args, **kwargs): + evaluate = kwargs.get('evaluate', global_parameters.evaluate) + + # flatten inputs + args = list(args) + + # adapted from sympy.sets.sets.Union + def _flatten(arg): + if isinstance(arg, SeqBase): + if isinstance(arg, SeqAdd): + return sum(map(_flatten, arg.args), []) + else: + return [arg] + if iterable(arg): + return sum(map(_flatten, arg), []) + raise TypeError("Input must be Sequences or " + " iterables of Sequences") + args = _flatten(args) + + args = [a for a in args if a is not S.EmptySequence] + + # Addition of no sequences is EmptySequence + if not args: + return S.EmptySequence + + if Intersection(*(a.interval for a in args)) is S.EmptySet: + return S.EmptySequence + + # reduce using known rules + if evaluate: + return SeqAdd.reduce(args) + + args = list(ordered(args, SeqBase._start_key)) + + return Basic.__new__(cls, *args) + + @staticmethod + def reduce(args): + """Simplify :class:`SeqAdd` using known rules. + + Iterates through all pairs and ask the constituent + sequences if they can simplify themselves with any other constituent. + + Notes + ===== + + adapted from ``Union.reduce`` + + """ + new_args = True + while new_args: + for id1, s in enumerate(args): + new_args = False + for id2, t in enumerate(args): + if id1 == id2: + continue + new_seq = s._add(t) + # This returns None if s does not know how to add + # with t. Returns the newly added sequence otherwise + if new_seq is not None: + new_args = [a for a in args if a not in (s, t)] + new_args.append(new_seq) + break + if new_args: + args = new_args + break + + if len(args) == 1: + return args.pop() + else: + return SeqAdd(args, evaluate=False) + + def _eval_coeff(self, pt): + """adds up the coefficients of all the sequences at point pt""" + return sum(a.coeff(pt) for a in self.args) + + +class SeqMul(SeqExprOp): + r"""Represents term-wise multiplication of sequences. + + Explanation + =========== + + Handles multiplication of sequences only. For multiplication + with other objects see :func:`SeqBase.coeff_mul`. + + Rules: + * The interval on which sequence is defined is the intersection + of respective intervals of sequences. + * Anything \* :class:`EmptySequence` returns :class:`EmptySequence`. + * Other rules are defined in ``_mul`` methods of sequence classes. + + Examples + ======== + + >>> from sympy import EmptySequence, oo, SeqMul, SeqPer, SeqFormula + >>> from sympy.abc import n + >>> SeqMul(SeqPer((1, 2), (n, 0, oo)), EmptySequence) + EmptySequence + >>> SeqMul(SeqPer((1, 2), (n, 0, 5)), SeqPer((1, 2), (n, 6, 10))) + EmptySequence + >>> SeqMul(SeqPer((1, 2), (n, 0, oo)), SeqFormula(n**2)) + SeqMul(SeqFormula(n**2, (n, 0, oo)), SeqPer((1, 2), (n, 0, oo))) + >>> SeqMul(SeqFormula(n**3), SeqFormula(n**2)) + SeqFormula(n**5, (n, 0, oo)) + + See Also + ======== + + sympy.series.sequences.SeqAdd + """ + + def __new__(cls, *args, **kwargs): + evaluate = kwargs.get('evaluate', global_parameters.evaluate) + + # flatten inputs + args = list(args) + + # adapted from sympy.sets.sets.Union + def _flatten(arg): + if isinstance(arg, SeqBase): + if isinstance(arg, SeqMul): + return sum(map(_flatten, arg.args), []) + else: + return [arg] + elif iterable(arg): + return sum(map(_flatten, arg), []) + raise TypeError("Input must be Sequences or " + " iterables of Sequences") + args = _flatten(args) + + # Multiplication of no sequences is EmptySequence + if not args: + return S.EmptySequence + + if Intersection(*(a.interval for a in args)) is S.EmptySet: + return S.EmptySequence + + # reduce using known rules + if evaluate: + return SeqMul.reduce(args) + + args = list(ordered(args, SeqBase._start_key)) + + return Basic.__new__(cls, *args) + + @staticmethod + def reduce(args): + """Simplify a :class:`SeqMul` using known rules. + + Explanation + =========== + + Iterates through all pairs and ask the constituent + sequences if they can simplify themselves with any other constituent. + + Notes + ===== + + adapted from ``Union.reduce`` + + """ + new_args = True + while new_args: + for id1, s in enumerate(args): + new_args = False + for id2, t in enumerate(args): + if id1 == id2: + continue + new_seq = s._mul(t) + # This returns None if s does not know how to multiply + # with t. Returns the newly multiplied sequence otherwise + if new_seq is not None: + new_args = [a for a in args if a not in (s, t)] + new_args.append(new_seq) + break + if new_args: + args = new_args + break + + if len(args) == 1: + return args.pop() + else: + return SeqMul(args, evaluate=False) + + def _eval_coeff(self, pt): + """multiplies the coefficients of all the sequences at point pt""" + val = 1 + for a in self.args: + val *= a.coeff(pt) + return val diff --git a/phivenv/Lib/site-packages/sympy/series/series.py b/phivenv/Lib/site-packages/sympy/series/series.py new file mode 100644 index 0000000000000000000000000000000000000000..e9feec7d3b1987bfaa5238969f531e9f98b88b25 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/series/series.py @@ -0,0 +1,63 @@ +from sympy.core.sympify import sympify + + +def series(expr, x=None, x0=0, n=6, dir="+"): + """Series expansion of expr around point `x = x0`. + + Parameters + ========== + + expr : Expression + The expression whose series is to be expanded. + + x : Symbol + It is the variable of the expression to be calculated. + + x0 : Value + The value around which ``x`` is calculated. Can be any value + from ``-oo`` to ``oo``. + + n : Value + The number of terms upto which the series is to be expanded. + + dir : String, optional + The series-expansion can be bi-directional. If ``dir="+"``, + then (x->x0+). If ``dir="-"``, then (x->x0-). For infinite + ``x0`` (``oo`` or ``-oo``), the ``dir`` argument is determined + from the direction of the infinity (i.e., ``dir="-"`` for + ``oo``). + + Examples + ======== + + >>> from sympy import series, tan, oo + >>> from sympy.abc import x + >>> f = tan(x) + >>> series(f, x, 2, 6, "+") + tan(2) + (1 + tan(2)**2)*(x - 2) + (x - 2)**2*(tan(2)**3 + tan(2)) + + (x - 2)**3*(1/3 + 4*tan(2)**2/3 + tan(2)**4) + (x - 2)**4*(tan(2)**5 + + 5*tan(2)**3/3 + 2*tan(2)/3) + (x - 2)**5*(2/15 + 17*tan(2)**2/15 + + 2*tan(2)**4 + tan(2)**6) + O((x - 2)**6, (x, 2)) + + >>> series(f, x, 2, 3, "-") + tan(2) + (2 - x)*(-tan(2)**2 - 1) + (2 - x)**2*(tan(2)**3 + tan(2)) + + O((x - 2)**3, (x, 2)) + + >>> series(f, x, 2, oo, "+") + Traceback (most recent call last): + ... + TypeError: 'Infinity' object cannot be interpreted as an integer + + Returns + ======= + + Expr + Series expansion of the expression about x0 + + See Also + ======== + + sympy.core.expr.Expr.series: See the docstring of Expr.series() for complete details of this wrapper. + """ + expr = sympify(expr) + return expr.series(x, x0, n, dir) diff --git a/phivenv/Lib/site-packages/sympy/series/series_class.py b/phivenv/Lib/site-packages/sympy/series/series_class.py new file mode 100644 index 0000000000000000000000000000000000000000..ff04993b266a3cbd3f767042d4325fb11edb2168 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/series/series_class.py @@ -0,0 +1,99 @@ +""" +Contains the base class for series +Made using sequences in mind +""" + +from sympy.core.expr import Expr +from sympy.core.singleton import S +from sympy.core.cache import cacheit + + +class SeriesBase(Expr): + """Base Class for series""" + + @property + def interval(self): + """The interval on which the series is defined""" + raise NotImplementedError("(%s).interval" % self) + + @property + def start(self): + """The starting point of the series. This point is included""" + raise NotImplementedError("(%s).start" % self) + + @property + def stop(self): + """The ending point of the series. This point is included""" + raise NotImplementedError("(%s).stop" % self) + + @property + def length(self): + """Length of the series expansion""" + raise NotImplementedError("(%s).length" % self) + + @property + def variables(self): + """Returns a tuple of variables that are bounded""" + return () + + @property + def free_symbols(self): + """ + This method returns the symbols in the object, excluding those + that take on a specific value (i.e. the dummy symbols). + """ + return ({j for i in self.args for j in i.free_symbols} + .difference(self.variables)) + + @cacheit + def term(self, pt): + """Term at point pt of a series""" + if pt < self.start or pt > self.stop: + raise IndexError("Index %s out of bounds %s" % (pt, self.interval)) + return self._eval_term(pt) + + def _eval_term(self, pt): + raise NotImplementedError("The _eval_term method should be added to" + "%s to return series term so it is available" + "when 'term' calls it." + % self.func) + + def _ith_point(self, i): + """ + Returns the i'th point of a series + If start point is negative infinity, point is returned from the end. + Assumes the first point to be indexed zero. + + Examples + ======== + + TODO + """ + if self.start is S.NegativeInfinity: + initial = self.stop + step = -1 + else: + initial = self.start + step = 1 + + return initial + i*step + + def __iter__(self): + i = 0 + while i < self.length: + pt = self._ith_point(i) + yield self.term(pt) + i += 1 + + def __getitem__(self, index): + if isinstance(index, int): + index = self._ith_point(index) + return self.term(index) + elif isinstance(index, slice): + start, stop = index.start, index.stop + if start is None: + start = 0 + if stop is None: + stop = self.length + return [self.term(self._ith_point(i)) for i in + range(start, stop, index.step or 1)] diff --git a/phivenv/Lib/site-packages/sympy/series/tests/__init__.py b/phivenv/Lib/site-packages/sympy/series/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/series/tests/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/series/tests/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81e3b49f834afa361c6a164a630a0c74093af741 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/series/tests/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/series/tests/__pycache__/test_approximants.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/series/tests/__pycache__/test_approximants.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba2bc577215615e4d94f9a97923342f4a13ab8e9 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/series/tests/__pycache__/test_approximants.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/series/tests/__pycache__/test_aseries.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/series/tests/__pycache__/test_aseries.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28767ec4209fdc7a24f29ad12f716ff32fc411f8 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/series/tests/__pycache__/test_aseries.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/series/tests/__pycache__/test_demidovich.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/series/tests/__pycache__/test_demidovich.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ce12f3a08521a5560f886c2d4450ddd3e867d3d Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/series/tests/__pycache__/test_demidovich.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/series/tests/__pycache__/test_formal.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/series/tests/__pycache__/test_formal.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a7d5391260214da7260b76219c265408a0198d4 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/series/tests/__pycache__/test_formal.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/series/tests/__pycache__/test_fourier.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/series/tests/__pycache__/test_fourier.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7f6af86a83719027e12013663a890e269e20aa3 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/series/tests/__pycache__/test_fourier.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/series/tests/__pycache__/test_gruntz.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/series/tests/__pycache__/test_gruntz.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..796a9dd57ba05475e7ddc69363842bd62ebff4a6 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/series/tests/__pycache__/test_gruntz.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/series/tests/__pycache__/test_kauers.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/series/tests/__pycache__/test_kauers.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..099152deb94cc671e457f9107d4e5b9025f6dbe3 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/series/tests/__pycache__/test_kauers.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/series/tests/__pycache__/test_limits.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/series/tests/__pycache__/test_limits.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bea537977abfe40331f54b2ce683d109794e1473 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/series/tests/__pycache__/test_limits.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/series/tests/__pycache__/test_limitseq.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/series/tests/__pycache__/test_limitseq.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df30ccec931a7896c4e581e311526292b9eab53d Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/series/tests/__pycache__/test_limitseq.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/series/tests/__pycache__/test_lseries.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/series/tests/__pycache__/test_lseries.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f7da77340a325b13a990970db19f57669d4972d Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/series/tests/__pycache__/test_lseries.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/series/tests/__pycache__/test_nseries.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/series/tests/__pycache__/test_nseries.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cfa21a9be8cc5b4c83ee46a0738f4b422d850bb3 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/series/tests/__pycache__/test_nseries.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/series/tests/__pycache__/test_order.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/series/tests/__pycache__/test_order.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..73e925fc31781b9d19ff67033f771f1afb8a37b3 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/series/tests/__pycache__/test_order.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/series/tests/__pycache__/test_residues.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/series/tests/__pycache__/test_residues.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..639463fbd8b6daab1e1090ab267a683cad1ce600 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/series/tests/__pycache__/test_residues.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/series/tests/__pycache__/test_sequences.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/series/tests/__pycache__/test_sequences.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce3d39c8b6c9c277975fcbaf7b3137dec97242a1 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/series/tests/__pycache__/test_sequences.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/series/tests/__pycache__/test_series.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/series/tests/__pycache__/test_series.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32c97c2129fcf7d2fb9e6e1952f74ad38146fb61 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/series/tests/__pycache__/test_series.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/series/tests/test_approximants.py b/phivenv/Lib/site-packages/sympy/series/tests/test_approximants.py new file mode 100644 index 0000000000000000000000000000000000000000..9c03d2ce38add99b0dce8725b6c8d8844b31f76b --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/series/tests/test_approximants.py @@ -0,0 +1,23 @@ +from sympy.series import approximants +from sympy.core.symbol import symbols +from sympy.functions.combinatorial.factorials import binomial +from sympy.functions.combinatorial.numbers import (fibonacci, lucas) + + +def test_approximants(): + x, t = symbols("x,t") + g = [lucas(k) for k in range(16)] + assert list(approximants(g)) == ( + [2, -4/(x - 2), (5*x - 2)/(3*x - 1), (x - 2)/(x**2 + x - 1)] ) + g = [lucas(k)+fibonacci(k+2) for k in range(16)] + assert list(approximants(g)) == ( + [3, -3/(x - 1), (3*x - 3)/(2*x - 1), -3/(x**2 + x - 1)] ) + g = [lucas(k)**2 for k in range(16)] + assert list(approximants(g)) == ( + [4, -16/(x - 4), (35*x - 4)/(9*x - 1), (37*x - 28)/(13*x**2 + 11*x - 7), + (50*x**2 + 63*x - 52)/(37*x**2 + 19*x - 13), + (-x**2 - 7*x + 4)/(x**3 - 2*x**2 - 2*x + 1)] ) + p = [sum(binomial(k,i)*x**i for i in range(k+1)) for k in range(16)] + y = approximants(p, t, simplify=True) + assert next(y) == 1 + assert next(y) == -1/(t*(x + 1) - 1) diff --git a/phivenv/Lib/site-packages/sympy/series/tests/test_aseries.py b/phivenv/Lib/site-packages/sympy/series/tests/test_aseries.py new file mode 100644 index 0000000000000000000000000000000000000000..cae0ac0a43f2406dd96e45c6a31939ac6b4cdcaa --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/series/tests/test_aseries.py @@ -0,0 +1,55 @@ +from sympy.core.function import PoleError +from sympy.core.numbers import oo +from sympy.core.symbol import Symbol +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.series.order import O +from sympy.abc import x + +from sympy.testing.pytest import raises + +def test_simple(): + # Gruntz' theses pp. 91 to 96 + # 6.6 + e = sin(1/x + exp(-x)) - sin(1/x) + assert e.aseries(x) == (1/(24*x**4) - 1/(2*x**2) + 1 + O(x**(-6), (x, oo)))*exp(-x) + + e = exp(x) * (exp(1/x + exp(-x)) - exp(1/x)) + assert e.aseries(x, n=4) == 1/(6*x**3) + 1/(2*x**2) + 1/x + 1 + O(x**(-4), (x, oo)) + + e = exp(exp(x) / (1 - 1/x)) + assert e.aseries(x) == exp(exp(x) / (1 - 1/x)) + + # The implementation of bound in aseries is incorrect currently. This test + # should be commented out when that is fixed. + # assert e.aseries(x, bound=3) == exp(exp(x) / x**2)*exp(exp(x) / x)*exp(-exp(x) + exp(x)/(1 - 1/x) - \ + # exp(x) / x - exp(x) / x**2) * exp(exp(x)) + + e = exp(sin(1/x + exp(-exp(x)))) - exp(sin(1/x)) + assert e.aseries(x, n=4) == (-1/(2*x**3) + 1/x + 1 + O(x**(-4), (x, oo)))*exp(-exp(x)) + + e3 = lambda x:exp(exp(exp(x))) + e = e3(x)/e3(x - 1/e3(x)) + assert e.aseries(x, n=3) == 1 + exp(2*x + 2*exp(x))*exp(-2*exp(exp(x)))/2\ + - exp(2*x + exp(x))*exp(-2*exp(exp(x)))/2 - exp(x + exp(x))*exp(-2*exp(exp(x)))/2\ + + exp(x + exp(x))*exp(-exp(exp(x))) + O(exp(-3*exp(exp(x))), (x, oo)) + + e = exp(exp(x)) * (exp(sin(1/x + 1/exp(exp(x)))) - exp(sin(1/x))) + assert e.aseries(x, n=4) == -1/(2*x**3) + 1/x + 1 + O(x**(-4), (x, oo)) + + n = Symbol('n', integer=True) + e = (sqrt(n)*log(n)**2*exp(sqrt(log(n))*log(log(n))**2*exp(sqrt(log(log(n)))*log(log(log(n)))**3)))/n + assert e.aseries(n) == \ + exp(exp(sqrt(log(log(n)))*log(log(log(n)))**3)*sqrt(log(n))*log(log(n))**2)*log(n)**2/sqrt(n) + + +def test_hierarchical(): + e = sin(1/x + exp(-x)) + assert e.aseries(x, n=3, hir=True) == -exp(-2*x)*sin(1/x)/2 + \ + exp(-x)*cos(1/x) + sin(1/x) + O(exp(-3*x), (x, oo)) + + e = sin(x) * cos(exp(-x)) + assert e.aseries(x, hir=True) == exp(-4*x)*sin(x)/24 - \ + exp(-2*x)*sin(x)/2 + sin(x) + O(exp(-6*x), (x, oo)) + raises(PoleError, lambda: e.aseries(x)) diff --git a/phivenv/Lib/site-packages/sympy/series/tests/test_demidovich.py b/phivenv/Lib/site-packages/sympy/series/tests/test_demidovich.py new file mode 100644 index 0000000000000000000000000000000000000000..98cafbae6f019dd3d97d306099d5780ed2f37f04 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/series/tests/test_demidovich.py @@ -0,0 +1,143 @@ +from sympy.core.numbers import (Rational, oo, pi) +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import (root, sqrt) +from sympy.functions.elementary.trigonometric import (asin, cos, sin, tan) +from sympy.polys.rationaltools import together +from sympy.series.limits import limit + +# Numbers listed with the tests refer to problem numbers in the book +# "Anti-demidovich, problemas resueltos, Ed. URSS" + +x = Symbol("x") + + +def test_leadterm(): + assert (3 + 2*x**(log(3)/log(2) - 1)).leadterm(x) == (3, 0) + + +def root3(x): + return root(x, 3) + + +def root4(x): + return root(x, 4) + + +def test_Limits_simple_0(): + assert limit((2**(x + 1) + 3**(x + 1))/(2**x + 3**x), x, oo) == 3 # 175 + + +def test_Limits_simple_1(): + assert limit((x + 1)*(x + 2)*(x + 3)/x**3, x, oo) == 1 # 172 + assert limit(sqrt(x + 1) - sqrt(x), x, oo) == 0 # 179 + assert limit((2*x - 3)*(3*x + 5)*(4*x - 6)/(3*x**3 + x - 1), x, oo) == 8 # Primjer 1 + assert limit(x/root3(x**3 + 10), x, oo) == 1 # Primjer 2 + assert limit((x + 1)**2/(x**2 + 1), x, oo) == 1 # 181 + + +def test_Limits_simple_2(): + assert limit(1000*x/(x**2 - 1), x, oo) == 0 # 182 + assert limit((x**2 - 5*x + 1)/(3*x + 7), x, oo) is oo # 183 + assert limit((2*x**2 - x + 3)/(x**3 - 8*x + 5), x, oo) == 0 # 184 + assert limit((2*x**2 - 3*x - 4)/sqrt(x**4 + 1), x, oo) == 2 # 186 + assert limit((2*x + 3)/(x + root3(x)), x, oo) == 2 # 187 + assert limit(x**2/(10 + x*sqrt(x)), x, oo) is oo # 188 + assert limit(root3(x**2 + 1)/(x + 1), x, oo) == 0 # 189 + assert limit(sqrt(x)/sqrt(x + sqrt(x + sqrt(x))), x, oo) == 1 # 190 + + +def test_Limits_simple_3a(): + a = Symbol('a') + #issue 3513 + assert together(limit((x**2 - (a + 1)*x + a)/(x**3 - a**3), x, a)) == \ + (a - 1)/(3*a**2) # 196 + + +def test_Limits_simple_3b(): + h = Symbol("h") + assert limit(((x + h)**3 - x**3)/h, h, 0) == 3*x**2 # 197 + assert limit((1/(1 - x) - 3/(1 - x**3)), x, 1) == -1 # 198 + assert limit((sqrt(1 + x) - 1)/(root3(1 + x) - 1), x, 0) == Rational(3)/2 # Primer 4 + assert limit((sqrt(x) - 1)/(x - 1), x, 1) == Rational(1)/2 # 199 + assert limit((sqrt(x) - 8)/(root3(x) - 4), x, 64) == 3 # 200 + assert limit((root3(x) - 1)/(root4(x) - 1), x, 1) == Rational(4)/3 # 201 + assert limit( + (root3(x**2) - 2*root3(x) + 1)/(x - 1)**2, x, 1) == Rational(1)/9 # 202 + + +def test_Limits_simple_4a(): + a = Symbol('a') + assert limit((sqrt(x) - sqrt(a))/(x - a), x, a) == 1/(2*sqrt(a)) # Primer 5 + assert limit((sqrt(x) - 1)/(root3(x) - 1), x, 1) == Rational(3, 2) # 205 + assert limit((sqrt(1 + x) - sqrt(1 - x))/x, x, 0) == 1 # 207 + assert limit(sqrt(x**2 - 5*x + 6) - x, x, oo) == Rational(-5, 2) # 213 + + +def test_limits_simple_4aa(): + assert limit(x*(sqrt(x**2 + 1) - x), x, oo) == Rational(1)/2 # 214 + + +def test_Limits_simple_4b(): + #issue 3511 + assert limit(x - root3(x**3 - 1), x, oo) == 0 # 215 + + +def test_Limits_simple_4c(): + assert limit(log(1 + exp(x))/x, x, -oo) == 0 # 267a + assert limit(log(1 + exp(x))/x, x, oo) == 1 # 267b + + +def test_bounded(): + assert limit(sin(x)/x, x, oo) == 0 # 216b + assert limit(x*sin(1/x), x, 0) == 0 # 227a + + +def test_f1a(): + #issue 3508: + assert limit((sin(2*x)/x)**(1 + x), x, 0) == 2 # Primer 7 + + +def test_f1a2(): + #issue 3509: + assert limit(((x - 1)/(x + 1))**x, x, oo) == exp(-2) # Primer 9 + + +def test_f1b(): + m = Symbol("m") + n = Symbol("n") + h = Symbol("h") + a = Symbol("a") + assert limit(sin(x)/x, x, 2) == sin(2)/2 # 216a + assert limit(sin(3*x)/x, x, 0) == 3 # 217 + assert limit(sin(5*x)/sin(2*x), x, 0) == Rational(5, 2) # 218 + assert limit(sin(pi*x)/sin(3*pi*x), x, 0) == Rational(1, 3) # 219 + assert limit(x*sin(pi/x), x, oo) == pi # 220 + assert limit((1 - cos(x))/x**2, x, 0) == S.Half # 221 + assert limit(x*sin(1/x), x, oo) == 1 # 227b + assert limit((cos(m*x) - cos(n*x))/x**2, x, 0) == -m**2/2 + n**2/2 # 232 + assert limit((tan(x) - sin(x))/x**3, x, 0) == S.Half # 233 + assert limit((x - sin(2*x))/(x + sin(3*x)), x, 0) == -Rational(1, 4) # 237 + assert limit((1 - sqrt(cos(x)))/x**2, x, 0) == Rational(1, 4) # 239 + assert limit((sqrt(1 + sin(x)) - sqrt(1 - sin(x)))/x, x, 0) == 1 # 240 + + assert limit((1 + h/x)**x, x, oo) == exp(h) # Primer 9 + assert limit((sin(x) - sin(a))/(x - a), x, a) == cos(a) # 222, *176 + assert limit((cos(x) - cos(a))/(x - a), x, a) == -sin(a) # 223 + assert limit((sin(x + h) - sin(x))/h, h, 0) == cos(x) # 225 + + +def test_f2a(): + assert limit(((x + 1)/(2*x + 1))**(x**2), x, oo) == 0 # Primer 8 + + +def test_f2(): + assert limit((sqrt( + cos(x)) - root3(cos(x)))/(sin(x)**2), x, 0) == -Rational(1, 12) # *184 + + +def test_f3(): + a = Symbol('a') + #issue 3504 + assert limit(asin(a*x)/x, x, 0) == a diff --git a/phivenv/Lib/site-packages/sympy/series/tests/test_formal.py b/phivenv/Lib/site-packages/sympy/series/tests/test_formal.py new file mode 100644 index 0000000000000000000000000000000000000000..cac60b12534152a5783bb8f0faab2c06da6691fb --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/series/tests/test_formal.py @@ -0,0 +1,618 @@ +from sympy.concrete.summations import Sum +from sympy.core.add import Add +from sympy.core.function import (Derivative, Function) +from sympy.core.mul import Mul +from sympy.core.numbers import (I, Rational, oo, pi) +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.combinatorial.factorials import factorial +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.hyperbolic import (acosh, asech) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (acos, asin, atan, cos, sin) +from sympy.functions.special.bessel import airyai +from sympy.functions.special.error_functions import erf +from sympy.functions.special.gamma_functions import gamma +from sympy.integrals.integrals import integrate +from sympy.series.formal import fps +from sympy.series.order import O +from sympy.series.formal import (rational_algorithm, FormalPowerSeries, + FormalPowerSeriesProduct, FormalPowerSeriesCompose, + FormalPowerSeriesInverse, simpleDE, + rational_independent, exp_re, hyper_re) +from sympy.testing.pytest import raises, XFAIL, slow + +x, y, z = symbols('x y z') +n, m, k = symbols('n m k', integer=True) +f, r = Function('f'), Function('r') + + +def test_rational_algorithm(): + f = 1 / ((x - 1)**2 * (x - 2)) + assert rational_algorithm(f, x, k) == \ + (-2**(-k - 1) + 1 - (factorial(k + 1) / factorial(k)), 0, 0) + + f = (1 + x + x**2 + x**3) / ((x - 1) * (x - 2)) + assert rational_algorithm(f, x, k) == \ + (-15*2**(-k - 1) + 4, x + 4, 0) + + f = z / (y*m - m*x - y*x + x**2) + assert rational_algorithm(f, x, k) == \ + (((-y**(-k - 1)*z) / (y - m)) + ((m**(-k - 1)*z) / (y - m)), 0, 0) + + f = x / (1 - x - x**2) + assert rational_algorithm(f, x, k) is None + assert rational_algorithm(f, x, k, full=True) == \ + (((Rational(-1, 2) + sqrt(5)/2)**(-k - 1) * + (-sqrt(5)/10 + S.Half)) + + ((-sqrt(5)/2 - S.Half)**(-k - 1) * + (sqrt(5)/10 + S.Half)), 0, 0) + + f = 1 / (x**2 + 2*x + 2) + assert rational_algorithm(f, x, k) is None + assert rational_algorithm(f, x, k, full=True) == \ + ((I*(-1 + I)**(-k - 1)) / 2 - (I*(-1 - I)**(-k - 1)) / 2, 0, 0) + + f = log(1 + x) + assert rational_algorithm(f, x, k) == \ + (-(-1)**(-k) / k, 0, 1) + + f = atan(x) + assert rational_algorithm(f, x, k) is None + assert rational_algorithm(f, x, k, full=True) == \ + (((I*I**(-k)) / 2 - (I*(-I)**(-k)) / 2) / k, 0, 1) + + f = x*atan(x) - log(1 + x**2) / 2 + assert rational_algorithm(f, x, k) is None + assert rational_algorithm(f, x, k, full=True) == \ + (((I*I**(-k + 1)) / 2 - (I*(-I)**(-k + 1)) / 2) / + (k*(k - 1)), 0, 2) + + f = log((1 + x) / (1 - x)) / 2 - atan(x) + assert rational_algorithm(f, x, k) is None + assert rational_algorithm(f, x, k, full=True) == \ + ((-(-1)**(-k) / 2 - (I*I**(-k)) / 2 + (I*(-I)**(-k)) / 2 + + S.Half) / k, 0, 1) + + assert rational_algorithm(cos(x), x, k) is None + + +def test_rational_independent(): + ri = rational_independent + assert ri([], x) == [] + assert ri([cos(x), sin(x)], x) == [cos(x), sin(x)] + assert ri([x**2, sin(x), x*sin(x), x**3], x) == \ + [x**3 + x**2, x*sin(x) + sin(x)] + assert ri([S.One, x*log(x), log(x), sin(x)/x, cos(x), sin(x), x], x) == \ + [x + 1, x*log(x) + log(x), sin(x)/x + sin(x), cos(x)] + + +def test_simpleDE(): + # Tests just the first valid DE + for DE in simpleDE(exp(x), x, f): + assert DE == (-f(x) + Derivative(f(x), x), 1) + break + for DE in simpleDE(sin(x), x, f): + assert DE == (f(x) + Derivative(f(x), x, x), 2) + break + for DE in simpleDE(log(1 + x), x, f): + assert DE == ((x + 1)*Derivative(f(x), x, 2) + Derivative(f(x), x), 2) + break + for DE in simpleDE(asin(x), x, f): + assert DE == (x*Derivative(f(x), x) + (x**2 - 1)*Derivative(f(x), x, x), + 2) + break + for DE in simpleDE(exp(x)*sin(x), x, f): + assert DE == (2*f(x) - 2*Derivative(f(x)) + Derivative(f(x), x, x), 2) + break + for DE in simpleDE(((1 + x)/(1 - x))**n, x, f): + assert DE == (2*n*f(x) + (x**2 - 1)*Derivative(f(x), x), 1) + break + for DE in simpleDE(airyai(x), x, f): + assert DE == (-x*f(x) + Derivative(f(x), x, x), 2) + break + + +def test_exp_re(): + d = -f(x) + Derivative(f(x), x) + assert exp_re(d, r, k) == -r(k) + r(k + 1) + + d = f(x) + Derivative(f(x), x, x) + assert exp_re(d, r, k) == r(k) + r(k + 2) + + d = f(x) + Derivative(f(x), x) + Derivative(f(x), x, x) + assert exp_re(d, r, k) == r(k) + r(k + 1) + r(k + 2) + + d = Derivative(f(x), x) + Derivative(f(x), x, x) + assert exp_re(d, r, k) == r(k) + r(k + 1) + + d = Derivative(f(x), x, 3) + Derivative(f(x), x, 4) + Derivative(f(x)) + assert exp_re(d, r, k) == r(k) + r(k + 2) + r(k + 3) + + +def test_hyper_re(): + d = f(x) + Derivative(f(x), x, x) + assert hyper_re(d, r, k) == r(k) + (k+1)*(k+2)*r(k + 2) + + d = -x*f(x) + Derivative(f(x), x, x) + assert hyper_re(d, r, k) == (k + 2)*(k + 3)*r(k + 3) - r(k) + + d = 2*f(x) - 2*Derivative(f(x), x) + Derivative(f(x), x, x) + assert hyper_re(d, r, k) == \ + (-2*k - 2)*r(k + 1) + (k + 1)*(k + 2)*r(k + 2) + 2*r(k) + + d = 2*n*f(x) + (x**2 - 1)*Derivative(f(x), x) + assert hyper_re(d, r, k) == \ + k*r(k) + 2*n*r(k + 1) + (-k - 2)*r(k + 2) + + d = (x**10 + 4)*Derivative(f(x), x) + x*(x**10 - 1)*Derivative(f(x), x, x) + assert hyper_re(d, r, k) == \ + (k*(k - 1) + k)*r(k) + (4*k - (k + 9)*(k + 10) + 40)*r(k + 10) + + d = ((x**2 - 1)*Derivative(f(x), x, 3) + 3*x*Derivative(f(x), x, x) + + Derivative(f(x), x)) + assert hyper_re(d, r, k) == \ + ((k*(k - 2)*(k - 1) + 3*k*(k - 1) + k)*r(k) + + (-k*(k + 1)*(k + 2))*r(k + 2)) + + +def test_fps(): + assert fps(1) == 1 + assert fps(2, x) == 2 + assert fps(2, x, dir='+') == 2 + assert fps(2, x, dir='-') == 2 + assert fps(1/x + 1/x**2) == 1/x + 1/x**2 + assert fps(log(1 + x), hyper=False, rational=False) == log(1 + x) + + f = fps(x**2 + x + 1) + assert isinstance(f, FormalPowerSeries) + assert f.function == x**2 + x + 1 + assert f[0] == 1 + assert f[2] == x**2 + assert f.truncate(4) == x**2 + x + 1 + O(x**4) + assert f.polynomial() == x**2 + x + 1 + + f = fps(log(1 + x)) + assert isinstance(f, FormalPowerSeries) + assert f.function == log(1 + x) + assert f.subs(x, y) == f + assert f[:5] == [0, x, -x**2/2, x**3/3, -x**4/4] + assert f.as_leading_term(x) == x + assert f.polynomial(6) == x - x**2/2 + x**3/3 - x**4/4 + x**5/5 + + k = f.ak.variables[0] + assert f.infinite == Sum((-(-1)**(-k)*x**k)/k, (k, 1, oo)) + + ft, s = f.truncate(n=None), f[:5] + for i, t in enumerate(ft): + if i == 5: + break + assert s[i] == t + + f = sin(x).fps(x) + assert isinstance(f, FormalPowerSeries) + assert f.truncate() == x - x**3/6 + x**5/120 + O(x**6) + + raises(NotImplementedError, lambda: fps(y*x)) + raises(ValueError, lambda: fps(x, dir=0)) + + +@slow +def test_fps__rational(): + assert fps(1/x) == (1/x) + assert fps((x**2 + x + 1) / x**3, dir=-1) == (x**2 + x + 1) / x**3 + + f = 1 / ((x - 1)**2 * (x - 2)) + assert fps(f, x).truncate() == \ + (Rational(-1, 2) - x*Rational(5, 4) - 17*x**2/8 - 49*x**3/16 - 129*x**4/32 - + 321*x**5/64 + O(x**6)) + + f = (1 + x + x**2 + x**3) / ((x - 1) * (x - 2)) + assert fps(f, x).truncate() == \ + (S.Half + x*Rational(5, 4) + 17*x**2/8 + 49*x**3/16 + 113*x**4/32 + + 241*x**5/64 + O(x**6)) + + f = x / (1 - x - x**2) + assert fps(f, x, full=True).truncate() == \ + x + x**2 + 2*x**3 + 3*x**4 + 5*x**5 + O(x**6) + + f = 1 / (x**2 + 2*x + 2) + assert fps(f, x, full=True).truncate() == \ + S.Half - x/2 + x**2/4 - x**4/8 + x**5/8 + O(x**6) + + f = log(1 + x) + assert fps(f, x).truncate() == \ + x - x**2/2 + x**3/3 - x**4/4 + x**5/5 + O(x**6) + assert fps(f, x, dir=1).truncate() == fps(f, x, dir=-1).truncate() + assert fps(f, x, 2).truncate() == \ + (log(3) - Rational(2, 3) - (x - 2)**2/18 + (x - 2)**3/81 - + (x - 2)**4/324 + (x - 2)**5/1215 + x/3 + O((x - 2)**6, (x, 2))) + assert fps(f, x, 2, dir=-1).truncate() == \ + (log(3) - Rational(2, 3) - (-x + 2)**2/18 - (-x + 2)**3/81 - + (-x + 2)**4/324 - (-x + 2)**5/1215 + x/3 + O((x - 2)**6, (x, 2))) + + f = atan(x) + assert fps(f, x, full=True).truncate() == x - x**3/3 + x**5/5 + O(x**6) + assert fps(f, x, full=True, dir=1).truncate() == \ + fps(f, x, full=True, dir=-1).truncate() + assert fps(f, x, 2, full=True).truncate() == \ + (atan(2) - Rational(2, 5) - 2*(x - 2)**2/25 + 11*(x - 2)**3/375 - + 6*(x - 2)**4/625 + 41*(x - 2)**5/15625 + x/5 + O((x - 2)**6, (x, 2))) + assert fps(f, x, 2, full=True, dir=-1).truncate() == \ + (atan(2) - Rational(2, 5) - 2*(-x + 2)**2/25 - 11*(-x + 2)**3/375 - + 6*(-x + 2)**4/625 - 41*(-x + 2)**5/15625 + x/5 + O((x - 2)**6, (x, 2))) + + f = x*atan(x) - log(1 + x**2) / 2 + assert fps(f, x, full=True).truncate() == x**2/2 - x**4/12 + O(x**6) + + f = log((1 + x) / (1 - x)) / 2 - atan(x) + assert fps(f, x, full=True).truncate(n=10) == 2*x**3/3 + 2*x**7/7 + O(x**10) + + +@slow +def test_fps__hyper(): + f = sin(x) + assert fps(f, x).truncate() == x - x**3/6 + x**5/120 + O(x**6) + + f = cos(x) + assert fps(f, x).truncate() == 1 - x**2/2 + x**4/24 + O(x**6) + + f = exp(x) + assert fps(f, x).truncate() == \ + 1 + x + x**2/2 + x**3/6 + x**4/24 + x**5/120 + O(x**6) + + f = atan(x) + assert fps(f, x).truncate() == x - x**3/3 + x**5/5 + O(x**6) + + f = exp(acos(x)) + assert fps(f, x).truncate() == \ + (exp(pi/2) - x*exp(pi/2) + x**2*exp(pi/2)/2 - x**3*exp(pi/2)/3 + + 5*x**4*exp(pi/2)/24 - x**5*exp(pi/2)/6 + O(x**6)) + + f = exp(acosh(x)) + assert fps(f, x).truncate() == I + x - I*x**2/2 - I*x**4/8 + O(x**6) + + f = atan(1/x) + assert fps(f, x).truncate() == pi/2 - x + x**3/3 - x**5/5 + O(x**6) + + f = x*atan(x) - log(1 + x**2) / 2 + assert fps(f, x, rational=False).truncate() == x**2/2 - x**4/12 + O(x**6) + + f = log(1 + x) + assert fps(f, x, rational=False).truncate() == \ + x - x**2/2 + x**3/3 - x**4/4 + x**5/5 + O(x**6) + + f = airyai(x**2) + assert fps(f, x).truncate() == \ + (3**Rational(5, 6)*gamma(Rational(1, 3))/(6*pi) - + 3**Rational(2, 3)*x**2/(3*gamma(Rational(1, 3))) + O(x**6)) + + f = exp(x)*sin(x) + assert fps(f, x).truncate() == x + x**2 + x**3/3 - x**5/30 + O(x**6) + + f = exp(x)*sin(x)/x + assert fps(f, x).truncate() == 1 + x + x**2/3 - x**4/30 - x**5/90 + O(x**6) + + f = sin(x) * cos(x) + assert fps(f, x).truncate() == x - 2*x**3/3 + 2*x**5/15 + O(x**6) + + +def test_fps_shift(): + f = x**-5*sin(x) + assert fps(f, x).truncate() == \ + 1/x**4 - 1/(6*x**2) + Rational(1, 120) - x**2/5040 + x**4/362880 + O(x**6) + + f = x**2*atan(x) + assert fps(f, x, rational=False).truncate() == \ + x**3 - x**5/3 + O(x**6) + + f = cos(sqrt(x))*x + assert fps(f, x).truncate() == \ + x - x**2/2 + x**3/24 - x**4/720 + x**5/40320 + O(x**6) + + f = x**2*cos(sqrt(x)) + assert fps(f, x).truncate() == \ + x**2 - x**3/2 + x**4/24 - x**5/720 + O(x**6) + + +def test_fps__Add_expr(): + f = x*atan(x) - log(1 + x**2) / 2 + assert fps(f, x).truncate() == x**2/2 - x**4/12 + O(x**6) + + f = sin(x) + cos(x) - exp(x) + log(1 + x) + assert fps(f, x).truncate() == x - 3*x**2/2 - x**4/4 + x**5/5 + O(x**6) + + f = 1/x + sin(x) + assert fps(f, x).truncate() == 1/x + x - x**3/6 + x**5/120 + O(x**6) + + f = sin(x) - cos(x) + 1/(x - 1) + assert fps(f, x).truncate() == \ + -2 - x**2/2 - 7*x**3/6 - 25*x**4/24 - 119*x**5/120 + O(x**6) + + +def test_fps__asymptotic(): + f = exp(x) + assert fps(f, x, oo) == f + assert fps(f, x, -oo).truncate() == O(1/x**6, (x, oo)) + + f = erf(x) + assert fps(f, x, oo).truncate() == 1 + O(1/x**6, (x, oo)) + assert fps(f, x, -oo).truncate() == -1 + O(1/x**6, (x, oo)) + + f = atan(x) + assert fps(f, x, oo, full=True).truncate() == \ + -1/(5*x**5) + 1/(3*x**3) - 1/x + pi/2 + O(1/x**6, (x, oo)) + assert fps(f, x, -oo, full=True).truncate() == \ + -1/(5*x**5) + 1/(3*x**3) - 1/x - pi/2 + O(1/x**6, (x, oo)) + + f = log(1 + x) + assert fps(f, x, oo) != \ + (-1/(5*x**5) - 1/(4*x**4) + 1/(3*x**3) - 1/(2*x**2) + 1/x - log(1/x) + + O(1/x**6, (x, oo))) + assert fps(f, x, -oo) != \ + (-1/(5*x**5) - 1/(4*x**4) + 1/(3*x**3) - 1/(2*x**2) + 1/x + I*pi - + log(-1/x) + O(1/x**6, (x, oo))) + + +def test_fps__fractional(): + f = sin(sqrt(x)) / x + assert fps(f, x).truncate() == \ + (1/sqrt(x) - sqrt(x)/6 + x**Rational(3, 2)/120 - + x**Rational(5, 2)/5040 + x**Rational(7, 2)/362880 - + x**Rational(9, 2)/39916800 + x**Rational(11, 2)/6227020800 + O(x**6)) + + f = sin(sqrt(x)) * x + assert fps(f, x).truncate() == \ + (x**Rational(3, 2) - x**Rational(5, 2)/6 + x**Rational(7, 2)/120 - + x**Rational(9, 2)/5040 + x**Rational(11, 2)/362880 + O(x**6)) + + f = atan(sqrt(x)) / x**2 + assert fps(f, x).truncate() == \ + (x**Rational(-3, 2) - x**Rational(-1, 2)/3 + x**S.Half/5 - + x**Rational(3, 2)/7 + x**Rational(5, 2)/9 - x**Rational(7, 2)/11 + + x**Rational(9, 2)/13 - x**Rational(11, 2)/15 + O(x**6)) + + f = exp(sqrt(x)) + assert fps(f, x).truncate().expand() == \ + (1 + x/2 + x**2/24 + x**3/720 + x**4/40320 + x**5/3628800 + sqrt(x) + + x**Rational(3, 2)/6 + x**Rational(5, 2)/120 + x**Rational(7, 2)/5040 + + x**Rational(9, 2)/362880 + x**Rational(11, 2)/39916800 + O(x**6)) + + f = exp(sqrt(x))*x + assert fps(f, x).truncate().expand() == \ + (x + x**2/2 + x**3/24 + x**4/720 + x**5/40320 + x**Rational(3, 2) + + x**Rational(5, 2)/6 + x**Rational(7, 2)/120 + x**Rational(9, 2)/5040 + + x**Rational(11, 2)/362880 + O(x**6)) + + +def test_fps__logarithmic_singularity(): + f = log(1 + 1/x) + assert fps(f, x) != \ + -log(x) + x - x**2/2 + x**3/3 - x**4/4 + x**5/5 + O(x**6) + assert fps(f, x, rational=False) != \ + -log(x) + x - x**2/2 + x**3/3 - x**4/4 + x**5/5 + O(x**6) + + +@XFAIL +def test_fps__logarithmic_singularity_fail(): + f = asech(x) # Algorithms for computing limits probably needs improvements + assert fps(f, x) == log(2) - log(x) - x**2/4 - 3*x**4/64 + O(x**6) + + +def test_fps_symbolic(): + f = x**n*sin(x**2) + assert fps(f, x).truncate(8) == x**(n + 2) - x**(n + 6)/6 + O(x**(n + 8), x) + + f = x**n*log(1 + x) + fp = fps(f, x) + k = fp.ak.variables[0] + assert fp.infinite == \ + Sum((-(-1)**(-k)*x**(k + n))/k, (k, 1, oo)) + + f = (x - 2)**n*log(1 + x) + assert fps(f, x, 2).truncate() == \ + ((x - 2)**n*log(3) + (x - 2)**(n + 1)/3 - (x - 2)**(n + 2)/18 + (x - 2)**(n + 3)/81 - + (x - 2)**(n + 4)/324 + (x - 2)**(n + 5)/1215 + O((x - 2)**(n + 6), (x, 2))) + + f = x**(n - 2)*cos(x) + assert fps(f, x).truncate() == \ + (x**(n - 2) - x**n/2 + x**(n + 2)/24 + O(x**(n + 4), x)) + + f = x**(n - 2)*sin(x) + x**n*exp(x) + assert fps(f, x).truncate() == \ + (x**(n - 1) + x**(n + 1) + x**(n + 2)/2 + x**n + + x**(n + 4)/24 + x**(n + 5)/60 + O(x**(n + 6), x)) + + f = x**n*atan(x) + assert fps(f, x, oo).truncate() == \ + (-x**(n - 5)/5 + x**(n - 3)/3 + x**n*(pi/2 - 1/x) + + O((1/x)**(-n)/x**6, (x, oo))) + + f = x**(n/2)*cos(x) + assert fps(f, x).truncate() == \ + x**(n/2) - x**(n/2 + 2)/2 + x**(n/2 + 4)/24 + O(x**(n/2 + 6), x) + + f = x**(n + m)*sin(x) + assert fps(f, x).truncate() == \ + x**(m + n + 1) - x**(m + n + 3)/6 + x**(m + n + 5)/120 + O(x**(m + n + 6), x) + + +def test_fps__slow(): + f = x*exp(x)*sin(2*x) # TODO: rsolve needs improvement + assert fps(f, x).truncate() == 2*x**2 + 2*x**3 - x**4/3 - x**5 + O(x**6) + + +def test_fps__operations(): + f1, f2 = fps(sin(x)), fps(cos(x)) + + fsum = f1 + f2 + assert fsum.function == sin(x) + cos(x) + assert fsum.truncate() == \ + 1 + x - x**2/2 - x**3/6 + x**4/24 + x**5/120 + O(x**6) + + fsum = f1 + 1 + assert fsum.function == sin(x) + 1 + assert fsum.truncate() == 1 + x - x**3/6 + x**5/120 + O(x**6) + + fsum = 1 + f2 + assert fsum.function == cos(x) + 1 + assert fsum.truncate() == 2 - x**2/2 + x**4/24 + O(x**6) + + assert (f1 + x) == Add(f1, x) + + assert -f2.truncate() == -1 + x**2/2 - x**4/24 + O(x**6) + assert (f1 - f1) is S.Zero + + fsub = f1 - f2 + assert fsub.function == sin(x) - cos(x) + assert fsub.truncate() == \ + -1 + x + x**2/2 - x**3/6 - x**4/24 + x**5/120 + O(x**6) + + fsub = f1 - 1 + assert fsub.function == sin(x) - 1 + assert fsub.truncate() == -1 + x - x**3/6 + x**5/120 + O(x**6) + + fsub = 1 - f2 + assert fsub.function == -cos(x) + 1 + assert fsub.truncate() == x**2/2 - x**4/24 + O(x**6) + + raises(ValueError, lambda: f1 + fps(exp(x), dir=-1)) + raises(ValueError, lambda: f1 + fps(exp(x), x0=1)) + + fm = f1 * 3 + + assert fm.function == 3*sin(x) + assert fm.truncate() == 3*x - x**3/2 + x**5/40 + O(x**6) + + fm = 3 * f2 + + assert fm.function == 3*cos(x) + assert fm.truncate() == 3 - 3*x**2/2 + x**4/8 + O(x**6) + + assert (f1 * f2) == Mul(f1, f2) + assert (f1 * x) == Mul(f1, x) + + fd = f1.diff() + assert fd.function == cos(x) + assert fd.truncate() == 1 - x**2/2 + x**4/24 + O(x**6) + + fd = f2.diff() + assert fd.function == -sin(x) + assert fd.truncate() == -x + x**3/6 - x**5/120 + O(x**6) + + fd = f2.diff().diff() + assert fd.function == -cos(x) + assert fd.truncate() == -1 + x**2/2 - x**4/24 + O(x**6) + + f3 = fps(exp(sqrt(x))) + fd = f3.diff() + assert fd.truncate().expand() == \ + (1/(2*sqrt(x)) + S.Half + x/12 + x**2/240 + x**3/10080 + x**4/725760 + + x**5/79833600 + sqrt(x)/4 + x**Rational(3, 2)/48 + x**Rational(5, 2)/1440 + + x**Rational(7, 2)/80640 + x**Rational(9, 2)/7257600 + x**Rational(11, 2)/958003200 + + O(x**6)) + + assert f1.integrate((x, 0, 1)) == -cos(1) + 1 + assert integrate(f1, (x, 0, 1)) == -cos(1) + 1 + + fi = integrate(f1, x) + assert fi.function == -cos(x) + assert fi.truncate() == -1 + x**2/2 - x**4/24 + O(x**6) + + fi = f2.integrate(x) + assert fi.function == sin(x) + assert fi.truncate() == x - x**3/6 + x**5/120 + O(x**6) + +def test_fps__product(): + f1, f2, f3 = fps(sin(x)), fps(exp(x)), fps(cos(x)) + + raises(ValueError, lambda: f1.product(exp(x), x)) + raises(ValueError, lambda: f1.product(fps(exp(x), dir=-1), x, 4)) + raises(ValueError, lambda: f1.product(fps(exp(x), x0=1), x, 4)) + raises(ValueError, lambda: f1.product(fps(exp(y)), x, 4)) + + fprod = f1.product(f2, x) + assert isinstance(fprod, FormalPowerSeriesProduct) + assert isinstance(fprod.ffps, FormalPowerSeries) + assert isinstance(fprod.gfps, FormalPowerSeries) + assert fprod.f == sin(x) + assert fprod.g == exp(x) + assert fprod.function == sin(x) * exp(x) + assert fprod._eval_terms(4) == x + x**2 + x**3/3 + assert fprod.truncate(4) == x + x**2 + x**3/3 + O(x**4) + assert fprod.polynomial(4) == x + x**2 + x**3/3 + + raises(NotImplementedError, lambda: fprod._eval_term(5)) + raises(NotImplementedError, lambda: fprod.infinite) + raises(NotImplementedError, lambda: fprod._eval_derivative(x)) + raises(NotImplementedError, lambda: fprod.integrate(x)) + + assert f1.product(f3, x)._eval_terms(4) == x - 2*x**3/3 + assert f1.product(f3, x).truncate(4) == x - 2*x**3/3 + O(x**4) + + +def test_fps__compose(): + f1, f2, f3 = fps(exp(x)), fps(sin(x)), fps(cos(x)) + + raises(ValueError, lambda: f1.compose(sin(x), x)) + raises(ValueError, lambda: f1.compose(fps(sin(x), dir=-1), x, 4)) + raises(ValueError, lambda: f1.compose(fps(sin(x), x0=1), x, 4)) + raises(ValueError, lambda: f1.compose(fps(sin(y)), x, 4)) + + raises(ValueError, lambda: f1.compose(f3, x)) + raises(ValueError, lambda: f2.compose(f3, x)) + + fcomp = f1.compose(f2, x) + assert isinstance(fcomp, FormalPowerSeriesCompose) + assert isinstance(fcomp.ffps, FormalPowerSeries) + assert isinstance(fcomp.gfps, FormalPowerSeries) + assert fcomp.f == exp(x) + assert fcomp.g == sin(x) + assert fcomp.function == exp(sin(x)) + assert fcomp._eval_terms(6) == 1 + x + x**2/2 - x**4/8 - x**5/15 + assert fcomp.truncate() == 1 + x + x**2/2 - x**4/8 - x**5/15 + O(x**6) + assert fcomp.truncate(5) == 1 + x + x**2/2 - x**4/8 + O(x**5) + + raises(NotImplementedError, lambda: fcomp._eval_term(5)) + raises(NotImplementedError, lambda: fcomp.infinite) + raises(NotImplementedError, lambda: fcomp._eval_derivative(x)) + raises(NotImplementedError, lambda: fcomp.integrate(x)) + + assert f1.compose(f2, x).truncate(4) == 1 + x + x**2/2 + O(x**4) + assert f1.compose(f2, x).truncate(8) == \ + 1 + x + x**2/2 - x**4/8 - x**5/15 - x**6/240 + x**7/90 + O(x**8) + assert f1.compose(f2, x).truncate(6) == \ + 1 + x + x**2/2 - x**4/8 - x**5/15 + O(x**6) + + assert f2.compose(f2, x).truncate(4) == x - x**3/3 + O(x**4) + assert f2.compose(f2, x).truncate(8) == x - x**3/3 + x**5/10 - 8*x**7/315 + O(x**8) + assert f2.compose(f2, x).truncate(6) == x - x**3/3 + x**5/10 + O(x**6) + + +def test_fps__inverse(): + f1, f2, f3 = fps(sin(x)), fps(exp(x)), fps(cos(x)) + + raises(ValueError, lambda: f1.inverse(x)) + + finv = f2.inverse(x) + assert isinstance(finv, FormalPowerSeriesInverse) + assert isinstance(finv.ffps, FormalPowerSeries) + raises(ValueError, lambda: finv.gfps) + + assert finv.f == exp(x) + assert finv.function == exp(-x) + assert finv._eval_terms(5) == 1 - x + x**2/2 - x**3/6 + x**4/24 + assert finv.truncate() == 1 - x + x**2/2 - x**3/6 + x**4/24 - x**5/120 + O(x**6) + assert finv.truncate(5) == 1 - x + x**2/2 - x**3/6 + x**4/24 + O(x**5) + + raises(NotImplementedError, lambda: finv._eval_term(5)) + raises(ValueError, lambda: finv.g) + raises(NotImplementedError, lambda: finv.infinite) + raises(NotImplementedError, lambda: finv._eval_derivative(x)) + raises(NotImplementedError, lambda: finv.integrate(x)) + + assert f2.inverse(x).truncate(8) == \ + 1 - x + x**2/2 - x**3/6 + x**4/24 - x**5/120 + x**6/720 - x**7/5040 + O(x**8) + + assert f3.inverse(x).truncate() == 1 + x**2/2 + 5*x**4/24 + O(x**6) + assert f3.inverse(x).truncate(8) == 1 + x**2/2 + 5*x**4/24 + 61*x**6/720 + O(x**8) diff --git a/phivenv/Lib/site-packages/sympy/series/tests/test_fourier.py b/phivenv/Lib/site-packages/sympy/series/tests/test_fourier.py new file mode 100644 index 0000000000000000000000000000000000000000..994f182088b09b038e0e1b3885fec1c27f69f2b0 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/series/tests/test_fourier.py @@ -0,0 +1,165 @@ +from sympy.core.add import Add +from sympy.core.numbers import (Rational, oo, pi) +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import (cos, sin, sinc, tan) +from sympy.series.fourier import fourier_series +from sympy.series.fourier import FourierSeries +from sympy.testing.pytest import raises +from functools import lru_cache + +x, y, z = symbols('x y z') + +# Don't declare these during import because they are slow +@lru_cache() +def _get_examples(): + fo = fourier_series(x, (x, -pi, pi)) + fe = fourier_series(x**2, (-pi, pi)) + fp = fourier_series(Piecewise((0, x < 0), (pi, True)), (x, -pi, pi)) + return fo, fe, fp + + +def test_FourierSeries(): + fo, fe, fp = _get_examples() + + assert fourier_series(1, (-pi, pi)) == 1 + assert (Piecewise((0, x < 0), (pi, True)). + fourier_series((x, -pi, pi)).truncate()) == fp.truncate() + assert isinstance(fo, FourierSeries) + assert fo.function == x + assert fo.x == x + assert fo.period == (-pi, pi) + + assert fo.term(3) == 2*sin(3*x) / 3 + assert fe.term(3) == -4*cos(3*x) / 9 + assert fp.term(3) == 2*sin(3*x) / 3 + + assert fo.as_leading_term(x) == 2*sin(x) + assert fe.as_leading_term(x) == pi**2 / 3 + assert fp.as_leading_term(x) == pi / 2 + + assert fo.truncate() == 2*sin(x) - sin(2*x) + (2*sin(3*x) / 3) + assert fe.truncate() == -4*cos(x) + cos(2*x) + pi**2 / 3 + assert fp.truncate() == 2*sin(x) + (2*sin(3*x) / 3) + pi / 2 + + fot = fo.truncate(n=None) + s = [0, 2*sin(x), -sin(2*x)] + for i, t in enumerate(fot): + if i == 3: + break + assert s[i] == t + + def _check_iter(f, i): + for ind, t in enumerate(f): + assert t == f[ind] # noqa: PLR1736 + if ind == i: + break + + _check_iter(fo, 3) + _check_iter(fe, 3) + _check_iter(fp, 3) + + assert fo.subs(x, x**2) == fo + + raises(ValueError, lambda: fourier_series(x, (0, 1, 2))) + raises(ValueError, lambda: fourier_series(x, (x, 0, oo))) + raises(ValueError, lambda: fourier_series(x*y, (0, oo))) + + +def test_FourierSeries_2(): + p = Piecewise((0, x < 0), (x, True)) + f = fourier_series(p, (x, -2, 2)) + + assert f.term(3) == (2*sin(3*pi*x / 2) / (3*pi) - + 4*cos(3*pi*x / 2) / (9*pi**2)) + assert f.truncate() == (2*sin(pi*x / 2) / pi - sin(pi*x) / pi - + 4*cos(pi*x / 2) / pi**2 + S.Half) + + +def test_square_wave(): + """Test if fourier_series approximates discontinuous function correctly.""" + square_wave = Piecewise((1, x < pi), (-1, True)) + s = fourier_series(square_wave, (x, 0, 2*pi)) + + assert s.truncate(3) == 4 / pi * sin(x) + 4 / (3 * pi) * sin(3 * x) + \ + 4 / (5 * pi) * sin(5 * x) + assert s.sigma_approximation(4) == 4 / pi * sin(x) * sinc(pi / 4) + \ + 4 / (3 * pi) * sin(3 * x) * sinc(3 * pi / 4) + + +def test_sawtooth_wave(): + s = fourier_series(x, (x, 0, pi)) + assert s.truncate(4) == \ + pi/2 - sin(2*x) - sin(4*x)/2 - sin(6*x)/3 + s = fourier_series(x, (x, 0, 1)) + assert s.truncate(4) == \ + S.Half - sin(2*pi*x)/pi - sin(4*pi*x)/(2*pi) - sin(6*pi*x)/(3*pi) + + +def test_FourierSeries__operations(): + fo, fe, fp = _get_examples() + + fes = fe.scale(-1).shift(pi**2) + assert fes.truncate() == 4*cos(x) - cos(2*x) + 2*pi**2 / 3 + + assert fp.shift(-pi/2).truncate() == (2*sin(x) + (2*sin(3*x) / 3) + + (2*sin(5*x) / 5)) + + fos = fo.scale(3) + assert fos.truncate() == 6*sin(x) - 3*sin(2*x) + 2*sin(3*x) + + fx = fe.scalex(2).shiftx(1) + assert fx.truncate() == -4*cos(2*x + 2) + cos(4*x + 4) + pi**2 / 3 + + fl = fe.scalex(3).shift(-pi).scalex(2).shiftx(1).scale(4) + assert fl.truncate() == (-16*cos(6*x + 6) + 4*cos(12*x + 12) - + 4*pi + 4*pi**2 / 3) + + raises(ValueError, lambda: fo.shift(x)) + raises(ValueError, lambda: fo.shiftx(sin(x))) + raises(ValueError, lambda: fo.scale(x*y)) + raises(ValueError, lambda: fo.scalex(x**2)) + + +def test_FourierSeries__neg(): + fo, fe, fp = _get_examples() + + assert (-fo).truncate() == -2*sin(x) + sin(2*x) - (2*sin(3*x) / 3) + assert (-fe).truncate() == +4*cos(x) - cos(2*x) - pi**2 / 3 + + +def test_FourierSeries__add__sub(): + fo, fe, fp = _get_examples() + + assert fo + fo == fo.scale(2) + assert fo - fo == 0 + assert -fe - fe == fe.scale(-2) + + assert (fo + fe).truncate() == 2*sin(x) - sin(2*x) - 4*cos(x) + cos(2*x) \ + + pi**2 / 3 + assert (fo - fe).truncate() == 2*sin(x) - sin(2*x) + 4*cos(x) - cos(2*x) \ + - pi**2 / 3 + + assert isinstance(fo + 1, Add) + + raises(ValueError, lambda: fo + fourier_series(x, (x, 0, 2))) + + +def test_FourierSeries_finite(): + + assert fourier_series(sin(x)).truncate(1) == sin(x) + # assert type(fourier_series(sin(x)*log(x))).truncate() == FourierSeries + # assert type(fourier_series(sin(x**2+6))).truncate() == FourierSeries + assert fourier_series(sin(x)*log(y)*exp(z),(x,pi,-pi)).truncate() == sin(x)*log(y)*exp(z) + assert fourier_series(sin(x)**6).truncate(oo) == -15*cos(2*x)/32 + 3*cos(4*x)/16 - cos(6*x)/32 \ + + Rational(5, 16) + assert fourier_series(sin(x) ** 6).truncate() == -15 * cos(2 * x) / 32 + 3 * cos(4 * x) / 16 \ + + Rational(5, 16) + assert fourier_series(sin(4*x+3) + cos(3*x+4)).truncate(oo) == -sin(4)*sin(3*x) + sin(4*x)*cos(3) \ + + cos(4)*cos(3*x) + sin(3)*cos(4*x) + assert fourier_series(sin(x)+cos(x)*tan(x)).truncate(oo) == 2*sin(x) + assert fourier_series(cos(pi*x), (x, -1, 1)).truncate(oo) == cos(pi*x) + assert fourier_series(cos(3*pi*x + 4) - sin(4*pi*x)*log(pi*y), (x, -1, 1)).truncate(oo) == -log(pi*y)*sin(4*pi*x)\ + - sin(4)*sin(3*pi*x) + cos(4)*cos(3*pi*x) diff --git a/phivenv/Lib/site-packages/sympy/series/tests/test_gruntz.py b/phivenv/Lib/site-packages/sympy/series/tests/test_gruntz.py new file mode 100644 index 0000000000000000000000000000000000000000..4cae15297048bc52a69a3d9ca57a7614cfcdc61c --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/series/tests/test_gruntz.py @@ -0,0 +1,490 @@ +from sympy.core import EulerGamma +from sympy.core.function import Function +from sympy.core.numbers import (E, I, Integer, Rational, oo, pi) +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (acot, atan, cos, sin) +from sympy.functions.special.error_functions import (Ei, erf) +from sympy.functions.special.gamma_functions import (digamma, gamma, loggamma) +from sympy.functions.special.zeta_functions import zeta +from sympy.polys.polytools import cancel +from sympy.functions.elementary.hyperbolic import cosh, coth, sinh, tanh +from sympy.series.gruntz import compare, mrv, rewrite, mrv_leadterm, gruntz, \ + sign +from sympy.testing.pytest import XFAIL, raises, skip, slow + +""" +This test suite is testing the limit algorithm using the bottom up approach. +See the documentation in limits2.py. The algorithm itself is highly recursive +by nature, so "compare" is logically the lowest part of the algorithm, yet in +some sense it's the most complex part, because it needs to calculate a limit +to return the result. + +Nevertheless, the rest of the algorithm depends on compare working correctly. +""" + +x = Symbol('x', real=True) +m = Symbol('m', real=True) + + +runslow = False + + +def _sskip(): + if not runslow: + skip("slow") + + +@slow +def test_gruntz_evaluation(): + # Gruntz' thesis pp. 122 to 123 + # 8.1 + assert gruntz(exp(x)*(exp(1/x - exp(-x)) - exp(1/x)), x, oo) == -1 + # 8.2 + assert gruntz(exp(x)*(exp(1/x + exp(-x) + exp(-x**2)) + - exp(1/x - exp(-exp(x)))), x, oo) == 1 + # 8.3 + assert gruntz(exp(exp(x - exp(-x))/(1 - 1/x)) - exp(exp(x)), x, oo) is oo + # 8.5 + assert gruntz(exp(exp(exp(x + exp(-x)))) / exp(exp(exp(x))), x, oo) is oo + # 8.6 + assert gruntz(exp(exp(exp(x))) / exp(exp(exp(x - exp(-exp(x))))), + x, oo) is oo + # 8.7 + assert gruntz(exp(exp(exp(x))) / exp(exp(exp(x - exp(-exp(exp(x)))))), + x, oo) == 1 + # 8.8 + assert gruntz(exp(exp(x)) / exp(exp(x - exp(-exp(exp(x))))), x, oo) == 1 + # 8.9 + assert gruntz(log(x)**2 * exp(sqrt(log(x))*(log(log(x)))**2 + * exp(sqrt(log(log(x))) * (log(log(log(x))))**3)) / sqrt(x), + x, oo) == 0 + # 8.10 + assert gruntz((x*log(x)*(log(x*exp(x) - x**2))**2) + / (log(log(x**2 + 2*exp(exp(3*x**3*log(x)))))), x, oo) == Rational(1, 3) + # 8.11 + assert gruntz((exp(x*exp(-x)/(exp(-x) + exp(-2*x**2/(x + 1)))) - exp(x))/x, + x, oo) == -exp(2) + # 8.12 + assert gruntz((3**x + 5**x)**(1/x), x, oo) == 5 + # 8.13 + assert gruntz(x/log(x**(log(x**(log(2)/log(x))))), x, oo) is oo + # 8.14 + assert gruntz(exp(exp(2*log(x**5 + x)*log(log(x)))) + / exp(exp(10*log(x)*log(log(x)))), x, oo) is oo + # 8.15 + assert gruntz(exp(exp(Rational(5, 2)*x**Rational(-5, 7) + Rational(21, 8)*x**Rational(6, 11) + + 2*x**(-8) + Rational(54, 17)*x**Rational(49, 45)))**8 + / log(log(-log(Rational(4, 3)*x**Rational(-5, 14))))**Rational(7, 6), x, oo) is oo + # 8.16 + assert gruntz((exp(4*x*exp(-x)/(1/exp(x) + 1/exp(2*x**2/(x + 1)))) - exp(x)) + / exp(x)**4, x, oo) == 1 + # 8.17 + assert gruntz(exp(x*exp(-x)/(exp(-x) + exp(-2*x**2/(x + 1))))/exp(x), x, oo) \ + == 1 + # 8.19 + assert gruntz(log(x)*(log(log(x) + log(log(x))) - log(log(x))) + / (log(log(x) + log(log(log(x))))), x, oo) == 1 + # 8.20 + assert gruntz(exp((log(log(x + exp(log(x)*log(log(x)))))) + / (log(log(log(exp(x) + x + log(x)))))), x, oo) == E + # Another + assert gruntz(exp(exp(exp(x + exp(-x)))) / exp(exp(x)), x, oo) is oo + + +def test_gruntz_evaluation_slow(): + _sskip() + # 8.4 + assert gruntz(exp(exp(exp(x)/(1 - 1/x))) + - exp(exp(exp(x)/(1 - 1/x - log(x)**(-log(x))))), x, oo) is -oo + # 8.18 + assert gruntz((exp(exp(-x/(1 + exp(-x))))*exp(-x/(1 + exp(-x/(1 + exp(-x))))) + *exp(exp(-x + exp(-x/(1 + exp(-x)))))) + / (exp(-x/(1 + exp(-x))))**2 - exp(x) + x, x, oo) == 2 + + +@slow +def test_gruntz_eval_special(): + # Gruntz, p. 126 + assert gruntz(exp(x)*(sin(1/x + exp(-x)) - sin(1/x + exp(-x**2))), x, oo) == 1 + assert gruntz((erf(x - exp(-exp(x))) - erf(x)) * exp(exp(x)) * exp(x**2), + x, oo) == -2/sqrt(pi) + assert gruntz(exp(exp(x)) * (exp(sin(1/x + exp(-exp(x)))) - exp(sin(1/x))), + x, oo) == 1 + assert gruntz(exp(x)*(gamma(x + exp(-x)) - gamma(x)), x, oo) is oo + assert gruntz(exp(exp(digamma(digamma(x))))/x, x, oo) == exp(Rational(-1, 2)) + assert gruntz(exp(exp(digamma(log(x))))/x, x, oo) == exp(Rational(-1, 2)) + assert gruntz(digamma(digamma(digamma(x))), x, oo) is oo + assert gruntz(loggamma(loggamma(x)), x, oo) is oo + assert gruntz(((gamma(x + 1/gamma(x)) - gamma(x))/log(x) - cos(1/x)) + * x*log(x), x, oo) == Rational(-1, 2) + assert gruntz(x * (gamma(x - 1/gamma(x)) - gamma(x) + log(x)), x, oo) \ + == S.Half + assert gruntz((gamma(x + 1/gamma(x)) - gamma(x)) / log(x), x, oo) == 1 + + +def test_gruntz_eval_special_slow(): + _sskip() + assert gruntz(gamma(x + 1)/sqrt(2*pi) + - exp(-x)*(x**(x + S.Half) + x**(x - S.Half)/12), x, oo) is oo + assert gruntz(exp(exp(exp(digamma(digamma(digamma(x))))))/x, x, oo) == 0 + + +@XFAIL +def test_grunts_eval_special_slow_sometimes_fail(): + _sskip() + # XXX This sometimes fails!!! + assert gruntz(exp(gamma(x - exp(-x))*exp(1/x)) - exp(gamma(x)), x, oo) is oo + + +def test_gruntz_Ei(): + assert gruntz((Ei(x - exp(-exp(x))) - Ei(x)) *exp(-x)*exp(exp(x))*x, x, oo) == -1 + + +@XFAIL +def test_gruntz_eval_special_fail(): + # TODO zeta function series + assert gruntz( + exp((log(2) + 1)*x) * (zeta(x + exp(-x)) - zeta(x)), x, oo) == -log(2) + + # TODO 8.35 - 8.37 (bessel, max-min) + + +def test_gruntz_hyperbolic(): + assert gruntz(cosh(x), x, oo) is oo + assert gruntz(cosh(x), x, -oo) is oo + assert gruntz(sinh(x), x, oo) is oo + assert gruntz(sinh(x), x, -oo) is -oo + assert gruntz(2*cosh(x)*exp(x), x, oo) is oo + assert gruntz(2*cosh(x)*exp(x), x, -oo) == 1 + assert gruntz(2*sinh(x)*exp(x), x, oo) is oo + assert gruntz(2*sinh(x)*exp(x), x, -oo) == -1 + assert gruntz(tanh(x), x, oo) == 1 + assert gruntz(tanh(x), x, -oo) == -1 + assert gruntz(coth(x), x, oo) == 1 + assert gruntz(coth(x), x, -oo) == -1 + + +def test_compare1(): + assert compare(2, x, x) == "<" + assert compare(x, exp(x), x) == "<" + assert compare(exp(x), exp(x**2), x) == "<" + assert compare(exp(x**2), exp(exp(x)), x) == "<" + assert compare(1, exp(exp(x)), x) == "<" + + assert compare(x, 2, x) == ">" + assert compare(exp(x), x, x) == ">" + assert compare(exp(x**2), exp(x), x) == ">" + assert compare(exp(exp(x)), exp(x**2), x) == ">" + assert compare(exp(exp(x)), 1, x) == ">" + + assert compare(2, 3, x) == "=" + assert compare(3, -5, x) == "=" + assert compare(2, -5, x) == "=" + + assert compare(x, x**2, x) == "=" + assert compare(x**2, x**3, x) == "=" + assert compare(x**3, 1/x, x) == "=" + assert compare(1/x, x**m, x) == "=" + assert compare(x**m, -x, x) == "=" + + assert compare(exp(x), exp(-x), x) == "=" + assert compare(exp(-x), exp(2*x), x) == "=" + assert compare(exp(2*x), exp(x)**2, x) == "=" + assert compare(exp(x)**2, exp(x + exp(-x)), x) == "=" + assert compare(exp(x), exp(x + exp(-x)), x) == "=" + + assert compare(exp(x**2), 1/exp(x**2), x) == "=" + + +def test_compare2(): + assert compare(exp(x), x**5, x) == ">" + assert compare(exp(x**2), exp(x)**2, x) == ">" + assert compare(exp(x), exp(x + exp(-x)), x) == "=" + assert compare(exp(x + exp(-x)), exp(x), x) == "=" + assert compare(exp(x + exp(-x)), exp(-x), x) == "=" + assert compare(exp(-x), x, x) == ">" + assert compare(x, exp(-x), x) == "<" + assert compare(exp(x + 1/x), x, x) == ">" + assert compare(exp(-exp(x)), exp(x), x) == ">" + assert compare(exp(exp(-exp(x)) + x), exp(-exp(x)), x) == "<" + + +def test_compare3(): + assert compare(exp(exp(x)), exp(x + exp(-exp(x))), x) == ">" + + +def test_sign1(): + assert sign(Rational(0), x) == 0 + assert sign(Rational(3), x) == 1 + assert sign(Rational(-5), x) == -1 + assert sign(log(x), x) == 1 + assert sign(exp(-x), x) == 1 + assert sign(exp(x), x) == 1 + assert sign(-exp(x), x) == -1 + assert sign(3 - 1/x, x) == 1 + assert sign(-3 - 1/x, x) == -1 + assert sign(sin(1/x), x) == 1 + assert sign((x**Integer(2)), x) == 1 + assert sign(x**2, x) == 1 + assert sign(x**5, x) == 1 + + +def test_sign2(): + assert sign(x, x) == 1 + assert sign(-x, x) == -1 + y = Symbol("y", positive=True) + assert sign(y, x) == 1 + assert sign(-y, x) == -1 + assert sign(y*x, x) == 1 + assert sign(-y*x, x) == -1 + + +def mmrv(a, b): + return set(mrv(a, b)[0].keys()) + + +def test_mrv1(): + assert mmrv(x, x) == {x} + assert mmrv(x + 1/x, x) == {x} + assert mmrv(x**2, x) == {x} + assert mmrv(log(x), x) == {x} + assert mmrv(exp(x), x) == {exp(x)} + assert mmrv(exp(-x), x) == {exp(-x)} + assert mmrv(exp(x**2), x) == {exp(x**2)} + assert mmrv(-exp(1/x), x) == {x} + assert mmrv(exp(x + 1/x), x) == {exp(x + 1/x)} + + +def test_mrv2a(): + assert mmrv(exp(x + exp(-exp(x))), x) == {exp(-exp(x))} + assert mmrv(exp(x + exp(-x)), x) == {exp(x + exp(-x)), exp(-x)} + assert mmrv(exp(1/x + exp(-x)), x) == {exp(-x)} + +#sometimes infinite recursion due to log(exp(x**2)) not simplifying + + +def test_mrv2b(): + assert mmrv(exp(x + exp(-x**2)), x) == {exp(-x**2)} + +#sometimes infinite recursion due to log(exp(x**2)) not simplifying + + +def test_mrv2c(): + assert mmrv( + exp(-x + 1/x**2) - exp(x + 1/x), x) == {exp(x + 1/x), exp(1/x**2 - x)} + +#sometimes infinite recursion due to log(exp(x**2)) not simplifying + + +def test_mrv3(): + assert mmrv(exp(x**2) + x*exp(x) + log(x)**x/x, x) == {exp(x**2)} + assert mmrv( + exp(x)*(exp(1/x + exp(-x)) - exp(1/x)), x) == {exp(x), exp(-x)} + assert mmrv(log( + x**2 + 2*exp(exp(3*x**3*log(x)))), x) == {exp(exp(3*x**3*log(x)))} + assert mmrv(log(x - log(x))/log(x), x) == {x} + assert mmrv( + (exp(1/x - exp(-x)) - exp(1/x))*exp(x), x) == {exp(x), exp(-x)} + assert mmrv( + 1/exp(-x + exp(-x)) - exp(x), x) == {exp(x), exp(-x), exp(x - exp(-x))} + assert mmrv(log(log(x*exp(x*exp(x)) + 1)), x) == {exp(x*exp(x))} + assert mmrv(exp(exp(log(log(x) + 1/x))), x) == {x} + + +def test_mrv4(): + ln = log + assert mmrv((ln(ln(x) + ln(ln(x))) - ln(ln(x)))/ln(ln(x) + ln(ln(ln(x))))*ln(x), + x) == {x} + assert mmrv(log(log(x*exp(x*exp(x)) + 1)) - exp(exp(log(log(x) + 1/x))), x) == \ + {exp(x*exp(x))} + + +def mrewrite(a, b, c): + return rewrite(a[1], a[0], b, c) + + +def test_rewrite1(): + e = exp(x) + assert mrewrite(mrv(e, x), x, m) == (1/m, -x) + e = exp(x**2) + assert mrewrite(mrv(e, x), x, m) == (1/m, -x**2) + e = exp(x + 1/x) + assert mrewrite(mrv(e, x), x, m) == (1/m, -x - 1/x) + e = 1/exp(-x + exp(-x)) - exp(x) + assert mrewrite(mrv(e, x), x, m) == ((-m*exp(m) + m)*exp(-m)/m**2, -x) + + +def test_rewrite2(): + e = exp(x)*log(log(exp(x))) + assert mmrv(e, x) == {exp(x)} + assert mrewrite(mrv(e, x), x, m) == (1/m*log(x), -x) + +#sometimes infinite recursion due to log(exp(x**2)) not simplifying + + +def test_rewrite3(): + e = exp(-x + 1/x**2) - exp(x + 1/x) + #both of these are correct and should be equivalent: + assert mrewrite(mrv(e, x), x, m) in [(-1/m + m*exp( + (x**2 + x)/x**3), -x - 1/x), ((m**2 - exp((x**2 + x)/x**3))/m, x**(-2) - x)] + + +def test_mrv_leadterm1(): + assert mrv_leadterm(-exp(1/x), x) == (-1, 0) + assert mrv_leadterm(1/exp(-x + exp(-x)) - exp(x), x) == (-1, 0) + assert mrv_leadterm( + (exp(1/x - exp(-x)) - exp(1/x))*exp(x), x) == (-exp(1/x), 0) + + +def test_mrv_leadterm2(): + #Gruntz: p51, 3.25 + assert mrv_leadterm((log(exp(x) + x) - x)/log(exp(x) + log(x))*exp(x), x) == \ + (1, 0) + + +def test_mrv_leadterm3(): + #Gruntz: p56, 3.27 + assert mmrv(exp(-x + exp(-x)*exp(-x*log(x))), x) == {exp(-x - x*log(x))} + assert mrv_leadterm(exp(-x + exp(-x)*exp(-x*log(x))), x) == (exp(-x), 0) + + +def test_limit1(): + assert gruntz(x, x, oo) is oo + assert gruntz(x, x, -oo) is -oo + assert gruntz(-x, x, oo) is -oo + assert gruntz(x**2, x, -oo) is oo + assert gruntz(-x**2, x, oo) is -oo + assert gruntz(x*log(x), x, 0, dir="+") == 0 + assert gruntz(1/x, x, oo) == 0 + assert gruntz(exp(x), x, oo) is oo + assert gruntz(-exp(x), x, oo) is -oo + assert gruntz(exp(x)/x, x, oo) is oo + assert gruntz(1/x - exp(-x), x, oo) == 0 + assert gruntz(x + 1/x, x, oo) is oo + + +def test_limit2(): + assert gruntz(x**x, x, 0, dir="+") == 1 + assert gruntz((exp(x) - 1)/x, x, 0) == 1 + assert gruntz(1 + 1/x, x, oo) == 1 + assert gruntz(-exp(1/x), x, oo) == -1 + assert gruntz(x + exp(-x), x, oo) is oo + assert gruntz(x + exp(-x**2), x, oo) is oo + assert gruntz(x + exp(-exp(x)), x, oo) is oo + assert gruntz(13 + 1/x - exp(-x), x, oo) == 13 + + +def test_limit3(): + a = Symbol('a') + assert gruntz(x - log(1 + exp(x)), x, oo) == 0 + assert gruntz(x - log(a + exp(x)), x, oo) == 0 + assert gruntz(exp(x)/(1 + exp(x)), x, oo) == 1 + assert gruntz(exp(x)/(a + exp(x)), x, oo) == 1 + + +def test_limit4(): + #issue 3463 + assert gruntz((3**x + 5**x)**(1/x), x, oo) == 5 + #issue 3463 + assert gruntz((3**(1/x) + 5**(1/x))**x, x, 0) == 5 + + +@XFAIL +def test_MrvTestCase_page47_ex3_21(): + h = exp(-x/(1 + exp(-x))) + expr = exp(h)*exp(-x/(1 + h))*exp(exp(-x + h))/h**2 - exp(x) + x + assert mmrv(expr, x) == {1/h, exp(-x), exp(x), exp(x - h), exp(x/(1 + h))} + + +def test_gruntz_I(): + y = Symbol("y") + assert gruntz(I*x, x, oo) == I*oo + assert gruntz(y*I*x, x, oo) == y*I*oo + assert gruntz(y*3*I*x, x, oo) == y*I*oo + assert gruntz(y*3*sin(I)*x, x, oo) == y*I*oo + + +def test_issue_4814(): + assert gruntz((x + 1)**(1/log(x + 1)), x, oo) == E + + +def test_intractable(): + assert gruntz(1/gamma(x), x, oo) == 0 + assert gruntz(1/loggamma(x), x, oo) == 0 + assert gruntz(gamma(x)/loggamma(x), x, oo) is oo + assert gruntz(exp(gamma(x))/gamma(x), x, oo) is oo + assert gruntz(gamma(x), x, 3) == 2 + assert gruntz(gamma(Rational(1, 7) + 1/x), x, oo) == gamma(Rational(1, 7)) + assert gruntz(log(x**x)/log(gamma(x)), x, oo) == 1 + assert gruntz(log(gamma(gamma(x)))/exp(x), x, oo) is oo + + +def test_aseries_trig(): + assert cancel(gruntz(1/log(atan(x)), x, oo) + - 1/(log(pi) + log(S.Half))) == 0 + assert gruntz(1/acot(x), x, -oo) is -oo + + +def test_exp_log_series(): + assert gruntz(x/log(log(x*exp(x))), x, oo) is oo + + +def test_issue_3644(): + assert gruntz(((x**7 + x + 1)/(2**x + x**2))**(-1/x), x, oo) == 2 + + +def test_issue_6843(): + n = Symbol('n', integer=True, positive=True) + r = (n + 1)*x**(n + 1)/(x**(n + 1) - 1) - x/(x - 1) + assert gruntz(r, x, 1).simplify() == n/2 + + +def test_issue_4190(): + assert gruntz(x - gamma(1/x), x, oo) == S.EulerGamma + + +@XFAIL +def test_issue_5172(): + n = Symbol('n') + r = Symbol('r', positive=True) + c = Symbol('c') + p = Symbol('p', positive=True) + m = Symbol('m', negative=True) + expr = ((2*n*(n - r + 1)/(n + r*(n - r + 1)))**c + \ + (r - 1)*(n*(n - r + 2)/(n + r*(n - r + 1)))**c - n)/(n**c - n) + expr = expr.subs(c, c + 1) + assert gruntz(expr.subs(c, m), n, oo) == 1 + # fail: + assert gruntz(expr.subs(c, p), n, oo).simplify() == \ + (2**(p + 1) + r - 1)/(r + 1)**(p + 1) + + +def test_issue_4109(): + assert gruntz(1/gamma(x), x, 0) == 0 + assert gruntz(x*gamma(x), x, 0) == 1 + + +def test_issue_6682(): + assert gruntz(exp(2*Ei(-x))/x**2, x, 0) == exp(2*EulerGamma) + + +def test_issue_7096(): + from sympy.functions import sign + assert gruntz(x**-pi, x, 0, dir='-') == oo*sign((-1)**(-pi)) + + +def test_issue_7391_8166(): + f = Function('f') + # limit should depend on the continuity of the expression at the point passed + raises(ValueError, lambda: gruntz(f(x), x, 4)) + raises(ValueError, lambda: gruntz(x*f(x)**2/(x**2 + f(x)**4), x, 0)) + + +def test_issue_24210_25885(): + eq = exp(x)/(1+1/x)**x**2 + ans = sqrt(E) + assert gruntz(eq, x, oo) == ans + assert gruntz(1/eq, x, oo) == 1/ans diff --git a/phivenv/Lib/site-packages/sympy/series/tests/test_kauers.py b/phivenv/Lib/site-packages/sympy/series/tests/test_kauers.py new file mode 100644 index 0000000000000000000000000000000000000000..bfb9044b33416bc38879649b258150ba2906250c --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/series/tests/test_kauers.py @@ -0,0 +1,23 @@ +from sympy.series.kauers import finite_diff +from sympy.series.kauers import finite_diff_kauers +from sympy.abc import x, y, z, m, n, w +from sympy.core.numbers import pi +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.concrete.summations import Sum + + +def test_finite_diff(): + assert finite_diff(x**2 + 2*x + 1, x) == 2*x + 3 + assert finite_diff(y**3 + 2*y**2 + 3*y + 5, y) == 3*y**2 + 7*y + 6 + assert finite_diff(z**2 - 2*z + 3, z) == 2*z - 1 + assert finite_diff(w**2 + 3*w - 2, w) == 2*w + 4 + assert finite_diff(sin(x), x, pi/6) == -sin(x) + sin(x + pi/6) + assert finite_diff(cos(y), y, pi/3) == -cos(y) + cos(y + pi/3) + assert finite_diff(x**2 - 2*x + 3, x, 2) == 4*x + assert finite_diff(n**2 - 2*n + 3, n, 3) == 6*n + 3 + +def test_finite_diff_kauers(): + assert finite_diff_kauers(Sum(x**2, (x, 1, n))) == (n + 1)**2 + assert finite_diff_kauers(Sum(y, (y, 1, m))) == (m + 1) + assert finite_diff_kauers(Sum((x*y), (x, 1, m), (y, 1, n))) == (m + 1)*(n + 1) + assert finite_diff_kauers(Sum((x*y**2), (x, 1, m), (y, 1, n))) == (n + 1)**2*(m + 1) diff --git a/phivenv/Lib/site-packages/sympy/series/tests/test_limits.py b/phivenv/Lib/site-packages/sympy/series/tests/test_limits.py new file mode 100644 index 0000000000000000000000000000000000000000..aa3ab7683f057424f1c3215a06381d27687710dc --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/series/tests/test_limits.py @@ -0,0 +1,1440 @@ +from itertools import product + +from sympy.concrete.summations import Sum +from sympy.core.function import (Function, diff) +from sympy.core import EulerGamma, GoldenRatio +from sympy.core.mod import Mod +from sympy.core.numbers import (E, I, Rational, oo, pi, zoo) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.combinatorial.numbers import fibonacci +from sympy.functions.combinatorial.factorials import (binomial, factorial, subfactorial) +from sympy.functions.elementary.complexes import (Abs, re, sign) +from sympy.functions.elementary.exponential import (LambertW, exp, log) +from sympy.functions.elementary.hyperbolic import (atanh, asinh, acosh, acoth, acsch, asech, tanh, sinh) +from sympy.functions.elementary.integers import (ceiling, floor, frac) +from sympy.functions.elementary.miscellaneous import (cbrt, real_root, sqrt) +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import (acos, acot, acsc, asec, asin, + atan, cos, cot, csc, sec, sin, tan) +from sympy.functions.special.bessel import (besseli, bessely, besselj, besselk) +from sympy.functions.special.error_functions import (Ei, erf, erfc, erfi, fresnelc, fresnels) +from sympy.functions.special.gamma_functions import (digamma, gamma, uppergamma) +from sympy.functions.special.hyper import meijerg +from sympy.integrals.integrals import (Integral, integrate) +from sympy.series.limits import (Limit, limit) +from sympy.simplify.simplify import (logcombine, simplify) +from sympy.simplify.hyperexpand import hyperexpand + +from sympy.calculus.accumulationbounds import AccumBounds +from sympy.core.mul import Mul +from sympy.series.limits import heuristics +from sympy.series.order import Order +from sympy.testing.pytest import XFAIL, raises + +from sympy import elliptic_e, elliptic_k + +from sympy.abc import x, y, z, k +n = Symbol('n', integer=True, positive=True) + + +def test_basic1(): + assert limit(x, x, oo) is oo + assert limit(x, x, -oo) is -oo + assert limit(-x, x, oo) is -oo + assert limit(x**2, x, -oo) is oo + assert limit(-x**2, x, oo) is -oo + assert limit(x*log(x), x, 0, dir="+") == 0 + assert limit(1/x, x, oo) == 0 + assert limit(exp(x), x, oo) is oo + assert limit(-exp(x), x, oo) is -oo + assert limit(exp(x)/x, x, oo) is oo + assert limit(1/x - exp(-x), x, oo) == 0 + assert limit(x + 1/x, x, oo) is oo + assert limit(x - x**2, x, oo) is -oo + assert limit((1 + x)**(1 + sqrt(2)), x, 0) == 1 + assert limit((1 + x)**oo, x, 0) == Limit((x + 1)**oo, x, 0) + assert limit((1 + x)**oo, x, 0, dir='-') == Limit((x + 1)**oo, x, 0, dir='-') + assert limit((1 + x + y)**oo, x, 0, dir='-') == Limit((1 + x + y)**oo, x, 0, dir='-') + assert limit(y/x/log(x), x, 0) == -oo*sign(y) + assert limit(cos(x + y)/x, x, 0) == sign(cos(y))*oo + assert limit(gamma(1/x + 3), x, oo) == 2 + assert limit(S.NaN, x, -oo) is S.NaN + assert limit(Order(2)*x, x, S.NaN) is S.NaN + assert limit(1/(x - 1), x, 1, dir="+") is oo + assert limit(1/(x - 1), x, 1, dir="-") is -oo + assert limit(1/(5 - x)**3, x, 5, dir="+") is -oo + assert limit(1/(5 - x)**3, x, 5, dir="-") is oo + assert limit(1/sin(x), x, pi, dir="+") is -oo + assert limit(1/sin(x), x, pi, dir="-") is oo + assert limit(1/cos(x), x, pi/2, dir="+") is -oo + assert limit(1/cos(x), x, pi/2, dir="-") is oo + assert limit(1/tan(x**3), x, (2*pi)**Rational(1, 3), dir="+") is oo + assert limit(1/tan(x**3), x, (2*pi)**Rational(1, 3), dir="-") is -oo + assert limit(1/cot(x)**3, x, (pi*Rational(3, 2)), dir="+") is -oo + assert limit(1/cot(x)**3, x, (pi*Rational(3, 2)), dir="-") is oo + assert limit(tan(x), x, oo) == AccumBounds(S.NegativeInfinity, S.Infinity) + assert limit(cot(x), x, oo) == AccumBounds(S.NegativeInfinity, S.Infinity) + assert limit(sec(x), x, oo) == AccumBounds(S.NegativeInfinity, S.Infinity) + assert limit(csc(x), x, oo) == AccumBounds(S.NegativeInfinity, S.Infinity) + + # test bi-directional limits + assert limit(sin(x)/x, x, 0, dir="+-") == 1 + assert limit(x**2, x, 0, dir="+-") == 0 + assert limit(1/x**2, x, 0, dir="+-") is oo + + # test failing bi-directional limits + assert limit(1/x, x, 0, dir="+-") is zoo + # approaching 0 + # from dir="+" + assert limit(1 + 1/x, x, 0) is oo + # from dir='-' + # Add + assert limit(1 + 1/x, x, 0, dir='-') is -oo + # Pow + assert limit(x**(-2), x, 0, dir='-') is oo + assert limit(x**(-3), x, 0, dir='-') is -oo + assert limit(1/sqrt(x), x, 0, dir='-') == (-oo)*I + assert limit(x**2, x, 0, dir='-') == 0 + assert limit(sqrt(x), x, 0, dir='-') == 0 + assert limit(x**-pi, x, 0, dir='-') == -oo*(-1)**(1 - pi) + assert limit((1 + cos(x))**oo, x, 0) == Limit((cos(x) + 1)**oo, x, 0) + + # test pull request 22491 + assert limit(1/asin(x), x, 0, dir = '+') == oo + assert limit(1/asin(x), x, 0, dir = '-') == -oo + assert limit(1/sinh(x), x, 0, dir = '+') == oo + assert limit(1/sinh(x), x, 0, dir = '-') == -oo + assert limit(log(1/x) + 1/sin(x), x, 0, dir = '+') == oo + assert limit(log(1/x) + 1/x, x, 0, dir = '+') == oo + + +def test_basic2(): + assert limit(x**x, x, 0, dir="+") == 1 + assert limit((exp(x) - 1)/x, x, 0) == 1 + assert limit(1 + 1/x, x, oo) == 1 + assert limit(-exp(1/x), x, oo) == -1 + assert limit(x + exp(-x), x, oo) is oo + assert limit(x + exp(-x**2), x, oo) is oo + assert limit(x + exp(-exp(x)), x, oo) is oo + assert limit(13 + 1/x - exp(-x), x, oo) == 13 + + +def test_basic3(): + assert limit(1/x, x, 0, dir="+") is oo + assert limit(1/x, x, 0, dir="-") is -oo + + +def test_basic4(): + assert limit(2*x + y*x, x, 0) == 0 + assert limit(2*x + y*x, x, 1) == 2 + y + assert limit(2*x**8 + y*x**(-3), x, -2) == 512 - y/8 + assert limit(sqrt(x + 1) - sqrt(x), x, oo) == 0 + assert integrate(1/(x**3 + 1), (x, 0, oo)) == 2*pi*sqrt(3)/9 + + +def test_log(): + # https://github.com/sympy/sympy/issues/21598 + a, b, c = symbols('a b c', positive=True) + A = log(a/b) - (log(a) - log(b)) + assert A.limit(a, oo) == 0 + assert (A * c).limit(a, oo) == 0 + + tau, x = symbols('tau x', positive=True) + # The value of manualintegrate in the issue + expr = tau**2*((tau - 1)*(tau + 1)*log(x + 1)/(tau**2 + 1)**2 + 1/((tau**2\ + + 1)*(x + 1)) - (-2*tau*atan(x/tau) + (tau**2/2 - 1/2)*log(tau**2\ + + x**2))/(tau**2 + 1)**2) + assert limit(expr, x, oo) == pi*tau**3/(tau**2 + 1)**2 + + +def test_piecewise(): + # https://github.com/sympy/sympy/issues/18363 + assert limit((real_root(x - 6, 3) + 2)/(x + 2), x, -2, '+') == Rational(1, 12) + + +def test_piecewise2(): + func1 = 2*sqrt(x)*Piecewise(((4*x - 2)/Abs(sqrt(4 - 4*(2*x - 1)**2)), 4*x - 2\ + >= 0), ((2 - 4*x)/Abs(sqrt(4 - 4*(2*x - 1)**2)), True)) + func2 = Piecewise((x**2/2, x <= 0.5), (x/2 - 0.125, True)) + func3 = Piecewise(((x - 9) / 5, x < -1), ((x - 9) / 5, x > 4), (sqrt(Abs(x - 3)), True)) + assert limit(func1, x, 0) == 1 + assert limit(func2, x, 0) == 0 + assert limit(func3, x, -1) == 2 + + +def test_basic5(): + class my(Function): + @classmethod + def eval(cls, arg): + if arg is S.Infinity: + return S.NaN + assert limit(my(x), x, oo) == Limit(my(x), x, oo) + + +def test_issue_3885(): + assert limit(x*y + x*z, z, 2) == x*y + 2*x + + +def test_Limit(): + assert Limit(sin(x)/x, x, 0) != 1 + assert Limit(sin(x)/x, x, 0).doit() == 1 + assert Limit(x, x, 0, dir='+-').args == (x, x, 0, Symbol('+-')) + + +def test_floor(): + assert limit(floor(x), x, -2, "+") == -2 + assert limit(floor(x), x, -2, "-") == -3 + assert limit(floor(x), x, -1, "+") == -1 + assert limit(floor(x), x, -1, "-") == -2 + assert limit(floor(x), x, 0, "+") == 0 + assert limit(floor(x), x, 0, "-") == -1 + assert limit(floor(x), x, 1, "+") == 1 + assert limit(floor(x), x, 1, "-") == 0 + assert limit(floor(x), x, 2, "+") == 2 + assert limit(floor(x), x, 2, "-") == 1 + assert limit(floor(x), x, 248, "+") == 248 + assert limit(floor(x), x, 248, "-") == 247 + + # https://github.com/sympy/sympy/issues/14478 + assert limit(x*floor(3/x)/2, x, 0, '+') == Rational(3, 2) + assert limit(floor(x + 1/2) - floor(x), x, oo) == AccumBounds(-S.Half, S(3)/2) + + # test issue 9158 + assert limit(floor(atan(x)), x, oo) == 1 + assert limit(floor(atan(x)), x, -oo) == -2 + assert limit(ceiling(atan(x)), x, oo) == 2 + assert limit(ceiling(atan(x)), x, -oo) == -1 + + +def test_floor_requires_robust_assumptions(): + assert limit(floor(sin(x)), x, 0, "+") == 0 + assert limit(floor(sin(x)), x, 0, "-") == -1 + assert limit(floor(cos(x)), x, 0, "+") == 0 + assert limit(floor(cos(x)), x, 0, "-") == 0 + assert limit(floor(5 + sin(x)), x, 0, "+") == 5 + assert limit(floor(5 + sin(x)), x, 0, "-") == 4 + assert limit(floor(5 + cos(x)), x, 0, "+") == 5 + assert limit(floor(5 + cos(x)), x, 0, "-") == 5 + + +def test_ceiling(): + assert limit(ceiling(x), x, -2, "+") == -1 + assert limit(ceiling(x), x, -2, "-") == -2 + assert limit(ceiling(x), x, -1, "+") == 0 + assert limit(ceiling(x), x, -1, "-") == -1 + assert limit(ceiling(x), x, 0, "+") == 1 + assert limit(ceiling(x), x, 0, "-") == 0 + assert limit(ceiling(x), x, 1, "+") == 2 + assert limit(ceiling(x), x, 1, "-") == 1 + assert limit(ceiling(x), x, 2, "+") == 3 + assert limit(ceiling(x), x, 2, "-") == 2 + assert limit(ceiling(x), x, 248, "+") == 249 + assert limit(ceiling(x), x, 248, "-") == 248 + + # https://github.com/sympy/sympy/issues/14478 + assert limit(x*ceiling(3/x)/2, x, 0, '+') == Rational(3, 2) + assert limit(ceiling(x + 1/2) - ceiling(x), x, oo) == AccumBounds(-S.Half, S(3)/2) + + +def test_ceiling_requires_robust_assumptions(): + assert limit(ceiling(sin(x)), x, 0, "+") == 1 + assert limit(ceiling(sin(x)), x, 0, "-") == 0 + assert limit(ceiling(cos(x)), x, 0, "+") == 1 + assert limit(ceiling(cos(x)), x, 0, "-") == 1 + assert limit(ceiling(5 + sin(x)), x, 0, "+") == 6 + assert limit(ceiling(5 + sin(x)), x, 0, "-") == 5 + assert limit(ceiling(5 + cos(x)), x, 0, "+") == 6 + assert limit(ceiling(5 + cos(x)), x, 0, "-") == 6 + + +def test_frac(): + assert limit(frac(x), x, oo) == AccumBounds(0, 1) + assert limit(frac(x)**(1/x), x, oo) == AccumBounds(0, 1) + assert limit(frac(x)**(1/x), x, -oo) == AccumBounds(1, oo) + assert limit(frac(x)**x, x, oo) == AccumBounds(0, oo) # wolfram gives (0, 1) + assert limit(frac(sin(x)), x, 0, "+") == 0 + assert limit(frac(sin(x)), x, 0, "-") == 1 + assert limit(frac(cos(x)), x, 0, "+-") == 1 + assert limit(frac(x**2), x, 0, "+-") == 0 + raises(ValueError, lambda: limit(frac(x), x, 0, '+-')) + assert limit(frac(-2*x + 1), x, 0, "+") == 1 + assert limit(frac(-2*x + 1), x, 0, "-") == 0 + assert limit(frac(x + S.Half), x, 0, "+-") == S(1)/2 + assert limit(frac(1/x), x, 0) == AccumBounds(0, 1) + + +def test_issue_14355(): + assert limit(floor(sin(x)/x), x, 0, '+') == 0 + assert limit(floor(sin(x)/x), x, 0, '-') == 0 + # test comment https://github.com/sympy/sympy/issues/14355#issuecomment-372121314 + assert limit(floor(-tan(x)/x), x, 0, '+') == -2 + assert limit(floor(-tan(x)/x), x, 0, '-') == -2 + + +def test_atan(): + x = Symbol("x", real=True) + assert limit(atan(x)*sin(1/x), x, 0) == 0 + assert limit(atan(x) + sqrt(x + 1) - sqrt(x), x, oo) == pi/2 + + +def test_set_signs(): + assert limit(abs(x), x, 0) == 0 + assert limit(abs(sin(x)), x, 0) == 0 + assert limit(abs(cos(x)), x, 0) == 1 + assert limit(abs(sin(x + 1)), x, 0) == sin(1) + + # https://github.com/sympy/sympy/issues/9449 + assert limit((Abs(x + y) - Abs(x - y))/(2*x), x, 0) == sign(y) + + # https://github.com/sympy/sympy/issues/12398 + assert limit(Abs(log(x)/x**3), x, oo) == 0 + assert limit(x*(Abs(log(x)/x**3)/Abs(log(x + 1)/(x + 1)**3) - 1), x, oo) == 3 + + # https://github.com/sympy/sympy/issues/18501 + assert limit(Abs(log(x - 1)**3 - 1), x, 1, '+') == oo + + # https://github.com/sympy/sympy/issues/18997 + assert limit(Abs(log(x)), x, 0) == oo + assert limit(Abs(log(Abs(x))), x, 0) == oo + + # https://github.com/sympy/sympy/issues/19026 + z = Symbol('z', positive=True) + assert limit(Abs(log(z) + 1)/log(z), z, oo) == 1 + + # https://github.com/sympy/sympy/issues/20704 + assert limit(z*(Abs(1/z + y) - Abs(y - 1/z))/2, z, 0) == 0 + + # https://github.com/sympy/sympy/issues/21606 + assert limit(cos(z)/sign(z), z, pi, '-') == -1 + + +def test_heuristic(): + x = Symbol("x", real=True) + assert heuristics(sin(1/x) + atan(x), x, 0, '+') == AccumBounds(-1, 1) + assert limit(log(2 + sqrt(atan(x))*sqrt(sin(1/x))), x, 0) == log(2) + + +def test_issue_3871(): + z = Symbol("z", positive=True) + f = -1/z*exp(-z*x) + assert limit(f, x, oo) == 0 + assert f.limit(x, oo) == 0 + + +def test_exponential(): + n = Symbol('n') + x = Symbol('x', real=True) + assert limit((1 + x/n)**n, n, oo) == exp(x) + assert limit((1 + x/(2*n))**n, n, oo) == exp(x/2) + assert limit((1 + x/(2*n + 1))**n, n, oo) == exp(x/2) + assert limit(((x - 1)/(x + 1))**x, x, oo) == exp(-2) + assert limit(1 + (1 + 1/x)**x, x, oo) == 1 + S.Exp1 + assert limit((2 + 6*x)**x/(6*x)**x, x, oo) == exp(S('1/3')) + + +def test_exponential2(): + n = Symbol('n') + assert limit((1 + x/(n + sin(n)))**n, n, oo) == exp(x) + + +def test_doit(): + f = Integral(2 * x, x) + l = Limit(f, x, oo) + assert l.doit() is oo + + +def test_series_AccumBounds(): + assert limit(sin(k) - sin(k + 1), k, oo) == AccumBounds(-2, 2) + assert limit(cos(k) - cos(k + 1) + 1, k, oo) == AccumBounds(-1, 3) + + # not the exact bound + assert limit(sin(k) - sin(k)*cos(k), k, oo) == AccumBounds(-2, 2) + + # test for issue #9934 + lo = (-3 + cos(1))/2 + hi = (1 + cos(1))/2 + t1 = Mul(AccumBounds(lo, hi), 1/(-1 + cos(1)), evaluate=False) + assert limit(simplify(Sum(cos(n).rewrite(exp), (n, 0, k)).doit().rewrite(sin)), k, oo) == t1 + + t2 = Mul(AccumBounds(-1 + sin(1)/2, sin(1)/2 + 1), 1/(1 - cos(1))) + assert limit(simplify(Sum(sin(n).rewrite(exp), (n, 0, k)).doit().rewrite(sin)), k, oo) == t2 + + assert limit(((sin(x) + 1)/2)**x, x, oo) == AccumBounds(0, oo) # wolfram says 0 + + # https://github.com/sympy/sympy/issues/12312 + e = 2**(-x)*(sin(x) + 1)**x + assert limit(e, x, oo) == AccumBounds(0, oo) + + +def test_bessel_functions_at_infinity(): + # Pull Request 23844 implements limits for all bessel and modified bessel + # functions approaching infinity along any direction i.e. abs(z0) tends to oo + + assert limit(besselj(1, x), x, oo) == 0 + assert limit(besselj(1, x), x, -oo) == 0 + assert limit(besselj(1, x), x, I*oo) == oo*I + assert limit(besselj(1, x), x, -I*oo) == -oo*I + assert limit(bessely(1, x), x, oo) == 0 + assert limit(bessely(1, x), x, -oo) == 0 + assert limit(bessely(1, x), x, I*oo) == -oo + assert limit(bessely(1, x), x, -I*oo) == -oo + assert limit(besseli(1, x), x, oo) == oo + assert limit(besseli(1, x), x, -oo) == -oo + assert limit(besseli(1, x), x, I*oo) == 0 + assert limit(besseli(1, x), x, -I*oo) == 0 + assert limit(besselk(1, x), x, oo) == 0 + assert limit(besselk(1, x), x, -oo) == -oo*I + assert limit(besselk(1, x), x, I*oo) == 0 + assert limit(besselk(1, x), x, -I*oo) == 0 + + # test issue 14874 + assert limit(besselk(0, x), x, oo) == 0 + + +@XFAIL +def test_doit2(): + f = Integral(2 * x, x) + l = Limit(f, x, oo) + # limit() breaks on the contained Integral. + assert l.doit(deep=False) == l + + +def test_issue_2929(): + assert limit((x * exp(x))/(exp(x) - 1), x, -oo) == 0 + + +def test_issue_3792(): + assert limit((1 - cos(x))/x**2, x, S.Half) == 4 - 4*cos(S.Half) + assert limit(sin(sin(x + 1) + 1), x, 0) == sin(1 + sin(1)) + assert limit(abs(sin(x + 1) + 1), x, 0) == 1 + sin(1) + + +def test_issue_4090(): + assert limit(1/(x + 3), x, 2) == Rational(1, 5) + assert limit(1/(x + pi), x, 2) == S.One/(2 + pi) + assert limit(log(x)/(x**2 + 3), x, 2) == log(2)/7 + assert limit(log(x)/(x**2 + pi), x, 2) == log(2)/(4 + pi) + + +def test_issue_4547(): + assert limit(cot(x), x, 0, dir='+') is oo + assert limit(cot(x), x, pi/2, dir='+') == 0 + + +def test_issue_5164(): + assert limit(x**0.5, x, oo) == oo**0.5 is oo + assert limit(x**0.5, x, 16) == 4 # Should this be a float? + assert limit(x**0.5, x, 0) == 0 + assert limit(x**(-0.5), x, oo) == 0 + assert limit(x**(-0.5), x, 4) == S.Half # Should this be a float? + + +def test_issue_5383(): + func = (1.0 * 1 + 1.0 * x)**(1.0 * 1 / x) + assert limit(func, x, 0) == E + + +def test_issue_14793(): + expr = ((x + S(1)/2) * log(x) - x + log(2*pi)/2 - \ + log(factorial(x)) + S(1)/(12*x))*x**3 + assert limit(expr, x, oo) == S(1)/360 + + +def test_issue_5183(): + # using list(...) so py.test can recalculate values + tests = list(product([x, -x], + [-1, 1], + [2, 3, S.Half, Rational(2, 3)], + ['-', '+'])) + results = (oo, oo, -oo, oo, -oo*I, oo, -oo*(-1)**Rational(1, 3), oo, + 0, 0, 0, 0, 0, 0, 0, 0, + oo, oo, oo, -oo, oo, -oo*I, oo, -oo*(-1)**Rational(1, 3), + 0, 0, 0, 0, 0, 0, 0, 0) + assert len(tests) == len(results) + for i, (args, res) in enumerate(zip(tests, results)): + y, s, e, d = args + eq = y**(s*e) + try: + assert limit(eq, x, 0, dir=d) == res + except AssertionError: + if 0: # change to 1 if you want to see the failing tests + print() + print(i, res, eq, d, limit(eq, x, 0, dir=d)) + else: + assert None + + +def test_issue_5184(): + assert limit(sin(x)/x, x, oo) == 0 + assert limit(atan(x), x, oo) == pi/2 + assert limit(gamma(x), x, oo) is oo + assert limit(cos(x)/x, x, oo) == 0 + assert limit(gamma(x), x, S.Half) == sqrt(pi) + + r = Symbol('r', real=True) + assert limit(r*sin(1/r), r, 0) == 0 + + +def test_issue_5229(): + assert limit((1 + y)**(1/y) - S.Exp1, y, 0) == 0 + + +def test_issue_4546(): + # using list(...) so py.test can recalculate values + tests = list(product([cot, tan], + [-pi/2, 0, pi/2, pi, pi*Rational(3, 2)], + ['-', '+'])) + results = (0, 0, -oo, oo, 0, 0, -oo, oo, 0, 0, + oo, -oo, 0, 0, oo, -oo, 0, 0, oo, -oo) + assert len(tests) == len(results) + for i, (args, res) in enumerate(zip(tests, results)): + f, l, d = args + eq = f(x) + try: + assert limit(eq, x, l, dir=d) == res + except AssertionError: + if 0: # change to 1 if you want to see the failing tests + print() + print(i, res, eq, l, d, limit(eq, x, l, dir=d)) + else: + assert None + + +def test_issue_3934(): + assert limit((1 + x**log(3))**(1/x), x, 0) == 1 + assert limit((5**(1/x) + 3**(1/x))**x, x, 0) == 5 + + +def test_issue_5955(): + assert limit((x**16)/(1 + x**16), x, oo) == 1 + assert limit((x**100)/(1 + x**100), x, oo) == 1 + assert limit((x**1885)/(1 + x**1885), x, oo) == 1 + assert limit((x**1000/((x + 1)**1000 + exp(-x))), x, oo) == 1 + + +def test_newissue(): + assert limit(exp(1/sin(x))/exp(cot(x)), x, 0) == 1 + + +def test_extended_real_line(): + assert limit(x - oo, x, oo) == Limit(x - oo, x, oo) + assert limit(1/(x + sin(x)) - oo, x, 0) == Limit(1/(x + sin(x)) - oo, x, 0) + assert limit(oo/x, x, oo) == Limit(oo/x, x, oo) + assert limit(x - oo + 1/x, x, oo) == Limit(x - oo + 1/x, x, oo) + + +@XFAIL +def test_order_oo(): + x = Symbol('x', positive=True) + assert Order(x)*oo != Order(1, x) + assert limit(oo/(x**2 - 4), x, oo) is oo + + +def test_issue_5436(): + raises(NotImplementedError, lambda: limit(exp(x*y), x, oo)) + raises(NotImplementedError, lambda: limit(exp(-x*y), x, oo)) + + +def test_Limit_dir(): + raises(TypeError, lambda: Limit(x, x, 0, dir=0)) + raises(ValueError, lambda: Limit(x, x, 0, dir='0')) + + +def test_polynomial(): + assert limit((x + 1)**1000/((x + 1)**1000 + 1), x, oo) == 1 + assert limit((x + 1)**1000/((x + 1)**1000 + 1), x, -oo) == 1 + assert limit(x ** Rational(77, 3) / (1 + x ** Rational(77, 3)), x, oo) == 1 + assert limit(x ** 101.1 / (1 + x ** 101.1), x, oo) == 1 + + +def test_rational(): + assert limit(1/y - (1/(y + x) + x/(y + x)/y)/z, x, oo) == (z - 1)/(y*z) + assert limit(1/y - (1/(y + x) + x/(y + x)/y)/z, x, -oo) == (z - 1)/(y*z) + + +def test_issue_5740(): + assert limit(log(x)*z - log(2*x)*y, x, 0) == oo*sign(y - z) + + +def test_issue_6366(): + n = Symbol('n', integer=True, positive=True) + r = (n + 1)*x**(n + 1)/(x**(n + 1) - 1) - x/(x - 1) + assert limit(r, x, 1).cancel() == n/2 + + +def test_factorial(): + f = factorial(x) + assert limit(f, x, oo) is oo + assert limit(x/f, x, oo) == 0 + # see Stirling's approximation: + # https://en.wikipedia.org/wiki/Stirling's_approximation + assert limit(f/(sqrt(2*pi*x)*(x/E)**x), x, oo) == 1 + assert limit(f, x, -oo) == gamma(-oo) + + +def test_issue_6560(): + e = (5*x**3/4 - x*Rational(3, 4) + (y*(3*x**2/2 - S.Half) + + 35*x**4/8 - 15*x**2/4 + Rational(3, 8))/(2*(y + 1))) + assert limit(e, y, oo) == 5*x**3/4 + 3*x**2/4 - 3*x/4 - Rational(1, 4) + +@XFAIL +def test_issue_5172(): + n = Symbol('n') + r = Symbol('r', positive=True) + c = Symbol('c') + p = Symbol('p', positive=True) + m = Symbol('m', negative=True) + expr = ((2*n*(n - r + 1)/(n + r*(n - r + 1)))**c + + (r - 1)*(n*(n - r + 2)/(n + r*(n - r + 1)))**c - n)/(n**c - n) + expr = expr.subs(c, c + 1) + raises(NotImplementedError, lambda: limit(expr, n, oo)) + assert limit(expr.subs(c, m), n, oo) == 1 + assert limit(expr.subs(c, p), n, oo).simplify() == \ + (2**(p + 1) + r - 1)/(r + 1)**(p + 1) + + +def test_issue_7088(): + a = Symbol('a') + assert limit(sqrt(x/(x + a)), x, oo) == 1 + + +def test_branch_cuts(): + assert limit(asin(I*x + 2), x, 0) == pi - asin(2) + assert limit(asin(I*x + 2), x, 0, '-') == asin(2) + assert limit(asin(I*x - 2), x, 0) == -asin(2) + assert limit(asin(I*x - 2), x, 0, '-') == -pi + asin(2) + assert limit(acos(I*x + 2), x, 0) == -acos(2) + assert limit(acos(I*x + 2), x, 0, '-') == acos(2) + assert limit(acos(I*x - 2), x, 0) == acos(-2) + assert limit(acos(I*x - 2), x, 0, '-') == 2*pi - acos(-2) + assert limit(atan(x + 2*I), x, 0) == I*atanh(2) + assert limit(atan(x + 2*I), x, 0, '-') == -pi + I*atanh(2) + assert limit(atan(x - 2*I), x, 0) == pi - I*atanh(2) + assert limit(atan(x - 2*I), x, 0, '-') == -I*atanh(2) + assert limit(atan(1/x), x, 0) == pi/2 + assert limit(atan(1/x), x, 0, '-') == -pi/2 + assert limit(atan(x), x, oo) == pi/2 + assert limit(atan(x), x, -oo) == -pi/2 + assert limit(acot(x + S(1)/2*I), x, 0) == pi - I*acoth(S(1)/2) + assert limit(acot(x + S(1)/2*I), x, 0, '-') == -I*acoth(S(1)/2) + assert limit(acot(x - S(1)/2*I), x, 0) == I*acoth(S(1)/2) + assert limit(acot(x - S(1)/2*I), x, 0, '-') == -pi + I*acoth(S(1)/2) + assert limit(acot(x), x, 0) == pi/2 + assert limit(acot(x), x, 0, '-') == -pi/2 + assert limit(asec(I*x + S(1)/2), x, 0) == asec(S(1)/2) + assert limit(asec(I*x + S(1)/2), x, 0, '-') == -asec(S(1)/2) + assert limit(asec(I*x - S(1)/2), x, 0) == 2*pi - asec(-S(1)/2) + assert limit(asec(I*x - S(1)/2), x, 0, '-') == asec(-S(1)/2) + assert limit(acsc(I*x + S(1)/2), x, 0) == acsc(S(1)/2) + assert limit(acsc(I*x + S(1)/2), x, 0, '-') == pi - acsc(S(1)/2) + assert limit(acsc(I*x - S(1)/2), x, 0) == -pi + acsc(S(1)/2) + assert limit(acsc(I*x - S(1)/2), x, 0, '-') == -acsc(S(1)/2) + + assert limit(log(I*x - 1), x, 0) == I*pi + assert limit(log(I*x - 1), x, 0, '-') == -I*pi + assert limit(log(-I*x - 1), x, 0) == -I*pi + assert limit(log(-I*x - 1), x, 0, '-') == I*pi + + assert limit(sqrt(I*x - 1), x, 0) == I + assert limit(sqrt(I*x - 1), x, 0, '-') == -I + assert limit(sqrt(-I*x - 1), x, 0) == -I + assert limit(sqrt(-I*x - 1), x, 0, '-') == I + + assert limit(cbrt(I*x - 1), x, 0) == (-1)**(S(1)/3) + assert limit(cbrt(I*x - 1), x, 0, '-') == -(-1)**(S(2)/3) + assert limit(cbrt(-I*x - 1), x, 0) == -(-1)**(S(2)/3) + assert limit(cbrt(-I*x - 1), x, 0, '-') == (-1)**(S(1)/3) + + +def test_issue_6364(): + a = Symbol('a') + e = z/(1 - sqrt(1 + z)*sin(a)**2 - sqrt(1 - z)*cos(a)**2) + assert limit(e, z, 0) == 1/(cos(a)**2 - S.Half) + + +def test_issue_6682(): + assert limit(exp(2*Ei(-x))/x**2, x, 0) == exp(2*EulerGamma) + + +def test_issue_4099(): + a = Symbol('a') + assert limit(a/x, x, 0) == oo*sign(a) + assert limit(-a/x, x, 0) == -oo*sign(a) + assert limit(-a*x, x, oo) == -oo*sign(a) + assert limit(a*x, x, oo) == oo*sign(a) + + +def test_issue_4503(): + dx = Symbol('dx') + assert limit((sqrt(1 + exp(x + dx)) - sqrt(1 + exp(x)))/dx, dx, 0) == \ + exp(x)/(2*sqrt(exp(x) + 1)) + + +def test_issue_6052(): + G = meijerg((), (), (1,), (0,), -x) + g = hyperexpand(G) + assert limit(g, x, 0, '+-') == 0 + assert limit(g, x, oo) == -oo + + +def test_issue_7224(): + expr = sqrt(x)*besseli(1,sqrt(8*x)) + assert limit(x*diff(expr, x, x)/expr, x, 0) == 2 + assert limit(x*diff(expr, x, x)/expr, x, 1).evalf() == 2.0 + + +def test_issue_7391_8166(): + f = Function('f') + # limit should depend on the continuity of the expression at the point passed + assert limit(f(x), x, 4) == Limit(f(x), x, 4, dir='+') + assert limit(x*f(x)**2/(x**2 + f(x)**4), x, 0) == Limit(x*f(x)**2/(x**2 + f(x)**4), x, 0, dir='+') + + +def test_issue_8208(): + assert limit(n**(Rational(1, 1e9) - 1), n, oo) == 0 + + +def test_issue_8229(): + assert limit((x**Rational(1, 4) - 2)/(sqrt(x) - 4)**Rational(2, 3), x, 16) == 0 + + +def test_issue_8433(): + d, t = symbols('d t', positive=True) + assert limit(erf(1 - t/d), t, oo) == -1 + + +def test_issue_8481(): + k = Symbol('k', integer=True, nonnegative=True) + lamda = Symbol('lamda', positive=True) + assert limit(lamda**k * exp(-lamda) / factorial(k), k, oo) == 0 + + +def test_issue_8462(): + assert limit(binomial(n, n/2), n, oo) == oo + assert limit(binomial(n, n/2) * 3 ** (-n), n, oo) == 0 + + +def test_issue_8634(): + n = Symbol('n', integer=True, positive=True) + x = Symbol('x') + assert limit(x**n, x, -oo) == oo*sign((-1)**n) + + +def test_issue_8635_18176(): + x = Symbol('x', real=True) + k = Symbol('k', positive=True) + assert limit(x**n - x**(n - 0), x, oo) == 0 + assert limit(x**n - x**(n - 5), x, oo) == oo + assert limit(x**n - x**(n - 2.5), x, oo) == oo + assert limit(x**n - x**(n - k - 1), x, oo) == oo + x = Symbol('x', positive=True) + assert limit(x**n - x**(n - 1), x, oo) == oo + assert limit(x**n - x**(n + 2), x, oo) == -oo + + +def test_issue_8730(): + assert limit(subfactorial(x), x, oo) is oo + + +def test_issue_9252(): + n = Symbol('n', integer=True) + c = Symbol('c', positive=True) + assert limit((log(n))**(n/log(n)) / (1 + c)**n, n, oo) == 0 + # limit should depend on the value of c + raises(NotImplementedError, lambda: limit((log(n))**(n/log(n)) / c**n, n, oo)) + + +def test_issue_9558(): + assert limit(sin(x)**15, x, 0, '-') == 0 + + +def test_issue_10801(): + # make sure limits work with binomial + assert limit(16**k / (k * binomial(2*k, k)**2), k, oo) == pi + + +def test_issue_10976(): + s, x = symbols('s x', real=True) + assert limit(erf(s*x)/erf(s), s, 0) == x + + +def test_issue_9041(): + assert limit(factorial(n) / ((n/exp(1))**n * sqrt(2*pi*n)), n, oo) == 1 + + +def test_issue_9205(): + x, y, a = symbols('x, y, a') + assert Limit(x, x, a).free_symbols == {a} + assert Limit(x, x, a, '-').free_symbols == {a} + assert Limit(x + y, x + y, a).free_symbols == {a} + assert Limit(-x**2 + y, x**2, a).free_symbols == {y, a} + + +def test_issue_9471(): + assert limit(((27**(log(n,3)))/n**3),n,oo) == 1 + assert limit(((27**(log(n,3)+1))/n**3),n,oo) == 27 + + +def test_issue_10382(): + assert limit(fibonacci(n + 1)/fibonacci(n), n, oo) == GoldenRatio + + +def test_issue_11496(): + assert limit(erfc(log(1/x)), x, oo) == 2 + + +def test_issue_11879(): + assert simplify(limit(((x+y)**n-x**n)/y, y, 0)) == n*x**(n-1) + + +def test_limit_with_Float(): + k = symbols("k") + assert limit(1.0 ** k, k, oo) == 1 + assert limit(0.3*1.0**k, k, oo) == Rational(3, 10) + + +def test_issue_10610(): + assert limit(3**x*3**(-x - 1)*(x + 1)**2/x**2, x, oo) == Rational(1, 3) + + +def test_issue_10868(): + assert limit(log(x) + asech(x), x, 0, '+') == log(2) + assert limit(log(x) + asech(x), x, 0, '-') == log(2) + 2*I*pi + raises(ValueError, lambda: limit(log(x) + asech(x), x, 0, '+-')) + assert limit(log(x) + asech(x), x, oo) == oo + assert limit(log(x) + acsch(x), x, 0, '+') == log(2) + assert limit(log(x) + acsch(x), x, 0, '-') == -oo + raises(ValueError, lambda: limit(log(x) + acsch(x), x, 0, '+-')) + assert limit(log(x) + acsch(x), x, oo) == oo + + +def test_issue_6599(): + assert limit((n + cos(n))/n, n, oo) == 1 + + +def test_issue_12555(): + assert limit((3**x + 2* x**10) / (x**10 + exp(x)), x, -oo) == 2 + assert limit((3**x + 2* x**10) / (x**10 + exp(x)), x, oo) is oo + + +def test_issue_12769(): + r, z, x = symbols('r z x', real=True) + a, b, s0, K, F0, s, T = symbols('a b s0 K F0 s T', positive=True, real=True) + fx = (F0**b*K**b*r*s0 - sqrt((F0**2*K**(2*b)*a**2*(b - 1) + \ + F0**(2*b)*K**2*a**2*(b - 1) + F0**(2*b)*K**(2*b)*s0**2*(b - 1)*(b**2 - 2*b + 1) - \ + 2*F0**(2*b)*K**(b + 1)*a*r*s0*(b**2 - 2*b + 1) + \ + 2*F0**(b + 1)*K**(2*b)*a*r*s0*(b**2 - 2*b + 1) - \ + 2*F0**(b + 1)*K**(b + 1)*a**2*(b - 1))/((b - 1)*(b**2 - 2*b + 1))))*(b*r - b - r + 1) + + assert fx.subs(K, F0).factor(deep=True) == limit(fx, K, F0).factor(deep=True) + + +def test_issue_13332(): + assert limit(sqrt(30)*5**(-5*x - 1)*(46656*x)**x*(5*x + 2)**(5*x + 5*S.Half) * + (6*x + 2)**(-6*x - 5*S.Half), x, oo) == Rational(25, 36) + + +def test_issue_12564(): + assert limit(x**2 + x*sin(x) + cos(x), x, -oo) is oo + assert limit(x**2 + x*sin(x) + cos(x), x, oo) is oo + assert limit(((x + cos(x))**2).expand(), x, oo) is oo + assert limit(((x + sin(x))**2).expand(), x, oo) is oo + assert limit(((x + cos(x))**2).expand(), x, -oo) is oo + assert limit(((x + sin(x))**2).expand(), x, -oo) is oo + + +def test_issue_14456(): + raises(NotImplementedError, lambda: Limit(exp(x), x, zoo).doit()) + raises(NotImplementedError, lambda: Limit(x**2/(x+1), x, zoo).doit()) + + +def test_issue_14411(): + assert limit(3*sec(4*pi*x - x/3), x, 3*pi/(24*pi - 2)) is -oo + + +def test_issue_13382(): + assert limit(x*(((x + 1)**2 + 1)/(x**2 + 1) - 1), x, oo) == 2 + + +def test_issue_13403(): + assert limit(x*(-1 + (x + log(x + 1) + 1)/(x + log(x))), x, oo) == 1 + + +def test_issue_13416(): + assert limit((-x**3*log(x)**3 + (x - 1)*(x + 1)**2*log(x + 1)**3)/(x**2*log(x)**3), x, oo) == 1 + + +def test_issue_13462(): + assert limit(n**2*(2*n*(-(1 - 1/(2*n))**x + 1) - x - (-x**2/4 + x/4)/n), n, oo) == x**3/24 - x**2/8 + x/12 + + +def test_issue_13750(): + a = Symbol('a') + assert limit(erf(a - x), x, oo) == -1 + assert limit(erf(sqrt(x) - x), x, oo) == -1 + + +def test_issue_14276(): + assert isinstance(limit(sin(x)**log(x), x, oo), Limit) + assert isinstance(limit(sin(x)**cos(x), x, oo), Limit) + assert isinstance(limit(sin(log(cos(x))), x, oo), Limit) + assert limit((1 + 1/(x**2 + cos(x)))**(x**2 + x), x, oo) == E + + +def test_issue_14514(): + assert limit((1/(log(x)**log(x)))**(1/x), x, oo) == 1 + + +def test_issues_14525(): + assert limit(sin(x)**2 - cos(x) + tan(x)*csc(x), x, oo) == AccumBounds(S.NegativeInfinity, S.Infinity) + assert limit(sin(x)**2 - cos(x) + sin(x)*cot(x), x, oo) == AccumBounds(S.NegativeInfinity, S.Infinity) + assert limit(cot(x) - tan(x)**2, x, oo) == AccumBounds(S.NegativeInfinity, S.Infinity) + assert limit(cos(x) - tan(x)**2, x, oo) == AccumBounds(S.NegativeInfinity, S.One) + assert limit(sin(x) - tan(x)**2, x, oo) == AccumBounds(S.NegativeInfinity, S.One) + assert limit(cos(x)**2 - tan(x)**2, x, oo) == AccumBounds(S.NegativeInfinity, S.One) + assert limit(tan(x)**2 + sin(x)**2 - cos(x), x, oo) == AccumBounds(-S.One, S.Infinity) + + +def test_issue_14574(): + assert limit(sqrt(x)*cos(x - x**2) / (x + 1), x, oo) == 0 + + +def test_issue_10102(): + assert limit(fresnels(x), x, oo) == S.Half + assert limit(3 + fresnels(x), x, oo) == 3 + S.Half + assert limit(5*fresnels(x), x, oo) == Rational(5, 2) + assert limit(fresnelc(x), x, oo) == S.Half + assert limit(fresnels(x), x, -oo) == Rational(-1, 2) + assert limit(4*fresnelc(x), x, -oo) == -2 + + +def test_issue_14377(): + raises(NotImplementedError, lambda: limit(exp(I*x)*sin(pi*x), x, oo)) + + +def test_issue_15146(): + e = (x/2) * (-2*x**3 - 2*(x**3 - 1) * x**2 * digamma(x**3 + 1) + \ + 2*(x**3 - 1) * x**2 * digamma(x**3 + x + 1) + x + 3) + assert limit(e, x, oo) == S(1)/3 + + +def test_issue_15202(): + e = (2**x*(2 + 2**(-x)*(-2*2**x + x + 2))/(x + 1))**(x + 1) + assert limit(e, x, oo) == exp(1) + + e = (log(x, 2)**7 + 10*x*factorial(x) + 5**x) / (factorial(x + 1) + 3*factorial(x) + 10**x) + assert limit(e, x, oo) == 10 + + +def test_issue_15282(): + assert limit((x**2000 - (x + 1)**2000) / x**1999, x, oo) == -2000 + + +def test_issue_15984(): + assert limit((-x + log(exp(x) + 1))/x, x, oo, dir='-') == 0 + + +def test_issue_13571(): + assert limit(uppergamma(x, 1) / gamma(x), x, oo) == 1 + + +def test_issue_13575(): + assert limit(acos(erfi(x)), x, 1) == acos(erfi(S.One)) + + +def test_issue_17325(): + assert Limit(sin(x)/x, x, 0, dir="+-").doit() == 1 + assert Limit(x**2, x, 0, dir="+-").doit() == 0 + assert Limit(1/x**2, x, 0, dir="+-").doit() is oo + assert Limit(1/x, x, 0, dir="+-").doit() is zoo + + +def test_issue_10978(): + assert LambertW(x).limit(x, 0) == 0 + + +def test_issue_14313_comment(): + assert limit(floor(n/2), n, oo) is oo + + +def test_issue_15323(): + d = ((1 - 1/x)**x).diff(x) + assert limit(d, x, 1, dir='+') == 1 + + +def test_issue_12571(): + assert limit(-LambertW(-log(x))/log(x), x, 1) == 1 + + +def test_issue_14590(): + assert limit((x**3*((x + 1)/x)**x)/((x + 1)*(x + 2)*(x + 3)), x, oo) == exp(1) + + +def test_issue_14393(): + a, b = symbols('a b') + assert limit((x**b - y**b)/(x**a - y**a), x, y) == b*y**(-a + b)/a + + +def test_issue_14556(): + assert limit(factorial(n + 1)**(1/(n + 1)) - factorial(n)**(1/n), n, oo) == exp(-1) + + +def test_issue_14811(): + assert limit(((1 + ((S(2)/3)**(x + 1)))**(2**x))/(2**((S(4)/3)**(x - 1))), x, oo) == oo + + +def test_issue_16222(): + assert limit(exp(x), x, 1000000000) == exp(1000000000) + + +def test_issue_16714(): + assert limit(((x**(x + 1) + (x + 1)**x) / x**(x + 1))**x, x, oo) == exp(exp(1)) + + +def test_issue_16722(): + z = symbols('z', positive=True) + assert limit(binomial(n + z, n)*n**-z, n, oo) == 1/gamma(z + 1) + z = symbols('z', positive=True, integer=True) + assert limit(binomial(n + z, n)*n**-z, n, oo) == 1/gamma(z + 1) + + +def test_issue_17431(): + assert limit(((n + 1) + 1) / (((n + 1) + 2) * factorial(n + 1)) * + (n + 2) * factorial(n) / (n + 1), n, oo) == 0 + assert limit((n + 2)**2*factorial(n)/((n + 1)*(n + 3)*factorial(n + 1)) + , n, oo) == 0 + assert limit((n + 1) * factorial(n) / (n * factorial(n + 1)), n, oo) == 0 + + +def test_issue_17671(): + assert limit(Ei(-log(x)) - log(log(x))/x, x, 1) == EulerGamma + + +def test_issue_17751(): + a, b, c, x = symbols('a b c x', positive=True) + assert limit((a + 1)*x - sqrt((a + 1)**2*x**2 + b*x + c), x, oo) == -b/(2*a + 2) + + +def test_issue_17792(): + assert limit(factorial(n)/sqrt(n)*(exp(1)/n)**n, n, oo) == sqrt(2)*sqrt(pi) + + +def test_issue_18118(): + assert limit(sign(sin(x)), x, 0, "-") == -1 + assert limit(sign(sin(x)), x, 0, "+") == 1 + + +def test_issue_18306(): + assert limit(sin(sqrt(x))/sqrt(sin(x)), x, 0, '+') == 1 + + +def test_issue_18378(): + assert limit(log(exp(3*x) + x)/log(exp(x) + x**100), x, oo) == 3 + + +def test_issue_18399(): + assert limit((1 - S(1)/2*x)**(3*x), x, oo) is zoo + assert limit((-x)**x, x, oo) is zoo + + +def test_issue_18442(): + assert limit(tan(x)**(2**(sqrt(pi))), x, oo, dir='-') == Limit(tan(x)**(2**(sqrt(pi))), x, oo, dir='-') + + +def test_issue_18452(): + assert limit(abs(log(x))**x, x, 0) == 1 + assert limit(abs(log(x))**x, x, 0, "-") == 1 + + +def test_issue_18473(): + assert limit(sin(x)**(1/x), x, oo) == Limit(sin(x)**(1/x), x, oo, dir='-') + assert limit(cos(x)**(1/x), x, oo) == Limit(cos(x)**(1/x), x, oo, dir='-') + assert limit(tan(x)**(1/x), x, oo) == Limit(tan(x)**(1/x), x, oo, dir='-') + assert limit((cos(x) + 2)**(1/x), x, oo) == 1 + assert limit((sin(x) + 10)**(1/x), x, oo) == 1 + assert limit((cos(x) - 2)**(1/x), x, oo) == Limit((cos(x) - 2)**(1/x), x, oo, dir='-') + assert limit((cos(x) + 1)**(1/x), x, oo) == AccumBounds(0, 1) + assert limit((tan(x)**2)**(2/x) , x, oo) == AccumBounds(0, oo) + assert limit((sin(x)**2)**(1/x), x, oo) == AccumBounds(0, 1) + # Tests for issue #23751 + assert limit((cos(x) + 1)**(1/x), x, -oo) == AccumBounds(1, oo) + assert limit((sin(x)**2)**(1/x), x, -oo) == AccumBounds(1, oo) + assert limit((tan(x)**2)**(2/x) , x, -oo) == AccumBounds(0, oo) + + +def test_issue_18482(): + assert limit((2*exp(3*x)/(exp(2*x) + 1))**(1/x), x, oo) == exp(1) + + +def test_issue_18508(): + assert limit(sin(x)/sqrt(1-cos(x)), x, 0) == sqrt(2) + assert limit(sin(x)/sqrt(1-cos(x)), x, 0, dir='+') == sqrt(2) + assert limit(sin(x)/sqrt(1-cos(x)), x, 0, dir='-') == -sqrt(2) + + +def test_issue_18521(): + raises(NotImplementedError, lambda: limit(exp((2 - n) * x), x, oo)) + + +def test_issue_18969(): + a, b = symbols('a b', positive=True) + assert limit(LambertW(a), a, b) == LambertW(b) + assert limit(exp(LambertW(a)), a, b) == exp(LambertW(b)) + + +def test_issue_18992(): + assert limit(n/(factorial(n)**(1/n)), n, oo) == exp(1) + + +def test_issue_19067(): + x = Symbol('x') + assert limit(gamma(x)/(gamma(x - 1)*gamma(x + 2)), x, 0) == -1 + + +def test_issue_19586(): + assert limit(x**(2**x*3**(-x)), x, oo) == 1 + + +def test_issue_13715(): + n = Symbol('n') + p = Symbol('p', zero=True) + assert limit(n + p, n, 0) == 0 + + +def test_issue_15055(): + assert limit(n**3*((-n - 1)*sin(1/n) + (n + 2)*sin(1/(n + 1)))/(-n + 1), n, oo) == 1 + + +def test_issue_16708(): + m, vi = symbols('m vi', positive=True) + B, ti, d = symbols('B ti d') + assert limit((B*ti*vi - sqrt(m)*sqrt(-2*B*d*vi + m*(vi)**2) + m*vi)/(B*vi), B, 0) == (d + ti*vi)/vi + + +def test_issue_19154(): + assert limit(besseli(1, 3 *x)/(x *besseli(1, x)**3), x , oo) == 2*sqrt(3)*pi/3 + assert limit(besseli(1, 3 *x)/(x *besseli(1, x)**3), x , -oo) == -2*sqrt(3)*pi/3 + + +def test_issue_19453(): + beta = Symbol("beta", positive=True) + h = Symbol("h", positive=True) + m = Symbol("m", positive=True) + w = Symbol("omega", positive=True) + g = Symbol("g", positive=True) + + e = exp(1) + q = 3*h**2*beta*g*e**(0.5*h*beta*w) + p = m**2*w**2 + s = e**(h*beta*w) - 1 + Z = -q/(4*p*s) - q/(2*p*s**2) - q*(e**(h*beta*w) + 1)/(2*p*s**3)\ + + e**(0.5*h*beta*w)/s + E = -diff(log(Z), beta) + + assert limit(E - 0.5*h*w, beta, oo) == 0 + assert limit(E.simplify() - 0.5*h*w, beta, oo) == 0 + + +def test_issue_19739(): + assert limit((-S(1)/4)**x, x, oo) == 0 + + +def test_issue_19766(): + assert limit(2**(-x)*sqrt(4**(x + 1) + 1), x, oo) == 2 + + +def test_issue_19770(): + m = Symbol('m') + # the result is not 0 for non-real m + assert limit(cos(m*x)/x, x, oo) == Limit(cos(m*x)/x, x, oo, dir='-') + m = Symbol('m', real=True) + # can be improved to give the correct result 0 + assert limit(cos(m*x)/x, x, oo) == Limit(cos(m*x)/x, x, oo, dir='-') + m = Symbol('m', nonzero=True) + assert limit(cos(m*x), x, oo) == AccumBounds(-1, 1) + assert limit(cos(m*x)/x, x, oo) == 0 + + +def test_issue_7535(): + assert limit(tan(x)/sin(tan(x)), x, pi/2) == Limit(tan(x)/sin(tan(x)), x, pi/2, dir='+') + assert limit(tan(x)/sin(tan(x)), x, pi/2, dir='-') == Limit(tan(x)/sin(tan(x)), x, pi/2, dir='-') + assert limit(tan(x)/sin(tan(x)), x, pi/2, dir='+-') == Limit(tan(x)/sin(tan(x)), x, pi/2, dir='+-') + assert limit(sin(tan(x)),x,pi/2) == AccumBounds(-1, 1) + assert -oo*(1/sin(-oo)) == AccumBounds(-oo, oo) + assert oo*(1/sin(oo)) == AccumBounds(-oo, oo) + assert oo*(1/sin(-oo)) == AccumBounds(-oo, oo) + assert -oo*(1/sin(oo)) == AccumBounds(-oo, oo) + + +def test_issue_20365(): + assert limit(((x + 1)**(1/x) - E)/x, x, 0) == -E/2 + + +def test_issue_21031(): + assert limit(((1 + x)**(1/x) - (1 + 2*x)**(1/(2*x)))/asin(x), x, 0) == E/2 + + +def test_issue_21038(): + assert limit(sin(pi*x)/(3*x - 12), x, 4) == pi/3 + + +def test_issue_20578(): + expr = abs(x) * sin(1/x) + assert limit(expr,x,0,'+') == 0 + assert limit(expr,x,0,'-') == 0 + assert limit(expr,x,0,'+-') == 0 + + +def test_issue_21227(): + f = log(x) + + assert f.nseries(x, logx=y) == y + assert f.nseries(x, logx=-x) == -x + + f = log(-log(x)) + + assert f.nseries(x, logx=y) == log(-y) + assert f.nseries(x, logx=-x) == log(x) + + f = log(log(x)) + + assert f.nseries(x, logx=y) == log(y) + assert f.nseries(x, logx=-x) == log(-x) + assert f.nseries(x, logx=x) == log(x) + + f = log(log(log(1/x))) + + assert f.nseries(x, logx=y) == log(log(-y)) + assert f.nseries(x, logx=-y) == log(log(y)) + assert f.nseries(x, logx=x) == log(log(-x)) + assert f.nseries(x, logx=-x) == log(log(x)) + + +def test_issue_21415(): + exp = (x-1)*cos(1/(x-1)) + assert exp.limit(x,1) == 0 + assert exp.expand().limit(x,1) == 0 + + +def test_issue_21530(): + assert limit(sinh(n + 1)/sinh(n), n, oo) == E + + +def test_issue_21550(): + r = (sqrt(5) - 1)/2 + assert limit((x - r)/(x**2 + x - 1), x, r) == sqrt(5)/5 + + +def test_issue_21661(): + out = limit((x**(x + 1) * (log(x) + 1) + 1) / x, x, 11) + assert out == S(3138428376722)/11 + 285311670611*log(11) + + +def test_issue_21701(): + assert limit((besselj(z, x)/x**z).subs(z, 7), x, 0) == S(1)/645120 + + +def test_issue_21721(): + a = Symbol('a', real=True) + I = integrate(1/(pi*(1 + (x - a)**2)), x) + assert I.limit(x, oo) == S.Half + + +def test_issue_21756(): + term = (1 - exp(-2*I*pi*z))/(1 - exp(-2*I*pi*z/5)) + assert term.limit(z, 0) == 5 + assert re(term).limit(z, 0) == 5 + + +def test_issue_21785(): + a = Symbol('a') + assert sqrt((-a**2 + x**2)/(1 - x**2)).limit(a, 1, '-') == I + + +def test_issue_22181(): + assert limit((-1)**x * 2**(-x), x, oo) == 0 + + +def test_issue_22220(): + e1 = sqrt(30)*atan(sqrt(30)*tan(x/2)/6)/30 + e2 = sqrt(30)*I*(-log(sqrt(2)*tan(x/2) - 2*sqrt(15)*I/5) + + +log(sqrt(2)*tan(x/2) + 2*sqrt(15)*I/5))/60 + + assert limit(e1, x, -pi) == -sqrt(30)*pi/60 + assert limit(e2, x, -pi) == -sqrt(30)*pi/30 + + assert limit(e1, x, -pi, '-') == sqrt(30)*pi/60 + assert limit(e2, x, -pi, '-') == 0 + + # test https://github.com/sympy/sympy/issues/22220#issuecomment-972727694 + expr = log(x - I) - log(-x - I) + expr2 = logcombine(expr, force=True) + assert limit(expr, x, oo) == limit(expr2, x, oo) == I*pi + + # test https://github.com/sympy/sympy/issues/22220#issuecomment-1077618340 + expr = expr = (-log(tan(x/2) - I) +log(tan(x/2) + I)) + assert limit(expr, x, pi, '+') == 2*I*pi + assert limit(expr, x, pi, '-') == 0 + + +def test_issue_22334(): + k, n = symbols('k, n', positive=True) + assert limit((n+1)**k/((n+1)**(k+1) - (n)**(k+1)), n, oo) == 1/(k + 1) + assert limit((n+1)**k/((n+1)**(k+1) - (n)**(k+1)).expand(), n, oo) == 1/(k + 1) + assert limit((n+1)**k/(n*(-n**k + (n + 1)**k) + (n + 1)**k), n, oo) == 1/(k + 1) + + +def test_issue_22836_limit(): + assert limit(2**(1/x)/factorial(1/(x)), x, 0) == S.Zero + + +def test_sympyissue_22986(): + assert limit(acosh(1 + 1/x)*sqrt(x), x, oo) == sqrt(2) + + +def test_issue_23231(): + f = (2**x - 2**(-x))/(2**x + 2**(-x)) + assert limit(f, x, -oo) == -1 + + +def test_issue_23596(): + assert integrate(((1 + x)/x**2)*exp(-1/x), (x, 0, oo)) == oo + + +def test_issue_23752(): + expr1 = sqrt(-I*x**2 + x - 3) + expr2 = sqrt(-I*x**2 + I*x - 3) + assert limit(expr1, x, 0, '+') == -sqrt(3)*I + assert limit(expr1, x, 0, '-') == -sqrt(3)*I + assert limit(expr2, x, 0, '+') == sqrt(3)*I + assert limit(expr2, x, 0, '-') == -sqrt(3)*I + + +def test_issue_24276(): + fx = log(tan(pi/2*tanh(x))).diff(x) + assert fx.limit(x, oo) == 2 + assert fx.simplify().limit(x, oo) == 2 + assert fx.rewrite(sin).limit(x, oo) == 2 + assert fx.rewrite(sin).simplify().limit(x, oo) == 2 + +def test_issue_25230(): + a = Symbol('a', real = True) + b = Symbol('b', positive = True) + c = Symbol('c', negative = True) + n = Symbol('n', integer = True) + raises(NotImplementedError, lambda: limit(Mod(x, a), x, a)) + assert limit(Mod(x, b), x, n*b, '+') == 0 + assert limit(Mod(x, b), x, n*b, '-') == b + assert limit(Mod(x, c), x, n*c, '+') == c + assert limit(Mod(x, c), x, n*c, '-') == 0 + + +def test_issue_25582(): + + assert limit(asin(exp(x)), x, oo, '-') == -oo*I + assert limit(acos(exp(x)), x, oo, '-') == oo*I + assert limit(atan(exp(x)), x, oo, '-') == pi/2 + assert limit(acot(exp(x)), x, oo, '-') == 0 + assert limit(asec(exp(x)), x, oo, '-') == pi/2 + assert limit(acsc(exp(x)), x, oo, '-') == 0 + + +def test_issue_25847(): + #atan + assert limit(atan(sin(x)/x), x, 0, '+-') == pi/4 + assert limit(atan(exp(1/x)), x, 0, '+') == pi/2 + assert limit(atan(exp(1/x)), x, 0, '-') == 0 + + #asin + assert limit(asin(sin(x)/x), x, 0, '+-') == pi/2 + assert limit(asin(exp(1/x)), x, 0, '+') == -oo*I + assert limit(asin(exp(1/x)), x, 0, '-') == 0 + + #acos + assert limit(acos(sin(x)/x), x, 0, '+-') == 0 + assert limit(acos(exp(1/x)), x, 0, '+') == oo*I + assert limit(acos(exp(1/x)), x, 0, '-') == pi/2 + + #acot + assert limit(acot(sin(x)/x), x, 0, '+-') == pi/4 + assert limit(acot(exp(1/x)), x, 0, '+') == 0 + assert limit(acot(exp(1/x)), x, 0, '-') == pi/2 + + #asec + assert limit(asec(sin(x)/x), x, 0, '+-') == 0 + assert limit(asec(exp(1/x)), x, 0, '+') == pi/2 + assert limit(asec(exp(1/x)), x, 0, '-') == oo*I + + #acsc + assert limit(acsc(sin(x)/x), x, 0, '+-') == pi/2 + assert limit(acsc(exp(1/x)), x, 0, '+') == 0 + assert limit(acsc(exp(1/x)), x, 0, '-') == -oo*I + + #atanh + assert limit(atanh(sin(x)/x), x, 0, '+-') == oo + assert limit(atanh(exp(1/x)), x, 0, '+') == -I*pi/2 + assert limit(atanh(exp(1/x)), x, 0, '-') == 0 + + #asinh + assert limit(asinh(sin(x)/x), x, 0, '+-') == log(1 + sqrt(2)) + assert limit(asinh(exp(1/x)), x, 0, '+') == oo + assert limit(asinh(exp(1/x)), x, 0, '-') == 0 + + #acosh + assert limit(acosh(sin(x)/x), x, 0, '+-') == 0 + assert limit(acosh(exp(1/x)), x, 0, '+') == oo + assert limit(acosh(exp(1/x)), x, 0, '-') == I*pi/2 + + #acoth + assert limit(acoth(sin(x)/x), x, 0, '+-') == oo + assert limit(acoth(exp(1/x)), x, 0, '+') == 0 + assert limit(acoth(exp(1/x)), x, 0, '-') == -I*pi/2 + + #asech + assert limit(asech(sin(x)/x), x, 0, '+-') == 0 + assert limit(asech(exp(1/x)), x, 0, '+') == I*pi/2 + assert limit(asech(exp(1/x)), x, 0, '-') == oo + + #acsch + assert limit(acsch(sin(x)/x), x, 0, '+-') == log(1 + sqrt(2)) + assert limit(acsch(exp(1/x)), x, 0, '+') == 0 + assert limit(acsch(exp(1/x)), x, 0, '-') == oo + + +def test_issue_26040(): + assert limit(besseli(0, x + 1)/besseli(0, x), x, oo) == S.Exp1 + + +def test_issue_26250(): + e = elliptic_e(4*x/(x**2 + 2*x + 1)) + k = elliptic_k(4*x/(x**2 + 2*x + 1)) + e1 = ((1-3*x**2)*e**2/2 - (x**2-2*x+1)*e*k/2) + e2 = pi**2*(x**8 - 2*x**7 - x**6 + 4*x**5 - x**4 - 2*x**3 + x**2) + assert limit(e1/e2, x, 0) == -S(1)/8 + + +def test_issue_26513(): + assert limit(abs((-x/(x+1))**x), x ,oo) == exp(-1) + assert limit((x/(x + 1))**x, x, oo) == exp(-1) + raises (NotImplementedError, lambda: limit((-x/(x+1))**x, x, oo)) + + +def test_issue_26916(): + assert limit(Ei(x)*exp(-x), x, +oo) == 0 + assert limit(Ei(x)*exp(-x), x, -oo) == 0 + + +def test_issue_22982_15323(): + assert limit((log(E + 1/x) - 1)**(1 - sqrt(E + 1/x)), x, oo) == oo + assert limit((1 - 1/x)**x*(log(1 - 1/x) + 1/(x*(1 - 1/x))), x, 1, dir='+') == 1 + assert limit((log(E + 1/x) )**(1 - sqrt(E + 1/x)), x, oo) == 1 + assert limit((log(E + 1/x) - 1)**(- sqrt(E + 1/x)), x, oo) == oo + + +def test_issue_26991(): + assert limit(x/((x - 6)*sinh(tanh(0.03*x)) + tanh(x) - 0.5), x, oo) == 1/sinh(1) + +def test_issue_27278(): + expr = (1/(x*log((x + 3)/x)))**x*((x + 1)*log((x + 4)/(x + 1)))**(x + 1)/3 + assert limit(expr, x, oo) == 1 diff --git a/phivenv/Lib/site-packages/sympy/series/tests/test_limitseq.py b/phivenv/Lib/site-packages/sympy/series/tests/test_limitseq.py new file mode 100644 index 0000000000000000000000000000000000000000..362bb0397feb0ec63929920855c81279eca0bd6a --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/series/tests/test_limitseq.py @@ -0,0 +1,177 @@ +from sympy.concrete.summations import Sum +from sympy.core.add import Add +from sympy.core.numbers import (I, Rational, oo, pi) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.combinatorial.factorials import (binomial, factorial, subfactorial) +from sympy.functions.combinatorial.numbers import (fibonacci, harmonic) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.functions.special.gamma_functions import gamma +from sympy.series.limitseq import limit_seq +from sympy.series.limitseq import difference_delta as dd +from sympy.testing.pytest import raises, XFAIL +from sympy.calculus.accumulationbounds import AccumulationBounds + +n, m, k = symbols('n m k', integer=True) + + +def test_difference_delta(): + e = n*(n + 1) + e2 = e * k + + assert dd(e) == 2*n + 2 + assert dd(e2, n, 2) == k*(4*n + 6) + + raises(ValueError, lambda: dd(e2)) + raises(ValueError, lambda: dd(e2, n, oo)) + + +def test_difference_delta__Sum(): + e = Sum(1/k, (k, 1, n)) + assert dd(e, n) == 1/(n + 1) + assert dd(e, n, 5) == Add(*[1/(i + n + 1) for i in range(5)]) + + e = Sum(1/k, (k, 1, 3*n)) + assert dd(e, n) == Add(*[1/(i + 3*n + 1) for i in range(3)]) + + e = n * Sum(1/k, (k, 1, n)) + assert dd(e, n) == 1 + Sum(1/k, (k, 1, n)) + + e = Sum(1/k, (k, 1, n), (m, 1, n)) + assert dd(e, n) == harmonic(n) + + +def test_difference_delta__Add(): + e = n + n*(n + 1) + assert dd(e, n) == 2*n + 3 + assert dd(e, n, 2) == 4*n + 8 + + e = n + Sum(1/k, (k, 1, n)) + assert dd(e, n) == 1 + 1/(n + 1) + assert dd(e, n, 5) == 5 + Add(*[1/(i + n + 1) for i in range(5)]) + + +def test_difference_delta__Pow(): + e = 4**n + assert dd(e, n) == 3*4**n + assert dd(e, n, 2) == 15*4**n + + e = 4**(2*n) + assert dd(e, n) == 15*4**(2*n) + assert dd(e, n, 2) == 255*4**(2*n) + + e = n**4 + assert dd(e, n) == (n + 1)**4 - n**4 + + e = n**n + assert dd(e, n) == (n + 1)**(n + 1) - n**n + + +def test_limit_seq(): + e = binomial(2*n, n) / Sum(binomial(2*k, k), (k, 1, n)) + assert limit_seq(e) == S(3) / 4 + assert limit_seq(e, m) == e + + e = (5*n**3 + 3*n**2 + 4) / (3*n**3 + 4*n - 5) + assert limit_seq(e, n) == S(5) / 3 + + e = (harmonic(n) * Sum(harmonic(k), (k, 1, n))) / (n * harmonic(2*n)**2) + assert limit_seq(e, n) == 1 + + e = Sum(k**2 * Sum(2**m/m, (m, 1, k)), (k, 1, n)) / (2**n*n) + assert limit_seq(e, n) == 4 + + e = (Sum(binomial(3*k, k) * binomial(5*k, k), (k, 1, n)) / + (binomial(3*n, n) * binomial(5*n, n))) + assert limit_seq(e, n) == S(84375) / 83351 + + e = Sum(harmonic(k)**2/k, (k, 1, 2*n)) / harmonic(n)**3 + assert limit_seq(e, n) == S.One / 3 + + raises(ValueError, lambda: limit_seq(e * m)) + + +def test_alternating_sign(): + assert limit_seq((-1)**n/n**2, n) == 0 + assert limit_seq((-2)**(n+1)/(n + 3**n), n) == 0 + assert limit_seq((2*n + (-1)**n)/(n + 1), n) == 2 + assert limit_seq(sin(pi*n), n) == 0 + assert limit_seq(cos(2*pi*n), n) == 1 + assert limit_seq((S.NegativeOne/5)**n, n) == 0 + assert limit_seq((Rational(-1, 5))**n, n) == 0 + assert limit_seq((I/3)**n, n) == 0 + assert limit_seq(sqrt(n)*(I/2)**n, n) == 0 + assert limit_seq(n**7*(I/3)**n, n) == 0 + assert limit_seq(n/(n + 1) + (I/2)**n, n) == 1 + + +def test_accum_bounds(): + assert limit_seq((-1)**n, n) == AccumulationBounds(-1, 1) + assert limit_seq(cos(pi*n), n) == AccumulationBounds(-1, 1) + assert limit_seq(sin(pi*n/2)**2, n) == AccumulationBounds(0, 1) + assert limit_seq(2*(-3)**n/(n + 3**n), n) == AccumulationBounds(-2, 2) + assert limit_seq(3*n/(n + 1) + 2*(-1)**n, n) == AccumulationBounds(1, 5) + + +def test_limitseq_sum(): + from sympy.abc import x, y, z + assert limit_seq(Sum(1/x, (x, 1, y)) - log(y), y) == S.EulerGamma + assert limit_seq(Sum(1/x, (x, 1, y)) - 1/y, y) is S.Infinity + assert (limit_seq(binomial(2*x, x) / Sum(binomial(2*y, y), (y, 1, x)), x) == + S(3) / 4) + assert (limit_seq(Sum(y**2 * Sum(2**z/z, (z, 1, y)), (y, 1, x)) / + (2**x*x), x) == 4) + + +def test_issue_9308(): + assert limit_seq(subfactorial(n)/factorial(n), n) == exp(-1) + + +def test_issue_10382(): + n = Symbol('n', integer=True) + assert limit_seq(fibonacci(n+1)/fibonacci(n), n).together() == S.GoldenRatio + + +def test_issue_11672(): + assert limit_seq(Rational(-1, 2)**n, n) == 0 + + +def test_issue_14196(): + k, n = symbols('k, n', positive=True) + m = Symbol('m') + assert limit_seq(Sum(m**k, (m, 1, n)).doit()/(n**(k + 1)), n) == 1/(k + 1) + + +def test_issue_16735(): + assert limit_seq(5**n/factorial(n), n) == 0 + + +def test_issue_19868(): + assert limit_seq(1/gamma(n + S.One/2), n) == 0 + + +@XFAIL +def test_limit_seq_fail(): + # improve Summation algorithm or add ad-hoc criteria + e = (harmonic(n)**3 * Sum(1/harmonic(k), (k, 1, n)) / + (n * Sum(harmonic(k)/k, (k, 1, n)))) + assert limit_seq(e, n) == 2 + + # No unique dominant term + e = (Sum(2**k * binomial(2*k, k) / k**2, (k, 1, n)) / + (Sum(2**k/k*2, (k, 1, n)) * Sum(binomial(2*k, k), (k, 1, n)))) + assert limit_seq(e, n) == S(3) / 7 + + # Simplifications of summations needs to be improved. + e = n**3*Sum(2**k/k**2, (k, 1, n))**2 / (2**n * Sum(2**k/k, (k, 1, n))) + assert limit_seq(e, n) == 2 + + e = (harmonic(n) * Sum(2**k/k, (k, 1, n)) / + (n * Sum(2**k*harmonic(k)/k**2, (k, 1, n)))) + assert limit_seq(e, n) == 1 + + e = (Sum(2**k*factorial(k) / k**2, (k, 1, 2*n)) / + (Sum(4**k/k**2, (k, 1, n)) * Sum(factorial(k), (k, 1, 2*n)))) + assert limit_seq(e, n) == S(3) / 16 diff --git a/phivenv/Lib/site-packages/sympy/series/tests/test_lseries.py b/phivenv/Lib/site-packages/sympy/series/tests/test_lseries.py new file mode 100644 index 0000000000000000000000000000000000000000..42d327bf60c76eebdc4570d631efef4bc84b58e3 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/series/tests/test_lseries.py @@ -0,0 +1,65 @@ +from sympy.core.numbers import E +from sympy.core.singleton import S +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.hyperbolic import tanh +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.series.order import Order +from sympy.abc import x, y + + +def test_sin(): + e = sin(x).lseries(x) + assert next(e) == x + assert next(e) == -x**3/6 + assert next(e) == x**5/120 + + +def test_cos(): + e = cos(x).lseries(x) + assert next(e) == 1 + assert next(e) == -x**2/2 + assert next(e) == x**4/24 + + +def test_exp(): + e = exp(x).lseries(x) + assert next(e) == 1 + assert next(e) == x + assert next(e) == x**2/2 + assert next(e) == x**3/6 + + +def test_exp2(): + e = exp(cos(x)).lseries(x) + assert next(e) == E + assert next(e) == -E*x**2/2 + assert next(e) == E*x**4/6 + assert next(e) == -31*E*x**6/720 + + +def test_simple(): + assert list(x.lseries()) == [x] + assert list(S.One.lseries(x)) == [1] + assert not next((x/(x + y)).lseries(y)).has(Order) + + +def test_issue_5183(): + s = (x + 1/x).lseries() + assert list(s) == [1/x, x] + assert next((x + x**2).lseries()) == x + assert next(((1 + x)**7).lseries(x)) == 1 + assert next((sin(x + y)).series(x, n=3).lseries(y)) == x + # it would be nice if all terms were grouped, but in the + # following case that would mean that all the terms would have + # to be known since, for example, every term has a constant in it. + s = ((1 + x)**7).series(x, 1, n=None) + assert [next(s) for i in range(2)] == [128, -448 + 448*x] + + +def test_issue_6999(): + s = tanh(x).lseries(x, 1) + assert next(s) == tanh(1) + assert next(s) == x - (x - 1)*tanh(1)**2 - 1 + assert next(s) == -(x - 1)**2*tanh(1) + (x - 1)**2*tanh(1)**3 + assert next(s) == -(x - 1)**3*tanh(1)**4 - (x - 1)**3/3 + \ + 4*(x - 1)**3*tanh(1)**2/3 diff --git a/phivenv/Lib/site-packages/sympy/series/tests/test_nseries.py b/phivenv/Lib/site-packages/sympy/series/tests/test_nseries.py new file mode 100644 index 0000000000000000000000000000000000000000..a2f20add82d3e858e2ce145fc9fcd4a6548a48cc --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/series/tests/test_nseries.py @@ -0,0 +1,557 @@ +from sympy.calculus.util import AccumBounds +from sympy.core.function import (Derivative, PoleError) +from sympy.core.numbers import (E, I, Integer, Rational, pi) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.complexes import sign +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.hyperbolic import (acosh, acoth, asinh, atanh, cosh, coth, sinh, tanh) +from sympy.functions.elementary.integers import (ceiling, floor, frac) +from sympy.functions.elementary.miscellaneous import (cbrt, sqrt) +from sympy.functions.elementary.trigonometric import (asin, cos, cot, sin, tan) +from sympy.series.limits import limit +from sympy.series.order import O +from sympy.abc import x, y, z + +from sympy.testing.pytest import raises, XFAIL + + +def test_simple_1(): + assert x.nseries(x, n=5) == x + assert y.nseries(x, n=5) == y + assert (1/(x*y)).nseries(y, n=5) == 1/(x*y) + assert Rational(3, 4).nseries(x, n=5) == Rational(3, 4) + assert x.nseries() == x + + +def test_mul_0(): + assert (x*log(x)).nseries(x, n=5) == x*log(x) + + +def test_mul_1(): + assert (x*log(2 + x)).nseries(x, n=5) == x*log(2) + x**2/2 - x**3/8 + \ + x**4/24 + O(x**5) + assert (x*log(1 + x)).nseries( + x, n=5) == x**2 - x**3/2 + x**4/3 + O(x**5) + + +def test_pow_0(): + assert (x**2).nseries(x, n=5) == x**2 + assert (1/x).nseries(x, n=5) == 1/x + assert (1/x**2).nseries(x, n=5) == 1/x**2 + assert (x**Rational(2, 3)).nseries(x, n=5) == (x**Rational(2, 3)) + assert (sqrt(x)**3).nseries(x, n=5) == (sqrt(x)**3) + + +def test_pow_1(): + assert ((1 + x)**2).nseries(x, n=5) == x**2 + 2*x + 1 + + # https://github.com/sympy/sympy/issues/21075 + assert ((sqrt(x) + 1)**2).nseries(x) == 2*sqrt(x) + x + 1 + assert ((sqrt(x) + cbrt(x))**2).nseries(x) == 2*x**Rational(5, 6)\ + + x**Rational(2, 3) + x + + +def test_geometric_1(): + assert (1/(1 - x)).nseries(x, n=5) == 1 + x + x**2 + x**3 + x**4 + O(x**5) + assert (x/(1 - x)).nseries(x, n=6) == x + x**2 + x**3 + x**4 + x**5 + O(x**6) + assert (x**3/(1 - x)).nseries(x, n=8) == x**3 + x**4 + x**5 + x**6 + \ + x**7 + O(x**8) + + +def test_sqrt_1(): + assert sqrt(1 + x).nseries(x, n=5) == 1 + x/2 - x**2/8 + x**3/16 - 5*x**4/128 + O(x**5) + + +def test_exp_1(): + assert exp(x).nseries(x, n=5) == 1 + x + x**2/2 + x**3/6 + x**4/24 + O(x**5) + assert exp(x).nseries(x, n=12) == 1 + x + x**2/2 + x**3/6 + x**4/24 + x**5/120 + \ + x**6/720 + x**7/5040 + x**8/40320 + x**9/362880 + x**10/3628800 + \ + x**11/39916800 + O(x**12) + assert exp(1/x).nseries(x, n=5) == exp(1/x) + assert exp(1/(1 + x)).nseries(x, n=4) == \ + (E*(1 - x - 13*x**3/6 + 3*x**2/2)).expand() + O(x**4) + assert exp(2 + x).nseries(x, n=5) == \ + (exp(2)*(1 + x + x**2/2 + x**3/6 + x**4/24)).expand() + O(x**5) + + +def test_exp_sqrt_1(): + assert exp(1 + sqrt(x)).nseries(x, n=3) == \ + (exp(1)*(1 + sqrt(x) + x/2 + sqrt(x)*x/6)).expand() + O(sqrt(x)**3) + + +def test_power_x_x1(): + assert (exp(x*log(x))).nseries(x, n=4) == \ + 1 + x*log(x) + x**2*log(x)**2/2 + x**3*log(x)**3/6 + O(x**4*log(x)**4) + + +def test_power_x_x2(): + assert (x**x).nseries(x, n=4) == \ + 1 + x*log(x) + x**2*log(x)**2/2 + x**3*log(x)**3/6 + O(x**4*log(x)**4) + + +def test_log_singular1(): + assert log(1 + 1/x).nseries(x, n=5) == x - log(x) - x**2/2 + x**3/3 - \ + x**4/4 + O(x**5) + + +def test_log_power1(): + e = 1 / (1/x + x ** (log(3)/log(2))) + assert e.nseries(x, n=5) == -x**(log(3)/log(2) + 2) + x + O(x**5) + + +def test_log_series(): + l = Symbol('l') + e = 1/(1 - log(x)) + assert e.nseries(x, n=5, logx=l) == 1/(1 - l) + + +def test_log2(): + e = log(-1/x) + assert e.nseries(x, n=5) == -log(x) + log(-1) + + +def test_log3(): + l = Symbol('l') + e = 1/log(-1/x) + assert e.nseries(x, n=4, logx=l) == 1/(-l + log(-1)) + + +def test_series1(): + e = sin(x) + assert e.nseries(x, 0, 0) != 0 + assert e.nseries(x, 0, 0) == O(1, x) + assert e.nseries(x, 0, 1) == O(x, x) + assert e.nseries(x, 0, 2) == x + O(x**2, x) + assert e.nseries(x, 0, 3) == x + O(x**3, x) + assert e.nseries(x, 0, 4) == x - x**3/6 + O(x**4, x) + + e = (exp(x) - 1)/x + assert e.nseries(x, 0, 3) == 1 + x/2 + x**2/6 + O(x**3) + + assert x.nseries(x, 0, 2) == x + + +@XFAIL +def test_series1_failing(): + assert x.nseries(x, 0, 0) == O(1, x) + assert x.nseries(x, 0, 1) == O(x, x) + + +def test_seriesbug1(): + assert (1/x).nseries(x, 0, 3) == 1/x + assert (x + 1/x).nseries(x, 0, 3) == x + 1/x + + +def test_series2x(): + assert ((x + 1)**(-2)).nseries(x, 0, 4) == 1 - 2*x + 3*x**2 - 4*x**3 + O(x**4, x) + assert ((x + 1)**(-1)).nseries(x, 0, 4) == 1 - x + x**2 - x**3 + O(x**4, x) + assert ((x + 1)**0).nseries(x, 0, 3) == 1 + assert ((x + 1)**1).nseries(x, 0, 3) == 1 + x + assert ((x + 1)**2).nseries(x, 0, 3) == x**2 + 2*x + 1 + assert ((x + 1)**3).nseries(x, 0, 3) == 1 + 3*x + 3*x**2 + O(x**3) + + assert (1/(1 + x)).nseries(x, 0, 4) == 1 - x + x**2 - x**3 + O(x**4, x) + assert (x + 3/(1 + 2*x)).nseries(x, 0, 4) == 3 - 5*x + 12*x**2 - 24*x**3 + O(x**4, x) + + assert ((1/x + 1)**3).nseries(x, 0, 3) == 1 + 3/x + 3/x**2 + x**(-3) + assert (1/(1 + 1/x)).nseries(x, 0, 4) == x - x**2 + x**3 - O(x**4, x) + assert (1/(1 + 1/x**2)).nseries(x, 0, 6) == x**2 - x**4 + O(x**6, x) + + +def test_bug2(): # 1/log(0)*log(0) problem + w = Symbol("w") + e = (w**(-1) + w**( + -log(3)*log(2)**(-1)))**(-1)*(3*w**(-log(3)*log(2)**(-1)) + 2*w**(-1)) + e = e.expand() + assert e.nseries(w, 0, 4).subs(w, 0) == 3 + + +def test_exp(): + e = (1 + x)**(1/x) + assert e.nseries(x, n=3) == exp(1) - x*exp(1)/2 + 11*exp(1)*x**2/24 + O(x**3) + + +def test_exp2(): + w = Symbol("w") + e = w**(1 - log(x)/(log(2) + log(x))) + logw = Symbol("logw") + assert e.nseries( + w, 0, 1, logx=logw) == exp(logw*log(2)/(log(x) + log(2))) + + +def test_bug3(): + e = (2/x + 3/x**2)/(1/x + 1/x**2) + assert e.nseries(x, n=3) == 3 - x + x**2 + O(x**3) + + +def test_generalexponent(): + p = 2 + e = (2/x + 3/x**p)/(1/x + 1/x**p) + assert e.nseries(x, 0, 3) == 3 - x + x**2 + O(x**3) + p = S.Half + e = (2/x + 3/x**p)/(1/x + 1/x**p) + assert e.nseries(x, 0, 2) == 2 - x + sqrt(x) + x**(S(3)/2) + O(x**2) + + e = 1 + sqrt(x) + assert e.nseries(x, 0, 4) == 1 + sqrt(x) + +# more complicated example + + +def test_genexp_x(): + e = 1/(1 + sqrt(x)) + assert e.nseries(x, 0, 2) == \ + 1 + x - sqrt(x) - sqrt(x)**3 + O(x**2, x) + +# more complicated example + + +def test_genexp_x2(): + p = Rational(3, 2) + e = (2/x + 3/x**p)/(1/x + 1/x**p) + assert e.nseries(x, 0, 3) == 3 + x + x**2 - sqrt(x) - x**(S(3)/2) - x**(S(5)/2) + O(x**3) + + +def test_seriesbug2(): + w = Symbol("w") + #simple case (1): + e = ((2*w)/w)**(1 + w) + assert e.nseries(w, 0, 1) == 2 + O(w, w) + assert e.nseries(w, 0, 1).subs(w, 0) == 2 + + +def test_seriesbug2b(): + w = Symbol("w") + #test sin + e = sin(2*w)/w + assert e.nseries(w, 0, 3) == 2 - 4*w**2/3 + O(w**3) + + +def test_seriesbug2d(): + w = Symbol("w", real=True) + e = log(sin(2*w)/w) + assert e.series(w, n=5) == log(2) - 2*w**2/3 - 4*w**4/45 + O(w**5) + + +def test_seriesbug2c(): + w = Symbol("w", real=True) + #more complicated case, but sin(x)~x, so the result is the same as in (1) + e = (sin(2*w)/w)**(1 + w) + assert e.series(w, 0, 1) == 2 + O(w) + assert e.series(w, 0, 3) == 2 + 2*w*log(2) + \ + w**2*(Rational(-4, 3) + log(2)**2) + O(w**3) + assert e.series(w, 0, 2).subs(w, 0) == 2 + + +def test_expbug4(): + x = Symbol("x", real=True) + assert (log( + sin(2*x)/x)*(1 + x)).series(x, 0, 2) == log(2) + x*log(2) + O(x**2, x) + assert exp( + log(sin(2*x)/x)*(1 + x)).series(x, 0, 2) == 2 + 2*x*log(2) + O(x**2) + + assert exp(log(2) + O(x)).nseries(x, 0, 2) == 2 + O(x) + assert ((2 + O(x))**(1 + x)).nseries(x, 0, 2) == 2 + O(x) + + +def test_logbug4(): + assert log(2 + O(x)).nseries(x, 0, 2) == log(2) + O(x, x) + + +def test_expbug5(): + assert exp(log(1 + x)/x).nseries(x, n=3) == exp(1) + -exp(1)*x/2 + 11*exp(1)*x**2/24 + O(x**3) + + assert exp(O(x)).nseries(x, 0, 2) == 1 + O(x) + + +def test_sinsinbug(): + assert sin(sin(x)).nseries(x, 0, 8) == x - x**3/3 + x**5/10 - 8*x**7/315 + O(x**8) + + +def test_issue_3258(): + a = x/(exp(x) - 1) + assert a.nseries(x, 0, 5) == 1 - x/2 - x**4/720 + x**2/12 + O(x**5) + + +def test_issue_3204(): + x = Symbol("x", nonnegative=True) + f = sin(x**3)**Rational(1, 3) + assert f.nseries(x, 0, 17) == x - x**7/18 - x**13/3240 + O(x**17) + + +def test_issue_3224(): + f = sqrt(1 - sqrt(y)) + assert f.nseries(y, 0, 2) == 1 - sqrt(y)/2 - y/8 - sqrt(y)**3/16 + O(y**2) + + +def test_issue_3463(): + w, i = symbols('w,i') + r = log(5)/log(3) + p = w**(-1 + r) + e = 1/x*(-log(w**(1 + r)) + log(w + w**r)) + e_ser = -r*log(w)/x + p/x - p**2/(2*x) + O(w) + assert e.nseries(w, n=1) == e_ser + + +def test_sin(): + assert sin(8*x).nseries(x, n=4) == 8*x - 256*x**3/3 + O(x**4) + assert sin(x + y).nseries(x, n=1) == sin(y) + O(x) + assert sin(x + y).nseries(x, n=2) == sin(y) + cos(y)*x + O(x**2) + assert sin(x + y).nseries(x, n=5) == sin(y) + cos(y)*x - sin(y)*x**2/2 - \ + cos(y)*x**3/6 + sin(y)*x**4/24 + O(x**5) + + +def test_issue_3515(): + e = sin(8*x)/x + assert e.nseries(x, n=6) == 8 - 256*x**2/3 + 4096*x**4/15 + O(x**6) + + +def test_issue_3505(): + e = sin(x)**(-4)*(sqrt(cos(x))*sin(x)**2 - + cos(x)**Rational(1, 3)*sin(x)**2) + assert e.nseries(x, n=9) == Rational(-1, 12) - 7*x**2/288 - \ + 43*x**4/10368 - 1123*x**6/2488320 + 377*x**8/29859840 + O(x**9) + + +def test_issue_3501(): + a = Symbol("a") + e = x**(-2)*(x*sin(a + x) - x*sin(a)) + assert e.nseries(x, n=6) == cos(a) - sin(a)*x/2 - cos(a)*x**2/6 + \ + x**3*sin(a)/24 + x**4*cos(a)/120 - x**5*sin(a)/720 + O(x**6) + e = x**(-2)*(x*cos(a + x) - x*cos(a)) + assert e.nseries(x, n=6) == -sin(a) - cos(a)*x/2 + sin(a)*x**2/6 + \ + cos(a)*x**3/24 - x**4*sin(a)/120 - x**5*cos(a)/720 + O(x**6) + + +def test_issue_3502(): + e = sin(5*x)/sin(2*x) + assert e.nseries(x, n=2) == Rational(5, 2) + O(x**2) + assert e.nseries(x, n=6) == \ + Rational(5, 2) - 35*x**2/4 + 329*x**4/48 + O(x**6) + + +def test_issue_3503(): + e = sin(2 + x)/(2 + x) + assert e.nseries(x, n=2) == sin(2)/2 + x*cos(2)/2 - x*sin(2)/4 + O(x**2) + + +def test_issue_3506(): + e = (x + sin(3*x))**(-2)*(x*(x + sin(3*x)) - (x + sin(3*x))*sin(2*x)) + assert e.nseries(x, n=7) == \ + Rational(-1, 4) + 5*x**2/96 + 91*x**4/768 + 11117*x**6/129024 + O(x**7) + + +def test_issue_3508(): + x = Symbol("x", real=True) + assert log(sin(x)).series(x, n=5) == log(x) - x**2/6 - x**4/180 + O(x**5) + e = -log(x) + x*(-log(x) + log(sin(2*x))) + log(sin(2*x)) + assert e.series(x, n=5) == \ + log(2) + log(2)*x - 2*x**2/3 - 2*x**3/3 - 4*x**4/45 + O(x**5) + + +def test_issue_3507(): + e = x**(-4)*(x**2 - x**2*sqrt(cos(x))) + assert e.nseries(x, n=9) == \ + Rational(1, 4) + x**2/96 + 19*x**4/5760 + 559*x**6/645120 + 29161*x**8/116121600 + O(x**9) + + +def test_issue_3639(): + assert sin(cos(x)).nseries(x, n=5) == \ + sin(1) - x**2*cos(1)/2 - x**4*sin(1)/8 + x**4*cos(1)/24 + O(x**5) + + +def test_hyperbolic(): + assert sinh(x).nseries(x, n=6) == x + x**3/6 + x**5/120 + O(x**6) + assert cosh(x).nseries(x, n=5) == 1 + x**2/2 + x**4/24 + O(x**5) + assert tanh(x).nseries(x, n=6) == x - x**3/3 + 2*x**5/15 + O(x**6) + assert coth(x).nseries(x, n=6) == \ + 1/x - x**3/45 + x/3 + 2*x**5/945 + O(x**6) + assert asinh(x).nseries(x, n=6) == x - x**3/6 + 3*x**5/40 + O(x**6) + assert acosh(x).nseries(x, n=6) == \ + pi*I/2 - I*x - 3*I*x**5/40 - I*x**3/6 + O(x**6) + assert atanh(x).nseries(x, n=6) == x + x**3/3 + x**5/5 + O(x**6) + assert acoth(x).nseries(x, n=6) == -I*pi/2 + x + x**3/3 + x**5/5 + O(x**6) + + +def test_series2(): + w = Symbol("w", real=True) + x = Symbol("x", real=True) + e = w**(-2)*(w*exp(1/x - w) - w*exp(1/x)) + assert e.nseries(w, n=4) == -exp(1/x) + w*exp(1/x)/2 - w**2*exp(1/x)/6 + w**3*exp(1/x)/24 + O(w**4) + + +def test_series3(): + w = Symbol("w", real=True) + e = w**(-6)*(w**3*tan(w) - w**3*sin(w)) + assert e.nseries(w, n=8) == Integer(1)/2 + w**2/8 + 13*w**4/240 + 529*w**6/24192 + O(w**8) + + +def test_bug4(): + w = Symbol("w") + e = x/(w**4 + x**2*w**4 + 2*x*w**4)*w**4 + assert e.nseries(w, n=2).removeO().expand() in [x/(1 + 2*x + x**2), + 1/(1 + x/2 + 1/x/2)/2, 1/x/(1 + 2/x + x**(-2))] + + +def test_bug5(): + w = Symbol("w") + l = Symbol('l') + e = (-log(w) + log(1 + w*log(x)))**(-2)*w**(-2)*((-log(w) + + log(1 + x*w))*(-log(w) + log(1 + w*log(x)))*w - x*(-log(w) + + log(1 + w*log(x)))*w) + assert e.nseries(w, n=0, logx=l) == x/w/l + 1/w + O(1, w) + assert e.nseries(w, n=1, logx=l) == x/w/l + 1/w - x/l + 1/l*log(x) \ + + x*log(x)/l**2 + O(w) + + +def test_issue_4115(): + assert (sin(x)/(1 - cos(x))).nseries(x, n=1) == 2/x + O(x) + assert (sin(x)**2/(1 - cos(x))).nseries(x, n=1) == 2 + O(x) + + +def test_pole(): + raises(PoleError, lambda: sin(1/x).series(x, 0, 5)) + raises(PoleError, lambda: sin(1 + 1/x).series(x, 0, 5)) + raises(PoleError, lambda: (x*sin(1/x)).series(x, 0, 5)) + + +def test_expsinbug(): + assert exp(sin(x)).series(x, 0, 0) == O(1, x) + assert exp(sin(x)).series(x, 0, 1) == 1 + O(x) + assert exp(sin(x)).series(x, 0, 2) == 1 + x + O(x**2) + assert exp(sin(x)).series(x, 0, 3) == 1 + x + x**2/2 + O(x**3) + assert exp(sin(x)).series(x, 0, 4) == 1 + x + x**2/2 + O(x**4) + assert exp(sin(x)).series(x, 0, 5) == 1 + x + x**2/2 - x**4/8 + O(x**5) + + +def test_floor(): + x = Symbol('x') + assert floor(x).series(x) == 0 + assert floor(-x).series(x) == -1 + assert floor(sin(x)).series(x) == 0 + assert floor(sin(-x)).series(x) == -1 + assert floor(x**3).series(x) == 0 + assert floor(-x**3).series(x) == -1 + assert floor(cos(x)).series(x) == 0 + assert floor(cos(-x)).series(x) == 0 + assert floor(5 + sin(x)).series(x) == 5 + assert floor(5 + sin(-x)).series(x) == 4 + + assert floor(x).series(x, 2) == 2 + assert floor(-x).series(x, 2) == -3 + + x = Symbol('x', negative=True) + assert floor(x + 1.5).series(x) == 1 + + +def test_frac(): + assert frac(x).series(x, cdir=1) == x + assert frac(x).series(x, cdir=-1) == 1 + x + assert frac(2*x + 1).series(x, cdir=1) == 2*x + assert frac(2*x + 1).series(x, cdir=-1) == 1 + 2*x + assert frac(x**2).series(x, cdir=1) == x**2 + assert frac(x**2).series(x, cdir=-1) == x**2 + assert frac(sin(x) + 5).series(x, cdir=1) == x - x**3/6 + x**5/120 + O(x**6) + assert frac(sin(x) + 5).series(x, cdir=-1) == 1 + x - x**3/6 + x**5/120 + O(x**6) + assert frac(sin(x) + S.Half).series(x) == S.Half + x - x**3/6 + x**5/120 + O(x**6) + assert frac(x**8).series(x, cdir=1) == O(x**6) + assert frac(1/x).series(x) == AccumBounds(0, 1) + O(x**6) + + +def test_ceiling(): + assert ceiling(x).series(x) == 1 + assert ceiling(-x).series(x) == 0 + assert ceiling(sin(x)).series(x) == 1 + assert ceiling(sin(-x)).series(x) == 0 + assert ceiling(1 - cos(x)).series(x) == 1 + assert ceiling(1 - cos(-x)).series(x) == 1 + assert ceiling(x).series(x, 2) == 3 + assert ceiling(-x).series(x, 2) == -2 + + +def test_abs(): + a = Symbol('a') + assert abs(x).nseries(x, n=4) == x + assert abs(-x).nseries(x, n=4) == x + assert abs(x + 1).nseries(x, n=4) == x + 1 + assert abs(sin(x)).nseries(x, n=4) == x - Rational(1, 6)*x**3 + O(x**4) + assert abs(sin(-x)).nseries(x, n=4) == x - Rational(1, 6)*x**3 + O(x**4) + assert abs(x - a).nseries(x, 1) == -a*sign(1 - a) + (x - 1)*sign(1 - a) + sign(1 - a) + + +def test_dir(): + assert abs(x).series(x, 0, dir="+") == x + assert abs(x).series(x, 0, dir="-") == -x + assert floor(x + 2).series(x, 0, dir='+') == 2 + assert floor(x + 2).series(x, 0, dir='-') == 1 + assert floor(x + 2.2).series(x, 0, dir='-') == 2 + assert ceiling(x + 2.2).series(x, 0, dir='-') == 3 + assert sin(x + y).series(x, 0, dir='-') == sin(x + y).series(x, 0, dir='+') + + +def test_cdir(): + assert abs(x).series(x, 0, cdir=1) == x + assert abs(x).series(x, 0, cdir=-1) == -x + assert floor(x + 2).series(x, 0, cdir=1) == 2 + assert floor(x + 2).series(x, 0, cdir=-1) == 1 + assert floor(x + 2.2).series(x, 0, cdir=1) == 2 + assert ceiling(x + 2.2).series(x, 0, cdir=-1) == 3 + assert sin(x + y).series(x, 0, cdir=-1) == sin(x + y).series(x, 0, cdir=1) + + +def test_issue_3504(): + a = Symbol("a") + e = asin(a*x)/x + assert e.series(x, 4, n=2).removeO() == \ + (x - 4)*(a/(4*sqrt(-16*a**2 + 1)) - asin(4*a)/16) + asin(4*a)/4 + + +def test_issue_4441(): + a, b = symbols('a,b') + f = 1/(1 + a*x) + assert f.series(x, 0, 5) == 1 - a*x + a**2*x**2 - a**3*x**3 + \ + a**4*x**4 + O(x**5) + f = 1/(1 + (a + b)*x) + assert f.series(x, 0, 3) == 1 + x*(-a - b)\ + + x**2*(a + b)**2 + O(x**3) + + +def test_issue_4329(): + assert tan(x).series(x, pi/2, n=3).removeO() == \ + -pi/6 + x/3 - 1/(x - pi/2) + assert cot(x).series(x, pi, n=3).removeO() == \ + -x/3 + pi/3 + 1/(x - pi) + assert limit(tan(x)**tan(2*x), x, pi/4) == exp(-1) + + +def test_issue_5183(): + assert abs(x + x**2).series(n=1) == O(x) + assert abs(x + x**2).series(n=2) == x + O(x**2) + assert ((1 + x)**2).series(x, n=6) == x**2 + 2*x + 1 + assert (1 + 1/x).series() == 1 + 1/x + assert Derivative(exp(x).series(), x).doit() == \ + 1 + x + x**2/2 + x**3/6 + x**4/24 + O(x**5) + + +def test_issue_5654(): + a = Symbol('a') + assert (1/(x**2+a**2)**2).nseries(x, x0=I*a, n=0) == \ + -I/(4*a**3*(-I*a + x)) - 1/(4*a**2*(-I*a + x)**2) + O(1, (x, I*a)) + assert (1/(x**2+a**2)**2).nseries(x, x0=I*a, n=1) == 3/(16*a**4) \ + -I/(4*a**3*(-I*a + x)) - 1/(4*a**2*(-I*a + x)**2) + O(-I*a + x, (x, I*a)) + + +def test_issue_5925(): + sx = sqrt(x + z).series(z, 0, 1) + sxy = sqrt(x + y + z).series(z, 0, 1) + s1, s2 = sx.subs(x, x + y), sxy + assert (s1 - s2).expand().removeO().simplify() == 0 + + sx = sqrt(x + z).series(z, 0, 1) + sxy = sqrt(x + y + z).series(z, 0, 1) + assert sxy.subs({x:1, y:2}) == sx.subs(x, 3) + + +def test_exp_2(): + assert exp(x**3).nseries(x, 0, 14) == 1 + x**3 + x**6/2 + x**9/6 + x**12/24 + O(x**14) diff --git a/phivenv/Lib/site-packages/sympy/series/tests/test_order.py b/phivenv/Lib/site-packages/sympy/series/tests/test_order.py new file mode 100644 index 0000000000000000000000000000000000000000..50fcb861ee2a76c730baae6d26cc1e7a00347176 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/series/tests/test_order.py @@ -0,0 +1,503 @@ +from sympy.core.add import Add +from sympy.core.function import (Function, expand) +from sympy.core.numbers import (I, Rational, nan, oo, pi) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.combinatorial.factorials import factorial +from sympy.functions.elementary.complexes import (conjugate, transpose) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.integrals.integrals import Integral +from sympy.series.order import O, Order +from sympy.core.expr import unchanged +from sympy.testing.pytest import raises +from sympy.abc import w, x, y, z +from sympy.testing.pytest import XFAIL + + +def test_caching_bug(): + #needs to be a first test, so that all caches are clean + #cache it + O(w) + #and test that this won't raise an exception + O(w**(-1/x/log(3)*log(5)), w) + + +def test_free_symbols(): + assert Order(1).free_symbols == set() + assert Order(x).free_symbols == {x} + assert Order(1, x).free_symbols == {x} + assert Order(x*y).free_symbols == {x, y} + assert Order(x, x, y).free_symbols == {x, y} + + +def test_simple_1(): + o = Rational(0) + assert Order(2*x) == Order(x) + assert Order(x)*3 == Order(x) + assert -28*Order(x) == Order(x) + assert Order(Order(x)) == Order(x) + assert Order(Order(x), y) == Order(Order(x), x, y) + assert Order(-23) == Order(1) + assert Order(exp(x)) == Order(1, x) + assert Order(exp(1/x)).expr == exp(1/x) + assert Order(x*exp(1/x)).expr == x*exp(1/x) + assert Order(x**(o/3)).expr == x**(o/3) + assert Order(x**(o*Rational(5, 3))).expr == x**(o*Rational(5, 3)) + assert Order(x**2 + x + y, x) == O(1, x) + assert Order(x**2 + x + y, y) == O(1, y) + raises(ValueError, lambda: Order(exp(x), x, x)) + raises(TypeError, lambda: Order(x, 2 - x)) + + +def test_simple_2(): + assert Order(2*x)*x == Order(x**2) + assert Order(2*x)/x == Order(1, x) + assert Order(2*x)*x*exp(1/x) == Order(x**2*exp(1/x)) + assert (Order(2*x)*x*exp(1/x)/log(x)**3).expr == x**2*exp(1/x)*log(x)**-3 + + +def test_simple_3(): + assert Order(x) + x == Order(x) + assert Order(x) + 2 == 2 + Order(x) + assert Order(x) + x**2 == Order(x) + assert Order(x) + 1/x == 1/x + Order(x) + assert Order(1/x) + 1/x**2 == 1/x**2 + Order(1/x) + assert Order(x) + exp(1/x) == Order(x) + exp(1/x) + + +def test_simple_4(): + assert Order(x)**2 == Order(x**2) + + +def test_simple_5(): + assert Order(x) + Order(x**2) == Order(x) + assert Order(x) + Order(x**-2) == Order(x**-2) + assert Order(x) + Order(1/x) == Order(1/x) + + +def test_simple_6(): + assert Order(x) - Order(x) == Order(x) + assert Order(x) + Order(1) == Order(1) + assert Order(x) + Order(x**2) == Order(x) + assert Order(1/x) + Order(1) == Order(1/x) + assert Order(x) + Order(exp(1/x)) == Order(exp(1/x)) + assert Order(x**3) + Order(exp(2/x)) == Order(exp(2/x)) + assert Order(x**-3) + Order(exp(2/x)) == Order(exp(2/x)) + + +def test_simple_7(): + assert 1 + O(1) == O(1) + assert 2 + O(1) == O(1) + assert x + O(1) == O(1) + assert 1/x + O(1) == 1/x + O(1) + + +def test_simple_8(): + assert O(sqrt(-x)) == O(sqrt(x)) + assert O(x**2*sqrt(x)) == O(x**Rational(5, 2)) + assert O(x**3*sqrt(-(-x)**3)) == O(x**Rational(9, 2)) + assert O(x**Rational(3, 2)*sqrt((-x)**3)) == O(x**3) + assert O(x*(-2*x)**(I/2)) == O(x*(-x)**(I/2)) + + +def test_as_expr_variables(): + assert Order(x).as_expr_variables(None) == (x, ((x, 0),)) + assert Order(x).as_expr_variables(((x, 0),)) == (x, ((x, 0),)) + assert Order(y).as_expr_variables(((x, 0),)) == (y, ((x, 0), (y, 0))) + assert Order(y).as_expr_variables(((x, 0), (y, 0))) == (y, ((x, 0), (y, 0))) + + +def test_contains_0(): + assert Order(1, x).contains(Order(1, x)) + assert Order(1, x).contains(Order(1)) + assert Order(1).contains(Order(1, x)) is False + + +def test_contains_1(): + assert Order(x).contains(Order(x)) + assert Order(x).contains(Order(x**2)) + assert not Order(x**2).contains(Order(x)) + assert not Order(x).contains(Order(1/x)) + assert not Order(1/x).contains(Order(exp(1/x))) + assert not Order(x).contains(Order(exp(1/x))) + assert Order(1/x).contains(Order(x)) + assert Order(exp(1/x)).contains(Order(x)) + assert Order(exp(1/x)).contains(Order(1/x)) + assert Order(exp(1/x)).contains(Order(exp(1/x))) + assert Order(exp(2/x)).contains(Order(exp(1/x))) + assert not Order(exp(1/x)).contains(Order(exp(2/x))) + + +def test_contains_2(): + assert Order(x).contains(Order(y)) is None + assert Order(x).contains(Order(y*x)) + assert Order(y*x).contains(Order(x)) + assert Order(y).contains(Order(x*y)) + assert Order(x).contains(Order(y**2*x)) + + +def test_contains_3(): + assert Order(x*y**2).contains(Order(x**2*y)) is None + assert Order(x**2*y).contains(Order(x*y**2)) is None + + +def test_contains_4(): + assert Order(sin(1/x**2)).contains(Order(cos(1/x**2))) is True + assert Order(cos(1/x**2)).contains(Order(sin(1/x**2))) is True + + +def test_contains(): + assert Order(1, x) not in Order(1) + assert Order(1) in Order(1, x) + raises(TypeError, lambda: Order(x*y**2) in Order(x**2*y)) + + +def test_add_1(): + assert Order(x + x) == Order(x) + assert Order(3*x - 2*x**2) == Order(x) + assert Order(1 + x) == Order(1, x) + assert Order(1 + 1/x) == Order(1/x) + # TODO : A better output for Order(log(x) + 1/log(x)) + # could be Order(log(x)). Currently Order for expressions + # where all arguments would involve a log term would fall + # in this category and outputs for these should be improved. + assert Order(log(x) + 1/log(x)) == Order((log(x)**2 + 1)/log(x)) + assert Order(exp(1/x) + x) == Order(exp(1/x)) + assert Order(exp(1/x) + 1/x**20) == Order(exp(1/x)) + + +def test_ln_args(): + assert O(log(x)) + O(log(2*x)) == O(log(x)) + assert O(log(x)) + O(log(x**3)) == O(log(x)) + assert O(log(x*y)) + O(log(x) + log(y)) == O(log(x) + log(y), x, y) + + +def test_multivar_0(): + assert Order(x*y).expr == x*y + assert Order(x*y**2).expr == x*y**2 + assert Order(x*y, x).expr == x + assert Order(x*y**2, y).expr == y**2 + assert Order(x*y*z).expr == x*y*z + assert Order(x/y).expr == x/y + assert Order(x*exp(1/y)).expr == x*exp(1/y) + assert Order(exp(x)*exp(1/y)).expr == exp(x)*exp(1/y) + + +def test_multivar_0a(): + assert Order(exp(1/x)*exp(1/y)).expr == exp(1/x)*exp(1/y) + + +def test_multivar_1(): + assert Order(x + y).expr == x + y + assert Order(x + 2*y).expr == x + y + assert (Order(x + y) + x).expr == (x + y) + assert (Order(x + y) + x**2) == Order(x + y) + assert (Order(x + y) + 1/x) == 1/x + Order(x + y) + assert Order(x**2 + y*x).expr == x**2 + y*x + + +def test_multivar_2(): + assert Order(x**2*y + y**2*x, x, y).expr == x**2*y + y**2*x + + +def test_multivar_mul_1(): + assert Order(x + y)*x == Order(x**2 + y*x, x, y) + + +def test_multivar_3(): + assert (Order(x) + Order(y)).args in [ + (Order(x), Order(y)), + (Order(y), Order(x))] + assert Order(x) + Order(y) + Order(x + y) == Order(x + y) + assert (Order(x**2*y) + Order(y**2*x)).args in [ + (Order(x*y**2), Order(y*x**2)), + (Order(y*x**2), Order(x*y**2))] + assert (Order(x**2*y) + Order(y*x)) == Order(x*y) + + +def test_issue_3468(): + y = Symbol('y', negative=True) + z = Symbol('z', complex=True) + + # check that Order does not modify assumptions about symbols + Order(x) + Order(y) + Order(z) + + assert x.is_positive is None + assert y.is_positive is False + assert z.is_positive is None + + +def test_leading_order(): + assert (x + 1 + 1/x**5).extract_leading_order(x) == ((1/x**5, O(1/x**5)),) + assert (1 + 1/x).extract_leading_order(x) == ((1/x, O(1/x)),) + assert (1 + x).extract_leading_order(x) == ((1, O(1, x)),) + assert (1 + x**2).extract_leading_order(x) == ((1, O(1, x)),) + assert (2 + x**2).extract_leading_order(x) == ((2, O(1, x)),) + assert (x + x**2).extract_leading_order(x) == ((x, O(x)),) + + +def test_leading_order2(): + assert set((2 + pi + x**2).extract_leading_order(x)) == {(pi, O(1, x)), + (S(2), O(1, x))} + assert set((2*x + pi*x + x**2).extract_leading_order(x)) == {(2*x, O(x)), + (x*pi, O(x))} + + +def test_order_leadterm(): + assert O(x**2)._eval_as_leading_term(x, None, 1) == O(x**2) + + +def test_order_symbols(): + e = x*y*sin(x)*Integral(x, (x, 1, 2)) + assert O(e) == O(x**2*y, x, y) + assert O(e, x) == O(x**2) + + +def test_nan(): + assert O(nan) is nan + assert not O(x).contains(nan) + + +def test_O1(): + assert O(1, x) * x == O(x) + assert O(1, y) * x == O(1, y) + + +def test_getn(): + # other lines are tested incidentally by the suite + assert O(x).getn() == 1 + assert O(x/log(x)).getn() == 1 + assert O(x**2/log(x)**2).getn() == 2 + assert O(x*log(x)).getn() == 1 + raises(NotImplementedError, lambda: (O(x) + O(y)).getn()) + + +def test_diff(): + assert O(x**2).diff(x) == O(x) + + +def test_getO(): + assert (x).getO() is None + assert (x).removeO() == x + assert (O(x)).getO() == O(x) + assert (O(x)).removeO() == 0 + assert (z + O(x) + O(y)).getO() == O(x) + O(y) + assert (z + O(x) + O(y)).removeO() == z + raises(NotImplementedError, lambda: (O(x) + O(y)).getn()) + + +def test_leading_term(): + from sympy.functions.special.gamma_functions import digamma + assert O(1/digamma(1/x)) == O(1/log(x)) + + +def test_eval(): + assert Order(x).subs(Order(x), 1) == 1 + assert Order(x).subs(x, y) == Order(y) + assert Order(x).subs(y, x) == Order(x) + assert Order(x).subs(x, x + y) == Order(x + y, (x, -y)) + assert (O(1)**x).is_Pow + + +def test_issue_4279(): + a, b = symbols('a b') + assert O(a, a, b) + O(1, a, b) == O(1, a, b) + assert O(b, a, b) + O(1, a, b) == O(1, a, b) + assert O(a + b, a, b) + O(1, a, b) == O(1, a, b) + assert O(1, a, b) + O(a, a, b) == O(1, a, b) + assert O(1, a, b) + O(b, a, b) == O(1, a, b) + assert O(1, a, b) + O(a + b, a, b) == O(1, a, b) + + +def test_issue_4855(): + assert 1/O(1) != O(1) + assert 1/O(x) != O(1/x) + assert 1/O(x, (x, oo)) != O(1/x, (x, oo)) + + f = Function('f') + assert 1/O(f(x)) != O(1/x) + + +def test_order_conjugate_transpose(): + x = Symbol('x', real=True) + y = Symbol('y', imaginary=True) + assert conjugate(Order(x)) == Order(conjugate(x)) + assert conjugate(Order(y)) == Order(conjugate(y)) + assert conjugate(Order(x**2)) == Order(conjugate(x)**2) + assert conjugate(Order(y**2)) == Order(conjugate(y)**2) + assert transpose(Order(x)) == Order(transpose(x)) + assert transpose(Order(y)) == Order(transpose(y)) + assert transpose(Order(x**2)) == Order(transpose(x)**2) + assert transpose(Order(y**2)) == Order(transpose(y)**2) + + +def test_order_noncommutative(): + A = Symbol('A', commutative=False) + assert Order(A + A*x, x) == Order(1, x) + assert (A + A*x)*Order(x) == Order(x) + assert (A*x)*Order(x) == Order(x**2, x) + assert expand((1 + Order(x))*A*A*x) == A*A*x + Order(x**2, x) + assert expand((A*A + Order(x))*x) == A*A*x + Order(x**2, x) + assert expand((A + Order(x))*A*x) == A*A*x + Order(x**2, x) + + +def test_issue_6753(): + assert (1 + x**2)**10000*O(x) == O(x) + + +def test_order_at_infinity(): + assert Order(1 + x, (x, oo)) == Order(x, (x, oo)) + assert Order(3*x, (x, oo)) == Order(x, (x, oo)) + assert Order(x, (x, oo))*3 == Order(x, (x, oo)) + assert -28*Order(x, (x, oo)) == Order(x, (x, oo)) + assert Order(Order(x, (x, oo)), (x, oo)) == Order(x, (x, oo)) + assert Order(Order(x, (x, oo)), (y, oo)) == Order(x, (x, oo), (y, oo)) + assert Order(3, (x, oo)) == Order(1, (x, oo)) + assert Order(x**2 + x + y, (x, oo)) == O(x**2, (x, oo)) + assert Order(x**2 + x + y, (y, oo)) == O(y, (y, oo)) + + assert Order(2*x, (x, oo))*x == Order(x**2, (x, oo)) + assert Order(2*x, (x, oo))/x == Order(1, (x, oo)) + assert Order(2*x, (x, oo))*x*exp(1/x) == Order(x**2*exp(1/x), (x, oo)) + assert Order(2*x, (x, oo))*x*exp(1/x)/log(x)**3 == Order(x**2*exp(1/x)*log(x)**-3, (x, oo)) + + assert Order(x, (x, oo)) + 1/x == 1/x + Order(x, (x, oo)) == Order(x, (x, oo)) + assert Order(x, (x, oo)) + 1 == 1 + Order(x, (x, oo)) == Order(x, (x, oo)) + assert Order(x, (x, oo)) + x == x + Order(x, (x, oo)) == Order(x, (x, oo)) + assert Order(x, (x, oo)) + x**2 == x**2 + Order(x, (x, oo)) + assert Order(1/x, (x, oo)) + 1/x**2 == 1/x**2 + Order(1/x, (x, oo)) == Order(1/x, (x, oo)) + assert Order(x, (x, oo)) + exp(1/x) == exp(1/x) + Order(x, (x, oo)) + + assert Order(x, (x, oo))**2 == Order(x**2, (x, oo)) + + assert Order(x, (x, oo)) + Order(x**2, (x, oo)) == Order(x**2, (x, oo)) + assert Order(x, (x, oo)) + Order(x**-2, (x, oo)) == Order(x, (x, oo)) + assert Order(x, (x, oo)) + Order(1/x, (x, oo)) == Order(x, (x, oo)) + + assert Order(x, (x, oo)) - Order(x, (x, oo)) == Order(x, (x, oo)) + assert Order(x, (x, oo)) + Order(1, (x, oo)) == Order(x, (x, oo)) + assert Order(x, (x, oo)) + Order(x**2, (x, oo)) == Order(x**2, (x, oo)) + assert Order(1/x, (x, oo)) + Order(1, (x, oo)) == Order(1, (x, oo)) + assert Order(x, (x, oo)) + Order(exp(1/x), (x, oo)) == Order(x, (x, oo)) + assert Order(x**3, (x, oo)) + Order(exp(2/x), (x, oo)) == Order(x**3, (x, oo)) + assert Order(x**-3, (x, oo)) + Order(exp(2/x), (x, oo)) == Order(exp(2/x), (x, oo)) + + # issue 7207 + assert Order(exp(x), (x, oo)).expr == Order(2*exp(x), (x, oo)).expr == exp(x) + assert Order(y**x, (x, oo)).expr == Order(2*y**x, (x, oo)).expr == exp(x*log(y)) + + # issue 19545 + assert Order(1/x - 3/(3*x + 2), (x, oo)).expr == x**(-2) + +def test_mixing_order_at_zero_and_infinity(): + assert (Order(x, (x, 0)) + Order(x, (x, oo))).is_Add + assert Order(x, (x, 0)) + Order(x, (x, oo)) == Order(x, (x, oo)) + Order(x, (x, 0)) + assert Order(Order(x, (x, oo))) == Order(x, (x, oo)) + + # not supported (yet) + raises(NotImplementedError, lambda: Order(x, (x, 0))*Order(x, (x, oo))) + raises(NotImplementedError, lambda: Order(x, (x, oo))*Order(x, (x, 0))) + raises(NotImplementedError, lambda: Order(Order(x, (x, oo)), y)) + raises(NotImplementedError, lambda: Order(Order(x), (x, oo))) + + +def test_order_at_some_point(): + assert Order(x, (x, 1)) == Order(1, (x, 1)) + assert Order(2*x - 2, (x, 1)) == Order(x - 1, (x, 1)) + assert Order(-x + 1, (x, 1)) == Order(x - 1, (x, 1)) + assert Order(x - 1, (x, 1))**2 == Order((x - 1)**2, (x, 1)) + assert Order(x - 2, (x, 2)) - O(x - 2, (x, 2)) == Order(x - 2, (x, 2)) + + +def test_order_subs_limits(): + # issue 3333 + assert (1 + Order(x)).subs(x, 1/x) == 1 + Order(1/x, (x, oo)) + assert (1 + Order(x)).limit(x, 0) == 1 + # issue 5769 + assert ((x + Order(x**2))/x).limit(x, 0) == 1 + + assert Order(x**2).subs(x, y - 1) == Order((y - 1)**2, (y, 1)) + assert Order(10*x**2, (x, 2)).subs(x, y - 1) == Order(1, (y, 3)) + + #issue 19120 + assert O(x).subs(x, O(x)) == O(x) + assert O(x**2).subs(x, x + O(x)) == O(x**2) + assert O(x, (x, oo)).subs(x, O(x, (x, oo))) == O(x, (x, oo)) + assert O(x**2, (x, oo)).subs(x, x + O(x, (x, oo))) == O(x**2, (x, oo)) + assert (x + O(x**2)).subs(x, x + O(x**2)) == x + O(x**2) + assert (x**2 + O(x**2) + 1/x**2).subs(x, x + O(x**2)) == (x + O(x**2))**(-2) + O(x**2) + assert (x**2 + O(x**2) + 1).subs(x, x + O(x**2)) == 1 + O(x**2) + assert O(x, (x, oo)).subs(x, x + O(x**2, (x, oo))) == O(x**2, (x, oo)) + assert sin(x).series(n=8).subs(x,sin(x).series(n=8)).expand() == x - x**3/3 + x**5/10 - 8*x**7/315 + O(x**8) + assert cos(x).series(n=8).subs(x,sin(x).series(n=8)).expand() == 1 - x**2/2 + 5*x**4/24 - 37*x**6/720 + O(x**8) + assert O(x).subs(x, O(1/x, (x, oo))) == O(1/x, (x, oo)) + +@XFAIL +def test_order_failing_due_to_solveset(): + assert O(x**3).subs(x, exp(-x**2)) == O(exp(-3*x**2), (x, -oo)) + raises(NotImplementedError, lambda: O(x).subs(x, O(1/x))) # mixing of order at different points + + +def test_issue_9351(): + assert exp(x).series(x, 10, 1) == exp(10) + Order(x - 10, (x, 10)) + + +def test_issue_9192(): + assert O(1)*O(1) == O(1) + assert O(1)**O(1) == O(1) + + +def test_issue_9910(): + assert O(x*log(x) + sin(x), (x, oo)) == O(x*log(x), (x, oo)) + + +def test_performance_of_adding_order(): + l = [x**i for i in range(1000)] + l.append(O(x**1001)) + assert Add(*l).subs(x,1) == O(1) + +def test_issue_14622(): + assert (x**(-4) + x**(-3) + x**(-1) + O(x**(-6), (x, oo))).as_numer_denom() == ( + x**4 + x**5 + x**7 + O(x**2, (x, oo)), x**8) + assert (x**3 + O(x**2, (x, oo))).is_Add + assert O(x**2, (x, oo)).contains(x**3) is False + assert O(x, (x, oo)).contains(O(x, (x, 0))) is None + assert O(x, (x, 0)).contains(O(x, (x, oo))) is None + raises(NotImplementedError, lambda: O(x**3).contains(x**w)) + + +def test_issue_15539(): + assert O(1/x**2 + 1/x**4, (x, -oo)) == O(1/x**2, (x, -oo)) + assert O(1/x**4 + exp(x), (x, -oo)) == O(1/x**4, (x, -oo)) + assert O(1/x**4 + exp(-x), (x, -oo)) == O(exp(-x), (x, -oo)) + assert O(1/x, (x, oo)).subs(x, -x) == O(-1/x, (x, -oo)) + +def test_issue_18606(): + assert unchanged(Order, 0) + + +def test_issue_22165(): + assert O(log(x)).contains(2) + + +def test_issue_23231(): + # This test checks Order for expressions having + # arguments containing variables in exponents/powers. + assert O(x**x + 2**x, (x, oo)) == O(exp(x*log(x)), (x, oo)) + assert O(x**x + x**2, (x, oo)) == O(exp(x*log(x)), (x, oo)) + assert O(x**x + 1/x**2, (x, oo)) == O(exp(x*log(x)), (x, oo)) + assert O(2**x + 3**x , (x, oo)) == O(exp(x*log(3)), (x, oo)) + + +def test_issue_9917(): + assert O(x*sin(x) + 1, (x, oo)) == O(x, (x, oo)) + + +def test_issue_22836(): + assert O(2**x + factorial(x), (x, oo)) == O(factorial(x), (x, oo)) + assert O(2**x + factorial(x) + x**x, (x, oo)) == O(exp(x*log(x)), (x, oo)) + assert O(x + factorial(x), (x, oo)) == O(factorial(x), (x, oo)) diff --git a/phivenv/Lib/site-packages/sympy/series/tests/test_residues.py b/phivenv/Lib/site-packages/sympy/series/tests/test_residues.py new file mode 100644 index 0000000000000000000000000000000000000000..9f7d075a56500d008e3c8b46c1fda5db890fd76a --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/series/tests/test_residues.py @@ -0,0 +1,101 @@ +from sympy.core.function import Function +from sympy.core.numbers import (I, Rational, pi) +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.functions.combinatorial.factorials import factorial +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.hyperbolic import tanh +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cot, sin, tan) +from sympy.series.residues import residue +from sympy.testing.pytest import XFAIL, raises +from sympy.abc import x, z, a, s, k + + +def test_basic1(): + assert residue(1/x, x, 0) == 1 + assert residue(-2/x, x, 0) == -2 + assert residue(81/x, x, 0) == 81 + assert residue(1/x**2, x, 0) == 0 + assert residue(0, x, 0) == 0 + assert residue(5, x, 0) == 0 + assert residue(x, x, 0) == 0 + assert residue(x**2, x, 0) == 0 + + +def test_basic2(): + assert residue(1/x, x, 1) == 0 + assert residue(-2/x, x, 1) == 0 + assert residue(81/x, x, -1) == 0 + assert residue(1/x**2, x, 1) == 0 + assert residue(0, x, 1) == 0 + assert residue(5, x, 1) == 0 + assert residue(x, x, 1) == 0 + assert residue(x**2, x, 5) == 0 + + +def test_f(): + f = Function("f") + assert residue(f(x)/x**5, x, 0) == f(x).diff(x, 4).subs(x, 0)/24 + + +def test_functions(): + assert residue(1/sin(x), x, 0) == 1 + assert residue(2/sin(x), x, 0) == 2 + assert residue(1/sin(x)**2, x, 0) == 0 + assert residue(1/sin(x)**5, x, 0) == Rational(3, 8) + + +def test_expressions(): + assert residue(1/(x + 1), x, 0) == 0 + assert residue(1/(x + 1), x, -1) == 1 + assert residue(1/(x**2 + 1), x, -1) == 0 + assert residue(1/(x**2 + 1), x, I) == -I/2 + assert residue(1/(x**2 + 1), x, -I) == I/2 + assert residue(1/(x**4 + 1), x, 0) == 0 + assert residue(1/(x**4 + 1), x, exp(I*pi/4)).equals(-(Rational(1, 4) + I/4)/sqrt(2)) + assert residue(1/(x**2 + a**2)**2, x, a*I) == -I/4/a**3 + + +@XFAIL +def test_expressions_failing(): + n = Symbol('n', integer=True, positive=True) + assert residue(exp(z)/(z - pi*I/4*a)**n, z, I*pi*a) == \ + exp(I*pi*a/4)/factorial(n - 1) + + +def test_NotImplemented(): + raises(NotImplementedError, lambda: residue(exp(1/z), z, 0)) + + +def test_bug(): + assert residue(2**(z)*(s + z)*(1 - s - z)/z**2, z, 0) == \ + 1 + s*log(2) - s**2*log(2) - 2*s + + +def test_issue_5654(): + assert residue(1/(x**2 + a**2)**2, x, a*I) == -I/(4*a**3) + assert residue(1/s*1/(z - exp(s)), s, 0) == 1/(z - 1) + assert residue((1 + k)/s*1/(z - exp(s)), s, 0) == k/(z - 1) + 1/(z - 1) + + +def test_issue_6499(): + assert residue(1/(exp(z) - 1), z, 0) == 1 + + +def test_issue_14037(): + assert residue(sin(x**50)/x**51, x, 0) == 1 + + +def test_issue_21176(): + f = x**2*cot(pi*x)/(x**4 + 1) + assert residue(f, x, -sqrt(2)/2 - sqrt(2)*I/2).cancel().together(deep=True)\ + == sqrt(2)*(1 - I)/(8*tan(sqrt(2)*pi*(1 + I)/2)) + + +def test_issue_21177(): + r = -sqrt(3)*tanh(sqrt(3)*pi/2)/3 + a = residue(cot(pi*x)/((x - 1)*(x - 2) + 1), x, S(3)/2 - sqrt(3)*I/2) + b = residue(cot(pi*x)/(x**2 - 3*x + 3), x, S(3)/2 - sqrt(3)*I/2) + assert a == r + assert (b - a).cancel() == 0 diff --git a/phivenv/Lib/site-packages/sympy/series/tests/test_sequences.py b/phivenv/Lib/site-packages/sympy/series/tests/test_sequences.py new file mode 100644 index 0000000000000000000000000000000000000000..61e276ad67982f0a9877de3548d70238976d28a5 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/series/tests/test_sequences.py @@ -0,0 +1,312 @@ +from sympy.core.containers import Tuple +from sympy.core.function import Function +from sympy.core.numbers import oo, Rational +from sympy.core.singleton import S +from sympy.core.symbol import symbols, Symbol +from sympy.functions.combinatorial.numbers import tribonacci, fibonacci +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import cos, sin +from sympy.series import EmptySequence +from sympy.series.sequences import (SeqMul, SeqAdd, SeqPer, SeqFormula, + sequence) +from sympy.sets.sets import Interval +from sympy.tensor.indexed import Indexed, Idx +from sympy.series.sequences import SeqExpr, SeqExprOp, RecursiveSeq +from sympy.testing.pytest import raises, slow + +x, y, z = symbols('x y z') +n, m = symbols('n m') + + +def test_EmptySequence(): + assert S.EmptySequence is EmptySequence + + assert S.EmptySequence.interval is S.EmptySet + assert S.EmptySequence.length is S.Zero + + assert list(S.EmptySequence) == [] + + +def test_SeqExpr(): + #SeqExpr is a baseclass and does not take care of + #ensuring all arguments are Basics hence the use of + #Tuple(...) here. + s = SeqExpr(Tuple(1, n, y), Tuple(x, 0, 10)) + + assert isinstance(s, SeqExpr) + assert s.gen == (1, n, y) + assert s.interval == Interval(0, 10) + assert s.start == 0 + assert s.stop == 10 + assert s.length == 11 + assert s.variables == (x,) + + assert SeqExpr(Tuple(1, 2, 3), Tuple(x, 0, oo)).length is oo + + +def test_SeqPer(): + s = SeqPer((1, n, 3), (x, 0, 5)) + + assert isinstance(s, SeqPer) + assert s.periodical == Tuple(1, n, 3) + assert s.period == 3 + assert s.coeff(3) == 1 + assert s.free_symbols == {n} + + assert list(s) == [1, n, 3, 1, n, 3] + assert s[:] == [1, n, 3, 1, n, 3] + assert SeqPer((1, n, 3), (x, -oo, 0))[0:6] == [1, n, 3, 1, n, 3] + + raises(ValueError, lambda: SeqPer((1, 2, 3), (0, 1, 2))) + raises(ValueError, lambda: SeqPer((1, 2, 3), (x, -oo, oo))) + raises(ValueError, lambda: SeqPer(n**2, (0, oo))) + + assert SeqPer((n, n**2, n**3), (m, 0, oo))[:6] == \ + [n, n**2, n**3, n, n**2, n**3] + assert SeqPer((n, n**2, n**3), (n, 0, oo))[:6] == [0, 1, 8, 3, 16, 125] + assert SeqPer((n, m), (n, 0, oo))[:6] == [0, m, 2, m, 4, m] + + +def test_SeqFormula(): + s = SeqFormula(n**2, (n, 0, 5)) + + assert isinstance(s, SeqFormula) + assert s.formula == n**2 + assert s.coeff(3) == 9 + + assert list(s) == [i**2 for i in range(6)] + assert s[:] == [i**2 for i in range(6)] + assert SeqFormula(n**2, (n, -oo, 0))[0:6] == [i**2 for i in range(6)] + + assert SeqFormula(n**2, (0, oo)) == SeqFormula(n**2, (n, 0, oo)) + + assert SeqFormula(n**2, (0, m)).subs(m, x) == SeqFormula(n**2, (0, x)) + assert SeqFormula(m*n**2, (n, 0, oo)).subs(m, x) == \ + SeqFormula(x*n**2, (n, 0, oo)) + + raises(ValueError, lambda: SeqFormula(n**2, (0, 1, 2))) + raises(ValueError, lambda: SeqFormula(n**2, (n, -oo, oo))) + raises(ValueError, lambda: SeqFormula(m*n**2, (0, oo))) + + seq = SeqFormula(x*(y**2 + z), (z, 1, 100)) + assert seq.expand() == SeqFormula(x*y**2 + x*z, (z, 1, 100)) + seq = SeqFormula(sin(x*(y**2 + z)),(z, 1, 100)) + assert seq.expand(trig=True) == SeqFormula(sin(x*y**2)*cos(x*z) + sin(x*z)*cos(x*y**2), (z, 1, 100)) + assert seq.expand() == SeqFormula(sin(x*y**2 + x*z), (z, 1, 100)) + assert seq.expand(trig=False) == SeqFormula(sin(x*y**2 + x*z), (z, 1, 100)) + seq = SeqFormula(exp(x*(y**2 + z)), (z, 1, 100)) + assert seq.expand() == SeqFormula(exp(x*y**2)*exp(x*z), (z, 1, 100)) + assert seq.expand(power_exp=False) == SeqFormula(exp(x*y**2 + x*z), (z, 1, 100)) + assert seq.expand(mul=False, power_exp=False) == SeqFormula(exp(x*(y**2 + z)), (z, 1, 100)) + +def test_sequence(): + form = SeqFormula(n**2, (n, 0, 5)) + per = SeqPer((1, 2, 3), (n, 0, 5)) + inter = SeqFormula(n**2) + + assert sequence(n**2, (n, 0, 5)) == form + assert sequence((1, 2, 3), (n, 0, 5)) == per + assert sequence(n**2) == inter + + +def test_SeqExprOp(): + form = SeqFormula(n**2, (n, 0, 10)) + per = SeqPer((1, 2, 3), (m, 5, 10)) + + s = SeqExprOp(form, per) + assert s.gen == (n**2, (1, 2, 3)) + assert s.interval == Interval(5, 10) + assert s.start == 5 + assert s.stop == 10 + assert s.length == 6 + assert s.variables == (n, m) + + +def test_SeqAdd(): + per = SeqPer((1, 2, 3), (n, 0, oo)) + form = SeqFormula(n**2) + + per_bou = SeqPer((1, 2), (n, 1, 5)) + form_bou = SeqFormula(n**2, (6, 10)) + form_bou2 = SeqFormula(n**2, (1, 5)) + + assert SeqAdd() == S.EmptySequence + assert SeqAdd(S.EmptySequence) == S.EmptySequence + assert SeqAdd(per) == per + assert SeqAdd(per, S.EmptySequence) == per + assert SeqAdd(per_bou, form_bou) == S.EmptySequence + + s = SeqAdd(per_bou, form_bou2, evaluate=False) + assert s.args == (form_bou2, per_bou) + assert s[:] == [2, 6, 10, 18, 26] + assert list(s) == [2, 6, 10, 18, 26] + + assert isinstance(SeqAdd(per, per_bou, evaluate=False), SeqAdd) + + s1 = SeqAdd(per, per_bou) + assert isinstance(s1, SeqPer) + assert s1 == SeqPer((2, 4, 4, 3, 3, 5), (n, 1, 5)) + s2 = SeqAdd(form, form_bou) + assert isinstance(s2, SeqFormula) + assert s2 == SeqFormula(2*n**2, (6, 10)) + + assert SeqAdd(form, form_bou, per) == \ + SeqAdd(per, SeqFormula(2*n**2, (6, 10))) + assert SeqAdd(form, SeqAdd(form_bou, per)) == \ + SeqAdd(per, SeqFormula(2*n**2, (6, 10))) + assert SeqAdd(per, SeqAdd(form, form_bou), evaluate=False) == \ + SeqAdd(per, SeqFormula(2*n**2, (6, 10))) + + assert SeqAdd(SeqPer((1, 2), (n, 0, oo)), SeqPer((1, 2), (m, 0, oo))) == \ + SeqPer((2, 4), (n, 0, oo)) + + +def test_SeqMul(): + per = SeqPer((1, 2, 3), (n, 0, oo)) + form = SeqFormula(n**2) + + per_bou = SeqPer((1, 2), (n, 1, 5)) + form_bou = SeqFormula(n**2, (n, 6, 10)) + form_bou2 = SeqFormula(n**2, (1, 5)) + + assert SeqMul() == S.EmptySequence + assert SeqMul(S.EmptySequence) == S.EmptySequence + assert SeqMul(per) == per + assert SeqMul(per, S.EmptySequence) == S.EmptySequence + assert SeqMul(per_bou, form_bou) == S.EmptySequence + + s = SeqMul(per_bou, form_bou2, evaluate=False) + assert s.args == (form_bou2, per_bou) + assert s[:] == [1, 8, 9, 32, 25] + assert list(s) == [1, 8, 9, 32, 25] + + assert isinstance(SeqMul(per, per_bou, evaluate=False), SeqMul) + + s1 = SeqMul(per, per_bou) + assert isinstance(s1, SeqPer) + assert s1 == SeqPer((1, 4, 3, 2, 2, 6), (n, 1, 5)) + s2 = SeqMul(form, form_bou) + assert isinstance(s2, SeqFormula) + assert s2 == SeqFormula(n**4, (6, 10)) + + assert SeqMul(form, form_bou, per) == \ + SeqMul(per, SeqFormula(n**4, (6, 10))) + assert SeqMul(form, SeqMul(form_bou, per)) == \ + SeqMul(per, SeqFormula(n**4, (6, 10))) + assert SeqMul(per, SeqMul(form, form_bou2, + evaluate=False), evaluate=False) == \ + SeqMul(form, per, form_bou2, evaluate=False) + + assert SeqMul(SeqPer((1, 2), (n, 0, oo)), SeqPer((1, 2), (n, 0, oo))) == \ + SeqPer((1, 4), (n, 0, oo)) + + +def test_add(): + per = SeqPer((1, 2), (n, 0, oo)) + form = SeqFormula(n**2) + + assert per + (SeqPer((2, 3))) == SeqPer((3, 5), (n, 0, oo)) + assert form + SeqFormula(n**3) == SeqFormula(n**2 + n**3) + + assert per + form == SeqAdd(per, form) + + raises(TypeError, lambda: per + n) + raises(TypeError, lambda: n + per) + + +def test_sub(): + per = SeqPer((1, 2), (n, 0, oo)) + form = SeqFormula(n**2) + + assert per - (SeqPer((2, 3))) == SeqPer((-1, -1), (n, 0, oo)) + assert form - (SeqFormula(n**3)) == SeqFormula(n**2 - n**3) + + assert per - form == SeqAdd(per, -form) + + raises(TypeError, lambda: per - n) + raises(TypeError, lambda: n - per) + + +def test_mul__coeff_mul(): + assert SeqPer((1, 2), (n, 0, oo)).coeff_mul(2) == SeqPer((2, 4), (n, 0, oo)) + assert SeqFormula(n**2).coeff_mul(2) == SeqFormula(2*n**2) + assert S.EmptySequence.coeff_mul(100) == S.EmptySequence + + assert SeqPer((1, 2), (n, 0, oo)) * (SeqPer((2, 3))) == \ + SeqPer((2, 6), (n, 0, oo)) + assert SeqFormula(n**2) * SeqFormula(n**3) == SeqFormula(n**5) + + assert S.EmptySequence * SeqFormula(n**2) == S.EmptySequence + assert SeqFormula(n**2) * S.EmptySequence == S.EmptySequence + + raises(TypeError, lambda: sequence(n**2) * n) + raises(TypeError, lambda: n * sequence(n**2)) + + +def test_neg(): + assert -SeqPer((1, -2), (n, 0, oo)) == SeqPer((-1, 2), (n, 0, oo)) + assert -SeqFormula(n**2) == SeqFormula(-n**2) + + +def test_operations(): + per = SeqPer((1, 2), (n, 0, oo)) + per2 = SeqPer((2, 4), (n, 0, oo)) + form = SeqFormula(n**2) + form2 = SeqFormula(n**3) + + assert per + form + form2 == SeqAdd(per, form, form2) + assert per + form - form2 == SeqAdd(per, form, -form2) + assert per + form - S.EmptySequence == SeqAdd(per, form) + assert per + per2 + form == SeqAdd(SeqPer((3, 6), (n, 0, oo)), form) + assert S.EmptySequence - per == -per + assert form + form == SeqFormula(2*n**2) + + assert per * form * form2 == SeqMul(per, form, form2) + assert form * form == SeqFormula(n**4) + assert form * -form == SeqFormula(-n**4) + + assert form * (per + form2) == SeqMul(form, SeqAdd(per, form2)) + assert form * (per + per) == SeqMul(form, per2) + + assert form.coeff_mul(m) == SeqFormula(m*n**2, (n, 0, oo)) + assert per.coeff_mul(m) == SeqPer((m, 2*m), (n, 0, oo)) + + +def test_Idx_limits(): + i = symbols('i', cls=Idx) + r = Indexed('r', i) + + assert SeqFormula(r, (i, 0, 5))[:] == [r.subs(i, j) for j in range(6)] + assert SeqPer((1, 2), (i, 0, 5))[:] == [1, 2, 1, 2, 1, 2] + + +@slow +def test_find_linear_recurrence(): + assert sequence((0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55), \ + (n, 0, 10)).find_linear_recurrence(11) == [1, 1] + assert sequence((1, 2, 4, 7, 28, 128, 582, 2745, 13021, 61699, 292521, \ + 1387138), (n, 0, 11)).find_linear_recurrence(12) == [5, -2, 6, -11] + assert sequence(x*n**3+y*n, (n, 0, oo)).find_linear_recurrence(10) \ + == [4, -6, 4, -1] + assert sequence(x**n, (n,0,20)).find_linear_recurrence(21) == [x] + assert sequence((1,2,3)).find_linear_recurrence(10, 5) == [0, 0, 1] + assert sequence(((1 + sqrt(5))/2)**n + \ + (-(1 + sqrt(5))/2)**(-n)).find_linear_recurrence(10) == [1, 1] + assert sequence(x*((1 + sqrt(5))/2)**n + y*(-(1 + sqrt(5))/2)**(-n), \ + (n,0,oo)).find_linear_recurrence(10) == [1, 1] + assert sequence((1,2,3,4,6),(n, 0, 4)).find_linear_recurrence(5) == [] + assert sequence((2,3,4,5,6,79),(n, 0, 5)).find_linear_recurrence(6,gfvar=x) \ + == ([], None) + assert sequence((2,3,4,5,8,30),(n, 0, 5)).find_linear_recurrence(6,gfvar=x) \ + == ([Rational(19, 2), -20, Rational(27, 2)], (-31*x**2 + 32*x - 4)/(27*x**3 - 40*x**2 + 19*x -2)) + assert sequence(fibonacci(n)).find_linear_recurrence(30,gfvar=x) \ + == ([1, 1], -x/(x**2 + x - 1)) + assert sequence(tribonacci(n)).find_linear_recurrence(30,gfvar=x) \ + == ([1, 1, 1], -x/(x**3 + x**2 + x - 1)) + +def test_RecursiveSeq(): + y = Function('y') + n = Symbol('n') + fib = RecursiveSeq(y(n - 1) + y(n - 2), y(n), n, [0, 1]) + assert fib.coeff(3) == 2 diff --git a/phivenv/Lib/site-packages/sympy/series/tests/test_series.py b/phivenv/Lib/site-packages/sympy/series/tests/test_series.py new file mode 100644 index 0000000000000000000000000000000000000000..e3f3c122b98c14a58c6d5c6636cbb53e1e66a75d --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/series/tests/test_series.py @@ -0,0 +1,421 @@ +from sympy.core.evalf import N +from sympy.core.function import (Derivative, Function, PoleError, Subs) +from sympy.core.numbers import (E, Float, Rational, oo, pi, I) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.exponential import (LambertW, exp, log) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (atan, cos, sin) +from sympy.functions.special.gamma_functions import gamma +from sympy.integrals.integrals import Integral, integrate +from sympy.series.order import O +from sympy.series.series import series +from sympy.abc import x, y, n, k +from sympy.testing.pytest import raises +from sympy.core import EulerGamma + + +def test_sin(): + e1 = sin(x).series(x, 0) + e2 = series(sin(x), x, 0) + assert e1 == e2 + + +def test_cos(): + e1 = cos(x).series(x, 0) + e2 = series(cos(x), x, 0) + assert e1 == e2 + + +def test_exp(): + e1 = exp(x).series(x, 0) + e2 = series(exp(x), x, 0) + assert e1 == e2 + + +def test_exp2(): + e1 = exp(cos(x)).series(x, 0) + e2 = series(exp(cos(x)), x, 0) + assert e1 == e2 + + +def test_issue_5223(): + assert series(1, x) == 1 + assert next(S.Zero.lseries(x)) == 0 + assert cos(x).series() == cos(x).series(x) + raises(ValueError, lambda: cos(x + y).series()) + raises(ValueError, lambda: x.series(dir="")) + + assert (cos(x).series(x, 1) - + cos(x + 1).series(x).subs(x, x - 1)).removeO() == 0 + e = cos(x).series(x, 1, n=None) + assert [next(e) for i in range(2)] == [cos(1), -((x - 1)*sin(1))] + e = cos(x).series(x, 1, n=None, dir='-') + assert [next(e) for i in range(2)] == [cos(1), (1 - x)*sin(1)] + # the following test is exact so no need for x -> x - 1 replacement + assert abs(x).series(x, 1, dir='-') == x + assert exp(x).series(x, 1, dir='-', n=3).removeO() == \ + E - E*(-x + 1) + E*(-x + 1)**2/2 + + D = Derivative + assert D(x**2 + x**3*y**2, x, 2, y, 1).series(x).doit() == 12*x*y + assert next(D(cos(x), x).lseries()) == D(1, x) + assert D( + exp(x), x).series(n=3) == D(1, x) + D(x, x) + D(x**2/2, x) + D(x**3/6, x) + O(x**3) + + assert Integral(x, (x, 1, 3), (y, 1, x)).series(x) == -4 + 4*x + + assert (1 + x + O(x**2)).getn() == 2 + assert (1 + x).getn() is None + + raises(PoleError, lambda: ((1/sin(x))**oo).series()) + logx = Symbol('logx') + assert ((sin(x))**y).nseries(x, n=1, logx=logx) == \ + exp(y*logx) + O(x*exp(y*logx), x) + + assert sin(1/x).series(x, oo, n=5) == 1/x - 1/(6*x**3) + O(x**(-5), (x, oo)) + assert abs(x).series(x, oo, n=5, dir='+') == x + assert abs(x).series(x, -oo, n=5, dir='-') == -x + assert abs(-x).series(x, oo, n=5, dir='+') == x + assert abs(-x).series(x, -oo, n=5, dir='-') == -x + + assert exp(x*log(x)).series(n=3) == \ + 1 + x*log(x) + x**2*log(x)**2/2 + O(x**3*log(x)**3) + # XXX is this right? If not, fix "ngot > n" handling in expr. + p = Symbol('p', positive=True) + assert exp(sqrt(p)**3*log(p)).series(n=3) == \ + 1 + p**S('3/2')*log(p) + O(p**3*log(p)**3) + + assert exp(sin(x)*log(x)).series(n=2) == 1 + x*log(x) + O(x**2*log(x)**2) + + +def test_issue_6350(): + expr = integrate(exp(k*(y**3 - 3*y)), (y, 0, oo), conds='none') + assert expr.series(k, 0, 3) == -(-1)**(S(2)/3)*sqrt(3)*gamma(S(1)/3)**2*gamma(S(2)/3)/(6*pi*k**(S(1)/3)) - \ + sqrt(3)*k*gamma(-S(2)/3)*gamma(-S(1)/3)/(6*pi) - \ + (-1)**(S(1)/3)*sqrt(3)*k**(S(1)/3)*gamma(-S(1)/3)*gamma(S(1)/3)*gamma(S(2)/3)/(6*pi) - \ + (-1)**(S(2)/3)*sqrt(3)*k**(S(5)/3)*gamma(S(1)/3)**2*gamma(S(2)/3)/(4*pi) - \ + (-1)**(S(1)/3)*sqrt(3)*k**(S(7)/3)*gamma(-S(1)/3)*gamma(S(1)/3)*gamma(S(2)/3)/(8*pi) + O(k**3) + + +def test_issue_11313(): + assert Integral(cos(x), x).series(x) == sin(x).series(x) + assert Derivative(sin(x), x).series(x, n=3).doit() == cos(x).series(x, n=3) + + assert Derivative(x**3, x).as_leading_term(x) == 3*x**2 + assert Derivative(x**3, y).as_leading_term(x) == 0 + assert Derivative(sin(x), x).as_leading_term(x) == 1 + assert Derivative(cos(x), x).as_leading_term(x) == -x + + # This result is equivalent to zero, zero is not return because + # `Expr.series` doesn't currently detect an `x` in its `free_symbol`s. + assert Derivative(1, x).as_leading_term(x) == Derivative(1, x) + + assert Derivative(exp(x), x).series(x).doit() == exp(x).series(x) + assert 1 + Integral(exp(x), x).series(x) == exp(x).series(x) + + assert Derivative(log(x), x).series(x).doit() == (1/x).series(x) + assert Integral(log(x), x).series(x) == Integral(log(x), x).doit().series(x).removeO() + + +def test_series_of_Subs(): + from sympy.abc import z + + subs1 = Subs(sin(x), x, y) + subs2 = Subs(sin(x) * cos(z), x, y) + subs3 = Subs(sin(x * z), (x, z), (y, x)) + + assert subs1.series(x) == subs1 + subs1_series = (Subs(x, x, y) + Subs(-x**3/6, x, y) + + Subs(x**5/120, x, y) + O(y**6)) + assert subs1.series() == subs1_series + assert subs1.series(y) == subs1_series + assert subs1.series(z) == subs1 + assert subs2.series(z) == (Subs(z**4*sin(x)/24, x, y) + + Subs(-z**2*sin(x)/2, x, y) + Subs(sin(x), x, y) + O(z**6)) + assert subs3.series(x).doit() == subs3.doit().series(x) + assert subs3.series(z).doit() == sin(x*y) + + raises(ValueError, lambda: Subs(x + 2*y, y, z).series()) + assert Subs(x + y, y, z).series(x).doit() == x + z + + +def test_issue_3978(): + f = Function('f') + assert f(x).series(x, 0, 3, dir='-') == \ + f(0) + x*Subs(Derivative(f(x), x), x, 0) + \ + x**2*Subs(Derivative(f(x), x, x), x, 0)/2 + O(x**3) + assert f(x).series(x, 0, 3) == \ + f(0) + x*Subs(Derivative(f(x), x), x, 0) + \ + x**2*Subs(Derivative(f(x), x, x), x, 0)/2 + O(x**3) + assert f(x**2).series(x, 0, 3) == \ + f(0) + x**2*Subs(Derivative(f(x), x), x, 0) + O(x**3) + assert f(x**2+1).series(x, 0, 3) == \ + f(1) + x**2*Subs(Derivative(f(x), x), x, 1) + O(x**3) + + class TestF(Function): + pass + + assert TestF(x).series(x, 0, 3) == TestF(0) + \ + x*Subs(Derivative(TestF(x), x), x, 0) + \ + x**2*Subs(Derivative(TestF(x), x, x), x, 0)/2 + O(x**3) + +from sympy.series.acceleration import richardson, shanks +from sympy.concrete.summations import Sum +from sympy.core.numbers import Integer + + +def test_acceleration(): + e = (1 + 1/n)**n + assert round(richardson(e, n, 10, 20).evalf(), 10) == round(E.evalf(), 10) + + A = Sum(Integer(-1)**(k + 1) / k, (k, 1, n)) + assert round(shanks(A, n, 25).evalf(), 4) == round(log(2).evalf(), 4) + assert round(shanks(A, n, 25, 5).evalf(), 10) == round(log(2).evalf(), 10) + + +def test_issue_5852(): + assert series(1/cos(x/log(x)), x, 0) == 1 + x**2/(2*log(x)**2) + \ + 5*x**4/(24*log(x)**4) + O(x**6) + + +def test_issue_4583(): + assert cos(1 + x + x**2).series(x, 0, 5) == cos(1) - x*sin(1) + \ + x**2*(-sin(1) - cos(1)/2) + x**3*(-cos(1) + sin(1)/6) + \ + x**4*(-11*cos(1)/24 + sin(1)/2) + O(x**5) + + +def test_issue_6318(): + eq = (1/x)**Rational(2, 3) + assert (eq + 1).as_leading_term(x) == eq + + +def test_x_is_base_detection(): + eq = (x**2)**Rational(2, 3) + assert eq.series() == x**Rational(4, 3) + + +def test_issue_7203(): + assert series(cos(x), x, pi, 3) == \ + -1 + (x - pi)**2/2 + O((x - pi)**3, (x, pi)) + + +def test_exp_product_positive_factors(): + a, b = symbols('a, b', positive=True) + x = a * b + assert series(exp(x), x, n=8) == 1 + a*b + a**2*b**2/2 + \ + a**3*b**3/6 + a**4*b**4/24 + a**5*b**5/120 + a**6*b**6/720 + \ + a**7*b**7/5040 + O(a**8*b**8, a, b) + + +def test_issue_8805(): + assert series(1, n=8) == 1 + + +def test_issue_9173(): + p0,p1,p2,p3,b0,b1,b2=symbols('p0 p1 p2 p3 b0 b1 b2') + Q=(p0+(p1+(p2+p3/y)/y)/y)/(1+((p3/(b0*y)+(b0*p2-b1*p3)/b0**2)/y+\ + (b0**2*p1-b0*b1*p2-p3*(b0*b2-b1**2))/b0**3)/y) + + series = Q.series(y,n=3) + + assert series == y*(b0*p2/p3+b0*(-p2/p3+b1/b0))+y**2*(b0*p1/p3+b0*p2*\ + (-p2/p3+b1/b0)/p3+b0*(-p1/p3+(p2/p3-b1/b0)**2+b1*p2/(b0*p3)+\ + b2/b0-b1**2/b0**2))+b0+O(y**3) + assert series.simplify() == b2*y**2 + b1*y + b0 + O(y**3) + + +def test_issue_9549(): + y = (x**2 + x + 1) / (x**3 + x**2) + assert series(y, x, oo) == x**(-5) - 1/x**4 + x**(-3) + 1/x + O(x**(-6), (x, oo)) + + +def test_issue_10761(): + assert series(1/(x**-2 + x**-3), x, 0) == x**3 - x**4 + x**5 + O(x**6) + + +def test_issue_12578(): + y = (1 - 1/(x/2 - 1/(2*x))**4)**(S(1)/8) + assert y.series(x, 0, n=17) == 1 - 2*x**4 - 8*x**6 - 34*x**8 - 152*x**10 - 714*x**12 - \ + 3472*x**14 - 17318*x**16 + O(x**17) + + +def test_issue_12791(): + beta = symbols('beta', positive=True) + theta, varphi = symbols('theta varphi', real=True) + + expr = (-beta**2*varphi*sin(theta) + beta**2*cos(theta) + \ + beta*varphi*sin(theta) - beta*cos(theta) - beta + 1)/(beta*cos(theta) - 1)**2 + + sol = (0.5/(0.5*cos(theta) - 1.0)**2 - 0.25*cos(theta)/(0.5*cos(theta) - 1.0)**2 + + (beta - 0.5)*(-0.25*varphi*sin(2*theta) - 1.5*cos(theta) + + 0.25*cos(2*theta) + 1.25)/((0.5*cos(theta) - 1.0)**2*(0.5*cos(theta) - 1.0)) + + 0.25*varphi*sin(theta)/(0.5*cos(theta) - 1.0)**2 + + O((beta - S.Half)**2, (beta, S.Half))) + + assert expr.series(beta, 0.5, 2).trigsimp() == sol + + +def test_issue_14384(): + x, a = symbols('x a') + assert series(x**a, x) == x**a + assert series(x**(-2*a), x) == x**(-2*a) + assert series(exp(a*log(x)), x) == exp(a*log(x)) + raises(PoleError, lambda: series(x**I, x)) + raises(PoleError, lambda: series(x**(I + 1), x)) + raises(PoleError, lambda: series(exp(I*log(x)), x)) + + +def test_issue_14885(): + assert series(x**Rational(-3, 2)*exp(x), x, 0) == (x**Rational(-3, 2) + 1/sqrt(x) + + sqrt(x)/2 + x**Rational(3, 2)/6 + x**Rational(5, 2)/24 + x**Rational(7, 2)/120 + + x**Rational(9, 2)/720 + x**Rational(11, 2)/5040 + O(x**6)) + + +def test_issue_15539(): + assert series(atan(x), x, -oo) == (-1/(5*x**5) + 1/(3*x**3) - 1/x - pi/2 + + O(x**(-6), (x, -oo))) + assert series(atan(x), x, oo) == (-1/(5*x**5) + 1/(3*x**3) - 1/x + pi/2 + + O(x**(-6), (x, oo))) + + +def test_issue_7259(): + assert series(LambertW(x), x) == x - x**2 + 3*x**3/2 - 8*x**4/3 + 125*x**5/24 + O(x**6) + assert series(LambertW(x**2), x, n=8) == x**2 - x**4 + 3*x**6/2 + O(x**8) + assert series(LambertW(sin(x)), x, n=4) == x - x**2 + 4*x**3/3 + O(x**4) + +def test_issue_11884(): + assert cos(x).series(x, 1, n=1) == cos(1) + O(x - 1, (x, 1)) + + +def test_issue_18008(): + y = x*(1 + x*(1 - x))/((1 + x*(1 - x)) - (1 - x)*(1 - x)) + assert y.series(x, oo, n=4) == -9/(32*x**3) - 3/(16*x**2) - 1/(8*x) + S(1)/4 + x/2 + \ + O(x**(-4), (x, oo)) + + +def test_issue_18842(): + f = log(x/(1 - x)) + assert f.series(x, 0.491, n=1).removeO().nsimplify() == \ + -S(180019443780011)/5000000000000000 + + +def test_issue_19534(): + dt = symbols('dt', real=True) + expr = 16*dt*(0.125*dt*(2.0*dt + 1.0) + 0.875*dt + 1.0)/45 + \ + 49*dt*(-0.049335189898860408029*dt*(2.0*dt + 1.0) + \ + 0.29601113939316244817*dt*(0.125*dt*(2.0*dt + 1.0) + 0.875*dt + 1.0) - \ + 0.12564355335492979587*dt*(0.074074074074074074074*dt*(2.0*dt + 1.0) + \ + 0.2962962962962962963*dt*(0.125*dt*(2.0*dt + 1.0) + 0.875*dt + 1.0) + \ + 0.96296296296296296296*dt + 1.0) + 0.051640768506639183825*dt + \ + dt*(1/2 - sqrt(21)/14) + 1.0)/180 + 49*dt*(-0.23637909581542530626*dt*(2.0*dt + 1.0) - \ + 0.74817562366625959291*dt*(0.125*dt*(2.0*dt + 1.0) + 0.875*dt + 1.0) + \ + 0.88085458023927036857*dt*(0.074074074074074074074*dt*(2.0*dt + 1.0) + \ + 0.2962962962962962963*dt*(0.125*dt*(2.0*dt + 1.0) + 0.875*dt + 1.0) + \ + 0.96296296296296296296*dt + 1.0) + \ + 2.1165151389911680013*dt*(-0.049335189898860408029*dt*(2.0*dt + 1.0) + \ + 0.29601113939316244817*dt*(0.125*dt*(2.0*dt + 1.0) + 0.875*dt + 1.0) - \ + 0.12564355335492979587*dt*(0.074074074074074074074*dt*(2.0*dt + 1.0) + \ + 0.2962962962962962963*dt*(0.125*dt*(2.0*dt + 1.0) + 0.875*dt + 1.0) + \ + 0.96296296296296296296*dt + 1.0) + 0.22431393315265061193*dt + 1.0) - \ + 1.1854881643947648988*dt + dt*(sqrt(21)/14 + 1/2) + 1.0)/180 + \ + dt*(0.66666666666666666667*dt*(2.0*dt + 1.0) + \ + 6.0173399699313066769*dt*(0.125*dt*(2.0*dt + 1.0) + 0.875*dt + 1.0) - \ + 4.1117044797036320069*dt*(0.074074074074074074074*dt*(2.0*dt + 1.0) + \ + 0.2962962962962962963*dt*(0.125*dt*(2.0*dt + 1.0) + 0.875*dt + 1.0) + \ + 0.96296296296296296296*dt + 1.0) - \ + 7.0189140975801991157*dt*(-0.049335189898860408029*dt*(2.0*dt + 1.0) + \ + 0.29601113939316244817*dt*(0.125*dt*(2.0*dt + 1.0) + 0.875*dt + 1.0) - \ + 0.12564355335492979587*dt*(0.074074074074074074074*dt*(2.0*dt + 1.0) + \ + 0.2962962962962962963*dt*(0.125*dt*(2.0*dt + 1.0) + 0.875*dt + 1.0) + \ + 0.96296296296296296296*dt + 1.0) + 0.22431393315265061193*dt + 1.0) + \ + 0.94010945196161777522*dt*(-0.23637909581542530626*dt*(2.0*dt + 1.0) - \ + 0.74817562366625959291*dt*(0.125*dt*(2.0*dt + 1.0) + 0.875*dt + 1.0) + \ + 0.88085458023927036857*dt*(0.074074074074074074074*dt*(2.0*dt + 1.0) + \ + 0.2962962962962962963*dt*(0.125*dt*(2.0*dt + 1.0) + 0.875*dt + 1.0) + \ + 0.96296296296296296296*dt + 1.0) + \ + 2.1165151389911680013*dt*(-0.049335189898860408029*dt*(2.0*dt + 1.0) + \ + 0.29601113939316244817*dt*(0.125*dt*(2.0*dt + 1.0) + 0.875*dt + 1.0) - \ + 0.12564355335492979587*dt*(0.074074074074074074074*dt*(2.0*dt + 1.0) + \ + 0.2962962962962962963*dt*(0.125*dt*(2.0*dt + 1.0) + 0.875*dt + 1.0) + \ + 0.96296296296296296296*dt + 1.0) + 0.22431393315265061193*dt + 1.0) - \ + 0.35816132904077632692*dt + 1.0) + 5.5065024887242400038*dt + 1.0)/20 + dt/20 + 1 + + assert N(expr.series(dt, 0, 8), 20) == ( + - Float('0.00092592592592592596126289', precision=70) * dt**7 + + Float('0.0027777777777777783174695', precision=70) * dt**6 + + Float('0.016666666666666656027029', precision=70) * dt**5 + + Float('0.083333333333333300951828', precision=70) * dt**4 + + Float('0.33333333333333337034077', precision=70) * dt**3 + + Float('1.0', precision=70) * dt**2 + + Float('1.0', precision=70) * dt + + Float('1.0', precision=70) + ) + + +def test_issue_11407(): + a, b, c, x = symbols('a b c x') + assert series(sqrt(a + b + c*x), x, 0, 1) == sqrt(a + b) + O(x) + assert series(sqrt(a + b + c + c*x), x, 0, 1) == sqrt(a + b + c) + O(x) + + +def test_issue_14037(): + assert (sin(x**50)/x**51).series(x, n=0) == 1/x + O(1, x) + + +def test_issue_20551(): + expr = (exp(x)/x).series(x, n=None) + terms = [ next(expr) for i in range(3) ] + assert terms == [1/x, 1, x/2] + + +def test_issue_20697(): + p_0, p_1, p_2, p_3, b_0, b_1, b_2 = symbols('p_0 p_1 p_2 p_3 b_0 b_1 b_2') + Q = (p_0 + (p_1 + (p_2 + p_3/y)/y)/y)/(1 + ((p_3/(b_0*y) + (b_0*p_2\ + - b_1*p_3)/b_0**2)/y + (b_0**2*p_1 - b_0*b_1*p_2 - p_3*(b_0*b_2\ + - b_1**2))/b_0**3)/y) + assert Q.series(y, n=3).ratsimp() == b_2*y**2 + b_1*y + b_0 + O(y**3) + + +def test_issue_21245(): + fi = (1 + sqrt(5))/2 + assert (1/(1 - x - x**2)).series(x, 1/fi, 1).factor() == \ + (-37*sqrt(5) - 83 + 13*sqrt(5)*x + 29*x + O((x - 2/(1 + sqrt(5)))**2, (x\ + , 2/(1 + sqrt(5)))))/((2*sqrt(5) + 5)**2*(x + sqrt(5)*x - 2)) + + + +def test_issue_21938(): + expr = sin(1/x + exp(-x)) - sin(1/x) + assert expr.series(x, oo) == (1/(24*x**4) - 1/(2*x**2) + 1 + O(x**(-6), (x, oo)))*exp(-x) + + +def test_issue_23432(): + expr = 1/sqrt(1 - x**2) + result = expr.series(x, 0.5) + assert result.is_Add and len(result.args) == 7 + + +def test_issue_23727(): + res = series(sqrt(1 - x**2), x, 0.1) + assert res.is_Add == True + + +def test_issue_24266(): + #type1: exp(f(x)) + assert (exp(-I*pi*(2*x+1))).series(x, 0, 3) == -1 + 2*I*pi*x + 2*pi**2*x**2 + O(x**3) + assert (exp(-I*pi*(2*x+1))*gamma(1+x)).series(x, 0, 3) == -1 + x*(EulerGamma + 2*I*pi) + \ + x**2*(-EulerGamma**2/2 + 23*pi**2/12 - 2*EulerGamma*I*pi) + O(x**3) + + #type2: c**f(x) + assert ((2*I)**(-I*pi*(2*x+1))).series(x, 0, 2) == exp(pi**2/2 - I*pi*log(2)) + \ + x*(pi**2*exp(pi**2/2 - I*pi*log(2)) - 2*I*pi*exp(pi**2/2 - I*pi*log(2))*log(2)) + O(x**2) + assert ((2)**(-I*pi*(2*x+1))).series(x, 0, 2) == exp(-I*pi*log(2)) - 2*I*pi*x*exp(-I*pi*log(2))*log(2) + O(x**2) + + #type3: f(y)**g(x) + assert ((y)**(I*pi*(2*x+1))).series(x, 0, 2) == exp(I*pi*log(y)) + 2*I*pi*x*exp(I*pi*log(y))*log(y) + O(x**2) + assert ((I*y)**(I*pi*(2*x+1))).series(x, 0, 2) == exp(I*pi*log(I*y)) + 2*I*pi*x*exp(I*pi*log(I*y))*log(I*y) + O(x**2) + + +def test_issue_26856(): + raises(ValueError, lambda: (2**x).series(x, oo, -1)) diff --git a/phivenv/Lib/site-packages/sympy/sets/__init__.py b/phivenv/Lib/site-packages/sympy/sets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b909c0b5ef03b1e1e76dfbf4288f61860575da7 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/sets/__init__.py @@ -0,0 +1,36 @@ +from .sets import (Set, Interval, Union, FiniteSet, ProductSet, + Intersection, imageset, Complement, SymmetricDifference, + DisjointUnion) + +from .fancysets import ImageSet, Range, ComplexRegion +from .contains import Contains +from .conditionset import ConditionSet +from .ordinals import Ordinal, OmegaPower, ord0 +from .powerset import PowerSet +from ..core.singleton import S +from .handlers.comparison import _eval_is_eq # noqa:F401 +Complexes = S.Complexes +EmptySet = S.EmptySet +Integers = S.Integers +Naturals = S.Naturals +Naturals0 = S.Naturals0 +Rationals = S.Rationals +Reals = S.Reals +UniversalSet = S.UniversalSet + +__all__ = [ + 'Set', 'Interval', 'Union', 'EmptySet', 'FiniteSet', 'ProductSet', + 'Intersection', 'imageset', 'Complement', 'SymmetricDifference', 'DisjointUnion', + + 'ImageSet', 'Range', 'ComplexRegion', 'Reals', + + 'Contains', + + 'ConditionSet', + + 'Ordinal', 'OmegaPower', 'ord0', + + 'PowerSet', + + 'Reals', 'Naturals', 'Naturals0', 'UniversalSet', 'Integers', 'Rationals', +] diff --git a/phivenv/Lib/site-packages/sympy/sets/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/sets/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ccd5ac34c9722db3c5ae7d3981f69c80e7f35864 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/sets/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/sets/__pycache__/conditionset.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/sets/__pycache__/conditionset.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7c50e12522ec0d224b8ab40bd016ac577c9e3a2 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/sets/__pycache__/conditionset.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/sets/__pycache__/contains.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/sets/__pycache__/contains.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd1e78303f0f052aced3760d5b276bc3de56ced1 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/sets/__pycache__/contains.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/sets/__pycache__/fancysets.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/sets/__pycache__/fancysets.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..921bca17641a2057677010412c4058c3c04e4ee0 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/sets/__pycache__/fancysets.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/sets/__pycache__/ordinals.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/sets/__pycache__/ordinals.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..631bf663208908795e5371ea2cc662f690aa1afa Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/sets/__pycache__/ordinals.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/sets/__pycache__/powerset.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/sets/__pycache__/powerset.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37e95135a7f45ef3eeabe055899c401b3421f06b Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/sets/__pycache__/powerset.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/sets/__pycache__/setexpr.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/sets/__pycache__/setexpr.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee46d206efa5e3e0faaeb00b05756003432715d1 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/sets/__pycache__/setexpr.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/sets/__pycache__/sets.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/sets/__pycache__/sets.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..879a135a0abf7bb50b81c2090f3115084c1d1aeb Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/sets/__pycache__/sets.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/sets/conditionset.py b/phivenv/Lib/site-packages/sympy/sets/conditionset.py new file mode 100644 index 0000000000000000000000000000000000000000..e847e60ce97d7e9922ce907042ace941838b0ab1 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/sets/conditionset.py @@ -0,0 +1,246 @@ +from sympy.core.singleton import S +from sympy.core.basic import Basic +from sympy.core.containers import Tuple +from sympy.core.function import Lambda, BadSignatureError +from sympy.core.logic import fuzzy_bool +from sympy.core.relational import Eq +from sympy.core.symbol import Dummy +from sympy.core.sympify import _sympify +from sympy.logic.boolalg import And, as_Boolean +from sympy.utilities.iterables import sift, flatten, has_dups +from sympy.utilities.exceptions import sympy_deprecation_warning +from .contains import Contains +from .sets import Set, Union, FiniteSet, SetKind + + +adummy = Dummy('conditionset') + + +class ConditionSet(Set): + r""" + Set of elements which satisfies a given condition. + + .. math:: \{x \mid \textrm{condition}(x) = \texttt{True}, x \in S\} + + Examples + ======== + + >>> from sympy import Symbol, S, ConditionSet, pi, Eq, sin, Interval + >>> from sympy.abc import x, y, z + + >>> sin_sols = ConditionSet(x, Eq(sin(x), 0), Interval(0, 2*pi)) + >>> 2*pi in sin_sols + True + >>> pi/2 in sin_sols + False + >>> 3*pi in sin_sols + False + >>> 5 in ConditionSet(x, x**2 > 4, S.Reals) + True + + If the value is not in the base set, the result is false: + + >>> 5 in ConditionSet(x, x**2 > 4, Interval(2, 4)) + False + + Notes + ===== + + Symbols with assumptions should be avoided or else the + condition may evaluate without consideration of the set: + + >>> n = Symbol('n', negative=True) + >>> cond = (n > 0); cond + False + >>> ConditionSet(n, cond, S.Integers) + EmptySet + + Only free symbols can be changed by using `subs`: + + >>> c = ConditionSet(x, x < 1, {x, z}) + >>> c.subs(x, y) + ConditionSet(x, x < 1, {y, z}) + + To check if ``pi`` is in ``c`` use: + + >>> pi in c + False + + If no base set is specified, the universal set is implied: + + >>> ConditionSet(x, x < 1).base_set + UniversalSet + + Only symbols or symbol-like expressions can be used: + + >>> ConditionSet(x + 1, x + 1 < 1, S.Integers) + Traceback (most recent call last): + ... + ValueError: non-symbol dummy not recognized in condition + + When the base set is a ConditionSet, the symbols will be + unified if possible with preference for the outermost symbols: + + >>> ConditionSet(x, x < y, ConditionSet(z, z + y < 2, S.Integers)) + ConditionSet(x, (x < y) & (x + y < 2), Integers) + + """ + def __new__(cls, sym, condition, base_set=S.UniversalSet): + sym = _sympify(sym) + flat = flatten([sym]) + if has_dups(flat): + raise BadSignatureError("Duplicate symbols detected") + base_set = _sympify(base_set) + if not isinstance(base_set, Set): + raise TypeError( + 'base set should be a Set object, not %s' % base_set) + condition = _sympify(condition) + + if isinstance(condition, FiniteSet): + condition_orig = condition + temp = (Eq(lhs, 0) for lhs in condition) + condition = And(*temp) + sympy_deprecation_warning( + f""" +Using a set for the condition in ConditionSet is deprecated. Use a boolean +instead. + +In this case, replace + + {condition_orig} + +with + + {condition} +""", + deprecated_since_version='1.5', + active_deprecations_target="deprecated-conditionset-set", + ) + + condition = as_Boolean(condition) + + if condition is S.true: + return base_set + + if condition is S.false: + return S.EmptySet + + if base_set is S.EmptySet: + return S.EmptySet + + # no simple answers, so now check syms + for i in flat: + if not getattr(i, '_diff_wrt', False): + raise ValueError('`%s` is not symbol-like' % i) + + if base_set.contains(sym) is S.false: + raise TypeError('sym `%s` is not in base_set `%s`' % (sym, base_set)) + + know = None + if isinstance(base_set, FiniteSet): + sifted = sift( + base_set, lambda _: fuzzy_bool(condition.subs(sym, _))) + if sifted[None]: + know = FiniteSet(*sifted[True]) + base_set = FiniteSet(*sifted[None]) + else: + return FiniteSet(*sifted[True]) + + if isinstance(base_set, cls): + s, c, b = base_set.args + def sig(s): + return cls(s, Eq(adummy, 0)).as_dummy().sym + sa, sb = map(sig, (sym, s)) + if sa != sb: + raise BadSignatureError('sym does not match sym of base set') + reps = dict(zip(flatten([sym]), flatten([s]))) + if s == sym: + condition = And(condition, c) + base_set = b + elif not c.free_symbols & sym.free_symbols: + reps = {v: k for k, v in reps.items()} + condition = And(condition, c.xreplace(reps)) + base_set = b + elif not condition.free_symbols & s.free_symbols: + sym = sym.xreplace(reps) + condition = And(condition.xreplace(reps), c) + base_set = b + + # flatten ConditionSet(Contains(ConditionSet())) expressions + if isinstance(condition, Contains) and (sym == condition.args[0]): + if isinstance(condition.args[1], Set): + return condition.args[1].intersect(base_set) + + rv = Basic.__new__(cls, sym, condition, base_set) + return rv if know is None else Union(know, rv) + + sym = property(lambda self: self.args[0]) + condition = property(lambda self: self.args[1]) + base_set = property(lambda self: self.args[2]) + + @property + def free_symbols(self): + cond_syms = self.condition.free_symbols - self.sym.free_symbols + return cond_syms | self.base_set.free_symbols + + @property + def bound_symbols(self): + return flatten([self.sym]) + + def _contains(self, other): + def ok_sig(a, b): + tuples = [isinstance(i, Tuple) for i in (a, b)] + c = tuples.count(True) + if c == 1: + return False + if c == 0: + return True + return len(a) == len(b) and all( + ok_sig(i, j) for i, j in zip(a, b)) + if not ok_sig(self.sym, other): + return S.false + + # try doing base_cond first and return + # False immediately if it is False + base_cond = Contains(other, self.base_set) + if base_cond is S.false: + return S.false + + # Substitute other into condition. This could raise e.g. for + # ConditionSet(x, 1/x >= 0, Reals).contains(0) + lamda = Lambda((self.sym,), self.condition) + try: + lambda_cond = lamda(other) + except TypeError: + return None + else: + return And(base_cond, lambda_cond) + + def as_relational(self, other): + f = Lambda(self.sym, self.condition) + if isinstance(self.sym, Tuple): + f = f(*other) + else: + f = f(other) + return And(f, self.base_set.contains(other)) + + def _eval_subs(self, old, new): + sym, cond, base = self.args + dsym = sym.subs(old, adummy) + insym = dsym.has(adummy) + # prioritize changing a symbol in the base + newbase = base.subs(old, new) + if newbase != base: + if not insym: + cond = cond.subs(old, new) + return self.func(sym, cond, newbase) + if insym: + pass # no change of bound symbols via subs + elif getattr(new, '_diff_wrt', False): + cond = cond.subs(old, new) + else: + pass # let error about the symbol raise from __new__ + return self.func(sym, cond, base) + + def _kind(self): + return SetKind(self.sym.kind) diff --git a/phivenv/Lib/site-packages/sympy/sets/contains.py b/phivenv/Lib/site-packages/sympy/sets/contains.py new file mode 100644 index 0000000000000000000000000000000000000000..403d4875279d718724a898efa5cba41bc7bed6ea --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/sets/contains.py @@ -0,0 +1,63 @@ +from sympy.core import S +from sympy.core.sympify import sympify +from sympy.core.relational import Eq, Ne +from sympy.core.parameters import global_parameters +from sympy.logic.boolalg import Boolean +from sympy.utilities.misc import func_name +from .sets import Set + + +class Contains(Boolean): + """ + Asserts that x is an element of the set S. + + Examples + ======== + + >>> from sympy import Symbol, Integer, S, Contains + >>> Contains(Integer(2), S.Integers) + True + >>> Contains(Integer(-2), S.Naturals) + False + >>> i = Symbol('i', integer=True) + >>> Contains(i, S.Naturals) + Contains(i, Naturals) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Element_%28mathematics%29 + """ + def __new__(cls, x, s, evaluate=None): + x = sympify(x) + s = sympify(s) + + if evaluate is None: + evaluate = global_parameters.evaluate + + if not isinstance(s, Set): + raise TypeError('expecting Set, not %s' % func_name(s)) + + if evaluate: + # _contains can return symbolic booleans that would be returned by + # s.contains(x) but here for Contains(x, s) we only evaluate to + # true, false or return the unevaluated Contains. + result = s._contains(x) + + if isinstance(result, Boolean): + if result in (S.true, S.false): + return result + elif result is not None: + raise TypeError("_contains() should return Boolean or None") + + return super().__new__(cls, x, s) + + @property + def binary_symbols(self): + return set().union(*[i.binary_symbols + for i in self.args[1].args + if i.is_Boolean or i.is_Symbol or + isinstance(i, (Eq, Ne))]) + + def as_set(self): + return self.args[1] diff --git a/phivenv/Lib/site-packages/sympy/sets/fancysets.py b/phivenv/Lib/site-packages/sympy/sets/fancysets.py new file mode 100644 index 0000000000000000000000000000000000000000..e0e24a2a864222d16ba1a697558b5211127fb2ad --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/sets/fancysets.py @@ -0,0 +1,1523 @@ +from functools import reduce +from itertools import product + +from sympy.core.basic import Basic +from sympy.core.containers import Tuple +from sympy.core.expr import Expr +from sympy.core.function import Lambda +from sympy.core.logic import fuzzy_not, fuzzy_or, fuzzy_and +from sympy.core.mod import Mod +from sympy.core.intfunc import igcd +from sympy.core.numbers import oo, Rational, Integer +from sympy.core.relational import Eq, is_eq +from sympy.core.kind import NumberKind +from sympy.core.singleton import Singleton, S +from sympy.core.symbol import Dummy, symbols, Symbol +from sympy.core.sympify import _sympify, sympify, _sympy_converter +from sympy.functions.elementary.integers import ceiling, floor +from sympy.functions.elementary.trigonometric import sin, cos +from sympy.logic.boolalg import And, Or +from .sets import tfn, Set, Interval, Union, FiniteSet, ProductSet, SetKind +from sympy.utilities.misc import filldedent + + +class Rationals(Set, metaclass=Singleton): + """ + Represents the rational numbers. This set is also available as + the singleton ``S.Rationals``. + + Examples + ======== + + >>> from sympy import S + >>> S.Half in S.Rationals + True + >>> iterable = iter(S.Rationals) + >>> [next(iterable) for i in range(12)] + [0, 1, -1, 1/2, 2, -1/2, -2, 1/3, 3, -1/3, -3, 2/3] + """ + + is_iterable = True + _inf = S.NegativeInfinity + _sup = S.Infinity + is_empty = False + is_finite_set = False + + def _contains(self, other): + if not isinstance(other, Expr): + return S.false + return tfn[other.is_rational] + + def __iter__(self): + yield S.Zero + yield S.One + yield S.NegativeOne + d = 2 + while True: + for n in range(d): + if igcd(n, d) == 1: + yield Rational(n, d) + yield Rational(d, n) + yield Rational(-n, d) + yield Rational(-d, n) + d += 1 + + @property + def _boundary(self): + return S.Reals + + def _kind(self): + return SetKind(NumberKind) + + +class Naturals(Set, metaclass=Singleton): + """ + Represents the natural numbers (or counting numbers) which are all + positive integers starting from 1. This set is also available as + the singleton ``S.Naturals``. + + Examples + ======== + + >>> from sympy import S, Interval, pprint + >>> 5 in S.Naturals + True + >>> iterable = iter(S.Naturals) + >>> next(iterable) + 1 + >>> next(iterable) + 2 + >>> next(iterable) + 3 + >>> pprint(S.Naturals.intersect(Interval(0, 10))) + {1, 2, ..., 10} + + See Also + ======== + + Naturals0 : non-negative integers (i.e. includes 0, too) + Integers : also includes negative integers + """ + + is_iterable = True + _inf: Integer = S.One + _sup = S.Infinity + is_empty = False + is_finite_set = False + + def _contains(self, other): + if not isinstance(other, Expr): + return S.false + elif other.is_positive and other.is_integer: + return S.true + elif other.is_integer is False or other.is_positive is False: + return S.false + + def _eval_is_subset(self, other): + return Range(1, oo).is_subset(other) + + def _eval_is_superset(self, other): + return Range(1, oo).is_superset(other) + + def __iter__(self): + i = self._inf + while True: + yield i + i = i + 1 + + @property + def _boundary(self): + return self + + def as_relational(self, x): + return And(Eq(floor(x), x), x >= self.inf, x < oo) + + def _kind(self): + return SetKind(NumberKind) + + +class Naturals0(Naturals): + """Represents the whole numbers which are all the non-negative integers, + inclusive of zero. + + See Also + ======== + + Naturals : positive integers; does not include 0 + Integers : also includes the negative integers + """ + _inf = S.Zero + + def _contains(self, other): + if not isinstance(other, Expr): + return S.false + elif other.is_integer and other.is_nonnegative: + return S.true + elif other.is_integer is False or other.is_nonnegative is False: + return S.false + + def _eval_is_subset(self, other): + return Range(oo).is_subset(other) + + def _eval_is_superset(self, other): + return Range(oo).is_superset(other) + + +class Integers(Set, metaclass=Singleton): + """ + Represents all integers: positive, negative and zero. This set is also + available as the singleton ``S.Integers``. + + Examples + ======== + + >>> from sympy import S, Interval, pprint + >>> 5 in S.Naturals + True + >>> iterable = iter(S.Integers) + >>> next(iterable) + 0 + >>> next(iterable) + 1 + >>> next(iterable) + -1 + >>> next(iterable) + 2 + + >>> pprint(S.Integers.intersect(Interval(-4, 4))) + {-4, -3, ..., 4} + + See Also + ======== + + Naturals0 : non-negative integers + Integers : positive and negative integers and zero + """ + + is_iterable = True + is_empty = False + is_finite_set = False + + def _contains(self, other): + if not isinstance(other, Expr): + return S.false + return tfn[other.is_integer] + + def __iter__(self): + yield S.Zero + i = S.One + while True: + yield i + yield -i + i = i + 1 + + @property + def _inf(self): + return S.NegativeInfinity + + @property + def _sup(self): + return S.Infinity + + @property + def _boundary(self): + return self + + def _kind(self): + return SetKind(NumberKind) + + def as_relational(self, x): + return And(Eq(floor(x), x), -oo < x, x < oo) + + def _eval_is_subset(self, other): + return Range(-oo, oo).is_subset(other) + + def _eval_is_superset(self, other): + return Range(-oo, oo).is_superset(other) + + +class Reals(Interval, metaclass=Singleton): + """ + Represents all real numbers + from negative infinity to positive infinity, + including all integer, rational and irrational numbers. + This set is also available as the singleton ``S.Reals``. + + + Examples + ======== + + >>> from sympy import S, Rational, pi, I + >>> 5 in S.Reals + True + >>> Rational(-1, 2) in S.Reals + True + >>> pi in S.Reals + True + >>> 3*I in S.Reals + False + >>> S.Reals.contains(pi) + True + + + See Also + ======== + + ComplexRegion + """ + @property + def start(self): + return S.NegativeInfinity + + @property + def end(self): + return S.Infinity + + @property + def left_open(self): + return True + + @property + def right_open(self): + return True + + def __eq__(self, other): + return other == Interval(S.NegativeInfinity, S.Infinity) + + def __hash__(self): + return hash(Interval(S.NegativeInfinity, S.Infinity)) + + +class ImageSet(Set): + """ + Image of a set under a mathematical function. The transformation + must be given as a Lambda function which has as many arguments + as the elements of the set upon which it operates, e.g. 1 argument + when acting on the set of integers or 2 arguments when acting on + a complex region. + + This function is not normally called directly, but is called + from ``imageset``. + + + Examples + ======== + + >>> from sympy import Symbol, S, pi, Dummy, Lambda + >>> from sympy import FiniteSet, ImageSet, Interval + + >>> x = Symbol('x') + >>> N = S.Naturals + >>> squares = ImageSet(Lambda(x, x**2), N) # {x**2 for x in N} + >>> 4 in squares + True + >>> 5 in squares + False + + >>> FiniteSet(0, 1, 2, 3, 4, 5, 6, 7, 9, 10).intersect(squares) + {1, 4, 9} + + >>> square_iterable = iter(squares) + >>> for i in range(4): + ... next(square_iterable) + 1 + 4 + 9 + 16 + + If you want to get value for `x` = 2, 1/2 etc. (Please check whether the + `x` value is in ``base_set`` or not before passing it as args) + + >>> squares.lamda(2) + 4 + >>> squares.lamda(S(1)/2) + 1/4 + + >>> n = Dummy('n') + >>> solutions = ImageSet(Lambda(n, n*pi), S.Integers) # solutions of sin(x) = 0 + >>> dom = Interval(-1, 1) + >>> dom.intersect(solutions) + {0} + + See Also + ======== + + sympy.sets.sets.imageset + """ + def __new__(cls, flambda, *sets): + if not isinstance(flambda, Lambda): + raise ValueError('First argument must be a Lambda') + + signature = flambda.signature + + if len(signature) != len(sets): + raise ValueError('Incompatible signature') + + sets = [_sympify(s) for s in sets] + + if not all(isinstance(s, Set) for s in sets): + raise TypeError("Set arguments to ImageSet should of type Set") + + if not all(cls._check_sig(sg, st) for sg, st in zip(signature, sets)): + raise ValueError("Signature %s does not match sets %s" % (signature, sets)) + + if flambda is S.IdentityFunction and len(sets) == 1: + return sets[0] + + if not set(flambda.variables) & flambda.expr.free_symbols: + is_empty = fuzzy_or(s.is_empty for s in sets) + if is_empty == True: + return S.EmptySet + elif is_empty == False: + return FiniteSet(flambda.expr) + + return Basic.__new__(cls, flambda, *sets) + + lamda = property(lambda self: self.args[0]) + base_sets = property(lambda self: self.args[1:]) + + @property + def base_set(self): + # XXX: Maybe deprecate this? It is poorly defined in handling + # the multivariate case... + sets = self.base_sets + if len(sets) == 1: + return sets[0] + else: + return ProductSet(*sets).flatten() + + @property + def base_pset(self): + return ProductSet(*self.base_sets) + + @classmethod + def _check_sig(cls, sig_i, set_i): + if sig_i.is_symbol: + return True + elif isinstance(set_i, ProductSet): + sets = set_i.sets + if len(sig_i) != len(sets): + return False + # Recurse through the signature for nested tuples: + return all(cls._check_sig(ts, ps) for ts, ps in zip(sig_i, sets)) + else: + # XXX: Need a better way of checking whether a set is a set of + # Tuples or not. For example a FiniteSet can contain Tuples + # but so can an ImageSet or a ConditionSet. Others like + # Integers, Reals etc can not contain Tuples. We could just + # list the possibilities here... Current code for e.g. + # _contains probably only works for ProductSet. + return True # Give the benefit of the doubt + + def __iter__(self): + already_seen = set() + for i in self.base_pset: + val = self.lamda(*i) + if val in already_seen: + continue + else: + already_seen.add(val) + yield val + + def _is_multivariate(self): + return len(self.lamda.variables) > 1 + + def _contains(self, other): + from sympy.solvers.solveset import _solveset_multi + + def get_symsetmap(signature, base_sets): + '''Attempt to get a map of symbols to base_sets''' + queue = list(zip(signature, base_sets)) + symsetmap = {} + for sig, base_set in queue: + if sig.is_symbol: + symsetmap[sig] = base_set + elif base_set.is_ProductSet: + sets = base_set.sets + if len(sig) != len(sets): + raise ValueError("Incompatible signature") + # Recurse + queue.extend(zip(sig, sets)) + else: + # If we get here then we have something like sig = (x, y) and + # base_set = {(1, 2), (3, 4)}. For now we give up. + return None + + return symsetmap + + def get_equations(expr, candidate): + '''Find the equations relating symbols in expr and candidate.''' + queue = [(expr, candidate)] + for e, c in queue: + if not isinstance(e, Tuple): + yield Eq(e, c) + elif not isinstance(c, Tuple) or len(e) != len(c): + yield False + return + else: + queue.extend(zip(e, c)) + + # Get the basic objects together: + other = _sympify(other) + expr = self.lamda.expr + sig = self.lamda.signature + variables = self.lamda.variables + base_sets = self.base_sets + + # Use dummy symbols for ImageSet parameters so they don't match + # anything in other + rep = {v: Dummy(v.name) for v in variables} + variables = [v.subs(rep) for v in variables] + sig = sig.subs(rep) + expr = expr.subs(rep) + + # Map the parts of other to those in the Lambda expr + equations = [] + for eq in get_equations(expr, other): + # Unsatisfiable equation? + if eq is False: + return S.false + equations.append(eq) + + # Map the symbols in the signature to the corresponding domains + symsetmap = get_symsetmap(sig, base_sets) + if symsetmap is None: + # Can't factor the base sets to a ProductSet + return None + + # Which of the variables in the Lambda signature need to be solved for? + symss = (eq.free_symbols for eq in equations) + variables = set(variables) & reduce(set.union, symss, set()) + + # Use internal multivariate solveset + variables = tuple(variables) + base_sets = [symsetmap[v] for v in variables] + solnset = _solveset_multi(equations, variables, base_sets) + if solnset is None: + return None + return tfn[fuzzy_not(solnset.is_empty)] + + @property + def is_iterable(self): + return all(s.is_iterable for s in self.base_sets) + + def doit(self, **hints): + from sympy.sets.setexpr import SetExpr + f = self.lamda + sig = f.signature + if len(sig) == 1 and sig[0].is_symbol and isinstance(f.expr, Expr): + base_set = self.base_sets[0] + return SetExpr(base_set)._eval_func(f).set + if all(s.is_FiniteSet for s in self.base_sets): + return FiniteSet(*(f(*a) for a in product(*self.base_sets))) + return self + + def _kind(self): + return SetKind(self.lamda.expr.kind) + + +class Range(Set): + """ + Represents a range of integers. Can be called as ``Range(stop)``, + ``Range(start, stop)``, or ``Range(start, stop, step)``; when ``step`` is + not given it defaults to 1. + + ``Range(stop)`` is the same as ``Range(0, stop, 1)`` and the stop value + (just as for Python ranges) is not included in the Range values. + + >>> from sympy import Range + >>> list(Range(3)) + [0, 1, 2] + + The step can also be negative: + + >>> list(Range(10, 0, -2)) + [10, 8, 6, 4, 2] + + The stop value is made canonical so equivalent ranges always + have the same args: + + >>> Range(0, 10, 3) + Range(0, 12, 3) + + Infinite ranges are allowed. ``oo`` and ``-oo`` are never included in the + set (``Range`` is always a subset of ``Integers``). If the starting point + is infinite, then the final value is ``stop - step``. To iterate such a + range, it needs to be reversed: + + >>> from sympy import oo + >>> r = Range(-oo, 1) + >>> r[-1] + 0 + >>> next(iter(r)) + Traceback (most recent call last): + ... + TypeError: Cannot iterate over Range with infinite start + >>> next(iter(r.reversed)) + 0 + + Although ``Range`` is a :class:`Set` (and supports the normal set + operations) it maintains the order of the elements and can + be used in contexts where ``range`` would be used. + + >>> from sympy import Interval + >>> Range(0, 10, 2).intersect(Interval(3, 7)) + Range(4, 8, 2) + >>> list(_) + [4, 6] + + Although slicing of a Range will always return a Range -- possibly + empty -- an empty set will be returned from any intersection that + is empty: + + >>> Range(3)[:0] + Range(0, 0, 1) + >>> Range(3).intersect(Interval(4, oo)) + EmptySet + >>> Range(3).intersect(Range(4, oo)) + EmptySet + + Range will accept symbolic arguments but has very limited support + for doing anything other than displaying the Range: + + >>> from sympy import Symbol, pprint + >>> from sympy.abc import i, j, k + >>> Range(i, j, k).start + i + >>> Range(i, j, k).inf + Traceback (most recent call last): + ... + ValueError: invalid method for symbolic range + + Better success will be had when using integer symbols: + + >>> n = Symbol('n', integer=True) + >>> r = Range(n, n + 20, 3) + >>> r.inf + n + >>> pprint(r) + {n, n + 3, ..., n + 18} + """ + + def __new__(cls, *args): + if len(args) == 1: + if isinstance(args[0], range): + raise TypeError( + 'use sympify(%s) to convert range to Range' % args[0]) + + # expand range + slc = slice(*args) + + if slc.step == 0: + raise ValueError("step cannot be 0") + + start, stop, step = slc.start or 0, slc.stop, slc.step or 1 + try: + ok = [] + for w in (start, stop, step): + w = sympify(w) + if w in [S.NegativeInfinity, S.Infinity] or ( + w.has(Symbol) and w.is_integer != False): + ok.append(w) + elif not w.is_Integer: + if w.is_infinite: + raise ValueError('infinite symbols not allowed') + raise ValueError + else: + ok.append(w) + except ValueError: + raise ValueError(filldedent(''' + Finite arguments to Range must be integers; `imageset` can define + other cases, e.g. use `imageset(i, i/10, Range(3))` to give + [0, 1/10, 1/5].''')) + start, stop, step = ok + + null = False + if any(i.has(Symbol) for i in (start, stop, step)): + dif = stop - start + n = dif/step + if n.is_Rational: + if dif == 0: + null = True + else: # (x, x + 5, 2) or (x, 3*x, x) + n = floor(n) + end = start + n*step + if dif.is_Rational: # (x, x + 5, 2) + if (end - stop).is_negative: + end += step + else: # (x, 3*x, x) + if (end/stop - 1).is_negative: + end += step + elif n.is_extended_negative: + null = True + else: + end = stop # other methods like sup and reversed must fail + elif start.is_infinite: + span = step*(stop - start) + if span is S.NaN or span <= 0: + null = True + elif step.is_Integer and stop.is_infinite and abs(step) != 1: + raise ValueError(filldedent(''' + Step size must be %s in this case.''' % (1 if step > 0 else -1))) + else: + end = stop + else: + oostep = step.is_infinite + if oostep: + step = S.One if step > 0 else S.NegativeOne + n = ceiling((stop - start)/step) + if n <= 0: + null = True + elif oostep: + step = S.One # make it canonical + end = start + step + else: + end = start + n*step + if null: + start = end = S.Zero + step = S.One + return Basic.__new__(cls, start, end, step) + + start = property(lambda self: self.args[0]) + stop = property(lambda self: self.args[1]) + step = property(lambda self: self.args[2]) + + @property + def reversed(self): + """Return an equivalent Range in the opposite order. + + Examples + ======== + + >>> from sympy import Range + >>> Range(10).reversed + Range(9, -1, -1) + """ + if self.has(Symbol): + n = (self.stop - self.start)/self.step + if not n.is_extended_positive or not all( + i.is_integer or i.is_infinite for i in self.args): + raise ValueError('invalid method for symbolic range') + if self.start == self.stop: + return self + return self.func( + self.stop - self.step, self.start - self.step, -self.step) + + def _kind(self): + return SetKind(NumberKind) + + def _contains(self, other): + if self.start == self.stop: + return S.false + if other.is_infinite: + return S.false + if not other.is_integer: + return tfn[other.is_integer] + if self.has(Symbol): + n = (self.stop - self.start)/self.step + if not n.is_extended_positive or not all( + i.is_integer or i.is_infinite for i in self.args): + return + else: + n = self.size + if self.start.is_finite: + ref = self.start + elif self.stop.is_finite: + ref = self.stop + else: # both infinite; step is +/- 1 (enforced by __new__) + return S.true + if n == 1: + return Eq(other, self[0]) + res = (ref - other) % self.step + if res == S.Zero: + if self.has(Symbol): + d = Dummy('i') + return self.as_relational(d).subs(d, other) + return And(other >= self.inf, other <= self.sup) + elif res.is_Integer: # off sequence + return S.false + else: # symbolic/unsimplified residue modulo step + return None + + def __iter__(self): + n = self.size # validate + if not (n.has(S.Infinity) or n.has(S.NegativeInfinity) or n.is_Integer): + raise TypeError("Cannot iterate over symbolic Range") + if self.start in [S.NegativeInfinity, S.Infinity]: + raise TypeError("Cannot iterate over Range with infinite start") + elif self.start != self.stop: + i = self.start + if n.is_infinite: + while True: + yield i + i += self.step + else: + for _ in range(n): + yield i + i += self.step + + @property + def is_iterable(self): + # Check that size can be determined, used by __iter__ + dif = self.stop - self.start + n = dif/self.step + if not (n.has(S.Infinity) or n.has(S.NegativeInfinity) or n.is_Integer): + return False + if self.start in [S.NegativeInfinity, S.Infinity]: + return False + if not (n.is_extended_nonnegative and all(i.is_integer for i in self.args)): + return False + return True + + def __len__(self): + rv = self.size + if rv is S.Infinity: + raise ValueError('Use .size to get the length of an infinite Range') + return int(rv) + + @property + def size(self): + if self.start == self.stop: + return S.Zero + dif = self.stop - self.start + n = dif/self.step + if n.is_infinite: + return S.Infinity + if n.is_extended_nonnegative and all(i.is_integer for i in self.args): + return abs(floor(n)) + raise ValueError('Invalid method for symbolic Range') + + @property + def is_finite_set(self): + if self.start.is_integer and self.stop.is_integer: + return True + return self.size.is_finite + + @property + def is_empty(self): + try: + return self.size.is_zero + except ValueError: + return None + + def __bool__(self): + # this only distinguishes between definite null range + # and non-null/unknown null; getting True doesn't mean + # that it actually is not null + b = is_eq(self.start, self.stop) + if b is None: + raise ValueError('cannot tell if Range is null or not') + return not bool(b) + + def __getitem__(self, i): + ooslice = "cannot slice from the end with an infinite value" + zerostep = "slice step cannot be zero" + infinite = "slicing not possible on range with infinite start" + # if we had to take every other element in the following + # oo, ..., 6, 4, 2, 0 + # we might get oo, ..., 4, 0 or oo, ..., 6, 2 + ambiguous = "cannot unambiguously re-stride from the end " + \ + "with an infinite value" + if isinstance(i, slice): + if self.size.is_finite: # validates, too + if self.start == self.stop: + return Range(0) + start, stop, step = i.indices(self.size) + n = ceiling((stop - start)/step) + if n <= 0: + return Range(0) + canonical_stop = start + n*step + end = canonical_stop - step + ss = step*self.step + return Range(self[start], self[end] + ss, ss) + else: # infinite Range + start = i.start + stop = i.stop + if i.step == 0: + raise ValueError(zerostep) + step = i.step or 1 + ss = step*self.step + #--------------------- + # handle infinite Range + # i.e. Range(-oo, oo) or Range(oo, -oo, -1) + # -------------------- + if self.start.is_infinite and self.stop.is_infinite: + raise ValueError(infinite) + #--------------------- + # handle infinite on right + # e.g. Range(0, oo) or Range(0, -oo, -1) + # -------------------- + if self.stop.is_infinite: + # start and stop are not interdependent -- + # they only depend on step --so we use the + # equivalent reversed values + return self.reversed[ + stop if stop is None else -stop + 1: + start if start is None else -start: + step].reversed + #--------------------- + # handle infinite on the left + # e.g. Range(oo, 0, -1) or Range(-oo, 0) + # -------------------- + # consider combinations of + # start/stop {== None, < 0, == 0, > 0} and + # step {< 0, > 0} + if start is None: + if stop is None: + if step < 0: + return Range(self[-1], self.start, ss) + elif step > 1: + raise ValueError(ambiguous) + else: # == 1 + return self + elif stop < 0: + if step < 0: + return Range(self[-1], self[stop], ss) + else: # > 0 + return Range(self.start, self[stop], ss) + elif stop == 0: + if step > 0: + return Range(0) + else: # < 0 + raise ValueError(ooslice) + elif stop == 1: + if step > 0: + raise ValueError(ooslice) # infinite singleton + else: # < 0 + raise ValueError(ooslice) + else: # > 1 + raise ValueError(ooslice) + elif start < 0: + if stop is None: + if step < 0: + return Range(self[start], self.start, ss) + else: # > 0 + return Range(self[start], self.stop, ss) + elif stop < 0: + return Range(self[start], self[stop], ss) + elif stop == 0: + if step < 0: + raise ValueError(ooslice) + else: # > 0 + return Range(0) + elif stop > 0: + raise ValueError(ooslice) + elif start == 0: + if stop is None: + if step < 0: + raise ValueError(ooslice) # infinite singleton + elif step > 1: + raise ValueError(ambiguous) + else: # == 1 + return self + elif stop < 0: + if step > 1: + raise ValueError(ambiguous) + elif step == 1: + return Range(self.start, self[stop], ss) + else: # < 0 + return Range(0) + else: # >= 0 + raise ValueError(ooslice) + elif start > 0: + raise ValueError(ooslice) + else: + if self.start == self.stop: + raise IndexError('Range index out of range') + if not (all(i.is_integer or i.is_infinite + for i in self.args) and ((self.stop - self.start)/ + self.step).is_extended_positive): + raise ValueError('Invalid method for symbolic Range') + if i == 0: + if self.start.is_infinite: + raise ValueError(ooslice) + return self.start + if i == -1: + if self.stop.is_infinite: + raise ValueError(ooslice) + return self.stop - self.step + n = self.size # must be known for any other index + rv = (self.stop if i < 0 else self.start) + i*self.step + if rv.is_infinite: + raise ValueError(ooslice) + val = (rv - self.start)/self.step + rel = fuzzy_or([val.is_infinite, + fuzzy_and([val.is_nonnegative, (n-val).is_nonnegative])]) + if rel: + return rv + if rel is None: + raise ValueError('Invalid method for symbolic Range') + raise IndexError("Range index out of range") + + @property + def _inf(self): + if not self: + return S.EmptySet.inf + if self.has(Symbol): + if all(i.is_integer or i.is_infinite for i in self.args): + dif = self.stop - self.start + if self.step.is_positive and dif.is_positive: + return self.start + elif self.step.is_negative and dif.is_negative: + return self.stop - self.step + raise ValueError('invalid method for symbolic range') + if self.step > 0: + return self.start + else: + return self.stop - self.step + + @property + def _sup(self): + if not self: + return S.EmptySet.sup + if self.has(Symbol): + if all(i.is_integer or i.is_infinite for i in self.args): + dif = self.stop - self.start + if self.step.is_positive and dif.is_positive: + return self.stop - self.step + elif self.step.is_negative and dif.is_negative: + return self.start + raise ValueError('invalid method for symbolic range') + if self.step > 0: + return self.stop - self.step + else: + return self.start + + @property + def _boundary(self): + return self + + def as_relational(self, x): + """Rewrite a Range in terms of equalities and logic operators. """ + if self.start.is_infinite: + assert not self.stop.is_infinite # by instantiation + a = self.reversed.start + else: + a = self.start + step = self.step + in_seq = Eq(Mod(x - a, step), 0) + ints = And(Eq(Mod(a, 1), 0), Eq(Mod(step, 1), 0)) + n = (self.stop - self.start)/self.step + if n == 0: + return S.EmptySet.as_relational(x) + if n == 1: + return And(Eq(x, a), ints) + try: + a, b = self.inf, self.sup + except ValueError: + a = None + if a is not None: + range_cond = And( + x > a if a.is_infinite else x >= a, + x < b if b.is_infinite else x <= b) + else: + a, b = self.start, self.stop - self.step + range_cond = Or( + And(self.step >= 1, x > a if a.is_infinite else x >= a, + x < b if b.is_infinite else x <= b), + And(self.step <= -1, x < a if a.is_infinite else x <= a, + x > b if b.is_infinite else x >= b)) + return And(in_seq, ints, range_cond) + + +_sympy_converter[range] = lambda r: Range(r.start, r.stop, r.step) + +def normalize_theta_set(theta): + r""" + Normalize a Real Set `theta` in the interval `[0, 2\pi)`. It returns + a normalized value of theta in the Set. For Interval, a maximum of + one cycle $[0, 2\pi]$, is returned i.e. for theta equal to $[0, 10\pi]$, + returned normalized value would be $[0, 2\pi)$. As of now intervals + with end points as non-multiples of ``pi`` is not supported. + + Raises + ====== + + NotImplementedError + The algorithms for Normalizing theta Set are not yet + implemented. + ValueError + The input is not valid, i.e. the input is not a real set. + RuntimeError + It is a bug, please report to the github issue tracker. + + Examples + ======== + + >>> from sympy.sets.fancysets import normalize_theta_set + >>> from sympy import Interval, FiniteSet, pi + >>> normalize_theta_set(Interval(9*pi/2, 5*pi)) + Interval(pi/2, pi) + >>> normalize_theta_set(Interval(-3*pi/2, pi/2)) + Interval.Ropen(0, 2*pi) + >>> normalize_theta_set(Interval(-pi/2, pi/2)) + Union(Interval(0, pi/2), Interval.Ropen(3*pi/2, 2*pi)) + >>> normalize_theta_set(Interval(-4*pi, 3*pi)) + Interval.Ropen(0, 2*pi) + >>> normalize_theta_set(Interval(-3*pi/2, -pi/2)) + Interval(pi/2, 3*pi/2) + >>> normalize_theta_set(FiniteSet(0, pi, 3*pi)) + {0, pi} + + """ + from sympy.functions.elementary.trigonometric import _pi_coeff + + if theta.is_Interval: + interval_len = theta.measure + # one complete circle + if interval_len >= 2*S.Pi: + if interval_len == 2*S.Pi and theta.left_open and theta.right_open: + k = _pi_coeff(theta.start) + return Union(Interval(0, k*S.Pi, False, True), + Interval(k*S.Pi, 2*S.Pi, True, True)) + return Interval(0, 2*S.Pi, False, True) + + k_start, k_end = _pi_coeff(theta.start), _pi_coeff(theta.end) + + if k_start is None or k_end is None: + raise NotImplementedError("Normalizing theta without pi as coefficient is " + "not yet implemented") + new_start = k_start*S.Pi + new_end = k_end*S.Pi + + if new_start > new_end: + return Union(Interval(S.Zero, new_end, False, theta.right_open), + Interval(new_start, 2*S.Pi, theta.left_open, True)) + else: + return Interval(new_start, new_end, theta.left_open, theta.right_open) + + elif theta.is_FiniteSet: + new_theta = [] + for element in theta: + k = _pi_coeff(element) + if k is None: + raise NotImplementedError('Normalizing theta without pi as ' + 'coefficient, is not Implemented.') + else: + new_theta.append(k*S.Pi) + return FiniteSet(*new_theta) + + elif theta.is_Union: + return Union(*[normalize_theta_set(interval) for interval in theta.args]) + + elif theta.is_subset(S.Reals): + raise NotImplementedError("Normalizing theta when, it is of type %s is not " + "implemented" % type(theta)) + else: + raise ValueError(" %s is not a real set" % (theta)) + + +class ComplexRegion(Set): + r""" + Represents the Set of all Complex Numbers. It can represent a + region of Complex Plane in both the standard forms Polar and + Rectangular coordinates. + + * Polar Form + Input is in the form of the ProductSet or Union of ProductSets + of the intervals of ``r`` and ``theta``, and use the flag ``polar=True``. + + .. math:: Z = \{z \in \mathbb{C} \mid z = r\times (\cos(\theta) + I\sin(\theta)), r \in [\texttt{r}], \theta \in [\texttt{theta}]\} + + * Rectangular Form + Input is in the form of the ProductSet or Union of ProductSets + of interval of x and y, the real and imaginary parts of the Complex numbers in a plane. + Default input type is in rectangular form. + + .. math:: Z = \{z \in \mathbb{C} \mid z = x + Iy, x \in [\operatorname{re}(z)], y \in [\operatorname{im}(z)]\} + + Examples + ======== + + >>> from sympy import ComplexRegion, Interval, S, I, Union + >>> a = Interval(2, 3) + >>> b = Interval(4, 6) + >>> c1 = ComplexRegion(a*b) # Rectangular Form + >>> c1 + CartesianComplexRegion(ProductSet(Interval(2, 3), Interval(4, 6))) + + * c1 represents the rectangular region in complex plane + surrounded by the coordinates (2, 4), (3, 4), (3, 6) and + (2, 6), of the four vertices. + + >>> c = Interval(1, 8) + >>> c2 = ComplexRegion(Union(a*b, b*c)) + >>> c2 + CartesianComplexRegion(Union(ProductSet(Interval(2, 3), Interval(4, 6)), ProductSet(Interval(4, 6), Interval(1, 8)))) + + * c2 represents the Union of two rectangular regions in complex + plane. One of them surrounded by the coordinates of c1 and + other surrounded by the coordinates (4, 1), (6, 1), (6, 8) and + (4, 8). + + >>> 2.5 + 4.5*I in c1 + True + >>> 2.5 + 6.5*I in c1 + False + + >>> r = Interval(0, 1) + >>> theta = Interval(0, 2*S.Pi) + >>> c2 = ComplexRegion(r*theta, polar=True) # Polar Form + >>> c2 # unit Disk + PolarComplexRegion(ProductSet(Interval(0, 1), Interval.Ropen(0, 2*pi))) + + * c2 represents the region in complex plane inside the + Unit Disk centered at the origin. + + >>> 0.5 + 0.5*I in c2 + True + >>> 1 + 2*I in c2 + False + + >>> unit_disk = ComplexRegion(Interval(0, 1)*Interval(0, 2*S.Pi), polar=True) + >>> upper_half_unit_disk = ComplexRegion(Interval(0, 1)*Interval(0, S.Pi), polar=True) + >>> intersection = unit_disk.intersect(upper_half_unit_disk) + >>> intersection + PolarComplexRegion(ProductSet(Interval(0, 1), Interval(0, pi))) + >>> intersection == upper_half_unit_disk + True + + See Also + ======== + + CartesianComplexRegion + PolarComplexRegion + Complexes + + """ + is_ComplexRegion = True + + def __new__(cls, sets, polar=False): + if polar is False: + return CartesianComplexRegion(sets) + elif polar is True: + return PolarComplexRegion(sets) + else: + raise ValueError("polar should be either True or False") + + @property + def sets(self): + """ + Return raw input sets to the self. + + Examples + ======== + + >>> from sympy import Interval, ComplexRegion, Union + >>> a = Interval(2, 3) + >>> b = Interval(4, 5) + >>> c = Interval(1, 7) + >>> C1 = ComplexRegion(a*b) + >>> C1.sets + ProductSet(Interval(2, 3), Interval(4, 5)) + >>> C2 = ComplexRegion(Union(a*b, b*c)) + >>> C2.sets + Union(ProductSet(Interval(2, 3), Interval(4, 5)), ProductSet(Interval(4, 5), Interval(1, 7))) + + """ + return self.args[0] + + @property + def psets(self): + """ + Return a tuple of sets (ProductSets) input of the self. + + Examples + ======== + + >>> from sympy import Interval, ComplexRegion, Union + >>> a = Interval(2, 3) + >>> b = Interval(4, 5) + >>> c = Interval(1, 7) + >>> C1 = ComplexRegion(a*b) + >>> C1.psets + (ProductSet(Interval(2, 3), Interval(4, 5)),) + >>> C2 = ComplexRegion(Union(a*b, b*c)) + >>> C2.psets + (ProductSet(Interval(2, 3), Interval(4, 5)), ProductSet(Interval(4, 5), Interval(1, 7))) + + """ + if self.sets.is_ProductSet: + psets = () + psets = psets + (self.sets, ) + else: + psets = self.sets.args + return psets + + @property + def a_interval(self): + """ + Return the union of intervals of `x` when, self is in + rectangular form, or the union of intervals of `r` when + self is in polar form. + + Examples + ======== + + >>> from sympy import Interval, ComplexRegion, Union + >>> a = Interval(2, 3) + >>> b = Interval(4, 5) + >>> c = Interval(1, 7) + >>> C1 = ComplexRegion(a*b) + >>> C1.a_interval + Interval(2, 3) + >>> C2 = ComplexRegion(Union(a*b, b*c)) + >>> C2.a_interval + Union(Interval(2, 3), Interval(4, 5)) + + """ + a_interval = [] + for element in self.psets: + a_interval.append(element.args[0]) + + a_interval = Union(*a_interval) + return a_interval + + @property + def b_interval(self): + """ + Return the union of intervals of `y` when, self is in + rectangular form, or the union of intervals of `theta` + when self is in polar form. + + Examples + ======== + + >>> from sympy import Interval, ComplexRegion, Union + >>> a = Interval(2, 3) + >>> b = Interval(4, 5) + >>> c = Interval(1, 7) + >>> C1 = ComplexRegion(a*b) + >>> C1.b_interval + Interval(4, 5) + >>> C2 = ComplexRegion(Union(a*b, b*c)) + >>> C2.b_interval + Interval(1, 7) + + """ + b_interval = [] + for element in self.psets: + b_interval.append(element.args[1]) + + b_interval = Union(*b_interval) + return b_interval + + @property + def _measure(self): + """ + The measure of self.sets. + + Examples + ======== + + >>> from sympy import Interval, ComplexRegion, S + >>> a, b = Interval(2, 5), Interval(4, 8) + >>> c = Interval(0, 2*S.Pi) + >>> c1 = ComplexRegion(a*b) + >>> c1.measure + 12 + >>> c2 = ComplexRegion(a*c, polar=True) + >>> c2.measure + 6*pi + + """ + return self.sets._measure + + def _kind(self): + return self.args[0].kind + + @classmethod + def from_real(cls, sets): + """ + Converts given subset of real numbers to a complex region. + + Examples + ======== + + >>> from sympy import Interval, ComplexRegion + >>> unit = Interval(0,1) + >>> ComplexRegion.from_real(unit) + CartesianComplexRegion(ProductSet(Interval(0, 1), {0})) + + """ + if not sets.is_subset(S.Reals): + raise ValueError("sets must be a subset of the real line") + + return CartesianComplexRegion(sets * FiniteSet(0)) + + def _contains(self, other): + from sympy.functions import arg, Abs + + isTuple = isinstance(other, Tuple) + if isTuple and len(other) != 2: + raise ValueError('expecting Tuple of length 2') + + # If the other is not an Expression, and neither a Tuple + if not isinstance(other, (Expr, Tuple)): + return S.false + + # self in rectangular form + if not self.polar: + re, im = other if isTuple else other.as_real_imag() + return tfn[fuzzy_or(fuzzy_and([ + pset.args[0]._contains(re), + pset.args[1]._contains(im)]) + for pset in self.psets)] + + # self in polar form + elif self.polar: + if other.is_zero: + # ignore undefined complex argument + return tfn[fuzzy_or(pset.args[0]._contains(S.Zero) + for pset in self.psets)] + if isTuple: + r, theta = other + else: + r, theta = Abs(other), arg(other) + if theta.is_real and theta.is_number: + # angles in psets are normalized to [0, 2pi) + theta %= 2*S.Pi + return tfn[fuzzy_or(fuzzy_and([ + pset.args[0]._contains(r), + pset.args[1]._contains(theta)]) + for pset in self.psets)] + + +class CartesianComplexRegion(ComplexRegion): + r""" + Set representing a square region of the complex plane. + + .. math:: Z = \{z \in \mathbb{C} \mid z = x + Iy, x \in [\operatorname{re}(z)], y \in [\operatorname{im}(z)]\} + + Examples + ======== + + >>> from sympy import ComplexRegion, I, Interval + >>> region = ComplexRegion(Interval(1, 3) * Interval(4, 6)) + >>> 2 + 5*I in region + True + >>> 5*I in region + False + + See also + ======== + + ComplexRegion + PolarComplexRegion + Complexes + """ + + polar = False + variables = symbols('x, y', cls=Dummy) + + def __new__(cls, sets): + + if sets == S.Reals*S.Reals: + return S.Complexes + + if all(_a.is_FiniteSet for _a in sets.args) and (len(sets.args) == 2): + + # ** ProductSet of FiniteSets in the Complex Plane. ** + # For Cases like ComplexRegion({2, 4}*{3}), It + # would return {2 + 3*I, 4 + 3*I} + + # FIXME: This should probably be handled with something like: + # return ImageSet(Lambda((x, y), x+I*y), sets).rewrite(FiniteSet) + complex_num = [] + for x in sets.args[0]: + for y in sets.args[1]: + complex_num.append(x + S.ImaginaryUnit*y) + return FiniteSet(*complex_num) + else: + return Set.__new__(cls, sets) + + @property + def expr(self): + x, y = self.variables + return x + S.ImaginaryUnit*y + + +class PolarComplexRegion(ComplexRegion): + r""" + Set representing a polar region of the complex plane. + + .. math:: Z = \{z \in \mathbb{C} \mid z = r\times (\cos(\theta) + I\sin(\theta)), r \in [\texttt{r}], \theta \in [\texttt{theta}]\} + + Examples + ======== + + >>> from sympy import ComplexRegion, Interval, oo, pi, I + >>> rset = Interval(0, oo) + >>> thetaset = Interval(0, pi) + >>> upper_half_plane = ComplexRegion(rset * thetaset, polar=True) + >>> 1 + I in upper_half_plane + True + >>> 1 - I in upper_half_plane + False + + See also + ======== + + ComplexRegion + CartesianComplexRegion + Complexes + + """ + + polar = True + variables = symbols('r, theta', cls=Dummy) + + def __new__(cls, sets): + + new_sets = [] + # sets is Union of ProductSets + if not sets.is_ProductSet: + for k in sets.args: + new_sets.append(k) + # sets is ProductSets + else: + new_sets.append(sets) + # Normalize input theta + for k, v in enumerate(new_sets): + new_sets[k] = ProductSet(v.args[0], + normalize_theta_set(v.args[1])) + sets = Union(*new_sets) + return Set.__new__(cls, sets) + + @property + def expr(self): + r, theta = self.variables + return r*(cos(theta) + S.ImaginaryUnit*sin(theta)) + + +class Complexes(CartesianComplexRegion, metaclass=Singleton): + """ + The :class:`Set` of all complex numbers + + Examples + ======== + + >>> from sympy import S, I + >>> S.Complexes + Complexes + >>> 1 + I in S.Complexes + True + + See also + ======== + + Reals + ComplexRegion + + """ + + is_empty = False + is_finite_set = False + + # Override property from superclass since Complexes has no args + @property + def sets(self): + return ProductSet(S.Reals, S.Reals) + + def __new__(cls): + return Set.__new__(cls) diff --git a/phivenv/Lib/site-packages/sympy/sets/handlers/__init__.py b/phivenv/Lib/site-packages/sympy/sets/handlers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/sets/handlers/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/sets/handlers/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4797bede1415cc18d4d9613ba353529c1914b8af Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/sets/handlers/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/sets/handlers/__pycache__/add.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/sets/handlers/__pycache__/add.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ade096f4ad4b17f6ab957c438329639bbed37eb Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/sets/handlers/__pycache__/add.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/sets/handlers/__pycache__/comparison.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/sets/handlers/__pycache__/comparison.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b72034d64c425ed428347d29aa0cdb781eb07a5 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/sets/handlers/__pycache__/comparison.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/sets/handlers/__pycache__/functions.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/sets/handlers/__pycache__/functions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7cca024c5707e5bc05d7fe9b285ce712f942260 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/sets/handlers/__pycache__/functions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/sets/handlers/__pycache__/intersection.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/sets/handlers/__pycache__/intersection.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04821c4231d426832e6441709017136d4c115186 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/sets/handlers/__pycache__/intersection.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/sets/handlers/__pycache__/issubset.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/sets/handlers/__pycache__/issubset.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d19f4287f028e040f3c238b538c14314e445b7c4 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/sets/handlers/__pycache__/issubset.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/sets/handlers/__pycache__/mul.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/sets/handlers/__pycache__/mul.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c17afc2ef3137c0b8f900d9772e05b9629b0dd0 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/sets/handlers/__pycache__/mul.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/sets/handlers/__pycache__/power.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/sets/handlers/__pycache__/power.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf1d03f0ea8a421259aadd3930a3cd77982dae60 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/sets/handlers/__pycache__/power.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/sets/handlers/__pycache__/union.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/sets/handlers/__pycache__/union.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90b9697eb78a4072fd07bfdf8bf85915278ad5ad Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/sets/handlers/__pycache__/union.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/sets/handlers/add.py b/phivenv/Lib/site-packages/sympy/sets/handlers/add.py new file mode 100644 index 0000000000000000000000000000000000000000..8c07b25ed19d21febffd6b23a92b34b787179f44 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/sets/handlers/add.py @@ -0,0 +1,79 @@ +from sympy.core.numbers import oo, Infinity, NegativeInfinity +from sympy.core.singleton import S +from sympy.core import Basic, Expr +from sympy.multipledispatch import Dispatcher +from sympy.sets import Interval, FiniteSet + + + +# XXX: The functions in this module are clearly not tested and are broken in a +# number of ways. + +_set_add = Dispatcher('_set_add') +_set_sub = Dispatcher('_set_sub') + + +@_set_add.register(Basic, Basic) +def _(x, y): + return None + + +@_set_add.register(Expr, Expr) +def _(x, y): + return x+y + + +@_set_add.register(Interval, Interval) +def _(x, y): + """ + Additions in interval arithmetic + https://en.wikipedia.org/wiki/Interval_arithmetic + """ + return Interval(x.start + y.start, x.end + y.end, + x.left_open or y.left_open, x.right_open or y.right_open) + + +@_set_add.register(Interval, Infinity) +def _(x, y): + if x.start is S.NegativeInfinity: + return Interval(-oo, oo) + return FiniteSet({S.Infinity}) + +@_set_add.register(Interval, NegativeInfinity) +def _(x, y): + if x.end is S.Infinity: + return Interval(-oo, oo) + return FiniteSet({S.NegativeInfinity}) + + +@_set_sub.register(Basic, Basic) +def _(x, y): + return None + + +@_set_sub.register(Expr, Expr) +def _(x, y): + return x-y + + +@_set_sub.register(Interval, Interval) +def _(x, y): + """ + Subtractions in interval arithmetic + https://en.wikipedia.org/wiki/Interval_arithmetic + """ + return Interval(x.start - y.end, x.end - y.start, + x.left_open or y.right_open, x.right_open or y.left_open) + + +@_set_sub.register(Interval, Infinity) +def _(x, y): + if x.start is S.NegativeInfinity: + return Interval(-oo, oo) + return FiniteSet(-oo) + +@_set_sub.register(Interval, NegativeInfinity) +def _(x, y): + if x.start is S.NegativeInfinity: + return Interval(-oo, oo) + return FiniteSet(-oo) diff --git a/phivenv/Lib/site-packages/sympy/sets/handlers/comparison.py b/phivenv/Lib/site-packages/sympy/sets/handlers/comparison.py new file mode 100644 index 0000000000000000000000000000000000000000..b64d1a2a22e15d09f6f10fb4fef730163d468d45 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/sets/handlers/comparison.py @@ -0,0 +1,53 @@ +from sympy.core.relational import Eq, is_eq +from sympy.core.basic import Basic +from sympy.core.logic import fuzzy_and, fuzzy_bool +from sympy.logic.boolalg import And +from sympy.multipledispatch import dispatch +from sympy.sets.sets import tfn, ProductSet, Interval, FiniteSet, Set + + +@dispatch(Interval, FiniteSet) # type:ignore +def _eval_is_eq(lhs, rhs): # noqa: F811 + return False + + +@dispatch(FiniteSet, Interval) # type:ignore +def _eval_is_eq(lhs, rhs): # noqa: F811 + return False + + +@dispatch(Interval, Interval) # type:ignore +def _eval_is_eq(lhs, rhs): # noqa: F811 + return And(Eq(lhs.left, rhs.left), + Eq(lhs.right, rhs.right), + lhs.left_open == rhs.left_open, + lhs.right_open == rhs.right_open) + +@dispatch(FiniteSet, FiniteSet) # type:ignore +def _eval_is_eq(lhs, rhs): # noqa: F811 + def all_in_both(): + s_set = set(lhs.args) + o_set = set(rhs.args) + yield fuzzy_and(lhs._contains(e) for e in o_set - s_set) + yield fuzzy_and(rhs._contains(e) for e in s_set - o_set) + + return tfn[fuzzy_and(all_in_both())] + + +@dispatch(ProductSet, ProductSet) # type:ignore +def _eval_is_eq(lhs, rhs): # noqa: F811 + if len(lhs.sets) != len(rhs.sets): + return False + + eqs = (is_eq(x, y) for x, y in zip(lhs.sets, rhs.sets)) + return tfn[fuzzy_and(map(fuzzy_bool, eqs))] + + +@dispatch(Set, Basic) # type:ignore +def _eval_is_eq(lhs, rhs): # noqa: F811 + return False + + +@dispatch(Set, Set) # type:ignore +def _eval_is_eq(lhs, rhs): # noqa: F811 + return tfn[fuzzy_and(a.is_subset(b) for a, b in [(lhs, rhs), (rhs, lhs)])] diff --git a/phivenv/Lib/site-packages/sympy/sets/handlers/functions.py b/phivenv/Lib/site-packages/sympy/sets/handlers/functions.py new file mode 100644 index 0000000000000000000000000000000000000000..2529dbfd458451d7d09e91c717b170df77b1d9fe --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/sets/handlers/functions.py @@ -0,0 +1,262 @@ +from sympy.core.singleton import S +from sympy.sets.sets import Set +from sympy.calculus.singularities import singularities +from sympy.core import Expr, Add +from sympy.core.function import Lambda, FunctionClass, diff, expand_mul +from sympy.core.numbers import Float, oo +from sympy.core.symbol import Dummy, symbols, Wild +from sympy.functions.elementary.exponential import exp, log +from sympy.functions.elementary.miscellaneous import Min, Max +from sympy.logic.boolalg import true +from sympy.multipledispatch import Dispatcher +from sympy.sets import (imageset, Interval, FiniteSet, Union, ImageSet, + Intersection, Range, Complement) +from sympy.sets.sets import EmptySet, is_function_invertible_in_set +from sympy.sets.fancysets import Integers, Naturals, Reals +from sympy.functions.elementary.exponential import match_real_imag + + +_x, _y = symbols("x y") + +FunctionUnion = (FunctionClass, Lambda) + +_set_function = Dispatcher('_set_function') + + +@_set_function.register(FunctionClass, Set) +def _(f, x): + return None + +@_set_function.register(FunctionUnion, FiniteSet) +def _(f, x): + return FiniteSet(*map(f, x)) + +@_set_function.register(Lambda, Interval) +def _(f, x): + from sympy.solvers.solveset import solveset + from sympy.series import limit + # TODO: handle functions with infinitely many solutions (eg, sin, tan) + # TODO: handle multivariate functions + + expr = f.expr + if len(expr.free_symbols) > 1 or len(f.variables) != 1: + return + var = f.variables[0] + if not var.is_real: + if expr.subs(var, Dummy(real=True)).is_real is False: + return + + if expr.is_Piecewise: + result = S.EmptySet + domain_set = x + for (p_expr, p_cond) in expr.args: + if p_cond is true: + intrvl = domain_set + else: + intrvl = p_cond.as_set() + intrvl = Intersection(domain_set, intrvl) + + if p_expr.is_Number: + image = FiniteSet(p_expr) + else: + image = imageset(Lambda(var, p_expr), intrvl) + result = Union(result, image) + + # remove the part which has been `imaged` + domain_set = Complement(domain_set, intrvl) + if domain_set is S.EmptySet: + break + return result + + if not x.start.is_comparable or not x.end.is_comparable: + return + + try: + from sympy.polys.polyutils import _nsort + sing = list(singularities(expr, var, x)) + if len(sing) > 1: + sing = _nsort(sing) + except NotImplementedError: + return + + if x.left_open: + _start = limit(expr, var, x.start, dir="+") + elif x.start not in sing: + _start = f(x.start) + if x.right_open: + _end = limit(expr, var, x.end, dir="-") + elif x.end not in sing: + _end = f(x.end) + + if len(sing) == 0: + soln_expr = solveset(diff(expr, var), var) + if not (isinstance(soln_expr, FiniteSet) + or soln_expr is S.EmptySet): + return + solns = list(soln_expr) + + extr = [_start, _end] + [f(i) for i in solns + if i.is_real and i in x] + start, end = Min(*extr), Max(*extr) + + left_open, right_open = False, False + if _start <= _end: + # the minimum or maximum value can occur simultaneously + # on both the edge of the interval and in some interior + # point + if start == _start and start not in solns: + left_open = x.left_open + if end == _end and end not in solns: + right_open = x.right_open + else: + if start == _end and start not in solns: + left_open = x.right_open + if end == _start and end not in solns: + right_open = x.left_open + + return Interval(start, end, left_open, right_open) + else: + return imageset(f, Interval(x.start, sing[0], + x.left_open, True)) + \ + Union(*[imageset(f, Interval(sing[i], sing[i + 1], True, True)) + for i in range(0, len(sing) - 1)]) + \ + imageset(f, Interval(sing[-1], x.end, True, x.right_open)) + +@_set_function.register(FunctionClass, Interval) +def _(f, x): + if f == exp: + return Interval(exp(x.start), exp(x.end), x.left_open, x.right_open) + elif f == log: + return Interval(log(x.start), log(x.end), x.left_open, x.right_open) + return ImageSet(Lambda(_x, f(_x)), x) + +@_set_function.register(FunctionUnion, Union) +def _(f, x): + return Union(*(imageset(f, arg) for arg in x.args)) + +@_set_function.register(FunctionUnion, Intersection) +def _(f, x): + # If the function is invertible, intersect the maps of the sets. + if is_function_invertible_in_set(f, x): + return Intersection(*(imageset(f, arg) for arg in x.args)) + else: + return ImageSet(Lambda(_x, f(_x)), x) + +@_set_function.register(FunctionUnion, EmptySet) +def _(f, x): + return x + +@_set_function.register(FunctionUnion, Set) +def _(f, x): + return ImageSet(Lambda(_x, f(_x)), x) + +@_set_function.register(FunctionUnion, Range) +def _(f, self): + if not self: + return S.EmptySet + if not isinstance(f.expr, Expr): + return + if self.size == 1: + return FiniteSet(f(self[0])) + if f is S.IdentityFunction: + return self + + x = f.variables[0] + expr = f.expr + # handle f that is linear in f's variable + if x not in expr.free_symbols or x in expr.diff(x).free_symbols: + return + if self.start.is_finite: + F = f(self.step*x + self.start) # for i in range(len(self)) + else: + F = f(-self.step*x + self[-1]) + F = expand_mul(F) + if F != expr: + return imageset(x, F, Range(self.size)) + +@_set_function.register(FunctionUnion, Integers) +def _(f, self): + expr = f.expr + if not isinstance(expr, Expr): + return + + n = f.variables[0] + if expr == abs(n): + return S.Naturals0 + + # f(x) + c and f(-x) + c cover the same integers + # so choose the form that has the fewest negatives + c = f(0) + fx = f(n) - c + f_x = f(-n) - c + neg_count = lambda e: sum(_.could_extract_minus_sign() + for _ in Add.make_args(e)) + if neg_count(f_x) < neg_count(fx): + expr = f_x + c + + a = Wild('a', exclude=[n]) + b = Wild('b', exclude=[n]) + match = expr.match(a*n + b) + if match and match[a] and ( + not match[a].atoms(Float) and + not match[b].atoms(Float)): + # canonical shift + a, b = match[a], match[b] + if a in [1, -1]: + # drop integer addends in b + nonint = [] + for bi in Add.make_args(b): + if not bi.is_integer: + nonint.append(bi) + b = Add(*nonint) + if b.is_number and a.is_real: + # avoid Mod for complex numbers, #11391 + br, bi = match_real_imag(b) + if br and br.is_comparable and a.is_comparable: + br %= a + b = br + S.ImaginaryUnit*bi + elif b.is_number and a.is_imaginary: + br, bi = match_real_imag(b) + ai = a/S.ImaginaryUnit + if bi and bi.is_comparable and ai.is_comparable: + bi %= ai + b = br + S.ImaginaryUnit*bi + expr = a*n + b + + if expr != f.expr: + return ImageSet(Lambda(n, expr), S.Integers) + + +@_set_function.register(FunctionUnion, Naturals) +def _(f, self): + expr = f.expr + if not isinstance(expr, Expr): + return + + x = f.variables[0] + if not expr.free_symbols - {x}: + if expr == abs(x): + if self is S.Naturals: + return self + return S.Naturals0 + step = expr.coeff(x) + c = expr.subs(x, 0) + if c.is_Integer and step.is_Integer and expr == step*x + c: + if self is S.Naturals: + c += step + if step > 0: + if step == 1: + if c == 0: + return S.Naturals0 + elif c == 1: + return S.Naturals + return Range(c, oo, step) + return Range(c, -oo, step) + + +@_set_function.register(FunctionUnion, Reals) +def _(f, self): + expr = f.expr + if not isinstance(expr, Expr): + return + return _set_function(f, Interval(-oo, oo)) diff --git a/phivenv/Lib/site-packages/sympy/sets/handlers/intersection.py b/phivenv/Lib/site-packages/sympy/sets/handlers/intersection.py new file mode 100644 index 0000000000000000000000000000000000000000..fcb9309ef3e9d2722ab1bfe664f1d1644f17da5d --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/sets/handlers/intersection.py @@ -0,0 +1,533 @@ +from sympy.core.basic import _aresame +from sympy.core.function import Lambda, expand_complex +from sympy.core.mul import Mul +from sympy.core.numbers import ilcm, Float +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, symbols) +from sympy.core.sorting import ordered +from sympy.functions.elementary.complexes import sign +from sympy.functions.elementary.integers import floor, ceiling +from sympy.sets.fancysets import ComplexRegion +from sympy.sets.sets import (FiniteSet, Intersection, Interval, Set, Union) +from sympy.multipledispatch import Dispatcher +from sympy.sets.conditionset import ConditionSet +from sympy.sets.fancysets import (Integers, Naturals, Reals, Range, + ImageSet, Rationals) +from sympy.sets.sets import EmptySet, UniversalSet, imageset, ProductSet +from sympy.simplify.radsimp import numer + + +intersection_sets = Dispatcher('intersection_sets') + + +@intersection_sets.register(ConditionSet, ConditionSet) +def _(a, b): + return None + +@intersection_sets.register(ConditionSet, Set) +def _(a, b): + return ConditionSet(a.sym, a.condition, Intersection(a.base_set, b)) + +@intersection_sets.register(Naturals, Integers) +def _(a, b): + return a + +@intersection_sets.register(Naturals, Naturals) +def _(a, b): + return a if a is S.Naturals else b + +@intersection_sets.register(Interval, Naturals) +def _(a, b): + return intersection_sets(b, a) + +@intersection_sets.register(ComplexRegion, Set) +def _(self, other): + if other.is_ComplexRegion: + # self in rectangular form + if (not self.polar) and (not other.polar): + return ComplexRegion(Intersection(self.sets, other.sets)) + + # self in polar form + elif self.polar and other.polar: + r1, theta1 = self.a_interval, self.b_interval + r2, theta2 = other.a_interval, other.b_interval + new_r_interval = Intersection(r1, r2) + new_theta_interval = Intersection(theta1, theta2) + + # 0 and 2*Pi means the same + if ((2*S.Pi in theta1 and S.Zero in theta2) or + (2*S.Pi in theta2 and S.Zero in theta1)): + new_theta_interval = Union(new_theta_interval, + FiniteSet(0)) + return ComplexRegion(new_r_interval*new_theta_interval, + polar=True) + + + if other.is_subset(S.Reals): + new_interval = [] + x = symbols("x", cls=Dummy, real=True) + + # self in rectangular form + if not self.polar: + for element in self.psets: + if S.Zero in element.args[1]: + new_interval.append(element.args[0]) + new_interval = Union(*new_interval) + return Intersection(new_interval, other) + + # self in polar form + elif self.polar: + for element in self.psets: + if S.Zero in element.args[1]: + new_interval.append(element.args[0]) + if S.Pi in element.args[1]: + new_interval.append(ImageSet(Lambda(x, -x), element.args[0])) + if S.Zero in element.args[0]: + new_interval.append(FiniteSet(0)) + new_interval = Union(*new_interval) + return Intersection(new_interval, other) + +@intersection_sets.register(Integers, Reals) +def _(a, b): + return a + +@intersection_sets.register(Range, Interval) +def _(a, b): + # Check that there are no symbolic arguments + if not all(i.is_number for i in a.args + b.args[:2]): + return + + # In case of null Range, return an EmptySet. + if a.size == 0: + return S.EmptySet + + # trim down to self's size, and represent + # as a Range with step 1. + start = ceiling(max(b.inf, a.inf)) + if start not in b: + start += 1 + end = floor(min(b.sup, a.sup)) + if end not in b: + end -= 1 + return intersection_sets(a, Range(start, end + 1)) + +@intersection_sets.register(Range, Naturals) +def _(a, b): + return intersection_sets(a, Interval(b.inf, S.Infinity)) + +@intersection_sets.register(Range, Range) +def _(a, b): + # Check that there are no symbolic range arguments + if not all(all(v.is_number for v in r.args) for r in [a, b]): + return None + + # non-overlap quick exits + if not b: + return S.EmptySet + if not a: + return S.EmptySet + if b.sup < a.inf: + return S.EmptySet + if b.inf > a.sup: + return S.EmptySet + + # work with finite end at the start + r1 = a + if r1.start.is_infinite: + r1 = r1.reversed + r2 = b + if r2.start.is_infinite: + r2 = r2.reversed + + # If both ends are infinite then it means that one Range is just the set + # of all integers (the step must be 1). + if r1.start.is_infinite: + return b + if r2.start.is_infinite: + return a + + from sympy.solvers.diophantine.diophantine import diop_linear + + # this equation represents the values of the Range; + # it's a linear equation + eq = lambda r, i: r.start + i*r.step + + # we want to know when the two equations might + # have integer solutions so we use the diophantine + # solver + va, vb = diop_linear(eq(r1, Dummy('a')) - eq(r2, Dummy('b'))) + + # check for no solution + no_solution = va is None and vb is None + if no_solution: + return S.EmptySet + + # there is a solution + # ------------------- + + # find the coincident point, c + a0 = va.as_coeff_Add()[0] + c = eq(r1, a0) + + # find the first point, if possible, in each range + # since c may not be that point + def _first_finite_point(r1, c): + if c == r1.start: + return c + # st is the signed step we need to take to + # get from c to r1.start + st = sign(r1.start - c)*step + # use Range to calculate the first point: + # we want to get as close as possible to + # r1.start; the Range will not be null since + # it will at least contain c + s1 = Range(c, r1.start + st, st)[-1] + if s1 == r1.start: + pass + else: + # if we didn't hit r1.start then, if the + # sign of st didn't match the sign of r1.step + # we are off by one and s1 is not in r1 + if sign(r1.step) != sign(st): + s1 -= st + if s1 not in r1: + return + return s1 + + # calculate the step size of the new Range + step = abs(ilcm(r1.step, r2.step)) + s1 = _first_finite_point(r1, c) + if s1 is None: + return S.EmptySet + s2 = _first_finite_point(r2, c) + if s2 is None: + return S.EmptySet + + # replace the corresponding start or stop in + # the original Ranges with these points; the + # result must have at least one point since + # we know that s1 and s2 are in the Ranges + def _updated_range(r, first): + st = sign(r.step)*step + if r.start.is_finite: + rv = Range(first, r.stop, st) + else: + rv = Range(r.start, first + st, st) + return rv + r1 = _updated_range(a, s1) + r2 = _updated_range(b, s2) + + # work with them both in the increasing direction + if sign(r1.step) < 0: + r1 = r1.reversed + if sign(r2.step) < 0: + r2 = r2.reversed + + # return clipped Range with positive step; it + # can't be empty at this point + start = max(r1.start, r2.start) + stop = min(r1.stop, r2.stop) + return Range(start, stop, step) + + +@intersection_sets.register(Range, Integers) +def _(a, b): + return a + + +@intersection_sets.register(Range, Rationals) +def _(a, b): + return a + + +@intersection_sets.register(ImageSet, Set) +def _(self, other): + from sympy.solvers.diophantine import diophantine + + # Only handle the straight-forward univariate case + if (len(self.lamda.variables) > 1 + or self.lamda.signature != self.lamda.variables): + return None + base_set = self.base_sets[0] + + # Intersection between ImageSets with Integers as base set + # For {f(n) : n in Integers} & {g(m) : m in Integers} we solve the + # diophantine equations f(n)=g(m). + # If the solutions for n are {h(t) : t in Integers} then we return + # {f(h(t)) : t in integers}. + # If the solutions for n are {n_1, n_2, ..., n_k} then we return + # {f(n_i) : 1 <= i <= k}. + if base_set is S.Integers: + gm = None + if isinstance(other, ImageSet) and other.base_sets == (S.Integers,): + gm = other.lamda.expr + var = other.lamda.variables[0] + # Symbol of second ImageSet lambda must be distinct from first + m = Dummy('m') + gm = gm.subs(var, m) + elif other is S.Integers: + m = gm = Dummy('m') + if gm is not None: + fn = self.lamda.expr + n = self.lamda.variables[0] + try: + solns = list(diophantine(fn - gm, syms=(n, m), permute=True)) + except (TypeError, NotImplementedError): + # TypeError if equation not polynomial with rational coeff. + # NotImplementedError if correct format but no solver. + return + # 3 cases are possible for solns: + # - empty set, + # - one or more parametric (infinite) solutions, + # - a finite number of (non-parametric) solution couples. + # Among those, there is one type of solution set that is + # not helpful here: multiple parametric solutions. + if len(solns) == 0: + return S.EmptySet + elif any(s.free_symbols for tupl in solns for s in tupl): + if len(solns) == 1: + soln, solm = solns[0] + (t,) = soln.free_symbols + expr = fn.subs(n, soln.subs(t, n)).expand() + return imageset(Lambda(n, expr), S.Integers) + else: + return + else: + return FiniteSet(*(fn.subs(n, s[0]) for s in solns)) + + if other == S.Reals: + from sympy.solvers.solvers import denoms, solve_linear + + def _solution_union(exprs, sym): + # return a union of linear solutions to i in expr; + # if i cannot be solved, use a ConditionSet for solution + sols = [] + for i in exprs: + x, xis = solve_linear(i, 0, [sym]) + if x == sym: + sols.append(FiniteSet(xis)) + else: + sols.append(ConditionSet(sym, Eq(i, 0))) + return Union(*sols) + + f = self.lamda.expr + n = self.lamda.variables[0] + + n_ = Dummy(n.name, real=True) + f_ = f.subs(n, n_) + + re, im = f_.as_real_imag() + im = expand_complex(im) + + re = re.subs(n_, n) + im = im.subs(n_, n) + ifree = im.free_symbols + lam = Lambda(n, re) + if im.is_zero: + # allow re-evaluation + # of self in this case to make + # the result canonical + pass + elif im.is_zero is False: + return S.EmptySet + elif ifree != {n}: + return None + else: + # univarite imaginary part in same variable; + # use numer instead of as_numer_denom to keep + # this as fast as possible while still handling + # simple cases + base_set &= _solution_union( + Mul.make_args(numer(im)), n) + # exclude values that make denominators 0 + base_set -= _solution_union(denoms(f), n) + return imageset(lam, base_set) + + elif isinstance(other, Interval): + from sympy.solvers.solveset import (invert_real, invert_complex, + solveset) + + f = self.lamda.expr + n = self.lamda.variables[0] + new_inf, new_sup = None, None + new_lopen, new_ropen = other.left_open, other.right_open + + if f.is_real: + inverter = invert_real + else: + inverter = invert_complex + + g1, h1 = inverter(f, other.inf, n) + g2, h2 = inverter(f, other.sup, n) + + if all(isinstance(i, FiniteSet) for i in (h1, h2)): + if g1 == n: + if len(h1) == 1: + new_inf = h1.args[0] + if g2 == n: + if len(h2) == 1: + new_sup = h2.args[0] + # TODO: Design a technique to handle multiple-inverse + # functions + + # Any of the new boundary values cannot be determined + if any(i is None for i in (new_sup, new_inf)): + return + + + range_set = S.EmptySet + + if all(i.is_real for i in (new_sup, new_inf)): + # this assumes continuity of underlying function + # however fixes the case when it is decreasing + if new_inf > new_sup: + new_inf, new_sup = new_sup, new_inf + new_interval = Interval(new_inf, new_sup, new_lopen, new_ropen) + range_set = base_set.intersect(new_interval) + else: + if other.is_subset(S.Reals): + solutions = solveset(f, n, S.Reals) + if not isinstance(range_set, (ImageSet, ConditionSet)): + range_set = solutions.intersect(other) + else: + return + + if range_set is S.EmptySet: + return S.EmptySet + elif isinstance(range_set, Range) and range_set.size is not S.Infinity: + range_set = FiniteSet(*list(range_set)) + + if range_set is not None: + return imageset(Lambda(n, f), range_set) + return + else: + return + + +@intersection_sets.register(ProductSet, ProductSet) +def _(a, b): + if len(b.args) != len(a.args): + return S.EmptySet + return ProductSet(*(i.intersect(j) for i, j in zip(a.sets, b.sets))) + + +@intersection_sets.register(Interval, Interval) +def _(a, b): + # handle (-oo, oo) + infty = S.NegativeInfinity, S.Infinity + if a == Interval(*infty): + l, r = a.left, a.right + if l.is_real or l in infty or r.is_real or r in infty: + return b + + # We can't intersect [0,3] with [x,6] -- we don't know if x>0 or x<0 + if not a._is_comparable(b): + return None + + empty = False + + if a.start <= b.end and b.start <= a.end: + # Get topology right. + if a.start < b.start: + start = b.start + left_open = b.left_open + elif a.start > b.start: + start = a.start + left_open = a.left_open + else: + start = a.start + if not _aresame(a.start, b.start): + # For example Integer(2) != Float(2) + # Prefer the Float boundary because Floats should be + # contagious in calculations. + if b.start.has(Float) and not a.start.has(Float): + start = b.start + elif a.start.has(Float) and not b.start.has(Float): + start = a.start + else: + #this is to ensure that if Eq(a.start, b.start) but + #type(a.start) != type(b.start) the order of a and b + #does not matter for the result + start = list(ordered([a,b]))[0].start + left_open = a.left_open or b.left_open + + if a.end < b.end: + end = a.end + right_open = a.right_open + elif a.end > b.end: + end = b.end + right_open = b.right_open + else: + # see above for logic with start + end = a.end + if not _aresame(a.end, b.end): + if b.end.has(Float) and not a.end.has(Float): + end = b.end + elif a.end.has(Float) and not b.end.has(Float): + end = a.end + else: + end = list(ordered([a,b]))[0].end + right_open = a.right_open or b.right_open + + if end - start == 0 and (left_open or right_open): + empty = True + else: + empty = True + + if empty: + return S.EmptySet + + return Interval(start, end, left_open, right_open) + +@intersection_sets.register(EmptySet, Set) +def _(a, b): + return S.EmptySet + +@intersection_sets.register(UniversalSet, Set) +def _(a, b): + return b + +@intersection_sets.register(FiniteSet, FiniteSet) +def _(a, b): + return FiniteSet(*(a._elements & b._elements)) + +@intersection_sets.register(FiniteSet, Set) +def _(a, b): + try: + return FiniteSet(*[el for el in a if el in b]) + except TypeError: + return None # could not evaluate `el in b` due to symbolic ranges. + +@intersection_sets.register(Set, Set) +def _(a, b): + return None + +@intersection_sets.register(Integers, Rationals) +def _(a, b): + return a + +@intersection_sets.register(Naturals, Rationals) +def _(a, b): + return a + +@intersection_sets.register(Rationals, Reals) +def _(a, b): + return a + +def _intlike_interval(a, b): + try: + if b._inf is S.NegativeInfinity and b._sup is S.Infinity: + return a + s = Range(max(a.inf, ceiling(b.left)), floor(b.right) + 1) + return intersection_sets(s, b) # take out endpoints if open interval + except ValueError: + return None + +@intersection_sets.register(Integers, Interval) +def _(a, b): + return _intlike_interval(a, b) + +@intersection_sets.register(Naturals, Interval) +def _(a, b): + return _intlike_interval(a, b) diff --git a/phivenv/Lib/site-packages/sympy/sets/handlers/issubset.py b/phivenv/Lib/site-packages/sympy/sets/handlers/issubset.py new file mode 100644 index 0000000000000000000000000000000000000000..cc23e8bf56f1743cd7f08452dd09a0acf981f5da --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/sets/handlers/issubset.py @@ -0,0 +1,144 @@ +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.core.logic import fuzzy_and, fuzzy_bool, fuzzy_not, fuzzy_or +from sympy.core.relational import Eq +from sympy.sets.sets import FiniteSet, Interval, Set, Union, ProductSet +from sympy.sets.fancysets import Complexes, Reals, Range, Rationals +from sympy.multipledispatch import Dispatcher + + +_inf_sets = [S.Naturals, S.Naturals0, S.Integers, S.Rationals, S.Reals, S.Complexes] + + +is_subset_sets = Dispatcher('is_subset_sets') + + +@is_subset_sets.register(Set, Set) +def _(a, b): + return None + +@is_subset_sets.register(Interval, Interval) +def _(a, b): + # This is correct but can be made more comprehensive... + if fuzzy_bool(a.start < b.start): + return False + if fuzzy_bool(a.end > b.end): + return False + if (b.left_open and not a.left_open and fuzzy_bool(Eq(a.start, b.start))): + return False + if (b.right_open and not a.right_open and fuzzy_bool(Eq(a.end, b.end))): + return False + +@is_subset_sets.register(Interval, FiniteSet) +def _(a_interval, b_fs): + # An Interval can only be a subset of a finite set if it is finite + # which can only happen if it has zero measure. + if fuzzy_not(a_interval.measure.is_zero): + return False + +@is_subset_sets.register(Interval, Union) +def _(a_interval, b_u): + if all(isinstance(s, (Interval, FiniteSet)) for s in b_u.args): + intervals = [s for s in b_u.args if isinstance(s, Interval)] + if all(fuzzy_bool(a_interval.start < s.start) for s in intervals): + return False + if all(fuzzy_bool(a_interval.end > s.end) for s in intervals): + return False + if a_interval.measure.is_nonzero: + no_overlap = lambda s1, s2: fuzzy_or([ + fuzzy_bool(s1.end <= s2.start), + fuzzy_bool(s1.start >= s2.end), + ]) + if all(no_overlap(s, a_interval) for s in intervals): + return False + +@is_subset_sets.register(Range, Range) +def _(a, b): + if a.step == b.step == 1: + return fuzzy_and([fuzzy_bool(a.start >= b.start), + fuzzy_bool(a.stop <= b.stop)]) + +@is_subset_sets.register(Range, Interval) +def _(a_range, b_interval): + if a_range.step.is_positive: + if b_interval.left_open and a_range.inf.is_finite: + cond_left = a_range.inf > b_interval.left + else: + cond_left = a_range.inf >= b_interval.left + if b_interval.right_open and a_range.sup.is_finite: + cond_right = a_range.sup < b_interval.right + else: + cond_right = a_range.sup <= b_interval.right + return fuzzy_and([cond_left, cond_right]) + +@is_subset_sets.register(Range, FiniteSet) +def _(a_range, b_finiteset): + try: + a_size = a_range.size + except ValueError: + # symbolic Range of unknown size + return None + if a_size > len(b_finiteset): + return False + elif any(arg.has(Symbol) for arg in a_range.args): + return fuzzy_and(b_finiteset.contains(x) for x in a_range) + else: + # Checking A \ B == EmptySet is more efficient than repeated naive + # membership checks on an arbitrary FiniteSet. + a_set = set(a_range) + b_remaining = len(b_finiteset) + # Symbolic expressions and numbers of unknown type (integer or not) are + # all counted as "candidates", i.e. *potentially* matching some a in + # a_range. + cnt_candidate = 0 + for b in b_finiteset: + if b.is_Integer: + a_set.discard(b) + elif fuzzy_not(b.is_integer): + pass + else: + cnt_candidate += 1 + b_remaining -= 1 + if len(a_set) > b_remaining + cnt_candidate: + return False + if len(a_set) == 0: + return True + return None + +@is_subset_sets.register(Interval, Range) +def _(a_interval, b_range): + if a_interval.measure.is_extended_nonzero: + return False + +@is_subset_sets.register(Interval, Rationals) +def _(a_interval, b_rationals): + if a_interval.measure.is_extended_nonzero: + return False + +@is_subset_sets.register(Range, Complexes) +def _(a, b): + return True + +@is_subset_sets.register(Complexes, Interval) +def _(a, b): + return False + +@is_subset_sets.register(Complexes, Range) +def _(a, b): + return False + +@is_subset_sets.register(Complexes, Rationals) +def _(a, b): + return False + +@is_subset_sets.register(Rationals, Reals) +def _(a, b): + return True + +@is_subset_sets.register(Rationals, Range) +def _(a, b): + return False + +@is_subset_sets.register(ProductSet, FiniteSet) +def _(a_ps, b_fs): + return fuzzy_and(b_fs.contains(x) for x in a_ps) diff --git a/phivenv/Lib/site-packages/sympy/sets/handlers/mul.py b/phivenv/Lib/site-packages/sympy/sets/handlers/mul.py new file mode 100644 index 0000000000000000000000000000000000000000..0dedc8068b7973fd4cb6fbf2854e5fa671d188de --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/sets/handlers/mul.py @@ -0,0 +1,79 @@ +from sympy.core import Basic, Expr +from sympy.core.numbers import oo +from sympy.core.symbol import symbols +from sympy.multipledispatch import Dispatcher +from sympy.sets.setexpr import set_mul +from sympy.sets.sets import Interval, Set + + +_x, _y = symbols("x y") + + +_set_mul = Dispatcher('_set_mul') +_set_div = Dispatcher('_set_div') + + +@_set_mul.register(Basic, Basic) +def _(x, y): + return None + +@_set_mul.register(Set, Set) +def _(x, y): + return None + +@_set_mul.register(Expr, Expr) +def _(x, y): + return x*y + +@_set_mul.register(Interval, Interval) +def _(x, y): + """ + Multiplications in interval arithmetic + https://en.wikipedia.org/wiki/Interval_arithmetic + """ + # TODO: some intervals containing 0 and oo will fail as 0*oo returns nan. + comvals = ( + (x.start * y.start, bool(x.left_open or y.left_open)), + (x.start * y.end, bool(x.left_open or y.right_open)), + (x.end * y.start, bool(x.right_open or y.left_open)), + (x.end * y.end, bool(x.right_open or y.right_open)), + ) + # TODO: handle symbolic intervals + minval, minopen = min(comvals) + maxval, maxopen = max(comvals) + return Interval( + minval, + maxval, + minopen, + maxopen + ) + +@_set_div.register(Basic, Basic) +def _(x, y): + return None + +@_set_div.register(Expr, Expr) +def _(x, y): + return x/y + +@_set_div.register(Set, Set) +def _(x, y): + return None + +@_set_div.register(Interval, Interval) +def _(x, y): + """ + Divisions in interval arithmetic + https://en.wikipedia.org/wiki/Interval_arithmetic + """ + if (y.start*y.end).is_negative: + return Interval(-oo, oo) + if y.start == 0: + s2 = oo + else: + s2 = 1/y.start + if y.end == 0: + s1 = -oo + else: + s1 = 1/y.end + return set_mul(x, Interval(s1, s2, y.right_open, y.left_open)) diff --git a/phivenv/Lib/site-packages/sympy/sets/handlers/power.py b/phivenv/Lib/site-packages/sympy/sets/handlers/power.py new file mode 100644 index 0000000000000000000000000000000000000000..3cad4ee49ab27770143bc121d1fbcd024bf01548 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/sets/handlers/power.py @@ -0,0 +1,107 @@ +from sympy.core import Basic, Expr +from sympy.core.function import Lambda +from sympy.core.numbers import oo, Infinity, NegativeInfinity, Zero, Integer +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.miscellaneous import (Max, Min) +from sympy.sets.fancysets import ImageSet +from sympy.sets.setexpr import set_div +from sympy.sets.sets import Set, Interval, FiniteSet, Union +from sympy.multipledispatch import Dispatcher + + +_x, _y = symbols("x y") + + +_set_pow = Dispatcher('_set_pow') + + +@_set_pow.register(Basic, Basic) +def _(x, y): + return None + +@_set_pow.register(Set, Set) +def _(x, y): + return ImageSet(Lambda((_x, _y), (_x ** _y)), x, y) + +@_set_pow.register(Expr, Expr) +def _(x, y): + return x**y + +@_set_pow.register(Interval, Zero) +def _(x, z): + return FiniteSet(S.One) + +@_set_pow.register(Interval, Integer) +def _(x, exponent): + """ + Powers in interval arithmetic + https://en.wikipedia.org/wiki/Interval_arithmetic + """ + s1 = x.start**exponent + s2 = x.end**exponent + if ((s2 > s1) if exponent > 0 else (x.end > -x.start)) == True: + left_open = x.left_open + right_open = x.right_open + # TODO: handle unevaluated condition. + sleft = s2 + else: + # TODO: `s2 > s1` could be unevaluated. + left_open = x.right_open + right_open = x.left_open + sleft = s1 + + if x.start.is_positive: + return Interval( + Min(s1, s2), + Max(s1, s2), left_open, right_open) + elif x.end.is_negative: + return Interval( + Min(s1, s2), + Max(s1, s2), left_open, right_open) + + # Case where x.start < 0 and x.end > 0: + if exponent.is_odd: + if exponent.is_negative: + if x.start.is_zero: + return Interval(s2, oo, x.right_open) + if x.end.is_zero: + return Interval(-oo, s1, True, x.left_open) + return Union(Interval(-oo, s1, True, x.left_open), Interval(s2, oo, x.right_open)) + else: + return Interval(s1, s2, x.left_open, x.right_open) + elif exponent.is_even: + if exponent.is_negative: + if x.start.is_zero: + return Interval(s2, oo, x.right_open) + if x.end.is_zero: + return Interval(s1, oo, x.left_open) + return Interval(0, oo) + else: + return Interval(S.Zero, sleft, S.Zero not in x, left_open) + +@_set_pow.register(Interval, Infinity) +def _(b, e): + # TODO: add logic for open intervals? + if b.start.is_nonnegative: + if b.end < 1: + return FiniteSet(S.Zero) + if b.start > 1: + return FiniteSet(S.Infinity) + return Interval(0, oo) + elif b.end.is_negative: + if b.start > -1: + return FiniteSet(S.Zero) + if b.end < -1: + return FiniteSet(-oo, oo) + return Interval(-oo, oo) + else: + if b.start > -1: + if b.end < 1: + return FiniteSet(S.Zero) + return Interval(0, oo) + return Interval(-oo, oo) + +@_set_pow.register(Interval, NegativeInfinity) +def _(b, e): + return _set_pow(set_div(S.One, b), oo) diff --git a/phivenv/Lib/site-packages/sympy/sets/handlers/union.py b/phivenv/Lib/site-packages/sympy/sets/handlers/union.py new file mode 100644 index 0000000000000000000000000000000000000000..75d867b49969ae2aeea76155dbaae7e05c1a6847 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/sets/handlers/union.py @@ -0,0 +1,147 @@ +from sympy.core.singleton import S +from sympy.core.sympify import sympify +from sympy.functions.elementary.miscellaneous import Min, Max +from sympy.sets.sets import (EmptySet, FiniteSet, Intersection, + Interval, ProductSet, Set, Union, UniversalSet) +from sympy.sets.fancysets import (ComplexRegion, Naturals, Naturals0, + Integers, Rationals, Reals) +from sympy.multipledispatch import Dispatcher + + +union_sets = Dispatcher('union_sets') + + +@union_sets.register(Naturals0, Naturals) +def _(a, b): + return a + +@union_sets.register(Rationals, Naturals) +def _(a, b): + return a + +@union_sets.register(Rationals, Naturals0) +def _(a, b): + return a + +@union_sets.register(Reals, Naturals) +def _(a, b): + return a + +@union_sets.register(Reals, Naturals0) +def _(a, b): + return a + +@union_sets.register(Reals, Rationals) +def _(a, b): + return a + +@union_sets.register(Integers, Set) +def _(a, b): + intersect = Intersection(a, b) + if intersect == a: + return b + elif intersect == b: + return a + +@union_sets.register(ComplexRegion, Set) +def _(a, b): + if b.is_subset(S.Reals): + # treat a subset of reals as a complex region + b = ComplexRegion.from_real(b) + + if b.is_ComplexRegion: + # a in rectangular form + if (not a.polar) and (not b.polar): + return ComplexRegion(Union(a.sets, b.sets)) + # a in polar form + elif a.polar and b.polar: + return ComplexRegion(Union(a.sets, b.sets), polar=True) + return None + +@union_sets.register(EmptySet, Set) +def _(a, b): + return b + + +@union_sets.register(UniversalSet, Set) +def _(a, b): + return a + +@union_sets.register(ProductSet, ProductSet) +def _(a, b): + if b.is_subset(a): + return a + if len(b.sets) != len(a.sets): + return None + if len(a.sets) == 2: + a1, a2 = a.sets + b1, b2 = b.sets + if a1 == b1: + return a1 * Union(a2, b2) + if a2 == b2: + return Union(a1, b1) * a2 + return None + +@union_sets.register(ProductSet, Set) +def _(a, b): + if b.is_subset(a): + return a + return None + +@union_sets.register(Interval, Interval) +def _(a, b): + if a._is_comparable(b): + # Non-overlapping intervals + end = Min(a.end, b.end) + start = Max(a.start, b.start) + if (end < start or + (end == start and (end not in a and end not in b))): + return None + else: + start = Min(a.start, b.start) + end = Max(a.end, b.end) + + left_open = ((a.start != start or a.left_open) and + (b.start != start or b.left_open)) + right_open = ((a.end != end or a.right_open) and + (b.end != end or b.right_open)) + return Interval(start, end, left_open, right_open) + +@union_sets.register(Interval, UniversalSet) +def _(a, b): + return S.UniversalSet + +@union_sets.register(Interval, Set) +def _(a, b): + # If I have open end points and these endpoints are contained in b + # But only in case, when endpoints are finite. Because + # interval does not contain oo or -oo. + open_left_in_b_and_finite = (a.left_open and + sympify(b.contains(a.start)) is S.true and + a.start.is_finite) + open_right_in_b_and_finite = (a.right_open and + sympify(b.contains(a.end)) is S.true and + a.end.is_finite) + if open_left_in_b_and_finite or open_right_in_b_and_finite: + # Fill in my end points and return + open_left = a.left_open and a.start not in b + open_right = a.right_open and a.end not in b + new_a = Interval(a.start, a.end, open_left, open_right) + return {new_a, b} + return None + +@union_sets.register(FiniteSet, FiniteSet) +def _(a, b): + return FiniteSet(*(a._elements | b._elements)) + +@union_sets.register(FiniteSet, Set) +def _(a, b): + # If `b` set contains one of my elements, remove it from `a` + if any(b.contains(x) == True for x in a): + return { + FiniteSet(*[x for x in a if b.contains(x) != True]), b} + return None + +@union_sets.register(Set, Set) +def _(a, b): + return None diff --git a/phivenv/Lib/site-packages/sympy/sets/ordinals.py b/phivenv/Lib/site-packages/sympy/sets/ordinals.py new file mode 100644 index 0000000000000000000000000000000000000000..cfe062354cfe58a4747998e51fa0d261e67576cc --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/sets/ordinals.py @@ -0,0 +1,282 @@ +from sympy.core import Basic, Integer +import operator + + +class OmegaPower(Basic): + """ + Represents ordinal exponential and multiplication terms one of the + building blocks of the :class:`Ordinal` class. + In ``OmegaPower(a, b)``, ``a`` represents exponent and ``b`` represents multiplicity. + """ + def __new__(cls, a, b): + if isinstance(b, int): + b = Integer(b) + if not isinstance(b, Integer) or b <= 0: + raise TypeError("multiplicity must be a positive integer") + + if not isinstance(a, Ordinal): + a = Ordinal.convert(a) + + return Basic.__new__(cls, a, b) + + @property + def exp(self): + return self.args[0] + + @property + def mult(self): + return self.args[1] + + def _compare_term(self, other, op): + if self.exp == other.exp: + return op(self.mult, other.mult) + else: + return op(self.exp, other.exp) + + def __eq__(self, other): + if not isinstance(other, OmegaPower): + try: + other = OmegaPower(0, other) + except TypeError: + return NotImplemented + return self.args == other.args + + def __hash__(self): + return Basic.__hash__(self) + + def __lt__(self, other): + if not isinstance(other, OmegaPower): + try: + other = OmegaPower(0, other) + except TypeError: + return NotImplemented + return self._compare_term(other, operator.lt) + + +class Ordinal(Basic): + """ + Represents ordinals in Cantor normal form. + + Internally, this class is just a list of instances of OmegaPower. + + Examples + ======== + >>> from sympy import Ordinal, OmegaPower + >>> from sympy.sets.ordinals import omega + >>> w = omega + >>> w.is_limit_ordinal + True + >>> Ordinal(OmegaPower(w + 1, 1), OmegaPower(3, 2)) + w**(w + 1) + w**3*2 + >>> 3 + w + w + >>> (w + 1) * w + w**2 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Ordinal_arithmetic + """ + def __new__(cls, *terms): + obj = super().__new__(cls, *terms) + powers = [i.exp for i in obj.args] + if not all(powers[i] >= powers[i+1] for i in range(len(powers) - 1)): + raise ValueError("powers must be in decreasing order") + return obj + + @property + def terms(self): + return self.args + + @property + def leading_term(self): + if self == ord0: + raise ValueError("ordinal zero has no leading term") + return self.terms[0] + + @property + def trailing_term(self): + if self == ord0: + raise ValueError("ordinal zero has no trailing term") + return self.terms[-1] + + @property + def is_successor_ordinal(self): + try: + return self.trailing_term.exp == ord0 + except ValueError: + return False + + @property + def is_limit_ordinal(self): + try: + return not self.trailing_term.exp == ord0 + except ValueError: + return False + + @property + def degree(self): + return self.leading_term.exp + + @classmethod + def convert(cls, integer_value): + if integer_value == 0: + return ord0 + return Ordinal(OmegaPower(0, integer_value)) + + def __eq__(self, other): + if not isinstance(other, Ordinal): + try: + other = Ordinal.convert(other) + except TypeError: + return NotImplemented + return self.terms == other.terms + + def __hash__(self): + return hash(self.args) + + def __lt__(self, other): + if not isinstance(other, Ordinal): + try: + other = Ordinal.convert(other) + except TypeError: + return NotImplemented + for term_self, term_other in zip(self.terms, other.terms): + if term_self != term_other: + return term_self < term_other + return len(self.terms) < len(other.terms) + + def __le__(self, other): + return (self == other or self < other) + + def __gt__(self, other): + return not self <= other + + def __ge__(self, other): + return not self < other + + def __str__(self): + net_str = "" + plus_count = 0 + if self == ord0: + return 'ord0' + for i in self.terms: + if plus_count: + net_str += " + " + + if i.exp == ord0: + net_str += str(i.mult) + elif i.exp == 1: + net_str += 'w' + elif len(i.exp.terms) > 1 or i.exp.is_limit_ordinal: + net_str += 'w**(%s)'%i.exp + else: + net_str += 'w**%s'%i.exp + + if not i.mult == 1 and not i.exp == ord0: + net_str += '*%s'%i.mult + + plus_count += 1 + return(net_str) + + __repr__ = __str__ + + def __add__(self, other): + if not isinstance(other, Ordinal): + try: + other = Ordinal.convert(other) + except TypeError: + return NotImplemented + if other == ord0: + return self + a_terms = list(self.terms) + b_terms = list(other.terms) + r = len(a_terms) - 1 + b_exp = other.degree + while r >= 0 and a_terms[r].exp < b_exp: + r -= 1 + if r < 0: + terms = b_terms + elif a_terms[r].exp == b_exp: + sum_term = OmegaPower(b_exp, a_terms[r].mult + other.leading_term.mult) + terms = a_terms[:r] + [sum_term] + b_terms[1:] + else: + terms = a_terms[:r+1] + b_terms + return Ordinal(*terms) + + def __radd__(self, other): + if not isinstance(other, Ordinal): + try: + other = Ordinal.convert(other) + except TypeError: + return NotImplemented + return other + self + + def __mul__(self, other): + if not isinstance(other, Ordinal): + try: + other = Ordinal.convert(other) + except TypeError: + return NotImplemented + if ord0 in (self, other): + return ord0 + a_exp = self.degree + a_mult = self.leading_term.mult + summation = [] + if other.is_limit_ordinal: + for arg in other.terms: + summation.append(OmegaPower(a_exp + arg.exp, arg.mult)) + + else: + for arg in other.terms[:-1]: + summation.append(OmegaPower(a_exp + arg.exp, arg.mult)) + b_mult = other.trailing_term.mult + summation.append(OmegaPower(a_exp, a_mult*b_mult)) + summation += list(self.terms[1:]) + return Ordinal(*summation) + + def __rmul__(self, other): + if not isinstance(other, Ordinal): + try: + other = Ordinal.convert(other) + except TypeError: + return NotImplemented + return other * self + + def __pow__(self, other): + if not self == omega: + return NotImplemented + return Ordinal(OmegaPower(other, 1)) + + +class OrdinalZero(Ordinal): + """The ordinal zero. + + OrdinalZero can be imported as ``ord0``. + """ + pass + + +class OrdinalOmega(Ordinal): + """The ordinal omega which forms the base of all ordinals in cantor normal form. + + OrdinalOmega can be imported as ``omega``. + + Examples + ======== + + >>> from sympy.sets.ordinals import omega + >>> omega + omega + w*2 + """ + def __new__(cls): + return Ordinal.__new__(cls) + + @property + def terms(self): + return (OmegaPower(1, 1),) + + +ord0 = OrdinalZero() +omega = OrdinalOmega() diff --git a/phivenv/Lib/site-packages/sympy/sets/powerset.py b/phivenv/Lib/site-packages/sympy/sets/powerset.py new file mode 100644 index 0000000000000000000000000000000000000000..2eb3b41b9859281480bc9517a1cad0abe7a5683f --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/sets/powerset.py @@ -0,0 +1,119 @@ +from sympy.core.decorators import _sympifyit +from sympy.core.parameters import global_parameters +from sympy.core.logic import fuzzy_bool +from sympy.core.singleton import S +from sympy.core.sympify import _sympify + +from .sets import Set, FiniteSet, SetKind + + +class PowerSet(Set): + r"""A symbolic object representing a power set. + + Parameters + ========== + + arg : Set + The set to take power of. + + evaluate : bool + The flag to control evaluation. + + If the evaluation is disabled for finite sets, it can take + advantage of using subset test as a membership test. + + Notes + ===== + + Power set `\mathcal{P}(S)` is defined as a set containing all the + subsets of `S`. + + If the set `S` is a finite set, its power set would have + `2^{\left| S \right|}` elements, where `\left| S \right|` denotes + the cardinality of `S`. + + Examples + ======== + + >>> from sympy import PowerSet, S, FiniteSet + + A power set of a finite set: + + >>> PowerSet(FiniteSet(1, 2, 3)) + PowerSet({1, 2, 3}) + + A power set of an empty set: + + >>> PowerSet(S.EmptySet) + PowerSet(EmptySet) + >>> PowerSet(PowerSet(S.EmptySet)) + PowerSet(PowerSet(EmptySet)) + + A power set of an infinite set: + + >>> PowerSet(S.Reals) + PowerSet(Reals) + + Evaluating the power set of a finite set to its explicit form: + + >>> PowerSet(FiniteSet(1, 2, 3)).rewrite(FiniteSet) + FiniteSet(EmptySet, {1}, {2}, {3}, {1, 2}, {1, 3}, {2, 3}, {1, 2, 3}) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Power_set + + .. [2] https://en.wikipedia.org/wiki/Axiom_of_power_set + """ + def __new__(cls, arg, evaluate=None): + if evaluate is None: + evaluate=global_parameters.evaluate + + arg = _sympify(arg) + + if not isinstance(arg, Set): + raise ValueError('{} must be a set.'.format(arg)) + + return super().__new__(cls, arg) + + @property + def arg(self): + return self.args[0] + + def _eval_rewrite_as_FiniteSet(self, *args, **kwargs): + arg = self.arg + if arg.is_FiniteSet: + return arg.powerset() + return None + + @_sympifyit('other', NotImplemented) + def _contains(self, other): + if not isinstance(other, Set): + return None + + return fuzzy_bool(self.arg.is_superset(other)) + + def _eval_is_subset(self, other): + if isinstance(other, PowerSet): + return self.arg.is_subset(other.arg) + + def __len__(self): + return 2 ** len(self.arg) + + def __iter__(self): + found = [S.EmptySet] + yield S.EmptySet + + for x in self.arg: + temp = [] + x = FiniteSet(x) + for y in found: + new = x + y + yield new + temp.append(new) + found.extend(temp) + + @property + def kind(self): + return SetKind(self.arg.kind) diff --git a/phivenv/Lib/site-packages/sympy/sets/setexpr.py b/phivenv/Lib/site-packages/sympy/sets/setexpr.py new file mode 100644 index 0000000000000000000000000000000000000000..94d77d5293617a620b70a945888987ce6cc61157 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/sets/setexpr.py @@ -0,0 +1,97 @@ +from sympy.core import Expr +from sympy.core.decorators import call_highest_priority, _sympifyit +from .fancysets import ImageSet +from .sets import set_add, set_sub, set_mul, set_div, set_pow, set_function + + +class SetExpr(Expr): + """An expression that can take on values of a set. + + Examples + ======== + + >>> from sympy import Interval, FiniteSet + >>> from sympy.sets.setexpr import SetExpr + + >>> a = SetExpr(Interval(0, 5)) + >>> b = SetExpr(FiniteSet(1, 10)) + >>> (a + b).set + Union(Interval(1, 6), Interval(10, 15)) + >>> (2*a + b).set + Interval(1, 20) + """ + _op_priority = 11.0 + + def __new__(cls, setarg): + return Expr.__new__(cls, setarg) + + set = property(lambda self: self.args[0]) + + def _latex(self, printer): + return r"SetExpr\left({}\right)".format(printer._print(self.set)) + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__radd__') + def __add__(self, other): + return _setexpr_apply_operation(set_add, self, other) + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__add__') + def __radd__(self, other): + return _setexpr_apply_operation(set_add, other, self) + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__rmul__') + def __mul__(self, other): + return _setexpr_apply_operation(set_mul, self, other) + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__mul__') + def __rmul__(self, other): + return _setexpr_apply_operation(set_mul, other, self) + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__rsub__') + def __sub__(self, other): + return _setexpr_apply_operation(set_sub, self, other) + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__sub__') + def __rsub__(self, other): + return _setexpr_apply_operation(set_sub, other, self) + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__rpow__') + def __pow__(self, other): + return _setexpr_apply_operation(set_pow, self, other) + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__pow__') + def __rpow__(self, other): + return _setexpr_apply_operation(set_pow, other, self) + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__rtruediv__') + def __truediv__(self, other): + return _setexpr_apply_operation(set_div, self, other) + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__truediv__') + def __rtruediv__(self, other): + return _setexpr_apply_operation(set_div, other, self) + + def _eval_func(self, func): + # TODO: this could be implemented straight into `imageset`: + res = set_function(func, self.set) + if res is None: + return SetExpr(ImageSet(func, self.set)) + return SetExpr(res) + + +def _setexpr_apply_operation(op, x, y): + if isinstance(x, SetExpr): + x = x.set + if isinstance(y, SetExpr): + y = y.set + out = op(x, y) + return SetExpr(out) diff --git a/phivenv/Lib/site-packages/sympy/sets/sets.py b/phivenv/Lib/site-packages/sympy/sets/sets.py new file mode 100644 index 0000000000000000000000000000000000000000..3c85ce87c515cfd4520dcc6b9265fe76d8c6163f --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/sets/sets.py @@ -0,0 +1,2804 @@ +from __future__ import annotations + +from typing import Any, Callable, TYPE_CHECKING, overload +from functools import reduce +from collections import defaultdict +from collections.abc import Mapping, Iterable +import inspect + +from sympy.core.kind import Kind, UndefinedKind, NumberKind +from sympy.core.basic import Basic +from sympy.core.containers import Tuple, TupleKind +from sympy.core.decorators import sympify_method_args, sympify_return +from sympy.core.evalf import EvalfMixin +from sympy.core.expr import Expr +from sympy.core.function import Lambda +from sympy.core.logic import (FuzzyBool, fuzzy_bool, fuzzy_or, fuzzy_and, + fuzzy_not) +from sympy.core.numbers import Float, Integer +from sympy.core.operations import LatticeOp +from sympy.core.parameters import global_parameters +from sympy.core.relational import Eq, Ne, is_lt +from sympy.core.singleton import Singleton, S +from sympy.core.sorting import ordered +from sympy.core.symbol import symbols, Symbol, Dummy, uniquely_named_symbol +from sympy.core.sympify import _sympify, sympify, _sympy_converter +from sympy.functions.elementary.exponential import exp, log +from sympy.functions.elementary.miscellaneous import Max, Min +from sympy.logic.boolalg import And, Or, Not, Xor, true, false +from sympy.utilities.decorator import deprecated +from sympy.utilities.exceptions import sympy_deprecation_warning +from sympy.utilities.iterables import (iproduct, sift, roundrobin, iterable, + subsets) +from sympy.utilities.misc import func_name, filldedent + +from mpmath import mpi, mpf + +from mpmath.libmp.libmpf import prec_to_dps + + +tfn = defaultdict(lambda: None, { + True: S.true, + S.true: S.true, + False: S.false, + S.false: S.false}) + + +@sympify_method_args +class Set(Basic, EvalfMixin): + """ + The base class for any kind of set. + + Explanation + =========== + + This is not meant to be used directly as a container of items. It does not + behave like the builtin ``set``; see :class:`FiniteSet` for that. + + Real intervals are represented by the :class:`Interval` class and unions of + sets by the :class:`Union` class. The empty set is represented by the + :class:`EmptySet` class and available as a singleton as ``S.EmptySet``. + """ + + __slots__: tuple[()] = () + + is_number = False + is_iterable = False + is_interval = False + + is_FiniteSet = False + is_Interval = False + is_ProductSet = False + is_Union = False + is_Intersection: FuzzyBool = None + is_UniversalSet: FuzzyBool = None + is_Complement: FuzzyBool = None + is_ComplexRegion = False + + is_empty: FuzzyBool = None + is_finite_set: FuzzyBool = None + + @property # type: ignore + @deprecated( + """ + The is_EmptySet attribute of Set objects is deprecated. + Use 's is S.EmptySet" or 's.is_empty' instead. + """, + deprecated_since_version="1.5", + active_deprecations_target="deprecated-is-emptyset", + ) + def is_EmptySet(self): + return None + + if TYPE_CHECKING: + + def __new__(cls, *args: Basic | complex) -> Set: + ... + + @overload # type: ignore + def subs(self, arg1: Mapping[Basic | complex, Set | complex], arg2: None=None) -> Set: ... + @overload + def subs(self, arg1: Iterable[tuple[Basic | complex, Set | complex]], arg2: None=None, **kwargs: Any) -> Set: ... + @overload + def subs(self, arg1: Set | complex, arg2: Set | complex) -> Set: ... + @overload + def subs(self, arg1: Mapping[Basic | complex, Basic | complex], arg2: None=None, **kwargs: Any) -> Basic: ... + @overload + def subs(self, arg1: Iterable[tuple[Basic | complex, Basic | complex]], arg2: None=None, **kwargs: Any) -> Basic: ... + @overload + def subs(self, arg1: Basic | complex, arg2: Basic | complex, **kwargs: Any) -> Basic: ... + + def subs(self, arg1: Mapping[Basic | complex, Basic | complex] | Basic | complex, # type: ignore + arg2: Basic | complex | None = None, **kwargs: Any) -> Basic: + ... + + def simplify(self, **kwargs) -> Set: + assert False + + def evalf(self, n: int = 15, subs: dict[Basic, Basic | float] | None = None, + maxn: int = 100, chop: bool = False, strict: bool = False, + quad: str | None = None, verbose: bool = False) -> Set: + ... + + n = evalf + + @staticmethod + def _infimum_key(expr): + """ + Return infimum (if possible) else S.Infinity. + """ + try: + infimum = expr.inf + assert infimum.is_comparable + infimum = infimum.evalf() # issue #18505 + except (NotImplementedError, + AttributeError, AssertionError, ValueError): + infimum = S.Infinity + return infimum + + def union(self, other): + """ + Returns the union of ``self`` and ``other``. + + Examples + ======== + + As a shortcut it is possible to use the ``+`` operator: + + >>> from sympy import Interval, FiniteSet + >>> Interval(0, 1).union(Interval(2, 3)) + Union(Interval(0, 1), Interval(2, 3)) + >>> Interval(0, 1) + Interval(2, 3) + Union(Interval(0, 1), Interval(2, 3)) + >>> Interval(1, 2, True, True) + FiniteSet(2, 3) + Union({3}, Interval.Lopen(1, 2)) + + Similarly it is possible to use the ``-`` operator for set differences: + + >>> Interval(0, 2) - Interval(0, 1) + Interval.Lopen(1, 2) + >>> Interval(1, 3) - FiniteSet(2) + Union(Interval.Ropen(1, 2), Interval.Lopen(2, 3)) + + """ + return Union(self, other) + + def intersect(self, other): + """ + Returns the intersection of 'self' and 'other'. + + Examples + ======== + + >>> from sympy import Interval + + >>> Interval(1, 3).intersect(Interval(1, 2)) + Interval(1, 2) + + >>> from sympy import imageset, Lambda, symbols, S + >>> n, m = symbols('n m') + >>> a = imageset(Lambda(n, 2*n), S.Integers) + >>> a.intersect(imageset(Lambda(m, 2*m + 1), S.Integers)) + EmptySet + + """ + return Intersection(self, other) + + def intersection(self, other): + """ + Alias for :meth:`intersect()` + """ + return self.intersect(other) + + def is_disjoint(self, other): + """ + Returns True if ``self`` and ``other`` are disjoint. + + Examples + ======== + + >>> from sympy import Interval + >>> Interval(0, 2).is_disjoint(Interval(1, 2)) + False + >>> Interval(0, 2).is_disjoint(Interval(3, 4)) + True + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Disjoint_sets + """ + return self.intersect(other) == S.EmptySet + + def isdisjoint(self, other): + """ + Alias for :meth:`is_disjoint()` + """ + return self.is_disjoint(other) + + def complement(self, universe): + r""" + The complement of 'self' w.r.t the given universe. + + Examples + ======== + + >>> from sympy import Interval, S + >>> Interval(0, 1).complement(S.Reals) + Union(Interval.open(-oo, 0), Interval.open(1, oo)) + + >>> Interval(0, 1).complement(S.UniversalSet) + Complement(UniversalSet, Interval(0, 1)) + + """ + return Complement(universe, self) + + def _complement(self, other): + # this behaves as other - self + if isinstance(self, ProductSet) and isinstance(other, ProductSet): + # If self and other are disjoint then other - self == self + if len(self.sets) != len(other.sets): + return other + + # There can be other ways to represent this but this gives: + # (A x B) - (C x D) = ((A - C) x B) U (A x (B - D)) + overlaps = [] + pairs = list(zip(self.sets, other.sets)) + for n in range(len(pairs)): + sets = (o if i != n else o-s for i, (s, o) in enumerate(pairs)) + overlaps.append(ProductSet(*sets)) + return Union(*overlaps) + + elif isinstance(other, Interval): + if isinstance(self, (Interval, FiniteSet)): + return Intersection(other, self.complement(S.Reals)) + + elif isinstance(other, Union): + return Union(*(o - self for o in other.args)) + + elif isinstance(other, Complement): + return Complement(other.args[0], Union(other.args[1], self), evaluate=False) + + elif other is S.EmptySet: + return S.EmptySet + + elif isinstance(other, FiniteSet): + sifted = sift(other, lambda x: fuzzy_bool(self.contains(x))) + # ignore those that are contained in self + return Union(FiniteSet(*(sifted[False])), + Complement(FiniteSet(*(sifted[None])), self, evaluate=False) + if sifted[None] else S.EmptySet) + + def symmetric_difference(self, other): + """ + Returns symmetric difference of ``self`` and ``other``. + + Examples + ======== + + >>> from sympy import Interval, S + >>> Interval(1, 3).symmetric_difference(S.Reals) + Union(Interval.open(-oo, 1), Interval.open(3, oo)) + >>> Interval(1, 10).symmetric_difference(S.Reals) + Union(Interval.open(-oo, 1), Interval.open(10, oo)) + + >>> from sympy import S, EmptySet + >>> S.Reals.symmetric_difference(EmptySet) + Reals + + References + ========== + .. [1] https://en.wikipedia.org/wiki/Symmetric_difference + + """ + return SymmetricDifference(self, other) + + def _symmetric_difference(self, other): + return Union(Complement(self, other), Complement(other, self)) + + @property + def inf(self): + """ + The infimum of ``self``. + + Examples + ======== + + >>> from sympy import Interval, Union + >>> Interval(0, 1).inf + 0 + >>> Union(Interval(0, 1), Interval(2, 3)).inf + 0 + + """ + return self._inf + + @property + def _inf(self): + raise NotImplementedError("(%s)._inf" % self) + + @property + def sup(self): + """ + The supremum of ``self``. + + Examples + ======== + + >>> from sympy import Interval, Union + >>> Interval(0, 1).sup + 1 + >>> Union(Interval(0, 1), Interval(2, 3)).sup + 3 + + """ + return self._sup + + @property + def _sup(self): + raise NotImplementedError("(%s)._sup" % self) + + def contains(self, other): + """ + Returns a SymPy value indicating whether ``other`` is contained + in ``self``: ``true`` if it is, ``false`` if it is not, else + an unevaluated ``Contains`` expression (or, as in the case of + ConditionSet and a union of FiniteSet/Intervals, an expression + indicating the conditions for containment). + + Examples + ======== + + >>> from sympy import Interval, S + >>> from sympy.abc import x + + >>> Interval(0, 1).contains(0.5) + True + + As a shortcut it is possible to use the ``in`` operator, but that + will raise an error unless an affirmative true or false is not + obtained. + + >>> Interval(0, 1).contains(x) + (0 <= x) & (x <= 1) + >>> x in Interval(0, 1) + Traceback (most recent call last): + ... + TypeError: did not evaluate to a bool: None + + The result of 'in' is a bool, not a SymPy value + + >>> 1 in Interval(0, 2) + True + >>> _ is S.true + False + """ + from .contains import Contains + other = sympify(other, strict=True) + + c = self._contains(other) + if isinstance(c, Contains): + return c + if c is None: + return Contains(other, self, evaluate=False) + b = tfn[c] + if b is None: + return c + return b + + def _contains(self, other): + """Test if ``other`` is an element of the set ``self``. + + This is an internal method that is expected to be overridden by + subclasses of ``Set`` and will be called by the public + :func:`Set.contains` method or the :class:`Contains` expression. + + Parameters + ========== + + other: Sympified :class:`Basic` instance + The object whose membership in ``self`` is to be tested. + + Returns + ======= + + Symbolic :class:`Boolean` or ``None``. + + A return value of ``None`` indicates that it is unknown whether + ``other`` is contained in ``self``. Returning ``None`` from here + ensures that ``self.contains(other)`` or ``Contains(self, other)`` will + return an unevaluated :class:`Contains` expression. + + If not ``None`` then the returned value is a :class:`Boolean` that is + logically equivalent to the statement that ``other`` is an element of + ``self``. Usually this would be either ``S.true`` or ``S.false`` but + not always. + """ + raise NotImplementedError(f"{type(self).__name__}._contains") + + def is_subset(self, other): + """ + Returns True if ``self`` is a subset of ``other``. + + Examples + ======== + + >>> from sympy import Interval + >>> Interval(0, 0.5).is_subset(Interval(0, 1)) + True + >>> Interval(0, 1).is_subset(Interval(0, 1, left_open=True)) + False + + """ + if not isinstance(other, Set): + raise ValueError("Unknown argument '%s'" % other) + + # Handle the trivial cases + if self == other: + return True + is_empty = self.is_empty + if is_empty is True: + return True + elif fuzzy_not(is_empty) and other.is_empty: + return False + if self.is_finite_set is False and other.is_finite_set: + return False + + # Dispatch on subclass rules + ret = self._eval_is_subset(other) + if ret is not None: + return ret + ret = other._eval_is_superset(self) + if ret is not None: + return ret + + # Use pairwise rules from multiple dispatch + from sympy.sets.handlers.issubset import is_subset_sets + ret = is_subset_sets(self, other) + if ret is not None: + return ret + + # Fall back on computing the intersection + # XXX: We shouldn't do this. A query like this should be handled + # without evaluating new Set objects. It should be the other way round + # so that the intersect method uses is_subset for evaluation. + if self.intersect(other) == self: + return True + + def _eval_is_subset(self, other): + '''Returns a fuzzy bool for whether self is a subset of other.''' + return None + + def _eval_is_superset(self, other): + '''Returns a fuzzy bool for whether self is a subset of other.''' + return None + + # This should be deprecated: + def issubset(self, other): + """ + Alias for :meth:`is_subset()` + """ + return self.is_subset(other) + + def is_proper_subset(self, other): + """ + Returns True if ``self`` is a proper subset of ``other``. + + Examples + ======== + + >>> from sympy import Interval + >>> Interval(0, 0.5).is_proper_subset(Interval(0, 1)) + True + >>> Interval(0, 1).is_proper_subset(Interval(0, 1)) + False + + """ + if isinstance(other, Set): + return self != other and self.is_subset(other) + else: + raise ValueError("Unknown argument '%s'" % other) + + def is_superset(self, other): + """ + Returns True if ``self`` is a superset of ``other``. + + Examples + ======== + + >>> from sympy import Interval + >>> Interval(0, 0.5).is_superset(Interval(0, 1)) + False + >>> Interval(0, 1).is_superset(Interval(0, 1, left_open=True)) + True + + """ + if isinstance(other, Set): + return other.is_subset(self) + else: + raise ValueError("Unknown argument '%s'" % other) + + # This should be deprecated: + def issuperset(self, other): + """ + Alias for :meth:`is_superset()` + """ + return self.is_superset(other) + + def is_proper_superset(self, other): + """ + Returns True if ``self`` is a proper superset of ``other``. + + Examples + ======== + + >>> from sympy import Interval + >>> Interval(0, 1).is_proper_superset(Interval(0, 0.5)) + True + >>> Interval(0, 1).is_proper_superset(Interval(0, 1)) + False + + """ + if isinstance(other, Set): + return self != other and self.is_superset(other) + else: + raise ValueError("Unknown argument '%s'" % other) + + def _eval_powerset(self): + from .powerset import PowerSet + return PowerSet(self) + + def powerset(self): + """ + Find the Power set of ``self``. + + Examples + ======== + + >>> from sympy import EmptySet, FiniteSet, Interval + + A power set of an empty set: + + >>> A = EmptySet + >>> A.powerset() + {EmptySet} + + A power set of a finite set: + + >>> A = FiniteSet(1, 2) + >>> a, b, c = FiniteSet(1), FiniteSet(2), FiniteSet(1, 2) + >>> A.powerset() == FiniteSet(a, b, c, EmptySet) + True + + A power set of an interval: + + >>> Interval(1, 2).powerset() + PowerSet(Interval(1, 2)) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Power_set + + """ + return self._eval_powerset() + + @property + def measure(self): + """ + The (Lebesgue) measure of ``self``. + + Examples + ======== + + >>> from sympy import Interval, Union + >>> Interval(0, 1).measure + 1 + >>> Union(Interval(0, 1), Interval(2, 3)).measure + 2 + + """ + return self._measure + + @property + def kind(self): + """ + The kind of a Set + + Explanation + =========== + + Any :class:`Set` will have kind :class:`SetKind` which is + parametrised by the kind of the elements of the set. For example + most sets are sets of numbers and will have kind + ``SetKind(NumberKind)``. If elements of sets are different in kind than + their kind will ``SetKind(UndefinedKind)``. See + :class:`sympy.core.kind.Kind` for an explanation of the kind system. + + Examples + ======== + + >>> from sympy import Interval, Matrix, FiniteSet, EmptySet, ProductSet, PowerSet + + >>> FiniteSet(Matrix([1, 2])).kind + SetKind(MatrixKind(NumberKind)) + + >>> Interval(1, 2).kind + SetKind(NumberKind) + + >>> EmptySet.kind + SetKind() + + A :class:`sympy.sets.powerset.PowerSet` is a set of sets: + + >>> PowerSet({1, 2, 3}).kind + SetKind(SetKind(NumberKind)) + + A :class:`ProductSet` represents the set of tuples of elements of + other sets. Its kind is :class:`sympy.core.containers.TupleKind` + parametrised by the kinds of the elements of those sets: + + >>> p = ProductSet(FiniteSet(1, 2), FiniteSet(3, 4)) + >>> list(p) + [(1, 3), (2, 3), (1, 4), (2, 4)] + >>> p.kind + SetKind(TupleKind(NumberKind, NumberKind)) + + When all elements of the set do not have same kind, the kind + will be returned as ``SetKind(UndefinedKind)``: + + >>> FiniteSet(0, Matrix([1, 2])).kind + SetKind(UndefinedKind) + + The kind of the elements of a set are given by the ``element_kind`` + attribute of ``SetKind``: + + >>> Interval(1, 2).kind.element_kind + NumberKind + + See Also + ======== + + NumberKind + sympy.core.kind.UndefinedKind + sympy.core.containers.TupleKind + MatrixKind + sympy.matrices.expressions.sets.MatrixSet + sympy.sets.conditionset.ConditionSet + Rationals + Naturals + Integers + sympy.sets.fancysets.ImageSet + sympy.sets.fancysets.Range + sympy.sets.fancysets.ComplexRegion + sympy.sets.powerset.PowerSet + sympy.sets.sets.ProductSet + sympy.sets.sets.Interval + sympy.sets.sets.Union + sympy.sets.sets.Intersection + sympy.sets.sets.Complement + sympy.sets.sets.EmptySet + sympy.sets.sets.UniversalSet + sympy.sets.sets.FiniteSet + sympy.sets.sets.SymmetricDifference + sympy.sets.sets.DisjointUnion + """ + return self._kind() + + @property + def boundary(self): + """ + The boundary or frontier of a set. + + Explanation + =========== + + A point x is on the boundary of a set S if + + 1. x is in the closure of S. + I.e. Every neighborhood of x contains a point in S. + 2. x is not in the interior of S. + I.e. There does not exist an open set centered on x contained + entirely within S. + + There are the points on the outer rim of S. If S is open then these + points need not actually be contained within S. + + For example, the boundary of an interval is its start and end points. + This is true regardless of whether or not the interval is open. + + Examples + ======== + + >>> from sympy import Interval + >>> Interval(0, 1).boundary + {0, 1} + >>> Interval(0, 1, True, False).boundary + {0, 1} + """ + return self._boundary + + @property + def is_open(self): + """ + Property method to check whether a set is open. + + Explanation + =========== + + A set is open if and only if it has an empty intersection with its + boundary. In particular, a subset A of the reals is open if and only + if each one of its points is contained in an open interval that is a + subset of A. + + Examples + ======== + >>> from sympy import S + >>> S.Reals.is_open + True + >>> S.Rationals.is_open + False + """ + return Intersection(self, self.boundary).is_empty + + @property + def is_closed(self): + """ + A property method to check whether a set is closed. + + Explanation + =========== + + A set is closed if its complement is an open set. The closedness of a + subset of the reals is determined with respect to R and its standard + topology. + + Examples + ======== + >>> from sympy import Interval + >>> Interval(0, 1).is_closed + True + """ + return self.boundary.is_subset(self) + + @property + def closure(self): + """ + Property method which returns the closure of a set. + The closure is defined as the union of the set itself and its + boundary. + + Examples + ======== + >>> from sympy import S, Interval + >>> S.Reals.closure + Reals + >>> Interval(0, 1).closure + Interval(0, 1) + """ + return self + self.boundary + + @property + def interior(self): + """ + Property method which returns the interior of a set. + The interior of a set S consists all points of S that do not + belong to the boundary of S. + + Examples + ======== + >>> from sympy import Interval + >>> Interval(0, 1).interior + Interval.open(0, 1) + >>> Interval(0, 1).boundary.interior + EmptySet + """ + return self - self.boundary + + @property + def _boundary(self): + raise NotImplementedError() + + @property + def _measure(self): + raise NotImplementedError("(%s)._measure" % self) + + def _kind(self): + return SetKind(UndefinedKind) + + def _eval_evalf(self, prec): + dps = prec_to_dps(prec) + return self.func(*[arg.evalf(n=dps) for arg in self.args]) + + @sympify_return([('other', 'Set')], NotImplemented) + def __add__(self, other): + return self.union(other) + + @sympify_return([('other', 'Set')], NotImplemented) + def __or__(self, other): + return self.union(other) + + @sympify_return([('other', 'Set')], NotImplemented) + def __and__(self, other): + return self.intersect(other) + + @sympify_return([('other', 'Set')], NotImplemented) + def __mul__(self, other): + return ProductSet(self, other) + + @sympify_return([('other', 'Set')], NotImplemented) + def __xor__(self, other): + return SymmetricDifference(self, other) + + @sympify_return([('exp', Expr)], NotImplemented) + def __pow__(self, exp): + if not (exp.is_Integer and exp >= 0): + raise ValueError("%s: Exponent must be a positive Integer" % exp) + return ProductSet(*[self]*exp) + + @sympify_return([('other', 'Set')], NotImplemented) + def __sub__(self, other): + return Complement(self, other) + + def __contains__(self, other): + other = _sympify(other) + c = self._contains(other) + b = tfn[c] + if b is None: + # x in y must evaluate to T or F; to entertain a None + # result with Set use y.contains(x) + raise TypeError('did not evaluate to a bool: %r' % c) + return b + + +class ProductSet(Set): + """ + Represents a Cartesian Product of Sets. + + Explanation + =========== + + Returns a Cartesian product given several sets as either an iterable + or individual arguments. + + Can use ``*`` operator on any sets for convenient shorthand. + + Examples + ======== + + >>> from sympy import Interval, FiniteSet, ProductSet + >>> I = Interval(0, 5); S = FiniteSet(1, 2, 3) + >>> ProductSet(I, S) + ProductSet(Interval(0, 5), {1, 2, 3}) + + >>> (2, 2) in ProductSet(I, S) + True + + >>> Interval(0, 1) * Interval(0, 1) # The unit square + ProductSet(Interval(0, 1), Interval(0, 1)) + + >>> coin = FiniteSet('H', 'T') + >>> set(coin**2) + {(H, H), (H, T), (T, H), (T, T)} + + The Cartesian product is not commutative or associative e.g.: + + >>> I*S == S*I + False + >>> (I*I)*I == I*(I*I) + False + + Notes + ===== + + - Passes most operations down to the argument sets + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Cartesian_product + """ + is_ProductSet = True + + def __new__(cls, *sets, **assumptions): + if len(sets) == 1 and iterable(sets[0]) and not isinstance(sets[0], (Set, set)): + sympy_deprecation_warning( + """ +ProductSet(iterable) is deprecated. Use ProductSet(*iterable) instead. + """, + deprecated_since_version="1.5", + active_deprecations_target="deprecated-productset-iterable", + ) + sets = tuple(sets[0]) + + sets = [sympify(s) for s in sets] + + if not all(isinstance(s, Set) for s in sets): + raise TypeError("Arguments to ProductSet should be of type Set") + + # Nullary product of sets is *not* the empty set + if len(sets) == 0: + return FiniteSet(()) + + if S.EmptySet in sets: + return S.EmptySet + + return Basic.__new__(cls, *sets, **assumptions) + + @property + def sets(self): + return self.args + + def flatten(self): + def _flatten(sets): + for s in sets: + if s.is_ProductSet: + yield from _flatten(s.sets) + else: + yield s + return ProductSet(*_flatten(self.sets)) + + + + def _contains(self, element): + """ + ``in`` operator for ProductSets. + + Examples + ======== + + >>> from sympy import Interval + >>> (2, 3) in Interval(0, 5) * Interval(0, 5) + True + + >>> (10, 10) in Interval(0, 5) * Interval(0, 5) + False + + Passes operation on to constituent sets + """ + if element.is_Symbol: + return None + + if not isinstance(element, Tuple) or len(element) != len(self.sets): + return S.false + + return And(*[s.contains(e) for s, e in zip(self.sets, element)]) + + def as_relational(self, *symbols): + symbols = [_sympify(s) for s in symbols] + if len(symbols) != len(self.sets) or not all( + i.is_Symbol for i in symbols): + raise ValueError( + 'number of symbols must match the number of sets') + return And(*[s.as_relational(i) for s, i in zip(self.sets, symbols)]) + + @property + def _boundary(self): + return Union(*(ProductSet(*(b + b.boundary if i != j else b.boundary + for j, b in enumerate(self.sets))) + for i, a in enumerate(self.sets))) + + @property + def is_iterable(self): + """ + A property method which tests whether a set is iterable or not. + Returns True if set is iterable, otherwise returns False. + + Examples + ======== + + >>> from sympy import FiniteSet, Interval + >>> I = Interval(0, 1) + >>> A = FiniteSet(1, 2, 3, 4, 5) + >>> I.is_iterable + False + >>> A.is_iterable + True + + """ + return all(set.is_iterable for set in self.sets) + + def __iter__(self): + """ + A method which implements is_iterable property method. + If self.is_iterable returns True (both constituent sets are iterable), + then return the Cartesian Product. Otherwise, raise TypeError. + """ + return iproduct(*self.sets) + + @property + def is_empty(self): + return fuzzy_or(s.is_empty for s in self.sets) + + @property + def is_finite_set(self): + all_finite = fuzzy_and(s.is_finite_set for s in self.sets) + return fuzzy_or([self.is_empty, all_finite]) + + @property + def _measure(self): + measure = 1 + for s in self.sets: + measure *= s.measure + return measure + + def _kind(self): + return SetKind(TupleKind(*(i.kind.element_kind for i in self.args))) + + def __len__(self): + return reduce(lambda a, b: a*b, (len(s) for s in self.args)) + + def __bool__(self): + return all(self.sets) + + +class Interval(Set): + """ + Represents a real interval as a Set. + + Usage: + Returns an interval with end points ``start`` and ``end``. + + For ``left_open=True`` (default ``left_open`` is ``False``) the interval + will be open on the left. Similarly, for ``right_open=True`` the interval + will be open on the right. + + Examples + ======== + + >>> from sympy import Symbol, Interval + >>> Interval(0, 1) + Interval(0, 1) + >>> Interval.Ropen(0, 1) + Interval.Ropen(0, 1) + >>> Interval.Ropen(0, 1) + Interval.Ropen(0, 1) + >>> Interval.Lopen(0, 1) + Interval.Lopen(0, 1) + >>> Interval.open(0, 1) + Interval.open(0, 1) + + >>> a = Symbol('a', real=True) + >>> Interval(0, a) + Interval(0, a) + + Notes + ===== + - Only real end points are supported + - ``Interval(a, b)`` with $a > b$ will return the empty set + - Use the ``evalf()`` method to turn an Interval into an mpmath + ``mpi`` interval instance + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Interval_%28mathematics%29 + """ + is_Interval = True + + def __new__(cls, start, end, left_open=False, right_open=False): + + start = _sympify(start) + end = _sympify(end) + left_open = _sympify(left_open) + right_open = _sympify(right_open) + + if not all(isinstance(a, (type(true), type(false))) + for a in [left_open, right_open]): + raise NotImplementedError( + "left_open and right_open can have only true/false values, " + "got %s and %s" % (left_open, right_open)) + + # Only allow real intervals + if fuzzy_not(fuzzy_and(i.is_extended_real for i in (start, end, end-start))): + raise ValueError("Non-real intervals are not supported") + + # evaluate if possible + if is_lt(end, start): + return S.EmptySet + elif (end - start).is_negative: + return S.EmptySet + + if end == start and (left_open or right_open): + return S.EmptySet + if end == start and not (left_open or right_open): + if start is S.Infinity or start is S.NegativeInfinity: + return S.EmptySet + return FiniteSet(end) + + # Make sure infinite interval end points are open. + if start is S.NegativeInfinity: + left_open = true + if end is S.Infinity: + right_open = true + if start == S.Infinity or end == S.NegativeInfinity: + return S.EmptySet + + return Basic.__new__(cls, start, end, left_open, right_open) + + @property + def start(self): + """ + The left end point of the interval. + + This property takes the same value as the ``inf`` property. + + Examples + ======== + + >>> from sympy import Interval + >>> Interval(0, 1).start + 0 + + """ + return self._args[0] + + @property + def end(self): + """ + The right end point of the interval. + + This property takes the same value as the ``sup`` property. + + Examples + ======== + + >>> from sympy import Interval + >>> Interval(0, 1).end + 1 + + """ + return self._args[1] + + @property + def left_open(self): + """ + True if interval is left-open. + + Examples + ======== + + >>> from sympy import Interval + >>> Interval(0, 1, left_open=True).left_open + True + >>> Interval(0, 1, left_open=False).left_open + False + + """ + return self._args[2] + + @property + def right_open(self): + """ + True if interval is right-open. + + Examples + ======== + + >>> from sympy import Interval + >>> Interval(0, 1, right_open=True).right_open + True + >>> Interval(0, 1, right_open=False).right_open + False + + """ + return self._args[3] + + @classmethod + def open(cls, a, b): + """Return an interval including neither boundary.""" + return cls(a, b, True, True) + + @classmethod + def Lopen(cls, a, b): + """Return an interval not including the left boundary.""" + return cls(a, b, True, False) + + @classmethod + def Ropen(cls, a, b): + """Return an interval not including the right boundary.""" + return cls(a, b, False, True) + + @property + def _inf(self): + return self.start + + @property + def _sup(self): + return self.end + + @property + def left(self): + return self.start + + @property + def right(self): + return self.end + + @property + def is_empty(self): + if self.left_open or self.right_open: + cond = self.start >= self.end # One/both bounds open + else: + cond = self.start > self.end # Both bounds closed + return fuzzy_bool(cond) + + @property + def is_finite_set(self): + return self.measure.is_zero + + def _complement(self, other): + if other == S.Reals: + a = Interval(S.NegativeInfinity, self.start, + True, not self.left_open) + b = Interval(self.end, S.Infinity, not self.right_open, True) + return Union(a, b) + + if isinstance(other, FiniteSet): + nums = [m for m in other.args if m.is_number] + if nums == []: + return None + + return Set._complement(self, other) + + @property + def _boundary(self): + finite_points = [p for p in (self.start, self.end) + if abs(p) != S.Infinity] + return FiniteSet(*finite_points) + + def _contains(self, other): + if (not isinstance(other, Expr) or other is S.NaN + or other.is_real is False or other.has(S.ComplexInfinity)): + # if an expression has zoo it will be zoo or nan + # and neither of those is real + return false + + if self.start is S.NegativeInfinity and self.end is S.Infinity: + if other.is_real is not None: + return tfn[other.is_real] + + d = Dummy() + return self.as_relational(d).subs(d, other) + + def as_relational(self, x): + """Rewrite an interval in terms of inequalities and logic operators.""" + x = sympify(x) + if self.right_open: + right = x < self.end + else: + right = x <= self.end + if self.left_open: + left = self.start < x + else: + left = self.start <= x + return And(left, right) + + @property + def _measure(self): + return self.end - self.start + + def _kind(self): + return SetKind(NumberKind) + + def to_mpi(self, prec=53): + return mpi(mpf(self.start._eval_evalf(prec)), + mpf(self.end._eval_evalf(prec))) + + def _eval_evalf(self, prec): + return Interval(self.left._evalf(prec), self.right._evalf(prec), + left_open=self.left_open, right_open=self.right_open) + + def _is_comparable(self, other): + is_comparable = self.start.is_comparable + is_comparable &= self.end.is_comparable + is_comparable &= other.start.is_comparable + is_comparable &= other.end.is_comparable + + return is_comparable + + @property + def is_left_unbounded(self): + """Return ``True`` if the left endpoint is negative infinity. """ + return self.left is S.NegativeInfinity or self.left == Float("-inf") + + @property + def is_right_unbounded(self): + """Return ``True`` if the right endpoint is positive infinity. """ + return self.right is S.Infinity or self.right == Float("+inf") + + def _eval_Eq(self, other): + if not isinstance(other, Interval): + if isinstance(other, FiniteSet): + return false + elif isinstance(other, Set): + return None + return false + + +class Union(Set, LatticeOp): + """ + Represents a union of sets as a :class:`Set`. + + Examples + ======== + + >>> from sympy import Union, Interval + >>> Union(Interval(1, 2), Interval(3, 4)) + Union(Interval(1, 2), Interval(3, 4)) + + The Union constructor will always try to merge overlapping intervals, + if possible. For example: + + >>> Union(Interval(1, 2), Interval(2, 3)) + Interval(1, 3) + + See Also + ======== + + Intersection + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Union_%28set_theory%29 + """ + is_Union = True + + @property + def identity(self): + return S.EmptySet + + @property + def zero(self): + return S.UniversalSet + + def __new__(cls, *args, **kwargs): + evaluate = kwargs.get('evaluate', global_parameters.evaluate) + + # flatten inputs to merge intersections and iterables + args = _sympify(args) + + # Reduce sets using known rules + if evaluate: + args = list(cls._new_args_filter(args)) + return simplify_union(args) + + args = list(ordered(args, Set._infimum_key)) + + obj = Basic.__new__(cls, *args) + obj._argset = frozenset(args) + return obj + + @property + def args(self): + return self._args + + def _complement(self, universe): + # DeMorgan's Law + return Intersection(s.complement(universe) for s in self.args) + + @property + def _inf(self): + # We use Min so that sup is meaningful in combination with symbolic + # interval end points. + return Min(*[set.inf for set in self.args]) + + @property + def _sup(self): + # We use Max so that sup is meaningful in combination with symbolic + # end points. + return Max(*[set.sup for set in self.args]) + + @property + def is_empty(self): + return fuzzy_and(set.is_empty for set in self.args) + + @property + def is_finite_set(self): + return fuzzy_and(set.is_finite_set for set in self.args) + + @property + def _measure(self): + # Measure of a union is the sum of the measures of the sets minus + # the sum of their pairwise intersections plus the sum of their + # triple-wise intersections minus ... etc... + + # Sets is a collection of intersections and a set of elementary + # sets which made up those intersections (called "sos" for set of sets) + # An example element might of this list might be: + # ( {A,B,C}, A.intersect(B).intersect(C) ) + + # Start with just elementary sets ( ({A}, A), ({B}, B), ... ) + # Then get and subtract ( ({A,B}, (A int B), ... ) while non-zero + sets = [(FiniteSet(s), s) for s in self.args] + measure = 0 + parity = 1 + while sets: + # Add up the measure of these sets and add or subtract it to total + measure += parity * sum(inter.measure for sos, inter in sets) + + # For each intersection in sets, compute the intersection with every + # other set not already part of the intersection. + sets = ((sos + FiniteSet(newset), newset.intersect(intersection)) + for sos, intersection in sets for newset in self.args + if newset not in sos) + + # Clear out sets with no measure + sets = [(sos, inter) for sos, inter in sets if inter.measure != 0] + + # Clear out duplicates + sos_list = [] + sets_list = [] + for _set in sets: + if _set[0] in sos_list: + continue + else: + sos_list.append(_set[0]) + sets_list.append(_set) + sets = sets_list + + # Flip Parity - next time subtract/add if we added/subtracted here + parity *= -1 + return measure + + def _kind(self): + kinds = tuple(arg.kind for arg in self.args if arg is not S.EmptySet) + if not kinds: + return SetKind() + elif all(i == kinds[0] for i in kinds): + return kinds[0] + else: + return SetKind(UndefinedKind) + + @property + def _boundary(self): + def boundary_of_set(i): + """ The boundary of set i minus interior of all other sets """ + b = self.args[i].boundary + for j, a in enumerate(self.args): + if j != i: + b = b - a.interior + return b + return Union(*map(boundary_of_set, range(len(self.args)))) + + def _contains(self, other): + return Or(*[s.contains(other) for s in self.args]) + + def is_subset(self, other): + return fuzzy_and(s.is_subset(other) for s in self.args) + + def as_relational(self, symbol): + """Rewrite a Union in terms of equalities and logic operators. """ + if (len(self.args) == 2 and + all(isinstance(i, Interval) for i in self.args)): + # optimization to give 3 args as (x > 1) & (x < 5) & Ne(x, 3) + # instead of as 4, ((1 <= x) & (x < 3)) | ((x <= 5) & (3 < x)) + # XXX: This should be ideally be improved to handle any number of + # intervals and also not to assume that the intervals are in any + # particular sorted order. + a, b = self.args + if a.sup == b.inf and a.right_open and b.left_open: + mincond = symbol > a.inf if a.left_open else symbol >= a.inf + maxcond = symbol < b.sup if b.right_open else symbol <= b.sup + necond = Ne(symbol, a.sup) + return And(necond, mincond, maxcond) + return Or(*[i.as_relational(symbol) for i in self.args]) + + @property + def is_iterable(self): + return all(arg.is_iterable for arg in self.args) + + def __iter__(self): + return roundrobin(*(iter(arg) for arg in self.args)) + + +class Intersection(Set, LatticeOp): + """ + Represents an intersection of sets as a :class:`Set`. + + Examples + ======== + + >>> from sympy import Intersection, Interval + >>> Intersection(Interval(1, 3), Interval(2, 4)) + Interval(2, 3) + + We often use the .intersect method + + >>> Interval(1,3).intersect(Interval(2,4)) + Interval(2, 3) + + See Also + ======== + + Union + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Intersection_%28set_theory%29 + """ + is_Intersection = True + + @property + def identity(self): + return S.UniversalSet + + @property + def zero(self): + return S.EmptySet + + def __new__(cls, *args , evaluate=None): + if evaluate is None: + evaluate = global_parameters.evaluate + + # flatten inputs to merge intersections and iterables + args = list(ordered(set(_sympify(args)))) + + # Reduce sets using known rules + if evaluate: + args = list(cls._new_args_filter(args)) + return simplify_intersection(args) + + args = list(ordered(args, Set._infimum_key)) + + obj = Basic.__new__(cls, *args) + obj._argset = frozenset(args) + return obj + + @property + def args(self): + return self._args + + @property + def is_iterable(self): + return any(arg.is_iterable for arg in self.args) + + @property + def is_finite_set(self): + if fuzzy_or(arg.is_finite_set for arg in self.args): + return True + + def _kind(self): + kinds = tuple(arg.kind for arg in self.args if arg is not S.UniversalSet) + if not kinds: + return SetKind(UndefinedKind) + elif all(i == kinds[0] for i in kinds): + return kinds[0] + else: + return SetKind() + + @property + def _inf(self): + raise NotImplementedError() + + @property + def _sup(self): + raise NotImplementedError() + + def _contains(self, other): + return And(*[set.contains(other) for set in self.args]) + + def __iter__(self): + sets_sift = sift(self.args, lambda x: x.is_iterable) + + completed = False + candidates = sets_sift[True] + sets_sift[None] + + finite_candidates, others = [], [] + for candidate in candidates: + length = None + try: + length = len(candidate) + except TypeError: + others.append(candidate) + + if length is not None: + finite_candidates.append(candidate) + finite_candidates.sort(key=len) + + for s in finite_candidates + others: + other_sets = set(self.args) - {s} + other = Intersection(*other_sets, evaluate=False) + completed = True + for x in s: + try: + if x in other: + yield x + except TypeError: + completed = False + if completed: + return + + if not completed: + if not candidates: + raise TypeError("None of the constituent sets are iterable") + raise TypeError( + "The computation had not completed because of the " + "undecidable set membership is found in every candidates.") + + @staticmethod + def _handle_finite_sets(args): + '''Simplify intersection of one or more FiniteSets and other sets''' + + # First separate the FiniteSets from the others + fs_args, others = sift(args, lambda x: x.is_FiniteSet, binary=True) + + # Let the caller handle intersection of non-FiniteSets + if not fs_args: + return + + # Convert to Python sets and build the set of all elements + fs_sets = [set(fs) for fs in fs_args] + all_elements = reduce(lambda a, b: a | b, fs_sets, set()) + + # Extract elements that are definitely in or definitely not in the + # intersection. Here we check contains for all of args. + definite = set() + for e in all_elements: + inall = fuzzy_and(s.contains(e) for s in args) + if inall is True: + definite.add(e) + if inall is not None: + for s in fs_sets: + s.discard(e) + + # At this point all elements in all of fs_sets are possibly in the + # intersection. In some cases this is because they are definitely in + # the intersection of the finite sets but it's not clear if they are + # members of others. We might have {m, n}, {m}, and Reals where we + # don't know if m or n is real. We want to remove n here but it is + # possibly in because it might be equal to m. So what we do now is + # extract the elements that are definitely in the remaining finite + # sets iteratively until we end up with {n}, {}. At that point if we + # get any empty set all remaining elements are discarded. + + fs_elements = reduce(lambda a, b: a | b, fs_sets, set()) + + # Need fuzzy containment testing + fs_symsets = [FiniteSet(*s) for s in fs_sets] + + while fs_elements: + for e in fs_elements: + infs = fuzzy_and(s.contains(e) for s in fs_symsets) + if infs is True: + definite.add(e) + if infs is not None: + for n, s in enumerate(fs_sets): + # Update Python set and FiniteSet + if e in s: + s.remove(e) + fs_symsets[n] = FiniteSet(*s) + fs_elements.remove(e) + break + # If we completed the for loop without removing anything we are + # done so quit the outer while loop + else: + break + + # If any of the sets of remainder elements is empty then we discard + # all of them for the intersection. + if not all(fs_sets): + fs_sets = [set()] + + # Here we fold back the definitely included elements into each fs. + # Since they are definitely included they must have been members of + # each FiniteSet to begin with. We could instead fold these in with a + # Union at the end to get e.g. {3}|({x}&{y}) rather than {3,x}&{3,y}. + if definite: + fs_sets = [fs | definite for fs in fs_sets] + + if fs_sets == [set()]: + return S.EmptySet + + sets = [FiniteSet(*s) for s in fs_sets] + + # Any set in others is redundant if it contains all the elements that + # are in the finite sets so we don't need it in the Intersection + all_elements = reduce(lambda a, b: a | b, fs_sets, set()) + is_redundant = lambda o: all(fuzzy_bool(o.contains(e)) for e in all_elements) + others = [o for o in others if not is_redundant(o)] + + if others: + rest = Intersection(*others) + # XXX: Maybe this shortcut should be at the beginning. For large + # FiniteSets it could much more efficient to process the other + # sets first... + if rest is S.EmptySet: + return S.EmptySet + # Flatten the Intersection + if rest.is_Intersection: + sets.extend(rest.args) + else: + sets.append(rest) + + if len(sets) == 1: + return sets[0] + else: + return Intersection(*sets, evaluate=False) + + def as_relational(self, symbol): + """Rewrite an Intersection in terms of equalities and logic operators""" + return And(*[set.as_relational(symbol) for set in self.args]) + + +class Complement(Set): + r"""Represents the set difference or relative complement of a set with + another set. + + $$A - B = \{x \in A \mid x \notin B\}$$ + + + Examples + ======== + + >>> from sympy import Complement, FiniteSet + >>> Complement(FiniteSet(0, 1, 2), FiniteSet(1)) + {0, 2} + + See Also + ========= + + Intersection, Union + + References + ========== + + .. [1] https://mathworld.wolfram.com/ComplementSet.html + """ + + is_Complement = True + + def __new__(cls, a, b, evaluate=True): + a, b = map(_sympify, (a, b)) + if evaluate: + return Complement.reduce(a, b) + + return Basic.__new__(cls, a, b) + + @staticmethod + def reduce(A, B): + """ + Simplify a :class:`Complement`. + + """ + if B == S.UniversalSet or A.is_subset(B): + return S.EmptySet + + if isinstance(B, Union): + return Intersection(*(s.complement(A) for s in B.args)) + + result = B._complement(A) + if result is not None: + return result + else: + return Complement(A, B, evaluate=False) + + def _contains(self, other): + A = self.args[0] + B = self.args[1] + return And(A.contains(other), Not(B.contains(other))) + + def as_relational(self, symbol): + """Rewrite a complement in terms of equalities and logic + operators""" + A, B = self.args + + A_rel = A.as_relational(symbol) + B_rel = Not(B.as_relational(symbol)) + + return And(A_rel, B_rel) + + def _kind(self): + return self.args[0].kind + + @property + def is_iterable(self): + if self.args[0].is_iterable: + return True + + @property + def is_finite_set(self): + A, B = self.args + a_finite = A.is_finite_set + if a_finite is True: + return True + elif a_finite is False and B.is_finite_set: + return False + + def __iter__(self): + A, B = self.args + for a in A: + if a not in B: + yield a + else: + continue + + +class EmptySet(Set, metaclass=Singleton): + """ + Represents the empty set. The empty set is available as a singleton + as ``S.EmptySet``. + + Examples + ======== + + >>> from sympy import S, Interval + >>> S.EmptySet + EmptySet + + >>> Interval(1, 2).intersect(S.EmptySet) + EmptySet + + See Also + ======== + + UniversalSet + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Empty_set + """ + is_empty = True + is_finite_set = True + is_FiniteSet = True + + @property # type: ignore + @deprecated( + """ + The is_EmptySet attribute of Set objects is deprecated. + Use 's is S.EmptySet" or 's.is_empty' instead. + """, + deprecated_since_version="1.5", + active_deprecations_target="deprecated-is-emptyset", + ) + def is_EmptySet(self): + return True + + @property + def _measure(self): + return 0 + + def _contains(self, other): + return false + + def as_relational(self, symbol): + return false + + def __len__(self): + return 0 + + def __iter__(self): + return iter([]) + + def _eval_powerset(self): + return FiniteSet(self) + + @property + def _boundary(self): + return self + + def _complement(self, other): + return other + + def _kind(self): + return SetKind() + + def _symmetric_difference(self, other): + return other + + +class UniversalSet(Set, metaclass=Singleton): + """ + Represents the set of all things. + The universal set is available as a singleton as ``S.UniversalSet``. + + Examples + ======== + + >>> from sympy import S, Interval + >>> S.UniversalSet + UniversalSet + + >>> Interval(1, 2).intersect(S.UniversalSet) + Interval(1, 2) + + See Also + ======== + + EmptySet + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Universal_set + """ + + is_UniversalSet = True + is_empty = False + is_finite_set = False + + def _complement(self, other): + return S.EmptySet + + def _symmetric_difference(self, other): + return other + + @property + def _measure(self): + return S.Infinity + + def _kind(self): + return SetKind(UndefinedKind) + + def _contains(self, other): + return true + + def as_relational(self, symbol): + return true + + @property + def _boundary(self): + return S.EmptySet + + +class FiniteSet(Set): + """ + Represents a finite set of Sympy expressions. + + Examples + ======== + + >>> from sympy import FiniteSet, Symbol, Interval, Naturals0 + >>> FiniteSet(1, 2, 3, 4) + {1, 2, 3, 4} + >>> 3 in FiniteSet(1, 2, 3, 4) + True + >>> FiniteSet(1, (1, 2), Symbol('x')) + {1, x, (1, 2)} + >>> FiniteSet(Interval(1, 2), Naturals0, {1, 2}) + FiniteSet({1, 2}, Interval(1, 2), Naturals0) + >>> members = [1, 2, 3, 4] + >>> f = FiniteSet(*members) + >>> f + {1, 2, 3, 4} + >>> f - FiniteSet(2) + {1, 3, 4} + >>> f + FiniteSet(2, 5) + {1, 2, 3, 4, 5} + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Finite_set + """ + is_FiniteSet = True + is_iterable = True + is_empty = False + is_finite_set = True + + def __new__(cls, *args, **kwargs): + evaluate = kwargs.get('evaluate', global_parameters.evaluate) + if evaluate: + args = list(map(sympify, args)) + + if len(args) == 0: + return S.EmptySet + else: + args = list(map(sympify, args)) + + # keep the form of the first canonical arg + dargs = {} + for i in reversed(list(ordered(args))): + if i.is_Symbol: + dargs[i] = i + else: + try: + dargs[i.as_dummy()] = i + except TypeError: + # e.g. i = class without args like `Interval` + dargs[i] = i + _args_set = set(dargs.values()) + args = list(ordered(_args_set, Set._infimum_key)) + obj = Basic.__new__(cls, *args) + obj._args_set = _args_set + return obj + + + def __iter__(self): + return iter(self.args) + + def _complement(self, other): + if isinstance(other, Interval): + # Splitting in sub-intervals is only done for S.Reals; + # other cases that need splitting will first pass through + # Set._complement(). + nums, syms = [], [] + for m in self.args: + if m.is_number and m.is_real: + nums.append(m) + elif m.is_real == False: + pass # drop non-reals + else: + syms.append(m) # various symbolic expressions + if other == S.Reals and nums != []: + nums.sort() + intervals = [] # Build up a list of intervals between the elements + intervals += [Interval(S.NegativeInfinity, nums[0], True, True)] + for a, b in zip(nums[:-1], nums[1:]): + intervals.append(Interval(a, b, True, True)) # both open + intervals.append(Interval(nums[-1], S.Infinity, True, True)) + if syms != []: + return Complement(Union(*intervals, evaluate=False), + FiniteSet(*syms), evaluate=False) + else: + return Union(*intervals, evaluate=False) + elif nums == []: # no splitting necessary or possible: + if syms: + return Complement(other, FiniteSet(*syms), evaluate=False) + else: + return other + + elif isinstance(other, FiniteSet): + unk = [] + for i in self: + c = sympify(other.contains(i)) + if c is not S.true and c is not S.false: + unk.append(i) + unk = FiniteSet(*unk) + if unk == self: + return + not_true = [] + for i in other: + c = sympify(self.contains(i)) + if c is not S.true: + not_true.append(i) + return Complement(FiniteSet(*not_true), unk) + + return Set._complement(self, other) + + def _contains(self, other): + """ + Tests whether an element, other, is in the set. + + Explanation + =========== + + The actual test is for mathematical equality (as opposed to + syntactical equality). In the worst case all elements of the + set must be checked. + + Examples + ======== + + >>> from sympy import FiniteSet + >>> 1 in FiniteSet(1, 2) + True + >>> 5 in FiniteSet(1, 2) + False + + """ + if other in self._args_set: + return S.true + else: + # evaluate=True is needed to override evaluate=False context; + # we need Eq to do the evaluation + return Or(*[Eq(e, other, evaluate=True) for e in self.args]) + + def _eval_is_subset(self, other): + return fuzzy_and(other._contains(e) for e in self.args) + + @property + def _boundary(self): + return self + + @property + def _inf(self): + return Min(*self) + + @property + def _sup(self): + return Max(*self) + + @property + def measure(self): + return 0 + + def _kind(self): + if not self.args: + return SetKind() + elif all(i.kind == self.args[0].kind for i in self.args): + return SetKind(self.args[0].kind) + else: + return SetKind(UndefinedKind) + + def __len__(self): + return len(self.args) + + def as_relational(self, symbol): + """Rewrite a FiniteSet in terms of equalities and logic operators. """ + return Or(*[Eq(symbol, elem) for elem in self]) + + def compare(self, other): + return (hash(self) - hash(other)) + + def _eval_evalf(self, prec): + dps = prec_to_dps(prec) + return FiniteSet(*[elem.evalf(n=dps) for elem in self]) + + def _eval_simplify(self, **kwargs): + from sympy.simplify import simplify + return FiniteSet(*[simplify(elem, **kwargs) for elem in self]) + + @property + def _sorted_args(self): + return self.args + + def _eval_powerset(self): + return self.func(*[self.func(*s) for s in subsets(self.args)]) + + def _eval_rewrite_as_PowerSet(self, *args, **kwargs): + """Rewriting method for a finite set to a power set.""" + from .powerset import PowerSet + + is2pow = lambda n: bool(n and not n & (n - 1)) + if not is2pow(len(self)): + return None + + fs_test = lambda arg: isinstance(arg, Set) and arg.is_FiniteSet + if not all(fs_test(arg) for arg in args): + return None + + biggest = max(args, key=len) + for arg in subsets(biggest.args): + arg_set = FiniteSet(*arg) + if arg_set not in args: + return None + return PowerSet(biggest) + + def __ge__(self, other): + if not isinstance(other, Set): + raise TypeError("Invalid comparison of set with %s" % func_name(other)) + return other.is_subset(self) + + def __gt__(self, other): + if not isinstance(other, Set): + raise TypeError("Invalid comparison of set with %s" % func_name(other)) + return self.is_proper_superset(other) + + def __le__(self, other): + if not isinstance(other, Set): + raise TypeError("Invalid comparison of set with %s" % func_name(other)) + return self.is_subset(other) + + def __lt__(self, other): + if not isinstance(other, Set): + raise TypeError("Invalid comparison of set with %s" % func_name(other)) + return self.is_proper_subset(other) + + def __eq__(self, other): + if isinstance(other, (set, frozenset)): + return self._args_set == other + return super().__eq__(other) + + __hash__ : Callable[[Basic], Any] = Basic.__hash__ + +_sympy_converter[set] = lambda x: FiniteSet(*x) +_sympy_converter[frozenset] = lambda x: FiniteSet(*x) + + +class SymmetricDifference(Set): + """Represents the set of elements which are in either of the + sets and not in their intersection. + + Examples + ======== + + >>> from sympy import SymmetricDifference, FiniteSet + >>> SymmetricDifference(FiniteSet(1, 2, 3), FiniteSet(3, 4, 5)) + {1, 2, 4, 5} + + See Also + ======== + + Complement, Union + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Symmetric_difference + """ + + is_SymmetricDifference = True + + def __new__(cls, a, b, evaluate=True): + if evaluate: + return SymmetricDifference.reduce(a, b) + + return Basic.__new__(cls, a, b) + + @staticmethod + def reduce(A, B): + result = B._symmetric_difference(A) + if result is not None: + return result + else: + return SymmetricDifference(A, B, evaluate=False) + + def as_relational(self, symbol): + """Rewrite a symmetric_difference in terms of equalities and + logic operators""" + A, B = self.args + + A_rel = A.as_relational(symbol) + B_rel = B.as_relational(symbol) + + return Xor(A_rel, B_rel) + + @property + def is_iterable(self): + if all(arg.is_iterable for arg in self.args): + return True + + def __iter__(self): + + args = self.args + union = roundrobin(*(iter(arg) for arg in args)) + + for item in union: + count = 0 + for s in args: + if item in s: + count += 1 + + if count % 2 == 1: + yield item + + + +class DisjointUnion(Set): + """ Represents the disjoint union (also known as the external disjoint union) + of a finite number of sets. + + Examples + ======== + + >>> from sympy import DisjointUnion, FiniteSet, Interval, Union, Symbol + >>> A = FiniteSet(1, 2, 3) + >>> B = Interval(0, 5) + >>> DisjointUnion(A, B) + DisjointUnion({1, 2, 3}, Interval(0, 5)) + >>> DisjointUnion(A, B).rewrite(Union) + Union(ProductSet({1, 2, 3}, {0}), ProductSet(Interval(0, 5), {1})) + >>> C = FiniteSet(Symbol('x'), Symbol('y'), Symbol('z')) + >>> DisjointUnion(C, C) + DisjointUnion({x, y, z}, {x, y, z}) + >>> DisjointUnion(C, C).rewrite(Union) + ProductSet({x, y, z}, {0, 1}) + + References + ========== + + https://en.wikipedia.org/wiki/Disjoint_union + """ + + def __new__(cls, *sets): + dj_collection = [] + for set_i in sets: + if isinstance(set_i, Set): + dj_collection.append(set_i) + else: + raise TypeError("Invalid input: '%s', input args \ + to DisjointUnion must be Sets" % set_i) + obj = Basic.__new__(cls, *dj_collection) + return obj + + @property + def sets(self): + return self.args + + @property + def is_empty(self): + return fuzzy_and(s.is_empty for s in self.sets) + + @property + def is_finite_set(self): + all_finite = fuzzy_and(s.is_finite_set for s in self.sets) + return fuzzy_or([self.is_empty, all_finite]) + + @property + def is_iterable(self): + if self.is_empty: + return False + iter_flag = True + for set_i in self.sets: + if not set_i.is_empty: + iter_flag = iter_flag and set_i.is_iterable + return iter_flag + + def _eval_rewrite_as_Union(self, *sets, **kwargs): + """ + Rewrites the disjoint union as the union of (``set`` x {``i``}) + where ``set`` is the element in ``sets`` at index = ``i`` + """ + + dj_union = S.EmptySet + index = 0 + for set_i in sets: + if isinstance(set_i, Set): + cross = ProductSet(set_i, FiniteSet(index)) + dj_union = Union(dj_union, cross) + index = index + 1 + return dj_union + + def _contains(self, element): + """ + ``in`` operator for DisjointUnion + + Examples + ======== + + >>> from sympy import Interval, DisjointUnion + >>> D = DisjointUnion(Interval(0, 1), Interval(0, 2)) + >>> (0.5, 0) in D + True + >>> (0.5, 1) in D + True + >>> (1.5, 0) in D + False + >>> (1.5, 1) in D + True + + Passes operation on to constituent sets + """ + if not isinstance(element, Tuple) or len(element) != 2: + return S.false + + if not element[1].is_Integer: + return S.false + + if element[1] >= len(self.sets) or element[1] < 0: + return S.false + + return self.sets[element[1]]._contains(element[0]) + + def _kind(self): + if not self.args: + return SetKind() + elif all(i.kind == self.args[0].kind for i in self.args): + return self.args[0].kind + else: + return SetKind(UndefinedKind) + + def __iter__(self): + if self.is_iterable: + + iters = [] + for i, s in enumerate(self.sets): + iters.append(iproduct(s, {Integer(i)})) + + return iter(roundrobin(*iters)) + else: + raise ValueError("'%s' is not iterable." % self) + + def __len__(self): + """ + Returns the length of the disjoint union, i.e., the number of elements in the set. + + Examples + ======== + + >>> from sympy import FiniteSet, DisjointUnion, EmptySet + >>> D1 = DisjointUnion(FiniteSet(1, 2, 3, 4), EmptySet, FiniteSet(3, 4, 5)) + >>> len(D1) + 7 + >>> D2 = DisjointUnion(FiniteSet(3, 5, 7), EmptySet, FiniteSet(3, 5, 7)) + >>> len(D2) + 6 + >>> D3 = DisjointUnion(EmptySet, EmptySet) + >>> len(D3) + 0 + + Adds up the lengths of the constituent sets. + """ + + if self.is_finite_set: + size = 0 + for set in self.sets: + size += len(set) + return size + else: + raise ValueError("'%s' is not a finite set." % self) + + +def imageset(*args): + r""" + Return an image of the set under transformation ``f``. + + Explanation + =========== + + If this function cannot compute the image, it returns an + unevaluated ImageSet object. + + .. math:: + \{ f(x) \mid x \in \mathrm{self} \} + + Examples + ======== + + >>> from sympy import S, Interval, imageset, sin, Lambda + >>> from sympy.abc import x + + >>> imageset(x, 2*x, Interval(0, 2)) + Interval(0, 4) + + >>> imageset(lambda x: 2*x, Interval(0, 2)) + Interval(0, 4) + + >>> imageset(Lambda(x, sin(x)), Interval(-2, 1)) + ImageSet(Lambda(x, sin(x)), Interval(-2, 1)) + + >>> imageset(sin, Interval(-2, 1)) + ImageSet(Lambda(x, sin(x)), Interval(-2, 1)) + >>> imageset(lambda y: x + y, Interval(-2, 1)) + ImageSet(Lambda(y, x + y), Interval(-2, 1)) + + Expressions applied to the set of Integers are simplified + to show as few negatives as possible and linear expressions + are converted to a canonical form. If this is not desirable + then the unevaluated ImageSet should be used. + + >>> imageset(x, -2*x + 5, S.Integers) + ImageSet(Lambda(x, 2*x + 1), Integers) + + See Also + ======== + + sympy.sets.fancysets.ImageSet + + """ + from .fancysets import ImageSet + from .setexpr import set_function + + if len(args) < 2: + raise ValueError('imageset expects at least 2 args, got: %s' % len(args)) + + if isinstance(args[0], (Symbol, tuple)) and len(args) > 2: + f = Lambda(args[0], args[1]) + set_list = args[2:] + else: + f = args[0] + set_list = args[1:] + + if isinstance(f, Lambda): + pass + elif callable(f): + nargs = getattr(f, 'nargs', {}) + if nargs: + if len(nargs) != 1: + raise NotImplementedError(filldedent(''' + This function can take more than 1 arg + but the potentially complicated set input + has not been analyzed at this point to + know its dimensions. TODO + ''')) + N = nargs.args[0] + if N == 1: + s = 'x' + else: + s = [Symbol('x%i' % i) for i in range(1, N + 1)] + else: + s = inspect.signature(f).parameters + + dexpr = _sympify(f(*[Dummy() for i in s])) + var = tuple(uniquely_named_symbol( + Symbol(i), dexpr) for i in s) + f = Lambda(var, f(*var)) + else: + raise TypeError(filldedent(''' + expecting lambda, Lambda, or FunctionClass, + not \'%s\'.''' % func_name(f))) + + if any(not isinstance(s, Set) for s in set_list): + name = [func_name(s) for s in set_list] + raise ValueError( + 'arguments after mapping should be sets, not %s' % name) + + if len(set_list) == 1: + set = set_list[0] + try: + # TypeError if arg count != set dimensions + r = set_function(f, set) + if r is None: + raise TypeError + if not r: + return r + except TypeError: + r = ImageSet(f, set) + if isinstance(r, ImageSet): + f, set = r.args + + if f.variables[0] == f.expr: + return set + + if isinstance(set, ImageSet): + # XXX: Maybe this should just be: + # f2 = set.lambda + # fun = Lambda(f2.signature, f(*f2.expr)) + # return imageset(fun, *set.base_sets) + if len(set.lamda.variables) == 1 and len(f.variables) == 1: + x = set.lamda.variables[0] + y = f.variables[0] + return imageset( + Lambda(x, f.expr.subs(y, set.lamda.expr)), *set.base_sets) + + if r is not None: + return r + + return ImageSet(f, *set_list) + + +def is_function_invertible_in_set(func, setv): + """ + Checks whether function ``func`` is invertible when the domain is + restricted to set ``setv``. + """ + # Functions known to always be invertible: + if func in (exp, log): + return True + u = Dummy("u") + fdiff = func(u).diff(u) + # monotonous functions: + # TODO: check subsets (`func` in `setv`) + if (fdiff > 0) == True or (fdiff < 0) == True: + return True + # TODO: support more + return None + + +def simplify_union(args): + """ + Simplify a :class:`Union` using known rules. + + Explanation + =========== + + We first start with global rules like 'Merge all FiniteSets' + + Then we iterate through all pairs and ask the constituent sets if they + can simplify themselves with any other constituent. This process depends + on ``union_sets(a, b)`` functions. + """ + from sympy.sets.handlers.union import union_sets + + # ===== Global Rules ===== + if not args: + return S.EmptySet + + for arg in args: + if not isinstance(arg, Set): + raise TypeError("Input args to Union must be Sets") + + # Merge all finite sets + finite_sets = [x for x in args if x.is_FiniteSet] + if len(finite_sets) > 1: + a = (x for set in finite_sets for x in set) + finite_set = FiniteSet(*a) + args = [finite_set] + [x for x in args if not x.is_FiniteSet] + + # ===== Pair-wise Rules ===== + # Here we depend on rules built into the constituent sets + args = set(args) + new_args = True + while new_args: + for s in args: + new_args = False + for t in args - {s}: + new_set = union_sets(s, t) + # This returns None if s does not know how to intersect + # with t. Returns the newly intersected set otherwise + if new_set is not None: + if not isinstance(new_set, set): + new_set = {new_set} + new_args = (args - {s, t}).union(new_set) + break + if new_args: + args = new_args + break + + if len(args) == 1: + return args.pop() + else: + return Union(*args, evaluate=False) + + +def simplify_intersection(args): + """ + Simplify an intersection using known rules. + + Explanation + =========== + + We first start with global rules like + 'if any empty sets return empty set' and 'distribute any unions' + + Then we iterate through all pairs and ask the constituent sets if they + can simplify themselves with any other constituent + """ + + # ===== Global Rules ===== + if not args: + return S.UniversalSet + + for arg in args: + if not isinstance(arg, Set): + raise TypeError("Input args to Union must be Sets") + + # If any EmptySets return EmptySet + if S.EmptySet in args: + return S.EmptySet + + # Handle Finite sets + rv = Intersection._handle_finite_sets(args) + + if rv is not None: + return rv + + # If any of the sets are unions, return a Union of Intersections + for s in args: + if s.is_Union: + other_sets = set(args) - {s} + if len(other_sets) > 0: + other = Intersection(*other_sets) + return Union(*(Intersection(arg, other) for arg in s.args)) + else: + return Union(*s.args) + + for s in args: + if s.is_Complement: + args.remove(s) + other_sets = args + [s.args[0]] + return Complement(Intersection(*other_sets), s.args[1]) + + from sympy.sets.handlers.intersection import intersection_sets + + # At this stage we are guaranteed not to have any + # EmptySets, FiniteSets, or Unions in the intersection + + # ===== Pair-wise Rules ===== + # Here we depend on rules built into the constituent sets + args = set(args) + new_args = True + while new_args: + for s in args: + new_args = False + for t in args - {s}: + new_set = intersection_sets(s, t) + # This returns None if s does not know how to intersect + # with t. Returns the newly intersected set otherwise + + if new_set is not None: + new_args = (args - {s, t}).union({new_set}) + break + if new_args: + args = new_args + break + + if len(args) == 1: + return args.pop() + else: + return Intersection(*args, evaluate=False) + + +def _handle_finite_sets(op, x, y, commutative): + # Handle finite sets: + fs_args, other = sift([x, y], lambda x: isinstance(x, FiniteSet), binary=True) + if len(fs_args) == 2: + return FiniteSet(*[op(i, j) for i in fs_args[0] for j in fs_args[1]]) + elif len(fs_args) == 1: + sets = [_apply_operation(op, other[0], i, commutative) for i in fs_args[0]] + return Union(*sets) + else: + return None + + +def _apply_operation(op, x, y, commutative): + from .fancysets import ImageSet + d = Dummy('d') + + out = _handle_finite_sets(op, x, y, commutative) + if out is None: + out = op(x, y) + + if out is None and commutative: + out = op(y, x) + if out is None: + _x, _y = symbols("x y") + if isinstance(x, Set) and not isinstance(y, Set): + out = ImageSet(Lambda(d, op(d, y)), x).doit() + elif not isinstance(x, Set) and isinstance(y, Set): + out = ImageSet(Lambda(d, op(x, d)), y).doit() + else: + out = ImageSet(Lambda((_x, _y), op(_x, _y)), x, y) + return out + + +def set_add(x, y): + from sympy.sets.handlers.add import _set_add + return _apply_operation(_set_add, x, y, commutative=True) + + +def set_sub(x, y): + from sympy.sets.handlers.add import _set_sub + return _apply_operation(_set_sub, x, y, commutative=False) + + +def set_mul(x, y): + from sympy.sets.handlers.mul import _set_mul + return _apply_operation(_set_mul, x, y, commutative=True) + + +def set_div(x, y): + from sympy.sets.handlers.mul import _set_div + return _apply_operation(_set_div, x, y, commutative=False) + + +def set_pow(x, y): + from sympy.sets.handlers.power import _set_pow + return _apply_operation(_set_pow, x, y, commutative=False) + + +def set_function(f, x): + from sympy.sets.handlers.functions import _set_function + return _set_function(f, x) + + +class SetKind(Kind): + """ + SetKind is kind for all Sets + + Every instance of Set will have kind ``SetKind`` parametrised by the kind + of the elements of the ``Set``. The kind of the elements might be + ``NumberKind``, or ``TupleKind`` or something else. When not all elements + have the same kind then the kind of the elements will be given as + ``UndefinedKind``. + + Parameters + ========== + + element_kind: Kind (optional) + The kind of the elements of the set. In a well defined set all elements + will have the same kind. Otherwise the kind should + :class:`sympy.core.kind.UndefinedKind`. The ``element_kind`` argument is optional but + should only be omitted in the case of ``EmptySet`` whose kind is simply + ``SetKind()`` + + Examples + ======== + + >>> from sympy import Interval + >>> Interval(1, 2).kind + SetKind(NumberKind) + >>> Interval(1,2).kind.element_kind + NumberKind + + See Also + ======== + + sympy.core.kind.NumberKind + sympy.matrices.kind.MatrixKind + sympy.core.containers.TupleKind + """ + def __new__(cls, element_kind=None): + obj = super().__new__(cls, element_kind) + obj.element_kind = element_kind + return obj + + def __repr__(self): + if not self.element_kind: + return "SetKind()" + else: + return "SetKind(%s)" % self.element_kind diff --git a/phivenv/Lib/site-packages/sympy/sets/tests/__init__.py b/phivenv/Lib/site-packages/sympy/sets/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/sets/tests/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/sets/tests/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d75aff9846492de85f039cf1503dde1d6a619545 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/sets/tests/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/sets/tests/__pycache__/test_conditionset.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/sets/tests/__pycache__/test_conditionset.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16b6c871f25e62ff6e1fc5bfac9b09dadcbf8f84 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/sets/tests/__pycache__/test_conditionset.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/sets/tests/__pycache__/test_contains.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/sets/tests/__pycache__/test_contains.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d944665e63df94277d705411687711aec56b13c7 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/sets/tests/__pycache__/test_contains.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/sets/tests/__pycache__/test_fancysets.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/sets/tests/__pycache__/test_fancysets.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a70f9e7dae64ff1f3feec1aecf283319f5db46f9 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/sets/tests/__pycache__/test_fancysets.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/sets/tests/__pycache__/test_ordinals.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/sets/tests/__pycache__/test_ordinals.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b8363f709ada83e1b37afea3cb592da23324380 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/sets/tests/__pycache__/test_ordinals.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/sets/tests/__pycache__/test_powerset.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/sets/tests/__pycache__/test_powerset.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9cdaccf6d634c830bc99eef1d2f09033c777a1b8 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/sets/tests/__pycache__/test_powerset.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/sets/tests/__pycache__/test_setexpr.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/sets/tests/__pycache__/test_setexpr.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb65f5efdafbbe9638750818fd1f686672501786 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/sets/tests/__pycache__/test_setexpr.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/sets/tests/__pycache__/test_sets.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/sets/tests/__pycache__/test_sets.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ce2290e22e6cafb1160607ad3cc959be26f2d28 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/sets/tests/__pycache__/test_sets.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/sets/tests/test_conditionset.py b/phivenv/Lib/site-packages/sympy/sets/tests/test_conditionset.py new file mode 100644 index 0000000000000000000000000000000000000000..4818246f306afd46a09a2cbea1faab858a9e7806 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/sets/tests/test_conditionset.py @@ -0,0 +1,294 @@ +from sympy.core.expr import unchanged +from sympy.sets import (ConditionSet, Intersection, FiniteSet, + EmptySet, Union, Contains, ImageSet) +from sympy.sets.sets import SetKind +from sympy.core.function import (Function, Lambda) +from sympy.core.mod import Mod +from sympy.core.kind import NumberKind +from sympy.core.numbers import (oo, pi) +from sympy.core.relational import (Eq, Ne) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.trigonometric import (asin, sin) +from sympy.logic.boolalg import And +from sympy.matrices.dense import Matrix +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.sets.sets import Interval +from sympy.testing.pytest import raises, warns_deprecated_sympy + + +w = Symbol('w') +x = Symbol('x') +y = Symbol('y') +z = Symbol('z') +f = Function('f') + + +def test_CondSet(): + sin_sols_principal = ConditionSet(x, Eq(sin(x), 0), + Interval(0, 2*pi, False, True)) + assert pi in sin_sols_principal + assert pi/2 not in sin_sols_principal + assert 3*pi not in sin_sols_principal + assert oo not in sin_sols_principal + assert 5 in ConditionSet(x, x**2 > 4, S.Reals) + assert 1 not in ConditionSet(x, x**2 > 4, S.Reals) + # in this case, 0 is not part of the base set so + # it can't be in any subset selected by the condition + assert 0 not in ConditionSet(x, y > 5, Interval(1, 7)) + # since 'in' requires a true/false, the following raises + # an error because the given value provides no information + # for the condition to evaluate (since the condition does + # not depend on the dummy symbol): the result is `y > 5`. + # In this case, ConditionSet is just acting like + # Piecewise((Interval(1, 7), y > 5), (S.EmptySet, True)). + raises(TypeError, lambda: 6 in ConditionSet(x, y > 5, + Interval(1, 7))) + + X = MatrixSymbol('X', 2, 2) + matrix_set = ConditionSet(X, Eq(X*Matrix([[1, 1], [1, 1]]), X)) + Y = Matrix([[0, 0], [0, 0]]) + assert matrix_set.contains(Y).doit() is S.true + Z = Matrix([[1, 2], [3, 4]]) + assert matrix_set.contains(Z).doit() is S.false + + assert isinstance(ConditionSet(x, x < 1, {x, y}).base_set, + FiniteSet) + raises(TypeError, lambda: ConditionSet(x, x + 1, {x, y})) + raises(TypeError, lambda: ConditionSet(x, x, 1)) + + I = S.Integers + U = S.UniversalSet + C = ConditionSet + assert C(x, False, I) is S.EmptySet + assert C(x, True, I) is I + assert C(x, x < 1, C(x, x < 2, I) + ) == C(x, (x < 1) & (x < 2), I) + assert C(y, y < 1, C(x, y < 2, I) + ) == C(x, (x < 1) & (y < 2), I), C(y, y < 1, C(x, y < 2, I)) + assert C(y, y < 1, C(x, x < 2, I) + ) == C(y, (y < 1) & (y < 2), I) + assert C(y, y < 1, C(x, y < x, I) + ) == C(x, (x < 1) & (y < x), I) + assert unchanged(C, y, x < 1, C(x, y < x, I)) + assert ConditionSet(x, x < 1).base_set is U + # arg checking is not done at instantiation but this + # will raise an error when containment is tested + assert ConditionSet((x,), x < 1).base_set is U + + c = ConditionSet((x, y), x < y, I**2) + assert (1, 2) in c + assert (1, pi) not in c + + raises(TypeError, lambda: C(x, x > 1, C((x, y), x > 1, I**2))) + # signature mismatch since only 3 args are accepted + raises(TypeError, lambda: C((x, y), x + y < 2, U, U)) + + +def test_CondSet_intersect(): + input_conditionset = ConditionSet(x, x**2 > 4, Interval(1, 4, False, + False)) + other_domain = Interval(0, 3, False, False) + output_conditionset = ConditionSet(x, x**2 > 4, Interval( + 1, 3, False, False)) + assert Intersection(input_conditionset, other_domain + ) == output_conditionset + + +def test_issue_9849(): + assert ConditionSet(x, Eq(x, x), S.Naturals + ) is S.Naturals + assert ConditionSet(x, Eq(Abs(sin(x)), -1), S.Naturals + ) == S.EmptySet + + +def test_simplified_FiniteSet_in_CondSet(): + assert ConditionSet(x, And(x < 1, x > -3), FiniteSet(0, 1, 2) + ) == FiniteSet(0) + assert ConditionSet(x, x < 0, FiniteSet(0, 1, 2)) == EmptySet + assert ConditionSet(x, And(x < -3), EmptySet) == EmptySet + y = Symbol('y') + assert (ConditionSet(x, And(x > 0), FiniteSet(-1, 0, 1, y)) == + Union(FiniteSet(1), ConditionSet(x, And(x > 0), FiniteSet(y)))) + assert (ConditionSet(x, Eq(Mod(x, 3), 1), FiniteSet(1, 4, 2, y)) == + Union(FiniteSet(1, 4), ConditionSet(x, Eq(Mod(x, 3), 1), + FiniteSet(y)))) + + +def test_free_symbols(): + assert ConditionSet(x, Eq(y, 0), FiniteSet(z) + ).free_symbols == {y, z} + assert ConditionSet(x, Eq(x, 0), FiniteSet(z) + ).free_symbols == {z} + assert ConditionSet(x, Eq(x, 0), FiniteSet(x, z) + ).free_symbols == {x, z} + assert ConditionSet(x, Eq(x, 0), ImageSet(Lambda(y, y**2), + S.Integers)).free_symbols == set() + + +def test_bound_symbols(): + assert ConditionSet(x, Eq(y, 0), FiniteSet(z) + ).bound_symbols == [x] + assert ConditionSet(x, Eq(x, 0), FiniteSet(x, y) + ).bound_symbols == [x] + assert ConditionSet(x, x < 10, ImageSet(Lambda(y, y**2), S.Integers) + ).bound_symbols == [x] + assert ConditionSet(x, x < 10, ConditionSet(y, y > 1, S.Integers) + ).bound_symbols == [x] + + +def test_as_dummy(): + _0, _1 = symbols('_0 _1') + assert ConditionSet(x, x < 1, Interval(y, oo) + ).as_dummy() == ConditionSet(_0, _0 < 1, Interval(y, oo)) + assert ConditionSet(x, x < 1, Interval(x, oo) + ).as_dummy() == ConditionSet(_0, _0 < 1, Interval(x, oo)) + assert ConditionSet(x, x < 1, ImageSet(Lambda(y, y**2), S.Integers) + ).as_dummy() == ConditionSet( + _0, _0 < 1, ImageSet(Lambda(_0, _0**2), S.Integers)) + e = ConditionSet((x, y), x <= y, S.Reals**2) + assert e.bound_symbols == [x, y] + assert e.as_dummy() == ConditionSet((_0, _1), _0 <= _1, S.Reals**2) + assert e.as_dummy() == ConditionSet((y, x), y <= x, S.Reals**2 + ).as_dummy() + + +def test_subs_CondSet(): + s = FiniteSet(z, y) + c = ConditionSet(x, x < 2, s) + assert c.subs(x, y) == c + assert c.subs(z, y) == ConditionSet(x, x < 2, FiniteSet(y)) + assert c.xreplace({x: y}) == ConditionSet(y, y < 2, s) + + assert ConditionSet(x, x < y, s + ).subs(y, w) == ConditionSet(x, x < w, s.subs(y, w)) + # if the user uses assumptions that cause the condition + # to evaluate, that can't be helped from SymPy's end + n = Symbol('n', negative=True) + assert ConditionSet(n, 0 < n, S.Integers) is S.EmptySet + p = Symbol('p', positive=True) + assert ConditionSet(n, n < y, S.Integers + ).subs(n, x) == ConditionSet(n, n < y, S.Integers) + raises(ValueError, lambda: ConditionSet( + x + 1, x < 1, S.Integers)) + assert ConditionSet( + p, n < x, Interval(-5, 5)).subs(x, p) == Interval(-5, 5), ConditionSet( + p, n < x, Interval(-5, 5)).subs(x, p) + assert ConditionSet( + n, n < x, Interval(-oo, 0)).subs(x, p + ) == Interval(-oo, 0) + + assert ConditionSet(f(x), f(x) < 1, {w, z} + ).subs(f(x), y) == ConditionSet(f(x), f(x) < 1, {w, z}) + + # issue 17341 + k = Symbol('k') + img1 = ImageSet(Lambda(k, 2*k*pi + asin(y)), S.Integers) + img2 = ImageSet(Lambda(k, 2*k*pi + asin(S.One/3)), S.Integers) + assert ConditionSet(x, Contains( + y, Interval(-1,1)), img1).subs(y, S.One/3).dummy_eq(img2) + + assert (0, 1) in ConditionSet((x, y), x + y < 3, S.Integers**2) + + raises(TypeError, lambda: ConditionSet(n, n < -10, Interval(0, 10))) + + +def test_subs_CondSet_tebr(): + with warns_deprecated_sympy(): + assert ConditionSet((x, y), {x + 1, x + y}, S.Reals**2) == \ + ConditionSet((x, y), Eq(x + 1, 0) & Eq(x + y, 0), S.Reals**2) + + +def test_dummy_eq(): + C = ConditionSet + I = S.Integers + c = C(x, x < 1, I) + assert c.dummy_eq(C(y, y < 1, I)) + assert c.dummy_eq(1) == False + assert c.dummy_eq(C(x, x < 1, S.Reals)) == False + + c1 = ConditionSet((x, y), Eq(x + 1, 0) & Eq(x + y, 0), S.Reals**2) + c2 = ConditionSet((x, y), Eq(x + 1, 0) & Eq(x + y, 0), S.Reals**2) + c3 = ConditionSet((x, y), Eq(x + 1, 0) & Eq(x + y, 0), S.Complexes**2) + assert c1.dummy_eq(c2) + assert c1.dummy_eq(c3) is False + assert c.dummy_eq(c1) is False + assert c1.dummy_eq(c) is False + + # issue 19496 + m = Symbol('m') + n = Symbol('n') + a = Symbol('a') + d1 = ImageSet(Lambda(m, m*pi), S.Integers) + d2 = ImageSet(Lambda(n, n*pi), S.Integers) + c1 = ConditionSet(x, Ne(a, 0), d1) + c2 = ConditionSet(x, Ne(a, 0), d2) + assert c1.dummy_eq(c2) + + +def test_contains(): + assert 6 in ConditionSet(x, x > 5, Interval(1, 7)) + assert (8 in ConditionSet(x, y > 5, Interval(1, 7))) is False + # `in` should give True or False; in this case there is not + # enough information for that result + raises(TypeError, + lambda: 6 in ConditionSet(x, y > 5, Interval(1, 7))) + # here, there is enough information but the comparison is + # not defined + raises(TypeError, lambda: 0 in ConditionSet(x, 1/x >= 0, S.Reals)) + assert ConditionSet(x, y > 5, Interval(1, 7) + ).contains(6) == (y > 5) + assert ConditionSet(x, y > 5, Interval(1, 7) + ).contains(8) is S.false + assert ConditionSet(x, y > 5, Interval(1, 7) + ).contains(w) == And(Contains(w, Interval(1, 7)), y > 5) + # This returns an unevaluated Contains object + # because 1/0 should not be defined for 1 and 0 in the context of + # reals. + assert ConditionSet(x, 1/x >= 0, S.Reals).contains(0) == \ + Contains(0, ConditionSet(x, 1/x >= 0, S.Reals), evaluate=False) + c = ConditionSet((x, y), x + y > 1, S.Integers**2) + assert not c.contains(1) + assert c.contains((2, 1)) + assert not c.contains((0, 1)) + c = ConditionSet((w, (x, y)), w + x + y > 1, S.Integers*S.Integers**2) + assert not c.contains(1) + assert not c.contains((1, 2)) + assert not c.contains(((1, 2), 3)) + assert not c.contains(((1, 2), (3, 4))) + assert c.contains((1, (3, 4))) + + +def test_as_relational(): + assert ConditionSet((x, y), x > 1, S.Integers**2).as_relational((x, y) + ) == (x > 1) & Contains(x, S.Integers) & Contains(y, S.Integers) + assert ConditionSet(x, x > 1, S.Integers).as_relational(x + ) == Contains(x, S.Integers) & (x > 1) + + +def test_flatten(): + """Tests whether there is basic denesting functionality""" + inner = ConditionSet(x, sin(x) + x > 0) + outer = ConditionSet(x, Contains(x, inner), S.Reals) + assert outer == ConditionSet(x, sin(x) + x > 0, S.Reals) + + inner = ConditionSet(y, sin(y) + y > 0) + outer = ConditionSet(x, Contains(y, inner), S.Reals) + assert outer != ConditionSet(x, sin(x) + x > 0, S.Reals) + + inner = ConditionSet(x, sin(x) + x > 0).intersect(Interval(-1, 1)) + outer = ConditionSet(x, Contains(x, inner), S.Reals) + assert outer == ConditionSet(x, sin(x) + x > 0, Interval(-1, 1)) + + +def test_duplicate(): + from sympy.core.function import BadSignatureError + # test coverage for line 95 in conditionset.py, check for duplicates in symbols + dup = symbols('a,a') + raises(BadSignatureError, lambda: ConditionSet(dup, x < 0)) + + +def test_SetKind_ConditionSet(): + assert ConditionSet(x, Eq(sin(x), 0), Interval(0, 2*pi)).kind is SetKind(NumberKind) + assert ConditionSet(x, x < 0).kind is SetKind(NumberKind) diff --git a/phivenv/Lib/site-packages/sympy/sets/tests/test_contains.py b/phivenv/Lib/site-packages/sympy/sets/tests/test_contains.py new file mode 100644 index 0000000000000000000000000000000000000000..bb6b98940946f98bf377aad6810f5b32eb6dd069 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/sets/tests/test_contains.py @@ -0,0 +1,52 @@ +from sympy.core.expr import unchanged +from sympy.core.numbers import oo +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.sets.contains import Contains +from sympy.sets.sets import (FiniteSet, Interval) +from sympy.testing.pytest import raises + + +def test_contains_basic(): + raises(TypeError, lambda: Contains(S.Integers, 1)) + assert Contains(2, S.Integers) is S.true + assert Contains(-2, S.Naturals) is S.false + + i = Symbol('i', integer=True) + assert Contains(i, S.Naturals) == Contains(i, S.Naturals, evaluate=False) + + +def test_issue_6194(): + x = Symbol('x') + assert unchanged(Contains, x, Interval(0, 1)) + assert Interval(0, 1).contains(x) == (S.Zero <= x) & (x <= 1) + assert Contains(x, FiniteSet(0)) != S.false + assert Contains(x, Interval(1, 1)) != S.false + assert Contains(x, S.Integers) != S.false + + +def test_issue_10326(): + assert Contains(oo, Interval(-oo, oo)) == False + assert Contains(-oo, Interval(-oo, oo)) == False + + +def test_binary_symbols(): + x = Symbol('x') + y = Symbol('y') + z = Symbol('z') + assert Contains(x, FiniteSet(y, Eq(z, True)) + ).binary_symbols == {y, z} + + +def test_as_set(): + x = Symbol('x') + y = Symbol('y') + assert Contains(x, FiniteSet(y)).as_set() == FiniteSet(y) + assert Contains(x, S.Integers).as_set() == S.Integers + assert Contains(x, S.Reals).as_set() == S.Reals + + +def test_type_error(): + # Pass in a parameter not of type "set" + raises(TypeError, lambda: Contains(2, None)) diff --git a/phivenv/Lib/site-packages/sympy/sets/tests/test_fancysets.py b/phivenv/Lib/site-packages/sympy/sets/tests/test_fancysets.py new file mode 100644 index 0000000000000000000000000000000000000000..b23c2a99fce0af5bfe7c667185465ee417de19ce --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/sets/tests/test_fancysets.py @@ -0,0 +1,1313 @@ + +from sympy.core.expr import unchanged +from sympy.sets.contains import Contains +from sympy.sets.fancysets import (ImageSet, Range, normalize_theta_set, + ComplexRegion) +from sympy.sets.sets import (FiniteSet, Interval, Union, imageset, + Intersection, ProductSet, SetKind) +from sympy.sets.conditionset import ConditionSet +from sympy.simplify.simplify import simplify +from sympy.core.basic import Basic +from sympy.core.containers import Tuple, TupleKind +from sympy.core.function import Lambda +from sympy.core.kind import NumberKind +from sympy.core.numbers import (I, Rational, oo, pi) +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol, symbols) +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.integers import floor +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin, tan) +from sympy.logic.boolalg import And +from sympy.matrices.dense import eye +from sympy.testing.pytest import XFAIL, raises +from sympy.abc import x, y, t, z +from sympy.core.mod import Mod + +import itertools + + +def test_naturals(): + N = S.Naturals + assert 5 in N + assert -5 not in N + assert 5.5 not in N + ni = iter(N) + a, b, c, d = next(ni), next(ni), next(ni), next(ni) + assert (a, b, c, d) == (1, 2, 3, 4) + assert isinstance(a, Basic) + + assert N.intersect(Interval(-5, 5)) == Range(1, 6) + assert N.intersect(Interval(-5, 5, True, True)) == Range(1, 5) + + assert N.boundary == N + assert N.is_open == False + assert N.is_closed == True + + assert N.inf == 1 + assert N.sup is oo + assert not N.contains(oo) + for s in (S.Naturals0, S.Naturals): + assert s.intersection(S.Reals) is s + assert s.is_subset(S.Reals) + + assert N.as_relational(x) == And(Eq(floor(x), x), x >= 1, x < oo) + + +def test_naturals0(): + N = S.Naturals0 + assert 0 in N + assert -1 not in N + assert next(iter(N)) == 0 + assert not N.contains(oo) + assert N.contains(sin(x)) == Contains(sin(x), N) + + +def test_integers(): + Z = S.Integers + assert 5 in Z + assert -5 in Z + assert 5.5 not in Z + assert not Z.contains(oo) + assert not Z.contains(-oo) + + zi = iter(Z) + a, b, c, d = next(zi), next(zi), next(zi), next(zi) + assert (a, b, c, d) == (0, 1, -1, 2) + assert isinstance(a, Basic) + + assert Z.intersect(Interval(-5, 5)) == Range(-5, 6) + assert Z.intersect(Interval(-5, 5, True, True)) == Range(-4, 5) + assert Z.intersect(Interval(5, S.Infinity)) == Range(5, S.Infinity) + assert Z.intersect(Interval.Lopen(5, S.Infinity)) == Range(6, S.Infinity) + + assert Z.inf is -oo + assert Z.sup is oo + + assert Z.boundary == Z + assert Z.is_open == False + assert Z.is_closed == True + + assert Z.as_relational(x) == And(Eq(floor(x), x), -oo < x, x < oo) + + +def test_ImageSet(): + raises(ValueError, lambda: ImageSet(x, S.Integers)) + assert ImageSet(Lambda(x, 1), S.Integers) == FiniteSet(1) + assert ImageSet(Lambda(x, y), S.Integers) == {y} + assert ImageSet(Lambda(x, 1), S.EmptySet) == S.EmptySet + empty = Intersection(FiniteSet(log(2)/pi), S.Integers) + assert unchanged(ImageSet, Lambda(x, 1), empty) # issue #17471 + squares = ImageSet(Lambda(x, x**2), S.Naturals) + assert 4 in squares + assert 5 not in squares + assert FiniteSet(*range(10)).intersect(squares) == FiniteSet(1, 4, 9) + + assert 16 not in squares.intersect(Interval(0, 10)) + + si = iter(squares) + a, b, c, d = next(si), next(si), next(si), next(si) + assert (a, b, c, d) == (1, 4, 9, 16) + + harmonics = ImageSet(Lambda(x, 1/x), S.Naturals) + assert Rational(1, 5) in harmonics + assert Rational(.25) in harmonics + assert harmonics.contains(.25) == Contains( + 0.25, ImageSet(Lambda(x, 1/x), S.Naturals), evaluate=False) + assert Rational(.3) not in harmonics + assert (1, 2) not in harmonics + + assert harmonics.is_iterable + + assert imageset(x, -x, Interval(0, 1)) == Interval(-1, 0) + + assert ImageSet(Lambda(x, x**2), Interval(0, 2)).doit() == Interval(0, 4) + assert ImageSet(Lambda((x, y), 2*x), {4}, {3}).doit() == FiniteSet(8) + assert (ImageSet(Lambda((x, y), x+y), {1, 2, 3}, {10, 20, 30}).doit() == + FiniteSet(11, 12, 13, 21, 22, 23, 31, 32, 33)) + + c = Interval(1, 3) * Interval(1, 3) + assert Tuple(2, 6) in ImageSet(Lambda(((x, y),), (x, 2*y)), c) + assert Tuple(2, S.Half) in ImageSet(Lambda(((x, y),), (x, 1/y)), c) + assert Tuple(2, -2) not in ImageSet(Lambda(((x, y),), (x, y**2)), c) + assert Tuple(2, -2) in ImageSet(Lambda(((x, y),), (x, -2)), c) + c3 = ProductSet(Interval(3, 7), Interval(8, 11), Interval(5, 9)) + assert Tuple(8, 3, 9) in ImageSet(Lambda(((t, y, x),), (y, t, x)), c3) + assert Tuple(Rational(1, 8), 3, 9) in ImageSet(Lambda(((t, y, x),), (1/y, t, x)), c3) + assert 2/pi not in ImageSet(Lambda(((x, y),), 2/x), c) + assert 2/S(100) not in ImageSet(Lambda(((x, y),), 2/x), c) + assert Rational(2, 3) in ImageSet(Lambda(((x, y),), 2/x), c) + + S1 = imageset(lambda x, y: x + y, S.Integers, S.Naturals) + assert S1.base_pset == ProductSet(S.Integers, S.Naturals) + assert S1.base_sets == (S.Integers, S.Naturals) + + # Passing a set instead of a FiniteSet shouldn't raise + assert unchanged(ImageSet, Lambda(x, x**2), {1, 2, 3}) + + S2 = ImageSet(Lambda(((x, y),), x+y), {(1, 2), (3, 4)}) + assert 3 in S2.doit() + # FIXME: This doesn't yet work: + #assert 3 in S2 + assert S2._contains(3) is None + + raises(TypeError, lambda: ImageSet(Lambda(x, x**2), 1)) + + +def test_image_is_ImageSet(): + assert isinstance(imageset(x, sqrt(sin(x)), Range(5)), ImageSet) + + +def test_halfcircle(): + r, th = symbols('r, theta', real=True) + L = Lambda(((r, th),), (r*cos(th), r*sin(th))) + halfcircle = ImageSet(L, Interval(0, 1)*Interval(0, pi)) + + assert (1, 0) in halfcircle + assert (0, -1) not in halfcircle + assert (0, 0) in halfcircle + assert halfcircle._contains((r, 0)) is None + assert not halfcircle.is_iterable + + +@XFAIL +def test_halfcircle_fail(): + r, th = symbols('r, theta', real=True) + L = Lambda(((r, th),), (r*cos(th), r*sin(th))) + halfcircle = ImageSet(L, Interval(0, 1)*Interval(0, pi)) + assert (r, 2*pi) not in halfcircle + + +def test_ImageSet_iterator_not_injective(): + L = Lambda(x, x - x % 2) # produces 0, 2, 2, 4, 4, 6, 6, ... + evens = ImageSet(L, S.Naturals) + i = iter(evens) + # No repeats here + assert (next(i), next(i), next(i), next(i)) == (0, 2, 4, 6) + + +def test_inf_Range_len(): + raises(ValueError, lambda: len(Range(0, oo, 2))) + assert Range(0, oo, 2).size is S.Infinity + assert Range(0, -oo, -2).size is S.Infinity + assert Range(oo, 0, -2).size is S.Infinity + assert Range(-oo, 0, 2).size is S.Infinity + + +def test_Range_set(): + empty = Range(0) + + assert Range(5) == Range(0, 5) == Range(0, 5, 1) + + r = Range(10, 20, 2) + assert 12 in r + assert 8 not in r + assert 11 not in r + assert 30 not in r + + assert list(Range(0, 5)) == list(range(5)) + assert list(Range(5, 0, -1)) == list(range(5, 0, -1)) + + + assert Range(5, 15).sup == 14 + assert Range(5, 15).inf == 5 + assert Range(15, 5, -1).sup == 15 + assert Range(15, 5, -1).inf == 6 + assert Range(10, 67, 10).sup == 60 + assert Range(60, 7, -10).inf == 10 + + assert len(Range(10, 38, 10)) == 3 + + assert Range(0, 0, 5) == empty + assert Range(oo, oo, 1) == empty + assert Range(oo, 1, 1) == empty + assert Range(-oo, 1, -1) == empty + assert Range(1, oo, -1) == empty + assert Range(1, -oo, 1) == empty + assert Range(1, -4, oo) == empty + ip = symbols('ip', positive=True) + assert Range(0, ip, -1) == empty + assert Range(0, -ip, 1) == empty + assert Range(1, -4, -oo) == Range(1, 2) + assert Range(1, 4, oo) == Range(1, 2) + assert Range(-oo, oo).size == oo + assert Range(oo, -oo, -1).size == oo + raises(ValueError, lambda: Range(-oo, oo, 2)) + raises(ValueError, lambda: Range(x, pi, y)) + raises(ValueError, lambda: Range(x, y, 0)) + + assert 5 in Range(0, oo, 5) + assert -5 in Range(-oo, 0, 5) + assert oo not in Range(0, oo) + ni = symbols('ni', integer=False) + assert ni not in Range(oo) + u = symbols('u', integer=None) + assert Range(oo).contains(u) is not False + inf = symbols('inf', infinite=True) + assert inf not in Range(-oo, oo) + raises(ValueError, lambda: Range(0, oo, 2)[-1]) + raises(ValueError, lambda: Range(0, -oo, -2)[-1]) + assert Range(-oo, 1, 1)[-1] is S.Zero + assert Range(oo, 1, -1)[-1] == 2 + assert inf not in Range(oo) + assert Range(1, 10, 1)[-1] == 9 + assert all(i.is_Integer for i in Range(0, -1, 1)) + it = iter(Range(-oo, 0, 2)) + raises(TypeError, lambda: next(it)) + + assert empty.intersect(S.Integers) == empty + assert Range(-1, 10, 1).intersect(S.Complexes) == Range(-1, 10, 1) + assert Range(-1, 10, 1).intersect(S.Reals) == Range(-1, 10, 1) + assert Range(-1, 10, 1).intersect(S.Rationals) == Range(-1, 10, 1) + assert Range(-1, 10, 1).intersect(S.Integers) == Range(-1, 10, 1) + assert Range(-1, 10, 1).intersect(S.Naturals) == Range(1, 10, 1) + assert Range(-1, 10, 1).intersect(S.Naturals0) == Range(0, 10, 1) + + # test slicing + assert Range(1, 10, 1)[5] == 6 + assert Range(1, 12, 2)[5] == 11 + assert Range(1, 10, 1)[-1] == 9 + assert Range(1, 10, 3)[-1] == 7 + raises(ValueError, lambda: Range(oo,0,-1)[1:3:0]) + raises(ValueError, lambda: Range(oo,0,-1)[:1]) + raises(ValueError, lambda: Range(1, oo)[-2]) + raises(ValueError, lambda: Range(-oo, 1)[2]) + raises(IndexError, lambda: Range(10)[-20]) + raises(IndexError, lambda: Range(10)[20]) + raises(ValueError, lambda: Range(2, -oo, -2)[2:2:0]) + assert Range(2, -oo, -2)[2:2:2] == empty + assert Range(2, -oo, -2)[:2:2] == Range(2, -2, -4) + raises(ValueError, lambda: Range(-oo, 4, 2)[:2:2]) + assert Range(-oo, 4, 2)[::-2] == Range(2, -oo, -4) + raises(ValueError, lambda: Range(-oo, 4, 2)[::2]) + assert Range(oo, 2, -2)[::] == Range(oo, 2, -2) + assert Range(-oo, 4, 2)[:-2:-2] == Range(2, 0, -4) + assert Range(-oo, 4, 2)[:-2:2] == Range(-oo, 0, 4) + raises(ValueError, lambda: Range(-oo, 4, 2)[:0:-2]) + raises(ValueError, lambda: Range(-oo, 4, 2)[:2:-2]) + assert Range(-oo, 4, 2)[-2::-2] == Range(0, -oo, -4) + raises(ValueError, lambda: Range(-oo, 4, 2)[-2:0:-2]) + raises(ValueError, lambda: Range(-oo, 4, 2)[0::2]) + assert Range(oo, 2, -2)[0::] == Range(oo, 2, -2) + raises(ValueError, lambda: Range(-oo, 4, 2)[0:-2:2]) + assert Range(oo, 2, -2)[0:-2:] == Range(oo, 6, -2) + raises(ValueError, lambda: Range(oo, 2, -2)[0:2:]) + raises(ValueError, lambda: Range(-oo, 4, 2)[2::-1]) + assert Range(-oo, 4, 2)[-2::2] == Range(0, 4, 4) + assert Range(oo, 0, -2)[-10:0:2] == empty + raises(ValueError, lambda: Range(oo, 0, -2)[0]) + raises(ValueError, lambda: Range(oo, 0, -2)[-10:10:2]) + raises(ValueError, lambda: Range(oo, 0, -2)[0::-2]) + assert Range(oo, 0, -2)[0:-4:-2] == empty + assert Range(oo, 0, -2)[:0:2] == empty + raises(ValueError, lambda: Range(oo, 0, -2)[:1:-1]) + + # test empty Range + assert Range(x, x, y) == empty + assert empty.reversed == empty + assert 0 not in empty + assert list(empty) == [] + assert len(empty) == 0 + assert empty.size is S.Zero + assert empty.intersect(FiniteSet(0)) is S.EmptySet + assert bool(empty) is False + raises(IndexError, lambda: empty[0]) + assert empty[:0] == empty + raises(NotImplementedError, lambda: empty.inf) + raises(NotImplementedError, lambda: empty.sup) + assert empty.as_relational(x) is S.false + + AB = [None] + list(range(12)) + for R in [ + Range(1, 10), + Range(1, 10, 2), + ]: + r = list(R) + for a, b, c in itertools.product(AB, AB, [-3, -1, None, 1, 3]): + for reverse in range(2): + r = list(reversed(r)) + R = R.reversed + result = list(R[a:b:c]) + ans = r[a:b:c] + txt = ('\n%s[%s:%s:%s] = %s -> %s' % ( + R, a, b, c, result, ans)) + check = ans == result + assert check, txt + + assert Range(1, 10, 1).boundary == Range(1, 10, 1) + + for r in (Range(1, 10, 2), Range(1, oo, 2)): + rev = r.reversed + assert r.inf == rev.inf and r.sup == rev.sup + assert r.step == -rev.step + + builtin_range = range + + raises(TypeError, lambda: Range(builtin_range(1))) + assert S(builtin_range(10)) == Range(10) + assert S(builtin_range(1000000000000)) == Range(1000000000000) + + # test Range.as_relational + assert Range(1, 4).as_relational(x) == (x >= 1) & (x <= 3) & Eq(Mod(x, 1), 0) + assert Range(oo, 1, -2).as_relational(x) == (x >= 3) & (x < oo) & Eq(Mod(x + 1, -2), 0) + + +def test_Range_symbolic(): + # symbolic Range + xr = Range(x, x + 4, 5) + sr = Range(x, y, t) + i = Symbol('i', integer=True) + ip = Symbol('i', integer=True, positive=True) + ipr = Range(ip) + inr = Range(0, -ip, -1) + ir = Range(i, i + 19, 2) + ir2 = Range(i, i*8, 3*i) + i = Symbol('i', integer=True) + inf = symbols('inf', infinite=True) + raises(ValueError, lambda: Range(inf)) + raises(ValueError, lambda: Range(inf, 0, -1)) + raises(ValueError, lambda: Range(inf, inf, 1)) + raises(ValueError, lambda: Range(1, 1, inf)) + # args + assert xr.args == (x, x + 5, 5) + assert sr.args == (x, y, t) + assert ir.args == (i, i + 20, 2) + assert ir2.args == (i, 10*i, 3*i) + # reversed + raises(ValueError, lambda: xr.reversed) + raises(ValueError, lambda: sr.reversed) + assert ipr.reversed.args == (ip - 1, -1, -1) + assert inr.reversed.args == (-ip + 1, 1, 1) + assert ir.reversed.args == (i + 18, i - 2, -2) + assert ir2.reversed.args == (7*i, -2*i, -3*i) + # contains + assert inf not in sr + assert inf not in ir + assert 0 in ipr + assert 0 in inr + raises(TypeError, lambda: 1 in ipr) + raises(TypeError, lambda: -1 in inr) + assert .1 not in sr + assert .1 not in ir + assert i + 1 not in ir + assert i + 2 in ir + raises(TypeError, lambda: x in xr) # XXX is this what contains is supposed to do? + raises(TypeError, lambda: 1 in sr) # XXX is this what contains is supposed to do? + # iter + raises(ValueError, lambda: next(iter(xr))) + raises(ValueError, lambda: next(iter(sr))) + assert next(iter(ir)) == i + assert next(iter(ir2)) == i + assert sr.intersect(S.Integers) == sr + assert sr.intersect(FiniteSet(x)) == Intersection({x}, sr) + raises(ValueError, lambda: sr[:2]) + raises(ValueError, lambda: xr[0]) + raises(ValueError, lambda: sr[0]) + # len + assert len(ir) == ir.size == 10 + assert len(ir2) == ir2.size == 3 + raises(ValueError, lambda: len(xr)) + raises(ValueError, lambda: xr.size) + raises(ValueError, lambda: len(sr)) + raises(ValueError, lambda: sr.size) + # bool + assert bool(Range(0)) == False + assert bool(xr) + assert bool(ir) + assert bool(ipr) + assert bool(inr) + raises(ValueError, lambda: bool(sr)) + raises(ValueError, lambda: bool(ir2)) + # inf + raises(ValueError, lambda: xr.inf) + raises(ValueError, lambda: sr.inf) + assert ipr.inf == 0 + assert inr.inf == -ip + 1 + assert ir.inf == i + raises(ValueError, lambda: ir2.inf) + # sup + raises(ValueError, lambda: xr.sup) + raises(ValueError, lambda: sr.sup) + assert ipr.sup == ip - 1 + assert inr.sup == 0 + assert ir.inf == i + raises(ValueError, lambda: ir2.sup) + # getitem + raises(ValueError, lambda: xr[0]) + raises(ValueError, lambda: sr[0]) + raises(ValueError, lambda: sr[-1]) + raises(ValueError, lambda: sr[:2]) + assert ir[:2] == Range(i, i + 4, 2) + assert ir[0] == i + assert ir[-2] == i + 16 + assert ir[-1] == i + 18 + assert ir2[:2] == Range(i, 7*i, 3*i) + assert ir2[0] == i + assert ir2[-2] == 4*i + assert ir2[-1] == 7*i + raises(ValueError, lambda: Range(i)[-1]) + assert ipr[0] == ipr.inf == 0 + assert ipr[-1] == ipr.sup == ip - 1 + assert inr[0] == inr.sup == 0 + assert inr[-1] == inr.inf == -ip + 1 + raises(ValueError, lambda: ipr[-2]) + assert ir.inf == i + assert ir.sup == i + 18 + raises(ValueError, lambda: Range(i).inf) + # as_relational + assert ir.as_relational(x) == ((x >= i) & (x <= i + 18) & + Eq(Mod(-i + x, 2), 0)) + assert ir2.as_relational(x) == Eq( + Mod(-i + x, 3*i), 0) & (((x >= i) & (x <= 7*i) & (3*i >= 1)) | + ((x <= i) & (x >= 7*i) & (3*i <= -1))) + assert Range(i, i + 1).as_relational(x) == Eq(x, i) + assert sr.as_relational(z) == Eq( + Mod(t, 1), 0) & Eq(Mod(x, 1), 0) & Eq(Mod(-x + z, t), 0 + ) & (((z >= x) & (z <= -t + y) & (t >= 1)) | + ((z <= x) & (z >= -t + y) & (t <= -1))) + assert xr.as_relational(z) == Eq(z, x) & Eq(Mod(x, 1), 0) + # symbols can clash if user wants (but it must be integer) + assert xr.as_relational(x) == Eq(Mod(x, 1), 0) + # contains() for symbolic values (issue #18146) + e = Symbol('e', integer=True, even=True) + o = Symbol('o', integer=True, odd=True) + assert Range(5).contains(i) == And(i >= 0, i <= 4) + assert Range(1).contains(i) == Eq(i, 0) + assert Range(-oo, 5, 1).contains(i) == (i <= 4) + assert Range(-oo, oo).contains(i) == True + assert Range(0, 8, 2).contains(i) == Contains(i, Range(0, 8, 2)) + assert Range(0, 8, 2).contains(e) == And(e >= 0, e <= 6) + assert Range(0, 8, 2).contains(2*i) == And(2*i >= 0, 2*i <= 6) + assert Range(0, 8, 2).contains(o) == False + assert Range(1, 9, 2).contains(e) == False + assert Range(1, 9, 2).contains(o) == And(o >= 1, o <= 7) + assert Range(8, 0, -2).contains(o) == False + assert Range(9, 1, -2).contains(o) == And(o >= 3, o <= 9) + assert Range(-oo, 8, 2).contains(i) == Contains(i, Range(-oo, 8, 2)) + + +def test_range_range_intersection(): + for a, b, r in [ + (Range(0), Range(1), S.EmptySet), + (Range(3), Range(4, oo), S.EmptySet), + (Range(3), Range(-3, -1), S.EmptySet), + (Range(1, 3), Range(0, 3), Range(1, 3)), + (Range(1, 3), Range(1, 4), Range(1, 3)), + (Range(1, oo, 2), Range(2, oo, 2), S.EmptySet), + (Range(0, oo, 2), Range(oo), Range(0, oo, 2)), + (Range(0, oo, 2), Range(100), Range(0, 100, 2)), + (Range(2, oo, 2), Range(oo), Range(2, oo, 2)), + (Range(0, oo, 2), Range(5, 6), S.EmptySet), + (Range(2, 80, 1), Range(55, 71, 4), Range(55, 71, 4)), + (Range(0, 6, 3), Range(-oo, 5, 3), S.EmptySet), + (Range(0, oo, 2), Range(5, oo, 3), Range(8, oo, 6)), + (Range(4, 6, 2), Range(2, 16, 7), S.EmptySet),]: + assert a.intersect(b) == r + assert a.intersect(b.reversed) == r + assert a.reversed.intersect(b) == r + assert a.reversed.intersect(b.reversed) == r + a, b = b, a + assert a.intersect(b) == r + assert a.intersect(b.reversed) == r + assert a.reversed.intersect(b) == r + assert a.reversed.intersect(b.reversed) == r + + +def test_range_interval_intersection(): + p = symbols('p', positive=True) + assert isinstance(Range(3).intersect(Interval(p, p + 2)), Intersection) + assert Range(4).intersect(Interval(0, 3)) == Range(4) + assert Range(4).intersect(Interval(-oo, oo)) == Range(4) + assert Range(4).intersect(Interval(1, oo)) == Range(1, 4) + assert Range(4).intersect(Interval(1.1, oo)) == Range(2, 4) + assert Range(4).intersect(Interval(0.1, 3)) == Range(1, 4) + assert Range(4).intersect(Interval(0.1, 3.1)) == Range(1, 4) + assert Range(4).intersect(Interval.open(0, 3)) == Range(1, 3) + assert Range(4).intersect(Interval.open(0.1, 0.5)) is S.EmptySet + assert Interval(-1, 5).intersect(S.Complexes) == Interval(-1, 5) + assert Interval(-1, 5).intersect(S.Reals) == Interval(-1, 5) + assert Interval(-1, 5).intersect(S.Integers) == Range(-1, 6) + assert Interval(-1, 5).intersect(S.Naturals) == Range(1, 6) + assert Interval(-1, 5).intersect(S.Naturals0) == Range(0, 6) + + # Null Range intersections + assert Range(0).intersect(Interval(0.2, 0.8)) is S.EmptySet + assert Range(0).intersect(Interval(-oo, oo)) is S.EmptySet + + +def test_range_is_finite_set(): + assert Range(-100, 100).is_finite_set is True + assert Range(2, oo).is_finite_set is False + assert Range(-oo, 50).is_finite_set is False + assert Range(-oo, oo).is_finite_set is False + assert Range(oo, -oo).is_finite_set is True + assert Range(0, 0).is_finite_set is True + assert Range(oo, oo).is_finite_set is True + assert Range(-oo, -oo).is_finite_set is True + n = Symbol('n', integer=True) + m = Symbol('m', integer=True) + assert Range(n, n + 49).is_finite_set is True + assert Range(n, 0).is_finite_set is True + assert Range(-3, n + 7).is_finite_set is True + assert Range(n, m).is_finite_set is True + assert Range(n + m, m - n).is_finite_set is True + assert Range(n, n + m + n).is_finite_set is True + assert Range(n, oo).is_finite_set is False + assert Range(-oo, n).is_finite_set is False + assert Range(n, -oo).is_finite_set is True + assert Range(oo, n).is_finite_set is True + + +def test_Range_is_iterable(): + assert Range(-100, 100).is_iterable is True + assert Range(2, oo).is_iterable is False + assert Range(-oo, 50).is_iterable is False + assert Range(-oo, oo).is_iterable is False + assert Range(oo, -oo).is_iterable is True + assert Range(0, 0).is_iterable is True + assert Range(oo, oo).is_iterable is True + assert Range(-oo, -oo).is_iterable is True + n = Symbol('n', integer=True) + m = Symbol('m', integer=True) + p = Symbol('p', integer=True, positive=True) + assert Range(n, n + 49).is_iterable is True + assert Range(n, 0).is_iterable is False + assert Range(-3, n + 7).is_iterable is False + assert Range(-3, p + 7).is_iterable is False # Should work with better __iter__ + assert Range(n, m).is_iterable is False + assert Range(n + m, m - n).is_iterable is False + assert Range(n, n + m + n).is_iterable is False + assert Range(n, oo).is_iterable is False + assert Range(-oo, n).is_iterable is False + x = Symbol('x') + assert Range(x, x + 49).is_iterable is False + assert Range(x, 0).is_iterable is False + assert Range(-3, x + 7).is_iterable is False + assert Range(x, m).is_iterable is False + assert Range(x + m, m - x).is_iterable is False + assert Range(x, x + m + x).is_iterable is False + assert Range(x, oo).is_iterable is False + assert Range(-oo, x).is_iterable is False + + +def test_Integers_eval_imageset(): + ans = ImageSet(Lambda(x, 2*x + Rational(3, 7)), S.Integers) + im = imageset(Lambda(x, -2*x + Rational(3, 7)), S.Integers) + assert im == ans + im = imageset(Lambda(x, -2*x - Rational(11, 7)), S.Integers) + assert im == ans + y = Symbol('y') + L = imageset(x, 2*x + y, S.Integers) + assert y + 4 in L + a, b, c = 0.092, 0.433, 0.341 + assert a in imageset(x, a + c*x, S.Integers) + assert b in imageset(x, b + c*x, S.Integers) + + _x = symbols('x', negative=True) + eq = _x**2 - _x + 1 + assert imageset(_x, eq, S.Integers).lamda.expr == _x**2 + _x + 1 + eq = 3*_x - 1 + assert imageset(_x, eq, S.Integers).lamda.expr == 3*_x + 2 + + assert imageset(x, (x, 1/x), S.Integers) == \ + ImageSet(Lambda(x, (x, 1/x)), S.Integers) + + +def test_Range_eval_imageset(): + a, b, c = symbols('a b c') + assert imageset(x, a*(x + b) + c, Range(3)) == \ + imageset(x, a*x + a*b + c, Range(3)) + eq = (x + 1)**2 + assert imageset(x, eq, Range(3)).lamda.expr == eq + eq = a*(x + b) + c + r = Range(3, -3, -2) + imset = imageset(x, eq, r) + assert imset.lamda.expr != eq + assert list(imset) == [eq.subs(x, i).expand() for i in list(r)] + + +def test_fun(): + assert (FiniteSet(*ImageSet(Lambda(x, sin(pi*x/4)), + Range(-10, 11))) == FiniteSet(-1, -sqrt(2)/2, 0, sqrt(2)/2, 1)) + + +def test_Range_is_empty(): + i = Symbol('i', integer=True) + n = Symbol('n', negative=True, integer=True) + p = Symbol('p', positive=True, integer=True) + + assert Range(0).is_empty + assert not Range(1).is_empty + assert Range(1, 0).is_empty + assert not Range(-1, 0).is_empty + assert Range(i).is_empty is None + assert Range(n).is_empty + assert Range(p).is_empty is False + assert Range(n, 0).is_empty is False + assert Range(n, p).is_empty is False + assert Range(p, n).is_empty + assert Range(n, -1).is_empty is None + assert Range(p, n, -1).is_empty is False + + +def test_Reals(): + assert 5 in S.Reals + assert S.Pi in S.Reals + assert -sqrt(2) in S.Reals + assert (2, 5) not in S.Reals + assert sqrt(-1) not in S.Reals + assert S.Reals == Interval(-oo, oo) + assert S.Reals != Interval(0, oo) + assert S.Reals.is_subset(Interval(-oo, oo)) + assert S.Reals.intersect(Range(-oo, oo)) == Range(-oo, oo) + assert S.ComplexInfinity not in S.Reals + assert S.NaN not in S.Reals + assert x + S.ComplexInfinity not in S.Reals + + +def test_Complex(): + assert 5 in S.Complexes + assert 5 + 4*I in S.Complexes + assert S.Pi in S.Complexes + assert -sqrt(2) in S.Complexes + assert -I in S.Complexes + assert sqrt(-1) in S.Complexes + assert S.Complexes.intersect(S.Reals) == S.Reals + assert S.Complexes.union(S.Reals) == S.Complexes + assert S.Complexes == ComplexRegion(S.Reals*S.Reals) + assert (S.Complexes == ComplexRegion(Interval(1, 2)*Interval(3, 4))) == False + assert str(S.Complexes) == "Complexes" + assert repr(S.Complexes) == "Complexes" + + +def take(n, iterable): + "Return first n items of the iterable as a list" + return list(itertools.islice(iterable, n)) + + +def test_intersections(): + assert S.Integers.intersect(S.Reals) == S.Integers + assert 5 in S.Integers.intersect(S.Reals) + assert 5 in S.Integers.intersect(S.Reals) + assert -5 not in S.Naturals.intersect(S.Reals) + assert 5.5 not in S.Integers.intersect(S.Reals) + assert 5 in S.Integers.intersect(Interval(3, oo)) + assert -5 in S.Integers.intersect(Interval(-oo, 3)) + assert all(x.is_Integer + for x in take(10, S.Integers.intersect(Interval(3, oo)) )) + + +def test_infinitely_indexed_set_1(): + from sympy.abc import n, m + assert imageset(Lambda(n, n), S.Integers) == imageset(Lambda(m, m), S.Integers) + + assert imageset(Lambda(n, 2*n), S.Integers).intersect( + imageset(Lambda(m, 2*m + 1), S.Integers)) is S.EmptySet + + assert imageset(Lambda(n, 2*n), S.Integers).intersect( + imageset(Lambda(n, 2*n + 1), S.Integers)) is S.EmptySet + + assert imageset(Lambda(m, 2*m), S.Integers).intersect( + imageset(Lambda(n, 3*n), S.Integers)).dummy_eq( + ImageSet(Lambda(t, 6*t), S.Integers)) + + assert imageset(x, x/2 + Rational(1, 3), S.Integers).intersect(S.Integers) is S.EmptySet + assert imageset(x, x/2 + S.Half, S.Integers).intersect(S.Integers) is S.Integers + + # https://github.com/sympy/sympy/issues/17355 + S53 = ImageSet(Lambda(n, 5*n + 3), S.Integers) + assert S53.intersect(S.Integers) == S53 + + +def test_infinitely_indexed_set_2(): + from sympy.abc import n + a = Symbol('a', integer=True) + assert imageset(Lambda(n, n), S.Integers) == \ + imageset(Lambda(n, n + a), S.Integers) + assert imageset(Lambda(n, n + pi), S.Integers) == \ + imageset(Lambda(n, n + a + pi), S.Integers) + assert imageset(Lambda(n, n), S.Integers) == \ + imageset(Lambda(n, -n + a), S.Integers) + assert imageset(Lambda(n, -6*n), S.Integers) == \ + ImageSet(Lambda(n, 6*n), S.Integers) + assert imageset(Lambda(n, 2*n + pi), S.Integers) == \ + ImageSet(Lambda(n, 2*n + pi - 2), S.Integers) + + +def test_imageset_intersect_real(): + from sympy.abc import n + assert imageset(Lambda(n, n + (n - 1)*(n + 1)*I), S.Integers).intersect(S.Reals) == FiniteSet(-1, 1) + im = (n - 1)*(n + S.Half) + assert imageset(Lambda(n, n + im*I), S.Integers + ).intersect(S.Reals) == FiniteSet(1) + assert imageset(Lambda(n, n + im*(n + 1)*I), S.Naturals0 + ).intersect(S.Reals) == FiniteSet(1) + assert imageset(Lambda(n, n/2 + im.expand()*I), S.Integers + ).intersect(S.Reals) == ImageSet(Lambda(x, x/2), ConditionSet( + n, Eq(n**2 - n/2 - S(1)/2, 0), S.Integers)) + assert imageset(Lambda(n, n/(1/n - 1) + im*(n + 1)*I), S.Integers + ).intersect(S.Reals) == FiniteSet(S.Half) + assert imageset(Lambda(n, n/(n - 6) + + (n - 3)*(n + 1)*I/(2*n + 2)), S.Integers).intersect( + S.Reals) == FiniteSet(-1) + assert imageset(Lambda(n, n/(n**2 - 9) + + (n - 3)*(n + 1)*I/(2*n + 2)), S.Integers).intersect( + S.Reals) is S.EmptySet + s = ImageSet( + Lambda(n, -I*(I*(2*pi*n - pi/4) + log(Abs(sqrt(-I))))), + S.Integers) + # s is unevaluated, but after intersection the result + # should be canonical + assert s.intersect(S.Reals) == imageset( + Lambda(n, 2*n*pi - pi/4), S.Integers) == ImageSet( + Lambda(n, 2*pi*n + pi*Rational(7, 4)), S.Integers) + + +def test_imageset_intersect_interval(): + from sympy.abc import n + f1 = ImageSet(Lambda(n, n*pi), S.Integers) + f2 = ImageSet(Lambda(n, 2*n), Interval(0, pi)) + f3 = ImageSet(Lambda(n, 2*n*pi + pi/2), S.Integers) + # complex expressions + f4 = ImageSet(Lambda(n, n*I*pi), S.Integers) + f5 = ImageSet(Lambda(n, 2*I*n*pi + pi/2), S.Integers) + # non-linear expressions + f6 = ImageSet(Lambda(n, log(n)), S.Integers) + f7 = ImageSet(Lambda(n, n**2), S.Integers) + f8 = ImageSet(Lambda(n, Abs(n)), S.Integers) + f9 = ImageSet(Lambda(n, exp(n)), S.Naturals0) + + assert f1.intersect(Interval(-1, 1)) == FiniteSet(0) + assert f1.intersect(Interval(0, 2*pi, False, True)) == FiniteSet(0, pi) + assert f2.intersect(Interval(1, 2)) == Interval(1, 2) + assert f3.intersect(Interval(-1, 1)) == S.EmptySet + assert f3.intersect(Interval(-5, 5)) == FiniteSet(pi*Rational(-3, 2), pi/2) + assert f4.intersect(Interval(-1, 1)) == FiniteSet(0) + assert f4.intersect(Interval(1, 2)) == S.EmptySet + assert f5.intersect(Interval(0, 1)) == S.EmptySet + assert f6.intersect(Interval(0, 1)) == FiniteSet(S.Zero, log(2)) + assert f7.intersect(Interval(0, 10)) == Intersection(f7, Interval(0, 10)) + assert f8.intersect(Interval(0, 2)) == Intersection(f8, Interval(0, 2)) + assert f9.intersect(Interval(1, 2)) == Intersection(f9, Interval(1, 2)) + + +def test_imageset_intersect_diophantine(): + from sympy.abc import m, n + # Check that same lambda variable for both ImageSets is handled correctly + img1 = ImageSet(Lambda(n, 2*n + 1), S.Integers) + img2 = ImageSet(Lambda(n, 4*n + 1), S.Integers) + assert img1.intersect(img2) == img2 + # Empty solution set returned by diophantine: + assert ImageSet(Lambda(n, 2*n), S.Integers).intersect( + ImageSet(Lambda(n, 2*n + 1), S.Integers)) == S.EmptySet + # Check intersection with S.Integers: + assert ImageSet(Lambda(n, 9/n + 20*n/3), S.Integers).intersect( + S.Integers) == FiniteSet(-61, -23, 23, 61) + # Single solution (2, 3) for diophantine solution: + assert ImageSet(Lambda(n, (n - 2)**2), S.Integers).intersect( + ImageSet(Lambda(n, -(n - 3)**2), S.Integers)) == FiniteSet(0) + # Single parametric solution for diophantine solution: + assert ImageSet(Lambda(n, n**2 + 5), S.Integers).intersect( + ImageSet(Lambda(m, 2*m), S.Integers)).dummy_eq(ImageSet( + Lambda(n, 4*n**2 + 4*n + 6), S.Integers)) + # 4 non-parametric solution couples for dioph. equation: + assert ImageSet(Lambda(n, n**2 - 9), S.Integers).intersect( + ImageSet(Lambda(m, -m**2), S.Integers)) == FiniteSet(-9, 0) + # Double parametric solution for diophantine solution: + assert ImageSet(Lambda(m, m**2 + 40), S.Integers).intersect( + ImageSet(Lambda(n, 41*n), S.Integers)).dummy_eq(Intersection( + ImageSet(Lambda(m, m**2 + 40), S.Integers), + ImageSet(Lambda(n, 41*n), S.Integers))) + # Check that diophantine returns *all* (8) solutions (permute=True) + assert ImageSet(Lambda(n, n**4 - 2**4), S.Integers).intersect( + ImageSet(Lambda(m, -m**4 + 3**4), S.Integers)) == FiniteSet(0, 65) + assert ImageSet(Lambda(n, pi/12 + n*5*pi/12), S.Integers).intersect( + ImageSet(Lambda(n, 7*pi/12 + n*11*pi/12), S.Integers)).dummy_eq(ImageSet( + Lambda(n, 55*pi*n/12 + 17*pi/4), S.Integers)) + # TypeError raised by diophantine (#18081) + assert ImageSet(Lambda(n, n*log(2)), S.Integers).intersection( + S.Integers).dummy_eq(Intersection(ImageSet( + Lambda(n, n*log(2)), S.Integers), S.Integers)) + # NotImplementedError raised by diophantine (no solver for cubic_thue) + assert ImageSet(Lambda(n, n**3 + 1), S.Integers).intersect( + ImageSet(Lambda(n, n**3), S.Integers)).dummy_eq(Intersection( + ImageSet(Lambda(n, n**3 + 1), S.Integers), + ImageSet(Lambda(n, n**3), S.Integers))) + + +def test_infinitely_indexed_set_3(): + from sympy.abc import n, m + assert imageset(Lambda(m, 2*pi*m), S.Integers).intersect( + imageset(Lambda(n, 3*pi*n), S.Integers)).dummy_eq( + ImageSet(Lambda(t, 6*pi*t), S.Integers)) + assert imageset(Lambda(n, 2*n + 1), S.Integers) == \ + imageset(Lambda(n, 2*n - 1), S.Integers) + assert imageset(Lambda(n, 3*n + 2), S.Integers) == \ + imageset(Lambda(n, 3*n - 1), S.Integers) + + +def test_ImageSet_simplification(): + from sympy.abc import n, m + assert imageset(Lambda(n, n), S.Integers) == S.Integers + assert imageset(Lambda(n, sin(n)), + imageset(Lambda(m, tan(m)), S.Integers)) == \ + imageset(Lambda(m, sin(tan(m))), S.Integers) + assert imageset(n, 1 + 2*n, S.Naturals) == Range(3, oo, 2) + assert imageset(n, 1 + 2*n, S.Naturals0) == Range(1, oo, 2) + assert imageset(n, 1 - 2*n, S.Naturals) == Range(-1, -oo, -2) + + +def test_ImageSet_contains(): + assert (2, S.Half) in imageset(x, (x, 1/x), S.Integers) + assert imageset(x, x + I*3, S.Integers).intersection(S.Reals) is S.EmptySet + i = Dummy(integer=True) + q = imageset(x, x + I*y, S.Integers).intersection(S.Reals) + assert q.subs(y, I*i).intersection(S.Integers) is S.Integers + q = imageset(x, x + I*y/x, S.Integers).intersection(S.Reals) + assert q.subs(y, 0) is S.Integers + assert q.subs(y, I*i*x).intersection(S.Integers) is S.Integers + z = cos(1)**2 + sin(1)**2 - 1 + q = imageset(x, x + I*z, S.Integers).intersection(S.Reals) + assert q is not S.EmptySet + + +def test_ComplexRegion_contains(): + r = Symbol('r', real=True) + # contains in ComplexRegion + a = Interval(2, 3) + b = Interval(4, 6) + c = Interval(7, 9) + c1 = ComplexRegion(a*b) + c2 = ComplexRegion(Union(a*b, c*a)) + assert 2.5 + 4.5*I in c1 + assert 2 + 4*I in c1 + assert 3 + 4*I in c1 + assert 8 + 2.5*I in c2 + assert 2.5 + 6.1*I not in c1 + assert 4.5 + 3.2*I not in c1 + assert c1.contains(x) == Contains(x, c1, evaluate=False) + assert c1.contains(r) == False + assert c2.contains(x) == Contains(x, c2, evaluate=False) + assert c2.contains(r) == False + + r1 = Interval(0, 1) + theta1 = Interval(0, 2*S.Pi) + c3 = ComplexRegion(r1*theta1, polar=True) + assert (0.5 + I*6/10) in c3 + assert (S.Half + I*6/10) in c3 + assert (S.Half + .6*I) in c3 + assert (0.5 + .6*I) in c3 + assert I in c3 + assert 1 in c3 + assert 0 in c3 + assert 1 + I not in c3 + assert 1 - I not in c3 + assert c3.contains(x) == Contains(x, c3, evaluate=False) + assert c3.contains(r + 2*I) == Contains( + r + 2*I, c3, evaluate=False) # is in fact False + assert c3.contains(1/(1 + r**2)) == Contains( + 1/(1 + r**2), c3, evaluate=False) # is in fact True + + r2 = Interval(0, 3) + theta2 = Interval(pi, 2*pi, left_open=True) + c4 = ComplexRegion(r2*theta2, polar=True) + assert c4.contains(0) == True + assert c4.contains(2 + I) == False + assert c4.contains(-2 + I) == False + assert c4.contains(-2 - I) == True + assert c4.contains(2 - I) == True + assert c4.contains(-2) == False + assert c4.contains(2) == True + assert c4.contains(x) == Contains(x, c4, evaluate=False) + assert c4.contains(3/(1 + r**2)) == Contains( + 3/(1 + r**2), c4, evaluate=False) # is in fact True + + raises(ValueError, lambda: ComplexRegion(r1*theta1, polar=2)) + + +def test_symbolic_Range(): + n = Symbol('n') + raises(ValueError, lambda: Range(n)[0]) + raises(IndexError, lambda: Range(n, n)[0]) + raises(ValueError, lambda: Range(n, n+1)[0]) + raises(ValueError, lambda: Range(n).size) + + n = Symbol('n', integer=True) + raises(ValueError, lambda: Range(n)[0]) + raises(IndexError, lambda: Range(n, n)[0]) + assert Range(n, n+1)[0] == n + raises(ValueError, lambda: Range(n).size) + assert Range(n, n+1).size == 1 + + n = Symbol('n', integer=True, nonnegative=True) + raises(ValueError, lambda: Range(n)[0]) + raises(IndexError, lambda: Range(n, n)[0]) + assert Range(n+1)[0] == 0 + assert Range(n, n+1)[0] == n + assert Range(n).size == n + assert Range(n+1).size == n+1 + assert Range(n, n+1).size == 1 + + n = Symbol('n', integer=True, positive=True) + assert Range(n)[0] == 0 + assert Range(n, n+1)[0] == n + assert Range(n).size == n + assert Range(n, n+1).size == 1 + + m = Symbol('m', integer=True, positive=True) + + assert Range(n, n+m)[0] == n + assert Range(n, n+m).size == m + assert Range(n, n+1).size == 1 + assert Range(n, n+m, 2).size == floor(m/2) + + m = Symbol('m', integer=True, positive=True, even=True) + assert Range(n, n+m, 2).size == m/2 + + +def test_issue_18400(): + n = Symbol('n', integer=True) + raises(ValueError, lambda: imageset(lambda x: x*2, Range(n))) + + n = Symbol('n', integer=True, positive=True) + # No exception + assert imageset(lambda x: x*2, Range(n)) == imageset(lambda x: x*2, Range(n)) + + +def test_ComplexRegion_intersect(): + # Polar form + X_axis = ComplexRegion(Interval(0, oo)*FiniteSet(0, S.Pi), polar=True) + + unit_disk = ComplexRegion(Interval(0, 1)*Interval(0, 2*S.Pi), polar=True) + upper_half_unit_disk = ComplexRegion(Interval(0, 1)*Interval(0, S.Pi), polar=True) + upper_half_disk = ComplexRegion(Interval(0, oo)*Interval(0, S.Pi), polar=True) + lower_half_disk = ComplexRegion(Interval(0, oo)*Interval(S.Pi, 2*S.Pi), polar=True) + right_half_disk = ComplexRegion(Interval(0, oo)*Interval(-S.Pi/2, S.Pi/2), polar=True) + first_quad_disk = ComplexRegion(Interval(0, oo)*Interval(0, S.Pi/2), polar=True) + + assert upper_half_disk.intersect(unit_disk) == upper_half_unit_disk + assert right_half_disk.intersect(first_quad_disk) == first_quad_disk + assert upper_half_disk.intersect(right_half_disk) == first_quad_disk + assert upper_half_disk.intersect(lower_half_disk) == X_axis + + c1 = ComplexRegion(Interval(0, 4)*Interval(0, 2*S.Pi), polar=True) + assert c1.intersect(Interval(1, 5)) == Interval(1, 4) + assert c1.intersect(Interval(4, 9)) == FiniteSet(4) + assert c1.intersect(Interval(5, 12)) is S.EmptySet + + # Rectangular form + X_axis = ComplexRegion(Interval(-oo, oo)*FiniteSet(0)) + + unit_square = ComplexRegion(Interval(-1, 1)*Interval(-1, 1)) + upper_half_unit_square = ComplexRegion(Interval(-1, 1)*Interval(0, 1)) + upper_half_plane = ComplexRegion(Interval(-oo, oo)*Interval(0, oo)) + lower_half_plane = ComplexRegion(Interval(-oo, oo)*Interval(-oo, 0)) + right_half_plane = ComplexRegion(Interval(0, oo)*Interval(-oo, oo)) + first_quad_plane = ComplexRegion(Interval(0, oo)*Interval(0, oo)) + + assert upper_half_plane.intersect(unit_square) == upper_half_unit_square + assert right_half_plane.intersect(first_quad_plane) == first_quad_plane + assert upper_half_plane.intersect(right_half_plane) == first_quad_plane + assert upper_half_plane.intersect(lower_half_plane) == X_axis + + c1 = ComplexRegion(Interval(-5, 5)*Interval(-10, 10)) + assert c1.intersect(Interval(2, 7)) == Interval(2, 5) + assert c1.intersect(Interval(5, 7)) == FiniteSet(5) + assert c1.intersect(Interval(6, 9)) is S.EmptySet + + # unevaluated object + C1 = ComplexRegion(Interval(0, 1)*Interval(0, 2*S.Pi), polar=True) + C2 = ComplexRegion(Interval(-1, 1)*Interval(-1, 1)) + assert C1.intersect(C2) == Intersection(C1, C2, evaluate=False) + + +def test_ComplexRegion_union(): + # Polar form + c1 = ComplexRegion(Interval(0, 1)*Interval(0, 2*S.Pi), polar=True) + c2 = ComplexRegion(Interval(0, 1)*Interval(0, S.Pi), polar=True) + c3 = ComplexRegion(Interval(0, oo)*Interval(0, S.Pi), polar=True) + c4 = ComplexRegion(Interval(0, oo)*Interval(S.Pi, 2*S.Pi), polar=True) + + p1 = Union(Interval(0, 1)*Interval(0, 2*S.Pi), Interval(0, 1)*Interval(0, S.Pi)) + p2 = Union(Interval(0, oo)*Interval(0, S.Pi), Interval(0, oo)*Interval(S.Pi, 2*S.Pi)) + + assert c1.union(c2) == ComplexRegion(p1, polar=True) + assert c3.union(c4) == ComplexRegion(p2, polar=True) + + # Rectangular form + c5 = ComplexRegion(Interval(2, 5)*Interval(6, 9)) + c6 = ComplexRegion(Interval(4, 6)*Interval(10, 12)) + c7 = ComplexRegion(Interval(0, 10)*Interval(-10, 0)) + c8 = ComplexRegion(Interval(12, 16)*Interval(14, 20)) + + p3 = Union(Interval(2, 5)*Interval(6, 9), Interval(4, 6)*Interval(10, 12)) + p4 = Union(Interval(0, 10)*Interval(-10, 0), Interval(12, 16)*Interval(14, 20)) + + assert c5.union(c6) == ComplexRegion(p3) + assert c7.union(c8) == ComplexRegion(p4) + + assert c1.union(Interval(2, 4)) == Union(c1, Interval(2, 4), evaluate=False) + assert c5.union(Interval(2, 4)) == Union(c5, ComplexRegion.from_real(Interval(2, 4))) + + +def test_ComplexRegion_from_real(): + c1 = ComplexRegion(Interval(0, 1) * Interval(0, 2 * S.Pi), polar=True) + + raises(ValueError, lambda: c1.from_real(c1)) + assert c1.from_real(Interval(-1, 1)) == ComplexRegion(Interval(-1, 1) * FiniteSet(0), False) + + +def test_ComplexRegion_measure(): + a, b = Interval(2, 5), Interval(4, 8) + theta1, theta2 = Interval(0, 2*S.Pi), Interval(0, S.Pi) + c1 = ComplexRegion(a*b) + c2 = ComplexRegion(Union(a*theta1, b*theta2), polar=True) + + assert c1.measure == 12 + assert c2.measure == 9*pi + + +def test_normalize_theta_set(): + # Interval + assert normalize_theta_set(Interval(pi, 2*pi)) == \ + Union(FiniteSet(0), Interval.Ropen(pi, 2*pi)) + assert normalize_theta_set(Interval(pi*Rational(9, 2), 5*pi)) == Interval(pi/2, pi) + assert normalize_theta_set(Interval(pi*Rational(-3, 2), pi/2)) == Interval.Ropen(0, 2*pi) + assert normalize_theta_set(Interval.open(pi*Rational(-3, 2), pi/2)) == \ + Union(Interval.Ropen(0, pi/2), Interval.open(pi/2, 2*pi)) + assert normalize_theta_set(Interval.open(pi*Rational(-7, 2), pi*Rational(-3, 2))) == \ + Union(Interval.Ropen(0, pi/2), Interval.open(pi/2, 2*pi)) + assert normalize_theta_set(Interval(-pi/2, pi/2)) == \ + Union(Interval(0, pi/2), Interval.Ropen(pi*Rational(3, 2), 2*pi)) + assert normalize_theta_set(Interval.open(-pi/2, pi/2)) == \ + Union(Interval.Ropen(0, pi/2), Interval.open(pi*Rational(3, 2), 2*pi)) + assert normalize_theta_set(Interval(-4*pi, 3*pi)) == Interval.Ropen(0, 2*pi) + assert normalize_theta_set(Interval(pi*Rational(-3, 2), -pi/2)) == Interval(pi/2, pi*Rational(3, 2)) + assert normalize_theta_set(Interval.open(0, 2*pi)) == Interval.open(0, 2*pi) + assert normalize_theta_set(Interval.Ropen(-pi/2, pi/2)) == \ + Union(Interval.Ropen(0, pi/2), Interval.Ropen(pi*Rational(3, 2), 2*pi)) + assert normalize_theta_set(Interval.Lopen(-pi/2, pi/2)) == \ + Union(Interval(0, pi/2), Interval.open(pi*Rational(3, 2), 2*pi)) + assert normalize_theta_set(Interval(-pi/2, pi/2)) == \ + Union(Interval(0, pi/2), Interval.Ropen(pi*Rational(3, 2), 2*pi)) + assert normalize_theta_set(Interval.open(4*pi, pi*Rational(9, 2))) == Interval.open(0, pi/2) + assert normalize_theta_set(Interval.Lopen(4*pi, pi*Rational(9, 2))) == Interval.Lopen(0, pi/2) + assert normalize_theta_set(Interval.Ropen(4*pi, pi*Rational(9, 2))) == Interval.Ropen(0, pi/2) + assert normalize_theta_set(Interval.open(3*pi, 5*pi)) == \ + Union(Interval.Ropen(0, pi), Interval.open(pi, 2*pi)) + + # FiniteSet + assert normalize_theta_set(FiniteSet(0, pi, 3*pi)) == FiniteSet(0, pi) + assert normalize_theta_set(FiniteSet(0, pi/2, pi, 2*pi)) == FiniteSet(0, pi/2, pi) + assert normalize_theta_set(FiniteSet(0, -pi/2, -pi, -2*pi)) == FiniteSet(0, pi, pi*Rational(3, 2)) + assert normalize_theta_set(FiniteSet(pi*Rational(-3, 2), pi/2)) == \ + FiniteSet(pi/2) + assert normalize_theta_set(FiniteSet(2*pi)) == FiniteSet(0) + + # Unions + assert normalize_theta_set(Union(Interval(0, pi/3), Interval(pi/2, pi))) == \ + Union(Interval(0, pi/3), Interval(pi/2, pi)) + assert normalize_theta_set(Union(Interval(0, pi), Interval(2*pi, pi*Rational(7, 3)))) == \ + Interval(0, pi) + + # ValueError for non-real sets + raises(ValueError, lambda: normalize_theta_set(S.Complexes)) + + # NotImplementedError for subset of reals + raises(NotImplementedError, lambda: normalize_theta_set(Interval(0, 1))) + + # NotImplementedError without pi as coefficient + raises(NotImplementedError, lambda: normalize_theta_set(Interval(1, 2*pi))) + raises(NotImplementedError, lambda: normalize_theta_set(Interval(2*pi, 10))) + raises(NotImplementedError, lambda: normalize_theta_set(FiniteSet(0, 3, 3*pi))) + + +def test_ComplexRegion_FiniteSet(): + x, y, z, a, b, c = symbols('x y z a b c') + + # Issue #9669 + assert ComplexRegion(FiniteSet(a, b, c)*FiniteSet(x, y, z)) == \ + FiniteSet(a + I*x, a + I*y, a + I*z, b + I*x, b + I*y, + b + I*z, c + I*x, c + I*y, c + I*z) + assert ComplexRegion(FiniteSet(2)*FiniteSet(3)) == FiniteSet(2 + 3*I) + + +def test_union_RealSubSet(): + assert (S.Complexes).union(Interval(1, 2)) == S.Complexes + assert (S.Complexes).union(S.Integers) == S.Complexes + + +def test_SetKind_fancySet(): + G = lambda *args: ImageSet(Lambda(x, x ** 2), *args) + assert G(Interval(1, 4)).kind is SetKind(NumberKind) + assert G(FiniteSet(1, 4)).kind is SetKind(NumberKind) + assert S.Rationals.kind is SetKind(NumberKind) + assert S.Naturals.kind is SetKind(NumberKind) + assert S.Integers.kind is SetKind(NumberKind) + assert Range(3).kind is SetKind(NumberKind) + a = Interval(2, 3) + b = Interval(4, 6) + c1 = ComplexRegion(a*b) + assert c1.kind is SetKind(TupleKind(NumberKind, NumberKind)) + + +def test_issue_9980(): + c1 = ComplexRegion(Interval(1, 2)*Interval(2, 3)) + c2 = ComplexRegion(Interval(1, 5)*Interval(1, 3)) + R = Union(c1, c2) + assert simplify(R) == ComplexRegion(Union(Interval(1, 2)*Interval(2, 3), \ + Interval(1, 5)*Interval(1, 3)), False) + assert c1.func(*c1.args) == c1 + assert R.func(*R.args) == R + + +def test_issue_11732(): + interval12 = Interval(1, 2) + finiteset1234 = FiniteSet(1, 2, 3, 4) + pointComplex = Tuple(1, 5) + + assert (interval12 in S.Naturals) == False + assert (interval12 in S.Naturals0) == False + assert (interval12 in S.Integers) == False + assert (interval12 in S.Complexes) == False + + assert (finiteset1234 in S.Naturals) == False + assert (finiteset1234 in S.Naturals0) == False + assert (finiteset1234 in S.Integers) == False + assert (finiteset1234 in S.Complexes) == False + + assert (pointComplex in S.Naturals) == False + assert (pointComplex in S.Naturals0) == False + assert (pointComplex in S.Integers) == False + assert (pointComplex in S.Complexes) == True + + +def test_issue_11730(): + unit = Interval(0, 1) + square = ComplexRegion(unit ** 2) + + assert Union(S.Complexes, FiniteSet(oo)) != S.Complexes + assert Union(S.Complexes, FiniteSet(eye(4))) != S.Complexes + assert Union(unit, square) == square + assert Intersection(S.Reals, square) == unit + + +def test_issue_11938(): + unit = Interval(0, 1) + ival = Interval(1, 2) + cr1 = ComplexRegion(ival * unit) + + assert Intersection(cr1, S.Reals) == ival + assert Intersection(cr1, unit) == FiniteSet(1) + + arg1 = Interval(0, S.Pi) + arg2 = FiniteSet(S.Pi) + arg3 = Interval(S.Pi / 4, 3 * S.Pi / 4) + cp1 = ComplexRegion(unit * arg1, polar=True) + cp2 = ComplexRegion(unit * arg2, polar=True) + cp3 = ComplexRegion(unit * arg3, polar=True) + + assert Intersection(cp1, S.Reals) == Interval(-1, 1) + assert Intersection(cp2, S.Reals) == Interval(-1, 0) + assert Intersection(cp3, S.Reals) == FiniteSet(0) + + +def test_issue_11914(): + a, b = Interval(0, 1), Interval(0, pi) + c, d = Interval(2, 3), Interval(pi, 3 * pi / 2) + cp1 = ComplexRegion(a * b, polar=True) + cp2 = ComplexRegion(c * d, polar=True) + + assert -3 in cp1.union(cp2) + assert -3 in cp2.union(cp1) + assert -5 not in cp1.union(cp2) + + +def test_issue_9543(): + assert ImageSet(Lambda(x, x**2), S.Naturals).is_subset(S.Reals) + + +def test_issue_16871(): + assert ImageSet(Lambda(x, x), FiniteSet(1)) == {1} + assert ImageSet(Lambda(x, x - 3), S.Integers + ).intersection(S.Integers) is S.Integers + + +@XFAIL +def test_issue_16871b(): + assert ImageSet(Lambda(x, x - 3), S.Integers).is_subset(S.Integers) + + +def test_issue_18050(): + assert imageset(Lambda(x, I*x + 1), S.Integers + ) == ImageSet(Lambda(x, I*x + 1), S.Integers) + assert imageset(Lambda(x, 3*I*x + 4 + 8*I), S.Integers + ) == ImageSet(Lambda(x, 3*I*x + 4 + 2*I), S.Integers) + # no 'Mod' for next 2 tests: + assert imageset(Lambda(x, 2*x + 3*I), S.Integers + ) == ImageSet(Lambda(x, 2*x + 3*I), S.Integers) + r = Symbol('r', positive=True) + assert imageset(Lambda(x, r*x + 10), S.Integers + ) == ImageSet(Lambda(x, r*x + 10), S.Integers) + # reduce real part: + assert imageset(Lambda(x, 3*x + 8 + 5*I), S.Integers + ) == ImageSet(Lambda(x, 3*x + 2 + 5*I), S.Integers) + + +def test_Rationals(): + assert S.Integers.is_subset(S.Rationals) + assert S.Naturals.is_subset(S.Rationals) + assert S.Naturals0.is_subset(S.Rationals) + assert S.Rationals.is_subset(S.Reals) + assert S.Rationals.inf is -oo + assert S.Rationals.sup is oo + it = iter(S.Rationals) + assert [next(it) for i in range(12)] == [ + 0, 1, -1, S.Half, 2, Rational(-1, 2), -2, + Rational(1, 3), 3, Rational(-1, 3), -3, Rational(2, 3)] + assert Basic() not in S.Rationals + assert S.Half in S.Rationals + assert S.Rationals.contains(0.5) == Contains( + 0.5, S.Rationals, evaluate=False) + assert 2 in S.Rationals + r = symbols('r', rational=True) + assert r in S.Rationals + raises(TypeError, lambda: x in S.Rationals) + # issue #18134: + assert S.Rationals.boundary == S.Reals + assert S.Rationals.closure == S.Reals + assert S.Rationals.is_open == False + assert S.Rationals.is_closed == False + + +def test_NZQRC_unions(): + # check that all trivial number set unions are simplified: + nbrsets = (S.Naturals, S.Naturals0, S.Integers, S.Rationals, + S.Reals, S.Complexes) + unions = (Union(a, b) for a in nbrsets for b in nbrsets) + assert all(u.is_Union is False for u in unions) + + +def test_imageset_intersection(): + n = Dummy() + s = ImageSet(Lambda(n, -I*(I*(2*pi*n - pi/4) + + log(Abs(sqrt(-I))))), S.Integers) + assert s.intersect(S.Reals) == ImageSet( + Lambda(n, 2*pi*n + pi*Rational(7, 4)), S.Integers) + + +def test_issue_17858(): + assert 1 in Range(-oo, oo) + assert 0 in Range(oo, -oo, -1) + assert oo not in Range(-oo, oo) + assert -oo not in Range(-oo, oo) + +def test_issue_17859(): + r = Range(-oo,oo) + raises(ValueError,lambda: r[::2]) + raises(ValueError, lambda: r[::-2]) + r = Range(oo,-oo,-1) + raises(ValueError,lambda: r[::2]) + raises(ValueError, lambda: r[::-2]) diff --git a/phivenv/Lib/site-packages/sympy/sets/tests/test_ordinals.py b/phivenv/Lib/site-packages/sympy/sets/tests/test_ordinals.py new file mode 100644 index 0000000000000000000000000000000000000000..973ca329586f3e904f9377c44022c266f81c805c --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/sets/tests/test_ordinals.py @@ -0,0 +1,67 @@ +from sympy.sets.ordinals import Ordinal, OmegaPower, ord0, omega +from sympy.testing.pytest import raises + +def test_string_ordinals(): + assert str(omega) == 'w' + assert str(Ordinal(OmegaPower(5, 3), OmegaPower(3, 2))) == 'w**5*3 + w**3*2' + assert str(Ordinal(OmegaPower(5, 3), OmegaPower(0, 5))) == 'w**5*3 + 5' + assert str(Ordinal(OmegaPower(1, 3), OmegaPower(0, 5))) == 'w*3 + 5' + assert str(Ordinal(OmegaPower(omega + 1, 1), OmegaPower(3, 2))) == 'w**(w + 1) + w**3*2' + +def test_addition_with_integers(): + assert 3 + Ordinal(OmegaPower(5, 3)) == Ordinal(OmegaPower(5, 3)) + assert Ordinal(OmegaPower(5, 3))+3 == Ordinal(OmegaPower(5, 3), OmegaPower(0, 3)) + assert Ordinal(OmegaPower(5, 3), OmegaPower(0, 2))+3 == \ + Ordinal(OmegaPower(5, 3), OmegaPower(0, 5)) + + +def test_addition_with_ordinals(): + assert Ordinal(OmegaPower(5, 3), OmegaPower(3, 2)) + Ordinal(OmegaPower(3, 3)) == \ + Ordinal(OmegaPower(5, 3), OmegaPower(3, 5)) + assert Ordinal(OmegaPower(5, 3), OmegaPower(3, 2)) + Ordinal(OmegaPower(4, 2)) == \ + Ordinal(OmegaPower(5, 3), OmegaPower(4, 2)) + assert Ordinal(OmegaPower(omega, 2), OmegaPower(3, 2)) + Ordinal(OmegaPower(4, 2)) == \ + Ordinal(OmegaPower(omega, 2), OmegaPower(4, 2)) + +def test_comparison(): + assert Ordinal(OmegaPower(5, 3)) > Ordinal(OmegaPower(4, 3), OmegaPower(2, 1)) + assert Ordinal(OmegaPower(5, 3), OmegaPower(3, 2)) < Ordinal(OmegaPower(5, 4)) + assert Ordinal(OmegaPower(5, 4)) < Ordinal(OmegaPower(5, 5), OmegaPower(4, 1)) + + assert Ordinal(OmegaPower(5, 3), OmegaPower(3, 2)) == \ + Ordinal(OmegaPower(5, 3), OmegaPower(3, 2)) + assert not Ordinal(OmegaPower(5, 3), OmegaPower(3, 2)) == Ordinal(OmegaPower(5, 3)) + assert Ordinal(OmegaPower(omega, 3)) > Ordinal(OmegaPower(5, 3)) + +def test_multiplication_with_integers(): + w = omega + assert 3*w == w + assert w*9 == Ordinal(OmegaPower(1, 9)) + +def test_multiplication(): + w = omega + assert w*(w + 1) == w*w + w + assert (w + 1)*(w + 1) == w*w + w + 1 + assert w*1 == w + assert 1*w == w + assert w*ord0 == ord0 + assert ord0*w == ord0 + assert w**w == w * w**w + assert (w**w)*w*w == w**(w + 2) + +def test_exponentiation(): + w = omega + assert w**2 == w*w + assert w**3 == w*w*w + assert w**(w + 1) == Ordinal(OmegaPower(omega + 1, 1)) + assert (w**w)*(w**w) == w**(w*2) + +def test_comapre_not_instance(): + w = OmegaPower(omega + 1, 1) + assert(not (w == None)) + assert(not (w < 5)) + raises(TypeError, lambda: w < 6.66) + +def test_is_successort(): + w = Ordinal(OmegaPower(5, 1)) + assert not w.is_successor_ordinal diff --git a/phivenv/Lib/site-packages/sympy/sets/tests/test_powerset.py b/phivenv/Lib/site-packages/sympy/sets/tests/test_powerset.py new file mode 100644 index 0000000000000000000000000000000000000000..2e3a407d565f6b9537a296af103ec0a4e137cff9 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/sets/tests/test_powerset.py @@ -0,0 +1,141 @@ +from sympy.core.expr import unchanged +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.sets.contains import Contains +from sympy.sets.fancysets import Interval +from sympy.sets.powerset import PowerSet +from sympy.sets.sets import FiniteSet +from sympy.testing.pytest import raises, XFAIL + + +def test_powerset_creation(): + assert unchanged(PowerSet, FiniteSet(1, 2)) + assert unchanged(PowerSet, S.EmptySet) + raises(ValueError, lambda: PowerSet(123)) + assert unchanged(PowerSet, S.Reals) + assert unchanged(PowerSet, S.Integers) + + +def test_powerset_rewrite_FiniteSet(): + assert PowerSet(FiniteSet(1, 2)).rewrite(FiniteSet) == \ + FiniteSet(S.EmptySet, FiniteSet(1), FiniteSet(2), FiniteSet(1, 2)) + assert PowerSet(S.EmptySet).rewrite(FiniteSet) == FiniteSet(S.EmptySet) + assert PowerSet(S.Naturals).rewrite(FiniteSet) == PowerSet(S.Naturals) + + +def test_finiteset_rewrite_powerset(): + assert FiniteSet(S.EmptySet).rewrite(PowerSet) == PowerSet(S.EmptySet) + assert FiniteSet( + S.EmptySet, FiniteSet(1), + FiniteSet(2), FiniteSet(1, 2)).rewrite(PowerSet) == \ + PowerSet(FiniteSet(1, 2)) + assert FiniteSet(1, 2, 3).rewrite(PowerSet) == FiniteSet(1, 2, 3) + + +def test_powerset__contains__(): + subset_series = [ + S.EmptySet, + FiniteSet(1, 2), + S.Naturals, + S.Naturals0, + S.Integers, + S.Rationals, + S.Reals, + S.Complexes] + + l = len(subset_series) + for i in range(l): + for j in range(l): + if i <= j: + assert subset_series[i] in \ + PowerSet(subset_series[j], evaluate=False) + else: + assert subset_series[i] not in \ + PowerSet(subset_series[j], evaluate=False) + + +@XFAIL +def test_failing_powerset__contains__(): + # XXX These are failing when evaluate=True, + # but using unevaluated PowerSet works fine. + assert FiniteSet(1, 2) not in PowerSet(S.EmptySet).rewrite(FiniteSet) + assert S.Naturals not in PowerSet(S.EmptySet).rewrite(FiniteSet) + assert S.Naturals not in PowerSet(FiniteSet(1, 2)).rewrite(FiniteSet) + assert S.Naturals0 not in PowerSet(S.EmptySet).rewrite(FiniteSet) + assert S.Naturals0 not in PowerSet(FiniteSet(1, 2)).rewrite(FiniteSet) + assert S.Integers not in PowerSet(S.EmptySet).rewrite(FiniteSet) + assert S.Integers not in PowerSet(FiniteSet(1, 2)).rewrite(FiniteSet) + assert S.Rationals not in PowerSet(S.EmptySet).rewrite(FiniteSet) + assert S.Rationals not in PowerSet(FiniteSet(1, 2)).rewrite(FiniteSet) + assert S.Reals not in PowerSet(S.EmptySet).rewrite(FiniteSet) + assert S.Reals not in PowerSet(FiniteSet(1, 2)).rewrite(FiniteSet) + assert S.Complexes not in PowerSet(S.EmptySet).rewrite(FiniteSet) + assert S.Complexes not in PowerSet(FiniteSet(1, 2)).rewrite(FiniteSet) + + +def test_powerset__len__(): + A = PowerSet(S.EmptySet, evaluate=False) + assert len(A) == 1 + A = PowerSet(A, evaluate=False) + assert len(A) == 2 + A = PowerSet(A, evaluate=False) + assert len(A) == 4 + A = PowerSet(A, evaluate=False) + assert len(A) == 16 + + +def test_powerset__iter__(): + a = PowerSet(FiniteSet(1, 2)).__iter__() + assert next(a) == S.EmptySet + assert next(a) == FiniteSet(1) + assert next(a) == FiniteSet(2) + assert next(a) == FiniteSet(1, 2) + + a = PowerSet(S.Naturals).__iter__() + assert next(a) == S.EmptySet + assert next(a) == FiniteSet(1) + assert next(a) == FiniteSet(2) + assert next(a) == FiniteSet(1, 2) + assert next(a) == FiniteSet(3) + assert next(a) == FiniteSet(1, 3) + assert next(a) == FiniteSet(2, 3) + assert next(a) == FiniteSet(1, 2, 3) + + +def test_powerset_contains(): + A = PowerSet(FiniteSet(1), evaluate=False) + assert A.contains(2) == Contains(2, A) + + x = Symbol('x') + + A = PowerSet(FiniteSet(x), evaluate=False) + assert A.contains(FiniteSet(1)) == Contains(FiniteSet(1), A) + + +def test_powerset_method(): + # EmptySet + A = FiniteSet() + pset = A.powerset() + assert len(pset) == 1 + assert pset == FiniteSet(S.EmptySet) + + # FiniteSets + A = FiniteSet(1, 2) + pset = A.powerset() + assert len(pset) == 2**len(A) + assert pset == FiniteSet(FiniteSet(), FiniteSet(1), + FiniteSet(2), A) + # Not finite sets + A = Interval(0, 1) + assert A.powerset() == PowerSet(A) + +def test_is_subset(): + # covers line 101-102 + # initialize powerset(1), which is a subset of powerset(1,2) + subset = PowerSet(FiniteSet(1)) + pset = PowerSet(FiniteSet(1, 2)) + bad_set = PowerSet(FiniteSet(2, 3)) + # assert "subset" is subset of pset == True + assert subset.is_subset(pset) + # assert "bad_set" is subset of pset == False + assert not pset.is_subset(bad_set) diff --git a/phivenv/Lib/site-packages/sympy/sets/tests/test_setexpr.py b/phivenv/Lib/site-packages/sympy/sets/tests/test_setexpr.py new file mode 100644 index 0000000000000000000000000000000000000000..faab1261c8d3e86901b04d30e8bc94de31642b93 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/sets/tests/test_setexpr.py @@ -0,0 +1,317 @@ +from sympy.sets.setexpr import SetExpr +from sympy.sets import Interval, FiniteSet, Intersection, ImageSet, Union + +from sympy.core.expr import Expr +from sympy.core.function import Lambda +from sympy.core.numbers import (I, Rational, oo) +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol, symbols) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import (Max, Min, sqrt) +from sympy.functions.elementary.trigonometric import cos +from sympy.sets.sets import Set + + +a, x = symbols("a, x") +_d = Dummy("d") + + +def test_setexpr(): + se = SetExpr(Interval(0, 1)) + assert isinstance(se.set, Set) + assert isinstance(se, Expr) + + +def test_scalar_funcs(): + assert SetExpr(Interval(0, 1)).set == Interval(0, 1) + a, b = Symbol('a', real=True), Symbol('b', real=True) + a, b = 1, 2 + # TODO: add support for more functions in the future: + for f in [exp, log]: + input_se = f(SetExpr(Interval(a, b))) + output = input_se.set + expected = Interval(Min(f(a), f(b)), Max(f(a), f(b))) + assert output == expected + + +def test_Add_Mul(): + assert (SetExpr(Interval(0, 1)) + 1).set == Interval(1, 2) + assert (SetExpr(Interval(0, 1))*2).set == Interval(0, 2) + + +def test_Pow(): + assert (SetExpr(Interval(0, 2))**2).set == Interval(0, 4) + + +def test_compound(): + assert (exp(SetExpr(Interval(0, 1))*2 + 1)).set == \ + Interval(exp(1), exp(3)) + + +def test_Interval_Interval(): + assert (SetExpr(Interval(1, 2)) + SetExpr(Interval(10, 20))).set == \ + Interval(11, 22) + assert (SetExpr(Interval(1, 2))*SetExpr(Interval(10, 20))).set == \ + Interval(10, 40) + + +def test_FiniteSet_FiniteSet(): + assert (SetExpr(FiniteSet(1, 2, 3)) + SetExpr(FiniteSet(1, 2))).set == \ + FiniteSet(2, 3, 4, 5) + assert (SetExpr(FiniteSet(1, 2, 3))*SetExpr(FiniteSet(1, 2))).set == \ + FiniteSet(1, 2, 3, 4, 6) + + +def test_Interval_FiniteSet(): + assert (SetExpr(FiniteSet(1, 2)) + SetExpr(Interval(0, 10))).set == \ + Interval(1, 12) + + +def test_Many_Sets(): + assert (SetExpr(Interval(0, 1)) + + SetExpr(Interval(2, 3)) + + SetExpr(FiniteSet(10, 11, 12))).set == Interval(12, 16) + + +def test_same_setexprs_are_not_identical(): + a = SetExpr(FiniteSet(0, 1)) + b = SetExpr(FiniteSet(0, 1)) + assert (a + b).set == FiniteSet(0, 1, 2) + + # Cannot detect the set being the same: + # assert (a + a).set == FiniteSet(0, 2) + + +def test_Interval_arithmetic(): + i12cc = SetExpr(Interval(1, 2)) + i12lo = SetExpr(Interval.Lopen(1, 2)) + i12ro = SetExpr(Interval.Ropen(1, 2)) + i12o = SetExpr(Interval.open(1, 2)) + + n23cc = SetExpr(Interval(-2, 3)) + n23lo = SetExpr(Interval.Lopen(-2, 3)) + n23ro = SetExpr(Interval.Ropen(-2, 3)) + n23o = SetExpr(Interval.open(-2, 3)) + + n3n2cc = SetExpr(Interval(-3, -2)) + + assert i12cc + i12cc == SetExpr(Interval(2, 4)) + assert i12cc - i12cc == SetExpr(Interval(-1, 1)) + assert i12cc*i12cc == SetExpr(Interval(1, 4)) + assert i12cc/i12cc == SetExpr(Interval(S.Half, 2)) + assert i12cc**2 == SetExpr(Interval(1, 4)) + assert i12cc**3 == SetExpr(Interval(1, 8)) + + assert i12lo + i12ro == SetExpr(Interval.open(2, 4)) + assert i12lo - i12ro == SetExpr(Interval.Lopen(-1, 1)) + assert i12lo*i12ro == SetExpr(Interval.open(1, 4)) + assert i12lo/i12ro == SetExpr(Interval.Lopen(S.Half, 2)) + assert i12lo + i12lo == SetExpr(Interval.Lopen(2, 4)) + assert i12lo - i12lo == SetExpr(Interval.open(-1, 1)) + assert i12lo*i12lo == SetExpr(Interval.Lopen(1, 4)) + assert i12lo/i12lo == SetExpr(Interval.open(S.Half, 2)) + assert i12lo + i12cc == SetExpr(Interval.Lopen(2, 4)) + assert i12lo - i12cc == SetExpr(Interval.Lopen(-1, 1)) + assert i12lo*i12cc == SetExpr(Interval.Lopen(1, 4)) + assert i12lo/i12cc == SetExpr(Interval.Lopen(S.Half, 2)) + assert i12lo + i12o == SetExpr(Interval.open(2, 4)) + assert i12lo - i12o == SetExpr(Interval.open(-1, 1)) + assert i12lo*i12o == SetExpr(Interval.open(1, 4)) + assert i12lo/i12o == SetExpr(Interval.open(S.Half, 2)) + assert i12lo**2 == SetExpr(Interval.Lopen(1, 4)) + assert i12lo**3 == SetExpr(Interval.Lopen(1, 8)) + + assert i12ro + i12ro == SetExpr(Interval.Ropen(2, 4)) + assert i12ro - i12ro == SetExpr(Interval.open(-1, 1)) + assert i12ro*i12ro == SetExpr(Interval.Ropen(1, 4)) + assert i12ro/i12ro == SetExpr(Interval.open(S.Half, 2)) + assert i12ro + i12cc == SetExpr(Interval.Ropen(2, 4)) + assert i12ro - i12cc == SetExpr(Interval.Ropen(-1, 1)) + assert i12ro*i12cc == SetExpr(Interval.Ropen(1, 4)) + assert i12ro/i12cc == SetExpr(Interval.Ropen(S.Half, 2)) + assert i12ro + i12o == SetExpr(Interval.open(2, 4)) + assert i12ro - i12o == SetExpr(Interval.open(-1, 1)) + assert i12ro*i12o == SetExpr(Interval.open(1, 4)) + assert i12ro/i12o == SetExpr(Interval.open(S.Half, 2)) + assert i12ro**2 == SetExpr(Interval.Ropen(1, 4)) + assert i12ro**3 == SetExpr(Interval.Ropen(1, 8)) + + assert i12o + i12lo == SetExpr(Interval.open(2, 4)) + assert i12o - i12lo == SetExpr(Interval.open(-1, 1)) + assert i12o*i12lo == SetExpr(Interval.open(1, 4)) + assert i12o/i12lo == SetExpr(Interval.open(S.Half, 2)) + assert i12o + i12ro == SetExpr(Interval.open(2, 4)) + assert i12o - i12ro == SetExpr(Interval.open(-1, 1)) + assert i12o*i12ro == SetExpr(Interval.open(1, 4)) + assert i12o/i12ro == SetExpr(Interval.open(S.Half, 2)) + assert i12o + i12cc == SetExpr(Interval.open(2, 4)) + assert i12o - i12cc == SetExpr(Interval.open(-1, 1)) + assert i12o*i12cc == SetExpr(Interval.open(1, 4)) + assert i12o/i12cc == SetExpr(Interval.open(S.Half, 2)) + assert i12o**2 == SetExpr(Interval.open(1, 4)) + assert i12o**3 == SetExpr(Interval.open(1, 8)) + + assert n23cc + n23cc == SetExpr(Interval(-4, 6)) + assert n23cc - n23cc == SetExpr(Interval(-5, 5)) + assert n23cc*n23cc == SetExpr(Interval(-6, 9)) + assert n23cc/n23cc == SetExpr(Interval.open(-oo, oo)) + assert n23cc + n23ro == SetExpr(Interval.Ropen(-4, 6)) + assert n23cc - n23ro == SetExpr(Interval.Lopen(-5, 5)) + assert n23cc*n23ro == SetExpr(Interval.Ropen(-6, 9)) + assert n23cc/n23ro == SetExpr(Interval.Lopen(-oo, oo)) + assert n23cc + n23lo == SetExpr(Interval.Lopen(-4, 6)) + assert n23cc - n23lo == SetExpr(Interval.Ropen(-5, 5)) + assert n23cc*n23lo == SetExpr(Interval(-6, 9)) + assert n23cc/n23lo == SetExpr(Interval.open(-oo, oo)) + assert n23cc + n23o == SetExpr(Interval.open(-4, 6)) + assert n23cc - n23o == SetExpr(Interval.open(-5, 5)) + assert n23cc*n23o == SetExpr(Interval.open(-6, 9)) + assert n23cc/n23o == SetExpr(Interval.open(-oo, oo)) + assert n23cc**2 == SetExpr(Interval(0, 9)) + assert n23cc**3 == SetExpr(Interval(-8, 27)) + + n32cc = SetExpr(Interval(-3, 2)) + n32lo = SetExpr(Interval.Lopen(-3, 2)) + n32ro = SetExpr(Interval.Ropen(-3, 2)) + assert n32cc*n32lo == SetExpr(Interval.Ropen(-6, 9)) + assert n32cc*n32cc == SetExpr(Interval(-6, 9)) + assert n32lo*n32cc == SetExpr(Interval.Ropen(-6, 9)) + assert n32cc*n32ro == SetExpr(Interval(-6, 9)) + assert n32lo*n32ro == SetExpr(Interval.Ropen(-6, 9)) + assert n32cc/n32lo == SetExpr(Interval.Ropen(-oo, oo)) + assert i12cc/n32lo == SetExpr(Interval.Ropen(-oo, oo)) + + assert n3n2cc**2 == SetExpr(Interval(4, 9)) + assert n3n2cc**3 == SetExpr(Interval(-27, -8)) + + assert n23cc + i12cc == SetExpr(Interval(-1, 5)) + assert n23cc - i12cc == SetExpr(Interval(-4, 2)) + assert n23cc*i12cc == SetExpr(Interval(-4, 6)) + assert n23cc/i12cc == SetExpr(Interval(-2, 3)) + + +def test_SetExpr_Intersection(): + x, y, z, w = symbols("x y z w") + set1 = Interval(x, y) + set2 = Interval(w, z) + inter = Intersection(set1, set2) + se = SetExpr(inter) + assert exp(se).set == Intersection( + ImageSet(Lambda(x, exp(x)), set1), + ImageSet(Lambda(x, exp(x)), set2)) + assert cos(se).set == ImageSet(Lambda(x, cos(x)), inter) + + +def test_SetExpr_Interval_div(): + # TODO: some expressions cannot be calculated due to bugs (currently + # commented): + assert SetExpr(Interval(-3, -2))/SetExpr(Interval(-2, 1)) == SetExpr(Interval(-oo, oo)) + assert SetExpr(Interval(2, 3))/SetExpr(Interval(-2, 2)) == SetExpr(Interval(-oo, oo)) + + assert SetExpr(Interval(-3, -2))/SetExpr(Interval(0, 4)) == SetExpr(Interval(-oo, Rational(-1, 2))) + assert SetExpr(Interval(2, 4))/SetExpr(Interval(-3, 0)) == SetExpr(Interval(-oo, Rational(-2, 3))) + assert SetExpr(Interval(2, 4))/SetExpr(Interval(0, 3)) == SetExpr(Interval(Rational(2, 3), oo)) + + # assert SetExpr(Interval(0, 1))/SetExpr(Interval(0, 1)) == SetExpr(Interval(0, oo)) + # assert SetExpr(Interval(-1, 0))/SetExpr(Interval(0, 1)) == SetExpr(Interval(-oo, 0)) + assert SetExpr(Interval(-1, 2))/SetExpr(Interval(-2, 2)) == SetExpr(Interval(-oo, oo)) + + assert 1/SetExpr(Interval(-1, 2)) == SetExpr(Union(Interval(-oo, -1), Interval(S.Half, oo))) + + assert 1/SetExpr(Interval(0, 2)) == SetExpr(Interval(S.Half, oo)) + assert (-1)/SetExpr(Interval(0, 2)) == SetExpr(Interval(-oo, Rational(-1, 2))) + assert 1/SetExpr(Interval(-oo, 0)) == SetExpr(Interval.open(-oo, 0)) + assert 1/SetExpr(Interval(-1, 0)) == SetExpr(Interval(-oo, -1)) + # assert (-2)/SetExpr(Interval(-oo, 0)) == SetExpr(Interval(0, oo)) + # assert 1/SetExpr(Interval(-oo, -1)) == SetExpr(Interval(-1, 0)) + + # assert SetExpr(Interval(1, 2))/a == Mul(SetExpr(Interval(1, 2)), 1/a, evaluate=False) + + # assert SetExpr(Interval(1, 2))/0 == SetExpr(Interval(1, 2))*zoo + # assert SetExpr(Interval(1, oo))/oo == SetExpr(Interval(0, oo)) + # assert SetExpr(Interval(1, oo))/(-oo) == SetExpr(Interval(-oo, 0)) + # assert SetExpr(Interval(-oo, -1))/oo == SetExpr(Interval(-oo, 0)) + # assert SetExpr(Interval(-oo, -1))/(-oo) == SetExpr(Interval(0, oo)) + # assert SetExpr(Interval(-oo, oo))/oo == SetExpr(Interval(-oo, oo)) + # assert SetExpr(Interval(-oo, oo))/(-oo) == SetExpr(Interval(-oo, oo)) + # assert SetExpr(Interval(-1, oo))/oo == SetExpr(Interval(0, oo)) + # assert SetExpr(Interval(-1, oo))/(-oo) == SetExpr(Interval(-oo, 0)) + # assert SetExpr(Interval(-oo, 1))/oo == SetExpr(Interval(-oo, 0)) + # assert SetExpr(Interval(-oo, 1))/(-oo) == SetExpr(Interval(0, oo)) + + +def test_SetExpr_Interval_pow(): + assert SetExpr(Interval(0, 2))**2 == SetExpr(Interval(0, 4)) + assert SetExpr(Interval(-1, 1))**2 == SetExpr(Interval(0, 1)) + assert SetExpr(Interval(1, 2))**2 == SetExpr(Interval(1, 4)) + assert SetExpr(Interval(-1, 2))**3 == SetExpr(Interval(-1, 8)) + assert SetExpr(Interval(-1, 1))**0 == SetExpr(FiniteSet(1)) + + + assert SetExpr(Interval(1, 2))**Rational(5, 2) == SetExpr(Interval(1, 4*sqrt(2))) + #assert SetExpr(Interval(-1, 2))**Rational(1, 3) == SetExpr(Interval(-1, 2**Rational(1, 3))) + #assert SetExpr(Interval(0, 2))**S.Half == SetExpr(Interval(0, sqrt(2))) + + #assert SetExpr(Interval(-4, 2))**Rational(2, 3) == SetExpr(Interval(0, 2*2**Rational(1, 3))) + + #assert SetExpr(Interval(-1, 5))**S.Half == SetExpr(Interval(0, sqrt(5))) + #assert SetExpr(Interval(-oo, 2))**S.Half == SetExpr(Interval(0, sqrt(2))) + #assert SetExpr(Interval(-2, 3))**(Rational(-1, 4)) == SetExpr(Interval(0, oo)) + + assert SetExpr(Interval(1, 5))**(-2) == SetExpr(Interval(Rational(1, 25), 1)) + assert SetExpr(Interval(-1, 3))**(-2) == SetExpr(Interval(0, oo)) + + assert SetExpr(Interval(0, 2))**(-2) == SetExpr(Interval(Rational(1, 4), oo)) + assert SetExpr(Interval(-1, 2))**(-3) == SetExpr(Union(Interval(-oo, -1), Interval(Rational(1, 8), oo))) + assert SetExpr(Interval(-3, -2))**(-3) == SetExpr(Interval(Rational(-1, 8), Rational(-1, 27))) + assert SetExpr(Interval(-3, -2))**(-2) == SetExpr(Interval(Rational(1, 9), Rational(1, 4))) + #assert SetExpr(Interval(0, oo))**S.Half == SetExpr(Interval(0, oo)) + #assert SetExpr(Interval(-oo, -1))**Rational(1, 3) == SetExpr(Interval(-oo, -1)) + #assert SetExpr(Interval(-2, 3))**(Rational(-1, 3)) == SetExpr(Interval(-oo, oo)) + + assert SetExpr(Interval(-oo, 0))**(-2) == SetExpr(Interval.open(0, oo)) + assert SetExpr(Interval(-2, 0))**(-2) == SetExpr(Interval(Rational(1, 4), oo)) + + assert SetExpr(Interval(Rational(1, 3), S.Half))**oo == SetExpr(FiniteSet(0)) + assert SetExpr(Interval(0, S.Half))**oo == SetExpr(FiniteSet(0)) + assert SetExpr(Interval(S.Half, 1))**oo == SetExpr(Interval(0, oo)) + assert SetExpr(Interval(0, 1))**oo == SetExpr(Interval(0, oo)) + assert SetExpr(Interval(2, 3))**oo == SetExpr(FiniteSet(oo)) + assert SetExpr(Interval(1, 2))**oo == SetExpr(Interval(0, oo)) + assert SetExpr(Interval(S.Half, 3))**oo == SetExpr(Interval(0, oo)) + assert SetExpr(Interval(Rational(-1, 3), Rational(-1, 4)))**oo == SetExpr(FiniteSet(0)) + assert SetExpr(Interval(-1, Rational(-1, 2)))**oo == SetExpr(Interval(-oo, oo)) + assert SetExpr(Interval(-3, -2))**oo == SetExpr(FiniteSet(-oo, oo)) + assert SetExpr(Interval(-2, -1))**oo == SetExpr(Interval(-oo, oo)) + assert SetExpr(Interval(-2, Rational(-1, 2)))**oo == SetExpr(Interval(-oo, oo)) + assert SetExpr(Interval(Rational(-1, 2), S.Half))**oo == SetExpr(FiniteSet(0)) + assert SetExpr(Interval(Rational(-1, 2), 1))**oo == SetExpr(Interval(0, oo)) + assert SetExpr(Interval(Rational(-2, 3), 2))**oo == SetExpr(Interval(0, oo)) + assert SetExpr(Interval(-1, 1))**oo == SetExpr(Interval(-oo, oo)) + assert SetExpr(Interval(-1, S.Half))**oo == SetExpr(Interval(-oo, oo)) + assert SetExpr(Interval(-1, 2))**oo == SetExpr(Interval(-oo, oo)) + assert SetExpr(Interval(-2, S.Half))**oo == SetExpr(Interval(-oo, oo)) + + assert (SetExpr(Interval(1, 2))**x).dummy_eq(SetExpr(ImageSet(Lambda(_d, _d**x), Interval(1, 2)))) + + assert SetExpr(Interval(2, 3))**(-oo) == SetExpr(FiniteSet(0)) + assert SetExpr(Interval(0, 2))**(-oo) == SetExpr(Interval(0, oo)) + assert (SetExpr(Interval(-1, 2))**(-oo)).dummy_eq(SetExpr(ImageSet(Lambda(_d, _d**(-oo)), Interval(-1, 2)))) + + +def test_SetExpr_Integers(): + assert SetExpr(S.Integers) + 1 == SetExpr(S.Integers) + assert (SetExpr(S.Integers) + I).dummy_eq( + SetExpr(ImageSet(Lambda(_d, _d + I), S.Integers))) + assert SetExpr(S.Integers)*(-1) == SetExpr(S.Integers) + assert (SetExpr(S.Integers)*2).dummy_eq( + SetExpr(ImageSet(Lambda(_d, 2*_d), S.Integers))) + assert (SetExpr(S.Integers)*I).dummy_eq( + SetExpr(ImageSet(Lambda(_d, I*_d), S.Integers))) + # issue #18050: + assert SetExpr(S.Integers)._eval_func(Lambda(x, I*x + 1)).dummy_eq( + SetExpr(ImageSet(Lambda(_d, I*_d + 1), S.Integers))) + # needs improvement: + assert (SetExpr(S.Integers)*I + 1).dummy_eq( + SetExpr(ImageSet(Lambda(x, x + 1), + ImageSet(Lambda(_d, _d*I), S.Integers)))) diff --git a/phivenv/Lib/site-packages/sympy/sets/tests/test_sets.py b/phivenv/Lib/site-packages/sympy/sets/tests/test_sets.py new file mode 100644 index 0000000000000000000000000000000000000000..657ab19a90eb88ca48f266f7a5cf050504caed43 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/sets/tests/test_sets.py @@ -0,0 +1,1753 @@ +from sympy.concrete.summations import Sum +from sympy.core.add import Add +from sympy.core.containers import TupleKind +from sympy.core.function import Lambda +from sympy.core.kind import NumberKind, UndefinedKind +from sympy.core.numbers import (Float, I, Rational, nan, oo, pi, zoo) +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.core.sympify import sympify +from sympy.functions.elementary.miscellaneous import (Max, Min, sqrt) +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.logic.boolalg import (false, true) +from sympy.matrices.kind import MatrixKind +from sympy.matrices.dense import Matrix +from sympy.polys.rootoftools import rootof +from sympy.sets.contains import Contains +from sympy.sets.fancysets import (ImageSet, Range) +from sympy.sets.sets import (Complement, DisjointUnion, FiniteSet, Intersection, Interval, ProductSet, Set, SymmetricDifference, Union, imageset, SetKind) +from mpmath import mpi + +from sympy.core.expr import unchanged +from sympy.core.relational import Eq, Ne, Le, Lt, LessThan +from sympy.logic import And, Or, Xor +from sympy.testing.pytest import raises, XFAIL, warns_deprecated_sympy +from sympy.utilities.iterables import cartes + +from sympy.abc import x, y, z, m, n + +EmptySet = S.EmptySet + +def test_imageset(): + ints = S.Integers + assert imageset(x, x - 1, S.Naturals) is S.Naturals0 + assert imageset(x, x + 1, S.Naturals0) is S.Naturals + assert imageset(x, abs(x), S.Naturals0) is S.Naturals0 + assert imageset(x, abs(x), S.Naturals) is S.Naturals + assert imageset(x, abs(x), S.Integers) is S.Naturals0 + # issue 16878a + r = symbols('r', real=True) + assert imageset(x, (x, x), S.Reals)._contains((1, r)) == None + assert imageset(x, (x, x), S.Reals)._contains((1, 2)) == False + assert (r, r) in imageset(x, (x, x), S.Reals) + assert 1 + I in imageset(x, x + I, S.Reals) + assert {1} not in imageset(x, (x,), S.Reals) + assert (1, 1) not in imageset(x, (x,), S.Reals) + raises(TypeError, lambda: imageset(x, ints)) + raises(ValueError, lambda: imageset(x, y, z, ints)) + raises(ValueError, lambda: imageset(Lambda(x, cos(x)), y)) + assert (1, 2) in imageset(Lambda((x, y), (x, y)), ints, ints) + raises(ValueError, lambda: imageset(Lambda(x, x), ints, ints)) + assert imageset(cos, ints) == ImageSet(Lambda(x, cos(x)), ints) + def f(x): + return cos(x) + assert imageset(f, ints) == imageset(x, cos(x), ints) + f = lambda x: cos(x) + assert imageset(f, ints) == ImageSet(Lambda(x, cos(x)), ints) + assert imageset(x, 1, ints) == FiniteSet(1) + assert imageset(x, y, ints) == {y} + assert imageset((x, y), (1, z), ints, S.Reals) == {(1, z)} + clash = Symbol('x', integer=true) + assert (str(imageset(lambda x: x + clash, Interval(-2, 1)).lamda.expr) + in ('x0 + x', 'x + x0')) + x1, x2 = symbols("x1, x2") + assert imageset(lambda x, y: + Add(x, y), Interval(1, 2), Interval(2, 3)).dummy_eq( + ImageSet(Lambda((x1, x2), x1 + x2), + Interval(1, 2), Interval(2, 3))) + + +def test_is_empty(): + for s in [S.Naturals, S.Naturals0, S.Integers, S.Rationals, S.Reals, + S.UniversalSet]: + assert s.is_empty is False + + assert S.EmptySet.is_empty is True + + +def test_is_finiteset(): + for s in [S.Naturals, S.Naturals0, S.Integers, S.Rationals, S.Reals, + S.UniversalSet]: + assert s.is_finite_set is False + + assert S.EmptySet.is_finite_set is True + + assert FiniteSet(1, 2).is_finite_set is True + assert Interval(1, 2).is_finite_set is False + assert Interval(x, y).is_finite_set is None + assert ProductSet(FiniteSet(1), FiniteSet(2)).is_finite_set is True + assert ProductSet(FiniteSet(1), Interval(1, 2)).is_finite_set is False + assert ProductSet(FiniteSet(1), Interval(x, y)).is_finite_set is None + assert Union(Interval(0, 1), Interval(2, 3)).is_finite_set is False + assert Union(FiniteSet(1), Interval(2, 3)).is_finite_set is False + assert Union(FiniteSet(1), FiniteSet(2)).is_finite_set is True + assert Union(FiniteSet(1), Interval(x, y)).is_finite_set is None + assert Intersection(Interval(x, y), FiniteSet(1)).is_finite_set is True + assert Intersection(Interval(x, y), Interval(1, 2)).is_finite_set is None + assert Intersection(FiniteSet(x), FiniteSet(y)).is_finite_set is True + assert Complement(FiniteSet(1), Interval(x, y)).is_finite_set is True + assert Complement(Interval(x, y), FiniteSet(1)).is_finite_set is None + assert Complement(Interval(1, 2), FiniteSet(x)).is_finite_set is False + assert DisjointUnion(Interval(-5, 3), FiniteSet(x, y)).is_finite_set is False + assert DisjointUnion(S.EmptySet, FiniteSet(x, y), S.EmptySet).is_finite_set is True + + +def test_deprecated_is_EmptySet(): + with warns_deprecated_sympy(): + S.EmptySet.is_EmptySet + + with warns_deprecated_sympy(): + FiniteSet(1).is_EmptySet + + +def test_interval_arguments(): + assert Interval(0, oo) == Interval(0, oo, False, True) + assert Interval(0, oo).right_open is true + assert Interval(-oo, 0) == Interval(-oo, 0, True, False) + assert Interval(-oo, 0).left_open is true + assert Interval(oo, -oo) == S.EmptySet + assert Interval(oo, oo) == S.EmptySet + assert Interval(-oo, -oo) == S.EmptySet + assert Interval(oo, x) == S.EmptySet + assert Interval(oo, oo) == S.EmptySet + assert Interval(x, -oo) == S.EmptySet + assert Interval(x, x) == {x} + + assert isinstance(Interval(1, 1), FiniteSet) + e = Sum(x, (x, 1, 3)) + assert isinstance(Interval(e, e), FiniteSet) + + assert Interval(1, 0) == S.EmptySet + assert Interval(1, 1).measure == 0 + + assert Interval(1, 1, False, True) == S.EmptySet + assert Interval(1, 1, True, False) == S.EmptySet + assert Interval(1, 1, True, True) == S.EmptySet + + + assert isinstance(Interval(0, Symbol('a')), Interval) + assert Interval(Symbol('a', positive=True), 0) == S.EmptySet + raises(ValueError, lambda: Interval(0, S.ImaginaryUnit)) + raises(ValueError, lambda: Interval(0, Symbol('z', extended_real=False))) + raises(ValueError, lambda: Interval(x, x + S.ImaginaryUnit)) + + raises(NotImplementedError, lambda: Interval(0, 1, And(x, y))) + raises(NotImplementedError, lambda: Interval(0, 1, False, And(x, y))) + raises(NotImplementedError, lambda: Interval(0, 1, z, And(x, y))) + + +def test_interval_symbolic_end_points(): + a = Symbol('a', real=True) + + assert Union(Interval(0, a), Interval(0, 3)).sup == Max(a, 3) + assert Union(Interval(a, 0), Interval(-3, 0)).inf == Min(-3, a) + + assert Interval(0, a).contains(1) == LessThan(1, a) + + +def test_interval_is_empty(): + x, y = symbols('x, y') + r = Symbol('r', real=True) + p = Symbol('p', positive=True) + n = Symbol('n', negative=True) + nn = Symbol('nn', nonnegative=True) + assert Interval(1, 2).is_empty == False + assert Interval(3, 3).is_empty == False # FiniteSet + assert Interval(r, r).is_empty == False # FiniteSet + assert Interval(r, r + nn).is_empty == False + assert Interval(x, x).is_empty == False + assert Interval(1, oo).is_empty == False + assert Interval(-oo, oo).is_empty == False + assert Interval(-oo, 1).is_empty == False + assert Interval(x, y).is_empty == None + assert Interval(r, oo).is_empty == False # real implies finite + assert Interval(n, 0).is_empty == False + assert Interval(n, 0, left_open=True).is_empty == False + assert Interval(p, 0).is_empty == True # EmptySet + assert Interval(nn, 0).is_empty == None + assert Interval(n, p).is_empty == False + assert Interval(0, p, left_open=True).is_empty == False + assert Interval(0, p, right_open=True).is_empty == False + assert Interval(0, nn, left_open=True).is_empty == None + assert Interval(0, nn, right_open=True).is_empty == None + + +def test_union(): + assert Union(Interval(1, 2), Interval(2, 3)) == Interval(1, 3) + assert Union(Interval(1, 2), Interval(2, 3, True)) == Interval(1, 3) + assert Union(Interval(1, 3), Interval(2, 4)) == Interval(1, 4) + assert Union(Interval(1, 2), Interval(1, 3)) == Interval(1, 3) + assert Union(Interval(1, 3), Interval(1, 2)) == Interval(1, 3) + assert Union(Interval(1, 3, False, True), Interval(1, 2)) == \ + Interval(1, 3, False, True) + assert Union(Interval(1, 3), Interval(1, 2, False, True)) == Interval(1, 3) + assert Union(Interval(1, 2, True), Interval(1, 3)) == Interval(1, 3) + assert Union(Interval(1, 2, True), Interval(1, 3, True)) == \ + Interval(1, 3, True) + assert Union(Interval(1, 2, True), Interval(1, 3, True, True)) == \ + Interval(1, 3, True, True) + assert Union(Interval(1, 2, True, True), Interval(1, 3, True)) == \ + Interval(1, 3, True) + assert Union(Interval(1, 3), Interval(2, 3)) == Interval(1, 3) + assert Union(Interval(1, 3, False, True), Interval(2, 3)) == \ + Interval(1, 3) + assert Union(Interval(1, 2, False, True), Interval(2, 3, True)) != \ + Interval(1, 3) + assert Union(Interval(1, 2), S.EmptySet) == Interval(1, 2) + assert Union(S.EmptySet) == S.EmptySet + + assert Union(Interval(0, 1), *[FiniteSet(1.0/n) for n in range(1, 10)]) == \ + Interval(0, 1) + # issue #18241: + x = Symbol('x') + assert Union(Interval(0, 1), FiniteSet(1, x)) == Union( + Interval(0, 1), FiniteSet(x)) + assert unchanged(Union, Interval(0, 1), FiniteSet(2, x)) + + assert Interval(1, 2).union(Interval(2, 3)) == \ + Interval(1, 2) + Interval(2, 3) + + assert Interval(1, 2).union(Interval(2, 3)) == Interval(1, 3) + + assert Union(Set()) == Set() + + assert FiniteSet(1) + FiniteSet(2) + FiniteSet(3) == FiniteSet(1, 2, 3) + assert FiniteSet('ham') + FiniteSet('eggs') == FiniteSet('ham', 'eggs') + assert FiniteSet(1, 2, 3) + S.EmptySet == FiniteSet(1, 2, 3) + + assert FiniteSet(1, 2, 3) & FiniteSet(2, 3, 4) == FiniteSet(2, 3) + assert FiniteSet(1, 2, 3) | FiniteSet(2, 3, 4) == FiniteSet(1, 2, 3, 4) + + assert FiniteSet(1, 2, 3) & S.EmptySet == S.EmptySet + assert FiniteSet(1, 2, 3) | S.EmptySet == FiniteSet(1, 2, 3) + + x = Symbol("x") + y = Symbol("y") + z = Symbol("z") + assert S.EmptySet | FiniteSet(x, FiniteSet(y, z)) == \ + FiniteSet(x, FiniteSet(y, z)) + + # Test that Intervals and FiniteSets play nicely + assert Interval(1, 3) + FiniteSet(2) == Interval(1, 3) + assert Interval(1, 3, True, True) + FiniteSet(3) == \ + Interval(1, 3, True, False) + X = Interval(1, 3) + FiniteSet(5) + Y = Interval(1, 2) + FiniteSet(3) + XandY = X.intersect(Y) + assert 2 in X and 3 in X and 3 in XandY + assert XandY.is_subset(X) and XandY.is_subset(Y) + + raises(TypeError, lambda: Union(1, 2, 3)) + + assert X.is_iterable is False + + # issue 7843 + assert Union(S.EmptySet, FiniteSet(-sqrt(-I), sqrt(-I))) == \ + FiniteSet(-sqrt(-I), sqrt(-I)) + + assert Union(S.Reals, S.Integers) == S.Reals + + +def test_union_iter(): + # Use Range because it is ordered + u = Union(Range(3), Range(5), Range(4), evaluate=False) + + # Round robin + assert list(u) == [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 4] + + +def test_union_is_empty(): + assert (Interval(x, y) + FiniteSet(1)).is_empty == False + assert (Interval(x, y) + Interval(-x, y)).is_empty == None + + +def test_difference(): + assert Interval(1, 3) - Interval(1, 2) == Interval(2, 3, True) + assert Interval(1, 3) - Interval(2, 3) == Interval(1, 2, False, True) + assert Interval(1, 3, True) - Interval(2, 3) == Interval(1, 2, True, True) + assert Interval(1, 3, True) - Interval(2, 3, True) == \ + Interval(1, 2, True, False) + assert Interval(0, 2) - FiniteSet(1) == \ + Union(Interval(0, 1, False, True), Interval(1, 2, True, False)) + + # issue #18119 + assert S.Reals - FiniteSet(I) == S.Reals + assert S.Reals - FiniteSet(-I, I) == S.Reals + assert Interval(0, 10) - FiniteSet(-I, I) == Interval(0, 10) + assert Interval(0, 10) - FiniteSet(1, I) == Union( + Interval.Ropen(0, 1), Interval.Lopen(1, 10)) + assert S.Reals - FiniteSet(1, 2 + I, x, y**2) == Complement( + Union(Interval.open(-oo, 1), Interval.open(1, oo)), FiniteSet(x, y**2), + evaluate=False) + + assert FiniteSet(1, 2, 3) - FiniteSet(2) == FiniteSet(1, 3) + assert FiniteSet('ham', 'eggs') - FiniteSet('eggs') == FiniteSet('ham') + assert FiniteSet(1, 2, 3, 4) - Interval(2, 10, True, False) == \ + FiniteSet(1, 2) + assert FiniteSet(1, 2, 3, 4) - S.EmptySet == FiniteSet(1, 2, 3, 4) + assert Union(Interval(0, 2), FiniteSet(2, 3, 4)) - Interval(1, 3) == \ + Union(Interval(0, 1, False, True), FiniteSet(4)) + + assert -1 in S.Reals - S.Naturals + + +def test_Complement(): + A = FiniteSet(1, 3, 4) + B = FiniteSet(3, 4) + C = Interval(1, 3) + D = Interval(1, 2) + + assert Complement(A, B, evaluate=False).is_iterable is True + assert Complement(A, C, evaluate=False).is_iterable is True + assert Complement(C, D, evaluate=False).is_iterable is None + + assert FiniteSet(*Complement(A, B, evaluate=False)) == FiniteSet(1) + assert FiniteSet(*Complement(A, C, evaluate=False)) == FiniteSet(4) + raises(TypeError, lambda: FiniteSet(*Complement(C, A, evaluate=False))) + + assert Complement(Interval(1, 3), Interval(1, 2)) == Interval(2, 3, True) + assert Complement(FiniteSet(1, 3, 4), FiniteSet(3, 4)) == FiniteSet(1) + assert Complement(Union(Interval(0, 2), FiniteSet(2, 3, 4)), + Interval(1, 3)) == \ + Union(Interval(0, 1, False, True), FiniteSet(4)) + + assert 3 not in Complement(Interval(0, 5), Interval(1, 4), evaluate=False) + assert -1 in Complement(S.Reals, S.Naturals, evaluate=False) + assert 1 not in Complement(S.Reals, S.Naturals, evaluate=False) + + assert Complement(S.Integers, S.UniversalSet) == EmptySet + assert S.UniversalSet.complement(S.Integers) == EmptySet + + assert (0 not in S.Reals.intersect(S.Integers - FiniteSet(0))) + + assert S.EmptySet - S.Integers == S.EmptySet + + assert (S.Integers - FiniteSet(0)) - FiniteSet(1) == S.Integers - FiniteSet(0, 1) + + assert S.Reals - Union(S.Naturals, FiniteSet(pi)) == \ + Intersection(S.Reals - S.Naturals, S.Reals - FiniteSet(pi)) + # issue 12712 + assert Complement(FiniteSet(x, y, 2), Interval(-10, 10)) == \ + Complement(FiniteSet(x, y), Interval(-10, 10)) + + A = FiniteSet(*symbols('a:c')) + B = FiniteSet(*symbols('d:f')) + assert unchanged(Complement, ProductSet(A, A), B) + + A2 = ProductSet(A, A) + B3 = ProductSet(B, B, B) + assert A2 - B3 == A2 + assert B3 - A2 == B3 + + +def test_set_operations_nonsets(): + '''Tests that e.g. FiniteSet(1) * 2 raises TypeError''' + ops = [ + lambda a, b: a + b, + lambda a, b: a - b, + lambda a, b: a * b, + lambda a, b: a / b, + lambda a, b: a // b, + lambda a, b: a | b, + lambda a, b: a & b, + lambda a, b: a ^ b, + # FiniteSet(1) ** 2 gives a ProductSet + #lambda a, b: a ** b, + ] + Sx = FiniteSet(x) + Sy = FiniteSet(y) + sets = [ + {1}, + FiniteSet(1), + Interval(1, 2), + Union(Sx, Interval(1, 2)), + Intersection(Sx, Sy), + Complement(Sx, Sy), + ProductSet(Sx, Sy), + S.EmptySet, + ] + nums = [0, 1, 2, S(0), S(1), S(2)] + + for si in sets: + for ni in nums: + for op in ops: + raises(TypeError, lambda : op(si, ni)) + raises(TypeError, lambda : op(ni, si)) + raises(TypeError, lambda: si ** object()) + raises(TypeError, lambda: si ** {1}) + + +def test_complement(): + assert Complement({1, 2}, {1}) == {2} + assert Interval(0, 1).complement(S.Reals) == \ + Union(Interval(-oo, 0, True, True), Interval(1, oo, True, True)) + assert Interval(0, 1, True, False).complement(S.Reals) == \ + Union(Interval(-oo, 0, True, False), Interval(1, oo, True, True)) + assert Interval(0, 1, False, True).complement(S.Reals) == \ + Union(Interval(-oo, 0, True, True), Interval(1, oo, False, True)) + assert Interval(0, 1, True, True).complement(S.Reals) == \ + Union(Interval(-oo, 0, True, False), Interval(1, oo, False, True)) + + assert S.UniversalSet.complement(S.EmptySet) == S.EmptySet + assert S.UniversalSet.complement(S.Reals) == S.EmptySet + assert S.UniversalSet.complement(S.UniversalSet) == S.EmptySet + + assert S.EmptySet.complement(S.Reals) == S.Reals + + assert Union(Interval(0, 1), Interval(2, 3)).complement(S.Reals) == \ + Union(Interval(-oo, 0, True, True), Interval(1, 2, True, True), + Interval(3, oo, True, True)) + + assert FiniteSet(0).complement(S.Reals) == \ + Union(Interval(-oo, 0, True, True), Interval(0, oo, True, True)) + + assert (FiniteSet(5) + Interval(S.NegativeInfinity, + 0)).complement(S.Reals) == \ + Interval(0, 5, True, True) + Interval(5, S.Infinity, True, True) + + assert FiniteSet(1, 2, 3).complement(S.Reals) == \ + Interval(S.NegativeInfinity, 1, True, True) + \ + Interval(1, 2, True, True) + Interval(2, 3, True, True) +\ + Interval(3, S.Infinity, True, True) + + assert FiniteSet(x).complement(S.Reals) == Complement(S.Reals, FiniteSet(x)) + + assert FiniteSet(0, x).complement(S.Reals) == Complement(Interval(-oo, 0, True, True) + + Interval(0, oo, True, True) + , FiniteSet(x), evaluate=False) + + square = Interval(0, 1) * Interval(0, 1) + notsquare = square.complement(S.Reals*S.Reals) + + assert all(pt in square for pt in [(0, 0), (.5, .5), (1, 0), (1, 1)]) + assert not any( + pt in notsquare for pt in [(0, 0), (.5, .5), (1, 0), (1, 1)]) + assert not any(pt in square for pt in [(-1, 0), (1.5, .5), (10, 10)]) + assert all(pt in notsquare for pt in [(-1, 0), (1.5, .5), (10, 10)]) + + +def test_intersect1(): + assert all(S.Integers.intersection(i) is i for i in + (S.Naturals, S.Naturals0)) + assert all(i.intersection(S.Integers) is i for i in + (S.Naturals, S.Naturals0)) + s = S.Naturals0 + assert S.Naturals.intersection(s) is S.Naturals + assert s.intersection(S.Naturals) is S.Naturals + x = Symbol('x') + assert Interval(0, 2).intersect(Interval(1, 2)) == Interval(1, 2) + assert Interval(0, 2).intersect(Interval(1, 2, True)) == \ + Interval(1, 2, True) + assert Interval(0, 2, True).intersect(Interval(1, 2)) == \ + Interval(1, 2, False, False) + assert Interval(0, 2, True, True).intersect(Interval(1, 2)) == \ + Interval(1, 2, False, True) + assert Interval(0, 2).intersect(Union(Interval(0, 1), Interval(2, 3))) == \ + Union(Interval(0, 1), Interval(2, 2)) + + assert FiniteSet(1, 2).intersect(FiniteSet(1, 2, 3)) == FiniteSet(1, 2) + assert FiniteSet(1, 2, x).intersect(FiniteSet(x)) == FiniteSet(x) + assert FiniteSet('ham', 'eggs').intersect(FiniteSet('ham')) == \ + FiniteSet('ham') + assert FiniteSet(1, 2, 3, 4, 5).intersect(S.EmptySet) == S.EmptySet + + assert Interval(0, 5).intersect(FiniteSet(1, 3)) == FiniteSet(1, 3) + assert Interval(0, 1, True, True).intersect(FiniteSet(1)) == S.EmptySet + + assert Union(Interval(0, 1), Interval(2, 3)).intersect(Interval(1, 2)) == \ + Union(Interval(1, 1), Interval(2, 2)) + assert Union(Interval(0, 1), Interval(2, 3)).intersect(Interval(0, 2)) == \ + Union(Interval(0, 1), Interval(2, 2)) + assert Union(Interval(0, 1), Interval(2, 3)).intersect(Interval(1, 2, True, True)) == \ + S.EmptySet + assert Union(Interval(0, 1), Interval(2, 3)).intersect(S.EmptySet) == \ + S.EmptySet + assert Union(Interval(0, 5), FiniteSet('ham')).intersect(FiniteSet(2, 3, 4, 5, 6)) == \ + Intersection(FiniteSet(2, 3, 4, 5, 6), Union(FiniteSet('ham'), Interval(0, 5))) + assert Intersection(FiniteSet(1, 2, 3), Interval(2, x), Interval(3, y)) == \ + Intersection(FiniteSet(3), Interval(2, x), Interval(3, y), evaluate=False) + assert Intersection(FiniteSet(1, 2), Interval(0, 3), Interval(x, y)) == \ + Intersection({1, 2}, Interval(x, y), evaluate=False) + assert Intersection(FiniteSet(1, 2, 4), Interval(0, 3), Interval(x, y)) == \ + Intersection({1, 2}, Interval(x, y), evaluate=False) + # XXX: Is the real=True necessary here? + # https://github.com/sympy/sympy/issues/17532 + m, n = symbols('m, n', real=True) + assert Intersection(FiniteSet(m), FiniteSet(m, n), Interval(m, m+1)) == \ + FiniteSet(m) + + # issue 8217 + assert Intersection(FiniteSet(x), FiniteSet(y)) == \ + Intersection(FiniteSet(x), FiniteSet(y), evaluate=False) + assert FiniteSet(x).intersect(S.Reals) == \ + Intersection(S.Reals, FiniteSet(x), evaluate=False) + + # tests for the intersection alias + assert Interval(0, 5).intersection(FiniteSet(1, 3)) == FiniteSet(1, 3) + assert Interval(0, 1, True, True).intersection(FiniteSet(1)) == S.EmptySet + + assert Union(Interval(0, 1), Interval(2, 3)).intersection(Interval(1, 2)) == \ + Union(Interval(1, 1), Interval(2, 2)) + + # canonical boundary selected + a = sqrt(2*sqrt(6) + 5) + b = sqrt(2) + sqrt(3) + assert Interval(a, 4).intersection(Interval(b, 5)) == Interval(b, 4) + assert Interval(1, a).intersection(Interval(0, b)) == Interval(1, b) + + +def test_intersection_interval_float(): + # intersection of Intervals with mixed Rational/Float boundaries should + # lead to Float boundaries in all cases regardless of which Interval is + # open or closed. + typs = [ + (Interval, Interval, Interval), + (Interval, Interval.open, Interval.open), + (Interval, Interval.Lopen, Interval.Lopen), + (Interval, Interval.Ropen, Interval.Ropen), + (Interval.open, Interval.open, Interval.open), + (Interval.open, Interval.Lopen, Interval.open), + (Interval.open, Interval.Ropen, Interval.open), + (Interval.Lopen, Interval.Lopen, Interval.Lopen), + (Interval.Lopen, Interval.Ropen, Interval.open), + (Interval.Ropen, Interval.Ropen, Interval.Ropen), + ] + + as_float = lambda a1, a2: a2 if isinstance(a2, float) else a1 + + for t1, t2, t3 in typs: + for t1i, t2i in [(t1, t2), (t2, t1)]: + for a1, a2, b1, b2 in cartes([2, 2.0], [2, 2.0], [3, 3.0], [3, 3.0]): + I1 = t1(a1, b1) + I2 = t2(a2, b2) + I3 = t3(as_float(a1, a2), as_float(b1, b2)) + assert I1.intersect(I2) == I3 + + +def test_intersection(): + # iterable + i = Intersection(FiniteSet(1, 2, 3), Interval(2, 5), evaluate=False) + assert i.is_iterable + assert set(i) == {S(2), S(3)} + + # challenging intervals + x = Symbol('x', real=True) + i = Intersection(Interval(0, 3), Interval(x, 6)) + assert (5 in i) is False + raises(TypeError, lambda: 2 in i) + + # Singleton special cases + assert Intersection(Interval(0, 1), S.EmptySet) == S.EmptySet + assert Intersection(Interval(-oo, oo), Interval(-oo, x)) == Interval(-oo, x) + + # Products + line = Interval(0, 5) + i = Intersection(line**2, line**3, evaluate=False) + assert (2, 2) not in i + assert (2, 2, 2) not in i + raises(TypeError, lambda: list(i)) + + a = Intersection(Intersection(S.Integers, S.Naturals, evaluate=False), S.Reals, evaluate=False) + assert a._argset == frozenset([Intersection(S.Naturals, S.Integers, evaluate=False), S.Reals]) + + assert Intersection(S.Complexes, FiniteSet(S.ComplexInfinity)) == S.EmptySet + + # issue 12178 + assert Intersection() == S.UniversalSet + + # issue 16987 + assert Intersection({1}, {1}, {x}) == Intersection({1}, {x}) + + +def test_issue_9623(): + n = Symbol('n') + + a = S.Reals + b = Interval(0, oo) + c = FiniteSet(n) + + assert Intersection(a, b, c) == Intersection(b, c) + assert Intersection(Interval(1, 2), Interval(3, 4), FiniteSet(n)) == EmptySet + + +def test_is_disjoint(): + assert Interval(0, 2).is_disjoint(Interval(1, 2)) == False + assert Interval(0, 2).is_disjoint(Interval(3, 4)) == True + + +def test_ProductSet__len__(): + A = FiniteSet(1, 2) + B = FiniteSet(1, 2, 3) + assert ProductSet(A).__len__() == 2 + assert ProductSet(A).__len__() is not S(2) + assert ProductSet(A, B).__len__() == 6 + assert ProductSet(A, B).__len__() is not S(6) + + +def test_ProductSet(): + # ProductSet is always a set of Tuples + assert ProductSet(S.Reals) == S.Reals ** 1 + assert ProductSet(S.Reals, S.Reals) == S.Reals ** 2 + assert ProductSet(S.Reals, S.Reals, S.Reals) == S.Reals ** 3 + + assert ProductSet(S.Reals) != S.Reals + assert ProductSet(S.Reals, S.Reals) == S.Reals * S.Reals + assert ProductSet(S.Reals, S.Reals, S.Reals) != S.Reals * S.Reals * S.Reals + assert ProductSet(S.Reals, S.Reals, S.Reals) == (S.Reals * S.Reals * S.Reals).flatten() + + assert 1 not in ProductSet(S.Reals) + assert (1,) in ProductSet(S.Reals) + + assert 1 not in ProductSet(S.Reals, S.Reals) + assert (1, 2) in ProductSet(S.Reals, S.Reals) + assert (1, I) not in ProductSet(S.Reals, S.Reals) + + assert (1, 2, 3) in ProductSet(S.Reals, S.Reals, S.Reals) + assert (1, 2, 3) in S.Reals ** 3 + assert (1, 2, 3) not in S.Reals * S.Reals * S.Reals + assert ((1, 2), 3) in S.Reals * S.Reals * S.Reals + assert (1, (2, 3)) not in S.Reals * S.Reals * S.Reals + assert (1, (2, 3)) in S.Reals * (S.Reals * S.Reals) + + assert ProductSet() == FiniteSet(()) + assert ProductSet(S.Reals, S.EmptySet) == S.EmptySet + + # See GH-17458 + + for ni in range(5): + Rn = ProductSet(*(S.Reals,) * ni) + assert (1,) * ni in Rn + assert 1 not in Rn + + assert (S.Reals * S.Reals) * S.Reals != S.Reals * (S.Reals * S.Reals) + + S1 = S.Reals + S2 = S.Integers + x1 = pi + x2 = 3 + assert x1 in S1 + assert x2 in S2 + assert (x1, x2) in S1 * S2 + S3 = S1 * S2 + x3 = (x1, x2) + assert x3 in S3 + assert (x3, x3) in S3 * S3 + assert x3 + x3 not in S3 * S3 + + raises(ValueError, lambda: S.Reals**-1) + with warns_deprecated_sympy(): + ProductSet(FiniteSet(s) for s in range(2)) + raises(TypeError, lambda: ProductSet(None)) + + S1 = FiniteSet(1, 2) + S2 = FiniteSet(3, 4) + S3 = ProductSet(S1, S2) + assert (S3.as_relational(x, y) + == And(S1.as_relational(x), S2.as_relational(y)) + == And(Or(Eq(x, 1), Eq(x, 2)), Or(Eq(y, 3), Eq(y, 4)))) + raises(ValueError, lambda: S3.as_relational(x)) + raises(ValueError, lambda: S3.as_relational(x, 1)) + raises(ValueError, lambda: ProductSet(Interval(0, 1)).as_relational(x, y)) + + Z2 = ProductSet(S.Integers, S.Integers) + assert Z2.contains((1, 2)) is S.true + assert Z2.contains((1,)) is S.false + assert Z2.contains(x) == Contains(x, Z2, evaluate=False) + assert Z2.contains(x).subs(x, 1) is S.false + assert Z2.contains((x, 1)).subs(x, 2) is S.true + assert Z2.contains((x, y)) == Contains(x, S.Integers) & Contains(y, S.Integers) + assert unchanged(Contains, (x, y), Z2) + assert Contains((1, 2), Z2) is S.true + + +def test_ProductSet_of_single_arg_is_not_arg(): + assert unchanged(ProductSet, Interval(0, 1)) + assert unchanged(ProductSet, ProductSet(Interval(0, 1))) + + +def test_ProductSet_is_empty(): + assert ProductSet(S.Integers, S.Reals).is_empty == False + assert ProductSet(Interval(x, 1), S.Reals).is_empty == None + + +def test_interval_subs(): + a = Symbol('a', real=True) + + assert Interval(0, a).subs(a, 2) == Interval(0, 2) + assert Interval(a, 0).subs(a, 2) == S.EmptySet + + +def test_interval_to_mpi(): + assert Interval(0, 1).to_mpi() == mpi(0, 1) + assert Interval(0, 1, True, False).to_mpi() == mpi(0, 1) + assert type(Interval(0, 1).to_mpi()) == type(mpi(0, 1)) + + +def test_set_evalf(): + assert Interval(S(11)/64, S.Half).evalf() == Interval( + Float('0.171875'), Float('0.5')) + assert Interval(x, S.Half, right_open=True).evalf() == Interval( + x, Float('0.5'), right_open=True) + assert Interval(-oo, S.Half).evalf() == Interval(-oo, Float('0.5')) + assert FiniteSet(2, x).evalf() == FiniteSet(Float('2.0'), x) + + +def test_measure(): + a = Symbol('a', real=True) + + assert Interval(1, 3).measure == 2 + assert Interval(0, a).measure == a + assert Interval(1, a).measure == a - 1 + + assert Union(Interval(1, 2), Interval(3, 4)).measure == 2 + assert Union(Interval(1, 2), Interval(3, 4), FiniteSet(5, 6, 7)).measure \ + == 2 + + assert FiniteSet(1, 2, oo, a, -oo, -5).measure == 0 + + assert S.EmptySet.measure == 0 + + square = Interval(0, 10) * Interval(0, 10) + offsetsquare = Interval(5, 15) * Interval(5, 15) + band = Interval(-oo, oo) * Interval(2, 4) + + assert square.measure == offsetsquare.measure == 100 + assert (square + offsetsquare).measure == 175 # there is some overlap + assert (square - offsetsquare).measure == 75 + assert (square * FiniteSet(1, 2, 3)).measure == 0 + assert (square.intersect(band)).measure == 20 + assert (square + band).measure is oo + assert (band * FiniteSet(1, 2, 3)).measure is nan + + +def test_is_subset(): + assert Interval(0, 1).is_subset(Interval(0, 2)) is True + assert Interval(0, 3).is_subset(Interval(0, 2)) is False + assert Interval(0, 1).is_subset(FiniteSet(0, 1)) is False + + assert FiniteSet(1, 2).is_subset(FiniteSet(1, 2, 3, 4)) + assert FiniteSet(4, 5).is_subset(FiniteSet(1, 2, 3, 4)) is False + assert FiniteSet(1).is_subset(Interval(0, 2)) + assert FiniteSet(1, 2).is_subset(Interval(0, 2, True, True)) is False + assert (Interval(1, 2) + FiniteSet(3)).is_subset( + Interval(0, 2, False, True) + FiniteSet(2, 3)) + + assert Interval(3, 4).is_subset(Union(Interval(0, 1), Interval(2, 5))) is True + assert Interval(3, 6).is_subset(Union(Interval(0, 1), Interval(2, 5))) is False + + assert FiniteSet(1, 2, 3, 4).is_subset(Interval(0, 5)) is True + assert S.EmptySet.is_subset(FiniteSet(1, 2, 3)) is True + + assert Interval(0, 1).is_subset(S.EmptySet) is False + assert S.EmptySet.is_subset(S.EmptySet) is True + + raises(ValueError, lambda: S.EmptySet.is_subset(1)) + + # tests for the issubset alias + assert FiniteSet(1, 2, 3, 4).issubset(Interval(0, 5)) is True + assert S.EmptySet.issubset(FiniteSet(1, 2, 3)) is True + + assert S.Naturals.is_subset(S.Integers) + assert S.Naturals0.is_subset(S.Integers) + + assert FiniteSet(x).is_subset(FiniteSet(y)) is None + assert FiniteSet(x).is_subset(FiniteSet(y).subs(y, x)) is True + assert FiniteSet(x).is_subset(FiniteSet(y).subs(y, x+1)) is False + + assert Interval(0, 1).is_subset(Interval(0, 1, left_open=True)) is False + assert Interval(-2, 3).is_subset(Union(Interval(-oo, -2), Interval(3, oo))) is False + + n = Symbol('n', integer=True) + assert Range(-3, 4, 1).is_subset(FiniteSet(-10, 10)) is False + assert Range(S(10)**100).is_subset(FiniteSet(0, 1, 2)) is False + assert Range(6, 0, -2).is_subset(FiniteSet(2, 4, 6)) is True + assert Range(1, oo).is_subset(FiniteSet(1, 2)) is False + assert Range(-oo, 1).is_subset(FiniteSet(1)) is False + assert Range(3).is_subset(FiniteSet(0, 1, n)) is None + assert Range(n, n + 2).is_subset(FiniteSet(n, n + 1)) is True + assert Range(5).is_subset(Interval(0, 4, right_open=True)) is False + #issue 19513 + assert imageset(Lambda(n, 1/n), S.Integers).is_subset(S.Reals) is None + +def test_is_proper_subset(): + assert Interval(0, 1).is_proper_subset(Interval(0, 2)) is True + assert Interval(0, 3).is_proper_subset(Interval(0, 2)) is False + assert S.EmptySet.is_proper_subset(FiniteSet(1, 2, 3)) is True + + raises(ValueError, lambda: Interval(0, 1).is_proper_subset(0)) + + +def test_is_superset(): + assert Interval(0, 1).is_superset(Interval(0, 2)) == False + assert Interval(0, 3).is_superset(Interval(0, 2)) + + assert FiniteSet(1, 2).is_superset(FiniteSet(1, 2, 3, 4)) == False + assert FiniteSet(4, 5).is_superset(FiniteSet(1, 2, 3, 4)) == False + assert FiniteSet(1).is_superset(Interval(0, 2)) == False + assert FiniteSet(1, 2).is_superset(Interval(0, 2, True, True)) == False + assert (Interval(1, 2) + FiniteSet(3)).is_superset( + Interval(0, 2, False, True) + FiniteSet(2, 3)) == False + + assert Interval(3, 4).is_superset(Union(Interval(0, 1), Interval(2, 5))) == False + + assert FiniteSet(1, 2, 3, 4).is_superset(Interval(0, 5)) == False + assert S.EmptySet.is_superset(FiniteSet(1, 2, 3)) == False + + assert Interval(0, 1).is_superset(S.EmptySet) == True + assert S.EmptySet.is_superset(S.EmptySet) == True + + raises(ValueError, lambda: S.EmptySet.is_superset(1)) + + # tests for the issuperset alias + assert Interval(0, 1).issuperset(S.EmptySet) == True + assert S.EmptySet.issuperset(S.EmptySet) == True + + +def test_is_proper_superset(): + assert Interval(0, 1).is_proper_superset(Interval(0, 2)) is False + assert Interval(0, 3).is_proper_superset(Interval(0, 2)) is True + assert FiniteSet(1, 2, 3).is_proper_superset(S.EmptySet) is True + + raises(ValueError, lambda: Interval(0, 1).is_proper_superset(0)) + + +def test_contains(): + assert Interval(0, 2).contains(1) is S.true + assert Interval(0, 2).contains(3) is S.false + assert Interval(0, 2, True, False).contains(0) is S.false + assert Interval(0, 2, True, False).contains(2) is S.true + assert Interval(0, 2, False, True).contains(0) is S.true + assert Interval(0, 2, False, True).contains(2) is S.false + assert Interval(0, 2, True, True).contains(0) is S.false + assert Interval(0, 2, True, True).contains(2) is S.false + + assert (Interval(0, 2) in Interval(0, 2)) is False + + assert FiniteSet(1, 2, 3).contains(2) is S.true + assert FiniteSet(1, 2, Symbol('x')).contains(Symbol('x')) is S.true + + assert FiniteSet(y)._contains(x) == Eq(y, x, evaluate=False) + raises(TypeError, lambda: x in FiniteSet(y)) + assert FiniteSet({x, y})._contains({x}) == Eq({x, y}, {x}, evaluate=False) + assert FiniteSet({x, y}).subs(y, x)._contains({x}) is S.true + assert FiniteSet({x, y}).subs(y, x+1)._contains({x}) is S.false + + # issue 8197 + from sympy.abc import a, b + assert FiniteSet(b).contains(-a) == Eq(b, -a) + assert FiniteSet(b).contains(a) == Eq(b, a) + assert FiniteSet(a).contains(1) == Eq(a, 1) + raises(TypeError, lambda: 1 in FiniteSet(a)) + + # issue 8209 + rad1 = Pow(Pow(2, Rational(1, 3)) - 1, Rational(1, 3)) + rad2 = Pow(Rational(1, 9), Rational(1, 3)) - Pow(Rational(2, 9), Rational(1, 3)) + Pow(Rational(4, 9), Rational(1, 3)) + s1 = FiniteSet(rad1) + s2 = FiniteSet(rad2) + assert s1 - s2 == S.EmptySet + + items = [1, 2, S.Infinity, S('ham'), -1.1] + fset = FiniteSet(*items) + assert all(item in fset for item in items) + assert all(fset.contains(item) is S.true for item in items) + + assert Union(Interval(0, 1), Interval(2, 5)).contains(3) is S.true + assert Union(Interval(0, 1), Interval(2, 5)).contains(6) is S.false + assert Union(Interval(0, 1), FiniteSet(2, 5)).contains(3) is S.false + + assert S.EmptySet.contains(1) is S.false + assert FiniteSet(rootof(x**3 + x - 1, 0)).contains(S.Infinity) is S.false + + assert rootof(x**5 + x**3 + 1, 0) in S.Reals + assert not rootof(x**5 + x**3 + 1, 1) in S.Reals + + # non-bool results + assert Union(Interval(1, 2), Interval(3, 4)).contains(x) == \ + Or(And(S.One <= x, x <= 2), And(S(3) <= x, x <= 4)) + assert Intersection(Interval(1, x), Interval(2, 3)).contains(y) == \ + And(y <= 3, y <= x, S.One <= y, S(2) <= y) + + assert (S.Complexes).contains(S.ComplexInfinity) == S.false + + +def test_interval_symbolic(): + x = Symbol('x') + e = Interval(0, 1) + assert e.contains(x) == And(S.Zero <= x, x <= 1) + raises(TypeError, lambda: x in e) + e = Interval(0, 1, True, True) + assert e.contains(x) == And(S.Zero < x, x < 1) + c = Symbol('c', real=False) + assert Interval(x, x + 1).contains(c) == False + e = Symbol('e', extended_real=True) + assert Interval(-oo, oo).contains(e) == And( + S.NegativeInfinity < e, e < S.Infinity) + + +def test_union_contains(): + x = Symbol('x') + i1 = Interval(0, 1) + i2 = Interval(2, 3) + i3 = Union(i1, i2) + assert i3.as_relational(x) == Or(And(S.Zero <= x, x <= 1), And(S(2) <= x, x <= 3)) + raises(TypeError, lambda: x in i3) + e = i3.contains(x) + assert e == i3.as_relational(x) + assert e.subs(x, -0.5) is false + assert e.subs(x, 0.5) is true + assert e.subs(x, 1.5) is false + assert e.subs(x, 2.5) is true + assert e.subs(x, 3.5) is false + + U = Interval(0, 2, True, True) + Interval(10, oo) + FiniteSet(-1, 2, 5, 6) + assert all(el not in U for el in [0, 4, -oo]) + assert all(el in U for el in [2, 5, 10]) + + +def test_is_number(): + assert Interval(0, 1).is_number is False + assert Set().is_number is False + + +def test_Interval_is_left_unbounded(): + assert Interval(3, 4).is_left_unbounded is False + assert Interval(-oo, 3).is_left_unbounded is True + assert Interval(Float("-inf"), 3).is_left_unbounded is True + + +def test_Interval_is_right_unbounded(): + assert Interval(3, 4).is_right_unbounded is False + assert Interval(3, oo).is_right_unbounded is True + assert Interval(3, Float("+inf")).is_right_unbounded is True + + +def test_Interval_as_relational(): + x = Symbol('x') + + assert Interval(-1, 2, False, False).as_relational(x) == \ + And(Le(-1, x), Le(x, 2)) + assert Interval(-1, 2, True, False).as_relational(x) == \ + And(Lt(-1, x), Le(x, 2)) + assert Interval(-1, 2, False, True).as_relational(x) == \ + And(Le(-1, x), Lt(x, 2)) + assert Interval(-1, 2, True, True).as_relational(x) == \ + And(Lt(-1, x), Lt(x, 2)) + + assert Interval(-oo, 2, right_open=False).as_relational(x) == And(Lt(-oo, x), Le(x, 2)) + assert Interval(-oo, 2, right_open=True).as_relational(x) == And(Lt(-oo, x), Lt(x, 2)) + + assert Interval(-2, oo, left_open=False).as_relational(x) == And(Le(-2, x), Lt(x, oo)) + assert Interval(-2, oo, left_open=True).as_relational(x) == And(Lt(-2, x), Lt(x, oo)) + + assert Interval(-oo, oo).as_relational(x) == And(Lt(-oo, x), Lt(x, oo)) + x = Symbol('x', real=True) + y = Symbol('y', real=True) + assert Interval(x, y).as_relational(x) == (x <= y) + assert Interval(y, x).as_relational(x) == (y <= x) + + +def test_Finite_as_relational(): + x = Symbol('x') + y = Symbol('y') + + assert FiniteSet(1, 2).as_relational(x) == Or(Eq(x, 1), Eq(x, 2)) + assert FiniteSet(y, -5).as_relational(x) == Or(Eq(x, y), Eq(x, -5)) + + +def test_Union_as_relational(): + x = Symbol('x') + assert (Interval(0, 1) + FiniteSet(2)).as_relational(x) == \ + Or(And(Le(0, x), Le(x, 1)), Eq(x, 2)) + assert (Interval(0, 1, True, True) + FiniteSet(1)).as_relational(x) == \ + And(Lt(0, x), Le(x, 1)) + assert Or(x < 0, x > 0).as_set().as_relational(x) == \ + And((x > -oo), (x < oo), Ne(x, 0)) + assert (Interval.Ropen(1, 3) + Interval.Lopen(3, 5) + ).as_relational(x) == And(Ne(x,3),(x>=1),(x<=5)) + + +def test_Intersection_as_relational(): + x = Symbol('x') + assert (Intersection(Interval(0, 1), FiniteSet(2), + evaluate=False).as_relational(x) + == And(And(Le(0, x), Le(x, 1)), Eq(x, 2))) + + +def test_Complement_as_relational(): + x = Symbol('x') + expr = Complement(Interval(0, 1), FiniteSet(2), evaluate=False) + assert expr.as_relational(x) == \ + And(Le(0, x), Le(x, 1), Ne(x, 2)) + + +@XFAIL +def test_Complement_as_relational_fail(): + x = Symbol('x') + expr = Complement(Interval(0, 1), FiniteSet(2), evaluate=False) + # XXX This example fails because 0 <= x changes to x >= 0 + # during the evaluation. + assert expr.as_relational(x) == \ + (0 <= x) & (x <= 1) & Ne(x, 2) + + +def test_SymmetricDifference_as_relational(): + x = Symbol('x') + expr = SymmetricDifference(Interval(0, 1), FiniteSet(2), evaluate=False) + assert expr.as_relational(x) == Xor(Eq(x, 2), Le(0, x) & Le(x, 1)) + + +def test_EmptySet(): + assert S.EmptySet.as_relational(Symbol('x')) is S.false + assert S.EmptySet.intersect(S.UniversalSet) == S.EmptySet + assert S.EmptySet.boundary == S.EmptySet + + +def test_finite_basic(): + x = Symbol('x') + A = FiniteSet(1, 2, 3) + B = FiniteSet(3, 4, 5) + AorB = Union(A, B) + AandB = A.intersect(B) + assert A.is_subset(AorB) and B.is_subset(AorB) + assert AandB.is_subset(A) + assert AandB == FiniteSet(3) + + assert A.inf == 1 and A.sup == 3 + assert AorB.inf == 1 and AorB.sup == 5 + assert FiniteSet(x, 1, 5).sup == Max(x, 5) + assert FiniteSet(x, 1, 5).inf == Min(x, 1) + + # issue 7335 + assert FiniteSet(S.EmptySet) != S.EmptySet + assert FiniteSet(FiniteSet(1, 2, 3)) != FiniteSet(1, 2, 3) + assert FiniteSet((1, 2, 3)) != FiniteSet(1, 2, 3) + + # Ensure a variety of types can exist in a FiniteSet + assert FiniteSet((1, 2), A, -5, x, 'eggs', x**2) + + assert (A > B) is False + assert (A >= B) is False + assert (A < B) is False + assert (A <= B) is False + assert AorB > A and AorB > B + assert AorB >= A and AorB >= B + assert A >= A and A <= A + assert A >= AandB and B >= AandB + assert A > AandB and B > AandB + + +def test_product_basic(): + H, T = 'H', 'T' + unit_line = Interval(0, 1) + d6 = FiniteSet(1, 2, 3, 4, 5, 6) + d4 = FiniteSet(1, 2, 3, 4) + coin = FiniteSet(H, T) + + square = unit_line * unit_line + + assert (0, 0) in square + assert 0 not in square + assert (H, T) in coin ** 2 + assert (.5, .5, .5) in (square * unit_line).flatten() + assert ((.5, .5), .5) in square * unit_line + assert (H, 3, 3) in (coin * d6 * d6).flatten() + assert ((H, 3), 3) in coin * d6 * d6 + HH, TT = sympify(H), sympify(T) + assert set(coin**2) == {(HH, HH), (HH, TT), (TT, HH), (TT, TT)} + + assert (d4*d4).is_subset(d6*d6) + + assert square.complement(Interval(-oo, oo)*Interval(-oo, oo)) == Union( + (Interval(-oo, 0, True, True) + + Interval(1, oo, True, True))*Interval(-oo, oo), + Interval(-oo, oo)*(Interval(-oo, 0, True, True) + + Interval(1, oo, True, True))) + + assert (Interval(-5, 5)**3).is_subset(Interval(-10, 10)**3) + assert not (Interval(-10, 10)**3).is_subset(Interval(-5, 5)**3) + assert not (Interval(-5, 5)**2).is_subset(Interval(-10, 10)**3) + + assert (Interval(.2, .5)*FiniteSet(.5)).is_subset(square) # segment in square + + assert len(coin*coin*coin) == 8 + assert len(S.EmptySet*S.EmptySet) == 0 + assert len(S.EmptySet*coin) == 0 + raises(TypeError, lambda: len(coin*Interval(0, 2))) + + +def test_real(): + x = Symbol('x', real=True) + + I = Interval(0, 5) + J = Interval(10, 20) + A = FiniteSet(1, 2, 30, x, S.Pi) + B = FiniteSet(-4, 0) + C = FiniteSet(100) + D = FiniteSet('Ham', 'Eggs') + + assert all(s.is_subset(S.Reals) for s in [I, J, A, B, C]) + assert not D.is_subset(S.Reals) + assert all((a + b).is_subset(S.Reals) for a in [I, J, A, B, C] for b in [I, J, A, B, C]) + assert not any((a + D).is_subset(S.Reals) for a in [I, J, A, B, C, D]) + + assert not (I + A + D).is_subset(S.Reals) + + +def test_supinf(): + x = Symbol('x', real=True) + y = Symbol('y', real=True) + + assert (Interval(0, 1) + FiniteSet(2)).sup == 2 + assert (Interval(0, 1) + FiniteSet(2)).inf == 0 + assert (Interval(0, 1) + FiniteSet(x)).sup == Max(1, x) + assert (Interval(0, 1) + FiniteSet(x)).inf == Min(0, x) + assert FiniteSet(5, 1, x).sup == Max(5, x) + assert FiniteSet(5, 1, x).inf == Min(1, x) + assert FiniteSet(5, 1, x, y).sup == Max(5, x, y) + assert FiniteSet(5, 1, x, y).inf == Min(1, x, y) + assert FiniteSet(5, 1, x, y, S.Infinity, S.NegativeInfinity).sup == \ + S.Infinity + assert FiniteSet(5, 1, x, y, S.Infinity, S.NegativeInfinity).inf == \ + S.NegativeInfinity + assert FiniteSet('Ham', 'Eggs').sup == Max('Ham', 'Eggs') + + +def test_universalset(): + U = S.UniversalSet + x = Symbol('x') + assert U.as_relational(x) is S.true + assert U.union(Interval(2, 4)) == U + + assert U.intersect(Interval(2, 4)) == Interval(2, 4) + assert U.measure is S.Infinity + assert U.boundary == S.EmptySet + assert U.contains(0) is S.true + + +def test_Union_of_ProductSets_shares(): + line = Interval(0, 2) + points = FiniteSet(0, 1, 2) + assert Union(line * line, line * points) == line * line + + +def test_Interval_free_symbols(): + # issue 6211 + assert Interval(0, 1).free_symbols == set() + x = Symbol('x', real=True) + assert Interval(0, x).free_symbols == {x} + + +def test_image_interval(): + x = Symbol('x', real=True) + a = Symbol('a', real=True) + assert imageset(x, 2*x, Interval(-2, 1)) == Interval(-4, 2) + assert imageset(x, 2*x, Interval(-2, 1, True, False)) == \ + Interval(-4, 2, True, False) + assert imageset(x, x**2, Interval(-2, 1, True, False)) == \ + Interval(0, 4, False, True) + assert imageset(x, x**2, Interval(-2, 1)) == Interval(0, 4) + assert imageset(x, x**2, Interval(-2, 1, True, False)) == \ + Interval(0, 4, False, True) + assert imageset(x, x**2, Interval(-2, 1, True, True)) == \ + Interval(0, 4, False, True) + assert imageset(x, (x - 2)**2, Interval(1, 3)) == Interval(0, 1) + assert imageset(x, 3*x**4 - 26*x**3 + 78*x**2 - 90*x, Interval(0, 4)) == \ + Interval(-35, 0) # Multiple Maxima + assert imageset(x, x + 1/x, Interval(-oo, oo)) == Interval(-oo, -2) \ + + Interval(2, oo) # Single Infinite discontinuity + assert imageset(x, 1/x + 1/(x-1)**2, Interval(0, 2, True, False)) == \ + Interval(Rational(3, 2), oo, False) # Multiple Infinite discontinuities + + # Test for Python lambda + assert imageset(lambda x: 2*x, Interval(-2, 1)) == Interval(-4, 2) + + assert imageset(Lambda(x, a*x), Interval(0, 1)) == \ + ImageSet(Lambda(x, a*x), Interval(0, 1)) + + assert imageset(Lambda(x, sin(cos(x))), Interval(0, 1)) == \ + ImageSet(Lambda(x, sin(cos(x))), Interval(0, 1)) + + +def test_image_piecewise(): + f = Piecewise((x, x <= -1), (1/x**2, x <= 5), (x**3, True)) + f1 = Piecewise((0, x <= 1), (1, x <= 2), (2, True)) + assert imageset(x, f, Interval(-5, 5)) == Union(Interval(-5, -1), Interval(Rational(1, 25), oo)) + assert imageset(x, f1, Interval(1, 2)) == FiniteSet(0, 1) + + +@XFAIL # See: https://github.com/sympy/sympy/pull/2723#discussion_r8659826 +def test_image_Intersection(): + x = Symbol('x', real=True) + y = Symbol('y', real=True) + assert imageset(x, x**2, Interval(-2, 0).intersect(Interval(x, y))) == \ + Interval(0, 4).intersect(Interval(Min(x**2, y**2), Max(x**2, y**2))) + + +def test_image_FiniteSet(): + x = Symbol('x', real=True) + assert imageset(x, 2*x, FiniteSet(1, 2, 3)) == FiniteSet(2, 4, 6) + + +def test_image_Union(): + x = Symbol('x', real=True) + assert imageset(x, x**2, Interval(-2, 0) + FiniteSet(1, 2, 3)) == \ + (Interval(0, 4) + FiniteSet(9)) + + +def test_image_EmptySet(): + x = Symbol('x', real=True) + assert imageset(x, 2*x, S.EmptySet) == S.EmptySet + + +def test_issue_5724_7680(): + assert I not in S.Reals # issue 7680 + assert Interval(-oo, oo).contains(I) is S.false + + +def test_boundary(): + assert FiniteSet(1).boundary == FiniteSet(1) + assert all(Interval(0, 1, left_open, right_open).boundary == FiniteSet(0, 1) + for left_open in (true, false) for right_open in (true, false)) + + +def test_boundary_Union(): + assert (Interval(0, 1) + Interval(2, 3)).boundary == FiniteSet(0, 1, 2, 3) + assert ((Interval(0, 1, False, True) + + Interval(1, 2, True, False)).boundary == FiniteSet(0, 1, 2)) + + assert (Interval(0, 1) + FiniteSet(2)).boundary == FiniteSet(0, 1, 2) + assert Union(Interval(0, 10), Interval(5, 15), evaluate=False).boundary \ + == FiniteSet(0, 15) + + assert Union(Interval(0, 10), Interval(0, 1), evaluate=False).boundary \ + == FiniteSet(0, 10) + assert Union(Interval(0, 10, True, True), + Interval(10, 15, True, True), evaluate=False).boundary \ + == FiniteSet(0, 10, 15) + + +@XFAIL +def test_union_boundary_of_joining_sets(): + """ Testing the boundary of unions is a hard problem """ + assert Union(Interval(0, 10), Interval(10, 15), evaluate=False).boundary \ + == FiniteSet(0, 15) + + +def test_boundary_ProductSet(): + open_square = Interval(0, 1, True, True) ** 2 + assert open_square.boundary == (FiniteSet(0, 1) * Interval(0, 1) + + Interval(0, 1) * FiniteSet(0, 1)) + + second_square = Interval(1, 2, True, True) * Interval(0, 1, True, True) + assert (open_square + second_square).boundary == ( + FiniteSet(0, 1) * Interval(0, 1) + + FiniteSet(1, 2) * Interval(0, 1) + + Interval(0, 1) * FiniteSet(0, 1) + + Interval(1, 2) * FiniteSet(0, 1)) + + +def test_boundary_ProductSet_line(): + line_in_r2 = Interval(0, 1) * FiniteSet(0) + assert line_in_r2.boundary == line_in_r2 + + +def test_is_open(): + assert Interval(0, 1, False, False).is_open is False + assert Interval(0, 1, True, False).is_open is False + assert Interval(0, 1, True, True).is_open is True + assert FiniteSet(1, 2, 3).is_open is False + + +def test_is_closed(): + assert Interval(0, 1, False, False).is_closed is True + assert Interval(0, 1, True, False).is_closed is False + assert FiniteSet(1, 2, 3).is_closed is True + + +def test_closure(): + assert Interval(0, 1, False, True).closure == Interval(0, 1, False, False) + + +def test_interior(): + assert Interval(0, 1, False, True).interior == Interval(0, 1, True, True) + + +def test_issue_7841(): + raises(TypeError, lambda: x in S.Reals) + + +def test_Eq(): + assert Eq(Interval(0, 1), Interval(0, 1)) + assert Eq(Interval(0, 1), Interval(0, 2)) == False + + s1 = FiniteSet(0, 1) + s2 = FiniteSet(1, 2) + + assert Eq(s1, s1) + assert Eq(s1, s2) == False + + assert Eq(s1*s2, s1*s2) + assert Eq(s1*s2, s2*s1) == False + + assert unchanged(Eq, FiniteSet({x, y}), FiniteSet({x})) + assert Eq(FiniteSet({x, y}).subs(y, x), FiniteSet({x})) is S.true + assert Eq(FiniteSet({x, y}), FiniteSet({x})).subs(y, x) is S.true + assert Eq(FiniteSet({x, y}).subs(y, x+1), FiniteSet({x})) is S.false + assert Eq(FiniteSet({x, y}), FiniteSet({x})).subs(y, x+1) is S.false + + assert Eq(ProductSet({1}, {2}), Interval(1, 2)) is S.false + assert Eq(ProductSet({1}), ProductSet({1}, {2})) is S.false + + assert Eq(FiniteSet(()), FiniteSet(1)) is S.false + assert Eq(ProductSet(), FiniteSet(1)) is S.false + + i1 = Interval(0, 1) + i2 = Interval(x, y) + assert unchanged(Eq, ProductSet(i1, i1), ProductSet(i2, i2)) + + +def test_SymmetricDifference(): + A = FiniteSet(0, 1, 2, 3, 4, 5) + B = FiniteSet(2, 4, 6, 8, 10) + C = Interval(8, 10) + + assert SymmetricDifference(A, B, evaluate=False).is_iterable is True + assert SymmetricDifference(A, C, evaluate=False).is_iterable is None + assert FiniteSet(*SymmetricDifference(A, B, evaluate=False)) == \ + FiniteSet(0, 1, 3, 5, 6, 8, 10) + raises(TypeError, + lambda: FiniteSet(*SymmetricDifference(A, C, evaluate=False))) + + assert SymmetricDifference(FiniteSet(0, 1, 2, 3, 4, 5), \ + FiniteSet(2, 4, 6, 8, 10)) == FiniteSet(0, 1, 3, 5, 6, 8, 10) + assert SymmetricDifference(FiniteSet(2, 3, 4), FiniteSet(2, 3, 4 ,5)) \ + == FiniteSet(5) + assert FiniteSet(1, 2, 3, 4, 5) ^ FiniteSet(1, 2, 5, 6) == \ + FiniteSet(3, 4, 6) + assert Set(S(1), S(2), S(3)) ^ Set(S(2), S(3), S(4)) == Union(Set(S(1), S(2), S(3)) - Set(S(2), S(3), S(4)), \ + Set(S(2), S(3), S(4)) - Set(S(1), S(2), S(3))) + assert Interval(0, 4) ^ Interval(2, 5) == Union(Interval(0, 4) - \ + Interval(2, 5), Interval(2, 5) - Interval(0, 4)) + + +def test_issue_9536(): + from sympy.functions.elementary.exponential import log + a = Symbol('a', real=True) + assert FiniteSet(log(a)).intersect(S.Reals) == Intersection(S.Reals, FiniteSet(log(a))) + + +def test_issue_9637(): + n = Symbol('n') + a = FiniteSet(n) + b = FiniteSet(2, n) + assert Complement(S.Reals, a) == Complement(S.Reals, a, evaluate=False) + assert Complement(Interval(1, 3), a) == Complement(Interval(1, 3), a, evaluate=False) + assert Complement(Interval(1, 3), b) == \ + Complement(Union(Interval(1, 2, False, True), Interval(2, 3, True, False)), a) + assert Complement(a, S.Reals) == Complement(a, S.Reals, evaluate=False) + assert Complement(a, Interval(1, 3)) == Complement(a, Interval(1, 3), evaluate=False) + + +def test_issue_9808(): + # See https://github.com/sympy/sympy/issues/16342 + assert Complement(FiniteSet(y), FiniteSet(1)) == Complement(FiniteSet(y), FiniteSet(1), evaluate=False) + assert Complement(FiniteSet(1, 2, x), FiniteSet(x, y, 2, 3)) == \ + Complement(FiniteSet(1), FiniteSet(y), evaluate=False) + + +def test_issue_9956(): + assert Union(Interval(-oo, oo), FiniteSet(1)) == Interval(-oo, oo) + assert Interval(-oo, oo).contains(1) is S.true + + +def test_issue_Symbol_inter(): + i = Interval(0, oo) + r = S.Reals + mat = Matrix([0, 0, 0]) + assert Intersection(r, i, FiniteSet(m), FiniteSet(m, n)) == \ + Intersection(i, FiniteSet(m)) + assert Intersection(FiniteSet(1, m, n), FiniteSet(m, n, 2), i) == \ + Intersection(i, FiniteSet(m, n)) + assert Intersection(FiniteSet(m, n, x), FiniteSet(m, z), r) == \ + Intersection(Intersection({m, z}, {m, n, x}), r) + assert Intersection(FiniteSet(m, n, 3), FiniteSet(m, n, x), r) == \ + Intersection(FiniteSet(3, m, n), FiniteSet(m, n, x), r, evaluate=False) + assert Intersection(FiniteSet(m, n, 3), FiniteSet(m, n, 2, 3), r) == \ + Intersection(FiniteSet(3, m, n), r) + assert Intersection(r, FiniteSet(mat, 2, n), FiniteSet(0, mat, n)) == \ + Intersection(r, FiniteSet(n)) + assert Intersection(FiniteSet(sin(x), cos(x)), FiniteSet(sin(x), cos(x), 1), r) == \ + Intersection(r, FiniteSet(sin(x), cos(x))) + assert Intersection(FiniteSet(x**2, 1, sin(x)), FiniteSet(x**2, 2, sin(x)), r) == \ + Intersection(r, FiniteSet(x**2, sin(x))) + + +def test_issue_11827(): + assert S.Naturals0**4 + + +def test_issue_10113(): + f = x**2/(x**2 - 4) + assert imageset(x, f, S.Reals) == Union(Interval(-oo, 0), Interval(1, oo, True, True)) + assert imageset(x, f, Interval(-2, 2)) == Interval(-oo, 0) + assert imageset(x, f, Interval(-2, 3)) == Union(Interval(-oo, 0), Interval(Rational(9, 5), oo)) + + +def test_issue_10248(): + raises( + TypeError, lambda: list(Intersection(S.Reals, FiniteSet(x))) + ) + A = Symbol('A', real=True) + assert list(Intersection(S.Reals, FiniteSet(A))) == [A] + + +def test_issue_9447(): + a = Interval(0, 1) + Interval(2, 3) + assert Complement(S.UniversalSet, a) == Complement( + S.UniversalSet, Union(Interval(0, 1), Interval(2, 3)), evaluate=False) + assert Complement(S.Naturals, a) == Complement( + S.Naturals, Union(Interval(0, 1), Interval(2, 3)), evaluate=False) + + +def test_issue_10337(): + assert (FiniteSet(2) == 3) is False + assert (FiniteSet(2) != 3) is True + raises(TypeError, lambda: FiniteSet(2) < 3) + raises(TypeError, lambda: FiniteSet(2) <= 3) + raises(TypeError, lambda: FiniteSet(2) > 3) + raises(TypeError, lambda: FiniteSet(2) >= 3) + + +def test_issue_10326(): + bad = [ + EmptySet, + FiniteSet(1), + Interval(1, 2), + S.ComplexInfinity, + S.ImaginaryUnit, + S.Infinity, + S.NaN, + S.NegativeInfinity, + ] + interval = Interval(0, 5) + for i in bad: + assert i not in interval + + x = Symbol('x', real=True) + nr = Symbol('nr', extended_real=False) + assert x + 1 in Interval(x, x + 4) + assert nr not in Interval(x, x + 4) + assert Interval(1, 2) in FiniteSet(Interval(0, 5), Interval(1, 2)) + assert Interval(-oo, oo).contains(oo) is S.false + assert Interval(-oo, oo).contains(-oo) is S.false + + +def test_issue_2799(): + U = S.UniversalSet + a = Symbol('a', real=True) + inf_interval = Interval(a, oo) + R = S.Reals + + assert U + inf_interval == inf_interval + U + assert U + R == R + U + assert R + inf_interval == inf_interval + R + + +def test_issue_9706(): + assert Interval(-oo, 0).closure == Interval(-oo, 0, True, False) + assert Interval(0, oo).closure == Interval(0, oo, False, True) + assert Interval(-oo, oo).closure == Interval(-oo, oo) + + +def test_issue_8257(): + reals_plus_infinity = Union(Interval(-oo, oo), FiniteSet(oo)) + reals_plus_negativeinfinity = Union(Interval(-oo, oo), FiniteSet(-oo)) + assert Interval(-oo, oo) + FiniteSet(oo) == reals_plus_infinity + assert FiniteSet(oo) + Interval(-oo, oo) == reals_plus_infinity + assert Interval(-oo, oo) + FiniteSet(-oo) == reals_plus_negativeinfinity + assert FiniteSet(-oo) + Interval(-oo, oo) == reals_plus_negativeinfinity + + +def test_issue_10931(): + assert S.Integers - S.Integers == EmptySet + assert S.Integers - S.Reals == EmptySet + + +def test_issue_11174(): + soln = Intersection(Interval(-oo, oo), FiniteSet(-x), evaluate=False) + assert Intersection(FiniteSet(-x), S.Reals) == soln + + soln = Intersection(S.Reals, FiniteSet(x), evaluate=False) + assert Intersection(FiniteSet(x), S.Reals) == soln + + +def test_issue_18505(): + assert ImageSet(Lambda(n, sqrt(pi*n/2 - 1 + pi/2)), S.Integers).contains(0) == \ + Contains(0, ImageSet(Lambda(n, sqrt(pi*n/2 - 1 + pi/2)), S.Integers)) + + +def test_finite_set_intersection(): + # The following should not produce recursion errors + # Note: some of these are not completely correct. See + # https://github.com/sympy/sympy/issues/16342. + assert Intersection(FiniteSet(-oo, x), FiniteSet(x)) == FiniteSet(x) + assert Intersection._handle_finite_sets([FiniteSet(-oo, x), FiniteSet(0, x)]) == FiniteSet(x) + + assert Intersection._handle_finite_sets([FiniteSet(-oo, x), FiniteSet(x)]) == FiniteSet(x) + assert Intersection._handle_finite_sets([FiniteSet(2, 3, x, y), FiniteSet(1, 2, x)]) == \ + Intersection._handle_finite_sets([FiniteSet(1, 2, x), FiniteSet(2, 3, x, y)]) == \ + Intersection(FiniteSet(1, 2, x), FiniteSet(2, 3, x, y)) == \ + Intersection(FiniteSet(1, 2, x), FiniteSet(2, x, y)) + + assert FiniteSet(1+x-y) & FiniteSet(1) == \ + FiniteSet(1) & FiniteSet(1+x-y) == \ + Intersection(FiniteSet(1+x-y), FiniteSet(1), evaluate=False) + + assert FiniteSet(1) & FiniteSet(x) == FiniteSet(x) & FiniteSet(1) == \ + Intersection(FiniteSet(1), FiniteSet(x), evaluate=False) + + assert FiniteSet({x}) & FiniteSet({x, y}) == \ + Intersection(FiniteSet({x}), FiniteSet({x, y}), evaluate=False) + + +def test_union_intersection_constructor(): + # The actual exception does not matter here, so long as these fail + sets = [FiniteSet(1), FiniteSet(2)] + raises(Exception, lambda: Union(sets)) + raises(Exception, lambda: Intersection(sets)) + raises(Exception, lambda: Union(tuple(sets))) + raises(Exception, lambda: Intersection(tuple(sets))) + raises(Exception, lambda: Union(i for i in sets)) + raises(Exception, lambda: Intersection(i for i in sets)) + + # Python sets are treated the same as FiniteSet + # The union of a single set (of sets) is the set (of sets) itself + assert Union(set(sets)) == FiniteSet(*sets) + assert Intersection(set(sets)) == FiniteSet(*sets) + + assert Union({1}, {2}) == FiniteSet(1, 2) + assert Intersection({1, 2}, {2, 3}) == FiniteSet(2) + + +def test_Union_contains(): + assert zoo not in Union( + Interval.open(-oo, 0), Interval.open(0, oo)) + + +@XFAIL +def test_issue_16878b(): + # in intersection_sets for (ImageSet, Set) there is no code + # that handles the base_set of S.Reals like there is + # for Integers + assert imageset(x, (x, x), S.Reals).is_subset(S.Reals**2) is True + +def test_DisjointUnion(): + assert DisjointUnion(FiniteSet(1, 2, 3), FiniteSet(1, 2, 3), FiniteSet(1, 2, 3)).rewrite(Union) == (FiniteSet(1, 2, 3) * FiniteSet(0, 1, 2)) + assert DisjointUnion(Interval(1, 3), Interval(2, 4)).rewrite(Union) == Union(Interval(1, 3) * FiniteSet(0), Interval(2, 4) * FiniteSet(1)) + assert DisjointUnion(Interval(0, 5), Interval(0, 5)).rewrite(Union) == Union(Interval(0, 5) * FiniteSet(0), Interval(0, 5) * FiniteSet(1)) + assert DisjointUnion(Interval(-1, 2), S.EmptySet, S.EmptySet).rewrite(Union) == Interval(-1, 2) * FiniteSet(0) + assert DisjointUnion(Interval(-1, 2)).rewrite(Union) == Interval(-1, 2) * FiniteSet(0) + assert DisjointUnion(S.EmptySet, Interval(-1, 2), S.EmptySet).rewrite(Union) == Interval(-1, 2) * FiniteSet(1) + assert DisjointUnion(Interval(-oo, oo)).rewrite(Union) == Interval(-oo, oo) * FiniteSet(0) + assert DisjointUnion(S.EmptySet).rewrite(Union) == S.EmptySet + assert DisjointUnion().rewrite(Union) == S.EmptySet + raises(TypeError, lambda: DisjointUnion(Symbol('n'))) + + x = Symbol("x") + y = Symbol("y") + z = Symbol("z") + assert DisjointUnion(FiniteSet(x), FiniteSet(y, z)).rewrite(Union) == (FiniteSet(x) * FiniteSet(0)) + (FiniteSet(y, z) * FiniteSet(1)) + +def test_DisjointUnion_is_empty(): + assert DisjointUnion(S.EmptySet).is_empty is True + assert DisjointUnion(S.EmptySet, S.EmptySet).is_empty is True + assert DisjointUnion(S.EmptySet, FiniteSet(1, 2, 3)).is_empty is False + +def test_DisjointUnion_is_iterable(): + assert DisjointUnion(S.Integers, S.Naturals, S.Rationals).is_iterable is True + assert DisjointUnion(S.EmptySet, S.Reals).is_iterable is False + assert DisjointUnion(FiniteSet(1, 2, 3), S.EmptySet, FiniteSet(x, y)).is_iterable is True + assert DisjointUnion(S.EmptySet, S.EmptySet).is_iterable is False + +def test_DisjointUnion_contains(): + assert (0, 0) in DisjointUnion(FiniteSet(0, 1, 2), FiniteSet(0, 1, 2), FiniteSet(0, 1, 2)) + assert (0, 1) in DisjointUnion(FiniteSet(0, 1, 2), FiniteSet(0, 1, 2), FiniteSet(0, 1, 2)) + assert (0, 2) in DisjointUnion(FiniteSet(0, 1, 2), FiniteSet(0, 1, 2), FiniteSet(0, 1, 2)) + assert (1, 0) in DisjointUnion(FiniteSet(0, 1, 2), FiniteSet(0, 1, 2), FiniteSet(0, 1, 2)) + assert (1, 1) in DisjointUnion(FiniteSet(0, 1, 2), FiniteSet(0, 1, 2), FiniteSet(0, 1, 2)) + assert (1, 2) in DisjointUnion(FiniteSet(0, 1, 2), FiniteSet(0, 1, 2), FiniteSet(0, 1, 2)) + assert (2, 0) in DisjointUnion(FiniteSet(0, 1, 2), FiniteSet(0, 1, 2), FiniteSet(0, 1, 2)) + assert (2, 1) in DisjointUnion(FiniteSet(0, 1, 2), FiniteSet(0, 1, 2), FiniteSet(0, 1, 2)) + assert (2, 2) in DisjointUnion(FiniteSet(0, 1, 2), FiniteSet(0, 1, 2), FiniteSet(0, 1, 2)) + assert (0, 1, 2) not in DisjointUnion(FiniteSet(0, 1, 2), FiniteSet(0, 1, 2), FiniteSet(0, 1, 2)) + assert (0, 0.5) not in DisjointUnion(FiniteSet(0.5)) + assert (0, 5) not in DisjointUnion(FiniteSet(0, 1, 2), FiniteSet(0, 1, 2), FiniteSet(0, 1, 2)) + assert (x, 0) in DisjointUnion(FiniteSet(x, y, z), S.EmptySet, FiniteSet(y)) + assert (y, 0) in DisjointUnion(FiniteSet(x, y, z), S.EmptySet, FiniteSet(y)) + assert (z, 0) in DisjointUnion(FiniteSet(x, y, z), S.EmptySet, FiniteSet(y)) + assert (y, 2) in DisjointUnion(FiniteSet(x, y, z), S.EmptySet, FiniteSet(y)) + assert (0.5, 0) in DisjointUnion(Interval(0, 1), Interval(0, 2)) + assert (0.5, 1) in DisjointUnion(Interval(0, 1), Interval(0, 2)) + assert (1.5, 0) not in DisjointUnion(Interval(0, 1), Interval(0, 2)) + assert (1.5, 1) in DisjointUnion(Interval(0, 1), Interval(0, 2)) + +def test_DisjointUnion_iter(): + D = DisjointUnion(FiniteSet(3, 5, 7, 9), FiniteSet(x, y, z)) + it = iter(D) + L1 = [(x, 1), (y, 1), (z, 1)] + L2 = [(3, 0), (5, 0), (7, 0), (9, 0)] + nxt = next(it) + assert nxt in L2 + L2.remove(nxt) + nxt = next(it) + assert nxt in L1 + L1.remove(nxt) + nxt = next(it) + assert nxt in L2 + L2.remove(nxt) + nxt = next(it) + assert nxt in L1 + L1.remove(nxt) + nxt = next(it) + assert nxt in L2 + L2.remove(nxt) + nxt = next(it) + assert nxt in L1 + L1.remove(nxt) + nxt = next(it) + assert nxt in L2 + L2.remove(nxt) + raises(StopIteration, lambda: next(it)) + + raises(ValueError, lambda: iter(DisjointUnion(Interval(0, 1), S.EmptySet))) + +def test_DisjointUnion_len(): + assert len(DisjointUnion(FiniteSet(3, 5, 7, 9), FiniteSet(x, y, z))) == 7 + assert len(DisjointUnion(S.EmptySet, S.EmptySet, FiniteSet(x, y, z), S.EmptySet)) == 3 + raises(ValueError, lambda: len(DisjointUnion(Interval(0, 1), S.EmptySet))) + +def test_SetKind_ProductSet(): + p = ProductSet(FiniteSet(Matrix([1, 2])), FiniteSet(Matrix([1, 2]))) + mk = MatrixKind(NumberKind) + k = SetKind(TupleKind(mk, mk)) + assert p.kind is k + assert ProductSet(Interval(1, 2), FiniteSet(Matrix([1, 2]))).kind is SetKind(TupleKind(NumberKind, mk)) + +def test_SetKind_Interval(): + assert Interval(1, 2).kind is SetKind(NumberKind) + +def test_SetKind_EmptySet_UniversalSet(): + assert S.UniversalSet.kind is SetKind(UndefinedKind) + assert EmptySet.kind is SetKind() + +def test_SetKind_FiniteSet(): + assert FiniteSet(1, Matrix([1, 2])).kind is SetKind(UndefinedKind) + assert FiniteSet(1, 2).kind is SetKind(NumberKind) + +def test_SetKind_Unions(): + assert Union(FiniteSet(Matrix([1, 2])), Interval(1, 2)).kind is SetKind(UndefinedKind) + assert Union(Interval(1, 2), Interval(1, 7)).kind is SetKind(NumberKind) + +def test_SetKind_DisjointUnion(): + A = FiniteSet(1, 2, 3) + B = Interval(0, 5) + assert DisjointUnion(A, B).kind is SetKind(NumberKind) + +def test_SetKind_evaluate_False(): + U = lambda *args: Union(*args, evaluate=False) + assert U({1}, EmptySet).kind is SetKind(NumberKind) + assert U(Interval(1, 2), EmptySet).kind is SetKind(NumberKind) + assert U({1}, S.UniversalSet).kind is SetKind(UndefinedKind) + assert U(Interval(1, 2), Interval(4, 5), + FiniteSet(1)).kind is SetKind(NumberKind) + I = lambda *args: Intersection(*args, evaluate=False) + assert I({1}, S.UniversalSet).kind is SetKind(NumberKind) + assert I({1}, EmptySet).kind is SetKind() + C = lambda *args: Complement(*args, evaluate=False) + assert C(S.UniversalSet, {1, 2, 4, 5}).kind is SetKind(UndefinedKind) + assert C({1, 2, 3, 4, 5}, EmptySet).kind is SetKind(NumberKind) + assert C(EmptySet, {1, 2, 3, 4, 5}).kind is SetKind() + +def test_SetKind_ImageSet_Special(): + f = ImageSet(Lambda(n, n ** 2), Interval(1, 4)) + assert (f - FiniteSet(3)).kind is SetKind(NumberKind) + assert (f + Interval(16, 17)).kind is SetKind(NumberKind) + assert (f + FiniteSet(17)).kind is SetKind(NumberKind) + +def test_issue_20089(): + B = FiniteSet(FiniteSet(1, 2), FiniteSet(1)) + assert 1 not in B + assert 1.0 not in B + assert not Eq(1, FiniteSet(1, 2)) + assert FiniteSet(1) in B + A = FiniteSet(1, 2) + assert A in B + assert B.issubset(B) + assert not A.issubset(B) + assert 1 in A + C = FiniteSet(FiniteSet(1, 2), FiniteSet(1), 1, 2) + assert A.issubset(C) + assert B.issubset(C) + +def test_issue_19378(): + a = FiniteSet(1, 2) + b = ProductSet(a, a) + c = FiniteSet((1, 1), (1, 2), (2, 1), (2, 2)) + assert b.is_subset(c) is True + d = FiniteSet(1) + assert b.is_subset(d) is False + assert Eq(c, b).simplify() is S.true + assert Eq(a, c).simplify() is S.false + assert Eq({1}, {x}).simplify() == Eq({1}, {x}) + +def test_intersection_symbolic(): + n = Symbol('n') + # These should not throw an error + assert isinstance(Intersection(Range(n), Range(100)), Intersection) + assert isinstance(Intersection(Range(n), Interval(1, 100)), Intersection) + assert isinstance(Intersection(Range(100), Interval(1, n)), Intersection) + + +@XFAIL +def test_intersection_symbolic_failing(): + n = Symbol('n', integer=True, positive=True) + assert Intersection(Range(10, n), Range(4, 500, 5)) == Intersection( + Range(14, n), Range(14, 500, 5)) + assert Intersection(Interval(10, n), Range(4, 500, 5)) == Intersection( + Interval(14, n), Range(14, 500, 5)) + + +def test_issue_20379(): + #https://github.com/sympy/sympy/issues/20379 + x = pi - 3.14159265358979 + assert FiniteSet(x).evalf(2) == FiniteSet(Float('3.23108914886517e-15', 2)) + +def test_finiteset_simplify(): + S = FiniteSet(1, cos(1)**2 + sin(1)**2) + assert S.simplify() == {1} + +def test_issue_14336(): + #https://github.com/sympy/sympy/issues/14336 + U = S.Complexes + x = Symbol("x") + U -= U.intersect(Ne(x, 1).as_set()) + U -= U.intersect(S.true.as_set()) + +def test_issue_9855(): + #https://github.com/sympy/sympy/issues/9855 + x, y, z = symbols('x, y, z', real=True) + s1 = Interval(1, x) & Interval(y, 2) + s2 = Interval(1, 2) + assert s1.is_subset(s2) == None diff --git a/phivenv/Lib/site-packages/sympy/simplify/__init__.py b/phivenv/Lib/site-packages/sympy/simplify/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0619d1c3ebbd6c6a7d663093c7ed2202114148af --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/simplify/__init__.py @@ -0,0 +1,60 @@ +"""The module helps converting SymPy expressions into shorter forms of them. + +for example: +the expression E**(pi*I) will be converted into -1 +the expression (x+x)**2 will be converted into 4*x**2 +""" +from .simplify import (simplify, hypersimp, hypersimilar, + logcombine, separatevars, posify, besselsimp, kroneckersimp, + signsimp, nsimplify) + +from .fu import FU, fu + +from .sqrtdenest import sqrtdenest + +from .cse_main import cse + +from .epathtools import epath, EPath + +from .hyperexpand import hyperexpand + +from .radsimp import collect, rcollect, radsimp, collect_const, fraction, numer, denom + +from .trigsimp import trigsimp, exptrigsimp + +from .powsimp import powsimp, powdenest + +from .combsimp import combsimp + +from .gammasimp import gammasimp + +from .ratsimp import ratsimp, ratsimpmodprime + +__all__ = [ + 'simplify', 'hypersimp', 'hypersimilar', 'logcombine', 'separatevars', + 'posify', 'besselsimp', 'kroneckersimp', 'signsimp', + 'nsimplify', + + 'FU', 'fu', + + 'sqrtdenest', + + 'cse', + + 'epath', 'EPath', + + 'hyperexpand', + + 'collect', 'rcollect', 'radsimp', 'collect_const', 'fraction', 'numer', + 'denom', + + 'trigsimp', 'exptrigsimp', + + 'powsimp', 'powdenest', + + 'combsimp', + + 'gammasimp', + + 'ratsimp', 'ratsimpmodprime', +] diff --git a/phivenv/Lib/site-packages/sympy/simplify/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/simplify/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2cd471740d71a6972a29cf9595fb6eb7df56577 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/simplify/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/simplify/__pycache__/_cse_diff.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/simplify/__pycache__/_cse_diff.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..212451648055782137b1f9e941c58c6df2e075c1 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/simplify/__pycache__/_cse_diff.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/simplify/__pycache__/combsimp.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/simplify/__pycache__/combsimp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b5987262f11b89cac431f95d2f3bd201a15f413 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/simplify/__pycache__/combsimp.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/simplify/__pycache__/cse_main.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/simplify/__pycache__/cse_main.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..835ce35b35af56213fdbca1df039c988581e3456 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/simplify/__pycache__/cse_main.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/simplify/__pycache__/cse_opts.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/simplify/__pycache__/cse_opts.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c0c7cbb4744679c7bbaa917b208e6836d853ebc Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/simplify/__pycache__/cse_opts.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/simplify/__pycache__/epathtools.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/simplify/__pycache__/epathtools.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e734a75d56735d0a98f5c42840108cafa6172f7 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/simplify/__pycache__/epathtools.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/simplify/__pycache__/fu.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/simplify/__pycache__/fu.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f23982bd69be4e4183432baccb69255bb90f2805 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/simplify/__pycache__/fu.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/simplify/__pycache__/gammasimp.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/simplify/__pycache__/gammasimp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5860331263f10dc06ab415feeddc9a5bea26fb6 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/simplify/__pycache__/gammasimp.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/simplify/__pycache__/hyperexpand.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/simplify/__pycache__/hyperexpand.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1d9caf163b3aa887c6a2e3ee00244f1b9daa114 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/simplify/__pycache__/hyperexpand.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/simplify/__pycache__/hyperexpand_doc.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/simplify/__pycache__/hyperexpand_doc.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95abd6eabc0a85f79b61d70e1645e0c3cd1dc2e5 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/simplify/__pycache__/hyperexpand_doc.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/simplify/__pycache__/powsimp.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/simplify/__pycache__/powsimp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eca323844f84d3cb1f5f3c9790480ac814f498eb Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/simplify/__pycache__/powsimp.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/simplify/__pycache__/radsimp.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/simplify/__pycache__/radsimp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b0a0c6286d1765474577bd60e8f0c2dc58949de Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/simplify/__pycache__/radsimp.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/simplify/__pycache__/ratsimp.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/simplify/__pycache__/ratsimp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..641fd8ecd43289cc14322d389db053a0069448ae Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/simplify/__pycache__/ratsimp.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/simplify/__pycache__/simplify.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/simplify/__pycache__/simplify.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ea2747767f390af4222e606666581bf2ed9ae29 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/simplify/__pycache__/simplify.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/simplify/__pycache__/sqrtdenest.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/simplify/__pycache__/sqrtdenest.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c917c7d52389b5ff90b2f145040ebc798f495df Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/simplify/__pycache__/sqrtdenest.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/simplify/__pycache__/traversaltools.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/simplify/__pycache__/traversaltools.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a0a5dbb694202cb86a011f18fdaa4c049eb6e229 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/simplify/__pycache__/traversaltools.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/simplify/__pycache__/trigsimp.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/simplify/__pycache__/trigsimp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc543948ac2016a6b78a7d3e2da13e8e53836aa2 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/simplify/__pycache__/trigsimp.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/simplify/_cse_diff.py b/phivenv/Lib/site-packages/sympy/simplify/_cse_diff.py new file mode 100644 index 0000000000000000000000000000000000000000..3496ad3b31a4f45312cac002429be40aa9aa0868 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/simplify/_cse_diff.py @@ -0,0 +1,291 @@ +"""Module for differentiation using CSE.""" + +from sympy import cse, Matrix, Derivative, MatrixBase +from sympy.utilities.iterables import iterable + + +def _remove_cse_from_derivative(replacements, reduced_expressions): + """ + This function is designed to postprocess the output of a common subexpression + elimination (CSE) operation. Specifically, it removes any CSE replacement + symbols from the arguments of ``Derivative`` terms in the expression. This + is necessary to ensure that the forward Jacobian function correctly handles + derivative terms. + + Parameters + ========== + + replacements : list of (Symbol, expression) pairs + Replacement symbols and relative common subexpressions that have been + replaced during a CSE operation. + + reduced_expressions : list of SymPy expressions + The reduced expressions with all the replacements from the + replacements list above. + + Returns + ======= + + processed_replacements : list of (Symbol, expression) pairs + Processed replacement list, in the same format of the + ``replacements`` input list. + + processed_reduced : list of SymPy expressions + Processed reduced list, in the same format of the + ``reduced_expressions`` input list. + """ + + def traverse(node, repl_dict): + if isinstance(node, Derivative): + return replace_all(node, repl_dict) + if not node.args: + return node + new_args = [traverse(arg, repl_dict) for arg in node.args] + return node.func(*new_args) + + def replace_all(node, repl_dict): + result = node + while True: + free_symbols = result.free_symbols + symbols_dict = {k: repl_dict[k] for k in free_symbols if k in repl_dict} + if not symbols_dict: + break + result = result.xreplace(symbols_dict) + return result + + repl_dict = dict(replacements) + processed_replacements = [ + (rep_sym, traverse(sub_exp, repl_dict)) + for rep_sym, sub_exp in replacements + ] + processed_reduced = [ + red_exp.__class__([traverse(exp, repl_dict) for exp in red_exp]) + for red_exp in reduced_expressions + ] + + return processed_replacements, processed_reduced + + +def _forward_jacobian_cse(replacements, reduced_expr, wrt): + """ + Core function to compute the Jacobian of an input Matrix of expressions + through forward accumulation. Takes directly the output of a CSE operation + (replacements and reduced_expr), and an iterable of variables (wrt) with + respect to which to differentiate the reduced expression and returns the + reduced Jacobian matrix and the ``replacements`` list. + + The function also returns a list of precomputed free symbols for each + subexpression, which are useful in the substitution process. + + Parameters + ========== + + replacements : list of (Symbol, expression) pairs + Replacement symbols and relative common subexpressions that have been + replaced during a CSE operation. + + reduced_expr : list of SymPy expressions + The reduced expressions with all the replacements from the + replacements list above. + + wrt : iterable + Iterable of expressions with respect to which to compute the + Jacobian matrix. + + Returns + ======= + + replacements : list of (Symbol, expression) pairs + Replacement symbols and relative common subexpressions that have been + replaced during a CSE operation. Compared to the input replacement list, + the output one doesn't contain replacement symbols inside + ``Derivative``'s arguments. + + jacobian : list of SymPy expressions + The list only contains one element, which is the Jacobian matrix with + elements in reduced form (replacement symbols are present). + + precomputed_fs: list + List of sets, which store the free symbols present in each sub-expression. + Useful in the substitution process. + """ + + if not isinstance(reduced_expr[0], MatrixBase): + raise TypeError("``expr`` must be of matrix type") + + if not (reduced_expr[0].shape[0] == 1 or reduced_expr[0].shape[1] == 1): + raise TypeError("``expr`` must be a row or a column matrix") + + if not iterable(wrt): + raise TypeError("``wrt`` must be an iterable of variables") + + elif not isinstance(wrt, MatrixBase): + wrt = Matrix(wrt) + + if not (wrt.shape[0] == 1 or wrt.shape[1] == 1): + raise TypeError("``wrt`` must be a row or a column matrix") + + replacements, reduced_expr = _remove_cse_from_derivative(replacements, reduced_expr) + + if replacements: + rep_sym, sub_expr = map(Matrix, zip(*replacements)) + else: + rep_sym, sub_expr = Matrix([]), Matrix([]) + + l_sub, l_wrt, l_red = len(sub_expr), len(wrt), len(reduced_expr[0]) + + f1 = reduced_expr[0].__class__.from_dok(l_red, l_wrt, + { + (i, j): diff_value + for i, r in enumerate(reduced_expr[0]) + for j, w in enumerate(wrt) + if (diff_value := r.diff(w)) != 0 + }, + ) + + if not replacements: + return [], [f1], [] + + f2 = Matrix.from_dok(l_red, l_sub, + { + (i, j): diff_value + for i, (r, fs) in enumerate([(r, r.free_symbols) for r in reduced_expr[0]]) + for j, s in enumerate(rep_sym) + if s in fs and (diff_value := r.diff(s)) != 0 + }, + ) + + rep_sym_set = set(rep_sym) + precomputed_fs = [s.free_symbols & rep_sym_set for s in sub_expr ] + + c_matrix = Matrix.from_dok(1, l_wrt, + {(0, j): diff_value for j, w in enumerate(wrt) + if (diff_value := sub_expr[0].diff(w)) != 0}) + + for i in range(1, l_sub): + + bi_matrix = Matrix.from_dok(1, i, + {(0, j): diff_value for j in range(i + 1) + if rep_sym[j] in precomputed_fs[i] + and (diff_value := sub_expr[i].diff(rep_sym[j])) != 0}) + + ai_matrix = Matrix.from_dok(1, l_wrt, + {(0, j): diff_value for j, w in enumerate(wrt) + if (diff_value := sub_expr[i].diff(w)) != 0}) + + if bi_matrix._rep.nnz(): + ci_matrix = bi_matrix.multiply(c_matrix).add(ai_matrix) + c_matrix = Matrix.vstack(c_matrix, ci_matrix) + else: + c_matrix = Matrix.vstack(c_matrix, ai_matrix) + + jacobian = f2.multiply(c_matrix).add(f1) + jacobian = [reduced_expr[0].__class__(jacobian)] + + return replacements, jacobian, precomputed_fs + + +def _forward_jacobian_norm_in_cse_out(expr, wrt): + """ + Function to compute the Jacobian of an input Matrix of expressions through + forward accumulation. Takes a sympy Matrix of expressions (expr) as input + and an iterable of variables (wrt) with respect to which to compute the + Jacobian matrix. The matrix is returned in reduced form (containing + replacement symbols) along with the ``replacements`` list. + + The function also returns a list of precomputed free symbols for each + subexpression, which are useful in the substitution process. + + Parameters + ========== + + expr : Matrix + The vector to be differentiated. + + wrt : iterable + The vector with respect to which to perform the differentiation. + Can be a matrix or an iterable of variables. + + Returns + ======= + + replacements : list of (Symbol, expression) pairs + Replacement symbols and relative common subexpressions that have been + replaced during a CSE operation. The output replacement list doesn't + contain replacement symbols inside ``Derivative``'s arguments. + + jacobian : list of SymPy expressions + The list only contains one element, which is the Jacobian matrix with + elements in reduced form (replacement symbols are present). + + precomputed_fs: list + List of sets, which store the free symbols present in each + sub-expression. Useful in the substitution process. + """ + + replacements, reduced_expr = cse(expr) + replacements, jacobian, precomputed_fs = _forward_jacobian_cse(replacements, reduced_expr, wrt) + + return replacements, jacobian, precomputed_fs + + +def _forward_jacobian(expr, wrt): + """ + Function to compute the Jacobian of an input Matrix of expressions through + forward accumulation. Takes a sympy Matrix of expressions (expr) as input + and an iterable of variables (wrt) with respect to which to compute the + Jacobian matrix. + + Explanation + =========== + + Expressions often contain repeated subexpressions. Using a tree structure, + these subexpressions are duplicated and differentiated multiple times, + leading to inefficiency. + + Instead, if a data structure called a directed acyclic graph (DAG) is used + then each of these repeated subexpressions will only exist a single time. + This function uses a combination of representing the expression as a DAG and + a forward accumulation algorithm (repeated application of the chain rule + symbolically) to more efficiently calculate the Jacobian matrix of a target + expression ``expr`` with respect to an expression or set of expressions + ``wrt``. + + Note that this function is intended to improve performance when + differentiating large expressions that contain many common subexpressions. + For small and simple expressions it is likely less performant than using + SymPy's standard differentiation functions and methods. + + Parameters + ========== + + expr : Matrix + The vector to be differentiated. + + wrt : iterable + The vector with respect to which to do the differentiation. + Can be a matrix or an iterable of variables. + + See Also + ======== + + Direct Acyclic Graph : https://en.wikipedia.org/wiki/Directed_acyclic_graph + """ + + replacements, reduced_expr = cse(expr) + + if replacements: + rep_sym, _ = map(Matrix, zip(*replacements)) + else: + rep_sym = Matrix([]) + + replacements, jacobian, precomputed_fs = _forward_jacobian_cse(replacements, reduced_expr, wrt) + + if not replacements: return jacobian[0] + + sub_rep = dict(replacements) + for i, ik in enumerate(precomputed_fs): + sub_dict = {j: sub_rep[j] for j in ik} + sub_rep[rep_sym[i]] = sub_rep[rep_sym[i]].xreplace(sub_dict) + + return jacobian[0].xreplace(sub_rep) diff --git a/phivenv/Lib/site-packages/sympy/simplify/combsimp.py b/phivenv/Lib/site-packages/sympy/simplify/combsimp.py new file mode 100644 index 0000000000000000000000000000000000000000..8b0b3cefcba11b4b7759b7d3ec3c2d4415cfd849 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/simplify/combsimp.py @@ -0,0 +1,114 @@ +from sympy.core import Mul +from sympy.core.function import count_ops +from sympy.core.traversal import preorder_traversal, bottom_up +from sympy.functions.combinatorial.factorials import binomial, factorial +from sympy.functions import gamma +from sympy.simplify.gammasimp import gammasimp, _gammasimp + +from sympy.utilities.timeutils import timethis + + +@timethis('combsimp') +def combsimp(expr): + r""" + Simplify combinatorial expressions. + + Explanation + =========== + + This function takes as input an expression containing factorials, + binomials, Pochhammer symbol and other "combinatorial" functions, + and tries to minimize the number of those functions and reduce + the size of their arguments. + + The algorithm works by rewriting all combinatorial functions as + gamma functions and applying gammasimp() except simplification + steps that may make an integer argument non-integer. See docstring + of gammasimp for more information. + + Then it rewrites expression in terms of factorials and binomials by + rewriting gammas as factorials and converting (a+b)!/a!b! into + binomials. + + If expression has gamma functions or combinatorial functions + with non-integer argument, it is automatically passed to gammasimp. + + Examples + ======== + + >>> from sympy.simplify import combsimp + >>> from sympy import factorial, binomial, symbols + >>> n, k = symbols('n k', integer = True) + + >>> combsimp(factorial(n)/factorial(n - 3)) + n*(n - 2)*(n - 1) + >>> combsimp(binomial(n+1, k+1)/binomial(n, k)) + (n + 1)/(k + 1) + + """ + + expr = expr.rewrite(gamma, piecewise=False) + if any(isinstance(node, gamma) and not node.args[0].is_integer + for node in preorder_traversal(expr)): + return gammasimp(expr) + + expr = _gammasimp(expr, as_comb = True) + expr = _gamma_as_comb(expr) + return expr + + +def _gamma_as_comb(expr): + """ + Helper function for combsimp. + + Rewrites expression in terms of factorials and binomials + """ + + expr = expr.rewrite(factorial) + + def f(rv): + if not rv.is_Mul: + return rv + rvd = rv.as_powers_dict() + nd_fact_args = [[], []] # numerator, denominator + + for k in rvd: + if isinstance(k, factorial) and rvd[k].is_Integer: + if rvd[k].is_positive: + nd_fact_args[0].extend([k.args[0]]*rvd[k]) + else: + nd_fact_args[1].extend([k.args[0]]*-rvd[k]) + rvd[k] = 0 + if not nd_fact_args[0] or not nd_fact_args[1]: + return rv + + hit = False + for m in range(2): + i = 0 + while i < len(nd_fact_args[m]): + ai = nd_fact_args[m][i] + for j in range(i + 1, len(nd_fact_args[m])): + aj = nd_fact_args[m][j] + + sum = ai + aj + if sum in nd_fact_args[1 - m]: + hit = True + + nd_fact_args[1 - m].remove(sum) + del nd_fact_args[m][j] + del nd_fact_args[m][i] + + rvd[binomial(sum, ai if count_ops(ai) < + count_ops(aj) else aj)] += ( + -1 if m == 0 else 1) + break + else: + i += 1 + + if hit: + return Mul(*([k**rvd[k] for k in rvd] + [factorial(k) + for k in nd_fact_args[0]]))/Mul(*[factorial(k) + for k in nd_fact_args[1]]) + return rv + + return bottom_up(expr, f) diff --git a/phivenv/Lib/site-packages/sympy/simplify/cse_main.py b/phivenv/Lib/site-packages/sympy/simplify/cse_main.py new file mode 100644 index 0000000000000000000000000000000000000000..bcd1b2e50adae8c3d3400d6c489e63a44ae1e59b --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/simplify/cse_main.py @@ -0,0 +1,945 @@ +""" Tools for doing common subexpression elimination. +""" +from collections import defaultdict + +from sympy.core import Basic, Mul, Add, Pow, sympify +from sympy.core.containers import Tuple, OrderedSet +from sympy.core.exprtools import factor_terms +from sympy.core.singleton import S +from sympy.core.sorting import ordered +from sympy.core.symbol import symbols, Symbol +from sympy.matrices import (MatrixBase, Matrix, ImmutableMatrix, + SparseMatrix, ImmutableSparseMatrix) +from sympy.matrices.expressions import (MatrixExpr, MatrixSymbol, MatMul, + MatAdd, MatPow, Inverse) +from sympy.matrices.expressions.matexpr import MatrixElement +from sympy.polys.rootoftools import RootOf +from sympy.utilities.iterables import numbered_symbols, sift, \ + topological_sort, iterable + +from . import cse_opts + +# (preprocessor, postprocessor) pairs which are commonly useful. They should +# each take a SymPy expression and return a possibly transformed expression. +# When used in the function ``cse()``, the target expressions will be transformed +# by each of the preprocessor functions in order. After the common +# subexpressions are eliminated, each resulting expression will have the +# postprocessor functions transform them in *reverse* order in order to undo the +# transformation if necessary. This allows the algorithm to operate on +# a representation of the expressions that allows for more optimization +# opportunities. +# ``None`` can be used to specify no transformation for either the preprocessor or +# postprocessor. + + +basic_optimizations = [(cse_opts.sub_pre, cse_opts.sub_post), + (factor_terms, None)] + +# sometimes we want the output in a different format; non-trivial +# transformations can be put here for users +# =============================================================== + + +def reps_toposort(r): + """Sort replacements ``r`` so (k1, v1) appears before (k2, v2) + if k2 is in v1's free symbols. This orders items in the + way that cse returns its results (hence, in order to use the + replacements in a substitution option it would make sense + to reverse the order). + + Examples + ======== + + >>> from sympy.simplify.cse_main import reps_toposort + >>> from sympy.abc import x, y + >>> from sympy import Eq + >>> for l, r in reps_toposort([(x, y + 1), (y, 2)]): + ... print(Eq(l, r)) + ... + Eq(y, 2) + Eq(x, y + 1) + + """ + r = sympify(r) + E = [] + for c1, (k1, v1) in enumerate(r): + for c2, (k2, v2) in enumerate(r): + if k1 in v2.free_symbols: + E.append((c1, c2)) + return [r[i] for i in topological_sort((range(len(r)), E))] + + +def cse_separate(r, e): + """Move expressions that are in the form (symbol, expr) out of the + expressions and sort them into the replacements using the reps_toposort. + + Examples + ======== + + >>> from sympy.simplify.cse_main import cse_separate + >>> from sympy.abc import x, y, z + >>> from sympy import cos, exp, cse, Eq, symbols + >>> x0, x1 = symbols('x:2') + >>> eq = (x + 1 + exp((x + 1)/(y + 1)) + cos(y + 1)) + >>> cse([eq, Eq(x, z + 1), z - 2], postprocess=cse_separate) in [ + ... [[(x0, y + 1), (x, z + 1), (x1, x + 1)], + ... [x1 + exp(x1/x0) + cos(x0), z - 2]], + ... [[(x1, y + 1), (x, z + 1), (x0, x + 1)], + ... [x0 + exp(x0/x1) + cos(x1), z - 2]]] + ... + True + """ + d = sift(e, lambda w: w.is_Equality and w.lhs.is_Symbol) + r = r + [w.args for w in d[True]] + e = d[False] + return [reps_toposort(r), e] + + +def cse_release_variables(r, e): + """ + Return tuples giving ``(a, b)`` where ``a`` is a symbol and ``b`` is + either an expression or None. The value of None is used when a + symbol is no longer needed for subsequent expressions. + + Use of such output can reduce the memory footprint of lambdified + expressions that contain large, repeated subexpressions. + + Examples + ======== + + >>> from sympy import cse + >>> from sympy.simplify.cse_main import cse_release_variables + >>> from sympy.abc import x, y + >>> eqs = [(x + y - 1)**2, x, x + y, (x + y)/(2*x + 1) + (x + y - 1)**2, (2*x + 1)**(x + y)] + >>> defs, rvs = cse_release_variables(*cse(eqs)) + >>> for i in defs: + ... print(i) + ... + (x0, x + y) + (x1, (x0 - 1)**2) + (x2, 2*x + 1) + (_3, x0/x2 + x1) + (_4, x2**x0) + (x2, None) + (_0, x1) + (x1, None) + (_2, x0) + (x0, None) + (_1, x) + >>> print(rvs) + (_0, _1, _2, _3, _4) + """ + if not r: + return r, e + + s, p = zip(*r) + esyms = symbols('_:%d' % len(e)) + syms = list(esyms) + s = list(s) + in_use = set(s) + p = list(p) + # sort e so those with most sub-expressions appear first + e = [(e[i], syms[i]) for i in range(len(e))] + e, syms = zip(*sorted(e, + key=lambda x: -sum(p[s.index(i)].count_ops() + for i in x[0].free_symbols & in_use))) + syms = list(syms) + p += e + rv = [] + i = len(p) - 1 + while i >= 0: + _p = p.pop() + c = in_use & _p.free_symbols + if c: # sorting for canonical results + rv.extend([(s, None) for s in sorted(c, key=str)]) + if i >= len(r): + rv.append((syms.pop(), _p)) + else: + rv.append((s[i], _p)) + in_use -= c + i -= 1 + rv.reverse() + return rv, esyms + + +# ====end of cse postprocess idioms=========================== + + +def preprocess_for_cse(expr, optimizations): + """ Preprocess an expression to optimize for common subexpression + elimination. + + Parameters + ========== + + expr : SymPy expression + The target expression to optimize. + optimizations : list of (callable, callable) pairs + The (preprocessor, postprocessor) pairs. + + Returns + ======= + + expr : SymPy expression + The transformed expression. + """ + for pre, post in optimizations: + if pre is not None: + expr = pre(expr) + return expr + + +def postprocess_for_cse(expr, optimizations): + """Postprocess an expression after common subexpression elimination to + return the expression to canonical SymPy form. + + Parameters + ========== + + expr : SymPy expression + The target expression to transform. + optimizations : list of (callable, callable) pairs, optional + The (preprocessor, postprocessor) pairs. The postprocessors will be + applied in reversed order to undo the effects of the preprocessors + correctly. + + Returns + ======= + + expr : SymPy expression + The transformed expression. + """ + for pre, post in reversed(optimizations): + if post is not None: + expr = post(expr) + return expr + + +class FuncArgTracker: + """ + A class which manages a mapping from functions to arguments and an inverse + mapping from arguments to functions. + """ + + def __init__(self, funcs): + # To minimize the number of symbolic comparisons, all function arguments + # get assigned a value number. + self.value_numbers = {} + self.value_number_to_value = [] + + # Both of these maps use integer indices for arguments / functions. + self.arg_to_funcset = [] + self.func_to_argset = [] + + for func_i, func in enumerate(funcs): + func_argset = OrderedSet() + + for func_arg in func.args: + arg_number = self.get_or_add_value_number(func_arg) + func_argset.add(arg_number) + self.arg_to_funcset[arg_number].add(func_i) + + self.func_to_argset.append(func_argset) + + def get_args_in_value_order(self, argset): + """ + Return the list of arguments in sorted order according to their value + numbers. + """ + return [self.value_number_to_value[argn] for argn in sorted(argset)] + + def get_or_add_value_number(self, value): + """ + Return the value number for the given argument. + """ + nvalues = len(self.value_numbers) + value_number = self.value_numbers.setdefault(value, nvalues) + if value_number == nvalues: + self.value_number_to_value.append(value) + self.arg_to_funcset.append(OrderedSet()) + return value_number + + def stop_arg_tracking(self, func_i): + """ + Remove the function func_i from the argument to function mapping. + """ + for arg in self.func_to_argset[func_i]: + self.arg_to_funcset[arg].remove(func_i) + + + def get_common_arg_candidates(self, argset, min_func_i=0): + """Return a dict whose keys are function numbers. The entries of the dict are + the number of arguments said function has in common with + ``argset``. Entries have at least 2 items in common. All keys have + value at least ``min_func_i``. + """ + count_map = defaultdict(lambda: 0) + if not argset: + return count_map + + funcsets = [self.arg_to_funcset[arg] for arg in argset] + # As an optimization below, we handle the largest funcset separately from + # the others. + largest_funcset = max(funcsets, key=len) + + for funcset in funcsets: + if largest_funcset is funcset: + continue + for func_i in funcset: + if func_i >= min_func_i: + count_map[func_i] += 1 + + # We pick the smaller of the two containers (count_map, largest_funcset) + # to iterate over to reduce the number of iterations needed. + (smaller_funcs_container, + larger_funcs_container) = sorted( + [largest_funcset, count_map], + key=len) + + for func_i in smaller_funcs_container: + # Not already in count_map? It can't possibly be in the output, so + # skip it. + if count_map[func_i] < 1: + continue + + if func_i in larger_funcs_container: + count_map[func_i] += 1 + + return {k: v for k, v in count_map.items() if v >= 2} + + def get_subset_candidates(self, argset, restrict_to_funcset=None): + """ + Return a set of functions each of which whose argument list contains + ``argset``, optionally filtered only to contain functions in + ``restrict_to_funcset``. + """ + iarg = iter(argset) + + indices = OrderedSet( + fi for fi in self.arg_to_funcset[next(iarg)]) + + if restrict_to_funcset is not None: + indices &= restrict_to_funcset + + for arg in iarg: + indices &= self.arg_to_funcset[arg] + + return indices + + def update_func_argset(self, func_i, new_argset): + """ + Update a function with a new set of arguments. + """ + new_args = OrderedSet(new_argset) + old_args = self.func_to_argset[func_i] + + for deleted_arg in old_args - new_args: + self.arg_to_funcset[deleted_arg].remove(func_i) + for added_arg in new_args - old_args: + self.arg_to_funcset[added_arg].add(func_i) + + self.func_to_argset[func_i].clear() + self.func_to_argset[func_i].update(new_args) + + +class Unevaluated: + + def __init__(self, func, args): + self.func = func + self.args = args + + def __str__(self): + return "Uneval<{}>({})".format( + self.func, ", ".join(str(a) for a in self.args)) + + def as_unevaluated_basic(self): + return self.func(*self.args, evaluate=False) + + @property + def free_symbols(self): + return set().union(*[a.free_symbols for a in self.args]) + + __repr__ = __str__ + + +def match_common_args(func_class, funcs, opt_subs): + """ + Recognize and extract common subexpressions of function arguments within a + set of function calls. For instance, for the following function calls:: + + x + z + y + sin(x + y) + + this will extract a common subexpression of `x + y`:: + + w = x + y + w + z + sin(w) + + The function we work with is assumed to be associative and commutative. + + Parameters + ========== + + func_class: class + The function class (e.g. Add, Mul) + funcs: list of functions + A list of function calls. + opt_subs: dict + A dictionary of substitutions which this function may update. + """ + + # Sort to ensure that whole-function subexpressions come before the items + # that use them. + funcs = sorted(funcs, key=lambda f: len(f.args)) + arg_tracker = FuncArgTracker(funcs) + + changed = OrderedSet() + + for i in range(len(funcs)): + common_arg_candidates_counts = arg_tracker.get_common_arg_candidates( + arg_tracker.func_to_argset[i], min_func_i=i + 1) + + # Sort the candidates in order of match size. + # This makes us try combining smaller matches first. + common_arg_candidates = OrderedSet(sorted( + common_arg_candidates_counts.keys(), + key=lambda k: (common_arg_candidates_counts[k], k))) + + while common_arg_candidates: + j = common_arg_candidates.pop(last=False) + + com_args = arg_tracker.func_to_argset[i].intersection( + arg_tracker.func_to_argset[j]) + + if len(com_args) <= 1: + # This may happen if a set of common arguments was already + # combined in a previous iteration. + continue + + # For all sets, replace the common symbols by the function + # over them, to allow recursive matches. + + diff_i = arg_tracker.func_to_argset[i].difference(com_args) + if diff_i: + # com_func needs to be unevaluated to allow for recursive matches. + com_func = Unevaluated( + func_class, arg_tracker.get_args_in_value_order(com_args)) + com_func_number = arg_tracker.get_or_add_value_number(com_func) + arg_tracker.update_func_argset(i, diff_i | OrderedSet([com_func_number])) + changed.add(i) + else: + # Treat the whole expression as a CSE. + # + # The reason this needs to be done is somewhat subtle. Within + # tree_cse(), to_eliminate only contains expressions that are + # seen more than once. The problem is unevaluated expressions + # do not compare equal to the evaluated equivalent. So + # tree_cse() won't mark funcs[i] as a CSE if we use an + # unevaluated version. + com_func_number = arg_tracker.get_or_add_value_number(funcs[i]) + + diff_j = arg_tracker.func_to_argset[j].difference(com_args) + arg_tracker.update_func_argset(j, diff_j | OrderedSet([com_func_number])) + changed.add(j) + + for k in arg_tracker.get_subset_candidates( + com_args, common_arg_candidates): + diff_k = arg_tracker.func_to_argset[k].difference(com_args) + arg_tracker.update_func_argset(k, diff_k | OrderedSet([com_func_number])) + changed.add(k) + + if i in changed: + opt_subs[funcs[i]] = Unevaluated(func_class, + arg_tracker.get_args_in_value_order(arg_tracker.func_to_argset[i])) + + arg_tracker.stop_arg_tracking(i) + + +def opt_cse(exprs, order='canonical'): + """Find optimization opportunities in Adds, Muls, Pows and negative + coefficient Muls. + + Parameters + ========== + + exprs : list of SymPy expressions + The expressions to optimize. + order : string, 'none' or 'canonical' + The order by which Mul and Add arguments are processed. For large + expressions where speed is a concern, use the setting order='none'. + + Returns + ======= + + opt_subs : dictionary of expression substitutions + The expression substitutions which can be useful to optimize CSE. + + Examples + ======== + + >>> from sympy.simplify.cse_main import opt_cse + >>> from sympy.abc import x + >>> opt_subs = opt_cse([x**-2]) + >>> k, v = list(opt_subs.keys())[0], list(opt_subs.values())[0] + >>> print((k, v.as_unevaluated_basic())) + (x**(-2), 1/(x**2)) + """ + opt_subs = {} + + adds = OrderedSet() + muls = OrderedSet() + + seen_subexp = set() + collapsible_subexp = set() + + def _find_opts(expr): + + if not isinstance(expr, (Basic, Unevaluated)): + return + + if expr.is_Atom or expr.is_Order: + return + + if iterable(expr): + list(map(_find_opts, expr)) + return + + if expr in seen_subexp: + return expr + seen_subexp.add(expr) + + list(map(_find_opts, expr.args)) + + if not isinstance(expr, MatrixExpr) and expr.could_extract_minus_sign(): + # XXX -expr does not always work rigorously for some expressions + # containing UnevaluatedExpr. + # https://github.com/sympy/sympy/issues/24818 + if isinstance(expr, Add): + neg_expr = Add(*(-i for i in expr.args)) + else: + neg_expr = -expr + + if not neg_expr.is_Atom: + opt_subs[expr] = Unevaluated(Mul, (S.NegativeOne, neg_expr)) + seen_subexp.add(neg_expr) + expr = neg_expr + + if isinstance(expr, (Mul, MatMul)): + if len(expr.args) == 1: + collapsible_subexp.add(expr) + else: + muls.add(expr) + + elif isinstance(expr, (Add, MatAdd)): + if len(expr.args) == 1: + collapsible_subexp.add(expr) + else: + adds.add(expr) + + elif isinstance(expr, Inverse): + # Do not want to treat `Inverse` as a `MatPow` + pass + + elif isinstance(expr, (Pow, MatPow)): + base, exp = expr.base, expr.exp + if exp.could_extract_minus_sign(): + opt_subs[expr] = Unevaluated(Pow, (Pow(base, -exp), -1)) + + for e in exprs: + if isinstance(e, (Basic, Unevaluated)): + _find_opts(e) + + # Handle collapsing of multinary operations with single arguments + edges = [(s, s.args[0]) for s in collapsible_subexp + if s.args[0] in collapsible_subexp] + for e in reversed(topological_sort((collapsible_subexp, edges))): + opt_subs[e] = opt_subs.get(e.args[0], e.args[0]) + + # split muls into commutative + commutative_muls = OrderedSet() + for m in muls: + c, nc = m.args_cnc(cset=False) + if c: + c_mul = m.func(*c) + if nc: + if c_mul == 1: + new_obj = m.func(*nc) + else: + if isinstance(m, MatMul): + new_obj = m.func(c_mul, *nc, evaluate=False) + else: + new_obj = m.func(c_mul, m.func(*nc), evaluate=False) + opt_subs[m] = new_obj + if len(c) > 1: + commutative_muls.add(c_mul) + + match_common_args(Add, adds, opt_subs) + match_common_args(Mul, commutative_muls, opt_subs) + + return opt_subs + + +def tree_cse(exprs, symbols, opt_subs=None, order='canonical', ignore=()): + """Perform raw CSE on expression tree, taking opt_subs into account. + + Parameters + ========== + + exprs : list of SymPy expressions + The expressions to reduce. + symbols : infinite iterator yielding unique Symbols + The symbols used to label the common subexpressions which are pulled + out. + opt_subs : dictionary of expression substitutions + The expressions to be substituted before any CSE action is performed. + order : string, 'none' or 'canonical' + The order by which Mul and Add arguments are processed. For large + expressions where speed is a concern, use the setting order='none'. + ignore : iterable of Symbols + Substitutions containing any Symbol from ``ignore`` will be ignored. + """ + if opt_subs is None: + opt_subs = {} + + ## Find repeated sub-expressions + + to_eliminate = set() + + seen_subexp = set() + excluded_symbols = set() + + def _find_repeated(expr): + if not isinstance(expr, (Basic, Unevaluated)): + return + + if isinstance(expr, RootOf): + return + + if isinstance(expr, Basic) and ( + expr.is_Atom or + expr.is_Order or + isinstance(expr, (MatrixSymbol, MatrixElement))): + if expr.is_Symbol: + excluded_symbols.add(expr.name) + return + + if iterable(expr): + args = expr + + else: + if expr in seen_subexp: + for ign in ignore: + if ign in expr.free_symbols: + break + else: + to_eliminate.add(expr) + return + + seen_subexp.add(expr) + + if expr in opt_subs: + expr = opt_subs[expr] + + args = expr.args + + list(map(_find_repeated, args)) + + for e in exprs: + if isinstance(e, Basic): + _find_repeated(e) + + ## Rebuild tree + + # Remove symbols from the generator that conflict with names in the expressions. + symbols = (_ for _ in symbols if _.name not in excluded_symbols) + + replacements = [] + + subs = {} + + def _rebuild(expr): + if not isinstance(expr, (Basic, Unevaluated)): + return expr + + if not expr.args: + return expr + + if iterable(expr): + new_args = [_rebuild(arg) for arg in expr.args] + return expr.func(*new_args) + + if expr in subs: + return subs[expr] + + orig_expr = expr + if expr in opt_subs: + expr = opt_subs[expr] + + # If enabled, parse Muls and Adds arguments by order to ensure + # replacement order independent from hashes + if order != 'none': + if isinstance(expr, (Mul, MatMul)): + c, nc = expr.args_cnc() + if c == [1]: + args = nc + else: + args = list(ordered(c)) + nc + elif isinstance(expr, (Add, MatAdd)): + args = list(ordered(expr.args)) + else: + args = expr.args + else: + args = expr.args + + new_args = list(map(_rebuild, args)) + if isinstance(expr, Unevaluated) or new_args != args: + new_expr = expr.func(*new_args) + else: + new_expr = expr + + if orig_expr in to_eliminate: + try: + sym = next(symbols) + except StopIteration: + raise ValueError("Symbols iterator ran out of symbols.") + + if isinstance(orig_expr, MatrixExpr): + sym = MatrixSymbol(sym.name, orig_expr.rows, + orig_expr.cols) + + subs[orig_expr] = sym + replacements.append((sym, new_expr)) + return sym + + else: + return new_expr + + reduced_exprs = [] + for e in exprs: + if isinstance(e, Basic): + reduced_e = _rebuild(e) + else: + reduced_e = e + reduced_exprs.append(reduced_e) + return replacements, reduced_exprs + + +def cse(exprs, symbols=None, optimizations=None, postprocess=None, + order='canonical', ignore=(), list=True): + """ Perform common subexpression elimination on an expression. + + Parameters + ========== + + exprs : list of SymPy expressions, or a single SymPy expression + The expressions to reduce. + symbols : infinite iterator yielding unique Symbols + The symbols used to label the common subexpressions which are pulled + out. The ``numbered_symbols`` generator is useful. The default is a + stream of symbols of the form "x0", "x1", etc. This must be an + infinite iterator. + optimizations : list of (callable, callable) pairs + The (preprocessor, postprocessor) pairs of external optimization + functions. Optionally 'basic' can be passed for a set of predefined + basic optimizations. Such 'basic' optimizations were used by default + in old implementation, however they can be really slow on larger + expressions. Now, no pre or post optimizations are made by default. + postprocess : a function which accepts the two return values of cse and + returns the desired form of output from cse, e.g. if you want the + replacements reversed the function might be the following lambda: + lambda r, e: return reversed(r), e + order : string, 'none' or 'canonical' + The order by which Mul and Add arguments are processed. If set to + 'canonical', arguments will be canonically ordered. If set to 'none', + ordering will be faster but dependent on expressions hashes, thus + machine dependent and variable. For large expressions where speed is a + concern, use the setting order='none'. + ignore : iterable of Symbols + Substitutions containing any Symbol from ``ignore`` will be ignored. + list : bool, (default True) + Returns expression in list or else with same type as input (when False). + + Returns + ======= + + replacements : list of (Symbol, expression) pairs + All of the common subexpressions that were replaced. Subexpressions + earlier in this list might show up in subexpressions later in this + list. + reduced_exprs : list of SymPy expressions + The reduced expressions with all of the replacements above. + + Examples + ======== + + >>> from sympy import cse, SparseMatrix + >>> from sympy.abc import x, y, z, w + >>> cse(((w + x + y + z)*(w + y + z))/(w + x)**3) + ([(x0, y + z), (x1, w + x)], [(w + x0)*(x0 + x1)/x1**3]) + + + List of expressions with recursive substitutions: + + >>> m = SparseMatrix([x + y, x + y + z]) + >>> cse([(x+y)**2, x + y + z, y + z, x + z + y, m]) + ([(x0, x + y), (x1, x0 + z)], [x0**2, x1, y + z, x1, Matrix([ + [x0], + [x1]])]) + + Note: the type and mutability of input matrices is retained. + + >>> isinstance(_[1][-1], SparseMatrix) + True + + The user may disallow substitutions containing certain symbols: + + >>> cse([y**2*(x + 1), 3*y**2*(x + 1)], ignore=(y,)) + ([(x0, x + 1)], [x0*y**2, 3*x0*y**2]) + + The default return value for the reduced expression(s) is a list, even if there is only + one expression. The `list` flag preserves the type of the input in the output: + + >>> cse(x) + ([], [x]) + >>> cse(x, list=False) + ([], x) + """ + if not list: + return _cse_homogeneous(exprs, + symbols=symbols, optimizations=optimizations, + postprocess=postprocess, order=order, ignore=ignore) + + if isinstance(exprs, (int, float)): + exprs = sympify(exprs) + + # Handle the case if just one expression was passed. + if isinstance(exprs, (Basic, MatrixBase)): + exprs = [exprs] + + copy = exprs + temp = [] + for e in exprs: + if isinstance(e, (Matrix, ImmutableMatrix)): + temp.append(Tuple(*e.flat())) + elif isinstance(e, (SparseMatrix, ImmutableSparseMatrix)): + temp.append(Tuple(*e.todok().items())) + else: + temp.append(e) + exprs = temp + del temp + + if optimizations is None: + optimizations = [] + elif optimizations == 'basic': + optimizations = basic_optimizations + + # Preprocess the expressions to give us better optimization opportunities. + reduced_exprs = [preprocess_for_cse(e, optimizations) for e in exprs] + + if symbols is None: + symbols = numbered_symbols(cls=Symbol) + else: + # In case we get passed an iterable with an __iter__ method instead of + # an actual iterator. + symbols = iter(symbols) + + # Find other optimization opportunities. + opt_subs = opt_cse(reduced_exprs, order) + + # Main CSE algorithm. + replacements, reduced_exprs = tree_cse(reduced_exprs, symbols, opt_subs, + order, ignore) + + # Postprocess the expressions to return the expressions to canonical form. + exprs = copy + replacements = [(sym, postprocess_for_cse(subtree, optimizations)) + for sym, subtree in replacements] + reduced_exprs = [postprocess_for_cse(e, optimizations) + for e in reduced_exprs] + + # Get the matrices back + for i, e in enumerate(exprs): + if isinstance(e, (Matrix, ImmutableMatrix)): + reduced_exprs[i] = Matrix(e.rows, e.cols, reduced_exprs[i]) + if isinstance(e, ImmutableMatrix): + reduced_exprs[i] = reduced_exprs[i].as_immutable() + elif isinstance(e, (SparseMatrix, ImmutableSparseMatrix)): + m = SparseMatrix(e.rows, e.cols, {}) + for k, v in reduced_exprs[i]: + m[k] = v + if isinstance(e, ImmutableSparseMatrix): + m = m.as_immutable() + reduced_exprs[i] = m + + if postprocess is None: + return replacements, reduced_exprs + + return postprocess(replacements, reduced_exprs) + + +def _cse_homogeneous(exprs, **kwargs): + """ + Same as ``cse`` but the ``reduced_exprs`` are returned + with the same type as ``exprs`` or a sympified version of the same. + + Parameters + ========== + + exprs : an Expr, iterable of Expr or dictionary with Expr values + the expressions in which repeated subexpressions will be identified + kwargs : additional arguments for the ``cse`` function + + Returns + ======= + + replacements : list of (Symbol, expression) pairs + All of the common subexpressions that were replaced. Subexpressions + earlier in this list might show up in subexpressions later in this + list. + reduced_exprs : list of SymPy expressions + The reduced expressions with all of the replacements above. + + Examples + ======== + + >>> from sympy.simplify.cse_main import cse + >>> from sympy import cos, Tuple, Matrix + >>> from sympy.abc import x + >>> output = lambda x: type(cse(x, list=False)[1]) + >>> output(1) + + >>> output('cos(x)') + + >>> output(cos(x)) + cos + >>> output(Tuple(1, x)) + + >>> output(Matrix([[1,0], [0,1]])) + + >>> output([1, x]) + + >>> output((1, x)) + + >>> output({1, x}) + + """ + if isinstance(exprs, str): + replacements, reduced_exprs = _cse_homogeneous( + sympify(exprs), **kwargs) + return replacements, repr(reduced_exprs) + if isinstance(exprs, (list, tuple, set)): + replacements, reduced_exprs = cse(exprs, **kwargs) + return replacements, type(exprs)(reduced_exprs) + if isinstance(exprs, dict): + keys = list(exprs.keys()) # In order to guarantee the order of the elements. + replacements, values = cse([exprs[k] for k in keys], **kwargs) + reduced_exprs = dict(zip(keys, values)) + return replacements, reduced_exprs + + try: + replacements, (reduced_exprs,) = cse(exprs, **kwargs) + except TypeError: # For example 'mpf' objects + return [], exprs + else: + return replacements, reduced_exprs diff --git a/phivenv/Lib/site-packages/sympy/simplify/cse_opts.py b/phivenv/Lib/site-packages/sympy/simplify/cse_opts.py new file mode 100644 index 0000000000000000000000000000000000000000..36a59857411de740ae47423442af88b118a3395d --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/simplify/cse_opts.py @@ -0,0 +1,52 @@ +""" Optimizations of the expression tree representation for better CSE +opportunities. +""" +from sympy.core import Add, Basic, Mul +from sympy.core.singleton import S +from sympy.core.sorting import default_sort_key +from sympy.core.traversal import preorder_traversal + + +def sub_pre(e): + """ Replace y - x with -(x - y) if -1 can be extracted from y - x. + """ + # replacing Add, A, from which -1 can be extracted with -1*-A + adds = [a for a in e.atoms(Add) if a.could_extract_minus_sign()] + reps = {} + ignore = set() + for a in adds: + na = -a + if na.is_Mul: # e.g. MatExpr + ignore.add(a) + continue + reps[a] = Mul._from_args([S.NegativeOne, na]) + + e = e.xreplace(reps) + + # repeat again for persisting Adds but mark these with a leading 1, -1 + # e.g. y - x -> 1*-1*(x - y) + if isinstance(e, Basic): + negs = {} + for a in sorted(e.atoms(Add), key=default_sort_key): + if a in ignore: + continue + if a in reps: + negs[a] = reps[a] + elif a.could_extract_minus_sign(): + negs[a] = Mul._from_args([S.One, S.NegativeOne, -a]) + e = e.xreplace(negs) + return e + + +def sub_post(e): + """ Replace 1*-1*x with -x. + """ + replacements = [] + for node in preorder_traversal(e): + if isinstance(node, Mul) and \ + node.args[0] is S.One and node.args[1] is S.NegativeOne: + replacements.append((node, -Mul._from_args(node.args[2:]))) + for node, replacement in replacements: + e = e.xreplace({node: replacement}) + + return e diff --git a/phivenv/Lib/site-packages/sympy/simplify/epathtools.py b/phivenv/Lib/site-packages/sympy/simplify/epathtools.py new file mode 100644 index 0000000000000000000000000000000000000000..7be983ada63fd39d7d467acf9afd62b3a41a2d85 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/simplify/epathtools.py @@ -0,0 +1,352 @@ +"""Tools for manipulation of expressions using paths. """ + +from sympy.core import Basic + + +class EPath: + r""" + Manipulate expressions using paths. + + EPath grammar in EBNF notation:: + + literal ::= /[A-Za-z_][A-Za-z_0-9]*/ + number ::= /-?\d+/ + type ::= literal + attribute ::= literal "?" + all ::= "*" + slice ::= "[" number? (":" number? (":" number?)?)? "]" + range ::= all | slice + query ::= (type | attribute) ("|" (type | attribute))* + selector ::= range | query range? + path ::= "/" selector ("/" selector)* + + See the docstring of the epath() function. + + """ + + __slots__ = ("_path", "_epath") + + def __new__(cls, path): + """Construct new EPath. """ + if isinstance(path, EPath): + return path + + if not path: + raise ValueError("empty EPath") + + _path = path + + if path[0] == '/': + path = path[1:] + else: + raise NotImplementedError("non-root EPath") + + epath = [] + + for selector in path.split('/'): + selector = selector.strip() + + if not selector: + raise ValueError("empty selector") + + index = 0 + + for c in selector: + if c.isalnum() or c in ('_', '|', '?'): + index += 1 + else: + break + + attrs = [] + types = [] + + if index: + elements = selector[:index] + selector = selector[index:] + + for element in elements.split('|'): + element = element.strip() + + if not element: + raise ValueError("empty element") + + if element.endswith('?'): + attrs.append(element[:-1]) + else: + types.append(element) + + span = None + + if selector == '*': + pass + else: + if selector.startswith('['): + try: + i = selector.index(']') + except ValueError: + raise ValueError("expected ']', got EOL") + + _span, span = selector[1:i], [] + + if ':' not in _span: + span = int(_span) + else: + for elt in _span.split(':', 3): + if not elt: + span.append(None) + else: + span.append(int(elt)) + + span = slice(*span) + + selector = selector[i + 1:] + + if selector: + raise ValueError("trailing characters in selector") + + epath.append((attrs, types, span)) + + obj = object.__new__(cls) + + obj._path = _path + obj._epath = epath + + return obj + + def __repr__(self): + return "%s(%r)" % (self.__class__.__name__, self._path) + + def _get_ordered_args(self, expr): + """Sort ``expr.args`` using printing order. """ + if expr.is_Add: + return expr.as_ordered_terms() + elif expr.is_Mul: + return expr.as_ordered_factors() + else: + return expr.args + + def _hasattrs(self, expr, attrs) -> bool: + """Check if ``expr`` has any of ``attrs``. """ + return all(hasattr(expr, attr) for attr in attrs) + + def _hastypes(self, expr, types): + """Check if ``expr`` is any of ``types``. """ + _types = [ cls.__name__ for cls in expr.__class__.mro() ] + return bool(set(_types).intersection(types)) + + def _has(self, expr, attrs, types): + """Apply ``_hasattrs`` and ``_hastypes`` to ``expr``. """ + if not (attrs or types): + return True + + if attrs and self._hasattrs(expr, attrs): + return True + + if types and self._hastypes(expr, types): + return True + + return False + + def apply(self, expr, func, args=None, kwargs=None): + """ + Modify parts of an expression selected by a path. + + Examples + ======== + + >>> from sympy.simplify.epathtools import EPath + >>> from sympy import sin, cos, E + >>> from sympy.abc import x, y, z, t + + >>> path = EPath("/*/[0]/Symbol") + >>> expr = [((x, 1), 2), ((3, y), z)] + + >>> path.apply(expr, lambda expr: expr**2) + [((x**2, 1), 2), ((3, y**2), z)] + + >>> path = EPath("/*/*/Symbol") + >>> expr = t + sin(x + 1) + cos(x + y + E) + + >>> path.apply(expr, lambda expr: 2*expr) + t + sin(2*x + 1) + cos(2*x + 2*y + E) + + """ + def _apply(path, expr, func): + if not path: + return func(expr) + else: + selector, path = path[0], path[1:] + attrs, types, span = selector + + if isinstance(expr, Basic): + if not expr.is_Atom: + args, basic = self._get_ordered_args(expr), True + else: + return expr + elif hasattr(expr, '__iter__'): + args, basic = expr, False + else: + return expr + + args = list(args) + + if span is not None: + if isinstance(span, slice): + indices = range(*span.indices(len(args))) + else: + indices = [span] + else: + indices = range(len(args)) + + for i in indices: + try: + arg = args[i] + except IndexError: + continue + + if self._has(arg, attrs, types): + args[i] = _apply(path, arg, func) + + if basic: + return expr.func(*args) + else: + return expr.__class__(args) + + _args, _kwargs = args or (), kwargs or {} + _func = lambda expr: func(expr, *_args, **_kwargs) + + return _apply(self._epath, expr, _func) + + def select(self, expr): + """ + Retrieve parts of an expression selected by a path. + + Examples + ======== + + >>> from sympy.simplify.epathtools import EPath + >>> from sympy import sin, cos, E + >>> from sympy.abc import x, y, z, t + + >>> path = EPath("/*/[0]/Symbol") + >>> expr = [((x, 1), 2), ((3, y), z)] + + >>> path.select(expr) + [x, y] + + >>> path = EPath("/*/*/Symbol") + >>> expr = t + sin(x + 1) + cos(x + y + E) + + >>> path.select(expr) + [x, x, y] + + """ + result = [] + + def _select(path, expr): + if not path: + result.append(expr) + else: + selector, path = path[0], path[1:] + attrs, types, span = selector + + if isinstance(expr, Basic): + args = self._get_ordered_args(expr) + elif hasattr(expr, '__iter__'): + args = expr + else: + return + + if span is not None: + if isinstance(span, slice): + args = args[span] + else: + try: + args = [args[span]] + except IndexError: + return + + for arg in args: + if self._has(arg, attrs, types): + _select(path, arg) + + _select(self._epath, expr) + return result + + +def epath(path, expr=None, func=None, args=None, kwargs=None): + r""" + Manipulate parts of an expression selected by a path. + + Explanation + =========== + + This function allows to manipulate large nested expressions in single + line of code, utilizing techniques to those applied in XML processing + standards (e.g. XPath). + + If ``func`` is ``None``, :func:`epath` retrieves elements selected by + the ``path``. Otherwise it applies ``func`` to each matching element. + + Note that it is more efficient to create an EPath object and use the select + and apply methods of that object, since this will compile the path string + only once. This function should only be used as a convenient shortcut for + interactive use. + + This is the supported syntax: + + * select all: ``/*`` + Equivalent of ``for arg in args:``. + * select slice: ``/[0]`` or ``/[1:5]`` or ``/[1:5:2]`` + Supports standard Python's slice syntax. + * select by type: ``/list`` or ``/list|tuple`` + Emulates ``isinstance()``. + * select by attribute: ``/__iter__?`` + Emulates ``hasattr()``. + + Parameters + ========== + + path : str | EPath + A path as a string or a compiled EPath. + expr : Basic | iterable + An expression or a container of expressions. + func : callable (optional) + A callable that will be applied to matching parts. + args : tuple (optional) + Additional positional arguments to ``func``. + kwargs : dict (optional) + Additional keyword arguments to ``func``. + + Examples + ======== + + >>> from sympy.simplify.epathtools import epath + >>> from sympy import sin, cos, E + >>> from sympy.abc import x, y, z, t + + >>> path = "/*/[0]/Symbol" + >>> expr = [((x, 1), 2), ((3, y), z)] + + >>> epath(path, expr) + [x, y] + >>> epath(path, expr, lambda expr: expr**2) + [((x**2, 1), 2), ((3, y**2), z)] + + >>> path = "/*/*/Symbol" + >>> expr = t + sin(x + 1) + cos(x + y + E) + + >>> epath(path, expr) + [x, x, y] + >>> epath(path, expr, lambda expr: 2*expr) + t + sin(2*x + 1) + cos(2*x + 2*y + E) + + """ + _epath = EPath(path) + + if expr is None: + return _epath + if func is None: + return _epath.select(expr) + else: + return _epath.apply(expr, func, args, kwargs) diff --git a/phivenv/Lib/site-packages/sympy/simplify/fu.py b/phivenv/Lib/site-packages/sympy/simplify/fu.py new file mode 100644 index 0000000000000000000000000000000000000000..a26706edca98385df0009a8ee41476a17d36420c --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/simplify/fu.py @@ -0,0 +1,2112 @@ +from collections import defaultdict + +from sympy.core.add import Add +from sympy.core.cache import cacheit +from sympy.core.expr import Expr +from sympy.core.exprtools import Factors, gcd_terms, factor_terms +from sympy.core.function import expand_mul +from sympy.core.mul import Mul +from sympy.core.numbers import pi, I +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.core.sorting import ordered +from sympy.core.symbol import Dummy +from sympy.core.sympify import sympify +from sympy.core.traversal import bottom_up +from sympy.functions.combinatorial.factorials import binomial +from sympy.functions.elementary.hyperbolic import ( + cosh, sinh, tanh, coth, sech, csch, HyperbolicFunction) +from sympy.functions.elementary.trigonometric import ( + cos, sin, tan, cot, sec, csc, sqrt, TrigonometricFunction) +from sympy.ntheory.factor_ import perfect_power +from sympy.polys.polytools import factor +from sympy.strategies.tree import greedy +from sympy.strategies.core import identity, debug + +from sympy import SYMPY_DEBUG + + +# ================== Fu-like tools =========================== + + +def TR0(rv): + """Simplification of rational polynomials, trying to simplify + the expression, e.g. combine things like 3*x + 2*x, etc.... + """ + # although it would be nice to use cancel, it doesn't work + # with noncommutatives + return rv.normal().factor().expand() + + +def TR1(rv): + """Replace sec, csc with 1/cos, 1/sin + + Examples + ======== + + >>> from sympy.simplify.fu import TR1, sec, csc + >>> from sympy.abc import x + >>> TR1(2*csc(x) + sec(x)) + 1/cos(x) + 2/sin(x) + """ + + def f(rv): + if isinstance(rv, sec): + a = rv.args[0] + return S.One/cos(a) + elif isinstance(rv, csc): + a = rv.args[0] + return S.One/sin(a) + return rv + + return bottom_up(rv, f) + + +def TR2(rv): + """Replace tan and cot with sin/cos and cos/sin + + Examples + ======== + + >>> from sympy.simplify.fu import TR2 + >>> from sympy.abc import x + >>> from sympy import tan, cot, sin, cos + >>> TR2(tan(x)) + sin(x)/cos(x) + >>> TR2(cot(x)) + cos(x)/sin(x) + >>> TR2(tan(tan(x) - sin(x)/cos(x))) + 0 + + """ + + def f(rv): + if isinstance(rv, tan): + a = rv.args[0] + return sin(a)/cos(a) + elif isinstance(rv, cot): + a = rv.args[0] + return cos(a)/sin(a) + return rv + + return bottom_up(rv, f) + + +def TR2i(rv, half=False): + """Converts ratios involving sin and cos as follows:: + sin(x)/cos(x) -> tan(x) + sin(x)/(cos(x) + 1) -> tan(x/2) if half=True + + Examples + ======== + + >>> from sympy.simplify.fu import TR2i + >>> from sympy.abc import x, a + >>> from sympy import sin, cos + >>> TR2i(sin(x)/cos(x)) + tan(x) + + Powers of the numerator and denominator are also recognized + + >>> TR2i(sin(x)**2/(cos(x) + 1)**2, half=True) + tan(x/2)**2 + + The transformation does not take place unless assumptions allow + (i.e. the base must be positive or the exponent must be an integer + for both numerator and denominator) + + >>> TR2i(sin(x)**a/(cos(x) + 1)**a) + sin(x)**a/(cos(x) + 1)**a + + """ + + def f(rv): + if not rv.is_Mul: + return rv + + n, d = rv.as_numer_denom() + if n.is_Atom or d.is_Atom: + return rv + + def ok(k, e): + # initial filtering of factors + return ( + (e.is_integer or k.is_positive) and ( + k.func in (sin, cos) or (half and + k.is_Add and + len(k.args) >= 2 and + any(any(isinstance(ai, cos) or ai.is_Pow and ai.base is cos + for ai in Mul.make_args(a)) for a in k.args)))) + + n = n.as_powers_dict() + ndone = [(k, n.pop(k)) for k in list(n.keys()) if not ok(k, n[k])] + if not n: + return rv + + d = d.as_powers_dict() + ddone = [(k, d.pop(k)) for k in list(d.keys()) if not ok(k, d[k])] + if not d: + return rv + + # factoring if necessary + + def factorize(d, ddone): + newk = [] + for k in d: + if k.is_Add and len(k.args) > 1: + knew = factor(k) if half else factor_terms(k) + if knew != k: + newk.append((k, knew)) + if newk: + for i, (k, knew) in enumerate(newk): + del d[k] + newk[i] = knew + newk = Mul(*newk).as_powers_dict() + for k in newk: + v = d[k] + newk[k] + if ok(k, v): + d[k] = v + else: + ddone.append((k, v)) + del newk + factorize(n, ndone) + factorize(d, ddone) + + # joining + t = [] + for k in n: + if isinstance(k, sin): + a = cos(k.args[0], evaluate=False) + if a in d and d[a] == n[k]: + t.append(tan(k.args[0])**n[k]) + n[k] = d[a] = None + elif half: + a1 = 1 + a + if a1 in d and d[a1] == n[k]: + t.append((tan(k.args[0]/2))**n[k]) + n[k] = d[a1] = None + elif isinstance(k, cos): + a = sin(k.args[0], evaluate=False) + if a in d and d[a] == n[k]: + t.append(tan(k.args[0])**-n[k]) + n[k] = d[a] = None + elif half and k.is_Add and k.args[0] is S.One and \ + isinstance(k.args[1], cos): + a = sin(k.args[1].args[0], evaluate=False) + if a in d and d[a] == n[k] and (d[a].is_integer or \ + a.is_positive): + t.append(tan(a.args[0]/2)**-n[k]) + n[k] = d[a] = None + + if t: + rv = Mul(*(t + [b**e for b, e in n.items() if e]))/\ + Mul(*[b**e for b, e in d.items() if e]) + rv *= Mul(*[b**e for b, e in ndone])/Mul(*[b**e for b, e in ddone]) + + return rv + + return bottom_up(rv, f) + + +def TR3(rv): + """Induced formula: example sin(-a) = -sin(a) + + Examples + ======== + + >>> from sympy.simplify.fu import TR3 + >>> from sympy.abc import x, y + >>> from sympy import pi + >>> from sympy import cos + >>> TR3(cos(y - x*(y - x))) + cos(x*(x - y) + y) + >>> cos(pi/2 + x) + -sin(x) + >>> cos(30*pi/2 + x) + -cos(x) + + """ + from sympy.simplify.simplify import signsimp + + # Negative argument (already automatic for funcs like sin(-x) -> -sin(x) + # but more complicated expressions can use it, too). Also, trig angles + # between pi/4 and pi/2 are not reduced to an angle between 0 and pi/4. + # The following are automatically handled: + # Argument of type: pi/2 +/- angle + # Argument of type: pi +/- angle + # Argument of type : 2k*pi +/- angle + + def f(rv): + if not isinstance(rv, TrigonometricFunction): + return rv + rv = rv.func(signsimp(rv.args[0])) + if not isinstance(rv, TrigonometricFunction): + return rv + if (rv.args[0] - S.Pi/4).is_positive is (S.Pi/2 - rv.args[0]).is_positive is True: + fmap = {cos: sin, sin: cos, tan: cot, cot: tan, sec: csc, csc: sec} + rv = fmap[type(rv)](S.Pi/2 - rv.args[0]) + return rv + + # touch numbers iside of trig functions to let them automatically update + rv = rv.replace( + lambda x: isinstance(x, TrigonometricFunction), + lambda x: x.replace( + lambda n: n.is_number and n.is_Mul, + lambda n: n.func(*n.args))) + + return bottom_up(rv, f) + + +def TR4(rv): + """Identify values of special angles. + + a= 0 pi/6 pi/4 pi/3 pi/2 + ---------------------------------------------------- + sin(a) 0 1/2 sqrt(2)/2 sqrt(3)/2 1 + cos(a) 1 sqrt(3)/2 sqrt(2)/2 1/2 0 + tan(a) 0 sqt(3)/3 1 sqrt(3) -- + + Examples + ======== + + >>> from sympy import pi + >>> from sympy import cos, sin, tan, cot + >>> for s in (0, pi/6, pi/4, pi/3, pi/2): + ... print('%s %s %s %s' % (cos(s), sin(s), tan(s), cot(s))) + ... + 1 0 0 zoo + sqrt(3)/2 1/2 sqrt(3)/3 sqrt(3) + sqrt(2)/2 sqrt(2)/2 1 1 + 1/2 sqrt(3)/2 sqrt(3) sqrt(3)/3 + 0 1 zoo 0 + """ + # special values at 0, pi/6, pi/4, pi/3, pi/2 already handled + return rv.replace( + lambda x: + isinstance(x, TrigonometricFunction) and + (r:=x.args[0]/pi).is_Rational and r.q in (1, 2, 3, 4, 6), + lambda x: + x.func(x.args[0].func(*x.args[0].args))) + + +def _TR56(rv, f, g, h, max, pow): + """Helper for TR5 and TR6 to replace f**2 with h(g**2) + + Options + ======= + + max : controls size of exponent that can appear on f + e.g. if max=4 then f**4 will be changed to h(g**2)**2. + pow : controls whether the exponent must be a perfect power of 2 + e.g. if pow=True (and max >= 6) then f**6 will not be changed + but f**8 will be changed to h(g**2)**4 + + >>> from sympy.simplify.fu import _TR56 as T + >>> from sympy.abc import x + >>> from sympy import sin, cos + >>> h = lambda x: 1 - x + >>> T(sin(x)**3, sin, cos, h, 4, False) + (1 - cos(x)**2)*sin(x) + >>> T(sin(x)**6, sin, cos, h, 6, False) + (1 - cos(x)**2)**3 + >>> T(sin(x)**6, sin, cos, h, 6, True) + sin(x)**6 + >>> T(sin(x)**8, sin, cos, h, 10, True) + (1 - cos(x)**2)**4 + """ + + def _f(rv): + # I'm not sure if this transformation should target all even powers + # or only those expressible as powers of 2. Also, should it only + # make the changes in powers that appear in sums -- making an isolated + # change is not going to allow a simplification as far as I can tell. + if not (rv.is_Pow and rv.base.func == f): + return rv + if not rv.exp.is_real: + return rv + + if (rv.exp < 0) == True: + return rv + if (rv.exp > max) == True: + return rv + if rv.exp == 1: + return rv + if rv.exp == 2: + return h(g(rv.base.args[0])**2) + else: + if rv.exp % 2 == 1: + e = rv.exp//2 + return f(rv.base.args[0])*h(g(rv.base.args[0])**2)**e + elif rv.exp == 4: + e = 2 + elif not pow: + if rv.exp % 2: + return rv + e = rv.exp//2 + else: + p = perfect_power(rv.exp) + if not p: + return rv + e = rv.exp//2 + return h(g(rv.base.args[0])**2)**e + + return bottom_up(rv, _f) + + +def TR5(rv, max=4, pow=False): + """Replacement of sin**2 with 1 - cos(x)**2. + + See _TR56 docstring for advanced use of ``max`` and ``pow``. + + Examples + ======== + + >>> from sympy.simplify.fu import TR5 + >>> from sympy.abc import x + >>> from sympy import sin + >>> TR5(sin(x)**2) + 1 - cos(x)**2 + >>> TR5(sin(x)**-2) # unchanged + sin(x)**(-2) + >>> TR5(sin(x)**4) + (1 - cos(x)**2)**2 + """ + return _TR56(rv, sin, cos, lambda x: 1 - x, max=max, pow=pow) + + +def TR6(rv, max=4, pow=False): + """Replacement of cos**2 with 1 - sin(x)**2. + + See _TR56 docstring for advanced use of ``max`` and ``pow``. + + Examples + ======== + + >>> from sympy.simplify.fu import TR6 + >>> from sympy.abc import x + >>> from sympy import cos + >>> TR6(cos(x)**2) + 1 - sin(x)**2 + >>> TR6(cos(x)**-2) #unchanged + cos(x)**(-2) + >>> TR6(cos(x)**4) + (1 - sin(x)**2)**2 + """ + return _TR56(rv, cos, sin, lambda x: 1 - x, max=max, pow=pow) + + +def TR7(rv): + """Lowering the degree of cos(x)**2. + + Examples + ======== + + >>> from sympy.simplify.fu import TR7 + >>> from sympy.abc import x + >>> from sympy import cos + >>> TR7(cos(x)**2) + cos(2*x)/2 + 1/2 + >>> TR7(cos(x)**2 + 1) + cos(2*x)/2 + 3/2 + + """ + + def f(rv): + if not (rv.is_Pow and rv.base.func == cos and rv.exp == 2): + return rv + return (1 + cos(2*rv.base.args[0]))/2 + + return bottom_up(rv, f) + + +def TR8(rv, first=True): + """Converting products of ``cos`` and/or ``sin`` to a sum or + difference of ``cos`` and or ``sin`` terms. + + Examples + ======== + + >>> from sympy.simplify.fu import TR8 + >>> from sympy import cos, sin + >>> TR8(cos(2)*cos(3)) + cos(5)/2 + cos(1)/2 + >>> TR8(cos(2)*sin(3)) + sin(5)/2 + sin(1)/2 + >>> TR8(sin(2)*sin(3)) + -cos(5)/2 + cos(1)/2 + """ + + def f(rv): + if not ( + rv.is_Mul or + rv.is_Pow and + rv.base.func in (cos, sin) and + (rv.exp.is_integer or rv.base.is_positive)): + return rv + + if first: + n, d = [expand_mul(i) for i in rv.as_numer_denom()] + newn = TR8(n, first=False) + newd = TR8(d, first=False) + if newn != n or newd != d: + rv = gcd_terms(newn/newd) + if rv.is_Mul and rv.args[0].is_Rational and \ + len(rv.args) == 2 and rv.args[1].is_Add: + rv = Mul(*rv.as_coeff_Mul()) + return rv + + args = {cos: [], sin: [], None: []} + for a in Mul.make_args(rv): + if a.func in (cos, sin): + args[type(a)].append(a.args[0]) + elif (a.is_Pow and a.exp.is_Integer and a.exp > 0 and \ + a.base.func in (cos, sin)): + # XXX this is ok but pathological expression could be handled + # more efficiently as in TRmorrie + args[type(a.base)].extend([a.base.args[0]]*a.exp) + else: + args[None].append(a) + c = args[cos] + s = args[sin] + if not (c and s or len(c) > 1 or len(s) > 1): + return rv + + args = args[None] + n = min(len(c), len(s)) + for i in range(n): + a1 = s.pop() + a2 = c.pop() + args.append((sin(a1 + a2) + sin(a1 - a2))/2) + while len(c) > 1: + a1 = c.pop() + a2 = c.pop() + args.append((cos(a1 + a2) + cos(a1 - a2))/2) + if c: + args.append(cos(c.pop())) + while len(s) > 1: + a1 = s.pop() + a2 = s.pop() + args.append((-cos(a1 + a2) + cos(a1 - a2))/2) + if s: + args.append(sin(s.pop())) + return TR8(expand_mul(Mul(*args))) + + return bottom_up(rv, f) + + +def TR9(rv): + """Sum of ``cos`` or ``sin`` terms as a product of ``cos`` or ``sin``. + + Examples + ======== + + >>> from sympy.simplify.fu import TR9 + >>> from sympy import cos, sin + >>> TR9(cos(1) + cos(2)) + 2*cos(1/2)*cos(3/2) + >>> TR9(cos(1) + 2*sin(1) + 2*sin(2)) + cos(1) + 4*sin(3/2)*cos(1/2) + + If no change is made by TR9, no re-arrangement of the + expression will be made. For example, though factoring + of common term is attempted, if the factored expression + was not changed, the original expression will be returned: + + >>> TR9(cos(3) + cos(3)*cos(2)) + cos(3) + cos(2)*cos(3) + + """ + + def f(rv): + if not rv.is_Add: + return rv + + def do(rv, first=True): + # cos(a)+/-cos(b) can be combined into a product of cosines and + # sin(a)+/-sin(b) can be combined into a product of cosine and + # sine. + # + # If there are more than two args, the pairs which "work" will + # have a gcd extractable and the remaining two terms will have + # the above structure -- all pairs must be checked to find the + # ones that work. args that don't have a common set of symbols + # are skipped since this doesn't lead to a simpler formula and + # also has the arbitrariness of combining, for example, the x + # and y term instead of the y and z term in something like + # cos(x) + cos(y) + cos(z). + + if not rv.is_Add: + return rv + + args = list(ordered(rv.args)) + if len(args) != 2: + hit = False + for i in range(len(args)): + ai = args[i] + if ai is None: + continue + for j in range(i + 1, len(args)): + aj = args[j] + if aj is None: + continue + was = ai + aj + new = do(was) + if new != was: + args[i] = new # update in place + args[j] = None + hit = True + break # go to next i + if hit: + rv = Add(*[_f for _f in args if _f]) + if rv.is_Add: + rv = do(rv) + + return rv + + # two-arg Add + split = trig_split(*args) + if not split: + return rv + gcd, n1, n2, a, b, iscos = split + + # application of rule if possible + if iscos: + if n1 == n2: + return gcd*n1*2*cos((a + b)/2)*cos((a - b)/2) + if n1 < 0: + a, b = b, a + return -2*gcd*sin((a + b)/2)*sin((a - b)/2) + else: + if n1 == n2: + return gcd*n1*2*sin((a + b)/2)*cos((a - b)/2) + if n1 < 0: + a, b = b, a + return 2*gcd*cos((a + b)/2)*sin((a - b)/2) + + return process_common_addends(rv, do) # DON'T sift by free symbols + + return bottom_up(rv, f) + + +def TR10(rv, first=True): + """Separate sums in ``cos`` and ``sin``. + + Examples + ======== + + >>> from sympy.simplify.fu import TR10 + >>> from sympy.abc import a, b, c + >>> from sympy import cos, sin + >>> TR10(cos(a + b)) + -sin(a)*sin(b) + cos(a)*cos(b) + >>> TR10(sin(a + b)) + sin(a)*cos(b) + sin(b)*cos(a) + >>> TR10(sin(a + b + c)) + (-sin(a)*sin(b) + cos(a)*cos(b))*sin(c) + \ + (sin(a)*cos(b) + sin(b)*cos(a))*cos(c) + """ + + def f(rv): + if rv.func not in (cos, sin): + return rv + + f = rv.func + arg = rv.args[0] + if arg.is_Add: + if first: + args = list(ordered(arg.args)) + else: + args = list(arg.args) + a = args.pop() + b = Add._from_args(args) + if b.is_Add: + if f == sin: + return sin(a)*TR10(cos(b), first=False) + \ + cos(a)*TR10(sin(b), first=False) + else: + return cos(a)*TR10(cos(b), first=False) - \ + sin(a)*TR10(sin(b), first=False) + else: + if f == sin: + return sin(a)*cos(b) + cos(a)*sin(b) + else: + return cos(a)*cos(b) - sin(a)*sin(b) + return rv + + return bottom_up(rv, f) + + +def TR10i(rv): + """Sum of products to function of sum. + + Examples + ======== + + >>> from sympy.simplify.fu import TR10i + >>> from sympy import cos, sin, sqrt + >>> from sympy.abc import x + + >>> TR10i(cos(1)*cos(3) + sin(1)*sin(3)) + cos(2) + >>> TR10i(cos(1)*sin(3) + sin(1)*cos(3) + cos(3)) + cos(3) + sin(4) + >>> TR10i(sqrt(2)*cos(x)*x + sqrt(6)*sin(x)*x) + 2*sqrt(2)*x*sin(x + pi/6) + + """ + def f(rv): + if not rv.is_Add: + return rv + + def do(rv, first=True): + # args which can be expressed as A*(cos(a)*cos(b)+/-sin(a)*sin(b)) + # or B*(cos(a)*sin(b)+/-cos(b)*sin(a)) can be combined into + # A*f(a+/-b) where f is either sin or cos. + # + # If there are more than two args, the pairs which "work" will have + # a gcd extractable and the remaining two terms will have the above + # structure -- all pairs must be checked to find the ones that + # work. + + if not rv.is_Add: + return rv + + args = list(ordered(rv.args)) + if len(args) != 2: + hit = False + for i in range(len(args)): + ai = args[i] + if ai is None: + continue + for j in range(i + 1, len(args)): + aj = args[j] + if aj is None: + continue + was = ai + aj + new = do(was) + if new != was: + args[i] = new # update in place + args[j] = None + hit = True + break # go to next i + if hit: + rv = Add(*[_f for _f in args if _f]) + if rv.is_Add: + rv = do(rv) + + return rv + + # two-arg Add + split = trig_split(*args, two=True) + if not split: + return rv + gcd, n1, n2, a, b, same = split + + # identify and get c1 to be cos then apply rule if possible + if same: # coscos, sinsin + gcd = n1*gcd + if n1 == n2: + return gcd*cos(a - b) + return gcd*cos(a + b) + else: #cossin, cossin + gcd = n1*gcd + if n1 == n2: + return gcd*sin(a + b) + return gcd*sin(b - a) + + rv = process_common_addends( + rv, do, lambda x: tuple(ordered(x.free_symbols))) + + # need to check for inducible pairs in ratio of sqrt(3):1 that + # appeared in different lists when sorting by coefficient + while rv.is_Add: + byrad = defaultdict(list) + for a in rv.args: + hit = 0 + if a.is_Mul: + for ai in a.args: + if ai.is_Pow and ai.exp is S.Half and \ + ai.base.is_Integer: + byrad[ai].append(a) + hit = 1 + break + if not hit: + byrad[S.One].append(a) + + # no need to check all pairs -- just check for the onees + # that have the right ratio + args = [] + for a in byrad: + for b in [_ROOT3()*a, _invROOT3()]: + if b in byrad: + for i in range(len(byrad[a])): + if byrad[a][i] is None: + continue + for j in range(len(byrad[b])): + if byrad[b][j] is None: + continue + was = Add(byrad[a][i] + byrad[b][j]) + new = do(was) + if new != was: + args.append(new) + byrad[a][i] = None + byrad[b][j] = None + break + if args: + rv = Add(*(args + [Add(*[_f for _f in v if _f]) + for v in byrad.values()])) + else: + rv = do(rv) # final pass to resolve any new inducible pairs + break + + return rv + + return bottom_up(rv, f) + + +def TR11(rv, base=None): + """Function of double angle to product. The ``base`` argument can be used + to indicate what is the un-doubled argument, e.g. if 3*pi/7 is the base + then cosine and sine functions with argument 6*pi/7 will be replaced. + + Examples + ======== + + >>> from sympy.simplify.fu import TR11 + >>> from sympy import cos, sin, pi + >>> from sympy.abc import x + >>> TR11(sin(2*x)) + 2*sin(x)*cos(x) + >>> TR11(cos(2*x)) + -sin(x)**2 + cos(x)**2 + >>> TR11(sin(4*x)) + 4*(-sin(x)**2 + cos(x)**2)*sin(x)*cos(x) + >>> TR11(sin(4*x/3)) + 4*(-sin(x/3)**2 + cos(x/3)**2)*sin(x/3)*cos(x/3) + + If the arguments are simply integers, no change is made + unless a base is provided: + + >>> TR11(cos(2)) + cos(2) + >>> TR11(cos(4), 2) + -sin(2)**2 + cos(2)**2 + + There is a subtle issue here in that autosimplification will convert + some higher angles to lower angles + + >>> cos(6*pi/7) + cos(3*pi/7) + -cos(pi/7) + cos(3*pi/7) + + The 6*pi/7 angle is now pi/7 but can be targeted with TR11 by supplying + the 3*pi/7 base: + + >>> TR11(_, 3*pi/7) + -sin(3*pi/7)**2 + cos(3*pi/7)**2 + cos(3*pi/7) + + """ + + def f(rv): + if rv.func not in (cos, sin): + return rv + + if base: + f = rv.func + t = f(base*2) + co = S.One + if t.is_Mul: + co, t = t.as_coeff_Mul() + if t.func not in (cos, sin): + return rv + if rv.args[0] == t.args[0]: + c = cos(base) + s = sin(base) + if f is cos: + return (c**2 - s**2)/co + else: + return 2*c*s/co + return rv + + elif not rv.args[0].is_Number: + # make a change if the leading coefficient's numerator is + # divisible by 2 + c, m = rv.args[0].as_coeff_Mul(rational=True) + if c.p % 2 == 0: + arg = c.p//2*m/c.q + c = TR11(cos(arg)) + s = TR11(sin(arg)) + if rv.func == sin: + rv = 2*s*c + else: + rv = c**2 - s**2 + return rv + + return bottom_up(rv, f) + + +def _TR11(rv): + """ + Helper for TR11 to find half-arguments for sin in factors of + num/den that appear in cos or sin factors in the den/num. + + Examples + ======== + + >>> from sympy.simplify.fu import TR11, _TR11 + >>> from sympy import cos, sin + >>> from sympy.abc import x + >>> TR11(sin(x/3)/(cos(x/6))) + sin(x/3)/cos(x/6) + >>> _TR11(sin(x/3)/(cos(x/6))) + 2*sin(x/6) + >>> TR11(sin(x/6)/(sin(x/3))) + sin(x/6)/sin(x/3) + >>> _TR11(sin(x/6)/(sin(x/3))) + 1/(2*cos(x/6)) + + """ + def f(rv): + if not isinstance(rv, Expr): + return rv + + def sincos_args(flat): + # find arguments of sin and cos that + # appears as bases in args of flat + # and have Integer exponents + args = defaultdict(set) + for fi in Mul.make_args(flat): + b, e = fi.as_base_exp() + if e.is_Integer and e > 0: + if b.func in (cos, sin): + args[type(b)].add(b.args[0]) + return args + num_args, den_args = map(sincos_args, rv.as_numer_denom()) + def handle_match(rv, num_args, den_args): + # for arg in sin args of num_args, look for arg/2 + # in den_args and pass this half-angle to TR11 + # for handling in rv + for narg in num_args[sin]: + half = narg/2 + if half in den_args[cos]: + func = cos + elif half in den_args[sin]: + func = sin + else: + continue + rv = TR11(rv, half) + den_args[func].remove(half) + return rv + # sin in num, sin or cos in den + rv = handle_match(rv, num_args, den_args) + # sin in den, sin or cos in num + rv = handle_match(rv, den_args, num_args) + return rv + + return bottom_up(rv, f) + + +def TR12(rv, first=True): + """Separate sums in ``tan``. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy import tan + >>> from sympy.simplify.fu import TR12 + >>> TR12(tan(x + y)) + (tan(x) + tan(y))/(-tan(x)*tan(y) + 1) + """ + + def f(rv): + if not rv.func == tan: + return rv + + arg = rv.args[0] + if arg.is_Add: + if first: + args = list(ordered(arg.args)) + else: + args = list(arg.args) + a = args.pop() + b = Add._from_args(args) + if b.is_Add: + tb = TR12(tan(b), first=False) + else: + tb = tan(b) + return (tan(a) + tb)/(1 - tan(a)*tb) + return rv + + return bottom_up(rv, f) + + +def TR12i(rv): + """Combine tan arguments as + (tan(y) + tan(x))/(tan(x)*tan(y) - 1) -> -tan(x + y). + + Examples + ======== + + >>> from sympy.simplify.fu import TR12i + >>> from sympy import tan + >>> from sympy.abc import a, b, c + >>> ta, tb, tc = [tan(i) for i in (a, b, c)] + >>> TR12i((ta + tb)/(-ta*tb + 1)) + tan(a + b) + >>> TR12i((ta + tb)/(ta*tb - 1)) + -tan(a + b) + >>> TR12i((-ta - tb)/(ta*tb - 1)) + tan(a + b) + >>> eq = (ta + tb)/(-ta*tb + 1)**2*(-3*ta - 3*tc)/(2*(ta*tc - 1)) + >>> TR12i(eq.expand()) + -3*tan(a + b)*tan(a + c)/(2*(tan(a) + tan(b) - 1)) + """ + def f(rv): + if not (rv.is_Add or rv.is_Mul or rv.is_Pow): + return rv + + n, d = rv.as_numer_denom() + if not d.args or not n.args: + return rv + + dok = {} + + def ok(di): + m = as_f_sign_1(di) + if m: + g, f, s = m + if s is S.NegativeOne and f.is_Mul and len(f.args) == 2 and \ + all(isinstance(fi, tan) for fi in f.args): + return g, f + + d_args = list(Mul.make_args(d)) + for i, di in enumerate(d_args): + m = ok(di) + if m: + g, t = m + s = Add(*[_.args[0] for _ in t.args]) + dok[s] = S.One + d_args[i] = g + continue + if di.is_Add: + di = factor(di) + if di.is_Mul: + d_args.extend(di.args) + d_args[i] = S.One + elif di.is_Pow and (di.exp.is_integer or di.base.is_positive): + m = ok(di.base) + if m: + g, t = m + s = Add(*[_.args[0] for _ in t.args]) + dok[s] = di.exp + d_args[i] = g**di.exp + else: + di = factor(di) + if di.is_Mul: + d_args.extend(di.args) + d_args[i] = S.One + if not dok: + return rv + + def ok(ni): + if ni.is_Add and len(ni.args) == 2: + a, b = ni.args + if isinstance(a, tan) and isinstance(b, tan): + return a, b + n_args = list(Mul.make_args(factor_terms(n))) + hit = False + for i, ni in enumerate(n_args): + m = ok(ni) + if not m: + m = ok(-ni) + if m: + n_args[i] = S.NegativeOne + else: + if ni.is_Add: + ni = factor(ni) + if ni.is_Mul: + n_args.extend(ni.args) + n_args[i] = S.One + continue + elif ni.is_Pow and ( + ni.exp.is_integer or ni.base.is_positive): + m = ok(ni.base) + if m: + n_args[i] = S.One + else: + ni = factor(ni) + if ni.is_Mul: + n_args.extend(ni.args) + n_args[i] = S.One + continue + else: + continue + else: + n_args[i] = S.One + hit = True + s = Add(*[_.args[0] for _ in m]) + ed = dok[s] + newed = ed.extract_additively(S.One) + if newed is not None: + if newed: + dok[s] = newed + else: + dok.pop(s) + n_args[i] *= -tan(s) + + if hit: + rv = Mul(*n_args)/Mul(*d_args)/Mul(*[(Add(*[ + tan(a) for a in i.args]) - 1)**e for i, e in dok.items()]) + + return rv + + return bottom_up(rv, f) + + +def TR13(rv): + """Change products of ``tan`` or ``cot``. + + Examples + ======== + + >>> from sympy.simplify.fu import TR13 + >>> from sympy import tan, cot + >>> TR13(tan(3)*tan(2)) + -tan(2)/tan(5) - tan(3)/tan(5) + 1 + >>> TR13(cot(3)*cot(2)) + cot(2)*cot(5) + 1 + cot(3)*cot(5) + """ + + def f(rv): + if not rv.is_Mul: + return rv + + # XXX handle products of powers? or let power-reducing handle it? + args = {tan: [], cot: [], None: []} + for a in Mul.make_args(rv): + if a.func in (tan, cot): + args[type(a)].append(a.args[0]) + else: + args[None].append(a) + t = args[tan] + c = args[cot] + if len(t) < 2 and len(c) < 2: + return rv + args = args[None] + while len(t) > 1: + t1 = t.pop() + t2 = t.pop() + args.append(1 - (tan(t1)/tan(t1 + t2) + tan(t2)/tan(t1 + t2))) + if t: + args.append(tan(t.pop())) + while len(c) > 1: + t1 = c.pop() + t2 = c.pop() + args.append(1 + cot(t1)*cot(t1 + t2) + cot(t2)*cot(t1 + t2)) + if c: + args.append(cot(c.pop())) + return Mul(*args) + + return bottom_up(rv, f) + + +def TRmorrie(rv): + """Returns cos(x)*cos(2*x)*...*cos(2**(k-1)*x) -> sin(2**k*x)/(2**k*sin(x)) + + Examples + ======== + + >>> from sympy.simplify.fu import TRmorrie, TR8, TR3 + >>> from sympy.abc import x + >>> from sympy import Mul, cos, pi + >>> TRmorrie(cos(x)*cos(2*x)) + sin(4*x)/(4*sin(x)) + >>> TRmorrie(7*Mul(*[cos(x) for x in range(10)])) + 7*sin(12)*sin(16)*cos(5)*cos(7)*cos(9)/(64*sin(1)*sin(3)) + + Sometimes autosimplification will cause a power to be + not recognized. e.g. in the following, cos(4*pi/7) automatically + simplifies to -cos(3*pi/7) so only 2 of the 3 terms are + recognized: + + >>> TRmorrie(cos(pi/7)*cos(2*pi/7)*cos(4*pi/7)) + -sin(3*pi/7)*cos(3*pi/7)/(4*sin(pi/7)) + + A touch by TR8 resolves the expression to a Rational + + >>> TR8(_) + -1/8 + + In this case, if eq is unsimplified, the answer is obtained + directly: + + >>> eq = cos(pi/9)*cos(2*pi/9)*cos(3*pi/9)*cos(4*pi/9) + >>> TRmorrie(eq) + 1/16 + + But if angles are made canonical with TR3 then the answer + is not simplified without further work: + + >>> TR3(eq) + sin(pi/18)*cos(pi/9)*cos(2*pi/9)/2 + >>> TRmorrie(_) + sin(pi/18)*sin(4*pi/9)/(8*sin(pi/9)) + >>> TR8(_) + cos(7*pi/18)/(16*sin(pi/9)) + >>> TR3(_) + 1/16 + + The original expression would have resolve to 1/16 directly with TR8, + however: + + >>> TR8(eq) + 1/16 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Morrie%27s_law + + """ + + def f(rv, first=True): + if not rv.is_Mul: + return rv + if first: + n, d = rv.as_numer_denom() + return f(n, 0)/f(d, 0) + + args = defaultdict(list) + coss = {} + other = [] + for c in rv.args: + b, e = c.as_base_exp() + if e.is_Integer and isinstance(b, cos): + co, a = b.args[0].as_coeff_Mul() + args[a].append(co) + coss[b] = e + else: + other.append(c) + + new = [] + for a in args: + c = args[a] + c.sort() + while c: + k = 0 + cc = ci = c[0] + while cc in c: + k += 1 + cc *= 2 + if k > 1: + newarg = sin(2**k*ci*a)/2**k/sin(ci*a) + # see how many times this can be taken + take = None + ccs = [] + for i in range(k): + cc /= 2 + key = cos(a*cc, evaluate=False) + ccs.append(cc) + take = min(coss[key], take or coss[key]) + # update exponent counts + for i in range(k): + cc = ccs.pop() + key = cos(a*cc, evaluate=False) + coss[key] -= take + if not coss[key]: + c.remove(cc) + new.append(newarg**take) + else: + b = cos(c.pop(0)*a) + other.append(b**coss[b]) + + if new: + rv = Mul(*(new + other + [ + cos(k*a, evaluate=False) for a in args for k in args[a]])) + + return rv + + return bottom_up(rv, f) + + +def TR14(rv, first=True): + """Convert factored powers of sin and cos identities into simpler + expressions. + + Examples + ======== + + >>> from sympy.simplify.fu import TR14 + >>> from sympy.abc import x, y + >>> from sympy import cos, sin + >>> TR14((cos(x) - 1)*(cos(x) + 1)) + -sin(x)**2 + >>> TR14((sin(x) - 1)*(sin(x) + 1)) + -cos(x)**2 + >>> p1 = (cos(x) + 1)*(cos(x) - 1) + >>> p2 = (cos(y) - 1)*2*(cos(y) + 1) + >>> p3 = (3*(cos(y) - 1))*(3*(cos(y) + 1)) + >>> TR14(p1*p2*p3*(x - 1)) + -18*(x - 1)*sin(x)**2*sin(y)**4 + + """ + + def f(rv): + if not rv.is_Mul: + return rv + + if first: + # sort them by location in numerator and denominator + # so the code below can just deal with positive exponents + n, d = rv.as_numer_denom() + if d is not S.One: + newn = TR14(n, first=False) + newd = TR14(d, first=False) + if newn != n or newd != d: + rv = newn/newd + return rv + + other = [] + process = [] + for a in rv.args: + if a.is_Pow: + b, e = a.as_base_exp() + if not (e.is_integer or b.is_positive): + other.append(a) + continue + a = b + else: + e = S.One + m = as_f_sign_1(a) + if not m or m[1].func not in (cos, sin): + if e is S.One: + other.append(a) + else: + other.append(a**e) + continue + g, f, si = m + process.append((g, e.is_Number, e, f, si, a)) + + # sort them to get like terms next to each other + process = list(ordered(process)) + + # keep track of whether there was any change + nother = len(other) + + # access keys + keys = (g, t, e, f, si, a) = list(range(6)) + + while process: + A = process.pop(0) + if process: + B = process[0] + + if A[e].is_Number and B[e].is_Number: + # both exponents are numbers + if A[f] == B[f]: + if A[si] != B[si]: + B = process.pop(0) + take = min(A[e], B[e]) + + # reinsert any remainder + # the B will likely sort after A so check it first + if B[e] != take: + rem = [B[i] for i in keys] + rem[e] -= take + process.insert(0, rem) + elif A[e] != take: + rem = [A[i] for i in keys] + rem[e] -= take + process.insert(0, rem) + + if isinstance(A[f], cos): + t = sin + else: + t = cos + other.append((-A[g]*B[g]*t(A[f].args[0])**2)**take) + continue + + elif A[e] == B[e]: + # both exponents are equal symbols + if A[f] == B[f]: + if A[si] != B[si]: + B = process.pop(0) + take = A[e] + if isinstance(A[f], cos): + t = sin + else: + t = cos + other.append((-A[g]*B[g]*t(A[f].args[0])**2)**take) + continue + + # either we are done or neither condition above applied + other.append(A[a]**A[e]) + + if len(other) != nother: + rv = Mul(*other) + + return rv + + return bottom_up(rv, f) + + +def TR15(rv, max=4, pow=False): + """Convert sin(x)**-2 to 1 + cot(x)**2. + + See _TR56 docstring for advanced use of ``max`` and ``pow``. + + Examples + ======== + + >>> from sympy.simplify.fu import TR15 + >>> from sympy.abc import x + >>> from sympy import sin + >>> TR15(1 - 1/sin(x)**2) + -cot(x)**2 + + """ + + def f(rv): + if not (isinstance(rv, Pow) and isinstance(rv.base, sin)): + return rv + + e = rv.exp + if e % 2 == 1: + return TR15(rv.base**(e + 1))/rv.base + + ia = 1/rv + a = _TR56(ia, sin, cot, lambda x: 1 + x, max=max, pow=pow) + if a != ia: + rv = a + return rv + + return bottom_up(rv, f) + + +def TR16(rv, max=4, pow=False): + """Convert cos(x)**-2 to 1 + tan(x)**2. + + See _TR56 docstring for advanced use of ``max`` and ``pow``. + + Examples + ======== + + >>> from sympy.simplify.fu import TR16 + >>> from sympy.abc import x + >>> from sympy import cos + >>> TR16(1 - 1/cos(x)**2) + -tan(x)**2 + + """ + + def f(rv): + if not (isinstance(rv, Pow) and isinstance(rv.base, cos)): + return rv + + e = rv.exp + if e % 2 == 1: + return TR15(rv.base**(e + 1))/rv.base + + ia = 1/rv + a = _TR56(ia, cos, tan, lambda x: 1 + x, max=max, pow=pow) + if a != ia: + rv = a + return rv + + return bottom_up(rv, f) + + +def TR111(rv): + """Convert f(x)**-i to g(x)**i where either ``i`` is an integer + or the base is positive and f, g are: tan, cot; sin, csc; or cos, sec. + + Examples + ======== + + >>> from sympy.simplify.fu import TR111 + >>> from sympy.abc import x + >>> from sympy import tan + >>> TR111(1 - 1/tan(x)**2) + 1 - cot(x)**2 + + """ + + def f(rv): + if not ( + isinstance(rv, Pow) and + (rv.base.is_positive or rv.exp.is_integer and rv.exp.is_negative)): + return rv + + if isinstance(rv.base, tan): + return cot(rv.base.args[0])**-rv.exp + elif isinstance(rv.base, sin): + return csc(rv.base.args[0])**-rv.exp + elif isinstance(rv.base, cos): + return sec(rv.base.args[0])**-rv.exp + return rv + + return bottom_up(rv, f) + + +def TR22(rv, max=4, pow=False): + """Convert tan(x)**2 to sec(x)**2 - 1 and cot(x)**2 to csc(x)**2 - 1. + + See _TR56 docstring for advanced use of ``max`` and ``pow``. + + Examples + ======== + + >>> from sympy.simplify.fu import TR22 + >>> from sympy.abc import x + >>> from sympy import tan, cot + >>> TR22(1 + tan(x)**2) + sec(x)**2 + >>> TR22(1 + cot(x)**2) + csc(x)**2 + + """ + + def f(rv): + if not (isinstance(rv, Pow) and rv.base.func in (cot, tan)): + return rv + + rv = _TR56(rv, tan, sec, lambda x: x - 1, max=max, pow=pow) + rv = _TR56(rv, cot, csc, lambda x: x - 1, max=max, pow=pow) + return rv + + return bottom_up(rv, f) + + +def TRpower(rv): + """Convert sin(x)**n and cos(x)**n with positive n to sums. + + Examples + ======== + + >>> from sympy.simplify.fu import TRpower + >>> from sympy.abc import x + >>> from sympy import cos, sin + >>> TRpower(sin(x)**6) + -15*cos(2*x)/32 + 3*cos(4*x)/16 - cos(6*x)/32 + 5/16 + >>> TRpower(sin(x)**3*cos(2*x)**4) + (3*sin(x)/4 - sin(3*x)/4)*(cos(4*x)/2 + cos(8*x)/8 + 3/8) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/List_of_trigonometric_identities#Power-reduction_formulae + + """ + + def f(rv): + if not (isinstance(rv, Pow) and isinstance(rv.base, (sin, cos))): + return rv + b, n = rv.as_base_exp() + x = b.args[0] + if n.is_Integer and n.is_positive: + if n.is_odd and isinstance(b, cos): + rv = 2**(1-n)*Add(*[binomial(n, k)*cos((n - 2*k)*x) + for k in range((n + 1)/2)]) + elif n.is_odd and isinstance(b, sin): + rv = 2**(1-n)*S.NegativeOne**((n-1)/2)*Add(*[binomial(n, k)* + S.NegativeOne**k*sin((n - 2*k)*x) for k in range((n + 1)/2)]) + elif n.is_even and isinstance(b, cos): + rv = 2**(1-n)*Add(*[binomial(n, k)*cos((n - 2*k)*x) + for k in range(n/2)]) + elif n.is_even and isinstance(b, sin): + rv = 2**(1-n)*S.NegativeOne**(n/2)*Add(*[binomial(n, k)* + S.NegativeOne**k*cos((n - 2*k)*x) for k in range(n/2)]) + if n.is_even: + rv += 2**(-n)*binomial(n, n/2) + return rv + + return bottom_up(rv, f) + + +def L(rv): + """Return count of trigonometric functions in expression. + + Examples + ======== + + >>> from sympy.simplify.fu import L + >>> from sympy.abc import x + >>> from sympy import cos, sin + >>> L(cos(x)+sin(x)) + 2 + """ + return S(rv.count(TrigonometricFunction)) + + +# ============== end of basic Fu-like tools ===================== + +if SYMPY_DEBUG: + (TR0, TR1, TR2, TR3, TR4, TR5, TR6, TR7, TR8, TR9, TR10, TR11, TR12, TR13, + TR2i, TRmorrie, TR14, TR15, TR16, TR12i, TR111, TR22 + )= list(map(debug, + (TR0, TR1, TR2, TR3, TR4, TR5, TR6, TR7, TR8, TR9, TR10, TR11, TR12, TR13, + TR2i, TRmorrie, TR14, TR15, TR16, TR12i, TR111, TR22))) + + +# tuples are chains -- (f, g) -> lambda x: g(f(x)) +# lists are choices -- [f, g] -> lambda x: min(f(x), g(x), key=objective) + +CTR1 = [(TR5, TR0), (TR6, TR0), identity] + +CTR2 = (TR11, [(TR5, TR0), (TR6, TR0), TR0]) + +CTR3 = [(TRmorrie, TR8, TR0), (TRmorrie, TR8, TR10i, TR0), identity] + +CTR4 = [(TR4, TR10i), identity] + +RL1 = (TR4, TR3, TR4, TR12, TR4, TR13, TR4, TR0) + + +# XXX it's a little unclear how this one is to be implemented +# see Fu paper of reference, page 7. What is the Union symbol referring to? +# The diagram shows all these as one chain of transformations, but the +# text refers to them being applied independently. Also, a break +# if L starts to increase has not been implemented. +RL2 = [ + (TR4, TR3, TR10, TR4, TR3, TR11), + (TR5, TR7, TR11, TR4), + (CTR3, CTR1, TR9, CTR2, TR4, TR9, TR9, CTR4), + identity, + ] + + +def fu(rv, measure=lambda x: (L(x), x.count_ops())): + """Attempt to simplify expression by using transformation rules given + in the algorithm by Fu et al. + + :func:`fu` will try to minimize the objective function ``measure``. + By default this first minimizes the number of trig terms and then minimizes + the number of total operations. + + Examples + ======== + + >>> from sympy.simplify.fu import fu + >>> from sympy import cos, sin, tan, pi, S, sqrt + >>> from sympy.abc import x, y, a, b + + >>> fu(sin(50)**2 + cos(50)**2 + sin(pi/6)) + 3/2 + >>> fu(sqrt(6)*cos(x) + sqrt(2)*sin(x)) + 2*sqrt(2)*sin(x + pi/3) + + CTR1 example + + >>> eq = sin(x)**4 - cos(y)**2 + sin(y)**2 + 2*cos(x)**2 + >>> fu(eq) + cos(x)**4 - 2*cos(y)**2 + 2 + + CTR2 example + + >>> fu(S.Half - cos(2*x)/2) + sin(x)**2 + + CTR3 example + + >>> fu(sin(a)*(cos(b) - sin(b)) + cos(a)*(sin(b) + cos(b))) + sqrt(2)*sin(a + b + pi/4) + + CTR4 example + + >>> fu(sqrt(3)*cos(x)/2 + sin(x)/2) + sin(x + pi/3) + + Example 1 + + >>> fu(1-sin(2*x)**2/4-sin(y)**2-cos(x)**4) + -cos(x)**2 + cos(y)**2 + + Example 2 + + >>> fu(cos(4*pi/9)) + sin(pi/18) + >>> fu(cos(pi/9)*cos(2*pi/9)*cos(3*pi/9)*cos(4*pi/9)) + 1/16 + + Example 3 + + >>> fu(tan(7*pi/18)+tan(5*pi/18)-sqrt(3)*tan(5*pi/18)*tan(7*pi/18)) + -sqrt(3) + + Objective function example + + >>> fu(sin(x)/cos(x)) # default objective function + tan(x) + >>> fu(sin(x)/cos(x), measure=lambda x: -x.count_ops()) # maximize op count + sin(x)/cos(x) + + References + ========== + + .. [1] https://www.sciencedirect.com/science/article/pii/S0895717706001609 + """ + fRL1 = greedy(RL1, measure) + fRL2 = greedy(RL2, measure) + + was = rv + rv = sympify(rv) + if not isinstance(rv, Expr): + return rv.func(*[fu(a, measure=measure) for a in rv.args]) + rv = TR1(rv) + if rv.has(tan, cot): + rv1 = fRL1(rv) + if (measure(rv1) < measure(rv)): + rv = rv1 + if rv.has(tan, cot): + rv = TR2(rv) + if rv.has(sin, cos): + rv1 = fRL2(rv) + rv2 = TR8(TRmorrie(rv1)) + rv = min([was, rv, rv1, rv2], key=measure) + return min(TR2i(rv), rv, key=measure) + + +def process_common_addends(rv, do, key2=None, key1=True): + """Apply ``do`` to addends of ``rv`` that (if ``key1=True``) share at least + a common absolute value of their coefficient and the value of ``key2`` when + applied to the argument. If ``key1`` is False ``key2`` must be supplied and + will be the only key applied. + """ + + # collect by absolute value of coefficient and key2 + absc = defaultdict(list) + if key1: + for a in rv.args: + c, a = a.as_coeff_Mul() + if c < 0: + c = -c + a = -a # put the sign on `a` + absc[(c, key2(a) if key2 else 1)].append(a) + elif key2: + for a in rv.args: + absc[(S.One, key2(a))].append(a) + else: + raise ValueError('must have at least one key') + + args = [] + hit = False + for k in absc: + v = absc[k] + c, _ = k + if len(v) > 1: + e = Add(*v, evaluate=False) + new = do(e) + if new != e: + e = new + hit = True + args.append(c*e) + else: + args.append(c*v[0]) + if hit: + rv = Add(*args) + + return rv + + +fufuncs = ''' + TR0 TR1 TR2 TR3 TR4 TR5 TR6 TR7 TR8 TR9 TR10 TR10i TR11 + TR12 TR13 L TR2i TRmorrie TR12i + TR14 TR15 TR16 TR111 TR22'''.split() +FU = dict(list(zip(fufuncs, list(map(locals().get, fufuncs))))) + + +@cacheit +def _ROOT2(): + return sqrt(2) + + +@cacheit +def _ROOT3(): + return sqrt(3) + + +@cacheit +def _invROOT3(): + return 1/sqrt(3) + + +def trig_split(a, b, two=False): + """Return the gcd, s1, s2, a1, a2, bool where + + If two is False (default) then:: + a + b = gcd*(s1*f(a1) + s2*f(a2)) where f = cos if bool else sin + else: + if bool, a + b was +/- cos(a1)*cos(a2) +/- sin(a1)*sin(a2) and equals + n1*gcd*cos(a - b) if n1 == n2 else + n1*gcd*cos(a + b) + else a + b was +/- cos(a1)*sin(a2) +/- sin(a1)*cos(a2) and equals + n1*gcd*sin(a + b) if n1 = n2 else + n1*gcd*sin(b - a) + + Examples + ======== + + >>> from sympy.simplify.fu import trig_split + >>> from sympy.abc import x, y, z + >>> from sympy import cos, sin, sqrt + + >>> trig_split(cos(x), cos(y)) + (1, 1, 1, x, y, True) + >>> trig_split(2*cos(x), -2*cos(y)) + (2, 1, -1, x, y, True) + >>> trig_split(cos(x)*sin(y), cos(y)*sin(y)) + (sin(y), 1, 1, x, y, True) + + >>> trig_split(cos(x), -sqrt(3)*sin(x), two=True) + (2, 1, -1, x, pi/6, False) + >>> trig_split(cos(x), sin(x), two=True) + (sqrt(2), 1, 1, x, pi/4, False) + >>> trig_split(cos(x), -sin(x), two=True) + (sqrt(2), 1, -1, x, pi/4, False) + >>> trig_split(sqrt(2)*cos(x), -sqrt(6)*sin(x), two=True) + (2*sqrt(2), 1, -1, x, pi/6, False) + >>> trig_split(-sqrt(6)*cos(x), -sqrt(2)*sin(x), two=True) + (-2*sqrt(2), 1, 1, x, pi/3, False) + >>> trig_split(cos(x)/sqrt(6), sin(x)/sqrt(2), two=True) + (sqrt(6)/3, 1, 1, x, pi/6, False) + >>> trig_split(-sqrt(6)*cos(x)*sin(y), -sqrt(2)*sin(x)*sin(y), two=True) + (-2*sqrt(2)*sin(y), 1, 1, x, pi/3, False) + + >>> trig_split(cos(x), sin(x)) + >>> trig_split(cos(x), sin(z)) + >>> trig_split(2*cos(x), -sin(x)) + >>> trig_split(cos(x), -sqrt(3)*sin(x)) + >>> trig_split(cos(x)*cos(y), sin(x)*sin(z)) + >>> trig_split(cos(x)*cos(y), sin(x)*sin(y)) + >>> trig_split(-sqrt(6)*cos(x), sqrt(2)*sin(x)*sin(y), two=True) + """ + a, b = [Factors(i) for i in (a, b)] + ua, ub = a.normal(b) + gcd = a.gcd(b).as_expr() + n1 = n2 = 1 + if S.NegativeOne in ua.factors: + ua = ua.quo(S.NegativeOne) + n1 = -n1 + elif S.NegativeOne in ub.factors: + ub = ub.quo(S.NegativeOne) + n2 = -n2 + a, b = [i.as_expr() for i in (ua, ub)] + + def pow_cos_sin(a, two): + """Return ``a`` as a tuple (r, c, s) such that + ``a = (r or 1)*(c or 1)*(s or 1)``. + + Three arguments are returned (radical, c-factor, s-factor) as + long as the conditions set by ``two`` are met; otherwise None is + returned. If ``two`` is True there will be one or two non-None + values in the tuple: c and s or c and r or s and r or s or c with c + being a cosine function (if possible) else a sine, and s being a sine + function (if possible) else oosine. If ``two`` is False then there + will only be a c or s term in the tuple. + + ``two`` also require that either two cos and/or sin be present (with + the condition that if the functions are the same the arguments are + different or vice versa) or that a single cosine or a single sine + be present with an optional radical. + + If the above conditions dictated by ``two`` are not met then None + is returned. + """ + c = s = None + co = S.One + if a.is_Mul: + co, a = a.as_coeff_Mul() + if len(a.args) > 2 or not two: + return None + if a.is_Mul: + args = list(a.args) + else: + args = [a] + a = args.pop(0) + if isinstance(a, cos): + c = a + elif isinstance(a, sin): + s = a + elif a.is_Pow and a.exp is S.Half: # autoeval doesn't allow -1/2 + co *= a + else: + return None + if args: + b = args[0] + if isinstance(b, cos): + if c: + s = b + else: + c = b + elif isinstance(b, sin): + if s: + c = b + else: + s = b + elif b.is_Pow and b.exp is S.Half: + co *= b + else: + return None + return co if co is not S.One else None, c, s + elif isinstance(a, cos): + c = a + elif isinstance(a, sin): + s = a + if c is None and s is None: + return + co = co if co is not S.One else None + return co, c, s + + # get the parts + m = pow_cos_sin(a, two) + if m is None: + return + coa, ca, sa = m + m = pow_cos_sin(b, two) + if m is None: + return + cob, cb, sb = m + + # check them + if (not ca) and cb or ca and isinstance(ca, sin): + coa, ca, sa, cob, cb, sb = cob, cb, sb, coa, ca, sa + n1, n2 = n2, n1 + if not two: # need cos(x) and cos(y) or sin(x) and sin(y) + c = ca or sa + s = cb or sb + if not isinstance(c, s.func): + return None + return gcd, n1, n2, c.args[0], s.args[0], isinstance(c, cos) + else: + if not coa and not cob: + if (ca and cb and sa and sb): + if isinstance(ca, sa.func) is not isinstance(cb, sb.func): + return + args = {j.args for j in (ca, sa)} + if not all(i.args in args for i in (cb, sb)): + return + return gcd, n1, n2, ca.args[0], sa.args[0], isinstance(ca, sa.func) + if ca and sa or cb and sb or \ + two and (ca is None and sa is None or cb is None and sb is None): + return + c = ca or sa + s = cb or sb + if c.args != s.args: + return + if not coa: + coa = S.One + if not cob: + cob = S.One + if coa is cob: + gcd *= _ROOT2() + return gcd, n1, n2, c.args[0], pi/4, False + elif coa/cob == _ROOT3(): + gcd *= 2*cob + return gcd, n1, n2, c.args[0], pi/3, False + elif coa/cob == _invROOT3(): + gcd *= 2*coa + return gcd, n1, n2, c.args[0], pi/6, False + + +def as_f_sign_1(e): + """If ``e`` is a sum that can be written as ``g*(a + s)`` where + ``s`` is ``+/-1``, return ``g``, ``a``, and ``s`` where ``a`` does + not have a leading negative coefficient. + + Examples + ======== + + >>> from sympy.simplify.fu import as_f_sign_1 + >>> from sympy.abc import x + >>> as_f_sign_1(x + 1) + (1, x, 1) + >>> as_f_sign_1(x - 1) + (1, x, -1) + >>> as_f_sign_1(-x + 1) + (-1, x, -1) + >>> as_f_sign_1(-x - 1) + (-1, x, 1) + >>> as_f_sign_1(2*x + 2) + (2, x, 1) + """ + if not e.is_Add or len(e.args) != 2: + return + # exact match + a, b = e.args + if a in (S.NegativeOne, S.One): + g = S.One + if b.is_Mul and b.args[0].is_Number and b.args[0] < 0: + a, b = -a, -b + g = -g + return g, b, a + # gcd match + a, b = [Factors(i) for i in e.args] + ua, ub = a.normal(b) + gcd = a.gcd(b).as_expr() + if S.NegativeOne in ua.factors: + ua = ua.quo(S.NegativeOne) + n1 = -1 + n2 = 1 + elif S.NegativeOne in ub.factors: + ub = ub.quo(S.NegativeOne) + n1 = 1 + n2 = -1 + else: + n1 = n2 = 1 + a, b = [i.as_expr() for i in (ua, ub)] + if a is S.One: + a, b = b, a + n1, n2 = n2, n1 + if n1 == -1: + gcd = -gcd + n2 = -n2 + + if b is S.One: + return gcd, a, n2 + + +def _osborne(e, d): + """Replace all hyperbolic functions with trig functions using + the Osborne rule. + + Notes + ===== + + ``d`` is a dummy variable to prevent automatic evaluation + of trigonometric/hyperbolic functions. + + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Hyperbolic_function + """ + + def f(rv): + if not isinstance(rv, HyperbolicFunction): + return rv + a = rv.args[0] + a = a*d if not a.is_Add else Add._from_args([i*d for i in a.args]) + if isinstance(rv, sinh): + return I*sin(a) + elif isinstance(rv, cosh): + return cos(a) + elif isinstance(rv, tanh): + return I*tan(a) + elif isinstance(rv, coth): + return cot(a)/I + elif isinstance(rv, sech): + return sec(a) + elif isinstance(rv, csch): + return csc(a)/I + else: + raise NotImplementedError('unhandled %s' % rv.func) + + return bottom_up(e, f) + + +def _osbornei(e, d): + """Replace all trig functions with hyperbolic functions using + the Osborne rule. + + Notes + ===== + + ``d`` is a dummy variable to prevent automatic evaluation + of trigonometric/hyperbolic functions. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Hyperbolic_function + """ + + def f(rv): + if not isinstance(rv, TrigonometricFunction): + return rv + const, x = rv.args[0].as_independent(d, as_Add=True) + a = x.xreplace({d: S.One}) + const*I + if isinstance(rv, sin): + return sinh(a)/I + elif isinstance(rv, cos): + return cosh(a) + elif isinstance(rv, tan): + return tanh(a)/I + elif isinstance(rv, cot): + return coth(a)*I + elif isinstance(rv, sec): + return sech(a) + elif isinstance(rv, csc): + return csch(a)*I + else: + raise NotImplementedError('unhandled %s' % rv.func) + + return bottom_up(e, f) + + +def hyper_as_trig(rv): + """Return an expression containing hyperbolic functions in terms + of trigonometric functions. Any trigonometric functions initially + present are replaced with Dummy symbols and the function to undo + the masking and the conversion back to hyperbolics is also returned. It + should always be true that:: + + t, f = hyper_as_trig(expr) + expr == f(t) + + Examples + ======== + + >>> from sympy.simplify.fu import hyper_as_trig, fu + >>> from sympy.abc import x + >>> from sympy import cosh, sinh + >>> eq = sinh(x)**2 + cosh(x)**2 + >>> t, f = hyper_as_trig(eq) + >>> f(fu(t)) + cosh(2*x) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Hyperbolic_function + """ + from sympy.simplify.simplify import signsimp + from sympy.simplify.radsimp import collect + + # mask off trig functions + trigs = rv.atoms(TrigonometricFunction) + reps = [(t, Dummy()) for t in trigs] + masked = rv.xreplace(dict(reps)) + + # get inversion substitutions in place + reps = [(v, k) for k, v in reps] + + d = Dummy() + + return _osborne(masked, d), lambda x: collect(signsimp( + _osbornei(x, d).xreplace(dict(reps))), S.ImaginaryUnit) + + +def sincos_to_sum(expr): + """Convert products and powers of sin and cos to sums. + + Explanation + =========== + + Applied power reduction TRpower first, then expands products, and + converts products to sums with TR8. + + Examples + ======== + + >>> from sympy.simplify.fu import sincos_to_sum + >>> from sympy.abc import x + >>> from sympy import cos, sin + >>> sincos_to_sum(16*sin(x)**3*cos(2*x)**2) + 7*sin(x) - 5*sin(3*x) + 3*sin(5*x) - sin(7*x) + """ + + if not expr.has(cos, sin): + return expr + else: + return TR8(expand_mul(TRpower(expr))) diff --git a/phivenv/Lib/site-packages/sympy/simplify/gammasimp.py b/phivenv/Lib/site-packages/sympy/simplify/gammasimp.py new file mode 100644 index 0000000000000000000000000000000000000000..aec20c56eb60efb8e1aadfb5bff3d1ba1ab51869 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/simplify/gammasimp.py @@ -0,0 +1,493 @@ +from sympy.core import Function, S, Mul, Pow, Add +from sympy.core.sorting import ordered, default_sort_key +from sympy.core.function import expand_func +from sympy.core.symbol import Dummy +from sympy.functions import gamma, sqrt, sin +from sympy.polys import factor, cancel +from sympy.utilities.iterables import sift, uniq + + +def gammasimp(expr): + r""" + Simplify expressions with gamma functions. + + Explanation + =========== + + This function takes as input an expression containing gamma + functions or functions that can be rewritten in terms of gamma + functions and tries to minimize the number of those functions and + reduce the size of their arguments. + + The algorithm works by rewriting all gamma functions as expressions + involving rising factorials (Pochhammer symbols) and applies + recurrence relations and other transformations applicable to rising + factorials, to reduce their arguments, possibly letting the resulting + rising factorial to cancel. Rising factorials with the second argument + being an integer are expanded into polynomial forms and finally all + other rising factorial are rewritten in terms of gamma functions. + + Then the following two steps are performed. + + 1. Reduce the number of gammas by applying the reflection theorem + gamma(x)*gamma(1-x) == pi/sin(pi*x). + 2. Reduce the number of gammas by applying the multiplication theorem + gamma(x)*gamma(x+1/n)*...*gamma(x+(n-1)/n) == C*gamma(n*x). + + It then reduces the number of prefactors by absorbing them into gammas + where possible and expands gammas with rational argument. + + All transformation rules can be found (or were derived from) here: + + .. [1] https://functions.wolfram.com/GammaBetaErf/Pochhammer/17/01/02/ + .. [2] https://functions.wolfram.com/GammaBetaErf/Pochhammer/27/01/0005/ + + Examples + ======== + + >>> from sympy.simplify import gammasimp + >>> from sympy import gamma, Symbol + >>> from sympy.abc import x + >>> n = Symbol('n', integer = True) + + >>> gammasimp(gamma(x)/gamma(x - 3)) + (x - 3)*(x - 2)*(x - 1) + >>> gammasimp(gamma(n + 3)) + gamma(n + 3) + + """ + + expr = expr.rewrite(gamma) + + # compute_ST will be looking for Functions and we don't want + # it looking for non-gamma functions: issue 22606 + # so we mask free, non-gamma functions + f = expr.atoms(Function) + # take out gammas + gammas = {i for i in f if isinstance(i, gamma)} + if not gammas: + return expr # avoid side effects like factoring + f -= gammas + # keep only those without bound symbols + f = f & expr.as_dummy().atoms(Function) + if f: + dum, fun, simp = zip(*[ + (Dummy(), fi, fi.func(*[ + _gammasimp(a, as_comb=False) for a in fi.args])) + for fi in ordered(f)]) + d = expr.xreplace(dict(zip(fun, dum))) + return _gammasimp(d, as_comb=False).xreplace(dict(zip(dum, simp))) + + return _gammasimp(expr, as_comb=False) + + +def _gammasimp(expr, as_comb): + """ + Helper function for gammasimp and combsimp. + + Explanation + =========== + + Simplifies expressions written in terms of gamma function. If + as_comb is True, it tries to preserve integer arguments. See + docstring of gammasimp for more information. This was part of + combsimp() in combsimp.py. + """ + expr = expr.replace(gamma, + lambda n: _rf(1, (n - 1).expand())) + + if as_comb: + expr = expr.replace(_rf, + lambda a, b: gamma(b + 1)) + else: + expr = expr.replace(_rf, + lambda a, b: gamma(a + b)/gamma(a)) + + def rule_gamma(expr, level=0): + """ Simplify products of gamma functions further. """ + + if expr.is_Atom: + return expr + + def gamma_rat(x): + # helper to simplify ratios of gammas + was = x.count(gamma) + xx = x.replace(gamma, lambda n: _rf(1, (n - 1).expand() + ).replace(_rf, lambda a, b: gamma(a + b)/gamma(a))) + if xx.count(gamma) < was: + x = xx + return x + + def gamma_factor(x): + # return True if there is a gamma factor in shallow args + if isinstance(x, gamma): + return True + if x.is_Add or x.is_Mul: + return any(gamma_factor(xi) for xi in x.args) + if x.is_Pow and (x.exp.is_integer or x.base.is_positive): + return gamma_factor(x.base) + return False + + # recursion step + if level == 0: + expr = expr.func(*[rule_gamma(x, level + 1) for x in expr.args]) + level += 1 + + if not expr.is_Mul: + return expr + + # non-commutative step + if level == 1: + args, nc = expr.args_cnc() + if not args: + return expr + if nc: + return rule_gamma(Mul._from_args(args), level + 1)*Mul._from_args(nc) + level += 1 + + # pure gamma handling, not factor absorption + if level == 2: + T, F = sift(expr.args, gamma_factor, binary=True) + gamma_ind = Mul(*F) + d = Mul(*T) + + nd, dd = d.as_numer_denom() + for ipass in range(2): + args = list(ordered(Mul.make_args(nd))) + for i, ni in enumerate(args): + if ni.is_Add: + ni, dd = Add(*[ + rule_gamma(gamma_rat(a/dd), level + 1) for a in ni.args] + ).as_numer_denom() + args[i] = ni + if not dd.has(gamma): + break + nd = Mul(*args) + if ipass == 0 and not gamma_factor(nd): + break + nd, dd = dd, nd # now process in reversed order + expr = gamma_ind*nd/dd + if not (expr.is_Mul and (gamma_factor(dd) or gamma_factor(nd))): + return expr + level += 1 + + # iteration until constant + if level == 3: + while True: + was = expr + expr = rule_gamma(expr, 4) + if expr == was: + return expr + + numer_gammas = [] + denom_gammas = [] + numer_others = [] + denom_others = [] + def explicate(p): + if p is S.One: + return None, [] + b, e = p.as_base_exp() + if e.is_Integer: + if isinstance(b, gamma): + return True, [b.args[0]]*e + else: + return False, [b]*e + else: + return False, [p] + + newargs = list(ordered(expr.args)) + while newargs: + n, d = newargs.pop().as_numer_denom() + isg, l = explicate(n) + if isg: + numer_gammas.extend(l) + elif isg is False: + numer_others.extend(l) + isg, l = explicate(d) + if isg: + denom_gammas.extend(l) + elif isg is False: + denom_others.extend(l) + + # =========== level 2 work: pure gamma manipulation ========= + + if not as_comb: + # Try to reduce the number of gamma factors by applying the + # reflection formula gamma(x)*gamma(1-x) = pi/sin(pi*x) + for gammas, numer, denom in [( + numer_gammas, numer_others, denom_others), + (denom_gammas, denom_others, numer_others)]: + new = [] + while gammas: + g1 = gammas.pop() + if g1.is_integer: + new.append(g1) + continue + for i, g2 in enumerate(gammas): + n = g1 + g2 - 1 + if not n.is_Integer: + continue + numer.append(S.Pi) + denom.append(sin(S.Pi*g1)) + gammas.pop(i) + if n > 0: + numer.extend(1 - g1 + k for k in range(n)) + elif n < 0: + denom.extend(-g1 - k for k in range(-n)) + break + else: + new.append(g1) + # /!\ updating IN PLACE + gammas[:] = new + + # Try to reduce the number of gammas by using the duplication + # theorem to cancel an upper and lower: gamma(2*s)/gamma(s) = + # 2**(2*s + 1)/(4*sqrt(pi))*gamma(s + 1/2). Although this could + # be done with higher argument ratios like gamma(3*x)/gamma(x), + # this would not reduce the number of gammas as in this case. + for ng, dg, no, do in [(numer_gammas, denom_gammas, numer_others, + denom_others), + (denom_gammas, numer_gammas, denom_others, + numer_others)]: + + while True: + for x in ng: + for y in dg: + n = x - 2*y + if n.is_Integer: + break + else: + continue + break + else: + break + ng.remove(x) + dg.remove(y) + if n > 0: + no.extend(2*y + k for k in range(n)) + elif n < 0: + do.extend(2*y - 1 - k for k in range(-n)) + ng.append(y + S.Half) + no.append(2**(2*y - 1)) + do.append(sqrt(S.Pi)) + + # Try to reduce the number of gamma factors by applying the + # multiplication theorem (used when n gammas with args differing + # by 1/n mod 1 are encountered). + # + # run of 2 with args differing by 1/2 + # + # >>> gammasimp(gamma(x)*gamma(x+S.Half)) + # 2*sqrt(2)*2**(-2*x - 1/2)*sqrt(pi)*gamma(2*x) + # + # run of 3 args differing by 1/3 (mod 1) + # + # >>> gammasimp(gamma(x)*gamma(x+S(1)/3)*gamma(x+S(2)/3)) + # 6*3**(-3*x - 1/2)*pi*gamma(3*x) + # >>> gammasimp(gamma(x)*gamma(x+S(1)/3)*gamma(x+S(5)/3)) + # 2*3**(-3*x - 1/2)*pi*(3*x + 2)*gamma(3*x) + # + def _run(coeffs): + # find runs in coeffs such that the difference in terms (mod 1) + # of t1, t2, ..., tn is 1/n + u = list(uniq(coeffs)) + for i in range(len(u)): + dj = ([((u[j] - u[i]) % 1, j) for j in range(i + 1, len(u))]) + for one, j in dj: + if one.p == 1 and one.q != 1: + n = one.q + got = [i] + get = list(range(1, n)) + for d, j in dj: + m = n*d + if m.is_Integer and m in get: + get.remove(m) + got.append(j) + if not get: + break + else: + continue + for i, j in enumerate(got): + c = u[j] + coeffs.remove(c) + got[i] = c + return one.q, got[0], got[1:] + + def _mult_thm(gammas, numer, denom): + # pull off and analyze the leading coefficient from each gamma arg + # looking for runs in those Rationals + + # expr -> coeff + resid -> rats[resid] = coeff + rats = {} + for g in gammas: + c, resid = g.as_coeff_Add() + rats.setdefault(resid, []).append(c) + + # look for runs in Rationals for each resid + keys = sorted(rats, key=default_sort_key) + for resid in keys: + coeffs = sorted(rats[resid]) + new = [] + while True: + run = _run(coeffs) + if run is None: + break + + # process the sequence that was found: + # 1) convert all the gamma functions to have the right + # argument (could be off by an integer) + # 2) append the factors corresponding to the theorem + # 3) append the new gamma function + + n, ui, other = run + + # (1) + for u in other: + con = resid + u - 1 + for k in range(int(u - ui)): + numer.append(con - k) + + con = n*(resid + ui) # for (2) and (3) + + # (2) + numer.append((2*S.Pi)**(S(n - 1)/2)* + n**(S.Half - con)) + # (3) + new.append(con) + + # restore resid to coeffs + rats[resid] = [resid + c for c in coeffs] + new + + # rebuild the gamma arguments + g = [] + for resid in keys: + g += rats[resid] + # /!\ updating IN PLACE + gammas[:] = g + + for l, numer, denom in [(numer_gammas, numer_others, denom_others), + (denom_gammas, denom_others, numer_others)]: + _mult_thm(l, numer, denom) + + # =========== level >= 2 work: factor absorption ========= + + if level >= 2: + # Try to absorb factors into the gammas: x*gamma(x) -> gamma(x + 1) + # and gamma(x)/(x - 1) -> gamma(x - 1) + # This code (in particular repeated calls to find_fuzzy) can be very + # slow. + def find_fuzzy(l, x): + if not l: + return + S1, T1 = compute_ST(x) + for y in l: + S2, T2 = inv[y] + if T1 != T2 or (not S1.intersection(S2) and + (S1 != set() or S2 != set())): + continue + # XXX we want some simplification (e.g. cancel or + # simplify) but no matter what it's slow. + a = len(cancel(x/y).free_symbols) + b = len(x.free_symbols) + c = len(y.free_symbols) + # TODO is there a better heuristic? + if a == 0 and (b > 0 or c > 0): + return y + + # We thus try to avoid expensive calls by building the following + # "invariants": For every factor or gamma function argument + # - the set of free symbols S + # - the set of functional components T + # We will only try to absorb if T1==T2 and (S1 intersect S2 != emptyset + # or S1 == S2 == emptyset) + inv = {} + + def compute_ST(expr): + if expr in inv: + return inv[expr] + return (expr.free_symbols, expr.atoms(Function).union( + {e.exp for e in expr.atoms(Pow)})) + + def update_ST(expr): + inv[expr] = compute_ST(expr) + for expr in numer_gammas + denom_gammas + numer_others + denom_others: + update_ST(expr) + + for gammas, numer, denom in [( + numer_gammas, numer_others, denom_others), + (denom_gammas, denom_others, numer_others)]: + new = [] + while gammas: + g = gammas.pop() + cont = True + while cont: + cont = False + y = find_fuzzy(numer, g) + if y is not None: + numer.remove(y) + if y != g: + numer.append(y/g) + update_ST(y/g) + g += 1 + cont = True + y = find_fuzzy(denom, g - 1) + if y is not None: + denom.remove(y) + if y != g - 1: + numer.append((g - 1)/y) + update_ST((g - 1)/y) + g -= 1 + cont = True + new.append(g) + # /!\ updating IN PLACE + gammas[:] = new + + # =========== rebuild expr ================================== + + return Mul(*[gamma(g) for g in numer_gammas]) \ + / Mul(*[gamma(g) for g in denom_gammas]) \ + * Mul(*numer_others) / Mul(*denom_others) + + was = factor(expr) + # (for some reason we cannot use Basic.replace in this case) + expr = rule_gamma(was) + if expr != was: + expr = factor(expr) + + expr = expr.replace(gamma, + lambda n: expand_func(gamma(n)) if n.is_Rational else gamma(n)) + + return expr + + +class _rf(Function): + @classmethod + def eval(cls, a, b): + if b.is_Integer: + if not b: + return S.One + + n = int(b) + + if n > 0: + return Mul(*[a + i for i in range(n)]) + elif n < 0: + return 1/Mul(*[a - i for i in range(1, -n + 1)]) + else: + if b.is_Add: + c, _b = b.as_coeff_Add() + + if c.is_Integer: + if c > 0: + return _rf(a, _b)*_rf(a + _b, c) + elif c < 0: + return _rf(a, _b)/_rf(a + _b + c, -c) + + if a.is_Add: + c, _a = a.as_coeff_Add() + + if c.is_Integer: + if c > 0: + return _rf(_a, b)*_rf(_a + b, c)/_rf(_a, c) + elif c < 0: + return _rf(_a, b)*_rf(_a + c, -c)/_rf(_a + b + c, -c) diff --git a/phivenv/Lib/site-packages/sympy/simplify/hyperexpand.py b/phivenv/Lib/site-packages/sympy/simplify/hyperexpand.py new file mode 100644 index 0000000000000000000000000000000000000000..c070aa2e44b92794107b3e33df897813a54307b9 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/simplify/hyperexpand.py @@ -0,0 +1,2494 @@ +""" +Expand Hypergeometric (and Meijer G) functions into named +special functions. + +The algorithm for doing this uses a collection of lookup tables of +hypergeometric functions, and various of their properties, to expand +many hypergeometric functions in terms of special functions. + +It is based on the following paper: + Kelly B. Roach. Meijer G Function Representations. + In: Proceedings of the 1997 International Symposium on Symbolic and + Algebraic Computation, pages 205-211, New York, 1997. ACM. + +It is described in great(er) detail in the Sphinx documentation. +""" +# SUMMARY OF EXTENSIONS FOR MEIJER G FUNCTIONS +# +# o z**rho G(ap, bq; z) = G(ap + rho, bq + rho; z) +# +# o denote z*d/dz by D +# +# o It is helpful to keep in mind that ap and bq play essentially symmetric +# roles: G(1/z) has slightly altered parameters, with ap and bq interchanged. +# +# o There are four shift operators: +# A_J = b_J - D, J = 1, ..., n +# B_J = 1 - a_j + D, J = 1, ..., m +# C_J = -b_J + D, J = m+1, ..., q +# D_J = a_J - 1 - D, J = n+1, ..., p +# +# A_J, C_J increment b_J +# B_J, D_J decrement a_J +# +# o The corresponding four inverse-shift operators are defined if there +# is no cancellation. Thus e.g. an index a_J (upper or lower) can be +# incremented if a_J != b_i for i = 1, ..., q. +# +# o Order reduction: if b_j - a_i is a non-negative integer, where +# j <= m and i > n, the corresponding quotient of gamma functions reduces +# to a polynomial. Hence the G function can be expressed using a G-function +# of lower order. +# Similarly if j > m and i <= n. +# +# Secondly, there are paired index theorems [Adamchik, The evaluation of +# integrals of Bessel functions via G-function identities]. Suppose there +# are three parameters a, b, c, where a is an a_i, i <= n, b is a b_j, +# j <= m and c is a denominator parameter (i.e. a_i, i > n or b_j, j > m). +# Suppose further all three differ by integers. +# Then the order can be reduced. +# TODO work this out in detail. +# +# o An index quadruple is called suitable if its order cannot be reduced. +# If there exists a sequence of shift operators transforming one index +# quadruple into another, we say one is reachable from the other. +# +# o Deciding if one index quadruple is reachable from another is tricky. For +# this reason, we use hand-built routines to match and instantiate formulas. +# +from collections import defaultdict +from itertools import product +from functools import reduce +from math import prod + +from sympy import SYMPY_DEBUG +from sympy.core import (S, Dummy, symbols, sympify, Tuple, expand, I, pi, Mul, + EulerGamma, oo, zoo, expand_func, Add, nan, Expr, Rational) +from sympy.core.mod import Mod +from sympy.core.sorting import default_sort_key +from sympy.functions import (exp, sqrt, root, log, lowergamma, cos, + besseli, gamma, uppergamma, expint, erf, sin, besselj, Ei, Ci, Si, Shi, + sinh, cosh, Chi, fresnels, fresnelc, polar_lift, exp_polar, floor, ceiling, + rf, factorial, lerchphi, Piecewise, re, elliptic_k, elliptic_e) +from sympy.functions.elementary.complexes import polarify, unpolarify +from sympy.functions.special.hyper import (hyper, HyperRep_atanh, + HyperRep_power1, HyperRep_power2, HyperRep_log1, HyperRep_asin1, + HyperRep_asin2, HyperRep_sqrts1, HyperRep_sqrts2, HyperRep_log2, + HyperRep_cosasin, HyperRep_sinasin, meijerg) +from sympy.matrices import Matrix, eye, zeros +from sympy.polys import apart, poly, Poly +from sympy.series import residue +from sympy.simplify.powsimp import powdenest +from sympy.utilities.iterables import sift + +# function to define "buckets" +def _mod1(x): + # TODO see if this can work as Mod(x, 1); this will require + # different handling of the "buckets" since these need to + # be sorted and that fails when there is a mixture of + # integers and expressions with parameters. With the current + # Mod behavior, Mod(k, 1) == Mod(1, 1) == 0 if k is an integer. + # Although the sorting can be done with Basic.compare, this may + # still require different handling of the sorted buckets. + if x.is_Number: + return Mod(x, 1) + c, x = x.as_coeff_Add() + return Mod(c, 1) + x + + +# leave add formulae at the top for easy reference +def add_formulae(formulae): + """ Create our knowledge base. """ + a, b, c, z = symbols('a b c, z', cls=Dummy) + + def add(ap, bq, res): + func = Hyper_Function(ap, bq) + formulae.append(Formula(func, z, res, (a, b, c))) + + def addb(ap, bq, B, C, M): + func = Hyper_Function(ap, bq) + formulae.append(Formula(func, z, None, (a, b, c), B, C, M)) + + # Luke, Y. L. (1969), The Special Functions and Their Approximations, + # Volume 1, section 6.2 + + # 0F0 + add((), (), exp(z)) + + # 1F0 + add((a, ), (), HyperRep_power1(-a, z)) + + # 2F1 + addb((a, a - S.Half), (2*a, ), + Matrix([HyperRep_power2(a, z), + HyperRep_power2(a + S.Half, z)/2]), + Matrix([[1, 0]]), + Matrix([[(a - S.Half)*z/(1 - z), (S.Half - a)*z/(1 - z)], + [a/(1 - z), a*(z - 2)/(1 - z)]])) + addb((1, 1), (2, ), + Matrix([HyperRep_log1(z), 1]), Matrix([[-1/z, 0]]), + Matrix([[0, z/(z - 1)], [0, 0]])) + addb((S.Half, 1), (S('3/2'), ), + Matrix([HyperRep_atanh(z), 1]), + Matrix([[1, 0]]), + Matrix([[Rational(-1, 2), 1/(1 - z)/2], [0, 0]])) + addb((S.Half, S.Half), (S('3/2'), ), + Matrix([HyperRep_asin1(z), HyperRep_power1(Rational(-1, 2), z)]), + Matrix([[1, 0]]), + Matrix([[Rational(-1, 2), S.Half], [0, z/(1 - z)/2]])) + addb((a, S.Half + a), (S.Half, ), + Matrix([HyperRep_sqrts1(-a, z), -HyperRep_sqrts2(-a - S.Half, z)]), + Matrix([[1, 0]]), + Matrix([[0, -a], + [z*(-2*a - 1)/2/(1 - z), S.Half - z*(-2*a - 1)/(1 - z)]])) + + # A. P. Prudnikov, Yu. A. Brychkov and O. I. Marichev (1990). + # Integrals and Series: More Special Functions, Vol. 3,. + # Gordon and Breach Science Publisher + addb([a, -a], [S.Half], + Matrix([HyperRep_cosasin(a, z), HyperRep_sinasin(a, z)]), + Matrix([[1, 0]]), + Matrix([[0, -a], [a*z/(1 - z), 1/(1 - z)/2]])) + addb([1, 1], [3*S.Half], + Matrix([HyperRep_asin2(z), 1]), Matrix([[1, 0]]), + Matrix([[(z - S.Half)/(1 - z), 1/(1 - z)/2], [0, 0]])) + + # Complete elliptic integrals K(z) and E(z), both a 2F1 function + addb([S.Half, S.Half], [S.One], + Matrix([elliptic_k(z), elliptic_e(z)]), + Matrix([[2/pi, 0]]), + Matrix([[Rational(-1, 2), -1/(2*z-2)], + [Rational(-1, 2), S.Half]])) + addb([Rational(-1, 2), S.Half], [S.One], + Matrix([elliptic_k(z), elliptic_e(z)]), + Matrix([[0, 2/pi]]), + Matrix([[Rational(-1, 2), -1/(2*z-2)], + [Rational(-1, 2), S.Half]])) + + # 3F2 + addb([Rational(-1, 2), 1, 1], [S.Half, 2], + Matrix([z*HyperRep_atanh(z), HyperRep_log1(z), 1]), + Matrix([[Rational(-2, 3), -S.One/(3*z), Rational(2, 3)]]), + Matrix([[S.Half, 0, z/(1 - z)/2], + [0, 0, z/(z - 1)], + [0, 0, 0]])) + # actually the formula for 3/2 is much nicer ... + addb([Rational(-1, 2), 1, 1], [2, 2], + Matrix([HyperRep_power1(S.Half, z), HyperRep_log2(z), 1]), + Matrix([[Rational(4, 9) - 16/(9*z), 4/(3*z), 16/(9*z)]]), + Matrix([[z/2/(z - 1), 0, 0], [1/(2*(z - 1)), 0, S.Half], [0, 0, 0]])) + + # 1F1 + addb([1], [b], Matrix([z**(1 - b) * exp(z) * lowergamma(b - 1, z), 1]), + Matrix([[b - 1, 0]]), Matrix([[1 - b + z, 1], [0, 0]])) + addb([a], [2*a], + Matrix([z**(S.Half - a)*exp(z/2)*besseli(a - S.Half, z/2) + * gamma(a + S.Half)/4**(S.Half - a), + z**(S.Half - a)*exp(z/2)*besseli(a + S.Half, z/2) + * gamma(a + S.Half)/4**(S.Half - a)]), + Matrix([[1, 0]]), + Matrix([[z/2, z/2], [z/2, (z/2 - 2*a)]])) + mz = polar_lift(-1)*z + addb([a], [a + 1], + Matrix([mz**(-a)*a*lowergamma(a, mz), a*exp(z)]), + Matrix([[1, 0]]), + Matrix([[-a, 1], [0, z]])) + # This one is redundant. + add([Rational(-1, 2)], [S.Half], exp(z) - sqrt(pi*z)*(-I)*erf(I*sqrt(z))) + + # Added to get nice results for Laplace transform of Fresnel functions + # https://functions.wolfram.com/07.22.03.6437.01 + # Basic rule + #add([1], [Rational(3, 4), Rational(5, 4)], + # sqrt(pi) * (cos(2*sqrt(polar_lift(-1)*z))*fresnelc(2*root(polar_lift(-1)*z,4)/sqrt(pi)) + + # sin(2*sqrt(polar_lift(-1)*z))*fresnels(2*root(polar_lift(-1)*z,4)/sqrt(pi))) + # / (2*root(polar_lift(-1)*z,4))) + # Manually tuned rule + addb([1], [Rational(3, 4), Rational(5, 4)], + Matrix([ sqrt(pi)*(I*sinh(2*sqrt(z))*fresnels(2*root(z, 4)*exp(I*pi/4)/sqrt(pi)) + + cosh(2*sqrt(z))*fresnelc(2*root(z, 4)*exp(I*pi/4)/sqrt(pi))) + * exp(-I*pi/4)/(2*root(z, 4)), + sqrt(pi)*root(z, 4)*(sinh(2*sqrt(z))*fresnelc(2*root(z, 4)*exp(I*pi/4)/sqrt(pi)) + + I*cosh(2*sqrt(z))*fresnels(2*root(z, 4)*exp(I*pi/4)/sqrt(pi))) + *exp(-I*pi/4)/2, + 1 ]), + Matrix([[1, 0, 0]]), + Matrix([[Rational(-1, 4), 1, Rational(1, 4)], + [ z, Rational(1, 4), 0], + [ 0, 0, 0]])) + + # 2F2 + addb([S.Half, a], [Rational(3, 2), a + 1], + Matrix([a/(2*a - 1)*(-I)*sqrt(pi/z)*erf(I*sqrt(z)), + a/(2*a - 1)*(polar_lift(-1)*z)**(-a)* + lowergamma(a, polar_lift(-1)*z), + a/(2*a - 1)*exp(z)]), + Matrix([[1, -1, 0]]), + Matrix([[Rational(-1, 2), 0, 1], [0, -a, 1], [0, 0, z]])) + # We make a "basis" of four functions instead of three, and give EulerGamma + # an extra slot (it could just be a coefficient to 1). The advantage is + # that this way Polys will not see multivariate polynomials (it treats + # EulerGamma as an indeterminate), which is *way* faster. + addb([1, 1], [2, 2], + Matrix([Ei(z) - log(z), exp(z), 1, EulerGamma]), + Matrix([[1/z, 0, 0, -1/z]]), + Matrix([[0, 1, -1, 0], [0, z, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]])) + + # 0F1 + add((), (S.Half, ), cosh(2*sqrt(z))) + addb([], [b], + Matrix([gamma(b)*z**((1 - b)/2)*besseli(b - 1, 2*sqrt(z)), + gamma(b)*z**(1 - b/2)*besseli(b, 2*sqrt(z))]), + Matrix([[1, 0]]), Matrix([[0, 1], [z, (1 - b)]])) + + # 0F3 + x = 4*z**Rational(1, 4) + + def fp(a, z): + return besseli(a, x) + besselj(a, x) + + def fm(a, z): + return besseli(a, x) - besselj(a, x) + + # TODO branching + addb([], [S.Half, a, a + S.Half], + Matrix([fp(2*a - 1, z), fm(2*a, z)*z**Rational(1, 4), + fm(2*a - 1, z)*sqrt(z), fp(2*a, z)*z**Rational(3, 4)]) + * 2**(-2*a)*gamma(2*a)*z**((1 - 2*a)/4), + Matrix([[1, 0, 0, 0]]), + Matrix([[0, 1, 0, 0], + [0, S.Half - a, 1, 0], + [0, 0, S.Half, 1], + [z, 0, 0, 1 - a]])) + x = 2*(4*z)**Rational(1, 4)*exp_polar(I*pi/4) + addb([], [a, a + S.Half, 2*a], + (2*sqrt(polar_lift(-1)*z))**(1 - 2*a)*gamma(2*a)**2 * + Matrix([besselj(2*a - 1, x)*besseli(2*a - 1, x), + x*(besseli(2*a, x)*besselj(2*a - 1, x) + - besseli(2*a - 1, x)*besselj(2*a, x)), + x**2*besseli(2*a, x)*besselj(2*a, x), + x**3*(besseli(2*a, x)*besselj(2*a - 1, x) + + besseli(2*a - 1, x)*besselj(2*a, x))]), + Matrix([[1, 0, 0, 0]]), + Matrix([[0, Rational(1, 4), 0, 0], + [0, (1 - 2*a)/2, Rational(-1, 2), 0], + [0, 0, 1 - 2*a, Rational(1, 4)], + [-32*z, 0, 0, 1 - a]])) + + # 1F2 + addb([a], [a - S.Half, 2*a], + Matrix([z**(S.Half - a)*besseli(a - S.Half, sqrt(z))**2, + z**(1 - a)*besseli(a - S.Half, sqrt(z)) + *besseli(a - Rational(3, 2), sqrt(z)), + z**(Rational(3, 2) - a)*besseli(a - Rational(3, 2), sqrt(z))**2]), + Matrix([[-gamma(a + S.Half)**2/4**(S.Half - a), + 2*gamma(a - S.Half)*gamma(a + S.Half)/4**(1 - a), + 0]]), + Matrix([[1 - 2*a, 1, 0], [z/2, S.Half - a, S.Half], [0, z, 0]])) + addb([S.Half], [b, 2 - b], + pi*(1 - b)/sin(pi*b)* + Matrix([besseli(1 - b, sqrt(z))*besseli(b - 1, sqrt(z)), + sqrt(z)*(besseli(-b, sqrt(z))*besseli(b - 1, sqrt(z)) + + besseli(1 - b, sqrt(z))*besseli(b, sqrt(z))), + besseli(-b, sqrt(z))*besseli(b, sqrt(z))]), + Matrix([[1, 0, 0]]), + Matrix([[b - 1, S.Half, 0], + [z, 0, z], + [0, S.Half, -b]])) + addb([S.Half], [Rational(3, 2), Rational(3, 2)], + Matrix([Shi(2*sqrt(z))/2/sqrt(z), sinh(2*sqrt(z))/2/sqrt(z), + cosh(2*sqrt(z))]), + Matrix([[1, 0, 0]]), + Matrix([[Rational(-1, 2), S.Half, 0], [0, Rational(-1, 2), S.Half], [0, 2*z, 0]])) + + # FresnelS + # Basic rule + #add([Rational(3, 4)], [Rational(3, 2),Rational(7, 4)], 6*fresnels( exp(pi*I/4)*root(z,4)*2/sqrt(pi) ) / ( pi * (exp(pi*I/4)*root(z,4)*2/sqrt(pi))**3 ) ) + # Manually tuned rule + addb([Rational(3, 4)], [Rational(3, 2), Rational(7, 4)], + Matrix( + [ fresnels( + exp( + pi*I/4)*root( + z, 4)*2/sqrt( + pi) ) / ( + pi * (exp(pi*I/4)*root(z, 4)*2/sqrt(pi))**3 ), + sinh(2*sqrt(z))/sqrt(z), + cosh(2*sqrt(z)) ]), + Matrix([[6, 0, 0]]), + Matrix([[Rational(-3, 4), Rational(1, 16), 0], + [ 0, Rational(-1, 2), 1], + [ 0, z, 0]])) + + # FresnelC + # Basic rule + #add([Rational(1, 4)], [S.Half,Rational(5, 4)], fresnelc( exp(pi*I/4)*root(z,4)*2/sqrt(pi) ) / ( exp(pi*I/4)*root(z,4)*2/sqrt(pi) ) ) + # Manually tuned rule + addb([Rational(1, 4)], [S.Half, Rational(5, 4)], + Matrix( + [ sqrt( + pi)*exp( + -I*pi/4)*fresnelc( + 2*root(z, 4)*exp(I*pi/4)/sqrt(pi))/(2*root(z, 4)), + cosh(2*sqrt(z)), + sinh(2*sqrt(z))*sqrt(z) ]), + Matrix([[1, 0, 0]]), + Matrix([[Rational(-1, 4), Rational(1, 4), 0 ], + [ 0, 0, 1 ], + [ 0, z, S.Half]])) + + # 2F3 + # XXX with this five-parameter formula is pretty slow with the current + # Formula.find_instantiations (creates 2!*3!*3**(2+3) ~ 3000 + # instantiations ... But it's not too bad. + addb([a, a + S.Half], [2*a, b, 2*a - b + 1], + gamma(b)*gamma(2*a - b + 1) * (sqrt(z)/2)**(1 - 2*a) * + Matrix([besseli(b - 1, sqrt(z))*besseli(2*a - b, sqrt(z)), + sqrt(z)*besseli(b, sqrt(z))*besseli(2*a - b, sqrt(z)), + sqrt(z)*besseli(b - 1, sqrt(z))*besseli(2*a - b + 1, sqrt(z)), + besseli(b, sqrt(z))*besseli(2*a - b + 1, sqrt(z))]), + Matrix([[1, 0, 0, 0]]), + Matrix([[0, S.Half, S.Half, 0], + [z/2, 1 - b, 0, z/2], + [z/2, 0, b - 2*a, z/2], + [0, S.Half, S.Half, -2*a]])) + # (C/f above comment about eulergamma in the basis). + addb([1, 1], [2, 2, Rational(3, 2)], + Matrix([Chi(2*sqrt(z)) - log(2*sqrt(z)), + cosh(2*sqrt(z)), sqrt(z)*sinh(2*sqrt(z)), 1, EulerGamma]), + Matrix([[1/z, 0, 0, 0, -1/z]]), + Matrix([[0, S.Half, 0, Rational(-1, 2), 0], + [0, 0, 1, 0, 0], + [0, z, S.Half, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0]])) + + # 3F3 + # This is rule: https://functions.wolfram.com/07.31.03.0134.01 + # Initial reason to add it was a nice solution for + # integrate(erf(a*z)/z**2, z) and same for erfc and erfi. + # Basic rule + # add([1, 1, a], [2, 2, a+1], (a/(z*(a-1)**2)) * + # (1 - (-z)**(1-a) * (gamma(a) - uppergamma(a,-z)) + # - (a-1) * (EulerGamma + uppergamma(0,-z) + log(-z)) + # - exp(z))) + # Manually tuned rule + addb([1, 1, a], [2, 2, a+1], + Matrix([a*(log(-z) + expint(1, -z) + EulerGamma)/(z*(a**2 - 2*a + 1)), + a*(-z)**(-a)*(gamma(a) - uppergamma(a, -z))/(a - 1)**2, + a*exp(z)/(a**2 - 2*a + 1), + a/(z*(a**2 - 2*a + 1))]), + Matrix([[1-a, 1, -1/z, 1]]), + Matrix([[-1,0,-1/z,1], + [0,-a,1,0], + [0,0,z,0], + [0,0,0,-1]])) + + +def add_meijerg_formulae(formulae): + a, b, c, z = list(map(Dummy, 'abcz')) + rho = Dummy('rho') + + def add(an, ap, bm, bq, B, C, M, matcher): + formulae.append(MeijerFormula(an, ap, bm, bq, z, [a, b, c, rho], + B, C, M, matcher)) + + def detect_uppergamma(func): + x = func.an[0] + y, z = func.bm + swapped = False + if not _mod1((x - y).simplify()): + swapped = True + (y, z) = (z, y) + if _mod1((x - z).simplify()) or x - z > 0: + return None + l = [y, x] + if swapped: + l = [x, y] + return {rho: y, a: x - y}, G_Function([x], [], l, []) + + add([a + rho], [], [rho, a + rho], [], + Matrix([gamma(1 - a)*z**rho*exp(z)*uppergamma(a, z), + gamma(1 - a)*z**(a + rho)]), + Matrix([[1, 0]]), + Matrix([[rho + z, -1], [0, a + rho]]), + detect_uppergamma) + + def detect_3113(func): + """https://functions.wolfram.com/07.34.03.0984.01""" + x = func.an[0] + u, v, w = func.bm + if _mod1((u - v).simplify()) == 0: + if _mod1((v - w).simplify()) == 0: + return + sig = (S.Half, S.Half, S.Zero) + x1, x2, y = u, v, w + else: + if _mod1((x - u).simplify()) == 0: + sig = (S.Half, S.Zero, S.Half) + x1, y, x2 = u, v, w + else: + sig = (S.Zero, S.Half, S.Half) + y, x1, x2 = u, v, w + + if (_mod1((x - x1).simplify()) != 0 or + _mod1((x - x2).simplify()) != 0 or + _mod1((x - y).simplify()) != S.Half or + x - x1 > 0 or x - x2 > 0): + return + + return {a: x}, G_Function([x], [], [x - S.Half + t for t in sig], []) + + s = sin(2*sqrt(z)) + c_ = cos(2*sqrt(z)) + S_ = Si(2*sqrt(z)) - pi/2 + C = Ci(2*sqrt(z)) + add([a], [], [a, a, a - S.Half], [], + Matrix([sqrt(pi)*z**(a - S.Half)*(c_*S_ - s*C), + sqrt(pi)*z**a*(s*S_ + c_*C), + sqrt(pi)*z**a]), + Matrix([[-2, 0, 0]]), + Matrix([[a - S.Half, -1, 0], [z, a, S.Half], [0, 0, a]]), + detect_3113) + + +def make_simp(z): + """ Create a function that simplifies rational functions in ``z``. """ + + def simp(expr): + """ Efficiently simplify the rational function ``expr``. """ + numer, denom = expr.as_numer_denom() + numer = numer.expand() + # denom = denom.expand() # is this needed? + c, numer, denom = poly(numer, z).cancel(poly(denom, z)) + return c * numer.as_expr() / denom.as_expr() + + return simp + + +def debug(*args): + if SYMPY_DEBUG: + for a in args: + print(a, end="") + print() + + +class Hyper_Function(Expr): + """ A generalized hypergeometric function. """ + + def __new__(cls, ap, bq): + obj = super().__new__(cls) + obj.ap = Tuple(*list(map(expand, ap))) + obj.bq = Tuple(*list(map(expand, bq))) + return obj + + @property + def args(self): + return (self.ap, self.bq) + + @property + def sizes(self): + return (len(self.ap), len(self.bq)) + + @property + def gamma(self): + """ + Number of upper parameters that are negative integers + + This is a transformation invariant. + """ + return sum(bool(x.is_integer and x.is_negative) for x in self.ap) + + def _hashable_content(self): + return super()._hashable_content() + (self.ap, + self.bq) + + def __call__(self, arg): + return hyper(self.ap, self.bq, arg) + + def build_invariants(self): + """ + Compute the invariant vector. + + Explanation + =========== + + The invariant vector is: + (gamma, ((s1, n1), ..., (sk, nk)), ((t1, m1), ..., (tr, mr))) + where gamma is the number of integer a < 0, + s1 < ... < sk + nl is the number of parameters a_i congruent to sl mod 1 + t1 < ... < tr + ml is the number of parameters b_i congruent to tl mod 1 + + If the index pair contains parameters, then this is not truly an + invariant, since the parameters cannot be sorted uniquely mod1. + + Examples + ======== + + >>> from sympy.simplify.hyperexpand import Hyper_Function + >>> from sympy import S + >>> ap = (S.Half, S.One/3, S(-1)/2, -2) + >>> bq = (1, 2) + + Here gamma = 1, + k = 3, s1 = 0, s2 = 1/3, s3 = 1/2 + n1 = 1, n2 = 1, n2 = 2 + r = 1, t1 = 0 + m1 = 2: + + >>> Hyper_Function(ap, bq).build_invariants() + (1, ((0, 1), (1/3, 1), (1/2, 2)), ((0, 2),)) + """ + abuckets, bbuckets = sift(self.ap, _mod1), sift(self.bq, _mod1) + + def tr(bucket): + bucket = list(bucket.items()) + if not any(isinstance(x[0], Mod) for x in bucket): + bucket.sort(key=lambda x: default_sort_key(x[0])) + bucket = tuple([(mod, len(values)) for mod, values in bucket if + values]) + return bucket + + return (self.gamma, tr(abuckets), tr(bbuckets)) + + def difficulty(self, func): + """ Estimate how many steps it takes to reach ``func`` from self. + Return -1 if impossible. """ + if self.gamma != func.gamma: + return -1 + oabuckets, obbuckets, abuckets, bbuckets = [sift(params, _mod1) for + params in (self.ap, self.bq, func.ap, func.bq)] + + diff = 0 + for bucket, obucket in [(abuckets, oabuckets), (bbuckets, obbuckets)]: + for mod in set(list(bucket.keys()) + list(obucket.keys())): + if (mod not in bucket) or (mod not in obucket) \ + or len(bucket[mod]) != len(obucket[mod]): + return -1 + l1 = list(bucket[mod]) + l2 = list(obucket[mod]) + l1.sort() + l2.sort() + for i, j in zip(l1, l2): + diff += abs(i - j) + + return diff + + def _is_suitable_origin(self): + """ + Decide if ``self`` is a suitable origin. + + Explanation + =========== + + A function is a suitable origin iff: + * none of the ai equals bj + n, with n a non-negative integer + * none of the ai is zero + * none of the bj is a non-positive integer + + Note that this gives meaningful results only when none of the indices + are symbolic. + + """ + for a in self.ap: + for b in self.bq: + if (a - b).is_integer and (a - b).is_negative is False: + return False + for a in self.ap: + if a == 0: + return False + for b in self.bq: + if b.is_integer and b.is_nonpositive: + return False + return True + + +class G_Function(Expr): + """ A Meijer G-function. """ + + def __new__(cls, an, ap, bm, bq): + obj = super().__new__(cls) + obj.an = Tuple(*list(map(expand, an))) + obj.ap = Tuple(*list(map(expand, ap))) + obj.bm = Tuple(*list(map(expand, bm))) + obj.bq = Tuple(*list(map(expand, bq))) + return obj + + @property + def args(self): + return (self.an, self.ap, self.bm, self.bq) + + def _hashable_content(self): + return super()._hashable_content() + self.args + + def __call__(self, z): + return meijerg(self.an, self.ap, self.bm, self.bq, z) + + def compute_buckets(self): + """ + Compute buckets for the fours sets of parameters. + + Explanation + =========== + + We guarantee that any two equal Mod objects returned are actually the + same, and that the buckets are sorted by real part (an and bq + descendending, bm and ap ascending). + + Examples + ======== + + >>> from sympy.simplify.hyperexpand import G_Function + >>> from sympy.abc import y + >>> from sympy import S + + >>> a, b = [1, 3, 2, S(3)/2], [1 + y, y, 2, y + 3] + >>> G_Function(a, b, [2], [y]).compute_buckets() + ({0: [3, 2, 1], 1/2: [3/2]}, + {0: [2], y: [y, y + 1, y + 3]}, {0: [2]}, {y: [y]}) + + """ + dicts = pan, pap, pbm, pbq = [defaultdict(list) for i in range(4)] + for dic, lis in zip(dicts, (self.an, self.ap, self.bm, self.bq)): + for x in lis: + dic[_mod1(x)].append(x) + + for dic, flip in zip(dicts, (True, False, False, True)): + for m, items in dic.items(): + x0 = items[0] + items.sort(key=lambda x: x - x0, reverse=flip) + dic[m] = items + + return tuple([dict(w) for w in dicts]) + + @property + def signature(self): + return (len(self.an), len(self.ap), len(self.bm), len(self.bq)) + + +# Dummy variable. +_x = Dummy('x') + +class Formula: + """ + This class represents hypergeometric formulae. + + Explanation + =========== + + Its data members are: + - z, the argument + - closed_form, the closed form expression + - symbols, the free symbols (parameters) in the formula + - func, the function + - B, C, M (see _compute_basis) + + Examples + ======== + + >>> from sympy.abc import a, b, z + >>> from sympy.simplify.hyperexpand import Formula, Hyper_Function + >>> func = Hyper_Function((a/2, a/3 + b, (1+a)/2), (a, b, (a+b)/7)) + >>> f = Formula(func, z, None, [a, b]) + + """ + + def _compute_basis(self, closed_form): + """ + Compute a set of functions B=(f1, ..., fn), a nxn matrix M + and a 1xn matrix C such that: + closed_form = C B + z d/dz B = M B. + """ + afactors = [_x + a for a in self.func.ap] + bfactors = [_x + b - 1 for b in self.func.bq] + expr = _x*Mul(*bfactors) - self.z*Mul(*afactors) + poly = Poly(expr, _x) + + n = poly.degree() - 1 + b = [closed_form] + for _ in range(n): + b.append(self.z*b[-1].diff(self.z)) + + self.B = Matrix(b) + self.C = Matrix([[1] + [0]*n]) + + m = eye(n) + m = m.col_insert(0, zeros(n, 1)) + l = poly.all_coeffs()[1:] + l.reverse() + self.M = m.row_insert(n, -Matrix([l])/poly.all_coeffs()[0]) + + def __init__(self, func, z, res, symbols, B=None, C=None, M=None): + z = sympify(z) + res = sympify(res) + symbols = [x for x in sympify(symbols) if func.has(x)] + + self.z = z + self.symbols = symbols + self.B = B + self.C = C + self.M = M + self.func = func + + # TODO with symbolic parameters, it could be advantageous + # (for prettier answers) to compute a basis only *after* + # instantiation + if res is not None: + self._compute_basis(res) + + @property + def closed_form(self): + return reduce(lambda s,m: s+m[0]*m[1], zip(self.C, self.B), S.Zero) + + def find_instantiations(self, func): + """ + Find substitutions of the free symbols that match ``func``. + + Return the substitution dictionaries as a list. Note that the returned + instantiations need not actually match, or be valid! + + """ + from sympy.solvers import solve + ap = func.ap + bq = func.bq + if len(ap) != len(self.func.ap) or len(bq) != len(self.func.bq): + raise TypeError('Cannot instantiate other number of parameters') + symbol_values = [] + for a in self.symbols: + if a in self.func.ap.args: + symbol_values.append(ap) + elif a in self.func.bq.args: + symbol_values.append(bq) + else: + raise ValueError("At least one of the parameters of the " + "formula must be equal to %s" % (a,)) + base_repl = [dict(list(zip(self.symbols, values))) + for values in product(*symbol_values)] + abuckets, bbuckets = [sift(params, _mod1) for params in [ap, bq]] + a_inv, b_inv = [{a: len(vals) for a, vals in bucket.items()} + for bucket in [abuckets, bbuckets]] + critical_values = [[0] for _ in self.symbols] + result = [] + _n = Dummy() + for repl in base_repl: + symb_a, symb_b = [sift(params, lambda x: _mod1(x.xreplace(repl))) + for params in [self.func.ap, self.func.bq]] + for bucket, obucket in [(abuckets, symb_a), (bbuckets, symb_b)]: + for mod in set(list(bucket.keys()) + list(obucket.keys())): + if (mod not in bucket) or (mod not in obucket) \ + or len(bucket[mod]) != len(obucket[mod]): + break + for a, vals in zip(self.symbols, critical_values): + if repl[a].free_symbols: + continue + exprs = [expr for expr in obucket[mod] if expr.has(a)] + repl0 = repl.copy() + repl0[a] += _n + for expr in exprs: + for target in bucket[mod]: + n0, = solve(expr.xreplace(repl0) - target, _n) + if n0.free_symbols: + raise ValueError("Value should not be true") + vals.append(n0) + else: + values = [] + for a, vals in zip(self.symbols, critical_values): + a0 = repl[a] + min_ = floor(min(vals)) + max_ = ceiling(max(vals)) + values.append([a0 + n for n in range(min_, max_ + 1)]) + result.extend(dict(list(zip(self.symbols, l))) for l in product(*values)) + return result + + + + +class FormulaCollection: + """ A collection of formulae to use as origins. """ + + def __init__(self): + """ Doing this globally at module init time is a pain ... """ + self.symbolic_formulae = {} + self.concrete_formulae = {} + self.formulae = [] + + add_formulae(self.formulae) + + # Now process the formulae into a helpful form. + # These dicts are indexed by (p, q). + + for f in self.formulae: + sizes = f.func.sizes + if len(f.symbols) > 0: + self.symbolic_formulae.setdefault(sizes, []).append(f) + else: + inv = f.func.build_invariants() + self.concrete_formulae.setdefault(sizes, {})[inv] = f + + def lookup_origin(self, func): + """ + Given the suitable target ``func``, try to find an origin in our + knowledge base. + + Examples + ======== + + >>> from sympy.simplify.hyperexpand import (FormulaCollection, + ... Hyper_Function) + >>> f = FormulaCollection() + >>> f.lookup_origin(Hyper_Function((), ())).closed_form + exp(_z) + >>> f.lookup_origin(Hyper_Function([1], ())).closed_form + HyperRep_power1(-1, _z) + + >>> from sympy import S + >>> i = Hyper_Function([S('1/4'), S('3/4 + 4')], [S.Half]) + >>> f.lookup_origin(i).closed_form + HyperRep_sqrts1(-1/4, _z) + """ + inv = func.build_invariants() + sizes = func.sizes + if sizes in self.concrete_formulae and \ + inv in self.concrete_formulae[sizes]: + return self.concrete_formulae[sizes][inv] + + # We don't have a concrete formula. Try to instantiate. + if sizes not in self.symbolic_formulae: + return None # Too bad... + + possible = [] + for f in self.symbolic_formulae[sizes]: + repls = f.find_instantiations(func) + for repl in repls: + func2 = f.func.xreplace(repl) + if not func2._is_suitable_origin(): + continue + diff = func2.difficulty(func) + if diff == -1: + continue + possible.append((diff, repl, f, func2)) + + # find the nearest origin + possible.sort(key=lambda x: x[0]) + for _, repl, f, func2 in possible: + f2 = Formula(func2, f.z, None, [], f.B.subs(repl), + f.C.subs(repl), f.M.subs(repl)) + if not any(e.has(S.NaN, oo, -oo, zoo) for e in [f2.B, f2.M, f2.C]): + return f2 + + return None + + +class MeijerFormula: + """ + This class represents a Meijer G-function formula. + + Its data members are: + - z, the argument + - symbols, the free symbols (parameters) in the formula + - func, the function + - B, C, M (c/f ordinary Formula) + """ + + def __init__(self, an, ap, bm, bq, z, symbols, B, C, M, matcher): + an, ap, bm, bq = [Tuple(*list(map(expand, w))) for w in [an, ap, bm, bq]] + self.func = G_Function(an, ap, bm, bq) + self.z = z + self.symbols = symbols + self._matcher = matcher + self.B = B + self.C = C + self.M = M + + @property + def closed_form(self): + return reduce(lambda s,m: s+m[0]*m[1], zip(self.C, self.B), S.Zero) + + def try_instantiate(self, func): + """ + Try to instantiate the current formula to (almost) match func. + This uses the _matcher passed on init. + """ + if func.signature != self.func.signature: + return None + res = self._matcher(func) + if res is not None: + subs, newfunc = res + return MeijerFormula(newfunc.an, newfunc.ap, newfunc.bm, newfunc.bq, + self.z, [], + self.B.subs(subs), self.C.subs(subs), + self.M.subs(subs), None) + + +class MeijerFormulaCollection: + """ + This class holds a collection of meijer g formulae. + """ + + def __init__(self): + formulae = [] + add_meijerg_formulae(formulae) + self.formulae = defaultdict(list) + for formula in formulae: + self.formulae[formula.func.signature].append(formula) + self.formulae = dict(self.formulae) + + def lookup_origin(self, func): + """ Try to find a formula that matches func. """ + if func.signature not in self.formulae: + return None + for formula in self.formulae[func.signature]: + res = formula.try_instantiate(func) + if res is not None: + return res + + +class Operator: + """ + Base class for operators to be applied to our functions. + + Explanation + =========== + + These operators are differential operators. They are by convention + expressed in the variable D = z*d/dz (although this base class does + not actually care). + Note that when the operator is applied to an object, we typically do + *not* blindly differentiate but instead use a different representation + of the z*d/dz operator (see make_derivative_operator). + + To subclass from this, define a __init__ method that initializes a + self._poly variable. This variable stores a polynomial. By convention + the generator is z*d/dz, and acts to the right of all coefficients. + + Thus this poly + x**2 + 2*z*x + 1 + represents the differential operator + (z*d/dz)**2 + 2*z**2*d/dz. + + This class is used only in the implementation of the hypergeometric + function expansion algorithm. + """ + + def apply(self, obj, op): + """ + Apply ``self`` to the object ``obj``, where the generator is ``op``. + + Examples + ======== + + >>> from sympy.simplify.hyperexpand import Operator + >>> from sympy.polys.polytools import Poly + >>> from sympy.abc import x, y, z + >>> op = Operator() + >>> op._poly = Poly(x**2 + z*x + y, x) + >>> op.apply(z**7, lambda f: f.diff(z)) + y*z**7 + 7*z**7 + 42*z**5 + """ + coeffs = self._poly.all_coeffs() + coeffs.reverse() + diffs = [obj] + for c in coeffs[1:]: + diffs.append(op(diffs[-1])) + r = coeffs[0]*diffs[0] + for c, d in zip(coeffs[1:], diffs[1:]): + r += c*d + return r + + +class MultOperator(Operator): + """ Simply multiply by a "constant" """ + + def __init__(self, p): + self._poly = Poly(p, _x) + + +class ShiftA(Operator): + """ Increment an upper index. """ + + def __init__(self, ai): + ai = sympify(ai) + if ai == 0: + raise ValueError('Cannot increment zero upper index.') + self._poly = Poly(_x/ai + 1, _x) + + def __str__(self): + return '' % (1/self._poly.all_coeffs()[0]) + + +class ShiftB(Operator): + """ Decrement a lower index. """ + + def __init__(self, bi): + bi = sympify(bi) + if bi == 1: + raise ValueError('Cannot decrement unit lower index.') + self._poly = Poly(_x/(bi - 1) + 1, _x) + + def __str__(self): + return '' % (1/self._poly.all_coeffs()[0] + 1) + + +class UnShiftA(Operator): + """ Decrement an upper index. """ + + def __init__(self, ap, bq, i, z): + """ Note: i counts from zero! """ + ap, bq, i = list(map(sympify, [ap, bq, i])) + + self._ap = ap + self._bq = bq + self._i = i + + ap = list(ap) + bq = list(bq) + ai = ap.pop(i) - 1 + + if ai == 0: + raise ValueError('Cannot decrement unit upper index.') + + m = Poly(z*ai, _x) + for a in ap: + m *= Poly(_x + a, _x) + + A = Dummy('A') + n = D = Poly(ai*A - ai, A) + for b in bq: + n *= D + (b - 1).as_poly(A) + + b0 = -n.nth(0) + if b0 == 0: + raise ValueError('Cannot decrement upper index: ' + 'cancels with lower') + + n = Poly(Poly(n.all_coeffs()[:-1], A).as_expr().subs(A, _x/ai + 1), _x) + + self._poly = Poly((n - m)/b0, _x) + + def __str__(self): + return '' % (self._i, + self._ap, self._bq) + + +class UnShiftB(Operator): + """ Increment a lower index. """ + + def __init__(self, ap, bq, i, z): + """ Note: i counts from zero! """ + ap, bq, i = list(map(sympify, [ap, bq, i])) + + self._ap = ap + self._bq = bq + self._i = i + + ap = list(ap) + bq = list(bq) + bi = bq.pop(i) + 1 + + if bi == 0: + raise ValueError('Cannot increment -1 lower index.') + + m = Poly(_x*(bi - 1), _x) + for b in bq: + m *= Poly(_x + b - 1, _x) + + B = Dummy('B') + D = Poly((bi - 1)*B - bi + 1, B) + n = Poly(z, B) + for a in ap: + n *= (D + a.as_poly(B)) + + b0 = n.nth(0) + if b0 == 0: + raise ValueError('Cannot increment index: cancels with upper') + + n = Poly(Poly(n.all_coeffs()[:-1], B).as_expr().subs( + B, _x/(bi - 1) + 1), _x) + + self._poly = Poly((m - n)/b0, _x) + + def __str__(self): + return '' % (self._i, + self._ap, self._bq) + + +class MeijerShiftA(Operator): + """ Increment an upper b index. """ + + def __init__(self, bi): + bi = sympify(bi) + self._poly = Poly(bi - _x, _x) + + def __str__(self): + return '' % (self._poly.all_coeffs()[1]) + + +class MeijerShiftB(Operator): + """ Decrement an upper a index. """ + + def __init__(self, bi): + bi = sympify(bi) + self._poly = Poly(1 - bi + _x, _x) + + def __str__(self): + return '' % (1 - self._poly.all_coeffs()[1]) + + +class MeijerShiftC(Operator): + """ Increment a lower b index. """ + + def __init__(self, bi): + bi = sympify(bi) + self._poly = Poly(-bi + _x, _x) + + def __str__(self): + return '' % (-self._poly.all_coeffs()[1]) + + +class MeijerShiftD(Operator): + """ Decrement a lower a index. """ + + def __init__(self, bi): + bi = sympify(bi) + self._poly = Poly(bi - 1 - _x, _x) + + def __str__(self): + return '' % (self._poly.all_coeffs()[1] + 1) + + +class MeijerUnShiftA(Operator): + """ Decrement an upper b index. """ + + def __init__(self, an, ap, bm, bq, i, z): + """ Note: i counts from zero! """ + an, ap, bm, bq, i = list(map(sympify, [an, ap, bm, bq, i])) + + self._an = an + self._ap = ap + self._bm = bm + self._bq = bq + self._i = i + + an = list(an) + ap = list(ap) + bm = list(bm) + bq = list(bq) + bi = bm.pop(i) - 1 + + m = Poly(1, _x) * prod(Poly(b - _x, _x) for b in bm) * prod(Poly(_x - b, _x) for b in bq) + + A = Dummy('A') + D = Poly(bi - A, A) + n = Poly(z, A) * prod((D + 1 - a) for a in an) * prod((-D + a - 1) for a in ap) + + b0 = n.nth(0) + if b0 == 0: + raise ValueError('Cannot decrement upper b index (cancels)') + + n = Poly(Poly(n.all_coeffs()[:-1], A).as_expr().subs(A, bi - _x), _x) + + self._poly = Poly((m - n)/b0, _x) + + def __str__(self): + return '' % (self._i, + self._an, self._ap, self._bm, self._bq) + + +class MeijerUnShiftB(Operator): + """ Increment an upper a index. """ + + def __init__(self, an, ap, bm, bq, i, z): + """ Note: i counts from zero! """ + an, ap, bm, bq, i = list(map(sympify, [an, ap, bm, bq, i])) + + self._an = an + self._ap = ap + self._bm = bm + self._bq = bq + self._i = i + + an = list(an) + ap = list(ap) + bm = list(bm) + bq = list(bq) + ai = an.pop(i) + 1 + + m = Poly(z, _x) + for a in an: + m *= Poly(1 - a + _x, _x) + for a in ap: + m *= Poly(a - 1 - _x, _x) + + B = Dummy('B') + D = Poly(B + ai - 1, B) + n = Poly(1, B) + for b in bm: + n *= (-D + b) + for b in bq: + n *= (D - b) + + b0 = n.nth(0) + if b0 == 0: + raise ValueError('Cannot increment upper a index (cancels)') + + n = Poly(Poly(n.all_coeffs()[:-1], B).as_expr().subs( + B, 1 - ai + _x), _x) + + self._poly = Poly((m - n)/b0, _x) + + def __str__(self): + return '' % (self._i, + self._an, self._ap, self._bm, self._bq) + + +class MeijerUnShiftC(Operator): + """ Decrement a lower b index. """ + # XXX this is "essentially" the same as MeijerUnShiftA. This "essentially" + # can be made rigorous using the functional equation G(1/z) = G'(z), + # where G' denotes a G function of slightly altered parameters. + # However, sorting out the details seems harder than just coding it + # again. + + def __init__(self, an, ap, bm, bq, i, z): + """ Note: i counts from zero! """ + an, ap, bm, bq, i = list(map(sympify, [an, ap, bm, bq, i])) + + self._an = an + self._ap = ap + self._bm = bm + self._bq = bq + self._i = i + + an = list(an) + ap = list(ap) + bm = list(bm) + bq = list(bq) + bi = bq.pop(i) - 1 + + m = Poly(1, _x) + for b in bm: + m *= Poly(b - _x, _x) + for b in bq: + m *= Poly(_x - b, _x) + + C = Dummy('C') + D = Poly(bi + C, C) + n = Poly(z, C) + for a in an: + n *= (D + 1 - a) + for a in ap: + n *= (-D + a - 1) + + b0 = n.nth(0) + if b0 == 0: + raise ValueError('Cannot decrement lower b index (cancels)') + + n = Poly(Poly(n.all_coeffs()[:-1], C).as_expr().subs(C, _x - bi), _x) + + self._poly = Poly((m - n)/b0, _x) + + def __str__(self): + return '' % (self._i, + self._an, self._ap, self._bm, self._bq) + + +class MeijerUnShiftD(Operator): + """ Increment a lower a index. """ + # XXX This is essentially the same as MeijerUnShiftA. + # See comment at MeijerUnShiftC. + + def __init__(self, an, ap, bm, bq, i, z): + """ Note: i counts from zero! """ + an, ap, bm, bq, i = list(map(sympify, [an, ap, bm, bq, i])) + + self._an = an + self._ap = ap + self._bm = bm + self._bq = bq + self._i = i + + an = list(an) + ap = list(ap) + bm = list(bm) + bq = list(bq) + ai = ap.pop(i) + 1 + + m = Poly(z, _x) + for a in an: + m *= Poly(1 - a + _x, _x) + for a in ap: + m *= Poly(a - 1 - _x, _x) + + B = Dummy('B') # - this is the shift operator `D_I` + D = Poly(ai - 1 - B, B) + n = Poly(1, B) + for b in bm: + n *= (-D + b) + for b in bq: + n *= (D - b) + + b0 = n.nth(0) + if b0 == 0: + raise ValueError('Cannot increment lower a index (cancels)') + + n = Poly(Poly(n.all_coeffs()[:-1], B).as_expr().subs( + B, ai - 1 - _x), _x) + + self._poly = Poly((m - n)/b0, _x) + + def __str__(self): + return '' % (self._i, + self._an, self._ap, self._bm, self._bq) + + +class ReduceOrder(Operator): + """ Reduce Order by cancelling an upper and a lower index. """ + + def __new__(cls, ai, bj): + """ For convenience if reduction is not possible, return None. """ + ai = sympify(ai) + bj = sympify(bj) + n = ai - bj + if not n.is_Integer or n < 0: + return None + if bj.is_integer and bj.is_nonpositive: + return None + + expr = Operator.__new__(cls) + + p = S.One + for k in range(n): + p *= (_x + bj + k)/(bj + k) + + expr._poly = Poly(p, _x) + expr._a = ai + expr._b = bj + + return expr + + @classmethod + def _meijer(cls, b, a, sign): + """ Cancel b + sign*s and a + sign*s + This is for meijer G functions. """ + b = sympify(b) + a = sympify(a) + n = b - a + if n.is_negative or not n.is_Integer: + return None + + expr = Operator.__new__(cls) + + p = S.One + for k in range(n): + p *= (sign*_x + a + k) + + expr._poly = Poly(p, _x) + if sign == -1: + expr._a = b + expr._b = a + else: + expr._b = Add(1, a - 1, evaluate=False) + expr._a = Add(1, b - 1, evaluate=False) + + return expr + + @classmethod + def meijer_minus(cls, b, a): + return cls._meijer(b, a, -1) + + @classmethod + def meijer_plus(cls, a, b): + return cls._meijer(1 - a, 1 - b, 1) + + def __str__(self): + return '' % \ + (self._a, self._b) + + +def _reduce_order(ap, bq, gen, key): + """ Order reduction algorithm used in Hypergeometric and Meijer G """ + ap = list(ap) + bq = list(bq) + + ap.sort(key=key) + bq.sort(key=key) + + nap = [] + # we will edit bq in place + operators = [] + for a in ap: + op = None + for i in range(len(bq)): + op = gen(a, bq[i]) + if op is not None: + bq.pop(i) + break + if op is None: + nap.append(a) + else: + operators.append(op) + + return nap, bq, operators + + +def reduce_order(func): + """ + Given the hypergeometric function ``func``, find a sequence of operators to + reduces order as much as possible. + + Explanation + =========== + + Return (newfunc, [operators]), where applying the operators to the + hypergeometric function newfunc yields func. + + Examples + ======== + + >>> from sympy.simplify.hyperexpand import reduce_order, Hyper_Function + >>> reduce_order(Hyper_Function((1, 2), (3, 4))) + (Hyper_Function((1, 2), (3, 4)), []) + >>> reduce_order(Hyper_Function((1,), (1,))) + (Hyper_Function((), ()), []) + >>> reduce_order(Hyper_Function((2, 4), (3, 3))) + (Hyper_Function((2,), (3,)), []) + """ + nap, nbq, operators = _reduce_order(func.ap, func.bq, ReduceOrder, default_sort_key) + + return Hyper_Function(Tuple(*nap), Tuple(*nbq)), operators + + +def reduce_order_meijer(func): + """ + Given the Meijer G function parameters, ``func``, find a sequence of + operators that reduces order as much as possible. + + Return newfunc, [operators]. + + Examples + ======== + + >>> from sympy.simplify.hyperexpand import (reduce_order_meijer, + ... G_Function) + >>> reduce_order_meijer(G_Function([3, 4], [5, 6], [3, 4], [1, 2]))[0] + G_Function((4, 3), (5, 6), (3, 4), (2, 1)) + >>> reduce_order_meijer(G_Function([3, 4], [5, 6], [3, 4], [1, 8]))[0] + G_Function((3,), (5, 6), (3, 4), (1,)) + >>> reduce_order_meijer(G_Function([3, 4], [5, 6], [7, 5], [1, 5]))[0] + G_Function((3,), (), (), (1,)) + >>> reduce_order_meijer(G_Function([3, 4], [5, 6], [7, 5], [5, 3]))[0] + G_Function((), (), (), ()) + """ + + nan, nbq, ops1 = _reduce_order(func.an, func.bq, ReduceOrder.meijer_plus, + lambda x: default_sort_key(-x)) + nbm, nap, ops2 = _reduce_order(func.bm, func.ap, ReduceOrder.meijer_minus, + default_sort_key) + + return G_Function(nan, nap, nbm, nbq), ops1 + ops2 + + +def make_derivative_operator(M, z): + """ Create a derivative operator, to be passed to Operator.apply. """ + def doit(C): + r = z*C.diff(z) + C*M + r = r.applyfunc(make_simp(z)) + return r + return doit + + +def apply_operators(obj, ops, op): + """ + Apply the list of operators ``ops`` to object ``obj``, substituting + ``op`` for the generator. + """ + res = obj + for o in reversed(ops): + res = o.apply(res, op) + return res + + +def devise_plan(target, origin, z): + """ + Devise a plan (consisting of shift and un-shift operators) to be applied + to the hypergeometric function ``target`` to yield ``origin``. + Returns a list of operators. + + Examples + ======== + + >>> from sympy.simplify.hyperexpand import devise_plan, Hyper_Function + >>> from sympy.abc import z + + Nothing to do: + + >>> devise_plan(Hyper_Function((1, 2), ()), Hyper_Function((1, 2), ()), z) + [] + >>> devise_plan(Hyper_Function((), (1, 2)), Hyper_Function((), (1, 2)), z) + [] + + Very simple plans: + + >>> devise_plan(Hyper_Function((2,), ()), Hyper_Function((1,), ()), z) + [] + >>> devise_plan(Hyper_Function((), (2,)), Hyper_Function((), (1,)), z) + [] + + Several buckets: + + >>> from sympy import S + >>> devise_plan(Hyper_Function((1, S.Half), ()), + ... Hyper_Function((2, S('3/2')), ()), z) #doctest: +NORMALIZE_WHITESPACE + [, + ] + + A slightly more complicated plan: + + >>> devise_plan(Hyper_Function((1, 3), ()), Hyper_Function((2, 2), ()), z) + [, ] + + Another more complicated plan: (note that the ap have to be shifted first!) + + >>> devise_plan(Hyper_Function((1, -1), (2,)), Hyper_Function((3, -2), (4,)), z) + [, , + , + , ] + """ + abuckets, bbuckets, nabuckets, nbbuckets = [sift(params, _mod1) for + params in (target.ap, target.bq, origin.ap, origin.bq)] + + if len(list(abuckets.keys())) != len(list(nabuckets.keys())) or \ + len(list(bbuckets.keys())) != len(list(nbbuckets.keys())): + raise ValueError('%s not reachable from %s' % (target, origin)) + + ops = [] + + def do_shifts(fro, to, inc, dec): + ops = [] + for i in range(len(fro)): + if to[i] - fro[i] > 0: + sh = inc + ch = 1 + else: + sh = dec + ch = -1 + + while to[i] != fro[i]: + ops += [sh(fro, i)] + fro[i] += ch + + return ops + + def do_shifts_a(nal, nbk, al, aother, bother): + """ Shift us from (nal, nbk) to (al, nbk). """ + return do_shifts(nal, al, lambda p, i: ShiftA(p[i]), + lambda p, i: UnShiftA(p + aother, nbk + bother, i, z)) + + def do_shifts_b(nal, nbk, bk, aother, bother): + """ Shift us from (nal, nbk) to (nal, bk). """ + return do_shifts(nbk, bk, + lambda p, i: UnShiftB(nal + aother, p + bother, i, z), + lambda p, i: ShiftB(p[i])) + + for r in sorted(list(abuckets.keys()) + list(bbuckets.keys()), key=default_sort_key): + al = () + nal = () + bk = () + nbk = () + if r in abuckets: + al = abuckets[r] + nal = nabuckets[r] + if r in bbuckets: + bk = bbuckets[r] + nbk = nbbuckets[r] + if len(al) != len(nal) or len(bk) != len(nbk): + raise ValueError('%s not reachable from %s' % (target, origin)) + + al, nal, bk, nbk = [sorted(w, key=default_sort_key) + for w in [al, nal, bk, nbk]] + + def others(dic, key): + l = [] + for k in dic: + if k != key: + l.extend(dic[k]) + return l + aother = others(nabuckets, r) + bother = others(nbbuckets, r) + + if len(al) == 0: + # there can be no complications, just shift the bs as we please + ops += do_shifts_b([], nbk, bk, aother, bother) + elif len(bk) == 0: + # there can be no complications, just shift the as as we please + ops += do_shifts_a(nal, [], al, aother, bother) + else: + namax = nal[-1] + amax = al[-1] + + if nbk[0] - namax <= 0 or bk[0] - amax <= 0: + raise ValueError('Non-suitable parameters.') + + if namax - amax > 0: + # we are going to shift down - first do the as, then the bs + ops += do_shifts_a(nal, nbk, al, aother, bother) + ops += do_shifts_b(al, nbk, bk, aother, bother) + else: + # we are going to shift up - first do the bs, then the as + ops += do_shifts_b(nal, nbk, bk, aother, bother) + ops += do_shifts_a(nal, bk, al, aother, bother) + + nabuckets[r] = al + nbbuckets[r] = bk + + ops.reverse() + return ops + + +def try_shifted_sum(func, z): + """ Try to recognise a hypergeometric sum that starts from k > 0. """ + abuckets, bbuckets = sift(func.ap, _mod1), sift(func.bq, _mod1) + if len(abuckets[S.Zero]) != 1: + return None + r = abuckets[S.Zero][0] + if r <= 0: + return None + if S.Zero not in bbuckets: + return None + l = list(bbuckets[S.Zero]) + l.sort() + k = l[0] + if k <= 0: + return None + + nap = list(func.ap) + nap.remove(r) + nbq = list(func.bq) + nbq.remove(k) + k -= 1 + nap = [x - k for x in nap] + nbq = [x - k for x in nbq] + + ops = [] + for n in range(r - 1): + ops.append(ShiftA(n + 1)) + ops.reverse() + + fac = factorial(k)/z**k + fac *= Mul(*[rf(b, k) for b in nbq]) + fac /= Mul(*[rf(a, k) for a in nap]) + + ops += [MultOperator(fac)] + + p = 0 + for n in range(k): + m = z**n/factorial(n) + m *= Mul(*[rf(a, n) for a in nap]) + m /= Mul(*[rf(b, n) for b in nbq]) + p += m + + return Hyper_Function(nap, nbq), ops, -p + + +def try_polynomial(func, z): + """ Recognise polynomial cases. Returns None if not such a case. + Requires order to be fully reduced. """ + abuckets, bbuckets = sift(func.ap, _mod1), sift(func.bq, _mod1) + a0 = abuckets[S.Zero] + b0 = bbuckets[S.Zero] + a0.sort() + b0.sort() + al0 = [x for x in a0 if x <= 0] + bl0 = [x for x in b0 if x <= 0] + + if bl0 and all(a < bl0[-1] for a in al0): + return oo + if not al0: + return None + + a = al0[-1] + fac = 1 + res = S.One + for n in Tuple(*list(range(-a))): + fac *= z + fac /= n + 1 + fac *= Mul(*[a + n for a in func.ap]) + fac /= Mul(*[b + n for b in func.bq]) + res += fac + return res + + +def try_lerchphi(func): + """ + Try to find an expression for Hyper_Function ``func`` in terms of Lerch + Transcendents. + + Return None if no such expression can be found. + """ + # This is actually quite simple, and is described in Roach's paper, + # section 18. + # We don't need to implement the reduction to polylog here, this + # is handled by expand_func. + + # First we need to figure out if the summation coefficient is a rational + # function of the summation index, and construct that rational function. + abuckets, bbuckets = sift(func.ap, _mod1), sift(func.bq, _mod1) + + paired = {} + for key, value in abuckets.items(): + if key != 0 and key not in bbuckets: + return None + bvalue = bbuckets[key] + paired[key] = (list(value), list(bvalue)) + bbuckets.pop(key, None) + if bbuckets != {}: + return None + if S.Zero not in abuckets: + return None + aints, bints = paired[S.Zero] + # Account for the additional n! in denominator + paired[S.Zero] = (aints, bints + [1]) + + t = Dummy('t') + numer = S.One + denom = S.One + for key, (avalue, bvalue) in paired.items(): + if len(avalue) != len(bvalue): + return None + # Note that since order has been reduced fully, all the b are + # bigger than all the a they differ from by an integer. In particular + # if there are any negative b left, this function is not well-defined. + for a, b in zip(avalue, bvalue): + if (a - b).is_positive: + k = a - b + numer *= rf(b + t, k) + denom *= rf(b, k) + else: + k = b - a + numer *= rf(a, k) + denom *= rf(a + t, k) + + # Now do a partial fraction decomposition. + # We assemble two structures: a list monomials of pairs (a, b) representing + # a*t**b (b a non-negative integer), and a dict terms, where + # terms[a] = [(b, c)] means that there is a term b/(t-a)**c. + part = apart(numer/denom, t) + args = Add.make_args(part) + monomials = [] + terms = {} + for arg in args: + numer, denom = arg.as_numer_denom() + if not denom.has(t): + p = Poly(numer, t) + if not p.is_monomial: + raise TypeError("p should be monomial") + ((b, ), a) = p.LT() + monomials += [(a/denom, b)] + continue + if numer.has(t): + raise NotImplementedError('Need partial fraction decomposition' + ' with linear denominators') + indep, [dep] = denom.as_coeff_mul(t) + n = 1 + if dep.is_Pow: + n = dep.exp + dep = dep.base + if dep == t: + a = 0 + elif dep.is_Add: + a, tmp = dep.as_independent(t) + b = 1 + if tmp != t: + b, _ = tmp.as_independent(t) + if dep != b*t + a: + raise NotImplementedError('unrecognised form %s' % dep) + a /= b + indep *= b**n + else: + raise NotImplementedError('unrecognised form of partial fraction') + terms.setdefault(a, []).append((numer/indep, n)) + + # Now that we have this information, assemble our formula. All the + # monomials yield rational functions and go into one basis element. + # The terms[a] are related by differentiation. If the largest exponent is + # n, we need lerchphi(z, k, a) for k = 1, 2, ..., n. + # deriv maps a basis to its derivative, expressed as a C(z)-linear + # combination of other basis elements. + deriv = {} + coeffs = {} + z = Dummy('z') + monomials.sort(key=lambda x: x[1]) + mon = {0: 1/(1 - z)} + if monomials: + for k in range(monomials[-1][1]): + mon[k + 1] = z*mon[k].diff(z) + for a, n in monomials: + coeffs.setdefault(S.One, []).append(a*mon[n]) + for a, l in terms.items(): + for c, k in l: + coeffs.setdefault(lerchphi(z, k, a), []).append(c) + l.sort(key=lambda x: x[1]) + for k in range(2, l[-1][1] + 1): + deriv[lerchphi(z, k, a)] = [(-a, lerchphi(z, k, a)), + (1, lerchphi(z, k - 1, a))] + deriv[lerchphi(z, 1, a)] = [(-a, lerchphi(z, 1, a)), + (1/(1 - z), S.One)] + trans = {} + for n, b in enumerate([S.One] + list(deriv.keys())): + trans[b] = n + basis = [expand_func(b) for (b, _) in sorted(trans.items(), + key=lambda x:x[1])] + B = Matrix(basis) + C = Matrix([[0]*len(B)]) + for b, c in coeffs.items(): + C[trans[b]] = Add(*c) + M = zeros(len(B)) + for b, l in deriv.items(): + for c, b2 in l: + M[trans[b], trans[b2]] = c + return Formula(func, z, None, [], B, C, M) + + +def build_hypergeometric_formula(func): + """ + Create a formula object representing the hypergeometric function ``func``. + + """ + # We know that no `ap` are negative integers, otherwise "detect poly" + # would have kicked in. However, `ap` could be empty. In this case we can + # use a different basis. + # I'm not aware of a basis that works in all cases. + z = Dummy('z') + if func.ap: + afactors = [_x + a for a in func.ap] + bfactors = [_x + b - 1 for b in func.bq] + expr = _x*Mul(*bfactors) - z*Mul(*afactors) + poly = Poly(expr, _x) + n = poly.degree() + basis = [] + M = zeros(n) + for k in range(n): + a = func.ap[0] + k + basis += [hyper([a] + list(func.ap[1:]), func.bq, z)] + if k < n - 1: + M[k, k] = -a + M[k, k + 1] = a + B = Matrix(basis) + C = Matrix([[1] + [0]*(n - 1)]) + derivs = [eye(n)] + for k in range(n): + derivs.append(M*derivs[k]) + l = poly.all_coeffs() + l.reverse() + res = [0]*n + for k, c in enumerate(l): + for r, d in enumerate(C*derivs[k]): + res[r] += c*d + for k, c in enumerate(res): + M[n - 1, k] = -c/derivs[n - 1][0, n - 1]/poly.all_coeffs()[0] + return Formula(func, z, None, [], B, C, M) + else: + # Since there are no `ap`, none of the `bq` can be non-positive + # integers. + basis = [] + bq = list(func.bq[:]) + for i in range(len(bq)): + basis += [hyper([], bq, z)] + bq[i] += 1 + basis += [hyper([], bq, z)] + B = Matrix(basis) + n = len(B) + C = Matrix([[1] + [0]*(n - 1)]) + M = zeros(n) + M[0, n - 1] = z/Mul(*func.bq) + for k in range(1, n): + M[k, k - 1] = func.bq[k - 1] + M[k, k] = -func.bq[k - 1] + return Formula(func, z, None, [], B, C, M) + + +def hyperexpand_special(ap, bq, z): + """ + Try to find a closed-form expression for hyper(ap, bq, z), where ``z`` + is supposed to be a "special" value, e.g. 1. + + This function tries various of the classical summation formulae + (Gauss, Saalschuetz, etc). + """ + # This code is very ad-hoc. There are many clever algorithms + # (notably Zeilberger's) related to this problem. + # For now we just want a few simple cases to work. + p, q = len(ap), len(bq) + z_ = z + z = unpolarify(z) + if z == 0: + return S.One + from sympy.simplify.simplify import simplify + if p == 2 and q == 1: + # 2F1 + a, b, c = ap + bq + if z == 1: + # Gauss + return gamma(c - a - b)*gamma(c)/gamma(c - a)/gamma(c - b) + if z == -1 and simplify(b - a + c) == 1: + b, a = a, b + if z == -1 and simplify(a - b + c) == 1: + # Kummer + if b.is_integer and b.is_negative: + return 2*cos(pi*b/2)*gamma(-b)*gamma(b - a + 1) \ + /gamma(-b/2)/gamma(b/2 - a + 1) + else: + return gamma(b/2 + 1)*gamma(b - a + 1) \ + /gamma(b + 1)/gamma(b/2 - a + 1) + # TODO tons of more formulae + # investigate what algorithms exist + return hyper(ap, bq, z_) + +_collection = None + + +def _hyperexpand(func, z, ops0=[], z0=Dummy('z0'), premult=1, prem=0, + rewrite='default'): + """ + Try to find an expression for the hypergeometric function ``func``. + + Explanation + =========== + + The result is expressed in terms of a dummy variable ``z0``. Then it + is multiplied by ``premult``. Then ``ops0`` is applied. + ``premult`` must be a*z**prem for some a independent of ``z``. + """ + + if z.is_zero: + return S.One + + from sympy.simplify.simplify import simplify + + z = polarify(z, subs=False) + if rewrite == 'default': + rewrite = 'nonrepsmall' + + def carryout_plan(f, ops): + C = apply_operators(f.C.subs(f.z, z0), ops, + make_derivative_operator(f.M.subs(f.z, z0), z0)) + C = apply_operators(C, ops0, + make_derivative_operator(f.M.subs(f.z, z0) + + prem*eye(f.M.shape[0]), z0)) + + if premult == 1: + C = C.applyfunc(make_simp(z0)) + r = reduce(lambda s,m: s+m[0]*m[1], zip(C, f.B.subs(f.z, z0)), S.Zero)*premult + res = r.subs(z0, z) + if rewrite: + res = res.rewrite(rewrite) + return res + + # TODO + # The following would be possible: + # *) PFD Duplication (see Kelly Roach's paper) + # *) In a similar spirit, try_lerchphi() can be generalised considerably. + + global _collection + if _collection is None: + _collection = FormulaCollection() + + debug('Trying to expand hypergeometric function ', func) + + # First reduce order as much as possible. + func, ops = reduce_order(func) + if ops: + debug(' Reduced order to ', func) + else: + debug(' Could not reduce order.') + + # Now try polynomial cases + res = try_polynomial(func, z0) + if res is not None: + debug(' Recognised polynomial.') + p = apply_operators(res, ops, lambda f: z0*f.diff(z0)) + p = apply_operators(p*premult, ops0, lambda f: z0*f.diff(z0)) + return unpolarify(simplify(p).subs(z0, z)) + + # Try to recognise a shifted sum. + p = S.Zero + res = try_shifted_sum(func, z0) + if res is not None: + func, nops, p = res + debug(' Recognised shifted sum, reduced order to ', func) + ops += nops + + # apply the plan for poly + p = apply_operators(p, ops, lambda f: z0*f.diff(z0)) + p = apply_operators(p*premult, ops0, lambda f: z0*f.diff(z0)) + p = simplify(p).subs(z0, z) + + # Try special expansions early. + if unpolarify(z) in [1, -1] and (len(func.ap), len(func.bq)) == (2, 1): + f = build_hypergeometric_formula(func) + r = carryout_plan(f, ops).replace(hyper, hyperexpand_special) + if not r.has(hyper): + return r + p + + # Try to find a formula in our collection + formula = _collection.lookup_origin(func) + + # Now try a lerch phi formula + if formula is None: + formula = try_lerchphi(func) + + if formula is None: + debug(' Could not find an origin. ', + 'Will return answer in terms of ' + 'simpler hypergeometric functions.') + formula = build_hypergeometric_formula(func) + + debug(' Found an origin: ', formula.closed_form, ' ', formula.func) + + # We need to find the operators that convert formula into func. + ops += devise_plan(func, formula.func, z0) + + # Now carry out the plan. + r = carryout_plan(formula, ops) + p + + return powdenest(r, polar=True).replace(hyper, hyperexpand_special) + + +def devise_plan_meijer(fro, to, z): + """ + Find operators to convert G-function ``fro`` into G-function ``to``. + + Explanation + =========== + + It is assumed that ``fro`` and ``to`` have the same signatures, and that in fact + any corresponding pair of parameters differs by integers, and a direct path + is possible. I.e. if there are parameters a1 b1 c1 and a2 b2 c2 it is + assumed that a1 can be shifted to a2, etc. The only thing this routine + determines is the order of shifts to apply, nothing clever will be tried. + It is also assumed that ``fro`` is suitable. + + Examples + ======== + + >>> from sympy.simplify.hyperexpand import (devise_plan_meijer, + ... G_Function) + >>> from sympy.abc import z + + Empty plan: + + >>> devise_plan_meijer(G_Function([1], [2], [3], [4]), + ... G_Function([1], [2], [3], [4]), z) + [] + + Very simple plans: + + >>> devise_plan_meijer(G_Function([0], [], [], []), + ... G_Function([1], [], [], []), z) + [] + >>> devise_plan_meijer(G_Function([0], [], [], []), + ... G_Function([-1], [], [], []), z) + [] + >>> devise_plan_meijer(G_Function([], [1], [], []), + ... G_Function([], [2], [], []), z) + [] + + Slightly more complicated plans: + + >>> devise_plan_meijer(G_Function([0], [], [], []), + ... G_Function([2], [], [], []), z) + [, + ] + >>> devise_plan_meijer(G_Function([0], [], [0], []), + ... G_Function([-1], [], [1], []), z) + [, ] + + Order matters: + + >>> devise_plan_meijer(G_Function([0], [], [0], []), + ... G_Function([1], [], [1], []), z) + [, ] + """ + # TODO for now, we use the following simple heuristic: inverse-shift + # when possible, shift otherwise. Give up if we cannot make progress. + + def try_shift(f, t, shifter, diff, counter): + """ Try to apply ``shifter`` in order to bring some element in ``f`` + nearer to its counterpart in ``to``. ``diff`` is +/- 1 and + determines the effect of ``shifter``. Counter is a list of elements + blocking the shift. + + Return an operator if change was possible, else None. + """ + for idx, (a, b) in enumerate(zip(f, t)): + if ( + (a - b).is_integer and (b - a)/diff > 0 and + all(a != x for x in counter)): + sh = shifter(idx) + f[idx] += diff + return sh + fan = list(fro.an) + fap = list(fro.ap) + fbm = list(fro.bm) + fbq = list(fro.bq) + ops = [] + change = True + while change: + change = False + op = try_shift(fan, to.an, + lambda i: MeijerUnShiftB(fan, fap, fbm, fbq, i, z), + 1, fbm + fbq) + if op is not None: + ops += [op] + change = True + continue + op = try_shift(fap, to.ap, + lambda i: MeijerUnShiftD(fan, fap, fbm, fbq, i, z), + 1, fbm + fbq) + if op is not None: + ops += [op] + change = True + continue + op = try_shift(fbm, to.bm, + lambda i: MeijerUnShiftA(fan, fap, fbm, fbq, i, z), + -1, fan + fap) + if op is not None: + ops += [op] + change = True + continue + op = try_shift(fbq, to.bq, + lambda i: MeijerUnShiftC(fan, fap, fbm, fbq, i, z), + -1, fan + fap) + if op is not None: + ops += [op] + change = True + continue + op = try_shift(fan, to.an, lambda i: MeijerShiftB(fan[i]), -1, []) + if op is not None: + ops += [op] + change = True + continue + op = try_shift(fap, to.ap, lambda i: MeijerShiftD(fap[i]), -1, []) + if op is not None: + ops += [op] + change = True + continue + op = try_shift(fbm, to.bm, lambda i: MeijerShiftA(fbm[i]), 1, []) + if op is not None: + ops += [op] + change = True + continue + op = try_shift(fbq, to.bq, lambda i: MeijerShiftC(fbq[i]), 1, []) + if op is not None: + ops += [op] + change = True + continue + if fan != list(to.an) or fap != list(to.ap) or fbm != list(to.bm) or \ + fbq != list(to.bq): + raise NotImplementedError('Could not devise plan.') + ops.reverse() + return ops + +_meijercollection = None + + +def _meijergexpand(func, z0, allow_hyper=False, rewrite='default', + place=None): + """ + Try to find an expression for the Meijer G function specified + by the G_Function ``func``. If ``allow_hyper`` is True, then returning + an expression in terms of hypergeometric functions is allowed. + + Currently this just does Slater's theorem. + If expansions exist both at zero and at infinity, ``place`` + can be set to ``0`` or ``zoo`` for the preferred choice. + """ + global _meijercollection + if _meijercollection is None: + _meijercollection = MeijerFormulaCollection() + if rewrite == 'default': + rewrite = None + + func0 = func + debug('Try to expand Meijer G function corresponding to ', func) + + # We will play games with analytic continuation - rather use a fresh symbol + z = Dummy('z') + + func, ops = reduce_order_meijer(func) + if ops: + debug(' Reduced order to ', func) + else: + debug(' Could not reduce order.') + + # Try to find a direct formula + f = _meijercollection.lookup_origin(func) + if f is not None: + debug(' Found a Meijer G formula: ', f.func) + ops += devise_plan_meijer(f.func, func, z) + + # Now carry out the plan. + C = apply_operators(f.C.subs(f.z, z), ops, + make_derivative_operator(f.M.subs(f.z, z), z)) + + C = C.applyfunc(make_simp(z)) + r = C*f.B.subs(f.z, z) + r = r[0].subs(z, z0) + return powdenest(r, polar=True) + + debug(" Could not find a direct formula. Trying Slater's theorem.") + + # TODO the following would be possible: + # *) Paired Index Theorems + # *) PFD Duplication + # (See Kelly Roach's paper for details on either.) + # + # TODO Also, we tend to create combinations of gamma functions that can be + # simplified. + + def can_do(pbm, pap): + """ Test if slater applies. """ + for i in pbm: + if len(pbm[i]) > 1: + l = 0 + if i in pap: + l = len(pap[i]) + if l + 1 < len(pbm[i]): + return False + return True + + def do_slater(an, bm, ap, bq, z, zfinal): + # zfinal is the value that will eventually be substituted for z. + # We pass it to _hyperexpand to improve performance. + func = G_Function(an, bm, ap, bq) + _, pbm, pap, _ = func.compute_buckets() + if not can_do(pbm, pap): + return S.Zero, False + + cond = len(an) + len(ap) < len(bm) + len(bq) + if len(an) + len(ap) == len(bm) + len(bq): + cond = abs(z) < 1 + if cond is False: + return S.Zero, False + + res = S.Zero + for m in pbm: + if len(pbm[m]) == 1: + bh = pbm[m][0] + fac = 1 + bo = list(bm) + bo.remove(bh) + for bj in bo: + fac *= gamma(bj - bh) + for aj in an: + fac *= gamma(1 + bh - aj) + for bj in bq: + fac /= gamma(1 + bh - bj) + for aj in ap: + fac /= gamma(aj - bh) + nap = [1 + bh - a for a in list(an) + list(ap)] + nbq = [1 + bh - b for b in list(bo) + list(bq)] + + k = polar_lift(S.NegativeOne**(len(ap) - len(bm))) + harg = k*zfinal + # NOTE even though k "is" +-1, this has to be t/k instead of + # t*k ... we are using polar numbers for consistency! + premult = (t/k)**bh + hyp = _hyperexpand(Hyper_Function(nap, nbq), harg, ops, + t, premult, bh, rewrite=None) + res += fac * hyp + else: + b_ = pbm[m][0] + ki = [bi - b_ for bi in pbm[m][1:]] + u = len(ki) + li = [ai - b_ for ai in pap[m][:u + 1]] + bo = list(bm) + for b in pbm[m]: + bo.remove(b) + ao = list(ap) + for a in pap[m][:u]: + ao.remove(a) + lu = li[-1] + di = [l - k for (l, k) in zip(li, ki)] + + # We first work out the integrand: + s = Dummy('s') + integrand = z**s + for b in bm: + if not Mod(b, 1) and b.is_Number: + b = int(round(b)) + integrand *= gamma(b - s) + for a in an: + integrand *= gamma(1 - a + s) + for b in bq: + integrand /= gamma(1 - b + s) + for a in ap: + integrand /= gamma(a - s) + + # Now sum the finitely many residues: + # XXX This speeds up some cases - is it a good idea? + integrand = expand_func(integrand) + for r in range(int(round(lu))): + resid = residue(integrand, s, b_ + r) + resid = apply_operators(resid, ops, lambda f: z*f.diff(z)) + res -= resid + + # Now the hypergeometric term. + au = b_ + lu + k = polar_lift(S.NegativeOne**(len(ao) + len(bo) + 1)) + harg = k*zfinal + premult = (t/k)**au + nap = [1 + au - a for a in list(an) + list(ap)] + [1] + nbq = [1 + au - b for b in list(bm) + list(bq)] + + hyp = _hyperexpand(Hyper_Function(nap, nbq), harg, ops, + t, premult, au, rewrite=None) + + C = S.NegativeOne**(lu)/factorial(lu) + for i in range(u): + C *= S.NegativeOne**di[i]/rf(lu - li[i] + 1, di[i]) + for a in an: + C *= gamma(1 - a + au) + for b in bo: + C *= gamma(b - au) + for a in ao: + C /= gamma(a - au) + for b in bq: + C /= gamma(1 - b + au) + + res += C*hyp + + return res, cond + + t = Dummy('t') + slater1, cond1 = do_slater(func.an, func.bm, func.ap, func.bq, z, z0) + + def tr(l): + return [1 - x for x in l] + + for op in ops: + op._poly = Poly(op._poly.subs({z: 1/t, _x: -_x}), _x) + slater2, cond2 = do_slater(tr(func.bm), tr(func.an), tr(func.bq), tr(func.ap), + t, 1/z0) + + slater1 = powdenest(slater1.subs(z, z0), polar=True) + slater2 = powdenest(slater2.subs(t, 1/z0), polar=True) + if not isinstance(cond2, bool): + cond2 = cond2.subs(t, 1/z) + + m = func(z) + if m.delta > 0 or \ + (m.delta == 0 and len(m.ap) == len(m.bq) and + (re(m.nu) < -1) is not False and polar_lift(z0) == polar_lift(1)): + # The condition delta > 0 means that the convergence region is + # connected. Any expression we find can be continued analytically + # to the entire convergence region. + # The conditions delta==0, p==q, re(nu) < -1 imply that G is continuous + # on the positive reals, so the values at z=1 agree. + if cond1 is not False: + cond1 = True + if cond2 is not False: + cond2 = True + + if cond1 is True: + slater1 = slater1.rewrite(rewrite or 'nonrep') + else: + slater1 = slater1.rewrite(rewrite or 'nonrepsmall') + if cond2 is True: + slater2 = slater2.rewrite(rewrite or 'nonrep') + else: + slater2 = slater2.rewrite(rewrite or 'nonrepsmall') + + if cond1 is not False and cond2 is not False: + # If one condition is False, there is no choice. + if place == 0: + cond2 = False + if place == zoo: + cond1 = False + + if not isinstance(cond1, bool): + cond1 = cond1.subs(z, z0) + if not isinstance(cond2, bool): + cond2 = cond2.subs(z, z0) + + def weight(expr, cond): + if cond is True: + c0 = 0 + elif cond is False: + c0 = 1 + else: + c0 = 2 + if expr.has(oo, zoo, -oo, nan): + # XXX this actually should not happen, but consider + # S('meijerg(((0, -1/2, 0, -1/2, 1/2), ()), ((0,), + # (-1/2, -1/2, -1/2, -1)), exp_polar(I*pi))/4') + c0 = 3 + return (c0, expr.count(hyper), expr.count_ops()) + + w1 = weight(slater1, cond1) + w2 = weight(slater2, cond2) + if min(w1, w2) <= (0, 1, oo): + if w1 < w2: + return slater1 + else: + return slater2 + if max(w1[0], w2[0]) <= 1 and max(w1[1], w2[1]) <= 1: + return Piecewise((slater1, cond1), (slater2, cond2), (func0(z0), True)) + + # We couldn't find an expression without hypergeometric functions. + # TODO it would be helpful to give conditions under which the integral + # is known to diverge. + r = Piecewise((slater1, cond1), (slater2, cond2), (func0(z0), True)) + if r.has(hyper) and not allow_hyper: + debug(' Could express using hypergeometric functions, ' + 'but not allowed.') + if not r.has(hyper) or allow_hyper: + return r + + return func0(z0) + + +def hyperexpand(f, allow_hyper=False, rewrite='default', place=None): + """ + Expand hypergeometric functions. If allow_hyper is True, allow partial + simplification (that is a result different from input, + but still containing hypergeometric functions). + + If a G-function has expansions both at zero and at infinity, + ``place`` can be set to ``0`` or ``zoo`` to indicate the + preferred choice. + + Examples + ======== + + >>> from sympy.simplify.hyperexpand import hyperexpand + >>> from sympy.functions import hyper + >>> from sympy.abc import z + >>> hyperexpand(hyper([], [], z)) + exp(z) + + Non-hyperegeometric parts of the expression and hypergeometric expressions + that are not recognised are left unchanged: + + >>> hyperexpand(1 + hyper([1, 1, 1], [], z)) + hyper((1, 1, 1), (), z) + 1 + """ + f = sympify(f) + + def do_replace(ap, bq, z): + r = _hyperexpand(Hyper_Function(ap, bq), z, rewrite=rewrite) + if r is None: + return hyper(ap, bq, z) + else: + return r + + def do_meijer(ap, bq, z): + r = _meijergexpand(G_Function(ap[0], ap[1], bq[0], bq[1]), z, + allow_hyper, rewrite=rewrite, place=place) + if not r.has(nan, zoo, oo, -oo): + return r + return f.replace(hyper, do_replace).replace(meijerg, do_meijer) diff --git a/phivenv/Lib/site-packages/sympy/simplify/hyperexpand_doc.py b/phivenv/Lib/site-packages/sympy/simplify/hyperexpand_doc.py new file mode 100644 index 0000000000000000000000000000000000000000..a18377f3aede5214036fbf628825536611001584 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/simplify/hyperexpand_doc.py @@ -0,0 +1,18 @@ +""" This module cooks up a docstring when imported. Its only purpose is to + be displayed in the sphinx documentation. """ + +from sympy.core.relational import Eq +from sympy.functions.special.hyper import hyper +from sympy.printing.latex import latex +from sympy.simplify.hyperexpand import FormulaCollection + +c = FormulaCollection() + +doc = "" + +for f in c.formulae: + obj = Eq(hyper(f.func.ap, f.func.bq, f.z), + f.closed_form.rewrite('nonrepsmall')) + doc += ".. math::\n %s\n" % latex(obj) + +__doc__ = doc diff --git a/phivenv/Lib/site-packages/sympy/simplify/powsimp.py b/phivenv/Lib/site-packages/sympy/simplify/powsimp.py new file mode 100644 index 0000000000000000000000000000000000000000..f72dfeb072e0d0d4737ace310eda5c2a3a082c16 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/simplify/powsimp.py @@ -0,0 +1,718 @@ +from collections import defaultdict +from functools import reduce +from math import prod + +from sympy.core.function import expand_log, count_ops, _coeff_isneg +from sympy.core import sympify, Basic, Dummy, S, Add, Mul, Pow, expand_mul, factor_terms +from sympy.core.sorting import ordered, default_sort_key +from sympy.core.numbers import Integer, Rational, equal_valued +from sympy.core.mul import _keep_coeff +from sympy.core.rules import Transform +from sympy.functions import exp_polar, exp, log, root, polarify, unpolarify +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.polys import lcm, gcd +from sympy.ntheory.factor_ import multiplicity + + + +def powsimp(expr, deep=False, combine='all', force=False, measure=count_ops): + """ + Reduce expression by combining powers with similar bases and exponents. + + Explanation + =========== + + If ``deep`` is ``True`` then powsimp() will also simplify arguments of + functions. By default ``deep`` is set to ``False``. + + If ``force`` is ``True`` then bases will be combined without checking for + assumptions, e.g. sqrt(x)*sqrt(y) -> sqrt(x*y) which is not true + if x and y are both negative. + + You can make powsimp() only combine bases or only combine exponents by + changing combine='base' or combine='exp'. By default, combine='all', + which does both. combine='base' will only combine:: + + a a a 2x x + x * y => (x*y) as well as things like 2 => 4 + + and combine='exp' will only combine + :: + + a b (a + b) + x * x => x + + combine='exp' will strictly only combine exponents in the way that used + to be automatic. Also use deep=True if you need the old behavior. + + When combine='all', 'exp' is evaluated first. Consider the first + example below for when there could be an ambiguity relating to this. + This is done so things like the second example can be completely + combined. If you want 'base' combined first, do something like + powsimp(powsimp(expr, combine='base'), combine='exp'). + + Examples + ======== + + >>> from sympy import powsimp, exp, log, symbols + >>> from sympy.abc import x, y, z, n + >>> powsimp(x**y*x**z*y**z, combine='all') + x**(y + z)*y**z + >>> powsimp(x**y*x**z*y**z, combine='exp') + x**(y + z)*y**z + >>> powsimp(x**y*x**z*y**z, combine='base', force=True) + x**y*(x*y)**z + + >>> powsimp(x**z*x**y*n**z*n**y, combine='all', force=True) + (n*x)**(y + z) + >>> powsimp(x**z*x**y*n**z*n**y, combine='exp') + n**(y + z)*x**(y + z) + >>> powsimp(x**z*x**y*n**z*n**y, combine='base', force=True) + (n*x)**y*(n*x)**z + + >>> x, y = symbols('x y', positive=True) + >>> powsimp(log(exp(x)*exp(y))) + log(exp(x)*exp(y)) + >>> powsimp(log(exp(x)*exp(y)), deep=True) + x + y + + Radicals with Mul bases will be combined if combine='exp' + + >>> from sympy import sqrt + >>> x, y = symbols('x y') + + Two radicals are automatically joined through Mul: + + >>> a=sqrt(x*sqrt(y)) + >>> a*a**3 == a**4 + True + + But if an integer power of that radical has been + autoexpanded then Mul does not join the resulting factors: + + >>> a**4 # auto expands to a Mul, no longer a Pow + x**2*y + >>> _*a # so Mul doesn't combine them + x**2*y*sqrt(x*sqrt(y)) + >>> powsimp(_) # but powsimp will + (x*sqrt(y))**(5/2) + >>> powsimp(x*y*a) # but won't when doing so would violate assumptions + x*y*sqrt(x*sqrt(y)) + + """ + def recurse(arg, **kwargs): + _deep = kwargs.get('deep', deep) + _combine = kwargs.get('combine', combine) + _force = kwargs.get('force', force) + _measure = kwargs.get('measure', measure) + return powsimp(arg, _deep, _combine, _force, _measure) + + expr = sympify(expr) + + if (not isinstance(expr, Basic) or isinstance(expr, MatrixSymbol) or ( + expr.is_Atom or expr in (exp_polar(0), exp_polar(1)))): + return expr + + if deep or expr.is_Add or expr.is_Mul and _y not in expr.args: + expr = expr.func(*[recurse(w) for w in expr.args]) + + if expr.is_Pow: + return recurse(expr*_y, deep=False)/_y + + if not expr.is_Mul: + return expr + + # handle the Mul + if combine in ('exp', 'all'): + # Collect base/exp data, while maintaining order in the + # non-commutative parts of the product + c_powers = defaultdict(list) + nc_part = [] + newexpr = [] + coeff = S.One + for term in expr.args: + if term.is_Rational: + coeff *= term + continue + if term.is_Pow: + term = _denest_pow(term) + if term.is_commutative: + b, e = term.as_base_exp() + if deep: + b, e = [recurse(i) for i in [b, e]] + if b.is_Pow or isinstance(b, exp): + # don't let smthg like sqrt(x**a) split into x**a, 1/2 + # or else it will be joined as x**(a/2) later + b, e = b**e, S.One + c_powers[b].append(e) + else: + # This is the logic that combines exponents for equal, + # but non-commutative bases: A**x*A**y == A**(x+y). + if nc_part: + b1, e1 = nc_part[-1].as_base_exp() + b2, e2 = term.as_base_exp() + if (b1 == b2 and + e1.is_commutative and e2.is_commutative): + nc_part[-1] = Pow(b1, Add(e1, e2)) + continue + nc_part.append(term) + + # add up exponents of common bases + for b, e in ordered(iter(c_powers.items())): + # allow 2**x/4 -> 2**(x - 2); don't do this when b and e are + # Numbers since autoevaluation will undo it, e.g. + # 2**(1/3)/4 -> 2**(1/3 - 2) -> 2**(1/3)/4 + if (b and b.is_Rational and not all(ei.is_Number for ei in e) and \ + coeff is not S.One and + b not in (S.One, S.NegativeOne)): + m = multiplicity(abs(b), abs(coeff)) + if m: + e.append(m) + coeff /= b**m + c_powers[b] = Add(*e) + if coeff is not S.One: + if coeff in c_powers: + c_powers[coeff] += S.One + else: + c_powers[coeff] = S.One + + # convert to plain dictionary + c_powers = dict(c_powers) + + # check for base and inverted base pairs + be = list(c_powers.items()) + skip = set() # skip if we already saw them + for b, e in be: + if b in skip: + continue + bpos = b.is_positive or b.is_polar + if bpos: + binv = 1/b + #Special case for float 1 + if b.is_Float and equal_valued(b, 1): + c_powers[b] = S.One + continue + if b != binv and binv in c_powers: + if b.as_numer_denom()[0] is S.One: + c_powers.pop(b) + c_powers[binv] -= e + else: + skip.add(binv) + e = c_powers.pop(binv) + c_powers[b] -= e + + # check for base and negated base pairs + be = list(c_powers.items()) + _n = S.NegativeOne + for b, e in be: + if (b.is_Symbol or b.is_Add) and -b in c_powers and b in c_powers: + if (b.is_positive is not None or e.is_integer): + if e.is_integer or b.is_negative: + c_powers[-b] += c_powers.pop(b) + else: # (-b).is_positive so use its e + e = c_powers.pop(-b) + c_powers[b] += e + if _n in c_powers: + c_powers[_n] += e + else: + c_powers[_n] = e + + # filter c_powers and convert to a list + c_powers = [(b, e) for b, e in c_powers.items() if e] + + # ============================================================== + # check for Mul bases of Rational powers that can be combined with + # separated bases, e.g. x*sqrt(x*y)*sqrt(x*sqrt(x*y)) -> + # (x*sqrt(x*y))**(3/2) + # ---------------- helper functions + + def ratq(x): + '''Return Rational part of x's exponent as it appears in the bkey. + ''' + return bkey(x)[0][1] + + def bkey(b, e=None): + '''Return (b**s, c.q), c.p where e -> c*s. If e is not given then + it will be taken by using as_base_exp() on the input b. + e.g. + x**3/2 -> (x, 2), 3 + x**y -> (x**y, 1), 1 + x**(2*y/3) -> (x**y, 3), 2 + exp(x/2) -> (exp(a), 2), 1 + + ''' + if e is not None: # coming from c_powers or from below + if e.is_Integer: + return (b, S.One), e + elif e.is_Rational: + return (b, Integer(e.q)), Integer(e.p) + else: + c, m = e.as_coeff_Mul(rational=True) + if c is not S.One: + if m.is_integer: + return (b, Integer(c.q)), m*Integer(c.p) + return (b**m, Integer(c.q)), Integer(c.p) + else: + return (b**e, S.One), S.One + else: + return bkey(*b.as_base_exp()) + + def update(b): + '''Decide what to do with base, b. If its exponent is now an + integer multiple of the Rational denominator, then remove it + and put the factors of its base in the common_b dictionary or + update the existing bases if necessary. If it has been zeroed + out, simply remove the base. + ''' + newe, r = divmod(common_b[b], b[1]) + if not r: + common_b.pop(b) + if newe: + for m in Mul.make_args(b[0]**newe): + b, e = bkey(m) + if b not in common_b: + common_b[b] = 0 + common_b[b] += e + if b[1] != 1: + bases.append(b) + # ---------------- end of helper functions + + # assemble a dictionary of the factors having a Rational power + common_b = {} + done = [] + bases = [] + for b, e in c_powers: + b, e = bkey(b, e) + if b in common_b: + common_b[b] = common_b[b] + e + else: + common_b[b] = e + if b[1] != 1 and b[0].is_Mul: + bases.append(b) + bases.sort(key=default_sort_key) # this makes tie-breaking canonical + bases.sort(key=measure, reverse=True) # handle longest first + for base in bases: + if base not in common_b: # it may have been removed already + continue + b, exponent = base + last = False # True when no factor of base is a radical + qlcm = 1 # the lcm of the radical denominators + while True: + bstart = b + qstart = qlcm + + bb = [] # list of factors + ee = [] # (factor's expo. and it's current value in common_b) + for bi in Mul.make_args(b): + bib, bie = bkey(bi) + if bib not in common_b or common_b[bib] < bie: + ee = bb = [] # failed + break + ee.append([bie, common_b[bib]]) + bb.append(bib) + if ee: + # find the number of integral extractions possible + # e.g. [(1, 2), (2, 2)] -> min(2/1, 2/2) -> 1 + min1 = ee[0][1]//ee[0][0] + for i in range(1, len(ee)): + rat = ee[i][1]//ee[i][0] + if rat < 1: + break + min1 = min(min1, rat) + else: + # update base factor counts + # e.g. if ee = [(2, 5), (3, 6)] then min1 = 2 + # and the new base counts will be 5-2*2 and 6-2*3 + for i in range(len(bb)): + common_b[bb[i]] -= min1*ee[i][0] + update(bb[i]) + # update the count of the base + # e.g. x**2*y*sqrt(x*sqrt(y)) the count of x*sqrt(y) + # will increase by 4 to give bkey (x*sqrt(y), 2, 5) + common_b[base] += min1*qstart*exponent + if (last # no more radicals in base + or len(common_b) == 1 # nothing left to join with + or all(k[1] == 1 for k in common_b) # no rad's in common_b + ): + break + # see what we can exponentiate base by to remove any radicals + # so we know what to search for + # e.g. if base were x**(1/2)*y**(1/3) then we should + # exponentiate by 6 and look for powers of x and y in the ratio + # of 2 to 3 + qlcm = lcm([ratq(bi) for bi in Mul.make_args(bstart)]) + if qlcm == 1: + break # we are done + b = bstart**qlcm + qlcm *= qstart + if all(ratq(bi) == 1 for bi in Mul.make_args(b)): + last = True # we are going to be done after this next pass + # this base no longer can find anything to join with and + # since it was longer than any other we are done with it + b, q = base + done.append((b, common_b.pop(base)*Rational(1, q))) + + # update c_powers and get ready to continue with powsimp + c_powers = done + # there may be terms still in common_b that were bases that were + # identified as needing processing, so remove those, too + for (b, q), e in common_b.items(): + if (b.is_Pow or isinstance(b, exp)) and \ + q is not S.One and not b.exp.is_Rational: + b, be = b.as_base_exp() + b = b**(be/q) + else: + b = root(b, q) + c_powers.append((b, e)) + check = len(c_powers) + c_powers = dict(c_powers) + assert len(c_powers) == check # there should have been no duplicates + # ============================================================== + + # rebuild the expression + newexpr = expr.func(*(newexpr + [Pow(b, e) for b, e in c_powers.items()])) + if combine == 'exp': + return expr.func(newexpr, expr.func(*nc_part)) + else: + return recurse(expr.func(*nc_part), combine='base') * \ + recurse(newexpr, combine='base') + + elif combine == 'base': + + # Build c_powers and nc_part. These must both be lists not + # dicts because exp's are not combined. + c_powers = [] + nc_part = [] + for term in expr.args: + if term.is_commutative: + c_powers.append(list(term.as_base_exp())) + else: + nc_part.append(term) + + # Pull out numerical coefficients from exponent if assumptions allow + # e.g., 2**(2*x) => 4**x + for i in range(len(c_powers)): + b, e = c_powers[i] + if not (all(x.is_nonnegative for x in b.as_numer_denom()) or e.is_integer or force or b.is_polar): + continue + exp_c, exp_t = e.as_coeff_Mul(rational=True) + if exp_c is not S.One and exp_t is not S.One: + c_powers[i] = [Pow(b, exp_c), exp_t] + + # Combine bases whenever they have the same exponent and + # assumptions allow + # first gather the potential bases under the common exponent + c_exp = defaultdict(list) + for b, e in c_powers: + if deep: + e = recurse(e) + if e.is_Add and (b.is_positive or e.is_integer): + e = factor_terms(e) + if _coeff_isneg(e): + e = -e + b = 1/b + c_exp[e].append(b) + del c_powers + + # Merge back in the results of the above to form a new product + c_powers = defaultdict(list) + for e in c_exp: + bases = c_exp[e] + + # calculate the new base for e + + if len(bases) == 1: + new_base = bases[0] + elif e.is_integer or force: + new_base = expr.func(*bases) + else: + # see which ones can be joined + unk = [] + nonneg = [] + neg = [] + for bi in bases: + if bi.is_negative: + neg.append(bi) + elif bi.is_nonnegative: + nonneg.append(bi) + elif bi.is_polar: + nonneg.append( + bi) # polar can be treated like non-negative + else: + unk.append(bi) + if len(unk) == 1 and not neg or len(neg) == 1 and not unk: + # a single neg or a single unk can join the rest + nonneg.extend(unk + neg) + unk = neg = [] + elif neg: + # their negative signs cancel in groups of 2*q if we know + # that e = p/q else we have to treat them as unknown + israt = False + if e.is_Rational: + israt = True + else: + p, d = e.as_numer_denom() + if p.is_integer and d.is_integer: + israt = True + if israt: + neg = [-w for w in neg] + unk.extend([S.NegativeOne]*len(neg)) + else: + unk.extend(neg) + neg = [] + del israt + + # these shouldn't be joined + for b in unk: + c_powers[b].append(e) + # here is a new joined base + new_base = expr.func(*(nonneg + neg)) + # if there are positive parts they will just get separated + # again unless some change is made + + def _terms(e): + # return the number of terms of this expression + # when multiplied out -- assuming no joining of terms + if e.is_Add: + return sum(_terms(ai) for ai in e.args) + if e.is_Mul: + return prod([_terms(mi) for mi in e.args]) + return 1 + xnew_base = expand_mul(new_base, deep=False) + if len(Add.make_args(xnew_base)) < _terms(new_base): + new_base = factor_terms(xnew_base) + + c_powers[new_base].append(e) + + # break out the powers from c_powers now + c_part = [Pow(b, ei) for b, e in c_powers.items() for ei in e] + + # we're done + return expr.func(*(c_part + nc_part)) + + else: + raise ValueError("combine must be one of ('all', 'exp', 'base').") + + +def powdenest(eq, force=False, polar=False): + r""" + Collect exponents on powers as assumptions allow. + + Explanation + =========== + + Given ``(bb**be)**e``, this can be simplified as follows: + * if ``bb`` is positive, or + * ``e`` is an integer, or + * ``|be| < 1`` then this simplifies to ``bb**(be*e)`` + + Given a product of powers raised to a power, ``(bb1**be1 * + bb2**be2...)**e``, simplification can be done as follows: + + - if e is positive, the gcd of all bei can be joined with e; + - all non-negative bb can be separated from those that are negative + and their gcd can be joined with e; autosimplification already + handles this separation. + - integer factors from powers that have integers in the denominator + of the exponent can be removed from any term and the gcd of such + integers can be joined with e + + Setting ``force`` to ``True`` will make symbols that are not explicitly + negative behave as though they are positive, resulting in more + denesting. + + Setting ``polar`` to ``True`` will do simplifications on the Riemann surface of + the logarithm, also resulting in more denestings. + + When there are sums of logs in exp() then a product of powers may be + obtained e.g. ``exp(3*(log(a) + 2*log(b)))`` - > ``a**3*b**6``. + + Examples + ======== + + >>> from sympy.abc import a, b, x, y, z + >>> from sympy import Symbol, exp, log, sqrt, symbols, powdenest + + >>> powdenest((x**(2*a/3))**(3*x)) + (x**(2*a/3))**(3*x) + >>> powdenest(exp(3*x*log(2))) + 2**(3*x) + + Assumptions may prevent expansion: + + >>> powdenest(sqrt(x**2)) + sqrt(x**2) + + >>> p = symbols('p', positive=True) + >>> powdenest(sqrt(p**2)) + p + + No other expansion is done. + + >>> i, j = symbols('i,j', integer=True) + >>> powdenest((x**x)**(i + j)) # -X-> (x**x)**i*(x**x)**j + x**(x*(i + j)) + + But exp() will be denested by moving all non-log terms outside of + the function; this may result in the collapsing of the exp to a power + with a different base: + + >>> powdenest(exp(3*y*log(x))) + x**(3*y) + >>> powdenest(exp(y*(log(a) + log(b)))) + (a*b)**y + >>> powdenest(exp(3*(log(a) + log(b)))) + a**3*b**3 + + If assumptions allow, symbols can also be moved to the outermost exponent: + + >>> i = Symbol('i', integer=True) + >>> powdenest(((x**(2*i))**(3*y))**x) + ((x**(2*i))**(3*y))**x + >>> powdenest(((x**(2*i))**(3*y))**x, force=True) + x**(6*i*x*y) + + >>> powdenest(((x**(2*a/3))**(3*y/i))**x) + ((x**(2*a/3))**(3*y/i))**x + >>> powdenest((x**(2*i)*y**(4*i))**z, force=True) + (x*y**2)**(2*i*z) + + >>> n = Symbol('n', negative=True) + + >>> powdenest((x**i)**y, force=True) + x**(i*y) + >>> powdenest((n**i)**x, force=True) + (n**i)**x + + """ + from sympy.simplify.simplify import posify + + if force: + def _denest(b, e): + if not isinstance(b, (Pow, exp)): + return b.is_positive, Pow(b, e, evaluate=False) + return _denest(b.base, b.exp*e) + reps = [] + for p in eq.atoms(Pow, exp): + if isinstance(p.base, (Pow, exp)): + ok, dp = _denest(*p.args) + if ok is not False: + reps.append((p, dp)) + if reps: + eq = eq.subs(reps) + eq, reps = posify(eq) + return powdenest(eq, force=False, polar=polar).xreplace(reps) + + if polar: + eq, rep = polarify(eq) + return unpolarify(powdenest(unpolarify(eq, exponents_only=True)), rep) + + new = powsimp(eq) + return new.xreplace(Transform( + _denest_pow, filter=lambda m: m.is_Pow or isinstance(m, exp))) + +_y = Dummy('y') + + +def _denest_pow(eq): + """ + Denest powers. + + This is a helper function for powdenest that performs the actual + transformation. + """ + from sympy.simplify.simplify import logcombine + + b, e = eq.as_base_exp() + if b.is_Pow or isinstance(b, exp) and e != 1: + new = b._eval_power(e) + if new is not None: + eq = new + b, e = new.as_base_exp() + + # denest exp with log terms in exponent + if b is S.Exp1 and e.is_Mul: + logs = [] + other = [] + for ei in e.args: + if any(isinstance(ai, log) for ai in Add.make_args(ei)): + logs.append(ei) + else: + other.append(ei) + logs = logcombine(Mul(*logs)) + return Pow(exp(logs), Mul(*other)) + + _, be = b.as_base_exp() + if be is S.One and not (b.is_Mul or + b.is_Rational and b.q != 1 or + b.is_positive): + return eq + + # denest eq which is either pos**e or Pow**e or Mul**e or + # Mul(b1**e1, b2**e2) + + # handle polar numbers specially + polars, nonpolars = [], [] + for bb in Mul.make_args(b): + if bb.is_polar: + polars.append(bb.as_base_exp()) + else: + nonpolars.append(bb) + if len(polars) == 1 and not polars[0][0].is_Mul: + return Pow(polars[0][0], polars[0][1]*e)*powdenest(Mul(*nonpolars)**e) + elif polars: + return Mul(*[powdenest(bb**(ee*e)) for (bb, ee) in polars]) \ + *powdenest(Mul(*nonpolars)**e) + + if b.is_Integer: + # use log to see if there is a power here + logb = expand_log(log(b)) + if logb.is_Mul: + c, logb = logb.args + e *= c + base = logb.args[0] + return Pow(base, e) + + # if b is not a Mul or any factor is an atom then there is nothing to do + if not b.is_Mul or any(s.is_Atom for s in Mul.make_args(b)): + return eq + + # let log handle the case of the base of the argument being a Mul, e.g. + # sqrt(x**(2*i)*y**(6*i)) -> x**i*y**(3**i) if x and y are positive; we + # will take the log, expand it, and then factor out the common powers that + # now appear as coefficient. We do this manually since terms_gcd pulls out + # fractions, terms_gcd(x+x*y/2) -> x*(y + 2)/2 and we don't want the 1/2; + # gcd won't pull out numerators from a fraction: gcd(3*x, 9*x/2) -> x but + # we want 3*x. Neither work with noncommutatives. + + def nc_gcd(aa, bb): + a, b = [i.as_coeff_Mul() for i in [aa, bb]] + c = gcd(a[0], b[0]).as_numer_denom()[0] + g = Mul(*(a[1].args_cnc(cset=True)[0] & b[1].args_cnc(cset=True)[0])) + return _keep_coeff(c, g) + + glogb = expand_log(log(b)) + if glogb.is_Add: + args = glogb.args + g = reduce(nc_gcd, args) + if g != 1: + cg, rg = g.as_coeff_Mul() + glogb = _keep_coeff(cg, rg*Add(*[a/g for a in args])) + + # now put the log back together again + if isinstance(glogb, log) or not glogb.is_Mul: + if glogb.args[0].is_Pow or isinstance(glogb.args[0], exp): + glogb = _denest_pow(glogb.args[0]) + if (abs(glogb.exp) < 1) == True: + return Pow(glogb.base, glogb.exp*e) + return eq + + # the log(b) was a Mul so join any adds with logcombine + add = [] + other = [] + for a in glogb.args: + if a.is_Add: + add.append(a) + else: + other.append(a) + return Pow(exp(logcombine(Mul(*add))), e*Mul(*other)) diff --git a/phivenv/Lib/site-packages/sympy/simplify/radsimp.py b/phivenv/Lib/site-packages/sympy/simplify/radsimp.py new file mode 100644 index 0000000000000000000000000000000000000000..c878168ebfbc29fc632577d6325befc120c26f56 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/simplify/radsimp.py @@ -0,0 +1,1234 @@ +from collections import defaultdict + +from sympy.core import sympify, S, Mul, Derivative, Pow +from sympy.core.add import _unevaluated_Add, Add +from sympy.core.assumptions import assumptions +from sympy.core.exprtools import Factors, gcd_terms +from sympy.core.function import _mexpand, expand_mul, expand_power_base +from sympy.core.mul import _keep_coeff, _unevaluated_Mul, _mulsort +from sympy.core.numbers import Rational, zoo, nan +from sympy.core.parameters import global_parameters +from sympy.core.sorting import ordered, default_sort_key +from sympy.core.symbol import Dummy, Wild, symbols +from sympy.functions import exp, sqrt, log +from sympy.functions.elementary.complexes import Abs +from sympy.polys import gcd +from sympy.simplify.sqrtdenest import sqrtdenest +from sympy.utilities.iterables import iterable, sift + + + + +def collect(expr, syms, func=None, evaluate=None, exact=False, distribute_order_term=True): + """ + Collect additive terms of an expression. + + Explanation + =========== + + This function collects additive terms of an expression with respect + to a list of expression up to powers with rational exponents. By the + term symbol here are meant arbitrary expressions, which can contain + powers, products, sums etc. In other words symbol is a pattern which + will be searched for in the expression's terms. + + The input expression is not expanded by :func:`collect`, so user is + expected to provide an expression in an appropriate form. This makes + :func:`collect` more predictable as there is no magic happening behind the + scenes. However, it is important to note, that powers of products are + converted to products of powers using the :func:`~.expand_power_base` + function. + + There are two possible types of output. First, if ``evaluate`` flag is + set, this function will return an expression with collected terms or + else it will return a dictionary with expressions up to rational powers + as keys and collected coefficients as values. + + Examples + ======== + + >>> from sympy import S, collect, expand, factor, Wild + >>> from sympy.abc import a, b, c, x, y + + This function can collect symbolic coefficients in polynomials or + rational expressions. It will manage to find all integer or rational + powers of collection variable:: + + >>> collect(a*x**2 + b*x**2 + a*x - b*x + c, x) + c + x**2*(a + b) + x*(a - b) + + The same result can be achieved in dictionary form:: + + >>> d = collect(a*x**2 + b*x**2 + a*x - b*x + c, x, evaluate=False) + >>> d[x**2] + a + b + >>> d[x] + a - b + >>> d[S.One] + c + + You can also work with multivariate polynomials. However, remember that + this function is greedy so it will care only about a single symbol at time, + in specification order:: + + >>> collect(x**2 + y*x**2 + x*y + y + a*y, [x, y]) + x**2*(y + 1) + x*y + y*(a + 1) + + Also more complicated expressions can be used as patterns:: + + >>> from sympy import sin, log + >>> collect(a*sin(2*x) + b*sin(2*x), sin(2*x)) + (a + b)*sin(2*x) + + >>> collect(a*x*log(x) + b*(x*log(x)), x*log(x)) + x*(a + b)*log(x) + + You can use wildcards in the pattern:: + + >>> w = Wild('w1') + >>> collect(a*x**y - b*x**y, w**y) + x**y*(a - b) + + It is also possible to work with symbolic powers, although it has more + complicated behavior, because in this case power's base and symbolic part + of the exponent are treated as a single symbol:: + + >>> collect(a*x**c + b*x**c, x) + a*x**c + b*x**c + >>> collect(a*x**c + b*x**c, x**c) + x**c*(a + b) + + However if you incorporate rationals to the exponents, then you will get + well known behavior:: + + >>> collect(a*x**(2*c) + b*x**(2*c), x**c) + x**(2*c)*(a + b) + + Note also that all previously stated facts about :func:`collect` function + apply to the exponential function, so you can get:: + + >>> from sympy import exp + >>> collect(a*exp(2*x) + b*exp(2*x), exp(x)) + (a + b)*exp(2*x) + + If you are interested only in collecting specific powers of some symbols + then set ``exact`` flag to True:: + + >>> collect(a*x**7 + b*x**7, x, exact=True) + a*x**7 + b*x**7 + >>> collect(a*x**7 + b*x**7, x**7, exact=True) + x**7*(a + b) + + If you want to collect on any object containing symbols, set + ``exact`` to None: + + >>> collect(x*exp(x) + sin(x)*y + sin(x)*2 + 3*x, x, exact=None) + x*exp(x) + 3*x + (y + 2)*sin(x) + >>> collect(a*x*y + x*y + b*x + x, [x, y], exact=None) + x*y*(a + 1) + x*(b + 1) + + You can also apply this function to differential equations, where + derivatives of arbitrary order can be collected. Note that if you + collect with respect to a function or a derivative of a function, all + derivatives of that function will also be collected. Use + ``exact=True`` to prevent this from happening:: + + >>> from sympy import Derivative as D, collect, Function + >>> f = Function('f') (x) + + >>> collect(a*D(f,x) + b*D(f,x), D(f,x)) + (a + b)*Derivative(f(x), x) + + >>> collect(a*D(D(f,x),x) + b*D(D(f,x),x), f) + (a + b)*Derivative(f(x), (x, 2)) + + >>> collect(a*D(D(f,x),x) + b*D(D(f,x),x), D(f,x), exact=True) + a*Derivative(f(x), (x, 2)) + b*Derivative(f(x), (x, 2)) + + >>> collect(a*D(f,x) + b*D(f,x) + a*f + b*f, f) + (a + b)*f(x) + (a + b)*Derivative(f(x), x) + + Or you can even match both derivative order and exponent at the same time:: + + >>> collect(a*D(D(f,x),x)**2 + b*D(D(f,x),x)**2, D(f,x)) + (a + b)*Derivative(f(x), (x, 2))**2 + + Finally, you can apply a function to each of the collected coefficients. + For example you can factorize symbolic coefficients of polynomial:: + + >>> f = expand((x + a + 1)**3) + + >>> collect(f, x, factor) + x**3 + 3*x**2*(a + 1) + 3*x*(a + 1)**2 + (a + 1)**3 + + .. note:: Arguments are expected to be in expanded form, so you might have + to call :func:`~.expand` prior to calling this function. + + See Also + ======== + + collect_const, collect_sqrt, rcollect + """ + expr = sympify(expr) + syms = [sympify(i) for i in (syms if iterable(syms) else [syms])] + + # replace syms[i] if it is not x, -x or has Wild symbols + cond = lambda x: x.is_Symbol or (-x).is_Symbol or bool( + x.atoms(Wild)) + _, nonsyms = sift(syms, cond, binary=True) + if nonsyms: + reps = dict(zip(nonsyms, [Dummy(**assumptions(i)) for i in nonsyms])) + syms = [reps.get(s, s) for s in syms] + rv = collect(expr.subs(reps), syms, + func=func, evaluate=evaluate, exact=exact, + distribute_order_term=distribute_order_term) + urep = {v: k for k, v in reps.items()} + if not isinstance(rv, dict): + return rv.xreplace(urep) + else: + return {urep.get(k, k).xreplace(urep): v.xreplace(urep) + for k, v in rv.items()} + + # see if other expressions should be considered + if exact is None: + _syms = set() + for i in Add.make_args(expr): + if not i.has_free(*syms) or i in syms: + continue + if not i.is_Mul and i not in syms: + _syms.add(i) + else: + # identify compound generators + g = i._new_rawargs(*i.as_coeff_mul(*syms)[1]) + if g not in syms: + _syms.add(g) + simple = all(i.is_Pow and i.base in syms for i in _syms) + syms = syms + list(ordered(_syms)) + if not simple: + return collect(expr, syms, + func=func, evaluate=evaluate, exact=False, + distribute_order_term=distribute_order_term) + + if evaluate is None: + evaluate = global_parameters.evaluate + + def make_expression(terms): + product = [] + + for term, rat, sym, deriv in terms: + if deriv is not None: + var, order = deriv + for _ in range(order): + term = Derivative(term, var) + + if sym is None: + if rat is S.One: + product.append(term) + else: + product.append(Pow(term, rat)) + else: + product.append(Pow(term, rat*sym)) + + return Mul(*product) + + def parse_derivative(deriv): + # scan derivatives tower in the input expression and return + # underlying function and maximal differentiation order + expr, sym, order = deriv.expr, deriv.variables[0], 1 + + for s in deriv.variables[1:]: + if s == sym: + order += 1 + else: + raise NotImplementedError( + 'Improve MV Derivative support in collect') + + while isinstance(expr, Derivative): + s0 = expr.variables[0] + + if any(s != s0 for s in expr.variables): + raise NotImplementedError( + 'Improve MV Derivative support in collect') + + if s0 == sym: + expr, order = expr.expr, order + len(expr.variables) + else: + break + + return expr, (sym, Rational(order)) + + def parse_term(expr): + """Parses expression expr and outputs tuple (sexpr, rat_expo, + sym_expo, deriv) + where: + - sexpr is the base expression + - rat_expo is the rational exponent that sexpr is raised to + - sym_expo is the symbolic exponent that sexpr is raised to + - deriv contains the derivatives of the expression + + For example, the output of x would be (x, 1, None, None) + the output of 2**x would be (2, 1, x, None). + """ + rat_expo, sym_expo = S.One, None + sexpr, deriv = expr, None + + if expr.is_Pow: + if isinstance(expr.base, Derivative): + sexpr, deriv = parse_derivative(expr.base) + else: + sexpr = expr.base + + if expr.base == S.Exp1: + arg = expr.exp + if arg.is_Rational: + sexpr, rat_expo = S.Exp1, arg + elif arg.is_Mul: + coeff, tail = arg.as_coeff_Mul(rational=True) + sexpr, rat_expo = exp(tail), coeff + + elif expr.exp.is_Number: + rat_expo = expr.exp + else: + coeff, tail = expr.exp.as_coeff_Mul() + + if coeff.is_Number: + rat_expo, sym_expo = coeff, tail + else: + sym_expo = expr.exp + elif isinstance(expr, exp): + arg = expr.exp + if arg.is_Rational: + sexpr, rat_expo = S.Exp1, arg + elif arg.is_Mul: + coeff, tail = arg.as_coeff_Mul(rational=True) + sexpr, rat_expo = exp(tail), coeff + elif isinstance(expr, Derivative): + sexpr, deriv = parse_derivative(expr) + + return sexpr, rat_expo, sym_expo, deriv + + def parse_expression(terms, pattern): + """Parse terms searching for a pattern. + Terms is a list of tuples as returned by parse_terms; + Pattern is an expression treated as a product of factors. + """ + pattern = Mul.make_args(pattern) + + if len(terms) < len(pattern): + # pattern is longer than matched product + # so no chance for positive parsing result + return None + else: + pattern = [parse_term(elem) for elem in pattern] + + terms = terms[:] # need a copy + elems, common_expo, has_deriv = [], None, False + + for elem, e_rat, e_sym, e_ord in pattern: + + if elem.is_Number and e_rat == 1 and e_sym is None: + # a constant is a match for everything + continue + + for j in range(len(terms)): + if terms[j] is None: + continue + + term, t_rat, t_sym, t_ord = terms[j] + + # keeping track of whether one of the terms had + # a derivative or not as this will require rebuilding + # the expression later + if t_ord is not None: + has_deriv = True + + if (term.match(elem) is not None and + (t_sym == e_sym or t_sym is not None and + e_sym is not None and + t_sym.match(e_sym) is not None)): + if exact is False: + # we don't have to be exact so find common exponent + # for both expression's term and pattern's element + expo = t_rat / e_rat + + if common_expo is None: + # first time + common_expo = expo + else: + # common exponent was negotiated before so + # there is no chance for a pattern match unless + # common and current exponents are equal + if common_expo != expo: + common_expo = 1 + else: + # we ought to be exact so all fields of + # interest must match in every details + if e_rat != t_rat or e_ord != t_ord: + continue + + # found common term so remove it from the expression + # and try to match next element in the pattern + elems.append(terms[j]) + terms[j] = None + + break + + else: + # pattern element not found + return None + + return [_f for _f in terms if _f], elems, common_expo, has_deriv + + if evaluate: + if expr.is_Add: + o = expr.getO() or 0 + expr = expr.func(*[ + collect(a, syms, func, True, exact, distribute_order_term) + for a in expr.args if a != o]) + o + elif expr.is_Mul: + return expr.func(*[ + collect(term, syms, func, True, exact, distribute_order_term) + for term in expr.args]) + elif expr.is_Pow: + b = collect( + expr.base, syms, func, True, exact, distribute_order_term) + return Pow(b, expr.exp) + + syms = [expand_power_base(i, deep=False) for i in syms] + + order_term = None + + if distribute_order_term: + order_term = expr.getO() + + if order_term is not None: + if order_term.has(*syms): + order_term = None + else: + expr = expr.removeO() + + summa = [expand_power_base(i, deep=False) for i in Add.make_args(expr)] + + collected, disliked = defaultdict(list), S.Zero + for product in summa: + c, nc = product.args_cnc(split_1=False) + args = list(ordered(c)) + nc + terms = [parse_term(i) for i in args] + small_first = True + + for symbol in syms: + if isinstance(symbol, Derivative) and small_first: + terms = list(reversed(terms)) + small_first = not small_first + result = parse_expression(terms, symbol) + + if result is not None: + if not symbol.is_commutative: + raise AttributeError("Can not collect noncommutative symbol") + + terms, elems, common_expo, has_deriv = result + + # when there was derivative in current pattern we + # will need to rebuild its expression from scratch + if not has_deriv: + margs = [] + for elem in elems: + if elem[2] is None: + e = elem[1] + else: + e = elem[1]*elem[2] + margs.append(Pow(elem[0], e)) + index = Mul(*margs) + else: + index = make_expression(elems) + terms = expand_power_base(make_expression(terms), deep=False) + index = expand_power_base(index, deep=False) + collected[index].append(terms) + break + else: + # none of the patterns matched + disliked += product + # add terms now for each key + collected = {k: Add(*v) for k, v in collected.items()} + + if disliked is not S.Zero: + collected[S.One] = disliked + + if order_term is not None: + for key, val in collected.items(): + collected[key] = val + order_term + + if func is not None: + collected = { + key: func(val) for key, val in collected.items()} + + if evaluate: + return Add(*[key*val for key, val in collected.items()]) + else: + return collected + + +def rcollect(expr, *vars): + """ + Recursively collect sums in an expression. + + Examples + ======== + + >>> from sympy.simplify import rcollect + >>> from sympy.abc import x, y + + >>> expr = (x**2*y + x*y + x + y)/(x + y) + + >>> rcollect(expr, y) + (x + y*(x**2 + x + 1))/(x + y) + + See Also + ======== + + collect, collect_const, collect_sqrt + """ + if expr.is_Atom or not expr.has(*vars): + return expr + else: + expr = expr.__class__(*[rcollect(arg, *vars) for arg in expr.args]) + + if expr.is_Add: + return collect(expr, vars) + else: + return expr + + +def collect_sqrt(expr, evaluate=None): + """Return expr with terms having common square roots collected together. + If ``evaluate`` is False a count indicating the number of sqrt-containing + terms will be returned and, if non-zero, the terms of the Add will be + returned, else the expression itself will be returned as a single term. + If ``evaluate`` is True, the expression with any collected terms will be + returned. + + Note: since I = sqrt(-1), it is collected, too. + + Examples + ======== + + >>> from sympy import sqrt + >>> from sympy.simplify.radsimp import collect_sqrt + >>> from sympy.abc import a, b + + >>> r2, r3, r5 = [sqrt(i) for i in [2, 3, 5]] + >>> collect_sqrt(a*r2 + b*r2) + sqrt(2)*(a + b) + >>> collect_sqrt(a*r2 + b*r2 + a*r3 + b*r3) + sqrt(2)*(a + b) + sqrt(3)*(a + b) + >>> collect_sqrt(a*r2 + b*r2 + a*r3 + b*r5) + sqrt(3)*a + sqrt(5)*b + sqrt(2)*(a + b) + + If evaluate is False then the arguments will be sorted and + returned as a list and a count of the number of sqrt-containing + terms will be returned: + + >>> collect_sqrt(a*r2 + b*r2 + a*r3 + b*r5, evaluate=False) + ((sqrt(3)*a, sqrt(5)*b, sqrt(2)*(a + b)), 3) + >>> collect_sqrt(a*sqrt(2) + b, evaluate=False) + ((b, sqrt(2)*a), 1) + >>> collect_sqrt(a + b, evaluate=False) + ((a + b,), 0) + + See Also + ======== + + collect, collect_const, rcollect + """ + if evaluate is None: + evaluate = global_parameters.evaluate + # this step will help to standardize any complex arguments + # of sqrts + coeff, expr = expr.as_content_primitive() + vars = set() + for a in Add.make_args(expr): + for m in a.args_cnc()[0]: + if m.is_number and ( + m.is_Pow and m.exp.is_Rational and m.exp.q == 2 or + m is S.ImaginaryUnit): + vars.add(m) + + # we only want radicals, so exclude Number handling; in this case + # d will be evaluated + d = collect_const(expr, *vars, Numbers=False) + hit = expr != d + + if not evaluate: + nrad = 0 + # make the evaluated args canonical + args = list(ordered(Add.make_args(d))) + for i, m in enumerate(args): + c, nc = m.args_cnc() + for ci in c: + # XXX should this be restricted to ci.is_number as above? + if ci.is_Pow and ci.exp.is_Rational and ci.exp.q == 2 or \ + ci is S.ImaginaryUnit: + nrad += 1 + break + args[i] *= coeff + if not (hit or nrad): + args = [Add(*args)] + return tuple(args), nrad + + return coeff*d + + +def collect_abs(expr): + """Return ``expr`` with arguments of multiple Abs in a term collected + under a single instance. + + Examples + ======== + + >>> from sympy.simplify.radsimp import collect_abs + >>> from sympy.abc import x + >>> collect_abs(abs(x + 1)/abs(x**2 - 1)) + Abs((x + 1)/(x**2 - 1)) + >>> collect_abs(abs(1/x)) + Abs(1/x) + """ + def _abs(mul): + c, nc = mul.args_cnc() + a = [] + o = [] + for i in c: + if isinstance(i, Abs): + a.append(i.args[0]) + elif isinstance(i, Pow) and isinstance(i.base, Abs) and i.exp.is_real: + a.append(i.base.args[0]**i.exp) + else: + o.append(i) + if len(a) < 2 and not any(i.exp.is_negative for i in a if isinstance(i, Pow)): + return mul + absarg = Mul(*a) + A = Abs(absarg) + args = [A] + args.extend(o) + if not A.has(Abs): + args.extend(nc) + return Mul(*args) + if not isinstance(A, Abs): + # reevaluate and make it unevaluated + A = Abs(absarg, evaluate=False) + args[0] = A + _mulsort(args) + args.extend(nc) # nc always go last + return Mul._from_args(args, is_commutative=not nc) + + return expr.replace( + lambda x: isinstance(x, Mul), + lambda x: _abs(x)).replace( + lambda x: isinstance(x, Pow), + lambda x: _abs(x)) + + +def collect_const(expr, *vars, Numbers=True): + """A non-greedy collection of terms with similar number coefficients in + an Add expr. If ``vars`` is given then only those constants will be + targeted. Although any Number can also be targeted, if this is not + desired set ``Numbers=False`` and no Float or Rational will be collected. + + Parameters + ========== + + expr : SymPy expression + This parameter defines the expression the expression from which + terms with similar coefficients are to be collected. A non-Add + expression is returned as it is. + + vars : variable length collection of Numbers, optional + Specifies the constants to target for collection. Can be multiple in + number. + + Numbers : bool + Specifies to target all instance of + :class:`sympy.core.numbers.Number` class. If ``Numbers=False``, then + no Float or Rational will be collected. + + Returns + ======= + + expr : Expr + Returns an expression with similar coefficient terms collected. + + Examples + ======== + + >>> from sympy import sqrt + >>> from sympy.abc import s, x, y, z + >>> from sympy.simplify.radsimp import collect_const + >>> collect_const(sqrt(3) + sqrt(3)*(1 + sqrt(2))) + sqrt(3)*(sqrt(2) + 2) + >>> collect_const(sqrt(3)*s + sqrt(7)*s + sqrt(3) + sqrt(7)) + (sqrt(3) + sqrt(7))*(s + 1) + >>> s = sqrt(2) + 2 + >>> collect_const(sqrt(3)*s + sqrt(3) + sqrt(7)*s + sqrt(7)) + (sqrt(2) + 3)*(sqrt(3) + sqrt(7)) + >>> collect_const(sqrt(3)*s + sqrt(3) + sqrt(7)*s + sqrt(7), sqrt(3)) + sqrt(7) + sqrt(3)*(sqrt(2) + 3) + sqrt(7)*(sqrt(2) + 2) + + The collection is sign-sensitive, giving higher precedence to the + unsigned values: + + >>> collect_const(x - y - z) + x - (y + z) + >>> collect_const(-y - z) + -(y + z) + >>> collect_const(2*x - 2*y - 2*z, 2) + 2*(x - y - z) + >>> collect_const(2*x - 2*y - 2*z, -2) + 2*x - 2*(y + z) + + See Also + ======== + + collect, collect_sqrt, rcollect + """ + if not expr.is_Add: + return expr + + recurse = False + + if not vars: + recurse = True + vars = set() + for a in expr.args: + for m in Mul.make_args(a): + if m.is_number: + vars.add(m) + else: + vars = sympify(vars) + if not Numbers: + vars = [v for v in vars if not v.is_Number] + + vars = list(ordered(vars)) + for v in vars: + terms = defaultdict(list) + Fv = Factors(v) + for m in Add.make_args(expr): + f = Factors(m) + q, r = f.div(Fv) + if r.is_one: + # only accept this as a true factor if + # it didn't change an exponent from an Integer + # to a non-Integer, e.g. 2/sqrt(2) -> sqrt(2) + # -- we aren't looking for this sort of change + fwas = f.factors.copy() + fnow = q.factors + if not any(k in fwas and fwas[k].is_Integer and not + fnow[k].is_Integer for k in fnow): + terms[v].append(q.as_expr()) + continue + terms[S.One].append(m) + + args = [] + hit = False + uneval = False + for k in ordered(terms): + v = terms[k] + if k is S.One: + args.extend(v) + continue + + if len(v) > 1: + v = Add(*v) + hit = True + if recurse and v != expr: + vars.append(v) + else: + v = v[0] + + # be careful not to let uneval become True unless + # it must be because it's going to be more expensive + # to rebuild the expression as an unevaluated one + if Numbers and k.is_Number and v.is_Add: + args.append(_keep_coeff(k, v, sign=True)) + uneval = True + else: + args.append(k*v) + + if hit: + if uneval: + expr = _unevaluated_Add(*args) + else: + expr = Add(*args) + if not expr.is_Add: + break + + return expr + + +def radsimp(expr, symbolic=True, max_terms=4): + r""" + Rationalize the denominator by removing square roots. + + Explanation + =========== + + The expression returned from radsimp must be used with caution + since if the denominator contains symbols, it will be possible to make + substitutions that violate the assumptions of the simplification process: + that for a denominator matching a + b*sqrt(c), a != +/-b*sqrt(c). (If + there are no symbols, this assumptions is made valid by collecting terms + of sqrt(c) so the match variable ``a`` does not contain ``sqrt(c)``.) If + you do not want the simplification to occur for symbolic denominators, set + ``symbolic`` to False. + + If there are more than ``max_terms`` radical terms then the expression is + returned unchanged. + + Examples + ======== + + >>> from sympy import radsimp, sqrt, Symbol, pprint + >>> from sympy import factor_terms, fraction, signsimp + >>> from sympy.simplify.radsimp import collect_sqrt + >>> from sympy.abc import a, b, c + + >>> radsimp(1/(2 + sqrt(2))) + (2 - sqrt(2))/2 + >>> x,y = map(Symbol, 'xy') + >>> e = ((2 + 2*sqrt(2))*x + (2 + sqrt(8))*y)/(2 + sqrt(2)) + >>> radsimp(e) + sqrt(2)*(x + y) + + No simplification beyond removal of the gcd is done. One might + want to polish the result a little, however, by collecting + square root terms: + + >>> r2 = sqrt(2) + >>> r5 = sqrt(5) + >>> ans = radsimp(1/(y*r2 + x*r2 + a*r5 + b*r5)); pprint(ans) + ___ ___ ___ ___ + \/ 5 *a + \/ 5 *b - \/ 2 *x - \/ 2 *y + ------------------------------------------ + 2 2 2 2 + 5*a + 10*a*b + 5*b - 2*x - 4*x*y - 2*y + + >>> n, d = fraction(ans) + >>> pprint(factor_terms(signsimp(collect_sqrt(n))/d, radical=True)) + ___ ___ + \/ 5 *(a + b) - \/ 2 *(x + y) + ------------------------------------------ + 2 2 2 2 + 5*a + 10*a*b + 5*b - 2*x - 4*x*y - 2*y + + If radicals in the denominator cannot be removed or there is no denominator, + the original expression will be returned. + + >>> radsimp(sqrt(2)*x + sqrt(2)) + sqrt(2)*x + sqrt(2) + + Results with symbols will not always be valid for all substitutions: + + >>> eq = 1/(a + b*sqrt(c)) + >>> eq.subs(a, b*sqrt(c)) + 1/(2*b*sqrt(c)) + >>> radsimp(eq).subs(a, b*sqrt(c)) + nan + + If ``symbolic=False``, symbolic denominators will not be transformed (but + numeric denominators will still be processed): + + >>> radsimp(eq, symbolic=False) + 1/(a + b*sqrt(c)) + + """ + from sympy.core.expr import Expr + from sympy.simplify.simplify import signsimp + + syms = symbols("a:d A:D") + def _num(rterms): + # return the multiplier that will simplify the expression described + # by rterms [(sqrt arg, coeff), ... ] + a, b, c, d, A, B, C, D = syms + if len(rterms) == 2: + reps = dict(list(zip([A, a, B, b], [j for i in rterms for j in i]))) + return ( + sqrt(A)*a - sqrt(B)*b).xreplace(reps) + if len(rterms) == 3: + reps = dict(list(zip([A, a, B, b, C, c], [j for i in rterms for j in i]))) + return ( + (sqrt(A)*a + sqrt(B)*b - sqrt(C)*c)*(2*sqrt(A)*sqrt(B)*a*b - A*a**2 - + B*b**2 + C*c**2)).xreplace(reps) + elif len(rterms) == 4: + reps = dict(list(zip([A, a, B, b, C, c, D, d], [j for i in rterms for j in i]))) + return ((sqrt(A)*a + sqrt(B)*b - sqrt(C)*c - sqrt(D)*d)*(2*sqrt(A)*sqrt(B)*a*b + - A*a**2 - B*b**2 - 2*sqrt(C)*sqrt(D)*c*d + C*c**2 + + D*d**2)*(-8*sqrt(A)*sqrt(B)*sqrt(C)*sqrt(D)*a*b*c*d + A**2*a**4 - + 2*A*B*a**2*b**2 - 2*A*C*a**2*c**2 - 2*A*D*a**2*d**2 + B**2*b**4 - + 2*B*C*b**2*c**2 - 2*B*D*b**2*d**2 + C**2*c**4 - 2*C*D*c**2*d**2 + + D**2*d**4)).xreplace(reps) + elif len(rterms) == 1: + return sqrt(rterms[0][0]) + else: + raise NotImplementedError + + def ispow2(d, log2=False): + if not d.is_Pow: + return False + e = d.exp + if e.is_Rational and e.q == 2 or symbolic and denom(e) == 2: + return True + if log2: + q = 1 + if e.is_Rational: + q = e.q + elif symbolic: + d = denom(e) + if d.is_Integer: + q = d + if q != 1 and log(q, 2).is_Integer: + return True + return False + + def handle(expr): + # Handle first reduces to the case + # expr = 1/d, where d is an add, or d is base**p/2. + # We do this by recursively calling handle on each piece. + from sympy.simplify.simplify import nsimplify + + if expr.is_Atom: + return expr + elif not isinstance(expr, Expr): + return expr.func(*[handle(a) for a in expr.args]) + + n, d = fraction(expr) + + if d.is_Atom and n.is_Atom: + return expr + elif not n.is_Atom: + n = n.func(*[handle(a) for a in n.args]) + return _unevaluated_Mul(n, handle(1/d)) + elif n is not S.One: + return _unevaluated_Mul(n, handle(1/d)) + elif d.is_Mul: + return _unevaluated_Mul(*[handle(1/d) for d in d.args]) + + # By this step, expr is 1/d, and d is not a mul. + if not symbolic and d.free_symbols: + return expr + + if ispow2(d): + d2 = sqrtdenest(sqrt(d.base))**numer(d.exp) + if d2 != d: + return handle(1/d2) + elif d.is_Pow and (d.exp.is_integer or d.base.is_positive): + # (1/d**i) = (1/d)**i + return handle(1/d.base)**d.exp + + if not (d.is_Add or ispow2(d)): + return 1/d.func(*[handle(a) for a in d.args]) + + # handle 1/d treating d as an Add (though it may not be) + + keep = True # keep changes that are made + + # flatten it and collect radicals after checking for special + # conditions + d = _mexpand(d) + + # did it change? + if d.is_Atom: + return 1/d + + # is it a number that might be handled easily? + if d.is_number: + _d = nsimplify(d) + if _d.is_Number and _d.equals(d): + return 1/_d + + while True: + # collect similar terms + collected = defaultdict(list) + for m in Add.make_args(d): # d might have become non-Add + p2 = [] + other = [] + for i in Mul.make_args(m): + if ispow2(i, log2=True): + p2.append(i.base if i.exp is S.Half else i.base**(2*i.exp)) + elif i is S.ImaginaryUnit: + p2.append(S.NegativeOne) + else: + other.append(i) + collected[tuple(ordered(p2))].append(Mul(*other)) + rterms = list(ordered(list(collected.items()))) + rterms = [(Mul(*i), Add(*j)) for i, j in rterms] + nrad = len(rterms) - (1 if rterms[0][0] is S.One else 0) + if nrad < 1: + break + elif nrad > max_terms: + # there may have been invalid operations leading to this point + # so don't keep changes, e.g. this expression is troublesome + # in collecting terms so as not to raise the issue of 2834: + # r = sqrt(sqrt(5) + 5) + # eq = 1/(sqrt(5)*r + 2*sqrt(5)*sqrt(-sqrt(5) + 5) + 5*r) + keep = False + break + if len(rterms) > 4: + # in general, only 4 terms can be removed with repeated squaring + # but other considerations can guide selection of radical terms + # so that radicals are removed + if all(x.is_Integer and (y**2).is_Rational for x, y in rterms): + nd, d = rad_rationalize(S.One, Add._from_args( + [sqrt(x)*y for x, y in rterms])) + n *= nd + else: + # is there anything else that might be attempted? + keep = False + break + from sympy.simplify.powsimp import powsimp, powdenest + + num = powsimp(_num(rterms)) + n *= num + d *= num + d = powdenest(_mexpand(d), force=symbolic) + if d.has(S.Zero, nan, zoo): + return expr + if d.is_Atom: + break + + if not keep: + return expr + return _unevaluated_Mul(n, 1/d) + + if not isinstance(expr, Expr): + return expr.func(*[radsimp(a, symbolic=symbolic, max_terms=max_terms) for a in expr.args]) + + coeff, expr = expr.as_coeff_Add() + expr = expr.normal() + old = fraction(expr) + n, d = fraction(handle(expr)) + if old != (n, d): + if not d.is_Atom: + was = (n, d) + n = signsimp(n, evaluate=False) + d = signsimp(d, evaluate=False) + u = Factors(_unevaluated_Mul(n, 1/d)) + u = _unevaluated_Mul(*[k**v for k, v in u.factors.items()]) + n, d = fraction(u) + if old == (n, d): + n, d = was + n = expand_mul(n) + if d.is_Number or d.is_Add: + n2, d2 = fraction(gcd_terms(_unevaluated_Mul(n, 1/d))) + if d2.is_Number or (d2.count_ops() <= d.count_ops()): + n, d = [signsimp(i) for i in (n2, d2)] + if n.is_Mul and n.args[0].is_Number: + n = n.func(*n.args) + + return coeff + _unevaluated_Mul(n, 1/d) + + +def rad_rationalize(num, den): + """ + Rationalize ``num/den`` by removing square roots in the denominator; + num and den are sum of terms whose squares are positive rationals. + + Examples + ======== + + >>> from sympy import sqrt + >>> from sympy.simplify.radsimp import rad_rationalize + >>> rad_rationalize(sqrt(3), 1 + sqrt(2)/3) + (-sqrt(3) + sqrt(6)/3, -7/9) + """ + if not den.is_Add: + return num, den + g, a, b = split_surds(den) + a = a*sqrt(g) + num = _mexpand((a - b)*num) + den = _mexpand(a**2 - b**2) + return rad_rationalize(num, den) + + +def fraction(expr, exact=False): + """Returns a pair with expression's numerator and denominator. + If the given expression is not a fraction then this function + will return the tuple (expr, 1). + + This function will not make any attempt to simplify nested + fractions or to do any term rewriting at all. + + If only one of the numerator/denominator pair is needed then + use numer(expr) or denom(expr) functions respectively. + + >>> from sympy import fraction, Rational, Symbol + >>> from sympy.abc import x, y + + >>> fraction(x/y) + (x, y) + >>> fraction(x) + (x, 1) + + >>> fraction(1/y**2) + (1, y**2) + + >>> fraction(x*y/2) + (x*y, 2) + >>> fraction(Rational(1, 2)) + (1, 2) + + This function will also work fine with assumptions: + + >>> k = Symbol('k', negative=True) + >>> fraction(x * y**k) + (x, y**(-k)) + + If we know nothing about sign of some exponent and ``exact`` + flag is unset, then the exponent's structure will + be analyzed and pretty fraction will be returned: + + >>> from sympy import exp, Mul + >>> fraction(2*x**(-y)) + (2, x**y) + + >>> fraction(exp(-x)) + (1, exp(x)) + + >>> fraction(exp(-x), exact=True) + (exp(-x), 1) + + The ``exact`` flag will also keep any unevaluated Muls from + being evaluated: + + >>> u = Mul(2, x + 1, evaluate=False) + >>> fraction(u) + (2*x + 2, 1) + >>> fraction(u, exact=True) + (2*(x + 1), 1) + """ + expr = sympify(expr) + + numer, denom = [], [] + + for term in Mul.make_args(expr): + if term.is_commutative and (term.is_Pow or isinstance(term, exp)): + b, ex = term.as_base_exp() + if ex.is_negative: + if ex is S.NegativeOne: + denom.append(b) + elif exact: + if ex.is_constant(): + denom.append(Pow(b, -ex)) + else: + numer.append(term) + else: + denom.append(Pow(b, -ex)) + elif ex.is_positive: + numer.append(term) + elif not exact and ex.is_Mul: + n, d = term.as_numer_denom() # this will cause evaluation + if n != 1: + numer.append(n) + denom.append(d) + else: + numer.append(term) + elif term.is_Rational and not term.is_Integer: + if term.p != 1: + numer.append(term.p) + denom.append(term.q) + else: + numer.append(term) + return Mul(*numer, evaluate=not exact), Mul(*denom, evaluate=not exact) + + +def numer(expr, exact=False): # default matches fraction's default + return fraction(expr, exact=exact)[0] + + +def denom(expr, exact=False): # default matches fraction's default + return fraction(expr, exact=exact)[1] + + +def fraction_expand(expr, **hints): + return expr.expand(frac=True, **hints) + + +def numer_expand(expr, **hints): + # default matches fraction's default + a, b = fraction(expr, exact=hints.get('exact', False)) + return a.expand(numer=True, **hints) / b + + +def denom_expand(expr, **hints): + # default matches fraction's default + a, b = fraction(expr, exact=hints.get('exact', False)) + return a / b.expand(denom=True, **hints) + + +expand_numer = numer_expand +expand_denom = denom_expand +expand_fraction = fraction_expand + + +def split_surds(expr): + """ + Split an expression with terms whose squares are positive rationals + into a sum of terms whose surds squared have gcd equal to g + and a sum of terms with surds squared prime with g. + + Examples + ======== + + >>> from sympy import sqrt + >>> from sympy.simplify.radsimp import split_surds + >>> split_surds(3*sqrt(3) + sqrt(5)/7 + sqrt(6) + sqrt(10) + sqrt(15)) + (3, sqrt(2) + sqrt(5) + 3, sqrt(5)/7 + sqrt(10)) + """ + args = sorted(expr.args, key=default_sort_key) + coeff_muls = [x.as_coeff_Mul() for x in args] + surds = [x[1]**2 for x in coeff_muls if x[1].is_Pow] + surds.sort(key=default_sort_key) + g, b1, b2 = _split_gcd(*surds) + g2 = g + if not b2 and len(b1) >= 2: + b1n = [x/g for x in b1] + b1n = [x for x in b1n if x != 1] + # only a common factor has been factored; split again + g1, b1n, b2 = _split_gcd(*b1n) + g2 = g*g1 + a1v, a2v = [], [] + for c, s in coeff_muls: + if s.is_Pow and s.exp == S.Half: + s1 = s.base + if s1 in b1: + a1v.append(c*sqrt(s1/g2)) + else: + a2v.append(c*s) + else: + a2v.append(c*s) + a = Add(*a1v) + b = Add(*a2v) + return g2, a, b + + +def _split_gcd(*a): + """ + Split the list of integers ``a`` into a list of integers, ``a1`` having + ``g = gcd(a1)``, and a list ``a2`` whose elements are not divisible by + ``g``. Returns ``g, a1, a2``. + + Examples + ======== + + >>> from sympy.simplify.radsimp import _split_gcd + >>> _split_gcd(55, 35, 22, 14, 77, 10) + (5, [55, 35, 10], [22, 14, 77]) + """ + g = a[0] + b1 = [g] + b2 = [] + for x in a[1:]: + g1 = gcd(g, x) + if g1 == 1: + b2.append(x) + else: + g = g1 + b1.append(x) + return g, b1, b2 diff --git a/phivenv/Lib/site-packages/sympy/simplify/ratsimp.py b/phivenv/Lib/site-packages/sympy/simplify/ratsimp.py new file mode 100644 index 0000000000000000000000000000000000000000..95751fab47f585d3ae2e1289f014fba0f2708224 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/simplify/ratsimp.py @@ -0,0 +1,222 @@ +from itertools import combinations_with_replacement +from sympy.core import symbols, Add, Dummy +from sympy.core.numbers import Rational +from sympy.polys import cancel, ComputationFailed, parallel_poly_from_expr, reduced, Poly +from sympy.polys.monomials import Monomial, monomial_div +from sympy.polys.polyerrors import DomainError, PolificationFailed +from sympy.utilities.misc import debug, debugf + +def ratsimp(expr): + """ + Put an expression over a common denominator, cancel and reduce. + + Examples + ======== + + >>> from sympy import ratsimp + >>> from sympy.abc import x, y + >>> ratsimp(1/x + 1/y) + (x + y)/(x*y) + """ + + f, g = cancel(expr).as_numer_denom() + try: + Q, r = reduced(f, [g], field=True, expand=False) + except ComputationFailed: + return f/g + + return Add(*Q) + cancel(r/g) + + +def ratsimpmodprime(expr, G, *gens, quick=True, polynomial=False, **args): + """ + Simplifies a rational expression ``expr`` modulo the prime ideal + generated by ``G``. ``G`` should be a Groebner basis of the + ideal. + + Examples + ======== + + >>> from sympy.simplify.ratsimp import ratsimpmodprime + >>> from sympy.abc import x, y + >>> eq = (x + y**5 + y)/(x - y) + >>> ratsimpmodprime(eq, [x*y**5 - x - y], x, y, order='lex') + (-x**2 - x*y - x - y)/(-x**2 + x*y) + + If ``polynomial`` is ``False``, the algorithm computes a rational + simplification which minimizes the sum of the total degrees of + the numerator and the denominator. + + If ``polynomial`` is ``True``, this function just brings numerator and + denominator into a canonical form. This is much faster, but has + potentially worse results. + + References + ========== + + .. [1] M. Monagan, R. Pearce, Rational Simplification Modulo a Polynomial + Ideal, https://dl.acm.org/doi/pdf/10.1145/1145768.1145809 + (specifically, the second algorithm) + """ + from sympy.solvers.solvers import solve + + debug('ratsimpmodprime', expr) + + # usual preparation of polynomials: + + num, denom = cancel(expr).as_numer_denom() + + try: + polys, opt = parallel_poly_from_expr([num, denom] + G, *gens, **args) + except PolificationFailed: + return expr + + domain = opt.domain + + if domain.has_assoc_Field: + opt.domain = domain.get_field() + else: + raise DomainError( + "Cannot compute rational simplification over %s" % domain) + + # compute only once + leading_monomials = [g.LM(opt.order) for g in polys[2:]] + tested = set() + + def staircase(n): + """ + Compute all monomials with degree less than ``n`` that are + not divisible by any element of ``leading_monomials``. + """ + if n == 0: + return [1] + S = [] + for mi in combinations_with_replacement(range(len(opt.gens)), n): + m = [0]*len(opt.gens) + for i in mi: + m[i] += 1 + if all(monomial_div(m, lmg) is None for lmg in + leading_monomials): + S.append(m) + + return [Monomial(s).as_expr(*opt.gens) for s in S] + staircase(n - 1) + + def _ratsimpmodprime(a, b, allsol, N=0, D=0): + r""" + Computes a rational simplification of ``a/b`` which minimizes + the sum of the total degrees of the numerator and the denominator. + + Explanation + =========== + + The algorithm proceeds by looking at ``a * d - b * c`` modulo + the ideal generated by ``G`` for some ``c`` and ``d`` with degree + less than ``a`` and ``b`` respectively. + The coefficients of ``c`` and ``d`` are indeterminates and thus + the coefficients of the normalform of ``a * d - b * c`` are + linear polynomials in these indeterminates. + If these linear polynomials, considered as system of + equations, have a nontrivial solution, then `\frac{a}{b} + \equiv \frac{c}{d}` modulo the ideal generated by ``G``. So, + by construction, the degree of ``c`` and ``d`` is less than + the degree of ``a`` and ``b``, so a simpler representation + has been found. + After a simpler representation has been found, the algorithm + tries to reduce the degree of the numerator and denominator + and returns the result afterwards. + + As an extension, if quick=False, we look at all possible degrees such + that the total degree is less than *or equal to* the best current + solution. We retain a list of all solutions of minimal degree, and try + to find the best one at the end. + """ + c, d = a, b + steps = 0 + + maxdeg = a.total_degree() + b.total_degree() + if quick: + bound = maxdeg - 1 + else: + bound = maxdeg + while N + D <= bound: + if (N, D) in tested: + break + tested.add((N, D)) + + M1 = staircase(N) + M2 = staircase(D) + debugf('%s / %s: %s, %s', (N, D, M1, M2)) + + Cs = symbols("c:%d" % len(M1), cls=Dummy) + Ds = symbols("d:%d" % len(M2), cls=Dummy) + ng = Cs + Ds + + c_hat = Poly( + sum(Cs[i] * M1[i] for i in range(len(M1))), opt.gens + ng) + d_hat = Poly( + sum(Ds[i] * M2[i] for i in range(len(M2))), opt.gens + ng) + + r = reduced(a * d_hat - b * c_hat, G, opt.gens + ng, + order=opt.order, polys=True)[1] + + S = Poly(r, gens=opt.gens).coeffs() + sol = solve(S, Cs + Ds, particular=True, quick=True) + + if sol and not all(s == 0 for s in sol.values()): + c = c_hat.subs(sol) + d = d_hat.subs(sol) + + # The "free" variables occurring before as parameters + # might still be in the substituted c, d, so set them + # to the value chosen before: + c = c.subs(dict(list(zip(Cs + Ds, [1] * (len(Cs) + len(Ds)))))) + d = d.subs(dict(list(zip(Cs + Ds, [1] * (len(Cs) + len(Ds)))))) + + c = Poly(c, opt.gens) + d = Poly(d, opt.gens) + if d == 0: + raise ValueError('Ideal not prime?') + + allsol.append((c_hat, d_hat, S, Cs + Ds)) + if N + D != maxdeg: + allsol = [allsol[-1]] + + break + + steps += 1 + N += 1 + D += 1 + + if steps > 0: + c, d, allsol = _ratsimpmodprime(c, d, allsol, N, D - steps) + c, d, allsol = _ratsimpmodprime(c, d, allsol, N - steps, D) + + return c, d, allsol + + # preprocessing. this improves performance a bit when deg(num) + # and deg(denom) are large: + num = reduced(num, G, opt.gens, order=opt.order)[1] + denom = reduced(denom, G, opt.gens, order=opt.order)[1] + + if polynomial: + return (num/denom).cancel() + + c, d, allsol = _ratsimpmodprime( + Poly(num, opt.gens, domain=opt.domain), Poly(denom, opt.gens, domain=opt.domain), []) + if not quick and allsol: + debugf('Looking for best minimal solution. Got: %s', len(allsol)) + newsol = [] + for c_hat, d_hat, S, ng in allsol: + sol = solve(S, ng, particular=True, quick=False) + # all values of sol should be numbers; if not, solve is broken + newsol.append((c_hat.subs(sol), d_hat.subs(sol))) + c, d = min(newsol, key=lambda x: len(x[0].terms()) + len(x[1].terms())) + + if not domain.is_Field: + cn, c = c.clear_denoms(convert=True) + dn, d = d.clear_denoms(convert=True) + r = Rational(cn, dn) + else: + r = Rational(1) + + return (c*r.q)/(d*r.p) diff --git a/phivenv/Lib/site-packages/sympy/simplify/simplify.py b/phivenv/Lib/site-packages/sympy/simplify/simplify.py new file mode 100644 index 0000000000000000000000000000000000000000..8b315cc20c19fc10c37b903d16129a7f5579ecd3 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/simplify/simplify.py @@ -0,0 +1,2164 @@ +from __future__ import annotations + +from typing import overload + +from collections import defaultdict + +from sympy.concrete.products import Product +from sympy.concrete.summations import Sum +from sympy.core import (Basic, S, Add, Mul, Pow, Symbol, sympify, + expand_func, Function, Dummy, Expr, factor_terms, + expand_power_exp, Eq) +from sympy.core.exprtools import factor_nc +from sympy.core.parameters import global_parameters +from sympy.core.function import (expand_log, count_ops, _mexpand, + nfloat, expand_mul, expand) +from sympy.core.numbers import Float, I, pi, Rational, equal_valued +from sympy.core.relational import Relational +from sympy.core.rules import Transform +from sympy.core.sorting import ordered +from sympy.core.sympify import _sympify +from sympy.core.traversal import bottom_up as _bottom_up, walk as _walk +from sympy.functions import gamma, exp, sqrt, log, exp_polar, re +from sympy.functions.combinatorial.factorials import CombinatorialFunction +from sympy.functions.elementary.complexes import unpolarify, Abs, sign +from sympy.functions.elementary.exponential import ExpBase +from sympy.functions.elementary.hyperbolic import HyperbolicFunction +from sympy.functions.elementary.integers import ceiling +from sympy.functions.elementary.piecewise import (Piecewise, piecewise_fold, + piecewise_simplify) +from sympy.functions.elementary.trigonometric import TrigonometricFunction +from sympy.functions.special.bessel import (BesselBase, besselj, besseli, + besselk, bessely, jn) +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.integrals.integrals import Integral +from sympy.logic.boolalg import Boolean +from sympy.matrices.expressions import (MatrixExpr, MatAdd, MatMul, + MatPow, MatrixSymbol) +from sympy.polys import together, cancel, factor +from sympy.polys.numberfields.minpoly import _is_sum_surds, _minimal_polynomial_sq +from sympy.sets.sets import Set +from sympy.simplify.combsimp import combsimp +from sympy.simplify.cse_opts import sub_pre, sub_post +from sympy.simplify.hyperexpand import hyperexpand +from sympy.simplify.powsimp import powsimp +from sympy.simplify.radsimp import radsimp, fraction, collect_abs +from sympy.simplify.sqrtdenest import sqrtdenest +from sympy.simplify.trigsimp import trigsimp, exptrigsimp +from sympy.utilities.decorator import deprecated +from sympy.utilities.iterables import has_variety, sift, subsets, iterable +from sympy.utilities.misc import as_int + +import mpmath + + +def separatevars(expr, symbols=[], dict=False, force=False): + """ + Separates variables in an expression, if possible. By + default, it separates with respect to all symbols in an + expression and collects constant coefficients that are + independent of symbols. + + Explanation + =========== + + If ``dict=True`` then the separated terms will be returned + in a dictionary keyed to their corresponding symbols. + By default, all symbols in the expression will appear as + keys; if symbols are provided, then all those symbols will + be used as keys, and any terms in the expression containing + other symbols or non-symbols will be returned keyed to the + string 'coeff'. (Passing None for symbols will return the + expression in a dictionary keyed to 'coeff'.) + + If ``force=True``, then bases of powers will be separated regardless + of assumptions on the symbols involved. + + Notes + ===== + + The order of the factors is determined by Mul, so that the + separated expressions may not necessarily be grouped together. + + Although factoring is necessary to separate variables in some + expressions, it is not necessary in all cases, so one should not + count on the returned factors being factored. + + Examples + ======== + + >>> from sympy.abc import x, y, z, alpha + >>> from sympy import separatevars, sin + >>> separatevars((x*y)**y) + (x*y)**y + >>> separatevars((x*y)**y, force=True) + x**y*y**y + + >>> e = 2*x**2*z*sin(y)+2*z*x**2 + >>> separatevars(e) + 2*x**2*z*(sin(y) + 1) + >>> separatevars(e, symbols=(x, y), dict=True) + {'coeff': 2*z, x: x**2, y: sin(y) + 1} + >>> separatevars(e, [x, y, alpha], dict=True) + {'coeff': 2*z, alpha: 1, x: x**2, y: sin(y) + 1} + + If the expression is not really separable, or is only partially + separable, separatevars will do the best it can to separate it + by using factoring. + + >>> separatevars(x + x*y - 3*x**2) + -x*(3*x - y - 1) + + If the expression is not separable then expr is returned unchanged + or (if dict=True) then None is returned. + + >>> eq = 2*x + y*sin(x) + >>> separatevars(eq) == eq + True + >>> separatevars(2*x + y*sin(x), symbols=(x, y), dict=True) is None + True + + """ + expr = sympify(expr) + if dict: + return _separatevars_dict(_separatevars(expr, force), symbols) + else: + return _separatevars(expr, force) + + +def _separatevars(expr, force): + if isinstance(expr, Abs): + arg = expr.args[0] + if arg.is_Mul and not arg.is_number: + s = separatevars(arg, dict=True, force=force) + if s is not None: + return Mul(*map(expr.func, s.values())) + else: + return expr + + if len(expr.free_symbols) < 2: + return expr + + # don't destroy a Mul since much of the work may already be done + if expr.is_Mul: + args = list(expr.args) + changed = False + for i, a in enumerate(args): + args[i] = separatevars(a, force) + changed = changed or args[i] != a + if changed: + expr = expr.func(*args) + return expr + + # get a Pow ready for expansion + if expr.is_Pow and expr.base != S.Exp1: + expr = Pow(separatevars(expr.base, force=force), expr.exp) + + # First try other expansion methods + expr = expr.expand(mul=False, multinomial=False, force=force) + + _expr, reps = posify(expr) if force else (expr, {}) + expr = factor(_expr).subs(reps) + + if not expr.is_Add: + return expr + + # Find any common coefficients to pull out + args = list(expr.args) + commonc = args[0].args_cnc(cset=True, warn=False)[0] + for i in args[1:]: + commonc &= i.args_cnc(cset=True, warn=False)[0] + commonc = Mul(*commonc) + commonc = commonc.as_coeff_Mul()[1] # ignore constants + commonc_set = commonc.args_cnc(cset=True, warn=False)[0] + + # remove them + for i, a in enumerate(args): + c, nc = a.args_cnc(cset=True, warn=False) + c = c - commonc_set + args[i] = Mul(*c)*Mul(*nc) + nonsepar = Add(*args) + + if len(nonsepar.free_symbols) > 1: + _expr = nonsepar + _expr, reps = posify(_expr) if force else (_expr, {}) + _expr = (factor(_expr)).subs(reps) + + if not _expr.is_Add: + nonsepar = _expr + + return commonc*nonsepar + + +def _separatevars_dict(expr, symbols): + if symbols: + if not all(t.is_Atom for t in symbols): + raise ValueError("symbols must be Atoms.") + symbols = list(symbols) + elif symbols is None: + return {'coeff': expr} + else: + symbols = list(expr.free_symbols) + if not symbols: + return None + + ret = {i: [] for i in symbols + ['coeff']} + + for i in Mul.make_args(expr): + expsym = i.free_symbols + intersection = set(symbols).intersection(expsym) + if len(intersection) > 1: + return None + if len(intersection) == 0: + # There are no symbols, so it is part of the coefficient + ret['coeff'].append(i) + else: + ret[intersection.pop()].append(i) + + # rebuild + for k, v in ret.items(): + ret[k] = Mul(*v) + + return ret + + +def posify(eq): + """Return ``eq`` (with generic symbols made positive) and a + dictionary containing the mapping between the old and new + symbols. + + Explanation + =========== + + Any symbol that has positive=None will be replaced with a positive dummy + symbol having the same name. This replacement will allow more symbolic + processing of expressions, especially those involving powers and + logarithms. + + A dictionary that can be sent to subs to restore ``eq`` to its original + symbols is also returned. + + >>> from sympy import posify, Symbol, log, solve + >>> from sympy.abc import x + >>> posify(x + Symbol('p', positive=True) + Symbol('n', negative=True)) + (_x + n + p, {_x: x}) + + >>> eq = 1/x + >>> log(eq).expand() + log(1/x) + >>> log(posify(eq)[0]).expand() + -log(_x) + >>> p, rep = posify(eq) + >>> log(p).expand().subs(rep) + -log(x) + + It is possible to apply the same transformations to an iterable + of expressions: + + >>> eq = x**2 - 4 + >>> solve(eq, x) + [-2, 2] + >>> eq_x, reps = posify([eq, x]); eq_x + [_x**2 - 4, _x] + >>> solve(*eq_x) + [2] + """ + eq = sympify(eq) + if not isinstance(eq, Basic) and iterable(eq): + f = type(eq) + eq = list(eq) + syms = set() + for e in eq: + syms = syms.union(e.atoms(Symbol)) + reps = {} + for s in syms: + reps.update({v: k for k, v in posify(s)[1].items()}) + for i, e in enumerate(eq): + eq[i] = e.subs(reps) + return f(eq), {r: s for s, r in reps.items()} + + reps = {s: Dummy(s.name, positive=True, **s.assumptions0) + for s in eq.free_symbols if s.is_positive is None} + eq = eq.subs(reps) + return eq, {r: s for s, r in reps.items()} + + +def hypersimp(f, k): + """Given combinatorial term f(k) simplify its consecutive term ratio + i.e. f(k+1)/f(k). The input term can be composed of functions and + integer sequences which have equivalent representation in terms + of gamma special function. + + Explanation + =========== + + The algorithm performs three basic steps: + + 1. Rewrite all functions in terms of gamma, if possible. + + 2. Rewrite all occurrences of gamma in terms of products + of gamma and rising factorial with integer, absolute + constant exponent. + + 3. Perform simplification of nested fractions, powers + and if the resulting expression is a quotient of + polynomials, reduce their total degree. + + If f(k) is hypergeometric then as result we arrive with a + quotient of polynomials of minimal degree. Otherwise None + is returned. + + For more information on the implemented algorithm refer to: + + 1. W. Koepf, Algorithms for m-fold Hypergeometric Summation, + Journal of Symbolic Computation (1995) 20, 399-417 + """ + f = sympify(f) + + g = f.subs(k, k + 1) / f + + g = g.rewrite(gamma) + if g.has(Piecewise): + g = piecewise_fold(g) + g = g.args[-1][0] + g = expand_func(g) + g = powsimp(g, deep=True, combine='exp') + + if g.is_rational_function(k): + return simplify(g, ratio=S.Infinity) + else: + return None + + +def hypersimilar(f, g, k): + """ + Returns True if ``f`` and ``g`` are hyper-similar. + + Explanation + =========== + + Similarity in hypergeometric sense means that a quotient of + f(k) and g(k) is a rational function in ``k``. This procedure + is useful in solving recurrence relations. + + For more information see hypersimp(). + + """ + f, g = list(map(sympify, (f, g))) + + h = (f/g).rewrite(gamma) + h = h.expand(func=True, basic=False) + + return h.is_rational_function(k) + + +def signsimp(expr, evaluate=None): + """Make all Add sub-expressions canonical wrt sign. + + Explanation + =========== + + If an Add subexpression, ``a``, can have a sign extracted, + as determined by could_extract_minus_sign, it is replaced + with Mul(-1, a, evaluate=False). This allows signs to be + extracted from powers and products. + + Examples + ======== + + >>> from sympy import signsimp, exp, symbols + >>> from sympy.abc import x, y + >>> i = symbols('i', odd=True) + >>> n = -1 + 1/x + >>> n/x/(-n)**2 - 1/n/x + (-1 + 1/x)/(x*(1 - 1/x)**2) - 1/(x*(-1 + 1/x)) + >>> signsimp(_) + 0 + >>> x*n + x*-n + x*(-1 + 1/x) + x*(1 - 1/x) + >>> signsimp(_) + 0 + + Since powers automatically handle leading signs + + >>> (-2)**i + -2**i + + signsimp can be used to put the base of a power with an integer + exponent into canonical form: + + >>> n**i + (-1 + 1/x)**i + + By default, signsimp does not leave behind any hollow simplification: + if making an Add canonical wrt sign didn't change the expression, the + original Add is restored. If this is not desired then the keyword + ``evaluate`` can be set to False: + + >>> e = exp(y - x) + >>> signsimp(e) == e + True + >>> signsimp(e, evaluate=False) + exp(-(x - y)) + + """ + if evaluate is None: + evaluate = global_parameters.evaluate + expr = sympify(expr) + if not isinstance(expr, (Expr, Relational)) or expr.is_Atom: + return expr + # get rid of an pre-existing unevaluation regarding sign + e = expr.replace(lambda x: x.is_Mul and -(-x) != x, lambda x: -(-x)) + e = sub_post(sub_pre(e)) + if not isinstance(e, (Expr, Relational)) or e.is_Atom: + return e + if e.is_Add: + rv = e.func(*[signsimp(a) for a in e.args]) + if not evaluate and isinstance(rv, Add + ) and rv.could_extract_minus_sign(): + return Mul(S.NegativeOne, -rv, evaluate=False) + return rv + if evaluate: + e = e.replace(lambda x: x.is_Mul and -(-x) != x, lambda x: -(-x)) + return e + + +@overload +def simplify(expr: Expr, **kwargs) -> Expr: ... +@overload +def simplify(expr: Boolean, **kwargs) -> Boolean: ... +@overload +def simplify(expr: Set, **kwargs) -> Set: ... +@overload +def simplify(expr: Basic, **kwargs) -> Basic: ... + +def simplify(expr, ratio=1.7, measure=count_ops, rational=False, inverse=False, doit=True, **kwargs): + """Simplifies the given expression. + + Explanation + =========== + + Simplification is not a well defined term and the exact strategies + this function tries can change in the future versions of SymPy. If + your algorithm relies on "simplification" (whatever it is), try to + determine what you need exactly - is it powsimp()?, radsimp()?, + together()?, logcombine()?, or something else? And use this particular + function directly, because those are well defined and thus your algorithm + will be robust. + + Nonetheless, especially for interactive use, or when you do not know + anything about the structure of the expression, simplify() tries to apply + intelligent heuristics to make the input expression "simpler". For + example: + + >>> from sympy import simplify, cos, sin + >>> from sympy.abc import x, y + >>> a = (x + x**2)/(x*sin(y)**2 + x*cos(y)**2) + >>> a + (x**2 + x)/(x*sin(y)**2 + x*cos(y)**2) + >>> simplify(a) + x + 1 + + Note that we could have obtained the same result by using specific + simplification functions: + + >>> from sympy import trigsimp, cancel + >>> trigsimp(a) + (x**2 + x)/x + >>> cancel(_) + x + 1 + + In some cases, applying :func:`simplify` may actually result in some more + complicated expression. The default ``ratio=1.7`` prevents more extreme + cases: if (result length)/(input length) > ratio, then input is returned + unmodified. The ``measure`` parameter lets you specify the function used + to determine how complex an expression is. The function should take a + single argument as an expression and return a number such that if + expression ``a`` is more complex than expression ``b``, then + ``measure(a) > measure(b)``. The default measure function is + :func:`~.count_ops`, which returns the total number of operations in the + expression. + + For example, if ``ratio=1``, ``simplify`` output cannot be longer + than input. + + :: + + >>> from sympy import sqrt, simplify, count_ops, oo + >>> root = 1/(sqrt(2)+3) + + Since ``simplify(root)`` would result in a slightly longer expression, + root is returned unchanged instead:: + + >>> simplify(root, ratio=1) == root + True + + If ``ratio=oo``, simplify will be applied anyway:: + + >>> count_ops(simplify(root, ratio=oo)) > count_ops(root) + True + + Note that the shortest expression is not necessary the simplest, so + setting ``ratio`` to 1 may not be a good idea. + Heuristically, the default value ``ratio=1.7`` seems like a reasonable + choice. + + You can easily define your own measure function based on what you feel + should represent the "size" or "complexity" of the input expression. Note + that some choices, such as ``lambda expr: len(str(expr))`` may appear to be + good metrics, but have other problems (in this case, the measure function + may slow down simplify too much for very large expressions). If you do not + know what a good metric would be, the default, ``count_ops``, is a good + one. + + For example: + + >>> from sympy import symbols, log + >>> a, b = symbols('a b', positive=True) + >>> g = log(a) + log(b) + log(a)*log(1/b) + >>> h = simplify(g) + >>> h + log(a*b**(1 - log(a))) + >>> count_ops(g) + 8 + >>> count_ops(h) + 5 + + So you can see that ``h`` is simpler than ``g`` using the count_ops metric. + However, we may not like how ``simplify`` (in this case, using + ``logcombine``) has created the ``b**(log(1/a) + 1)`` term. A simple way + to reduce this would be to give more weight to powers as operations in + ``count_ops``. We can do this by using the ``visual=True`` option: + + >>> print(count_ops(g, visual=True)) + 2*ADD + DIV + 4*LOG + MUL + >>> print(count_ops(h, visual=True)) + 2*LOG + MUL + POW + SUB + + >>> from sympy import Symbol, S + >>> def my_measure(expr): + ... POW = Symbol('POW') + ... # Discourage powers by giving POW a weight of 10 + ... count = count_ops(expr, visual=True).subs(POW, 10) + ... # Every other operation gets a weight of 1 (the default) + ... count = count.replace(Symbol, type(S.One)) + ... return count + >>> my_measure(g) + 8 + >>> my_measure(h) + 14 + >>> 15./8 > 1.7 # 1.7 is the default ratio + True + >>> simplify(g, measure=my_measure) + -log(a)*log(b) + log(a) + log(b) + + Note that because ``simplify()`` internally tries many different + simplification strategies and then compares them using the measure + function, we get a completely different result that is still different + from the input expression by doing this. + + If ``rational=True``, Floats will be recast as Rationals before simplification. + If ``rational=None``, Floats will be recast as Rationals but the result will + be recast as Floats. If rational=False(default) then nothing will be done + to the Floats. + + If ``inverse=True``, it will be assumed that a composition of inverse + functions, such as sin and asin, can be cancelled in any order. + For example, ``asin(sin(x))`` will yield ``x`` without checking whether + x belongs to the set where this relation is true. The default is + False. + + Note that ``simplify()`` automatically calls ``doit()`` on the final + expression. You can avoid this behavior by passing ``doit=False`` as + an argument. + + Also, it should be noted that simplifying a boolean expression is not + well defined. If the expression prefers automatic evaluation (such as + :obj:`~.Eq()` or :obj:`~.Or()`), simplification will return ``True`` or + ``False`` if truth value can be determined. If the expression is not + evaluated by default (such as :obj:`~.Predicate()`), simplification will + not reduce it and you should use :func:`~.refine` or :func:`~.ask` + function. This inconsistency will be resolved in future version. + + See Also + ======== + + sympy.assumptions.refine.refine : Simplification using assumptions. + sympy.assumptions.ask.ask : Query for boolean expressions using assumptions. + """ + + def shorter(*choices): + """ + Return the choice that has the fewest ops. In case of a tie, + the expression listed first is selected. + """ + if not has_variety(choices): + return choices[0] + return min(choices, key=measure) + + def done(e): + rv = e.doit() if doit else e + return shorter(rv, collect_abs(rv)) + + expr = sympify(expr, rational=rational) + kwargs = { + "ratio": kwargs.get('ratio', ratio), + "measure": kwargs.get('measure', measure), + "rational": kwargs.get('rational', rational), + "inverse": kwargs.get('inverse', inverse), + "doit": kwargs.get('doit', doit)} + # no routine for Expr needs to check for is_zero + if isinstance(expr, Expr) and expr.is_zero: + return S.Zero if not expr.is_Number else expr + + _eval_simplify = getattr(expr, '_eval_simplify', None) + if _eval_simplify is not None: + return _eval_simplify(**kwargs) + + original_expr = expr = collect_abs(signsimp(expr)) + + if not isinstance(expr, Basic) or not expr.args: # XXX: temporary hack + return expr + + if inverse and expr.has(Function): + expr = inversecombine(expr) + if not expr.args: # simplified to atomic + return expr + + # do deep simplification + handled = Add, Mul, Pow, ExpBase + expr = expr.replace( + # here, checking for x.args is not enough because Basic has + # args but Basic does not always play well with replace, e.g. + # when simultaneous is True found expressions will be masked + # off with a Dummy but not all Basic objects in an expression + # can be replaced with a Dummy + lambda x: isinstance(x, Expr) and x.args and not isinstance( + x, handled), + lambda x: x.func(*[simplify(i, **kwargs) for i in x.args]), + simultaneous=False) + if not isinstance(expr, handled): + return done(expr) + + if not expr.is_commutative: + expr = nc_simplify(expr) + + # TODO: Apply different strategies, considering expression pattern: + # is it a purely rational function? Is there any trigonometric function?... + # See also https://github.com/sympy/sympy/pull/185. + + # rationalize Floats + floats = False + if rational is not False and expr.has(Float): + floats = True + expr = nsimplify(expr, rational=True) + + expr = _bottom_up(expr, lambda w: getattr(w, 'normal', lambda: w)()) + expr = Mul(*powsimp(expr).as_content_primitive()) + _e = cancel(expr) + expr1 = shorter(_e, _mexpand(_e).cancel()) # issue 6829 + expr2 = shorter(together(expr, deep=True), together(expr1, deep=True)) + + if ratio is S.Infinity: + expr = expr2 + else: + expr = shorter(expr2, expr1, expr) + if not isinstance(expr, Basic): # XXX: temporary hack + return expr + + expr = factor_terms(expr, sign=False) + + # must come before `Piecewise` since this introduces more `Piecewise` terms + if expr.has(sign): + expr = expr.rewrite(Abs) + + # Deal with Piecewise separately to avoid recursive growth of expressions + if expr.has(Piecewise): + # Fold into a single Piecewise + expr = piecewise_fold(expr) + # Apply doit, if doit=True + expr = done(expr) + # Still a Piecewise? + if expr.has(Piecewise): + # Fold into a single Piecewise, in case doit lead to some + # expressions being Piecewise + expr = piecewise_fold(expr) + # kroneckersimp also affects Piecewise + if expr.has(KroneckerDelta): + expr = kroneckersimp(expr) + # Still a Piecewise? + if expr.has(Piecewise): + # Do not apply doit on the segments as it has already + # been done above, but simplify + expr = piecewise_simplify(expr, deep=True, doit=False) + # Still a Piecewise? + if expr.has(Piecewise): + # Try factor common terms + expr = shorter(expr, factor_terms(expr)) + # As all expressions have been simplified above with the + # complete simplify, nothing more needs to be done here + return expr + + # hyperexpand automatically only works on hypergeometric terms + # Do this after the Piecewise part to avoid recursive expansion + expr = hyperexpand(expr) + + if expr.has(KroneckerDelta): + expr = kroneckersimp(expr) + + if expr.has(BesselBase): + expr = besselsimp(expr) + + if expr.has(TrigonometricFunction, HyperbolicFunction): + expr = trigsimp(expr, deep=True) + + if expr.has(log): + expr = shorter(expand_log(expr, deep=True), logcombine(expr)) + + if expr.has(CombinatorialFunction, gamma): + # expression with gamma functions or non-integer arguments is + # automatically passed to gammasimp + expr = combsimp(expr) + + if expr.has(Sum): + expr = sum_simplify(expr, **kwargs) + + if expr.has(Integral): + expr = expr.xreplace({ + i: factor_terms(i) for i in expr.atoms(Integral)}) + + if expr.has(Product): + expr = product_simplify(expr, **kwargs) + + from sympy.physics.units import Quantity + + if expr.has(Quantity): + from sympy.physics.units.util import quantity_simplify + expr = quantity_simplify(expr) + + short = shorter(powsimp(expr, combine='exp', deep=True), powsimp(expr), expr) + short = shorter(short, cancel(short)) + short = shorter(short, factor_terms(short), expand_power_exp(expand_mul(short))) + if short.has(TrigonometricFunction, HyperbolicFunction, ExpBase, exp): + short = exptrigsimp(short) + + # get rid of hollow 2-arg Mul factorization + hollow_mul = Transform( + lambda x: Mul(*x.args), + lambda x: + x.is_Mul and + len(x.args) == 2 and + x.args[0].is_Number and + x.args[1].is_Add and + x.is_commutative) + expr = short.xreplace(hollow_mul) + + numer, denom = expr.as_numer_denom() + if denom.is_Add: + n, d = fraction(radsimp(1/denom, symbolic=False, max_terms=1)) + if n is not S.One: + expr = (numer*n).expand()/d + + if expr.could_extract_minus_sign(): + n, d = fraction(expr) + if d != 0: + expr = signsimp(-n/(-d)) + + if measure(expr) > ratio*measure(original_expr): + expr = original_expr + + # restore floats + if floats and rational is None: + expr = nfloat(expr, exponent=False) + + return done(expr) + + +def sum_simplify(s, **kwargs): + """Main function for Sum simplification""" + if not isinstance(s, Add): + s = s.xreplace({a: sum_simplify(a, **kwargs) + for a in s.atoms(Add) if a.has(Sum)}) + s = expand(s) + if not isinstance(s, Add): + return s + + terms = s.args + s_t = [] # Sum Terms + o_t = [] # Other Terms + + for term in terms: + sum_terms, other = sift(Mul.make_args(term), + lambda i: isinstance(i, Sum), binary=True) + if not sum_terms: + o_t.append(term) + continue + other = [Mul(*other)] + s_t.append(Mul(*(other + [s._eval_simplify(**kwargs) for s in sum_terms]))) + + result = Add(sum_combine(s_t), *o_t) + + return result + + +def sum_combine(s_t): + """Helper function for Sum simplification + + Attempts to simplify a list of sums, by combining limits / sum function's + returns the simplified sum + """ + used = [False] * len(s_t) + + for method in range(2): + for i, s_term1 in enumerate(s_t): + if not used[i]: + for j, s_term2 in enumerate(s_t): + if not used[j] and i != j: + temp = sum_add(s_term1, s_term2, method) + if isinstance(temp, (Sum, Mul)): + s_t[i] = temp + s_term1 = s_t[i] + used[j] = True + + result = S.Zero + for i, s_term in enumerate(s_t): + if not used[i]: + result = Add(result, s_term) + + return result + +def factor_sum(self, limits=None, radical=False, clear=False, fraction=False, sign=True): + """Return Sum with constant factors extracted. + + If ``limits`` is specified then ``self`` is the summand; the other + keywords are passed to ``factor_terms``. + + Examples + ======== + + >>> from sympy import Sum + >>> from sympy.abc import x, y + >>> from sympy.simplify.simplify import factor_sum + >>> s = Sum(x*y, (x, 1, 3)) + >>> factor_sum(s) + y*Sum(x, (x, 1, 3)) + >>> factor_sum(s.function, s.limits) + y*Sum(x, (x, 1, 3)) + """ + + # XXX deprecate in favor of direct call to factor_terms + kwargs = {"radical": radical, "clear": clear, + "fraction": fraction, "sign": sign} + expr = Sum(self, *limits) if limits else self + return factor_terms(expr, **kwargs) + + +def sum_add(self, other, method=0): + """Helper function for Sum simplification""" + #we know this is something in terms of a constant * a sum + #so we temporarily put the constants inside for simplification + #then simplify the result + def __refactor(val): + args = Mul.make_args(val) + sumv = next(x for x in args if isinstance(x, Sum)) + constant = Mul(*[x for x in args if x != sumv]) + return Sum(constant * sumv.function, *sumv.limits) + + if isinstance(self, Mul): + rself = __refactor(self) + else: + rself = self + + if isinstance(other, Mul): + rother = __refactor(other) + else: + rother = other + + if type(rself) is type(rother): + if method == 0: + if rself.limits == rother.limits: + return factor_sum(Sum(rself.function + rother.function, *rself.limits)) + elif method == 1: + if simplify(rself.function - rother.function) == 0: + if len(rself.limits) == len(rother.limits) == 1: + i = rself.limits[0][0] + x1 = rself.limits[0][1] + y1 = rself.limits[0][2] + j = rother.limits[0][0] + x2 = rother.limits[0][1] + y2 = rother.limits[0][2] + + if i == j: + if x2 == y1 + 1: + return factor_sum(Sum(rself.function, (i, x1, y2))) + elif x1 == y2 + 1: + return factor_sum(Sum(rself.function, (i, x2, y1))) + + return Add(self, other) + + +def product_simplify(s, **kwargs): + """Main function for Product simplification""" + terms = Mul.make_args(s) + p_t = [] # Product Terms + o_t = [] # Other Terms + + deep = kwargs.get('deep', True) + for term in terms: + if isinstance(term, Product): + if deep: + p_t.append(Product(term.function.simplify(**kwargs), + *term.limits)) + else: + p_t.append(term) + else: + o_t.append(term) + + used = [False] * len(p_t) + + for method in range(2): + for i, p_term1 in enumerate(p_t): + if not used[i]: + for j, p_term2 in enumerate(p_t): + if not used[j] and i != j: + tmp_prod = product_mul(p_term1, p_term2, method) + if isinstance(tmp_prod, Product): + p_t[i] = tmp_prod + used[j] = True + + result = Mul(*o_t) + + for i, p_term in enumerate(p_t): + if not used[i]: + result = Mul(result, p_term) + + return result + + +def product_mul(self, other, method=0): + """Helper function for Product simplification""" + if type(self) is type(other): + if method == 0: + if self.limits == other.limits: + return Product(self.function * other.function, *self.limits) + elif method == 1: + if simplify(self.function - other.function) == 0: + if len(self.limits) == len(other.limits) == 1: + i = self.limits[0][0] + x1 = self.limits[0][1] + y1 = self.limits[0][2] + j = other.limits[0][0] + x2 = other.limits[0][1] + y2 = other.limits[0][2] + + if i == j: + if x2 == y1 + 1: + return Product(self.function, (i, x1, y2)) + elif x1 == y2 + 1: + return Product(self.function, (i, x2, y1)) + + return Mul(self, other) + + +def _nthroot_solve(p, n, prec): + """ + helper function for ``nthroot`` + It denests ``p**Rational(1, n)`` using its minimal polynomial + """ + from sympy.solvers import solve + while n % 2 == 0: + p = sqrtdenest(sqrt(p)) + n = n // 2 + if n == 1: + return p + pn = p**Rational(1, n) + x = Symbol('x') + f = _minimal_polynomial_sq(p, n, x) + if f is None: + return None + sols = solve(f, x) + for sol in sols: + if abs(sol - pn).n() < 1./10**prec: + sol = sqrtdenest(sol) + if _mexpand(sol**n) == p: + return sol + + +def logcombine(expr, force=False): + """ + Takes logarithms and combines them using the following rules: + + - log(x) + log(y) == log(x*y) if both are positive + - a*log(x) == log(x**a) if x is positive and a is real + + If ``force`` is ``True`` then the assumptions above will be assumed to hold if + there is no assumption already in place on a quantity. For example, if + ``a`` is imaginary or the argument negative, force will not perform a + combination but if ``a`` is a symbol with no assumptions the change will + take place. + + Examples + ======== + + >>> from sympy import Symbol, symbols, log, logcombine, I + >>> from sympy.abc import a, x, y, z + >>> logcombine(a*log(x) + log(y) - log(z)) + a*log(x) + log(y) - log(z) + >>> logcombine(a*log(x) + log(y) - log(z), force=True) + log(x**a*y/z) + >>> x,y,z = symbols('x,y,z', positive=True) + >>> a = Symbol('a', real=True) + >>> logcombine(a*log(x) + log(y) - log(z)) + log(x**a*y/z) + + The transformation is limited to factors and/or terms that + contain logs, so the result depends on the initial state of + expansion: + + >>> eq = (2 + 3*I)*log(x) + >>> logcombine(eq, force=True) == eq + True + >>> logcombine(eq.expand(), force=True) + log(x**2) + I*log(x**3) + + See Also + ======== + + posify: replace all symbols with symbols having positive assumptions + sympy.core.function.expand_log: expand the logarithms of products + and powers; the opposite of logcombine + + """ + + def f(rv): + if not (rv.is_Add or rv.is_Mul): + return rv + + def gooda(a): + # bool to tell whether the leading ``a`` in ``a*log(x)`` + # could appear as log(x**a) + return (a is not S.NegativeOne and # -1 *could* go, but we disallow + (a.is_extended_real or force and a.is_extended_real is not False)) + + def goodlog(l): + # bool to tell whether log ``l``'s argument can combine with others + a = l.args[0] + return a.is_positive or force and a.is_nonpositive is not False + + other = [] + logs = [] + log1 = defaultdict(list) + for a in Add.make_args(rv): + if isinstance(a, log) and goodlog(a): + log1[()].append(([], a)) + elif not a.is_Mul: + other.append(a) + else: + ot = [] + co = [] + lo = [] + for ai in a.args: + if ai.is_Rational and ai < 0: + ot.append(S.NegativeOne) + co.append(-ai) + elif isinstance(ai, log) and goodlog(ai): + lo.append(ai) + elif gooda(ai): + co.append(ai) + else: + ot.append(ai) + if len(lo) > 1: + logs.append((ot, co, lo)) + elif lo: + log1[tuple(ot)].append((co, lo[0])) + else: + other.append(a) + + # if there is only one log in other, put it with the + # good logs + if len(other) == 1 and isinstance(other[0], log): + log1[()].append(([], other.pop())) + # if there is only one log at each coefficient and none have + # an exponent to place inside the log then there is nothing to do + if not logs and all(len(log1[k]) == 1 and log1[k][0] == [] for k in log1): + return rv + + # collapse multi-logs as far as possible in a canonical way + # TODO: see if x*log(a)+x*log(a)*log(b) -> x*log(a)*(1+log(b))? + # -- in this case, it's unambiguous, but if it were were a log(c) in + # each term then it's arbitrary whether they are grouped by log(a) or + # by log(c). So for now, just leave this alone; it's probably better to + # let the user decide + for o, e, l in logs: + l = list(ordered(l)) + e = log(l.pop(0).args[0]**Mul(*e)) + while l: + li = l.pop(0) + e = log(li.args[0]**e) + c, l = Mul(*o), e + if isinstance(l, log): # it should be, but check to be sure + log1[(c,)].append(([], l)) + else: + other.append(c*l) + + # logs that have the same coefficient can multiply + for k in list(log1.keys()): + log1[Mul(*k)] = log(logcombine(Mul(*[ + l.args[0]**Mul(*c) for c, l in log1.pop(k)]), + force=force), evaluate=False) + + # logs that have oppositely signed coefficients can divide + for k in ordered(list(log1.keys())): + if k not in log1: # already popped as -k + continue + if -k in log1: + # figure out which has the minus sign; the one with + # more op counts should be the one + num, den = k, -k + if num.count_ops() > den.count_ops(): + num, den = den, num + other.append( + num*log(log1.pop(num).args[0]/log1.pop(den).args[0], + evaluate=False)) + else: + other.append(k*log1.pop(k)) + + return Add(*other) + + return _bottom_up(expr, f) + + +def inversecombine(expr): + """Simplify the composition of a function and its inverse. + + Explanation + =========== + + No attention is paid to whether the inverse is a left inverse or a + right inverse; thus, the result will in general not be equivalent + to the original expression. + + Examples + ======== + + >>> from sympy.simplify.simplify import inversecombine + >>> from sympy import asin, sin, log, exp + >>> from sympy.abc import x + >>> inversecombine(asin(sin(x))) + x + >>> inversecombine(2*log(exp(3*x))) + 6*x + """ + + def f(rv): + if isinstance(rv, log): + if isinstance(rv.args[0], exp) or (rv.args[0].is_Pow and rv.args[0].base == S.Exp1): + rv = rv.args[0].exp + elif rv.is_Function and hasattr(rv, "inverse"): + if (len(rv.args) == 1 and len(rv.args[0].args) == 1 and + isinstance(rv.args[0], rv.inverse(argindex=1))): + rv = rv.args[0].args[0] + if rv.is_Pow and rv.base == S.Exp1: + if isinstance(rv.exp, log): + rv = rv.exp.args[0] + return rv + + return _bottom_up(expr, f) + + +def kroneckersimp(expr): + """ + Simplify expressions with KroneckerDelta. + + The only simplification currently attempted is to identify multiplicative cancellation: + + Examples + ======== + + >>> from sympy import KroneckerDelta, kroneckersimp + >>> from sympy.abc import i + >>> kroneckersimp(1 + KroneckerDelta(0, i) * KroneckerDelta(1, i)) + 1 + """ + def args_cancel(args1, args2): + for i1 in range(2): + for i2 in range(2): + a1 = args1[i1] + a2 = args2[i2] + a3 = args1[(i1 + 1) % 2] + a4 = args2[(i2 + 1) % 2] + if Eq(a1, a2) is S.true and Eq(a3, a4) is S.false: + return True + return False + + def cancel_kronecker_mul(m): + args = m.args + deltas = [a for a in args if isinstance(a, KroneckerDelta)] + for delta1, delta2 in subsets(deltas, 2): + args1 = delta1.args + args2 = delta2.args + if args_cancel(args1, args2): + return S.Zero * m # In case of oo etc + return m + + if not expr.has(KroneckerDelta): + return expr + + if expr.has(Piecewise): + expr = expr.rewrite(KroneckerDelta) + + newexpr = expr + expr = None + + while newexpr != expr: + expr = newexpr + newexpr = expr.replace(lambda e: isinstance(e, Mul), cancel_kronecker_mul) + + return expr + + +def besselsimp(expr): + """ + Simplify bessel-type functions. + + Explanation + =========== + + This routine tries to simplify bessel-type functions. Currently it only + works on the Bessel J and I functions, however. It works by looking at all + such functions in turn, and eliminating factors of "I" and "-1" (actually + their polar equivalents) in front of the argument. Then, functions of + half-integer order are rewritten using trigonometric functions and + functions of integer order (> 1) are rewritten using functions + of low order. Finally, if the expression was changed, compute + factorization of the result with factor(). + + >>> from sympy import besselj, besseli, besselsimp, polar_lift, I, S + >>> from sympy.abc import z, nu + >>> besselsimp(besselj(nu, z*polar_lift(-1))) + exp(I*pi*nu)*besselj(nu, z) + >>> besselsimp(besseli(nu, z*polar_lift(-I))) + exp(-I*pi*nu/2)*besselj(nu, z) + >>> besselsimp(besseli(S(-1)/2, z)) + sqrt(2)*cosh(z)/(sqrt(pi)*sqrt(z)) + >>> besselsimp(z*besseli(0, z) + z*(besseli(2, z))/2 + besseli(1, z)) + 3*z*besseli(0, z)/2 + """ + # TODO + # - better algorithm? + # - simplify (cos(pi*b)*besselj(b,z) - besselj(-b,z))/sin(pi*b) ... + # - use contiguity relations? + + def replacer(fro, to, factors): + factors = set(factors) + + def repl(nu, z): + if factors.intersection(Mul.make_args(z)): + return to(nu, z) + return fro(nu, z) + return repl + + def torewrite(fro, to): + def tofunc(nu, z): + return fro(nu, z).rewrite(to) + return tofunc + + def tominus(fro): + def tofunc(nu, z): + return exp(I*pi*nu)*fro(nu, exp_polar(-I*pi)*z) + return tofunc + + orig_expr = expr + + ifactors = [I, exp_polar(I*pi/2), exp_polar(-I*pi/2)] + expr = expr.replace( + besselj, replacer(besselj, + torewrite(besselj, besseli), ifactors)) + expr = expr.replace( + besseli, replacer(besseli, + torewrite(besseli, besselj), ifactors)) + + minusfactors = [-1, exp_polar(I*pi)] + expr = expr.replace( + besselj, replacer(besselj, tominus(besselj), minusfactors)) + expr = expr.replace( + besseli, replacer(besseli, tominus(besseli), minusfactors)) + + z0 = Dummy('z') + + def expander(fro): + def repl(nu, z): + if (nu % 1) == S.Half: + return simplify(trigsimp(unpolarify( + fro(nu, z0).rewrite(besselj).rewrite(jn).expand( + func=True)).subs(z0, z))) + elif nu.is_Integer and nu > 1: + return fro(nu, z).expand(func=True) + return fro(nu, z) + return repl + + expr = expr.replace(besselj, expander(besselj)) + expr = expr.replace(bessely, expander(bessely)) + expr = expr.replace(besseli, expander(besseli)) + expr = expr.replace(besselk, expander(besselk)) + + def _bessel_simp_recursion(expr): + + def _use_recursion(bessel, expr): + while True: + bessels = expr.find(lambda x: isinstance(x, bessel)) + try: + for ba in sorted(bessels, key=lambda x: re(x.args[0])): + a, x = ba.args + bap1 = bessel(a+1, x) + bap2 = bessel(a+2, x) + if expr.has(bap1) and expr.has(bap2): + expr = expr.subs(ba, 2*(a+1)/x*bap1 - bap2) + break + else: + return expr + except (ValueError, TypeError): + return expr + if expr.has(besselj): + expr = _use_recursion(besselj, expr) + if expr.has(bessely): + expr = _use_recursion(bessely, expr) + return expr + + expr = _bessel_simp_recursion(expr) + if expr != orig_expr: + expr = expr.factor() + + return expr + + +def nthroot(expr, n, max_len=4, prec=15): + """ + Compute a real nth-root of a sum of surds. + + Parameters + ========== + + expr : sum of surds + n : integer + max_len : maximum number of surds passed as constants to ``nsimplify`` + + Algorithm + ========= + + First ``nsimplify`` is used to get a candidate root; if it is not a + root the minimal polynomial is computed; the answer is one of its + roots. + + Examples + ======== + + >>> from sympy.simplify.simplify import nthroot + >>> from sympy import sqrt + >>> nthroot(90 + 34*sqrt(7), 3) + sqrt(7) + 3 + + """ + expr = sympify(expr) + n = sympify(n) + p = expr**Rational(1, n) + if not n.is_integer: + return p + if not _is_sum_surds(expr): + return p + surds = [] + coeff_muls = [x.as_coeff_Mul() for x in expr.args] + for x, y in coeff_muls: + if not x.is_rational: + return p + if y is S.One: + continue + if not (y.is_Pow and y.exp == S.Half and y.base.is_integer): + return p + surds.append(y) + surds.sort() + surds = surds[:max_len] + if expr < 0 and n % 2 == 1: + p = (-expr)**Rational(1, n) + a = nsimplify(p, constants=surds) + res = a if _mexpand(a**n) == _mexpand(-expr) else p + return -res + a = nsimplify(p, constants=surds) + if _mexpand(a) is not _mexpand(p) and _mexpand(a**n) == _mexpand(expr): + return _mexpand(a) + expr = _nthroot_solve(expr, n, prec) + if expr is None: + return p + return expr + + +def nsimplify(expr, constants=(), tolerance=None, full=False, rational=None, + rational_conversion='base10'): + """ + Find a simple representation for a number or, if there are free symbols or + if ``rational=True``, then replace Floats with their Rational equivalents. If + no change is made and rational is not False then Floats will at least be + converted to Rationals. + + Explanation + =========== + + For numerical expressions, a simple formula that numerically matches the + given numerical expression is sought (and the input should be possible + to evalf to a precision of at least 30 digits). + + Optionally, a list of (rationally independent) constants to + include in the formula may be given. + + A lower tolerance may be set to find less exact matches. If no tolerance + is given then the least precise value will set the tolerance (e.g. Floats + default to 15 digits of precision, so would be tolerance=10**-15). + + With ``full=True``, a more extensive search is performed + (this is useful to find simpler numbers when the tolerance + is set low). + + When converting to rational, if rational_conversion='base10' (the default), then + convert floats to rationals using their base-10 (string) representation. + When rational_conversion='exact' it uses the exact, base-2 representation. + + Examples + ======== + + >>> from sympy import nsimplify, sqrt, GoldenRatio, exp, I, pi + >>> nsimplify(4/(1+sqrt(5)), [GoldenRatio]) + -2 + 2*GoldenRatio + >>> nsimplify((1/(exp(3*pi*I/5)+1))) + 1/2 - I*sqrt(sqrt(5)/10 + 1/4) + >>> nsimplify(I**I, [pi]) + exp(-pi/2) + >>> nsimplify(pi, tolerance=0.01) + 22/7 + + >>> nsimplify(0.333333333333333, rational=True, rational_conversion='exact') + 6004799503160655/18014398509481984 + >>> nsimplify(0.333333333333333, rational=True) + 1/3 + + See Also + ======== + + sympy.core.function.nfloat + + """ + try: + return sympify(as_int(expr)) + except (TypeError, ValueError): + pass + expr = sympify(expr).xreplace({ + Float('inf'): S.Infinity, + Float('-inf'): S.NegativeInfinity, + }) + if expr is S.Infinity or expr is S.NegativeInfinity: + return expr + if rational or expr.free_symbols: + return _real_to_rational(expr, tolerance, rational_conversion) + + # SymPy's default tolerance for Rationals is 15; other numbers may have + # lower tolerances set, so use them to pick the largest tolerance if None + # was given + if tolerance is None: + tolerance = 10**-min([15] + + [mpmath.libmp.libmpf.prec_to_dps(n._prec) + for n in expr.atoms(Float)]) + # XXX should prec be set independent of tolerance or should it be computed + # from tolerance? + prec = 30 + bprec = int(prec*3.33) + + constants_dict = {} + for constant in constants: + constant = sympify(constant) + v = constant.evalf(prec) + if not v.is_Float: + raise ValueError("constants must be real-valued") + constants_dict[str(constant)] = v._to_mpmath(bprec) + + exprval = expr.evalf(prec, chop=True) + re, im = exprval.as_real_imag() + + # safety check to make sure that this evaluated to a number + if not (re.is_Number and im.is_Number): + return expr + + def nsimplify_real(x): + orig = mpmath.mp.dps + xv = x._to_mpmath(bprec) + try: + # We'll be happy with low precision if a simple fraction + if not (tolerance or full): + mpmath.mp.dps = 15 + rat = mpmath.pslq([xv, 1]) + if rat is not None: + return Rational(-int(rat[1]), int(rat[0])) + mpmath.mp.dps = prec + newexpr = mpmath.identify(xv, constants=constants_dict, + tol=tolerance, full=full) + if not newexpr: + raise ValueError + if full: + newexpr = newexpr[0] + expr = sympify(newexpr) + if x and not expr: # don't let x become 0 + raise ValueError + if expr.is_finite is False and xv not in [mpmath.inf, mpmath.ninf]: + raise ValueError + return expr + finally: + # even though there are returns above, this is executed + # before leaving + mpmath.mp.dps = orig + try: + if re: + re = nsimplify_real(re) + if im: + im = nsimplify_real(im) + except ValueError: + if rational is None: + return _real_to_rational(expr, rational_conversion=rational_conversion) + return expr + + rv = re + im*S.ImaginaryUnit + # if there was a change or rational is explicitly not wanted + # return the value, else return the Rational representation + if rv != expr or rational is False: + return rv + return _real_to_rational(expr, rational_conversion=rational_conversion) + + +def _real_to_rational(expr, tolerance=None, rational_conversion='base10'): + """ + Replace all reals in expr with rationals. + + Examples + ======== + + >>> from sympy.simplify.simplify import _real_to_rational + >>> from sympy.abc import x + + >>> _real_to_rational(.76 + .1*x**.5) + sqrt(x)/10 + 19/25 + + If rational_conversion='base10', this uses the base-10 string. If + rational_conversion='exact', the exact, base-2 representation is used. + + >>> _real_to_rational(0.333333333333333, rational_conversion='exact') + 6004799503160655/18014398509481984 + >>> _real_to_rational(0.333333333333333) + 1/3 + + """ + expr = _sympify(expr) + inf = Float('inf') + p = expr + reps = {} + reduce_num = None + if tolerance is not None and tolerance < 1: + reduce_num = ceiling(1/tolerance) + for fl in p.atoms(Float): + key = fl + if reduce_num is not None: + r = Rational(fl).limit_denominator(reduce_num) + elif (tolerance is not None and tolerance >= 1 and + fl.is_Integer is False): + r = Rational(tolerance*round(fl/tolerance) + ).limit_denominator(int(tolerance)) + else: + if rational_conversion == 'exact': + r = Rational(fl) + reps[key] = r + continue + elif rational_conversion != 'base10': + raise ValueError("rational_conversion must be 'base10' or 'exact'") + + r = nsimplify(fl, rational=False) + # e.g. log(3).n() -> log(3) instead of a Rational + if fl and not r: + r = Rational(fl) + elif not r.is_Rational: + if fl in (inf, -inf): + r = S.ComplexInfinity + elif fl < 0: + fl = -fl + d = Pow(10, int(mpmath.log(fl)/mpmath.log(10))) + r = -Rational(str(fl/d))*d + elif fl > 0: + d = Pow(10, int(mpmath.log(fl)/mpmath.log(10))) + r = Rational(str(fl/d))*d + else: + r = S.Zero + reps[key] = r + return p.subs(reps, simultaneous=True) + + +def clear_coefficients(expr, rhs=S.Zero): + """Return `p, r` where `p` is the expression obtained when Rational + additive and multiplicative coefficients of `expr` have been stripped + away in a naive fashion (i.e. without simplification). The operations + needed to remove the coefficients will be applied to `rhs` and returned + as `r`. + + Examples + ======== + + >>> from sympy.simplify.simplify import clear_coefficients + >>> from sympy.abc import x, y + >>> from sympy import Dummy + >>> expr = 4*y*(6*x + 3) + >>> clear_coefficients(expr - 2) + (y*(2*x + 1), 1/6) + + When solving 2 or more expressions like `expr = a`, + `expr = b`, etc..., it is advantageous to provide a Dummy symbol + for `rhs` and simply replace it with `a`, `b`, etc... in `r`. + + >>> rhs = Dummy('rhs') + >>> clear_coefficients(expr, rhs) + (y*(2*x + 1), _rhs/12) + >>> _[1].subs(rhs, 2) + 1/6 + """ + was = None + free = expr.free_symbols + if expr.is_Rational: + return (S.Zero, rhs - expr) + while expr and was != expr: + was = expr + m, expr = ( + expr.as_content_primitive() + if free else + factor_terms(expr).as_coeff_Mul(rational=True)) + rhs /= m + c, expr = expr.as_coeff_Add(rational=True) + rhs -= c + expr = signsimp(expr, evaluate = False) + if expr.could_extract_minus_sign(): + expr = -expr + rhs = -rhs + return expr, rhs + +def nc_simplify(expr, deep=True): + ''' + Simplify a non-commutative expression composed of multiplication + and raising to a power by grouping repeated subterms into one power. + Priority is given to simplifications that give the fewest number + of arguments in the end (for example, in a*b*a*b*c*a*b*c simplifying + to (a*b)**2*c*a*b*c gives 5 arguments while a*b*(a*b*c)**2 has 3). + If ``expr`` is a sum of such terms, the sum of the simplified terms + is returned. + + Keyword argument ``deep`` controls whether or not subexpressions + nested deeper inside the main expression are simplified. See examples + below. Setting `deep` to `False` can save time on nested expressions + that do not need simplifying on all levels. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.simplify.simplify import nc_simplify + >>> a, b, c = symbols("a b c", commutative=False) + >>> nc_simplify(a*b*a*b*c*a*b*c) + a*b*(a*b*c)**2 + >>> expr = a**2*b*a**4*b*a**4 + >>> nc_simplify(expr) + a**2*(b*a**4)**2 + >>> nc_simplify(a*b*a*b*c**2*(a*b)**2*c**2) + ((a*b)**2*c**2)**2 + >>> nc_simplify(a*b*a*b + 2*a*c*a**2*c*a**2*c*a) + (a*b)**2 + 2*(a*c*a)**3 + >>> nc_simplify(b**-1*a**-1*(a*b)**2) + a*b + >>> nc_simplify(a**-1*b**-1*c*a) + (b*a)**(-1)*c*a + >>> expr = (a*b*a*b)**2*a*c*a*c + >>> nc_simplify(expr) + (a*b)**4*(a*c)**2 + >>> nc_simplify(expr, deep=False) + (a*b*a*b)**2*(a*c)**2 + + ''' + if isinstance(expr, MatrixExpr): + expr = expr.doit(inv_expand=False) + _Add, _Mul, _Pow, _Symbol = MatAdd, MatMul, MatPow, MatrixSymbol + else: + _Add, _Mul, _Pow, _Symbol = Add, Mul, Pow, Symbol + + # =========== Auxiliary functions ======================== + def _overlaps(args): + # Calculate a list of lists m such that m[i][j] contains the lengths + # of all possible overlaps between args[:i+1] and args[i+1+j:]. + # An overlap is a suffix of the prefix that matches a prefix + # of the suffix. + # For example, let expr=c*a*b*a*b*a*b*a*b. Then m[3][0] contains + # the lengths of overlaps of c*a*b*a*b with a*b*a*b. The overlaps + # are a*b*a*b, a*b and the empty word so that m[3][0]=[4,2,0]. + # All overlaps rather than only the longest one are recorded + # because this information helps calculate other overlap lengths. + m = [[([1, 0] if a == args[0] else [0]) for a in args[1:]]] + for i in range(1, len(args)): + overlaps = [] + j = 0 + for j in range(len(args) - i - 1): + overlap = [] + for v in m[i-1][j+1]: + if j + i + 1 + v < len(args) and args[i] == args[j+i+1+v]: + overlap.append(v + 1) + overlap += [0] + overlaps.append(overlap) + m.append(overlaps) + return m + + def _reduce_inverses(_args): + # replace consecutive negative powers by an inverse + # of a product of positive powers, e.g. a**-1*b**-1*c + # will simplify to (a*b)**-1*c; + # return that new args list and the number of negative + # powers in it (inv_tot) + inv_tot = 0 # total number of inverses + inverses = [] + args = [] + for arg in _args: + if isinstance(arg, _Pow) and arg.args[1].is_extended_negative: + inverses = [arg**-1] + inverses + inv_tot += 1 + else: + if len(inverses) == 1: + args.append(inverses[0]**-1) + elif len(inverses) > 1: + args.append(_Pow(_Mul(*inverses), -1)) + inv_tot -= len(inverses) - 1 + inverses = [] + args.append(arg) + if inverses: + args.append(_Pow(_Mul(*inverses), -1)) + inv_tot -= len(inverses) - 1 + return inv_tot, tuple(args) + + def get_score(s): + # compute the number of arguments of s + # (including in nested expressions) overall + # but ignore exponents + if isinstance(s, _Pow): + return get_score(s.args[0]) + elif isinstance(s, (_Add, _Mul)): + return sum(get_score(a) for a in s.args) + return 1 + + def compare(s, alt_s): + # compare two possible simplifications and return a + # "better" one + if s != alt_s and get_score(alt_s) < get_score(s): + return alt_s + return s + # ======================================================== + + if not isinstance(expr, (_Add, _Mul, _Pow)) or expr.is_commutative: + return expr + args = expr.args[:] + if isinstance(expr, _Pow): + if deep: + return _Pow(nc_simplify(args[0]), args[1]).doit() + else: + return expr + elif isinstance(expr, _Add): + return _Add(*[nc_simplify(a, deep=deep) for a in args]).doit() + else: + # get the non-commutative part + c_args, args = expr.args_cnc() + com_coeff = Mul(*c_args) + if not equal_valued(com_coeff, 1): + return com_coeff*nc_simplify(expr/com_coeff, deep=deep) + + inv_tot, args = _reduce_inverses(args) + # if most arguments are negative, work with the inverse + # of the expression, e.g. a**-1*b*a**-1*c**-1 will become + # (c*a*b**-1*a)**-1 at the end so can work with c*a*b**-1*a + invert = False + if inv_tot > len(args)/2: + invert = True + args = [a**-1 for a in args[::-1]] + + if deep: + args = tuple(nc_simplify(a) for a in args) + + m = _overlaps(args) + + # simps will be {subterm: end} where `end` is the ending + # index of a sequence of repetitions of subterm; + # this is for not wasting time with subterms that are part + # of longer, already considered sequences + simps = {} + + post = 1 + pre = 1 + + # the simplification coefficient is the number of + # arguments by which contracting a given sequence + # would reduce the word; e.g. in a*b*a*b*c*a*b*c, + # contracting a*b*a*b to (a*b)**2 removes 3 arguments + # while a*b*c*a*b*c to (a*b*c)**2 removes 6. It's + # better to contract the latter so simplification + # with a maximum simplification coefficient will be chosen + max_simp_coeff = 0 + simp = None # information about future simplification + + for i in range(1, len(args)): + simp_coeff = 0 + l = 0 # length of a subterm + p = 0 # the power of a subterm + if i < len(args) - 1: + rep = m[i][0] + start = i # starting index of the repeated sequence + end = i+1 # ending index of the repeated sequence + if i == len(args)-1 or rep == [0]: + # no subterm is repeated at this stage, at least as + # far as the arguments are concerned - there may be + # a repetition if powers are taken into account + if (isinstance(args[i], _Pow) and + not isinstance(args[i].args[0], _Symbol)): + subterm = args[i].args[0].args + l = len(subterm) + if args[i-l:i] == subterm: + # e.g. a*b in a*b*(a*b)**2 is not repeated + # in args (= [a, b, (a*b)**2]) but it + # can be matched here + p += 1 + start -= l + if args[i+1:i+1+l] == subterm: + # e.g. a*b in (a*b)**2*a*b + p += 1 + end += l + if p: + p += args[i].args[1] + else: + continue + else: + l = rep[0] # length of the longest repeated subterm at this point + start -= l - 1 + subterm = args[start:end] + p = 2 + end += l + + if subterm in simps and simps[subterm] >= start: + # the subterm is part of a sequence that + # has already been considered + continue + + # count how many times it's repeated + while end < len(args): + if l in m[end-1][0]: + p += 1 + end += l + elif isinstance(args[end], _Pow) and args[end].args[0].args == subterm: + # for cases like a*b*a*b*(a*b)**2*a*b + p += args[end].args[1] + end += 1 + else: + break + + # see if another match can be made, e.g. + # for b*a**2 in b*a**2*b*a**3 or a*b in + # a**2*b*a*b + + pre_exp = 0 + pre_arg = 1 + if start - l >= 0 and args[start-l+1:start] == subterm[1:]: + if isinstance(subterm[0], _Pow): + pre_arg = subterm[0].args[0] + exp = subterm[0].args[1] + else: + pre_arg = subterm[0] + exp = 1 + if isinstance(args[start-l], _Pow) and args[start-l].args[0] == pre_arg: + pre_exp = args[start-l].args[1] - exp + start -= l + p += 1 + elif args[start-l] == pre_arg: + pre_exp = 1 - exp + start -= l + p += 1 + + post_exp = 0 + post_arg = 1 + if end + l - 1 < len(args) and args[end:end+l-1] == subterm[:-1]: + if isinstance(subterm[-1], _Pow): + post_arg = subterm[-1].args[0] + exp = subterm[-1].args[1] + else: + post_arg = subterm[-1] + exp = 1 + if isinstance(args[end+l-1], _Pow) and args[end+l-1].args[0] == post_arg: + post_exp = args[end+l-1].args[1] - exp + end += l + p += 1 + elif args[end+l-1] == post_arg: + post_exp = 1 - exp + end += l + p += 1 + + # Consider a*b*a**2*b*a**2*b*a: + # b*a**2 is explicitly repeated, but note + # that in this case a*b*a is also repeated + # so there are two possible simplifications: + # a*(b*a**2)**3*a**-1 or (a*b*a)**3 + # The latter is obviously simpler. + # But in a*b*a**2*b**2*a**2 the simplifications are + # a*(b*a**2)**2 and (a*b*a)**3*a in which case + # it's better to stick with the shorter subterm + if post_exp and exp % 2 == 0 and start > 0: + exp = exp/2 + _pre_exp = 1 + _post_exp = 1 + if isinstance(args[start-1], _Pow) and args[start-1].args[0] == post_arg: + _post_exp = post_exp + exp + _pre_exp = args[start-1].args[1] - exp + elif args[start-1] == post_arg: + _post_exp = post_exp + exp + _pre_exp = 1 - exp + if _pre_exp == 0 or _post_exp == 0: + if not pre_exp: + start -= 1 + post_exp = _post_exp + pre_exp = _pre_exp + pre_arg = post_arg + subterm = (post_arg**exp,) + subterm[:-1] + (post_arg**exp,) + + simp_coeff += end-start + + if post_exp: + simp_coeff -= 1 + if pre_exp: + simp_coeff -= 1 + + simps[subterm] = end + + if simp_coeff > max_simp_coeff: + max_simp_coeff = simp_coeff + simp = (start, _Mul(*subterm), p, end, l) + pre = pre_arg**pre_exp + post = post_arg**post_exp + + if simp: + subterm = _Pow(nc_simplify(simp[1], deep=deep), simp[2]) + pre = nc_simplify(_Mul(*args[:simp[0]])*pre, deep=deep) + post = post*nc_simplify(_Mul(*args[simp[3]:]), deep=deep) + simp = pre*subterm*post + if pre != 1 or post != 1: + # new simplifications may be possible but no need + # to recurse over arguments + simp = nc_simplify(simp, deep=False) + else: + simp = _Mul(*args) + + if invert: + simp = _Pow(simp, -1) + + # see if factor_nc(expr) is simplified better + if not isinstance(expr, MatrixExpr): + f_expr = factor_nc(expr) + if f_expr != expr: + alt_simp = nc_simplify(f_expr, deep=deep) + simp = compare(simp, alt_simp) + else: + simp = simp.doit(inv_expand=False) + return simp + + +def dotprodsimp(expr, withsimp=False): + """Simplification for a sum of products targeted at the kind of blowup that + occurs during summation of products. Intended to reduce expression blowup + during matrix multiplication or other similar operations. Only works with + algebraic expressions and does not recurse into non. + + Parameters + ========== + + withsimp : bool, optional + Specifies whether a flag should be returned along with the expression + to indicate roughly whether simplification was successful. It is used + in ``MatrixArithmetic._eval_pow_by_recursion`` to avoid attempting to + simplify an expression repetitively which does not simplify. + """ + + def count_ops_alg(expr): + """Optimized count algebraic operations with no recursion into + non-algebraic args that ``core.function.count_ops`` does. Also returns + whether rational functions may be present according to negative + exponents of powers or non-number fractions. + + Returns + ======= + + ops, ratfunc : int, bool + ``ops`` is the number of algebraic operations starting at the top + level expression (not recursing into non-alg children). ``ratfunc`` + specifies whether the expression MAY contain rational functions + which ``cancel`` MIGHT optimize. + """ + + ops = 0 + args = [expr] + ratfunc = False + + while args: + a = args.pop() + + if not isinstance(a, Basic): + continue + + if a.is_Rational: + if a is not S.One: # -1/3 = NEG + DIV + ops += bool (a.p < 0) + bool (a.q != 1) + + elif a.is_Mul: + if a.could_extract_minus_sign(): + ops += 1 + if a.args[0] is S.NegativeOne: + a = a.as_two_terms()[1] + else: + a = -a + + n, d = fraction(a) + + if n.is_Integer: + ops += 1 + bool (n < 0) + args.append(d) # won't be -Mul but could be Add + + elif d is not S.One: + if not d.is_Integer: + args.append(d) + ratfunc=True + + ops += 1 + args.append(n) # could be -Mul + + else: + ops += len(a.args) - 1 + args.extend(a.args) + + elif a.is_Add: + laargs = len(a.args) + negs = 0 + + for ai in a.args: + if ai.could_extract_minus_sign(): + negs += 1 + ai = -ai + args.append(ai) + + ops += laargs - (negs != laargs) # -x - y = NEG + SUB + + elif a.is_Pow: + ops += 1 + args.append(a.base) + + if not ratfunc: + ratfunc = a.exp.is_negative is not False + + return ops, ratfunc + + def nonalg_subs_dummies(expr, dummies): + """Substitute dummy variables for non-algebraic expressions to avoid + evaluation of non-algebraic terms that ``polys.polytools.cancel`` does. + """ + + if not expr.args: + return expr + + if expr.is_Add or expr.is_Mul or expr.is_Pow: + args = None + + for i, a in enumerate(expr.args): + c = nonalg_subs_dummies(a, dummies) + + if c is a: + continue + + if args is None: + args = list(expr.args) + + args[i] = c + + if args is None: + return expr + + return expr.func(*args) + + return dummies.setdefault(expr, Dummy()) + + simplified = False # doesn't really mean simplified, rather "can simplify again" + + if isinstance(expr, Basic) and (expr.is_Add or expr.is_Mul or expr.is_Pow): + expr2 = expr.expand(deep=True, modulus=None, power_base=False, + power_exp=False, mul=True, log=False, multinomial=True, basic=False) + + if expr2 != expr: + expr = expr2 + simplified = True + + exprops, ratfunc = count_ops_alg(expr) + + if exprops >= 6: # empirically tested cutoff for expensive simplification + if ratfunc: + dummies = {} + expr2 = nonalg_subs_dummies(expr, dummies) + + if expr2 is expr or count_ops_alg(expr2)[0] >= 6: # check again after substitution + expr3 = cancel(expr2) + + if expr3 != expr2: + expr = expr3.subs([(d, e) for e, d in dummies.items()]) + simplified = True + + # very special case: x/(x-1) - 1/(x-1) -> 1 + elif (exprops == 5 and expr.is_Add and expr.args [0].is_Mul and + expr.args [1].is_Mul and expr.args [0].args [-1].is_Pow and + expr.args [1].args [-1].is_Pow and + expr.args [0].args [-1].exp is S.NegativeOne and + expr.args [1].args [-1].exp is S.NegativeOne): + + expr2 = together (expr) + expr2ops = count_ops_alg(expr2)[0] + + if expr2ops < exprops: + expr = expr2 + simplified = True + + else: + simplified = True + + return (expr, simplified) if withsimp else expr + + +bottom_up = deprecated( + """ + Using bottom_up from the sympy.simplify.simplify submodule is + deprecated. + + Instead, use bottom_up from the top-level sympy namespace, like + + sympy.bottom_up + """, + deprecated_since_version="1.10", + active_deprecations_target="deprecated-traversal-functions-moved", +)(_bottom_up) + + +# XXX: This function really should either be private API or exported in the +# top-level sympy/__init__.py +walk = deprecated( + """ + Using walk from the sympy.simplify.simplify submodule is + deprecated. + + Instead, use walk from sympy.core.traversal.walk + """, + deprecated_since_version="1.10", + active_deprecations_target="deprecated-traversal-functions-moved", +)(_walk) diff --git a/phivenv/Lib/site-packages/sympy/simplify/sqrtdenest.py b/phivenv/Lib/site-packages/sympy/simplify/sqrtdenest.py new file mode 100644 index 0000000000000000000000000000000000000000..d266de7e62a4b7d37a2109f7091ff91e4df7c79d --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/simplify/sqrtdenest.py @@ -0,0 +1,678 @@ +from sympy.core import Add, Expr, Mul, S, sympify +from sympy.core.function import _mexpand, count_ops, expand_mul +from sympy.core.sorting import default_sort_key +from sympy.core.symbol import Dummy +from sympy.functions import root, sign, sqrt +from sympy.polys import Poly, PolynomialError + + +def is_sqrt(expr): + """Return True if expr is a sqrt, otherwise False.""" + + return expr.is_Pow and expr.exp.is_Rational and abs(expr.exp) is S.Half + + +def sqrt_depth(p) -> int: + """Return the maximum depth of any square root argument of p. + + >>> from sympy.functions.elementary.miscellaneous import sqrt + >>> from sympy.simplify.sqrtdenest import sqrt_depth + + Neither of these square roots contains any other square roots + so the depth is 1: + + >>> sqrt_depth(1 + sqrt(2)*(1 + sqrt(3))) + 1 + + The sqrt(3) is contained within a square root so the depth is + 2: + + >>> sqrt_depth(1 + sqrt(2)*sqrt(1 + sqrt(3))) + 2 + """ + if p is S.ImaginaryUnit: + return 1 + if p.is_Atom: + return 0 + if p.is_Add or p.is_Mul: + return max(sqrt_depth(x) for x in p.args) + if is_sqrt(p): + return sqrt_depth(p.base) + 1 + return 0 + + +def is_algebraic(p): + """Return True if p is comprised of only Rationals or square roots + of Rationals and algebraic operations. + + Examples + ======== + + >>> from sympy.functions.elementary.miscellaneous import sqrt + >>> from sympy.simplify.sqrtdenest import is_algebraic + >>> from sympy import cos + >>> is_algebraic(sqrt(2)*(3/(sqrt(7) + sqrt(5)*sqrt(2)))) + True + >>> is_algebraic(sqrt(2)*(3/(sqrt(7) + sqrt(5)*cos(2)))) + False + """ + + if p.is_Rational: + return True + elif p.is_Atom: + return False + elif is_sqrt(p) or p.is_Pow and p.exp.is_Integer: + return is_algebraic(p.base) + elif p.is_Add or p.is_Mul: + return all(is_algebraic(x) for x in p.args) + else: + return False + + +def _subsets(n): + """ + Returns all possible subsets of the set (0, 1, ..., n-1) except the + empty set, listed in reversed lexicographical order according to binary + representation, so that the case of the fourth root is treated last. + + Examples + ======== + + >>> from sympy.simplify.sqrtdenest import _subsets + >>> _subsets(2) + [[1, 0], [0, 1], [1, 1]] + + """ + if n == 1: + a = [[1]] + elif n == 2: + a = [[1, 0], [0, 1], [1, 1]] + elif n == 3: + a = [[1, 0, 0], [0, 1, 0], [1, 1, 0], + [0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1]] + else: + b = _subsets(n - 1) + a0 = [x + [0] for x in b] + a1 = [x + [1] for x in b] + a = a0 + [[0]*(n - 1) + [1]] + a1 + return a + + +def sqrtdenest(expr, max_iter=3): + """Denests sqrts in an expression that contain other square roots + if possible, otherwise returns the expr unchanged. This is based on the + algorithms of [1]. + + Examples + ======== + + >>> from sympy.simplify.sqrtdenest import sqrtdenest + >>> from sympy import sqrt + >>> sqrtdenest(sqrt(5 + 2 * sqrt(6))) + sqrt(2) + sqrt(3) + + See Also + ======== + + sympy.solvers.solvers.unrad + + References + ========== + + .. [1] https://web.archive.org/web/20210806201615/https://researcher.watson.ibm.com/researcher/files/us-fagin/symb85.pdf + + .. [2] D. J. Jeffrey and A. D. Rich, 'Symplifying Square Roots of Square Roots + by Denesting' (available at https://www.cybertester.com/data/denest.pdf) + + """ + expr = expand_mul(expr) + for i in range(max_iter): + z = _sqrtdenest0(expr) + if expr == z: + return expr + expr = z + return expr + + +def _sqrt_match(p): + """Return [a, b, r] for p.match(a + b*sqrt(r)) where, in addition to + matching, sqrt(r) also has then maximal sqrt_depth among addends of p. + + Examples + ======== + + >>> from sympy.functions.elementary.miscellaneous import sqrt + >>> from sympy.simplify.sqrtdenest import _sqrt_match + >>> _sqrt_match(1 + sqrt(2) + sqrt(2)*sqrt(3) + 2*sqrt(1+sqrt(5))) + [1 + sqrt(2) + sqrt(6), 2, 1 + sqrt(5)] + """ + from sympy.simplify.radsimp import split_surds + + p = _mexpand(p) + if p.is_Number: + res = (p, S.Zero, S.Zero) + elif p.is_Add: + pargs = sorted(p.args, key=default_sort_key) + sqargs = [x**2 for x in pargs] + if all(sq.is_Rational and sq.is_positive for sq in sqargs): + r, b, a = split_surds(p) + res = a, b, r + return list(res) + # to make the process canonical, the argument is included in the tuple + # so when the max is selected, it will be the largest arg having a + # given depth + v = [(sqrt_depth(x), x, i) for i, x in enumerate(pargs)] + nmax = max(v, key=default_sort_key) + if nmax[0] == 0: + res = [] + else: + # select r + depth, _, i = nmax + r = pargs.pop(i) + v.pop(i) + b = S.One + if r.is_Mul: + bv = [] + rv = [] + for x in r.args: + if sqrt_depth(x) < depth: + bv.append(x) + else: + rv.append(x) + b = Mul._from_args(bv) + r = Mul._from_args(rv) + # collect terms containing r + a1 = [] + b1 = [b] + for x in v: + if x[0] < depth: + a1.append(x[1]) + else: + x1 = x[1] + if x1 == r: + b1.append(1) + else: + if x1.is_Mul: + x1args = list(x1.args) + if r in x1args: + x1args.remove(r) + b1.append(Mul(*x1args)) + else: + a1.append(x[1]) + else: + a1.append(x[1]) + a = Add(*a1) + b = Add(*b1) + res = (a, b, r**2) + else: + b, r = p.as_coeff_Mul() + if is_sqrt(r): + res = (S.Zero, b, r**2) + else: + res = [] + return list(res) + + +class SqrtdenestStopIteration(StopIteration): + pass + + +def _sqrtdenest0(expr): + """Returns expr after denesting its arguments.""" + + if is_sqrt(expr): + n, d = expr.as_numer_denom() + if d is S.One: # n is a square root + if n.base.is_Add: + args = sorted(n.base.args, key=default_sort_key) + if len(args) > 2 and all((x**2).is_Integer for x in args): + try: + return _sqrtdenest_rec(n) + except SqrtdenestStopIteration: + pass + expr = sqrt(_mexpand(Add(*[_sqrtdenest0(x) for x in args]))) + return _sqrtdenest1(expr) + else: + n, d = [_sqrtdenest0(i) for i in (n, d)] + return n/d + + if isinstance(expr, Add): + cs = [] + args = [] + for arg in expr.args: + c, a = arg.as_coeff_Mul() + cs.append(c) + args.append(a) + + if all(c.is_Rational for c in cs) and all(is_sqrt(arg) for arg in args): + return _sqrt_ratcomb(cs, args) + + if isinstance(expr, Expr): + args = expr.args + if args: + return expr.func(*[_sqrtdenest0(a) for a in args]) + return expr + + +def _sqrtdenest_rec(expr): + """Helper that denests the square root of three or more surds. + + Explanation + =========== + + It returns the denested expression; if it cannot be denested it + throws SqrtdenestStopIteration + + Algorithm: expr.base is in the extension Q_m = Q(sqrt(r_1),..,sqrt(r_k)); + split expr.base = a + b*sqrt(r_k), where `a` and `b` are on + Q_(m-1) = Q(sqrt(r_1),..,sqrt(r_(k-1))); then a**2 - b**2*r_k is + on Q_(m-1); denest sqrt(a**2 - b**2*r_k) and so on. + See [1], section 6. + + Examples + ======== + + >>> from sympy import sqrt + >>> from sympy.simplify.sqrtdenest import _sqrtdenest_rec + >>> _sqrtdenest_rec(sqrt(-72*sqrt(2) + 158*sqrt(5) + 498)) + -sqrt(10) + sqrt(2) + 9 + 9*sqrt(5) + >>> w=-6*sqrt(55)-6*sqrt(35)-2*sqrt(22)-2*sqrt(14)+2*sqrt(77)+6*sqrt(10)+65 + >>> _sqrtdenest_rec(sqrt(w)) + -sqrt(11) - sqrt(7) + sqrt(2) + 3*sqrt(5) + """ + from sympy.simplify.radsimp import radsimp, rad_rationalize, split_surds + if not expr.is_Pow: + return sqrtdenest(expr) + if expr.base < 0: + return sqrt(-1)*_sqrtdenest_rec(sqrt(-expr.base)) + g, a, b = split_surds(expr.base) + a = a*sqrt(g) + if a < b: + a, b = b, a + c2 = _mexpand(a**2 - b**2) + if len(c2.args) > 2: + g, a1, b1 = split_surds(c2) + a1 = a1*sqrt(g) + if a1 < b1: + a1, b1 = b1, a1 + c2_1 = _mexpand(a1**2 - b1**2) + c_1 = _sqrtdenest_rec(sqrt(c2_1)) + d_1 = _sqrtdenest_rec(sqrt(a1 + c_1)) + num, den = rad_rationalize(b1, d_1) + c = _mexpand(d_1/sqrt(2) + num/(den*sqrt(2))) + else: + c = _sqrtdenest1(sqrt(c2)) + + if sqrt_depth(c) > 1: + raise SqrtdenestStopIteration + ac = a + c + if len(ac.args) >= len(expr.args): + if count_ops(ac) >= count_ops(expr.base): + raise SqrtdenestStopIteration + d = sqrtdenest(sqrt(ac)) + if sqrt_depth(d) > 1: + raise SqrtdenestStopIteration + num, den = rad_rationalize(b, d) + r = d/sqrt(2) + num/(den*sqrt(2)) + r = radsimp(r) + return _mexpand(r) + + +def _sqrtdenest1(expr, denester=True): + """Return denested expr after denesting with simpler methods or, that + failing, using the denester.""" + + from sympy.simplify.simplify import radsimp + + if not is_sqrt(expr): + return expr + + a = expr.base + if a.is_Atom: + return expr + val = _sqrt_match(a) + if not val: + return expr + + a, b, r = val + # try a quick numeric denesting + d2 = _mexpand(a**2 - b**2*r) + if d2.is_Rational: + if d2.is_positive: + z = _sqrt_numeric_denest(a, b, r, d2) + if z is not None: + return z + else: + # fourth root case + # sqrtdenest(sqrt(3 + 2*sqrt(3))) = + # sqrt(2)*3**(1/4)/2 + sqrt(2)*3**(3/4)/2 + dr2 = _mexpand(-d2*r) + dr = sqrt(dr2) + if dr.is_Rational: + z = _sqrt_numeric_denest(_mexpand(b*r), a, r, dr2) + if z is not None: + return z/root(r, 4) + + else: + z = _sqrt_symbolic_denest(a, b, r) + if z is not None: + return z + + if not denester or not is_algebraic(expr): + return expr + + res = sqrt_biquadratic_denest(expr, a, b, r, d2) + if res: + return res + + # now call to the denester + av0 = [a, b, r, d2] + z = _denester([radsimp(expr**2)], av0, 0, sqrt_depth(expr))[0] + if av0[1] is None: + return expr + if z is not None: + if sqrt_depth(z) == sqrt_depth(expr) and count_ops(z) > count_ops(expr): + return expr + return z + return expr + + +def _sqrt_symbolic_denest(a, b, r): + """Given an expression, sqrt(a + b*sqrt(b)), return the denested + expression or None. + + Explanation + =========== + + If r = ra + rb*sqrt(rr), try replacing sqrt(rr) in ``a`` with + (y**2 - ra)/rb, and if the result is a quadratic, ca*y**2 + cb*y + cc, and + (cb + b)**2 - 4*ca*cc is 0, then sqrt(a + b*sqrt(r)) can be rewritten as + sqrt(ca*(sqrt(r) + (cb + b)/(2*ca))**2). + + Examples + ======== + + >>> from sympy.simplify.sqrtdenest import _sqrt_symbolic_denest, sqrtdenest + >>> from sympy import sqrt, Symbol + >>> from sympy.abc import x + + >>> a, b, r = 16 - 2*sqrt(29), 2, -10*sqrt(29) + 55 + >>> _sqrt_symbolic_denest(a, b, r) + sqrt(11 - 2*sqrt(29)) + sqrt(5) + + If the expression is numeric, it will be simplified: + + >>> w = sqrt(sqrt(sqrt(3) + 1) + 1) + 1 + sqrt(2) + >>> sqrtdenest(sqrt((w**2).expand())) + 1 + sqrt(2) + sqrt(1 + sqrt(1 + sqrt(3))) + + Otherwise, it will only be simplified if assumptions allow: + + >>> w = w.subs(sqrt(3), sqrt(x + 3)) + >>> sqrtdenest(sqrt((w**2).expand())) + sqrt((sqrt(sqrt(sqrt(x + 3) + 1) + 1) + 1 + sqrt(2))**2) + + Notice that the argument of the sqrt is a square. If x is made positive + then the sqrt of the square is resolved: + + >>> _.subs(x, Symbol('x', positive=True)) + sqrt(sqrt(sqrt(x + 3) + 1) + 1) + 1 + sqrt(2) + """ + + a, b, r = map(sympify, (a, b, r)) + rval = _sqrt_match(r) + if not rval: + return None + ra, rb, rr = rval + if rb: + y = Dummy('y', positive=True) + try: + newa = Poly(a.subs(sqrt(rr), (y**2 - ra)/rb), y) + except PolynomialError: + return None + if newa.degree() == 2: + ca, cb, cc = newa.all_coeffs() + cb += b + if _mexpand(cb**2 - 4*ca*cc).equals(0): + z = sqrt(ca*(sqrt(r) + cb/(2*ca))**2) + if z.is_number: + z = _mexpand(Mul._from_args(z.as_content_primitive())) + return z + + +def _sqrt_numeric_denest(a, b, r, d2): + r"""Helper that denest + $\sqrt{a + b \sqrt{r}}, d^2 = a^2 - b^2 r > 0$ + + If it cannot be denested, it returns ``None``. + """ + d = sqrt(d2) + s = a + d + # sqrt_depth(res) <= sqrt_depth(s) + 1 + # sqrt_depth(expr) = sqrt_depth(r) + 2 + # there is denesting if sqrt_depth(s) + 1 < sqrt_depth(r) + 2 + # if s**2 is Number there is a fourth root + if sqrt_depth(s) < sqrt_depth(r) + 1 or (s**2).is_Rational: + s1, s2 = sign(s), sign(b) + if s1 == s2 == -1: + s1 = s2 = 1 + res = (s1 * sqrt(a + d) + s2 * sqrt(a - d)) * sqrt(2) / 2 + return res.expand() + + +def sqrt_biquadratic_denest(expr, a, b, r, d2): + """denest expr = sqrt(a + b*sqrt(r)) + where a, b, r are linear combinations of square roots of + positive rationals on the rationals (SQRR) and r > 0, b != 0, + d2 = a**2 - b**2*r > 0 + + If it cannot denest it returns None. + + Explanation + =========== + + Search for a solution A of type SQRR of the biquadratic equation + 4*A**4 - 4*a*A**2 + b**2*r = 0 (1) + sqd = sqrt(a**2 - b**2*r) + Choosing the sqrt to be positive, the possible solutions are + A = sqrt(a/2 +/- sqd/2) + Since a, b, r are SQRR, then a**2 - b**2*r is a SQRR, + so if sqd can be denested, it is done by + _sqrtdenest_rec, and the result is a SQRR. + Similarly for A. + Examples of solutions (in both cases a and sqd are positive): + + Example of expr with solution sqrt(a/2 + sqd/2) but not + solution sqrt(a/2 - sqd/2): + expr = sqrt(-sqrt(15) - sqrt(2)*sqrt(-sqrt(5) + 5) - sqrt(3) + 8) + a = -sqrt(15) - sqrt(3) + 8; sqd = -2*sqrt(5) - 2 + 4*sqrt(3) + + Example of expr with solution sqrt(a/2 - sqd/2) but not + solution sqrt(a/2 + sqd/2): + w = 2 + r2 + r3 + (1 + r3)*sqrt(2 + r2 + 5*r3) + expr = sqrt((w**2).expand()) + a = 4*sqrt(6) + 8*sqrt(2) + 47 + 28*sqrt(3) + sqd = 29 + 20*sqrt(3) + + Define B = b/2*A; eq.(1) implies a = A**2 + B**2*r; then + expr**2 = a + b*sqrt(r) = (A + B*sqrt(r))**2 + + Examples + ======== + + >>> from sympy import sqrt + >>> from sympy.simplify.sqrtdenest import _sqrt_match, sqrt_biquadratic_denest + >>> z = sqrt((2*sqrt(2) + 4)*sqrt(2 + sqrt(2)) + 5*sqrt(2) + 8) + >>> a, b, r = _sqrt_match(z**2) + >>> d2 = a**2 - b**2*r + >>> sqrt_biquadratic_denest(z, a, b, r, d2) + sqrt(2) + sqrt(sqrt(2) + 2) + 2 + """ + from sympy.simplify.radsimp import radsimp, rad_rationalize + if r <= 0 or d2 < 0 or not b or sqrt_depth(expr.base) < 2: + return None + for x in (a, b, r): + for y in x.args: + y2 = y**2 + if not y2.is_Integer or not y2.is_positive: + return None + sqd = _mexpand(sqrtdenest(sqrt(radsimp(d2)))) + if sqrt_depth(sqd) > 1: + return None + x1, x2 = [a/2 + sqd/2, a/2 - sqd/2] + # look for a solution A with depth 1 + for x in (x1, x2): + A = sqrtdenest(sqrt(x)) + if sqrt_depth(A) > 1: + continue + Bn, Bd = rad_rationalize(b, _mexpand(2*A)) + B = Bn/Bd + z = A + B*sqrt(r) + if z < 0: + z = -z + return _mexpand(z) + return None + + +def _denester(nested, av0, h, max_depth_level): + """Denests a list of expressions that contain nested square roots. + + Explanation + =========== + + Algorithm based on . + + It is assumed that all of the elements of 'nested' share the same + bottom-level radicand. (This is stated in the paper, on page 177, in + the paragraph immediately preceding the algorithm.) + + When evaluating all of the arguments in parallel, the bottom-level + radicand only needs to be denested once. This means that calling + _denester with x arguments results in a recursive invocation with x+1 + arguments; hence _denester has polynomial complexity. + + However, if the arguments were evaluated separately, each call would + result in two recursive invocations, and the algorithm would have + exponential complexity. + + This is discussed in the paper in the middle paragraph of page 179. + """ + from sympy.simplify.simplify import radsimp + if h > max_depth_level: + return None, None + if av0[1] is None: + return None, None + if (av0[0] is None and + all(n.is_Number for n in nested)): # no arguments are nested + for f in _subsets(len(nested)): # test subset 'f' of nested + p = _mexpand(Mul(*[nested[i] for i in range(len(f)) if f[i]])) + if f.count(1) > 1 and f[-1]: + p = -p + sqp = sqrt(p) + if sqp.is_Rational: + return sqp, f # got a perfect square so return its square root. + # Otherwise, return the radicand from the previous invocation. + return sqrt(nested[-1]), [0]*len(nested) + else: + R = None + if av0[0] is not None: + values = [av0[:2]] + R = av0[2] + nested2 = [av0[3], R] + av0[0] = None + else: + values = list(filter(None, [_sqrt_match(expr) for expr in nested])) + for v in values: + if v[2]: # Since if b=0, r is not defined + if R is not None: + if R != v[2]: + av0[1] = None + return None, None + else: + R = v[2] + if R is None: + # return the radicand from the previous invocation + return sqrt(nested[-1]), [0]*len(nested) + nested2 = [_mexpand(v[0]**2) - + _mexpand(R*v[1]**2) for v in values] + [R] + d, f = _denester(nested2, av0, h + 1, max_depth_level) + if not f: + return None, None + if not any(f[i] for i in range(len(nested))): + v = values[-1] + return sqrt(v[0] + _mexpand(v[1]*d)), f + else: + p = Mul(*[nested[i] for i in range(len(nested)) if f[i]]) + v = _sqrt_match(p) + if 1 in f and f.index(1) < len(nested) - 1 and f[len(nested) - 1]: + v[0] = -v[0] + v[1] = -v[1] + if not f[len(nested)]: # Solution denests with square roots + vad = _mexpand(v[0] + d) + if vad <= 0: + # return the radicand from the previous invocation. + return sqrt(nested[-1]), [0]*len(nested) + if not(sqrt_depth(vad) <= sqrt_depth(R) + 1 or + (vad**2).is_Number): + av0[1] = None + return None, None + + sqvad = _sqrtdenest1(sqrt(vad), denester=False) + if not (sqrt_depth(sqvad) <= sqrt_depth(R) + 1): + av0[1] = None + return None, None + sqvad1 = radsimp(1/sqvad) + res = _mexpand(sqvad/sqrt(2) + (v[1]*sqrt(R)*sqvad1/sqrt(2))) + return res, f + + # sign(v[1])*sqrt(_mexpand(v[1]**2*R*vad1/2))), f + else: # Solution requires a fourth root + s2 = _mexpand(v[1]*R) + d + if s2 <= 0: + return sqrt(nested[-1]), [0]*len(nested) + FR, s = root(_mexpand(R), 4), sqrt(s2) + return _mexpand(s/(sqrt(2)*FR) + v[0]*FR/(sqrt(2)*s)), f + + +def _sqrt_ratcomb(cs, args): + """Denest rational combinations of radicals. + + Based on section 5 of [1]. + + Examples + ======== + + >>> from sympy import sqrt + >>> from sympy.simplify.sqrtdenest import sqrtdenest + >>> z = sqrt(1+sqrt(3)) + sqrt(3+3*sqrt(3)) - sqrt(10+6*sqrt(3)) + >>> sqrtdenest(z) + 0 + """ + from sympy.simplify.radsimp import radsimp + + # check if there exists a pair of sqrt that can be denested + def find(a): + n = len(a) + for i in range(n - 1): + for j in range(i + 1, n): + s1 = a[i].base + s2 = a[j].base + p = _mexpand(s1 * s2) + s = sqrtdenest(sqrt(p)) + if s != sqrt(p): + return s, i, j + + indices = find(args) + if indices is None: + return Add(*[c * arg for c, arg in zip(cs, args)]) + + s, i1, i2 = indices + + c2 = cs.pop(i2) + args.pop(i2) + a1 = args[i1] + + # replace a2 by s/a1 + cs[i1] += radsimp(c2 * s / a1.base) + + return _sqrt_ratcomb(cs, args) diff --git a/phivenv/Lib/site-packages/sympy/simplify/tests/__init__.py b/phivenv/Lib/site-packages/sympy/simplify/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/simplify/tests/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/simplify/tests/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..148390a3ba1e53cbf9aa0e4c44415bdc819baf33 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/simplify/tests/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/simplify/tests/__pycache__/test_combsimp.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/simplify/tests/__pycache__/test_combsimp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26a8149c20ff90b0653fe3c46e2c23f721048f0c Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/simplify/tests/__pycache__/test_combsimp.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/simplify/tests/__pycache__/test_cse.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/simplify/tests/__pycache__/test_cse.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e246ab2c43753f6697d1716d0e12a47e1068e7f3 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/simplify/tests/__pycache__/test_cse.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/simplify/tests/__pycache__/test_cse_diff.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/simplify/tests/__pycache__/test_cse_diff.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e92b8bce5eb92c5bd5bf94107732eedf3c85136b Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/simplify/tests/__pycache__/test_cse_diff.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/simplify/tests/__pycache__/test_epathtools.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/simplify/tests/__pycache__/test_epathtools.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..12f3039c3c57bc290e34a24eb8c8a7ac72b97323 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/simplify/tests/__pycache__/test_epathtools.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/simplify/tests/__pycache__/test_fu.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/simplify/tests/__pycache__/test_fu.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a2d8ca87e556a52206c7753c86e6fb8c7ef6bd0a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/simplify/tests/__pycache__/test_fu.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/simplify/tests/__pycache__/test_function.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/simplify/tests/__pycache__/test_function.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..58a8066b5023ac1bebca78de8ee4e98485cff173 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/simplify/tests/__pycache__/test_function.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/simplify/tests/__pycache__/test_gammasimp.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/simplify/tests/__pycache__/test_gammasimp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..713456f395d029c51ed5b62590b2e28ef6968edd Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/simplify/tests/__pycache__/test_gammasimp.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/simplify/tests/__pycache__/test_hyperexpand.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/simplify/tests/__pycache__/test_hyperexpand.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45eed5a9b2a35c1f9695e5dc45b2478cc0eebde2 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/simplify/tests/__pycache__/test_hyperexpand.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/simplify/tests/__pycache__/test_powsimp.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/simplify/tests/__pycache__/test_powsimp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21220790aa690cacf4473546aea3df8861b2ff93 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/simplify/tests/__pycache__/test_powsimp.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/simplify/tests/__pycache__/test_radsimp.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/simplify/tests/__pycache__/test_radsimp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c21bee36093dc62027ed4d9403824e02c2cb0fa Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/simplify/tests/__pycache__/test_radsimp.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/simplify/tests/__pycache__/test_ratsimp.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/simplify/tests/__pycache__/test_ratsimp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40546f1093f1fa9c49ccb289acdf37b24f3c6f4d Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/simplify/tests/__pycache__/test_ratsimp.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/simplify/tests/__pycache__/test_rewrite.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/simplify/tests/__pycache__/test_rewrite.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4e02acdeb3fa3dd36fa0431a7588f3f7cf8749f Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/simplify/tests/__pycache__/test_rewrite.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/simplify/tests/__pycache__/test_simplify.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/simplify/tests/__pycache__/test_simplify.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11f664df4cb9966d153e1765ce3331350d94c7f2 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/simplify/tests/__pycache__/test_simplify.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/simplify/tests/__pycache__/test_sqrtdenest.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/simplify/tests/__pycache__/test_sqrtdenest.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d864b8cfb85dc6910c2f78664a4aef6a1d0ad4c6 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/simplify/tests/__pycache__/test_sqrtdenest.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/simplify/tests/__pycache__/test_trigsimp.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/simplify/tests/__pycache__/test_trigsimp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ebe1c7f1e95792d43a89cec3c029870ada10c0ae Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/simplify/tests/__pycache__/test_trigsimp.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/simplify/tests/test_combsimp.py b/phivenv/Lib/site-packages/sympy/simplify/tests/test_combsimp.py new file mode 100644 index 0000000000000000000000000000000000000000..e56758a005fbb013c2b6ea4121b16c3434a54b03 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/simplify/tests/test_combsimp.py @@ -0,0 +1,75 @@ +from sympy.core.numbers import Rational +from sympy.core.symbol import symbols +from sympy.functions.combinatorial.factorials import (FallingFactorial, RisingFactorial, binomial, factorial) +from sympy.functions.special.gamma_functions import gamma +from sympy.simplify.combsimp import combsimp +from sympy.abc import x + + +def test_combsimp(): + k, m, n = symbols('k m n', integer = True) + + assert combsimp(factorial(n)) == factorial(n) + assert combsimp(binomial(n, k)) == binomial(n, k) + + assert combsimp(factorial(n)/factorial(n - 3)) == n*(-1 + n)*(-2 + n) + assert combsimp(binomial(n + 1, k + 1)/binomial(n, k)) == (1 + n)/(1 + k) + + assert combsimp(binomial(3*n + 4, n + 1)/binomial(3*n + 1, n)) == \ + Rational(3, 2)*((3*n + 2)*(3*n + 4)/((n + 1)*(2*n + 3))) + + assert combsimp(factorial(n)**2/factorial(n - 3)) == \ + factorial(n)*n*(-1 + n)*(-2 + n) + assert combsimp(factorial(n)*binomial(n + 1, k + 1)/binomial(n, k)) == \ + factorial(n + 1)/(1 + k) + + assert combsimp(gamma(n + 3)) == factorial(n + 2) + + assert combsimp(factorial(x)) == gamma(x + 1) + + # issue 9699 + assert combsimp((n + 1)*factorial(n)) == factorial(n + 1) + assert combsimp(factorial(n)/n) == factorial(n-1) + + # issue 6658 + assert combsimp(binomial(n, n - k)) == binomial(n, k) + + # issue 6341, 7135 + assert combsimp(factorial(n)/(factorial(k)*factorial(n - k))) == \ + binomial(n, k) + assert combsimp(factorial(k)*factorial(n - k)/factorial(n)) == \ + 1/binomial(n, k) + assert combsimp(factorial(2*n)/factorial(n)**2) == binomial(2*n, n) + assert combsimp(factorial(2*n)*factorial(k)*factorial(n - k)/ + factorial(n)**3) == binomial(2*n, n)/binomial(n, k) + + assert combsimp(factorial(n*(1 + n) - n**2 - n)) == 1 + + assert combsimp(6*FallingFactorial(-4, n)/factorial(n)) == \ + (-1)**n*(n + 1)*(n + 2)*(n + 3) + assert combsimp(6*FallingFactorial(-4, n - 1)/factorial(n - 1)) == \ + (-1)**(n - 1)*n*(n + 1)*(n + 2) + assert combsimp(6*FallingFactorial(-4, n - 3)/factorial(n - 3)) == \ + (-1)**(n - 3)*n*(n - 1)*(n - 2) + assert combsimp(6*FallingFactorial(-4, -n - 1)/factorial(-n - 1)) == \ + -(-1)**(-n - 1)*n*(n - 1)*(n - 2) + + assert combsimp(6*RisingFactorial(4, n)/factorial(n)) == \ + (n + 1)*(n + 2)*(n + 3) + assert combsimp(6*RisingFactorial(4, n - 1)/factorial(n - 1)) == \ + n*(n + 1)*(n + 2) + assert combsimp(6*RisingFactorial(4, n - 3)/factorial(n - 3)) == \ + n*(n - 1)*(n - 2) + assert combsimp(6*RisingFactorial(4, -n - 1)/factorial(-n - 1)) == \ + -n*(n - 1)*(n - 2) + + +def test_issue_6878(): + n = symbols('n', integer=True) + assert combsimp(RisingFactorial(-10, n)) == 3628800*(-1)**n/factorial(10 - n) + + +def test_issue_14528(): + p = symbols("p", integer=True, positive=True) + assert combsimp(binomial(1,p)) == 1/(factorial(p)*factorial(1-p)) + assert combsimp(factorial(2-p)) == factorial(2-p) diff --git a/phivenv/Lib/site-packages/sympy/simplify/tests/test_cse.py b/phivenv/Lib/site-packages/sympy/simplify/tests/test_cse.py new file mode 100644 index 0000000000000000000000000000000000000000..c2a34dfb0e227547bd41bed2491284fd7150d0b6 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/simplify/tests/test_cse.py @@ -0,0 +1,761 @@ +from functools import reduce +import itertools +from operator import add + +from sympy.codegen.matrix_nodes import MatrixSolve +from sympy.core.add import Add +from sympy.core.containers import Tuple +from sympy.core.expr import UnevaluatedExpr +from sympy.core.function import Function +from sympy.core.mul import Mul +from sympy.core.power import Pow +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.core.sympify import sympify +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.matrices.dense import Matrix +from sympy.matrices.expressions import Inverse, MatAdd, MatMul, Transpose +from sympy.polys.rootoftools import CRootOf +from sympy.series.order import O +from sympy.simplify.cse_main import cse +from sympy.simplify.simplify import signsimp +from sympy.tensor.indexed import (Idx, IndexedBase) + +from sympy.core.function import count_ops +from sympy.simplify.cse_opts import sub_pre, sub_post +from sympy.functions.special.hyper import meijerg +from sympy.simplify import cse_main, cse_opts +from sympy.utilities.iterables import subsets +from sympy.testing.pytest import XFAIL, raises +from sympy.matrices import (MutableDenseMatrix, MutableSparseMatrix, + ImmutableDenseMatrix, ImmutableSparseMatrix) +from sympy.matrices.expressions import MatrixSymbol + + +w, x, y, z = symbols('w,x,y,z') +x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12 = symbols('x:13') + + +def test_numbered_symbols(): + ns = cse_main.numbered_symbols(prefix='y') + assert list(itertools.islice( + ns, 0, 10)) == [Symbol('y%s' % i) for i in range(0, 10)] + ns = cse_main.numbered_symbols(prefix='y') + assert list(itertools.islice( + ns, 10, 20)) == [Symbol('y%s' % i) for i in range(10, 20)] + ns = cse_main.numbered_symbols() + assert list(itertools.islice( + ns, 0, 10)) == [Symbol('x%s' % i) for i in range(0, 10)] + +# Dummy "optimization" functions for testing. + + +def opt1(expr): + return expr + y + + +def opt2(expr): + return expr*z + + +def test_preprocess_for_cse(): + assert cse_main.preprocess_for_cse(x, [(opt1, None)]) == x + y + assert cse_main.preprocess_for_cse(x, [(None, opt1)]) == x + assert cse_main.preprocess_for_cse(x, [(None, None)]) == x + assert cse_main.preprocess_for_cse(x, [(opt1, opt2)]) == x + y + assert cse_main.preprocess_for_cse( + x, [(opt1, None), (opt2, None)]) == (x + y)*z + + +def test_postprocess_for_cse(): + assert cse_main.postprocess_for_cse(x, [(opt1, None)]) == x + assert cse_main.postprocess_for_cse(x, [(None, opt1)]) == x + y + assert cse_main.postprocess_for_cse(x, [(None, None)]) == x + assert cse_main.postprocess_for_cse(x, [(opt1, opt2)]) == x*z + # Note the reverse order of application. + assert cse_main.postprocess_for_cse( + x, [(None, opt1), (None, opt2)]) == x*z + y + + +def test_cse_single(): + # Simple substitution. + e = Add(Pow(x + y, 2), sqrt(x + y)) + substs, reduced = cse([e]) + assert substs == [(x0, x + y)] + assert reduced == [sqrt(x0) + x0**2] + + subst42, (red42,) = cse([42]) # issue_15082 + assert len(subst42) == 0 and red42 == 42 + subst_half, (red_half,) = cse([0.5]) + assert len(subst_half) == 0 and red_half == 0.5 + + +def test_cse_single2(): + # Simple substitution, test for being able to pass the expression directly + e = Add(Pow(x + y, 2), sqrt(x + y)) + substs, reduced = cse(e) + assert substs == [(x0, x + y)] + assert reduced == [sqrt(x0) + x0**2] + substs, reduced = cse(Matrix([[1]])) + assert isinstance(reduced[0], Matrix) + + subst42, (red42,) = cse(42) # issue 15082 + assert len(subst42) == 0 and red42 == 42 + subst_half, (red_half,) = cse(0.5) # issue 15082 + assert len(subst_half) == 0 and red_half == 0.5 + + +def test_cse_not_possible(): + # No substitution possible. + e = Add(x, y) + substs, reduced = cse([e]) + assert substs == [] + assert reduced == [x + y] + # issue 6329 + eq = (meijerg((1, 2), (y, 4), (5,), [], x) + + meijerg((1, 3), (y, 4), (5,), [], x)) + assert cse(eq) == ([], [eq]) + + +def test_nested_substitution(): + # Substitution within a substitution. + e = Add(Pow(w*x + y, 2), sqrt(w*x + y)) + substs, reduced = cse([e]) + assert substs == [(x0, w*x + y)] + assert reduced == [sqrt(x0) + x0**2] + + +def test_subtraction_opt(): + # Make sure subtraction is optimized. + e = (x - y)*(z - y) + exp((x - y)*(z - y)) + substs, reduced = cse( + [e], optimizations=[(cse_opts.sub_pre, cse_opts.sub_post)]) + assert substs == [(x0, (x - y)*(y - z))] + assert reduced == [-x0 + exp(-x0)] + e = -(x - y)*(z - y) + exp(-(x - y)*(z - y)) + substs, reduced = cse( + [e], optimizations=[(cse_opts.sub_pre, cse_opts.sub_post)]) + assert substs == [(x0, (x - y)*(y - z))] + assert reduced == [x0 + exp(x0)] + # issue 4077 + n = -1 + 1/x + e = n/x/(-n)**2 - 1/n/x + assert cse(e, optimizations=[(cse_opts.sub_pre, cse_opts.sub_post)]) == \ + ([], [0]) + assert cse(((w + x + y + z)*(w - y - z))/(w + x)**3) == \ + ([(x0, w + x), (x1, y + z)], [(w - x1)*(x0 + x1)/x0**3]) + + +def test_multiple_expressions(): + e1 = (x + y)*z + e2 = (x + y)*w + substs, reduced = cse([e1, e2]) + assert substs == [(x0, x + y)] + assert reduced == [x0*z, x0*w] + l = [w*x*y + z, w*y] + substs, reduced = cse(l) + rsubsts, _ = cse(reversed(l)) + assert substs == rsubsts + assert reduced == [z + x*x0, x0] + l = [w*x*y, w*x*y + z, w*y] + substs, reduced = cse(l) + rsubsts, _ = cse(reversed(l)) + assert substs == rsubsts + assert reduced == [x1, x1 + z, x0] + l = [(x - z)*(y - z), x - z, y - z] + substs, reduced = cse(l) + rsubsts, _ = cse(reversed(l)) + assert substs == [(x0, -z), (x1, x + x0), (x2, x0 + y)] + assert rsubsts == [(x0, -z), (x1, x0 + y), (x2, x + x0)] + assert reduced == [x1*x2, x1, x2] + l = [w*y + w + x + y + z, w*x*y] + assert cse(l) == ([(x0, w*y)], [w + x + x0 + y + z, x*x0]) + assert cse([x + y, x + y + z]) == ([(x0, x + y)], [x0, z + x0]) + assert cse([x + y, x + z]) == ([], [x + y, x + z]) + assert cse([x*y, z + x*y, x*y*z + 3]) == \ + ([(x0, x*y)], [x0, z + x0, 3 + x0*z]) + + +@XFAIL # CSE of non-commutative Mul terms is disabled +def test_non_commutative_cse(): + A, B, C = symbols('A B C', commutative=False) + l = [A*B*C, A*C] + assert cse(l) == ([], l) + l = [A*B*C, A*B] + assert cse(l) == ([(x0, A*B)], [x0*C, x0]) + + +# Test if CSE of non-commutative Mul terms is disabled +def test_bypass_non_commutatives(): + A, B, C = symbols('A B C', commutative=False) + l = [A*B*C, A*C] + assert cse(l) == ([], l) + l = [A*B*C, A*B] + assert cse(l) == ([], l) + l = [B*C, A*B*C] + assert cse(l) == ([], l) + + +@XFAIL # CSE fails when replacing non-commutative sub-expressions +def test_non_commutative_order(): + A, B, C = symbols('A B C', commutative=False) + x0 = symbols('x0', commutative=False) + l = [B+C, A*(B+C)] + assert cse(l) == ([(x0, B+C)], [x0, A*x0]) + + +@XFAIL # Worked in gh-11232, but was reverted due to performance considerations +def test_issue_10228(): + assert cse([x*y**2 + x*y]) == ([(x0, x*y)], [x0*y + x0]) + assert cse([x + y, 2*x + y]) == ([(x0, x + y)], [x0, x + x0]) + assert cse((w + 2*x + y + z, w + x + 1)) == ( + [(x0, w + x)], [x0 + x + y + z, x0 + 1]) + assert cse(((w + x + y + z)*(w - x))/(w + x)) == ( + [(x0, w + x)], [(x0 + y + z)*(w - x)/x0]) + a, b, c, d, f, g, j, m = symbols('a, b, c, d, f, g, j, m') + exprs = (d*g**2*j*m, 4*a*f*g*m, a*b*c*f**2) + assert cse(exprs) == ( + [(x0, g*m), (x1, a*f)], [d*g*j*x0, 4*x0*x1, b*c*f*x1] +) + +@XFAIL +def test_powers(): + assert cse(x*y**2 + x*y) == ([(x0, x*y)], [x0*y + x0]) + + +def test_issue_4498(): + assert cse(w/(x - y) + z/(y - x), optimizations='basic') == \ + ([], [(w - z)/(x - y)]) + + +def test_issue_4020(): + assert cse(x**5 + x**4 + x**3 + x**2, optimizations='basic') \ + == ([(x0, x**2)], [x0*(x**3 + x + x0 + 1)]) + + +def test_issue_4203(): + assert cse(sin(x**x)/x**x) == ([(x0, x**x)], [sin(x0)/x0]) + + +def test_issue_6263(): + e = Eq(x*(-x + 1) + x*(x - 1), 0) + assert cse(e, optimizations='basic') == ([], [True]) + + +def test_issue_25043(): + c = symbols("c") + x = symbols("x0", real=True) + cse_expr = cse(c*x**2 + c*(x**4 - x**2))[-1][-1] + free = cse_expr.free_symbols + assert len(free) == len({i.name for i in free}) + + +def test_dont_cse_tuples(): + from sympy.core.function import Subs + f = Function("f") + g = Function("g") + + name_val, (expr,) = cse( + Subs(f(x, y), (x, y), (0, 1)) + + Subs(g(x, y), (x, y), (0, 1))) + + assert name_val == [] + assert expr == (Subs(f(x, y), (x, y), (0, 1)) + + Subs(g(x, y), (x, y), (0, 1))) + + name_val, (expr,) = cse( + Subs(f(x, y), (x, y), (0, x + y)) + + Subs(g(x, y), (x, y), (0, x + y))) + + assert name_val == [(x0, x + y)] + assert expr == Subs(f(x, y), (x, y), (0, x0)) + \ + Subs(g(x, y), (x, y), (0, x0)) + + +def test_pow_invpow(): + assert cse(1/x**2 + x**2) == \ + ([(x0, x**2)], [x0 + 1/x0]) + assert cse(x**2 + (1 + 1/x**2)/x**2) == \ + ([(x0, x**2), (x1, 1/x0)], [x0 + x1*(x1 + 1)]) + assert cse(1/x**2 + (1 + 1/x**2)*x**2) == \ + ([(x0, x**2), (x1, 1/x0)], [x0*(x1 + 1) + x1]) + assert cse(cos(1/x**2) + sin(1/x**2)) == \ + ([(x0, x**(-2))], [sin(x0) + cos(x0)]) + assert cse(cos(x**2) + sin(x**2)) == \ + ([(x0, x**2)], [sin(x0) + cos(x0)]) + assert cse(y/(2 + x**2) + z/x**2/y) == \ + ([(x0, x**2)], [y/(x0 + 2) + z/(x0*y)]) + assert cse(exp(x**2) + x**2*cos(1/x**2)) == \ + ([(x0, x**2)], [x0*cos(1/x0) + exp(x0)]) + assert cse((1 + 1/x**2)/x**2) == \ + ([(x0, x**(-2))], [x0*(x0 + 1)]) + assert cse(x**(2*y) + x**(-2*y)) == \ + ([(x0, x**(2*y))], [x0 + 1/x0]) + + +def test_postprocess(): + eq = (x + 1 + exp((x + 1)/(y + 1)) + cos(y + 1)) + assert cse([eq, Eq(x, z + 1), z - 2, (z + 1)*(x + 1)], + postprocess=cse_main.cse_separate) == \ + [[(x0, y + 1), (x2, z + 1), (x, x2), (x1, x + 1)], + [x1 + exp(x1/x0) + cos(x0), z - 2, x1*x2]] + + +def test_issue_4499(): + # previously, this gave 16 constants + from sympy.abc import a, b + B = Function('B') + G = Function('G') + t = Tuple(* + (a, a + S.Half, 2*a, b, 2*a - b + 1, (sqrt(z)/2)**(-2*a + 1)*B(2*a - + b, sqrt(z))*B(b - 1, sqrt(z))*G(b)*G(2*a - b + 1), + sqrt(z)*(sqrt(z)/2)**(-2*a + 1)*B(b, sqrt(z))*B(2*a - b, + sqrt(z))*G(b)*G(2*a - b + 1), sqrt(z)*(sqrt(z)/2)**(-2*a + 1)*B(b - 1, + sqrt(z))*B(2*a - b + 1, sqrt(z))*G(b)*G(2*a - b + 1), + (sqrt(z)/2)**(-2*a + 1)*B(b, sqrt(z))*B(2*a - b + 1, + sqrt(z))*G(b)*G(2*a - b + 1), 1, 0, S.Half, z/2, -b + 1, -2*a + b, + -2*a)) + c = cse(t) + ans = ( + [(x0, 2*a), (x1, -b + x0), (x2, x1 + 1), (x3, b - 1), (x4, sqrt(z)), + (x5, B(x3, x4)), (x6, (x4/2)**(1 - x0)*G(b)*G(x2)), (x7, x6*B(x1, x4)), + (x8, B(b, x4)), (x9, x6*B(x2, x4))], + [(a, a + S.Half, x0, b, x2, x5*x7, x4*x7*x8, x4*x5*x9, x8*x9, + 1, 0, S.Half, z/2, -x3, -x1, -x0)]) + assert ans == c + + +def test_issue_6169(): + r = CRootOf(x**6 - 4*x**5 - 2, 1) + assert cse(r) == ([], [r]) + # and a check that the right thing is done with the new + # mechanism + assert sub_post(sub_pre((-x - y)*z - x - y)) == -z*(x + y) - x - y + + +def test_cse_Indexed(): + len_y = 5 + y = IndexedBase('y', shape=(len_y,)) + x = IndexedBase('x', shape=(len_y,)) + i = Idx('i', len_y-1) + + expr1 = (y[i+1]-y[i])/(x[i+1]-x[i]) + expr2 = 1/(x[i+1]-x[i]) + replacements, reduced_exprs = cse([expr1, expr2]) + assert len(replacements) > 0 + + +def test_cse_MatrixSymbol(): + # MatrixSymbols have non-Basic args, so make sure that works + A = MatrixSymbol("A", 3, 3) + assert cse(A) == ([], [A]) + + n = symbols('n', integer=True) + B = MatrixSymbol("B", n, n) + assert cse(B) == ([], [B]) + + assert cse(A[0] * A[0]) == ([], [A[0]*A[0]]) + + assert cse(A[0,0]*A[0,1] + A[0,0]*A[0,1]*A[0,2]) == ([(x0, A[0, 0]*A[0, 1])], [x0*A[0, 2] + x0]) + +def test_cse_MatrixExpr(): + A = MatrixSymbol('A', 3, 3) + y = MatrixSymbol('y', 3, 1) + + expr1 = (A.T*A).I * A * y + expr2 = (A.T*A) * A * y + replacements, reduced_exprs = cse([expr1, expr2]) + assert len(replacements) > 0 + + replacements, reduced_exprs = cse([expr1 + expr2, expr1]) + assert replacements + + replacements, reduced_exprs = cse([A**2, A + A**2]) + assert replacements + + +def test_Piecewise(): + f = Piecewise((-z + x*y, Eq(y, 0)), (-z - x*y, True)) + ans = cse(f) + actual_ans = ([(x0, x*y)], + [Piecewise((x0 - z, Eq(y, 0)), (-z - x0, True))]) + assert ans == actual_ans + + +def test_ignore_order_terms(): + eq = exp(x).series(x,0,3) + sin(y+x**3) - 1 + assert cse(eq) == ([], [sin(x**3 + y) + x + x**2/2 + O(x**3)]) + + +def test_name_conflict(): + z1 = x0 + y + z2 = x2 + x3 + l = [cos(z1) + z1, cos(z2) + z2, x0 + x2] + substs, reduced = cse(l) + assert [e.subs(reversed(substs)) for e in reduced] == l + + +def test_name_conflict_cust_symbols(): + z1 = x0 + y + z2 = x2 + x3 + l = [cos(z1) + z1, cos(z2) + z2, x0 + x2] + substs, reduced = cse(l, symbols("x:10")) + assert [e.subs(reversed(substs)) for e in reduced] == l + + +def test_symbols_exhausted_error(): + l = cos(x+y)+x+y+cos(w+y)+sin(w+y) + sym = [x, y, z] + with raises(ValueError): + cse(l, symbols=sym) + + +def test_issue_7840(): + # daveknippers' example + C393 = sympify( \ + 'Piecewise((C391 - 1.65, C390 < 0.5), (Piecewise((C391 - 1.65, \ + C391 > 2.35), (C392, True)), True))' + ) + C391 = sympify( \ + 'Piecewise((2.05*C390**(-1.03), C390 < 0.5), (2.5*C390**(-0.625), True))' + ) + C393 = C393.subs('C391',C391) + # simple substitution + sub = {} + sub['C390'] = 0.703451854 + sub['C392'] = 1.01417794 + ss_answer = C393.subs(sub) + # cse + substitutions,new_eqn = cse(C393) + for pair in substitutions: + sub[pair[0].name] = pair[1].subs(sub) + cse_answer = new_eqn[0].subs(sub) + # both methods should be the same + assert ss_answer == cse_answer + + # GitRay's example + expr = sympify( + "Piecewise((Symbol('ON'), Equality(Symbol('mode'), Symbol('ON'))), \ + (Piecewise((Piecewise((Symbol('OFF'), StrictLessThan(Symbol('x'), \ + Symbol('threshold'))), (Symbol('ON'), true)), Equality(Symbol('mode'), \ + Symbol('AUTO'))), (Symbol('OFF'), true)), true))" + ) + substitutions, new_eqn = cse(expr) + # this Piecewise should be exactly the same + assert new_eqn[0] == expr + # there should not be any replacements + assert len(substitutions) < 1 + + +def test_issue_8891(): + for cls in (MutableDenseMatrix, MutableSparseMatrix, + ImmutableDenseMatrix, ImmutableSparseMatrix): + m = cls(2, 2, [x + y, 0, 0, 0]) + res = cse([x + y, m]) + ans = ([(x0, x + y)], [x0, cls([[x0, 0], [0, 0]])]) + assert res == ans + assert isinstance(res[1][-1], cls) + + +def test_issue_11230(): + # a specific test that always failed + a, b, f, k, l, i = symbols('a b f k l i') + p = [a*b*f*k*l, a*i*k**2*l, f*i*k**2*l] + R, C = cse(p) + assert not any(i.is_Mul for a in C for i in a.args) + + # random tests for the issue + from sympy.core.random import choice + from sympy.core.function import expand_mul + s = symbols('a:m') + # 35 Mul tests, none of which should ever fail + ex = [Mul(*[choice(s) for i in range(5)]) for i in range(7)] + for p in subsets(ex, 3): + p = list(p) + R, C = cse(p) + assert not any(i.is_Mul for a in C for i in a.args) + for ri in reversed(R): + for i in range(len(C)): + C[i] = C[i].subs(*ri) + assert p == C + # 35 Add tests, none of which should ever fail + ex = [Add(*[choice(s[:7]) for i in range(5)]) for i in range(7)] + for p in subsets(ex, 3): + p = list(p) + R, C = cse(p) + assert not any(i.is_Add for a in C for i in a.args) + for ri in reversed(R): + for i in range(len(C)): + C[i] = C[i].subs(*ri) + # use expand_mul to handle cases like this: + # p = [a + 2*b + 2*e, 2*b + c + 2*e, b + 2*c + 2*g] + # x0 = 2*(b + e) is identified giving a rebuilt p that + # is now `[a + 2*(b + e), c + 2*(b + e), b + 2*c + 2*g]` + assert p == [expand_mul(i) for i in C] + + +@XFAIL +def test_issue_11577(): + def check(eq): + r, c = cse(eq) + assert eq.count_ops() >= \ + len(r) + sum(i[1].count_ops() for i in r) + \ + count_ops(c) + + eq = x**5*y**2 + x**5*y + x**5 + assert cse(eq) == ( + [(x0, x**4), (x1, x*y)], [x**5 + x0*x1*y + x0*x1]) + # ([(x0, x**5*y)], [x0*y + x0 + x**5]) or + # ([(x0, x**5)], [x0*y**2 + x0*y + x0]) + check(eq) + + eq = x**2/(y + 1)**2 + x/(y + 1) + assert cse(eq) == ( + [(x0, y + 1)], [x**2/x0**2 + x/x0]) + # ([(x0, x/(y + 1))], [x0**2 + x0]) + check(eq) + + +def test_hollow_rejection(): + eq = [x + 3, x + 4] + assert cse(eq) == ([], eq) + + +def test_cse_ignore(): + exprs = [exp(y)*(3*y + 3*sqrt(x+1)), exp(y)*(5*y + 5*sqrt(x+1))] + subst1, red1 = cse(exprs) + assert any(y in sub.free_symbols for _, sub in subst1), "cse failed to identify any term with y" + + subst2, red2 = cse(exprs, ignore=(y,)) # y is not allowed in substitutions + assert not any(y in sub.free_symbols for _, sub in subst2), "Sub-expressions containing y must be ignored" + assert any(sub - sqrt(x + 1) == 0 for _, sub in subst2), "cse failed to identify sqrt(x + 1) as sub-expression" + + +def test_cse_ignore_issue_15002(): + l = [ + w*exp(x)*exp(-z), + exp(y)*exp(x)*exp(-z) + ] + substs, reduced = cse(l, ignore=(x,)) + rl = [e.subs(reversed(substs)) for e in reduced] + assert rl == l + + +def test_cse_unevaluated(): + xp1 = UnevaluatedExpr(x + 1) + # This used to cause RecursionError + [(x0, ue)], [red] = cse([(-1 - xp1) / (1 - xp1)]) + if ue == xp1: + assert red == (-1 - x0) / (1 - x0) + elif ue == -xp1: + assert red == (-1 + x0) / (1 + x0) + else: + msg = f'Expected common subexpression {xp1} or {-xp1}, instead got {ue}' + assert False, msg + + +def test_cse__performance(): + nexprs, nterms = 3, 20 + x = symbols('x:%d' % nterms) + exprs = [ + reduce(add, [x[j]*(-1)**(i+j) for j in range(nterms)]) + for i in range(nexprs) + ] + assert (exprs[0] + exprs[1]).simplify() == 0 + subst, red = cse(exprs) + assert len(subst) > 0, "exprs[0] == -exprs[2], i.e. a CSE" + for i, e in enumerate(red): + assert (e.subs(reversed(subst)) - exprs[i]).simplify() == 0 + + +def test_issue_12070(): + exprs = [x + y, 2 + x + y, x + y + z, 3 + x + y + z] + subst, red = cse(exprs) + assert 6 >= (len(subst) + sum(v.count_ops() for k, v in subst) + + count_ops(red)) + + +def test_issue_13000(): + eq = x/(-4*x**2 + y**2) + cse_eq = cse(eq)[1][0] + assert cse_eq == eq + + +def test_issue_18203(): + eq = CRootOf(x**5 + 11*x - 2, 0) + CRootOf(x**5 + 11*x - 2, 1) + assert cse(eq) == ([], [eq]) + + +def test_unevaluated_mul(): + eq = Mul(x + y, x + y, evaluate=False) + assert cse(eq) == ([(x0, x + y)], [x0**2]) + + +def test_cse_release_variables(): + from sympy.simplify.cse_main import cse_release_variables + _0, _1, _2, _3, _4 = symbols('_:5') + eqs = [(x + y - 1)**2, x, + x + y, (x + y)/(2*x + 1) + (x + y - 1)**2, + (2*x + 1)**(x + y)] + r, e = cse(eqs, postprocess=cse_release_variables) + # this can change in keeping with the intention of the function + assert r, e == ([ + (x0, x + y), (x1, (x0 - 1)**2), (x2, 2*x + 1), + (_3, x0/x2 + x1), (_4, x2**x0), (x2, None), (_0, x1), + (x1, None), (_2, x0), (x0, None), (_1, x)], (_0, _1, _2, _3, _4)) + r.reverse() + r = [(s, v) for s, v in r if v is not None] + assert eqs == [i.subs(r) for i in e] + + +def test_cse_list(): + _cse = lambda x: cse(x, list=False) + assert _cse(x) == ([], x) + assert _cse('x') == ([], 'x') + it = [x] + for c in (list, tuple, set): + assert _cse(c(it)) == ([], c(it)) + #Tuple works different from tuple: + assert _cse(Tuple(*it)) == ([], Tuple(*it)) + d = {x: 1} + assert _cse(d) == ([], d) + +def test_issue_18991(): + A = MatrixSymbol('A', 2, 2) + assert signsimp(-A * A - A) == -A * A - A + + +def test_unevaluated_Mul(): + m = [Mul(1, 2, evaluate=False)] + assert cse(m) == ([], m) + + +def test_cse_matrix_expression_inverse(): + A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2) + x = Inverse(A) + cse_expr = cse(x) + assert cse_expr == ([], [Inverse(A)]) + + +def test_cse_matrix_expression_matmul_inverse(): + A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2) + b = ImmutableDenseMatrix(symbols('b:2')) + x = MatMul(Inverse(A), b) + cse_expr = cse(x) + assert cse_expr == ([], [x]) + + +def test_cse_matrix_negate_matrix(): + A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2) + x = MatMul(S.NegativeOne, A) + cse_expr = cse(x) + assert cse_expr == ([], [x]) + + +def test_cse_matrix_negate_matmul_not_extracted(): + A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2) + B = ImmutableDenseMatrix(symbols('B:4')).reshape(2, 2) + x = MatMul(S.NegativeOne, A, B) + cse_expr = cse(x) + assert cse_expr == ([], [x]) + + +@XFAIL # No simplification rule for nested associative operations +def test_cse_matrix_nested_matmul_collapsed(): + A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2) + B = ImmutableDenseMatrix(symbols('B:4')).reshape(2, 2) + x = MatMul(S.NegativeOne, MatMul(A, B)) + cse_expr = cse(x) + assert cse_expr == ([], [MatMul(S.NegativeOne, A, B)]) + + +def test_cse_matrix_optimize_out_single_argument_mul(): + A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2) + x = MatMul(MatMul(MatMul(A))) + cse_expr = cse(x) + assert cse_expr == ([], [A]) + + +@XFAIL # Multiple simplification passed not supported in CSE +def test_cse_matrix_optimize_out_single_argument_mul_combined(): + A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2) + x = MatAdd(MatMul(MatMul(MatMul(A))), MatMul(MatMul(A)), MatMul(A), A) + cse_expr = cse(x) + assert cse_expr == ([], [MatMul(4, A)]) + + +def test_cse_matrix_optimize_out_single_argument_add(): + A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2) + x = MatAdd(MatAdd(MatAdd(MatAdd(A)))) + cse_expr = cse(x) + assert cse_expr == ([], [A]) + + +@XFAIL # Multiple simplification passed not supported in CSE +def test_cse_matrix_optimize_out_single_argument_add_combined(): + A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2) + x = MatMul(MatAdd(MatAdd(MatAdd(A))), MatAdd(MatAdd(A)), MatAdd(A), A) + cse_expr = cse(x) + assert cse_expr == ([], [MatMul(4, A)]) + + +def test_cse_matrix_expression_matrix_solve(): + A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2) + b = ImmutableDenseMatrix(symbols('b:2')) + x = MatrixSolve(A, b) + cse_expr = cse(x) + assert cse_expr == ([], [x]) + + +def test_cse_matrix_matrix_expression(): + X = ImmutableDenseMatrix(symbols('X:4')).reshape(2, 2) + y = ImmutableDenseMatrix(symbols('y:2')) + b = MatMul(Inverse(MatMul(Transpose(X), X)), Transpose(X), y) + cse_expr = cse(b) + x0 = MatrixSymbol('x0', 2, 2) + reduced_expr_expected = MatMul(Inverse(MatMul(x0, X)), x0, y) + assert cse_expr == ([(x0, Transpose(X))], [reduced_expr_expected]) + + +def test_cse_matrix_kalman_filter(): + """Kalman Filter example from Matthew Rocklin's SciPy 2013 talk. + + Talk titled: "Matrix Expressions and BLAS/LAPACK; SciPy 2013 Presentation" + + Video: https://pyvideo.org/scipy-2013/matrix-expressions-and-blaslapack-scipy-2013-pr.html + + Notes + ===== + + Equations are: + + new_mu = mu + Sigma*H.T * (R + H*Sigma*H.T).I * (H*mu - data) + = MatAdd(mu, MatMul(Sigma, Transpose(H), Inverse(MatAdd(R, MatMul(H, Sigma, Transpose(H)))), MatAdd(MatMul(H, mu), MatMul(S.NegativeOne, data)))) + new_Sigma = Sigma - Sigma*H.T * (R + H*Sigma*H.T).I * H * Sigma + = MatAdd(Sigma, MatMul(S.NegativeOne, Sigma, Transpose(H)), Inverse(MatAdd(R, MatMul(H*Sigma*Transpose(H)))), H, Sigma)) + + """ + N = 2 + mu = ImmutableDenseMatrix(symbols(f'mu:{N}')) + Sigma = ImmutableDenseMatrix(symbols(f'Sigma:{N * N}')).reshape(N, N) + H = ImmutableDenseMatrix(symbols(f'H:{N * N}')).reshape(N, N) + R = ImmutableDenseMatrix(symbols(f'R:{N * N}')).reshape(N, N) + data = ImmutableDenseMatrix(symbols(f'data:{N}')) + new_mu = MatAdd(mu, MatMul(Sigma, Transpose(H), Inverse(MatAdd(R, MatMul(H, Sigma, Transpose(H)))), MatAdd(MatMul(H, mu), MatMul(S.NegativeOne, data)))) + new_Sigma = MatAdd(Sigma, MatMul(S.NegativeOne, Sigma, Transpose(H), Inverse(MatAdd(R, MatMul(H, Sigma, Transpose(H)))), H, Sigma)) + cse_expr = cse([new_mu, new_Sigma]) + x0 = MatrixSymbol('x0', N, N) + x1 = MatrixSymbol('x1', N, N) + replacements_expected = [ + (x0, Transpose(H)), + (x1, Inverse(MatAdd(R, MatMul(H, Sigma, x0)))), + ] + reduced_exprs_expected = [ + MatAdd(mu, MatMul(Sigma, x0, x1, MatAdd(MatMul(H, mu), MatMul(S.NegativeOne, data)))), + MatAdd(Sigma, MatMul(S.NegativeOne, Sigma, x0, x1, H, Sigma)), + ] + assert cse_expr == (replacements_expected, reduced_exprs_expected) diff --git a/phivenv/Lib/site-packages/sympy/simplify/tests/test_cse_diff.py b/phivenv/Lib/site-packages/sympy/simplify/tests/test_cse_diff.py new file mode 100644 index 0000000000000000000000000000000000000000..92b2d3d6bbaafb838a5e75f32a214511a1d39567 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/simplify/tests/test_cse_diff.py @@ -0,0 +1,206 @@ +"""Tests for the ``sympy.simplify._cse_diff.py`` module.""" + +import pytest + +from sympy.core.symbol import (Symbol, symbols) +from sympy.core.numbers import Integer +from sympy.core.function import Function +from sympy.core import Derivative +from sympy.functions.elementary.exponential import exp +from sympy.matrices.immutable import ImmutableDenseMatrix +from sympy.physics.mechanics import dynamicsymbols +from sympy.simplify._cse_diff import (_forward_jacobian, + _remove_cse_from_derivative, + _forward_jacobian_cse, + _forward_jacobian_norm_in_cse_out) +from sympy.simplify.simplify import simplify +from sympy.matrices import Matrix, eye + +from sympy.testing.pytest import raises +from sympy.functions.elementary.trigonometric import (cos, sin, tan) +from sympy.simplify.trigsimp import trigsimp + +from sympy import cse + + +w = Symbol('w') +x = Symbol('x') +y = Symbol('y') +z = Symbol('z') + +q1, q2, q3 = dynamicsymbols('q1 q2 q3') + +# Define the custom functions +k = Function('k')(x, y) +f = Function('f')(k, z) + +zero = Integer(0) +one = Integer(1) +two = Integer(2) +neg_one = Integer(-1) + + +@pytest.mark.parametrize( + 'expr, wrt', + [ + ([zero], [x]), + ([one], [x]), + ([two], [x]), + ([neg_one], [x]), + ([x], [x]), + ([y], [x]), + ([x + y], [x]), + ([x*y], [x]), + ([x**2], [x]), + ([x**y], [x]), + ([exp(x)], [x]), + ([sin(x)], [x]), + ([tan(x)], [x]), + ([zero, one, x, y, x*y, x + y], [x, y]), + ([((x/y) + sin(x/y) - exp(y))*((x/y) - exp(y))], [x, y]), + ([w*tan(y*z)/(x - tan(y*z)), w*x*tan(y*z)/(x - tan(y*z))], [w, x, y, z]), + ([q1**2 + q2, q2**2 + q3, q3**2 + q1], [q1, q2, q3]), + ([f + Derivative(f, x) + k + 2*x], [x]) + ] +) + + +def test_forward_jacobian(expr, wrt): + expr = ImmutableDenseMatrix([expr]).T + wrt = ImmutableDenseMatrix([wrt]).T + jacobian = _forward_jacobian(expr, wrt) + zeros = ImmutableDenseMatrix.zeros(*jacobian.shape) + assert simplify(jacobian - expr.jacobian(wrt)) == zeros + + +def test_process_cse(): + x, y, z = symbols('x y z') + f = Function('f') + k = Function('k') + expr = Matrix([f(k(x,y), z) + Derivative(f(k(x,y), z), x) + k(x,y) + 2*x]) + repl, reduced = cse(expr) + p_repl, p_reduced = _remove_cse_from_derivative(repl, reduced) + + x0 = symbols('x0') + x1 = symbols('x1') + + expected_output = ( + [(x0, k(x, y)), (x1, f(x0, z))], + [Matrix([2 * x + x0 + x1 + Derivative(f(k(x, y), z), x)])] + ) + + assert p_repl == expected_output[0], f"Expected {expected_output[0]}, but got {p_repl}" + assert p_reduced == expected_output[1], f"Expected {expected_output[1]}, but got {p_reduced}" + + +def test_io_matrix_type(): + x, y, z = symbols('x y z') + expr = ImmutableDenseMatrix([ + x * y + y * z + x * y * z, + x ** 2 + y ** 2 + z ** 2, + x * y + x * z + y * z + ]) + wrt = ImmutableDenseMatrix([x, y, z]) + + replacements, reduced_expr = cse(expr) + + # Test _forward_jacobian_core + replacements_core, jacobian_core, precomputed_fs_core = _forward_jacobian_cse(replacements, reduced_expr, wrt) + assert isinstance(jacobian_core[0], type(reduced_expr[0])), "Jacobian should be a Matrix of the same type as the input" + + # Test _forward_jacobian_norm_in_dag_out + replacements_norm, jacobian_norm, precomputed_fs_norm = _forward_jacobian_norm_in_cse_out( + expr, wrt) + assert isinstance(jacobian_norm[0], type(reduced_expr[0])), "Jacobian should be a Matrix of the same type as the input" + + # Test _forward_jacobian + jacobian = _forward_jacobian(expr, wrt) + assert isinstance(jacobian, type(expr)), "Jacobian should be a Matrix of the same type as the input" + + +def test_forward_jacobian_input_output(): + x, y, z = symbols('x y z') + expr = Matrix([ + x * y + y * z + x * y * z, + x ** 2 + y ** 2 + z ** 2, + x * y + x * z + y * z + ]) + wrt = Matrix([x, y, z]) + + replacements, reduced_expr = cse(expr) + + # Test _forward_jacobian_core + replacements_core, jacobian_core, precomputed_fs_core = _forward_jacobian_cse(replacements, reduced_expr, wrt) + assert isinstance(replacements_core, type(replacements)), "Replacements should be a list" + assert isinstance(jacobian_core, type(reduced_expr)), "Jacobian should be a list" + assert isinstance(precomputed_fs_core, list), "Precomputed free symbols should be a list" + assert len(replacements_core) == len(replacements), "Length of replacements does not match" + assert len(jacobian_core) == 1, "Jacobian should have one element" + assert len(precomputed_fs_core) == len(replacements), "Length of precomputed free symbols does not match" + + # Test _forward_jacobian_norm_in_dag_out + replacements_norm, jacobian_norm, precomputed_fs_norm = _forward_jacobian_norm_in_cse_out(expr, wrt) + assert isinstance(replacements_norm, type(replacements)), "Replacements should be a list" + assert isinstance(jacobian_norm, type(reduced_expr)), "Jacobian should be a list" + assert isinstance(precomputed_fs_norm, list), "Precomputed free symbols should be a list" + assert len(replacements_norm) == len(replacements), "Length of replacements does not match" + assert len(jacobian_norm) == 1, "Jacobian should have one element" + assert len(precomputed_fs_norm) == len(replacements), "Length of precomputed free symbols does not match" + + +def test_jacobian_hessian(): + L = Matrix(1, 2, [x**2*y, 2*y**2 + x*y]) + syms = [x, y] + assert _forward_jacobian(L, syms) == Matrix([[2*x*y, x**2], [y, 4*y + x]]) + + L = Matrix(1, 2, [x, x**2*y**3]) + assert _forward_jacobian(L, syms) == Matrix([[1, 0], [2*x*y**3, x**2*3*y**2]]) + + +def test_jacobian_metrics(): + rho, phi = symbols("rho,phi") + X = Matrix([rho * cos(phi), rho * sin(phi)]) + Y = Matrix([rho, phi]) + J = _forward_jacobian(X, Y) + assert J == X.jacobian(Y.T) + assert J == (X.T).jacobian(Y) + assert J == (X.T).jacobian(Y.T) + g = J.T * eye(J.shape[0]) * J + g = g.applyfunc(trigsimp) + assert g == Matrix([[1, 0], [0, rho ** 2]]) + + +def test_jacobian2(): + rho, phi = symbols("rho,phi") + X = Matrix([rho * cos(phi), rho * sin(phi), rho ** 2]) + Y = Matrix([rho, phi]) + J = Matrix([ + [cos(phi), -rho * sin(phi)], + [sin(phi), rho * cos(phi)], + [2 * rho, 0], + ]) + assert _forward_jacobian(X, Y) == J + + +def test_issue_4564(): + X = Matrix([exp(x + y + z), exp(x + y + z), exp(x + y + z)]) + Y = Matrix([x, y, z]) + for i in range(1, 3): + for j in range(1, 3): + X_slice = X[:i, :] + Y_slice = Y[:j, :] + J = _forward_jacobian(X_slice, Y_slice) + assert J.rows == i + assert J.cols == j + for k in range(j): + assert J[:, k] == X_slice + + +def test_nonvectorJacobian(): + X = Matrix([[exp(x + y + z), exp(x + y + z)], + [exp(x + y + z), exp(x + y + z)]]) + raises(TypeError, lambda: _forward_jacobian(X, Matrix([x, y, z]))) + X = X[0, :] + Y = Matrix([[x, y], [x, z]]) + raises(TypeError, lambda: _forward_jacobian(X, Y)) + raises(TypeError, lambda: _forward_jacobian(X, Matrix([[x, y], [x, z]]))) diff --git a/phivenv/Lib/site-packages/sympy/simplify/tests/test_epathtools.py b/phivenv/Lib/site-packages/sympy/simplify/tests/test_epathtools.py new file mode 100644 index 0000000000000000000000000000000000000000..a8bb47b2f2ff624077ab9905677b181c587ab5a7 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/simplify/tests/test_epathtools.py @@ -0,0 +1,90 @@ +"""Tests for tools for manipulation of expressions using paths. """ + +from sympy.simplify.epathtools import epath, EPath +from sympy.testing.pytest import raises + +from sympy.core.numbers import E +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.abc import x, y, z, t + + +def test_epath_select(): + expr = [((x, 1, t), 2), ((3, y, 4), z)] + + assert epath("/*", expr) == [((x, 1, t), 2), ((3, y, 4), z)] + assert epath("/*/*", expr) == [(x, 1, t), 2, (3, y, 4), z] + assert epath("/*/*/*", expr) == [x, 1, t, 3, y, 4] + assert epath("/*/*/*/*", expr) == [] + + assert epath("/[:]", expr) == [((x, 1, t), 2), ((3, y, 4), z)] + assert epath("/[:]/[:]", expr) == [(x, 1, t), 2, (3, y, 4), z] + assert epath("/[:]/[:]/[:]", expr) == [x, 1, t, 3, y, 4] + assert epath("/[:]/[:]/[:]/[:]", expr) == [] + + assert epath("/*/[:]", expr) == [(x, 1, t), 2, (3, y, 4), z] + + assert epath("/*/[0]", expr) == [(x, 1, t), (3, y, 4)] + assert epath("/*/[1]", expr) == [2, z] + assert epath("/*/[2]", expr) == [] + + assert epath("/*/int", expr) == [2] + assert epath("/*/Symbol", expr) == [z] + assert epath("/*/tuple", expr) == [(x, 1, t), (3, y, 4)] + assert epath("/*/__iter__?", expr) == [(x, 1, t), (3, y, 4)] + + assert epath("/*/int|tuple", expr) == [(x, 1, t), 2, (3, y, 4)] + assert epath("/*/Symbol|tuple", expr) == [(x, 1, t), (3, y, 4), z] + assert epath("/*/int|Symbol|tuple", expr) == [(x, 1, t), 2, (3, y, 4), z] + + assert epath("/*/int|__iter__?", expr) == [(x, 1, t), 2, (3, y, 4)] + assert epath("/*/Symbol|__iter__?", expr) == [(x, 1, t), (3, y, 4), z] + assert epath( + "/*/int|Symbol|__iter__?", expr) == [(x, 1, t), 2, (3, y, 4), z] + + assert epath("/*/[0]/int", expr) == [1, 3, 4] + assert epath("/*/[0]/Symbol", expr) == [x, t, y] + + assert epath("/*/[0]/int[1:]", expr) == [1, 4] + assert epath("/*/[0]/Symbol[1:]", expr) == [t, y] + + assert epath("/Symbol", x + y + z + 1) == [x, y, z] + assert epath("/*/*/Symbol", t + sin(x + 1) + cos(x + y + E)) == [x, x, y] + + +def test_epath_apply(): + expr = [((x, 1, t), 2), ((3, y, 4), z)] + func = lambda expr: expr**2 + + assert epath("/*", expr, list) == [[(x, 1, t), 2], [(3, y, 4), z]] + + assert epath("/*/[0]", expr, list) == [([x, 1, t], 2), ([3, y, 4], z)] + assert epath("/*/[1]", expr, func) == [((x, 1, t), 4), ((3, y, 4), z**2)] + assert epath("/*/[2]", expr, list) == expr + + assert epath("/*/[0]/int", expr, func) == [((x, 1, t), 2), ((9, y, 16), z)] + assert epath("/*/[0]/Symbol", expr, func) == [((x**2, 1, t**2), 2), + ((3, y**2, 4), z)] + assert epath( + "/*/[0]/int[1:]", expr, func) == [((x, 1, t), 2), ((3, y, 16), z)] + assert epath("/*/[0]/Symbol[1:]", expr, func) == [((x, 1, t**2), + 2), ((3, y**2, 4), z)] + + assert epath("/Symbol", x + y + z + 1, func) == x**2 + y**2 + z**2 + 1 + assert epath("/*/*/Symbol", t + sin(x + 1) + cos(x + y + E), func) == \ + t + sin(x**2 + 1) + cos(x**2 + y**2 + E) + + +def test_EPath(): + assert EPath("/*/[0]")._path == "/*/[0]" + assert EPath(EPath("/*/[0]"))._path == "/*/[0]" + assert isinstance(epath("/*/[0]"), EPath) is True + + assert repr(EPath("/*/[0]")) == "EPath('/*/[0]')" + + raises(ValueError, lambda: EPath("")) + raises(ValueError, lambda: EPath("/")) + raises(ValueError, lambda: EPath("/|x")) + raises(ValueError, lambda: EPath("/[")) + raises(ValueError, lambda: EPath("/[0]%")) + + raises(NotImplementedError, lambda: EPath("Symbol")) diff --git a/phivenv/Lib/site-packages/sympy/simplify/tests/test_fu.py b/phivenv/Lib/site-packages/sympy/simplify/tests/test_fu.py new file mode 100644 index 0000000000000000000000000000000000000000..2de2126b7333195fceeffe72dc9cb642e7eba9a9 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/simplify/tests/test_fu.py @@ -0,0 +1,492 @@ +from sympy.core.add import Add +from sympy.core.mul import Mul +from sympy.core.numbers import (I, Rational, pi) +from sympy.core.parameters import evaluate +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol, symbols) +from sympy.functions.elementary.hyperbolic import (cosh, coth, csch, sech, sinh, tanh) +from sympy.functions.elementary.miscellaneous import (root, sqrt) +from sympy.functions.elementary.trigonometric import (cos, cot, csc, sec, sin, tan) +from sympy.simplify.powsimp import powsimp +from sympy.simplify.fu import ( + L, TR1, TR10, TR10i, TR11, _TR11, TR12, TR12i, TR13, TR14, TR15, TR16, + TR111, TR2, TR2i, TR3, TR4, TR5, TR6, TR7, TR8, TR9, TRmorrie, _TR56 as T, + TRpower, hyper_as_trig, fu, process_common_addends, trig_split, + as_f_sign_1) +from sympy.core.random import verify_numerically +from sympy.abc import a, b, c, x, y, z + + +def test_TR1(): + assert TR1(2*csc(x) + sec(x)) == 1/cos(x) + 2/sin(x) + + +def test_TR2(): + assert TR2(tan(x)) == sin(x)/cos(x) + assert TR2(cot(x)) == cos(x)/sin(x) + assert TR2(tan(tan(x) - sin(x)/cos(x))) == 0 + + +def test_TR2i(): + # just a reminder that ratios of powers only simplify if both + # numerator and denominator satisfy the condition that each + # has a positive base or an integer exponent; e.g. the following, + # at y=-1, x=1/2 gives sqrt(2)*I != -sqrt(2)*I + assert powsimp(2**x/y**x) != (2/y)**x + + assert TR2i(sin(x)/cos(x)) == tan(x) + assert TR2i(sin(x)*sin(y)/cos(x)) == tan(x)*sin(y) + assert TR2i(1/(sin(x)/cos(x))) == 1/tan(x) + assert TR2i(1/(sin(x)*sin(y)/cos(x))) == 1/tan(x)/sin(y) + assert TR2i(sin(x)/2/(cos(x) + 1)) == sin(x)/(cos(x) + 1)/2 + + assert TR2i(sin(x)/2/(cos(x) + 1), half=True) == tan(x/2)/2 + assert TR2i(sin(1)/(cos(1) + 1), half=True) == tan(S.Half) + assert TR2i(sin(2)/(cos(2) + 1), half=True) == tan(1) + assert TR2i(sin(4)/(cos(4) + 1), half=True) == tan(2) + assert TR2i(sin(5)/(cos(5) + 1), half=True) == tan(5*S.Half) + assert TR2i((cos(1) + 1)/sin(1), half=True) == 1/tan(S.Half) + assert TR2i((cos(2) + 1)/sin(2), half=True) == 1/tan(1) + assert TR2i((cos(4) + 1)/sin(4), half=True) == 1/tan(2) + assert TR2i((cos(5) + 1)/sin(5), half=True) == 1/tan(5*S.Half) + assert TR2i((cos(1) + 1)**(-a)*sin(1)**a, half=True) == tan(S.Half)**a + assert TR2i((cos(2) + 1)**(-a)*sin(2)**a, half=True) == tan(1)**a + assert TR2i((cos(4) + 1)**(-a)*sin(4)**a, half=True) == (cos(4) + 1)**(-a)*sin(4)**a + assert TR2i((cos(5) + 1)**(-a)*sin(5)**a, half=True) == (cos(5) + 1)**(-a)*sin(5)**a + assert TR2i((cos(1) + 1)**a*sin(1)**(-a), half=True) == tan(S.Half)**(-a) + assert TR2i((cos(2) + 1)**a*sin(2)**(-a), half=True) == tan(1)**(-a) + assert TR2i((cos(4) + 1)**a*sin(4)**(-a), half=True) == (cos(4) + 1)**a*sin(4)**(-a) + assert TR2i((cos(5) + 1)**a*sin(5)**(-a), half=True) == (cos(5) + 1)**a*sin(5)**(-a) + + i = symbols('i', integer=True) + assert TR2i(((cos(5) + 1)**i*sin(5)**(-i)), half=True) == tan(5*S.Half)**(-i) + assert TR2i(1/((cos(5) + 1)**i*sin(5)**(-i)), half=True) == tan(5*S.Half)**i + + +def test_TR3(): + assert TR3(cos(y - x*(y - x))) == cos(x*(x - y) + y) + assert cos(pi/2 + x) == -sin(x) + assert cos(30*pi/2 + x) == -cos(x) + + for f in (cos, sin, tan, cot, csc, sec): + i = f(pi*Rational(3, 7)) + j = TR3(i) + assert verify_numerically(i, j) and i.func != j.func + + with evaluate(False): + eq = cos(9*pi/22) + assert eq.has(9*pi) and TR3(eq) == sin(pi/11) + + +def test_TR4(): + for i in [0, pi/6, pi/4, pi/3, pi/2]: + with evaluate(False): + eq = cos(i) + assert isinstance(eq, cos) and TR4(eq) == cos(i) + + +def test__TR56(): + h = lambda x: 1 - x + assert T(sin(x)**3, sin, cos, h, 4, False) == sin(x)*(-cos(x)**2 + 1) + assert T(sin(x)**10, sin, cos, h, 4, False) == sin(x)**10 + assert T(sin(x)**6, sin, cos, h, 6, False) == (-cos(x)**2 + 1)**3 + assert T(sin(x)**6, sin, cos, h, 6, True) == sin(x)**6 + assert T(sin(x)**8, sin, cos, h, 10, True) == (-cos(x)**2 + 1)**4 + + # issue 17137 + assert T(sin(x)**I, sin, cos, h, 4, True) == sin(x)**I + assert T(sin(x)**(2*I + 1), sin, cos, h, 4, True) == sin(x)**(2*I + 1) + + +def test_TR5(): + assert TR5(sin(x)**2) == -cos(x)**2 + 1 + assert TR5(sin(x)**-2) == sin(x)**(-2) + assert TR5(sin(x)**4) == (-cos(x)**2 + 1)**2 + + +def test_TR6(): + assert TR6(cos(x)**2) == -sin(x)**2 + 1 + assert TR6(cos(x)**-2) == cos(x)**(-2) + assert TR6(cos(x)**4) == (-sin(x)**2 + 1)**2 + + +def test_TR7(): + assert TR7(cos(x)**2) == cos(2*x)/2 + S.Half + assert TR7(cos(x)**2 + 1) == cos(2*x)/2 + Rational(3, 2) + + +def test_TR8(): + assert TR8(cos(2)*cos(3)) == cos(5)/2 + cos(1)/2 + assert TR8(cos(2)*sin(3)) == sin(5)/2 + sin(1)/2 + assert TR8(sin(2)*sin(3)) == -cos(5)/2 + cos(1)/2 + assert TR8(sin(1)*sin(2)*sin(3)) == sin(4)/4 - sin(6)/4 + sin(2)/4 + assert TR8(cos(2)*cos(3)*cos(4)*cos(5)) == \ + cos(4)/4 + cos(10)/8 + cos(2)/8 + cos(8)/8 + cos(14)/8 + \ + cos(6)/8 + Rational(1, 8) + assert TR8(cos(2)*cos(3)*cos(4)*cos(5)*cos(6)) == \ + cos(10)/8 + cos(4)/8 + 3*cos(2)/16 + cos(16)/16 + cos(8)/8 + \ + cos(14)/16 + cos(20)/16 + cos(12)/16 + Rational(1, 16) + cos(6)/8 + assert TR8(sin(pi*Rational(3, 7))**2*cos(pi*Rational(3, 7))**2/(16*sin(pi/7)**2)) == Rational(1, 64) + +def test_TR9(): + a = S.Half + b = 3*a + assert TR9(a) == a + assert TR9(cos(1) + cos(2)) == 2*cos(a)*cos(b) + assert TR9(cos(1) - cos(2)) == 2*sin(a)*sin(b) + assert TR9(sin(1) - sin(2)) == -2*sin(a)*cos(b) + assert TR9(sin(1) + sin(2)) == 2*sin(b)*cos(a) + assert TR9(cos(1) + 2*sin(1) + 2*sin(2)) == cos(1) + 4*sin(b)*cos(a) + assert TR9(cos(4) + cos(2) + 2*cos(1)*cos(3)) == 4*cos(1)*cos(3) + assert TR9((cos(4) + cos(2))/cos(3)/2 + cos(3)) == 2*cos(1)*cos(2) + assert TR9(cos(3) + cos(4) + cos(5) + cos(6)) == \ + 4*cos(S.Half)*cos(1)*cos(Rational(9, 2)) + assert TR9(cos(3) + cos(3)*cos(2)) == cos(3) + cos(2)*cos(3) + assert TR9(-cos(y) + cos(x*y)) == -2*sin(x*y/2 - y/2)*sin(x*y/2 + y/2) + assert TR9(-sin(y) + sin(x*y)) == 2*sin(x*y/2 - y/2)*cos(x*y/2 + y/2) + c = cos(x) + s = sin(x) + for si in ((1, 1), (1, -1), (-1, 1), (-1, -1)): + for a in ((c, s), (s, c), (cos(x), cos(x*y)), (sin(x), sin(x*y))): + args = zip(si, a) + ex = Add(*[Mul(*ai) for ai in args]) + t = TR9(ex) + assert not (a[0].func == a[1].func and ( + not verify_numerically(ex, t.expand(trig=True)) or t.is_Add) + or a[1].func != a[0].func and ex != t) + + +def test_TR10(): + assert TR10(cos(a + b)) == -sin(a)*sin(b) + cos(a)*cos(b) + assert TR10(sin(a + b)) == sin(a)*cos(b) + sin(b)*cos(a) + assert TR10(sin(a + b + c)) == \ + (-sin(a)*sin(b) + cos(a)*cos(b))*sin(c) + \ + (sin(a)*cos(b) + sin(b)*cos(a))*cos(c) + assert TR10(cos(a + b + c)) == \ + (-sin(a)*sin(b) + cos(a)*cos(b))*cos(c) - \ + (sin(a)*cos(b) + sin(b)*cos(a))*sin(c) + + +def test_TR10i(): + assert TR10i(cos(1)*cos(3) + sin(1)*sin(3)) == cos(2) + assert TR10i(cos(1)*cos(3) - sin(1)*sin(3)) == cos(4) + assert TR10i(cos(1)*sin(3) - sin(1)*cos(3)) == sin(2) + assert TR10i(cos(1)*sin(3) + sin(1)*cos(3)) == sin(4) + assert TR10i(cos(1)*sin(3) + sin(1)*cos(3) + 7) == sin(4) + 7 + assert TR10i(cos(1)*sin(3) + sin(1)*cos(3) + cos(3)) == cos(3) + sin(4) + assert TR10i(2*cos(1)*sin(3) + 2*sin(1)*cos(3) + cos(3)) == \ + 2*sin(4) + cos(3) + assert TR10i(cos(2)*cos(3) + sin(2)*(cos(1)*sin(2) + cos(2)*sin(1))) == \ + cos(1) + eq = (cos(2)*cos(3) + sin(2)*( + cos(1)*sin(2) + cos(2)*sin(1)))*cos(5) + sin(1)*sin(5) + assert TR10i(eq) == TR10i(eq.expand()) == cos(4) + assert TR10i(sqrt(2)*cos(x)*x + sqrt(6)*sin(x)*x) == \ + 2*sqrt(2)*x*sin(x + pi/6) + assert TR10i(cos(x)/sqrt(6) + sin(x)/sqrt(2) + + cos(x)/sqrt(6)/3 + sin(x)/sqrt(2)/3) == 4*sqrt(6)*sin(x + pi/6)/9 + assert TR10i(cos(x)/sqrt(6) + sin(x)/sqrt(2) + + cos(y)/sqrt(6)/3 + sin(y)/sqrt(2)/3) == \ + sqrt(6)*sin(x + pi/6)/3 + sqrt(6)*sin(y + pi/6)/9 + assert TR10i(cos(x) + sqrt(3)*sin(x) + 2*sqrt(3)*cos(x + pi/6)) == 4*cos(x) + assert TR10i(cos(x) + sqrt(3)*sin(x) + + 2*sqrt(3)*cos(x + pi/6) + 4*sin(x)) == 4*sqrt(2)*sin(x + pi/4) + assert TR10i(cos(2)*sin(3) + sin(2)*cos(4)) == \ + sin(2)*cos(4) + sin(3)*cos(2) + + A = Symbol('A', commutative=False) + assert TR10i(sqrt(2)*cos(x)*A + sqrt(6)*sin(x)*A) == \ + 2*sqrt(2)*sin(x + pi/6)*A + + + c = cos(x) + s = sin(x) + h = sin(y) + r = cos(y) + for si in ((1, 1), (1, -1), (-1, 1), (-1, -1)): + for argsi in ((c*r, s*h), (c*h, s*r)): # explicit 2-args + args = zip(si, argsi) + ex = Add(*[Mul(*ai) for ai in args]) + t = TR10i(ex) + assert not (ex - t.expand(trig=True) or t.is_Add) + + c = cos(x) + s = sin(x) + h = sin(pi/6) + r = cos(pi/6) + for si in ((1, 1), (1, -1), (-1, 1), (-1, -1)): + for argsi in ((c*r, s*h), (c*h, s*r)): # induced + args = zip(si, argsi) + ex = Add(*[Mul(*ai) for ai in args]) + t = TR10i(ex) + assert not (ex - t.expand(trig=True) or t.is_Add) + + +def test_TR11(): + + assert TR11(sin(2*x)) == 2*sin(x)*cos(x) + assert TR11(sin(4*x)) == 4*((-sin(x)**2 + cos(x)**2)*sin(x)*cos(x)) + assert TR11(sin(x*Rational(4, 3))) == \ + 4*((-sin(x/3)**2 + cos(x/3)**2)*sin(x/3)*cos(x/3)) + + assert TR11(cos(2*x)) == -sin(x)**2 + cos(x)**2 + assert TR11(cos(4*x)) == \ + (-sin(x)**2 + cos(x)**2)**2 - 4*sin(x)**2*cos(x)**2 + + assert TR11(cos(2)) == cos(2) + + assert TR11(cos(pi*Rational(3, 7)), pi*Rational(2, 7)) == -cos(pi*Rational(2, 7))**2 + sin(pi*Rational(2, 7))**2 + assert TR11(cos(4), 2) == -sin(2)**2 + cos(2)**2 + assert TR11(cos(6), 2) == cos(6) + assert TR11(sin(x)/cos(x/2), x/2) == 2*sin(x/2) + +def test__TR11(): + + assert _TR11(sin(x/3)*sin(2*x)*sin(x/4)/(cos(x/6)*cos(x/8))) == \ + 4*sin(x/8)*sin(x/6)*sin(2*x),_TR11(sin(x/3)*sin(2*x)*sin(x/4)/(cos(x/6)*cos(x/8))) + assert _TR11(sin(x/3)/cos(x/6)) == 2*sin(x/6) + + assert _TR11(cos(x/6)/sin(x/3)) == 1/(2*sin(x/6)) + assert _TR11(sin(2*x)*cos(x/8)/sin(x/4)) == sin(2*x)/(2*sin(x/8)), _TR11(sin(2*x)*cos(x/8)/sin(x/4)) + assert _TR11(sin(x)/sin(x/2)) == 2*cos(x/2) + + +def test_TR12(): + assert TR12(tan(x + y)) == (tan(x) + tan(y))/(-tan(x)*tan(y) + 1) + assert TR12(tan(x + y + z)) ==\ + (tan(z) + (tan(x) + tan(y))/(-tan(x)*tan(y) + 1))/( + 1 - (tan(x) + tan(y))*tan(z)/(-tan(x)*tan(y) + 1)) + assert TR12(tan(x*y)) == tan(x*y) + + +def test_TR13(): + assert TR13(tan(3)*tan(2)) == -tan(2)/tan(5) - tan(3)/tan(5) + 1 + assert TR13(cot(3)*cot(2)) == 1 + cot(3)*cot(5) + cot(2)*cot(5) + assert TR13(tan(1)*tan(2)*tan(3)) == \ + (-tan(2)/tan(5) - tan(3)/tan(5) + 1)*tan(1) + assert TR13(tan(1)*tan(2)*cot(3)) == \ + (-tan(2)/tan(3) + 1 - tan(1)/tan(3))*cot(3) + + +def test_L(): + assert L(cos(x) + sin(x)) == 2 + + +def test_fu(): + + assert fu(sin(50)**2 + cos(50)**2 + sin(pi/6)) == Rational(3, 2) + assert fu(sqrt(6)*cos(x) + sqrt(2)*sin(x)) == 2*sqrt(2)*sin(x + pi/3) + + + eq = sin(x)**4 - cos(y)**2 + sin(y)**2 + 2*cos(x)**2 + assert fu(eq) == cos(x)**4 - 2*cos(y)**2 + 2 + + assert fu(S.Half - cos(2*x)/2) == sin(x)**2 + + assert fu(sin(a)*(cos(b) - sin(b)) + cos(a)*(sin(b) + cos(b))) == \ + sqrt(2)*sin(a + b + pi/4) + + assert fu(sqrt(3)*cos(x)/2 + sin(x)/2) == sin(x + pi/3) + + assert fu(1 - sin(2*x)**2/4 - sin(y)**2 - cos(x)**4) == \ + -cos(x)**2 + cos(y)**2 + + assert fu(cos(pi*Rational(4, 9))) == sin(pi/18) + assert fu(cos(pi/9)*cos(pi*Rational(2, 9))*cos(pi*Rational(3, 9))*cos(pi*Rational(4, 9))) == Rational(1, 16) + + assert fu( + tan(pi*Rational(7, 18)) + tan(pi*Rational(5, 18)) - sqrt(3)*tan(pi*Rational(5, 18))*tan(pi*Rational(7, 18))) == \ + -sqrt(3) + + assert fu(tan(1)*tan(2)) == tan(1)*tan(2) + + expr = Mul(*[cos(2**i) for i in range(10)]) + assert fu(expr) == sin(1024)/(1024*sin(1)) + + # issue #18059: + assert fu(cos(x) + sqrt(sin(x)**2)) == cos(x) + sqrt(sin(x)**2) + + assert fu((-14*sin(x)**3 + 35*sin(x) + 6*sqrt(3)*cos(x)**3 + 9*sqrt(3)*cos(x))/((cos(2*x) + 4))) == \ + 7*sin(x) + 3*sqrt(3)*cos(x) + + +def test_objective(): + assert fu(sin(x)/cos(x), measure=lambda x: x.count_ops()) == \ + tan(x) + assert fu(sin(x)/cos(x), measure=lambda x: -x.count_ops()) == \ + sin(x)/cos(x) + + +def test_process_common_addends(): + # this tests that the args are not evaluated as they are given to do + # and that key2 works when key1 is False + do = lambda x: Add(*[i**(i%2) for i in x.args]) + assert process_common_addends(Add(*[1, 2, 3, 4], evaluate=False), do, + key2=lambda x: x%2, key1=False) == 1**1 + 3**1 + 2**0 + 4**0 + + +def test_trig_split(): + assert trig_split(cos(x), cos(y)) == (1, 1, 1, x, y, True) + assert trig_split(2*cos(x), -2*cos(y)) == (2, 1, -1, x, y, True) + assert trig_split(cos(x)*sin(y), cos(y)*sin(y)) == \ + (sin(y), 1, 1, x, y, True) + + assert trig_split(cos(x), -sqrt(3)*sin(x), two=True) == \ + (2, 1, -1, x, pi/6, False) + assert trig_split(cos(x), sin(x), two=True) == \ + (sqrt(2), 1, 1, x, pi/4, False) + assert trig_split(cos(x), -sin(x), two=True) == \ + (sqrt(2), 1, -1, x, pi/4, False) + assert trig_split(sqrt(2)*cos(x), -sqrt(6)*sin(x), two=True) == \ + (2*sqrt(2), 1, -1, x, pi/6, False) + assert trig_split(-sqrt(6)*cos(x), -sqrt(2)*sin(x), two=True) == \ + (-2*sqrt(2), 1, 1, x, pi/3, False) + assert trig_split(cos(x)/sqrt(6), sin(x)/sqrt(2), two=True) == \ + (sqrt(6)/3, 1, 1, x, pi/6, False) + assert trig_split(-sqrt(6)*cos(x)*sin(y), + -sqrt(2)*sin(x)*sin(y), two=True) == \ + (-2*sqrt(2)*sin(y), 1, 1, x, pi/3, False) + + assert trig_split(cos(x), sin(x)) is None + assert trig_split(cos(x), sin(z)) is None + assert trig_split(2*cos(x), -sin(x)) is None + assert trig_split(cos(x), -sqrt(3)*sin(x)) is None + assert trig_split(cos(x)*cos(y), sin(x)*sin(z)) is None + assert trig_split(cos(x)*cos(y), sin(x)*sin(y)) is None + assert trig_split(-sqrt(6)*cos(x), sqrt(2)*sin(x)*sin(y), two=True) is \ + None + + assert trig_split(sqrt(3)*sqrt(x), cos(3), two=True) is None + assert trig_split(sqrt(3)*root(x, 3), sin(3)*cos(2), two=True) is None + assert trig_split(cos(5)*cos(6), cos(7)*sin(5), two=True) is None + + +def test_TRmorrie(): + assert TRmorrie(7*Mul(*[cos(i) for i in range(10)])) == \ + 7*sin(12)*sin(16)*cos(5)*cos(7)*cos(9)/(64*sin(1)*sin(3)) + assert TRmorrie(x) == x + assert TRmorrie(2*x) == 2*x + e = cos(pi/7)*cos(pi*Rational(2, 7))*cos(pi*Rational(4, 7)) + assert TR8(TRmorrie(e)) == Rational(-1, 8) + e = Mul(*[cos(2**i*pi/17) for i in range(1, 17)]) + assert TR8(TR3(TRmorrie(e))) == Rational(1, 65536) + # issue 17063 + eq = cos(x)/cos(x/2) + assert TRmorrie(eq) == eq + # issue #20430 + eq = cos(x/2)*sin(x/2)*cos(x)**3 + assert TRmorrie(eq) == sin(2*x)*cos(x)**2/4 + + +def test_TRpower(): + assert TRpower(1/sin(x)**2) == 1/sin(x)**2 + assert TRpower(cos(x)**3*sin(x/2)**4) == \ + (3*cos(x)/4 + cos(3*x)/4)*(-cos(x)/2 + cos(2*x)/8 + Rational(3, 8)) + for k in range(2, 8): + assert verify_numerically(sin(x)**k, TRpower(sin(x)**k)) + assert verify_numerically(cos(x)**k, TRpower(cos(x)**k)) + + +def test_hyper_as_trig(): + from sympy.simplify.fu import _osborne, _osbornei + + eq = sinh(x)**2 + cosh(x)**2 + t, f = hyper_as_trig(eq) + assert f(fu(t)) == cosh(2*x) + e, f = hyper_as_trig(tanh(x + y)) + assert f(TR12(e)) == (tanh(x) + tanh(y))/(tanh(x)*tanh(y) + 1) + + d = Dummy() + assert _osborne(sinh(x), d) == I*sin(x*d) + assert _osborne(tanh(x), d) == I*tan(x*d) + assert _osborne(coth(x), d) == cot(x*d)/I + assert _osborne(cosh(x), d) == cos(x*d) + assert _osborne(sech(x), d) == sec(x*d) + assert _osborne(csch(x), d) == csc(x*d)/I + for func in (sinh, cosh, tanh, coth, sech, csch): + h = func(pi) + assert _osbornei(_osborne(h, d), d) == h + # /!\ the _osborne functions are not meant to work + # in the o(i(trig, d), d) direction so we just check + # that they work as they are supposed to work + assert _osbornei(cos(x*y + z), y) == cosh(x + z*I) + assert _osbornei(sin(x*y + z), y) == sinh(x + z*I)/I + assert _osbornei(tan(x*y + z), y) == tanh(x + z*I)/I + assert _osbornei(cot(x*y + z), y) == coth(x + z*I)*I + assert _osbornei(sec(x*y + z), y) == sech(x + z*I) + assert _osbornei(csc(x*y + z), y) == csch(x + z*I)*I + + +def test_TR12i(): + ta, tb, tc = [tan(i) for i in (a, b, c)] + assert TR12i((ta + tb)/(-ta*tb + 1)) == tan(a + b) + assert TR12i((ta + tb)/(ta*tb - 1)) == -tan(a + b) + assert TR12i((-ta - tb)/(ta*tb - 1)) == tan(a + b) + eq = (ta + tb)/(-ta*tb + 1)**2*(-3*ta - 3*tc)/(2*(ta*tc - 1)) + assert TR12i(eq.expand()) == \ + -3*tan(a + b)*tan(a + c)/(tan(a) + tan(b) - 1)/2 + assert TR12i(tan(x)/sin(x)) == tan(x)/sin(x) + eq = (ta + cos(2))/(-ta*tb + 1) + assert TR12i(eq) == eq + eq = (ta + tb + 2)**2/(-ta*tb + 1) + assert TR12i(eq) == eq + eq = ta/(-ta*tb + 1) + assert TR12i(eq) == eq + eq = (((ta + tb)*(a + 1)).expand())**2/(ta*tb - 1) + assert TR12i(eq) == -(a + 1)**2*tan(a + b) + + +def test_TR14(): + eq = (cos(x) - 1)*(cos(x) + 1) + ans = -sin(x)**2 + assert TR14(eq) == ans + assert TR14(1/eq) == 1/ans + assert TR14((cos(x) - 1)**2*(cos(x) + 1)**2) == ans**2 + assert TR14((cos(x) - 1)**2*(cos(x) + 1)**3) == ans**2*(cos(x) + 1) + assert TR14((cos(x) - 1)**3*(cos(x) + 1)**2) == ans**2*(cos(x) - 1) + eq = (cos(x) - 1)**y*(cos(x) + 1)**y + assert TR14(eq) == eq + eq = (cos(x) - 2)**y*(cos(x) + 1) + assert TR14(eq) == eq + eq = (tan(x) - 2)**2*(cos(x) + 1) + assert TR14(eq) == eq + i = symbols('i', integer=True) + assert TR14((cos(x) - 1)**i*(cos(x) + 1)**i) == ans**i + assert TR14((sin(x) - 1)**i*(sin(x) + 1)**i) == (-cos(x)**2)**i + # could use extraction in this case + eq = (cos(x) - 1)**(i + 1)*(cos(x) + 1)**i + assert TR14(eq) in [(cos(x) - 1)*ans**i, eq] + + assert TR14((sin(x) - 1)*(sin(x) + 1)) == -cos(x)**2 + p1 = (cos(x) + 1)*(cos(x) - 1) + p2 = (cos(y) - 1)*2*(cos(y) + 1) + p3 = (3*(cos(y) - 1))*(3*(cos(y) + 1)) + assert TR14(p1*p2*p3*(x - 1)) == -18*((x - 1)*sin(x)**2*sin(y)**4) + + +def test_TR15_16_17(): + assert TR15(1 - 1/sin(x)**2) == -cot(x)**2 + assert TR16(1 - 1/cos(x)**2) == -tan(x)**2 + assert TR111(1 - 1/tan(x)**2) == 1 - cot(x)**2 + + +def test_as_f_sign_1(): + assert as_f_sign_1(x + 1) == (1, x, 1) + assert as_f_sign_1(x - 1) == (1, x, -1) + assert as_f_sign_1(-x + 1) == (-1, x, -1) + assert as_f_sign_1(-x - 1) == (-1, x, 1) + assert as_f_sign_1(2*x + 2) == (2, x, 1) + assert as_f_sign_1(x*y - y) == (y, x, -1) + assert as_f_sign_1(-x*y + y) == (-y, x, -1) + + +def test_issue_25590(): + A = Symbol('A', commutative=False) + B = Symbol('B', commutative=False) + + assert TR8(2*cos(x)*sin(x)*B*A) == sin(2*x)*B*A + assert TR13(tan(2)*tan(3)*B*A) == (-tan(2)/tan(5) - tan(3)/tan(5) + 1)*B*A + + # XXX The result may not be optimal than + # sin(2*x)*B*A + cos(x)**2 and may change in the future + assert (2*cos(x)*sin(x)*B*A + cos(x)**2).simplify() == sin(2*x)*B*A + cos(2*x)/2 + S.One/2 diff --git a/phivenv/Lib/site-packages/sympy/simplify/tests/test_function.py b/phivenv/Lib/site-packages/sympy/simplify/tests/test_function.py new file mode 100644 index 0000000000000000000000000000000000000000..441b9faf1bb3c5e7f2279b2a61066d050e45f773 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/simplify/tests/test_function.py @@ -0,0 +1,54 @@ +""" Unit tests for Hyper_Function""" +from sympy.core import symbols, Dummy, Tuple, S, Rational +from sympy.functions import hyper + +from sympy.simplify.hyperexpand import Hyper_Function + +def test_attrs(): + a, b = symbols('a, b', cls=Dummy) + f = Hyper_Function([2, a], [b]) + assert f.ap == Tuple(2, a) + assert f.bq == Tuple(b) + assert f.args == (Tuple(2, a), Tuple(b)) + assert f.sizes == (2, 1) + +def test_call(): + a, b, x = symbols('a, b, x', cls=Dummy) + f = Hyper_Function([2, a], [b]) + assert f(x) == hyper([2, a], [b], x) + +def test_has(): + a, b, c = symbols('a, b, c', cls=Dummy) + f = Hyper_Function([2, -a], [b]) + assert f.has(a) + assert f.has(Tuple(b)) + assert not f.has(c) + +def test_eq(): + assert Hyper_Function([1], []) == Hyper_Function([1], []) + assert (Hyper_Function([1], []) != Hyper_Function([1], [])) is False + assert Hyper_Function([1], []) != Hyper_Function([2], []) + assert Hyper_Function([1], []) != Hyper_Function([1, 2], []) + assert Hyper_Function([1], []) != Hyper_Function([1], [2]) + +def test_gamma(): + assert Hyper_Function([2, 3], [-1]).gamma == 0 + assert Hyper_Function([-2, -3], [-1]).gamma == 2 + n = Dummy(integer=True) + assert Hyper_Function([-1, n, 1], []).gamma == 1 + assert Hyper_Function([-1, -n, 1], []).gamma == 1 + p = Dummy(integer=True, positive=True) + assert Hyper_Function([-1, p, 1], []).gamma == 1 + assert Hyper_Function([-1, -p, 1], []).gamma == 2 + +def test_suitable_origin(): + assert Hyper_Function((S.Half,), (Rational(3, 2),))._is_suitable_origin() is True + assert Hyper_Function((S.Half,), (S.Half,))._is_suitable_origin() is False + assert Hyper_Function((S.Half,), (Rational(-1, 2),))._is_suitable_origin() is False + assert Hyper_Function((S.Half,), (0,))._is_suitable_origin() is False + assert Hyper_Function((S.Half,), (-1, 1,))._is_suitable_origin() is False + assert Hyper_Function((S.Half, 0), (1,))._is_suitable_origin() is False + assert Hyper_Function((S.Half, 1), + (2, Rational(-2, 3)))._is_suitable_origin() is True + assert Hyper_Function((S.Half, 1), + (2, Rational(-2, 3), Rational(3, 2)))._is_suitable_origin() is True diff --git a/phivenv/Lib/site-packages/sympy/simplify/tests/test_gammasimp.py b/phivenv/Lib/site-packages/sympy/simplify/tests/test_gammasimp.py new file mode 100644 index 0000000000000000000000000000000000000000..e4c73093250b279510e3c2274db22818a9adffd8 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/simplify/tests/test_gammasimp.py @@ -0,0 +1,127 @@ +from sympy.core.function import Function +from sympy.core.numbers import (Rational, pi) +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.combinatorial.factorials import (rf, binomial, factorial) +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.functions.special.gamma_functions import gamma +from sympy.simplify.gammasimp import gammasimp +from sympy.simplify.powsimp import powsimp +from sympy.simplify.simplify import simplify + +from sympy.abc import x, y, n, k + + +def test_gammasimp(): + R = Rational + + # was part of test_combsimp_gamma() in test_combsimp.py + assert gammasimp(gamma(x)) == gamma(x) + assert gammasimp(gamma(x + 1)/x) == gamma(x) + assert gammasimp(gamma(x)/(x - 1)) == gamma(x - 1) + assert gammasimp(x*gamma(x)) == gamma(x + 1) + assert gammasimp((x + 1)*gamma(x + 1)) == gamma(x + 2) + assert gammasimp(gamma(x + y)*(x + y)) == gamma(x + y + 1) + assert gammasimp(x/gamma(x + 1)) == 1/gamma(x) + assert gammasimp((x + 1)**2/gamma(x + 2)) == (x + 1)/gamma(x + 1) + assert gammasimp(x*gamma(x) + gamma(x + 3)/(x + 2)) == \ + (x + 2)*gamma(x + 1) + + assert gammasimp(gamma(2*x)*x) == gamma(2*x + 1)/2 + assert gammasimp(gamma(2*x)/(x - S.Half)) == 2*gamma(2*x - 1) + + assert gammasimp(gamma(x)*gamma(1 - x)) == pi/sin(pi*x) + assert gammasimp(gamma(x)*gamma(-x)) == -pi/(x*sin(pi*x)) + assert gammasimp(1/gamma(x + 3)/gamma(1 - x)) == \ + sin(pi*x)/(pi*x*(x + 1)*(x + 2)) + + assert gammasimp(factorial(n + 2)) == gamma(n + 3) + assert gammasimp(binomial(n, k)) == \ + gamma(n + 1)/(gamma(k + 1)*gamma(-k + n + 1)) + + assert powsimp(gammasimp( + gamma(x)*gamma(x + S.Half)*gamma(y)/gamma(x + y))) == \ + 2**(-2*x + 1)*sqrt(pi)*gamma(2*x)*gamma(y)/gamma(x + y) + assert gammasimp(1/gamma(x)/gamma(x - Rational(1, 3))/gamma(x + Rational(1, 3))) == \ + 3**(3*x - Rational(3, 2))/(2*pi*gamma(3*x - 1)) + assert simplify( + gamma(S.Half + x/2)*gamma(1 + x/2)/gamma(1 + x)/sqrt(pi)*2**x) == 1 + assert gammasimp(gamma(Rational(-1, 4))*gamma(Rational(-3, 4))) == 16*sqrt(2)*pi/3 + + assert powsimp(gammasimp(gamma(2*x)/gamma(x))) == \ + 2**(2*x - 1)*gamma(x + S.Half)/sqrt(pi) + + # issue 6792 + e = (-gamma(k)*gamma(k + 2) + gamma(k + 1)**2)/gamma(k)**2 + assert gammasimp(e) == -k + assert gammasimp(1/e) == -1/k + e = (gamma(x) + gamma(x + 1))/gamma(x) + assert gammasimp(e) == x + 1 + assert gammasimp(1/e) == 1/(x + 1) + e = (gamma(x) + gamma(x + 2))*(gamma(x - 1) + gamma(x))/gamma(x) + assert gammasimp(e) == (x**2 + x + 1)*gamma(x + 1)/(x - 1) + e = (-gamma(k)*gamma(k + 2) + gamma(k + 1)**2)/gamma(k)**2 + assert gammasimp(e**2) == k**2 + assert gammasimp(e**2/gamma(k + 1)) == k/gamma(k) + a = R(1, 2) + R(1, 3) + b = a + R(1, 3) + assert gammasimp(gamma(2*k)/gamma(k)*gamma(k + a)*gamma(k + b) + ) == 3*2**(2*k + 1)*3**(-3*k - 2)*sqrt(pi)*gamma(3*k + R(3, 2))/2 + + # issue 9699 + assert gammasimp((x + 1)*factorial(x)/gamma(y)) == gamma(x + 2)/gamma(y) + assert gammasimp(rf(x + n, k)*binomial(n, k)).simplify() == Piecewise( + (gamma(n + 1)*gamma(k + n + x)/(gamma(k + 1)*gamma(n + x)*gamma(-k + n + 1)), n > -x), + ((-1)**k*gamma(n + 1)*gamma(-n - x + 1)/(gamma(k + 1)*gamma(-k + n + 1)*gamma(-k - n - x + 1)), True)) + + A, B = symbols('A B', commutative=False) + assert gammasimp(e*B*A) == gammasimp(e)*B*A + + # check iteration + assert gammasimp(gamma(2*k)/gamma(k)*gamma(-k - R(1, 2))) == ( + -2**(2*k + 1)*sqrt(pi)/(2*((2*k + 1)*cos(pi*k)))) + assert gammasimp( + gamma(k)*gamma(k + R(1, 3))*gamma(k + R(2, 3))/gamma(k*R(3, 2))) == ( + 3*2**(3*k + 1)*3**(-3*k - S.Half)*sqrt(pi)*gamma(k*R(3, 2) + S.Half)/2) + + # issue 6153 + assert gammasimp(gamma(Rational(1, 4))/gamma(Rational(5, 4))) == 4 + + # was part of test_combsimp() in test_combsimp.py + assert gammasimp(binomial(n + 2, k + S.Half)) == gamma(n + 3)/ \ + (gamma(k + R(3, 2))*gamma(-k + n + R(5, 2))) + assert gammasimp(binomial(n + 2, k + 2.0)) == \ + gamma(n + 3)/(gamma(k + 3.0)*gamma(-k + n + 1)) + + # issue 11548 + assert gammasimp(binomial(0, x)) == sin(pi*x)/(pi*x) + + e = gamma(n + Rational(1, 3))*gamma(n + R(2, 3)) + assert gammasimp(e) == e + assert gammasimp(gamma(4*n + S.Half)/gamma(2*n - R(3, 4))) == \ + 2**(4*n - R(5, 2))*(8*n - 3)*gamma(2*n + R(3, 4))/sqrt(pi) + + i, m = symbols('i m', integer = True) + e = gamma(exp(i)) + assert gammasimp(e) == e + e = gamma(m + 3) + assert gammasimp(e) == e + e = gamma(m + 1)/(gamma(i + 1)*gamma(-i + m + 1)) + assert gammasimp(e) == e + + p = symbols("p", integer=True, positive=True) + assert gammasimp(gamma(-p + 4)) == gamma(-p + 4) + + +def test_issue_22606(): + fx = Function('f')(x) + eq = x + gamma(y) + # seems like ans should be `eq`, not `(x*y + gamma(y + 1))/y` + ans = gammasimp(eq) + assert gammasimp(eq.subs(x, fx)).subs(fx, x) == ans + assert gammasimp(eq.subs(x, cos(x))).subs(cos(x), x) == ans + assert 1/gammasimp(1/eq) == ans + assert gammasimp(fx.subs(x, eq)).args[0] == ans diff --git a/phivenv/Lib/site-packages/sympy/simplify/tests/test_hyperexpand.py b/phivenv/Lib/site-packages/sympy/simplify/tests/test_hyperexpand.py new file mode 100644 index 0000000000000000000000000000000000000000..c703c228a13201de13cfd4c3413fc75a2cf5bdb6 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/simplify/tests/test_hyperexpand.py @@ -0,0 +1,1063 @@ +from sympy.core.random import randrange + +from sympy.simplify.hyperexpand import (ShiftA, ShiftB, UnShiftA, UnShiftB, + MeijerShiftA, MeijerShiftB, MeijerShiftC, MeijerShiftD, + MeijerUnShiftA, MeijerUnShiftB, MeijerUnShiftC, + MeijerUnShiftD, + ReduceOrder, reduce_order, apply_operators, + devise_plan, make_derivative_operator, Formula, + hyperexpand, Hyper_Function, G_Function, + reduce_order_meijer, + build_hypergeometric_formula) +from sympy.concrete.summations import Sum +from sympy.core.containers import Tuple +from sympy.core.expr import Expr +from sympy.core.numbers import I +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.combinatorial.factorials import binomial +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.special.hyper import (hyper, meijerg) +from sympy.abc import z, a, b, c +from sympy.testing.pytest import XFAIL, raises, slow, tooslow +from sympy.core.random import verify_numerically as tn + +from sympy.core.numbers import (Rational, pi) +from sympy.functions.elementary.exponential import (exp, exp_polar, log) +from sympy.functions.elementary.hyperbolic import atanh +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (asin, cos, sin) +from sympy.functions.special.bessel import besseli +from sympy.functions.special.error_functions import erf +from sympy.functions.special.gamma_functions import (gamma, lowergamma) + + +def test_branch_bug(): + assert hyperexpand(hyper((Rational(-1, 3), S.Half), (Rational(2, 3), Rational(3, 2)), -z)) == \ + -z**S('1/3')*lowergamma(exp_polar(I*pi)/3, z)/5 \ + + sqrt(pi)*erf(sqrt(z))/(5*sqrt(z)) + assert hyperexpand(meijerg([Rational(7, 6), 1], [], [Rational(2, 3)], [Rational(1, 6), 0], z)) == \ + 2*z**S('2/3')*(2*sqrt(pi)*erf(sqrt(z))/sqrt(z) - 2*lowergamma( + Rational(2, 3), z)/z**S('2/3'))*gamma(Rational(2, 3))/gamma(Rational(5, 3)) + + +def test_hyperexpand(): + # Luke, Y. L. (1969), The Special Functions and Their Approximations, + # Volume 1, section 6.2 + + assert hyperexpand(hyper([], [], z)) == exp(z) + assert hyperexpand(hyper([1, 1], [2], -z)*z) == log(1 + z) + assert hyperexpand(hyper([], [S.Half], -z**2/4)) == cos(z) + assert hyperexpand(z*hyper([], [S('3/2')], -z**2/4)) == sin(z) + assert hyperexpand(hyper([S('1/2'), S('1/2')], [S('3/2')], z**2)*z) \ + == asin(z) + assert isinstance(Sum(binomial(2, z)*z**2, (z, 0, a)).doit(), Expr) + + +def can_do(ap, bq, numerical=True, div=1, lowerplane=False): + r = hyperexpand(hyper(ap, bq, z)) + if r.has(hyper): + return False + if not numerical: + return True + repl = {} + randsyms = r.free_symbols - {z} + while randsyms: + # Only randomly generated parameters are checked. + for n, ai in enumerate(randsyms): + repl[ai] = randcplx(n)/div + if not any(b.is_Integer and b <= 0 for b in Tuple(*bq).subs(repl)): + break + [a, b, c, d] = [2, -1, 3, 1] + if lowerplane: + [a, b, c, d] = [2, -2, 3, -1] + return tn( + hyper(ap, bq, z).subs(repl), + r.replace(exp_polar, exp).subs(repl), + z, a=a, b=b, c=c, d=d) + + +def test_roach(): + # Kelly B. Roach. Meijer G Function Representations. + # Section "Gallery" + assert can_do([S.Half], [Rational(9, 2)]) + assert can_do([], [1, Rational(5, 2), 4]) + assert can_do([Rational(-1, 2), 1, 2], [3, 4]) + assert can_do([Rational(1, 3)], [Rational(-2, 3), Rational(-1, 2), S.Half, 1]) + assert can_do([Rational(-3, 2), Rational(-1, 2)], [Rational(-5, 2), 1]) + assert can_do([Rational(-3, 2), ], [Rational(-1, 2), S.Half]) # shine-integral + assert can_do([Rational(-3, 2), Rational(-1, 2)], [2]) # elliptic integrals + + +@XFAIL +def test_roach_fail(): + assert can_do([Rational(-1, 2), 1], [Rational(1, 4), S.Half, Rational(3, 4)]) # PFDD + assert can_do([Rational(3, 2)], [Rational(5, 2), 5]) # struve function + assert can_do([Rational(-1, 2), S.Half, 1], [Rational(3, 2), Rational(5, 2)]) # polylog, pfdd + assert can_do([1, 2, 3], [S.Half, 4]) # XXX ? + assert can_do([S.Half], [Rational(-1, 3), Rational(-1, 2), Rational(-2, 3)]) # PFDD ? + +# For the long table tests, see end of file + + +def test_polynomial(): + from sympy.core.numbers import oo + assert hyperexpand(hyper([], [-1], z)) is oo + assert hyperexpand(hyper([-2], [-1], z)) is oo + assert hyperexpand(hyper([0, 0], [-1], z)) == 1 + assert can_do([-5, -2, randcplx(), randcplx()], [-10, randcplx()]) + assert hyperexpand(hyper((-1, 1), (-2,), z)) == 1 + z/2 + + +def test_hyperexpand_bases(): + assert hyperexpand(hyper([2], [a], z)) == \ + a + z**(-a + 1)*(-a**2 + 3*a + z*(a - 1) - 2)*exp(z)* \ + lowergamma(a - 1, z) - 1 + # TODO [a+1, aRational(-1, 2)], [2*a] + assert hyperexpand(hyper([1, 2], [3], z)) == -2/z - 2*log(-z + 1)/z**2 + assert hyperexpand(hyper([S.Half, 2], [Rational(3, 2)], z)) == \ + -1/(2*z - 2) + atanh(sqrt(z))/sqrt(z)/2 + assert hyperexpand(hyper([S.Half, S.Half], [Rational(5, 2)], z)) == \ + (-3*z + 3)/4/(z*sqrt(-z + 1)) \ + + (6*z - 3)*asin(sqrt(z))/(4*z**Rational(3, 2)) + assert hyperexpand(hyper([1, 2], [Rational(3, 2)], z)) == -1/(2*z - 2) \ + - asin(sqrt(z))/(sqrt(z)*(2*z - 2)*sqrt(-z + 1)) + assert hyperexpand(hyper([Rational(-1, 2) - 1, 1, 2], [S.Half, 3], z)) == \ + sqrt(z)*(z*Rational(6, 7) - Rational(6, 5))*atanh(sqrt(z)) \ + + (-30*z**2 + 32*z - 6)/35/z - 6*log(-z + 1)/(35*z**2) + assert hyperexpand(hyper([1 + S.Half, 1, 1], [2, 2], z)) == \ + -4*log(sqrt(-z + 1)/2 + S.Half)/z + # TODO hyperexpand(hyper([a], [2*a + 1], z)) + # TODO [S.Half, a], [Rational(3, 2), a+1] + assert hyperexpand(hyper([2], [b, 1], z)) == \ + z**(-b/2 + S.Half)*besseli(b - 1, 2*sqrt(z))*gamma(b) \ + + z**(-b/2 + 1)*besseli(b, 2*sqrt(z))*gamma(b) + # TODO [a], [a - S.Half, 2*a] + + +def test_hyperexpand_parametric(): + assert hyperexpand(hyper([a, S.Half + a], [S.Half], z)) \ + == (1 + sqrt(z))**(-2*a)/2 + (1 - sqrt(z))**(-2*a)/2 + assert hyperexpand(hyper([a, Rational(-1, 2) + a], [2*a], z)) \ + == 2**(2*a - 1)*((-z + 1)**S.Half + 1)**(-2*a + 1) + + +def test_shifted_sum(): + from sympy.simplify.simplify import simplify + assert simplify(hyperexpand(z**4*hyper([2], [3, S('3/2')], -z**2))) \ + == z*sin(2*z) + (-z**2 + S.Half)*cos(2*z) - S.Half + + +def _randrat(): + """ Steer clear of integers. """ + return S(randrange(25) + 10)/50 + + +def randcplx(offset=-1): + """ Polys is not good with real coefficients. """ + return _randrat() + I*_randrat() + I*(1 + offset) + + +@slow +def test_formulae(): + from sympy.simplify.hyperexpand import FormulaCollection + formulae = FormulaCollection().formulae + for formula in formulae: + h = formula.func(formula.z) + rep = {} + for n, sym in enumerate(formula.symbols): + rep[sym] = randcplx(n) + + # NOTE hyperexpand returns truly branched functions. We know we are + # on the main sheet, but numerical evaluation can still go wrong + # (e.g. if exp_polar cannot be evalf'd). + # Just replace all exp_polar by exp, this usually works. + + # first test if the closed-form is actually correct + h = h.subs(rep) + closed_form = formula.closed_form.subs(rep).rewrite('nonrepsmall') + z = formula.z + assert tn(h, closed_form.replace(exp_polar, exp), z) + + # now test the computed matrix + cl = (formula.C * formula.B)[0].subs(rep).rewrite('nonrepsmall') + assert tn(closed_form.replace( + exp_polar, exp), cl.replace(exp_polar, exp), z) + deriv1 = z*formula.B.applyfunc(lambda t: t.rewrite( + 'nonrepsmall')).diff(z) + deriv2 = formula.M * formula.B + for d1, d2 in zip(deriv1, deriv2): + assert tn(d1.subs(rep).replace(exp_polar, exp), + d2.subs(rep).rewrite('nonrepsmall').replace(exp_polar, exp), z) + + +def test_meijerg_formulae(): + from sympy.simplify.hyperexpand import MeijerFormulaCollection + formulae = MeijerFormulaCollection().formulae + for sig in formulae: + for formula in formulae[sig]: + g = meijerg(formula.func.an, formula.func.ap, + formula.func.bm, formula.func.bq, + formula.z) + rep = {} + for sym in formula.symbols: + rep[sym] = randcplx() + + # first test if the closed-form is actually correct + g = g.subs(rep) + closed_form = formula.closed_form.subs(rep) + z = formula.z + assert tn(g, closed_form, z) + + # now test the computed matrix + cl = (formula.C * formula.B)[0].subs(rep) + assert tn(closed_form, cl, z) + deriv1 = z*formula.B.diff(z) + deriv2 = formula.M * formula.B + for d1, d2 in zip(deriv1, deriv2): + assert tn(d1.subs(rep), d2.subs(rep), z) + + +def op(f): + return z*f.diff(z) + + +def test_plan(): + assert devise_plan(Hyper_Function([0], ()), + Hyper_Function([0], ()), z) == [] + with raises(ValueError): + devise_plan(Hyper_Function([1], ()), Hyper_Function((), ()), z) + with raises(ValueError): + devise_plan(Hyper_Function([2], [1]), Hyper_Function([2], [2]), z) + with raises(ValueError): + devise_plan(Hyper_Function([2], []), Hyper_Function([S("1/2")], []), z) + + # We cannot use pi/(10000 + n) because polys is insanely slow. + a1, a2, b1 = (randcplx(n) for n in range(3)) + b1 += 2*I + h = hyper([a1, a2], [b1], z) + + h2 = hyper((a1 + 1, a2), [b1], z) + assert tn(apply_operators(h, + devise_plan(Hyper_Function((a1 + 1, a2), [b1]), + Hyper_Function((a1, a2), [b1]), z), op), + h2, z) + + h2 = hyper((a1 + 1, a2 - 1), [b1], z) + assert tn(apply_operators(h, + devise_plan(Hyper_Function((a1 + 1, a2 - 1), [b1]), + Hyper_Function((a1, a2), [b1]), z), op), + h2, z) + + +def test_plan_derivatives(): + a1, a2, a3 = 1, 2, S('1/2') + b1, b2 = 3, S('5/2') + h = Hyper_Function((a1, a2, a3), (b1, b2)) + h2 = Hyper_Function((a1 + 1, a2 + 1, a3 + 2), (b1 + 1, b2 + 1)) + ops = devise_plan(h2, h, z) + f = Formula(h, z, h(z), []) + deriv = make_derivative_operator(f.M, z) + assert tn((apply_operators(f.C, ops, deriv)*f.B)[0], h2(z), z) + + h2 = Hyper_Function((a1, a2 - 1, a3 - 2), (b1 - 1, b2 - 1)) + ops = devise_plan(h2, h, z) + assert tn((apply_operators(f.C, ops, deriv)*f.B)[0], h2(z), z) + + +def test_reduction_operators(): + a1, a2, b1 = (randcplx(n) for n in range(3)) + h = hyper([a1], [b1], z) + + assert ReduceOrder(2, 0) is None + assert ReduceOrder(2, -1) is None + assert ReduceOrder(1, S('1/2')) is None + + h2 = hyper((a1, a2), (b1, a2), z) + assert tn(ReduceOrder(a2, a2).apply(h, op), h2, z) + + h2 = hyper((a1, a2 + 1), (b1, a2), z) + assert tn(ReduceOrder(a2 + 1, a2).apply(h, op), h2, z) + + h2 = hyper((a2 + 4, a1), (b1, a2), z) + assert tn(ReduceOrder(a2 + 4, a2).apply(h, op), h2, z) + + # test several step order reduction + ap = (a2 + 4, a1, b1 + 1) + bq = (a2, b1, b1) + func, ops = reduce_order(Hyper_Function(ap, bq)) + assert func.ap == (a1,) + assert func.bq == (b1,) + assert tn(apply_operators(h, ops, op), hyper(ap, bq, z), z) + + +def test_shift_operators(): + a1, a2, b1, b2, b3 = (randcplx(n) for n in range(5)) + h = hyper((a1, a2), (b1, b2, b3), z) + + raises(ValueError, lambda: ShiftA(0)) + raises(ValueError, lambda: ShiftB(1)) + + assert tn(ShiftA(a1).apply(h, op), hyper((a1 + 1, a2), (b1, b2, b3), z), z) + assert tn(ShiftA(a2).apply(h, op), hyper((a1, a2 + 1), (b1, b2, b3), z), z) + assert tn(ShiftB(b1).apply(h, op), hyper((a1, a2), (b1 - 1, b2, b3), z), z) + assert tn(ShiftB(b2).apply(h, op), hyper((a1, a2), (b1, b2 - 1, b3), z), z) + assert tn(ShiftB(b3).apply(h, op), hyper((a1, a2), (b1, b2, b3 - 1), z), z) + + +def test_ushift_operators(): + a1, a2, b1, b2, b3 = (randcplx(n) for n in range(5)) + h = hyper((a1, a2), (b1, b2, b3), z) + + raises(ValueError, lambda: UnShiftA((1,), (), 0, z)) + raises(ValueError, lambda: UnShiftB((), (-1,), 0, z)) + raises(ValueError, lambda: UnShiftA((1,), (0, -1, 1), 0, z)) + raises(ValueError, lambda: UnShiftB((0, 1), (1,), 0, z)) + + s = UnShiftA((a1, a2), (b1, b2, b3), 0, z) + assert tn(s.apply(h, op), hyper((a1 - 1, a2), (b1, b2, b3), z), z) + s = UnShiftA((a1, a2), (b1, b2, b3), 1, z) + assert tn(s.apply(h, op), hyper((a1, a2 - 1), (b1, b2, b3), z), z) + + s = UnShiftB((a1, a2), (b1, b2, b3), 0, z) + assert tn(s.apply(h, op), hyper((a1, a2), (b1 + 1, b2, b3), z), z) + s = UnShiftB((a1, a2), (b1, b2, b3), 1, z) + assert tn(s.apply(h, op), hyper((a1, a2), (b1, b2 + 1, b3), z), z) + s = UnShiftB((a1, a2), (b1, b2, b3), 2, z) + assert tn(s.apply(h, op), hyper((a1, a2), (b1, b2, b3 + 1), z), z) + + +def can_do_meijer(a1, a2, b1, b2, numeric=True): + """ + This helper function tries to hyperexpand() the meijer g-function + corresponding to the parameters a1, a2, b1, b2. + It returns False if this expansion still contains g-functions. + If numeric is True, it also tests the so-obtained formula numerically + (at random values) and returns False if the test fails. + Else it returns True. + """ + from sympy.core.function import expand + from sympy.functions.elementary.complexes import unpolarify + r = hyperexpand(meijerg(a1, a2, b1, b2, z)) + if r.has(meijerg): + return False + # NOTE hyperexpand() returns a truly branched function, whereas numerical + # evaluation only works on the main branch. Since we are evaluating on + # the main branch, this should not be a problem, but expressions like + # exp_polar(I*pi/2*x)**a are evaluated incorrectly. We thus have to get + # rid of them. The expand heuristically does this... + r = unpolarify(expand(r, force=True, power_base=True, power_exp=False, + mul=False, log=False, multinomial=False, basic=False)) + + if not numeric: + return True + + repl = {} + for n, ai in enumerate(meijerg(a1, a2, b1, b2, z).free_symbols - {z}): + repl[ai] = randcplx(n) + return tn(meijerg(a1, a2, b1, b2, z).subs(repl), r.subs(repl), z) + + +@slow +def test_meijerg_expand(): + from sympy.simplify.gammasimp import gammasimp + from sympy.simplify.simplify import simplify + # from mpmath docs + assert hyperexpand(meijerg([[], []], [[0], []], -z)) == exp(z) + + assert hyperexpand(meijerg([[1, 1], []], [[1], [0]], z)) == \ + log(z + 1) + assert hyperexpand(meijerg([[1, 1], []], [[1], [1]], z)) == \ + z/(z + 1) + assert hyperexpand(meijerg([[], []], [[S.Half], [0]], (z/2)**2)) \ + == sin(z)/sqrt(pi) + assert hyperexpand(meijerg([[], []], [[0], [S.Half]], (z/2)**2)) \ + == cos(z)/sqrt(pi) + assert can_do_meijer([], [a], [a - 1, a - S.Half], []) + assert can_do_meijer([], [], [a/2], [-a/2], False) # branches... + assert can_do_meijer([a], [b], [a], [b, a - 1]) + + # wikipedia + assert hyperexpand(meijerg([1], [], [], [0], z)) == \ + Piecewise((0, abs(z) < 1), (1, abs(1/z) < 1), + (meijerg([1], [], [], [0], z), True)) + assert hyperexpand(meijerg([], [1], [0], [], z)) == \ + Piecewise((1, abs(z) < 1), (0, abs(1/z) < 1), + (meijerg([], [1], [0], [], z), True)) + + # The Special Functions and their Approximations + assert can_do_meijer([], [], [a + b/2], [a, a - b/2, a + S.Half]) + assert can_do_meijer( + [], [], [a], [b], False) # branches only agree for small z + assert can_do_meijer([], [S.Half], [a], [-a]) + assert can_do_meijer([], [], [a, b], []) + assert can_do_meijer([], [], [a, b], []) + assert can_do_meijer([], [], [a, a + S.Half], [b, b + S.Half]) + assert can_do_meijer([], [], [a, -a], [0, S.Half], False) # dito + assert can_do_meijer([], [], [a, a + S.Half, b, b + S.Half], []) + assert can_do_meijer([S.Half], [], [0], [a, -a]) + assert can_do_meijer([S.Half], [], [a], [0, -a], False) # dito + assert can_do_meijer([], [a - S.Half], [a, b], [a - S.Half], False) + assert can_do_meijer([], [a + S.Half], [a + b, a - b, a], [], False) + assert can_do_meijer([a + S.Half], [], [b, 2*a - b, a], [], False) + + # This for example is actually zero. + assert can_do_meijer([], [], [], [a, b]) + + # Testing a bug: + assert hyperexpand(meijerg([0, 2], [], [], [-1, 1], z)) == \ + Piecewise((0, abs(z) < 1), + (z*(1 - 1/z**2)/2, abs(1/z) < 1), + (meijerg([0, 2], [], [], [-1, 1], z), True)) + + # Test that the simplest possible answer is returned: + assert gammasimp(simplify(hyperexpand( + meijerg([1], [1 - a], [-a/2, -a/2 + S.Half], [], 1/z)))) == \ + -2*sqrt(pi)*(sqrt(z + 1) + 1)**a/a + + # Test that hyper is returned + assert hyperexpand(meijerg([1], [], [a], [0, 0], z)) == hyper( + (a,), (a + 1, a + 1), z*exp_polar(I*pi))*z**a*gamma(a)/gamma(a + 1)**2 + + # Test place option + f = meijerg(((0, 1), ()), ((S.Half,), (0,)), z**2) + assert hyperexpand(f) == sqrt(pi)/sqrt(1 + z**(-2)) + assert hyperexpand(f, place=0) == sqrt(pi)*z/sqrt(z**2 + 1) + + +def test_meijerg_lookup(): + from sympy.functions.special.error_functions import (Ci, Si) + from sympy.functions.special.gamma_functions import uppergamma + assert hyperexpand(meijerg([a], [], [b, a], [], z)) == \ + z**b*exp(z)*gamma(-a + b + 1)*uppergamma(a - b, z) + assert hyperexpand(meijerg([0], [], [0, 0], [], z)) == \ + exp(z)*uppergamma(0, z) + assert can_do_meijer([a], [], [b, a + 1], []) + assert can_do_meijer([a], [], [b + 2, a], []) + assert can_do_meijer([a], [], [b - 2, a], []) + + assert hyperexpand(meijerg([a], [], [a, a, a - S.Half], [], z)) == \ + -sqrt(pi)*z**(a - S.Half)*(2*cos(2*sqrt(z))*(Si(2*sqrt(z)) - pi/2) + - 2*sin(2*sqrt(z))*Ci(2*sqrt(z))) == \ + hyperexpand(meijerg([a], [], [a, a - S.Half, a], [], z)) == \ + hyperexpand(meijerg([a], [], [a - S.Half, a, a], [], z)) + assert can_do_meijer([a - 1], [], [a + 2, a - Rational(3, 2), a + 1], []) + + +@XFAIL +def test_meijerg_expand_fail(): + # These basically test hyper([], [1/2 - a, 1/2 + 1, 1/2], z), + # which is *very* messy. But since the meijer g actually yields a + # sum of bessel functions, things can sometimes be simplified a lot and + # are then put into tables... + assert can_do_meijer([], [], [a + S.Half], [a, a - b/2, a + b/2]) + assert can_do_meijer([], [], [0, S.Half], [a, -a]) + assert can_do_meijer([], [], [3*a - S.Half, a, -a - S.Half], [a - S.Half]) + assert can_do_meijer([], [], [0, a - S.Half, -a - S.Half], [S.Half]) + assert can_do_meijer([], [], [a, b + S.Half, b], [2*b - a]) + assert can_do_meijer([], [], [a, b + S.Half, b, 2*b - a]) + assert can_do_meijer([S.Half], [], [-a, a], [0]) + + +@slow +def test_meijerg(): + # carefully set up the parameters. + # NOTE: this used to fail sometimes. I believe it is fixed, but if you + # hit an inexplicable test failure here, please let me know the seed. + a1, a2 = (randcplx(n) - 5*I - n*I for n in range(2)) + b1, b2 = (randcplx(n) + 5*I + n*I for n in range(2)) + b3, b4, b5, a3, a4, a5 = (randcplx() for n in range(6)) + g = meijerg([a1], [a3, a4], [b1], [b3, b4], z) + + assert ReduceOrder.meijer_minus(3, 4) is None + assert ReduceOrder.meijer_plus(4, 3) is None + + g2 = meijerg([a1, a2], [a3, a4], [b1], [b3, b4, a2], z) + assert tn(ReduceOrder.meijer_plus(a2, a2).apply(g, op), g2, z) + + g2 = meijerg([a1, a2], [a3, a4], [b1], [b3, b4, a2 + 1], z) + assert tn(ReduceOrder.meijer_plus(a2, a2 + 1).apply(g, op), g2, z) + + g2 = meijerg([a1, a2 - 1], [a3, a4], [b1], [b3, b4, a2 + 2], z) + assert tn(ReduceOrder.meijer_plus(a2 - 1, a2 + 2).apply(g, op), g2, z) + + g2 = meijerg([a1], [a3, a4, b2 - 1], [b1, b2 + 2], [b3, b4], z) + assert tn(ReduceOrder.meijer_minus( + b2 + 2, b2 - 1).apply(g, op), g2, z, tol=1e-6) + + # test several-step reduction + an = [a1, a2] + bq = [b3, b4, a2 + 1] + ap = [a3, a4, b2 - 1] + bm = [b1, b2 + 1] + niq, ops = reduce_order_meijer(G_Function(an, ap, bm, bq)) + assert niq.an == (a1,) + assert set(niq.ap) == {a3, a4} + assert niq.bm == (b1,) + assert set(niq.bq) == {b3, b4} + assert tn(apply_operators(g, ops, op), meijerg(an, ap, bm, bq, z), z) + + +def test_meijerg_shift_operators(): + # carefully set up the parameters. XXX this still fails sometimes + a1, a2, a3, a4, a5, b1, b2, b3, b4, b5 = (randcplx(n) for n in range(10)) + g = meijerg([a1], [a3, a4], [b1], [b3, b4], z) + + assert tn(MeijerShiftA(b1).apply(g, op), + meijerg([a1], [a3, a4], [b1 + 1], [b3, b4], z), z) + assert tn(MeijerShiftB(a1).apply(g, op), + meijerg([a1 - 1], [a3, a4], [b1], [b3, b4], z), z) + assert tn(MeijerShiftC(b3).apply(g, op), + meijerg([a1], [a3, a4], [b1], [b3 + 1, b4], z), z) + assert tn(MeijerShiftD(a3).apply(g, op), + meijerg([a1], [a3 - 1, a4], [b1], [b3, b4], z), z) + + s = MeijerUnShiftA([a1], [a3, a4], [b1], [b3, b4], 0, z) + assert tn( + s.apply(g, op), meijerg([a1], [a3, a4], [b1 - 1], [b3, b4], z), z) + + s = MeijerUnShiftC([a1], [a3, a4], [b1], [b3, b4], 0, z) + assert tn( + s.apply(g, op), meijerg([a1], [a3, a4], [b1], [b3 - 1, b4], z), z) + + s = MeijerUnShiftB([a1], [a3, a4], [b1], [b3, b4], 0, z) + assert tn( + s.apply(g, op), meijerg([a1 + 1], [a3, a4], [b1], [b3, b4], z), z) + + s = MeijerUnShiftD([a1], [a3, a4], [b1], [b3, b4], 0, z) + assert tn( + s.apply(g, op), meijerg([a1], [a3 + 1, a4], [b1], [b3, b4], z), z) + + +@slow +def test_meijerg_confluence(): + def t(m, a, b): + from sympy.core.sympify import sympify + a, b = sympify([a, b]) + m_ = m + m = hyperexpand(m) + if not m == Piecewise((a, abs(z) < 1), (b, abs(1/z) < 1), (m_, True)): + return False + if not (m.args[0].args[0] == a and m.args[1].args[0] == b): + return False + z0 = randcplx()/10 + if abs(m.subs(z, z0).n() - a.subs(z, z0).n()).n() > 1e-10: + return False + if abs(m.subs(z, 1/z0).n() - b.subs(z, 1/z0).n()).n() > 1e-10: + return False + return True + + assert t(meijerg([], [1, 1], [0, 0], [], z), -log(z), 0) + assert t(meijerg( + [], [3, 1], [0, 0], [], z), -z**2/4 + z - log(z)/2 - Rational(3, 4), 0) + assert t(meijerg([], [3, 1], [-1, 0], [], z), + z**2/12 - z/2 + log(z)/2 + Rational(1, 4) + 1/(6*z), 0) + assert t(meijerg([], [1, 1, 1, 1], [0, 0, 0, 0], [], z), -log(z)**3/6, 0) + assert t(meijerg([1, 1], [], [], [0, 0], z), 0, -log(1/z)) + assert t(meijerg([1, 1], [2, 2], [1, 1], [0, 0], z), + -z*log(z) + 2*z, -log(1/z) + 2) + assert t(meijerg([S.Half], [1, 1], [0, 0], [Rational(3, 2)], z), log(z)/2 - 1, 0) + + def u(an, ap, bm, bq): + m = meijerg(an, ap, bm, bq, z) + m2 = hyperexpand(m, allow_hyper=True) + if m2.has(meijerg) and not (m2.is_Piecewise and len(m2.args) == 3): + return False + return tn(m, m2, z) + assert u([], [1], [0, 0], []) + assert u([1, 1], [], [], [0]) + assert u([1, 1], [2, 2, 5], [1, 1, 6], [0, 0]) + assert u([1, 1], [2, 2, 5], [1, 1, 6], [0]) + + +def test_meijerg_with_Floats(): + # see issue #10681 + from sympy.polys.domains.realfield import RR + f = meijerg(((3.0, 1), ()), ((Rational(3, 2),), (0,)), z) + a = -2.3632718012073 + g = a*z**Rational(3, 2)*hyper((-0.5, Rational(3, 2)), (Rational(5, 2),), z*exp_polar(I*pi)) + assert RR.almosteq((hyperexpand(f)/g).n(), 1.0, 1e-12) + + +def test_lerchphi(): + from sympy.functions.special.zeta_functions import (lerchphi, polylog) + from sympy.simplify.gammasimp import gammasimp + assert hyperexpand(hyper([1, a], [a + 1], z)/a) == lerchphi(z, 1, a) + assert hyperexpand( + hyper([1, a, a], [a + 1, a + 1], z)/a**2) == lerchphi(z, 2, a) + assert hyperexpand(hyper([1, a, a, a], [a + 1, a + 1, a + 1], z)/a**3) == \ + lerchphi(z, 3, a) + assert hyperexpand(hyper([1] + [a]*10, [a + 1]*10, z)/a**10) == \ + lerchphi(z, 10, a) + assert gammasimp(hyperexpand(meijerg([0, 1 - a], [], [0], + [-a], exp_polar(-I*pi)*z))) == lerchphi(z, 1, a) + assert gammasimp(hyperexpand(meijerg([0, 1 - a, 1 - a], [], [0], + [-a, -a], exp_polar(-I*pi)*z))) == lerchphi(z, 2, a) + assert gammasimp(hyperexpand(meijerg([0, 1 - a, 1 - a, 1 - a], [], [0], + [-a, -a, -a], exp_polar(-I*pi)*z))) == lerchphi(z, 3, a) + + assert hyperexpand(z*hyper([1, 1], [2], z)) == -log(1 + -z) + assert hyperexpand(z*hyper([1, 1, 1], [2, 2], z)) == polylog(2, z) + assert hyperexpand(z*hyper([1, 1, 1, 1], [2, 2, 2], z)) == polylog(3, z) + + assert hyperexpand(hyper([1, a, 1 + S.Half], [a + 1, S.Half], z)) == \ + -2*a/(z - 1) + (-2*a**2 + a)*lerchphi(z, 1, a) + + # Now numerical tests. These make sure reductions etc are carried out + # correctly + + # a rational function (polylog at negative integer order) + assert can_do([2, 2, 2], [1, 1]) + + # NOTE these contain log(1-x) etc ... better make sure we have |z| < 1 + # reduction of order for polylog + assert can_do([1, 1, 1, b + 5], [2, 2, b], div=10) + + # reduction of order for lerchphi + # XXX lerchphi in mpmath is flaky + assert can_do( + [1, a, a, a, b + 5], [a + 1, a + 1, a + 1, b], numerical=False) + + # test a bug + from sympy.functions.elementary.complexes import Abs + assert hyperexpand(hyper([S.Half, S.Half, S.Half, 1], + [Rational(3, 2), Rational(3, 2), Rational(3, 2)], Rational(1, 4))) == \ + Abs(-polylog(3, exp_polar(I*pi)/2) + polylog(3, S.Half)) + + +def test_partial_simp(): + # First test that hypergeometric function formulae work. + a, b, c, d, e = (randcplx() for _ in range(5)) + for func in [Hyper_Function([a, b, c], [d, e]), + Hyper_Function([], [a, b, c, d, e])]: + f = build_hypergeometric_formula(func) + z = f.z + assert f.closed_form == func(z) + deriv1 = f.B.diff(z)*z + deriv2 = f.M*f.B + for func1, func2 in zip(deriv1, deriv2): + assert tn(func1, func2, z) + + # Now test that formulae are partially simplified. + a, b, z = symbols('a b z') + assert hyperexpand(hyper([3, a], [1, b], z)) == \ + (-a*b/2 + a*z/2 + 2*a)*hyper([a + 1], [b], z) \ + + (a*b/2 - 2*a + 1)*hyper([a], [b], z) + assert tn( + hyperexpand(hyper([3, d], [1, e], z)), hyper([3, d], [1, e], z), z) + assert hyperexpand(hyper([3], [1, a, b], z)) == \ + hyper((), (a, b), z) \ + + z*hyper((), (a + 1, b), z)/(2*a) \ + - z*(b - 4)*hyper((), (a + 1, b + 1), z)/(2*a*b) + assert tn( + hyperexpand(hyper([3], [1, d, e], z)), hyper([3], [1, d, e], z), z) + + +def test_hyperexpand_special(): + assert hyperexpand(hyper([a, b], [c], 1)) == \ + gamma(c)*gamma(c - a - b)/gamma(c - a)/gamma(c - b) + assert hyperexpand(hyper([a, b], [1 + a - b], -1)) == \ + gamma(1 + a/2)*gamma(1 + a - b)/gamma(1 + a)/gamma(1 + a/2 - b) + assert hyperexpand(hyper([a, b], [1 + b - a], -1)) == \ + gamma(1 + b/2)*gamma(1 + b - a)/gamma(1 + b)/gamma(1 + b/2 - a) + assert hyperexpand(meijerg([1 - z - a/2], [1 - z + a/2], [b/2], [-b/2], 1)) == \ + gamma(1 - 2*z)*gamma(z + a/2 + b/2)/gamma(1 - z + a/2 - b/2) \ + /gamma(1 - z - a/2 + b/2)/gamma(1 - z + a/2 + b/2) + assert hyperexpand(hyper([a], [b], 0)) == 1 + assert hyper([a], [b], 0) != 0 + + +def test_Mod1_behavior(): + from sympy.core.symbol import Symbol + from sympy.simplify.simplify import simplify + n = Symbol('n', integer=True) + # Note: this should not hang. + assert simplify(hyperexpand(meijerg([1], [], [n + 1], [0], z))) == \ + lowergamma(n + 1, z) + + +@slow +def test_prudnikov_misc(): + assert can_do([1, (3 + I)/2, (3 - I)/2], [Rational(3, 2), 2]) + assert can_do([S.Half, a - 1], [Rational(3, 2), a + 1], lowerplane=True) + assert can_do([], [b + 1]) + assert can_do([a], [a - 1, b + 1]) + + assert can_do([a], [a - S.Half, 2*a]) + assert can_do([a], [a - S.Half, 2*a + 1]) + assert can_do([a], [a - S.Half, 2*a - 1]) + assert can_do([a], [a + S.Half, 2*a]) + assert can_do([a], [a + S.Half, 2*a + 1]) + assert can_do([a], [a + S.Half, 2*a - 1]) + assert can_do([S.Half], [b, 2 - b]) + assert can_do([S.Half], [b, 3 - b]) + assert can_do([1], [2, b]) + + assert can_do([a, a + S.Half], [2*a, b, 2*a - b + 1]) + assert can_do([a, a + S.Half], [S.Half, 2*a, 2*a + S.Half]) + assert can_do([a], [a + 1], lowerplane=True) # lowergamma + + +def test_prudnikov_1(): + # A. P. Prudnikov, Yu. A. Brychkov and O. I. Marichev (1990). + # Integrals and Series: More Special Functions, Vol. 3,. + # Gordon and Breach Science Publisher + + # 7.3.1 + assert can_do([a, -a], [S.Half]) + assert can_do([a, 1 - a], [S.Half]) + assert can_do([a, 1 - a], [Rational(3, 2)]) + assert can_do([a, 2 - a], [S.Half]) + assert can_do([a, 2 - a], [Rational(3, 2)]) + assert can_do([a, 2 - a], [Rational(3, 2)]) + assert can_do([a, a + S.Half], [2*a - 1]) + assert can_do([a, a + S.Half], [2*a]) + assert can_do([a, a + S.Half], [2*a + 1]) + assert can_do([a, a + S.Half], [S.Half]) + assert can_do([a, a + S.Half], [Rational(3, 2)]) + assert can_do([a, a/2 + 1], [a/2]) + assert can_do([1, b], [2]) + assert can_do([1, b], [b + 1], numerical=False) # Lerch Phi + # NOTE: branches are complicated for |z| > 1 + + assert can_do([a], [2*a]) + assert can_do([a], [2*a + 1]) + assert can_do([a], [2*a - 1]) + + +@slow +def test_prudnikov_2(): + h = S.Half + assert can_do([-h, -h], [h]) + assert can_do([-h, h], [3*h]) + assert can_do([-h, h], [5*h]) + assert can_do([-h, h], [7*h]) + assert can_do([-h, 1], [h]) + + for p in [-h, h]: + for n in [-h, h, 1, 3*h, 2, 5*h, 3, 7*h, 4]: + for m in [-h, h, 3*h, 5*h, 7*h]: + assert can_do([p, n], [m]) + for n in [1, 2, 3, 4]: + for m in [1, 2, 3, 4]: + assert can_do([p, n], [m]) + + +def test_prudnikov_3(): + h = S.Half + assert can_do([Rational(1, 4), Rational(3, 4)], [h]) + assert can_do([Rational(1, 4), Rational(3, 4)], [3*h]) + assert can_do([Rational(1, 3), Rational(2, 3)], [3*h]) + assert can_do([Rational(3, 4), Rational(5, 4)], [h]) + assert can_do([Rational(3, 4), Rational(5, 4)], [3*h]) + + +@tooslow +def test_prudnikov_3_slow(): + # XXX: This is marked as tooslow and hence skipped in CI. None of the + # individual cases below fails or hangs. Some cases are slow and the loops + # below generate 280 different cases. Is it really necessary to test all + # 280 cases here? + h = S.Half + for p in [1, 2, 3, 4]: + for n in [-h, h, 1, 3*h, 2, 5*h, 3, 7*h, 4, 9*h]: + for m in [1, 3*h, 2, 5*h, 3, 7*h, 4]: + assert can_do([p, m], [n]) + + +@slow +def test_prudnikov_4(): + h = S.Half + for p in [3*h, 5*h, 7*h]: + for n in [-h, h, 3*h, 5*h, 7*h]: + for m in [3*h, 2, 5*h, 3, 7*h, 4]: + assert can_do([p, m], [n]) + for n in [1, 2, 3, 4]: + for m in [2, 3, 4]: + assert can_do([p, m], [n]) + + +@slow +def test_prudnikov_5(): + h = S.Half + + for p in [1, 2, 3]: + for q in range(p, 4): + for r in [1, 2, 3]: + for s in range(r, 4): + assert can_do([-h, p, q], [r, s]) + + for p in [h, 1, 3*h, 2, 5*h, 3]: + for q in [h, 3*h, 5*h]: + for r in [h, 3*h, 5*h]: + for s in [h, 3*h, 5*h]: + if s <= q and s <= r: + assert can_do([-h, p, q], [r, s]) + + for p in [h, 1, 3*h, 2, 5*h, 3]: + for q in [1, 2, 3]: + for r in [h, 3*h, 5*h]: + for s in [1, 2, 3]: + assert can_do([-h, p, q], [r, s]) + + +@slow +def test_prudnikov_6(): + h = S.Half + + for m in [3*h, 5*h]: + for n in [1, 2, 3]: + for q in [h, 1, 2]: + for p in [1, 2, 3]: + assert can_do([h, q, p], [m, n]) + for q in [1, 2, 3]: + for p in [3*h, 5*h]: + assert can_do([h, q, p], [m, n]) + + for q in [1, 2]: + for p in [1, 2, 3]: + for m in [1, 2, 3]: + for n in [1, 2, 3]: + assert can_do([h, q, p], [m, n]) + + assert can_do([h, h, 5*h], [3*h, 3*h]) + assert can_do([h, 1, 5*h], [3*h, 3*h]) + assert can_do([h, 2, 2], [1, 3]) + + # pages 435 to 457 contain more PFDD and stuff like this + + +@slow +def test_prudnikov_7(): + assert can_do([3], [6]) + + h = S.Half + for n in [h, 3*h, 5*h, 7*h]: + assert can_do([-h], [n]) + for m in [-h, h, 1, 3*h, 2, 5*h, 3, 7*h, 4]: # HERE + for n in [-h, h, 3*h, 5*h, 7*h, 1, 2, 3, 4]: + assert can_do([m], [n]) + + +@slow +def test_prudnikov_8(): + h = S.Half + + # 7.12.2 + for ai in [1, 2, 3]: + for bi in [1, 2, 3]: + for ci in range(1, ai + 1): + for di in [h, 1, 3*h, 2, 5*h, 3]: + assert can_do([ai, bi], [ci, di]) + for bi in [3*h, 5*h]: + for ci in [h, 1, 3*h, 2, 5*h, 3]: + for di in [1, 2, 3]: + assert can_do([ai, bi], [ci, di]) + + for ai in [-h, h, 3*h, 5*h]: + for bi in [1, 2, 3]: + for ci in [h, 1, 3*h, 2, 5*h, 3]: + for di in [1, 2, 3]: + assert can_do([ai, bi], [ci, di]) + for bi in [h, 3*h, 5*h]: + for ci in [h, 3*h, 5*h, 3]: + for di in [h, 1, 3*h, 2, 5*h, 3]: + if ci <= bi: + assert can_do([ai, bi], [ci, di]) + + +def test_prudnikov_9(): + # 7.13.1 [we have a general formula ... so this is a bit pointless] + for i in range(9): + assert can_do([], [(S(i) + 1)/2]) + for i in range(5): + assert can_do([], [-(2*S(i) + 1)/2]) + + +@slow +def test_prudnikov_10(): + # 7.14.2 + h = S.Half + for p in [-h, h, 1, 3*h, 2, 5*h, 3, 7*h, 4]: + for m in [1, 2, 3, 4]: + for n in range(m, 5): + assert can_do([p], [m, n]) + + for p in [1, 2, 3, 4]: + for n in [h, 3*h, 5*h, 7*h]: + for m in [1, 2, 3, 4]: + assert can_do([p], [n, m]) + + for p in [3*h, 5*h, 7*h]: + for m in [h, 1, 2, 5*h, 3, 7*h, 4]: + assert can_do([p], [h, m]) + assert can_do([p], [3*h, m]) + + for m in [h, 1, 2, 5*h, 3, 7*h, 4]: + assert can_do([7*h], [5*h, m]) + + assert can_do([Rational(-1, 2)], [S.Half, S.Half]) # shine-integral shi + + +def test_prudnikov_11(): + # 7.15 + assert can_do([a, a + S.Half], [2*a, b, 2*a - b]) + assert can_do([a, a + S.Half], [Rational(3, 2), 2*a, 2*a - S.Half]) + + assert can_do([Rational(1, 4), Rational(3, 4)], [S.Half, S.Half, 1]) + assert can_do([Rational(5, 4), Rational(3, 4)], [Rational(3, 2), S.Half, 2]) + assert can_do([Rational(5, 4), Rational(3, 4)], [Rational(3, 2), Rational(3, 2), 1]) + assert can_do([Rational(5, 4), Rational(7, 4)], [Rational(3, 2), Rational(5, 2), 2]) + + assert can_do([1, 1], [Rational(3, 2), 2, 2]) # cosh-integral chi + + +def test_prudnikov_12(): + # 7.16 + assert can_do( + [], [a, a + S.Half, 2*a], False) # branches only agree for some z! + assert can_do([], [a, a + S.Half, 2*a + 1], False) # dito + assert can_do([], [S.Half, a, a + S.Half]) + assert can_do([], [Rational(3, 2), a, a + S.Half]) + + assert can_do([], [Rational(1, 4), S.Half, Rational(3, 4)]) + assert can_do([], [S.Half, S.Half, 1]) + assert can_do([], [S.Half, Rational(3, 2), 1]) + assert can_do([], [Rational(3, 4), Rational(3, 2), Rational(5, 4)]) + assert can_do([], [1, 1, Rational(3, 2)]) + assert can_do([], [1, 2, Rational(3, 2)]) + assert can_do([], [1, Rational(3, 2), Rational(3, 2)]) + assert can_do([], [Rational(5, 4), Rational(3, 2), Rational(7, 4)]) + assert can_do([], [2, Rational(3, 2), Rational(3, 2)]) + + +@slow +def test_prudnikov_2F1(): + h = S.Half + # Elliptic integrals + for p in [-h, h]: + for m in [h, 3*h, 5*h, 7*h]: + for n in [1, 2, 3, 4]: + assert can_do([p, m], [n]) + + +@XFAIL +def test_prudnikov_fail_2F1(): + assert can_do([a, b], [b + 1]) # incomplete beta function + assert can_do([-1, b], [c]) # Poly. also -2, -3 etc + + # TODO polys + + # Legendre functions: + assert can_do([a, b], [a + b + S.Half]) + assert can_do([a, b], [a + b - S.Half]) + assert can_do([a, b], [a + b + Rational(3, 2)]) + assert can_do([a, b], [(a + b + 1)/2]) + assert can_do([a, b], [(a + b)/2 + 1]) + assert can_do([a, b], [a - b + 1]) + assert can_do([a, b], [a - b + 2]) + assert can_do([a, b], [2*b]) + assert can_do([a, b], [S.Half]) + assert can_do([a, b], [Rational(3, 2)]) + assert can_do([a, 1 - a], [c]) + assert can_do([a, 2 - a], [c]) + assert can_do([a, 3 - a], [c]) + assert can_do([a, a + S.Half], [c]) + assert can_do([1, b], [c]) + assert can_do([1, b], [Rational(3, 2)]) + + assert can_do([Rational(1, 4), Rational(3, 4)], [1]) + + # PFDD + o = S.One + assert can_do([o/8, 1], [o/8*9]) + assert can_do([o/6, 1], [o/6*7]) + assert can_do([o/6, 1], [o/6*13]) + assert can_do([o/5, 1], [o/5*6]) + assert can_do([o/5, 1], [o/5*11]) + assert can_do([o/4, 1], [o/4*5]) + assert can_do([o/4, 1], [o/4*9]) + assert can_do([o/3, 1], [o/3*4]) + assert can_do([o/3, 1], [o/3*7]) + assert can_do([o/8*3, 1], [o/8*11]) + assert can_do([o/5*2, 1], [o/5*7]) + assert can_do([o/5*2, 1], [o/5*12]) + assert can_do([o/5*3, 1], [o/5*8]) + assert can_do([o/5*3, 1], [o/5*13]) + assert can_do([o/8*5, 1], [o/8*13]) + assert can_do([o/4*3, 1], [o/4*7]) + assert can_do([o/4*3, 1], [o/4*11]) + assert can_do([o/3*2, 1], [o/3*5]) + assert can_do([o/3*2, 1], [o/3*8]) + assert can_do([o/5*4, 1], [o/5*9]) + assert can_do([o/5*4, 1], [o/5*14]) + assert can_do([o/6*5, 1], [o/6*11]) + assert can_do([o/6*5, 1], [o/6*17]) + assert can_do([o/8*7, 1], [o/8*15]) + + +@XFAIL +def test_prudnikov_fail_3F2(): + assert can_do([a, a + Rational(1, 3), a + Rational(2, 3)], [Rational(1, 3), Rational(2, 3)]) + assert can_do([a, a + Rational(1, 3), a + Rational(2, 3)], [Rational(2, 3), Rational(4, 3)]) + assert can_do([a, a + Rational(1, 3), a + Rational(2, 3)], [Rational(4, 3), Rational(5, 3)]) + + # page 421 + assert can_do([a, a + Rational(1, 3), a + Rational(2, 3)], [a*Rational(3, 2), (3*a + 1)/2]) + + # pages 422 ... + assert can_do([Rational(-1, 2), S.Half, S.Half], [1, 1]) # elliptic integrals + assert can_do([Rational(-1, 2), S.Half, 1], [Rational(3, 2), Rational(3, 2)]) + # TODO LOTS more + + # PFDD + assert can_do([Rational(1, 8), Rational(3, 8), 1], [Rational(9, 8), Rational(11, 8)]) + assert can_do([Rational(1, 8), Rational(5, 8), 1], [Rational(9, 8), Rational(13, 8)]) + assert can_do([Rational(1, 8), Rational(7, 8), 1], [Rational(9, 8), Rational(15, 8)]) + assert can_do([Rational(1, 6), Rational(1, 3), 1], [Rational(7, 6), Rational(4, 3)]) + assert can_do([Rational(1, 6), Rational(2, 3), 1], [Rational(7, 6), Rational(5, 3)]) + assert can_do([Rational(1, 6), Rational(2, 3), 1], [Rational(5, 3), Rational(13, 6)]) + assert can_do([S.Half, 1, 1], [Rational(1, 4), Rational(3, 4)]) + # LOTS more + + +@XFAIL +def test_prudnikov_fail_other(): + # 7.11.2 + + # 7.12.1 + assert can_do([1, a], [b, 1 - 2*a + b]) # ??? + + # 7.14.2 + assert can_do([Rational(-1, 2)], [S.Half, 1]) # struve + assert can_do([1], [S.Half, S.Half]) # struve + assert can_do([Rational(1, 4)], [S.Half, Rational(5, 4)]) # PFDD + assert can_do([Rational(3, 4)], [Rational(3, 2), Rational(7, 4)]) # PFDD + assert can_do([1], [Rational(1, 4), Rational(3, 4)]) # PFDD + assert can_do([1], [Rational(3, 4), Rational(5, 4)]) # PFDD + assert can_do([1], [Rational(5, 4), Rational(7, 4)]) # PFDD + # TODO LOTS more + + # 7.15.2 + assert can_do([S.Half, 1], [Rational(3, 4), Rational(5, 4), Rational(3, 2)]) # PFDD + assert can_do([S.Half, 1], [Rational(7, 4), Rational(5, 4), Rational(3, 2)]) # PFDD + + # 7.16.1 + assert can_do([], [Rational(1, 3), S(2/3)]) # PFDD + assert can_do([], [Rational(2, 3), S(4/3)]) # PFDD + assert can_do([], [Rational(5, 3), S(4/3)]) # PFDD + + # XXX this does not *evaluate* right?? + assert can_do([], [a, a + S.Half, 2*a - 1]) + + +def test_bug(): + h = hyper([-1, 1], [z], -1) + assert hyperexpand(h) == (z + 1)/z + + +def test_omgissue_203(): + h = hyper((-5, -3, -4), (-6, -6), 1) + assert hyperexpand(h) == Rational(1, 30) + h = hyper((-6, -7, -5), (-6, -6), 1) + assert hyperexpand(h) == Rational(-1, 6) diff --git a/phivenv/Lib/site-packages/sympy/simplify/tests/test_powsimp.py b/phivenv/Lib/site-packages/sympy/simplify/tests/test_powsimp.py new file mode 100644 index 0000000000000000000000000000000000000000..61bdc93d052baf4b1e80da8f5864cf22b8fa383e --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/simplify/tests/test_powsimp.py @@ -0,0 +1,368 @@ +from sympy.core.function import Function +from sympy.core.mul import Mul +from sympy.core.numbers import (E, I, Rational, oo, pi) +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol, symbols) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import (root, sqrt) +from sympy.functions.elementary.trigonometric import sin +from sympy.functions.special.gamma_functions import gamma +from sympy.functions.special.hyper import hyper +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.simplify.powsimp import (powdenest, powsimp) +from sympy.simplify.simplify import (signsimp, simplify) +from sympy.core.symbol import Str + +from sympy.abc import x, y, z, a, b + + +def test_powsimp(): + x, y, z, n = symbols('x,y,z,n') + f = Function('f') + assert powsimp( 4**x * 2**(-x) * 2**(-x) ) == 1 + assert powsimp( (-4)**x * (-2)**(-x) * 2**(-x) ) == 1 + + assert powsimp( + f(4**x * 2**(-x) * 2**(-x)) ) == f(4**x * 2**(-x) * 2**(-x)) + assert powsimp( f(4**x * 2**(-x) * 2**(-x)), deep=True ) == f(1) + assert exp(x)*exp(y) == exp(x)*exp(y) + assert powsimp(exp(x)*exp(y)) == exp(x + y) + assert powsimp(exp(x)*exp(y)*2**x*2**y) == (2*E)**(x + y) + assert powsimp(exp(x)*exp(y)*2**x*2**y, combine='exp') == \ + exp(x + y)*2**(x + y) + assert powsimp(exp(x)*exp(y)*exp(2)*sin(x) + sin(y) + 2**x*2**y) == \ + exp(2 + x + y)*sin(x) + sin(y) + 2**(x + y) + assert powsimp(sin(exp(x)*exp(y))) == sin(exp(x)*exp(y)) + assert powsimp(sin(exp(x)*exp(y)), deep=True) == sin(exp(x + y)) + assert powsimp(x**2*x**y) == x**(2 + y) + # This should remain factored, because 'exp' with deep=True is supposed + # to act like old automatic exponent combining. + assert powsimp((1 + E*exp(E))*exp(-E), combine='exp', deep=True) == \ + (1 + exp(1 + E))*exp(-E) + assert powsimp((1 + E*exp(E))*exp(-E), deep=True) == \ + (1 + exp(1 + E))*exp(-E) + assert powsimp((1 + E*exp(E))*exp(-E)) == (1 + exp(1 + E))*exp(-E) + assert powsimp((1 + E*exp(E))*exp(-E), combine='exp') == \ + (1 + exp(1 + E))*exp(-E) + assert powsimp((1 + E*exp(E))*exp(-E), combine='base') == \ + (1 + E*exp(E))*exp(-E) + x, y = symbols('x,y', nonnegative=True) + n = Symbol('n', real=True) + assert powsimp(y**n * (y/x)**(-n)) == x**n + assert powsimp(x**(x**(x*y)*y**(x*y))*y**(x**(x*y)*y**(x*y)), deep=True) \ + == (x*y)**(x*y)**(x*y) + assert powsimp(2**(2**(2*x)*x), deep=False) == 2**(2**(2*x)*x) + assert powsimp(2**(2**(2*x)*x), deep=True) == 2**(x*4**x) + assert powsimp( + exp(-x + exp(-x)*exp(-x*log(x))), deep=False, combine='exp') == \ + exp(-x + exp(-x)*exp(-x*log(x))) + assert powsimp( + exp(-x + exp(-x)*exp(-x*log(x))), deep=False, combine='exp') == \ + exp(-x + exp(-x)*exp(-x*log(x))) + assert powsimp((x + y)/(3*z), deep=False, combine='exp') == (x + y)/(3*z) + assert powsimp((x/3 + y/3)/z, deep=True, combine='exp') == (x/3 + y/3)/z + assert powsimp(exp(x)/(1 + exp(x)*exp(y)), deep=True) == \ + exp(x)/(1 + exp(x + y)) + assert powsimp(x*y**(z**x*z**y), deep=True) == x*y**(z**(x + y)) + assert powsimp((z**x*z**y)**x, deep=True) == (z**(x + y))**x + assert powsimp(x*(z**x*z**y)**x, deep=True) == x*(z**(x + y))**x + p = symbols('p', positive=True) + assert powsimp((1/x)**log(2)/x) == (1/x)**(1 + log(2)) + assert powsimp((1/p)**log(2)/p) == p**(-1 - log(2)) + + # coefficient of exponent can only be simplified for positive bases + assert powsimp(2**(2*x)) == 4**x + assert powsimp((-1)**(2*x)) == (-1)**(2*x) + i = symbols('i', integer=True) + assert powsimp((-1)**(2*i)) == 1 + assert powsimp((-1)**(-x)) != (-1)**x # could be 1/((-1)**x), but is not + # force=True overrides assumptions + assert powsimp((-1)**(2*x), force=True) == 1 + + # rational exponents allow combining of negative terms + w, n, m = symbols('w n m', negative=True) + e = i/a # not a rational exponent if `a` is unknown + ex = w**e*n**e*m**e + assert powsimp(ex) == m**(i/a)*n**(i/a)*w**(i/a) + e = i/3 + ex = w**e*n**e*m**e + assert powsimp(ex) == (-1)**i*(-m*n*w)**(i/3) + e = (3 + i)/i + ex = w**e*n**e*m**e + assert powsimp(ex) == (-1)**(3*e)*(-m*n*w)**e + + eq = x**(a*Rational(2, 3)) + # eq != (x**a)**(2/3) (try x = -1 and a = 3 to see) + assert powsimp(eq).exp == eq.exp == a*Rational(2, 3) + # powdenest goes the other direction + assert powsimp(2**(2*x)) == 4**x + + assert powsimp(exp(p/2)) == exp(p/2) + + # issue 6368 + eq = Mul(*[sqrt(Dummy(imaginary=True)) for i in range(3)]) + assert powsimp(eq) == eq and eq.is_Mul + + assert all(powsimp(e) == e for e in (sqrt(x**a), sqrt(x**2))) + + # issue 8836 + assert str( powsimp(exp(I*pi/3)*root(-1,3)) ) == '(-1)**(2/3)' + + # issue 9183 + assert powsimp(-0.1**x) == -0.1**x + + # issue 10095 + assert powsimp((1/(2*E))**oo) == (exp(-1)/2)**oo + + # PR 13131 + eq = sin(2*x)**2*sin(2.0*x)**2 + assert powsimp(eq) == eq + + # issue 14615 + assert powsimp(x**2*y**3*(x*y**2)**Rational(3, 2) + ) == x*y*(x*y**2)**Rational(5, 2) + + #issue 27380 + assert powsimp(1.0**(x+1)/1.0**x) == 1.0 + +def test_powsimp_negated_base(): + assert powsimp((-x + y)/sqrt(x - y)) == -sqrt(x - y) + assert powsimp((-x + y)*(-z + y)/sqrt(x - y)/sqrt(z - y)) == sqrt(x - y)*sqrt(z - y) + p = symbols('p', positive=True) + reps = {p: 2, a: S.Half} + assert powsimp((-p)**a/p**a).subs(reps) == ((-1)**a).subs(reps) + assert powsimp((-p)**a*p**a).subs(reps) == ((-p**2)**a).subs(reps) + n = symbols('n', negative=True) + reps = {p: -2, a: S.Half} + assert powsimp((-n)**a/n**a).subs(reps) == (-1)**(-a).subs(a, S.Half) + assert powsimp((-n)**a*n**a).subs(reps) == ((-n**2)**a).subs(reps) + # if x is 0 then the lhs is 0**a*oo**a which is not (-1)**a + eq = (-x)**a/x**a + assert powsimp(eq) == eq + + +def test_powsimp_nc(): + x, y, z = symbols('x,y,z') + A, B, C = symbols('A B C', commutative=False) + + assert powsimp(A**x*A**y, combine='all') == A**(x + y) + assert powsimp(A**x*A**y, combine='base') == A**x*A**y + assert powsimp(A**x*A**y, combine='exp') == A**(x + y) + + assert powsimp(A**x*B**x, combine='all') == A**x*B**x + assert powsimp(A**x*B**x, combine='base') == A**x*B**x + assert powsimp(A**x*B**x, combine='exp') == A**x*B**x + + assert powsimp(B**x*A**x, combine='all') == B**x*A**x + assert powsimp(B**x*A**x, combine='base') == B**x*A**x + assert powsimp(B**x*A**x, combine='exp') == B**x*A**x + + assert powsimp(A**x*A**y*A**z, combine='all') == A**(x + y + z) + assert powsimp(A**x*A**y*A**z, combine='base') == A**x*A**y*A**z + assert powsimp(A**x*A**y*A**z, combine='exp') == A**(x + y + z) + + assert powsimp(A**x*B**x*C**x, combine='all') == A**x*B**x*C**x + assert powsimp(A**x*B**x*C**x, combine='base') == A**x*B**x*C**x + assert powsimp(A**x*B**x*C**x, combine='exp') == A**x*B**x*C**x + + assert powsimp(B**x*A**x*C**x, combine='all') == B**x*A**x*C**x + assert powsimp(B**x*A**x*C**x, combine='base') == B**x*A**x*C**x + assert powsimp(B**x*A**x*C**x, combine='exp') == B**x*A**x*C**x + + +def test_issue_6440(): + assert powsimp(16*2**a*8**b) == 2**(a + 3*b + 4) + + +def test_powdenest(): + x, y = symbols('x,y') + p, q = symbols('p q', positive=True) + i, j = symbols('i,j', integer=True) + + assert powdenest(x) == x + assert powdenest(x + 2*(x**(a*Rational(2, 3)))**(3*x)) == (x + 2*(x**(a*Rational(2, 3)))**(3*x)) + assert powdenest((exp(a*Rational(2, 3)))**(3*x)) # -X-> (exp(a/3))**(6*x) + assert powdenest((x**(a*Rational(2, 3)))**(3*x)) == ((x**(a*Rational(2, 3)))**(3*x)) + assert powdenest(exp(3*x*log(2))) == 2**(3*x) + assert powdenest(sqrt(p**2)) == p + eq = p**(2*i)*q**(4*i) + assert powdenest(eq) == (p*q**2)**(2*i) + # -X-> (x**x)**i*(x**x)**j == x**(x*(i + j)) + assert powdenest((x**x)**(i + j)) + assert powdenest(exp(3*y*log(x))) == x**(3*y) + assert powdenest(exp(y*(log(a) + log(b)))) == (a*b)**y + assert powdenest(exp(3*(log(a) + log(b)))) == a**3*b**3 + assert powdenest(((x**(2*i))**(3*y))**x) == ((x**(2*i))**(3*y))**x + assert powdenest(((x**(2*i))**(3*y))**x, force=True) == x**(6*i*x*y) + assert powdenest(((x**(a*Rational(2, 3)))**(3*y/i))**x) == \ + (((x**(a*Rational(2, 3)))**(3*y/i))**x) + assert powdenest((x**(2*i)*y**(4*i))**z, force=True) == (x*y**2)**(2*i*z) + assert powdenest((p**(2*i)*q**(4*i))**j) == (p*q**2)**(2*i*j) + e = ((p**(2*a))**(3*y))**x + assert powdenest(e) == e + e = ((x**2*y**4)**a)**(x*y) + assert powdenest(e) == e + e = (((x**2*y**4)**a)**(x*y))**3 + assert powdenest(e) == ((x**2*y**4)**a)**(3*x*y) + assert powdenest((((x**2*y**4)**a)**(x*y)), force=True) == \ + (x*y**2)**(2*a*x*y) + assert powdenest((((x**2*y**4)**a)**(x*y))**3, force=True) == \ + (x*y**2)**(6*a*x*y) + assert powdenest((x**2*y**6)**i) != (x*y**3)**(2*i) + x, y = symbols('x,y', positive=True) + assert powdenest((x**2*y**6)**i) == (x*y**3)**(2*i) + + assert powdenest((x**(i*Rational(2, 3))*y**(i/2))**(2*i)) == (x**Rational(4, 3)*y)**(i**2) + assert powdenest(sqrt(x**(2*i)*y**(6*i))) == (x*y**3)**i + + assert powdenest(4**x) == 2**(2*x) + assert powdenest((4**x)**y) == 2**(2*x*y) + assert powdenest(4**x*y) == 2**(2*x)*y + + +def test_powdenest_polar(): + x, y, z = symbols('x y z', polar=True) + a, b, c = symbols('a b c') + assert powdenest((x*y*z)**a) == x**a*y**a*z**a + assert powdenest((x**a*y**b)**c) == x**(a*c)*y**(b*c) + assert powdenest(((x**a)**b*y**c)**c) == x**(a*b*c)*y**(c**2) + + +def test_issue_5805(): + arg = ((gamma(x)*hyper((), (), x))*pi)**2 + assert powdenest(arg) == (pi*gamma(x)*hyper((), (), x))**2 + assert arg.is_positive is None + + +def test_issue_9324_powsimp_on_matrix_symbol(): + M = MatrixSymbol('M', 10, 10) + expr = powsimp(M, deep=True) + assert expr == M + assert expr.args[0] == Str('M') + + +def test_issue_6367(): + z = -5*sqrt(2)/(2*sqrt(2*sqrt(29) + 29)) + sqrt(-sqrt(29)/29 + S.Half) + assert Mul(*[powsimp(a) for a in Mul.make_args(z.normal())]) == 0 + assert powsimp(z.normal()) == 0 + assert simplify(z) == 0 + assert powsimp(sqrt(2 + sqrt(3))*sqrt(2 - sqrt(3)) + 1) == 2 + assert powsimp(z) != 0 + + +def test_powsimp_polar(): + from sympy.functions.elementary.complexes import polar_lift + from sympy.functions.elementary.exponential import exp_polar + x, y, z = symbols('x y z') + p, q, r = symbols('p q r', polar=True) + + assert (polar_lift(-1))**(2*x) == exp_polar(2*pi*I*x) + assert powsimp(p**x * q**x) == (p*q)**x + assert p**x * (1/p)**x == 1 + assert (1/p)**x == p**(-x) + + assert exp_polar(x)*exp_polar(y) == exp_polar(x)*exp_polar(y) + assert powsimp(exp_polar(x)*exp_polar(y)) == exp_polar(x + y) + assert powsimp(exp_polar(x)*exp_polar(y)*p**x*p**y) == \ + (p*exp_polar(1))**(x + y) + assert powsimp(exp_polar(x)*exp_polar(y)*p**x*p**y, combine='exp') == \ + exp_polar(x + y)*p**(x + y) + assert powsimp( + exp_polar(x)*exp_polar(y)*exp_polar(2)*sin(x) + sin(y) + p**x*p**y) \ + == p**(x + y) + sin(x)*exp_polar(2 + x + y) + sin(y) + assert powsimp(sin(exp_polar(x)*exp_polar(y))) == \ + sin(exp_polar(x)*exp_polar(y)) + assert powsimp(sin(exp_polar(x)*exp_polar(y)), deep=True) == \ + sin(exp_polar(x + y)) + + +def test_issue_5728(): + b = x*sqrt(y) + a = sqrt(b) + c = sqrt(sqrt(x)*y) + assert powsimp(a*b) == sqrt(b)**3 + assert powsimp(a*b**2*sqrt(y)) == sqrt(y)*a**5 + assert powsimp(a*x**2*c**3*y) == c**3*a**5 + assert powsimp(a*x*c**3*y**2) == c**7*a + assert powsimp(x*c**3*y**2) == c**7 + assert powsimp(x*c**3*y) == x*y*c**3 + assert powsimp(sqrt(x)*c**3*y) == c**5 + assert powsimp(sqrt(x)*a**3*sqrt(y)) == sqrt(x)*sqrt(y)*a**3 + assert powsimp(Mul(sqrt(x)*c**3*sqrt(y), y, evaluate=False)) == \ + sqrt(x)*sqrt(y)**3*c**3 + assert powsimp(a**2*a*x**2*y) == a**7 + + # symbolic powers work, too + b = x**y*y + a = b*sqrt(b) + assert a.is_Mul is True + assert powsimp(a) == sqrt(b)**3 + + # as does exp + a = x*exp(y*Rational(2, 3)) + assert powsimp(a*sqrt(a)) == sqrt(a)**3 + assert powsimp(a**2*sqrt(a)) == sqrt(a)**5 + assert powsimp(a**2*sqrt(sqrt(a))) == sqrt(sqrt(a))**9 + + +def test_issue_from_PR1599(): + n1, n2, n3, n4 = symbols('n1 n2 n3 n4', negative=True) + assert (powsimp(sqrt(n1)*sqrt(n2)*sqrt(n3)) == + -I*sqrt(-n1)*sqrt(-n2)*sqrt(-n3)) + assert (powsimp(root(n1, 3)*root(n2, 3)*root(n3, 3)*root(n4, 3)) == + -(-1)**Rational(1, 3)* + (-n1)**Rational(1, 3)*(-n2)**Rational(1, 3)*(-n3)**Rational(1, 3)*(-n4)**Rational(1, 3)) + + +def test_issue_10195(): + a = Symbol('a', integer=True) + l = Symbol('l', even=True, nonzero=True) + n = Symbol('n', odd=True) + e_x = (-1)**(n/2 - S.Half) - (-1)**(n*Rational(3, 2) - S.Half) + assert powsimp((-1)**(l/2)) == I**l + assert powsimp((-1)**(n/2)) == I**n + assert powsimp((-1)**(n*Rational(3, 2))) == -I**n + assert powsimp(e_x) == (-1)**(n/2 - S.Half) + (-1)**(n*Rational(3, 2) + + S.Half) + assert powsimp((-1)**(a*Rational(3, 2))) == (-I)**a + +def test_issue_15709(): + assert powsimp(3**x*Rational(2, 3)) == 2*3**(x-1) + assert powsimp(2*3**x/3) == 2*3**(x-1) + + +def test_issue_11981(): + x, y = symbols('x y', commutative=False) + assert powsimp((x*y)**2 * (y*x)**2) == (x*y)**2 * (y*x)**2 + + +def test_issue_17524(): + a = symbols("a", real=True) + e = (-1 - a**2)*sqrt(1 + a**2) + assert signsimp(powsimp(e)) == signsimp(e) == -(a**2 + 1)**(S(3)/2) + + +def test_issue_19627(): + # if you use force the user must verify + assert powdenest(sqrt(sin(x)**2), force=True) == sin(x) + assert powdenest((x**(S.Half/y))**(2*y), force=True) == x + from sympy.core.function import expand_power_base + e = 1 - a + expr = (exp(z/e)*x**(b/e)*y**((1 - b)/e))**e + assert powdenest(expand_power_base(expr, force=True), force=True + ) == x**b*y**(1 - b)*exp(z) + + +def test_issue_22546(): + p1, p2 = symbols('p1, p2', positive=True) + ref = powsimp(p1**z/p2**z) + e = z + 1 + ans = ref.subs(z, e) + assert ans.is_Pow + assert powsimp(p1**e/p2**e) == ans + i = symbols('i', integer=True) + ref = powsimp(x**i/y**i) + e = i + 1 + ans = ref.subs(i, e) + assert ans.is_Pow + assert powsimp(x**e/y**e) == ans diff --git a/phivenv/Lib/site-packages/sympy/simplify/tests/test_radsimp.py b/phivenv/Lib/site-packages/sympy/simplify/tests/test_radsimp.py new file mode 100644 index 0000000000000000000000000000000000000000..f8ff955e48a34536c1752c565c0864dedae6a214 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/simplify/tests/test_radsimp.py @@ -0,0 +1,498 @@ +from sympy.core.add import Add +from sympy.core.function import (Derivative, Function, diff) +from sympy.core.mul import Mul +from sympy.core.numbers import (I, Rational) +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, Wild, symbols) +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import (root, sqrt) +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.polys.polytools import factor +from sympy.series.order import O +from sympy.simplify.radsimp import (collect, collect_const, fraction, radsimp, rcollect) + +from sympy.core.expr import unchanged +from sympy.core.mul import _unevaluated_Mul as umul +from sympy.simplify.radsimp import (_unevaluated_Add, + collect_sqrt, fraction_expand, collect_abs) +from sympy.testing.pytest import raises + +from sympy.abc import x, y, z, a, b, c, d + + +def test_radsimp(): + r2 = sqrt(2) + r3 = sqrt(3) + r5 = sqrt(5) + r7 = sqrt(7) + assert fraction(radsimp(1/r2)) == (sqrt(2), 2) + assert radsimp(1/(1 + r2)) == \ + -1 + sqrt(2) + assert radsimp(1/(r2 + r3)) == \ + -sqrt(2) + sqrt(3) + assert fraction(radsimp(1/(1 + r2 + r3))) == \ + (-sqrt(6) + sqrt(2) + 2, 4) + assert fraction(radsimp(1/(r2 + r3 + r5))) == \ + (-sqrt(30) + 2*sqrt(3) + 3*sqrt(2), 12) + assert fraction(radsimp(1/(1 + r2 + r3 + r5))) == ( + (-34*sqrt(10) - 26*sqrt(15) - 55*sqrt(3) - 61*sqrt(2) + 14*sqrt(30) + + 93 + 46*sqrt(6) + 53*sqrt(5), 71)) + assert fraction(radsimp(1/(r2 + r3 + r5 + r7))) == ( + (-50*sqrt(42) - 133*sqrt(5) - 34*sqrt(70) - 145*sqrt(3) + 22*sqrt(105) + + 185*sqrt(2) + 62*sqrt(30) + 135*sqrt(7), 215)) + z = radsimp(1/(1 + r2/3 + r3/5 + r5 + r7)) + assert len((3616791619821680643598*z).args) == 16 + assert radsimp(1/z) == 1/z + assert radsimp(1/z, max_terms=20).expand() == 1 + r2/3 + r3/5 + r5 + r7 + assert radsimp(1/(r2*3)) == \ + sqrt(2)/6 + assert radsimp(1/(r2*a + r3 + r5 + r7)) == ( + (8*sqrt(2)*a**7 - 8*sqrt(7)*a**6 - 8*sqrt(5)*a**6 - 8*sqrt(3)*a**6 - + 180*sqrt(2)*a**5 + 8*sqrt(30)*a**5 + 8*sqrt(42)*a**5 + 8*sqrt(70)*a**5 + - 24*sqrt(105)*a**4 + 84*sqrt(3)*a**4 + 100*sqrt(5)*a**4 + + 116*sqrt(7)*a**4 - 72*sqrt(70)*a**3 - 40*sqrt(42)*a**3 - + 8*sqrt(30)*a**3 + 782*sqrt(2)*a**3 - 462*sqrt(3)*a**2 - + 302*sqrt(7)*a**2 - 254*sqrt(5)*a**2 + 120*sqrt(105)*a**2 - + 795*sqrt(2)*a - 62*sqrt(30)*a + 82*sqrt(42)*a + 98*sqrt(70)*a - + 118*sqrt(105) + 59*sqrt(7) + 295*sqrt(5) + 531*sqrt(3))/(16*a**8 - + 480*a**6 + 3128*a**4 - 6360*a**2 + 3481)) + assert radsimp(1/(r2*a + r2*b + r3 + r7)) == ( + (sqrt(2)*a*(a + b)**2 - 5*sqrt(2)*a + sqrt(42)*a + sqrt(2)*b*(a + + b)**2 - 5*sqrt(2)*b + sqrt(42)*b - sqrt(7)*(a + b)**2 - sqrt(3)*(a + + b)**2 - 2*sqrt(3) + 2*sqrt(7))/(2*a**4 + 8*a**3*b + 12*a**2*b**2 - + 20*a**2 + 8*a*b**3 - 40*a*b + 2*b**4 - 20*b**2 + 8)) + assert radsimp(1/(r2*a + r2*b + r2*c + r2*d)) == \ + sqrt(2)/(2*a + 2*b + 2*c + 2*d) + assert radsimp(1/(1 + r2*a + r2*b + r2*c + r2*d)) == ( + (sqrt(2)*a + sqrt(2)*b + sqrt(2)*c + sqrt(2)*d - 1)/(2*a**2 + 4*a*b + + 4*a*c + 4*a*d + 2*b**2 + 4*b*c + 4*b*d + 2*c**2 + 4*c*d + 2*d**2 - 1)) + assert radsimp((y**2 - x)/(y - sqrt(x))) == \ + sqrt(x) + y + assert radsimp(-(y**2 - x)/(y - sqrt(x))) == \ + -(sqrt(x) + y) + assert radsimp(1/(1 - I + a*I)) == \ + (-I*a + 1 + I)/(a**2 - 2*a + 2) + assert radsimp(1/((-x + y)*(x - sqrt(y)))) == \ + (-x - sqrt(y))/((x - y)*(x**2 - y)) + e = (3 + 3*sqrt(2))*x*(3*x - 3*sqrt(y)) + assert radsimp(e) == x*(3 + 3*sqrt(2))*(3*x - 3*sqrt(y)) + assert radsimp(1/e) == ( + (-9*x + 9*sqrt(2)*x - 9*sqrt(y) + 9*sqrt(2)*sqrt(y))/(9*x*(9*x**2 - + 9*y))) + assert radsimp(1 + 1/(1 + sqrt(3))) == \ + Mul(S.Half, -1 + sqrt(3), evaluate=False) + 1 + A = symbols("A", commutative=False) + assert radsimp(x**2 + sqrt(2)*x**2 - sqrt(2)*x*A) == \ + x**2 + sqrt(2)*x**2 - sqrt(2)*x*A + assert radsimp(1/sqrt(5 + 2 * sqrt(6))) == -sqrt(2) + sqrt(3) + assert radsimp(1/sqrt(5 + 2 * sqrt(6))**3) == -(-sqrt(3) + sqrt(2))**3 + + # issue 6532 + assert fraction(radsimp(1/sqrt(x))) == (sqrt(x), x) + assert fraction(radsimp(1/sqrt(2*x + 3))) == (sqrt(2*x + 3), 2*x + 3) + assert fraction(radsimp(1/sqrt(2*(x + 3)))) == (sqrt(2*x + 6), 2*x + 6) + + # issue 5994 + e = S('-(2 + 2*sqrt(2) + 4*2**(1/4))/' + '(1 + 2**(3/4) + 3*2**(1/4) + 3*sqrt(2))') + assert radsimp(e).expand() == -2*2**Rational(3, 4) - 2*2**Rational(1, 4) + 2 + 2*sqrt(2) + + # issue 5986 (modifications to radimp didn't initially recognize this so + # the test is included here) + assert radsimp(1/(-sqrt(5)/2 - S.Half + (-sqrt(5)/2 - S.Half)**2)) == 1 + + # from issue 5934 + eq = ( + (-240*sqrt(2)*sqrt(sqrt(5) + 5)*sqrt(8*sqrt(5) + 40) - + 360*sqrt(2)*sqrt(-8*sqrt(5) + 40)*sqrt(-sqrt(5) + 5) - + 120*sqrt(10)*sqrt(-8*sqrt(5) + 40)*sqrt(-sqrt(5) + 5) + + 120*sqrt(2)*sqrt(-sqrt(5) + 5)*sqrt(8*sqrt(5) + 40) + + 120*sqrt(2)*sqrt(-8*sqrt(5) + 40)*sqrt(sqrt(5) + 5) + + 120*sqrt(10)*sqrt(-sqrt(5) + 5)*sqrt(8*sqrt(5) + 40) + + 120*sqrt(10)*sqrt(-8*sqrt(5) + 40)*sqrt(sqrt(5) + 5))/(-36000 - + 7200*sqrt(5) + (12*sqrt(10)*sqrt(sqrt(5) + 5) + + 24*sqrt(10)*sqrt(-sqrt(5) + 5))**2)) + assert radsimp(eq) is S.NaN # it's 0/0 + + # work with normal form + e = 1/sqrt(sqrt(7)/7 + 2*sqrt(2) + 3*sqrt(3) + 5*sqrt(5)) + 3 + assert radsimp(e) == ( + -sqrt(sqrt(7) + 14*sqrt(2) + 21*sqrt(3) + + 35*sqrt(5))*(-11654899*sqrt(35) - 1577436*sqrt(210) - 1278438*sqrt(15) + - 1346996*sqrt(10) + 1635060*sqrt(6) + 5709765 + 7539830*sqrt(14) + + 8291415*sqrt(21))/1300423175 + 3) + + # obey power rules + base = sqrt(3) - sqrt(2) + assert radsimp(1/base**3) == (sqrt(3) + sqrt(2))**3 + assert radsimp(1/(-base)**3) == -(sqrt(2) + sqrt(3))**3 + assert radsimp(1/(-base)**x) == (-base)**(-x) + assert radsimp(1/base**x) == (sqrt(2) + sqrt(3))**x + assert radsimp(root(1/(-1 - sqrt(2)), -x)) == (-1)**(-1/x)*(1 + sqrt(2))**(1/x) + + # recurse + e = cos(1/(1 + sqrt(2))) + assert radsimp(e) == cos(-sqrt(2) + 1) + assert radsimp(e/2) == cos(-sqrt(2) + 1)/2 + assert radsimp(1/e) == 1/cos(-sqrt(2) + 1) + assert radsimp(2/e) == 2/cos(-sqrt(2) + 1) + assert fraction(radsimp(e/sqrt(x))) == (sqrt(x)*cos(-sqrt(2)+1), x) + + # test that symbolic denominators are not processed + r = 1 + sqrt(2) + assert radsimp(x/r, symbolic=False) == -x*(-sqrt(2) + 1) + assert radsimp(x/(y + r), symbolic=False) == x/(y + 1 + sqrt(2)) + assert radsimp(x/(y + r)/r, symbolic=False) == \ + -x*(-sqrt(2) + 1)/(y + 1 + sqrt(2)) + + # issue 7408 + eq = sqrt(x)/sqrt(y) + assert radsimp(eq) == umul(sqrt(x), sqrt(y), 1/y) + assert radsimp(eq, symbolic=False) == eq + + # issue 7498 + assert radsimp(sqrt(x)/sqrt(y)**3) == umul(sqrt(x), sqrt(y**3), 1/y**3) + + # for coverage + eq = sqrt(x)/y**2 + assert radsimp(eq) == eq + + # handle non-Expr args + from sympy.integrals.integrals import Integral + eq = Integral(x/(sqrt(2) - 1), (x, 0, 1/(sqrt(2) + 1))) + assert radsimp(eq) == Integral((sqrt(2) + 1)*x , (x, 0, sqrt(2) - 1)) + + from sympy.sets import FiniteSet + eq = FiniteSet(x/(sqrt(2) - 1)) + assert radsimp(eq) == FiniteSet((sqrt(2) + 1)*x) + +def test_radsimp_issue_3214(): + c, p = symbols('c p', positive=True) + s = sqrt(c**2 - p**2) + b = (c + I*p - s)/(c + I*p + s) + assert radsimp(b) == -I*(c + I*p - sqrt(c**2 - p**2))**2/(2*c*p) + + +def test_collect_1(): + """Collect with respect to Symbol""" + x, y, z, n = symbols('x,y,z,n') + assert collect(1, x) == 1 + assert collect( x + y*x, x ) == x * (1 + y) + assert collect( x + x**2, x ) == x + x**2 + assert collect( x**2 + y*x**2, x ) == (x**2)*(1 + y) + assert collect( x**2 + y*x, x ) == x*y + x**2 + assert collect( 2*x**2 + y*x**2 + 3*x*y, [x] ) == x**2*(2 + y) + 3*x*y + assert collect( 2*x**2 + y*x**2 + 3*x*y, [y] ) == 2*x**2 + y*(x**2 + 3*x) + + assert collect( ((1 + y + x)**4).expand(), x) == ((1 + y)**4).expand() + \ + x*(4*(1 + y)**3).expand() + x**2*(6*(1 + y)**2).expand() + \ + x**3*(4*(1 + y)).expand() + x**4 + # symbols can be given as any iterable + expr = x + y + assert collect(expr, expr.free_symbols) == expr + assert collect(x*exp(x) + sin(x)*y + sin(x)*2 + 3*x, x, exact=None + ) == x*exp(x) + 3*x + (y + 2)*sin(x) + assert collect(x*exp(x) + sin(x)*y + sin(x)*2 + 3*x + y*x + + y*x*exp(x), x, exact=None + ) == x*exp(x)*(y + 1) + (3 + y)*x + (y + 2)*sin(x) + + +def test_collect_2(): + """Collect with respect to a sum""" + a, b, x = symbols('a,b,x') + assert collect(a*(cos(x) + sin(x)) + b*(cos(x) + sin(x)), + sin(x) + cos(x)) == (a + b)*(cos(x) + sin(x)) + + +def test_collect_3(): + """Collect with respect to a product""" + a, b, c = symbols('a,b,c') + f = Function('f') + x, y, z, n = symbols('x,y,z,n') + + assert collect(-x/8 + x*y, -x) == x*(y - Rational(1, 8)) + + assert collect( 1 + x*(y**2), x*y ) == 1 + x*(y**2) + assert collect( x*y + a*x*y, x*y) == x*y*(1 + a) + assert collect( 1 + x*y + a*x*y, x*y) == 1 + x*y*(1 + a) + assert collect(a*x*f(x) + b*(x*f(x)), x*f(x)) == x*(a + b)*f(x) + + assert collect(a*x*log(x) + b*(x*log(x)), x*log(x)) == x*(a + b)*log(x) + assert collect(a*x**2*log(x)**2 + b*(x*log(x))**2, x*log(x)) == \ + x**2*log(x)**2*(a + b) + + # with respect to a product of three symbols + assert collect(y*x*z + a*x*y*z, x*y*z) == (1 + a)*x*y*z + + +def test_collect_4(): + """Collect with respect to a power""" + a, b, c, x = symbols('a,b,c,x') + + assert collect(a*x**c + b*x**c, x**c) == x**c*(a + b) + # issue 6096: 2 stays with c (unless c is integer or x is positive0 + assert collect(a*x**(2*c) + b*x**(2*c), x**c) == x**(2*c)*(a + b) + + +def test_collect_5(): + """Collect with respect to a tuple""" + a, x, y, z, n = symbols('a,x,y,z,n') + assert collect(x**2*y**4 + z*(x*y**2)**2 + z + a*z, [x*y**2, z]) in [ + z*(1 + a + x**2*y**4) + x**2*y**4, + z*(1 + a) + x**2*y**4*(1 + z) ] + assert collect((1 + (x + y) + (x + y)**2).expand(), + [x, y]) == 1 + y + x*(1 + 2*y) + x**2 + y**2 + + +def test_collect_pr19431(): + """Unevaluated collect with respect to a product""" + a = symbols('a') + assert collect(a**2*(a**2 + 1), a**2, evaluate=False)[a**2] == (a**2 + 1) + + +def test_collect_D(): + D = Derivative + f = Function('f') + x, a, b = symbols('x,a,b') + fx = D(f(x), x) + fxx = D(f(x), x, x) + + assert collect(a*fx + b*fx, fx) == (a + b)*fx + assert collect(a*D(fx, x) + b*D(fx, x), fx) == (a + b)*D(fx, x) + assert collect(a*fxx + b*fxx, fx) == (a + b)*D(fx, x) + # issue 4784 + assert collect(5*f(x) + 3*fx, fx) == 5*f(x) + 3*fx + assert collect(f(x) + f(x)*diff(f(x), x) + x*diff(f(x), x)*f(x), f(x).diff(x)) == \ + (x*f(x) + f(x))*D(f(x), x) + f(x) + assert collect(f(x) + f(x)*diff(f(x), x) + x*diff(f(x), x)*f(x), f(x).diff(x), exact=True) == \ + (x*f(x) + f(x))*D(f(x), x) + f(x) + assert collect(1/f(x) + 1/f(x)*diff(f(x), x) + x*diff(f(x), x)/f(x), f(x).diff(x), exact=True) == \ + (1/f(x) + x/f(x))*D(f(x), x) + 1/f(x) + e = (1 + x*fx + fx)/f(x) + assert collect(e.expand(), fx) == fx*(x/f(x) + 1/f(x)) + 1/f(x) + + +def test_collect_func(): + f = ((x + a + 1)**3).expand() + + assert collect(f, x) == a**3 + 3*a**2 + 3*a + x**3 + x**2*(3*a + 3) + \ + x*(3*a**2 + 6*a + 3) + 1 + assert collect(f, x, factor) == x**3 + 3*x**2*(a + 1) + 3*x*(a + 1)**2 + \ + (a + 1)**3 + + assert collect(f, x, evaluate=False) == { + S.One: a**3 + 3*a**2 + 3*a + 1, + x: 3*a**2 + 6*a + 3, x**2: 3*a + 3, + x**3: 1 + } + + assert collect(f, x, factor, evaluate=False) == { + S.One: (a + 1)**3, x: 3*(a + 1)**2, + x**2: umul(S(3), a + 1), x**3: 1} + + +def test_collect_order(): + a, b, x, t = symbols('a,b,x,t') + + assert collect(t + t*x + t*x**2 + O(x**3), t) == t*(1 + x + x**2 + O(x**3)) + assert collect(t + t*x + x**2 + O(x**3), t) == \ + t*(1 + x + O(x**3)) + x**2 + O(x**3) + + f = a*x + b*x + c*x**2 + d*x**2 + O(x**3) + g = x*(a + b) + x**2*(c + d) + O(x**3) + + assert collect(f, x) == g + assert collect(f, x, distribute_order_term=False) == g + + f = sin(a + b).series(b, 0, 10) + + assert collect(f, [sin(a), cos(a)]) == \ + sin(a)*cos(b).series(b, 0, 10) + cos(a)*sin(b).series(b, 0, 10) + assert collect(f, [sin(a), cos(a)], distribute_order_term=False) == \ + sin(a)*cos(b).series(b, 0, 10).removeO() + \ + cos(a)*sin(b).series(b, 0, 10).removeO() + O(b**10) + + +def test_rcollect(): + assert rcollect((x**2*y + x*y + x + y)/(x + y), y) == \ + (x + y*(1 + x + x**2))/(x + y) + assert rcollect(sqrt(-((x + 1)*(y + 1))), z) == sqrt(-((x + 1)*(y + 1))) + + +def test_collect_D_0(): + D = Derivative + f = Function('f') + x, a, b = symbols('x,a,b') + fxx = D(f(x), x, x) + + assert collect(a*fxx + b*fxx, fxx) == (a + b)*fxx + + +def test_collect_Wild(): + """Collect with respect to functions with Wild argument""" + a, b, x, y = symbols('a b x y') + f = Function('f') + w1 = Wild('.1') + w2 = Wild('.2') + assert collect(f(x) + a*f(x), f(w1)) == (1 + a)*f(x) + assert collect(f(x, y) + a*f(x, y), f(w1)) == f(x, y) + a*f(x, y) + assert collect(f(x, y) + a*f(x, y), f(w1, w2)) == (1 + a)*f(x, y) + assert collect(f(x, y) + a*f(x, y), f(w1, w1)) == f(x, y) + a*f(x, y) + assert collect(f(x, x) + a*f(x, x), f(w1, w1)) == (1 + a)*f(x, x) + assert collect(a*(x + 1)**y + (x + 1)**y, w1**y) == (1 + a)*(x + 1)**y + assert collect(a*(x + 1)**y + (x + 1)**y, w1**b) == \ + a*(x + 1)**y + (x + 1)**y + assert collect(a*(x + 1)**y + (x + 1)**y, (x + 1)**w2) == \ + (1 + a)*(x + 1)**y + assert collect(a*(x + 1)**y + (x + 1)**y, w1**w2) == (1 + a)*(x + 1)**y + + +def test_collect_const(): + # coverage not provided by above tests + assert collect_const(2*sqrt(3) + 4*a*sqrt(5)) == \ + 2*(2*sqrt(5)*a + sqrt(3)) # let the primitive reabsorb + assert collect_const(2*sqrt(3) + 4*a*sqrt(5), sqrt(3)) == \ + 2*sqrt(3) + 4*a*sqrt(5) + assert collect_const(sqrt(2)*(1 + sqrt(2)) + sqrt(3) + x*sqrt(2)) == \ + sqrt(2)*(x + 1 + sqrt(2)) + sqrt(3) + + # issue 5290 + assert collect_const(2*x + 2*y + 1, 2) == \ + collect_const(2*x + 2*y + 1) == \ + Add(S.One, Mul(2, x + y, evaluate=False), evaluate=False) + assert collect_const(-y - z) == Mul(-1, y + z, evaluate=False) + assert collect_const(2*x - 2*y - 2*z, 2) == \ + Mul(2, x - y - z, evaluate=False) + assert collect_const(2*x - 2*y - 2*z, -2) == \ + _unevaluated_Add(2*x, Mul(-2, y + z, evaluate=False)) + + # this is why the content_primitive is used + eq = (sqrt(15 + 5*sqrt(2))*x + sqrt(3 + sqrt(2))*y)*2 + assert collect_sqrt(eq + 2) == \ + 2*sqrt(sqrt(2) + 3)*(sqrt(5)*x + y) + 2 + + # issue 16296 + assert collect_const(a + b + x/2 + y/2) == a + b + Mul(S.Half, x + y, evaluate=False) + + +def test_issue_13143(): + f = Function('f') + fx = f(x).diff(x) + e = f(x) + fx + f(x)*fx + # collect function before derivative + assert collect(e, Wild('w')) == f(x)*(fx + 1) + fx + e = f(x) + f(x)*fx + x*fx*f(x) + assert collect(e, fx) == (x*f(x) + f(x))*fx + f(x) + assert collect(e, f(x)) == (x*fx + fx + 1)*f(x) + e = f(x) + fx + f(x)*fx + assert collect(e, [f(x), fx]) == f(x)*(1 + fx) + fx + assert collect(e, [fx, f(x)]) == fx*(1 + f(x)) + f(x) + + +def test_issue_6097(): + assert collect(a*y**(2.0*x) + b*y**(2.0*x), y**x) == (a + b)*(y**x)**2.0 + assert collect(a*2**(2.0*x) + b*2**(2.0*x), 2**x) == (a + b)*(2**x)**2.0 + + +def test_fraction_expand(): + eq = (x + y)*y/x + assert eq.expand(frac=True) == fraction_expand(eq) == (x*y + y**2)/x + assert eq.expand() == y + y**2/x + + +def test_fraction(): + x, y, z = map(Symbol, 'xyz') + A = Symbol('A', commutative=False) + + assert fraction(S.Half) == (1, 2) + + assert fraction(x) == (x, 1) + assert fraction(1/x) == (1, x) + assert fraction(x/y) == (x, y) + assert fraction(x/2) == (x, 2) + + assert fraction(x*y/z) == (x*y, z) + assert fraction(x/(y*z)) == (x, y*z) + + assert fraction(1/y**2) == (1, y**2) + assert fraction(x/y**2) == (x, y**2) + + assert fraction((x**2 + 1)/y) == (x**2 + 1, y) + assert fraction(x*(y + 1)/y**7) == (x*(y + 1), y**7) + + assert fraction(exp(-x), exact=True) == (exp(-x), 1) + assert fraction((1/(x + y))/2, exact=True) == (1, Mul(2,(x + y), evaluate=False)) + + assert fraction(x*A/y) == (x*A, y) + assert fraction(x*A**-1/y) == (x*A**-1, y) + + n = symbols('n', negative=True) + assert fraction(exp(n)) == (1, exp(-n)) + assert fraction(exp(-n)) == (exp(-n), 1) + + p = symbols('p', positive=True) + assert fraction(exp(-p)*log(p), exact=True) == (exp(-p)*log(p), 1) + + m = Mul(1, 1, S.Half, evaluate=False) + assert fraction(m) == (1, 2) + assert fraction(m, exact=True) == (Mul(1, 1, evaluate=False), 2) + + m = Mul(1, 1, S.Half, S.Half, Pow(1, -1, evaluate=False), evaluate=False) + assert fraction(m) == (1, 4) + assert fraction(m, exact=True) == \ + (Mul(1, 1, evaluate=False), Mul(2, 2, 1, evaluate=False)) + + +def test_issue_5615(): + aA, Re, a, b, D = symbols('aA Re a b D') + e = ((D**3*a + b*aA**3)/Re).expand() + assert collect(e, [aA**3/Re, a]) == e + + +def test_issue_5933(): + from sympy.geometry.polygon import (Polygon, RegularPolygon) + from sympy.simplify.radsimp import denom + x = Polygon(*RegularPolygon((0, 0), 1, 5).vertices).centroid.x + assert abs(denom(x).n()) > 1e-12 + assert abs(denom(radsimp(x))) > 1e-12 # in case simplify didn't handle it + + +def test_issue_14608(): + a, b = symbols('a b', commutative=False) + x, y = symbols('x y') + raises(AttributeError, lambda: collect(a*b + b*a, a)) + assert collect(x*y + y*(x+1), a) == x*y + y*(x+1) + assert collect(x*y + y*(x+1) + a*b + b*a, y) == y*(2*x + 1) + a*b + b*a + + +def test_collect_abs(): + s = abs(x) + abs(y) + assert collect_abs(s) == s + assert unchanged(Mul, abs(x), abs(y)) + ans = Abs(x*y) + assert isinstance(ans, Abs) + assert collect_abs(abs(x)*abs(y)) == ans + assert collect_abs(1 + exp(abs(x)*abs(y))) == 1 + exp(ans) + + # See https://github.com/sympy/sympy/issues/12910 + p = Symbol('p', positive=True) + assert collect_abs(p/abs(1-p)).is_commutative is True + + +def test_issue_19149(): + eq = exp(3*x/4) + assert collect(eq, exp(x)) == eq + +def test_issue_19719(): + a, b = symbols('a, b') + expr = a**2 * (b + 1) + (7 + 1/b)/a + collected = collect(expr, (a**2, 1/a), evaluate=False) + # Would return {_Dummy_20**(-2): b + 1, 1/a: 7 + 1/b} without xreplace + assert collected == {a**2: b + 1, 1/a: 7 + 1/b} + + +def test_issue_21355(): + assert radsimp(1/(x + sqrt(x**2))) == 1/(x + sqrt(x**2)) + assert radsimp(1/(x - sqrt(x**2))) == 1/(x - sqrt(x**2)) diff --git a/phivenv/Lib/site-packages/sympy/simplify/tests/test_ratsimp.py b/phivenv/Lib/site-packages/sympy/simplify/tests/test_ratsimp.py new file mode 100644 index 0000000000000000000000000000000000000000..14e84fd2b227518baff1bda4e5b27ecc40a8bcdd --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/simplify/tests/test_ratsimp.py @@ -0,0 +1,78 @@ +from sympy.core.numbers import (Rational, pi) +from sympy.functions.elementary.exponential import log +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.special.error_functions import erf +from sympy.polys.domains import GF +from sympy.simplify.ratsimp import (ratsimp, ratsimpmodprime) + +from sympy.abc import x, y, z, t, a, b, c, d, e + + +def test_ratsimp(): + f, g = 1/x + 1/y, (x + y)/(x*y) + + assert f != g and ratsimp(f) == g + + f, g = 1/(1 + 1/x), 1 - 1/(x + 1) + + assert f != g and ratsimp(f) == g + + f, g = x/(x + y) + y/(x + y), 1 + + assert f != g and ratsimp(f) == g + + f, g = -x - y - y**2/(x + y) + x**2/(x + y), -2*y + + assert f != g and ratsimp(f) == g + + f = (a*c*x*y + a*c*z - b*d*x*y - b*d*z - b*t*x*y - b*t*x - b*t*z + + e*x)/(x*y + z) + G = [a*c - b*d - b*t + (-b*t*x + e*x)/(x*y + z), + a*c - b*d - b*t - ( b*t*x - e*x)/(x*y + z)] + + assert f != g and ratsimp(f) in G + + A = sqrt(pi) + + B = log(erf(x) - 1) + C = log(erf(x) + 1) + + D = 8 - 8*erf(x) + + f = A*B/D - A*C/D + A*C*erf(x)/D - A*B*erf(x)/D + 2*A/D + + assert ratsimp(f) == A*B/8 - A*C/8 - A/(4*erf(x) - 4) + + +def test_ratsimpmodprime(): + a = y**5 + x + y + b = x - y + F = [x*y**5 - x - y] + assert ratsimpmodprime(a/b, F, x, y, order='lex') == \ + (-x**2 - x*y - x - y) / (-x**2 + x*y) + + a = x + y**2 - 2 + b = x + y**2 - y - 1 + F = [x*y - 1] + assert ratsimpmodprime(a/b, F, x, y, order='lex') == \ + (1 + y - x)/(y - x) + + a = 5*x**3 + 21*x**2 + 4*x*y + 23*x + 12*y + 15 + b = 7*x**3 - y*x**2 + 31*x**2 + 2*x*y + 15*y + 37*x + 21 + F = [x**2 + y**2 - 1] + assert ratsimpmodprime(a/b, F, x, y, order='lex') == \ + (1 + 5*y - 5*x)/(8*y - 6*x) + + a = x*y - x - 2*y + 4 + b = x + y**2 - 2*y + F = [x - 2, y - 3] + assert ratsimpmodprime(a/b, F, x, y, order='lex') == \ + Rational(2, 5) + + # Test a bug where denominators would be dropped + assert ratsimpmodprime(x, [y - 2*x], order='lex') == \ + y/2 + + a = (x**5 + 2*x**4 + 2*x**3 + 2*x**2 + x + 2/x + x**(-2)) + assert ratsimpmodprime(a, [x + 1], domain=GF(2)) == 1 + assert ratsimpmodprime(a, [x + 1], domain=GF(3)) == -1 diff --git a/phivenv/Lib/site-packages/sympy/simplify/tests/test_rewrite.py b/phivenv/Lib/site-packages/sympy/simplify/tests/test_rewrite.py new file mode 100644 index 0000000000000000000000000000000000000000..56d2fb7a85bd959bd4accc2f36127429efbdbe70 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/simplify/tests/test_rewrite.py @@ -0,0 +1,31 @@ +from sympy.core.numbers import I +from sympy.core.symbol import symbols +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.trigonometric import (cos, cot, sin) +from sympy.testing.pytest import _both_exp_pow + +x, y, z, n = symbols('x,y,z,n') + + +@_both_exp_pow +def test_has(): + assert cot(x).has(x) + assert cot(x).has(cot) + assert not cot(x).has(sin) + assert sin(x).has(x) + assert sin(x).has(sin) + assert not sin(x).has(cot) + assert exp(x).has(exp) + + +@_both_exp_pow +def test_sin_exp_rewrite(): + assert sin(x).rewrite(sin, exp) == -I/2*(exp(I*x) - exp(-I*x)) + assert sin(x).rewrite(sin, exp).rewrite(exp, sin) == sin(x) + assert cos(x).rewrite(cos, exp).rewrite(exp, cos) == cos(x) + assert (sin(5*y) - sin( + 2*x)).rewrite(sin, exp).rewrite(exp, sin) == sin(5*y) - sin(2*x) + assert sin(x + y).rewrite(sin, exp).rewrite(exp, sin) == sin(x + y) + assert cos(x + y).rewrite(cos, exp).rewrite(exp, cos) == cos(x + y) + # This next test currently passes... not clear whether it should or not? + assert cos(x).rewrite(cos, exp).rewrite(exp, sin) == cos(x) diff --git a/phivenv/Lib/site-packages/sympy/simplify/tests/test_simplify.py b/phivenv/Lib/site-packages/sympy/simplify/tests/test_simplify.py new file mode 100644 index 0000000000000000000000000000000000000000..a5bf469f68adf5c5dfbdf7559414681e2fb28ba7 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/simplify/tests/test_simplify.py @@ -0,0 +1,1093 @@ +from sympy.concrete.summations import Sum +from sympy.core.add import Add +from sympy.core.basic import Basic +from sympy.core.expr import unchanged +from sympy.core.function import (count_ops, diff, expand, expand_multinomial, Function, Derivative) +from sympy.core.mul import Mul, _keep_coeff +from sympy.core import GoldenRatio +from sympy.core.numbers import (E, Float, I, oo, pi, Rational, zoo) +from sympy.core.relational import (Eq, Lt, Gt, Ge, Le) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.core.sympify import sympify +from sympy.functions.combinatorial.factorials import (binomial, factorial) +from sympy.functions.elementary.complexes import (Abs, sign) +from sympy.functions.elementary.exponential import (exp, exp_polar, log) +from sympy.functions.elementary.hyperbolic import (cosh, csch, sinh) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import (acos, asin, atan, cos, sin, sinc, tan) +from sympy.functions.special.error_functions import erf +from sympy.functions.special.gamma_functions import gamma +from sympy.functions.special.hyper import hyper +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.geometry.polygon import rad +from sympy.integrals.integrals import (Integral, integrate) +from sympy.logic.boolalg import (And, Or) +from sympy.matrices.dense import (Matrix, eye) +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.polys.polytools import (factor, Poly) +from sympy.simplify.simplify import (besselsimp, hypersimp, inversecombine, logcombine, nsimplify, nthroot, posify, separatevars, signsimp, simplify) +from sympy.solvers.solvers import solve + +from sympy.testing.pytest import XFAIL, slow, _both_exp_pow +from sympy.abc import x, y, z, t, a, b, c, d, e, f, g, h, i, n + + +def test_issue_7263(): + assert abs((simplify(30.8**2 - 82.5**2 * sin(rad(11.6))**2)).evalf() - \ + 673.447451402970) < 1e-12 + + +def test_factorial_simplify(): + # There are more tests in test_factorials.py. + x = Symbol('x') + assert simplify(factorial(x)/x) == gamma(x) + assert simplify(factorial(factorial(x))) == factorial(factorial(x)) + + +def test_simplify_expr(): + x, y, z, k, n, m, w, s, A = symbols('x,y,z,k,n,m,w,s,A') + f = Function('f') + + assert all(simplify(tmp) == tmp for tmp in [I, E, oo, x, -x, -oo, -E, -I]) + + e = 1/x + 1/y + assert e != (x + y)/(x*y) + assert simplify(e) == (x + y)/(x*y) + + e = A**2*s**4/(4*pi*k*m**3) + assert simplify(e) == e + + e = (4 + 4*x - 2*(2 + 2*x))/(2 + 2*x) + assert simplify(e) == 0 + + e = (-4*x*y**2 - 2*y**3 - 2*x**2*y)/(x + y)**2 + assert simplify(e) == -2*y + + e = -x - y - (x + y)**(-1)*y**2 + (x + y)**(-1)*x**2 + assert simplify(e) == -2*y + + e = (x + x*y)/x + assert simplify(e) == 1 + y + + e = (f(x) + y*f(x))/f(x) + assert simplify(e) == 1 + y + + e = (2 * (1/n - cos(n * pi)/n))/pi + assert simplify(e) == (-cos(pi*n) + 1)/(pi*n)*2 + + e = integrate(1/(x**3 + 1), x).diff(x) + assert simplify(e) == 1/(x**3 + 1) + + e = integrate(x/(x**2 + 3*x + 1), x).diff(x) + assert simplify(e) == x/(x**2 + 3*x + 1) + + f = Symbol('f') + A = Matrix([[2*k - m*w**2, -k], [-k, k - m*w**2]]).inv() + assert simplify((A*Matrix([0, f]))[1] - + (-f*(2*k - m*w**2)/(k**2 - (k - m*w**2)*(2*k - m*w**2)))) == 0 + + f = -x + y/(z + t) + z*x/(z + t) + z*a/(z + t) + t*x/(z + t) + assert simplify(f) == (y + a*z)/(z + t) + + # issue 10347 + expr = -x*(y**2 - 1)*(2*y**2*(x**2 - 1)/(a*(x**2 - y**2)**2) + (x**2 - 1) + /(a*(x**2 - y**2)))/(a*(x**2 - y**2)) + x*(-2*x**2*sqrt(-x**2*y**2 + x**2 + + y**2 - 1)*sin(z)/(a*(x**2 - y**2)**2) - x**2*sqrt(-x**2*y**2 + x**2 + + y**2 - 1)*sin(z)/(a*(x**2 - 1)*(x**2 - y**2)) + (x**2*sqrt((-x**2 + 1)* + (y**2 - 1))*sqrt(-x**2*y**2 + x**2 + y**2 - 1)*sin(z)/(x**2 - 1) + sqrt( + (-x**2 + 1)*(y**2 - 1))*(x*(-x*y**2 + x)/sqrt(-x**2*y**2 + x**2 + y**2 - + 1) + sqrt(-x**2*y**2 + x**2 + y**2 - 1))*sin(z))/(a*sqrt((-x**2 + 1)*( + y**2 - 1))*(x**2 - y**2)))*sqrt(-x**2*y**2 + x**2 + y**2 - 1)*sin(z)/(a* + (x**2 - y**2)) + x*(-2*x**2*sqrt(-x**2*y**2 + x**2 + y**2 - 1)*cos(z)/(a* + (x**2 - y**2)**2) - x**2*sqrt(-x**2*y**2 + x**2 + y**2 - 1)*cos(z)/(a* + (x**2 - 1)*(x**2 - y**2)) + (x**2*sqrt((-x**2 + 1)*(y**2 - 1))*sqrt(-x**2 + *y**2 + x**2 + y**2 - 1)*cos(z)/(x**2 - 1) + x*sqrt((-x**2 + 1)*(y**2 - + 1))*(-x*y**2 + x)*cos(z)/sqrt(-x**2*y**2 + x**2 + y**2 - 1) + sqrt((-x**2 + + 1)*(y**2 - 1))*sqrt(-x**2*y**2 + x**2 + y**2 - 1)*cos(z))/(a*sqrt((-x**2 + + 1)*(y**2 - 1))*(x**2 - y**2)))*sqrt(-x**2*y**2 + x**2 + y**2 - 1)*cos( + z)/(a*(x**2 - y**2)) - y*sqrt((-x**2 + 1)*(y**2 - 1))*(-x*y*sqrt(-x**2* + y**2 + x**2 + y**2 - 1)*sin(z)/(a*(x**2 - y**2)*(y**2 - 1)) + 2*x*y*sqrt( + -x**2*y**2 + x**2 + y**2 - 1)*sin(z)/(a*(x**2 - y**2)**2) + (x*y*sqrt(( + -x**2 + 1)*(y**2 - 1))*sqrt(-x**2*y**2 + x**2 + y**2 - 1)*sin(z)/(y**2 - + 1) + x*sqrt((-x**2 + 1)*(y**2 - 1))*(-x**2*y + y)*sin(z)/sqrt(-x**2*y**2 + + x**2 + y**2 - 1))/(a*sqrt((-x**2 + 1)*(y**2 - 1))*(x**2 - y**2)))*sin( + z)/(a*(x**2 - y**2)) + y*(x**2 - 1)*(-2*x*y*(x**2 - 1)/(a*(x**2 - y**2) + **2) + 2*x*y/(a*(x**2 - y**2)))/(a*(x**2 - y**2)) + y*(x**2 - 1)*(y**2 - + 1)*(-x*y*sqrt(-x**2*y**2 + x**2 + y**2 - 1)*cos(z)/(a*(x**2 - y**2)*(y**2 + - 1)) + 2*x*y*sqrt(-x**2*y**2 + x**2 + y**2 - 1)*cos(z)/(a*(x**2 - y**2) + **2) + (x*y*sqrt((-x**2 + 1)*(y**2 - 1))*sqrt(-x**2*y**2 + x**2 + y**2 - + 1)*cos(z)/(y**2 - 1) + x*sqrt((-x**2 + 1)*(y**2 - 1))*(-x**2*y + y)*cos( + z)/sqrt(-x**2*y**2 + x**2 + y**2 - 1))/(a*sqrt((-x**2 + 1)*(y**2 - 1) + )*(x**2 - y**2)))*cos(z)/(a*sqrt((-x**2 + 1)*(y**2 - 1))*(x**2 - y**2) + ) - x*sqrt((-x**2 + 1)*(y**2 - 1))*sqrt(-x**2*y**2 + x**2 + y**2 - 1)*sin( + z)**2/(a**2*(x**2 - 1)*(x**2 - y**2)*(y**2 - 1)) - x*sqrt((-x**2 + 1)*( + y**2 - 1))*sqrt(-x**2*y**2 + x**2 + y**2 - 1)*cos(z)**2/(a**2*(x**2 - 1)*( + x**2 - y**2)*(y**2 - 1)) + assert simplify(expr) == 2*x/(a**2*(x**2 - y**2)) + + #issue 17631 + assert simplify('((-1/2)*Boole(True)*Boole(False)-1)*Boole(True)') == \ + Mul(sympify('(2 + Boole(True)*Boole(False))'), sympify('-Boole(True)/2')) + + A, B = symbols('A,B', commutative=False) + + assert simplify(A*B - B*A) == A*B - B*A + assert simplify(A/(1 + y/x)) == x*A/(x + y) + assert simplify(A*(1/x + 1/y)) == A/x + A/y #(x + y)*A/(x*y) + + assert simplify(log(2) + log(3)) == log(6) + assert simplify(log(2*x) - log(2)) == log(x) + + assert simplify(hyper([], [], x)) == exp(x) + + +def test_issue_3557(): + f_1 = x*a + y*b + z*c - 1 + f_2 = x*d + y*e + z*f - 1 + f_3 = x*g + y*h + z*i - 1 + + solutions = solve([f_1, f_2, f_3], x, y, z, simplify=False) + + assert simplify(solutions[y]) == \ + (a*i + c*d + f*g - a*f - c*g - d*i)/ \ + (a*e*i + b*f*g + c*d*h - a*f*h - b*d*i - c*e*g) + + +def test_simplify_other(): + assert simplify(sin(x)**2 + cos(x)**2) == 1 + assert simplify(gamma(x + 1)/gamma(x)) == x + assert simplify(sin(x)**2 + cos(x)**2 + factorial(x)/gamma(x)) == 1 + x + assert simplify( + Eq(sin(x)**2 + cos(x)**2, factorial(x)/gamma(x))) == Eq(x, 1) + nc = symbols('nc', commutative=False) + assert simplify(x + x*nc) == x*(1 + nc) + # issue 6123 + # f = exp(-I*(k*sqrt(t) + x/(2*sqrt(t)))**2) + # ans = integrate(f, (k, -oo, oo), conds='none') + ans = I*(-pi*x*exp(I*pi*Rational(-3, 4) + I*x**2/(4*t))*erf(x*exp(I*pi*Rational(-3, 4))/ + (2*sqrt(t)))/(2*sqrt(t)) + pi*x*exp(I*pi*Rational(-3, 4) + I*x**2/(4*t))/ + (2*sqrt(t)))*exp(-I*x**2/(4*t))/(sqrt(pi)*x) - I*sqrt(pi) * \ + (-erf(x*exp(I*pi/4)/(2*sqrt(t))) + 1)*exp(I*pi/4)/(2*sqrt(t)) + assert simplify(ans) == -(-1)**Rational(3, 4)*sqrt(pi)/sqrt(t) + # issue 6370 + assert simplify(2**(2 + x)/4) == 2**x + + +@_both_exp_pow +def test_simplify_complex(): + cosAsExp = cos(x)._eval_rewrite_as_exp(x) + tanAsExp = tan(x)._eval_rewrite_as_exp(x) + assert simplify(cosAsExp*tanAsExp) == sin(x) # issue 4341 + + # issue 10124 + assert simplify(exp(Matrix([[0, -1], [1, 0]]))) == Matrix([[cos(1), + -sin(1)], [sin(1), cos(1)]]) + + +def test_simplify_ratio(): + # roots of x**3-3*x+5 + roots = ['(1/2 - sqrt(3)*I/2)*(sqrt(21)/2 + 5/2)**(1/3) + 1/((1/2 - ' + 'sqrt(3)*I/2)*(sqrt(21)/2 + 5/2)**(1/3))', + '1/((1/2 + sqrt(3)*I/2)*(sqrt(21)/2 + 5/2)**(1/3)) + ' + '(1/2 + sqrt(3)*I/2)*(sqrt(21)/2 + 5/2)**(1/3)', + '-(sqrt(21)/2 + 5/2)**(1/3) - 1/(sqrt(21)/2 + 5/2)**(1/3)'] + + for r in roots: + r = S(r) + assert count_ops(simplify(r, ratio=1)) <= count_ops(r) + # If ratio=oo, simplify() is always applied: + assert simplify(r, ratio=oo) is not r + + +def test_simplify_measure(): + measure1 = lambda expr: len(str(expr)) + measure2 = lambda expr: -count_ops(expr) + # Return the most complicated result + expr = (x + 1)/(x + sin(x)**2 + cos(x)**2) + assert measure1(simplify(expr, measure=measure1)) <= measure1(expr) + assert measure2(simplify(expr, measure=measure2)) <= measure2(expr) + + expr2 = Eq(sin(x)**2 + cos(x)**2, 1) + assert measure1(simplify(expr2, measure=measure1)) <= measure1(expr2) + assert measure2(simplify(expr2, measure=measure2)) <= measure2(expr2) + + +def test_simplify_rational(): + expr = 2**x*2.**y + assert simplify(expr, rational = True) == 2**(x+y) + assert simplify(expr, rational = None) == 2.0**(x+y) + assert simplify(expr, rational = False) == expr + assert simplify('0.9 - 0.8 - 0.1', rational = True) == 0 + + +def test_simplify_issue_1308(): + assert simplify(exp(Rational(-1, 2)) + exp(Rational(-3, 2))) == \ + (1 + E)*exp(Rational(-3, 2)) + + +def test_issue_5652(): + assert simplify(E + exp(-E)) == exp(-E) + E + n = symbols('n', commutative=False) + assert simplify(n + n**(-n)) == n + n**(-n) + +def test_issue_27380(): + assert simplify(1.0**(x+1)/1.0**x) == 1.0 + +def test_simplify_fail1(): + x = Symbol('x') + y = Symbol('y') + e = (x + y)**2/(-4*x*y**2 - 2*y**3 - 2*x**2*y) + assert simplify(e) == 1 / (-2*y) + + +def test_nthroot(): + assert nthroot(90 + 34*sqrt(7), 3) == sqrt(7) + 3 + q = 1 + sqrt(2) - 2*sqrt(3) + sqrt(6) + sqrt(7) + assert nthroot(expand_multinomial(q**3), 3) == q + assert nthroot(41 + 29*sqrt(2), 5) == 1 + sqrt(2) + assert nthroot(-41 - 29*sqrt(2), 5) == -1 - sqrt(2) + expr = 1320*sqrt(10) + 4216 + 2576*sqrt(6) + 1640*sqrt(15) + assert nthroot(expr, 5) == 1 + sqrt(6) + sqrt(15) + q = 1 + sqrt(2) + sqrt(3) + sqrt(5) + assert expand_multinomial(nthroot(expand_multinomial(q**5), 5)) == q + q = 1 + sqrt(2) + 7*sqrt(6) + 2*sqrt(10) + assert nthroot(expand_multinomial(q**5), 5, 8) == q + q = 1 + sqrt(2) - 2*sqrt(3) + 1171*sqrt(6) + assert nthroot(expand_multinomial(q**3), 3) == q + assert nthroot(expand_multinomial(q**6), 6) == q + + +def test_nthroot1(): + q = 1 + sqrt(2) + sqrt(3) + S.One/10**20 + p = expand_multinomial(q**5) + assert nthroot(p, 5) == q + q = 1 + sqrt(2) + sqrt(3) + S.One/10**30 + p = expand_multinomial(q**5) + assert nthroot(p, 5) == q + + +@_both_exp_pow +def test_separatevars(): + x, y, z, n = symbols('x,y,z,n') + assert separatevars(2*n*x*z + 2*x*y*z) == 2*x*z*(n + y) + assert separatevars(x*z + x*y*z) == x*z*(1 + y) + assert separatevars(pi*x*z + pi*x*y*z) == pi*x*z*(1 + y) + assert separatevars(x*y**2*sin(x) + x*sin(x)*sin(y)) == \ + x*(sin(y) + y**2)*sin(x) + assert separatevars(x*exp(x + y) + x*exp(x)) == x*(1 + exp(y))*exp(x) + assert separatevars((x*(y + 1))**z).is_Pow # != x**z*(1 + y)**z + assert separatevars(1 + x + y + x*y) == (x + 1)*(y + 1) + assert separatevars(y/pi*exp(-(z - x)/cos(n))) == \ + y*exp(x/cos(n))*exp(-z/cos(n))/pi + assert separatevars((x + y)*(x - y) + y**2 + 2*x + 1) == (x + 1)**2 + # issue 4858 + p = Symbol('p', positive=True) + assert separatevars(sqrt(p**2 + x*p**2)) == p*sqrt(1 + x) + assert separatevars(sqrt(y*(p**2 + x*p**2))) == p*sqrt(y*(1 + x)) + assert separatevars(sqrt(y*(p**2 + x*p**2)), force=True) == \ + p*sqrt(y)*sqrt(1 + x) + # issue 4865 + assert separatevars(sqrt(x*y)).is_Pow + assert separatevars(sqrt(x*y), force=True) == sqrt(x)*sqrt(y) + # issue 4957 + # any type sequence for symbols is fine + assert separatevars(((2*x + 2)*y), dict=True, symbols=()) == \ + {'coeff': 1, x: 2*x + 2, y: y} + # separable + assert separatevars(((2*x + 2)*y), dict=True, symbols=[x]) == \ + {'coeff': y, x: 2*x + 2} + assert separatevars(((2*x + 2)*y), dict=True, symbols=[]) == \ + {'coeff': 1, x: 2*x + 2, y: y} + assert separatevars(((2*x + 2)*y), dict=True) == \ + {'coeff': 1, x: 2*x + 2, y: y} + assert separatevars(((2*x + 2)*y), dict=True, symbols=None) == \ + {'coeff': y*(2*x + 2)} + # not separable + assert separatevars(3, dict=True) is None + assert separatevars(2*x + y, dict=True, symbols=()) is None + assert separatevars(2*x + y, dict=True) is None + assert separatevars(2*x + y, dict=True, symbols=None) == {'coeff': 2*x + y} + # issue 4808 + n, m = symbols('n,m', commutative=False) + assert separatevars(m + n*m) == (1 + n)*m + assert separatevars(x + x*n) == x*(1 + n) + # issue 4910 + f = Function('f') + assert separatevars(f(x) + x*f(x)) == f(x) + x*f(x) + # a noncommutable object present + eq = x*(1 + hyper((), (), y*z)) + assert separatevars(eq) == eq + + s = separatevars(abs(x*y)) + assert s == abs(x)*abs(y) and s.is_Mul + z = cos(1)**2 + sin(1)**2 - 1 + a = abs(x*z) + s = separatevars(a) + assert not a.is_Mul and s.is_Mul and s == abs(x)*abs(z) + s = separatevars(abs(x*y*z)) + assert s == abs(x)*abs(y)*abs(z) + + # abs(x+y)/abs(z) would be better but we test this here to + # see that it doesn't raise + assert separatevars(abs((x+y)/z)) == abs((x+y)/z) + + +def test_separatevars_advanced_factor(): + x, y, z = symbols('x,y,z') + assert separatevars(1 + log(x)*log(y) + log(x) + log(y)) == \ + (log(x) + 1)*(log(y) + 1) + assert separatevars(1 + x - log(z) - x*log(z) - exp(y)*log(z) - + x*exp(y)*log(z) + x*exp(y) + exp(y)) == \ + -((x + 1)*(log(z) - 1)*(exp(y) + 1)) + x, y = symbols('x,y', positive=True) + assert separatevars(1 + log(x**log(y)) + log(x*y)) == \ + (log(x) + 1)*(log(y) + 1) + + +def test_hypersimp(): + n, k = symbols('n,k', integer=True) + + assert hypersimp(factorial(k), k) == k + 1 + assert hypersimp(factorial(k**2), k) is None + + assert hypersimp(1/factorial(k), k) == 1/(k + 1) + + assert hypersimp(2**k/factorial(k)**2, k) == 2/(k + 1)**2 + + assert hypersimp(binomial(n, k), k) == (n - k)/(k + 1) + assert hypersimp(binomial(n + 1, k), k) == (n - k + 1)/(k + 1) + + term = (4*k + 1)*factorial(k)/factorial(2*k + 1) + assert hypersimp(term, k) == S.Half*((4*k + 5)/(3 + 14*k + 8*k**2)) + + term = 1/((2*k - 1)*factorial(2*k + 1)) + assert hypersimp(term, k) == (k - S.Half)/((k + 1)*(2*k + 1)*(2*k + 3)) + + term = binomial(n, k)*(-1)**k/factorial(k) + assert hypersimp(term, k) == (k - n)/(k + 1)**2 + + +def test_nsimplify(): + x = Symbol("x") + assert nsimplify(0) == 0 + assert nsimplify(-1) == -1 + assert nsimplify(1) == 1 + assert nsimplify(1 + x) == 1 + x + assert nsimplify(2.7) == Rational(27, 10) + assert nsimplify(1 - GoldenRatio) == (1 - sqrt(5))/2 + assert nsimplify((1 + sqrt(5))/4, [GoldenRatio]) == GoldenRatio/2 + assert nsimplify(2/GoldenRatio, [GoldenRatio]) == 2*GoldenRatio - 2 + assert nsimplify(exp(pi*I*Rational(5, 3), evaluate=False)) == \ + sympify('1/2 - sqrt(3)*I/2') + assert nsimplify(sin(pi*Rational(3, 5), evaluate=False)) == \ + sympify('sqrt(sqrt(5)/8 + 5/8)') + assert nsimplify(sqrt(atan('1', evaluate=False))*(2 + I), [pi]) == \ + sqrt(pi) + sqrt(pi)/2*I + assert nsimplify(2 + exp(2*atan('1/4')*I)) == sympify('49/17 + 8*I/17') + assert nsimplify(pi, tolerance=0.01) == Rational(22, 7) + assert nsimplify(pi, tolerance=0.001) == Rational(355, 113) + assert nsimplify(0.33333, tolerance=1e-4) == Rational(1, 3) + assert nsimplify(2.0**(1/3.), tolerance=0.001) == Rational(635, 504) + assert nsimplify(2.0**(1/3.), tolerance=0.001, full=True) == \ + 2**Rational(1, 3) + assert nsimplify(x + .5, rational=True) == S.Half + x + assert nsimplify(1/.3 + x, rational=True) == Rational(10, 3) + x + assert nsimplify(log(3).n(), rational=True) == \ + sympify('109861228866811/100000000000000') + assert nsimplify(Float(0.272198261287950), [pi, log(2)]) == pi*log(2)/8 + assert nsimplify(Float(0.272198261287950).n(3), [pi, log(2)]) == \ + -pi/4 - log(2) + Rational(7, 4) + assert nsimplify(x/7.0) == x/7 + assert nsimplify(pi/1e2) == pi/100 + assert nsimplify(pi/1e2, rational=False) == pi/100.0 + assert nsimplify(pi/1e-7) == 10000000*pi + assert not nsimplify( + factor(-3.0*z**2*(z**2)**(-2.5) + 3*(z**2)**(-1.5))).atoms(Float) + e = x**0.0 + assert e.is_Pow and nsimplify(x**0.0) == 1 + assert nsimplify(3.333333, tolerance=0.1, rational=True) == Rational(10, 3) + assert nsimplify(3.333333, tolerance=0.01, rational=True) == Rational(10, 3) + assert nsimplify(3.666666, tolerance=0.1, rational=True) == Rational(11, 3) + assert nsimplify(3.666666, tolerance=0.01, rational=True) == Rational(11, 3) + assert nsimplify(33, tolerance=10, rational=True) == Rational(33) + assert nsimplify(33.33, tolerance=10, rational=True) == Rational(30) + assert nsimplify(37.76, tolerance=10, rational=True) == Rational(40) + assert nsimplify(-203.1) == Rational(-2031, 10) + assert nsimplify(.2, tolerance=0) == Rational(1, 5) + assert nsimplify(-.2, tolerance=0) == Rational(-1, 5) + assert nsimplify(.2222, tolerance=0) == Rational(1111, 5000) + assert nsimplify(-.2222, tolerance=0) == Rational(-1111, 5000) + # issue 7211, PR 4112 + assert nsimplify(S(2e-8)) == Rational(1, 50000000) + # issue 7322 direct test + assert nsimplify(1e-42, rational=True) != 0 + # issue 10336 + inf = Float('inf') + infs = (-oo, oo, inf, -inf) + for zi in infs: + ans = sign(zi)*oo + assert nsimplify(zi) == ans + assert nsimplify(zi + x) == x + ans + + assert nsimplify(0.33333333, rational=True, rational_conversion='exact') == Rational(0.33333333) + + # Make sure nsimplify on expressions uses full precision + assert nsimplify(pi.evalf(100)*x, rational_conversion='exact').evalf(100) == pi.evalf(100)*x + + +def test_issue_9448(): + tmp = sympify("1/(1 - (-1)**(2/3) - (-1)**(1/3)) + 1/(1 + (-1)**(2/3) + (-1)**(1/3))") + assert nsimplify(tmp) == S.Half + + +def test_extract_minus_sign(): + x = Symbol("x") + y = Symbol("y") + a = Symbol("a") + b = Symbol("b") + assert simplify(-x/-y) == x/y + assert simplify(-x/y) == -x/y + assert simplify(x/y) == x/y + assert simplify(x/-y) == -x/y + assert simplify(-x/0) == zoo*x + assert simplify(Rational(-5, 0)) is zoo + assert simplify(-a*x/(-y - b)) == a*x/(b + y) + + +def test_diff(): + x = Symbol("x") + y = Symbol("y") + f = Function("f") + g = Function("g") + assert simplify(g(x).diff(x)*f(x).diff(x) - f(x).diff(x)*g(x).diff(x)) == 0 + assert simplify(2*f(x)*f(x).diff(x) - diff(f(x)**2, x)) == 0 + assert simplify(diff(1/f(x), x) + f(x).diff(x)/f(x)**2) == 0 + assert simplify(f(x).diff(x, y) - f(x).diff(y, x)) == 0 + + +def test_logcombine_1(): + x, y = symbols("x,y") + a = Symbol("a") + z, w = symbols("z,w", positive=True) + b = Symbol("b", real=True) + assert logcombine(log(x) + 2*log(y)) == log(x) + 2*log(y) + assert logcombine(log(x) + 2*log(y), force=True) == log(x*y**2) + assert logcombine(a*log(w) + log(z)) == a*log(w) + log(z) + assert logcombine(b*log(z) + b*log(x)) == log(z**b) + b*log(x) + assert logcombine(b*log(z) - log(w)) == log(z**b/w) + assert logcombine(log(x)*log(z)) == log(x)*log(z) + assert logcombine(log(w)*log(x)) == log(w)*log(x) + assert logcombine(cos(-2*log(z) + b*log(w))) in [cos(log(w**b/z**2)), + cos(log(z**2/w**b))] + assert logcombine(log(log(x) - log(y)) - log(z), force=True) == \ + log(log(x/y)/z) + assert logcombine((2 + I)*log(x), force=True) == (2 + I)*log(x) + assert logcombine((x**2 + log(x) - log(y))/(x*y), force=True) == \ + (x**2 + log(x/y))/(x*y) + # the following could also give log(z*x**log(y**2)), what we + # are testing is that a canonical result is obtained + assert logcombine(log(x)*2*log(y) + log(z), force=True) == \ + log(z*y**log(x**2)) + assert logcombine((x*y + sqrt(x**4 + y**4) + log(x) - log(y))/(pi*x**Rational(2, 3)* + sqrt(y)**3), force=True) == ( + x*y + sqrt(x**4 + y**4) + log(x/y))/(pi*x**Rational(2, 3)*y**Rational(3, 2)) + assert logcombine(gamma(-log(x/y))*acos(-log(x/y)), force=True) == \ + acos(-log(x/y))*gamma(-log(x/y)) + + assert logcombine(2*log(z)*log(w)*log(x) + log(z) + log(w)) == \ + log(z**log(w**2))*log(x) + log(w*z) + assert logcombine(3*log(w) + 3*log(z)) == log(w**3*z**3) + assert logcombine(x*(y + 1) + log(2) + log(3)) == x*(y + 1) + log(6) + assert logcombine((x + y)*log(w) + (-x - y)*log(3)) == (x + y)*log(w/3) + # a single unknown can combine + assert logcombine(log(x) + log(2)) == log(2*x) + eq = log(abs(x)) + log(abs(y)) + assert logcombine(eq) == eq + reps = {x: 0, y: 0} + assert log(abs(x)*abs(y)).subs(reps) != eq.subs(reps) + + +def test_logcombine_complex_coeff(): + i = Integral((sin(x**2) + cos(x**3))/x, x) + assert logcombine(i, force=True) == i + assert logcombine(i + 2*log(x), force=True) == \ + i + log(x**2) + + +def test_issue_5950(): + x, y = symbols("x,y", positive=True) + assert logcombine(log(3) - log(2)) == log(Rational(3,2), evaluate=False) + assert logcombine(log(x) - log(y)) == log(x/y) + assert logcombine(log(Rational(3,2), evaluate=False) - log(2)) == \ + log(Rational(3,4), evaluate=False) + + +def test_posify(): + x = symbols('x') + + assert str(posify( + x + + Symbol('p', positive=True) + + Symbol('n', negative=True))) == '(_x + n + p, {_x: x})' + + eq, rep = posify(1/x) + assert log(eq).expand().subs(rep) == -log(x) + assert str(posify([x, 1 + x])) == '([_x, _x + 1], {_x: x})' + + p = symbols('p', positive=True) + n = symbols('n', negative=True) + orig = [x, n, p] + modified, reps = posify(orig) + assert str(modified) == '[_x, n, p]' + assert [w.subs(reps) for w in modified] == orig + + assert str(Integral(posify(1/x + y)[0], (y, 1, 3)).expand()) == \ + 'Integral(1/_x, (y, 1, 3)) + Integral(_y, (y, 1, 3))' + assert str(Sum(posify(1/x**n)[0], (n,1,3)).expand()) == \ + 'Sum(_x**(-n), (n, 1, 3))' + + A = Matrix([[1, 2, 3], [4, 5, 6 * Abs(x)]]) + Ap, rep = posify(A) + assert Ap == A.subs(*reversed(rep.popitem())) + + # issue 16438 + k = Symbol('k', finite=True) + eq, rep = posify(k) + assert eq.assumptions0 == {'positive': True, 'zero': False, 'imaginary': False, + 'nonpositive': False, 'commutative': True, 'hermitian': True, 'real': True, 'nonzero': True, + 'nonnegative': True, 'negative': False, 'complex': True, 'finite': True, + 'infinite': False, 'extended_real':True, 'extended_negative': False, + 'extended_nonnegative': True, 'extended_nonpositive': False, + 'extended_nonzero': True, 'extended_positive': True} + + +def test_issue_4194(): + # simplify should call cancel + f = Function('f') + assert simplify((4*x + 6*f(y))/(2*x + 3*f(y))) == 2 + + +@XFAIL +def test_simplify_float_vs_integer(): + # Test for issue 4473: + # https://github.com/sympy/sympy/issues/4473 + assert simplify(x**2.0 - x**2) == 0 + assert simplify(x**2 - x**2.0) == 0 + + +def test_as_content_primitive(): + assert (x/2 + y).as_content_primitive() == (S.Half, x + 2*y) + assert (x/2 + y).as_content_primitive(clear=False) == (S.One, x/2 + y) + assert (y*(x/2 + y)).as_content_primitive() == (S.Half, y*(x + 2*y)) + assert (y*(x/2 + y)).as_content_primitive(clear=False) == (S.One, y*(x/2 + y)) + + # although the _as_content_primitive methods do not alter the underlying structure, + # the as_content_primitive function will touch up the expression and join + # bases that would otherwise have not been joined. + assert (x*(2 + 2*x)*(3*x + 3)**2).as_content_primitive() == \ + (18, x*(x + 1)**3) + assert (2 + 2*x + 2*y*(3 + 3*y)).as_content_primitive() == \ + (2, x + 3*y*(y + 1) + 1) + assert ((2 + 6*x)**2).as_content_primitive() == \ + (4, (3*x + 1)**2) + assert ((2 + 6*x)**(2*y)).as_content_primitive() == \ + (1, (_keep_coeff(S(2), (3*x + 1)))**(2*y)) + assert (5 + 10*x + 2*y*(3 + 3*y)).as_content_primitive() == \ + (1, 10*x + 6*y*(y + 1) + 5) + assert (5*(x*(1 + y)) + 2*x*(3 + 3*y)).as_content_primitive() == \ + (11, x*(y + 1)) + assert ((5*(x*(1 + y)) + 2*x*(3 + 3*y))**2).as_content_primitive() == \ + (121, x**2*(y + 1)**2) + assert (y**2).as_content_primitive() == \ + (1, y**2) + assert (S.Infinity).as_content_primitive() == (1, oo) + eq = x**(2 + y) + assert (eq).as_content_primitive() == (1, eq) + assert (S.Half**(2 + x)).as_content_primitive() == (Rational(1, 4), 2**-x) + assert (Rational(-1, 2)**(2 + x)).as_content_primitive() == \ + (Rational(1, 4), (Rational(-1, 2))**x) + assert (Rational(-1, 2)**(2 + x)).as_content_primitive() == \ + (Rational(1, 4), Rational(-1, 2)**x) + assert (4**((1 + y)/2)).as_content_primitive() == (2, 4**(y/2)) + assert (3**((1 + y)/2)).as_content_primitive() == \ + (1, 3**(Mul(S.Half, 1 + y, evaluate=False))) + assert (5**Rational(3, 4)).as_content_primitive() == (1, 5**Rational(3, 4)) + assert (5**Rational(7, 4)).as_content_primitive() == (5, 5**Rational(3, 4)) + assert Add(z*Rational(5, 7), 0.5*x, y*Rational(3, 2), evaluate=False).as_content_primitive() == \ + (Rational(1, 14), 7.0*x + 21*y + 10*z) + assert (2**Rational(3, 4) + 2**Rational(1, 4)*sqrt(3)).as_content_primitive(radical=True) == \ + (1, 2**Rational(1, 4)*(sqrt(2) + sqrt(3))) + + +def test_signsimp(): + e = x*(-x + 1) + x*(x - 1) + assert signsimp(Eq(e, 0)) is S.true + assert Abs(x - 1) == Abs(1 - x) + assert signsimp(y - x) == y - x + assert signsimp(y - x, evaluate=False) == Mul(-1, x - y, evaluate=False) + + +def test_besselsimp(): + from sympy.functions.special.bessel import (besseli, besselj, bessely) + from sympy.integrals.transforms import cosine_transform + assert besselsimp(exp(-I*pi*y/2)*besseli(y, z*exp_polar(I*pi/2))) == \ + besselj(y, z) + assert besselsimp(exp(-I*pi*a/2)*besseli(a, 2*sqrt(x)*exp_polar(I*pi/2))) == \ + besselj(a, 2*sqrt(x)) + assert besselsimp(sqrt(2)*sqrt(pi)*x**Rational(1, 4)*exp(I*pi/4)*exp(-I*pi*a/2) * + besseli(Rational(-1, 2), sqrt(x)*exp_polar(I*pi/2)) * + besseli(a, sqrt(x)*exp_polar(I*pi/2))/2) == \ + besselj(a, sqrt(x)) * cos(sqrt(x)) + assert besselsimp(besseli(Rational(-1, 2), z)) == \ + sqrt(2)*cosh(z)/(sqrt(pi)*sqrt(z)) + assert besselsimp(besseli(a, z*exp_polar(-I*pi/2))) == \ + exp(-I*pi*a/2)*besselj(a, z) + assert cosine_transform(1/t*sin(a/t), t, y) == \ + sqrt(2)*sqrt(pi)*besselj(0, 2*sqrt(a)*sqrt(y))/2 + + assert besselsimp(x**2*(a*(-2*besselj(5*I, x) + besselj(-2 + 5*I, x) + + besselj(2 + 5*I, x)) + b*(-2*bessely(5*I, x) + bessely(-2 + 5*I, x) + + bessely(2 + 5*I, x)))/4 + x*(a*(besselj(-1 + 5*I, x)/2 - besselj(1 + 5*I, x)/2) + + b*(bessely(-1 + 5*I, x)/2 - bessely(1 + 5*I, x)/2)) + (x**2 + 25)*(a*besselj(5*I, x) + + b*bessely(5*I, x))) == 0 + + assert besselsimp(81*x**2*(a*(besselj(Rational(-5, 3), 9*x) - 2*besselj(Rational(1, 3), 9*x) + besselj(Rational(7, 3), 9*x)) + + b*(bessely(Rational(-5, 3), 9*x) - 2*bessely(Rational(1, 3), 9*x) + bessely(Rational(7, 3), 9*x)))/4 + x*(a*(9*besselj(Rational(-2, 3), 9*x)/2 + - 9*besselj(Rational(4, 3), 9*x)/2) + b*(9*bessely(Rational(-2, 3), 9*x)/2 - 9*bessely(Rational(4, 3), 9*x)/2)) + + (81*x**2 - Rational(1, 9))*(a*besselj(Rational(1, 3), 9*x) + b*bessely(Rational(1, 3), 9*x))) == 0 + + assert besselsimp(besselj(a-1,x) + besselj(a+1, x) - 2*a*besselj(a, x)/x) == 0 + + assert besselsimp(besselj(a-1,x) + besselj(a+1, x) + besselj(a, x)) == (2*a + x)*besselj(a, x)/x + + assert besselsimp(x**2* besselj(a,x) + x**3*besselj(a+1, x) + besselj(a+2, x)) == \ + 2*a*x*besselj(a + 1, x) + x**3*besselj(a + 1, x) - x**2*besselj(a + 2, x) + 2*x*besselj(a + 1, x) + besselj(a + 2, x) + +def test_Piecewise(): + e1 = x*(x + y) - y*(x + y) + e2 = sin(x)**2 + cos(x)**2 + e3 = expand((x + y)*y/x) + s1 = simplify(e1) + s2 = simplify(e2) + s3 = simplify(e3) + assert simplify(Piecewise((e1, x < e2), (e3, True))) == \ + Piecewise((s1, x < s2), (s3, True)) + + +def test_polymorphism(): + class A(Basic): + def _eval_simplify(x, **kwargs): + return S.One + + a = A(S(5), S(2)) + assert simplify(a) == 1 + + +def test_issue_from_PR1599(): + n1, n2, n3, n4 = symbols('n1 n2 n3 n4', negative=True) + assert simplify(I*sqrt(n1)) == -sqrt(-n1) + + +def test_issue_6811(): + eq = (x + 2*y)*(2*x + 2) + assert simplify(eq) == (x + 1)*(x + 2*y)*2 + # reject the 2-arg Mul -- these are a headache for test writing + assert simplify(eq.expand()) == \ + 2*x**2 + 4*x*y + 2*x + 4*y + + +def test_issue_6920(): + e = [cos(x) + I*sin(x), cos(x) - I*sin(x), + cosh(x) - sinh(x), cosh(x) + sinh(x)] + ok = [exp(I*x), exp(-I*x), exp(-x), exp(x)] + # wrap in f to show that the change happens wherever ei occurs + f = Function('f') + assert [simplify(f(ei)).args[0] for ei in e] == ok + + +def test_issue_7001(): + from sympy.abc import r, R + assert simplify(-(r*Piecewise((pi*Rational(4, 3), r <= R), + (-8*pi*R**3/(3*r**3), True)) + 2*Piecewise((pi*r*Rational(4, 3), r <= R), + (4*pi*R**3/(3*r**2), True)))/(4*pi*r)) == \ + Piecewise((-1, r <= R), (0, True)) + + +def test_inequality_no_auto_simplify(): + # no simplify on creation but can be simplified + lhs = cos(x)**2 + sin(x)**2 + rhs = 2 + e = Lt(lhs, rhs, evaluate=False) + assert e is not S.true + assert simplify(e) + + +def test_issue_9398(): + from sympy.core.numbers import Number + from sympy.polys.polytools import cancel + assert cancel(1e-14) != 0 + assert cancel(1e-14*I) != 0 + + assert simplify(1e-14) != 0 + assert simplify(1e-14*I) != 0 + + assert (I*Number(1.)*Number(10)**Number(-14)).simplify() != 0 + + assert cancel(1e-20) != 0 + assert cancel(1e-20*I) != 0 + + assert simplify(1e-20) != 0 + assert simplify(1e-20*I) != 0 + + assert cancel(1e-100) != 0 + assert cancel(1e-100*I) != 0 + + assert simplify(1e-100) != 0 + assert simplify(1e-100*I) != 0 + + f = Float("1e-1000") + assert cancel(f) != 0 + assert cancel(f*I) != 0 + + assert simplify(f) != 0 + assert simplify(f*I) != 0 + + +def test_issue_9324_simplify(): + M = MatrixSymbol('M', 10, 10) + e = M[0, 0] + M[5, 4] + 1304 + assert simplify(e) == e + + +def test_issue_9817_simplify(): + # simplify on trace of substituted explicit quadratic form of matrix + # expressions (a scalar) should return without errors (AttributeError) + # See issue #9817 and #9190 for the original bug more discussion on this + from sympy.matrices.expressions import Identity, trace + v = MatrixSymbol('v', 3, 1) + A = MatrixSymbol('A', 3, 3) + x = Matrix([i + 1 for i in range(3)]) + X = Identity(3) + quadratic = v.T * A * v + assert simplify((trace(quadratic.as_explicit())).xreplace({v:x, A:X})) == 14 + + +def test_issue_13474(): + x = Symbol('x') + assert simplify(x + csch(sinc(1))) == x + csch(sinc(1)) + + +@_both_exp_pow +def test_simplify_function_inverse(): + # "inverse" attribute does not guarantee that f(g(x)) is x + # so this simplification should not happen automatically. + # See issue #12140 + x, y = symbols('x, y') + g = Function('g') + + class f(Function): + def inverse(self, argindex=1): + return g + + assert simplify(f(g(x))) == f(g(x)) + assert inversecombine(f(g(x))) == x + assert simplify(f(g(x)), inverse=True) == x + assert simplify(f(g(sin(x)**2 + cos(x)**2)), inverse=True) == 1 + assert simplify(f(g(x, y)), inverse=True) == f(g(x, y)) + assert unchanged(asin, sin(x)) + assert simplify(asin(sin(x))) == asin(sin(x)) + assert simplify(2*asin(sin(3*x)), inverse=True) == 6*x + assert simplify(log(exp(x))) == log(exp(x)) + assert simplify(log(exp(x)), inverse=True) == x + assert simplify(exp(log(x)), inverse=True) == x + assert simplify(log(exp(x), 2), inverse=True) == x/log(2) + assert simplify(log(exp(x), 2, evaluate=False), inverse=True) == x/log(2) + + +def test_clear_coefficients(): + from sympy.simplify.simplify import clear_coefficients + assert clear_coefficients(4*y*(6*x + 3)) == (y*(2*x + 1), 0) + assert clear_coefficients(4*y*(6*x + 3) - 2) == (y*(2*x + 1), Rational(1, 6)) + assert clear_coefficients(4*y*(6*x + 3) - 2, x) == (y*(2*x + 1), x/12 + Rational(1, 6)) + assert clear_coefficients(sqrt(2) - 2) == (sqrt(2), 2) + assert clear_coefficients(4*sqrt(2) - 2) == (sqrt(2), S.Half) + assert clear_coefficients(S(3), x) == (0, x - 3) + assert clear_coefficients(S.Infinity, x) == (S.Infinity, x) + assert clear_coefficients(-S.Pi, x) == (S.Pi, -x) + assert clear_coefficients(2 - S.Pi/3, x) == (pi, -3*x + 6) + +def test_nc_simplify(): + from sympy.simplify.simplify import nc_simplify + from sympy.matrices.expressions import MatPow, Identity + from sympy.core import Pow + from functools import reduce + + a, b, c, d = symbols('a b c d', commutative = False) + x = Symbol('x') + A = MatrixSymbol("A", x, x) + B = MatrixSymbol("B", x, x) + C = MatrixSymbol("C", x, x) + D = MatrixSymbol("D", x, x) + subst = {a: A, b: B, c: C, d:D} + funcs = {Add: lambda x,y: x+y, Mul: lambda x,y: x*y } + + def _to_matrix(expr): + if expr in subst: + return subst[expr] + if isinstance(expr, Pow): + return MatPow(_to_matrix(expr.args[0]), expr.args[1]) + elif isinstance(expr, (Add, Mul)): + return reduce(funcs[expr.func],[_to_matrix(a) for a in expr.args]) + else: + return expr*Identity(x) + + def _check(expr, simplified, deep=True, matrix=True): + assert nc_simplify(expr, deep=deep) == simplified + assert expand(expr) == expand(simplified) + if matrix: + m_simp = _to_matrix(simplified).doit(inv_expand=False) + assert nc_simplify(_to_matrix(expr), deep=deep) == m_simp + + _check(a*b*a*b*a*b*c*(a*b)**3*c, ((a*b)**3*c)**2) + _check(a*b*(a*b)**-2*a*b, 1) + _check(a**2*b*a*b*a*b*(a*b)**-1, a*(a*b)**2, matrix=False) + _check(b*a*b**2*a*b**2*a*b**2, b*(a*b**2)**3) + _check(a*b*a**2*b*a**2*b*a**3, (a*b*a)**3*a**2) + _check(a**2*b*a**4*b*a**4*b*a**2, (a**2*b*a**2)**3) + _check(a**3*b*a**4*b*a**4*b*a, a**3*(b*a**4)**3*a**-3) + _check(a*b*a*b + a*b*c*x*a*b*c, (a*b)**2 + x*(a*b*c)**2) + _check(a*b*a*b*c*a*b*a*b*c, ((a*b)**2*c)**2) + _check(b**-1*a**-1*(a*b)**2, a*b) + _check(a**-1*b*c**-1, (c*b**-1*a)**-1) + expr = a**3*b*a**4*b*a**4*b*a**2*b*a**2*(b*a**2)**2*b*a**2*b*a**2 + for _ in range(10): + expr *= a*b + _check(expr, a**3*(b*a**4)**2*(b*a**2)**6*(a*b)**10) + _check((a*b*a*b)**2, (a*b*a*b)**2, deep=False) + _check(a*b*(c*d)**2, a*b*(c*d)**2) + expr = b**-1*(a**-1*b**-1 - a**-1*c*b**-1)**-1*a**-1 + assert nc_simplify(expr) == (1-c)**-1 + # commutative expressions should be returned without an error + assert nc_simplify(2*x**2) == 2*x**2 + +def test_issue_15965(): + A = Sum(z*x**y, (x, 1, a)) + anew = z*Sum(x**y, (x, 1, a)) + B = Integral(x*y, x) + bdo = x**2*y/2 + assert simplify(A + B) == anew + bdo + assert simplify(A) == anew + assert simplify(B) == bdo + assert simplify(B, doit=False) == y*Integral(x, x) + + +def test_issue_17137(): + assert simplify(cos(x)**I) == cos(x)**I + assert simplify(cos(x)**(2 + 3*I)) == cos(x)**(2 + 3*I) + + +def test_issue_21869(): + x = Symbol('x', real=True) + y = Symbol('y', real=True) + expr = And(Eq(x**2, 4), Le(x, y)) + assert expr.simplify() == expr + + expr = And(Eq(x**2, 4), Eq(x, 2)) + assert expr.simplify() == Eq(x, 2) + + expr = And(Eq(x**3, x**2), Eq(x, 1)) + assert expr.simplify() == Eq(x, 1) + + expr = And(Eq(sin(x), x**2), Eq(x, 0)) + assert expr.simplify() == Eq(x, 0) + + expr = And(Eq(x**3, x**2), Eq(x, 2)) + assert expr.simplify() == S.false + + expr = And(Eq(y, x**2), Eq(x, 1)) + assert expr.simplify() == And(Eq(y,1), Eq(x, 1)) + + expr = And(Eq(y**2, 1), Eq(y, x**2), Eq(x, 1)) + assert expr.simplify() == And(Eq(y,1), Eq(x, 1)) + + expr = And(Eq(y**2, 4), Eq(y, 2*x**2), Eq(x, 1)) + assert expr.simplify() == And(Eq(y,2), Eq(x, 1)) + + expr = And(Eq(y**2, 4), Eq(y, x**2), Eq(x, 1)) + assert expr.simplify() == S.false + + +def test_issue_7971_21740(): + z = Integral(x, (x, 1, 1)) + assert z != 0 + assert simplify(z) is S.Zero + assert simplify(S.Zero) is S.Zero + z = simplify(Float(0)) + assert z is not S.Zero and z == 0.0 + + +@slow +def test_issue_17141_slow(): + # Should not give RecursionError + assert simplify((2**acos(I+1)**2).rewrite('log')) == 2**((pi + 2*I*log(-1 + + sqrt(1 - 2*I) + I))**2/4) + + +def test_issue_17141(): + # Check that there is no RecursionError + assert simplify(x**(1 / acos(I))) == x**(2/(pi - 2*I*log(1 + sqrt(2)))) + assert simplify(acos(-I)**2*acos(I)**2) == \ + log(1 + sqrt(2))**4 + pi**2*log(1 + sqrt(2))**2/2 + pi**4/16 + assert simplify(2**acos(I)**2) == 2**((pi - 2*I*log(1 + sqrt(2)))**2/4) + p = 2**acos(I+1)**2 + assert simplify(p) == p + + +def test_simplify_kroneckerdelta(): + i, j = symbols("i j") + K = KroneckerDelta + + assert simplify(K(i, j)) == K(i, j) + assert simplify(K(0, j)) == K(0, j) + assert simplify(K(i, 0)) == K(i, 0) + + assert simplify(K(0, j).rewrite(Piecewise) * K(1, j)) == 0 + assert simplify(K(1, i) + Piecewise((1, Eq(j, 2)), (0, True))) == K(1, i) + K(2, j) + + # issue 17214 + assert simplify(K(0, j) * K(1, j)) == 0 + + n = Symbol('n', integer=True) + assert simplify(K(0, n) * K(1, n)) == 0 + + M = Matrix(4, 4, lambda i, j: K(j - i, n) if i <= j else 0) + assert simplify(M**2) == Matrix([[K(0, n), 0, K(1, n), 0], + [0, K(0, n), 0, K(1, n)], + [0, 0, K(0, n), 0], + [0, 0, 0, K(0, n)]]) + assert simplify(eye(1) * KroneckerDelta(0, n) * + KroneckerDelta(1, n)) == Matrix([[0]]) + + assert simplify(S.Infinity * KroneckerDelta(0, n) * + KroneckerDelta(1, n)) is S.NaN + + +def test_issue_17292(): + assert simplify(abs(x)/abs(x**2)) == 1/abs(x) + # this is bigger than the issue: check that deep processing works + assert simplify(5*abs((x**2 - 1)/(x - 1))) == 5*Abs(x + 1) + + +def test_issue_19822(): + expr = And(Gt(n-2, 1), Gt(n, 1)) + assert simplify(expr) == Gt(n, 3) + + +def test_issue_18645(): + expr = And(Ge(x, 3), Le(x, 3)) + assert simplify(expr) == Eq(x, 3) + expr = And(Eq(x, 3), Le(x, 3)) + assert simplify(expr) == Eq(x, 3) + + +@XFAIL +def test_issue_18642(): + i = Symbol("i", integer=True) + n = Symbol("n", integer=True) + expr = And(Eq(i, 2 * n), Le(i, 2*n -1)) + assert simplify(expr) == S.false + + +@XFAIL +def test_issue_18389(): + n = Symbol("n", integer=True) + expr = Eq(n, 0) | (n >= 1) + assert simplify(expr) == Ge(n, 0) + + +def test_issue_8373(): + x = Symbol('x', real=True) + assert simplify(Or(x < 1, x >= 1)) == S.true + + +def test_issue_7950(): + expr = And(Eq(x, 1), Eq(x, 2)) + assert simplify(expr) == S.false + + +def test_issue_22020(): + expr = I*pi/2 -oo + assert simplify(expr) == expr + # Used to throw an error + + +def test_issue_19484(): + assert simplify(sign(x) * Abs(x)) == x + + e = x + sign(x + x**3) + assert simplify(Abs(x + x**3)*e) == x**3 + x*Abs(x**3 + x) + x + + e = x**2 + sign(x**3 + 1) + assert simplify(Abs(x**3 + 1) * e) == x**3 + x**2*Abs(x**3 + 1) + 1 + + f = Function('f') + e = x + sign(x + f(x)**3) + assert simplify(Abs(x + f(x)**3) * e) == x*Abs(x + f(x)**3) + x + f(x)**3 + + +def test_issue_23543(): + # Used to give an error + x, y, z = symbols("x y z", commutative=False) + assert (x*(y + z/2)).simplify() == x*(2*y + z)/2 + + +def test_issue_11004(): + + def f(n): + return sqrt(2*pi*n) * (n/E)**n + + def m(n, k): + return f(n) / (f(n/k)**k) + + def p(n,k): + return m(n, k) / (k**n) + + N, k = symbols('N k') + half = Float('0.5', 4) + z = log(p(n, k) / p(n, k + 1)).expand(force=True) + r = simplify(z.subs(n, N).n(4)) + assert r == ( + half*k*log(k) + - half*k*log(k + 1) + + half*log(N) + - half*log(k + 1) + + Float(0.9189224, 4) + ) + + +def test_issue_19161(): + polynomial = Poly('x**2').simplify() + assert (polynomial-x**2).simplify() == 0 + + +def test_issue_22210(): + d = Symbol('d', integer=True) + expr = 2*Derivative(sin(x), (x, d)) + assert expr.simplify() == expr + + +def test_reduce_inverses_nc_pow(): + x, y = symbols("x y", commutative=True) + Z = symbols("Z", commutative=False) + assert simplify(2**Z * y**Z) == 2**Z * y**Z + assert simplify(x**Z * y**Z) == x**Z * y**Z + x, y = symbols("x y", positive=True) + assert expand((x*y)**Z) == x**Z * y**Z + assert simplify(x**Z * y**Z) == expand((x*y)**Z) + +def test_nc_recursion_coeff(): + X = symbols("X", commutative = False) + assert (2 * cos(pi/3) * X).simplify() == X + assert (2.0 * cos(pi/3) * X).simplify() == X diff --git a/phivenv/Lib/site-packages/sympy/simplify/tests/test_sqrtdenest.py b/phivenv/Lib/site-packages/sympy/simplify/tests/test_sqrtdenest.py new file mode 100644 index 0000000000000000000000000000000000000000..41c771bb2055a1199d349ae3649f33927d79313a --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/simplify/tests/test_sqrtdenest.py @@ -0,0 +1,204 @@ +from sympy.core.mul import Mul +from sympy.core.numbers import (I, Integer, Rational) +from sympy.core.symbol import Symbol +from sympy.functions.elementary.miscellaneous import (root, sqrt) +from sympy.functions.elementary.trigonometric import cos +from sympy.integrals.integrals import Integral +from sympy.simplify.sqrtdenest import sqrtdenest +from sympy.simplify.sqrtdenest import ( + _subsets as subsets, _sqrt_numeric_denest) + +r2, r3, r5, r6, r7, r10, r15, r29 = [sqrt(x) for x in (2, 3, 5, 6, 7, 10, + 15, 29)] + + +def test_sqrtdenest(): + d = {sqrt(5 + 2 * r6): r2 + r3, + sqrt(5. + 2 * r6): sqrt(5. + 2 * r6), + sqrt(5. + 4*sqrt(5 + 2 * r6)): sqrt(5.0 + 4*r2 + 4*r3), + sqrt(r2): sqrt(r2), + sqrt(5 + r7): sqrt(5 + r7), + sqrt(3 + sqrt(5 + 2*r7)): + 3*r2*(5 + 2*r7)**Rational(1, 4)/(2*sqrt(6 + 3*r7)) + + r2*sqrt(6 + 3*r7)/(2*(5 + 2*r7)**Rational(1, 4)), + sqrt(3 + 2*r3): 3**Rational(3, 4)*(r6/2 + 3*r2/2)/3} + for i in d: + assert sqrtdenest(i) == d[i], i + + +def test_sqrtdenest2(): + assert sqrtdenest(sqrt(16 - 2*r29 + 2*sqrt(55 - 10*r29))) == \ + r5 + sqrt(11 - 2*r29) + e = sqrt(-r5 + sqrt(-2*r29 + 2*sqrt(-10*r29 + 55) + 16)) + assert sqrtdenest(e) == root(-2*r29 + 11, 4) + r = sqrt(1 + r7) + assert sqrtdenest(sqrt(1 + r)) == sqrt(1 + r) + e = sqrt(((1 + sqrt(1 + 2*sqrt(3 + r2 + r5)))**2).expand()) + assert sqrtdenest(e) == 1 + sqrt(1 + 2*sqrt(r2 + r5 + 3)) + + assert sqrtdenest(sqrt(5*r3 + 6*r2)) == \ + sqrt(2)*root(3, 4) + root(3, 4)**3 + + assert sqrtdenest(sqrt(((1 + r5 + sqrt(1 + r3))**2).expand())) == \ + 1 + r5 + sqrt(1 + r3) + + assert sqrtdenest(sqrt(((1 + r5 + r7 + sqrt(1 + r3))**2).expand())) == \ + 1 + sqrt(1 + r3) + r5 + r7 + + e = sqrt(((1 + cos(2) + cos(3) + sqrt(1 + r3))**2).expand()) + assert sqrtdenest(e) == cos(3) + cos(2) + 1 + sqrt(1 + r3) + + e = sqrt(-2*r10 + 2*r2*sqrt(-2*r10 + 11) + 14) + assert sqrtdenest(e) == sqrt(-2*r10 - 2*r2 + 4*r5 + 14) + + # check that the result is not more complicated than the input + z = sqrt(-2*r29 + cos(2) + 2*sqrt(-10*r29 + 55) + 16) + assert sqrtdenest(z) == z + + assert sqrtdenest(sqrt(r6 + sqrt(15))) == sqrt(r6 + sqrt(15)) + + z = sqrt(15 - 2*sqrt(31) + 2*sqrt(55 - 10*r29)) + assert sqrtdenest(z) == z + + +def test_sqrtdenest_rec(): + assert sqrtdenest(sqrt(-4*sqrt(14) - 2*r6 + 4*sqrt(21) + 33)) == \ + -r2 + r3 + 2*r7 + assert sqrtdenest(sqrt(-28*r7 - 14*r5 + 4*sqrt(35) + 82)) == \ + -7 + r5 + 2*r7 + assert sqrtdenest(sqrt(6*r2/11 + 2*sqrt(22)/11 + 6*sqrt(11)/11 + 2)) == \ + sqrt(11)*(r2 + 3 + sqrt(11))/11 + assert sqrtdenest(sqrt(468*r3 + 3024*r2 + 2912*r6 + 19735)) == \ + 9*r3 + 26 + 56*r6 + z = sqrt(-490*r3 - 98*sqrt(115) - 98*sqrt(345) - 2107) + assert sqrtdenest(z) == sqrt(-1)*(7*r5 + 7*r15 + 7*sqrt(23)) + z = sqrt(-4*sqrt(14) - 2*r6 + 4*sqrt(21) + 34) + assert sqrtdenest(z) == z + assert sqrtdenest(sqrt(-8*r2 - 2*r5 + 18)) == -r10 + 1 + r2 + r5 + assert sqrtdenest(sqrt(8*r2 + 2*r5 - 18)) == \ + sqrt(-1)*(-r10 + 1 + r2 + r5) + assert sqrtdenest(sqrt(8*r2/3 + 14*r5/3 + Rational(154, 9))) == \ + -r10/3 + r2 + r5 + 3 + assert sqrtdenest(sqrt(sqrt(2*r6 + 5) + sqrt(2*r7 + 8))) == \ + sqrt(1 + r2 + r3 + r7) + assert sqrtdenest(sqrt(4*r15 + 8*r5 + 12*r3 + 24)) == 1 + r3 + r5 + r15 + + w = 1 + r2 + r3 + r5 + r7 + assert sqrtdenest(sqrt((w**2).expand())) == w + z = sqrt((w**2).expand() + 1) + assert sqrtdenest(z) == z + + z = sqrt(2*r10 + 6*r2 + 4*r5 + 12 + 10*r15 + 30*r3) + assert sqrtdenest(z) == z + + +def test_issue_6241(): + z = sqrt( -320 + 32*sqrt(5) + 64*r15) + assert sqrtdenest(z) == z + + +def test_sqrtdenest3(): + z = sqrt(13 - 2*r10 + 2*r2*sqrt(-2*r10 + 11)) + assert sqrtdenest(z) == -1 + r2 + r10 + assert sqrtdenest(z, max_iter=1) == -1 + sqrt(2) + sqrt(10) + z = sqrt(sqrt(r2 + 2) + 2) + assert sqrtdenest(z) == z + assert sqrtdenest(sqrt(-2*r10 + 4*r2*sqrt(-2*r10 + 11) + 20)) == \ + sqrt(-2*r10 - 4*r2 + 8*r5 + 20) + assert sqrtdenest(sqrt((112 + 70*r2) + (46 + 34*r2)*r5)) == \ + r10 + 5 + 4*r2 + 3*r5 + z = sqrt(5 + sqrt(2*r6 + 5)*sqrt(-2*r29 + 2*sqrt(-10*r29 + 55) + 16)) + r = sqrt(-2*r29 + 11) + assert sqrtdenest(z) == sqrt(r2*r + r3*r + r10 + r15 + 5) + + n = sqrt(2*r6/7 + 2*r7/7 + 2*sqrt(42)/7 + 2) + d = sqrt(16 - 2*r29 + 2*sqrt(55 - 10*r29)) + assert sqrtdenest(n/d) == r7*(1 + r6 + r7)/(Mul(7, (sqrt(-2*r29 + 11) + r5), + evaluate=False)) + + +def test_sqrtdenest4(): + # see Denest_en.pdf in https://github.com/sympy/sympy/issues/3192 + z = sqrt(8 - r2*sqrt(5 - r5) - sqrt(3)*(1 + r5)) + z1 = sqrtdenest(z) + c = sqrt(-r5 + 5) + z1 = ((-r15*c - r3*c + c + r5*c - r6 - r2 + r10 + sqrt(30))/4).expand() + assert sqrtdenest(z) == z1 + + z = sqrt(2*r2*sqrt(r2 + 2) + 5*r2 + 4*sqrt(r2 + 2) + 8) + assert sqrtdenest(z) == r2 + sqrt(r2 + 2) + 2 + + w = 2 + r2 + r3 + (1 + r3)*sqrt(2 + r2 + 5*r3) + z = sqrt((w**2).expand()) + assert sqrtdenest(z) == w.expand() + + +def test_sqrt_symbolic_denest(): + x = Symbol('x') + z = sqrt(((1 + sqrt(sqrt(2 + x) + 3))**2).expand()) + assert sqrtdenest(z) == sqrt((1 + sqrt(sqrt(2 + x) + 3))**2) + z = sqrt(((1 + sqrt(sqrt(2 + cos(1)) + 3))**2).expand()) + assert sqrtdenest(z) == 1 + sqrt(sqrt(2 + cos(1)) + 3) + z = ((1 + cos(2))**4 + 1).expand() + assert sqrtdenest(z) == z + z = sqrt(((1 + sqrt(sqrt(2 + cos(3*x)) + 3))**2 + 1).expand()) + assert sqrtdenest(z) == z + c = cos(3) + c2 = c**2 + assert sqrtdenest(sqrt(2*sqrt(1 + r3)*c + c2 + 1 + r3*c2)) == \ + -1 - sqrt(1 + r3)*c + ra = sqrt(1 + r3) + z = sqrt(20*ra*sqrt(3 + 3*r3) + 12*r3*ra*sqrt(3 + 3*r3) + 64*r3 + 112) + assert sqrtdenest(z) == z + + +def test_issue_5857(): + from sympy.abc import x, y + z = sqrt(1/(4*r3 + 7) + 1) + ans = (r2 + r6)/(r3 + 2) + assert sqrtdenest(z) == ans + assert sqrtdenest(1 + z) == 1 + ans + assert sqrtdenest(Integral(z + 1, (x, 1, 2))) == \ + Integral(1 + ans, (x, 1, 2)) + assert sqrtdenest(x + sqrt(y)) == x + sqrt(y) + ans = (r2 + r6)/(r3 + 2) + assert sqrtdenest(z) == ans + assert sqrtdenest(1 + z) == 1 + ans + assert sqrtdenest(Integral(z + 1, (x, 1, 2))) == \ + Integral(1 + ans, (x, 1, 2)) + assert sqrtdenest(x + sqrt(y)) == x + sqrt(y) + + +def test_subsets(): + assert subsets(1) == [[1]] + assert subsets(4) == [ + [1, 0, 0, 0], [0, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 0], [1, 0, 1, 0], + [0, 1, 1, 0], [1, 1, 1, 0], [0, 0, 0, 1], [1, 0, 0, 1], [0, 1, 0, 1], + [1, 1, 0, 1], [0, 0, 1, 1], [1, 0, 1, 1], [0, 1, 1, 1], [1, 1, 1, 1]] + + +def test_issue_5653(): + assert sqrtdenest( + sqrt(2 + sqrt(2 + sqrt(2)))) == sqrt(2 + sqrt(2 + sqrt(2))) + +def test_issue_12420(): + assert sqrtdenest((3 - sqrt(2)*sqrt(4 + 3*I) + 3*I)/2) == I + e = 3 - sqrt(2)*sqrt(4 + I) + 3*I + assert sqrtdenest(e) == e + +def test_sqrt_ratcomb(): + assert sqrtdenest(sqrt(1 + r3) + sqrt(3 + 3*r3) - sqrt(10 + 6*r3)) == 0 + +def test_issue_18041(): + e = -sqrt(-2 + 2*sqrt(3)*I) + assert sqrtdenest(e) == -1 - sqrt(3)*I + +def test_issue_19914(): + a = Integer(-8) + b = Integer(-1) + r = Integer(63) + d2 = a*a - b*b*r + + assert _sqrt_numeric_denest(a, b, r, d2) == \ + sqrt(14)*I/2 + 3*sqrt(2)*I/2 + assert sqrtdenest(sqrt(-8-sqrt(63))) == sqrt(14)*I/2 + 3*sqrt(2)*I/2 diff --git a/phivenv/Lib/site-packages/sympy/simplify/tests/test_trigsimp.py b/phivenv/Lib/site-packages/sympy/simplify/tests/test_trigsimp.py new file mode 100644 index 0000000000000000000000000000000000000000..ea091ec8a6c7d654405968e3d035c2bbe02ccdf7 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/simplify/tests/test_trigsimp.py @@ -0,0 +1,520 @@ +from itertools import product +from sympy.core.function import (Subs, count_ops, diff, expand) +from sympy.core.numbers import (E, I, Rational, pi) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.hyperbolic import (cosh, coth, sinh, tanh) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import (cos, cot, sin, tan) +from sympy.functions.elementary.trigonometric import (acos, asin, atan2) +from sympy.functions.elementary.trigonometric import (asec, acsc) +from sympy.functions.elementary.trigonometric import (acot, atan) +from sympy.integrals.integrals import integrate +from sympy.matrices.dense import Matrix +from sympy.simplify.simplify import simplify +from sympy.simplify.trigsimp import (exptrigsimp, trigsimp) + +from sympy.testing.pytest import XFAIL + +from sympy.abc import x, y + + + +def test_trigsimp1(): + x, y = symbols('x,y') + + assert trigsimp(1 - sin(x)**2) == cos(x)**2 + assert trigsimp(1 - cos(x)**2) == sin(x)**2 + assert trigsimp(sin(x)**2 + cos(x)**2) == 1 + assert trigsimp(1 + tan(x)**2) == 1/cos(x)**2 + assert trigsimp(1/cos(x)**2 - 1) == tan(x)**2 + assert trigsimp(1/cos(x)**2 - tan(x)**2) == 1 + assert trigsimp(1 + cot(x)**2) == 1/sin(x)**2 + assert trigsimp(1/sin(x)**2 - 1) == 1/tan(x)**2 + assert trigsimp(1/sin(x)**2 - cot(x)**2) == 1 + + assert trigsimp(5*cos(x)**2 + 5*sin(x)**2) == 5 + assert trigsimp(5*cos(x/2)**2 + 2*sin(x/2)**2) == 3*cos(x)/2 + Rational(7, 2) + + assert trigsimp(sin(x)/cos(x)) == tan(x) + assert trigsimp(2*tan(x)*cos(x)) == 2*sin(x) + assert trigsimp(cot(x)**3*sin(x)**3) == cos(x)**3 + assert trigsimp(y*tan(x)**2/sin(x)**2) == y/cos(x)**2 + assert trigsimp(cot(x)/cos(x)) == 1/sin(x) + + assert trigsimp(sin(x + y) + sin(x - y)) == 2*sin(x)*cos(y) + assert trigsimp(sin(x + y) - sin(x - y)) == 2*sin(y)*cos(x) + assert trigsimp(cos(x + y) + cos(x - y)) == 2*cos(x)*cos(y) + assert trigsimp(cos(x + y) - cos(x - y)) == -2*sin(x)*sin(y) + assert trigsimp(tan(x + y) - tan(x)/(1 - tan(x)*tan(y))) == \ + sin(y)/(-sin(y)*tan(x) + cos(y)) # -tan(y)/(tan(x)*tan(y) - 1) + + assert trigsimp(sinh(x + y) + sinh(x - y)) == 2*sinh(x)*cosh(y) + assert trigsimp(sinh(x + y) - sinh(x - y)) == 2*sinh(y)*cosh(x) + assert trigsimp(cosh(x + y) + cosh(x - y)) == 2*cosh(x)*cosh(y) + assert trigsimp(cosh(x + y) - cosh(x - y)) == 2*sinh(x)*sinh(y) + assert trigsimp(tanh(x + y) - tanh(x)/(1 + tanh(x)*tanh(y))) == \ + sinh(y)/(sinh(y)*tanh(x) + cosh(y)) + + assert trigsimp(cos(0.12345)**2 + sin(0.12345)**2) == 1.0 + e = 2*sin(x)**2 + 2*cos(x)**2 + assert trigsimp(log(e)) == log(2) + + +def test_trigsimp1a(): + assert trigsimp(sin(2)**2*cos(3)*exp(2)/cos(2)**2) == tan(2)**2*cos(3)*exp(2) + assert trigsimp(tan(2)**2*cos(3)*exp(2)*cos(2)**2) == sin(2)**2*cos(3)*exp(2) + assert trigsimp(cot(2)*cos(3)*exp(2)*sin(2)) == cos(3)*exp(2)*cos(2) + assert trigsimp(tan(2)*cos(3)*exp(2)/sin(2)) == cos(3)*exp(2)/cos(2) + assert trigsimp(cot(2)*cos(3)*exp(2)/cos(2)) == cos(3)*exp(2)/sin(2) + assert trigsimp(cot(2)*cos(3)*exp(2)*tan(2)) == cos(3)*exp(2) + assert trigsimp(sinh(2)*cos(3)*exp(2)/cosh(2)) == tanh(2)*cos(3)*exp(2) + assert trigsimp(tanh(2)*cos(3)*exp(2)*cosh(2)) == sinh(2)*cos(3)*exp(2) + assert trigsimp(coth(2)*cos(3)*exp(2)*sinh(2)) == cosh(2)*cos(3)*exp(2) + assert trigsimp(tanh(2)*cos(3)*exp(2)/sinh(2)) == cos(3)*exp(2)/cosh(2) + assert trigsimp(coth(2)*cos(3)*exp(2)/cosh(2)) == cos(3)*exp(2)/sinh(2) + assert trigsimp(coth(2)*cos(3)*exp(2)*tanh(2)) == cos(3)*exp(2) + + +def test_trigsimp2(): + x, y = symbols('x,y') + assert trigsimp(cos(x)**2*sin(y)**2 + cos(x)**2*cos(y)**2 + sin(x)**2, + recursive=True) == 1 + assert trigsimp(sin(x)**2*sin(y)**2 + sin(x)**2*cos(y)**2 + cos(x)**2, + recursive=True) == 1 + assert trigsimp( + Subs(x, x, sin(y)**2 + cos(y)**2)) == Subs(x, x, 1) + + +def test_issue_4373(): + x = Symbol("x") + assert abs(trigsimp(2.0*sin(x)**2 + 2.0*cos(x)**2) - 2.0) < 1e-10 + + +def test_trigsimp3(): + x, y = symbols('x,y') + assert trigsimp(sin(x)/cos(x)) == tan(x) + assert trigsimp(sin(x)**2/cos(x)**2) == tan(x)**2 + assert trigsimp(sin(x)**3/cos(x)**3) == tan(x)**3 + assert trigsimp(sin(x)**10/cos(x)**10) == tan(x)**10 + + assert trigsimp(cos(x)/sin(x)) == 1/tan(x) + assert trigsimp(cos(x)**2/sin(x)**2) == 1/tan(x)**2 + assert trigsimp(cos(x)**10/sin(x)**10) == 1/tan(x)**10 + + assert trigsimp(tan(x)) == trigsimp(sin(x)/cos(x)) + + +def test_issue_4661(): + a, x, y = symbols('a x y') + eq = -4*sin(x)**4 + 4*cos(x)**4 - 8*cos(x)**2 + assert trigsimp(eq) == -4 + n = sin(x)**6 + 4*sin(x)**4*cos(x)**2 + 5*sin(x)**2*cos(x)**4 + 2*cos(x)**6 + d = -sin(x)**2 - 2*cos(x)**2 + assert simplify(n/d) == -1 + assert trigsimp(-2*cos(x)**2 + cos(x)**4 - sin(x)**4) == -1 + eq = (- sin(x)**3/4)*cos(x) + (cos(x)**3/4)*sin(x) - sin(2*x)*cos(2*x)/8 + assert trigsimp(eq) == 0 + + +def test_issue_4494(): + a, b = symbols('a b') + eq = sin(a)**2*sin(b)**2 + cos(a)**2*cos(b)**2*tan(a)**2 + cos(a)**2 + assert trigsimp(eq) == 1 + + +def test_issue_5948(): + a, x, y = symbols('a x y') + assert trigsimp(diff(integrate(cos(x)/sin(x)**7, x), x)) == \ + cos(x)/sin(x)**7 + + +def test_issue_4775(): + a, x, y = symbols('a x y') + assert trigsimp(sin(x)*cos(y)+cos(x)*sin(y)) == sin(x + y) + assert trigsimp(sin(x)*cos(y)+cos(x)*sin(y)+3) == sin(x + y) + 3 + + +def test_issue_4280(): + a, x, y = symbols('a x y') + assert trigsimp(cos(x)**2 + cos(y)**2*sin(x)**2 + sin(y)**2*sin(x)**2) == 1 + assert trigsimp(a**2*sin(x)**2 + a**2*cos(y)**2*cos(x)**2 + a**2*cos(x)**2*sin(y)**2) == a**2 + assert trigsimp(a**2*cos(y)**2*sin(x)**2 + a**2*sin(y)**2*sin(x)**2) == a**2*sin(x)**2 + + +def test_issue_3210(): + eqs = (sin(2)*cos(3) + sin(3)*cos(2), + -sin(2)*sin(3) + cos(2)*cos(3), + sin(2)*cos(3) - sin(3)*cos(2), + sin(2)*sin(3) + cos(2)*cos(3), + sin(2)*sin(3) + cos(2)*cos(3) + cos(2), + sinh(2)*cosh(3) + sinh(3)*cosh(2), + sinh(2)*sinh(3) + cosh(2)*cosh(3), + ) + assert [trigsimp(e) for e in eqs] == [ + sin(5), + cos(5), + -sin(1), + cos(1), + cos(1) + cos(2), + sinh(5), + cosh(5), + ] + + +def test_trigsimp_issues(): + a, x, y = symbols('a x y') + + # issue 4625 - factor_terms works, too + assert trigsimp(sin(x)**3 + cos(x)**2*sin(x)) == sin(x) + + # issue 5948 + assert trigsimp(diff(integrate(cos(x)/sin(x)**3, x), x)) == \ + cos(x)/sin(x)**3 + assert trigsimp(diff(integrate(sin(x)/cos(x)**3, x), x)) == \ + sin(x)/cos(x)**3 + + # check integer exponents + e = sin(x)**y/cos(x)**y + assert trigsimp(e) == e + assert trigsimp(e.subs(y, 2)) == tan(x)**2 + assert trigsimp(e.subs(x, 1)) == tan(1)**y + + # check for multiple patterns + assert (cos(x)**2/sin(x)**2*cos(y)**2/sin(y)**2).trigsimp() == \ + 1/tan(x)**2/tan(y)**2 + assert trigsimp(cos(x)/sin(x)*cos(x+y)/sin(x+y)) == \ + 1/(tan(x)*tan(x + y)) + + eq = cos(2)*(cos(3) + 1)**2/(cos(3) - 1)**2 + assert trigsimp(eq) == eq.factor() # factor makes denom (-1 + cos(3))**2 + assert trigsimp(cos(2)*(cos(3) + 1)**2*(cos(3) - 1)**2) == \ + cos(2)*sin(3)**4 + + # issue 6789; this generates an expression that formerly caused + # trigsimp to hang + assert cot(x).equals(tan(x)) is False + + # nan or the unchanged expression is ok, but not sin(1) + z = cos(x)**2 + sin(x)**2 - 1 + z1 = tan(x)**2 - 1/cot(x)**2 + n = (1 + z1/z) + assert trigsimp(sin(n)) != sin(1) + eq = x*(n - 1) - x*n + assert trigsimp(eq) is S.NaN + assert trigsimp(eq, recursive=True) is S.NaN + assert trigsimp(1).is_Integer + + assert trigsimp(-sin(x)**4 - 2*sin(x)**2*cos(x)**2 - cos(x)**4) == -1 + + +def test_trigsimp_issue_2515(): + x = Symbol('x') + assert trigsimp(x*cos(x)*tan(x)) == x*sin(x) + assert trigsimp(-sin(x) + cos(x)*tan(x)) == 0 + + +def test_trigsimp_issue_3826(): + assert trigsimp(tan(2*x).expand(trig=True)) == tan(2*x) + + +def test_trigsimp_issue_4032(): + n = Symbol('n', integer=True, positive=True) + assert trigsimp(2**(n/2)*cos(pi*n/4)/2 + 2**(n - 1)/2) == \ + 2**(n/2)*cos(pi*n/4)/2 + 2**n/4 + + +def test_trigsimp_issue_7761(): + assert trigsimp(cosh(pi/4)) == cosh(pi/4) + + +def test_trigsimp_noncommutative(): + x, y = symbols('x,y') + A, B = symbols('A,B', commutative=False) + + assert trigsimp(A - A*sin(x)**2) == A*cos(x)**2 + assert trigsimp(A - A*cos(x)**2) == A*sin(x)**2 + assert trigsimp(A*sin(x)**2 + A*cos(x)**2) == A + assert trigsimp(A + A*tan(x)**2) == A/cos(x)**2 + assert trigsimp(A/cos(x)**2 - A) == A*tan(x)**2 + assert trigsimp(A/cos(x)**2 - A*tan(x)**2) == A + assert trigsimp(A + A*cot(x)**2) == A/sin(x)**2 + assert trigsimp(A/sin(x)**2 - A) == A/tan(x)**2 + assert trigsimp(A/sin(x)**2 - A*cot(x)**2) == A + + assert trigsimp(y*A*cos(x)**2 + y*A*sin(x)**2) == y*A + + assert trigsimp(A*sin(x)/cos(x)) == A*tan(x) + assert trigsimp(A*tan(x)*cos(x)) == A*sin(x) + assert trigsimp(A*cot(x)**3*sin(x)**3) == A*cos(x)**3 + assert trigsimp(y*A*tan(x)**2/sin(x)**2) == y*A/cos(x)**2 + assert trigsimp(A*cot(x)/cos(x)) == A/sin(x) + + assert trigsimp(A*sin(x + y) + A*sin(x - y)) == 2*A*sin(x)*cos(y) + assert trigsimp(A*sin(x + y) - A*sin(x - y)) == 2*A*sin(y)*cos(x) + assert trigsimp(A*cos(x + y) + A*cos(x - y)) == 2*A*cos(x)*cos(y) + assert trigsimp(A*cos(x + y) - A*cos(x - y)) == -2*A*sin(x)*sin(y) + + assert trigsimp(A*sinh(x + y) + A*sinh(x - y)) == 2*A*sinh(x)*cosh(y) + assert trigsimp(A*sinh(x + y) - A*sinh(x - y)) == 2*A*sinh(y)*cosh(x) + assert trigsimp(A*cosh(x + y) + A*cosh(x - y)) == 2*A*cosh(x)*cosh(y) + assert trigsimp(A*cosh(x + y) - A*cosh(x - y)) == 2*A*sinh(x)*sinh(y) + + assert trigsimp(A*cos(0.12345)**2 + A*sin(0.12345)**2) == 1.0*A + + +def test_hyperbolic_simp(): + x, y = symbols('x,y') + + assert trigsimp(sinh(x)**2 + 1) == cosh(x)**2 + assert trigsimp(cosh(x)**2 - 1) == sinh(x)**2 + assert trigsimp(cosh(x)**2 - sinh(x)**2) == 1 + assert trigsimp(1 - tanh(x)**2) == 1/cosh(x)**2 + assert trigsimp(1 - 1/cosh(x)**2) == tanh(x)**2 + assert trigsimp(tanh(x)**2 + 1/cosh(x)**2) == 1 + assert trigsimp(coth(x)**2 - 1) == 1/sinh(x)**2 + assert trigsimp(1/sinh(x)**2 + 1) == 1/tanh(x)**2 + assert trigsimp(coth(x)**2 - 1/sinh(x)**2) == 1 + + assert trigsimp(5*cosh(x)**2 - 5*sinh(x)**2) == 5 + assert trigsimp(5*cosh(x/2)**2 - 2*sinh(x/2)**2) == 3*cosh(x)/2 + Rational(7, 2) + + assert trigsimp(sinh(x)/cosh(x)) == tanh(x) + assert trigsimp(tanh(x)) == trigsimp(sinh(x)/cosh(x)) + assert trigsimp(cosh(x)/sinh(x)) == 1/tanh(x) + assert trigsimp(2*tanh(x)*cosh(x)) == 2*sinh(x) + assert trigsimp(coth(x)**3*sinh(x)**3) == cosh(x)**3 + assert trigsimp(y*tanh(x)**2/sinh(x)**2) == y/cosh(x)**2 + assert trigsimp(coth(x)/cosh(x)) == 1/sinh(x) + + for a in (pi/6*I, pi/4*I, pi/3*I): + assert trigsimp(sinh(a)*cosh(x) + cosh(a)*sinh(x)) == sinh(x + a) + assert trigsimp(-sinh(a)*cosh(x) + cosh(a)*sinh(x)) == sinh(x - a) + + e = 2*cosh(x)**2 - 2*sinh(x)**2 + assert trigsimp(log(e)) == log(2) + + # issue 19535: + assert trigsimp(sqrt(cosh(x)**2 - 1)) == sqrt(sinh(x)**2) + + assert trigsimp(cosh(x)**2*cosh(y)**2 - cosh(x)**2*sinh(y)**2 - sinh(x)**2, + recursive=True) == 1 + assert trigsimp(sinh(x)**2*sinh(y)**2 - sinh(x)**2*cosh(y)**2 + cosh(x)**2, + recursive=True) == 1 + + assert abs(trigsimp(2.0*cosh(x)**2 - 2.0*sinh(x)**2) - 2.0) < 1e-10 + + assert trigsimp(sinh(x)**2/cosh(x)**2) == tanh(x)**2 + assert trigsimp(sinh(x)**3/cosh(x)**3) == tanh(x)**3 + assert trigsimp(sinh(x)**10/cosh(x)**10) == tanh(x)**10 + assert trigsimp(cosh(x)**3/sinh(x)**3) == 1/tanh(x)**3 + + assert trigsimp(cosh(x)/sinh(x)) == 1/tanh(x) + assert trigsimp(cosh(x)**2/sinh(x)**2) == 1/tanh(x)**2 + assert trigsimp(cosh(x)**10/sinh(x)**10) == 1/tanh(x)**10 + + assert trigsimp(x*cosh(x)*tanh(x)) == x*sinh(x) + assert trigsimp(-sinh(x) + cosh(x)*tanh(x)) == 0 + + assert tan(x) != 1/cot(x) # cot doesn't auto-simplify + + assert trigsimp(tan(x) - 1/cot(x)) == 0 + assert trigsimp(3*tanh(x)**7 - 2/coth(x)**7) == tanh(x)**7 + + +def test_trigsimp_groebner(): + from sympy.simplify.trigsimp import trigsimp_groebner + + c = cos(x) + s = sin(x) + ex = (4*s*c + 12*s + 5*c**3 + 21*c**2 + 23*c + 15)/( + -s*c**2 + 2*s*c + 15*s + 7*c**3 + 31*c**2 + 37*c + 21) + resnum = (5*s - 5*c + 1) + resdenom = (8*s - 6*c) + results = [resnum/resdenom, (-resnum)/(-resdenom)] + assert trigsimp_groebner(ex) in results + assert trigsimp_groebner(s/c, hints=[tan]) == tan(x) + assert trigsimp_groebner(c*s) == c*s + assert trigsimp((-s + 1)/c + c/(-s + 1), + method='groebner') == 2/c + assert trigsimp((-s + 1)/c + c/(-s + 1), + method='groebner', polynomial=True) == 2/c + + # Test quick=False works + assert trigsimp_groebner(ex, hints=[2]) in results + assert trigsimp_groebner(ex, hints=[int(2)]) in results + + # test "I" + assert trigsimp_groebner(sin(I*x)/cos(I*x), hints=[tanh]) == I*tanh(x) + + # test hyperbolic / sums + assert trigsimp_groebner((tanh(x)+tanh(y))/(1+tanh(x)*tanh(y)), + hints=[(tanh, x, y)]) == tanh(x + y) + + +def test_issue_2827_trigsimp_methods(): + measure1 = lambda expr: len(str(expr)) + measure2 = lambda expr: -count_ops(expr) + # Return the most complicated result + expr = (x + 1)/(x + sin(x)**2 + cos(x)**2) + ans = Matrix([1]) + M = Matrix([expr]) + assert trigsimp(M, method='fu', measure=measure1) == ans + assert trigsimp(M, method='fu', measure=measure2) != ans + # all methods should work with Basic expressions even if they + # aren't Expr + M = Matrix.eye(1) + assert all(trigsimp(M, method=m) == M for m in + 'fu matching groebner old'.split()) + # watch for E in exptrigsimp, not only exp() + eq = 1/sqrt(E) + E + assert exptrigsimp(eq) == eq + +def test_issue_15129_trigsimp_methods(): + t1 = Matrix([sin(Rational(1, 50)), cos(Rational(1, 50)), 0]) + t2 = Matrix([sin(Rational(1, 25)), cos(Rational(1, 25)), 0]) + t3 = Matrix([cos(Rational(1, 25)), sin(Rational(1, 25)), 0]) + r1 = t1.dot(t2) + r2 = t1.dot(t3) + assert trigsimp(r1) == cos(Rational(1, 50)) + assert trigsimp(r2) == sin(Rational(3, 50)) + +def test_exptrigsimp(): + def valid(a, b): + from sympy.core.random import verify_numerically as tn + if not (tn(a, b) and a == b): + return False + return True + + assert exptrigsimp(exp(x) + exp(-x)) == 2*cosh(x) + assert exptrigsimp(exp(x) - exp(-x)) == 2*sinh(x) + assert exptrigsimp((2*exp(x)-2*exp(-x))/(exp(x)+exp(-x))) == 2*tanh(x) + assert exptrigsimp((2*exp(2*x)-2)/(exp(2*x)+1)) == 2*tanh(x) + e = [cos(x) + I*sin(x), cos(x) - I*sin(x), + cosh(x) - sinh(x), cosh(x) + sinh(x)] + ok = [exp(I*x), exp(-I*x), exp(-x), exp(x)] + assert all(valid(i, j) for i, j in zip( + [exptrigsimp(ei) for ei in e], ok)) + + ue = [cos(x) + sin(x), cos(x) - sin(x), + cosh(x) + I*sinh(x), cosh(x) - I*sinh(x)] + assert [exptrigsimp(ei) == ei for ei in ue] + + res = [] + ok = [y*tanh(1), 1/(y*tanh(1)), I*y*tan(1), -I/(y*tan(1)), + y*tanh(x), 1/(y*tanh(x)), I*y*tan(x), -I/(y*tan(x)), + y*tanh(1 + I), 1/(y*tanh(1 + I))] + for a in (1, I, x, I*x, 1 + I): + w = exp(a) + eq = y*(w - 1/w)/(w + 1/w) + res.append(simplify(eq)) + res.append(simplify(1/eq)) + assert all(valid(i, j) for i, j in zip(res, ok)) + + for a in range(1, 3): + w = exp(a) + e = w + 1/w + s = simplify(e) + assert s == exptrigsimp(e) + assert valid(s, 2*cosh(a)) + e = w - 1/w + s = simplify(e) + assert s == exptrigsimp(e) + assert valid(s, 2*sinh(a)) + +def test_exptrigsimp_noncommutative(): + a,b = symbols('a b', commutative=False) + x = Symbol('x', commutative=True) + assert exp(a + x) == exptrigsimp(exp(a)*exp(x)) + p = exp(a)*exp(b) - exp(b)*exp(a) + assert p == exptrigsimp(p) != 0 + +def test_powsimp_on_numbers(): + assert 2**(Rational(1, 3) - 2) == 2**Rational(1, 3)/4 + + +@XFAIL +def test_issue_6811_fail(): + # from doc/src/modules/physics/mechanics/examples.rst, the current `eq` + # at Line 576 (in different variables) was formerly the equivalent and + # shorter expression given below...it would be nice to get the short one + # back again + xp, y, x, z = symbols('xp, y, x, z') + eq = 4*(-19*sin(x)*y + 5*sin(3*x)*y + 15*cos(2*x)*z - 21*z)*xp/(9*cos(x) - 5*cos(3*x)) + assert trigsimp(eq) == -2*(2*cos(x)*tan(x)*y + 3*z)*xp/cos(x) + + +def test_Piecewise(): + e1 = x*(x + y) - y*(x + y) + e2 = sin(x)**2 + cos(x)**2 + e3 = expand((x + y)*y/x) + # s1 = simplify(e1) + s2 = simplify(e2) + # s3 = simplify(e3) + + # trigsimp tries not to touch non-trig containing args + assert trigsimp(Piecewise((e1, e3 < e2), (e3, True))) == \ + Piecewise((e1, e3 < s2), (e3, True)) + + +def test_issue_21594(): + assert simplify(exp(Rational(1,2)) + exp(Rational(-1,2))) == cosh(S.Half)*2 + + +def test_trigsimp_old(): + x, y = symbols('x,y') + + assert trigsimp(1 - sin(x)**2, old=True) == cos(x)**2 + assert trigsimp(1 - cos(x)**2, old=True) == sin(x)**2 + assert trigsimp(sin(x)**2 + cos(x)**2, old=True) == 1 + assert trigsimp(1 + tan(x)**2, old=True) == 1/cos(x)**2 + assert trigsimp(1/cos(x)**2 - 1, old=True) == tan(x)**2 + assert trigsimp(1/cos(x)**2 - tan(x)**2, old=True) == 1 + assert trigsimp(1 + cot(x)**2, old=True) == 1/sin(x)**2 + assert trigsimp(1/sin(x)**2 - cot(x)**2, old=True) == 1 + + assert trigsimp(5*cos(x)**2 + 5*sin(x)**2, old=True) == 5 + + assert trigsimp(sin(x)/cos(x), old=True) == tan(x) + assert trigsimp(2*tan(x)*cos(x), old=True) == 2*sin(x) + assert trigsimp(cot(x)**3*sin(x)**3, old=True) == cos(x)**3 + assert trigsimp(y*tan(x)**2/sin(x)**2, old=True) == y/cos(x)**2 + assert trigsimp(cot(x)/cos(x), old=True) == 1/sin(x) + + assert trigsimp(sin(x + y) + sin(x - y), old=True) == 2*sin(x)*cos(y) + assert trigsimp(sin(x + y) - sin(x - y), old=True) == 2*sin(y)*cos(x) + assert trigsimp(cos(x + y) + cos(x - y), old=True) == 2*cos(x)*cos(y) + assert trigsimp(cos(x + y) - cos(x - y), old=True) == -2*sin(x)*sin(y) + + assert trigsimp(sinh(x + y) + sinh(x - y), old=True) == 2*sinh(x)*cosh(y) + assert trigsimp(sinh(x + y) - sinh(x - y), old=True) == 2*sinh(y)*cosh(x) + assert trigsimp(cosh(x + y) + cosh(x - y), old=True) == 2*cosh(x)*cosh(y) + assert trigsimp(cosh(x + y) - cosh(x - y), old=True) == 2*sinh(x)*sinh(y) + + assert trigsimp(cos(0.12345)**2 + sin(0.12345)**2, old=True) == 1.0 + + assert trigsimp(sin(x)/cos(x), old=True, method='combined') == tan(x) + assert trigsimp(sin(x)/cos(x), old=True, method='groebner') == sin(x)/cos(x) + assert trigsimp(sin(x)/cos(x), old=True, method='groebner', hints=[tan]) == tan(x) + + assert trigsimp(1-sin(sin(x)**2+cos(x)**2)**2, old=True, deep=True) == cos(1)**2 + + +def test_trigsimp_inverse(): + alpha = symbols('alpha') + s, c = sin(alpha), cos(alpha) + + for finv in [asin, acos, asec, acsc, atan, acot]: + f = finv.inverse(None) + assert alpha == trigsimp(finv(f(alpha)), inverse=True) + + # test atan2(cos, sin), atan2(sin, cos), etc... + for a, b in [[c, s], [s, c]]: + for i, j in product([-1, 1], repeat=2): + angle = atan2(i*b, j*a) + angle_inverted = trigsimp(angle, inverse=True) + assert angle_inverted != angle # assures simplification happened + assert sin(angle_inverted) == trigsimp(sin(angle)) + assert cos(angle_inverted) == trigsimp(cos(angle)) diff --git a/phivenv/Lib/site-packages/sympy/simplify/traversaltools.py b/phivenv/Lib/site-packages/sympy/simplify/traversaltools.py new file mode 100644 index 0000000000000000000000000000000000000000..75b0bd0d8fd198cb12640ab8a0fe63a23c81ed8f --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/simplify/traversaltools.py @@ -0,0 +1,15 @@ +from sympy.core.traversal import use as _use +from sympy.utilities.decorator import deprecated + +use = deprecated( + """ + Using use from the sympy.simplify.traversaltools submodule is + deprecated. + + Instead, use use from the top-level sympy namespace, like + + sympy.use + """, + deprecated_since_version="1.10", + active_deprecations_target="deprecated-traversal-functions-moved" +)(_use) diff --git a/phivenv/Lib/site-packages/sympy/simplify/trigsimp.py b/phivenv/Lib/site-packages/sympy/simplify/trigsimp.py new file mode 100644 index 0000000000000000000000000000000000000000..fe5be1444a4625e4b63b339877e441d12cfbe8de --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/simplify/trigsimp.py @@ -0,0 +1,1252 @@ +from collections import defaultdict +from functools import reduce + +from sympy.core import (sympify, Basic, S, Expr, factor_terms, + Mul, Add, bottom_up) +from sympy.core.cache import cacheit +from sympy.core.function import (count_ops, _mexpand, FunctionClass, expand, + expand_mul, _coeff_isneg, Derivative) +from sympy.core.numbers import I, Integer +from sympy.core.intfunc import igcd +from sympy.core.sorting import _nodes +from sympy.core.symbol import Dummy, symbols, Wild +from sympy.external.gmpy import SYMPY_INTS +from sympy.functions import sin, cos, exp, cosh, tanh, sinh, tan, cot, coth +from sympy.functions import atan2 +from sympy.functions.elementary.hyperbolic import HyperbolicFunction +from sympy.functions.elementary.trigonometric import TrigonometricFunction +from sympy.polys import Poly, factor, cancel, parallel_poly_from_expr +from sympy.polys.domains import ZZ +from sympy.polys.polyerrors import PolificationFailed +from sympy.polys.polytools import groebner +from sympy.simplify.cse_main import cse +from sympy.strategies.core import identity +from sympy.strategies.tree import greedy +from sympy.utilities.iterables import iterable +from sympy.utilities.misc import debug + +def trigsimp_groebner(expr, hints=[], quick=False, order="grlex", + polynomial=False): + """ + Simplify trigonometric expressions using a groebner basis algorithm. + + Explanation + =========== + + This routine takes a fraction involving trigonometric or hyperbolic + expressions, and tries to simplify it. The primary metric is the + total degree. Some attempts are made to choose the simplest possible + expression of the minimal degree, but this is non-rigorous, and also + very slow (see the ``quick=True`` option). + + If ``polynomial`` is set to True, instead of simplifying numerator and + denominator together, this function just brings numerator and denominator + into a canonical form. This is much faster, but has potentially worse + results. However, if the input is a polynomial, then the result is + guaranteed to be an equivalent polynomial of minimal degree. + + The most important option is hints. Its entries can be any of the + following: + + - a natural number + - a function + - an iterable of the form (func, var1, var2, ...) + - anything else, interpreted as a generator + + A number is used to indicate that the search space should be increased. + A function is used to indicate that said function is likely to occur in a + simplified expression. + An iterable is used indicate that func(var1 + var2 + ...) is likely to + occur in a simplified . + An additional generator also indicates that it is likely to occur. + (See examples below). + + This routine carries out various computationally intensive algorithms. + The option ``quick=True`` can be used to suppress one particularly slow + step (at the expense of potentially more complicated results, but never at + the expense of increased total degree). + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy import sin, tan, cos, sinh, cosh, tanh + >>> from sympy.simplify.trigsimp import trigsimp_groebner + + Suppose you want to simplify ``sin(x)*cos(x)``. Naively, nothing happens: + + >>> ex = sin(x)*cos(x) + >>> trigsimp_groebner(ex) + sin(x)*cos(x) + + This is because ``trigsimp_groebner`` only looks for a simplification + involving just ``sin(x)`` and ``cos(x)``. You can tell it to also try + ``2*x`` by passing ``hints=[2]``: + + >>> trigsimp_groebner(ex, hints=[2]) + sin(2*x)/2 + >>> trigsimp_groebner(sin(x)**2 - cos(x)**2, hints=[2]) + -cos(2*x) + + Increasing the search space this way can quickly become expensive. A much + faster way is to give a specific expression that is likely to occur: + + >>> trigsimp_groebner(ex, hints=[sin(2*x)]) + sin(2*x)/2 + + Hyperbolic expressions are similarly supported: + + >>> trigsimp_groebner(sinh(2*x)/sinh(x)) + 2*cosh(x) + + Note how no hints had to be passed, since the expression already involved + ``2*x``. + + The tangent function is also supported. You can either pass ``tan`` in the + hints, to indicate that tan should be tried whenever cosine or sine are, + or you can pass a specific generator: + + >>> trigsimp_groebner(sin(x)/cos(x), hints=[tan]) + tan(x) + >>> trigsimp_groebner(sinh(x)/cosh(x), hints=[tanh(x)]) + tanh(x) + + Finally, you can use the iterable form to suggest that angle sum formulae + should be tried: + + >>> ex = (tan(x) + tan(y))/(1 - tan(x)*tan(y)) + >>> trigsimp_groebner(ex, hints=[(tan, x, y)]) + tan(x + y) + """ + # TODO + # - preprocess by replacing everything by funcs we can handle + # - optionally use cot instead of tan + # - more intelligent hinting. + # For example, if the ideal is small, and we have sin(x), sin(y), + # add sin(x + y) automatically... ? + # - algebraic numbers ... + # - expressions of lowest degree are not distinguished properly + # e.g. 1 - sin(x)**2 + # - we could try to order the generators intelligently, so as to influence + # which monomials appear in the quotient basis + + # THEORY + # ------ + # Ratsimpmodprime above can be used to "simplify" a rational function + # modulo a prime ideal. "Simplify" mainly means finding an equivalent + # expression of lower total degree. + # + # We intend to use this to simplify trigonometric functions. To do that, + # we need to decide (a) which ring to use, and (b) modulo which ideal to + # simplify. In practice, (a) means settling on a list of "generators" + # a, b, c, ..., such that the fraction we want to simplify is a rational + # function in a, b, c, ..., with coefficients in ZZ (integers). + # (2) means that we have to decide what relations to impose on the + # generators. There are two practical problems: + # (1) The ideal has to be *prime* (a technical term). + # (2) The relations have to be polynomials in the generators. + # + # We typically have two kinds of generators: + # - trigonometric expressions, like sin(x), cos(5*x), etc + # - "everything else", like gamma(x), pi, etc. + # + # Since this function is trigsimp, we will concentrate on what to do with + # trigonometric expressions. We can also simplify hyperbolic expressions, + # but the extensions should be clear. + # + # One crucial point is that all *other* generators really should behave + # like indeterminates. In particular if (say) "I" is one of them, then + # in fact I**2 + 1 = 0 and we may and will compute non-sensical + # expressions. However, we can work with a dummy and add the relation + # I**2 + 1 = 0 to our ideal, then substitute back in the end. + # + # Now regarding trigonometric generators. We split them into groups, + # according to the argument of the trigonometric functions. We want to + # organise this in such a way that most trigonometric identities apply in + # the same group. For example, given sin(x), cos(2*x) and cos(y), we would + # group as [sin(x), cos(2*x)] and [cos(y)]. + # + # Our prime ideal will be built in three steps: + # (1) For each group, compute a "geometrically prime" ideal of relations. + # Geometrically prime means that it generates a prime ideal in + # CC[gens], not just ZZ[gens]. + # (2) Take the union of all the generators of the ideals for all groups. + # By the geometric primality condition, this is still prime. + # (3) Add further inter-group relations which preserve primality. + # + # Step (1) works as follows. We will isolate common factors in the + # argument, so that all our generators are of the form sin(n*x), cos(n*x) + # or tan(n*x), with n an integer. Suppose first there are no tan terms. + # The ideal [sin(x)**2 + cos(x)**2 - 1] is geometrically prime, since + # X**2 + Y**2 - 1 is irreducible over CC. + # Now, if we have a generator sin(n*x), than we can, using trig identities, + # express sin(n*x) as a polynomial in sin(x) and cos(x). We can add this + # relation to the ideal, preserving geometric primality, since the quotient + # ring is unchanged. + # Thus we have treated all sin and cos terms. + # For tan(n*x), we add a relation tan(n*x)*cos(n*x) - sin(n*x) = 0. + # (This requires of course that we already have relations for cos(n*x) and + # sin(n*x).) It is not obvious, but it seems that this preserves geometric + # primality. + # XXX A real proof would be nice. HELP! + # Sketch that is a prime ideal of + # CC[S, C, T]: + # - it suffices to show that the projective closure in CP**3 is + # irreducible + # - using the half-angle substitutions, we can express sin(x), tan(x), + # cos(x) as rational functions in tan(x/2) + # - from this, we get a rational map from CP**1 to our curve + # - this is a morphism, hence the curve is prime + # + # Step (2) is trivial. + # + # Step (3) works by adding selected relations of the form + # sin(x + y) - sin(x)*cos(y) - sin(y)*cos(x), etc. Geometric primality is + # preserved by the same argument as before. + + def parse_hints(hints): + """Split hints into (n, funcs, iterables, gens).""" + n = 1 + funcs, iterables, gens = [], [], [] + for e in hints: + if isinstance(e, (SYMPY_INTS, Integer)): + n = e + elif isinstance(e, FunctionClass): + funcs.append(e) + elif iterable(e): + iterables.append((e[0], e[1:])) + # XXX sin(x+2y)? + # Note: we go through polys so e.g. + # sin(-x) -> -sin(x) -> sin(x) + gens.extend(parallel_poly_from_expr( + [e[0](x) for x in e[1:]] + [e[0](Add(*e[1:]))])[1].gens) + else: + gens.append(e) + return n, funcs, iterables, gens + + def build_ideal(x, terms): + """ + Build generators for our ideal. ``Terms`` is an iterable with elements of + the form (fn, coeff), indicating that we have a generator fn(coeff*x). + + If any of the terms is trigonometric, sin(x) and cos(x) are guaranteed + to appear in terms. Similarly for hyperbolic functions. For tan(n*x), + sin(n*x) and cos(n*x) are guaranteed. + """ + I = [] + y = Dummy('y') + for fn, coeff in terms: + for c, s, t, rel in ( + [cos, sin, tan, cos(x)**2 + sin(x)**2 - 1], + [cosh, sinh, tanh, cosh(x)**2 - sinh(x)**2 - 1]): + if coeff == 1 and fn in [c, s]: + I.append(rel) + elif fn == t: + I.append(t(coeff*x)*c(coeff*x) - s(coeff*x)) + elif fn in [c, s]: + cn = fn(coeff*y).expand(trig=True).subs(y, x) + I.append(fn(coeff*x) - cn) + return list(set(I)) + + def analyse_gens(gens, hints): + """ + Analyse the generators ``gens``, using the hints ``hints``. + + The meaning of ``hints`` is described in the main docstring. + Return a new list of generators, and also the ideal we should + work with. + """ + # First parse the hints + n, funcs, iterables, extragens = parse_hints(hints) + debug('n=%s funcs: %s iterables: %s extragens: %s', + (funcs, iterables, extragens)) + + # We just add the extragens to gens and analyse them as before + gens = list(gens) + gens.extend(extragens) + + # remove duplicates + funcs = list(set(funcs)) + iterables = list(set(iterables)) + gens = list(set(gens)) + + # all the functions we can do anything with + allfuncs = {sin, cos, tan, sinh, cosh, tanh} + # sin(3*x) -> ((3, x), sin) + trigterms = [(g.args[0].as_coeff_mul(), g.func) for g in gens + if g.func in allfuncs] + # Our list of new generators - start with anything that we cannot + # work with (i.e. is not a trigonometric term) + freegens = [g for g in gens if g.func not in allfuncs] + newgens = [] + trigdict = {} + for (coeff, var), fn in trigterms: + trigdict.setdefault(var, []).append((coeff, fn)) + res = [] # the ideal + + for key, val in trigdict.items(): + # We have now assembeled a dictionary. Its keys are common + # arguments in trigonometric expressions, and values are lists of + # pairs (fn, coeff). x0, (fn, coeff) in trigdict means that we + # need to deal with fn(coeff*x0). We take the rational gcd of the + # coeffs, call it ``gcd``. We then use x = x0/gcd as "base symbol", + # all other arguments are integral multiples thereof. + # We will build an ideal which works with sin(x), cos(x). + # If hint tan is provided, also work with tan(x). Moreover, if + # n > 1, also work with sin(k*x) for k <= n, and similarly for cos + # (and tan if the hint is provided). Finally, any generators which + # the ideal does not work with but we need to accommodate (either + # because it was in expr or because it was provided as a hint) + # we also build into the ideal. + # This selection process is expressed in the list ``terms``. + # build_ideal then generates the actual relations in our ideal, + # from this list. + fns = [x[1] for x in val] + val = [x[0] for x in val] + gcd = reduce(igcd, val) + terms = [(fn, v/gcd) for (fn, v) in zip(fns, val)] + fs = set(funcs + fns) + for c, s, t in ([cos, sin, tan], [cosh, sinh, tanh]): + if any(x in fs for x in (c, s, t)): + fs.add(c) + fs.add(s) + for fn in fs: + terms.extend((fn, k) for k in range(1, n + 1)) + extra = [] + for fn, v in terms: + if fn == tan: + extra.append((sin, v)) + extra.append((cos, v)) + if fn in [sin, cos] and tan in fs: + extra.append((tan, v)) + if fn == tanh: + extra.append((sinh, v)) + extra.append((cosh, v)) + if fn in [sinh, cosh] and tanh in fs: + extra.append((tanh, v)) + terms.extend(extra) + x = gcd*Mul(*key) + r = build_ideal(x, terms) + res.extend(r) + newgens.extend({fn(v*x) for fn, v in terms}) + + # Add generators for compound expressions from iterables + for fn, args in iterables: + if fn == tan: + # Tan expressions are recovered from sin and cos. + iterables.extend([(sin, args), (cos, args)]) + elif fn == tanh: + # Tanh expressions are recovered from sihn and cosh. + iterables.extend([(sinh, args), (cosh, args)]) + else: + dummys = symbols('d:%i' % len(args), cls=Dummy) + expr = fn( Add(*dummys)).expand(trig=True).subs(list(zip(dummys, args))) + res.append(fn(Add(*args)) - expr) + + if myI in gens: + res.append(myI**2 + 1) + freegens.remove(myI) + newgens.append(myI) + + return res, freegens, newgens + + myI = Dummy('I') + expr = expr.subs(S.ImaginaryUnit, myI) + subs = [(myI, S.ImaginaryUnit)] + + num, denom = cancel(expr).as_numer_denom() + try: + (pnum, pdenom), opt = parallel_poly_from_expr([num, denom]) + except PolificationFailed: + return expr + debug('initial gens:', opt.gens) + ideal, freegens, gens = analyse_gens(opt.gens, hints) + debug('ideal:', ideal) + debug('new gens:', gens, " -- len", len(gens)) + debug('free gens:', freegens, " -- len", len(gens)) + # NOTE we force the domain to be ZZ to stop polys from injecting generators + # (which is usually a sign of a bug in the way we build the ideal) + if not gens: + return expr + G = groebner(ideal, order=order, gens=gens, domain=ZZ) + debug('groebner basis:', list(G), " -- len", len(G)) + + # If our fraction is a polynomial in the free generators, simplify all + # coefficients separately: + + from sympy.simplify.ratsimp import ratsimpmodprime + + if freegens and pdenom.has_only_gens(*set(gens).intersection(pdenom.gens)): + num = Poly(num, gens=gens+freegens).eject(*gens) + res = [] + for monom, coeff in num.terms(): + ourgens = set(parallel_poly_from_expr([coeff, denom])[1].gens) + # We compute the transitive closure of all generators that can + # be reached from our generators through relations in the ideal. + changed = True + while changed: + changed = False + for p in ideal: + p = Poly(p) + if not ourgens.issuperset(p.gens) and \ + not p.has_only_gens(*set(p.gens).difference(ourgens)): + changed = True + ourgens.update(p.exclude().gens) + # NOTE preserve order! + realgens = [x for x in gens if x in ourgens] + # The generators of the ideal have now been (implicitly) split + # into two groups: those involving ourgens and those that don't. + # Since we took the transitive closure above, these two groups + # live in subgrings generated by a *disjoint* set of variables. + # Any sensible groebner basis algorithm will preserve this disjoint + # structure (i.e. the elements of the groebner basis can be split + # similarly), and and the two subsets of the groebner basis then + # form groebner bases by themselves. (For the smaller generating + # sets, of course.) + ourG = [g.as_expr() for g in G.polys if + g.has_only_gens(*ourgens.intersection(g.gens))] + res.append(Mul(*[a**b for a, b in zip(freegens, monom)]) * \ + ratsimpmodprime(coeff/denom, ourG, order=order, + gens=realgens, quick=quick, domain=ZZ, + polynomial=polynomial).subs(subs)) + return Add(*res) + # NOTE The following is simpler and has less assumptions on the + # groebner basis algorithm. If the above turns out to be broken, + # use this. + return Add(*[Mul(*[a**b for a, b in zip(freegens, monom)]) * \ + ratsimpmodprime(coeff/denom, list(G), order=order, + gens=gens, quick=quick, domain=ZZ) + for monom, coeff in num.terms()]) + else: + return ratsimpmodprime( + expr, list(G), order=order, gens=freegens+gens, + quick=quick, domain=ZZ, polynomial=polynomial).subs(subs) + + +_trigs = (TrigonometricFunction, HyperbolicFunction) + + +def _trigsimp_inverse(rv): + + def check_args(x, y): + try: + return x.args[0] == y.args[0] + except IndexError: + return False + + def f(rv): + # for simple functions + g = getattr(rv, 'inverse', None) + if (g is not None and isinstance(rv.args[0], g()) and + isinstance(g()(1), TrigonometricFunction)): + return rv.args[0].args[0] + + # for atan2 simplifications, harder because atan2 has 2 args + if isinstance(rv, atan2): + y, x = rv.args + if _coeff_isneg(y): + return -f(atan2(-y, x)) + elif _coeff_isneg(x): + return S.Pi - f(atan2(y, -x)) + + if check_args(x, y): + if isinstance(y, sin) and isinstance(x, cos): + return x.args[0] + if isinstance(y, cos) and isinstance(x, sin): + return S.Pi / 2 - x.args[0] + + return rv + + return bottom_up(rv, f) + + +def trigsimp(expr, inverse=False, **opts): + """Returns a reduced expression by using known trig identities. + + Parameters + ========== + + inverse : bool, optional + If ``inverse=True``, it will be assumed that a composition of inverse + functions, such as sin and asin, can be cancelled in any order. + For example, ``asin(sin(x))`` will yield ``x`` without checking whether + x belongs to the set where this relation is true. The default is False. + Default : True + + method : string, optional + Specifies the method to use. Valid choices are: + + - ``'matching'``, default + - ``'groebner'`` + - ``'combined'`` + - ``'fu'`` + - ``'old'`` + + If ``'matching'``, simplify the expression recursively by targeting + common patterns. If ``'groebner'``, apply an experimental groebner + basis algorithm. In this case further options are forwarded to + ``trigsimp_groebner``, please refer to + its docstring. If ``'combined'``, it first runs the groebner basis + algorithm with small default parameters, then runs the ``'matching'`` + algorithm. If ``'fu'``, run the collection of trigonometric + transformations described by Fu, et al. (see the + :py:func:`~sympy.simplify.fu.fu` docstring). If ``'old'``, the original + SymPy trig simplification function is run. + opts : + Optional keyword arguments passed to the method. See each method's + function docstring for details. + + Examples + ======== + + >>> from sympy import trigsimp, sin, cos, log + >>> from sympy.abc import x + >>> e = 2*sin(x)**2 + 2*cos(x)**2 + >>> trigsimp(e) + 2 + + Simplification occurs wherever trigonometric functions are located. + + >>> trigsimp(log(e)) + log(2) + + Using ``method='groebner'`` (or ``method='combined'``) might lead to + greater simplification. + + The old trigsimp routine can be accessed as with method ``method='old'``. + + >>> from sympy import coth, tanh + >>> t = 3*tanh(x)**7 - 2/coth(x)**7 + >>> trigsimp(t, method='old') == t + True + >>> trigsimp(t) + tanh(x)**7 + + """ + from sympy.simplify.fu import fu + + expr = sympify(expr) + + _eval_trigsimp = getattr(expr, '_eval_trigsimp', None) + if _eval_trigsimp is not None: + return _eval_trigsimp(**opts) + + old = opts.pop('old', False) + if not old: + opts.pop('deep', None) + opts.pop('recursive', None) + method = opts.pop('method', 'matching') + else: + method = 'old' + + def groebnersimp(ex, **opts): + def traverse(e): + if e.is_Atom: + return e + args = [traverse(x) for x in e.args] + if e.is_Function or e.is_Pow: + args = [trigsimp_groebner(x, **opts) for x in args] + return e.func(*args) + new = traverse(ex) + if not isinstance(new, Expr): + return new + return trigsimp_groebner(new, **opts) + + trigsimpfunc = { + 'fu': (lambda x: fu(x, **opts)), + 'matching': (lambda x: futrig(x)), + 'groebner': (lambda x: groebnersimp(x, **opts)), + 'combined': (lambda x: futrig(groebnersimp(x, + polynomial=True, hints=[2, tan]))), + 'old': lambda x: trigsimp_old(x, **opts), + }[method] + + expr_simplified = trigsimpfunc(expr) + if inverse: + expr_simplified = _trigsimp_inverse(expr_simplified) + + return expr_simplified + + +def exptrigsimp(expr): + """ + Simplifies exponential / trigonometric / hyperbolic functions. + + Examples + ======== + + >>> from sympy import exptrigsimp, exp, cosh, sinh + >>> from sympy.abc import z + + >>> exptrigsimp(exp(z) + exp(-z)) + 2*cosh(z) + >>> exptrigsimp(cosh(z) - sinh(z)) + exp(-z) + """ + from sympy.simplify.fu import hyper_as_trig, TR2i + + def exp_trig(e): + # select the better of e, and e rewritten in terms of exp or trig + # functions + choices = [e] + if e.has(*_trigs): + choices.append(e.rewrite(exp)) + choices.append(e.rewrite(cos)) + return min(*choices, key=count_ops) + newexpr = bottom_up(expr, exp_trig) + + def f(rv): + if not rv.is_Mul: + return rv + commutative_part, noncommutative_part = rv.args_cnc() + # Since as_powers_dict loses order information, + # if there is more than one noncommutative factor, + # it should only be used to simplify the commutative part. + if (len(noncommutative_part) > 1): + return f(Mul(*commutative_part))*Mul(*noncommutative_part) + rvd = rv.as_powers_dict() + newd = rvd.copy() + + def signlog(expr, sign=S.One): + if expr is S.Exp1: + return sign, S.One + elif isinstance(expr, exp) or (expr.is_Pow and expr.base == S.Exp1): + return sign, expr.exp + elif sign is S.One: + return signlog(-expr, sign=-S.One) + else: + return None, None + + ee = rvd[S.Exp1] + for k in rvd: + if k.is_Add and len(k.args) == 2: + # k == c*(1 + sign*E**x) + c = k.args[0] + sign, x = signlog(k.args[1]/c) + if not x: + continue + m = rvd[k] + newd[k] -= m + if ee == -x*m/2: + # sinh and cosh + newd[S.Exp1] -= ee + ee = 0 + if sign == 1: + newd[2*c*cosh(x/2)] += m + else: + newd[-2*c*sinh(x/2)] += m + elif newd[1 - sign*S.Exp1**x] == -m: + # tanh + del newd[1 - sign*S.Exp1**x] + if sign == 1: + newd[-c/tanh(x/2)] += m + else: + newd[-c*tanh(x/2)] += m + else: + newd[1 + sign*S.Exp1**x] += m + newd[c] += m + + return Mul(*[k**newd[k] for k in newd]) + newexpr = bottom_up(newexpr, f) + + # sin/cos and sinh/cosh ratios to tan and tanh, respectively + if newexpr.has(HyperbolicFunction): + e, f = hyper_as_trig(newexpr) + newexpr = f(TR2i(e)) + if newexpr.has(TrigonometricFunction): + newexpr = TR2i(newexpr) + + # can we ever generate an I where there was none previously? + if not (newexpr.has(I) and not expr.has(I)): + expr = newexpr + return expr + +#-------------------- the old trigsimp routines --------------------- + +def trigsimp_old(expr, *, first=True, **opts): + """ + Reduces expression by using known trig identities. + + Notes + ===== + + deep: + - Apply trigsimp inside all objects with arguments + + recursive: + - Use common subexpression elimination (cse()) and apply + trigsimp recursively (this is quite expensive if the + expression is large) + + method: + - Determine the method to use. Valid choices are 'matching' (default), + 'groebner', 'combined', 'fu' and 'futrig'. If 'matching', simplify the + expression recursively by pattern matching. If 'groebner', apply an + experimental groebner basis algorithm. In this case further options + are forwarded to ``trigsimp_groebner``, please refer to its docstring. + If 'combined', first run the groebner basis algorithm with small + default parameters, then run the 'matching' algorithm. 'fu' runs the + collection of trigonometric transformations described by Fu, et al. + (see the `fu` docstring) while `futrig` runs a subset of Fu-transforms + that mimic the behavior of `trigsimp`. + + compare: + - show input and output from `trigsimp` and `futrig` when different, + but returns the `trigsimp` value. + + Examples + ======== + + >>> from sympy import trigsimp, sin, cos, log, cot + >>> from sympy.abc import x + >>> e = 2*sin(x)**2 + 2*cos(x)**2 + >>> trigsimp(e, old=True) + 2 + >>> trigsimp(log(e), old=True) + log(2*sin(x)**2 + 2*cos(x)**2) + >>> trigsimp(log(e), deep=True, old=True) + log(2) + + Using `method="groebner"` (or `"combined"`) can sometimes lead to a lot + more simplification: + + >>> e = (-sin(x) + 1)/cos(x) + cos(x)/(-sin(x) + 1) + >>> trigsimp(e, old=True) + (1 - sin(x))/cos(x) + cos(x)/(1 - sin(x)) + >>> trigsimp(e, method="groebner", old=True) + 2/cos(x) + + >>> trigsimp(1/cot(x)**2, compare=True, old=True) + futrig: tan(x)**2 + cot(x)**(-2) + + """ + old = expr + if first: + if not expr.has(*_trigs): + return expr + + trigsyms = set().union(*[t.free_symbols for t in expr.atoms(*_trigs)]) + if len(trigsyms) > 1: + from sympy.simplify.simplify import separatevars + + d = separatevars(expr) + if d.is_Mul: + d = separatevars(d, dict=True) or d + if isinstance(d, dict): + expr = 1 + for v in d.values(): + # remove hollow factoring + was = v + v = expand_mul(v) + opts['first'] = False + vnew = trigsimp(v, **opts) + if vnew == v: + vnew = was + expr *= vnew + old = expr + else: + if d.is_Add: + for s in trigsyms: + r, e = expr.as_independent(s) + if r: + opts['first'] = False + expr = r + trigsimp(e, **opts) + if not expr.is_Add: + break + old = expr + + recursive = opts.pop('recursive', False) + deep = opts.pop('deep', False) + method = opts.pop('method', 'matching') + + def groebnersimp(ex, deep, **opts): + def traverse(e): + if e.is_Atom: + return e + args = [traverse(x) for x in e.args] + if e.is_Function or e.is_Pow: + args = [trigsimp_groebner(x, **opts) for x in args] + return e.func(*args) + if deep: + ex = traverse(ex) + return trigsimp_groebner(ex, **opts) + + trigsimpfunc = { + 'matching': (lambda x, d: _trigsimp(x, d)), + 'groebner': (lambda x, d: groebnersimp(x, d, **opts)), + 'combined': (lambda x, d: _trigsimp(groebnersimp(x, + d, polynomial=True, hints=[2, tan]), + d)) + }[method] + + if recursive: + w, g = cse(expr) + g = trigsimpfunc(g[0], deep) + + for sub in reversed(w): + g = g.subs(sub[0], sub[1]) + g = trigsimpfunc(g, deep) + result = g + else: + result = trigsimpfunc(expr, deep) + + if opts.get('compare', False): + f = futrig(old) + if f != result: + print('\tfutrig:', f) + + return result + + +def _dotrig(a, b): + """Helper to tell whether ``a`` and ``b`` have the same sorts + of symbols in them -- no need to test hyperbolic patterns against + expressions that have no hyperbolics in them.""" + return a.func == b.func and ( + a.has(TrigonometricFunction) and b.has(TrigonometricFunction) or + a.has(HyperbolicFunction) and b.has(HyperbolicFunction)) + + +_trigpat = None +def _trigpats(): + global _trigpat + a, b, c = symbols('a b c', cls=Wild) + d = Wild('d', commutative=False) + + # for the simplifications like sinh/cosh -> tanh: + # DO NOT REORDER THE FIRST 14 since these are assumed to be in this + # order in _match_div_rewrite. + matchers_division = ( + (a*sin(b)**c/cos(b)**c, a*tan(b)**c, sin(b), cos(b)), + (a*tan(b)**c*cos(b)**c, a*sin(b)**c, sin(b), cos(b)), + (a*cot(b)**c*sin(b)**c, a*cos(b)**c, sin(b), cos(b)), + (a*tan(b)**c/sin(b)**c, a/cos(b)**c, sin(b), cos(b)), + (a*cot(b)**c/cos(b)**c, a/sin(b)**c, sin(b), cos(b)), + (a*cot(b)**c*tan(b)**c, a, sin(b), cos(b)), + (a*(cos(b) + 1)**c*(cos(b) - 1)**c, + a*(-sin(b)**2)**c, cos(b) + 1, cos(b) - 1), + (a*(sin(b) + 1)**c*(sin(b) - 1)**c, + a*(-cos(b)**2)**c, sin(b) + 1, sin(b) - 1), + + (a*sinh(b)**c/cosh(b)**c, a*tanh(b)**c, S.One, S.One), + (a*tanh(b)**c*cosh(b)**c, a*sinh(b)**c, S.One, S.One), + (a*coth(b)**c*sinh(b)**c, a*cosh(b)**c, S.One, S.One), + (a*tanh(b)**c/sinh(b)**c, a/cosh(b)**c, S.One, S.One), + (a*coth(b)**c/cosh(b)**c, a/sinh(b)**c, S.One, S.One), + (a*coth(b)**c*tanh(b)**c, a, S.One, S.One), + + (c*(tanh(a) + tanh(b))/(1 + tanh(a)*tanh(b)), + tanh(a + b)*c, S.One, S.One), + ) + + matchers_add = ( + (c*sin(a)*cos(b) + c*cos(a)*sin(b) + d, sin(a + b)*c + d), + (c*cos(a)*cos(b) - c*sin(a)*sin(b) + d, cos(a + b)*c + d), + (c*sin(a)*cos(b) - c*cos(a)*sin(b) + d, sin(a - b)*c + d), + (c*cos(a)*cos(b) + c*sin(a)*sin(b) + d, cos(a - b)*c + d), + (c*sinh(a)*cosh(b) + c*sinh(b)*cosh(a) + d, sinh(a + b)*c + d), + (c*cosh(a)*cosh(b) + c*sinh(a)*sinh(b) + d, cosh(a + b)*c + d), + ) + + # for cos(x)**2 + sin(x)**2 -> 1 + matchers_identity = ( + (a*sin(b)**2, a - a*cos(b)**2), + (a*tan(b)**2, a*(1/cos(b))**2 - a), + (a*cot(b)**2, a*(1/sin(b))**2 - a), + (a*sin(b + c), a*(sin(b)*cos(c) + sin(c)*cos(b))), + (a*cos(b + c), a*(cos(b)*cos(c) - sin(b)*sin(c))), + (a*tan(b + c), a*((tan(b) + tan(c))/(1 - tan(b)*tan(c)))), + + (a*sinh(b)**2, a*cosh(b)**2 - a), + (a*tanh(b)**2, a - a*(1/cosh(b))**2), + (a*coth(b)**2, a + a*(1/sinh(b))**2), + (a*sinh(b + c), a*(sinh(b)*cosh(c) + sinh(c)*cosh(b))), + (a*cosh(b + c), a*(cosh(b)*cosh(c) + sinh(b)*sinh(c))), + (a*tanh(b + c), a*((tanh(b) + tanh(c))/(1 + tanh(b)*tanh(c)))), + + ) + + # Reduce any lingering artifacts, such as sin(x)**2 changing + # to 1-cos(x)**2 when sin(x)**2 was "simpler" + artifacts = ( + (a - a*cos(b)**2 + c, a*sin(b)**2 + c, cos), + (a - a*(1/cos(b))**2 + c, -a*tan(b)**2 + c, cos), + (a - a*(1/sin(b))**2 + c, -a*cot(b)**2 + c, sin), + + (a - a*cosh(b)**2 + c, -a*sinh(b)**2 + c, cosh), + (a - a*(1/cosh(b))**2 + c, a*tanh(b)**2 + c, cosh), + (a + a*(1/sinh(b))**2 + c, a*coth(b)**2 + c, sinh), + + # same as above but with noncommutative prefactor + (a*d - a*d*cos(b)**2 + c, a*d*sin(b)**2 + c, cos), + (a*d - a*d*(1/cos(b))**2 + c, -a*d*tan(b)**2 + c, cos), + (a*d - a*d*(1/sin(b))**2 + c, -a*d*cot(b)**2 + c, sin), + + (a*d - a*d*cosh(b)**2 + c, -a*d*sinh(b)**2 + c, cosh), + (a*d - a*d*(1/cosh(b))**2 + c, a*d*tanh(b)**2 + c, cosh), + (a*d + a*d*(1/sinh(b))**2 + c, a*d*coth(b)**2 + c, sinh), + ) + + _trigpat = (a, b, c, d, matchers_division, matchers_add, + matchers_identity, artifacts) + return _trigpat + + +def _replace_mul_fpowxgpow(expr, f, g, rexp, h, rexph): + """Helper for _match_div_rewrite. + + Replace f(b_)**c_*g(b_)**(rexp(c_)) with h(b)**rexph(c) if f(b_) + and g(b_) are both positive or if c_ is an integer. + """ + # assert expr.is_Mul and expr.is_commutative and f != g + fargs = defaultdict(int) + gargs = defaultdict(int) + args = [] + for x in expr.args: + if x.is_Pow or x.func in (f, g): + b, e = x.as_base_exp() + if b.is_positive or e.is_integer: + if b.func == f: + fargs[b.args[0]] += e + continue + elif b.func == g: + gargs[b.args[0]] += e + continue + args.append(x) + common = set(fargs) & set(gargs) + hit = False + while common: + key = common.pop() + fe = fargs.pop(key) + ge = gargs.pop(key) + if fe == rexp(ge): + args.append(h(key)**rexph(fe)) + hit = True + else: + fargs[key] = fe + gargs[key] = ge + if not hit: + return expr + while fargs: + key, e = fargs.popitem() + args.append(f(key)**e) + while gargs: + key, e = gargs.popitem() + args.append(g(key)**e) + return Mul(*args) + + +_idn = lambda x: x +_midn = lambda x: -x +_one = lambda x: S.One + +def _match_div_rewrite(expr, i): + """helper for __trigsimp""" + if i == 0: + expr = _replace_mul_fpowxgpow(expr, sin, cos, + _midn, tan, _idn) + elif i == 1: + expr = _replace_mul_fpowxgpow(expr, tan, cos, + _idn, sin, _idn) + elif i == 2: + expr = _replace_mul_fpowxgpow(expr, cot, sin, + _idn, cos, _idn) + elif i == 3: + expr = _replace_mul_fpowxgpow(expr, tan, sin, + _midn, cos, _midn) + elif i == 4: + expr = _replace_mul_fpowxgpow(expr, cot, cos, + _midn, sin, _midn) + elif i == 5: + expr = _replace_mul_fpowxgpow(expr, cot, tan, + _idn, _one, _idn) + # i in (6, 7) is skipped + elif i == 8: + expr = _replace_mul_fpowxgpow(expr, sinh, cosh, + _midn, tanh, _idn) + elif i == 9: + expr = _replace_mul_fpowxgpow(expr, tanh, cosh, + _idn, sinh, _idn) + elif i == 10: + expr = _replace_mul_fpowxgpow(expr, coth, sinh, + _idn, cosh, _idn) + elif i == 11: + expr = _replace_mul_fpowxgpow(expr, tanh, sinh, + _midn, cosh, _midn) + elif i == 12: + expr = _replace_mul_fpowxgpow(expr, coth, cosh, + _midn, sinh, _midn) + elif i == 13: + expr = _replace_mul_fpowxgpow(expr, coth, tanh, + _idn, _one, _idn) + else: + return None + return expr + + +def _trigsimp(expr, deep=False): + # protect the cache from non-trig patterns; we only allow + # trig patterns to enter the cache + if expr.has(*_trigs): + return __trigsimp(expr, deep) + return expr + + +@cacheit +def __trigsimp(expr, deep=False): + """recursive helper for trigsimp""" + from sympy.simplify.fu import TR10i + + if _trigpat is None: + _trigpats() + a, b, c, d, matchers_division, matchers_add, \ + matchers_identity, artifacts = _trigpat + + if expr.is_Mul: + # do some simplifications like sin/cos -> tan: + if not expr.is_commutative: + com, nc = expr.args_cnc() + expr = _trigsimp(Mul._from_args(com), deep)*Mul._from_args(nc) + else: + for i, (pattern, simp, ok1, ok2) in enumerate(matchers_division): + if not _dotrig(expr, pattern): + continue + + newexpr = _match_div_rewrite(expr, i) + if newexpr is not None: + if newexpr != expr: + expr = newexpr + break + else: + continue + + # use SymPy matching instead + res = expr.match(pattern) + if res and res.get(c, 0): + if not res[c].is_integer: + ok = ok1.subs(res) + if not ok.is_positive: + continue + ok = ok2.subs(res) + if not ok.is_positive: + continue + # if "a" contains any of trig or hyperbolic funcs with + # argument "b" then skip the simplification + if any(w.args[0] == res[b] for w in res[a].atoms( + TrigonometricFunction, HyperbolicFunction)): + continue + # simplify and finish: + expr = simp.subs(res) + break # process below + + if expr.is_Add: + args = [] + for term in expr.args: + if not term.is_commutative: + com, nc = term.args_cnc() + nc = Mul._from_args(nc) + term = Mul._from_args(com) + else: + nc = S.One + term = _trigsimp(term, deep) + for pattern, result in matchers_identity: + res = term.match(pattern) + if res is not None: + term = result.subs(res) + break + args.append(term*nc) + if args != expr.args: + expr = Add(*args) + expr = min(expr, expand(expr), key=count_ops) + if expr.is_Add: + for pattern, result in matchers_add: + if not _dotrig(expr, pattern): + continue + expr = TR10i(expr) + if expr.has(HyperbolicFunction): + res = expr.match(pattern) + # if "d" contains any trig or hyperbolic funcs with + # argument "a" or "b" then skip the simplification; + # this isn't perfect -- see tests + if res is None or not (a in res and b in res) or any( + w.args[0] in (res[a], res[b]) for w in res[d].atoms( + TrigonometricFunction, HyperbolicFunction)): + continue + expr = result.subs(res) + break + + # Reduce any lingering artifacts, such as sin(x)**2 changing + # to 1 - cos(x)**2 when sin(x)**2 was "simpler" + for pattern, result, ex in artifacts: + if not _dotrig(expr, pattern): + continue + # Substitute a new wild that excludes some function(s) + # to help influence a better match. This is because + # sometimes, for example, 'a' would match sec(x)**2 + a_t = Wild('a', exclude=[ex]) + pattern = pattern.subs(a, a_t) + result = result.subs(a, a_t) + + m = expr.match(pattern) + was = None + while m and was != expr: + was = expr + if m[a_t] == 0 or \ + -m[a_t] in m[c].args or m[a_t] + m[c] == 0: + break + if d in m and m[a_t]*m[d] + m[c] == 0: + break + expr = result.subs(m) + m = expr.match(pattern) + m.setdefault(c, S.Zero) + + elif expr.is_Mul or expr.is_Pow or deep and expr.args: + expr = expr.func(*[_trigsimp(a, deep) for a in expr.args]) + + try: + if not expr.has(*_trigs): + raise TypeError + e = expr.atoms(exp) + new = expr.rewrite(exp, deep=deep) + if new == e: + raise TypeError + fnew = factor(new) + if fnew != new: + new = min([new, factor(new)], key=count_ops) + # if all exp that were introduced disappeared then accept it + if not (new.atoms(exp) - e): + expr = new + except TypeError: + pass + + return expr +#------------------- end of old trigsimp routines -------------------- + + +def futrig(e, *, hyper=True, **kwargs): + """Return simplified ``e`` using Fu-like transformations. + This is not the "Fu" algorithm. This is called by default + from ``trigsimp``. By default, hyperbolics subexpressions + will be simplified, but this can be disabled by setting + ``hyper=False``. + + Examples + ======== + + >>> from sympy import trigsimp, tan, sinh, tanh + >>> from sympy.simplify.trigsimp import futrig + >>> from sympy.abc import x + >>> trigsimp(1/tan(x)**2) + tan(x)**(-2) + + >>> futrig(sinh(x)/tanh(x)) + cosh(x) + + """ + from sympy.simplify.fu import hyper_as_trig + + e = sympify(e) + + if not isinstance(e, Basic): + return e + + if not e.args: + return e + + old = e + e = bottom_up(e, _futrig) + + if hyper and e.has(HyperbolicFunction): + e, f = hyper_as_trig(e) + e = f(bottom_up(e, _futrig)) + + if e != old and e.is_Mul and e.args[0].is_Rational: + # redistribute leading coeff on 2-arg Add + e = Mul(*e.as_coeff_Mul()) + return e + + +def _futrig(e): + """Helper for futrig.""" + from sympy.simplify.fu import ( + TR1, TR2, TR3, TR2i, TR10, L, TR10i, + TR8, TR6, TR15, TR16, TR111, TR5, TRmorrie, TR11, _TR11, TR14, TR22, + TR12) + + if not e.has(TrigonometricFunction): + return e + + if e.is_Mul: + coeff, e = e.as_independent(TrigonometricFunction) + else: + coeff = None + + Lops = lambda x: (L(x), x.count_ops(), _nodes(x), len(x.args), x.is_Add) + trigs = lambda x: x.has(TrigonometricFunction) + + tree = [identity, + ( + TR3, # canonical angles + TR1, # sec-csc -> cos-sin + TR12, # expand tan of sum + lambda x: _eapply(factor, x, trigs), + TR2, # tan-cot -> sin-cos + [identity, lambda x: _eapply(_mexpand, x, trigs)], + TR2i, # sin-cos ratio -> tan + lambda x: _eapply(lambda i: factor(i.normal()), x, trigs), + TR14, # factored identities + TR5, # sin-pow -> cos_pow + TR10, # sin-cos of sums -> sin-cos prod + TR11, _TR11, TR6, # reduce double angles and rewrite cos pows + lambda x: _eapply(factor, x, trigs), + TR14, # factored powers of identities + [identity, lambda x: _eapply(_mexpand, x, trigs)], + TR10i, # sin-cos products > sin-cos of sums + TRmorrie, + [identity, TR8], # sin-cos products -> sin-cos of sums + [identity, lambda x: TR2i(TR2(x))], # tan -> sin-cos -> tan + [ + lambda x: _eapply(expand_mul, TR5(x), trigs), + lambda x: _eapply( + expand_mul, TR15(x), trigs)], # pos/neg powers of sin + [ + lambda x: _eapply(expand_mul, TR6(x), trigs), + lambda x: _eapply( + expand_mul, TR16(x), trigs)], # pos/neg powers of cos + TR111, # tan, sin, cos to neg power -> cot, csc, sec + [identity, TR2i], # sin-cos ratio to tan + [identity, lambda x: _eapply( + expand_mul, TR22(x), trigs)], # tan-cot to sec-csc + TR1, TR2, TR2i, + [identity, lambda x: _eapply( + factor_terms, TR12(x), trigs)], # expand tan of sum + )] + e = greedy(tree, objective=Lops)(e) + + if coeff is not None: + e = coeff * e + + return e + + +def _is_Expr(e): + """_eapply helper to tell whether ``e`` and all its args + are Exprs.""" + if isinstance(e, Derivative): + return _is_Expr(e.expr) + if not isinstance(e, Expr): + return False + return all(_is_Expr(i) for i in e.args) + + +def _eapply(func, e, cond=None): + """Apply ``func`` to ``e`` if all args are Exprs else only + apply it to those args that *are* Exprs.""" + if not isinstance(e, Expr): + return e + if _is_Expr(e) or not e.args: + return func(e) + return e.func(*[ + _eapply(func, ei) if (cond is None or cond(ei)) else ei + for ei in e.args]) diff --git a/phivenv/Lib/site-packages/sympy/solvers/__init__.py b/phivenv/Lib/site-packages/sympy/solvers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..02cfb35a765c748e70ecc36dd78d8d4432118c64 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/solvers/__init__.py @@ -0,0 +1,75 @@ +"""A module for solving all kinds of equations. + + Examples + ======== + + >>> from sympy.solvers import solve + >>> from sympy.abc import x + >>> solve(((x + 1)**5).expand(), x) + [-1] +""" +from sympy.core.assumptions import check_assumptions, failing_assumptions + +from .solvers import solve, solve_linear_system, solve_linear_system_LU, \ + solve_undetermined_coeffs, nsolve, solve_linear, checksol, \ + det_quick, inv_quick + +from sympy.solvers.diophantine.diophantine import diophantine + +from .recurr import rsolve, rsolve_poly, rsolve_ratio, rsolve_hyper + +from .ode import checkodesol, classify_ode, dsolve, \ + homogeneous_order + +from .polysys import solve_poly_system, solve_triangulated, factor_system + +from .pde import pde_separate, pde_separate_add, pde_separate_mul, \ + pdsolve, classify_pde, checkpdesol + +from .deutils import ode_order + +from .inequalities import reduce_inequalities, reduce_abs_inequality, \ + reduce_abs_inequalities, solve_poly_inequality, solve_rational_inequalities, solve_univariate_inequality + +from .decompogen import decompogen + +from .solveset import solveset, linsolve, linear_eq_to_matrix, nonlinsolve, substitution + +from .simplex import lpmin, lpmax, linprog + +# This is here instead of sympy/sets/__init__.py to avoid circular import issues +from ..core.singleton import S +Complexes = S.Complexes + +__all__ = [ + 'solve', 'solve_linear_system', 'solve_linear_system_LU', + 'solve_undetermined_coeffs', 'nsolve', 'solve_linear', 'checksol', + 'det_quick', 'inv_quick', 'check_assumptions', 'failing_assumptions', + + 'diophantine', + + 'rsolve', 'rsolve_poly', 'rsolve_ratio', 'rsolve_hyper', + + 'checkodesol', 'classify_ode', 'dsolve', 'homogeneous_order', + + 'solve_poly_system', 'solve_triangulated', 'factor_system', + + 'pde_separate', 'pde_separate_add', 'pde_separate_mul', 'pdsolve', + 'classify_pde', 'checkpdesol', + + 'ode_order', + + 'reduce_inequalities', 'reduce_abs_inequality', 'reduce_abs_inequalities', + 'solve_poly_inequality', 'solve_rational_inequalities', + 'solve_univariate_inequality', + + 'decompogen', + + 'solveset', 'linsolve', 'linear_eq_to_matrix', 'nonlinsolve', + 'substitution', + + # This is here instead of sympy/sets/__init__.py to avoid circular import issues + 'Complexes', + + 'lpmin', 'lpmax', 'linprog' +] diff --git a/phivenv/Lib/site-packages/sympy/solvers/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/solvers/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..23cb2796e18afd9eee53162ae2d4859565129724 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/solvers/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/solvers/__pycache__/bivariate.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/solvers/__pycache__/bivariate.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26fbc9b10152dec37cba8aa38db63143d818bcd5 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/solvers/__pycache__/bivariate.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/solvers/__pycache__/decompogen.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/solvers/__pycache__/decompogen.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c31ad046fd55005e6056c60ffc7fa44d2012741b Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/solvers/__pycache__/decompogen.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/solvers/__pycache__/deutils.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/solvers/__pycache__/deutils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c51dc16451c41cee235ae954625747ba2a7be5f5 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/solvers/__pycache__/deutils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/solvers/__pycache__/inequalities.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/solvers/__pycache__/inequalities.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68a4c6d845984d610d9329a2380c29d47a193281 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/solvers/__pycache__/inequalities.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/solvers/__pycache__/pde.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/solvers/__pycache__/pde.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4fd7b61cd7405d06d7900d270599b98e252a34cc Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/solvers/__pycache__/pde.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/solvers/__pycache__/polysys.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/solvers/__pycache__/polysys.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea5b82da51ee4b1e80d6e2436ad393e887a5841a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/solvers/__pycache__/polysys.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/solvers/__pycache__/recurr.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/solvers/__pycache__/recurr.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9ad723d3488299afa219f41924f1fa366e211dc Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/solvers/__pycache__/recurr.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/solvers/__pycache__/simplex.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/solvers/__pycache__/simplex.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac2f5a7276704d2fc31436406626a33762593250 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/solvers/__pycache__/simplex.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/solvers/benchmarks/__init__.py b/phivenv/Lib/site-packages/sympy/solvers/benchmarks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/solvers/benchmarks/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/solvers/benchmarks/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd7543159d7dd52d72dd5b52f96fcf8ad63d1aff Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/solvers/benchmarks/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/solvers/benchmarks/__pycache__/bench_solvers.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/solvers/benchmarks/__pycache__/bench_solvers.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b907e9c158efe4ddb5e7e9d2cd5b2d31eff14446 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/solvers/benchmarks/__pycache__/bench_solvers.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/solvers/benchmarks/bench_solvers.py b/phivenv/Lib/site-packages/sympy/solvers/benchmarks/bench_solvers.py new file mode 100644 index 0000000000000000000000000000000000000000..d18102873f7efcde1d111e0e8eca12e208f94663 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/solvers/benchmarks/bench_solvers.py @@ -0,0 +1,12 @@ +from sympy.core.symbol import Symbol +from sympy.matrices.dense import (eye, zeros) +from sympy.solvers.solvers import solve_linear_system + +N = 8 +M = zeros(N, N + 1) +M[:, :N] = eye(N) +S = [Symbol('A%i' % i) for i in range(N)] + + +def timeit_linsolve_trivial(): + solve_linear_system(M, *S) diff --git a/phivenv/Lib/site-packages/sympy/solvers/bivariate.py b/phivenv/Lib/site-packages/sympy/solvers/bivariate.py new file mode 100644 index 0000000000000000000000000000000000000000..72f8e266a16634fa65366e1058543dfe2171ba1c --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/solvers/bivariate.py @@ -0,0 +1,509 @@ +from sympy.core.add import Add +from sympy.core.exprtools import factor_terms +from sympy.core.function import expand_log, _mexpand +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.core.sorting import ordered +from sympy.core.symbol import Dummy +from sympy.functions.elementary.exponential import (LambertW, exp, log) +from sympy.functions.elementary.miscellaneous import root +from sympy.polys.polyroots import roots +from sympy.polys.polytools import Poly, factor +from sympy.simplify.simplify import separatevars +from sympy.simplify.radsimp import collect +from sympy.simplify.simplify import powsimp +from sympy.solvers.solvers import solve, _invert +from sympy.utilities.iterables import uniq + + +def _filtered_gens(poly, symbol): + """process the generators of ``poly``, returning the set of generators that + have ``symbol``. If there are two generators that are inverses of each other, + prefer the one that has no denominator. + + Examples + ======== + + >>> from sympy.solvers.bivariate import _filtered_gens + >>> from sympy import Poly, exp + >>> from sympy.abc import x + >>> _filtered_gens(Poly(x + 1/x + exp(x)), x) + {x, exp(x)} + + """ + # TODO it would be good to pick the smallest divisible power + # instead of the base for something like x**4 + x**2 --> + # return x**2 not x + gens = {g for g in poly.gens if symbol in g.free_symbols} + for g in list(gens): + ag = 1/g + if g in gens and ag in gens: + if ag.as_numer_denom()[1] is not S.One: + g = ag + gens.remove(g) + return gens + + +def _mostfunc(lhs, func, X=None): + """Returns the term in lhs which contains the most of the + func-type things e.g. log(log(x)) wins over log(x) if both terms appear. + + ``func`` can be a function (exp, log, etc...) or any other SymPy object, + like Pow. + + If ``X`` is not ``None``, then the function returns the term composed with the + most ``func`` having the specified variable. + + Examples + ======== + + >>> from sympy.solvers.bivariate import _mostfunc + >>> from sympy import exp + >>> from sympy.abc import x, y + >>> _mostfunc(exp(x) + exp(exp(x) + 2), exp) + exp(exp(x) + 2) + >>> _mostfunc(exp(x) + exp(exp(y) + 2), exp) + exp(exp(y) + 2) + >>> _mostfunc(exp(x) + exp(exp(y) + 2), exp, x) + exp(x) + >>> _mostfunc(x, exp, x) is None + True + >>> _mostfunc(exp(x) + exp(x*y), exp, x) + exp(x) + """ + fterms = [tmp for tmp in lhs.atoms(func) if (not X or + X.is_Symbol and X in tmp.free_symbols or + not X.is_Symbol and tmp.has(X))] + if len(fterms) == 1: + return fterms[0] + elif fterms: + return max(list(ordered(fterms)), key=lambda x: x.count(func)) + return None + + +def _linab(arg, symbol): + """Return ``a, b, X`` assuming ``arg`` can be written as ``a*X + b`` + where ``X`` is a symbol-dependent factor and ``a`` and ``b`` are + independent of ``symbol``. + + Examples + ======== + + >>> from sympy.solvers.bivariate import _linab + >>> from sympy.abc import x, y + >>> from sympy import exp, S + >>> _linab(S(2), x) + (2, 0, 1) + >>> _linab(2*x, x) + (2, 0, x) + >>> _linab(y + y*x + 2*x, x) + (y + 2, y, x) + >>> _linab(3 + 2*exp(x), x) + (2, 3, exp(x)) + """ + arg = factor_terms(arg.expand()) + ind, dep = arg.as_independent(symbol) + if arg.is_Mul and dep.is_Add: + a, b, x = _linab(dep, symbol) + return ind*a, ind*b, x + if not arg.is_Add: + b = 0 + a, x = ind, dep + else: + b = ind + a, x = separatevars(dep).as_independent(symbol, as_Add=False) + if x.could_extract_minus_sign(): + a = -a + x = -x + return a, b, x + + +def _lambert(eq, x): + """ + Given an expression assumed to be in the form + ``F(X, a..f) = a*log(b*X + c) + d*X + f = 0`` + where X = g(x) and x = g^-1(X), return the Lambert solution, + ``x = g^-1(-c/b + (a/d)*W(d/(a*b)*exp(c*d/a/b)*exp(-f/a)))``. + """ + eq = _mexpand(expand_log(eq)) + mainlog = _mostfunc(eq, log, x) + if not mainlog: + return [] # violated assumptions + other = eq.subs(mainlog, 0) + if isinstance(-other, log): + eq = (eq - other).subs(mainlog, mainlog.args[0]) + mainlog = mainlog.args[0] + if not isinstance(mainlog, log): + return [] # violated assumptions + other = -(-other).args[0] + eq += other + if x not in other.free_symbols: + return [] # violated assumptions + d, f, X2 = _linab(other, x) + logterm = collect(eq - other, mainlog) + a = logterm.as_coefficient(mainlog) + if a is None or x in a.free_symbols: + return [] # violated assumptions + logarg = mainlog.args[0] + b, c, X1 = _linab(logarg, x) + if X1 != X2: + return [] # violated assumptions + + # invert the generator X1 so we have x(u) + u = Dummy('rhs') + xusolns = solve(X1 - u, x) + + # There are infinitely many branches for LambertW + # but only branches for k = -1 and 0 might be real. The k = 0 + # branch is real and the k = -1 branch is real if the LambertW argument + # in in range [-1/e, 0]. Since `solve` does not return infinite + # solutions we will only include the -1 branch if it tests as real. + # Otherwise, inclusion of any LambertW in the solution indicates to + # the user that there are imaginary solutions corresponding to + # different k values. + lambert_real_branches = [-1, 0] + sol = [] + + # solution of the given Lambert equation is like + # sol = -c/b + (a/d)*LambertW(arg, k), + # where arg = d/(a*b)*exp((c*d-b*f)/a/b) and k in lambert_real_branches. + # Instead of considering the single arg, `d/(a*b)*exp((c*d-b*f)/a/b)`, + # the individual `p` roots obtained when writing `exp((c*d-b*f)/a/b)` + # as `exp(A/p) = exp(A)**(1/p)`, where `p` is an Integer, are used. + + # calculating args for LambertW + num, den = ((c*d-b*f)/a/b).as_numer_denom() + p, den = den.as_coeff_Mul() + e = exp(num/den) + t = Dummy('t') + args = [d/(a*b)*t for t in roots(t**p - e, t).keys()] + + # calculating solutions from args + for arg in args: + for k in lambert_real_branches: + w = LambertW(arg, k) + if k and not w.is_real: + continue + rhs = -c/b + (a/d)*w + + sol.extend(xu.subs(u, rhs) for xu in xusolns) + return sol + + +def _solve_lambert(f, symbol, gens): + """Return solution to ``f`` if it is a Lambert-type expression + else raise NotImplementedError. + + For ``f(X, a..f) = a*log(b*X + c) + d*X - f = 0`` the solution + for ``X`` is ``X = -c/b + (a/d)*W(d/(a*b)*exp(c*d/a/b)*exp(f/a))``. + There are a variety of forms for `f(X, a..f)` as enumerated below: + + 1a1) + if B**B = R for R not in [0, 1] (since those cases would already + be solved before getting here) then log of both sides gives + log(B) + log(log(B)) = log(log(R)) and + X = log(B), a = 1, b = 1, c = 0, d = 1, f = log(log(R)) + 1a2) + if B*(b*log(B) + c)**a = R then log of both sides gives + log(B) + a*log(b*log(B) + c) = log(R) and + X = log(B), d=1, f=log(R) + 1b) + if a*log(b*B + c) + d*B = R and + X = B, f = R + 2a) + if (b*B + c)*exp(d*B + g) = R then log of both sides gives + log(b*B + c) + d*B + g = log(R) and + X = B, a = 1, f = log(R) - g + 2b) + if g*exp(d*B + h) - b*B = c then the log form is + log(g) + d*B + h - log(b*B + c) = 0 and + X = B, a = -1, f = -h - log(g) + 3) + if d*p**(a*B + g) - b*B = c then the log form is + log(d) + (a*B + g)*log(p) - log(b*B + c) = 0 and + X = B, a = -1, d = a*log(p), f = -log(d) - g*log(p) + """ + + def _solve_even_degree_expr(expr, t, symbol): + """Return the unique solutions of equations derived from + ``expr`` by replacing ``t`` with ``+/- symbol``. + + Parameters + ========== + + expr : Expr + The expression which includes a dummy variable t to be + replaced with +symbol and -symbol. + + symbol : Symbol + The symbol for which a solution is being sought. + + Returns + ======= + + List of unique solution of the two equations generated by + replacing ``t`` with positive and negative ``symbol``. + + Notes + ===== + + If ``expr = 2*log(t) + x/2` then solutions for + ``2*log(x) + x/2 = 0`` and ``2*log(-x) + x/2 = 0`` are + returned by this function. Though this may seem + counter-intuitive, one must note that the ``expr`` being + solved here has been derived from a different expression. For + an expression like ``eq = x**2*g(x) = 1``, if we take the + log of both sides we obtain ``log(x**2) + log(g(x)) = 0``. If + x is positive then this simplifies to + ``2*log(x) + log(g(x)) = 0``; the Lambert-solving routines will + return solutions for this, but we must also consider the + solutions for ``2*log(-x) + log(g(x))`` since those must also + be a solution of ``eq`` which has the same value when the ``x`` + in ``x**2`` is negated. If `g(x)` does not have even powers of + symbol then we do not want to replace the ``x`` there with + ``-x``. So the role of the ``t`` in the expression received by + this function is to mark where ``+/-x`` should be inserted + before obtaining the Lambert solutions. + + """ + nlhs, plhs = [ + expr.xreplace({t: sgn*symbol}) for sgn in (-1, 1)] + sols = _solve_lambert(nlhs, symbol, gens) + if plhs != nlhs: + sols.extend(_solve_lambert(plhs, symbol, gens)) + # uniq is needed for a case like + # 2*log(t) - log(-z**2) + log(z + log(x) + log(z)) + # where substituting t with +/-x gives all the same solution; + # uniq, rather than list(set()), is used to maintain canonical + # order + return list(uniq(sols)) + + nrhs, lhs = f.as_independent(symbol, as_Add=True) + rhs = -nrhs + + lamcheck = [tmp for tmp in gens + if (tmp.func in [exp, log] or + (tmp.is_Pow and symbol in tmp.exp.free_symbols))] + if not lamcheck: + raise NotImplementedError() + + if lhs.is_Add or lhs.is_Mul: + # replacing all even_degrees of symbol with dummy variable t + # since these will need special handling; non-Add/Mul do not + # need this handling + t = Dummy('t', **symbol.assumptions0) + lhs = lhs.replace( + lambda i: # find symbol**even + i.is_Pow and i.base == symbol and i.exp.is_even, + lambda i: # replace t**even + t**i.exp) + + if lhs.is_Add and lhs.has(t): + t_indep = lhs.subs(t, 0) + t_term = lhs - t_indep + _rhs = rhs - t_indep + if not t_term.is_Add and _rhs and not ( + t_term.has(S.ComplexInfinity, S.NaN)): + eq = expand_log(log(t_term) - log(_rhs)) + return _solve_even_degree_expr(eq, t, symbol) + elif lhs.is_Mul and rhs: + # this needs to happen whether t is present or not + lhs = expand_log(log(lhs), force=True) + rhs = log(rhs) + if lhs.has(t) and lhs.is_Add: + # it expanded from Mul to Add + eq = lhs - rhs + return _solve_even_degree_expr(eq, t, symbol) + + # restore symbol in lhs + lhs = lhs.xreplace({t: symbol}) + + lhs = powsimp(factor(lhs, deep=True)) + + # make sure we have inverted as completely as possible + r = Dummy() + i, lhs = _invert(lhs - r, symbol) + rhs = i.xreplace({r: rhs}) + + # For the first forms: + # + # 1a1) B**B = R will arrive here as B*log(B) = log(R) + # lhs is Mul so take log of both sides: + # log(B) + log(log(B)) = log(log(R)) + # 1a2) B*(b*log(B) + c)**a = R will arrive unchanged so + # lhs is Mul, so take log of both sides: + # log(B) + a*log(b*log(B) + c) = log(R) + # 1b) d*log(a*B + b) + c*B = R will arrive unchanged so + # lhs is Add, so isolate c*B and expand log of both sides: + # log(c) + log(B) = log(R - d*log(a*B + b)) + + soln = [] + if not soln: + mainlog = _mostfunc(lhs, log, symbol) + if mainlog: + if lhs.is_Mul and rhs != 0: + soln = _lambert(log(lhs) - log(rhs), symbol) + elif lhs.is_Add: + other = lhs.subs(mainlog, 0) + if other and not other.is_Add and [ + tmp for tmp in other.atoms(Pow) + if symbol in tmp.free_symbols]: + if not rhs: + diff = log(other) - log(other - lhs) + else: + diff = log(lhs - other) - log(rhs - other) + soln = _lambert(expand_log(diff), symbol) + else: + #it's ready to go + soln = _lambert(lhs - rhs, symbol) + + # For the next forms, + # + # collect on main exp + # 2a) (b*B + c)*exp(d*B + g) = R + # lhs is mul, so take log of both sides: + # log(b*B + c) + d*B = log(R) - g + # 2b) g*exp(d*B + h) - b*B = R + # lhs is add, so add b*B to both sides, + # take the log of both sides and rearrange to give + # log(R + b*B) - d*B = log(g) + h + + if not soln: + mainexp = _mostfunc(lhs, exp, symbol) + if mainexp: + lhs = collect(lhs, mainexp) + if lhs.is_Mul and rhs != 0: + soln = _lambert(expand_log(log(lhs) - log(rhs)), symbol) + elif lhs.is_Add: + # move all but mainexp-containing term to rhs + other = lhs.subs(mainexp, 0) + mainterm = lhs - other + rhs = rhs - other + if (mainterm.could_extract_minus_sign() and + rhs.could_extract_minus_sign()): + mainterm *= -1 + rhs *= -1 + diff = log(mainterm) - log(rhs) + soln = _lambert(expand_log(diff), symbol) + + # For the last form: + # + # 3) d*p**(a*B + g) - b*B = c + # collect on main pow, add b*B to both sides, + # take log of both sides and rearrange to give + # a*B*log(p) - log(b*B + c) = -log(d) - g*log(p) + if not soln: + mainpow = _mostfunc(lhs, Pow, symbol) + if mainpow and symbol in mainpow.exp.free_symbols: + lhs = collect(lhs, mainpow) + if lhs.is_Mul and rhs != 0: + # b*B = 0 + soln = _lambert(expand_log(log(lhs) - log(rhs)), symbol) + elif lhs.is_Add: + # move all but mainpow-containing term to rhs + other = lhs.subs(mainpow, 0) + mainterm = lhs - other + rhs = rhs - other + diff = log(mainterm) - log(rhs) + soln = _lambert(expand_log(diff), symbol) + + if not soln: + raise NotImplementedError('%s does not appear to have a solution in ' + 'terms of LambertW' % f) + + return list(ordered(soln)) + + +def bivariate_type(f, x, y, *, first=True): + """Given an expression, f, 3 tests will be done to see what type + of composite bivariate it might be, options for u(x, y) are:: + + x*y + x+y + x*y+x + x*y+y + + If it matches one of these types, ``u(x, y)``, ``P(u)`` and dummy + variable ``u`` will be returned. Solving ``P(u)`` for ``u`` and + equating the solutions to ``u(x, y)`` and then solving for ``x`` or + ``y`` is equivalent to solving the original expression for ``x`` or + ``y``. If ``x`` and ``y`` represent two functions in the same + variable, e.g. ``x = g(t)`` and ``y = h(t)``, then if ``u(x, y) - p`` + can be solved for ``t`` then these represent the solutions to + ``P(u) = 0`` when ``p`` are the solutions of ``P(u) = 0``. + + Only positive values of ``u`` are considered. + + Examples + ======== + + >>> from sympy import solve + >>> from sympy.solvers.bivariate import bivariate_type + >>> from sympy.abc import x, y + >>> eq = (x**2 - 3).subs(x, x + y) + >>> bivariate_type(eq, x, y) + (x + y, _u**2 - 3, _u) + >>> uxy, pu, u = _ + >>> usol = solve(pu, u); usol + [sqrt(3)] + >>> [solve(uxy - s) for s in solve(pu, u)] + [[{x: -y + sqrt(3)}]] + >>> all(eq.subs(s).equals(0) for sol in _ for s in sol) + True + + """ + + u = Dummy('u', positive=True) + + if first: + p = Poly(f, x, y) + f = p.as_expr() + _x = Dummy() + _y = Dummy() + rv = bivariate_type(Poly(f.subs({x: _x, y: _y}), _x, _y), _x, _y, first=False) + if rv: + reps = {_x: x, _y: y} + return rv[0].xreplace(reps), rv[1].xreplace(reps), rv[2] + return + + p = f + f = p.as_expr() + + # f(x*y) + args = Add.make_args(p.as_expr()) + new = [] + for a in args: + a = _mexpand(a.subs(x, u/y)) + free = a.free_symbols + if x in free or y in free: + break + new.append(a) + else: + return x*y, Add(*new), u + + def ok(f, v, c): + new = _mexpand(f.subs(v, c)) + free = new.free_symbols + return None if (x in free or y in free) else new + + # f(a*x + b*y) + new = [] + d = p.degree(x) + if p.degree(y) == d: + a = root(p.coeff_monomial(x**d), d) + b = root(p.coeff_monomial(y**d), d) + new = ok(f, x, (u - b*y)/a) + if new is not None: + return a*x + b*y, new, u + + # f(a*x*y + b*y) + new = [] + d = p.degree(x) + if p.degree(y) == d: + for itry in range(2): + a = root(p.coeff_monomial(x**d*y**d), d) + b = root(p.coeff_monomial(y**d), d) + new = ok(f, x, (u - b*y)/a/y) + if new is not None: + return a*x*y + b*y, new, u + x, y = y, x diff --git a/phivenv/Lib/site-packages/sympy/solvers/decompogen.py b/phivenv/Lib/site-packages/sympy/solvers/decompogen.py new file mode 100644 index 0000000000000000000000000000000000000000..ec1b3b683511a34e6f98b9839d112b87517390d8 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/solvers/decompogen.py @@ -0,0 +1,126 @@ +from sympy.core import (Function, Pow, sympify, Expr) +from sympy.core.relational import Relational +from sympy.core.singleton import S +from sympy.polys import Poly, decompose +from sympy.utilities.misc import func_name +from sympy.functions.elementary.miscellaneous import Min, Max + + +def decompogen(f, symbol): + """ + Computes General functional decomposition of ``f``. + Given an expression ``f``, returns a list ``[f_1, f_2, ..., f_n]``, + where:: + f = f_1 o f_2 o ... f_n = f_1(f_2(... f_n)) + + Note: This is a General decomposition function. It also decomposes + Polynomials. For only Polynomial decomposition see ``decompose`` in polys. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy import decompogen, sqrt, sin, cos + >>> decompogen(sin(cos(x)), x) + [sin(x), cos(x)] + >>> decompogen(sin(x)**2 + sin(x) + 1, x) + [x**2 + x + 1, sin(x)] + >>> decompogen(sqrt(6*x**2 - 5), x) + [sqrt(x), 6*x**2 - 5] + >>> decompogen(sin(sqrt(cos(x**2 + 1))), x) + [sin(x), sqrt(x), cos(x), x**2 + 1] + >>> decompogen(x**4 + 2*x**3 - x - 1, x) + [x**2 - x - 1, x**2 + x] + + """ + f = sympify(f) + if not isinstance(f, Expr) or isinstance(f, Relational): + raise TypeError('expecting Expr but got: `%s`' % func_name(f)) + if symbol not in f.free_symbols: + return [f] + + + # ===== Simple Functions ===== # + if isinstance(f, (Function, Pow)): + if f.is_Pow and f.base == S.Exp1: + arg = f.exp + else: + arg = f.args[0] + if arg == symbol: + return [f] + return [f.subs(arg, symbol)] + decompogen(arg, symbol) + + # ===== Min/Max Functions ===== # + if isinstance(f, (Min, Max)): + args = list(f.args) + d0 = None + for i, a in enumerate(args): + if not a.has_free(symbol): + continue + d = decompogen(a, symbol) + if len(d) == 1: + d = [symbol] + d + if d0 is None: + d0 = d[1:] + elif d[1:] != d0: + # decomposition is not the same for each arg: + # mark as having no decomposition + d = [symbol] + break + args[i] = d[0] + if d[0] == symbol: + return [f] + return [f.func(*args)] + d0 + + # ===== Convert to Polynomial ===== # + fp = Poly(f) + gens = list(filter(lambda x: symbol in x.free_symbols, fp.gens)) + + if len(gens) == 1 and gens[0] != symbol: + f1 = f.subs(gens[0], symbol) + f2 = gens[0] + return [f1] + decompogen(f2, symbol) + + # ===== Polynomial decompose() ====== # + try: + return decompose(f) + except ValueError: + return [f] + + +def compogen(g_s, symbol): + """ + Returns the composition of functions. + Given a list of functions ``g_s``, returns their composition ``f``, + where: + f = g_1 o g_2 o .. o g_n + + Note: This is a General composition function. It also composes Polynomials. + For only Polynomial composition see ``compose`` in polys. + + Examples + ======== + + >>> from sympy.solvers.decompogen import compogen + >>> from sympy.abc import x + >>> from sympy import sqrt, sin, cos + >>> compogen([sin(x), cos(x)], x) + sin(cos(x)) + >>> compogen([x**2 + x + 1, sin(x)], x) + sin(x)**2 + sin(x) + 1 + >>> compogen([sqrt(x), 6*x**2 - 5], x) + sqrt(6*x**2 - 5) + >>> compogen([sin(x), sqrt(x), cos(x), x**2 + 1], x) + sin(sqrt(cos(x**2 + 1))) + >>> compogen([x**2 - x - 1, x**2 + x], x) + -x**2 - x + (x**2 + x)**2 - 1 + """ + if len(g_s) == 1: + return g_s[0] + + foo = g_s[0].subs(symbol, g_s[1]) + + if len(g_s) == 2: + return foo + + return compogen([foo] + g_s[2:], symbol) diff --git a/phivenv/Lib/site-packages/sympy/solvers/deutils.py b/phivenv/Lib/site-packages/sympy/solvers/deutils.py new file mode 100644 index 0000000000000000000000000000000000000000..b13f37b5004fe7bce2545ad2c788ec3feae56025 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/solvers/deutils.py @@ -0,0 +1,273 @@ +"""Utility functions for classifying and solving +ordinary and partial differential equations. + +Contains +======== +_preprocess +ode_order +_desolve + +""" +from sympy.core import Pow +from sympy.core.function import Derivative, AppliedUndef +from sympy.core.relational import Equality +from sympy.core.symbol import Wild + +def _preprocess(expr, func=None, hint='_Integral'): + """Prepare expr for solving by making sure that differentiation + is done so that only func remains in unevaluated derivatives and + (if hint does not end with _Integral) that doit is applied to all + other derivatives. If hint is None, do not do any differentiation. + (Currently this may cause some simple differential equations to + fail.) + + In case func is None, an attempt will be made to autodetect the + function to be solved for. + + >>> from sympy.solvers.deutils import _preprocess + >>> from sympy import Derivative, Function + >>> from sympy.abc import x, y, z + >>> f, g = map(Function, 'fg') + + If f(x)**p == 0 and p>0 then we can solve for f(x)=0 + >>> _preprocess((f(x).diff(x)-4)**5, f(x)) + (Derivative(f(x), x) - 4, f(x)) + + Apply doit to derivatives that contain more than the function + of interest: + + >>> _preprocess(Derivative(f(x) + x, x)) + (Derivative(f(x), x) + 1, f(x)) + + Do others if the differentiation variable(s) intersect with those + of the function of interest or contain the function of interest: + + >>> _preprocess(Derivative(g(x), y, z), f(y)) + (0, f(y)) + >>> _preprocess(Derivative(f(y), z), f(y)) + (0, f(y)) + + Do others if the hint does not end in '_Integral' (the default + assumes that it does): + + >>> _preprocess(Derivative(g(x), y), f(x)) + (Derivative(g(x), y), f(x)) + >>> _preprocess(Derivative(f(x), y), f(x), hint='') + (0, f(x)) + + Do not do any derivatives if hint is None: + + >>> eq = Derivative(f(x) + 1, x) + Derivative(f(x), y) + >>> _preprocess(eq, f(x), hint=None) + (Derivative(f(x) + 1, x) + Derivative(f(x), y), f(x)) + + If it's not clear what the function of interest is, it must be given: + + >>> eq = Derivative(f(x) + g(x), x) + >>> _preprocess(eq, g(x)) + (Derivative(f(x), x) + Derivative(g(x), x), g(x)) + >>> try: _preprocess(eq) + ... except ValueError: print("A ValueError was raised.") + A ValueError was raised. + + """ + if isinstance(expr, Pow): + # if f(x)**p=0 then f(x)=0 (p>0) + if (expr.exp).is_positive: + expr = expr.base + derivs = expr.atoms(Derivative) + if not func: + funcs = set().union(*[d.atoms(AppliedUndef) for d in derivs]) + if len(funcs) != 1: + raise ValueError('The function cannot be ' + 'automatically detected for %s.' % expr) + func = funcs.pop() + fvars = set(func.args) + if hint is None: + return expr, func + reps = [(d, d.doit()) for d in derivs if not hint.endswith('_Integral') or + d.has(func) or set(d.variables) & fvars] + eq = expr.subs(reps) + return eq, func + + +def ode_order(expr, func): + """ + Returns the order of a given differential + equation with respect to func. + + This function is implemented recursively. + + Examples + ======== + + >>> from sympy import Function + >>> from sympy.solvers.deutils import ode_order + >>> from sympy.abc import x + >>> f, g = map(Function, ['f', 'g']) + >>> ode_order(f(x).diff(x, 2) + f(x).diff(x)**2 + + ... f(x).diff(x), f(x)) + 2 + >>> ode_order(f(x).diff(x, 2) + g(x).diff(x, 3), f(x)) + 2 + >>> ode_order(f(x).diff(x, 2) + g(x).diff(x, 3), g(x)) + 3 + + """ + a = Wild('a', exclude=[func]) + if expr.match(a): + return 0 + + if isinstance(expr, Derivative): + if expr.args[0] == func: + return len(expr.variables) + else: + args = expr.args[0].args + rv = len(expr.variables) + if args: + rv += max(ode_order(_, func) for _ in args) + return rv + else: + return max(ode_order(_, func) for _ in expr.args) if expr.args else 0 + + +def _desolve(eq, func=None, hint="default", ics=None, simplify=True, *, prep=True, **kwargs): + """This is a helper function to dsolve and pdsolve in the ode + and pde modules. + + If the hint provided to the function is "default", then a dict with + the following keys are returned + + 'func' - It provides the function for which the differential equation + has to be solved. This is useful when the expression has + more than one function in it. + + 'default' - The default key as returned by classifier functions in ode + and pde.py + + 'hint' - The hint given by the user for which the differential equation + is to be solved. If the hint given by the user is 'default', + then the value of 'hint' and 'default' is the same. + + 'order' - The order of the function as returned by ode_order + + 'match' - It returns the match as given by the classifier functions, for + the default hint. + + If the hint provided to the function is not "default" and is not in + ('all', 'all_Integral', 'best'), then a dict with the above mentioned keys + is returned along with the keys which are returned when dict in + classify_ode or classify_pde is set True + + If the hint given is in ('all', 'all_Integral', 'best'), then this function + returns a nested dict, with the keys, being the set of classified hints + returned by classifier functions, and the values being the dict of form + as mentioned above. + + Key 'eq' is a common key to all the above mentioned hints which returns an + expression if eq given by user is an Equality. + + See Also + ======== + classify_ode(ode.py) + classify_pde(pde.py) + """ + if isinstance(eq, Equality): + eq = eq.lhs - eq.rhs + + # preprocess the equation and find func if not given + if prep or func is None: + eq, func = _preprocess(eq, func) + prep = False + + # type is an argument passed by the solve functions in ode and pde.py + # that identifies whether the function caller is an ordinary + # or partial differential equation. Accordingly corresponding + # changes are made in the function. + type = kwargs.get('type', None) + xi = kwargs.get('xi') + eta = kwargs.get('eta') + x0 = kwargs.get('x0', 0) + terms = kwargs.get('n') + + if type == 'ode': + from sympy.solvers.ode import classify_ode, allhints + classifier = classify_ode + string = 'ODE ' + dummy = '' + + elif type == 'pde': + from sympy.solvers.pde import classify_pde, allhints + classifier = classify_pde + string = 'PDE ' + dummy = 'p' + + # Magic that should only be used internally. Prevents classify_ode from + # being called more than it needs to be by passing its results through + # recursive calls. + if kwargs.get('classify', True): + hints = classifier(eq, func, dict=True, ics=ics, xi=xi, eta=eta, + n=terms, x0=x0, hint=hint, prep=prep) + + else: + # Here is what all this means: + # + # hint: The hint method given to _desolve() by the user. + # hints: The dictionary of hints that match the DE, along with other + # information (including the internal pass-through magic). + # default: The default hint to return, the first hint from allhints + # that matches the hint; obtained from classify_ode(). + # match: Dictionary containing the match dictionary for each hint + # (the parts of the DE for solving). When going through the + # hints in "all", this holds the match string for the current + # hint. + # order: The order of the DE, as determined by ode_order(). + hints = kwargs.get('hint', + {'default': hint, + hint: kwargs['match'], + 'order': kwargs['order']}) + if not hints['default']: + # classify_ode will set hints['default'] to None if no hints match. + if hint not in allhints and hint != 'default': + raise ValueError("Hint not recognized: " + hint) + elif hint not in hints['ordered_hints'] and hint != 'default': + raise ValueError(string + str(eq) + " does not match hint " + hint) + # If dsolve can't solve the purely algebraic equation then dsolve will raise + # ValueError + elif hints['order'] == 0: + raise ValueError( + str(eq) + " is not a solvable differential equation in " + str(func)) + else: + raise NotImplementedError(dummy + "solve" + ": Cannot solve " + str(eq)) + if hint == 'default': + return _desolve(eq, func, ics=ics, hint=hints['default'], simplify=simplify, + prep=prep, x0=x0, classify=False, order=hints['order'], + match=hints[hints['default']], xi=xi, eta=eta, n=terms, type=type) + elif hint in ('all', 'all_Integral', 'best'): + retdict = {} + gethints = set(hints) - {'order', 'default', 'ordered_hints'} + if hint == 'all_Integral': + for i in hints: + if i.endswith('_Integral'): + gethints.remove(i.removesuffix('_Integral')) + # special cases + for k in ["1st_homogeneous_coeff_best", "1st_power_series", + "lie_group", "2nd_power_series_ordinary", "2nd_power_series_regular"]: + if k in gethints: + gethints.remove(k) + for i in gethints: + sol = _desolve(eq, func, ics=ics, hint=i, x0=x0, simplify=simplify, prep=prep, + classify=False, n=terms, order=hints['order'], match=hints[i], type=type) + retdict[i] = sol + retdict['all'] = True + retdict['eq'] = eq + return retdict + elif hint not in allhints: # and hint not in ('default', 'ordered_hints'): + raise ValueError("Hint not recognized: " + hint) + elif hint not in hints: + raise ValueError(string + str(eq) + " does not match hint " + hint) + else: + # Key added to identify the hint needed to solve the equation + hints['hint'] = hint + hints.update({'func': func, 'eq': eq}) + return hints diff --git a/phivenv/Lib/site-packages/sympy/solvers/diophantine/__init__.py b/phivenv/Lib/site-packages/sympy/solvers/diophantine/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..23c21242208d6f520c130250ecdce43382b9d868 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/solvers/diophantine/__init__.py @@ -0,0 +1,5 @@ +from .diophantine import diophantine, classify_diop, diop_solve + +__all__ = [ + 'diophantine', 'classify_diop', 'diop_solve' +] diff --git a/phivenv/Lib/site-packages/sympy/solvers/diophantine/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/solvers/diophantine/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e18197cb58fcc32d1c792c1938bd771e7e389ef3 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/solvers/diophantine/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/solvers/diophantine/diophantine.py b/phivenv/Lib/site-packages/sympy/solvers/diophantine/diophantine.py new file mode 100644 index 0000000000000000000000000000000000000000..ffdef6344451c96ed48dff099cf8f02494f4b504 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/solvers/diophantine/diophantine.py @@ -0,0 +1,3980 @@ +from __future__ import annotations + +from sympy.core.add import Add +from sympy.core.assumptions import check_assumptions +from sympy.core.containers import Tuple +from sympy.core.exprtools import factor_terms +from sympy.core.function import _mexpand +from sympy.core.mul import Mul +from sympy.core.numbers import Rational, int_valued +from sympy.core.intfunc import igcdex, ilcm, igcd, integer_nthroot, isqrt +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.sorting import default_sort_key, ordered +from sympy.core.symbol import Symbol, symbols +from sympy.core.sympify import _sympify +from sympy.external.gmpy import jacobi, remove, invert, iroot +from sympy.functions.elementary.complexes import sign +from sympy.functions.elementary.integers import floor +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.matrices.dense import MutableDenseMatrix as Matrix +from sympy.ntheory.factor_ import divisors, factorint, perfect_power +from sympy.ntheory.generate import nextprime +from sympy.ntheory.primetest import is_square, isprime +from sympy.ntheory.modular import symmetric_residue +from sympy.ntheory.residue_ntheory import sqrt_mod, sqrt_mod_iter +from sympy.polys.polyerrors import GeneratorsNeeded +from sympy.polys.polytools import Poly, factor_list +from sympy.simplify.simplify import signsimp +from sympy.solvers.solveset import solveset_real +from sympy.utilities import numbered_symbols +from sympy.utilities.misc import as_int, filldedent +from sympy.utilities.iterables import (is_sequence, subsets, permute_signs, + signed_permutations, ordered_partitions) + + +# these are imported with 'from sympy.solvers.diophantine import * +__all__ = ['diophantine', 'classify_diop'] + + +class DiophantineSolutionSet(set): + """ + Container for a set of solutions to a particular diophantine equation. + + The base representation is a set of tuples representing each of the solutions. + + Parameters + ========== + + symbols : list + List of free symbols in the original equation. + parameters: list + List of parameters to be used in the solution. + + Examples + ======== + + Adding solutions: + + >>> from sympy.solvers.diophantine.diophantine import DiophantineSolutionSet + >>> from sympy.abc import x, y, t, u + >>> s1 = DiophantineSolutionSet([x, y], [t, u]) + >>> s1 + set() + >>> s1.add((2, 3)) + >>> s1.add((-1, u)) + >>> s1 + {(-1, u), (2, 3)} + >>> s2 = DiophantineSolutionSet([x, y], [t, u]) + >>> s2.add((3, 4)) + >>> s1.update(*s2) + >>> s1 + {(-1, u), (2, 3), (3, 4)} + + Conversion of solutions into dicts: + + >>> list(s1.dict_iterator()) + [{x: -1, y: u}, {x: 2, y: 3}, {x: 3, y: 4}] + + Substituting values: + + >>> s3 = DiophantineSolutionSet([x, y], [t, u]) + >>> s3.add((t**2, t + u)) + >>> s3 + {(t**2, t + u)} + >>> s3.subs({t: 2, u: 3}) + {(4, 5)} + >>> s3.subs(t, -1) + {(1, u - 1)} + >>> s3.subs(t, 3) + {(9, u + 3)} + + Evaluation at specific values. Positional arguments are given in the same order as the parameters: + + >>> s3(-2, 3) + {(4, 1)} + >>> s3(5) + {(25, u + 5)} + >>> s3(None, 2) + {(t**2, t + 2)} + """ + + def __init__(self, symbols_seq, parameters): + super().__init__() + + if not is_sequence(symbols_seq): + raise ValueError("Symbols must be given as a sequence.") + + if not is_sequence(parameters): + raise ValueError("Parameters must be given as a sequence.") + + self.symbols = tuple(symbols_seq) + self.parameters = tuple(parameters) + + def add(self, solution): + if len(solution) != len(self.symbols): + raise ValueError("Solution should have a length of %s, not %s" % (len(self.symbols), len(solution))) + # make solution canonical wrt sign (i.e. no -x unless x is also present as an arg) + args = set(solution) + for i in range(len(solution)): + x = solution[i] + if not type(x) is int and (-x).is_Symbol and -x not in args: + solution = [_.subs(-x, x) for _ in solution] + super().add(Tuple(*solution)) + + def update(self, *solutions): + for solution in solutions: + self.add(solution) + + def dict_iterator(self): + for solution in ordered(self): + yield dict(zip(self.symbols, solution)) + + def subs(self, *args, **kwargs): + result = DiophantineSolutionSet(self.symbols, self.parameters) + for solution in self: + result.add(solution.subs(*args, **kwargs)) + return result + + def __call__(self, *args): + if len(args) > len(self.parameters): + raise ValueError("Evaluation should have at most %s values, not %s" % (len(self.parameters), len(args))) + rep = {p: v for p, v in zip(self.parameters, args) if v is not None} + return self.subs(rep) + + +class DiophantineEquationType: + """ + Internal representation of a particular diophantine equation type. + + Parameters + ========== + + equation : + The diophantine equation that is being solved. + free_symbols : list (optional) + The symbols being solved for. + + Attributes + ========== + + total_degree : + The maximum of the degrees of all terms in the equation + homogeneous : + Does the equation contain a term of degree 0 + homogeneous_order : + Does the equation contain any coefficient that is in the symbols being solved for + dimension : + The number of symbols being solved for + """ + name: str + + def __init__(self, equation, free_symbols=None): + self.equation = _sympify(equation).expand(force=True) + + if free_symbols is not None: + self.free_symbols = free_symbols + else: + self.free_symbols = list(self.equation.free_symbols) + self.free_symbols.sort(key=default_sort_key) + + if not self.free_symbols: + raise ValueError('equation should have 1 or more free symbols') + + self.coeff = self.equation.as_coefficients_dict() + if not all(int_valued(c) for c in self.coeff.values()): + raise TypeError("Coefficients should be Integers") + + self.total_degree = Poly(self.equation).total_degree() + self.homogeneous = 1 not in self.coeff + self.homogeneous_order = not (set(self.coeff) & set(self.free_symbols)) + self.dimension = len(self.free_symbols) + self._parameters = None + + def matches(self): + """ + Determine whether the given equation can be matched to the particular equation type. + """ + return False + + @property + def n_parameters(self): + return self.dimension + + @property + def parameters(self): + if self._parameters is None: + self._parameters = symbols('t_:%i' % (self.n_parameters,), integer=True) + return self._parameters + + def solve(self, parameters=None, limit=None) -> DiophantineSolutionSet: + raise NotImplementedError('No solver has been written for %s.' % self.name) + + def pre_solve(self, parameters=None): + if not self.matches(): + raise ValueError("This equation does not match the %s equation type." % self.name) + + if parameters is not None: + if len(parameters) != self.n_parameters: + raise ValueError("Expected %s parameter(s) but got %s" % (self.n_parameters, len(parameters))) + + self._parameters = parameters + + +class Univariate(DiophantineEquationType): + """ + Representation of a univariate diophantine equation. + + A univariate diophantine equation is an equation of the form + `a_{0} + a_{1}x + a_{2}x^2 + .. + a_{n}x^n = 0` where `a_{1}, a_{2}, ..a_{n}` are + integer constants and `x` is an integer variable. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import Univariate + >>> from sympy.abc import x + >>> Univariate((x - 2)*(x - 3)**2).solve() # solves equation (x - 2)*(x - 3)**2 == 0 + {(2,), (3,)} + + """ + + name = 'univariate' + + def matches(self): + return self.dimension == 1 + + def solve(self, parameters=None, limit=None): + self.pre_solve(parameters) + + result = DiophantineSolutionSet(self.free_symbols, parameters=self.parameters) + for i in solveset_real(self.equation, self.free_symbols[0]).intersect(S.Integers): + result.add((i,)) + return result + + +class Linear(DiophantineEquationType): + """ + Representation of a linear diophantine equation. + + A linear diophantine equation is an equation of the form `a_{1}x_{1} + + a_{2}x_{2} + .. + a_{n}x_{n} = 0` where `a_{1}, a_{2}, ..a_{n}` are + integer constants and `x_{1}, x_{2}, ..x_{n}` are integer variables. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import Linear + >>> from sympy.abc import x, y, z + >>> l1 = Linear(2*x - 3*y - 5) + >>> l1.matches() # is this equation linear + True + >>> l1.solve() # solves equation 2*x - 3*y - 5 == 0 + {(3*t_0 - 5, 2*t_0 - 5)} + + Here x = -3*t_0 - 5 and y = -2*t_0 - 5 + + >>> Linear(2*x - 3*y - 4*z -3).solve() + {(t_0, 2*t_0 + 4*t_1 + 3, -t_0 - 3*t_1 - 3)} + + """ + + name = 'linear' + + def matches(self): + return self.total_degree == 1 + + def solve(self, parameters=None, limit=None): + self.pre_solve(parameters) + + coeff = self.coeff + var = self.free_symbols + + if 1 in coeff: + # negate coeff[] because input is of the form: ax + by + c == 0 + # but is used as: ax + by == -c + c = -coeff[1] + else: + c = 0 + + result = DiophantineSolutionSet(var, parameters=self.parameters) + params = result.parameters + + if len(var) == 1: + q, r = divmod(c, coeff[var[0]]) + if not r: + result.add((q,)) + return result + + ''' + base_solution_linear() can solve diophantine equations of the form: + + a*x + b*y == c + + We break down multivariate linear diophantine equations into a + series of bivariate linear diophantine equations which can then + be solved individually by base_solution_linear(). + + Consider the following: + + a_0*x_0 + a_1*x_1 + a_2*x_2 == c + + which can be re-written as: + + a_0*x_0 + g_0*y_0 == c + + where + + g_0 == gcd(a_1, a_2) + + and + + y == (a_1*x_1)/g_0 + (a_2*x_2)/g_0 + + This leaves us with two binary linear diophantine equations. + For the first equation: + + a == a_0 + b == g_0 + c == c + + For the second: + + a == a_1/g_0 + b == a_2/g_0 + c == the solution we find for y_0 in the first equation. + + The arrays A and B are the arrays of integers used for + 'a' and 'b' in each of the n-1 bivariate equations we solve. + ''' + + A = [coeff[v] for v in var] + B = [] + if len(var) > 2: + B.append(igcd(A[-2], A[-1])) + A[-2] = A[-2] // B[0] + A[-1] = A[-1] // B[0] + for i in range(len(A) - 3, 0, -1): + gcd = igcd(B[0], A[i]) + B[0] = B[0] // gcd + A[i] = A[i] // gcd + B.insert(0, gcd) + B.append(A[-1]) + + ''' + Consider the trivariate linear equation: + + 4*x_0 + 6*x_1 + 3*x_2 == 2 + + This can be re-written as: + + 4*x_0 + 3*y_0 == 2 + + where + + y_0 == 2*x_1 + x_2 + (Note that gcd(3, 6) == 3) + + The complete integral solution to this equation is: + + x_0 == 2 + 3*t_0 + y_0 == -2 - 4*t_0 + + where 't_0' is any integer. + + Now that we have a solution for 'x_0', find 'x_1' and 'x_2': + + 2*x_1 + x_2 == -2 - 4*t_0 + + We can then solve for '-2' and '-4' independently, + and combine the results: + + 2*x_1a + x_2a == -2 + x_1a == 0 + t_0 + x_2a == -2 - 2*t_0 + + 2*x_1b + x_2b == -4*t_0 + x_1b == 0*t_0 + t_1 + x_2b == -4*t_0 - 2*t_1 + + ==> + + x_1 == t_0 + t_1 + x_2 == -2 - 6*t_0 - 2*t_1 + + where 't_0' and 't_1' are any integers. + + Note that: + + 4*(2 + 3*t_0) + 6*(t_0 + t_1) + 3*(-2 - 6*t_0 - 2*t_1) == 2 + + for any integral values of 't_0', 't_1'; as required. + + This method is generalised for many variables, below. + + ''' + solutions = [] + for Ai, Bi in zip(A, B): + tot_x, tot_y = [], [] + + for arg in Add.make_args(c): + if arg.is_Integer: + # example: 5 -> k = 5 + k, p = arg, S.One + pnew = params[0] + else: # arg is a Mul or Symbol + # example: 3*t_1 -> k = 3 + # example: t_0 -> k = 1 + k, p = arg.as_coeff_Mul() + pnew = params[params.index(p) + 1] + + sol = sol_x, sol_y = base_solution_linear(k, Ai, Bi, pnew) + + if p is S.One: + if None in sol: + return result + else: + # convert a + b*pnew -> a*p + b*pnew + if isinstance(sol_x, Add): + sol_x = sol_x.args[0]*p + sol_x.args[1] + if isinstance(sol_y, Add): + sol_y = sol_y.args[0]*p + sol_y.args[1] + + tot_x.append(sol_x) + tot_y.append(sol_y) + + solutions.append(Add(*tot_x)) + c = Add(*tot_y) + + solutions.append(c) + result.add(solutions) + return result + + +class BinaryQuadratic(DiophantineEquationType): + """ + Representation of a binary quadratic diophantine equation. + + A binary quadratic diophantine equation is an equation of the + form `Ax^2 + Bxy + Cy^2 + Dx + Ey + F = 0`, where `A, B, C, D, E, + F` are integer constants and `x` and `y` are integer variables. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy.solvers.diophantine.diophantine import BinaryQuadratic + >>> b1 = BinaryQuadratic(x**3 + y**2 + 1) + >>> b1.matches() + False + >>> b2 = BinaryQuadratic(x**2 + y**2 + 2*x + 2*y + 2) + >>> b2.matches() + True + >>> b2.solve() + {(-1, -1)} + + References + ========== + + .. [1] Methods to solve Ax^2 + Bxy + Cy^2 + Dx + Ey + F = 0, [online], + Available: https://www.alpertron.com.ar/METHODS.HTM + .. [2] Solving the equation ax^2+ bxy + cy^2 + dx + ey + f= 0, [online], + Available: https://web.archive.org/web/20160323033111/http://www.jpr2718.org/ax2p.pdf + + """ + + name = 'binary_quadratic' + + def matches(self): + return self.total_degree == 2 and self.dimension == 2 + + def solve(self, parameters=None, limit=None) -> DiophantineSolutionSet: + self.pre_solve(parameters) + + var = self.free_symbols + coeff = self.coeff + + x, y = var + + A = coeff[x**2] + B = coeff[x*y] + C = coeff[y**2] + D = coeff[x] + E = coeff[y] + F = coeff[S.One] + + A, B, C, D, E, F = [as_int(i) for i in _remove_gcd(A, B, C, D, E, F)] + + # (1) Simple-Hyperbolic case: A = C = 0, B != 0 + # In this case equation can be converted to (Bx + E)(By + D) = DE - BF + # We consider two cases; DE - BF = 0 and DE - BF != 0 + # More details, https://www.alpertron.com.ar/METHODS.HTM#SHyperb + + result = DiophantineSolutionSet(var, self.parameters) + t, u = result.parameters + + discr = B**2 - 4*A*C + if A == 0 and C == 0 and B != 0: + + if D*E - B*F == 0: + q, r = divmod(E, B) + if not r: + result.add((-q, t)) + q, r = divmod(D, B) + if not r: + result.add((t, -q)) + else: + div = divisors(D*E - B*F) + div = div + [-term for term in div] + for d in div: + x0, r = divmod(d - E, B) + if not r: + q, r = divmod(D*E - B*F, d) + if not r: + y0, r = divmod(q - D, B) + if not r: + result.add((x0, y0)) + + # (2) Parabolic case: B**2 - 4*A*C = 0 + # There are two subcases to be considered in this case. + # sqrt(c)D - sqrt(a)E = 0 and sqrt(c)D - sqrt(a)E != 0 + # More Details, https://www.alpertron.com.ar/METHODS.HTM#Parabol + + elif discr == 0: + + if A == 0: + s = BinaryQuadratic(self.equation, free_symbols=[y, x]).solve(parameters=[t, u]) + for soln in s: + result.add((soln[1], soln[0])) + + else: + g = sign(A)*igcd(A, C) + a = A // g + c = C // g + e = sign(B / A) + + sqa = isqrt(a) + sqc = isqrt(c) + _c = e*sqc*D - sqa*E + if not _c: + z = Symbol("z", real=True) + eq = sqa*g*z**2 + D*z + sqa*F + roots = solveset_real(eq, z).intersect(S.Integers) + for root in roots: + ans = diop_solve(sqa*x + e*sqc*y - root) + result.add((ans[0], ans[1])) + + elif int_valued(c): + solve_x = lambda u: -e*sqc*g*_c*t**2 - (E + 2*e*sqc*g*u)*t \ + - (e*sqc*g*u**2 + E*u + e*sqc*F) // _c + + solve_y = lambda u: sqa*g*_c*t**2 + (D + 2*sqa*g*u)*t \ + + (sqa*g*u**2 + D*u + sqa*F) // _c + + for z0 in range(0, abs(_c)): + # Check if the coefficients of y and x obtained are integers or not + if (divisible(sqa*g*z0**2 + D*z0 + sqa*F, _c) and + divisible(e*sqc*g*z0**2 + E*z0 + e*sqc*F, _c)): + result.add((solve_x(z0), solve_y(z0))) + + # (3) Method used when B**2 - 4*A*C is a square, is described in p. 6 of the below paper + # by John P. Robertson. + # https://web.archive.org/web/20160323033111/http://www.jpr2718.org/ax2p.pdf + + elif is_square(discr): + if A != 0: + r = sqrt(discr) + u, v = symbols("u, v", integer=True) + eq = _mexpand( + 4*A*r*u*v + 4*A*D*(B*v + r*u + r*v - B*u) + + 2*A*4*A*E*(u - v) + 4*A*r*4*A*F) + + solution = diop_solve(eq, t) + + for s0, t0 in solution: + + num = B*t0 + r*s0 + r*t0 - B*s0 + x_0 = S(num) / (4*A*r) + y_0 = S(s0 - t0) / (2*r) + if isinstance(s0, Symbol) or isinstance(t0, Symbol): + if len(check_param(x_0, y_0, 4*A*r, parameters)) > 0: + ans = check_param(x_0, y_0, 4*A*r, parameters) + result.update(*ans) + elif x_0.is_Integer and y_0.is_Integer: + if is_solution_quad(var, coeff, x_0, y_0): + result.add((x_0, y_0)) + + else: + s = BinaryQuadratic(self.equation, free_symbols=var[::-1]).solve(parameters=[t, u]) # Interchange x and y + while s: + result.add(s.pop()[::-1]) # and solution <--------+ + + # (4) B**2 - 4*A*C > 0 and B**2 - 4*A*C not a square or B**2 - 4*A*C < 0 + + else: + + P, Q = _transformation_to_DN(var, coeff) + D, N = _find_DN(var, coeff) + solns_pell = diop_DN(D, N) + + if D < 0: + for x0, y0 in solns_pell: + for x in [-x0, x0]: + for y in [-y0, y0]: + s = P*Matrix([x, y]) + Q + try: + result.add([as_int(_) for _ in s]) + except ValueError: + pass + else: + # In this case equation can be transformed into a Pell equation + + solns_pell = set(solns_pell) + solns_pell.update((-X, -Y) for X, Y in list(solns_pell)) + + a = diop_DN(D, 1) + T = a[0][0] + U = a[0][1] + + if all(int_valued(_) for _ in P[:4] + Q[:2]): + for r, s in solns_pell: + _a = (r + s*sqrt(D))*(T + U*sqrt(D))**t + _b = (r - s*sqrt(D))*(T - U*sqrt(D))**t + x_n = _mexpand(S(_a + _b) / 2) + y_n = _mexpand(S(_a - _b) / (2*sqrt(D))) + s = P*Matrix([x_n, y_n]) + Q + result.add(s) + + else: + L = ilcm(*[_.q for _ in P[:4] + Q[:2]]) + + k = 1 + + T_k = T + U_k = U + + while (T_k - 1) % L != 0 or U_k % L != 0: + T_k, U_k = T_k*T + D*U_k*U, T_k*U + U_k*T + k += 1 + + for X, Y in solns_pell: + + for i in range(k): + if all(int_valued(_) for _ in P*Matrix([X, Y]) + Q): + _a = (X + sqrt(D)*Y)*(T_k + sqrt(D)*U_k)**t + _b = (X - sqrt(D)*Y)*(T_k - sqrt(D)*U_k)**t + Xt = S(_a + _b) / 2 + Yt = S(_a - _b) / (2*sqrt(D)) + s = P*Matrix([Xt, Yt]) + Q + result.add(s) + + X, Y = X*T + D*U*Y, X*U + Y*T + + return result + + +class InhomogeneousTernaryQuadratic(DiophantineEquationType): + """ + + Representation of an inhomogeneous ternary quadratic. + + No solver is currently implemented for this equation type. + + """ + + name = 'inhomogeneous_ternary_quadratic' + + def matches(self): + if not (self.total_degree == 2 and self.dimension == 3): + return False + if not self.homogeneous: + return False + return not self.homogeneous_order + + +class HomogeneousTernaryQuadraticNormal(DiophantineEquationType): + """ + Representation of a homogeneous ternary quadratic normal diophantine equation. + + Examples + ======== + + >>> from sympy.abc import x, y, z + >>> from sympy.solvers.diophantine.diophantine import HomogeneousTernaryQuadraticNormal + >>> HomogeneousTernaryQuadraticNormal(4*x**2 - 5*y**2 + z**2).solve() + {(1, 2, 4)} + + """ + + name = 'homogeneous_ternary_quadratic_normal' + + def matches(self): + if not (self.total_degree == 2 and self.dimension == 3): + return False + if not self.homogeneous: + return False + if not self.homogeneous_order: + return False + + nonzero = [k for k in self.coeff if self.coeff[k]] + return len(nonzero) == 3 and all(i**2 in nonzero for i in self.free_symbols) + + def solve(self, parameters=None, limit=None) -> DiophantineSolutionSet: + self.pre_solve(parameters) + + var = self.free_symbols + coeff = self.coeff + + x, y, z = var + + a = coeff[x**2] + b = coeff[y**2] + c = coeff[z**2] + + (sqf_of_a, sqf_of_b, sqf_of_c), (a_1, b_1, c_1), (a_2, b_2, c_2) = \ + sqf_normal(a, b, c, steps=True) + + A = -a_2*c_2 + B = -b_2*c_2 + + result = DiophantineSolutionSet(var, parameters=self.parameters) + + # If following two conditions are satisfied then there are no solutions + if A < 0 and B < 0: + return result + + if ( + sqrt_mod(-b_2*c_2, a_2) is None or + sqrt_mod(-c_2*a_2, b_2) is None or + sqrt_mod(-a_2*b_2, c_2) is None): + return result + + z_0, x_0, y_0 = descent(A, B) + + z_0, q = _rational_pq(z_0, abs(c_2)) + x_0 *= q + y_0 *= q + + x_0, y_0, z_0 = _remove_gcd(x_0, y_0, z_0) + + # Holzer reduction + if sign(a) == sign(b): + x_0, y_0, z_0 = holzer(x_0, y_0, z_0, abs(a_2), abs(b_2), abs(c_2)) + elif sign(a) == sign(c): + x_0, z_0, y_0 = holzer(x_0, z_0, y_0, abs(a_2), abs(c_2), abs(b_2)) + else: + y_0, z_0, x_0 = holzer(y_0, z_0, x_0, abs(b_2), abs(c_2), abs(a_2)) + + x_0 = reconstruct(b_1, c_1, x_0) + y_0 = reconstruct(a_1, c_1, y_0) + z_0 = reconstruct(a_1, b_1, z_0) + + sq_lcm = ilcm(sqf_of_a, sqf_of_b, sqf_of_c) + + x_0 = abs(x_0*sq_lcm // sqf_of_a) + y_0 = abs(y_0*sq_lcm // sqf_of_b) + z_0 = abs(z_0*sq_lcm // sqf_of_c) + + result.add(_remove_gcd(x_0, y_0, z_0)) + return result + + +class HomogeneousTernaryQuadratic(DiophantineEquationType): + """ + Representation of a homogeneous ternary quadratic diophantine equation. + + Examples + ======== + + >>> from sympy.abc import x, y, z + >>> from sympy.solvers.diophantine.diophantine import HomogeneousTernaryQuadratic + >>> HomogeneousTernaryQuadratic(x**2 + y**2 - 3*z**2 + x*y).solve() + {(-1, 2, 1)} + >>> HomogeneousTernaryQuadratic(3*x**2 + y**2 - 3*z**2 + 5*x*y + y*z).solve() + {(3, 12, 13)} + + """ + + name = 'homogeneous_ternary_quadratic' + + def matches(self): + if not (self.total_degree == 2 and self.dimension == 3): + return False + if not self.homogeneous: + return False + if not self.homogeneous_order: + return False + + nonzero = [k for k in self.coeff if self.coeff[k]] + return not (len(nonzero) == 3 and all(i**2 in nonzero for i in self.free_symbols)) + + def solve(self, parameters=None, limit=None): + self.pre_solve(parameters) + + _var = self.free_symbols + coeff = self.coeff + + x, y, z = _var + var = [x, y, z] + + # Equations of the form B*x*y + C*z*x + E*y*z = 0 and At least two of the + # coefficients A, B, C are non-zero. + # There are infinitely many solutions for the equation. + # Ex: (0, 0, t), (0, t, 0), (t, 0, 0) + # Equation can be re-written as y*(B*x + E*z) = -C*x*z and we can find rather + # unobvious solutions. Set y = -C and B*x + E*z = x*z. The latter can be solved by + # using methods for binary quadratic diophantine equations. Let's select the + # solution which minimizes |x| + |z| + + result = DiophantineSolutionSet(var, parameters=self.parameters) + + def unpack_sol(sol): + if len(sol) > 0: + return list(sol)[0] + return None, None, None + + if not any(coeff[i**2] for i in var): + if coeff[x*z]: + sols = diophantine(coeff[x*y]*x + coeff[y*z]*z - x*z) + s = min(sols, key=lambda r: abs(r[0]) + abs(r[1])) + result.add(_remove_gcd(s[0], -coeff[x*z], s[1])) + return result + + var[0], var[1] = _var[1], _var[0] + y_0, x_0, z_0 = unpack_sol(_diop_ternary_quadratic(var, coeff)) + if x_0 is not None: + result.add((x_0, y_0, z_0)) + return result + + if coeff[x**2] == 0: + # If the coefficient of x is zero change the variables + if coeff[y**2] == 0: + var[0], var[2] = _var[2], _var[0] + z_0, y_0, x_0 = unpack_sol(_diop_ternary_quadratic(var, coeff)) + + else: + var[0], var[1] = _var[1], _var[0] + y_0, x_0, z_0 = unpack_sol(_diop_ternary_quadratic(var, coeff)) + + else: + if coeff[x*y] or coeff[x*z]: + # Apply the transformation x --> X - (B*y + C*z)/(2*A) + A = coeff[x**2] + B = coeff[x*y] + C = coeff[x*z] + D = coeff[y**2] + E = coeff[y*z] + F = coeff[z**2] + + _coeff = {} + + _coeff[x**2] = 4*A**2 + _coeff[y**2] = 4*A*D - B**2 + _coeff[z**2] = 4*A*F - C**2 + _coeff[y*z] = 4*A*E - 2*B*C + _coeff[x*y] = 0 + _coeff[x*z] = 0 + + x_0, y_0, z_0 = unpack_sol(_diop_ternary_quadratic(var, _coeff)) + + if x_0 is None: + return result + + p, q = _rational_pq(B*y_0 + C*z_0, 2*A) + x_0, y_0, z_0 = x_0*q - p, y_0*q, z_0*q + + elif coeff[z*y] != 0: + if coeff[y**2] == 0: + if coeff[z**2] == 0: + # Equations of the form A*x**2 + E*yz = 0. + A = coeff[x**2] + E = coeff[y*z] + + b, a = _rational_pq(-E, A) + + x_0, y_0, z_0 = b, a, b + + else: + # Ax**2 + E*y*z + F*z**2 = 0 + var[0], var[2] = _var[2], _var[0] + z_0, y_0, x_0 = unpack_sol(_diop_ternary_quadratic(var, coeff)) + + else: + # A*x**2 + D*y**2 + E*y*z + F*z**2 = 0, C may be zero + var[0], var[1] = _var[1], _var[0] + y_0, x_0, z_0 = unpack_sol(_diop_ternary_quadratic(var, coeff)) + + else: + # Ax**2 + D*y**2 + F*z**2 = 0, C may be zero + x_0, y_0, z_0 = unpack_sol(_diop_ternary_quadratic_normal(var, coeff)) + + if x_0 is None: + return result + + result.add(_remove_gcd(x_0, y_0, z_0)) + return result + + +class InhomogeneousGeneralQuadratic(DiophantineEquationType): + """ + + Representation of an inhomogeneous general quadratic. + + No solver is currently implemented for this equation type. + + """ + + name = 'inhomogeneous_general_quadratic' + + def matches(self): + if not (self.total_degree == 2 and self.dimension >= 3): + return False + if not self.homogeneous_order: + return True + # there may be Pow keys like x**2 or Mul keys like x*y + return any(k.is_Mul for k in self.coeff) and not self.homogeneous + + +class HomogeneousGeneralQuadratic(DiophantineEquationType): + """ + + Representation of a homogeneous general quadratic. + + No solver is currently implemented for this equation type. + + """ + + name = 'homogeneous_general_quadratic' + + def matches(self): + if not (self.total_degree == 2 and self.dimension >= 3): + return False + if not self.homogeneous_order: + return False + # there may be Pow keys like x**2 or Mul keys like x*y + return any(k.is_Mul for k in self.coeff) and self.homogeneous + + +class GeneralSumOfSquares(DiophantineEquationType): + r""" + Representation of the diophantine equation + + `x_{1}^2 + x_{2}^2 + . . . + x_{n}^2 - k = 0`. + + Details + ======= + + When `n = 3` if `k = 4^a(8m + 7)` for some `a, m \in Z` then there will be + no solutions. Refer [1]_ for more details. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import GeneralSumOfSquares + >>> from sympy.abc import a, b, c, d, e + >>> GeneralSumOfSquares(a**2 + b**2 + c**2 + d**2 + e**2 - 2345).solve() + {(15, 22, 22, 24, 24)} + + By default only 1 solution is returned. Use the `limit` keyword for more: + + >>> sorted(GeneralSumOfSquares(a**2 + b**2 + c**2 + d**2 + e**2 - 2345).solve(limit=3)) + [(15, 22, 22, 24, 24), (16, 19, 24, 24, 24), (16, 20, 22, 23, 26)] + + References + ========== + + .. [1] Representing an integer as a sum of three squares, [online], + Available: + https://proofwiki.org/wiki/Integer_as_Sum_of_Three_Squares + """ + + name = 'general_sum_of_squares' + + def matches(self): + if not (self.total_degree == 2 and self.dimension >= 3): + return False + if not self.homogeneous_order: + return False + if any(k.is_Mul for k in self.coeff): + return False + return all(self.coeff[k] == 1 for k in self.coeff if k != 1) + + def solve(self, parameters=None, limit=1): + self.pre_solve(parameters) + + var = self.free_symbols + k = -int(self.coeff[1]) + n = self.dimension + + result = DiophantineSolutionSet(var, parameters=self.parameters) + + if k < 0 or limit < 1: + return result + + signs = [-1 if x.is_nonpositive else 1 for x in var] + negs = signs.count(-1) != 0 + + took = 0 + for t in sum_of_squares(k, n, zeros=True): + if negs: + result.add([signs[i]*j for i, j in enumerate(t)]) + else: + result.add(t) + took += 1 + if took == limit: + break + return result + + +class GeneralPythagorean(DiophantineEquationType): + """ + Representation of the general pythagorean equation, + `a_{1}^2x_{1}^2 + a_{2}^2x_{2}^2 + . . . + a_{n}^2x_{n}^2 - a_{n + 1}^2x_{n + 1}^2 = 0`. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import GeneralPythagorean + >>> from sympy.abc import a, b, c, d, e, x, y, z, t + >>> GeneralPythagorean(a**2 + b**2 + c**2 - d**2).solve() + {(t_0**2 + t_1**2 - t_2**2, 2*t_0*t_2, 2*t_1*t_2, t_0**2 + t_1**2 + t_2**2)} + >>> GeneralPythagorean(9*a**2 - 4*b**2 + 16*c**2 + 25*d**2 + e**2).solve(parameters=[x, y, z, t]) + {(-10*t**2 + 10*x**2 + 10*y**2 + 10*z**2, 15*t**2 + 15*x**2 + 15*y**2 + 15*z**2, 15*t*x, 12*t*y, 60*t*z)} + """ + + name = 'general_pythagorean' + + def matches(self): + if not (self.total_degree == 2 and self.dimension >= 3): + return False + if not self.homogeneous_order: + return False + if any(k.is_Mul for k in self.coeff): + return False + if all(self.coeff[k] == 1 for k in self.coeff if k != 1): + return False + if not all(is_square(abs(self.coeff[k])) for k in self.coeff): + return False + # all but one has the same sign + # e.g. 4*x**2 + y**2 - 4*z**2 + return abs(sum(sign(self.coeff[k]) for k in self.coeff)) == self.dimension - 2 + + @property + def n_parameters(self): + return self.dimension - 1 + + def solve(self, parameters=None, limit=1): + self.pre_solve(parameters) + + coeff = self.coeff + var = self.free_symbols + n = self.dimension + + if sign(coeff[var[0] ** 2]) + sign(coeff[var[1] ** 2]) + sign(coeff[var[2] ** 2]) < 0: + for key in coeff.keys(): + coeff[key] = -coeff[key] + + result = DiophantineSolutionSet(var, parameters=self.parameters) + + index = 0 + + for i, v in enumerate(var): + if sign(coeff[v ** 2]) == -1: + index = i + + m = result.parameters + + ith = sum(m_i ** 2 for m_i in m) + L = [ith - 2 * m[n - 2] ** 2] + L.extend([2 * m[i] * m[n - 2] for i in range(n - 2)]) + sol = L[:index] + [ith] + L[index:] + + lcm = 1 + for i, v in enumerate(var): + if i == index or (index > 0 and i == 0) or (index == 0 and i == 1): + lcm = ilcm(lcm, sqrt(abs(coeff[v ** 2]))) + else: + s = sqrt(coeff[v ** 2]) + lcm = ilcm(lcm, s if _odd(s) else s // 2) + + for i, v in enumerate(var): + sol[i] = (lcm * sol[i]) / sqrt(abs(coeff[v ** 2])) + + result.add(sol) + return result + + +class CubicThue(DiophantineEquationType): + """ + Representation of a cubic Thue diophantine equation. + + A cubic Thue diophantine equation is a polynomial of the form + `f(x, y) = r` of degree 3, where `x` and `y` are integers + and `r` is a rational number. + + No solver is currently implemented for this equation type. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy.solvers.diophantine.diophantine import CubicThue + >>> c1 = CubicThue(x**3 + y**2 + 1) + >>> c1.matches() + True + + """ + + name = 'cubic_thue' + + def matches(self): + return self.total_degree == 3 and self.dimension == 2 + + +class GeneralSumOfEvenPowers(DiophantineEquationType): + """ + Representation of the diophantine equation + + `x_{1}^e + x_{2}^e + . . . + x_{n}^e - k = 0` + + where `e` is an even, integer power. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import GeneralSumOfEvenPowers + >>> from sympy.abc import a, b + >>> GeneralSumOfEvenPowers(a**4 + b**4 - (2**4 + 3**4)).solve() + {(2, 3)} + + """ + + name = 'general_sum_of_even_powers' + + def matches(self): + if not self.total_degree > 3: + return False + if self.total_degree % 2 != 0: + return False + if not all(k.is_Pow and k.exp == self.total_degree for k in self.coeff if k != 1): + return False + return all(self.coeff[k] == 1 for k in self.coeff if k != 1) + + def solve(self, parameters=None, limit=1): + self.pre_solve(parameters) + + var = self.free_symbols + coeff = self.coeff + + p = None + for q in coeff.keys(): + if q.is_Pow and coeff[q]: + p = q.exp + + k = len(var) + n = -coeff[1] + + result = DiophantineSolutionSet(var, parameters=self.parameters) + + if n < 0 or limit < 1: + return result + + sign = [-1 if x.is_nonpositive else 1 for x in var] + negs = sign.count(-1) != 0 + + took = 0 + for t in power_representation(n, p, k): + if negs: + result.add([sign[i]*j for i, j in enumerate(t)]) + else: + result.add(t) + took += 1 + if took == limit: + break + return result + +# these types are known (but not necessarily handled) +# note that order is important here (in the current solver state) +all_diop_classes = [ + Linear, + Univariate, + BinaryQuadratic, + InhomogeneousTernaryQuadratic, + HomogeneousTernaryQuadraticNormal, + HomogeneousTernaryQuadratic, + InhomogeneousGeneralQuadratic, + HomogeneousGeneralQuadratic, + GeneralSumOfSquares, + GeneralPythagorean, + CubicThue, + GeneralSumOfEvenPowers, +] + +diop_known = {diop_class.name for diop_class in all_diop_classes} + + +def _remove_gcd(*x): + try: + g = igcd(*x) + except ValueError: + fx = list(filter(None, x)) + if len(fx) < 2: + return x + g = igcd(*[i.as_content_primitive()[0] for i in fx]) + except TypeError: + raise TypeError('_remove_gcd(a,b,c) or _remove_gcd(*container)') + if g == 1: + return x + return tuple([i//g for i in x]) + + +def _rational_pq(a, b): + # return `(numer, denom)` for a/b; sign in numer and gcd removed + return _remove_gcd(sign(b)*a, abs(b)) + + +def _nint_or_floor(p, q): + # return nearest int to p/q; in case of tie return floor(p/q) + w, r = divmod(p, q) + if abs(r) <= abs(q)//2: + return w + return w + 1 + + +def _odd(i): + return i % 2 != 0 + + +def _even(i): + return i % 2 == 0 + + +def diophantine(eq, param=symbols("t", integer=True), syms=None, + permute=False): + """ + Simplify the solution procedure of diophantine equation ``eq`` by + converting it into a product of terms which should equal zero. + + Explanation + =========== + + For example, when solving, `x^2 - y^2 = 0` this is treated as + `(x + y)(x - y) = 0` and `x + y = 0` and `x - y = 0` are solved + independently and combined. Each term is solved by calling + ``diop_solve()``. (Although it is possible to call ``diop_solve()`` + directly, one must be careful to pass an equation in the correct + form and to interpret the output correctly; ``diophantine()`` is + the public-facing function to use in general.) + + Output of ``diophantine()`` is a set of tuples. The elements of the + tuple are the solutions for each variable in the equation and + are arranged according to the alphabetic ordering of the variables. + e.g. For an equation with two variables, `a` and `b`, the first + element of the tuple is the solution for `a` and the second for `b`. + + Usage + ===== + + ``diophantine(eq, t, syms)``: Solve the diophantine + equation ``eq``. + ``t`` is the optional parameter to be used by ``diop_solve()``. + ``syms`` is an optional list of symbols which determines the + order of the elements in the returned tuple. + + By default, only the base solution is returned. If ``permute`` is set to + True then permutations of the base solution and/or permutations of the + signs of the values will be returned when applicable. + + Details + ======= + + ``eq`` should be an expression which is assumed to be zero. + ``t`` is the parameter to be used in the solution. + + Examples + ======== + + >>> from sympy import diophantine + >>> from sympy.abc import a, b + >>> eq = a**4 + b**4 - (2**4 + 3**4) + >>> diophantine(eq) + {(2, 3)} + >>> diophantine(eq, permute=True) + {(-3, -2), (-3, 2), (-2, -3), (-2, 3), (2, -3), (2, 3), (3, -2), (3, 2)} + + >>> from sympy.abc import x, y, z + >>> diophantine(x**2 - y**2) + {(t_0, -t_0), (t_0, t_0)} + + >>> diophantine(x*(2*x + 3*y - z)) + {(0, n1, n2), (t_0, t_1, 2*t_0 + 3*t_1)} + >>> diophantine(x**2 + 3*x*y + 4*x) + {(0, n1), (-3*t_0 - 4, t_0)} + + See Also + ======== + + diop_solve + sympy.utilities.iterables.permute_signs + sympy.utilities.iterables.signed_permutations + """ + + eq = _sympify(eq) + + if isinstance(eq, Eq): + eq = eq.lhs - eq.rhs + + try: + var = list(eq.expand(force=True).free_symbols) + var.sort(key=default_sort_key) + if syms: + if not is_sequence(syms): + raise TypeError( + 'syms should be given as a sequence, e.g. a list') + syms = [i for i in syms if i in var] + if syms != var: + dict_sym_index = dict(zip(syms, range(len(syms)))) + return {tuple([t[dict_sym_index[i]] for i in var]) + for t in diophantine(eq, param, permute=permute)} + n, d = eq.as_numer_denom() + if n.is_number: + return set() + if not d.is_number: + dsol = diophantine(d) + good = diophantine(n) - dsol + return {s for s in good if _mexpand(d.subs(zip(var, s)))} + eq = factor_terms(n) + assert not eq.is_number + eq = eq.as_independent(*var, as_Add=False)[1] + p = Poly(eq) + assert not any(g.is_number for g in p.gens) + eq = p.as_expr() + assert eq.is_polynomial() + except (GeneratorsNeeded, AssertionError): + raise TypeError(filldedent(''' + Equation should be a polynomial with Rational coefficients.''')) + + # permute only sign + do_permute_signs = False + # permute sign and values + do_permute_signs_var = False + # permute few signs + permute_few_signs = False + try: + # if we know that factoring should not be attempted, skip + # the factoring step + v, c, t = classify_diop(eq) + + # check for permute sign + if permute: + len_var = len(v) + permute_signs_for = [ + GeneralSumOfSquares.name, + GeneralSumOfEvenPowers.name] + permute_signs_check = [ + HomogeneousTernaryQuadratic.name, + HomogeneousTernaryQuadraticNormal.name, + BinaryQuadratic.name] + if t in permute_signs_for: + do_permute_signs_var = True + elif t in permute_signs_check: + # if all the variables in eq have even powers + # then do_permute_sign = True + if len_var == 3: + var_mul = list(subsets(v, 2)) + # here var_mul is like [(x, y), (x, z), (y, z)] + xy_coeff = True + x_coeff = True + var1_mul_var2 = (a[0]*a[1] for a in var_mul) + # if coeff(y*z), coeff(y*x), coeff(x*z) is not 0 then + # `xy_coeff` => True and do_permute_sign => False. + # Means no permuted solution. + for v1_mul_v2 in var1_mul_var2: + try: + coeff = c[v1_mul_v2] + except KeyError: + coeff = 0 + xy_coeff = bool(xy_coeff) and bool(coeff) + var_mul = list(subsets(v, 1)) + # here var_mul is like [(x,), (y, )] + for v1 in var_mul: + try: + coeff = c[v1[0]] + except KeyError: + coeff = 0 + x_coeff = bool(x_coeff) and bool(coeff) + if not any((xy_coeff, x_coeff)): + # means only x**2, y**2, z**2, const is present + do_permute_signs = True + elif not x_coeff: + permute_few_signs = True + elif len_var == 2: + var_mul = list(subsets(v, 2)) + # here var_mul is like [(x, y)] + xy_coeff = True + x_coeff = True + var1_mul_var2 = (x[0]*x[1] for x in var_mul) + for v1_mul_v2 in var1_mul_var2: + try: + coeff = c[v1_mul_v2] + except KeyError: + coeff = 0 + xy_coeff = bool(xy_coeff) and bool(coeff) + var_mul = list(subsets(v, 1)) + # here var_mul is like [(x,), (y, )] + for v1 in var_mul: + try: + coeff = c[v1[0]] + except KeyError: + coeff = 0 + x_coeff = bool(x_coeff) and bool(coeff) + if not any((xy_coeff, x_coeff)): + # means only x**2, y**2 and const is present + # so we can get more soln by permuting this soln. + do_permute_signs = True + elif not x_coeff: + # when coeff(x), coeff(y) is not present then signs of + # x, y can be permuted such that their sign are same + # as sign of x*y. + # e.g 1. (x_val,y_val)=> (x_val,y_val), (-x_val,-y_val) + # 2. (-x_vall, y_val)=> (-x_val,y_val), (x_val,-y_val) + permute_few_signs = True + if t == 'general_sum_of_squares': + # trying to factor such expressions will sometimes hang + terms = [(eq, 1)] + else: + raise TypeError + except (TypeError, NotImplementedError): + fl = factor_list(eq) + if fl[0].is_Rational and fl[0] != 1: + return diophantine(eq/fl[0], param=param, syms=syms, permute=permute) + terms = fl[1] + + sols = set() + + for term in terms: + + base, _ = term + var_t, _, eq_type = classify_diop(base, _dict=False) + _, base = signsimp(base, evaluate=False).as_coeff_Mul() + solution = diop_solve(base, param) + + if eq_type in [ + Linear.name, + HomogeneousTernaryQuadratic.name, + HomogeneousTernaryQuadraticNormal.name, + GeneralPythagorean.name]: + sols.add(merge_solution(var, var_t, solution)) + + elif eq_type in [ + BinaryQuadratic.name, + GeneralSumOfSquares.name, + GeneralSumOfEvenPowers.name, + Univariate.name]: + sols.update(merge_solution(var, var_t, sol) for sol in solution) + + else: + raise NotImplementedError('unhandled type: %s' % eq_type) + + sols.discard(()) + null = tuple([0]*len(var)) + # if there is no solution, return trivial solution + if not sols and eq.subs(zip(var, null)).is_zero: + if all(check_assumptions(val, **s.assumptions0) is not False for val, s in zip(null, var)): + sols.add(null) + + final_soln = set() + for sol in sols: + if all(int_valued(s) for s in sol): + if do_permute_signs: + permuted_sign = set(permute_signs(sol)) + final_soln.update(permuted_sign) + elif permute_few_signs: + lst = list(permute_signs(sol)) + lst = list(filter(lambda x: x[0]*x[1] == sol[1]*sol[0], lst)) + permuted_sign = set(lst) + final_soln.update(permuted_sign) + elif do_permute_signs_var: + permuted_sign_var = set(signed_permutations(sol)) + final_soln.update(permuted_sign_var) + else: + final_soln.add(sol) + else: + final_soln.add(sol) + return final_soln + + +def merge_solution(var, var_t, solution): + """ + This is used to construct the full solution from the solutions of sub + equations. + + Explanation + =========== + + For example when solving the equation `(x - y)(x^2 + y^2 - z^2) = 0`, + solutions for each of the equations `x - y = 0` and `x^2 + y^2 - z^2` are + found independently. Solutions for `x - y = 0` are `(x, y) = (t, t)`. But + we should introduce a value for z when we output the solution for the + original equation. This function converts `(t, t)` into `(t, t, n_{1})` + where `n_{1}` is an integer parameter. + """ + sol = [] + + if None in solution: + return () + + solution = iter(solution) + params = numbered_symbols("n", integer=True, start=1) + for v in var: + if v in var_t: + sol.append(next(solution)) + else: + sol.append(next(params)) + + for val, symb in zip(sol, var): + if check_assumptions(val, **symb.assumptions0) is False: + return () + + return tuple(sol) + + +def _diop_solve(eq, params=None): + for diop_type in all_diop_classes: + if diop_type(eq).matches(): + return diop_type(eq).solve(parameters=params) + + +def diop_solve(eq, param=symbols("t", integer=True)): + """ + Solves the diophantine equation ``eq``. + + Explanation + =========== + + Unlike ``diophantine()``, factoring of ``eq`` is not attempted. Uses + ``classify_diop()`` to determine the type of the equation and calls + the appropriate solver function. + + Use of ``diophantine()`` is recommended over other helper functions. + ``diop_solve()`` can return either a set or a tuple depending on the + nature of the equation. All non-trivial solutions are returned: assumptions + on symbols are ignored. + + Usage + ===== + + ``diop_solve(eq, t)``: Solve diophantine equation, ``eq`` using ``t`` + as a parameter if needed. + + Details + ======= + + ``eq`` should be an expression which is assumed to be zero. + ``t`` is a parameter to be used in the solution. + + Examples + ======== + + >>> from sympy.solvers.diophantine import diop_solve + >>> from sympy.abc import x, y, z, w + >>> diop_solve(2*x + 3*y - 5) + (3*t_0 - 5, 5 - 2*t_0) + >>> diop_solve(4*x + 3*y - 4*z + 5) + (t_0, 8*t_0 + 4*t_1 + 5, 7*t_0 + 3*t_1 + 5) + >>> diop_solve(x + 3*y - 4*z + w - 6) + (t_0, t_0 + t_1, 6*t_0 + 5*t_1 + 4*t_2 - 6, 5*t_0 + 4*t_1 + 3*t_2 - 6) + >>> diop_solve(x**2 + y**2 - 5) + {(-2, -1), (-2, 1), (-1, -2), (-1, 2), (1, -2), (1, 2), (2, -1), (2, 1)} + + + See Also + ======== + + diophantine() + """ + var, coeff, eq_type = classify_diop(eq, _dict=False) + + if eq_type == Linear.name: + return diop_linear(eq, param) + + elif eq_type == BinaryQuadratic.name: + return diop_quadratic(eq, param) + + elif eq_type == HomogeneousTernaryQuadratic.name: + return diop_ternary_quadratic(eq, parameterize=True) + + elif eq_type == HomogeneousTernaryQuadraticNormal.name: + return diop_ternary_quadratic_normal(eq, parameterize=True) + + elif eq_type == GeneralPythagorean.name: + return diop_general_pythagorean(eq, param) + + elif eq_type == Univariate.name: + return diop_univariate(eq) + + elif eq_type == GeneralSumOfSquares.name: + return diop_general_sum_of_squares(eq, limit=S.Infinity) + + elif eq_type == GeneralSumOfEvenPowers.name: + return diop_general_sum_of_even_powers(eq, limit=S.Infinity) + + if eq_type is not None and eq_type not in diop_known: + raise ValueError(filldedent(''' + Although this type of equation was identified, it is not yet + handled. It should, however, be listed in `diop_known` at the + top of this file. Developers should see comments at the end of + `classify_diop`. + ''')) # pragma: no cover + else: + raise NotImplementedError( + 'No solver has been written for %s.' % eq_type) + + +def classify_diop(eq, _dict=True): + # docstring supplied externally + + matched = False + diop_type = None + for diop_class in all_diop_classes: + diop_type = diop_class(eq) + if diop_type.matches(): + matched = True + break + + if matched: + return diop_type.free_symbols, dict(diop_type.coeff) if _dict else diop_type.coeff, diop_type.name + + # new diop type instructions + # -------------------------- + # if this error raises and the equation *can* be classified, + # * it should be identified in the if-block above + # * the type should be added to the diop_known + # if a solver can be written for it, + # * a dedicated handler should be written (e.g. diop_linear) + # * it should be passed to that handler in diop_solve + raise NotImplementedError(filldedent(''' + This equation is not yet recognized or else has not been + simplified sufficiently to put it in a form recognized by + diop_classify().''')) + + +classify_diop.func_doc = ( # type: ignore + ''' + Helper routine used by diop_solve() to find information about ``eq``. + + Explanation + =========== + + Returns a tuple containing the type of the diophantine equation + along with the variables (free symbols) and their coefficients. + Variables are returned as a list and coefficients are returned + as a dict with the key being the respective term and the constant + term is keyed to 1. The type is one of the following: + + * %s + + Usage + ===== + + ``classify_diop(eq)``: Return variables, coefficients and type of the + ``eq``. + + Details + ======= + + ``eq`` should be an expression which is assumed to be zero. + ``_dict`` is for internal use: when True (default) a dict is returned, + otherwise a defaultdict which supplies 0 for missing keys is returned. + + Examples + ======== + + >>> from sympy.solvers.diophantine import classify_diop + >>> from sympy.abc import x, y, z, w, t + >>> classify_diop(4*x + 6*y - 4) + ([x, y], {1: -4, x: 4, y: 6}, 'linear') + >>> classify_diop(x + 3*y -4*z + 5) + ([x, y, z], {1: 5, x: 1, y: 3, z: -4}, 'linear') + >>> classify_diop(x**2 + y**2 - x*y + x + 5) + ([x, y], {1: 5, x: 1, x**2: 1, y**2: 1, x*y: -1}, 'binary_quadratic') + ''' % ('\n * '.join(sorted(diop_known)))) + + +def diop_linear(eq, param=symbols("t", integer=True)): + """ + Solves linear diophantine equations. + + A linear diophantine equation is an equation of the form `a_{1}x_{1} + + a_{2}x_{2} + .. + a_{n}x_{n} = 0` where `a_{1}, a_{2}, ..a_{n}` are + integer constants and `x_{1}, x_{2}, ..x_{n}` are integer variables. + + Usage + ===== + + ``diop_linear(eq)``: Returns a tuple containing solutions to the + diophantine equation ``eq``. Values in the tuple is arranged in the same + order as the sorted variables. + + Details + ======= + + ``eq`` is a linear diophantine equation which is assumed to be zero. + ``param`` is the parameter to be used in the solution. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import diop_linear + >>> from sympy.abc import x, y, z + >>> diop_linear(2*x - 3*y - 5) # solves equation 2*x - 3*y - 5 == 0 + (3*t_0 - 5, 2*t_0 - 5) + + Here x = -3*t_0 - 5 and y = -2*t_0 - 5 + + >>> diop_linear(2*x - 3*y - 4*z -3) + (t_0, 2*t_0 + 4*t_1 + 3, -t_0 - 3*t_1 - 3) + + See Also + ======== + + diop_quadratic(), diop_ternary_quadratic(), diop_general_pythagorean(), + diop_general_sum_of_squares() + """ + var, coeff, diop_type = classify_diop(eq, _dict=False) + + if diop_type == Linear.name: + parameters = None + if param is not None: + parameters = symbols('%s_0:%i' % (param, len(var)), integer=True) + + result = Linear(eq).solve(parameters=parameters) + + if param is None: + result = result(*[0]*len(result.parameters)) + + if len(result) > 0: + return list(result)[0] + else: + return tuple([None]*len(result.parameters)) + + +def base_solution_linear(c, a, b, t=None): + """ + Return the base solution for the linear equation, `ax + by = c`. + + Explanation + =========== + + Used by ``diop_linear()`` to find the base solution of a linear + Diophantine equation. If ``t`` is given then the parametrized solution is + returned. + + Usage + ===== + + ``base_solution_linear(c, a, b, t)``: ``a``, ``b``, ``c`` are coefficients + in `ax + by = c` and ``t`` is the parameter to be used in the solution. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import base_solution_linear + >>> from sympy.abc import t + >>> base_solution_linear(5, 2, 3) # equation 2*x + 3*y = 5 + (-5, 5) + >>> base_solution_linear(0, 5, 7) # equation 5*x + 7*y = 0 + (0, 0) + >>> base_solution_linear(5, 2, 3, t) # equation 2*x + 3*y = 5 + (3*t - 5, 5 - 2*t) + >>> base_solution_linear(0, 5, 7, t) # equation 5*x + 7*y = 0 + (7*t, -5*t) + """ + a, b, c = _remove_gcd(a, b, c) + + if c == 0: + if t is None: + return (0, 0) + if b < 0: + t = -t + return (b*t, -a*t) + + x0, y0, d = igcdex(abs(a), abs(b)) + x0 *= sign(a) + y0 *= sign(b) + if c % d: + return (None, None) + if t is None: + return (c*x0, c*y0) + if b < 0: + t = -t + return (c*x0 + b*t, c*y0 - a*t) + + +def diop_univariate(eq): + """ + Solves a univariate diophantine equations. + + Explanation + =========== + + A univariate diophantine equation is an equation of the form + `a_{0} + a_{1}x + a_{2}x^2 + .. + a_{n}x^n = 0` where `a_{1}, a_{2}, ..a_{n}` are + integer constants and `x` is an integer variable. + + Usage + ===== + + ``diop_univariate(eq)``: Returns a set containing solutions to the + diophantine equation ``eq``. + + Details + ======= + + ``eq`` is a univariate diophantine equation which is assumed to be zero. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import diop_univariate + >>> from sympy.abc import x + >>> diop_univariate((x - 2)*(x - 3)**2) # solves equation (x - 2)*(x - 3)**2 == 0 + {(2,), (3,)} + + """ + var, coeff, diop_type = classify_diop(eq, _dict=False) + + if diop_type == Univariate.name: + return {(int(i),) for i in solveset_real( + eq, var[0]).intersect(S.Integers)} + + +def divisible(a, b): + """ + Returns `True` if ``a`` is divisible by ``b`` and `False` otherwise. + """ + return not a % b + + +def diop_quadratic(eq, param=symbols("t", integer=True)): + """ + Solves quadratic diophantine equations. + + i.e. equations of the form `Ax^2 + Bxy + Cy^2 + Dx + Ey + F = 0`. Returns a + set containing the tuples `(x, y)` which contains the solutions. If there + are no solutions then `(None, None)` is returned. + + Usage + ===== + + ``diop_quadratic(eq, param)``: ``eq`` is a quadratic binary diophantine + equation. ``param`` is used to indicate the parameter to be used in the + solution. + + Details + ======= + + ``eq`` should be an expression which is assumed to be zero. + ``param`` is a parameter to be used in the solution. + + Examples + ======== + + >>> from sympy.abc import x, y, t + >>> from sympy.solvers.diophantine.diophantine import diop_quadratic + >>> diop_quadratic(x**2 + y**2 + 2*x + 2*y + 2, t) + {(-1, -1)} + + References + ========== + + .. [1] Methods to solve Ax^2 + Bxy + Cy^2 + Dx + Ey + F = 0, [online], + Available: https://www.alpertron.com.ar/METHODS.HTM + .. [2] Solving the equation ax^2+ bxy + cy^2 + dx + ey + f= 0, [online], + Available: https://web.archive.org/web/20160323033111/http://www.jpr2718.org/ax2p.pdf + + See Also + ======== + + diop_linear(), diop_ternary_quadratic(), diop_general_sum_of_squares(), + diop_general_pythagorean() + """ + var, coeff, diop_type = classify_diop(eq, _dict=False) + + if diop_type == BinaryQuadratic.name: + if param is not None: + parameters = [param, Symbol("u", integer=True)] + else: + parameters = None + return set(BinaryQuadratic(eq).solve(parameters=parameters)) + + +def is_solution_quad(var, coeff, u, v): + """ + Check whether `(u, v)` is solution to the quadratic binary diophantine + equation with the variable list ``var`` and coefficient dictionary + ``coeff``. + + Not intended for use by normal users. + """ + reps = dict(zip(var, (u, v))) + eq = Add(*[j*i.xreplace(reps) for i, j in coeff.items()]) + return _mexpand(eq) == 0 + + +def diop_DN(D, N, t=symbols("t", integer=True)): + """ + Solves the equation `x^2 - Dy^2 = N`. + + Explanation + =========== + + Mainly concerned with the case `D > 0, D` is not a perfect square, + which is the same as the generalized Pell equation. The LMM + algorithm [1]_ is used to solve this equation. + + Returns one solution tuple, (`x, y)` for each class of the solutions. + Other solutions of the class can be constructed according to the + values of ``D`` and ``N``. + + Usage + ===== + + ``diop_DN(D, N, t)``: D and N are integers as in `x^2 - Dy^2 = N` and + ``t`` is the parameter to be used in the solutions. + + Details + ======= + + ``D`` and ``N`` correspond to D and N in the equation. + ``t`` is the parameter to be used in the solutions. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import diop_DN + >>> diop_DN(13, -4) # Solves equation x**2 - 13*y**2 = -4 + [(3, 1), (393, 109), (36, 10)] + + The output can be interpreted as follows: There are three fundamental + solutions to the equation `x^2 - 13y^2 = -4` given by (3, 1), (393, 109) + and (36, 10). Each tuple is in the form (x, y), i.e. solution (3, 1) means + that `x = 3` and `y = 1`. + + >>> diop_DN(986, 1) # Solves equation x**2 - 986*y**2 = 1 + [(49299, 1570)] + + See Also + ======== + + find_DN(), diop_bf_DN() + + References + ========== + + .. [1] Solving the generalized Pell equation x**2 - D*y**2 = N, John P. + Robertson, July 31, 2004, Pages 16 - 17. [online], Available: + https://web.archive.org/web/20160323033128/http://www.jpr2718.org/pell.pdf + """ + if D < 0: + if N == 0: + return [(0, 0)] + if N < 0: + return [] + # N > 0: + sol = [] + for d in divisors(square_factor(N), generator=True): + for x, y in cornacchia(1, int(-D), int(N // d**2)): + sol.append((d*x, d*y)) + if D == -1: + sol.append((d*y, d*x)) + return sol + + if D == 0: + if N < 0: + return [] + if N == 0: + return [(0, t)] + sN, _exact = integer_nthroot(N, 2) + if _exact: + return [(sN, t)] + return [] + + # D > 0 + sD, _exact = integer_nthroot(D, 2) + if _exact: + if N == 0: + return [(sD*t, t)] + + sol = [] + for y in range(floor(sign(N)*(N - 1)/(2*sD)) + 1): + try: + sq, _exact = integer_nthroot(D*y**2 + N, 2) + except ValueError: + _exact = False + if _exact: + sol.append((sq, y)) + return sol + + if 1 < N**2 < D: + # It is much faster to call `_special_diop_DN`. + return _special_diop_DN(D, N) + + if N == 0: + return [(0, 0)] + + sol = [] + if abs(N) == 1: + pqa = PQa(0, 1, D) + *_, prev_B, prev_G = next(pqa) + for j, (*_, a, _, _B, _G) in enumerate(pqa): + if a == 2*sD: + break + prev_B, prev_G = _B, _G + if j % 2: + if N == 1: + sol.append((prev_G, prev_B)) + return sol + if N == -1: + return [(prev_G, prev_B)] + for _ in range(j): + *_, _B, _G = next(pqa) + return [(_G, _B)] + + for f in divisors(square_factor(N), generator=True): + m = N // f**2 + am = abs(m) + for sqm in sqrt_mod(D, am, all_roots=True): + z = symmetric_residue(sqm, am) + pqa = PQa(z, am, D) + *_, prev_B, prev_G = next(pqa) + for _ in range(length(z, am, D) - 1): + _, q, *_, _B, _G = next(pqa) + if abs(q) == 1: + if prev_G**2 - D*prev_B**2 == m: + sol.append((f*prev_G, f*prev_B)) + elif a := diop_DN(D, -1): + sol.append((f*(prev_G*a[0][0] + prev_B*D*a[0][1]), + f*(prev_G*a[0][1] + prev_B*a[0][0]))) + break + prev_B, prev_G = _B, _G + return sol + + +def _special_diop_DN(D, N): + """ + Solves the equation `x^2 - Dy^2 = N` for the special case where + `1 < N**2 < D` and `D` is not a perfect square. + It is better to call `diop_DN` rather than this function, as + the former checks the condition `1 < N**2 < D`, and calls the latter only + if appropriate. + + Usage + ===== + + WARNING: Internal method. Do not call directly! + + ``_special_diop_DN(D, N)``: D and N are integers as in `x^2 - Dy^2 = N`. + + Details + ======= + + ``D`` and ``N`` correspond to D and N in the equation. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import _special_diop_DN + >>> _special_diop_DN(13, -3) # Solves equation x**2 - 13*y**2 = -3 + [(7, 2), (137, 38)] + + The output can be interpreted as follows: There are two fundamental + solutions to the equation `x^2 - 13y^2 = -3` given by (7, 2) and + (137, 38). Each tuple is in the form (x, y), i.e. solution (7, 2) means + that `x = 7` and `y = 2`. + + >>> _special_diop_DN(2445, -20) # Solves equation x**2 - 2445*y**2 = -20 + [(445, 9), (17625560, 356454), (698095554475, 14118073569)] + + See Also + ======== + + diop_DN() + + References + ========== + + .. [1] Section 4.4.4 of the following book: + Quadratic Diophantine Equations, T. Andreescu and D. Andrica, + Springer, 2015. + """ + + # The following assertion was removed for efficiency, with the understanding + # that this method is not called directly. The parent method, `diop_DN` + # is responsible for performing the appropriate checks. + # + # assert (1 < N**2 < D) and (not integer_nthroot(D, 2)[1]) + + sqrt_D = isqrt(D) + F = {N // f**2: f for f in divisors(square_factor(abs(N)), generator=True)} + P = 0 + Q = 1 + G0, G1 = 0, 1 + B0, B1 = 1, 0 + + solutions = [] + while True: + for _ in range(2): + a = (P + sqrt_D) // Q + P = a*Q - P + Q = (D - P**2) // Q + G0, G1 = G1, a*G1 + G0 + B0, B1 = B1, a*B1 + B0 + if (s := G1**2 - D*B1**2) in F: + f = F[s] + solutions.append((f*G1, f*B1)) + if Q == 1: + break + return solutions + + +def cornacchia(a:int, b:int, m:int) -> set[tuple[int, int]]: + r""" + Solves `ax^2 + by^2 = m` where `\gcd(a, b) = 1 = gcd(a, m)` and `a, b > 0`. + + Explanation + =========== + + Uses the algorithm due to Cornacchia. The method only finds primitive + solutions, i.e. ones with `\gcd(x, y) = 1`. So this method cannot be used to + find the solutions of `x^2 + y^2 = 20` since the only solution to former is + `(x, y) = (4, 2)` and it is not primitive. When `a = b`, only the + solutions with `x \leq y` are found. For more details, see the References. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import cornacchia + >>> cornacchia(2, 3, 35) # equation 2x**2 + 3y**2 = 35 + {(2, 3), (4, 1)} + >>> cornacchia(1, 1, 25) # equation x**2 + y**2 = 25 + {(4, 3)} + + References + =========== + + .. [1] A. Nitaj, "L'algorithme de Cornacchia" + .. [2] Solving the diophantine equation ax**2 + by**2 = m by Cornacchia's + method, [online], Available: + http://www.numbertheory.org/php/cornacchia.html + + See Also + ======== + + sympy.utilities.iterables.signed_permutations + """ + # Assume gcd(a, b) = gcd(a, m) = 1 and a, b > 0 but no error checking + sols = set() + + if a + b > m: + # xy = 0 must hold if there exists a solution + if a == 1: + # y = 0 + s, _exact = iroot(m // a, 2) + if _exact: + sols.add((int(s), 0)) + if a == b: + # only keep one solution + return sols + if m % b == 0: + # x = 0 + s, _exact = iroot(m // b, 2) + if _exact: + sols.add((0, int(s))) + return sols + + # the original cornacchia + for t in sqrt_mod_iter(-b*invert(a, m), m): + if t < m // 2: + continue + u, r = m, t + while (m1 := m - a*r**2) <= 0: + u, r = r, u % r + m1, _r = divmod(m1, b) + if _r: + continue + s, _exact = iroot(m1, 2) + if _exact: + if a == b and r < s: + r, s = s, r + sols.add((int(r), int(s))) + return sols + + +def PQa(P_0, Q_0, D): + r""" + Returns useful information needed to solve the Pell equation. + + Explanation + =========== + + There are six sequences of integers defined related to the continued + fraction representation of `\\frac{P + \sqrt{D}}{Q}`, namely {`P_{i}`}, + {`Q_{i}`}, {`a_{i}`},{`A_{i}`}, {`B_{i}`}, {`G_{i}`}. ``PQa()`` Returns + these values as a 6-tuple in the same order as mentioned above. Refer [1]_ + for more detailed information. + + Usage + ===== + + ``PQa(P_0, Q_0, D)``: ``P_0``, ``Q_0`` and ``D`` are integers corresponding + to `P_{0}`, `Q_{0}` and `D` in the continued fraction + `\\frac{P_{0} + \sqrt{D}}{Q_{0}}`. + Also it's assumed that `P_{0}^2 == D mod(|Q_{0}|)` and `D` is square free. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import PQa + >>> pqa = PQa(13, 4, 5) # (13 + sqrt(5))/4 + >>> next(pqa) # (P_0, Q_0, a_0, A_0, B_0, G_0) + (13, 4, 3, 3, 1, -1) + >>> next(pqa) # (P_1, Q_1, a_1, A_1, B_1, G_1) + (-1, 1, 1, 4, 1, 3) + + References + ========== + + .. [1] Solving the generalized Pell equation x^2 - Dy^2 = N, John P. + Robertson, July 31, 2004, Pages 4 - 8. https://web.archive.org/web/20160323033128/http://www.jpr2718.org/pell.pdf + """ + sqD = isqrt(D) + A2 = B1 = 0 + A1 = B2 = 1 + G1 = Q_0 + G2 = -P_0 + P_i = P_0 + Q_i = Q_0 + + while True: + a_i = (P_i + sqD) // Q_i + A1, A2 = a_i*A1 + A2, A1 + B1, B2 = a_i*B1 + B2, B1 + G1, G2 = a_i*G1 + G2, G1 + yield P_i, Q_i, a_i, A1, B1, G1 + + P_i = a_i*Q_i - P_i + Q_i = (D - P_i**2) // Q_i + + +def diop_bf_DN(D, N, t=symbols("t", integer=True)): + r""" + Uses brute force to solve the equation, `x^2 - Dy^2 = N`. + + Explanation + =========== + + Mainly concerned with the generalized Pell equation which is the case when + `D > 0, D` is not a perfect square. For more information on the case refer + [1]_. Let `(t, u)` be the minimal positive solution of the equation + `x^2 - Dy^2 = 1`. Then this method requires + `\sqrt{\\frac{\mid N \mid (t \pm 1)}{2D}}` to be small. + + Usage + ===== + + ``diop_bf_DN(D, N, t)``: ``D`` and ``N`` are coefficients in + `x^2 - Dy^2 = N` and ``t`` is the parameter to be used in the solutions. + + Details + ======= + + ``D`` and ``N`` correspond to D and N in the equation. + ``t`` is the parameter to be used in the solutions. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import diop_bf_DN + >>> diop_bf_DN(13, -4) + [(3, 1), (-3, 1), (36, 10)] + >>> diop_bf_DN(986, 1) + [(49299, 1570)] + + See Also + ======== + + diop_DN() + + References + ========== + + .. [1] Solving the generalized Pell equation x**2 - D*y**2 = N, John P. + Robertson, July 31, 2004, Page 15. https://web.archive.org/web/20160323033128/http://www.jpr2718.org/pell.pdf + """ + D = as_int(D) + N = as_int(N) + + sol = [] + a = diop_DN(D, 1) + u = a[0][0] + + if N == 0: + if D < 0: + return [(0, 0)] + if D == 0: + return [(0, t)] + sD, _exact = integer_nthroot(D, 2) + if _exact: + return [(sD*t, t), (-sD*t, t)] + return [(0, 0)] + + if abs(N) == 1: + return diop_DN(D, N) + + if N > 1: + L1 = 0 + L2 = integer_nthroot(int(N*(u - 1)/(2*D)), 2)[0] + 1 + else: # N < -1 + L1, _exact = integer_nthroot(-int(N/D), 2) + if not _exact: + L1 += 1 + L2 = integer_nthroot(-int(N*(u + 1)/(2*D)), 2)[0] + 1 + + for y in range(L1, L2): + try: + x, _exact = integer_nthroot(N + D*y**2, 2) + except ValueError: + _exact = False + if _exact: + sol.append((x, y)) + if not equivalent(x, y, -x, y, D, N): + sol.append((-x, y)) + + return sol + + +def equivalent(u, v, r, s, D, N): + """ + Returns True if two solutions `(u, v)` and `(r, s)` of `x^2 - Dy^2 = N` + belongs to the same equivalence class and False otherwise. + + Explanation + =========== + + Two solutions `(u, v)` and `(r, s)` to the above equation fall to the same + equivalence class iff both `(ur - Dvs)` and `(us - vr)` are divisible by + `N`. See reference [1]_. No test is performed to test whether `(u, v)` and + `(r, s)` are actually solutions to the equation. User should take care of + this. + + Usage + ===== + + ``equivalent(u, v, r, s, D, N)``: `(u, v)` and `(r, s)` are two solutions + of the equation `x^2 - Dy^2 = N` and all parameters involved are integers. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import equivalent + >>> equivalent(18, 5, -18, -5, 13, -1) + True + >>> equivalent(3, 1, -18, 393, 109, -4) + False + + References + ========== + + .. [1] Solving the generalized Pell equation x**2 - D*y**2 = N, John P. + Robertson, July 31, 2004, Page 12. https://web.archive.org/web/20160323033128/http://www.jpr2718.org/pell.pdf + + """ + return divisible(u*r - D*v*s, N) and divisible(u*s - v*r, N) + + +def length(P, Q, D): + r""" + Returns the (length of aperiodic part + length of periodic part) of + continued fraction representation of `\\frac{P + \sqrt{D}}{Q}`. + + It is important to remember that this does NOT return the length of the + periodic part but the sum of the lengths of the two parts as mentioned + above. + + Usage + ===== + + ``length(P, Q, D)``: ``P``, ``Q`` and ``D`` are integers corresponding to + the continued fraction `\\frac{P + \sqrt{D}}{Q}`. + + Details + ======= + + ``P``, ``D`` and ``Q`` corresponds to P, D and Q in the continued fraction, + `\\frac{P + \sqrt{D}}{Q}`. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import length + >>> length(-2, 4, 5) # (-2 + sqrt(5))/4 + 3 + >>> length(-5, 4, 17) # (-5 + sqrt(17))/4 + 4 + + See Also + ======== + sympy.ntheory.continued_fraction.continued_fraction_periodic + """ + from sympy.ntheory.continued_fraction import continued_fraction_periodic + v = continued_fraction_periodic(P, Q, D) + if isinstance(v[-1], list): + rpt = len(v[-1]) + nonrpt = len(v) - 1 + else: + rpt = 0 + nonrpt = len(v) + return rpt + nonrpt + + +def transformation_to_DN(eq): + """ + This function transforms general quadratic, + `ax^2 + bxy + cy^2 + dx + ey + f = 0` + to more easy to deal with `X^2 - DY^2 = N` form. + + Explanation + =========== + + This is used to solve the general quadratic equation by transforming it to + the latter form. Refer to [1]_ for more detailed information on the + transformation. This function returns a tuple (A, B) where A is a 2 X 2 + matrix and B is a 2 X 1 matrix such that, + + Transpose([x y]) = A * Transpose([X Y]) + B + + Usage + ===== + + ``transformation_to_DN(eq)``: where ``eq`` is the quadratic to be + transformed. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy.solvers.diophantine.diophantine import transformation_to_DN + >>> A, B = transformation_to_DN(x**2 - 3*x*y - y**2 - 2*y + 1) + >>> A + Matrix([ + [1/26, 3/26], + [ 0, 1/13]]) + >>> B + Matrix([ + [-6/13], + [-4/13]]) + + A, B returned are such that Transpose((x y)) = A * Transpose((X Y)) + B. + Substituting these values for `x` and `y` and a bit of simplifying work + will give an equation of the form `x^2 - Dy^2 = N`. + + >>> from sympy.abc import X, Y + >>> from sympy import Matrix, simplify + >>> u = (A*Matrix([X, Y]) + B)[0] # Transformation for x + >>> u + X/26 + 3*Y/26 - 6/13 + >>> v = (A*Matrix([X, Y]) + B)[1] # Transformation for y + >>> v + Y/13 - 4/13 + + Next we will substitute these formulas for `x` and `y` and do + ``simplify()``. + + >>> eq = simplify((x**2 - 3*x*y - y**2 - 2*y + 1).subs(zip((x, y), (u, v)))) + >>> eq + X**2/676 - Y**2/52 + 17/13 + + By multiplying the denominator appropriately, we can get a Pell equation + in the standard form. + + >>> eq * 676 + X**2 - 13*Y**2 + 884 + + If only the final equation is needed, ``find_DN()`` can be used. + + See Also + ======== + + find_DN() + + References + ========== + + .. [1] Solving the equation ax^2 + bxy + cy^2 + dx + ey + f = 0, + John P.Robertson, May 8, 2003, Page 7 - 11. + https://web.archive.org/web/20160323033111/http://www.jpr2718.org/ax2p.pdf + """ + + var, coeff, diop_type = classify_diop(eq, _dict=False) + if diop_type == BinaryQuadratic.name: + return _transformation_to_DN(var, coeff) + + +def _transformation_to_DN(var, coeff): + + x, y = var + + a = coeff[x**2] + b = coeff[x*y] + c = coeff[y**2] + d = coeff[x] + e = coeff[y] + f = coeff[1] + + a, b, c, d, e, f = [as_int(i) for i in _remove_gcd(a, b, c, d, e, f)] + + X, Y = symbols("X, Y", integer=True) + + if b: + B, C = _rational_pq(2*a, b) + A, T = _rational_pq(a, B**2) + + # eq_1 = A*B*X**2 + B*(c*T - A*C**2)*Y**2 + d*T*X + (B*e*T - d*T*C)*Y + f*T*B + coeff = {X**2: A*B, X*Y: 0, Y**2: B*(c*T - A*C**2), X: d*T, Y: B*e*T - d*T*C, 1: f*T*B} + A_0, B_0 = _transformation_to_DN([X, Y], coeff) + return Matrix(2, 2, [S.One/B, -S(C)/B, 0, 1])*A_0, Matrix(2, 2, [S.One/B, -S(C)/B, 0, 1])*B_0 + + if d: + B, C = _rational_pq(2*a, d) + A, T = _rational_pq(a, B**2) + + # eq_2 = A*X**2 + c*T*Y**2 + e*T*Y + f*T - A*C**2 + coeff = {X**2: A, X*Y: 0, Y**2: c*T, X: 0, Y: e*T, 1: f*T - A*C**2} + A_0, B_0 = _transformation_to_DN([X, Y], coeff) + return Matrix(2, 2, [S.One/B, 0, 0, 1])*A_0, Matrix(2, 2, [S.One/B, 0, 0, 1])*B_0 + Matrix([-S(C)/B, 0]) + + if e: + B, C = _rational_pq(2*c, e) + A, T = _rational_pq(c, B**2) + + # eq_3 = a*T*X**2 + A*Y**2 + f*T - A*C**2 + coeff = {X**2: a*T, X*Y: 0, Y**2: A, X: 0, Y: 0, 1: f*T - A*C**2} + A_0, B_0 = _transformation_to_DN([X, Y], coeff) + return Matrix(2, 2, [1, 0, 0, S.One/B])*A_0, Matrix(2, 2, [1, 0, 0, S.One/B])*B_0 + Matrix([0, -S(C)/B]) + + # TODO: pre-simplification: Not necessary but may simplify + # the equation. + return Matrix(2, 2, [S.One/a, 0, 0, 1]), Matrix([0, 0]) + + +def find_DN(eq): + """ + This function returns a tuple, `(D, N)` of the simplified form, + `x^2 - Dy^2 = N`, corresponding to the general quadratic, + `ax^2 + bxy + cy^2 + dx + ey + f = 0`. + + Solving the general quadratic is then equivalent to solving the equation + `X^2 - DY^2 = N` and transforming the solutions by using the transformation + matrices returned by ``transformation_to_DN()``. + + Usage + ===== + + ``find_DN(eq)``: where ``eq`` is the quadratic to be transformed. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy.solvers.diophantine.diophantine import find_DN + >>> find_DN(x**2 - 3*x*y - y**2 - 2*y + 1) + (13, -884) + + Interpretation of the output is that we get `X^2 -13Y^2 = -884` after + transforming `x^2 - 3xy - y^2 - 2y + 1` using the transformation returned + by ``transformation_to_DN()``. + + See Also + ======== + + transformation_to_DN() + + References + ========== + + .. [1] Solving the equation ax^2 + bxy + cy^2 + dx + ey + f = 0, + John P.Robertson, May 8, 2003, Page 7 - 11. + https://web.archive.org/web/20160323033111/http://www.jpr2718.org/ax2p.pdf + """ + var, coeff, diop_type = classify_diop(eq, _dict=False) + if diop_type == BinaryQuadratic.name: + return _find_DN(var, coeff) + + +def _find_DN(var, coeff): + + x, y = var + X, Y = symbols("X, Y", integer=True) + A, B = _transformation_to_DN(var, coeff) + + u = (A*Matrix([X, Y]) + B)[0] + v = (A*Matrix([X, Y]) + B)[1] + eq = x**2*coeff[x**2] + x*y*coeff[x*y] + y**2*coeff[y**2] + x*coeff[x] + y*coeff[y] + coeff[1] + + simplified = _mexpand(eq.subs(zip((x, y), (u, v)))) + + coeff = simplified.as_coefficients_dict() + + return -coeff[Y**2]/coeff[X**2], -coeff[1]/coeff[X**2] + + +def check_param(x, y, a, params): + """ + If there is a number modulo ``a`` such that ``x`` and ``y`` are both + integers, then return a parametric representation for ``x`` and ``y`` + else return (None, None). + + Here ``x`` and ``y`` are functions of ``t``. + """ + from sympy.simplify.simplify import clear_coefficients + + if x.is_number and not x.is_Integer: + return DiophantineSolutionSet([x, y], parameters=params) + + if y.is_number and not y.is_Integer: + return DiophantineSolutionSet([x, y], parameters=params) + + m, n = symbols("m, n", integer=True) + c, p = (m*x + n*y).as_content_primitive() + if a % c.q: + return DiophantineSolutionSet([x, y], parameters=params) + + # clear_coefficients(mx + b, R)[1] -> (R - b)/m + eq = clear_coefficients(x, m)[1] - clear_coefficients(y, n)[1] + junk, eq = eq.as_content_primitive() + + return _diop_solve(eq, params=params) + + +def diop_ternary_quadratic(eq, parameterize=False): + """ + Solves the general quadratic ternary form, + `ax^2 + by^2 + cz^2 + fxy + gyz + hxz = 0`. + + Returns a tuple `(x, y, z)` which is a base solution for the above + equation. If there are no solutions, `(None, None, None)` is returned. + + Usage + ===== + + ``diop_ternary_quadratic(eq)``: Return a tuple containing a basic solution + to ``eq``. + + Details + ======= + + ``eq`` should be an homogeneous expression of degree two in three variables + and it is assumed to be zero. + + Examples + ======== + + >>> from sympy.abc import x, y, z + >>> from sympy.solvers.diophantine.diophantine import diop_ternary_quadratic + >>> diop_ternary_quadratic(x**2 + 3*y**2 - z**2) + (1, 0, 1) + >>> diop_ternary_quadratic(4*x**2 + 5*y**2 - z**2) + (1, 0, 2) + >>> diop_ternary_quadratic(45*x**2 - 7*y**2 - 8*x*y - z**2) + (28, 45, 105) + >>> diop_ternary_quadratic(x**2 - 49*y**2 - z**2 + 13*z*y -8*x*y) + (9, 1, 5) + """ + var, coeff, diop_type = classify_diop(eq, _dict=False) + + if diop_type in ( + HomogeneousTernaryQuadratic.name, + HomogeneousTernaryQuadraticNormal.name): + sol = _diop_ternary_quadratic(var, coeff) + if len(sol) > 0: + x_0, y_0, z_0 = list(sol)[0] + else: + x_0, y_0, z_0 = None, None, None + + if parameterize: + return _parametrize_ternary_quadratic( + (x_0, y_0, z_0), var, coeff) + return x_0, y_0, z_0 + + +def _diop_ternary_quadratic(_var, coeff): + eq = sum(i*coeff[i] for i in coeff) + if HomogeneousTernaryQuadratic(eq).matches(): + return HomogeneousTernaryQuadratic(eq, free_symbols=_var).solve() + elif HomogeneousTernaryQuadraticNormal(eq).matches(): + return HomogeneousTernaryQuadraticNormal(eq, free_symbols=_var).solve() + + +def transformation_to_normal(eq): + """ + Returns the transformation Matrix that converts a general ternary + quadratic equation ``eq`` (`ax^2 + by^2 + cz^2 + dxy + eyz + fxz`) + to a form without cross terms: `ax^2 + by^2 + cz^2 = 0`. This is + not used in solving ternary quadratics; it is only implemented for + the sake of completeness. + """ + var, coeff, diop_type = classify_diop(eq, _dict=False) + + if diop_type in ( + "homogeneous_ternary_quadratic", + "homogeneous_ternary_quadratic_normal"): + return _transformation_to_normal(var, coeff) + + +def _transformation_to_normal(var, coeff): + + _var = list(var) # copy + x, y, z = var + + if not any(coeff[i**2] for i in var): + # https://math.stackexchange.com/questions/448051/transform-quadratic-ternary-form-to-normal-form/448065#448065 + a = coeff[x*y] + b = coeff[y*z] + c = coeff[x*z] + swap = False + if not a: # b can't be 0 or else there aren't 3 vars + swap = True + a, b = b, a + T = Matrix(((1, 1, -b/a), (1, -1, -c/a), (0, 0, 1))) + if swap: + T.row_swap(0, 1) + T.col_swap(0, 1) + return T + + if coeff[x**2] == 0: + # If the coefficient of x is zero change the variables + if coeff[y**2] == 0: + _var[0], _var[2] = var[2], var[0] + T = _transformation_to_normal(_var, coeff) + T.row_swap(0, 2) + T.col_swap(0, 2) + return T + + _var[0], _var[1] = var[1], var[0] + T = _transformation_to_normal(_var, coeff) + T.row_swap(0, 1) + T.col_swap(0, 1) + return T + + # Apply the transformation x --> X - (B*Y + C*Z)/(2*A) + if coeff[x*y] != 0 or coeff[x*z] != 0: + A = coeff[x**2] + B = coeff[x*y] + C = coeff[x*z] + D = coeff[y**2] + E = coeff[y*z] + F = coeff[z**2] + + _coeff = {} + + _coeff[x**2] = 4*A**2 + _coeff[y**2] = 4*A*D - B**2 + _coeff[z**2] = 4*A*F - C**2 + _coeff[y*z] = 4*A*E - 2*B*C + _coeff[x*y] = 0 + _coeff[x*z] = 0 + + T_0 = _transformation_to_normal(_var, _coeff) + return Matrix(3, 3, [1, S(-B)/(2*A), S(-C)/(2*A), 0, 1, 0, 0, 0, 1])*T_0 + + elif coeff[y*z] != 0: + if coeff[y**2] == 0: + if coeff[z**2] == 0: + # Equations of the form A*x**2 + E*yz = 0. + # Apply transformation y -> Y + Z ans z -> Y - Z + return Matrix(3, 3, [1, 0, 0, 0, 1, 1, 0, 1, -1]) + + # Ax**2 + E*y*z + F*z**2 = 0 + _var[0], _var[2] = var[2], var[0] + T = _transformation_to_normal(_var, coeff) + T.row_swap(0, 2) + T.col_swap(0, 2) + return T + + # A*x**2 + D*y**2 + E*y*z + F*z**2 = 0, F may be zero + _var[0], _var[1] = var[1], var[0] + T = _transformation_to_normal(_var, coeff) + T.row_swap(0, 1) + T.col_swap(0, 1) + return T + + return Matrix.eye(3) + + +def parametrize_ternary_quadratic(eq): + """ + Returns the parametrized general solution for the ternary quadratic + equation ``eq`` which has the form + `ax^2 + by^2 + cz^2 + fxy + gyz + hxz = 0`. + + Examples + ======== + + >>> from sympy import Tuple, ordered + >>> from sympy.abc import x, y, z + >>> from sympy.solvers.diophantine.diophantine import parametrize_ternary_quadratic + + The parametrized solution may be returned with three parameters: + + >>> parametrize_ternary_quadratic(2*x**2 + y**2 - 2*z**2) + (p**2 - 2*q**2, -2*p**2 + 4*p*q - 4*p*r - 4*q**2, p**2 - 4*p*q + 2*q**2 - 4*q*r) + + There might also be only two parameters: + + >>> parametrize_ternary_quadratic(4*x**2 + 2*y**2 - 3*z**2) + (2*p**2 - 3*q**2, -4*p**2 + 12*p*q - 6*q**2, 4*p**2 - 8*p*q + 6*q**2) + + Notes + ===== + + Consider ``p`` and ``q`` in the previous 2-parameter + solution and observe that more than one solution can be represented + by a given pair of parameters. If `p` and ``q`` are not coprime, this is + trivially true since the common factor will also be a common factor of the + solution values. But it may also be true even when ``p`` and + ``q`` are coprime: + + >>> sol = Tuple(*_) + >>> p, q = ordered(sol.free_symbols) + >>> sol.subs([(p, 3), (q, 2)]) + (6, 12, 12) + >>> sol.subs([(q, 1), (p, 1)]) + (-1, 2, 2) + >>> sol.subs([(q, 0), (p, 1)]) + (2, -4, 4) + >>> sol.subs([(q, 1), (p, 0)]) + (-3, -6, 6) + + Except for sign and a common factor, these are equivalent to + the solution of (1, 2, 2). + + References + ========== + + .. [1] The algorithmic resolution of Diophantine equations, Nigel P. Smart, + London Mathematical Society Student Texts 41, Cambridge University + Press, Cambridge, 1998. + + """ + var, coeff, diop_type = classify_diop(eq, _dict=False) + + if diop_type in ( + "homogeneous_ternary_quadratic", + "homogeneous_ternary_quadratic_normal"): + x_0, y_0, z_0 = list(_diop_ternary_quadratic(var, coeff))[0] + return _parametrize_ternary_quadratic( + (x_0, y_0, z_0), var, coeff) + + +def _parametrize_ternary_quadratic(solution, _var, coeff): + # called for a*x**2 + b*y**2 + c*z**2 + d*x*y + e*y*z + f*x*z = 0 + assert 1 not in coeff + + x_0, y_0, z_0 = solution + + v = list(_var) # copy + + if x_0 is None: + return (None, None, None) + + if solution.count(0) >= 2: + # if there are 2 zeros the equation reduces + # to k*X**2 == 0 where X is x, y, or z so X must + # be zero, too. So there is only the trivial + # solution. + return (None, None, None) + + if x_0 == 0: + v[0], v[1] = v[1], v[0] + y_p, x_p, z_p = _parametrize_ternary_quadratic( + (y_0, x_0, z_0), v, coeff) + return x_p, y_p, z_p + + x, y, z = v + r, p, q = symbols("r, p, q", integer=True) + + eq = sum(k*v for k, v in coeff.items()) + eq_1 = _mexpand(eq.subs(zip( + (x, y, z), (r*x_0, r*y_0 + p, r*z_0 + q)))) + A, B = eq_1.as_independent(r, as_Add=True) + + + x = A*x_0 + y = (A*y_0 - _mexpand(B/r*p)) + z = (A*z_0 - _mexpand(B/r*q)) + + return _remove_gcd(x, y, z) + + +def diop_ternary_quadratic_normal(eq, parameterize=False): + """ + Solves the quadratic ternary diophantine equation, + `ax^2 + by^2 + cz^2 = 0`. + + Explanation + =========== + + Here the coefficients `a`, `b`, and `c` should be non zero. Otherwise the + equation will be a quadratic binary or univariate equation. If solvable, + returns a tuple `(x, y, z)` that satisfies the given equation. If the + equation does not have integer solutions, `(None, None, None)` is returned. + + Usage + ===== + + ``diop_ternary_quadratic_normal(eq)``: where ``eq`` is an equation of the form + `ax^2 + by^2 + cz^2 = 0`. + + Examples + ======== + + >>> from sympy.abc import x, y, z + >>> from sympy.solvers.diophantine.diophantine import diop_ternary_quadratic_normal + >>> diop_ternary_quadratic_normal(x**2 + 3*y**2 - z**2) + (1, 0, 1) + >>> diop_ternary_quadratic_normal(4*x**2 + 5*y**2 - z**2) + (1, 0, 2) + >>> diop_ternary_quadratic_normal(34*x**2 - 3*y**2 - 301*z**2) + (4, 9, 1) + """ + var, coeff, diop_type = classify_diop(eq, _dict=False) + if diop_type == HomogeneousTernaryQuadraticNormal.name: + sol = _diop_ternary_quadratic_normal(var, coeff) + if len(sol) > 0: + x_0, y_0, z_0 = list(sol)[0] + else: + x_0, y_0, z_0 = None, None, None + if parameterize: + return _parametrize_ternary_quadratic( + (x_0, y_0, z_0), var, coeff) + return x_0, y_0, z_0 + + +def _diop_ternary_quadratic_normal(var, coeff): + eq = sum(i * coeff[i] for i in coeff) + return HomogeneousTernaryQuadraticNormal(eq, free_symbols=var).solve() + + +def sqf_normal(a, b, c, steps=False): + """ + Return `a', b', c'`, the coefficients of the square-free normal + form of `ax^2 + by^2 + cz^2 = 0`, where `a', b', c'` are pairwise + prime. If `steps` is True then also return three tuples: + `sq`, `sqf`, and `(a', b', c')` where `sq` contains the square + factors of `a`, `b` and `c` after removing the `gcd(a, b, c)`; + `sqf` contains the values of `a`, `b` and `c` after removing + both the `gcd(a, b, c)` and the square factors. + + The solutions for `ax^2 + by^2 + cz^2 = 0` can be + recovered from the solutions of `a'x^2 + b'y^2 + c'z^2 = 0`. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import sqf_normal + >>> sqf_normal(2 * 3**2 * 5, 2 * 5 * 11, 2 * 7**2 * 11) + (11, 1, 5) + >>> sqf_normal(2 * 3**2 * 5, 2 * 5 * 11, 2 * 7**2 * 11, True) + ((3, 1, 7), (5, 55, 11), (11, 1, 5)) + + References + ========== + + .. [1] Legendre's Theorem, Legrange's Descent, + https://public.csusm.edu/aitken_html/notes/legendre.pdf + + + See Also + ======== + + reconstruct() + """ + ABC = _remove_gcd(a, b, c) + sq = tuple(square_factor(i) for i in ABC) + sqf = A, B, C = tuple([i//j**2 for i,j in zip(ABC, sq)]) + pc = igcd(A, B) + A /= pc + B /= pc + pa = igcd(B, C) + B /= pa + C /= pa + pb = igcd(A, C) + A /= pb + B /= pb + + A *= pa + B *= pb + C *= pc + + if steps: + return (sq, sqf, (A, B, C)) + else: + return A, B, C + + +def square_factor(a): + r""" + Returns an integer `c` s.t. `a = c^2k, \ c,k \in Z`. Here `k` is square + free. `a` can be given as an integer or a dictionary of factors. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import square_factor + >>> square_factor(24) + 2 + >>> square_factor(-36*3) + 6 + >>> square_factor(1) + 1 + >>> square_factor({3: 2, 2: 1, -1: 1}) # -18 + 3 + + See Also + ======== + sympy.ntheory.factor_.core + """ + f = a if isinstance(a, dict) else factorint(a) + return Mul(*[p**(e//2) for p, e in f.items()]) + + +def reconstruct(A, B, z): + """ + Reconstruct the `z` value of an equivalent solution of `ax^2 + by^2 + cz^2` + from the `z` value of a solution of the square-free normal form of the + equation, `a'*x^2 + b'*y^2 + c'*z^2`, where `a'`, `b'` and `c'` are square + free and `gcd(a', b', c') == 1`. + """ + f = factorint(igcd(A, B)) + for p, e in f.items(): + if e != 1: + raise ValueError('a and b should be square-free') + z *= p + return z + + +def ldescent(A, B): + """ + Return a non-trivial solution to `w^2 = Ax^2 + By^2` using + Lagrange's method; return None if there is no such solution. + + Parameters + ========== + + A : Integer + B : Integer + non-zero integer + + Returns + ======= + + (int, int, int) | None : a tuple `(w_0, x_0, y_0)` which is a solution to the above equation. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import ldescent + >>> ldescent(1, 1) # w^2 = x^2 + y^2 + (1, 1, 0) + >>> ldescent(4, -7) # w^2 = 4x^2 - 7y^2 + (2, -1, 0) + + This means that `x = -1, y = 0` and `w = 2` is a solution to the equation + `w^2 = 4x^2 - 7y^2` + + >>> ldescent(5, -1) # w^2 = 5x^2 - y^2 + (2, 1, -1) + + References + ========== + + .. [1] The algorithmic resolution of Diophantine equations, Nigel P. Smart, + London Mathematical Society Student Texts 41, Cambridge University + Press, Cambridge, 1998. + .. [2] Cremona, J. E., Rusin, D. (2003). Efficient Solution of Rational Conics. + Mathematics of Computation, 72(243), 1417-1441. + https://doi.org/10.1090/S0025-5718-02-01480-1 + """ + if A == 0 or B == 0: + raise ValueError("A and B must be non-zero integers") + if abs(A) > abs(B): + w, y, x = ldescent(B, A) + return w, x, y + if A == 1: + return (1, 1, 0) + if B == 1: + return (1, 0, 1) + if B == -1: # and A == -1 + return + + r = sqrt_mod(A, B) + if r is None: + return + Q = (r**2 - A) // B + if Q == 0: + return r, -1, 0 + for i in divisors(Q): + d, _exact = integer_nthroot(abs(Q) // i, 2) + if _exact: + B_0 = sign(Q)*i + W, X, Y = ldescent(A, B_0) + return _remove_gcd(-A*X + r*W, r*X - W, Y*B_0*d) + + +def descent(A, B): + """ + Returns a non-trivial solution, (x, y, z), to `x^2 = Ay^2 + Bz^2` + using Lagrange's descent method with lattice-reduction. `A` and `B` + are assumed to be valid for such a solution to exist. + + This is faster than the normal Lagrange's descent algorithm because + the Gaussian reduction is used. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import descent + >>> descent(3, 1) # x**2 = 3*y**2 + z**2 + (1, 0, 1) + + `(x, y, z) = (1, 0, 1)` is a solution to the above equation. + + >>> descent(41, -113) + (-16, -3, 1) + + References + ========== + + .. [1] Cremona, J. E., Rusin, D. (2003). Efficient Solution of Rational Conics. + Mathematics of Computation, 72(243), 1417-1441. + https://doi.org/10.1090/S0025-5718-02-01480-1 + """ + if abs(A) > abs(B): + x, y, z = descent(B, A) + return x, z, y + + if B == 1: + return (1, 0, 1) + if A == 1: + return (1, 1, 0) + if B == -A: + return (0, 1, 1) + if B == A: + x, z, y = descent(-1, A) + return (A*y, z, x) + + w = sqrt_mod(A, B) + x_0, z_0 = gaussian_reduce(w, A, B) + + t = (x_0**2 - A*z_0**2) // B + t_2 = square_factor(t) + t_1 = t // t_2**2 + + x_1, z_1, y_1 = descent(A, t_1) + + return _remove_gcd(x_0*x_1 + A*z_0*z_1, z_0*x_1 + x_0*z_1, t_1*t_2*y_1) + + +def gaussian_reduce(w:int, a:int, b:int) -> tuple[int, int]: + r""" + Returns a reduced solution `(x, z)` to the congruence + `X^2 - aZ^2 \equiv 0 \pmod{b}` so that `x^2 + |a|z^2` is as small as possible. + Here ``w`` is a solution of the congruence `x^2 \equiv a \pmod{b}`. + + This function is intended to be used only for ``descent()``. + + Explanation + =========== + + The Gaussian reduction can find the shortest vector for any norm. + So we define the special norm for the vectors `u = (u_1, u_2)` and `v = (v_1, v_2)` as follows. + + .. math :: + u \cdot v := (wu_1 + bu_2)(wv_1 + bv_2) + |a|u_1v_1 + + Note that, given the mapping `f: (u_1, u_2) \to (wu_1 + bu_2, u_1)`, + `f((u_1,u_2))` is the solution to `X^2 - aZ^2 \equiv 0 \pmod{b}`. + In other words, finding the shortest vector in this norm will yield a solution with smaller `X^2 + |a|Z^2`. + The algorithm starts from basis vectors `(0, 1)` and `(1, 0)` + (corresponding to solutions `(b, 0)` and `(w, 1)`, respectively) and finds the shortest vector. + The shortest vector does not necessarily correspond to the smallest solution, + but since ``descent()`` only wants the smallest possible solution, it is sufficient. + + Parameters + ========== + + w : int + ``w`` s.t. `w^2 \equiv a \pmod{b}` + a : int + square-free nonzero integer + b : int + square-free nonzero integer + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import gaussian_reduce + >>> from sympy.ntheory.residue_ntheory import sqrt_mod + >>> a, b = 19, 101 + >>> gaussian_reduce(sqrt_mod(a, b), a, b) # 1**2 - 19*(-4)**2 = -303 + (1, -4) + >>> a, b = 11, 14 + >>> x, z = gaussian_reduce(sqrt_mod(a, b), a, b) + >>> (x**2 - a*z**2) % b == 0 + True + + It does not always return the smallest solution. + + >>> a, b = 6, 95 + >>> min_x, min_z = 1, 4 + >>> x, z = gaussian_reduce(sqrt_mod(a, b), a, b) + >>> (x**2 - a*z**2) % b == 0 and (min_x**2 - a*min_z**2) % b == 0 + True + >>> min_x**2 + abs(a)*min_z**2 < x**2 + abs(a)*z**2 + True + + References + ========== + + .. [1] Gaussian lattice Reduction [online]. Available: + https://web.archive.org/web/20201021115213/http://home.ie.cuhk.edu.hk/~wkshum/wordpress/?p=404 + .. [2] Cremona, J. E., Rusin, D. (2003). Efficient Solution of Rational Conics. + Mathematics of Computation, 72(243), 1417-1441. + https://doi.org/10.1090/S0025-5718-02-01480-1 + """ + a = abs(a) + def _dot(u, v): + return u[0]*v[0] + a*u[1]*v[1] + + u = (b, 0) + v = (w, 1) if b*w >= 0 else (-w, -1) + # i.e., _dot(u, v) >= 0 + + if b**2 < w**2 + a: + u, v = v, u + # i.e., norm(u) >= norm(v), where norm(u) := sqrt(_dot(u, u)) + + while _dot(u, u) > (dv := _dot(v, v)): + k = _dot(u, v) // dv + u, v = v, (u[0] - k*v[0], u[1] - k*v[1]) + c = (v[0] - u[0], v[1] - u[1]) + if _dot(c, c) <= _dot(u, u) <= 2*_dot(u, v): + return c + return u + + +def holzer(x, y, z, a, b, c): + r""" + Simplify the solution `(x, y, z)` of the equation + `ax^2 + by^2 = cz^2` with `a, b, c > 0` and `z^2 \geq \mid ab \mid` to + a new reduced solution `(x', y', z')` such that `z'^2 \leq \mid ab \mid`. + + The algorithm is an interpretation of Mordell's reduction as described + on page 8 of Cremona and Rusin's paper [1]_ and the work of Mordell in + reference [2]_. + + References + ========== + + .. [1] Cremona, J. E., Rusin, D. (2003). Efficient Solution of Rational Conics. + Mathematics of Computation, 72(243), 1417-1441. + https://doi.org/10.1090/S0025-5718-02-01480-1 + .. [2] Diophantine Equations, L. J. Mordell, page 48. + + """ + + if _odd(c): + k = 2*c + else: + k = c//2 + + small = a*b*c + step = 0 + while True: + t1, t2, t3 = a*x**2, b*y**2, c*z**2 + # check that it's a solution + if t1 + t2 != t3: + if step == 0: + raise ValueError('bad starting solution') + break + x_0, y_0, z_0 = x, y, z + if max(t1, t2, t3) <= small: + # Holzer condition + break + + uv = u, v = base_solution_linear(k, y_0, -x_0) + if None in uv: + break + + p, q = -(a*u*x_0 + b*v*y_0), c*z_0 + r = Rational(p, q) + if _even(c): + w = _nint_or_floor(p, q) + assert abs(w - r) <= S.Half + else: + w = p//q # floor + if _odd(a*u + b*v + c*w): + w += 1 + assert abs(w - r) <= S.One + + A = (a*u**2 + b*v**2 + c*w**2) + B = (a*u*x_0 + b*v*y_0 + c*w*z_0) + x = Rational(x_0*A - 2*u*B, k) + y = Rational(y_0*A - 2*v*B, k) + z = Rational(z_0*A - 2*w*B, k) + assert all(i.is_Integer for i in (x, y, z)) + step += 1 + + return tuple([int(i) for i in (x_0, y_0, z_0)]) + + +def diop_general_pythagorean(eq, param=symbols("m", integer=True)): + """ + Solves the general pythagorean equation, + `a_{1}^2x_{1}^2 + a_{2}^2x_{2}^2 + . . . + a_{n}^2x_{n}^2 - a_{n + 1}^2x_{n + 1}^2 = 0`. + + Returns a tuple which contains a parametrized solution to the equation, + sorted in the same order as the input variables. + + Usage + ===== + + ``diop_general_pythagorean(eq, param)``: where ``eq`` is a general + pythagorean equation which is assumed to be zero and ``param`` is the base + parameter used to construct other parameters by subscripting. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import diop_general_pythagorean + >>> from sympy.abc import a, b, c, d, e + >>> diop_general_pythagorean(a**2 + b**2 + c**2 - d**2) + (m1**2 + m2**2 - m3**2, 2*m1*m3, 2*m2*m3, m1**2 + m2**2 + m3**2) + >>> diop_general_pythagorean(9*a**2 - 4*b**2 + 16*c**2 + 25*d**2 + e**2) + (10*m1**2 + 10*m2**2 + 10*m3**2 - 10*m4**2, 15*m1**2 + 15*m2**2 + 15*m3**2 + 15*m4**2, 15*m1*m4, 12*m2*m4, 60*m3*m4) + """ + var, coeff, diop_type = classify_diop(eq, _dict=False) + + if diop_type == GeneralPythagorean.name: + if param is None: + params = None + else: + params = symbols('%s1:%i' % (param, len(var)), integer=True) + return list(GeneralPythagorean(eq).solve(parameters=params))[0] + + +def diop_general_sum_of_squares(eq, limit=1): + r""" + Solves the equation `x_{1}^2 + x_{2}^2 + . . . + x_{n}^2 - k = 0`. + + Returns at most ``limit`` number of solutions. + + Usage + ===== + + ``general_sum_of_squares(eq, limit)`` : Here ``eq`` is an expression which + is assumed to be zero. Also, ``eq`` should be in the form, + `x_{1}^2 + x_{2}^2 + . . . + x_{n}^2 - k = 0`. + + Details + ======= + + When `n = 3` if `k = 4^a(8m + 7)` for some `a, m \in Z` then there will be + no solutions. Refer to [1]_ for more details. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import diop_general_sum_of_squares + >>> from sympy.abc import a, b, c, d, e + >>> diop_general_sum_of_squares(a**2 + b**2 + c**2 + d**2 + e**2 - 2345) + {(15, 22, 22, 24, 24)} + + Reference + ========= + + .. [1] Representing an integer as a sum of three squares, [online], + Available: + https://proofwiki.org/wiki/Integer_as_Sum_of_Three_Squares + """ + var, coeff, diop_type = classify_diop(eq, _dict=False) + + if diop_type == GeneralSumOfSquares.name: + return set(GeneralSumOfSquares(eq).solve(limit=limit)) + + +def diop_general_sum_of_even_powers(eq, limit=1): + """ + Solves the equation `x_{1}^e + x_{2}^e + . . . + x_{n}^e - k = 0` + where `e` is an even, integer power. + + Returns at most ``limit`` number of solutions. + + Usage + ===== + + ``general_sum_of_even_powers(eq, limit)`` : Here ``eq`` is an expression which + is assumed to be zero. Also, ``eq`` should be in the form, + `x_{1}^e + x_{2}^e + . . . + x_{n}^e - k = 0`. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import diop_general_sum_of_even_powers + >>> from sympy.abc import a, b + >>> diop_general_sum_of_even_powers(a**4 + b**4 - (2**4 + 3**4)) + {(2, 3)} + + See Also + ======== + + power_representation + """ + var, coeff, diop_type = classify_diop(eq, _dict=False) + + if diop_type == GeneralSumOfEvenPowers.name: + return set(GeneralSumOfEvenPowers(eq).solve(limit=limit)) + + +## Functions below this comment can be more suitably grouped under +## an Additive number theory module rather than the Diophantine +## equation module. + + +def partition(n, k=None, zeros=False): + """ + Returns a generator that can be used to generate partitions of an integer + `n`. + + Explanation + =========== + + A partition of `n` is a set of positive integers which add up to `n`. For + example, partitions of 3 are 3, 1 + 2, 1 + 1 + 1. A partition is returned + as a tuple. If ``k`` equals None, then all possible partitions are returned + irrespective of their size, otherwise only the partitions of size ``k`` are + returned. If the ``zero`` parameter is set to True then a suitable + number of zeros are added at the end of every partition of size less than + ``k``. + + ``zero`` parameter is considered only if ``k`` is not None. When the + partitions are over, the last `next()` call throws the ``StopIteration`` + exception, so this function should always be used inside a try - except + block. + + Details + ======= + + ``partition(n, k)``: Here ``n`` is a positive integer and ``k`` is the size + of the partition which is also positive integer. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import partition + >>> f = partition(5) + >>> next(f) + (1, 1, 1, 1, 1) + >>> next(f) + (1, 1, 1, 2) + >>> g = partition(5, 3) + >>> next(g) + (1, 1, 3) + >>> next(g) + (1, 2, 2) + >>> g = partition(5, 3, zeros=True) + >>> next(g) + (0, 0, 5) + + """ + if not zeros or k is None: + for i in ordered_partitions(n, k): + yield tuple(i) + else: + for m in range(1, k + 1): + for i in ordered_partitions(n, m): + i = tuple(i) + yield (0,)*(k - len(i)) + i + + +def prime_as_sum_of_two_squares(p): + """ + Represent a prime `p` as a unique sum of two squares; this can + only be done if the prime is congruent to 1 mod 4. + + Parameters + ========== + + p : Integer + A prime that is congruent to 1 mod 4 + + Returns + ======= + + (int, int) | None : Pair of positive integers ``(x, y)`` satisfying ``x**2 + y**2 = p``. + None if ``p`` is not congruent to 1 mod 4. + + Raises + ====== + + ValueError + If ``p`` is not prime number + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import prime_as_sum_of_two_squares + >>> prime_as_sum_of_two_squares(7) # can't be done + >>> prime_as_sum_of_two_squares(5) + (1, 2) + + Reference + ========= + + .. [1] Representing a number as a sum of four squares, [online], + Available: https://schorn.ch/lagrange.html + + See Also + ======== + + sum_of_squares + + """ + p = as_int(p) + if p % 4 != 1: + return + if not isprime(p): + raise ValueError("p should be a prime number") + + if p % 8 == 5: + # Legendre symbol (2/p) == -1 if p % 8 in [3, 5] + b = 2 + elif p % 12 == 5: + # Legendre symbol (3/p) == -1 if p % 12 in [5, 7] + b = 3 + elif p % 5 in [2, 3]: + # Legendre symbol (5/p) == -1 if p % 5 in [2, 3] + b = 5 + else: + b = 7 + while jacobi(b, p) == 1: + b = nextprime(b) + + b = pow(b, p >> 2, p) + a = p + while b**2 > p: + a, b = b, a % b + return (int(a % b), int(b)) # convert from long + + +def sum_of_three_squares(n): + r""" + Returns a 3-tuple $(a, b, c)$ such that $a^2 + b^2 + c^2 = n$ and + $a, b, c \geq 0$. + + Returns None if $n = 4^a(8m + 7)$ for some `a, m \in \mathbb{Z}`. See + [1]_ for more details. + + Parameters + ========== + + n : Integer + non-negative integer + + Returns + ======= + + (int, int, int) | None : 3-tuple non-negative integers ``(a, b, c)`` satisfying ``a**2 + b**2 + c**2 = n``. + a,b,c are sorted in ascending order. ``None`` if no such ``(a,b,c)``. + + Raises + ====== + + ValueError + If ``n`` is a negative integer + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import sum_of_three_squares + >>> sum_of_three_squares(44542) + (18, 37, 207) + + References + ========== + + .. [1] Representing a number as a sum of three squares, [online], + Available: https://schorn.ch/lagrange.html + + See Also + ======== + + power_representation : + ``sum_of_three_squares(n)`` is one of the solutions output by ``power_representation(n, 2, 3, zeros=True)`` + + """ + # https://math.stackexchange.com/questions/483101/rabin-and-shallit-algorithm/651425#651425 + # discusses these numbers (except for 1, 2, 3) as the exceptions of H&L's conjecture that + # Every sufficiently large number n is either a square or the sum of a prime and a square. + special = {1: (0, 0, 1), 2: (0, 1, 1), 3: (1, 1, 1), 10: (0, 1, 3), 34: (3, 3, 4), + 58: (0, 3, 7), 85: (0, 6, 7), 130: (0, 3, 11), 214: (3, 6, 13), 226: (8, 9, 9), + 370: (8, 9, 15), 526: (6, 7, 21), 706: (15, 15, 16), 730: (0, 1, 27), + 1414: (6, 17, 33), 1906: (13, 21, 36), 2986: (21, 32, 39), 9634: (56, 57, 57)} + n = as_int(n) + if n < 0: + raise ValueError("n should be a non-negative integer") + if n == 0: + return (0, 0, 0) + n, v = remove(n, 4) + v = 1 << v + if n % 8 == 7: + return + if n in special: + return tuple([v*i for i in special[n]]) + + s, _exact = integer_nthroot(n, 2) + if _exact: + return (0, 0, v*s) + if n % 8 == 3: + if not s % 2: + s -= 1 + for x in range(s, -1, -2): + N = (n - x**2) // 2 + if isprime(N): + # n % 8 == 3 and x % 2 == 1 => N % 4 == 1 + y, z = prime_as_sum_of_two_squares(N) + return tuple(sorted([v*x, v*(y + z), v*abs(y - z)])) + # We will never reach this point because there must be a solution. + assert False + + # assert n % 4 in [1, 2] + if not((n % 2) ^ (s % 2)): + s -= 1 + for x in range(s, -1, -2): + N = n - x**2 + if isprime(N): + # assert N % 4 == 1 + y, z = prime_as_sum_of_two_squares(N) + return tuple(sorted([v*x, v*y, v*z])) + # We will never reach this point because there must be a solution. + assert False + + +def sum_of_four_squares(n): + r""" + Returns a 4-tuple `(a, b, c, d)` such that `a^2 + b^2 + c^2 + d^2 = n`. + Here `a, b, c, d \geq 0`. + + Parameters + ========== + + n : Integer + non-negative integer + + Returns + ======= + + (int, int, int, int) : 4-tuple non-negative integers ``(a, b, c, d)`` satisfying ``a**2 + b**2 + c**2 + d**2 = n``. + a,b,c,d are sorted in ascending order. + + Raises + ====== + + ValueError + If ``n`` is a negative integer + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import sum_of_four_squares + >>> sum_of_four_squares(3456) + (8, 8, 32, 48) + >>> sum_of_four_squares(1294585930293) + (0, 1234, 2161, 1137796) + + References + ========== + + .. [1] Representing a number as a sum of four squares, [online], + Available: https://schorn.ch/lagrange.html + + See Also + ======== + + power_representation : + ``sum_of_four_squares(n)`` is one of the solutions output by ``power_representation(n, 2, 4, zeros=True)`` + + """ + n = as_int(n) + if n < 0: + raise ValueError("n should be a non-negative integer") + if n == 0: + return (0, 0, 0, 0) + # remove factors of 4 since a solution in terms of 3 squares is + # going to be returned; this is also done in sum_of_three_squares, + # but it needs to be done here to select d + n, v = remove(n, 4) + v = 1 << v + if n % 8 == 7: + d = 2 + n = n - 4 + elif n % 8 in (2, 6): + d = 1 + n = n - 1 + else: + d = 0 + x, y, z = sum_of_three_squares(n) # sorted + return tuple(sorted([v*d, v*x, v*y, v*z])) + + +def power_representation(n, p, k, zeros=False): + r""" + Returns a generator for finding k-tuples of integers, + `(n_{1}, n_{2}, . . . n_{k})`, such that + `n = n_{1}^p + n_{2}^p + . . . n_{k}^p`. + + Usage + ===== + + ``power_representation(n, p, k, zeros)``: Represent non-negative number + ``n`` as a sum of ``k`` ``p``\ th powers. If ``zeros`` is true, then the + solutions is allowed to contain zeros. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import power_representation + + Represent 1729 as a sum of two cubes: + + >>> f = power_representation(1729, 3, 2) + >>> next(f) + (9, 10) + >>> next(f) + (1, 12) + + If the flag `zeros` is True, the solution may contain tuples with + zeros; any such solutions will be generated after the solutions + without zeros: + + >>> list(power_representation(125, 2, 3, zeros=True)) + [(5, 6, 8), (3, 4, 10), (0, 5, 10), (0, 2, 11)] + + For even `p` the `permute_sign` function can be used to get all + signed values: + + >>> from sympy.utilities.iterables import permute_signs + >>> list(permute_signs((1, 12))) + [(1, 12), (-1, 12), (1, -12), (-1, -12)] + + All possible signed permutations can also be obtained: + + >>> from sympy.utilities.iterables import signed_permutations + >>> list(signed_permutations((1, 12))) + [(1, 12), (-1, 12), (1, -12), (-1, -12), (12, 1), (-12, 1), (12, -1), (-12, -1)] + """ + n, p, k = [as_int(i) for i in (n, p, k)] + + if n < 0: + if p % 2: + for t in power_representation(-n, p, k, zeros): + yield tuple(-i for i in t) + return + + if p < 1 or k < 1: + raise ValueError(filldedent(''' + Expecting positive integers for `(p, k)`, but got `(%s, %s)`''' + % (p, k))) + + if n == 0: + if zeros: + yield (0,)*k + return + + if k == 1: + if p == 1: + yield (n,) + elif n == 1: + yield (1,) + else: + be = perfect_power(n) + if be: + b, e = be + d, r = divmod(e, p) + if not r: + yield (b**d,) + return + + if p == 1: + yield from partition(n, k, zeros=zeros) + return + + if p == 2: + if k == 3: + n, v = remove(n, 4) + if v: + v = 1 << v + for t in power_representation(n, p, k, zeros): + yield tuple(i*v for i in t) + return + feasible = _can_do_sum_of_squares(n, k) + if not feasible: + return + if not zeros: + if n > 33 and k >= 5 and k <= n and n - k in ( + 13, 10, 7, 5, 4, 2, 1): + '''Todd G. Will, "When Is n^2 a Sum of k Squares?", [online]. + Available: https://www.maa.org/sites/default/files/Will-MMz-201037918.pdf''' + return + # quick tests since feasibility includes the possibility of 0 + if k == 4 and (n in (1, 3, 5, 9, 11, 17, 29, 41) or remove(n, 4)[0] in (2, 6, 14)): + # A000534 + return + if k == 3 and n in (1, 2, 5, 10, 13, 25, 37, 58, 85, 130): # or n = some number >= 5*10**10 + # A051952 + return + if feasible is not True: # it's prime and k == 2 + yield prime_as_sum_of_two_squares(n) + return + + if k == 2 and p > 2: + be = perfect_power(n) + if be and be[1] % p == 0: + return # Fermat: a**n + b**n = c**n has no solution for n > 2 + + if n >= k: + a = integer_nthroot(n - (k - 1), p)[0] + for t in pow_rep_recursive(a, k, n, [], p): + yield tuple(reversed(t)) + + if zeros: + a = integer_nthroot(n, p)[0] + for i in range(1, k): + for t in pow_rep_recursive(a, i, n, [], p): + yield tuple(reversed(t + (0,)*(k - i))) + + +sum_of_powers = power_representation + + +def pow_rep_recursive(n_i, k, n_remaining, terms, p): + # Invalid arguments + if n_i <= 0 or k <= 0: + return + + # No solutions may exist + if n_remaining < k: + return + if k * pow(n_i, p) < n_remaining: + return + + if k == 0 and n_remaining == 0: + yield tuple(terms) + + elif k == 1: + # next_term^p must equal to n_remaining + next_term, exact = integer_nthroot(n_remaining, p) + if exact and next_term <= n_i: + yield tuple(terms + [next_term]) + return + + else: + # TODO: Fall back to diop_DN when k = 2 + if n_i >= 1 and k > 0: + for next_term in range(1, n_i + 1): + residual = n_remaining - pow(next_term, p) + if residual < 0: + break + yield from pow_rep_recursive(next_term, k - 1, residual, terms + [next_term], p) + + +def sum_of_squares(n, k, zeros=False): + """Return a generator that yields the k-tuples of nonnegative + values, the squares of which sum to n. If zeros is False (default) + then the solution will not contain zeros. The nonnegative + elements of a tuple are sorted. + + * If k == 1 and n is square, (n,) is returned. + + * If k == 2 then n can only be written as a sum of squares if + every prime in the factorization of n that has the form + 4*k + 3 has an even multiplicity. If n is prime then + it can only be written as a sum of two squares if it is + in the form 4*k + 1. + + * if k == 3 then n can be written as a sum of squares if it does + not have the form 4**m*(8*k + 7). + + * all integers can be written as the sum of 4 squares. + + * if k > 4 then n can be partitioned and each partition can + be written as a sum of 4 squares; if n is not evenly divisible + by 4 then n can be written as a sum of squares only if the + an additional partition can be written as sum of squares. + For example, if k = 6 then n is partitioned into two parts, + the first being written as a sum of 4 squares and the second + being written as a sum of 2 squares -- which can only be + done if the condition above for k = 2 can be met, so this will + automatically reject certain partitions of n. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import sum_of_squares + >>> list(sum_of_squares(25, 2)) + [(3, 4)] + >>> list(sum_of_squares(25, 2, True)) + [(3, 4), (0, 5)] + >>> list(sum_of_squares(25, 4)) + [(1, 2, 2, 4)] + + See Also + ======== + + sympy.utilities.iterables.signed_permutations + """ + yield from power_representation(n, 2, k, zeros) + + +def _can_do_sum_of_squares(n, k): + """Return True if n can be written as the sum of k squares, + False if it cannot, or 1 if ``k == 2`` and ``n`` is prime (in which + case it *can* be written as a sum of two squares). A False + is returned only if it cannot be written as ``k``-squares, even + if 0s are allowed. + """ + if k < 1: + return False + if n < 0: + return False + if n == 0: + return True + if k == 1: + return is_square(n) + if k == 2: + if n in (1, 2): + return True + if isprime(n): + if n % 4 == 1: + return 1 # signal that it was prime + return False + # n is a composite number + # we can proceed iff no prime factor in the form 4*k + 3 + # has an odd multiplicity + return all(p % 4 !=3 or m % 2 == 0 for p, m in factorint(n).items()) + if k == 3: + return remove(n, 4)[0] % 8 != 7 + # every number can be written as a sum of 4 squares; for k > 4 partitions + # can be 0 + return True diff --git a/phivenv/Lib/site-packages/sympy/solvers/diophantine/tests/__init__.py b/phivenv/Lib/site-packages/sympy/solvers/diophantine/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/solvers/diophantine/tests/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/solvers/diophantine/tests/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3f2aacf5b64127924e18c979e878056e72e4bd7 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/solvers/diophantine/tests/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/solvers/diophantine/tests/__pycache__/test_diophantine.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/solvers/diophantine/tests/__pycache__/test_diophantine.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8115d3a9fbd7592db5b8acaac989524cb42aef8b Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/solvers/diophantine/tests/__pycache__/test_diophantine.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/solvers/diophantine/tests/test_diophantine.py b/phivenv/Lib/site-packages/sympy/solvers/diophantine/tests/test_diophantine.py new file mode 100644 index 0000000000000000000000000000000000000000..b8b031a1e63fa445e3dbb0b425f84cfe88888667 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/solvers/diophantine/tests/test_diophantine.py @@ -0,0 +1,1071 @@ +from sympy.core.add import Add +from sympy.core.mul import Mul +from sympy.core.numbers import (Rational, oo, pi) +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.matrices.dense import Matrix +from sympy.ntheory.factor_ import factorint +from sympy.simplify.powsimp import powsimp +from sympy.core.function import _mexpand +from sympy.core.sorting import default_sort_key, ordered +from sympy.functions.elementary.trigonometric import sin +from sympy.solvers.diophantine import diophantine +from sympy.solvers.diophantine.diophantine import (diop_DN, + diop_solve, diop_ternary_quadratic_normal, + diop_general_pythagorean, diop_ternary_quadratic, diop_linear, + diop_quadratic, diop_general_sum_of_squares, diop_general_sum_of_even_powers, + descent, diop_bf_DN, divisible, equivalent, find_DN, ldescent, length, + reconstruct, partition, power_representation, + prime_as_sum_of_two_squares, square_factor, sum_of_four_squares, + sum_of_three_squares, transformation_to_DN, transformation_to_normal, + classify_diop, base_solution_linear, cornacchia, sqf_normal, gaussian_reduce, holzer, + check_param, parametrize_ternary_quadratic, sum_of_powers, sum_of_squares, + _diop_ternary_quadratic_normal, _nint_or_floor, + _odd, _even, _remove_gcd, _can_do_sum_of_squares, DiophantineSolutionSet, GeneralPythagorean, + BinaryQuadratic) + +from sympy.testing.pytest import slow, raises, XFAIL +from sympy.utilities.iterables import ( + signed_permutations) + +a, b, c, d, p, q, x, y, z, w, t, u, v, X, Y, Z = symbols( + "a, b, c, d, p, q, x, y, z, w, t, u, v, X, Y, Z", integer=True) +t_0, t_1, t_2, t_3, t_4, t_5, t_6 = symbols("t_:7", integer=True) +m1, m2, m3 = symbols('m1:4', integer=True) +n1 = symbols('n1', integer=True) + + +def diop_simplify(eq): + return _mexpand(powsimp(_mexpand(eq))) + + +def test_input_format(): + raises(TypeError, lambda: diophantine(sin(x))) + raises(TypeError, lambda: diophantine(x/pi - 3)) + + +def test_nosols(): + # diophantine should sympify eq so that these are equivalent + assert diophantine(3) == set() + assert diophantine(S(3)) == set() + + +def test_univariate(): + assert diop_solve((x - 1)*(x - 2)**2) == {(1,), (2,)} + assert diop_solve((x - 1)*(x - 2)) == {(1,), (2,)} + + +def test_classify_diop(): + raises(TypeError, lambda: classify_diop(x**2/3 - 1)) + raises(ValueError, lambda: classify_diop(1)) + raises(NotImplementedError, lambda: classify_diop(w*x*y*z - 1)) + raises(NotImplementedError, lambda: classify_diop(x**3 + y**3 + z**4 - 90)) + assert classify_diop(14*x**2 + 15*x - 42) == ( + [x], {1: -42, x: 15, x**2: 14}, 'univariate') + assert classify_diop(x*y + z) == ( + [x, y, z], {x*y: 1, z: 1}, 'inhomogeneous_ternary_quadratic') + assert classify_diop(x*y + z + w + x**2) == ( + [w, x, y, z], {x*y: 1, w: 1, x**2: 1, z: 1}, 'inhomogeneous_general_quadratic') + assert classify_diop(x*y + x*z + x**2 + 1) == ( + [x, y, z], {x*y: 1, x*z: 1, x**2: 1, 1: 1}, 'inhomogeneous_general_quadratic') + assert classify_diop(x*y + z + w + 42) == ( + [w, x, y, z], {x*y: 1, w: 1, 1: 42, z: 1}, 'inhomogeneous_general_quadratic') + assert classify_diop(x*y + z*w) == ( + [w, x, y, z], {x*y: 1, w*z: 1}, 'homogeneous_general_quadratic') + assert classify_diop(x*y**2 + 1) == ( + [x, y], {x*y**2: 1, 1: 1}, 'cubic_thue') + assert classify_diop(x**4 + y**4 + z**4 - (1 + 16 + 81)) == ( + [x, y, z], {1: -98, x**4: 1, z**4: 1, y**4: 1}, 'general_sum_of_even_powers') + assert classify_diop(x**2 + y**2 + z**2) == ( + [x, y, z], {x**2: 1, y**2: 1, z**2: 1}, 'homogeneous_ternary_quadratic_normal') + + +def test_linear(): + assert diop_solve(x) == (0,) + assert diop_solve(1*x) == (0,) + assert diop_solve(3*x) == (0,) + assert diop_solve(x + 1) == (-1,) + assert diop_solve(2*x + 1) == (None,) + assert diop_solve(2*x + 4) == (-2,) + assert diop_solve(y + x) == (t_0, -t_0) + assert diop_solve(y + x + 0) == (t_0, -t_0) + assert diop_solve(y + x - 0) == (t_0, -t_0) + assert diop_solve(0*x - y - 5) == (-5,) + assert diop_solve(3*y + 2*x - 5) == (3*t_0 - 5, -2*t_0 + 5) + assert diop_solve(2*x - 3*y - 5) == (3*t_0 - 5, 2*t_0 - 5) + assert diop_solve(-2*x - 3*y - 5) == (3*t_0 + 5, -2*t_0 - 5) + assert diop_solve(7*x + 5*y) == (5*t_0, -7*t_0) + assert diop_solve(2*x + 4*y) == (-2*t_0, t_0) + assert diop_solve(4*x + 6*y - 4) == (3*t_0 - 2, -2*t_0 + 2) + assert diop_solve(4*x + 6*y - 3) == (None, None) + assert diop_solve(0*x + 3*y - 4*z + 5) == (4*t_0 + 5, 3*t_0 + 5) + assert diop_solve(4*x + 3*y - 4*z + 5) == (t_0, 8*t_0 + 4*t_1 + 5, 7*t_0 + 3*t_1 + 5) + assert diop_solve(4*x + 3*y - 4*z + 5, None) == (0, 5, 5) + assert diop_solve(4*x + 2*y + 8*z - 5) == (None, None, None) + assert diop_solve(5*x + 7*y - 2*z - 6) == (t_0, -3*t_0 + 2*t_1 + 6, -8*t_0 + 7*t_1 + 18) + assert diop_solve(3*x - 6*y + 12*z - 9) == (2*t_0 + 3, t_0 + 2*t_1, t_1) + assert diop_solve(6*w + 9*x + 20*y - z) == (t_0, t_1, t_1 + t_2, 6*t_0 + 29*t_1 + 20*t_2) + + # to ignore constant factors, use diophantine + raises(TypeError, lambda: diop_solve(x/2)) + + +def test_quadratic_simple_hyperbolic_case(): + # Simple Hyperbolic case: A = C = 0 and B != 0 + assert diop_solve(3*x*y + 34*x - 12*y + 1) == \ + {(-133, -11), (5, -57)} + assert diop_solve(6*x*y + 2*x + 3*y + 1) == set() + assert diop_solve(-13*x*y + 2*x - 4*y - 54) == {(27, 0)} + assert diop_solve(-27*x*y - 30*x - 12*y - 54) == {(-14, -1)} + assert diop_solve(2*x*y + 5*x + 56*y + 7) == {(-161, -3), (-47, -6), (-35, -12), + (-29, -69), (-27, 64), (-21, 7), + (-9, 1), (105, -2)} + assert diop_solve(6*x*y + 9*x + 2*y + 3) == set() + assert diop_solve(x*y + x + y + 1) == {(-1, t), (t, -1)} + assert diophantine(48*x*y) + + +def test_quadratic_elliptical_case(): + # Elliptical case: B**2 - 4AC < 0 + + assert diop_solve(42*x**2 + 8*x*y + 15*y**2 + 23*x + 17*y - 4915) == {(-11, -1)} + assert diop_solve(4*x**2 + 3*y**2 + 5*x - 11*y + 12) == set() + assert diop_solve(x**2 + y**2 + 2*x + 2*y + 2) == {(-1, -1)} + assert diop_solve(15*x**2 - 9*x*y + 14*y**2 - 23*x - 14*y - 4950) == {(-15, 6)} + assert diop_solve(10*x**2 + 12*x*y + 12*y**2 - 34) == \ + {(-1, -1), (-1, 2), (1, -2), (1, 1)} + + +def test_quadratic_parabolic_case(): + # Parabolic case: B**2 - 4AC = 0 + assert check_solutions(8*x**2 - 24*x*y + 18*y**2 + 5*x + 7*y + 16) + assert check_solutions(8*x**2 - 24*x*y + 18*y**2 + 6*x + 12*y - 6) + assert check_solutions(8*x**2 + 24*x*y + 18*y**2 + 4*x + 6*y - 7) + assert check_solutions(-4*x**2 + 4*x*y - y**2 + 2*x - 3) + assert check_solutions(x**2 + 2*x*y + y**2 + 2*x + 2*y + 1) + assert check_solutions(x**2 - 2*x*y + y**2 + 2*x + 2*y + 1) + assert check_solutions(y**2 - 41*x + 40) + + +def test_quadratic_perfect_square(): + # B**2 - 4*A*C > 0 + # B**2 - 4*A*C is a perfect square + assert check_solutions(48*x*y) + assert check_solutions(4*x**2 - 5*x*y + y**2 + 2) + assert check_solutions(-2*x**2 - 3*x*y + 2*y**2 -2*x - 17*y + 25) + assert check_solutions(12*x**2 + 13*x*y + 3*y**2 - 2*x + 3*y - 12) + assert check_solutions(8*x**2 + 10*x*y + 2*y**2 - 32*x - 13*y - 23) + assert check_solutions(4*x**2 - 4*x*y - 3*y- 8*x - 3) + assert check_solutions(- 4*x*y - 4*y**2 - 3*y- 5*x - 10) + assert check_solutions(x**2 - y**2 - 2*x - 2*y) + assert check_solutions(x**2 - 9*y**2 - 2*x - 6*y) + assert check_solutions(4*x**2 - 9*y**2 - 4*x - 12*y - 3) + + +def test_quadratic_non_perfect_square(): + # B**2 - 4*A*C is not a perfect square + # Used check_solutions() since the solutions are complex expressions involving + # square roots and exponents + assert check_solutions(x**2 - 2*x - 5*y**2) + assert check_solutions(3*x**2 - 2*y**2 - 2*x - 2*y) + assert check_solutions(x**2 - x*y - y**2 - 3*y) + assert check_solutions(x**2 - 9*y**2 - 2*x - 6*y) + assert BinaryQuadratic(x**2 + y**2 + 2*x + 2*y + 2).solve() == {(-1, -1)} + + +def test_issue_9106(): + eq = -48 - 2*x*(3*x - 1) + y*(3*y - 1) + v = (x, y) + for sol in diophantine(eq): + assert not diop_simplify(eq.xreplace(dict(zip(v, sol)))) + + +def test_issue_18138(): + eq = x**2 - x - y**2 + v = (x, y) + for sol in diophantine(eq): + assert not diop_simplify(eq.xreplace(dict(zip(v, sol)))) + + +@slow +def test_quadratic_non_perfect_slow(): + assert check_solutions(8*x**2 + 10*x*y - 2*y**2 - 32*x - 13*y - 23) + # This leads to very large numbers. + # assert check_solutions(5*x**2 - 13*x*y + y**2 - 4*x - 4*y - 15) + assert check_solutions(-3*x**2 - 2*x*y + 7*y**2 - 5*x - 7) + assert check_solutions(-4 - x + 4*x**2 - y - 3*x*y - 4*y**2) + assert check_solutions(1 + 2*x + 2*x**2 + 2*y + x*y - 2*y**2) + + +def test_DN(): + # Most of the test cases were adapted from, + # Solving the generalized Pell equation x**2 - D*y**2 = N, John P. Robertson, July 31, 2004. + # https://web.archive.org/web/20160323033128/http://www.jpr2718.org/pell.pdf + # others are verified using Wolfram Alpha. + + # Covers cases where D <= 0 or D > 0 and D is a square or N = 0 + # Solutions are straightforward in these cases. + assert diop_DN(3, 0) == [(0, 0)] + assert diop_DN(-17, -5) == [] + assert diop_DN(-19, 23) == [(2, 1)] + assert diop_DN(-13, 17) == [(2, 1)] + assert diop_DN(-15, 13) == [] + assert diop_DN(0, 5) == [] + assert diop_DN(0, 9) == [(3, t)] + assert diop_DN(9, 0) == [(3*t, t)] + assert diop_DN(16, 24) == [] + assert diop_DN(9, 180) == [(18, 4)] + assert diop_DN(9, -180) == [(12, 6)] + assert diop_DN(7, 0) == [(0, 0)] + + # When equation is x**2 + y**2 = N + # Solutions are interchangeable + assert diop_DN(-1, 5) == [(2, 1), (1, 2)] + assert diop_DN(-1, 169) == [(12, 5), (5, 12), (13, 0), (0, 13)] + + # D > 0 and D is not a square + + # N = 1 + assert diop_DN(13, 1) == [(649, 180)] + assert diop_DN(980, 1) == [(51841, 1656)] + assert diop_DN(981, 1) == [(158070671986249, 5046808151700)] + assert diop_DN(986, 1) == [(49299, 1570)] + assert diop_DN(991, 1) == [(379516400906811930638014896080, 12055735790331359447442538767)] + assert diop_DN(17, 1) == [(33, 8)] + assert diop_DN(19, 1) == [(170, 39)] + + # N = -1 + assert diop_DN(13, -1) == [(18, 5)] + assert diop_DN(991, -1) == [] + assert diop_DN(41, -1) == [(32, 5)] + assert diop_DN(290, -1) == [(17, 1)] + assert diop_DN(21257, -1) == [(13913102721304, 95427381109)] + assert diop_DN(32, -1) == [] + + # |N| > 1 + # Some tests were created using calculator at + # http://www.numbertheory.org/php/patz.html + + assert diop_DN(13, -4) == [(3, 1), (393, 109), (36, 10)] + # Source I referred returned (3, 1), (393, 109) and (-3, 1) as fundamental solutions + # So (-3, 1) and (393, 109) should be in the same equivalent class + assert equivalent(-3, 1, 393, 109, 13, -4) == True + + assert diop_DN(13, 27) == [(220, 61), (40, 11), (768, 213), (12, 3)] + assert set(diop_DN(157, 12)) == {(13, 1), (10663, 851), (579160, 46222), + (483790960, 38610722), (26277068347, 2097138361), + (21950079635497, 1751807067011)} + assert diop_DN(13, 25) == [(3245, 900)] + assert diop_DN(192, 18) == [] + assert diop_DN(23, 13) == [(-6, 1), (6, 1)] + assert diop_DN(167, 2) == [(13, 1)] + assert diop_DN(167, -2) == [] + + assert diop_DN(123, -2) == [(11, 1)] + # One calculator returned [(11, 1), (-11, 1)] but both of these are in + # the same equivalence class + assert equivalent(11, 1, -11, 1, 123, -2) + + assert diop_DN(123, -23) == [(-10, 1), (10, 1)] + + assert diop_DN(0, 0, t) == [(0, t)] + assert diop_DN(0, -1, t) == [] + + +def test_bf_pell(): + assert diop_bf_DN(13, -4) == [(3, 1), (-3, 1), (36, 10)] + assert diop_bf_DN(13, 27) == [(12, 3), (-12, 3), (40, 11), (-40, 11)] + assert diop_bf_DN(167, -2) == [] + assert diop_bf_DN(1729, 1) == [(44611924489705, 1072885712316)] + assert diop_bf_DN(89, -8) == [(9, 1), (-9, 1)] + assert diop_bf_DN(21257, -1) == [(13913102721304, 95427381109)] + assert diop_bf_DN(340, -4) == [(756, 41)] + assert diop_bf_DN(-1, 0, t) == [(0, 0)] + assert diop_bf_DN(0, 0, t) == [(0, t)] + assert diop_bf_DN(4, 0, t) == [(2*t, t), (-2*t, t)] + assert diop_bf_DN(3, 0, t) == [(0, 0)] + assert diop_bf_DN(1, -2, t) == [] + + +def test_length(): + assert length(2, 1, 0) == 1 + assert length(-2, 4, 5) == 3 + assert length(-5, 4, 17) == 4 + assert length(0, 4, 13) == 6 + assert length(7, 13, 11) == 23 + assert length(1, 6, 4) == 2 + + +def is_pell_transformation_ok(eq): + """ + Test whether X*Y, X, or Y terms are present in the equation + after transforming the equation using the transformation returned + by transformation_to_pell(). If they are not present we are good. + Moreover, coefficient of X**2 should be a divisor of coefficient of + Y**2 and the constant term. + """ + A, B = transformation_to_DN(eq) + u = (A*Matrix([X, Y]) + B)[0] + v = (A*Matrix([X, Y]) + B)[1] + simplified = diop_simplify(eq.subs(zip((x, y), (u, v)))) + + coeff = dict([reversed(t.as_independent(*[X, Y])) for t in simplified.args]) + + for term in [X*Y, X, Y]: + if term in coeff.keys(): + return False + + for term in [X**2, Y**2, 1]: + if term not in coeff.keys(): + coeff[term] = 0 + + if coeff[X**2] != 0: + return divisible(coeff[Y**2], coeff[X**2]) and \ + divisible(coeff[1], coeff[X**2]) + + return True + + +def test_transformation_to_pell(): + assert is_pell_transformation_ok(-13*x**2 - 7*x*y + y**2 + 2*x - 2*y - 14) + assert is_pell_transformation_ok(-17*x**2 + 19*x*y - 7*y**2 - 5*x - 13*y - 23) + assert is_pell_transformation_ok(x**2 - y**2 + 17) + assert is_pell_transformation_ok(-x**2 + 7*y**2 - 23) + assert is_pell_transformation_ok(25*x**2 - 45*x*y + 5*y**2 - 5*x - 10*y + 5) + assert is_pell_transformation_ok(190*x**2 + 30*x*y + y**2 - 3*y - 170*x - 130) + assert is_pell_transformation_ok(x**2 - 2*x*y -190*y**2 - 7*y - 23*x - 89) + assert is_pell_transformation_ok(15*x**2 - 9*x*y + 14*y**2 - 23*x - 14*y - 4950) + + +def test_find_DN(): + assert find_DN(x**2 - 2*x - y**2) == (1, 1) + assert find_DN(x**2 - 3*y**2 - 5) == (3, 5) + assert find_DN(x**2 - 2*x*y - 4*y**2 - 7) == (5, 7) + assert find_DN(4*x**2 - 8*x*y - y**2 - 9) == (20, 36) + assert find_DN(7*x**2 - 2*x*y - y**2 - 12) == (8, 84) + assert find_DN(-3*x**2 + 4*x*y -y**2) == (1, 0) + assert find_DN(-13*x**2 - 7*x*y + y**2 + 2*x - 2*y -14) == (101, -7825480) + + +def test_ldescent(): + # Equations which have solutions + u = ([(13, 23), (3, -11), (41, -113), (4, -7), (-7, 4), (91, -3), (1, 1), (1, -1), + (4, 32), (17, 13), (123689, 1), (19, -570)]) + for a, b in u: + w, x, y = ldescent(a, b) + assert a*x**2 + b*y**2 == w**2 + assert ldescent(-1, -1) is None + assert ldescent(2, 6) is None + + +def test_diop_ternary_quadratic_normal(): + assert check_solutions(234*x**2 - 65601*y**2 - z**2) + assert check_solutions(23*x**2 + 616*y**2 - z**2) + assert check_solutions(5*x**2 + 4*y**2 - z**2) + assert check_solutions(3*x**2 + 6*y**2 - 3*z**2) + assert check_solutions(x**2 + 3*y**2 - z**2) + assert check_solutions(4*x**2 + 5*y**2 - z**2) + assert check_solutions(x**2 + y**2 - z**2) + assert check_solutions(16*x**2 + y**2 - 25*z**2) + assert check_solutions(6*x**2 - y**2 + 10*z**2) + assert check_solutions(213*x**2 + 12*y**2 - 9*z**2) + assert check_solutions(34*x**2 - 3*y**2 - 301*z**2) + assert check_solutions(124*x**2 - 30*y**2 - 7729*z**2) + + +def is_normal_transformation_ok(eq): + A = transformation_to_normal(eq) + X, Y, Z = A*Matrix([x, y, z]) + simplified = diop_simplify(eq.subs(zip((x, y, z), (X, Y, Z)))) + + coeff = dict([reversed(t.as_independent(*[X, Y, Z])) for t in simplified.args]) + for term in [X*Y, Y*Z, X*Z]: + if term in coeff.keys(): + return False + + return True + + +def test_transformation_to_normal(): + assert is_normal_transformation_ok(x**2 + 3*y**2 + z**2 - 13*x*y - 16*y*z + 12*x*z) + assert is_normal_transformation_ok(x**2 + 3*y**2 - 100*z**2) + assert is_normal_transformation_ok(x**2 + 23*y*z) + assert is_normal_transformation_ok(3*y**2 - 100*z**2 - 12*x*y) + assert is_normal_transformation_ok(x**2 + 23*x*y - 34*y*z + 12*x*z) + assert is_normal_transformation_ok(z**2 + 34*x*y - 23*y*z + x*z) + assert is_normal_transformation_ok(x**2 + y**2 + z**2 - x*y - y*z - x*z) + assert is_normal_transformation_ok(x**2 + 2*y*z + 3*z**2) + assert is_normal_transformation_ok(x*y + 2*x*z + 3*y*z) + assert is_normal_transformation_ok(2*x*z + 3*y*z) + + +def test_diop_ternary_quadratic(): + assert check_solutions(2*x**2 + z**2 + y**2 - 4*x*y) + assert check_solutions(x**2 - y**2 - z**2 - x*y - y*z) + assert check_solutions(3*x**2 - x*y - y*z - x*z) + assert check_solutions(x**2 - y*z - x*z) + assert check_solutions(5*x**2 - 3*x*y - x*z) + assert check_solutions(4*x**2 - 5*y**2 - x*z) + assert check_solutions(3*x**2 + 2*y**2 - z**2 - 2*x*y + 5*y*z - 7*y*z) + assert check_solutions(8*x**2 - 12*y*z) + assert check_solutions(45*x**2 - 7*y**2 - 8*x*y - z**2) + assert check_solutions(x**2 - 49*y**2 - z**2 + 13*z*y -8*x*y) + assert check_solutions(90*x**2 + 3*y**2 + 5*x*y + 2*z*y + 5*x*z) + assert check_solutions(x**2 + 3*y**2 + z**2 - x*y - 17*y*z) + assert check_solutions(x**2 + 3*y**2 + z**2 - x*y - 16*y*z + 12*x*z) + assert check_solutions(x**2 + 3*y**2 + z**2 - 13*x*y - 16*y*z + 12*x*z) + assert check_solutions(x*y - 7*y*z + 13*x*z) + + assert diop_ternary_quadratic_normal(x**2 + y**2 + z**2) == (None, None, None) + assert diop_ternary_quadratic_normal(x**2 + y**2) is None + raises(ValueError, lambda: + _diop_ternary_quadratic_normal((x, y, z), + {x*y: 1, x**2: 2, y**2: 3, z**2: 0})) + eq = -2*x*y - 6*x*z + 7*y**2 - 3*y*z + 4*z**2 + assert diop_ternary_quadratic(eq) == (7, 2, 0) + assert diop_ternary_quadratic_normal(4*x**2 + 5*y**2 - z**2) == \ + (1, 0, 2) + assert diop_ternary_quadratic(x*y + 2*y*z) == \ + (-2, 0, n1) + eq = -5*x*y - 8*x*z - 3*y*z + 8*z**2 + assert parametrize_ternary_quadratic(eq) == \ + (8*p**2 - 3*p*q, -8*p*q + 8*q**2, 5*p*q) + # this cannot be tested with diophantine because it will + # factor into a product + assert diop_solve(x*y + 2*y*z) == (-2*p*q, -n1*p**2 + p**2, p*q) + + +def test_square_factor(): + assert square_factor(1) == square_factor(-1) == 1 + assert square_factor(0) == 1 + assert square_factor(5) == square_factor(-5) == 1 + assert square_factor(4) == square_factor(-4) == 2 + assert square_factor(12) == square_factor(-12) == 2 + assert square_factor(6) == 1 + assert square_factor(18) == 3 + assert square_factor(52) == 2 + assert square_factor(49) == 7 + assert square_factor(392) == 14 + assert square_factor(factorint(-12)) == 2 + + +def test_parametrize_ternary_quadratic(): + assert check_solutions(x**2 + y**2 - z**2) + assert check_solutions(x**2 + 2*x*y + z**2) + assert check_solutions(234*x**2 - 65601*y**2 - z**2) + assert check_solutions(3*x**2 + 2*y**2 - z**2 - 2*x*y + 5*y*z - 7*y*z) + assert check_solutions(x**2 - y**2 - z**2) + assert check_solutions(x**2 - 49*y**2 - z**2 + 13*z*y - 8*x*y) + assert check_solutions(8*x*y + z**2) + assert check_solutions(124*x**2 - 30*y**2 - 7729*z**2) + assert check_solutions(236*x**2 - 225*y**2 - 11*x*y - 13*y*z - 17*x*z) + assert check_solutions(90*x**2 + 3*y**2 + 5*x*y + 2*z*y + 5*x*z) + assert check_solutions(124*x**2 - 30*y**2 - 7729*z**2) + + +def test_no_square_ternary_quadratic(): + assert check_solutions(2*x*y + y*z - 3*x*z) + assert check_solutions(189*x*y - 345*y*z - 12*x*z) + assert check_solutions(23*x*y + 34*y*z) + assert check_solutions(x*y + y*z + z*x) + assert check_solutions(23*x*y + 23*y*z + 23*x*z) + + +def test_descent(): + + u = ([(13, 23), (3, -11), (41, -113), (91, -3), (1, 1), (1, -1), (17, 13), (123689, 1), (19, -570)]) + for a, b in u: + w, x, y = descent(a, b) + assert a*x**2 + b*y**2 == w**2 + # the docstring warns against bad input, so these are expected results + # - can't both be negative + raises(TypeError, lambda: descent(-1, -3)) + # A can't be zero unless B != 1 + raises(ZeroDivisionError, lambda: descent(0, 3)) + # supposed to be square-free + raises(TypeError, lambda: descent(4, 3)) + + +def test_diophantine(): + assert check_solutions((x - y)*(y - z)*(z - x)) + assert check_solutions((x - y)*(x**2 + y**2 - z**2)) + assert check_solutions((x - 3*y + 7*z)*(x**2 + y**2 - z**2)) + assert check_solutions(x**2 - 3*y**2 - 1) + assert check_solutions(y**2 + 7*x*y) + assert check_solutions(x**2 - 3*x*y + y**2) + assert check_solutions(z*(x**2 - y**2 - 15)) + assert check_solutions(x*(2*y - 2*z + 5)) + assert check_solutions((x**2 - 3*y**2 - 1)*(x**2 - y**2 - 15)) + assert check_solutions((x**2 - 3*y**2 - 1)*(y - 7*z)) + assert check_solutions((x**2 + y**2 - z**2)*(x - 7*y - 3*z + 4*w)) + # Following test case caused problems in parametric representation + # But this can be solved by factoring out y. + # No need to use methods for ternary quadratic equations. + assert check_solutions(y**2 - 7*x*y + 4*y*z) + assert check_solutions(x**2 - 2*x + 1) + + assert diophantine(x - y) == diophantine(Eq(x, y)) + # 18196 + eq = x**4 + y**4 - 97 + assert diophantine(eq, permute=True) == diophantine(-eq, permute=True) + assert diophantine(3*x*pi - 2*y*pi) == {(2*t_0, 3*t_0)} + eq = x**2 + y**2 + z**2 - 14 + base_sol = {(1, 2, 3)} + assert diophantine(eq) == base_sol + complete_soln = set(signed_permutations(base_sol.pop())) + assert diophantine(eq, permute=True) == complete_soln + + assert diophantine(x**2 + x*Rational(15, 14) - 3) == set() + # test issue 11049 + eq = 92*x**2 - 99*y**2 - z**2 + coeff = eq.as_coefficients_dict() + assert _diop_ternary_quadratic_normal((x, y, z), coeff) == \ + {(9, 7, 51)} + assert diophantine(eq) == {( + 891*p**2 + 9*q**2, -693*p**2 - 102*p*q + 7*q**2, + 5049*p**2 - 1386*p*q - 51*q**2)} + eq = 2*x**2 + 2*y**2 - z**2 + coeff = eq.as_coefficients_dict() + assert _diop_ternary_quadratic_normal((x, y, z), coeff) == \ + {(1, 1, 2)} + assert diophantine(eq) == {( + 2*p**2 - q**2, -2*p**2 + 4*p*q - q**2, + 4*p**2 - 4*p*q + 2*q**2)} + eq = 411*x**2+57*y**2-221*z**2 + coeff = eq.as_coefficients_dict() + assert _diop_ternary_quadratic_normal((x, y, z), coeff) == \ + {(2021, 2645, 3066)} + assert diophantine(eq) == \ + {(115197*p**2 - 446641*q**2, -150765*p**2 + 1355172*p*q - + 584545*q**2, 174762*p**2 - 301530*p*q + 677586*q**2)} + eq = 573*x**2+267*y**2-984*z**2 + coeff = eq.as_coefficients_dict() + assert _diop_ternary_quadratic_normal((x, y, z), coeff) == \ + {(49, 233, 127)} + assert diophantine(eq) == \ + {(4361*p**2 - 16072*q**2, -20737*p**2 + 83312*p*q - 76424*q**2, + 11303*p**2 - 41474*p*q + 41656*q**2)} + # this produces factors during reconstruction + eq = x**2 + 3*y**2 - 12*z**2 + coeff = eq.as_coefficients_dict() + assert _diop_ternary_quadratic_normal((x, y, z), coeff) == \ + {(0, 2, 1)} + assert diophantine(eq) == \ + {(24*p*q, 2*p**2 - 24*q**2, p**2 + 12*q**2)} + # solvers have not been written for every type + raises(NotImplementedError, lambda: diophantine(x*y**2 + 1)) + + # rational expressions + assert diophantine(1/x) == set() + assert diophantine(1/x + 1/y - S.Half) == {(6, 3), (-2, 1), (4, 4), (1, -2), (3, 6)} + assert diophantine(x**2 + y**2 +3*x- 5, permute=True) == \ + {(-1, 1), (-4, -1), (1, -1), (1, 1), (-4, 1), (-1, -1), (4, 1), (4, -1)} + + + #test issue 18186 + assert diophantine(y**4 + x**4 - 2**4 - 3**4, syms=(x, y), permute=True) == \ + {(-3, -2), (-3, 2), (-2, -3), (-2, 3), (2, -3), (2, 3), (3, -2), (3, 2)} + assert diophantine(y**4 + x**4 - 2**4 - 3**4, syms=(y, x), permute=True) == \ + {(-3, -2), (-3, 2), (-2, -3), (-2, 3), (2, -3), (2, 3), (3, -2), (3, 2)} + + # issue 18122 + assert check_solutions(x**2 - y) + assert check_solutions(y**2 - x) + assert diophantine((x**2 - y), t) == {(t, t**2)} + assert diophantine((y**2 - x), t) == {(t**2, t)} + + +def test_general_pythagorean(): + from sympy.abc import a, b, c, d, e + + assert check_solutions(a**2 + b**2 + c**2 - d**2) + assert check_solutions(a**2 + 4*b**2 + 4*c**2 - d**2) + assert check_solutions(9*a**2 + 4*b**2 + 4*c**2 - d**2) + assert check_solutions(9*a**2 + 4*b**2 - 25*d**2 + 4*c**2 ) + assert check_solutions(9*a**2 - 16*d**2 + 4*b**2 + 4*c**2) + assert check_solutions(-e**2 + 9*a**2 + 4*b**2 + 4*c**2 + 25*d**2) + assert check_solutions(16*a**2 - b**2 + 9*c**2 + d**2 + 25*e**2) + + assert GeneralPythagorean(a**2 + b**2 + c**2 - d**2).solve(parameters=[x, y, z]) == \ + {(x**2 + y**2 - z**2, 2*x*z, 2*y*z, x**2 + y**2 + z**2)} + + +def test_diop_general_sum_of_squares_quick(): + for i in range(3, 10): + assert check_solutions(sum(i**2 for i in symbols(':%i' % i)) - i) + + assert diop_general_sum_of_squares(x**2 + y**2 - 2) is None + assert diop_general_sum_of_squares(x**2 + y**2 + z**2 + 2) == set() + eq = x**2 + y**2 + z**2 - (1 + 4 + 9) + assert diop_general_sum_of_squares(eq) == \ + {(1, 2, 3)} + eq = u**2 + v**2 + x**2 + y**2 + z**2 - 1313 + assert len(diop_general_sum_of_squares(eq, 3)) == 3 + # issue 11016 + var = symbols(':5') + (symbols('6', negative=True),) + eq = Add(*[i**2 for i in var]) - 112 + + base_soln = {(0, 1, 1, 5, 6, -7), (1, 1, 1, 3, 6, -8), (2, 3, 3, 4, 5, -7), (0, 1, 1, 1, 3, -10), + (0, 0, 4, 4, 4, -8), (1, 2, 3, 3, 5, -8), (0, 1, 2, 3, 7, -7), (2, 2, 4, 4, 6, -6), + (1, 1, 3, 4, 6, -7), (0, 2, 3, 3, 3, -9), (0, 0, 2, 2, 2, -10), (1, 1, 2, 3, 4, -9), + (0, 1, 1, 2, 5, -9), (0, 0, 2, 6, 6, -6), (1, 3, 4, 5, 5, -6), (0, 2, 2, 2, 6, -8), + (0, 3, 3, 3, 6, -7), (0, 2, 3, 5, 5, -7), (0, 1, 5, 5, 5, -6)} + assert diophantine(eq) == base_soln + assert len(diophantine(eq, permute=True)) == 196800 + + # handle negated squares with signsimp + assert diophantine(12 - x**2 - y**2 - z**2) == {(2, 2, 2)} + # diophantine handles simplification, so classify_diop should + # not have to look for additional patterns that are removed + # by diophantine + eq = a**2 + b**2 + c**2 + d**2 - 4 + raises(NotImplementedError, lambda: classify_diop(-eq)) + + +def test_issue_23807(): + # fixes recursion error + eq = x**2 + y**2 + z**2 - 1000000 + base_soln = {(0, 0, 1000), (0, 352, 936), (480, 600, 640), (24, 640, 768), (192, 640, 744), + (192, 480, 856), (168, 224, 960), (0, 600, 800), (280, 576, 768), (152, 480, 864), + (0, 280, 960), (352, 360, 864), (424, 480, 768), (360, 480, 800), (224, 600, 768), + (96, 360, 928), (168, 576, 800), (96, 480, 872)} + + assert diophantine(eq) == base_soln + + +def test_diop_partition(): + for n in [8, 10]: + for k in range(1, 8): + for p in partition(n, k): + assert len(p) == k + assert list(partition(3, 5)) == [] + assert [list(p) for p in partition(3, 5, 1)] == [ + [0, 0, 0, 0, 3], [0, 0, 0, 1, 2], [0, 0, 1, 1, 1]] + assert list(partition(0)) == [()] + assert list(partition(1, 0)) == [()] + assert [list(i) for i in partition(3)] == [[1, 1, 1], [1, 2], [3]] + + +def test_prime_as_sum_of_two_squares(): + for i in [5, 13, 17, 29, 37, 41, 2341, 3557, 34841, 64601]: + a, b = prime_as_sum_of_two_squares(i) + assert a**2 + b**2 == i + assert prime_as_sum_of_two_squares(7) is None + ans = prime_as_sum_of_two_squares(800029) + assert ans == (450, 773) and type(ans[0]) is int + + +def test_sum_of_three_squares(): + for i in [0, 1, 2, 34, 123, 34304595905, 34304595905394941, 343045959052344, + 800, 801, 802, 803, 804, 805, 806]: + a, b, c = sum_of_three_squares(i) + assert a**2 + b**2 + c**2 == i + assert a >= 0 + + # error + raises(ValueError, lambda: sum_of_three_squares(-1)) + + assert sum_of_three_squares(7) is None + assert sum_of_three_squares((4**5)*15) is None + # if there are two zeros, there might be a solution + # with only one zero, e.g. 25 => (0, 3, 4) or + # with no zeros, e.g. 49 => (2, 3, 6) + assert sum_of_three_squares(25) == (0, 0, 5) + assert sum_of_three_squares(4) == (0, 0, 2) + + +def test_sum_of_four_squares(): + from sympy.core.random import randint + + # this should never fail + n = randint(1, 100000000000000) + assert sum(i**2 for i in sum_of_four_squares(n)) == n + + # error + raises(ValueError, lambda: sum_of_four_squares(-1)) + + for n in range(1000): + result = sum_of_four_squares(n) + assert len(result) == 4 + assert all(r >= 0 for r in result) + assert sum(r**2 for r in result) == n + assert list(result) == sorted(result) + + +def test_power_representation(): + tests = [(1729, 3, 2), (234, 2, 4), (2, 1, 2), (3, 1, 3), (5, 2, 2), (12352, 2, 4), + (32760, 2, 3)] + + for test in tests: + n, p, k = test + f = power_representation(n, p, k) + + while True: + try: + l = next(f) + assert len(l) == k + + chk_sum = 0 + for l_i in l: + chk_sum = chk_sum + l_i**p + assert chk_sum == n + + except StopIteration: + break + + assert list(power_representation(20, 2, 4, True)) == \ + [(1, 1, 3, 3), (0, 0, 2, 4)] + raises(ValueError, lambda: list(power_representation(1.2, 2, 2))) + raises(ValueError, lambda: list(power_representation(2, 0, 2))) + raises(ValueError, lambda: list(power_representation(2, 2, 0))) + assert list(power_representation(-1, 2, 2)) == [] + assert list(power_representation(1, 1, 1)) == [(1,)] + assert list(power_representation(3, 2, 1)) == [] + assert list(power_representation(4, 2, 1)) == [(2,)] + assert list(power_representation(3**4, 4, 6, zeros=True)) == \ + [(1, 2, 2, 2, 2, 2), (0, 0, 0, 0, 0, 3)] + assert list(power_representation(3**4, 4, 5, zeros=False)) == [] + assert list(power_representation(-2, 3, 2)) == [(-1, -1)] + assert list(power_representation(-2, 4, 2)) == [] + assert list(power_representation(0, 3, 2, True)) == [(0, 0)] + assert list(power_representation(0, 3, 2, False)) == [] + # when we are dealing with squares, do feasibility checks + assert len(list(power_representation(4**10*(8*10 + 7), 2, 3))) == 0 + # there will be a recursion error if these aren't recognized + big = 2**30 + for i in [13, 10, 7, 5, 4, 2, 1]: + assert list(sum_of_powers(big, 2, big - i)) == [] + + +def test_assumptions(): + """ + Test whether diophantine respects the assumptions. + """ + #Test case taken from the below so question regarding assumptions in diophantine module + #https://stackoverflow.com/questions/23301941/how-can-i-declare-natural-symbols-with-sympy + m, n = symbols('m n', integer=True, positive=True) + diof = diophantine(n**2 + m*n - 500) + assert diof == {(5, 20), (40, 10), (95, 5), (121, 4), (248, 2), (499, 1)} + + a, b = symbols('a b', integer=True, positive=False) + diof = diophantine(a*b + 2*a + 3*b - 6) + assert diof == {(-15, -3), (-9, -4), (-7, -5), (-6, -6), (-5, -8), (-4, -14)} + + x, y = symbols('x y', integer=True) + diof = diophantine(10*x**2 + 5*x*y - 3*y) + assert diof == {(1, -5), (-3, 5), (0, 0)} + + x, y = symbols('x y', integer=True, positive=True) + diof = diophantine(10*x**2 + 5*x*y - 3*y) + assert diof == set() + + x, y = symbols('x y', integer=True, negative=False) + diof = diophantine(10*x**2 + 5*x*y - 3*y) + assert diof == {(0, 0)} + + +def check_solutions(eq): + """ + Determines whether solutions returned by diophantine() satisfy the original + equation. Hope to generalize this so we can remove functions like check_ternay_quadratic, + check_solutions_normal, check_solutions() + """ + s = diophantine(eq) + + factors = Mul.make_args(eq) + + var = list(eq.free_symbols) + var.sort(key=default_sort_key) + + while s: + solution = s.pop() + for f in factors: + if diop_simplify(f.subs(zip(var, solution))) == 0: + break + else: + return False + return True + + +def test_diopcoverage(): + eq = (2*x + y + 1)**2 + assert diop_solve(eq) == {(t_0, -2*t_0 - 1)} + eq = 2*x**2 + 6*x*y + 12*x + 4*y**2 + 18*y + 18 + assert diop_solve(eq) == {(t, -t - 3), (-2*t - 3, t)} + assert diop_quadratic(x + y**2 - 3) == {(-t**2 + 3, t)} + + assert diop_linear(x + y - 3) == (t_0, 3 - t_0) + + assert base_solution_linear(0, 1, 2, t=None) == (0, 0) + ans = (3*t - 1, -2*t + 1) + assert base_solution_linear(4, 8, 12, t) == ans + assert base_solution_linear(4, 8, 12, t=None) == tuple(_.subs(t, 0) for _ in ans) + + assert cornacchia(1, 1, 20) == set() + assert cornacchia(1, 1, 5) == {(2, 1)} + assert cornacchia(1, 2, 17) == {(3, 2)} + + raises(ValueError, lambda: reconstruct(4, 20, 1)) + + assert gaussian_reduce(4, 1, 3) == (1, 1) + eq = -w**2 - x**2 - y**2 + z**2 + + assert diop_general_pythagorean(eq) == \ + diop_general_pythagorean(-eq) == \ + (m1**2 + m2**2 - m3**2, 2*m1*m3, + 2*m2*m3, m1**2 + m2**2 + m3**2) + + assert len(check_param(S(3) + x/3, S(4) + x/2, S(2), [x])) == 0 + assert len(check_param(Rational(3, 2), S(4) + x, S(2), [x])) == 0 + assert len(check_param(S(4) + x, Rational(3, 2), S(2), [x])) == 0 + + assert _nint_or_floor(16, 10) == 2 + assert _odd(1) == (not _even(1)) == True + assert _odd(0) == (not _even(0)) == False + assert _remove_gcd(2, 4, 6) == (1, 2, 3) + raises(TypeError, lambda: _remove_gcd((2, 4, 6))) + assert sqf_normal(2*3**2*5, 2*5*11, 2*7**2*11) == \ + (11, 1, 5) + + # it's ok if these pass some day when the solvers are implemented + raises(NotImplementedError, lambda: diophantine(x**2 + y**2 + x*y + 2*y*z - 12)) + raises(NotImplementedError, lambda: diophantine(x**3 + y**2)) + assert diop_quadratic(x**2 + y**2 - 1**2 - 3**4) == \ + {(-9, -1), (-9, 1), (-1, -9), (-1, 9), (1, -9), (1, 9), (9, -1), (9, 1)} + + +def test_holzer(): + # if the input is good, don't let it diverge in holzer() + # (but see test_fail_holzer below) + assert holzer(2, 7, 13, 4, 79, 23) == (2, 7, 13) + + # None in uv condition met; solution is not Holzer reduced + # so this will hopefully change but is here for coverage + assert holzer(2, 6, 2, 1, 1, 10) == (2, 6, 2) + + raises(ValueError, lambda: holzer(2, 7, 14, 4, 79, 23)) + + +@XFAIL +def test_fail_holzer(): + eq = lambda x, y, z: a*x**2 + b*y**2 - c*z**2 + a, b, c = 4, 79, 23 + x, y, z = xyz = 26, 1, 11 + X, Y, Z = ans = 2, 7, 13 + assert eq(*xyz) == 0 + assert eq(*ans) == 0 + assert max(a*x**2, b*y**2, c*z**2) <= a*b*c + assert max(a*X**2, b*Y**2, c*Z**2) <= a*b*c + h = holzer(x, y, z, a, b, c) + assert h == ans # it would be nice to get the smaller soln + + +def test_issue_9539(): + assert diophantine(6*w + 9*y + 20*x - z) == \ + {(t_0, t_1, t_1 + t_2, 6*t_0 + 29*t_1 + 9*t_2)} + + +def test_issue_8943(): + assert diophantine( + 3*(x**2 + y**2 + z**2) - 14*(x*y + y*z + z*x)) == \ + {(0, 0, 0)} + + +def test_diop_sum_of_even_powers(): + eq = x**4 + y**4 + z**4 - 2673 + assert diop_solve(eq) == {(3, 6, 6), (2, 4, 7)} + assert diop_general_sum_of_even_powers(eq, 2) == {(3, 6, 6), (2, 4, 7)} + raises(NotImplementedError, lambda: diop_general_sum_of_even_powers(-eq, 2)) + neg = symbols('neg', negative=True) + eq = x**4 + y**4 + neg**4 - 2673 + assert diop_general_sum_of_even_powers(eq) == {(-3, 6, 6)} + assert diophantine(x**4 + y**4 + 2) == set() + assert diop_general_sum_of_even_powers(x**4 + y**4 - 2, limit=0) == set() + + +def test_sum_of_squares_powers(): + tru = {(0, 0, 1, 1, 11), (0, 0, 5, 7, 7), (0, 1, 3, 7, 8), (0, 1, 4, 5, 9), (0, 3, 4, 7, 7), (0, 3, 5, 5, 8), + (1, 1, 2, 6, 9), (1, 1, 6, 6, 7), (1, 2, 3, 3, 10), (1, 3, 4, 4, 9), (1, 5, 5, 6, 6), (2, 2, 3, 5, 9), + (2, 3, 5, 6, 7), (3, 3, 4, 5, 8)} + eq = u**2 + v**2 + x**2 + y**2 + z**2 - 123 + ans = diop_general_sum_of_squares(eq, oo) # allow oo to be used + assert len(ans) == 14 + assert ans == tru + + raises(ValueError, lambda: list(sum_of_squares(10, -1))) + assert list(sum_of_squares(1, 1)) == [(1,)] + assert list(sum_of_squares(1, 2)) == [] + assert list(sum_of_squares(1, 2, True)) == [(0, 1)] + assert list(sum_of_squares(-10, 2)) == [] + assert list(sum_of_squares(2, 3)) == [] + assert list(sum_of_squares(0, 3, True)) == [(0, 0, 0)] + assert list(sum_of_squares(0, 3)) == [] + assert list(sum_of_squares(4, 1)) == [(2,)] + assert list(sum_of_squares(5, 1)) == [] + assert list(sum_of_squares(50, 2)) == [(5, 5), (1, 7)] + assert list(sum_of_squares(11, 5, True)) == [ + (1, 1, 1, 2, 2), (0, 0, 1, 1, 3)] + assert list(sum_of_squares(8, 8)) == [(1, 1, 1, 1, 1, 1, 1, 1)] + + assert [len(list(sum_of_squares(i, 5, True))) for i in range(30)] == [ + 1, 1, 1, 1, 2, + 2, 1, 1, 2, 2, + 2, 2, 2, 3, 2, + 1, 3, 3, 3, 3, + 4, 3, 3, 2, 2, + 4, 4, 4, 4, 5] + assert [len(list(sum_of_squares(i, 5))) for i in range(30)] == [ + 0, 0, 0, 0, 0, + 1, 0, 0, 1, 0, + 0, 1, 0, 1, 1, + 0, 1, 1, 0, 1, + 2, 1, 1, 1, 1, + 1, 1, 1, 1, 3] + for i in range(30): + s1 = set(sum_of_squares(i, 5, True)) + assert not s1 or all(sum(j**2 for j in t) == i for t in s1) + s2 = set(sum_of_squares(i, 5)) + assert all(sum(j**2 for j in t) == i for t in s2) + + raises(ValueError, lambda: list(sum_of_powers(2, -1, 1))) + raises(ValueError, lambda: list(sum_of_powers(2, 1, -1))) + assert list(sum_of_powers(-2, 3, 2)) == [(-1, -1)] + assert list(sum_of_powers(-2, 4, 2)) == [] + assert list(sum_of_powers(2, 1, 1)) == [(2,)] + assert list(sum_of_powers(2, 1, 3, True)) == [(0, 0, 2), (0, 1, 1)] + assert list(sum_of_powers(5, 1, 2, True)) == [(0, 5), (1, 4), (2, 3)] + assert list(sum_of_powers(6, 2, 2)) == [] + assert list(sum_of_powers(3**5, 3, 1)) == [] + assert list(sum_of_powers(3**6, 3, 1)) == [(9,)] and (9**3 == 3**6) + assert list(sum_of_powers(2**1000, 5, 2)) == [] + + +def test__can_do_sum_of_squares(): + assert _can_do_sum_of_squares(3, -1) is False + assert _can_do_sum_of_squares(-3, 1) is False + assert _can_do_sum_of_squares(0, 1) + assert _can_do_sum_of_squares(4, 1) + assert _can_do_sum_of_squares(1, 2) + assert _can_do_sum_of_squares(2, 2) + assert _can_do_sum_of_squares(3, 2) is False + + +def test_diophantine_permute_sign(): + from sympy.abc import a, b, c, d, e + eq = a**4 + b**4 - (2**4 + 3**4) + base_sol = {(2, 3)} + assert diophantine(eq) == base_sol + complete_soln = set(signed_permutations(base_sol.pop())) + assert diophantine(eq, permute=True) == complete_soln + + eq = a**2 + b**2 + c**2 + d**2 + e**2 - 234 + assert len(diophantine(eq)) == 35 + assert len(diophantine(eq, permute=True)) == 62000 + soln = {(-1, -1), (-1, 2), (1, -2), (1, 1)} + assert diophantine(10*x**2 + 12*x*y + 12*y**2 - 34, permute=True) == soln + + +@XFAIL +def test_not_implemented(): + eq = x**2 + y**4 - 1**2 - 3**4 + assert diophantine(eq, syms=[x, y]) == {(9, 1), (1, 3)} + + +def test_issue_9538(): + eq = x - 3*y + 2 + assert diophantine(eq, syms=[y,x]) == {(t_0, 3*t_0 - 2)} + raises(TypeError, lambda: diophantine(eq, syms={y, x})) + + +def test_ternary_quadratic(): + # solution with 3 parameters + s = diophantine(2*x**2 + y**2 - 2*z**2) + p, q, r = ordered(S(s).free_symbols) + assert s == {( + p**2 - 2*q**2, + -2*p**2 + 4*p*q - 4*p*r - 4*q**2, + p**2 - 4*p*q + 2*q**2 - 4*q*r)} + # solution with Mul in solution + s = diophantine(x**2 + 2*y**2 - 2*z**2) + assert s == {(4*p*q, p**2 - 2*q**2, p**2 + 2*q**2)} + # solution with no Mul in solution + s = diophantine(2*x**2 + 2*y**2 - z**2) + assert s == {(2*p**2 - q**2, -2*p**2 + 4*p*q - q**2, + 4*p**2 - 4*p*q + 2*q**2)} + # reduced form when parametrized + s = diophantine(3*x**2 + 72*y**2 - 27*z**2) + assert s == {(24*p**2 - 9*q**2, 6*p*q, 8*p**2 + 3*q**2)} + assert parametrize_ternary_quadratic( + 3*x**2 + 2*y**2 - z**2 - 2*x*y + 5*y*z - 7*y*z) == ( + 2*p**2 - 2*p*q - q**2, 2*p**2 + 2*p*q - q**2, 2*p**2 - + 2*p*q + 3*q**2) + assert parametrize_ternary_quadratic( + 124*x**2 - 30*y**2 - 7729*z**2) == ( + -1410*p**2 - 363263*q**2, 2700*p**2 + 30916*p*q - + 695610*q**2, -60*p**2 + 5400*p*q + 15458*q**2) + + +def test_diophantine_solution_set(): + s1 = DiophantineSolutionSet([], []) + assert set(s1) == set() + assert s1.symbols == () + assert s1.parameters == () + raises(ValueError, lambda: s1.add((x,))) + assert list(s1.dict_iterator()) == [] + + s2 = DiophantineSolutionSet([x, y], [t, u]) + assert s2.symbols == (x, y) + assert s2.parameters == (t, u) + raises(ValueError, lambda: s2.add((1,))) + s2.add((3, 4)) + assert set(s2) == {(3, 4)} + s2.update((3, 4), (-1, u)) + assert set(s2) == {(3, 4), (-1, u)} + raises(ValueError, lambda: s1.update(s2)) + assert list(s2.dict_iterator()) == [{x: -1, y: u}, {x: 3, y: 4}] + + s3 = DiophantineSolutionSet([x, y, z], [t, u]) + assert len(s3.parameters) == 2 + s3.add((t**2 + u, t - u, 1)) + assert set(s3) == {(t**2 + u, t - u, 1)} + assert s3.subs(t, 2) == {(u + 4, 2 - u, 1)} + assert s3(2) == {(u + 4, 2 - u, 1)} + assert s3.subs({t: 7, u: 8}) == {(57, -1, 1)} + assert s3(7, 8) == {(57, -1, 1)} + assert s3.subs({t: 5}) == {(u + 25, 5 - u, 1)} + assert s3(5) == {(u + 25, 5 - u, 1)} + assert s3.subs(u, -3) == {(t**2 - 3, t + 3, 1)} + assert s3(None, -3) == {(t**2 - 3, t + 3, 1)} + assert s3.subs({t: 2, u: 8}) == {(12, -6, 1)} + assert s3(2, 8) == {(12, -6, 1)} + assert s3.subs({t: 5, u: -3}) == {(22, 8, 1)} + assert s3(5, -3) == {(22, 8, 1)} + raises(TypeError, lambda: s3.subs(x=1)) + raises(TypeError, lambda: s3.subs(1, 2, 3)) + raises(ValueError, lambda: s3.add(())) + raises(ValueError, lambda: s3.add((1, 2, 3, 4))) + raises(ValueError, lambda: s3.add((1, 2))) + raises(ValueError, lambda: s3(1, 2, 3)) + raises(TypeError, lambda: s3(t=1)) + + s4 = DiophantineSolutionSet([x, y], [t, u]) + s4.add((t, 11*t)) + s4.add((-t, 22*t)) + assert s4(0, 0) == {(0, 0)} + + +def test_quadratic_parameter_passing(): + eq = -33*x*y + 3*y**2 + solution = BinaryQuadratic(eq).solve(parameters=[t, u]) + # test that parameters are passed all the way to the final solution + assert solution == {(t, 11*t), (t, -22*t)} + assert solution(0, 0) == {(0, 0)} + +def test_issue_18628(): + eq1 = x**2 - 15*x + y**2 - 8*y + sol = diophantine(eq1) + assert sol == {(15, 0), (15, 8), (-1, 4), (0, 0), (0, 8), (16, 4)} + eq2 = 2*x**2 - 9*x + 4*y**2 - 8*y + 14 + sol = diophantine(eq2) + assert sol == {(2, 1)} diff --git a/phivenv/Lib/site-packages/sympy/solvers/inequalities.py b/phivenv/Lib/site-packages/sympy/solvers/inequalities.py new file mode 100644 index 0000000000000000000000000000000000000000..f50f7a7572ff25e8f48a4214d7b0c5ec6b5b35f3 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/solvers/inequalities.py @@ -0,0 +1,986 @@ +"""Tools for solving inequalities and systems of inequalities. """ +import itertools + +from sympy.calculus.util import (continuous_domain, periodicity, + function_range) +from sympy.core import sympify +from sympy.core.exprtools import factor_terms +from sympy.core.relational import Relational, Lt, Ge, Eq +from sympy.core.symbol import Symbol, Dummy +from sympy.sets.sets import Interval, FiniteSet, Union, Intersection +from sympy.core.singleton import S +from sympy.core.function import expand_mul +from sympy.functions.elementary.complexes import Abs +from sympy.logic import And +from sympy.polys import Poly, PolynomialError, parallel_poly_from_expr +from sympy.polys.polyutils import _nsort +from sympy.solvers.solveset import solvify, solveset +from sympy.utilities.iterables import sift, iterable +from sympy.utilities.misc import filldedent + + +def solve_poly_inequality(poly, rel): + """Solve a polynomial inequality with rational coefficients. + + Examples + ======== + + >>> from sympy import solve_poly_inequality, Poly + >>> from sympy.abc import x + + >>> solve_poly_inequality(Poly(x, x, domain='ZZ'), '==') + [{0}] + + >>> solve_poly_inequality(Poly(x**2 - 1, x, domain='ZZ'), '!=') + [Interval.open(-oo, -1), Interval.open(-1, 1), Interval.open(1, oo)] + + >>> solve_poly_inequality(Poly(x**2 - 1, x, domain='ZZ'), '==') + [{-1}, {1}] + + See Also + ======== + solve_poly_inequalities + """ + if not isinstance(poly, Poly): + raise ValueError( + 'For efficiency reasons, `poly` should be a Poly instance') + if poly.as_expr().is_number: + t = Relational(poly.as_expr(), 0, rel) + if t is S.true: + return [S.Reals] + elif t is S.false: + return [S.EmptySet] + else: + raise NotImplementedError( + "could not determine truth value of %s" % t) + + reals, intervals = poly.real_roots(multiple=False), [] + + if rel == '==': + for root, _ in reals: + interval = Interval(root, root) + intervals.append(interval) + elif rel == '!=': + left = S.NegativeInfinity + + for right, _ in reals + [(S.Infinity, 1)]: + interval = Interval(left, right, True, True) + intervals.append(interval) + left = right + else: + if poly.LC() > 0: + sign = +1 + else: + sign = -1 + + eq_sign, equal = None, False + + if rel == '>': + eq_sign = +1 + elif rel == '<': + eq_sign = -1 + elif rel == '>=': + eq_sign, equal = +1, True + elif rel == '<=': + eq_sign, equal = -1, True + else: + raise ValueError("'%s' is not a valid relation" % rel) + + right, right_open = S.Infinity, True + + for left, multiplicity in reversed(reals): + if multiplicity % 2: + if sign == eq_sign: + intervals.insert( + 0, Interval(left, right, not equal, right_open)) + + sign, right, right_open = -sign, left, not equal + else: + if sign == eq_sign and not equal: + intervals.insert( + 0, Interval(left, right, True, right_open)) + right, right_open = left, True + elif sign != eq_sign and equal: + intervals.insert(0, Interval(left, left)) + + if sign == eq_sign: + intervals.insert( + 0, Interval(S.NegativeInfinity, right, True, right_open)) + + return intervals + + +def solve_poly_inequalities(polys): + """Solve polynomial inequalities with rational coefficients. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.solvers.inequalities import solve_poly_inequalities + >>> from sympy.abc import x + >>> solve_poly_inequalities((( + ... Poly(x**2 - 3), ">"), ( + ... Poly(-x**2 + 1), ">"))) + Union(Interval.open(-oo, -sqrt(3)), Interval.open(-1, 1), Interval.open(sqrt(3), oo)) + """ + return Union(*[s for p in polys for s in solve_poly_inequality(*p)]) + + +def solve_rational_inequalities(eqs): + """Solve a system of rational inequalities with rational coefficients. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy import solve_rational_inequalities, Poly + + >>> solve_rational_inequalities([[ + ... ((Poly(-x + 1), Poly(1, x)), '>='), + ... ((Poly(-x + 1), Poly(1, x)), '<=')]]) + {1} + + >>> solve_rational_inequalities([[ + ... ((Poly(x), Poly(1, x)), '!='), + ... ((Poly(-x + 1), Poly(1, x)), '>=')]]) + Union(Interval.open(-oo, 0), Interval.Lopen(0, 1)) + + See Also + ======== + solve_poly_inequality + """ + result = S.EmptySet + + for _eqs in eqs: + if not _eqs: + continue + + global_intervals = [Interval(S.NegativeInfinity, S.Infinity)] + + for (numer, denom), rel in _eqs: + numer_intervals = solve_poly_inequality(numer*denom, rel) + denom_intervals = solve_poly_inequality(denom, '==') + + intervals = [] + + for numer_interval, global_interval in itertools.product( + numer_intervals, global_intervals): + interval = numer_interval.intersect(global_interval) + + if interval is not S.EmptySet: + intervals.append(interval) + + global_intervals = intervals + + intervals = [] + + for global_interval in global_intervals: + for denom_interval in denom_intervals: + global_interval -= denom_interval + + if global_interval is not S.EmptySet: + intervals.append(global_interval) + + global_intervals = intervals + + if not global_intervals: + break + + for interval in global_intervals: + result = result.union(interval) + + return result + + +def reduce_rational_inequalities(exprs, gen, relational=True): + """Reduce a system of rational inequalities with rational coefficients. + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy.solvers.inequalities import reduce_rational_inequalities + + >>> x = Symbol('x', real=True) + + >>> reduce_rational_inequalities([[x**2 <= 0]], x) + Eq(x, 0) + + >>> reduce_rational_inequalities([[x + 2 > 0]], x) + -2 < x + >>> reduce_rational_inequalities([[(x + 2, ">")]], x) + -2 < x + >>> reduce_rational_inequalities([[x + 2]], x) + Eq(x, -2) + + This function find the non-infinite solution set so if the unknown symbol + is declared as extended real rather than real then the result may include + finiteness conditions: + + >>> y = Symbol('y', extended_real=True) + >>> reduce_rational_inequalities([[y + 2 > 0]], y) + (-2 < y) & (y < oo) + """ + exact = True + eqs = [] + solution = S.EmptySet # add pieces for each group + for _exprs in exprs: + if not _exprs: + continue + _eqs = [] + _sol = S.Reals + for expr in _exprs: + if isinstance(expr, tuple): + expr, rel = expr + else: + if expr.is_Relational: + expr, rel = expr.lhs - expr.rhs, expr.rel_op + else: + rel = '==' + + if expr is S.true: + numer, denom, rel = S.Zero, S.One, '==' + elif expr is S.false: + numer, denom, rel = S.One, S.One, '==' + else: + numer, denom = expr.together().as_numer_denom() + + try: + (numer, denom), opt = parallel_poly_from_expr( + (numer, denom), gen) + except PolynomialError: + raise PolynomialError(filldedent(''' + only polynomials and rational functions are + supported in this context. + ''')) + + if not opt.domain.is_Exact: + numer, denom, exact = numer.to_exact(), denom.to_exact(), False + + domain = opt.domain.get_exact() + + if not (domain.is_ZZ or domain.is_QQ): + expr = numer/denom + expr = Relational(expr, 0, rel) + _sol &= solve_univariate_inequality(expr, gen, relational=False) + else: + _eqs.append(((numer, denom), rel)) + + if _eqs: + _sol &= solve_rational_inequalities([_eqs]) + exclude = solve_rational_inequalities([[((d, d.one), '==') + for i in eqs for ((n, d), _) in i if d.has(gen)]]) + _sol -= exclude + + solution |= _sol + + if not exact and solution: + solution = solution.evalf() + + if relational: + solution = solution.as_relational(gen) + + return solution + + +def reduce_abs_inequality(expr, rel, gen): + """Reduce an inequality with nested absolute values. + + Examples + ======== + + >>> from sympy import reduce_abs_inequality, Abs, Symbol + >>> x = Symbol('x', real=True) + + >>> reduce_abs_inequality(Abs(x - 5) - 3, '<', x) + (2 < x) & (x < 8) + + >>> reduce_abs_inequality(Abs(x + 2)*3 - 13, '<', x) + (-19/3 < x) & (x < 7/3) + + See Also + ======== + + reduce_abs_inequalities + """ + if gen.is_extended_real is False: + raise TypeError(filldedent(''' + Cannot solve inequalities with absolute values containing + non-real variables. + ''')) + + def _bottom_up_scan(expr): + exprs = [] + + if expr.is_Add or expr.is_Mul: + op = expr.func + + for arg in expr.args: + _exprs = _bottom_up_scan(arg) + + if not exprs: + exprs = _exprs + else: + exprs = [(op(expr, _expr), conds + _conds) for (expr, conds), (_expr, _conds) in + itertools.product(exprs, _exprs)] + elif expr.is_Pow: + n = expr.exp + if not n.is_Integer: + raise ValueError("Only Integer Powers are allowed on Abs.") + + exprs.extend((expr**n, conds) for expr, conds in _bottom_up_scan(expr.base)) + elif isinstance(expr, Abs): + _exprs = _bottom_up_scan(expr.args[0]) + + for expr, conds in _exprs: + exprs.append(( expr, conds + [Ge(expr, 0)])) + exprs.append((-expr, conds + [Lt(expr, 0)])) + else: + exprs = [(expr, [])] + + return exprs + + mapping = {'<': '>', '<=': '>='} + inequalities = [] + + for expr, conds in _bottom_up_scan(expr): + if rel not in mapping.keys(): + expr = Relational( expr, 0, rel) + else: + expr = Relational(-expr, 0, mapping[rel]) + + inequalities.append([expr] + conds) + + return reduce_rational_inequalities(inequalities, gen) + + +def reduce_abs_inequalities(exprs, gen): + """Reduce a system of inequalities with nested absolute values. + + Examples + ======== + + >>> from sympy import reduce_abs_inequalities, Abs, Symbol + >>> x = Symbol('x', extended_real=True) + + >>> reduce_abs_inequalities([(Abs(3*x - 5) - 7, '<'), + ... (Abs(x + 25) - 13, '>')], x) + (-2/3 < x) & (x < 4) & (((-oo < x) & (x < -38)) | ((-12 < x) & (x < oo))) + + >>> reduce_abs_inequalities([(Abs(x - 4) + Abs(3*x - 5) - 7, '<')], x) + (1/2 < x) & (x < 4) + + See Also + ======== + + reduce_abs_inequality + """ + return And(*[ reduce_abs_inequality(expr, rel, gen) + for expr, rel in exprs ]) + + +def solve_univariate_inequality(expr, gen, relational=True, domain=S.Reals, continuous=False): + """Solves a real univariate inequality. + + Parameters + ========== + + expr : Relational + The target inequality + gen : Symbol + The variable for which the inequality is solved + relational : bool + A Relational type output is expected or not + domain : Set + The domain over which the equation is solved + continuous: bool + True if expr is known to be continuous over the given domain + (and so continuous_domain() does not need to be called on it) + + Raises + ====== + + NotImplementedError + The solution of the inequality cannot be determined due to limitation + in :func:`sympy.solvers.solveset.solvify`. + + Notes + ===== + + Currently, we cannot solve all the inequalities due to limitations in + :func:`sympy.solvers.solveset.solvify`. Also, the solution returned for trigonometric inequalities + are restricted in its periodic interval. + + See Also + ======== + + sympy.solvers.solveset.solvify: solver returning solveset solutions with solve's output API + + Examples + ======== + + >>> from sympy import solve_univariate_inequality, Symbol, sin, Interval, S + >>> x = Symbol('x') + + >>> solve_univariate_inequality(x**2 >= 4, x) + ((2 <= x) & (x < oo)) | ((-oo < x) & (x <= -2)) + + >>> solve_univariate_inequality(x**2 >= 4, x, relational=False) + Union(Interval(-oo, -2), Interval(2, oo)) + + >>> domain = Interval(0, S.Infinity) + >>> solve_univariate_inequality(x**2 >= 4, x, False, domain) + Interval(2, oo) + + >>> solve_univariate_inequality(sin(x) > 0, x, relational=False) + Interval.open(0, pi) + + """ + from sympy.solvers.solvers import denoms + + if domain.is_subset(S.Reals) is False: + raise NotImplementedError(filldedent(''' + Inequalities in the complex domain are + not supported. Try the real domain by + setting domain=S.Reals''')) + elif domain is not S.Reals: + rv = solve_univariate_inequality( + expr, gen, relational=False, continuous=continuous).intersection(domain) + if relational: + rv = rv.as_relational(gen) + return rv + else: + pass # continue with attempt to solve in Real domain + + # This keeps the function independent of the assumptions about `gen`. + # `solveset` makes sure this function is called only when the domain is + # real. + _gen = gen + _domain = domain + if gen.is_extended_real is False: + rv = S.EmptySet + return rv if not relational else rv.as_relational(_gen) + elif gen.is_extended_real is None: + gen = Dummy('gen', extended_real=True) + try: + expr = expr.xreplace({_gen: gen}) + except TypeError: + raise TypeError(filldedent(''' + When gen is real, the relational has a complex part + which leads to an invalid comparison like I < 0. + ''')) + + rv = None + + if expr is S.true: + rv = domain + + elif expr is S.false: + rv = S.EmptySet + + else: + e = expr.lhs - expr.rhs + period = periodicity(e, gen) + if period == S.Zero: + e = expand_mul(e) + const = expr.func(e, 0) + if const is S.true: + rv = domain + elif const is S.false: + rv = S.EmptySet + elif period is not None: + frange = function_range(e, gen, domain) + + rel = expr.rel_op + if rel in ('<', '<='): + if expr.func(frange.sup, 0): + rv = domain + elif not expr.func(frange.inf, 0): + rv = S.EmptySet + + elif rel in ('>', '>='): + if expr.func(frange.inf, 0): + rv = domain + elif not expr.func(frange.sup, 0): + rv = S.EmptySet + + inf, sup = domain.inf, domain.sup + if sup - inf is S.Infinity: + domain = Interval(0, period, False, True).intersect(_domain) + _domain = domain + + if rv is None: + n, d = e.as_numer_denom() + try: + if gen not in n.free_symbols and len(e.free_symbols) > 1: + raise ValueError + # this might raise ValueError on its own + # or it might give None... + solns = solvify(e, gen, domain) + if solns is None: + # in which case we raise ValueError + raise ValueError + except (ValueError, NotImplementedError): + # replace gen with generic x since it's + # univariate anyway + raise NotImplementedError(filldedent(''' + The inequality, %s, cannot be solved using + solve_univariate_inequality. + ''' % expr.subs(gen, Symbol('x')))) + + expanded_e = expand_mul(e) + def valid(x): + # this is used to see if gen=x satisfies the + # relational by substituting it into the + # expanded form and testing against 0, e.g. + # if expr = x*(x + 1) < 2 then e = x*(x + 1) - 2 + # and expanded_e = x**2 + x - 2; the test is + # whether a given value of x satisfies + # x**2 + x - 2 < 0 + # + # expanded_e, expr and gen used from enclosing scope + v = expanded_e.subs(gen, expand_mul(x)) + try: + r = expr.func(v, 0) + except TypeError: + r = S.false + if r in (S.true, S.false): + return r + if v.is_extended_real is False: + return S.false + else: + v = v.n(2) + if v.is_comparable: + return expr.func(v, 0) + # not comparable or couldn't be evaluated + raise NotImplementedError( + 'relationship did not evaluate: %s' % r) + + singularities = [] + for d in denoms(expr, gen): + singularities.extend(solvify(d, gen, domain)) + if not continuous: + domain = continuous_domain(expanded_e, gen, domain) + + include_x = '=' in expr.rel_op and expr.rel_op != '!=' + + try: + discontinuities = set(domain.boundary - + FiniteSet(domain.inf, domain.sup)) + # remove points that are not between inf and sup of domain + critical_points = FiniteSet(*(solns + singularities + list( + discontinuities))).intersection( + Interval(domain.inf, domain.sup, + domain.inf not in domain, domain.sup not in domain)) + if all(r.is_number for r in critical_points): + reals = _nsort(critical_points, separated=True)[0] + else: + sifted = sift(critical_points, lambda x: x.is_extended_real) + if sifted[None]: + # there were some roots that weren't known + # to be real + raise NotImplementedError + try: + reals = sifted[True] + if len(reals) > 1: + reals = sorted(reals) + except TypeError: + raise NotImplementedError + except NotImplementedError: + raise NotImplementedError('sorting of these roots is not supported') + + # If expr contains imaginary coefficients, only take real + # values of x for which the imaginary part is 0 + make_real = S.Reals + if (coeffI := expanded_e.coeff(S.ImaginaryUnit)) != S.Zero: + check = True + im_sol = FiniteSet() + try: + a = solveset(coeffI, gen, domain) + if not isinstance(a, Interval): + for z in a: + if z not in singularities and valid(z) and z.is_extended_real: + im_sol += FiniteSet(z) + else: + start, end = a.inf, a.sup + for z in _nsort(critical_points + FiniteSet(end)): + valid_start = valid(start) + if start != end: + valid_z = valid(z) + pt = _pt(start, z) + if pt not in singularities and pt.is_extended_real and valid(pt): + if valid_start and valid_z: + im_sol += Interval(start, z) + elif valid_start: + im_sol += Interval.Ropen(start, z) + elif valid_z: + im_sol += Interval.Lopen(start, z) + else: + im_sol += Interval.open(start, z) + start = z + for s in singularities: + im_sol -= FiniteSet(s) + except (TypeError): + im_sol = S.Reals + check = False + + if im_sol is S.EmptySet: + raise ValueError(filldedent(''' + %s contains imaginary parts which cannot be + made 0 for any value of %s satisfying the + inequality, leading to relations like I < 0. + ''' % (expr.subs(gen, _gen), _gen))) + + make_real = make_real.intersect(im_sol) + + sol_sets = [S.EmptySet] + + start = domain.inf + if start in domain and valid(start) and start.is_finite: + sol_sets.append(FiniteSet(start)) + + for x in reals: + end = x + + if valid(_pt(start, end)): + sol_sets.append(Interval(start, end, True, True)) + + if x in singularities: + singularities.remove(x) + else: + if x in discontinuities: + discontinuities.remove(x) + _valid = valid(x) + else: # it's a solution + _valid = include_x + if _valid: + sol_sets.append(FiniteSet(x)) + + start = end + + end = domain.sup + if end in domain and valid(end) and end.is_finite: + sol_sets.append(FiniteSet(end)) + + if valid(_pt(start, end)): + sol_sets.append(Interval.open(start, end)) + + if coeffI != S.Zero and check: + rv = (make_real).intersect(_domain) + else: + rv = Intersection( + (Union(*sol_sets)), make_real, _domain).subs(gen, _gen) + + return rv if not relational else rv.as_relational(_gen) + + +def _pt(start, end): + """Return a point between start and end""" + if not start.is_infinite and not end.is_infinite: + pt = (start + end)/2 + elif start.is_infinite and end.is_infinite: + pt = S.Zero + else: + if (start.is_infinite and start.is_extended_positive is None or + end.is_infinite and end.is_extended_positive is None): + raise ValueError('cannot proceed with unsigned infinite values') + if (end.is_infinite and end.is_extended_negative or + start.is_infinite and start.is_extended_positive): + start, end = end, start + # if possible, use a multiple of self which has + # better behavior when checking assumptions than + # an expression obtained by adding or subtracting 1 + if end.is_infinite: + if start.is_extended_positive: + pt = start*2 + elif start.is_extended_negative: + pt = start*S.Half + else: + pt = start + 1 + elif start.is_infinite: + if end.is_extended_positive: + pt = end*S.Half + elif end.is_extended_negative: + pt = end*2 + else: + pt = end - 1 + return pt + + +def _solve_inequality(ie, s, linear=False): + """Return the inequality with s isolated on the left, if possible. + If the relationship is non-linear, a solution involving And or Or + may be returned. False or True are returned if the relationship + is never True or always True, respectively. + + If `linear` is True (default is False) an `s`-dependent expression + will be isolated on the left, if possible + but it will not be solved for `s` unless the expression is linear + in `s`. Furthermore, only "safe" operations which do not change the + sense of the relationship are applied: no division by an unsigned + value is attempted unless the relationship involves Eq or Ne and + no division by a value not known to be nonzero is ever attempted. + + Examples + ======== + + >>> from sympy import Eq, Symbol + >>> from sympy.solvers.inequalities import _solve_inequality as f + >>> from sympy.abc import x, y + + For linear expressions, the symbol can be isolated: + + >>> f(x - 2 < 0, x) + x < 2 + >>> f(-x - 6 < x, x) + x > -3 + + Sometimes nonlinear relationships will be False + + >>> f(x**2 + 4 < 0, x) + False + + Or they may involve more than one region of values: + + >>> f(x**2 - 4 < 0, x) + (-2 < x) & (x < 2) + + To restrict the solution to a relational, set linear=True + and only the x-dependent portion will be isolated on the left: + + >>> f(x**2 - 4 < 0, x, linear=True) + x**2 < 4 + + Division of only nonzero quantities is allowed, so x cannot + be isolated by dividing by y: + + >>> y.is_nonzero is None # it is unknown whether it is 0 or not + True + >>> f(x*y < 1, x) + x*y < 1 + + And while an equality (or inequality) still holds after dividing by a + non-zero quantity + + >>> nz = Symbol('nz', nonzero=True) + >>> f(Eq(x*nz, 1), x) + Eq(x, 1/nz) + + the sign must be known for other inequalities involving > or <: + + >>> f(x*nz <= 1, x) + nz*x <= 1 + >>> p = Symbol('p', positive=True) + >>> f(x*p <= 1, x) + x <= 1/p + + When there are denominators in the original expression that + are removed by expansion, conditions for them will be returned + as part of the result: + + >>> f(x < x*(2/x - 1), x) + (x < 1) & Ne(x, 0) + """ + from sympy.solvers.solvers import denoms + if s not in ie.free_symbols: + return ie + if ie.rhs == s: + ie = ie.reversed + if ie.lhs == s and s not in ie.rhs.free_symbols: + return ie + + def classify(ie, s, i): + # return True or False if ie evaluates when substituting s with + # i else None (if unevaluated) or NaN (when there is an error + # in evaluating) + try: + v = ie.subs(s, i) + if v is S.NaN: + return v + elif v not in (True, False): + return + return v + except TypeError: + return S.NaN + + rv = None + oo = S.Infinity + expr = ie.lhs - ie.rhs + try: + p = Poly(expr, s) + if p.degree() == 0: + rv = ie.func(p.as_expr(), 0) + elif not linear and p.degree() > 1: + # handle in except clause + raise NotImplementedError + except (PolynomialError, NotImplementedError): + if not linear: + try: + rv = reduce_rational_inequalities([[ie]], s) + except PolynomialError: + rv = solve_univariate_inequality(ie, s) + # remove restrictions wrt +/-oo that may have been + # applied when using sets to simplify the relationship + okoo = classify(ie, s, oo) + if okoo is S.true and classify(rv, s, oo) is S.false: + rv = rv.subs(s < oo, True) + oknoo = classify(ie, s, -oo) + if (oknoo is S.true and + classify(rv, s, -oo) is S.false): + rv = rv.subs(-oo < s, True) + rv = rv.subs(s > -oo, True) + if rv is S.true: + rv = (s <= oo) if okoo is S.true else (s < oo) + if oknoo is not S.true: + rv = And(-oo < s, rv) + else: + p = Poly(expr) + + conds = [] + if rv is None: + e = p.as_expr() # this is in expanded form + # Do a safe inversion of e, moving non-s terms + # to the rhs and dividing by a nonzero factor if + # the relational is Eq/Ne; for other relationals + # the sign must also be positive or negative + rhs = 0 + b, ax = e.as_independent(s, as_Add=True) + e -= b + rhs -= b + ef = factor_terms(e) + a, e = ef.as_independent(s, as_Add=False) + if (a.is_zero != False or # don't divide by potential 0 + a.is_negative == + a.is_positive is None and # if sign is not known then + ie.rel_op not in ('!=', '==')): # reject if not Eq/Ne + e = ef + a = S.One + rhs /= a + if a.is_positive: + rv = ie.func(e, rhs) + else: + rv = ie.reversed.func(e, rhs) + + # return conditions under which the value is + # valid, too. + beginning_denoms = denoms(ie.lhs) | denoms(ie.rhs) + current_denoms = denoms(rv) + for d in beginning_denoms - current_denoms: + c = _solve_inequality(Eq(d, 0), s, linear=linear) + if isinstance(c, Eq) and c.lhs == s: + if classify(rv, s, c.rhs) is S.true: + # rv is permitting this value but it shouldn't + conds.append(~c) + for i in (-oo, oo): + if (classify(rv, s, i) is S.true and + classify(ie, s, i) is not S.true): + conds.append(s < i if i is oo else i < s) + + conds.append(rv) + return And(*conds) + + +def _reduce_inequalities(inequalities, symbols): + # helper for reduce_inequalities + + poly_part, abs_part = {}, {} + other = [] + + for inequality in inequalities: + + expr, rel = inequality.lhs, inequality.rel_op # rhs is 0 + + # check for gens using atoms which is more strict than free_symbols to + # guard against EX domain which won't be handled by + # reduce_rational_inequalities + gens = expr.atoms(Symbol) + + if len(gens) == 1: + gen = gens.pop() + else: + common = expr.free_symbols & symbols + if len(common) == 1: + gen = common.pop() + other.append(_solve_inequality(Relational(expr, 0, rel), gen)) + continue + else: + raise NotImplementedError(filldedent(''' + inequality has more than one symbol of interest. + ''')) + + if expr.is_polynomial(gen): + poly_part.setdefault(gen, []).append((expr, rel)) + else: + components = expr.find(lambda u: + u.has(gen) and ( + u.is_Function or u.is_Pow and not u.exp.is_Integer)) + if components and all(isinstance(i, Abs) for i in components): + abs_part.setdefault(gen, []).append((expr, rel)) + else: + other.append(_solve_inequality(Relational(expr, 0, rel), gen)) + + poly_reduced = [reduce_rational_inequalities([exprs], gen) for gen, exprs in poly_part.items()] + abs_reduced = [reduce_abs_inequalities(exprs, gen) for gen, exprs in abs_part.items()] + + return And(*(poly_reduced + abs_reduced + other)) + + +def reduce_inequalities(inequalities, symbols=[]): + """Reduce a system of inequalities with rational coefficients. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy import reduce_inequalities + + >>> reduce_inequalities(0 <= x + 3, []) + (-3 <= x) & (x < oo) + + >>> reduce_inequalities(0 <= x + y*2 - 1, [x]) + (x < oo) & (x >= 1 - 2*y) + """ + if not iterable(inequalities): + inequalities = [inequalities] + inequalities = [sympify(i) for i in inequalities] + + gens = set().union(*[i.free_symbols for i in inequalities]) + + if not iterable(symbols): + symbols = [symbols] + symbols = (set(symbols) or gens) & gens + if any(i.is_extended_real is False for i in symbols): + raise TypeError(filldedent(''' + inequalities cannot contain symbols that are not real. + ''')) + + # make vanilla symbol real + recast = {i: Dummy(i.name, extended_real=True) + for i in gens if i.is_extended_real is None} + inequalities = [i.xreplace(recast) for i in inequalities] + symbols = {i.xreplace(recast) for i in symbols} + + # prefilter + keep = [] + for i in inequalities: + if isinstance(i, Relational): + i = i.func(i.lhs.as_expr() - i.rhs.as_expr(), 0) + elif i not in (True, False): + i = Eq(i, 0) + if i == True: + continue + elif i == False: + return S.false + if i.lhs.is_number: + raise NotImplementedError( + "could not determine truth value of %s" % i) + keep.append(i) + inequalities = keep + del keep + + # solve system + rv = _reduce_inequalities(inequalities, symbols) + + # restore original symbols and return + return rv.xreplace({v: k for k, v in recast.items()}) diff --git a/phivenv/Lib/site-packages/sympy/solvers/ode/__init__.py b/phivenv/Lib/site-packages/sympy/solvers/ode/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2b543425251dea6380a1860279cb6d636f3dd629 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/solvers/ode/__init__.py @@ -0,0 +1,16 @@ +from .ode import (allhints, checkinfsol, classify_ode, + constantsimp, dsolve, homogeneous_order) + +from .lie_group import infinitesimals + +from .subscheck import checkodesol + +from .systems import (canonical_odes, linear_ode_to_matrix, + linodesolve) + + +__all__ = [ + 'allhints', 'checkinfsol', 'checkodesol', 'classify_ode', 'constantsimp', + 'dsolve', 'homogeneous_order', 'infinitesimals', 'canonical_odes', 'linear_ode_to_matrix', + 'linodesolve' +] diff --git a/phivenv/Lib/site-packages/sympy/solvers/ode/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/solvers/ode/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa4124c8ae89fb9dc50d0a2710c6f8201f57b3dc Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/solvers/ode/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/solvers/ode/__pycache__/hypergeometric.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/solvers/ode/__pycache__/hypergeometric.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f2fef5e00be2707e6152d6b466bfa3247ac731c Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/solvers/ode/__pycache__/hypergeometric.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/solvers/ode/__pycache__/lie_group.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/solvers/ode/__pycache__/lie_group.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0777bb2e4f4b0d406ce7f85745f9f1d97403b593 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/solvers/ode/__pycache__/lie_group.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/solvers/ode/__pycache__/nonhomogeneous.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/solvers/ode/__pycache__/nonhomogeneous.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6bd30392b018d6654da5fcf5b11b57aa8970c04f Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/solvers/ode/__pycache__/nonhomogeneous.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/solvers/ode/__pycache__/riccati.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/solvers/ode/__pycache__/riccati.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10170e4fbe2cd56d6da18ea9c699b0dd8ce000b7 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/solvers/ode/__pycache__/riccati.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/solvers/ode/__pycache__/subscheck.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/solvers/ode/__pycache__/subscheck.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..376bafd6a7871f3023d7c5fc2aaddda8ddc27719 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/solvers/ode/__pycache__/subscheck.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/solvers/ode/__pycache__/systems.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/solvers/ode/__pycache__/systems.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee30af26769ef000829976236dd9b04b1d645856 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/solvers/ode/__pycache__/systems.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/solvers/ode/hypergeometric.py b/phivenv/Lib/site-packages/sympy/solvers/ode/hypergeometric.py new file mode 100644 index 0000000000000000000000000000000000000000..5699d2418058acaf48d1e4f87f6635a7b1f7284c --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/solvers/ode/hypergeometric.py @@ -0,0 +1,272 @@ +r''' +This module contains the implementation of the 2nd_hypergeometric hint for +dsolve. This is an incomplete implementation of the algorithm described in [1]. +The algorithm solves 2nd order linear ODEs of the form + +.. math:: y'' + A(x) y' + B(x) y = 0\text{,} + +where `A` and `B` are rational functions. The algorithm should find any +solution of the form + +.. math:: y = P(x) _pF_q(..; ..;\frac{\alpha x^k + \beta}{\gamma x^k + \delta})\text{,} + +where pFq is any of 2F1, 1F1 or 0F1 and `P` is an "arbitrary function". +Currently only the 2F1 case is implemented in SymPy but the other cases are +described in the paper and could be implemented in future (contributions +welcome!). + +References +========== + +.. [1] L. Chan, E.S. Cheb-Terrab, Non-Liouvillian solutions for second order + linear ODEs, (2004). + https://arxiv.org/abs/math-ph/0402063 +''' + +from sympy.core import S, Pow +from sympy.core.function import expand +from sympy.core.relational import Eq +from sympy.core.symbol import Symbol, Wild +from sympy.functions import exp, sqrt, hyper +from sympy.integrals import Integral +from sympy.polys import roots, gcd +from sympy.polys.polytools import cancel, factor +from sympy.simplify import collect, simplify, logcombine # type: ignore +from sympy.simplify.powsimp import powdenest +from sympy.solvers.ode.ode import get_numbered_constants + + +def match_2nd_hypergeometric(eq, func): + x = func.args[0] + df = func.diff(x) + a3 = Wild('a3', exclude=[func, func.diff(x), func.diff(x, 2)]) + b3 = Wild('b3', exclude=[func, func.diff(x), func.diff(x, 2)]) + c3 = Wild('c3', exclude=[func, func.diff(x), func.diff(x, 2)]) + deq = a3*(func.diff(x, 2)) + b3*df + c3*func + r = collect(eq, + [func.diff(x, 2), func.diff(x), func]).match(deq) + if r: + if not all(val.is_polynomial() for val in r.values()): + n, d = eq.as_numer_denom() + eq = expand(n) + r = collect(eq, [func.diff(x, 2), func.diff(x), func]).match(deq) + + if r and r[a3]!=0: + A = cancel(r[b3]/r[a3]) + B = cancel(r[c3]/r[a3]) + return [A, B] + else: + return [] + + +def equivalence_hypergeometric(A, B, func): + # This method for finding the equivalence is only for 2F1 type. + # We can extend it for 1F1 and 0F1 type also. + x = func.args[0] + + # making given equation in normal form + I1 = factor(cancel(A.diff(x)/2 + A**2/4 - B)) + + # computing shifted invariant(J1) of the equation + J1 = factor(cancel(x**2*I1 + S(1)/4)) + num, dem = J1.as_numer_denom() + num = powdenest(expand(num)) + dem = powdenest(expand(dem)) + # this function will compute the different powers of variable(x) in J1. + # then it will help in finding value of k. k is power of x such that we can express + # J1 = x**k * J0(x**k) then all the powers in J0 become integers. + def _power_counting(num): + _pow = {0} + for val in num: + if val.has(x): + if isinstance(val, Pow) and val.as_base_exp()[0] == x: + _pow.add(val.as_base_exp()[1]) + elif val == x: + _pow.add(val.as_base_exp()[1]) + else: + _pow.update(_power_counting(val.args)) + return _pow + + pow_num = _power_counting((num, )) + pow_dem = _power_counting((dem, )) + pow_dem.update(pow_num) + + _pow = pow_dem + k = gcd(_pow) + + # computing I0 of the given equation + I0 = powdenest(simplify(factor(((J1/k**2) - S(1)/4)/((x**k)**2))), force=True) + I0 = factor(cancel(powdenest(I0.subs(x, x**(S(1)/k)), force=True))) + + # Before this point I0, J1 might be functions of e.g. sqrt(x) but replacing + # x with x**(1/k) should result in I0 being a rational function of x or + # otherwise the hypergeometric solver cannot be used. Note that k can be a + # non-integer rational such as 2/7. + if not I0.is_rational_function(x): + return None + + num, dem = I0.as_numer_denom() + + max_num_pow = max(_power_counting((num, ))) + dem_args = dem.args + sing_point = [] + dem_pow = [] + # calculating singular point of I0. + for arg in dem_args: + if arg.has(x): + if isinstance(arg, Pow): + # (x-a)**n + dem_pow.append(arg.as_base_exp()[1]) + sing_point.append(list(roots(arg.as_base_exp()[0], x).keys())[0]) + else: + # (x-a) type + dem_pow.append(arg.as_base_exp()[1]) + sing_point.append(list(roots(arg, x).keys())[0]) + + dem_pow.sort() + # checking if equivalence is exists or not. + + if equivalence(max_num_pow, dem_pow) == "2F1": + return {'I0':I0, 'k':k, 'sing_point':sing_point, 'type':"2F1"} + else: + return None + + +def match_2nd_2F1_hypergeometric(I, k, sing_point, func): + x = func.args[0] + a = Wild("a") + b = Wild("b") + c = Wild("c") + t = Wild("t") + s = Wild("s") + r = Wild("r") + alpha = Wild("alpha") + beta = Wild("beta") + gamma = Wild("gamma") + delta = Wild("delta") + # I0 of the standard 2F1 equation. + I0 = ((a-b+1)*(a-b-1)*x**2 + 2*((1-a-b)*c + 2*a*b)*x + c*(c-2))/(4*x**2*(x-1)**2) + if sing_point != [0, 1]: + # If singular point is [0, 1] then we have standard equation. + eqs = [] + sing_eqs = [-beta/alpha, -delta/gamma, (delta-beta)/(alpha-gamma)] + # making equations for the finding the mobius transformation + for i in range(3): + if i>> from sympy import Function, Eq, pprint + >>> from sympy.abc import x, y + >>> xi, eta, h = map(Function, ['xi', 'eta', 'h']) + >>> h = h(x, y) # dy/dx = h + >>> eta = eta(x, y) + >>> xi = xi(x, y) + >>> genform = Eq(eta.diff(x) + (eta.diff(y) - xi.diff(x))*h + ... - (xi.diff(y))*h**2 - xi*(h.diff(x)) - eta*(h.diff(y)), 0) + >>> pprint(genform) + /d d \ d 2 d d d + |--(eta(x, y)) - --(xi(x, y))|*h(x, y) - eta(x, y)*--(h(x, y)) - h (x, y)*--(xi(x, y)) - xi(x, y)*--(h(x, y)) + --(eta(x, y)) = 0 + \dy dx / dy dy dx dx + + Solving the above mentioned PDE is not trivial, and can be solved only by + making intelligent assumptions for `\xi` and `\eta` (heuristics). Once an + infinitesimal is found, the attempt to find more heuristics stops. This is done to + optimise the speed of solving the differential equation. If a list of all the + infinitesimals is needed, ``hint`` should be flagged as ``all``, which gives + the complete list of infinitesimals. If the infinitesimals for a particular + heuristic needs to be found, it can be passed as a flag to ``hint``. + + Examples + ======== + + >>> from sympy import Function + >>> from sympy.solvers.ode.lie_group import infinitesimals + >>> from sympy.abc import x + >>> f = Function('f') + >>> eq = f(x).diff(x) - x**2*f(x) + >>> infinitesimals(eq) + [{eta(x, f(x)): exp(x**3/3), xi(x, f(x)): 0}] + + References + ========== + + - Solving differential equations by Symmetry Groups, + John Starrett, pp. 1 - pp. 14 + + """ + + if isinstance(eq, Equality): + eq = eq.lhs - eq.rhs + if not func: + eq, func = _preprocess(eq) + variables = func.args + if len(variables) != 1: + raise ValueError("ODE's have only one independent variable") + else: + x = variables[0] + if not order: + order = ode_order(eq, func) + if order != 1: + raise NotImplementedError("Infinitesimals for only " + "first order ODE's have been implemented") + else: + df = func.diff(x) + # Matching differential equation of the form a*df + b + a = Wild('a', exclude = [df]) + b = Wild('b', exclude = [df]) + if match: # Used by lie_group hint + h = match['h'] + y = match['y'] + else: + match = collect(expand(eq), df).match(a*df + b) + if match: + h = -simplify(match[b]/match[a]) + else: + try: + sol = solve(eq, df) + except NotImplementedError: + raise NotImplementedError("Infinitesimals for the " + "first order ODE could not be found") + else: + h = sol[0] # Find infinitesimals for one solution + y = Dummy("y") + h = h.subs(func, y) + + u = Dummy("u") + hx = h.diff(x) + hy = h.diff(y) + hinv = ((1/h).subs([(x, u), (y, x)])).subs(u, y) # Inverse ODE + match = {'h': h, 'func': func, 'hx': hx, 'hy': hy, 'y': y, 'hinv': hinv} + if hint == 'all': + xieta = [] + for heuristic in lie_heuristics: + function = globals()['lie_heuristic_' + heuristic] + inflist = function(match, comp=True) + if inflist: + xieta.extend([inf for inf in inflist if inf not in xieta]) + if xieta: + return xieta + else: + raise NotImplementedError("Infinitesimals could not be found for " + "the given ODE") + + elif hint == 'default': + for heuristic in lie_heuristics: + function = globals()['lie_heuristic_' + heuristic] + xieta = function(match, comp=False) + if xieta: + return xieta + + raise NotImplementedError("Infinitesimals could not be found for" + " the given ODE") + + elif hint not in lie_heuristics: + raise ValueError("Heuristic not recognized: " + hint) + + else: + function = globals()['lie_heuristic_' + hint] + xieta = function(match, comp=True) + if xieta: + return xieta + else: + raise ValueError("Infinitesimals could not be found using the" + " given heuristic") + + +def lie_heuristic_abaco1_simple(match, comp=False): + r""" + The first heuristic uses the following four sets of + assumptions on `\xi` and `\eta` + + .. math:: \xi = 0, \eta = f(x) + + .. math:: \xi = 0, \eta = f(y) + + .. math:: \xi = f(x), \eta = 0 + + .. math:: \xi = f(y), \eta = 0 + + The success of this heuristic is determined by algebraic factorisation. + For the first assumption `\xi = 0` and `\eta` to be a function of `x`, the PDE + + .. math:: \frac{\partial \eta}{\partial x} + (\frac{\partial \eta}{\partial y} + - \frac{\partial \xi}{\partial x})*h + - \frac{\partial \xi}{\partial y}*h^{2} + - \xi*\frac{\partial h}{\partial x} - \eta*\frac{\partial h}{\partial y} = 0 + + reduces to `f'(x) - f\frac{\partial h}{\partial y} = 0` + If `\frac{\partial h}{\partial y}` is a function of `x`, then this can usually + be integrated easily. A similar idea is applied to the other 3 assumptions as well. + + + References + ========== + + - E.S Cheb-Terrab, L.G.S Duarte and L.A,C.P da Mota, Computer Algebra + Solving of First Order ODEs Using Symmetry Methods, pp. 8 + + + """ + + xieta = [] + y = match['y'] + h = match['h'] + func = match['func'] + x = func.args[0] + hx = match['hx'] + hy = match['hy'] + xi = Function('xi')(x, func) + eta = Function('eta')(x, func) + + hysym = hy.free_symbols + if y not in hysym: + try: + fx = exp(integrate(hy, x)) + except NotImplementedError: + pass + else: + inf = {xi: S.Zero, eta: fx} + if not comp: + return [inf] + if comp and inf not in xieta: + xieta.append(inf) + + factor = hy/h + facsym = factor.free_symbols + if x not in facsym: + try: + fy = exp(integrate(factor, y)) + except NotImplementedError: + pass + else: + inf = {xi: S.Zero, eta: fy.subs(y, func)} + if not comp: + return [inf] + if comp and inf not in xieta: + xieta.append(inf) + + factor = -hx/h + facsym = factor.free_symbols + if y not in facsym: + try: + fx = exp(integrate(factor, x)) + except NotImplementedError: + pass + else: + inf = {xi: fx, eta: S.Zero} + if not comp: + return [inf] + if comp and inf not in xieta: + xieta.append(inf) + + factor = -hx/(h**2) + facsym = factor.free_symbols + if x not in facsym: + try: + fy = exp(integrate(factor, y)) + except NotImplementedError: + pass + else: + inf = {xi: fy.subs(y, func), eta: S.Zero} + if not comp: + return [inf] + if comp and inf not in xieta: + xieta.append(inf) + + if xieta: + return xieta + +def lie_heuristic_abaco1_product(match, comp=False): + r""" + The second heuristic uses the following two assumptions on `\xi` and `\eta` + + .. math:: \eta = 0, \xi = f(x)*g(y) + + .. math:: \eta = f(x)*g(y), \xi = 0 + + The first assumption of this heuristic holds good if + `\frac{1}{h^{2}}\frac{\partial^2}{\partial x \partial y}\log(h)` is + separable in `x` and `y`, then the separated factors containing `x` + is `f(x)`, and `g(y)` is obtained by + + .. math:: e^{\int f\frac{\partial}{\partial x}\left(\frac{1}{f*h}\right)\,dy} + + provided `f\frac{\partial}{\partial x}\left(\frac{1}{f*h}\right)` is a function + of `y` only. + + The second assumption holds good if `\frac{dy}{dx} = h(x, y)` is rewritten as + `\frac{dy}{dx} = \frac{1}{h(y, x)}` and the same properties of the first assumption + satisfies. After obtaining `f(x)` and `g(y)`, the coordinates are again + interchanged, to get `\eta` as `f(x)*g(y)` + + + References + ========== + - E.S. Cheb-Terrab, A.D. Roche, Symmetries and First Order + ODE Patterns, pp. 7 - pp. 8 + + """ + + xieta = [] + y = match['y'] + h = match['h'] + hinv = match['hinv'] + func = match['func'] + x = func.args[0] + xi = Function('xi')(x, func) + eta = Function('eta')(x, func) + + + inf = separatevars(((log(h).diff(y)).diff(x))/h**2, dict=True, symbols=[x, y]) + if inf and inf['coeff']: + fx = inf[x] + gy = simplify(fx*((1/(fx*h)).diff(x))) + gysyms = gy.free_symbols + if x not in gysyms: + gy = exp(integrate(gy, y)) + inf = {eta: S.Zero, xi: (fx*gy).subs(y, func)} + if not comp: + return [inf] + if comp and inf not in xieta: + xieta.append(inf) + + u1 = Dummy("u1") + inf = separatevars(((log(hinv).diff(y)).diff(x))/hinv**2, dict=True, symbols=[x, y]) + if inf and inf['coeff']: + fx = inf[x] + gy = simplify(fx*((1/(fx*hinv)).diff(x))) + gysyms = gy.free_symbols + if x not in gysyms: + gy = exp(integrate(gy, y)) + etaval = fx*gy + etaval = (etaval.subs([(x, u1), (y, x)])).subs(u1, y) + inf = {eta: etaval.subs(y, func), xi: S.Zero} + if not comp: + return [inf] + if comp and inf not in xieta: + xieta.append(inf) + + if xieta: + return xieta + +def lie_heuristic_bivariate(match, comp=False): + r""" + The third heuristic assumes the infinitesimals `\xi` and `\eta` + to be bi-variate polynomials in `x` and `y`. The assumption made here + for the logic below is that `h` is a rational function in `x` and `y` + though that may not be necessary for the infinitesimals to be + bivariate polynomials. The coefficients of the infinitesimals + are found out by substituting them in the PDE and grouping similar terms + that are polynomials and since they form a linear system, solve and check + for non trivial solutions. The degree of the assumed bivariates + are increased till a certain maximum value. + + References + ========== + - Lie Groups and Differential Equations + pp. 327 - pp. 329 + + """ + + h = match['h'] + hx = match['hx'] + hy = match['hy'] + func = match['func'] + x = func.args[0] + y = match['y'] + xi = Function('xi')(x, func) + eta = Function('eta')(x, func) + + if h.is_rational_function(): + # The maximum degree that the infinitesimals can take is + # calculated by this technique. + etax, etay, etad, xix, xiy, xid = symbols("etax etay etad xix xiy xid") + ipde = etax + (etay - xix)*h - xiy*h**2 - xid*hx - etad*hy + num, denom = cancel(ipde).as_numer_denom() + deg = Poly(num, x, y).total_degree() + deta = Function('deta')(x, y) + dxi = Function('dxi')(x, y) + ipde = (deta.diff(x) + (deta.diff(y) - dxi.diff(x))*h - (dxi.diff(y))*h**2 + - dxi*hx - deta*hy) + xieq = Symbol("xi0") + etaeq = Symbol("eta0") + + for i in range(deg + 1): + if i: + xieq += Add(*[ + Symbol("xi_" + str(power) + "_" + str(i - power))*x**power*y**(i - power) + for power in range(i + 1)]) + etaeq += Add(*[ + Symbol("eta_" + str(power) + "_" + str(i - power))*x**power*y**(i - power) + for power in range(i + 1)]) + pden, denom = (ipde.subs({dxi: xieq, deta: etaeq}).doit()).as_numer_denom() + pden = expand(pden) + + # If the individual terms are monomials, the coefficients + # are grouped + if pden.is_polynomial(x, y) and pden.is_Add: + polyy = Poly(pden, x, y).as_dict() + if polyy: + symset = xieq.free_symbols.union(etaeq.free_symbols) - {x, y} + soldict = solve(polyy.values(), *symset) + if isinstance(soldict, list): + soldict = soldict[0] + if any(soldict.values()): + xired = xieq.subs(soldict) + etared = etaeq.subs(soldict) + # Scaling is done by substituting one for the parameters + # This can be any number except zero. + dict_ = dict.fromkeys(symset, 1) + inf = {eta: etared.subs(dict_).subs(y, func), + xi: xired.subs(dict_).subs(y, func)} + return [inf] + +def lie_heuristic_chi(match, comp=False): + r""" + The aim of the fourth heuristic is to find the function `\chi(x, y)` + that satisfies the PDE `\frac{d\chi}{dx} + h\frac{d\chi}{dx} + - \frac{\partial h}{\partial y}\chi = 0`. + + This assumes `\chi` to be a bivariate polynomial in `x` and `y`. By intuition, + `h` should be a rational function in `x` and `y`. The method used here is + to substitute a general binomial for `\chi` up to a certain maximum degree + is reached. The coefficients of the polynomials, are calculated by by collecting + terms of the same order in `x` and `y`. + + After finding `\chi`, the next step is to use `\eta = \xi*h + \chi`, to + determine `\xi` and `\eta`. This can be done by dividing `\chi` by `h` + which would give `-\xi` as the quotient and `\eta` as the remainder. + + + References + ========== + - E.S Cheb-Terrab, L.G.S Duarte and L.A,C.P da Mota, Computer Algebra + Solving of First Order ODEs Using Symmetry Methods, pp. 8 + + """ + + h = match['h'] + hy = match['hy'] + func = match['func'] + x = func.args[0] + y = match['y'] + xi = Function('xi')(x, func) + eta = Function('eta')(x, func) + + if h.is_rational_function(): + schi, schix, schiy = symbols("schi, schix, schiy") + cpde = schix + h*schiy - hy*schi + num, denom = cancel(cpde).as_numer_denom() + deg = Poly(num, x, y).total_degree() + + chi = Function('chi')(x, y) + chix = chi.diff(x) + chiy = chi.diff(y) + cpde = chix + h*chiy - hy*chi + chieq = Symbol("chi") + for i in range(1, deg + 1): + chieq += Add(*[ + Symbol("chi_" + str(power) + "_" + str(i - power))*x**power*y**(i - power) + for power in range(i + 1)]) + cnum, cden = cancel(cpde.subs({chi : chieq}).doit()).as_numer_denom() + cnum = expand(cnum) + if cnum.is_polynomial(x, y) and cnum.is_Add: + cpoly = Poly(cnum, x, y).as_dict() + if cpoly: + solsyms = chieq.free_symbols - {x, y} + soldict = solve(cpoly.values(), *solsyms) + if isinstance(soldict, list): + soldict = soldict[0] + if any(soldict.values()): + chieq = chieq.subs(soldict) + dict_ = dict.fromkeys(solsyms, 1) + chieq = chieq.subs(dict_) + # After finding chi, the main aim is to find out + # eta, xi by the equation eta = xi*h + chi + # One method to set xi, would be rearranging it to + # (eta/h) - xi = (chi/h). This would mean dividing + # chi by h would give -xi as the quotient and eta + # as the remainder. Thanks to Sean Vig for suggesting + # this method. + xic, etac = div(chieq, h) + inf = {eta: etac.subs(y, func), xi: -xic.subs(y, func)} + return [inf] + +def lie_heuristic_function_sum(match, comp=False): + r""" + This heuristic uses the following two assumptions on `\xi` and `\eta` + + .. math:: \eta = 0, \xi = f(x) + g(y) + + .. math:: \eta = f(x) + g(y), \xi = 0 + + The first assumption of this heuristic holds good if + + .. math:: \frac{\partial}{\partial y}[(h\frac{\partial^{2}}{ + \partial x^{2}}(h^{-1}))^{-1}] + + is separable in `x` and `y`, + + 1. The separated factors containing `y` is `\frac{\partial g}{\partial y}`. + From this `g(y)` can be determined. + 2. The separated factors containing `x` is `f''(x)`. + 3. `h\frac{\partial^{2}}{\partial x^{2}}(h^{-1})` equals + `\frac{f''(x)}{f(x) + g(y)}`. From this `f(x)` can be determined. + + The second assumption holds good if `\frac{dy}{dx} = h(x, y)` is rewritten as + `\frac{dy}{dx} = \frac{1}{h(y, x)}` and the same properties of the first + assumption satisfies. After obtaining `f(x)` and `g(y)`, the coordinates + are again interchanged, to get `\eta` as `f(x) + g(y)`. + + For both assumptions, the constant factors are separated among `g(y)` + and `f''(x)`, such that `f''(x)` obtained from 3] is the same as that + obtained from 2]. If not possible, then this heuristic fails. + + + References + ========== + - E.S. Cheb-Terrab, A.D. Roche, Symmetries and First Order + ODE Patterns, pp. 7 - pp. 8 + + """ + + xieta = [] + h = match['h'] + func = match['func'] + hinv = match['hinv'] + x = func.args[0] + y = match['y'] + xi = Function('xi')(x, func) + eta = Function('eta')(x, func) + + for odefac in [h, hinv]: + factor = odefac*((1/odefac).diff(x, 2)) + sep = separatevars((1/factor).diff(y), dict=True, symbols=[x, y]) + if sep and sep['coeff'] and sep[x].has(x) and sep[y].has(y): + k = Dummy("k") + try: + gy = k*integrate(sep[y], y) + except NotImplementedError: + pass + else: + fdd = 1/(k*sep[x]*sep['coeff']) + fx = simplify(fdd/factor - gy) + check = simplify(fx.diff(x, 2) - fdd) + if fx: + if not check: + fx = fx.subs(k, 1) + gy = (gy/k) + else: + sol = solve(check, k) + if sol: + sol = sol[0] + fx = fx.subs(k, sol) + gy = (gy/k)*sol + else: + continue + if odefac == hinv: # Inverse ODE + fx = fx.subs(x, y) + gy = gy.subs(y, x) + etaval = factor_terms(fx + gy) + if etaval.is_Mul: + etaval = Mul(*[arg for arg in etaval.args if arg.has(x, y)]) + if odefac == hinv: # Inverse ODE + inf = {eta: etaval.subs(y, func), xi : S.Zero} + else: + inf = {xi: etaval.subs(y, func), eta : S.Zero} + if not comp: + return [inf] + else: + xieta.append(inf) + + if xieta: + return xieta + +def lie_heuristic_abaco2_similar(match, comp=False): + r""" + This heuristic uses the following two assumptions on `\xi` and `\eta` + + .. math:: \eta = g(x), \xi = f(x) + + .. math:: \eta = f(y), \xi = g(y) + + For the first assumption, + + 1. First `\frac{\frac{\partial h}{\partial y}}{\frac{\partial^{2} h}{ + \partial yy}}` is calculated. Let us say this value is A + + 2. If this is constant, then `h` is matched to the form `A(x) + B(x)e^{ + \frac{y}{C}}` then, `\frac{e^{\int \frac{A(x)}{C} \,dx}}{B(x)}` gives `f(x)` + and `A(x)*f(x)` gives `g(x)` + + 3. Otherwise `\frac{\frac{\partial A}{\partial X}}{\frac{\partial A}{ + \partial Y}} = \gamma` is calculated. If + + a] `\gamma` is a function of `x` alone + + b] `\frac{\gamma\frac{\partial h}{\partial y} - \gamma'(x) - \frac{ + \partial h}{\partial x}}{h + \gamma} = G` is a function of `x` alone. + then, `e^{\int G \,dx}` gives `f(x)` and `-\gamma*f(x)` gives `g(x)` + + The second assumption holds good if `\frac{dy}{dx} = h(x, y)` is rewritten as + `\frac{dy}{dx} = \frac{1}{h(y, x)}` and the same properties of the first assumption + satisfies. After obtaining `f(x)` and `g(x)`, the coordinates are again + interchanged, to get `\xi` as `f(x^*)` and `\eta` as `g(y^*)` + + References + ========== + - E.S. Cheb-Terrab, A.D. Roche, Symmetries and First Order + ODE Patterns, pp. 10 - pp. 12 + + """ + + h = match['h'] + hx = match['hx'] + hy = match['hy'] + func = match['func'] + hinv = match['hinv'] + x = func.args[0] + y = match['y'] + xi = Function('xi')(x, func) + eta = Function('eta')(x, func) + + factor = cancel(h.diff(y)/h.diff(y, 2)) + factorx = factor.diff(x) + factory = factor.diff(y) + if not factor.has(x) and not factor.has(y): + A = Wild('A', exclude=[y]) + B = Wild('B', exclude=[y]) + C = Wild('C', exclude=[x, y]) + match = h.match(A + B*exp(y/C)) + try: + tau = exp(-integrate(match[A]/match[C]), x)/match[B] + except NotImplementedError: + pass + else: + gx = match[A]*tau + return [{xi: tau, eta: gx}] + + else: + gamma = cancel(factorx/factory) + if not gamma.has(y): + tauint = cancel((gamma*hy - gamma.diff(x) - hx)/(h + gamma)) + if not tauint.has(y): + try: + tau = exp(integrate(tauint, x)) + except NotImplementedError: + pass + else: + gx = -tau*gamma + return [{xi: tau, eta: gx}] + + factor = cancel(hinv.diff(y)/hinv.diff(y, 2)) + factorx = factor.diff(x) + factory = factor.diff(y) + if not factor.has(x) and not factor.has(y): + A = Wild('A', exclude=[y]) + B = Wild('B', exclude=[y]) + C = Wild('C', exclude=[x, y]) + match = h.match(A + B*exp(y/C)) + try: + tau = exp(-integrate(match[A]/match[C]), x)/match[B] + except NotImplementedError: + pass + else: + gx = match[A]*tau + return [{eta: tau.subs(x, func), xi: gx.subs(x, func)}] + + else: + gamma = cancel(factorx/factory) + if not gamma.has(y): + tauint = cancel((gamma*hinv.diff(y) - gamma.diff(x) - hinv.diff(x))/( + hinv + gamma)) + if not tauint.has(y): + try: + tau = exp(integrate(tauint, x)) + except NotImplementedError: + pass + else: + gx = -tau*gamma + return [{eta: tau.subs(x, func), xi: gx.subs(x, func)}] + + +def lie_heuristic_abaco2_unique_unknown(match, comp=False): + r""" + This heuristic assumes the presence of unknown functions or known functions + with non-integer powers. + + 1. A list of all functions and non-integer powers containing x and y + 2. Loop over each element `f` in the list, find `\frac{\frac{\partial f}{\partial x}}{ + \frac{\partial f}{\partial x}} = R` + + If it is separable in `x` and `y`, let `X` be the factors containing `x`. Then + + a] Check if `\xi = X` and `\eta = -\frac{X}{R}` satisfy the PDE. If yes, then return + `\xi` and `\eta` + b] Check if `\xi = \frac{-R}{X}` and `\eta = -\frac{1}{X}` satisfy the PDE. + If yes, then return `\xi` and `\eta` + + If not, then check if + + a] :math:`\xi = -R,\eta = 1` + + b] :math:`\xi = 1, \eta = -\frac{1}{R}` + + are solutions. + + References + ========== + - E.S. Cheb-Terrab, A.D. Roche, Symmetries and First Order + ODE Patterns, pp. 10 - pp. 12 + + """ + + h = match['h'] + hx = match['hx'] + hy = match['hy'] + func = match['func'] + x = func.args[0] + y = match['y'] + xi = Function('xi')(x, func) + eta = Function('eta')(x, func) + + funclist = [] + for atom in h.atoms(Pow): + base, exp = atom.as_base_exp() + if base.has(x) and base.has(y): + if not exp.is_Integer: + funclist.append(atom) + + for function in h.atoms(AppliedUndef): + syms = function.free_symbols + if x in syms and y in syms: + funclist.append(function) + + for f in funclist: + frac = cancel(f.diff(y)/f.diff(x)) + sep = separatevars(frac, dict=True, symbols=[x, y]) + if sep and sep['coeff']: + xitry1 = sep[x] + etatry1 = -1/(sep[y]*sep['coeff']) + pde1 = etatry1.diff(y)*h - xitry1.diff(x)*h - xitry1*hx - etatry1*hy + if not simplify(pde1): + return [{xi: xitry1, eta: etatry1.subs(y, func)}] + xitry2 = 1/etatry1 + etatry2 = 1/xitry1 + pde2 = etatry2.diff(x) - (xitry2.diff(y))*h**2 - xitry2*hx - etatry2*hy + if not simplify(expand(pde2)): + return [{xi: xitry2.subs(y, func), eta: etatry2}] + + else: + etatry = -1/frac + pde = etatry.diff(x) + etatry.diff(y)*h - hx - etatry*hy + if not simplify(pde): + return [{xi: S.One, eta: etatry.subs(y, func)}] + xitry = -frac + pde = -xitry.diff(x)*h -xitry.diff(y)*h**2 - xitry*hx -hy + if not simplify(expand(pde)): + return [{xi: xitry.subs(y, func), eta: S.One}] + + +def lie_heuristic_abaco2_unique_general(match, comp=False): + r""" + This heuristic finds if infinitesimals of the form `\eta = f(x)`, `\xi = g(y)` + without making any assumptions on `h`. + + The complete sequence of steps is given in the paper mentioned below. + + References + ========== + - E.S. Cheb-Terrab, A.D. Roche, Symmetries and First Order + ODE Patterns, pp. 10 - pp. 12 + + """ + hx = match['hx'] + hy = match['hy'] + func = match['func'] + x = func.args[0] + y = match['y'] + xi = Function('xi')(x, func) + eta = Function('eta')(x, func) + + A = hx.diff(y) + B = hy.diff(y) + hy**2 + C = hx.diff(x) - hx**2 + + if not (A and B and C): + return + + Ax = A.diff(x) + Ay = A.diff(y) + Axy = Ax.diff(y) + Axx = Ax.diff(x) + Ayy = Ay.diff(y) + D = simplify(2*Axy + hx*Ay - Ax*hy + (hx*hy + 2*A)*A)*A - 3*Ax*Ay + if not D: + E1 = simplify(3*Ax**2 + ((hx**2 + 2*C)*A - 2*Axx)*A) + if E1: + E2 = simplify((2*Ayy + (2*B - hy**2)*A)*A - 3*Ay**2) + if not E2: + E3 = simplify( + E1*((28*Ax + 4*hx*A)*A**3 - E1*(hy*A + Ay)) - E1.diff(x)*8*A**4) + if not E3: + etaval = cancel((4*A**3*(Ax - hx*A) + E1*(hy*A - Ay))/(S(2)*A*E1)) + if x not in etaval: + try: + etaval = exp(integrate(etaval, y)) + except NotImplementedError: + pass + else: + xival = -4*A**3*etaval/E1 + if y not in xival: + return [{xi: xival, eta: etaval.subs(y, func)}] + + else: + E1 = simplify((2*Ayy + (2*B - hy**2)*A)*A - 3*Ay**2) + if E1: + E2 = simplify( + 4*A**3*D - D**2 + E1*((2*Axx - (hx**2 + 2*C)*A)*A - 3*Ax**2)) + if not E2: + E3 = simplify( + -(A*D)*E1.diff(y) + ((E1.diff(x) - hy*D)*A + 3*Ay*D + + (A*hx - 3*Ax)*E1)*E1) + if not E3: + etaval = cancel(((A*hx - Ax)*E1 - (Ay + A*hy)*D)/(S(2)*A*D)) + if x not in etaval: + try: + etaval = exp(integrate(etaval, y)) + except NotImplementedError: + pass + else: + xival = -E1*etaval/D + if y not in xival: + return [{xi: xival, eta: etaval.subs(y, func)}] + + +def lie_heuristic_linear(match, comp=False): + r""" + This heuristic assumes + + 1. `\xi = ax + by + c` and + 2. `\eta = fx + gy + h` + + After substituting the following assumptions in the determining PDE, it + reduces to + + .. math:: f + (g - a)h - bh^{2} - (ax + by + c)\frac{\partial h}{\partial x} + - (fx + gy + c)\frac{\partial h}{\partial y} + + Solving the reduced PDE obtained, using the method of characteristics, becomes + impractical. The method followed is grouping similar terms and solving the system + of linear equations obtained. The difference between the bivariate heuristic is that + `h` need not be a rational function in this case. + + References + ========== + - E.S. Cheb-Terrab, A.D. Roche, Symmetries and First Order + ODE Patterns, pp. 10 - pp. 12 + + """ + h = match['h'] + hx = match['hx'] + hy = match['hy'] + func = match['func'] + x = func.args[0] + y = match['y'] + xi = Function('xi')(x, func) + eta = Function('eta')(x, func) + + coeffdict = {} + symbols = numbered_symbols("c", cls=Dummy) + symlist = [next(symbols) for _ in islice(symbols, 6)] + C0, C1, C2, C3, C4, C5 = symlist + pde = C3 + (C4 - C0)*h - (C0*x + C1*y + C2)*hx - (C3*x + C4*y + C5)*hy - C1*h**2 + pde, denom = pde.as_numer_denom() + pde = powsimp(expand(pde)) + if pde.is_Add: + terms = pde.args + for term in terms: + if term.is_Mul: + rem = Mul(*[m for m in term.args if not m.has(x, y)]) + xypart = term/rem + if xypart not in coeffdict: + coeffdict[xypart] = rem + else: + coeffdict[xypart] += rem + else: + if term not in coeffdict: + coeffdict[term] = S.One + else: + coeffdict[term] += S.One + + sollist = coeffdict.values() + soldict = solve(sollist, symlist) + if soldict: + if isinstance(soldict, list): + soldict = soldict[0] + subval = soldict.values() + if any(t for t in subval): + onedict = dict(zip(symlist, [1]*6)) + xival = C0*x + C1*func + C2 + etaval = C3*x + C4*func + C5 + xival = xival.subs(soldict) + etaval = etaval.subs(soldict) + xival = xival.subs(onedict) + etaval = etaval.subs(onedict) + return [{xi: xival, eta: etaval}] + + +def _lie_group_remove(coords): + r""" + This function is strictly meant for internal use by the Lie group ODE solving + method. It replaces arbitrary functions returned by pdsolve as follows: + + 1] If coords is an arbitrary function, then its argument is returned. + 2] An arbitrary function in an Add object is replaced by zero. + 3] An arbitrary function in a Mul object is replaced by one. + 4] If there is no arbitrary function coords is returned unchanged. + + Examples + ======== + + >>> from sympy.solvers.ode.lie_group import _lie_group_remove + >>> from sympy import Function + >>> from sympy.abc import x, y + >>> F = Function("F") + >>> eq = x**2*y + >>> _lie_group_remove(eq) + x**2*y + >>> eq = F(x**2*y) + >>> _lie_group_remove(eq) + x**2*y + >>> eq = x*y**2 + F(x**3) + >>> _lie_group_remove(eq) + x*y**2 + >>> eq = (F(x**3) + y)*x**4 + >>> _lie_group_remove(eq) + x**4*y + + """ + if isinstance(coords, AppliedUndef): + return coords.args[0] + elif coords.is_Add: + subfunc = coords.atoms(AppliedUndef) + if subfunc: + for func in subfunc: + coords = coords.subs(func, 0) + return coords + elif coords.is_Pow: + base, expr = coords.as_base_exp() + base = _lie_group_remove(base) + expr = _lie_group_remove(expr) + return base**expr + elif coords.is_Mul: + mulargs = [] + coordargs = coords.args + for arg in coordargs: + if not isinstance(coords, AppliedUndef): + mulargs.append(_lie_group_remove(arg)) + return Mul(*mulargs) + return coords diff --git a/phivenv/Lib/site-packages/sympy/solvers/ode/nonhomogeneous.py b/phivenv/Lib/site-packages/sympy/solvers/ode/nonhomogeneous.py new file mode 100644 index 0000000000000000000000000000000000000000..ae39d55664e4850168ca7d68f65cf02171979957 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/solvers/ode/nonhomogeneous.py @@ -0,0 +1,484 @@ +r""" +This File contains helper functions for nth_linear_constant_coeff_undetermined_coefficients, +nth_linear_euler_eq_nonhomogeneous_undetermined_coefficients, +nth_linear_constant_coeff_variation_of_parameters, +and nth_linear_euler_eq_nonhomogeneous_variation_of_parameters. + +All the functions in this file are used by more than one solvers so, instead of creating +instances in other classes for using them it is better to keep it here as separate helpers. + +""" +from collections import Counter +from sympy.core import Add, S +from sympy.core.function import diff, expand, _mexpand, expand_mul +from sympy.core.relational import Eq +from sympy.core.sorting import default_sort_key +from sympy.core.symbol import Dummy, Wild +from sympy.functions import exp, cos, cosh, im, log, re, sin, sinh, \ + atan2, conjugate +from sympy.integrals import Integral +from sympy.polys import (Poly, RootOf, rootof, roots) +from sympy.simplify import collect, simplify, separatevars, powsimp, trigsimp # type: ignore +from sympy.utilities import numbered_symbols +from sympy.solvers.solvers import solve +from sympy.matrices import wronskian +from .subscheck import sub_func_doit +from sympy.solvers.ode.ode import get_numbered_constants + + +def _test_term(coeff, func, order): + r""" + Linear Euler ODEs have the form K*x**order*diff(y(x), x, order) = F(x), + where K is independent of x and y(x), order>= 0. + So we need to check that for each term, coeff == K*x**order from + some K. We have a few cases, since coeff may have several + different types. + """ + x = func.args[0] + f = func.func + if order < 0: + raise ValueError("order should be greater than 0") + if coeff == 0: + return True + if order == 0: + if x in coeff.free_symbols: + return False + return True + if coeff.is_Mul: + if coeff.has(f(x)): + return False + return x**order in coeff.args + elif coeff.is_Pow: + return coeff.as_base_exp() == (x, order) + elif order == 1: + return x == coeff + return False + + +def _get_euler_characteristic_eq_sols(eq, func, match_obj): + r""" + Returns the solution of homogeneous part of the linear euler ODE and + the list of roots of characteristic equation. + + The parameter ``match_obj`` is a dict of order:coeff terms, where order is the order + of the derivative on each term, and coeff is the coefficient of that derivative. + + """ + x = func.args[0] + f = func.func + + # First, set up characteristic equation. + chareq, symbol = S.Zero, Dummy('x') + + for i in match_obj: + if i >= 0: + chareq += (match_obj[i]*diff(x**symbol, x, i)*x**-symbol).expand() + + chareq = Poly(chareq, symbol) + chareqroots = [rootof(chareq, k) for k in range(chareq.degree())] + collectterms = [] + + # A generator of constants + constants = list(get_numbered_constants(eq, num=chareq.degree()*2)) + constants.reverse() + + # Create a dict root: multiplicity or charroots + charroots = Counter(chareqroots) + gsol = S.Zero + ln = log + for root, multiplicity in charroots.items(): + for i in range(multiplicity): + if isinstance(root, RootOf): + gsol += (x**root) * constants.pop() + if multiplicity != 1: + raise ValueError("Value should be 1") + collectterms = [(0, root, 0)] + collectterms + elif root.is_real: + gsol += ln(x)**i*(x**root) * constants.pop() + collectterms = [(i, root, 0)] + collectterms + else: + reroot = re(root) + imroot = im(root) + gsol += ln(x)**i * (x**reroot) * ( + constants.pop() * sin(abs(imroot)*ln(x)) + + constants.pop() * cos(imroot*ln(x))) + collectterms = [(i, reroot, imroot)] + collectterms + + gsol = Eq(f(x), gsol) + + gensols = [] + # Keep track of when to use sin or cos for nonzero imroot + for i, reroot, imroot in collectterms: + if imroot == 0: + gensols.append(ln(x)**i*x**reroot) + else: + sin_form = ln(x)**i*x**reroot*sin(abs(imroot)*ln(x)) + if sin_form in gensols: + cos_form = ln(x)**i*x**reroot*cos(imroot*ln(x)) + gensols.append(cos_form) + else: + gensols.append(sin_form) + return gsol, gensols + + +def _solve_variation_of_parameters(eq, func, roots, homogen_sol, order, match_obj, simplify_flag=True): + r""" + Helper function for the method of variation of parameters and nonhomogeneous euler eq. + + See the + :py:meth:`~sympy.solvers.ode.single.NthLinearConstantCoeffVariationOfParameters` + docstring for more information on this method. + + The parameter are ``match_obj`` should be a dictionary that has the following + keys: + + ``list`` + A list of solutions to the homogeneous equation. + + ``sol`` + The general solution. + + """ + f = func.func + x = func.args[0] + r = match_obj + psol = 0 + wr = wronskian(roots, x) + + if simplify_flag: + wr = simplify(wr) # We need much better simplification for + # some ODEs. See issue 4662, for example. + # To reduce commonly occurring sin(x)**2 + cos(x)**2 to 1 + wr = trigsimp(wr, deep=True, recursive=True) + if not wr: + # The wronskian will be 0 iff the solutions are not linearly + # independent. + raise NotImplementedError("Cannot find " + str(order) + + " solutions to the homogeneous equation necessary to apply " + + "variation of parameters to " + str(eq) + " (Wronskian == 0)") + if len(roots) != order: + raise NotImplementedError("Cannot find " + str(order) + + " solutions to the homogeneous equation necessary to apply " + + "variation of parameters to " + + str(eq) + " (number of terms != order)") + negoneterm = S.NegativeOne**(order) + for i in roots: + psol += negoneterm*Integral(wronskian([sol for sol in roots if sol != i], x)*r[-1]/wr, x)*i/r[order] + negoneterm *= -1 + + if simplify_flag: + psol = simplify(psol) + psol = trigsimp(psol, deep=True) + return Eq(f(x), homogen_sol.rhs + psol) + + +def _get_const_characteristic_eq_sols(r, func, order): + r""" + Returns the roots of characteristic equation of constant coefficient + linear ODE and list of collectterms which is later on used by simplification + to use collect on solution. + + The parameter `r` is a dict of order:coeff terms, where order is the order of the + derivative on each term, and coeff is the coefficient of that derivative. + + """ + x = func.args[0] + # First, set up characteristic equation. + chareq, symbol = S.Zero, Dummy('x') + + for i in r.keys(): + if isinstance(i, str) or i < 0: + pass + else: + chareq += r[i]*symbol**i + + chareq = Poly(chareq, symbol) + # Can't just call roots because it doesn't return rootof for unsolveable + # polynomials. + chareqroots = roots(chareq, multiple=True) + if len(chareqroots) != order: + chareqroots = [rootof(chareq, k) for k in range(chareq.degree())] + + chareq_is_complex = not all(i.is_real for i in chareq.all_coeffs()) + + # Create a dict root: multiplicity or charroots + charroots = Counter(chareqroots) + # We need to keep track of terms so we can run collect() at the end. + # This is necessary for constantsimp to work properly. + collectterms = [] + gensols = [] + conjugate_roots = [] # used to prevent double-use of conjugate roots + # Loop over roots in theorder provided by roots/rootof... + for root in chareqroots: + # but don't repoeat multiple roots. + if root not in charroots: + continue + multiplicity = charroots.pop(root) + for i in range(multiplicity): + if chareq_is_complex: + gensols.append(x**i*exp(root*x)) + collectterms = [(i, root, 0)] + collectterms + continue + reroot = re(root) + imroot = im(root) + if imroot.has(atan2) and reroot.has(atan2): + # Remove this condition when re and im stop returning + # circular atan2 usages. + gensols.append(x**i*exp(root*x)) + collectterms = [(i, root, 0)] + collectterms + else: + if root in conjugate_roots: + collectterms = [(i, reroot, imroot)] + collectterms + continue + if imroot == 0: + gensols.append(x**i*exp(reroot*x)) + collectterms = [(i, reroot, 0)] + collectterms + continue + conjugate_roots.append(conjugate(root)) + gensols.append(x**i*exp(reroot*x) * sin(abs(imroot) * x)) + gensols.append(x**i*exp(reroot*x) * cos( imroot * x)) + + # This ordering is important + collectterms = [(i, reroot, imroot)] + collectterms + return gensols, collectterms + + +# Ideally these kind of simplification functions shouldn't be part of solvers. +# odesimp should be improved to handle these kind of specific simplifications. +def _get_simplified_sol(sol, func, collectterms): + r""" + Helper function which collects the solution on + collectterms. Ideally this should be handled by odesimp.It is used + only when the simplify is set to True in dsolve. + + The parameter ``collectterms`` is a list of tuple (i, reroot, imroot) where `i` is + the multiplicity of the root, reroot is real part and imroot being the imaginary part. + + """ + f = func.func + x = func.args[0] + collectterms.sort(key=default_sort_key) + collectterms.reverse() + assert len(sol) == 1 and sol[0].lhs == f(x) + sol = sol[0].rhs + sol = expand_mul(sol) + for i, reroot, imroot in collectterms: + sol = collect(sol, x**i*exp(reroot*x)*sin(abs(imroot)*x)) + sol = collect(sol, x**i*exp(reroot*x)*cos(imroot*x)) + for i, reroot, imroot in collectterms: + sol = collect(sol, x**i*exp(reroot*x)) + sol = powsimp(sol) + return Eq(f(x), sol) + + +def _undetermined_coefficients_match(expr, x, func=None, eq_homogeneous=S.Zero): + r""" + Returns a trial function match if undetermined coefficients can be applied + to ``expr``, and ``None`` otherwise. + + A trial expression can be found for an expression for use with the method + of undetermined coefficients if the expression is an + additive/multiplicative combination of constants, polynomials in `x` (the + independent variable of expr), `\sin(a x + b)`, `\cos(a x + b)`, and + `e^{a x}` terms (in other words, it has a finite number of linearly + independent derivatives). + + Note that you may still need to multiply each term returned here by + sufficient `x` to make it linearly independent with the solutions to the + homogeneous equation. + + This is intended for internal use by ``undetermined_coefficients`` hints. + + SymPy currently has no way to convert `\sin^n(x) \cos^m(y)` into a sum of + only `\sin(a x)` and `\cos(b x)` terms, so these are not implemented. So, + for example, you will need to manually convert `\sin^2(x)` into `[1 + + \cos(2 x)]/2` to properly apply the method of undetermined coefficients on + it. + + Examples + ======== + + >>> from sympy import log, exp + >>> from sympy.solvers.ode.nonhomogeneous import _undetermined_coefficients_match + >>> from sympy.abc import x + >>> _undetermined_coefficients_match(9*x*exp(x) + exp(-x), x) + {'test': True, 'trialset': {x*exp(x), exp(-x), exp(x)}} + >>> _undetermined_coefficients_match(log(x), x) + {'test': False} + + """ + a = Wild('a', exclude=[x]) + b = Wild('b', exclude=[x]) + expr = powsimp(expr, combine='exp') # exp(x)*exp(2*x + 1) => exp(3*x + 1) + retdict = {} + + def _test_term(expr, x) -> bool: + r""" + Test if ``expr`` fits the proper form for undetermined coefficients. + """ + if not expr.has(x): + return True + if expr.is_Add: + return all(_test_term(i, x) for i in expr.args) + if expr.is_Mul: + if expr.has(sin, cos): + foundtrig = False + # Make sure that there is only one trig function in the args. + # See the docstring. + for i in expr.args: + if i.has(sin, cos): + if foundtrig: + return False + else: + foundtrig = True + return all(_test_term(i, x) for i in expr.args) + if expr.is_Function: + return expr.func in (sin, cos, exp, sinh, cosh) and \ + bool(expr.args[0].match(a*x + b)) + if expr.is_Pow and expr.base.is_Symbol and expr.exp.is_Integer and \ + expr.exp >= 0: + return True + if expr.is_Pow and expr.base.is_number: + return bool(expr.exp.match(a*x + b)) + return expr.is_Symbol or bool(expr.is_number) + + def _get_trial_set(expr, x, exprs=set()): + r""" + Returns a set of trial terms for undetermined coefficients. + + The idea behind undetermined coefficients is that the terms expression + repeat themselves after a finite number of derivatives, except for the + coefficients (they are linearly dependent). So if we collect these, + we should have the terms of our trial function. + """ + def _remove_coefficient(expr, x): + r""" + Returns the expression without a coefficient. + + Similar to expr.as_independent(x)[1], except it only works + multiplicatively. + """ + term = S.One + if expr.is_Mul: + for i in expr.args: + if i.has(x): + term *= i + elif expr.has(x): + term = expr + return term + + expr = expand_mul(expr) + if expr.is_Add: + for term in expr.args: + if _remove_coefficient(term, x) in exprs: + pass + else: + exprs.add(_remove_coefficient(term, x)) + exprs = exprs.union(_get_trial_set(term, x, exprs)) + else: + term = _remove_coefficient(expr, x) + tmpset = exprs.union({term}) + oldset = set() + while tmpset != oldset: + # If you get stuck in this loop, then _test_term is probably + # broken + oldset = tmpset.copy() + expr = expr.diff(x) + term = _remove_coefficient(expr, x) + if term.is_Add: + tmpset = tmpset.union(_get_trial_set(term, x, tmpset)) + else: + tmpset.add(term) + exprs = tmpset + return exprs + + def is_homogeneous_solution(term): + r""" This function checks whether the given trialset contains any root + of homogeneous equation""" + return expand(sub_func_doit(eq_homogeneous, func, term)).is_zero + + retdict['test'] = _test_term(expr, x) + if retdict['test']: + # Try to generate a list of trial solutions that will have the + # undetermined coefficients. Note that if any of these are not linearly + # independent with any of the solutions to the homogeneous equation, + # then they will need to be multiplied by sufficient x to make them so. + # This function DOES NOT do that (it doesn't even look at the + # homogeneous equation). + temp_set = set() + for i in Add.make_args(expr): + act = _get_trial_set(i, x) + if eq_homogeneous is not S.Zero: + while any(is_homogeneous_solution(ts) for ts in act): + act = {x*ts for ts in act} + temp_set = temp_set.union(act) + + retdict['trialset'] = temp_set + return retdict + + +def _solve_undetermined_coefficients(eq, func, order, match, trialset): + r""" + Helper function for the method of undetermined coefficients. + + See the + :py:meth:`~sympy.solvers.ode.single.NthLinearConstantCoeffUndeterminedCoefficients` + docstring for more information on this method. + + The parameter ``trialset`` is the set of trial functions as returned by + ``_undetermined_coefficients_match()['trialset']``. + + The parameter ``match`` should be a dictionary that has the following + keys: + + ``list`` + A list of solutions to the homogeneous equation. + + ``sol`` + The general solution. + + """ + r = match + coeffs = numbered_symbols('a', cls=Dummy) + coefflist = [] + gensols = r['list'] + gsol = r['sol'] + f = func.func + x = func.args[0] + + if len(gensols) != order: + raise NotImplementedError("Cannot find " + str(order) + + " solutions to the homogeneous equation necessary to apply" + + " undetermined coefficients to " + str(eq) + + " (number of terms != order)") + + trialfunc = 0 + for i in trialset: + c = next(coeffs) + coefflist.append(c) + trialfunc += c*i + + eqs = sub_func_doit(eq, f(x), trialfunc) + + coeffsdict = dict(list(zip(trialset, [0]*(len(trialset) + 1)))) + + eqs = _mexpand(eqs) + + for i in Add.make_args(eqs): + s = separatevars(i, dict=True, symbols=[x]) + if coeffsdict.get(s[x]): + coeffsdict[s[x]] += s['coeff'] + else: + coeffsdict[s[x]] = s['coeff'] + + coeffvals = solve(list(coeffsdict.values()), coefflist) + + if not coeffvals: + raise NotImplementedError( + "Could not solve `%s` using the " + "method of undetermined coefficients " + "(unable to solve for coefficients)." % eq) + + psol = trialfunc.subs(coeffvals) + + return Eq(f(x), gsol.rhs + psol) diff --git a/phivenv/Lib/site-packages/sympy/solvers/ode/ode.py b/phivenv/Lib/site-packages/sympy/solvers/ode/ode.py new file mode 100644 index 0000000000000000000000000000000000000000..6a28b2162b38a2dd8612c14e91fd588912f6756a --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/solvers/ode/ode.py @@ -0,0 +1,3572 @@ +r""" +This module contains :py:meth:`~sympy.solvers.ode.dsolve` and different helper +functions that it uses. + +:py:meth:`~sympy.solvers.ode.dsolve` solves ordinary differential equations. +See the docstring on the various functions for their uses. Note that partial +differential equations support is in ``pde.py``. Note that hint functions +have docstrings describing their various methods, but they are intended for +internal use. Use ``dsolve(ode, func, hint=hint)`` to solve an ODE using a +specific hint. See also the docstring on +:py:meth:`~sympy.solvers.ode.dsolve`. + +**Functions in this module** + + These are the user functions in this module: + + - :py:meth:`~sympy.solvers.ode.dsolve` - Solves ODEs. + - :py:meth:`~sympy.solvers.ode.classify_ode` - Classifies ODEs into + possible hints for :py:meth:`~sympy.solvers.ode.dsolve`. + - :py:meth:`~sympy.solvers.ode.checkodesol` - Checks if an equation is the + solution to an ODE. + - :py:meth:`~sympy.solvers.ode.homogeneous_order` - Returns the + homogeneous order of an expression. + - :py:meth:`~sympy.solvers.ode.infinitesimals` - Returns the infinitesimals + of the Lie group of point transformations of an ODE, such that it is + invariant. + - :py:meth:`~sympy.solvers.ode.checkinfsol` - Checks if the given infinitesimals + are the actual infinitesimals of a first order ODE. + + These are the non-solver helper functions that are for internal use. The + user should use the various options to + :py:meth:`~sympy.solvers.ode.dsolve` to obtain the functionality provided + by these functions: + + - :py:meth:`~sympy.solvers.ode.ode.odesimp` - Does all forms of ODE + simplification. + - :py:meth:`~sympy.solvers.ode.ode.ode_sol_simplicity` - A key function for + comparing solutions by simplicity. + - :py:meth:`~sympy.solvers.ode.constantsimp` - Simplifies arbitrary + constants. + - :py:meth:`~sympy.solvers.ode.ode.constant_renumber` - Renumber arbitrary + constants. + - :py:meth:`~sympy.solvers.ode.ode._handle_Integral` - Evaluate unevaluated + Integrals. + + See also the docstrings of these functions. + +**Currently implemented solver methods** + +The following methods are implemented for solving ordinary differential +equations. See the docstrings of the various hint functions for more +information on each (run ``help(ode)``): + + - 1st order separable differential equations. + - 1st order differential equations whose coefficients or `dx` and `dy` are + functions homogeneous of the same order. + - 1st order exact differential equations. + - 1st order linear differential equations. + - 1st order Bernoulli differential equations. + - Power series solutions for first order differential equations. + - Lie Group method of solving first order differential equations. + - 2nd order Liouville differential equations. + - Power series solutions for second order differential equations + at ordinary and regular singular points. + - `n`\th order differential equation that can be solved with algebraic + rearrangement and integration. + - `n`\th order linear homogeneous differential equation with constant + coefficients. + - `n`\th order linear inhomogeneous differential equation with constant + coefficients using the method of undetermined coefficients. + - `n`\th order linear inhomogeneous differential equation with constant + coefficients using the method of variation of parameters. + +**Philosophy behind this module** + +This module is designed to make it easy to add new ODE solving methods without +having to mess with the solving code for other methods. The idea is that +there is a :py:meth:`~sympy.solvers.ode.classify_ode` function, which takes in +an ODE and tells you what hints, if any, will solve the ODE. It does this +without attempting to solve the ODE, so it is fast. Each solving method is a +hint, and it has its own function, named ``ode_``. That function takes +in the ODE and any match expression gathered by +:py:meth:`~sympy.solvers.ode.classify_ode` and returns a solved result. If +this result has any integrals in it, the hint function will return an +unevaluated :py:class:`~sympy.integrals.integrals.Integral` class. +:py:meth:`~sympy.solvers.ode.dsolve`, which is the user wrapper function +around all of this, will then call :py:meth:`~sympy.solvers.ode.ode.odesimp` on +the result, which, among other things, will attempt to solve the equation for +the dependent variable (the function we are solving for), simplify the +arbitrary constants in the expression, and evaluate any integrals, if the hint +allows it. + +**How to add new solution methods** + +If you have an ODE that you want :py:meth:`~sympy.solvers.ode.dsolve` to be +able to solve, try to avoid adding special case code here. Instead, try +finding a general method that will solve your ODE, as well as others. This +way, the :py:mod:`~sympy.solvers.ode` module will become more robust, and +unhindered by special case hacks. WolphramAlpha and Maple's +DETools[odeadvisor] function are two resources you can use to classify a +specific ODE. It is also better for a method to work with an `n`\th order ODE +instead of only with specific orders, if possible. + +To add a new method, there are a few things that you need to do. First, you +need a hint name for your method. Try to name your hint so that it is +unambiguous with all other methods, including ones that may not be implemented +yet. If your method uses integrals, also include a ``hint_Integral`` hint. +If there is more than one way to solve ODEs with your method, include a hint +for each one, as well as a ``_best`` hint. Your ``ode__best()`` +function should choose the best using min with ``ode_sol_simplicity`` as the +key argument. See +:obj:`~sympy.solvers.ode.single.HomogeneousCoeffBest`, for example. +The function that uses your method will be called ``ode_()``, so the +hint must only use characters that are allowed in a Python function name +(alphanumeric characters and the underscore '``_``' character). Include a +function for every hint, except for ``_Integral`` hints +(:py:meth:`~sympy.solvers.ode.dsolve` takes care of those automatically). +Hint names should be all lowercase, unless a word is commonly capitalized +(such as Integral or Bernoulli). If you have a hint that you do not want to +run with ``all_Integral`` that does not have an ``_Integral`` counterpart (such +as a best hint that would defeat the purpose of ``all_Integral``), you will +need to remove it manually in the :py:meth:`~sympy.solvers.ode.dsolve` code. +See also the :py:meth:`~sympy.solvers.ode.classify_ode` docstring for +guidelines on writing a hint name. + +Determine *in general* how the solutions returned by your method compare with +other methods that can potentially solve the same ODEs. Then, put your hints +in the :py:data:`~sympy.solvers.ode.allhints` tuple in the order that they +should be called. The ordering of this tuple determines which hints are +default. Note that exceptions are ok, because it is easy for the user to +choose individual hints with :py:meth:`~sympy.solvers.ode.dsolve`. In +general, ``_Integral`` variants should go at the end of the list, and +``_best`` variants should go before the various hints they apply to. For +example, the ``undetermined_coefficients`` hint comes before the +``variation_of_parameters`` hint because, even though variation of parameters +is more general than undetermined coefficients, undetermined coefficients +generally returns cleaner results for the ODEs that it can solve than +variation of parameters does, and it does not require integration, so it is +much faster. + +Next, you need to have a match expression or a function that matches the type +of the ODE, which you should put in :py:meth:`~sympy.solvers.ode.classify_ode` +(if the match function is more than just a few lines. It should match the +ODE without solving for it as much as possible, so that +:py:meth:`~sympy.solvers.ode.classify_ode` remains fast and is not hindered by +bugs in solving code. Be sure to consider corner cases. For example, if your +solution method involves dividing by something, make sure you exclude the case +where that division will be 0. + +In most cases, the matching of the ODE will also give you the various parts +that you need to solve it. You should put that in a dictionary (``.match()`` +will do this for you), and add that as ``matching_hints['hint'] = matchdict`` +in the relevant part of :py:meth:`~sympy.solvers.ode.classify_ode`. +:py:meth:`~sympy.solvers.ode.classify_ode` will then send this to +:py:meth:`~sympy.solvers.ode.dsolve`, which will send it to your function as +the ``match`` argument. Your function should be named ``ode_(eq, func, +order, match)`. If you need to send more information, put it in the ``match`` +dictionary. For example, if you had to substitute in a dummy variable in +:py:meth:`~sympy.solvers.ode.classify_ode` to match the ODE, you will need to +pass it to your function using the `match` dict to access it. You can access +the independent variable using ``func.args[0]``, and the dependent variable +(the function you are trying to solve for) as ``func.func``. If, while trying +to solve the ODE, you find that you cannot, raise ``NotImplementedError``. +:py:meth:`~sympy.solvers.ode.dsolve` will catch this error with the ``all`` +meta-hint, rather than causing the whole routine to fail. + +Add a docstring to your function that describes the method employed. Like +with anything else in SymPy, you will need to add a doctest to the docstring, +in addition to real tests in ``test_ode.py``. Try to maintain consistency +with the other hint functions' docstrings. Add your method to the list at the +top of this docstring. Also, add your method to ``ode.rst`` in the +``docs/src`` directory, so that the Sphinx docs will pull its docstring into +the main SymPy documentation. Be sure to make the Sphinx documentation by +running ``make html`` from within the doc directory to verify that the +docstring formats correctly. + +If your solution method involves integrating, use :py:obj:`~.Integral` instead of +:py:meth:`~sympy.core.expr.Expr.integrate`. This allows the user to bypass +hard/slow integration by using the ``_Integral`` variant of your hint. In +most cases, calling :py:meth:`sympy.core.basic.Basic.doit` will integrate your +solution. If this is not the case, you will need to write special code in +:py:meth:`~sympy.solvers.ode.ode._handle_Integral`. Arbitrary constants should be +symbols named ``C1``, ``C2``, and so on. All solution methods should return +an equality instance. If you need an arbitrary number of arbitrary constants, +you can use ``constants = numbered_symbols(prefix='C', cls=Symbol, start=1)``. +If it is possible to solve for the dependent function in a general way, do so. +Otherwise, do as best as you can, but do not call solve in your +``ode_()`` function. :py:meth:`~sympy.solvers.ode.ode.odesimp` will attempt +to solve the solution for you, so you do not need to do that. Lastly, if your +ODE has a common simplification that can be applied to your solutions, you can +add a special case in :py:meth:`~sympy.solvers.ode.ode.odesimp` for it. For +example, solutions returned from the ``1st_homogeneous_coeff`` hints often +have many :obj:`~sympy.functions.elementary.exponential.log` terms, so +:py:meth:`~sympy.solvers.ode.ode.odesimp` calls +:py:meth:`~sympy.simplify.simplify.logcombine` on them (it also helps to write +the arbitrary constant as ``log(C1)`` instead of ``C1`` in this case). Also +consider common ways that you can rearrange your solution to have +:py:meth:`~sympy.solvers.ode.constantsimp` take better advantage of it. It is +better to put simplification in :py:meth:`~sympy.solvers.ode.ode.odesimp` than in +your method, because it can then be turned off with the simplify flag in +:py:meth:`~sympy.solvers.ode.dsolve`. If you have any extraneous +simplification in your function, be sure to only run it using ``if +match.get('simplify', True):``, especially if it can be slow or if it can +reduce the domain of the solution. + +Finally, as with every contribution to SymPy, your method will need to be +tested. Add a test for each method in ``test_ode.py``. Follow the +conventions there, i.e., test the solver using ``dsolve(eq, f(x), +hint=your_hint)``, and also test the solution using +:py:meth:`~sympy.solvers.ode.checkodesol` (you can put these in a separate +tests and skip/XFAIL if it runs too slow/does not work). Be sure to call your +hint specifically in :py:meth:`~sympy.solvers.ode.dsolve`, that way the test +will not be broken simply by the introduction of another matching hint. If your +method works for higher order (>1) ODEs, you will need to run ``sol = +constant_renumber(sol, 'C', 1, order)`` for each solution, where ``order`` is +the order of the ODE. This is because ``constant_renumber`` renumbers the +arbitrary constants by printing order, which is platform dependent. Try to +test every corner case of your solver, including a range of orders if it is a +`n`\th order solver, but if your solver is slow, such as if it involves hard +integration, try to keep the test run time down. + +Feel free to refactor existing hints to avoid duplicating code or creating +inconsistencies. If you can show that your method exactly duplicates an +existing method, including in the simplicity and speed of obtaining the +solutions, then you can remove the old, less general method. The existing +code is tested extensively in ``test_ode.py``, so if anything is broken, one +of those tests will surely fail. + +""" + +from sympy.core import Add, S, Mul, Pow, oo +from sympy.core.containers import Tuple +from sympy.core.expr import AtomicExpr, Expr +from sympy.core.function import (Function, Derivative, AppliedUndef, diff, + expand, expand_mul, Subs) +from sympy.core.multidimensional import vectorize +from sympy.core.numbers import nan, zoo, Number +from sympy.core.relational import Equality, Eq +from sympy.core.sorting import default_sort_key, ordered +from sympy.core.symbol import Symbol, Wild, Dummy, symbols +from sympy.core.sympify import sympify +from sympy.core.traversal import preorder_traversal + +from sympy.logic.boolalg import (BooleanAtom, BooleanTrue, + BooleanFalse) +from sympy.functions import exp, log, sqrt +from sympy.functions.combinatorial.factorials import factorial +from sympy.integrals.integrals import Integral +from sympy.polys import (Poly, terms_gcd, PolynomialError, lcm) +from sympy.polys.polytools import cancel +from sympy.series import Order +from sympy.series.series import series +from sympy.simplify import (collect, logcombine, powsimp, # type: ignore + separatevars, simplify, cse) +from sympy.simplify.radsimp import collect_const +from sympy.solvers import checksol, solve + +from sympy.utilities import numbered_symbols +from sympy.utilities.iterables import uniq, sift, iterable +from sympy.solvers.deutils import _preprocess, ode_order, _desolve + + +#: This is a list of hints in the order that they should be preferred by +#: :py:meth:`~sympy.solvers.ode.classify_ode`. In general, hints earlier in the +#: list should produce simpler solutions than those later in the list (for +#: ODEs that fit both). For now, the order of this list is based on empirical +#: observations by the developers of SymPy. +#: +#: The hint used by :py:meth:`~sympy.solvers.ode.dsolve` for a specific ODE +#: can be overridden (see the docstring). +#: +#: In general, ``_Integral`` hints are grouped at the end of the list, unless +#: there is a method that returns an unevaluable integral most of the time +#: (which go near the end of the list anyway). ``default``, ``all``, +#: ``best``, and ``all_Integral`` meta-hints should not be included in this +#: list, but ``_best`` and ``_Integral`` hints should be included. +allhints = ( + "factorable", + "nth_algebraic", + "separable", + "1st_exact", + "1st_linear", + "Bernoulli", + "1st_rational_riccati", + "Riccati_special_minus2", + "1st_homogeneous_coeff_best", + "1st_homogeneous_coeff_subs_indep_div_dep", + "1st_homogeneous_coeff_subs_dep_div_indep", + "almost_linear", + "linear_coefficients", + "separable_reduced", + "1st_power_series", + "lie_group", + "nth_linear_constant_coeff_homogeneous", + "nth_linear_euler_eq_homogeneous", + "nth_linear_constant_coeff_undetermined_coefficients", + "nth_linear_euler_eq_nonhomogeneous_undetermined_coefficients", + "nth_linear_constant_coeff_variation_of_parameters", + "nth_linear_euler_eq_nonhomogeneous_variation_of_parameters", + "Liouville", + "2nd_linear_airy", + "2nd_linear_bessel", + "2nd_hypergeometric", + "2nd_hypergeometric_Integral", + "nth_order_reducible", + "2nd_power_series_ordinary", + "2nd_power_series_regular", + "nth_algebraic_Integral", + "separable_Integral", + "1st_exact_Integral", + "1st_linear_Integral", + "Bernoulli_Integral", + "1st_homogeneous_coeff_subs_indep_div_dep_Integral", + "1st_homogeneous_coeff_subs_dep_div_indep_Integral", + "almost_linear_Integral", + "linear_coefficients_Integral", + "separable_reduced_Integral", + "nth_linear_constant_coeff_variation_of_parameters_Integral", + "nth_linear_euler_eq_nonhomogeneous_variation_of_parameters_Integral", + "Liouville_Integral", + "2nd_nonlinear_autonomous_conserved", + "2nd_nonlinear_autonomous_conserved_Integral", + ) + + + +def get_numbered_constants(eq, num=1, start=1, prefix='C'): + """ + Returns a list of constants that do not occur + in eq already. + """ + + ncs = iter_numbered_constants(eq, start, prefix) + Cs = [next(ncs) for i in range(num)] + return (Cs[0] if num == 1 else tuple(Cs)) + + +def iter_numbered_constants(eq, start=1, prefix='C'): + """ + Returns an iterator of constants that do not occur + in eq already. + """ + + if isinstance(eq, (Expr, Eq)): + eq = [eq] + elif not iterable(eq): + raise ValueError("Expected Expr or iterable but got %s" % eq) + + atom_set = set().union(*[i.free_symbols for i in eq]) + func_set = set().union(*[i.atoms(Function) for i in eq]) + if func_set: + atom_set |= {Symbol(str(f.func)) for f in func_set} + return numbered_symbols(start=start, prefix=prefix, exclude=atom_set) + + +def dsolve(eq, func=None, hint="default", simplify=True, + ics= None, xi=None, eta=None, x0=0, n=6, **kwargs): + r""" + Solves any (supported) kind of ordinary differential equation and + system of ordinary differential equations. + + For single ordinary differential equation + ========================================= + + It is classified under this when number of equation in ``eq`` is one. + **Usage** + + ``dsolve(eq, f(x), hint)`` -> Solve ordinary differential equation + ``eq`` for function ``f(x)``, using method ``hint``. + + **Details** + + ``eq`` can be any supported ordinary differential equation (see the + :py:mod:`~sympy.solvers.ode` docstring for supported methods). + This can either be an :py:class:`~sympy.core.relational.Equality`, + or an expression, which is assumed to be equal to ``0``. + + ``f(x)`` is a function of one variable whose derivatives in that + variable make up the ordinary differential equation ``eq``. In + many cases it is not necessary to provide this; it will be + autodetected (and an error raised if it could not be detected). + + ``hint`` is the solving method that you want dsolve to use. Use + ``classify_ode(eq, f(x))`` to get all of the possible hints for an + ODE. The default hint, ``default``, will use whatever hint is + returned first by :py:meth:`~sympy.solvers.ode.classify_ode`. See + Hints below for more options that you can use for hint. + + ``simplify`` enables simplification by + :py:meth:`~sympy.solvers.ode.ode.odesimp`. See its docstring for more + information. Turn this off, for example, to disable solving of + solutions for ``func`` or simplification of arbitrary constants. + It will still integrate with this hint. Note that the solution may + contain more arbitrary constants than the order of the ODE with + this option enabled. + + ``xi`` and ``eta`` are the infinitesimal functions of an ordinary + differential equation. They are the infinitesimals of the Lie group + of point transformations for which the differential equation is + invariant. The user can specify values for the infinitesimals. If + nothing is specified, ``xi`` and ``eta`` are calculated using + :py:meth:`~sympy.solvers.ode.infinitesimals` with the help of various + heuristics. + + ``ics`` is the set of initial/boundary conditions for the differential equation. + It should be given in the form of ``{f(x0): x1, f(x).diff(x).subs(x, x2): + x3}`` and so on. For power series solutions, if no initial + conditions are specified ``f(0)`` is assumed to be ``C0`` and the power + series solution is calculated about 0. + + ``x0`` is the point about which the power series solution of a differential + equation is to be evaluated. + + ``n`` gives the exponent of the dependent variable up to which the power series + solution of a differential equation is to be evaluated. + + **Hints** + + Aside from the various solving methods, there are also some meta-hints + that you can pass to :py:meth:`~sympy.solvers.ode.dsolve`: + + ``default``: + This uses whatever hint is returned first by + :py:meth:`~sympy.solvers.ode.classify_ode`. This is the + default argument to :py:meth:`~sympy.solvers.ode.dsolve`. + + ``all``: + To make :py:meth:`~sympy.solvers.ode.dsolve` apply all + relevant classification hints, use ``dsolve(ODE, func, + hint="all")``. This will return a dictionary of + ``hint:solution`` terms. If a hint causes dsolve to raise the + ``NotImplementedError``, value of that hint's key will be the + exception object raised. The dictionary will also include + some special keys: + + - ``order``: The order of the ODE. See also + :py:meth:`~sympy.solvers.deutils.ode_order` in + ``deutils.py``. + - ``best``: The simplest hint; what would be returned by + ``best`` below. + - ``best_hint``: The hint that would produce the solution + given by ``best``. If more than one hint produces the best + solution, the first one in the tuple returned by + :py:meth:`~sympy.solvers.ode.classify_ode` is chosen. + - ``default``: The solution that would be returned by default. + This is the one produced by the hint that appears first in + the tuple returned by + :py:meth:`~sympy.solvers.ode.classify_ode`. + + ``all_Integral``: + This is the same as ``all``, except if a hint also has a + corresponding ``_Integral`` hint, it only returns the + ``_Integral`` hint. This is useful if ``all`` causes + :py:meth:`~sympy.solvers.ode.dsolve` to hang because of a + difficult or impossible integral. This meta-hint will also be + much faster than ``all``, because + :py:meth:`~sympy.core.expr.Expr.integrate` is an expensive + routine. + + ``best``: + To have :py:meth:`~sympy.solvers.ode.dsolve` try all methods + and return the simplest one. This takes into account whether + the solution is solvable in the function, whether it contains + any Integral classes (i.e. unevaluatable integrals), and + which one is the shortest in size. + + See also the :py:meth:`~sympy.solvers.ode.classify_ode` docstring for + more info on hints, and the :py:mod:`~sympy.solvers.ode` docstring for + a list of all supported hints. + + **Tips** + + - You can declare the derivative of an unknown function this way: + + >>> from sympy import Function, Derivative + >>> from sympy.abc import x # x is the independent variable + >>> f = Function("f")(x) # f is a function of x + >>> # f_ will be the derivative of f with respect to x + >>> f_ = Derivative(f, x) + + - See ``test_ode.py`` for many tests, which serves also as a set of + examples for how to use :py:meth:`~sympy.solvers.ode.dsolve`. + - :py:meth:`~sympy.solvers.ode.dsolve` always returns an + :py:class:`~sympy.core.relational.Equality` class (except for the + case when the hint is ``all`` or ``all_Integral``). If possible, it + solves the solution explicitly for the function being solved for. + Otherwise, it returns an implicit solution. + - Arbitrary constants are symbols named ``C1``, ``C2``, and so on. + - Because all solutions should be mathematically equivalent, some + hints may return the exact same result for an ODE. Often, though, + two different hints will return the same solution formatted + differently. The two should be equivalent. Also note that sometimes + the values of the arbitrary constants in two different solutions may + not be the same, because one constant may have "absorbed" other + constants into it. + - Do ``help(ode.ode_)`` to get help more information on a + specific hint, where ```` is the name of a hint without + ``_Integral``. + + For system of ordinary differential equations + ============================================= + + **Usage** + ``dsolve(eq, func)`` -> Solve a system of ordinary differential + equations ``eq`` for ``func`` being list of functions including + `x(t)`, `y(t)`, `z(t)` where number of functions in the list depends + upon the number of equations provided in ``eq``. + + **Details** + + ``eq`` can be any supported system of ordinary differential equations + This can either be an :py:class:`~sympy.core.relational.Equality`, + or an expression, which is assumed to be equal to ``0``. + + ``func`` holds ``x(t)`` and ``y(t)`` being functions of one variable which + together with some of their derivatives make up the system of ordinary + differential equation ``eq``. It is not necessary to provide this; it + will be autodetected (and an error raised if it could not be detected). + + **Hints** + + The hints are formed by parameters returned by classify_sysode, combining + them give hints name used later for forming method name. + + Examples + ======== + + >>> from sympy import Function, dsolve, Eq, Derivative, sin, cos, symbols + >>> from sympy.abc import x + >>> f = Function('f') + >>> dsolve(Derivative(f(x), x, x) + 9*f(x), f(x)) + Eq(f(x), C1*sin(3*x) + C2*cos(3*x)) + + >>> eq = sin(x)*cos(f(x)) + cos(x)*sin(f(x))*f(x).diff(x) + >>> dsolve(eq, hint='1st_exact') + [Eq(f(x), -acos(C1/cos(x)) + 2*pi), Eq(f(x), acos(C1/cos(x)))] + >>> dsolve(eq, hint='almost_linear') + [Eq(f(x), -acos(C1/cos(x)) + 2*pi), Eq(f(x), acos(C1/cos(x)))] + >>> t = symbols('t') + >>> x, y = symbols('x, y', cls=Function) + >>> eq = (Eq(Derivative(x(t),t), 12*t*x(t) + 8*y(t)), Eq(Derivative(y(t),t), 21*x(t) + 7*t*y(t))) + >>> dsolve(eq) + [Eq(x(t), C1*x0(t) + C2*x0(t)*Integral(8*exp(Integral(7*t, t))*exp(Integral(12*t, t))/x0(t)**2, t)), + Eq(y(t), C1*y0(t) + C2*(y0(t)*Integral(8*exp(Integral(7*t, t))*exp(Integral(12*t, t))/x0(t)**2, t) + + exp(Integral(7*t, t))*exp(Integral(12*t, t))/x0(t)))] + >>> eq = (Eq(Derivative(x(t),t),x(t)*y(t)*sin(t)), Eq(Derivative(y(t),t),y(t)**2*sin(t))) + >>> dsolve(eq) + {Eq(x(t), -exp(C1)/(C2*exp(C1) - cos(t))), Eq(y(t), -1/(C1 - cos(t)))} + """ + if iterable(eq): + from sympy.solvers.ode.systems import dsolve_system + + # This may have to be changed in future + # when we have weakly and strongly + # connected components. This have to + # changed to show the systems that haven't + # been solved. + try: + sol = dsolve_system(eq, funcs=func, ics=ics, doit=True) + return sol[0] if len(sol) == 1 else sol + except NotImplementedError: + pass + + match = classify_sysode(eq, func) + + eq = match['eq'] + order = match['order'] + func = match['func'] + t = list(list(eq[0].atoms(Derivative))[0].atoms(Symbol))[0] + + # keep highest order term coefficient positive + for i in range(len(eq)): + for func_ in func: + if isinstance(func_, list): + pass + else: + if eq[i].coeff(diff(func[i],t,ode_order(eq[i], func[i]))).is_negative: + eq[i] = -eq[i] + match['eq'] = eq + if len(set(order.values()))!=1: + raise ValueError("It solves only those systems of equations whose orders are equal") + match['order'] = list(order.values())[0] + def recur_len(l): + return sum(recur_len(item) if isinstance(item,list) else 1 for item in l) + if recur_len(func) != len(eq): + raise ValueError("dsolve() and classify_sysode() work with " + "number of functions being equal to number of equations") + if match['type_of_equation'] is None: + raise NotImplementedError + else: + if match['is_linear'] == True: + solvefunc = globals()['sysode_linear_%(no_of_equation)seq_order%(order)s' % match] + else: + solvefunc = globals()['sysode_nonlinear_%(no_of_equation)seq_order%(order)s' % match] + sols = solvefunc(match) + if ics: + constants = Tuple(*sols).free_symbols - Tuple(*eq).free_symbols + solved_constants = solve_ics(sols, func, constants, ics) + return [sol.subs(solved_constants) for sol in sols] + return sols + else: + given_hint = hint # hint given by the user + + # See the docstring of _desolve for more details. + hints = _desolve(eq, func=func, + hint=hint, simplify=True, xi=xi, eta=eta, type='ode', ics=ics, + x0=x0, n=n, **kwargs) + eq = hints.pop('eq', eq) + all_ = hints.pop('all', False) + if all_: + retdict = {} + failed_hints = {} + gethints = classify_ode(eq, dict=True, hint='all') + orderedhints = gethints['ordered_hints'] + for hint in hints: + try: + rv = _helper_simplify(eq, hint, hints[hint], simplify) + except NotImplementedError as detail: + failed_hints[hint] = detail + else: + retdict[hint] = rv + func = hints[hint]['func'] + + retdict['best'] = min(list(retdict.values()), key=lambda x: + ode_sol_simplicity(x, func, trysolving=not simplify)) + if given_hint == 'best': + return retdict['best'] + for i in orderedhints: + if retdict['best'] == retdict.get(i, None): + retdict['best_hint'] = i + break + retdict['default'] = gethints['default'] + retdict['order'] = gethints['order'] + retdict.update(failed_hints) + return retdict + + else: + # The key 'hint' stores the hint needed to be solved for. + hint = hints['hint'] + return _helper_simplify(eq, hint, hints, simplify, ics=ics) + + +def _helper_simplify(eq, hint, match, simplify=True, ics=None, **kwargs): + r""" + Helper function of dsolve that calls the respective + :py:mod:`~sympy.solvers.ode` functions to solve for the ordinary + differential equations. This minimizes the computation in calling + :py:meth:`~sympy.solvers.deutils._desolve` multiple times. + """ + r = match + func = r['func'] + order = r['order'] + match = r[hint] + + if isinstance(match, SingleODESolver): + solvefunc = match + else: + solvefunc = globals()['ode_' + hint.removesuffix('_Integral')] + + free = eq.free_symbols + cons = lambda s: s.free_symbols.difference(free) + + if simplify: + # odesimp() will attempt to integrate, if necessary, apply constantsimp(), + # attempt to solve for func, and apply any other hint specific + # simplifications + if isinstance(solvefunc, SingleODESolver): + sols = solvefunc.get_general_solution() + else: + sols = solvefunc(eq, func, order, match) + if iterable(sols): + rv = [] + for s in sols: + simp = odesimp(eq, s, func, hint) + if iterable(simp): + rv.extend(simp) + else: + rv.append(simp) + else: + rv = odesimp(eq, sols, func, hint) + else: + # We still want to integrate (you can disable it separately with the hint) + if isinstance(solvefunc, SingleODESolver): + exprs = solvefunc.get_general_solution(simplify=False) + else: + match['simplify'] = False # Some hints can take advantage of this option + exprs = solvefunc(eq, func, order, match) + if isinstance(exprs, list): + rv = [_handle_Integral(expr, func, hint) for expr in exprs] + else: + rv = _handle_Integral(exprs, func, hint) + + if isinstance(rv, list): + assert all(isinstance(i, Eq) for i in rv), rv # if not => internal error + if simplify: + rv = _remove_redundant_solutions(eq, rv, order, func.args[0]) + if len(rv) == 1: + rv = rv[0] + if ics and 'power_series' not in hint: + if isinstance(rv, (Expr, Eq)): + solved_constants = solve_ics([rv], [r['func']], cons(rv), ics) + rv = rv.subs(solved_constants) + else: + rv1 = [] + for s in rv: + try: + solved_constants = solve_ics([s], [r['func']], cons(s), ics) + except ValueError: + continue + rv1.append(s.subs(solved_constants)) + if len(rv1) == 1: + return rv1[0] + rv = rv1 + return rv + + +def solve_ics(sols, funcs, constants, ics): + """ + Solve for the constants given initial conditions + + ``sols`` is a list of solutions. + + ``funcs`` is a list of functions. + + ``constants`` is a list of constants. + + ``ics`` is the set of initial/boundary conditions for the differential + equation. It should be given in the form of ``{f(x0): x1, + f(x).diff(x).subs(x, x2): x3}`` and so on. + + Returns a dictionary mapping constants to values. + ``solution.subs(constants)`` will replace the constants in ``solution``. + + Example + ======= + >>> # From dsolve(f(x).diff(x) - f(x), f(x)) + >>> from sympy import symbols, Eq, exp, Function + >>> from sympy.solvers.ode.ode import solve_ics + >>> f = Function('f') + >>> x, C1 = symbols('x C1') + >>> sols = [Eq(f(x), C1*exp(x))] + >>> funcs = [f(x)] + >>> constants = [C1] + >>> ics = {f(0): 2} + >>> solved_constants = solve_ics(sols, funcs, constants, ics) + >>> solved_constants + {C1: 2} + >>> sols[0].subs(solved_constants) + Eq(f(x), 2*exp(x)) + + """ + # Assume ics are of the form f(x0): value or Subs(diff(f(x), x, n), (x, + # x0)): value (currently checked by classify_ode). To solve, replace x + # with x0, f(x0) with value, then solve for constants. For f^(n)(x0), + # differentiate the solution n times, so that f^(n)(x) appears. + x = funcs[0].args[0] + diff_sols = [] + subs_sols = [] + diff_variables = set() + for funcarg, value in ics.items(): + if isinstance(funcarg, AppliedUndef): + x0 = funcarg.args[0] + matching_func = [f for f in funcs if f.func == funcarg.func][0] + S = sols + elif isinstance(funcarg, (Subs, Derivative)): + if isinstance(funcarg, Subs): + # Make sure it stays a subs. Otherwise subs below will produce + # a different looking term. + funcarg = funcarg.doit() + if isinstance(funcarg, Subs): + deriv = funcarg.expr + x0 = funcarg.point[0] + variables = funcarg.expr.variables + matching_func = deriv + elif isinstance(funcarg, Derivative): + deriv = funcarg + x0 = funcarg.variables[0] + variables = (x,)*len(funcarg.variables) + matching_func = deriv.subs(x0, x) + for sol in sols: + if sol.has(deriv.expr.func): + diff_sols.append(Eq(sol.lhs.diff(*variables), sol.rhs.diff(*variables))) + diff_variables.add(variables) + S = diff_sols + else: + raise NotImplementedError("Unrecognized initial condition") + + for sol in S: + if sol.has(matching_func): + sol2 = sol + sol2 = sol2.subs(x, x0) + sol2 = sol2.subs(funcarg, value) + # This check is necessary because of issue #15724 + if not isinstance(sol2, BooleanAtom) or not subs_sols: + subs_sols = [s for s in subs_sols if not isinstance(s, BooleanAtom)] + subs_sols.append(sol2) + + # TODO: Use solveset here + try: + solved_constants = solve(subs_sols, constants, dict=True) + except NotImplementedError: + solved_constants = [] + + # XXX: We can't differentiate between the solution not existing because of + # invalid initial conditions, and not existing because solve is not smart + # enough. If we could use solveset, this might be improvable, but for now, + # we use NotImplementedError in this case. + if not solved_constants: + raise ValueError("Couldn't solve for initial conditions") + + if solved_constants == True: + raise ValueError("Initial conditions did not produce any solutions for constants. Perhaps they are degenerate.") + + if len(solved_constants) > 1: + raise NotImplementedError("Initial conditions produced too many solutions for constants") + + return solved_constants[0] + +def classify_ode(eq, func=None, dict=False, ics=None, *, prep=True, xi=None, eta=None, n=None, **kwargs): + r""" + Returns a tuple of possible :py:meth:`~sympy.solvers.ode.dsolve` + classifications for an ODE. + + The tuple is ordered so that first item is the classification that + :py:meth:`~sympy.solvers.ode.dsolve` uses to solve the ODE by default. In + general, classifications at the near the beginning of the list will + produce better solutions faster than those near the end, thought there are + always exceptions. To make :py:meth:`~sympy.solvers.ode.dsolve` use a + different classification, use ``dsolve(ODE, func, + hint=)``. See also the + :py:meth:`~sympy.solvers.ode.dsolve` docstring for different meta-hints + you can use. + + If ``dict`` is true, :py:meth:`~sympy.solvers.ode.classify_ode` will + return a dictionary of ``hint:match`` expression terms. This is intended + for internal use by :py:meth:`~sympy.solvers.ode.dsolve`. Note that + because dictionaries are ordered arbitrarily, this will most likely not be + in the same order as the tuple. + + You can get help on different hints by executing + ``help(ode.ode_hintname)``, where ``hintname`` is the name of the hint + without ``_Integral``. + + See :py:data:`~sympy.solvers.ode.allhints` or the + :py:mod:`~sympy.solvers.ode` docstring for a list of all supported hints + that can be returned from :py:meth:`~sympy.solvers.ode.classify_ode`. + + Notes + ===== + + These are remarks on hint names. + + ``_Integral`` + + If a classification has ``_Integral`` at the end, it will return the + expression with an unevaluated :py:class:`~.Integral` + class in it. Note that a hint may do this anyway if + :py:meth:`~sympy.core.expr.Expr.integrate` cannot do the integral, + though just using an ``_Integral`` will do so much faster. Indeed, an + ``_Integral`` hint will always be faster than its corresponding hint + without ``_Integral`` because + :py:meth:`~sympy.core.expr.Expr.integrate` is an expensive routine. + If :py:meth:`~sympy.solvers.ode.dsolve` hangs, it is probably because + :py:meth:`~sympy.core.expr.Expr.integrate` is hanging on a tough or + impossible integral. Try using an ``_Integral`` hint or + ``all_Integral`` to get it return something. + + Note that some hints do not have ``_Integral`` counterparts. This is + because :py:func:`~sympy.integrals.integrals.integrate` is not used in + solving the ODE for those method. For example, `n`\th order linear + homogeneous ODEs with constant coefficients do not require integration + to solve, so there is no + ``nth_linear_homogeneous_constant_coeff_Integrate`` hint. You can + easily evaluate any unevaluated + :py:class:`~sympy.integrals.integrals.Integral`\s in an expression by + doing ``expr.doit()``. + + Ordinals + + Some hints contain an ordinal such as ``1st_linear``. This is to help + differentiate them from other hints, as well as from other methods + that may not be implemented yet. If a hint has ``nth`` in it, such as + the ``nth_linear`` hints, this means that the method used to applies + to ODEs of any order. + + ``indep`` and ``dep`` + + Some hints contain the words ``indep`` or ``dep``. These reference + the independent variable and the dependent function, respectively. For + example, if an ODE is in terms of `f(x)`, then ``indep`` will refer to + `x` and ``dep`` will refer to `f`. + + ``subs`` + + If a hints has the word ``subs`` in it, it means that the ODE is solved + by substituting the expression given after the word ``subs`` for a + single dummy variable. This is usually in terms of ``indep`` and + ``dep`` as above. The substituted expression will be written only in + characters allowed for names of Python objects, meaning operators will + be spelled out. For example, ``indep``/``dep`` will be written as + ``indep_div_dep``. + + ``coeff`` + + The word ``coeff`` in a hint refers to the coefficients of something + in the ODE, usually of the derivative terms. See the docstring for + the individual methods for more info (``help(ode)``). This is + contrast to ``coefficients``, as in ``undetermined_coefficients``, + which refers to the common name of a method. + + ``_best`` + + Methods that have more than one fundamental way to solve will have a + hint for each sub-method and a ``_best`` meta-classification. This + will evaluate all hints and return the best, using the same + considerations as the normal ``best`` meta-hint. + + + Examples + ======== + + >>> from sympy import Function, classify_ode, Eq + >>> from sympy.abc import x + >>> f = Function('f') + >>> classify_ode(Eq(f(x).diff(x), 0), f(x)) + ('nth_algebraic', + 'separable', + '1st_exact', + '1st_linear', + 'Bernoulli', + '1st_homogeneous_coeff_best', + '1st_homogeneous_coeff_subs_indep_div_dep', + '1st_homogeneous_coeff_subs_dep_div_indep', + '1st_power_series', 'lie_group', 'nth_linear_constant_coeff_homogeneous', + 'nth_linear_euler_eq_homogeneous', + 'nth_algebraic_Integral', 'separable_Integral', '1st_exact_Integral', + '1st_linear_Integral', 'Bernoulli_Integral', + '1st_homogeneous_coeff_subs_indep_div_dep_Integral', + '1st_homogeneous_coeff_subs_dep_div_indep_Integral') + >>> classify_ode(f(x).diff(x, 2) + 3*f(x).diff(x) + 2*f(x) - 4) + ('factorable', 'nth_linear_constant_coeff_undetermined_coefficients', + 'nth_linear_constant_coeff_variation_of_parameters', + 'nth_linear_constant_coeff_variation_of_parameters_Integral') + + """ + ics = sympify(ics) + + if func and len(func.args) != 1: + raise ValueError("dsolve() and classify_ode() only " + "work with functions of one variable, not %s" % func) + + if isinstance(eq, Equality): + eq = eq.lhs - eq.rhs + + # Some methods want the unprocessed equation + eq_orig = eq + + if prep or func is None: + eq, func_ = _preprocess(eq, func) + if func is None: + func = func_ + x = func.args[0] + f = func.func + y = Dummy('y') + terms = 5 if n is None else n + + order = ode_order(eq, f(x)) + # hint:matchdict or hint:(tuple of matchdicts) + # Also will contain "default": and "order":order items. + matching_hints = {"order": order} + + df = f(x).diff(x) + a = Wild('a', exclude=[f(x)]) + d = Wild('d', exclude=[df, f(x).diff(x, 2)]) + e = Wild('e', exclude=[df]) + n = Wild('n', exclude=[x, f(x), df]) + c1 = Wild('c1', exclude=[x]) + a3 = Wild('a3', exclude=[f(x), df, f(x).diff(x, 2)]) + b3 = Wild('b3', exclude=[f(x), df, f(x).diff(x, 2)]) + c3 = Wild('c3', exclude=[f(x), df, f(x).diff(x, 2)]) + boundary = {} # Used to extract initial conditions + C1 = Symbol("C1") + + # Preprocessing to get the initial conditions out + if ics is not None: + for funcarg in ics: + # Separating derivatives + if isinstance(funcarg, (Subs, Derivative)): + # f(x).diff(x).subs(x, 0) is a Subs, but f(x).diff(x).subs(x, + # y) is a Derivative + if isinstance(funcarg, Subs): + deriv = funcarg.expr + old = funcarg.variables[0] + new = funcarg.point[0] + elif isinstance(funcarg, Derivative): + deriv = funcarg + # No information on this. Just assume it was x + old = x + new = funcarg.variables[0] + + if (isinstance(deriv, Derivative) and isinstance(deriv.args[0], + AppliedUndef) and deriv.args[0].func == f and + len(deriv.args[0].args) == 1 and old == x and not + new.has(x) and all(i == deriv.variables[0] for i in + deriv.variables) and x not in ics[funcarg].free_symbols): + + dorder = ode_order(deriv, x) + temp = 'f' + str(dorder) + boundary.update({temp: new, temp + 'val': ics[funcarg]}) + else: + raise ValueError("Invalid boundary conditions for Derivatives") + + + # Separating functions + elif isinstance(funcarg, AppliedUndef): + if (funcarg.func == f and len(funcarg.args) == 1 and + not funcarg.args[0].has(x) and x not in ics[funcarg].free_symbols): + boundary.update({'f0': funcarg.args[0], 'f0val': ics[funcarg]}) + else: + raise ValueError("Invalid boundary conditions for Function") + + else: + raise ValueError("Enter boundary conditions of the form ics={f(point): value, f(x).diff(x, order).subs(x, point): value}") + + ode = SingleODEProblem(eq_orig, func, x, prep=prep, xi=xi, eta=eta) + user_hint = kwargs.get('hint', 'default') + # Used when dsolve is called without an explicit hint. + # We exit early to return the first valid match + early_exit = (user_hint=='default') + user_hint = user_hint.removesuffix('_Integral') + user_map = solver_map + # An explicit hint has been given to dsolve + # Skip matching code for other hints + if user_hint not in ['default', 'all', 'all_Integral', 'best'] and user_hint in solver_map: + user_map = {user_hint: solver_map[user_hint]} + + for hint in user_map: + solver = user_map[hint](ode) + if solver.matches(): + matching_hints[hint] = solver + if user_map[hint].has_integral: + matching_hints[hint + "_Integral"] = solver + if dict and early_exit: + matching_hints["default"] = hint + return matching_hints + + eq = expand(eq) + # Precondition to try remove f(x) from highest order derivative + reduced_eq = None + if eq.is_Add: + deriv_coef = eq.coeff(f(x).diff(x, order)) + if deriv_coef not in (1, 0): + r = deriv_coef.match(a*f(x)**c1) + if r and r[c1]: + den = f(x)**r[c1] + reduced_eq = Add(*[arg/den for arg in eq.args]) + if not reduced_eq: + reduced_eq = eq + + if order == 1: + + # NON-REDUCED FORM OF EQUATION matches + r = collect(eq, df, exact=True).match(d + e * df) + if r: + r['d'] = d + r['e'] = e + r['y'] = y + r[d] = r[d].subs(f(x), y) + r[e] = r[e].subs(f(x), y) + + # FIRST ORDER POWER SERIES WHICH NEEDS INITIAL CONDITIONS + # TODO: Hint first order series should match only if d/e is analytic. + # For now, only d/e and (d/e).diff(arg) is checked for existence at + # at a given point. + # This is currently done internally in ode_1st_power_series. + point = boundary.get('f0', 0) + value = boundary.get('f0val', C1) + check = cancel(r[d]/r[e]) + check1 = check.subs({x: point, y: value}) + if not check1.has(oo) and not check1.has(zoo) and \ + not check1.has(nan) and not check1.has(-oo): + check2 = (check1.diff(x)).subs({x: point, y: value}) + if not check2.has(oo) and not check2.has(zoo) and \ + not check2.has(nan) and not check2.has(-oo): + rseries = r.copy() + rseries.update({'terms': terms, 'f0': point, 'f0val': value}) + matching_hints["1st_power_series"] = rseries + + elif order == 2: + # Homogeneous second order differential equation of the form + # a3*f(x).diff(x, 2) + b3*f(x).diff(x) + c3 + # It has a definite power series solution at point x0 if, b3/a3 and c3/a3 + # are analytic at x0. + deq = a3*(f(x).diff(x, 2)) + b3*df + c3*f(x) + r = collect(reduced_eq, + [f(x).diff(x, 2), f(x).diff(x), f(x)]).match(deq) + ordinary = False + if r: + if not all(r[key].is_polynomial() for key in r): + n, d = reduced_eq.as_numer_denom() + reduced_eq = expand(n) + r = collect(reduced_eq, + [f(x).diff(x, 2), f(x).diff(x), f(x)]).match(deq) + if r and r[a3] != 0: + p = cancel(r[b3]/r[a3]) # Used below + q = cancel(r[c3]/r[a3]) # Used below + point = kwargs.get('x0', 0) + check = p.subs(x, point) + if not check.has(oo, nan, zoo, -oo): + check = q.subs(x, point) + if not check.has(oo, nan, zoo, -oo): + ordinary = True + r.update({'a3': a3, 'b3': b3, 'c3': c3, 'x0': point, 'terms': terms}) + matching_hints["2nd_power_series_ordinary"] = r + + # Checking if the differential equation has a regular singular point + # at x0. It has a regular singular point at x0, if (b3/a3)*(x - x0) + # and (c3/a3)*((x - x0)**2) are analytic at x0. + if not ordinary: + p = cancel((x - point)*p) + check = p.subs(x, point) + if not check.has(oo, nan, zoo, -oo): + q = cancel(((x - point)**2)*q) + check = q.subs(x, point) + if not check.has(oo, nan, zoo, -oo): + coeff_dict = {'p': p, 'q': q, 'x0': point, 'terms': terms} + matching_hints["2nd_power_series_regular"] = coeff_dict + + + # Order keys based on allhints. + retlist = [i for i in allhints if i in matching_hints] + if dict: + # Dictionaries are ordered arbitrarily, so make note of which + # hint would come first for dsolve(). Use an ordered dict in Py 3. + matching_hints["default"] = retlist[0] if retlist else None + matching_hints["ordered_hints"] = tuple(retlist) + return matching_hints + else: + return tuple(retlist) + + +def classify_sysode(eq, funcs=None, **kwargs): + r""" + Returns a dictionary of parameter names and values that define the system + of ordinary differential equations in ``eq``. + The parameters are further used in + :py:meth:`~sympy.solvers.ode.dsolve` for solving that system. + + Some parameter names and values are: + + 'is_linear' (boolean), which tells whether the given system is linear. + Note that "linear" here refers to the operator: terms such as ``x*diff(x,t)`` are + nonlinear, whereas terms like ``sin(t)*diff(x,t)`` are still linear operators. + + 'func' (list) contains the :py:class:`~sympy.core.function.Function`s that + appear with a derivative in the ODE, i.e. those that we are trying to solve + the ODE for. + + 'order' (dict) with the maximum derivative for each element of the 'func' + parameter. + + 'func_coeff' (dict or Matrix) with the coefficient for each triple ``(equation number, + function, order)```. The coefficients are those subexpressions that do not + appear in 'func', and hence can be considered constant for purposes of ODE + solving. The value of this parameter can also be a Matrix if the system of ODEs are + linear first order of the form X' = AX where X is the vector of dependent variables. + Here, this function returns the coefficient matrix A. + + 'eq' (list) with the equations from ``eq``, sympified and transformed into + expressions (we are solving for these expressions to be zero). + + 'no_of_equations' (int) is the number of equations (same as ``len(eq)``). + + 'type_of_equation' (string) is an internal classification of the type of + ODE. + + 'is_constant' (boolean), which tells if the system of ODEs is constant coefficient + or not. This key is temporary addition for now and is in the match dict only when + the system of ODEs is linear first order constant coefficient homogeneous. So, this + key's value is True for now if it is available else it does not exist. + + 'is_homogeneous' (boolean), which tells if the system of ODEs is homogeneous. Like the + key 'is_constant', this key is a temporary addition and it is True since this key value + is available only when the system is linear first order constant coefficient homogeneous. + + References + ========== + -https://eqworld.ipmnet.ru/en/solutions/sysode/sode-toc1.htm + -A. D. Polyanin and A. V. Manzhirov, Handbook of Mathematics for Engineers and Scientists + + Examples + ======== + + >>> from sympy import Function, Eq, symbols, diff + >>> from sympy.solvers.ode.ode import classify_sysode + >>> from sympy.abc import t + >>> f, x, y = symbols('f, x, y', cls=Function) + >>> k, l, m, n = symbols('k, l, m, n', Integer=True) + >>> x1 = diff(x(t), t) ; y1 = diff(y(t), t) + >>> x2 = diff(x(t), t, t) ; y2 = diff(y(t), t, t) + >>> eq = (Eq(x1, 12*x(t) - 6*y(t)), Eq(y1, 11*x(t) + 3*y(t))) + >>> classify_sysode(eq) + {'eq': [-12*x(t) + 6*y(t) + Derivative(x(t), t), -11*x(t) - 3*y(t) + Derivative(y(t), t)], 'func': [x(t), y(t)], + 'func_coeff': {(0, x(t), 0): -12, (0, x(t), 1): 1, (0, y(t), 0): 6, (0, y(t), 1): 0, (1, x(t), 0): -11, (1, x(t), 1): 0, (1, y(t), 0): -3, (1, y(t), 1): 1}, 'is_linear': True, 'no_of_equation': 2, 'order': {x(t): 1, y(t): 1}, 'type_of_equation': None} + >>> eq = (Eq(diff(x(t),t), 5*t*x(t) + t**2*y(t) + 2), Eq(diff(y(t),t), -t**2*x(t) + 5*t*y(t))) + >>> classify_sysode(eq) + {'eq': [-t**2*y(t) - 5*t*x(t) + Derivative(x(t), t) - 2, t**2*x(t) - 5*t*y(t) + Derivative(y(t), t)], + 'func': [x(t), y(t)], 'func_coeff': {(0, x(t), 0): -5*t, (0, x(t), 1): 1, (0, y(t), 0): -t**2, (0, y(t), 1): 0, + (1, x(t), 0): t**2, (1, x(t), 1): 0, (1, y(t), 0): -5*t, (1, y(t), 1): 1}, 'is_linear': True, 'no_of_equation': 2, + 'order': {x(t): 1, y(t): 1}, 'type_of_equation': None} + + """ + + # Sympify equations and convert iterables of equations into + # a list of equations + def _sympify(eq): + return list(map(sympify, eq if iterable(eq) else [eq])) + + eq, funcs = (_sympify(w) for w in [eq, funcs]) + for i, fi in enumerate(eq): + if isinstance(fi, Equality): + eq[i] = fi.lhs - fi.rhs + + t = list(list(eq[0].atoms(Derivative))[0].atoms(Symbol))[0] + matching_hints = {"no_of_equation":i+1} + matching_hints['eq'] = eq + if i==0: + raise ValueError("classify_sysode() works for systems of ODEs. " + "For scalar ODEs, classify_ode should be used") + + # find all the functions if not given + order = {} + if funcs==[None]: + funcs = _extract_funcs(eq) + + funcs = list(set(funcs)) + if len(funcs) != len(eq): + raise ValueError("Number of functions given is not equal to the number of equations %s" % funcs) + + # This logic of list of lists in funcs to + # be replaced later. + func_dict = {} + for func in funcs: + if not order.get(func, False): + max_order = 0 + for i, eqs_ in enumerate(eq): + order_ = ode_order(eqs_,func) + if max_order < order_: + max_order = order_ + eq_no = i + if eq_no in func_dict: + func_dict[eq_no] = [func_dict[eq_no], func] + else: + func_dict[eq_no] = func + order[func] = max_order + + funcs = [func_dict[i] for i in range(len(func_dict))] + matching_hints['func'] = funcs + for func in funcs: + if isinstance(func, list): + for func_elem in func: + if len(func_elem.args) != 1: + raise ValueError("dsolve() and classify_sysode() work with " + "functions of one variable only, not %s" % func) + else: + if func and len(func.args) != 1: + raise ValueError("dsolve() and classify_sysode() work with " + "functions of one variable only, not %s" % func) + + # find the order of all equation in system of odes + matching_hints["order"] = order + + # find coefficients of terms f(t), diff(f(t),t) and higher derivatives + # and similarly for other functions g(t), diff(g(t),t) in all equations. + # Here j denotes the equation number, funcs[l] denotes the function about + # which we are talking about and k denotes the order of function funcs[l] + # whose coefficient we are calculating. + def linearity_check(eqs, j, func, is_linear_): + for k in range(order[func] + 1): + func_coef[j, func, k] = collect(eqs.expand(), [diff(func, t, k)]).coeff(diff(func, t, k)) + if is_linear_ == True: + if func_coef[j, func, k] == 0: + if k == 0: + coef = eqs.as_independent(func, as_Add=True)[1] + for xr in range(1, ode_order(eqs,func) + 1): + coef -= eqs.as_independent(diff(func, t, xr), as_Add=True)[1] + if coef != 0: + is_linear_ = False + else: + if eqs.as_independent(diff(func, t, k), as_Add=True)[1]: + is_linear_ = False + else: + for func_ in funcs: + if isinstance(func_, list): + for elem_func_ in func_: + dep = func_coef[j, func, k].as_independent(elem_func_, as_Add=True)[1] + if dep != 0: + is_linear_ = False + else: + dep = func_coef[j, func, k].as_independent(func_, as_Add=True)[1] + if dep != 0: + is_linear_ = False + return is_linear_ + + func_coef = {} + is_linear = True + for j, eqs in enumerate(eq): + for func in funcs: + if isinstance(func, list): + for func_elem in func: + is_linear = linearity_check(eqs, j, func_elem, is_linear) + else: + is_linear = linearity_check(eqs, j, func, is_linear) + matching_hints['func_coeff'] = func_coef + matching_hints['is_linear'] = is_linear + + + if len(set(order.values())) == 1: + order_eq = list(matching_hints['order'].values())[0] + if matching_hints['is_linear'] == True: + if matching_hints['no_of_equation'] == 2: + if order_eq == 1: + type_of_equation = check_linear_2eq_order1(eq, funcs, func_coef) + else: + type_of_equation = None + # If the equation does not match up with any of the + # general case solvers in systems.py and the number + # of equations is greater than 2, then NotImplementedError + # should be raised. + else: + type_of_equation = None + + else: + if matching_hints['no_of_equation'] == 2: + if order_eq == 1: + type_of_equation = check_nonlinear_2eq_order1(eq, funcs, func_coef) + else: + type_of_equation = None + elif matching_hints['no_of_equation'] == 3: + if order_eq == 1: + type_of_equation = check_nonlinear_3eq_order1(eq, funcs, func_coef) + else: + type_of_equation = None + else: + type_of_equation = None + else: + type_of_equation = None + + matching_hints['type_of_equation'] = type_of_equation + + return matching_hints + + +def check_linear_2eq_order1(eq, func, func_coef): + x = func[0].func + y = func[1].func + fc = func_coef + t = list(list(eq[0].atoms(Derivative))[0].atoms(Symbol))[0] + r = {} + # for equations Eq(a1*diff(x(t),t), b1*x(t) + c1*y(t) + d1) + # and Eq(a2*diff(y(t),t), b2*x(t) + c2*y(t) + d2) + r['a1'] = fc[0,x(t),1] ; r['a2'] = fc[1,y(t),1] + r['b1'] = -fc[0,x(t),0]/fc[0,x(t),1] ; r['b2'] = -fc[1,x(t),0]/fc[1,y(t),1] + r['c1'] = -fc[0,y(t),0]/fc[0,x(t),1] ; r['c2'] = -fc[1,y(t),0]/fc[1,y(t),1] + forcing = [S.Zero,S.Zero] + for i in range(2): + for j in Add.make_args(eq[i]): + if not j.has(x(t), y(t)): + forcing[i] += j + if not (forcing[0].has(t) or forcing[1].has(t)): + # We can handle homogeneous case and simple constant forcings + r['d1'] = forcing[0] + r['d2'] = forcing[1] + else: + # Issue #9244: nonhomogeneous linear systems are not supported + return None + + # Conditions to check for type 6 whose equations are Eq(diff(x(t),t), f(t)*x(t) + g(t)*y(t)) and + # Eq(diff(y(t),t), a*[f(t) + a*h(t)]x(t) + a*[g(t) - h(t)]*y(t)) + p = 0 + q = 0 + p1 = cancel(r['b2']/(cancel(r['b2']/r['c2']).as_numer_denom()[0])) + p2 = cancel(r['b1']/(cancel(r['b1']/r['c1']).as_numer_denom()[0])) + for n, i in enumerate([p1, p2]): + for j in Mul.make_args(collect_const(i)): + if not j.has(t): + q = j + if q and n==0: + if ((r['b2']/j - r['b1'])/(r['c1'] - r['c2']/j)) == j: + p = 1 + elif q and n==1: + if ((r['b1']/j - r['b2'])/(r['c2'] - r['c1']/j)) == j: + p = 2 + # End of condition for type 6 + + if r['d1']!=0 or r['d2']!=0: + return None + else: + if not any(r[k].has(t) for k in 'a1 a2 b1 b2 c1 c2'.split()): + return None + else: + r['b1'] = r['b1']/r['a1'] ; r['b2'] = r['b2']/r['a2'] + r['c1'] = r['c1']/r['a1'] ; r['c2'] = r['c2']/r['a2'] + if p: + return "type6" + else: + # Equations for type 7 are Eq(diff(x(t),t), f(t)*x(t) + g(t)*y(t)) and Eq(diff(y(t),t), h(t)*x(t) + p(t)*y(t)) + return "type7" +def check_nonlinear_2eq_order1(eq, func, func_coef): + t = list(list(eq[0].atoms(Derivative))[0].atoms(Symbol))[0] + f = Wild('f') + g = Wild('g') + u, v = symbols('u, v', cls=Dummy) + def check_type(x, y): + r1 = eq[0].match(t*diff(x(t),t) - x(t) + f) + r2 = eq[1].match(t*diff(y(t),t) - y(t) + g) + if not (r1 and r2): + r1 = eq[0].match(diff(x(t),t) - x(t)/t + f/t) + r2 = eq[1].match(diff(y(t),t) - y(t)/t + g/t) + if not (r1 and r2): + r1 = (-eq[0]).match(t*diff(x(t),t) - x(t) + f) + r2 = (-eq[1]).match(t*diff(y(t),t) - y(t) + g) + if not (r1 and r2): + r1 = (-eq[0]).match(diff(x(t),t) - x(t)/t + f/t) + r2 = (-eq[1]).match(diff(y(t),t) - y(t)/t + g/t) + if r1 and r2 and not (r1[f].subs(diff(x(t),t),u).subs(diff(y(t),t),v).has(t) \ + or r2[g].subs(diff(x(t),t),u).subs(diff(y(t),t),v).has(t)): + return 'type5' + else: + return None + for func_ in func: + if isinstance(func_, list): + x = func[0][0].func + y = func[0][1].func + eq_type = check_type(x, y) + if not eq_type: + eq_type = check_type(y, x) + return eq_type + x = func[0].func + y = func[1].func + fc = func_coef + n = Wild('n', exclude=[x(t),y(t)]) + f1 = Wild('f1', exclude=[v,t]) + f2 = Wild('f2', exclude=[v,t]) + g1 = Wild('g1', exclude=[u,t]) + g2 = Wild('g2', exclude=[u,t]) + for i in range(2): + eqs = 0 + for terms in Add.make_args(eq[i]): + eqs += terms/fc[i,func[i],1] + eq[i] = eqs + r = eq[0].match(diff(x(t),t) - x(t)**n*f) + if r: + g = (diff(y(t),t) - eq[1])/r[f] + if r and not (g.has(x(t)) or g.subs(y(t),v).has(t) or r[f].subs(x(t),u).subs(y(t),v).has(t)): + return 'type1' + r = eq[0].match(diff(x(t),t) - exp(n*x(t))*f) + if r: + g = (diff(y(t),t) - eq[1])/r[f] + if r and not (g.has(x(t)) or g.subs(y(t),v).has(t) or r[f].subs(x(t),u).subs(y(t),v).has(t)): + return 'type2' + g = Wild('g') + r1 = eq[0].match(diff(x(t),t) - f) + r2 = eq[1].match(diff(y(t),t) - g) + if r1 and r2 and not (r1[f].subs(x(t),u).subs(y(t),v).has(t) or \ + r2[g].subs(x(t),u).subs(y(t),v).has(t)): + return 'type3' + r1 = eq[0].match(diff(x(t),t) - f) + r2 = eq[1].match(diff(y(t),t) - g) + num, den = ( + (r1[f].subs(x(t),u).subs(y(t),v))/ + (r2[g].subs(x(t),u).subs(y(t),v))).as_numer_denom() + R1 = num.match(f1*g1) + R2 = den.match(f2*g2) + # phi = (r1[f].subs(x(t),u).subs(y(t),v))/num + if R1 and R2: + return 'type4' + return None + + +def check_nonlinear_2eq_order2(eq, func, func_coef): + return None + +def check_nonlinear_3eq_order1(eq, func, func_coef): + x = func[0].func + y = func[1].func + z = func[2].func + fc = func_coef + t = list(list(eq[0].atoms(Derivative))[0].atoms(Symbol))[0] + u, v, w = symbols('u, v, w', cls=Dummy) + a = Wild('a', exclude=[x(t), y(t), z(t), t]) + b = Wild('b', exclude=[x(t), y(t), z(t), t]) + c = Wild('c', exclude=[x(t), y(t), z(t), t]) + f = Wild('f') + F1 = Wild('F1') + F2 = Wild('F2') + F3 = Wild('F3') + for i in range(3): + eqs = 0 + for terms in Add.make_args(eq[i]): + eqs += terms/fc[i,func[i],1] + eq[i] = eqs + r1 = eq[0].match(diff(x(t),t) - a*y(t)*z(t)) + r2 = eq[1].match(diff(y(t),t) - b*z(t)*x(t)) + r3 = eq[2].match(diff(z(t),t) - c*x(t)*y(t)) + if r1 and r2 and r3: + num1, den1 = r1[a].as_numer_denom() + num2, den2 = r2[b].as_numer_denom() + num3, den3 = r3[c].as_numer_denom() + if solve([num1*u-den1*(v-w), num2*v-den2*(w-u), num3*w-den3*(u-v)],[u, v]): + return 'type1' + r = eq[0].match(diff(x(t),t) - y(t)*z(t)*f) + if r: + r1 = collect_const(r[f]).match(a*f) + r2 = ((diff(y(t),t) - eq[1])/r1[f]).match(b*z(t)*x(t)) + r3 = ((diff(z(t),t) - eq[2])/r1[f]).match(c*x(t)*y(t)) + if r1 and r2 and r3: + num1, den1 = r1[a].as_numer_denom() + num2, den2 = r2[b].as_numer_denom() + num3, den3 = r3[c].as_numer_denom() + if solve([num1*u-den1*(v-w), num2*v-den2*(w-u), num3*w-den3*(u-v)],[u, v]): + return 'type2' + r = eq[0].match(diff(x(t),t) - (F2-F3)) + if r: + r1 = collect_const(r[F2]).match(c*F2) + r1.update(collect_const(r[F3]).match(b*F3)) + if r1: + if eq[1].has(r1[F2]) and not eq[1].has(r1[F3]): + r1[F2], r1[F3] = r1[F3], r1[F2] + r1[c], r1[b] = -r1[b], -r1[c] + r2 = eq[1].match(diff(y(t),t) - a*r1[F3] + r1[c]*F1) + if r2: + r3 = (eq[2] == diff(z(t),t) - r1[b]*r2[F1] + r2[a]*r1[F2]) + if r1 and r2 and r3: + return 'type3' + r = eq[0].match(diff(x(t),t) - z(t)*F2 + y(t)*F3) + if r: + r1 = collect_const(r[F2]).match(c*F2) + r1.update(collect_const(r[F3]).match(b*F3)) + if r1: + if eq[1].has(r1[F2]) and not eq[1].has(r1[F3]): + r1[F2], r1[F3] = r1[F3], r1[F2] + r1[c], r1[b] = -r1[b], -r1[c] + r2 = (diff(y(t),t) - eq[1]).match(a*x(t)*r1[F3] - r1[c]*z(t)*F1) + if r2: + r3 = (diff(z(t),t) - eq[2] == r1[b]*y(t)*r2[F1] - r2[a]*x(t)*r1[F2]) + if r1 and r2 and r3: + return 'type4' + r = (diff(x(t),t) - eq[0]).match(x(t)*(F2 - F3)) + if r: + r1 = collect_const(r[F2]).match(c*F2) + r1.update(collect_const(r[F3]).match(b*F3)) + if r1: + if eq[1].has(r1[F2]) and not eq[1].has(r1[F3]): + r1[F2], r1[F3] = r1[F3], r1[F2] + r1[c], r1[b] = -r1[b], -r1[c] + r2 = (diff(y(t),t) - eq[1]).match(y(t)*(a*r1[F3] - r1[c]*F1)) + if r2: + r3 = (diff(z(t),t) - eq[2] == z(t)*(r1[b]*r2[F1] - r2[a]*r1[F2])) + if r1 and r2 and r3: + return 'type5' + return None + + +def check_nonlinear_3eq_order2(eq, func, func_coef): + return None + + +@vectorize(0) +def odesimp(ode, eq, func, hint): + r""" + Simplifies solutions of ODEs, including trying to solve for ``func`` and + running :py:meth:`~sympy.solvers.ode.constantsimp`. + + It may use knowledge of the type of solution that the hint returns to + apply additional simplifications. + + It also attempts to integrate any :py:class:`~sympy.integrals.integrals.Integral`\s + in the expression, if the hint is not an ``_Integral`` hint. + + This function should have no effect on expressions returned by + :py:meth:`~sympy.solvers.ode.dsolve`, as + :py:meth:`~sympy.solvers.ode.dsolve` already calls + :py:meth:`~sympy.solvers.ode.ode.odesimp`, but the individual hint functions + do not call :py:meth:`~sympy.solvers.ode.ode.odesimp` (because the + :py:meth:`~sympy.solvers.ode.dsolve` wrapper does). Therefore, this + function is designed for mainly internal use. + + Examples + ======== + + >>> from sympy import sin, symbols, dsolve, pprint, Function + >>> from sympy.solvers.ode.ode import odesimp + >>> x, u2, C1= symbols('x,u2,C1') + >>> f = Function('f') + + >>> eq = dsolve(x*f(x).diff(x) - f(x) - x*sin(f(x)/x), f(x), + ... hint='1st_homogeneous_coeff_subs_indep_div_dep_Integral', + ... simplify=False) + >>> pprint(eq, wrap_line=False) + x + ---- + f(x) + / + | + | / 1 \ + | -|u1 + -------| + | | /1 \| + | | sin|--|| + | \ \u1// + log(f(x)) = log(C1) + | ---------------- d(u1) + | 2 + | u1 + | + / + + >>> pprint(odesimp(eq, f(x), 1, {C1}, + ... hint='1st_homogeneous_coeff_subs_indep_div_dep' + ... )) #doctest: +SKIP + x + --------- = C1 + /f(x)\ + tan|----| + \2*x / + + """ + x = func.args[0] + f = func.func + C1 = get_numbered_constants(eq, num=1) + constants = eq.free_symbols - ode.free_symbols + + # First, integrate if the hint allows it. + eq = _handle_Integral(eq, func, hint) + if hint.startswith("nth_linear_euler_eq_nonhomogeneous"): + eq = simplify(eq) + if not isinstance(eq, Equality): + raise TypeError("eq should be an instance of Equality") + + # allow simplifications under assumption that symbols are nonzero + eq = eq.xreplace((_:={i: Dummy(nonzero=True) for i in constants})).xreplace({_[i]: i for i in _}) + + # Second, clean up the arbitrary constants. + # Right now, nth linear hints can put as many as 2*order constants in an + # expression. If that number grows with another hint, the third argument + # here should be raised accordingly, or constantsimp() rewritten to handle + # an arbitrary number of constants. + eq = constantsimp(eq, constants) + + # Lastly, now that we have cleaned up the expression, try solving for func. + # When CRootOf is implemented in solve(), we will want to return a CRootOf + # every time instead of an Equality. + + # Get the f(x) on the left if possible. + if eq.rhs == func and not eq.lhs.has(func): + eq = [Eq(eq.rhs, eq.lhs)] + + # make sure we are working with lists of solutions in simplified form. + if eq.lhs == func and not eq.rhs.has(func): + # The solution is already solved + eq = [eq] + + else: + # The solution is not solved, so try to solve it + try: + floats = any(i.is_Float for i in eq.atoms(Number)) + eqsol = solve(eq, func, force=True, rational=False if floats else None) + if not eqsol: + raise NotImplementedError + except (NotImplementedError, PolynomialError): + eq = [eq] + else: + def _expand(expr): + numer, denom = expr.as_numer_denom() + + if denom.is_Add: + return expr + else: + return powsimp(expr.expand(), combine='exp', deep=True) + + # XXX: the rest of odesimp() expects each ``t`` to be in a + # specific normal form: rational expression with numerator + # expanded, but with combined exponential functions (at + # least in this setup all tests pass). + eq = [Eq(f(x), _expand(t)) for t in eqsol] + + # special simplification of the lhs. + if hint.startswith("1st_homogeneous_coeff"): + for j, eqi in enumerate(eq): + newi = logcombine(eqi, force=True) + if isinstance(newi.lhs, log) and newi.rhs == 0: + newi = Eq(newi.lhs.args[0]/C1, C1) + eq[j] = newi + + # We cleaned up the constants before solving to help the solve engine with + # a simpler expression, but the solved expression could have introduced + # things like -C1, so rerun constantsimp() one last time before returning. + for i, eqi in enumerate(eq): + eq[i] = constantsimp(eqi, constants) + eq[i] = constant_renumber(eq[i], ode.free_symbols) + + # If there is only 1 solution, return it; + # otherwise return the list of solutions. + if len(eq) == 1: + eq = eq[0] + return eq + + +def ode_sol_simplicity(sol, func, trysolving=True): + r""" + Returns an extended integer representing how simple a solution to an ODE + is. + + The following things are considered, in order from most simple to least: + + - ``sol`` is solved for ``func``. + - ``sol`` is not solved for ``func``, but can be if passed to solve (e.g., + a solution returned by ``dsolve(ode, func, simplify=False``). + - If ``sol`` is not solved for ``func``, then base the result on the + length of ``sol``, as computed by ``len(str(sol))``. + - If ``sol`` has any unevaluated :py:class:`~sympy.integrals.integrals.Integral`\s, + this will automatically be considered less simple than any of the above. + + This function returns an integer such that if solution A is simpler than + solution B by above metric, then ``ode_sol_simplicity(sola, func) < + ode_sol_simplicity(solb, func)``. + + Currently, the following are the numbers returned, but if the heuristic is + ever improved, this may change. Only the ordering is guaranteed. + + +----------------------------------------------+-------------------+ + | Simplicity | Return | + +==============================================+===================+ + | ``sol`` solved for ``func`` | ``-2`` | + +----------------------------------------------+-------------------+ + | ``sol`` not solved for ``func`` but can be | ``-1`` | + +----------------------------------------------+-------------------+ + | ``sol`` is not solved nor solvable for | ``len(str(sol))`` | + | ``func`` | | + +----------------------------------------------+-------------------+ + | ``sol`` contains an | ``oo`` | + | :obj:`~sympy.integrals.integrals.Integral` | | + +----------------------------------------------+-------------------+ + + ``oo`` here means the SymPy infinity, which should compare greater than + any integer. + + If you already know :py:meth:`~sympy.solvers.solvers.solve` cannot solve + ``sol``, you can use ``trysolving=False`` to skip that step, which is the + only potentially slow step. For example, + :py:meth:`~sympy.solvers.ode.dsolve` with the ``simplify=False`` flag + should do this. + + If ``sol`` is a list of solutions, if the worst solution in the list + returns ``oo`` it returns that, otherwise it returns ``len(str(sol))``, + that is, the length of the string representation of the whole list. + + Examples + ======== + + This function is designed to be passed to ``min`` as the key argument, + such as ``min(listofsolutions, key=lambda i: ode_sol_simplicity(i, + f(x)))``. + + >>> from sympy import symbols, Function, Eq, tan, Integral + >>> from sympy.solvers.ode.ode import ode_sol_simplicity + >>> x, C1, C2 = symbols('x, C1, C2') + >>> f = Function('f') + + >>> ode_sol_simplicity(Eq(f(x), C1*x**2), f(x)) + -2 + >>> ode_sol_simplicity(Eq(x**2 + f(x), C1), f(x)) + -1 + >>> ode_sol_simplicity(Eq(f(x), C1*Integral(2*x, x)), f(x)) + oo + >>> eq1 = Eq(f(x)/tan(f(x)/(2*x)), C1) + >>> eq2 = Eq(f(x)/tan(f(x)/(2*x) + f(x)), C2) + >>> [ode_sol_simplicity(eq, f(x)) for eq in [eq1, eq2]] + [28, 35] + >>> min([eq1, eq2], key=lambda i: ode_sol_simplicity(i, f(x))) + Eq(f(x)/tan(f(x)/(2*x)), C1) + + """ + # TODO: if two solutions are solved for f(x), we still want to be + # able to get the simpler of the two + + # See the docstring for the coercion rules. We check easier (faster) + # things here first, to save time. + + if iterable(sol): + # See if there are Integrals + for i in sol: + if ode_sol_simplicity(i, func, trysolving=trysolving) == oo: + return oo + + return len(str(sol)) + + if sol.has(Integral): + return oo + + # Next, try to solve for func. This code will change slightly when CRootOf + # is implemented in solve(). Probably a CRootOf solution should fall + # somewhere between a normal solution and an unsolvable expression. + + # First, see if they are already solved + if sol.lhs == func and not sol.rhs.has(func) or \ + sol.rhs == func and not sol.lhs.has(func): + return -2 + # We are not so lucky, try solving manually + if trysolving: + try: + sols = solve(sol, func) + if not sols: + raise NotImplementedError + except NotImplementedError: + pass + else: + return -1 + + # Finally, a naive computation based on the length of the string version + # of the expression. This may favor combined fractions because they + # will not have duplicate denominators, and may slightly favor expressions + # with fewer additions and subtractions, as those are separated by spaces + # by the printer. + + # Additional ideas for simplicity heuristics are welcome, like maybe + # checking if a equation has a larger domain, or if constantsimp has + # introduced arbitrary constants numbered higher than the order of a + # given ODE that sol is a solution of. + return len(str(sol)) + + +def _extract_funcs(eqs): + funcs = [] + for eq in eqs: + derivs = [node for node in preorder_traversal(eq) if isinstance(node, Derivative)] + func = [] + for d in derivs: + func += list(d.atoms(AppliedUndef)) + for func_ in func: + funcs.append(func_) + funcs = list(uniq(funcs)) + + return funcs + + +def _get_constant_subexpressions(expr, Cs): + Cs = set(Cs) + Ces = [] + def _recursive_walk(expr): + expr_syms = expr.free_symbols + if expr_syms and expr_syms.issubset(Cs): + Ces.append(expr) + else: + if expr.func == exp: + expr = expr.expand(mul=True) + if expr.func in (Add, Mul): + d = sift(expr.args, lambda i : i.free_symbols.issubset(Cs)) + if len(d[True]) > 1: + x = expr.func(*d[True]) + if not x.is_number: + Ces.append(x) + elif isinstance(expr, Integral): + if expr.free_symbols.issubset(Cs) and \ + all(len(x) == 3 for x in expr.limits): + Ces.append(expr) + for i in expr.args: + _recursive_walk(i) + return + _recursive_walk(expr) + return Ces + +def __remove_linear_redundancies(expr, Cs): + cnts = {i: expr.count(i) for i in Cs} + Cs = [i for i in Cs if cnts[i] > 0] + + def _linear(expr): + if isinstance(expr, Add): + xs = [i for i in Cs if expr.count(i)==cnts[i] \ + and 0 == expr.diff(i, 2)] + d = {} + for x in xs: + y = expr.diff(x) + if y not in d: + d[y]=[] + d[y].append(x) + for y in d: + if len(d[y]) > 1: + d[y].sort(key=str) + for x in d[y][1:]: + expr = expr.subs(x, 0) + return expr + + def _recursive_walk(expr): + if len(expr.args) != 0: + expr = expr.func(*[_recursive_walk(i) for i in expr.args]) + expr = _linear(expr) + return expr + + if isinstance(expr, Equality): + lhs, rhs = [_recursive_walk(i) for i in expr.args] + f = lambda i: isinstance(i, Number) or i in Cs + if isinstance(lhs, Symbol) and lhs in Cs: + rhs, lhs = lhs, rhs + if lhs.func in (Add, Symbol) and rhs.func in (Add, Symbol): + dlhs = sift([lhs] if isinstance(lhs, AtomicExpr) else lhs.args, f) + drhs = sift([rhs] if isinstance(rhs, AtomicExpr) else rhs.args, f) + for i in [True, False]: + for hs in [dlhs, drhs]: + if i not in hs: + hs[i] = [0] + # this calculation can be simplified + lhs = Add(*dlhs[False]) - Add(*drhs[False]) + rhs = Add(*drhs[True]) - Add(*dlhs[True]) + elif lhs.func in (Mul, Symbol) and rhs.func in (Mul, Symbol): + dlhs = sift([lhs] if isinstance(lhs, AtomicExpr) else lhs.args, f) + if True in dlhs: + if False not in dlhs: + dlhs[False] = [1] + lhs = Mul(*dlhs[False]) + rhs = rhs/Mul(*dlhs[True]) + return Eq(lhs, rhs) + else: + return _recursive_walk(expr) + +@vectorize(0) +def constantsimp(expr, constants): + r""" + Simplifies an expression with arbitrary constants in it. + + This function is written specifically to work with + :py:meth:`~sympy.solvers.ode.dsolve`, and is not intended for general use. + + Simplification is done by "absorbing" the arbitrary constants into other + arbitrary constants, numbers, and symbols that they are not independent + of. + + The symbols must all have the same name with numbers after it, for + example, ``C1``, ``C2``, ``C3``. The ``symbolname`` here would be + '``C``', the ``startnumber`` would be 1, and the ``endnumber`` would be 3. + If the arbitrary constants are independent of the variable ``x``, then the + independent symbol would be ``x``. There is no need to specify the + dependent function, such as ``f(x)``, because it already has the + independent symbol, ``x``, in it. + + Because terms are "absorbed" into arbitrary constants and because + constants are renumbered after simplifying, the arbitrary constants in + expr are not necessarily equal to the ones of the same name in the + returned result. + + If two or more arbitrary constants are added, multiplied, or raised to the + power of each other, they are first absorbed together into a single + arbitrary constant. Then the new constant is combined into other terms if + necessary. + + Absorption of constants is done with limited assistance: + + 1. terms of :py:class:`~sympy.core.add.Add`\s are collected to try join + constants so `e^x (C_1 \cos(x) + C_2 \cos(x))` will simplify to `e^x + C_1 \cos(x)`; + + 2. powers with exponents that are :py:class:`~sympy.core.add.Add`\s are + expanded so `e^{C_1 + x}` will be simplified to `C_1 e^x`. + + Use :py:meth:`~sympy.solvers.ode.ode.constant_renumber` to renumber constants + after simplification or else arbitrary numbers on constants may appear, + e.g. `C_1 + C_3 x`. + + In rare cases, a single constant can be "simplified" into two constants. + Every differential equation solution should have as many arbitrary + constants as the order of the differential equation. The result here will + be technically correct, but it may, for example, have `C_1` and `C_2` in + an expression, when `C_1` is actually equal to `C_2`. Use your discretion + in such situations, and also take advantage of the ability to use hints in + :py:meth:`~sympy.solvers.ode.dsolve`. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.solvers.ode.ode import constantsimp + >>> C1, C2, C3, x, y = symbols('C1, C2, C3, x, y') + >>> constantsimp(2*C1*x, {C1, C2, C3}) + C1*x + >>> constantsimp(C1 + 2 + x, {C1, C2, C3}) + C1 + x + >>> constantsimp(C1*C2 + 2 + C2 + C3*x, {C1, C2, C3}) + C1 + C3*x + + """ + # This function works recursively. The idea is that, for Mul, + # Add, Pow, and Function, if the class has a constant in it, then + # we can simplify it, which we do by recursing down and + # simplifying up. Otherwise, we can skip that part of the + # expression. + + Cs = constants + + orig_expr = expr + + constant_subexprs = _get_constant_subexpressions(expr, Cs) + for xe in constant_subexprs: + xes = list(xe.free_symbols) + if not xes: + continue + if all(expr.count(c) == xe.count(c) for c in xes): + xes.sort(key=str) + expr = expr.subs(xe, xes[0]) + + # try to perform common sub-expression elimination of constant terms + try: + commons, rexpr = cse(expr) + commons.reverse() + rexpr = rexpr[0] + for s in commons: + cs = list(s[1].atoms(Symbol)) + if len(cs) == 1 and cs[0] in Cs and \ + cs[0] not in rexpr.atoms(Symbol) and \ + not any(cs[0] in ex for ex in commons if ex != s): + rexpr = rexpr.subs(s[0], cs[0]) + else: + rexpr = rexpr.subs(*s) + expr = rexpr + except IndexError: + pass + expr = __remove_linear_redundancies(expr, Cs) + + def _conditional_term_factoring(expr): + new_expr = terms_gcd(expr, clear=False, deep=True, expand=False) + + # we do not want to factor exponentials, so handle this separately + if new_expr.is_Mul: + infac = False + asfac = False + for m in new_expr.args: + if isinstance(m, exp): + asfac = True + elif m.is_Add: + infac = any(isinstance(fi, exp) for t in m.args + for fi in Mul.make_args(t)) + if asfac and infac: + new_expr = expr + break + return new_expr + + expr = _conditional_term_factoring(expr) + + # call recursively if more simplification is possible + if orig_expr != expr: + return constantsimp(expr, Cs) + return expr + + +def constant_renumber(expr, variables=None, newconstants=None): + r""" + Renumber arbitrary constants in ``expr`` to use the symbol names as given + in ``newconstants``. In the process, this reorders expression terms in a + standard way. + + If ``newconstants`` is not provided then the new constant names will be + ``C1``, ``C2`` etc. Otherwise ``newconstants`` should be an iterable + giving the new symbols to use for the constants in order. + + The ``variables`` argument is a list of non-constant symbols. All other + free symbols found in ``expr`` are assumed to be constants and will be + renumbered. If ``variables`` is not given then any numbered symbol + beginning with ``C`` (e.g. ``C1``) is assumed to be a constant. + + Symbols are renumbered based on ``.sort_key()``, so they should be + numbered roughly in the order that they appear in the final, printed + expression. Note that this ordering is based in part on hashes, so it can + produce different results on different machines. + + The structure of this function is very similar to that of + :py:meth:`~sympy.solvers.ode.constantsimp`. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.solvers.ode.ode import constant_renumber + >>> x, C1, C2, C3 = symbols('x,C1:4') + >>> expr = C3 + C2*x + C1*x**2 + >>> expr + C1*x**2 + C2*x + C3 + >>> constant_renumber(expr) + C1 + C2*x + C3*x**2 + + The ``variables`` argument specifies which are constants so that the + other symbols will not be renumbered: + + >>> constant_renumber(expr, [C1, x]) + C1*x**2 + C2 + C3*x + + The ``newconstants`` argument is used to specify what symbols to use when + replacing the constants: + + >>> constant_renumber(expr, [x], newconstants=symbols('E1:4')) + E1 + E2*x + E3*x**2 + + """ + + # System of expressions + if isinstance(expr, (set, list, tuple)): + return type(expr)(constant_renumber(Tuple(*expr), + variables=variables, newconstants=newconstants)) + + # Symbols in solution but not ODE are constants + if variables is not None: + variables = set(variables) + free_symbols = expr.free_symbols + constantsymbols = list(free_symbols - variables) + # Any Cn is a constant... + else: + variables = set() + isconstant = lambda s: s.startswith('C') and s[1:].isdigit() + constantsymbols = [sym for sym in expr.free_symbols if isconstant(sym.name)] + + # Find new constants checking that they aren't already in the ODE + if newconstants is None: + iter_constants = numbered_symbols(start=1, prefix='C', exclude=variables) + else: + iter_constants = (sym for sym in newconstants if sym not in variables) + + constants_found = [] + + # make a mapping to send all constantsymbols to S.One and use + # that to make sure that term ordering is not dependent on + # the indexed value of C + C_1 = [(ci, S.One) for ci in constantsymbols] + sort_key=lambda arg: default_sort_key(arg.subs(C_1)) + + def _constant_renumber(expr): + r""" + We need to have an internal recursive function + """ + + # For system of expressions + if isinstance(expr, Tuple): + renumbered = [_constant_renumber(e) for e in expr] + return Tuple(*renumbered) + + if isinstance(expr, Equality): + return Eq( + _constant_renumber(expr.lhs), + _constant_renumber(expr.rhs)) + + if type(expr) not in (Mul, Add, Pow) and not expr.is_Function and \ + not expr.has(*constantsymbols): + # Base case, as above. Hope there aren't constants inside + # of some other class, because they won't be renumbered. + return expr + elif expr.is_Piecewise: + return expr + elif expr in constantsymbols: + if expr not in constants_found: + constants_found.append(expr) + return expr + elif expr.is_Function or expr.is_Pow: + return expr.func( + *[_constant_renumber(x) for x in expr.args]) + else: + sortedargs = list(expr.args) + sortedargs.sort(key=sort_key) + return expr.func(*[_constant_renumber(x) for x in sortedargs]) + expr = _constant_renumber(expr) + + # Don't renumber symbols present in the ODE. + constants_found = [c for c in constants_found if c not in variables] + + # Renumbering happens here + subs_dict = dict(zip(constants_found, iter_constants)) + expr = expr.subs(subs_dict, simultaneous=True) + + return expr + + +def _handle_Integral(expr, func, hint): + r""" + Converts a solution with Integrals in it into an actual solution. + + For most hints, this simply runs ``expr.doit()``. + + """ + if hint == "nth_linear_constant_coeff_homogeneous": + sol = expr + elif not hint.endswith("_Integral"): + sol = expr.doit() + else: + sol = expr + return sol + + +# XXX: Should this function maybe go somewhere else? + + +def homogeneous_order(eq, *symbols): + r""" + Returns the order `n` if `g` is homogeneous and ``None`` if it is not + homogeneous. + + Determines if a function is homogeneous and if so of what order. A + function `f(x, y, \cdots)` is homogeneous of order `n` if `f(t x, t y, + \cdots) = t^n f(x, y, \cdots)`. + + If the function is of two variables, `F(x, y)`, then `f` being homogeneous + of any order is equivalent to being able to rewrite `F(x, y)` as `G(x/y)` + or `H(y/x)`. This fact is used to solve 1st order ordinary differential + equations whose coefficients are homogeneous of the same order (see the + docstrings of + :obj:`~sympy.solvers.ode.single.HomogeneousCoeffSubsDepDivIndep` and + :obj:`~sympy.solvers.ode.single.HomogeneousCoeffSubsIndepDivDep`). + + Symbols can be functions, but every argument of the function must be a + symbol, and the arguments of the function that appear in the expression + must match those given in the list of symbols. If a declared function + appears with different arguments than given in the list of symbols, + ``None`` is returned. + + Examples + ======== + + >>> from sympy import Function, homogeneous_order, sqrt + >>> from sympy.abc import x, y + >>> f = Function('f') + >>> homogeneous_order(f(x), f(x)) is None + True + >>> homogeneous_order(f(x,y), f(y, x), x, y) is None + True + >>> homogeneous_order(f(x), f(x), x) + 1 + >>> homogeneous_order(x**2*f(x)/sqrt(x**2+f(x)**2), x, f(x)) + 2 + >>> homogeneous_order(x**2+f(x), x, f(x)) is None + True + + """ + + if not symbols: + raise ValueError("homogeneous_order: no symbols were given.") + symset = set(symbols) + eq = sympify(eq) + + # The following are not supported + if eq.has(Order, Derivative): + return None + + # These are all constants + if (eq.is_Number or + eq.is_NumberSymbol or + eq.is_number + ): + return S.Zero + + # Replace all functions with dummy variables + dum = numbered_symbols(prefix='d', cls=Dummy) + newsyms = set() + for i in [j for j in symset if getattr(j, 'is_Function')]: + iargs = set(i.args) + if iargs.difference(symset): + return None + else: + dummyvar = next(dum) + eq = eq.subs(i, dummyvar) + symset.remove(i) + newsyms.add(dummyvar) + symset.update(newsyms) + + if not eq.free_symbols & symset: + return None + + # assuming order of a nested function can only be equal to zero + if isinstance(eq, Function): + return None if homogeneous_order( + eq.args[0], *tuple(symset)) != 0 else S.Zero + + # make the replacement of x with x*t and see if t can be factored out + t = Dummy('t', positive=True) # It is sufficient that t > 0 + eqs = separatevars(eq.subs([(i, t*i) for i in symset]), [t], dict=True)[t] + if eqs is S.One: + return S.Zero # there was no term with only t + i, d = eqs.as_independent(t, as_Add=False) + b, e = d.as_base_exp() + if b == t: + return e + + +def ode_2nd_power_series_ordinary(eq, func, order, match): + r""" + Gives a power series solution to a second order homogeneous differential + equation with polynomial coefficients at an ordinary point. A homogeneous + differential equation is of the form + + .. math :: P(x)\frac{d^2y}{dx^2} + Q(x)\frac{dy}{dx} + R(x) y(x) = 0 + + For simplicity it is assumed that `P(x)`, `Q(x)` and `R(x)` are polynomials, + it is sufficient that `\frac{Q(x)}{P(x)}` and `\frac{R(x)}{P(x)}` exists at + `x_{0}`. A recurrence relation is obtained by substituting `y` as `\sum_{n=0}^\infty a_{n}x^{n}`, + in the differential equation, and equating the nth term. Using this relation + various terms can be generated. + + + Examples + ======== + + >>> from sympy import dsolve, Function, pprint + >>> from sympy.abc import x + >>> f = Function("f") + >>> eq = f(x).diff(x, 2) + f(x) + >>> pprint(dsolve(eq, hint='2nd_power_series_ordinary')) + / 4 2 \ / 2\ + |x x | | x | / 6\ + f(x) = C2*|-- - -- + 1| + C1*x*|1 - --| + O\x / + \24 2 / \ 6 / + + + References + ========== + - https://tutorial.math.lamar.edu/Classes/DE/SeriesSolutions.aspx + - George E. Simmons, "Differential Equations with Applications and + Historical Notes", p.p 176 - 184 + + """ + x = func.args[0] + f = func.func + C0, C1 = get_numbered_constants(eq, num=2) + n = Dummy("n", integer=True) + s = Wild("s") + k = Wild("k", exclude=[x]) + x0 = match['x0'] + terms = match['terms'] + p = match[match['a3']] + q = match[match['b3']] + r = match[match['c3']] + seriesdict = {} + recurr = Function("r") + + # Generating the recurrence relation which works this way: + # for the second order term the summation begins at n = 2. The coefficients + # p is multiplied with an*(n - 1)*(n - 2)*x**n-2 and a substitution is made such that + # the exponent of x becomes n. + # For example, if p is x, then the second degree recurrence term is + # an*(n - 1)*(n - 2)*x**n-1, substituting (n - 1) as n, it transforms to + # an+1*n*(n - 1)*x**n. + # A similar process is done with the first order and zeroth order term. + + coefflist = [(recurr(n), r), (n*recurr(n), q), (n*(n - 1)*recurr(n), p)] + for index, coeff in enumerate(coefflist): + if coeff[1]: + f2 = powsimp(expand((coeff[1]*(x - x0)**(n - index)).subs(x, x + x0))) + if f2.is_Add: + addargs = f2.args + else: + addargs = [f2] + for arg in addargs: + powm = arg.match(s*x**k) + term = coeff[0]*powm[s] + if not powm[k].is_Symbol: + term = term.subs(n, n - powm[k].as_independent(n)[0]) + startind = powm[k].subs(n, index) + # Seeing if the startterm can be reduced further. + # If it vanishes for n lesser than startind, it is + # equal to summation from n. + if startind: + for i in reversed(range(startind)): + if not term.subs(n, i): + seriesdict[term] = i + else: + seriesdict[term] = i + 1 + break + else: + seriesdict[term] = S.Zero + + # Stripping of terms so that the sum starts with the same number. + teq = S.Zero + suminit = seriesdict.values() + rkeys = seriesdict.keys() + req = Add(*rkeys) + if any(suminit): + maxval = max(suminit) + for term in seriesdict: + val = seriesdict[term] + if val != maxval: + for i in range(val, maxval): + teq += term.subs(n, val) + + finaldict = {} + if teq: + fargs = teq.atoms(AppliedUndef) + if len(fargs) == 1: + finaldict[fargs.pop()] = 0 + else: + maxf = max(fargs, key = lambda x: x.args[0]) + sol = solve(teq, maxf) + if isinstance(sol, list): + sol = sol[0] + finaldict[maxf] = sol + + # Finding the recurrence relation in terms of the largest term. + fargs = req.atoms(AppliedUndef) + maxf = max(fargs, key = lambda x: x.args[0]) + minf = min(fargs, key = lambda x: x.args[0]) + if minf.args[0].is_Symbol: + startiter = 0 + else: + startiter = -minf.args[0].as_independent(n)[0] + lhs = maxf + rhs = solve(req, maxf) + if isinstance(rhs, list): + rhs = rhs[0] + + # Checking how many values are already present + tcounter = len([t for t in finaldict.values() if t]) + + for _ in range(tcounter, terms - 3): # Assuming c0 and c1 to be arbitrary + check = rhs.subs(n, startiter) + nlhs = lhs.subs(n, startiter) + nrhs = check.subs(finaldict) + finaldict[nlhs] = nrhs + startiter += 1 + + # Post processing + series = C0 + C1*(x - x0) + for term in finaldict: + if finaldict[term]: + fact = term.args[0] + series += (finaldict[term].subs([(recurr(0), C0), (recurr(1), C1)])*( + x - x0)**fact) + series = collect(expand_mul(series), [C0, C1]) + Order(x**terms) + return Eq(f(x), series) + + +def ode_2nd_power_series_regular(eq, func, order, match): + r""" + Gives a power series solution to a second order homogeneous differential + equation with polynomial coefficients at a regular point. A second order + homogeneous differential equation is of the form + + .. math :: P(x)\frac{d^2y}{dx^2} + Q(x)\frac{dy}{dx} + R(x) y(x) = 0 + + A point is said to regular singular at `x0` if `x - x0\frac{Q(x)}{P(x)}` + and `(x - x0)^{2}\frac{R(x)}{P(x)}` are analytic at `x0`. For simplicity + `P(x)`, `Q(x)` and `R(x)` are assumed to be polynomials. The algorithm for + finding the power series solutions is: + + 1. Try expressing `(x - x0)P(x)` and `((x - x0)^{2})Q(x)` as power series + solutions about x0. Find `p0` and `q0` which are the constants of the + power series expansions. + 2. Solve the indicial equation `f(m) = m(m - 1) + m*p0 + q0`, to obtain the + roots `m1` and `m2` of the indicial equation. + 3. If `m1 - m2` is a non integer there exists two series solutions. If + `m1 = m2`, there exists only one solution. If `m1 - m2` is an integer, + then the existence of one solution is confirmed. The other solution may + or may not exist. + + The power series solution is of the form `x^{m}\sum_{n=0}^\infty a_{n}x^{n}`. The + coefficients are determined by the following recurrence relation. + `a_{n} = -\frac{\sum_{k=0}^{n-1} q_{n-k} + (m + k)p_{n-k}}{f(m + n)}`. For the case + in which `m1 - m2` is an integer, it can be seen from the recurrence relation + that for the lower root `m`, when `n` equals the difference of both the + roots, the denominator becomes zero. So if the numerator is not equal to zero, + a second series solution exists. + + + Examples + ======== + + >>> from sympy import dsolve, Function, pprint + >>> from sympy.abc import x + >>> f = Function("f") + >>> eq = x*(f(x).diff(x, 2)) + 2*(f(x).diff(x)) + x*f(x) + >>> pprint(dsolve(eq, hint='2nd_power_series_regular')) + / 6 4 2 \ + | x x x | + / 4 2 \ C1*|- --- + -- - -- + 1| + |x x | \ 720 24 2 / / 6\ + f(x) = C2*|--- - -- + 1| + ------------------------ + O\x / + \120 6 / x + + + References + ========== + - George E. Simmons, "Differential Equations with Applications and + Historical Notes", p.p 176 - 184 + + """ + x = func.args[0] + f = func.func + C0, C1 = get_numbered_constants(eq, num=2) + m = Dummy("m") # for solving the indicial equation + x0 = match['x0'] + terms = match['terms'] + p = match['p'] + q = match['q'] + + # Generating the indicial equation + indicial = [] + for term in [p, q]: + if not term.has(x): + indicial.append(term) + else: + term = series(term, x=x, n=1, x0=x0) + if isinstance(term, Order): + indicial.append(S.Zero) + else: + for arg in term.args: + if not arg.has(x): + indicial.append(arg) + break + + p0, q0 = indicial + sollist = solve(m*(m - 1) + m*p0 + q0, m) + if sollist and isinstance(sollist, list) and all( + sol.is_real for sol in sollist): + serdict1 = {} + serdict2 = {} + if len(sollist) == 1: + # Only one series solution exists in this case. + m1 = m2 = sollist.pop() + if terms-m1-1 <= 0: + return Eq(f(x), Order(terms)) + serdict1 = _frobenius(terms-m1-1, m1, p0, q0, p, q, x0, x, C0) + + else: + m1 = sollist[0] + m2 = sollist[1] + if m1 < m2: + m1, m2 = m2, m1 + # Irrespective of whether m1 - m2 is an integer or not, one + # Frobenius series solution exists. + serdict1 = _frobenius(terms-m1-1, m1, p0, q0, p, q, x0, x, C0) + if not (m1 - m2).is_integer: + # Second frobenius series solution exists. + serdict2 = _frobenius(terms-m2-1, m2, p0, q0, p, q, x0, x, C1) + else: + # Check if second frobenius series solution exists. + serdict2 = _frobenius(terms-m2-1, m2, p0, q0, p, q, x0, x, C1, check=m1) + + if serdict1: + finalseries1 = C0 + for key in serdict1: + power = int(key.name[1:]) + finalseries1 += serdict1[key]*(x - x0)**power + finalseries1 = (x - x0)**m1*finalseries1 + finalseries2 = S.Zero + if serdict2: + for key in serdict2: + power = int(key.name[1:]) + finalseries2 += serdict2[key]*(x - x0)**power + finalseries2 += C1 + finalseries2 = (x - x0)**m2*finalseries2 + return Eq(f(x), collect(finalseries1 + finalseries2, + [C0, C1]) + Order(x**terms)) + + +def _frobenius(n, m, p0, q0, p, q, x0, x, c, check=None): + r""" + Returns a dict with keys as coefficients and values as their values in terms of C0 + """ + n = int(n) + # In cases where m1 - m2 is not an integer + m2 = check + + d = Dummy("d") + numsyms = numbered_symbols("C", start=0) + numsyms = [next(numsyms) for i in range(n + 1)] + serlist = [] + for ser in [p, q]: + # Order term not present + if ser.is_polynomial(x) and Poly(ser, x).degree() <= n: + if x0: + ser = ser.subs(x, x + x0) + dict_ = Poly(ser, x).as_dict() + # Order term present + else: + tseries = series(ser, x=x0, n=n+1) + # Removing order + dict_ = Poly(list(ordered(tseries.args))[: -1], x).as_dict() + # Fill in with zeros, if coefficients are zero. + for i in range(n + 1): + if (i,) not in dict_: + dict_[(i,)] = S.Zero + serlist.append(dict_) + + pseries = serlist[0] + qseries = serlist[1] + indicial = d*(d - 1) + d*p0 + q0 + frobdict = {} + for i in range(1, n + 1): + num = c*(m*pseries[(i,)] + qseries[(i,)]) + for j in range(1, i): + sym = Symbol("C" + str(j)) + num += frobdict[sym]*((m + j)*pseries[(i - j,)] + qseries[(i - j,)]) + + # Checking for cases when m1 - m2 is an integer. If num equals zero + # then a second Frobenius series solution cannot be found. If num is not zero + # then set constant as zero and proceed. + if m2 is not None and i == m2 - m: + if num: + return False + else: + frobdict[numsyms[i]] = S.Zero + else: + frobdict[numsyms[i]] = -num/(indicial.subs(d, m+i)) + + return frobdict + +def _remove_redundant_solutions(eq, solns, order, var): + r""" + Remove redundant solutions from the set of solutions. + + This function is needed because otherwise dsolve can return + redundant solutions. As an example consider: + + eq = Eq((f(x).diff(x, 2))*f(x).diff(x), 0) + + There are two ways to find solutions to eq. The first is to solve f(x).diff(x, 2) = 0 + leading to solution f(x)=C1 + C2*x. The second is to solve the equation f(x).diff(x) = 0 + leading to the solution f(x) = C1. In this particular case we then see + that the second solution is a special case of the first and we do not + want to return it. + + This does not always happen. If we have + + eq = Eq((f(x)**2-4)*(f(x).diff(x)-4), 0) + + then we get the algebraic solution f(x) = [-2, 2] and the integral solution + f(x) = x + C1 and in this case the two solutions are not equivalent wrt + initial conditions so both should be returned. + """ + def is_special_case_of(soln1, soln2): + return _is_special_case_of(soln1, soln2, eq, order, var) + + unique_solns = [] + for soln1 in solns: + for soln2 in unique_solns.copy(): + if is_special_case_of(soln1, soln2): + break + elif is_special_case_of(soln2, soln1): + unique_solns.remove(soln2) + else: + unique_solns.append(soln1) + + return unique_solns + +def _is_special_case_of(soln1, soln2, eq, order, var): + r""" + True if soln1 is found to be a special case of soln2 wrt some value of the + constants that appear in soln2. False otherwise. + """ + # The solutions returned by dsolve may be given explicitly or implicitly. + # We will equate the sol1=(soln1.rhs - soln1.lhs), sol2=(soln2.rhs - soln2.lhs) + # of the two solutions. + # + # Since this is supposed to hold for all x it also holds for derivatives. + # For an order n ode we should be able to differentiate + # each solution n times to get n+1 equations. + # + # We then try to solve those n+1 equations for the integrations constants + # in sol2. If we can find a solution that does not depend on x then it + # means that some value of the constants in sol1 is a special case of + # sol2 corresponding to a particular choice of the integration constants. + + # In case the solution is in implicit form we subtract the sides + soln1 = soln1.rhs - soln1.lhs + soln2 = soln2.rhs - soln2.lhs + + # Work for the series solution + if soln1.has(Order) and soln2.has(Order): + if soln1.getO() == soln2.getO(): + soln1 = soln1.removeO() + soln2 = soln2.removeO() + else: + return False + elif soln1.has(Order) or soln2.has(Order): + return False + + constants1 = soln1.free_symbols.difference(eq.free_symbols) + constants2 = soln2.free_symbols.difference(eq.free_symbols) + + constants1_new = get_numbered_constants(Tuple(soln1, soln2), len(constants1)) + if len(constants1) == 1: + constants1_new = {constants1_new} + for c_old, c_new in zip(constants1, constants1_new): + soln1 = soln1.subs(c_old, c_new) + + # n equations for sol1 = sol2, sol1'=sol2', ... + lhs = soln1 + rhs = soln2 + eqns = [Eq(lhs, rhs)] + for n in range(1, order): + lhs = lhs.diff(var) + rhs = rhs.diff(var) + eq = Eq(lhs, rhs) + eqns.append(eq) + + # BooleanTrue/False awkwardly show up for trivial equations + if any(isinstance(eq, BooleanFalse) for eq in eqns): + return False + eqns = [eq for eq in eqns if not isinstance(eq, BooleanTrue)] + + try: + constant_solns = solve(eqns, constants2) + except NotImplementedError: + return False + + # Sometimes returns a dict and sometimes a list of dicts + if isinstance(constant_solns, dict): + constant_solns = [constant_solns] + + # after solving the issue 17418, maybe we don't need the following checksol code. + for constant_soln in constant_solns: + for eq in eqns: + eq=eq.rhs-eq.lhs + if checksol(eq, constant_soln) is not True: + return False + + # If any solution gives all constants as expressions that don't depend on + # x then there exists constants for soln2 that give soln1 + for constant_soln in constant_solns: + if not any(c.has(var) for c in constant_soln.values()): + return True + + return False + + +def ode_1st_power_series(eq, func, order, match): + r""" + The power series solution is a method which gives the Taylor series expansion + to the solution of a differential equation. + + For a first order differential equation `\frac{dy}{dx} = h(x, y)`, a power + series solution exists at a point `x = x_{0}` if `h(x, y)` is analytic at `x_{0}`. + The solution is given by + + .. math:: y(x) = y(x_{0}) + \sum_{n = 1}^{\infty} \frac{F_{n}(x_{0},b)(x - x_{0})^n}{n!}, + + where `y(x_{0}) = b` is the value of y at the initial value of `x_{0}`. + To compute the values of the `F_{n}(x_{0},b)` the following algorithm is + followed, until the required number of terms are generated. + + 1. `F_1 = h(x_{0}, b)` + 2. `F_{n+1} = \frac{\partial F_{n}}{\partial x} + \frac{\partial F_{n}}{\partial y}F_{1}` + + Examples + ======== + + >>> from sympy import Function, pprint, exp, dsolve + >>> from sympy.abc import x + >>> f = Function('f') + >>> eq = exp(x)*(f(x).diff(x)) - f(x) + >>> pprint(dsolve(eq, hint='1st_power_series')) + 3 4 5 + C1*x C1*x C1*x / 6\ + f(x) = C1 + C1*x - ----- + ----- + ----- + O\x / + 6 24 60 + + + References + ========== + + - Travis W. Walker, Analytic power series technique for solving first-order + differential equations, p.p 17, 18 + + """ + x = func.args[0] + y = match['y'] + f = func.func + h = -match[match['d']]/match[match['e']] + point = match['f0'] + value = match['f0val'] + terms = match['terms'] + + # First term + F = h + if not h: + return Eq(f(x), value) + + # Initialization + series = value + if terms > 1: + hc = h.subs({x: point, y: value}) + if hc.has(oo) or hc.has(nan) or hc.has(zoo): + # Derivative does not exist, not analytic + return Eq(f(x), oo) + elif hc: + series += hc*(x - point) + + for factcount in range(2, terms): + Fnew = F.diff(x) + F.diff(y)*h + Fnewc = Fnew.subs({x: point, y: value}) + # Same logic as above + if Fnewc.has(oo) or Fnewc.has(nan) or Fnewc.has(-oo) or Fnewc.has(zoo): + return Eq(f(x), oo) + series += Fnewc*((x - point)**factcount)/factorial(factcount) + F = Fnew + series += Order(x**terms) + return Eq(f(x), series) + + +def checkinfsol(eq, infinitesimals, func=None, order=None): + r""" + This function is used to check if the given infinitesimals are the + actual infinitesimals of the given first order differential equation. + This method is specific to the Lie Group Solver of ODEs. + + As of now, it simply checks, by substituting the infinitesimals in the + partial differential equation. + + + .. math:: \frac{\partial \eta}{\partial x} + \left(\frac{\partial \eta}{\partial y} + - \frac{\partial \xi}{\partial x}\right)*h + - \frac{\partial \xi}{\partial y}*h^{2} + - \xi\frac{\partial h}{\partial x} - \eta\frac{\partial h}{\partial y} = 0 + + + where `\eta`, and `\xi` are the infinitesimals and `h(x,y) = \frac{dy}{dx}` + + The infinitesimals should be given in the form of a list of dicts + ``[{xi(x, y): inf, eta(x, y): inf}]``, corresponding to the + output of the function infinitesimals. It returns a list + of values of the form ``[(True/False, sol)]`` where ``sol`` is the value + obtained after substituting the infinitesimals in the PDE. If it + is ``True``, then ``sol`` would be 0. + + """ + if isinstance(eq, Equality): + eq = eq.lhs - eq.rhs + if not func: + eq, func = _preprocess(eq) + variables = func.args + if len(variables) != 1: + raise ValueError("ODE's have only one independent variable") + else: + x = variables[0] + if not order: + order = ode_order(eq, func) + if order != 1: + raise NotImplementedError("Lie groups solver has been implemented " + "only for first order differential equations") + else: + df = func.diff(x) + a = Wild('a', exclude = [df]) + b = Wild('b', exclude = [df]) + match = collect(expand(eq), df).match(a*df + b) + + if match: + h = -simplify(match[b]/match[a]) + else: + try: + sol = solve(eq, df) + except NotImplementedError: + raise NotImplementedError("Infinitesimals for the " + "first order ODE could not be found") + else: + h = sol[0] # Find infinitesimals for one solution + + y = Dummy('y') + h = h.subs(func, y) + xi = Function('xi')(x, y) + eta = Function('eta')(x, y) + dxi = Function('xi')(x, func) + deta = Function('eta')(x, func) + pde = (eta.diff(x) + (eta.diff(y) - xi.diff(x))*h - + (xi.diff(y))*h**2 - xi*(h.diff(x)) - eta*(h.diff(y))) + soltup = [] + for sol in infinitesimals: + tsol = {xi: S(sol[dxi]).subs(func, y), + eta: S(sol[deta]).subs(func, y)} + sol = simplify(pde.subs(tsol).doit()) + if sol: + soltup.append((False, sol.subs(y, func))) + else: + soltup.append((True, 0)) + return soltup + + +def sysode_linear_2eq_order1(match_): + x = match_['func'][0].func + y = match_['func'][1].func + func = match_['func'] + fc = match_['func_coeff'] + eq = match_['eq'] + r = {} + t = list(list(eq[0].atoms(Derivative))[0].atoms(Symbol))[0] + for i in range(2): + eq[i] = Add(*[terms/fc[i,func[i],1] for terms in Add.make_args(eq[i])]) + + # for equations Eq(a1*diff(x(t),t), a*x(t) + b*y(t) + k1) + # and Eq(a2*diff(x(t),t), c*x(t) + d*y(t) + k2) + r['a'] = -fc[0,x(t),0]/fc[0,x(t),1] + r['c'] = -fc[1,x(t),0]/fc[1,y(t),1] + r['b'] = -fc[0,y(t),0]/fc[0,x(t),1] + r['d'] = -fc[1,y(t),0]/fc[1,y(t),1] + forcing = [S.Zero,S.Zero] + for i in range(2): + for j in Add.make_args(eq[i]): + if not j.has(x(t), y(t)): + forcing[i] += j + if not (forcing[0].has(t) or forcing[1].has(t)): + r['k1'] = forcing[0] + r['k2'] = forcing[1] + else: + raise NotImplementedError("Only homogeneous problems are supported" + + " (and constant inhomogeneity)") + + if match_['type_of_equation'] == 'type6': + sol = _linear_2eq_order1_type6(x, y, t, r, eq) + if match_['type_of_equation'] == 'type7': + sol = _linear_2eq_order1_type7(x, y, t, r, eq) + return sol + +def _linear_2eq_order1_type6(x, y, t, r, eq): + r""" + The equations of this type of ode are . + + .. math:: x' = f(t) x + g(t) y + + .. math:: y' = a [f(t) + a h(t)] x + a [g(t) - h(t)] y + + This is solved by first multiplying the first equation by `-a` and adding + it to the second equation to obtain + + .. math:: y' - a x' = -a h(t) (y - a x) + + Setting `U = y - ax` and integrating the equation we arrive at + + .. math:: y - ax = C_1 e^{-a \int h(t) \,dt} + + and on substituting the value of y in first equation give rise to first order ODEs. After solving for + `x`, we can obtain `y` by substituting the value of `x` in second equation. + + """ + C1, C2, C3, C4 = get_numbered_constants(eq, num=4) + p = 0 + q = 0 + p1 = cancel(r['c']/cancel(r['c']/r['d']).as_numer_denom()[0]) + p2 = cancel(r['a']/cancel(r['a']/r['b']).as_numer_denom()[0]) + for n, i in enumerate([p1, p2]): + for j in Mul.make_args(collect_const(i)): + if not j.has(t): + q = j + if q!=0 and n==0: + if ((r['c']/j - r['a'])/(r['b'] - r['d']/j)) == j: + p = 1 + s = j + break + if q!=0 and n==1: + if ((r['a']/j - r['c'])/(r['d'] - r['b']/j)) == j: + p = 2 + s = j + break + + if p == 1: + equ = diff(x(t),t) - r['a']*x(t) - r['b']*(s*x(t) + C1*exp(-s*Integral(r['b'] - r['d']/s, t))) + hint1 = classify_ode(equ)[1] + sol1 = dsolve(equ, hint=hint1+'_Integral').rhs + sol2 = s*sol1 + C1*exp(-s*Integral(r['b'] - r['d']/s, t)) + elif p ==2: + equ = diff(y(t),t) - r['c']*y(t) - r['d']*s*y(t) + C1*exp(-s*Integral(r['d'] - r['b']/s, t)) + hint1 = classify_ode(equ)[1] + sol2 = dsolve(equ, hint=hint1+'_Integral').rhs + sol1 = s*sol2 + C1*exp(-s*Integral(r['d'] - r['b']/s, t)) + return [Eq(x(t), sol1), Eq(y(t), sol2)] + +def _linear_2eq_order1_type7(x, y, t, r, eq): + r""" + The equations of this type of ode are . + + .. math:: x' = f(t) x + g(t) y + + .. math:: y' = h(t) x + p(t) y + + Differentiating the first equation and substituting the value of `y` + from second equation will give a second-order linear equation + + .. math:: g x'' - (fg + gp + g') x' + (fgp - g^{2} h + f g' - f' g) x = 0 + + This above equation can be easily integrated if following conditions are satisfied. + + 1. `fgp - g^{2} h + f g' - f' g = 0` + + 2. `fgp - g^{2} h + f g' - f' g = ag, fg + gp + g' = bg` + + If first condition is satisfied then it is solved by current dsolve solver and in second case it becomes + a constant coefficient differential equation which is also solved by current solver. + + Otherwise if the above condition fails then, + a particular solution is assumed as `x = x_0(t)` and `y = y_0(t)` + Then the general solution is expressed as + + .. math:: x = C_1 x_0(t) + C_2 x_0(t) \int \frac{g(t) F(t) P(t)}{x_0^{2}(t)} \,dt + + .. math:: y = C_1 y_0(t) + C_2 [\frac{F(t) P(t)}{x_0(t)} + y_0(t) \int \frac{g(t) F(t) P(t)}{x_0^{2}(t)} \,dt] + + where C1 and C2 are arbitrary constants and + + .. math:: F(t) = e^{\int f(t) \,dt}, P(t) = e^{\int p(t) \,dt} + + """ + C1, C2, C3, C4 = get_numbered_constants(eq, num=4) + e1 = r['a']*r['b']*r['c'] - r['b']**2*r['c'] + r['a']*diff(r['b'],t) - diff(r['a'],t)*r['b'] + e2 = r['a']*r['c']*r['d'] - r['b']*r['c']**2 + diff(r['c'],t)*r['d'] - r['c']*diff(r['d'],t) + m1 = r['a']*r['b'] + r['b']*r['d'] + diff(r['b'],t) + m2 = r['a']*r['c'] + r['c']*r['d'] + diff(r['c'],t) + if e1 == 0: + sol1 = dsolve(r['b']*diff(x(t),t,t) - m1*diff(x(t),t)).rhs + sol2 = dsolve(diff(y(t),t) - r['c']*sol1 - r['d']*y(t)).rhs + elif e2 == 0: + sol2 = dsolve(r['c']*diff(y(t),t,t) - m2*diff(y(t),t)).rhs + sol1 = dsolve(diff(x(t),t) - r['a']*x(t) - r['b']*sol2).rhs + elif not (e1/r['b']).has(t) and not (m1/r['b']).has(t): + sol1 = dsolve(diff(x(t),t,t) - (m1/r['b'])*diff(x(t),t) - (e1/r['b'])*x(t)).rhs + sol2 = dsolve(diff(y(t),t) - r['c']*sol1 - r['d']*y(t)).rhs + elif not (e2/r['c']).has(t) and not (m2/r['c']).has(t): + sol2 = dsolve(diff(y(t),t,t) - (m2/r['c'])*diff(y(t),t) - (e2/r['c'])*y(t)).rhs + sol1 = dsolve(diff(x(t),t) - r['a']*x(t) - r['b']*sol2).rhs + else: + x0 = Function('x0')(t) # x0 and y0 being particular solutions + y0 = Function('y0')(t) + F = exp(Integral(r['a'],t)) + P = exp(Integral(r['d'],t)) + sol1 = C1*x0 + C2*x0*Integral(r['b']*F*P/x0**2, t) + sol2 = C1*y0 + C2*(F*P/x0 + y0*Integral(r['b']*F*P/x0**2, t)) + return [Eq(x(t), sol1), Eq(y(t), sol2)] + + +def sysode_nonlinear_2eq_order1(match_): + func = match_['func'] + eq = match_['eq'] + fc = match_['func_coeff'] + t = list(list(eq[0].atoms(Derivative))[0].atoms(Symbol))[0] + if match_['type_of_equation'] == 'type5': + sol = _nonlinear_2eq_order1_type5(func, t, eq) + return sol + x = func[0].func + y = func[1].func + for i in range(2): + eqs = 0 + for terms in Add.make_args(eq[i]): + eqs += terms/fc[i,func[i],1] + eq[i] = eqs + if match_['type_of_equation'] == 'type1': + sol = _nonlinear_2eq_order1_type1(x, y, t, eq) + elif match_['type_of_equation'] == 'type2': + sol = _nonlinear_2eq_order1_type2(x, y, t, eq) + elif match_['type_of_equation'] == 'type3': + sol = _nonlinear_2eq_order1_type3(x, y, t, eq) + elif match_['type_of_equation'] == 'type4': + sol = _nonlinear_2eq_order1_type4(x, y, t, eq) + return sol + + +def _nonlinear_2eq_order1_type1(x, y, t, eq): + r""" + Equations: + + .. math:: x' = x^n F(x,y) + + .. math:: y' = g(y) F(x,y) + + Solution: + + .. math:: x = \varphi(y), \int \frac{1}{g(y) F(\varphi(y),y)} \,dy = t + C_2 + + where + + if `n \neq 1` + + .. math:: \varphi = [C_1 + (1-n) \int \frac{1}{g(y)} \,dy]^{\frac{1}{1-n}} + + if `n = 1` + + .. math:: \varphi = C_1 e^{\int \frac{1}{g(y)} \,dy} + + where `C_1` and `C_2` are arbitrary constants. + + """ + C1, C2 = get_numbered_constants(eq, num=2) + n = Wild('n', exclude=[x(t),y(t)]) + f = Wild('f') + u, v = symbols('u, v') + r = eq[0].match(diff(x(t),t) - x(t)**n*f) + g = ((diff(y(t),t) - eq[1])/r[f]).subs(y(t),v) + F = r[f].subs(x(t),u).subs(y(t),v) + n = r[n] + if n!=1: + phi = (C1 + (1-n)*Integral(1/g, v))**(1/(1-n)) + else: + phi = C1*exp(Integral(1/g, v)) + phi = phi.doit() + sol2 = solve(Integral(1/(g*F.subs(u,phi)), v).doit() - t - C2, v) + sol = [] + for sols in sol2: + sol.append(Eq(x(t),phi.subs(v, sols))) + sol.append(Eq(y(t), sols)) + return sol + +def _nonlinear_2eq_order1_type2(x, y, t, eq): + r""" + Equations: + + .. math:: x' = e^{\lambda x} F(x,y) + + .. math:: y' = g(y) F(x,y) + + Solution: + + .. math:: x = \varphi(y), \int \frac{1}{g(y) F(\varphi(y),y)} \,dy = t + C_2 + + where + + if `\lambda \neq 0` + + .. math:: \varphi = -\frac{1}{\lambda} log(C_1 - \lambda \int \frac{1}{g(y)} \,dy) + + if `\lambda = 0` + + .. math:: \varphi = C_1 + \int \frac{1}{g(y)} \,dy + + where `C_1` and `C_2` are arbitrary constants. + + """ + C1, C2 = get_numbered_constants(eq, num=2) + n = Wild('n', exclude=[x(t),y(t)]) + f = Wild('f') + u, v = symbols('u, v') + r = eq[0].match(diff(x(t),t) - exp(n*x(t))*f) + g = ((diff(y(t),t) - eq[1])/r[f]).subs(y(t),v) + F = r[f].subs(x(t),u).subs(y(t),v) + n = r[n] + if n: + phi = -1/n*log(C1 - n*Integral(1/g, v)) + else: + phi = C1 + Integral(1/g, v) + phi = phi.doit() + sol2 = solve(Integral(1/(g*F.subs(u,phi)), v).doit() - t - C2, v) + sol = [] + for sols in sol2: + sol.append(Eq(x(t),phi.subs(v, sols))) + sol.append(Eq(y(t), sols)) + return sol + +def _nonlinear_2eq_order1_type3(x, y, t, eq): + r""" + Autonomous system of general form + + .. math:: x' = F(x,y) + + .. math:: y' = G(x,y) + + Assuming `y = y(x, C_1)` where `C_1` is an arbitrary constant is the general + solution of the first-order equation + + .. math:: F(x,y) y'_x = G(x,y) + + Then the general solution of the original system of equations has the form + + .. math:: \int \frac{1}{F(x,y(x,C_1))} \,dx = t + C_1 + + """ + C1, C2, C3, C4 = get_numbered_constants(eq, num=4) + v = Function('v') + u = Symbol('u') + f = Wild('f') + g = Wild('g') + r1 = eq[0].match(diff(x(t),t) - f) + r2 = eq[1].match(diff(y(t),t) - g) + F = r1[f].subs(x(t), u).subs(y(t), v(u)) + G = r2[g].subs(x(t), u).subs(y(t), v(u)) + sol2r = dsolve(Eq(diff(v(u), u), G/F)) + if isinstance(sol2r, Equality): + sol2r = [sol2r] + for sol2s in sol2r: + sol1 = solve(Integral(1/F.subs(v(u), sol2s.rhs), u).doit() - t - C2, u) + sol = [] + for sols in sol1: + sol.append(Eq(x(t), sols)) + sol.append(Eq(y(t), (sol2s.rhs).subs(u, sols))) + return sol + +def _nonlinear_2eq_order1_type4(x, y, t, eq): + r""" + Equation: + + .. math:: x' = f_1(x) g_1(y) \phi(x,y,t) + + .. math:: y' = f_2(x) g_2(y) \phi(x,y,t) + + First integral: + + .. math:: \int \frac{f_2(x)}{f_1(x)} \,dx - \int \frac{g_1(y)}{g_2(y)} \,dy = C + + where `C` is an arbitrary constant. + + On solving the first integral for `x` (resp., `y` ) and on substituting the + resulting expression into either equation of the original solution, one + arrives at a first-order equation for determining `y` (resp., `x` ). + + """ + C1, C2 = get_numbered_constants(eq, num=2) + u, v = symbols('u, v') + U, V = symbols('U, V', cls=Function) + f = Wild('f') + g = Wild('g') + f1 = Wild('f1', exclude=[v,t]) + f2 = Wild('f2', exclude=[v,t]) + g1 = Wild('g1', exclude=[u,t]) + g2 = Wild('g2', exclude=[u,t]) + r1 = eq[0].match(diff(x(t),t) - f) + r2 = eq[1].match(diff(y(t),t) - g) + num, den = ( + (r1[f].subs(x(t),u).subs(y(t),v))/ + (r2[g].subs(x(t),u).subs(y(t),v))).as_numer_denom() + R1 = num.match(f1*g1) + R2 = den.match(f2*g2) + phi = (r1[f].subs(x(t),u).subs(y(t),v))/num + F1 = R1[f1]; F2 = R2[f2] + G1 = R1[g1]; G2 = R2[g2] + sol1r = solve(Integral(F2/F1, u).doit() - Integral(G1/G2,v).doit() - C1, u) + sol2r = solve(Integral(F2/F1, u).doit() - Integral(G1/G2,v).doit() - C1, v) + sol = [] + for sols in sol1r: + sol.append(Eq(y(t), dsolve(diff(V(t),t) - F2.subs(u,sols).subs(v,V(t))*G2.subs(v,V(t))*phi.subs(u,sols).subs(v,V(t))).rhs)) + for sols in sol2r: + sol.append(Eq(x(t), dsolve(diff(U(t),t) - F1.subs(u,U(t))*G1.subs(v,sols).subs(u,U(t))*phi.subs(v,sols).subs(u,U(t))).rhs)) + return set(sol) + +def _nonlinear_2eq_order1_type5(func, t, eq): + r""" + Clairaut system of ODEs + + .. math:: x = t x' + F(x',y') + + .. math:: y = t y' + G(x',y') + + The following are solutions of the system + + `(i)` straight lines: + + .. math:: x = C_1 t + F(C_1, C_2), y = C_2 t + G(C_1, C_2) + + where `C_1` and `C_2` are arbitrary constants; + + `(ii)` envelopes of the above lines; + + `(iii)` continuously differentiable lines made up from segments of the lines + `(i)` and `(ii)`. + + """ + C1, C2 = get_numbered_constants(eq, num=2) + f = Wild('f') + g = Wild('g') + def check_type(x, y): + r1 = eq[0].match(t*diff(x(t),t) - x(t) + f) + r2 = eq[1].match(t*diff(y(t),t) - y(t) + g) + if not (r1 and r2): + r1 = eq[0].match(diff(x(t),t) - x(t)/t + f/t) + r2 = eq[1].match(diff(y(t),t) - y(t)/t + g/t) + if not (r1 and r2): + r1 = (-eq[0]).match(t*diff(x(t),t) - x(t) + f) + r2 = (-eq[1]).match(t*diff(y(t),t) - y(t) + g) + if not (r1 and r2): + r1 = (-eq[0]).match(diff(x(t),t) - x(t)/t + f/t) + r2 = (-eq[1]).match(diff(y(t),t) - y(t)/t + g/t) + return [r1, r2] + for func_ in func: + if isinstance(func_, list): + x = func[0][0].func + y = func[0][1].func + [r1, r2] = check_type(x, y) + if not (r1 and r2): + [r1, r2] = check_type(y, x) + x, y = y, x + x1 = diff(x(t),t); y1 = diff(y(t),t) + return {Eq(x(t), C1*t + r1[f].subs(x1,C1).subs(y1,C2)), Eq(y(t), C2*t + r2[g].subs(x1,C1).subs(y1,C2))} + +def sysode_nonlinear_3eq_order1(match_): + x = match_['func'][0].func + y = match_['func'][1].func + z = match_['func'][2].func + eq = match_['eq'] + t = list(list(eq[0].atoms(Derivative))[0].atoms(Symbol))[0] + if match_['type_of_equation'] == 'type1': + sol = _nonlinear_3eq_order1_type1(x, y, z, t, eq) + if match_['type_of_equation'] == 'type2': + sol = _nonlinear_3eq_order1_type2(x, y, z, t, eq) + if match_['type_of_equation'] == 'type3': + sol = _nonlinear_3eq_order1_type3(x, y, z, t, eq) + if match_['type_of_equation'] == 'type4': + sol = _nonlinear_3eq_order1_type4(x, y, z, t, eq) + if match_['type_of_equation'] == 'type5': + sol = _nonlinear_3eq_order1_type5(x, y, z, t, eq) + return sol + +def _nonlinear_3eq_order1_type1(x, y, z, t, eq): + r""" + Equations: + + .. math:: a x' = (b - c) y z, \enspace b y' = (c - a) z x, \enspace c z' = (a - b) x y + + First Integrals: + + .. math:: a x^{2} + b y^{2} + c z^{2} = C_1 + + .. math:: a^{2} x^{2} + b^{2} y^{2} + c^{2} z^{2} = C_2 + + where `C_1` and `C_2` are arbitrary constants. On solving the integrals for `y` and + `z` and on substituting the resulting expressions into the first equation of the + system, we arrives at a separable first-order equation on `x`. Similarly doing that + for other two equations, we will arrive at first order equation on `y` and `z` too. + + References + ========== + -https://eqworld.ipmnet.ru/en/solutions/sysode/sode0401.pdf + + """ + C1, C2 = get_numbered_constants(eq, num=2) + u, v, w = symbols('u, v, w') + p = Wild('p', exclude=[x(t), y(t), z(t), t]) + q = Wild('q', exclude=[x(t), y(t), z(t), t]) + s = Wild('s', exclude=[x(t), y(t), z(t), t]) + r = (diff(x(t),t) - eq[0]).match(p*y(t)*z(t)) + r.update((diff(y(t),t) - eq[1]).match(q*z(t)*x(t))) + r.update((diff(z(t),t) - eq[2]).match(s*x(t)*y(t))) + n1, d1 = r[p].as_numer_denom() + n2, d2 = r[q].as_numer_denom() + n3, d3 = r[s].as_numer_denom() + val = solve([n1*u-d1*v+d1*w, d2*u+n2*v-d2*w, d3*u-d3*v-n3*w],[u,v]) + vals = [val[v], val[u]] + c = lcm(vals[0].as_numer_denom()[1], vals[1].as_numer_denom()[1]) + b = vals[0].subs(w, c) + a = vals[1].subs(w, c) + y_x = sqrt(((c*C1-C2) - a*(c-a)*x(t)**2)/(b*(c-b))) + z_x = sqrt(((b*C1-C2) - a*(b-a)*x(t)**2)/(c*(b-c))) + z_y = sqrt(((a*C1-C2) - b*(a-b)*y(t)**2)/(c*(a-c))) + x_y = sqrt(((c*C1-C2) - b*(c-b)*y(t)**2)/(a*(c-a))) + x_z = sqrt(((b*C1-C2) - c*(b-c)*z(t)**2)/(a*(b-a))) + y_z = sqrt(((a*C1-C2) - c*(a-c)*z(t)**2)/(b*(a-b))) + sol1 = dsolve(a*diff(x(t),t) - (b-c)*y_x*z_x) + sol2 = dsolve(b*diff(y(t),t) - (c-a)*z_y*x_y) + sol3 = dsolve(c*diff(z(t),t) - (a-b)*x_z*y_z) + return [sol1, sol2, sol3] + + +def _nonlinear_3eq_order1_type2(x, y, z, t, eq): + r""" + Equations: + + .. math:: a x' = (b - c) y z f(x, y, z, t) + + .. math:: b y' = (c - a) z x f(x, y, z, t) + + .. math:: c z' = (a - b) x y f(x, y, z, t) + + First Integrals: + + .. math:: a x^{2} + b y^{2} + c z^{2} = C_1 + + .. math:: a^{2} x^{2} + b^{2} y^{2} + c^{2} z^{2} = C_2 + + where `C_1` and `C_2` are arbitrary constants. On solving the integrals for `y` and + `z` and on substituting the resulting expressions into the first equation of the + system, we arrives at a first-order differential equations on `x`. Similarly doing + that for other two equations we will arrive at first order equation on `y` and `z`. + + References + ========== + -https://eqworld.ipmnet.ru/en/solutions/sysode/sode0402.pdf + + """ + C1, C2 = get_numbered_constants(eq, num=2) + u, v, w = symbols('u, v, w') + p = Wild('p', exclude=[x(t), y(t), z(t), t]) + q = Wild('q', exclude=[x(t), y(t), z(t), t]) + s = Wild('s', exclude=[x(t), y(t), z(t), t]) + f = Wild('f') + r1 = (diff(x(t),t) - eq[0]).match(y(t)*z(t)*f) + r = collect_const(r1[f]).match(p*f) + r.update(((diff(y(t),t) - eq[1])/r[f]).match(q*z(t)*x(t))) + r.update(((diff(z(t),t) - eq[2])/r[f]).match(s*x(t)*y(t))) + n1, d1 = r[p].as_numer_denom() + n2, d2 = r[q].as_numer_denom() + n3, d3 = r[s].as_numer_denom() + val = solve([n1*u-d1*v+d1*w, d2*u+n2*v-d2*w, -d3*u+d3*v+n3*w],[u,v]) + vals = [val[v], val[u]] + c = lcm(vals[0].as_numer_denom()[1], vals[1].as_numer_denom()[1]) + a = vals[0].subs(w, c) + b = vals[1].subs(w, c) + y_x = sqrt(((c*C1-C2) - a*(c-a)*x(t)**2)/(b*(c-b))) + z_x = sqrt(((b*C1-C2) - a*(b-a)*x(t)**2)/(c*(b-c))) + z_y = sqrt(((a*C1-C2) - b*(a-b)*y(t)**2)/(c*(a-c))) + x_y = sqrt(((c*C1-C2) - b*(c-b)*y(t)**2)/(a*(c-a))) + x_z = sqrt(((b*C1-C2) - c*(b-c)*z(t)**2)/(a*(b-a))) + y_z = sqrt(((a*C1-C2) - c*(a-c)*z(t)**2)/(b*(a-b))) + sol1 = dsolve(a*diff(x(t),t) - (b-c)*y_x*z_x*r[f]) + sol2 = dsolve(b*diff(y(t),t) - (c-a)*z_y*x_y*r[f]) + sol3 = dsolve(c*diff(z(t),t) - (a-b)*x_z*y_z*r[f]) + return [sol1, sol2, sol3] + +def _nonlinear_3eq_order1_type3(x, y, z, t, eq): + r""" + Equations: + + .. math:: x' = c F_2 - b F_3, \enspace y' = a F_3 - c F_1, \enspace z' = b F_1 - a F_2 + + where `F_n = F_n(x, y, z, t)`. + + 1. First Integral: + + .. math:: a x + b y + c z = C_1, + + where C is an arbitrary constant. + + 2. If we assume function `F_n` to be independent of `t`,i.e, `F_n` = `F_n (x, y, z)` + Then, on eliminating `t` and `z` from the first two equation of the system, one + arrives at the first-order equation + + .. math:: \frac{dy}{dx} = \frac{a F_3 (x, y, z) - c F_1 (x, y, z)}{c F_2 (x, y, z) - + b F_3 (x, y, z)} + + where `z = \frac{1}{c} (C_1 - a x - b y)` + + References + ========== + -https://eqworld.ipmnet.ru/en/solutions/sysode/sode0404.pdf + + """ + C1 = get_numbered_constants(eq, num=1) + u, v, w = symbols('u, v, w') + fu, fv, fw = symbols('u, v, w', cls=Function) + p = Wild('p', exclude=[x(t), y(t), z(t), t]) + q = Wild('q', exclude=[x(t), y(t), z(t), t]) + s = Wild('s', exclude=[x(t), y(t), z(t), t]) + F1, F2, F3 = symbols('F1, F2, F3', cls=Wild) + r1 = (diff(x(t), t) - eq[0]).match(F2-F3) + r = collect_const(r1[F2]).match(s*F2) + r.update(collect_const(r1[F3]).match(q*F3)) + if eq[1].has(r[F2]) and not eq[1].has(r[F3]): + r[F2], r[F3] = r[F3], r[F2] + r[s], r[q] = -r[q], -r[s] + r.update((diff(y(t), t) - eq[1]).match(p*r[F3] - r[s]*F1)) + a = r[p]; b = r[q]; c = r[s] + F1 = r[F1].subs(x(t), u).subs(y(t),v).subs(z(t), w) + F2 = r[F2].subs(x(t), u).subs(y(t),v).subs(z(t), w) + F3 = r[F3].subs(x(t), u).subs(y(t),v).subs(z(t), w) + z_xy = (C1-a*u-b*v)/c + y_zx = (C1-a*u-c*w)/b + x_yz = (C1-b*v-c*w)/a + y_x = dsolve(diff(fv(u),u) - ((a*F3-c*F1)/(c*F2-b*F3)).subs(w,z_xy).subs(v,fv(u))).rhs + z_x = dsolve(diff(fw(u),u) - ((b*F1-a*F2)/(c*F2-b*F3)).subs(v,y_zx).subs(w,fw(u))).rhs + z_y = dsolve(diff(fw(v),v) - ((b*F1-a*F2)/(a*F3-c*F1)).subs(u,x_yz).subs(w,fw(v))).rhs + x_y = dsolve(diff(fu(v),v) - ((c*F2-b*F3)/(a*F3-c*F1)).subs(w,z_xy).subs(u,fu(v))).rhs + y_z = dsolve(diff(fv(w),w) - ((a*F3-c*F1)/(b*F1-a*F2)).subs(u,x_yz).subs(v,fv(w))).rhs + x_z = dsolve(diff(fu(w),w) - ((c*F2-b*F3)/(b*F1-a*F2)).subs(v,y_zx).subs(u,fu(w))).rhs + sol1 = dsolve(diff(fu(t),t) - (c*F2 - b*F3).subs(v,y_x).subs(w,z_x).subs(u,fu(t))).rhs + sol2 = dsolve(diff(fv(t),t) - (a*F3 - c*F1).subs(u,x_y).subs(w,z_y).subs(v,fv(t))).rhs + sol3 = dsolve(diff(fw(t),t) - (b*F1 - a*F2).subs(u,x_z).subs(v,y_z).subs(w,fw(t))).rhs + return [sol1, sol2, sol3] + +def _nonlinear_3eq_order1_type4(x, y, z, t, eq): + r""" + Equations: + + .. math:: x' = c z F_2 - b y F_3, \enspace y' = a x F_3 - c z F_1, \enspace z' = b y F_1 - a x F_2 + + where `F_n = F_n (x, y, z, t)` + + 1. First integral: + + .. math:: a x^{2} + b y^{2} + c z^{2} = C_1 + + where `C` is an arbitrary constant. + + 2. Assuming the function `F_n` is independent of `t`: `F_n = F_n (x, y, z)`. Then on + eliminating `t` and `z` from the first two equations of the system, one arrives at + the first-order equation + + .. math:: \frac{dy}{dx} = \frac{a x F_3 (x, y, z) - c z F_1 (x, y, z)} + {c z F_2 (x, y, z) - b y F_3 (x, y, z)} + + where `z = \pm \sqrt{\frac{1}{c} (C_1 - a x^{2} - b y^{2})}` + + References + ========== + -https://eqworld.ipmnet.ru/en/solutions/sysode/sode0405.pdf + + """ + C1 = get_numbered_constants(eq, num=1) + u, v, w = symbols('u, v, w') + p = Wild('p', exclude=[x(t), y(t), z(t), t]) + q = Wild('q', exclude=[x(t), y(t), z(t), t]) + s = Wild('s', exclude=[x(t), y(t), z(t), t]) + F1, F2, F3 = symbols('F1, F2, F3', cls=Wild) + r1 = eq[0].match(diff(x(t),t) - z(t)*F2 + y(t)*F3) + r = collect_const(r1[F2]).match(s*F2) + r.update(collect_const(r1[F3]).match(q*F3)) + if eq[1].has(r[F2]) and not eq[1].has(r[F3]): + r[F2], r[F3] = r[F3], r[F2] + r[s], r[q] = -r[q], -r[s] + r.update((diff(y(t),t) - eq[1]).match(p*x(t)*r[F3] - r[s]*z(t)*F1)) + a = r[p]; b = r[q]; c = r[s] + F1 = r[F1].subs(x(t),u).subs(y(t),v).subs(z(t),w) + F2 = r[F2].subs(x(t),u).subs(y(t),v).subs(z(t),w) + F3 = r[F3].subs(x(t),u).subs(y(t),v).subs(z(t),w) + x_yz = sqrt((C1 - b*v**2 - c*w**2)/a) + y_zx = sqrt((C1 - c*w**2 - a*u**2)/b) + z_xy = sqrt((C1 - a*u**2 - b*v**2)/c) + y_x = dsolve(diff(v(u),u) - ((a*u*F3-c*w*F1)/(c*w*F2-b*v*F3)).subs(w,z_xy).subs(v,v(u))).rhs + z_x = dsolve(diff(w(u),u) - ((b*v*F1-a*u*F2)/(c*w*F2-b*v*F3)).subs(v,y_zx).subs(w,w(u))).rhs + z_y = dsolve(diff(w(v),v) - ((b*v*F1-a*u*F2)/(a*u*F3-c*w*F1)).subs(u,x_yz).subs(w,w(v))).rhs + x_y = dsolve(diff(u(v),v) - ((c*w*F2-b*v*F3)/(a*u*F3-c*w*F1)).subs(w,z_xy).subs(u,u(v))).rhs + y_z = dsolve(diff(v(w),w) - ((a*u*F3-c*w*F1)/(b*v*F1-a*u*F2)).subs(u,x_yz).subs(v,v(w))).rhs + x_z = dsolve(diff(u(w),w) - ((c*w*F2-b*v*F3)/(b*v*F1-a*u*F2)).subs(v,y_zx).subs(u,u(w))).rhs + sol1 = dsolve(diff(u(t),t) - (c*w*F2 - b*v*F3).subs(v,y_x).subs(w,z_x).subs(u,u(t))).rhs + sol2 = dsolve(diff(v(t),t) - (a*u*F3 - c*w*F1).subs(u,x_y).subs(w,z_y).subs(v,v(t))).rhs + sol3 = dsolve(diff(w(t),t) - (b*v*F1 - a*u*F2).subs(u,x_z).subs(v,y_z).subs(w,w(t))).rhs + return [sol1, sol2, sol3] + +def _nonlinear_3eq_order1_type5(x, y, z, t, eq): + r""" + .. math:: x' = x (c F_2 - b F_3), \enspace y' = y (a F_3 - c F_1), \enspace z' = z (b F_1 - a F_2) + + where `F_n = F_n (x, y, z, t)` and are arbitrary functions. + + First Integral: + + .. math:: \left|x\right|^{a} \left|y\right|^{b} \left|z\right|^{c} = C_1 + + where `C` is an arbitrary constant. If the function `F_n` is independent of `t`, + then, by eliminating `t` and `z` from the first two equations of the system, one + arrives at a first-order equation. + + References + ========== + -https://eqworld.ipmnet.ru/en/solutions/sysode/sode0406.pdf + + """ + C1 = get_numbered_constants(eq, num=1) + u, v, w = symbols('u, v, w') + fu, fv, fw = symbols('u, v, w', cls=Function) + p = Wild('p', exclude=[x(t), y(t), z(t), t]) + q = Wild('q', exclude=[x(t), y(t), z(t), t]) + s = Wild('s', exclude=[x(t), y(t), z(t), t]) + F1, F2, F3 = symbols('F1, F2, F3', cls=Wild) + r1 = eq[0].match(diff(x(t), t) - x(t)*F2 + x(t)*F3) + r = collect_const(r1[F2]).match(s*F2) + r.update(collect_const(r1[F3]).match(q*F3)) + if eq[1].has(r[F2]) and not eq[1].has(r[F3]): + r[F2], r[F3] = r[F3], r[F2] + r[s], r[q] = -r[q], -r[s] + r.update((diff(y(t), t) - eq[1]).match(y(t)*(p*r[F3] - r[s]*F1))) + a = r[p]; b = r[q]; c = r[s] + F1 = r[F1].subs(x(t), u).subs(y(t), v).subs(z(t), w) + F2 = r[F2].subs(x(t), u).subs(y(t), v).subs(z(t), w) + F3 = r[F3].subs(x(t), u).subs(y(t), v).subs(z(t), w) + x_yz = (C1*v**-b*w**-c)**-a + y_zx = (C1*w**-c*u**-a)**-b + z_xy = (C1*u**-a*v**-b)**-c + y_x = dsolve(diff(fv(u), u) - ((v*(a*F3 - c*F1))/(u*(c*F2 - b*F3))).subs(w, z_xy).subs(v, fv(u))).rhs + z_x = dsolve(diff(fw(u), u) - ((w*(b*F1 - a*F2))/(u*(c*F2 - b*F3))).subs(v, y_zx).subs(w, fw(u))).rhs + z_y = dsolve(diff(fw(v), v) - ((w*(b*F1 - a*F2))/(v*(a*F3 - c*F1))).subs(u, x_yz).subs(w, fw(v))).rhs + x_y = dsolve(diff(fu(v), v) - ((u*(c*F2 - b*F3))/(v*(a*F3 - c*F1))).subs(w, z_xy).subs(u, fu(v))).rhs + y_z = dsolve(diff(fv(w), w) - ((v*(a*F3 - c*F1))/(w*(b*F1 - a*F2))).subs(u, x_yz).subs(v, fv(w))).rhs + x_z = dsolve(diff(fu(w), w) - ((u*(c*F2 - b*F3))/(w*(b*F1 - a*F2))).subs(v, y_zx).subs(u, fu(w))).rhs + sol1 = dsolve(diff(fu(t), t) - (u*(c*F2 - b*F3)).subs(v, y_x).subs(w, z_x).subs(u, fu(t))).rhs + sol2 = dsolve(diff(fv(t), t) - (v*(a*F3 - c*F1)).subs(u, x_y).subs(w, z_y).subs(v, fv(t))).rhs + sol3 = dsolve(diff(fw(t), t) - (w*(b*F1 - a*F2)).subs(u, x_z).subs(v, y_z).subs(w, fw(t))).rhs + return [sol1, sol2, sol3] + + +#This import is written at the bottom to avoid circular imports. +from .single import SingleODEProblem, SingleODESolver, solver_map diff --git a/phivenv/Lib/site-packages/sympy/solvers/ode/riccati.py b/phivenv/Lib/site-packages/sympy/solvers/ode/riccati.py new file mode 100644 index 0000000000000000000000000000000000000000..2ef66ed0896d39bee8fba1b74a0c93734742fc1f --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/solvers/ode/riccati.py @@ -0,0 +1,893 @@ +r""" +This module contains :py:meth:`~sympy.solvers.ode.riccati.solve_riccati`, +a function which gives all rational particular solutions to first order +Riccati ODEs. A general first order Riccati ODE is given by - + +.. math:: y' = b_0(x) + b_1(x)w + b_2(x)w^2 + +where `b_0, b_1` and `b_2` can be arbitrary rational functions of `x` +with `b_2 \ne 0`. When `b_2 = 0`, the equation is not a Riccati ODE +anymore and becomes a Linear ODE. Similarly, when `b_0 = 0`, the equation +is a Bernoulli ODE. The algorithm presented below can find rational +solution(s) to all ODEs with `b_2 \ne 0` that have a rational solution, +or prove that no rational solution exists for the equation. + +Background +========== + +A Riccati equation can be transformed to its normal form + +.. math:: y' + y^2 = a(x) + +using the transformation + +.. math:: y = -b_2(x) - \frac{b'_2(x)}{2 b_2(x)} - \frac{b_1(x)}{2} + +where `a(x)` is given by + +.. math:: a(x) = \frac{1}{4}\left(\frac{b_2'}{b_2} + b_1\right)^2 - \frac{1}{2}\left(\frac{b_2'}{b_2} + b_1\right)' - b_0 b_2 + +Thus, we can develop an algorithm to solve for the Riccati equation +in its normal form, which would in turn give us the solution for +the original Riccati equation. + +Algorithm +========= + +The algorithm implemented here is presented in the Ph.D thesis +"Rational and Algebraic Solutions of First-Order Algebraic ODEs" +by N. Thieu Vo. The entire thesis can be found here - +https://www3.risc.jku.at/publications/download/risc_5387/PhDThesisThieu.pdf + +We have only implemented the Rational Riccati solver (Algorithm 11, +Pg 78-82 in Thesis). Before we proceed towards the implementation +of the algorithm, a few definitions to understand are - + +1. Valuation of a Rational Function at `\infty`: + The valuation of a rational function `p(x)` at `\infty` is equal + to the difference between the degree of the denominator and the + numerator of `p(x)`. + + NOTE: A general definition of valuation of a rational function + at any value of `x` can be found in Pg 63 of the thesis, but + is not of any interest for this algorithm. + +2. Zeros and Poles of a Rational Function: + Let `a(x) = \frac{S(x)}{T(x)}, T \ne 0` be a rational function + of `x`. Then - + + a. The Zeros of `a(x)` are the roots of `S(x)`. + b. The Poles of `a(x)` are the roots of `T(x)`. However, `\infty` + can also be a pole of a(x). We say that `a(x)` has a pole at + `\infty` if `a(\frac{1}{x})` has a pole at 0. + +Every pole is associated with an order that is equal to the multiplicity +of its appearance as a root of `T(x)`. A pole is called a simple pole if +it has an order 1. Similarly, a pole is called a multiple pole if it has +an order `\ge` 2. + +Necessary Conditions +==================== + +For a Riccati equation in its normal form, + +.. math:: y' + y^2 = a(x) + +we can define + +a. A pole is called a movable pole if it is a pole of `y(x)` and is not +a pole of `a(x)`. +b. Similarly, a pole is called a non-movable pole if it is a pole of both +`y(x)` and `a(x)`. + +Then, the algorithm states that a rational solution exists only if - + +a. Every pole of `a(x)` must be either a simple pole or a multiple pole +of even order. +b. The valuation of `a(x)` at `\infty` must be even or be `\ge` 2. + +This algorithm finds all possible rational solutions for the Riccati ODE. +If no rational solutions are found, it means that no rational solutions +exist. + +The algorithm works for Riccati ODEs where the coefficients are rational +functions in the independent variable `x` with rational number coefficients +i.e. in `Q(x)`. The coefficients in the rational function cannot be floats, +irrational numbers, symbols or any other kind of expression. The reasons +for this are - + +1. When using symbols, different symbols could take the same value and this +would affect the multiplicity of poles if symbols are present here. + +2. An integer degree bound is required to calculate a polynomial solution +to an auxiliary differential equation, which in turn gives the particular +solution for the original ODE. If symbols/floats/irrational numbers are +present, we cannot determine if the expression for the degree bound is an +integer or not. + +Solution +======== + +With these definitions, we can state a general form for the solution of +the equation. `y(x)` must have the form - + +.. math:: y(x) = \sum_{i=1}^{n} \sum_{j=1}^{r_i} \frac{c_{ij}}{(x - x_i)^j} + \sum_{i=1}^{m} \frac{1}{x - \chi_i} + \sum_{i=0}^{N} d_i x^i + +where `x_1, x_2, \dots, x_n` are non-movable poles of `a(x)`, +`\chi_1, \chi_2, \dots, \chi_m` are movable poles of `a(x)`, and the values +of `N, n, r_1, r_2, \dots, r_n` can be determined from `a(x)`. The +coefficient vectors `(d_0, d_1, \dots, d_N)` and `(c_{i1}, c_{i2}, \dots, c_{i r_i})` +can be determined from `a(x)`. We will have 2 choices each of these vectors +and part of the procedure is figuring out which of the 2 should be used +to get the solution correctly. + +Implementation +============== + +In this implementation, we use ``Poly`` to represent a rational function +rather than using ``Expr`` since ``Poly`` is much faster. Since we cannot +represent rational functions directly using ``Poly``, we instead represent +a rational function with 2 ``Poly`` objects - one for its numerator and +the other for its denominator. + +The code is written to match the steps given in the thesis (Pg 82) + +Step 0 : Match the equation - +Find `b_0, b_1` and `b_2`. If `b_2 = 0` or no such functions exist, raise +an error + +Step 1 : Transform the equation to its normal form as explained in the +theory section. + +Step 2 : Initialize an empty set of solutions, ``sol``. + +Step 3 : If `a(x) = 0`, append `\frac{1}/{(x - C1)}` to ``sol``. + +Step 4 : If `a(x)` is a rational non-zero number, append `\pm \sqrt{a}` +to ``sol``. + +Step 5 : Find the poles and their multiplicities of `a(x)`. Let +the number of poles be `n`. Also find the valuation of `a(x)` at +`\infty` using ``val_at_inf``. + +NOTE: Although the algorithm considers `\infty` as a pole, it is +not mentioned if it a part of the set of finite poles. `\infty` +is NOT a part of the set of finite poles. If a pole exists at +`\infty`, we use its multiplicity to find the laurent series of +`a(x)` about `\infty`. + +Step 6 : Find `n` c-vectors (one for each pole) and 1 d-vector using +``construct_c`` and ``construct_d``. Now, determine all the ``2**(n + 1)`` +combinations of choosing between 2 choices for each of the `n` c-vectors +and 1 d-vector. + +NOTE: The equation for `d_{-1}` in Case 4 (Pg 80) has a printinig +mistake. The term `- d_N` must be replaced with `-N d_N`. The same +has been explained in the code as well. + +For each of these above combinations, do + +Step 8 : Compute `m` in ``compute_m_ybar``. `m` is the degree bound of +the polynomial solution we must find for the auxiliary equation. + +Step 9 : In ``compute_m_ybar``, compute ybar as well where ``ybar`` is +one part of y(x) - + +.. math:: \overline{y}(x) = \sum_{i=1}^{n} \sum_{j=1}^{r_i} \frac{c_{ij}}{(x - x_i)^j} + \sum_{i=0}^{N} d_i x^i + +Step 10 : If `m` is a non-negative integer - + +Step 11: Find a polynomial solution of degree `m` for the auxiliary equation. + +There are 2 cases possible - + + a. `m` is a non-negative integer: We can solve for the coefficients + in `p(x)` using Undetermined Coefficients. + + b. `m` is not a non-negative integer: In this case, we cannot find + a polynomial solution to the auxiliary equation, and hence, we ignore + this value of `m`. + +Step 12 : For each `p(x)` that exists, append `ybar + \frac{p'(x)}{p(x)}` +to ``sol``. + +Step 13 : For each solution in ``sol``, apply an inverse transformation, +so that the solutions of the original equation are found using the +solutions of the equation in its normal form. +""" + + +from itertools import product +from sympy.core import S +from sympy.core.add import Add +from sympy.core.numbers import oo, Float +from sympy.core.function import count_ops +from sympy.core.relational import Eq +from sympy.core.symbol import symbols, Symbol, Dummy +from sympy.functions import sqrt, exp +from sympy.functions.elementary.complexes import sign +from sympy.integrals.integrals import Integral +from sympy.polys.domains import ZZ +from sympy.polys.polytools import Poly +from sympy.polys.polyroots import roots +from sympy.solvers.solveset import linsolve + + +def riccati_normal(w, x, b1, b2): + """ + Given a solution `w(x)` to the equation + + .. math:: w'(x) = b_0(x) + b_1(x)*w(x) + b_2(x)*w(x)^2 + + and rational function coefficients `b_1(x)` and + `b_2(x)`, this function transforms the solution to + give a solution `y(x)` for its corresponding normal + Riccati ODE + + .. math:: y'(x) + y(x)^2 = a(x) + + using the transformation + + .. math:: y(x) = -b_2(x)*w(x) - b'_2(x)/(2*b_2(x)) - b_1(x)/2 + """ + return -b2*w - b2.diff(x)/(2*b2) - b1/2 + + +def riccati_inverse_normal(y, x, b1, b2, bp=None): + """ + Inverse transforming the solution to the normal + Riccati ODE to get the solution to the Riccati ODE. + """ + # bp is the expression which is independent of the solution + # and hence, it need not be computed again + if bp is None: + bp = -b2.diff(x)/(2*b2**2) - b1/(2*b2) + # w(x) = -y(x)/b2(x) - b2'(x)/(2*b2(x)^2) - b1(x)/(2*b2(x)) + return -y/b2 + bp + + +def riccati_reduced(eq, f, x): + """ + Convert a Riccati ODE into its corresponding + normal Riccati ODE. + """ + match, funcs = match_riccati(eq, f, x) + # If equation is not a Riccati ODE, exit + if not match: + return False + # Using the rational functions, find the expression for a(x) + b0, b1, b2 = funcs + a = -b0*b2 + b1**2/4 - b1.diff(x)/2 + 3*b2.diff(x)**2/(4*b2**2) + b1*b2.diff(x)/(2*b2) - \ + b2.diff(x, 2)/(2*b2) + # Normal form of Riccati ODE is f'(x) + f(x)^2 = a(x) + return f(x).diff(x) + f(x)**2 - a + +def linsolve_dict(eq, syms): + """ + Get the output of linsolve as a dict + """ + # Convert tuple type return value of linsolve + # to a dictionary for ease of use + sol = linsolve(eq, syms) + if not sol: + return {} + return dict(zip(syms, list(sol)[0])) + + +def match_riccati(eq, f, x): + """ + A function that matches and returns the coefficients + if an equation is a Riccati ODE + + Parameters + ========== + + eq: Equation to be matched + f: Dependent variable + x: Independent variable + + Returns + ======= + + match: True if equation is a Riccati ODE, False otherwise + funcs: [b0, b1, b2] if match is True, [] otherwise. Here, + b0, b1 and b2 are rational functions which match the equation. + """ + # Group terms based on f(x) + if isinstance(eq, Eq): + eq = eq.lhs - eq.rhs + eq = eq.expand().collect(f(x)) + cf = eq.coeff(f(x).diff(x)) + + # There must be an f(x).diff(x) term. + # eq must be an Add object since we are using the expanded + # equation and it must have atleast 2 terms (b2 != 0) + if cf != 0 and isinstance(eq, Add): + + # Divide all coefficients by the coefficient of f(x).diff(x) + # and add the terms again to get the same equation + eq = Add(*((x/cf).cancel() for x in eq.args)).collect(f(x)) + + # Match the equation with the pattern + b1 = -eq.coeff(f(x)) + b2 = -eq.coeff(f(x)**2) + b0 = (f(x).diff(x) - b1*f(x) - b2*f(x)**2 - eq).expand() + funcs = [b0, b1, b2] + + # Check if coefficients are not symbols and floats + if any(len(x.atoms(Symbol)) > 1 or len(x.atoms(Float)) for x in funcs): + return False, [] + + # If b_0(x) contains f(x), it is not a Riccati ODE + if len(b0.atoms(f)) or not all((b2 != 0, b0.is_rational_function(x), + b1.is_rational_function(x), b2.is_rational_function(x))): + return False, [] + return True, funcs + return False, [] + + +def val_at_inf(num, den, x): + # Valuation of a rational function at oo = deg(denom) - deg(numer) + return den.degree(x) - num.degree(x) + + +def check_necessary_conds(val_inf, muls): + """ + The necessary conditions for a rational solution + to exist are as follows - + + i) Every pole of a(x) must be either a simple pole + or a multiple pole of even order. + + ii) The valuation of a(x) at infinity must be even + or be greater than or equal to 2. + + Here, a simple pole is a pole with multiplicity 1 + and a multiple pole is a pole with multiplicity + greater than 1. + """ + return (val_inf >= 2 or (val_inf <= 0 and val_inf%2 == 0)) and \ + all(mul == 1 or (mul%2 == 0 and mul >= 2) for mul in muls) + + +def inverse_transform_poly(num, den, x): + """ + A function to make the substitution + x -> 1/x in a rational function that + is represented using Poly objects for + numerator and denominator. + """ + # Declare for reuse + one = Poly(1, x) + xpoly = Poly(x, x) + + # Check if degree of numerator is same as denominator + pwr = val_at_inf(num, den, x) + if pwr >= 0: + # Denominator has greater degree. Substituting x with + # 1/x would make the extra power go to the numerator + if num.expr != 0: + num = num.transform(one, xpoly) * x**pwr + den = den.transform(one, xpoly) + else: + # Numerator has greater degree. Substituting x with + # 1/x would make the extra power go to the denominator + num = num.transform(one, xpoly) + den = den.transform(one, xpoly) * x**(-pwr) + return num.cancel(den, include=True) + + +def limit_at_inf(num, den, x): + """ + Find the limit of a rational function + at oo + """ + # pwr = degree(num) - degree(den) + pwr = -val_at_inf(num, den, x) + # Numerator has a greater degree than denominator + # Limit at infinity would depend on the sign of the + # leading coefficients of numerator and denominator + if pwr > 0: + return oo*sign(num.LC()/den.LC()) + # Degree of numerator is equal to that of denominator + # Limit at infinity is just the ratio of leading coeffs + elif pwr == 0: + return num.LC()/den.LC() + # Degree of numerator is less than that of denominator + # Limit at infinity is just 0 + else: + return 0 + + +def construct_c_case_1(num, den, x, pole): + # Find the coefficient of 1/(x - pole)**2 in the + # Laurent series expansion of a(x) about pole. + num1, den1 = (num*Poly((x - pole)**2, x, extension=True)).cancel(den, include=True) + r = (num1.subs(x, pole))/(den1.subs(x, pole)) + + # If multiplicity is 2, the coefficient to be added + # in the c-vector is c = (1 +- sqrt(1 + 4*r))/2 + if r != -S(1)/4: + return [[(1 + sqrt(1 + 4*r))/2], [(1 - sqrt(1 + 4*r))/2]] + return [[S.Half]] + + +def construct_c_case_2(num, den, x, pole, mul): + # Generate the coefficients using the recurrence + # relation mentioned in (5.14) in the thesis (Pg 80) + + # r_i = mul/2 + ri = mul//2 + + # Find the Laurent series coefficients about the pole + ser = rational_laurent_series(num, den, x, pole, mul, 6) + + # Start with an empty memo to store the coefficients + # This is for the plus case + cplus = [0 for i in range(ri)] + + # Base Case + cplus[ri-1] = sqrt(ser[2*ri]) + + # Iterate backwards to find all coefficients + s = ri - 1 + sm = 0 + for s in range(ri-1, 0, -1): + sm = 0 + for j in range(s+1, ri): + sm += cplus[j-1]*cplus[ri+s-j-1] + if s!= 1: + cplus[s-1] = (ser[ri+s] - sm)/(2*cplus[ri-1]) + + # Memo for the minus case + cminus = [-x for x in cplus] + + # Find the 0th coefficient in the recurrence + cplus[0] = (ser[ri+s] - sm - ri*cplus[ri-1])/(2*cplus[ri-1]) + cminus[0] = (ser[ri+s] - sm - ri*cminus[ri-1])/(2*cminus[ri-1]) + + # Add both the plus and minus cases' coefficients + if cplus != cminus: + return [cplus, cminus] + return cplus + + +def construct_c_case_3(): + # If multiplicity is 1, the coefficient to be added + # in the c-vector is 1 (no choice) + return [[1]] + + +def construct_c(num, den, x, poles, muls): + """ + Helper function to calculate the coefficients + in the c-vector for each pole. + """ + c = [] + for pole, mul in zip(poles, muls): + c.append([]) + + # Case 3 + if mul == 1: + # Add the coefficients from Case 3 + c[-1].extend(construct_c_case_3()) + + # Case 1 + elif mul == 2: + # Add the coefficients from Case 1 + c[-1].extend(construct_c_case_1(num, den, x, pole)) + + # Case 2 + else: + # Add the coefficients from Case 2 + c[-1].extend(construct_c_case_2(num, den, x, pole, mul)) + + return c + + +def construct_d_case_4(ser, N): + # Initialize an empty vector + dplus = [0 for i in range(N+2)] + # d_N = sqrt(a_{2*N}) + dplus[N] = sqrt(ser[2*N]) + + # Use the recurrence relations to find + # the value of d_s + for s in range(N-1, -2, -1): + sm = 0 + for j in range(s+1, N): + sm += dplus[j]*dplus[N+s-j] + if s != -1: + dplus[s] = (ser[N+s] - sm)/(2*dplus[N]) + + # Coefficients for the case of d_N = -sqrt(a_{2*N}) + dminus = [-x for x in dplus] + + # The third equation in Eq 5.15 of the thesis is WRONG! + # d_N must be replaced with N*d_N in that equation. + dplus[-1] = (ser[N+s] - N*dplus[N] - sm)/(2*dplus[N]) + dminus[-1] = (ser[N+s] - N*dminus[N] - sm)/(2*dminus[N]) + + if dplus != dminus: + return [dplus, dminus] + return dplus + + +def construct_d_case_5(ser): + # List to store coefficients for plus case + dplus = [0, 0] + + # d_0 = sqrt(a_0) + dplus[0] = sqrt(ser[0]) + + # d_(-1) = a_(-1)/(2*d_0) + dplus[-1] = ser[-1]/(2*dplus[0]) + + # Coefficients for the minus case are just the negative + # of the coefficients for the positive case. + dminus = [-x for x in dplus] + + if dplus != dminus: + return [dplus, dminus] + return dplus + + +def construct_d_case_6(num, den, x): + # s_oo = lim x->0 1/x**2 * a(1/x) which is equivalent to + # s_oo = lim x->oo x**2 * a(x) + s_inf = limit_at_inf(Poly(x**2, x)*num, den, x) + + # d_(-1) = (1 +- sqrt(1 + 4*s_oo))/2 + if s_inf != -S(1)/4: + return [[(1 + sqrt(1 + 4*s_inf))/2], [(1 - sqrt(1 + 4*s_inf))/2]] + return [[S.Half]] + + +def construct_d(num, den, x, val_inf): + """ + Helper function to calculate the coefficients + in the d-vector based on the valuation of the + function at oo. + """ + N = -val_inf//2 + # Multiplicity of oo as a pole + mul = -val_inf if val_inf < 0 else 0 + ser = rational_laurent_series(num, den, x, oo, mul, 1) + + # Case 4 + if val_inf < 0: + d = construct_d_case_4(ser, N) + + # Case 5 + elif val_inf == 0: + d = construct_d_case_5(ser) + + # Case 6 + else: + d = construct_d_case_6(num, den, x) + + return d + + +def rational_laurent_series(num, den, x, r, m, n): + r""" + The function computes the Laurent series coefficients + of a rational function. + + Parameters + ========== + + num: A Poly object that is the numerator of `f(x)`. + den: A Poly object that is the denominator of `f(x)`. + x: The variable of expansion of the series. + r: The point of expansion of the series. + m: Multiplicity of r if r is a pole of `f(x)`. Should + be zero otherwise. + n: Order of the term upto which the series is expanded. + + Returns + ======= + + series: A dictionary that has power of the term as key + and coefficient of that term as value. + + Below is a basic outline of how the Laurent series of a + rational function `f(x)` about `x_0` is being calculated - + + 1. Substitute `x + x_0` in place of `x`. If `x_0` + is a pole of `f(x)`, multiply the expression by `x^m` + where `m` is the multiplicity of `x_0`. Denote the + the resulting expression as g(x). We do this substitution + so that we can now find the Laurent series of g(x) about + `x = 0`. + + 2. We can then assume that the Laurent series of `g(x)` + takes the following form - + + .. math:: g(x) = \frac{num(x)}{den(x)} = \sum_{m = 0}^{\infty} a_m x^m + + where `a_m` denotes the Laurent series coefficients. + + 3. Multiply the denominator to the RHS of the equation + and form a recurrence relation for the coefficients `a_m`. + """ + one = Poly(1, x, extension=True) + + if r == oo: + # Series at x = oo is equal to first transforming + # the function from x -> 1/x and finding the + # series at x = 0 + num, den = inverse_transform_poly(num, den, x) + r = S(0) + + if r: + # For an expansion about a non-zero point, a + # transformation from x -> x + r must be made + num = num.transform(Poly(x + r, x, extension=True), one) + den = den.transform(Poly(x + r, x, extension=True), one) + + # Remove the pole from the denominator if the series + # expansion is about one of the poles + num, den = (num*x**m).cancel(den, include=True) + + # Equate coefficients for the first terms (base case) + maxdegree = 1 + max(num.degree(), den.degree()) + syms = symbols(f'a:{maxdegree}', cls=Dummy) + diff = num - den * Poly(syms[::-1], x) + coeff_diffs = diff.all_coeffs()[::-1][:maxdegree] + (coeffs, ) = linsolve(coeff_diffs, syms) + + # Use the recursion relation for the rest + recursion = den.all_coeffs()[::-1] + div, rec_rhs = recursion[0], recursion[1:] + series = list(coeffs) + while len(series) < n: + next_coeff = Add(*(c*series[-1-n] for n, c in enumerate(rec_rhs))) / div + series.append(-next_coeff) + series = {m - i: val for i, val in enumerate(series)} + return series + +def compute_m_ybar(x, poles, choice, N): + """ + Helper function to calculate - + + 1. m - The degree bound for the polynomial + solution that must be found for the auxiliary + differential equation. + + 2. ybar - Part of the solution which can be + computed using the poles, c and d vectors. + """ + ybar = 0 + m = Poly(choice[-1][-1], x, extension=True) + + # Calculate the first (nested) summation for ybar + # as given in Step 9 of the Thesis (Pg 82) + dybar = [] + for i, polei in enumerate(poles): + for j, cij in enumerate(choice[i]): + dybar.append(cij/(x - polei)**(j + 1)) + m -=Poly(choice[i][0], x, extension=True) # can't accumulate Poly and use with Add + ybar += Add(*dybar) + + # Calculate the second summation for ybar + for i in range(N+1): + ybar += choice[-1][i]*x**i + return (m.expr, ybar) + + +def solve_aux_eq(numa, dena, numy, deny, x, m): + """ + Helper function to find a polynomial solution + of degree m for the auxiliary differential + equation. + """ + # Assume that the solution is of the type + # p(x) = C_0 + C_1*x + ... + C_{m-1}*x**(m-1) + x**m + psyms = symbols(f'C0:{m}', cls=Dummy) + K = ZZ[psyms] + psol = Poly(K.gens, x, domain=K) + Poly(x**m, x, domain=K) + + # Eq (5.16) in Thesis - Pg 81 + auxeq = (dena*(numy.diff(x)*deny - numy*deny.diff(x) + numy**2) - numa*deny**2)*psol + if m >= 1: + px = psol.diff(x) + auxeq += px*(2*numy*deny*dena) + if m >= 2: + auxeq += px.diff(x)*(deny**2*dena) + if m != 0: + # m is a non-zero integer. Find the constant terms using undetermined coefficients + return psol, linsolve_dict(auxeq.all_coeffs(), psyms), True + else: + # m == 0 . Check if 1 (x**0) is a solution to the auxiliary equation + return S.One, auxeq, auxeq == 0 + + +def remove_redundant_sols(sol1, sol2, x): + """ + Helper function to remove redundant + solutions to the differential equation. + """ + # If y1 and y2 are redundant solutions, there is + # some value of the arbitrary constant for which + # they will be equal + + syms1 = sol1.atoms(Symbol, Dummy) + syms2 = sol2.atoms(Symbol, Dummy) + num1, den1 = [Poly(e, x, extension=True) for e in sol1.together().as_numer_denom()] + num2, den2 = [Poly(e, x, extension=True) for e in sol2.together().as_numer_denom()] + # Cross multiply + e = num1*den2 - den1*num2 + # Check if there are any constants + syms = list(e.atoms(Symbol, Dummy)) + if len(syms): + # Find values of constants for which solutions are equal + redn = linsolve(e.all_coeffs(), syms) + if len(redn): + # Return the general solution over a particular solution + if len(syms1) > len(syms2): + return sol2 + # If both have constants, return the lesser complex solution + elif len(syms1) == len(syms2): + return sol1 if count_ops(syms1) >= count_ops(syms2) else sol2 + else: + return sol1 + + +def get_gen_sol_from_part_sol(part_sols, a, x): + """" + Helper function which computes the general + solution for a Riccati ODE from its particular + solutions. + + There are 3 cases to find the general solution + from the particular solutions for a Riccati ODE + depending on the number of particular solution(s) + we have - 1, 2 or 3. + + For more information, see Section 6 of + "Methods of Solution of the Riccati Differential Equation" + by D. R. Haaheim and F. M. Stein + """ + + # If no particular solutions are found, a general + # solution cannot be found + if len(part_sols) == 0: + return [] + + # In case of a single particular solution, the general + # solution can be found by using the substitution + # y = y1 + 1/z and solving a Bernoulli ODE to find z. + elif len(part_sols) == 1: + y1 = part_sols[0] + i = exp(Integral(2*y1, x)) + z = i * Integral(a/i, x) + z = z.doit() + if a == 0 or z == 0: + return y1 + return y1 + 1/z + + # In case of 2 particular solutions, the general solution + # can be found by solving a separable equation. This is + # the most common case, i.e. most Riccati ODEs have 2 + # rational particular solutions. + elif len(part_sols) == 2: + y1, y2 = part_sols + # One of them already has a constant + if len(y1.atoms(Dummy)) + len(y2.atoms(Dummy)) > 0: + u = exp(Integral(y2 - y1, x)).doit() + # Introduce a constant + else: + C1 = Dummy('C1') + u = C1*exp(Integral(y2 - y1, x)).doit() + if u == 1: + return y2 + return (y2*u - y1)/(u - 1) + + # In case of 3 particular solutions, a closed form + # of the general solution can be obtained directly + else: + y1, y2, y3 = part_sols[:3] + C1 = Dummy('C1') + return (C1 + 1)*y2*(y1 - y3)/(C1*y1 + y2 - (C1 + 1)*y3) + + +def solve_riccati(fx, x, b0, b1, b2, gensol=False): + """ + The main function that gives particular/general + solutions to Riccati ODEs that have atleast 1 + rational particular solution. + """ + # Step 1 : Convert to Normal Form + a = -b0*b2 + b1**2/4 - b1.diff(x)/2 + 3*b2.diff(x)**2/(4*b2**2) + b1*b2.diff(x)/(2*b2) - \ + b2.diff(x, 2)/(2*b2) + a_t = a.together() + num, den = [Poly(e, x, extension=True) for e in a_t.as_numer_denom()] + num, den = num.cancel(den, include=True) + + # Step 2 + presol = [] + + # Step 3 : a(x) is 0 + if num == 0: + presol.append(1/(x + Dummy('C1'))) + + # Step 4 : a(x) is a non-zero constant + elif x not in num.free_symbols.union(den.free_symbols): + presol.extend([sqrt(a), -sqrt(a)]) + + # Step 5 : Find poles and valuation at infinity + poles = roots(den, x) + poles, muls = list(poles.keys()), list(poles.values()) + val_inf = val_at_inf(num, den, x) + + if len(poles): + # Check necessary conditions (outlined in the module docstring) + if not check_necessary_conds(val_inf, muls): + raise ValueError("Rational Solution doesn't exist") + + # Step 6 + # Construct c-vectors for each singular point + c = construct_c(num, den, x, poles, muls) + + # Construct d vectors for each singular point + d = construct_d(num, den, x, val_inf) + + # Step 7 : Iterate over all possible combinations and return solutions + # For each possible combination, generate an array of 0's and 1's + # where 0 means pick 1st choice and 1 means pick the second choice. + + # NOTE: We could exit from the loop if we find 3 particular solutions, + # but it is not implemented here as - + # a. Finding 3 particular solutions is very rare. Most of the time, + # only 2 particular solutions are found. + # b. In case we exit after finding 3 particular solutions, it might + # happen that 1 or 2 of them are redundant solutions. So, instead of + # spending some more time in computing the particular solutions, + # we will end up computing the general solution from a single + # particular solution which is usually slower than computing the + # general solution from 2 or 3 particular solutions. + c.append(d) + choices = product(*c) + for choice in choices: + m, ybar = compute_m_ybar(x, poles, choice, -val_inf//2) + numy, deny = [Poly(e, x, extension=True) for e in ybar.together().as_numer_denom()] + # Step 10 : Check if a valid solution exists. If yes, also check + # if m is a non-negative integer + if m.is_nonnegative == True and m.is_integer == True: + + # Step 11 : Find polynomial solutions of degree m for the auxiliary equation + psol, coeffs, exists = solve_aux_eq(num, den, numy, deny, x, m) + + # Step 12 : If valid polynomial solution exists, append solution. + if exists: + # m == 0 case + if psol == 1 and coeffs == 0: + # p(x) = 1, so p'(x)/p(x) term need not be added + presol.append(ybar) + # m is a positive integer and there are valid coefficients + elif len(coeffs): + # Substitute the valid coefficients to get p(x) + psol = psol.xreplace(coeffs) + # y(x) = ybar(x) + p'(x)/p(x) + presol.append(ybar + psol.diff(x)/psol) + + # Remove redundant solutions from the list of existing solutions + remove = set() + for i in range(len(presol)): + for j in range(i+1, len(presol)): + rem = remove_redundant_sols(presol[i], presol[j], x) + if rem is not None: + remove.add(rem) + sols = [x for x in presol if x not in remove] + + # Step 15 : Inverse transform the solutions of the equation in normal form + bp = -b2.diff(x)/(2*b2**2) - b1/(2*b2) + + # If general solution is required, compute it from the particular solutions + if gensol: + sols = [get_gen_sol_from_part_sol(sols, a, x)] + + # Inverse transform the particular solutions + presol = [Eq(fx, riccati_inverse_normal(y, x, b1, b2, bp).cancel(extension=True)) for y in sols] + return presol diff --git a/phivenv/Lib/site-packages/sympy/solvers/ode/single.py b/phivenv/Lib/site-packages/sympy/solvers/ode/single.py new file mode 100644 index 0000000000000000000000000000000000000000..c4829acf41293c2f6f20af3e9c45d37457802102 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/solvers/ode/single.py @@ -0,0 +1,2977 @@ +# +# This is the module for ODE solver classes for single ODEs. +# + +from __future__ import annotations +from typing import ClassVar, Iterator + +from .riccati import match_riccati, solve_riccati +from sympy.core import Add, S, Pow, Rational +from sympy.core.cache import cached_property +from sympy.core.exprtools import factor_terms +from sympy.core.expr import Expr +from sympy.core.function import AppliedUndef, Derivative, diff, Function, expand, Subs, _mexpand +from sympy.core.numbers import zoo +from sympy.core.relational import Equality, Eq +from sympy.core.symbol import Symbol, Dummy, Wild +from sympy.core.mul import Mul +from sympy.functions import exp, tan, log, sqrt, besselj, bessely, cbrt, airyai, airybi +from sympy.integrals import Integral +from sympy.polys import Poly +from sympy.polys.polytools import cancel, factor, degree +from sympy.simplify import collect, simplify, separatevars, logcombine, posify # type: ignore +from sympy.simplify.radsimp import fraction +from sympy.utilities import numbered_symbols +from sympy.solvers.solvers import solve +from sympy.solvers.deutils import ode_order, _preprocess +from sympy.polys.matrices.linsolve import _lin_eq2dict +from sympy.polys.solvers import PolyNonlinearError +from .hypergeometric import equivalence_hypergeometric, match_2nd_2F1_hypergeometric, \ + get_sol_2F1_hypergeometric, match_2nd_hypergeometric +from .nonhomogeneous import _get_euler_characteristic_eq_sols, _get_const_characteristic_eq_sols, \ + _solve_undetermined_coefficients, _solve_variation_of_parameters, _test_term, _undetermined_coefficients_match, \ + _get_simplified_sol +from .lie_group import _ode_lie_group + + +class ODEMatchError(NotImplementedError): + """Raised if a SingleODESolver is asked to solve an ODE it does not match""" + pass + + +class SingleODEProblem: + """Represents an ordinary differential equation (ODE) + + This class is used internally in the by dsolve and related + functions/classes so that properties of an ODE can be computed + efficiently. + + Examples + ======== + + This class is used internally by dsolve. To instantiate an instance + directly first define an ODE problem: + + >>> from sympy import Function, Symbol + >>> x = Symbol('x') + >>> f = Function('f') + >>> eq = f(x).diff(x, 2) + + Now you can create a SingleODEProblem instance and query its properties: + + >>> from sympy.solvers.ode.single import SingleODEProblem + >>> problem = SingleODEProblem(f(x).diff(x), f(x), x) + >>> problem.eq + Derivative(f(x), x) + >>> problem.func + f(x) + >>> problem.sym + x + """ + + # Instance attributes: + eq: Expr + func: AppliedUndef + sym: Symbol + _order: int + _eq_expanded: Expr + _eq_preprocessed: Expr + _eq_high_order_free = None + + def __init__(self, eq, func, sym, prep=True, **kwargs): + assert isinstance(eq, Expr) + assert isinstance(func, AppliedUndef) + assert isinstance(sym, Symbol) + assert isinstance(prep, bool) + self.eq = eq + self.func = func + self.sym = sym + self.prep = prep + self.params = kwargs + + @cached_property + def order(self) -> int: + return ode_order(self.eq, self.func) + + @cached_property + def eq_preprocessed(self) -> Expr: + return self._get_eq_preprocessed() + + @cached_property + def eq_high_order_free(self) -> Expr: + a = Wild('a', exclude=[self.func]) + c1 = Wild('c1', exclude=[self.sym]) + # Precondition to try remove f(x) from highest order derivative + reduced_eq = None + if self.eq.is_Add: + deriv_coef = self.eq.coeff(self.func.diff(self.sym, self.order)) + if deriv_coef not in (1, 0): + r = deriv_coef.match(a*self.func**c1) + if r and r[c1]: + den = self.func**r[c1] + reduced_eq = Add(*[arg/den for arg in self.eq.args]) + if reduced_eq is None: + reduced_eq = expand(self.eq) + return reduced_eq + + @cached_property + def eq_expanded(self) -> Expr: + return expand(self.eq_preprocessed) + + def _get_eq_preprocessed(self) -> Expr: + if self.prep: + process_eq, process_func = _preprocess(self.eq, self.func) + if process_func != self.func: + raise ValueError + else: + process_eq = self.eq + return process_eq + + def get_numbered_constants(self, num=1, start=1, prefix='C') -> list[Symbol]: + """ + Returns a list of constants that do not occur + in eq already. + """ + ncs = self.iter_numbered_constants(start, prefix) + Cs = [next(ncs) for i in range(num)] + return Cs + + def iter_numbered_constants(self, start=1, prefix='C') -> Iterator[Symbol]: + """ + Returns an iterator of constants that do not occur + in eq already. + """ + atom_set = self.eq.free_symbols + func_set = self.eq.atoms(Function) + if func_set: + atom_set |= {Symbol(str(f.func)) for f in func_set} + return numbered_symbols(start=start, prefix=prefix, exclude=atom_set) + + @cached_property + def is_autonomous(self): + u = Dummy('u') + x = self.sym + syms = self.eq.subs(self.func, u).free_symbols + return x not in syms + + def get_linear_coefficients(self, eq, func, order): + r""" + Matches a differential equation to the linear form: + + .. math:: a_n(x) y^{(n)} + \cdots + a_1(x)y' + a_0(x) y + B(x) = 0 + + Returns a dict of order:coeff terms, where order is the order of the + derivative on each term, and coeff is the coefficient of that derivative. + The key ``-1`` holds the function `B(x)`. Returns ``None`` if the ODE is + not linear. This function assumes that ``func`` has already been checked + to be good. + + Examples + ======== + + >>> from sympy import Function, cos, sin + >>> from sympy.abc import x + >>> from sympy.solvers.ode.single import SingleODEProblem + >>> f = Function('f') + >>> eq = f(x).diff(x, 3) + 2*f(x).diff(x) + \ + ... x*f(x).diff(x, 2) + cos(x)*f(x).diff(x) + x - f(x) - \ + ... sin(x) + >>> obj = SingleODEProblem(eq, f(x), x) + >>> obj.get_linear_coefficients(eq, f(x), 3) + {-1: x - sin(x), 0: -1, 1: cos(x) + 2, 2: x, 3: 1} + >>> eq = f(x).diff(x, 3) + 2*f(x).diff(x) + \ + ... x*f(x).diff(x, 2) + cos(x)*f(x).diff(x) + x - f(x) - \ + ... sin(f(x)) + >>> obj = SingleODEProblem(eq, f(x), x) + >>> obj.get_linear_coefficients(eq, f(x), 3) == None + True + + """ + f = func.func + x = func.args[0] + symset = {Derivative(f(x), x, i) for i in range(order+1)} + try: + rhs, lhs_terms = _lin_eq2dict(eq, symset) + except PolyNonlinearError: + return None + + if rhs.has(func) or any(c.has(func) for c in lhs_terms.values()): + return None + terms = {i: lhs_terms.get(f(x).diff(x, i), S.Zero) for i in range(order+1)} + terms[-1] = rhs + return terms + + # TODO: Add methods that can be used by many ODE solvers: + # order + # is_linear() + # get_linear_coefficients() + # eq_prepared (the ODE in prepared form) + + +class SingleODESolver: + """ + Base class for Single ODE solvers. + + Subclasses should implement the _matches and _get_general_solution + methods. This class is not intended to be instantiated directly but its + subclasses are as part of dsolve. + + Examples + ======== + + You can use a subclass of SingleODEProblem to solve a particular type of + ODE. We first define a particular ODE problem: + + >>> from sympy import Function, Symbol + >>> x = Symbol('x') + >>> f = Function('f') + >>> eq = f(x).diff(x, 2) + + Now we solve this problem using the NthAlgebraic solver which is a + subclass of SingleODESolver: + + >>> from sympy.solvers.ode.single import NthAlgebraic, SingleODEProblem + >>> problem = SingleODEProblem(eq, f(x), x) + >>> solver = NthAlgebraic(problem) + >>> solver.get_general_solution() + [Eq(f(x), _C*x + _C)] + + The normal way to solve an ODE is to use dsolve (which would use + NthAlgebraic and other solvers internally). When using dsolve a number of + other things are done such as evaluating integrals, simplifying the + solution and renumbering the constants: + + >>> from sympy import dsolve + >>> dsolve(eq, hint='nth_algebraic') + Eq(f(x), C1 + C2*x) + """ + + # Subclasses should store the hint name (the argument to dsolve) in this + # attribute + hint: ClassVar[str] + + # Subclasses should define this to indicate if they support an _Integral + # hint. + has_integral: ClassVar[bool] + + # The ODE to be solved + ode_problem: SingleODEProblem + + # Cache whether or not the equation has matched the method + _matched: bool | None = None + + # Subclasses should store in this attribute the list of order(s) of ODE + # that subclass can solve or leave it to None if not specific to any order + order: list | None = None + + def __init__(self, ode_problem): + self.ode_problem = ode_problem + + def matches(self) -> bool: + if self.order is not None and self.ode_problem.order not in self.order: + self._matched = False + return self._matched + + if self._matched is None: + self._matched = self._matches() + return self._matched + + def get_general_solution(self, *, simplify: bool = True) -> list[Equality]: + if not self.matches(): + msg = "%s solver cannot solve:\n%s" + raise ODEMatchError(msg % (self.hint, self.ode_problem.eq)) + return self._get_general_solution(simplify_flag=simplify) + + def _matches(self) -> bool: + msg = "Subclasses of SingleODESolver should implement matches." + raise NotImplementedError(msg) + + def _get_general_solution(self, *, simplify_flag: bool = True) -> list[Equality]: + msg = "Subclasses of SingleODESolver should implement get_general_solution." + raise NotImplementedError(msg) + + +class SinglePatternODESolver(SingleODESolver): + '''Superclass for ODE solvers based on pattern matching''' + + def wilds(self): + prob = self.ode_problem + f = prob.func.func + x = prob.sym + order = prob.order + return self._wilds(f, x, order) + + def wilds_match(self): + match = self._wilds_match + return [match.get(w, S.Zero) for w in self.wilds()] + + def _matches(self): + eq = self.ode_problem.eq_expanded + f = self.ode_problem.func.func + x = self.ode_problem.sym + order = self.ode_problem.order + df = f(x).diff(x, order) + + if order not in [1, 2]: + return False + + pattern = self._equation(f(x), x, order) + + if not pattern.coeff(df).has(Wild): + eq = expand(eq / eq.coeff(df)) + eq = eq.collect([f(x).diff(x), f(x)], func = cancel) + + self._wilds_match = match = eq.match(pattern) + if match is not None: + return self._verify(f(x)) + return False + + def _verify(self, fx) -> bool: + return True + + def _wilds(self, f, x, order): + msg = "Subclasses of SingleODESolver should implement _wilds" + raise NotImplementedError(msg) + + def _equation(self, fx, x, order): + msg = "Subclasses of SingleODESolver should implement _equation" + raise NotImplementedError(msg) + + +class NthAlgebraic(SingleODESolver): + r""" + Solves an `n`\th order ordinary differential equation using algebra and + integrals. + + There is no general form for the kind of equation that this can solve. The + the equation is solved algebraically treating differentiation as an + invertible algebraic function. + + Examples + ======== + + >>> from sympy import Function, dsolve, Eq + >>> from sympy.abc import x + >>> f = Function('f') + >>> eq = Eq(f(x) * (f(x).diff(x)**2 - 1), 0) + >>> dsolve(eq, f(x), hint='nth_algebraic') + [Eq(f(x), 0), Eq(f(x), C1 - x), Eq(f(x), C1 + x)] + + Note that this solver can return algebraic solutions that do not have any + integration constants (f(x) = 0 in the above example). + """ + + hint = 'nth_algebraic' + has_integral = True # nth_algebraic_Integral hint + + def _matches(self): + r""" + Matches any differential equation that nth_algebraic can solve. Uses + `sympy.solve` but teaches it how to integrate derivatives. + + This involves calling `sympy.solve` and does most of the work of finding a + solution (apart from evaluating the integrals). + """ + eq = self.ode_problem.eq + func = self.ode_problem.func + var = self.ode_problem.sym + + # Derivative that solve can handle: + diffx = self._get_diffx(var) + + # Replace derivatives wrt the independent variable with diffx + def replace(eq, var): + def expand_diffx(*args): + differand, diffs = args[0], args[1:] + toreplace = differand + for v, n in diffs: + for _ in range(n): + if v == var: + toreplace = diffx(toreplace) + else: + toreplace = Derivative(toreplace, v) + return toreplace + return eq.replace(Derivative, expand_diffx) + + # Restore derivatives in solution afterwards + def unreplace(eq, var): + return eq.replace(diffx, lambda e: Derivative(e, var)) + + subs_eqn = replace(eq, var) + try: + # turn off simplification to protect Integrals that have + # _t instead of fx in them and would otherwise factor + # as t_*Integral(1, x) + solns = solve(subs_eqn, func, simplify=False) + except NotImplementedError: + solns = [] + + solns = [simplify(unreplace(soln, var)) for soln in solns] + solns = [Equality(func, soln) for soln in solns] + + self.solutions = solns + return len(solns) != 0 + + def _get_general_solution(self, *, simplify_flag: bool = True): + return self.solutions + + # This needs to produce an invertible function but the inverse depends + # which variable we are integrating with respect to. Since the class can + # be stored in cached results we need to ensure that we always get the + # same class back for each particular integration variable so we store these + # classes in a global dict: + _diffx_stored: dict[Symbol, type[Function]] = {} + + @staticmethod + def _get_diffx(var): + diffcls = NthAlgebraic._diffx_stored.get(var, None) + + if diffcls is None: + # A class that behaves like Derivative wrt var but is "invertible". + class diffx(Function): + def inverse(self): + # don't use integrate here because fx has been replaced by _t + # in the equation; integrals will not be correct while solve + # is at work. + return lambda expr: Integral(expr, var) + Dummy('C') + + diffcls = NthAlgebraic._diffx_stored.setdefault(var, diffx) + + return diffcls + + +class FirstExact(SinglePatternODESolver): + r""" + Solves 1st order exact ordinary differential equations. + + A 1st order differential equation is called exact if it is the total + differential of a function. That is, the differential equation + + .. math:: P(x, y) \,\partial{}x + Q(x, y) \,\partial{}y = 0 + + is exact if there is some function `F(x, y)` such that `P(x, y) = + \partial{}F/\partial{}x` and `Q(x, y) = \partial{}F/\partial{}y`. It can + be shown that a necessary and sufficient condition for a first order ODE + to be exact is that `\partial{}P/\partial{}y = \partial{}Q/\partial{}x`. + Then, the solution will be as given below:: + + >>> from sympy import Function, Eq, Integral, symbols, pprint + >>> x, y, t, x0, y0, C1= symbols('x,y,t,x0,y0,C1') + >>> P, Q, F= map(Function, ['P', 'Q', 'F']) + >>> pprint(Eq(Eq(F(x, y), Integral(P(t, y), (t, x0, x)) + + ... Integral(Q(x0, t), (t, y0, y))), C1)) + x y + / / + | | + F(x, y) = | P(t, y) dt + | Q(x0, t) dt = C1 + | | + / / + x0 y0 + + Where the first partials of `P` and `Q` exist and are continuous in a + simply connected region. + + A note: SymPy currently has no way to represent inert substitution on an + expression, so the hint ``1st_exact_Integral`` will return an integral + with `dy`. This is supposed to represent the function that you are + solving for. + + Examples + ======== + + >>> from sympy import Function, dsolve, cos, sin + >>> from sympy.abc import x + >>> f = Function('f') + >>> dsolve(cos(f(x)) - (x*sin(f(x)) - f(x)**2)*f(x).diff(x), + ... f(x), hint='1st_exact') + Eq(x*cos(f(x)) + f(x)**3/3, C1) + + References + ========== + + - https://en.wikipedia.org/wiki/Exact_differential_equation + - M. Tenenbaum & H. Pollard, "Ordinary Differential Equations", + Dover 1963, pp. 73 + + # indirect doctest + + """ + hint = "1st_exact" + has_integral = True + order = [1] + + def _wilds(self, f, x, order): + P = Wild('P', exclude=[f(x).diff(x)]) + Q = Wild('Q', exclude=[f(x).diff(x)]) + return P, Q + + def _equation(self, fx, x, order): + P, Q = self.wilds() + return P + Q*fx.diff(x) + + def _verify(self, fx) -> bool: + P, Q = self.wilds() + x = self.ode_problem.sym + y = Dummy('y') + + m, n = self.wilds_match() + + m = m.subs(fx, y) + n = n.subs(fx, y) + numerator = cancel(m.diff(y) - n.diff(x)) + + if numerator.is_zero: + # Is exact + return True + else: + # The following few conditions try to convert a non-exact + # differential equation into an exact one. + # References: + # 1. Differential equations with applications + # and historical notes - George E. Simmons + # 2. https://math.okstate.edu/people/binegar/2233-S99/2233-l12.pdf + + factor_n = cancel(numerator/n) + factor_m = cancel(-numerator/m) + if y not in factor_n.free_symbols: + # If (dP/dy - dQ/dx) / Q = f(x) + # then exp(integral(f(x))*equation becomes exact + factor = factor_n + integration_variable = x + elif x not in factor_m.free_symbols: + # If (dP/dy - dQ/dx) / -P = f(y) + # then exp(integral(f(y))*equation becomes exact + factor = factor_m + integration_variable = y + else: + # Couldn't convert to exact + return False + + factor = exp(Integral(factor, integration_variable)) + m *= factor + n *= factor + self._wilds_match[P] = m.subs(y, fx) + self._wilds_match[Q] = n.subs(y, fx) + return True + + def _get_general_solution(self, *, simplify_flag: bool = True): + m, n = self.wilds_match() + fx = self.ode_problem.func + x = self.ode_problem.sym + (C1,) = self.ode_problem.get_numbered_constants(num=1) + y = Dummy('y') + + m = m.subs(fx, y) + n = n.subs(fx, y) + + gen_sol = Eq(Subs(Integral(m, x) + + Integral(n - Integral(m, x).diff(y), y), y, fx), C1) + return [gen_sol] + + +class FirstLinear(SinglePatternODESolver): + r""" + Solves 1st order linear differential equations. + + These are differential equations of the form + + .. math:: dy/dx + P(x) y = Q(x)\text{.} + + These kinds of differential equations can be solved in a general way. The + integrating factor `e^{\int P(x) \,dx}` will turn the equation into a + separable equation. The general solution is:: + + >>> from sympy import Function, dsolve, Eq, pprint, diff, sin + >>> from sympy.abc import x + >>> f, P, Q = map(Function, ['f', 'P', 'Q']) + >>> genform = Eq(f(x).diff(x) + P(x)*f(x), Q(x)) + >>> pprint(genform) + d + P(x)*f(x) + --(f(x)) = Q(x) + dx + >>> pprint(dsolve(genform, f(x), hint='1st_linear_Integral')) + / / \ + | | | + | | / | / + | | | | | + | | | P(x) dx | - | P(x) dx + | | | | | + | | / | / + f(x) = |C1 + | Q(x)*e dx|*e + | | | + \ / / + + + Examples + ======== + + >>> f = Function('f') + >>> pprint(dsolve(Eq(x*diff(f(x), x) - f(x), x**2*sin(x)), + ... f(x), '1st_linear')) + f(x) = x*(C1 - cos(x)) + + References + ========== + + - https://en.wikipedia.org/wiki/Linear_differential_equation#First-order_equation_with_variable_coefficients + - M. Tenenbaum & H. Pollard, "Ordinary Differential Equations", + Dover 1963, pp. 92 + + # indirect doctest + + """ + hint = '1st_linear' + has_integral = True + order = [1] + + def _wilds(self, f, x, order): + P = Wild('P', exclude=[f(x)]) + Q = Wild('Q', exclude=[f(x), f(x).diff(x)]) + return P, Q + + def _equation(self, fx, x, order): + P, Q = self.wilds() + return fx.diff(x) + P*fx - Q + + def _get_general_solution(self, *, simplify_flag: bool = True): + P, Q = self.wilds_match() + fx = self.ode_problem.func + x = self.ode_problem.sym + (C1,) = self.ode_problem.get_numbered_constants(num=1) + gensol = Eq(fx, ((C1 + Integral(Q*exp(Integral(P, x)), x)) + * exp(-Integral(P, x)))) + return [gensol] + + +class AlmostLinear(SinglePatternODESolver): + r""" + Solves an almost-linear differential equation. + + The general form of an almost linear differential equation is + + .. math:: a(x) g'(f(x)) f'(x) + b(x) g(f(x)) + c(x) + + Here `f(x)` is the function to be solved for (the dependent variable). + The substitution `g(f(x)) = u(x)` leads to a linear differential equation + for `u(x)` of the form `a(x) u' + b(x) u + c(x) = 0`. This can be solved + for `u(x)` by the `first_linear` hint and then `f(x)` is found by solving + `g(f(x)) = u(x)`. + + See Also + ======== + :obj:`sympy.solvers.ode.single.FirstLinear` + + Examples + ======== + + >>> from sympy import dsolve, Function, pprint, sin, cos + >>> from sympy.abc import x + >>> f = Function('f') + >>> d = f(x).diff(x) + >>> eq = x*d + x*f(x) + 1 + >>> dsolve(eq, f(x), hint='almost_linear') + Eq(f(x), (C1 - Ei(x))*exp(-x)) + >>> pprint(dsolve(eq, f(x), hint='almost_linear')) + -x + f(x) = (C1 - Ei(x))*e + >>> example = cos(f(x))*f(x).diff(x) + sin(f(x)) + 1 + >>> pprint(example) + d + sin(f(x)) + cos(f(x))*--(f(x)) + 1 + dx + >>> pprint(dsolve(example, f(x), hint='almost_linear')) + / -x \ / -x \ + [f(x) = pi - asin\C1*e - 1/, f(x) = asin\C1*e - 1/] + + + References + ========== + + - Joel Moses, "Symbolic Integration - The Stormy Decade", Communications + of the ACM, Volume 14, Number 8, August 1971, pp. 558 + """ + hint = "almost_linear" + has_integral = True + order = [1] + + def _wilds(self, f, x, order): + P = Wild('P', exclude=[f(x).diff(x)]) + Q = Wild('Q', exclude=[f(x).diff(x)]) + return P, Q + + def _equation(self, fx, x, order): + P, Q = self.wilds() + return P*fx.diff(x) + Q + + def _verify(self, fx): + a, b = self.wilds_match() + c, b = b.as_independent(fx) if b.is_Add else (S.Zero, b) + # a, b and c are the function a(x), b(x) and c(x) respectively. + # c(x) is obtained by separating out b as terms with and without fx i.e, l(y) + # The following conditions checks if the given equation is an almost-linear differential equation using the fact that + # a(x)*(l(y))' / l(y)' is independent of l(y) + + if b.diff(fx) != 0 and not simplify(b.diff(fx)/a).has(fx): + self.ly = factor_terms(b).as_independent(fx, as_Add=False)[1] # Gives the term containing fx i.e., l(y) + self.ax = a / self.ly.diff(fx) + self.cx = -c # cx is taken as -c(x) to simplify expression in the solution integral + self.bx = factor_terms(b) / self.ly + return True + + return False + + def _get_general_solution(self, *, simplify_flag: bool = True): + x = self.ode_problem.sym + (C1,) = self.ode_problem.get_numbered_constants(num=1) + gensol = Eq(self.ly, ((C1 + Integral((self.cx/self.ax)*exp(Integral(self.bx/self.ax, x)), x)) + * exp(-Integral(self.bx/self.ax, x)))) + + return [gensol] + + +class Bernoulli(SinglePatternODESolver): + r""" + Solves Bernoulli differential equations. + + These are equations of the form + + .. math:: dy/dx + P(x) y = Q(x) y^n\text{, }n \ne 1`\text{.} + + The substitution `w = 1/y^{1-n}` will transform an equation of this form + into one that is linear (see the docstring of + :obj:`~sympy.solvers.ode.single.FirstLinear`). The general solution is:: + + >>> from sympy import Function, dsolve, Eq, pprint + >>> from sympy.abc import x, n + >>> f, P, Q = map(Function, ['f', 'P', 'Q']) + >>> genform = Eq(f(x).diff(x) + P(x)*f(x), Q(x)*f(x)**n) + >>> pprint(genform) + d n + P(x)*f(x) + --(f(x)) = Q(x)*f (x) + dx + >>> pprint(dsolve(genform, f(x), hint='Bernoulli_Integral'), num_columns=110) + -1 + ----- + n - 1 + // / / \ \ + || | | | | + || | / | / | / | + || | | | | | | | + || | -(n - 1)* | P(x) dx | -(n - 1)* | P(x) dx | (n - 1)* | P(x) dx| + || | | | | | | | + || | / | / | / | + f(x) = ||C1 - n* | Q(x)*e dx + | Q(x)*e dx|*e | + || | | | | + \\ / / / / + + + Note that the equation is separable when `n = 1` (see the docstring of + :obj:`~sympy.solvers.ode.single.Separable`). + + >>> pprint(dsolve(Eq(f(x).diff(x) + P(x)*f(x), Q(x)*f(x)), f(x), + ... hint='separable_Integral')) + f(x) + / + | / + | 1 | + | - dy = C1 + | (-P(x) + Q(x)) dx + | y | + | / + / + + + Examples + ======== + + >>> from sympy import Function, dsolve, Eq, pprint, log + >>> from sympy.abc import x + >>> f = Function('f') + + >>> pprint(dsolve(Eq(x*f(x).diff(x) + f(x), log(x)*f(x)**2), + ... f(x), hint='Bernoulli')) + 1 + f(x) = ----------------- + C1*x + log(x) + 1 + + References + ========== + + - https://en.wikipedia.org/wiki/Bernoulli_differential_equation + + - M. Tenenbaum & H. Pollard, "Ordinary Differential Equations", + Dover 1963, pp. 95 + + # indirect doctest + + """ + hint = "Bernoulli" + has_integral = True + order = [1] + + def _wilds(self, f, x, order): + P = Wild('P', exclude=[f(x)]) + Q = Wild('Q', exclude=[f(x)]) + n = Wild('n', exclude=[x, f(x), f(x).diff(x)]) + return P, Q, n + + def _equation(self, fx, x, order): + P, Q, n = self.wilds() + return fx.diff(x) + P*fx - Q*fx**n + + def _get_general_solution(self, *, simplify_flag: bool = True): + P, Q, n = self.wilds_match() + fx = self.ode_problem.func + x = self.ode_problem.sym + (C1,) = self.ode_problem.get_numbered_constants(num=1) + if n==1: + gensol = Eq(log(fx), ( + C1 + Integral((-P + Q), x) + )) + else: + gensol = Eq(fx**(1-n), ( + (C1 - (n - 1) * Integral(Q*exp(-n*Integral(P, x)) + * exp(Integral(P, x)), x) + ) * exp(-(1 - n)*Integral(P, x))) + ) + return [gensol] + + +class Factorable(SingleODESolver): + r""" + Solves equations having a solvable factor. + + This function is used to solve the equation having factors. Factors may be of type algebraic or ode. It + will try to solve each factor independently. Factors will be solved by calling dsolve. We will return the + list of solutions. + + Examples + ======== + + >>> from sympy import Function, dsolve, pprint + >>> from sympy.abc import x + >>> f = Function('f') + >>> eq = (f(x)**2-4)*(f(x).diff(x)+f(x)) + >>> pprint(dsolve(eq, f(x))) + -x + [f(x) = 2, f(x) = -2, f(x) = C1*e ] + + + """ + hint = "factorable" + has_integral = False + + def _matches(self): + eq_orig = self.ode_problem.eq + f = self.ode_problem.func.func + x = self.ode_problem.sym + df = f(x).diff(x) + self.eqs = [] + eq = eq_orig.collect(f(x), func = cancel) + eq = fraction(factor(eq))[0] + factors = Mul.make_args(factor(eq)) + roots = [fac.as_base_exp() for fac in factors if len(fac.args)!=0] + if len(roots)>1 or roots[0][1]>1: + for base, expo in roots: + if base.has(f(x)): + self.eqs.append(base) + if len(self.eqs)>0: + return True + roots = solve(eq, df) + if len(roots)>0: + self.eqs = [(df - root) for root in roots] + # Avoid infinite recursion + matches = self.eqs != [eq_orig] + return matches + for i in factors: + if i.has(f(x)): + self.eqs.append(i) + return len(self.eqs)>0 and len(factors)>1 + + def _get_general_solution(self, *, simplify_flag: bool = True): + func = self.ode_problem.func.func + x = self.ode_problem.sym + eqns = self.eqs + sols = [] + for eq in eqns: + try: + sol = dsolve(eq, func(x)) + except NotImplementedError: + continue + else: + if isinstance(sol, list): + sols.extend(sol) + else: + sols.append(sol) + + if sols == []: + raise NotImplementedError("The given ODE " + str(eq) + " cannot be solved by" + + " the factorable group method") + return sols + + +class RiccatiSpecial(SinglePatternODESolver): + r""" + The general Riccati equation has the form + + .. math:: dy/dx = f(x) y^2 + g(x) y + h(x)\text{.} + + While it does not have a general solution [1], the "special" form, `dy/dx + = a y^2 - b x^c`, does have solutions in many cases [2]. This routine + returns a solution for `a(dy/dx) = b y^2 + c y/x + d/x^2` that is obtained + by using a suitable change of variables to reduce it to the special form + and is valid when neither `a` nor `b` are zero and either `c` or `d` is + zero. + + >>> from sympy.abc import x, a, b, c, d + >>> from sympy import dsolve, checkodesol, pprint, Function + >>> f = Function('f') + >>> y = f(x) + >>> genform = a*y.diff(x) - (b*y**2 + c*y/x + d/x**2) + >>> sol = dsolve(genform, y, hint="Riccati_special_minus2") + >>> pprint(sol, wrap_line=False) + / / __________________ \\ + | __________________ | / 2 || + | / 2 | \/ 4*b*d - (a + c) *log(x)|| + -|a + c - \/ 4*b*d - (a + c) *tan|C1 + ----------------------------|| + \ \ 2*a // + f(x) = ------------------------------------------------------------------------ + 2*b*x + + >>> checkodesol(genform, sol, order=1)[0] + True + + References + ========== + + - https://www.maplesoft.com/support/help/Maple/view.aspx?path=odeadvisor/Riccati + - https://eqworld.ipmnet.ru/en/solutions/ode/ode0106.pdf - + https://eqworld.ipmnet.ru/en/solutions/ode/ode0123.pdf + """ + hint = "Riccati_special_minus2" + has_integral = False + order = [1] + + def _wilds(self, f, x, order): + a = Wild('a', exclude=[x, f(x), f(x).diff(x), 0]) + b = Wild('b', exclude=[x, f(x), f(x).diff(x), 0]) + c = Wild('c', exclude=[x, f(x), f(x).diff(x)]) + d = Wild('d', exclude=[x, f(x), f(x).diff(x)]) + return a, b, c, d + + def _equation(self, fx, x, order): + a, b, c, d = self.wilds() + return a*fx.diff(x) + b*fx**2 + c*fx/x + d/x**2 + + def _get_general_solution(self, *, simplify_flag: bool = True): + a, b, c, d = self.wilds_match() + fx = self.ode_problem.func + x = self.ode_problem.sym + (C1,) = self.ode_problem.get_numbered_constants(num=1) + mu = sqrt(4*d*b - (a - c)**2) + + gensol = Eq(fx, (a - c - mu*tan(mu/(2*a)*log(x) + C1))/(2*b*x)) + return [gensol] + + +class RationalRiccati(SinglePatternODESolver): + r""" + Gives general solutions to the first order Riccati differential + equations that have atleast one rational particular solution. + + .. math :: y' = b_0(x) + b_1(x) y + b_2(x) y^2 + + where `b_0`, `b_1` and `b_2` are rational functions of `x` + with `b_2 \ne 0` (`b_2 = 0` would make it a Bernoulli equation). + + Examples + ======== + + >>> from sympy import Symbol, Function, dsolve, checkodesol + >>> f = Function('f') + >>> x = Symbol('x') + + >>> eq = -x**4*f(x)**2 + x**3*f(x).diff(x) + x**2*f(x) + 20 + >>> sol = dsolve(eq, hint="1st_rational_riccati") + >>> sol + Eq(f(x), (4*C1 - 5*x**9 - 4)/(x**2*(C1 + x**9 - 1))) + >>> checkodesol(eq, sol) + (True, 0) + + References + ========== + + - Riccati ODE: https://en.wikipedia.org/wiki/Riccati_equation + - N. Thieu Vo - Rational and Algebraic Solutions of First-Order Algebraic ODEs: + Algorithm 11, pp. 78 - https://www3.risc.jku.at/publications/download/risc_5387/PhDThesisThieu.pdf + """ + has_integral = False + hint = "1st_rational_riccati" + order = [1] + + def _wilds(self, f, x, order): + b0 = Wild('b0', exclude=[f(x), f(x).diff(x)]) + b1 = Wild('b1', exclude=[f(x), f(x).diff(x)]) + b2 = Wild('b2', exclude=[f(x), f(x).diff(x)]) + return (b0, b1, b2) + + def _equation(self, fx, x, order): + b0, b1, b2 = self.wilds() + return fx.diff(x) - b0 - b1*fx - b2*fx**2 + + def _matches(self): + eq = self.ode_problem.eq_expanded + f = self.ode_problem.func.func + x = self.ode_problem.sym + order = self.ode_problem.order + + if order != 1: + return False + + match, funcs = match_riccati(eq, f, x) + if not match: + return False + _b0, _b1, _b2 = funcs + b0, b1, b2 = self.wilds() + self._wilds_match = match = {b0: _b0, b1: _b1, b2: _b2} + return True + + def _get_general_solution(self, *, simplify_flag: bool = True): + # Match the equation + b0, b1, b2 = self.wilds_match() + fx = self.ode_problem.func + x = self.ode_problem.sym + return solve_riccati(fx, x, b0, b1, b2, gensol=True) + + +class SecondNonlinearAutonomousConserved(SinglePatternODESolver): + r""" + Gives solution for the autonomous second order nonlinear + differential equation of the form + + .. math :: f''(x) = g(f(x)) + + The solution for this differential equation can be computed + by multiplying by `f'(x)` and integrating on both sides, + converting it into a first order differential equation. + + Examples + ======== + + >>> from sympy import Function, symbols, dsolve + >>> f, g = symbols('f g', cls=Function) + >>> x = symbols('x') + + >>> eq = f(x).diff(x, 2) - g(f(x)) + >>> dsolve(eq, simplify=False) + [Eq(Integral(1/sqrt(C1 + 2*Integral(g(_u), _u)), (_u, f(x))), C2 + x), + Eq(Integral(1/sqrt(C1 + 2*Integral(g(_u), _u)), (_u, f(x))), C2 - x)] + + >>> from sympy import exp, log + >>> eq = f(x).diff(x, 2) - exp(f(x)) + log(f(x)) + >>> dsolve(eq, simplify=False) + [Eq(Integral(1/sqrt(-2*_u*log(_u) + 2*_u + C1 + 2*exp(_u)), (_u, f(x))), C2 + x), + Eq(Integral(1/sqrt(-2*_u*log(_u) + 2*_u + C1 + 2*exp(_u)), (_u, f(x))), C2 - x)] + + References + ========== + + - https://eqworld.ipmnet.ru/en/solutions/ode/ode0301.pdf + """ + hint = "2nd_nonlinear_autonomous_conserved" + has_integral = True + order = [2] + + def _wilds(self, f, x, order): + fy = Wild('fy', exclude=[0, f(x).diff(x), f(x).diff(x, 2)]) + return (fy, ) + + def _equation(self, fx, x, order): + fy = self.wilds()[0] + return fx.diff(x, 2) + fy + + def _verify(self, fx): + return self.ode_problem.is_autonomous + + def _get_general_solution(self, *, simplify_flag: bool = True): + g = self.wilds_match()[0] + fx = self.ode_problem.func + x = self.ode_problem.sym + u = Dummy('u') + g = g.subs(fx, u) + C1, C2 = self.ode_problem.get_numbered_constants(num=2) + inside = -2*Integral(g, u) + C1 + lhs = Integral(1/sqrt(inside), (u, fx)) + return [Eq(lhs, C2 + x), Eq(lhs, C2 - x)] + + +class Liouville(SinglePatternODESolver): + r""" + Solves 2nd order Liouville differential equations. + + The general form of a Liouville ODE is + + .. math:: \frac{d^2 y}{dx^2} + g(y) \left(\! + \frac{dy}{dx}\!\right)^2 + h(x) + \frac{dy}{dx}\text{.} + + The general solution is: + + >>> from sympy import Function, dsolve, Eq, pprint, diff + >>> from sympy.abc import x + >>> f, g, h = map(Function, ['f', 'g', 'h']) + >>> genform = Eq(diff(f(x),x,x) + g(f(x))*diff(f(x),x)**2 + + ... h(x)*diff(f(x),x), 0) + >>> pprint(genform) + 2 2 + /d \ d d + g(f(x))*|--(f(x))| + h(x)*--(f(x)) + ---(f(x)) = 0 + \dx / dx 2 + dx + >>> pprint(dsolve(genform, f(x), hint='Liouville_Integral')) + f(x) + / / + | | + | / | / + | | | | + | - | h(x) dx | | g(y) dy + | | | | + | / | / + C1 + C2* | e dx + | e dy = 0 + | | + / / + + Examples + ======== + + >>> from sympy import Function, dsolve, Eq, pprint + >>> from sympy.abc import x + >>> f = Function('f') + >>> pprint(dsolve(diff(f(x), x, x) + diff(f(x), x)**2/f(x) + + ... diff(f(x), x)/x, f(x), hint='Liouville')) + ________________ ________________ + [f(x) = -\/ C1 + C2*log(x) , f(x) = \/ C1 + C2*log(x) ] + + References + ========== + + - Goldstein and Braun, "Advanced Methods for the Solution of Differential + Equations", pp. 98 + - https://www.maplesoft.com/support/help/Maple/view.aspx?path=odeadvisor/Liouville + + # indirect doctest + + """ + hint = "Liouville" + has_integral = True + order = [2] + + def _wilds(self, f, x, order): + d = Wild('d', exclude=[f(x).diff(x), f(x).diff(x, 2)]) + e = Wild('e', exclude=[f(x).diff(x)]) + k = Wild('k', exclude=[f(x).diff(x)]) + return d, e, k + + def _equation(self, fx, x, order): + # Liouville ODE in the form + # f(x).diff(x, 2) + g(f(x))*(f(x).diff(x))**2 + h(x)*f(x).diff(x) + # See Goldstein and Braun, "Advanced Methods for the Solution of + # Differential Equations", pg. 98 + d, e, k = self.wilds() + return d*fx.diff(x, 2) + e*fx.diff(x)**2 + k*fx.diff(x) + + def _verify(self, fx): + d, e, k = self.wilds_match() + self.y = Dummy('y') + x = self.ode_problem.sym + self.g = simplify(e/d).subs(fx, self.y) + self.h = simplify(k/d).subs(fx, self.y) + if self.y in self.h.free_symbols or x in self.g.free_symbols: + return False + return True + + def _get_general_solution(self, *, simplify_flag: bool = True): + d, e, k = self.wilds_match() + fx = self.ode_problem.func + x = self.ode_problem.sym + C1, C2 = self.ode_problem.get_numbered_constants(num=2) + int = Integral(exp(Integral(self.g, self.y)), (self.y, None, fx)) + gen_sol = Eq(int + C1*Integral(exp(-Integral(self.h, x)), x) + C2, 0) + + return [gen_sol] + + +class Separable(SinglePatternODESolver): + r""" + Solves separable 1st order differential equations. + + This is any differential equation that can be written as `P(y) + \tfrac{dy}{dx} = Q(x)`. The solution can then just be found by + rearranging terms and integrating: `\int P(y) \,dy = \int Q(x) \,dx`. + This hint uses :py:meth:`sympy.simplify.simplify.separatevars` as its back + end, so if a separable equation is not caught by this solver, it is most + likely the fault of that function. + :py:meth:`~sympy.simplify.simplify.separatevars` is + smart enough to do most expansion and factoring necessary to convert a + separable equation `F(x, y)` into the proper form `P(x)\cdot{}Q(y)`. The + general solution is:: + + >>> from sympy import Function, dsolve, Eq, pprint + >>> from sympy.abc import x + >>> a, b, c, d, f = map(Function, ['a', 'b', 'c', 'd', 'f']) + >>> genform = Eq(a(x)*b(f(x))*f(x).diff(x), c(x)*d(f(x))) + >>> pprint(genform) + d + a(x)*b(f(x))*--(f(x)) = c(x)*d(f(x)) + dx + >>> pprint(dsolve(genform, f(x), hint='separable_Integral')) + f(x) + / / + | | + | b(y) | c(x) + | ---- dy = C1 + | ---- dx + | d(y) | a(x) + | | + / / + + Examples + ======== + + >>> from sympy import Function, dsolve, Eq + >>> from sympy.abc import x + >>> f = Function('f') + >>> pprint(dsolve(Eq(f(x)*f(x).diff(x) + x, 3*x*f(x)**2), f(x), + ... hint='separable', simplify=False)) + / 2 \ 2 + log\3*f (x) - 1/ x + ---------------- = C1 + -- + 6 2 + + References + ========== + + - M. Tenenbaum & H. Pollard, "Ordinary Differential Equations", + Dover 1963, pp. 52 + + # indirect doctest + + """ + hint = "separable" + has_integral = True + order = [1] + + def _wilds(self, f, x, order): + d = Wild('d', exclude=[f(x).diff(x), f(x).diff(x, 2)]) + e = Wild('e', exclude=[f(x).diff(x)]) + return d, e + + def _equation(self, fx, x, order): + d, e = self.wilds() + return d + e*fx.diff(x) + + def _verify(self, fx): + d, e = self.wilds_match() + self.y = Dummy('y') + x = self.ode_problem.sym + d = separatevars(d.subs(fx, self.y)) + e = separatevars(e.subs(fx, self.y)) + # m1[coeff]*m1[x]*m1[y] + m2[coeff]*m2[x]*m2[y]*y' + self.m1 = separatevars(d, dict=True, symbols=(x, self.y)) + self.m2 = separatevars(e, dict=True, symbols=(x, self.y)) + return bool(self.m1 and self.m2) + + def _get_match_object(self): + fx = self.ode_problem.func + x = self.ode_problem.sym + return self.m1, self.m2, x, fx + + def _get_general_solution(self, *, simplify_flag: bool = True): + m1, m2, x, fx = self._get_match_object() + (C1,) = self.ode_problem.get_numbered_constants(num=1) + int = Integral(m2['coeff']*m2[self.y]/m1[self.y], + (self.y, None, fx)) + gen_sol = Eq(int, Integral(-m1['coeff']*m1[x]/ + m2[x], x) + C1) + return [gen_sol] + + +class SeparableReduced(Separable): + r""" + Solves a differential equation that can be reduced to the separable form. + + The general form of this equation is + + .. math:: y' + (y/x) H(x^n y) = 0\text{}. + + This can be solved by substituting `u(y) = x^n y`. The equation then + reduces to the separable form `\frac{u'}{u (\mathrm{power} - H(u))} - + \frac{1}{x} = 0`. + + The general solution is: + + >>> from sympy import Function, dsolve, pprint + >>> from sympy.abc import x, n + >>> f, g = map(Function, ['f', 'g']) + >>> genform = f(x).diff(x) + (f(x)/x)*g(x**n*f(x)) + >>> pprint(genform) + / n \ + d f(x)*g\x *f(x)/ + --(f(x)) + --------------- + dx x + >>> pprint(dsolve(genform, hint='separable_reduced')) + n + x *f(x) + / + | + | 1 + | ------------ dy = C1 + log(x) + | y*(n - g(y)) + | + / + + See Also + ======== + :obj:`sympy.solvers.ode.single.Separable` + + Examples + ======== + + >>> from sympy import dsolve, Function, pprint + >>> from sympy.abc import x + >>> f = Function('f') + >>> d = f(x).diff(x) + >>> eq = (x - x**2*f(x))*d - f(x) + >>> dsolve(eq, hint='separable_reduced') + [Eq(f(x), (1 - sqrt(C1*x**2 + 1))/x), Eq(f(x), (sqrt(C1*x**2 + 1) + 1)/x)] + >>> pprint(dsolve(eq, hint='separable_reduced')) + ___________ ___________ + / 2 / 2 + 1 - \/ C1*x + 1 \/ C1*x + 1 + 1 + [f(x) = ------------------, f(x) = ------------------] + x x + + References + ========== + + - Joel Moses, "Symbolic Integration - The Stormy Decade", Communications + of the ACM, Volume 14, Number 8, August 1971, pp. 558 + """ + hint = "separable_reduced" + has_integral = True + order = [1] + + def _degree(self, expr, x): + # Made this function to calculate the degree of + # x in an expression. If expr will be of form + # x**p*y, (wheare p can be variables/rationals) then it + # will return p. + for val in expr: + if val.has(x): + if isinstance(val, Pow) and val.as_base_exp()[0] == x: + return (val.as_base_exp()[1]) + elif val == x: + return (val.as_base_exp()[1]) + else: + return self._degree(val.args, x) + return 0 + + def _powers(self, expr): + # this function will return all the different relative power of x w.r.t f(x). + # expr = x**p * f(x)**q then it will return {p/q}. + pows = set() + fx = self.ode_problem.func + x = self.ode_problem.sym + self.y = Dummy('y') + if isinstance(expr, Add): + exprs = expr.atoms(Add) + elif isinstance(expr, Mul): + exprs = expr.atoms(Mul) + elif isinstance(expr, Pow): + exprs = expr.atoms(Pow) + else: + exprs = {expr} + + for arg in exprs: + if arg.has(x): + _, u = arg.as_independent(x, fx) + pow = self._degree((u.subs(fx, self.y), ), x)/self._degree((u.subs(fx, self.y), ), self.y) + pows.add(pow) + return pows + + def _verify(self, fx): + num, den = self.wilds_match() + x = self.ode_problem.sym + factor = simplify(x/fx*num/den) + # Try representing factor in terms of x^n*y + # where n is lowest power of x in factor; + # first remove terms like sqrt(2)*3 from factor.atoms(Mul) + num, dem = factor.as_numer_denom() + num = expand(num) + dem = expand(dem) + pows = self._powers(num) + pows.update(self._powers(dem)) + pows = list(pows) + if(len(pows)==1) and pows[0]!=zoo: + self.t = Dummy('t') + self.r2 = {'t': self.t} + num = num.subs(x**pows[0]*fx, self.t) + dem = dem.subs(x**pows[0]*fx, self.t) + test = num/dem + free = test.free_symbols + if len(free) == 1 and free.pop() == self.t: + self.r2.update({'power' : pows[0], 'u' : test}) + return True + return False + return False + + def _get_match_object(self): + fx = self.ode_problem.func + x = self.ode_problem.sym + u = self.r2['u'].subs(self.r2['t'], self.y) + ycoeff = 1/(self.y*(self.r2['power'] - u)) + m1 = {self.y: 1, x: -1/x, 'coeff': 1} + m2 = {self.y: ycoeff, x: 1, 'coeff': 1} + return m1, m2, x, x**self.r2['power']*fx + + +class HomogeneousCoeffSubsDepDivIndep(SinglePatternODESolver): + r""" + Solves a 1st order differential equation with homogeneous coefficients + using the substitution `u_1 = \frac{\text{}}{\text{}}`. + + This is a differential equation + + .. math:: P(x, y) + Q(x, y) dy/dx = 0 + + such that `P` and `Q` are homogeneous and of the same order. A function + `F(x, y)` is homogeneous of order `n` if `F(x t, y t) = t^n F(x, y)`. + Equivalently, `F(x, y)` can be rewritten as `G(y/x)` or `H(x/y)`. See + also the docstring of :py:meth:`~sympy.solvers.ode.homogeneous_order`. + + If the coefficients `P` and `Q` in the differential equation above are + homogeneous functions of the same order, then it can be shown that the + substitution `y = u_1 x` (i.e. `u_1 = y/x`) will turn the differential + equation into an equation separable in the variables `x` and `u`. If + `h(u_1)` is the function that results from making the substitution `u_1 = + f(x)/x` on `P(x, f(x))` and `g(u_2)` is the function that results from the + substitution on `Q(x, f(x))` in the differential equation `P(x, f(x)) + + Q(x, f(x)) f'(x) = 0`, then the general solution is:: + + >>> from sympy import Function, dsolve, pprint + >>> from sympy.abc import x + >>> f, g, h = map(Function, ['f', 'g', 'h']) + >>> genform = g(f(x)/x) + h(f(x)/x)*f(x).diff(x) + >>> pprint(genform) + /f(x)\ /f(x)\ d + g|----| + h|----|*--(f(x)) + \ x / \ x / dx + >>> pprint(dsolve(genform, f(x), + ... hint='1st_homogeneous_coeff_subs_dep_div_indep_Integral')) + f(x) + ---- + x + / + | + | -h(u1) + log(x) = C1 + | ---------------- d(u1) + | u1*h(u1) + g(u1) + | + / + + Where `u_1 h(u_1) + g(u_1) \ne 0` and `x \ne 0`. + + See also the docstrings of + :obj:`~sympy.solvers.ode.single.HomogeneousCoeffBest` and + :obj:`~sympy.solvers.ode.single.HomogeneousCoeffSubsIndepDivDep`. + + Examples + ======== + + >>> from sympy import Function, dsolve + >>> from sympy.abc import x + >>> f = Function('f') + >>> pprint(dsolve(2*x*f(x) + (x**2 + f(x)**2)*f(x).diff(x), f(x), + ... hint='1st_homogeneous_coeff_subs_dep_div_indep', simplify=False)) + / 3 \ + |3*f(x) f (x)| + log|------ + -----| + | x 3 | + \ x / + log(x) = log(C1) - ------------------- + 3 + + References + ========== + + - https://en.wikipedia.org/wiki/Homogeneous_differential_equation + - M. Tenenbaum & H. Pollard, "Ordinary Differential Equations", + Dover 1963, pp. 59 + + # indirect doctest + + """ + hint = "1st_homogeneous_coeff_subs_dep_div_indep" + has_integral = True + order = [1] + + def _wilds(self, f, x, order): + d = Wild('d', exclude=[f(x).diff(x), f(x).diff(x, 2)]) + e = Wild('e', exclude=[f(x).diff(x)]) + return d, e + + def _equation(self, fx, x, order): + d, e = self.wilds() + return d + e*fx.diff(x) + + def _verify(self, fx): + self.d, self.e = self.wilds_match() + self.y = Dummy('y') + x = self.ode_problem.sym + self.d = separatevars(self.d.subs(fx, self.y)) + self.e = separatevars(self.e.subs(fx, self.y)) + ordera = homogeneous_order(self.d, x, self.y) + orderb = homogeneous_order(self.e, x, self.y) + if ordera == orderb and ordera is not None: + self.u = Dummy('u') + if simplify((self.d + self.u*self.e).subs({x: 1, self.y: self.u})) != 0: + return True + return False + return False + + def _get_match_object(self): + fx = self.ode_problem.func + x = self.ode_problem.sym + self.u1 = Dummy('u1') + xarg = 0 + yarg = 0 + return [self.d, self.e, fx, x, self.u, self.u1, self.y, xarg, yarg] + + def _get_general_solution(self, *, simplify_flag: bool = True): + d, e, fx, x, u, u1, y, xarg, yarg = self._get_match_object() + (C1,) = self.ode_problem.get_numbered_constants(num=1) + int = Integral( + (-e/(d + u1*e)).subs({x: 1, y: u1}), + (u1, None, fx/x)) + sol = logcombine(Eq(log(x), int + log(C1)), force=True) + gen_sol = sol.subs(fx, u).subs(((u, u - yarg), (x, x - xarg), (u, fx))) + return [gen_sol] + + +class HomogeneousCoeffSubsIndepDivDep(SinglePatternODESolver): + r""" + Solves a 1st order differential equation with homogeneous coefficients + using the substitution `u_2 = \frac{\text{}}{\text{}}`. + + This is a differential equation + + .. math:: P(x, y) + Q(x, y) dy/dx = 0 + + such that `P` and `Q` are homogeneous and of the same order. A function + `F(x, y)` is homogeneous of order `n` if `F(x t, y t) = t^n F(x, y)`. + Equivalently, `F(x, y)` can be rewritten as `G(y/x)` or `H(x/y)`. See + also the docstring of :py:meth:`~sympy.solvers.ode.homogeneous_order`. + + If the coefficients `P` and `Q` in the differential equation above are + homogeneous functions of the same order, then it can be shown that the + substitution `x = u_2 y` (i.e. `u_2 = x/y`) will turn the differential + equation into an equation separable in the variables `y` and `u_2`. If + `h(u_2)` is the function that results from making the substitution `u_2 = + x/f(x)` on `P(x, f(x))` and `g(u_2)` is the function that results from the + substitution on `Q(x, f(x))` in the differential equation `P(x, f(x)) + + Q(x, f(x)) f'(x) = 0`, then the general solution is: + + >>> from sympy import Function, dsolve, pprint + >>> from sympy.abc import x + >>> f, g, h = map(Function, ['f', 'g', 'h']) + >>> genform = g(x/f(x)) + h(x/f(x))*f(x).diff(x) + >>> pprint(genform) + / x \ / x \ d + g|----| + h|----|*--(f(x)) + \f(x)/ \f(x)/ dx + >>> pprint(dsolve(genform, f(x), + ... hint='1st_homogeneous_coeff_subs_indep_div_dep_Integral')) + x + ---- + f(x) + / + | + | -g(u1) + | ---------------- d(u1) + | u1*g(u1) + h(u1) + | + / + + f(x) = C1*e + + Where `u_1 g(u_1) + h(u_1) \ne 0` and `f(x) \ne 0`. + + See also the docstrings of + :obj:`~sympy.solvers.ode.single.HomogeneousCoeffBest` and + :obj:`~sympy.solvers.ode.single.HomogeneousCoeffSubsDepDivIndep`. + + Examples + ======== + + >>> from sympy import Function, pprint, dsolve + >>> from sympy.abc import x + >>> f = Function('f') + >>> pprint(dsolve(2*x*f(x) + (x**2 + f(x)**2)*f(x).diff(x), f(x), + ... hint='1st_homogeneous_coeff_subs_indep_div_dep', + ... simplify=False)) + / 2 \ + |3*x | + log|----- + 1| + | 2 | + \f (x) / + log(f(x)) = log(C1) - -------------- + 3 + + References + ========== + + - https://en.wikipedia.org/wiki/Homogeneous_differential_equation + - M. Tenenbaum & H. Pollard, "Ordinary Differential Equations", + Dover 1963, pp. 59 + + # indirect doctest + + """ + hint = "1st_homogeneous_coeff_subs_indep_div_dep" + has_integral = True + order = [1] + + def _wilds(self, f, x, order): + d = Wild('d', exclude=[f(x).diff(x), f(x).diff(x, 2)]) + e = Wild('e', exclude=[f(x).diff(x)]) + return d, e + + def _equation(self, fx, x, order): + d, e = self.wilds() + return d + e*fx.diff(x) + + def _verify(self, fx): + self.d, self.e = self.wilds_match() + self.y = Dummy('y') + x = self.ode_problem.sym + self.d = separatevars(self.d.subs(fx, self.y)) + self.e = separatevars(self.e.subs(fx, self.y)) + ordera = homogeneous_order(self.d, x, self.y) + orderb = homogeneous_order(self.e, x, self.y) + if ordera == orderb and ordera is not None: + self.u = Dummy('u') + if simplify((self.e + self.u*self.d).subs({x: self.u, self.y: 1})) != 0: + return True + return False + return False + + def _get_match_object(self): + fx = self.ode_problem.func + x = self.ode_problem.sym + self.u1 = Dummy('u1') + xarg = 0 + yarg = 0 + return [self.d, self.e, fx, x, self.u, self.u1, self.y, xarg, yarg] + + def _get_general_solution(self, *, simplify_flag: bool = True): + d, e, fx, x, u, u1, y, xarg, yarg = self._get_match_object() + (C1,) = self.ode_problem.get_numbered_constants(num=1) + int = Integral(simplify((-d/(e + u1*d)).subs({x: u1, y: 1})), (u1, None, x/fx)) # type: ignore + sol = logcombine(Eq(log(fx), int + log(C1)), force=True) + gen_sol = sol.subs(fx, u).subs(((u, u - yarg), (x, x - xarg), (u, fx))) + return [gen_sol] + + +class HomogeneousCoeffBest(HomogeneousCoeffSubsIndepDivDep, HomogeneousCoeffSubsDepDivIndep): + r""" + Returns the best solution to an ODE from the two hints + ``1st_homogeneous_coeff_subs_dep_div_indep`` and + ``1st_homogeneous_coeff_subs_indep_div_dep``. + + This is as determined by :py:meth:`~sympy.solvers.ode.ode.ode_sol_simplicity`. + + See the + :obj:`~sympy.solvers.ode.single.HomogeneousCoeffSubsIndepDivDep` + and + :obj:`~sympy.solvers.ode.single.HomogeneousCoeffSubsDepDivIndep` + docstrings for more information on these hints. Note that there is no + ``ode_1st_homogeneous_coeff_best_Integral`` hint. + + Examples + ======== + + >>> from sympy import Function, dsolve, pprint + >>> from sympy.abc import x + >>> f = Function('f') + >>> pprint(dsolve(2*x*f(x) + (x**2 + f(x)**2)*f(x).diff(x), f(x), + ... hint='1st_homogeneous_coeff_best', simplify=False)) + / 2 \ + |3*x | + log|----- + 1| + | 2 | + \f (x) / + log(f(x)) = log(C1) - -------------- + 3 + + References + ========== + + - https://en.wikipedia.org/wiki/Homogeneous_differential_equation + - M. Tenenbaum & H. Pollard, "Ordinary Differential Equations", + Dover 1963, pp. 59 + + # indirect doctest + + """ + hint = "1st_homogeneous_coeff_best" + has_integral = False + order = [1] + + def _verify(self, fx): + return HomogeneousCoeffSubsIndepDivDep._verify(self, fx) and \ + HomogeneousCoeffSubsDepDivIndep._verify(self, fx) + + def _get_general_solution(self, *, simplify_flag: bool = True): + # There are two substitutions that solve the equation, u1=y/x and u2=x/y + # # They produce different integrals, so try them both and see which + # # one is easier + sol1 = HomogeneousCoeffSubsIndepDivDep._get_general_solution(self) + sol2 = HomogeneousCoeffSubsDepDivIndep._get_general_solution(self) + fx = self.ode_problem.func + if simplify_flag: + sol1 = odesimp(self.ode_problem.eq, *sol1, fx, "1st_homogeneous_coeff_subs_indep_div_dep") + sol2 = odesimp(self.ode_problem.eq, *sol2, fx, "1st_homogeneous_coeff_subs_dep_div_indep") + # XXX: not simplify should be not simplify_flag. mypy correctly complains + return min([sol1, sol2], key=lambda x: ode_sol_simplicity(x, fx, trysolving=not simplify)) # type: ignore + + +class LinearCoefficients(HomogeneousCoeffBest): + r""" + Solves a differential equation with linear coefficients. + + The general form of a differential equation with linear coefficients is + + .. math:: y' + F\left(\!\frac{a_1 x + b_1 y + c_1}{a_2 x + b_2 y + + c_2}\!\right) = 0\text{,} + + where `a_1`, `b_1`, `c_1`, `a_2`, `b_2`, `c_2` are constants and `a_1 b_2 + - a_2 b_1 \ne 0`. + + This can be solved by substituting: + + .. math:: x = x' + \frac{b_2 c_1 - b_1 c_2}{a_2 b_1 - a_1 b_2} + + y = y' + \frac{a_1 c_2 - a_2 c_1}{a_2 b_1 - a_1 + b_2}\text{.} + + This substitution reduces the equation to a homogeneous differential + equation. + + See Also + ======== + :obj:`sympy.solvers.ode.single.HomogeneousCoeffBest` + :obj:`sympy.solvers.ode.single.HomogeneousCoeffSubsIndepDivDep` + :obj:`sympy.solvers.ode.single.HomogeneousCoeffSubsDepDivIndep` + + Examples + ======== + + >>> from sympy import dsolve, Function, pprint + >>> from sympy.abc import x + >>> f = Function('f') + >>> df = f(x).diff(x) + >>> eq = (x + f(x) + 1)*df + (f(x) - 6*x + 1) + >>> dsolve(eq, hint='linear_coefficients') + [Eq(f(x), -x - sqrt(C1 + 7*x**2) - 1), Eq(f(x), -x + sqrt(C1 + 7*x**2) - 1)] + >>> pprint(dsolve(eq, hint='linear_coefficients')) + ___________ ___________ + / 2 / 2 + [f(x) = -x - \/ C1 + 7*x - 1, f(x) = -x + \/ C1 + 7*x - 1] + + + References + ========== + + - Joel Moses, "Symbolic Integration - The Stormy Decade", Communications + of the ACM, Volume 14, Number 8, August 1971, pp. 558 + """ + hint = "linear_coefficients" + has_integral = True + order = [1] + + def _wilds(self, f, x, order): + d = Wild('d', exclude=[f(x).diff(x), f(x).diff(x, 2)]) + e = Wild('e', exclude=[f(x).diff(x)]) + return d, e + + def _equation(self, fx, x, order): + d, e = self.wilds() + return d + e*fx.diff(x) + + def _verify(self, fx): + self.d, self.e = self.wilds_match() + a, b = self.wilds() + F = self.d/self.e + x = self.ode_problem.sym + params = self._linear_coeff_match(F, fx) + if params: + self.xarg, self.yarg = params + u = Dummy('u') + t = Dummy('t') + self.y = Dummy('y') + # Dummy substitution for df and f(x). + dummy_eq = self.ode_problem.eq.subs(((fx.diff(x), t), (fx, u))) + reps = ((x, x + self.xarg), (u, u + self.yarg), (t, fx.diff(x)), (u, fx)) + dummy_eq = simplify(dummy_eq.subs(reps)) + # get the re-cast values for e and d + r2 = collect(expand(dummy_eq), [fx.diff(x), fx]).match(a*fx.diff(x) + b) + if r2: + self.d, self.e = r2[b], r2[a] + orderd = homogeneous_order(self.d, x, fx) + ordere = homogeneous_order(self.e, x, fx) + if orderd == ordere and orderd is not None: + self.d = self.d.subs(fx, self.y) + self.e = self.e.subs(fx, self.y) + return True + return False + return False + + def _linear_coeff_match(self, expr, func): + r""" + Helper function to match hint ``linear_coefficients``. + + Matches the expression to the form `(a_1 x + b_1 f(x) + c_1)/(a_2 x + b_2 + f(x) + c_2)` where the following conditions hold: + + 1. `a_1`, `b_1`, `c_1`, `a_2`, `b_2`, `c_2` are Rationals; + 2. `c_1` or `c_2` are not equal to zero; + 3. `a_2 b_1 - a_1 b_2` is not equal to zero. + + Return ``xarg``, ``yarg`` where + + 1. ``xarg`` = `(b_2 c_1 - b_1 c_2)/(a_2 b_1 - a_1 b_2)` + 2. ``yarg`` = `(a_1 c_2 - a_2 c_1)/(a_2 b_1 - a_1 b_2)` + + + Examples + ======== + + >>> from sympy import Function, sin + >>> from sympy.abc import x + >>> from sympy.solvers.ode.single import LinearCoefficients + >>> f = Function('f') + >>> eq = (-25*f(x) - 8*x + 62)/(4*f(x) + 11*x - 11) + >>> obj = LinearCoefficients(eq) + >>> obj._linear_coeff_match(eq, f(x)) + (1/9, 22/9) + >>> eq = sin((-5*f(x) - 8*x + 6)/(4*f(x) + x - 1)) + >>> obj = LinearCoefficients(eq) + >>> obj._linear_coeff_match(eq, f(x)) + (19/27, 2/27) + >>> eq = sin(f(x)/x) + >>> obj = LinearCoefficients(eq) + >>> obj._linear_coeff_match(eq, f(x)) + + """ + f = func.func + x = func.args[0] + def abc(eq): + r''' + Internal function of _linear_coeff_match + that returns Rationals a, b, c + if eq is a*x + b*f(x) + c, else None. + ''' + eq = _mexpand(eq) + c = eq.as_independent(x, f(x), as_Add=True)[0] + if not c.is_Rational: + return + a = eq.coeff(x) + if not a.is_Rational: + return + b = eq.coeff(f(x)) + if not b.is_Rational: + return + if eq == a*x + b*f(x) + c: + return a, b, c + + def match(arg): + r''' + Internal function of _linear_coeff_match that returns Rationals a1, + b1, c1, a2, b2, c2 and a2*b1 - a1*b2 of the expression (a1*x + b1*f(x) + + c1)/(a2*x + b2*f(x) + c2) if one of c1 or c2 and a2*b1 - a1*b2 is + non-zero, else None. + ''' + n, d = arg.together().as_numer_denom() + m = abc(n) + if m is not None: + a1, b1, c1 = m + m = abc(d) + if m is not None: + a2, b2, c2 = m + d = a2*b1 - a1*b2 + if (c1 or c2) and d: + return a1, b1, c1, a2, b2, c2, d + + m = [fi.args[0] for fi in expr.atoms(Function) if fi.func != f and + len(fi.args) == 1 and not fi.args[0].is_Function] or {expr} + m1 = match(m.pop()) + if m1 and all(match(mi) == m1 for mi in m): + a1, b1, c1, a2, b2, c2, denom = m1 + return (b2*c1 - b1*c2)/denom, (a1*c2 - a2*c1)/denom + + def _get_match_object(self): + fx = self.ode_problem.func + x = self.ode_problem.sym + self.u1 = Dummy('u1') + u = Dummy('u') + return [self.d, self.e, fx, x, u, self.u1, self.y, self.xarg, self.yarg] + + +class NthOrderReducible(SingleODESolver): + r""" + Solves ODEs that only involve derivatives of the dependent variable using + a substitution of the form `f^n(x) = g(x)`. + + For example any second order ODE of the form `f''(x) = h(f'(x), x)` can be + transformed into a pair of 1st order ODEs `g'(x) = h(g(x), x)` and + `f'(x) = g(x)`. Usually the 1st order ODE for `g` is easier to solve. If + that gives an explicit solution for `g` then `f` is found simply by + integration. + + + Examples + ======== + + >>> from sympy import Function, dsolve, Eq + >>> from sympy.abc import x + >>> f = Function('f') + >>> eq = Eq(x*f(x).diff(x)**2 + f(x).diff(x, 2), 0) + >>> dsolve(eq, f(x), hint='nth_order_reducible') + ... # doctest: +NORMALIZE_WHITESPACE + Eq(f(x), C1 - sqrt(-1/C2)*log(-C2*sqrt(-1/C2) + x) + sqrt(-1/C2)*log(C2*sqrt(-1/C2) + x)) + + """ + hint = "nth_order_reducible" + has_integral = False + + def _matches(self): + # Any ODE that can be solved with a substitution and + # repeated integration e.g.: + # `d^2/dx^2(y) + x*d/dx(y) = constant + #f'(x) must be finite for this to work + eq = self.ode_problem.eq_preprocessed + func = self.ode_problem.func + x = self.ode_problem.sym + r""" + Matches any differential equation that can be rewritten with a smaller + order. Only derivatives of ``func`` alone, wrt a single variable, + are considered, and only in them should ``func`` appear. + """ + # ODE only handles functions of 1 variable so this affirms that state + assert len(func.args) == 1 + vc = [d.variable_count[0] for d in eq.atoms(Derivative) + if d.expr == func and len(d.variable_count) == 1] + ords = [c for v, c in vc if v == x] + if len(ords) < 2: + return False + self.smallest = min(ords) + # make sure func does not appear outside of derivatives + D = Dummy() + if eq.subs(func.diff(x, self.smallest), D).has(func): + return False + return True + + def _get_general_solution(self, *, simplify_flag: bool = True): + eq = self.ode_problem.eq + f = self.ode_problem.func.func + x = self.ode_problem.sym + n = self.smallest + # get a unique function name for g + names = [a.name for a in eq.atoms(AppliedUndef)] + while True: + name = Dummy().name + if name not in names: + g = Function(name) + break + w = f(x).diff(x, n) + geq = eq.subs(w, g(x)) + gsol = dsolve(geq, g(x)) + + if not isinstance(gsol, list): + gsol = [gsol] + + # Might be multiple solutions to the reduced ODE: + fsol = [] + for gsoli in gsol: + fsoli = dsolve(gsoli.subs(g(x), w), f(x)) # or do integration n times + fsol.append(fsoli) + + return fsol + + +class SecondHypergeometric(SingleODESolver): + r""" + Solves 2nd order linear differential equations. + + It computes special function solutions which can be expressed using the + 2F1, 1F1 or 0F1 hypergeometric functions. + + .. math:: y'' + A(x) y' + B(x) y = 0\text{,} + + where `A` and `B` are rational functions. + + These kinds of differential equations have solution of non-Liouvillian form. + + Given linear ODE can be obtained from 2F1 given by + + .. math:: (x^2 - x) y'' + ((a + b + 1) x - c) y' + b a y = 0\text{,} + + where {a, b, c} are arbitrary constants. + + Notes + ===== + + The algorithm should find any solution of the form + + .. math:: y = P(x) _pF_q(..; ..;\frac{\alpha x^k + \beta}{\gamma x^k + \delta})\text{,} + + where pFq is any of 2F1, 1F1 or 0F1 and `P` is an "arbitrary function". + Currently only the 2F1 case is implemented in SymPy but the other cases are + described in the paper and could be implemented in future (contributions + welcome!). + + + Examples + ======== + + >>> from sympy import Function, dsolve, pprint + >>> from sympy.abc import x + >>> f = Function('f') + >>> eq = (x*x - x)*f(x).diff(x,2) + (5*x - 1)*f(x).diff(x) + 4*f(x) + >>> pprint(dsolve(eq, f(x), '2nd_hypergeometric')) + _ + / / 4 \\ |_ /-1, -1 | \ + |C1 + C2*|log(x) + -----||* | | | x| + \ \ x + 1// 2 1 \ 1 | / + f(x) = -------------------------------------------- + 3 + (x - 1) + + + References + ========== + + - "Non-Liouvillian solutions for second order linear ODEs" by L. Chan, E.S. Cheb-Terrab + + """ + hint = "2nd_hypergeometric" + has_integral = True + + def _matches(self): + eq = self.ode_problem.eq_preprocessed + func = self.ode_problem.func + r = match_2nd_hypergeometric(eq, func) + self.match_object = None + if r: + A, B = r + d = equivalence_hypergeometric(A, B, func) + if d: + if d['type'] == "2F1": + self.match_object = match_2nd_2F1_hypergeometric(d['I0'], d['k'], d['sing_point'], func) + if self.match_object is not None: + self.match_object.update({'A':A, 'B':B}) + # We can extend it for 1F1 and 0F1 type also. + return self.match_object is not None + + def _get_general_solution(self, *, simplify_flag: bool = True): + eq = self.ode_problem.eq + func = self.ode_problem.func + if self.match_object['type'] == "2F1": + sol = get_sol_2F1_hypergeometric(eq, func, self.match_object) + if sol is None: + raise NotImplementedError("The given ODE " + str(eq) + " cannot be solved by" + + " the hypergeometric method") + + return [sol] + + +class NthLinearConstantCoeffHomogeneous(SingleODESolver): + r""" + Solves an `n`\th order linear homogeneous differential equation with + constant coefficients. + + This is an equation of the form + + .. math:: a_n f^{(n)}(x) + a_{n-1} f^{(n-1)}(x) + \cdots + a_1 f'(x) + + a_0 f(x) = 0\text{.} + + These equations can be solved in a general manner, by taking the roots of + the characteristic equation `a_n m^n + a_{n-1} m^{n-1} + \cdots + a_1 m + + a_0 = 0`. The solution will then be the sum of `C_n x^i e^{r x}` terms, + for each where `C_n` is an arbitrary constant, `r` is a root of the + characteristic equation and `i` is one of each from 0 to the multiplicity + of the root - 1 (for example, a root 3 of multiplicity 2 would create the + terms `C_1 e^{3 x} + C_2 x e^{3 x}`). The exponential is usually expanded + for complex roots using Euler's equation `e^{I x} = \cos(x) + I \sin(x)`. + Complex roots always come in conjugate pairs in polynomials with real + coefficients, so the two roots will be represented (after simplifying the + constants) as `e^{a x} \left(C_1 \cos(b x) + C_2 \sin(b x)\right)`. + + If SymPy cannot find exact roots to the characteristic equation, a + :py:class:`~sympy.polys.rootoftools.ComplexRootOf` instance will be return + instead. + + >>> from sympy import Function, dsolve + >>> from sympy.abc import x + >>> f = Function('f') + >>> dsolve(f(x).diff(x, 5) + 10*f(x).diff(x) - 2*f(x), f(x), + ... hint='nth_linear_constant_coeff_homogeneous') + ... # doctest: +NORMALIZE_WHITESPACE + Eq(f(x), C5*exp(x*CRootOf(_x**5 + 10*_x - 2, 0)) + + (C1*sin(x*im(CRootOf(_x**5 + 10*_x - 2, 1))) + + C2*cos(x*im(CRootOf(_x**5 + 10*_x - 2, 1))))*exp(x*re(CRootOf(_x**5 + 10*_x - 2, 1))) + + (C3*sin(x*im(CRootOf(_x**5 + 10*_x - 2, 3))) + + C4*cos(x*im(CRootOf(_x**5 + 10*_x - 2, 3))))*exp(x*re(CRootOf(_x**5 + 10*_x - 2, 3)))) + + Note that because this method does not involve integration, there is no + ``nth_linear_constant_coeff_homogeneous_Integral`` hint. + + Examples + ======== + + >>> from sympy import Function, dsolve, pprint + >>> from sympy.abc import x + >>> f = Function('f') + >>> pprint(dsolve(f(x).diff(x, 4) + 2*f(x).diff(x, 3) - + ... 2*f(x).diff(x, 2) - 6*f(x).diff(x) + 5*f(x), f(x), + ... hint='nth_linear_constant_coeff_homogeneous')) + x -2*x + f(x) = (C1 + C2*x)*e + (C3*sin(x) + C4*cos(x))*e + + References + ========== + + - https://en.wikipedia.org/wiki/Linear_differential_equation section: + Nonhomogeneous_equation_with_constant_coefficients + - M. Tenenbaum & H. Pollard, "Ordinary Differential Equations", + Dover 1963, pp. 211 + + # indirect doctest + + """ + hint = "nth_linear_constant_coeff_homogeneous" + has_integral = False + + def _matches(self): + eq = self.ode_problem.eq_high_order_free + func = self.ode_problem.func + order = self.ode_problem.order + x = self.ode_problem.sym + self.r = self.ode_problem.get_linear_coefficients(eq, func, order) + if order and self.r and not any(self.r[i].has(x) for i in self.r if i >= 0): + if not self.r[-1]: + return True + else: + return False + return False + + def _get_general_solution(self, *, simplify_flag: bool = True): + fx = self.ode_problem.func + order = self.ode_problem.order + roots, collectterms = _get_const_characteristic_eq_sols(self.r, fx, order) + # A generator of constants + constants = self.ode_problem.get_numbered_constants(num=len(roots)) + gsol_rhs = Add(*[i*j for (i, j) in zip(constants, roots)]) + gsol = Eq(fx, gsol_rhs) + if simplify_flag: + gsol = _get_simplified_sol([gsol], fx, collectterms) + + return [gsol] + + +class NthLinearConstantCoeffVariationOfParameters(SingleODESolver): + r""" + Solves an `n`\th order linear differential equation with constant + coefficients using the method of variation of parameters. + + This method works on any differential equations of the form + + .. math:: f^{(n)}(x) + a_{n-1} f^{(n-1)}(x) + \cdots + a_1 f'(x) + a_0 + f(x) = P(x)\text{.} + + This method works by assuming that the particular solution takes the form + + .. math:: \sum_{x=1}^{n} c_i(x) y_i(x)\text{,} + + where `y_i` is the `i`\th solution to the homogeneous equation. The + solution is then solved using Wronskian's and Cramer's Rule. The + particular solution is given by + + .. math:: \sum_{x=1}^n \left( \int \frac{W_i(x)}{W(x)} \,dx + \right) y_i(x) \text{,} + + where `W(x)` is the Wronskian of the fundamental system (the system of `n` + linearly independent solutions to the homogeneous equation), and `W_i(x)` + is the Wronskian of the fundamental system with the `i`\th column replaced + with `[0, 0, \cdots, 0, P(x)]`. + + This method is general enough to solve any `n`\th order inhomogeneous + linear differential equation with constant coefficients, but sometimes + SymPy cannot simplify the Wronskian well enough to integrate it. If this + method hangs, try using the + ``nth_linear_constant_coeff_variation_of_parameters_Integral`` hint and + simplifying the integrals manually. Also, prefer using + ``nth_linear_constant_coeff_undetermined_coefficients`` when it + applies, because it does not use integration, making it faster and more + reliable. + + Warning, using simplify=False with + 'nth_linear_constant_coeff_variation_of_parameters' in + :py:meth:`~sympy.solvers.ode.dsolve` may cause it to hang, because it will + not attempt to simplify the Wronskian before integrating. It is + recommended that you only use simplify=False with + 'nth_linear_constant_coeff_variation_of_parameters_Integral' for this + method, especially if the solution to the homogeneous equation has + trigonometric functions in it. + + Examples + ======== + + >>> from sympy import Function, dsolve, pprint, exp, log + >>> from sympy.abc import x + >>> f = Function('f') + >>> pprint(dsolve(f(x).diff(x, 3) - 3*f(x).diff(x, 2) + + ... 3*f(x).diff(x) - f(x) - exp(x)*log(x), f(x), + ... hint='nth_linear_constant_coeff_variation_of_parameters')) + / / / x*log(x) 11*x\\\ x + f(x) = |C1 + x*|C2 + x*|C3 + -------- - ----|||*e + \ \ \ 6 36 /// + + References + ========== + + - https://en.wikipedia.org/wiki/Variation_of_parameters + - https://planetmath.org/VariationOfParameters + - M. Tenenbaum & H. Pollard, "Ordinary Differential Equations", + Dover 1963, pp. 233 + + # indirect doctest + + """ + hint = "nth_linear_constant_coeff_variation_of_parameters" + has_integral = True + + def _matches(self): + eq = self.ode_problem.eq_high_order_free + func = self.ode_problem.func + order = self.ode_problem.order + x = self.ode_problem.sym + self.r = self.ode_problem.get_linear_coefficients(eq, func, order) + + if order and self.r and not any(self.r[i].has(x) for i in self.r if i >= 0): + if self.r[-1]: + return True + else: + return False + return False + + def _get_general_solution(self, *, simplify_flag: bool = True): + eq = self.ode_problem.eq_high_order_free + f = self.ode_problem.func.func + x = self.ode_problem.sym + order = self.ode_problem.order + roots, collectterms = _get_const_characteristic_eq_sols(self.r, f(x), order) + # A generator of constants + constants = self.ode_problem.get_numbered_constants(num=len(roots)) + homogen_sol_rhs = Add(*[i*j for (i, j) in zip(constants, roots)]) + homogen_sol = Eq(f(x), homogen_sol_rhs) + homogen_sol = _solve_variation_of_parameters(eq, f(x), roots, homogen_sol, order, self.r, simplify_flag) + if simplify_flag: + homogen_sol = _get_simplified_sol([homogen_sol], f(x), collectterms) + return [homogen_sol] + + +class NthLinearConstantCoeffUndeterminedCoefficients(SingleODESolver): + r""" + Solves an `n`\th order linear differential equation with constant + coefficients using the method of undetermined coefficients. + + This method works on differential equations of the form + + .. math:: a_n f^{(n)}(x) + a_{n-1} f^{(n-1)}(x) + \cdots + a_1 f'(x) + + a_0 f(x) = P(x)\text{,} + + where `P(x)` is a function that has a finite number of linearly + independent derivatives. + + Functions that fit this requirement are finite sums functions of the form + `a x^i e^{b x} \sin(c x + d)` or `a x^i e^{b x} \cos(c x + d)`, where `i` + is a non-negative integer and `a`, `b`, `c`, and `d` are constants. For + example any polynomial in `x`, functions like `x^2 e^{2 x}`, `x \sin(x)`, + and `e^x \cos(x)` can all be used. Products of `\sin`'s and `\cos`'s have + a finite number of derivatives, because they can be expanded into `\sin(a + x)` and `\cos(b x)` terms. However, SymPy currently cannot do that + expansion, so you will need to manually rewrite the expression in terms of + the above to use this method. So, for example, you will need to manually + convert `\sin^2(x)` into `(1 + \cos(2 x))/2` to properly apply the method + of undetermined coefficients on it. + + This method works by creating a trial function from the expression and all + of its linear independent derivatives and substituting them into the + original ODE. The coefficients for each term will be a system of linear + equations, which are be solved for and substituted, giving the solution. + If any of the trial functions are linearly dependent on the solution to + the homogeneous equation, they are multiplied by sufficient `x` to make + them linearly independent. + + Examples + ======== + + >>> from sympy import Function, dsolve, pprint, exp, cos + >>> from sympy.abc import x + >>> f = Function('f') + >>> pprint(dsolve(f(x).diff(x, 2) + 2*f(x).diff(x) + f(x) - + ... 4*exp(-x)*x**2 + cos(2*x), f(x), + ... hint='nth_linear_constant_coeff_undetermined_coefficients')) + / / 3\\ + | | x || -x 4*sin(2*x) 3*cos(2*x) + f(x) = |C1 + x*|C2 + --||*e - ---------- + ---------- + \ \ 3 // 25 25 + + References + ========== + + - https://en.wikipedia.org/wiki/Method_of_undetermined_coefficients + - M. Tenenbaum & H. Pollard, "Ordinary Differential Equations", + Dover 1963, pp. 221 + + # indirect doctest + + """ + hint = "nth_linear_constant_coeff_undetermined_coefficients" + has_integral = False + + def _matches(self): + eq = self.ode_problem.eq_high_order_free + func = self.ode_problem.func + order = self.ode_problem.order + x = self.ode_problem.sym + self.r = self.ode_problem.get_linear_coefficients(eq, func, order) + does_match = False + if order and self.r and not any(self.r[i].has(x) for i in self.r if i >= 0): + if self.r[-1]: + eq_homogeneous = Add(eq, -self.r[-1]) + undetcoeff = _undetermined_coefficients_match(self.r[-1], x, func, eq_homogeneous) + if undetcoeff['test']: + self.trialset = undetcoeff['trialset'] + does_match = True + return does_match + + def _get_general_solution(self, *, simplify_flag: bool = True): + eq = self.ode_problem.eq + f = self.ode_problem.func.func + x = self.ode_problem.sym + order = self.ode_problem.order + roots, collectterms = _get_const_characteristic_eq_sols(self.r, f(x), order) + # A generator of constants + constants = self.ode_problem.get_numbered_constants(num=len(roots)) + homogen_sol_rhs = Add(*[i*j for (i, j) in zip(constants, roots)]) + homogen_sol = Eq(f(x), homogen_sol_rhs) + self.r.update({'list': roots, 'sol': homogen_sol, 'simpliy_flag': simplify_flag}) + gsol = _solve_undetermined_coefficients(eq, f(x), order, self.r, self.trialset) + if simplify_flag: + gsol = _get_simplified_sol([gsol], f(x), collectterms) + return [gsol] + + +class NthLinearEulerEqHomogeneous(SingleODESolver): + r""" + Solves an `n`\th order linear homogeneous variable-coefficient + Cauchy-Euler equidimensional ordinary differential equation. + + This is an equation with form `0 = a_0 f(x) + a_1 x f'(x) + a_2 x^2 f''(x) + \cdots`. + + These equations can be solved in a general manner, by substituting + solutions of the form `f(x) = x^r`, and deriving a characteristic equation + for `r`. When there are repeated roots, we include extra terms of the + form `C_{r k} \ln^k(x) x^r`, where `C_{r k}` is an arbitrary integration + constant, `r` is a root of the characteristic equation, and `k` ranges + over the multiplicity of `r`. In the cases where the roots are complex, + solutions of the form `C_1 x^a \sin(b \log(x)) + C_2 x^a \cos(b \log(x))` + are returned, based on expansions with Euler's formula. The general + solution is the sum of the terms found. If SymPy cannot find exact roots + to the characteristic equation, a + :py:obj:`~.ComplexRootOf` instance will be returned + instead. + + >>> from sympy import Function, dsolve + >>> from sympy.abc import x + >>> f = Function('f') + >>> dsolve(4*x**2*f(x).diff(x, 2) + f(x), f(x), + ... hint='nth_linear_euler_eq_homogeneous') + ... # doctest: +NORMALIZE_WHITESPACE + Eq(f(x), sqrt(x)*(C1 + C2*log(x))) + + Note that because this method does not involve integration, there is no + ``nth_linear_euler_eq_homogeneous_Integral`` hint. + + The following is for internal use: + + - ``returns = 'sol'`` returns the solution to the ODE. + - ``returns = 'list'`` returns a list of linearly independent solutions, + corresponding to the fundamental solution set, for use with non + homogeneous solution methods like variation of parameters and + undetermined coefficients. Note that, though the solutions should be + linearly independent, this function does not explicitly check that. You + can do ``assert simplify(wronskian(sollist)) != 0`` to check for linear + independence. Also, ``assert len(sollist) == order`` will need to pass. + - ``returns = 'both'``, return a dictionary ``{'sol': , + 'list': }``. + + Examples + ======== + + >>> from sympy import Function, dsolve, pprint + >>> from sympy.abc import x + >>> f = Function('f') + >>> eq = f(x).diff(x, 2)*x**2 - 4*f(x).diff(x)*x + 6*f(x) + >>> pprint(dsolve(eq, f(x), + ... hint='nth_linear_euler_eq_homogeneous')) + 2 + f(x) = x *(C1 + C2*x) + + References + ========== + + - https://en.wikipedia.org/wiki/Cauchy%E2%80%93Euler_equation + - C. Bender & S. Orszag, "Advanced Mathematical Methods for Scientists and + Engineers", Springer 1999, pp. 12 + + # indirect doctest + + """ + hint = "nth_linear_euler_eq_homogeneous" + has_integral = False + + def _matches(self): + eq = self.ode_problem.eq_preprocessed + f = self.ode_problem.func.func + order = self.ode_problem.order + x = self.ode_problem.sym + match = self.ode_problem.get_linear_coefficients(eq, f(x), order) + self.r = None + does_match = False + + if order and match: + coeff = match[order] + factor = x**order / coeff + self.r = {i: factor*match[i] for i in match} + if self.r and all(_test_term(self.r[i], f(x), i) for i in + self.r if i >= 0): + if not self.r[-1]: + does_match = True + return does_match + + def _get_general_solution(self, *, simplify_flag: bool = True): + fx = self.ode_problem.func + eq = self.ode_problem.eq + homogen_sol = _get_euler_characteristic_eq_sols(eq, fx, self.r)[0] + return [homogen_sol] + + +class NthLinearEulerEqNonhomogeneousVariationOfParameters(SingleODESolver): + r""" + Solves an `n`\th order linear non homogeneous Cauchy-Euler equidimensional + ordinary differential equation using variation of parameters. + + This is an equation with form `g(x) = a_0 f(x) + a_1 x f'(x) + a_2 x^2 f''(x) + \cdots`. + + This method works by assuming that the particular solution takes the form + + .. math:: \sum_{x=1}^{n} c_i(x) y_i(x) {a_n} {x^n} \text{, } + + where `y_i` is the `i`\th solution to the homogeneous equation. The + solution is then solved using Wronskian's and Cramer's Rule. The + particular solution is given by multiplying eq given below with `a_n x^{n}` + + .. math:: \sum_{x=1}^n \left( \int \frac{W_i(x)}{W(x)} \, dx + \right) y_i(x) \text{, } + + where `W(x)` is the Wronskian of the fundamental system (the system of `n` + linearly independent solutions to the homogeneous equation), and `W_i(x)` + is the Wronskian of the fundamental system with the `i`\th column replaced + with `[0, 0, \cdots, 0, \frac{x^{- n}}{a_n} g{\left(x \right)}]`. + + This method is general enough to solve any `n`\th order inhomogeneous + linear differential equation, but sometimes SymPy cannot simplify the + Wronskian well enough to integrate it. If this method hangs, try using the + ``nth_linear_constant_coeff_variation_of_parameters_Integral`` hint and + simplifying the integrals manually. Also, prefer using + ``nth_linear_constant_coeff_undetermined_coefficients`` when it + applies, because it does not use integration, making it faster and more + reliable. + + Warning, using simplify=False with + 'nth_linear_constant_coeff_variation_of_parameters' in + :py:meth:`~sympy.solvers.ode.dsolve` may cause it to hang, because it will + not attempt to simplify the Wronskian before integrating. It is + recommended that you only use simplify=False with + 'nth_linear_constant_coeff_variation_of_parameters_Integral' for this + method, especially if the solution to the homogeneous equation has + trigonometric functions in it. + + Examples + ======== + + >>> from sympy import Function, dsolve, Derivative + >>> from sympy.abc import x + >>> f = Function('f') + >>> eq = x**2*Derivative(f(x), x, x) - 2*x*Derivative(f(x), x) + 2*f(x) - x**4 + >>> dsolve(eq, f(x), + ... hint='nth_linear_euler_eq_nonhomogeneous_variation_of_parameters').expand() + Eq(f(x), C1*x + C2*x**2 + x**4/6) + + """ + hint = "nth_linear_euler_eq_nonhomogeneous_variation_of_parameters" + has_integral = True + + def _matches(self): + eq = self.ode_problem.eq_preprocessed + f = self.ode_problem.func.func + order = self.ode_problem.order + x = self.ode_problem.sym + match = self.ode_problem.get_linear_coefficients(eq, f(x), order) + self.r = None + does_match = False + + if order and match: + coeff = match[order] + factor = x**order / coeff + self.r = {i: factor*match[i] for i in match} + if self.r and all(_test_term(self.r[i], f(x), i) for i in + self.r if i >= 0): + if self.r[-1]: + does_match = True + + return does_match + + def _get_general_solution(self, *, simplify_flag: bool = True): + eq = self.ode_problem.eq + f = self.ode_problem.func.func + x = self.ode_problem.sym + order = self.ode_problem.order + homogen_sol, roots = _get_euler_characteristic_eq_sols(eq, f(x), self.r) + self.r[-1] = self.r[-1]/self.r[order] + sol = _solve_variation_of_parameters(eq, f(x), roots, homogen_sol, order, self.r, simplify_flag) + + return [Eq(f(x), homogen_sol.rhs + (sol.rhs - homogen_sol.rhs)*self.r[order])] + + +class NthLinearEulerEqNonhomogeneousUndeterminedCoefficients(SingleODESolver): + r""" + Solves an `n`\th order linear non homogeneous Cauchy-Euler equidimensional + ordinary differential equation using undetermined coefficients. + + This is an equation with form `g(x) = a_0 f(x) + a_1 x f'(x) + a_2 x^2 f''(x) + \cdots`. + + These equations can be solved in a general manner, by substituting + solutions of the form `x = exp(t)`, and deriving a characteristic equation + of form `g(exp(t)) = b_0 f(t) + b_1 f'(t) + b_2 f''(t) \cdots` which can + be then solved by nth_linear_constant_coeff_undetermined_coefficients if + g(exp(t)) has finite number of linearly independent derivatives. + + Functions that fit this requirement are finite sums functions of the form + `a x^i e^{b x} \sin(c x + d)` or `a x^i e^{b x} \cos(c x + d)`, where `i` + is a non-negative integer and `a`, `b`, `c`, and `d` are constants. For + example any polynomial in `x`, functions like `x^2 e^{2 x}`, `x \sin(x)`, + and `e^x \cos(x)` can all be used. Products of `\sin`'s and `\cos`'s have + a finite number of derivatives, because they can be expanded into `\sin(a + x)` and `\cos(b x)` terms. However, SymPy currently cannot do that + expansion, so you will need to manually rewrite the expression in terms of + the above to use this method. So, for example, you will need to manually + convert `\sin^2(x)` into `(1 + \cos(2 x))/2` to properly apply the method + of undetermined coefficients on it. + + After replacement of x by exp(t), this method works by creating a trial function + from the expression and all of its linear independent derivatives and + substituting them into the original ODE. The coefficients for each term + will be a system of linear equations, which are be solved for and + substituted, giving the solution. If any of the trial functions are linearly + dependent on the solution to the homogeneous equation, they are multiplied + by sufficient `x` to make them linearly independent. + + Examples + ======== + + >>> from sympy import dsolve, Function, Derivative, log + >>> from sympy.abc import x + >>> f = Function('f') + >>> eq = x**2*Derivative(f(x), x, x) - 2*x*Derivative(f(x), x) + 2*f(x) - log(x) + >>> dsolve(eq, f(x), + ... hint='nth_linear_euler_eq_nonhomogeneous_undetermined_coefficients').expand() + Eq(f(x), C1*x + C2*x**2 + log(x)/2 + 3/4) + + """ + hint = "nth_linear_euler_eq_nonhomogeneous_undetermined_coefficients" + has_integral = False + + def _matches(self): + eq = self.ode_problem.eq_high_order_free + f = self.ode_problem.func.func + order = self.ode_problem.order + x = self.ode_problem.sym + match = self.ode_problem.get_linear_coefficients(eq, f(x), order) + self.r = None + does_match = False + + if order and match: + coeff = match[order] + factor = x**order / coeff + self.r = {i: factor*match[i] for i in match} + if self.r and all(_test_term(self.r[i], f(x), i) for i in + self.r if i >= 0): + if self.r[-1]: + e, re = posify(self.r[-1].subs(x, exp(x))) + undetcoeff = _undetermined_coefficients_match(e.subs(re), x) + if undetcoeff['test']: + does_match = True + return does_match + + def _get_general_solution(self, *, simplify_flag: bool = True): + f = self.ode_problem.func.func + x = self.ode_problem.sym + chareq, eq, symbol = S.Zero, S.Zero, Dummy('x') + for i in self.r.keys(): + if i >= 0: + chareq += (self.r[i]*diff(x**symbol, x, i)*x**-symbol).expand() + + for i in range(1, degree(Poly(chareq, symbol))+1): + eq += chareq.coeff(symbol**i)*diff(f(x), x, i) + + if chareq.as_coeff_add(symbol)[0]: + eq += chareq.as_coeff_add(symbol)[0]*f(x) + e, re = posify(self.r[-1].subs(x, exp(x))) + eq += e.subs(re) + + self.const_undet_instance = NthLinearConstantCoeffUndeterminedCoefficients(SingleODEProblem(eq, f(x), x)) + sol = self.const_undet_instance.get_general_solution(simplify = simplify_flag)[0] + sol = sol.subs(x, log(x)) # type: ignore + sol = sol.subs(f(log(x)), f(x)).expand() # type: ignore + + return [sol] + + +class SecondLinearBessel(SingleODESolver): + r""" + Gives solution of the Bessel differential equation + + .. math :: x^2 \frac{d^2y}{dx^2} + x \frac{dy}{dx} y(x) + (x^2-n^2) y(x) + + if `n` is integer then the solution is of the form ``Eq(f(x), C0 besselj(n,x) + + C1 bessely(n,x))`` as both the solutions are linearly independent else if + `n` is a fraction then the solution is of the form ``Eq(f(x), C0 besselj(n,x) + + C1 besselj(-n,x))`` which can also transform into ``Eq(f(x), C0 besselj(n,x) + + C1 bessely(n,x))``. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy import Symbol + >>> v = Symbol('v', positive=True) + >>> from sympy import dsolve, Function + >>> f = Function('f') + >>> y = f(x) + >>> genform = x**2*y.diff(x, 2) + x*y.diff(x) + (x**2 - v**2)*y + >>> dsolve(genform) + Eq(f(x), C1*besselj(v, x) + C2*bessely(v, x)) + + References + ========== + + https://math24.net/bessel-differential-equation.html + + """ + hint = "2nd_linear_bessel" + has_integral = False + + def _matches(self): + eq = self.ode_problem.eq_high_order_free + f = self.ode_problem.func + order = self.ode_problem.order + x = self.ode_problem.sym + df = f.diff(x) + a = Wild('a', exclude=[f,df]) + b = Wild('b', exclude=[x, f,df]) + a4 = Wild('a4', exclude=[x,f,df]) + b4 = Wild('b4', exclude=[x,f,df]) + c4 = Wild('c4', exclude=[x,f,df]) + d4 = Wild('d4', exclude=[x,f,df]) + a3 = Wild('a3', exclude=[f, df, f.diff(x, 2)]) + b3 = Wild('b3', exclude=[f, df, f.diff(x, 2)]) + c3 = Wild('c3', exclude=[f, df, f.diff(x, 2)]) + deq = a3*(f.diff(x, 2)) + b3*df + c3*f + r = collect(eq, + [f.diff(x, 2), df, f]).match(deq) + if order == 2 and r: + if not all(r[key].is_polynomial() for key in r): + n, d = eq.as_numer_denom() + eq = expand(n) + r = collect(eq, + [f.diff(x, 2), df, f]).match(deq) + + if r and r[a3] != 0: + # leading coeff of f(x).diff(x, 2) + coeff = factor(r[a3]).match(a4*(x-b)**b4) + + if coeff: + # if coeff[b4] = 0 means constant coefficient + if coeff[b4] == 0: + return False + point = coeff[b] + else: + return False + + if point: + r[a3] = simplify(r[a3].subs(x, x+point)) + r[b3] = simplify(r[b3].subs(x, x+point)) + r[c3] = simplify(r[c3].subs(x, x+point)) + + # making a3 in the form of x**2 + r[a3] = cancel(r[a3]/(coeff[a4]*(x)**(-2+coeff[b4]))) + r[b3] = cancel(r[b3]/(coeff[a4]*(x)**(-2+coeff[b4]))) + r[c3] = cancel(r[c3]/(coeff[a4]*(x)**(-2+coeff[b4]))) + # checking if b3 is of form c*(x-b) + coeff1 = factor(r[b3]).match(a4*(x)) + if coeff1 is None: + return False + # c3 maybe of very complex form so I am simply checking (a - b) form + # if yes later I will match with the standard form of bessel in a and b + # a, b are wild variable defined above. + _coeff2 = expand(r[c3]).match(a - b) + if _coeff2 is None: + return False + # matching with standard form for c3 + coeff2 = factor(_coeff2[a]).match(c4**2*(x)**(2*a4)) + if coeff2 is None: + return False + + if _coeff2[b] == 0: + coeff2[d4] = 0 + else: + coeff2[d4] = factor(_coeff2[b]).match(d4**2)[d4] + + self.rn = {'n':coeff2[d4], 'a4':coeff2[c4], 'd4':coeff2[a4]} + self.rn['c4'] = coeff1[a4] + self.rn['b4'] = point + return True + return False + + def _get_general_solution(self, *, simplify_flag: bool = True): + f = self.ode_problem.func.func + x = self.ode_problem.sym + n = self.rn['n'] + a4 = self.rn['a4'] + c4 = self.rn['c4'] + d4 = self.rn['d4'] + b4 = self.rn['b4'] + n = sqrt(n**2 + Rational(1, 4)*(c4 - 1)**2) + (C1, C2) = self.ode_problem.get_numbered_constants(num=2) + return [Eq(f(x), ((x**(Rational(1-c4,2)))*(C1*besselj(n/d4,a4*x**d4/d4) + + C2*bessely(n/d4,a4*x**d4/d4))).subs(x, x-b4))] + + +class SecondLinearAiry(SingleODESolver): + r""" + Gives solution of the Airy differential equation + + .. math :: \frac{d^2y}{dx^2} + (a + b x) y(x) = 0 + + in terms of Airy special functions airyai and airybi. + + Examples + ======== + + >>> from sympy import dsolve, Function + >>> from sympy.abc import x + >>> f = Function("f") + >>> eq = f(x).diff(x, 2) - x*f(x) + >>> dsolve(eq) + Eq(f(x), C1*airyai(x) + C2*airybi(x)) + """ + hint = "2nd_linear_airy" + has_integral = False + + def _matches(self): + eq = self.ode_problem.eq_high_order_free + f = self.ode_problem.func + order = self.ode_problem.order + x = self.ode_problem.sym + df = f.diff(x) + a4 = Wild('a4', exclude=[x,f,df]) + b4 = Wild('b4', exclude=[x,f,df]) + match = self.ode_problem.get_linear_coefficients(eq, f, order) + does_match = False + if order == 2 and match and match[2] != 0: + if match[1].is_zero: + self.rn = cancel(match[0]/match[2]).match(a4+b4*x) + if self.rn and self.rn[b4] != 0: + self.rn = {'b':self.rn[a4],'m':self.rn[b4]} + does_match = True + return does_match + + def _get_general_solution(self, *, simplify_flag: bool = True): + f = self.ode_problem.func.func + x = self.ode_problem.sym + (C1, C2) = self.ode_problem.get_numbered_constants(num=2) + b = self.rn['b'] + m = self.rn['m'] + if m.is_positive: + arg = - b/cbrt(m)**2 - cbrt(m)*x + elif m.is_negative: + arg = - b/cbrt(-m)**2 + cbrt(-m)*x + else: + arg = - b/cbrt(-m)**2 + cbrt(-m)*x + + return [Eq(f(x), C1*airyai(arg) + C2*airybi(arg))] + + +class LieGroup(SingleODESolver): + r""" + This hint implements the Lie group method of solving first order differential + equations. The aim is to convert the given differential equation from the + given coordinate system into another coordinate system where it becomes + invariant under the one-parameter Lie group of translations. The converted + ODE can be easily solved by quadrature. It makes use of the + :py:meth:`sympy.solvers.ode.infinitesimals` function which returns the + infinitesimals of the transformation. + + The coordinates `r` and `s` can be found by solving the following Partial + Differential Equations. + + .. math :: \xi\frac{\partial r}{\partial x} + \eta\frac{\partial r}{\partial y} + = 0 + + .. math :: \xi\frac{\partial s}{\partial x} + \eta\frac{\partial s}{\partial y} + = 1 + + The differential equation becomes separable in the new coordinate system + + .. math :: \frac{ds}{dr} = \frac{\frac{\partial s}{\partial x} + + h(x, y)\frac{\partial s}{\partial y}}{ + \frac{\partial r}{\partial x} + h(x, y)\frac{\partial r}{\partial y}} + + After finding the solution by integration, it is then converted back to the original + coordinate system by substituting `r` and `s` in terms of `x` and `y` again. + + Examples + ======== + + >>> from sympy import Function, dsolve, exp, pprint + >>> from sympy.abc import x + >>> f = Function('f') + >>> pprint(dsolve(f(x).diff(x) + 2*x*f(x) - x*exp(-x**2), f(x), + ... hint='lie_group')) + / 2\ 2 + | x | -x + f(x) = |C1 + --|*e + \ 2 / + + + References + ========== + + - Solving differential equations by Symmetry Groups, + John Starrett, pp. 1 - pp. 14 + + """ + hint = "lie_group" + has_integral = False + + def _has_additional_params(self): + return 'xi' in self.ode_problem.params and 'eta' in self.ode_problem.params + + def _matches(self): + eq = self.ode_problem.eq + f = self.ode_problem.func.func + order = self.ode_problem.order + x = self.ode_problem.sym + df = f(x).diff(x) + y = Dummy('y') + d = Wild('d', exclude=[df, f(x).diff(x, 2)]) + e = Wild('e', exclude=[df]) + does_match = False + if self._has_additional_params() and order == 1: + xi = self.ode_problem.params['xi'] + eta = self.ode_problem.params['eta'] + self.r3 = {'xi': xi, 'eta': eta} + r = collect(eq, df, exact=True).match(d + e * df) + if r: + r['d'] = d + r['e'] = e + r['y'] = y + r[d] = r[d].subs(f(x), y) + r[e] = r[e].subs(f(x), y) + self.r3.update(r) + does_match = True + return does_match + + def _get_general_solution(self, *, simplify_flag: bool = True): + eq = self.ode_problem.eq + x = self.ode_problem.sym + func = self.ode_problem.func + order = self.ode_problem.order + df = func.diff(x) + + try: + eqsol = solve(eq, df) + except NotImplementedError: + eqsol = [] + + desols = [] + for s in eqsol: + sol = _ode_lie_group(s, func, order, match=self.r3) + if sol: + desols.extend(sol) + + if desols == []: + raise NotImplementedError("The given ODE " + str(eq) + " cannot be solved by" + + " the lie group method") + return desols + + +solver_map = { + 'factorable': Factorable, + 'nth_linear_constant_coeff_homogeneous': NthLinearConstantCoeffHomogeneous, + 'nth_linear_euler_eq_homogeneous': NthLinearEulerEqHomogeneous, + 'nth_linear_constant_coeff_undetermined_coefficients': NthLinearConstantCoeffUndeterminedCoefficients, + 'nth_linear_euler_eq_nonhomogeneous_undetermined_coefficients': NthLinearEulerEqNonhomogeneousUndeterminedCoefficients, + 'separable': Separable, + '1st_exact': FirstExact, + '1st_linear': FirstLinear, + 'Bernoulli': Bernoulli, + 'Riccati_special_minus2': RiccatiSpecial, + '1st_rational_riccati': RationalRiccati, + '1st_homogeneous_coeff_best': HomogeneousCoeffBest, + '1st_homogeneous_coeff_subs_indep_div_dep': HomogeneousCoeffSubsIndepDivDep, + '1st_homogeneous_coeff_subs_dep_div_indep': HomogeneousCoeffSubsDepDivIndep, + 'almost_linear': AlmostLinear, + 'linear_coefficients': LinearCoefficients, + 'separable_reduced': SeparableReduced, + 'nth_linear_constant_coeff_variation_of_parameters': NthLinearConstantCoeffVariationOfParameters, + 'nth_linear_euler_eq_nonhomogeneous_variation_of_parameters': NthLinearEulerEqNonhomogeneousVariationOfParameters, + 'Liouville': Liouville, + '2nd_linear_airy': SecondLinearAiry, + '2nd_linear_bessel': SecondLinearBessel, + '2nd_hypergeometric': SecondHypergeometric, + 'nth_order_reducible': NthOrderReducible, + '2nd_nonlinear_autonomous_conserved': SecondNonlinearAutonomousConserved, + 'nth_algebraic': NthAlgebraic, + 'lie_group': LieGroup, + } + +# Avoid circular import: +from .ode import dsolve, ode_sol_simplicity, odesimp, homogeneous_order diff --git a/phivenv/Lib/site-packages/sympy/solvers/ode/subscheck.py b/phivenv/Lib/site-packages/sympy/solvers/ode/subscheck.py new file mode 100644 index 0000000000000000000000000000000000000000..6ac7fba7d364bf599e928ccf591b5bef096576d0 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/solvers/ode/subscheck.py @@ -0,0 +1,392 @@ +from sympy.core import S, Pow +from sympy.core.function import (Derivative, AppliedUndef, diff) +from sympy.core.relational import Equality, Eq +from sympy.core.symbol import Dummy +from sympy.core.sympify import sympify + +from sympy.logic.boolalg import BooleanAtom +from sympy.functions import exp +from sympy.series import Order +from sympy.simplify.simplify import simplify, posify, besselsimp +from sympy.simplify.trigsimp import trigsimp +from sympy.simplify.sqrtdenest import sqrtdenest +from sympy.solvers import solve +from sympy.solvers.deutils import _preprocess, ode_order +from sympy.utilities.iterables import iterable, is_sequence + + +def sub_func_doit(eq, func, new): + r""" + When replacing the func with something else, we usually want the + derivative evaluated, so this function helps in making that happen. + + Examples + ======== + + >>> from sympy import Derivative, symbols, Function + >>> from sympy.solvers.ode.subscheck import sub_func_doit + >>> x, z = symbols('x, z') + >>> y = Function('y') + + >>> sub_func_doit(3*Derivative(y(x), x) - 1, y(x), x) + 2 + + >>> sub_func_doit(x*Derivative(y(x), x) - y(x)**2 + y(x), y(x), + ... 1/(x*(z + 1/x))) + x*(-1/(x**2*(z + 1/x)) + 1/(x**3*(z + 1/x)**2)) + 1/(x*(z + 1/x)) + ...- 1/(x**2*(z + 1/x)**2) + """ + reps= {func: new} + for d in eq.atoms(Derivative): + if d.expr == func: + reps[d] = new.diff(*d.variable_count) + else: + reps[d] = d.xreplace({func: new}).doit(deep=False) + return eq.xreplace(reps) + + +def checkodesol(ode, sol, func=None, order='auto', solve_for_func=True): + r""" + Substitutes ``sol`` into ``ode`` and checks that the result is ``0``. + + This works when ``func`` is one function, like `f(x)` or a list of + functions like `[f(x), g(x)]` when `ode` is a system of ODEs. ``sol`` can + be a single solution or a list of solutions. Each solution may be an + :py:class:`~sympy.core.relational.Equality` that the solution satisfies, + e.g. ``Eq(f(x), C1), Eq(f(x) + C1, 0)``; or simply an + :py:class:`~sympy.core.expr.Expr`, e.g. ``f(x) - C1``. In most cases it + will not be necessary to explicitly identify the function, but if the + function cannot be inferred from the original equation it can be supplied + through the ``func`` argument. + + If a sequence of solutions is passed, the same sort of container will be + used to return the result for each solution. + + It tries the following methods, in order, until it finds zero equivalence: + + 1. Substitute the solution for `f` in the original equation. This only + works if ``ode`` is solved for `f`. It will attempt to solve it first + unless ``solve_for_func == False``. + 2. Take `n` derivatives of the solution, where `n` is the order of + ``ode``, and check to see if that is equal to the solution. This only + works on exact ODEs. + 3. Take the 1st, 2nd, ..., `n`\th derivatives of the solution, each time + solving for the derivative of `f` of that order (this will always be + possible because `f` is a linear operator). Then back substitute each + derivative into ``ode`` in reverse order. + + This function returns a tuple. The first item in the tuple is ``True`` if + the substitution results in ``0``, and ``False`` otherwise. The second + item in the tuple is what the substitution results in. It should always + be ``0`` if the first item is ``True``. Sometimes this function will + return ``False`` even when an expression is identically equal to ``0``. + This happens when :py:meth:`~sympy.simplify.simplify.simplify` does not + reduce the expression to ``0``. If an expression returned by this + function vanishes identically, then ``sol`` really is a solution to + the ``ode``. + + If this function seems to hang, it is probably because of a hard + simplification. + + To use this function to test, test the first item of the tuple. + + Examples + ======== + + >>> from sympy import (Eq, Function, checkodesol, symbols, + ... Derivative, exp) + >>> x, C1, C2 = symbols('x,C1,C2') + >>> f, g = symbols('f g', cls=Function) + >>> checkodesol(f(x).diff(x), Eq(f(x), C1)) + (True, 0) + >>> assert checkodesol(f(x).diff(x), C1)[0] + >>> assert not checkodesol(f(x).diff(x), x)[0] + >>> checkodesol(f(x).diff(x, 2), x**2) + (False, 2) + + >>> eqs = [Eq(Derivative(f(x), x), f(x)), Eq(Derivative(g(x), x), g(x))] + >>> sol = [Eq(f(x), C1*exp(x)), Eq(g(x), C2*exp(x))] + >>> checkodesol(eqs, sol) + (True, [0, 0]) + + """ + if iterable(ode): + return checksysodesol(ode, sol, func=func) + + if not isinstance(ode, Equality): + ode = Eq(ode, 0) + if func is None: + try: + _, func = _preprocess(ode.lhs) + except ValueError: + funcs = [s.atoms(AppliedUndef) for s in ( + sol if is_sequence(sol, set) else [sol])] + funcs = set().union(*funcs) + if len(funcs) != 1: + raise ValueError( + 'must pass func arg to checkodesol for this case.') + func = funcs.pop() + if not isinstance(func, AppliedUndef) or len(func.args) != 1: + raise ValueError( + "func must be a function of one variable, not %s" % func) + if is_sequence(sol, set): + return type(sol)([checkodesol(ode, i, order=order, solve_for_func=solve_for_func) for i in sol]) + + if not isinstance(sol, Equality): + sol = Eq(func, sol) + elif sol.rhs == func: + sol = sol.reversed + + if order == 'auto': + order = ode_order(ode, func) + solved = sol.lhs == func and not sol.rhs.has(func) + if solve_for_func and not solved: + rhs = solve(sol, func) + if rhs: + eqs = [Eq(func, t) for t in rhs] + if len(rhs) == 1: + eqs = eqs[0] + return checkodesol(ode, eqs, order=order, + solve_for_func=False) + + x = func.args[0] + + # Handle series solutions here + if sol.has(Order): + assert sol.lhs == func + Oterm = sol.rhs.getO() + solrhs = sol.rhs.removeO() + + Oexpr = Oterm.expr + assert isinstance(Oexpr, Pow) + sorder = Oexpr.exp + assert Oterm == Order(x**sorder) + + odesubs = (ode.lhs-ode.rhs).subs(func, solrhs).doit().expand() + + neworder = Order(x**(sorder - order)) + odesubs = odesubs + neworder + assert odesubs.getO() == neworder + residual = odesubs.removeO() + + return (residual == 0, residual) + + s = True + testnum = 0 + while s: + if testnum == 0: + # First pass, try substituting a solved solution directly into the + # ODE. This has the highest chance of succeeding. + ode_diff = ode.lhs - ode.rhs + + if sol.lhs == func: + s = sub_func_doit(ode_diff, func, sol.rhs) + s = besselsimp(s) + else: + testnum += 1 + continue + ss = simplify(s.rewrite(exp)) + if ss: + # with the new numer_denom in power.py, if we do a simple + # expansion then testnum == 0 verifies all solutions. + s = ss.expand(force=True) + else: + s = 0 + testnum += 1 + elif testnum == 1: + # Second pass. If we cannot substitute f, try seeing if the nth + # derivative is equal, this will only work for odes that are exact, + # by definition. + s = simplify( + trigsimp(diff(sol.lhs, x, order) - diff(sol.rhs, x, order)) - + trigsimp(ode.lhs) + trigsimp(ode.rhs)) + # s2 = simplify( + # diff(sol.lhs, x, order) - diff(sol.rhs, x, order) - \ + # ode.lhs + ode.rhs) + testnum += 1 + elif testnum == 2: + # Third pass. Try solving for df/dx and substituting that into the + # ODE. Thanks to Chris Smith for suggesting this method. Many of + # the comments below are his, too. + # The method: + # - Take each of 1..n derivatives of the solution. + # - Solve each nth derivative for d^(n)f/dx^(n) + # (the differential of that order) + # - Back substitute into the ODE in decreasing order + # (i.e., n, n-1, ...) + # - Check the result for zero equivalence + if sol.lhs == func and not sol.rhs.has(func): + diffsols = {0: sol.rhs} + elif sol.rhs == func and not sol.lhs.has(func): + diffsols = {0: sol.lhs} + else: + diffsols = {} + sol = sol.lhs - sol.rhs + for i in range(1, order + 1): + # Differentiation is a linear operator, so there should always + # be 1 solution. Nonetheless, we test just to make sure. + # We only need to solve once. After that, we automatically + # have the solution to the differential in the order we want. + if i == 1: + ds = sol.diff(x) + try: + sdf = solve(ds, func.diff(x, i)) + if not sdf: + raise NotImplementedError + except NotImplementedError: + testnum += 1 + break + else: + diffsols[i] = sdf[0] + else: + # This is what the solution says df/dx should be. + diffsols[i] = diffsols[i - 1].diff(x) + + # Make sure the above didn't fail. + if testnum > 2: + continue + else: + # Substitute it into ODE to check for self consistency. + lhs, rhs = ode.lhs, ode.rhs + for i in range(order, -1, -1): + if i == 0 and 0 not in diffsols: + # We can only substitute f(x) if the solution was + # solved for f(x). + break + lhs = sub_func_doit(lhs, func.diff(x, i), diffsols[i]) + rhs = sub_func_doit(rhs, func.diff(x, i), diffsols[i]) + ode_or_bool = Eq(lhs, rhs) + ode_or_bool = simplify(ode_or_bool) + + if isinstance(ode_or_bool, (bool, BooleanAtom)): + if ode_or_bool: + lhs = rhs = S.Zero + else: + lhs = ode_or_bool.lhs + rhs = ode_or_bool.rhs + # No sense in overworking simplify -- just prove that the + # numerator goes to zero + num = trigsimp((lhs - rhs).as_numer_denom()[0]) + # since solutions are obtained using force=True we test + # using the same level of assumptions + ## replace function with dummy so assumptions will work + _func = Dummy('func') + num = num.subs(func, _func) + ## posify the expression + num, reps = posify(num) + s = simplify(num).xreplace(reps).xreplace({_func: func}) + testnum += 1 + else: + break + + if not s: + return (True, s) + elif s is True: # The code above never was able to change s + raise NotImplementedError("Unable to test if " + str(sol) + + " is a solution to " + str(ode) + ".") + else: + return (False, s) + + +def checksysodesol(eqs, sols, func=None): + r""" + Substitutes corresponding ``sols`` for each functions into each ``eqs`` and + checks that the result of substitutions for each equation is ``0``. The + equations and solutions passed can be any iterable. + + This only works when each ``sols`` have one function only, like `x(t)` or `y(t)`. + For each function, ``sols`` can have a single solution or a list of solutions. + In most cases it will not be necessary to explicitly identify the function, + but if the function cannot be inferred from the original equation it + can be supplied through the ``func`` argument. + + When a sequence of equations is passed, the same sequence is used to return + the result for each equation with each function substituted with corresponding + solutions. + + It tries the following method to find zero equivalence for each equation: + + Substitute the solutions for functions, like `x(t)` and `y(t)` into the + original equations containing those functions. + This function returns a tuple. The first item in the tuple is ``True`` if + the substitution results for each equation is ``0``, and ``False`` otherwise. + The second item in the tuple is what the substitution results in. Each element + of the ``list`` should always be ``0`` corresponding to each equation if the + first item is ``True``. Note that sometimes this function may return ``False``, + but with an expression that is identically equal to ``0``, instead of returning + ``True``. This is because :py:meth:`~sympy.simplify.simplify.simplify` cannot + reduce the expression to ``0``. If an expression returned by each function + vanishes identically, then ``sols`` really is a solution to ``eqs``. + + If this function seems to hang, it is probably because of a difficult simplification. + + Examples + ======== + + >>> from sympy import Eq, diff, symbols, sin, cos, exp, sqrt, S, Function + >>> from sympy.solvers.ode.subscheck import checksysodesol + >>> C1, C2 = symbols('C1:3') + >>> t = symbols('t') + >>> x, y = symbols('x, y', cls=Function) + >>> eq = (Eq(diff(x(t),t), x(t) + y(t) + 17), Eq(diff(y(t),t), -2*x(t) + y(t) + 12)) + >>> sol = [Eq(x(t), (C1*sin(sqrt(2)*t) + C2*cos(sqrt(2)*t))*exp(t) - S(5)/3), + ... Eq(y(t), (sqrt(2)*C1*cos(sqrt(2)*t) - sqrt(2)*C2*sin(sqrt(2)*t))*exp(t) - S(46)/3)] + >>> checksysodesol(eq, sol) + (True, [0, 0]) + >>> eq = (Eq(diff(x(t),t),x(t)*y(t)**4), Eq(diff(y(t),t),y(t)**3)) + >>> sol = [Eq(x(t), C1*exp(-1/(4*(C2 + t)))), Eq(y(t), -sqrt(2)*sqrt(-1/(C2 + t))/2), + ... Eq(x(t), C1*exp(-1/(4*(C2 + t)))), Eq(y(t), sqrt(2)*sqrt(-1/(C2 + t))/2)] + >>> checksysodesol(eq, sol) + (True, [0, 0]) + + """ + def _sympify(eq): + return list(map(sympify, eq if iterable(eq) else [eq])) + eqs = _sympify(eqs) + for i in range(len(eqs)): + if isinstance(eqs[i], Equality): + eqs[i] = eqs[i].lhs - eqs[i].rhs + if func is None: + funcs = [] + for eq in eqs: + derivs = eq.atoms(Derivative) + func = set().union(*[d.atoms(AppliedUndef) for d in derivs]) + funcs.extend(func) + funcs = list(set(funcs)) + if not all(isinstance(func, AppliedUndef) and len(func.args) == 1 for func in funcs)\ + and len({func.args for func in funcs})!=1: + raise ValueError("func must be a function of one variable, not %s" % func) + for sol in sols: + if len(sol.atoms(AppliedUndef)) != 1: + raise ValueError("solutions should have one function only") + if len(funcs) != len({sol.lhs for sol in sols}): + raise ValueError("number of solutions provided does not match the number of equations") + dictsol = {} + for sol in sols: + func = list(sol.atoms(AppliedUndef))[0] + if sol.rhs == func: + sol = sol.reversed + solved = sol.lhs == func and not sol.rhs.has(func) + if not solved: + rhs = solve(sol, func) + if not rhs: + raise NotImplementedError + else: + rhs = sol.rhs + dictsol[func] = rhs + checkeq = [] + for eq in eqs: + for func in funcs: + eq = sub_func_doit(eq, func, dictsol[func]) + ss = simplify(eq) + if ss != 0: + eq = ss.expand(force=True) + if eq != 0: + eq = sqrtdenest(eq).simplify() + else: + eq = 0 + checkeq.append(eq) + if len(set(checkeq)) == 1 and list(set(checkeq))[0] == 0: + return (True, checkeq) + else: + return (False, checkeq) diff --git a/phivenv/Lib/site-packages/sympy/solvers/ode/systems.py b/phivenv/Lib/site-packages/sympy/solvers/ode/systems.py new file mode 100644 index 0000000000000000000000000000000000000000..2d2c9b57a969c7fb5c67c06ce952fa398e22a48d --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/solvers/ode/systems.py @@ -0,0 +1,2135 @@ +from sympy.core import Add, Mul, S +from sympy.core.containers import Tuple +from sympy.core.exprtools import factor_terms +from sympy.core.numbers import I +from sympy.core.relational import Eq, Equality +from sympy.core.sorting import default_sort_key, ordered +from sympy.core.symbol import Dummy, Symbol +from sympy.core.function import (expand_mul, expand, Derivative, + AppliedUndef, Function, Subs) +from sympy.functions import (exp, im, cos, sin, re, Piecewise, + piecewise_fold, sqrt, log) +from sympy.functions.combinatorial.factorials import factorial +from sympy.matrices import zeros, Matrix, NonSquareMatrixError, MatrixBase, eye +from sympy.polys import Poly, together +from sympy.simplify import collect, radsimp, signsimp # type: ignore +from sympy.simplify.powsimp import powdenest, powsimp +from sympy.simplify.ratsimp import ratsimp +from sympy.simplify.simplify import simplify +from sympy.sets.sets import FiniteSet +from sympy.solvers.deutils import ode_order +from sympy.solvers.solveset import NonlinearError, solveset +from sympy.utilities.iterables import (connected_components, iterable, + strongly_connected_components) +from sympy.utilities.misc import filldedent +from sympy.integrals.integrals import Integral, integrate + + +def _get_func_order(eqs, funcs): + return {func: max(ode_order(eq, func) for eq in eqs) for func in funcs} + + +class ODEOrderError(ValueError): + """Raised by linear_ode_to_matrix if the system has the wrong order""" + pass + + +class ODENonlinearError(NonlinearError): + """Raised by linear_ode_to_matrix if the system is nonlinear""" + pass + + +def _simpsol(soleq): + lhs = soleq.lhs + sol = soleq.rhs + sol = powsimp(sol) + gens = list(sol.atoms(exp)) + p = Poly(sol, *gens, expand=False) + gens = [factor_terms(g) for g in gens] + if not gens: + gens = p.gens + syms = [Symbol('C1'), Symbol('C2')] + terms = [] + for coeff, monom in zip(p.coeffs(), p.monoms()): + coeff = piecewise_fold(coeff) + if isinstance(coeff, Piecewise): + coeff = Piecewise(*((ratsimp(coef).collect(syms), cond) for coef, cond in coeff.args)) + else: + coeff = ratsimp(coeff).collect(syms) + monom = Mul(*(g ** i for g, i in zip(gens, monom))) + terms.append(coeff * monom) + return Eq(lhs, Add(*terms)) + + +def _solsimp(e, t): + no_t, has_t = powsimp(expand_mul(e)).as_independent(t) + + no_t = ratsimp(no_t) + has_t = has_t.replace(exp, lambda a: exp(factor_terms(a))) + + return no_t + has_t + + +def simpsol(sol, wrt1, wrt2, doit=True): + """Simplify solutions from dsolve_system.""" + + # The parameter sol is the solution as returned by dsolve (list of Eq). + # + # The parameters wrt1 and wrt2 are lists of symbols to be collected for + # with those in wrt1 being collected for first. This allows for collecting + # on any factors involving the independent variable before collecting on + # the integration constants or vice versa using e.g.: + # + # sol = simpsol(sol, [t], [C1, C2]) # t first, constants after + # sol = simpsol(sol, [C1, C2], [t]) # constants first, t after + # + # If doit=True (default) then simpsol will begin by evaluating any + # unevaluated integrals. Since many integrals will appear multiple times + # in the solutions this is done intelligently by computing each integral + # only once. + # + # The strategy is to first perform simple cancellation with factor_terms + # and then multiply out all brackets with expand_mul. This gives an Add + # with many terms. + # + # We split each term into two multiplicative factors dep and coeff where + # all factors that involve wrt1 are in dep and any constant factors are in + # coeff e.g. + # sqrt(2)*C1*exp(t) -> ( exp(t), sqrt(2)*C1 ) + # + # The dep factors are simplified using powsimp to combine expanded + # exponential factors e.g. + # exp(a*t)*exp(b*t) -> exp(t*(a+b)) + # + # We then collect coefficients for all terms having the same (simplified) + # dep. The coefficients are then simplified using together and ratsimp and + # lastly by recursively applying the same transformation to the + # coefficients to collect on wrt2. + # + # Finally the result is recombined into an Add and signsimp is used to + # normalise any minus signs. + + def simprhs(rhs, rep, wrt1, wrt2): + """Simplify the rhs of an ODE solution""" + if rep: + rhs = rhs.subs(rep) + rhs = factor_terms(rhs) + rhs = simp_coeff_dep(rhs, wrt1, wrt2) + rhs = signsimp(rhs) + return rhs + + def simp_coeff_dep(expr, wrt1, wrt2=None): + """Split rhs into terms, split terms into dep and coeff and collect on dep""" + add_dep_terms = lambda e: e.is_Add and e.has(*wrt1) + expandable = lambda e: e.is_Mul and any(map(add_dep_terms, e.args)) + expand_func = lambda e: expand_mul(e, deep=False) + expand_mul_mod = lambda e: e.replace(expandable, expand_func) + terms = Add.make_args(expand_mul_mod(expr)) + dc = {} + for term in terms: + coeff, dep = term.as_independent(*wrt1, as_Add=False) + # Collect together the coefficients for terms that have the same + # dependence on wrt1 (after dep is normalised using simpdep). + dep = simpdep(dep, wrt1) + + # See if the dependence on t cancels out... + if dep is not S.One: + dep2 = factor_terms(dep) + if not dep2.has(*wrt1): + coeff *= dep2 + dep = S.One + + if dep not in dc: + dc[dep] = coeff + else: + dc[dep] += coeff + # Apply the method recursively to the coefficients but this time + # collecting on wrt2 rather than wrt2. + termpairs = ((simpcoeff(c, wrt2), d) for d, c in dc.items()) + if wrt2 is not None: + termpairs = ((simp_coeff_dep(c, wrt2), d) for c, d in termpairs) + return Add(*(c * d for c, d in termpairs)) + + def simpdep(term, wrt1): + """Normalise factors involving t with powsimp and recombine exp""" + def canonicalise(a): + # Using factor_terms here isn't quite right because it leads to things + # like exp(t*(1+t)) that we don't want. We do want to cancel factors + # and pull out a common denominator but ideally the numerator would be + # expressed as a standard form polynomial in t so we expand_mul + # and collect afterwards. + a = factor_terms(a) + num, den = a.as_numer_denom() + num = expand_mul(num) + num = collect(num, wrt1) + return num / den + + term = powsimp(term) + rep = {e: exp(canonicalise(e.args[0])) for e in term.atoms(exp)} + term = term.subs(rep) + return term + + def simpcoeff(coeff, wrt2): + """Bring to a common fraction and cancel with ratsimp""" + coeff = together(coeff) + if coeff.is_polynomial(): + # Calling ratsimp can be expensive. The main reason is to simplify + # sums of terms with irrational denominators so we limit ourselves + # to the case where the expression is polynomial in any symbols. + # Maybe there's a better approach... + coeff = ratsimp(radsimp(coeff)) + # collect on secondary variables first and any remaining symbols after + if wrt2 is not None: + syms = list(wrt2) + list(ordered(coeff.free_symbols - set(wrt2))) + else: + syms = list(ordered(coeff.free_symbols)) + coeff = collect(coeff, syms) + coeff = together(coeff) + return coeff + + # There are often repeated integrals. Collect unique integrals and + # evaluate each once and then substitute into the final result to replace + # all occurrences in each of the solution equations. + if doit: + integrals = set().union(*(s.atoms(Integral) for s in sol)) + rep = {i: factor_terms(i).doit() for i in integrals} + else: + rep = {} + + sol = [Eq(s.lhs, simprhs(s.rhs, rep, wrt1, wrt2)) for s in sol] + return sol + + +def linodesolve_type(A, t, b=None): + r""" + Helper function that determines the type of the system of ODEs for solving with :obj:`sympy.solvers.ode.systems.linodesolve()` + + Explanation + =========== + + This function takes in the coefficient matrix and/or the non-homogeneous term + and returns the type of the equation that can be solved by :obj:`sympy.solvers.ode.systems.linodesolve()`. + + If the system is constant coefficient homogeneous, then "type1" is returned + + If the system is constant coefficient non-homogeneous, then "type2" is returned + + If the system is non-constant coefficient homogeneous, then "type3" is returned + + If the system is non-constant coefficient non-homogeneous, then "type4" is returned + + If the system has a non-constant coefficient matrix which can be factorized into constant + coefficient matrix, then "type5" or "type6" is returned for when the system is homogeneous or + non-homogeneous respectively. + + Note that, if the system of ODEs is of "type3" or "type4", then along with the type, + the commutative antiderivative of the coefficient matrix is also returned. + + If the system cannot be solved by :obj:`sympy.solvers.ode.systems.linodesolve()`, then + NotImplementedError is raised. + + Parameters + ========== + + A : Matrix + Coefficient matrix of the system of ODEs + b : Matrix or None + Non-homogeneous term of the system. The default value is None. + If this argument is None, then the system is assumed to be homogeneous. + + Examples + ======== + + >>> from sympy import symbols, Matrix + >>> from sympy.solvers.ode.systems import linodesolve_type + >>> t = symbols("t") + >>> A = Matrix([[1, 1], [2, 3]]) + >>> b = Matrix([t, 1]) + + >>> linodesolve_type(A, t) + {'antiderivative': None, 'type_of_equation': 'type1'} + + >>> linodesolve_type(A, t, b=b) + {'antiderivative': None, 'type_of_equation': 'type2'} + + >>> A_t = Matrix([[1, t], [-t, 1]]) + + >>> linodesolve_type(A_t, t) + {'antiderivative': Matrix([ + [ t, t**2/2], + [-t**2/2, t]]), 'type_of_equation': 'type3'} + + >>> linodesolve_type(A_t, t, b=b) + {'antiderivative': Matrix([ + [ t, t**2/2], + [-t**2/2, t]]), 'type_of_equation': 'type4'} + + >>> A_non_commutative = Matrix([[1, t], [t, -1]]) + >>> linodesolve_type(A_non_commutative, t) + Traceback (most recent call last): + ... + NotImplementedError: + The system does not have a commutative antiderivative, it cannot be + solved by linodesolve. + + Returns + ======= + + Dict + + Raises + ====== + + NotImplementedError + When the coefficient matrix does not have a commutative antiderivative + + See Also + ======== + + linodesolve: Function for which linodesolve_type gets the information + + """ + + match = {} + is_non_constant = not _matrix_is_constant(A, t) + is_non_homogeneous = not (b is None or b.is_zero_matrix) + type = "type{}".format(int("{}{}".format(int(is_non_constant), int(is_non_homogeneous)), 2) + 1) + + B = None + match.update({"type_of_equation": type, "antiderivative": B}) + + if is_non_constant: + B, is_commuting = _is_commutative_anti_derivative(A, t) + if not is_commuting: + raise NotImplementedError(filldedent(''' + The system does not have a commutative antiderivative, it cannot be solved + by linodesolve. + ''')) + + match['antiderivative'] = B + match.update(_first_order_type5_6_subs(A, t, b=b)) + + return match + + +def _first_order_type5_6_subs(A, t, b=None): + match = {} + + factor_terms = _factor_matrix(A, t) + is_homogeneous = b is None or b.is_zero_matrix + + if factor_terms is not None: + t_ = Symbol("{}_".format(t)) + F_t = integrate(factor_terms[0], t) + inverse = solveset(Eq(t_, F_t), t) + + # Note: A simple way to check if a function is invertible + # or not. + if isinstance(inverse, FiniteSet) and not inverse.has(Piecewise)\ + and len(inverse) == 1: + + A = factor_terms[1] + if not is_homogeneous: + b = b / factor_terms[0] + b = b.subs(t, list(inverse)[0]) + type = "type{}".format(5 + (not is_homogeneous)) + match.update({'func_coeff': A, 'tau': F_t, + 't_': t_, 'type_of_equation': type, 'rhs': b}) + + return match + + +def linear_ode_to_matrix(eqs, funcs, t, order): + r""" + Convert a linear system of ODEs to matrix form + + Explanation + =========== + + Express a system of linear ordinary differential equations as a single + matrix differential equation [1]. For example the system $x' = x + y + 1$ + and $y' = x - y$ can be represented as + + .. math:: A_1 X' = A_0 X + b + + where $A_1$ and $A_0$ are $2 \times 2$ matrices and $b$, $X$ and $X'$ are + $2 \times 1$ matrices with $X = [x, y]^T$. + + Higher-order systems are represented with additional matrices e.g. a + second-order system would look like + + .. math:: A_2 X'' = A_1 X' + A_0 X + b + + Examples + ======== + + >>> from sympy import Function, Symbol, Matrix, Eq + >>> from sympy.solvers.ode.systems import linear_ode_to_matrix + >>> t = Symbol('t') + >>> x = Function('x') + >>> y = Function('y') + + We can create a system of linear ODEs like + + >>> eqs = [ + ... Eq(x(t).diff(t), x(t) + y(t) + 1), + ... Eq(y(t).diff(t), x(t) - y(t)), + ... ] + >>> funcs = [x(t), y(t)] + >>> order = 1 # 1st order system + + Now ``linear_ode_to_matrix`` can represent this as a matrix + differential equation. + + >>> (A1, A0), b = linear_ode_to_matrix(eqs, funcs, t, order) + >>> A1 + Matrix([ + [1, 0], + [0, 1]]) + >>> A0 + Matrix([ + [1, 1], + [1, -1]]) + >>> b + Matrix([ + [1], + [0]]) + + The original equations can be recovered from these matrices: + + >>> eqs_mat = Matrix([eq.lhs - eq.rhs for eq in eqs]) + >>> X = Matrix(funcs) + >>> A1 * X.diff(t) - A0 * X - b == eqs_mat + True + + If the system of equations has a maximum order greater than the + order of the system specified, a ODEOrderError exception is raised. + + >>> eqs = [Eq(x(t).diff(t, 2), x(t).diff(t) + x(t)), Eq(y(t).diff(t), y(t) + x(t))] + >>> linear_ode_to_matrix(eqs, funcs, t, 1) + Traceback (most recent call last): + ... + ODEOrderError: Cannot represent system in 1-order form + + If the system of equations is nonlinear, then ODENonlinearError is + raised. + + >>> eqs = [Eq(x(t).diff(t), x(t) + y(t)), Eq(y(t).diff(t), y(t)**2 + x(t))] + >>> linear_ode_to_matrix(eqs, funcs, t, 1) + Traceback (most recent call last): + ... + ODENonlinearError: The system of ODEs is nonlinear. + + Parameters + ========== + + eqs : list of SymPy expressions or equalities + The equations as expressions (assumed equal to zero). + funcs : list of applied functions + The dependent variables of the system of ODEs. + t : symbol + The independent variable. + order : int + The order of the system of ODEs. + + Returns + ======= + + The tuple ``(As, b)`` where ``As`` is a tuple of matrices and ``b`` is the + the matrix representing the rhs of the matrix equation. + + Raises + ====== + + ODEOrderError + When the system of ODEs have an order greater than what was specified + ODENonlinearError + When the system of ODEs is nonlinear + + See Also + ======== + + linear_eq_to_matrix: for systems of linear algebraic equations. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Matrix_differential_equation + + """ + from sympy.solvers.solveset import linear_eq_to_matrix + + if any(ode_order(eq, func) > order for eq in eqs for func in funcs): + msg = "Cannot represent system in {}-order form" + raise ODEOrderError(msg.format(order)) + + As = [] + + for o in range(order, -1, -1): + # Work from the highest derivative down + syms = [func.diff(t, o) for func in funcs] + + # Ai is the matrix for X(t).diff(t, o) + # eqs is minus the remainder of the equations. + try: + Ai, b = linear_eq_to_matrix(eqs, syms) + except NonlinearError: + raise ODENonlinearError("The system of ODEs is nonlinear.") + + Ai = Ai.applyfunc(expand_mul) + + As.append(Ai if o == order else -Ai) + + if o: + eqs = [-eq for eq in b] + else: + rhs = b + + return As, rhs + + +def matrix_exp(A, t): + r""" + Matrix exponential $\exp(A*t)$ for the matrix ``A`` and scalar ``t``. + + Explanation + =========== + + This functions returns the $\exp(A*t)$ by doing a simple + matrix multiplication: + + .. math:: \exp(A*t) = P * expJ * P^{-1} + + where $expJ$ is $\exp(J*t)$. $J$ is the Jordan normal + form of $A$ and $P$ is matrix such that: + + .. math:: A = P * J * P^{-1} + + The matrix exponential $\exp(A*t)$ appears in the solution of linear + differential equations. For example if $x$ is a vector and $A$ is a matrix + then the initial value problem + + .. math:: \frac{dx(t)}{dt} = A \times x(t), x(0) = x0 + + has the unique solution + + .. math:: x(t) = \exp(A t) x0 + + Examples + ======== + + >>> from sympy import Symbol, Matrix, pprint + >>> from sympy.solvers.ode.systems import matrix_exp + >>> t = Symbol('t') + + We will consider a 2x2 matrix for comupting the exponential + + >>> A = Matrix([[2, -5], [2, -4]]) + >>> pprint(A) + [2 -5] + [ ] + [2 -4] + + Now, exp(A*t) is given as follows: + + >>> pprint(matrix_exp(A, t)) + [ -t -t -t ] + [3*e *sin(t) + e *cos(t) -5*e *sin(t) ] + [ ] + [ -t -t -t ] + [ 2*e *sin(t) - 3*e *sin(t) + e *cos(t)] + + Parameters + ========== + + A : Matrix + The matrix $A$ in the expression $\exp(A*t)$ + t : Symbol + The independent variable + + See Also + ======== + + matrix_exp_jordan_form: For exponential of Jordan normal form + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Jordan_normal_form + .. [2] https://en.wikipedia.org/wiki/Matrix_exponential + + """ + P, expJ = matrix_exp_jordan_form(A, t) + return P * expJ * P.inv() + + +def matrix_exp_jordan_form(A, t): + r""" + Matrix exponential $\exp(A*t)$ for the matrix *A* and scalar *t*. + + Explanation + =========== + + Returns the Jordan form of the $\exp(A*t)$ along with the matrix $P$ such that: + + .. math:: + \exp(A*t) = P * expJ * P^{-1} + + Examples + ======== + + >>> from sympy import Matrix, Symbol + >>> from sympy.solvers.ode.systems import matrix_exp, matrix_exp_jordan_form + >>> t = Symbol('t') + + We will consider a 2x2 defective matrix. This shows that our method + works even for defective matrices. + + >>> A = Matrix([[1, 1], [0, 1]]) + + It can be observed that this function gives us the Jordan normal form + and the required invertible matrix P. + + >>> P, expJ = matrix_exp_jordan_form(A, t) + + Here, it is shown that P and expJ returned by this function is correct + as they satisfy the formula: P * expJ * P_inverse = exp(A*t). + + >>> P * expJ * P.inv() == matrix_exp(A, t) + True + + Parameters + ========== + + A : Matrix + The matrix $A$ in the expression $\exp(A*t)$ + t : Symbol + The independent variable + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Defective_matrix + .. [2] https://en.wikipedia.org/wiki/Jordan_matrix + .. [3] https://en.wikipedia.org/wiki/Jordan_normal_form + + """ + + N, M = A.shape + if N != M: + raise ValueError('Needed square matrix but got shape (%s, %s)' % (N, M)) + elif A.has(t): + raise ValueError('Matrix A should not depend on t') + + def jordan_chains(A): + '''Chains from Jordan normal form analogous to M.eigenvects(). + Returns a dict with eignevalues as keys like: + {e1: [[v111,v112,...], [v121, v122,...]], e2:...} + where vijk is the kth vector in the jth chain for eigenvalue i. + ''' + P, blocks = A.jordan_cells() + basis = [P[:,i] for i in range(P.shape[1])] + n = 0 + chains = {} + for b in blocks: + eigval = b[0, 0] + size = b.shape[0] + if eigval not in chains: + chains[eigval] = [] + chains[eigval].append(basis[n:n+size]) + n += size + return chains + + eigenchains = jordan_chains(A) + + # Needed for consistency across Python versions + eigenchains_iter = sorted(eigenchains.items(), key=default_sort_key) + isreal = not A.has(I) + + blocks = [] + vectors = [] + seen_conjugate = set() + for e, chains in eigenchains_iter: + for chain in chains: + n = len(chain) + if isreal and e != e.conjugate() and e.conjugate() in eigenchains: + if e in seen_conjugate: + continue + seen_conjugate.add(e.conjugate()) + exprt = exp(re(e) * t) + imrt = im(e) * t + imblock = Matrix([[cos(imrt), sin(imrt)], + [-sin(imrt), cos(imrt)]]) + expJblock2 = Matrix(n, n, lambda i,j: + imblock * t**(j-i) / factorial(j-i) if j >= i + else zeros(2, 2)) + expJblock = Matrix(2*n, 2*n, lambda i,j: expJblock2[i//2,j//2][i%2,j%2]) + + blocks.append(exprt * expJblock) + for i in range(n): + vectors.append(re(chain[i])) + vectors.append(im(chain[i])) + else: + vectors.extend(chain) + fun = lambda i,j: t**(j-i)/factorial(j-i) if j >= i else 0 + expJblock = Matrix(n, n, fun) + blocks.append(exp(e * t) * expJblock) + + expJ = Matrix.diag(*blocks) + P = Matrix(N, N, lambda i,j: vectors[j][i]) + + return P, expJ + + +# Note: To add a docstring example with tau +def linodesolve(A, t, b=None, B=None, type="auto", doit=False, + tau=None): + r""" + System of n equations linear first-order differential equations + + Explanation + =========== + + This solver solves the system of ODEs of the following form: + + .. math:: + X'(t) = A(t) X(t) + b(t) + + Here, $A(t)$ is the coefficient matrix, $X(t)$ is the vector of n independent variables, + $b(t)$ is the non-homogeneous term and $X'(t)$ is the derivative of $X(t)$ + + Depending on the properties of $A(t)$ and $b(t)$, this solver evaluates the solution + differently. + + When $A(t)$ is constant coefficient matrix and $b(t)$ is zero vector i.e. system is homogeneous, + the system is "type1". The solution is: + + .. math:: + X(t) = \exp(A t) C + + Here, $C$ is a vector of constants and $A$ is the constant coefficient matrix. + + When $A(t)$ is constant coefficient matrix and $b(t)$ is non-zero i.e. system is non-homogeneous, + the system is "type2". The solution is: + + .. math:: + X(t) = e^{A t} ( \int e^{- A t} b \,dt + C) + + When $A(t)$ is coefficient matrix such that its commutative with its antiderivative $B(t)$ and + $b(t)$ is a zero vector i.e. system is homogeneous, the system is "type3". The solution is: + + .. math:: + X(t) = \exp(B(t)) C + + When $A(t)$ is commutative with its antiderivative $B(t)$ and $b(t)$ is non-zero i.e. system is + non-homogeneous, the system is "type4". The solution is: + + .. math:: + X(t) = e^{B(t)} ( \int e^{-B(t)} b(t) \,dt + C) + + When $A(t)$ is a coefficient matrix such that it can be factorized into a scalar and a constant + coefficient matrix: + + .. math:: + A(t) = f(t) * A + + Where $f(t)$ is a scalar expression in the independent variable $t$ and $A$ is a constant matrix, + then we can do the following substitutions: + + .. math:: + tau = \int f(t) dt, X(t) = Y(tau), b(t) = b(f^{-1}(tau)) + + Here, the substitution for the non-homogeneous term is done only when its non-zero. + Using these substitutions, our original system becomes: + + .. math:: + Y'(tau) = A * Y(tau) + b(tau)/f(tau) + + The above system can be easily solved using the solution for "type1" or "type2" depending + on the homogeneity of the system. After we get the solution for $Y(tau)$, we substitute the + solution for $tau$ as $t$ to get back $X(t)$ + + .. math:: + X(t) = Y(tau) + + Systems of "type5" and "type6" have a commutative antiderivative but we use this solution + because its faster to compute. + + The final solution is the general solution for all the four equations since a constant coefficient + matrix is always commutative with its antidervative. + + An additional feature of this function is, if someone wants to substitute for value of the independent + variable, they can pass the substitution `tau` and the solution will have the independent variable + substituted with the passed expression(`tau`). + + Parameters + ========== + + A : Matrix + Coefficient matrix of the system of linear first order ODEs. + t : Symbol + Independent variable in the system of ODEs. + b : Matrix or None + Non-homogeneous term in the system of ODEs. If None is passed, + a homogeneous system of ODEs is assumed. + B : Matrix or None + Antiderivative of the coefficient matrix. If the antiderivative + is not passed and the solution requires the term, then the solver + would compute it internally. + type : String + Type of the system of ODEs passed. Depending on the type, the + solution is evaluated. The type values allowed and the corresponding + system it solves are: "type1" for constant coefficient homogeneous + "type2" for constant coefficient non-homogeneous, "type3" for non-constant + coefficient homogeneous, "type4" for non-constant coefficient non-homogeneous, + "type5" and "type6" for non-constant coefficient homogeneous and non-homogeneous + systems respectively where the coefficient matrix can be factorized to a constant + coefficient matrix. + The default value is "auto" which will let the solver decide the correct type of + the system passed. + doit : Boolean + Evaluate the solution if True, default value is False + tau: Expression + Used to substitute for the value of `t` after we get the solution of the system. + + Examples + ======== + + To solve the system of ODEs using this function directly, several things must be + done in the right order. Wrong inputs to the function will lead to incorrect results. + + >>> from sympy import symbols, Function, Eq + >>> from sympy.solvers.ode.systems import canonical_odes, linear_ode_to_matrix, linodesolve, linodesolve_type + >>> from sympy.solvers.ode.subscheck import checkodesol + >>> f, g = symbols("f, g", cls=Function) + >>> x, a = symbols("x, a") + >>> funcs = [f(x), g(x)] + >>> eqs = [Eq(f(x).diff(x) - f(x), a*g(x) + 1), Eq(g(x).diff(x) + g(x), a*f(x))] + + Here, it is important to note that before we derive the coefficient matrix, it is + important to get the system of ODEs into the desired form. For that we will use + :obj:`sympy.solvers.ode.systems.canonical_odes()`. + + >>> eqs = canonical_odes(eqs, funcs, x) + >>> eqs + [[Eq(Derivative(f(x), x), a*g(x) + f(x) + 1), Eq(Derivative(g(x), x), a*f(x) - g(x))]] + + Now, we will use :obj:`sympy.solvers.ode.systems.linear_ode_to_matrix()` to get the coefficient matrix and the + non-homogeneous term if it is there. + + >>> eqs = eqs[0] + >>> (A1, A0), b = linear_ode_to_matrix(eqs, funcs, x, 1) + >>> A = A0 + + We have the coefficient matrices and the non-homogeneous term ready. Now, we can use + :obj:`sympy.solvers.ode.systems.linodesolve_type()` to get the information for the system of ODEs + to finally pass it to the solver. + + >>> system_info = linodesolve_type(A, x, b=b) + >>> sol_vector = linodesolve(A, x, b=b, B=system_info['antiderivative'], type=system_info['type_of_equation']) + + Now, we can prove if the solution is correct or not by using :obj:`sympy.solvers.ode.checkodesol()` + + >>> sol = [Eq(f, s) for f, s in zip(funcs, sol_vector)] + >>> checkodesol(eqs, sol) + (True, [0, 0]) + + We can also use the doit method to evaluate the solutions passed by the function. + + >>> sol_vector_evaluated = linodesolve(A, x, b=b, type="type2", doit=True) + + Now, we will look at a system of ODEs which is non-constant. + + >>> eqs = [Eq(f(x).diff(x), f(x) + x*g(x)), Eq(g(x).diff(x), -x*f(x) + g(x))] + + The system defined above is already in the desired form, so we do not have to convert it. + + >>> (A1, A0), b = linear_ode_to_matrix(eqs, funcs, x, 1) + >>> A = A0 + + A user can also pass the commutative antiderivative required for type3 and type4 system of ODEs. + Passing an incorrect one will lead to incorrect results. If the coefficient matrix is not commutative + with its antiderivative, then :obj:`sympy.solvers.ode.systems.linodesolve_type()` raises a NotImplementedError. + If it does have a commutative antiderivative, then the function just returns the information about the system. + + >>> system_info = linodesolve_type(A, x, b=b) + + Now, we can pass the antiderivative as an argument to get the solution. If the system information is not + passed, then the solver will compute the required arguments internally. + + >>> sol_vector = linodesolve(A, x, b=b) + + Once again, we can verify the solution obtained. + + >>> sol = [Eq(f, s) for f, s in zip(funcs, sol_vector)] + >>> checkodesol(eqs, sol) + (True, [0, 0]) + + Returns + ======= + + List + + Raises + ====== + + ValueError + This error is raised when the coefficient matrix, non-homogeneous term + or the antiderivative, if passed, are not a matrix or + do not have correct dimensions + NonSquareMatrixError + When the coefficient matrix or its antiderivative, if passed is not a + square matrix + NotImplementedError + If the coefficient matrix does not have a commutative antiderivative + + See Also + ======== + + linear_ode_to_matrix: Coefficient matrix computation function + canonical_odes: System of ODEs representation change + linodesolve_type: Getting information about systems of ODEs to pass in this solver + + """ + + if not isinstance(A, MatrixBase): + raise ValueError(filldedent('''\ + The coefficients of the system of ODEs should be of type Matrix + ''')) + + if not A.is_square: + raise NonSquareMatrixError(filldedent('''\ + The coefficient matrix must be a square + ''')) + + if b is not None: + if not isinstance(b, MatrixBase): + raise ValueError(filldedent('''\ + The non-homogeneous terms of the system of ODEs should be of type Matrix + ''')) + + if A.rows != b.rows: + raise ValueError(filldedent('''\ + The system of ODEs should have the same number of non-homogeneous terms and the number of + equations + ''')) + + if B is not None: + if not isinstance(B, MatrixBase): + raise ValueError(filldedent('''\ + The antiderivative of coefficients of the system of ODEs should be of type Matrix + ''')) + + if not B.is_square: + raise NonSquareMatrixError(filldedent('''\ + The antiderivative of the coefficient matrix must be a square + ''')) + + if A.rows != B.rows: + raise ValueError(filldedent('''\ + The coefficient matrix and its antiderivative should have same dimensions + ''')) + + if not any(type == "type{}".format(i) for i in range(1, 7)) and not type == "auto": + raise ValueError(filldedent('''\ + The input type should be a valid one + ''')) + + n = A.rows + + # constants = numbered_symbols(prefix='C', cls=Dummy, start=const_idx+1) + Cvect = Matrix([Dummy() for _ in range(n)]) + + if b is None and any(type == typ for typ in ["type2", "type4", "type6"]): + b = zeros(n, 1) + + is_transformed = tau is not None + passed_type = type + + if type == "auto": + system_info = linodesolve_type(A, t, b=b) + type = system_info["type_of_equation"] + B = system_info["antiderivative"] + + if type in ("type5", "type6"): + is_transformed = True + if passed_type != "auto": + if tau is None: + system_info = _first_order_type5_6_subs(A, t, b=b) + if not system_info: + raise ValueError(filldedent(''' + The system passed isn't {}. + '''.format(type))) + + tau = system_info['tau'] + t = system_info['t_'] + A = system_info['A'] + b = system_info['b'] + + intx_wrtt = lambda x: Integral(x, t) if x else 0 + if type in ("type1", "type2", "type5", "type6"): + P, J = matrix_exp_jordan_form(A, t) + P = simplify(P) + + if type in ("type1", "type5"): + sol_vector = P * (J * Cvect) + else: + Jinv = J.subs(t, -t) + sol_vector = P * J * ((Jinv * P.inv() * b).applyfunc(intx_wrtt) + Cvect) + else: + if B is None: + B, _ = _is_commutative_anti_derivative(A, t) + + if type == "type3": + sol_vector = B.exp() * Cvect + else: + sol_vector = B.exp() * (((-B).exp() * b).applyfunc(intx_wrtt) + Cvect) + + if is_transformed: + sol_vector = sol_vector.subs(t, tau) + + gens = sol_vector.atoms(exp) + + if type != "type1": + sol_vector = [expand_mul(s) for s in sol_vector] + + sol_vector = [collect(s, ordered(gens), exact=True) for s in sol_vector] + + if doit: + sol_vector = [s.doit() for s in sol_vector] + + return sol_vector + + +def _matrix_is_constant(M, t): + """Checks if the matrix M is independent of t or not.""" + return all(coef.as_independent(t, as_Add=True)[1] == 0 for coef in M) + + +def canonical_odes(eqs, funcs, t): + r""" + Function that solves for highest order derivatives in a system + + Explanation + =========== + + This function inputs a system of ODEs and based on the system, + the dependent variables and their highest order, returns the system + in the following form: + + .. math:: + X'(t) = A(t) X(t) + b(t) + + Here, $X(t)$ is the vector of dependent variables of lower order, $A(t)$ is + the coefficient matrix, $b(t)$ is the non-homogeneous term and $X'(t)$ is the + vector of dependent variables in their respective highest order. We use the term + canonical form to imply the system of ODEs which is of the above form. + + If the system passed has a non-linear term with multiple solutions, then a list of + systems is returned in its canonical form. + + Parameters + ========== + + eqs : List + List of the ODEs + funcs : List + List of dependent variables + t : Symbol + Independent variable + + Examples + ======== + + >>> from sympy import symbols, Function, Eq, Derivative + >>> from sympy.solvers.ode.systems import canonical_odes + >>> f, g = symbols("f g", cls=Function) + >>> x, y = symbols("x y") + >>> funcs = [f(x), g(x)] + >>> eqs = [Eq(f(x).diff(x) - 7*f(x), 12*g(x)), Eq(g(x).diff(x) + g(x), 20*f(x))] + + >>> canonical_eqs = canonical_odes(eqs, funcs, x) + >>> canonical_eqs + [[Eq(Derivative(f(x), x), 7*f(x) + 12*g(x)), Eq(Derivative(g(x), x), 20*f(x) - g(x))]] + + >>> system = [Eq(Derivative(f(x), x)**2 - 2*Derivative(f(x), x) + 1, 4), Eq(-y*f(x) + Derivative(g(x), x), 0)] + + >>> canonical_system = canonical_odes(system, funcs, x) + >>> canonical_system + [[Eq(Derivative(f(x), x), -1), Eq(Derivative(g(x), x), y*f(x))], [Eq(Derivative(f(x), x), 3), Eq(Derivative(g(x), x), y*f(x))]] + + Returns + ======= + + List + + """ + from sympy.solvers.solvers import solve + + order = _get_func_order(eqs, funcs) + + canon_eqs = solve(eqs, *[func.diff(t, order[func]) for func in funcs], dict=True) + + systems = [] + for eq in canon_eqs: + system = [Eq(func.diff(t, order[func]), eq[func.diff(t, order[func])]) for func in funcs] + systems.append(system) + + return systems + + +def _is_commutative_anti_derivative(A, t): + r""" + Helper function for determining if the Matrix passed is commutative with its antiderivative + + Explanation + =========== + + This function checks if the Matrix $A$ passed is commutative with its antiderivative with respect + to the independent variable $t$. + + .. math:: + B(t) = \int A(t) dt + + The function outputs two values, first one being the antiderivative $B(t)$, second one being a + boolean value, if True, then the matrix $A(t)$ passed is commutative with $B(t)$, else the matrix + passed isn't commutative with $B(t)$. + + Parameters + ========== + + A : Matrix + The matrix which has to be checked + t : Symbol + Independent variable + + Examples + ======== + + >>> from sympy import symbols, Matrix + >>> from sympy.solvers.ode.systems import _is_commutative_anti_derivative + >>> t = symbols("t") + >>> A = Matrix([[1, t], [-t, 1]]) + + >>> B, is_commuting = _is_commutative_anti_derivative(A, t) + >>> is_commuting + True + + Returns + ======= + + Matrix, Boolean + + """ + B = integrate(A, t) + is_commuting = (B*A - A*B).applyfunc(expand).applyfunc(factor_terms).is_zero_matrix + + is_commuting = False if is_commuting is None else is_commuting + + return B, is_commuting + + +def _factor_matrix(A, t): + term = None + for element in A: + temp_term = element.as_independent(t)[1] + if temp_term.has(t): + term = temp_term + break + + if term is not None: + A_factored = (A/term).applyfunc(ratsimp) + can_factor = _matrix_is_constant(A_factored, t) + term = (term, A_factored) if can_factor else None + + return term + + +def _is_second_order_type2(A, t): + term = _factor_matrix(A, t) + is_type2 = False + + if term is not None: + term = 1/term[0] + is_type2 = term.is_polynomial() + + if is_type2: + poly = Poly(term.expand(), t) + monoms = poly.monoms() + + if monoms[0][0] in (2, 4): + cs = _get_poly_coeffs(poly, 4) + a, b, c, d, e = cs + + a1 = powdenest(sqrt(a), force=True) + c1 = powdenest(sqrt(e), force=True) + b1 = powdenest(sqrt(c - 2*a1*c1), force=True) + + is_type2 = (b == 2*a1*b1) and (d == 2*b1*c1) + term = a1*t**2 + b1*t + c1 + + else: + is_type2 = False + + return is_type2, term + + +def _get_poly_coeffs(poly, order): + cs = [0 for _ in range(order+1)] + for c, m in zip(poly.coeffs(), poly.monoms()): + cs[-1-m[0]] = c + return cs + + +def _match_second_order_type(A1, A0, t, b=None): + r""" + Works only for second order system in its canonical form. + + Type 0: Constant coefficient matrix, can be simply solved by + introducing dummy variables. + Type 1: When the substitution: $U = t*X' - X$ works for reducing + the second order system to first order system. + Type 2: When the system is of the form: $poly * X'' = A*X$ where + $poly$ is square of a quadratic polynomial with respect to + *t* and $A$ is a constant coefficient matrix. + + """ + match = {"type_of_equation": "type0"} + n = A1.shape[0] + + if _matrix_is_constant(A1, t) and _matrix_is_constant(A0, t): + return match + + if (A1 + A0*t).applyfunc(expand_mul).is_zero_matrix: + match.update({"type_of_equation": "type1", "A1": A1}) + + elif A1.is_zero_matrix and (b is None or b.is_zero_matrix): + is_type2, term = _is_second_order_type2(A0, t) + if is_type2: + a, b, c = _get_poly_coeffs(Poly(term, t), 2) + A = (A0*(term**2).expand()).applyfunc(ratsimp) + (b**2/4 - a*c)*eye(n, n) + tau = integrate(1/term, t) + t_ = Symbol("{}_".format(t)) + match.update({"type_of_equation": "type2", "A0": A, + "g(t)": sqrt(term), "tau": tau, "is_transformed": True, + "t_": t_}) + + return match + + +def _second_order_subs_type1(A, b, funcs, t): + r""" + For a linear, second order system of ODEs, a particular substitution. + + A system of the below form can be reduced to a linear first order system of + ODEs: + .. math:: + X'' = A(t) * (t*X' - X) + b(t) + + By substituting: + .. math:: U = t*X' - X + + To get the system: + .. math:: U' = t*(A(t)*U + b(t)) + + Where $U$ is the vector of dependent variables, $X$ is the vector of dependent + variables in `funcs` and $X'$ is the first order derivative of $X$ with respect to + $t$. It may or may not reduce the system into linear first order system of ODEs. + + Then a check is made to determine if the system passed can be reduced or not, if + this substitution works, then the system is reduced and its solved for the new + substitution. After we get the solution for $U$: + + .. math:: U = a(t) + + We substitute and return the reduced system: + + .. math:: + a(t) = t*X' - X + + Parameters + ========== + + A: Matrix + Coefficient matrix($A(t)*t$) of the second order system of this form. + b: Matrix + Non-homogeneous term($b(t)$) of the system of ODEs. + funcs: List + List of dependent variables + t: Symbol + Independent variable of the system of ODEs. + + Returns + ======= + + List + + """ + + U = Matrix([t*func.diff(t) - func for func in funcs]) + + sol = linodesolve(A, t, t*b) + reduced_eqs = [Eq(u, s) for s, u in zip(sol, U)] + reduced_eqs = canonical_odes(reduced_eqs, funcs, t)[0] + + return reduced_eqs + + +def _second_order_subs_type2(A, funcs, t_): + r""" + Returns a second order system based on the coefficient matrix passed. + + Explanation + =========== + + This function returns a system of second order ODE of the following form: + + .. math:: + X'' = A * X + + Here, $X$ is the vector of dependent variables, but a bit modified, $A$ is the + coefficient matrix passed. + + Along with returning the second order system, this function also returns the new + dependent variables with the new independent variable `t_` passed. + + Parameters + ========== + + A: Matrix + Coefficient matrix of the system + funcs: List + List of old dependent variables + t_: Symbol + New independent variable + + Returns + ======= + + List, List + + """ + func_names = [func.func.__name__ for func in funcs] + new_funcs = [Function(Dummy("{}_".format(name)))(t_) for name in func_names] + rhss = A * Matrix(new_funcs) + new_eqs = [Eq(func.diff(t_, 2), rhs) for func, rhs in zip(new_funcs, rhss)] + + return new_eqs, new_funcs + + +def _is_euler_system(As, t): + return all(_matrix_is_constant((A*t**i).applyfunc(ratsimp), t) for i, A in enumerate(As)) + + +def _classify_linear_system(eqs, funcs, t, is_canon=False): + r""" + Returns a dictionary with details of the eqs if the system passed is linear + and can be classified by this function else returns None + + Explanation + =========== + + This function takes the eqs, converts it into a form Ax = b where x is a vector of terms + containing dependent variables and their derivatives till their maximum order. If it is + possible to convert eqs into Ax = b, then all the equations in eqs are linear otherwise + they are non-linear. + + To check if the equations are constant coefficient, we need to check if all the terms in + A obtained above are constant or not. + + To check if the equations are homogeneous or not, we need to check if b is a zero matrix + or not. + + Parameters + ========== + + eqs: List + List of ODEs + funcs: List + List of dependent variables + t: Symbol + Independent variable of the equations in eqs + is_canon: Boolean + If True, then this function will not try to get the + system in canonical form. Default value is False + + Returns + ======= + + match = { + 'no_of_equation': len(eqs), + 'eq': eqs, + 'func': funcs, + 'order': order, + 'is_linear': is_linear, + 'is_constant': is_constant, + 'is_homogeneous': is_homogeneous, + } + + Dict or list of Dicts or None + Dict with values for keys: + 1. no_of_equation: Number of equations + 2. eq: The set of equations + 3. func: List of dependent variables + 4. order: A dictionary that gives the order of the + dependent variable in eqs + 5. is_linear: Boolean value indicating if the set of + equations are linear or not. + 6. is_constant: Boolean value indicating if the set of + equations have constant coefficients or not. + 7. is_homogeneous: Boolean value indicating if the set of + equations are homogeneous or not. + 8. commutative_antiderivative: Antiderivative of the coefficient + matrix if the coefficient matrix is non-constant + and commutative with its antiderivative. This key + may or may not exist. + 9. is_general: Boolean value indicating if the system of ODEs is + solvable using one of the general case solvers or not. + 10. rhs: rhs of the non-homogeneous system of ODEs in Matrix form. This + key may or may not exist. + 11. is_higher_order: True if the system passed has an order greater than 1. + This key may or may not exist. + 12. is_second_order: True if the system passed is a second order ODE. This + key may or may not exist. + This Dict is the answer returned if the eqs are linear and constant + coefficient. Otherwise, None is returned. + + """ + + # Error for i == 0 can be added but isn't for now + + # Check for len(funcs) == len(eqs) + if len(funcs) != len(eqs): + raise ValueError("Number of functions given is not equal to the number of equations %s" % funcs) + + # ValueError when functions have more than one arguments + for func in funcs: + if len(func.args) != 1: + raise ValueError("dsolve() and classify_sysode() work with " + "functions of one variable only, not %s" % func) + + # Getting the func_dict and order using the helper + # function + order = _get_func_order(eqs, funcs) + system_order = max(order[func] for func in funcs) + is_higher_order = system_order > 1 + is_second_order = system_order == 2 and all(order[func] == 2 for func in funcs) + + # Not adding the check if the len(func.args) for + # every func in funcs is 1 + + # Linearity check + try: + + canon_eqs = canonical_odes(eqs, funcs, t) if not is_canon else [eqs] + if len(canon_eqs) == 1: + As, b = linear_ode_to_matrix(canon_eqs[0], funcs, t, system_order) + else: + + match = { + 'is_implicit': True, + 'canon_eqs': canon_eqs + } + + return match + + # When the system of ODEs is non-linear, an ODENonlinearError is raised. + # This function catches the error and None is returned. + except ODENonlinearError: + return None + + is_linear = True + + # Homogeneous check + is_homogeneous = True if b.is_zero_matrix else False + + # Is general key is used to identify if the system of ODEs can be solved by + # one of the general case solvers or not. + match = { + 'no_of_equation': len(eqs), + 'eq': eqs, + 'func': funcs, + 'order': order, + 'is_linear': is_linear, + 'is_homogeneous': is_homogeneous, + 'is_general': True + } + + if not is_homogeneous: + match['rhs'] = b + + is_constant = all(_matrix_is_constant(A_, t) for A_ in As) + + # The match['is_linear'] check will be added in the future when this + # function becomes ready to deal with non-linear systems of ODEs + + if not is_higher_order: + A = As[1] + match['func_coeff'] = A + + # Constant coefficient check + is_constant = _matrix_is_constant(A, t) + match['is_constant'] = is_constant + + try: + system_info = linodesolve_type(A, t, b=b) + except NotImplementedError: + return None + + match.update(system_info) + antiderivative = match.pop("antiderivative") + + if not is_constant: + match['commutative_antiderivative'] = antiderivative + + return match + else: + match['type_of_equation'] = "type0" + + if is_second_order: + A1, A0 = As[1:] + + match_second_order = _match_second_order_type(A1, A0, t) + match.update(match_second_order) + + match['is_second_order'] = True + + # If system is constant, then no need to check if its in euler + # form or not. It will be easier and faster to directly proceed + # to solve it. + if match['type_of_equation'] == "type0" and not is_constant: + is_euler = _is_euler_system(As, t) + if is_euler: + t_ = Symbol('{}_'.format(t)) + match.update({'is_transformed': True, 'type_of_equation': 'type1', + 't_': t_}) + else: + is_jordan = lambda M: M == Matrix.jordan_block(M.shape[0], M[0, 0]) + terms = _factor_matrix(As[-1], t) + if all(A.is_zero_matrix for A in As[1:-1]) and terms is not None and not is_jordan(terms[1]): + P, J = terms[1].jordan_form() + match.update({'type_of_equation': 'type2', 'J': J, + 'f(t)': terms[0], 'P': P, 'is_transformed': True}) + + if match['type_of_equation'] != 'type0' and is_second_order: + match.pop('is_second_order', None) + + match['is_higher_order'] = is_higher_order + + return match + +def _preprocess_eqs(eqs): + processed_eqs = [] + for eq in eqs: + processed_eqs.append(eq if isinstance(eq, Equality) else Eq(eq, 0)) + + return processed_eqs + + +def _eqs2dict(eqs, funcs): + eqsorig = {} + eqsmap = {} + funcset = set(funcs) + for eq in eqs: + f1, = eq.lhs.atoms(AppliedUndef) + f2s = (eq.rhs.atoms(AppliedUndef) - {f1}) & funcset + eqsmap[f1] = f2s + eqsorig[f1] = eq + return eqsmap, eqsorig + + +def _dict2graph(d): + nodes = list(d) + edges = [(f1, f2) for f1, f2s in d.items() for f2 in f2s] + G = (nodes, edges) + return G + + +def _is_type1(scc, t): + eqs, funcs = scc + + try: + (A1, A0), b = linear_ode_to_matrix(eqs, funcs, t, 1) + except (ODENonlinearError, ODEOrderError): + return False + + if _matrix_is_constant(A0, t) and b.is_zero_matrix: + return True + + return False + + +def _combine_type1_subsystems(subsystem, funcs, t): + indices = [i for i, sys in enumerate(zip(subsystem, funcs)) if _is_type1(sys, t)] + remove = set() + for ip, i in enumerate(indices): + for j in indices[ip+1:]: + if any(eq2.has(funcs[i]) for eq2 in subsystem[j]): + subsystem[j] = subsystem[i] + subsystem[j] + remove.add(i) + subsystem = [sys for i, sys in enumerate(subsystem) if i not in remove] + return subsystem + + +def _component_division(eqs, funcs, t): + + # Assuming that each eq in eqs is in canonical form, + # that is, [f(x).diff(x) = .., g(x).diff(x) = .., etc] + # and that the system passed is in its first order + eqsmap, eqsorig = _eqs2dict(eqs, funcs) + + subsystems = [] + for cc in connected_components(_dict2graph(eqsmap)): + eqsmap_c = {f: eqsmap[f] for f in cc} + sccs = strongly_connected_components(_dict2graph(eqsmap_c)) + subsystem = [[eqsorig[f] for f in scc] for scc in sccs] + subsystem = _combine_type1_subsystems(subsystem, sccs, t) + subsystems.append(subsystem) + + return subsystems + + +# Returns: List of equations +def _linear_ode_solver(match): + t = match['t'] + funcs = match['func'] + + rhs = match.get('rhs', None) + tau = match.get('tau', None) + t = match['t_'] if 't_' in match else t + A = match['func_coeff'] + + # Note: To make B None when the matrix has constant + # coefficient + B = match.get('commutative_antiderivative', None) + type = match['type_of_equation'] + + sol_vector = linodesolve(A, t, b=rhs, B=B, + type=type, tau=tau) + + sol = [Eq(f, s) for f, s in zip(funcs, sol_vector)] + + return sol + + +def _select_equations(eqs, funcs, key=lambda x: x): + eq_dict = {e.lhs: e.rhs for e in eqs} + return [Eq(f, eq_dict[key(f)]) for f in funcs] + + +def _higher_order_ode_solver(match): + eqs = match["eq"] + funcs = match["func"] + t = match["t"] + sysorder = match['order'] + type = match.get('type_of_equation', "type0") + + is_second_order = match.get('is_second_order', False) + is_transformed = match.get('is_transformed', False) + is_euler = is_transformed and type == "type1" + is_higher_order_type2 = is_transformed and type == "type2" and 'P' in match + + if is_second_order: + new_eqs, new_funcs = _second_order_to_first_order(eqs, funcs, t, + A1=match.get("A1", None), A0=match.get("A0", None), + b=match.get("rhs", None), type=type, + t_=match.get("t_", None)) + else: + new_eqs, new_funcs = _higher_order_to_first_order(eqs, sysorder, t, funcs=funcs, + type=type, J=match.get('J', None), + f_t=match.get('f(t)', None), + P=match.get('P', None), b=match.get('rhs', None)) + + if is_transformed: + t = match.get('t_', t) + + if not is_higher_order_type2: + new_eqs = _select_equations(new_eqs, [f.diff(t) for f in new_funcs]) + + sol = None + + # NotImplementedError may be raised when the system may be actually + # solvable if it can be just divided into sub-systems + try: + if not is_higher_order_type2: + sol = _strong_component_solver(new_eqs, new_funcs, t) + except NotImplementedError: + sol = None + + # Dividing the system only when it becomes essential + if sol is None: + try: + sol = _component_solver(new_eqs, new_funcs, t) + except NotImplementedError: + sol = None + + if sol is None: + return sol + + is_second_order_type2 = is_second_order and type == "type2" + + underscores = '__' if is_transformed else '_' + + sol = _select_equations(sol, funcs, + key=lambda x: Function(Dummy('{}{}0'.format(x.func.__name__, underscores)))(t)) + + if match.get("is_transformed", False): + if is_second_order_type2: + g_t = match["g(t)"] + tau = match["tau"] + sol = [Eq(s.lhs, s.rhs.subs(t, tau) * g_t) for s in sol] + elif is_euler: + t = match['t'] + tau = match['t_'] + sol = [s.subs(tau, log(t)) for s in sol] + elif is_higher_order_type2: + P = match['P'] + sol_vector = P * Matrix([s.rhs for s in sol]) + sol = [Eq(f, s) for f, s in zip(funcs, sol_vector)] + + return sol + + +# Returns: List of equations or None +# If None is returned by this solver, then the system +# of ODEs cannot be solved directly by dsolve_system. +def _strong_component_solver(eqs, funcs, t): + from sympy.solvers.ode.ode import dsolve, constant_renumber + + match = _classify_linear_system(eqs, funcs, t, is_canon=True) + sol = None + + # Assuming that we can't get an implicit system + # since we are already canonical equations from + # dsolve_system + if match: + match['t'] = t + + if match.get('is_higher_order', False): + sol = _higher_order_ode_solver(match) + + elif match.get('is_linear', False): + sol = _linear_ode_solver(match) + + # Note: For now, only linear systems are handled by this function + # hence, the match condition is added. This can be removed later. + if sol is None and len(eqs) == 1: + sol = dsolve(eqs[0], func=funcs[0]) + variables = Tuple(eqs[0]).free_symbols + new_constants = [Dummy() for _ in range(ode_order(eqs[0], funcs[0]))] + sol = constant_renumber(sol, variables=variables, newconstants=new_constants) + sol = [sol] + + # To add non-linear case here in future + + return sol + + +def _get_funcs_from_canon(eqs): + return [eq.lhs.args[0] for eq in eqs] + + +# Returns: List of Equations(a solution) +def _weak_component_solver(wcc, t): + + # We will divide the systems into sccs + # only when the wcc cannot be solved as + # a whole + eqs = [] + for scc in wcc: + eqs += scc + funcs = _get_funcs_from_canon(eqs) + + sol = _strong_component_solver(eqs, funcs, t) + if sol: + return sol + + sol = [] + + for scc in wcc: + eqs = scc + funcs = _get_funcs_from_canon(eqs) + + # Substituting solutions for the dependent + # variables solved in previous SCC, if any solved. + comp_eqs = [eq.subs({s.lhs: s.rhs for s in sol}) for eq in eqs] + scc_sol = _strong_component_solver(comp_eqs, funcs, t) + + if scc_sol is None: + raise NotImplementedError(filldedent(''' + The system of ODEs passed cannot be solved by dsolve_system. + ''')) + + # scc_sol: List of equations + # scc_sol is a solution + sol += scc_sol + + return sol + + +# Returns: List of Equations(a solution) +def _component_solver(eqs, funcs, t): + components = _component_division(eqs, funcs, t) + sol = [] + + for wcc in components: + + # wcc_sol: List of Equations + sol += _weak_component_solver(wcc, t) + + # sol: List of Equations + return sol + + +def _second_order_to_first_order(eqs, funcs, t, type="auto", A1=None, + A0=None, b=None, t_=None): + r""" + Expects the system to be in second order and in canonical form + + Explanation + =========== + + Reduces a second order system into a first order one depending on the type of second + order system. + 1. "type0": If this is passed, then the system will be reduced to first order by + introducing dummy variables. + 2. "type1": If this is passed, then a particular substitution will be used to reduce the + the system into first order. + 3. "type2": If this is passed, then the system will be transformed with new dependent + variables and independent variables. This transformation is a part of solving + the corresponding system of ODEs. + + `A1` and `A0` are the coefficient matrices from the system and it is assumed that the + second order system has the form given below: + + .. math:: + A2 * X'' = A1 * X' + A0 * X + b + + Here, $A2$ is the coefficient matrix for the vector $X''$ and $b$ is the non-homogeneous + term. + + Default value for `b` is None but if `A1` and `A0` are passed and `b` is not passed, then the + system will be assumed homogeneous. + + """ + is_a1 = A1 is None + is_a0 = A0 is None + + if (type == "type1" and is_a1) or (type == "type2" and is_a0)\ + or (type == "auto" and (is_a1 or is_a0)): + (A2, A1, A0), b = linear_ode_to_matrix(eqs, funcs, t, 2) + + if not A2.is_Identity: + raise ValueError(filldedent(''' + The system must be in its canonical form. + ''')) + + if type == "auto": + match = _match_second_order_type(A1, A0, t) + type = match["type_of_equation"] + A1 = match.get("A1", None) + A0 = match.get("A0", None) + + sys_order = dict.fromkeys(funcs, 2) + + if type == "type1": + if b is None: + b = zeros(len(eqs)) + eqs = _second_order_subs_type1(A1, b, funcs, t) + sys_order = dict.fromkeys(funcs, 1) + + if type == "type2": + if t_ is None: + t_ = Symbol("{}_".format(t)) + t = t_ + eqs, funcs = _second_order_subs_type2(A0, funcs, t_) + sys_order = dict.fromkeys(funcs, 2) + + return _higher_order_to_first_order(eqs, sys_order, t, funcs=funcs) + + +def _higher_order_type2_to_sub_systems(J, f_t, funcs, t, max_order, b=None, P=None): + + # Note: To add a test for this ValueError + if J is None or f_t is None or not _matrix_is_constant(J, t): + raise ValueError(filldedent(''' + Correctly input for args 'A' and 'f_t' for Linear, Higher Order, + Type 2 + ''')) + + if P is None and b is not None and not b.is_zero_matrix: + raise ValueError(filldedent(''' + Provide the keyword 'P' for matrix P in A = P * J * P-1. + ''')) + + new_funcs = Matrix([Function(Dummy('{}__0'.format(f.func.__name__)))(t) for f in funcs]) + new_eqs = new_funcs.diff(t, max_order) - f_t * J * new_funcs + + if b is not None and not b.is_zero_matrix: + new_eqs -= P.inv() * b + + new_eqs = canonical_odes(new_eqs, new_funcs, t)[0] + + return new_eqs, new_funcs + + +def _higher_order_to_first_order(eqs, sys_order, t, funcs=None, type="type0", **kwargs): + if funcs is None: + funcs = sys_order.keys() + + # Standard Cauchy Euler system + if type == "type1": + t_ = Symbol('{}_'.format(t)) + new_funcs = [Function(Dummy('{}_'.format(f.func.__name__)))(t_) for f in funcs] + max_order = max(sys_order[func] for func in funcs) + subs_dict = dict(zip(funcs, new_funcs)) + subs_dict[t] = exp(t_) + + free_function = Function(Dummy()) + + def _get_coeffs_from_subs_expression(expr): + if isinstance(expr, Subs): + free_symbol = expr.args[1][0] + term = expr.args[0] + return {ode_order(term, free_symbol): 1} + + if isinstance(expr, Mul): + coeff = expr.args[0] + order = list(_get_coeffs_from_subs_expression(expr.args[1]).keys())[0] + return {order: coeff} + + if isinstance(expr, Add): + coeffs = {} + for arg in expr.args: + + if isinstance(arg, Mul): + coeffs.update(_get_coeffs_from_subs_expression(arg)) + + else: + order = list(_get_coeffs_from_subs_expression(arg).keys())[0] + coeffs[order] = 1 + + return coeffs + + for o in range(1, max_order + 1): + expr = free_function(log(t_)).diff(t_, o)*t_**o + coeff_dict = _get_coeffs_from_subs_expression(expr) + coeffs = [coeff_dict[order] if order in coeff_dict else 0 for order in range(o + 1)] + expr_to_subs = sum(free_function(t_).diff(t_, i) * c for i, c in + enumerate(coeffs)) / t**o + subs_dict.update({f.diff(t, o): expr_to_subs.subs(free_function(t_), nf) + for f, nf in zip(funcs, new_funcs)}) + + new_eqs = [eq.subs(subs_dict) for eq in eqs] + new_sys_order = {nf: sys_order[f] for f, nf in zip(funcs, new_funcs)} + + new_eqs = canonical_odes(new_eqs, new_funcs, t_)[0] + + return _higher_order_to_first_order(new_eqs, new_sys_order, t_, funcs=new_funcs) + + # Systems of the form: X(n)(t) = f(t)*A*X + b + # where X(n)(t) is the nth derivative of the vector of dependent variables + # with respect to the independent variable and A is a constant matrix. + if type == "type2": + J = kwargs.get('J', None) + f_t = kwargs.get('f_t', None) + b = kwargs.get('b', None) + P = kwargs.get('P', None) + max_order = max(sys_order[func] for func in funcs) + + return _higher_order_type2_to_sub_systems(J, f_t, funcs, t, max_order, P=P, b=b) + + # Note: To be changed to this after doit option is disabled for default cases + # new_sysorder = _get_func_order(new_eqs, new_funcs) + # + # return _higher_order_to_first_order(new_eqs, new_sysorder, t, funcs=new_funcs) + + new_funcs = [] + + for prev_func in funcs: + func_name = prev_func.func.__name__ + func = Function(Dummy('{}_0'.format(func_name)))(t) + new_funcs.append(func) + subs_dict = {prev_func: func} + new_eqs = [] + + for i in range(1, sys_order[prev_func]): + new_func = Function(Dummy('{}_{}'.format(func_name, i)))(t) + subs_dict[prev_func.diff(t, i)] = new_func + new_funcs.append(new_func) + + prev_f = subs_dict[prev_func.diff(t, i-1)] + new_eq = Eq(prev_f.diff(t), new_func) + new_eqs.append(new_eq) + + eqs = [eq.subs(subs_dict) for eq in eqs] + new_eqs + + return eqs, new_funcs + + +def dsolve_system(eqs, funcs=None, t=None, ics=None, doit=False, simplify=True): + r""" + Solves any(supported) system of Ordinary Differential Equations + + Explanation + =========== + + This function takes a system of ODEs as an input, determines if the + it is solvable by this function, and returns the solution if found any. + + This function can handle: + 1. Linear, First Order, Constant coefficient homogeneous system of ODEs + 2. Linear, First Order, Constant coefficient non-homogeneous system of ODEs + 3. Linear, First Order, non-constant coefficient homogeneous system of ODEs + 4. Linear, First Order, non-constant coefficient non-homogeneous system of ODEs + 5. Any implicit system which can be divided into system of ODEs which is of the above 4 forms + 6. Any higher order linear system of ODEs that can be reduced to one of the 5 forms of systems described above. + + The types of systems described above are not limited by the number of equations, i.e. this + function can solve the above types irrespective of the number of equations in the system passed. + But, the bigger the system, the more time it will take to solve the system. + + This function returns a list of solutions. Each solution is a list of equations where LHS is + the dependent variable and RHS is an expression in terms of the independent variable. + + Among the non constant coefficient types, not all the systems are solvable by this function. Only + those which have either a coefficient matrix with a commutative antiderivative or those systems which + may be divided further so that the divided systems may have coefficient matrix with commutative antiderivative. + + Parameters + ========== + + eqs : List + system of ODEs to be solved + funcs : List or None + List of dependent variables that make up the system of ODEs + t : Symbol or None + Independent variable in the system of ODEs + ics : Dict or None + Set of initial boundary/conditions for the system of ODEs + doit : Boolean + Evaluate the solutions if True. Default value is True. Can be + set to false if the integral evaluation takes too much time and/or + is not required. + simplify: Boolean + Simplify the solutions for the systems. Default value is True. + Can be set to false if simplification takes too much time and/or + is not required. + + Examples + ======== + + >>> from sympy import symbols, Eq, Function + >>> from sympy.solvers.ode.systems import dsolve_system + >>> f, g = symbols("f g", cls=Function) + >>> x = symbols("x") + + >>> eqs = [Eq(f(x).diff(x), g(x)), Eq(g(x).diff(x), f(x))] + >>> dsolve_system(eqs) + [[Eq(f(x), -C1*exp(-x) + C2*exp(x)), Eq(g(x), C1*exp(-x) + C2*exp(x))]] + + You can also pass the initial conditions for the system of ODEs: + + >>> dsolve_system(eqs, ics={f(0): 1, g(0): 0}) + [[Eq(f(x), exp(x)/2 + exp(-x)/2), Eq(g(x), exp(x)/2 - exp(-x)/2)]] + + Optionally, you can pass the dependent variables and the independent + variable for which the system is to be solved: + + >>> funcs = [f(x), g(x)] + >>> dsolve_system(eqs, funcs=funcs, t=x) + [[Eq(f(x), -C1*exp(-x) + C2*exp(x)), Eq(g(x), C1*exp(-x) + C2*exp(x))]] + + Lets look at an implicit system of ODEs: + + >>> eqs = [Eq(f(x).diff(x)**2, g(x)**2), Eq(g(x).diff(x), g(x))] + >>> dsolve_system(eqs) + [[Eq(f(x), C1 - C2*exp(x)), Eq(g(x), C2*exp(x))], [Eq(f(x), C1 + C2*exp(x)), Eq(g(x), C2*exp(x))]] + + Returns + ======= + + List of List of Equations + + Raises + ====== + + NotImplementedError + When the system of ODEs is not solvable by this function. + ValueError + When the parameters passed are not in the required form. + + """ + from sympy.solvers.ode.ode import solve_ics, _extract_funcs, constant_renumber + + if not iterable(eqs): + raise ValueError(filldedent(''' + List of equations should be passed. The input is not valid. + ''')) + + eqs = _preprocess_eqs(eqs) + + if funcs is not None and not isinstance(funcs, list): + raise ValueError(filldedent(''' + Input to the funcs should be a list of functions. + ''')) + + if funcs is None: + funcs = _extract_funcs(eqs) + + if any(len(func.args) != 1 for func in funcs): + raise ValueError(filldedent(''' + dsolve_system can solve a system of ODEs with only one independent + variable. + ''')) + + if len(eqs) != len(funcs): + raise ValueError(filldedent(''' + Number of equations and number of functions do not match + ''')) + + if t is not None and not isinstance(t, Symbol): + raise ValueError(filldedent(''' + The independent variable must be of type Symbol + ''')) + + if t is None: + t = list(list(eqs[0].atoms(Derivative))[0].atoms(Symbol))[0] + + sols = [] + canon_eqs = canonical_odes(eqs, funcs, t) + + for canon_eq in canon_eqs: + try: + sol = _strong_component_solver(canon_eq, funcs, t) + except NotImplementedError: + sol = None + + if sol is None: + sol = _component_solver(canon_eq, funcs, t) + + sols.append(sol) + + if sols: + final_sols = [] + variables = Tuple(*eqs).free_symbols + + for sol in sols: + + sol = _select_equations(sol, funcs) + sol = constant_renumber(sol, variables=variables) + + if ics: + constants = Tuple(*sol).free_symbols - variables + solved_constants = solve_ics(sol, funcs, constants, ics) + sol = [s.subs(solved_constants) for s in sol] + + if simplify: + constants = Tuple(*sol).free_symbols - variables + sol = simpsol(sol, [t], constants, doit=doit) + + final_sols.append(sol) + + sols = final_sols + + return sols diff --git a/phivenv/Lib/site-packages/sympy/solvers/ode/tests/__init__.py b/phivenv/Lib/site-packages/sympy/solvers/ode/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/solvers/ode/tests/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/solvers/ode/tests/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..451582a43aeaef7c4c689b36fa47f21d6bc46521 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/solvers/ode/tests/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/solvers/ode/tests/__pycache__/test_lie_group.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/solvers/ode/tests/__pycache__/test_lie_group.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e2e0703be47a1b18debd9a95106fd5b2c6f3dd2 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/solvers/ode/tests/__pycache__/test_lie_group.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/solvers/ode/tests/__pycache__/test_ode.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/solvers/ode/tests/__pycache__/test_ode.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60cd14820f04befc98351c049feb027df7c9b789 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/solvers/ode/tests/__pycache__/test_ode.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/solvers/ode/tests/__pycache__/test_riccati.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/solvers/ode/tests/__pycache__/test_riccati.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d11b12202bfadd23df2e7b4cb3efdeb3a2d9cafd Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/solvers/ode/tests/__pycache__/test_riccati.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/solvers/ode/tests/__pycache__/test_single.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/solvers/ode/tests/__pycache__/test_single.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f99c303860cff45163a96168e30579b9ec12a2ae Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/solvers/ode/tests/__pycache__/test_single.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/solvers/ode/tests/__pycache__/test_subscheck.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/solvers/ode/tests/__pycache__/test_subscheck.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb6c1a8fe3d3615ff49856b524f561a6d51d3106 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/solvers/ode/tests/__pycache__/test_subscheck.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/solvers/ode/tests/test_lie_group.py b/phivenv/Lib/site-packages/sympy/solvers/ode/tests/test_lie_group.py new file mode 100644 index 0000000000000000000000000000000000000000..153d30ff563773819e49c989f447c1ec7962169b --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/solvers/ode/tests/test_lie_group.py @@ -0,0 +1,152 @@ +from sympy.core.function import Function +from sympy.core.numbers import Rational +from sympy.core.relational import Eq +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (atan, sin, tan) + +from sympy.solvers.ode import (classify_ode, checkinfsol, dsolve, infinitesimals) + +from sympy.solvers.ode.subscheck import checkodesol + +from sympy.testing.pytest import XFAIL + + +C1 = Symbol('C1') +x, y = symbols("x y") +f = Function('f') +xi = Function('xi') +eta = Function('eta') + + +def test_heuristic1(): + a, b, c, a4, a3, a2, a1, a0 = symbols("a b c a4 a3 a2 a1 a0") + df = f(x).diff(x) + eq = Eq(df, x**2*f(x)) + eq1 = f(x).diff(x) + a*f(x) - c*exp(b*x) + eq2 = f(x).diff(x) + 2*x*f(x) - x*exp(-x**2) + eq3 = (1 + 2*x)*df + 2 - 4*exp(-f(x)) + eq4 = f(x).diff(x) - (a4*x**4 + a3*x**3 + a2*x**2 + a1*x + a0)**Rational(-1, 2) + eq5 = x**2*df - f(x) + x**2*exp(x - (1/x)) + eqlist = [eq, eq1, eq2, eq3, eq4, eq5] + + i = infinitesimals(eq, hint='abaco1_simple') + assert i == [{eta(x, f(x)): exp(x**3/3), xi(x, f(x)): 0}, + {eta(x, f(x)): f(x), xi(x, f(x)): 0}, + {eta(x, f(x)): 0, xi(x, f(x)): x**(-2)}] + i1 = infinitesimals(eq1, hint='abaco1_simple') + assert i1 == [{eta(x, f(x)): exp(-a*x), xi(x, f(x)): 0}] + i2 = infinitesimals(eq2, hint='abaco1_simple') + assert i2 == [{eta(x, f(x)): exp(-x**2), xi(x, f(x)): 0}] + i3 = infinitesimals(eq3, hint='abaco1_simple') + assert i3 == [{eta(x, f(x)): 0, xi(x, f(x)): 2*x + 1}, + {eta(x, f(x)): 0, xi(x, f(x)): 1/(exp(f(x)) - 2)}] + i4 = infinitesimals(eq4, hint='abaco1_simple') + assert i4 == [{eta(x, f(x)): 1, xi(x, f(x)): 0}, + {eta(x, f(x)): 0, + xi(x, f(x)): sqrt(a0 + a1*x + a2*x**2 + a3*x**3 + a4*x**4)}] + i5 = infinitesimals(eq5, hint='abaco1_simple') + assert i5 == [{xi(x, f(x)): 0, eta(x, f(x)): exp(-1/x)}] + + ilist = [i, i1, i2, i3, i4, i5] + for eq, i in (zip(eqlist, ilist)): + check = checkinfsol(eq, i) + assert check[0] + + # This ODE can be solved by the Lie Group method, when there are + # better assumptions + eq6 = df - (f(x)/x)*(x*log(x**2/f(x)) + 2) + i = infinitesimals(eq6, hint='abaco1_product') + assert i == [{eta(x, f(x)): f(x)*exp(-x), xi(x, f(x)): 0}] + assert checkinfsol(eq6, i)[0] + + eq7 = x*(f(x).diff(x)) + 1 - f(x)**2 + i = infinitesimals(eq7, hint='chi') + assert checkinfsol(eq7, i)[0] + + +def test_heuristic3(): + a, b = symbols("a b") + df = f(x).diff(x) + + eq = x**2*df + x*f(x) + f(x)**2 + x**2 + i = infinitesimals(eq, hint='bivariate') + assert i == [{eta(x, f(x)): f(x), xi(x, f(x)): x}] + assert checkinfsol(eq, i)[0] + + eq = x**2*(-f(x)**2 + df)- a*x**2*f(x) + 2 - a*x + i = infinitesimals(eq, hint='bivariate') + assert checkinfsol(eq, i)[0] + + +def test_heuristic_function_sum(): + eq = f(x).diff(x) - (3*(1 + x**2/f(x)**2)*atan(f(x)/x) + (1 - 2*f(x))/x + + (1 - 3*f(x))*(x/f(x)**2)) + i = infinitesimals(eq, hint='function_sum') + assert i == [{eta(x, f(x)): f(x)**(-2) + x**(-2), xi(x, f(x)): 0}] + assert checkinfsol(eq, i)[0] + + +def test_heuristic_abaco2_similar(): + a, b = symbols("a b") + F = Function('F') + eq = f(x).diff(x) - F(a*x + b*f(x)) + i = infinitesimals(eq, hint='abaco2_similar') + assert i == [{eta(x, f(x)): -a/b, xi(x, f(x)): 1}] + assert checkinfsol(eq, i)[0] + + eq = f(x).diff(x) - (f(x)**2 / (sin(f(x) - x) - x**2 + 2*x*f(x))) + i = infinitesimals(eq, hint='abaco2_similar') + assert i == [{eta(x, f(x)): f(x)**2, xi(x, f(x)): f(x)**2}] + assert checkinfsol(eq, i)[0] + + +def test_heuristic_abaco2_unique_unknown(): + + a, b = symbols("a b") + F = Function('F') + eq = f(x).diff(x) - x**(a - 1)*(f(x)**(1 - b))*F(x**a/a + f(x)**b/b) + i = infinitesimals(eq, hint='abaco2_unique_unknown') + assert i == [{eta(x, f(x)): -f(x)*f(x)**(-b), xi(x, f(x)): x*x**(-a)}] + assert checkinfsol(eq, i)[0] + + eq = f(x).diff(x) + tan(F(x**2 + f(x)**2) + atan(x/f(x))) + i = infinitesimals(eq, hint='abaco2_unique_unknown') + assert i == [{eta(x, f(x)): x, xi(x, f(x)): -f(x)}] + assert checkinfsol(eq, i)[0] + + eq = (x*f(x).diff(x) + f(x) + 2*x)**2 -4*x*f(x) -4*x**2 -4*a + i = infinitesimals(eq, hint='abaco2_unique_unknown') + assert checkinfsol(eq, i)[0] + + +def test_heuristic_linear(): + a, b, m, n = symbols("a b m n") + + eq = x**(n*(m + 1) - m)*(f(x).diff(x)) - a*f(x)**n -b*x**(n*(m + 1)) + i = infinitesimals(eq, hint='linear') + assert checkinfsol(eq, i)[0] + + +@XFAIL +def test_kamke(): + a, b, alpha, c = symbols("a b alpha c") + eq = x**2*(a*f(x)**2+(f(x).diff(x))) + b*x**alpha + c + i = infinitesimals(eq, hint='sum_function') # XFAIL + assert checkinfsol(eq, i)[0] + + +def test_user_infinitesimals(): + x = Symbol("x") # assuming x is real generates an error + eq = x*(f(x).diff(x)) + 1 - f(x)**2 + sol = Eq(f(x), (C1 + x**2)/(C1 - x**2)) + infinitesimals = {'xi':sqrt(f(x) - 1)/sqrt(f(x) + 1), 'eta':0} + assert dsolve(eq, hint='lie_group', **infinitesimals) == sol + assert checkodesol(eq, sol) == (True, 0) + + +@XFAIL +def test_lie_group_issue15219(): + eqn = exp(f(x).diff(x)-f(x)) + assert 'lie_group' not in classify_ode(eqn, f(x)) diff --git a/phivenv/Lib/site-packages/sympy/solvers/ode/tests/test_ode.py b/phivenv/Lib/site-packages/sympy/solvers/ode/tests/test_ode.py new file mode 100644 index 0000000000000000000000000000000000000000..65e0fa62d52445a4669f3cdc5ef278dbf9c88ea4 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/solvers/ode/tests/test_ode.py @@ -0,0 +1,1105 @@ +from sympy.core.function import (Derivative, Function, Subs, diff) +from sympy.core.numbers import (E, I, Rational, pi) +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.complexes import (im, re) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.hyperbolic import acosh +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (atan2, cos, sin, tan) +from sympy.integrals.integrals import Integral +from sympy.polys.polytools import Poly +from sympy.series.order import O +from sympy.simplify.radsimp import collect + +from sympy.solvers.ode import (classify_ode, + homogeneous_order, dsolve) + +from sympy.solvers.ode.subscheck import checkodesol +from sympy.solvers.ode.ode import (classify_sysode, + constant_renumber, constantsimp, get_numbered_constants, solve_ics) + +from sympy.solvers.ode.nonhomogeneous import _undetermined_coefficients_match +from sympy.solvers.ode.single import LinearCoefficients +from sympy.solvers.deutils import ode_order +from sympy.testing.pytest import XFAIL, raises, slow, SKIP +from sympy.utilities.misc import filldedent + + +C0, C1, C2, C3, C4, C5, C6, C7, C8, C9, C10 = symbols('C0:11') +u, x, y, z = symbols('u,x:z', real=True) +f = Function('f') +g = Function('g') +h = Function('h') + +# Note: Examples which were specifically testing Single ODE solver are moved to test_single.py +# and all the system of ode examples are moved to test_systems.py +# Note: the tests below may fail (but still be correct) if ODE solver, +# the integral engine, solve(), or even simplify() changes. Also, in +# differently formatted solutions, the arbitrary constants might not be +# equal. Using specific hints in tests can help to avoid this. + +# Tests of order higher than 1 should run the solutions through +# constant_renumber because it will normalize it (constant_renumber causes +# dsolve() to return different results on different machines) + + +def test_get_numbered_constants(): + with raises(ValueError): + get_numbered_constants(None) + + +def test_dsolve_all_hint(): + eq = f(x).diff(x) + output = dsolve(eq, hint='all') + + # Match the Dummy variables: + sol1 = output['separable_Integral'] + _y = sol1.lhs.args[1][0] + sol1 = output['1st_homogeneous_coeff_subs_dep_div_indep_Integral'] + _u1 = sol1.rhs.args[1].args[1][0] + + expected = {'Bernoulli_Integral': Eq(f(x), C1 + Integral(0, x)), + '1st_homogeneous_coeff_best': Eq(f(x), C1), + 'Bernoulli': Eq(f(x), C1), + 'nth_algebraic': Eq(f(x), C1), + 'nth_linear_euler_eq_homogeneous': Eq(f(x), C1), + 'nth_linear_constant_coeff_homogeneous': Eq(f(x), C1), + 'separable': Eq(f(x), C1), + '1st_homogeneous_coeff_subs_indep_div_dep': Eq(f(x), C1), + 'nth_algebraic_Integral': Eq(f(x), C1), + '1st_linear': Eq(f(x), C1), + '1st_linear_Integral': Eq(f(x), C1 + Integral(0, x)), + '1st_exact': Eq(f(x), C1), + '1st_exact_Integral': Eq(Subs(Integral(0, x) + Integral(1, _y), _y, f(x)), C1), + 'lie_group': Eq(f(x), C1), + '1st_homogeneous_coeff_subs_dep_div_indep': Eq(f(x), C1), + '1st_homogeneous_coeff_subs_dep_div_indep_Integral': Eq(log(x), C1 + Integral(-1/_u1, (_u1, f(x)/x))), + '1st_power_series': Eq(f(x), C1), + 'separable_Integral': Eq(Integral(1, (_y, f(x))), C1 + Integral(0, x)), + '1st_homogeneous_coeff_subs_indep_div_dep_Integral': Eq(f(x), C1), + 'best': Eq(f(x), C1), + 'best_hint': 'nth_algebraic', + 'default': 'nth_algebraic', + 'order': 1} + assert output == expected + + assert dsolve(eq, hint='best') == Eq(f(x), C1) + + +def test_dsolve_ics(): + # Maybe this should just use one of the solutions instead of raising... + with raises(NotImplementedError): + dsolve(f(x).diff(x) - sqrt(f(x)), ics={f(1):1}) + + +@slow +def test_dsolve_options(): + eq = x*f(x).diff(x) + f(x) + a = dsolve(eq, hint='all') + b = dsolve(eq, hint='all', simplify=False) + c = dsolve(eq, hint='all_Integral') + keys = ['1st_exact', '1st_exact_Integral', '1st_homogeneous_coeff_best', + '1st_homogeneous_coeff_subs_dep_div_indep', + '1st_homogeneous_coeff_subs_dep_div_indep_Integral', + '1st_homogeneous_coeff_subs_indep_div_dep', + '1st_homogeneous_coeff_subs_indep_div_dep_Integral', '1st_linear', + '1st_linear_Integral', 'Bernoulli', 'Bernoulli_Integral', + 'almost_linear', 'almost_linear_Integral', 'best', 'best_hint', + 'default', 'factorable', 'lie_group', + 'nth_linear_euler_eq_homogeneous', 'order', + 'separable', 'separable_Integral'] + Integral_keys = ['1st_exact_Integral', + '1st_homogeneous_coeff_subs_dep_div_indep_Integral', + '1st_homogeneous_coeff_subs_indep_div_dep_Integral', '1st_linear_Integral', + 'Bernoulli_Integral', 'almost_linear_Integral', 'best', 'best_hint', 'default', + 'factorable', 'nth_linear_euler_eq_homogeneous', + 'order', 'separable_Integral'] + assert sorted(a.keys()) == keys + assert a['order'] == ode_order(eq, f(x)) + assert a['best'] == Eq(f(x), C1/x) + assert dsolve(eq, hint='best') == Eq(f(x), C1/x) + assert a['default'] == 'factorable' + assert a['best_hint'] == 'factorable' + assert not a['1st_exact'].has(Integral) + assert not a['separable'].has(Integral) + assert not a['1st_homogeneous_coeff_best'].has(Integral) + assert not a['1st_homogeneous_coeff_subs_dep_div_indep'].has(Integral) + assert not a['1st_homogeneous_coeff_subs_indep_div_dep'].has(Integral) + assert not a['1st_linear'].has(Integral) + assert a['1st_linear_Integral'].has(Integral) + assert a['1st_exact_Integral'].has(Integral) + assert a['1st_homogeneous_coeff_subs_dep_div_indep_Integral'].has(Integral) + assert a['1st_homogeneous_coeff_subs_indep_div_dep_Integral'].has(Integral) + assert a['separable_Integral'].has(Integral) + assert sorted(b.keys()) == keys + assert b['order'] == ode_order(eq, f(x)) + assert b['best'] == Eq(f(x), C1/x) + assert dsolve(eq, hint='best', simplify=False) == Eq(f(x), C1/x) + assert b['default'] == 'factorable' + assert b['best_hint'] == 'factorable' + assert a['separable'] != b['separable'] + assert a['1st_homogeneous_coeff_subs_dep_div_indep'] != \ + b['1st_homogeneous_coeff_subs_dep_div_indep'] + assert a['1st_homogeneous_coeff_subs_indep_div_dep'] != \ + b['1st_homogeneous_coeff_subs_indep_div_dep'] + assert not b['1st_exact'].has(Integral) + assert not b['separable'].has(Integral) + assert not b['1st_homogeneous_coeff_best'].has(Integral) + assert not b['1st_homogeneous_coeff_subs_dep_div_indep'].has(Integral) + assert not b['1st_homogeneous_coeff_subs_indep_div_dep'].has(Integral) + assert not b['1st_linear'].has(Integral) + assert b['1st_linear_Integral'].has(Integral) + assert b['1st_exact_Integral'].has(Integral) + assert b['1st_homogeneous_coeff_subs_dep_div_indep_Integral'].has(Integral) + assert b['1st_homogeneous_coeff_subs_indep_div_dep_Integral'].has(Integral) + assert b['separable_Integral'].has(Integral) + assert sorted(c.keys()) == Integral_keys + raises(ValueError, lambda: dsolve(eq, hint='notarealhint')) + raises(ValueError, lambda: dsolve(eq, hint='Liouville')) + assert dsolve(f(x).diff(x) - 1/f(x)**2, hint='all')['best'] == \ + dsolve(f(x).diff(x) - 1/f(x)**2, hint='best') + assert dsolve(f(x) + f(x).diff(x) + sin(x).diff(x) + 1, f(x), + hint="1st_linear_Integral") == \ + Eq(f(x), (C1 + Integral((-sin(x).diff(x) - 1)* + exp(Integral(1, x)), x))*exp(-Integral(1, x))) + + +def test_classify_ode(): + assert classify_ode(f(x).diff(x, 2), f(x)) == \ + ( + 'nth_algebraic', + 'nth_linear_constant_coeff_homogeneous', + 'nth_linear_euler_eq_homogeneous', + 'Liouville', + '2nd_power_series_ordinary', + 'nth_algebraic_Integral', + 'Liouville_Integral', + ) + assert classify_ode(f(x), f(x)) == ('nth_algebraic', 'nth_algebraic_Integral') + assert classify_ode(Eq(f(x).diff(x), 0), f(x)) == ( + 'nth_algebraic', + 'separable', + '1st_exact', + '1st_linear', + 'Bernoulli', + '1st_homogeneous_coeff_best', + '1st_homogeneous_coeff_subs_indep_div_dep', + '1st_homogeneous_coeff_subs_dep_div_indep', + '1st_power_series', 'lie_group', + 'nth_linear_constant_coeff_homogeneous', + 'nth_linear_euler_eq_homogeneous', + 'nth_algebraic_Integral', + 'separable_Integral', + '1st_exact_Integral', + '1st_linear_Integral', + 'Bernoulli_Integral', + '1st_homogeneous_coeff_subs_indep_div_dep_Integral', + '1st_homogeneous_coeff_subs_dep_div_indep_Integral') + assert classify_ode(f(x).diff(x)**2, f(x)) == ('factorable', + 'nth_algebraic', + 'separable', + '1st_exact', + '1st_linear', + 'Bernoulli', + '1st_homogeneous_coeff_best', + '1st_homogeneous_coeff_subs_indep_div_dep', + '1st_homogeneous_coeff_subs_dep_div_indep', + '1st_power_series', + 'lie_group', + 'nth_linear_euler_eq_homogeneous', + 'nth_algebraic_Integral', + 'separable_Integral', + '1st_exact_Integral', + '1st_linear_Integral', + 'Bernoulli_Integral', + '1st_homogeneous_coeff_subs_indep_div_dep_Integral', + '1st_homogeneous_coeff_subs_dep_div_indep_Integral') + # issue 4749: f(x) should be cleared from highest derivative before classifying + a = classify_ode(Eq(f(x).diff(x) + f(x), x), f(x)) + b = classify_ode(f(x).diff(x)*f(x) + f(x)*f(x) - x*f(x), f(x)) + c = classify_ode(f(x).diff(x)/f(x) + f(x)/f(x) - x/f(x), f(x)) + assert a == ('1st_exact', + '1st_linear', + 'Bernoulli', + 'almost_linear', + '1st_power_series', "lie_group", + 'nth_linear_constant_coeff_undetermined_coefficients', + 'nth_linear_constant_coeff_variation_of_parameters', + '1st_exact_Integral', + '1st_linear_Integral', + 'Bernoulli_Integral', + 'almost_linear_Integral', + 'nth_linear_constant_coeff_variation_of_parameters_Integral') + assert b == ('factorable', + '1st_linear', + 'Bernoulli', + '1st_power_series', + 'lie_group', + 'nth_linear_constant_coeff_undetermined_coefficients', + 'nth_linear_constant_coeff_variation_of_parameters', + '1st_linear_Integral', + 'Bernoulli_Integral', + 'nth_linear_constant_coeff_variation_of_parameters_Integral') + assert c == ('factorable', + '1st_linear', + 'Bernoulli', + '1st_power_series', + 'lie_group', + 'nth_linear_constant_coeff_undetermined_coefficients', + 'nth_linear_constant_coeff_variation_of_parameters', + '1st_linear_Integral', + 'Bernoulli_Integral', + 'nth_linear_constant_coeff_variation_of_parameters_Integral') + + assert classify_ode( + 2*x*f(x)*f(x).diff(x) + (1 + x)*f(x)**2 - exp(x), f(x) + ) == ('factorable', '1st_exact', 'Bernoulli', 'almost_linear', 'lie_group', + '1st_exact_Integral', 'Bernoulli_Integral', 'almost_linear_Integral') + assert 'Riccati_special_minus2' in \ + classify_ode(2*f(x).diff(x) + f(x)**2 - f(x)/x + 3*x**(-2), f(x)) + raises(ValueError, lambda: classify_ode(x + f(x, y).diff(x).diff( + y), f(x, y))) + # issue 5176 + k = Symbol('k') + assert classify_ode(f(x).diff(x)/(k*f(x) + k*x*f(x)) + 2*f(x)/(k*f(x) + + k*x*f(x)) + x*f(x).diff(x)/(k*f(x) + k*x*f(x)) + z, f(x)) == \ + ('factorable', 'separable', '1st_exact', '1st_linear', 'Bernoulli', + '1st_power_series', 'lie_group', 'separable_Integral', '1st_exact_Integral', + '1st_linear_Integral', 'Bernoulli_Integral') + # preprocessing + ans = ('factorable', 'nth_algebraic', 'separable', '1st_exact', '1st_linear', 'Bernoulli', + '1st_homogeneous_coeff_best', + '1st_homogeneous_coeff_subs_indep_div_dep', + '1st_homogeneous_coeff_subs_dep_div_indep', + '1st_power_series', 'lie_group', + 'nth_linear_constant_coeff_undetermined_coefficients', + 'nth_linear_euler_eq_nonhomogeneous_undetermined_coefficients', + 'nth_linear_constant_coeff_variation_of_parameters', + 'nth_linear_euler_eq_nonhomogeneous_variation_of_parameters', + 'nth_algebraic_Integral', + 'separable_Integral', '1st_exact_Integral', + '1st_linear_Integral', + 'Bernoulli_Integral', + '1st_homogeneous_coeff_subs_indep_div_dep_Integral', + '1st_homogeneous_coeff_subs_dep_div_indep_Integral', + 'nth_linear_constant_coeff_variation_of_parameters_Integral', + 'nth_linear_euler_eq_nonhomogeneous_variation_of_parameters_Integral') + # w/o f(x) given + assert classify_ode(diff(f(x) + x, x) + diff(f(x), x)) == ans + # w/ f(x) and prep=True + assert classify_ode(diff(f(x) + x, x) + diff(f(x), x), f(x), + prep=True) == ans + + assert classify_ode(Eq(2*x**3*f(x).diff(x), 0), f(x)) == \ + ('factorable', 'nth_algebraic', 'separable', '1st_exact', + '1st_linear', 'Bernoulli', '1st_power_series', + 'lie_group', 'nth_linear_euler_eq_homogeneous', + 'nth_algebraic_Integral', 'separable_Integral', '1st_exact_Integral', + '1st_linear_Integral', 'Bernoulli_Integral') + + + assert classify_ode(Eq(2*f(x)**3*f(x).diff(x), 0), f(x)) == \ + ('factorable', 'nth_algebraic', 'separable', '1st_exact', '1st_linear', + 'Bernoulli', '1st_power_series', 'lie_group', 'nth_algebraic_Integral', + 'separable_Integral', '1st_exact_Integral', '1st_linear_Integral', + 'Bernoulli_Integral') + # test issue 13864 + assert classify_ode(Eq(diff(f(x), x) - f(x)**x, 0), f(x)) == \ + ('1st_power_series', 'lie_group') + assert isinstance(classify_ode(Eq(f(x), 5), f(x), dict=True), dict) + + #This is for new behavior of classify_ode when called internally with default, It should + # return the first hint which matches therefore, 'ordered_hints' key will not be there. + assert sorted(classify_ode(Eq(f(x).diff(x), 0), f(x), dict=True).keys()) == \ + ['default', 'nth_linear_constant_coeff_homogeneous', 'order'] + a = classify_ode(2*x*f(x)*f(x).diff(x) + (1 + x)*f(x)**2 - exp(x), f(x), dict=True, hint='Bernoulli') + assert sorted(a.keys()) == ['Bernoulli', 'Bernoulli_Integral', 'default', 'order', 'ordered_hints'] + + # test issue 22155 + a = classify_ode(f(x).diff(x) - exp(f(x) - x), f(x)) + assert a == ('separable', + '1st_exact', '1st_power_series', + 'lie_group', 'separable_Integral', + '1st_exact_Integral') + + +def test_classify_ode_ics(): + # Dummy + eq = f(x).diff(x, x) - f(x) + + # Not f(0) or f'(0) + ics = {x: 1} + raises(ValueError, lambda: classify_ode(eq, f(x), ics=ics)) + + + ############################ + # f(0) type (AppliedUndef) # + ############################ + + + # Wrong function + ics = {g(0): 1} + raises(ValueError, lambda: classify_ode(eq, f(x), ics=ics)) + + # Contains x + ics = {f(x): 1} + raises(ValueError, lambda: classify_ode(eq, f(x), ics=ics)) + + # Too many args + ics = {f(0, 0): 1} + raises(ValueError, lambda: classify_ode(eq, f(x), ics=ics)) + + # point contains x + ics = {f(0): f(x)} + raises(ValueError, lambda: classify_ode(eq, f(x), ics=ics)) + + # Does not raise + ics = {f(0): f(0)} + classify_ode(eq, f(x), ics=ics) + + # Does not raise + ics = {f(0): 1} + classify_ode(eq, f(x), ics=ics) + + + ##################### + # f'(0) type (Subs) # + ##################### + + # Wrong function + ics = {g(x).diff(x).subs(x, 0): 1} + raises(ValueError, lambda: classify_ode(eq, f(x), ics=ics)) + + # Contains x + ics = {f(y).diff(y).subs(y, x): 1} + raises(ValueError, lambda: classify_ode(eq, f(x), ics=ics)) + + # Wrong variable + ics = {f(y).diff(y).subs(y, 0): 1} + raises(ValueError, lambda: classify_ode(eq, f(x), ics=ics)) + + # Too many args + ics = {f(x, y).diff(x).subs(x, 0): 1} + raises(ValueError, lambda: classify_ode(eq, f(x), ics=ics)) + + # Derivative wrt wrong vars + ics = {Derivative(f(x), x, y).subs(x, 0): 1} + raises(ValueError, lambda: classify_ode(eq, f(x), ics=ics)) + + # point contains x + ics = {f(x).diff(x).subs(x, 0): f(x)} + raises(ValueError, lambda: classify_ode(eq, f(x), ics=ics)) + + # Does not raise + ics = {f(x).diff(x).subs(x, 0): f(x).diff(x).subs(x, 0)} + classify_ode(eq, f(x), ics=ics) + + # Does not raise + ics = {f(x).diff(x).subs(x, 0): 1} + classify_ode(eq, f(x), ics=ics) + + ########################### + # f'(y) type (Derivative) # + ########################### + + # Wrong function + ics = {g(x).diff(x).subs(x, y): 1} + raises(ValueError, lambda: classify_ode(eq, f(x), ics=ics)) + + # Contains x + ics = {f(y).diff(y).subs(y, x): 1} + raises(ValueError, lambda: classify_ode(eq, f(x), ics=ics)) + + # Too many args + ics = {f(x, y).diff(x).subs(x, y): 1} + raises(ValueError, lambda: classify_ode(eq, f(x), ics=ics)) + + # Derivative wrt wrong vars + ics = {Derivative(f(x), x, z).subs(x, y): 1} + raises(ValueError, lambda: classify_ode(eq, f(x), ics=ics)) + + # point contains x + ics = {f(x).diff(x).subs(x, y): f(x)} + raises(ValueError, lambda: classify_ode(eq, f(x), ics=ics)) + + # Does not raise + ics = {f(x).diff(x).subs(x, 0): f(0)} + classify_ode(eq, f(x), ics=ics) + + # Does not raise + ics = {f(x).diff(x).subs(x, y): 1} + classify_ode(eq, f(x), ics=ics) + +def test_classify_sysode(): + # Here x is assumed to be x(t) and y as y(t) for simplicity. + # Similarly diff(x,t) and diff(y,y) is assumed to be x1 and y1 respectively. + k, l, m, n = symbols('k, l, m, n', Integer=True) + k1, k2, k3, l1, l2, l3, m1, m2, m3 = symbols('k1, k2, k3, l1, l2, l3, m1, m2, m3', Integer=True) + P, Q, R, p, q, r = symbols('P, Q, R, p, q, r', cls=Function) + P1, P2, P3, Q1, Q2, R1, R2 = symbols('P1, P2, P3, Q1, Q2, R1, R2', cls=Function) + x, y, z = symbols('x, y, z', cls=Function) + t = symbols('t') + x1 = diff(x(t),t) + y1 = diff(y(t),t) + + eq6 = (Eq(x1, exp(k*x(t))*P(x(t),y(t))), Eq(y1,r(y(t))*P(x(t),y(t)))) + sol6 = {'no_of_equation': 2, 'func_coeff': {(0, x(t), 0): 0, (1, x(t), 1): 0, (0, x(t), 1): 1, (1, y(t), 0): 0, \ + (1, x(t), 0): 0, (0, y(t), 1): 0, (0, y(t), 0): 0, (1, y(t), 1): 1}, 'type_of_equation': 'type2', 'func': \ + [x(t), y(t)], 'is_linear': False, 'eq': [-P(x(t), y(t))*exp(k*x(t)) + Derivative(x(t), t), -P(x(t), \ + y(t))*r(y(t)) + Derivative(y(t), t)], 'order': {y(t): 1, x(t): 1}} + assert classify_sysode(eq6) == sol6 + + eq7 = (Eq(x1, x(t)**2+y(t)/x(t)), Eq(y1, x(t)/y(t))) + sol7 = {'no_of_equation': 2, 'func_coeff': {(0, x(t), 0): 0, (1, x(t), 1): 0, (0, x(t), 1): 1, (1, y(t), 0): 0, \ + (1, x(t), 0): -1/y(t), (0, y(t), 1): 0, (0, y(t), 0): -1/x(t), (1, y(t), 1): 1}, 'type_of_equation': 'type3', \ + 'func': [x(t), y(t)], 'is_linear': False, 'eq': [-x(t)**2 + Derivative(x(t), t) - y(t)/x(t), -x(t)/y(t) + \ + Derivative(y(t), t)], 'order': {y(t): 1, x(t): 1}} + assert classify_sysode(eq7) == sol7 + + eq8 = (Eq(x1, P1(x(t))*Q1(y(t))*R(x(t),y(t),t)), Eq(y1, P1(x(t))*Q1(y(t))*R(x(t),y(t),t))) + sol8 = {'func': [x(t), y(t)], 'is_linear': False, 'type_of_equation': 'type4', 'eq': \ + [-P1(x(t))*Q1(y(t))*R(x(t), y(t), t) + Derivative(x(t), t), -P1(x(t))*Q1(y(t))*R(x(t), y(t), t) + \ + Derivative(y(t), t)], 'func_coeff': {(0, y(t), 1): 0, (1, y(t), 1): 1, (1, x(t), 1): 0, (0, y(t), 0): 0, \ + (1, x(t), 0): 0, (0, x(t), 0): 0, (1, y(t), 0): 0, (0, x(t), 1): 1}, 'order': {y(t): 1, x(t): 1}, 'no_of_equation': 2} + assert classify_sysode(eq8) == sol8 + + eq11 = (Eq(x1,x(t)*y(t)**3), Eq(y1,y(t)**5)) + sol11 = {'no_of_equation': 2, 'func_coeff': {(0, x(t), 0): -y(t)**3, (1, x(t), 1): 0, (0, x(t), 1): 1, \ + (1, y(t), 0): 0, (1, x(t), 0): 0, (0, y(t), 1): 0, (0, y(t), 0): 0, (1, y(t), 1): 1}, 'type_of_equation': \ + 'type1', 'func': [x(t), y(t)], 'is_linear': False, 'eq': [-x(t)*y(t)**3 + Derivative(x(t), t), \ + -y(t)**5 + Derivative(y(t), t)], 'order': {y(t): 1, x(t): 1}} + assert classify_sysode(eq11) == sol11 + + eq13 = (Eq(x1,x(t)*y(t)*sin(t)**2), Eq(y1,y(t)**2*sin(t)**2)) + sol13 = {'no_of_equation': 2, 'func_coeff': {(0, x(t), 0): -y(t)*sin(t)**2, (1, x(t), 1): 0, (0, x(t), 1): 1, \ + (1, y(t), 0): 0, (1, x(t), 0): 0, (0, y(t), 1): 0, (0, y(t), 0): -x(t)*sin(t)**2, (1, y(t), 1): 1}, \ + 'type_of_equation': 'type4', 'func': [x(t), y(t)], 'is_linear': False, 'eq': [-x(t)*y(t)*sin(t)**2 + \ + Derivative(x(t), t), -y(t)**2*sin(t)**2 + Derivative(y(t), t)], 'order': {y(t): 1, x(t): 1}} + assert classify_sysode(eq13) == sol13 + + +def test_solve_ics(): + # Basic tests that things work from dsolve. + assert dsolve(f(x).diff(x) - 1/f(x), f(x), ics={f(1): 2}) == \ + Eq(f(x), sqrt(2 * x + 2)) + assert dsolve(f(x).diff(x) - f(x), f(x), ics={f(0): 1}) == Eq(f(x), exp(x)) + assert dsolve(f(x).diff(x) - f(x), f(x), ics={f(x).diff(x).subs(x, 0): 1}) == Eq(f(x), exp(x)) + assert dsolve(f(x).diff(x, x) + f(x), f(x), ics={f(0): 1, + f(x).diff(x).subs(x, 0): 1}) == Eq(f(x), sin(x) + cos(x)) + assert dsolve([f(x).diff(x) - f(x) + g(x), g(x).diff(x) - g(x) - f(x)], + [f(x), g(x)], ics={f(0): 1, g(0): 0}) == [Eq(f(x), exp(x)*cos(x)), Eq(g(x), exp(x)*sin(x))] + + # Test cases where dsolve returns two solutions. + eq = (x**2*f(x)**2 - x).diff(x) + assert dsolve(eq, f(x), ics={f(1): 0}) == [Eq(f(x), + -sqrt(x - 1)/x), Eq(f(x), sqrt(x - 1)/x)] + assert dsolve(eq, f(x), ics={f(x).diff(x).subs(x, 1): 0}) == [Eq(f(x), + -sqrt(x - S.Half)/x), Eq(f(x), sqrt(x - S.Half)/x)] + + eq = cos(f(x)) - (x*sin(f(x)) - f(x)**2)*f(x).diff(x) + assert dsolve(eq, f(x), + ics={f(0):1}, hint='1st_exact', simplify=False) == Eq(x*cos(f(x)) + f(x)**3/3, Rational(1, 3)) + assert dsolve(eq, f(x), + ics={f(0):1}, hint='1st_exact', simplify=True) == Eq(x*cos(f(x)) + f(x)**3/3, Rational(1, 3)) + + assert solve_ics([Eq(f(x), C1*exp(x))], [f(x)], [C1], {f(0): 1}) == {C1: 1} + assert solve_ics([Eq(f(x), C1*sin(x) + C2*cos(x))], [f(x)], [C1, C2], + {f(0): 1, f(pi/2): 1}) == {C1: 1, C2: 1} + + assert solve_ics([Eq(f(x), C1*sin(x) + C2*cos(x))], [f(x)], [C1, C2], + {f(0): 1, f(x).diff(x).subs(x, 0): 1}) == {C1: 1, C2: 1} + + assert solve_ics([Eq(f(x), C1*sin(x) + C2*cos(x))], [f(x)], [C1, C2], {f(0): 1}) == \ + {C2: 1} + + # Some more complicated tests Refer to PR #16098 + + assert set(dsolve(f(x).diff(x)*(f(x).diff(x, 2)-x), ics={f(0):0, f(x).diff(x).subs(x, 1):0})) == \ + {Eq(f(x), 0), Eq(f(x), x ** 3 / 6 - x / 2)} + assert set(dsolve(f(x).diff(x)*(f(x).diff(x, 2)-x), ics={f(0):0})) == \ + {Eq(f(x), 0), Eq(f(x), C2*x + x**3/6)} + + K, r, f0 = symbols('K r f0') + sol = Eq(f(x), K*f0*exp(r*x)/((-K + f0)*(f0*exp(r*x)/(-K + f0) - 1))) + assert (dsolve(Eq(f(x).diff(x), r * f(x) * (1 - f(x) / K)), f(x), ics={f(0): f0})) == sol + + + #Order dependent issues Refer to PR #16098 + assert set(dsolve(f(x).diff(x)*(f(x).diff(x, 2)-x), ics={f(x).diff(x).subs(x,0):0, f(0):0})) == \ + {Eq(f(x), 0), Eq(f(x), x ** 3 / 6)} + assert set(dsolve(f(x).diff(x)*(f(x).diff(x, 2)-x), ics={f(0):0, f(x).diff(x).subs(x,0):0})) == \ + {Eq(f(x), 0), Eq(f(x), x ** 3 / 6)} + + # XXX: Ought to be ValueError + raises(ValueError, lambda: solve_ics([Eq(f(x), C1*sin(x) + C2*cos(x))], [f(x)], [C1, C2], {f(0): 1, f(pi): 1})) + + # Degenerate case. f'(0) is identically 0. + raises(ValueError, lambda: solve_ics([Eq(f(x), sqrt(C1 - x**2))], [f(x)], [C1], {f(x).diff(x).subs(x, 0): 0})) + + EI, q, L = symbols('EI q L') + + # eq = Eq(EI*diff(f(x), x, 4), q) + sols = [Eq(f(x), C1 + C2*x + C3*x**2 + C4*x**3 + q*x**4/(24*EI))] + funcs = [f(x)] + constants = [C1, C2, C3, C4] + # Test both cases, Derivative (the default from f(x).diff(x).subs(x, L)), + # and Subs + ics1 = {f(0): 0, + f(x).diff(x).subs(x, 0): 0, + f(L).diff(L, 2): 0, + f(L).diff(L, 3): 0} + ics2 = {f(0): 0, + f(x).diff(x).subs(x, 0): 0, + Subs(f(x).diff(x, 2), x, L): 0, + Subs(f(x).diff(x, 3), x, L): 0} + + solved_constants1 = solve_ics(sols, funcs, constants, ics1) + solved_constants2 = solve_ics(sols, funcs, constants, ics2) + assert solved_constants1 == solved_constants2 == { + C1: 0, + C2: 0, + C3: L**2*q/(4*EI), + C4: -L*q/(6*EI)} + + # Allow the ics to refer to f + ics = {f(0): f(0)} + assert dsolve(f(x).diff(x) - f(x), f(x), ics=ics) == Eq(f(x), f(0)*exp(x)) + + ics = {f(x).diff(x).subs(x, 0): f(x).diff(x).subs(x, 0), f(0): f(0)} + assert dsolve(f(x).diff(x, x) + f(x), f(x), ics=ics) == \ + Eq(f(x), f(0)*cos(x) + f(x).diff(x).subs(x, 0)*sin(x)) + +def test_ode_order(): + f = Function('f') + g = Function('g') + x = Symbol('x') + assert ode_order(3*x*exp(f(x)), f(x)) == 0 + assert ode_order(x*diff(f(x), x) + 3*x*f(x) - sin(x)/x, f(x)) == 1 + assert ode_order(x**2*f(x).diff(x, x) + x*diff(f(x), x) - f(x), f(x)) == 2 + assert ode_order(diff(x*exp(f(x)), x, x), f(x)) == 2 + assert ode_order(diff(x*diff(x*exp(f(x)), x, x), x), f(x)) == 3 + assert ode_order(diff(f(x), x, x), g(x)) == 0 + assert ode_order(diff(f(x), x, x)*diff(g(x), x), f(x)) == 2 + assert ode_order(diff(f(x), x, x)*diff(g(x), x), g(x)) == 1 + assert ode_order(diff(x*diff(x*exp(f(x)), x, x), x), g(x)) == 0 + # issue 5835: ode_order has to also work for unevaluated derivatives + # (ie, without using doit()). + assert ode_order(Derivative(x*f(x), x), f(x)) == 1 + assert ode_order(x*sin(Derivative(x*f(x)**2, x, x)), f(x)) == 2 + assert ode_order(Derivative(x*Derivative(x*exp(f(x)), x, x), x), g(x)) == 0 + assert ode_order(Derivative(f(x), x, x), g(x)) == 0 + assert ode_order(Derivative(x*exp(f(x)), x, x), f(x)) == 2 + assert ode_order(Derivative(f(x), x, x)*Derivative(g(x), x), g(x)) == 1 + assert ode_order(Derivative(x*Derivative(f(x), x, x), x), f(x)) == 3 + assert ode_order( + x*sin(Derivative(x*Derivative(f(x), x)**2, x, x)), f(x)) == 3 + + +def test_homogeneous_order(): + assert homogeneous_order(exp(y/x) + tan(y/x), x, y) == 0 + assert homogeneous_order(x**2 + sin(x)*cos(y), x, y) is None + assert homogeneous_order(x - y - x*sin(y/x), x, y) == 1 + assert homogeneous_order((x*y + sqrt(x**4 + y**4) + x**2*(log(x) - log(y)))/ + (pi*x**Rational(2, 3)*sqrt(y)**3), x, y) == Rational(-1, 6) + assert homogeneous_order(y/x*cos(y/x) - x/y*sin(y/x) + cos(y/x), x, y) == 0 + assert homogeneous_order(f(x), x, f(x)) == 1 + assert homogeneous_order(f(x)**2, x, f(x)) == 2 + assert homogeneous_order(x*y*z, x, y) == 2 + assert homogeneous_order(x*y*z, x, y, z) == 3 + assert homogeneous_order(x**2*f(x)/sqrt(x**2 + f(x)**2), f(x)) is None + assert homogeneous_order(f(x, y)**2, x, f(x, y), y) == 2 + assert homogeneous_order(f(x, y)**2, x, f(x), y) is None + assert homogeneous_order(f(x, y)**2, x, f(x, y)) is None + assert homogeneous_order(f(y, x)**2, x, y, f(x, y)) is None + assert homogeneous_order(f(y), f(x), x) is None + assert homogeneous_order(-f(x)/x + 1/sin(f(x)/ x), f(x), x) == 0 + assert homogeneous_order(log(1/y) + log(x**2), x, y) is None + assert homogeneous_order(log(1/y) + log(x), x, y) == 0 + assert homogeneous_order(log(x/y), x, y) == 0 + assert homogeneous_order(2*log(1/y) + 2*log(x), x, y) == 0 + a = Symbol('a') + assert homogeneous_order(a*log(1/y) + a*log(x), x, y) == 0 + assert homogeneous_order(f(x).diff(x), x, y) is None + assert homogeneous_order(-f(x).diff(x) + x, x, y) is None + assert homogeneous_order(O(x), x, y) is None + assert homogeneous_order(x + O(x**2), x, y) is None + assert homogeneous_order(x**pi, x) == pi + assert homogeneous_order(x**x, x) is None + raises(ValueError, lambda: homogeneous_order(x*y)) + + +@XFAIL +def test_noncircularized_real_imaginary_parts(): + # If this passes, lines numbered 3878-3882 (at the time of this commit) + # of sympy/solvers/ode.py for nth_linear_constant_coeff_homogeneous + # should be removed. + y = sqrt(1+x) + i, r = im(y), re(y) + assert not (i.has(atan2) and r.has(atan2)) + + +def test_collect_respecting_exponentials(): + # If this test passes, lines 1306-1311 (at the time of this commit) + # of sympy/solvers/ode.py should be removed. + sol = 1 + exp(x/2) + assert sol == collect( sol, exp(x/3)) + + +def test_undetermined_coefficients_match(): + assert _undetermined_coefficients_match(g(x), x) == {'test': False} + assert _undetermined_coefficients_match(sin(2*x + sqrt(5)), x) == \ + {'test': True, 'trialset': + {cos(2*x + sqrt(5)), sin(2*x + sqrt(5))}} + assert _undetermined_coefficients_match(sin(x)*cos(x), x) == \ + {'test': False} + s = {cos(x), x*cos(x), x**2*cos(x), x**2*sin(x), x*sin(x), sin(x)} + assert _undetermined_coefficients_match(sin(x)*(x**2 + x + 1), x) == \ + {'test': True, 'trialset': s} + assert _undetermined_coefficients_match( + sin(x)*x**2 + sin(x)*x + sin(x), x) == {'test': True, 'trialset': s} + assert _undetermined_coefficients_match( + exp(2*x)*sin(x)*(x**2 + x + 1), x + ) == { + 'test': True, 'trialset': {exp(2*x)*sin(x), x**2*exp(2*x)*sin(x), + cos(x)*exp(2*x), x**2*cos(x)*exp(2*x), x*cos(x)*exp(2*x), + x*exp(2*x)*sin(x)}} + assert _undetermined_coefficients_match(1/sin(x), x) == {'test': False} + assert _undetermined_coefficients_match(log(x), x) == {'test': False} + assert _undetermined_coefficients_match(2**(x)*(x**2 + x + 1), x) == \ + {'test': True, 'trialset': {2**x, x*2**x, x**2*2**x}} + assert _undetermined_coefficients_match(x**y, x) == {'test': False} + assert _undetermined_coefficients_match(exp(x)*exp(2*x + 1), x) == \ + {'test': True, 'trialset': {exp(1 + 3*x)}} + assert _undetermined_coefficients_match(sin(x)*(x**2 + x + 1), x) == \ + {'test': True, 'trialset': {x*cos(x), x*sin(x), x**2*cos(x), + x**2*sin(x), cos(x), sin(x)}} + assert _undetermined_coefficients_match(sin(x)*(x + sin(x)), x) == \ + {'test': False} + assert _undetermined_coefficients_match(sin(x)*(x + sin(2*x)), x) == \ + {'test': False} + assert _undetermined_coefficients_match(sin(x)*tan(x), x) == \ + {'test': False} + assert _undetermined_coefficients_match( + x**2*sin(x)*exp(x) + x*sin(x) + x, x + ) == { + 'test': True, 'trialset': {x**2*cos(x)*exp(x), x, cos(x), S.One, + exp(x)*sin(x), sin(x), x*exp(x)*sin(x), x*cos(x), x*cos(x)*exp(x), + x*sin(x), cos(x)*exp(x), x**2*exp(x)*sin(x)}} + assert _undetermined_coefficients_match(4*x*sin(x - 2), x) == { + 'trialset': {x*cos(x - 2), x*sin(x - 2), cos(x - 2), sin(x - 2)}, + 'test': True, + } + assert _undetermined_coefficients_match(2**x*x, x) == \ + {'test': True, 'trialset': {2**x, x*2**x}} + assert _undetermined_coefficients_match(2**x*exp(2*x), x) == \ + {'test': True, 'trialset': {2**x*exp(2*x)}} + assert _undetermined_coefficients_match(exp(-x)/x, x) == \ + {'test': False} + # Below are from Ordinary Differential Equations, + # Tenenbaum and Pollard, pg. 231 + assert _undetermined_coefficients_match(S(4), x) == \ + {'test': True, 'trialset': {S.One}} + assert _undetermined_coefficients_match(12*exp(x), x) == \ + {'test': True, 'trialset': {exp(x)}} + assert _undetermined_coefficients_match(exp(I*x), x) == \ + {'test': True, 'trialset': {exp(I*x)}} + assert _undetermined_coefficients_match(sin(x), x) == \ + {'test': True, 'trialset': {cos(x), sin(x)}} + assert _undetermined_coefficients_match(cos(x), x) == \ + {'test': True, 'trialset': {cos(x), sin(x)}} + assert _undetermined_coefficients_match(8 + 6*exp(x) + 2*sin(x), x) == \ + {'test': True, 'trialset': {S.One, cos(x), sin(x), exp(x)}} + assert _undetermined_coefficients_match(x**2, x) == \ + {'test': True, 'trialset': {S.One, x, x**2}} + assert _undetermined_coefficients_match(9*x*exp(x) + exp(-x), x) == \ + {'test': True, 'trialset': {x*exp(x), exp(x), exp(-x)}} + assert _undetermined_coefficients_match(2*exp(2*x)*sin(x), x) == \ + {'test': True, 'trialset': {exp(2*x)*sin(x), cos(x)*exp(2*x)}} + assert _undetermined_coefficients_match(x - sin(x), x) == \ + {'test': True, 'trialset': {S.One, x, cos(x), sin(x)}} + assert _undetermined_coefficients_match(x**2 + 2*x, x) == \ + {'test': True, 'trialset': {S.One, x, x**2}} + assert _undetermined_coefficients_match(4*x*sin(x), x) == \ + {'test': True, 'trialset': {x*cos(x), x*sin(x), cos(x), sin(x)}} + assert _undetermined_coefficients_match(x*sin(2*x), x) == \ + {'test': True, 'trialset': + {x*cos(2*x), x*sin(2*x), cos(2*x), sin(2*x)}} + assert _undetermined_coefficients_match(x**2*exp(-x), x) == \ + {'test': True, 'trialset': {x*exp(-x), x**2*exp(-x), exp(-x)}} + assert _undetermined_coefficients_match(2*exp(-x) - x**2*exp(-x), x) == \ + {'test': True, 'trialset': {x*exp(-x), x**2*exp(-x), exp(-x)}} + assert _undetermined_coefficients_match(exp(-2*x) + x**2, x) == \ + {'test': True, 'trialset': {S.One, x, x**2, exp(-2*x)}} + assert _undetermined_coefficients_match(x*exp(-x), x) == \ + {'test': True, 'trialset': {x*exp(-x), exp(-x)}} + assert _undetermined_coefficients_match(x + exp(2*x), x) == \ + {'test': True, 'trialset': {S.One, x, exp(2*x)}} + assert _undetermined_coefficients_match(sin(x) + exp(-x), x) == \ + {'test': True, 'trialset': {cos(x), sin(x), exp(-x)}} + assert _undetermined_coefficients_match(exp(x), x) == \ + {'test': True, 'trialset': {exp(x)}} + # converted from sin(x)**2 + assert _undetermined_coefficients_match(S.Half - cos(2*x)/2, x) == \ + {'test': True, 'trialset': {S.One, cos(2*x), sin(2*x)}} + # converted from exp(2*x)*sin(x)**2 + assert _undetermined_coefficients_match( + exp(2*x)*(S.Half + cos(2*x)/2), x + ) == { + 'test': True, 'trialset': {exp(2*x)*sin(2*x), cos(2*x)*exp(2*x), + exp(2*x)}} + assert _undetermined_coefficients_match(2*x + sin(x) + cos(x), x) == \ + {'test': True, 'trialset': {S.One, x, cos(x), sin(x)}} + # converted from sin(2*x)*sin(x) + assert _undetermined_coefficients_match(cos(x)/2 - cos(3*x)/2, x) == \ + {'test': True, 'trialset': {cos(x), cos(3*x), sin(x), sin(3*x)}} + assert _undetermined_coefficients_match(cos(x**2), x) == {'test': False} + assert _undetermined_coefficients_match(2**(x**2), x) == {'test': False} + + +def test_issue_4785_22462(): + from sympy.abc import A + eq = x + A*(x + diff(f(x), x) + f(x)) + diff(f(x), x) + f(x) + 2 + assert classify_ode(eq, f(x)) == ('factorable', '1st_exact', '1st_linear', + 'Bernoulli', 'almost_linear', '1st_power_series', 'lie_group', + 'nth_linear_constant_coeff_undetermined_coefficients', + 'nth_linear_constant_coeff_variation_of_parameters', + '1st_exact_Integral', '1st_linear_Integral', 'Bernoulli_Integral', + 'almost_linear_Integral', + 'nth_linear_constant_coeff_variation_of_parameters_Integral') + # issue 4864 + eq = (x**2 + f(x)**2)*f(x).diff(x) - 2*x*f(x) + assert classify_ode(eq, f(x)) == ('factorable', '1st_exact', + '1st_homogeneous_coeff_best', + '1st_homogeneous_coeff_subs_indep_div_dep', + '1st_homogeneous_coeff_subs_dep_div_indep', + '1st_power_series', + 'lie_group', '1st_exact_Integral', + '1st_homogeneous_coeff_subs_indep_div_dep_Integral', + '1st_homogeneous_coeff_subs_dep_div_indep_Integral') + + +def test_issue_4825(): + raises(ValueError, lambda: dsolve(f(x, y).diff(x) - y*f(x, y), f(x))) + assert classify_ode(f(x, y).diff(x) - y*f(x, y), f(x), dict=True) == \ + {'order': 0, 'default': None, 'ordered_hints': ()} + # See also issue 3793, test Z13. + raises(ValueError, lambda: dsolve(f(x).diff(x), f(y))) + assert classify_ode(f(x).diff(x), f(y), dict=True) == \ + {'order': 0, 'default': None, 'ordered_hints': ()} + + +def test_constant_renumber_order_issue_5308(): + from sympy.utilities.iterables import variations + + assert constant_renumber(C1*x + C2*y) == \ + constant_renumber(C1*y + C2*x) == \ + C1*x + C2*y + e = C1*(C2 + x)*(C3 + y) + for a, b, c in variations([C1, C2, C3], 3): + assert constant_renumber(a*(b + x)*(c + y)) == e + + +def test_constant_renumber(): + e1, e2, x, y = symbols("e1:3 x y") + exprs = [e2*x, e1*x + e2*y] + + assert constant_renumber(exprs[0]) == e2*x + assert constant_renumber(exprs[0], variables=[x]) == C1*x + assert constant_renumber(exprs[0], variables=[x], newconstants=[C2]) == C2*x + assert constant_renumber(exprs, variables=[x, y]) == [C1*x, C1*y + C2*x] + assert constant_renumber(exprs, variables=[x, y], newconstants=symbols("C3:5")) == [C3*x, C3*y + C4*x] + + +def test_issue_5770(): + k = Symbol("k", real=True) + t = Symbol('t') + w = Function('w') + sol = dsolve(w(t).diff(t, 6) - k**6*w(t), w(t)) + assert len([s for s in sol.free_symbols if s.name.startswith('C')]) == 6 + assert constantsimp((C1*cos(x) + C2*cos(x))*exp(x), {C1, C2}) == \ + C1*cos(x)*exp(x) + assert constantsimp(C1*cos(x) + C2*cos(x) + C3*sin(x), {C1, C2, C3}) == \ + C1*cos(x) + C3*sin(x) + assert constantsimp(exp(C1 + x), {C1}) == C1*exp(x) + assert constantsimp(x + C1 + y, {C1, y}) == C1 + x + assert constantsimp(x + C1 + Integral(x, (x, 1, 2)), {C1}) == C1 + x + + +def test_issue_5112_5430(): + assert homogeneous_order(-log(x) + acosh(x), x) is None + assert homogeneous_order(y - log(x), x, y) is None + + +def test_issue_5095(): + f = Function('f') + raises(ValueError, lambda: dsolve(f(x).diff(x)**2, f(x), 'fdsjf')) + + +def test_homogeneous_function(): + f = Function('f') + eq1 = tan(x + f(x)) + eq2 = sin((3*x)/(4*f(x))) + eq3 = cos(x*f(x)*Rational(3, 4)) + eq4 = log((3*x + 4*f(x))/(5*f(x) + 7*x)) + eq5 = exp((2*x**2)/(3*f(x)**2)) + eq6 = log((3*x + 4*f(x))/(5*f(x) + 7*x) + exp((2*x**2)/(3*f(x)**2))) + eq7 = sin((3*x)/(5*f(x) + x**2)) + assert homogeneous_order(eq1, x, f(x)) == None + assert homogeneous_order(eq2, x, f(x)) == 0 + assert homogeneous_order(eq3, x, f(x)) == None + assert homogeneous_order(eq4, x, f(x)) == 0 + assert homogeneous_order(eq5, x, f(x)) == 0 + assert homogeneous_order(eq6, x, f(x)) == 0 + assert homogeneous_order(eq7, x, f(x)) == None + + +def test_linear_coeff_match(): + n, d = z*(2*x + 3*f(x) + 5), z*(7*x + 9*f(x) + 11) + rat = n/d + eq1 = sin(rat) + cos(rat.expand()) + obj1 = LinearCoefficients(eq1) + eq2 = rat + obj2 = LinearCoefficients(eq2) + eq3 = log(sin(rat)) + obj3 = LinearCoefficients(eq3) + ans = (4, Rational(-13, 3)) + assert obj1._linear_coeff_match(eq1, f(x)) == ans + assert obj2._linear_coeff_match(eq2, f(x)) == ans + assert obj3._linear_coeff_match(eq3, f(x)) == ans + + # no c + eq4 = (3*x)/f(x) + obj4 = LinearCoefficients(eq4) + # not x and f(x) + eq5 = (3*x + 2)/x + obj5 = LinearCoefficients(eq5) + # denom will be zero + eq6 = (3*x + 2*f(x) + 1)/(3*x + 2*f(x) + 5) + obj6 = LinearCoefficients(eq6) + # not rational coefficient + eq7 = (3*x + 2*f(x) + sqrt(2))/(3*x + 2*f(x) + 5) + obj7 = LinearCoefficients(eq7) + assert obj4._linear_coeff_match(eq4, f(x)) is None + assert obj5._linear_coeff_match(eq5, f(x)) is None + assert obj6._linear_coeff_match(eq6, f(x)) is None + assert obj7._linear_coeff_match(eq7, f(x)) is None + + +def test_constantsimp_take_problem(): + c = exp(C1) + 2 + assert len(Poly(constantsimp(exp(C1) + c + c*x, [C1])).gens) == 2 + + +def test_series(): + C1 = Symbol("C1") + eq = f(x).diff(x) - f(x) + sol = Eq(f(x), C1 + C1*x + C1*x**2/2 + C1*x**3/6 + C1*x**4/24 + + C1*x**5/120 + O(x**6)) + assert dsolve(eq, hint='1st_power_series') == sol + assert checkodesol(eq, sol, order=1)[0] + + eq = f(x).diff(x) - x*f(x) + sol = Eq(f(x), C1*x**4/8 + C1*x**2/2 + C1 + O(x**6)) + assert dsolve(eq, hint='1st_power_series') == sol + assert checkodesol(eq, sol, order=1)[0] + + eq = f(x).diff(x) - sin(x*f(x)) + sol = Eq(f(x), (x - 2)**2*(1+ sin(4))*cos(4) + (x - 2)*sin(4) + 2 + O(x**3)) + assert dsolve(eq, hint='1st_power_series', ics={f(2): 2}, n=3) == sol + # FIXME: The solution here should be O((x-2)**3) so is incorrect + #assert checkodesol(eq, sol, order=1)[0] + + +@slow +def test_2nd_power_series_ordinary(): + C1, C2 = symbols("C1 C2") + + eq = f(x).diff(x, 2) - x*f(x) + assert classify_ode(eq) == ('2nd_linear_airy', '2nd_power_series_ordinary') + sol = Eq(f(x), C2*(x**3/6 + 1) + C1*x*(x**3/12 + 1) + O(x**6)) + assert dsolve(eq, hint='2nd_power_series_ordinary') == sol + assert checkodesol(eq, sol) == (True, 0) + + sol = Eq(f(x), C2*((x + 2)**4/6 + (x + 2)**3/6 - (x + 2)**2 + 1) + + C1*(x + (x + 2)**4/12 - (x + 2)**3/3 + S(2)) + + O(x**6)) + assert dsolve(eq, hint='2nd_power_series_ordinary', x0=-2) == sol + # FIXME: Solution should be O((x+2)**6) + # assert checkodesol(eq, sol) == (True, 0) + + sol = Eq(f(x), C2*x + C1 + O(x**2)) + assert dsolve(eq, hint='2nd_power_series_ordinary', n=2) == sol + assert checkodesol(eq, sol) == (True, 0) + + eq = (1 + x**2)*(f(x).diff(x, 2)) + 2*x*(f(x).diff(x)) -2*f(x) + assert classify_ode(eq) == ('factorable', '2nd_hypergeometric', '2nd_hypergeometric_Integral', + '2nd_power_series_ordinary') + + sol = Eq(f(x), C2*(-x**4/3 + x**2 + 1) + C1*x + O(x**6)) + assert dsolve(eq, hint='2nd_power_series_ordinary') == sol + assert checkodesol(eq, sol) == (True, 0) + + eq = f(x).diff(x, 2) + x*(f(x).diff(x)) + f(x) + assert classify_ode(eq) == ('factorable', '2nd_power_series_ordinary',) + sol = Eq(f(x), C2*(x**4/8 - x**2/2 + 1) + C1*x*(-x**2/3 + 1) + O(x**6)) + assert dsolve(eq) == sol + # FIXME: checkodesol fails for this solution... + # assert checkodesol(eq, sol) == (True, 0) + + eq = f(x).diff(x, 2) + f(x).diff(x) - x*f(x) + assert classify_ode(eq) == ('2nd_power_series_ordinary',) + sol = Eq(f(x), C2*(-x**4/24 + x**3/6 + 1) + + C1*x*(x**3/24 + x**2/6 - x/2 + 1) + O(x**6)) + assert dsolve(eq) == sol + # FIXME: checkodesol fails for this solution... + # assert checkodesol(eq, sol) == (True, 0) + + eq = f(x).diff(x, 2) + x*f(x) + assert classify_ode(eq) == ('2nd_linear_airy', '2nd_power_series_ordinary') + sol = Eq(f(x), C2*(x**6/180 - x**3/6 + 1) + C1*x*(-x**3/12 + 1) + O(x**7)) + assert dsolve(eq, hint='2nd_power_series_ordinary', n=7) == sol + assert checkodesol(eq, sol) == (True, 0) + + +def test_2nd_power_series_regular(): + C1, C2, a = symbols("C1 C2 a") + eq = x**2*(f(x).diff(x, 2)) - 3*x*(f(x).diff(x)) + (4*x + 4)*f(x) + sol = Eq(f(x), C1*x**2*(-16*x**3/9 + 4*x**2 - 4*x + 1) + O(x**6)) + assert dsolve(eq, hint='2nd_power_series_regular') == sol + assert checkodesol(eq, sol) == (True, 0) + + eq = 4*x**2*(f(x).diff(x, 2)) -8*x**2*(f(x).diff(x)) + (4*x**2 + + 1)*f(x) + sol = Eq(f(x), C1*sqrt(x)*(x**4/24 + x**3/6 + x**2/2 + x + 1) + O(x**6)) + assert dsolve(eq, hint='2nd_power_series_regular') == sol + assert checkodesol(eq, sol) == (True, 0) + + eq = x**2*(f(x).diff(x, 2)) - x**2*(f(x).diff(x)) + ( + x**2 - 2)*f(x) + sol = Eq(f(x), C1*(-x**6/720 - 3*x**5/80 - x**4/8 + x**2/2 + x/2 + 1)/x + + C2*x**2*(-x**3/60 + x**2/20 + x/2 + 1) + O(x**6)) + assert dsolve(eq) == sol + assert checkodesol(eq, sol) == (True, 0) + + eq = x**2*(f(x).diff(x, 2)) + x*(f(x).diff(x)) + (x**2 - Rational(1, 4))*f(x) + sol = Eq(f(x), C1*(x**4/24 - x**2/2 + 1)/sqrt(x) + + C2*sqrt(x)*(x**4/120 - x**2/6 + 1) + O(x**6)) + assert dsolve(eq, hint='2nd_power_series_regular') == sol + assert checkodesol(eq, sol) == (True, 0) + + eq = x*f(x).diff(x, 2) + f(x).diff(x) - a*x*f(x) + sol = Eq(f(x), C1*(a**2*x**4/64 + a*x**2/4 + 1) + O(x**6)) + assert dsolve(eq, f(x), hint="2nd_power_series_regular") == sol + assert checkodesol(eq, sol) == (True, 0) + + eq = f(x).diff(x, 2) + ((1 - x)/x)*f(x).diff(x) + (a/x)*f(x) + sol = Eq(f(x), C1*(-a*x**5*(a - 4)*(a - 3)*(a - 2)*(a - 1)/14400 + \ + a*x**4*(a - 3)*(a - 2)*(a - 1)/576 - a*x**3*(a - 2)*(a - 1)/36 + \ + a*x**2*(a - 1)/4 - a*x + 1) + O(x**6)) + assert dsolve(eq, f(x), hint="2nd_power_series_regular") == sol + assert checkodesol(eq, sol) == (True, 0) + + +def test_issue_15056(): + t = Symbol('t') + C3 = Symbol('C3') + assert get_numbered_constants(Symbol('C1') * Function('C2')(t)) == C3 + + +def test_issue_15913(): + eq = -C1/x - 2*x*f(x) - f(x) + Derivative(f(x), x) + sol = C2*exp(x**2 + x) + exp(x**2 + x)*Integral(C1*exp(-x**2 - x)/x, x) + assert checkodesol(eq, sol) == (True, 0) + sol = C1 + C2*exp(-x*y) + eq = Derivative(y*f(x), x) + f(x).diff(x, 2) + assert checkodesol(eq, sol, f(x)) == (True, 0) + + +def test_issue_16146(): + raises(ValueError, lambda: dsolve([f(x).diff(x), g(x).diff(x)], [f(x), g(x), h(x)])) + raises(ValueError, lambda: dsolve([f(x).diff(x), g(x).diff(x)], [f(x)])) + + +def test_dsolve_remove_redundant_solutions(): + + eq = (f(x)-2)*f(x).diff(x) + sol = Eq(f(x), C1) + assert dsolve(eq) == sol + + eq = (f(x)-sin(x))*(f(x).diff(x, 2)) + sol = {Eq(f(x), C1 + C2*x), Eq(f(x), sin(x))} + assert set(dsolve(eq)) == sol + + eq = (f(x)**2-2*f(x)+1)*f(x).diff(x, 3) + sol = Eq(f(x), C1 + C2*x + C3*x**2) + assert dsolve(eq) == sol + + +def test_issue_13060(): + A, B = symbols("A B", cls=Function) + t = Symbol("t") + eq = [Eq(Derivative(A(t), t), A(t)*B(t)), Eq(Derivative(B(t), t), A(t)*B(t))] + sol = dsolve(eq) + assert checkodesol(eq, sol) == (True, [0, 0]) + + +def test_issue_22523(): + N, s = symbols('N s') + rho = Function('rho') + # intentionally use 4.0 to confirm issue with nfloat + # works here + eqn = 4.0*N*sqrt(N - 1)*rho(s) + (4*s**2*(N - 1) + (N - 2*s*(N - 1))**2 + )*Derivative(rho(s), (s, 2)) + match = classify_ode(eqn, dict=True, hint='all') + assert match['2nd_power_series_ordinary']['terms'] == 5 + C1, C2 = symbols('C1,C2') + sol = dsolve(eqn, hint='2nd_power_series_ordinary') + # there is no r(2.0) in this result + assert filldedent(sol) == filldedent(str(''' + Eq(rho(s), C2*(1 - 4.0*s**4*sqrt(N - 1.0)/N + 0.666666666666667*s**4/N + - 2.66666666666667*s**3*sqrt(N - 1.0)/N - 2.0*s**2*sqrt(N - 1.0)/N + + 9.33333333333333*s**4*sqrt(N - 1.0)/N**2 - 0.666666666666667*s**4/N**2 + + 2.66666666666667*s**3*sqrt(N - 1.0)/N**2 - + 5.33333333333333*s**4*sqrt(N - 1.0)/N**3) + C1*s*(1.0 - + 1.33333333333333*s**3*sqrt(N - 1.0)/N - 0.666666666666667*s**2*sqrt(N + - 1.0)/N + 1.33333333333333*s**3*sqrt(N - 1.0)/N**2) + O(s**6))''')) + + +def test_issue_22604(): + x1, x2 = symbols('x1, x2', cls = Function) + t, k1, k2, m1, m2 = symbols('t k1 k2 m1 m2', real = True) + k1, k2, m1, m2 = 1, 1, 1, 1 + eq1 = Eq(m1*diff(x1(t), t, 2) + k1*x1(t) - k2*(x2(t) - x1(t)), 0) + eq2 = Eq(m2*diff(x2(t), t, 2) + k2*(x2(t) - x1(t)), 0) + eqs = [eq1, eq2] + [x1sol, x2sol] = dsolve(eqs, [x1(t), x2(t)], ics = {x1(0):0, x1(t).diff().subs(t,0):0, \ + x2(0):1, x2(t).diff().subs(t,0):0}) + assert x1sol == Eq(x1(t), sqrt(3 - sqrt(5))*(sqrt(10) + 5*sqrt(2))*cos(sqrt(2)*t*sqrt(3 - sqrt(5))/2)/20 + \ + (-5*sqrt(2) + sqrt(10))*sqrt(sqrt(5) + 3)*cos(sqrt(2)*t*sqrt(sqrt(5) + 3)/2)/20) + assert x2sol == Eq(x2(t), (sqrt(5) + 5)*cos(sqrt(2)*t*sqrt(3 - sqrt(5))/2)/10 + (5 - sqrt(5))*cos(sqrt(2)*t*sqrt(sqrt(5) + 3)/2)/10) + + +def test_issue_22462(): + for de in [ + Eq(f(x).diff(x), -20*f(x)**2 - 500*f(x)/7200), + Eq(f(x).diff(x), -2*f(x)**2 - 5*f(x)/7)]: + assert 'Bernoulli' in classify_ode(de, f(x)) + + +def test_issue_23425(): + x = symbols('x') + y = Function('y') + eq = Eq(-E**x*y(x).diff().diff() + y(x).diff(), 0) + assert classify_ode(eq) == \ + ('Liouville', 'nth_order_reducible', \ + '2nd_power_series_ordinary', 'Liouville_Integral') + + +@SKIP("too slow for @slow") +def test_issue_25820(): + x = Symbol('x') + y = Function('y') + eq = y(x)**3*y(x).diff(x, 2) + 49 + assert dsolve(eq, y(x)) is not None # doesn't raise diff --git a/phivenv/Lib/site-packages/sympy/solvers/ode/tests/test_riccati.py b/phivenv/Lib/site-packages/sympy/solvers/ode/tests/test_riccati.py new file mode 100644 index 0000000000000000000000000000000000000000..548a1ee5b5e82d88f1b0aa319af09b8b9d1d9bfe --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/solvers/ode/tests/test_riccati.py @@ -0,0 +1,877 @@ +from sympy.core.random import randint +from sympy.core.function import Function +from sympy.core.mul import Mul +from sympy.core.numbers import (I, Rational, oo) +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, symbols) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.hyperbolic import tanh +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import sin +from sympy.polys.polytools import Poly +from sympy.simplify.ratsimp import ratsimp +from sympy.solvers.ode.subscheck import checkodesol +from sympy.testing.pytest import slow +from sympy.solvers.ode.riccati import (riccati_normal, riccati_inverse_normal, + riccati_reduced, match_riccati, inverse_transform_poly, limit_at_inf, + check_necessary_conds, val_at_inf, construct_c_case_1, + construct_c_case_2, construct_c_case_3, construct_d_case_4, + construct_d_case_5, construct_d_case_6, rational_laurent_series, + solve_riccati) + +f = Function('f') +x = symbols('x') + +# These are the functions used to generate the tests +# SHOULD NOT BE USED DIRECTLY IN TESTS + +def rand_rational(maxint): + return Rational(randint(-maxint, maxint), randint(1, maxint)) + + +def rand_poly(x, degree, maxint): + return Poly([rand_rational(maxint) for _ in range(degree+1)], x) + + +def rand_rational_function(x, degree, maxint): + degnum = randint(1, degree) + degden = randint(1, degree) + num = rand_poly(x, degnum, maxint) + den = rand_poly(x, degden, maxint) + while den == Poly(0, x): + den = rand_poly(x, degden, maxint) + return num / den + + +def find_riccati_ode(ratfunc, x, yf): + y = ratfunc + yp = y.diff(x) + q1 = rand_rational_function(x, 1, 3) + q2 = rand_rational_function(x, 1, 3) + while q2 == 0: + q2 = rand_rational_function(x, 1, 3) + q0 = ratsimp(yp - q1*y - q2*y**2) + eq = Eq(yf.diff(), q0 + q1*yf + q2*yf**2) + sol = Eq(yf, y) + assert checkodesol(eq, sol) == (True, 0) + return eq, q0, q1, q2 + + +# Testing functions start + +def test_riccati_transformation(): + """ + This function tests the transformation of the + solution of a Riccati ODE to the solution of + its corresponding normal Riccati ODE. + + Each test case 4 values - + + 1. w - The solution to be transformed + 2. b1 - The coefficient of f(x) in the ODE. + 3. b2 - The coefficient of f(x)**2 in the ODE. + 4. y - The solution to the normal Riccati ODE. + """ + tests = [ + ( + x/(x - 1), + (x**2 + 7)/3*x, + x, + -x**2/(x - 1) - x*(x**2/3 + S(7)/3)/2 - 1/(2*x) + ), + ( + (2*x + 3)/(2*x + 2), + (3 - 3*x)/(x + 1), + 5*x, + -5*x*(2*x + 3)/(2*x + 2) - (3 - 3*x)/(Mul(2, x + 1, evaluate=False)) - 1/(2*x) + ), + ( + -1/(2*x**2 - 1), + 0, + (2 - x)/(4*x - 2), + (2 - x)/((4*x - 2)*(2*x**2 - 1)) - (4*x - 2)*(Mul(-4, 2 - x, evaluate=False)/(4*x - \ + 2)**2 - 1/(4*x - 2))/(Mul(2, 2 - x, evaluate=False)) + ), + ( + x, + (8*x - 12)/(12*x + 9), + x**3/(6*x - 9), + -x**4/(6*x - 9) - (8*x - 12)/(Mul(2, 12*x + 9, evaluate=False)) - (6*x - 9)*(-6*x**3/(6*x \ + - 9)**2 + 3*x**2/(6*x - 9))/(2*x**3) + )] + for w, b1, b2, y in tests: + assert y == riccati_normal(w, x, b1, b2) + assert w == riccati_inverse_normal(y, x, b1, b2).cancel() + + # Test bp parameter in riccati_inverse_normal + tests = [ + ( + (-2*x - 1)/(2*x**2 + 2*x - 2), + -2/x, + (-x - 1)/(4*x), + 8*x**2*(1/(4*x) + (-x - 1)/(4*x**2))/(-x - 1)**2 + 4/(-x - 1), + -2*x*(-1/(4*x) - (-x - 1)/(4*x**2))/(-x - 1) - (-2*x - 1)*(-x - 1)/(4*x*(2*x**2 + 2*x \ + - 2)) + 1/x + ), + ( + 3/(2*x**2), + -2/x, + (-x - 1)/(4*x), + 8*x**2*(1/(4*x) + (-x - 1)/(4*x**2))/(-x - 1)**2 + 4/(-x - 1), + -2*x*(-1/(4*x) - (-x - 1)/(4*x**2))/(-x - 1) + 1/x - Mul(3, -x - 1, evaluate=False)/(8*x**3) + )] + for w, b1, b2, bp, y in tests: + assert y == riccati_normal(w, x, b1, b2) + assert w == riccati_inverse_normal(y, x, b1, b2, bp).cancel() + + +def test_riccati_reduced(): + """ + This function tests the transformation of a + Riccati ODE to its normal Riccati ODE. + + Each test case 2 values - + + 1. eq - A Riccati ODE. + 2. normal_eq - The normal Riccati ODE of eq. + """ + tests = [ + ( + f(x).diff(x) - x**2 - x*f(x) - x*f(x)**2, + + f(x).diff(x) + f(x)**2 + x**3 - x**2/4 - 3/(4*x**2) + ), + ( + 6*x/(2*x + 9) + f(x).diff(x) - (x + 1)*f(x)**2/x, + + -3*x**2*(1/x + (-x - 1)/x**2)**2/(4*(-x - 1)**2) + Mul(6, \ + -x - 1, evaluate=False)/(2*x + 9) + f(x)**2 + f(x).diff(x) \ + - (-1 + (x + 1)/x)/(x*(-x - 1)) + ), + ( + f(x)**2 + f(x).diff(x) - (x - 1)*f(x)/(-x - S(1)/2), + + -(2*x - 2)**2/(4*(2*x + 1)**2) + (2*x - 2)/(2*x + 1)**2 + \ + f(x)**2 + f(x).diff(x) - 1/(2*x + 1) + ), + ( + f(x).diff(x) - f(x)**2/x, + + f(x)**2 + f(x).diff(x) + 1/(4*x**2) + ), + ( + -3*(-x**2 - x + 1)/(x**2 + 6*x + 1) + f(x).diff(x) + f(x)**2/x, + + f(x)**2 + f(x).diff(x) + (3*x**2/(x**2 + 6*x + 1) + 3*x/(x**2 \ + + 6*x + 1) - 3/(x**2 + 6*x + 1))/x + 1/(4*x**2) + ), + ( + 6*x/(2*x + 9) + f(x).diff(x) - (x + 1)*f(x)/x, + + False + ), + ( + f(x)*f(x).diff(x) - 1/x + f(x)/3 + f(x)**2/(x**2 - 2), + + False + )] + for eq, normal_eq in tests: + assert normal_eq == riccati_reduced(eq, f, x) + + +def test_match_riccati(): + """ + This function tests if an ODE is Riccati or not. + + Each test case has 5 values - + + 1. eq - The Riccati ODE. + 2. match - Boolean indicating if eq is a Riccati ODE. + 3. b0 - + 4. b1 - Coefficient of f(x) in eq. + 5. b2 - Coefficient of f(x)**2 in eq. + """ + tests = [ + # Test Rational Riccati ODEs + ( + f(x).diff(x) - (405*x**3 - 882*x**2 - 78*x + 92)/(243*x**4 \ + - 945*x**3 + 846*x**2 + 180*x - 72) - 2 - f(x)**2/(3*x + 1) \ + - (S(1)/3 - x)*f(x)/(S(1)/3 - 3*x/2), + + True, + + 45*x**3/(27*x**4 - 105*x**3 + 94*x**2 + 20*x - 8) - 98*x**2/ \ + (27*x**4 - 105*x**3 + 94*x**2 + 20*x - 8) - 26*x/(81*x**4 - \ + 315*x**3 + 282*x**2 + 60*x - 24) + 2 + 92/(243*x**4 - 945*x**3 \ + + 846*x**2 + 180*x - 72), + + Mul(-1, 2 - 6*x, evaluate=False)/(9*x - 2), + + 1/(3*x + 1) + ), + ( + f(x).diff(x) + 4*x/27 - (x/3 - 1)*f(x)**2 - (2*x/3 + \ + 1)*f(x)/(3*x + 2) - S(10)/27 - (265*x**2 + 423*x + 162) \ + /(324*x**3 + 216*x**2), + + True, + + -4*x/27 + S(10)/27 + 3/(6*x**3 + 4*x**2) + 47/(36*x**2 \ + + 24*x) + 265/(324*x + 216), + + Mul(-1, -2*x - 3, evaluate=False)/(9*x + 6), + + x/3 - 1 + ), + ( + f(x).diff(x) - (304*x**5 - 745*x**4 + 631*x**3 - 876*x**2 \ + + 198*x - 108)/(36*x**6 - 216*x**5 + 477*x**4 - 567*x**3 + \ + 360*x**2 - 108*x) - S(17)/9 - (x - S(3)/2)*f(x)/(x/2 - \ + S(3)/2) - (x/3 - 3)*f(x)**2/(3*x), + + True, + + 304*x**4/(36*x**5 - 216*x**4 + 477*x**3 - 567*x**2 + 360*x - \ + 108) - 745*x**3/(36*x**5 - 216*x**4 + 477*x**3 - 567*x**2 + \ + 360*x - 108) + 631*x**2/(36*x**5 - 216*x**4 + 477*x**3 - 567* \ + x**2 + 360*x - 108) - 292*x/(12*x**5 - 72*x**4 + 159*x**3 - \ + 189*x**2 + 120*x - 36) + S(17)/9 - 12/(4*x**6 - 24*x**5 + \ + 53*x**4 - 63*x**3 + 40*x**2 - 12*x) + 22/(4*x**5 - 24*x**4 \ + + 53*x**3 - 63*x**2 + 40*x - 12), + + Mul(-1, 3 - 2*x, evaluate=False)/(x - 3), + + Mul(-1, 9 - x, evaluate=False)/(9*x) + ), + # Test Non-Rational Riccati ODEs + ( + f(x).diff(x) - x**(S(3)/2)/(x**(S(1)/2) - 2) + x**2*f(x) + \ + x*f(x)**2/(x**(S(3)/4)), + False, 0, 0, 0 + ), + ( + f(x).diff(x) - sin(x**2) + exp(x)*f(x) + log(x)*f(x)**2, + False, 0, 0, 0 + ), + ( + f(x).diff(x) - tanh(x + sqrt(x)) + f(x) + x**4*f(x)**2, + False, 0, 0, 0 + ), + # Test Non-Riccati ODEs + ( + (1 - x**2)*f(x).diff(x, 2) - 2*x*f(x).diff(x) + 20*f(x), + False, 0, 0, 0 + ), + ( + f(x).diff(x) - x**2 + x**3*f(x) + (x**2/(x + 1))*f(x)**3, + False, 0, 0, 0 + ), + ( + f(x).diff(x)*f(x)**2 + (x**2 - 1)/(x**3 + 1)*f(x) + 1/(2*x \ + + 3) + f(x)**2, + False, 0, 0, 0 + )] + for eq, res, b0, b1, b2 in tests: + match, funcs = match_riccati(eq, f, x) + assert match == res + if res: + assert [b0, b1, b2] == funcs + + +def test_val_at_inf(): + """ + This function tests the valuation of rational + function at oo. + + Each test case has 3 values - + + 1. num - Numerator of rational function. + 2. den - Denominator of rational function. + 3. val_inf - Valuation of rational function at oo + """ + tests = [ + # degree(denom) > degree(numer) + ( + Poly(10*x**3 + 8*x**2 - 13*x + 6, x), + Poly(-13*x**10 - x**9 + 5*x**8 + 7*x**7 + 10*x**6 + 6*x**5 - 7*x**4 + 11*x**3 - 8*x**2 + 5*x + 13, x), + 7 + ), + ( + Poly(1, x), + Poly(-9*x**4 + 3*x**3 + 15*x**2 - 6*x - 14, x), + 4 + ), + # degree(denom) == degree(numer) + ( + Poly(-6*x**3 - 8*x**2 + 8*x - 6, x), + Poly(-5*x**3 + 12*x**2 - 6*x - 9, x), + 0 + ), + # degree(denom) < degree(numer) + ( + Poly(12*x**8 - 12*x**7 - 11*x**6 + 8*x**5 + 3*x**4 - x**3 + x**2 - 11*x, x), + Poly(-14*x**2 + x, x), + -6 + ), + ( + Poly(5*x**6 + 9*x**5 - 11*x**4 - 9*x**3 + x**2 - 4*x + 4, x), + Poly(15*x**4 + 3*x**3 - 8*x**2 + 15*x + 12, x), + -2 + )] + for num, den, val in tests: + assert val_at_inf(num, den, x) == val + + +def test_necessary_conds(): + """ + This function tests the necessary conditions for + a Riccati ODE to have a rational particular solution. + """ + # Valuation at Infinity is an odd negative integer + assert check_necessary_conds(-3, [1, 2, 4]) == False + # Valuation at Infinity is a positive integer lesser than 2 + assert check_necessary_conds(1, [1, 2, 4]) == False + # Multiplicity of a pole is an odd integer greater than 1 + assert check_necessary_conds(2, [3, 1, 6]) == False + # All values are correct + assert check_necessary_conds(-10, [1, 2, 8, 12]) == True + + +def test_inverse_transform_poly(): + """ + This function tests the substitution x -> 1/x + in rational functions represented using Poly. + """ + fns = [ + (15*x**3 - 8*x**2 - 2*x - 6)/(18*x + 6), + + (180*x**5 + 40*x**4 + 80*x**3 + 30*x**2 - 60*x - 80)/(180*x**3 - 150*x**2 + 75*x + 12), + + (-15*x**5 - 36*x**4 + 75*x**3 - 60*x**2 - 80*x - 60)/(80*x**4 + 60*x**3 + 60*x**2 + 60*x - 80), + + (60*x**7 + 24*x**6 - 15*x**5 - 20*x**4 + 30*x**2 + 100*x - 60)/(240*x**2 - 20*x - 30), + + (30*x**6 - 12*x**5 + 15*x**4 - 15*x**2 + 10*x + 60)/(3*x**10 - 45*x**9 + 15*x**5 + 15*x**4 - 5*x**3 \ + + 15*x**2 + 45*x - 15) + ] + for f in fns: + num, den = [Poly(e, x) for e in f.as_numer_denom()] + num, den = inverse_transform_poly(num, den, x) + assert f.subs(x, 1/x).cancel() == num/den + + +def test_limit_at_inf(): + """ + This function tests the limit at oo of a + rational function. + + Each test case has 3 values - + + 1. num - Numerator of rational function. + 2. den - Denominator of rational function. + 3. limit_at_inf - Limit of rational function at oo + """ + tests = [ + # deg(denom) > deg(numer) + ( + Poly(-12*x**2 + 20*x + 32, x), + Poly(32*x**3 + 72*x**2 + 3*x - 32, x), + 0 + ), + # deg(denom) < deg(numer) + ( + Poly(1260*x**4 - 1260*x**3 - 700*x**2 - 1260*x + 1400, x), + Poly(6300*x**3 - 1575*x**2 + 756*x - 540, x), + oo + ), + # deg(denom) < deg(numer), one of the leading coefficients is negative + ( + Poly(-735*x**8 - 1400*x**7 + 1680*x**6 - 315*x**5 - 600*x**4 + 840*x**3 - 525*x**2 \ + + 630*x + 3780, x), + Poly(1008*x**7 - 2940*x**6 - 84*x**5 + 2940*x**4 - 420*x**3 + 1512*x**2 + 105*x + 168, x), + -oo + ), + # deg(denom) == deg(numer) + ( + Poly(105*x**7 - 960*x**6 + 60*x**5 + 60*x**4 - 80*x**3 + 45*x**2 + 120*x + 15, x), + Poly(735*x**7 + 525*x**6 + 720*x**5 + 720*x**4 - 8400*x**3 - 2520*x**2 + 2800*x + 280, x), + S(1)/7 + ), + ( + Poly(288*x**4 - 450*x**3 + 280*x**2 - 900*x - 90, x), + Poly(607*x**4 + 840*x**3 - 1050*x**2 + 420*x + 420, x), + S(288)/607 + )] + for num, den, lim in tests: + assert limit_at_inf(num, den, x) == lim + + +def test_construct_c_case_1(): + """ + This function tests the Case 1 in the step + to calculate coefficients of c-vectors. + + Each test case has 4 values - + + 1. num - Numerator of the rational function a(x). + 2. den - Denominator of the rational function a(x). + 3. pole - Pole of a(x) for which c-vector is being + calculated. + 4. c - The c-vector for the pole. + """ + tests = [ + ( + Poly(-3*x**3 + 3*x**2 + 4*x - 5, x, extension=True), + Poly(4*x**8 + 16*x**7 + 9*x**5 + 12*x**4 + 6*x**3 + 12*x**2, x, extension=True), + S(0), + [[S(1)/2 + sqrt(6)*I/6], [S(1)/2 - sqrt(6)*I/6]] + ), + ( + Poly(1200*x**3 + 1440*x**2 + 816*x + 560, x, extension=True), + Poly(128*x**5 - 656*x**4 + 1264*x**3 - 1125*x**2 + 385*x + 49, x, extension=True), + S(7)/4, + [[S(1)/2 + sqrt(16367978)/634], [S(1)/2 - sqrt(16367978)/634]] + ), + ( + Poly(4*x + 2, x, extension=True), + Poly(18*x**4 + (2 - 18*sqrt(3))*x**3 + (14 - 11*sqrt(3))*x**2 + (4 - 6*sqrt(3))*x \ + + 8*sqrt(3) + 16, x, domain='QQ'), + (S(1) + sqrt(3))/2, + [[S(1)/2 + sqrt(Mul(4, 2*sqrt(3) + 4, evaluate=False)/(19*sqrt(3) + 44) + 1)/2], \ + [S(1)/2 - sqrt(Mul(4, 2*sqrt(3) + 4, evaluate=False)/(19*sqrt(3) + 44) + 1)/2]] + )] + for num, den, pole, c in tests: + assert construct_c_case_1(num, den, x, pole) == c + + +def test_construct_c_case_2(): + """ + This function tests the Case 2 in the step + to calculate coefficients of c-vectors. + + Each test case has 5 values - + + 1. num - Numerator of the rational function a(x). + 2. den - Denominator of the rational function a(x). + 3. pole - Pole of a(x) for which c-vector is being + calculated. + 4. mul - The multiplicity of the pole. + 5. c - The c-vector for the pole. + """ + tests = [ + # Testing poles with multiplicity 2 + ( + Poly(1, x, extension=True), + Poly((x - 1)**2*(x - 2), x, extension=True), + 1, 2, + [[-I*(-1 - I)/2], [I*(-1 + I)/2]] + ), + ( + Poly(3*x**5 - 12*x**4 - 7*x**3 + 1, x, extension=True), + Poly((3*x - 1)**2*(x + 2)**2, x, extension=True), + S(1)/3, 2, + [[-S(89)/98], [-S(9)/98]] + ), + # Testing poles with multiplicity 4 + ( + Poly(x**3 - x**2 + 4*x, x, extension=True), + Poly((x - 2)**4*(x + 5)**2, x, extension=True), + 2, 4, + [[7*sqrt(3)*(S(60)/343 - 4*sqrt(3)/7)/12, 2*sqrt(3)/7], \ + [-7*sqrt(3)*(S(60)/343 + 4*sqrt(3)/7)/12, -2*sqrt(3)/7]] + ), + ( + Poly(3*x**5 + x**4 + 3, x, extension=True), + Poly((4*x + 1)**4*(x + 2), x, extension=True), + -S(1)/4, 4, + [[128*sqrt(439)*(-sqrt(439)/128 - S(55)/14336)/439, sqrt(439)/256], \ + [-128*sqrt(439)*(sqrt(439)/128 - S(55)/14336)/439, -sqrt(439)/256]] + ), + # Testing poles with multiplicity 6 + ( + Poly(x**3 + 2, x, extension=True), + Poly((3*x - 1)**6*(x**2 + 1), x, extension=True), + S(1)/3, 6, + [[27*sqrt(66)*(-sqrt(66)/54 - S(131)/267300)/22, -2*sqrt(66)/1485, sqrt(66)/162], \ + [-27*sqrt(66)*(sqrt(66)/54 - S(131)/267300)/22, 2*sqrt(66)/1485, -sqrt(66)/162]] + ), + ( + Poly(x**2 + 12, x, extension=True), + Poly((x - sqrt(2))**6, x, extension=True), + sqrt(2), 6, + [[sqrt(14)*(S(6)/7 - 3*sqrt(14))/28, sqrt(7)/7, sqrt(14)], \ + [-sqrt(14)*(S(6)/7 + 3*sqrt(14))/28, -sqrt(7)/7, -sqrt(14)]] + )] + for num, den, pole, mul, c in tests: + assert construct_c_case_2(num, den, x, pole, mul) == c + + +def test_construct_c_case_3(): + """ + This function tests the Case 3 in the step + to calculate coefficients of c-vectors. + """ + assert construct_c_case_3() == [[1]] + + +def test_construct_d_case_4(): + """ + This function tests the Case 4 in the step + to calculate coefficients of the d-vector. + + Each test case has 4 values - + + 1. num - Numerator of the rational function a(x). + 2. den - Denominator of the rational function a(x). + 3. mul - Multiplicity of oo as a pole. + 4. d - The d-vector. + """ + tests = [ + # Tests with multiplicity at oo = 2 + ( + Poly(-x**5 - 2*x**4 + 4*x**3 + 2*x + 5, x, extension=True), + Poly(9*x**3 - 2*x**2 + 10*x - 2, x, extension=True), + 2, + [[10*I/27, I/3, -3*I*(S(158)/243 - I/3)/2], \ + [-10*I/27, -I/3, 3*I*(S(158)/243 + I/3)/2]] + ), + ( + Poly(-x**6 + 9*x**5 + 5*x**4 + 6*x**3 + 5*x**2 + 6*x + 7, x, extension=True), + Poly(x**4 + 3*x**3 + 12*x**2 - x + 7, x, extension=True), + 2, + [[-6*I, I, -I*(17 - I)/2], [6*I, -I, I*(17 + I)/2]] + ), + # Tests with multiplicity at oo = 4 + ( + Poly(-2*x**6 - x**5 - x**4 - 2*x**3 - x**2 - 3*x - 3, x, extension=True), + Poly(3*x**2 + 10*x + 7, x, extension=True), + 4, + [[269*sqrt(6)*I/288, -17*sqrt(6)*I/36, sqrt(6)*I/3, -sqrt(6)*I*(S(16969)/2592 \ + - 2*sqrt(6)*I/3)/4], [-269*sqrt(6)*I/288, 17*sqrt(6)*I/36, -sqrt(6)*I/3, \ + sqrt(6)*I*(S(16969)/2592 + 2*sqrt(6)*I/3)/4]] + ), + ( + Poly(-3*x**5 - 3*x**4 - 3*x**3 - x**2 - 1, x, extension=True), + Poly(12*x - 2, x, extension=True), + 4, + [[41*I/192, 7*I/24, I/2, -I*(-S(59)/6912 - I)], \ + [-41*I/192, -7*I/24, -I/2, I*(-S(59)/6912 + I)]] + ), + # Tests with multiplicity at oo = 4 + ( + Poly(-x**7 - x**5 - x**4 - x**2 - x, x, extension=True), + Poly(x + 2, x, extension=True), + 6, + [[-5*I/2, 2*I, -I, I, -I*(-9 - 3*I)/2], [5*I/2, -2*I, I, -I, I*(-9 + 3*I)/2]] + ), + ( + Poly(-x**7 - x**6 - 2*x**5 - 2*x**4 - x**3 - x**2 + 2*x - 2, x, extension=True), + Poly(2*x - 2, x, extension=True), + 6, + [[3*sqrt(2)*I/4, 3*sqrt(2)*I/4, sqrt(2)*I/2, sqrt(2)*I/2, -sqrt(2)*I*(-S(7)/8 - \ + 3*sqrt(2)*I/2)/2], [-3*sqrt(2)*I/4, -3*sqrt(2)*I/4, -sqrt(2)*I/2, -sqrt(2)*I/2, \ + sqrt(2)*I*(-S(7)/8 + 3*sqrt(2)*I/2)/2]] + )] + for num, den, mul, d in tests: + ser = rational_laurent_series(num, den, x, oo, mul, 1) + assert construct_d_case_4(ser, mul//2) == d + + +def test_construct_d_case_5(): + """ + This function tests the Case 5 in the step + to calculate coefficients of the d-vector. + + Each test case has 3 values - + + 1. num - Numerator of the rational function a(x). + 2. den - Denominator of the rational function a(x). + 3. d - The d-vector. + """ + tests = [ + ( + Poly(2*x**3 + x**2 + x - 2, x, extension=True), + Poly(9*x**3 + 5*x**2 + 2*x - 1, x, extension=True), + [[sqrt(2)/3, -sqrt(2)/108], [-sqrt(2)/3, sqrt(2)/108]] + ), + ( + Poly(3*x**5 + x**4 - x**3 + x**2 - 2*x - 2, x, domain='ZZ'), + Poly(9*x**5 + 7*x**4 + 3*x**3 + 2*x**2 + 5*x + 7, x, domain='ZZ'), + [[sqrt(3)/3, -2*sqrt(3)/27], [-sqrt(3)/3, 2*sqrt(3)/27]] + ), + ( + Poly(x**2 - x + 1, x, domain='ZZ'), + Poly(3*x**2 + 7*x + 3, x, domain='ZZ'), + [[sqrt(3)/3, -5*sqrt(3)/9], [-sqrt(3)/3, 5*sqrt(3)/9]] + )] + for num, den, d in tests: + # Multiplicity of oo is 0 + ser = rational_laurent_series(num, den, x, oo, 0, 1) + assert construct_d_case_5(ser) == d + + +def test_construct_d_case_6(): + """ + This function tests the Case 6 in the step + to calculate coefficients of the d-vector. + + Each test case has 3 values - + + 1. num - Numerator of the rational function a(x). + 2. den - Denominator of the rational function a(x). + 3. d - The d-vector. + """ + tests = [ + ( + Poly(-2*x**2 - 5, x, domain='ZZ'), + Poly(4*x**4 + 2*x**2 + 10*x + 2, x, domain='ZZ'), + [[S(1)/2 + I/2], [S(1)/2 - I/2]] + ), + ( + Poly(-2*x**3 - 4*x**2 - 2*x - 5, x, domain='ZZ'), + Poly(x**6 - x**5 + 2*x**4 - 4*x**3 - 5*x**2 - 5*x + 9, x, domain='ZZ'), + [[1], [0]] + ), + ( + Poly(-5*x**3 + x**2 + 11*x + 12, x, domain='ZZ'), + Poly(6*x**8 - 26*x**7 - 27*x**6 - 10*x**5 - 44*x**4 - 46*x**3 - 34*x**2 \ + - 27*x - 42, x, domain='ZZ'), + [[1], [0]] + )] + for num, den, d in tests: + assert construct_d_case_6(num, den, x) == d + + +def test_rational_laurent_series(): + """ + This function tests the computation of coefficients + of Laurent series of a rational function. + + Each test case has 5 values - + + 1. num - Numerator of the rational function. + 2. den - Denominator of the rational function. + 3. x0 - Point about which Laurent series is to + be calculated. + 4. mul - Multiplicity of x0 if x0 is a pole of + the rational function (0 otherwise). + 5. n - Number of terms upto which the series + is to be calculated. + """ + tests = [ + # Laurent series about simple pole (Multiplicity = 1) + ( + Poly(x**2 - 3*x + 9, x, extension=True), + Poly(x**2 - x, x, extension=True), + S(1), 1, 6, + {1: 7, 0: -8, -1: 9, -2: -9, -3: 9, -4: -9} + ), + # Laurent series about multiple pole (Multiplicity > 1) + ( + Poly(64*x**3 - 1728*x + 1216, x, extension=True), + Poly(64*x**4 - 80*x**3 - 831*x**2 + 1809*x - 972, x, extension=True), + S(9)/8, 2, 3, + {0: S(32177152)/46521675, 2: S(1019)/984, -1: S(11947565056)/28610830125, \ + 1: S(209149)/75645} + ), + ( + Poly(1, x, extension=True), + Poly(x**5 + (-4*sqrt(2) - 1)*x**4 + (4*sqrt(2) + 12)*x**3 + (-12 - 8*sqrt(2))*x**2 \ + + (4 + 8*sqrt(2))*x - 4, x, extension=True), + sqrt(2), 4, 6, + {4: 1 + sqrt(2), 3: -3 - 2*sqrt(2), 2: Mul(-1, -3 - 2*sqrt(2), evaluate=False)/(-1 \ + + sqrt(2)), 1: (-3 - 2*sqrt(2))/(-1 + sqrt(2))**2, 0: Mul(-1, -3 - 2*sqrt(2), evaluate=False \ + )/(-1 + sqrt(2))**3, -1: (-3 - 2*sqrt(2))/(-1 + sqrt(2))**4} + ), + # Laurent series about oo + ( + Poly(x**5 - 4*x**3 + 6*x**2 + 10*x - 13, x, extension=True), + Poly(x**2 - 5, x, extension=True), + oo, 3, 6, + {3: 1, 2: 0, 1: 1, 0: 6, -1: 15, -2: 17} + ), + # Laurent series at x0 where x0 is not a pole of the function + # Using multiplicity as 0 (as x0 will not be a pole) + ( + Poly(3*x**3 + 6*x**2 - 2*x + 5, x, extension=True), + Poly(9*x**4 - x**3 - 3*x**2 + 4*x + 4, x, extension=True), + S(2)/5, 0, 1, + {0: S(3345)/3304, -1: S(399325)/2729104, -2: S(3926413375)/4508479808, \ + -3: S(-5000852751875)/1862002160704, -4: S(-6683640101653125)/6152055138966016} + ), + ( + Poly(-7*x**2 + 2*x - 4, x, extension=True), + Poly(7*x**5 + 9*x**4 + 8*x**3 + 3*x**2 + 6*x + 9, x, extension=True), + oo, 0, 6, + {0: 0, -2: 0, -5: -S(71)/49, -1: 0, -3: -1, -4: S(11)/7} + )] + for num, den, x0, mul, n, ser in tests: + assert ser == rational_laurent_series(num, den, x, x0, mul, n) + + +def check_dummy_sol(eq, solse, dummy_sym): + """ + Helper function to check if actual solution + matches expected solution if actual solution + contains dummy symbols. + """ + if isinstance(eq, Eq): + eq = eq.lhs - eq.rhs + _, funcs = match_riccati(eq, f, x) + + sols = solve_riccati(f(x), x, *funcs) + C1 = Dummy('C1') + sols = [sol.subs(C1, dummy_sym) for sol in sols] + + assert all(x[0] for x in checkodesol(eq, sols)) + assert all(s1.dummy_eq(s2, dummy_sym) for s1, s2 in zip(sols, solse)) + + +def test_solve_riccati(): + """ + This function tests the computation of rational + particular solutions for a Riccati ODE. + + Each test case has 2 values - + + 1. eq - Riccati ODE to be solved. + 2. sol - Expected solution to the equation. + + Some examples have been taken from the paper - "Statistical Investigation of + First-Order Algebraic ODEs and their Rational General Solutions" by + Georg Grasegger, N. Thieu Vo, Franz Winkler + + https://www3.risc.jku.at/publications/download/risc_5197/RISCReport15-19.pdf + """ + C0 = Dummy('C0') + # Type: 1st Order Rational Riccati, dy/dx = a + b*y + c*y**2, + # a, b, c are rational functions of x + + tests = [ + # a(x) is a constant + ( + Eq(f(x).diff(x) + f(x)**2 - 2, 0), + [Eq(f(x), sqrt(2)), Eq(f(x), -sqrt(2))] + ), + # a(x) is a constant + ( + f(x)**2 + f(x).diff(x) + 4*f(x)/x + 2/x**2, + [Eq(f(x), (-2*C0 - x)/(C0*x + x**2))] + ), + # a(x) is a constant + ( + 2*x**2*f(x).diff(x) - x*(4*f(x) + f(x).diff(x) - 4) + (f(x) - 1)*f(x), + [Eq(f(x), (C0 + 2*x**2)/(C0 + x))] + ), + # Pole with multiplicity 1 + ( + Eq(f(x).diff(x), -f(x)**2 - 2/(x**3 - x**2)), + [Eq(f(x), 1/(x**2 - x))] + ), + # One pole of multiplicity 2 + ( + x**2 - (2*x + 1/x)*f(x) + f(x)**2 + f(x).diff(x), + [Eq(f(x), (C0*x + x**3 + 2*x)/(C0 + x**2)), Eq(f(x), x)] + ), + ( + x**4*f(x).diff(x) + x**2 - x*(2*f(x)**2 + f(x).diff(x)) + f(x), + [Eq(f(x), (C0*x**2 + x)/(C0 + x**2)), Eq(f(x), x**2)] + ), + # Multiple poles of multiplicity 2 + ( + -f(x)**2 + f(x).diff(x) + (15*x**2 - 20*x + 7)/((x - 1)**2*(2*x \ + - 1)**2), + [Eq(f(x), (9*C0*x - 6*C0 - 15*x**5 + 60*x**4 - 94*x**3 + 72*x**2 \ + - 30*x + 6)/(6*C0*x**2 - 9*C0*x + 3*C0 + 6*x**6 - 29*x**5 + \ + 57*x**4 - 58*x**3 + 30*x**2 - 6*x)), Eq(f(x), (3*x - 2)/(2*x**2 \ + - 3*x + 1))] + ), + # Regression: Poles with even multiplicity > 2 fixed + ( + f(x)**2 + f(x).diff(x) - (4*x**6 - 8*x**5 + 12*x**4 + 4*x**3 + \ + 7*x**2 - 20*x + 4)/(4*x**4), + [Eq(f(x), (2*x**5 - 2*x**4 - x**3 + 4*x**2 + 3*x - 2)/(2*x**4 \ + - 2*x**2))] + ), + # Regression: Poles with even multiplicity > 2 fixed + ( + Eq(f(x).diff(x), (-x**6 + 15*x**4 - 40*x**3 + 45*x**2 - 24*x + 4)/\ + (x**12 - 12*x**11 + 66*x**10 - 220*x**9 + 495*x**8 - 792*x**7 + 924*x**6 - \ + 792*x**5 + 495*x**4 - 220*x**3 + 66*x**2 - 12*x + 1) + f(x)**2 + f(x)), + [Eq(f(x), 1/(x**6 - 6*x**5 + 15*x**4 - 20*x**3 + 15*x**2 - 6*x + 1))] + ), + # More than 2 poles with multiplicity 2 + # Regression: Fixed mistake in necessary conditions + ( + Eq(f(x).diff(x), x*f(x) + 2*x + (3*x - 2)*f(x)**2/(4*x + 2) + \ + (8*x**2 - 7*x + 26)/(16*x**3 - 24*x**2 + 8) - S(3)/2), + [Eq(f(x), (1 - 4*x)/(2*x - 2))] + ), + # Regression: Fixed mistake in necessary conditions + ( + Eq(f(x).diff(x), (-12*x**2 - 48*x - 15)/(24*x**3 - 40*x**2 + 8*x + 8) \ + + 3*f(x)**2/(6*x + 2)), + [Eq(f(x), (2*x + 1)/(2*x - 2))] + ), + # Imaginary poles + ( + f(x).diff(x) + (3*x**2 + 1)*f(x)**2/x + (6*x**2 - x + 3)*f(x)/(x*(x \ + - 1)) + (3*x**2 - 2*x + 2)/(x*(x - 1)**2), + [Eq(f(x), (-C0 - x**3 + x**2 - 2*x)/(C0*x - C0 + x**4 - x**3 + x**2 \ + - x)), Eq(f(x), -1/(x - 1))], + ), + # Imaginary coefficients in equation + ( + f(x).diff(x) - 2*I*(f(x)**2 + 1)/x, + [Eq(f(x), (-I*C0 + I*x**4)/(C0 + x**4)), Eq(f(x), -I)] + ), + # Regression: linsolve returning empty solution + # Large value of m (> 10) + ( + Eq(f(x).diff(x), x*f(x)/(S(3)/2 - 2*x) + (x/2 - S(1)/3)*f(x)**2/\ + (2*x/3 - S(1)/2) - S(5)/4 + (281*x**2 - 1260*x + 756)/(16*x**3 - 12*x**2)), + [Eq(f(x), (9 - x)/x), Eq(f(x), (40*x**14 + 28*x**13 + 420*x**12 + 2940*x**11 + \ + 18480*x**10 + 103950*x**9 + 519750*x**8 + 2286900*x**7 + 8731800*x**6 + 28378350*\ + x**5 + 76403250*x**4 + 163721250*x**3 + 261954000*x**2 + 278326125*x + 147349125)/\ + ((24*x**14 + 140*x**13 + 840*x**12 + 4620*x**11 + 23100*x**10 + 103950*x**9 + \ + 415800*x**8 + 1455300*x**7 + 4365900*x**6 + 10914750*x**5 + 21829500*x**4 + 32744250\ + *x**3 + 32744250*x**2 + 16372125*x)))] + ), + # Regression: Fixed bug due to a typo in paper + ( + Eq(f(x).diff(x), 18*x**3 + 18*x**2 + (-x/2 - S(1)/2)*f(x)**2 + 6), + [Eq(f(x), 6*x)] + ), + # Regression: Fixed bug due to a typo in paper + ( + Eq(f(x).diff(x), -3*x**3/4 + 15*x/2 + (x/3 - S(4)/3)*f(x)**2 \ + + 9 + (1 - x)*f(x)/x + 3/x), + [Eq(f(x), -3*x/2 - 3)] + )] + for eq, sol in tests: + check_dummy_sol(eq, sol, C0) + + +@slow +def test_solve_riccati_slow(): + """ + This function tests the computation of rational + particular solutions for a Riccati ODE. + + Each test case has 2 values - + + 1. eq - Riccati ODE to be solved. + 2. sol - Expected solution to the equation. + """ + C0 = Dummy('C0') + tests = [ + # Very large values of m (989 and 991) + ( + Eq(f(x).diff(x), (1 - x)*f(x)/(x - 3) + (2 - 12*x)*f(x)**2/(2*x - 9) + \ + (54924*x**3 - 405264*x**2 + 1084347*x - 1087533)/(8*x**4 - 132*x**3 + 810*x**2 - \ + 2187*x + 2187) + 495), + [Eq(f(x), (18*x + 6)/(2*x - 9))] + )] + for eq, sol in tests: + check_dummy_sol(eq, sol, C0) diff --git a/phivenv/Lib/site-packages/sympy/solvers/ode/tests/test_single.py b/phivenv/Lib/site-packages/sympy/solvers/ode/tests/test_single.py new file mode 100644 index 0000000000000000000000000000000000000000..45b38029e97b9e74236c45f2f3efb6aa87c26e5c --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/solvers/ode/tests/test_single.py @@ -0,0 +1,2902 @@ +# +# The main tests for the code in single.py are currently located in +# sympy/solvers/tests/test_ode.py +# +r""" +This File contains test functions for the individual hints used for solving ODEs. + +Examples of each solver will be returned by _get_examples_ode_sol_name_of_solver. + +Examples should have a key 'XFAIL' which stores the list of hints if they are +expected to fail for that hint. + +Functions that are for internal use: + +1) _ode_solver_test(ode_examples) - It takes a dictionary of examples returned by + _get_examples method and tests them with their respective hints. + +2) _test_particular_example(our_hint, example_name) - It tests the ODE example corresponding + to the hint provided. + +3) _test_all_hints(runxfail=False) - It is used to test all the examples with all the hints + currently implemented. It calls _test_all_examples_for_one_hint() which outputs whether the + given hint functions properly if it classifies the ODE example. + If runxfail flag is set to True then it will only test the examples which are expected to fail. + + Everytime the ODE of a particular solver is added, _test_all_hints() is to be executed to find + the possible failures of different solver hints. + +4) _test_all_examples_for_one_hint(our_hint, all_examples) - It takes hint as argument and checks + this hint against all the ODE examples and gives output as the number of ODEs matched, number + of ODEs which were solved correctly, list of ODEs which gives incorrect solution and list of + ODEs which raises exception. + +""" +from sympy.core.function import (Derivative, diff) +from sympy.core.mul import Mul +from sympy.core.numbers import (E, I, Rational, pi) +from sympy.core.relational import (Eq, Ne) +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, symbols) +from sympy.functions.elementary.complexes import (im, re) +from sympy.functions.elementary.exponential import (LambertW, exp, log) +from sympy.functions.elementary.hyperbolic import (asinh, cosh, sinh, tanh) +from sympy.functions.elementary.miscellaneous import (cbrt, sqrt) +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import (acos, asin, atan, cos, sec, sin, tan) +from sympy.functions.special.error_functions import (Ei, erfi) +from sympy.functions.special.hyper import hyper +from sympy.integrals.integrals import (Integral, integrate) +from sympy.polys.rootoftools import rootof + +from sympy.core import Function, Symbol +from sympy.functions import airyai, airybi, besselj, bessely, lowergamma +from sympy.integrals.risch import NonElementaryIntegral +from sympy.solvers.ode import classify_ode, dsolve +from sympy.solvers.ode.ode import allhints, _remove_redundant_solutions +from sympy.solvers.ode.single import (FirstLinear, ODEMatchError, + SingleODEProblem, SingleODESolver, NthOrderReducible) + +from sympy.solvers.ode.subscheck import checkodesol + +from sympy.testing.pytest import raises, slow +import traceback + + +x = Symbol('x') +u = Symbol('u') +_u = Dummy('u') +y = Symbol('y') +f = Function('f') +g = Function('g') +C1, C2, C3, C4, C5, C6, C7, C8, C9, C10 = symbols('C1:11') +a, b, c = symbols('a b c') + + +hint_message = """\ +Hint did not match the example {example}. + +The ODE is: +{eq}. + +The expected hint was +{our_hint}\ +""" + +expected_sol_message = """\ +Different solution found from dsolve for example {example}. + +The ODE is: +{eq} + +The expected solution was +{sol} + +What dsolve returned is: +{dsolve_sol}\ +""" + +checkodesol_msg = """\ +solution found is not correct for example {example}. + +The ODE is: +{eq}\ +""" + +dsol_incorrect_msg = """\ +solution returned by dsolve is incorrect when using {hint}. + +The ODE is: +{eq} + +The expected solution was +{sol} + +what dsolve returned is: +{dsolve_sol} + +You can test this with: + +eq = {eq} +sol = dsolve(eq, hint='{hint}') +print(sol) +print(checkodesol(eq, sol)) + +""" + +exception_msg = """\ +dsolve raised exception : {e} + +when using {hint} for the example {example} + +You can test this with: + +from sympy.solvers.ode.tests.test_single import _test_an_example + +_test_an_example('{hint}', example_name = '{example}') + +The ODE is: +{eq} + +\ +""" + +check_hint_msg = """\ +Tested hint was : {hint} + +Total of {matched} examples matched with this hint. + +Out of which {solve} gave correct results. + +Examples which gave incorrect results are {unsolve}. + +Examples which raised exceptions are {exceptions} +\ +""" + + +def _add_example_keys(func): + def inner(): + solver=func() + examples=[] + for example in solver['examples']: + temp={ + 'eq': solver['examples'][example]['eq'], + 'sol': solver['examples'][example]['sol'], + 'XFAIL': solver['examples'][example].get('XFAIL', []), + 'func': solver['examples'][example].get('func',solver['func']), + 'example_name': example, + 'slow': solver['examples'][example].get('slow', False), + 'simplify_flag':solver['examples'][example].get('simplify_flag',True), + 'checkodesol_XFAIL': solver['examples'][example].get('checkodesol_XFAIL', False), + 'dsolve_too_slow':solver['examples'][example].get('dsolve_too_slow',False), + 'checkodesol_too_slow':solver['examples'][example].get('checkodesol_too_slow',False), + 'hint': solver['hint'] + } + examples.append(temp) + return examples + return inner() + + +def _ode_solver_test(ode_examples, run_slow_test=False): + for example in ode_examples: + if ((not run_slow_test) and example['slow']) or (run_slow_test and (not example['slow'])): + continue + + result = _test_particular_example(example['hint'], example, solver_flag=True) + if result['xpass_msg'] != "": + print(result['xpass_msg']) + + +def _test_all_hints(runxfail=False): + all_hints = list(allhints)+["default"] + all_examples = _get_all_examples() + + for our_hint in all_hints: + if our_hint.endswith('_Integral') or 'series' in our_hint: + continue + _test_all_examples_for_one_hint(our_hint, all_examples, runxfail) + + +def _test_dummy_sol(expected_sol,dsolve_sol): + if type(dsolve_sol)==list: + return any(expected_sol.dummy_eq(sub_dsol) for sub_dsol in dsolve_sol) + else: + return expected_sol.dummy_eq(dsolve_sol) + + +def _test_an_example(our_hint, example_name): + all_examples = _get_all_examples() + for example in all_examples: + if example['example_name'] == example_name: + _test_particular_example(our_hint, example) + + +def _test_particular_example(our_hint, ode_example, solver_flag=False): + eq = ode_example['eq'] + expected_sol = ode_example['sol'] + example = ode_example['example_name'] + xfail = our_hint in ode_example['XFAIL'] + func = ode_example['func'] + result = {'msg': '', 'xpass_msg': ''} + simplify_flag=ode_example['simplify_flag'] + checkodesol_XFAIL = ode_example['checkodesol_XFAIL'] + dsolve_too_slow = ode_example['dsolve_too_slow'] + checkodesol_too_slow = ode_example['checkodesol_too_slow'] + xpass = True + if solver_flag: + if our_hint not in classify_ode(eq, func): + message = hint_message.format(example=example, eq=eq, our_hint=our_hint) + raise AssertionError(message) + + if our_hint in classify_ode(eq, func): + result['match_list'] = example + try: + if not (dsolve_too_slow): + dsolve_sol = dsolve(eq, func, simplify=simplify_flag,hint=our_hint) + else: + if len(expected_sol)==1: + dsolve_sol = expected_sol[0] + else: + dsolve_sol = expected_sol + + except Exception as e: + dsolve_sol = [] + result['exception_list'] = example + if not solver_flag: + traceback.print_exc() + result['msg'] = exception_msg.format(e=str(e), hint=our_hint, example=example, eq=eq) + if solver_flag and not xfail: + print(result['msg']) + raise + xpass = False + + if solver_flag and dsolve_sol!=[]: + expect_sol_check = False + if type(dsolve_sol)==list: + for sub_sol in expected_sol: + if sub_sol.has(Dummy): + expect_sol_check = not _test_dummy_sol(sub_sol, dsolve_sol) + else: + expect_sol_check = sub_sol not in dsolve_sol + if expect_sol_check: + break + else: + expect_sol_check = dsolve_sol not in expected_sol + for sub_sol in expected_sol: + if sub_sol.has(Dummy): + expect_sol_check = not _test_dummy_sol(sub_sol, dsolve_sol) + + if expect_sol_check: + message = expected_sol_message.format(example=example, eq=eq, sol=expected_sol, dsolve_sol=dsolve_sol) + raise AssertionError(message) + + expected_checkodesol = [(True, 0) for i in range(len(expected_sol))] + if len(expected_sol) == 1: + expected_checkodesol = (True, 0) + + if not checkodesol_too_slow: + if not checkodesol_XFAIL: + if checkodesol(eq, dsolve_sol, func, solve_for_func=False) != expected_checkodesol: + result['unsolve_list'] = example + xpass = False + message = dsol_incorrect_msg.format(hint=our_hint, eq=eq, sol=expected_sol,dsolve_sol=dsolve_sol) + if solver_flag: + message = checkodesol_msg.format(example=example, eq=eq) + raise AssertionError(message) + else: + result['msg'] = 'AssertionError: ' + message + + if xpass and xfail: + result['xpass_msg'] = example + "is now passing for the hint" + our_hint + return result + + +def _test_all_examples_for_one_hint(our_hint, all_examples=[], runxfail=None): + if all_examples == []: + all_examples = _get_all_examples() + match_list, unsolve_list, exception_list = [], [], [] + for ode_example in all_examples: + xfail = our_hint in ode_example['XFAIL'] + if runxfail and not xfail: + continue + if xfail: + continue + result = _test_particular_example(our_hint, ode_example) + match_list += result.get('match_list',[]) + unsolve_list += result.get('unsolve_list',[]) + exception_list += result.get('exception_list',[]) + if runxfail is not None: + msg = result['msg'] + if msg!='': + print(result['msg']) + # print(result.get('xpass_msg','')) + if runxfail is None: + match_count = len(match_list) + solved = len(match_list)-len(unsolve_list)-len(exception_list) + msg = check_hint_msg.format(hint=our_hint, matched=match_count, solve=solved, unsolve=unsolve_list, exceptions=exception_list) + print(msg) + + +def test_SingleODESolver(): + # Test that not implemented methods give NotImplementedError + # Subclasses should override these methods. + problem = SingleODEProblem(f(x).diff(x), f(x), x) + solver = SingleODESolver(problem) + raises(NotImplementedError, lambda: solver.matches()) + raises(NotImplementedError, lambda: solver.get_general_solution()) + raises(NotImplementedError, lambda: solver._matches()) + raises(NotImplementedError, lambda: solver._get_general_solution()) + + # This ODE can not be solved by the FirstLinear solver. Here we test that + # it does not match and the asking for a general solution gives + # ODEMatchError + + problem = SingleODEProblem(f(x).diff(x) + f(x)*f(x), f(x), x) + + solver = FirstLinear(problem) + raises(ODEMatchError, lambda: solver.get_general_solution()) + + solver = FirstLinear(problem) + assert solver.matches() is False + + #These are just test for order of ODE + + problem = SingleODEProblem(f(x).diff(x) + f(x), f(x), x) + assert problem.order == 1 + + problem = SingleODEProblem(f(x).diff(x,4) + f(x).diff(x,2) - f(x).diff(x,3), f(x), x) + assert problem.order == 4 + + problem = SingleODEProblem(f(x).diff(x, 3) + f(x).diff(x, 2) - f(x)**2, f(x), x) + assert problem.is_autonomous == True + + problem = SingleODEProblem(f(x).diff(x, 3) + x*f(x).diff(x, 2) - f(x)**2, f(x), x) + assert problem.is_autonomous == False + + +def test_linear_coefficients(): + _ode_solver_test(_get_examples_ode_sol_linear_coefficients) + + +@slow +def test_1st_homogeneous_coeff_ode(): + #These were marked as test_1st_homogeneous_coeff_corner_case + eq1 = f(x).diff(x) - f(x)/x + c1 = classify_ode(eq1, f(x)) + eq2 = x*f(x).diff(x) - f(x) + c2 = classify_ode(eq2, f(x)) + sdi = "1st_homogeneous_coeff_subs_dep_div_indep" + sid = "1st_homogeneous_coeff_subs_indep_div_dep" + assert sid not in c1 and sdi not in c1 + assert sid not in c2 and sdi not in c2 + _ode_solver_test(_get_examples_ode_sol_1st_homogeneous_coeff_subs_dep_div_indep) + _ode_solver_test(_get_examples_ode_sol_1st_homogeneous_coeff_best) + + +@slow +def test_slow_examples_1st_homogeneous_coeff_ode(): + _ode_solver_test(_get_examples_ode_sol_1st_homogeneous_coeff_subs_dep_div_indep, run_slow_test=True) + _ode_solver_test(_get_examples_ode_sol_1st_homogeneous_coeff_best, run_slow_test=True) + + +@slow +def test_nth_linear_constant_coeff_homogeneous(): + _ode_solver_test(_get_examples_ode_sol_nth_linear_constant_coeff_homogeneous) + + +@slow +def test_slow_examples_nth_linear_constant_coeff_homogeneous(): + _ode_solver_test(_get_examples_ode_sol_nth_linear_constant_coeff_homogeneous, run_slow_test=True) + + +def test_Airy_equation(): + _ode_solver_test(_get_examples_ode_sol_2nd_linear_airy) + + +@slow +def test_lie_group(): + _ode_solver_test(_get_examples_ode_sol_lie_group) + + +@slow +def test_separable_reduced(): + df = f(x).diff(x) + eq = (x / f(x))*df + tan(x**2*f(x) / (x**2*f(x) - 1)) + assert classify_ode(eq) == ('factorable', 'separable_reduced', 'lie_group', + 'separable_reduced_Integral') + _ode_solver_test(_get_examples_ode_sol_separable_reduced) + + +@slow +def test_slow_examples_separable_reduced(): + _ode_solver_test(_get_examples_ode_sol_separable_reduced, run_slow_test=True) + + +@slow +def test_2nd_2F1_hypergeometric(): + _ode_solver_test(_get_examples_ode_sol_2nd_2F1_hypergeometric) + + +def test_2nd_2F1_hypergeometric_integral(): + eq = x*(x-1)*f(x).diff(x, 2) + (-1+ S(7)/2*x)*f(x).diff(x) + f(x) + sol = Eq(f(x), (C1 + C2*Integral(exp(Integral((1 - x/2)/(x*(x - 1)), x))/(1 - + x/2)**2, x))*exp(Integral(1/(x - 1), x)/4)*exp(-Integral(7/(x - + 1), x)/4)*hyper((S(1)/2, -1), (1,), x)) + assert sol == dsolve(eq, hint='2nd_hypergeometric_Integral') + assert checkodesol(eq, sol) == (True, 0) + + +@slow +def test_2nd_nonlinear_autonomous_conserved(): + _ode_solver_test(_get_examples_ode_sol_2nd_nonlinear_autonomous_conserved) + + +def test_2nd_nonlinear_autonomous_conserved_integral(): + eq = f(x).diff(x, 2) + asin(f(x)) + actual = [Eq(Integral(1/sqrt(C1 - 2*Integral(asin(_u), _u)), (_u, f(x))), C2 + x), + Eq(Integral(1/sqrt(C1 - 2*Integral(asin(_u), _u)), (_u, f(x))), C2 - x)] + solved = dsolve(eq, hint='2nd_nonlinear_autonomous_conserved_Integral', simplify=False) + for a,s in zip(actual, solved): + assert a.dummy_eq(s) + # checkodesol unable to simplify solutions with f(x) in an integral equation + assert checkodesol(eq, [s.doit() for s in solved]) == [(True, 0), (True, 0)] + + +@slow +def test_2nd_linear_bessel_equation(): + _ode_solver_test(_get_examples_ode_sol_2nd_linear_bessel) + + +@slow +def test_nth_algebraic(): + eqn = f(x) + f(x)*f(x).diff(x) + solns = [Eq(f(x), exp(x)), + Eq(f(x), C1*exp(C2*x))] + solns_final = _remove_redundant_solutions(eqn, solns, 2, x) + assert solns_final == [Eq(f(x), C1*exp(C2*x))] + + _ode_solver_test(_get_examples_ode_sol_nth_algebraic) + + +@slow +def test_slow_examples_nth_linear_constant_coeff_var_of_parameters(): + _ode_solver_test(_get_examples_ode_sol_nth_linear_var_of_parameters, run_slow_test=True) + + +def test_nth_linear_constant_coeff_var_of_parameters(): + _ode_solver_test(_get_examples_ode_sol_nth_linear_var_of_parameters) + + +@slow +def test_nth_linear_constant_coeff_variation_of_parameters__integral(): + # solve_variation_of_parameters shouldn't attempt to simplify the + # Wronskian if simplify=False. If wronskian() ever gets good enough + # to simplify the result itself, this test might fail. + our_hint = 'nth_linear_constant_coeff_variation_of_parameters_Integral' + eq = f(x).diff(x, 5) + 2*f(x).diff(x, 3) + f(x).diff(x) - 2*x - exp(I*x) + sol_simp = dsolve(eq, f(x), hint=our_hint, simplify=True) + sol_nsimp = dsolve(eq, f(x), hint=our_hint, simplify=False) + assert sol_simp != sol_nsimp + assert checkodesol(eq, sol_simp, order=5, solve_for_func=False) == (True, 0) + assert checkodesol(eq, sol_simp, order=5, solve_for_func=False) == (True, 0) + + +@slow +def test_slow_examples_1st_exact(): + _ode_solver_test(_get_examples_ode_sol_1st_exact, run_slow_test=True) + + +@slow +def test_1st_exact(): + _ode_solver_test(_get_examples_ode_sol_1st_exact) + + +def test_1st_exact_integral(): + eq = cos(f(x)) - (x*sin(f(x)) - f(x)**2)*f(x).diff(x) + sol_1 = dsolve(eq, f(x), simplify=False, hint='1st_exact_Integral') + assert checkodesol(eq, sol_1, order=1, solve_for_func=False) + + +@slow +def test_slow_examples_nth_order_reducible(): + _ode_solver_test(_get_examples_ode_sol_nth_order_reducible, run_slow_test=True) + + +@slow +def test_slow_examples_nth_linear_constant_coeff_undetermined_coefficients(): + _ode_solver_test(_get_examples_ode_sol_nth_linear_undetermined_coefficients, run_slow_test=True) + + +@slow +def test_slow_examples_separable(): + _ode_solver_test(_get_examples_ode_sol_separable, run_slow_test=True) + + +@slow +def test_nth_linear_constant_coeff_undetermined_coefficients(): + #issue-https://github.com/sympy/sympy/issues/5787 + # This test case is to show the classification of imaginary constants under + # nth_linear_constant_coeff_undetermined_coefficients + eq = Eq(diff(f(x), x), I*f(x) + S.Half - I) + our_hint = 'nth_linear_constant_coeff_undetermined_coefficients' + assert our_hint in classify_ode(eq) + _ode_solver_test(_get_examples_ode_sol_nth_linear_undetermined_coefficients) + + +def test_nth_order_reducible(): + F = lambda eq: NthOrderReducible(SingleODEProblem(eq, f(x), x))._matches() + D = Derivative + assert F(D(y*f(x), x, y) + D(f(x), x)) == False + assert F(D(y*f(y), y, y) + D(f(y), y)) == False + assert F(f(x)*D(f(x), x) + D(f(x), x, 2))== False + assert F(D(x*f(y), y, 2) + D(u*y*f(x), x, 3)) == False # no simplification by design + assert F(D(f(y), y, 2) + D(f(y), y, 3) + D(f(x), x, 4)) == False + assert F(D(f(x), x, 2) + D(f(x), x, 3)) == True + _ode_solver_test(_get_examples_ode_sol_nth_order_reducible) + + +@slow +def test_separable(): + _ode_solver_test(_get_examples_ode_sol_separable) + + +@slow +def test_factorable(): + assert integrate(-asin(f(2*x)+pi), x) == -Integral(asin(pi + f(2*x)), x) + _ode_solver_test(_get_examples_ode_sol_factorable) + + +@slow +def test_slow_examples_factorable(): + _ode_solver_test(_get_examples_ode_sol_factorable, run_slow_test=True) + + +def test_Riccati_special_minus2(): + _ode_solver_test(_get_examples_ode_sol_riccati) + + +@slow +def test_1st_rational_riccati(): + _ode_solver_test(_get_examples_ode_sol_1st_rational_riccati) + + +def test_Bernoulli(): + _ode_solver_test(_get_examples_ode_sol_bernoulli) + + +def test_1st_linear(): + _ode_solver_test(_get_examples_ode_sol_1st_linear) + + +def test_almost_linear(): + _ode_solver_test(_get_examples_ode_sol_almost_linear) + + +@slow +def test_Liouville_ODE(): + hint = 'Liouville' + not_Liouville1 = classify_ode(diff(f(x), x)/x + f(x)*diff(f(x), x, x)/2 - + diff(f(x), x)**2/2, f(x)) + not_Liouville2 = classify_ode(diff(f(x), x)/x + diff(f(x), x, x)/2 - + x*diff(f(x), x)**2/2, f(x)) + assert hint not in not_Liouville1 + assert hint not in not_Liouville2 + assert hint + '_Integral' not in not_Liouville1 + assert hint + '_Integral' not in not_Liouville2 + + _ode_solver_test(_get_examples_ode_sol_liouville) + + +def test_nth_order_linear_euler_eq_homogeneous(): + x, t, a, b, c = symbols('x t a b c') + y = Function('y') + our_hint = "nth_linear_euler_eq_homogeneous" + + eq = diff(f(t), t, 4)*t**4 - 13*diff(f(t), t, 2)*t**2 + 36*f(t) + assert our_hint in classify_ode(eq) + + eq = a*y(t) + b*t*diff(y(t), t) + c*t**2*diff(y(t), t, 2) + assert our_hint in classify_ode(eq) + + _ode_solver_test(_get_examples_ode_sol_euler_homogeneous) + + +def test_nth_order_linear_euler_eq_nonhomogeneous_undetermined_coefficients(): + x, t = symbols('x t') + a, b, c, d = symbols('a b c d', integer=True) + our_hint = "nth_linear_euler_eq_nonhomogeneous_undetermined_coefficients" + + eq = x**4*diff(f(x), x, 4) - 13*x**2*diff(f(x), x, 2) + 36*f(x) + x + assert our_hint in classify_ode(eq, f(x)) + + eq = a*x**2*diff(f(x), x, 2) + b*x*diff(f(x), x) + c*f(x) + d*log(x) + assert our_hint in classify_ode(eq, f(x)) + + _ode_solver_test(_get_examples_ode_sol_euler_undetermined_coeff) + + +@slow +def test_nth_order_linear_euler_eq_nonhomogeneous_variation_of_parameters(): + x, t = symbols('x, t') + a, b, c, d = symbols('a, b, c, d', integer=True) + our_hint = "nth_linear_euler_eq_nonhomogeneous_variation_of_parameters" + + eq = Eq(x**2*diff(f(x),x,2) - 8*x*diff(f(x),x) + 12*f(x), x**2) + assert our_hint in classify_ode(eq, f(x)) + + eq = Eq(a*x**3*diff(f(x),x,3) + b*x**2*diff(f(x),x,2) + c*x*diff(f(x),x) + d*f(x), x*log(x)) + assert our_hint in classify_ode(eq, f(x)) + + _ode_solver_test(_get_examples_ode_sol_euler_var_para) + + +@_add_example_keys +def _get_examples_ode_sol_euler_homogeneous(): + r1, r2, r3, r4, r5 = [rootof(x**5 - 14*x**4 + 71*x**3 - 154*x**2 + 120*x - 1, n) for n in range(5)] + return { + 'hint': "nth_linear_euler_eq_homogeneous", + 'func': f(x), + 'examples':{ + 'euler_hom_01': { + 'eq': Eq(-3*diff(f(x), x)*x + 2*x**2*diff(f(x), x, x), 0), + 'sol': [Eq(f(x), C1 + C2*x**Rational(5, 2))], + }, + + 'euler_hom_02': { + 'eq': Eq(3*f(x) - 5*diff(f(x), x)*x + 2*x**2*diff(f(x), x, x), 0), + 'sol': [Eq(f(x), C1*sqrt(x) + C2*x**3)] + }, + + 'euler_hom_03': { + 'eq': Eq(4*f(x) + 5*diff(f(x), x)*x + x**2*diff(f(x), x, x), 0), + 'sol': [Eq(f(x), (C1 + C2*log(x))/x**2)] + }, + + 'euler_hom_04': { + 'eq': Eq(6*f(x) - 6*diff(f(x), x)*x + 1*x**2*diff(f(x), x, x) + x**3*diff(f(x), x, x, x), 0), + 'sol': [Eq(f(x), C1/x**2 + C2*x + C3*x**3)] + }, + + 'euler_hom_05': { + 'eq': Eq(-125*f(x) + 61*diff(f(x), x)*x - 12*x**2*diff(f(x), x, x) + x**3*diff(f(x), x, x, x), 0), + 'sol': [Eq(f(x), x**5*(C1 + C2*log(x) + C3*log(x)**2))] + }, + + 'euler_hom_06': { + 'eq': x**2*diff(f(x), x, 2) + x*diff(f(x), x) - 9*f(x), + 'sol': [Eq(f(x), C1*x**-3 + C2*x**3)] + }, + + 'euler_hom_07': { + 'eq': sin(x)*x**2*f(x).diff(x, 2) + sin(x)*x*f(x).diff(x) + sin(x)*f(x), + 'sol': [Eq(f(x), C1*sin(log(x)) + C2*cos(log(x)))], + 'XFAIL': ['2nd_power_series_regular','nth_linear_euler_eq_nonhomogeneous_undetermined_coefficients'] + }, + + 'euler_hom_08': { + 'eq': x**6 * f(x).diff(x, 6) - x*f(x).diff(x) + f(x), + 'sol': [Eq(f(x), C1*x + C2*x**r1 + C3*x**r2 + C4*x**r3 + C5*x**r4 + C6*x**r5)], + 'checkodesol_XFAIL':True + }, + + #This example is from issue: https://github.com/sympy/sympy/issues/15237 #This example is from issue: + # https://github.com/sympy/sympy/issues/15237 + 'euler_hom_09': { + 'eq': Derivative(x*f(x), x, x, x), + 'sol': [Eq(f(x), C1 + C2/x + C3*x)], + }, + } + } + + +@_add_example_keys +def _get_examples_ode_sol_euler_undetermined_coeff(): + return { + 'hint': "nth_linear_euler_eq_nonhomogeneous_undetermined_coefficients", + 'func': f(x), + 'examples':{ + 'euler_undet_01': { + 'eq': Eq(x**2*diff(f(x), x, x) + x*diff(f(x), x), 1), + 'sol': [Eq(f(x), C1 + C2*log(x) + log(x)**2/2)] + }, + + 'euler_undet_02': { + 'eq': Eq(x**2*diff(f(x), x, x) - 2*x*diff(f(x), x) + 2*f(x), x**3), + 'sol': [Eq(f(x), x*(C1 + C2*x + Rational(1, 2)*x**2))] + }, + + 'euler_undet_03': { + 'eq': Eq(x**2*diff(f(x), x, x) - x*diff(f(x), x) - 3*f(x), log(x)/x), + 'sol': [Eq(f(x), (C1 + C2*x**4 - log(x)**2/8 - log(x)/16)/x)] + }, + + 'euler_undet_04': { + 'eq': Eq(x**2*diff(f(x), x, x) + 3*x*diff(f(x), x) - 8*f(x), log(x)**3 - log(x)), + 'sol': [Eq(f(x), C1/x**4 + C2*x**2 - Rational(1,8)*log(x)**3 - Rational(3,32)*log(x)**2 - Rational(1,64)*log(x) - Rational(7, 256))] + }, + + 'euler_undet_05': { + 'eq': Eq(x**3*diff(f(x), x, x, x) - 3*x**2*diff(f(x), x, x) + 6*x*diff(f(x), x) - 6*f(x), log(x)), + 'sol': [Eq(f(x), C1*x + C2*x**2 + C3*x**3 - Rational(1, 6)*log(x) - Rational(11, 36))] + }, + + #Below examples were added for the issue: https://github.com/sympy/sympy/issues/5096 + 'euler_undet_06': { + 'eq': 2*x**2*f(x).diff(x, 2) + f(x) + sqrt(2*x)*sin(log(2*x)/2), + 'sol': [Eq(f(x), sqrt(x)*(C1*sin(log(x)/2) + C2*cos(log(x)/2) + sqrt(2)*log(x)*cos(log(2*x)/2)/2))] + }, + + 'euler_undet_07': { + 'eq': 2*x**2*f(x).diff(x, 2) + f(x) + sin(log(2*x)/2), + 'sol': [Eq(f(x), C1*sqrt(x)*sin(log(x)/2) + C2*sqrt(x)*cos(log(x)/2) - 2*sin(log(2*x)/2)/5 - 4*cos(log(2*x)/2)/5)] + }, + } + } + + +@_add_example_keys +def _get_examples_ode_sol_euler_var_para(): + return { + 'hint': "nth_linear_euler_eq_nonhomogeneous_variation_of_parameters", + 'func': f(x), + 'examples':{ + 'euler_var_01': { + 'eq': Eq(x**2*Derivative(f(x), x, x) - 2*x*Derivative(f(x), x) + 2*f(x), x**4), + 'sol': [Eq(f(x), x*(C1 + C2*x + x**3/6))] + }, + + 'euler_var_02': { + 'eq': Eq(3*x**2*diff(f(x), x, x) + 6*x*diff(f(x), x) - 6*f(x), x**3*exp(x)), + 'sol': [Eq(f(x), C1/x**2 + C2*x + x*exp(x)/3 - 4*exp(x)/3 + 8*exp(x)/(3*x) - 8*exp(x)/(3*x**2))] + }, + + 'euler_var_03': { + 'eq': Eq(x**2*Derivative(f(x), x, x) - 2*x*Derivative(f(x), x) + 2*f(x), x**4*exp(x)), + 'sol': [Eq(f(x), x*(C1 + C2*x + x*exp(x) - 2*exp(x)))] + }, + + 'euler_var_04': { + 'eq': x**2*Derivative(f(x), x, x) - 2*x*Derivative(f(x), x) + 2*f(x) - log(x), + 'sol': [Eq(f(x), C1*x + C2*x**2 + log(x)/2 + Rational(3, 4))] + }, + + 'euler_var_05': { + 'eq': -exp(x) + (x*Derivative(f(x), (x, 2)) + Derivative(f(x), x))/x, + 'sol': [Eq(f(x), C1 + C2*log(x) + exp(x) - Ei(x))] + }, + + 'euler_var_06': { + 'eq': x**2 * f(x).diff(x, 2) + x * f(x).diff(x) + 4 * f(x) - 1/x, + 'sol': [Eq(f(x), C1*sin(2*log(x)) + C2*cos(2*log(x)) + 1/(5*x))] + }, + } + } + + +@_add_example_keys +def _get_examples_ode_sol_bernoulli(): + # Type: Bernoulli, f'(x) + p(x)*f(x) == q(x)*f(x)**n + return { + 'hint': "Bernoulli", + 'func': f(x), + 'examples':{ + 'bernoulli_01': { + 'eq': Eq(x*f(x).diff(x) + f(x) - f(x)**2, 0), + 'sol': [Eq(f(x), 1/(C1*x + 1))], + 'XFAIL': ['separable_reduced'] + }, + + 'bernoulli_02': { + 'eq': f(x).diff(x) - y*f(x), + 'sol': [Eq(f(x), C1*exp(x*y))] + }, + + 'bernoulli_03': { + 'eq': f(x)*f(x).diff(x) - 1, + 'sol': [Eq(f(x), -sqrt(C1 + 2*x)), Eq(f(x), sqrt(C1 + 2*x))] + }, + } + } + + +@_add_example_keys +def _get_examples_ode_sol_riccati(): + # Type: Riccati special alpha = -2, a*dy/dx + b*y**2 + c*y/x +d/x**2 + return { + 'hint': "Riccati_special_minus2", + 'func': f(x), + 'examples':{ + 'riccati_01': { + 'eq': 2*f(x).diff(x) + f(x)**2 - f(x)/x + 3*x**(-2), + 'sol': [Eq(f(x), (-sqrt(3)*tan(C1 + sqrt(3)*log(x)/4) + 3)/(2*x))], + }, + }, + } + + +@_add_example_keys +def _get_examples_ode_sol_1st_rational_riccati(): + # Type: 1st Order Rational Riccati, dy/dx = a + b*y + c*y**2, + # a, b, c are rational functions of x + return { + 'hint': "1st_rational_riccati", + 'func': f(x), + 'examples':{ + # a(x) is a constant + "rational_riccati_01": { + "eq": Eq(f(x).diff(x) + f(x)**2 - 2, 0), + "sol": [Eq(f(x), sqrt(2)*(-C1 - exp(2*sqrt(2)*x))/(C1 - exp(2*sqrt(2)*x)))] + }, + # a(x) is a constant + "rational_riccati_02": { + "eq": f(x)**2 + Derivative(f(x), x) + 4*f(x)/x + 2/x**2, + "sol": [Eq(f(x), (-2*C1 - x)/(x*(C1 + x)))] + }, + # a(x) is a constant + "rational_riccati_03": { + "eq": 2*x**2*Derivative(f(x), x) - x*(4*f(x) + Derivative(f(x), x) - 4) + (f(x) - 1)*f(x), + "sol": [Eq(f(x), (C1 + 2*x**2)/(C1 + x))] + }, + # Constant coefficients + "rational_riccati_04": { + "eq": f(x).diff(x) - 6 - 5*f(x) - f(x)**2, + "sol": [Eq(f(x), (-2*C1 + 3*exp(x))/(C1 - exp(x)))] + }, + # One pole of multiplicity 2 + "rational_riccati_05": { + "eq": x**2 - (2*x + 1/x)*f(x) + f(x)**2 + Derivative(f(x), x), + "sol": [Eq(f(x), x*(C1 + x**2 + 1)/(C1 + x**2 - 1))] + }, + # One pole of multiplicity 2 + "rational_riccati_06": { + "eq": x**4*Derivative(f(x), x) + x**2 - x*(2*f(x)**2 + Derivative(f(x), x)) + f(x), + "sol": [Eq(f(x), x*(C1*x - x + 1)/(C1 + x**2 - 1))] + }, + # Multiple poles of multiplicity 2 + "rational_riccati_07": { + "eq": -f(x)**2 + Derivative(f(x), x) + (15*x**2 - 20*x + 7)/((x - 1)**2*(2*x \ + - 1)**2), + "sol": [Eq(f(x), (9*C1*x - 6*C1 - 15*x**5 + 60*x**4 - 94*x**3 + 72*x**2 - \ + 33*x + 8)/(6*C1*x**2 - 9*C1*x + 3*C1 + 6*x**6 - 29*x**5 + 57*x**4 - \ + 58*x**3 + 28*x**2 - 3*x - 1))] + }, + # Imaginary poles + "rational_riccati_08": { + "eq": Derivative(f(x), x) + (3*x**2 + 1)*f(x)**2/x + (6*x**2 - x + 3)*f(x)/(x*(x \ + - 1)) + (3*x**2 - 2*x + 2)/(x*(x - 1)**2), + "sol": [Eq(f(x), (-C1 - x**3 + x**2 - 2*x + 1)/(C1*x - C1 + x**4 - x**3 + x**2 - \ + 2*x + 1))], + }, + # Imaginary coefficients in equation + "rational_riccati_09": { + "eq": Derivative(f(x), x) - 2*I*(f(x)**2 + 1)/x, + "sol": [Eq(f(x), (-I*C1 + I*x**4 + I)/(C1 + x**4 - 1))] + }, + # Regression: linsolve returning empty solution + # Large value of m (> 10) + "rational_riccati_10": { + "eq": Eq(Derivative(f(x), x), x*f(x)/(S(3)/2 - 2*x) + (x/2 - S(1)/3)*f(x)**2/\ + (2*x/3 - S(1)/2) - S(5)/4 + (281*x**2 - 1260*x + 756)/(16*x**3 - 12*x**2)), + "sol": [Eq(f(x), (40*C1*x**14 + 28*C1*x**13 + 420*C1*x**12 + 2940*C1*x**11 + \ + 18480*C1*x**10 + 103950*C1*x**9 + 519750*C1*x**8 + 2286900*C1*x**7 + \ + 8731800*C1*x**6 + 28378350*C1*x**5 + 76403250*C1*x**4 + 163721250*C1*x**3 \ + + 261954000*C1*x**2 + 278326125*C1*x + 147349125*C1 + x*exp(2*x) - 9*exp(2*x) \ + )/(x*(24*C1*x**13 + 140*C1*x**12 + 840*C1*x**11 + 4620*C1*x**10 + 23100*C1*x**9 \ + + 103950*C1*x**8 + 415800*C1*x**7 + 1455300*C1*x**6 + 4365900*C1*x**5 + \ + 10914750*C1*x**4 + 21829500*C1*x**3 + 32744250*C1*x**2 + 32744250*C1*x + \ + 16372125*C1 - exp(2*x))))] + } + } + } + + + +@_add_example_keys +def _get_examples_ode_sol_1st_linear(): + # Type: first order linear form f'(x)+p(x)f(x)=q(x) + return { + 'hint': "1st_linear", + 'func': f(x), + 'examples':{ + 'linear_01': { + 'eq': Eq(f(x).diff(x) + x*f(x), x**2), + 'sol': [Eq(f(x), (C1 + x*exp(x**2/2)- sqrt(2)*sqrt(pi)*erfi(sqrt(2)*x/2)/2)*exp(-x**2/2))], + }, + }, + } + + +@_add_example_keys +def _get_examples_ode_sol_factorable(): + """ some hints are marked as xfail for examples because they missed additional algebraic solution + which could be found by Factorable hint. Fact_01 raise exception for + nth_linear_constant_coeff_undetermined_coefficients""" + + y = Dummy('y') + a0,a1,a2,a3,a4 = symbols('a0, a1, a2, a3, a4') + return { + 'hint': "factorable", + 'func': f(x), + 'examples':{ + 'fact_01': { + 'eq': f(x) + f(x)*f(x).diff(x), + 'sol': [Eq(f(x), 0), Eq(f(x), C1 - x)], + 'XFAIL': ['separable', '1st_exact', '1st_linear', 'Bernoulli', '1st_homogeneous_coeff_best', + '1st_homogeneous_coeff_subs_indep_div_dep', '1st_homogeneous_coeff_subs_dep_div_indep', + 'lie_group', 'nth_linear_euler_eq_nonhomogeneous_undetermined_coefficients', + 'nth_linear_constant_coeff_variation_of_parameters', + 'nth_linear_euler_eq_nonhomogeneous_variation_of_parameters', + 'nth_linear_constant_coeff_undetermined_coefficients'] + }, + + 'fact_02': { + 'eq': f(x)*(f(x).diff(x)+f(x)*x+2), + 'sol': [Eq(f(x), (C1 - sqrt(2)*sqrt(pi)*erfi(sqrt(2)*x/2))*exp(-x**2/2)), Eq(f(x), 0)], + 'XFAIL': ['Bernoulli', '1st_linear', 'lie_group'] + }, + + 'fact_03': { + 'eq': (f(x).diff(x)+f(x)*x**2)*(f(x).diff(x, 2) + x*f(x)), + 'sol': [Eq(f(x), C1*airyai(-x) + C2*airybi(-x)),Eq(f(x), C1*exp(-x**3/3))] + }, + + 'fact_04': { + 'eq': (f(x).diff(x)+f(x)*x**2)*(f(x).diff(x, 2) + f(x)), + 'sol': [Eq(f(x), C1*exp(-x**3/3)), Eq(f(x), C1*sin(x) + C2*cos(x))] + }, + + 'fact_05': { + 'eq': (f(x).diff(x)**2-1)*(f(x).diff(x)**2-4), + 'sol': [Eq(f(x), C1 - x), Eq(f(x), C1 + x), Eq(f(x), C1 + 2*x), Eq(f(x), C1 - 2*x)] + }, + + 'fact_06': { + 'eq': (f(x).diff(x, 2)-exp(f(x)))*f(x).diff(x), + 'sol': [ + Eq(f(x), log(-C1/(cos(sqrt(-C1)*(C2 + x)) + 1))), + Eq(f(x), log(-C1/(cos(sqrt(-C1)*(C2 - x)) + 1))), + Eq(f(x), C1) + ], + 'slow': True, + }, + + 'fact_07': { + 'eq': (f(x).diff(x)**2-1)*(f(x)*f(x).diff(x)-1), + 'sol': [Eq(f(x), C1 - x), Eq(f(x), -sqrt(C1 + 2*x)),Eq(f(x), sqrt(C1 + 2*x)), Eq(f(x), C1 + x)] + }, + + 'fact_08': { + 'eq': Derivative(f(x), x)**4 - 2*Derivative(f(x), x)**2 + 1, + 'sol': [Eq(f(x), C1 - x), Eq(f(x), C1 + x)] + }, + + 'fact_09': { + 'eq': f(x)**2*Derivative(f(x), x)**6 - 2*f(x)**2*Derivative(f(x), + x)**4 + f(x)**2*Derivative(f(x), x)**2 - 2*f(x)*Derivative(f(x), + x)**5 + 4*f(x)*Derivative(f(x), x)**3 - 2*f(x)*Derivative(f(x), + x) + Derivative(f(x), x)**4 - 2*Derivative(f(x), x)**2 + 1, + 'sol': [ + Eq(f(x), C1 - x), Eq(f(x), -sqrt(C1 + 2*x)), + Eq(f(x), sqrt(C1 + 2*x)), Eq(f(x), C1 + x) + ] + }, + + 'fact_10': { + 'eq': x**4*f(x)**2 + 2*x**4*f(x)*Derivative(f(x), (x, 2)) + x**4*Derivative(f(x), + (x, 2))**2 + 2*x**3*f(x)*Derivative(f(x), x) + 2*x**3*Derivative(f(x), + x)*Derivative(f(x), (x, 2)) - 7*x**2*f(x)**2 - 7*x**2*f(x)*Derivative(f(x), + (x, 2)) + x**2*Derivative(f(x), x)**2 - 7*x*f(x)*Derivative(f(x), x) + 12*f(x)**2, + 'sol': [ + Eq(f(x), C1*besselj(2, x) + C2*bessely(2, x)), + Eq(f(x), C1*besselj(sqrt(3), x) + C2*bessely(sqrt(3), x)) + ], + 'slow': True, + }, + + 'fact_11': { + 'eq': (f(x).diff(x, 2)-exp(f(x)))*(f(x).diff(x, 2)+exp(f(x))), + 'sol': [ + Eq(f(x), log(C1/(cos(C1*sqrt(-1/C1)*(C2 + x)) - 1))), + Eq(f(x), log(C1/(cos(C1*sqrt(-1/C1)*(C2 - x)) - 1))), + Eq(f(x), log(C1/(1 - cos(C1*sqrt(-1/C1)*(C2 + x))))), + Eq(f(x), log(C1/(1 - cos(C1*sqrt(-1/C1)*(C2 - x))))) + ], + 'dsolve_too_slow': True, + }, + + #Below examples were added for the issue: https://github.com/sympy/sympy/issues/15889 + 'fact_12': { + 'eq': exp(f(x).diff(x))-f(x)**2, + 'sol': [Eq(NonElementaryIntegral(1/log(y**2), (y, f(x))), C1 + x)], + 'XFAIL': ['lie_group'] #It shows not implemented error for lie_group. + }, + + 'fact_13': { + 'eq': f(x).diff(x)**2 - f(x)**3, + 'sol': [Eq(f(x), 4/(C1**2 - 2*C1*x + x**2))], + 'XFAIL': ['lie_group'] #It shows not implemented error for lie_group. + }, + + 'fact_14': { + 'eq': f(x).diff(x)**2 - f(x), + 'sol': [Eq(f(x), C1**2/4 - C1*x/2 + x**2/4)] + }, + + 'fact_15': { + 'eq': f(x).diff(x)**2 - f(x)**2, + 'sol': [Eq(f(x), C1*exp(x)), Eq(f(x), C1*exp(-x))] + }, + + 'fact_16': { + 'eq': f(x).diff(x)**2 - f(x)**3, + 'sol': [Eq(f(x), 4/(C1**2 - 2*C1*x + x**2))], + }, + + # kamke ode 1.1 + 'fact_17': { + 'eq': f(x).diff(x)-(a4*x**4 + a3*x**3 + a2*x**2 + a1*x + a0)**(-1/2), + 'sol': [Eq(f(x), C1 + Integral(1/sqrt(a0 + a1*x + a2*x**2 + a3*x**3 + a4*x**4), x))], + 'slow': True + }, + + # This is from issue: https://github.com/sympy/sympy/issues/9446 + 'fact_18':{ + 'eq': Eq(f(2 * x), sin(Derivative(f(x)))), + 'sol': [Eq(f(x), C1 + Integral(pi - asin(f(2*x)), x)), Eq(f(x), C1 + Integral(asin(f(2*x)), x))], + 'checkodesol_XFAIL':True + }, + + # This is from issue: https://github.com/sympy/sympy/issues/7093 + 'fact_19': { + 'eq': Derivative(f(x), x)**2 - x**3, + 'sol': [Eq(f(x), C1 - 2*x**Rational(5,2)/5), Eq(f(x), C1 + 2*x**Rational(5,2)/5)], + }, + + 'fact_20': { + 'eq': x*f(x).diff(x, 2) - x*f(x), + 'sol': [Eq(f(x), C1*exp(-x) + C2*exp(x))], + }, + } + } + + + +@_add_example_keys +def _get_examples_ode_sol_almost_linear(): + from sympy.functions.special.error_functions import Ei + A = Symbol('A', positive=True) + f = Function('f') + d = f(x).diff(x) + + return { + 'hint': "almost_linear", + 'func': f(x), + 'examples':{ + 'almost_lin_01': { + 'eq': x**2*f(x)**2*d + f(x)**3 + 1, + 'sol': [Eq(f(x), (C1*exp(3/x) - 1)**Rational(1, 3)), + Eq(f(x), (-1 - sqrt(3)*I)*(C1*exp(3/x) - 1)**Rational(1, 3)/2), + Eq(f(x), (-1 + sqrt(3)*I)*(C1*exp(3/x) - 1)**Rational(1, 3)/2)], + + }, + + 'almost_lin_02': { + 'eq': x*f(x)*d + 2*x*f(x)**2 + 1, + 'sol': [Eq(f(x), -sqrt((C1 - 2*Ei(4*x))*exp(-4*x))), Eq(f(x), sqrt((C1 - 2*Ei(4*x))*exp(-4*x)))] + }, + + 'almost_lin_03': { + 'eq': x*d + x*f(x) + 1, + 'sol': [Eq(f(x), (C1 - Ei(x))*exp(-x))] + }, + + 'almost_lin_04': { + 'eq': x*exp(f(x))*d + exp(f(x)) + 3*x, + 'sol': [Eq(f(x), log(C1/x - x*Rational(3, 2)))], + }, + + 'almost_lin_05': { + 'eq': x + A*(x + diff(f(x), x) + f(x)) + diff(f(x), x) + f(x) + 2, + 'sol': [Eq(f(x), (C1 + Piecewise( + (x, Eq(A + 1, 0)), ((-A*x + A - x - 1)*exp(x)/(A + 1), True)))*exp(-x))], + }, + } + } + + +@_add_example_keys +def _get_examples_ode_sol_liouville(): + n = Symbol('n') + _y = Dummy('y') + return { + 'hint': "Liouville", + 'func': f(x), + 'examples':{ + 'liouville_01': { + 'eq': diff(f(x), x)/x + diff(f(x), x, x)/2 - diff(f(x), x)**2/2, + 'sol': [Eq(f(x), log(x/(C1 + C2*x)))], + + }, + + 'liouville_02': { + 'eq': diff(x*exp(-f(x)), x, x), + 'sol': [Eq(f(x), log(x/(C1 + C2*x)))] + }, + + 'liouville_03': { + 'eq': ((diff(f(x), x)/x + diff(f(x), x, x)/2 - diff(f(x), x)**2/2)*exp(-f(x))/exp(f(x))).expand(), + 'sol': [Eq(f(x), log(x/(C1 + C2*x)))] + }, + + 'liouville_04': { + 'eq': diff(f(x), x, x) + 1/f(x)*(diff(f(x), x))**2 + 1/x*diff(f(x), x), + 'sol': [Eq(f(x), -sqrt(C1 + C2*log(x))), Eq(f(x), sqrt(C1 + C2*log(x)))], + }, + + 'liouville_05': { + 'eq': x*diff(f(x), x, x) + x/f(x)*diff(f(x), x)**2 + x*diff(f(x), x), + 'sol': [Eq(f(x), -sqrt(C1 + C2*exp(-x))), Eq(f(x), sqrt(C1 + C2*exp(-x)))], + }, + + 'liouville_06': { + 'eq': Eq((x*exp(f(x))).diff(x, x), 0), + 'sol': [Eq(f(x), log(C1 + C2/x))], + }, + + 'liouville_07': { + 'eq': (diff(f(x), x)/x + diff(f(x), x, x)/2 - diff(f(x), x)**2/2)*exp(-f(x))/exp(f(x)), + 'sol': [Eq(f(x), log(x/(C1 + C2*x)))], + }, + + 'liouville_08': { + 'eq': x**2*diff(f(x),x) + (n*f(x) + f(x)**2)*diff(f(x),x)**2 + diff(f(x), (x, 2)), + 'sol': [Eq(C1 + C2*lowergamma(Rational(1,3), x**3/3) + NonElementaryIntegral(exp(_y**3/3)*exp(_y**2*n/2), (_y, f(x))), 0)], + }, + } + } + + +@_add_example_keys +def _get_examples_ode_sol_nth_algebraic(): + M, m, r, t = symbols('M m r t') + phi = Function('phi') + k = Symbol('k') + # This one needs a substitution f' = g. + # 'algeb_12': { + # 'eq': -exp(x) + (x*Derivative(f(x), (x, 2)) + Derivative(f(x), x))/x, + # 'sol': [Eq(f(x), C1 + C2*log(x) + exp(x) - Ei(x))], + # }, + return { + 'hint': "nth_algebraic", + 'func': f(x), + 'examples':{ + 'algeb_01': { + 'eq': f(x) * f(x).diff(x) * f(x).diff(x, x) * (f(x) - 1) * (f(x).diff(x) - x), + 'sol': [Eq(f(x), C1 + x**2/2), Eq(f(x), C1 + C2*x)] + }, + + 'algeb_02': { + 'eq': f(x) * f(x).diff(x) * f(x).diff(x, x) * (f(x) - 1), + 'sol': [Eq(f(x), C1 + C2*x)] + }, + + 'algeb_03': { + 'eq': f(x) * f(x).diff(x) * f(x).diff(x, x), + 'sol': [Eq(f(x), C1 + C2*x)] + }, + + 'algeb_04': { + 'eq': Eq(-M * phi(t).diff(t), + Rational(3, 2) * m * r**2 * phi(t).diff(t) * phi(t).diff(t,t)), + 'sol': [Eq(phi(t), C1), Eq(phi(t), C1 + C2*t - M*t**2/(3*m*r**2))], + 'func': phi(t) + }, + + 'algeb_05': { + 'eq': (1 - sin(f(x))) * f(x).diff(x), + 'sol': [Eq(f(x), C1)], + 'XFAIL': ['separable'] #It raised exception. + }, + + 'algeb_06': { + 'eq': (diff(f(x)) - x)*(diff(f(x)) + x), + 'sol': [Eq(f(x), C1 - x**2/2), Eq(f(x), C1 + x**2/2)] + }, + + 'algeb_07': { + 'eq': Eq(Derivative(f(x), x), Derivative(g(x), x)), + 'sol': [Eq(f(x), C1 + g(x))], + }, + + 'algeb_08': { + 'eq': f(x).diff(x) - C1, #this example is from issue 15999 + 'sol': [Eq(f(x), C1*x + C2)], + }, + + 'algeb_09': { + 'eq': f(x)*f(x).diff(x), + 'sol': [Eq(f(x), C1)], + }, + + 'algeb_10': { + 'eq': (diff(f(x)) - x)*(diff(f(x)) + x), + 'sol': [Eq(f(x), C1 - x**2/2), Eq(f(x), C1 + x**2/2)], + }, + + 'algeb_11': { + 'eq': f(x) + f(x)*f(x).diff(x), + 'sol': [Eq(f(x), 0), Eq(f(x), C1 - x)], + 'XFAIL': ['separable', '1st_exact', '1st_linear', 'Bernoulli', '1st_homogeneous_coeff_best', + '1st_homogeneous_coeff_subs_indep_div_dep', '1st_homogeneous_coeff_subs_dep_div_indep', + 'lie_group', 'nth_linear_constant_coeff_undetermined_coefficients', + 'nth_linear_euler_eq_nonhomogeneous_undetermined_coefficients', + 'nth_linear_constant_coeff_variation_of_parameters', + 'nth_linear_euler_eq_nonhomogeneous_variation_of_parameters'] + #nth_linear_constant_coeff_undetermined_coefficients raises exception rest all of them misses a solution. + }, + + 'algeb_12': { + 'eq': Derivative(x*f(x), x, x, x), + 'sol': [Eq(f(x), (C1 + C2*x + C3*x**2) / x)], + 'XFAIL': ['nth_algebraic'] # It passes only when prep=False is set in dsolve. + }, + + 'algeb_13': { + 'eq': Eq(Derivative(x*Derivative(f(x), x), x)/x, exp(x)), + 'sol': [Eq(f(x), C1 + C2*log(x) + exp(x) - Ei(x))], + 'XFAIL': ['nth_algebraic'] # It passes only when prep=False is set in dsolve. + }, + + # These are simple tests from the old ode module example 14-18 + 'algeb_14': { + 'eq': Eq(f(x).diff(x), 0), + 'sol': [Eq(f(x), C1)], + }, + + 'algeb_15': { + 'eq': Eq(3*f(x).diff(x) - 5, 0), + 'sol': [Eq(f(x), C1 + x*Rational(5, 3))], + }, + + 'algeb_16': { + 'eq': Eq(3*f(x).diff(x), 5), + 'sol': [Eq(f(x), C1 + x*Rational(5, 3))], + }, + + # Type: 2nd order, constant coefficients (two complex roots) + 'algeb_17': { + 'eq': Eq(3*f(x).diff(x) - 1, 0), + 'sol': [Eq(f(x), C1 + x/3)], + }, + + 'algeb_18': { + 'eq': Eq(x*f(x).diff(x) - 1, 0), + 'sol': [Eq(f(x), C1 + log(x))], + }, + + # https://github.com/sympy/sympy/issues/6989 + 'algeb_19': { + 'eq': f(x).diff(x) - x*exp(-k*x), + 'sol': [Eq(f(x), C1 + Piecewise(((-k*x - 1)*exp(-k*x)/k**2, Ne(k**2, 0)),(x**2/2, True)))], + }, + + 'algeb_20': { + 'eq': -f(x).diff(x) + x*exp(-k*x), + 'sol': [Eq(f(x), C1 + Piecewise(((-k*x - 1)*exp(-k*x)/k**2, Ne(k**2, 0)),(x**2/2, True)))], + }, + + # https://github.com/sympy/sympy/issues/10867 + 'algeb_21': { + 'eq': Eq(g(x).diff(x).diff(x), (x-2)**2 + (x-3)**3), + 'sol': [Eq(g(x), C1 + C2*x + x**5/20 - 2*x**4/3 + 23*x**3/6 - 23*x**2/2)], + 'func': g(x), + }, + + # https://github.com/sympy/sympy/issues/13691 + 'algeb_22': { + 'eq': f(x).diff(x) - C1*g(x).diff(x), + 'sol': [Eq(f(x), C2 + C1*g(x))], + 'func': f(x), + }, + + # https://github.com/sympy/sympy/issues/4838 + 'algeb_23': { + 'eq': f(x).diff(x) - 3*C1 - 3*x**2, + 'sol': [Eq(f(x), C2 + 3*C1*x + x**3)], + }, + } + } + + +@_add_example_keys +def _get_examples_ode_sol_nth_order_reducible(): + return { + 'hint': "nth_order_reducible", + 'func': f(x), + 'examples':{ + 'reducible_01': { + 'eq': Eq(x*Derivative(f(x), x)**2 + Derivative(f(x), x, 2), 0), + 'sol': [Eq(f(x),C1 - sqrt(-1/C2)*log(-C2*sqrt(-1/C2) + x) + + sqrt(-1/C2)*log(C2*sqrt(-1/C2) + x))], + 'slow': True, + }, + + 'reducible_02': { + 'eq': -exp(x) + (x*Derivative(f(x), (x, 2)) + Derivative(f(x), x))/x, + 'sol': [Eq(f(x), C1 + C2*log(x) + exp(x) - Ei(x))], + 'slow': True, + }, + + 'reducible_03': { + 'eq': Eq(sqrt(2) * f(x).diff(x,x,x) + f(x).diff(x), 0), + 'sol': [Eq(f(x), C1 + C2*sin(2**Rational(3, 4)*x/2) + C3*cos(2**Rational(3, 4)*x/2))], + 'slow': True, + }, + + 'reducible_04': { + 'eq': f(x).diff(x, 2) + 2*f(x).diff(x), + 'sol': [Eq(f(x), C1 + C2*exp(-2*x))], + }, + + 'reducible_05': { + 'eq': f(x).diff(x, 3) + f(x).diff(x, 2) - 6*f(x).diff(x), + 'sol': [Eq(f(x), C1 + C2*exp(-3*x) + C3*exp(2*x))], + 'slow': True, + }, + + 'reducible_06': { + 'eq': f(x).diff(x, 4) - f(x).diff(x, 3) - 4*f(x).diff(x, 2) + \ + 4*f(x).diff(x), + 'sol': [Eq(f(x), C1 + C2*exp(-2*x) + C3*exp(x) + C4*exp(2*x))], + 'slow': True, + }, + + 'reducible_07': { + 'eq': f(x).diff(x, 4) + 3*f(x).diff(x, 3), + 'sol': [Eq(f(x), C1 + C2*x + C3*x**2 + C4*exp(-3*x))], + 'slow': True, + }, + + 'reducible_08': { + 'eq': f(x).diff(x, 4) - 2*f(x).diff(x, 2), + 'sol': [Eq(f(x), C1 + C2*x + C3*exp(-sqrt(2)*x) + C4*exp(sqrt(2)*x))], + 'slow': True, + }, + + 'reducible_09': { + 'eq': f(x).diff(x, 4) + 4*f(x).diff(x, 2), + 'sol': [Eq(f(x), C1 + C2*x + C3*sin(2*x) + C4*cos(2*x))], + 'slow': True, + }, + + 'reducible_10': { + 'eq': f(x).diff(x, 5) + 2*f(x).diff(x, 3) + f(x).diff(x), + 'sol': [Eq(f(x), C1 + C2*x*sin(x) + C2*cos(x) - C3*x*cos(x) + C3*sin(x) + C4*sin(x) + C5*cos(x))], + 'slow': True, + }, + + 'reducible_11': { + 'eq': f(x).diff(x, 2) - f(x).diff(x)**3, + 'sol': [Eq(f(x), C1 - sqrt(2)*sqrt(-1/(C2 + x))*(C2 + x)), + Eq(f(x), C1 + sqrt(2)*sqrt(-1/(C2 + x))*(C2 + x))], + 'slow': True, + }, + + # Needs to be a way to know how to combine derivatives in the expression + 'reducible_12': { + 'eq': Derivative(x*f(x), x, x, x) + Derivative(f(x), x, x, x), + 'sol': [Eq(f(x), C1 + C3/Mul(2, (x**2 + 2*x + 1), evaluate=False) + + x*(C2 + C3/Mul(2, (x**2 + 2*x + 1), evaluate=False)))], # 2-arg Mul! + 'slow': True, + }, + } + } + + + +@_add_example_keys +def _get_examples_ode_sol_nth_linear_undetermined_coefficients(): + # examples 3-27 below are from Ordinary Differential Equations, + # Tenenbaum and Pollard, pg. 231 + g = exp(-x) + f2 = f(x).diff(x, 2) + c = 3*f(x).diff(x, 3) + 5*f2 + f(x).diff(x) - f(x) - x + t = symbols("t") + u = symbols("u",cls=Function) + R, L, C, E_0, alpha = symbols("R L C E_0 alpha",positive=True) + omega = Symbol('omega') + return { + 'hint': "nth_linear_constant_coeff_undetermined_coefficients", + 'func': f(x), + 'examples':{ + 'undet_01': { + 'eq': c - x*g, + 'sol': [Eq(f(x), C3*exp(x/3) - x + (C1 + x*(C2 - x**2/24 - 3*x/32))*exp(-x) - 1)], + 'slow': True, + }, + + 'undet_02': { + 'eq': c - g, + 'sol': [Eq(f(x), C3*exp(x/3) - x + (C1 + x*(C2 - x/8))*exp(-x) - 1)], + 'slow': True, + }, + + 'undet_03': { + 'eq': f2 + 3*f(x).diff(x) + 2*f(x) - 4, + 'sol': [Eq(f(x), C1*exp(-2*x) + C2*exp(-x) + 2)], + 'slow': True, + }, + + 'undet_04': { + 'eq': f2 + 3*f(x).diff(x) + 2*f(x) - 12*exp(x), + 'sol': [Eq(f(x), C1*exp(-2*x) + C2*exp(-x) + 2*exp(x))], + 'slow': True, + }, + + 'undet_05': { + 'eq': f2 + 3*f(x).diff(x) + 2*f(x) - exp(I*x), + 'sol': [Eq(f(x), (S(3)/10 + I/10)*(C1*exp(-2*x) + C2*exp(-x) - I*exp(I*x)))], + 'slow': True, + }, + + 'undet_06': { + 'eq': f2 + 3*f(x).diff(x) + 2*f(x) - sin(x), + 'sol': [Eq(f(x), C1*exp(-2*x) + C2*exp(-x) + sin(x)/10 - 3*cos(x)/10)], + 'slow': True, + }, + + 'undet_07': { + 'eq': f2 + 3*f(x).diff(x) + 2*f(x) - cos(x), + 'sol': [Eq(f(x), C1*exp(-2*x) + C2*exp(-x) + 3*sin(x)/10 + cos(x)/10)], + 'slow': True, + }, + + 'undet_08': { + 'eq': f2 + 3*f(x).diff(x) + 2*f(x) - (8 + 6*exp(x) + 2*sin(x)), + 'sol': [Eq(f(x), C1*exp(-2*x) + C2*exp(-x) + exp(x) + sin(x)/5 - 3*cos(x)/5 + 4)], + 'slow': True, + }, + + 'undet_09': { + 'eq': f2 + f(x).diff(x) + f(x) - x**2, + 'sol': [Eq(f(x), -2*x + x**2 + (C1*sin(x*sqrt(3)/2) + C2*cos(x*sqrt(3)/2))*exp(-x/2))], + 'slow': True, + }, + + 'undet_10': { + 'eq': f2 - 2*f(x).diff(x) - 8*f(x) - 9*x*exp(x) - 10*exp(-x), + 'sol': [Eq(f(x), -x*exp(x) - 2*exp(-x) + C1*exp(-2*x) + C2*exp(4*x))], + 'slow': True, + }, + + 'undet_11': { + 'eq': f2 - 3*f(x).diff(x) - 2*exp(2*x)*sin(x), + 'sol': [Eq(f(x), C1 + C2*exp(3*x) - 3*exp(2*x)*sin(x)/5 - exp(2*x)*cos(x)/5)], + 'slow': True, + }, + + 'undet_12': { + 'eq': f(x).diff(x, 4) - 2*f2 + f(x) - x + sin(x), + 'sol': [Eq(f(x), x - sin(x)/4 + (C1 + C2*x)*exp(-x) + (C3 + C4*x)*exp(x))], + 'slow': True, + }, + + 'undet_13': { + 'eq': f2 + f(x).diff(x) - x**2 - 2*x, + 'sol': [Eq(f(x), C1 + x**3/3 + C2*exp(-x))], + 'slow': True, + }, + + 'undet_14': { + 'eq': f2 + f(x).diff(x) - x - sin(2*x), + 'sol': [Eq(f(x), C1 - x - sin(2*x)/5 - cos(2*x)/10 + x**2/2 + C2*exp(-x))], + 'slow': True, + }, + + 'undet_15': { + 'eq': f2 + f(x) - 4*x*sin(x), + 'sol': [Eq(f(x), (C1 - x**2)*cos(x) + (C2 + x)*sin(x))], + 'slow': True, + }, + + 'undet_16': { + 'eq': f2 + 4*f(x) - x*sin(2*x), + 'sol': [Eq(f(x), (C1 - x**2/8)*cos(2*x) + (C2 + x/16)*sin(2*x))], + 'slow': True, + }, + + 'undet_17': { + 'eq': f2 + 2*f(x).diff(x) + f(x) - x**2*exp(-x), + 'sol': [Eq(f(x), (C1 + x*(C2 + x**3/12))*exp(-x))], + 'slow': True, + }, + + 'undet_18': { + 'eq': f(x).diff(x, 3) + 3*f2 + 3*f(x).diff(x) + f(x) - 2*exp(-x) + \ + x**2*exp(-x), + 'sol': [Eq(f(x), (C1 + x*(C2 + x*(C3 - x**3/60 + x/3)))*exp(-x))], + 'slow': True, + }, + + 'undet_19': { + 'eq': f2 + 3*f(x).diff(x) + 2*f(x) - exp(-2*x) - x**2, + 'sol': [Eq(f(x), C2*exp(-x) + x**2/2 - x*Rational(3,2) + (C1 - x)*exp(-2*x) + Rational(7,4))], + 'slow': True, + }, + + 'undet_20': { + 'eq': f2 - 3*f(x).diff(x) + 2*f(x) - x*exp(-x), + 'sol': [Eq(f(x), C1*exp(x) + C2*exp(2*x) + (6*x + 5)*exp(-x)/36)], + 'slow': True, + }, + + 'undet_21': { + 'eq': f2 + f(x).diff(x) - 6*f(x) - x - exp(2*x), + 'sol': [Eq(f(x), Rational(-1, 36) - x/6 + C2*exp(-3*x) + (C1 + x/5)*exp(2*x))], + 'slow': True, + }, + + 'undet_22': { + 'eq': f2 + f(x) - sin(x) - exp(-x), + 'sol': [Eq(f(x), C2*sin(x) + (C1 - x/2)*cos(x) + exp(-x)/2)], + 'slow': True, + }, + + 'undet_23': { + 'eq': f(x).diff(x, 3) - 3*f2 + 3*f(x).diff(x) - f(x) - exp(x), + 'sol': [Eq(f(x), (C1 + x*(C2 + x*(C3 + x/6)))*exp(x))], + 'slow': True, + }, + + 'undet_24': { + 'eq': f2 + f(x) - S.Half - cos(2*x)/2, + 'sol': [Eq(f(x), S.Half - cos(2*x)/6 + C1*sin(x) + C2*cos(x))], + 'slow': True, + }, + + 'undet_25': { + 'eq': f(x).diff(x, 3) - f(x).diff(x) - exp(2*x)*(S.Half - cos(2*x)/2), + 'sol': [Eq(f(x), C1 + C2*exp(-x) + C3*exp(x) + (-21*sin(2*x) + 27*cos(2*x) + 130)*exp(2*x)/1560)], + 'slow': True, + }, + + #Note: 'undet_26' is referred in 'undet_37' + 'undet_26': { + 'eq': (f(x).diff(x, 5) + 2*f(x).diff(x, 3) + f(x).diff(x) - 2*x - + sin(x) - cos(x)), + 'sol': [Eq(f(x), C1 + x**2 + (C2 + x*(C3 - x/8))*sin(x) + (C4 + x*(C5 + x/8))*cos(x))], + 'slow': True, + }, + + 'undet_27': { + 'eq': f2 + f(x) - cos(x)/2 + cos(3*x)/2, + 'sol': [Eq(f(x), cos(3*x)/16 + C2*cos(x) + (C1 + x/4)*sin(x))], + 'slow': True, + }, + + 'undet_28': { + 'eq': f(x).diff(x) - 1, + 'sol': [Eq(f(x), C1 + x)], + 'slow': True, + }, + + # https://github.com/sympy/sympy/issues/19358 + 'undet_29': { + 'eq': f2 + f(x).diff(x) + exp(x-C1), + 'sol': [Eq(f(x), C2 + C3*exp(-x) - exp(-C1 + x)/2)], + 'slow': True, + }, + + # https://github.com/sympy/sympy/issues/18408 + 'undet_30': { + 'eq': f(x).diff(x, 3) - f(x).diff(x) - sinh(x), + 'sol': [Eq(f(x), C1 + C2*exp(-x) + C3*exp(x) + x*sinh(x)/2)], + }, + + 'undet_31': { + 'eq': f(x).diff(x, 2) - 49*f(x) - sinh(3*x), + 'sol': [Eq(f(x), C1*exp(-7*x) + C2*exp(7*x) - sinh(3*x)/40)], + }, + + 'undet_32': { + 'eq': f(x).diff(x, 3) - f(x).diff(x) - sinh(x) - exp(x), + 'sol': [Eq(f(x), C1 + C3*exp(-x) + x*sinh(x)/2 + (C2 + x/2)*exp(x))], + }, + + # https://github.com/sympy/sympy/issues/5096 + 'undet_33': { + 'eq': f(x).diff(x, x) + f(x) - x*sin(x - 2), + 'sol': [Eq(f(x), C1*sin(x) + C2*cos(x) - x**2*cos(x - 2)/4 + x*sin(x - 2)/4)], + }, + + 'undet_34': { + 'eq': f(x).diff(x, 2) + f(x) - x**4*sin(x-1), + 'sol': [ Eq(f(x), C1*sin(x) + C2*cos(x) - x**5*cos(x - 1)/10 + x**4*sin(x - 1)/4 + x**3*cos(x - 1)/2 - 3*x**2*sin(x - 1)/4 - 3*x*cos(x - 1)/4)], + }, + + 'undet_35': { + 'eq': f(x).diff(x, 2) - f(x) - exp(x - 1), + 'sol': [Eq(f(x), C2*exp(-x) + (C1 + x*exp(-1)/2)*exp(x))], + }, + + 'undet_36': { + 'eq': f(x).diff(x, 2)+f(x)-(sin(x-2)+1), + 'sol': [Eq(f(x), C1*sin(x) + C2*cos(x) - x*cos(x - 2)/2 + 1)], + }, + + # Equivalent to example_name 'undet_26'. + # This previously failed because the algorithm for undetermined coefficients + # didn't know to multiply exp(I*x) by sufficient x because it is linearly + # dependent on sin(x) and cos(x). + 'undet_37': { + 'eq': f(x).diff(x, 5) + 2*f(x).diff(x, 3) + f(x).diff(x) - 2*x - exp(I*x), + 'sol': [Eq(f(x), C1 + x**2*(I*exp(I*x)/8 + 1) + (C2 + C3*x)*sin(x) + (C4 + C5*x)*cos(x))], + }, + + # https://github.com/sympy/sympy/issues/12623 + 'undet_38': { + 'eq': Eq( u(t).diff(t,t) + R /L*u(t).diff(t) + 1/(L*C)*u(t), alpha), + 'sol': [Eq(u(t), C*L*alpha + C2*exp(-t*(R + sqrt(C*R**2 - 4*L)/sqrt(C))/(2*L)) + + C1*exp(t*(-R + sqrt(C*R**2 - 4*L)/sqrt(C))/(2*L)))], + 'func': u(t) + }, + + 'undet_39': { + 'eq': Eq( L*C*u(t).diff(t,t) + R*C*u(t).diff(t) + u(t), E_0*exp(I*omega*t) ), + 'sol': [Eq(u(t), C2*exp(-t*(R + sqrt(C*R**2 - 4*L)/sqrt(C))/(2*L)) + + C1*exp(t*(-R + sqrt(C*R**2 - 4*L)/sqrt(C))/(2*L)) + - E_0*exp(I*omega*t)/(C*L*omega**2 - I*C*R*omega - 1))], + 'func': u(t), + }, + + # https://github.com/sympy/sympy/issues/6879 + 'undet_40': { + 'eq': Eq(Derivative(f(x), x, 2) - 2*Derivative(f(x), x) + f(x), sin(x)), + 'sol': [Eq(f(x), (C1 + C2*x)*exp(x) + cos(x)/2)], + }, + } + } + + +@_add_example_keys +def _get_examples_ode_sol_separable(): + # test_separable1-5 are from Ordinary Differential Equations, Tenenbaum and + # Pollard, pg. 55 + t,a = symbols('a,t') + m = 96 + g = 9.8 + k = .2 + f1 = g * m + v = Function('v') + return { + 'hint': "separable", + 'func': f(x), + 'examples':{ + 'separable_01': { + 'eq': f(x).diff(x) - f(x), + 'sol': [Eq(f(x), C1*exp(x))], + }, + + 'separable_02': { + 'eq': x*f(x).diff(x) - f(x), + 'sol': [Eq(f(x), C1*x)], + }, + + 'separable_03': { + 'eq': f(x).diff(x) + sin(x), + 'sol': [Eq(f(x), C1 + cos(x))], + }, + + 'separable_04': { + 'eq': f(x)**2 + 1 - (x**2 + 1)*f(x).diff(x), + 'sol': [Eq(f(x), tan(C1 + atan(x)))], + }, + + 'separable_05': { + 'eq': f(x).diff(x)/tan(x) - f(x) - 2, + 'sol': [Eq(f(x), C1/cos(x) - 2)], + }, + + 'separable_06': { + 'eq': f(x).diff(x) * (1 - sin(f(x))) - 1, + 'sol': [Eq(-x + f(x) + cos(f(x)), C1)], + }, + + 'separable_07': { + 'eq': f(x)*x**2*f(x).diff(x) - f(x)**3 - 2*x**2*f(x).diff(x), + 'sol': [Eq(f(x), (-x - sqrt(x*(4*C1*x + x - 4)))/(C1*x - 1)/2), + Eq(f(x), (-x + sqrt(x*(4*C1*x + x - 4)))/(C1*x - 1)/2)], + 'slow': True, + }, + + 'separable_08': { + 'eq': f(x)**2 - 1 - (2*f(x) + x*f(x))*f(x).diff(x), + 'sol': [Eq(f(x), -sqrt(C1*x**2 + 4*C1*x + 4*C1 + 1)), + Eq(f(x), sqrt(C1*x**2 + 4*C1*x + 4*C1 + 1))], + 'slow': True, + }, + + 'separable_09': { + 'eq': x*log(x)*f(x).diff(x) + sqrt(1 + f(x)**2), + 'sol': [Eq(f(x), sinh(C1 - log(log(x))))], #One more solution is f(x)=I + 'slow': True, + 'checkodesol_XFAIL': True, + }, + + 'separable_10': { + 'eq': exp(x + 1)*tan(f(x)) + cos(f(x))*f(x).diff(x), + 'sol': [Eq(E*exp(x) + log(cos(f(x)) - 1)/2 - log(cos(f(x)) + 1)/2 + cos(f(x)), C1)], + 'slow': True, + }, + + 'separable_11': { + 'eq': (x*cos(f(x)) + x**2*sin(f(x))*f(x).diff(x) - a**2*sin(f(x))*f(x).diff(x)), + 'sol': [ + Eq(f(x), -acos(C1*sqrt(-a**2 + x**2)) + 2*pi), + Eq(f(x), acos(C1*sqrt(-a**2 + x**2))) + ], + 'slow': True, + }, + + 'separable_12': { + 'eq': f(x).diff(x) - f(x)*tan(x), + 'sol': [Eq(f(x), C1/cos(x))], + }, + + 'separable_13': { + 'eq': (x - 1)*cos(f(x))*f(x).diff(x) - 2*x*sin(f(x)), + 'sol': [ + Eq(f(x), pi - asin(C1*(x**2 - 2*x + 1)*exp(2*x))), + Eq(f(x), asin(C1*(x**2 - 2*x + 1)*exp(2*x))) + ], + }, + + 'separable_14': { + 'eq': f(x).diff(x) - f(x)*log(f(x))/tan(x), + 'sol': [Eq(f(x), exp(C1*sin(x)))], + }, + + 'separable_15': { + 'eq': x*f(x).diff(x) + (1 + f(x)**2)*atan(f(x)), + 'sol': [Eq(f(x), tan(C1/x))], #Two more solutions are f(x)=0 and f(x)=I + 'slow': True, + 'checkodesol_XFAIL': True, + }, + + 'separable_16': { + 'eq': f(x).diff(x) + x*(f(x) + 1), + 'sol': [Eq(f(x), -1 + C1*exp(-x**2/2))], + }, + + 'separable_17': { + 'eq': exp(f(x)**2)*(x**2 + 2*x + 1) + (x*f(x) + f(x))*f(x).diff(x), + 'sol': [ + Eq(f(x), -sqrt(log(1/(C1 + x**2 + 2*x)))), + Eq(f(x), sqrt(log(1/(C1 + x**2 + 2*x)))) + ], + }, + + 'separable_18': { + 'eq': f(x).diff(x) + f(x), + 'sol': [Eq(f(x), C1*exp(-x))], + }, + + 'separable_19': { + 'eq': sin(x)*cos(2*f(x)) + cos(x)*sin(2*f(x))*f(x).diff(x), + 'sol': [Eq(f(x), pi - acos(C1/cos(x)**2)/2), Eq(f(x), acos(C1/cos(x)**2)/2)], + }, + + 'separable_20': { + 'eq': (1 - x)*f(x).diff(x) - x*(f(x) + 1), + 'sol': [Eq(f(x), (C1*exp(-x) - x + 1)/(x - 1))], + }, + + 'separable_21': { + 'eq': f(x)*diff(f(x), x) + x - 3*x*f(x)**2, + 'sol': [Eq(f(x), -sqrt(3)*sqrt(C1*exp(3*x**2) + 1)/3), + Eq(f(x), sqrt(3)*sqrt(C1*exp(3*x**2) + 1)/3)], + }, + + 'separable_22': { + 'eq': f(x).diff(x) - exp(x + f(x)), + 'sol': [Eq(f(x), log(-1/(C1 + exp(x))))], + 'XFAIL': ['lie_group'] #It shows 'NoneType' object is not subscriptable for lie_group. + }, + + # https://github.com/sympy/sympy/issues/7081 + 'separable_23': { + 'eq': x*(f(x).diff(x)) + 1 - f(x)**2, + 'sol': [Eq(f(x), (-C1 - x**2)/(-C1 + x**2))], + }, + + # https://github.com/sympy/sympy/issues/10379 + 'separable_24': { + 'eq': f(t).diff(t)-(1-51.05*y*f(t)), + 'sol': [Eq(f(t), (0.019588638589618023*exp(y*(C1 - 51.049999999999997*t)) + 0.019588638589618023)/y)], + 'func': f(t), + }, + + # https://github.com/sympy/sympy/issues/15999 + 'separable_25': { + 'eq': f(x).diff(x) - C1*f(x), + 'sol': [Eq(f(x), C2*exp(C1*x))], + }, + + 'separable_26': { + 'eq': f1 - k * (v(t) ** 2) - m * Derivative(v(t)), + 'sol': [Eq(v(t), -68.585712797928991/tanh(C1 - 0.14288690166235204*t))], + 'func': v(t), + 'checkodesol_XFAIL': True, + }, + + #https://github.com/sympy/sympy/issues/22155 + 'separable_27': { + 'eq': f(x).diff(x) - exp(f(x) - x), + 'sol': [Eq(f(x), log(-exp(x)/(C1*exp(x) - 1)))], + } + } + } + + +@_add_example_keys +def _get_examples_ode_sol_1st_exact(): + # Type: Exact differential equation, p(x,f) + q(x,f)*f' == 0, + # where dp/df == dq/dx + ''' + Example 7 is an exact equation that fails under the exact engine. It is caught + by first order homogeneous albeit with a much contorted solution. The + exact engine fails because of a poorly simplified integral of q(0,y)dy, + where q is the function multiplying f'. The solutions should be + Eq(sqrt(x**2+f(x)**2)**3+y**3, C1). The equation below is + equivalent, but it is so complex that checkodesol fails, and takes a long + time to do so. + ''' + return { + 'hint': "1st_exact", + 'func': f(x), + 'examples':{ + '1st_exact_01': { + 'eq': sin(x)*cos(f(x)) + cos(x)*sin(f(x))*f(x).diff(x), + 'sol': [Eq(f(x), -acos(C1/cos(x)) + 2*pi), Eq(f(x), acos(C1/cos(x)))], + 'slow': True, + }, + + '1st_exact_02': { + 'eq': (2*x*f(x) + 1)/f(x) + (f(x) - x)/f(x)**2*f(x).diff(x), + 'sol': [Eq(f(x), exp(C1 - x**2 + LambertW(-x*exp(-C1 + x**2))))], + 'XFAIL': ['lie_group'], #It shows dsolve raises an exception: List index out of range for lie_group + 'slow': True, + 'checkodesol_XFAIL':True + }, + + '1st_exact_03': { + 'eq': 2*x + f(x)*cos(x) + (2*f(x) + sin(x) - sin(f(x)))*f(x).diff(x), + 'sol': [Eq(f(x)*sin(x) + cos(f(x)) + x**2 + f(x)**2, C1)], + 'XFAIL': ['lie_group'], #It goes into infinite loop for lie_group. + 'slow': True, + }, + + '1st_exact_04': { + 'eq': cos(f(x)) - (x*sin(f(x)) - f(x)**2)*f(x).diff(x), + 'sol': [Eq(x*cos(f(x)) + f(x)**3/3, C1)], + 'slow': True, + }, + + '1st_exact_05': { + 'eq': 2*x*f(x) + (x**2 + f(x)**2)*f(x).diff(x), + 'sol': [Eq(x**2*f(x) + f(x)**3/3, C1)], + 'slow': True, + 'simplify_flag':False + }, + + # This was from issue: https://github.com/sympy/sympy/issues/11290 + '1st_exact_06': { + 'eq': cos(f(x)) - (x*sin(f(x)) - f(x)**2)*f(x).diff(x), + 'sol': [Eq(x*cos(f(x)) + f(x)**3/3, C1)], + 'simplify_flag':False + }, + + '1st_exact_07': { + 'eq': x*sqrt(x**2 + f(x)**2) - (x**2*f(x)/(f(x) - sqrt(x**2 + f(x)**2)))*f(x).diff(x), + 'sol': [Eq(log(x), + C1 - 9*sqrt(1 + f(x)**2/x**2)*asinh(f(x)/x)/(-27*f(x)/x + + 27*sqrt(1 + f(x)**2/x**2)) - 9*sqrt(1 + f(x)**2/x**2)* + log(1 - sqrt(1 + f(x)**2/x**2)*f(x)/x + 2*f(x)**2/x**2)/ + (-27*f(x)/x + 27*sqrt(1 + f(x)**2/x**2)) + + 9*asinh(f(x)/x)*f(x)/(x*(-27*f(x)/x + 27*sqrt(1 + f(x)**2/x**2))) + + 9*f(x)*log(1 - sqrt(1 + f(x)**2/x**2)*f(x)/x + 2*f(x)**2/x**2)/ + (x*(-27*f(x)/x + 27*sqrt(1 + f(x)**2/x**2))))], + 'slow': True, + 'dsolve_too_slow':True + }, + + # Type: a(x)f'(x)+b(x)*f(x)+c(x)=0 + '1st_exact_08': { + 'eq': Eq(x**2*f(x).diff(x) + 3*x*f(x) - sin(x)/x, 0), + 'sol': [Eq(f(x), (C1 - cos(x))/x**3)], + }, + + # these examples are from test_exact_enhancement + '1st_exact_09': { + 'eq': f(x)/x**2 + ((f(x)*x - 1)/x)*f(x).diff(x), + 'sol': [Eq(f(x), (i*sqrt(C1*x**2 + 1) + 1)/x) for i in (-1, 1)], + }, + + '1st_exact_10': { + 'eq': (x*f(x) - 1) + f(x).diff(x)*(x**2 - x*f(x)), + 'sol': [Eq(f(x), x - sqrt(C1 + x**2 - 2*log(x))), Eq(f(x), x + sqrt(C1 + x**2 - 2*log(x)))], + }, + + '1st_exact_11': { + 'eq': (x + 2)*sin(f(x)) + f(x).diff(x)*x*cos(f(x)), + 'sol': [Eq(f(x), -asin(C1*exp(-x)/x**2) + pi), Eq(f(x), asin(C1*exp(-x)/x**2))], + }, + } + } + + +@_add_example_keys +def _get_examples_ode_sol_nth_linear_var_of_parameters(): + g = exp(-x) + f2 = f(x).diff(x, 2) + c = 3*f(x).diff(x, 3) + 5*f2 + f(x).diff(x) - f(x) - x + return { + 'hint': "nth_linear_constant_coeff_variation_of_parameters", + 'func': f(x), + 'examples':{ + 'var_of_parameters_01': { + 'eq': c - x*g, + 'sol': [Eq(f(x), C3*exp(x/3) - x + (C1 + x*(C2 - x**2/24 - 3*x/32))*exp(-x) - 1)], + 'slow': True, + }, + + 'var_of_parameters_02': { + 'eq': c - g, + 'sol': [Eq(f(x), C3*exp(x/3) - x + (C1 + x*(C2 - x/8))*exp(-x) - 1)], + 'slow': True, + }, + + 'var_of_parameters_03': { + 'eq': f(x).diff(x) - 1, + 'sol': [Eq(f(x), C1 + x)], + 'slow': True, + }, + + 'var_of_parameters_04': { + 'eq': f2 + 3*f(x).diff(x) + 2*f(x) - 4, + 'sol': [Eq(f(x), C1*exp(-2*x) + C2*exp(-x) + 2)], + 'slow': True, + }, + + 'var_of_parameters_05': { + 'eq': f2 + 3*f(x).diff(x) + 2*f(x) - 12*exp(x), + 'sol': [Eq(f(x), C1*exp(-2*x) + C2*exp(-x) + 2*exp(x))], + 'slow': True, + }, + + 'var_of_parameters_06': { + 'eq': f2 - 2*f(x).diff(x) - 8*f(x) - 9*x*exp(x) - 10*exp(-x), + 'sol': [Eq(f(x), -x*exp(x) - 2*exp(-x) + C1*exp(-2*x) + C2*exp(4*x))], + 'slow': True, + }, + + 'var_of_parameters_07': { + 'eq': f2 + 2*f(x).diff(x) + f(x) - x**2*exp(-x), + 'sol': [Eq(f(x), (C1 + x*(C2 + x**3/12))*exp(-x))], + 'slow': True, + }, + + 'var_of_parameters_08': { + 'eq': f2 - 3*f(x).diff(x) + 2*f(x) - x*exp(-x), + 'sol': [Eq(f(x), C1*exp(x) + C2*exp(2*x) + (6*x + 5)*exp(-x)/36)], + 'slow': True, + }, + + 'var_of_parameters_09': { + 'eq': f(x).diff(x, 3) - 3*f2 + 3*f(x).diff(x) - f(x) - exp(x), + 'sol': [Eq(f(x), (C1 + x*(C2 + x*(C3 + x/6)))*exp(x))], + 'slow': True, + }, + + 'var_of_parameters_10': { + 'eq': f2 + 2*f(x).diff(x) + f(x) - exp(-x)/x, + 'sol': [Eq(f(x), (C1 + x*(C2 + log(x)))*exp(-x))], + 'slow': True, + }, + + 'var_of_parameters_11': { + 'eq': f2 + f(x) - 1/sin(x)*1/cos(x), + 'sol': [Eq(f(x), (C1 + log(sin(x) - 1)/2 - log(sin(x) + 1)/2 + )*cos(x) + (C2 + log(cos(x) - 1)/2 - log(cos(x) + 1)/2)*sin(x))], + 'slow': True, + }, + + 'var_of_parameters_12': { + 'eq': f(x).diff(x, 4) - 1/x, + 'sol': [Eq(f(x), C1 + C2*x + C3*x**2 + x**3*(C4 + log(x)/6))], + 'slow': True, + }, + + # These were from issue: https://github.com/sympy/sympy/issues/15996 + 'var_of_parameters_13': { + 'eq': f(x).diff(x, 5) + 2*f(x).diff(x, 3) + f(x).diff(x) - 2*x - exp(I*x), + 'sol': [Eq(f(x), C1 + x**2 + (C2 + x*(C3 - x/8 + 3*exp(I*x)/2 + 3*exp(-I*x)/2) + 5*exp(2*I*x)/16 + 2*I*exp(I*x) - 2*I*exp(-I*x))*sin(x) + (C4 + x*(C5 + I*x/8 + 3*I*exp(I*x)/2 - 3*I*exp(-I*x)/2) + + 5*I*exp(2*I*x)/16 - 2*exp(I*x) - 2*exp(-I*x))*cos(x) - I*exp(I*x))], + }, + + 'var_of_parameters_14': { + 'eq': f(x).diff(x, 5) + 2*f(x).diff(x, 3) + f(x).diff(x) - exp(I*x), + 'sol': [Eq(f(x), C1 + (C2 + x*(C3 - x/8) + 5*exp(2*I*x)/16)*sin(x) + (C4 + x*(C5 + I*x/8) + 5*I*exp(2*I*x)/16)*cos(x) - I*exp(I*x))], + }, + + # https://github.com/sympy/sympy/issues/14395 + 'var_of_parameters_15': { + 'eq': Derivative(f(x), x, x) + 9*f(x) - sec(x), + 'sol': [Eq(f(x), (C1 - x/3 + sin(2*x)/3)*sin(3*x) + (C2 + log(cos(x)) + - 2*log(cos(x)**2)/3 + 2*cos(x)**2/3)*cos(3*x))], + 'slow': True, + }, + } + } + + +@_add_example_keys +def _get_examples_ode_sol_2nd_linear_bessel(): + return { + 'hint': "2nd_linear_bessel", + 'func': f(x), + 'examples':{ + '2nd_lin_bessel_01': { + 'eq': x**2*(f(x).diff(x, 2)) + x*(f(x).diff(x)) + (x**2 - 4)*f(x), + 'sol': [Eq(f(x), C1*besselj(2, x) + C2*bessely(2, x))], + }, + + '2nd_lin_bessel_02': { + 'eq': x**2*(f(x).diff(x, 2)) + x*(f(x).diff(x)) + (x**2 +25)*f(x), + 'sol': [Eq(f(x), C1*besselj(5*I, x) + C2*bessely(5*I, x))], + }, + + '2nd_lin_bessel_03': { + 'eq': x**2*(f(x).diff(x, 2)) + x*(f(x).diff(x)) + (x**2)*f(x), + 'sol': [Eq(f(x), C1*besselj(0, x) + C2*bessely(0, x))], + }, + + '2nd_lin_bessel_04': { + 'eq': x**2*(f(x).diff(x, 2)) + x*(f(x).diff(x)) + (81*x**2 -S(1)/9)*f(x), + 'sol': [Eq(f(x), C1*besselj(S(1)/3, 9*x) + C2*bessely(S(1)/3, 9*x))], + }, + + '2nd_lin_bessel_05': { + 'eq': x**2*(f(x).diff(x, 2)) + x*(f(x).diff(x)) + (x**4 - 4)*f(x), + 'sol': [Eq(f(x), C1*besselj(1, x**2/2) + C2*bessely(1, x**2/2))], + }, + + '2nd_lin_bessel_06': { + 'eq': x**2*(f(x).diff(x, 2)) + 2*x*(f(x).diff(x)) + (x**4 - 4)*f(x), + 'sol': [Eq(f(x), (C1*besselj(sqrt(17)/4, x**2/2) + C2*bessely(sqrt(17)/4, x**2/2))/sqrt(x))], + }, + + '2nd_lin_bessel_07': { + 'eq': x**2*(f(x).diff(x, 2)) + x*(f(x).diff(x)) + (x**2 - S(1)/4)*f(x), + 'sol': [Eq(f(x), C1*besselj(S(1)/2, x) + C2*bessely(S(1)/2, x))], + }, + + '2nd_lin_bessel_08': { + 'eq': x**2*(f(x).diff(x, 2)) - 3*x*(f(x).diff(x)) + (4*x + 4)*f(x), + 'sol': [Eq(f(x), x**2*(C1*besselj(0, 4*sqrt(x)) + C2*bessely(0, 4*sqrt(x))))], + }, + + '2nd_lin_bessel_09': { + 'eq': x*(f(x).diff(x, 2)) - f(x).diff(x) + 4*x**3*f(x), + 'sol': [Eq(f(x), x*(C1*besselj(S(1)/2, x**2) + C2*bessely(S(1)/2, x**2)))], + }, + + '2nd_lin_bessel_10': { + 'eq': (x-2)**2*(f(x).diff(x, 2)) - (x-2)*f(x).diff(x) + 4*(x-2)**2*f(x), + 'sol': [Eq(f(x), (x - 2)*(C1*besselj(1, 2*x - 4) + C2*bessely(1, 2*x - 4)))], + }, + + # https://github.com/sympy/sympy/issues/4414 + '2nd_lin_bessel_11': { + 'eq': f(x).diff(x, x) + 2/x*f(x).diff(x) + f(x), + 'sol': [Eq(f(x), (C1*besselj(S(1)/2, x) + C2*bessely(S(1)/2, x))/sqrt(x))], + }, + '2nd_lin_bessel_12': { + 'eq': x**2*f(x).diff(x, 2) + x*f(x).diff(x) + (a**2*x**2/c**2 - b**2)*f(x), + 'sol': [Eq(f(x), C1*besselj(sqrt(b**2), x*sqrt(a**2/c**2)) + C2*bessely(sqrt(b**2), x*sqrt(a**2/c**2)))], + }, + } + } + + +@_add_example_keys +def _get_examples_ode_sol_2nd_2F1_hypergeometric(): + return { + 'hint': "2nd_hypergeometric", + 'func': f(x), + 'examples':{ + '2nd_2F1_hyper_01': { + 'eq': x*(x-1)*f(x).diff(x, 2) + (S(3)/2 -2*x)*f(x).diff(x) + 2*f(x), + 'sol': [Eq(f(x), C1*x**(S(5)/2)*hyper((S(3)/2, S(1)/2), (S(7)/2,), x) + C2*hyper((-1, -2), (-S(3)/2,), x))], + }, + + '2nd_2F1_hyper_02': { + 'eq': x*(x-1)*f(x).diff(x, 2) + (S(7)/2*x)*f(x).diff(x) + f(x), + 'sol': [Eq(f(x), (C1*(1 - x)**(S(5)/2)*hyper((S(1)/2, 2), (S(7)/2,), 1 - x) + + C2*hyper((-S(1)/2, -2), (-S(3)/2,), 1 - x))/(x - 1)**(S(5)/2))], + }, + + '2nd_2F1_hyper_03': { + 'eq': x*(x-1)*f(x).diff(x, 2) + (S(3)+ S(7)/2*x)*f(x).diff(x) + f(x), + 'sol': [Eq(f(x), (C1*(1 - x)**(S(11)/2)*hyper((S(1)/2, 2), (S(13)/2,), 1 - x) + + C2*hyper((-S(7)/2, -5), (-S(9)/2,), 1 - x))/(x - 1)**(S(11)/2))], + }, + + '2nd_2F1_hyper_04': { + 'eq': -x**(S(5)/7)*(-416*x**(S(9)/7)/9 - 2385*x**(S(5)/7)/49 + S(298)*x/3)*f(x)/(196*(-x**(S(6)/7) + + x)**2*(x**(S(6)/7) + x)**2) + Derivative(f(x), (x, 2)), + 'sol': [Eq(f(x), x**(S(45)/98)*(C1*x**(S(4)/49)*hyper((S(1)/3, -S(1)/2), (S(9)/7,), x**(S(2)/7)) + + C2*hyper((S(1)/21, -S(11)/14), (S(5)/7,), x**(S(2)/7)))/(x**(S(2)/7) - 1)**(S(19)/84))], + 'checkodesol_XFAIL':True, + }, + } + } + +@_add_example_keys +def _get_examples_ode_sol_2nd_nonlinear_autonomous_conserved(): + return { + 'hint': "2nd_nonlinear_autonomous_conserved", + 'func': f(x), + 'examples': { + '2nd_nonlinear_autonomous_conserved_01': { + 'eq': f(x).diff(x, 2) + exp(f(x)) + log(f(x)), + 'sol': [ + Eq(Integral(1/sqrt(C1 - 2*_u*log(_u) + 2*_u - 2*exp(_u)), (_u, f(x))), C2 + x), + Eq(Integral(1/sqrt(C1 - 2*_u*log(_u) + 2*_u - 2*exp(_u)), (_u, f(x))), C2 - x) + ], + 'simplify_flag': False, + }, + '2nd_nonlinear_autonomous_conserved_02': { + 'eq': f(x).diff(x, 2) + cbrt(f(x)) + 1/f(x), + 'sol': [ + Eq(sqrt(2)*Integral(1/sqrt(2*C1 - 3*_u**Rational(4, 3) - 4*log(_u)), (_u, f(x))), C2 + x), + Eq(sqrt(2)*Integral(1/sqrt(2*C1 - 3*_u**Rational(4, 3) - 4*log(_u)), (_u, f(x))), C2 - x) + ], + 'simplify_flag': False, + }, + '2nd_nonlinear_autonomous_conserved_03': { + 'eq': f(x).diff(x, 2) + sin(f(x)), + 'sol': [ + Eq(Integral(1/sqrt(C1 + 2*cos(_u)), (_u, f(x))), C2 + x), + Eq(Integral(1/sqrt(C1 + 2*cos(_u)), (_u, f(x))), C2 - x) + ], + 'simplify_flag': False, + }, + '2nd_nonlinear_autonomous_conserved_04': { + 'eq': f(x).diff(x, 2) + cosh(f(x)), + 'sol': [ + Eq(Integral(1/sqrt(C1 - 2*sinh(_u)), (_u, f(x))), C2 + x), + Eq(Integral(1/sqrt(C1 - 2*sinh(_u)), (_u, f(x))), C2 - x) + ], + 'simplify_flag': False, + }, + '2nd_nonlinear_autonomous_conserved_05': { + 'eq': f(x).diff(x, 2) + asin(f(x)), + 'sol': [ + Eq(Integral(1/sqrt(C1 - 2*_u*asin(_u) - 2*sqrt(1 - _u**2)), (_u, f(x))), C2 + x), + Eq(Integral(1/sqrt(C1 - 2*_u*asin(_u) - 2*sqrt(1 - _u**2)), (_u, f(x))), C2 - x) + ], + 'simplify_flag': False, + 'XFAIL': ['2nd_nonlinear_autonomous_conserved_Integral'] + } + } + } + + +@_add_example_keys +def _get_examples_ode_sol_separable_reduced(): + df = f(x).diff(x) + return { + 'hint': "separable_reduced", + 'func': f(x), + 'examples':{ + 'separable_reduced_01': { + 'eq': x* df + f(x)* (1 / (x**2*f(x) - 1)), + 'sol': [Eq(log(x**2*f(x))/3 + log(x**2*f(x) - Rational(3, 2))/6, C1 + log(x))], + 'simplify_flag': False, + 'XFAIL': ['lie_group'], #It hangs. + }, + + #Note: 'separable_reduced_02' is referred in 'separable_reduced_11' + 'separable_reduced_02': { + 'eq': f(x).diff(x) + (f(x) / (x**4*f(x) - x)), + 'sol': [Eq(log(x**3*f(x))/4 + log(x**3*f(x) - Rational(4,3))/12, C1 + log(x))], + 'simplify_flag': False, + 'checkodesol_XFAIL':True, #It hangs for this. + }, + + 'separable_reduced_03': { + 'eq': x*df + f(x)*(x**2*f(x)), + 'sol': [Eq(log(x**2*f(x))/2 - log(x**2*f(x) - 2)/2, C1 + log(x))], + 'simplify_flag': False, + }, + + 'separable_reduced_04': { + 'eq': Eq(f(x).diff(x) + f(x)/x * (1 + (x**(S(2)/3)*f(x))**2), 0), + 'sol': [Eq(-3*log(x**(S(2)/3)*f(x)) + 3*log(3*x**(S(4)/3)*f(x)**2 + 1)/2, C1 + log(x))], + 'simplify_flag': False, + }, + + 'separable_reduced_05': { + 'eq': Eq(f(x).diff(x) + f(x)/x * (1 + (x*f(x))**2), 0), + 'sol': [Eq(f(x), -sqrt(2)*sqrt(1/(C1 + log(x)))/(2*x)),\ + Eq(f(x), sqrt(2)*sqrt(1/(C1 + log(x)))/(2*x))], + }, + + 'separable_reduced_06': { + 'eq': Eq(f(x).diff(x) + (x**4*f(x)**2 + x**2*f(x))*f(x)/(x*(x**6*f(x)**3 + x**4*f(x)**2)), 0), + 'sol': [Eq(f(x), C1 + 1/(2*x**2))], + }, + + 'separable_reduced_07': { + 'eq': Eq(f(x).diff(x) + (f(x)**2)*f(x)/(x), 0), + 'sol': [ + Eq(f(x), -sqrt(2)*sqrt(1/(C1 + log(x)))/2), + Eq(f(x), sqrt(2)*sqrt(1/(C1 + log(x)))/2) + ], + }, + + 'separable_reduced_08': { + 'eq': Eq(f(x).diff(x) + (f(x)+3)*f(x)/(x*(f(x)+2)), 0), + 'sol': [Eq(-log(f(x) + 3)/3 - 2*log(f(x))/3, C1 + log(x))], + 'simplify_flag': False, + 'XFAIL': ['lie_group'], #It hangs. + }, + + 'separable_reduced_09': { + 'eq': Eq(f(x).diff(x) + (f(x)+3)*f(x)/x, 0), + 'sol': [Eq(f(x), 3/(C1*x**3 - 1))], + }, + + 'separable_reduced_10': { + 'eq': Eq(f(x).diff(x) + (f(x)**2+f(x))*f(x)/(x), 0), + 'sol': [Eq(- log(x) - log(f(x) + 1) + log(f(x)) + 1/f(x), C1)], + 'XFAIL': ['lie_group'],#No algorithms are implemented to solve equation -C1 + x*(_y + 1)*exp(-1/_y)/_y + + }, + + # Equivalent to example_name 'separable_reduced_02'. Only difference is testing with simplify=True + 'separable_reduced_11': { + 'eq': f(x).diff(x) + (f(x) / (x**4*f(x) - x)), + 'sol': [Eq(f(x), -sqrt(2)*sqrt(3*3**Rational(1,3)*(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) +- 3*3**Rational(2,3)*exp(12*C1)/(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) + 2/x**6)/6 +- sqrt(2)*sqrt(-3*3**Rational(1,3)*(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) ++ 3*3**Rational(2,3)*exp(12*C1)/(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) + 4/x**6 +- 4*sqrt(2)/(x**9*sqrt(3*3**Rational(1,3)*(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) +- 3*3**Rational(2,3)*exp(12*C1)/(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) + 2/x**6)))/6 + 1/(3*x**3)), +Eq(f(x), -sqrt(2)*sqrt(3*3**Rational(1,3)*(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) +- 3*3**Rational(2,3)*exp(12*C1)/(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) + 2/x**6)/6 ++ sqrt(2)*sqrt(-3*3**Rational(1,3)*(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) ++ 3*3**Rational(2,3)*exp(12*C1)/(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) + 4/x**6 +- 4*sqrt(2)/(x**9*sqrt(3*3**Rational(1,3)*(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) +- 3*3**Rational(2,3)*exp(12*C1)/(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) + 2/x**6)))/6 + 1/(3*x**3)), +Eq(f(x), sqrt(2)*sqrt(3*3**Rational(1,3)*(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) +- 3*3**Rational(2,3)*exp(12*C1)/(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) + 2/x**6)/6 +- sqrt(2)*sqrt(-3*3**Rational(1,3)*(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) ++ 3*3**Rational(2,3)*exp(12*C1)/(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) ++ 4/x**6 + 4*sqrt(2)/(x**9*sqrt(3*3**Rational(1,3)*(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) +- 3*3**Rational(2,3)*exp(12*C1)/(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) + 2/x**6)))/6 + 1/(3*x**3)), +Eq(f(x), sqrt(2)*sqrt(3*3**Rational(1,3)*(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) +- 3*3**Rational(2,3)*exp(12*C1)/(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) + 2/x**6)/6 ++ sqrt(2)*sqrt(-3*3**Rational(1,3)*(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) + 3*3**Rational(2,3)*exp(12*C1)/(sqrt((3*exp(12*C1) ++ x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) + 4/x**6 + 4*sqrt(2)/(x**9*sqrt(3*3**Rational(1,3)*(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) +- exp(12*C1)/x**6)**Rational(1,3) - 3*3**Rational(2,3)*exp(12*C1)/(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) + 2/x**6)))/6 + 1/(3*x**3))], + 'checkodesol_XFAIL':True, #It hangs for this. + 'slow': True, + }, + + #These were from issue: https://github.com/sympy/sympy/issues/6247 + 'separable_reduced_12': { + 'eq': x**2*f(x)**2 + x*Derivative(f(x), x), + 'sol': [Eq(f(x), 2*C1/(C1*x**2 - 1))], + }, + } + } + + +@_add_example_keys +def _get_examples_ode_sol_lie_group(): + a, b, c = symbols("a b c") + return { + 'hint': "lie_group", + 'func': f(x), + 'examples':{ + #Example 1-4 and 19-20 were from issue: https://github.com/sympy/sympy/issues/17322 + 'lie_group_01': { + 'eq': x*f(x).diff(x)*(f(x)+4) + (f(x)**2) -2*f(x)-2*x, + 'sol': [], + 'dsolve_too_slow': True, + 'checkodesol_too_slow': True, + }, + + 'lie_group_02': { + 'eq': x*f(x).diff(x)*(f(x)+4) + (f(x)**2) -2*f(x)-2*x, + 'sol': [], + 'dsolve_too_slow': True, + }, + + 'lie_group_03': { + 'eq': Eq(x**7*Derivative(f(x), x) + 5*x**3*f(x)**2 - (2*x**2 + 2)*f(x)**3, 0), + 'sol': [], + 'dsolve_too_slow': True, + }, + + 'lie_group_04': { + 'eq': f(x).diff(x) - (f(x) - x*log(x))**2/x**2 + log(x), + 'sol': [], + 'XFAIL': ['lie_group'], + }, + + 'lie_group_05': { + 'eq': f(x).diff(x)**2, + 'sol': [Eq(f(x), C1)], + 'XFAIL': ['factorable'], #It raises Not Implemented error + }, + + 'lie_group_06': { + 'eq': Eq(f(x).diff(x), x**2*f(x)), + 'sol': [Eq(f(x), C1*exp(x**3)**Rational(1, 3))], + }, + + 'lie_group_07': { + 'eq': f(x).diff(x) + a*f(x) - c*exp(b*x), + 'sol': [Eq(f(x), Piecewise(((-C1*(a + b) + c*exp(x*(a + b)))*exp(-a*x)/(a + b),\ + Ne(a, -b)), ((-C1 + c*x)*exp(-a*x), True)))], + }, + + 'lie_group_08': { + 'eq': f(x).diff(x) + 2*x*f(x) - x*exp(-x**2), + 'sol': [Eq(f(x), (C1 + x**2/2)*exp(-x**2))], + }, + + 'lie_group_09': { + 'eq': (1 + 2*x)*(f(x).diff(x)) + 2 - 4*exp(-f(x)), + 'sol': [Eq(f(x), log(C1/(2*x + 1) + 2))], + }, + + 'lie_group_10': { + 'eq': x**2*(f(x).diff(x)) - f(x) + x**2*exp(x - (1/x)), + 'sol': [Eq(f(x), (C1 - exp(x))*exp(-1/x))], + 'XFAIL': ['factorable'], #It raises Recursion Error (maixmum depth exceeded) + }, + + 'lie_group_11': { + 'eq': x**2*f(x)**2 + x*Derivative(f(x), x), + 'sol': [Eq(f(x), 2/(C1 + x**2))], + }, + + 'lie_group_12': { + 'eq': diff(f(x),x) + 2*x*f(x) - x*exp(-x**2), + 'sol': [Eq(f(x), exp(-x**2)*(C1 + x**2/2))], + }, + + 'lie_group_13': { + 'eq': diff(f(x),x) + f(x)*cos(x) - exp(2*x), + 'sol': [Eq(f(x), exp(-sin(x))*(C1 + Integral(exp(2*x)*exp(sin(x)), x)))], + }, + + 'lie_group_14': { + 'eq': diff(f(x),x) + f(x)*cos(x) - sin(2*x)/2, + 'sol': [Eq(f(x), C1*exp(-sin(x)) + sin(x) - 1)], + }, + + 'lie_group_15': { + 'eq': x*diff(f(x),x) + f(x) - x*sin(x), + 'sol': [Eq(f(x), (C1 - x*cos(x) + sin(x))/x)], + }, + + 'lie_group_16': { + 'eq': x*diff(f(x),x) - f(x) - x/log(x), + 'sol': [Eq(f(x), x*(C1 + log(log(x))))], + }, + + 'lie_group_17': { + 'eq': (f(x).diff(x)-f(x)) * (f(x).diff(x)+f(x)), + 'sol': [Eq(f(x), C1*exp(x)), Eq(f(x), C1*exp(-x))], + }, + + 'lie_group_18': { + 'eq': f(x).diff(x) * (f(x).diff(x) - f(x)), + 'sol': [Eq(f(x), C1*exp(x)), Eq(f(x), C1)], + }, + + 'lie_group_19': { + 'eq': (f(x).diff(x)-f(x)) * (f(x).diff(x)+f(x)), + 'sol': [Eq(f(x), C1*exp(-x)), Eq(f(x), C1*exp(x))], + }, + + 'lie_group_20': { + 'eq': f(x).diff(x)*(f(x).diff(x)+f(x)), + 'sol': [Eq(f(x), C1), Eq(f(x), C1*exp(-x))], + }, + } + } + + +@_add_example_keys +def _get_examples_ode_sol_2nd_linear_airy(): + return { + 'hint': "2nd_linear_airy", + 'func': f(x), + 'examples':{ + '2nd_lin_airy_01': { + 'eq': f(x).diff(x, 2) - x*f(x), + 'sol': [Eq(f(x), C1*airyai(x) + C2*airybi(x))], + }, + + '2nd_lin_airy_02': { + 'eq': f(x).diff(x, 2) + 2*x*f(x), + 'sol': [Eq(f(x), C1*airyai(-2**(S(1)/3)*x) + C2*airybi(-2**(S(1)/3)*x))], + }, + } + } + + +@_add_example_keys +def _get_examples_ode_sol_nth_linear_constant_coeff_homogeneous(): + # From Exercise 20, in Ordinary Differential Equations, + # Tenenbaum and Pollard, pg. 220 + a = Symbol('a', positive=True) + k = Symbol('k', real=True) + r1, r2, r3, r4, r5 = [rootof(x**5 + 11*x - 2, n) for n in range(5)] + r6, r7, r8, r9, r10 = [rootof(x**5 - 3*x + 1, n) for n in range(5)] + r11, r12, r13, r14, r15 = [rootof(x**5 - 100*x**3 + 1000*x + 1, n) for n in range(5)] + r16, r17, r18, r19, r20 = [rootof(x**5 - x**4 + 10, n) for n in range(5)] + r21, r22, r23, r24, r25 = [rootof(x**5 - x + 1, n) for n in range(5)] + E = exp(1) + return { + 'hint': "nth_linear_constant_coeff_homogeneous", + 'func': f(x), + 'examples':{ + 'lin_const_coeff_hom_01': { + 'eq': f(x).diff(x, 2) + 2*f(x).diff(x), + 'sol': [Eq(f(x), C1 + C2*exp(-2*x))], + }, + + 'lin_const_coeff_hom_02': { + 'eq': f(x).diff(x, 2) - 3*f(x).diff(x) + 2*f(x), + 'sol': [Eq(f(x), (C1 + C2*exp(x))*exp(x))], + }, + + 'lin_const_coeff_hom_03': { + 'eq': f(x).diff(x, 2) - f(x), + 'sol': [Eq(f(x), C1*exp(-x) + C2*exp(x))], + }, + + 'lin_const_coeff_hom_04': { + 'eq': f(x).diff(x, 3) + f(x).diff(x, 2) - 6*f(x).diff(x), + 'sol': [Eq(f(x), C1 + C2*exp(-3*x) + C3*exp(2*x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_05': { + 'eq': 6*f(x).diff(x, 2) - 11*f(x).diff(x) + 4*f(x), + 'sol': [Eq(f(x), C1*exp(x/2) + C2*exp(x*Rational(4, 3)))], + 'slow': True, + }, + + 'lin_const_coeff_hom_06': { + 'eq': Eq(f(x).diff(x, 2) + 2*f(x).diff(x) - f(x), 0), + 'sol': [Eq(f(x), C1*exp(x*(-1 + sqrt(2))) + C2*exp(-x*(sqrt(2) + 1)))], + 'slow': True, + }, + + 'lin_const_coeff_hom_07': { + 'eq': diff(f(x), x, 3) + diff(f(x), x, 2) - 10*diff(f(x), x) - 6*f(x), + 'sol': [Eq(f(x), C1*exp(3*x) + C3*exp(-x*(2 + sqrt(2))) + C2*exp(x*(-2 + sqrt(2))))], + 'slow': True, + }, + + 'lin_const_coeff_hom_08': { + 'eq': f(x).diff(x, 4) - f(x).diff(x, 3) - 4*f(x).diff(x, 2) + \ + 4*f(x).diff(x), + 'sol': [Eq(f(x), C1 + C2*exp(-2*x) + C3*exp(x) + C4*exp(2*x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_09': { + 'eq': f(x).diff(x, 4) + 4*f(x).diff(x, 3) + f(x).diff(x, 2) - \ + 4*f(x).diff(x) - 2*f(x), + 'sol': [Eq(f(x), C3*exp(-x) + C4*exp(x) + (C1*exp(-sqrt(2)*x) + C2*exp(sqrt(2)*x))*exp(-2*x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_10': { + 'eq': f(x).diff(x, 4) - a**2*f(x), + 'sol': [Eq(f(x), C1*exp(-sqrt(a)*x) + C2*exp(sqrt(a)*x) + C3*sin(sqrt(a)*x) + C4*cos(sqrt(a)*x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_11': { + 'eq': f(x).diff(x, 2) - 2*k*f(x).diff(x) - 2*f(x), + 'sol': [Eq(f(x), C1*exp(x*(k - sqrt(k**2 + 2))) + C2*exp(x*(k + sqrt(k**2 + 2))))], + 'slow': True, + }, + + 'lin_const_coeff_hom_12': { + 'eq': f(x).diff(x, 2) + 4*k*f(x).diff(x) - 12*k**2*f(x), + 'sol': [Eq(f(x), C1*exp(-6*k*x) + C2*exp(2*k*x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_13': { + 'eq': f(x).diff(x, 4), + 'sol': [Eq(f(x), C1 + C2*x + C3*x**2 + C4*x**3)], + 'slow': True, + }, + + 'lin_const_coeff_hom_14': { + 'eq': f(x).diff(x, 2) + 4*f(x).diff(x) + 4*f(x), + 'sol': [Eq(f(x), (C1 + C2*x)*exp(-2*x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_15': { + 'eq': 3*f(x).diff(x, 3) + 5*f(x).diff(x, 2) + f(x).diff(x) - f(x), + 'sol': [Eq(f(x), (C1 + C2*x)*exp(-x) + C3*exp(x/3))], + 'slow': True, + }, + + 'lin_const_coeff_hom_16': { + 'eq': f(x).diff(x, 3) - 6*f(x).diff(x, 2) + 12*f(x).diff(x) - 8*f(x), + 'sol': [Eq(f(x), (C1 + x*(C2 + C3*x))*exp(2*x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_17': { + 'eq': f(x).diff(x, 2) - 2*a*f(x).diff(x) + a**2*f(x), + 'sol': [Eq(f(x), (C1 + C2*x)*exp(a*x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_18': { + 'eq': f(x).diff(x, 4) + 3*f(x).diff(x, 3), + 'sol': [Eq(f(x), C1 + C2*x + C3*x**2 + C4*exp(-3*x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_19': { + 'eq': f(x).diff(x, 4) - 2*f(x).diff(x, 2), + 'sol': [Eq(f(x), C1 + C2*x + C3*exp(-sqrt(2)*x) + C4*exp(sqrt(2)*x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_20': { + 'eq': f(x).diff(x, 4) + 2*f(x).diff(x, 3) - 11*f(x).diff(x, 2) - \ + 12*f(x).diff(x) + 36*f(x), + 'sol': [Eq(f(x), (C1 + C2*x)*exp(-3*x) + (C3 + C4*x)*exp(2*x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_21': { + 'eq': 36*f(x).diff(x, 4) - 37*f(x).diff(x, 2) + 4*f(x).diff(x) + 5*f(x), + 'sol': [Eq(f(x), C1*exp(-x) + C2*exp(-x/3) + C3*exp(x/2) + C4*exp(x*Rational(5, 6)))], + 'slow': True, + }, + + 'lin_const_coeff_hom_22': { + 'eq': f(x).diff(x, 4) - 8*f(x).diff(x, 2) + 16*f(x), + 'sol': [Eq(f(x), (C1 + C2*x)*exp(-2*x) + (C3 + C4*x)*exp(2*x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_23': { + 'eq': f(x).diff(x, 2) - 2*f(x).diff(x) + 5*f(x), + 'sol': [Eq(f(x), (C1*sin(2*x) + C2*cos(2*x))*exp(x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_24': { + 'eq': f(x).diff(x, 2) - f(x).diff(x) + f(x), + 'sol': [Eq(f(x), (C1*sin(x*sqrt(3)/2) + C2*cos(x*sqrt(3)/2))*exp(x/2))], + 'slow': True, + }, + + 'lin_const_coeff_hom_25': { + 'eq': f(x).diff(x, 4) + 5*f(x).diff(x, 2) + 6*f(x), + 'sol': [Eq(f(x), + C1*sin(sqrt(2)*x) + C2*sin(sqrt(3)*x) + C3*cos(sqrt(2)*x) + C4*cos(sqrt(3)*x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_26': { + 'eq': f(x).diff(x, 2) - 4*f(x).diff(x) + 20*f(x), + 'sol': [Eq(f(x), (C1*sin(4*x) + C2*cos(4*x))*exp(2*x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_27': { + 'eq': f(x).diff(x, 4) + 4*f(x).diff(x, 2) + 4*f(x), + 'sol': [Eq(f(x), (C1 + C2*x)*sin(x*sqrt(2)) + (C3 + C4*x)*cos(x*sqrt(2)))], + 'slow': True, + }, + + 'lin_const_coeff_hom_28': { + 'eq': f(x).diff(x, 3) + 8*f(x), + 'sol': [Eq(f(x), (C1*sin(x*sqrt(3)) + C2*cos(x*sqrt(3)))*exp(x) + C3*exp(-2*x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_29': { + 'eq': f(x).diff(x, 4) + 4*f(x).diff(x, 2), + 'sol': [Eq(f(x), C1 + C2*x + C3*sin(2*x) + C4*cos(2*x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_30': { + 'eq': f(x).diff(x, 5) + 2*f(x).diff(x, 3) + f(x).diff(x), + 'sol': [Eq(f(x), C1 + (C2 + C3*x)*sin(x) + (C4 + C5*x)*cos(x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_31': { + 'eq': f(x).diff(x, 4) + f(x).diff(x, 2) + f(x), + 'sol': [Eq(f(x), (C1*sin(sqrt(3)*x/2) + C2*cos(sqrt(3)*x/2))*exp(-x/2) + + (C3*sin(sqrt(3)*x/2) + C4*cos(sqrt(3)*x/2))*exp(x/2))], + 'slow': True, + }, + + 'lin_const_coeff_hom_32': { + 'eq': f(x).diff(x, 4) + 4*f(x).diff(x, 2) + f(x), + 'sol': [Eq(f(x), C1*sin(x*sqrt(-sqrt(3) + 2)) + C2*sin(x*sqrt(sqrt(3) + 2)) + + C3*cos(x*sqrt(-sqrt(3) + 2)) + C4*cos(x*sqrt(sqrt(3) + 2)))], + 'slow': True, + }, + + # One real root, two complex conjugate pairs + 'lin_const_coeff_hom_33': { + 'eq': f(x).diff(x, 5) + 11*f(x).diff(x) - 2*f(x), + 'sol': [Eq(f(x), + C5*exp(r1*x) + exp(re(r2)*x) * (C1*sin(im(r2)*x) + C2*cos(im(r2)*x)) + + exp(re(r4)*x) * (C3*sin(im(r4)*x) + C4*cos(im(r4)*x)))], + 'checkodesol_XFAIL':True, #It Hangs + }, + + # Three real roots, one complex conjugate pair + 'lin_const_coeff_hom_34': { + 'eq': f(x).diff(x,5) - 3*f(x).diff(x) + f(x), + 'sol': [Eq(f(x), + C3*exp(r6*x) + C4*exp(r7*x) + C5*exp(r8*x) + + exp(re(r9)*x) * (C1*sin(im(r9)*x) + C2*cos(im(r9)*x)))], + 'checkodesol_XFAIL':True, #It Hangs + }, + + # Five distinct real roots + 'lin_const_coeff_hom_35': { + 'eq': f(x).diff(x,5) - 100*f(x).diff(x,3) + 1000*f(x).diff(x) + f(x), + 'sol': [Eq(f(x), C1*exp(r11*x) + C2*exp(r12*x) + C3*exp(r13*x) + C4*exp(r14*x) + C5*exp(r15*x))], + 'checkodesol_XFAIL':True, #It Hangs + }, + + # Rational root and unsolvable quintic + 'lin_const_coeff_hom_36': { + 'eq': f(x).diff(x, 6) - 6*f(x).diff(x, 5) + 5*f(x).diff(x, 4) + 10*f(x).diff(x) - 50 * f(x), + 'sol': [Eq(f(x), + C5*exp(5*x) + + C6*exp(x*r16) + + exp(re(r17)*x) * (C1*sin(im(r17)*x) + C2*cos(im(r17)*x)) + + exp(re(r19)*x) * (C3*sin(im(r19)*x) + C4*cos(im(r19)*x)))], + 'checkodesol_XFAIL':True, #It Hangs + }, + + # Five double roots (this is (x**5 - x + 1)**2) + 'lin_const_coeff_hom_37': { + 'eq': f(x).diff(x, 10) - 2*f(x).diff(x, 6) + 2*f(x).diff(x, 5) + + f(x).diff(x, 2) - 2*f(x).diff(x, 1) + f(x), + 'sol': [Eq(f(x), (C1 + C2*x)*exp(x*r21) + (-((C3 + C4*x)*sin(x*im(r22))) + + (C5 + C6*x)*cos(x*im(r22)))*exp(x*re(r22)) + (-((C7 + C8*x)*sin(x*im(r24))) + + (C10*x + C9)*cos(x*im(r24)))*exp(x*re(r24)))], + 'checkodesol_XFAIL':True, #It Hangs + }, + + 'lin_const_coeff_hom_38': { + 'eq': Eq(sqrt(2) * f(x).diff(x,x,x) + f(x).diff(x), 0), + 'sol': [Eq(f(x), C1 + C2*sin(2**Rational(3, 4)*x/2) + C3*cos(2**Rational(3, 4)*x/2))], + }, + + 'lin_const_coeff_hom_39': { + 'eq': Eq(E * f(x).diff(x,x,x) + f(x).diff(x), 0), + 'sol': [Eq(f(x), C1 + C2*sin(x/sqrt(E)) + C3*cos(x/sqrt(E)))], + }, + + 'lin_const_coeff_hom_40': { + 'eq': Eq(pi * f(x).diff(x,x,x) + f(x).diff(x), 0), + 'sol': [Eq(f(x), C1 + C2*sin(x/sqrt(pi)) + C3*cos(x/sqrt(pi)))], + }, + + 'lin_const_coeff_hom_41': { + 'eq': Eq(I * f(x).diff(x,x,x) + f(x).diff(x), 0), + 'sol': [Eq(f(x), C1 + C2*exp(-sqrt(I)*x) + C3*exp(sqrt(I)*x))], + }, + + 'lin_const_coeff_hom_42': { + 'eq': f(x).diff(x, x) + y*f(x), + 'sol': [Eq(f(x), C1*exp(-x*sqrt(-y)) + C2*exp(x*sqrt(-y)))], + }, + + 'lin_const_coeff_hom_43': { + 'eq': Eq(9*f(x).diff(x, x) + f(x), 0), + 'sol': [Eq(f(x), C1*sin(x/3) + C2*cos(x/3))], + }, + + 'lin_const_coeff_hom_44': { + 'eq': Eq(9*f(x).diff(x, x), f(x)), + 'sol': [Eq(f(x), C1*exp(-x/3) + C2*exp(x/3))], + }, + + 'lin_const_coeff_hom_45': { + 'eq': Eq(f(x).diff(x, x) - 3*diff(f(x), x) + 2*f(x), 0), + 'sol': [Eq(f(x), (C1 + C2*exp(x))*exp(x))], + }, + + 'lin_const_coeff_hom_46': { + 'eq': Eq(f(x).diff(x, x) - 4*diff(f(x), x) + 4*f(x), 0), + 'sol': [Eq(f(x), (C1 + C2*x)*exp(2*x))], + }, + + # Type: 2nd order, constant coefficients (two real equal roots) + 'lin_const_coeff_hom_47': { + 'eq': Eq(f(x).diff(x, x) + 2*diff(f(x), x) + 3*f(x), 0), + 'sol': [Eq(f(x), (C1*sin(x*sqrt(2)) + C2*cos(x*sqrt(2)))*exp(-x))], + }, + + #These were from issue: https://github.com/sympy/sympy/issues/6247 + 'lin_const_coeff_hom_48': { + 'eq': f(x).diff(x, x) + 4*f(x), + 'sol': [Eq(f(x), C1*sin(2*x) + C2*cos(2*x))], + }, + } + } + + +@_add_example_keys +def _get_examples_ode_sol_1st_homogeneous_coeff_subs_dep_div_indep(): + return { + 'hint': "1st_homogeneous_coeff_subs_dep_div_indep", + 'func': f(x), + 'examples':{ + 'dep_div_indep_01': { + 'eq': f(x)/x*cos(f(x)/x) - (x/f(x)*sin(f(x)/x) + cos(f(x)/x))*f(x).diff(x), + 'sol': [Eq(log(x), C1 - log(f(x)*sin(f(x)/x)/x))], + 'slow': True + }, + + #indep_div_dep actually has a simpler solution for example 2 but it runs too slow. + 'dep_div_indep_02': { + 'eq': x*f(x).diff(x) - f(x) - x*sin(f(x)/x), + 'sol': [Eq(log(x), log(C1) + log(cos(f(x)/x) - 1)/2 - log(cos(f(x)/x) + 1)/2)], + 'simplify_flag':False, + }, + + 'dep_div_indep_03': { + 'eq': x*exp(f(x)/x) - f(x)*sin(f(x)/x) + x*sin(f(x)/x)*f(x).diff(x), + 'sol': [Eq(log(x), C1 + exp(-f(x)/x)*sin(f(x)/x)/2 + exp(-f(x)/x)*cos(f(x)/x)/2)], + 'slow': True + }, + + 'dep_div_indep_04': { + 'eq': f(x).diff(x) - f(x)/x + 1/sin(f(x)/x), + 'sol': [Eq(f(x), x*(-acos(C1 + log(x)) + 2*pi)), Eq(f(x), x*acos(C1 + log(x)))], + 'slow': True + }, + + # previous code was testing with these other solution: + # example5_solb = Eq(f(x), log(log(C1/x)**(-x))) + 'dep_div_indep_05': { + 'eq': x*exp(f(x)/x) + f(x) - x*f(x).diff(x), + 'sol': [Eq(f(x), log((1/(C1 - log(x)))**x))], + 'checkodesol_XFAIL':True, #(because of **x?) + }, + } + } + +@_add_example_keys +def _get_examples_ode_sol_linear_coefficients(): + return { + 'hint': "linear_coefficients", + 'func': f(x), + 'examples':{ + 'linear_coeff_01': { + 'eq': f(x).diff(x) + (3 + 2*f(x))/(x + 3), + 'sol': [Eq(f(x), C1/(x**2 + 6*x + 9) - Rational(3, 2))], + }, + } + } + +@_add_example_keys +def _get_examples_ode_sol_1st_homogeneous_coeff_best(): + return { + 'hint': "1st_homogeneous_coeff_best", + 'func': f(x), + 'examples':{ + # previous code was testing this with other solution: + # example1_solb = Eq(-f(x)/(1 + log(x/f(x))), C1) + '1st_homogeneous_coeff_best_01': { + 'eq': f(x) + (x*log(f(x)/x) - 2*x)*diff(f(x), x), + 'sol': [Eq(f(x), -exp(C1)*LambertW(-x*exp(-C1 + 1)))], + 'checkodesol_XFAIL':True, #(because of LambertW?) + }, + + '1st_homogeneous_coeff_best_02': { + 'eq': 2*f(x)*exp(x/f(x)) + f(x)*f(x).diff(x) - 2*x*exp(x/f(x))*f(x).diff(x), + 'sol': [Eq(log(f(x)), C1 - 2*exp(x/f(x)))], + }, + + # previous code was testing this with other solution: + # example3_solb = Eq(log(C1*x*sqrt(1/x)*sqrt(f(x))) + x**2/(2*f(x)**2), 0) + '1st_homogeneous_coeff_best_03': { + 'eq': 2*x**2*f(x) + f(x)**3 + (x*f(x)**2 - 2*x**3)*f(x).diff(x), + 'sol': [Eq(f(x), exp(2*C1 + LambertW(-2*x**4*exp(-4*C1))/2)/x)], + 'checkodesol_XFAIL':True, #(because of LambertW?) + }, + + '1st_homogeneous_coeff_best_04': { + 'eq': (x + sqrt(f(x)**2 - x*f(x)))*f(x).diff(x) - f(x), + 'sol': [Eq(log(f(x)), C1 - 2*sqrt(-x/f(x) + 1))], + 'slow': True, + }, + + '1st_homogeneous_coeff_best_05': { + 'eq': x + f(x) - (x - f(x))*f(x).diff(x), + 'sol': [Eq(log(x), C1 - log(sqrt(1 + f(x)**2/x**2)) + atan(f(x)/x))], + }, + + '1st_homogeneous_coeff_best_06': { + 'eq': x*f(x).diff(x) - f(x) - x*sin(f(x)/x), + 'sol': [Eq(f(x), 2*x*atan(C1*x))], + }, + + '1st_homogeneous_coeff_best_07': { + 'eq': x**2 + f(x)**2 - 2*x*f(x)*f(x).diff(x), + 'sol': [Eq(f(x), -sqrt(x*(C1 + x))), Eq(f(x), sqrt(x*(C1 + x)))], + }, + + '1st_homogeneous_coeff_best_08': { + 'eq': f(x)**2 + (x*sqrt(f(x)**2 - x**2) - x*f(x))*f(x).diff(x), + 'sol': [Eq(f(x), -C1*sqrt(-x/(x - 2*C1))), Eq(f(x), C1*sqrt(-x/(x - 2*C1)))], + 'checkodesol_XFAIL': True # solutions are valid in a range + }, + } + } + + +def _get_all_examples(): + all_examples = _get_examples_ode_sol_euler_homogeneous + \ + _get_examples_ode_sol_euler_undetermined_coeff + \ + _get_examples_ode_sol_euler_var_para + \ + _get_examples_ode_sol_factorable + \ + _get_examples_ode_sol_bernoulli + \ + _get_examples_ode_sol_nth_algebraic + \ + _get_examples_ode_sol_riccati + \ + _get_examples_ode_sol_1st_linear + \ + _get_examples_ode_sol_1st_exact + \ + _get_examples_ode_sol_almost_linear + \ + _get_examples_ode_sol_nth_order_reducible + \ + _get_examples_ode_sol_nth_linear_undetermined_coefficients + \ + _get_examples_ode_sol_liouville + \ + _get_examples_ode_sol_separable + \ + _get_examples_ode_sol_1st_rational_riccati + \ + _get_examples_ode_sol_nth_linear_var_of_parameters + \ + _get_examples_ode_sol_2nd_linear_bessel + \ + _get_examples_ode_sol_2nd_2F1_hypergeometric + \ + _get_examples_ode_sol_2nd_nonlinear_autonomous_conserved + \ + _get_examples_ode_sol_separable_reduced + \ + _get_examples_ode_sol_lie_group + \ + _get_examples_ode_sol_2nd_linear_airy + \ + _get_examples_ode_sol_nth_linear_constant_coeff_homogeneous +\ + _get_examples_ode_sol_1st_homogeneous_coeff_best +\ + _get_examples_ode_sol_1st_homogeneous_coeff_subs_dep_div_indep +\ + _get_examples_ode_sol_linear_coefficients + + return all_examples diff --git a/phivenv/Lib/site-packages/sympy/solvers/ode/tests/test_subscheck.py b/phivenv/Lib/site-packages/sympy/solvers/ode/tests/test_subscheck.py new file mode 100644 index 0000000000000000000000000000000000000000..799c2854e878208721b600767de350cda08cd7e5 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/solvers/ode/tests/test_subscheck.py @@ -0,0 +1,203 @@ +from sympy.core.function import (Derivative, Function, diff) +from sympy.core.numbers import (I, Rational, pi) +from sympy.core.relational import Eq +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.functions.special.error_functions import (Ei, erf, erfi) +from sympy.integrals.integrals import Integral + +from sympy.solvers.ode.subscheck import checkodesol, checksysodesol + +from sympy.functions import besselj, bessely + +from sympy.testing.pytest import raises, slow + + +C0, C1, C2, C3, C4 = symbols('C0:5') +u, x, y, z = symbols('u,x:z', real=True) +f = Function('f') +g = Function('g') +h = Function('h') + + +@slow +def test_checkodesol(): + # For the most part, checkodesol is well tested in the tests below. + # These tests only handle cases not checked below. + raises(ValueError, lambda: checkodesol(f(x, y).diff(x), Eq(f(x, y), x))) + raises(ValueError, lambda: checkodesol(f(x).diff(x), Eq(f(x, y), + x), f(x, y))) + assert checkodesol(f(x).diff(x), Eq(f(x, y), x)) == \ + (False, -f(x).diff(x) + f(x, y).diff(x) - 1) + assert checkodesol(f(x).diff(x), Eq(f(x), x)) is not True + assert checkodesol(f(x).diff(x), Eq(f(x), x)) == (False, 1) + sol1 = Eq(f(x)**5 + 11*f(x) - 2*f(x) + x, 0) + assert checkodesol(diff(sol1.lhs, x), sol1) == (True, 0) + assert checkodesol(diff(sol1.lhs, x)*exp(f(x)), sol1) == (True, 0) + assert checkodesol(diff(sol1.lhs, x, 2), sol1) == (True, 0) + assert checkodesol(diff(sol1.lhs, x, 2)*exp(f(x)), sol1) == (True, 0) + assert checkodesol(diff(sol1.lhs, x, 3), sol1) == (True, 0) + assert checkodesol(diff(sol1.lhs, x, 3)*exp(f(x)), sol1) == (True, 0) + assert checkodesol(diff(sol1.lhs, x, 3), Eq(f(x), x*log(x))) == \ + (False, 60*x**4*((log(x) + 1)**2 + log(x))*( + log(x) + 1)*log(x)**2 - 5*x**4*log(x)**4 - 9) + assert checkodesol(diff(exp(f(x)) + x, x)*x, Eq(exp(f(x)) + x, 0)) == \ + (True, 0) + assert checkodesol(diff(exp(f(x)) + x, x)*x, Eq(exp(f(x)) + x, 0), + solve_for_func=False) == (True, 0) + assert checkodesol(f(x).diff(x, 2), [Eq(f(x), C1 + C2*x), + Eq(f(x), C2 + C1*x), Eq(f(x), C1*x + C2*x**2)]) == \ + [(True, 0), (True, 0), (False, C2)] + assert checkodesol(f(x).diff(x, 2), {Eq(f(x), C1 + C2*x), + Eq(f(x), C2 + C1*x), Eq(f(x), C1*x + C2*x**2)}) == \ + {(True, 0), (True, 0), (False, C2)} + assert checkodesol(f(x).diff(x) - 1/f(x)/2, Eq(f(x)**2, x)) == \ + [(True, 0), (True, 0)] + assert checkodesol(f(x).diff(x) - f(x), Eq(C1*exp(x), f(x))) == (True, 0) + # Based on test_1st_homogeneous_coeff_ode2_eq3sol. Make sure that + # checkodesol tries back substituting f(x) when it can. + eq3 = x*exp(f(x)/x) + f(x) - x*f(x).diff(x) + sol3 = Eq(f(x), log(log(C1/x)**(-x))) + assert not checkodesol(eq3, sol3)[1].has(f(x)) + # This case was failing intermittently depending on hash-seed: + eqn = Eq(Derivative(x*Derivative(f(x), x), x)/x, exp(x)) + sol = Eq(f(x), C1 + C2*log(x) + exp(x) - Ei(x)) + assert checkodesol(eqn, sol, order=2, solve_for_func=False)[0] + eq = x**2*(f(x).diff(x, 2)) + x*(f(x).diff(x)) + (2*x**2 +25)*f(x) + sol = Eq(f(x), C1*besselj(5*I, sqrt(2)*x) + C2*bessely(5*I, sqrt(2)*x)) + assert checkodesol(eq, sol) == (True, 0) + + eqs = [Eq(f(x).diff(x), f(x) + g(x)), Eq(g(x).diff(x), f(x) + g(x))] + sol = [Eq(f(x), -C1 + C2*exp(2*x)), Eq(g(x), C1 + C2*exp(2*x))] + assert checkodesol(eqs, sol) == (True, [0, 0]) + + +def test_checksysodesol(): + x, y, z = symbols('x, y, z', cls=Function) + t = Symbol('t') + eq = (Eq(diff(x(t),t), 9*y(t)), Eq(diff(y(t),t), 12*x(t))) + sol = [Eq(x(t), 9*C1*exp(-6*sqrt(3)*t) + 9*C2*exp(6*sqrt(3)*t)), \ + Eq(y(t), -6*sqrt(3)*C1*exp(-6*sqrt(3)*t) + 6*sqrt(3)*C2*exp(6*sqrt(3)*t))] + assert checksysodesol(eq, sol) == (True, [0, 0]) + + eq = (Eq(diff(x(t),t), 2*x(t) + 4*y(t)), Eq(diff(y(t),t), 12*x(t) + 41*y(t))) + sol = [Eq(x(t), 4*C1*exp(t*(-sqrt(1713)/2 + Rational(43, 2))) + 4*C2*exp(t*(sqrt(1713)/2 + \ + Rational(43, 2)))), Eq(y(t), C1*(-sqrt(1713)/2 + Rational(39, 2))*exp(t*(-sqrt(1713)/2 + \ + Rational(43, 2))) + C2*(Rational(39, 2) + sqrt(1713)/2)*exp(t*(sqrt(1713)/2 + Rational(43, 2))))] + assert checksysodesol(eq, sol) == (True, [0, 0]) + + eq = (Eq(diff(x(t),t), x(t) + y(t)), Eq(diff(y(t),t), -2*x(t) + 2*y(t))) + sol = [Eq(x(t), (C1*sin(sqrt(7)*t/2) + C2*cos(sqrt(7)*t/2))*exp(t*Rational(3, 2))), \ + Eq(y(t), ((C1/2 - sqrt(7)*C2/2)*sin(sqrt(7)*t/2) + (sqrt(7)*C1/2 + \ + C2/2)*cos(sqrt(7)*t/2))*exp(t*Rational(3, 2)))] + assert checksysodesol(eq, sol) == (True, [0, 0]) + + eq = (Eq(diff(x(t),t), x(t) + y(t) + 9), Eq(diff(y(t),t), 2*x(t) + 5*y(t) + 23)) + sol = [Eq(x(t), C1*exp(t*(-sqrt(6) + 3)) + C2*exp(t*(sqrt(6) + 3)) - \ + Rational(22, 3)), Eq(y(t), C1*(-sqrt(6) + 2)*exp(t*(-sqrt(6) + 3)) + C2*(2 + \ + sqrt(6))*exp(t*(sqrt(6) + 3)) - Rational(5, 3))] + assert checksysodesol(eq, sol) == (True, [0, 0]) + + eq = (Eq(diff(x(t),t), x(t) + y(t) + 81), Eq(diff(y(t),t), -2*x(t) + y(t) + 23)) + sol = [Eq(x(t), (C1*sin(sqrt(2)*t) + C2*cos(sqrt(2)*t))*exp(t) - Rational(58, 3)), \ + Eq(y(t), (sqrt(2)*C1*cos(sqrt(2)*t) - sqrt(2)*C2*sin(sqrt(2)*t))*exp(t) - Rational(185, 3))] + assert checksysodesol(eq, sol) == (True, [0, 0]) + + eq = (Eq(diff(x(t),t), 5*t*x(t) + 2*y(t)), Eq(diff(y(t),t), 2*x(t) + 5*t*y(t))) + sol = [Eq(x(t), (C1*exp(Integral(2, t).doit()) + C2*exp(-(Integral(2, t)).doit()))*\ + exp((Integral(5*t, t)).doit())), Eq(y(t), (C1*exp((Integral(2, t)).doit()) - \ + C2*exp(-(Integral(2, t)).doit()))*exp((Integral(5*t, t)).doit()))] + assert checksysodesol(eq, sol) == (True, [0, 0]) + + eq = (Eq(diff(x(t),t), 5*t*x(t) + t**2*y(t)), Eq(diff(y(t),t), -t**2*x(t) + 5*t*y(t))) + sol = [Eq(x(t), (C1*cos((Integral(t**2, t)).doit()) + C2*sin((Integral(t**2, t)).doit()))*\ + exp((Integral(5*t, t)).doit())), Eq(y(t), (-C1*sin((Integral(t**2, t)).doit()) + \ + C2*cos((Integral(t**2, t)).doit()))*exp((Integral(5*t, t)).doit()))] + assert checksysodesol(eq, sol) == (True, [0, 0]) + + eq = (Eq(diff(x(t),t), 5*t*x(t) + t**2*y(t)), Eq(diff(y(t),t), -t**2*x(t) + (5*t+9*t**2)*y(t))) + sol = [Eq(x(t), (C1*exp((-sqrt(77)/2 + Rational(9, 2))*(Integral(t**2, t)).doit()) + \ + C2*exp((sqrt(77)/2 + Rational(9, 2))*(Integral(t**2, t)).doit()))*exp((Integral(5*t, t)).doit())), \ + Eq(y(t), (C1*(-sqrt(77)/2 + Rational(9, 2))*exp((-sqrt(77)/2 + Rational(9, 2))*(Integral(t**2, t)).doit()) + \ + C2*(sqrt(77)/2 + Rational(9, 2))*exp((sqrt(77)/2 + Rational(9, 2))*(Integral(t**2, t)).doit()))*exp((Integral(5*t, t)).doit()))] + assert checksysodesol(eq, sol) == (True, [0, 0]) + + eq = (Eq(diff(x(t),t,t), 5*x(t) + 43*y(t)), Eq(diff(y(t),t,t), x(t) + 9*y(t))) + root0 = -sqrt(-sqrt(47) + 7) + root1 = sqrt(-sqrt(47) + 7) + root2 = -sqrt(sqrt(47) + 7) + root3 = sqrt(sqrt(47) + 7) + sol = [Eq(x(t), 43*C1*exp(t*root0) + 43*C2*exp(t*root1) + 43*C3*exp(t*root2) + 43*C4*exp(t*root3)), \ + Eq(y(t), C1*(root0**2 - 5)*exp(t*root0) + C2*(root1**2 - 5)*exp(t*root1) + \ + C3*(root2**2 - 5)*exp(t*root2) + C4*(root3**2 - 5)*exp(t*root3))] + assert checksysodesol(eq, sol) == (True, [0, 0]) + + eq = (Eq(diff(x(t),t,t), 8*x(t)+3*y(t)+31), Eq(diff(y(t),t,t), 9*x(t)+7*y(t)+12)) + root0 = -sqrt(-sqrt(109)/2 + Rational(15, 2)) + root1 = sqrt(-sqrt(109)/2 + Rational(15, 2)) + root2 = -sqrt(sqrt(109)/2 + Rational(15, 2)) + root3 = sqrt(sqrt(109)/2 + Rational(15, 2)) + sol = [Eq(x(t), 3*C1*exp(t*root0) + 3*C2*exp(t*root1) + 3*C3*exp(t*root2) + 3*C4*exp(t*root3) - Rational(181, 29)), \ + Eq(y(t), C1*(root0**2 - 8)*exp(t*root0) + C2*(root1**2 - 8)*exp(t*root1) + \ + C3*(root2**2 - 8)*exp(t*root2) + C4*(root3**2 - 8)*exp(t*root3) + Rational(183, 29))] + assert checksysodesol(eq, sol) == (True, [0, 0]) + + eq = (Eq(diff(x(t),t,t) - 9*diff(y(t),t) + 7*x(t),0), Eq(diff(y(t),t,t) + 9*diff(x(t),t) + 7*y(t),0)) + sol = [Eq(x(t), C1*cos(t*(Rational(9, 2) + sqrt(109)/2)) + C2*sin(t*(Rational(9, 2) + sqrt(109)/2)) + \ + C3*cos(t*(-sqrt(109)/2 + Rational(9, 2))) + C4*sin(t*(-sqrt(109)/2 + Rational(9, 2)))), Eq(y(t), -C1*sin(t*(Rational(9, 2) + sqrt(109)/2)) \ + + C2*cos(t*(Rational(9, 2) + sqrt(109)/2)) - C3*sin(t*(-sqrt(109)/2 + Rational(9, 2))) + C4*cos(t*(-sqrt(109)/2 + Rational(9, 2))))] + assert checksysodesol(eq, sol) == (True, [0, 0]) + + eq = (Eq(diff(x(t),t,t), 9*t*diff(y(t),t)-9*y(t)), Eq(diff(y(t),t,t),7*t*diff(x(t),t)-7*x(t))) + I1 = sqrt(6)*7**Rational(1, 4)*sqrt(pi)*erfi(sqrt(6)*7**Rational(1, 4)*t/2)/2 - exp(3*sqrt(7)*t**2/2)/t + I2 = -sqrt(6)*7**Rational(1, 4)*sqrt(pi)*erf(sqrt(6)*7**Rational(1, 4)*t/2)/2 - exp(-3*sqrt(7)*t**2/2)/t + sol = [Eq(x(t), C3*t + t*(9*C1*I1 + 9*C2*I2)), Eq(y(t), C4*t + t*(3*sqrt(7)*C1*I1 - 3*sqrt(7)*C2*I2))] + assert checksysodesol(eq, sol) == (True, [0, 0]) + + eq = (Eq(diff(x(t),t), 21*x(t)), Eq(diff(y(t),t), 17*x(t)+3*y(t)), Eq(diff(z(t),t), 5*x(t)+7*y(t)+9*z(t))) + sol = [Eq(x(t), C1*exp(21*t)), Eq(y(t), 17*C1*exp(21*t)/18 + C2*exp(3*t)), \ + Eq(z(t), 209*C1*exp(21*t)/216 - 7*C2*exp(3*t)/6 + C3*exp(9*t))] + assert checksysodesol(eq, sol) == (True, [0, 0, 0]) + + eq = (Eq(diff(x(t),t),3*y(t)-11*z(t)),Eq(diff(y(t),t),7*z(t)-3*x(t)),Eq(diff(z(t),t),11*x(t)-7*y(t))) + sol = [Eq(x(t), 7*C0 + sqrt(179)*C1*cos(sqrt(179)*t) + (77*C1/3 + 130*C2/3)*sin(sqrt(179)*t)), \ + Eq(y(t), 11*C0 + sqrt(179)*C2*cos(sqrt(179)*t) + (-58*C1/3 - 77*C2/3)*sin(sqrt(179)*t)), \ + Eq(z(t), 3*C0 + sqrt(179)*(-7*C1/3 - 11*C2/3)*cos(sqrt(179)*t) + (11*C1 - 7*C2)*sin(sqrt(179)*t))] + assert checksysodesol(eq, sol) == (True, [0, 0, 0]) + + eq = (Eq(3*diff(x(t),t),4*5*(y(t)-z(t))),Eq(4*diff(y(t),t),3*5*(z(t)-x(t))),Eq(5*diff(z(t),t),3*4*(x(t)-y(t)))) + sol = [Eq(x(t), C0 + 5*sqrt(2)*C1*cos(5*sqrt(2)*t) + (12*C1/5 + 164*C2/15)*sin(5*sqrt(2)*t)), \ + Eq(y(t), C0 + 5*sqrt(2)*C2*cos(5*sqrt(2)*t) + (-51*C1/10 - 12*C2/5)*sin(5*sqrt(2)*t)), \ + Eq(z(t), C0 + 5*sqrt(2)*(-9*C1/25 - 16*C2/25)*cos(5*sqrt(2)*t) + (12*C1/5 - 12*C2/5)*sin(5*sqrt(2)*t))] + assert checksysodesol(eq, sol) == (True, [0, 0, 0]) + + eq = (Eq(diff(x(t),t),4*x(t) - z(t)),Eq(diff(y(t),t),2*x(t)+2*y(t)-z(t)),Eq(diff(z(t),t),3*x(t)+y(t))) + sol = [Eq(x(t), C1*exp(2*t) + C2*t*exp(2*t) + C2*exp(2*t) + C3*t**2*exp(2*t)/2 + C3*t*exp(2*t) + C3*exp(2*t)), \ + Eq(y(t), C1*exp(2*t) + C2*t*exp(2*t) + C2*exp(2*t) + C3*t**2*exp(2*t)/2 + C3*t*exp(2*t)), \ + Eq(z(t), 2*C1*exp(2*t) + 2*C2*t*exp(2*t) + C2*exp(2*t) + C3*t**2*exp(2*t) + C3*t*exp(2*t) + C3*exp(2*t))] + assert checksysodesol(eq, sol) == (True, [0, 0, 0]) + + eq = (Eq(diff(x(t),t),4*x(t) - y(t) - 2*z(t)),Eq(diff(y(t),t),2*x(t) + y(t)- 2*z(t)),Eq(diff(z(t),t),5*x(t)-3*z(t))) + sol = [Eq(x(t), C1*exp(2*t) + C2*(-sin(t) + 3*cos(t)) + C3*(3*sin(t) + cos(t))), \ + Eq(y(t), C2*(-sin(t) + 3*cos(t)) + C3*(3*sin(t) + cos(t))), Eq(z(t), C1*exp(2*t) + 5*C2*cos(t) + 5*C3*sin(t))] + assert checksysodesol(eq, sol) == (True, [0, 0, 0]) + + eq = (Eq(diff(x(t),t),x(t)*y(t)**3), Eq(diff(y(t),t),y(t)**5)) + sol = [Eq(x(t), C1*exp((-1/(4*C2 + 4*t))**(Rational(-1, 4)))), Eq(y(t), -(-1/(4*C2 + 4*t))**Rational(1, 4)), \ + Eq(x(t), C1*exp(-1/(-1/(4*C2 + 4*t))**Rational(1, 4))), Eq(y(t), (-1/(4*C2 + 4*t))**Rational(1, 4)), \ + Eq(x(t), C1*exp(-I/(-1/(4*C2 + 4*t))**Rational(1, 4))), Eq(y(t), -I*(-1/(4*C2 + 4*t))**Rational(1, 4)), \ + Eq(x(t), C1*exp(I/(-1/(4*C2 + 4*t))**Rational(1, 4))), Eq(y(t), I*(-1/(4*C2 + 4*t))**Rational(1, 4))] + assert checksysodesol(eq, sol) == (True, [0, 0]) + + eq = (Eq(diff(x(t),t), exp(3*x(t))*y(t)**3),Eq(diff(y(t),t), y(t)**5)) + sol = [Eq(x(t), -log(C1 - 3/(-1/(4*C2 + 4*t))**Rational(1, 4))/3), Eq(y(t), -(-1/(4*C2 + 4*t))**Rational(1, 4)), \ + Eq(x(t), -log(C1 + 3/(-1/(4*C2 + 4*t))**Rational(1, 4))/3), Eq(y(t), (-1/(4*C2 + 4*t))**Rational(1, 4)), \ + Eq(x(t), -log(C1 + 3*I/(-1/(4*C2 + 4*t))**Rational(1, 4))/3), Eq(y(t), -I*(-1/(4*C2 + 4*t))**Rational(1, 4)), \ + Eq(x(t), -log(C1 - 3*I/(-1/(4*C2 + 4*t))**Rational(1, 4))/3), Eq(y(t), I*(-1/(4*C2 + 4*t))**Rational(1, 4))] + assert checksysodesol(eq, sol) == (True, [0, 0]) + + eq = (Eq(x(t),t*diff(x(t),t)+diff(x(t),t)*diff(y(t),t)), Eq(y(t),t*diff(y(t),t)+diff(y(t),t)**2)) + sol = {Eq(x(t), C1*C2 + C1*t), Eq(y(t), C2**2 + C2*t)} + assert checksysodesol(eq, sol) == (True, [0, 0]) diff --git a/phivenv/Lib/site-packages/sympy/solvers/ode/tests/test_systems.py b/phivenv/Lib/site-packages/sympy/solvers/ode/tests/test_systems.py new file mode 100644 index 0000000000000000000000000000000000000000..9d206129dfcf38c7b8c2e0ab42bd875003253f35 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/solvers/ode/tests/test_systems.py @@ -0,0 +1,2544 @@ +from sympy.core.function import (Derivative, Function, diff) +from sympy.core.mul import Mul +from sympy.core.numbers import (I, Rational, pi) +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.hyperbolic import sinh +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.matrices.dense import Matrix +from sympy.core.containers import Tuple +from sympy.functions import exp, cos, sin, log, Ci, Si, erf, erfi +from sympy.matrices import dotprodsimp, NonSquareMatrixError +from sympy.solvers.ode import dsolve +from sympy.solvers.ode.ode import constant_renumber +from sympy.solvers.ode.subscheck import checksysodesol +from sympy.solvers.ode.systems import (_classify_linear_system, linear_ode_to_matrix, + ODEOrderError, ODENonlinearError, _simpsol, + _is_commutative_anti_derivative, linodesolve, + canonical_odes, dsolve_system, _component_division, + _eqs2dict, _dict2graph) +from sympy.functions import airyai, airybi +from sympy.integrals.integrals import Integral +from sympy.simplify.ratsimp import ratsimp +from sympy.testing.pytest import raises, slow, tooslow, XFAIL + + +C0, C1, C2, C3, C4, C5, C6, C7, C8, C9, C10 = symbols('C0:11') +x = symbols('x') +f = Function('f') +g = Function('g') +h = Function('h') + + +def test_linear_ode_to_matrix(): + f, g, h = symbols("f, g, h", cls=Function) + t = Symbol("t") + funcs = [f(t), g(t), h(t)] + f1 = f(t).diff(t) + g1 = g(t).diff(t) + h1 = h(t).diff(t) + f2 = f(t).diff(t, 2) + g2 = g(t).diff(t, 2) + h2 = h(t).diff(t, 2) + + eqs_1 = [Eq(f1, g(t)), Eq(g1, f(t))] + sol_1 = ([Matrix([[1, 0], [0, 1]]), Matrix([[ 0, 1], [1, 0]])], Matrix([[0],[0]])) + assert linear_ode_to_matrix(eqs_1, funcs[:-1], t, 1) == sol_1 + + eqs_2 = [Eq(f1, f(t) + 2*g(t)), Eq(g1, h(t)), Eq(h1, g(t) + h(t) + f(t))] + sol_2 = ([Matrix([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), Matrix([[1, 2, 0], [ 0, 0, 1], [1, 1, 1]])], + Matrix([[0], [0], [0]])) + assert linear_ode_to_matrix(eqs_2, funcs, t, 1) == sol_2 + + eqs_3 = [Eq(2*f1 + 3*h1, f(t) + g(t)), Eq(4*h1 + 5*g1, f(t) + h(t)), Eq(5*f1 + 4*g1, g(t) + h(t))] + sol_3 = ([Matrix([[2, 0, 3], [0, 5, 4], [5, 4, 0]]), Matrix([[1, 1, 0], [1, 0, 1], [0, 1, 1]])], + Matrix([[0], [0], [0]])) + assert linear_ode_to_matrix(eqs_3, funcs, t, 1) == sol_3 + + eqs_4 = [Eq(f2 + h(t), f1 + g(t)), Eq(2*h2 + g2 + g1 + g(t), 0), Eq(3*h1, 4)] + sol_4 = ([Matrix([[1, 0, 0], [0, 1, 2], [0, 0, 0]]), Matrix([[1, 0, 0], [0, -1, 0], [0, 0, -3]]), + Matrix([[0, 1, -1], [0, -1, 0], [0, 0, 0]])], Matrix([[0], [0], [4]])) + assert linear_ode_to_matrix(eqs_4, funcs, t, 2) == sol_4 + + eqs_5 = [Eq(f2, g(t)), Eq(f1 + g1, f(t))] + raises(ODEOrderError, lambda: linear_ode_to_matrix(eqs_5, funcs[:-1], t, 1)) + + eqs_6 = [Eq(f1, f(t)**2), Eq(g1, f(t) + g(t))] + raises(ODENonlinearError, lambda: linear_ode_to_matrix(eqs_6, funcs[:-1], t, 1)) + + +def test__classify_linear_system(): + x, y, z, w = symbols('x, y, z, w', cls=Function) + t, k, l = symbols('t k l') + x1 = diff(x(t), t) + y1 = diff(y(t), t) + z1 = diff(z(t), t) + w1 = diff(w(t), t) + x2 = diff(x(t), t, t) + y2 = diff(y(t), t, t) + funcs = [x(t), y(t)] + funcs_2 = funcs + [z(t), w(t)] + + eqs_1 = (5 * x1 + 12 * x(t) - 6 * (y(t)), (2 * y1 - 11 * t * x(t) + 3 * y(t) + t)) + assert _classify_linear_system(eqs_1, funcs, t) is None + + eqs_2 = (5 * (x1**2) + 12 * x(t) - 6 * (y(t)), (2 * y1 - 11 * t * x(t) + 3 * y(t) + t)) + sol2 = {'is_implicit': True, + 'canon_eqs': [[Eq(Derivative(x(t), t), -sqrt(-12*x(t)/5 + 6*y(t)/5)), + Eq(Derivative(y(t), t), 11*t*x(t)/2 - t/2 - 3*y(t)/2)], + [Eq(Derivative(x(t), t), sqrt(-12*x(t)/5 + 6*y(t)/5)), + Eq(Derivative(y(t), t), 11*t*x(t)/2 - t/2 - 3*y(t)/2)]]} + assert _classify_linear_system(eqs_2, funcs, t) == sol2 + + eqs_2_1 = [Eq(Derivative(x(t), t), -sqrt(-12*x(t)/5 + 6*y(t)/5)), + Eq(Derivative(y(t), t), 11*t*x(t)/2 - t/2 - 3*y(t)/2)] + assert _classify_linear_system(eqs_2_1, funcs, t) is None + + eqs_2_2 = [Eq(Derivative(x(t), t), sqrt(-12*x(t)/5 + 6*y(t)/5)), + Eq(Derivative(y(t), t), 11*t*x(t)/2 - t/2 - 3*y(t)/2)] + assert _classify_linear_system(eqs_2_2, funcs, t) is None + + eqs_3 = (5 * x1 + 12 * x(t) - 6 * (y(t)), (2 * y1 - 11 * x(t) + 3 * y(t)), (5 * w1 + z(t)), (z1 + w(t))) + answer_3 = {'no_of_equation': 4, + 'eq': (12*x(t) - 6*y(t) + 5*Derivative(x(t), t), + -11*x(t) + 3*y(t) + 2*Derivative(y(t), t), + z(t) + 5*Derivative(w(t), t), + w(t) + Derivative(z(t), t)), + 'func': [x(t), y(t), z(t), w(t)], + 'order': {x(t): 1, y(t): 1, z(t): 1, w(t): 1}, + 'is_linear': True, + 'is_constant': True, + 'is_homogeneous': True, + 'func_coeff': -Matrix([ + [Rational(12, 5), Rational(-6, 5), 0, 0], + [Rational(-11, 2), Rational(3, 2), 0, 0], + [0, 0, 0, 1], + [0, 0, Rational(1, 5), 0]]), + 'type_of_equation': 'type1', + 'is_general': True} + assert _classify_linear_system(eqs_3, funcs_2, t) == answer_3 + + eqs_4 = (5 * x1 + 12 * x(t) - 6 * (y(t)), (2 * y1 - 11 * x(t) + 3 * y(t)), (z1 - w(t)), (w1 - z(t))) + answer_4 = {'no_of_equation': 4, + 'eq': (12 * x(t) - 6 * y(t) + 5 * Derivative(x(t), t), + -11 * x(t) + 3 * y(t) + 2 * Derivative(y(t), t), + -w(t) + Derivative(z(t), t), + -z(t) + Derivative(w(t), t)), + 'func': [x(t), y(t), z(t), w(t)], + 'order': {x(t): 1, y(t): 1, z(t): 1, w(t): 1}, + 'is_linear': True, + 'is_constant': True, + 'is_homogeneous': True, + 'func_coeff': -Matrix([ + [Rational(12, 5), Rational(-6, 5), 0, 0], + [Rational(-11, 2), Rational(3, 2), 0, 0], + [0, 0, 0, -1], + [0, 0, -1, 0]]), + 'type_of_equation': 'type1', + 'is_general': True} + assert _classify_linear_system(eqs_4, funcs_2, t) == answer_4 + + eqs_5 = (5*x1 + 12*x(t) - 6*(y(t)) + x2, (2*y1 - 11*x(t) + 3*y(t)), (z1 - w(t)), (w1 - z(t))) + answer_5 = {'no_of_equation': 4, 'eq': (12*x(t) - 6*y(t) + 5*Derivative(x(t), t) + Derivative(x(t), (t, 2)), + -11*x(t) + 3*y(t) + 2*Derivative(y(t), t), -w(t) + Derivative(z(t), t), -z(t) + Derivative(w(t), + t)), 'func': [x(t), y(t), z(t), w(t)], 'order': {x(t): 2, y(t): 1, z(t): 1, w(t): 1}, 'is_linear': + True, 'is_homogeneous': True, 'is_general': True, 'type_of_equation': 'type0', 'is_higher_order': True} + assert _classify_linear_system(eqs_5, funcs_2, t) == answer_5 + + eqs_6 = (Eq(x1, 3*y(t) - 11*z(t)), Eq(y1, 7*z(t) - 3*x(t)), Eq(z1, 11*x(t) - 7*y(t))) + answer_6 = {'no_of_equation': 3, 'eq': (Eq(Derivative(x(t), t), 3*y(t) - 11*z(t)), Eq(Derivative(y(t), t), -3*x(t) + 7*z(t)), + Eq(Derivative(z(t), t), 11*x(t) - 7*y(t))), 'func': [x(t), y(t), z(t)], 'order': {x(t): 1, y(t): 1, z(t): 1}, + 'is_linear': True, 'is_constant': True, 'is_homogeneous': True, + 'func_coeff': -Matrix([ + [ 0, -3, 11], + [ 3, 0, -7], + [-11, 7, 0]]), + 'type_of_equation': 'type1', 'is_general': True} + + assert _classify_linear_system(eqs_6, funcs_2[:-1], t) == answer_6 + + eqs_7 = (Eq(x1, y(t)), Eq(y1, x(t))) + answer_7 = {'no_of_equation': 2, 'eq': (Eq(Derivative(x(t), t), y(t)), Eq(Derivative(y(t), t), x(t))), + 'func': [x(t), y(t)], 'order': {x(t): 1, y(t): 1}, 'is_linear': True, 'is_constant': True, + 'is_homogeneous': True, 'func_coeff': -Matrix([ + [ 0, -1], + [-1, 0]]), + 'type_of_equation': 'type1', 'is_general': True} + assert _classify_linear_system(eqs_7, funcs, t) == answer_7 + + eqs_8 = (Eq(x1, 21*x(t)), Eq(y1, 17*x(t) + 3*y(t)), Eq(z1, 5*x(t) + 7*y(t) + 9*z(t))) + answer_8 = {'no_of_equation': 3, 'eq': (Eq(Derivative(x(t), t), 21*x(t)), Eq(Derivative(y(t), t), 17*x(t) + 3*y(t)), + Eq(Derivative(z(t), t), 5*x(t) + 7*y(t) + 9*z(t))), 'func': [x(t), y(t), z(t)], 'order': {x(t): 1, y(t): 1, z(t): 1}, + 'is_linear': True, 'is_constant': True, 'is_homogeneous': True, + 'func_coeff': -Matrix([ + [-21, 0, 0], + [-17, -3, 0], + [ -5, -7, -9]]), + 'type_of_equation': 'type1', 'is_general': True} + + assert _classify_linear_system(eqs_8, funcs_2[:-1], t) == answer_8 + + eqs_9 = (Eq(x1, 4*x(t) + 5*y(t) + 2*z(t)), Eq(y1, x(t) + 13*y(t) + 9*z(t)), Eq(z1, 32*x(t) + 41*y(t) + 11*z(t))) + answer_9 = {'no_of_equation': 3, 'eq': (Eq(Derivative(x(t), t), 4*x(t) + 5*y(t) + 2*z(t)), + Eq(Derivative(y(t), t), x(t) + 13*y(t) + 9*z(t)), Eq(Derivative(z(t), t), 32*x(t) + 41*y(t) + 11*z(t))), + 'func': [x(t), y(t), z(t)], 'order': {x(t): 1, y(t): 1, z(t): 1}, 'is_linear': True, + 'is_constant': True, 'is_homogeneous': True, + 'func_coeff': -Matrix([ + [ -4, -5, -2], + [ -1, -13, -9], + [-32, -41, -11]]), + 'type_of_equation': 'type1', 'is_general': True} + assert _classify_linear_system(eqs_9, funcs_2[:-1], t) == answer_9 + + eqs_10 = (Eq(3*x1, 4*5*(y(t) - z(t))), Eq(4*y1, 3*5*(z(t) - x(t))), Eq(5*z1, 3*4*(x(t) - y(t)))) + answer_10 = {'no_of_equation': 3, 'eq': (Eq(3*Derivative(x(t), t), 20*y(t) - 20*z(t)), + Eq(4*Derivative(y(t), t), -15*x(t) + 15*z(t)), Eq(5*Derivative(z(t), t), 12*x(t) - 12*y(t))), + 'func': [x(t), y(t), z(t)], 'order': {x(t): 1, y(t): 1, z(t): 1}, 'is_linear': True, + 'is_constant': True, 'is_homogeneous': True, + 'func_coeff': -Matrix([ + [ 0, Rational(-20, 3), Rational(20, 3)], + [Rational(15, 4), 0, Rational(-15, 4)], + [Rational(-12, 5), Rational(12, 5), 0]]), + 'type_of_equation': 'type1', 'is_general': True} + assert _classify_linear_system(eqs_10, funcs_2[:-1], t) == answer_10 + + eq11 = (Eq(x1, 3*y(t) - 11*z(t)), Eq(y1, 7*z(t) - 3*x(t)), Eq(z1, 11*x(t) - 7*y(t))) + sol11 = {'no_of_equation': 3, 'eq': (Eq(Derivative(x(t), t), 3*y(t) - 11*z(t)), Eq(Derivative(y(t), t), -3*x(t) + 7*z(t)), + Eq(Derivative(z(t), t), 11*x(t) - 7*y(t))), 'func': [x(t), y(t), z(t)], 'order': {x(t): 1, y(t): 1, z(t): 1}, + 'is_linear': True, 'is_constant': True, 'is_homogeneous': True, 'func_coeff': -Matrix([ + [ 0, -3, 11], [ 3, 0, -7], [-11, 7, 0]]), 'type_of_equation': 'type1', 'is_general': True} + assert _classify_linear_system(eq11, funcs_2[:-1], t) == sol11 + + eq12 = (Eq(Derivative(x(t), t), y(t)), Eq(Derivative(y(t), t), x(t))) + sol12 = {'no_of_equation': 2, 'eq': (Eq(Derivative(x(t), t), y(t)), Eq(Derivative(y(t), t), x(t))), + 'func': [x(t), y(t)], 'order': {x(t): 1, y(t): 1}, 'is_linear': True, 'is_constant': True, + 'is_homogeneous': True, 'func_coeff': -Matrix([ + [0, -1], + [-1, 0]]), 'type_of_equation': 'type1', 'is_general': True} + assert _classify_linear_system(eq12, [x(t), y(t)], t) == sol12 + + eq13 = (Eq(Derivative(x(t), t), 21*x(t)), Eq(Derivative(y(t), t), 17*x(t) + 3*y(t)), + Eq(Derivative(z(t), t), 5*x(t) + 7*y(t) + 9*z(t))) + sol13 = {'no_of_equation': 3, 'eq': ( + Eq(Derivative(x(t), t), 21 * x(t)), Eq(Derivative(y(t), t), 17 * x(t) + 3 * y(t)), + Eq(Derivative(z(t), t), 5 * x(t) + 7 * y(t) + 9 * z(t))), 'func': [x(t), y(t), z(t)], + 'order': {x(t): 1, y(t): 1, z(t): 1}, 'is_linear': True, 'is_constant': True, 'is_homogeneous': True, + 'func_coeff': -Matrix([ + [-21, 0, 0], + [-17, -3, 0], + [-5, -7, -9]]), 'type_of_equation': 'type1', 'is_general': True} + assert _classify_linear_system(eq13, [x(t), y(t), z(t)], t) == sol13 + + eq14 = ( + Eq(Derivative(x(t), t), 4*x(t) + 5*y(t) + 2*z(t)), Eq(Derivative(y(t), t), x(t) + 13*y(t) + 9*z(t)), + Eq(Derivative(z(t), t), 32*x(t) + 41*y(t) + 11*z(t))) + sol14 = {'no_of_equation': 3, 'eq': ( + Eq(Derivative(x(t), t), 4 * x(t) + 5 * y(t) + 2 * z(t)), Eq(Derivative(y(t), t), x(t) + 13 * y(t) + 9 * z(t)), + Eq(Derivative(z(t), t), 32 * x(t) + 41 * y(t) + 11 * z(t))), 'func': [x(t), y(t), z(t)], + 'order': {x(t): 1, y(t): 1, z(t): 1}, 'is_linear': True, 'is_constant': True, 'is_homogeneous': True, + 'func_coeff': -Matrix([ + [-4, -5, -2], + [-1, -13, -9], + [-32, -41, -11]]), 'type_of_equation': 'type1', 'is_general': True} + assert _classify_linear_system(eq14, [x(t), y(t), z(t)], t) == sol14 + + eq15 = (Eq(3*Derivative(x(t), t), 20*y(t) - 20*z(t)), Eq(4*Derivative(y(t), t), -15*x(t) + 15*z(t)), + Eq(5*Derivative(z(t), t), 12*x(t) - 12*y(t))) + sol15 = {'no_of_equation': 3, 'eq': ( + Eq(3 * Derivative(x(t), t), 20 * y(t) - 20 * z(t)), Eq(4 * Derivative(y(t), t), -15 * x(t) + 15 * z(t)), + Eq(5 * Derivative(z(t), t), 12 * x(t) - 12 * y(t))), 'func': [x(t), y(t), z(t)], + 'order': {x(t): 1, y(t): 1, z(t): 1}, 'is_linear': True, 'is_constant': True, 'is_homogeneous': True, + 'func_coeff': -Matrix([ + [0, Rational(-20, 3), Rational(20, 3)], + [Rational(15, 4), 0, Rational(-15, 4)], + [Rational(-12, 5), Rational(12, 5), 0]]), 'type_of_equation': 'type1', 'is_general': True} + assert _classify_linear_system(eq15, [x(t), y(t), z(t)], t) == sol15 + + # Constant coefficient homogeneous ODEs + eq1 = (Eq(diff(x(t), t), x(t) + y(t) + 9), Eq(diff(y(t), t), 2*x(t) + 5*y(t) + 23)) + sol1 = {'no_of_equation': 2, 'eq': (Eq(Derivative(x(t), t), x(t) + y(t) + 9), + Eq(Derivative(y(t), t), 2*x(t) + 5*y(t) + 23)), 'func': [x(t), y(t)], + 'order': {x(t): 1, y(t): 1}, 'is_linear': True, 'is_constant': True, 'is_homogeneous': False, 'is_general': True, + 'func_coeff': -Matrix([[-1, -1], [-2, -5]]), 'rhs': Matrix([[ 9], [23]]), 'type_of_equation': 'type2'} + assert _classify_linear_system(eq1, funcs, t) == sol1 + + # Non constant coefficient homogeneous ODEs + eq1 = (Eq(diff(x(t), t), 5*t*x(t) + 2*y(t)), Eq(diff(y(t), t), 2*x(t) + 5*t*y(t))) + sol1 = {'no_of_equation': 2, 'eq': (Eq(Derivative(x(t), t), 5*t*x(t) + 2*y(t)), Eq(Derivative(y(t), t), 5*t*y(t) + 2*x(t))), + 'func': [x(t), y(t)], 'order': {x(t): 1, y(t): 1}, 'is_linear': True, 'is_constant': False, + 'is_homogeneous': True, 'func_coeff': -Matrix([ [-5*t, -2], [ -2, -5*t]]), 'commutative_antiderivative': Matrix([ + [5*t**2/2, 2*t], [ 2*t, 5*t**2/2]]), 'type_of_equation': 'type3', 'is_general': True} + assert _classify_linear_system(eq1, funcs, t) == sol1 + + # Non constant coefficient non-homogeneous ODEs + eq1 = [Eq(x1, x(t) + t*y(t) + t), Eq(y1, t*x(t) + y(t))] + sol1 = {'no_of_equation': 2, 'eq': [Eq(Derivative(x(t), t), t*y(t) + t + x(t)), Eq(Derivative(y(t), t), + t*x(t) + y(t))], 'func': [x(t), y(t)], 'order': {x(t): 1, y(t): 1}, 'is_linear': True, + 'is_constant': False, 'is_homogeneous': False, 'is_general': True, 'func_coeff': -Matrix([ [-1, -t], + [-t, -1]]), 'commutative_antiderivative': Matrix([ [ t, t**2/2], [t**2/2, t]]), 'rhs': + Matrix([ [t], [0]]), 'type_of_equation': 'type4'} + assert _classify_linear_system(eq1, funcs, t) == sol1 + + eq2 = [Eq(x1, t*x(t) + t*y(t) + t), Eq(y1, t*x(t) + t*y(t) + cos(t))] + sol2 = {'no_of_equation': 2, 'eq': [Eq(Derivative(x(t), t), t*x(t) + t*y(t) + t), Eq(Derivative(y(t), t), + t*x(t) + t*y(t) + cos(t))], 'func': [x(t), y(t)], 'order': {x(t): 1, y(t): 1}, 'is_linear': True, + 'is_homogeneous': False, 'is_general': True, 'rhs': Matrix([ [ t], [cos(t)]]), 'func_coeff': + Matrix([ [t, t], [t, t]]), 'is_constant': False, 'type_of_equation': 'type4', + 'commutative_antiderivative': Matrix([ [t**2/2, t**2/2], [t**2/2, t**2/2]])} + assert _classify_linear_system(eq2, funcs, t) == sol2 + + eq3 = [Eq(x1, t*(x(t) + y(t) + z(t) + 1)), Eq(y1, t*(x(t) + y(t) + z(t))), Eq(z1, t*(x(t) + y(t) + z(t)))] + sol3 = {'no_of_equation': 3, 'eq': [Eq(Derivative(x(t), t), t*(x(t) + y(t) + z(t) + 1)), + Eq(Derivative(y(t), t), t*(x(t) + y(t) + z(t))), Eq(Derivative(z(t), t), t*(x(t) + y(t) + z(t)))], + 'func': [x(t), y(t), z(t)], 'order': {x(t): 1, y(t): 1, z(t): 1}, 'is_linear': True, 'is_constant': + False, 'is_homogeneous': False, 'is_general': True, 'func_coeff': -Matrix([ [-t, -t, -t], [-t, -t, + -t], [-t, -t, -t]]), 'commutative_antiderivative': Matrix([ [t**2/2, t**2/2, t**2/2], [t**2/2, + t**2/2, t**2/2], [t**2/2, t**2/2, t**2/2]]), 'rhs': Matrix([ [t], [0], [0]]), 'type_of_equation': + 'type4'} + assert _classify_linear_system(eq3, funcs_2[:-1], t) == sol3 + + eq4 = [Eq(x1, x(t) + y(t) + t*z(t) + 1), Eq(y1, x(t) + t*y(t) + z(t) + 10), Eq(z1, t*x(t) + y(t) + z(t) + t)] + sol4 = {'no_of_equation': 3, 'eq': [Eq(Derivative(x(t), t), t*z(t) + x(t) + y(t) + 1), Eq(Derivative(y(t), + t), t*y(t) + x(t) + z(t) + 10), Eq(Derivative(z(t), t), t*x(t) + t + y(t) + z(t))], 'func': [x(t), + y(t), z(t)], 'order': {x(t): 1, y(t): 1, z(t): 1}, 'is_linear': True, 'is_constant': False, + 'is_homogeneous': False, 'is_general': True, 'func_coeff': -Matrix([ [-1, -1, -t], [-1, -t, -1], [-t, + -1, -1]]), 'commutative_antiderivative': Matrix([ [ t, t, t**2/2], [ t, t**2/2, + t], [t**2/2, t, t]]), 'rhs': Matrix([ [ 1], [10], [ t]]), 'type_of_equation': 'type4'} + assert _classify_linear_system(eq4, funcs_2[:-1], t) == sol4 + + sum_terms = t*(x(t) + y(t) + z(t) + w(t)) + eq5 = [Eq(x1, sum_terms), Eq(y1, sum_terms), Eq(z1, sum_terms + 1), Eq(w1, sum_terms)] + sol5 = {'no_of_equation': 4, 'eq': [Eq(Derivative(x(t), t), t*(w(t) + x(t) + y(t) + z(t))), + Eq(Derivative(y(t), t), t*(w(t) + x(t) + y(t) + z(t))), Eq(Derivative(z(t), t), t*(w(t) + x(t) + + y(t) + z(t)) + 1), Eq(Derivative(w(t), t), t*(w(t) + x(t) + y(t) + z(t)))], 'func': [x(t), y(t), + z(t), w(t)], 'order': {x(t): 1, y(t): 1, z(t): 1, w(t): 1}, 'is_linear': True, 'is_constant': False, + 'is_homogeneous': False, 'is_general': True, 'func_coeff': -Matrix([ [-t, -t, -t, -t], [-t, -t, -t, + -t], [-t, -t, -t, -t], [-t, -t, -t, -t]]), 'commutative_antiderivative': Matrix([ [t**2/2, t**2/2, + t**2/2, t**2/2], [t**2/2, t**2/2, t**2/2, t**2/2], [t**2/2, t**2/2, t**2/2, t**2/2], [t**2/2, + t**2/2, t**2/2, t**2/2]]), 'rhs': Matrix([ [0], [0], [1], [0]]), 'type_of_equation': 'type4'} + assert _classify_linear_system(eq5, funcs_2, t) == sol5 + + # Second Order + t_ = symbols("t_") + + eq1 = (Eq(9*x(t) + 7*y(t) + 4*Derivative(x(t), t) + Derivative(x(t), (t, 2)) + 3*Derivative(y(t), t), 11*exp(I*t)), + Eq(3*x(t) + 12*y(t) + 5*Derivative(x(t), t) + 8*Derivative(y(t), t) + Derivative(y(t), (t, 2)), 2*exp(I*t))) + sol1 = {'no_of_equation': 2, 'eq': (Eq(9*x(t) + 7*y(t) + 4*Derivative(x(t), t) + Derivative(x(t), (t, 2)) + + 3*Derivative(y(t), t), 11*exp(I*t)), Eq(3*x(t) + 12*y(t) + 5*Derivative(x(t), t) + + 8*Derivative(y(t), t) + Derivative(y(t), (t, 2)), 2*exp(I*t))), 'func': [x(t), y(t)], 'order': + {x(t): 2, y(t): 2}, 'is_linear': True, 'is_homogeneous': False, 'is_general': True, 'rhs': Matrix([ + [11*exp(I*t)], [ 2*exp(I*t)]]), 'type_of_equation': 'type0', 'is_second_order': True, + 'is_higher_order': True} + assert _classify_linear_system(eq1, funcs, t) == sol1 + + eq2 = (Eq((4*t**2 + 7*t + 1)**2*Derivative(x(t), (t, 2)), 5*x(t) + 35*y(t)), + Eq((4*t**2 + 7*t + 1)**2*Derivative(y(t), (t, 2)), x(t) + 9*y(t))) + sol2 = {'no_of_equation': 2, 'eq': (Eq((4*t**2 + 7*t + 1)**2*Derivative(x(t), (t, 2)), 5*x(t) + 35*y(t)), + Eq((4*t**2 + 7*t + 1)**2*Derivative(y(t), (t, 2)), x(t) + 9*y(t))), 'func': [x(t), y(t)], 'order': + {x(t): 2, y(t): 2}, 'is_linear': True, 'is_homogeneous': True, 'is_general': True, + 'type_of_equation': 'type2', 'A0': Matrix([ [Rational(53, 4), 35], [ 1, Rational(69, 4)]]), 'g(t)': sqrt(4*t**2 + 7*t + + 1), 'tau': sqrt(33)*log(t - sqrt(33)/8 + Rational(7, 8))/33 - sqrt(33)*log(t + sqrt(33)/8 + Rational(7, 8))/33, + 'is_transformed': True, 't_': t_, 'is_second_order': True, 'is_higher_order': True} + assert _classify_linear_system(eq2, funcs, t) == sol2 + + eq3 = ((t*Derivative(x(t), t) - x(t))*log(t) + (t*Derivative(y(t), t) - y(t))*exp(t) + Derivative(x(t), (t, 2)), + t**2*(t*Derivative(x(t), t) - x(t)) + t*(t*Derivative(y(t), t) - y(t)) + Derivative(y(t), (t, 2))) + sol3 = {'no_of_equation': 2, 'eq': ((t*Derivative(x(t), t) - x(t))*log(t) + (t*Derivative(y(t), t) - + y(t))*exp(t) + Derivative(x(t), (t, 2)), t**2*(t*Derivative(x(t), t) - x(t)) + t*(t*Derivative(y(t), + t) - y(t)) + Derivative(y(t), (t, 2))), 'func': [x(t), y(t)], 'order': {x(t): 2, y(t): 2}, + 'is_linear': True, 'is_homogeneous': True, 'is_general': True, 'type_of_equation': 'type1', 'A1': + Matrix([ [-t*log(t), -t*exp(t)], [ -t**3, -t**2]]), 'is_second_order': True, + 'is_higher_order': True} + assert _classify_linear_system(eq3, funcs, t) == sol3 + + eq4 = (Eq(x2, k*x(t) - l*y1), Eq(y2, l*x1 + k*y(t))) + sol4 = {'no_of_equation': 2, 'eq': (Eq(Derivative(x(t), (t, 2)), k*x(t) - l*Derivative(y(t), t)), + Eq(Derivative(y(t), (t, 2)), k*y(t) + l*Derivative(x(t), t))), 'func': [x(t), y(t)], 'order': {x(t): + 2, y(t): 2}, 'is_linear': True, 'is_homogeneous': True, 'is_general': True, 'type_of_equation': + 'type0', 'is_second_order': True, 'is_higher_order': True} + assert _classify_linear_system(eq4, funcs, t) == sol4 + + + # Multiple matches + + f, g = symbols("f g", cls=Function) + y, t_ = symbols("y t_") + funcs = [f(t), g(t)] + + eq1 = [Eq(Derivative(f(t), t)**2 - 2*Derivative(f(t), t) + 1, 4), + Eq(-y*f(t) + Derivative(g(t), t), 0)] + sol1 = {'is_implicit': True, + 'canon_eqs': [[Eq(Derivative(f(t), t), -1), Eq(Derivative(g(t), t), y*f(t))], + [Eq(Derivative(f(t), t), 3), Eq(Derivative(g(t), t), y*f(t))]]} + assert _classify_linear_system(eq1, funcs, t) == sol1 + + raises(ValueError, lambda: _classify_linear_system(eq1, funcs[:1], t)) + + eq2 = [Eq(Derivative(f(t), t), (2*f(t) + g(t) + 1)/t), Eq(Derivative(g(t), t), (f(t) + 2*g(t))/t)] + sol2 = {'no_of_equation': 2, 'eq': [Eq(Derivative(f(t), t), (2*f(t) + g(t) + 1)/t), Eq(Derivative(g(t), t), + (f(t) + 2*g(t))/t)], 'func': [f(t), g(t)], 'order': {f(t): 1, g(t): 1}, 'is_linear': True, + 'is_homogeneous': False, 'is_general': True, 'rhs': Matrix([ [1], [0]]), 'func_coeff': Matrix([ [2, + 1], [1, 2]]), 'is_constant': False, 'type_of_equation': 'type6', 't_': t_, 'tau': log(t), + 'commutative_antiderivative': Matrix([ [2*log(t), log(t)], [ log(t), 2*log(t)]])} + assert _classify_linear_system(eq2, funcs, t) == sol2 + + eq3 = [Eq(Derivative(f(t), t), (2*f(t) + g(t))/t), Eq(Derivative(g(t), t), (f(t) + 2*g(t))/t)] + sol3 = {'no_of_equation': 2, 'eq': [Eq(Derivative(f(t), t), (2*f(t) + g(t))/t), Eq(Derivative(g(t), t), + (f(t) + 2*g(t))/t)], 'func': [f(t), g(t)], 'order': {f(t): 1, g(t): 1}, 'is_linear': True, + 'is_homogeneous': True, 'is_general': True, 'func_coeff': Matrix([ [2, 1], [1, 2]]), 'is_constant': + False, 'type_of_equation': 'type5', 't_': t_, 'rhs': Matrix([ [0], [0]]), 'tau': log(t), + 'commutative_antiderivative': Matrix([ [2*log(t), log(t)], [ log(t), 2*log(t)]])} + assert _classify_linear_system(eq3, funcs, t) == sol3 + + +def test_matrix_exp(): + from sympy.matrices.dense import Matrix, eye, zeros + from sympy.solvers.ode.systems import matrix_exp + t = Symbol('t') + + for n in range(1, 6+1): + assert matrix_exp(zeros(n), t) == eye(n) + + for n in range(1, 6+1): + A = eye(n) + expAt = exp(t) * eye(n) + assert matrix_exp(A, t) == expAt + + for n in range(1, 6+1): + A = Matrix(n, n, lambda i,j: i+1 if i==j else 0) + expAt = Matrix(n, n, lambda i,j: exp((i+1)*t) if i==j else 0) + assert matrix_exp(A, t) == expAt + + A = Matrix([[0, 1], [-1, 0]]) + expAt = Matrix([[cos(t), sin(t)], [-sin(t), cos(t)]]) + assert matrix_exp(A, t) == expAt + + A = Matrix([[2, -5], [2, -4]]) + expAt = Matrix([ + [3*exp(-t)*sin(t) + exp(-t)*cos(t), -5*exp(-t)*sin(t)], + [2*exp(-t)*sin(t), -3*exp(-t)*sin(t) + exp(-t)*cos(t)] + ]) + assert matrix_exp(A, t) == expAt + + A = Matrix([[21, 17, 6], [-5, -1, -6], [4, 4, 16]]) + # TO update this. + # expAt = Matrix([ + # [(8*t*exp(12*t) + 5*exp(12*t) - 1)*exp(4*t)/4, + # (8*t*exp(12*t) + 5*exp(12*t) - 5)*exp(4*t)/4, + # (exp(12*t) - 1)*exp(4*t)/2], + # [(-8*t*exp(12*t) - exp(12*t) + 1)*exp(4*t)/4, + # (-8*t*exp(12*t) - exp(12*t) + 5)*exp(4*t)/4, + # (-exp(12*t) + 1)*exp(4*t)/2], + # [4*t*exp(16*t), 4*t*exp(16*t), exp(16*t)]]) + expAt = Matrix([ + [2*t*exp(16*t) + 5*exp(16*t)/4 - exp(4*t)/4, 2*t*exp(16*t) + 5*exp(16*t)/4 - 5*exp(4*t)/4, exp(16*t)/2 - exp(4*t)/2], + [ -2*t*exp(16*t) - exp(16*t)/4 + exp(4*t)/4, -2*t*exp(16*t) - exp(16*t)/4 + 5*exp(4*t)/4, -exp(16*t)/2 + exp(4*t)/2], + [ 4*t*exp(16*t), 4*t*exp(16*t), exp(16*t)] + ]) + assert matrix_exp(A, t) == expAt + + A = Matrix([[1, 1, 0, 0], + [0, 1, 1, 0], + [0, 0, 1, -S(1)/8], + [0, 0, S(1)/2, S(1)/2]]) + expAt = Matrix([ + [exp(t), t*exp(t), 4*t*exp(3*t/4) + 8*t*exp(t) + 48*exp(3*t/4) - 48*exp(t), + -2*t*exp(3*t/4) - 2*t*exp(t) - 16*exp(3*t/4) + 16*exp(t)], + [0, exp(t), -t*exp(3*t/4) - 8*exp(3*t/4) + 8*exp(t), t*exp(3*t/4)/2 + 2*exp(3*t/4) - 2*exp(t)], + [0, 0, t*exp(3*t/4)/4 + exp(3*t/4), -t*exp(3*t/4)/8], + [0, 0, t*exp(3*t/4)/2, -t*exp(3*t/4)/4 + exp(3*t/4)] + ]) + assert matrix_exp(A, t) == expAt + + A = Matrix([ + [ 0, 1, 0, 0], + [-1, 0, 0, 0], + [ 0, 0, 0, 1], + [ 0, 0, -1, 0]]) + + expAt = Matrix([ + [ cos(t), sin(t), 0, 0], + [-sin(t), cos(t), 0, 0], + [ 0, 0, cos(t), sin(t)], + [ 0, 0, -sin(t), cos(t)]]) + assert matrix_exp(A, t) == expAt + + A = Matrix([ + [ 0, 1, 1, 0], + [-1, 0, 0, 1], + [ 0, 0, 0, 1], + [ 0, 0, -1, 0]]) + + expAt = Matrix([ + [ cos(t), sin(t), t*cos(t), t*sin(t)], + [-sin(t), cos(t), -t*sin(t), t*cos(t)], + [ 0, 0, cos(t), sin(t)], + [ 0, 0, -sin(t), cos(t)]]) + assert matrix_exp(A, t) == expAt + + # This case is unacceptably slow right now but should be solvable... + #a, b, c, d, e, f = symbols('a b c d e f') + #A = Matrix([ + #[-a, b, c, d], + #[ a, -b, e, 0], + #[ 0, 0, -c - e - f, 0], + #[ 0, 0, f, -d]]) + + A = Matrix([[0, I], [I, 0]]) + expAt = Matrix([ + [exp(I*t)/2 + exp(-I*t)/2, exp(I*t)/2 - exp(-I*t)/2], + [exp(I*t)/2 - exp(-I*t)/2, exp(I*t)/2 + exp(-I*t)/2]]) + assert matrix_exp(A, t) == expAt + + # Testing Errors + M = Matrix([[1, 2, 3], [4, 5, 6], [7, 7, 7]]) + M1 = Matrix([[t, 1], [1, 1]]) + + raises(ValueError, lambda: matrix_exp(M[:, :2], t)) + raises(ValueError, lambda: matrix_exp(M[:2, :], t)) + raises(ValueError, lambda: matrix_exp(M1, t)) + raises(ValueError, lambda: matrix_exp(M1[:1, :1], t)) + + +def test_canonical_odes(): + f, g, h = symbols('f g h', cls=Function) + x = symbols('x') + funcs = [f(x), g(x), h(x)] + + eqs1 = [Eq(f(x).diff(x, x), f(x) + 2*g(x)), Eq(g(x) + 1, g(x).diff(x) + f(x))] + sol1 = [[Eq(Derivative(f(x), (x, 2)), f(x) + 2*g(x)), Eq(Derivative(g(x), x), -f(x) + g(x) + 1)]] + assert canonical_odes(eqs1, funcs[:2], x) == sol1 + + eqs2 = [Eq(f(x).diff(x), h(x).diff(x) + f(x)), Eq(g(x).diff(x)**2, f(x) + h(x)), Eq(h(x).diff(x), f(x))] + sol2 = [[Eq(Derivative(f(x), x), 2*f(x)), Eq(Derivative(g(x), x), -sqrt(f(x) + h(x))), Eq(Derivative(h(x), x), f(x))], + [Eq(Derivative(f(x), x), 2*f(x)), Eq(Derivative(g(x), x), sqrt(f(x) + h(x))), Eq(Derivative(h(x), x), f(x))]] + assert canonical_odes(eqs2, funcs, x) == sol2 + + +def test_sysode_linear_neq_order1_type1(): + + f, g, x, y, h = symbols('f g x y h', cls=Function) + a, b, c, t = symbols('a b c t') + + eqs1 = [Eq(Derivative(x(t), t), x(t)), + Eq(Derivative(y(t), t), y(t))] + sol1 = [Eq(x(t), C1*exp(t)), + Eq(y(t), C2*exp(t))] + assert dsolve(eqs1) == sol1 + assert checksysodesol(eqs1, sol1) == (True, [0, 0]) + + eqs2 = [Eq(Derivative(x(t), t), 2*x(t)), + Eq(Derivative(y(t), t), 3*y(t))] + sol2 = [Eq(x(t), C1*exp(2*t)), + Eq(y(t), C2*exp(3*t))] + assert dsolve(eqs2) == sol2 + assert checksysodesol(eqs2, sol2) == (True, [0, 0]) + + eqs3 = [Eq(Derivative(x(t), t), a*x(t)), + Eq(Derivative(y(t), t), a*y(t))] + sol3 = [Eq(x(t), C1*exp(a*t)), + Eq(y(t), C2*exp(a*t))] + assert dsolve(eqs3) == sol3 + assert checksysodesol(eqs3, sol3) == (True, [0, 0]) + + # Regression test case for issue #15474 + # https://github.com/sympy/sympy/issues/15474 + eqs4 = [Eq(Derivative(x(t), t), a*x(t)), + Eq(Derivative(y(t), t), b*y(t))] + sol4 = [Eq(x(t), C1*exp(a*t)), + Eq(y(t), C2*exp(b*t))] + assert dsolve(eqs4) == sol4 + assert checksysodesol(eqs4, sol4) == (True, [0, 0]) + + eqs5 = [Eq(Derivative(x(t), t), -y(t)), + Eq(Derivative(y(t), t), x(t))] + sol5 = [Eq(x(t), -C1*sin(t) - C2*cos(t)), + Eq(y(t), C1*cos(t) - C2*sin(t))] + assert dsolve(eqs5) == sol5 + assert checksysodesol(eqs5, sol5) == (True, [0, 0]) + + eqs6 = [Eq(Derivative(x(t), t), -2*y(t)), + Eq(Derivative(y(t), t), 2*x(t))] + sol6 = [Eq(x(t), -C1*sin(2*t) - C2*cos(2*t)), + Eq(y(t), C1*cos(2*t) - C2*sin(2*t))] + assert dsolve(eqs6) == sol6 + assert checksysodesol(eqs6, sol6) == (True, [0, 0]) + + eqs7 = [Eq(Derivative(x(t), t), I*y(t)), + Eq(Derivative(y(t), t), I*x(t))] + sol7 = [Eq(x(t), -C1*exp(-I*t) + C2*exp(I*t)), + Eq(y(t), C1*exp(-I*t) + C2*exp(I*t))] + assert dsolve(eqs7) == sol7 + assert checksysodesol(eqs7, sol7) == (True, [0, 0]) + + eqs8 = [Eq(Derivative(x(t), t), -a*y(t)), + Eq(Derivative(y(t), t), a*x(t))] + sol8 = [Eq(x(t), -I*C1*exp(-I*a*t) + I*C2*exp(I*a*t)), + Eq(y(t), C1*exp(-I*a*t) + C2*exp(I*a*t))] + assert dsolve(eqs8) == sol8 + assert checksysodesol(eqs8, sol8) == (True, [0, 0]) + + eqs9 = [Eq(Derivative(x(t), t), x(t) + y(t)), + Eq(Derivative(y(t), t), x(t) - y(t))] + sol9 = [Eq(x(t), C1*(1 - sqrt(2))*exp(-sqrt(2)*t) + C2*(1 + sqrt(2))*exp(sqrt(2)*t)), + Eq(y(t), C1*exp(-sqrt(2)*t) + C2*exp(sqrt(2)*t))] + assert dsolve(eqs9) == sol9 + assert checksysodesol(eqs9, sol9) == (True, [0, 0]) + + eqs10 = [Eq(Derivative(x(t), t), x(t) + y(t)), + Eq(Derivative(y(t), t), x(t) + y(t))] + sol10 = [Eq(x(t), -C1 + C2*exp(2*t)), + Eq(y(t), C1 + C2*exp(2*t))] + assert dsolve(eqs10) == sol10 + assert checksysodesol(eqs10, sol10) == (True, [0, 0]) + + eqs11 = [Eq(Derivative(x(t), t), 2*x(t) + y(t)), + Eq(Derivative(y(t), t), -x(t) + 2*y(t))] + sol11 = [Eq(x(t), C1*exp(2*t)*sin(t) + C2*exp(2*t)*cos(t)), + Eq(y(t), C1*exp(2*t)*cos(t) - C2*exp(2*t)*sin(t))] + assert dsolve(eqs11) == sol11 + assert checksysodesol(eqs11, sol11) == (True, [0, 0]) + + eqs12 = [Eq(Derivative(x(t), t), x(t) + 2*y(t)), + Eq(Derivative(y(t), t), 2*x(t) + y(t))] + sol12 = [Eq(x(t), -C1*exp(-t) + C2*exp(3*t)), + Eq(y(t), C1*exp(-t) + C2*exp(3*t))] + assert dsolve(eqs12) == sol12 + assert checksysodesol(eqs12, sol12) == (True, [0, 0]) + + eqs13 = [Eq(Derivative(x(t), t), 4*x(t) + y(t)), + Eq(Derivative(y(t), t), -x(t) + 2*y(t))] + sol13 = [Eq(x(t), C2*t*exp(3*t) + (C1 + C2)*exp(3*t)), + Eq(y(t), -C1*exp(3*t) - C2*t*exp(3*t))] + assert dsolve(eqs13) == sol13 + assert checksysodesol(eqs13, sol13) == (True, [0, 0]) + + eqs14 = [Eq(Derivative(x(t), t), a*y(t)), + Eq(Derivative(y(t), t), a*x(t))] + sol14 = [Eq(x(t), -C1*exp(-a*t) + C2*exp(a*t)), + Eq(y(t), C1*exp(-a*t) + C2*exp(a*t))] + assert dsolve(eqs14) == sol14 + assert checksysodesol(eqs14, sol14) == (True, [0, 0]) + + eqs15 = [Eq(Derivative(x(t), t), a*y(t)), + Eq(Derivative(y(t), t), b*x(t))] + sol15 = [Eq(x(t), -C1*a*exp(-t*sqrt(a*b))/sqrt(a*b) + C2*a*exp(t*sqrt(a*b))/sqrt(a*b)), + Eq(y(t), C1*exp(-t*sqrt(a*b)) + C2*exp(t*sqrt(a*b)))] + assert dsolve(eqs15) == sol15 + assert checksysodesol(eqs15, sol15) == (True, [0, 0]) + + eqs16 = [Eq(Derivative(x(t), t), a*x(t) + b*y(t)), + Eq(Derivative(y(t), t), c*x(t))] + sol16 = [Eq(x(t), -2*C1*b*exp(t*(a + sqrt(a**2 + 4*b*c))/2)/(a - sqrt(a**2 + 4*b*c)) - 2*C2*b*exp(t*(a - + sqrt(a**2 + 4*b*c))/2)/(a + sqrt(a**2 + 4*b*c))), + Eq(y(t), C1*exp(t*(a + sqrt(a**2 + 4*b*c))/2) + C2*exp(t*(a - sqrt(a**2 + 4*b*c))/2))] + assert dsolve(eqs16) == sol16 + assert checksysodesol(eqs16, sol16) == (True, [0, 0]) + + # Regression test case for issue #18562 + # https://github.com/sympy/sympy/issues/18562 + eqs17 = [Eq(Derivative(x(t), t), a*y(t) + x(t)), + Eq(Derivative(y(t), t), a*x(t) - y(t))] + sol17 = [Eq(x(t), C1*a*exp(t*sqrt(a**2 + 1))/(sqrt(a**2 + 1) - 1) - C2*a*exp(-t*sqrt(a**2 + 1))/(sqrt(a**2 + + 1) + 1)), + Eq(y(t), C1*exp(t*sqrt(a**2 + 1)) + C2*exp(-t*sqrt(a**2 + 1)))] + assert dsolve(eqs17) == sol17 + assert checksysodesol(eqs17, sol17) == (True, [0, 0]) + + eqs18 = [Eq(Derivative(x(t), t), 0), + Eq(Derivative(y(t), t), 0)] + sol18 = [Eq(x(t), C1), + Eq(y(t), C2)] + assert dsolve(eqs18) == sol18 + assert checksysodesol(eqs18, sol18) == (True, [0, 0]) + + eqs19 = [Eq(Derivative(x(t), t), 2*x(t) - y(t)), + Eq(Derivative(y(t), t), x(t))] + sol19 = [Eq(x(t), C2*t*exp(t) + (C1 + C2)*exp(t)), + Eq(y(t), C1*exp(t) + C2*t*exp(t))] + assert dsolve(eqs19) == sol19 + assert checksysodesol(eqs19, sol19) == (True, [0, 0]) + + eqs20 = [Eq(Derivative(x(t), t), x(t)), + Eq(Derivative(y(t), t), x(t) + y(t))] + sol20 = [Eq(x(t), C1*exp(t)), + Eq(y(t), C1*t*exp(t) + C2*exp(t))] + assert dsolve(eqs20) == sol20 + assert checksysodesol(eqs20, sol20) == (True, [0, 0]) + + eqs21 = [Eq(Derivative(x(t), t), 3*x(t)), + Eq(Derivative(y(t), t), x(t) + y(t))] + sol21 = [Eq(x(t), 2*C1*exp(3*t)), + Eq(y(t), C1*exp(3*t) + C2*exp(t))] + assert dsolve(eqs21) == sol21 + assert checksysodesol(eqs21, sol21) == (True, [0, 0]) + + eqs22 = [Eq(Derivative(x(t), t), 3*x(t)), + Eq(Derivative(y(t), t), y(t))] + sol22 = [Eq(x(t), C1*exp(3*t)), + Eq(y(t), C2*exp(t))] + assert dsolve(eqs22) == sol22 + assert checksysodesol(eqs22, sol22) == (True, [0, 0]) + + +@slow +def test_sysode_linear_neq_order1_type1_slow(): + + t = Symbol('t') + Z0 = Function('Z0') + Z1 = Function('Z1') + Z2 = Function('Z2') + Z3 = Function('Z3') + + k01, k10, k20, k21, k23, k30 = symbols('k01 k10 k20 k21 k23 k30') + + eqs1 = [Eq(Derivative(Z0(t), t), -k01*Z0(t) + k10*Z1(t) + k20*Z2(t) + k30*Z3(t)), + Eq(Derivative(Z1(t), t), k01*Z0(t) - k10*Z1(t) + k21*Z2(t)), + Eq(Derivative(Z2(t), t), (-k20 - k21 - k23)*Z2(t)), + Eq(Derivative(Z3(t), t), k23*Z2(t) - k30*Z3(t))] + sol1 = [Eq(Z0(t), C1*k10/k01 - C2*(k10 - k30)*exp(-k30*t)/(k01 + k10 - k30) - C3*(k10*(k20 + k21 - k30) - + k20**2 - k20*(k21 + k23 - k30) + k23*k30)*exp(-t*(k20 + k21 + k23))/(k23*(-k01 - k10 + k20 + k21 + + k23)) - C4*exp(-t*(k01 + k10))), + Eq(Z1(t), C1 - C2*k01*exp(-k30*t)/(k01 + k10 - k30) + C3*(-k01*(k20 + k21 - k30) + k20*k21 + k21**2 + + k21*(k23 - k30))*exp(-t*(k20 + k21 + k23))/(k23*(-k01 - k10 + k20 + k21 + k23)) + C4*exp(-t*(k01 + + k10))), + Eq(Z2(t), -C3*(k20 + k21 + k23 - k30)*exp(-t*(k20 + k21 + k23))/k23), + Eq(Z3(t), C2*exp(-k30*t) + C3*exp(-t*(k20 + k21 + k23)))] + assert dsolve(eqs1) == sol1 + assert checksysodesol(eqs1, sol1) == (True, [0, 0, 0, 0]) + + x, y, z, u, v, w = symbols('x y z u v w', cls=Function) + k2, k3 = symbols('k2 k3') + a_b, a_c = symbols('a_b a_c', real=True) + + eqs2 = [Eq(Derivative(z(t), t), k2*y(t)), + Eq(Derivative(x(t), t), k3*y(t)), + Eq(Derivative(y(t), t), (-k2 - k3)*y(t))] + sol2 = [Eq(z(t), C1 - C2*k2*exp(-t*(k2 + k3))/(k2 + k3)), + Eq(x(t), -C2*k3*exp(-t*(k2 + k3))/(k2 + k3) + C3), + Eq(y(t), C2*exp(-t*(k2 + k3)))] + assert dsolve(eqs2) == sol2 + assert checksysodesol(eqs2, sol2) == (True, [0, 0, 0]) + + eqs3 = [4*u(t) - v(t) - 2*w(t) + Derivative(u(t), t), + 2*u(t) + v(t) - 2*w(t) + Derivative(v(t), t), + 5*u(t) + v(t) - 3*w(t) + Derivative(w(t), t)] + sol3 = [Eq(u(t), C3*exp(-2*t) + (C1/2 + sqrt(3)*C2/6)*cos(sqrt(3)*t) + sin(sqrt(3)*t)*(sqrt(3)*C1/6 + + C2*Rational(-1, 2))), + Eq(v(t), (C1/2 + sqrt(3)*C2/6)*cos(sqrt(3)*t) + sin(sqrt(3)*t)*(sqrt(3)*C1/6 + C2*Rational(-1, 2))), + Eq(w(t), C1*cos(sqrt(3)*t) - C2*sin(sqrt(3)*t) + C3*exp(-2*t))] + assert dsolve(eqs3) == sol3 + assert checksysodesol(eqs3, sol3) == (True, [0, 0, 0]) + + eqs4 = [Eq(Derivative(x(t), t), w(t)*Rational(-2, 9) + 2*x(t) + y(t) + z(t)*Rational(-8, 9)), + Eq(Derivative(y(t), t), w(t)*Rational(4, 9) + 2*y(t) + z(t)*Rational(16, 9)), + Eq(Derivative(z(t), t), w(t)*Rational(-2, 9) + z(t)*Rational(37, 9)), + Eq(Derivative(w(t), t), w(t)*Rational(44, 9) + z(t)*Rational(-4, 9))] + sol4 = [Eq(x(t), C1*exp(2*t) + C2*t*exp(2*t)), + Eq(y(t), C2*exp(2*t) + 2*C3*exp(4*t)), + Eq(z(t), 2*C3*exp(4*t) + C4*exp(5*t)*Rational(-1, 4)), + Eq(w(t), C3*exp(4*t) + C4*exp(5*t))] + assert dsolve(eqs4) == sol4 + assert checksysodesol(eqs4, sol4) == (True, [0, 0, 0, 0]) + + # Regression test case for issue #15574 + # https://github.com/sympy/sympy/issues/15574 + eq5 = [Eq(x(t).diff(t), x(t)), Eq(y(t).diff(t), y(t)), Eq(z(t).diff(t), z(t)), Eq(w(t).diff(t), w(t))] + sol5 = [Eq(x(t), C1*exp(t)), Eq(y(t), C2*exp(t)), Eq(z(t), C3*exp(t)), Eq(w(t), C4*exp(t))] + assert dsolve(eq5) == sol5 + assert checksysodesol(eq5, sol5) == (True, [0, 0, 0, 0]) + + eqs6 = [Eq(Derivative(x(t), t), x(t) + y(t)), + Eq(Derivative(y(t), t), y(t) + z(t)), + Eq(Derivative(z(t), t), w(t)*Rational(-1, 8) + z(t)), + Eq(Derivative(w(t), t), w(t)/2 + z(t)/2)] + sol6 = [Eq(x(t), C1*exp(t) + C2*t*exp(t) + 4*C4*t*exp(t*Rational(3, 4)) + (4*C3 + 48*C4)*exp(t*Rational(3, + 4))), + Eq(y(t), C2*exp(t) - C4*t*exp(t*Rational(3, 4)) - (C3 + 8*C4)*exp(t*Rational(3, 4))), + Eq(z(t), C4*t*exp(t*Rational(3, 4))/4 + (C3/4 + C4)*exp(t*Rational(3, 4))), + Eq(w(t), C3*exp(t*Rational(3, 4))/2 + C4*t*exp(t*Rational(3, 4))/2)] + assert dsolve(eqs6) == sol6 + assert checksysodesol(eqs6, sol6) == (True, [0, 0, 0, 0]) + + # Regression test case for issue #15574 + # https://github.com/sympy/sympy/issues/15574 + eq7 = [Eq(Derivative(x(t), t), x(t)), Eq(Derivative(y(t), t), y(t)), Eq(Derivative(z(t), t), z(t)), + Eq(Derivative(w(t), t), w(t)), Eq(Derivative(u(t), t), u(t))] + sol7 = [Eq(x(t), C1*exp(t)), Eq(y(t), C2*exp(t)), Eq(z(t), C3*exp(t)), Eq(w(t), C4*exp(t)), + Eq(u(t), C5*exp(t))] + assert dsolve(eq7) == sol7 + assert checksysodesol(eq7, sol7) == (True, [0, 0, 0, 0, 0]) + + eqs8 = [Eq(Derivative(x(t), t), 2*x(t) + y(t)), + Eq(Derivative(y(t), t), 2*y(t)), + Eq(Derivative(z(t), t), 4*z(t)), + Eq(Derivative(w(t), t), u(t) + 5*w(t)), + Eq(Derivative(u(t), t), 5*u(t))] + sol8 = [Eq(x(t), C1*exp(2*t) + C2*t*exp(2*t)), + Eq(y(t), C2*exp(2*t)), + Eq(z(t), C3*exp(4*t)), + Eq(w(t), C4*exp(5*t) + C5*t*exp(5*t)), + Eq(u(t), C5*exp(5*t))] + assert dsolve(eqs8) == sol8 + assert checksysodesol(eqs8, sol8) == (True, [0, 0, 0, 0, 0]) + + # Regression test case for issue #15574 + # https://github.com/sympy/sympy/issues/15574 + eq9 = [Eq(Derivative(x(t), t), x(t)), Eq(Derivative(y(t), t), y(t)), Eq(Derivative(z(t), t), z(t))] + sol9 = [Eq(x(t), C1*exp(t)), Eq(y(t), C2*exp(t)), Eq(z(t), C3*exp(t))] + assert dsolve(eq9) == sol9 + assert checksysodesol(eq9, sol9) == (True, [0, 0, 0]) + + # Regression test case for issue #15407 + # https://github.com/sympy/sympy/issues/15407 + eqs10 = [Eq(Derivative(x(t), t), (-a_b - a_c)*x(t)), + Eq(Derivative(y(t), t), a_b*y(t)), + Eq(Derivative(z(t), t), a_c*x(t))] + sol10 = [Eq(x(t), -C1*(a_b + a_c)*exp(-t*(a_b + a_c))/a_c), + Eq(y(t), C2*exp(a_b*t)), + Eq(z(t), C1*exp(-t*(a_b + a_c)) + C3)] + assert dsolve(eqs10) == sol10 + assert checksysodesol(eqs10, sol10) == (True, [0, 0, 0]) + + # Regression test case for issue #14312 + # https://github.com/sympy/sympy/issues/14312 + eqs11 = [Eq(Derivative(x(t), t), k3*y(t)), + Eq(Derivative(y(t), t), (-k2 - k3)*y(t)), + Eq(Derivative(z(t), t), k2*y(t))] + sol11 = [Eq(x(t), C1 + C2*k3*exp(-t*(k2 + k3))/k2), + Eq(y(t), -C2*(k2 + k3)*exp(-t*(k2 + k3))/k2), + Eq(z(t), C2*exp(-t*(k2 + k3)) + C3)] + assert dsolve(eqs11) == sol11 + assert checksysodesol(eqs11, sol11) == (True, [0, 0, 0]) + + # Regression test case for issue #14312 + # https://github.com/sympy/sympy/issues/14312 + eqs12 = [Eq(Derivative(z(t), t), k2*y(t)), + Eq(Derivative(x(t), t), k3*y(t)), + Eq(Derivative(y(t), t), (-k2 - k3)*y(t))] + sol12 = [Eq(z(t), C1 - C2*k2*exp(-t*(k2 + k3))/(k2 + k3)), + Eq(x(t), -C2*k3*exp(-t*(k2 + k3))/(k2 + k3) + C3), + Eq(y(t), C2*exp(-t*(k2 + k3)))] + assert dsolve(eqs12) == sol12 + assert checksysodesol(eqs12, sol12) == (True, [0, 0, 0]) + + f, g, h = symbols('f, g, h', cls=Function) + a, b, c = symbols('a, b, c') + + # Regression test case for issue #15474 + # https://github.com/sympy/sympy/issues/15474 + eqs13 = [Eq(Derivative(f(t), t), 2*f(t) + g(t)), + Eq(Derivative(g(t), t), a*f(t))] + sol13 = [Eq(f(t), C1*exp(t*(sqrt(a + 1) + 1))/(sqrt(a + 1) - 1) - C2*exp(-t*(sqrt(a + 1) - 1))/(sqrt(a + 1) + + 1)), + Eq(g(t), C1*exp(t*(sqrt(a + 1) + 1)) + C2*exp(-t*(sqrt(a + 1) - 1)))] + assert dsolve(eqs13) == sol13 + assert checksysodesol(eqs13, sol13) == (True, [0, 0]) + + eqs14 = [Eq(Derivative(f(t), t), 2*g(t) - 3*h(t)), + Eq(Derivative(g(t), t), -2*f(t) + 4*h(t)), + Eq(Derivative(h(t), t), 3*f(t) - 4*g(t))] + sol14 = [Eq(f(t), 2*C1 - sin(sqrt(29)*t)*(sqrt(29)*C2*Rational(3, 25) + C3*Rational(-8, 25)) - + cos(sqrt(29)*t)*(C2*Rational(8, 25) + sqrt(29)*C3*Rational(3, 25))), + Eq(g(t), C1*Rational(3, 2) + sin(sqrt(29)*t)*(sqrt(29)*C2*Rational(4, 25) + C3*Rational(6, 25)) - + cos(sqrt(29)*t)*(C2*Rational(6, 25) + sqrt(29)*C3*Rational(-4, 25))), + Eq(h(t), C1 + C2*cos(sqrt(29)*t) - C3*sin(sqrt(29)*t))] + assert dsolve(eqs14) == sol14 + assert checksysodesol(eqs14, sol14) == (True, [0, 0, 0]) + + eqs15 = [Eq(2*Derivative(f(t), t), 12*g(t) - 12*h(t)), + Eq(3*Derivative(g(t), t), -8*f(t) + 8*h(t)), + Eq(4*Derivative(h(t), t), 6*f(t) - 6*g(t))] + sol15 = [Eq(f(t), C1 - sin(sqrt(29)*t)*(sqrt(29)*C2*Rational(6, 13) + C3*Rational(-16, 13)) - + cos(sqrt(29)*t)*(C2*Rational(16, 13) + sqrt(29)*C3*Rational(6, 13))), + Eq(g(t), C1 + sin(sqrt(29)*t)*(sqrt(29)*C2*Rational(8, 39) + C3*Rational(16, 13)) - + cos(sqrt(29)*t)*(C2*Rational(16, 13) + sqrt(29)*C3*Rational(-8, 39))), + Eq(h(t), C1 + C2*cos(sqrt(29)*t) - C3*sin(sqrt(29)*t))] + assert dsolve(eqs15) == sol15 + assert checksysodesol(eqs15, sol15) == (True, [0, 0, 0]) + + eq16 = (Eq(diff(x(t), t), 21*x(t)), Eq(diff(y(t), t), 17*x(t) + 3*y(t)), + Eq(diff(z(t), t), 5*x(t) + 7*y(t) + 9*z(t))) + sol16 = [Eq(x(t), 216*C1*exp(21*t)/209), + Eq(y(t), 204*C1*exp(21*t)/209 - 6*C2*exp(3*t)/7), + Eq(z(t), C1*exp(21*t) + C2*exp(3*t) + C3*exp(9*t))] + assert dsolve(eq16) == sol16 + assert checksysodesol(eq16, sol16) == (True, [0, 0, 0]) + + eqs17 = [Eq(Derivative(x(t), t), 3*y(t) - 11*z(t)), + Eq(Derivative(y(t), t), -3*x(t) + 7*z(t)), + Eq(Derivative(z(t), t), 11*x(t) - 7*y(t))] + sol17 = [Eq(x(t), C1*Rational(7, 3) - sin(sqrt(179)*t)*(sqrt(179)*C2*Rational(11, 170) + C3*Rational(-21, + 170)) - cos(sqrt(179)*t)*(C2*Rational(21, 170) + sqrt(179)*C3*Rational(11, 170))), + Eq(y(t), C1*Rational(11, 3) + sin(sqrt(179)*t)*(sqrt(179)*C2*Rational(7, 170) + C3*Rational(33, + 170)) - cos(sqrt(179)*t)*(C2*Rational(33, 170) + sqrt(179)*C3*Rational(-7, 170))), + Eq(z(t), C1 + C2*cos(sqrt(179)*t) - C3*sin(sqrt(179)*t))] + assert dsolve(eqs17) == sol17 + assert checksysodesol(eqs17, sol17) == (True, [0, 0, 0]) + + eqs18 = [Eq(3*Derivative(x(t), t), 20*y(t) - 20*z(t)), + Eq(4*Derivative(y(t), t), -15*x(t) + 15*z(t)), + Eq(5*Derivative(z(t), t), 12*x(t) - 12*y(t))] + sol18 = [Eq(x(t), C1 - sin(5*sqrt(2)*t)*(sqrt(2)*C2*Rational(4, 3) - C3) - cos(5*sqrt(2)*t)*(C2 + + sqrt(2)*C3*Rational(4, 3))), + Eq(y(t), C1 + sin(5*sqrt(2)*t)*(sqrt(2)*C2*Rational(3, 4) + C3) - cos(5*sqrt(2)*t)*(C2 + + sqrt(2)*C3*Rational(-3, 4))), + Eq(z(t), C1 + C2*cos(5*sqrt(2)*t) - C3*sin(5*sqrt(2)*t))] + assert dsolve(eqs18) == sol18 + assert checksysodesol(eqs18, sol18) == (True, [0, 0, 0]) + + eqs19 = [Eq(Derivative(x(t), t), 4*x(t) - z(t)), + Eq(Derivative(y(t), t), 2*x(t) + 2*y(t) - z(t)), + Eq(Derivative(z(t), t), 3*x(t) + y(t))] + sol19 = [Eq(x(t), C2*t**2*exp(2*t)/2 + t*(2*C2 + C3)*exp(2*t) + (C1 + C2 + 2*C3)*exp(2*t)), + Eq(y(t), C2*t**2*exp(2*t)/2 + t*(2*C2 + C3)*exp(2*t) + (C1 + 2*C3)*exp(2*t)), + Eq(z(t), C2*t**2*exp(2*t) + t*(3*C2 + 2*C3)*exp(2*t) + (2*C1 + 3*C3)*exp(2*t))] + assert dsolve(eqs19) == sol19 + assert checksysodesol(eqs19, sol19) == (True, [0, 0, 0]) + + eqs20 = [Eq(Derivative(x(t), t), 4*x(t) - y(t) - 2*z(t)), + Eq(Derivative(y(t), t), 2*x(t) + y(t) - 2*z(t)), + Eq(Derivative(z(t), t), 5*x(t) - 3*z(t))] + sol20 = [Eq(x(t), C1*exp(2*t) - sin(t)*(C2*Rational(3, 5) + C3/5) - cos(t)*(C2/5 + C3*Rational(-3, 5))), + Eq(y(t), -sin(t)*(C2*Rational(3, 5) + C3/5) - cos(t)*(C2/5 + C3*Rational(-3, 5))), + Eq(z(t), C1*exp(2*t) - C2*sin(t) + C3*cos(t))] + assert dsolve(eqs20) == sol20 + assert checksysodesol(eqs20, sol20) == (True, [0, 0, 0]) + + eq21 = (Eq(diff(x(t), t), 9*y(t)), Eq(diff(y(t), t), 12*x(t))) + sol21 = [Eq(x(t), -sqrt(3)*C1*exp(-6*sqrt(3)*t)/2 + sqrt(3)*C2*exp(6*sqrt(3)*t)/2), + Eq(y(t), C1*exp(-6*sqrt(3)*t) + C2*exp(6*sqrt(3)*t))] + + assert dsolve(eq21) == sol21 + assert checksysodesol(eq21, sol21) == (True, [0, 0]) + + eqs22 = [Eq(Derivative(x(t), t), 2*x(t) + 4*y(t)), + Eq(Derivative(y(t), t), 12*x(t) + 41*y(t))] + sol22 = [Eq(x(t), C1*(39 - sqrt(1713))*exp(t*(sqrt(1713) + 43)/2)*Rational(-1, 24) + C2*(39 + + sqrt(1713))*exp(t*(43 - sqrt(1713))/2)*Rational(-1, 24)), + Eq(y(t), C1*exp(t*(sqrt(1713) + 43)/2) + C2*exp(t*(43 - sqrt(1713))/2))] + assert dsolve(eqs22) == sol22 + assert checksysodesol(eqs22, sol22) == (True, [0, 0]) + + eqs23 = [Eq(Derivative(x(t), t), x(t) + y(t)), + Eq(Derivative(y(t), t), -2*x(t) + 2*y(t))] + sol23 = [Eq(x(t), (C1/4 + sqrt(7)*C2/4)*cos(sqrt(7)*t/2)*exp(t*Rational(3, 2)) + + sin(sqrt(7)*t/2)*(sqrt(7)*C1/4 + C2*Rational(-1, 4))*exp(t*Rational(3, 2))), + Eq(y(t), C1*cos(sqrt(7)*t/2)*exp(t*Rational(3, 2)) - C2*sin(sqrt(7)*t/2)*exp(t*Rational(3, 2)))] + assert dsolve(eqs23) == sol23 + assert checksysodesol(eqs23, sol23) == (True, [0, 0]) + + # Regression test case for issue #15474 + # https://github.com/sympy/sympy/issues/15474 + a = Symbol("a", real=True) + eq24 = [x(t).diff(t) - a*y(t), y(t).diff(t) + a*x(t)] + sol24 = [Eq(x(t), C1*sin(a*t) + C2*cos(a*t)), Eq(y(t), C1*cos(a*t) - C2*sin(a*t))] + assert dsolve(eq24) == sol24 + assert checksysodesol(eq24, sol24) == (True, [0, 0]) + + # Regression test case for issue #19150 + # https://github.com/sympy/sympy/issues/19150 + eqs25 = [Eq(Derivative(f(t), t), 0), + Eq(Derivative(g(t), t), (f(t) - 2*g(t) + x(t))/(b*c)), + Eq(Derivative(x(t), t), (g(t) - 2*x(t) + y(t))/(b*c)), + Eq(Derivative(y(t), t), (h(t) + x(t) - 2*y(t))/(b*c)), + Eq(Derivative(h(t), t), 0)] + sol25 = [Eq(f(t), -3*C1 + 4*C2), + Eq(g(t), -2*C1 + 3*C2 - C3*exp(-2*t/(b*c)) + C4*exp(-t*(sqrt(2) + 2)/(b*c)) + C5*exp(-t*(2 - + sqrt(2))/(b*c))), + Eq(x(t), -C1 + 2*C2 - sqrt(2)*C4*exp(-t*(sqrt(2) + 2)/(b*c)) + sqrt(2)*C5*exp(-t*(2 - + sqrt(2))/(b*c))), + Eq(y(t), C2 + C3*exp(-2*t/(b*c)) + C4*exp(-t*(sqrt(2) + 2)/(b*c)) + C5*exp(-t*(2 - sqrt(2))/(b*c))), + Eq(h(t), C1)] + assert dsolve(eqs25) == sol25 + assert checksysodesol(eqs25, sol25) == (True, [0, 0, 0, 0, 0]) + + eq26 = [Eq(Derivative(f(t), t), 2*f(t)), Eq(Derivative(g(t), t), 3*f(t) + 7*g(t))] + sol26 = [Eq(f(t), -5*C1*exp(2*t)/3), Eq(g(t), C1*exp(2*t) + C2*exp(7*t))] + assert dsolve(eq26) == sol26 + assert checksysodesol(eq26, sol26) == (True, [0, 0]) + + eq27 = [Eq(Derivative(f(t), t), -9*I*f(t) - 4*g(t)), Eq(Derivative(g(t), t), -4*I*g(t))] + sol27 = [Eq(f(t), 4*I*C1*exp(-4*I*t)/5 + C2*exp(-9*I*t)), Eq(g(t), C1*exp(-4*I*t))] + assert dsolve(eq27) == sol27 + assert checksysodesol(eq27, sol27) == (True, [0, 0]) + + eq28 = [Eq(Derivative(f(t), t), -9*I*f(t)), Eq(Derivative(g(t), t), -4*I*g(t))] + sol28 = [Eq(f(t), C1*exp(-9*I*t)), Eq(g(t), C2*exp(-4*I*t))] + assert dsolve(eq28) == sol28 + assert checksysodesol(eq28, sol28) == (True, [0, 0]) + + eq29 = [Eq(Derivative(f(t), t), 0), Eq(Derivative(g(t), t), 0)] + sol29 = [Eq(f(t), C1), Eq(g(t), C2)] + assert dsolve(eq29) == sol29 + assert checksysodesol(eq29, sol29) == (True, [0, 0]) + + eq30 = [Eq(Derivative(f(t), t), f(t)), Eq(Derivative(g(t), t), 0)] + sol30 = [Eq(f(t), C1*exp(t)), Eq(g(t), C2)] + assert dsolve(eq30) == sol30 + assert checksysodesol(eq30, sol30) == (True, [0, 0]) + + eq31 = [Eq(Derivative(f(t), t), g(t)), Eq(Derivative(g(t), t), 0)] + sol31 = [Eq(f(t), C1 + C2*t), Eq(g(t), C2)] + assert dsolve(eq31) == sol31 + assert checksysodesol(eq31, sol31) == (True, [0, 0]) + + eq32 = [Eq(Derivative(f(t), t), 0), Eq(Derivative(g(t), t), f(t))] + sol32 = [Eq(f(t), C1), Eq(g(t), C1*t + C2)] + assert dsolve(eq32) == sol32 + assert checksysodesol(eq32, sol32) == (True, [0, 0]) + + eq33 = [Eq(Derivative(f(t), t), 0), Eq(Derivative(g(t), t), g(t))] + sol33 = [Eq(f(t), C1), Eq(g(t), C2*exp(t))] + assert dsolve(eq33) == sol33 + assert checksysodesol(eq33, sol33) == (True, [0, 0]) + + eq34 = [Eq(Derivative(f(t), t), f(t)), Eq(Derivative(g(t), t), I*g(t))] + sol34 = [Eq(f(t), C1*exp(t)), Eq(g(t), C2*exp(I*t))] + assert dsolve(eq34) == sol34 + assert checksysodesol(eq34, sol34) == (True, [0, 0]) + + eq35 = [Eq(Derivative(f(t), t), I*f(t)), Eq(Derivative(g(t), t), -I*g(t))] + sol35 = [Eq(f(t), C1*exp(I*t)), Eq(g(t), C2*exp(-I*t))] + assert dsolve(eq35) == sol35 + assert checksysodesol(eq35, sol35) == (True, [0, 0]) + + eq36 = [Eq(Derivative(f(t), t), I*g(t)), Eq(Derivative(g(t), t), 0)] + sol36 = [Eq(f(t), I*C1 + I*C2*t), Eq(g(t), C2)] + assert dsolve(eq36) == sol36 + assert checksysodesol(eq36, sol36) == (True, [0, 0]) + + eq37 = [Eq(Derivative(f(t), t), I*g(t)), Eq(Derivative(g(t), t), I*f(t))] + sol37 = [Eq(f(t), -C1*exp(-I*t) + C2*exp(I*t)), Eq(g(t), C1*exp(-I*t) + C2*exp(I*t))] + assert dsolve(eq37) == sol37 + assert checksysodesol(eq37, sol37) == (True, [0, 0]) + + # Multiple systems + eq1 = [Eq(Derivative(f(t), t)**2, g(t)**2), Eq(-f(t) + Derivative(g(t), t), 0)] + sol1 = [[Eq(f(t), -C1*sin(t) - C2*cos(t)), + Eq(g(t), C1*cos(t) - C2*sin(t))], + [Eq(f(t), -C1*exp(-t) + C2*exp(t)), + Eq(g(t), C1*exp(-t) + C2*exp(t))]] + assert dsolve(eq1) == sol1 + for sol in sol1: + assert checksysodesol(eq1, sol) == (True, [0, 0]) + + +def test_sysode_linear_neq_order1_type2(): + + f, g, h, k = symbols('f g h k', cls=Function) + x, t, a, b, c, d, y = symbols('x t a b c d y') + k1, k2 = symbols('k1 k2') + + + eqs1 = [Eq(Derivative(f(x), x), f(x) + g(x) + 5), + Eq(Derivative(g(x), x), -f(x) - g(x) + 7)] + sol1 = [Eq(f(x), C1 + C2 + 6*x**2 + x*(C2 + 5)), + Eq(g(x), -C1 - 6*x**2 - x*(C2 - 7))] + assert dsolve(eqs1) == sol1 + assert checksysodesol(eqs1, sol1) == (True, [0, 0]) + + eqs2 = [Eq(Derivative(f(x), x), f(x) + g(x) + 5), + Eq(Derivative(g(x), x), f(x) + g(x) + 7)] + sol2 = [Eq(f(x), -C1 + C2*exp(2*x) - x - 3), + Eq(g(x), C1 + C2*exp(2*x) + x - 3)] + assert dsolve(eqs2) == sol2 + assert checksysodesol(eqs2, sol2) == (True, [0, 0]) + + eqs3 = [Eq(Derivative(f(x), x), f(x) + 5), + Eq(Derivative(g(x), x), f(x) + 7)] + sol3 = [Eq(f(x), C1*exp(x) - 5), + Eq(g(x), C1*exp(x) + C2 + 2*x - 5)] + assert dsolve(eqs3) == sol3 + assert checksysodesol(eqs3, sol3) == (True, [0, 0]) + + eqs4 = [Eq(Derivative(f(x), x), f(x) + exp(x)), + Eq(Derivative(g(x), x), x*exp(x) + f(x) + g(x))] + sol4 = [Eq(f(x), C1*exp(x) + x*exp(x)), + Eq(g(x), C1*x*exp(x) + C2*exp(x) + x**2*exp(x))] + assert dsolve(eqs4) == sol4 + assert checksysodesol(eqs4, sol4) == (True, [0, 0]) + + eqs5 = [Eq(Derivative(f(x), x), 5*x + f(x) + g(x)), + Eq(Derivative(g(x), x), f(x) - g(x))] + sol5 = [Eq(f(x), C1*(1 + sqrt(2))*exp(sqrt(2)*x) + C2*(1 - sqrt(2))*exp(-sqrt(2)*x) + x*Rational(-5, 2) + + Rational(-5, 2)), + Eq(g(x), C1*exp(sqrt(2)*x) + C2*exp(-sqrt(2)*x) + x*Rational(-5, 2))] + assert dsolve(eqs5) == sol5 + assert checksysodesol(eqs5, sol5) == (True, [0, 0]) + + eqs6 = [Eq(Derivative(f(x), x), -9*f(x) - 4*g(x)), + Eq(Derivative(g(x), x), -4*g(x)), + Eq(Derivative(h(x), x), h(x) + exp(x))] + sol6 = [Eq(f(x), C2*exp(-4*x)*Rational(-4, 5) + C1*exp(-9*x)), + Eq(g(x), C2*exp(-4*x)), + Eq(h(x), C3*exp(x) + x*exp(x))] + assert dsolve(eqs6) == sol6 + assert checksysodesol(eqs6, sol6) == (True, [0, 0, 0]) + + # Regression test case for issue #8859 + # https://github.com/sympy/sympy/issues/8859 + eqs7 = [Eq(Derivative(f(t), t), 3*t + f(t)), + Eq(Derivative(g(t), t), g(t))] + sol7 = [Eq(f(t), C1*exp(t) - 3*t - 3), + Eq(g(t), C2*exp(t))] + assert dsolve(eqs7) == sol7 + assert checksysodesol(eqs7, sol7) == (True, [0, 0]) + + # Regression test case for issue #8567 + # https://github.com/sympy/sympy/issues/8567 + eqs8 = [Eq(Derivative(f(t), t), f(t) + 2*g(t)), + Eq(Derivative(g(t), t), -2*f(t) + g(t) + 2*exp(t))] + sol8 = [Eq(f(t), C1*exp(t)*sin(2*t) + C2*exp(t)*cos(2*t) + + exp(t)*sin(2*t)**2 + exp(t)*cos(2*t)**2), + Eq(g(t), C1*exp(t)*cos(2*t) - C2*exp(t)*sin(2*t))] + assert dsolve(eqs8) == sol8 + assert checksysodesol(eqs8, sol8) == (True, [0, 0]) + + # Regression test case for issue #19150 + # https://github.com/sympy/sympy/issues/19150 + eqs9 = [Eq(Derivative(f(t), t), (c - 2*f(t) + g(t))/(a*b)), + Eq(Derivative(g(t), t), (f(t) - 2*g(t) + h(t))/(a*b)), + Eq(Derivative(h(t), t), (d + g(t) - 2*h(t))/(a*b))] + sol9 = [Eq(f(t), -C1*exp(-2*t/(a*b)) + C2*exp(-t*(sqrt(2) + 2)/(a*b)) + C3*exp(-t*(2 - sqrt(2))/(a*b)) + + Mul(Rational(1, 4), 3*c + d, evaluate=False)), + Eq(g(t), -sqrt(2)*C2*exp(-t*(sqrt(2) + 2)/(a*b)) + sqrt(2)*C3*exp(-t*(2 - sqrt(2))/(a*b)) + + Mul(Rational(1, 2), c + d, evaluate=False)), + Eq(h(t), C1*exp(-2*t/(a*b)) + C2*exp(-t*(sqrt(2) + 2)/(a*b)) + C3*exp(-t*(2 - sqrt(2))/(a*b)) + + Mul(Rational(1, 4), c + 3*d, evaluate=False))] + assert dsolve(eqs9) == sol9 + assert checksysodesol(eqs9, sol9) == (True, [0, 0, 0]) + + # Regression test case for issue #16635 + # https://github.com/sympy/sympy/issues/16635 + eqs10 = [Eq(Derivative(f(t), t), 15*t + f(t) - g(t) - 10), + Eq(Derivative(g(t), t), -15*t + f(t) - g(t) - 5)] + sol10 = [Eq(f(t), C1 + C2 + 5*t**3 + 5*t**2 + t*(C2 - 10)), + Eq(g(t), C1 + 5*t**3 - 10*t**2 + t*(C2 - 5))] + assert dsolve(eqs10) == sol10 + assert checksysodesol(eqs10, sol10) == (True, [0, 0]) + + # Multiple solutions + eqs11 = [Eq(Derivative(f(t), t)**2 - 2*Derivative(f(t), t) + 1, 4), + Eq(-y*f(t) + Derivative(g(t), t), 0)] + sol11 = [[Eq(f(t), C1 - t), Eq(g(t), C1*t*y + C2*y + t**2*y*Rational(-1, 2))], + [Eq(f(t), C1 + 3*t), Eq(g(t), C1*t*y + C2*y + t**2*y*Rational(3, 2))]] + assert dsolve(eqs11) == sol11 + for s11 in sol11: + assert checksysodesol(eqs11, s11) == (True, [0, 0]) + + # test case for issue #19831 + # https://github.com/sympy/sympy/issues/19831 + n = symbols('n', positive=True) + x0 = symbols('x_0') + t0 = symbols('t_0') + x_0 = symbols('x_0') + t_0 = symbols('t_0') + t = symbols('t') + x = Function('x') + y = Function('y') + T = symbols('T') + + eqs12 = [Eq(Derivative(y(t), t), x(t)), + Eq(Derivative(x(t), t), n*(y(t) + 1))] + sol12 = [Eq(y(t), C1*exp(sqrt(n)*t)*n**Rational(-1, 2) - C2*exp(-sqrt(n)*t)*n**Rational(-1, 2) - 1), + Eq(x(t), C1*exp(sqrt(n)*t) + C2*exp(-sqrt(n)*t))] + assert dsolve(eqs12) == sol12 + assert checksysodesol(eqs12, sol12) == (True, [0, 0]) + + sol12b = [ + Eq(y(t), (T*exp(-sqrt(n)*t_0)/2 + exp(-sqrt(n)*t_0)/2 + + x_0*exp(-sqrt(n)*t_0)/(2*sqrt(n)))*exp(sqrt(n)*t) + + (T*exp(sqrt(n)*t_0)/2 + exp(sqrt(n)*t_0)/2 - + x_0*exp(sqrt(n)*t_0)/(2*sqrt(n)))*exp(-sqrt(n)*t) - 1), + Eq(x(t), (T*sqrt(n)*exp(-sqrt(n)*t_0)/2 + sqrt(n)*exp(-sqrt(n)*t_0)/2 + + x_0*exp(-sqrt(n)*t_0)/2)*exp(sqrt(n)*t) + - (T*sqrt(n)*exp(sqrt(n)*t_0)/2 + sqrt(n)*exp(sqrt(n)*t_0)/2 - + x_0*exp(sqrt(n)*t_0)/2)*exp(-sqrt(n)*t)) + ] + assert dsolve(eqs12, ics={y(t0): T, x(t0): x0}) == sol12b + assert checksysodesol(eqs12, sol12b) == (True, [0, 0]) + + #Test cases added for the issue 19763 + #https://github.com/sympy/sympy/issues/19763 + + eq13 = [Eq(Derivative(f(t), t), f(t) + g(t) + 9), + Eq(Derivative(g(t), t), 2*f(t) + 5*g(t) + 23)] + sol13 = [Eq(f(t), -C1*(2 + sqrt(6))*exp(t*(3 - sqrt(6)))/2 - C2*(2 - sqrt(6))*exp(t*(sqrt(6) + 3))/2 - + Rational(22,3)), + Eq(g(t), C1*exp(t*(3 - sqrt(6))) + C2*exp(t*(sqrt(6) + 3)) - Rational(5,3))] + assert dsolve(eq13) == sol13 + assert checksysodesol(eq13, sol13) == (True, [0, 0]) + + eq14 = [Eq(Derivative(f(t), t), f(t) + g(t) + 81), + Eq(Derivative(g(t), t), -2*f(t) + g(t) + 23)] + sol14 = [Eq(f(t), sqrt(2)*C1*exp(t)*sin(sqrt(2)*t)/2 + + sqrt(2)*C2*exp(t)*cos(sqrt(2)*t)/2 + - 58*sin(sqrt(2)*t)**2/3 - 58*cos(sqrt(2)*t)**2/3), + Eq(g(t), C1*exp(t)*cos(sqrt(2)*t) - C2*exp(t)*sin(sqrt(2)*t) + - 185*sin(sqrt(2)*t)**2/3 - 185*cos(sqrt(2)*t)**2/3)] + assert dsolve(eq14) == sol14 + assert checksysodesol(eq14, sol14) == (True, [0,0]) + + eq15 = [Eq(Derivative(f(t), t), f(t) + 2*g(t) + k1), + Eq(Derivative(g(t), t), 3*f(t) + 4*g(t) + k2)] + sol15 = [Eq(f(t), -C1*(3 - sqrt(33))*exp(t*(5 + sqrt(33))/2)/6 - + C2*(3 + sqrt(33))*exp(t*(5 - sqrt(33))/2)/6 + 2*k1 - k2), + Eq(g(t), C1*exp(t*(5 + sqrt(33))/2) + C2*exp(t*(5 - sqrt(33))/2) - + Mul(Rational(1,2), 3*k1 - k2, evaluate = False))] + assert dsolve(eq15) == sol15 + assert checksysodesol(eq15, sol15) == (True, [0,0]) + + eq16 = [Eq(Derivative(f(t), t), k1), + Eq(Derivative(g(t), t), k2)] + sol16 = [Eq(f(t), C1 + k1*t), + Eq(g(t), C2 + k2*t)] + assert dsolve(eq16) == sol16 + assert checksysodesol(eq16, sol16) == (True, [0,0]) + + eq17 = [Eq(Derivative(f(t), t), 0), + Eq(Derivative(g(t), t), c*f(t) + k2)] + sol17 = [Eq(f(t), C1), + Eq(g(t), C2*c + t*(C1*c + k2))] + assert dsolve(eq17) == sol17 + assert checksysodesol(eq17 , sol17) == (True , [0,0]) + + eq18 = [Eq(Derivative(f(t), t), k1), + Eq(Derivative(g(t), t), f(t) + k2)] + sol18 = [Eq(f(t), C1 + k1*t), + Eq(g(t), C2 + k1*t**2/2 + t*(C1 + k2))] + assert dsolve(eq18) == sol18 + assert checksysodesol(eq18 , sol18) == (True , [0,0]) + + eq19 = [Eq(Derivative(f(t), t), k1), + Eq(Derivative(g(t), t), f(t) + 2*g(t) + k2)] + sol19 = [Eq(f(t), -2*C1 + k1*t), + Eq(g(t), C1 + C2*exp(2*t) - k1*t/2 - Mul(Rational(1,4), k1 + 2*k2 , evaluate = False))] + assert dsolve(eq19) == sol19 + assert checksysodesol(eq19 , sol19) == (True , [0,0]) + + eq20 = [Eq(diff(f(t), t), f(t) + k1), + Eq(diff(g(t), t), k2)] + sol20 = [Eq(f(t), C1*exp(t) - k1), + Eq(g(t), C2 + k2*t)] + assert dsolve(eq20) == sol20 + assert checksysodesol(eq20 , sol20) == (True , [0,0]) + + eq21 = [Eq(diff(f(t), t), g(t) + k1), + Eq(diff(g(t), t), 0)] + sol21 = [Eq(f(t), C1 + t*(C2 + k1)), + Eq(g(t), C2)] + assert dsolve(eq21) == sol21 + assert checksysodesol(eq21 , sol21) == (True , [0,0]) + + eq22 = [Eq(Derivative(f(t), t), f(t) + 2*g(t) + k1), + Eq(Derivative(g(t), t), k2)] + sol22 = [Eq(f(t), -2*C1 + C2*exp(t) - k1 - 2*k2*t - 2*k2), + Eq(g(t), C1 + k2*t)] + assert dsolve(eq22) == sol22 + assert checksysodesol(eq22 , sol22) == (True , [0,0]) + + eq23 = [Eq(Derivative(f(t), t), g(t) + k1), + Eq(Derivative(g(t), t), 2*g(t) + k2)] + sol23 = [Eq(f(t), C1 + C2*exp(2*t)/2 - k2/4 + t*(2*k1 - k2)/2), + Eq(g(t), C2*exp(2*t) - k2/2)] + assert dsolve(eq23) == sol23 + assert checksysodesol(eq23 , sol23) == (True , [0,0]) + + eq24 = [Eq(Derivative(f(t), t), f(t) + k1), + Eq(Derivative(g(t), t), 2*f(t) + k2)] + sol24 = [Eq(f(t), C1*exp(t)/2 - k1), + Eq(g(t), C1*exp(t) + C2 - 2*k1 - t*(2*k1 - k2))] + assert dsolve(eq24) == sol24 + assert checksysodesol(eq24 , sol24) == (True , [0,0]) + + eq25 = [Eq(Derivative(f(t), t), f(t) + 2*g(t) + k1), + Eq(Derivative(g(t), t), 3*f(t) + 6*g(t) + k2)] + sol25 = [Eq(f(t), -2*C1 + C2*exp(7*t)/3 + 2*t*(3*k1 - k2)/7 - + Mul(Rational(1,49), k1 + 2*k2 , evaluate = False)), + Eq(g(t), C1 + C2*exp(7*t) - t*(3*k1 - k2)/7 - + Mul(Rational(3,49), k1 + 2*k2 , evaluate = False))] + assert dsolve(eq25) == sol25 + assert checksysodesol(eq25 , sol25) == (True , [0,0]) + + eq26 = [Eq(Derivative(f(t), t), 2*f(t) - g(t) + k1), + Eq(Derivative(g(t), t), 4*f(t) - 2*g(t) + 2*k1)] + sol26 = [Eq(f(t), C1 + 2*C2 + t*(2*C1 + k1)), + Eq(g(t), 4*C2 + t*(4*C1 + 2*k1))] + assert dsolve(eq26) == sol26 + assert checksysodesol(eq26 , sol26) == (True , [0,0]) + + # Test Case added for issue #22715 + # https://github.com/sympy/sympy/issues/22715 + + eq27 = [Eq(diff(x(t),t),-1*y(t)+10), Eq(diff(y(t),t),5*x(t)-2*y(t)+3)] + sol27 = [Eq(x(t), (C1/5 - 2*C2/5)*exp(-t)*cos(2*t) + - (2*C1/5 + C2/5)*exp(-t)*sin(2*t) + + 17*sin(2*t)**2/5 + 17*cos(2*t)**2/5), + Eq(y(t), C1*exp(-t)*cos(2*t) - C2*exp(-t)*sin(2*t) + + 10*sin(2*t)**2 + 10*cos(2*t)**2)] + assert dsolve(eq27) == sol27 + assert checksysodesol(eq27 , sol27) == (True , [0,0]) + + +def test_sysode_linear_neq_order1_type3(): + + f, g, h, k, x0 , y0 = symbols('f g h k x0 y0', cls=Function) + x, t, a = symbols('x t a') + r = symbols('r', real=True) + + eqs1 = [Eq(Derivative(f(r), r), r*g(r) + f(r)), + Eq(Derivative(g(r), r), -r*f(r) + g(r))] + sol1 = [Eq(f(r), C1*exp(r)*sin(r**2/2) + C2*exp(r)*cos(r**2/2)), + Eq(g(r), C1*exp(r)*cos(r**2/2) - C2*exp(r)*sin(r**2/2))] + assert dsolve(eqs1) == sol1 + assert checksysodesol(eqs1, sol1) == (True, [0, 0]) + + eqs2 = [Eq(Derivative(f(x), x), x**2*g(x) + x*f(x)), + Eq(Derivative(g(x), x), 2*x**2*f(x) + (3*x**2 + x)*g(x))] + sol2 = [Eq(f(x), (sqrt(17)*C1/17 + C2*(17 - 3*sqrt(17))/34)*exp(x**3*(3 + sqrt(17))/6 + x**2/2) - + exp(x**3*(3 - sqrt(17))/6 + x**2/2)*(sqrt(17)*C1/17 + C2*(3*sqrt(17) + 17)*Rational(-1, 34))), + Eq(g(x), exp(x**3*(3 - sqrt(17))/6 + x**2/2)*(C1*(17 - 3*sqrt(17))/34 + sqrt(17)*C2*Rational(-2, + 17)) + exp(x**3*(3 + sqrt(17))/6 + x**2/2)*(C1*(3*sqrt(17) + 17)/34 + sqrt(17)*C2*Rational(2, 17)))] + assert dsolve(eqs2) == sol2 + assert checksysodesol(eqs2, sol2) == (True, [0, 0]) + + eqs3 = [Eq(f(x).diff(x), x*f(x) + g(x)), + Eq(g(x).diff(x), -f(x) + x*g(x))] + sol3 = [Eq(f(x), (C1/2 + I*C2/2)*exp(x**2/2 - I*x) + exp(x**2/2 + I*x)*(C1/2 + I*C2*Rational(-1, 2))), + Eq(g(x), (I*C1/2 + C2/2)*exp(x**2/2 + I*x) - exp(x**2/2 - I*x)*(I*C1/2 + C2*Rational(-1, 2)))] + assert dsolve(eqs3) == sol3 + assert checksysodesol(eqs3, sol3) == (True, [0, 0]) + + eqs4 = [Eq(f(x).diff(x), x*(f(x) + g(x) + h(x))), Eq(g(x).diff(x), x*(f(x) + g(x) + h(x))), + Eq(h(x).diff(x), x*(f(x) + g(x) + h(x)))] + sol4 = [Eq(f(x), -C1/3 - C2/3 + 2*C3/3 + (C1/3 + C2/3 + C3/3)*exp(3*x**2/2)), + Eq(g(x), 2*C1/3 - C2/3 - C3/3 + (C1/3 + C2/3 + C3/3)*exp(3*x**2/2)), + Eq(h(x), -C1/3 + 2*C2/3 - C3/3 + (C1/3 + C2/3 + C3/3)*exp(3*x**2/2))] + assert dsolve(eqs4) == sol4 + assert checksysodesol(eqs4, sol4) == (True, [0, 0, 0]) + + eqs5 = [Eq(f(x).diff(x), x**2*(f(x) + g(x) + h(x))), Eq(g(x).diff(x), x**2*(f(x) + g(x) + h(x))), + Eq(h(x).diff(x), x**2*(f(x) + g(x) + h(x)))] + sol5 = [Eq(f(x), -C1/3 - C2/3 + 2*C3/3 + (C1/3 + C2/3 + C3/3)*exp(x**3)), + Eq(g(x), 2*C1/3 - C2/3 - C3/3 + (C1/3 + C2/3 + C3/3)*exp(x**3)), + Eq(h(x), -C1/3 + 2*C2/3 - C3/3 + (C1/3 + C2/3 + C3/3)*exp(x**3))] + assert dsolve(eqs5) == sol5 + assert checksysodesol(eqs5, sol5) == (True, [0, 0, 0]) + + eqs6 = [Eq(Derivative(f(x), x), x*(f(x) + g(x) + h(x) + k(x))), + Eq(Derivative(g(x), x), x*(f(x) + g(x) + h(x) + k(x))), + Eq(Derivative(h(x), x), x*(f(x) + g(x) + h(x) + k(x))), + Eq(Derivative(k(x), x), x*(f(x) + g(x) + h(x) + k(x)))] + sol6 = [Eq(f(x), -C1/4 - C2/4 - C3/4 + 3*C4/4 + (C1/4 + C2/4 + C3/4 + C4/4)*exp(2*x**2)), + Eq(g(x), 3*C1/4 - C2/4 - C3/4 - C4/4 + (C1/4 + C2/4 + C3/4 + C4/4)*exp(2*x**2)), + Eq(h(x), -C1/4 + 3*C2/4 - C3/4 - C4/4 + (C1/4 + C2/4 + C3/4 + C4/4)*exp(2*x**2)), + Eq(k(x), -C1/4 - C2/4 + 3*C3/4 - C4/4 + (C1/4 + C2/4 + C3/4 + C4/4)*exp(2*x**2))] + assert dsolve(eqs6) == sol6 + assert checksysodesol(eqs6, sol6) == (True, [0, 0, 0, 0]) + + y = symbols("y", real=True) + + eqs7 = [Eq(Derivative(f(y), y), y*f(y) + g(y)), + Eq(Derivative(g(y), y), y*g(y) - f(y))] + sol7 = [Eq(f(y), C1*exp(y**2/2)*sin(y) + C2*exp(y**2/2)*cos(y)), + Eq(g(y), C1*exp(y**2/2)*cos(y) - C2*exp(y**2/2)*sin(y))] + assert dsolve(eqs7) == sol7 + assert checksysodesol(eqs7, sol7) == (True, [0, 0]) + + #Test cases added for the issue 19763 + #https://github.com/sympy/sympy/issues/19763 + + eqs8 = [Eq(Derivative(f(t), t), 5*t*f(t) + 2*h(t)), + Eq(Derivative(h(t), t), 2*f(t) + 5*t*h(t))] + sol8 = [Eq(f(t), Mul(-1, (C1/2 - C2/2), evaluate = False)*exp(5*t**2/2 - 2*t) + (C1/2 + C2/2)*exp(5*t**2/2 + 2*t)), + Eq(h(t), (C1/2 - C2/2)*exp(5*t**2/2 - 2*t) + (C1/2 + C2/2)*exp(5*t**2/2 + 2*t))] + assert dsolve(eqs8) == sol8 + assert checksysodesol(eqs8, sol8) == (True, [0, 0]) + + eqs9 = [Eq(diff(f(t), t), 5*t*f(t) + t**2*g(t)), + Eq(diff(g(t), t), -t**2*f(t) + 5*t*g(t))] + sol9 = [Eq(f(t), (C1/2 - I*C2/2)*exp(I*t**3/3 + 5*t**2/2) + (C1/2 + I*C2/2)*exp(-I*t**3/3 + 5*t**2/2)), + Eq(g(t), Mul(-1, (I*C1/2 - C2/2) , evaluate = False)*exp(-I*t**3/3 + 5*t**2/2) + (I*C1/2 + C2/2)*exp(I*t**3/3 + 5*t**2/2))] + assert dsolve(eqs9) == sol9 + assert checksysodesol(eqs9 , sol9) == (True , [0,0]) + + eqs10 = [Eq(diff(f(t), t), t**2*g(t) + 5*t*f(t)), + Eq(diff(g(t), t), -t**2*f(t) + (9*t**2 + 5*t)*g(t))] + sol10 = [Eq(f(t), (C1*(77 - 9*sqrt(77))/154 + sqrt(77)*C2/77)*exp(t**3*(sqrt(77) + 9)/6 + 5*t**2/2) + (C1*(77 + 9*sqrt(77))/154 - sqrt(77)*C2/77)*exp(t**3*(9 - sqrt(77))/6 + 5*t**2/2)), + Eq(g(t), (sqrt(77)*C1/77 + C2*(77 - 9*sqrt(77))/154)*exp(t**3*(9 - sqrt(77))/6 + 5*t**2/2) - (sqrt(77)*C1/77 - C2*(77 + 9*sqrt(77))/154)*exp(t**3*(sqrt(77) + 9)/6 + 5*t**2/2))] + assert dsolve(eqs10) == sol10 + assert checksysodesol(eqs10 , sol10) == (True , [0,0]) + + eqs11 = [Eq(diff(f(t), t), 5*t*f(t) + t**2*g(t)), + Eq(diff(g(t), t), (1-t**2)*f(t) + (5*t + 9*t**2)*g(t))] + sol11 = [Eq(f(t), C1*x0(t) + C2*x0(t)*Integral(t**2*exp(Integral(5*t, t))*exp(Integral(9*t**2 + 5*t, t))/x0(t)**2, t)), + Eq(g(t), C1*y0(t) + C2*(y0(t)*Integral(t**2*exp(Integral(5*t, t))*exp(Integral(9*t**2 + 5*t, t))/x0(t)**2, t) + exp(Integral(5*t, t))*exp(Integral(9*t**2 + 5*t, t))/x0(t)))] + assert dsolve(eqs11) == sol11 + +@slow +def test_sysode_linear_neq_order1_type4(): + + f, g, h, k = symbols('f g h k', cls=Function) + x, t, a = symbols('x t a') + r = symbols('r', real=True) + + eqs1 = [Eq(diff(f(r), r), f(r) + r*g(r) + r**2), Eq(diff(g(r), r), -r*f(r) + g(r) + r)] + sol1 = [Eq(f(r), C1*exp(r)*sin(r**2/2) + C2*exp(r)*cos(r**2/2) + exp(r)*sin(r**2/2)*Integral(r**2*exp(-r)*sin(r**2/2) + + r*exp(-r)*cos(r**2/2), r) + exp(r)*cos(r**2/2)*Integral(r**2*exp(-r)*cos(r**2/2) - r*exp(-r)*sin(r**2/2), r)), + Eq(g(r), C1*exp(r)*cos(r**2/2) - C2*exp(r)*sin(r**2/2) - exp(r)*sin(r**2/2)*Integral(r**2*exp(-r)*cos(r**2/2) - + r*exp(-r)*sin(r**2/2), r) + exp(r)*cos(r**2/2)*Integral(r**2*exp(-r)*sin(r**2/2) + r*exp(-r)*cos(r**2/2), r))] + assert dsolve(eqs1) == sol1 + assert checksysodesol(eqs1, sol1) == (True, [0, 0]) + + eqs2 = [Eq(diff(f(r), r), f(r) + r*g(r) + r), Eq(diff(g(r), r), -r*f(r) + g(r) + log(r))] + sol2 = [Eq(f(r), C1*exp(r)*sin(r**2/2) + C2*exp(r)*cos(r**2/2) + exp(r)*sin(r**2/2)*Integral(r*exp(-r)*sin(r**2/2) + + exp(-r)*log(r)*cos(r**2/2), r) + exp(r)*cos(r**2/2)*Integral(r*exp(-r)*cos(r**2/2) - exp(-r)*log(r)*sin( + r**2/2), r)), + Eq(g(r), C1*exp(r)*cos(r**2/2) - C2*exp(r)*sin(r**2/2) - exp(r)*sin(r**2/2)*Integral(r*exp(-r)*cos(r**2/2) - + exp(-r)*log(r)*sin(r**2/2), r) + exp(r)*cos(r**2/2)*Integral(r*exp(-r)*sin(r**2/2) + exp(-r)*log(r)*cos( + r**2/2), r))] + # XXX: dsolve hangs for this in integration + assert dsolve_system(eqs2, simplify=False, doit=False) == [sol2] + assert checksysodesol(eqs2, sol2) == (True, [0, 0]) + + eqs3 = [Eq(Derivative(f(x), x), x*(f(x) + g(x) + h(x)) + x), + Eq(Derivative(g(x), x), x*(f(x) + g(x) + h(x)) + x), + Eq(Derivative(h(x), x), x*(f(x) + g(x) + h(x)) + 1)] + sol3 = [Eq(f(x), C1*Rational(-1, 3) + C2*Rational(-1, 3) + C3*Rational(2, 3) + x**2/6 + x*Rational(-1, 3) + + (C1/3 + C2/3 + C3/3)*exp(x**2*Rational(3, 2)) + + sqrt(6)*sqrt(pi)*erf(sqrt(6)*x/2)*exp(x**2*Rational(3, 2))/18 + Rational(-2, 9)), + Eq(g(x), C1*Rational(2, 3) + C2*Rational(-1, 3) + C3*Rational(-1, 3) + x**2/6 + x*Rational(-1, 3) + + (C1/3 + C2/3 + C3/3)*exp(x**2*Rational(3, 2)) + + sqrt(6)*sqrt(pi)*erf(sqrt(6)*x/2)*exp(x**2*Rational(3, 2))/18 + Rational(-2, 9)), + Eq(h(x), C1*Rational(-1, 3) + C2*Rational(2, 3) + C3*Rational(-1, 3) + x**2*Rational(-1, 3) + + x*Rational(2, 3) + (C1/3 + C2/3 + C3/3)*exp(x**2*Rational(3, 2)) + + sqrt(6)*sqrt(pi)*erf(sqrt(6)*x/2)*exp(x**2*Rational(3, 2))/18 + Rational(-2, 9))] + assert dsolve(eqs3) == sol3 + assert checksysodesol(eqs3, sol3) == (True, [0, 0, 0]) + + eqs4 = [Eq(Derivative(f(x), x), x*(f(x) + g(x) + h(x)) + sin(x)), + Eq(Derivative(g(x), x), x*(f(x) + g(x) + h(x)) + sin(x)), + Eq(Derivative(h(x), x), x*(f(x) + g(x) + h(x)) + sin(x))] + sol4 = [Eq(f(x), C1*Rational(-1, 3) + C2*Rational(-1, 3) + C3*Rational(2, 3) + (C1/3 + C2/3 + + C3/3)*exp(x**2*Rational(3, 2)) + Integral(sin(x)*exp(x**2*Rational(-3, 2)), x)*exp(x**2*Rational(3, + 2))), + Eq(g(x), C1*Rational(2, 3) + C2*Rational(-1, 3) + C3*Rational(-1, 3) + (C1/3 + C2/3 + + C3/3)*exp(x**2*Rational(3, 2)) + Integral(sin(x)*exp(x**2*Rational(-3, 2)), x)*exp(x**2*Rational(3, + 2))), + Eq(h(x), C1*Rational(-1, 3) + C2*Rational(2, 3) + C3*Rational(-1, 3) + (C1/3 + C2/3 + + C3/3)*exp(x**2*Rational(3, 2)) + Integral(sin(x)*exp(x**2*Rational(-3, 2)), x)*exp(x**2*Rational(3, + 2)))] + assert dsolve(eqs4) == sol4 + assert checksysodesol(eqs4, sol4) == (True, [0, 0, 0]) + + eqs5 = [Eq(Derivative(f(x), x), x*(f(x) + g(x) + h(x) + k(x) + 1)), + Eq(Derivative(g(x), x), x*(f(x) + g(x) + h(x) + k(x) + 1)), + Eq(Derivative(h(x), x), x*(f(x) + g(x) + h(x) + k(x) + 1)), + Eq(Derivative(k(x), x), x*(f(x) + g(x) + h(x) + k(x) + 1))] + sol5 = [Eq(f(x), C1*Rational(-1, 4) + C2*Rational(-1, 4) + C3*Rational(-1, 4) + C4*Rational(3, 4) + (C1/4 + + C2/4 + C3/4 + C4/4)*exp(2*x**2) + Rational(-1, 4)), + Eq(g(x), C1*Rational(3, 4) + C2*Rational(-1, 4) + C3*Rational(-1, 4) + C4*Rational(-1, 4) + (C1/4 + + C2/4 + C3/4 + C4/4)*exp(2*x**2) + Rational(-1, 4)), + Eq(h(x), C1*Rational(-1, 4) + C2*Rational(3, 4) + C3*Rational(-1, 4) + C4*Rational(-1, 4) + (C1/4 + + C2/4 + C3/4 + C4/4)*exp(2*x**2) + Rational(-1, 4)), + Eq(k(x), C1*Rational(-1, 4) + C2*Rational(-1, 4) + C3*Rational(3, 4) + C4*Rational(-1, 4) + (C1/4 + + C2/4 + C3/4 + C4/4)*exp(2*x**2) + Rational(-1, 4))] + assert dsolve(eqs5) == sol5 + assert checksysodesol(eqs5, sol5) == (True, [0, 0, 0, 0]) + + eqs6 = [Eq(Derivative(f(x), x), x**2*(f(x) + g(x) + h(x) + k(x) + 1)), + Eq(Derivative(g(x), x), x**2*(f(x) + g(x) + h(x) + k(x) + 1)), + Eq(Derivative(h(x), x), x**2*(f(x) + g(x) + h(x) + k(x) + 1)), + Eq(Derivative(k(x), x), x**2*(f(x) + g(x) + h(x) + k(x) + 1))] + sol6 = [Eq(f(x), C1*Rational(-1, 4) + C2*Rational(-1, 4) + C3*Rational(-1, 4) + C4*Rational(3, 4) + (C1/4 + + C2/4 + C3/4 + C4/4)*exp(x**3*Rational(4, 3)) + Rational(-1, 4)), + Eq(g(x), C1*Rational(3, 4) + C2*Rational(-1, 4) + C3*Rational(-1, 4) + C4*Rational(-1, 4) + (C1/4 + + C2/4 + C3/4 + C4/4)*exp(x**3*Rational(4, 3)) + Rational(-1, 4)), + Eq(h(x), C1*Rational(-1, 4) + C2*Rational(3, 4) + C3*Rational(-1, 4) + C4*Rational(-1, 4) + (C1/4 + + C2/4 + C3/4 + C4/4)*exp(x**3*Rational(4, 3)) + Rational(-1, 4)), + Eq(k(x), C1*Rational(-1, 4) + C2*Rational(-1, 4) + C3*Rational(3, 4) + C4*Rational(-1, 4) + (C1/4 + + C2/4 + C3/4 + C4/4)*exp(x**3*Rational(4, 3)) + Rational(-1, 4))] + assert dsolve(eqs6) == sol6 + assert checksysodesol(eqs6, sol6) == (True, [0, 0, 0, 0]) + + eqs7 = [Eq(Derivative(f(x), x), (f(x) + g(x) + h(x))*log(x) + sin(x)), Eq(Derivative(g(x), x), (f(x) + g(x) + + h(x))*log(x) + sin(x)), Eq(Derivative(h(x), x), (f(x) + g(x) + h(x))*log(x) + sin(x))] + sol7 = [Eq(f(x), -C1/3 - C2/3 + 2*C3/3 + (C1/3 + C2/3 + + C3/3)*exp(x*(3*log(x) - 3)) + exp(x*(3*log(x) - + 3))*Integral(exp(3*x)*exp(-3*x*log(x))*sin(x), x)), + Eq(g(x), 2*C1/3 - C2/3 - C3/3 + (C1/3 + C2/3 + + C3/3)*exp(x*(3*log(x) - 3)) + exp(x*(3*log(x) - + 3))*Integral(exp(3*x)*exp(-3*x*log(x))*sin(x), x)), + Eq(h(x), -C1/3 + 2*C2/3 - C3/3 + (C1/3 + C2/3 + + C3/3)*exp(x*(3*log(x) - 3)) + exp(x*(3*log(x) - + 3))*Integral(exp(3*x)*exp(-3*x*log(x))*sin(x), x))] + with dotprodsimp(True): + assert dsolve(eqs7, simplify=False, doit=False) == sol7 + assert checksysodesol(eqs7, sol7) == (True, [0, 0, 0]) + + eqs8 = [Eq(Derivative(f(x), x), (f(x) + g(x) + h(x) + k(x))*log(x) + sin(x)), Eq(Derivative(g(x), x), (f(x) + + g(x) + h(x) + k(x))*log(x) + sin(x)), Eq(Derivative(h(x), x), (f(x) + g(x) + h(x) + k(x))*log(x) + + sin(x)), Eq(Derivative(k(x), x), (f(x) + g(x) + h(x) + k(x))*log(x) + sin(x))] + sol8 = [Eq(f(x), -C1/4 - C2/4 - C3/4 + 3*C4/4 + (C1/4 + C2/4 + C3/4 + + C4/4)*exp(x*(4*log(x) - 4)) + exp(x*(4*log(x) - + 4))*Integral(exp(4*x)*exp(-4*x*log(x))*sin(x), x)), + Eq(g(x), 3*C1/4 - C2/4 - C3/4 - C4/4 + (C1/4 + C2/4 + C3/4 + + C4/4)*exp(x*(4*log(x) - 4)) + exp(x*(4*log(x) - + 4))*Integral(exp(4*x)*exp(-4*x*log(x))*sin(x), x)), + Eq(h(x), -C1/4 + 3*C2/4 - C3/4 - C4/4 + (C1/4 + C2/4 + C3/4 + + C4/4)*exp(x*(4*log(x) - 4)) + exp(x*(4*log(x) - + 4))*Integral(exp(4*x)*exp(-4*x*log(x))*sin(x), x)), + Eq(k(x), -C1/4 - C2/4 + 3*C3/4 - C4/4 + (C1/4 + C2/4 + C3/4 + + C4/4)*exp(x*(4*log(x) - 4)) + exp(x*(4*log(x) - + 4))*Integral(exp(4*x)*exp(-4*x*log(x))*sin(x), x))] + with dotprodsimp(True): + assert dsolve(eqs8) == sol8 + assert checksysodesol(eqs8, sol8) == (True, [0, 0, 0, 0]) + + +def test_sysode_linear_neq_order1_type5_type6(): + f, g = symbols("f g", cls=Function) + x, x_ = symbols("x x_") + + # Type 5 + eqs1 = [Eq(Derivative(f(x), x), (2*f(x) + g(x))/x), Eq(Derivative(g(x), x), (f(x) + 2*g(x))/x)] + sol1 = [Eq(f(x), -C1*x + C2*x**3), Eq(g(x), C1*x + C2*x**3)] + assert dsolve(eqs1) == sol1 + assert checksysodesol(eqs1, sol1) == (True, [0, 0]) + + # Type 6 + eqs2 = [Eq(Derivative(f(x), x), (2*f(x) + g(x) + 1)/x), + Eq(Derivative(g(x), x), (x + f(x) + 2*g(x))/x)] + sol2 = [Eq(f(x), C2*x**3 - x*(C1 + Rational(1, 4)) + x*log(x)*Rational(-1, 2) + Rational(-2, 3)), + Eq(g(x), C2*x**3 + x*log(x)/2 + x*(C1 + Rational(-1, 4)) + Rational(1, 3))] + assert dsolve(eqs2) == sol2 + assert checksysodesol(eqs2, sol2) == (True, [0, 0]) + + +def test_higher_order_to_first_order(): + f, g = symbols('f g', cls=Function) + x = symbols('x') + + eqs1 = [Eq(Derivative(f(x), (x, 2)), 2*f(x) + g(x)), + Eq(Derivative(g(x), (x, 2)), -f(x))] + sol1 = [Eq(f(x), -C2*x*exp(-x) + C3*x*exp(x) - (C1 - C2)*exp(-x) + (C3 + C4)*exp(x)), + Eq(g(x), C2*x*exp(-x) - C3*x*exp(x) + (C1 + C2)*exp(-x) + (C3 - C4)*exp(x))] + assert dsolve(eqs1) == sol1 + assert checksysodesol(eqs1, sol1) == (True, [0, 0]) + + eqs2 = [Eq(f(x).diff(x, 2), 0), Eq(g(x).diff(x, 2), f(x))] + sol2 = [Eq(f(x), C1 + C2*x), Eq(g(x), C1*x**2/2 + C2*x**3/6 + C3 + C4*x)] + assert dsolve(eqs2) == sol2 + assert checksysodesol(eqs2, sol2) == (True, [0, 0]) + + eqs3 = [Eq(Derivative(f(x), (x, 2)), 2*f(x)), + Eq(Derivative(g(x), (x, 2)), -f(x) + 2*g(x))] + sol3 = [Eq(f(x), 4*C1*exp(-sqrt(2)*x) + 4*C2*exp(sqrt(2)*x)), + Eq(g(x), sqrt(2)*C1*x*exp(-sqrt(2)*x) - sqrt(2)*C2*x*exp(sqrt(2)*x) + (C1 + + sqrt(2)*C4)*exp(-sqrt(2)*x) + (C2 - sqrt(2)*C3)*exp(sqrt(2)*x))] + assert dsolve(eqs3) == sol3 + assert checksysodesol(eqs3, sol3) == (True, [0, 0]) + + eqs4 = [Eq(Derivative(f(x), (x, 2)), 2*f(x) + g(x)), + Eq(Derivative(g(x), (x, 2)), 2*g(x))] + sol4 = [Eq(f(x), C1*x*exp(sqrt(2)*x)/4 + C3*x*exp(-sqrt(2)*x)/4 + (C2/4 + sqrt(2)*C3/8)*exp(-sqrt(2)*x) - + exp(sqrt(2)*x)*(sqrt(2)*C1/8 + C4*Rational(-1, 4))), + Eq(g(x), sqrt(2)*C1*exp(sqrt(2)*x)/2 + sqrt(2)*C3*exp(-sqrt(2)*x)*Rational(-1, 2))] + assert dsolve(eqs4) == sol4 + assert checksysodesol(eqs4, sol4) == (True, [0, 0]) + + eqs5 = [Eq(f(x).diff(x, 2), f(x)), Eq(g(x).diff(x, 2), f(x))] + sol5 = [Eq(f(x), -C1*exp(-x) + C2*exp(x)), Eq(g(x), -C1*exp(-x) + C2*exp(x) + C3 + C4*x)] + assert dsolve(eqs5) == sol5 + assert checksysodesol(eqs5, sol5) == (True, [0, 0]) + + eqs6 = [Eq(Derivative(f(x), (x, 2)), f(x) + g(x)), + Eq(Derivative(g(x), (x, 2)), -f(x) - g(x))] + sol6 = [Eq(f(x), C1 + C2*x**2/2 + C2 + C4*x**3/6 + x*(C3 + C4)), + Eq(g(x), -C1 + C2*x**2*Rational(-1, 2) - C3*x + C4*x**3*Rational(-1, 6))] + assert dsolve(eqs6) == sol6 + assert checksysodesol(eqs6, sol6) == (True, [0, 0]) + + eqs7 = [Eq(Derivative(f(x), (x, 2)), f(x) + g(x) + 1), + Eq(Derivative(g(x), (x, 2)), f(x) + g(x) + 1)] + sol7 = [Eq(f(x), -C1 - C2*x + sqrt(2)*C3*exp(sqrt(2)*x)/2 + sqrt(2)*C4*exp(-sqrt(2)*x)*Rational(-1, 2) + + Rational(-1, 2)), + Eq(g(x), C1 + C2*x + sqrt(2)*C3*exp(sqrt(2)*x)/2 + sqrt(2)*C4*exp(-sqrt(2)*x)*Rational(-1, 2) + + Rational(-1, 2))] + assert dsolve(eqs7) == sol7 + assert checksysodesol(eqs7, sol7) == (True, [0, 0]) + + eqs8 = [Eq(Derivative(f(x), (x, 2)), f(x) + g(x) + 1), + Eq(Derivative(g(x), (x, 2)), -f(x) - g(x) + 1)] + sol8 = [Eq(f(x), C1 + C2 + C4*x**3/6 + x**4/12 + x**2*(C2/2 + Rational(1, 2)) + x*(C3 + C4)), + Eq(g(x), -C1 - C3*x + C4*x**3*Rational(-1, 6) + x**4*Rational(-1, 12) - x**2*(C2/2 + Rational(-1, + 2)))] + assert dsolve(eqs8) == sol8 + assert checksysodesol(eqs8, sol8) == (True, [0, 0]) + + x, y = symbols('x, y', cls=Function) + t, l = symbols('t, l') + + eqs10 = [Eq(Derivative(x(t), (t, 2)), 5*x(t) + 43*y(t)), + Eq(Derivative(y(t), (t, 2)), x(t) + 9*y(t))] + sol10 = [Eq(x(t), C1*(61 - 9*sqrt(47))*sqrt(sqrt(47) + 7)*exp(-t*sqrt(sqrt(47) + 7))/2 + C2*sqrt(7 - + sqrt(47))*(61 + 9*sqrt(47))*exp(-t*sqrt(7 - sqrt(47)))/2 + C3*(61 - 9*sqrt(47))*sqrt(sqrt(47) + + 7)*exp(t*sqrt(sqrt(47) + 7))*Rational(-1, 2) + C4*sqrt(7 - sqrt(47))*(61 + 9*sqrt(47))*exp(t*sqrt(7 + - sqrt(47)))*Rational(-1, 2)), + Eq(y(t), C1*(7 - sqrt(47))*sqrt(sqrt(47) + 7)*exp(-t*sqrt(sqrt(47) + 7))*Rational(-1, 2) + C2*sqrt(7 + - sqrt(47))*(sqrt(47) + 7)*exp(-t*sqrt(7 - sqrt(47)))*Rational(-1, 2) + C3*(7 - + sqrt(47))*sqrt(sqrt(47) + 7)*exp(t*sqrt(sqrt(47) + 7))/2 + C4*sqrt(7 - sqrt(47))*(sqrt(47) + + 7)*exp(t*sqrt(7 - sqrt(47)))/2)] + assert dsolve(eqs10) == sol10 + assert checksysodesol(eqs10, sol10) == (True, [0, 0]) + + eqs11 = [Eq(7*x(t) + Derivative(x(t), (t, 2)) - 9*Derivative(y(t), t), 0), + Eq(7*y(t) + 9*Derivative(x(t), t) + Derivative(y(t), (t, 2)), 0)] + sol11 = [Eq(y(t), C1*(9 - sqrt(109))*sin(sqrt(2)*t*sqrt(9*sqrt(109) + 95)/2)/14 + C2*(9 - + sqrt(109))*cos(sqrt(2)*t*sqrt(9*sqrt(109) + 95)/2)*Rational(-1, 14) + C3*(9 + + sqrt(109))*sin(sqrt(2)*t*sqrt(95 - 9*sqrt(109))/2)/14 + C4*(9 + sqrt(109))*cos(sqrt(2)*t*sqrt(95 - + 9*sqrt(109))/2)*Rational(-1, 14)), + Eq(x(t), C1*(9 - sqrt(109))*cos(sqrt(2)*t*sqrt(9*sqrt(109) + 95)/2)*Rational(-1, 14) + C2*(9 - + sqrt(109))*sin(sqrt(2)*t*sqrt(9*sqrt(109) + 95)/2)*Rational(-1, 14) + C3*(9 + + sqrt(109))*cos(sqrt(2)*t*sqrt(95 - 9*sqrt(109))/2)/14 + C4*(9 + sqrt(109))*sin(sqrt(2)*t*sqrt(95 - + 9*sqrt(109))/2)/14)] + assert dsolve(eqs11) == sol11 + assert checksysodesol(eqs11, sol11) == (True, [0, 0]) + + # Euler Systems + # Note: To add examples of euler systems solver with non-homogeneous term. + eqs13 = [Eq(Derivative(f(t), (t, 2)), Derivative(f(t), t)/t + f(t)/t**2 + g(t)/t**2), + Eq(Derivative(g(t), (t, 2)), g(t)/t**2)] + sol13 = [Eq(f(t), C1*(sqrt(5) + 3)*Rational(-1, 2)*t**(Rational(1, 2) + + sqrt(5)*Rational(-1, 2)) + C2*t**(Rational(1, 2) + + sqrt(5)/2)*(3 - sqrt(5))*Rational(-1, 2) - C3*t**(1 - + sqrt(2))*(1 + sqrt(2)) - C4*t**(1 + sqrt(2))*(1 - sqrt(2))), + Eq(g(t), C1*(1 + sqrt(5))*Rational(-1, 2)*t**(Rational(1, 2) + + sqrt(5)*Rational(-1, 2)) + C2*t**(Rational(1, 2) + + sqrt(5)/2)*(1 - sqrt(5))*Rational(-1, 2))] + assert dsolve(eqs13) == sol13 + assert checksysodesol(eqs13, sol13) == (True, [0, 0]) + + # Solving systems using dsolve separately + eqs14 = [Eq(Derivative(f(t), (t, 2)), t*f(t)), + Eq(Derivative(g(t), (t, 2)), t*g(t))] + sol14 = [Eq(f(t), C1*airyai(t) + C2*airybi(t)), + Eq(g(t), C3*airyai(t) + C4*airybi(t))] + assert dsolve(eqs14) == sol14 + assert checksysodesol(eqs14, sol14) == (True, [0, 0]) + + + eqs15 = [Eq(Derivative(x(t), (t, 2)), t*(4*Derivative(x(t), t) + 8*Derivative(y(t), t))), + Eq(Derivative(y(t), (t, 2)), t*(12*Derivative(x(t), t) - 6*Derivative(y(t), t)))] + sol15 = [Eq(x(t), C1 - erf(sqrt(6)*t)*(sqrt(6)*sqrt(pi)*C2/33 + sqrt(6)*sqrt(pi)*C3*Rational(-1, 44)) + + erfi(sqrt(5)*t)*(sqrt(5)*sqrt(pi)*C2*Rational(2, 55) + sqrt(5)*sqrt(pi)*C3*Rational(4, 55))), + Eq(y(t), C4 + erf(sqrt(6)*t)*(sqrt(6)*sqrt(pi)*C2*Rational(2, 33) + sqrt(6)*sqrt(pi)*C3*Rational(-1, + 22)) + erfi(sqrt(5)*t)*(sqrt(5)*sqrt(pi)*C2*Rational(3, 110) + sqrt(5)*sqrt(pi)*C3*Rational(3, 55)))] + assert dsolve(eqs15) == sol15 + assert checksysodesol(eqs15, sol15) == (True, [0, 0]) + + +@slow +def test_higher_order_to_first_order_9(): + f, g = symbols('f g', cls=Function) + x = symbols('x') + + eqs9 = [f(x) + g(x) - 2*exp(I*x) + 2*Derivative(f(x), x) + Derivative(f(x), (x, 2)), + f(x) + g(x) - 2*exp(I*x) + 2*Derivative(g(x), x) + Derivative(g(x), (x, 2))] + sol9 = [Eq(f(x), -C1 + C4*exp(-2*x)/2 - (C2/2 - C3/2)*exp(-x)*cos(x) + + (C2/2 + C3/2)*exp(-x)*sin(x) + 2*((1 - 2*I)*exp(I*x)*sin(x)**2/5) + + 2*((1 - 2*I)*exp(I*x)*cos(x)**2/5)), + Eq(g(x), C1 - C4*exp(-2*x)/2 - (C2/2 - C3/2)*exp(-x)*cos(x) + + (C2/2 + C3/2)*exp(-x)*sin(x) + 2*((1 - 2*I)*exp(I*x)*sin(x)**2/5) + + 2*((1 - 2*I)*exp(I*x)*cos(x)**2/5))] + assert dsolve(eqs9) == sol9 + assert checksysodesol(eqs9, sol9) == (True, [0, 0]) + + +def test_higher_order_to_first_order_12(): + f, g = symbols('f g', cls=Function) + x = symbols('x') + + x, y = symbols('x, y', cls=Function) + t, l = symbols('t, l') + + eqs12 = [Eq(4*x(t) + Derivative(x(t), (t, 2)) + 8*Derivative(y(t), t), 0), + Eq(4*y(t) - 8*Derivative(x(t), t) + Derivative(y(t), (t, 2)), 0)] + sol12 = [Eq(y(t), C1*(2 - sqrt(5))*sin(2*t*sqrt(4*sqrt(5) + 9))*Rational(-1, 2) + C2*(2 - + sqrt(5))*cos(2*t*sqrt(4*sqrt(5) + 9))/2 + C3*(2 + sqrt(5))*sin(2*t*sqrt(9 - 4*sqrt(5)))*Rational(-1, + 2) + C4*(2 + sqrt(5))*cos(2*t*sqrt(9 - 4*sqrt(5)))/2), + Eq(x(t), C1*(2 - sqrt(5))*cos(2*t*sqrt(4*sqrt(5) + 9))*Rational(-1, 2) + C2*(2 - + sqrt(5))*sin(2*t*sqrt(4*sqrt(5) + 9))*Rational(-1, 2) + C3*(2 + sqrt(5))*cos(2*t*sqrt(9 - + 4*sqrt(5)))/2 + C4*(2 + sqrt(5))*sin(2*t*sqrt(9 - 4*sqrt(5)))/2)] + assert dsolve(eqs12) == sol12 + assert checksysodesol(eqs12, sol12) == (True, [0, 0]) + + +def test_second_order_to_first_order_2(): + f, g = symbols("f g", cls=Function) + x, t, x_, t_, d, a, m = symbols("x t x_ t_ d a m") + + eqs2 = [Eq(f(x).diff(x, 2), 2*(x*g(x).diff(x) - g(x))), + Eq(g(x).diff(x, 2),-2*(x*f(x).diff(x) - f(x)))] + sol2 = [Eq(f(x), C1*x + x*Integral(C2*exp(-x_)*exp(I*exp(2*x_))/2 + C2*exp(-x_)*exp(-I*exp(2*x_))/2 - + I*C3*exp(-x_)*exp(I*exp(2*x_))/2 + I*C3*exp(-x_)*exp(-I*exp(2*x_))/2, (x_, log(x)))), + Eq(g(x), C4*x + x*Integral(I*C2*exp(-x_)*exp(I*exp(2*x_))/2 - I*C2*exp(-x_)*exp(-I*exp(2*x_))/2 + + C3*exp(-x_)*exp(I*exp(2*x_))/2 + C3*exp(-x_)*exp(-I*exp(2*x_))/2, (x_, log(x))))] + # XXX: dsolve hangs for this in integration + assert dsolve_system(eqs2, simplify=False, doit=False) == [sol2] + assert checksysodesol(eqs2, sol2) == (True, [0, 0]) + + eqs3 = (Eq(diff(f(t),t,t), 9*t*diff(g(t),t)-9*g(t)), Eq(diff(g(t),t,t),7*t*diff(f(t),t)-7*f(t))) + sol3 = [Eq(f(t), C1*t + t*Integral(C2*exp(-t_)*exp(3*sqrt(7)*exp(2*t_)/2)/2 + C2*exp(-t_)* + exp(-3*sqrt(7)*exp(2*t_)/2)/2 + 3*sqrt(7)*C3*exp(-t_)*exp(3*sqrt(7)*exp(2*t_)/2)/14 - + 3*sqrt(7)*C3*exp(-t_)*exp(-3*sqrt(7)*exp(2*t_)/2)/14, (t_, log(t)))), + Eq(g(t), C4*t + t*Integral(sqrt(7)*C2*exp(-t_)*exp(3*sqrt(7)*exp(2*t_)/2)/6 - sqrt(7)*C2*exp(-t_)* + exp(-3*sqrt(7)*exp(2*t_)/2)/6 + C3*exp(-t_)*exp(3*sqrt(7)*exp(2*t_)/2)/2 + C3*exp(-t_)*exp(-3*sqrt(7)* + exp(2*t_)/2)/2, (t_, log(t))))] + # XXX: dsolve hangs for this in integration + assert dsolve_system(eqs3, simplify=False, doit=False) == [sol3] + assert checksysodesol(eqs3, sol3) == (True, [0, 0]) + + # Regression Test case for sympy#19238 + # https://github.com/sympy/sympy/issues/19238 + # Note: When the doit method is removed, these particular types of systems + # can be divided first so that we have lesser number of big matrices. + eqs5 = [Eq(Derivative(g(t), (t, 2)), a*m), + Eq(Derivative(f(t), (t, 2)), 0)] + sol5 = [Eq(g(t), C1 + C2*t + a*m*t**2/2), + Eq(f(t), C3 + C4*t)] + assert dsolve(eqs5) == sol5 + assert checksysodesol(eqs5, sol5) == (True, [0, 0]) + + # Type 2 + eqs6 = [Eq(Derivative(f(t), (t, 2)), f(t)/t**4), + Eq(Derivative(g(t), (t, 2)), d*g(t)/t**4)] + sol6 = [Eq(f(t), C1*sqrt(t**2)*exp(-1/t) - C2*sqrt(t**2)*exp(1/t)), + Eq(g(t), C3*sqrt(t**2)*exp(-sqrt(d)/t)*d**Rational(-1, 2) - + C4*sqrt(t**2)*exp(sqrt(d)/t)*d**Rational(-1, 2))] + assert dsolve(eqs6) == sol6 + assert checksysodesol(eqs6, sol6) == (True, [0, 0]) + + +@slow +def test_second_order_to_first_order_slow1(): + f, g = symbols("f g", cls=Function) + x, t, x_, t_, d, a, m = symbols("x t x_ t_ d a m") + + # Type 1 + + eqs1 = [Eq(f(x).diff(x, 2), 2/x *(x*g(x).diff(x) - g(x))), + Eq(g(x).diff(x, 2),-2/x *(x*f(x).diff(x) - f(x)))] + sol1 = [Eq(f(x), C1*x + 2*C2*x*Ci(2*x) - C2*sin(2*x) - 2*C3*x*Si(2*x) - C3*cos(2*x)), + Eq(g(x), -2*C2*x*Si(2*x) - C2*cos(2*x) - 2*C3*x*Ci(2*x) + C3*sin(2*x) + C4*x)] + assert dsolve(eqs1) == sol1 + assert checksysodesol(eqs1, sol1) == (True, [0, 0]) + + +def test_second_order_to_first_order_slow4(): + f, g = symbols("f g", cls=Function) + x, t, x_, t_, d, a, m = symbols("x t x_ t_ d a m") + + eqs4 = [Eq(Derivative(f(t), (t, 2)), t*sin(t)*Derivative(g(t), t) - g(t)*sin(t)), + Eq(Derivative(g(t), (t, 2)), t*sin(t)*Derivative(f(t), t) - f(t)*sin(t))] + sol4 = [Eq(f(t), C1*t + t*Integral(C2*exp(-t_)*exp(exp(t_)*cos(exp(t_)))*exp(-sin(exp(t_)))/2 + + C2*exp(-t_)*exp(-exp(t_)*cos(exp(t_)))*exp(sin(exp(t_)))/2 - C3*exp(-t_)*exp(exp(t_)*cos(exp(t_)))* + exp(-sin(exp(t_)))/2 + + C3*exp(-t_)*exp(-exp(t_)*cos(exp(t_)))*exp(sin(exp(t_)))/2, (t_, log(t)))), + Eq(g(t), C4*t + t*Integral(-C2*exp(-t_)*exp(exp(t_)*cos(exp(t_)))*exp(-sin(exp(t_)))/2 + + C2*exp(-t_)*exp(-exp(t_)*cos(exp(t_)))*exp(sin(exp(t_)))/2 + C3*exp(-t_)*exp(exp(t_)*cos(exp(t_)))* + exp(-sin(exp(t_)))/2 + C3*exp(-t_)*exp(-exp(t_)*cos(exp(t_)))*exp(sin(exp(t_)))/2, (t_, log(t))))] + # XXX: dsolve hangs for this in integration + assert dsolve_system(eqs4, simplify=False, doit=False) == [sol4] + assert checksysodesol(eqs4, sol4) == (True, [0, 0]) + + +def test_component_division(): + f, g, h, k = symbols('f g h k', cls=Function) + x = symbols("x") + funcs = [f(x), g(x), h(x), k(x)] + + eqs1 = [Eq(Derivative(f(x), x), 2*f(x)), + Eq(Derivative(g(x), x), f(x)), + Eq(Derivative(h(x), x), h(x)), + Eq(Derivative(k(x), x), h(x)**4 + k(x))] + sol1 = [Eq(f(x), 2*C1*exp(2*x)), + Eq(g(x), C1*exp(2*x) + C2), + Eq(h(x), C3*exp(x)), + Eq(k(x), C3**4*exp(4*x)/3 + C4*exp(x))] + assert dsolve(eqs1) == sol1 + assert checksysodesol(eqs1, sol1) == (True, [0, 0, 0, 0]) + + components1 = {((Eq(Derivative(f(x), x), 2*f(x)),), (Eq(Derivative(g(x), x), f(x)),)), + ((Eq(Derivative(h(x), x), h(x)),), (Eq(Derivative(k(x), x), h(x)**4 + k(x)),))} + eqsdict1 = ({f(x): set(), g(x): {f(x)}, h(x): set(), k(x): {h(x)}}, + {f(x): Eq(Derivative(f(x), x), 2*f(x)), + g(x): Eq(Derivative(g(x), x), f(x)), + h(x): Eq(Derivative(h(x), x), h(x)), + k(x): Eq(Derivative(k(x), x), h(x)**4 + k(x))}) + graph1 = [{f(x), g(x), h(x), k(x)}, {(g(x), f(x)), (k(x), h(x))}] + assert {tuple(tuple(scc) for scc in wcc) for wcc in _component_division(eqs1, funcs, x)} == components1 + assert _eqs2dict(eqs1, funcs) == eqsdict1 + assert [set(element) for element in _dict2graph(eqsdict1[0])] == graph1 + + eqs2 = [Eq(Derivative(f(x), x), 2*f(x)), + Eq(Derivative(g(x), x), f(x)), + Eq(Derivative(h(x), x), h(x)), + Eq(Derivative(k(x), x), f(x)**4 + k(x))] + sol2 = [Eq(f(x), C1*exp(2*x)), + Eq(g(x), C1*exp(2*x)/2 + C2), + Eq(h(x), C3*exp(x)), + Eq(k(x), C1**4*exp(8*x)/7 + C4*exp(x))] + assert dsolve(eqs2) == sol2 + assert checksysodesol(eqs2, sol2) == (True, [0, 0, 0, 0]) + + components2 = {frozenset([(Eq(Derivative(f(x), x), 2*f(x)),), + (Eq(Derivative(g(x), x), f(x)),), + (Eq(Derivative(k(x), x), f(x)**4 + k(x)),)]), + frozenset([(Eq(Derivative(h(x), x), h(x)),)])} + eqsdict2 = ({f(x): set(), g(x): {f(x)}, h(x): set(), k(x): {f(x)}}, + {f(x): Eq(Derivative(f(x), x), 2*f(x)), + g(x): Eq(Derivative(g(x), x), f(x)), + h(x): Eq(Derivative(h(x), x), h(x)), + k(x): Eq(Derivative(k(x), x), f(x)**4 + k(x))}) + graph2 = [{f(x), g(x), h(x), k(x)}, {(g(x), f(x)), (k(x), f(x))}] + assert {frozenset(tuple(scc) for scc in wcc) for wcc in _component_division(eqs2, funcs, x)} == components2 + assert _eqs2dict(eqs2, funcs) == eqsdict2 + assert [set(element) for element in _dict2graph(eqsdict2[0])] == graph2 + + eqs3 = [Eq(Derivative(f(x), x), 2*f(x)), + Eq(Derivative(g(x), x), x + f(x)), + Eq(Derivative(h(x), x), h(x)), + Eq(Derivative(k(x), x), f(x)**4 + k(x))] + sol3 = [Eq(f(x), C1*exp(2*x)), + Eq(g(x), C1*exp(2*x)/2 + C2 + x**2/2), + Eq(h(x), C3*exp(x)), + Eq(k(x), C1**4*exp(8*x)/7 + C4*exp(x))] + assert dsolve(eqs3) == sol3 + assert checksysodesol(eqs3, sol3) == (True, [0, 0, 0, 0]) + + components3 = {frozenset([(Eq(Derivative(f(x), x), 2*f(x)),), + (Eq(Derivative(g(x), x), x + f(x)),), + (Eq(Derivative(k(x), x), f(x)**4 + k(x)),)]), + frozenset([(Eq(Derivative(h(x), x), h(x)),),])} + eqsdict3 = ({f(x): set(), g(x): {f(x)}, h(x): set(), k(x): {f(x)}}, + {f(x): Eq(Derivative(f(x), x), 2*f(x)), + g(x): Eq(Derivative(g(x), x), x + f(x)), + h(x): Eq(Derivative(h(x), x), h(x)), + k(x): Eq(Derivative(k(x), x), f(x)**4 + k(x))}) + graph3 = [{f(x), g(x), h(x), k(x)}, {(g(x), f(x)), (k(x), f(x))}] + assert {frozenset(tuple(scc) for scc in wcc) for wcc in _component_division(eqs3, funcs, x)} == components3 + assert _eqs2dict(eqs3, funcs) == eqsdict3 + assert [set(l) for l in _dict2graph(eqsdict3[0])] == graph3 + + # Note: To be uncommented when the default option to call dsolve first for + # single ODE system can be rearranged. This can be done after the doit + # option in dsolve is made False by default. + + eqs4 = [Eq(Derivative(f(x), x), x*f(x) + 2*g(x)), + Eq(Derivative(g(x), x), f(x) + x*g(x) + x), + Eq(Derivative(h(x), x), h(x)), + Eq(Derivative(k(x), x), f(x)**4 + k(x))] + sol4 = [Eq(f(x), (C1/2 - sqrt(2)*C2/2 - sqrt(2)*Integral(x*exp(-x**2/2 - sqrt(2)*x)/2 + x*exp(-x**2/2 +\ + sqrt(2)*x)/2, x)/2 + Integral(sqrt(2)*x*exp(-x**2/2 - sqrt(2)*x)/2 - sqrt(2)*x*exp(-x**2/2 +\ + sqrt(2)*x)/2, x)/2)*exp(x**2/2 - sqrt(2)*x) + (C1/2 + sqrt(2)*C2/2 + sqrt(2)*Integral(x*exp(-x**2/2 + - sqrt(2)*x)/2 + x*exp(-x**2/2 + sqrt(2)*x)/2, x)/2 + Integral(sqrt(2)*x*exp(-x**2/2 - sqrt(2)*x)/2 + - sqrt(2)*x*exp(-x**2/2 + sqrt(2)*x)/2, x)/2)*exp(x**2/2 + sqrt(2)*x)), + Eq(g(x), (-sqrt(2)*C1/4 + C2/2 + Integral(x*exp(-x**2/2 - sqrt(2)*x)/2 + x*exp(-x**2/2 + sqrt(2)*x)/2, x)/2 -\ + sqrt(2)*Integral(sqrt(2)*x*exp(-x**2/2 - sqrt(2)*x)/2 - sqrt(2)*x*exp(-x**2/2 + sqrt(2)*x)/2, + x)/4)*exp(x**2/2 - sqrt(2)*x) + (sqrt(2)*C1/4 + C2/2 + Integral(x*exp(-x**2/2 - sqrt(2)*x)/2 + + x*exp(-x**2/2 + sqrt(2)*x)/2, x)/2 + sqrt(2)*Integral(sqrt(2)*x*exp(-x**2/2 - sqrt(2)*x)/2 - + sqrt(2)*x*exp(-x**2/2 + sqrt(2)*x)/2, x)/4)*exp(x**2/2 + sqrt(2)*x)), + Eq(h(x), C3*exp(x)), + Eq(k(x), C4*exp(x) + exp(x)*Integral((C1*exp(x**2/2 - sqrt(2)*x)/2 + C1*exp(x**2/2 + sqrt(2)*x)/2 - + sqrt(2)*C2*exp(x**2/2 - sqrt(2)*x)/2 + sqrt(2)*C2*exp(x**2/2 + sqrt(2)*x)/2 - sqrt(2)*exp(x**2/2 - + sqrt(2)*x)*Integral(x*exp(-x**2/2 - sqrt(2)*x)/2 + x*exp(-x**2/2 + sqrt(2)*x)/2, x)/2 + exp(x**2/2 - + sqrt(2)*x)*Integral(sqrt(2)*x*exp(-x**2/2 - sqrt(2)*x)/2 - sqrt(2)*x*exp(-x**2/2 + sqrt(2)*x)/2, + x)/2 + sqrt(2)*exp(x**2/2 + sqrt(2)*x)*Integral(x*exp(-x**2/2 - sqrt(2)*x)/2 + x*exp(-x**2/2 + + sqrt(2)*x)/2, x)/2 + exp(x**2/2 + sqrt(2)*x)*Integral(sqrt(2)*x*exp(-x**2/2 - sqrt(2)*x)/2 - + sqrt(2)*x*exp(-x**2/2 + sqrt(2)*x)/2, x)/2)**4*exp(-x), x))] + components4 = {(frozenset([Eq(Derivative(f(x), x), x*f(x) + 2*g(x)), + Eq(Derivative(g(x), x), x*g(x) + x + f(x))]), + frozenset([Eq(Derivative(k(x), x), f(x)**4 + k(x)),])), + (frozenset([Eq(Derivative(h(x), x), h(x)),]),)} + eqsdict4 = ({f(x): {g(x)}, g(x): {f(x)}, h(x): set(), k(x): {f(x)}}, + {f(x): Eq(Derivative(f(x), x), x*f(x) + 2*g(x)), + g(x): Eq(Derivative(g(x), x), x*g(x) + x + f(x)), + h(x): Eq(Derivative(h(x), x), h(x)), + k(x): Eq(Derivative(k(x), x), f(x)**4 + k(x))}) + graph4 = [{f(x), g(x), h(x), k(x)}, {(f(x), g(x)), (g(x), f(x)), (k(x), f(x))}] + assert {tuple(frozenset(scc) for scc in wcc) for wcc in _component_division(eqs4, funcs, x)} == components4 + assert _eqs2dict(eqs4, funcs) == eqsdict4 + assert [set(element) for element in _dict2graph(eqsdict4[0])] == graph4 + # XXX: dsolve hangs in integration here: + assert dsolve_system(eqs4, simplify=False, doit=False) == [sol4] + assert checksysodesol(eqs4, sol4) == (True, [0, 0, 0, 0]) + + eqs5 = [Eq(Derivative(f(x), x), x*f(x) + 2*g(x)), + Eq(Derivative(g(x), x), x*g(x) + f(x)), + Eq(Derivative(h(x), x), h(x)), + Eq(Derivative(k(x), x), f(x)**4 + k(x))] + sol5 = [Eq(f(x), (C1/2 - sqrt(2)*C2/2)*exp(x**2/2 - sqrt(2)*x) + (C1/2 + sqrt(2)*C2/2)*exp(x**2/2 + sqrt(2)*x)), + Eq(g(x), (-sqrt(2)*C1/4 + C2/2)*exp(x**2/2 - sqrt(2)*x) + (sqrt(2)*C1/4 + C2/2)*exp(x**2/2 + sqrt(2)*x)), + Eq(h(x), C3*exp(x)), + Eq(k(x), C4*exp(x) + exp(x)*Integral((C1*exp(x**2/2 - sqrt(2)*x)/2 + C1*exp(x**2/2 + sqrt(2)*x)/2 - + sqrt(2)*C2*exp(x**2/2 - sqrt(2)*x)/2 + sqrt(2)*C2*exp(x**2/2 + sqrt(2)*x)/2)**4*exp(-x), x))] + components5 = {(frozenset([Eq(Derivative(f(x), x), x*f(x) + 2*g(x)), + Eq(Derivative(g(x), x), x*g(x) + f(x))]), + frozenset([Eq(Derivative(k(x), x), f(x)**4 + k(x)),])), + (frozenset([Eq(Derivative(h(x), x), h(x)),]),)} + eqsdict5 = ({f(x): {g(x)}, g(x): {f(x)}, h(x): set(), k(x): {f(x)}}, + {f(x): Eq(Derivative(f(x), x), x*f(x) + 2*g(x)), + g(x): Eq(Derivative(g(x), x), x*g(x) + f(x)), + h(x): Eq(Derivative(h(x), x), h(x)), + k(x): Eq(Derivative(k(x), x), f(x)**4 + k(x))}) + graph5 = [{f(x), g(x), h(x), k(x)}, {(f(x), g(x)), (g(x), f(x)), (k(x), f(x))}] + assert {tuple(frozenset(scc) for scc in wcc) for wcc in _component_division(eqs5, funcs, x)} == components5 + assert _eqs2dict(eqs5, funcs) == eqsdict5 + assert [set(element) for element in _dict2graph(eqsdict5[0])] == graph5 + # XXX: dsolve hangs in integration here: + assert dsolve_system(eqs5, simplify=False, doit=False) == [sol5] + assert checksysodesol(eqs5, sol5) == (True, [0, 0, 0, 0]) + + +def test_linodesolve(): + t, x, a = symbols("t x a") + f, g, h = symbols("f g h", cls=Function) + + # Testing the Errors + raises(ValueError, lambda: linodesolve(1, t)) + raises(ValueError, lambda: linodesolve(a, t)) + + A1 = Matrix([[1, 2], [2, 4], [4, 6]]) + raises(NonSquareMatrixError, lambda: linodesolve(A1, t)) + + A2 = Matrix([[1, 2, 1], [3, 1, 2]]) + raises(NonSquareMatrixError, lambda: linodesolve(A2, t)) + + # Testing auto functionality + func = [f(t), g(t)] + eq = [Eq(f(t).diff(t) + g(t).diff(t), g(t)), Eq(g(t).diff(t), f(t))] + ceq = canonical_odes(eq, func, t) + (A1, A0), b = linear_ode_to_matrix(ceq[0], func, t, 1) + A = A0 + sol = [C1*(-Rational(1, 2) + sqrt(5)/2)*exp(t*(-Rational(1, 2) + sqrt(5)/2)) + C2*(-sqrt(5)/2 - Rational(1, 2))* + exp(t*(-sqrt(5)/2 - Rational(1, 2))), + C1*exp(t*(-Rational(1, 2) + sqrt(5)/2)) + C2*exp(t*(-sqrt(5)/2 - Rational(1, 2)))] + assert constant_renumber(linodesolve(A, t), variables=Tuple(*eq).free_symbols) == sol + + # Testing the Errors + raises(ValueError, lambda: linodesolve(1, t, b=Matrix([t+1]))) + raises(ValueError, lambda: linodesolve(a, t, b=Matrix([log(t) + sin(t)]))) + + raises(ValueError, lambda: linodesolve(Matrix([7]), t, b=t**2)) + raises(ValueError, lambda: linodesolve(Matrix([a+10]), t, b=log(t)*cos(t))) + + raises(ValueError, lambda: linodesolve(7, t, b=t**2)) + raises(ValueError, lambda: linodesolve(a, t, b=log(t) + sin(t))) + + A1 = Matrix([[1, 2], [2, 4], [4, 6]]) + b1 = Matrix([t, 1, t**2]) + raises(NonSquareMatrixError, lambda: linodesolve(A1, t, b=b1)) + + A2 = Matrix([[1, 2, 1], [3, 1, 2]]) + b2 = Matrix([t, t**2]) + raises(NonSquareMatrixError, lambda: linodesolve(A2, t, b=b2)) + + raises(ValueError, lambda: linodesolve(A1[:2, :], t, b=b1)) + raises(ValueError, lambda: linodesolve(A1[:2, :], t, b=b1[:1])) + + # DOIT check + A1 = Matrix([[1, -1], [1, -1]]) + b1 = Matrix([15*t - 10, -15*t - 5]) + sol1 = [C1 + C2*t + C2 - 10*t**3 + 10*t**2 + t*(15*t**2 - 5*t) - 10*t, + C1 + C2*t - 10*t**3 - 5*t**2 + t*(15*t**2 - 5*t) - 5*t] + assert constant_renumber(linodesolve(A1, t, b=b1, type="type2", doit=True), + variables=[t]) == sol1 + + # Testing auto functionality + func = [f(t), g(t)] + eq = [Eq(f(t).diff(t) + g(t).diff(t), g(t) + t), Eq(g(t).diff(t), f(t))] + ceq = canonical_odes(eq, func, t) + (A1, A0), b = linear_ode_to_matrix(ceq[0], func, t, 1) + A = A0 + sol = [-C1*exp(-t/2 + sqrt(5)*t/2)/2 + sqrt(5)*C1*exp(-t/2 + sqrt(5)*t/2)/2 - sqrt(5)*C2*exp(-sqrt(5)*t/2 - + t/2)/2 - C2*exp(-sqrt(5)*t/2 - t/2)/2 - exp(-t/2 + sqrt(5)*t/2)*Integral(t*exp(-sqrt(5)*t/2 + + t/2)/(-5 + sqrt(5)) - sqrt(5)*t*exp(-sqrt(5)*t/2 + t/2)/(-5 + sqrt(5)), t)/2 + sqrt(5)*exp(-t/2 + + sqrt(5)*t/2)*Integral(t*exp(-sqrt(5)*t/2 + t/2)/(-5 + sqrt(5)) - sqrt(5)*t*exp(-sqrt(5)*t/2 + + t/2)/(-5 + sqrt(5)), t)/2 - sqrt(5)*exp(-sqrt(5)*t/2 - t/2)*Integral(-sqrt(5)*t*exp(t/2 + + sqrt(5)*t/2)/5, t)/2 - exp(-sqrt(5)*t/2 - t/2)*Integral(-sqrt(5)*t*exp(t/2 + sqrt(5)*t/2)/5, t)/2, + C1*exp(-t/2 + sqrt(5)*t/2) + C2*exp(-sqrt(5)*t/2 - t/2) + exp(-t/2 + + sqrt(5)*t/2)*Integral(t*exp(-sqrt(5)*t/2 + t/2)/(-5 + sqrt(5)) - sqrt(5)*t*exp(-sqrt(5)*t/2 + + t/2)/(-5 + sqrt(5)), t) + exp(-sqrt(5)*t/2 - + t/2)*Integral(-sqrt(5)*t*exp(t/2 + sqrt(5)*t/2)/5, t)] + assert constant_renumber(linodesolve(A, t, b=b), variables=[t]) == sol + + # non-homogeneous term assumed to be 0 + sol1 = [-C1*exp(-t/2 + sqrt(5)*t/2)/2 + sqrt(5)*C1*exp(-t/2 + sqrt(5)*t/2)/2 - sqrt(5)*C2*exp(-sqrt(5)*t/2 + - t/2)/2 - C2*exp(-sqrt(5)*t/2 - t/2)/2, + C1*exp(-t/2 + sqrt(5)*t/2) + C2*exp(-sqrt(5)*t/2 - t/2)] + assert constant_renumber(linodesolve(A, t, type="type2"), variables=[t]) == sol1 + + # Testing the Errors + raises(ValueError, lambda: linodesolve(t+10, t)) + raises(ValueError, lambda: linodesolve(a*t, t)) + + A1 = Matrix([[1, t], [-t, 1]]) + B1, _ = _is_commutative_anti_derivative(A1, t) + raises(NonSquareMatrixError, lambda: linodesolve(A1[:, :1], t, B=B1)) + raises(ValueError, lambda: linodesolve(A1, t, B=1)) + + A2 = Matrix([[t, t, t], [t, t, t], [t, t, t]]) + B2, _ = _is_commutative_anti_derivative(A2, t) + raises(NonSquareMatrixError, lambda: linodesolve(A2, t, B=B2[:2, :])) + raises(ValueError, lambda: linodesolve(A2, t, B=2)) + raises(ValueError, lambda: linodesolve(A2, t, B=B2, type="type31")) + + raises(ValueError, lambda: linodesolve(A1, t, B=B2)) + raises(ValueError, lambda: linodesolve(A2, t, B=B1)) + + # Testing auto functionality + func = [f(t), g(t)] + eq = [Eq(f(t).diff(t), f(t) + t*g(t)), Eq(g(t).diff(t), -t*f(t) + g(t))] + ceq = canonical_odes(eq, func, t) + (A1, A0), b = linear_ode_to_matrix(ceq[0], func, t, 1) + A = A0 + sol = [(C1/2 - I*C2/2)*exp(I*t**2/2 + t) + (C1/2 + I*C2/2)*exp(-I*t**2/2 + t), + (-I*C1/2 + C2/2)*exp(-I*t**2/2 + t) + (I*C1/2 + C2/2)*exp(I*t**2/2 + t)] + assert constant_renumber(linodesolve(A, t), variables=Tuple(*eq).free_symbols) == sol + assert constant_renumber(linodesolve(A, t, type="type3"), variables=Tuple(*eq).free_symbols) == sol + + A1 = Matrix([[t, 1], [t, -1]]) + raises(NotImplementedError, lambda: linodesolve(A1, t)) + + # Testing the Errors + raises(ValueError, lambda: linodesolve(t+10, t, b=Matrix([t+1]))) + raises(ValueError, lambda: linodesolve(a*t, t, b=Matrix([log(t) + sin(t)]))) + + raises(ValueError, lambda: linodesolve(Matrix([7*t]), t, b=t**2)) + raises(ValueError, lambda: linodesolve(Matrix([a + 10*log(t)]), t, b=log(t)*cos(t))) + + raises(ValueError, lambda: linodesolve(7*t, t, b=t**2)) + raises(ValueError, lambda: linodesolve(a*t**2, t, b=log(t) + sin(t))) + + A1 = Matrix([[1, t], [-t, 1]]) + b1 = Matrix([t, t ** 2]) + B1, _ = _is_commutative_anti_derivative(A1, t) + raises(NonSquareMatrixError, lambda: linodesolve(A1[:, :1], t, b=b1)) + + A2 = Matrix([[t, t, t], [t, t, t], [t, t, t]]) + b2 = Matrix([t, 1, t**2]) + B2, _ = _is_commutative_anti_derivative(A2, t) + raises(NonSquareMatrixError, lambda: linodesolve(A2[:2, :], t, b=b2)) + + raises(ValueError, lambda: linodesolve(A1, t, b=b2)) + raises(ValueError, lambda: linodesolve(A2, t, b=b1)) + + raises(ValueError, lambda: linodesolve(A1, t, b=b1, B=B2)) + raises(ValueError, lambda: linodesolve(A2, t, b=b2, B=B1)) + + # Testing auto functionality + func = [f(x), g(x), h(x)] + eq = [Eq(f(x).diff(x), x*(f(x) + g(x) + h(x)) + x), + Eq(g(x).diff(x), x*(f(x) + g(x) + h(x)) + x), + Eq(h(x).diff(x), x*(f(x) + g(x) + h(x)) + 1)] + ceq = canonical_odes(eq, func, x) + (A1, A0), b = linear_ode_to_matrix(ceq[0], func, x, 1) + A = A0 + _x1 = exp(-3*x**2/2) + _x2 = exp(3*x**2/2) + _x3 = Integral(2*_x1*x/3 + _x1/3 + x/3 - Rational(1, 3), x) + _x4 = 2*_x2*_x3/3 + _x5 = Integral(2*_x1*x/3 + _x1/3 - 2*x/3 + Rational(2, 3), x) + sol = [ + C1*_x2/3 - C1/3 + C2*_x2/3 - C2/3 + C3*_x2/3 + 2*C3/3 + _x2*_x5/3 + _x3/3 + _x4 - _x5/3, + C1*_x2/3 + 2*C1/3 + C2*_x2/3 - C2/3 + C3*_x2/3 - C3/3 + _x2*_x5/3 + _x3/3 + _x4 - _x5/3, + C1*_x2/3 - C1/3 + C2*_x2/3 + 2*C2/3 + C3*_x2/3 - C3/3 + _x2*_x5/3 - 2*_x3/3 + _x4 + 2*_x5/3, + ] + assert constant_renumber(linodesolve(A, x, b=b), variables=Tuple(*eq).free_symbols) == sol + assert constant_renumber(linodesolve(A, x, b=b, type="type4"), + variables=Tuple(*eq).free_symbols) == sol + + A1 = Matrix([[t, 1], [t, -1]]) + raises(NotImplementedError, lambda: linodesolve(A1, t, b=b1)) + + # non-homogeneous term not passed + sol1 = [-C1/3 - C2/3 + 2*C3/3 + (C1/3 + C2/3 + C3/3)*exp(3*x**2/2), 2*C1/3 - C2/3 - C3/3 + (C1/3 + C2/3 + C3/3)*exp(3*x**2/2), + -C1/3 + 2*C2/3 - C3/3 + (C1/3 + C2/3 + C3/3)*exp(3*x**2/2)] + assert constant_renumber(linodesolve(A, x, type="type4", doit=True), variables=Tuple(*eq).free_symbols) == sol1 + + +@slow +def test_linear_3eq_order1_type4_slow(): + x, y, z = symbols('x, y, z', cls=Function) + t = Symbol('t') + + f = t ** 3 + log(t) + g = t ** 2 + sin(t) + eq1 = (Eq(diff(x(t), t), (4 * f + g) * x(t) - f * y(t) - 2 * f * z(t)), + Eq(diff(y(t), t), 2 * f * x(t) + (f + g) * y(t) - 2 * f * z(t)), Eq(diff(z(t), t), 5 * f * x(t) + f * y( + t) + (-3 * f + g) * z(t))) + with dotprodsimp(True): + dsolve(eq1) + + +@slow +def test_linear_neq_order1_type2_slow1(): + i, r1, c1, r2, c2, t = symbols('i, r1, c1, r2, c2, t') + x1 = Function('x1') + x2 = Function('x2') + + eq1 = r1*c1*Derivative(x1(t), t) + x1(t) - x2(t) - r1*i + eq2 = r2*c1*Derivative(x1(t), t) + r2*c2*Derivative(x2(t), t) + x2(t) - r2*i + eq = [eq1, eq2] + + # XXX: Solution is too complicated + [sol] = dsolve_system(eq, simplify=False, doit=False) + assert checksysodesol(eq, sol) == (True, [0, 0]) + + +# Regression test case for issue #9204 +# https://github.com/sympy/sympy/issues/9204 +@tooslow +def test_linear_new_order1_type2_de_lorentz_slow_check(): + m = Symbol("m", real=True) + q = Symbol("q", real=True) + t = Symbol("t", real=True) + + e1, e2, e3 = symbols("e1:4", real=True) + b1, b2, b3 = symbols("b1:4", real=True) + v1, v2, v3 = symbols("v1:4", cls=Function, real=True) + + eqs = [ + -e1*q + m*Derivative(v1(t), t) - q*(-b2*v3(t) + b3*v2(t)), + -e2*q + m*Derivative(v2(t), t) - q*(b1*v3(t) - b3*v1(t)), + -e3*q + m*Derivative(v3(t), t) - q*(-b1*v2(t) + b2*v1(t)) + ] + sol = dsolve(eqs) + assert checksysodesol(eqs, sol) == (True, [0, 0, 0]) + + +# Regression test case for issue #14001 +# https://github.com/sympy/sympy/issues/14001 +@slow +def test_linear_neq_order1_type2_slow_check(): + RC, t, C, Vs, L, R1, V0, I0 = symbols("RC t C Vs L R1 V0 I0") + V = Function("V") + I = Function("I") + system = [Eq(V(t).diff(t), -1/RC*V(t) + I(t)/C), Eq(I(t).diff(t), -R1/L*I(t) - 1/L*V(t) + Vs/L)] + [sol] = dsolve_system(system, simplify=False, doit=False) + + assert checksysodesol(system, sol) == (True, [0, 0]) + + +def _linear_3eq_order1_type4_long(): + x, y, z = symbols('x, y, z', cls=Function) + t = Symbol('t') + + f = t ** 3 + log(t) + g = t ** 2 + sin(t) + + eq1 = (Eq(diff(x(t), t), (4*f + g)*x(t) - f*y(t) - 2*f*z(t)), + Eq(diff(y(t), t), 2*f*x(t) + (f + g)*y(t) - 2*f*z(t)), Eq(diff(z(t), t), 5*f*x(t) + f*y( + t) + (-3*f + g)*z(t))) + + dsolve_sol = dsolve(eq1) + dsolve_sol1 = [_simpsol(sol) for sol in dsolve_sol] + + x_1 = sqrt(-t**6 - 8*t**3*log(t) + 8*t**3 - 16*log(t)**2 + 32*log(t) - 16) + x_2 = sqrt(3) + x_3 = 8324372644*C1*x_1*x_2 + 4162186322*C2*x_1*x_2 - 8324372644*C3*x_1*x_2 + x_4 = 1 / (1903457163*t**3 + 3825881643*x_1*x_2 + 7613828652*log(t) - 7613828652) + x_5 = exp(t**3/3 + t*x_1*x_2/4 - cos(t)) + x_6 = exp(t**3/3 - t*x_1*x_2/4 - cos(t)) + x_7 = exp(t**4/2 + t**3/3 + 2*t*log(t) - 2*t - cos(t)) + x_8 = 91238*C1*x_1*x_2 + 91238*C2*x_1*x_2 - 91238*C3*x_1*x_2 + x_9 = 1 / (66049*t**3 - 50629*x_1*x_2 + 264196*log(t) - 264196) + x_10 = 50629 * C1 / 25189 + 37909*C2/25189 - 50629*C3/25189 - x_3*x_4 + x_11 = -50629*C1/25189 - 12720*C2/25189 + 50629*C3/25189 + x_3*x_4 + sol = [Eq(x(t), x_10*x_5 + x_11*x_6 + x_7*(C1 - C2)), Eq(y(t), x_10*x_5 + x_11*x_6), Eq(z(t), x_5*( + -424*C1/257 - 167*C2/257 + 424*C3/257 - x_8*x_9) + x_6*(167*C1/257 + 424*C2/257 - + 167*C3/257 + x_8*x_9) + x_7*(C1 - C2))] + + assert dsolve_sol1 == sol + assert checksysodesol(eq1, dsolve_sol1) == (True, [0, 0, 0]) + + +@slow +def test_neq_order1_type4_slow_check1(): + f, g = symbols("f g", cls=Function) + x = symbols("x") + + eqs = [Eq(diff(f(x), x), x*f(x) + x**2*g(x) + x), + Eq(diff(g(x), x), 2*x**2*f(x) + (x + 3*x**2)*g(x) + 1)] + sol = dsolve(eqs) + assert checksysodesol(eqs, sol) == (True, [0, 0]) + + +@slow +def test_neq_order1_type4_slow_check2(): + f, g, h = symbols("f, g, h", cls=Function) + x = Symbol("x") + + eqs = [ + Eq(Derivative(f(x), x), x*h(x) + f(x) + g(x) + 1), + Eq(Derivative(g(x), x), x*g(x) + f(x) + h(x) + 10), + Eq(Derivative(h(x), x), x*f(x) + x + g(x) + h(x)) + ] + with dotprodsimp(True): + sol = dsolve(eqs) + assert checksysodesol(eqs, sol) == (True, [0, 0, 0]) + + +def _neq_order1_type4_slow3(): + f, g = symbols("f g", cls=Function) + x = symbols("x") + + eqs = [ + Eq(Derivative(f(x), x), x*f(x) + g(x) + sin(x)), + Eq(Derivative(g(x), x), x**2 + x*g(x) - f(x)) + ] + sol = [ + Eq(f(x), (C1/2 - I*C2/2 - I*Integral(x**2*exp(-x**2/2 - I*x)/2 + + x**2*exp(-x**2/2 + I*x)/2 + I*exp(-x**2/2 - I*x)*sin(x)/2 - + I*exp(-x**2/2 + I*x)*sin(x)/2, x)/2 + Integral(-I*x**2*exp(-x**2/2 + - I*x)/2 + I*x**2*exp(-x**2/2 + I*x)/2 + exp(-x**2/2 - + I*x)*sin(x)/2 + exp(-x**2/2 + I*x)*sin(x)/2, x)/2)*exp(x**2/2 + + I*x) + (C1/2 + I*C2/2 + I*Integral(x**2*exp(-x**2/2 - I*x)/2 + + x**2*exp(-x**2/2 + I*x)/2 + I*exp(-x**2/2 - I*x)*sin(x)/2 - + I*exp(-x**2/2 + I*x)*sin(x)/2, x)/2 + Integral(-I*x**2*exp(-x**2/2 + - I*x)/2 + I*x**2*exp(-x**2/2 + I*x)/2 + exp(-x**2/2 - + I*x)*sin(x)/2 + exp(-x**2/2 + I*x)*sin(x)/2, x)/2)*exp(x**2/2 - + I*x)), + Eq(g(x), (-I*C1/2 + C2/2 + Integral(x**2*exp(-x**2/2 - I*x)/2 + + x**2*exp(-x**2/2 + I*x)/2 + I*exp(-x**2/2 - I*x)*sin(x)/2 - + I*exp(-x**2/2 + I*x)*sin(x)/2, x)/2 - + I*Integral(-I*x**2*exp(-x**2/2 - I*x)/2 + I*x**2*exp(-x**2/2 + + I*x)/2 + exp(-x**2/2 - I*x)*sin(x)/2 + exp(-x**2/2 + + I*x)*sin(x)/2, x)/2)*exp(x**2/2 - I*x) + (I*C1/2 + C2/2 + + Integral(x**2*exp(-x**2/2 - I*x)/2 + x**2*exp(-x**2/2 + I*x)/2 + + I*exp(-x**2/2 - I*x)*sin(x)/2 - I*exp(-x**2/2 + I*x)*sin(x)/2, + x)/2 + I*Integral(-I*x**2*exp(-x**2/2 - I*x)/2 + + I*x**2*exp(-x**2/2 + I*x)/2 + exp(-x**2/2 - I*x)*sin(x)/2 + + exp(-x**2/2 + I*x)*sin(x)/2, x)/2)*exp(x**2/2 + I*x)) + ] + + return eqs, sol + + +def test_neq_order1_type4_slow3(): + eqs, sol = _neq_order1_type4_slow3() + assert dsolve_system(eqs, simplify=False, doit=False) == [sol] + # XXX: dsolve gives an error in integration: + # assert dsolve(eqs) == sol + # https://github.com/sympy/sympy/issues/20155 + + +@slow +def test_neq_order1_type4_slow_check3(): + eqs, sol = _neq_order1_type4_slow3() + assert checksysodesol(eqs, sol) == (True, [0, 0]) + + +@tooslow +@XFAIL +def test_linear_3eq_order1_type4_long_dsolve_slow_xfail(): + eq, sol = _linear_3eq_order1_type4_long() + + dsolve_sol = dsolve(eq) + dsolve_sol1 = [_simpsol(sol) for sol in dsolve_sol] + + assert dsolve_sol1 == sol + + +@tooslow +def test_linear_3eq_order1_type4_long_dsolve_dotprodsimp(): + eq, sol = _linear_3eq_order1_type4_long() + + # XXX: Only works with dotprodsimp see + # test_linear_3eq_order1_type4_long_dsolve_slow_xfail which is too slow + with dotprodsimp(True): + dsolve_sol = dsolve(eq) + + dsolve_sol1 = [_simpsol(sol) for sol in dsolve_sol] + assert dsolve_sol1 == sol + + +@tooslow +def test_linear_3eq_order1_type4_long_check(): + eq, sol = _linear_3eq_order1_type4_long() + assert checksysodesol(eq, sol) == (True, [0, 0, 0]) + + +def test_dsolve_system(): + f, g = symbols("f g", cls=Function) + x = symbols("x") + eqs = [Eq(f(x).diff(x), f(x) + g(x)), Eq(g(x).diff(x), f(x) + g(x))] + funcs = [f(x), g(x)] + + sol = [[Eq(f(x), -C1 + C2*exp(2*x)), Eq(g(x), C1 + C2*exp(2*x))]] + assert dsolve_system(eqs, funcs=funcs, t=x, doit=True) == sol + + raises(ValueError, lambda: dsolve_system(1)) + raises(ValueError, lambda: dsolve_system(eqs, 1)) + raises(ValueError, lambda: dsolve_system(eqs, funcs, 1)) + raises(ValueError, lambda: dsolve_system(eqs, funcs[:1], x)) + + eq = (Eq(f(x).diff(x), 12 * f(x) - 6 * g(x)), Eq(g(x).diff(x) ** 2, 11 * f(x) + 3 * g(x))) + raises(NotImplementedError, lambda: dsolve_system(eq) == ([], [])) + + raises(NotImplementedError, lambda: dsolve_system(eq, funcs=[f(x), g(x)]) == ([], [])) + raises(NotImplementedError, lambda: dsolve_system(eq, funcs=[f(x), g(x)], t=x) == ([], [])) + raises(NotImplementedError, lambda: dsolve_system(eq, funcs=[f(x), g(x)], t=x, ics={f(0): 1, g(0): 1}) == ([], [])) + raises(NotImplementedError, lambda: dsolve_system(eq, t=x, ics={f(0): 1, g(0): 1}) == ([], [])) + raises(NotImplementedError, lambda: dsolve_system(eq, ics={f(0): 1, g(0): 1}) == ([], [])) + raises(NotImplementedError, lambda: dsolve_system(eq, funcs=[f(x), g(x)], ics={f(0): 1, g(0): 1}) == ([], [])) + +def test_dsolve(): + + f, g = symbols('f g', cls=Function) + x, y = symbols('x y') + + eqs = [f(x).diff(x) - x, f(x).diff(x) + x] + with raises(ValueError): + dsolve(eqs) + + eqs = [f(x, y).diff(x)] + with raises(ValueError): + dsolve(eqs) + + eqs = [f(x, y).diff(x)+g(x).diff(x), g(x).diff(x)] + with raises(ValueError): + dsolve(eqs) + + +@slow +def test_higher_order1_slow1(): + x, y = symbols("x y", cls=Function) + t = symbols("t") + + eq = [ + Eq(diff(x(t),t,t), (log(t)+t**2)*diff(x(t),t)+(log(t)+t**2)*3*diff(y(t),t)), + Eq(diff(y(t),t,t), (log(t)+t**2)*2*diff(x(t),t)+(log(t)+t**2)*9*diff(y(t),t)) + ] + sol, = dsolve_system(eq, simplify=False, doit=False) + # The solution is too long to write out explicitly and checkodesol is too + # slow so we test for particular values of t: + for e in eq: + res = (e.lhs - e.rhs).subs({sol[0].lhs:sol[0].rhs, sol[1].lhs:sol[1].rhs}) + res = res.subs({d: d.doit(deep=False) for d in res.atoms(Derivative)}) + assert ratsimp(res.subs(t, 1)) == 0 + + +def test_second_order_type2_slow1(): + x, y, z = symbols('x, y, z', cls=Function) + t, l = symbols('t, l') + + eqs1 = [Eq(Derivative(x(t), (t, 2)), t*(2*x(t) + y(t))), + Eq(Derivative(y(t), (t, 2)), t*(-x(t) + 2*y(t)))] + sol1 = [Eq(x(t), I*C1*airyai(t*(2 - I)**(S(1)/3)) + I*C2*airybi(t*(2 - I)**(S(1)/3)) - I*C3*airyai(t*(2 + + I)**(S(1)/3)) - I*C4*airybi(t*(2 + I)**(S(1)/3))), + Eq(y(t), C1*airyai(t*(2 - I)**(S(1)/3)) + C2*airybi(t*(2 - I)**(S(1)/3)) + C3*airyai(t*(2 + I)**(S(1)/3)) + + C4*airybi(t*(2 + I)**(S(1)/3)))] + assert dsolve(eqs1) == sol1 + assert checksysodesol(eqs1, sol1) == (True, [0, 0]) + + +@tooslow +@XFAIL +def test_nonlinear_3eq_order1_type1(): + a, b, c = symbols('a b c') + + eqs = [ + a * f(x).diff(x) - (b - c) * g(x) * h(x), + b * g(x).diff(x) - (c - a) * h(x) * f(x), + c * h(x).diff(x) - (a - b) * f(x) * g(x), + ] + + assert dsolve(eqs) # NotImplementedError + + +@XFAIL +def test_nonlinear_3eq_order1_type4(): + eqs = [ + Eq(f(x).diff(x), (2*h(x)*g(x) - 3*g(x)*h(x))), + Eq(g(x).diff(x), (4*f(x)*h(x) - 2*h(x)*f(x))), + Eq(h(x).diff(x), (3*g(x)*f(x) - 4*f(x)*g(x))), + ] + dsolve(eqs) # KeyError when matching + # sol = ? + # assert dsolve_sol == sol + # assert checksysodesol(eqs, dsolve_sol) == (True, [0, 0, 0]) + + +@tooslow +@XFAIL +def test_nonlinear_3eq_order1_type3(): + eqs = [ + Eq(f(x).diff(x), (2*f(x)**2 - 3 )), + Eq(g(x).diff(x), (4 - 2*h(x) )), + Eq(h(x).diff(x), (3*h(x) - 4*f(x)**2)), + ] + dsolve(eqs) # Not sure if this finishes... + # sol = ? + # assert dsolve_sol == sol + # assert checksysodesol(eqs, dsolve_sol) == (True, [0, 0, 0]) + + +@XFAIL +def test_nonlinear_3eq_order1_type5(): + eqs = [ + Eq(f(x).diff(x), f(x)*(2*f(x) - 3*g(x))), + Eq(g(x).diff(x), g(x)*(4*g(x) - 2*h(x))), + Eq(h(x).diff(x), h(x)*(3*h(x) - 4*f(x))), + ] + dsolve(eqs) # KeyError + # sol = ? + # assert dsolve_sol == sol + # assert checksysodesol(eqs, dsolve_sol) == (True, [0, 0, 0]) + + +def test_linear_2eq_order1(): + x, y, z = symbols('x, y, z', cls=Function) + k, l, m, n = symbols('k, l, m, n', Integer=True) + t = Symbol('t') + x0, y0 = symbols('x0, y0', cls=Function) + + eq1 = (Eq(diff(x(t),t), x(t) + y(t) + 9), Eq(diff(y(t),t), 2*x(t) + 5*y(t) + 23)) + sol1 = [Eq(x(t), C1*exp(t*(sqrt(6) + 3)) + C2*exp(t*(-sqrt(6) + 3)) - Rational(22, 3)), \ + Eq(y(t), C1*(2 + sqrt(6))*exp(t*(sqrt(6) + 3)) + C2*(-sqrt(6) + 2)*exp(t*(-sqrt(6) + 3)) - Rational(5, 3))] + assert checksysodesol(eq1, sol1) == (True, [0, 0]) + + eq2 = (Eq(diff(x(t),t), x(t) + y(t) + 81), Eq(diff(y(t),t), -2*x(t) + y(t) + 23)) + sol2 = [Eq(x(t), (C1*cos(sqrt(2)*t) + C2*sin(sqrt(2)*t))*exp(t) - Rational(58, 3)), \ + Eq(y(t), (-sqrt(2)*C1*sin(sqrt(2)*t) + sqrt(2)*C2*cos(sqrt(2)*t))*exp(t) - Rational(185, 3))] + assert checksysodesol(eq2, sol2) == (True, [0, 0]) + + eq3 = (Eq(diff(x(t),t), 5*t*x(t) + 2*y(t)), Eq(diff(y(t),t), 2*x(t) + 5*t*y(t))) + sol3 = [Eq(x(t), (C1*exp(2*t) + C2*exp(-2*t))*exp(Rational(5, 2)*t**2)), \ + Eq(y(t), (C1*exp(2*t) - C2*exp(-2*t))*exp(Rational(5, 2)*t**2))] + assert checksysodesol(eq3, sol3) == (True, [0, 0]) + + eq4 = (Eq(diff(x(t),t), 5*t*x(t) + t**2*y(t)), Eq(diff(y(t),t), -t**2*x(t) + 5*t*y(t))) + sol4 = [Eq(x(t), (C1*cos((t**3)/3) + C2*sin((t**3)/3))*exp(Rational(5, 2)*t**2)), \ + Eq(y(t), (-C1*sin((t**3)/3) + C2*cos((t**3)/3))*exp(Rational(5, 2)*t**2))] + assert checksysodesol(eq4, sol4) == (True, [0, 0]) + + eq5 = (Eq(diff(x(t),t), 5*t*x(t) + t**2*y(t)), Eq(diff(y(t),t), -t**2*x(t) + (5*t+9*t**2)*y(t))) + sol5 = [Eq(x(t), (C1*exp((sqrt(77)/2 + Rational(9, 2))*(t**3)/3) + \ + C2*exp((-sqrt(77)/2 + Rational(9, 2))*(t**3)/3))*exp(Rational(5, 2)*t**2)), \ + Eq(y(t), (C1*(sqrt(77)/2 + Rational(9, 2))*exp((sqrt(77)/2 + Rational(9, 2))*(t**3)/3) + \ + C2*(-sqrt(77)/2 + Rational(9, 2))*exp((-sqrt(77)/2 + Rational(9, 2))*(t**3)/3))*exp(Rational(5, 2)*t**2))] + assert checksysodesol(eq5, sol5) == (True, [0, 0]) + + eq6 = (Eq(diff(x(t),t), 5*t*x(t) + t**2*y(t)), Eq(diff(y(t),t), (1-t**2)*x(t) + (5*t+9*t**2)*y(t))) + sol6 = [Eq(x(t), C1*x0(t) + C2*x0(t)*Integral(t**2*exp(Integral(5*t, t))*exp(Integral(9*t**2 + 5*t, t))/x0(t)**2, t)), \ + Eq(y(t), C1*y0(t) + C2*(y0(t)*Integral(t**2*exp(Integral(5*t, t))*exp(Integral(9*t**2 + 5*t, t))/x0(t)**2, t) + \ + exp(Integral(5*t, t))*exp(Integral(9*t**2 + 5*t, t))/x0(t)))] + s = dsolve(eq6) + assert s == sol6 # too complicated to test with subs and simplify + # assert checksysodesol(eq10, sol10) == (True, [0, 0]) # this one fails + + +def test_nonlinear_2eq_order1(): + x, y, z = symbols('x, y, z', cls=Function) + t = Symbol('t') + eq1 = (Eq(diff(x(t),t),x(t)*y(t)**3), Eq(diff(y(t),t),y(t)**5)) + sol1 = [ + Eq(x(t), C1*exp((-1/(4*C2 + 4*t))**(Rational(-1, 4)))), + Eq(y(t), -(-1/(4*C2 + 4*t))**Rational(1, 4)), + Eq(x(t), C1*exp(-1/(-1/(4*C2 + 4*t))**Rational(1, 4))), + Eq(y(t), (-1/(4*C2 + 4*t))**Rational(1, 4)), + Eq(x(t), C1*exp(-I/(-1/(4*C2 + 4*t))**Rational(1, 4))), + Eq(y(t), -I*(-1/(4*C2 + 4*t))**Rational(1, 4)), + Eq(x(t), C1*exp(I/(-1/(4*C2 + 4*t))**Rational(1, 4))), + Eq(y(t), I*(-1/(4*C2 + 4*t))**Rational(1, 4))] + assert dsolve(eq1) == sol1 + assert checksysodesol(eq1, sol1) == (True, [0, 0]) + + eq2 = (Eq(diff(x(t),t), exp(3*x(t))*y(t)**3),Eq(diff(y(t),t), y(t)**5)) + sol2 = [ + Eq(x(t), -log(C1 - 3/(-1/(4*C2 + 4*t))**Rational(1, 4))/3), + Eq(y(t), -(-1/(4*C2 + 4*t))**Rational(1, 4)), + Eq(x(t), -log(C1 + 3/(-1/(4*C2 + 4*t))**Rational(1, 4))/3), + Eq(y(t), (-1/(4*C2 + 4*t))**Rational(1, 4)), + Eq(x(t), -log(C1 + 3*I/(-1/(4*C2 + 4*t))**Rational(1, 4))/3), + Eq(y(t), -I*(-1/(4*C2 + 4*t))**Rational(1, 4)), + Eq(x(t), -log(C1 - 3*I/(-1/(4*C2 + 4*t))**Rational(1, 4))/3), + Eq(y(t), I*(-1/(4*C2 + 4*t))**Rational(1, 4))] + assert dsolve(eq2) == sol2 + assert checksysodesol(eq2, sol2) == (True, [0, 0]) + + eq3 = (Eq(diff(x(t),t), y(t)*x(t)), Eq(diff(y(t),t), x(t)**3)) + tt = Rational(2, 3) + sol3 = [ + Eq(x(t), 6**tt/(6*(-sinh(sqrt(C1)*(C2 + t)/2)/sqrt(C1))**tt)), + Eq(y(t), sqrt(C1 + C1/sinh(sqrt(C1)*(C2 + t)/2)**2)/3)] + assert dsolve(eq3) == sol3 + # FIXME: assert checksysodesol(eq3, sol3) == (True, [0, 0]) + + eq4 = (Eq(diff(x(t),t),x(t)*y(t)*sin(t)**2), Eq(diff(y(t),t),y(t)**2*sin(t)**2)) + sol4 = {Eq(x(t), -2*exp(C1)/(C2*exp(C1) + t - sin(2*t)/2)), Eq(y(t), -2/(C1 + t - sin(2*t)/2))} + assert dsolve(eq4) == sol4 + # FIXME: assert checksysodesol(eq4, sol4) == (True, [0, 0]) + + eq5 = (Eq(x(t),t*diff(x(t),t)+diff(x(t),t)*diff(y(t),t)), Eq(y(t),t*diff(y(t),t)+diff(y(t),t)**2)) + sol5 = {Eq(x(t), C1*C2 + C1*t), Eq(y(t), C2**2 + C2*t)} + assert dsolve(eq5) == sol5 + assert checksysodesol(eq5, sol5) == (True, [0, 0]) + + eq6 = (Eq(diff(x(t),t),x(t)**2*y(t)**3), Eq(diff(y(t),t),y(t)**5)) + sol6 = [ + Eq(x(t), 1/(C1 - 1/(-1/(4*C2 + 4*t))**Rational(1, 4))), + Eq(y(t), -(-1/(4*C2 + 4*t))**Rational(1, 4)), + Eq(x(t), 1/(C1 + (-1/(4*C2 + 4*t))**(Rational(-1, 4)))), + Eq(y(t), (-1/(4*C2 + 4*t))**Rational(1, 4)), + Eq(x(t), 1/(C1 + I/(-1/(4*C2 + 4*t))**Rational(1, 4))), + Eq(y(t), -I*(-1/(4*C2 + 4*t))**Rational(1, 4)), + Eq(x(t), 1/(C1 - I/(-1/(4*C2 + 4*t))**Rational(1, 4))), + Eq(y(t), I*(-1/(4*C2 + 4*t))**Rational(1, 4))] + assert dsolve(eq6) == sol6 + assert checksysodesol(eq6, sol6) == (True, [0, 0]) + + +@slow +def test_nonlinear_3eq_order1(): + x, y, z = symbols('x, y, z', cls=Function) + t, u = symbols('t u') + eq1 = (4*diff(x(t),t) + 2*y(t)*z(t), 3*diff(y(t),t) - z(t)*x(t), 5*diff(z(t),t) - x(t)*y(t)) + sol1 = [Eq(4*Integral(1/(sqrt(-4*u**2 - 3*C1 + C2)*sqrt(-4*u**2 + 5*C1 - C2)), (u, x(t))), + C3 - sqrt(15)*t/15), Eq(3*Integral(1/(sqrt(-6*u**2 - C1 + 5*C2)*sqrt(3*u**2 + C1 - 4*C2)), + (u, y(t))), C3 + sqrt(5)*t/10), Eq(5*Integral(1/(sqrt(-10*u**2 - 3*C1 + C2)* + sqrt(5*u**2 + 4*C1 - C2)), (u, z(t))), C3 + sqrt(3)*t/6)] + assert [i.dummy_eq(j) for i, j in zip(dsolve(eq1), sol1)] + # FIXME: assert checksysodesol(eq1, sol1) == (True, [0, 0, 0]) + + eq2 = (4*diff(x(t),t) + 2*y(t)*z(t)*sin(t), 3*diff(y(t),t) - z(t)*x(t)*sin(t), 5*diff(z(t),t) - x(t)*y(t)*sin(t)) + sol2 = [Eq(3*Integral(1/(sqrt(-6*u**2 - C1 + 5*C2)*sqrt(3*u**2 + C1 - 4*C2)), (u, x(t))), C3 + + sqrt(5)*cos(t)/10), Eq(4*Integral(1/(sqrt(-4*u**2 - 3*C1 + C2)*sqrt(-4*u**2 + 5*C1 - C2)), + (u, y(t))), C3 - sqrt(15)*cos(t)/15), Eq(5*Integral(1/(sqrt(-10*u**2 - 3*C1 + C2)* + sqrt(5*u**2 + 4*C1 - C2)), (u, z(t))), C3 + sqrt(3)*cos(t)/6)] + assert [i.dummy_eq(j) for i, j in zip(dsolve(eq2), sol2)] + # FIXME: assert checksysodesol(eq2, sol2) == (True, [0, 0, 0]) + + +def test_C1_function_9239(): + t = Symbol('t') + C1 = Function('C1') + C2 = Function('C2') + C3 = Symbol('C3') + C4 = Symbol('C4') + eq = (Eq(diff(C1(t), t), 9*C2(t)), Eq(diff(C2(t), t), 12*C1(t))) + sol = [Eq(C1(t), 9*C3*exp(6*sqrt(3)*t) + 9*C4*exp(-6*sqrt(3)*t)), + Eq(C2(t), 6*sqrt(3)*C3*exp(6*sqrt(3)*t) - 6*sqrt(3)*C4*exp(-6*sqrt(3)*t))] + assert checksysodesol(eq, sol) == (True, [0, 0]) + + +def test_dsolve_linsystem_symbol(): + eps = Symbol('epsilon', positive=True) + eq1 = (Eq(diff(f(x), x), -eps*g(x)), Eq(diff(g(x), x), eps*f(x))) + sol1 = [Eq(f(x), -C1*eps*cos(eps*x) - C2*eps*sin(eps*x)), + Eq(g(x), -C1*eps*sin(eps*x) + C2*eps*cos(eps*x))] + assert checksysodesol(eq1, sol1) == (True, [0, 0]) diff --git a/phivenv/Lib/site-packages/sympy/solvers/pde.py b/phivenv/Lib/site-packages/sympy/solvers/pde.py new file mode 100644 index 0000000000000000000000000000000000000000..791ac67ae681ea952ea6e1dabacb7220d1843ebc --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/solvers/pde.py @@ -0,0 +1,966 @@ +""" +This module contains pdsolve() and different helper functions that it +uses. It is heavily inspired by the ode module and hence the basic +infrastructure remains the same. + +**Functions in this module** + + These are the user functions in this module: + + - pdsolve() - Solves PDE's + - classify_pde() - Classifies PDEs into possible hints for dsolve(). + - pde_separate() - Separate variables in partial differential equation either by + additive or multiplicative separation approach. + + These are the helper functions in this module: + + - pde_separate_add() - Helper function for searching additive separable solutions. + - pde_separate_mul() - Helper function for searching multiplicative + separable solutions. + +**Currently implemented solver methods** + +The following methods are implemented for solving partial differential +equations. See the docstrings of the various pde_hint() functions for +more information on each (run help(pde)): + + - 1st order linear homogeneous partial differential equations + with constant coefficients. + - 1st order linear general partial differential equations + with constant coefficients. + - 1st order linear partial differential equations with + variable coefficients. + +""" +from functools import reduce + +from itertools import combinations_with_replacement +from sympy.simplify import simplify # type: ignore +from sympy.core import Add, S +from sympy.core.function import Function, expand, AppliedUndef, Subs +from sympy.core.relational import Equality, Eq +from sympy.core.symbol import Symbol, Wild, symbols +from sympy.functions import exp +from sympy.integrals.integrals import Integral, integrate +from sympy.utilities.iterables import has_dups, is_sequence +from sympy.utilities.misc import filldedent + +from sympy.solvers.deutils import _preprocess, ode_order, _desolve +from sympy.solvers.solvers import solve +from sympy.simplify.radsimp import collect + +import operator + + +allhints = ( + "1st_linear_constant_coeff_homogeneous", + "1st_linear_constant_coeff", + "1st_linear_constant_coeff_Integral", + "1st_linear_variable_coeff" + ) + + +def pdsolve(eq, func=None, hint='default', dict=False, solvefun=None, **kwargs): + """ + Solves any (supported) kind of partial differential equation. + + **Usage** + + pdsolve(eq, f(x,y), hint) -> Solve partial differential equation + eq for function f(x,y), using method hint. + + **Details** + + ``eq`` can be any supported partial differential equation (see + the pde docstring for supported methods). This can either + be an Equality, or an expression, which is assumed to be + equal to 0. + + ``f(x,y)`` is a function of two variables whose derivatives in that + variable make up the partial differential equation. In many + cases it is not necessary to provide this; it will be autodetected + (and an error raised if it could not be detected). + + ``hint`` is the solving method that you want pdsolve to use. Use + classify_pde(eq, f(x,y)) to get all of the possible hints for + a PDE. The default hint, 'default', will use whatever hint + is returned first by classify_pde(). See Hints below for + more options that you can use for hint. + + ``solvefun`` is the convention used for arbitrary functions returned + by the PDE solver. If not set by the user, it is set by default + to be F. + + **Hints** + + Aside from the various solving methods, there are also some + meta-hints that you can pass to pdsolve(): + + "default": + This uses whatever hint is returned first by + classify_pde(). This is the default argument to + pdsolve(). + + "all": + To make pdsolve apply all relevant classification hints, + use pdsolve(PDE, func, hint="all"). This will return a + dictionary of hint:solution terms. If a hint causes + pdsolve to raise the NotImplementedError, value of that + hint's key will be the exception object raised. The + dictionary will also include some special keys: + + - order: The order of the PDE. See also ode_order() in + deutils.py + - default: The solution that would be returned by + default. This is the one produced by the hint that + appears first in the tuple returned by classify_pde(). + + "all_Integral": + This is the same as "all", except if a hint also has a + corresponding "_Integral" hint, it only returns the + "_Integral" hint. This is useful if "all" causes + pdsolve() to hang because of a difficult or impossible + integral. This meta-hint will also be much faster than + "all", because integrate() is an expensive routine. + + See also the classify_pde() docstring for more info on hints, + and the pde docstring for a list of all supported hints. + + **Tips** + - You can declare the derivative of an unknown function this way: + + >>> from sympy import Function, Derivative + >>> from sympy.abc import x, y # x and y are the independent variables + >>> f = Function("f")(x, y) # f is a function of x and y + >>> # fx will be the partial derivative of f with respect to x + >>> fx = Derivative(f, x) + >>> # fy will be the partial derivative of f with respect to y + >>> fy = Derivative(f, y) + + - See test_pde.py for many tests, which serves also as a set of + examples for how to use pdsolve(). + - pdsolve always returns an Equality class (except for the case + when the hint is "all" or "all_Integral"). Note that it is not possible + to get an explicit solution for f(x, y) as in the case of ODE's + - Do help(pde.pde_hintname) to get help more information on a + specific hint + + + Examples + ======== + + >>> from sympy.solvers.pde import pdsolve + >>> from sympy import Function, Eq + >>> from sympy.abc import x, y + >>> f = Function('f') + >>> u = f(x, y) + >>> ux = u.diff(x) + >>> uy = u.diff(y) + >>> eq = Eq(1 + (2*(ux/u)) + (3*(uy/u)), 0) + >>> pdsolve(eq) + Eq(f(x, y), F(3*x - 2*y)*exp(-2*x/13 - 3*y/13)) + + """ + + if not solvefun: + solvefun = Function('F') + + # See the docstring of _desolve for more details. + hints = _desolve(eq, func=func, hint=hint, simplify=True, + type='pde', **kwargs) + eq = hints.pop('eq', False) + all_ = hints.pop('all', False) + + if all_: + # TODO : 'best' hint should be implemented when adequate + # number of hints are added. + pdedict = {} + failed_hints = {} + gethints = classify_pde(eq, dict=True) + pdedict.update({'order': gethints['order'], + 'default': gethints['default']}) + for hint in hints: + try: + rv = _helper_simplify(eq, hint, hints[hint]['func'], + hints[hint]['order'], hints[hint][hint], solvefun) + except NotImplementedError as detail: + failed_hints[hint] = detail + else: + pdedict[hint] = rv + pdedict.update(failed_hints) + return pdedict + + else: + return _helper_simplify(eq, hints['hint'], hints['func'], + hints['order'], hints[hints['hint']], solvefun) + + +def _helper_simplify(eq, hint, func, order, match, solvefun): + """Helper function of pdsolve that calls the respective + pde functions to solve for the partial differential + equations. This minimizes the computation in + calling _desolve multiple times. + """ + solvefunc = globals()["pde_" + hint.removesuffix("_Integral")] + return _handle_Integral(solvefunc(eq, func, order, + match, solvefun), func, order, hint) + + +def _handle_Integral(expr, func, order, hint): + r""" + Converts a solution with integrals in it into an actual solution. + + Simplifies the integral mainly using doit() + """ + if hint.endswith("_Integral"): + return expr + + elif hint == "1st_linear_constant_coeff": + return simplify(expr.doit()) + + else: + return expr + + +def classify_pde(eq, func=None, dict=False, *, prep=True, **kwargs): + """ + Returns a tuple of possible pdsolve() classifications for a PDE. + + The tuple is ordered so that first item is the classification that + pdsolve() uses to solve the PDE by default. In general, + classifications near the beginning of the list will produce + better solutions faster than those near the end, though there are + always exceptions. To make pdsolve use a different classification, + use pdsolve(PDE, func, hint=). See also the pdsolve() + docstring for different meta-hints you can use. + + If ``dict`` is true, classify_pde() will return a dictionary of + hint:match expression terms. This is intended for internal use by + pdsolve(). Note that because dictionaries are ordered arbitrarily, + this will most likely not be in the same order as the tuple. + + You can get help on different hints by doing help(pde.pde_hintname), + where hintname is the name of the hint without "_Integral". + + See sympy.pde.allhints or the sympy.pde docstring for a list of all + supported hints that can be returned from classify_pde. + + + Examples + ======== + + >>> from sympy.solvers.pde import classify_pde + >>> from sympy import Function, Eq + >>> from sympy.abc import x, y + >>> f = Function('f') + >>> u = f(x, y) + >>> ux = u.diff(x) + >>> uy = u.diff(y) + >>> eq = Eq(1 + (2*(ux/u)) + (3*(uy/u)), 0) + >>> classify_pde(eq) + ('1st_linear_constant_coeff_homogeneous',) + """ + + if func and len(func.args) != 2: + raise NotImplementedError("Right now only partial " + "differential equations of two variables are supported") + + if prep or func is None: + prep, func_ = _preprocess(eq, func) + if func is None: + func = func_ + + if isinstance(eq, Equality): + if eq.rhs != 0: + return classify_pde(eq.lhs - eq.rhs, func) + eq = eq.lhs + + f = func.func + x = func.args[0] + y = func.args[1] + fx = f(x,y).diff(x) + fy = f(x,y).diff(y) + + # TODO : For now pde.py uses support offered by the ode_order function + # to find the order with respect to a multi-variable function. An + # improvement could be to classify the order of the PDE on the basis of + # individual variables. + order = ode_order(eq, f(x,y)) + + # hint:matchdict or hint:(tuple of matchdicts) + # Also will contain "default": and "order":order items. + matching_hints = {'order': order} + + if not order: + if dict: + matching_hints["default"] = None + return matching_hints + return () + + eq = expand(eq) + + a = Wild('a', exclude = [f(x,y)]) + b = Wild('b', exclude = [f(x,y), fx, fy, x, y]) + c = Wild('c', exclude = [f(x,y), fx, fy, x, y]) + d = Wild('d', exclude = [f(x,y), fx, fy, x, y]) + e = Wild('e', exclude = [f(x,y), fx, fy]) + n = Wild('n', exclude = [x, y]) + # Try removing the smallest power of f(x,y) + # from the highest partial derivatives of f(x,y) + reduced_eq = eq + if eq.is_Add: + power = None + for i in set(combinations_with_replacement((x,y), order)): + coeff = eq.coeff(f(x,y).diff(*i)) + if coeff == 1: + continue + match = coeff.match(a*f(x,y)**n) + if match and match[a]: + if power is None or match[n] < power: + power = match[n] + if power: + den = f(x,y)**power + reduced_eq = Add(*[arg/den for arg in eq.args]) + + if order == 1: + reduced_eq = collect(reduced_eq, f(x, y)) + r = reduced_eq.match(b*fx + c*fy + d*f(x,y) + e) + if r: + if not r[e]: + ## Linear first-order homogeneous partial-differential + ## equation with constant coefficients + r.update({'b': b, 'c': c, 'd': d}) + matching_hints["1st_linear_constant_coeff_homogeneous"] = r + elif r[b]**2 + r[c]**2 != 0: + ## Linear first-order general partial-differential + ## equation with constant coefficients + r.update({'b': b, 'c': c, 'd': d, 'e': e}) + matching_hints["1st_linear_constant_coeff"] = r + matching_hints["1st_linear_constant_coeff_Integral"] = r + + else: + b = Wild('b', exclude=[f(x, y), fx, fy]) + c = Wild('c', exclude=[f(x, y), fx, fy]) + d = Wild('d', exclude=[f(x, y), fx, fy]) + r = reduced_eq.match(b*fx + c*fy + d*f(x,y) + e) + if r: + r.update({'b': b, 'c': c, 'd': d, 'e': e}) + matching_hints["1st_linear_variable_coeff"] = r + + # Order keys based on allhints. + rettuple = tuple(i for i in allhints if i in matching_hints) + + if dict: + # Dictionaries are ordered arbitrarily, so make note of which + # hint would come first for pdsolve(). Use an ordered dict in Py 3. + matching_hints["default"] = None + matching_hints["ordered_hints"] = rettuple + for i in allhints: + if i in matching_hints: + matching_hints["default"] = i + break + return matching_hints + return rettuple + + +def checkpdesol(pde, sol, func=None, solve_for_func=True): + """ + Checks if the given solution satisfies the partial differential + equation. + + pde is the partial differential equation which can be given in the + form of an equation or an expression. sol is the solution for which + the pde is to be checked. This can also be given in an equation or + an expression form. If the function is not provided, the helper + function _preprocess from deutils is used to identify the function. + + If a sequence of solutions is passed, the same sort of container will be + used to return the result for each solution. + + The following methods are currently being implemented to check if the + solution satisfies the PDE: + + 1. Directly substitute the solution in the PDE and check. If the + solution has not been solved for f, then it will solve for f + provided solve_for_func has not been set to False. + + If the solution satisfies the PDE, then a tuple (True, 0) is returned. + Otherwise a tuple (False, expr) where expr is the value obtained + after substituting the solution in the PDE. However if a known solution + returns False, it may be due to the inability of doit() to simplify it to zero. + + Examples + ======== + + >>> from sympy import Function, symbols + >>> from sympy.solvers.pde import checkpdesol, pdsolve + >>> x, y = symbols('x y') + >>> f = Function('f') + >>> eq = 2*f(x,y) + 3*f(x,y).diff(x) + 4*f(x,y).diff(y) + >>> sol = pdsolve(eq) + >>> assert checkpdesol(eq, sol)[0] + >>> eq = x*f(x,y) + f(x,y).diff(x) + >>> checkpdesol(eq, sol) + (False, (x*F(4*x - 3*y) - 6*F(4*x - 3*y)/25 + 4*Subs(Derivative(F(_xi_1), _xi_1), _xi_1, 4*x - 3*y))*exp(-6*x/25 - 8*y/25)) + """ + + # Converting the pde into an equation + if not isinstance(pde, Equality): + pde = Eq(pde, 0) + + # If no function is given, try finding the function present. + if func is None: + try: + _, func = _preprocess(pde.lhs) + except ValueError: + funcs = [s.atoms(AppliedUndef) for s in ( + sol if is_sequence(sol, set) else [sol])] + funcs = set().union(funcs) + if len(funcs) != 1: + raise ValueError( + 'must pass func arg to checkpdesol for this case.') + func = funcs.pop() + + # If the given solution is in the form of a list or a set + # then return a list or set of tuples. + if is_sequence(sol, set): + return type(sol)([checkpdesol( + pde, i, func=func, + solve_for_func=solve_for_func) for i in sol]) + + # Convert solution into an equation + if not isinstance(sol, Equality): + sol = Eq(func, sol) + elif sol.rhs == func: + sol = sol.reversed + + # Try solving for the function + solved = sol.lhs == func and not sol.rhs.has(func) + if solve_for_func and not solved: + solved = solve(sol, func) + if solved: + if len(solved) == 1: + return checkpdesol(pde, Eq(func, solved[0]), + func=func, solve_for_func=False) + else: + return checkpdesol(pde, [Eq(func, t) for t in solved], + func=func, solve_for_func=False) + + # try direct substitution of the solution into the PDE and simplify + if sol.lhs == func: + pde = pde.lhs - pde.rhs + s = simplify(pde.subs(func, sol.rhs).doit()) + return s is S.Zero, s + + raise NotImplementedError(filldedent(''' + Unable to test if %s is a solution to %s.''' % (sol, pde))) + + + +def pde_1st_linear_constant_coeff_homogeneous(eq, func, order, match, solvefun): + r""" + Solves a first order linear homogeneous + partial differential equation with constant coefficients. + + The general form of this partial differential equation is + + .. math:: a \frac{\partial f(x,y)}{\partial x} + + b \frac{\partial f(x,y)}{\partial y} + c f(x,y) = 0 + + where `a`, `b` and `c` are constants. + + The general solution is of the form: + + .. math:: + f(x, y) = F(- a y + b x ) e^{- \frac{c (a x + b y)}{a^2 + b^2}} + + and can be found in SymPy with ``pdsolve``:: + + >>> from sympy.solvers import pdsolve + >>> from sympy.abc import x, y, a, b, c + >>> from sympy import Function, pprint + >>> f = Function('f') + >>> u = f(x,y) + >>> ux = u.diff(x) + >>> uy = u.diff(y) + >>> genform = a*ux + b*uy + c*u + >>> pprint(genform) + d d + a*--(f(x, y)) + b*--(f(x, y)) + c*f(x, y) + dx dy + + >>> pprint(pdsolve(genform)) + -c*(a*x + b*y) + --------------- + 2 2 + a + b + f(x, y) = F(-a*y + b*x)*e + + Examples + ======== + + >>> from sympy import pdsolve + >>> from sympy import Function, pprint + >>> from sympy.abc import x,y + >>> f = Function('f') + >>> pdsolve(f(x,y) + f(x,y).diff(x) + f(x,y).diff(y)) + Eq(f(x, y), F(x - y)*exp(-x/2 - y/2)) + >>> pprint(pdsolve(f(x,y) + f(x,y).diff(x) + f(x,y).diff(y))) + x y + - - - - + 2 2 + f(x, y) = F(x - y)*e + + References + ========== + + - Viktor Grigoryan, "Partial Differential Equations" + Math 124A - Fall 2010, pp.7 + + """ + # TODO : For now homogeneous first order linear PDE's having + # two variables are implemented. Once there is support for + # solving systems of ODE's, this can be extended to n variables. + + f = func.func + x = func.args[0] + y = func.args[1] + b = match[match['b']] + c = match[match['c']] + d = match[match['d']] + return Eq(f(x,y), exp(-S(d)/(b**2 + c**2)*(b*x + c*y))*solvefun(c*x - b*y)) + + +def pde_1st_linear_constant_coeff(eq, func, order, match, solvefun): + r""" + Solves a first order linear partial differential equation + with constant coefficients. + + The general form of this partial differential equation is + + .. math:: a \frac{\partial f(x,y)}{\partial x} + + b \frac{\partial f(x,y)}{\partial y} + + c f(x,y) = G(x,y) + + where `a`, `b` and `c` are constants and `G(x, y)` can be an arbitrary + function in `x` and `y`. + + The general solution of the PDE is: + + .. math:: + f(x, y) = \left. \left[F(\eta) + \frac{1}{a^2 + b^2} + \int\limits^{a x + b y} G\left(\frac{a \xi + b \eta}{a^2 + b^2}, + \frac{- a \eta + b \xi}{a^2 + b^2} \right) + e^{\frac{c \xi}{a^2 + b^2}}\, d\xi\right] + e^{- \frac{c \xi}{a^2 + b^2}} + \right|_{\substack{\eta=- a y + b x\\ \xi=a x + b y }}\, , + + where `F(\eta)` is an arbitrary single-valued function. The solution + can be found in SymPy with ``pdsolve``:: + + >>> from sympy.solvers import pdsolve + >>> from sympy.abc import x, y, a, b, c + >>> from sympy import Function, pprint + >>> f = Function('f') + >>> G = Function('G') + >>> u = f(x, y) + >>> ux = u.diff(x) + >>> uy = u.diff(y) + >>> genform = a*ux + b*uy + c*u - G(x,y) + >>> pprint(genform) + d d + a*--(f(x, y)) + b*--(f(x, y)) + c*f(x, y) - G(x, y) + dx dy + >>> pprint(pdsolve(genform, hint='1st_linear_constant_coeff_Integral')) + // a*x + b*y \ \| + || / | || + || | | || + || | c*xi | || + || | ------- | || + || | 2 2 | || + || | /a*xi + b*eta -a*eta + b*xi\ a + b | || + || | G|------------, -------------|*e d(xi)| || + || | | 2 2 2 2 | | || + || | \ a + b a + b / | -c*xi || + || | | -------|| + || / | 2 2|| + || | a + b || + f(x, y) = ||F(eta) + -------------------------------------------------------|*e || + || 2 2 | || + \\ a + b / /|eta=-a*y + b*x, xi=a*x + b*y + + Examples + ======== + + >>> from sympy.solvers.pde import pdsolve + >>> from sympy import Function, pprint, exp + >>> from sympy.abc import x,y + >>> f = Function('f') + >>> eq = -2*f(x,y).diff(x) + 4*f(x,y).diff(y) + 5*f(x,y) - exp(x + 3*y) + >>> pdsolve(eq) + Eq(f(x, y), (F(4*x + 2*y)*exp(x/2) + exp(x + 4*y)/15)*exp(-y)) + + References + ========== + + - Viktor Grigoryan, "Partial Differential Equations" + Math 124A - Fall 2010, pp.7 + + """ + + # TODO : For now homogeneous first order linear PDE's having + # two variables are implemented. Once there is support for + # solving systems of ODE's, this can be extended to n variables. + xi, eta = symbols("xi eta") + f = func.func + x = func.args[0] + y = func.args[1] + b = match[match['b']] + c = match[match['c']] + d = match[match['d']] + e = -match[match['e']] + expterm = exp(-S(d)/(b**2 + c**2)*xi) + functerm = solvefun(eta) + solvedict = solve((b*x + c*y - xi, c*x - b*y - eta), x, y) + # Integral should remain as it is in terms of xi, + # doit() should be done in _handle_Integral. + genterm = (1/S(b**2 + c**2))*Integral( + (1/expterm*e).subs(solvedict), (xi, b*x + c*y)) + return Eq(f(x,y), Subs(expterm*(functerm + genterm), + (eta, xi), (c*x - b*y, b*x + c*y))) + + +def pde_1st_linear_variable_coeff(eq, func, order, match, solvefun): + r""" + Solves a first order linear partial differential equation + with variable coefficients. The general form of this partial + differential equation is + + .. math:: a(x, y) \frac{\partial f(x, y)}{\partial x} + + b(x, y) \frac{\partial f(x, y)}{\partial y} + + c(x, y) f(x, y) = G(x, y) + + where `a(x, y)`, `b(x, y)`, `c(x, y)` and `G(x, y)` are arbitrary + functions in `x` and `y`. This PDE is converted into an ODE by + making the following transformation: + + 1. `\xi` as `x` + + 2. `\eta` as the constant in the solution to the differential + equation `\frac{dy}{dx} = -\frac{b}{a}` + + Making the previous substitutions reduces it to the linear ODE + + .. math:: a(\xi, \eta)\frac{du}{d\xi} + c(\xi, \eta)u - G(\xi, \eta) = 0 + + which can be solved using ``dsolve``. + + >>> from sympy.abc import x, y + >>> from sympy import Function, pprint + >>> a, b, c, G, f= [Function(i) for i in ['a', 'b', 'c', 'G', 'f']] + >>> u = f(x,y) + >>> ux = u.diff(x) + >>> uy = u.diff(y) + >>> genform = a(x, y)*u + b(x, y)*ux + c(x, y)*uy - G(x,y) + >>> pprint(genform) + d d + -G(x, y) + a(x, y)*f(x, y) + b(x, y)*--(f(x, y)) + c(x, y)*--(f(x, y)) + dx dy + + + Examples + ======== + + >>> from sympy.solvers.pde import pdsolve + >>> from sympy import Function, pprint + >>> from sympy.abc import x,y + >>> f = Function('f') + >>> eq = x*(u.diff(x)) - y*(u.diff(y)) + y**2*u - y**2 + >>> pdsolve(eq) + Eq(f(x, y), F(x*y)*exp(y**2/2) + 1) + + References + ========== + + - Viktor Grigoryan, "Partial Differential Equations" + Math 124A - Fall 2010, pp.7 + + """ + from sympy.solvers.ode import dsolve + + eta = symbols("eta") + f = func.func + x = func.args[0] + y = func.args[1] + b = match[match['b']] + c = match[match['c']] + d = match[match['d']] + e = -match[match['e']] + + + if not d: + # To deal with cases like b*ux = e or c*uy = e + if not (b and c): + if c: + try: + tsol = integrate(e/c, y) + except NotImplementedError: + raise NotImplementedError("Unable to find a solution" + " due to inability of integrate") + else: + return Eq(f(x,y), solvefun(x) + tsol) + if b: + try: + tsol = integrate(e/b, x) + except NotImplementedError: + raise NotImplementedError("Unable to find a solution" + " due to inability of integrate") + else: + return Eq(f(x,y), solvefun(y) + tsol) + + if not c: + # To deal with cases when c is 0, a simpler method is used. + # The PDE reduces to b*(u.diff(x)) + d*u = e, which is a linear ODE in x + plode = f(x).diff(x)*b + d*f(x) - e + sol = dsolve(plode, f(x)) + syms = sol.free_symbols - plode.free_symbols - {x, y} + rhs = _simplify_variable_coeff(sol.rhs, syms, solvefun, y) + return Eq(f(x, y), rhs) + + if not b: + # To deal with cases when b is 0, a simpler method is used. + # The PDE reduces to c*(u.diff(y)) + d*u = e, which is a linear ODE in y + plode = f(y).diff(y)*c + d*f(y) - e + sol = dsolve(plode, f(y)) + syms = sol.free_symbols - plode.free_symbols - {x, y} + rhs = _simplify_variable_coeff(sol.rhs, syms, solvefun, x) + return Eq(f(x, y), rhs) + + dummy = Function('d') + h = (c/b).subs(y, dummy(x)) + sol = dsolve(dummy(x).diff(x) - h, dummy(x)) + if isinstance(sol, list): + sol = sol[0] + solsym = sol.free_symbols - h.free_symbols - {x, y} + if len(solsym) == 1: + solsym = solsym.pop() + etat = (solve(sol, solsym)[0]).subs(dummy(x), y) + ysub = solve(eta - etat, y)[0] + deq = (b*(f(x).diff(x)) + d*f(x) - e).subs(y, ysub) + final = (dsolve(deq, f(x), hint='1st_linear')).rhs + if isinstance(final, list): + final = final[0] + finsyms = final.free_symbols - deq.free_symbols - {x, y} + rhs = _simplify_variable_coeff(final, finsyms, solvefun, etat) + return Eq(f(x, y), rhs) + + else: + raise NotImplementedError("Cannot solve the partial differential equation due" + " to inability of constantsimp") + + +def _simplify_variable_coeff(sol, syms, func, funcarg): + r""" + Helper function to replace constants by functions in 1st_linear_variable_coeff + """ + eta = Symbol("eta") + if len(syms) == 1: + sym = syms.pop() + final = sol.subs(sym, func(funcarg)) + + else: + for sym in syms: + final = sol.subs(sym, func(funcarg)) + + return simplify(final.subs(eta, funcarg)) + + +def pde_separate(eq, fun, sep, strategy='mul'): + """Separate variables in partial differential equation either by additive + or multiplicative separation approach. It tries to rewrite an equation so + that one of the specified variables occurs on a different side of the + equation than the others. + + :param eq: Partial differential equation + + :param fun: Original function F(x, y, z) + + :param sep: List of separated functions [X(x), u(y, z)] + + :param strategy: Separation strategy. You can choose between additive + separation ('add') and multiplicative separation ('mul') which is + default. + + Examples + ======== + + >>> from sympy import E, Eq, Function, pde_separate, Derivative as D + >>> from sympy.abc import x, t + >>> u, X, T = map(Function, 'uXT') + + >>> eq = Eq(D(u(x, t), x), E**(u(x, t))*D(u(x, t), t)) + >>> pde_separate(eq, u(x, t), [X(x), T(t)], strategy='add') + [exp(-X(x))*Derivative(X(x), x), exp(T(t))*Derivative(T(t), t)] + + >>> eq = Eq(D(u(x, t), x, 2), D(u(x, t), t, 2)) + >>> pde_separate(eq, u(x, t), [X(x), T(t)], strategy='mul') + [Derivative(X(x), (x, 2))/X(x), Derivative(T(t), (t, 2))/T(t)] + + See Also + ======== + pde_separate_add, pde_separate_mul + """ + + do_add = False + if strategy == 'add': + do_add = True + elif strategy == 'mul': + do_add = False + else: + raise ValueError('Unknown strategy: %s' % strategy) + + if isinstance(eq, Equality): + if eq.rhs != 0: + return pde_separate(Eq(eq.lhs - eq.rhs, 0), fun, sep, strategy) + else: + return pde_separate(Eq(eq, 0), fun, sep, strategy) + + if eq.rhs != 0: + raise ValueError("Value should be 0") + + # Handle arguments + orig_args = list(fun.args) + subs_args = [arg for s in sep for arg in s.args] + + if do_add: + functions = reduce(operator.add, sep) + else: + functions = reduce(operator.mul, sep) + + # Check whether variables match + if len(subs_args) != len(orig_args): + raise ValueError("Variable counts do not match") + # Check for duplicate arguments like [X(x), u(x, y)] + if has_dups(subs_args): + raise ValueError("Duplicate substitution arguments detected") + # Check whether the variables match + if set(orig_args) != set(subs_args): + raise ValueError("Arguments do not match") + + # Substitute original function with separated... + result = eq.lhs.subs(fun, functions).doit() + + # Divide by terms when doing multiplicative separation + if not do_add: + eq = 0 + for i in result.args: + eq += i/functions + result = eq + + svar = subs_args[0] + dvar = subs_args[1:] + return _separate(result, svar, dvar) + + +def pde_separate_add(eq, fun, sep): + """ + Helper function for searching additive separable solutions. + + Consider an equation of two independent variables x, y and a dependent + variable w, we look for the product of two functions depending on different + arguments: + + `w(x, y, z) = X(x) + y(y, z)` + + Examples + ======== + + >>> from sympy import E, Eq, Function, pde_separate_add, Derivative as D + >>> from sympy.abc import x, t + >>> u, X, T = map(Function, 'uXT') + + >>> eq = Eq(D(u(x, t), x), E**(u(x, t))*D(u(x, t), t)) + >>> pde_separate_add(eq, u(x, t), [X(x), T(t)]) + [exp(-X(x))*Derivative(X(x), x), exp(T(t))*Derivative(T(t), t)] + + """ + return pde_separate(eq, fun, sep, strategy='add') + + +def pde_separate_mul(eq, fun, sep): + """ + Helper function for searching multiplicative separable solutions. + + Consider an equation of two independent variables x, y and a dependent + variable w, we look for the product of two functions depending on different + arguments: + + `w(x, y, z) = X(x)*u(y, z)` + + Examples + ======== + + >>> from sympy import Function, Eq, pde_separate_mul, Derivative as D + >>> from sympy.abc import x, y + >>> u, X, Y = map(Function, 'uXY') + + >>> eq = Eq(D(u(x, y), x, 2), D(u(x, y), y, 2)) + >>> pde_separate_mul(eq, u(x, y), [X(x), Y(y)]) + [Derivative(X(x), (x, 2))/X(x), Derivative(Y(y), (y, 2))/Y(y)] + + """ + return pde_separate(eq, fun, sep, strategy='mul') + + +def _separate(eq, dep, others): + """Separate expression into two parts based on dependencies of variables.""" + + # FIRST PASS + # Extract derivatives depending our separable variable... + terms = set() + for term in eq.args: + if term.is_Mul: + for i in term.args: + if i.is_Derivative and not i.has(*others): + terms.add(term) + continue + elif term.is_Derivative and not term.has(*others): + terms.add(term) + # Find the factor that we need to divide by + div = set() + for term in terms: + ext, sep = term.expand().as_independent(dep) + # Failed? + if sep.has(*others): + return None + div.add(ext) + # FIXME: Find lcm() of all the divisors and divide with it, instead of + # current hack :( + # https://github.com/sympy/sympy/issues/4597 + if len(div) > 0: + # double sum required or some tests will fail + eq = Add(*[simplify(Add(*[term/i for i in div])) for term in eq.args]) + # SECOND PASS - separate the derivatives + div = set() + lhs = rhs = 0 + for term in eq.args: + # Check, whether we have already term with independent variable... + if not term.has(*others): + lhs += term + continue + # ...otherwise, try to separate + temp, sep = term.expand().as_independent(dep) + # Failed? + if sep.has(*others): + return None + # Extract the divisors + div.add(sep) + rhs -= term.expand() + # Do the division + fulldiv = reduce(operator.add, div) + lhs = simplify(lhs/fulldiv).expand() + rhs = simplify(rhs/fulldiv).expand() + # ...and check whether we were successful :) + if lhs.has(*others) or rhs.has(dep): + return None + return [lhs, rhs] diff --git a/phivenv/Lib/site-packages/sympy/solvers/polysys.py b/phivenv/Lib/site-packages/sympy/solvers/polysys.py new file mode 100644 index 0000000000000000000000000000000000000000..2edc70b36c25b986c975c33fcc57535ef0b31df2 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/solvers/polysys.py @@ -0,0 +1,872 @@ +"""Solvers of systems of polynomial equations. """ + +from __future__ import annotations + +from typing import Any +from collections.abc import Sequence, Iterable + +import itertools + +from sympy import Dummy +from sympy.core import S +from sympy.core.expr import Expr +from sympy.core.exprtools import factor_terms +from sympy.core.sorting import default_sort_key +from sympy.logic.boolalg import Boolean +from sympy.polys import Poly, groebner, roots +from sympy.polys.domains import ZZ +from sympy.polys.polyoptions import build_options +from sympy.polys.polytools import parallel_poly_from_expr, sqf_part +from sympy.polys.polyerrors import ( + ComputationFailed, + PolificationFailed, + CoercionFailed, + GeneratorsNeeded, + DomainError +) +from sympy.simplify import rcollect +from sympy.utilities import postfixes +from sympy.utilities.iterables import cartes +from sympy.utilities.misc import filldedent +from sympy.logic.boolalg import Or, And +from sympy.core.relational import Eq + + +class SolveFailed(Exception): + """Raised when solver's conditions were not met. """ + + +def solve_poly_system(seq, *gens, strict=False, **args): + """ + Return a list of solutions for the system of polynomial equations + or else None. + + Parameters + ========== + + seq: a list/tuple/set + Listing all the equations that are needed to be solved + gens: generators + generators of the equations in seq for which we want the + solutions + strict: a boolean (default is False) + if strict is True, NotImplementedError will be raised if + the solution is known to be incomplete (which can occur if + not all solutions are expressible in radicals) + args: Keyword arguments + Special options for solving the equations. + + + Returns + ======= + + List[Tuple] + a list of tuples with elements being solutions for the + symbols in the order they were passed as gens + None + None is returned when the computed basis contains only the ground. + + Examples + ======== + + >>> from sympy import solve_poly_system + >>> from sympy.abc import x, y + + >>> solve_poly_system([x*y - 2*y, 2*y**2 - x**2], x, y) + [(0, 0), (2, -sqrt(2)), (2, sqrt(2))] + + >>> solve_poly_system([x**5 - x + y**3, y**2 - 1], x, y, strict=True) + Traceback (most recent call last): + ... + UnsolvableFactorError + + """ + try: + polys, opt = parallel_poly_from_expr(seq, *gens, **args) + except PolificationFailed as exc: + raise ComputationFailed('solve_poly_system', len(seq), exc) + + if len(polys) == len(opt.gens) == 2: + f, g = polys + + if all(i <= 2 for i in f.degree_list() + g.degree_list()): + try: + return solve_biquadratic(f, g, opt) + except SolveFailed: + pass + + return solve_generic(polys, opt, strict=strict) + + +def solve_biquadratic(f, g, opt): + """Solve a system of two bivariate quadratic polynomial equations. + + Parameters + ========== + + f: a single Expr or Poly + First equation + g: a single Expr or Poly + Second Equation + opt: an Options object + For specifying keyword arguments and generators + + Returns + ======= + + List[Tuple] + a list of tuples with elements being solutions for the + symbols in the order they were passed as gens + None + None is returned when the computed basis contains only the ground. + + Examples + ======== + + >>> from sympy import Options, Poly + >>> from sympy.abc import x, y + >>> from sympy.solvers.polysys import solve_biquadratic + >>> NewOption = Options((x, y), {'domain': 'ZZ'}) + + >>> a = Poly(y**2 - 4 + x, y, x, domain='ZZ') + >>> b = Poly(y*2 + 3*x - 7, y, x, domain='ZZ') + >>> solve_biquadratic(a, b, NewOption) + [(1/3, 3), (41/27, 11/9)] + + >>> a = Poly(y + x**2 - 3, y, x, domain='ZZ') + >>> b = Poly(-y + x - 4, y, x, domain='ZZ') + >>> solve_biquadratic(a, b, NewOption) + [(7/2 - sqrt(29)/2, -sqrt(29)/2 - 1/2), (sqrt(29)/2 + 7/2, -1/2 + \ + sqrt(29)/2)] + """ + G = groebner([f, g]) + + if len(G) == 1 and G[0].is_ground: + return None + + if len(G) != 2: + raise SolveFailed + + x, y = opt.gens + p, q = G + if not p.gcd(q).is_ground: + # not 0-dimensional + raise SolveFailed + + p = Poly(p, x, expand=False) + p_roots = [rcollect(expr, y) for expr in roots(p).keys()] + + q = q.ltrim(-1) + q_roots = list(roots(q).keys()) + + solutions = [(p_root.subs(y, q_root), q_root) for q_root, p_root in + itertools.product(q_roots, p_roots)] + + return sorted(solutions, key=default_sort_key) + + +def solve_generic(polys, opt, strict=False): + """ + Solve a generic system of polynomial equations. + + Returns all possible solutions over C[x_1, x_2, ..., x_m] of a + set F = { f_1, f_2, ..., f_n } of polynomial equations, using + Groebner basis approach. For now only zero-dimensional systems + are supported, which means F can have at most a finite number + of solutions. If the basis contains only the ground, None is + returned. + + The algorithm works by the fact that, supposing G is the basis + of F with respect to an elimination order (here lexicographic + order is used), G and F generate the same ideal, they have the + same set of solutions. By the elimination property, if G is a + reduced, zero-dimensional Groebner basis, then there exists an + univariate polynomial in G (in its last variable). This can be + solved by computing its roots. Substituting all computed roots + for the last (eliminated) variable in other elements of G, new + polynomial system is generated. Applying the above procedure + recursively, a finite number of solutions can be found. + + The ability of finding all solutions by this procedure depends + on the root finding algorithms. If no solutions were found, it + means only that roots() failed, but the system is solvable. To + overcome this difficulty use numerical algorithms instead. + + Parameters + ========== + + polys: a list/tuple/set + Listing all the polynomial equations that are needed to be solved + opt: an Options object + For specifying keyword arguments and generators + strict: a boolean + If strict is True, NotImplementedError will be raised if the solution + is known to be incomplete + + Returns + ======= + + List[Tuple] + a list of tuples with elements being solutions for the + symbols in the order they were passed as gens + None + None is returned when the computed basis contains only the ground. + + References + ========== + + .. [Buchberger01] B. Buchberger, Groebner Bases: A Short + Introduction for Systems Theorists, In: R. Moreno-Diaz, + B. Buchberger, J.L. Freire, Proceedings of EUROCAST'01, + February, 2001 + + .. [Cox97] D. Cox, J. Little, D. O'Shea, Ideals, Varieties + and Algorithms, Springer, Second Edition, 1997, pp. 112 + + Raises + ======== + + NotImplementedError + If the system is not zero-dimensional (does not have a finite + number of solutions) + + UnsolvableFactorError + If ``strict`` is True and not all solution components are + expressible in radicals + + Examples + ======== + + >>> from sympy import Poly, Options + >>> from sympy.solvers.polysys import solve_generic + >>> from sympy.abc import x, y + >>> NewOption = Options((x, y), {'domain': 'ZZ'}) + + >>> a = Poly(x - y + 5, x, y, domain='ZZ') + >>> b = Poly(x + y - 3, x, y, domain='ZZ') + >>> solve_generic([a, b], NewOption) + [(-1, 4)] + + >>> a = Poly(x - 2*y + 5, x, y, domain='ZZ') + >>> b = Poly(2*x - y - 3, x, y, domain='ZZ') + >>> solve_generic([a, b], NewOption) + [(11/3, 13/3)] + + >>> a = Poly(x**2 + y, x, y, domain='ZZ') + >>> b = Poly(x + y*4, x, y, domain='ZZ') + >>> solve_generic([a, b], NewOption) + [(0, 0), (1/4, -1/16)] + + >>> a = Poly(x**5 - x + y**3, x, y, domain='ZZ') + >>> b = Poly(y**2 - 1, x, y, domain='ZZ') + >>> solve_generic([a, b], NewOption, strict=True) + Traceback (most recent call last): + ... + UnsolvableFactorError + + """ + def _is_univariate(f): + """Returns True if 'f' is univariate in its last variable. """ + for monom in f.monoms(): + if any(monom[:-1]): + return False + + return True + + def _subs_root(f, gen, zero): + """Replace generator with a root so that the result is nice. """ + p = f.as_expr({gen: zero}) + + if f.degree(gen) >= 2: + p = p.expand(deep=False) + + return p + + def _solve_reduced_system(system, gens, entry=False): + """Recursively solves reduced polynomial systems. """ + if len(system) == len(gens) == 1: + # the below line will produce UnsolvableFactorError if + # strict=True and the solution from `roots` is incomplete + zeros = list(roots(system[0], gens[-1], strict=strict).keys()) + return [(zero,) for zero in zeros] + + basis = groebner(system, gens, polys=True) + + if len(basis) == 1 and basis[0].is_ground: + if not entry: + return [] + else: + return None + + univariate = list(filter(_is_univariate, basis)) + + if len(basis) < len(gens): + raise NotImplementedError(filldedent(''' + only zero-dimensional systems supported + (finite number of solutions) + ''')) + + if len(univariate) == 1: + f = univariate.pop() + else: + raise NotImplementedError(filldedent(''' + only zero-dimensional systems supported + (finite number of solutions) + ''')) + + gens = f.gens + gen = gens[-1] + + # the below line will produce UnsolvableFactorError if + # strict=True and the solution from `roots` is incomplete + zeros = list(roots(f.ltrim(gen), strict=strict).keys()) + + if not zeros: + return [] + + if len(basis) == 1: + return [(zero,) for zero in zeros] + + solutions = [] + + for zero in zeros: + new_system = [] + new_gens = gens[:-1] + + for b in basis[:-1]: + eq = _subs_root(b, gen, zero) + + if eq is not S.Zero: + new_system.append(eq) + + for solution in _solve_reduced_system(new_system, new_gens): + solutions.append(solution + (zero,)) + + if solutions and len(solutions[0]) != len(gens): + raise NotImplementedError(filldedent(''' + only zero-dimensional systems supported + (finite number of solutions) + ''')) + return solutions + + try: + result = _solve_reduced_system(polys, opt.gens, entry=True) + except CoercionFailed: + raise NotImplementedError + + if result is not None: + return sorted(result, key=default_sort_key) + + +def solve_triangulated(polys, *gens, **args): + """ + Solve a polynomial system using Gianni-Kalkbrenner algorithm. + + The algorithm proceeds by computing one Groebner basis in the ground + domain and then by iteratively computing polynomial factorizations in + appropriately constructed algebraic extensions of the ground domain. + + Parameters + ========== + + polys: a list/tuple/set + Listing all the equations that are needed to be solved + gens: generators + generators of the equations in polys for which we want the + solutions + args: Keyword arguments + Special options for solving the equations + + Returns + ======= + + List[Tuple] + A List of tuples. Solutions for symbols that satisfy the + equations listed in polys + + Examples + ======== + + >>> from sympy import solve_triangulated + >>> from sympy.abc import x, y, z + + >>> F = [x**2 + y + z - 1, x + y**2 + z - 1, x + y + z**2 - 1] + + >>> solve_triangulated(F, x, y, z) + [(0, 0, 1), (0, 1, 0), (1, 0, 0)] + + Using extension for algebraic solutions. + + >>> solve_triangulated(F, x, y, z, extension=True) #doctest: +NORMALIZE_WHITESPACE + [(0, 0, 1), (0, 1, 0), (1, 0, 0), + (CRootOf(x**2 + 2*x - 1, 0), CRootOf(x**2 + 2*x - 1, 0), CRootOf(x**2 + 2*x - 1, 0)), + (CRootOf(x**2 + 2*x - 1, 1), CRootOf(x**2 + 2*x - 1, 1), CRootOf(x**2 + 2*x - 1, 1))] + + References + ========== + + 1. Patrizia Gianni, Teo Mora, Algebraic Solution of System of + Polynomial Equations using Groebner Bases, AAECC-5 on Applied Algebra, + Algebraic Algorithms and Error-Correcting Codes, LNCS 356 247--257, 1989 + + """ + opt = build_options(gens, args) + + G = groebner(polys, gens, polys=True) + G = list(reversed(G)) + + extension = opt.get('extension', False) + if extension: + def _solve_univariate(f): + return [r for r, _ in f.all_roots(multiple=False, radicals=False)] + else: + domain = opt.get('domain') + + if domain is not None: + for i, g in enumerate(G): + G[i] = g.set_domain(domain) + + def _solve_univariate(f): + return list(f.ground_roots().keys()) + + f, G = G[0].ltrim(-1), G[1:] + dom = f.get_domain() + + zeros = _solve_univariate(f) + + if extension: + solutions = {((zero,), dom.algebraic_field(zero)) for zero in zeros} + else: + solutions = {((zero,), dom) for zero in zeros} + + var_seq = reversed(gens[:-1]) + vars_seq = postfixes(gens[1:]) + + for var, vars in zip(var_seq, vars_seq): + _solutions = set() + + for values, dom in solutions: + H, mapping = [], list(zip(vars, values)) + + for g in G: + _vars = (var,) + vars + + if g.has_only_gens(*_vars) and g.degree(var) != 0: + if extension: + g = g.set_domain(g.domain.unify(dom)) + h = g.ltrim(var).eval(dict(mapping)) + + if g.degree(var) == h.degree(): + H.append(h) + + p = min(H, key=lambda h: h.degree()) + zeros = _solve_univariate(p) + + for zero in zeros: + if not (zero in dom): + dom_zero = dom.algebraic_field(zero) + else: + dom_zero = dom + + _solutions.add(((zero,) + values, dom_zero)) + + solutions = _solutions + return sorted((s for s, _ in solutions), key=default_sort_key) + + +def factor_system(eqs: Sequence[Expr | complex], gens: Sequence[Expr] = (), **kwargs: Any) -> list[list[Expr]]: + """ + Factorizes a system of polynomial equations into + irreducible subsystems. + + Parameters + ========== + + eqs : list + List of expressions to be factored. + Each expression is assumed to be equal to zero. + + gens : list, optional + Generator(s) of the polynomial ring. + If not provided, all free symbols will be used. + + **kwargs : dict, optional + Same optional arguments taken by ``factor`` + + Returns + ======= + + list[list[Expr]] + A list of lists of expressions, where each sublist represents + an irreducible subsystem. When solved, each subsystem gives + one component of the solution. Only generic solutions are + returned (cases not requiring parameters to be zero). + + Examples + ======== + + >>> from sympy.solvers.polysys import factor_system, factor_system_cond + >>> from sympy.abc import x, y, a, b, c + + A simple system with multiple solutions: + + >>> factor_system([x**2 - 1, y - 1]) + [[x + 1, y - 1], [x - 1, y - 1]] + + A system with no solution: + + >>> factor_system([x, 1]) + [] + + A system where any value of the symbol(s) is a solution: + + >>> factor_system([x - x, (x + 1)**2 - (x**2 + 2*x + 1)]) + [[]] + + A system with no generic solution: + + >>> factor_system([a*x*(x-1), b*y, c], [x, y]) + [] + + If c is added to the unknowns then the system has a generic solution: + + >>> factor_system([a*x*(x-1), b*y, c], [x, y, c]) + [[x - 1, y, c], [x, y, c]] + + Alternatively :func:`factor_system_cond` can be used to get degenerate + cases as well: + + >>> factor_system_cond([a*x*(x-1), b*y, c], [x, y]) + [[x - 1, y, c], [x, y, c], [x - 1, b, c], [x, b, c], [y, a, c], [a, b, c]] + + Each of the above cases is only satisfiable in the degenerate case `c = 0`. + + The solution set of the original system represented + by eqs is the union of the solution sets of the + factorized systems. + + An empty list [] means no generic solution exists. + A list containing an empty list [[]] means any value of + the symbol(s) is a solution. + + See Also + ======== + + factor_system_cond : Returns both generic and degenerate solutions + factor_system_bool : Returns a Boolean combination representing all solutions + sympy.polys.polytools.factor : Factors a polynomial into irreducible factors + over the rational numbers + """ + + systems = _factor_system_poly_from_expr(eqs, gens, **kwargs) + systems_generic = [sys for sys in systems if not _is_degenerate(sys)] + systems_expr = [[p.as_expr() for p in system] for system in systems_generic] + return systems_expr + + +def _is_degenerate(system: list[Poly]) -> bool: + """Helper function to check if a system is degenerate""" + return any(p.is_ground for p in system) + + +def factor_system_bool(eqs: Sequence[Expr | complex], gens: Sequence[Expr] = (), **kwargs: Any) -> Boolean: + """ + Factorizes a system of polynomial equations into irreducible DNF. + + The system of expressions(eqs) is taken and a Boolean combination + of equations is returned that represents the same solution set. + The result is in disjunctive normal form (OR of ANDs). + + Parameters + ========== + + eqs : list + List of expressions to be factored. + Each expression is assumed to be equal to zero. + + gens : list, optional + Generator(s) of the polynomial ring. + If not provided, all free symbols will be used. + + **kwargs : dict, optional + Optional keyword arguments + + + Returns + ======= + + Boolean: + A Boolean combination of equations. The result is typically in + the form of a conjunction (AND) of a disjunctive normal form + with additional conditions. + + Examples + ======== + + >>> from sympy.solvers.polysys import factor_system_bool + >>> from sympy.abc import x, y, a, b, c + >>> factor_system_bool([x**2 - 1]) + Eq(x - 1, 0) | Eq(x + 1, 0) + + >>> factor_system_bool([x**2 - 1, y - 1]) + (Eq(x - 1, 0) & Eq(y - 1, 0)) | (Eq(x + 1, 0) & Eq(y - 1, 0)) + + >>> eqs = [a * (x - 1), b] + >>> factor_system_bool([a*(x - 1), b]) + (Eq(a, 0) & Eq(b, 0)) | (Eq(b, 0) & Eq(x - 1, 0)) + + >>> factor_system_bool([a*x**2 - a, b*(x + 1), c], [x]) + (Eq(c, 0) & Eq(x + 1, 0)) | (Eq(a, 0) & Eq(b, 0) & Eq(c, 0)) | (Eq(b, 0) & Eq(c, 0) & Eq(x - 1, 0)) + + >>> factor_system_bool([x**2 + 2*x + 1 - (x + 1)**2]) + True + + The result is logically equivalent to the system of equations + i.e. eqs. The function returns ``True`` when all values of + the symbol(s) is a solution and ``False`` when the system + cannot be solved. + + See Also + ======== + + factor_system : Returns factors and solvability condition separately + factor_system_cond : Returns both factors and conditions + + """ + + systems = factor_system_cond(eqs, gens, **kwargs) + return Or(*[And(*[Eq(eq, 0) for eq in sys]) for sys in systems]) + + +def factor_system_cond(eqs: Sequence[Expr | complex], gens: Sequence[Expr] = (), **kwargs: Any) -> list[list[Expr]]: + """ + Factorizes a polynomial system into irreducible components and returns + both generic and degenerate solutions. + + Parameters + ========== + + eqs : list + List of expressions to be factored. + Each expression is assumed to be equal to zero. + + gens : list, optional + Generator(s) of the polynomial ring. + If not provided, all free symbols will be used. + + **kwargs : dict, optional + Optional keyword arguments. + + Returns + ======= + + list[list[Expr]] + A list of lists of expressions, where each sublist represents + an irreducible subsystem. Includes both generic solutions and + degenerate cases requiring equality conditions on parameters. + + Examples + ======== + + >>> from sympy.solvers.polysys import factor_system_cond + >>> from sympy.abc import x, y, a, b, c + + >>> factor_system_cond([x**2 - 4, a*y, b], [x, y]) + [[x + 2, y, b], [x - 2, y, b], [x + 2, a, b], [x - 2, a, b]] + + >>> factor_system_cond([a*x*(x-1), b*y, c], [x, y]) + [[x - 1, y, c], [x, y, c], [x - 1, b, c], [x, b, c], [y, a, c], [a, b, c]] + + An empty list [] means no solution exists. + A list containing an empty list [[]] means any value of + the symbol(s) is a solution. + + See Also + ======== + + factor_system : Returns only generic solutions + factor_system_bool : Returns a Boolean combination representing all solutions + sympy.polys.polytools.factor : Factors a polynomial into irreducible factors + over the rational numbers + """ + systems_poly = _factor_system_poly_from_expr(eqs, gens, **kwargs) + systems = [[p.as_expr() for p in system] for system in systems_poly] + return systems + + +def _factor_system_poly_from_expr( + eqs: Sequence[Expr | complex], gens: Sequence[Expr], **kwargs: Any +) -> list[list[Poly]]: + """ + Convert expressions to polynomials and factor the system. + + Takes a sequence of expressions, converts them to + polynomials, and factors the resulting system. Handles both regular + polynomial systems and purely numerical cases. + """ + try: + polys, opts = parallel_poly_from_expr(eqs, *gens, **kwargs) + only_numbers = False + except (GeneratorsNeeded, PolificationFailed): + _u = Dummy('u') + polys, opts = parallel_poly_from_expr(eqs, [_u], **kwargs) + assert opts['domain'].is_Numerical + only_numbers = True + + if only_numbers: + return [[]] if all(p == 0 for p in polys) else [] + + return factor_system_poly(polys) + + +def factor_system_poly(polys: list[Poly]) -> list[list[Poly]]: + """ + Factors a system of polynomial equations into irreducible subsystems + + Core implementation that works directly with Poly instances. + + Parameters + ========== + + polys : list[Poly] + A list of Poly instances to be factored. + + Returns + ======= + + list[list[Poly]] + A list of lists of polynomials, where each sublist represents + an irreducible component of the solution. Includes both + generic and degenerate cases. + + Examples + ======== + + >>> from sympy import symbols, Poly, ZZ + >>> from sympy.solvers.polysys import factor_system_poly + >>> a, b, c, x = symbols('a b c x') + >>> p1 = Poly((a - 1)*(x - 2), x, domain=ZZ[a,b,c]) + >>> p2 = Poly((b - 3)*(x - 2), x, domain=ZZ[a,b,c]) + >>> p3 = Poly(c, x, domain=ZZ[a,b,c]) + + The equation to be solved for x is ``x - 2 = 0`` provided either + of the two conditions on the parameters ``a`` and ``b`` is nonzero + and the constant parameter ``c`` should be zero. + + >>> sys1, sys2 = factor_system_poly([p1, p2, p3]) + >>> sys1 + [Poly(x - 2, x, domain='ZZ[a,b,c]'), + Poly(c, x, domain='ZZ[a,b,c]')] + >>> sys2 + [Poly(a - 1, x, domain='ZZ[a,b,c]'), + Poly(b - 3, x, domain='ZZ[a,b,c]'), + Poly(c, x, domain='ZZ[a,b,c]')] + + An empty list [] when returned means no solution exists. + Whereas a list containing an empty list [[]] means any value is a solution. + + See Also + ======== + + factor_system : Returns only generic solutions + factor_system_bool : Returns a Boolean combination representing the solutions + factor_system_cond : Returns both generic and degenerate solutions + sympy.polys.polytools.factor : Factors a polynomial into irreducible factors + over the rational numbers + """ + if not all(isinstance(poly, Poly) for poly in polys): + raise TypeError("polys should be a list of Poly instances") + if not polys: + return [[]] + + base_domain = polys[0].domain + base_gens = polys[0].gens + if not all(poly.domain == base_domain and poly.gens == base_gens for poly in polys[1:]): + raise DomainError("All polynomials must have the same domain and generators") + + factor_sets = [] + for poly in polys: + constant, factors_mult = poly.factor_list() + + if constant.is_zero is True: + continue + elif constant.is_zero is False: + if not factors_mult: + return [] + factor_sets.append([f for f, _ in factors_mult]) + else: + constant = sqf_part(factor_terms(constant).as_coeff_Mul()[1]) + constp = Poly(constant, base_gens, domain=base_domain) + factors = [f for f, _ in factors_mult] + factors.append(constp) + factor_sets.append(factors) + + if not factor_sets: + return [[]] + + result = _factor_sets(factor_sets) + return _sort_systems(result) + + +def _factor_sets_slow(eqs: list[list]) -> set[frozenset]: + """ + Helper to find the minimal set of factorised subsystems that is + equivalent to the original system. + + The result is in DNF. + """ + if not eqs: + return {frozenset()} + systems_set = {frozenset(sys) for sys in cartes(*eqs)} + return {s1 for s1 in systems_set if not any(s1 > s2 for s2 in systems_set)} + + +def _factor_sets(eqs: list[list]) -> set[frozenset]: + """ + Helper that builds factor combinations. + """ + if not eqs: + return {frozenset()} + + current_set = min(eqs, key=len) + other_sets = [s for s in eqs if s is not current_set] + + stack = [(factor, [s for s in other_sets if factor not in s], {factor}) + for factor in current_set] + + result = set() + + while stack: + factor, remaining_sets, current_solution = stack.pop() + + if not remaining_sets: + result.add(frozenset(current_solution)) + continue + + next_set = min(remaining_sets, key=len) + next_remaining = [s for s in remaining_sets if s is not next_set] + + for next_factor in next_set: + valid_remaining = [s for s in next_remaining if next_factor not in s] + new_solution = current_solution | {next_factor} + stack.append((next_factor, valid_remaining, new_solution)) + + return {s1 for s1 in result if not any(s1 > s2 for s2 in result)} + + +def _sort_systems(systems: Iterable[Iterable[Poly]]) -> list[list[Poly]]: + """Sorts a list of lists of polynomials""" + systems_list = [sorted(s, key=_poly_sort_key, reverse=True) for s in systems] + return sorted(systems_list, key=_sys_sort_key, reverse=True) + + +def _poly_sort_key(poly): + """Sort key for polynomials""" + if poly.domain.is_FF: + poly = poly.set_domain(ZZ) + return poly.degree_list(), poly.rep.to_list() + + +def _sys_sort_key(sys): + """Sort key for lists of polynomials""" + return list(zip(*map(_poly_sort_key, sys))) diff --git a/phivenv/Lib/site-packages/sympy/solvers/recurr.py b/phivenv/Lib/site-packages/sympy/solvers/recurr.py new file mode 100644 index 0000000000000000000000000000000000000000..ba627bbd4cb0844f11a8743634f5f10328aadca8 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/solvers/recurr.py @@ -0,0 +1,843 @@ +r""" +This module is intended for solving recurrences or, in other words, +difference equations. Currently supported are linear, inhomogeneous +equations with polynomial or rational coefficients. + +The solutions are obtained among polynomials, rational functions, +hypergeometric terms, or combinations of hypergeometric term which +are pairwise dissimilar. + +``rsolve_X`` functions were meant as a low level interface +for ``rsolve`` which would use Mathematica's syntax. + +Given a recurrence relation: + + .. math:: a_{k}(n) y(n+k) + a_{k-1}(n) y(n+k-1) + + ... + a_{0}(n) y(n) = f(n) + +where `k > 0` and `a_{i}(n)` are polynomials in `n`. To use +``rsolve_X`` we need to put all coefficients in to a list ``L`` of +`k+1` elements the following way: + + ``L = [a_{0}(n), ..., a_{k-1}(n), a_{k}(n)]`` + +where ``L[i]``, for `i=0, \ldots, k`, maps to +`a_{i}(n) y(n+i)` (`y(n+i)` is implicit). + +For example if we would like to compute `m`-th Bernoulli polynomial +up to a constant (example was taken from rsolve_poly docstring), +then we would use `b(n+1) - b(n) = m n^{m-1}` recurrence, which +has solution `b(n) = B_m + C`. + +Then ``L = [-1, 1]`` and `f(n) = m n^(m-1)` and finally for `m=4`: + +>>> from sympy import Symbol, bernoulli, rsolve_poly +>>> n = Symbol('n', integer=True) + +>>> rsolve_poly([-1, 1], 4*n**3, n) +C0 + n**4 - 2*n**3 + n**2 + +>>> bernoulli(4, n) +n**4 - 2*n**3 + n**2 - 1/30 + +For the sake of completeness, `f(n)` can be: + + [1] a polynomial -> rsolve_poly + [2] a rational function -> rsolve_ratio + [3] a hypergeometric function -> rsolve_hyper +""" +from collections import defaultdict + +from sympy.concrete import product +from sympy.core.singleton import S +from sympy.core.numbers import Rational, I +from sympy.core.symbol import Symbol, Wild, Dummy +from sympy.core.relational import Equality +from sympy.core.add import Add +from sympy.core.mul import Mul +from sympy.core.sorting import default_sort_key +from sympy.core.sympify import sympify + +from sympy.simplify import simplify, hypersimp, hypersimilar # type: ignore +from sympy.solvers import solve, solve_undetermined_coeffs +from sympy.polys import Poly, quo, gcd, lcm, roots, resultant +from sympy.functions import binomial, factorial, FallingFactorial, RisingFactorial +from sympy.matrices import Matrix, casoratian +from sympy.utilities.iterables import numbered_symbols + + +def rsolve_poly(coeffs, f, n, shift=0, **hints): + r""" + Given linear recurrence operator `\operatorname{L}` of order + `k` with polynomial coefficients and inhomogeneous equation + `\operatorname{L} y = f`, where `f` is a polynomial, we seek for + all polynomial solutions over field `K` of characteristic zero. + + The algorithm performs two basic steps: + + (1) Compute degree `N` of the general polynomial solution. + (2) Find all polynomials of degree `N` or less + of `\operatorname{L} y = f`. + + There are two methods for computing the polynomial solutions. + If the degree bound is relatively small, i.e. it's smaller than + or equal to the order of the recurrence, then naive method of + undetermined coefficients is being used. This gives a system + of algebraic equations with `N+1` unknowns. + + In the other case, the algorithm performs transformation of the + initial equation to an equivalent one for which the system of + algebraic equations has only `r` indeterminates. This method is + quite sophisticated (in comparison with the naive one) and was + invented together by Abramov, Bronstein and Petkovsek. + + It is possible to generalize the algorithm implemented here to + the case of linear q-difference and differential equations. + + Lets say that we would like to compute `m`-th Bernoulli polynomial + up to a constant. For this we can use `b(n+1) - b(n) = m n^{m-1}` + recurrence, which has solution `b(n) = B_m + C`. For example: + + >>> from sympy import Symbol, rsolve_poly + >>> n = Symbol('n', integer=True) + + >>> rsolve_poly([-1, 1], 4*n**3, n) + C0 + n**4 - 2*n**3 + n**2 + + References + ========== + + .. [1] S. A. Abramov, M. Bronstein and M. Petkovsek, On polynomial + solutions of linear operator equations, in: T. Levelt, ed., + Proc. ISSAC '95, ACM Press, New York, 1995, 290-296. + + .. [2] M. Petkovsek, Hypergeometric solutions of linear recurrences + with polynomial coefficients, J. Symbolic Computation, + 14 (1992), 243-264. + + .. [3] M. Petkovsek, H. S. Wilf, D. Zeilberger, A = B, 1996. + + """ + f = sympify(f) + + if not f.is_polynomial(n): + return None + + homogeneous = f.is_zero + + r = len(coeffs) - 1 + + coeffs = [Poly(coeff, n) for coeff in coeffs] + + polys = [Poly(0, n)]*(r + 1) + terms = [(S.Zero, S.NegativeInfinity)]*(r + 1) + + for i in range(r + 1): + for j in range(i, r + 1): + polys[i] += coeffs[j]*(binomial(j, i).as_poly(n)) + + if not polys[i].is_zero: + (exp,), coeff = polys[i].LT() + terms[i] = (coeff, exp) + + d = b = terms[0][1] + + for i in range(1, r + 1): + if terms[i][1] > d: + d = terms[i][1] + + if terms[i][1] - i > b: + b = terms[i][1] - i + + d, b = int(d), int(b) + + x = Dummy('x') + + degree_poly = S.Zero + + for i in range(r + 1): + if terms[i][1] - i == b: + degree_poly += terms[i][0]*FallingFactorial(x, i) + + nni_roots = list(roots(degree_poly, x, filter='Z', + predicate=lambda r: r >= 0).keys()) + + if nni_roots: + N = [max(nni_roots)] + else: + N = [] + + if homogeneous: + N += [-b - 1] + else: + N += [f.as_poly(n).degree() - b, -b - 1] + + N = int(max(N)) + + if N < 0: + if homogeneous: + if hints.get('symbols', False): + return (S.Zero, []) + else: + return S.Zero + else: + return None + + if N <= r: + C = [] + y = E = S.Zero + + for i in range(N + 1): + C.append(Symbol('C' + str(i + shift))) + y += C[i] * n**i + + for i in range(r + 1): + E += coeffs[i].as_expr()*y.subs(n, n + i) + + solutions = solve_undetermined_coeffs(E - f, C, n) + + if solutions is not None: + _C = C + C = [c for c in C if (c not in solutions)] + result = y.subs(solutions) + else: + return None # TBD + else: + A = r + U = N + A + b + 1 + + nni_roots = list(roots(polys[r], filter='Z', + predicate=lambda r: r >= 0).keys()) + + if nni_roots != []: + a = max(nni_roots) + 1 + else: + a = S.Zero + + def _zero_vector(k): + return [S.Zero] * k + + def _one_vector(k): + return [S.One] * k + + def _delta(p, k): + B = S.One + D = p.subs(n, a + k) + + for i in range(1, k + 1): + B *= Rational(i - k - 1, i) + D += B * p.subs(n, a + k - i) + + return D + + alpha = {} + + for i in range(-A, d + 1): + I = _one_vector(d + 1) + + for k in range(1, d + 1): + I[k] = I[k - 1] * (x + i - k + 1)/k + + alpha[i] = S.Zero + + for j in range(A + 1): + for k in range(d + 1): + B = binomial(k, i + j) + D = _delta(polys[j].as_expr(), k) + + alpha[i] += I[k]*B*D + + V = Matrix(U, A, lambda i, j: int(i == j)) + + if homogeneous: + for i in range(A, U): + v = _zero_vector(A) + + for k in range(1, A + b + 1): + if i - k < 0: + break + + B = alpha[k - A].subs(x, i - k) + + for j in range(A): + v[j] += B * V[i - k, j] + + denom = alpha[-A].subs(x, i) + + for j in range(A): + V[i, j] = -v[j] / denom + else: + G = _zero_vector(U) + + for i in range(A, U): + v = _zero_vector(A) + g = S.Zero + + for k in range(1, A + b + 1): + if i - k < 0: + break + + B = alpha[k - A].subs(x, i - k) + + for j in range(A): + v[j] += B * V[i - k, j] + + g += B * G[i - k] + + denom = alpha[-A].subs(x, i) + + for j in range(A): + V[i, j] = -v[j] / denom + + G[i] = (_delta(f, i - A) - g) / denom + + P, Q = _one_vector(U), _zero_vector(A) + + for i in range(1, U): + P[i] = (P[i - 1] * (n - a - i + 1)/i).expand() + + for i in range(A): + Q[i] = Add(*[(v*p).expand() for v, p in zip(V[:, i], P)]) + + if not homogeneous: + h = Add(*[(g*p).expand() for g, p in zip(G, P)]) + + C = [Symbol('C' + str(i + shift)) for i in range(A)] + + g = lambda i: Add(*[c*_delta(q, i) for c, q in zip(C, Q)]) + + if homogeneous: + E = [g(i) for i in range(N + 1, U)] + else: + E = [g(i) + _delta(h, i) for i in range(N + 1, U)] + + if E != []: + solutions = solve(E, *C) + + if not solutions: + if homogeneous: + if hints.get('symbols', False): + return (S.Zero, []) + else: + return S.Zero + else: + return None + else: + solutions = {} + + if homogeneous: + result = S.Zero + else: + result = h + + _C = C[:] + for c, q in list(zip(C, Q)): + if c in solutions: + s = solutions[c]*q + C.remove(c) + else: + s = c*q + + result += s.expand() + + if C != _C: + # renumber so they are contiguous + result = result.xreplace(dict(zip(C, _C))) + C = _C[:len(C)] + + if hints.get('symbols', False): + return (result, C) + else: + return result + + +def rsolve_ratio(coeffs, f, n, **hints): + r""" + Given linear recurrence operator `\operatorname{L}` of order `k` + with polynomial coefficients and inhomogeneous equation + `\operatorname{L} y = f`, where `f` is a polynomial, we seek + for all rational solutions over field `K` of characteristic zero. + + This procedure accepts only polynomials, however if you are + interested in solving recurrence with rational coefficients + then use ``rsolve`` which will pre-process the given equation + and run this procedure with polynomial arguments. + + The algorithm performs two basic steps: + + (1) Compute polynomial `v(n)` which can be used as universal + denominator of any rational solution of equation + `\operatorname{L} y = f`. + + (2) Construct new linear difference equation by substitution + `y(n) = u(n)/v(n)` and solve it for `u(n)` finding all its + polynomial solutions. Return ``None`` if none were found. + + The algorithm implemented here is a revised version of the original + Abramov's algorithm, developed in 1989. The new approach is much + simpler to implement and has better overall efficiency. This + method can be easily adapted to the q-difference equations case. + + Besides finding rational solutions alone, this functions is + an important part of Hyper algorithm where it is used to find + a particular solution for the inhomogeneous part of a recurrence. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy.solvers.recurr import rsolve_ratio + >>> rsolve_ratio([-2*x**3 + x**2 + 2*x - 1, 2*x**3 + x**2 - 6*x, + ... - 2*x**3 - 11*x**2 - 18*x - 9, 2*x**3 + 13*x**2 + 22*x + 8], 0, x) + C0*(2*x - 3)/(2*(x**2 - 1)) + + References + ========== + + .. [1] S. A. Abramov, Rational solutions of linear difference + and q-difference equations with polynomial coefficients, + in: T. Levelt, ed., Proc. ISSAC '95, ACM Press, New York, + 1995, 285-289 + + See Also + ======== + + rsolve_hyper + """ + f = sympify(f) + + if not f.is_polynomial(n): + return None + + coeffs = list(map(sympify, coeffs)) + + r = len(coeffs) - 1 + + A, B = coeffs[r], coeffs[0] + A = A.subs(n, n - r).expand() + + h = Dummy('h') + + res = resultant(A, B.subs(n, n + h), n) + + if not res.is_polynomial(h): + p, q = res.as_numer_denom() + res = quo(p, q, h) + + nni_roots = list(roots(res, h, filter='Z', + predicate=lambda r: r >= 0).keys()) + + if not nni_roots: + return rsolve_poly(coeffs, f, n, **hints) + else: + C, numers = S.One, [S.Zero]*(r + 1) + + for i in range(int(max(nni_roots)), -1, -1): + d = gcd(A, B.subs(n, n + i), n) + + A = quo(A, d, n) + B = quo(B, d.subs(n, n - i), n) + + C *= Mul(*[d.subs(n, n - j) for j in range(i + 1)]) + + denoms = [C.subs(n, n + i) for i in range(r + 1)] + + for i in range(r + 1): + g = gcd(coeffs[i], denoms[i], n) + + numers[i] = quo(coeffs[i], g, n) + denoms[i] = quo(denoms[i], g, n) + + for i in range(r + 1): + numers[i] *= Mul(*(denoms[:i] + denoms[i + 1:])) + + result = rsolve_poly(numers, f * Mul(*denoms), n, **hints) + + if result is not None: + if hints.get('symbols', False): + return (simplify(result[0] / C), result[1]) + else: + return simplify(result / C) + else: + return None + + +def rsolve_hyper(coeffs, f, n, **hints): + r""" + Given linear recurrence operator `\operatorname{L}` of order `k` + with polynomial coefficients and inhomogeneous equation + `\operatorname{L} y = f` we seek for all hypergeometric solutions + over field `K` of characteristic zero. + + The inhomogeneous part can be either hypergeometric or a sum + of a fixed number of pairwise dissimilar hypergeometric terms. + + The algorithm performs three basic steps: + + (1) Group together similar hypergeometric terms in the + inhomogeneous part of `\operatorname{L} y = f`, and find + particular solution using Abramov's algorithm. + + (2) Compute generating set of `\operatorname{L}` and find basis + in it, so that all solutions are linearly independent. + + (3) Form final solution with the number of arbitrary + constants equal to dimension of basis of `\operatorname{L}`. + + Term `a(n)` is hypergeometric if it is annihilated by first order + linear difference equations with polynomial coefficients or, in + simpler words, if consecutive term ratio is a rational function. + + The output of this procedure is a linear combination of fixed + number of hypergeometric terms. However the underlying method + can generate larger class of solutions - D'Alembertian terms. + + Note also that this method not only computes the kernel of the + inhomogeneous equation, but also reduces in to a basis so that + solutions generated by this procedure are linearly independent + + Examples + ======== + + >>> from sympy.solvers import rsolve_hyper + >>> from sympy.abc import x + + >>> rsolve_hyper([-1, -1, 1], 0, x) + C0*(1/2 - sqrt(5)/2)**x + C1*(1/2 + sqrt(5)/2)**x + + >>> rsolve_hyper([-1, 1], 1 + x, x) + C0 + x*(x + 1)/2 + + References + ========== + + .. [1] M. Petkovsek, Hypergeometric solutions of linear recurrences + with polynomial coefficients, J. Symbolic Computation, + 14 (1992), 243-264. + + .. [2] M. Petkovsek, H. S. Wilf, D. Zeilberger, A = B, 1996. + """ + coeffs = list(map(sympify, coeffs)) + + f = sympify(f) + + r, kernel, symbols = len(coeffs) - 1, [], set() + + if not f.is_zero: + if f.is_Add: + similar = {} + + for g in f.expand().args: + if not g.is_hypergeometric(n): + return None + + for h in similar.keys(): + if hypersimilar(g, h, n): + similar[h] += g + break + else: + similar[g] = S.Zero + + inhomogeneous = [g + h for g, h in similar.items()] + elif f.is_hypergeometric(n): + inhomogeneous = [f] + else: + return None + + for i, g in enumerate(inhomogeneous): + coeff, polys = S.One, coeffs[:] + denoms = [S.One]*(r + 1) + + s = hypersimp(g, n) + + for j in range(1, r + 1): + coeff *= s.subs(n, n + j - 1) + + p, q = coeff.as_numer_denom() + + polys[j] *= p + denoms[j] = q + + for j in range(r + 1): + polys[j] *= Mul(*(denoms[:j] + denoms[j + 1:])) + + # FIXME: The call to rsolve_ratio below should suffice (rsolve_poly + # call can be removed) but the XFAIL test_rsolve_ratio_missed must + # be fixed first. + R = rsolve_ratio(polys, Mul(*denoms), n, symbols=True) + if R is not None: + R, syms = R + if syms: + R = R.subs(zip(syms, [0]*len(syms))) + else: + R = rsolve_poly(polys, Mul(*denoms), n) + + if R: + inhomogeneous[i] *= R + else: + return None + + result = Add(*inhomogeneous) + result = simplify(result) + else: + result = S.Zero + + Z = Dummy('Z') + + p, q = coeffs[0], coeffs[r].subs(n, n - r + 1) + + p_factors = list(roots(p, n).keys()) + q_factors = list(roots(q, n).keys()) + + factors = [(S.One, S.One)] + + for p in p_factors: + for q in q_factors: + if p.is_integer and q.is_integer and p <= q: + continue + else: + factors += [(n - p, n - q)] + + p = [(n - p, S.One) for p in p_factors] + q = [(S.One, n - q) for q in q_factors] + + factors = p + factors + q + + for A, B in factors: + polys, degrees = [], [] + D = A*B.subs(n, n + r - 1) + + for i in range(r + 1): + a = Mul(*[A.subs(n, n + j) for j in range(i)]) + b = Mul(*[B.subs(n, n + j) for j in range(i, r)]) + + poly = quo(coeffs[i]*a*b, D, n) + polys.append(poly.as_poly(n)) + + if not poly.is_zero: + degrees.append(polys[i].degree()) + + if degrees: + d, poly = max(degrees), S.Zero + else: + return None + + for i in range(r + 1): + coeff = polys[i].nth(d) + + if coeff is not S.Zero: + poly += coeff * Z**i + + for z in roots(poly, Z).keys(): + if z.is_zero: + continue + + recurr_coeffs = [polys[i].as_expr()*z**i for i in range(r + 1)] + if d == 0 and 0 != Add(*[recurr_coeffs[j]*j for j in range(1, r + 1)]): + # faster inline check (than calling rsolve_poly) for a + # constant solution to a constant coefficient recurrence. + sol = [Symbol("C" + str(len(symbols)))] + else: + sol, syms = rsolve_poly(recurr_coeffs, 0, n, len(symbols), symbols=True) + sol = sol.collect(syms) + sol = [sol.coeff(s) for s in syms] + + for C in sol: + ratio = z * A * C.subs(n, n + 1) / B / C + ratio = simplify(ratio) + # If there is a nonnegative root in the denominator of the ratio, + # this indicates that the term y(n_root) is zero, and one should + # start the product with the term y(n_root + 1). + n0 = 0 + for n_root in roots(ratio.as_numer_denom()[1], n).keys(): + if n_root.has(I): + return None + elif (n0 < (n_root + 1)) == True: + n0 = n_root + 1 + K = product(ratio, (n, n0, n - 1)) + if K.has(factorial, FallingFactorial, RisingFactorial): + K = simplify(K) + + if casoratian(kernel + [K], n, zero=False) != 0: + kernel.append(K) + + kernel.sort(key=default_sort_key) + sk = list(zip(numbered_symbols('C'), kernel)) + + for C, ker in sk: + result += C * ker + + if hints.get('symbols', False): + # XXX: This returns the symbols in a non-deterministic order + symbols |= {s for s, k in sk} + return (result, list(symbols)) + else: + return result + + +def rsolve(f, y, init=None): + r""" + Solve univariate recurrence with rational coefficients. + + Given `k`-th order linear recurrence `\operatorname{L} y = f`, + or equivalently: + + .. math:: a_{k}(n) y(n+k) + a_{k-1}(n) y(n+k-1) + + \cdots + a_{0}(n) y(n) = f(n) + + where `a_{i}(n)`, for `i=0, \ldots, k`, are polynomials or rational + functions in `n`, and `f` is a hypergeometric function or a sum + of a fixed number of pairwise dissimilar hypergeometric terms in + `n`, finds all solutions or returns ``None``, if none were found. + + Initial conditions can be given as a dictionary in two forms: + + (1) ``{ n_0 : v_0, n_1 : v_1, ..., n_m : v_m}`` + (2) ``{y(n_0) : v_0, y(n_1) : v_1, ..., y(n_m) : v_m}`` + + or as a list ``L`` of values: + + ``L = [v_0, v_1, ..., v_m]`` + + where ``L[i] = v_i``, for `i=0, \ldots, m`, maps to `y(n_i)`. + + Examples + ======== + + Lets consider the following recurrence: + + .. math:: (n - 1) y(n + 2) - (n^2 + 3 n - 2) y(n + 1) + + 2 n (n + 1) y(n) = 0 + + >>> from sympy import Function, rsolve + >>> from sympy.abc import n + >>> y = Function('y') + + >>> f = (n - 1)*y(n + 2) - (n**2 + 3*n - 2)*y(n + 1) + 2*n*(n + 1)*y(n) + + >>> rsolve(f, y(n)) + 2**n*C0 + C1*factorial(n) + + >>> rsolve(f, y(n), {y(0):0, y(1):3}) + 3*2**n - 3*factorial(n) + + See Also + ======== + + rsolve_poly, rsolve_ratio, rsolve_hyper + + """ + if isinstance(f, Equality): + f = f.lhs - f.rhs + + n = y.args[0] + k = Wild('k', exclude=(n,)) + + # Preprocess user input to allow things like + # y(n) + a*(y(n + 1) + y(n - 1))/2 + f = f.expand().collect(y.func(Wild('m', integer=True))) + + h_part = defaultdict(list) + i_part = [] + for g in Add.make_args(f): + coeff, dep = g.as_coeff_mul(y.func) + if not dep: + i_part.append(coeff) + continue + for h in dep: + if h.is_Function and h.func == y.func: + result = h.args[0].match(n + k) + if result is not None: + h_part[int(result[k])].append(coeff) + continue + raise ValueError( + "'%s(%s + k)' expected, got '%s'" % (y.func, n, h)) + for k in h_part: + h_part[k] = Add(*h_part[k]) + h_part.default_factory = lambda: 0 + i_part = Add(*i_part) + + for k, coeff in h_part.items(): + h_part[k] = simplify(coeff) + + common = S.One + + if not i_part.is_zero and not i_part.is_hypergeometric(n) and \ + not (i_part.is_Add and all((x.is_hypergeometric(n) for x in i_part.expand().args))): + raise ValueError("The independent term should be a sum of hypergeometric functions, got '%s'" % i_part) + + for coeff in h_part.values(): + if coeff.is_rational_function(n): + if not coeff.is_polynomial(n): + common = lcm(common, coeff.as_numer_denom()[1], n) + else: + raise ValueError( + "Polynomial or rational function expected, got '%s'" % coeff) + + i_numer, i_denom = i_part.as_numer_denom() + + if i_denom.is_polynomial(n): + common = lcm(common, i_denom, n) + + if common is not S.One: + for k, coeff in h_part.items(): + numer, denom = coeff.as_numer_denom() + h_part[k] = numer*quo(common, denom, n) + + i_part = i_numer*quo(common, i_denom, n) + + K_min = min(h_part.keys()) + + if K_min < 0: + K = abs(K_min) + + H_part = defaultdict(lambda: S.Zero) + i_part = i_part.subs(n, n + K).expand() + common = common.subs(n, n + K).expand() + + for k, coeff in h_part.items(): + H_part[k + K] = coeff.subs(n, n + K).expand() + else: + H_part = h_part + + K_max = max(H_part.keys()) + coeffs = [H_part[i] for i in range(K_max + 1)] + + result = rsolve_hyper(coeffs, -i_part, n, symbols=True) + + if result is None: + return None + + solution, symbols = result + + if init in ({}, []): + init = None + + if symbols and init is not None: + if isinstance(init, list): + init = {i: init[i] for i in range(len(init))} + + equations = [] + + for k, v in init.items(): + try: + i = int(k) + except TypeError: + if k.is_Function and k.func == y.func: + i = int(k.args[0]) + else: + raise ValueError("Integer or term expected, got '%s'" % k) + + eq = solution.subs(n, i) - v + if eq.has(S.NaN): + eq = solution.limit(n, i) - v + equations.append(eq) + + result = solve(equations, *symbols) + + if not result: + return None + else: + solution = solution.subs(result) + + return solution diff --git a/phivenv/Lib/site-packages/sympy/solvers/simplex.py b/phivenv/Lib/site-packages/sympy/solvers/simplex.py new file mode 100644 index 0000000000000000000000000000000000000000..c8e652cb626507d7829f9bc1c78fc6f49809865f --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/solvers/simplex.py @@ -0,0 +1,1104 @@ +"""Tools for optimizing a linear function for a given simplex. + +For the linear objective function ``f`` with linear constraints +expressed using `Le`, `Ge` or `Eq` can be found with ``lpmin`` or +``lpmax``. The symbols are **unbounded** unless specifically +constrained. + +As an alternative, the matrices describing the objective and the +constraints, and an optional list of bounds can be passed to +``linprog`` which will solve for the minimization of ``C*x`` +under constraints ``A*x <= b`` and/or ``Aeq*x = beq``, and +individual bounds for variables given as ``(lo, hi)``. The values +returned are **nonnegative** unless bounds are provided that +indicate otherwise. + +Errors that might be raised are UnboundedLPError when there is no +finite solution for the system or InfeasibleLPError when the +constraints represent impossible conditions (i.e. a non-existent + simplex). + +Here is a simple 1-D system: minimize `x` given that ``x >= 1``. + + >>> from sympy.solvers.simplex import lpmin, linprog + >>> from sympy.abc import x + + The function and a list with the constraint is passed directly + to `lpmin`: + + >>> lpmin(x, [x >= 1]) + (1, {x: 1}) + + For `linprog` the matrix for the objective is `[1]` and the + uivariate constraint can be passed as a bound with None acting + as infinity: + + >>> linprog([1], bounds=(1, None)) + (1, [1]) + + Or the matrices, corresponding to ``x >= 1`` expressed as + ``-x <= -1`` as required by the routine, can be passed: + + >>> linprog([1], [-1], [-1]) + (1, [1]) + + If there is no limit for the objective, an error is raised. + In this case there is a valid region of interest (simplex) + but no limit to how small ``x`` can be: + + >>> lpmin(x, []) + Traceback (most recent call last): + ... + sympy.solvers.simplex.UnboundedLPError: + Objective function can assume arbitrarily large values! + + An error is raised if there is no possible solution: + + >>> lpmin(x,[x<=1,x>=2]) + Traceback (most recent call last): + ... + sympy.solvers.simplex.InfeasibleLPError: + Inconsistent/False constraint +""" + +from sympy.core import sympify +from sympy.core.exprtools import factor_terms +from sympy.core.relational import Le, Ge, Eq +from sympy.core.singleton import S +from sympy.core.symbol import Dummy +from sympy.core.sorting import ordered +from sympy.functions.elementary.complexes import sign +from sympy.matrices.dense import Matrix, zeros +from sympy.solvers.solveset import linear_eq_to_matrix +from sympy.utilities.iterables import numbered_symbols +from sympy.utilities.misc import filldedent + + +class UnboundedLPError(Exception): + """ + A linear programming problem is said to be unbounded if its objective + function can assume arbitrarily large values. + + Example + ======= + + Suppose you want to maximize + 2x + subject to + x >= 0 + + There's no upper limit that 2x can take. + """ + + pass + + +class InfeasibleLPError(Exception): + """ + A linear programming problem is considered infeasible if its + constraint set is empty. That is, if the set of all vectors + satisfying the constraints is empty, then the problem is infeasible. + + Example + ======= + + Suppose you want to maximize + x + subject to + x >= 10 + x <= 9 + + No x can satisfy those constraints. + """ + + pass + + +def _pivot(M, i, j): + """ + The pivot element `M[i, j]` is inverted and the rest of the matrix + modified and returned as a new matrix; original is left unmodified. + + Example + ======= + + >>> from sympy.matrices.dense import Matrix + >>> from sympy.solvers.simplex import _pivot + >>> from sympy import var + >>> Matrix(3, 3, var('a:i')) + Matrix([ + [a, b, c], + [d, e, f], + [g, h, i]]) + >>> _pivot(_, 1, 0) + Matrix([ + [-a/d, -a*e/d + b, -a*f/d + c], + [ 1/d, e/d, f/d], + [-g/d, h - e*g/d, i - f*g/d]]) + """ + Mi, Mj, Mij = M[i, :], M[:, j], M[i, j] + if Mij == 0: + raise ZeroDivisionError( + "Tried to pivot about zero-valued entry.") + A = M - Mj * (Mi / Mij) + A[i, :] = Mi / Mij + A[:, j] = -Mj / Mij + A[i, j] = 1 / Mij + return A + + +def _choose_pivot_row(A, B, candidate_rows, pivot_col, Y): + # Choose row with smallest ratio + # If there are ties, pick using Bland's rule + return min(candidate_rows, key=lambda i: (B[i] / A[i, pivot_col], Y[i])) + + +def _simplex(A, B, C, D=None, dual=False): + """Return ``(o, x, y)`` obtained from the two-phase simplex method + using Bland's rule: ``o`` is the minimum value of primal, + ``Cx - D``, under constraints ``Ax <= B`` (with ``x >= 0``) and + the maximum of the dual, ``y^{T}B - D``, under constraints + ``A^{T}*y >= C^{T}`` (with ``y >= 0``). To compute the dual of + the system, pass `dual=True` and ``(o, y, x)`` will be returned. + + Note: the nonnegative constraints for ``x`` and ``y`` supercede + any values of ``A`` and ``B`` that are inconsistent with that + assumption, so if a constraint of ``x >= -1`` is represented + in ``A`` and ``B``, no value will be obtained that is negative; if + a constraint of ``x <= -1`` is represented, an error will be + raised since no solution is possible. + + This routine relies on the ability of determining whether an + expression is 0 or not. This is guaranteed if the input contains + only Float or Rational entries. It will raise a TypeError if + a relationship does not evaluate to True or False. + + Examples + ======== + + >>> from sympy.solvers.simplex import _simplex + >>> from sympy import Matrix + + Consider the simple minimization of ``f = x + y + 1`` under the + constraint that ``y + 2*x >= 4``. This is the "standard form" of + a minimization. + + In the nonnegative quadrant, this inequality describes a area above + a triangle with vertices at (0, 4), (0, 0) and (2, 0). The minimum + of ``f`` occurs at (2, 0). Define A, B, C, D for the standard + minimization: + + >>> A = Matrix([[2, 1]]) + >>> B = Matrix([4]) + >>> C = Matrix([[1, 1]]) + >>> D = Matrix([-1]) + + Confirm that this is the system of interest: + + >>> from sympy.abc import x, y + >>> X = Matrix([x, y]) + >>> (C*X - D)[0] + x + y + 1 + >>> [i >= j for i, j in zip(A*X, B)] + [2*x + y >= 4] + + Since `_simplex` will do a minimization for constraints given as + ``A*x <= B``, the signs of ``A`` and ``B`` must be negated since + the currently correspond to a greater-than inequality: + + >>> _simplex(-A, -B, C, D) + (3, [2, 0], [1/2]) + + The dual of minimizing ``f`` is maximizing ``F = c*y - d`` for + ``a*y <= b`` where ``a``, ``b``, ``c``, ``d`` are derived from the + transpose of the matrix representation of the standard minimization: + + >>> tr = lambda a, b, c, d: [i.T for i in (a, c, b, d)] + >>> a, b, c, d = tr(A, B, C, D) + + This time ``a*x <= b`` is the expected inequality for the `_simplex` + method, but to maximize ``F``, the sign of ``c`` and ``d`` must be + changed (so that minimizing the negative will give the negative of + the maximum of ``F``): + + >>> _simplex(a, b, -c, -d) + (-3, [1/2], [2, 0]) + + The negative of ``F`` and the min of ``f`` are the same. The dual + point `[1/2]` is the value of ``y`` that minimized ``F = c*y - d`` + under constraints a*x <= b``: + + >>> y = Matrix(['y']) + >>> (c*y - d)[0] + 4*y + 1 + >>> [i <= j for i, j in zip(a*y,b)] + [2*y <= 1, y <= 1] + + In this 1-dimensional dual system, the more restrictive constraint is + the first which limits ``y`` between 0 and 1/2 and the maximum of + ``F`` is attained at the nonzero value, hence is ``4*(1/2) + 1 = 3``. + + In this case the values for ``x`` and ``y`` were the same when the + dual representation was solved. This is not always the case (though + the value of the function will be the same). + + >>> l = [[1, 1], [-1, 1], [0, 1], [-1, 0]], [5, 1, 2, -1], [[1, 1]], [-1] + >>> A, B, C, D = [Matrix(i) for i in l] + >>> _simplex(A, B, -C, -D) + (-6, [3, 2], [1, 0, 0, 0]) + >>> _simplex(A, B, -C, -D, dual=True) # [5, 0] != [3, 2] + (-6, [1, 0, 0, 0], [5, 0]) + + In both cases the function has the same value: + + >>> Matrix(C)*Matrix([3, 2]) == Matrix(C)*Matrix([5, 0]) + True + + See Also + ======== + _lp - poses min/max problem in form compatible with _simplex + lpmin - minimization which calls _lp + lpmax - maximimzation which calls _lp + + References + ========== + + .. [1] Thomas S. Ferguson, LINEAR PROGRAMMING: A Concise Introduction + web.tecnico.ulisboa.pt/mcasquilho/acad/or/ftp/FergusonUCLA_lp.pdf + + """ + A, B, C, D = [Matrix(i) for i in (A, B, C, D or [0])] + if dual: + _o, d, p = _simplex(-A.T, C.T, B.T, -D) + return -_o, d, p + + if A and B: + M = Matrix([[A, B], [C, D]]) + else: + if A or B: + raise ValueError("must give A and B") + # no constraints given + M = Matrix([[C, D]]) + n = M.cols - 1 + m = M.rows - 1 + + if not all(i.is_Float or i.is_Rational for i in M): + # with literal Float and Rational we are guaranteed the + # ability of determining whether an expression is 0 or not + raise TypeError(filldedent(""" + Only rationals and floats are allowed. + """ + ) + ) + + # x variables have priority over y variables during Bland's rule + # since False < True + X = [(False, j) for j in range(n)] + Y = [(True, i) for i in range(m)] + + # Phase 1: find a feasible solution or determine none exist + + ## keep track of last pivot row and column + last = None + + while True: + B = M[:-1, -1] + A = M[:-1, :-1] + if all(B[i] >= 0 for i in range(B.rows)): + # We have found a feasible solution + break + + # Find k: first row with a negative rightmost entry + for k in range(B.rows): + if B[k] < 0: + break # use current value of k below + else: + pass # error will raise below + + # Choose pivot column, c + piv_cols = [_ for _ in range(A.cols) if A[k, _] < 0] + if not piv_cols: + raise InfeasibleLPError(filldedent(""" + The constraint set is empty!""")) + _, c = min((X[i], i) for i in piv_cols) # Bland's rule + + # Choose pivot row, r + piv_rows = [_ for _ in range(A.rows) if A[_, c] > 0 and B[_] > 0] + piv_rows.append(k) + r = _choose_pivot_row(A, B, piv_rows, c, Y) + + # check for oscillation + if (r, c) == last: + # Not sure what to do here; it looks like there will be + # oscillations; see o1 test added at this commit to + # see a system with no solution and the o2 for one + # with a solution. In the case of o2, the solution + # from linprog is the same as the one from lpmin, but + # the matrices created in the lpmin case are different + # than those created without replacements in linprog and + # the matrices in the linprog case lead to oscillations. + # If the matrices could be re-written in linprog like + # lpmin does, this behavior could be avoided and then + # perhaps the oscillating case would only occur when + # there is no solution. For now, the output is checked + # before exit if oscillations were detected and an + # error is raised there if the solution was invalid. + # + # cf section 6 of Ferguson for a non-cycling modification + last = True + break + last = r, c + + M = _pivot(M, r, c) + X[c], Y[r] = Y[r], X[c] + + # Phase 2: from a feasible solution, pivot to optimal + while True: + B = M[:-1, -1] + A = M[:-1, :-1] + C = M[-1, :-1] + + # Choose a pivot column, c + piv_cols = [_ for _ in range(n) if C[_] < 0] + if not piv_cols: + break + _, c = min((X[i], i) for i in piv_cols) # Bland's rule + + # Choose a pivot row, r + piv_rows = [_ for _ in range(m) if A[_, c] > 0] + if not piv_rows: + raise UnboundedLPError(filldedent(""" + Objective function can assume + arbitrarily large values!""")) + r = _choose_pivot_row(A, B, piv_rows, c, Y) + + M = _pivot(M, r, c) + X[c], Y[r] = Y[r], X[c] + + argmax = [None] * n + argmin_dual = [None] * m + + for i, (v, n) in enumerate(X): + if v == False: + argmax[n] = 0 + else: + argmin_dual[n] = M[-1, i] + + for i, (v, n) in enumerate(Y): + if v == True: + argmin_dual[n] = 0 + else: + argmax[n] = M[i, -1] + + if last and not all(i >= 0 for i in argmax + argmin_dual): + raise InfeasibleLPError(filldedent(""" + Oscillating system led to invalid solution. + If you believe there was a valid solution, please + report this as a bug.""")) + return -M[-1, -1], argmax, argmin_dual + + +## routines that use _simplex or support those that do + + +def _abcd(M, list=False): + """return parts of M as matrices or lists + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.solvers.simplex import _abcd + + >>> m = Matrix(3, 3, range(9)); m + Matrix([ + [0, 1, 2], + [3, 4, 5], + [6, 7, 8]]) + >>> a, b, c, d = _abcd(m) + >>> a + Matrix([ + [0, 1], + [3, 4]]) + >>> b + Matrix([ + [2], + [5]]) + >>> c + Matrix([[6, 7]]) + >>> d + Matrix([[8]]) + + The matrices can be returned as compact lists, too: + + >>> L = a, b, c, d = _abcd(m, list=True); L + ([[0, 1], [3, 4]], [2, 5], [[6, 7]], [8]) + """ + + def aslist(i): + l = i.tolist() + if len(l[0]) == 1: # col vector + return [i[0] for i in l] + return l + + m = M[:-1, :-1], M[:-1, -1], M[-1, :-1], M[-1:, -1:] + if not list: + return m + return tuple([aslist(i) for i in m]) + + +def _m(a, b, c, d=None): + """return Matrix([[a, b], [c, d]]) from matrices + in Matrix or list form. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.solvers.simplex import _abcd, _m + >>> m = Matrix(3, 3, range(9)) + >>> L = _abcd(m, list=True); L + ([[0, 1], [3, 4]], [2, 5], [[6, 7]], [8]) + >>> _abcd(m) + (Matrix([ + [0, 1], + [3, 4]]), Matrix([ + [2], + [5]]), Matrix([[6, 7]]), Matrix([[8]])) + >>> assert m == _m(*L) == _m(*_) + """ + a, b, c, d = [Matrix(i) for i in (a, b, c, d or [0])] + return Matrix([[a, b], [c, d]]) + + +def _primal_dual(M, factor=True): + """return primal and dual function and constraints + assuming that ``M = Matrix([[A, b], [c, d]])`` and the + function ``c*x - d`` is being minimized with ``Ax >= b`` + for nonnegative values of ``x``. The dual and its + constraints will be for maximizing `b.T*y - d` subject + to ``A.T*y <= c.T``. + + Examples + ======== + + >>> from sympy.solvers.simplex import _primal_dual, lpmin, lpmax + >>> from sympy import Matrix + + The following matrix represents the primal task of + minimizing x + y + 7 for y >= x + 1 and y >= -2*x + 3. + The dual task seeks to maximize x + 3*y + 7 with + 2*y - x <= 1 and and x + y <= 1: + + >>> M = Matrix([ + ... [-1, 1, 1], + ... [ 2, 1, 3], + ... [ 1, 1, -7]]) + >>> p, d = _primal_dual(M) + + The minimum of the primal and maximum of the dual are the same + (though they occur at different points): + + >>> lpmin(*p) + (28/3, {x1: 2/3, x2: 5/3}) + >>> lpmax(*d) + (28/3, {y1: 1/3, y2: 2/3}) + + If the equivalent (but canonical) inequalities are + desired, leave `factor=True`, otherwise the unmodified + inequalities for M will be returned. + + >>> m = Matrix([ + ... [-3, -2, 4, -2], + ... [ 2, 0, 0, -2], + ... [ 0, 1, -3, 0]]) + + >>> _primal_dual(m, False) # last condition is 2*x1 >= -2 + ((x2 - 3*x3, + [-3*x1 - 2*x2 + 4*x3 >= -2, 2*x1 >= -2]), + (-2*y1 - 2*y2, + [-3*y1 + 2*y2 <= 0, -2*y1 <= 1, 4*y1 <= -3])) + + >>> _primal_dual(m) # condition now x1 >= -1 + ((x2 - 3*x3, + [-3*x1 - 2*x2 + 4*x3 >= -2, x1 >= -1]), + (-2*y1 - 2*y2, + [-3*y1 + 2*y2 <= 0, -2*y1 <= 1, 4*y1 <= -3])) + + If you pass the transpose of the matrix, the primal will be + identified as the standard minimization problem and the + dual as the standard maximization: + + >>> _primal_dual(m.T) + ((-2*x1 - 2*x2, + [-3*x1 + 2*x2 >= 0, -2*x1 >= 1, 4*x1 >= -3]), + (y2 - 3*y3, + [-3*y1 - 2*y2 + 4*y3 <= -2, y1 <= -1])) + + A matrix must have some size or else None will be returned for + the functions: + + >>> _primal_dual(Matrix([[1, 2]])) + ((x1 - 2, []), (-2, [])) + + >>> _primal_dual(Matrix([])) + ((None, []), (None, [])) + + References + ========== + + .. [1] David Galvin, Relations between Primal and Dual + www3.nd.edu/~dgalvin1/30210/30210_F07/presentations/dual_opt.pdf + """ + if not M: + return (None, []), (None, []) + if not hasattr(M, "shape"): + if len(M) not in (3, 4): + raise ValueError("expecting Matrix or 3 or 4 lists") + M = _m(*M) + m, n = [i - 1 for i in M.shape] + A, b, c, d = _abcd(M) + d = d[0] + _ = lambda x: numbered_symbols(x, start=1) + x = Matrix([i for i, j in zip(_("x"), range(n))]) + yT = Matrix([i for i, j in zip(_("y"), range(m))]).T + + def ineq(L, r, op): + rv = [] + for r in (op(i, j) for i, j in zip(L, r)): + if r == True: + continue + elif r == False: + return [False] + if factor: + f = factor_terms(r) + if f.lhs.is_Mul and f.rhs % f.lhs.args[0] == 0: + assert len(f.lhs.args) == 2, f.lhs + k = f.lhs.args[0] + r = r.func(sign(k) * f.lhs.args[1], f.rhs // abs(k)) + rv.append(r) + return rv + + eq = lambda x, d: x[0] - d if x else -d + F = eq(c * x, d) + f = eq(yT * b, d) + return (F, ineq(A * x, b, Ge)), (f, ineq(yT * A, c, Le)) + + +def _rel_as_nonpos(constr, syms): + """return `(np, d, aux)` where `np` is a list of nonpositive + expressions that represent the given constraints (possibly + rewritten in terms of auxilliary variables) expressible with + nonnegative symbols, and `d` is a dictionary mapping a given + symbols to an expression with an auxilliary variable. In some + cases a symbol will be used as part of the change of variables, + e.g. x: x - z1 instead of x: z1 - z2. + + If any constraint is False/empty, return None. All variables in + ``constr`` are assumed to be unbounded unless explicitly indicated + otherwise with a univariate constraint, e.g. ``x >= 0`` will + restrict ``x`` to nonnegative values. + + The ``syms`` must be included so all symbols can be given an + unbounded assumption if they are not otherwise bound with + univariate conditions like ``x <= 3``. + + Examples + ======== + + >>> from sympy.solvers.simplex import _rel_as_nonpos + >>> from sympy.abc import x, y + >>> _rel_as_nonpos([x >= y, x >= 0, y >= 0], (x, y)) + ([-x + y], {}, []) + >>> _rel_as_nonpos([x >= 3, x <= 5], [x]) + ([_z1 - 2], {x: _z1 + 3}, [_z1]) + >>> _rel_as_nonpos([x <= 5], [x]) + ([], {x: 5 - _z1}, [_z1]) + >>> _rel_as_nonpos([x >= 1], [x]) + ([], {x: _z1 + 1}, [_z1]) + """ + r = {} # replacements to handle change of variables + np = [] # nonpositive expressions + aux = [] # auxilliary symbols added + ui = numbered_symbols("z", start=1, cls=Dummy) # auxilliary symbols + univariate = {} # {x: interval} for univariate constraints + unbound = [] # symbols designated as unbound + syms = set(syms) # the expected syms of the system + + # separate out univariates + for i in constr: + if i == True: + continue # ignore + if i == False: + return # no solution + if i.has(S.Infinity, S.NegativeInfinity): + raise ValueError("only finite bounds are permitted") + if isinstance(i, (Le, Ge)): + i = i.lts - i.gts + freei = i.free_symbols + if freei - syms: + raise ValueError( + "unexpected symbol(s) in constraint: %s" % (freei - syms) + ) + if len(freei) > 1: + np.append(i) + elif freei: + x = freei.pop() + if x in unbound: + continue # will handle later + ivl = Le(i, 0, evaluate=False).as_set() + if x not in univariate: + univariate[x] = ivl + else: + univariate[x] &= ivl + elif i: + return False + else: + raise TypeError(filldedent(""" + only equalities like Eq(x, y) or non-strict + inequalities like x >= y are allowed in lp, not %s""" % i)) + + # introduce auxilliary variables as needed for univariate + # inequalities + for x in syms: + i = univariate.get(x, True) + if not i: + return None # no solution possible + if i == True: + unbound.append(x) + continue + a, b = i.inf, i.sup + if a.is_infinite: + u = next(ui) + r[x] = b - u + aux.append(u) + elif b.is_infinite: + if a: + u = next(ui) + r[x] = a + u + aux.append(u) + else: + # standard nonnegative relationship + pass + else: + u = next(ui) + aux.append(u) + # shift so u = x - a => x = u + a + r[x] = u + a + # add constraint for u <= b - a + # since when u = b-a then x = u + a = b - a + a = b: + # the upper limit for x + np.append(u - (b - a)) + + # make change of variables for unbound variables + for x in unbound: + u = next(ui) + r[x] = u - x # reusing x + aux.append(u) + + return np, r, aux + + +def _lp_matrices(objective, constraints): + """return A, B, C, D, r, x+X, X for maximizing + objective = Cx - D with constraints Ax <= B, introducing + introducing auxilliary variables, X, as necessary to make + replacements of symbols as given in r, {xi: expression with Xj}, + so all variables in x+X will take on nonnegative values. + + Every univariate condition creates a semi-infinite + condition, e.g. a single ``x <= 3`` creates the + interval ``[-oo, 3]`` while ``x <= 3`` and ``x >= 2`` + create an interval ``[2, 3]``. Variables not in a univariate + expression will take on nonnegative values. + """ + + # sympify input and collect free symbols + F = sympify(objective) + np = [sympify(i) for i in constraints] + syms = set.union(*[i.free_symbols for i in [F] + np], set()) + + # change Eq(x, y) to x - y <= 0 and y - x <= 0 + for i in range(len(np)): + if isinstance(np[i], Eq): + np[i] = np[i].lhs - np[i].rhs <= 0 + np.append(-np[i].lhs <= 0) + + # convert constraints to nonpositive expressions + _ = _rel_as_nonpos(np, syms) + if _ is None: + raise InfeasibleLPError(filldedent(""" + Inconsistent/False constraint""")) + np, r, aux = _ + + # do change of variables + F = F.xreplace(r) + np = [i.xreplace(r) for i in np] + + # convert to matrices + xx = list(ordered(syms)) + aux + A, B = linear_eq_to_matrix(np, xx) + C, D = linear_eq_to_matrix([F], xx) + return A, B, C, D, r, xx, aux + + +def _lp(min_max, f, constr): + """Return the optimization (min or max) of ``f`` with the given + constraints. All variables are unbounded unless constrained. + + If `min_max` is 'max' then the results corresponding to the + maximization of ``f`` will be returned, else the minimization. + The constraints can be given as Le, Ge or Eq expressions. + + Examples + ======== + + >>> from sympy.solvers.simplex import _lp as lp + >>> from sympy import Eq + >>> from sympy.abc import x, y, z + >>> f = x + y - 2*z + >>> c = [7*x + 4*y - 7*z <= 3, 3*x - y + 10*z <= 6] + >>> c += [i >= 0 for i in (x, y, z)] + >>> lp(min, f, c) + (-6/5, {x: 0, y: 0, z: 3/5}) + + By passing max, the maximum value for f under the constraints + is returned (if possible): + + >>> lp(max, f, c) + (3/4, {x: 0, y: 3/4, z: 0}) + + Constraints that are equalities will require that the solution + also satisfy them: + + >>> lp(max, f, c + [Eq(y - 9*x, 1)]) + (5/7, {x: 0, y: 1, z: 1/7}) + + All symbols are reported, even if they are not in the objective + function: + + >>> lp(min, x, [y + x >= 3, x >= 0]) + (0, {x: 0, y: 3}) + """ + # get the matrix components for the system expressed + # in terms of only nonnegative variables + A, B, C, D, r, xx, aux = _lp_matrices(f, constr) + + how = str(min_max).lower() + if "max" in how: + # _simplex minimizes for Ax <= B so we + # have to change the sign of the function + # and negate the optimal value returned + _o, p, d = _simplex(A, B, -C, -D) + o = -_o + elif "min" in how: + o, p, d = _simplex(A, B, C, D) + else: + raise ValueError("expecting min or max") + + # restore original variables and remove aux from p + p = dict(zip(xx, p)) + if r: # p has original symbols and auxilliary symbols + # if r has x: x - z1 use values from p to update + r = {k: v.xreplace(p) for k, v in r.items()} + # then use the actual value of x (= x - z1) in p + p.update(r) + # don't show aux + p = {k: p[k] for k in ordered(p) if k not in aux} + + # not returning dual since there may be extra constraints + # when a variable has finite bounds + return o, p + + +def lpmin(f, constr): + """return minimum of linear equation ``f`` under + linear constraints expressed using Ge, Le or Eq. + + All variables are unbounded unless constrained. + + Examples + ======== + + >>> from sympy.solvers.simplex import lpmin + >>> from sympy import Eq + >>> from sympy.abc import x, y + >>> lpmin(x, [2*x - 3*y >= -1, Eq(x + 3*y, 2), x <= 2*y]) + (1/3, {x: 1/3, y: 5/9}) + + Negative values for variables are permitted unless explicitly + excluding, so minimizing ``x`` for ``x <= 3`` is an + unbounded problem while the following has a bounded solution: + + >>> lpmin(x, [x >= 0, x <= 3]) + (0, {x: 0}) + + Without indicating that ``x`` is nonnegative, there + is no minimum for this objective: + + >>> lpmin(x, [x <= 3]) + Traceback (most recent call last): + ... + sympy.solvers.simplex.UnboundedLPError: + Objective function can assume arbitrarily large values! + + See Also + ======== + linprog, lpmax + """ + return _lp(min, f, constr) + + +def lpmax(f, constr): + """return maximum of linear equation ``f`` under + linear constraints expressed using Ge, Le or Eq. + + All variables are unbounded unless constrained. + + Examples + ======== + + >>> from sympy.solvers.simplex import lpmax + >>> from sympy import Eq + >>> from sympy.abc import x, y + >>> lpmax(x, [2*x - 3*y >= -1, Eq(x+ 3*y,2), x <= 2*y]) + (4/5, {x: 4/5, y: 2/5}) + + Negative values for variables are permitted unless explicitly + excluding: + + >>> lpmax(x, [x <= -1]) + (-1, {x: -1}) + + If a non-negative constraint is added for x, there is no + possible solution: + + >>> lpmax(x, [x <= -1, x >= 0]) + Traceback (most recent call last): + ... + sympy.solvers.simplex.InfeasibleLPError: inconsistent/False constraint + + See Also + ======== + linprog, lpmin + """ + return _lp(max, f, constr) + + +def _handle_bounds(bounds): + # introduce auxiliary variables as needed for univariate + # inequalities + + def _make_list(length: int, index_value_pairs): + li = [0] * length + for idx, val in index_value_pairs: + li[idx] = val + return li + + unbound = [] + row = [] + row2 = [] + b_len = len(bounds) + for x, (a, b) in enumerate(bounds): + if a is None and b is None: + unbound.append(x) + elif a is None: + # r[x] = b - u + b_len += 1 + row.append(_make_list(b_len, [(x, 1), (-1, 1)])) + row.append(_make_list(b_len, [(x, -1), (-1, -1)])) + row2.extend([[b], [-b]]) + elif b is None: + if a: + # r[x] = a + u + b_len += 1 + row.append(_make_list(b_len, [(x, 1), (-1, -1)])) + row.append(_make_list(b_len, [(x, -1), (-1, 1)])) + row2.extend([[a], [-a]]) + else: + # standard nonnegative relationship + pass + else: + # r[x] = u + a + b_len += 1 + row.append(_make_list(b_len, [(x, 1), (-1, -1)])) + row.append(_make_list(b_len, [(x, -1), (-1, 1)])) + # u <= b - a + row.append(_make_list(b_len, [(-1, 1)])) + row2.extend([[a], [-a], [b - a]]) + + # make change of variables for unbound variables + for x in unbound: + # r[x] = u - v + b_len += 2 + row.append(_make_list(b_len, [(x, 1), (-1, 1), (-2, -1)])) + row.append(_make_list(b_len, [(x, -1), (-1, -1), (-2, 1)])) + row2.extend([[0], [0]]) + + return Matrix([r + [0]*(b_len - len(r)) for r in row]), Matrix(row2) + + +def linprog(c, A=None, b=None, A_eq=None, b_eq=None, bounds=None): + """Return the minimization of ``c*x`` with the given + constraints ``A*x <= b`` and ``A_eq*x = b_eq``. Unless bounds + are given, variables will have nonnegative values in the solution. + + If ``A`` is not given, then the dimension of the system will + be determined by the length of ``C``. + + By default, all variables will be nonnegative. If ``bounds`` + is given as a single tuple, ``(lo, hi)``, then all variables + will be constrained to be between ``lo`` and ``hi``. Use + None for a ``lo`` or ``hi`` if it is unconstrained in the + negative or positive direction, respectively, e.g. + ``(None, 0)`` indicates nonpositive values. To set + individual ranges, pass a list with length equal to the + number of columns in ``A``, each element being a tuple; if + only a few variables take on non-default values they can be + passed as a dictionary with keys giving the corresponding + column to which the variable is assigned, e.g. ``bounds={2: + (1, 4)}`` would limit the 3rd variable to have a value in + range ``[1, 4]``. + + Examples + ======== + + >>> from sympy.solvers.simplex import linprog + >>> from sympy import symbols, Eq, linear_eq_to_matrix as M, Matrix + >>> x = x1, x2, x3, x4 = symbols('x1:5') + >>> X = Matrix(x) + >>> c, d = M(5*x2 + x3 + 4*x4 - x1, x) + >>> a, b = M([5*x2 + 2*x3 + 5*x4 - (x1 + 5)], x) + >>> aeq, beq = M([Eq(3*x2 + x4, 2), Eq(-x1 + x3 + 2*x4, 1)], x) + >>> constr = [i <= j for i,j in zip(a*X, b)] + >>> constr += [Eq(i, j) for i,j in zip(aeq*X, beq)] + >>> linprog(c, a, b, aeq, beq) + (9/2, [0, 1/2, 0, 1/2]) + >>> assert all(i.subs(dict(zip(x, _[1]))) for i in constr) + + See Also + ======== + lpmin, lpmax + """ + + ## the objective + C = Matrix(c) + if C.rows != 1 and C.cols == 1: + C = C.T + if C.rows != 1: + raise ValueError("C must be a single row.") + + ## the inequalities + if not A: + if b: + raise ValueError("A and b must both be given") + # the governing equations will be simple constraints + # on variables + A, b = zeros(0, C.cols), zeros(C.cols, 1) + else: + A, b = [Matrix(i) for i in (A, b)] + + if A.cols != C.cols: + raise ValueError("number of columns in A and C must match") + + ## the equalities + if A_eq is None: + if not b_eq is None: + raise ValueError("A_eq and b_eq must both be given") + else: + A_eq, b_eq = [Matrix(i) for i in (A_eq, b_eq)] + # if x == y then x <= y and x >= y (-x <= -y) + A = A.col_join(A_eq) + A = A.col_join(-A_eq) + b = b.col_join(b_eq) + b = b.col_join(-b_eq) + + if not (bounds is None or bounds == {} or bounds == (0, None)): + ## the bounds are interpreted + if type(bounds) is tuple and len(bounds) == 2: + bounds = [bounds] * A.cols + elif len(bounds) == A.cols and all( + type(i) is tuple and len(i) == 2 for i in bounds): + pass # individual bounds + elif type(bounds) is dict and all( + type(i) is tuple and len(i) == 2 + for i in bounds.values()): + # sparse bounds + db = bounds + bounds = [(0, None)] * A.cols + while db: + i, j = db.popitem() + bounds[i] = j # IndexError if out-of-bounds indices + else: + raise ValueError("unexpected bounds %s" % bounds) + A_, b_ = _handle_bounds(bounds) + aux = A_.cols - A.cols + if A: + A = Matrix([[A, zeros(A.rows, aux)], [A_]]) + b = b.col_join(b_) + else: + A = A_ + b = b_ + C = C.row_join(zeros(1, aux)) + else: + aux = -A.cols # set so -aux will give all cols below + + o, p, d = _simplex(A, b, C) + return o, p[:-aux] # don't include aux values + +def show_linprog(c, A=None, b=None, A_eq=None, b_eq=None, bounds=None): + from sympy import symbols + ## the objective + C = Matrix(c) + if C.rows != 1 and C.cols == 1: + C = C.T + if C.rows != 1: + raise ValueError("C must be a single row.") + + ## the inequalities + if not A: + if b: + raise ValueError("A and b must both be given") + # the governing equations will be simple constraints + # on variables + A, b = zeros(0, C.cols), zeros(C.cols, 1) + else: + A, b = [Matrix(i) for i in (A, b)] + + if A.cols != C.cols: + raise ValueError("number of columns in A and C must match") + + ## the equalities + if A_eq is None: + if not b_eq is None: + raise ValueError("A_eq and b_eq must both be given") + else: + A_eq, b_eq = [Matrix(i) for i in (A_eq, b_eq)] + + if not (bounds is None or bounds == {} or bounds == (0, None)): + ## the bounds are interpreted + if type(bounds) is tuple and len(bounds) == 2: + bounds = [bounds] * A.cols + elif len(bounds) == A.cols and all( + type(i) is tuple and len(i) == 2 for i in bounds): + pass # individual bounds + elif type(bounds) is dict and all( + type(i) is tuple and len(i) == 2 + for i in bounds.values()): + # sparse bounds + db = bounds + bounds = [(0, None)] * A.cols + while db: + i, j = db.popitem() + bounds[i] = j # IndexError if out-of-bounds indices + else: + raise ValueError("unexpected bounds %s" % bounds) + + x = Matrix(symbols('x1:%s' % (A.cols+1))) + f,c = (C*x)[0], [i<=j for i,j in zip(A*x, b)] + [Eq(i,j) for i,j in zip(A_eq*x,b_eq)] + for i, (lo, hi) in enumerate(bounds): + if lo is not None: + c.append(x[i]>=lo) + if hi is not None: + c.append(x[i]<=hi) + return f,c diff --git a/phivenv/Lib/site-packages/sympy/solvers/solvers.py b/phivenv/Lib/site-packages/sympy/solvers/solvers.py new file mode 100644 index 0000000000000000000000000000000000000000..ef621a84e34bde43a3181d2fd90e26fa7b05e968 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/solvers/solvers.py @@ -0,0 +1,3674 @@ +""" +This module contain solvers for all kinds of equations: + + - algebraic or transcendental, use solve() + + - recurrence, use rsolve() + + - differential, use dsolve() + + - nonlinear (numerically), use nsolve() + (you will need a good starting point) + +""" +from __future__ import annotations + +from sympy.core import (S, Add, Symbol, Dummy, Expr, Mul) +from sympy.core.assumptions import check_assumptions +from sympy.core.exprtools import factor_terms +from sympy.core.function import (expand_mul, expand_log, Derivative, + AppliedUndef, UndefinedFunction, nfloat, + Function, expand_power_exp, _mexpand, expand, + expand_func) +from sympy.core.logic import fuzzy_not, fuzzy_and +from sympy.core.numbers import Float, Rational, _illegal +from sympy.core.intfunc import integer_log, ilcm +from sympy.core.power import Pow +from sympy.core.relational import Eq, Ne +from sympy.core.sorting import ordered, default_sort_key +from sympy.core.sympify import sympify, _sympify +from sympy.core.traversal import preorder_traversal +from sympy.logic.boolalg import And, BooleanAtom + +from sympy.functions import (log, exp, LambertW, cos, sin, tan, acos, asin, atan, + Abs, re, im, arg, sqrt, atan2) +from sympy.functions.combinatorial.factorials import binomial +from sympy.functions.elementary.hyperbolic import HyperbolicFunction +from sympy.functions.elementary.piecewise import piecewise_fold, Piecewise +from sympy.functions.elementary.trigonometric import TrigonometricFunction +from sympy.integrals.integrals import Integral +from sympy.ntheory.factor_ import divisors +from sympy.simplify import (simplify, collect, powsimp, posify, # type: ignore + powdenest, nsimplify, denom, logcombine, sqrtdenest, fraction, + separatevars) +from sympy.simplify.sqrtdenest import sqrt_depth +from sympy.simplify.fu import TR1, TR2i, TR10, TR11 +from sympy.strategies.rl import rebuild +from sympy.matrices.exceptions import NonInvertibleMatrixError +from sympy.matrices import Matrix, zeros +from sympy.polys import roots, cancel, factor, Poly +from sympy.polys.solvers import sympy_eqs_to_ring, solve_lin_sys +from sympy.polys.polyerrors import GeneratorsNeeded, PolynomialError +from sympy.polys.polytools import gcd +from sympy.utilities.lambdify import lambdify +from sympy.utilities.misc import filldedent, debugf +from sympy.utilities.iterables import (connected_components, + generate_bell, uniq, iterable, is_sequence, subsets, flatten, sift) +from sympy.utilities.decorator import conserve_mpmath_dps + +from mpmath import findroot + +from sympy.solvers.polysys import solve_poly_system + +from types import GeneratorType +from collections import defaultdict +from itertools import combinations, product + +import warnings + + +def recast_to_symbols(eqs, symbols): + """ + Return (e, s, d) where e and s are versions of *eqs* and + *symbols* in which any non-Symbol objects in *symbols* have + been replaced with generic Dummy symbols and d is a dictionary + that can be used to restore the original expressions. + + Examples + ======== + + >>> from sympy.solvers.solvers import recast_to_symbols + >>> from sympy import symbols, Function + >>> x, y = symbols('x y') + >>> fx = Function('f')(x) + >>> eqs, syms = [fx + 1, x, y], [fx, y] + >>> e, s, d = recast_to_symbols(eqs, syms); (e, s, d) + ([_X0 + 1, x, y], [_X0, y], {_X0: f(x)}) + + The original equations and symbols can be restored using d: + + >>> assert [i.xreplace(d) for i in eqs] == eqs + >>> assert [d.get(i, i) for i in s] == syms + + """ + if not iterable(eqs) and iterable(symbols): + raise ValueError('Both eqs and symbols must be iterable') + orig = list(symbols) + symbols = list(ordered(symbols)) + swap_sym = {} + i = 0 + for s in symbols: + if not isinstance(s, Symbol) and s not in swap_sym: + swap_sym[s] = Dummy('X%d' % i) + i += 1 + new_f = [] + for i in eqs: + isubs = getattr(i, 'subs', None) + if isubs is not None: + new_f.append(isubs(swap_sym)) + else: + new_f.append(i) + restore = {v: k for k, v in swap_sym.items()} + return new_f, [swap_sym.get(i, i) for i in orig], restore + + +def _ispow(e): + """Return True if e is a Pow or is exp.""" + return isinstance(e, Expr) and (e.is_Pow or isinstance(e, exp)) + + +def _simple_dens(f, symbols): + # when checking if a denominator is zero, we can just check the + # base of powers with nonzero exponents since if the base is zero + # the power will be zero, too. To keep it simple and fast, we + # limit simplification to exponents that are Numbers + dens = set() + for d in denoms(f, symbols): + if d.is_Pow and d.exp.is_Number: + if d.exp.is_zero: + continue # foo**0 is never 0 + d = d.base + dens.add(d) + return dens + + +def denoms(eq, *symbols): + """ + Return (recursively) set of all denominators that appear in *eq* + that contain any symbol in *symbols*; if *symbols* are not + provided then all denominators will be returned. + + Examples + ======== + + >>> from sympy.solvers.solvers import denoms + >>> from sympy.abc import x, y, z + + >>> denoms(x/y) + {y} + + >>> denoms(x/(y*z)) + {y, z} + + >>> denoms(3/x + y/z) + {x, z} + + >>> denoms(x/2 + y/z) + {2, z} + + If *symbols* are provided then only denominators containing + those symbols will be returned: + + >>> denoms(1/x + 1/y + 1/z, y, z) + {y, z} + + """ + + pot = preorder_traversal(eq) + dens = set() + for p in pot: + # Here p might be Tuple or Relational + # Expr subtrees (e.g. lhs and rhs) will be traversed after by pot + if not isinstance(p, Expr): + continue + den = denom(p) + if den is S.One: + continue + dens.update(Mul.make_args(den)) + if not symbols: + return dens + elif len(symbols) == 1: + if iterable(symbols[0]): + symbols = symbols[0] + return {d for d in dens if any(s in d.free_symbols for s in symbols)} + + +def checksol(f, symbol, sol=None, **flags): + """ + Checks whether sol is a solution of equation f == 0. + + Explanation + =========== + + Input can be either a single symbol and corresponding value + or a dictionary of symbols and values. When given as a dictionary + and flag ``simplify=True``, the values in the dictionary will be + simplified. *f* can be a single equation or an iterable of equations. + A solution must satisfy all equations in *f* to be considered valid; + if a solution does not satisfy any equation, False is returned; if one or + more checks are inconclusive (and none are False) then None is returned. + + Examples + ======== + + >>> from sympy import checksol, symbols + >>> x, y = symbols('x,y') + >>> checksol(x**4 - 1, x, 1) + True + >>> checksol(x**4 - 1, x, 0) + False + >>> checksol(x**2 + y**2 - 5**2, {x: 3, y: 4}) + True + + To check if an expression is zero using ``checksol()``, pass it + as *f* and send an empty dictionary for *symbol*: + + >>> checksol(x**2 + x - x*(x + 1), {}) + True + + None is returned if ``checksol()`` could not conclude. + + flags: + 'numerical=True (default)' + do a fast numerical check if ``f`` has only one symbol. + 'minimal=True (default is False)' + a very fast, minimal testing. + 'warn=True (default is False)' + show a warning if checksol() could not conclude. + 'simplify=True (default)' + simplify solution before substituting into function and + simplify the function before trying specific simplifications + 'force=True (default is False)' + make positive all symbols without assumptions regarding sign. + + """ + from sympy.physics.units import Unit + + minimal = flags.get('minimal', False) + + if sol is not None: + sol = {symbol: sol} + elif isinstance(symbol, dict): + sol = symbol + else: + msg = 'Expecting (sym, val) or ({sym: val}, None) but got (%s, %s)' + raise ValueError(msg % (symbol, sol)) + + if iterable(f): + if not f: + raise ValueError('no functions to check') + return fuzzy_and(checksol(fi, sol, **flags) for fi in f) + + f = _sympify(f) + + if f.is_number: + return f.is_zero + + if isinstance(f, Poly): + f = f.as_expr() + elif isinstance(f, (Eq, Ne)): + if f.rhs in (S.true, S.false): + f = f.reversed + B, E = f.args + if isinstance(B, BooleanAtom): + f = f.subs(sol) + if not f.is_Boolean: + return + elif isinstance(f, Eq): + f = Add(f.lhs, -f.rhs, evaluate=False) + + if isinstance(f, BooleanAtom): + return bool(f) + elif not f.is_Relational and not f: + return True + + illegal = set(_illegal) + if any(sympify(v).atoms() & illegal for k, v in sol.items()): + return False + + attempt = -1 + numerical = flags.get('numerical', True) + while 1: + attempt += 1 + if attempt == 0: + val = f.subs(sol) + if isinstance(val, Mul): + val = val.as_independent(Unit)[0] + if val.atoms() & illegal: + return False + elif attempt == 1: + if not val.is_number: + if not val.is_constant(*list(sol.keys()), simplify=not minimal): + return False + # there are free symbols -- simple expansion might work + _, val = val.as_content_primitive() + val = _mexpand(val.as_numer_denom()[0], recursive=True) + elif attempt == 2: + if minimal: + return + if flags.get('simplify', True): + for k in sol: + sol[k] = simplify(sol[k]) + # start over without the failed expanded form, possibly + # with a simplified solution + val = simplify(f.subs(sol)) + if flags.get('force', True): + val, reps = posify(val) + # expansion may work now, so try again and check + exval = _mexpand(val, recursive=True) + if exval.is_number: + # we can decide now + val = exval + else: + # if there are no radicals and no functions then this can't be + # zero anymore -- can it? + pot = preorder_traversal(expand_mul(val)) + seen = set() + saw_pow_func = False + for p in pot: + if p in seen: + continue + seen.add(p) + if p.is_Pow and not p.exp.is_Integer: + saw_pow_func = True + elif p.is_Function: + saw_pow_func = True + elif isinstance(p, UndefinedFunction): + saw_pow_func = True + if saw_pow_func: + break + if saw_pow_func is False: + return False + if flags.get('force', True): + # don't do a zero check with the positive assumptions in place + val = val.subs(reps) + nz = fuzzy_not(val.is_zero) + if nz is not None: + # issue 5673: nz may be True even when False + # so these are just hacks to keep a false positive + # from being returned + + # HACK 1: LambertW (issue 5673) + if val.is_number and val.has(LambertW): + # don't eval this to verify solution since if we got here, + # numerical must be False + return None + + # add other HACKs here if necessary, otherwise we assume + # the nz value is correct + return not nz + break + if val.is_Rational: + return val == 0 + if numerical and val.is_number: + return (abs(val.n(18).n(12, chop=True)) < 1e-9) is S.true + + if flags.get('warn', False): + warnings.warn("\n\tWarning: could not verify solution %s." % sol) + # returns None if it can't conclude + # TODO: improve solution testing + + +def solve(f, *symbols, **flags): + r""" + Algebraically solves equations and systems of equations. + + Explanation + =========== + + Currently supported: + - polynomial + - transcendental + - piecewise combinations of the above + - systems of linear and polynomial equations + - systems containing relational expressions + - systems implied by undetermined coefficients + + Examples + ======== + + The default output varies according to the input and might + be a list (possibly empty), a dictionary, a list of + dictionaries or tuples, or an expression involving relationals. + For specifics regarding different forms of output that may appear, see :ref:`solve_output`. + Let it suffice here to say that to obtain a uniform output from + `solve` use ``dict=True`` or ``set=True`` (see below). + + >>> from sympy import solve, Poly, Eq, Matrix, Symbol + >>> from sympy.abc import x, y, z, a, b + + The expressions that are passed can be Expr, Equality, or Poly + classes (or lists of the same); a Matrix is considered to be a + list of all the elements of the matrix: + + >>> solve(x - 3, x) + [3] + >>> solve(Eq(x, 3), x) + [3] + >>> solve(Poly(x - 3), x) + [3] + >>> solve(Matrix([[x, x + y]]), x, y) == solve([x, x + y], x, y) + True + + If no symbols are indicated to be of interest and the equation is + univariate, a list of values is returned; otherwise, the keys in + a dictionary will indicate which (of all the variables used in + the expression(s)) variables and solutions were found: + + >>> solve(x**2 - 4) + [-2, 2] + >>> solve((x - a)*(y - b)) + [{a: x}, {b: y}] + >>> solve([x - 3, y - 1]) + {x: 3, y: 1} + >>> solve([x - 3, y**2 - 1]) + [{x: 3, y: -1}, {x: 3, y: 1}] + + If you pass symbols for which solutions are sought, the output will vary + depending on the number of symbols you passed, whether you are passing + a list of expressions or not, and whether a linear system was solved. + Uniform output is attained by using ``dict=True`` or ``set=True``. + + >>> #### *** feel free to skip to the stars below *** #### + >>> from sympy import TableForm + >>> h = [None, ';|;'.join(['e', 's', 'solve(e, s)', 'solve(e, s, dict=True)', + ... 'solve(e, s, set=True)']).split(';')] + >>> t = [] + >>> for e, s in [ + ... (x - y, y), + ... (x - y, [x, y]), + ... (x**2 - y, [x, y]), + ... ([x - 3, y -1], [x, y]), + ... ]: + ... how = [{}, dict(dict=True), dict(set=True)] + ... res = [solve(e, s, **f) for f in how] + ... t.append([e, '|', s, '|'] + [res[0], '|', res[1], '|', res[2]]) + ... + >>> # ******************************************************* # + >>> TableForm(t, headings=h, alignments="<") + e | s | solve(e, s) | solve(e, s, dict=True) | solve(e, s, set=True) + --------------------------------------------------------------------------------------- + x - y | y | [x] | [{y: x}] | ([y], {(x,)}) + x - y | [x, y] | [(y, y)] | [{x: y}] | ([x, y], {(y, y)}) + x**2 - y | [x, y] | [(x, x**2)] | [{y: x**2}] | ([x, y], {(x, x**2)}) + [x - 3, y - 1] | [x, y] | {x: 3, y: 1} | [{x: 3, y: 1}] | ([x, y], {(3, 1)}) + + * If any equation does not depend on the symbol(s) given, it will be + eliminated from the equation set and an answer may be given + implicitly in terms of variables that were not of interest: + + >>> solve([x - y, y - 3], x) + {x: y} + + When you pass all but one of the free symbols, an attempt + is made to find a single solution based on the method of + undetermined coefficients. If it succeeds, a dictionary of values + is returned. If you want an algebraic solutions for one + or more of the symbols, pass the expression to be solved in a list: + + >>> e = a*x + b - 2*x - 3 + >>> solve(e, [a, b]) + {a: 2, b: 3} + >>> solve([e], [a, b]) + {a: -b/x + (2*x + 3)/x} + + When there is no solution for any given symbol which will make all + expressions zero, the empty list is returned (or an empty set in + the tuple when ``set=True``): + + >>> from sympy import sqrt + >>> solve(3, x) + [] + >>> solve(x - 3, y) + [] + >>> solve(sqrt(x) + 1, x, set=True) + ([x], set()) + + When an object other than a Symbol is given as a symbol, it is + isolated algebraically and an implicit solution may be obtained. + This is mostly provided as a convenience to save you from replacing + the object with a Symbol and solving for that Symbol. It will only + work if the specified object can be replaced with a Symbol using the + subs method: + + >>> from sympy import exp, Function + >>> f = Function('f') + + >>> solve(f(x) - x, f(x)) + [x] + >>> solve(f(x).diff(x) - f(x) - x, f(x).diff(x)) + [x + f(x)] + >>> solve(f(x).diff(x) - f(x) - x, f(x)) + [-x + Derivative(f(x), x)] + >>> solve(x + exp(x)**2, exp(x), set=True) + ([exp(x)], {(-sqrt(-x),), (sqrt(-x),)}) + + >>> from sympy import Indexed, IndexedBase, Tuple + >>> A = IndexedBase('A') + >>> eqs = Tuple(A[1] + A[2] - 3, A[1] - A[2] + 1) + >>> solve(eqs, eqs.atoms(Indexed)) + {A[1]: 1, A[2]: 2} + + * To solve for a function within a derivative, use :func:`~.dsolve`. + + To solve for a symbol implicitly, use implicit=True: + + >>> solve(x + exp(x), x) + [-LambertW(1)] + >>> solve(x + exp(x), x, implicit=True) + [-exp(x)] + + It is possible to solve for anything in an expression that can be + replaced with a symbol using :obj:`~sympy.core.basic.Basic.subs`: + + >>> solve(x + 2 + sqrt(3), x + 2) + [-sqrt(3)] + >>> solve((x + 2 + sqrt(3), x + 4 + y), y, x + 2) + {y: -2 + sqrt(3), x + 2: -sqrt(3)} + + * Nothing heroic is done in this implicit solving so you may end up + with a symbol still in the solution: + + >>> eqs = (x*y + 3*y + sqrt(3), x + 4 + y) + >>> solve(eqs, y, x + 2) + {y: -sqrt(3)/(x + 3), x + 2: -2*x/(x + 3) - 6/(x + 3) + sqrt(3)/(x + 3)} + >>> solve(eqs, y*x, x) + {x: -y - 4, x*y: -3*y - sqrt(3)} + + * If you attempt to solve for a number, remember that the number + you have obtained does not necessarily mean that the value is + equivalent to the expression obtained: + + >>> solve(sqrt(2) - 1, 1) + [sqrt(2)] + >>> solve(x - y + 1, 1) # /!\ -1 is targeted, too + [x/(y - 1)] + >>> [_.subs(z, -1) for _ in solve((x - y + 1).subs(-1, z), 1)] + [-x + y] + + **Additional Examples** + + ``solve()`` with check=True (default) will run through the symbol tags to + eliminate unwanted solutions. If no assumptions are included, all possible + solutions will be returned: + + >>> x = Symbol("x") + >>> solve(x**2 - 1) + [-1, 1] + + By setting the ``positive`` flag, only one solution will be returned: + + >>> pos = Symbol("pos", positive=True) + >>> solve(pos**2 - 1) + [1] + + When the solutions are checked, those that make any denominator zero + are automatically excluded. If you do not want to exclude such solutions, + then use the check=False option: + + >>> from sympy import sin, limit + >>> solve(sin(x)/x) # 0 is excluded + [pi] + + If ``check=False``, then a solution to the numerator being zero is found + but the value of $x = 0$ is a spurious solution since $\sin(x)/x$ has the well + known limit (without discontinuity) of 1 at $x = 0$: + + >>> solve(sin(x)/x, check=False) + [0, pi] + + In the following case, however, the limit exists and is equal to the + value of $x = 0$ that is excluded when check=True: + + >>> eq = x**2*(1/x - z**2/x) + >>> solve(eq, x) + [] + >>> solve(eq, x, check=False) + [0] + >>> limit(eq, x, 0, '-') + 0 + >>> limit(eq, x, 0, '+') + 0 + + **Solving Relationships** + + When one or more expressions passed to ``solve`` is a relational, + a relational result is returned (and the ``dict`` and ``set`` flags + are ignored): + + >>> solve(x < 3) + (-oo < x) & (x < 3) + >>> solve([x < 3, x**2 > 4], x) + ((-oo < x) & (x < -2)) | ((2 < x) & (x < 3)) + >>> solve([x + y - 3, x > 3], x) + (3 < x) & (x < oo) & Eq(x, 3 - y) + + Although checking of assumptions on symbols in relationals + is not done, setting assumptions will affect how certain + relationals might automatically simplify: + + >>> solve(x**2 > 4) + ((-oo < x) & (x < -2)) | ((2 < x) & (x < oo)) + + >>> r = Symbol('r', real=True) + >>> solve(r**2 > 4) + (2 < r) | (r < -2) + + There is currently no algorithm in SymPy that allows you to use + relationships to resolve more than one variable. So the following + does not determine that ``q < 0`` (and trying to solve for ``r`` + and ``q`` will raise an error): + + >>> from sympy import symbols + >>> r, q = symbols('r, q', real=True) + >>> solve([r + q - 3, r > 3], r) + (3 < r) & Eq(r, 3 - q) + + You can directly call the routine that ``solve`` calls + when it encounters a relational: :func:`~.reduce_inequalities`. + It treats Expr like Equality. + + >>> from sympy import reduce_inequalities + >>> reduce_inequalities([x**2 - 4]) + Eq(x, -2) | Eq(x, 2) + + If each relationship contains only one symbol of interest, + the expressions can be processed for multiple symbols: + + >>> reduce_inequalities([0 <= x - 1, y < 3], [x, y]) + (-oo < y) & (1 <= x) & (x < oo) & (y < 3) + + But an error is raised if any relationship has more than one + symbol of interest: + + >>> reduce_inequalities([0 <= x*y - 1, y < 3], [x, y]) + Traceback (most recent call last): + ... + NotImplementedError: + inequality has more than one symbol of interest. + + **Disabling High-Order Explicit Solutions** + + When solving polynomial expressions, you might not want explicit solutions + (which can be quite long). If the expression is univariate, ``CRootOf`` + instances will be returned instead: + + >>> solve(x**3 - x + 1) + [-1/((-1/2 - sqrt(3)*I/2)*(3*sqrt(69)/2 + 27/2)**(1/3)) - + (-1/2 - sqrt(3)*I/2)*(3*sqrt(69)/2 + 27/2)**(1/3)/3, + -(-1/2 + sqrt(3)*I/2)*(3*sqrt(69)/2 + 27/2)**(1/3)/3 - + 1/((-1/2 + sqrt(3)*I/2)*(3*sqrt(69)/2 + 27/2)**(1/3)), + -(3*sqrt(69)/2 + 27/2)**(1/3)/3 - + 1/(3*sqrt(69)/2 + 27/2)**(1/3)] + >>> solve(x**3 - x + 1, cubics=False) + [CRootOf(x**3 - x + 1, 0), + CRootOf(x**3 - x + 1, 1), + CRootOf(x**3 - x + 1, 2)] + + If the expression is multivariate, no solution might be returned: + + >>> solve(x**3 - x + a, x, cubics=False) + [] + + Sometimes solutions will be obtained even when a flag is False because the + expression could be factored. In the following example, the equation can + be factored as the product of a linear and a quadratic factor so explicit + solutions (which did not require solving a cubic expression) are obtained: + + >>> eq = x**3 + 3*x**2 + x - 1 + >>> solve(eq, cubics=False) + [-1, -1 + sqrt(2), -sqrt(2) - 1] + + **Solving Equations Involving Radicals** + + Because of SymPy's use of the principle root, some solutions + to radical equations will be missed unless check=False: + + >>> from sympy import root + >>> eq = root(x**3 - 3*x**2, 3) + 1 - x + >>> solve(eq) + [] + >>> solve(eq, check=False) + [1/3] + + In the above example, there is only a single solution to the + equation. Other expressions will yield spurious roots which + must be checked manually; roots which give a negative argument + to odd-powered radicals will also need special checking: + + >>> from sympy import real_root, S + >>> eq = root(x, 3) - root(x, 5) + S(1)/7 + >>> solve(eq) # this gives 2 solutions but misses a 3rd + [CRootOf(7*x**5 - 7*x**3 + 1, 1)**15, + CRootOf(7*x**5 - 7*x**3 + 1, 2)**15] + >>> sol = solve(eq, check=False) + >>> [abs(eq.subs(x,i).n(2)) for i in sol] + [0.48, 0.e-110, 0.e-110, 0.052, 0.052] + + The first solution is negative so ``real_root`` must be used to see that it + satisfies the expression: + + >>> abs(real_root(eq.subs(x, sol[0])).n(2)) + 0.e-110 + + If the roots of the equation are not real then more care will be + necessary to find the roots, especially for higher order equations. + Consider the following expression: + + >>> expr = root(x, 3) - root(x, 5) + + We will construct a known value for this expression at x = 3 by selecting + the 1-th root for each radical: + + >>> expr1 = root(x, 3, 1) - root(x, 5, 1) + >>> v = expr1.subs(x, -3) + + The ``solve`` function is unable to find any exact roots to this equation: + + >>> eq = Eq(expr, v); eq1 = Eq(expr1, v) + >>> solve(eq, check=False), solve(eq1, check=False) + ([], []) + + The function ``unrad``, however, can be used to get a form of the equation + for which numerical roots can be found: + + >>> from sympy.solvers.solvers import unrad + >>> from sympy import nroots + >>> e, (p, cov) = unrad(eq) + >>> pvals = nroots(e) + >>> inversion = solve(cov, x)[0] + >>> xvals = [inversion.subs(p, i) for i in pvals] + + Although ``eq`` or ``eq1`` could have been used to find ``xvals``, the + solution can only be verified with ``expr1``: + + >>> z = expr - v + >>> [xi.n(chop=1e-9) for xi in xvals if abs(z.subs(x, xi).n()) < 1e-9] + [] + >>> z1 = expr1 - v + >>> [xi.n(chop=1e-9) for xi in xvals if abs(z1.subs(x, xi).n()) < 1e-9] + [-3.0] + + Parameters + ========== + + f : + - a single Expr or Poly that must be zero + - an Equality + - a Relational expression + - a Boolean + - iterable of one or more of the above + + symbols : (object(s) to solve for) specified as + - none given (other non-numeric objects will be used) + - single symbol + - denested list of symbols + (e.g., ``solve(f, x, y)``) + - ordered iterable of symbols + (e.g., ``solve(f, [x, y])``) + + flags : + dict=True (default is False) + Return list (perhaps empty) of solution mappings. + set=True (default is False) + Return list of symbols and set of tuple(s) of solution(s). + exclude=[] (default) + Do not try to solve for any of the free symbols in exclude; + if expressions are given, the free symbols in them will + be extracted automatically. + check=True (default) + If False, do not do any testing of solutions. This can be + useful if you want to include solutions that make any + denominator zero. + numerical=True (default) + Do a fast numerical check if *f* has only one symbol. + minimal=True (default is False) + A very fast, minimal testing. + warn=True (default is False) + Show a warning if ``checksol()`` could not conclude. + simplify=True (default) + Simplify all but polynomials of order 3 or greater before + returning them and (if check is not False) use the + general simplify function on the solutions and the + expression obtained when they are substituted into the + function which should be zero. + force=True (default is False) + Make positive all symbols without assumptions regarding sign. + rational=True (default) + Recast Floats as Rational; if this option is not used, the + system containing Floats may fail to solve because of issues + with polys. If rational=None, Floats will be recast as + rationals but the answer will be recast as Floats. If the + flag is False then nothing will be done to the Floats. + manual=True (default is False) + Do not use the polys/matrix method to solve a system of + equations, solve them one at a time as you might "manually." + implicit=True (default is False) + Allows ``solve`` to return a solution for a pattern in terms of + other functions that contain that pattern; this is only + needed if the pattern is inside of some invertible function + like cos, exp, etc. + particular=True (default is False) + Instructs ``solve`` to try to find a particular solution to + a linear system with as many zeros as possible; this is very + expensive. + quick=True (default is False; ``particular`` must be True) + Selects a fast heuristic to find a solution with many zeros + whereas a value of False uses the very slow method guaranteed + to find the largest number of zeros possible. + cubics=True (default) + Return explicit solutions when cubic expressions are encountered. + When False, quartics and quintics are disabled, too. + quartics=True (default) + Return explicit solutions when quartic expressions are encountered. + When False, quintics are disabled, too. + quintics=True (default) + Return explicit solutions (if possible) when quintic expressions + are encountered. + + See Also + ======== + + rsolve: For solving recurrence relationships + sympy.solvers.ode.dsolve: For solving differential equations + + """ + from .inequalities import reduce_inequalities + + # checking/recording flags + ########################################################################### + + # set solver types explicitly; as soon as one is False + # all the rest will be False + hints = ('cubics', 'quartics', 'quintics') + default = True + for k in hints: + default = flags.setdefault(k, bool(flags.get(k, default))) + + # allow solution to contain symbol if True: + implicit = flags.get('implicit', False) + + # record desire to see warnings + warn = flags.get('warn', False) + + # this flag will be needed for quick exits below, so record + # now -- but don't record `dict` yet since it might change + as_set = flags.get('set', False) + + # keeping track of how f was passed + bare_f = not iterable(f) + + # check flag usage for particular/quick which should only be used + # with systems of equations + if flags.get('quick', None) is not None: + if not flags.get('particular', None): + raise ValueError('when using `quick`, `particular` should be True') + if flags.get('particular', False) and bare_f: + raise ValueError(filldedent(""" + The 'particular/quick' flag is usually used with systems of + equations. Either pass your equation in a list or + consider using a solver like `diophantine` if you are + looking for a solution in integers.""")) + + # sympify everything, creating list of expressions and list of symbols + ########################################################################### + + def _sympified_list(w): + return list(map(sympify, w if iterable(w) else [w])) + f, symbols = (_sympified_list(w) for w in [f, symbols]) + + # preprocess symbol(s) + ########################################################################### + + ordered_symbols = None # were the symbols in a well defined order? + if not symbols: + # get symbols from equations + symbols = set().union(*[fi.free_symbols for fi in f]) + if len(symbols) < len(f): + for fi in f: + pot = preorder_traversal(fi) + for p in pot: + if isinstance(p, AppliedUndef): + if not as_set: + flags['dict'] = True # better show symbols + symbols.add(p) + pot.skip() # don't go any deeper + ordered_symbols = False + symbols = list(ordered(symbols)) # to make it canonical + else: + if len(symbols) == 1 and iterable(symbols[0]): + symbols = symbols[0] + ordered_symbols = symbols and is_sequence(symbols, + include=GeneratorType) + _symbols = list(uniq(symbols)) + if len(_symbols) != len(symbols): + ordered_symbols = False + symbols = list(ordered(symbols)) + else: + symbols = _symbols + + # check for duplicates + if len(symbols) != len(set(symbols)): + raise ValueError('duplicate symbols given') + # remove those not of interest + exclude = flags.pop('exclude', set()) + if exclude: + if isinstance(exclude, Expr): + exclude = [exclude] + exclude = set().union(*[e.free_symbols for e in sympify(exclude)]) + symbols = [s for s in symbols if s not in exclude] + + # preprocess equation(s) + ########################################################################### + + # automatically ignore True values + if isinstance(f, list): + f = [s for s in f if s is not S.true] + + # handle canonicalization of equation types + for i, fi in enumerate(f): + if isinstance(fi, (Eq, Ne)): + if 'ImmutableDenseMatrix' in [type(a).__name__ for a in fi.args]: + fi = fi.lhs - fi.rhs + else: + L, R = fi.args + if isinstance(R, BooleanAtom): + L, R = R, L + if isinstance(L, BooleanAtom): + if isinstance(fi, Ne): + L = ~L + if R.is_Relational: + fi = ~R if L is S.false else R + elif R.is_Symbol: + return L + elif R.is_Boolean and (~R).is_Symbol: + return ~L + else: + raise NotImplementedError(filldedent(''' + Unanticipated argument of Eq when other arg + is True or False. + ''')) + elif isinstance(fi, Eq): + fi = Add(fi.lhs, -fi.rhs, evaluate=False) + f[i] = fi + + # *** dispatch and handle as a system of relationals + # ************************************************** + if fi.is_Relational: + if len(symbols) != 1: + raise ValueError("can only solve for one symbol at a time") + if warn and symbols[0].assumptions0: + warnings.warn(filldedent(""" + \tWarning: assumptions about variable '%s' are + not handled currently.""" % symbols[0])) + return reduce_inequalities(f, symbols=symbols) + + # convert Poly to expression + if isinstance(fi, Poly): + f[i] = fi.as_expr() + + # rewrite hyperbolics in terms of exp if they have symbols of + # interest + f[i] = f[i].replace(lambda w: isinstance(w, HyperbolicFunction) and \ + w.has_free(*symbols), lambda w: w.rewrite(exp)) + + # if we have a Matrix, we need to iterate over its elements again + if f[i].is_Matrix: + try: + f[i] = f[i].as_explicit() + except ValueError: + raise ValueError( + "solve cannot handle matrices with symbolic shape." + ) + bare_f = False + f.extend(list(f[i])) + f[i] = S.Zero + + # if we can split it into real and imaginary parts then do so + freei = f[i].free_symbols + if freei and all(s.is_extended_real or s.is_imaginary for s in freei): + fr, fi = f[i].as_real_imag() + # accept as long as new re, im, arg or atan2 are not introduced + had = f[i].atoms(re, im, arg, atan2) + if fr and fi and fr != fi and not any( + i.atoms(re, im, arg, atan2) - had for i in (fr, fi)): + if bare_f: + bare_f = False + f[i: i + 1] = [fr, fi] + + # real/imag handling ----------------------------- + if any(isinstance(fi, (bool, BooleanAtom)) for fi in f): + if as_set: + return [], set() + return [] + + for i, fi in enumerate(f): + # Abs + while True: + was = fi + fi = fi.replace(Abs, lambda arg: + separatevars(Abs(arg)).rewrite(Piecewise) if arg.has(*symbols) + else Abs(arg)) + if was == fi: + break + + for e in fi.find(Abs): + if e.has(*symbols): + raise NotImplementedError('solving %s when the argument ' + 'is not real or imaginary.' % e) + + # arg + fi = fi.replace(arg, lambda a: arg(a).rewrite(atan2).rewrite(atan)) + + # save changes + f[i] = fi + + # see if re(s) or im(s) appear + freim = [fi for fi in f if fi.has(re, im)] + if freim: + irf = [] + for s in symbols: + if s.is_real or s.is_imaginary: + continue # neither re(x) nor im(x) will appear + # if re(s) or im(s) appear, the auxiliary equation must be present + if any(fi.has(re(s), im(s)) for fi in freim): + irf.append((s, re(s) + S.ImaginaryUnit*im(s))) + if irf: + for s, rhs in irf: + f = [fi.xreplace({s: rhs}) for fi in f] + [s - rhs] + symbols.extend([re(s), im(s)]) + if bare_f: + bare_f = False + flags['dict'] = True + # end of real/imag handling ----------------------------- + + # we can solve for non-symbol entities by replacing them with Dummy symbols + f, symbols, swap_sym = recast_to_symbols(f, symbols) + # this set of symbols (perhaps recast) is needed below + symset = set(symbols) + + # get rid of equations that have no symbols of interest; we don't + # try to solve them because the user didn't ask and they might be + # hard to solve; this means that solutions may be given in terms + # of the eliminated equations e.g. solve((x-y, y-3), x) -> {x: y} + newf = [] + for fi in f: + # let the solver handle equations that.. + # - have no symbols but are expressions + # - have symbols of interest + # - have no symbols of interest but are constant + # but when an expression is not constant and has no symbols of + # interest, it can't change what we obtain for a solution from + # the remaining equations so we don't include it; and if it's + # zero it can be removed and if it's not zero, there is no + # solution for the equation set as a whole + # + # The reason for doing this filtering is to allow an answer + # to be obtained to queries like solve((x - y, y), x); without + # this mod the return value is [] + ok = False + if fi.free_symbols & symset: + ok = True + else: + if fi.is_number: + if fi.is_Number: + if fi.is_zero: + continue + return [] + ok = True + else: + if fi.is_constant(): + ok = True + if ok: + newf.append(fi) + if not newf: + if as_set: + return symbols, set() + return [] + f = newf + del newf + + # mask off any Object that we aren't going to invert: Derivative, + # Integral, etc... so that solving for anything that they contain will + # give an implicit solution + seen = set() + non_inverts = set() + for fi in f: + pot = preorder_traversal(fi) + for p in pot: + if not isinstance(p, Expr) or isinstance(p, Piecewise): + pass + elif (isinstance(p, bool) or + not p.args or + p in symset or + p.is_Add or p.is_Mul or + p.is_Pow and not implicit or + p.is_Function and not implicit) and p.func not in (re, im): + continue + elif p not in seen: + seen.add(p) + if p.free_symbols & symset: + non_inverts.add(p) + else: + continue + pot.skip() + del seen + non_inverts = dict(list(zip(non_inverts, [Dummy() for _ in non_inverts]))) + f = [fi.subs(non_inverts) for fi in f] + + # Both xreplace and subs are needed below: xreplace to force substitution + # inside Derivative, subs to handle non-straightforward substitutions + non_inverts = [(v, k.xreplace(swap_sym).subs(swap_sym)) for k, v in non_inverts.items()] + + # rationalize Floats + floats = False + if flags.get('rational', True) is not False: + for i, fi in enumerate(f): + if fi.has(Float): + floats = True + f[i] = nsimplify(fi, rational=True) + + # capture any denominators before rewriting since + # they may disappear after the rewrite, e.g. issue 14779 + flags['_denominators'] = _simple_dens(f[0], symbols) + + # Any embedded piecewise functions need to be brought out to the + # top level so that the appropriate strategy gets selected. + # However, this is necessary only if one of the piecewise + # functions depends on one of the symbols we are solving for. + def _has_piecewise(e): + if e.is_Piecewise: + return e.has(*symbols) + return any(_has_piecewise(a) for a in e.args) + for i, fi in enumerate(f): + if _has_piecewise(fi): + f[i] = piecewise_fold(fi) + + # expand angles of sums; in general, expand_trig will allow + # more roots to be found but this is not a great solultion + # to not returning a parametric solution, otherwise + # many values can be returned that have a simple + # relationship between values + targs = {t for fi in f for t in fi.atoms(TrigonometricFunction)} + if len(targs) > 1: + add, other = sift(targs, lambda x: x.args[0].is_Add, binary=True) + add, other = [[i for i in l if i.has_free(*symbols)] for l in (add, other)] + trep = {} + for t in add: + a = t.args[0] + ind, dep = a.as_independent(*symbols) + if dep in symbols or -dep in symbols: + # don't let expansion expand wrt anything in ind + n = Dummy() if not ind.is_Number else ind + trep[t] = TR10(t.func(dep + n)).xreplace({n: ind}) + if other and len(other) <= 2: + base = gcd(*[i.args[0] for i in other]) if len(other) > 1 else other[0].args[0] + for i in other: + trep[i] = TR11(i, base) + f = [fi.xreplace(trep) for fi in f] + + # + # try to get a solution + ########################################################################### + if bare_f: + solution = None + if len(symbols) != 1: + solution = _solve_undetermined(f[0], symbols, flags) + if not solution: + solution = _solve(f[0], *symbols, **flags) + else: + linear, solution = _solve_system(f, symbols, **flags) + assert type(solution) is list + assert not solution or type(solution[0]) is dict, solution + # + # postprocessing + ########################################################################### + # capture as_dict flag now (as_set already captured) + as_dict = flags.get('dict', False) + + # define how solution will get unpacked + tuple_format = lambda s: [tuple([i.get(x, x) for x in symbols]) for i in s] + if as_dict or as_set: + unpack = None + elif bare_f: + if len(symbols) == 1: + unpack = lambda s: [i[symbols[0]] for i in s] + elif len(solution) == 1 and len(solution[0]) == len(symbols): + # undetermined linear coeffs solution + unpack = lambda s: s[0] + elif ordered_symbols: + unpack = tuple_format + else: + unpack = lambda s: s + else: + if solution: + if linear and len(solution) == 1: + # if you want the tuple solution for the linear + # case, use `set=True` + unpack = lambda s: s[0] + elif ordered_symbols: + unpack = tuple_format + else: + unpack = lambda s: s + else: + unpack = None + + # Restore masked-off objects + if non_inverts and type(solution) is list: + solution = [{k: v.subs(non_inverts) for k, v in s.items()} + for s in solution] + + # Restore original "symbols" if a dictionary is returned. + # This is not necessary for + # - the single univariate equation case + # since the symbol will have been removed from the solution; + # - the nonlinear poly_system since that only supports zero-dimensional + # systems and those results come back as a list + # + # ** unless there were Derivatives with the symbols, but those were handled + # above. + if swap_sym: + symbols = [swap_sym.get(k, k) for k in symbols] + for i, sol in enumerate(solution): + solution[i] = {swap_sym.get(k, k): v.subs(swap_sym) + for k, v in sol.items()} + + # Get assumptions about symbols, to filter solutions. + # Note that if assumptions about a solution can't be verified, it is still + # returned. + check = flags.get('check', True) + + # restore floats + if floats and solution and flags.get('rational', None) is None: + solution = nfloat(solution, exponent=False) + # nfloat might reveal more duplicates + solution = _remove_duplicate_solutions(solution) + + if check and solution: # assumption checking + warn = flags.get('warn', False) + got_None = [] # solutions for which one or more symbols gave None + no_False = [] # solutions for which no symbols gave False + for sol in solution: + v = fuzzy_and(check_assumptions(val, **symb.assumptions0) + for symb, val in sol.items()) + if v is False: + continue + no_False.append(sol) + if v is None: + got_None.append(sol) + + solution = no_False + if warn and got_None: + warnings.warn(filldedent(""" + \tWarning: assumptions concerning following solution(s) + cannot be checked:""" + '\n\t' + + ', '.join(str(s) for s in got_None))) + + # + # done + ########################################################################### + + if not solution: + if as_set: + return symbols, set() + return [] + + # make orderings canonical for list of dictionaries + if not as_set: # for set, no point in ordering + solution = [{k: s[k] for k in ordered(s)} for s in solution] + solution.sort(key=default_sort_key) + + if not (as_set or as_dict): + return unpack(solution) + + if as_dict: + return solution + + # set output: (symbols, {t1, t2, ...}) from list of dictionaries; + # include all symbols for those that like a verbose solution + # and to resolve any differences in dictionary keys. + # + # The set results can easily be used to make a verbose dict as + # k, v = solve(eqs, syms, set=True) + # sol = [dict(zip(k,i)) for i in v] + # + if ordered_symbols: + k = symbols # keep preferred order + else: + # just unify the symbols for which solutions were found + k = list(ordered(set(flatten(tuple(i.keys()) for i in solution)))) + return k, {tuple([s.get(ki, ki) for ki in k]) for s in solution} + + +def _solve_undetermined(g, symbols, flags): + """solve helper to return a list with one dict (solution) else None + + A direct call to solve_undetermined_coeffs is more flexible and + can return both multiple solutions and handle more than one independent + variable. Here, we have to be more cautious to keep from solving + something that does not look like an undetermined coeffs system -- + to minimize the surprise factor since singularities that cancel are not + prohibited in solve_undetermined_coeffs. + """ + if g.free_symbols - set(symbols): + sol = solve_undetermined_coeffs(g, symbols, **dict(flags, dict=True, set=None)) + if len(sol) == 1: + return sol + + +def _solve(f, *symbols, **flags): + """Return a checked solution for *f* in terms of one or more of the + symbols in the form of a list of dictionaries. + + If no method is implemented to solve the equation, a NotImplementedError + will be raised. In the case that conversion of an expression to a Poly + gives None a ValueError will be raised. + """ + + not_impl_msg = "No algorithms are implemented to solve equation %s" + + if len(symbols) != 1: + # look for solutions for desired symbols that are independent + # of symbols already solved for, e.g. if we solve for x = y + # then no symbol having x in its solution will be returned. + + # First solve for linear symbols (since that is easier and limits + # solution size) and then proceed with symbols appearing + # in a non-linear fashion. Ideally, if one is solving a single + # expression for several symbols, they would have to be + # appear in factors of an expression, but we do not here + # attempt factorization. XXX perhaps handling a Mul + # should come first in this routine whether there is + # one or several symbols. + nonlin_s = [] + got_s = set() + rhs_s = set() + result = [] + for s in symbols: + xi, v = solve_linear(f, symbols=[s]) + if xi == s: + # no need to check but we should simplify if desired + if flags.get('simplify', True): + v = simplify(v) + vfree = v.free_symbols + if vfree & got_s: + # was linear, but has redundant relationship + # e.g. x - y = 0 has y == x is redundant for x == y + # so ignore + continue + rhs_s |= vfree + got_s.add(xi) + result.append({xi: v}) + elif xi: # there might be a non-linear solution if xi is not 0 + nonlin_s.append(s) + if not nonlin_s: + return result + for s in nonlin_s: + try: + soln = _solve(f, s, **flags) + for sol in soln: + if sol[s].free_symbols & got_s: + # depends on previously solved symbols: ignore + continue + got_s.add(s) + result.append(sol) + except NotImplementedError: + continue + if got_s: + return result + else: + raise NotImplementedError(not_impl_msg % f) + + # solve f for a single variable + + symbol = symbols[0] + + # expand binomials only if it has the unknown symbol + f = f.replace(lambda e: isinstance(e, binomial) and e.has(symbol), + lambda e: expand_func(e)) + + # checking will be done unless it is turned off before making a + # recursive call; the variables `checkdens` and `check` are + # captured here (for reference below) in case flag value changes + flags['check'] = checkdens = check = flags.pop('check', True) + + # build up solutions if f is a Mul + if f.is_Mul: + result = set() + for m in f.args: + if m in {S.NegativeInfinity, S.ComplexInfinity, S.Infinity}: + result = set() + break + soln = _vsolve(m, symbol, **flags) + result.update(set(soln)) + result = [{symbol: v} for v in result] + if check: + # all solutions have been checked but now we must + # check that the solutions do not set denominators + # in any factor to zero + dens = flags.get('_denominators', _simple_dens(f, symbols)) + result = [s for s in result if + not any(checksol(den, s, **flags) for den in + dens)] + # set flags for quick exit at end; solutions for each + # factor were already checked and simplified + check = False + flags['simplify'] = False + + elif f.is_Piecewise: + result = set() + if any(e.is_zero for e, c in f.args): + f = f.simplify() # failure imminent w/o help + + cond = neg = True + for expr, cnd in f.args: + # the explicit condition for this expr is the current cond + # and none of the previous conditions + cond = And(neg, cnd) + neg = And(neg, ~cond) + + if expr.is_zero and cond.simplify() != False: + raise NotImplementedError(filldedent(''' + An expression is already zero when %s. + This means that in this *region* the solution + is zero but solve can only represent discrete, + not interval, solutions. If this is a spurious + interval it might be resolved with simplification + of the Piecewise conditions.''' % cond)) + candidates = _vsolve(expr, symbol, **flags) + + for candidate in candidates: + if candidate in result: + # an unconditional value was already there + continue + try: + v = cond.subs(symbol, candidate) + _eval_simplify = getattr(v, '_eval_simplify', None) + if _eval_simplify is not None: + # unconditionally take the simplification of v + v = _eval_simplify(ratio=2, measure=lambda x: 1) + except TypeError: + # incompatible type with condition(s) + continue + if v == False: + continue + if v == True: + result.add(candidate) + else: + result.add(Piecewise( + (candidate, v), + (S.NaN, True))) + # solutions already checked and simplified + # **************************************** + return [{symbol: r} for r in result] + else: + # first see if it really depends on symbol and whether there + # is only a linear solution + f_num, sol = solve_linear(f, symbols=symbols) + if f_num.is_zero or sol is S.NaN: + return [] + elif f_num.is_Symbol: + # no need to check but simplify if desired + if flags.get('simplify', True): + sol = simplify(sol) + return [{f_num: sol}] + + poly = None + # check for a single Add generator + if not f_num.is_Add: + add_args = [i for i in f_num.atoms(Add) + if symbol in i.free_symbols] + if len(add_args) == 1: + gen = add_args[0] + spart = gen.as_independent(symbol)[1].as_base_exp()[0] + if spart == symbol: + try: + poly = Poly(f_num, spart) + except PolynomialError: + pass + + result = False # no solution was obtained + msg = '' # there is no failure message + + # Poly is generally robust enough to convert anything to + # a polynomial and tell us the different generators that it + # contains, so we will inspect the generators identified by + # polys to figure out what to do. + + # try to identify a single generator that will allow us to solve this + # as a polynomial, followed (perhaps) by a change of variables if the + # generator is not a symbol + + try: + if poly is None: + poly = Poly(f_num) + if poly is None: + raise ValueError('could not convert %s to Poly' % f_num) + except GeneratorsNeeded: + simplified_f = simplify(f_num) + if simplified_f != f_num: + return _solve(simplified_f, symbol, **flags) + raise ValueError('expression appears to be a constant') + + gens = [g for g in poly.gens if g.has(symbol)] + + def _as_base_q(x): + """Return (b**e, q) for x = b**(p*e/q) where p/q is the leading + Rational of the exponent of x, e.g. exp(-2*x/3) -> (exp(x), 3) + """ + b, e = x.as_base_exp() + if e.is_Rational: + return b, e.q + if not e.is_Mul: + return x, 1 + c, ee = e.as_coeff_Mul() + if c.is_Rational and c is not S.One: # c could be a Float + return b**ee, c.q + return x, 1 + + if len(gens) > 1: + # If there is more than one generator, it could be that the + # generators have the same base but different powers, e.g. + # >>> Poly(exp(x) + 1/exp(x)) + # Poly(exp(-x) + exp(x), exp(-x), exp(x), domain='ZZ') + # + # If unrad was not disabled then there should be no rational + # exponents appearing as in + # >>> Poly(sqrt(x) + sqrt(sqrt(x))) + # Poly(sqrt(x) + x**(1/4), sqrt(x), x**(1/4), domain='ZZ') + + bases, qs = list(zip(*[_as_base_q(g) for g in gens])) + bases = set(bases) + + if len(bases) > 1 or not all(q == 1 for q in qs): + funcs = {b for b in bases if b.is_Function} + + trig = {_ for _ in funcs if + isinstance(_, TrigonometricFunction)} + other = funcs - trig + if not other and len(funcs.intersection(trig)) > 1: + newf = None + if f_num.is_Add and len(f_num.args) == 2: + # check for sin(x)**p = cos(x)**p + _args = f_num.args + t = a, b = [i.atoms(Function).intersection( + trig) for i in _args] + if all(len(i) == 1 for i in t): + a, b = [i.pop() for i in t] + if isinstance(a, cos): + a, b = b, a + _args = _args[::-1] + if isinstance(a, sin) and isinstance(b, cos + ) and a.args[0] == b.args[0]: + # sin(x) + cos(x) = 0 -> tan(x) + 1 = 0 + newf, _d = (TR2i(_args[0]/_args[1]) + 1 + ).as_numer_denom() + if not _d.is_Number: + newf = None + if newf is None: + newf = TR1(f_num).rewrite(tan) + if newf != f_num: + # don't check the rewritten form --check + # solutions in the un-rewritten form below + flags['check'] = False + result = _solve(newf, symbol, **flags) + flags['check'] = check + + # just a simple case - see if replacement of single function + # clears all symbol-dependent functions, e.g. + # log(x) - log(log(x) - 1) - 3 can be solved even though it has + # two generators. + + if result is False and funcs: + funcs = list(ordered(funcs)) # put shallowest function first + f1 = funcs[0] + t = Dummy('t') + # perform the substitution + ftry = f_num.subs(f1, t) + + # if no Functions left, we can proceed with usual solve + if not ftry.has(symbol): + cv_sols = _solve(ftry, t, **flags) + cv_inv = list(ordered(_vsolve(t - f1, symbol, **flags)))[0] + result = [{symbol: cv_inv.subs(sol)} for sol in cv_sols] + + if result is False: + msg = 'multiple generators %s' % gens + + else: + # e.g. case where gens are exp(x), exp(-x) + u = bases.pop() + t = Dummy('t') + inv = _vsolve(u - t, symbol, **flags) + if isinstance(u, (Pow, exp)): + # this will be resolved by factor in _tsolve but we might + # as well try a simple expansion here to get things in + # order so something like the following will work now without + # having to factor: + # + # >>> eq = (exp(I*(-x-2))+exp(I*(x+2))) + # >>> eq.subs(exp(x),y) # fails + # exp(I*(-x - 2)) + exp(I*(x + 2)) + # >>> eq.expand().subs(exp(x),y) # works + # y**I*exp(2*I) + y**(-I)*exp(-2*I) + def _expand(p): + b, e = p.as_base_exp() + e = expand_mul(e) + return expand_power_exp(b**e) + ftry = f_num.replace( + lambda w: w.is_Pow or isinstance(w, exp), + _expand).subs(u, t) + if not ftry.has(symbol): + soln = _solve(ftry, t, **flags) + result = [{symbol: i.subs(s)} for i in inv for s in soln] + + elif len(gens) == 1: + + # There is only one generator that we are interested in, but + # there may have been more than one generator identified by + # polys (e.g. for symbols other than the one we are interested + # in) so recast the poly in terms of our generator of interest. + # Also use composite=True with f_num since Poly won't update + # poly as documented in issue 8810. + + poly = Poly(f_num, gens[0], composite=True) + + # if we aren't on the tsolve-pass, use roots + if not flags.pop('tsolve', False): + soln = None + deg = poly.degree() + flags['tsolve'] = True + hints = ('cubics', 'quartics', 'quintics') + solvers = {h: flags.get(h) for h in hints} + soln = roots(poly, **solvers) + if sum(soln.values()) < deg: + # e.g. roots(32*x**5 + 400*x**4 + 2032*x**3 + + # 5000*x**2 + 6250*x + 3189) -> {} + # so all_roots is used and RootOf instances are + # returned *unless* the system is multivariate + # or high-order EX domain. + try: + soln = poly.all_roots() + except NotImplementedError: + if not flags.get('incomplete', True): + raise NotImplementedError( + filldedent(''' + Neither high-order multivariate polynomials + nor sorting of EX-domain polynomials is supported. + If you want to see any results, pass keyword incomplete=True to + solve; to see numerical values of roots + for univariate expressions, use nroots. + ''')) + else: + pass + else: + soln = list(soln.keys()) + + if soln is not None: + u = poly.gen + if u != symbol: + try: + t = Dummy('t') + inv = _vsolve(u - t, symbol, **flags) + soln = {i.subs(t, s) for i in inv for s in soln} + except NotImplementedError: + # perhaps _tsolve can handle f_num + soln = None + else: + check = False # only dens need to be checked + if soln is not None: + if len(soln) > 2: + # if the flag wasn't set then unset it since high-order + # results are quite long. Perhaps one could base this + # decision on a certain critical length of the + # roots. In addition, wester test M2 has an expression + # whose roots can be shown to be real with the + # unsimplified form of the solution whereas only one of + # the simplified forms appears to be real. + flags['simplify'] = flags.get('simplify', False) + if soln is not None: + result = [{symbol: v} for v in soln] + + # fallback if above fails + # ----------------------- + if result is False: + # try unrad + if flags.pop('_unrad', True): + try: + u = unrad(f_num, symbol) + except (ValueError, NotImplementedError): + u = False + if u: + eq, cov = u + if cov: + isym, ieq = cov + inv = _vsolve(ieq, symbol, **flags)[0] + rv = {inv.subs(xi) for xi in _solve(eq, isym, **flags)} + else: + try: + rv = set(_vsolve(eq, symbol, **flags)) + except NotImplementedError: + rv = None + if rv is not None: + result = [{symbol: v} for v in rv] + # if the flag wasn't set then unset it since unrad results + # can be quite long or of very high order + flags['simplify'] = flags.get('simplify', False) + else: + pass # for coverage + + # try _tsolve + if result is False: + flags.pop('tsolve', None) # allow tsolve to be used on next pass + try: + soln = _tsolve(f_num, symbol, **flags) + if soln is not None: + result = [{symbol: v} for v in soln] + except PolynomialError: + pass + # ----------- end of fallback ---------------------------- + + if result is False: + raise NotImplementedError('\n'.join([msg, not_impl_msg % f])) + + result = _remove_duplicate_solutions(result) + + if flags.get('simplify', True): + result = [{k: d[k].simplify() for k in d} for d in result] + # Simplification might reveal more duplicates + result = _remove_duplicate_solutions(result) + # we just simplified the solution so we now set the flag to + # False so the simplification doesn't happen again in checksol() + flags['simplify'] = False + + if checkdens: + # reject any result that makes any denom. affirmatively 0; + # if in doubt, keep it + dens = _simple_dens(f, symbols) + result = [r for r in result if + not any(checksol(d, r, **flags) + for d in dens)] + if check: + # keep only results if the check is not False + result = [r for r in result if + checksol(f_num, r, **flags) is not False] + return result + + +def _remove_duplicate_solutions(solutions: list[dict[Expr, Expr]] + ) -> list[dict[Expr, Expr]]: + """Remove duplicates from a list of dicts""" + solutions_set = set() + solutions_new = [] + + for sol in solutions: + solset = frozenset(sol.items()) + if solset not in solutions_set: + solutions_new.append(sol) + solutions_set.add(solset) + + return solutions_new + + +def _solve_system(exprs, symbols, **flags): + """return ``(linear, solution)`` where ``linear`` is True + if the system was linear, else False; ``solution`` + is a list of dictionaries giving solutions for the symbols + """ + if not exprs: + return False, [] + + if flags.pop('_split', True): + # Split the system into connected components + V = exprs + symsset = set(symbols) + exprsyms = {e: e.free_symbols & symsset for e in exprs} + E = [] + sym_indices = {sym: i for i, sym in enumerate(symbols)} + for n, e1 in enumerate(exprs): + for e2 in exprs[:n]: + # Equations are connected if they share a symbol + if exprsyms[e1] & exprsyms[e2]: + E.append((e1, e2)) + G = V, E + subexprs = connected_components(G) + if len(subexprs) > 1: + subsols = [] + linear = True + for subexpr in subexprs: + subsyms = set() + for e in subexpr: + subsyms |= exprsyms[e] + subsyms = sorted(subsyms, key = lambda x: sym_indices[x]) + flags['_split'] = False # skip split step + _linear, subsol = _solve_system(subexpr, subsyms, **flags) + if linear: + linear = linear and _linear + if not isinstance(subsol, list): + subsol = [subsol] + subsols.append(subsol) + # Full solution is cartesian product of subsystems + sols = [] + for soldicts in product(*subsols): + sols.append(dict(item for sd in soldicts + for item in sd.items())) + return linear, sols + + polys = [] + dens = set() + failed = [] + result = [] + solved_syms = [] + linear = True + manual = flags.get('manual', False) + checkdens = check = flags.get('check', True) + + for j, g in enumerate(exprs): + dens.update(_simple_dens(g, symbols)) + i, d = _invert(g, *symbols) + if d in symbols: + if linear: + linear = solve_linear(g, 0, [d])[0] == d + g = d - i + g = g.as_numer_denom()[0] + if manual: + failed.append(g) + continue + + poly = g.as_poly(*symbols, extension=True) + + if poly is not None: + polys.append(poly) + else: + failed.append(g) + + if polys: + if all(p.is_linear for p in polys): + n, m = len(polys), len(symbols) + matrix = zeros(n, m + 1) + + for i, poly in enumerate(polys): + for monom, coeff in poly.terms(): + try: + j = monom.index(1) + matrix[i, j] = coeff + except ValueError: + matrix[i, m] = -coeff + + # returns a dictionary ({symbols: values}) or None + if flags.pop('particular', False): + result = minsolve_linear_system(matrix, *symbols, **flags) + else: + result = solve_linear_system(matrix, *symbols, **flags) + result = [result] if result else [] + if failed: + if result: + solved_syms = list(result[0].keys()) # there is only one result dict + else: + solved_syms = [] + # linear doesn't change + else: + linear = False + if len(symbols) > len(polys): + + free = set().union(*[p.free_symbols for p in polys]) + free = list(ordered(free.intersection(symbols))) + got_s = set() + result = [] + for syms in subsets(free, min(len(free), len(polys))): + try: + # returns [], None or list of tuples + res = solve_poly_system(polys, *syms) + if res: + for r in set(res): + skip = False + for r1 in r: + if got_s and any(ss in r1.free_symbols + for ss in got_s): + # sol depends on previously + # solved symbols: discard it + skip = True + if not skip: + got_s.update(syms) + result.append(dict(list(zip(syms, r)))) + except NotImplementedError: + pass + if got_s: + solved_syms = list(got_s) + else: + failed.extend([g.as_expr() for g in polys]) + else: + try: + result = solve_poly_system(polys, *symbols) + if result: + solved_syms = symbols + result = [dict(list(zip(solved_syms, r))) for r in set(result)] + except NotImplementedError: + failed.extend([g.as_expr() for g in polys]) + solved_syms = [] + + # convert None or [] to [{}] + result = result or [{}] + + if failed: + linear = False + # For each failed equation, see if we can solve for one of the + # remaining symbols from that equation. If so, we update the + # solution set and continue with the next failed equation, + # repeating until we are done or we get an equation that can't + # be solved. + def _ok_syms(e, sort=False): + rv = e.free_symbols & legal + + # Solve first for symbols that have lower degree in the equation. + # Ideally we want to solve firstly for symbols that appear linearly + # with rational coefficients e.g. if e = x*y + z then we should + # solve for z first. + def key(sym): + ep = e.as_poly(sym) + if ep is None: + complexity = (S.Infinity, S.Infinity, S.Infinity) + else: + coeff_syms = ep.LC().free_symbols + complexity = (ep.degree(), len(coeff_syms & rv), len(coeff_syms)) + return complexity + (default_sort_key(sym),) + + if sort: + rv = sorted(rv, key=key) + return rv + + legal = set(symbols) # what we are interested in + # sort so equation with the fewest potential symbols is first + u = Dummy() # used in solution checking + for eq in ordered(failed, lambda _: len(_ok_syms(_))): + newresult = [] + bad_results = [] + hit = False + for r in result: + got_s = set() + # update eq with everything that is known so far + eq2 = eq.subs(r) + # if check is True then we see if it satisfies this + # equation, otherwise we just accept it + if check and r: + b = checksol(u, u, eq2, minimal=True) + if b is not None: + # this solution is sufficient to know whether + # it is valid or not so we either accept or + # reject it, then continue + if b: + newresult.append(r) + else: + bad_results.append(r) + continue + # search for a symbol amongst those available that + # can be solved for + ok_syms = _ok_syms(eq2, sort=True) + if not ok_syms: + if r: + newresult.append(r) + break # skip as it's independent of desired symbols + for s in ok_syms: + try: + soln = _vsolve(eq2, s, **flags) + except NotImplementedError: + continue + # put each solution in r and append the now-expanded + # result in the new result list; use copy since the + # solution for s is being added in-place + for sol in soln: + if got_s and any(ss in sol.free_symbols for ss in got_s): + # sol depends on previously solved symbols: discard it + continue + rnew = r.copy() + for k, v in r.items(): + rnew[k] = v.subs(s, sol) + # and add this new solution + rnew[s] = sol + # check that it is independent of previous solutions + iset = set(rnew.items()) + for i in newresult: + if len(i) < len(iset): + # update i with what is known + i_items_updated = {(k, v.xreplace(rnew)) for k, v in i.items()} + if not i_items_updated - iset: + # this is a superset of a known solution that + # is smaller + break + else: + # keep it + newresult.append(rnew) + hit = True + got_s.add(s) + if not hit: + raise NotImplementedError('could not solve %s' % eq2) + else: + result = newresult + for b in bad_results: + if b in result: + result.remove(b) + + if not result: + return False, [] + + # rely on linear/polynomial system solvers to simplify + # XXX the following tests show that the expressions + # returned are not the same as they would be if simplify + # were applied to this: + # sympy/solvers/ode/tests/test_systems/test__classify_linear_system + # sympy/solvers/tests/test_solvers/test_issue_4886 + # so the docs should be updated to reflect that or else + # the following should be `bool(failed) or not linear` + default_simplify = bool(failed) + if flags.get('simplify', default_simplify): + for r in result: + for k in r: + r[k] = simplify(r[k]) + flags['simplify'] = False # don't need to do so in checksol now + + if checkdens: + result = [r for r in result + if not any(checksol(d, r, **flags) for d in dens)] + + if check and not linear: + result = [r for r in result + if not any(checksol(e, r, **flags) is False for e in exprs)] + + result = [r for r in result if r] + return linear, result + + +def solve_linear(lhs, rhs=0, symbols=[], exclude=[]): + r""" + Return a tuple derived from ``f = lhs - rhs`` that is one of + the following: ``(0, 1)``, ``(0, 0)``, ``(symbol, solution)``, ``(n, d)``. + + Explanation + =========== + + ``(0, 1)`` meaning that ``f`` is independent of the symbols in *symbols* + that are not in *exclude*. + + ``(0, 0)`` meaning that there is no solution to the equation amongst the + symbols given. If the first element of the tuple is not zero, then the + function is guaranteed to be dependent on a symbol in *symbols*. + + ``(symbol, solution)`` where symbol appears linearly in the numerator of + ``f``, is in *symbols* (if given), and is not in *exclude* (if given). No + simplification is done to ``f`` other than a ``mul=True`` expansion, so the + solution will correspond strictly to a unique solution. + + ``(n, d)`` where ``n`` and ``d`` are the numerator and denominator of ``f`` + when the numerator was not linear in any symbol of interest; ``n`` will + never be a symbol unless a solution for that symbol was found (in which case + the second element is the solution, not the denominator). + + Examples + ======== + + >>> from sympy import cancel, Pow + + ``f`` is independent of the symbols in *symbols* that are not in + *exclude*: + + >>> from sympy import cos, sin, solve_linear + >>> from sympy.abc import x, y, z + >>> eq = y*cos(x)**2 + y*sin(x)**2 - y # = y*(1 - 1) = 0 + >>> solve_linear(eq) + (0, 1) + >>> eq = cos(x)**2 + sin(x)**2 # = 1 + >>> solve_linear(eq) + (0, 1) + >>> solve_linear(x, exclude=[x]) + (0, 1) + + The variable ``x`` appears as a linear variable in each of the + following: + + >>> solve_linear(x + y**2) + (x, -y**2) + >>> solve_linear(1/x - y**2) + (x, y**(-2)) + + When not linear in ``x`` or ``y`` then the numerator and denominator are + returned: + + >>> solve_linear(x**2/y**2 - 3) + (x**2 - 3*y**2, y**2) + + If the numerator of the expression is a symbol, then ``(0, 0)`` is + returned if the solution for that symbol would have set any + denominator to 0: + + >>> eq = 1/(1/x - 2) + >>> eq.as_numer_denom() + (x, 1 - 2*x) + >>> solve_linear(eq) + (0, 0) + + But automatic rewriting may cause a symbol in the denominator to + appear in the numerator so a solution will be returned: + + >>> (1/x)**-1 + x + >>> solve_linear((1/x)**-1) + (x, 0) + + Use an unevaluated expression to avoid this: + + >>> solve_linear(Pow(1/x, -1, evaluate=False)) + (0, 0) + + If ``x`` is allowed to cancel in the following expression, then it + appears to be linear in ``x``, but this sort of cancellation is not + done by ``solve_linear`` so the solution will always satisfy the + original expression without causing a division by zero error. + + >>> eq = x**2*(1/x - z**2/x) + >>> solve_linear(cancel(eq)) + (x, 0) + >>> solve_linear(eq) + (x**2*(1 - z**2), x) + + A list of symbols for which a solution is desired may be given: + + >>> solve_linear(x + y + z, symbols=[y]) + (y, -x - z) + + A list of symbols to ignore may also be given: + + >>> solve_linear(x + y + z, exclude=[x]) + (y, -x - z) + + (A solution for ``y`` is obtained because it is the first variable + from the canonically sorted list of symbols that had a linear + solution.) + + """ + if isinstance(lhs, Eq): + if rhs: + raise ValueError(filldedent(''' + If lhs is an Equality, rhs must be 0 but was %s''' % rhs)) + rhs = lhs.rhs + lhs = lhs.lhs + dens = None + eq = lhs - rhs + n, d = eq.as_numer_denom() + if not n: + return S.Zero, S.One + + free = n.free_symbols + if not symbols: + symbols = free + else: + bad = [s for s in symbols if not s.is_Symbol] + if bad: + if len(bad) == 1: + bad = bad[0] + if len(symbols) == 1: + eg = 'solve(%s, %s)' % (eq, symbols[0]) + else: + eg = 'solve(%s, *%s)' % (eq, list(symbols)) + raise ValueError(filldedent(''' + solve_linear only handles symbols, not %s. To isolate + non-symbols use solve, e.g. >>> %s <<<. + ''' % (bad, eg))) + symbols = free.intersection(symbols) + symbols = symbols.difference(exclude) + if not symbols: + return S.Zero, S.One + + # derivatives are easy to do but tricky to analyze to see if they + # are going to disallow a linear solution, so for simplicity we + # just evaluate the ones that have the symbols of interest + derivs = defaultdict(list) + for der in n.atoms(Derivative): + csym = der.free_symbols & symbols + for c in csym: + derivs[c].append(der) + + all_zero = True + for xi in sorted(symbols, key=default_sort_key): # canonical order + # if there are derivatives in this var, calculate them now + if isinstance(derivs[xi], list): + derivs[xi] = {der: der.doit() for der in derivs[xi]} + newn = n.subs(derivs[xi]) + dnewn_dxi = newn.diff(xi) + # dnewn_dxi can be nonzero if it survives differentation by any + # of its free symbols + free = dnewn_dxi.free_symbols + if dnewn_dxi and (not free or any(dnewn_dxi.diff(s) for s in free) or free == symbols): + all_zero = False + if dnewn_dxi is S.NaN: + break + if xi not in dnewn_dxi.free_symbols: + vi = -1/dnewn_dxi*(newn.subs(xi, 0)) + if dens is None: + dens = _simple_dens(eq, symbols) + if not any(checksol(di, {xi: vi}, minimal=True) is True + for di in dens): + # simplify any trivial integral + irep = [(i, i.doit()) for i in vi.atoms(Integral) if + i.function.is_number] + # do a slight bit of simplification + vi = expand_mul(vi.subs(irep)) + return xi, vi + if all_zero: + return S.Zero, S.One + if n.is_Symbol: # no solution for this symbol was found + return S.Zero, S.Zero + return n, d + + +def minsolve_linear_system(system, *symbols, **flags): + r""" + Find a particular solution to a linear system. + + Explanation + =========== + + In particular, try to find a solution with the minimal possible number + of non-zero variables using a naive algorithm with exponential complexity. + If ``quick=True``, a heuristic is used. + + """ + quick = flags.get('quick', False) + # Check if there are any non-zero solutions at all + s0 = solve_linear_system(system, *symbols, **flags) + if not s0 or all(v == 0 for v in s0.values()): + return s0 + if quick: + # We just solve the system and try to heuristically find a nice + # solution. + s = solve_linear_system(system, *symbols) + def update(determined, solution): + delete = [] + for k, v in solution.items(): + solution[k] = v.subs(determined) + if not solution[k].free_symbols: + delete.append(k) + determined[k] = solution[k] + for k in delete: + del solution[k] + determined = {} + update(determined, s) + while s: + # NOTE sort by default_sort_key to get deterministic result + k = max((k for k in s.values()), + key=lambda x: (len(x.free_symbols), default_sort_key(x))) + kfree = k.free_symbols + x = next(reversed(list(ordered(kfree)))) + if len(kfree) != 1: + determined[x] = S.Zero + else: + val = _vsolve(k, x, check=False)[0] + if not val and not any(v.subs(x, val) for v in s.values()): + determined[x] = S.One + else: + determined[x] = val + update(determined, s) + return determined + else: + # We try to select n variables which we want to be non-zero. + # All others will be assumed zero. We try to solve the modified system. + # If there is a non-trivial solution, just set the free variables to + # one. If we do this for increasing n, trying all combinations of + # variables, we will find an optimal solution. + # We speed up slightly by starting at one less than the number of + # variables the quick method manages. + N = len(symbols) + bestsol = minsolve_linear_system(system, *symbols, quick=True) + n0 = len([x for x in bestsol.values() if x != 0]) + for n in range(n0 - 1, 1, -1): + debugf('minsolve: %s', n) + thissol = None + for nonzeros in combinations(range(N), n): + subm = Matrix([system.col(i).T for i in nonzeros] + [system.col(-1).T]).T + s = solve_linear_system(subm, *[symbols[i] for i in nonzeros]) + if s and not all(v == 0 for v in s.values()): + subs = [(symbols[v], S.One) for v in nonzeros] + for k, v in s.items(): + s[k] = v.subs(subs) + for sym in symbols: + if sym not in s: + if symbols.index(sym) in nonzeros: + s[sym] = S.One + else: + s[sym] = S.Zero + thissol = s + break + if thissol is None: + break + bestsol = thissol + return bestsol + + +def solve_linear_system(system, *symbols, **flags): + r""" + Solve system of $N$ linear equations with $M$ variables, which means + both under- and overdetermined systems are supported. + + Explanation + =========== + + The possible number of solutions is zero, one, or infinite. Respectively, + this procedure will return None or a dictionary with solutions. In the + case of underdetermined systems, all arbitrary parameters are skipped. + This may cause a situation in which an empty dictionary is returned. + In that case, all symbols can be assigned arbitrary values. + + Input to this function is a $N\times M + 1$ matrix, which means it has + to be in augmented form. If you prefer to enter $N$ equations and $M$ + unknowns then use ``solve(Neqs, *Msymbols)`` instead. Note: a local + copy of the matrix is made by this routine so the matrix that is + passed will not be modified. + + The algorithm used here is fraction-free Gaussian elimination, + which results, after elimination, in an upper-triangular matrix. + Then solutions are found using back-substitution. This approach + is more efficient and compact than the Gauss-Jordan method. + + Examples + ======== + + >>> from sympy import Matrix, solve_linear_system + >>> from sympy.abc import x, y + + Solve the following system:: + + x + 4 y == 2 + -2 x + y == 14 + + >>> system = Matrix(( (1, 4, 2), (-2, 1, 14))) + >>> solve_linear_system(system, x, y) + {x: -6, y: 2} + + A degenerate system returns an empty dictionary: + + >>> system = Matrix(( (0,0,0), (0,0,0) )) + >>> solve_linear_system(system, x, y) + {} + + """ + assert system.shape[1] == len(symbols) + 1 + + # This is just a wrapper for solve_lin_sys + eqs = list(system * Matrix(symbols + (-1,))) + eqs, ring = sympy_eqs_to_ring(eqs, symbols) + sol = solve_lin_sys(eqs, ring, _raw=False) + if sol is not None: + sol = {sym:val for sym, val in sol.items() if sym != val} + return sol + + +def solve_undetermined_coeffs(equ, coeffs, *syms, **flags): + r""" + Solve a system of equations in $k$ parameters that is formed by + matching coefficients in variables ``coeffs`` that are on + factors dependent on the remaining variables (or those given + explicitly by ``syms``. + + Explanation + =========== + + The result of this function is a dictionary with symbolic values of those + parameters with respect to coefficients in $q$ -- empty if there + is no solution or coefficients do not appear in the equation -- else + None (if the system was not recognized). If there is more than one + solution, the solutions are passed as a list. The output can be modified using + the same semantics as for `solve` since the flags that are passed are sent + directly to `solve` so, for example the flag ``dict=True`` will always return a list + of solutions as dictionaries. + + This function accepts both Equality and Expr class instances. + The solving process is most efficient when symbols are specified + in addition to parameters to be determined, but an attempt to + determine them (if absent) will be made. If an expected solution is not + obtained (and symbols were not specified) try specifying them. + + Examples + ======== + + >>> from sympy import Eq, solve_undetermined_coeffs + >>> from sympy.abc import a, b, c, h, p, k, x, y + + >>> solve_undetermined_coeffs(Eq(a*x + a + b, x/2), [a, b], x) + {a: 1/2, b: -1/2} + >>> solve_undetermined_coeffs(a - 2, [a]) + {a: 2} + + The equation can be nonlinear in the symbols: + + >>> X, Y, Z = y, x**y, y*x**y + >>> eq = a*X + b*Y + c*Z - X - 2*Y - 3*Z + >>> coeffs = a, b, c + >>> syms = x, y + >>> solve_undetermined_coeffs(eq, coeffs, syms) + {a: 1, b: 2, c: 3} + + And the system can be nonlinear in coefficients, too, but if + there is only a single solution, it will be returned as a + dictionary: + + >>> eq = a*x**2 + b*x + c - ((x - h)**2 + 4*p*k)/4/p + >>> solve_undetermined_coeffs(eq, (h, p, k), x) + {h: -b/(2*a), k: (4*a*c - b**2)/(4*a), p: 1/(4*a)} + + Multiple solutions are always returned in a list: + + >>> solve_undetermined_coeffs(a**2*x + b - x, [a, b], x) + [{a: -1, b: 0}, {a: 1, b: 0}] + + Using flag ``dict=True`` (in keeping with semantics in :func:`~.solve`) + will force the result to always be a list with any solutions + as elements in that list. + + >>> solve_undetermined_coeffs(a*x - 2*x, [a], dict=True) + [{a: 2}] + """ + if not (coeffs and all(i.is_Symbol for i in coeffs)): + raise ValueError('must provide symbols for coeffs') + + if isinstance(equ, Eq): + eq = equ.lhs - equ.rhs + else: + eq = equ + + ceq = cancel(eq) + xeq = _mexpand(ceq.as_numer_denom()[0], recursive=True) + + free = xeq.free_symbols + coeffs = free & set(coeffs) + if not coeffs: + return ([], {}) if flags.get('set', None) else [] # solve(0, x) -> [] + + if not syms: + # e.g. A*exp(x) + B - (exp(x) + y) separated into parts that + # don't/do depend on coeffs gives + # -(exp(x) + y), A*exp(x) + B + # then see what symbols are common to both + # {x} = {x, A, B} - {x, y} + ind, dep = xeq.as_independent(*coeffs, as_Add=True) + dfree = dep.free_symbols + syms = dfree & ind.free_symbols + if not syms: + # but if the system looks like (a + b)*x + b - c + # then {} = {a, b, x} - c + # so calculate {x} = {a, b, x} - {a, b} + syms = dfree - set(coeffs) + if not syms: + syms = [Dummy()] + else: + if len(syms) == 1 and iterable(syms[0]): + syms = syms[0] + e, s, _ = recast_to_symbols([xeq], syms) + xeq = e[0] + syms = s + + # find the functional forms in which symbols appear + + gens = set(xeq.as_coefficients_dict(*syms).keys()) - {1} + cset = set(coeffs) + if any(g.has_xfree(cset) for g in gens): + return # a generator contained a coefficient symbol + + # make sure we are working with symbols for generators + + e, gens, _ = recast_to_symbols([xeq], list(gens)) + xeq = e[0] + + # collect coefficients in front of generators + + system = list(collect(xeq, gens, evaluate=False).values()) + + # get a solution + + soln = solve(system, coeffs, **flags) + + # unpack unless told otherwise if length is 1 + + settings = flags.get('dict', None) or flags.get('set', None) + if type(soln) is dict or settings or len(soln) != 1: + return soln + return soln[0] + + +def solve_linear_system_LU(matrix, syms): + """ + Solves the augmented matrix system using ``LUsolve`` and returns a + dictionary in which solutions are keyed to the symbols of *syms* as ordered. + + Explanation + =========== + + The matrix must be invertible. + + Examples + ======== + + >>> from sympy import Matrix, solve_linear_system_LU + >>> from sympy.abc import x, y, z + + >>> solve_linear_system_LU(Matrix([ + ... [1, 2, 0, 1], + ... [3, 2, 2, 1], + ... [2, 0, 0, 1]]), [x, y, z]) + {x: 1/2, y: 1/4, z: -1/2} + + See Also + ======== + + LUsolve + + """ + if matrix.rows != matrix.cols - 1: + raise ValueError("Rows should be equal to columns - 1") + A = matrix[:matrix.rows, :matrix.rows] + b = matrix[:, matrix.cols - 1:] + soln = A.LUsolve(b) + solutions = {} + for i in range(soln.rows): + solutions[syms[i]] = soln[i, 0] + return solutions + + +def det_perm(M): + """ + Return the determinant of *M* by using permutations to select factors. + + Explanation + =========== + + For sizes larger than 8 the number of permutations becomes prohibitively + large, or if there are no symbols in the matrix, it is better to use the + standard determinant routines (e.g., ``M.det()``.) + + See Also + ======== + + det_minor + det_quick + + """ + args = [] + s = True + n = M.rows + list_ = M.flat() + for perm in generate_bell(n): + fac = [] + idx = 0 + for j in perm: + fac.append(list_[idx + j]) + idx += n + term = Mul(*fac) # disaster with unevaluated Mul -- takes forever for n=7 + args.append(term if s else -term) + s = not s + return Add(*args) + + +def det_minor(M): + """ + Return the ``det(M)`` computed from minors without + introducing new nesting in products. + + See Also + ======== + + det_perm + det_quick + + """ + n = M.rows + if n == 2: + return M[0, 0]*M[1, 1] - M[1, 0]*M[0, 1] + else: + return sum((1, -1)[i % 2]*Add(*[M[0, i]*d for d in + Add.make_args(det_minor(M.minor_submatrix(0, i)))]) + if M[0, i] else S.Zero for i in range(n)) + + +def det_quick(M, method=None): + """ + Return ``det(M)`` assuming that either + there are lots of zeros or the size of the matrix + is small. If this assumption is not met, then the normal + Matrix.det function will be used with method = ``method``. + + See Also + ======== + + det_minor + det_perm + + """ + if any(i.has(Symbol) for i in M): + if M.rows < 8 and all(i.has(Symbol) for i in M): + return det_perm(M) + return det_minor(M) + else: + return M.det(method=method) if method else M.det() + + +def inv_quick(M): + """Return the inverse of ``M``, assuming that either + there are lots of zeros or the size of the matrix + is small. + """ + if not all(i.is_Number for i in M): + if not any(i.is_Number for i in M): + det = lambda _: det_perm(_) + else: + det = lambda _: det_minor(_) + else: + return M.inv() + n = M.rows + d = det(M) + if d == S.Zero: + raise NonInvertibleMatrixError("Matrix det == 0; not invertible") + ret = zeros(n) + s1 = -1 + for i in range(n): + s = s1 = -s1 + for j in range(n): + di = det(M.minor_submatrix(i, j)) + ret[j, i] = s*di/d + s = -s + return ret + + +# these are functions that have multiple inverse values per period +multi_inverses = { + sin: lambda x: (asin(x), S.Pi - asin(x)), + cos: lambda x: (acos(x), 2*S.Pi - acos(x)), +} + + +def _vsolve(e, s, **flags): + """return list of scalar values for the solution of e for symbol s""" + return [i[s] for i in _solve(e, s, **flags)] + + +def _tsolve(eq, sym, **flags): + """ + Helper for ``_solve`` that solves a transcendental equation with respect + to the given symbol. Various equations containing powers and logarithms, + can be solved. + + There is currently no guarantee that all solutions will be returned or + that a real solution will be favored over a complex one. + + Either a list of potential solutions will be returned or None will be + returned (in the case that no method was known to get a solution + for the equation). All other errors (like the inability to cast an + expression as a Poly) are unhandled. + + Examples + ======== + + >>> from sympy import log, ordered + >>> from sympy.solvers.solvers import _tsolve as tsolve + >>> from sympy.abc import x + + >>> list(ordered(tsolve(3**(2*x + 5) - 4, x))) + [-5/2 + log(2)/log(3), (-5*log(3)/2 + log(2) + I*pi)/log(3)] + + >>> tsolve(log(x) + 2*x, x) + [LambertW(2)/2] + + """ + if 'tsolve_saw' not in flags: + flags['tsolve_saw'] = [] + if eq in flags['tsolve_saw']: + return None + else: + flags['tsolve_saw'].append(eq) + + rhs, lhs = _invert(eq, sym) + + if lhs == sym: + return [rhs] + try: + if lhs.is_Add: + # it's time to try factoring; powdenest is used + # to try get powers in standard form for better factoring + f = factor(powdenest(lhs - rhs)) + if f.is_Mul: + return _vsolve(f, sym, **flags) + if rhs: + f = logcombine(lhs, force=flags.get('force', True)) + if f.count(log) != lhs.count(log): + if isinstance(f, log): + return _vsolve(f.args[0] - exp(rhs), sym, **flags) + return _tsolve(f - rhs, sym, **flags) + + elif lhs.is_Pow: + if lhs.exp.is_Integer: + if lhs - rhs != eq: + return _vsolve(lhs - rhs, sym, **flags) + + if sym not in lhs.exp.free_symbols: + return _vsolve(lhs.base - rhs**(1/lhs.exp), sym, **flags) + + # _tsolve calls this with Dummy before passing the actual number in. + if any(t.is_Dummy for t in rhs.free_symbols): + raise NotImplementedError # _tsolve will call here again... + + # a ** g(x) == 0 + if not rhs: + # f(x)**g(x) only has solutions where f(x) == 0 and g(x) != 0 at + # the same place + sol_base = _vsolve(lhs.base, sym, **flags) + return [s for s in sol_base if lhs.exp.subs(sym, s) != 0] # XXX use checksol here? + + # a ** g(x) == b + if not lhs.base.has(sym): + if lhs.base == 0: + return _vsolve(lhs.exp, sym, **flags) if rhs != 0 else [] + + # Gets most solutions... + if lhs.base == rhs.as_base_exp()[0]: + # handles case when bases are equal + sol = _vsolve(lhs.exp - rhs.as_base_exp()[1], sym, **flags) + else: + # handles cases when bases are not equal and exp + # may or may not be equal + f = exp(log(lhs.base)*lhs.exp) - exp(log(rhs)) + sol = _vsolve(f, sym, **flags) + + # Check for duplicate solutions + def equal(expr1, expr2): + _ = Dummy() + eq = checksol(expr1 - _, _, expr2) + if eq is None: + if nsimplify(expr1) != nsimplify(expr2): + return False + # they might be coincidentally the same + # so check more rigorously + eq = expr1.equals(expr2) # XXX expensive but necessary? + return eq + + # Guess a rational exponent + e_rat = nsimplify(log(abs(rhs))/log(abs(lhs.base))) + e_rat = simplify(posify(e_rat)[0]) + n, d = fraction(e_rat) + if expand(lhs.base**n - rhs**d) == 0: + sol = [s for s in sol if not equal(lhs.exp.subs(sym, s), e_rat)] + sol.extend(_vsolve(lhs.exp - e_rat, sym, **flags)) + + return list(set(sol)) + + # f(x) ** g(x) == c + else: + sol = [] + logform = lhs.exp*log(lhs.base) - log(rhs) + if logform != lhs - rhs: + try: + sol.extend(_vsolve(logform, sym, **flags)) + except NotImplementedError: + pass + + # Collect possible solutions and check with substitution later. + check = [] + if rhs == 1: + # f(x) ** g(x) = 1 -- g(x)=0 or f(x)=+-1 + check.extend(_vsolve(lhs.exp, sym, **flags)) + check.extend(_vsolve(lhs.base - 1, sym, **flags)) + check.extend(_vsolve(lhs.base + 1, sym, **flags)) + elif rhs.is_Rational: + for d in (i for i in divisors(abs(rhs.p)) if i != 1): + e, t = integer_log(rhs.p, d) + if not t: + continue # rhs.p != d**b + for s in divisors(abs(rhs.q)): + if s**e== rhs.q: + r = Rational(d, s) + check.extend(_vsolve(lhs.base - r, sym, **flags)) + check.extend(_vsolve(lhs.base + r, sym, **flags)) + check.extend(_vsolve(lhs.exp - e, sym, **flags)) + elif rhs.is_irrational: + b_l, e_l = lhs.base.as_base_exp() + n, d = (e_l*lhs.exp).as_numer_denom() + b, e = sqrtdenest(rhs).as_base_exp() + check = [sqrtdenest(i) for i in (_vsolve(lhs.base - b, sym, **flags))] + check.extend([sqrtdenest(i) for i in (_vsolve(lhs.exp - e, sym, **flags))]) + if e_l*d != 1: + check.extend(_vsolve(b_l**n - rhs**(e_l*d), sym, **flags)) + for s in check: + ok = checksol(eq, sym, s) + if ok is None: + ok = eq.subs(sym, s).equals(0) + if ok: + sol.append(s) + return list(set(sol)) + + elif lhs.is_Function and len(lhs.args) == 1: + if lhs.func in multi_inverses: + # sin(x) = 1/3 -> x - asin(1/3) & x - (pi - asin(1/3)) + soln = [] + for i in multi_inverses[type(lhs)](rhs): + soln.extend(_vsolve(lhs.args[0] - i, sym, **flags)) + return list(set(soln)) + elif lhs.func == LambertW: + return _vsolve(lhs.args[0] - rhs*exp(rhs), sym, **flags) + + rewrite = lhs.rewrite(exp) + rewrite = rebuild(rewrite) # avoid rewrites involving evaluate=False + if rewrite != lhs: + return _vsolve(rewrite - rhs, sym, **flags) + except NotImplementedError: + pass + + # maybe it is a lambert pattern + if flags.pop('bivariate', True): + # lambert forms may need some help being recognized, e.g. changing + # 2**(3*x) + x**3*log(2)**3 + 3*x**2*log(2)**2 + 3*x*log(2) + 1 + # to 2**(3*x) + (x*log(2) + 1)**3 + + # make generator in log have exponent of 1 + logs = eq.atoms(log) + spow = min( + {i.exp for j in logs for i in j.atoms(Pow) + if i.base == sym} or {1}) + if spow != 1: + p = sym**spow + u = Dummy('bivariate-cov') + ueq = eq.subs(p, u) + if not ueq.has_free(sym): + sol = _vsolve(ueq, u, **flags) + inv = _vsolve(p - u, sym) + return [i.subs(u, s) for i in inv for s in sol] + + g = _filtered_gens(eq.as_poly(), sym) + up_or_log = set() + for gi in g: + if isinstance(gi, (exp, log)) or (gi.is_Pow and gi.base == S.Exp1): + up_or_log.add(gi) + elif gi.is_Pow: + gisimp = powdenest(expand_power_exp(gi)) + if gisimp.is_Pow and sym in gisimp.exp.free_symbols: + up_or_log.add(gi) + eq_down = expand_log(expand_power_exp(eq)).subs( + dict(list(zip(up_or_log, [0]*len(up_or_log))))) + eq = expand_power_exp(factor(eq_down, deep=True) + (eq - eq_down)) + rhs, lhs = _invert(eq, sym) + if lhs.has(sym): + try: + poly = lhs.as_poly() + g = _filtered_gens(poly, sym) + _eq = lhs - rhs + sols = _solve_lambert(_eq, sym, g) + # use a simplified form if it satisfies eq + # and has fewer operations + for n, s in enumerate(sols): + ns = nsimplify(s) + if ns != s and ns.count_ops() <= s.count_ops(): + ok = checksol(_eq, sym, ns) + if ok is None: + ok = _eq.subs(sym, ns).equals(0) + if ok: + sols[n] = ns + return sols + except NotImplementedError: + # maybe it's a convoluted function + if len(g) == 2: + try: + gpu = bivariate_type(lhs - rhs, *g) + if gpu is None: + raise NotImplementedError + g, p, u = gpu + flags['bivariate'] = False + inversion = _tsolve(g - u, sym, **flags) + if inversion: + sol = _vsolve(p, u, **flags) + return list({i.subs(u, s) + for i in inversion for s in sol}) + except NotImplementedError: + pass + else: + pass + + if flags.pop('force', True): + flags['force'] = False + pos, reps = posify(lhs - rhs) + if rhs == S.ComplexInfinity: + return [] + for u, s in reps.items(): + if s == sym: + break + else: + u = sym + if pos.has(u): + try: + soln = _vsolve(pos, u, **flags) + return [s.subs(reps) for s in soln] + except NotImplementedError: + pass + else: + pass # here for coverage + + return # here for coverage + + +# TODO: option for calculating J numerically + +@conserve_mpmath_dps +def nsolve(*args, dict=False, **kwargs): + r""" + Solve a nonlinear equation system numerically: ``nsolve(f, [args,] x0, + modules=['mpmath'], **kwargs)``. + + Explanation + =========== + + ``f`` is a vector function of symbolic expressions representing the system. + *args* are the variables. If there is only one variable, this argument can + be omitted. ``x0`` is a starting vector close to a solution. + + Use the modules keyword to specify which modules should be used to + evaluate the function and the Jacobian matrix. Make sure to use a module + that supports matrices. For more information on the syntax, please see the + docstring of ``lambdify``. + + If the keyword arguments contain ``dict=True`` (default is False) ``nsolve`` + will return a list (perhaps empty) of solution mappings. This might be + especially useful if you want to use ``nsolve`` as a fallback to solve since + using the dict argument for both methods produces return values of + consistent type structure. Please note: to keep this consistent with + ``solve``, the solution will be returned in a list even though ``nsolve`` + (currently at least) only finds one solution at a time. + + Overdetermined systems are supported. + + Examples + ======== + + >>> from sympy import Symbol, nsolve + >>> import mpmath + >>> mpmath.mp.dps = 15 + >>> x1 = Symbol('x1') + >>> x2 = Symbol('x2') + >>> f1 = 3 * x1**2 - 2 * x2**2 - 1 + >>> f2 = x1**2 - 2 * x1 + x2**2 + 2 * x2 - 8 + >>> print(nsolve((f1, f2), (x1, x2), (-1, 1))) + Matrix([[-1.19287309935246], [1.27844411169911]]) + + For one-dimensional functions the syntax is simplified: + + >>> from sympy import sin, nsolve + >>> from sympy.abc import x + >>> nsolve(sin(x), x, 2) + 3.14159265358979 + >>> nsolve(sin(x), 2) + 3.14159265358979 + + To solve with higher precision than the default, use the prec argument: + + >>> from sympy import cos + >>> nsolve(cos(x) - x, 1) + 0.739085133215161 + >>> nsolve(cos(x) - x, 1, prec=50) + 0.73908513321516064165531208767387340401341175890076 + >>> cos(_) + 0.73908513321516064165531208767387340401341175890076 + + To solve for complex roots of real functions, a nonreal initial point + must be specified: + + >>> from sympy import I + >>> nsolve(x**2 + 2, I) + 1.4142135623731*I + + ``mpmath.findroot`` is used and you can find their more extensive + documentation, especially concerning keyword parameters and + available solvers. Note, however, that functions which are very + steep near the root, the verification of the solution may fail. In + this case you should use the flag ``verify=False`` and + independently verify the solution. + + >>> from sympy import cos, cosh + >>> f = cos(x)*cosh(x) - 1 + >>> nsolve(f, 3.14*100) + Traceback (most recent call last): + ... + ValueError: Could not find root within given tolerance. (1.39267e+230 > 2.1684e-19) + >>> ans = nsolve(f, 3.14*100, verify=False); ans + 312.588469032184 + >>> f.subs(x, ans).n(2) + 2.1e+121 + >>> (f/f.diff(x)).subs(x, ans).n(2) + 7.4e-15 + + One might safely skip the verification if bounds of the root are known + and a bisection method is used: + + >>> bounds = lambda i: (3.14*i, 3.14*(i + 1)) + >>> nsolve(f, bounds(100), solver='bisect', verify=False) + 315.730061685774 + + Alternatively, a function may be better behaved when the + denominator is ignored. Since this is not always the case, however, + the decision of what function to use is left to the discretion of + the user. + + >>> eq = x**2/(1 - x)/(1 - 2*x)**2 - 100 + >>> nsolve(eq, 0.46) + Traceback (most recent call last): + ... + ValueError: Could not find root within given tolerance. (10000 > 2.1684e-19) + Try another starting point or tweak arguments. + >>> nsolve(eq.as_numer_denom()[0], 0.46) + 0.46792545969349058 + + """ + # there are several other SymPy functions that use method= so + # guard against that here + if 'method' in kwargs: + raise ValueError(filldedent(''' + Keyword "method" should not be used in this context. When using + some mpmath solvers directly, the keyword "method" is + used, but when using nsolve (and findroot) the keyword to use is + "solver".''')) + + if 'prec' in kwargs: + import mpmath + mpmath.mp.dps = kwargs.pop('prec') + + # keyword argument to return result as a dictionary + as_dict = dict + from builtins import dict # to unhide the builtin + + # interpret arguments + if len(args) == 3: + f = args[0] + fargs = args[1] + x0 = args[2] + if iterable(fargs) and iterable(x0): + if len(x0) != len(fargs): + raise TypeError('nsolve expected exactly %i guess vectors, got %i' + % (len(fargs), len(x0))) + elif len(args) == 2: + f = args[0] + fargs = None + x0 = args[1] + if iterable(f): + raise TypeError('nsolve expected 3 arguments, got 2') + elif len(args) < 2: + raise TypeError('nsolve expected at least 2 arguments, got %i' + % len(args)) + else: + raise TypeError('nsolve expected at most 3 arguments, got %i' + % len(args)) + modules = kwargs.get('modules', ['mpmath']) + if iterable(f): + f = list(f) + for i, fi in enumerate(f): + if isinstance(fi, Eq): + f[i] = fi.lhs - fi.rhs + f = Matrix(f).T + if iterable(x0): + x0 = list(x0) + if not isinstance(f, Matrix): + # assume it's a SymPy expression + if isinstance(f, Eq): + f = f.lhs - f.rhs + elif f.is_Relational: + raise TypeError('nsolve cannot accept inequalities') + syms = f.free_symbols + if fargs is None: + fargs = syms.copy().pop() + if not (len(syms) == 1 and (fargs in syms or fargs[0] in syms)): + raise ValueError(filldedent(''' + expected a one-dimensional and numerical function''')) + + # the function is much better behaved if there is no denominator + # but sending the numerator is left to the user since sometimes + # the function is better behaved when the denominator is present + # e.g., issue 11768 + + f = lambdify(fargs, f, modules) + x = sympify(findroot(f, x0, **kwargs)) + if as_dict: + return [{fargs: x}] + return x + + if len(fargs) > f.cols: + raise NotImplementedError(filldedent(''' + need at least as many equations as variables''')) + verbose = kwargs.get('verbose', False) + if verbose: + print('f(x):') + print(f) + # derive Jacobian + J = f.jacobian(fargs) + if verbose: + print('J(x):') + print(J) + # create functions + f = lambdify(fargs, f.T, modules) + J = lambdify(fargs, J, modules) + # solve the system numerically + x = findroot(f, x0, J=J, **kwargs) + if as_dict: + return [dict(zip(fargs, [sympify(xi) for xi in x]))] + return Matrix(x) + + +def _invert(eq, *symbols, **kwargs): + """ + Return tuple (i, d) where ``i`` is independent of *symbols* and ``d`` + contains symbols. + + Explanation + =========== + + ``i`` and ``d`` are obtained after recursively using algebraic inversion + until an uninvertible ``d`` remains. If there are no free symbols then + ``d`` will be zero. Some (but not necessarily all) solutions to the + expression ``i - d`` will be related to the solutions of the original + expression. + + Examples + ======== + + >>> from sympy.solvers.solvers import _invert as invert + >>> from sympy import sqrt, cos + >>> from sympy.abc import x, y + >>> invert(x - 3) + (3, x) + >>> invert(3) + (3, 0) + >>> invert(2*cos(x) - 1) + (1/2, cos(x)) + >>> invert(sqrt(x) - 3) + (3, sqrt(x)) + >>> invert(sqrt(x) + y, x) + (-y, sqrt(x)) + >>> invert(sqrt(x) + y, y) + (-sqrt(x), y) + >>> invert(sqrt(x) + y, x, y) + (0, sqrt(x) + y) + + If there is more than one symbol in a power's base and the exponent + is not an Integer, then the principal root will be used for the + inversion: + + >>> invert(sqrt(x + y) - 2) + (4, x + y) + >>> invert(sqrt(x + y) + 2) # note +2 instead of -2 + (4, x + y) + + If the exponent is an Integer, setting ``integer_power`` to True + will force the principal root to be selected: + + >>> invert(x**2 - 4, integer_power=True) + (2, x) + + """ + eq = sympify(eq) + if eq.args: + # make sure we are working with flat eq + eq = eq.func(*eq.args) + free = eq.free_symbols + if not symbols: + symbols = free + if not free & set(symbols): + return eq, S.Zero + + dointpow = bool(kwargs.get('integer_power', False)) + + lhs = eq + rhs = S.Zero + while True: + was = lhs + while True: + indep, dep = lhs.as_independent(*symbols) + + # dep + indep == rhs + if lhs.is_Add: + # this indicates we have done it all + if indep.is_zero: + break + + lhs = dep + rhs -= indep + + # dep * indep == rhs + else: + # this indicates we have done it all + if indep is S.One: + break + + lhs = dep + rhs /= indep + + # collect like-terms in symbols + if lhs.is_Add: + terms = {} + for a in lhs.args: + i, d = a.as_independent(*symbols) + terms.setdefault(d, []).append(i) + if any(len(v) > 1 for v in terms.values()): + args = [] + for d, i in terms.items(): + if len(i) > 1: + args.append(Add(*i)*d) + else: + args.append(i[0]*d) + lhs = Add(*args) + + # if it's a two-term Add with rhs = 0 and two powers we can get the + # dependent terms together, e.g. 3*f(x) + 2*g(x) -> f(x)/g(x) = -2/3 + if lhs.is_Add and not rhs and len(lhs.args) == 2 and \ + not lhs.is_polynomial(*symbols): + a, b = ordered(lhs.args) + ai, ad = a.as_independent(*symbols) + bi, bd = b.as_independent(*symbols) + if any(_ispow(i) for i in (ad, bd)): + a_base, a_exp = ad.as_base_exp() + b_base, b_exp = bd.as_base_exp() + if a_base == b_base and a_exp.extract_additively(b_exp) is None: + # a = -b and exponents do not have canceling terms/factors + # e.g. if exponents were 3*x and x then the ratio would have + # an exponent of 2*x: one of the roots would be lost + rat = powsimp(powdenest(ad/bd)) + lhs = rat + rhs = -bi/ai + else: + rat = ad/bd + _lhs = powsimp(ad/bd) + if _lhs != rat: + lhs = _lhs + rhs = -bi/ai + elif ai == -bi: + if isinstance(ad, Function) and ad.func == bd.func: + if len(ad.args) == len(bd.args) == 1: + lhs = ad.args[0] - bd.args[0] + elif len(ad.args) == len(bd.args): + # should be able to solve + # f(x, y) - f(2 - x, 0) == 0 -> x == 1 + raise NotImplementedError( + 'equal function with more than 1 argument') + else: + raise ValueError( + 'function with different numbers of args') + + elif lhs.is_Mul and any(_ispow(a) for a in lhs.args): + lhs = powsimp(powdenest(lhs)) + + if lhs.is_Function: + if hasattr(lhs, 'inverse') and lhs.inverse() is not None and len(lhs.args) == 1: + # -1 + # f(x) = g -> x = f (g) + # + # /!\ inverse should not be defined if there are multiple values + # for the function -- these are handled in _tsolve + # + rhs = lhs.inverse()(rhs) + lhs = lhs.args[0] + elif isinstance(lhs, atan2): + y, x = lhs.args + lhs = 2*atan(y/(sqrt(x**2 + y**2) + x)) + elif lhs.func == rhs.func: + if len(lhs.args) == len(rhs.args) == 1: + lhs = lhs.args[0] + rhs = rhs.args[0] + elif len(lhs.args) == len(rhs.args): + # should be able to solve + # f(x, y) == f(2, 3) -> x == 2 + # f(x, x + y) == f(2, 3) -> x == 2 + raise NotImplementedError( + 'equal function with more than 1 argument') + else: + raise ValueError( + 'function with different numbers of args') + + + if rhs and lhs.is_Pow and lhs.exp.is_Integer and lhs.exp < 0: + lhs = 1/lhs + rhs = 1/rhs + + # base**a = b -> base = b**(1/a) if + # a is an Integer and dointpow=True (this gives real branch of root) + # a is not an Integer and the equation is multivariate and the + # base has more than 1 symbol in it + # The rationale for this is that right now the multi-system solvers + # doesn't try to resolve generators to see, for example, if the whole + # system is written in terms of sqrt(x + y) so it will just fail, so we + # do that step here. + if lhs.is_Pow and ( + lhs.exp.is_Integer and dointpow or not lhs.exp.is_Integer and + len(symbols) > 1 and len(lhs.base.free_symbols & set(symbols)) > 1): + rhs = rhs**(1/lhs.exp) + lhs = lhs.base + + if lhs == was: + break + return rhs, lhs + + +def unrad(eq, *syms, **flags): + """ + Remove radicals with symbolic arguments and return (eq, cov), + None, or raise an error. + + Explanation + =========== + + None is returned if there are no radicals to remove. + + NotImplementedError is raised if there are radicals and they cannot be + removed or if the relationship between the original symbols and the + change of variable needed to rewrite the system as a polynomial cannot + be solved. + + Otherwise the tuple, ``(eq, cov)``, is returned where: + + *eq*, ``cov`` + *eq* is an equation without radicals (in the symbol(s) of + interest) whose solutions are a superset of the solutions to the + original expression. *eq* might be rewritten in terms of a new + variable; the relationship to the original variables is given by + ``cov`` which is a list containing ``v`` and ``v**p - b`` where + ``p`` is the power needed to clear the radical and ``b`` is the + radical now expressed as a polynomial in the symbols of interest. + For example, for sqrt(2 - x) the tuple would be + ``(c, c**2 - 2 + x)``. The solutions of *eq* will contain + solutions to the original equation (if there are any). + + *syms* + An iterable of symbols which, if provided, will limit the focus of + radical removal: only radicals with one or more of the symbols of + interest will be cleared. All free symbols are used if *syms* is not + set. + + *flags* are used internally for communication during recursive calls. + Two options are also recognized: + + ``take``, when defined, is interpreted as a single-argument function + that returns True if a given Pow should be handled. + + Radicals can be removed from an expression if: + + * All bases of the radicals are the same; a change of variables is + done in this case. + * If all radicals appear in one term of the expression. + * There are only four terms with sqrt() factors or there are less than + four terms having sqrt() factors. + * There are only two terms with radicals. + + Examples + ======== + + >>> from sympy.solvers.solvers import unrad + >>> from sympy.abc import x + >>> from sympy import sqrt, Rational, root + + >>> unrad(sqrt(x)*x**Rational(1, 3) + 2) + (x**5 - 64, []) + >>> unrad(sqrt(x) + root(x + 1, 3)) + (-x**3 + x**2 + 2*x + 1, []) + >>> eq = sqrt(x) + root(x, 3) - 2 + >>> unrad(eq) + (_p**3 + _p**2 - 2, [_p, _p**6 - x]) + + """ + + uflags = {"check": False, "simplify": False} + + def _cov(p, e): + if cov: + # XXX - uncovered + oldp, olde = cov + if Poly(e, p).degree(p) in (1, 2): + cov[:] = [p, olde.subs(oldp, _vsolve(e, p, **uflags)[0])] + else: + raise NotImplementedError + else: + cov[:] = [p, e] + + def _canonical(eq, cov): + if cov: + # change symbol to vanilla so no solutions are eliminated + p, e = cov + rep = {p: Dummy(p.name)} + eq = eq.xreplace(rep) + cov = [p.xreplace(rep), e.xreplace(rep)] + + # remove constants and powers of factors since these don't change + # the location of the root; XXX should factor or factor_terms be used? + eq = factor_terms(_mexpand(eq.as_numer_denom()[0], recursive=True), clear=True) + if eq.is_Mul: + args = [] + for f in eq.args: + if f.is_number: + continue + if f.is_Pow: + args.append(f.base) + else: + args.append(f) + eq = Mul(*args) # leave as Mul for more efficient solving + + # make the sign canonical + margs = list(Mul.make_args(eq)) + changed = False + for i, m in enumerate(margs): + if m.could_extract_minus_sign(): + margs[i] = -m + changed = True + if changed: + eq = Mul(*margs, evaluate=False) + + return eq, cov + + def _Q(pow): + # return leading Rational of denominator of Pow's exponent + c = pow.as_base_exp()[1].as_coeff_Mul()[0] + if not c.is_Rational: + return S.One + return c.q + + # define the _take method that will determine whether a term is of interest + def _take(d): + # return True if coefficient of any factor's exponent's den is not 1 + for pow in Mul.make_args(d): + if not pow.is_Pow: + continue + if _Q(pow) == 1: + continue + if pow.free_symbols & syms: + return True + return False + _take = flags.setdefault('_take', _take) + + if isinstance(eq, Eq): + eq = eq.lhs - eq.rhs # XXX legacy Eq as Eqn support + elif not isinstance(eq, Expr): + return + + cov, nwas, rpt = [flags.setdefault(k, v) for k, v in + sorted({"cov": [], "n": None, "rpt": 0}.items())] + + # preconditioning + eq = powdenest(factor_terms(eq, radical=True, clear=True)) + eq = eq.as_numer_denom()[0] + eq = _mexpand(eq, recursive=True) + if eq.is_number: + return + + # see if there are radicals in symbols of interest + syms = set(syms) or eq.free_symbols # _take uses this + poly = eq.as_poly() + gens = [g for g in poly.gens if _take(g)] + if not gens: + return + + # recast poly in terms of eigen-gens + poly = eq.as_poly(*gens) + + # not a polynomial e.g. 1 + sqrt(x)*exp(sqrt(x)) with gen sqrt(x) + if poly is None: + return + + # - an exponent has a symbol of interest (don't handle) + if any(g.exp.has(*syms) for g in gens): + return + + def _rads_bases_lcm(poly): + # if all the bases are the same or all the radicals are in one + # term, `lcm` will be the lcm of the denominators of the + # exponents of the radicals + lcm = 1 + rads = set() + bases = set() + for g in poly.gens: + q = _Q(g) + if q != 1: + rads.add(g) + lcm = ilcm(lcm, q) + bases.add(g.base) + return rads, bases, lcm + rads, bases, lcm = _rads_bases_lcm(poly) + + covsym = Dummy('p', nonnegative=True) + + # only keep in syms symbols that actually appear in radicals; + # and update gens + newsyms = set() + for r in rads: + newsyms.update(syms & r.free_symbols) + if newsyms != syms: + syms = newsyms + # get terms together that have common generators + drad = dict(zip(rads, range(len(rads)))) + rterms = {(): []} + args = Add.make_args(poly.as_expr()) + for t in args: + if _take(t): + common = set(t.as_poly().gens).intersection(rads) + key = tuple(sorted([drad[i] for i in common])) + else: + key = () + rterms.setdefault(key, []).append(t) + others = Add(*rterms.pop(())) + rterms = [Add(*rterms[k]) for k in rterms.keys()] + + # the output will depend on the order terms are processed, so + # make it canonical quickly + rterms = list(reversed(list(ordered(rterms)))) + + ok = False # we don't have a solution yet + depth = sqrt_depth(eq) + + if len(rterms) == 1 and not (rterms[0].is_Add and lcm > 2): + eq = rterms[0]**lcm - ((-others)**lcm) + ok = True + else: + if len(rterms) == 1 and rterms[0].is_Add: + rterms = list(rterms[0].args) + if len(bases) == 1: + b = bases.pop() + if len(syms) > 1: + x = b.free_symbols + else: + x = syms + x = list(ordered(x))[0] + try: + inv = _vsolve(covsym**lcm - b, x, **uflags) + if not inv: + raise NotImplementedError + eq = poly.as_expr().subs(b, covsym**lcm).subs(x, inv[0]) + _cov(covsym, covsym**lcm - b) + return _canonical(eq, cov) + except NotImplementedError: + pass + + if len(rterms) == 2: + if not others: + eq = rterms[0]**lcm - (-rterms[1])**lcm + ok = True + elif not log(lcm, 2).is_Integer: + # the lcm-is-power-of-two case is handled below + r0, r1 = rterms + if flags.get('_reverse', False): + r1, r0 = r0, r1 + i0 = _rads0, _bases0, lcm0 = _rads_bases_lcm(r0.as_poly()) + i1 = _rads1, _bases1, lcm1 = _rads_bases_lcm(r1.as_poly()) + for reverse in range(2): + if reverse: + i0, i1 = i1, i0 + r0, r1 = r1, r0 + _rads1, _, lcm1 = i1 + _rads1 = Mul(*_rads1) + t1 = _rads1**lcm1 + c = covsym**lcm1 - t1 + for x in syms: + try: + sol = _vsolve(c, x, **uflags) + if not sol: + raise NotImplementedError + neweq = r0.subs(x, sol[0]) + covsym*r1/_rads1 + \ + others + tmp = unrad(neweq, covsym) + if tmp: + eq, newcov = tmp + if newcov: + newp, newc = newcov + _cov(newp, c.subs(covsym, + _vsolve(newc, covsym, **uflags)[0])) + else: + _cov(covsym, c) + else: + eq = neweq + _cov(covsym, c) + ok = True + break + except NotImplementedError: + if reverse: + raise NotImplementedError( + 'no successful change of variable found') + else: + pass + if ok: + break + elif len(rterms) == 3: + # two cube roots and another with order less than 5 + # (so an analytical solution can be found) or a base + # that matches one of the cube root bases + info = [_rads_bases_lcm(i.as_poly()) for i in rterms] + RAD = 0 + BASES = 1 + LCM = 2 + if info[0][LCM] != 3: + info.append(info.pop(0)) + rterms.append(rterms.pop(0)) + elif info[1][LCM] != 3: + info.append(info.pop(1)) + rterms.append(rterms.pop(1)) + if info[0][LCM] == info[1][LCM] == 3: + if info[1][BASES] != info[2][BASES]: + info[0], info[1] = info[1], info[0] + rterms[0], rterms[1] = rterms[1], rterms[0] + if info[1][BASES] == info[2][BASES]: + eq = rterms[0]**3 + (rterms[1] + rterms[2] + others)**3 + ok = True + elif info[2][LCM] < 5: + # a*root(A, 3) + b*root(B, 3) + others = c + a, b, c, d, A, B = [Dummy(i) for i in 'abcdAB'] + # zz represents the unraded expression into which the + # specifics for this case are substituted + zz = (c - d)*(A**3*a**9 + 3*A**2*B*a**6*b**3 - + 3*A**2*a**6*c**3 + 9*A**2*a**6*c**2*d - 9*A**2*a**6*c*d**2 + + 3*A**2*a**6*d**3 + 3*A*B**2*a**3*b**6 + 21*A*B*a**3*b**3*c**3 - + 63*A*B*a**3*b**3*c**2*d + 63*A*B*a**3*b**3*c*d**2 - + 21*A*B*a**3*b**3*d**3 + 3*A*a**3*c**6 - 18*A*a**3*c**5*d + + 45*A*a**3*c**4*d**2 - 60*A*a**3*c**3*d**3 + 45*A*a**3*c**2*d**4 - + 18*A*a**3*c*d**5 + 3*A*a**3*d**6 + B**3*b**9 - 3*B**2*b**6*c**3 + + 9*B**2*b**6*c**2*d - 9*B**2*b**6*c*d**2 + 3*B**2*b**6*d**3 + + 3*B*b**3*c**6 - 18*B*b**3*c**5*d + 45*B*b**3*c**4*d**2 - + 60*B*b**3*c**3*d**3 + 45*B*b**3*c**2*d**4 - 18*B*b**3*c*d**5 + + 3*B*b**3*d**6 - c**9 + 9*c**8*d - 36*c**7*d**2 + 84*c**6*d**3 - + 126*c**5*d**4 + 126*c**4*d**5 - 84*c**3*d**6 + 36*c**2*d**7 - + 9*c*d**8 + d**9) + def _t(i): + b = Mul(*info[i][RAD]) + return cancel(rterms[i]/b), Mul(*info[i][BASES]) + aa, AA = _t(0) + bb, BB = _t(1) + cc = -rterms[2] + dd = others + eq = zz.xreplace(dict(zip( + (a, A, b, B, c, d), + (aa, AA, bb, BB, cc, dd)))) + ok = True + # handle power-of-2 cases + if not ok: + if log(lcm, 2).is_Integer and (not others and + len(rterms) == 4 or len(rterms) < 4): + def _norm2(a, b): + return a**2 + b**2 + 2*a*b + + if len(rterms) == 4: + # (r0+r1)**2 - (r2+r3)**2 + r0, r1, r2, r3 = rterms + eq = _norm2(r0, r1) - _norm2(r2, r3) + ok = True + elif len(rterms) == 3: + # (r1+r2)**2 - (r0+others)**2 + r0, r1, r2 = rterms + eq = _norm2(r1, r2) - _norm2(r0, others) + ok = True + elif len(rterms) == 2: + # r0**2 - (r1+others)**2 + r0, r1 = rterms + eq = r0**2 - _norm2(r1, others) + ok = True + + new_depth = sqrt_depth(eq) if ok else depth + rpt += 1 # XXX how many repeats with others unchanging is enough? + if not ok or ( + nwas is not None and len(rterms) == nwas and + new_depth is not None and new_depth == depth and + rpt > 3): + raise NotImplementedError('Cannot remove all radicals') + + flags.update({"cov": cov, "n": len(rterms), "rpt": rpt}) + neq = unrad(eq, *syms, **flags) + if neq: + eq, cov = neq + eq, cov = _canonical(eq, cov) + return eq, cov + + +# delayed imports +from sympy.solvers.bivariate import ( + bivariate_type, _solve_lambert, _filtered_gens) diff --git a/phivenv/Lib/site-packages/sympy/solvers/solveset.py b/phivenv/Lib/site-packages/sympy/solvers/solveset.py new file mode 100644 index 0000000000000000000000000000000000000000..0ae242d9c8c4c0d1c1c46cd968a0c5e547ff0f66 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/solvers/solveset.py @@ -0,0 +1,4131 @@ +""" +This module contains functions to: + + - solve a single equation for a single variable, in any domain either real or complex. + + - solve a single transcendental equation for a single variable in any domain either real or complex. + (currently supports solving in real domain only) + + - solve a system of linear equations with N variables and M equations. + + - solve a system of Non Linear Equations with N variables and M equations +""" +from sympy.core.sympify import sympify +from sympy.core import (S, Pow, Dummy, pi, Expr, Wild, Mul, + Add, Basic) +from sympy.core.containers import Tuple +from sympy.core.function import (Lambda, expand_complex, AppliedUndef, + expand_log, _mexpand, expand_trig, nfloat) +from sympy.core.mod import Mod +from sympy.core.numbers import I, Number, Rational, oo +from sympy.core.intfunc import integer_log +from sympy.core.relational import Eq, Ne, Relational +from sympy.core.sorting import default_sort_key, ordered +from sympy.core.symbol import Symbol, _uniquely_named_symbol +from sympy.core.sympify import _sympify +from sympy.core.traversal import preorder_traversal +from sympy.external.gmpy import gcd as number_gcd, lcm as number_lcm +from sympy.polys.matrices.linsolve import _linear_eq_to_dict +from sympy.polys.polyroots import UnsolvableFactorError +from sympy.simplify.simplify import simplify, fraction, trigsimp, nsimplify +from sympy.simplify import powdenest, logcombine +from sympy.functions import (log, tan, cot, sin, cos, sec, csc, exp, + acos, asin, atan, acot, acsc, asec, + piecewise_fold, Piecewise) +from sympy.functions.combinatorial.numbers import totient +from sympy.functions.elementary.complexes import Abs, arg, re, im +from sympy.functions.elementary.hyperbolic import (HyperbolicFunction, + sinh, cosh, tanh, coth, sech, csch, + asinh, acosh, atanh, acoth, asech, acsch) +from sympy.functions.elementary.miscellaneous import real_root +from sympy.functions.elementary.trigonometric import TrigonometricFunction +from sympy.logic.boolalg import And, BooleanTrue +from sympy.sets import (FiniteSet, imageset, Interval, Intersection, + Union, ConditionSet, ImageSet, Complement, Contains) +from sympy.sets.sets import Set, ProductSet +from sympy.matrices import zeros, Matrix, MatrixBase +from sympy.ntheory.factor_ import divisors +from sympy.ntheory.residue_ntheory import discrete_log, nthroot_mod +from sympy.polys import (roots, Poly, degree, together, PolynomialError, + RootOf, factor, lcm, gcd) +from sympy.polys.polyerrors import CoercionFailed +from sympy.polys.polytools import invert, groebner, poly +from sympy.polys.solvers import (sympy_eqs_to_ring, solve_lin_sys, + PolyNonlinearError) +from sympy.polys.matrices.linsolve import _linsolve +from sympy.solvers.solvers import (checksol, denoms, unrad, + _simple_dens, recast_to_symbols) +from sympy.solvers.polysys import solve_poly_system +from sympy.utilities import filldedent +from sympy.utilities.iterables import (numbered_symbols, has_dups, + is_sequence, iterable) +from sympy.calculus.util import periodicity, continuous_domain, function_range + + +from types import GeneratorType + + +class NonlinearError(ValueError): + """Raised when unexpectedly encountering nonlinear equations""" + pass + + +def _masked(f, *atoms): + """Return ``f``, with all objects given by ``atoms`` replaced with + Dummy symbols, ``d``, and the list of replacements, ``(d, e)``, + where ``e`` is an object of type given by ``atoms`` in which + any other instances of atoms have been recursively replaced with + Dummy symbols, too. The tuples are ordered so that if they are + applied in sequence, the origin ``f`` will be restored. + + Examples + ======== + + >>> from sympy import cos + >>> from sympy.abc import x + >>> from sympy.solvers.solveset import _masked + + >>> f = cos(cos(x) + 1) + >>> f, reps = _masked(cos(1 + cos(x)), cos) + >>> f + _a1 + >>> reps + [(_a1, cos(_a0 + 1)), (_a0, cos(x))] + >>> for d, e in reps: + ... f = f.xreplace({d: e}) + >>> f + cos(cos(x) + 1) + """ + sym = numbered_symbols('a', cls=Dummy, real=True) + mask = [] + for a in ordered(f.atoms(*atoms)): + for i in mask: + a = a.replace(*i) + mask.append((a, next(sym))) + for i, (o, n) in enumerate(mask): + f = f.replace(o, n) + mask[i] = (n, o) + mask = list(reversed(mask)) + return f, mask + + +def _invert(f_x, y, x, domain=S.Complexes): + r""" + Reduce the complex valued equation $f(x) = y$ to a set of equations + + $$\left\{g(x) = h_1(y),\ g(x) = h_2(y),\ \dots,\ g(x) = h_n(y) \right\}$$ + + where $g(x)$ is a simpler function than $f(x)$. The return value is a tuple + $(g(x), \mathrm{set}_h)$, where $g(x)$ is a function of $x$ and $\mathrm{set}_h$ is + the set of function $\left\{h_1(y), h_2(y), \dots, h_n(y)\right\}$. + Here, $y$ is not necessarily a symbol. + + $\mathrm{set}_h$ contains the functions, along with the information + about the domain in which they are valid, through set + operations. For instance, if :math:`y = |x| - n` is inverted + in the real domain, then $\mathrm{set}_h$ is not simply + $\{-n, n\}$ as the nature of `n` is unknown; rather, it is: + + $$ \left(\left[0, \infty\right) \cap \left\{n\right\}\right) \cup + \left(\left(-\infty, 0\right] \cap \left\{- n\right\}\right)$$ + + By default, the complex domain is used which means that inverting even + seemingly simple functions like $\exp(x)$ will give very different + results from those obtained in the real domain. + (In the case of $\exp(x)$, the inversion via $\log$ is multi-valued + in the complex domain, having infinitely many branches.) + + If you are working with real values only (or you are not sure which + function to use) you should probably set the domain to + ``S.Reals`` (or use ``invert_real`` which does that automatically). + + + Examples + ======== + + >>> from sympy.solvers.solveset import invert_complex, invert_real + >>> from sympy.abc import x, y + >>> from sympy import exp + + When does exp(x) == y? + + >>> invert_complex(exp(x), y, x) + (x, ImageSet(Lambda(_n, I*(2*_n*pi + arg(y)) + log(Abs(y))), Integers)) + >>> invert_real(exp(x), y, x) + (x, Intersection({log(y)}, Reals)) + + When does exp(x) == 1? + + >>> invert_complex(exp(x), 1, x) + (x, ImageSet(Lambda(_n, 2*_n*I*pi), Integers)) + >>> invert_real(exp(x), 1, x) + (x, {0}) + + See Also + ======== + invert_real, invert_complex + """ + x = sympify(x) + if not x.is_Symbol: + raise ValueError("x must be a symbol") + f_x = sympify(f_x) + if x not in f_x.free_symbols: + raise ValueError("Inverse of constant function doesn't exist") + y = sympify(y) + if x in y.free_symbols: + raise ValueError("y should be independent of x ") + + if domain.is_subset(S.Reals): + x1, s = _invert_real(f_x, FiniteSet(y), x) + else: + x1, s = _invert_complex(f_x, FiniteSet(y), x) + + # f couldn't be inverted completely; return unmodified. + if x1 != x: + return x1, s + + # Avoid adding gratuitous intersections with S.Complexes. Actual + # conditions should be handled by the respective inverters. + if domain is S.Complexes: + return x1, s + + if isinstance(s, FiniteSet): + return x1, s.intersect(domain) + + # "Fancier" solution sets like those obtained by inversion of trigonometric + # functions already include general validity conditions (i.e. conditions on + # the domain of the respective inverse functions), so we should avoid adding + # blanket intersections with S.Reals. But subsets of R (or C) must still be + # accounted for. + if domain is S.Reals: + return x1, s + else: + return x1, s.intersect(domain) + + +invert_complex = _invert + + +def invert_real(f_x, y, x): + """ + Inverts a real-valued function. Same as :func:`invert_complex`, but sets + the domain to ``S.Reals`` before inverting. + """ + return _invert(f_x, y, x, S.Reals) + + +def _invert_real(f, g_ys, symbol): + """Helper function for _invert.""" + + if f == symbol or g_ys is S.EmptySet: + return (symbol, g_ys) + + n = Dummy('n', real=True) + + if isinstance(f, exp) or (f.is_Pow and f.base == S.Exp1): + return _invert_real(f.exp, + imageset(Lambda(n, log(n)), g_ys), + symbol) + + if hasattr(f, 'inverse') and f.inverse() is not None and not isinstance(f, ( + TrigonometricFunction, + HyperbolicFunction, + )): + if len(f.args) > 1: + raise ValueError("Only functions with one argument are supported.") + return _invert_real(f.args[0], + imageset(Lambda(n, f.inverse()(n)), g_ys), + symbol) + + if isinstance(f, Abs): + return _invert_abs(f.args[0], g_ys, symbol) + + if f.is_Add: + # f = g + h + g, h = f.as_independent(symbol) + if g is not S.Zero: + return _invert_real(h, imageset(Lambda(n, n - g), g_ys), symbol) + + if f.is_Mul: + # f = g*h + g, h = f.as_independent(symbol) + + if g is not S.One: + return _invert_real(h, imageset(Lambda(n, n/g), g_ys), symbol) + + if f.is_Pow: + base, expo = f.args + base_has_sym = base.has(symbol) + expo_has_sym = expo.has(symbol) + + if not expo_has_sym: + + if expo.is_rational: + num, den = expo.as_numer_denom() + + if den % 2 == 0 and num % 2 == 1 and den.is_zero is False: + # Here we have f(x)**(num/den) = y + # where den is nonzero and even and y is an element + # of the set g_ys. + # den is even, so we are only interested in the cases + # where both f(x) and y are positive. + # Restricting y to be positive (using the set g_ys_pos) + # means that y**(den/num) is always positive. + # Therefore it isn't necessary to also constrain f(x) + # to be positive because we are only going to + # find solutions of f(x) = y**(d/n) + # where the rhs is already required to be positive. + root = Lambda(n, real_root(n, expo)) + g_ys_pos = g_ys & Interval(0, oo) + res = imageset(root, g_ys_pos) + _inv, _set = _invert_real(base, res, symbol) + return (_inv, _set) + + if den % 2 == 1: + root = Lambda(n, real_root(n, expo)) + res = imageset(root, g_ys) + if num % 2 == 0: + neg_res = imageset(Lambda(n, -n), res) + return _invert_real(base, res + neg_res, symbol) + if num % 2 == 1: + return _invert_real(base, res, symbol) + + elif expo.is_irrational: + root = Lambda(n, real_root(n, expo)) + g_ys_pos = g_ys & Interval(0, oo) + res = imageset(root, g_ys_pos) + return _invert_real(base, res, symbol) + + else: + # indeterminate exponent, e.g. Float or parity of + # num, den of rational could not be determined + pass # use default return + + if not base_has_sym: + rhs = g_ys.args[0] + if base.is_positive: + return _invert_real(expo, + imageset(Lambda(n, log(n, base, evaluate=False)), g_ys), symbol) + elif base.is_negative: + s, b = integer_log(rhs, base) + if b: + return _invert_real(expo, FiniteSet(s), symbol) + else: + return (expo, S.EmptySet) + elif base.is_zero: + one = Eq(rhs, 1) + if one == S.true: + # special case: 0**x - 1 + return _invert_real(expo, FiniteSet(0), symbol) + elif one == S.false: + return (expo, S.EmptySet) + + if isinstance(f, (TrigonometricFunction, HyperbolicFunction)): + return _invert_trig_hyp_real(f, g_ys, symbol) + + return (f, g_ys) + + +# Dictionaries of inverses will be cached after first use. +_trig_inverses = None +_hyp_inverses = None + +def _invert_trig_hyp_real(f, g_ys, symbol): + """Helper function for inverting trigonometric and hyperbolic functions. + + This helper only handles inversion over the reals. + + For trigonometric functions only finite `g_ys` sets are implemented. + + For hyperbolic functions the set `g_ys` is checked against the domain of the + respective inverse functions. Infinite `g_ys` sets are also supported. + """ + + if isinstance(f, HyperbolicFunction): + n = Dummy('n', real=True) + + if isinstance(f, sinh): + # asinh is defined over R. + return _invert_real(f.args[0], imageset(n, asinh(n), g_ys), symbol) + + if isinstance(f, cosh): + g_ys_dom = g_ys.intersect(Interval(1, oo)) + if isinstance(g_ys_dom, Intersection): + # could not properly resolve domain check + if isinstance(g_ys, FiniteSet): + # If g_ys is a `FiniteSet`` it should be sufficient to just + # let the calling `_invert_real()` add an intersection with + # `S.Reals` (or a subset `domain`) to ensure that only valid + # (real) solutions are returned. + # This avoids adding "too many" Intersections or + # ConditionSets in the returned set. + g_ys_dom = g_ys + else: + return (f, g_ys) + return _invert_real(f.args[0], Union( + imageset(n, acosh(n), g_ys_dom), + imageset(n, -acosh(n), g_ys_dom)), symbol) + + if isinstance(f, sech): + g_ys_dom = g_ys.intersect(Interval.Lopen(0, 1)) + if isinstance(g_ys_dom, Intersection): + if isinstance(g_ys, FiniteSet): + g_ys_dom = g_ys + else: + return (f, g_ys) + return _invert_real(f.args[0], Union( + imageset(n, asech(n), g_ys_dom), + imageset(n, -asech(n), g_ys_dom)), symbol) + + if isinstance(f, tanh): + g_ys_dom = g_ys.intersect(Interval.open(-1, 1)) + if isinstance(g_ys_dom, Intersection): + if isinstance(g_ys, FiniteSet): + g_ys_dom = g_ys + else: + return (f, g_ys) + return _invert_real(f.args[0], + imageset(n, atanh(n), g_ys_dom), symbol) + + if isinstance(f, coth): + g_ys_dom = g_ys - Interval(-1, 1) + if isinstance(g_ys_dom, Complement): + if isinstance(g_ys, FiniteSet): + g_ys_dom = g_ys + else: + return (f, g_ys) + return _invert_real(f.args[0], + imageset(n, acoth(n), g_ys_dom), symbol) + + if isinstance(f, csch): + g_ys_dom = g_ys - FiniteSet(0) + if isinstance(g_ys_dom, Complement): + if isinstance(g_ys, FiniteSet): + g_ys_dom = g_ys + else: + return (f, g_ys) + return _invert_real(f.args[0], + imageset(n, acsch(n), g_ys_dom), symbol) + + elif isinstance(f, TrigonometricFunction) and isinstance(g_ys, FiniteSet): + def _get_trig_inverses(func): + global _trig_inverses + if _trig_inverses is None: + _trig_inverses = { + sin : ((asin, lambda y: pi-asin(y)), 2*pi, Interval(-1, 1)), + cos : ((acos, lambda y: -acos(y)), 2*pi, Interval(-1, 1)), + tan : ((atan,), pi, S.Reals), + cot : ((acot,), pi, S.Reals), + sec : ((asec, lambda y: -asec(y)), 2*pi, + Union(Interval(-oo, -1), Interval(1, oo))), + csc : ((acsc, lambda y: pi-acsc(y)), 2*pi, + Union(Interval(-oo, -1), Interval(1, oo)))} + return _trig_inverses[func] + + invs, period, rng = _get_trig_inverses(f.func) + n = Dummy('n', integer=True) + def create_return_set(g): + # returns ConditionSet that will be part of the final (x, set) tuple + invsimg = Union(*[ + imageset(n, period*n + inv(g), S.Integers) for inv in invs]) + inv_f, inv_g_ys = _invert_real(f.args[0], invsimg, symbol) + if inv_f == symbol: # inversion successful + conds = rng.contains(g) + return ConditionSet(symbol, conds, inv_g_ys) + else: + return ConditionSet(symbol, Eq(f, g), S.Reals) + + retset = Union(*[create_return_set(g) for g in g_ys]) + return (symbol, retset) + + else: + return (f, g_ys) + + +def _invert_trig_hyp_complex(f, g_ys, symbol): + """Helper function for inverting trigonometric and hyperbolic functions. + + This helper only handles inversion over the complex numbers. + Only finite `g_ys` sets are implemented. + + Handling of singularities is only implemented for hyperbolic equations. + In case of a symbolic element g in g_ys a ConditionSet may be returned. + """ + + if isinstance(f, TrigonometricFunction) and isinstance(g_ys, FiniteSet): + def inv(trig): + if isinstance(trig, (sin, csc)): + F = asin if isinstance(trig, sin) else acsc + return ( + lambda a: 2*n*pi + F(a), + lambda a: 2*n*pi + pi - F(a)) + if isinstance(trig, (cos, sec)): + F = acos if isinstance(trig, cos) else asec + return ( + lambda a: 2*n*pi + F(a), + lambda a: 2*n*pi - F(a)) + if isinstance(trig, (tan, cot)): + return (lambda a: n*pi + trig.inverse()(a),) + + n = Dummy('n', integer=True) + invs = S.EmptySet + for L in inv(f): + invs += Union(*[imageset(Lambda(n, L(g)), S.Integers) for g in g_ys]) + return _invert_complex(f.args[0], invs, symbol) + + elif isinstance(f, HyperbolicFunction) and isinstance(g_ys, FiniteSet): + # There are two main options regarding singularities / domain checking + # for symbolic elements in g_ys: + # 1. Add a "catch-all" intersection with S.Complexes. + # 2. ConditionSets. + # At present ConditionSets seem to work better and have the additional + # benefit of representing the precise conditions that must be satisfied. + # The conditions are also rather straightforward. (At most two isolated + # points.) + def _get_hyp_inverses(func): + global _hyp_inverses + if _hyp_inverses is None: + _hyp_inverses = { + sinh : ((asinh, lambda y: I*pi-asinh(y)), 2*I*pi, ()), + cosh : ((acosh, lambda y: -acosh(y)), 2*I*pi, ()), + tanh : ((atanh,), I*pi, (-1, 1)), + coth : ((acoth,), I*pi, (-1, 1)), + sech : ((asech, lambda y: -asech(y)), 2*I*pi, (0, )), + csch : ((acsch, lambda y: I*pi-acsch(y)), 2*I*pi, (0, ))} + return _hyp_inverses[func] + + # invs: iterable of main inverses, e.g. (acosh, -acosh). + # excl: iterable of singularities to be checked for. + invs, period, excl = _get_hyp_inverses(f.func) + n = Dummy('n', integer=True) + def create_return_set(g): + # returns ConditionSet that will be part of the final (x, set) tuple + invsimg = Union(*[ + imageset(n, period*n + inv(g), S.Integers) for inv in invs]) + inv_f, inv_g_ys = _invert_complex(f.args[0], invsimg, symbol) + if inv_f == symbol: # inversion successful + conds = And(*[Ne(g, e) for e in excl]) + return ConditionSet(symbol, conds, inv_g_ys) + else: + return ConditionSet(symbol, Eq(f, g), S.Complexes) + + retset = Union(*[create_return_set(g) for g in g_ys]) + return (symbol, retset) + + else: + return (f, g_ys) + + +def _invert_complex(f, g_ys, symbol): + """Helper function for _invert.""" + + if f == symbol or g_ys is S.EmptySet: + return (symbol, g_ys) + + n = Dummy('n') + + if f.is_Add: + # f = g + h + g, h = f.as_independent(symbol) + if g is not S.Zero: + return _invert_complex(h, imageset(Lambda(n, n - g), g_ys), symbol) + + if f.is_Mul: + # f = g*h + g, h = f.as_independent(symbol) + + if g is not S.One: + if g in {S.NegativeInfinity, S.ComplexInfinity, S.Infinity}: + return (h, S.EmptySet) + return _invert_complex(h, imageset(Lambda(n, n/g), g_ys), symbol) + + if f.is_Pow: + base, expo = f.args + # special case: g**r = 0 + # Could be improved like `_invert_real` to handle more general cases. + if expo.is_Rational and g_ys == FiniteSet(0): + if expo.is_positive: + return _invert_complex(base, g_ys, symbol) + + if hasattr(f, 'inverse') and f.inverse() is not None and \ + not isinstance(f, TrigonometricFunction) and \ + not isinstance(f, HyperbolicFunction) and \ + not isinstance(f, exp): + if len(f.args) > 1: + raise ValueError("Only functions with one argument are supported.") + return _invert_complex(f.args[0], + imageset(Lambda(n, f.inverse()(n)), g_ys), symbol) + + if isinstance(f, exp) or (f.is_Pow and f.base == S.Exp1): + if isinstance(g_ys, ImageSet): + # can solve up to `(d*exp(exp(...(exp(a*x + b))...) + c)` format. + # Further can be improved to `(d*exp(exp(...(exp(a*x**n + b*x**(n-1) + ... + f))...) + c)`. + g_ys_expr = g_ys.lamda.expr + g_ys_vars = g_ys.lamda.variables + k = Dummy('k{}'.format(len(g_ys_vars))) + g_ys_vars_1 = (k,) + g_ys_vars + exp_invs = Union(*[imageset(Lambda((g_ys_vars_1,), (I*(2*k*pi + arg(g_ys_expr)) + + log(Abs(g_ys_expr)))), S.Integers**(len(g_ys_vars_1)))]) + return _invert_complex(f.exp, exp_invs, symbol) + + elif isinstance(g_ys, FiniteSet): + exp_invs = Union(*[imageset(Lambda(n, I*(2*n*pi + arg(g_y)) + + log(Abs(g_y))), S.Integers) + for g_y in g_ys if g_y != 0]) + return _invert_complex(f.exp, exp_invs, symbol) + + if isinstance(f, (TrigonometricFunction, HyperbolicFunction)): + return _invert_trig_hyp_complex(f, g_ys, symbol) + + return (f, g_ys) + + +def _invert_abs(f, g_ys, symbol): + """Helper function for inverting absolute value functions. + + Returns the complete result of inverting an absolute value + function along with the conditions which must also be satisfied. + + If it is certain that all these conditions are met, a :class:`~.FiniteSet` + of all possible solutions is returned. If any condition cannot be + satisfied, an :class:`~.EmptySet` is returned. Otherwise, a + :class:`~.ConditionSet` of the solutions, with all the required conditions + specified, is returned. + + """ + if not g_ys.is_FiniteSet: + # this could be used for FiniteSet, but the + # results are more compact if they aren't, e.g. + # ConditionSet(x, Contains(n, Interval(0, oo)), {-n, n}) vs + # Union(Intersection(Interval(0, oo), {n}), Intersection(Interval(-oo, 0), {-n})) + # for the solution of abs(x) - n + pos = Intersection(g_ys, Interval(0, S.Infinity)) + parg = _invert_real(f, pos, symbol) + narg = _invert_real(-f, pos, symbol) + if parg[0] != narg[0]: + raise NotImplementedError + return parg[0], Union(narg[1], parg[1]) + + # check conditions: all these must be true. If any are unknown + # then return them as conditions which must be satisfied + unknown = [] + for a in g_ys.args: + ok = a.is_nonnegative if a.is_Number else a.is_positive + if ok is None: + unknown.append(a) + elif not ok: + return symbol, S.EmptySet + if unknown: + conditions = And(*[Contains(i, Interval(0, oo)) + for i in unknown]) + else: + conditions = True + n = Dummy('n', real=True) + # this is slightly different than above: instead of solving + # +/-f on positive values, here we solve for f on +/- g_ys + g_x, values = _invert_real(f, Union( + imageset(Lambda(n, n), g_ys), + imageset(Lambda(n, -n), g_ys)), symbol) + return g_x, ConditionSet(g_x, conditions, values) + + +def domain_check(f, symbol, p): + """Returns False if point p is infinite or any subexpression of f + is infinite or becomes so after replacing symbol with p. If none of + these conditions is met then True will be returned. + + Examples + ======== + + >>> from sympy import Mul, oo + >>> from sympy.abc import x + >>> from sympy.solvers.solveset import domain_check + >>> g = 1/(1 + (1/(x + 1))**2) + >>> domain_check(g, x, -1) + False + >>> domain_check(x**2, x, 0) + True + >>> domain_check(1/x, x, oo) + False + + * The function relies on the assumption that the original form + of the equation has not been changed by automatic simplification. + + >>> domain_check(x/x, x, 0) # x/x is automatically simplified to 1 + True + + * To deal with automatic evaluations use evaluate=False: + + >>> domain_check(Mul(x, 1/x, evaluate=False), x, 0) + False + """ + f, p = sympify(f), sympify(p) + if p.is_infinite: + return False + return _domain_check(f, symbol, p) + + +def _domain_check(f, symbol, p): + # helper for domain check + if f.is_Atom and f.is_finite: + return True + elif f.subs(symbol, p).is_infinite: + return False + elif isinstance(f, Piecewise): + # Check the cases of the Piecewise in turn. There might be invalid + # expressions in later cases that don't apply e.g. + # solveset(Piecewise((0, Eq(x, 0)), (1/x, True)), x) + for expr, cond in f.args: + condsubs = cond.subs(symbol, p) + if condsubs is S.false: + continue + elif condsubs is S.true: + return _domain_check(expr, symbol, p) + else: + # We don't know which case of the Piecewise holds. On this + # basis we cannot decide whether any solution is in or out of + # the domain. Ideally this function would allow returning a + # symbolic condition for the validity of the solution that + # could be handled in the calling code. In the mean time we'll + # give this particular solution the benefit of the doubt and + # let it pass. + return True + else: + # TODO : We should not blindly recurse through all args of arbitrary expressions like this + return all(_domain_check(g, symbol, p) + for g in f.args) + + +def _is_finite_with_finite_vars(f, domain=S.Complexes): + """ + Return True if the given expression is finite. For symbols that + do not assign a value for `complex` and/or `real`, the domain will + be used to assign a value; symbols that do not assign a value + for `finite` will be made finite. All other assumptions are + left unmodified. + """ + def assumptions(s): + A = s.assumptions0 + A.setdefault('finite', A.get('finite', True)) + if domain.is_subset(S.Reals): + # if this gets set it will make complex=True, too + A.setdefault('real', True) + else: + # don't change 'real' because being complex implies + # nothing about being real + A.setdefault('complex', True) + return A + + reps = {s: Dummy(**assumptions(s)) for s in f.free_symbols} + return f.xreplace(reps).is_finite + + +def _is_function_class_equation(func_class, f, symbol): + """ Tests whether the equation is an equation of the given function class. + + The given equation belongs to the given function class if it is + comprised of functions of the function class which are multiplied by + or added to expressions independent of the symbol. In addition, the + arguments of all such functions must be linear in the symbol as well. + + Examples + ======== + + >>> from sympy.solvers.solveset import _is_function_class_equation + >>> from sympy import tan, sin, tanh, sinh, exp + >>> from sympy.abc import x + >>> from sympy.functions.elementary.trigonometric import TrigonometricFunction + >>> from sympy.functions.elementary.hyperbolic import HyperbolicFunction + >>> _is_function_class_equation(TrigonometricFunction, exp(x) + tan(x), x) + False + >>> _is_function_class_equation(TrigonometricFunction, tan(x) + sin(x), x) + True + >>> _is_function_class_equation(TrigonometricFunction, tan(x**2), x) + False + >>> _is_function_class_equation(TrigonometricFunction, tan(x + 2), x) + True + >>> _is_function_class_equation(HyperbolicFunction, tanh(x) + sinh(x), x) + True + """ + if f.is_Mul or f.is_Add: + return all(_is_function_class_equation(func_class, arg, symbol) + for arg in f.args) + + if f.is_Pow: + if not f.exp.has(symbol): + return _is_function_class_equation(func_class, f.base, symbol) + else: + return False + + if not f.has(symbol): + return True + + if isinstance(f, func_class): + try: + g = Poly(f.args[0], symbol) + return g.degree() <= 1 + except PolynomialError: + return False + else: + return False + + +def _solve_as_rational(f, symbol, domain): + """ solve rational functions""" + f = together(_mexpand(f, recursive=True), deep=True) + g, h = fraction(f) + if not h.has(symbol): + try: + return _solve_as_poly(g, symbol, domain) + except NotImplementedError: + # The polynomial formed from g could end up having + # coefficients in a ring over which finding roots + # isn't implemented yet, e.g. ZZ[a] for some symbol a + return ConditionSet(symbol, Eq(f, 0), domain) + except CoercionFailed: + # contained oo, zoo or nan + return S.EmptySet + else: + valid_solns = _solveset(g, symbol, domain) + invalid_solns = _solveset(h, symbol, domain) + return valid_solns - invalid_solns + + +class _SolveTrig1Error(Exception): + """Raised when _solve_trig1 heuristics do not apply""" + +def _solve_trig(f, symbol, domain): + """Function to call other helpers to solve trigonometric equations """ + # If f is composed of a single trig function (potentially appearing multiple + # times) we should solve by either inverting directly or inverting after a + # suitable change of variable. + # + # _solve_trig is currently only called by _solveset for trig/hyperbolic + # functions of an argument linear in x. Inverting a symbolic argument should + # include a guard against division by zero in order to have a result that is + # consistent with similar processing done by _solve_trig1. + # (Ideally _invert should add these conditions by itself.) + trig_expr, count = None, 0 + for expr in preorder_traversal(f): + if isinstance(expr, (TrigonometricFunction, + HyperbolicFunction)) and expr.has(symbol): + if not trig_expr: + trig_expr, count = expr, 1 + elif expr == trig_expr: + count += 1 + else: + trig_expr, count = False, 0 + break + if count == 1: + # direct inversion + x, sol = _invert(f, 0, symbol, domain) + if x == symbol: + cond = True + if trig_expr.free_symbols - {symbol}: + a, h = trig_expr.args[0].as_independent(symbol, as_Add=True) + m, h = h.as_independent(symbol, as_Add=False) + num, den = m.as_numer_denom() + cond = Ne(num, 0) & Ne(den, 0) + return ConditionSet(symbol, cond, sol) + else: + return ConditionSet(symbol, Eq(f, 0), domain) + elif count: + # solve by change of variable + y = Dummy('y') + f_cov = f.subs(trig_expr, y) + sol_cov = solveset(f_cov, y, domain) + if isinstance(sol_cov, FiniteSet): + return Union( + *[_solve_trig(trig_expr-s, symbol, domain) for s in sol_cov]) + + sol = None + try: + # multiple trig/hyp functions; solve by rewriting to exp + sol = _solve_trig1(f, symbol, domain) + except _SolveTrig1Error: + try: + # multiple trig/hyp functions; solve by rewriting to tan(x/2) + sol = _solve_trig2(f, symbol, domain) + except ValueError: + raise NotImplementedError(filldedent(''' + Solution to this kind of trigonometric equations + is yet to be implemented''')) + return sol + + +def _solve_trig1(f, symbol, domain): + """Primary solver for trigonometric and hyperbolic equations + + Returns either the solution set as a ConditionSet (auto-evaluated to a + union of ImageSets if no variables besides 'symbol' are involved) or + raises _SolveTrig1Error if f == 0 cannot be solved. + + Notes + ===== + Algorithm: + 1. Do a change of variable x -> mu*x in arguments to trigonometric and + hyperbolic functions, in order to reduce them to small integers. (This + step is crucial to keep the degrees of the polynomials of step 4 low.) + 2. Rewrite trigonometric/hyperbolic functions as exponentials. + 3. Proceed to a 2nd change of variable, replacing exp(I*x) or exp(x) by y. + 4. Solve the resulting rational equation. + 5. Use invert_complex or invert_real to return to the original variable. + 6. If the coefficients of 'symbol' were symbolic in nature, add the + necessary consistency conditions in a ConditionSet. + + """ + # Prepare change of variable + x = Dummy('x') + if _is_function_class_equation(HyperbolicFunction, f, symbol): + cov = exp(x) + inverter = invert_real if domain.is_subset(S.Reals) else invert_complex + else: + cov = exp(I*x) + inverter = invert_complex + + f = trigsimp(f) + f_original = f + trig_functions = f.atoms(TrigonometricFunction, HyperbolicFunction) + trig_arguments = [e.args[0] for e in trig_functions] + # trigsimp may have reduced the equation to an expression + # that is independent of 'symbol' (e.g. cos**2+sin**2) + if not any(a.has(symbol) for a in trig_arguments): + return solveset(f_original, symbol, domain) + + denominators = [] + numerators = [] + for ar in trig_arguments: + try: + poly_ar = Poly(ar, symbol) + except PolynomialError: + raise _SolveTrig1Error("trig argument is not a polynomial") + if poly_ar.degree() > 1: # degree >1 still bad + raise _SolveTrig1Error("degree of variable must not exceed one") + if poly_ar.degree() == 0: # degree 0, don't care + continue + c = poly_ar.all_coeffs()[0] # got the coefficient of 'symbol' + numerators.append(fraction(c)[0]) + denominators.append(fraction(c)[1]) + + mu = lcm(denominators)/gcd(numerators) + f = f.subs(symbol, mu*x) + f = f.rewrite(exp) + f = together(f) + g, h = fraction(f) + y = Dummy('y') + g, h = g.expand(), h.expand() + g, h = g.subs(cov, y), h.subs(cov, y) + if g.has(x) or h.has(x): + raise _SolveTrig1Error("change of variable not possible") + + solns = solveset_complex(g, y) - solveset_complex(h, y) + if isinstance(solns, ConditionSet): + raise _SolveTrig1Error("polynomial has ConditionSet solution") + + if isinstance(solns, FiniteSet): + if any(isinstance(s, RootOf) for s in solns): + raise _SolveTrig1Error("polynomial results in RootOf object") + # revert the change of variable + cov = cov.subs(x, symbol/mu) + result = Union(*[inverter(cov, s, symbol)[1] for s in solns]) + # In case of symbolic coefficients, the solution set is only valid + # if numerator and denominator of mu are non-zero. + if mu.has(Symbol): + syms = (mu).atoms(Symbol) + munum, muden = fraction(mu) + condnum = munum.as_independent(*syms, as_Add=False)[1] + condden = muden.as_independent(*syms, as_Add=False)[1] + cond = And(Ne(condnum, 0), Ne(condden, 0)) + else: + cond = True + # Actual conditions are returned as part of the ConditionSet. Adding an + # intersection with C would only complicate some solution sets due to + # current limitations of intersection code. (e.g. #19154) + if domain is S.Complexes: + # This is a slight abuse of ConditionSet. Ideally this should + # be some kind of "PiecewiseSet". (See #19507 discussion) + return ConditionSet(symbol, cond, result) + else: + return ConditionSet(symbol, cond, Intersection(result, domain)) + elif solns is S.EmptySet: + return S.EmptySet + else: + raise _SolveTrig1Error("polynomial solutions must form FiniteSet") + + +def _solve_trig2(f, symbol, domain): + """Secondary helper to solve trigonometric equations, + called when first helper fails """ + f = trigsimp(f) + f_original = f + trig_functions = f.atoms(sin, cos, tan, sec, cot, csc) + trig_arguments = [e.args[0] for e in trig_functions] + denominators = [] + numerators = [] + + # todo: This solver can be extended to hyperbolics if the + # analogous change of variable to tanh (instead of tan) + # is used. + if not trig_functions: + return ConditionSet(symbol, Eq(f_original, 0), domain) + + # todo: The pre-processing below (extraction of numerators, denominators, + # gcd, lcm, mu, etc.) should be updated to the enhanced version in + # _solve_trig1. (See #19507) + for ar in trig_arguments: + try: + poly_ar = Poly(ar, symbol) + except PolynomialError: + raise ValueError("give up, we cannot solve if this is not a polynomial in x") + if poly_ar.degree() > 1: # degree >1 still bad + raise ValueError("degree of variable inside polynomial should not exceed one") + if poly_ar.degree() == 0: # degree 0, don't care + continue + c = poly_ar.all_coeffs()[0] # got the coefficient of 'symbol' + try: + numerators.append(Rational(c).p) + denominators.append(Rational(c).q) + except TypeError: + return ConditionSet(symbol, Eq(f_original, 0), domain) + + x = Dummy('x') + + mu = Rational(2)*number_lcm(*denominators)/number_gcd(*numerators) + f = f.subs(symbol, mu*x) + f = f.rewrite(tan) + f = expand_trig(f) + f = together(f) + + g, h = fraction(f) + y = Dummy('y') + g, h = g.expand(), h.expand() + g, h = g.subs(tan(x), y), h.subs(tan(x), y) + + if g.has(x) or h.has(x): + return ConditionSet(symbol, Eq(f_original, 0), domain) + solns = solveset(g, y, S.Reals) - solveset(h, y, S.Reals) + + if isinstance(solns, FiniteSet): + result = Union(*[invert_real(tan(symbol/mu), s, symbol)[1] + for s in solns]) + dsol = invert_real(tan(symbol/mu), oo, symbol)[1] + if degree(h) > degree(g): # If degree(denom)>degree(num) then there + result = Union(result, dsol) # would be another sol at Lim(denom-->oo) + return Intersection(result, domain) + elif solns is S.EmptySet: + return S.EmptySet + else: + return ConditionSet(symbol, Eq(f_original, 0), S.Reals) + + +def _solve_as_poly(f, symbol, domain=S.Complexes): + """ + Solve the equation using polynomial techniques if it already is a + polynomial equation or, with a change of variables, can be made so. + """ + result = None + if f.is_polynomial(symbol): + solns = roots(f, symbol, cubics=True, quartics=True, + quintics=True, domain='EX') + num_roots = sum(solns.values()) + if degree(f, symbol) <= num_roots: + result = FiniteSet(*solns.keys()) + else: + poly = Poly(f, symbol) + solns = poly.all_roots() + if poly.degree() <= len(solns): + result = FiniteSet(*solns) + else: + result = ConditionSet(symbol, Eq(f, 0), domain) + else: + poly = Poly(f) + if poly is None: + result = ConditionSet(symbol, Eq(f, 0), domain) + gens = [g for g in poly.gens if g.has(symbol)] + + if len(gens) == 1: + poly = Poly(poly, gens[0]) + gen = poly.gen + deg = poly.degree() + poly = Poly(poly.as_expr(), poly.gen, composite=True) + poly_solns = FiniteSet(*roots(poly, cubics=True, quartics=True, + quintics=True).keys()) + + if len(poly_solns) < deg: + result = ConditionSet(symbol, Eq(f, 0), domain) + + if gen != symbol: + y = Dummy('y') + inverter = invert_real if domain.is_subset(S.Reals) else invert_complex + lhs, rhs_s = inverter(gen, y, symbol) + if lhs == symbol: + result = Union(*[rhs_s.subs(y, s) for s in poly_solns]) + if isinstance(result, FiniteSet) and isinstance(gen, Pow + ) and gen.base.is_Rational: + result = FiniteSet(*[expand_log(i) for i in result]) + else: + result = ConditionSet(symbol, Eq(f, 0), domain) + else: + result = ConditionSet(symbol, Eq(f, 0), domain) + + if result is not None: + if isinstance(result, FiniteSet): + # this is to simplify solutions like -sqrt(-I) to sqrt(2)/2 + # - sqrt(2)*I/2. We are not expanding for solution with symbols + # or undefined functions because that makes the solution more complicated. + # For example, expand_complex(a) returns re(a) + I*im(a) + if all(s.atoms(Symbol, AppliedUndef) == set() and not isinstance(s, RootOf) + for s in result): + s = Dummy('s') + result = imageset(Lambda(s, expand_complex(s)), result) + if isinstance(result, FiniteSet) and domain != S.Complexes: + # Avoid adding gratuitous intersections with S.Complexes. Actual + # conditions should be handled elsewhere. + result = result.intersection(domain) + return result + else: + return ConditionSet(symbol, Eq(f, 0), domain) + + +def _solve_radical(f, unradf, symbol, solveset_solver): + """ Helper function to solve equations with radicals """ + res = unradf + eq, cov = res if res else (f, []) + if not cov: + result = solveset_solver(eq, symbol) - \ + Union(*[solveset_solver(g, symbol) for g in denoms(f, symbol)]) + else: + y, yeq = cov + if not solveset_solver(y - I, y): + yreal = Dummy('yreal', real=True) + yeq = yeq.xreplace({y: yreal}) + eq = eq.xreplace({y: yreal}) + y = yreal + g_y_s = solveset_solver(yeq, symbol) + f_y_sols = solveset_solver(eq, y) + result = Union(*[imageset(Lambda(y, g_y), f_y_sols) + for g_y in g_y_s]) + + def check_finiteset(solutions): + f_set = [] # solutions for FiniteSet + c_set = [] # solutions for ConditionSet + for s in solutions: + if checksol(f, symbol, s): + f_set.append(s) + else: + c_set.append(s) + return FiniteSet(*f_set) + ConditionSet(symbol, Eq(f, 0), FiniteSet(*c_set)) + + def check_set(solutions): + if solutions is S.EmptySet: + return solutions + elif isinstance(solutions, ConditionSet): + # XXX: Maybe the base set should be checked? + return solutions + elif isinstance(solutions, FiniteSet): + return check_finiteset(solutions) + elif isinstance(solutions, Complement): + A, B = solutions.args + return Complement(check_set(A), B) + elif isinstance(solutions, Union): + return Union(*[check_set(s) for s in solutions.args]) + else: + # XXX: There should be more cases checked here. The cases above + # are all those that come up in the test suite for now. + return solutions + + solution_set = check_set(result) + + return solution_set + + +def _solve_abs(f, symbol, domain): + """ Helper function to solve equation involving absolute value function """ + if not domain.is_subset(S.Reals): + raise ValueError(filldedent(''' + Absolute values cannot be inverted in the + complex domain.''')) + p, q, r = Wild('p'), Wild('q'), Wild('r') + pattern_match = f.match(p*Abs(q) + r) or {} + f_p, f_q, f_r = [pattern_match.get(i, S.Zero) for i in (p, q, r)] + + if not (f_p.is_zero or f_q.is_zero): + domain = continuous_domain(f_q, symbol, domain) + from .inequalities import solve_univariate_inequality + q_pos_cond = solve_univariate_inequality(f_q >= 0, symbol, + relational=False, domain=domain, continuous=True) + q_neg_cond = q_pos_cond.complement(domain) + + sols_q_pos = solveset_real(f_p*f_q + f_r, + symbol).intersect(q_pos_cond) + sols_q_neg = solveset_real(f_p*(-f_q) + f_r, + symbol).intersect(q_neg_cond) + return Union(sols_q_pos, sols_q_neg) + else: + return ConditionSet(symbol, Eq(f, 0), domain) + + +def solve_decomposition(f, symbol, domain): + """ + Function to solve equations via the principle of "Decomposition + and Rewriting". + + Examples + ======== + >>> from sympy import exp, sin, Symbol, pprint, S + >>> from sympy.solvers.solveset import solve_decomposition as sd + >>> x = Symbol('x') + >>> f1 = exp(2*x) - 3*exp(x) + 2 + >>> sd(f1, x, S.Reals) + {0, log(2)} + >>> f2 = sin(x)**2 + 2*sin(x) + 1 + >>> pprint(sd(f2, x, S.Reals), use_unicode=False) + 3*pi + {2*n*pi + ---- | n in Integers} + 2 + >>> f3 = sin(x + 2) + >>> pprint(sd(f3, x, S.Reals), use_unicode=False) + {2*n*pi - 2 | n in Integers} U {2*n*pi - 2 + pi | n in Integers} + + """ + from sympy.solvers.decompogen import decompogen + # decompose the given function + g_s = decompogen(f, symbol) + # `y_s` represents the set of values for which the function `g` is to be + # solved. + # `solutions` represent the solutions of the equations `g = y_s` or + # `g = 0` depending on the type of `y_s`. + # As we are interested in solving the equation: f = 0 + y_s = FiniteSet(0) + for g in g_s: + frange = function_range(g, symbol, domain) + y_s = Intersection(frange, y_s) + result = S.EmptySet + if isinstance(y_s, FiniteSet): + for y in y_s: + solutions = solveset(Eq(g, y), symbol, domain) + if not isinstance(solutions, ConditionSet): + result += solutions + + else: + if isinstance(y_s, ImageSet): + iter_iset = (y_s,) + + elif isinstance(y_s, Union): + iter_iset = y_s.args + + elif y_s is S.EmptySet: + # y_s is not in the range of g in g_s, so no solution exists + #in the given domain + return S.EmptySet + + for iset in iter_iset: + new_solutions = solveset(Eq(iset.lamda.expr, g), symbol, domain) + dummy_var = tuple(iset.lamda.expr.free_symbols)[0] + (base_set,) = iset.base_sets + if isinstance(new_solutions, FiniteSet): + new_exprs = new_solutions + + elif isinstance(new_solutions, Intersection): + if isinstance(new_solutions.args[1], FiniteSet): + new_exprs = new_solutions.args[1] + + for new_expr in new_exprs: + result += ImageSet(Lambda(dummy_var, new_expr), base_set) + + if result is S.EmptySet: + return ConditionSet(symbol, Eq(f, 0), domain) + + y_s = result + + return y_s + + +def _solveset(f, symbol, domain, _check=False): + """Helper for solveset to return a result from an expression + that has already been sympify'ed and is known to contain the + given symbol.""" + # _check controls whether the answer is checked or not + from sympy.simplify.simplify import signsimp + + if isinstance(f, BooleanTrue): + return domain + + orig_f = f + if f.is_Mul: + coeff, f = f.as_independent(symbol, as_Add=False) + if coeff in {S.ComplexInfinity, S.NegativeInfinity, S.Infinity}: + f = together(orig_f) + elif f.is_Add: + a, h = f.as_independent(symbol) + m, h = h.as_independent(symbol, as_Add=False) + if m not in {S.ComplexInfinity, S.Zero, S.Infinity, + S.NegativeInfinity}: + f = a/m + h # XXX condition `m != 0` should be added to soln + + # assign the solvers to use + solver = lambda f, x, domain=domain: _solveset(f, x, domain) + inverter = lambda f, rhs, symbol: _invert(f, rhs, symbol, domain) + + result = S.EmptySet + + if f.expand().is_zero: + return domain + elif not f.has(symbol): + return S.EmptySet + elif f.is_Mul and all(_is_finite_with_finite_vars(m, domain) + for m in f.args): + # if f(x) and g(x) are both finite we can say that the solution of + # f(x)*g(x) == 0 is same as Union(f(x) == 0, g(x) == 0) is not true in + # general. g(x) can grow to infinitely large for the values where + # f(x) == 0. To be sure that we are not silently allowing any + # wrong solutions we are using this technique only if both f and g are + # finite for a finite input. + result = Union(*[solver(m, symbol) for m in f.args]) + elif (_is_function_class_equation(TrigonometricFunction, f, symbol) or \ + _is_function_class_equation(HyperbolicFunction, f, symbol)): + result = _solve_trig(f, symbol, domain) + elif isinstance(f, arg): + a = f.args[0] + result = Intersection(_solveset(re(a) > 0, symbol, domain), + _solveset(im(a), symbol, domain)) + elif f.is_Piecewise: + expr_set_pairs = f.as_expr_set_pairs(domain) + for (expr, in_set) in expr_set_pairs: + if in_set.is_Relational: + in_set = in_set.as_set() + solns = solver(expr, symbol, in_set) + result += solns + elif isinstance(f, Eq): + result = solver(Add(f.lhs, -f.rhs, evaluate=False), symbol, domain) + + elif f.is_Relational: + from .inequalities import solve_univariate_inequality + try: + result = solve_univariate_inequality( + f, symbol, domain=domain, relational=False) + except NotImplementedError: + result = ConditionSet(symbol, f, domain) + return result + elif _is_modular(f, symbol): + result = _solve_modular(f, symbol, domain) + else: + lhs, rhs_s = inverter(f, 0, symbol) + if lhs == symbol: + # do some very minimal simplification since + # repeated inversion may have left the result + # in a state that other solvers (e.g. poly) + # would have simplified; this is done here + # rather than in the inverter since here it + # is only done once whereas there it would + # be repeated for each step of the inversion + if isinstance(rhs_s, FiniteSet): + rhs_s = FiniteSet(*[Mul(* + signsimp(i).as_content_primitive()) + for i in rhs_s]) + result = rhs_s + + elif isinstance(rhs_s, FiniteSet): + for equation in [lhs - rhs for rhs in rhs_s]: + if equation == f: + u = unrad(f, symbol) + if u: + result += _solve_radical(equation, u, + symbol, + solver) + elif equation.has(Abs): + result += _solve_abs(f, symbol, domain) + else: + result_rational = _solve_as_rational(equation, symbol, domain) + if not isinstance(result_rational, ConditionSet): + result += result_rational + else: + # may be a transcendental type equation + t_result = _transolve(equation, symbol, domain) + if isinstance(t_result, ConditionSet): + # might need factoring; this is expensive so we + # have delayed until now. To avoid recursion + # errors look for a non-trivial factoring into + # a product of symbol dependent terms; I think + # that something that factors as a Pow would + # have already been recognized by now. + factored = equation.factor() + if factored.is_Mul and equation != factored: + _, dep = factored.as_independent(symbol) + if not dep.is_Add: + # non-trivial factoring of equation + # but use form with constants + # in case they need special handling + t_results = [] + for fac in Mul.make_args(factored): + if fac.has(symbol): + t_results.append(solver(fac, symbol)) + t_result = Union(*t_results) + result += t_result + else: + result += solver(equation, symbol) + + elif rhs_s is not S.EmptySet: + result = ConditionSet(symbol, Eq(f, 0), domain) + + if isinstance(result, ConditionSet): + if isinstance(f, Expr): + num, den = f.as_numer_denom() + if den.has(symbol): + _result = _solveset(num, symbol, domain) + if not isinstance(_result, ConditionSet): + singularities = _solveset(den, symbol, domain) + result = _result - singularities + + if _check: + if isinstance(result, ConditionSet): + # it wasn't solved or has enumerated all conditions + # -- leave it alone + return result + + # whittle away all but the symbol-containing core + # to use this for testing + if isinstance(orig_f, Expr): + fx = orig_f.as_independent(symbol, as_Add=True)[1] + fx = fx.as_independent(symbol, as_Add=False)[1] + else: + fx = orig_f + + if isinstance(result, FiniteSet): + # check the result for invalid solutions + result = FiniteSet(*[s for s in result + if isinstance(s, RootOf) + or domain_check(fx, symbol, s)]) + + return result + + +def _is_modular(f, symbol): + """ + Helper function to check below mentioned types of modular equations. + ``A - Mod(B, C) = 0`` + + A -> This can or cannot be a function of symbol. + B -> This is surely a function of symbol. + C -> It is an integer. + + Parameters + ========== + + f : Expr + The equation to be checked. + + symbol : Symbol + The concerned variable for which the equation is to be checked. + + Examples + ======== + + >>> from sympy import symbols, exp, Mod + >>> from sympy.solvers.solveset import _is_modular as check + >>> x, y = symbols('x y') + >>> check(Mod(x, 3) - 1, x) + True + >>> check(Mod(x, 3) - 1, y) + False + >>> check(Mod(x, 3)**2 - 5, x) + False + >>> check(Mod(x, 3)**2 - y, x) + False + >>> check(exp(Mod(x, 3)) - 1, x) + False + >>> check(Mod(3, y) - 1, y) + False + """ + + if not f.has(Mod): + return False + + # extract modterms from f. + modterms = list(f.atoms(Mod)) + + return (len(modterms) == 1 and # only one Mod should be present + modterms[0].args[0].has(symbol) and # B-> function of symbol + modterms[0].args[1].is_integer and # C-> to be an integer. + any(isinstance(term, Mod) + for term in list(_term_factors(f))) # free from other funcs + ) + + +def _invert_modular(modterm, rhs, n, symbol): + """ + Helper function to invert modular equation. + ``Mod(a, m) - rhs = 0`` + + Generally it is inverted as (a, ImageSet(Lambda(n, m*n + rhs), S.Integers)). + More simplified form will be returned if possible. + + If it is not invertible then (modterm, rhs) is returned. + + The following cases arise while inverting equation ``Mod(a, m) - rhs = 0``: + + 1. If a is symbol then m*n + rhs is the required solution. + + 2. If a is an instance of ``Add`` then we try to find two symbol independent + parts of a and the symbol independent part gets transferred to the other + side and again the ``_invert_modular`` is called on the symbol + dependent part. + + 3. If a is an instance of ``Mul`` then same as we done in ``Add`` we separate + out the symbol dependent and symbol independent parts and transfer the + symbol independent part to the rhs with the help of invert and again the + ``_invert_modular`` is called on the symbol dependent part. + + 4. If a is an instance of ``Pow`` then two cases arise as following: + + - If a is of type (symbol_indep)**(symbol_dep) then the remainder is + evaluated with the help of discrete_log function and then the least + period is being found out with the help of totient function. + period*n + remainder is the required solution in this case. + For reference: (https://en.wikipedia.org/wiki/Euler's_theorem) + + - If a is of type (symbol_dep)**(symbol_indep) then we try to find all + primitive solutions list with the help of nthroot_mod function. + m*n + rem is the general solution where rem belongs to solutions list + from nthroot_mod function. + + Parameters + ========== + + modterm, rhs : Expr + The modular equation to be inverted, ``modterm - rhs = 0`` + + symbol : Symbol + The variable in the equation to be inverted. + + n : Dummy + Dummy variable for output g_n. + + Returns + ======= + + A tuple (f_x, g_n) is being returned where f_x is modular independent function + of symbol and g_n being set of values f_x can have. + + Examples + ======== + + >>> from sympy import symbols, exp, Mod, Dummy, S + >>> from sympy.solvers.solveset import _invert_modular as invert_modular + >>> x, y = symbols('x y') + >>> n = Dummy('n') + >>> invert_modular(Mod(exp(x), 7), S(5), n, x) + (Mod(exp(x), 7), 5) + >>> invert_modular(Mod(x, 7), S(5), n, x) + (x, ImageSet(Lambda(_n, 7*_n + 5), Integers)) + >>> invert_modular(Mod(3*x + 8, 7), S(5), n, x) + (x, ImageSet(Lambda(_n, 7*_n + 6), Integers)) + >>> invert_modular(Mod(x**4, 7), S(5), n, x) + (x, EmptySet) + >>> invert_modular(Mod(2**(x**2 + x + 1), 7), S(2), n, x) + (x**2 + x + 1, ImageSet(Lambda(_n, 3*_n + 1), Naturals0)) + + """ + a, m = modterm.args + + if rhs.is_integer is False: + return symbol, S.EmptySet + + if rhs.is_real is False or any(term.is_real is False + for term in list(_term_factors(a))): + # Check for complex arguments + return modterm, rhs + + if abs(rhs) >= abs(m): + # if rhs has value greater than value of m. + return symbol, S.EmptySet + + if a == symbol: + return symbol, ImageSet(Lambda(n, m*n + rhs), S.Integers) + + if a.is_Add: + # g + h = a + g, h = a.as_independent(symbol) + if g is not S.Zero: + x_indep_term = rhs - Mod(g, m) + return _invert_modular(Mod(h, m), Mod(x_indep_term, m), n, symbol) + + if a.is_Mul: + # g*h = a + g, h = a.as_independent(symbol) + if g is not S.One: + x_indep_term = rhs*invert(g, m) + return _invert_modular(Mod(h, m), Mod(x_indep_term, m), n, symbol) + + if a.is_Pow: + # base**expo = a + base, expo = a.args + if expo.has(symbol) and not base.has(symbol): + # remainder -> solution independent of n of equation. + # m, rhs are made coprime by dividing number_gcd(m, rhs) + if not m.is_Integer and rhs.is_Integer and a.base.is_Integer: + return modterm, rhs + + mdiv = m.p // number_gcd(m.p, rhs.p) + try: + remainder = discrete_log(mdiv, rhs.p, a.base.p) + except ValueError: # log does not exist + return modterm, rhs + # period -> coefficient of n in the solution and also referred as + # the least period of expo in which it is repeats itself. + # (a**(totient(m)) - 1) divides m. Here is link of theorem: + # (https://en.wikipedia.org/wiki/Euler's_theorem) + period = totient(m) + for p in divisors(period): + # there might a lesser period exist than totient(m). + if pow(a.base, p, m / number_gcd(m.p, a.base.p)) == 1: + period = p + break + # recursion is not applied here since _invert_modular is currently + # not smart enough to handle infinite rhs as here expo has infinite + # rhs = ImageSet(Lambda(n, period*n + remainder), S.Naturals0). + return expo, ImageSet(Lambda(n, period*n + remainder), S.Naturals0) + elif base.has(symbol) and not expo.has(symbol): + try: + remainder_list = nthroot_mod(rhs, expo, m, all_roots=True) + if remainder_list == []: + return symbol, S.EmptySet + except (ValueError, NotImplementedError): + return modterm, rhs + g_n = S.EmptySet + for rem in remainder_list: + g_n += ImageSet(Lambda(n, m*n + rem), S.Integers) + return base, g_n + + return modterm, rhs + + +def _solve_modular(f, symbol, domain): + r""" + Helper function for solving modular equations of type ``A - Mod(B, C) = 0``, + where A can or cannot be a function of symbol, B is surely a function of + symbol and C is an integer. + + Currently ``_solve_modular`` is only able to solve cases + where A is not a function of symbol. + + Parameters + ========== + + f : Expr + The modular equation to be solved, ``f = 0`` + + symbol : Symbol + The variable in the equation to be solved. + + domain : Set + A set over which the equation is solved. It has to be a subset of + Integers. + + Returns + ======= + + A set of integer solutions satisfying the given modular equation. + A ``ConditionSet`` if the equation is unsolvable. + + Examples + ======== + + >>> from sympy.solvers.solveset import _solve_modular as solve_modulo + >>> from sympy import S, Symbol, sin, Intersection, Interval, Mod + >>> x = Symbol('x') + >>> solve_modulo(Mod(5*x - 8, 7) - 3, x, S.Integers) + ImageSet(Lambda(_n, 7*_n + 5), Integers) + >>> solve_modulo(Mod(5*x - 8, 7) - 3, x, S.Reals) # domain should be subset of integers. + ConditionSet(x, Eq(Mod(5*x + 6, 7) - 3, 0), Reals) + >>> solve_modulo(-7 + Mod(x, 5), x, S.Integers) + EmptySet + >>> solve_modulo(Mod(12**x, 21) - 18, x, S.Integers) + ImageSet(Lambda(_n, 6*_n + 2), Naturals0) + >>> solve_modulo(Mod(sin(x), 7) - 3, x, S.Integers) # not solvable + ConditionSet(x, Eq(Mod(sin(x), 7) - 3, 0), Integers) + >>> solve_modulo(3 - Mod(x, 5), x, Intersection(S.Integers, Interval(0, 100))) + Intersection(ImageSet(Lambda(_n, 5*_n + 3), Integers), Range(0, 101, 1)) + """ + # extract modterm and g_y from f + unsolved_result = ConditionSet(symbol, Eq(f, 0), domain) + modterm = list(f.atoms(Mod))[0] + rhs = -S.One*(f.subs(modterm, S.Zero)) + if f.as_coefficients_dict()[modterm].is_negative: + # checks if coefficient of modterm is negative in main equation. + rhs *= -S.One + + if not domain.is_subset(S.Integers): + return unsolved_result + + if rhs.has(symbol): + # TODO Case: A-> function of symbol, can be extended here + # in future. + return unsolved_result + + n = Dummy('n', integer=True) + f_x, g_n = _invert_modular(modterm, rhs, n, symbol) + + if f_x == modterm and g_n == rhs: + return unsolved_result + + if f_x == symbol: + if domain is not S.Integers: + return domain.intersect(g_n) + return g_n + + if isinstance(g_n, ImageSet): + lamda_expr = g_n.lamda.expr + lamda_vars = g_n.lamda.variables + base_sets = g_n.base_sets + sol_set = _solveset(f_x - lamda_expr, symbol, S.Integers) + if isinstance(sol_set, FiniteSet): + tmp_sol = S.EmptySet + for sol in sol_set: + tmp_sol += ImageSet(Lambda(lamda_vars, sol), *base_sets) + sol_set = tmp_sol + else: + sol_set = ImageSet(Lambda(lamda_vars, sol_set), *base_sets) + return domain.intersect(sol_set) + + return unsolved_result + + +def _term_factors(f): + """ + Iterator to get the factors of all terms present + in the given equation. + + Parameters + ========== + f : Expr + Equation that needs to be addressed + + Returns + ======= + Factors of all terms present in the equation. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.solvers.solveset import _term_factors + >>> x = symbols('x') + >>> list(_term_factors(-2 - x**2 + x*(x + 1))) + [-2, -1, x**2, x, x + 1] + """ + for add_arg in Add.make_args(f): + yield from Mul.make_args(add_arg) + + +def _solve_exponential(lhs, rhs, symbol, domain): + r""" + Helper function for solving (supported) exponential equations. + + Exponential equations are the sum of (currently) at most + two terms with one or both of them having a power with a + symbol-dependent exponent. + + For example + + .. math:: 5^{2x + 3} - 5^{3x - 1} + + .. math:: 4^{5 - 9x} - e^{2 - x} + + Parameters + ========== + + lhs, rhs : Expr + The exponential equation to be solved, `lhs = rhs` + + symbol : Symbol + The variable in which the equation is solved + + domain : Set + A set over which the equation is solved. + + Returns + ======= + + A set of solutions satisfying the given equation. + A ``ConditionSet`` if the equation is unsolvable or + if the assumptions are not properly defined, in that case + a different style of ``ConditionSet`` is returned having the + solution(s) of the equation with the desired assumptions. + + Examples + ======== + + >>> from sympy.solvers.solveset import _solve_exponential as solve_expo + >>> from sympy import symbols, S + >>> x = symbols('x', real=True) + >>> a, b = symbols('a b') + >>> solve_expo(2**x + 3**x - 5**x, 0, x, S.Reals) # not solvable + ConditionSet(x, Eq(2**x + 3**x - 5**x, 0), Reals) + >>> solve_expo(a**x - b**x, 0, x, S.Reals) # solvable but incorrect assumptions + ConditionSet(x, (a > 0) & (b > 0), {0}) + >>> solve_expo(3**(2*x) - 2**(x + 3), 0, x, S.Reals) + {-3*log(2)/(-2*log(3) + log(2))} + >>> solve_expo(2**x - 4**x, 0, x, S.Reals) + {0} + + * Proof of correctness of the method + + The logarithm function is the inverse of the exponential function. + The defining relation between exponentiation and logarithm is: + + .. math:: {\log_b x} = y \enspace if \enspace b^y = x + + Therefore if we are given an equation with exponent terms, we can + convert every term to its corresponding logarithmic form. This is + achieved by taking logarithms and expanding the equation using + logarithmic identities so that it can easily be handled by ``solveset``. + + For example: + + .. math:: 3^{2x} = 2^{x + 3} + + Taking log both sides will reduce the equation to + + .. math:: (2x)\log(3) = (x + 3)\log(2) + + This form can be easily handed by ``solveset``. + """ + unsolved_result = ConditionSet(symbol, Eq(lhs - rhs, 0), domain) + newlhs = powdenest(lhs) + if lhs != newlhs: + # it may also be advantageous to factor the new expr + neweq = factor(newlhs - rhs) + if neweq != (lhs - rhs): + return _solveset(neweq, symbol, domain) # try again with _solveset + + if not (isinstance(lhs, Add) and len(lhs.args) == 2): + # solving for the sum of more than two powers is possible + # but not yet implemented + return unsolved_result + + if rhs != 0: + return unsolved_result + + a, b = list(ordered(lhs.args)) + a_term = a.as_independent(symbol)[1] + b_term = b.as_independent(symbol)[1] + + a_base, a_exp = a_term.as_base_exp() + b_base, b_exp = b_term.as_base_exp() + + if domain.is_subset(S.Reals): + conditions = And( + a_base > 0, + b_base > 0, + Eq(im(a_exp), 0), + Eq(im(b_exp), 0)) + else: + conditions = And( + Ne(a_base, 0), + Ne(b_base, 0)) + + L, R = (expand_log(log(i), force=True) for i in (a, -b)) + solutions = _solveset(L - R, symbol, domain) + + return ConditionSet(symbol, conditions, solutions) + + +def _is_exponential(f, symbol): + r""" + Return ``True`` if one or more terms contain ``symbol`` only in + exponents, else ``False``. + + Parameters + ========== + + f : Expr + The equation to be checked + + symbol : Symbol + The variable in which the equation is checked + + Examples + ======== + + >>> from sympy import symbols, cos, exp + >>> from sympy.solvers.solveset import _is_exponential as check + >>> x, y = symbols('x y') + >>> check(y, y) + False + >>> check(x**y - 1, y) + True + >>> check(x**y*2**y - 1, y) + True + >>> check(exp(x + 3) + 3**x, x) + True + >>> check(cos(2**x), x) + False + + * Philosophy behind the helper + + The function extracts each term of the equation and checks if it is + of exponential form w.r.t ``symbol``. + """ + rv = False + for expr_arg in _term_factors(f): + if symbol not in expr_arg.free_symbols: + continue + if (isinstance(expr_arg, Pow) and + symbol not in expr_arg.base.free_symbols or + isinstance(expr_arg, exp)): + rv = True # symbol in exponent + else: + return False # dependent on symbol in non-exponential way + return rv + + +def _solve_logarithm(lhs, rhs, symbol, domain): + r""" + Helper to solve logarithmic equations which are reducible + to a single instance of `\log`. + + Logarithmic equations are (currently) the equations that contains + `\log` terms which can be reduced to a single `\log` term or + a constant using various logarithmic identities. + + For example: + + .. math:: \log(x) + \log(x - 4) + + can be reduced to: + + .. math:: \log(x(x - 4)) + + Parameters + ========== + + lhs, rhs : Expr + The logarithmic equation to be solved, `lhs = rhs` + + symbol : Symbol + The variable in which the equation is solved + + domain : Set + A set over which the equation is solved. + + Returns + ======= + + A set of solutions satisfying the given equation. + A ``ConditionSet`` if the equation is unsolvable. + + Examples + ======== + + >>> from sympy import symbols, log, S + >>> from sympy.solvers.solveset import _solve_logarithm as solve_log + >>> x = symbols('x') + >>> f = log(x - 3) + log(x + 3) + >>> solve_log(f, 0, x, S.Reals) + {-sqrt(10), sqrt(10)} + + * Proof of correctness + + A logarithm is another way to write exponent and is defined by + + .. math:: {\log_b x} = y \enspace if \enspace b^y = x + + When one side of the equation contains a single logarithm, the + equation can be solved by rewriting the equation as an equivalent + exponential equation as defined above. But if one side contains + more than one logarithm, we need to use the properties of logarithm + to condense it into a single logarithm. + + Take for example + + .. math:: \log(2x) - 15 = 0 + + contains single logarithm, therefore we can directly rewrite it to + exponential form as + + .. math:: x = \frac{e^{15}}{2} + + But if the equation has more than one logarithm as + + .. math:: \log(x - 3) + \log(x + 3) = 0 + + we use logarithmic identities to convert it into a reduced form + + Using, + + .. math:: \log(a) + \log(b) = \log(ab) + + the equation becomes, + + .. math:: \log((x - 3)(x + 3)) + + This equation contains one logarithm and can be solved by rewriting + to exponents. + """ + new_lhs = logcombine(lhs, force=True) + new_f = new_lhs - rhs + + return _solveset(new_f, symbol, domain) + + +def _is_logarithmic(f, symbol): + r""" + Return ``True`` if the equation is in the form + `a\log(f(x)) + b\log(g(x)) + ... + c` else ``False``. + + Parameters + ========== + + f : Expr + The equation to be checked + + symbol : Symbol + The variable in which the equation is checked + + Returns + ======= + + ``True`` if the equation is logarithmic otherwise ``False``. + + Examples + ======== + + >>> from sympy import symbols, tan, log + >>> from sympy.solvers.solveset import _is_logarithmic as check + >>> x, y = symbols('x y') + >>> check(log(x + 2) - log(x + 3), x) + True + >>> check(tan(log(2*x)), x) + False + >>> check(x*log(x), x) + False + >>> check(x + log(x), x) + False + >>> check(y + log(x), x) + True + + * Philosophy behind the helper + + The function extracts each term and checks whether it is + logarithmic w.r.t ``symbol``. + """ + rv = False + for term in Add.make_args(f): + saw_log = False + for term_arg in Mul.make_args(term): + if symbol not in term_arg.free_symbols: + continue + if isinstance(term_arg, log): + if saw_log: + return False # more than one log in term + saw_log = True + else: + return False # dependent on symbol in non-log way + if saw_log: + rv = True + return rv + + +def _is_lambert(f, symbol): + r""" + If this returns ``False`` then the Lambert solver (``_solve_lambert``) will not be called. + + Explanation + =========== + + Quick check for cases that the Lambert solver might be able to handle. + + 1. Equations containing more than two operands and `symbol`s involving any of + `Pow`, `exp`, `HyperbolicFunction`,`TrigonometricFunction`, `log` terms. + + 2. In `Pow`, `exp` the exponent should have `symbol` whereas for + `HyperbolicFunction`,`TrigonometricFunction`, `log` should contain `symbol`. + + 3. For `HyperbolicFunction`,`TrigonometricFunction` the number of trigonometric functions in + equation should be less than number of symbols. (since `A*cos(x) + B*sin(x) - c` + is not the Lambert type). + + Some forms of lambert equations are: + 1. X**X = C + 2. X*(B*log(X) + D)**A = C + 3. A*log(B*X + A) + d*X = C + 4. (B*X + A)*exp(d*X + g) = C + 5. g*exp(B*X + h) - B*X = C + 6. A*D**(E*X + g) - B*X = C + 7. A*cos(X) + B*sin(X) - D*X = C + 8. A*cosh(X) + B*sinh(X) - D*X = C + + Where X is any variable, + A, B, C, D, E are any constants, + g, h are linear functions or log terms. + + Parameters + ========== + + f : Expr + The equation to be checked + + symbol : Symbol + The variable in which the equation is checked + + Returns + ======= + + If this returns ``False`` then the Lambert solver (``_solve_lambert``) will not be called. + + Examples + ======== + + >>> from sympy.solvers.solveset import _is_lambert + >>> from sympy import symbols, cosh, sinh, log + >>> x = symbols('x') + + >>> _is_lambert(3*log(x) - x*log(3), x) + True + >>> _is_lambert(log(log(x - 3)) + log(x-3), x) + True + >>> _is_lambert(cosh(x) - sinh(x), x) + False + >>> _is_lambert((x**2 - 2*x + 1).subs(x, (log(x) + 3*x)**2 - 1), x) + True + + See Also + ======== + + _solve_lambert + + """ + term_factors = list(_term_factors(f.expand())) + + # total number of symbols in equation + no_of_symbols = len([arg for arg in term_factors if arg.has(symbol)]) + # total number of trigonometric terms in equation + no_of_trig = len([arg for arg in term_factors \ + if arg.has(HyperbolicFunction, TrigonometricFunction)]) + + if f.is_Add and no_of_symbols >= 2: + # `log`, `HyperbolicFunction`, `TrigonometricFunction` should have symbols + # and no_of_trig < no_of_symbols + lambert_funcs = (log, HyperbolicFunction, TrigonometricFunction) + if any(isinstance(arg, lambert_funcs)\ + for arg in term_factors if arg.has(symbol)): + if no_of_trig < no_of_symbols: + return True + # here, `Pow`, `exp` exponent should have symbols + elif any(isinstance(arg, (Pow, exp)) \ + for arg in term_factors if (arg.as_base_exp()[1]).has(symbol)): + return True + return False + + +def _transolve(f, symbol, domain): + r""" + Function to solve transcendental equations. It is a helper to + ``solveset`` and should be used internally. ``_transolve`` + currently supports the following class of equations: + + - Exponential equations + - Logarithmic equations + + Parameters + ========== + + f : Any transcendental equation that needs to be solved. + This needs to be an expression, which is assumed + to be equal to ``0``. + + symbol : The variable for which the equation is solved. + This needs to be of class ``Symbol``. + + domain : A set over which the equation is solved. + This needs to be of class ``Set``. + + Returns + ======= + + Set + A set of values for ``symbol`` for which ``f`` is equal to + zero. An ``EmptySet`` is returned if ``f`` does not have solutions + in respective domain. A ``ConditionSet`` is returned as unsolved + object if algorithms to evaluate complete solution are not + yet implemented. + + How to use ``_transolve`` + ========================= + + ``_transolve`` should not be used as an independent function, because + it assumes that the equation (``f``) and the ``symbol`` comes from + ``solveset`` and might have undergone a few modification(s). + To use ``_transolve`` as an independent function the equation (``f``) + and the ``symbol`` should be passed as they would have been by + ``solveset``. + + Examples + ======== + + >>> from sympy.solvers.solveset import _transolve as transolve + >>> from sympy.solvers.solvers import _tsolve as tsolve + >>> from sympy import symbols, S, pprint + >>> x = symbols('x', real=True) # assumption added + >>> transolve(5**(x - 3) - 3**(2*x + 1), x, S.Reals) + {-(log(3) + 3*log(5))/(-log(5) + 2*log(3))} + + How ``_transolve`` works + ======================== + + ``_transolve`` uses two types of helper functions to solve equations + of a particular class: + + Identifying helpers: To determine whether a given equation + belongs to a certain class of equation or not. Returns either + ``True`` or ``False``. + + Solving helpers: Once an equation is identified, a corresponding + helper either solves the equation or returns a form of the equation + that ``solveset`` might better be able to handle. + + * Philosophy behind the module + + The purpose of ``_transolve`` is to take equations which are not + already polynomial in their generator(s) and to either recast them + as such through a valid transformation or to solve them outright. + A pair of helper functions for each class of supported + transcendental functions are employed for this purpose. One + identifies the transcendental form of an equation and the other + either solves it or recasts it into a tractable form that can be + solved by ``solveset``. + For example, an equation in the form `ab^{f(x)} - cd^{g(x)} = 0` + can be transformed to + `\log(a) + f(x)\log(b) - \log(c) - g(x)\log(d) = 0` + (under certain assumptions) and this can be solved with ``solveset`` + if `f(x)` and `g(x)` are in polynomial form. + + How ``_transolve`` is better than ``_tsolve`` + ============================================= + + 1) Better output + + ``_transolve`` provides expressions in a more simplified form. + + Consider a simple exponential equation + + >>> f = 3**(2*x) - 2**(x + 3) + >>> pprint(transolve(f, x, S.Reals), use_unicode=False) + -3*log(2) + {------------------} + -2*log(3) + log(2) + >>> pprint(tsolve(f, x), use_unicode=False) + / 3 \ + | --------| + | log(2/9)| + [-log\2 /] + + 2) Extensible + + The API of ``_transolve`` is designed such that it is easily + extensible, i.e. the code that solves a given class of + equations is encapsulated in a helper and not mixed in with + the code of ``_transolve`` itself. + + 3) Modular + + ``_transolve`` is designed to be modular i.e, for every class of + equation a separate helper for identification and solving is + implemented. This makes it easy to change or modify any of the + method implemented directly in the helpers without interfering + with the actual structure of the API. + + 4) Faster Computation + + Solving equation via ``_transolve`` is much faster as compared to + ``_tsolve``. In ``solve``, attempts are made computing every possibility + to get the solutions. This series of attempts makes solving a bit + slow. In ``_transolve``, computation begins only after a particular + type of equation is identified. + + How to add new class of equations + ================================= + + Adding a new class of equation solver is a three-step procedure: + + - Identify the type of the equations + + Determine the type of the class of equations to which they belong: + it could be of ``Add``, ``Pow``, etc. types. Separate internal functions + are used for each type. Write identification and solving helpers + and use them from within the routine for the given type of equation + (after adding it, if necessary). Something like: + + .. code-block:: python + + def add_type(lhs, rhs, x): + .... + if _is_exponential(lhs, x): + new_eq = _solve_exponential(lhs, rhs, x) + .... + rhs, lhs = eq.as_independent(x) + if lhs.is_Add: + result = add_type(lhs, rhs, x) + + - Define the identification helper. + + - Define the solving helper. + + Apart from this, a few other things needs to be taken care while + adding an equation solver: + + - Naming conventions: + Name of the identification helper should be as + ``_is_class`` where class will be the name or abbreviation + of the class of equation. The solving helper will be named as + ``_solve_class``. + For example: for exponential equations it becomes + ``_is_exponential`` and ``_solve_expo``. + - The identifying helpers should take two input parameters, + the equation to be checked and the variable for which a solution + is being sought, while solving helpers would require an additional + domain parameter. + - Be sure to consider corner cases. + - Add tests for each helper. + - Add a docstring to your helper that describes the method + implemented. + The documentation of the helpers should identify: + + - the purpose of the helper, + - the method used to identify and solve the equation, + - a proof of correctness + - the return values of the helpers + """ + + def add_type(lhs, rhs, symbol, domain): + """ + Helper for ``_transolve`` to handle equations of + ``Add`` type, i.e. equations taking the form as + ``a*f(x) + b*g(x) + .... = c``. + For example: 4**x + 8**x = 0 + """ + result = ConditionSet(symbol, Eq(lhs - rhs, 0), domain) + + # check if it is exponential type equation + if _is_exponential(lhs, symbol): + result = _solve_exponential(lhs, rhs, symbol, domain) + # check if it is logarithmic type equation + elif _is_logarithmic(lhs, symbol): + result = _solve_logarithm(lhs, rhs, symbol, domain) + + return result + + result = ConditionSet(symbol, Eq(f, 0), domain) + + # invert_complex handles the call to the desired inverter based + # on the domain specified. + lhs, rhs_s = invert_complex(f, 0, symbol, domain) + + if isinstance(rhs_s, FiniteSet): + assert (len(rhs_s.args)) == 1 + rhs = rhs_s.args[0] + + if lhs.is_Add: + result = add_type(lhs, rhs, symbol, domain) + else: + result = rhs_s + + return result + + +def solveset(f, symbol=None, domain=S.Complexes): + r"""Solves a given inequality or equation with set as output + + Parameters + ========== + + f : Expr or a relational. + The target equation or inequality + symbol : Symbol + The variable for which the equation is solved + domain : Set + The domain over which the equation is solved + + Returns + ======= + + Set + A set of values for `symbol` for which `f` is True or is equal to + zero. An :class:`~.EmptySet` is returned if `f` is False or nonzero. + A :class:`~.ConditionSet` is returned as unsolved object if algorithms + to evaluate complete solution are not yet implemented. + + ``solveset`` claims to be complete in the solution set that it returns. + + Raises + ====== + + NotImplementedError + The algorithms to solve inequalities in complex domain are + not yet implemented. + ValueError + The input is not valid. + RuntimeError + It is a bug, please report to the github issue tracker. + + + Notes + ===== + + Python interprets 0 and 1 as False and True, respectively, but + in this function they refer to solutions of an expression. So 0 and 1 + return the domain and EmptySet, respectively, while True and False + return the opposite (as they are assumed to be solutions of relational + expressions). + + + See Also + ======== + + solveset_real: solver for real domain + solveset_complex: solver for complex domain + + Examples + ======== + + >>> from sympy import exp, sin, Symbol, pprint, S, Eq + >>> from sympy.solvers.solveset import solveset, solveset_real + + * The default domain is complex. Not specifying a domain will lead + to the solving of the equation in the complex domain (and this + is not affected by the assumptions on the symbol): + + >>> x = Symbol('x') + >>> pprint(solveset(exp(x) - 1, x), use_unicode=False) + {2*n*I*pi | n in Integers} + + >>> x = Symbol('x', real=True) + >>> pprint(solveset(exp(x) - 1, x), use_unicode=False) + {2*n*I*pi | n in Integers} + + * If you want to use ``solveset`` to solve the equation in the + real domain, provide a real domain. (Using ``solveset_real`` + does this automatically.) + + >>> R = S.Reals + >>> x = Symbol('x') + >>> solveset(exp(x) - 1, x, R) + {0} + >>> solveset_real(exp(x) - 1, x) + {0} + + The solution is unaffected by assumptions on the symbol: + + >>> p = Symbol('p', positive=True) + >>> pprint(solveset(p**2 - 4)) + {-2, 2} + + When a :class:`~.ConditionSet` is returned, symbols with assumptions that + would alter the set are replaced with more generic symbols: + + >>> i = Symbol('i', imaginary=True) + >>> solveset(Eq(i**2 + i*sin(i), 1), i, domain=S.Reals) + ConditionSet(_R, Eq(_R**2 + _R*sin(_R) - 1, 0), Reals) + + * Inequalities can be solved over the real domain only. Use of a complex + domain leads to a NotImplementedError. + + >>> solveset(exp(x) > 1, x, R) + Interval.open(0, oo) + + """ + f = sympify(f) + symbol = sympify(symbol) + + if f is S.true: + return domain + + if f is S.false: + return S.EmptySet + + if not isinstance(f, (Expr, Relational, Number)): + raise ValueError("%s is not a valid SymPy expression" % f) + + if not isinstance(symbol, (Expr, Relational)) and symbol is not None: + raise ValueError("%s is not a valid SymPy symbol" % (symbol,)) + + if not isinstance(domain, Set): + raise ValueError("%s is not a valid domain" %(domain)) + + free_symbols = f.free_symbols + + if f.has(Piecewise): + f = piecewise_fold(f) + + if symbol is None and not free_symbols: + b = Eq(f, 0) + if b is S.true: + return domain + elif b is S.false: + return S.EmptySet + else: + raise NotImplementedError(filldedent(''' + relationship between value and 0 is unknown: %s''' % b)) + + if symbol is None: + if len(free_symbols) == 1: + symbol = free_symbols.pop() + elif free_symbols: + raise ValueError(filldedent(''' + The independent variable must be specified for a + multivariate equation.''')) + elif not isinstance(symbol, Symbol): + f, s, swap = recast_to_symbols([f], [symbol]) + # the xreplace will be needed if a ConditionSet is returned + return solveset(f[0], s[0], domain).xreplace(swap) + + # solveset should ignore assumptions on symbols + newsym = None + if domain.is_subset(S.Reals): + if symbol._assumptions_orig != {'real': True}: + newsym = Dummy('R', real=True) + elif domain.is_subset(S.Complexes): + if symbol._assumptions_orig != {'complex': True}: + newsym = Dummy('C', complex=True) + + if newsym is not None: + rv = solveset(f.xreplace({symbol: newsym}), newsym, domain) + # try to use the original symbol if possible + try: + _rv = rv.xreplace({newsym: symbol}) + except TypeError: + _rv = rv + if rv.dummy_eq(_rv): + rv = _rv + return rv + + # Abs has its own handling method which avoids the + # rewriting property that the first piece of abs(x) + # is for x >= 0 and the 2nd piece for x < 0 -- solutions + # can look better if the 2nd condition is x <= 0. Since + # the solution is a set, duplication of results is not + # an issue, e.g. {y, -y} when y is 0 will be {0} + f, mask = _masked(f, Abs) + f = f.rewrite(Piecewise) # everything that's not an Abs + for d, e in mask: + # everything *in* an Abs + e = e.func(e.args[0].rewrite(Piecewise)) + f = f.xreplace({d: e}) + f = piecewise_fold(f) + + return _solveset(f, symbol, domain, _check=True) + + +def solveset_real(f, symbol): + return solveset(f, symbol, S.Reals) + + +def solveset_complex(f, symbol): + return solveset(f, symbol, S.Complexes) + + +def _solveset_multi(eqs, syms, domains): + '''Basic implementation of a multivariate solveset. + + For internal use (not ready for public consumption)''' + + rep = {} + for sym, dom in zip(syms, domains): + if dom is S.Reals: + rep[sym] = Symbol(sym.name, real=True) + eqs = [eq.subs(rep) for eq in eqs] + syms = [sym.subs(rep) for sym in syms] + + syms = tuple(syms) + + if len(eqs) == 0: + return ProductSet(*domains) + + if len(syms) == 1: + sym = syms[0] + domain = domains[0] + solsets = [solveset(eq, sym, domain) for eq in eqs] + solset = Intersection(*solsets) + return ImageSet(Lambda((sym,), (sym,)), solset).doit() + + eqs = sorted(eqs, key=lambda eq: len(eq.free_symbols & set(syms))) + + for n, eq in enumerate(eqs): + sols = [] + all_handled = True + for sym in syms: + if sym not in eq.free_symbols: + continue + sol = solveset(eq, sym, domains[syms.index(sym)]) + + if isinstance(sol, FiniteSet): + i = syms.index(sym) + symsp = syms[:i] + syms[i+1:] + domainsp = domains[:i] + domains[i+1:] + eqsp = eqs[:n] + eqs[n+1:] + for s in sol: + eqsp_sub = [eq.subs(sym, s) for eq in eqsp] + sol_others = _solveset_multi(eqsp_sub, symsp, domainsp) + fun = Lambda((symsp,), symsp[:i] + (s,) + symsp[i:]) + sols.append(ImageSet(fun, sol_others).doit()) + else: + all_handled = False + if all_handled: + return Union(*sols) + + +def solvify(f, symbol, domain): + """Solves an equation using solveset and returns the solution in accordance + with the `solve` output API. + + Returns + ======= + + We classify the output based on the type of solution returned by `solveset`. + + Solution | Output + ---------------------------------------- + FiniteSet | list + + ImageSet, | list (if `f` is periodic) + Union | + + Union | list (with FiniteSet) + + EmptySet | empty list + + Others | None + + + Raises + ====== + + NotImplementedError + A ConditionSet is the input. + + Examples + ======== + + >>> from sympy.solvers.solveset import solvify + >>> from sympy.abc import x + >>> from sympy import S, tan, sin, exp + >>> solvify(x**2 - 9, x, S.Reals) + [-3, 3] + >>> solvify(sin(x) - 1, x, S.Reals) + [pi/2] + >>> solvify(tan(x), x, S.Reals) + [0] + >>> solvify(exp(x) - 1, x, S.Complexes) + + >>> solvify(exp(x) - 1, x, S.Reals) + [0] + + """ + solution_set = solveset(f, symbol, domain) + result = None + if solution_set is S.EmptySet: + result = [] + + elif isinstance(solution_set, ConditionSet): + raise NotImplementedError('solveset is unable to solve this equation.') + + elif isinstance(solution_set, FiniteSet): + result = list(solution_set) + + else: + period = periodicity(f, symbol) + if period is not None: + solutions = S.EmptySet + iter_solutions = () + if isinstance(solution_set, ImageSet): + iter_solutions = (solution_set,) + elif isinstance(solution_set, Union): + if all(isinstance(i, ImageSet) for i in solution_set.args): + iter_solutions = solution_set.args + + for solution in iter_solutions: + solutions += solution.intersect(Interval(0, period, False, True)) + + if isinstance(solutions, FiniteSet): + result = list(solutions) + + else: + solution = solution_set.intersect(domain) + if isinstance(solution, Union): + # concerned about only FiniteSet with Union but not about ImageSet + # if required could be extend + if any(isinstance(i, FiniteSet) for i in solution.args): + result = [sol for soln in solution.args \ + for sol in soln.args if isinstance(soln,FiniteSet)] + else: + return None + + elif isinstance(solution, FiniteSet): + result += solution + + return result + + +############################################################################### +################################ LINSOLVE ##################################### +############################################################################### + + +def linear_coeffs(eq, *syms, dict=False): + """Return a list whose elements are the coefficients of the + corresponding symbols in the sum of terms in ``eq``. + The additive constant is returned as the last element of the + list. + + Raises + ====== + + NonlinearError + The equation contains a nonlinear term + ValueError + duplicate or unordered symbols are passed + + Parameters + ========== + + dict - (default False) when True, return coefficients as a + dictionary with coefficients keyed to syms that were present; + key 1 gives the constant term + + Examples + ======== + + >>> from sympy.solvers.solveset import linear_coeffs + >>> from sympy.abc import x, y, z + >>> linear_coeffs(3*x + 2*y - 1, x, y) + [3, 2, -1] + + It is not necessary to expand the expression: + + >>> linear_coeffs(x + y*(z*(x*3 + 2) + 3), x) + [3*y*z + 1, y*(2*z + 3)] + + When nonlinear is detected, an error will be raised: + + * even if they would cancel after expansion (so the + situation does not pass silently past the caller's + attention) + + >>> eq = 1/x*(x - 1) + 1/x + >>> linear_coeffs(eq.expand(), x) + [0, 1] + >>> linear_coeffs(eq, x) + Traceback (most recent call last): + ... + NonlinearError: + nonlinear in given generators + + * when there are cross terms + + >>> linear_coeffs(x*(y + 1), x, y) + Traceback (most recent call last): + ... + NonlinearError: + symbol-dependent cross-terms encountered + + * when there are terms that contain an expression + dependent on the symbols that is not linear + + >>> linear_coeffs(x**2, x) + Traceback (most recent call last): + ... + NonlinearError: + nonlinear in given generators + """ + eq = _sympify(eq) + if len(syms) == 1 and iterable(syms[0]) and not isinstance(syms[0], Basic): + raise ValueError('expecting unpacked symbols, *syms') + symset = set(syms) + if len(symset) != len(syms): + raise ValueError('duplicate symbols given') + try: + d, c = _linear_eq_to_dict([eq], symset) + d = d[0] + c = c[0] + except PolyNonlinearError as err: + raise NonlinearError(str(err)) + if dict: + if c: + d[S.One] = c + return d + rv = [S.Zero]*(len(syms) + 1) + rv[-1] = c + for i, k in enumerate(syms): + if k not in d: + continue + rv[i] = d[k] + return rv + + +def linear_eq_to_matrix(equations, *symbols): + r""" + Converts a given System of Equations into Matrix form. Here ``equations`` + must be a linear system of equations in ``symbols``. Element ``M[i, j]`` + corresponds to the coefficient of the jth symbol in the ith equation. + + The Matrix form corresponds to the augmented matrix form. For example: + + .. math:: + + 4x + 2y + 3z & = 1 \\ + 3x + y + z & = -6 \\ + 2x + 4y + 9z & = 2 + + This system will return :math:`A` and :math:`b` as: + + .. math:: + + A = \left[\begin{array}{ccc} + 4 & 2 & 3 \\ + 3 & 1 & 1 \\ + 2 & 4 & 9 + \end{array}\right] \\ + + .. math:: + + b = \left[\begin{array}{c} + 1 \\ -6 \\ 2 + \end{array}\right] + + The only simplification performed is to convert + ``Eq(a, b)`` :math:`\Rightarrow a - b`. + + Raises + ====== + + NonlinearError + The equations contain a nonlinear term. + ValueError + The symbols are not given or are not unique. + + Examples + ======== + + >>> from sympy import linear_eq_to_matrix, symbols + >>> c, x, y, z = symbols('c, x, y, z') + + The coefficients (numerical or symbolic) of the symbols will + be returned as matrices: + + >>> eqns = [c*x + z - 1 - c, y + z, x - y] + >>> A, b = linear_eq_to_matrix(eqns, [x, y, z]) + >>> A + Matrix([ + [c, 0, 1], + [0, 1, 1], + [1, -1, 0]]) + >>> b + Matrix([ + [c + 1], + [ 0], + [ 0]]) + + This routine does not simplify expressions and will raise an error + if nonlinearity is encountered: + + >>> eqns = [ + ... (x**2 - 3*x)/(x - 3) - 3, + ... y**2 - 3*y - y*(y - 4) + x - 4] + >>> linear_eq_to_matrix(eqns, [x, y]) + Traceback (most recent call last): + ... + NonlinearError: + symbol-dependent term can be ignored using `strict=False` + + Simplifying these equations will discard the removable singularity in the + first and reveal the linear structure of the second: + + >>> [e.simplify() for e in eqns] + [x - 3, x + y - 4] + + Any such simplification needed to eliminate nonlinear terms must be done + *before* calling this routine. + + """ + if not symbols: + raise ValueError(filldedent(''' + Symbols must be given, for which coefficients + are to be found. + ''')) + + # Check if 'symbols' is a set and raise an error if it is + if isinstance(symbols[0], set): + raise TypeError( + "Unordered 'set' type is not supported as input for symbols.") + + if hasattr(symbols[0], '__iter__'): + symbols = symbols[0] + + if has_dups(symbols): + raise ValueError('Symbols must be unique') + + equations = sympify(equations) + if isinstance(equations, MatrixBase): + equations = list(equations) + elif isinstance(equations, (Expr, Eq)): + equations = [equations] + elif not is_sequence(equations): + raise ValueError(filldedent(''' + Equation(s) must be given as a sequence, Expr, + Eq or Matrix. + ''')) + + # construct the dictionaries + try: + eq, c = _linear_eq_to_dict(equations, symbols) + except PolyNonlinearError as err: + raise NonlinearError(str(err)) + # prepare output matrices + n, m = shape = len(eq), len(symbols) + ix = dict(zip(symbols, range(m))) + A = zeros(*shape) + for row, d in enumerate(eq): + for k in d: + col = ix[k] + A[row, col] = d[k] + b = Matrix(n, 1, [-i for i in c]) + return A, b + + +def linsolve(system, *symbols): + r""" + Solve system of $N$ linear equations with $M$ variables; both + underdetermined and overdetermined systems are supported. + The possible number of solutions is zero, one or infinite. + Zero solutions throws a ValueError, whereas infinite + solutions are represented parametrically in terms of the given + symbols. For unique solution a :class:`~.FiniteSet` of ordered tuples + is returned. + + All standard input formats are supported: + For the given set of equations, the respective input types + are given below: + + .. math:: 3x + 2y - z = 1 + .. math:: 2x - 2y + 4z = -2 + .. math:: 2x - y + 2z = 0 + + * Augmented matrix form, ``system`` given below: + + $$ \text{system} = \left[{array}{cccc} + 3 & 2 & -1 & 1\\ + 2 & -2 & 4 & -2\\ + 2 & -1 & 2 & 0 + \end{array}\right] $$ + + :: + + system = Matrix([[3, 2, -1, 1], [2, -2, 4, -2], [2, -1, 2, 0]]) + + * List of equations form + + :: + + system = [3x + 2y - z - 1, 2x - 2y + 4z + 2, 2x - y + 2z] + + * Input $A$ and $b$ in matrix form (from $Ax = b$) are given as: + + $$ A = \left[\begin{array}{ccc} + 3 & 2 & -1 \\ + 2 & -2 & 4 \\ + 2 & -1 & 2 + \end{array}\right] \ \ b = \left[\begin{array}{c} + 1 \\ -2 \\ 0 + \end{array}\right] $$ + + :: + + A = Matrix([[3, 2, -1], [2, -2, 4], [2, -1, 2]]) + b = Matrix([[1], [-2], [0]]) + system = (A, b) + + Symbols can always be passed but are actually only needed + when 1) a system of equations is being passed and 2) the + system is passed as an underdetermined matrix and one wants + to control the name of the free variables in the result. + An error is raised if no symbols are used for case 1, but if + no symbols are provided for case 2, internally generated symbols + will be provided. When providing symbols for case 2, there should + be at least as many symbols are there are columns in matrix A. + + The algorithm used here is Gauss-Jordan elimination, which + results, after elimination, in a row echelon form matrix. + + Returns + ======= + + A FiniteSet containing an ordered tuple of values for the + unknowns for which the `system` has a solution. (Wrapping + the tuple in FiniteSet is used to maintain a consistent + output format throughout solveset.) + + Returns EmptySet, if the linear system is inconsistent. + + Raises + ====== + + ValueError + The input is not valid. + The symbols are not given. + + Examples + ======== + + >>> from sympy import Matrix, linsolve, symbols + >>> x, y, z = symbols("x, y, z") + >>> A = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 10]]) + >>> b = Matrix([3, 6, 9]) + >>> A + Matrix([ + [1, 2, 3], + [4, 5, 6], + [7, 8, 10]]) + >>> b + Matrix([ + [3], + [6], + [9]]) + >>> linsolve((A, b), [x, y, z]) + {(-1, 2, 0)} + + * Parametric Solution: In case the system is underdetermined, the + function will return a parametric solution in terms of the given + symbols. Those that are free will be returned unchanged. e.g. in + the system below, `z` is returned as the solution for variable z; + it can take on any value. + + >>> A = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + >>> b = Matrix([3, 6, 9]) + >>> linsolve((A, b), x, y, z) + {(z - 1, 2 - 2*z, z)} + + If no symbols are given, internally generated symbols will be used. + The ``tau0`` in the third position indicates (as before) that the third + variable -- whatever it is named -- can take on any value: + + >>> linsolve((A, b)) + {(tau0 - 1, 2 - 2*tau0, tau0)} + + * List of equations as input + + >>> Eqns = [3*x + 2*y - z - 1, 2*x - 2*y + 4*z + 2, - x + y/2 - z] + >>> linsolve(Eqns, x, y, z) + {(1, -2, -2)} + + * Augmented matrix as input + + >>> aug = Matrix([[2, 1, 3, 1], [2, 6, 8, 3], [6, 8, 18, 5]]) + >>> aug + Matrix([ + [2, 1, 3, 1], + [2, 6, 8, 3], + [6, 8, 18, 5]]) + >>> linsolve(aug, x, y, z) + {(3/10, 2/5, 0)} + + * Solve for symbolic coefficients + + >>> a, b, c, d, e, f = symbols('a, b, c, d, e, f') + >>> eqns = [a*x + b*y - c, d*x + e*y - f] + >>> linsolve(eqns, x, y) + {((-b*f + c*e)/(a*e - b*d), (a*f - c*d)/(a*e - b*d))} + + * A degenerate system returns solution as set of given + symbols. + + >>> system = Matrix(([0, 0, 0], [0, 0, 0], [0, 0, 0])) + >>> linsolve(system, x, y) + {(x, y)} + + * For an empty system linsolve returns empty set + + >>> linsolve([], x) + EmptySet + + * An error is raised if any nonlinearity is detected, even + if it could be removed with expansion + + >>> linsolve([x*(1/x - 1)], x) + Traceback (most recent call last): + ... + NonlinearError: nonlinear term: 1/x + + >>> linsolve([x*(y + 1)], x, y) + Traceback (most recent call last): + ... + NonlinearError: nonlinear cross-term: x*(y + 1) + + >>> linsolve([x**2 - 1], x) + Traceback (most recent call last): + ... + NonlinearError: nonlinear term: x**2 + """ + if not system: + return S.EmptySet + + # If second argument is an iterable + if symbols and hasattr(symbols[0], '__iter__'): + symbols = symbols[0] + sym_gen = isinstance(symbols, GeneratorType) + dup_msg = 'duplicate symbols given' + + + b = None # if we don't get b the input was bad + # unpack system + + if hasattr(system, '__iter__'): + + # 1). (A, b) + if len(system) == 2 and isinstance(system[0], MatrixBase): + A, b = system + + # 2). (eq1, eq2, ...) + if not isinstance(system[0], MatrixBase): + if sym_gen or not symbols: + raise ValueError(filldedent(''' + When passing a system of equations, the explicit + symbols for which a solution is being sought must + be given as a sequence, too. + ''')) + if len(set(symbols)) != len(symbols): + raise ValueError(dup_msg) + + # + # Pass to the sparse solver implemented in polys. It is important + # that we do not attempt to convert the equations to a matrix + # because that would be very inefficient for large sparse systems + # of equations. + # + eqs = system + eqs = [sympify(eq) for eq in eqs] + try: + sol = _linsolve(eqs, symbols) + except PolyNonlinearError as exc: + # e.g. cos(x) contains an element of the set of generators + raise NonlinearError(str(exc)) + + if sol is None: + return S.EmptySet + + sol = FiniteSet(Tuple(*(sol.get(sym, sym) for sym in symbols))) + return sol + + elif isinstance(system, MatrixBase) and not ( + symbols and not isinstance(symbols, GeneratorType) and + isinstance(symbols[0], MatrixBase)): + # 3). A augmented with b + A, b = system[:, :-1], system[:, -1:] + + if b is None: + raise ValueError("Invalid arguments") + if sym_gen: + symbols = [next(symbols) for i in range(A.cols)] + symset = set(symbols) + if any(symset & (A.free_symbols | b.free_symbols)): + raise ValueError(filldedent(''' + At least one of the symbols provided + already appears in the system to be solved. + One way to avoid this is to use Dummy symbols in + the generator, e.g. numbered_symbols('%s', cls=Dummy) + ''' % symbols[0].name.rstrip('1234567890'))) + elif len(symset) != len(symbols): + raise ValueError(dup_msg) + + if not symbols: + symbols = [Dummy() for _ in range(A.cols)] + name = _uniquely_named_symbol('tau', (A, b), + compare=lambda i: str(i).rstrip('1234567890')).name + gen = numbered_symbols(name) + else: + gen = None + + # This is just a wrapper for solve_lin_sys + eqs = [] + rows = A.tolist() + for rowi, bi in zip(rows, b): + terms = [elem * sym for elem, sym in zip(rowi, symbols) if elem] + terms.append(-bi) + eqs.append(Add(*terms)) + + eqs, ring = sympy_eqs_to_ring(eqs, symbols) + sol = solve_lin_sys(eqs, ring, _raw=False) + if sol is None: + return S.EmptySet + #sol = {sym:val for sym, val in sol.items() if sym != val} + sol = FiniteSet(Tuple(*(sol.get(sym, sym) for sym in symbols))) + + if gen is not None: + solsym = sol.free_symbols + rep = {sym: next(gen) for sym in symbols if sym in solsym} + sol = sol.subs(rep) + + return sol + + +############################################################################## +# ------------------------------nonlinsolve ---------------------------------# +############################################################################## + + +def _return_conditionset(eqs, symbols): + # return conditionset + eqs = (Eq(lhs, 0) for lhs in eqs) + condition_set = ConditionSet( + Tuple(*symbols), And(*eqs), S.Complexes**len(symbols)) + return condition_set + + +def substitution(system, symbols, result=[{}], known_symbols=[], + exclude=[], all_symbols=None): + r""" + Solves the `system` using substitution method. It is used in + :func:`~.nonlinsolve`. This will be called from :func:`~.nonlinsolve` when any + equation(s) is non polynomial equation. + + Parameters + ========== + + system : list of equations + The target system of equations + symbols : list of symbols to be solved. + The variable(s) for which the system is solved + known_symbols : list of solved symbols + Values are known for these variable(s) + result : An empty list or list of dict + If No symbol values is known then empty list otherwise + symbol as keys and corresponding value in dict. + exclude : Set of expression. + Mostly denominator expression(s) of the equations of the system. + Final solution should not satisfy these expressions. + all_symbols : known_symbols + symbols(unsolved). + + Returns + ======= + + A FiniteSet of ordered tuple of values of `all_symbols` for which the + `system` has solution. Order of values in the tuple is same as symbols + present in the parameter `all_symbols`. If parameter `all_symbols` is None + then same as symbols present in the parameter `symbols`. + + Please note that general FiniteSet is unordered, the solution returned + here is not simply a FiniteSet of solutions, rather it is a FiniteSet of + ordered tuple, i.e. the first & only argument to FiniteSet is a tuple of + solutions, which is ordered, & hence the returned solution is ordered. + + Also note that solution could also have been returned as an ordered tuple, + FiniteSet is just a wrapper `{}` around the tuple. It has no other + significance except for the fact it is just used to maintain a consistent + output format throughout the solveset. + + Raises + ====== + + ValueError + The input is not valid. + The symbols are not given. + AttributeError + The input symbols are not :class:`~.Symbol` type. + + Examples + ======== + + >>> from sympy import symbols, substitution + >>> x, y = symbols('x, y', real=True) + >>> substitution([x + y], [x], [{y: 1}], [y], set([]), [x, y]) + {(-1, 1)} + + * When you want a soln not satisfying $x + 1 = 0$ + + >>> substitution([x + y], [x], [{y: 1}], [y], set([x + 1]), [y, x]) + EmptySet + >>> substitution([x + y], [x], [{y: 1}], [y], set([x - 1]), [y, x]) + {(1, -1)} + >>> substitution([x + y - 1, y - x**2 + 5], [x, y]) + {(-3, 4), (2, -1)} + + * Returns both real and complex solution + + >>> x, y, z = symbols('x, y, z') + >>> from sympy import exp, sin + >>> substitution([exp(x) - sin(y), y**2 - 4], [x, y]) + {(ImageSet(Lambda(_n, I*(2*_n*pi + pi) + log(sin(2))), Integers), -2), + (ImageSet(Lambda(_n, 2*_n*I*pi + log(sin(2))), Integers), 2)} + + >>> eqs = [z**2 + exp(2*x) - sin(y), -3 + exp(-y)] + >>> substitution(eqs, [y, z]) + {(-log(3), -sqrt(-exp(2*x) - sin(log(3)))), + (-log(3), sqrt(-exp(2*x) - sin(log(3)))), + (ImageSet(Lambda(_n, 2*_n*I*pi - log(3)), Integers), + ImageSet(Lambda(_n, -sqrt(-exp(2*x) + sin(2*_n*I*pi - log(3)))), Integers)), + (ImageSet(Lambda(_n, 2*_n*I*pi - log(3)), Integers), + ImageSet(Lambda(_n, sqrt(-exp(2*x) + sin(2*_n*I*pi - log(3)))), Integers))} + + """ + + if not system: + return S.EmptySet + + for i, e in enumerate(system): + if isinstance(e, Eq): + system[i] = e.lhs - e.rhs + + if not symbols: + msg = ('Symbols must be given, for which solution of the ' + 'system is to be found.') + raise ValueError(filldedent(msg)) + + if not is_sequence(symbols): + msg = ('symbols should be given as a sequence, e.g. a list.' + 'Not type %s: %s') + raise TypeError(filldedent(msg % (type(symbols), symbols))) + + if not getattr(symbols[0], 'is_Symbol', False): + msg = ('Iterable of symbols must be given as ' + 'second argument, not type %s: %s') + raise ValueError(filldedent(msg % (type(symbols[0]), symbols[0]))) + + # By default `all_symbols` will be same as `symbols` + if all_symbols is None: + all_symbols = symbols + + old_result = result + # storing complements and intersection for particular symbol + complements = {} + intersections = {} + + # when total_solveset_call equals total_conditionset + # it means that solveset failed to solve all eqs. + total_conditionset = -1 + total_solveset_call = -1 + + def _unsolved_syms(eq, sort=False): + """Returns the unsolved symbol present + in the equation `eq`. + """ + free = eq.free_symbols + unsolved = (free - set(known_symbols)) & set(all_symbols) + if sort: + unsolved = list(unsolved) + unsolved.sort(key=default_sort_key) + return unsolved + + # sort such that equation with the fewest potential symbols is first. + # means eq with less number of variable first in the list. + eqs_in_better_order = list( + ordered(system, lambda _: len(_unsolved_syms(_)))) + + def add_intersection_complement(result, intersection_dict, complement_dict): + # If solveset has returned some intersection/complement + # for any symbol, it will be added in the final solution. + final_result = [] + for res in result: + res_copy = res + for key_res, value_res in res.items(): + intersect_set, complement_set = None, None + for key_sym, value_sym in intersection_dict.items(): + if key_sym == key_res: + intersect_set = value_sym + for key_sym, value_sym in complement_dict.items(): + if key_sym == key_res: + complement_set = value_sym + if intersect_set or complement_set: + new_value = FiniteSet(value_res) + if intersect_set and intersect_set != S.Complexes: + new_value = Intersection(new_value, intersect_set) + if complement_set: + new_value = Complement(new_value, complement_set) + if new_value is S.EmptySet: + res_copy = None + break + elif new_value.is_FiniteSet and len(new_value) == 1: + res_copy[key_res] = set(new_value).pop() + else: + res_copy[key_res] = new_value + + if res_copy is not None: + final_result.append(res_copy) + return final_result + + def _extract_main_soln(sym, sol, soln_imageset): + """Separate the Complements, Intersections, ImageSet lambda expr and + its base_set. This function returns the unmasked sol from different classes + of sets and also returns the appended ImageSet elements in a + soln_imageset dict: `{unmasked element: ImageSet}`. + """ + # if there is union, then need to check + # Complement, Intersection, Imageset. + # Order should not be changed. + if isinstance(sol, ConditionSet): + # extracts any solution in ConditionSet + sol = sol.base_set + + if isinstance(sol, Complement): + # extract solution and complement + complements[sym] = sol.args[1] + sol = sol.args[0] + # complement will be added at the end + # using `add_intersection_complement` method + + # if there is union of Imageset or other in soln. + # no testcase is written for this if block + if isinstance(sol, Union): + sol_args = sol.args + sol = S.EmptySet + # We need in sequence so append finteset elements + # and then imageset or other. + for sol_arg2 in sol_args: + if isinstance(sol_arg2, FiniteSet): + sol += sol_arg2 + else: + # ImageSet, Intersection, complement then + # append them directly + sol += FiniteSet(sol_arg2) + + if isinstance(sol, Intersection): + # Interval/Set will be at 0th index always + if sol.args[0] not in (S.Reals, S.Complexes): + # Sometimes solveset returns soln with intersection + # S.Reals or S.Complexes. We don't consider that + # intersection. + intersections[sym] = sol.args[0] + sol = sol.args[1] + # after intersection and complement Imageset should + # be checked. + if isinstance(sol, ImageSet): + soln_imagest = sol + expr2 = sol.lamda.expr + sol = FiniteSet(expr2) + soln_imageset[expr2] = soln_imagest + + if not isinstance(sol, FiniteSet): + sol = FiniteSet(sol) + return sol, soln_imageset + + def _check_exclude(rnew, imgset_yes): + rnew_ = rnew + if imgset_yes: + # replace all dummy variables (Imageset lambda variables) + # with zero before `checksol`. Considering fundamental soln + # for `checksol`. + rnew_copy = rnew.copy() + dummy_n = imgset_yes[0] + for key_res, value_res in rnew_copy.items(): + rnew_copy[key_res] = value_res.subs(dummy_n, 0) + rnew_ = rnew_copy + # satisfy_exclude == true if it satisfies the expr of `exclude` list. + try: + # something like : `Mod(-log(3), 2*I*pi)` can't be + # simplified right now, so `checksol` returns `TypeError`. + # when this issue is fixed this try block should be + # removed. Mod(-log(3), 2*I*pi) == -log(3) + satisfy_exclude = any( + checksol(d, rnew_) for d in exclude) + except TypeError: + satisfy_exclude = None + return satisfy_exclude + + def _restore_imgset(rnew, original_imageset, newresult): + restore_sym = set(rnew.keys()) & \ + set(original_imageset.keys()) + for key_sym in restore_sym: + img = original_imageset[key_sym] + rnew[key_sym] = img + if rnew not in newresult: + newresult.append(rnew) + + def _append_eq(eq, result, res, delete_soln, n=None): + u = Dummy('u') + if n: + eq = eq.subs(n, 0) + satisfy = eq if eq in (True, False) else checksol(u, u, eq, minimal=True) + if satisfy is False: + delete_soln = True + res = {} + else: + result.append(res) + return result, res, delete_soln + + def _append_new_soln(rnew, sym, sol, imgset_yes, soln_imageset, + original_imageset, newresult, eq=None): + """If `rnew` (A dict ) contains valid soln + append it to `newresult` list. + `imgset_yes` is (base, dummy_var) if there was imageset in previously + calculated result(otherwise empty tuple). `original_imageset` is dict + of imageset expr and imageset from this result. + `soln_imageset` dict of imageset expr and imageset of new soln. + """ + satisfy_exclude = _check_exclude(rnew, imgset_yes) + delete_soln = False + # soln should not satisfy expr present in `exclude` list. + if not satisfy_exclude: + local_n = None + # if it is imageset + if imgset_yes: + local_n = imgset_yes[0] + base = imgset_yes[1] + if sym and sol: + # when `sym` and `sol` is `None` means no new + # soln. In that case we will append rnew directly after + # substituting original imagesets in rnew values if present + # (second last line of this function using _restore_imgset) + dummy_list = list(sol.atoms(Dummy)) + # use one dummy `n` which is in + # previous imageset + local_n_list = [ + local_n for i in range( + 0, len(dummy_list))] + + dummy_zip = zip(dummy_list, local_n_list) + lam = Lambda(local_n, sol.subs(dummy_zip)) + rnew[sym] = ImageSet(lam, base) + if eq is not None: + newresult, rnew, delete_soln = _append_eq( + eq, newresult, rnew, delete_soln, local_n) + elif eq is not None: + newresult, rnew, delete_soln = _append_eq( + eq, newresult, rnew, delete_soln) + elif sol in soln_imageset.keys(): + rnew[sym] = soln_imageset[sol] + # restore original imageset + _restore_imgset(rnew, original_imageset, newresult) + else: + newresult.append(rnew) + elif satisfy_exclude: + delete_soln = True + rnew = {} + _restore_imgset(rnew, original_imageset, newresult) + return newresult, delete_soln + + def _new_order_result(result, eq): + # separate first, second priority. `res` that makes `eq` value equals + # to zero, should be used first then other result(second priority). + # If it is not done then we may miss some soln. + first_priority = [] + second_priority = [] + for res in result: + if not any(isinstance(val, ImageSet) for val in res.values()): + if eq.subs(res) == 0: + first_priority.append(res) + else: + second_priority.append(res) + if first_priority or second_priority: + return first_priority + second_priority + return result + + def _solve_using_known_values(result, solver): + """Solves the system using already known solution + (result contains the dict ). + solver is :func:`~.solveset_complex` or :func:`~.solveset_real`. + """ + # stores imageset . + soln_imageset = {} + total_solvest_call = 0 + total_conditionst = 0 + + # sort equations so the one with the fewest potential + # symbols appears first + for index, eq in enumerate(eqs_in_better_order): + newresult = [] + # if imageset, expr is used to solve for other symbol + imgset_yes = False + for res in result: + original_imageset = {} + got_symbol = set() # symbols solved in one iteration + # find the imageset and use its expr. + for k, v in res.items(): + if isinstance(v, ImageSet): + res[k] = v.lamda.expr + original_imageset[k] = v + dummy_n = v.lamda.expr.atoms(Dummy).pop() + (base,) = v.base_sets + imgset_yes = (dummy_n, base) + assert not isinstance(v, FiniteSet) # if so, internal error + # update eq with everything that is known so far + eq2 = eq.subs(res).expand() + if imgset_yes and not eq2.has(imgset_yes[0]): + # The substituted equation simplified in such a way that + # it's no longer necessary to encapsulate a potential new + # solution in an ImageSet. (E.g. at the previous step some + # {n*2*pi} was found as partial solution for one of the + # unknowns, but its main solution expression n*2*pi has now + # been substituted in a trigonometric function.) + imgset_yes = False + + unsolved_syms = _unsolved_syms(eq2, sort=True) + if not unsolved_syms: + if res: + newresult, delete_res = _append_new_soln( + res, None, None, imgset_yes, soln_imageset, + original_imageset, newresult, eq2) + if delete_res: + # `delete_res` is true, means substituting `res` in + # eq2 doesn't return `zero` or deleting the `res` + # (a soln) since it satisfies expr of `exclude` + # list. + result.remove(res) + continue # skip as it's independent of desired symbols + depen1, depen2 = eq2.as_independent(*unsolved_syms) + if (depen1.has(Abs) or depen2.has(Abs)) and solver == solveset_complex: + # Absolute values cannot be inverted in the + # complex domain + continue + soln_imageset = {} + for sym in unsolved_syms: + not_solvable = False + try: + soln = solver(eq2, sym) + total_solvest_call += 1 + soln_new = S.EmptySet + if isinstance(soln, Complement): + # separate solution and complement + complements[sym] = soln.args[1] + soln = soln.args[0] + # complement will be added at the end + if isinstance(soln, Intersection): + # Interval will be at 0th index always + if soln.args[0] != Interval(-oo, oo): + # sometimes solveset returns soln + # with intersection S.Reals, to confirm that + # soln is in domain=S.Reals + intersections[sym] = soln.args[0] + soln_new += soln.args[1] + soln = soln_new if soln_new else soln + if index > 0 and solver == solveset_real: + # one symbol's real soln, another symbol may have + # corresponding complex soln. + if not isinstance(soln, (ImageSet, ConditionSet)): + soln += solveset_complex(eq2, sym) # might give ValueError with Abs + except (NotImplementedError, ValueError): + # If solveset is not able to solve equation `eq2`. Next + # time we may get soln using next equation `eq2` + continue + if isinstance(soln, ConditionSet): + if soln.base_set in (S.Reals, S.Complexes): + soln = S.EmptySet + # don't do `continue` we may get soln + # in terms of other symbol(s) + not_solvable = True + total_conditionst += 1 + else: + soln = soln.base_set + + if soln is not S.EmptySet: + soln, soln_imageset = _extract_main_soln( + sym, soln, soln_imageset) + + for sol in soln: + # sol is not a `Union` since we checked it + # before this loop + sol, soln_imageset = _extract_main_soln( + sym, sol, soln_imageset) + sol = set(sol).pop() # XXX what if there are more solutions? + free = sol.free_symbols + if got_symbol and any( + ss in free for ss in got_symbol + ): + # sol depends on previously solved symbols + # then continue + continue + rnew = res.copy() + # put each solution in res and append the new result + # in the new result list (solution for symbol `s`) + # along with old results. + for k, v in res.items(): + if isinstance(v, Expr) and isinstance(sol, Expr): + # if any unsolved symbol is present + # Then subs known value + rnew[k] = v.subs(sym, sol) + # and add this new solution + if sol in soln_imageset.keys(): + # replace all lambda variables with 0. + imgst = soln_imageset[sol] + rnew[sym] = imgst.lamda( + *[0 for i in range(0, len( + imgst.lamda.variables))]) + else: + rnew[sym] = sol + newresult, delete_res = _append_new_soln( + rnew, sym, sol, imgset_yes, soln_imageset, + original_imageset, newresult) + if delete_res: + # deleting the `res` (a soln) since it satisfies + # eq of `exclude` list + result.remove(res) + # solution got for sym + if not not_solvable: + got_symbol.add(sym) + # next time use this new soln + if newresult: + result = newresult + return result, total_solvest_call, total_conditionst + + new_result_real, solve_call1, cnd_call1 = _solve_using_known_values( + old_result, solveset_real) + new_result_complex, solve_call2, cnd_call2 = _solve_using_known_values( + old_result, solveset_complex) + + # If total_solveset_call is equal to total_conditionset + # then solveset failed to solve all of the equations. + # In this case we return a ConditionSet here. + total_conditionset += (cnd_call1 + cnd_call2) + total_solveset_call += (solve_call1 + solve_call2) + + if total_conditionset == total_solveset_call and total_solveset_call != -1: + return _return_conditionset(eqs_in_better_order, all_symbols) + + # don't keep duplicate solutions + filtered_complex = [] + for i in list(new_result_complex): + for j in list(new_result_real): + if i.keys() != j.keys(): + continue + if all(a.dummy_eq(b) for a, b in zip(i.values(), j.values()) \ + if not (isinstance(a, int) and isinstance(b, int))): + break + else: + filtered_complex.append(i) + # overall result + result = new_result_real + filtered_complex + + result_all_variables = [] + result_infinite = [] + for res in result: + if not res: + # means {None : None} + continue + # If length < len(all_symbols) means infinite soln. + # Some or all the soln is dependent on 1 symbol. + # eg. {x: y+2} then final soln {x: y+2, y: y} + if len(res) < len(all_symbols): + solved_symbols = res.keys() + unsolved = list(filter( + lambda x: x not in solved_symbols, all_symbols)) + for unsolved_sym in unsolved: + res[unsolved_sym] = unsolved_sym + result_infinite.append(res) + if res not in result_all_variables: + result_all_variables.append(res) + + if result_infinite: + # we have general soln + # eg : [{x: -1, y : 1}, {x : -y, y: y}] then + # return [{x : -y, y : y}] + result_all_variables = result_infinite + if intersections or complements: + result_all_variables = add_intersection_complement( + result_all_variables, intersections, complements) + + # convert to ordered tuple + result = S.EmptySet + for r in result_all_variables: + temp = [r[symb] for symb in all_symbols] + result += FiniteSet(tuple(temp)) + return result + + +def _solveset_work(system, symbols): + soln = solveset(system[0], symbols[0]) + if isinstance(soln, FiniteSet): + _soln = FiniteSet(*[(s,) for s in soln]) + return _soln + else: + return FiniteSet(tuple(FiniteSet(soln))) + + +def _handle_positive_dimensional(polys, symbols, denominators): + from sympy.polys.polytools import groebner + # substitution method where new system is groebner basis of the system + _symbols = list(symbols) + _symbols.sort(key=default_sort_key) + basis = groebner(polys, _symbols, polys=True) + new_system = [] + for poly_eq in basis: + new_system.append(poly_eq.as_expr()) + result = [{}] + result = substitution( + new_system, symbols, result, [], + denominators) + return result + + +def _handle_zero_dimensional(polys, symbols, system): + # solve 0 dimensional poly system using `solve_poly_system` + result = solve_poly_system(polys, *symbols) + # May be some extra soln is added because + # we used `unrad` in `_separate_poly_nonpoly`, so + # need to check and remove if it is not a soln. + result_update = S.EmptySet + for res in result: + dict_sym_value = dict(list(zip(symbols, res))) + if all(checksol(eq, dict_sym_value) for eq in system): + result_update += FiniteSet(res) + return result_update + + +def _separate_poly_nonpoly(system, symbols): + polys = [] + polys_expr = [] + nonpolys = [] + # unrad_changed stores a list of expressions containing + # radicals that were processed using unrad + # this is useful if solutions need to be checked later. + unrad_changed = [] + denominators = set() + poly = None + for eq in system: + # Store denom expressions that contain symbols + denominators.update(_simple_dens(eq, symbols)) + # Convert equality to expression + if isinstance(eq, Eq): + eq = eq.lhs - eq.rhs + # try to remove sqrt and rational power + without_radicals = unrad(simplify(eq), *symbols) + if without_radicals: + unrad_changed.append(eq) + eq_unrad, cov = without_radicals + if not cov: + eq = eq_unrad + if isinstance(eq, Expr): + eq = eq.as_numer_denom()[0] + poly = eq.as_poly(*symbols, extension=True) + elif simplify(eq).is_number: + continue + if poly is not None: + polys.append(poly) + polys_expr.append(poly.as_expr()) + else: + nonpolys.append(eq) + return polys, polys_expr, nonpolys, denominators, unrad_changed + + +def _handle_poly(polys, symbols): + # _handle_poly(polys, symbols) -> (poly_sol, poly_eqs) + # + # We will return possible solution information to nonlinsolve as well as a + # new system of polynomial equations to be solved if we cannot solve + # everything directly here. The new system of polynomial equations will be + # a lex-order Groebner basis for the original system. The lex basis + # hopefully separate some of the variables and equations and give something + # easier for substitution to work with. + + # The format for representing solution sets in nonlinsolve and substitution + # is a list of dicts. These are the special cases: + no_information = [{}] # No equations solved yet + no_solutions = [] # The system is inconsistent and has no solutions. + + # If there is no need to attempt further solution of these equations then + # we return no equations: + no_equations = [] + + inexact = any(not p.domain.is_Exact for p in polys) + if inexact: + # The use of Groebner over RR is likely to result incorrectly in an + # inconsistent Groebner basis. So, convert any float coefficients to + # Rational before computing the Groebner basis. + polys = [poly(nsimplify(p, rational=True)) for p in polys] + + # Compute a Groebner basis in grevlex order wrt the ordering given. We will + # try to convert this to lex order later. Usually it seems to be more + # efficient to compute a lex order basis by computing a grevlex basis and + # converting to lex with fglm. + basis = groebner(polys, symbols, order='grevlex', polys=False) + + # + # No solutions (inconsistent equations)? + # + if 1 in basis: + + # No solutions: + poly_sol = no_solutions + poly_eqs = no_equations + + # + # Finite number of solutions (zero-dimensional case) + # + elif basis.is_zero_dimensional: + + # Convert Groebner basis to lex ordering + basis = basis.fglm('lex') + + # Convert polynomial coefficients back to float before calling + # solve_poly_system + if inexact: + basis = [nfloat(p) for p in basis] + + # Solve the zero-dimensional case using solve_poly_system if possible. + # If some polynomials have factors that cannot be solved in radicals + # then this will fail. Using solve_poly_system(..., strict=True) + # ensures that we either get a complete solution set in radicals or + # UnsolvableFactorError will be raised. + try: + result = solve_poly_system(basis, *symbols, strict=True) + except UnsolvableFactorError: + # Failure... not fully solvable in radicals. Return the lex-order + # basis for substitution to handle. + poly_sol = no_information + poly_eqs = list(basis) + else: + # Success! We have a finite solution set and solve_poly_system has + # succeeded in finding all solutions. Return the solutions and also + # an empty list of remaining equations to be solved. + poly_sol = [dict(zip(symbols, res)) for res in result] + poly_eqs = no_equations + + # + # Infinite families of solutions (positive-dimensional case) + # + else: + # In this case the grevlex basis cannot be converted to lex using the + # fglm method and also solve_poly_system cannot solve the equations. We + # would like to return a lex basis but since we can't use fglm we + # compute the lex basis directly here. The time required to recompute + # the basis is generally significantly less than the time required by + # substitution to solve the new system. + poly_sol = no_information + poly_eqs = list(groebner(polys, symbols, order='lex', polys=False)) + + if inexact: + poly_eqs = [nfloat(p) for p in poly_eqs] + + return poly_sol, poly_eqs + + +def nonlinsolve(system, *symbols): + r""" + Solve system of $N$ nonlinear equations with $M$ variables, which means both + under and overdetermined systems are supported. Positive dimensional + system is also supported (A system with infinitely many solutions is said + to be positive-dimensional). In a positive dimensional system the solution will + be dependent on at least one symbol. Returns both real solution + and complex solution (if they exist). + + Parameters + ========== + + system : list of equations + The target system of equations + symbols : list of Symbols + symbols should be given as a sequence eg. list + + Returns + ======= + + A :class:`~.FiniteSet` of ordered tuple of values of `symbols` for which the `system` + has solution. Order of values in the tuple is same as symbols present in + the parameter `symbols`. + + Please note that general :class:`~.FiniteSet` is unordered, the solution + returned here is not simply a :class:`~.FiniteSet` of solutions, rather it + is a :class:`~.FiniteSet` of ordered tuple, i.e. the first and only + argument to :class:`~.FiniteSet` is a tuple of solutions, which is + ordered, and, hence ,the returned solution is ordered. + + Also note that solution could also have been returned as an ordered tuple, + FiniteSet is just a wrapper ``{}`` around the tuple. It has no other + significance except for the fact it is just used to maintain a consistent + output format throughout the solveset. + + For the given set of equations, the respective input types + are given below: + + .. math:: xy - 1 = 0 + .. math:: 4x^2 + y^2 - 5 = 0 + + :: + + system = [x*y - 1, 4*x**2 + y**2 - 5] + symbols = [x, y] + + Raises + ====== + + ValueError + The input is not valid. + The symbols are not given. + AttributeError + The input symbols are not `Symbol` type. + + Examples + ======== + + >>> from sympy import symbols, nonlinsolve + >>> x, y, z = symbols('x, y, z', real=True) + >>> nonlinsolve([x*y - 1, 4*x**2 + y**2 - 5], [x, y]) + {(-1, -1), (-1/2, -2), (1/2, 2), (1, 1)} + + 1. Positive dimensional system and complements: + + >>> from sympy import pprint + >>> from sympy.polys.polytools import is_zero_dimensional + >>> a, b, c, d = symbols('a, b, c, d', extended_real=True) + >>> eq1 = a + b + c + d + >>> eq2 = a*b + b*c + c*d + d*a + >>> eq3 = a*b*c + b*c*d + c*d*a + d*a*b + >>> eq4 = a*b*c*d - 1 + >>> system = [eq1, eq2, eq3, eq4] + >>> is_zero_dimensional(system) + False + >>> pprint(nonlinsolve(system, [a, b, c, d]), use_unicode=False) + -1 1 1 -1 + {(---, -d, -, {d} \ {0}), (-, -d, ---, {d} \ {0})} + d d d d + >>> nonlinsolve([(x+y)**2 - 4, x + y - 2], [x, y]) + {(2 - y, y)} + + 2. If some of the equations are non-polynomial then `nonlinsolve` + will call the ``substitution`` function and return real and complex solutions, + if present. + + >>> from sympy import exp, sin + >>> nonlinsolve([exp(x) - sin(y), y**2 - 4], [x, y]) + {(ImageSet(Lambda(_n, I*(2*_n*pi + pi) + log(sin(2))), Integers), -2), + (ImageSet(Lambda(_n, 2*_n*I*pi + log(sin(2))), Integers), 2)} + + 3. If system is non-linear polynomial and zero-dimensional then it + returns both solution (real and complex solutions, if present) using + :func:`~.solve_poly_system`: + + >>> from sympy import sqrt + >>> nonlinsolve([x**2 - 2*y**2 -2, x*y - 2], [x, y]) + {(-2, -1), (2, 1), (-sqrt(2)*I, sqrt(2)*I), (sqrt(2)*I, -sqrt(2)*I)} + + 4. ``nonlinsolve`` can solve some linear (zero or positive dimensional) + system (because it uses the :func:`sympy.polys.polytools.groebner` function to get the + groebner basis and then uses the ``substitution`` function basis as the + new `system`). But it is not recommended to solve linear system using + ``nonlinsolve``, because :func:`~.linsolve` is better for general linear systems. + + >>> nonlinsolve([x + 2*y -z - 3, x - y - 4*z + 9, y + z - 4], [x, y, z]) + {(3*z - 5, 4 - z, z)} + + 5. System having polynomial equations and only real solution is + solved using :func:`~.solve_poly_system`: + + >>> e1 = sqrt(x**2 + y**2) - 10 + >>> e2 = sqrt(y**2 + (-x + 10)**2) - 3 + >>> nonlinsolve((e1, e2), (x, y)) + {(191/20, -3*sqrt(391)/20), (191/20, 3*sqrt(391)/20)} + >>> nonlinsolve([x**2 + 2/y - 2, x + y - 3], [x, y]) + {(1, 2), (1 - sqrt(5), 2 + sqrt(5)), (1 + sqrt(5), 2 - sqrt(5))} + >>> nonlinsolve([x**2 + 2/y - 2, x + y - 3], [y, x]) + {(2, 1), (2 - sqrt(5), 1 + sqrt(5)), (2 + sqrt(5), 1 - sqrt(5))} + + 6. It is better to use symbols instead of trigonometric functions or + :class:`~.Function`. For example, replace $\sin(x)$ with a symbol, replace + $f(x)$ with a symbol and so on. Get a solution from ``nonlinsolve`` and then + use :func:`~.solveset` to get the value of $x$. + + How nonlinsolve is better than old solver ``_solve_system`` : + ============================================================= + + 1. A positive dimensional system solver: nonlinsolve can return + solution for positive dimensional system. It finds the + Groebner Basis of the positive dimensional system(calling it as + basis) then we can start solving equation(having least number of + variable first in the basis) using solveset and substituting that + solved solutions into other equation(of basis) to get solution in + terms of minimum variables. Here the important thing is how we + are substituting the known values and in which equations. + + 2. Real and complex solutions: nonlinsolve returns both real + and complex solution. If all the equations in the system are polynomial + then using :func:`~.solve_poly_system` both real and complex solution is returned. + If all the equations in the system are not polynomial equation then goes to + ``substitution`` method with this polynomial and non polynomial equation(s), + to solve for unsolved variables. Here to solve for particular variable + solveset_real and solveset_complex is used. For both real and complex + solution ``_solve_using_known_values`` is used inside ``substitution`` + (``substitution`` will be called when any non-polynomial equation is present). + If a solution is valid its general solution is added to the final result. + + 3. :class:`~.Complement` and :class:`~.Intersection` will be added: + nonlinsolve maintains dict for complements and intersections. If solveset + find complements or/and intersections with any interval or set during the + execution of ``substitution`` function, then complement or/and + intersection for that variable is added before returning final solution. + + """ + if not system: + return S.EmptySet + + if not symbols: + msg = ('Symbols must be given, for which solution of the ' + 'system is to be found.') + raise ValueError(filldedent(msg)) + + if hasattr(symbols[0], '__iter__'): + symbols = symbols[0] + + if not is_sequence(symbols) or not symbols: + msg = ('Symbols must be given, for which solution of the ' + 'system is to be found.') + raise IndexError(filldedent(msg)) + + symbols = list(map(_sympify, symbols)) + system, symbols, swap = recast_to_symbols(system, symbols) + if swap: + soln = nonlinsolve(system, symbols) + return FiniteSet(*[tuple(i.xreplace(swap) for i in s) for s in soln]) + + if len(system) == 1 and len(symbols) == 1: + return _solveset_work(system, symbols) + + # main code of def nonlinsolve() starts from here + + polys, polys_expr, nonpolys, denominators, unrad_changed = \ + _separate_poly_nonpoly(system, symbols) + + poly_eqs = [] + poly_sol = [{}] + + if polys: + poly_sol, poly_eqs = _handle_poly(polys, symbols) + if poly_sol and poly_sol[0]: + poly_syms = set().union(*(eq.free_symbols for eq in polys)) + unrad_syms = set().union(*(eq.free_symbols for eq in unrad_changed)) + if unrad_syms == poly_syms and unrad_changed: + # if all the symbols have been solved by _handle_poly + # and unrad has been used then check solutions + poly_sol = [sol for sol in poly_sol if checksol(unrad_changed, sol)] + + # Collect together the unsolved polynomials with the non-polynomial + # equations. + remaining = poly_eqs + nonpolys + + # to_tuple converts a solution dictionary to a tuple containing the + # value for each symbol + to_tuple = lambda sol: tuple(sol[s] for s in symbols) + + if not remaining: + # If there is nothing left to solve then return the solution from + # solve_poly_system directly. + return FiniteSet(*map(to_tuple, poly_sol)) + else: + # Here we handle: + # + # 1. The Groebner basis if solve_poly_system failed. + # 2. The Groebner basis in the positive-dimensional case. + # 3. Any non-polynomial equations + # + # If solve_poly_system did succeed then we pass those solutions in as + # preliminary results. + subs_res = substitution(remaining, symbols, result=poly_sol, exclude=denominators) + + if not isinstance(subs_res, FiniteSet): + return subs_res + + # check solutions produced by substitution. Currently, checking is done for + # only those solutions which have non-Set variable values. + if unrad_changed: + result = [dict(zip(symbols, sol)) for sol in subs_res.args] + correct_sols = [sol for sol in result if any(isinstance(v, Set) for v in sol) + or checksol(unrad_changed, sol) != False] + return FiniteSet(*map(to_tuple, correct_sols)) + else: + return subs_res diff --git a/phivenv/Lib/site-packages/sympy/solvers/tests/__init__.py b/phivenv/Lib/site-packages/sympy/solvers/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/solvers/tests/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/solvers/tests/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd869ff35a1ce3578b1cb9b14cfd425020e48a68 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/solvers/tests/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/solvers/tests/__pycache__/test_constantsimp.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/solvers/tests/__pycache__/test_constantsimp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0bfb66c4063ed889ee4d6df10654517c999f6e8c Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/solvers/tests/__pycache__/test_constantsimp.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/solvers/tests/__pycache__/test_decompogen.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/solvers/tests/__pycache__/test_decompogen.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3068a9334a92fc3b6f485c7649cbae4f59c42f88 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/solvers/tests/__pycache__/test_decompogen.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/solvers/tests/__pycache__/test_inequalities.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/solvers/tests/__pycache__/test_inequalities.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6b705c526586f00df58f54551554840d7c6f7f0 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/solvers/tests/__pycache__/test_inequalities.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/solvers/tests/__pycache__/test_numeric.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/solvers/tests/__pycache__/test_numeric.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f5c50d77267f5186d01664f1d7957fb86373351c Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/solvers/tests/__pycache__/test_numeric.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/solvers/tests/__pycache__/test_pde.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/solvers/tests/__pycache__/test_pde.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a3649b4bafec9fc61265a5aab2b941b48db8978 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/solvers/tests/__pycache__/test_pde.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/solvers/tests/__pycache__/test_polysys.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/solvers/tests/__pycache__/test_polysys.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4614e7679a77f11e4f3be8760952d80e37b10d70 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/solvers/tests/__pycache__/test_polysys.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/solvers/tests/__pycache__/test_recurr.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/solvers/tests/__pycache__/test_recurr.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7861c9a1243049e475f6653a047c8a2ea8a011e4 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/solvers/tests/__pycache__/test_recurr.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/solvers/tests/__pycache__/test_simplex.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/solvers/tests/__pycache__/test_simplex.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1f4159ab3d0f9931460e7cc6473a367efd75e41 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/solvers/tests/__pycache__/test_simplex.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/solvers/tests/test_constantsimp.py b/phivenv/Lib/site-packages/sympy/solvers/tests/test_constantsimp.py new file mode 100644 index 0000000000000000000000000000000000000000..efb966a4c8c2f93558d05e7c330f06530e69180c --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/solvers/tests/test_constantsimp.py @@ -0,0 +1,179 @@ +""" +If the arbitrary constant class from issue 4435 is ever implemented, this +should serve as a set of test cases. +""" + +from sympy.core.function import Function +from sympy.core.numbers import I +from sympy.core.power import Pow +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.hyperbolic import (cosh, sinh) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (acos, cos, sin) +from sympy.integrals.integrals import Integral +from sympy.solvers.ode.ode import constantsimp, constant_renumber +from sympy.testing.pytest import XFAIL + + +x = Symbol('x') +y = Symbol('y') +z = Symbol('z') +u2 = Symbol('u2') +_a = Symbol('_a') +C1 = Symbol('C1') +C2 = Symbol('C2') +C3 = Symbol('C3') +f = Function('f') + + +def test_constant_mul(): + # We want C1 (Constant) below to absorb the y's, but not the x's + assert constant_renumber(constantsimp(y*C1, [C1])) == C1*y + assert constant_renumber(constantsimp(C1*y, [C1])) == C1*y + assert constant_renumber(constantsimp(x*C1, [C1])) == x*C1 + assert constant_renumber(constantsimp(C1*x, [C1])) == x*C1 + assert constant_renumber(constantsimp(2*C1, [C1])) == C1 + assert constant_renumber(constantsimp(C1*2, [C1])) == C1 + assert constant_renumber(constantsimp(y*C1*x, [C1, y])) == C1*x + assert constant_renumber(constantsimp(x*y*C1, [C1, y])) == x*C1 + assert constant_renumber(constantsimp(y*x*C1, [C1, y])) == x*C1 + assert constant_renumber(constantsimp(C1*x*y, [C1, y])) == C1*x + assert constant_renumber(constantsimp(x*C1*y, [C1, y])) == x*C1 + assert constant_renumber(constantsimp(C1*y*(y + 1), [C1])) == C1*y*(y+1) + assert constant_renumber(constantsimp(y*C1*(y + 1), [C1])) == C1*y*(y+1) + assert constant_renumber(constantsimp(x*(y*C1), [C1])) == x*y*C1 + assert constant_renumber(constantsimp(x*(C1*y), [C1])) == x*y*C1 + assert constant_renumber(constantsimp(C1*(x*y), [C1, y])) == C1*x + assert constant_renumber(constantsimp((x*y)*C1, [C1, y])) == x*C1 + assert constant_renumber(constantsimp((y*x)*C1, [C1, y])) == x*C1 + assert constant_renumber(constantsimp(y*(y + 1)*C1, [C1, y])) == C1 + assert constant_renumber(constantsimp((C1*x)*y, [C1, y])) == C1*x + assert constant_renumber(constantsimp(y*(x*C1), [C1, y])) == x*C1 + assert constant_renumber(constantsimp((x*C1)*y, [C1, y])) == x*C1 + assert constant_renumber(constantsimp(C1*x*y*x*y*2, [C1, y])) == C1*x**2 + assert constant_renumber(constantsimp(C1*x*y*z, [C1, y, z])) == C1*x + assert constant_renumber(constantsimp(C1*x*y**2*sin(z), [C1, y, z])) == C1*x + assert constant_renumber(constantsimp(C1*C1, [C1])) == C1 + assert constant_renumber(constantsimp(C1*C2, [C1, C2])) == C1 + assert constant_renumber(constantsimp(C2*C2, [C1, C2])) == C1 + assert constant_renumber(constantsimp(C1*C1*C2, [C1, C2])) == C1 + assert constant_renumber(constantsimp(C1*x*2**x, [C1])) == C1*x*2**x + +def test_constant_add(): + assert constant_renumber(constantsimp(C1 + C1, [C1])) == C1 + assert constant_renumber(constantsimp(C1 + 2, [C1])) == C1 + assert constant_renumber(constantsimp(2 + C1, [C1])) == C1 + assert constant_renumber(constantsimp(C1 + y, [C1, y])) == C1 + assert constant_renumber(constantsimp(C1 + x, [C1])) == C1 + x + assert constant_renumber(constantsimp(C1 + C1, [C1])) == C1 + assert constant_renumber(constantsimp(C1 + C2, [C1, C2])) == C1 + assert constant_renumber(constantsimp(C2 + C1, [C1, C2])) == C1 + assert constant_renumber(constantsimp(C1 + C2 + C1, [C1, C2])) == C1 + + +def test_constant_power_as_base(): + assert constant_renumber(constantsimp(C1**C1, [C1])) == C1 + assert constant_renumber(constantsimp(Pow(C1, C1), [C1])) == C1 + assert constant_renumber(constantsimp(C1**C1, [C1])) == C1 + assert constant_renumber(constantsimp(C1**C2, [C1, C2])) == C1 + assert constant_renumber(constantsimp(C2**C1, [C1, C2])) == C1 + assert constant_renumber(constantsimp(C2**C2, [C1, C2])) == C1 + assert constant_renumber(constantsimp(C1**y, [C1, y])) == C1 + assert constant_renumber(constantsimp(C1**x, [C1])) == C1**x + assert constant_renumber(constantsimp(C1**2, [C1])) == C1 + assert constant_renumber( + constantsimp(C1**(x*y), [C1])) == C1**(x*y) + + +def test_constant_power_as_exp(): + assert constant_renumber(constantsimp(x**C1, [C1])) == x**C1 + assert constant_renumber(constantsimp(y**C1, [C1, y])) == C1 + assert constant_renumber(constantsimp(x**y**C1, [C1, y])) == x**C1 + assert constant_renumber( + constantsimp((x**y)**C1, [C1])) == (x**y)**C1 + assert constant_renumber( + constantsimp(x**(y**C1), [C1, y])) == x**C1 + assert constant_renumber(constantsimp(x**C1**y, [C1, y])) == x**C1 + assert constant_renumber( + constantsimp(x**(C1**y), [C1, y])) == x**C1 + assert constant_renumber( + constantsimp((x**C1)**y, [C1])) == (x**C1)**y + assert constant_renumber(constantsimp(2**C1, [C1])) == C1 + assert constant_renumber(constantsimp(S(2)**C1, [C1])) == C1 + assert constant_renumber(constantsimp(exp(C1), [C1])) == C1 + assert constant_renumber( + constantsimp(exp(C1 + x), [C1])) == C1*exp(x) + assert constant_renumber(constantsimp(Pow(2, C1), [C1])) == C1 + + +def test_constant_function(): + assert constant_renumber(constantsimp(sin(C1), [C1])) == C1 + assert constant_renumber(constantsimp(f(C1), [C1])) == C1 + assert constant_renumber(constantsimp(f(C1, C1), [C1])) == C1 + assert constant_renumber(constantsimp(f(C1, C2), [C1, C2])) == C1 + assert constant_renumber(constantsimp(f(C2, C1), [C1, C2])) == C1 + assert constant_renumber(constantsimp(f(C2, C2), [C1, C2])) == C1 + assert constant_renumber( + constantsimp(f(C1, x), [C1])) == f(C1, x) + assert constant_renumber(constantsimp(f(C1, y), [C1, y])) == C1 + assert constant_renumber(constantsimp(f(y, C1), [C1, y])) == C1 + assert constant_renumber(constantsimp(f(C1, y, C2), [C1, C2, y])) == C1 + + +def test_constant_function_multiple(): + # The rules to not renumber in this case would be too complicated, and + # dsolve is not likely to ever encounter anything remotely like this. + assert constant_renumber( + constantsimp(f(C1, C1, x), [C1])) == f(C1, C1, x) + + +def test_constant_multiple(): + assert constant_renumber(constantsimp(C1*2 + 2, [C1])) == C1 + assert constant_renumber(constantsimp(x*2/C1, [C1])) == C1*x + assert constant_renumber(constantsimp(C1**2*2 + 2, [C1])) == C1 + assert constant_renumber( + constantsimp(sin(2*C1) + x + sqrt(2), [C1])) == C1 + x + assert constant_renumber(constantsimp(2*C1 + C2, [C1, C2])) == C1 + +def test_constant_repeated(): + assert C1 + C1*x == constant_renumber( C1 + C1*x) + +def test_ode_solutions(): + # only a few examples here, the rest will be tested in the actual dsolve tests + assert constant_renumber(constantsimp(C1*exp(2*x) + exp(x)*(C2 + C3), [C1, C2, C3])) == \ + constant_renumber(C1*exp(x) + C2*exp(2*x)) + assert constant_renumber( + constantsimp(Eq(f(x), I*C1*sinh(x/3) + C2*cosh(x/3)), [C1, C2]) + ) == constant_renumber(Eq(f(x), C1*sinh(x/3) + C2*cosh(x/3))) + assert constant_renumber(constantsimp(Eq(f(x), acos((-C1)/cos(x))), [C1])) == \ + Eq(f(x), acos(C1/cos(x))) + assert constant_renumber( + constantsimp(Eq(log(f(x)/C1) + 2*exp(x/f(x)), 0), [C1]) + ) == Eq(log(C1*f(x)) + 2*exp(x/f(x)), 0) + assert constant_renumber(constantsimp(Eq(log(x*sqrt(2)*sqrt(1/x)*sqrt(f(x)) + /C1) + x**2/(2*f(x)**2), 0), [C1])) == \ + Eq(log(C1*sqrt(x)*sqrt(f(x))) + x**2/(2*f(x)**2), 0) + assert constant_renumber(constantsimp(Eq(-exp(-f(x)/x)*sin(f(x)/x)/2 + log(x/C1) - + cos(f(x)/x)*exp(-f(x)/x)/2, 0), [C1])) == \ + Eq(-exp(-f(x)/x)*sin(f(x)/x)/2 + log(C1*x) - cos(f(x)/x)* + exp(-f(x)/x)/2, 0) + assert constant_renumber(constantsimp(Eq(-Integral(-1/(sqrt(1 - u2**2)*u2), + (u2, _a, x/f(x))) + log(f(x)/C1), 0), [C1])) == \ + Eq(-Integral(-1/(u2*sqrt(1 - u2**2)), (u2, _a, x/f(x))) + + log(C1*f(x)), 0) + assert [constantsimp(i, [C1]) for i in [Eq(f(x), sqrt(-C1*x + x**2)), Eq(f(x), -sqrt(-C1*x + x**2))]] == \ + [Eq(f(x), sqrt(x*(C1 + x))), Eq(f(x), -sqrt(x*(C1 + x)))] + + +@XFAIL +def test_nonlocal_simplification(): + assert constantsimp(C1 + C2+x*C2, [C1, C2]) == C1 + C2*x + + +def test_constant_Eq(): + # C1 on the rhs is well-tested, but the lhs is only tested here + assert constantsimp(Eq(C1, 3 + f(x)*x), [C1]) == Eq(x*f(x), C1) + assert constantsimp(Eq(C1, 3 * f(x)*x), [C1]) == Eq(f(x)*x, C1) diff --git a/phivenv/Lib/site-packages/sympy/solvers/tests/test_decompogen.py b/phivenv/Lib/site-packages/sympy/solvers/tests/test_decompogen.py new file mode 100644 index 0000000000000000000000000000000000000000..1ba03f4b42558231b626b6ed169f8b0a81a72bf9 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/solvers/tests/test_decompogen.py @@ -0,0 +1,59 @@ +from sympy.solvers.decompogen import decompogen, compogen +from sympy.core.symbol import symbols +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt, Max +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.testing.pytest import XFAIL, raises + +x, y = symbols('x y') + + +def test_decompogen(): + assert decompogen(sin(cos(x)), x) == [sin(x), cos(x)] + assert decompogen(sin(x)**2 + sin(x) + 1, x) == [x**2 + x + 1, sin(x)] + assert decompogen(sqrt(6*x**2 - 5), x) == [sqrt(x), 6*x**2 - 5] + assert decompogen(sin(sqrt(cos(x**2 + 1))), x) == [sin(x), sqrt(x), cos(x), x**2 + 1] + assert decompogen(Abs(cos(x)**2 + 3*cos(x) - 4), x) == [Abs(x), x**2 + 3*x - 4, cos(x)] + assert decompogen(sin(x)**2 + sin(x) - sqrt(3)/2, x) == [x**2 + x - sqrt(3)/2, sin(x)] + assert decompogen(Abs(cos(y)**2 + 3*cos(x) - 4), x) == [Abs(x), 3*x + cos(y)**2 - 4, cos(x)] + assert decompogen(x, y) == [x] + assert decompogen(1, x) == [1] + assert decompogen(Max(3, x), x) == [Max(3, x)] + raises(TypeError, lambda: decompogen(x < 5, x)) + u = 2*x + 3 + assert decompogen(Max(sqrt(u),(u)**2), x) == [Max(sqrt(x), x**2), u] + assert decompogen(Max(u, u**2, y), x) == [Max(x, x**2, y), u] + assert decompogen(Max(sin(x), u), x) == [Max(2*x + 3, sin(x))] + + +def test_decompogen_poly(): + assert decompogen(x**4 + 2*x**2 + 1, x) == [x**2 + 2*x + 1, x**2] + assert decompogen(x**4 + 2*x**3 - x - 1, x) == [x**2 - x - 1, x**2 + x] + + +@XFAIL +def test_decompogen_fails(): + A = lambda x: x**2 + 2*x + 3 + B = lambda x: 4*x**2 + 5*x + 6 + assert decompogen(A(x*exp(x)), x) == [x**2 + 2*x + 3, x*exp(x)] + assert decompogen(A(B(x)), x) == [x**2 + 2*x + 3, 4*x**2 + 5*x + 6] + assert decompogen(A(1/x + 1/x**2), x) == [x**2 + 2*x + 3, 1/x + 1/x**2] + assert decompogen(A(1/x + 2/(x + 1)), x) == [x**2 + 2*x + 3, 1/x + 2/(x + 1)] + + +def test_compogen(): + assert compogen([sin(x), cos(x)], x) == sin(cos(x)) + assert compogen([x**2 + x + 1, sin(x)], x) == sin(x)**2 + sin(x) + 1 + assert compogen([sqrt(x), 6*x**2 - 5], x) == sqrt(6*x**2 - 5) + assert compogen([sin(x), sqrt(x), cos(x), x**2 + 1], x) == sin(sqrt( + cos(x**2 + 1))) + assert compogen([Abs(x), x**2 + 3*x - 4, cos(x)], x) == Abs(cos(x)**2 + + 3*cos(x) - 4) + assert compogen([x**2 + x - sqrt(3)/2, sin(x)], x) == (sin(x)**2 + sin(x) - + sqrt(3)/2) + assert compogen([Abs(x), 3*x + cos(y)**2 - 4, cos(x)], x) == \ + Abs(3*cos(x) + cos(y)**2 - 4) + assert compogen([x**2 + 2*x + 1, x**2], x) == x**4 + 2*x**2 + 1 + # the result is in unsimplified form + assert compogen([x**2 - x - 1, x**2 + x], x) == -x**2 - x + (x**2 + x)**2 - 1 diff --git a/phivenv/Lib/site-packages/sympy/solvers/tests/test_inequalities.py b/phivenv/Lib/site-packages/sympy/solvers/tests/test_inequalities.py new file mode 100644 index 0000000000000000000000000000000000000000..6ce6f4520b52d8714102c95457c90d44543c685c --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/solvers/tests/test_inequalities.py @@ -0,0 +1,500 @@ +"""Tests for tools for solving inequalities and systems of inequalities. """ + +from sympy.concrete.summations import Sum +from sympy.core.function import Function +from sympy.core.numbers import I, Rational, oo, pi +from sympy.core.relational import Eq, Ge, Gt, Le, Lt, Ne +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol) +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.exponential import exp, log +from sympy.functions.elementary.miscellaneous import root, sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import cos, sin, tan +from sympy.integrals.integrals import Integral +from sympy.logic.boolalg import And, Or +from sympy.polys.polytools import Poly, PurePoly +from sympy.sets.sets import FiniteSet, Interval, Union +from sympy.solvers.inequalities import (reduce_inequalities, + solve_poly_inequality as psolve, + reduce_rational_inequalities, + solve_univariate_inequality as isolve, + reduce_abs_inequality, + _solve_inequality) +from sympy.polys.rootoftools import rootof +from sympy.solvers.solvers import solve +from sympy.solvers.solveset import solveset +from sympy.core.mod import Mod +from sympy.abc import x, y + +from sympy.testing.pytest import raises, XFAIL + + +inf = oo.evalf() + + +def test_solve_poly_inequality(): + assert psolve(Poly(0, x), '==') == [S.Reals] + assert psolve(Poly(1, x), '==') == [S.EmptySet] + assert psolve(PurePoly(x + 1, x), ">") == [Interval(-1, oo, True, False)] + + +def test_reduce_poly_inequalities_real_interval(): + assert reduce_rational_inequalities( + [[Eq(x**2, 0)]], x, relational=False) == FiniteSet(0) + assert reduce_rational_inequalities( + [[Le(x**2, 0)]], x, relational=False) == FiniteSet(0) + assert reduce_rational_inequalities( + [[Lt(x**2, 0)]], x, relational=False) == S.EmptySet + assert reduce_rational_inequalities( + [[Ge(x**2, 0)]], x, relational=False) == \ + S.Reals if x.is_real else Interval(-oo, oo) + assert reduce_rational_inequalities( + [[Gt(x**2, 0)]], x, relational=False) == \ + FiniteSet(0).complement(S.Reals) + assert reduce_rational_inequalities( + [[Ne(x**2, 0)]], x, relational=False) == \ + FiniteSet(0).complement(S.Reals) + + assert reduce_rational_inequalities( + [[Eq(x**2, 1)]], x, relational=False) == FiniteSet(-1, 1) + assert reduce_rational_inequalities( + [[Le(x**2, 1)]], x, relational=False) == Interval(-1, 1) + assert reduce_rational_inequalities( + [[Lt(x**2, 1)]], x, relational=False) == Interval(-1, 1, True, True) + assert reduce_rational_inequalities( + [[Ge(x**2, 1)]], x, relational=False) == \ + Union(Interval(-oo, -1), Interval(1, oo)) + assert reduce_rational_inequalities( + [[Gt(x**2, 1)]], x, relational=False) == \ + Interval(-1, 1).complement(S.Reals) + assert reduce_rational_inequalities( + [[Ne(x**2, 1)]], x, relational=False) == \ + FiniteSet(-1, 1).complement(S.Reals) + assert reduce_rational_inequalities([[Eq( + x**2, 1.0)]], x, relational=False) == FiniteSet(-1.0, 1.0).evalf() + assert reduce_rational_inequalities( + [[Le(x**2, 1.0)]], x, relational=False) == Interval(-1.0, 1.0) + assert reduce_rational_inequalities([[Lt( + x**2, 1.0)]], x, relational=False) == Interval(-1.0, 1.0, True, True) + assert reduce_rational_inequalities( + [[Ge(x**2, 1.0)]], x, relational=False) == \ + Union(Interval(-inf, -1.0), Interval(1.0, inf)) + assert reduce_rational_inequalities( + [[Gt(x**2, 1.0)]], x, relational=False) == \ + Union(Interval(-inf, -1.0, right_open=True), + Interval(1.0, inf, left_open=True)) + assert reduce_rational_inequalities([[Ne( + x**2, 1.0)]], x, relational=False) == \ + FiniteSet(-1.0, 1.0).complement(S.Reals) + + s = sqrt(2) + + assert reduce_rational_inequalities([[Lt( + x**2 - 1, 0), Gt(x**2 - 1, 0)]], x, relational=False) == S.EmptySet + assert reduce_rational_inequalities([[Le(x**2 - 1, 0), Ge( + x**2 - 1, 0)]], x, relational=False) == FiniteSet(-1, 1) + assert reduce_rational_inequalities( + [[Le(x**2 - 2, 0), Ge(x**2 - 1, 0)]], x, relational=False + ) == Union(Interval(-s, -1, False, False), Interval(1, s, False, False)) + assert reduce_rational_inequalities( + [[Le(x**2 - 2, 0), Gt(x**2 - 1, 0)]], x, relational=False + ) == Union(Interval(-s, -1, False, True), Interval(1, s, True, False)) + assert reduce_rational_inequalities( + [[Lt(x**2 - 2, 0), Ge(x**2 - 1, 0)]], x, relational=False + ) == Union(Interval(-s, -1, True, False), Interval(1, s, False, True)) + assert reduce_rational_inequalities( + [[Lt(x**2 - 2, 0), Gt(x**2 - 1, 0)]], x, relational=False + ) == Union(Interval(-s, -1, True, True), Interval(1, s, True, True)) + assert reduce_rational_inequalities( + [[Lt(x**2 - 2, 0), Ne(x**2 - 1, 0)]], x, relational=False + ) == Union(Interval(-s, -1, True, True), Interval(-1, 1, True, True), + Interval(1, s, True, True)) + + assert reduce_rational_inequalities([[Lt(x**2, -1.)]], x) is S.false + + +def test_reduce_poly_inequalities_complex_relational(): + assert reduce_rational_inequalities( + [[Eq(x**2, 0)]], x, relational=True) == Eq(x, 0) + assert reduce_rational_inequalities( + [[Le(x**2, 0)]], x, relational=True) == Eq(x, 0) + assert reduce_rational_inequalities( + [[Lt(x**2, 0)]], x, relational=True) == False + assert reduce_rational_inequalities( + [[Ge(x**2, 0)]], x, relational=True) == And(Lt(-oo, x), Lt(x, oo)) + assert reduce_rational_inequalities( + [[Gt(x**2, 0)]], x, relational=True) == \ + And(Gt(x, -oo), Lt(x, oo), Ne(x, 0)) + assert reduce_rational_inequalities( + [[Ne(x**2, 0)]], x, relational=True) == \ + And(Gt(x, -oo), Lt(x, oo), Ne(x, 0)) + + for one in (S.One, S(1.0)): + inf = one*oo + assert reduce_rational_inequalities( + [[Eq(x**2, one)]], x, relational=True) == \ + Or(Eq(x, -one), Eq(x, one)) + assert reduce_rational_inequalities( + [[Le(x**2, one)]], x, relational=True) == \ + And(And(Le(-one, x), Le(x, one))) + assert reduce_rational_inequalities( + [[Lt(x**2, one)]], x, relational=True) == \ + And(And(Lt(-one, x), Lt(x, one))) + assert reduce_rational_inequalities( + [[Ge(x**2, one)]], x, relational=True) == \ + And(Or(And(Le(one, x), Lt(x, inf)), And(Le(x, -one), Lt(-inf, x)))) + assert reduce_rational_inequalities( + [[Gt(x**2, one)]], x, relational=True) == \ + And(Or(And(Lt(-inf, x), Lt(x, -one)), And(Lt(one, x), Lt(x, inf)))) + assert reduce_rational_inequalities( + [[Ne(x**2, one)]], x, relational=True) == \ + Or(And(Lt(-inf, x), Lt(x, -one)), + And(Lt(-one, x), Lt(x, one)), + And(Lt(one, x), Lt(x, inf))) + + +def test_reduce_rational_inequalities_real_relational(): + assert reduce_rational_inequalities([], x) == False + assert reduce_rational_inequalities( + [[(x**2 + 3*x + 2)/(x**2 - 16) >= 0]], x, relational=False) == \ + Union(Interval.open(-oo, -4), Interval(-2, -1), Interval.open(4, oo)) + + assert reduce_rational_inequalities( + [[((-2*x - 10)*(3 - x))/((x**2 + 5)*(x - 2)**2) < 0]], x, + relational=False) == \ + Union(Interval.open(-5, 2), Interval.open(2, 3)) + + assert reduce_rational_inequalities([[(x + 1)/(x - 5) <= 0]], x, + relational=False) == \ + Interval.Ropen(-1, 5) + + assert reduce_rational_inequalities([[(x**2 + 4*x + 3)/(x - 1) > 0]], x, + relational=False) == \ + Union(Interval.open(-3, -1), Interval.open(1, oo)) + + assert reduce_rational_inequalities([[(x**2 - 16)/(x - 1)**2 < 0]], x, + relational=False) == \ + Union(Interval.open(-4, 1), Interval.open(1, 4)) + + assert reduce_rational_inequalities([[(3*x + 1)/(x + 4) >= 1]], x, + relational=False) == \ + Union(Interval.open(-oo, -4), Interval.Ropen(Rational(3, 2), oo)) + + assert reduce_rational_inequalities([[(x - 8)/x <= 3 - x]], x, + relational=False) == \ + Union(Interval.Lopen(-oo, -2), Interval.Lopen(0, 4)) + + # issue sympy/sympy#10237 + assert reduce_rational_inequalities( + [[x < oo, x >= 0, -oo < x]], x, relational=False) == Interval(0, oo) + + +def test_reduce_abs_inequalities(): + e = abs(x - 5) < 3 + ans = And(Lt(2, x), Lt(x, 8)) + assert reduce_inequalities(e) == ans + assert reduce_inequalities(e, x) == ans + assert reduce_inequalities(abs(x - 5)) == Eq(x, 5) + assert reduce_inequalities( + abs(2*x + 3) >= 8) == Or(And(Le(Rational(5, 2), x), Lt(x, oo)), + And(Le(x, Rational(-11, 2)), Lt(-oo, x))) + assert reduce_inequalities(abs(x - 4) + abs( + 3*x - 5) < 7) == And(Lt(S.Half, x), Lt(x, 4)) + assert reduce_inequalities(abs(x - 4) + abs(3*abs(x) - 5) < 7) == \ + Or(And(S(-2) < x, x < -1), And(S.Half < x, x < 4)) + + nr = Symbol('nr', extended_real=False) + raises(TypeError, lambda: reduce_inequalities(abs(nr - 5) < 3)) + assert reduce_inequalities(x < 3, symbols=[x, nr]) == And(-oo < x, x < 3) + + +def test_reduce_inequalities_general(): + assert reduce_inequalities(Ge(sqrt(2)*x, 1)) == And(sqrt(2)/2 <= x, x < oo) + assert reduce_inequalities(x + 1 > 0) == And(S.NegativeOne < x, x < oo) + + +def test_reduce_inequalities_boolean(): + assert reduce_inequalities( + [Eq(x**2, 0), True]) == Eq(x, 0) + assert reduce_inequalities([Eq(x**2, 0), False]) == False + assert reduce_inequalities(x**2 >= 0) is S.true # issue 10196 + + +def test_reduce_inequalities_multivariate(): + assert reduce_inequalities([Ge(x**2, 1), Ge(y**2, 1)]) == And( + Or(And(Le(S.One, x), Lt(x, oo)), And(Le(x, -1), Lt(-oo, x))), + Or(And(Le(S.One, y), Lt(y, oo)), And(Le(y, -1), Lt(-oo, y)))) + + +def test_reduce_inequalities_errors(): + raises(NotImplementedError, lambda: reduce_inequalities(Ge(sin(x) + x, 1))) + raises(NotImplementedError, lambda: reduce_inequalities(Ge(x**2*y + y, 1))) + + +def test__solve_inequalities(): + assert reduce_inequalities(x + y < 1, symbols=[x]) == (x < 1 - y) + assert reduce_inequalities(x + y >= 1, symbols=[x]) == (x < oo) & (x >= -y + 1) + assert reduce_inequalities(Eq(0, x - y), symbols=[x]) == Eq(x, y) + assert reduce_inequalities(Ne(0, x - y), symbols=[x]) == Ne(x, y) + + +def test_issue_6343(): + eq = -3*x**2/2 - x*Rational(45, 4) + Rational(33, 2) > 0 + assert reduce_inequalities(eq) == \ + And(x < Rational(-15, 4) + sqrt(401)/4, -sqrt(401)/4 - Rational(15, 4) < x) + + +def test_issue_8235(): + assert reduce_inequalities(x**2 - 1 < 0) == \ + And(S.NegativeOne < x, x < 1) + assert reduce_inequalities(x**2 - 1 <= 0) == \ + And(S.NegativeOne <= x, x <= 1) + assert reduce_inequalities(x**2 - 1 > 0) == \ + Or(And(-oo < x, x < -1), And(x < oo, S.One < x)) + assert reduce_inequalities(x**2 - 1 >= 0) == \ + Or(And(-oo < x, x <= -1), And(S.One <= x, x < oo)) + + eq = x**8 + x - 9 # we want CRootOf solns here + sol = solve(eq >= 0) + tru = Or(And(rootof(eq, 1) <= x, x < oo), And(-oo < x, x <= rootof(eq, 0))) + assert sol == tru + + # recast vanilla as real + assert solve(sqrt((-x + 1)**2) < 1) == And(S.Zero < x, x < 2) + + +def test_issue_5526(): + assert reduce_inequalities(0 <= + x + Integral(y**2, (y, 1, 3)) - 1, [x]) == \ + (x >= -Integral(y**2, (y, 1, 3)) + 1) + f = Function('f') + e = Sum(f(x), (x, 1, 3)) + assert reduce_inequalities(0 <= x + e + y**2, [x]) == \ + (x >= -y**2 - Sum(f(x), (x, 1, 3))) + + +def test_solve_univariate_inequality(): + assert isolve(x**2 >= 4, x, relational=False) == Union(Interval(-oo, -2), + Interval(2, oo)) + assert isolve(x**2 >= 4, x) == Or(And(Le(2, x), Lt(x, oo)), And(Le(x, -2), + Lt(-oo, x))) + assert isolve((x - 1)*(x - 2)*(x - 3) >= 0, x, relational=False) == \ + Union(Interval(1, 2), Interval(3, oo)) + assert isolve((x - 1)*(x - 2)*(x - 3) >= 0, x) == \ + Or(And(Le(1, x), Le(x, 2)), And(Le(3, x), Lt(x, oo))) + assert isolve((x - 1)*(x - 2)*(x - 4) < 0, x, domain = FiniteSet(0, 3)) == \ + Or(Eq(x, 0), Eq(x, 3)) + # issue 2785: + assert isolve(x**3 - 2*x - 1 > 0, x, relational=False) == \ + Union(Interval(-1, -sqrt(5)/2 + S.Half, True, True), + Interval(S.Half + sqrt(5)/2, oo, True, True)) + # issue 2794: + assert isolve(x**3 - x**2 + x - 1 > 0, x, relational=False) == \ + Interval(1, oo, True) + #issue 13105 + assert isolve((x + I)*(x + 2*I) < 0, x) == Eq(x, 0) + assert isolve(((x - 1)*(x - 2) + I)*((x - 1)*(x - 2) + 2*I) < 0, x) == Or(Eq(x, 1), Eq(x, 2)) + assert isolve((((x - 1)*(x - 2) + I)*((x - 1)*(x - 2) + 2*I))/(x - 2) > 0, x) == Eq(x, 1) + raises (ValueError, lambda: isolve((x**2 - 3*x*I + 2)/x < 0, x)) + + # numerical testing in valid() is needed + assert isolve(x**7 - x - 2 > 0, x) == \ + And(rootof(x**7 - x - 2, 0) < x, x < oo) + + # handle numerator and denominator; although these would be handled as + # rational inequalities, these test confirm that the right thing is done + # when the domain is EX (e.g. when 2 is replaced with sqrt(2)) + assert isolve(1/(x - 2) > 0, x) == And(S(2) < x, x < oo) + den = ((x - 1)*(x - 2)).expand() + assert isolve((x - 1)/den <= 0, x) == \ + (x > -oo) & (x < 2) & Ne(x, 1) + + n = Dummy('n') + raises(NotImplementedError, lambda: isolve(Abs(x) <= n, x, relational=False)) + c1 = Dummy("c1", positive=True) + raises(NotImplementedError, lambda: isolve(n/c1 < 0, c1)) + n = Dummy('n', negative=True) + assert isolve(n/c1 > -2, c1) == (-n/2 < c1) + assert isolve(n/c1 < 0, c1) == True + assert isolve(n/c1 > 0, c1) == False + + zero = cos(1)**2 + sin(1)**2 - 1 + raises(NotImplementedError, lambda: isolve(x**2 < zero, x)) + raises(NotImplementedError, lambda: isolve( + x**2 < zero*I, x)) + raises(NotImplementedError, lambda: isolve(1/(x - y) < 2, x)) + raises(NotImplementedError, lambda: isolve(1/(x - y) < 0, x)) + raises(TypeError, lambda: isolve(x - I < 0, x)) + + zero = x**2 + x - x*(x + 1) + assert isolve(zero < 0, x, relational=False) is S.EmptySet + assert isolve(zero <= 0, x, relational=False) is S.Reals + + # make sure iter_solutions gets a default value + raises(NotImplementedError, lambda: isolve( + Eq(cos(x)**2 + sin(x)**2, 1), x)) + + +def test_trig_inequalities(): + # all the inequalities are solved in a periodic interval. + assert isolve(sin(x) < S.Half, x, relational=False) == \ + Union(Interval(0, pi/6, False, True), Interval.open(pi*Rational(5, 6), 2*pi)) + assert isolve(sin(x) > S.Half, x, relational=False) == \ + Interval(pi/6, pi*Rational(5, 6), True, True) + assert isolve(cos(x) < S.Zero, x, relational=False) == \ + Interval(pi/2, pi*Rational(3, 2), True, True) + assert isolve(cos(x) >= S.Zero, x, relational=False) == \ + Union(Interval(0, pi/2), Interval.Ropen(pi*Rational(3, 2), 2*pi)) + + assert isolve(tan(x) < S.One, x, relational=False) == \ + Union(Interval.Ropen(0, pi/4), Interval.open(pi/2, pi)) + + assert isolve(sin(x) <= S.Zero, x, relational=False) == \ + Union(FiniteSet(S.Zero), Interval.Ropen(pi, 2*pi)) + + assert isolve(sin(x) <= S.One, x, relational=False) == S.Reals + assert isolve(cos(x) < S(-2), x, relational=False) == S.EmptySet + assert isolve(sin(x) >= S.NegativeOne, x, relational=False) == S.Reals + assert isolve(cos(x) > S.One, x, relational=False) == S.EmptySet + + +def test_issue_9954(): + assert isolve(x**2 >= 0, x, relational=False) == S.Reals + assert isolve(x**2 >= 0, x, relational=True) == S.Reals.as_relational(x) + assert isolve(x**2 < 0, x, relational=False) == S.EmptySet + assert isolve(x**2 < 0, x, relational=True) == S.EmptySet.as_relational(x) + + +@XFAIL +def test_slow_general_univariate(): + r = rootof(x**5 - x**2 + 1, 0) + assert solve(sqrt(x) + 1/root(x, 3) > 1) == \ + Or(And(0 < x, x < r**6), And(r**6 < x, x < oo)) + + +def test_issue_8545(): + eq = 1 - x - abs(1 - x) + ans = And(Lt(1, x), Lt(x, oo)) + assert reduce_abs_inequality(eq, '<', x) == ans + eq = 1 - x - sqrt((1 - x)**2) + assert reduce_inequalities(eq < 0) == ans + + +def test_issue_8974(): + assert isolve(-oo < x, x) == And(-oo < x, x < oo) + assert isolve(oo > x, x) == And(-oo < x, x < oo) + + +def test_issue_10198(): + assert reduce_inequalities( + -1 + 1/abs(1/x - 1) < 0) == (x > -oo) & (x < S(1)/2) & Ne(x, 0) + + assert reduce_inequalities(abs(1/sqrt(x)) - 1, x) == Eq(x, 1) + assert reduce_abs_inequality(-3 + 1/abs(1 - 1/x), '<', x) == \ + Or(And(-oo < x, x < 0), + And(S.Zero < x, x < Rational(3, 4)), And(Rational(3, 2) < x, x < oo)) + raises(ValueError,lambda: reduce_abs_inequality(-3 + 1/abs( + 1 - 1/sqrt(x)), '<', x)) + + +def test_issue_10047(): + # issue 10047: this must remain an inequality, not True, since if x + # is not real the inequality is invalid + # assert solve(sin(x) < 2) == (x <= oo) + + # with PR 16956, (x <= oo) autoevaluates when x is extended_real + # which is assumed in the current implementation of inequality solvers + assert solve(sin(x) < 2) == True + assert solveset(sin(x) < 2, domain=S.Reals) == S.Reals + + +def test_issue_10268(): + assert solve(log(x) < 1000) == And(S.Zero < x, x < exp(1000)) + + +@XFAIL +def test_isolve_Sets(): + n = Dummy('n') + assert isolve(Abs(x) <= n, x, relational=False) == \ + Piecewise((S.EmptySet, n < 0), (Interval(-n, n), True)) + + +def test_integer_domain_relational_isolve(): + + dom = FiniteSet(0, 3) + x = Symbol('x',zero=False) + assert isolve((x - 1)*(x - 2)*(x - 4) < 0, x, domain=dom) == Eq(x, 3) + + x = Symbol('x') + assert isolve(x + 2 < 0, x, domain=S.Integers) == \ + (x <= -3) & (x > -oo) & Eq(Mod(x, 1), 0) + assert isolve(2 * x + 3 > 0, x, domain=S.Integers) == \ + (x >= -1) & (x < oo) & Eq(Mod(x, 1), 0) + assert isolve((x ** 2 + 3 * x - 2) < 0, x, domain=S.Integers) == \ + (x >= -3) & (x <= 0) & Eq(Mod(x, 1), 0) + assert isolve((x ** 2 + 3 * x - 2) > 0, x, domain=S.Integers) == \ + ((x >= 1) & (x < oo) & Eq(Mod(x, 1), 0)) | ( + (x <= -4) & (x > -oo) & Eq(Mod(x, 1), 0)) + + +def test_issue_10671_12466(): + assert solveset(sin(y), y, Interval(0, pi)) == FiniteSet(0, pi) + i = Interval(1, 10) + assert solveset((1/x).diff(x) < 0, x, i) == i + assert solveset((log(x - 6)/x) <= 0, x, S.Reals) == \ + Interval.Lopen(6, 7) + + +def test__solve_inequality(): + for op in (Gt, Lt, Le, Ge, Eq, Ne): + assert _solve_inequality(op(x, 1), x).lhs == x + assert _solve_inequality(op(S.One, x), x).lhs == x + # don't get tricked by symbol on right: solve it + assert _solve_inequality(Eq(2*x - 1, x), x) == Eq(x, 1) + ie = Eq(S.One, y) + assert _solve_inequality(ie, x) == ie + for fx in (x**2, exp(x), sin(x) + cos(x), x*(1 + x)): + for c in (0, 1): + e = 2*fx - c > 0 + assert _solve_inequality(e, x, linear=True) == ( + fx > c/S(2)) + assert _solve_inequality(2*x**2 + 2*x - 1 < 0, x, linear=True) == ( + x*(x + 1) < S.Half) + assert _solve_inequality(Eq(x*y, 1), x) == Eq(x*y, 1) + nz = Symbol('nz', nonzero=True) + assert _solve_inequality(Eq(x*nz, 1), x) == Eq(x, 1/nz) + assert _solve_inequality(x*nz < 1, x) == (x*nz < 1) + a = Symbol('a', positive=True) + assert _solve_inequality(a/x > 1, x) == (S.Zero < x) & (x < a) + assert _solve_inequality(a/x > 1, x, linear=True) == (1/x > 1/a) + # make sure to include conditions under which solution is valid + e = Eq(1 - x, x*(1/x - 1)) + assert _solve_inequality(e, x) == Ne(x, 0) + assert _solve_inequality(x < x*(1/x - 1), x) == (x < S.Half) & Ne(x, 0) + + +def test__pt(): + from sympy.solvers.inequalities import _pt + assert _pt(-oo, oo) == 0 + assert _pt(S.One, S(3)) == 2 + assert _pt(S.One, oo) == _pt(oo, S.One) == 2 + assert _pt(S.One, -oo) == _pt(-oo, S.One) == S.Half + assert _pt(S.NegativeOne, oo) == _pt(oo, S.NegativeOne) == Rational(-1, 2) + assert _pt(S.NegativeOne, -oo) == _pt(-oo, S.NegativeOne) == -2 + assert _pt(x, oo) == _pt(oo, x) == x + 1 + assert _pt(x, -oo) == _pt(-oo, x) == x - 1 + raises(ValueError, lambda: _pt(Dummy('i', infinite=True), S.One)) + + +def test_issue_25697(): + assert _solve_inequality(log(x, 3) <= 2, x) == (x <= 9) & (S.Zero < x) + + +def test_issue_25738(): + assert reduce_inequalities(3 < abs(x) + ) == reduce_inequalities(pi < abs(x)).subs(pi, 3) + + +def test_issue_25983(): + assert(reduce_inequalities(pi/Abs(x) <= 1) == ((pi <= x) & (x < oo)) | ((-oo < x) & (x <= -pi))) diff --git a/phivenv/Lib/site-packages/sympy/solvers/tests/test_numeric.py b/phivenv/Lib/site-packages/sympy/solvers/tests/test_numeric.py new file mode 100644 index 0000000000000000000000000000000000000000..12abd38c80f07279ed41aefc7952762da0f9f430 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/solvers/tests/test_numeric.py @@ -0,0 +1,139 @@ +from sympy.core.function import nfloat +from sympy.core.numbers import (Float, I, Rational, pi) +from sympy.core.relational import Eq +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import sin +from sympy.integrals.integrals import Integral +from sympy.matrices.dense import Matrix +from mpmath import mnorm, mpf +from sympy.solvers import nsolve +from sympy.utilities.lambdify import lambdify +from sympy.testing.pytest import raises, XFAIL +from sympy.utilities.decorator import conserve_mpmath_dps + +@XFAIL +def test_nsolve_fail(): + x = symbols('x') + # Sometimes it is better to use the numerator (issue 4829) + # but sometimes it is not (issue 11768) so leave this to + # the discretion of the user + ans = nsolve(x**2/(1 - x)/(1 - 2*x)**2 - 100, x, 0) + assert ans > 0.46 and ans < 0.47 + + +def test_nsolve_denominator(): + x = symbols('x') + # Test that nsolve uses the full expression (numerator and denominator). + ans = nsolve((x**2 + 3*x + 2)/(x + 2), -2.1) + # The root -2 was divided out, so make sure we don't find it. + assert ans == -1.0 + +def test_nsolve(): + # onedimensional + x = Symbol('x') + assert nsolve(sin(x), 2) - pi.evalf() < 1e-15 + assert nsolve(Eq(2*x, 2), x, -10) == nsolve(2*x - 2, -10) + # Testing checks on number of inputs + raises(TypeError, lambda: nsolve(Eq(2*x, 2))) + raises(TypeError, lambda: nsolve(Eq(2*x, 2), x, 1, 2)) + # multidimensional + x1 = Symbol('x1') + x2 = Symbol('x2') + f1 = 3 * x1**2 - 2 * x2**2 - 1 + f2 = x1**2 - 2 * x1 + x2**2 + 2 * x2 - 8 + f = Matrix((f1, f2)).T + F = lambdify((x1, x2), f.T, modules='mpmath') + for x0 in [(-1, 1), (1, -2), (4, 4), (-4, -4)]: + x = nsolve(f, (x1, x2), x0, tol=1.e-8) + assert mnorm(F(*x), 1) <= 1.e-10 + # The Chinese mathematician Zhu Shijie was the very first to solve this + # nonlinear system 700 years ago (z was added to make it 3-dimensional) + x = Symbol('x') + y = Symbol('y') + z = Symbol('z') + f1 = -x + 2*y + f2 = (x**2 + x*(y**2 - 2) - 4*y) / (x + 4) + f3 = sqrt(x**2 + y**2)*z + f = Matrix((f1, f2, f3)).T + F = lambdify((x, y, z), f.T, modules='mpmath') + + def getroot(x0): + root = nsolve(f, (x, y, z), x0) + assert mnorm(F(*root), 1) <= 1.e-8 + return root + assert list(map(round, getroot((1, 1, 1)))) == [2, 1, 0] + assert nsolve([Eq( + f1, 0), Eq(f2, 0), Eq(f3, 0)], [x, y, z], (1, 1, 1)) # just see that it works + a = Symbol('a') + assert abs(nsolve(1/(0.001 + a)**3 - 6/(0.9 - a)**3, a, 0.3) - + mpf('0.31883011387318591')) < 1e-15 + + +def test_issue_6408(): + x = Symbol('x') + assert nsolve(Piecewise((x, x < 1), (x**2, True)), x, 2) == 0 + + +def test_issue_6408_integral(): + x, y = symbols('x y') + assert nsolve(Integral(x*y, (x, 0, 5)), y, 2) == 0 + + +@conserve_mpmath_dps +def test_increased_dps(): + # Issue 8564 + import mpmath + mpmath.mp.dps = 128 + x = Symbol('x') + e1 = x**2 - pi + q = nsolve(e1, x, 3.0) + + assert abs(sqrt(pi).evalf(128) - q) < 1e-128 + +def test_nsolve_precision(): + x, y = symbols('x y') + sol = nsolve(x**2 - pi, x, 3, prec=128) + assert abs(sqrt(pi).evalf(128) - sol) < 1e-128 + assert isinstance(sol, Float) + + sols = nsolve((y**2 - x, x**2 - pi), (x, y), (3, 3), prec=128) + assert isinstance(sols, Matrix) + assert sols.shape == (2, 1) + assert abs(sqrt(pi).evalf(128) - sols[0]) < 1e-128 + assert abs(sqrt(sqrt(pi)).evalf(128) - sols[1]) < 1e-128 + assert all(isinstance(i, Float) for i in sols) + +def test_nsolve_complex(): + x, y = symbols('x y') + + assert nsolve(x**2 + 2, 1j) == sqrt(2.)*I + assert nsolve(x**2 + 2, I) == sqrt(2.)*I + + assert nsolve([x**2 + 2, y**2 + 2], [x, y], [I, I]) == Matrix([sqrt(2.)*I, sqrt(2.)*I]) + assert nsolve([x**2 + 2, y**2 + 2], [x, y], [I, I]) == Matrix([sqrt(2.)*I, sqrt(2.)*I]) + +def test_nsolve_dict_kwarg(): + x, y = symbols('x y') + # one variable + assert nsolve(x**2 - 2, 1, dict = True) == \ + [{x: sqrt(2.)}] + # one variable with complex solution + assert nsolve(x**2 + 2, I, dict = True) == \ + [{x: sqrt(2.)*I}] + # two variables + assert nsolve([x**2 + y**2 - 5, x**2 - y**2 + 1], [x, y], [1, 1], dict = True) == \ + [{x: sqrt(2.), y: sqrt(3.)}] + +def test_nsolve_rational(): + x = symbols('x') + assert nsolve(x - Rational(1, 3), 0, prec=100) == Rational(1, 3).evalf(100) + + +def test_issue_14950(): + x = Matrix(symbols('t s')) + x0 = Matrix([17, 23]) + eqn = x + x0 + assert nsolve(eqn, x, x0) == nfloat(-x0) + assert nsolve(eqn.T, x.T, x0.T) == nfloat(-x0) diff --git a/phivenv/Lib/site-packages/sympy/solvers/tests/test_pde.py b/phivenv/Lib/site-packages/sympy/solvers/tests/test_pde.py new file mode 100644 index 0000000000000000000000000000000000000000..948d90c7be21a9e0e03753e723ef04f1fb08a5d6 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/solvers/tests/test_pde.py @@ -0,0 +1,239 @@ +from sympy.core.function import (Derivative as D, Function) +from sympy.core.relational import Eq +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.core import S +from sympy.solvers.pde import (pde_separate, pde_separate_add, pde_separate_mul, + pdsolve, classify_pde, checkpdesol) +from sympy.testing.pytest import raises + + +a, b, c, x, y = symbols('a b c x y') + +def test_pde_separate_add(): + x, y, z, t = symbols("x,y,z,t") + F, T, X, Y, Z, u = map(Function, 'FTXYZu') + + eq = Eq(D(u(x, t), x), D(u(x, t), t)*exp(u(x, t))) + res = pde_separate_add(eq, u(x, t), [X(x), T(t)]) + assert res == [D(X(x), x)*exp(-X(x)), D(T(t), t)*exp(T(t))] + + +def test_pde_separate(): + x, y, z, t = symbols("x,y,z,t") + F, T, X, Y, Z, u = map(Function, 'FTXYZu') + + eq = Eq(D(u(x, t), x), D(u(x, t), t)*exp(u(x, t))) + raises(ValueError, lambda: pde_separate(eq, u(x, t), [X(x), T(t)], 'div')) + + +def test_pde_separate_mul(): + x, y, z, t = symbols("x,y,z,t") + c = Symbol("C", real=True) + Phi = Function('Phi') + F, R, T, X, Y, Z, u = map(Function, 'FRTXYZu') + r, theta, z = symbols('r,theta,z') + + # Something simple :) + eq = Eq(D(F(x, y, z), x) + D(F(x, y, z), y) + D(F(x, y, z), z), 0) + + # Duplicate arguments in functions + raises( + ValueError, lambda: pde_separate_mul(eq, F(x, y, z), [X(x), u(z, z)])) + # Wrong number of arguments + raises(ValueError, lambda: pde_separate_mul(eq, F(x, y, z), [X(x), Y(y)])) + # Wrong variables: [x, y] -> [x, z] + raises( + ValueError, lambda: pde_separate_mul(eq, F(x, y, z), [X(t), Y(x, y)])) + + assert pde_separate_mul(eq, F(x, y, z), [Y(y), u(x, z)]) == \ + [D(Y(y), y)/Y(y), -D(u(x, z), x)/u(x, z) - D(u(x, z), z)/u(x, z)] + assert pde_separate_mul(eq, F(x, y, z), [X(x), Y(y), Z(z)]) == \ + [D(X(x), x)/X(x), -D(Z(z), z)/Z(z) - D(Y(y), y)/Y(y)] + + # wave equation + wave = Eq(D(u(x, t), t, t), c**2*D(u(x, t), x, x)) + res = pde_separate_mul(wave, u(x, t), [X(x), T(t)]) + assert res == [D(X(x), x, x)/X(x), D(T(t), t, t)/(c**2*T(t))] + + # Laplace equation in cylindrical coords + eq = Eq(1/r * D(Phi(r, theta, z), r) + D(Phi(r, theta, z), r, 2) + + 1/r**2 * D(Phi(r, theta, z), theta, 2) + D(Phi(r, theta, z), z, 2), 0) + # Separate z + res = pde_separate_mul(eq, Phi(r, theta, z), [Z(z), u(theta, r)]) + assert res == [D(Z(z), z, z)/Z(z), + -D(u(theta, r), r, r)/u(theta, r) - + D(u(theta, r), r)/(r*u(theta, r)) - + D(u(theta, r), theta, theta)/(r**2*u(theta, r))] + # Lets use the result to create a new equation... + eq = Eq(res[1], c) + # ...and separate theta... + res = pde_separate_mul(eq, u(theta, r), [T(theta), R(r)]) + assert res == [D(T(theta), theta, theta)/T(theta), + -r*D(R(r), r)/R(r) - r**2*D(R(r), r, r)/R(r) - c*r**2] + # ...or r... + res = pde_separate_mul(eq, u(theta, r), [R(r), T(theta)]) + assert res == [r*D(R(r), r)/R(r) + r**2*D(R(r), r, r)/R(r) + c*r**2, + -D(T(theta), theta, theta)/T(theta)] + + +def test_issue_11726(): + x, t = symbols("x t") + f = symbols("f", cls=Function) + X, T = symbols("X T", cls=Function) + + u = f(x, t) + eq = u.diff(x, 2) - u.diff(t, 2) + res = pde_separate(eq, u, [T(x), X(t)]) + assert res == [D(T(x), x, x)/T(x),D(X(t), t, t)/X(t)] + + +def test_pde_classify(): + # When more number of hints are added, add tests for classifying here. + f = Function('f') + eq1 = a*f(x,y) + b*f(x,y).diff(x) + c*f(x,y).diff(y) + eq2 = 3*f(x,y) + 2*f(x,y).diff(x) + f(x,y).diff(y) + eq3 = a*f(x,y) + b*f(x,y).diff(x) + 2*f(x,y).diff(y) + eq4 = x*f(x,y) + f(x,y).diff(x) + 3*f(x,y).diff(y) + eq5 = x**2*f(x,y) + x*f(x,y).diff(x) + x*y*f(x,y).diff(y) + eq6 = y*x**2*f(x,y) + y*f(x,y).diff(x) + f(x,y).diff(y) + for eq in [eq1, eq2, eq3]: + assert classify_pde(eq) == ('1st_linear_constant_coeff_homogeneous',) + for eq in [eq4, eq5, eq6]: + assert classify_pde(eq) == ('1st_linear_variable_coeff',) + + +def test_checkpdesol(): + f, F = map(Function, ['f', 'F']) + eq1 = a*f(x,y) + b*f(x,y).diff(x) + c*f(x,y).diff(y) + eq2 = 3*f(x,y) + 2*f(x,y).diff(x) + f(x,y).diff(y) + eq3 = a*f(x,y) + b*f(x,y).diff(x) + 2*f(x,y).diff(y) + for eq in [eq1, eq2, eq3]: + assert checkpdesol(eq, pdsolve(eq))[0] + eq4 = x*f(x,y) + f(x,y).diff(x) + 3*f(x,y).diff(y) + eq5 = 2*f(x,y) + 1*f(x,y).diff(x) + 3*f(x,y).diff(y) + eq6 = f(x,y) + 1*f(x,y).diff(x) + 3*f(x,y).diff(y) + assert checkpdesol(eq4, [pdsolve(eq5), pdsolve(eq6)]) == [ + (False, (x - 2)*F(3*x - y)*exp(-x/S(5) - 3*y/S(5))), + (False, (x - 1)*F(3*x - y)*exp(-x/S(10) - 3*y/S(10)))] + for eq in [eq4, eq5, eq6]: + assert checkpdesol(eq, pdsolve(eq))[0] + sol = pdsolve(eq4) + sol4 = Eq(sol.lhs - sol.rhs, 0) + raises(NotImplementedError, lambda: + checkpdesol(eq4, sol4, solve_for_func=False)) + + +def test_solvefun(): + f, F, G, H = map(Function, ['f', 'F', 'G', 'H']) + eq1 = f(x,y) + f(x,y).diff(x) + f(x,y).diff(y) + assert pdsolve(eq1) == Eq(f(x, y), F(x - y)*exp(-x/2 - y/2)) + assert pdsolve(eq1, solvefun=G) == Eq(f(x, y), G(x - y)*exp(-x/2 - y/2)) + assert pdsolve(eq1, solvefun=H) == Eq(f(x, y), H(x - y)*exp(-x/2 - y/2)) + + +def test_pde_1st_linear_constant_coeff_homogeneous(): + f, F = map(Function, ['f', 'F']) + u = f(x, y) + eq = 2*u + u.diff(x) + u.diff(y) + assert classify_pde(eq) == ('1st_linear_constant_coeff_homogeneous',) + sol = pdsolve(eq) + assert sol == Eq(u, F(x - y)*exp(-x - y)) + assert checkpdesol(eq, sol)[0] + + eq = 4 + (3*u.diff(x)/u) + (2*u.diff(y)/u) + assert classify_pde(eq) == ('1st_linear_constant_coeff_homogeneous',) + sol = pdsolve(eq) + assert sol == Eq(u, F(2*x - 3*y)*exp(-S(12)*x/13 - S(8)*y/13)) + assert checkpdesol(eq, sol)[0] + + eq = u + (6*u.diff(x)) + (7*u.diff(y)) + assert classify_pde(eq) == ('1st_linear_constant_coeff_homogeneous',) + sol = pdsolve(eq) + assert sol == Eq(u, F(7*x - 6*y)*exp(-6*x/S(85) - 7*y/S(85))) + assert checkpdesol(eq, sol)[0] + + eq = a*u + b*u.diff(x) + c*u.diff(y) + sol = pdsolve(eq) + assert checkpdesol(eq, sol)[0] + + +def test_pde_1st_linear_constant_coeff(): + f, F = map(Function, ['f', 'F']) + u = f(x,y) + eq = -2*u.diff(x) + 4*u.diff(y) + 5*u - exp(x + 3*y) + sol = pdsolve(eq) + assert sol == Eq(f(x,y), + (F(4*x + 2*y)*exp(x/2) + exp(x + 4*y)/15)*exp(-y)) + assert classify_pde(eq) == ('1st_linear_constant_coeff', + '1st_linear_constant_coeff_Integral') + assert checkpdesol(eq, sol)[0] + + eq = (u.diff(x)/u) + (u.diff(y)/u) + 1 - (exp(x + y)/u) + sol = pdsolve(eq) + assert sol == Eq(f(x, y), F(x - y)*exp(-x/2 - y/2) + exp(x + y)/3) + assert classify_pde(eq) == ('1st_linear_constant_coeff', + '1st_linear_constant_coeff_Integral') + assert checkpdesol(eq, sol)[0] + + eq = 2*u + -u.diff(x) + 3*u.diff(y) + sin(x) + sol = pdsolve(eq) + assert sol == Eq(f(x, y), + F(3*x + y)*exp(x/5 - 3*y/5) - 2*sin(x)/5 - cos(x)/5) + assert classify_pde(eq) == ('1st_linear_constant_coeff', + '1st_linear_constant_coeff_Integral') + assert checkpdesol(eq, sol)[0] + + eq = u + u.diff(x) + u.diff(y) + x*y + sol = pdsolve(eq) + assert sol.expand() == Eq(f(x, y), + x + y + (x - y)**2/4 - (x + y)**2/4 + F(x - y)*exp(-x/2 - y/2) - 2).expand() + assert classify_pde(eq) == ('1st_linear_constant_coeff', + '1st_linear_constant_coeff_Integral') + assert checkpdesol(eq, sol)[0] + eq = u + u.diff(x) + u.diff(y) + log(x) + assert classify_pde(eq) == ('1st_linear_constant_coeff', + '1st_linear_constant_coeff_Integral') + + +def test_pdsolve_all(): + f, F = map(Function, ['f', 'F']) + u = f(x,y) + eq = u + u.diff(x) + u.diff(y) + x**2*y + sol = pdsolve(eq, hint = 'all') + keys = ['1st_linear_constant_coeff', + '1st_linear_constant_coeff_Integral', 'default', 'order'] + assert sorted(sol.keys()) == keys + assert sol['order'] == 1 + assert sol['default'] == '1st_linear_constant_coeff' + assert sol['1st_linear_constant_coeff'].expand() == Eq(f(x, y), + -x**2*y + x**2 + 2*x*y - 4*x - 2*y + F(x - y)*exp(-x/2 - y/2) + 6).expand() + + +def test_pdsolve_variable_coeff(): + f, F = map(Function, ['f', 'F']) + u = f(x, y) + eq = x*(u.diff(x)) - y*(u.diff(y)) + y**2*u - y**2 + sol = pdsolve(eq, hint="1st_linear_variable_coeff") + assert sol == Eq(u, F(x*y)*exp(y**2/2) + 1) + assert checkpdesol(eq, sol)[0] + + eq = x**2*u + x*u.diff(x) + x*y*u.diff(y) + sol = pdsolve(eq, hint='1st_linear_variable_coeff') + assert sol == Eq(u, F(y*exp(-x))*exp(-x**2/2)) + assert checkpdesol(eq, sol)[0] + + eq = y*x**2*u + y*u.diff(x) + u.diff(y) + sol = pdsolve(eq, hint='1st_linear_variable_coeff') + assert sol == Eq(u, F(-2*x + y**2)*exp(-x**3/3)) + assert checkpdesol(eq, sol)[0] + + eq = exp(x)**2*(u.diff(x)) + y + sol = pdsolve(eq, hint='1st_linear_variable_coeff') + assert sol == Eq(u, y*exp(-2*x)/2 + F(y)) + assert checkpdesol(eq, sol)[0] + + eq = exp(2*x)*(u.diff(y)) + y*u - u + sol = pdsolve(eq, hint='1st_linear_variable_coeff') + assert sol == Eq(u, F(x)*exp(-y*(y - 2)*exp(-2*x)/2)) diff --git a/phivenv/Lib/site-packages/sympy/solvers/tests/test_polysys.py b/phivenv/Lib/site-packages/sympy/solvers/tests/test_polysys.py new file mode 100644 index 0000000000000000000000000000000000000000..a119591a0354ba377a18767eae6e8b7812810a0d --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/solvers/tests/test_polysys.py @@ -0,0 +1,462 @@ +"""Tests for solvers of systems of polynomial equations. """ +from sympy.polys.domains import ZZ, QQ_I +from sympy.core.numbers import (I, Integer, Rational) +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.polys.domains.rationalfield import QQ +from sympy.polys.polyerrors import UnsolvableFactorError +from sympy.polys.polyoptions import Options +from sympy.polys.polytools import Poly +from sympy.polys.rootoftools import CRootOf +from sympy.solvers.solvers import solve +from sympy.utilities.iterables import flatten +from sympy.abc import a, b, c, x, y, z +from sympy.polys import PolynomialError +from sympy.solvers.polysys import (solve_poly_system, + solve_triangulated, + solve_biquadratic, SolveFailed, + solve_generic, factor_system_bool, + factor_system_cond, factor_system_poly, + factor_system, _factor_sets, _factor_sets_slow) +from sympy.polys.polytools import parallel_poly_from_expr +from sympy.testing.pytest import raises +from sympy.core.relational import Eq +from sympy.functions.elementary.trigonometric import sin, cos + +from sympy.functions.elementary.exponential import exp + + +def test_solve_poly_system(): + assert solve_poly_system([x - 1], x) == [(S.One,)] + + assert solve_poly_system([y - x, y - x - 1], x, y) is None + + assert solve_poly_system([y - x**2, y + x**2], x, y) == [(S.Zero, S.Zero)] + + assert solve_poly_system([2*x - 3, y*Rational(3, 2) - 2*x, z - 5*y], x, y, z) == \ + [(Rational(3, 2), Integer(2), Integer(10))] + + assert solve_poly_system([x*y - 2*y, 2*y**2 - x**2], x, y) == \ + [(0, 0), (2, -sqrt(2)), (2, sqrt(2))] + + assert solve_poly_system([y - x**2, y + x**2 + 1], x, y) == \ + [(-I*sqrt(S.Half), Rational(-1, 2)), (I*sqrt(S.Half), Rational(-1, 2))] + + f_1 = x**2 + y + z - 1 + f_2 = x + y**2 + z - 1 + f_3 = x + y + z**2 - 1 + + a, b = sqrt(2) - 1, -sqrt(2) - 1 + + assert solve_poly_system([f_1, f_2, f_3], x, y, z) == \ + [(0, 0, 1), (0, 1, 0), (1, 0, 0), (a, a, a), (b, b, b)] + + solution = [(1, -1), (1, 1)] + + assert solve_poly_system([Poly(x**2 - y**2), Poly(x - 1)]) == solution + assert solve_poly_system([x**2 - y**2, x - 1], x, y) == solution + assert solve_poly_system([x**2 - y**2, x - 1]) == solution + + assert solve_poly_system( + [x + x*y - 3, y + x*y - 4], x, y) == [(-3, -2), (1, 2)] + + raises(NotImplementedError, lambda: solve_poly_system([x**3 - y**3], x, y)) + raises(NotImplementedError, lambda: solve_poly_system( + [z, -2*x*y**2 + x + y**2*z, y**2*(-z - 4) + 2])) + raises(PolynomialError, lambda: solve_poly_system([1/x], x)) + + raises(NotImplementedError, lambda: solve_poly_system( + [x-1,], (x, y))) + raises(NotImplementedError, lambda: solve_poly_system( + [y-1,], (x, y))) + + # solve_poly_system should ideally construct solutions using + # CRootOf for the following four tests + assert solve_poly_system([x**5 - x + 1], [x], strict=False) == [] + raises(UnsolvableFactorError, lambda: solve_poly_system( + [x**5 - x + 1], [x], strict=True)) + + assert solve_poly_system([(x - 1)*(x**5 - x + 1), y**2 - 1], [x, y], + strict=False) == [(1, -1), (1, 1)] + raises(UnsolvableFactorError, + lambda: solve_poly_system([(x - 1)*(x**5 - x + 1), y**2-1], + [x, y], strict=True)) + + +def test_solve_generic(): + NewOption = Options((x, y), {'domain': 'ZZ'}) + assert solve_generic([x**2 - 2*y**2, y**2 - y + 1], NewOption) == \ + [(-sqrt(-1 - sqrt(3)*I), Rational(1, 2) - sqrt(3)*I/2), + (sqrt(-1 - sqrt(3)*I), Rational(1, 2) - sqrt(3)*I/2), + (-sqrt(-1 + sqrt(3)*I), Rational(1, 2) + sqrt(3)*I/2), + (sqrt(-1 + sqrt(3)*I), Rational(1, 2) + sqrt(3)*I/2)] + + # solve_generic should ideally construct solutions using + # CRootOf for the following two tests + assert solve_generic( + [2*x - y, (y - 1)*(y**5 - y + 1)], NewOption, strict=False) == \ + [(Rational(1, 2), 1)] + raises(UnsolvableFactorError, lambda: solve_generic( + [2*x - y, (y - 1)*(y**5 - y + 1)], NewOption, strict=True)) + + +def test_solve_biquadratic(): + x0, y0, x1, y1, r = symbols('x0 y0 x1 y1 r') + + f_1 = (x - 1)**2 + (y - 1)**2 - r**2 + f_2 = (x - 2)**2 + (y - 2)**2 - r**2 + s = sqrt(2*r**2 - 1) + a = (3 - s)/2 + b = (3 + s)/2 + assert solve_poly_system([f_1, f_2], x, y) == [(a, b), (b, a)] + + f_1 = (x - 1)**2 + (y - 2)**2 - r**2 + f_2 = (x - 1)**2 + (y - 1)**2 - r**2 + + assert solve_poly_system([f_1, f_2], x, y) == \ + [(1 - sqrt((2*r - 1)*(2*r + 1))/2, Rational(3, 2)), + (1 + sqrt((2*r - 1)*(2*r + 1))/2, Rational(3, 2))] + + query = lambda expr: expr.is_Pow and expr.exp is S.Half + + f_1 = (x - 1 )**2 + (y - 2)**2 - r**2 + f_2 = (x - x1)**2 + (y - 1)**2 - r**2 + + result = solve_poly_system([f_1, f_2], x, y) + + assert len(result) == 2 and all(len(r) == 2 for r in result) + assert all(r.count(query) == 1 for r in flatten(result)) + + f_1 = (x - x0)**2 + (y - y0)**2 - r**2 + f_2 = (x - x1)**2 + (y - y1)**2 - r**2 + + result = solve_poly_system([f_1, f_2], x, y) + + assert len(result) == 2 and all(len(r) == 2 for r in result) + assert all(len(r.find(query)) == 1 for r in flatten(result)) + + s1 = (x*y - y, x**2 - x) + assert solve(s1) == [{x: 1}, {x: 0, y: 0}] + s2 = (x*y - x, y**2 - y) + assert solve(s2) == [{y: 1}, {x: 0, y: 0}] + gens = (x, y) + for seq in (s1, s2): + (f, g), opt = parallel_poly_from_expr(seq, *gens) + raises(SolveFailed, lambda: solve_biquadratic(f, g, opt)) + seq = (x**2 + y**2 - 2, y**2 - 1) + (f, g), opt = parallel_poly_from_expr(seq, *gens) + assert solve_biquadratic(f, g, opt) == [ + (-1, -1), (-1, 1), (1, -1), (1, 1)] + ans = [(0, -1), (0, 1)] + seq = (x**2 + y**2 - 1, y**2 - 1) + (f, g), opt = parallel_poly_from_expr(seq, *gens) + assert solve_biquadratic(f, g, opt) == ans + seq = (x**2 + y**2 - 1, x**2 - x + y**2 - 1) + (f, g), opt = parallel_poly_from_expr(seq, *gens) + assert solve_biquadratic(f, g, opt) == ans + + +def test_solve_triangulated(): + f_1 = x**2 + y + z - 1 + f_2 = x + y**2 + z - 1 + f_3 = x + y + z**2 - 1 + + a, b = sqrt(2) - 1, -sqrt(2) - 1 + + assert solve_triangulated([f_1, f_2, f_3], x, y, z) == \ + [(0, 0, 1), (0, 1, 0), (1, 0, 0)] + + dom = QQ.algebraic_field(sqrt(2)) + + assert solve_triangulated([f_1, f_2, f_3], x, y, z, domain=dom) == \ + [(0, 0, 1), (0, 1, 0), (1, 0, 0), (a, a, a), (b, b, b)] + + a, b = CRootOf(z**2 + 2*z - 1, 0), CRootOf(z**2 + 2*z - 1, 1) + assert solve_triangulated([f_1, f_2, f_3], x, y, z, extension=True) == \ + [(0, 0, 1), (0, 1, 0), (1, 0, 0), (a, a, a), (b, b, b)] + + +def test_solve_issue_3686(): + roots = solve_poly_system([((x - 5)**2/250000 + (y - Rational(5, 10))**2/250000) - 1, x], x, y) + assert roots == [(0, S.Half - 15*sqrt(1111)), (0, S.Half + 15*sqrt(1111))] + + roots = solve_poly_system([((x - 5)**2/250000 + (y - 5.0/10)**2/250000) - 1, x], x, y) + # TODO: does this really have to be so complicated?! + assert len(roots) == 2 + assert roots[0][0] == 0 + assert roots[0][1].epsilon_eq(-499.474999374969, 1e12) + assert roots[1][0] == 0 + assert roots[1][1].epsilon_eq(500.474999374969, 1e12) + + +def test_factor_system(): + + assert factor_system([x**2 + 2*x + 1]) == [[x + 1]] + assert factor_system([x**2 + 2*x + 1, y**2 + 2*y + 1]) == [[x + 1, y + 1]] + assert factor_system([x**2 + 1]) == [[x**2 + 1]] + assert factor_system([]) == [[]] + + assert factor_system([x**2 + y**2 + 2*x*y, x**2 - 2], extension=sqrt(2)) == [ + [x + y, x + sqrt(2)], + [x + y, x - sqrt(2)], + ] + + assert factor_system([x**2 + 1, y**2 + 1], gaussian=True) == [ + [x + I, y + I], + [x + I, y - I], + [x - I, y + I], + [x - I, y - I], + ] + + assert factor_system([x**2 + 1, y**2 + 1], domain=QQ_I) == [ + [x + I, y + I], + [x + I, y - I], + [x - I, y + I], + [x - I, y - I], + ] + + assert factor_system([0]) == [[]] + assert factor_system([1]) == [] + assert factor_system([0 , x]) == [[x]] + assert factor_system([1, 0, x]) == [] + + assert factor_system([x**4 - 1, y**6 - 1]) == [ + [x**2 + 1, y**2 + y + 1], + [x**2 + 1, y**2 - y + 1], + [x**2 + 1, y + 1], + [x**2 + 1, y - 1], + [x + 1, y**2 + y + 1], + [x + 1, y**2 - y + 1], + [x - 1, y**2 + y + 1], + [x - 1, y**2 - y + 1], + [x + 1, y + 1], + [x + 1, y - 1], + [x - 1, y + 1], + [x - 1, y - 1], + ] + + assert factor_system([(x - 1)*(y - 2), (y - 2)*(z - 3)]) == [ + [x - 1, z - 3], + [y - 2] + ] + + assert factor_system([sin(x)**2 + cos(x)**2 - 1, x]) == [ + [x, sin(x)**2 + cos(x)**2 - 1], + ] + + assert factor_system([sin(x)**2 + cos(x)**2 - 1]) == [ + [sin(x)**2 + cos(x)**2 - 1] + ] + + assert factor_system([sin(x)**2 + cos(x)**2]) == [ + [sin(x)**2 + cos(x)**2] + ] + + assert factor_system([a*x, y, a]) == [[y, a]] + + assert factor_system([a*x, y, a], [x, y]) == [] + + assert factor_system([a ** 2 * x, y], [x, y]) == [[x, y]] + + assert factor_system([a*x*(x - 1), b*y, c], [x, y]) == [] + + assert factor_system([a*x*(x - 1), b*y, c], [x, y, c]) == [ + [x - 1, y, c], + [x, y, c], + ] + + assert factor_system([a*x*(x - 1), b*y, c]) == [ + [x - 1, y, c], + [x, y, c], + [x - 1, b, c], + [x, b, c], + [y, a, c], + [a, b, c], + ] + + assert factor_system([x**2 - 2], [y]) == [] + + assert factor_system([x**2 - 2], [x]) == [[x**2 - 2]] + + assert factor_system([cos(x)**2 - sin(x)**2, cos(x)**2 + sin(x)**2 - 1]) == [ + [sin(x)**2 + cos(x)**2 - 1, sin(x) + cos(x)], + [sin(x)**2 + cos(x)**2 - 1, -sin(x) + cos(x)], + ] + + assert factor_system([(cos(x) + sin(x))**2 - 1, cos(x)**2 - sin(x)**2 - cos(2*x)]) == [ + [sin(x)**2 - cos(x)**2 + cos(2*x), sin(x) + cos(x) + 1], + [sin(x)**2 - cos(x)**2 + cos(2*x), sin(x) + cos(x) - 1], + ] + + assert factor_system([(cos(x) + sin(x))*exp(y) - 1, (cos(x) - sin(x))*exp(y) - 1]) == [ + [exp(y)*sin(x) + exp(y)*cos(x) - 1, -exp(y)*sin(x) + exp(y)*cos(x) - 1] + ] + + +def test_factor_system_poly(): + + px = lambda e: Poly(e, x) + pxab = lambda e: Poly(e, x, domain=ZZ[a, b]) + pxI = lambda e: Poly(e, x, domain=QQ_I) + pxyz = lambda e: Poly(e, (x, y, z)) + + assert factor_system_poly([px(x**2 - 1), px(x**2 - 4)]) == [ + [px(x + 2), px(x + 1)], + [px(x + 2), px(x - 1)], + [px(x + 1), px(x - 2)], + [px(x - 1), px(x - 2)], + ] + + assert factor_system_poly([px(x**2 - 1)]) == [[px(x + 1)], [px(x - 1)]] + + assert factor_system_poly([pxyz(x**2*y - y), pxyz(x**2*z - z)]) == [ + [pxyz(x + 1)], + [pxyz(x - 1)], + [pxyz(y), pxyz(z)], + ] + + assert factor_system_poly([px(x**2*(x - 1)**2), px(x*(x - 1))]) == [ + [px(x)], + [px(x - 1)], + ] + + assert factor_system_poly([pxyz(x**2 + y*x), pxyz(x**2 + z*x)]) == [ + [pxyz(x + y), pxyz(x + z)], + [pxyz(x)], + ] + + assert factor_system_poly([pxab((a - 1)*(x - 2)), pxab((b - 3)*(x - 2))]) == [ + [pxab(x - 2)], + [pxab(a - 1), pxab(b - 3)], + ] + + assert factor_system_poly([pxI(x**2 + 1)]) == [[pxI(x + I)], [pxI(x - I)]] + + assert factor_system_poly([]) == [[]] + + assert factor_system_poly([px(1)]) == [] + assert factor_system_poly([px(0), px(x)]) == [[px(x)]] + + +def test_factor_system_cond(): + + assert factor_system_cond([x ** 2 - 1, x ** 2 - 4]) == [ + [x + 2, x + 1], + [x + 2, x - 1], + [x + 1, x - 2], + [x - 1, x - 2], + ] + + assert factor_system_cond([1]) == [] + assert factor_system_cond([0]) == [[]] + assert factor_system_cond([1, x]) == [] + assert factor_system_cond([0, x]) == [[x]] + assert factor_system_cond([]) == [[]] + + assert factor_system_cond([x**2 + y*x]) == [[x + y], [x]] + + assert factor_system_cond([(a - 1)*(x - 2), (b - 3)*(x - 2)], [x]) == [ + [x - 2], + [a - 1, b - 3], + ] + + assert factor_system_cond([a * (x - 1), b], [x]) == [[x - 1, b], [a, b]] + + assert factor_system_cond([a*x*(x-1), b*y, c], [x, y]) == [ + [x - 1, y, c], + [x, y, c], + [x - 1, b, c], + [x, b, c], + [y, a, c], + [a, b, c], + ] + + assert factor_system_cond([x*(x-1), y], [x, y]) == [[x - 1, y], [x, y]] + + assert factor_system_cond([a*x, y, a], [x, y]) == [[y, a]] + + assert factor_system_cond([a*x, b*x], [x, y]) == [[x], [a, b]] + + assert factor_system_cond([a*b*x, y], [x, y]) == [[x, y], [y, a*b]] + + assert factor_system_cond([a*b*x, y]) == [[x, y], [y, a], [y, b]] + + assert factor_system_cond([a**2*x, y], [x, y]) == [[x, y], [y, a]] + +def test_factor_system_bool(): + + eqs = [a*(x - 1)*(y - 1), b*(x - 2)*(y - 1)*(y - 2)] + assert factor_system_bool(eqs, [x, y]) == ( + Eq(y - 1, 0) + | (Eq(a, 0) & Eq(b, 0)) + | (Eq(a, 0) & Eq(x - 2, 0)) + | (Eq(a, 0) & Eq(y - 2, 0)) + | (Eq(b, 0) & Eq(x - 1, 0)) + | (Eq(x - 2, 0) & Eq(x - 1, 0)) + | (Eq(x - 1, 0) & Eq(y - 2, 0)) + ) + + assert factor_system_bool([x - 1], [x]) == Eq(x - 1, 0) + + assert factor_system_bool([(x - 1)*(x - 2)], [x]) == Eq(x - 2, 0) | Eq(x - 1, 0) + + assert factor_system_bool([], [x]) == True + assert factor_system_bool([0], [x]) == True + assert factor_system_bool([1], [x]) == False + assert factor_system_bool([a], [x]) == Eq(a, 0) + + assert factor_system_bool([a * x, y, a], [x, y]) == Eq(a, 0) & Eq(y, 0) + + assert (factor_system_bool([a*x, b*y*x, a], [x, y]) == ( + Eq(a, 0) & Eq(b, 0)) + | (Eq(a, 0) & Eq(x, 0)) + | (Eq(a, 0) & Eq(y, 0))) + + assert (factor_system_bool([a*x, b*x], [x, y]) == Eq(x, 0) | + (Eq(a, 0) & Eq(b, 0))) + + assert (factor_system_bool([a*b*x, y], [x, y]) == ( + Eq(x, 0) & Eq(y, 0)) | + (Eq(y, 0) & Eq(a*b, 0))) + + assert (factor_system_bool([a**2*x, y], [x, y]) == ( + Eq(a, 0) & Eq(y, 0)) | + (Eq(x, 0) & Eq(y, 0))) + + assert factor_system_bool([a*x*y, b*y*z], [x, y, z]) == ( + Eq(y, 0) + | (Eq(a, 0) & Eq(b, 0)) + | (Eq(a, 0) & Eq(z, 0)) + | (Eq(b, 0) & Eq(x, 0)) + | (Eq(x, 0) & Eq(z, 0)) + ) + + assert factor_system_bool([a*(x - 1), b], [x]) == ( + (Eq(a, 0) & Eq(b, 0)) + | (Eq(x - 1, 0) & Eq(b, 0)) + ) + + +def test_factor_sets(): + # + from random import randint + + def generate_random_system(n_eqs=3, n_factors=2, max_val=10): + return [ + [randint(0, max_val) for _ in range(randint(1, n_factors))] + for _ in range(n_eqs) + ] + + test_cases = [ + [[1, 2], [1, 3]], + [[1, 2], [3, 4]], + [[1], [1, 2], [2]], + ] + + for case in test_cases: + assert _factor_sets(case) == _factor_sets_slow(case) + + for _ in range(100): + system = generate_random_system() + assert _factor_sets(system) == _factor_sets_slow(system) diff --git a/phivenv/Lib/site-packages/sympy/solvers/tests/test_recurr.py b/phivenv/Lib/site-packages/sympy/solvers/tests/test_recurr.py new file mode 100644 index 0000000000000000000000000000000000000000..5a6306b51a5cf33ccd9fae131430a24690d540a7 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/solvers/tests/test_recurr.py @@ -0,0 +1,295 @@ +from sympy.core.function import (Function, Lambda, expand) +from sympy.core.numbers import (I, Rational) +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.combinatorial.factorials import (rf, binomial, factorial) +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.polys.polytools import factor +from sympy.solvers.recurr import rsolve, rsolve_hyper, rsolve_poly, rsolve_ratio +from sympy.testing.pytest import raises, slow, XFAIL +from sympy.abc import a, b + +y = Function('y') +n, k = symbols('n,k', integer=True) +C0, C1, C2 = symbols('C0,C1,C2') + + +def test_rsolve_poly(): + assert rsolve_poly([-1, -1, 1], 0, n) == 0 + assert rsolve_poly([-1, -1, 1], 1, n) == -1 + + assert rsolve_poly([-1, n + 1], n, n) == 1 + assert rsolve_poly([-1, 1], n, n) == C0 + (n**2 - n)/2 + assert rsolve_poly([-n - 1, n], 1, n) == C0*n - 1 + assert rsolve_poly([-4*n - 2, 1], 4*n + 1, n) == -1 + + assert rsolve_poly([-1, 1], n**5 + n**3, n) == \ + C0 - n**3 / 2 - n**5 / 2 + n**2 / 6 + n**6 / 6 + 2*n**4 / 3 + + +def test_rsolve_ratio(): + solution = rsolve_ratio([-2*n**3 + n**2 + 2*n - 1, 2*n**3 + n**2 - 6*n, + -2*n**3 - 11*n**2 - 18*n - 9, 2*n**3 + 13*n**2 + 22*n + 8], 0, n) + assert solution == C0*(2*n - 3)/(n**2 - 1)/2 + + +def test_rsolve_hyper(): + assert rsolve_hyper([-1, -1, 1], 0, n) in [ + C0*(S.Half - S.Half*sqrt(5))**n + C1*(S.Half + S.Half*sqrt(5))**n, + C1*(S.Half - S.Half*sqrt(5))**n + C0*(S.Half + S.Half*sqrt(5))**n, + ] + + assert rsolve_hyper([n**2 - 2, -2*n - 1, 1], 0, n) in [ + C0*rf(sqrt(2), n) + C1*rf(-sqrt(2), n), + C1*rf(sqrt(2), n) + C0*rf(-sqrt(2), n), + ] + + assert rsolve_hyper([n**2 - k, -2*n - 1, 1], 0, n) in [ + C0*rf(sqrt(k), n) + C1*rf(-sqrt(k), n), + C1*rf(sqrt(k), n) + C0*rf(-sqrt(k), n), + ] + + assert rsolve_hyper( + [2*n*(n + 1), -n**2 - 3*n + 2, n - 1], 0, n) == C1*factorial(n) + C0*2**n + + assert rsolve_hyper( + [n + 2, -(2*n + 3)*(17*n**2 + 51*n + 39), n + 1], 0, n) == 0 + + assert rsolve_hyper([-n - 1, -1, 1], 0, n) == 0 + + assert rsolve_hyper([-1, 1], n, n).expand() == C0 + n**2/2 - n/2 + + assert rsolve_hyper([-1, 1], 1 + n, n).expand() == C0 + n**2/2 + n/2 + + assert rsolve_hyper([-1, 1], 3*(n + n**2), n).expand() == C0 + n**3 - n + + assert rsolve_hyper([-a, 1],0,n).expand() == C0*a**n + + assert rsolve_hyper([-a, 0, 1], 0, n).expand() == (-1)**n*C1*a**(n/2) + C0*a**(n/2) + + assert rsolve_hyper([1, 1, 1], 0, n).expand() == \ + C0*(Rational(-1, 2) - sqrt(3)*I/2)**n + C1*(Rational(-1, 2) + sqrt(3)*I/2)**n + + assert rsolve_hyper([1, -2*n/a - 2/a, 1], 0, n) == 0 + + +@XFAIL +def test_rsolve_ratio_missed(): + # this arises during computation + # assert rsolve_hyper([-1, 1], 3*(n + n**2), n).expand() == C0 + n**3 - n + assert rsolve_ratio([-n, n + 2], n, n) is not None + + +def recurrence_term(c, f): + """Compute RHS of recurrence in f(n) with coefficients in c.""" + return sum(c[i]*f.subs(n, n + i) for i in range(len(c))) + + +def test_rsolve_bulk(): + """Some bulk-generated tests.""" + funcs = [ n, n + 1, n**2, n**3, n**4, n + n**2, 27*n + 52*n**2 - 3* + n**3 + 12*n**4 - 52*n**5 ] + coeffs = [ [-2, 1], [-2, -1, 1], [-1, 1, 1, -1, 1], [-n, 1], [n**2 - + n + 12, 1] ] + for p in funcs: + # compute difference + for c in coeffs: + q = recurrence_term(c, p) + if p.is_polynomial(n): + assert rsolve_poly(c, q, n) == p + # See issue 3956: + if p.is_hypergeometric(n) and len(c) <= 3: + assert rsolve_hyper(c, q, n).subs(zip(symbols('C:3'), [0, 0, 0])).expand() == p + + +def test_rsolve_0_sol_homogeneous(): + # fixed by cherry-pick from + # https://github.com/diofant/diofant/commit/e1d2e52125199eb3df59f12e8944f8a5f24b00a5 + assert rsolve_hyper([n**2 - n + 12, 1], n*(n**2 - n + 12) + n + 1, n) == n + + +def test_rsolve(): + f = y(n + 2) - y(n + 1) - y(n) + h = sqrt(5)*(S.Half + S.Half*sqrt(5))**n \ + - sqrt(5)*(S.Half - S.Half*sqrt(5))**n + + assert rsolve(f, y(n)) in [ + C0*(S.Half - S.Half*sqrt(5))**n + C1*(S.Half + S.Half*sqrt(5))**n, + C1*(S.Half - S.Half*sqrt(5))**n + C0*(S.Half + S.Half*sqrt(5))**n, + ] + + assert rsolve(f, y(n), [0, 5]) == h + assert rsolve(f, y(n), {0: 0, 1: 5}) == h + assert rsolve(f, y(n), {y(0): 0, y(1): 5}) == h + assert rsolve(y(n) - y(n - 1) - y(n - 2), y(n), [0, 5]) == h + assert rsolve(Eq(y(n), y(n - 1) + y(n - 2)), y(n), [0, 5]) == h + + assert f.subs(y, Lambda(k, rsolve(f, y(n)).subs(n, k))).simplify() == 0 + + f = (n - 1)*y(n + 2) - (n**2 + 3*n - 2)*y(n + 1) + 2*n*(n + 1)*y(n) + g = C1*factorial(n) + C0*2**n + h = -3*factorial(n) + 3*2**n + + assert rsolve(f, y(n)) == g + assert rsolve(f, y(n), []) == g + assert rsolve(f, y(n), {}) == g + + assert rsolve(f, y(n), [0, 3]) == h + assert rsolve(f, y(n), {0: 0, 1: 3}) == h + assert rsolve(f, y(n), {y(0): 0, y(1): 3}) == h + + assert f.subs(y, Lambda(k, rsolve(f, y(n)).subs(n, k))).simplify() == 0 + + f = y(n) - y(n - 1) - 2 + + assert rsolve(f, y(n), {y(0): 0}) == 2*n + assert rsolve(f, y(n), {y(0): 1}) == 2*n + 1 + assert rsolve(f, y(n), {y(0): 0, y(1): 1}) is None + + assert f.subs(y, Lambda(k, rsolve(f, y(n)).subs(n, k))).simplify() == 0 + + f = 3*y(n - 1) - y(n) - 1 + + assert rsolve(f, y(n), {y(0): 0}) == -3**n/2 + S.Half + assert rsolve(f, y(n), {y(0): 1}) == 3**n/2 + S.Half + assert rsolve(f, y(n), {y(0): 2}) == 3*3**n/2 + S.Half + + assert f.subs(y, Lambda(k, rsolve(f, y(n)).subs(n, k))).simplify() == 0 + + f = y(n) - 1/n*y(n - 1) + assert rsolve(f, y(n)) == C0/factorial(n) + assert f.subs(y, Lambda(k, rsolve(f, y(n)).subs(n, k))).simplify() == 0 + + f = y(n) - 1/n*y(n - 1) - 1 + assert rsolve(f, y(n)) is None + + f = 2*y(n - 1) + (1 - n)*y(n)/n + + assert rsolve(f, y(n), {y(1): 1}) == 2**(n - 1)*n + assert rsolve(f, y(n), {y(1): 2}) == 2**(n - 1)*n*2 + assert rsolve(f, y(n), {y(1): 3}) == 2**(n - 1)*n*3 + + assert f.subs(y, Lambda(k, rsolve(f, y(n)).subs(n, k))).simplify() == 0 + + f = (n - 1)*(n - 2)*y(n + 2) - (n + 1)*(n + 2)*y(n) + + assert rsolve(f, y(n), {y(3): 6, y(4): 24}) == n*(n - 1)*(n - 2) + assert rsolve( + f, y(n), {y(3): 6, y(4): -24}) == -n*(n - 1)*(n - 2)*(-1)**(n) + + assert f.subs(y, Lambda(k, rsolve(f, y(n)).subs(n, k))).simplify() == 0 + + assert rsolve(Eq(y(n + 1), a*y(n)), y(n), {y(1): a}).simplify() == a**n + + assert rsolve(y(n) - a*y(n-2),y(n), \ + {y(1): sqrt(a)*(a + b), y(2): a*(a - b)}).simplify() == \ + a**(n/2 + 1) - b*(-sqrt(a))**n + + f = (-16*n**2 + 32*n - 12)*y(n - 1) + (4*n**2 - 12*n + 9)*y(n) + + yn = rsolve(f, y(n), {y(1): binomial(2*n + 1, 3)}) + sol = 2**(2*n)*n*(2*n - 1)**2*(2*n + 1)/12 + assert factor(expand(yn, func=True)) == sol + + sol = rsolve(y(n) + a*(y(n + 1) + y(n - 1))/2, y(n)) + assert str(sol) == 'C0*((-sqrt(1 - a**2) - 1)/a)**n + C1*((sqrt(1 - a**2) - 1)/a)**n' + + assert rsolve((k + 1)*y(k), y(k)) is None + assert (rsolve((k + 1)*y(k) + (k + 3)*y(k + 1) + (k + 5)*y(k + 2), y(k)) + is None) + + assert rsolve(y(n) + y(n + 1) + 2**n + 3**n, y(n)) == (-1)**n*C0 - 2**n/3 - 3**n/4 + + +def test_rsolve_raises(): + x = Function('x') + raises(ValueError, lambda: rsolve(y(n) - y(k + 1), y(n))) + raises(ValueError, lambda: rsolve(y(n) - y(n + 1), x(n))) + raises(ValueError, lambda: rsolve(y(n) - x(n + 1), y(n))) + raises(ValueError, lambda: rsolve(y(n) - sqrt(n)*y(n + 1), y(n))) + raises(ValueError, lambda: rsolve(y(n) - y(n + 1), y(n), {x(0): 0})) + raises(ValueError, lambda: rsolve(y(n) + y(n + 1) + 2**n + cos(n), y(n))) + + +def test_issue_6844(): + f = y(n + 2) - y(n + 1) + y(n)/4 + assert rsolve(f, y(n)) == 2**(-n + 1)*C1*n + 2**(-n)*C0 + assert rsolve(f, y(n), {y(0): 0, y(1): 1}) == 2**(1 - n)*n + + +def test_issue_18751(): + r = Symbol('r', positive=True) + theta = Symbol('theta', real=True) + f = y(n) - 2 * r * cos(theta) * y(n - 1) + r**2 * y(n - 2) + assert rsolve(f, y(n)) == \ + C0*(r*(cos(theta) - I*Abs(sin(theta))))**n + C1*(r*(cos(theta) + I*Abs(sin(theta))))**n + + +def test_constant_naming(): + #issue 8697 + assert rsolve(y(n+3) - y(n+2) - y(n+1) + y(n), y(n)) == (-1)**n*C1 + C0 + C2*n + assert rsolve(y(n+3)+3*y(n+2)+3*y(n+1)+y(n), y(n)).expand() == (-1)**n*C0 - (-1)**n*C1*n - (-1)**n*C2*n**2 + assert rsolve(y(n) - 2*y(n - 3) + 5*y(n - 2) - 4*y(n - 1),y(n),[1,3,8]) == 3*2**n - n - 2 + + #issue 19630 + assert rsolve(y(n+3) - 3*y(n+1) + 2*y(n), y(n), {y(1):0, y(2):8, y(3):-2}) == (-2)**n + 2*n + + +@slow +def test_issue_15751(): + f = y(n) + 21*y(n + 1) - 273*y(n + 2) - 1092*y(n + 3) + 1820*y(n + 4) + 1092*y(n + 5) - 273*y(n + 6) - 21*y(n + 7) + y(n + 8) + assert rsolve(f, y(n)) is not None + + +def test_issue_17990(): + f = -10*y(n) + 4*y(n + 1) + 6*y(n + 2) + 46*y(n + 3) + sol = rsolve(f, y(n)) + expected = C0*((86*18**(S(1)/3)/69 + (-12 + (-1 + sqrt(3)*I)*(290412 + + 3036*sqrt(9165))**(S(1)/3))*(1 - sqrt(3)*I)*(24201 + 253*sqrt(9165))** + (S(1)/3)/276)/((1 - sqrt(3)*I)*(24201 + 253*sqrt(9165))**(S(1)/3)) + )**n + C1*((86*18**(S(1)/3)/69 + (-12 + (-1 - sqrt(3)*I)*(290412 + 3036 + *sqrt(9165))**(S(1)/3))*(1 + sqrt(3)*I)*(24201 + 253*sqrt(9165))** + (S(1)/3)/276)/((1 + sqrt(3)*I)*(24201 + 253*sqrt(9165))**(S(1)/3)) + )**n + C2*(-43*18**(S(1)/3)/(69*(24201 + 253*sqrt(9165))**(S(1)/3)) - + S(1)/23 + (290412 + 3036*sqrt(9165))**(S(1)/3)/138)**n + assert sol == expected + e = sol.subs({C0: 1, C1: 1, C2: 1, n: 1}).evalf() + assert abs(e + 0.130434782608696) < 1e-13 + + +def test_issue_8697(): + a = Function('a') + eq = a(n + 3) - a(n + 2) - a(n + 1) + a(n) + assert rsolve(eq, a(n)) == (-1)**n*C1 + C0 + C2*n + eq2 = a(n + 3) + 3*a(n + 2) + 3*a(n + 1) + a(n) + assert (rsolve(eq2, a(n)) == + (-1)**n*C0 + (-1)**(n + 1)*C1*n + (-1)**(n + 1)*C2*n**2) + + assert rsolve(a(n) - 2*a(n - 3) + 5*a(n - 2) - 4*a(n - 1), + a(n), {a(0): 1, a(1): 3, a(2): 8}) == 3*2**n - n - 2 + + # From issue thread (but fixed by https://github.com/diofant/diofant/commit/da9789c6cd7d0c2ceeea19fbf59645987125b289): + assert rsolve(a(n) - 2*a(n - 1) - n, a(n), {a(0): 1}) == 3*2**n - n - 2 + + +def test_diofantissue_294(): + f = y(n) - y(n - 1) - 2*y(n - 2) - 2*n + assert rsolve(f, y(n)) == (-1)**n*C0 + 2**n*C1 - n - Rational(5, 2) + # issue sympy/sympy#11261 + assert rsolve(f, y(n), {y(0): -1, y(1): 1}) == (-(-1)**n/2 + 2*2**n - + n - Rational(5, 2)) + # issue sympy/sympy#7055 + assert rsolve(-2*y(n) + y(n + 1) + n - 1, y(n)) == 2**n*C0 + n + + +def test_issue_15553(): + f = Function("f") + assert rsolve(Eq(f(n), 2*f(n - 1) + n), f(n)) == 2**n*C0 - n - 2 + assert rsolve(Eq(f(n + 1), 2*f(n) + n**2 + 1), f(n)) == 2**n*C0 - n**2 - 2*n - 4 + assert rsolve(Eq(f(n + 1), 2*f(n) + n**2 + 1), f(n), {f(1): 0}) == 7*2**n/2 - n**2 - 2*n - 4 + assert rsolve(Eq(f(n), 2*f(n - 1) + 3*n**2), f(n)) == 2**n*C0 - 3*n**2 - 12*n - 18 + assert rsolve(Eq(f(n), 2*f(n - 1) + n**2), f(n)) == 2**n*C0 - n**2 - 4*n - 6 + assert rsolve(Eq(f(n), 2*f(n - 1) + n), f(n), {f(0): 1}) == 3*2**n - n - 2 diff --git a/phivenv/Lib/site-packages/sympy/solvers/tests/test_simplex.py b/phivenv/Lib/site-packages/sympy/solvers/tests/test_simplex.py new file mode 100644 index 0000000000000000000000000000000000000000..611205f5df009a6d0de6e687501695b63bb932c9 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/solvers/tests/test_simplex.py @@ -0,0 +1,254 @@ +from sympy.core.numbers import Rational +from sympy.core.relational import Eq, Ne +from sympy.core.symbol import symbols +from sympy.core.sympify import sympify +from sympy.core.singleton import S +from sympy.core.random import random, choice +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.ntheory.generate import randprime +from sympy.matrices.dense import Matrix +from sympy.solvers.solveset import linear_eq_to_matrix +from sympy.solvers.simplex import (_lp as lp, _primal_dual, + UnboundedLPError, InfeasibleLPError, lpmin, lpmax, + _m, _abcd, _simplex, linprog) + +from sympy.external.importtools import import_module + +from sympy.testing.pytest import raises + +from sympy.abc import x, y, z + + +np = import_module("numpy") +scipy = import_module("scipy") + + +def test_lp(): + r1 = y + 2*z <= 3 + r2 = -x - 3*z <= -2 + r3 = 2*x + y + 7*z <= 5 + constraints = [r1, r2, r3, x >= 0, y >= 0, z >= 0] + objective = -x - y - 5 * z + ans = optimum, argmax = lp(max, objective, constraints) + assert ans == lpmax(objective, constraints) + assert objective.subs(argmax) == optimum + for constr in constraints: + assert constr.subs(argmax) == True + + r1 = x - y + 2*z <= 3 + r2 = -x + 2*y - 3*z <= -2 + r3 = 2*x + y - 7*z <= -5 + constraints = [r1, r2, r3, x >= 0, y >= 0, z >= 0] + objective = -x - y - 5*z + ans = optimum, argmax = lp(max, objective, constraints) + assert ans == lpmax(objective, constraints) + assert objective.subs(argmax) == optimum + for constr in constraints: + assert constr.subs(argmax) == True + + r1 = x - y + 2*z <= -4 + r2 = -x + 2*y - 3*z <= 8 + r3 = 2*x + y - 7*z <= 10 + constraints = [r1, r2, r3, x >= 0, y >= 0, z >= 0] + const = 2 + objective = -x-y-5*z+const # has constant term + ans = optimum, argmax = lp(max, objective, constraints) + assert ans == lpmax(objective, constraints) + assert objective.subs(argmax) == optimum + for constr in constraints: + assert constr.subs(argmax) == True + + # Section 4 Problem 1 from + # http://web.tecnico.ulisboa.pt/mcasquilho/acad/or/ftp/FergusonUCLA_LP.pdf + # answer on page 55 + v = x1, x2, x3, x4 = symbols('x1 x2 x3 x4') + r1 = x1 - x2 - 2*x3 - x4 <= 4 + r2 = 2*x1 + x3 -4*x4 <= 2 + r3 = -2*x1 + x2 + x4 <= 1 + objective, constraints = x1 - 2*x2 - 3*x3 - x4, [r1, r2, r3] + [ + i >= 0 for i in v] + ans = optimum, argmax = lp(max, objective, constraints) + assert ans == lpmax(objective, constraints) + assert ans == (4, {x1: 7, x2: 0, x3: 0, x4: 3}) + + # input contains Floats + r1 = x - y + 2.0*z <= -4 + r2 = -x + 2*y - 3.0*z <= 8 + r3 = 2*x + y - 7*z <= 10 + constraints = [r1, r2, r3] + [i >= 0 for i in (x, y, z)] + objective = -x-y-5*z + optimum, argmax = lp(max, objective, constraints) + assert objective.subs(argmax) == optimum + for constr in constraints: + assert constr.subs(argmax) == True + + # input contains non-float or non-Rational + r1 = x - y + sqrt(2) * z <= -4 + r2 = -x + 2*y - 3*z <= 8 + r3 = 2*x + y - 7*z <= 10 + raises(TypeError, lambda: lp(max, -x-y-5*z, [r1, r2, r3])) + + r1 = x >= 0 + raises(UnboundedLPError, lambda: lp(max, x, [r1])) + r2 = x <= -1 + raises(InfeasibleLPError, lambda: lp(max, x, [r1, r2])) + + # strict inequalities are not allowed + r1 = x > 0 + raises(TypeError, lambda: lp(max, x, [r1])) + + # not equals not allowed + r1 = Ne(x, 0) + raises(TypeError, lambda: lp(max, x, [r1])) + + def make_random_problem(nvar=2, num_constraints=2, sparsity=.1): + def rand(): + if random() < sparsity: + return sympify(0) + int1, int2 = [randprime(0, 200) for _ in range(2)] + return Rational(int1, int2)*choice([-1, 1]) + variables = symbols('x1:%s' % (nvar + 1)) + constraints = [(sum(rand()*x for x in variables) <= rand()) + for _ in range(num_constraints)] + objective = sum(rand() * x for x in variables) + return objective, constraints, variables + + # equality + r1 = Eq(x, y) + r2 = Eq(y, z) + r3 = z <= 3 + constraints = [r1, r2, r3] + objective = x + ans = optimum, argmax = lp(max, objective, constraints) + assert ans == lpmax(objective, constraints) + assert objective.subs(argmax) == optimum + for constr in constraints: + assert constr.subs(argmax) == True + + +def test_simplex(): + L = [ + [[1, 1], [-1, 1], [0, 1], [-1, 0]], + [5, 1, 2, -1], + [[1, 1]], + [-1]] + A, B, C, D = _abcd(_m(*L), list=False) + assert _simplex(A, B, -C, -D) == (-6, [3, 2], [1, 0, 0, 0]) + assert _simplex(A, B, -C, -D, dual=True) == (-6, + [1, 0, 0, 0], [5, 0]) + + assert _simplex([[]],[],[[1]],[0]) == (0, [0], []) + + # handling of Eq (or Eq-like x<=y, x>=y conditions) + assert lpmax(x - y, [x <= y + 2, x >= y + 2, x >= 0, y >= 0] + ) == (2, {x: 2, y: 0}) + assert lpmax(x - y, [x <= y + 2, Eq(x, y + 2), x >= 0, y >= 0] + ) == (2, {x: 2, y: 0}) + assert lpmax(x - y, [x <= y + 2, Eq(x, 2)]) == (2, {x: 2, y: 0}) + assert lpmax(y, [Eq(y, 2)]) == (2, {y: 2}) + + # the conditions are equivalent to Eq(x, y + 2) + assert lpmin(y, [x <= y + 2, x >= y + 2, y >= 0] + ) == (0, {x: 2, y: 0}) + # equivalent to Eq(y, -2) + assert lpmax(y, [0 <= y + 2, 0 >= y + 2]) == (-2, {y: -2}) + assert lpmax(y, [0 <= y + 2, 0 >= y + 2, y <= 0] + ) == (-2, {y: -2}) + + # extra symbols symbols + assert lpmin(x, [y >= 1, x >= y]) == (1, {x: 1, y: 1}) + assert lpmin(x, [y >= 1, x >= y + z, x >= 0, z >= 0] + ) == (1, {x: 1, y: 1, z: 0}) + + # detect oscillation + # o1 + v = x1, x2, x3, x4 = symbols('x1 x2 x3 x4') + raises(InfeasibleLPError, lambda: lpmin( + 9*x2 - 8*x3 + 3*x4 + 6, + [5*x2 - 2*x3 <= 0, + -x1 - 8*x2 + 9*x3 <= -3, + 10*x1 - x2+ 9*x4 <= -4] + [i >= 0 for i in v])) + # o2 - equations fed to lpmin are changed into a matrix + # system that doesn't oscillate and has the same solution + # as below + M = linear_eq_to_matrix + f = 5*x2 + x3 + 4*x4 - x1 + L = 5*x2 + 2*x3 + 5*x4 - (x1 + 5) + cond = [L <= 0] + [Eq(3*x2 + x4, 2), Eq(-x1 + x3 + 2*x4, 1)] + c, d = M(f, v) + a, b = M(L, v) + aeq, beq = M(cond[1:], v) + ans = (S(9)/2, [0, S(1)/2, 0, S(1)/2]) + assert linprog(c, a, b, aeq, beq, bounds=(0, 1)) == ans + lpans = lpmin(f, cond + [x1 >= 0, x1 <= 1, + x2 >= 0, x2 <= 1, x3 >= 0, x3 <= 1, x4 >= 0, x4 <= 1]) + assert (lpans[0], list(lpans[1].values())) == ans + + +def test_lpmin_lpmax(): + v = x1, x2, y1, y2 = symbols('x1 x2 y1 y2') + L = [[1, -1]], [1], [[1, 1]], [2] + a, b, c, d = [Matrix(i) for i in L] + m = Matrix([[a, b], [c, d]]) + f, constr = _primal_dual(m)[0] + ans = lpmin(f, constr + [i >= 0 for i in v[:2]]) + assert ans == (-1, {x1: 1, x2: 0}),ans + + L = [[1, -1], [1, 1]], [1, 1], [[1, 1]], [2] + a, b, c, d = [Matrix(i) for i in L] + m = Matrix([[a, b], [c, d]]) + f, constr = _primal_dual(m)[1] + ans = lpmax(f, constr + [i >= 0 for i in v[-2:]]) + assert ans == (-1, {y1: 1, y2: 0}) + + +def test_linprog(): + for do in range(2): + if not do: + M = lambda a, b: linear_eq_to_matrix(a, b) + else: + # check matrices as list + M = lambda a, b: tuple([ + i.tolist() for i in linear_eq_to_matrix(a, b)]) + + v = x, y, z = symbols('x1:4') + f = x + y - 2*z + c = M(f, v)[0] + ineq = [7*x + 4*y - 7*z <= 3, + 3*x - y + 10*z <= 6, + x >= 0, y >= 0, z >= 0] + ab = M([i.lts - i.gts for i in ineq], v) + ans = (-S(6)/5, [0, 0, S(3)/5]) + assert lpmin(f, ineq) == (ans[0], dict(zip(v, ans[1]))) + assert linprog(c, *ab) == ans + + f += 1 + c = M(f, v)[0] + eq = [Eq(y - 9*x, 1)] + abeq = M([i.lhs - i.rhs for i in eq], v) + ans = (1 - S(2)/5, [0, 1, S(7)/10]) + assert lpmin(f, ineq + eq) == (ans[0], dict(zip(v, ans[1]))) + assert linprog(c, *ab, *abeq) == (ans[0] - 1, ans[1]) + + eq = [z - y <= S.Half] + abeq = M([i.lhs - i.rhs for i in eq], v) + ans = (1 - S(10)/9, [0, S(1)/9, S(11)/18]) + assert lpmin(f, ineq + eq) == (ans[0], dict(zip(v, ans[1]))) + assert linprog(c, *ab, *abeq) == (ans[0] - 1, ans[1]) + + bounds = [(0, None), (0, None), (None, S.Half)] + ans = (0, [0, 0, S.Half]) + assert lpmin(f, ineq + [z <= S.Half]) == ( + ans[0], dict(zip(v, ans[1]))) + assert linprog(c, *ab, bounds=bounds) == (ans[0] - 1, ans[1]) + assert linprog(c, *ab, bounds={v.index(z): bounds[-1]} + ) == (ans[0] - 1, ans[1]) + eq = [z - y <= S.Half] + + assert linprog([[1]], [], [], bounds=(2, 3)) == (2, [2]) + assert linprog([1], [], [], bounds=(2, 3)) == (2, [2]) + assert linprog([1], bounds=(2, 3)) == (2, [2]) + assert linprog([1, -1], [[1, 1]], [2], bounds={1:(None, None)} + ) == (-2, [0, 2]) + assert linprog([1, -1], [[1, 1]], [5], bounds={1:(3, None)} + ) == (-5, [0, 5]) diff --git a/phivenv/Lib/site-packages/sympy/solvers/tests/test_solvers.py b/phivenv/Lib/site-packages/sympy/solvers/tests/test_solvers.py new file mode 100644 index 0000000000000000000000000000000000000000..ac9550ad404c2ec7592caf6afd2910f425138987 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/solvers/tests/test_solvers.py @@ -0,0 +1,2725 @@ +from sympy.assumptions.ask import (Q, ask) +from sympy.core.add import Add +from sympy.core.containers import Tuple +from sympy.core.function import (Derivative, Function, diff) +from sympy.core.mod import Mod +from sympy.core.mul import Mul +from sympy.core import (GoldenRatio, TribonacciConstant) +from sympy.core.numbers import (E, Float, I, Rational, oo, pi) +from sympy.core.relational import (Eq, Gt, Lt, Ne) +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol, Wild, symbols) +from sympy.core.sympify import sympify +from sympy.functions.combinatorial.factorials import binomial +from sympy.functions.elementary.complexes import (Abs, arg, conjugate, im, re) +from sympy.functions.elementary.exponential import (LambertW, exp, log) +from sympy.functions.elementary.hyperbolic import (atanh, cosh, sinh, tanh) +from sympy.functions.elementary.integers import floor +from sympy.functions.elementary.miscellaneous import (cbrt, root, sqrt) +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import (acos, asin, atan, atan2, cos, sec, sin, tan) +from sympy.functions.special.error_functions import (erf, erfc, erfcinv, erfinv) +from sympy.integrals.integrals import Integral +from sympy.logic.boolalg import (And, Or) +from sympy.matrices.dense import Matrix +from sympy.matrices import MatrixSymbol, SparseMatrix +from sympy.polys.polytools import Poly, groebner +from sympy.printing.str import sstr +from sympy.simplify.radsimp import denom +from sympy.solvers.solvers import (nsolve, solve, solve_linear) + +from sympy.core.function import nfloat +from sympy.solvers import solve_linear_system, solve_linear_system_LU, \ + solve_undetermined_coeffs +from sympy.solvers.bivariate import _filtered_gens, _solve_lambert, _lambert +from sympy.solvers.solvers import _invert, unrad, checksol, posify, _ispow, \ + det_quick, det_perm, det_minor, _simple_dens, denoms + +from sympy.physics.units import cm +from sympy.polys.rootoftools import CRootOf + +from sympy.testing.pytest import slow, XFAIL, SKIP, raises +from sympy.core.random import verify_numerically as tn + +from sympy.abc import a, b, c, d, e, k, h, p, x, y, z, t, q, m, R + + +def NS(e, n=15, **options): + return sstr(sympify(e).evalf(n, **options), full_prec=True) + + +def test_swap_back(): + f, g = map(Function, 'fg') + fx, gx = f(x), g(x) + assert solve([fx + y - 2, fx - gx - 5], fx, y, gx) == \ + {fx: gx + 5, y: -gx - 3} + assert solve(fx + gx*x - 2, [fx, gx], dict=True) == [{fx: 2, gx: 0}] + assert solve(fx + gx**2*x - y, [fx, gx], dict=True) == [{fx: y, gx: 0}] + assert solve([f(1) - 2, x + 2], dict=True) == [{x: -2, f(1): 2}] + + +def guess_solve_strategy(eq, symbol): + try: + solve(eq, symbol) + return True + except (TypeError, NotImplementedError): + return False + + +def test_guess_poly(): + # polynomial equations + assert guess_solve_strategy( S(4), x ) # == GS_POLY + assert guess_solve_strategy( x, x ) # == GS_POLY + assert guess_solve_strategy( x + a, x ) # == GS_POLY + assert guess_solve_strategy( 2*x, x ) # == GS_POLY + assert guess_solve_strategy( x + sqrt(2), x) # == GS_POLY + assert guess_solve_strategy( x + 2**Rational(1, 4), x) # == GS_POLY + assert guess_solve_strategy( x**2 + 1, x ) # == GS_POLY + assert guess_solve_strategy( x**2 - 1, x ) # == GS_POLY + assert guess_solve_strategy( x*y + y, x ) # == GS_POLY + assert guess_solve_strategy( x*exp(y) + y, x) # == GS_POLY + assert guess_solve_strategy( + (x - y**3)/(y**2*sqrt(1 - y**2)), x) # == GS_POLY + + +def test_guess_poly_cv(): + # polynomial equations via a change of variable + assert guess_solve_strategy( sqrt(x) + 1, x ) # == GS_POLY_CV_1 + assert guess_solve_strategy( + x**Rational(1, 3) + sqrt(x) + 1, x ) # == GS_POLY_CV_1 + assert guess_solve_strategy( 4*x*(1 - sqrt(x)), x ) # == GS_POLY_CV_1 + + # polynomial equation multiplying both sides by x**n + assert guess_solve_strategy( x + 1/x + y, x ) # == GS_POLY_CV_2 + + +def test_guess_rational_cv(): + # rational functions + assert guess_solve_strategy( (x + 1)/(x**2 + 2), x) # == GS_RATIONAL + assert guess_solve_strategy( + (x - y**3)/(y**2*sqrt(1 - y**2)), y) # == GS_RATIONAL_CV_1 + + # rational functions via the change of variable y -> x**n + assert guess_solve_strategy( (sqrt(x) + 1)/(x**Rational(1, 3) + sqrt(x) + 1), x ) \ + #== GS_RATIONAL_CV_1 + + +def test_guess_transcendental(): + #transcendental functions + assert guess_solve_strategy( exp(x) + 1, x ) # == GS_TRANSCENDENTAL + assert guess_solve_strategy( 2*cos(x) - y, x ) # == GS_TRANSCENDENTAL + assert guess_solve_strategy( + exp(x) + exp(-x) - y, x ) # == GS_TRANSCENDENTAL + assert guess_solve_strategy(3**x - 10, x) # == GS_TRANSCENDENTAL + assert guess_solve_strategy(-3**x + 10, x) # == GS_TRANSCENDENTAL + + assert guess_solve_strategy(a*x**b - y, x) # == GS_TRANSCENDENTAL + + +def test_solve_args(): + # equation container, issue 5113 + ans = {x: -3, y: 1} + eqs = (x + 5*y - 2, -3*x + 6*y - 15) + assert all(solve(container(eqs), x, y) == ans for container in + (tuple, list, set, frozenset)) + assert solve(Tuple(*eqs), x, y) == ans + # implicit symbol to solve for + assert set(solve(x**2 - 4)) == {S(2), -S(2)} + assert solve([x + y - 3, x - y - 5]) == {x: 4, y: -1} + assert solve(x - exp(x), x, implicit=True) == [exp(x)] + # no symbol to solve for + assert solve(42) == solve(42, x) == [] + assert solve([1, 2]) == [] + assert solve([sqrt(2)],[x]) == [] + # duplicate symbols raises + raises(ValueError, lambda: solve((x - 3, y + 2), x, y, x)) + raises(ValueError, lambda: solve(x, x, x)) + # no error in exclude + assert solve(x, x, exclude=[y, y]) == [0] + # duplicate symbols raises + raises(ValueError, lambda: solve((x - 3, y + 2), x, y, x)) + raises(ValueError, lambda: solve(x, x, x)) + # no error in exclude + assert solve(x, x, exclude=[y, y]) == [0] + # unordered symbols + # only 1 + assert solve(y - 3, {y}) == [3] + # more than 1 + assert solve(y - 3, {x, y}) == [{y: 3}] + # multiple symbols: take the first linear solution+ + # - return as tuple with values for all requested symbols + assert solve(x + y - 3, [x, y]) == [(3 - y, y)] + # - unless dict is True + assert solve(x + y - 3, [x, y], dict=True) == [{x: 3 - y}] + # - or no symbols are given + assert solve(x + y - 3) == [{x: 3 - y}] + # multiple symbols might represent an undetermined coefficients system + assert solve(a + b*x - 2, [a, b]) == {a: 2, b: 0} + assert solve((a + b)*x + b - c, [a, b]) == {a: -c, b: c} + eq = a*x**2 + b*x + c - ((x - h)**2 + 4*p*k)/4/p + # - check that flags are obeyed + sol = solve(eq, [h, p, k], exclude=[a, b, c]) + assert sol == {h: -b/(2*a), k: (4*a*c - b**2)/(4*a), p: 1/(4*a)} + assert solve(eq, [h, p, k], dict=True) == [sol] + assert solve(eq, [h, p, k], set=True) == \ + ([h, p, k], {(-b/(2*a), 1/(4*a), (4*a*c - b**2)/(4*a))}) + # issue 23889 - polysys not simplified + assert solve(eq, [h, p, k], exclude=[a, b, c], simplify=False) == \ + {h: -b/(2*a), k: (4*a*c - b**2)/(4*a), p: 1/(4*a)} + # but this only happens when system has a single solution + args = (a + b)*x - b**2 + 2, a, b + assert solve(*args) == [((b**2 - b*x - 2)/x, b)] + # and if the system has a solution; the following doesn't so + # an algebraic solution is returned + assert solve(a*x + b**2/(x + 4) - 3*x - 4/x, a, b, dict=True) == \ + [{a: (-b**2*x + 3*x**3 + 12*x**2 + 4*x + 16)/(x**2*(x + 4))}] + # failed single equation + assert solve(1/(1/x - y + exp(y))) == [] + raises( + NotImplementedError, lambda: solve(exp(x) + sin(x) + exp(y) + sin(y))) + # failed system + # -- when no symbols given, 1 fails + assert solve([y, exp(x) + x]) == [{x: -LambertW(1), y: 0}] + # both fail + assert solve( + (exp(x) - x, exp(y) - y)) == [{x: -LambertW(-1), y: -LambertW(-1)}] + # -- when symbols given + assert solve([y, exp(x) + x], x, y) == [(-LambertW(1), 0)] + # symbol is a number + assert solve(x**2 - pi, pi) == [x**2] + # no equations + assert solve([], [x]) == [] + # nonlinear system + assert solve((x**2 - 4, y - 2), x, y) == [(-2, 2), (2, 2)] + assert solve((x**2 - 4, y - 2), y, x) == [(2, -2), (2, 2)] + assert solve((x**2 - 4 + z, y - 2 - z), a, z, y, x, set=True + ) == ([a, z, y, x], { + (a, z, z + 2, -sqrt(4 - z)), + (a, z, z + 2, sqrt(4 - z))}) + # overdetermined system + # - nonlinear + assert solve([(x + y)**2 - 4, x + y - 2]) == [{x: -y + 2}] + # - linear + assert solve((x + y - 2, 2*x + 2*y - 4)) == {x: -y + 2} + # When one or more args are Boolean + assert solve(Eq(x**2, 0.0)) == [0.0] # issue 19048 + assert solve([True, Eq(x, 0)], [x], dict=True) == [{x: 0}] + assert solve([Eq(x, x), Eq(x, 0), Eq(x, x+1)], [x], dict=True) == [] + assert not solve([Eq(x, x+1), x < 2], x) + assert solve([Eq(x, 0), x+1<2]) == Eq(x, 0) + assert solve([Eq(x, x), Eq(x, x+1)], x) == [] + assert solve(True, x) == [] + assert solve([x - 1, False], [x], set=True) == ([], set()) + assert solve([-y*(x + y - 1)/2, (y - 1)/x/y + 1/y], + set=True, check=False) == ([x, y], {(1 - y, y), (x, 0)}) + # ordering should be canonical, fastest to order by keys instead + # of by size + assert list(solve((y - 1, x - sqrt(3)*z)).keys()) == [x, y] + # as set always returns as symbols, set even if no solution + assert solve([x - 1, x], (y, x), set=True) == ([y, x], set()) + assert solve([x - 1, x], {y, x}, set=True) == ([x, y], set()) + + +def test_solve_polynomial1(): + assert solve(3*x - 2, x) == [Rational(2, 3)] + assert solve(Eq(3*x, 2), x) == [Rational(2, 3)] + + assert set(solve(x**2 - 1, x)) == {-S.One, S.One} + assert set(solve(Eq(x**2, 1), x)) == {-S.One, S.One} + + assert solve(x - y**3, x) == [y**3] + rx = root(x, 3) + assert solve(x - y**3, y) == [ + rx, -rx/2 - sqrt(3)*I*rx/2, -rx/2 + sqrt(3)*I*rx/2] + a11, a12, a21, a22, b1, b2 = symbols('a11,a12,a21,a22,b1,b2') + + assert solve([a11*x + a12*y - b1, a21*x + a22*y - b2], x, y) == \ + { + x: (a22*b1 - a12*b2)/(a11*a22 - a12*a21), + y: (a11*b2 - a21*b1)/(a11*a22 - a12*a21), + } + + solution = {x: S.Zero, y: S.Zero} + + assert solve((x - y, x + y), x, y ) == solution + assert solve((x - y, x + y), (x, y)) == solution + assert solve((x - y, x + y), [x, y]) == solution + + assert set(solve(x**3 - 15*x - 4, x)) == { + -2 + 3**S.Half, + S(4), + -2 - 3**S.Half + } + + assert set(solve((x**2 - 1)**2 - a, x)) == \ + {sqrt(1 + sqrt(a)), -sqrt(1 + sqrt(a)), + sqrt(1 - sqrt(a)), -sqrt(1 - sqrt(a))} + + +def test_solve_polynomial2(): + assert solve(4, x) == [] + + +def test_solve_polynomial_cv_1a(): + """ + Test for solving on equations that can be converted to a polynomial equation + using the change of variable y -> x**Rational(p, q) + """ + assert solve( sqrt(x) - 1, x) == [1] + assert solve( sqrt(x) - 2, x) == [4] + assert solve( x**Rational(1, 4) - 2, x) == [16] + assert solve( x**Rational(1, 3) - 3, x) == [27] + assert solve(sqrt(x) + x**Rational(1, 3) + x**Rational(1, 4), x) == [0] + + +def test_solve_polynomial_cv_1b(): + assert set(solve(4*x*(1 - a*sqrt(x)), x)) == {S.Zero, 1/a**2} + assert set(solve(x*(root(x, 3) - 3), x)) == {S.Zero, S(27)} + + +def test_solve_polynomial_cv_2(): + """ + Test for solving on equations that can be converted to a polynomial equation + multiplying both sides of the equation by x**m + """ + assert solve(x + 1/x - 1, x) in \ + [[ S.Half + I*sqrt(3)/2, S.Half - I*sqrt(3)/2], + [ S.Half - I*sqrt(3)/2, S.Half + I*sqrt(3)/2]] + + +def test_quintics_1(): + f = x**5 - 110*x**3 - 55*x**2 + 2310*x + 979 + s = solve(f, check=False) + for r in s: + res = f.subs(x, r.n()).n() + assert tn(res, 0) + + f = x**5 - 15*x**3 - 5*x**2 + 10*x + 20 + s = solve(f) + for r in s: + assert r.func == CRootOf + + # if one uses solve to get the roots of a polynomial that has a CRootOf + # solution, make sure that the use of nfloat during the solve process + # doesn't fail. Note: if you want numerical solutions to a polynomial + # it is *much* faster to use nroots to get them than to solve the + # equation only to get RootOf solutions which are then numerically + # evaluated. So for eq = x**5 + 3*x + 7 do Poly(eq).nroots() rather + # than [i.n() for i in solve(eq)] to get the numerical roots of eq. + assert nfloat(solve(x**5 + 3*x**3 + 7)[0], exponent=False) == \ + CRootOf(x**5 + 3*x**3 + 7, 0).n() + + +def test_quintics_2(): + f = x**5 + 15*x + 12 + s = solve(f, check=False) + for r in s: + res = f.subs(x, r.n()).n() + assert tn(res, 0) + + f = x**5 - 15*x**3 - 5*x**2 + 10*x + 20 + s = solve(f) + for r in s: + assert r.func == CRootOf + + assert solve(x**5 - 6*x**3 - 6*x**2 + x - 6) == [ + CRootOf(x**5 - 6*x**3 - 6*x**2 + x - 6, 0), + CRootOf(x**5 - 6*x**3 - 6*x**2 + x - 6, 1), + CRootOf(x**5 - 6*x**3 - 6*x**2 + x - 6, 2), + CRootOf(x**5 - 6*x**3 - 6*x**2 + x - 6, 3), + CRootOf(x**5 - 6*x**3 - 6*x**2 + x - 6, 4)] + +def test_quintics_3(): + y = x**5 + x**3 - 2**Rational(1, 3) + assert solve(y) == solve(-y) == [] + + +def test_highorder_poly(): + # just testing that the uniq generator is unpacked + sol = solve(x**6 - 2*x + 2) + assert all(isinstance(i, CRootOf) for i in sol) and len(sol) == 6 + + +def test_solve_rational(): + """Test solve for rational functions""" + assert solve( ( x - y**3 )/( (y**2)*sqrt(1 - y**2) ), x) == [y**3] + + +def test_solve_conjugate(): + """Test solve for simple conjugate functions""" + assert solve(conjugate(x) -3 + I) == [3 + I] + + +def test_solve_nonlinear(): + assert solve(x**2 - y**2, x, y, dict=True) == [{x: -y}, {x: y}] + assert solve(x**2 - y**2/exp(x), y, x, dict=True) == [{y: -x*sqrt(exp(x))}, + {y: x*sqrt(exp(x))}] + + +def test_issue_8666(): + x = symbols('x') + assert solve(Eq(x**2 - 1/(x**2 - 4), 4 - 1/(x**2 - 4)), x) == [] + assert solve(Eq(x + 1/x, 1/x), x) == [] + + +def test_issue_7228(): + assert solve(4**(2*(x**2) + 2*x) - 8, x) == [Rational(-3, 2), S.Half] + + +def test_issue_7190(): + assert solve(log(x-3) + log(x+3), x) == [sqrt(10)] + + +def test_issue_21004(): + x = symbols('x') + f = x/sqrt(x**2+1) + f_diff = f.diff(x) + assert solve(f_diff, x) == [] + + +def test_issue_24650(): + x = symbols('x') + r = solve(Eq(Piecewise((x, Eq(x, 0) | (x > 1))), 0)) + assert r == [0] + r = checksol(Eq(Piecewise((x, Eq(x, 0) | (x > 1))), 0), x, sol=0) + assert r is True + + +def test_linear_system(): + x, y, z, t, n = symbols('x, y, z, t, n') + + assert solve([x - 1, x - y, x - 2*y, y - 1], [x, y]) == [] + + assert solve([x - 1, x - y, x - 2*y, x - 1], [x, y]) == [] + assert solve([x - 1, x - 1, x - y, x - 2*y], [x, y]) == [] + + assert solve([x + 5*y - 2, -3*x + 6*y - 15], x, y) == {x: -3, y: 1} + + M = Matrix([[0, 0, n*(n + 1), (n + 1)**2, 0], + [n + 1, n + 1, -2*n - 1, -(n + 1), 0], + [-1, 0, 1, 0, 0]]) + + assert solve_linear_system(M, x, y, z, t) == \ + {x: t*(-n-1)/n, y: 0, z: t*(-n-1)/n} + + assert solve([x + y + z + t, -z - t], x, y, z, t) == {x: -y, z: -t} + + +@XFAIL +def test_linear_system_xfail(): + # https://github.com/sympy/sympy/issues/6420 + M = Matrix([[0, 15.0, 10.0, 700.0], + [1, 1, 1, 100.0], + [0, 10.0, 5.0, 200.0], + [-5.0, 0, 0, 0 ]]) + + assert solve_linear_system(M, x, y, z) == {x: 0, y: -60.0, z: 160.0} + + +def test_linear_system_function(): + a = Function('a') + assert solve([a(0, 0) + a(0, 1) + a(1, 0) + a(1, 1), -a(1, 0) - a(1, 1)], + a(0, 0), a(0, 1), a(1, 0), a(1, 1)) == {a(1, 0): -a(1, 1), a(0, 0): -a(0, 1)} + + +def test_linear_system_symbols_doesnt_hang_1(): + + def _mk_eqs(wy): + # Equations for fitting a wy*2 - 1 degree polynomial between two points, + # at end points derivatives are known up to order: wy - 1 + order = 2*wy - 1 + x, x0, x1 = symbols('x, x0, x1', real=True) + y0s = symbols('y0_:{}'.format(wy), real=True) + y1s = symbols('y1_:{}'.format(wy), real=True) + c = symbols('c_:{}'.format(order+1), real=True) + + expr = sum(coeff*x**o for o, coeff in enumerate(c)) + eqs = [] + for i in range(wy): + eqs.append(expr.diff(x, i).subs({x: x0}) - y0s[i]) + eqs.append(expr.diff(x, i).subs({x: x1}) - y1s[i]) + return eqs, c + + # + # The purpose of this test is just to see that these calls don't hang. The + # expressions returned are complicated so are not included here. Testing + # their correctness takes longer than solving the system. + # + + for n in range(1, 7+1): + eqs, c = _mk_eqs(n) + solve(eqs, c) + + +def test_linear_system_symbols_doesnt_hang_2(): + + M = Matrix([ + [66, 24, 39, 50, 88, 40, 37, 96, 16, 65, 31, 11, 37, 72, 16, 19, 55, 37, 28, 76], + [10, 93, 34, 98, 59, 44, 67, 74, 74, 94, 71, 61, 60, 23, 6, 2, 57, 8, 29, 78], + [19, 91, 57, 13, 64, 65, 24, 53, 77, 34, 85, 58, 87, 39, 39, 7, 36, 67, 91, 3], + [74, 70, 15, 53, 68, 43, 86, 83, 81, 72, 25, 46, 67, 17, 59, 25, 78, 39, 63, 6], + [69, 40, 67, 21, 67, 40, 17, 13, 93, 44, 46, 89, 62, 31, 30, 38, 18, 20, 12, 81], + [50, 22, 74, 76, 34, 45, 19, 76, 28, 28, 11, 99, 97, 82, 8, 46, 99, 57, 68, 35], + [58, 18, 45, 88, 10, 64, 9, 34, 90, 82, 17, 41, 43, 81, 45, 83, 22, 88, 24, 39], + [42, 21, 70, 68, 6, 33, 64, 81, 83, 15, 86, 75, 86, 17, 77, 34, 62, 72, 20, 24], + [ 7, 8, 2, 72, 71, 52, 96, 5, 32, 51, 31, 36, 79, 88, 25, 77, 29, 26, 33, 13], + [19, 31, 30, 85, 81, 39, 63, 28, 19, 12, 16, 49, 37, 66, 38, 13, 3, 71, 61, 51], + [29, 82, 80, 49, 26, 85, 1, 37, 2, 74, 54, 82, 26, 47, 54, 9, 35, 0, 99, 40], + [15, 49, 82, 91, 93, 57, 45, 25, 45, 97, 15, 98, 48, 52, 66, 24, 62, 54, 97, 37], + [62, 23, 73, 53, 52, 86, 28, 38, 0, 74, 92, 38, 97, 70, 71, 29, 26, 90, 67, 45], + [ 2, 32, 23, 24, 71, 37, 25, 71, 5, 41, 97, 65, 93, 13, 65, 45, 25, 88, 69, 50], + [40, 56, 1, 29, 79, 98, 79, 62, 37, 28, 45, 47, 3, 1, 32, 74, 98, 35, 84, 32], + [33, 15, 87, 79, 65, 9, 14, 63, 24, 19, 46, 28, 74, 20, 29, 96, 84, 91, 93, 1], + [97, 18, 12, 52, 1, 2, 50, 14, 52, 76, 19, 82, 41, 73, 51, 79, 13, 3, 82, 96], + [40, 28, 52, 10, 10, 71, 56, 78, 82, 5, 29, 48, 1, 26, 16, 18, 50, 76, 86, 52], + [38, 89, 83, 43, 29, 52, 90, 77, 57, 0, 67, 20, 81, 88, 48, 96, 88, 58, 14, 3]]) + + syms = x0,x1,x2,x3,x4,x5,x6,x7,x8,x9,x10,x11,x12,x13,x14,x15,x16,x17,x18 = symbols('x:19') + + sol = { + x0: -S(1967374186044955317099186851240896179)/3166636564687820453598895768302256588, + x1: -S(84268280268757263347292368432053826)/791659141171955113399723942075564147, + x2: -S(229962957341664730974463872411844965)/1583318282343910226799447884151128294, + x3: S(990156781744251750886760432229180537)/6333273129375640907197791536604513176, + x4: -S(2169830351210066092046760299593096265)/18999819388126922721593374609813539528, + x5: S(4680868883477577389628494526618745355)/9499909694063461360796687304906769764, + x6: -S(1590820774344371990683178396480879213)/3166636564687820453598895768302256588, + x7: -S(54104723404825537735226491634383072)/339282489073695048599881689460956063, + x8: S(3182076494196560075964847771774733847)/6333273129375640907197791536604513176, + x9: -S(10870817431029210431989147852497539675)/18999819388126922721593374609813539528, + x10: -S(13118019242576506476316318268573312603)/18999819388126922721593374609813539528, + x11: -S(5173852969886775824855781403820641259)/4749954847031730680398343652453384882, + x12: S(4261112042731942783763341580651820563)/4749954847031730680398343652453384882, + x13: -S(821833082694661608993818117038209051)/6333273129375640907197791536604513176, + x14: S(906881575107250690508618713632090559)/904753304196520129599684505229216168, + x15: -S(732162528717458388995329317371283987)/6333273129375640907197791536604513176, + x16: S(4524215476705983545537087360959896817)/9499909694063461360796687304906769764, + x17: -S(3898571347562055611881270844646055217)/6333273129375640907197791536604513176, + x18: S(7513502486176995632751685137907442269)/18999819388126922721593374609813539528 + } + + eqs = list(M * Matrix(syms + (1,))) + assert solve(eqs, syms) == sol + + y = Symbol('y') + eqs = list(y * M * Matrix(syms + (1,))) + assert solve(eqs, syms) == sol + + +def test_linear_systemLU(): + n = Symbol('n') + + M = Matrix([[1, 2, 0, 1], [1, 3, 2*n, 1], [4, -1, n**2, 1]]) + + assert solve_linear_system_LU(M, [x, y, z]) == {z: -3/(n**2 + 18*n), + x: 1 - 12*n/(n**2 + 18*n), + y: 6*n/(n**2 + 18*n)} + +# Note: multiple solutions exist for some of these equations, so the tests +# should be expected to break if the implementation of the solver changes +# in such a way that a different branch is chosen + +@slow +def test_solve_transcendental(): + from sympy.abc import a, b + + assert solve(exp(x) - 3, x) == [log(3)] + assert set(solve((a*x + b)*(exp(x) - 3), x)) == {-b/a, log(3)} + assert solve(cos(x) - y, x) == [-acos(y) + 2*pi, acos(y)] + assert solve(2*cos(x) - y, x) == [-acos(y/2) + 2*pi, acos(y/2)] + assert solve(Eq(cos(x), sin(x)), x) == [pi/4] + + assert set(solve(exp(x) + exp(-x) - y, x)) in [{ + log(y/2 - sqrt(y**2 - 4)/2), + log(y/2 + sqrt(y**2 - 4)/2), + }, { + log(y - sqrt(y**2 - 4)) - log(2), + log(y + sqrt(y**2 - 4)) - log(2)}, + { + log(y/2 - sqrt((y - 2)*(y + 2))/2), + log(y/2 + sqrt((y - 2)*(y + 2))/2)}] + assert solve(exp(x) - 3, x) == [log(3)] + assert solve(Eq(exp(x), 3), x) == [log(3)] + assert solve(log(x) - 3, x) == [exp(3)] + assert solve(sqrt(3*x) - 4, x) == [Rational(16, 3)] + assert solve(3**(x + 2), x) == [] + assert solve(3**(2 - x), x) == [] + assert solve(x + 2**x, x) == [-LambertW(log(2))/log(2)] + assert solve(2*x + 5 + log(3*x - 2), x) == \ + [Rational(2, 3) + LambertW(2*exp(Rational(-19, 3))/3)/2] + assert solve(3*x + log(4*x), x) == [LambertW(Rational(3, 4))/3] + assert set(solve((2*x + 8)*(8 + exp(x)), x)) == {S(-4), log(8) + pi*I} + eq = 2*exp(3*x + 4) - 3 + ans = solve(eq, x) # this generated a failure in flatten + assert len(ans) == 3 and all(eq.subs(x, a).n(chop=True) == 0 for a in ans) + assert solve(2*log(3*x + 4) - 3, x) == [(exp(Rational(3, 2)) - 4)/3] + assert solve(exp(x) + 1, x) == [pi*I] + + eq = 2*(3*x + 4)**5 - 6*7**(3*x + 9) + result = solve(eq, x) + x0 = -log(2401) + x1 = 3**Rational(1, 5) + x2 = log(7**(7*x1/20)) + x3 = sqrt(2) + x4 = sqrt(5) + x5 = x3*sqrt(x4 - 5) + x6 = x4 + 1 + x7 = 1/(3*log(7)) + x8 = -x4 + x9 = x3*sqrt(x8 - 5) + x10 = x8 + 1 + ans = [x7*(x0 - 5*LambertW(x2*(-x5 + x6))), + x7*(x0 - 5*LambertW(x2*(x5 + x6))), + x7*(x0 - 5*LambertW(x2*(x10 - x9))), + x7*(x0 - 5*LambertW(x2*(x10 + x9))), + x7*(x0 - 5*LambertW(-log(7**(7*x1/5))))] + assert result == ans, result + # it works if expanded, too + assert solve(eq.expand(), x) == result + + assert solve(z*cos(x) - y, x) == [-acos(y/z) + 2*pi, acos(y/z)] + assert solve(z*cos(2*x) - y, x) == [-acos(y/z)/2 + pi, acos(y/z)/2] + assert solve(z*cos(sin(x)) - y, x) == [ + pi - asin(acos(y/z)), asin(acos(y/z) - 2*pi) + pi, + -asin(acos(y/z) - 2*pi), asin(acos(y/z))] + + assert solve(z*cos(x), x) == [pi/2, pi*Rational(3, 2)] + + # issue 4508 + assert solve(y - b*x/(a + x), x) in [[-a*y/(y - b)], [a*y/(b - y)]] + assert solve(y - b*exp(a/x), x) == [a/log(y/b)] + # issue 4507 + assert solve(y - b/(1 + a*x), x) in [[(b - y)/(a*y)], [-((y - b)/(a*y))]] + # issue 4506 + assert solve(y - a*x**b, x) == [(y/a)**(1/b)] + # issue 4505 + assert solve(z**x - y, x) == [log(y)/log(z)] + # issue 4504 + assert solve(2**x - 10, x) == [1 + log(5)/log(2)] + # issue 6744 + assert solve(x*y) == [{x: 0}, {y: 0}] + assert solve([x*y]) == [{x: 0}, {y: 0}] + assert solve(x**y - 1) == [{x: 1}, {y: 0}] + assert solve([x**y - 1]) == [{x: 1}, {y: 0}] + assert solve(x*y*(x**2 - y**2)) == [{x: 0}, {x: -y}, {x: y}, {y: 0}] + assert solve([x*y*(x**2 - y**2)]) == [{x: 0}, {x: -y}, {x: y}, {y: 0}] + # issue 4739 + assert solve(exp(log(5)*x) - 2**x, x) == [0] + # issue 14791 + assert solve(exp(log(5)*x) - exp(log(2)*x), x) == [0] + f = Function('f') + assert solve(y*f(log(5)*x) - y*f(log(2)*x), x) == [0] + assert solve(f(x) - f(0), x) == [0] + assert solve(f(x) - f(2 - x), x) == [1] + raises(NotImplementedError, lambda: solve(f(x, y) - f(1, 2), x)) + raises(NotImplementedError, lambda: solve(f(x, y) - f(2 - x, 2), x)) + raises(ValueError, lambda: solve(f(x, y) - f(1 - x), x)) + raises(ValueError, lambda: solve(f(x, y) - f(1), x)) + + # misc + # make sure that the right variables is picked up in tsolve + # shouldn't generate a GeneratorsNeeded error in _tsolve when the NaN is generated + # for eq_down. Actual answers, as determined numerically are approx. +/- 0.83 + raises(NotImplementedError, lambda: + solve(sinh(x)*sinh(sinh(x)) + cosh(x)*cosh(sinh(x)) - 3)) + + # watch out for recursive loop in tsolve + raises(NotImplementedError, lambda: solve((x + 2)**y*x - 3, x)) + + # issue 7245 + assert solve(sin(sqrt(x))) == [0, pi**2] + + # issue 7602 + a, b = symbols('a, b', real=True, negative=False) + assert str(solve(Eq(a, 0.5 - cos(pi*b)/2), b)) == \ + '[2.0 - 0.318309886183791*acos(1.0 - 2.0*a), 0.318309886183791*acos(1.0 - 2.0*a)]' + + # issue 15325 + assert solve(y**(1/x) - z, x) == [log(y)/log(z)] + + # issue 25685 (basic trig identities should give simple solutions) + for yi in [cos(2*x),sin(2*x),cos(x - pi/3)]: + sol = solve([cos(x) - S(3)/5, yi - y]) + assert (sol[0][y] + sol[1][y]).is_Rational, (yi,sol) + # don't allow massive expansion + assert solve(cos(1000*x) - S.Half) == [pi/3000, pi/600] + assert solve(cos(x - 1000*y) - 1, x) == [1000*y, 1000*y + 2*pi] + assert solve(cos(x + y + z) - 1, x) == [-y - z, -y - z + 2*pi] + + # issue 26008 + assert solve(sin(x + pi/6)) == [-pi/6, 5*pi/6] + + +def test_solve_for_functions_derivatives(): + t = Symbol('t') + x = Function('x')(t) + y = Function('y')(t) + a11, a12, a21, a22, b1, b2 = symbols('a11,a12,a21,a22,b1,b2') + + soln = solve([a11*x + a12*y - b1, a21*x + a22*y - b2], x, y) + assert soln == { + x: (a22*b1 - a12*b2)/(a11*a22 - a12*a21), + y: (a11*b2 - a21*b1)/(a11*a22 - a12*a21), + } + + assert solve(x - 1, x) == [1] + assert solve(3*x - 2, x) == [Rational(2, 3)] + + soln = solve([a11*x.diff(t) + a12*y.diff(t) - b1, a21*x.diff(t) + + a22*y.diff(t) - b2], x.diff(t), y.diff(t)) + assert soln == { y.diff(t): (a11*b2 - a21*b1)/(a11*a22 - a12*a21), + x.diff(t): (a22*b1 - a12*b2)/(a11*a22 - a12*a21) } + + assert solve(x.diff(t) - 1, x.diff(t)) == [1] + assert solve(3*x.diff(t) - 2, x.diff(t)) == [Rational(2, 3)] + + eqns = {3*x - 1, 2*y - 4} + assert solve(eqns, {x, y}) == { x: Rational(1, 3), y: 2 } + x = Symbol('x') + f = Function('f') + F = x**2 + f(x)**2 - 4*x - 1 + assert solve(F.diff(x), diff(f(x), x)) == [(-x + 2)/f(x)] + + # Mixed cased with a Symbol and a Function + x = Symbol('x') + y = Function('y')(t) + + soln = solve([a11*x + a12*y.diff(t) - b1, a21*x + + a22*y.diff(t) - b2], x, y.diff(t)) + assert soln == { y.diff(t): (a11*b2 - a21*b1)/(a11*a22 - a12*a21), + x: (a22*b1 - a12*b2)/(a11*a22 - a12*a21) } + + # issue 13263 + x = Symbol('x') + f = Function('f') + soln = solve([f(x).diff(x) + f(x).diff(x, 2) - 1, f(x).diff(x) - f(x).diff(x, 2)], + f(x).diff(x), f(x).diff(x, 2)) + assert soln == { f(x).diff(x, 2): S(1)/2, f(x).diff(x): S(1)/2 } + + soln = solve([f(x).diff(x, 2) + f(x).diff(x, 3) - 1, 1 - f(x).diff(x, 2) - + f(x).diff(x, 3), 1 - f(x).diff(x,3)], f(x).diff(x, 2), f(x).diff(x, 3)) + assert soln == { f(x).diff(x, 2): 0, f(x).diff(x, 3): 1 } + + +def test_issue_3725(): + f = Function('f') + F = x**2 + f(x)**2 - 4*x - 1 + e = F.diff(x) + assert solve(e, f(x).diff(x)) in [[(2 - x)/f(x)], [-((x - 2)/f(x))]] + + +def test_solve_Matrix(): + # https://github.com/sympy/sympy/issues/3870 + a, b, c, d = symbols('a b c d') + A = Matrix(2, 2, [a, b, c, d]) + B = Matrix(2, 2, [0, 2, -3, 0]) + C = Matrix(2, 2, [1, 2, 3, 4]) + + assert solve(A*B - C, [a, b, c, d]) == {a: 1, b: Rational(-1, 3), c: 2, d: -1} + assert solve([A*B - C], [a, b, c, d]) == {a: 1, b: Rational(-1, 3), c: 2, d: -1} + assert solve(Eq(A*B, C), [a, b, c, d]) == {a: 1, b: Rational(-1, 3), c: 2, d: -1} + + assert solve([A*B - B*A], [a, b, c, d]) == {a: d, b: Rational(-2, 3)*c} + assert solve([A*C - C*A], [a, b, c, d]) == {a: d - c, b: Rational(2, 3)*c} + assert solve([A*B - B*A, A*C - C*A], [a, b, c, d]) == {a: d, b: 0, c: 0} + + assert solve([Eq(A*B, B*A)], [a, b, c, d]) == {a: d, b: Rational(-2, 3)*c} + assert solve([Eq(A*C, C*A)], [a, b, c, d]) == {a: d - c, b: Rational(2, 3)*c} + assert solve([Eq(A*B, B*A), Eq(A*C, C*A)], [a, b, c, d]) == {a: d, b: 0, c: 0} + + # https://github.com/sympy/sympy/issues/27854 + m, n = symbols("m n") + A = MatrixSymbol("A", m, n) + x = MatrixSymbol("x", n, 1) + b = MatrixSymbol('b', m, 1) + r = A * x - b + f = r.T * r + grad_f = f.diff(x) + raises(ValueError, lambda: solve(grad_f, x)) + + +def test_solve_linear(): + w = Wild('w') + assert solve_linear(x, x) == (0, 1) + assert solve_linear(x, exclude=[x]) == (0, 1) + assert solve_linear(x, symbols=[w]) == (0, 1) + assert solve_linear(x, y - 2*x) in [(x, y/3), (y, 3*x)] + assert solve_linear(x, y - 2*x, exclude=[x]) == (y, 3*x) + assert solve_linear(3*x - y, 0) in [(x, y/3), (y, 3*x)] + assert solve_linear(3*x - y, 0, [x]) == (x, y/3) + assert solve_linear(3*x - y, 0, [y]) == (y, 3*x) + assert solve_linear(x**2/y, 1) == (y, x**2) + assert solve_linear(w, x) in [(w, x), (x, w)] + assert solve_linear(cos(x)**2 + sin(x)**2 + 2 + y) == \ + (y, -2 - cos(x)**2 - sin(x)**2) + assert solve_linear(cos(x)**2 + sin(x)**2 + 2 + y, symbols=[x]) == (0, 1) + assert solve_linear(Eq(x, 3)) == (x, 3) + assert solve_linear(1/(1/x - 2)) == (0, 0) + assert solve_linear((x + 1)*exp(-x), symbols=[x]) == (x, -1) + assert solve_linear((x + 1)*exp(x), symbols=[x]) == ((x + 1)*exp(x), 1) + assert solve_linear(x*exp(-x**2), symbols=[x]) == (x, 0) + assert solve_linear(0**x - 1) == (0**x - 1, 1) + assert solve_linear(1 + 1/(x - 1)) == (x, 0) + eq = y*cos(x)**2 + y*sin(x)**2 - y # = y*(1 - 1) = 0 + assert solve_linear(eq) == (0, 1) + eq = cos(x)**2 + sin(x)**2 # = 1 + assert solve_linear(eq) == (0, 1) + raises(ValueError, lambda: solve_linear(Eq(x, 3), 3)) + + +def test_solve_undetermined_coeffs(): + assert solve_undetermined_coeffs( + a*x**2 + b*x**2 + b*x + 2*c*x + c + 1, [a, b, c], x + ) == {a: -2, b: 2, c: -1} + # Test that rational functions work + assert solve_undetermined_coeffs(a/x + b/(x + 1) + - (2*x + 1)/(x**2 + x), [a, b], x) == {a: 1, b: 1} + # Test cancellation in rational functions + assert solve_undetermined_coeffs( + ((c + 1)*a*x**2 + (c + 1)*b*x**2 + + (c + 1)*b*x + (c + 1)*2*c*x + (c + 1)**2)/(c + 1), + [a, b, c], x) == \ + {a: -2, b: 2, c: -1} + # multivariate + X, Y, Z = y, x**y, y*x**y + eq = a*X + b*Y + c*Z - X - 2*Y - 3*Z + coeffs = a, b, c + syms = x, y + assert solve_undetermined_coeffs(eq, coeffs) == { + a: 1, b: 2, c: 3} + assert solve_undetermined_coeffs(eq, coeffs, syms) == { + a: 1, b: 2, c: 3} + assert solve_undetermined_coeffs(eq, coeffs, *syms) == { + a: 1, b: 2, c: 3} + # check output format + assert solve_undetermined_coeffs(a*x + a - 2, [a]) == [] + assert solve_undetermined_coeffs(a**2*x - 4*x, [a]) == [ + {a: -2}, {a: 2}] + assert solve_undetermined_coeffs(0, [a]) == [] + assert solve_undetermined_coeffs(0, [a], dict=True) == [] + assert solve_undetermined_coeffs(0, [a], set=True) == ([], {}) + assert solve_undetermined_coeffs(1, [a]) == [] + abeq = a*x - 2*x + b - 3 + s = {b, a} + assert solve_undetermined_coeffs(abeq, s, x) == {a: 2, b: 3} + assert solve_undetermined_coeffs(abeq, s, x, set=True) == ([a, b], {(2, 3)}) + assert solve_undetermined_coeffs(sin(a*x) - sin(2*x), (a,)) is None + assert solve_undetermined_coeffs(a*x + b*x - 2*x, (a, b)) == {a: 2 - b} + + +def test_solve_inequalities(): + x = Symbol('x') + sol = And(S.Zero < x, x < oo) + assert solve(x + 1 > 1) == sol + assert solve([x + 1 > 1]) == sol + assert solve([x + 1 > 1], x) == sol + assert solve([x + 1 > 1], [x]) == sol + + system = [Lt(x**2 - 2, 0), Gt(x**2 - 1, 0)] + assert solve(system) == \ + And(Or(And(Lt(-sqrt(2), x), Lt(x, -1)), + And(Lt(1, x), Lt(x, sqrt(2)))), Eq(0, 0)) + + x = Symbol('x', real=True) + system = [Lt(x**2 - 2, 0), Gt(x**2 - 1, 0)] + assert solve(system) == \ + Or(And(Lt(-sqrt(2), x), Lt(x, -1)), And(Lt(1, x), Lt(x, sqrt(2)))) + + # issues 6627, 3448 + assert solve((x - 3)/(x - 2) < 0, x) == And(Lt(2, x), Lt(x, 3)) + assert solve(x/(x + 1) > 1, x) == And(Lt(-oo, x), Lt(x, -1)) + + assert solve(sin(x) > S.Half) == And(pi/6 < x, x < pi*Rational(5, 6)) + + assert solve(Eq(False, x < 1)) == (S.One <= x) & (x < oo) + assert solve(Eq(True, x < 1)) == (-oo < x) & (x < 1) + assert solve(Eq(x < 1, False)) == (S.One <= x) & (x < oo) + assert solve(Eq(x < 1, True)) == (-oo < x) & (x < 1) + + assert solve(Eq(False, x)) == False + assert solve(Eq(0, x)) == [0] + assert solve(Eq(True, x)) == True + assert solve(Eq(1, x)) == [1] + assert solve(Eq(False, ~x)) == True + assert solve(Eq(True, ~x)) == False + assert solve(Ne(True, x)) == False + assert solve(Ne(1, x)) == (x > -oo) & (x < oo) & Ne(x, 1) + + +def test_issue_4793(): + assert solve(1/x) == [] + assert solve(x*(1 - 5/x)) == [5] + assert solve(x + sqrt(x) - 2) == [1] + assert solve(-(1 + x)/(2 + x)**2 + 1/(2 + x)) == [] + assert solve(-x**2 - 2*x + (x + 1)**2 - 1) == [] + assert solve((x/(x + 1) + 3)**(-2)) == [] + assert solve(x/sqrt(x**2 + 1), x) == [0] + assert solve(exp(x) - y, x) == [log(y)] + assert solve(exp(x)) == [] + assert solve(x**2 + x + sin(y)**2 + cos(y)**2 - 1, x) in [[0, -1], [-1, 0]] + eq = 4*3**(5*x + 2) - 7 + ans = solve(eq, x) + assert len(ans) == 5 and all(eq.subs(x, a).n(chop=True) == 0 for a in ans) + assert solve(log(x**2) - y**2/exp(x), x, y, set=True) == ( + [x, y], + {(x, sqrt(exp(x) * log(x ** 2))), (x, -sqrt(exp(x) * log(x ** 2)))}) + assert solve(x**2*z**2 - z**2*y**2) == [{x: -y}, {x: y}, {z: 0}] + assert solve((x - 1)/(1 + 1/(x - 1))) == [] + assert solve(x**(y*z) - x, x) == [1] + raises(NotImplementedError, lambda: solve(log(x) - exp(x), x)) + raises(NotImplementedError, lambda: solve(2**x - exp(x) - 3)) + + +def test_PR1964(): + # issue 5171 + assert solve(sqrt(x)) == solve(sqrt(x**3)) == [0] + assert solve(sqrt(x - 1)) == [1] + # issue 4462 + a = Symbol('a') + assert solve(-3*a/sqrt(x), x) == [] + # issue 4486 + assert solve(2*x/(x + 2) - 1, x) == [2] + # issue 4496 + assert set(solve((x**2/(7 - x)).diff(x))) == {S.Zero, S(14)} + # issue 4695 + f = Function('f') + assert solve((3 - 5*x/f(x))*f(x), f(x)) == [x*Rational(5, 3)] + # issue 4497 + assert solve(1/root(5 + x, 5) - 9, x) == [Rational(-295244, 59049)] + + assert solve(sqrt(x) + sqrt(sqrt(x)) - 4) == [(Rational(-1, 2) + sqrt(17)/2)**4] + assert set(solve(Poly(sqrt(exp(x)) + sqrt(exp(-x)) - 4))) in \ + [ + {log((-sqrt(3) + 2)**2), log((sqrt(3) + 2)**2)}, + {2*log(-sqrt(3) + 2), 2*log(sqrt(3) + 2)}, + {log(-4*sqrt(3) + 7), log(4*sqrt(3) + 7)}, + ] + assert set(solve(Poly(exp(x) + exp(-x) - 4))) == \ + {log(-sqrt(3) + 2), log(sqrt(3) + 2)} + assert set(solve(x**y + x**(2*y) - 1, x)) == \ + {(Rational(-1, 2) + sqrt(5)/2)**(1/y), (Rational(-1, 2) - sqrt(5)/2)**(1/y)} + + assert solve(exp(x/y)*exp(-z/y) - 2, y) == [(x - z)/log(2)] + assert solve( + x**z*y**z - 2, z) in [[log(2)/(log(x) + log(y))], [log(2)/(log(x*y))]] + # if you do inversion too soon then multiple roots (as for the following) + # will be missed, e.g. if exp(3*x) = exp(3) -> 3*x = 3 + E = S.Exp1 + assert solve(exp(3*x) - exp(3), x) in [ + [1, log(E*(Rational(-1, 2) - sqrt(3)*I/2)), log(E*(Rational(-1, 2) + sqrt(3)*I/2))], + [1, log(-E/2 - sqrt(3)*E*I/2), log(-E/2 + sqrt(3)*E*I/2)], + ] + + # coverage test + p = Symbol('p', positive=True) + assert solve((1/p + 1)**(p + 1)) == [] + + +def test_issue_5197(): + x = Symbol('x', real=True) + assert solve(x**2 + 1, x) == [] + n = Symbol('n', integer=True, positive=True) + assert solve((n - 1)*(n + 2)*(2*n - 1), n) == [1] + x = Symbol('x', positive=True) + y = Symbol('y') + assert solve([x + 5*y - 2, -3*x + 6*y - 15], x, y) == [] + # not {x: -3, y: 1} b/c x is positive + # The solution following should not contain (-sqrt(2), sqrt(2)) + assert solve([(x + y), 2 - y**2], x, y) == [(sqrt(2), -sqrt(2))] + y = Symbol('y', positive=True) + # The solution following should not contain {y: -x*exp(x/2)} + assert solve(x**2 - y**2/exp(x), y, x, dict=True) == [{y: x*exp(x/2)}] + x, y, z = symbols('x y z', positive=True) + assert solve(z**2*x**2 - z**2*y**2/exp(x), y, x, z, dict=True) == [{y: x*exp(x/2)}] + + +def test_checking(): + assert set( + solve(x*(x - y/x), x, check=False)) == {sqrt(y), S.Zero, -sqrt(y)} + assert set(solve(x*(x - y/x), x, check=True)) == {sqrt(y), -sqrt(y)} + # {x: 0, y: 4} sets denominator to 0 in the following so system should return None + assert solve((1/(1/x + 2), 1/(y - 3) - 1)) == [] + # 0 sets denominator of 1/x to zero so None is returned + assert solve(1/(1/x + 2)) == [] + + +def test_issue_4671_4463_4467(): + assert solve(sqrt(x**2 - 1) - 2) in ([sqrt(5), -sqrt(5)], + [-sqrt(5), sqrt(5)]) + assert solve((2**exp(y**2/x) + 2)/(x**2 + 15), y) == [ + -sqrt(x*log(1 + I*pi/log(2))), sqrt(x*log(1 + I*pi/log(2)))] + + C1, C2 = symbols('C1 C2') + f = Function('f') + assert solve(C1 + C2/x**2 - exp(-f(x)), f(x)) == [log(x**2/(C1*x**2 + C2))] + a = Symbol('a') + E = S.Exp1 + assert solve(1 - log(a + 4*x**2), x) in ( + [-sqrt(-a + E)/2, sqrt(-a + E)/2], + [sqrt(-a + E)/2, -sqrt(-a + E)/2] + ) + assert solve(log(a**(-3) - x**2)/a, x) in ( + [-sqrt(-1 + a**(-3)), sqrt(-1 + a**(-3))], + [sqrt(-1 + a**(-3)), -sqrt(-1 + a**(-3))],) + assert solve(1 - log(a + 4*x**2), x) in ( + [-sqrt(-a + E)/2, sqrt(-a + E)/2], + [sqrt(-a + E)/2, -sqrt(-a + E)/2],) + assert solve((a**2 + 1)*(sin(a*x) + cos(a*x)), x) == [-pi/(4*a)] + assert solve(3 - (sinh(a*x) + cosh(a*x)), x) == [log(3)/a] + assert set(solve(3 - (sinh(a*x) + cosh(a*x)**2), x)) == \ + {log(-2 + sqrt(5))/a, log(-sqrt(2) + 1)/a, + log(-sqrt(5) - 2)/a, log(1 + sqrt(2))/a} + assert solve(atan(x) - 1) == [tan(1)] + + +def test_issue_5132(): + r, t = symbols('r,t') + assert set(solve([r - x**2 - y**2, tan(t) - y/x], [x, y])) == \ + {( + -sqrt(r*cos(t)**2), -1*sqrt(r*cos(t)**2)*tan(t)), + (sqrt(r*cos(t)**2), sqrt(r*cos(t)**2)*tan(t))} + assert solve([exp(x) - sin(y), 1/y - 3], [x, y]) == \ + [(log(sin(Rational(1, 3))), Rational(1, 3))] + assert solve([exp(x) - sin(y), 1/exp(y) - 3], [x, y]) == \ + [(log(-sin(log(3))), -log(3))] + assert set(solve([exp(x) - sin(y), y**2 - 4], [x, y])) == \ + {(log(-sin(2)), -S(2)), (log(sin(2)), S(2))} + eqs = [exp(x)**2 - sin(y) + z**2, 1/exp(y) - 3] + assert solve(eqs, set=True) == \ + ([y, z], { + (-log(3), sqrt(-exp(2*x) - sin(log(3)))), + (-log(3), -sqrt(-exp(2*x) - sin(log(3))))}) + assert solve(eqs, x, z, set=True) == ( + [x, z], + {(x, sqrt(-exp(2*x) + sin(y))), (x, -sqrt(-exp(2*x) + sin(y)))}) + assert set(solve(eqs, x, y)) == \ + { + (log(-sqrt(-z**2 - sin(log(3)))), -log(3)), + (log(-z**2 - sin(log(3)))/2, -log(3))} + assert set(solve(eqs, y, z)) == \ + { + (-log(3), -sqrt(-exp(2*x) - sin(log(3)))), + (-log(3), sqrt(-exp(2*x) - sin(log(3))))} + eqs = [exp(x)**2 - sin(y) + z, 1/exp(y) - 3] + assert solve(eqs, set=True) == ([y, z], { + (-log(3), -exp(2*x) - sin(log(3)))}) + assert solve(eqs, x, z, set=True) == ( + [x, z], {(x, -exp(2*x) + sin(y))}) + assert set(solve(eqs, x, y)) == { + (log(-sqrt(-z - sin(log(3)))), -log(3)), + (log(-z - sin(log(3)))/2, -log(3))} + assert solve(eqs, z, y) == \ + [(-exp(2*x) - sin(log(3)), -log(3))] + assert solve((sqrt(x**2 + y**2) - sqrt(10), x + y - 4), set=True) == ( + [x, y], {(S.One, S(3)), (S(3), S.One)}) + assert set(solve((sqrt(x**2 + y**2) - sqrt(10), x + y - 4), x, y)) == \ + {(S.One, S(3)), (S(3), S.One)} + + +def test_issue_5335(): + lam, a0, conc = symbols('lam a0 conc') + a = 0.005 + b = 0.743436700916726 + eqs = [lam + 2*y - a0*(1 - x/2)*x - a*x/2*x, + a0*(1 - x/2)*x - 1*y - b*y, + x + y - conc] + sym = [x, y, a0] + # there are 4 solutions obtained manually but only two are valid + assert len(solve(eqs, sym, manual=True, minimal=True)) == 2 + assert len(solve(eqs, sym)) == 2 # cf below with rational=False + + +@SKIP("Hangs") +def _test_issue_5335_float(): + # gives ZeroDivisionError: polynomial division + lam, a0, conc = symbols('lam a0 conc') + a = 0.005 + b = 0.743436700916726 + eqs = [lam + 2*y - a0*(1 - x/2)*x - a*x/2*x, + a0*(1 - x/2)*x - 1*y - b*y, + x + y - conc] + sym = [x, y, a0] + assert len(solve(eqs, sym, rational=False)) == 2 + + +def test_issue_5767(): + assert set(solve([x**2 + y + 4], [x])) == \ + {(-sqrt(-y - 4),), (sqrt(-y - 4),)} + + +def _make_example_24609(): + D, R, H, B_g, V, D_c = symbols("D, R, H, B_g, V, D_c", real=True, positive=True) + Sigma_f, Sigma_a, nu = symbols("Sigma_f, Sigma_a, nu", real=True, positive=True) + x = symbols("x", real=True, positive=True) + eq = ( + 2**(S(2)/3)*pi**(S(2)/3)*D_c*(S(231361)/10000 + pi**2/x**2) + /(6*V**(S(2)/3)*x**(S(1)/3)) + - 2**(S(2)/3)*pi**(S(8)/3)*D_c/(2*V**(S(2)/3)*x**(S(7)/3)) + ) + expected = 100*sqrt(2)*pi/481 + return eq, expected, x + + +def test_issue_24609(): + # https://github.com/sympy/sympy/issues/24609 + eq, expected, x = _make_example_24609() + assert solve(eq, x, simplify=True) == [expected] + [solapprox] = solve(eq.n(), x) + assert abs(solapprox - expected.n()) < 1e-14 + + +@XFAIL +def test_issue_24609_xfail(): + # + # This returns 5 solutions when it should be 1 (with x positive). + # Simplification reveals all solutions to be equivalent. It is expected + # that solve without simplify=True returns duplicate solutions in some + # cases but the core of this equation is a simple quadratic that can easily + # be solved without introducing any redundant solutions: + # + # >>> print(factor_terms(eq.as_numer_denom()[0])) + # 2**(2/3)*pi**(2/3)*D_c*V**(2/3)*x**(7/3)*(231361*x**2 - 20000*pi**2) + # + eq, expected, x = _make_example_24609() + assert len(solve(eq, x)) == [expected] + # + # We do not want to pass this test just by using simplify so if the above + # passes then uncomment the additional test below: + # + # assert len(solve(eq, x, simplify=False)) == 1 + + +def test_polysys(): + assert set(solve([x**2 + 2/y - 2, x + y - 3], [x, y])) == \ + {(S.One, S(2)), (1 + sqrt(5), 2 - sqrt(5)), + (1 - sqrt(5), 2 + sqrt(5))} + assert solve([x**2 + y - 2, x**2 + y]) == [] + # the ordering should be whatever the user requested + assert solve([x**2 + y - 3, x - y - 4], (x, y)) != solve([x**2 + + y - 3, x - y - 4], (y, x)) + + +@slow +def test_unrad1(): + raises(NotImplementedError, lambda: + unrad(sqrt(x) + sqrt(x + 1) + sqrt(1 - sqrt(x)) + 3)) + raises(NotImplementedError, lambda: + unrad(sqrt(x) + (x + 1)**Rational(1, 3) + 2*sqrt(y))) + + s = symbols('s', cls=Dummy) + + # checkers to deal with possibility of answer coming + # back with a sign change (cf issue 5203) + def check(rv, ans): + assert bool(rv[1]) == bool(ans[1]) + if ans[1]: + return s_check(rv, ans) + e = rv[0].expand() + a = ans[0].expand() + return e in [a, -a] and rv[1] == ans[1] + + def s_check(rv, ans): + # get the dummy + rv = list(rv) + d = rv[0].atoms(Dummy) + reps = list(zip(d, [s]*len(d))) + # replace s with this dummy + rv = (rv[0].subs(reps).expand(), [rv[1][0].subs(reps), rv[1][1].subs(reps)]) + ans = (ans[0].subs(reps).expand(), [ans[1][0].subs(reps), ans[1][1].subs(reps)]) + return str(rv[0]) in [str(ans[0]), str(-ans[0])] and \ + str(rv[1]) == str(ans[1]) + + assert unrad(1) is None + assert check(unrad(sqrt(x)), + (x, [])) + assert check(unrad(sqrt(x) + 1), + (x - 1, [])) + assert check(unrad(sqrt(x) + root(x, 3) + 2), + (s**3 + s**2 + 2, [s, s**6 - x])) + assert check(unrad(sqrt(x)*root(x, 3) + 2), + (x**5 - 64, [])) + assert check(unrad(sqrt(x) + (x + 1)**Rational(1, 3)), + (x**3 - (x + 1)**2, [])) + assert check(unrad(sqrt(x) + sqrt(x + 1) + sqrt(2*x)), + (-2*sqrt(2)*x - 2*x + 1, [])) + assert check(unrad(sqrt(x) + sqrt(x + 1) + 2), + (16*x - 9, [])) + assert check(unrad(sqrt(x) + sqrt(x + 1) + sqrt(1 - x)), + (5*x**2 - 4*x, [])) + assert check(unrad(a*sqrt(x) + b*sqrt(x) + c*sqrt(y) + d*sqrt(y)), + ((a*sqrt(x) + b*sqrt(x))**2 - (c*sqrt(y) + d*sqrt(y))**2, [])) + assert check(unrad(sqrt(x) + sqrt(1 - x)), + (2*x - 1, [])) + assert check(unrad(sqrt(x) + sqrt(1 - x) - 3), + (x**2 - x + 16, [])) + assert check(unrad(sqrt(x) + sqrt(1 - x) + sqrt(2 + x)), + (5*x**2 - 2*x + 1, [])) + assert unrad(sqrt(x) + sqrt(1 - x) + sqrt(2 + x) - 3) in [ + (25*x**4 + 376*x**3 + 1256*x**2 - 2272*x + 784, []), + (25*x**8 - 476*x**6 + 2534*x**4 - 1468*x**2 + 169, [])] + assert unrad(sqrt(x) + sqrt(1 - x) + sqrt(2 + x) - sqrt(1 - 2*x)) == \ + (41*x**4 + 40*x**3 + 232*x**2 - 160*x + 16, []) # orig root at 0.487 + assert check(unrad(sqrt(x) + sqrt(x + 1)), (S.One, [])) + + eq = sqrt(x) + sqrt(x + 1) + sqrt(1 - sqrt(x)) + assert check(unrad(eq), + (16*x**2 - 9*x, [])) + assert set(solve(eq, check=False)) == {S.Zero, Rational(9, 16)} + assert solve(eq) == [] + # but this one really does have those solutions + assert set(solve(sqrt(x) - sqrt(x + 1) + sqrt(1 - sqrt(x)))) == \ + {S.Zero, Rational(9, 16)} + + assert check(unrad(sqrt(x) + root(x + 1, 3) + 2*sqrt(y), y), + (S('2*sqrt(x)*(x + 1)**(1/3) + x - 4*y + (x + 1)**(2/3)'), [])) + assert check(unrad(sqrt(x/(1 - x)) + (x + 1)**Rational(1, 3)), + (x**5 - x**4 - x**3 + 2*x**2 + x - 1, [])) + assert check(unrad(sqrt(x/(1 - x)) + 2*sqrt(y), y), + (4*x*y + x - 4*y, [])) + assert check(unrad(sqrt(x)*sqrt(1 - x) + 2, x), + (x**2 - x + 4, [])) + + # http://tutorial.math.lamar.edu/ + # Classes/Alg/SolveRadicalEqns.aspx#Solve_Rad_Ex2_a + assert solve(Eq(x, sqrt(x + 6))) == [3] + assert solve(Eq(x + sqrt(x - 4), 4)) == [4] + assert solve(Eq(1, x + sqrt(2*x - 3))) == [] + assert set(solve(Eq(sqrt(5*x + 6) - 2, x))) == {-S.One, S(2)} + assert set(solve(Eq(sqrt(2*x - 1) - sqrt(x - 4), 2))) == {S(5), S(13)} + assert solve(Eq(sqrt(x + 7) + 2, sqrt(3 - x))) == [-6] + # http://www.purplemath.com/modules/solverad.htm + assert solve((2*x - 5)**Rational(1, 3) - 3) == [16] + assert set(solve(x + 1 - root(x**4 + 4*x**3 - x, 4))) == \ + {Rational(-1, 2), Rational(-1, 3)} + assert set(solve(sqrt(2*x**2 - 7) - (3 - x))) == {-S(8), S(2)} + assert solve(sqrt(2*x + 9) - sqrt(x + 1) - sqrt(x + 4)) == [0] + assert solve(sqrt(x + 4) + sqrt(2*x - 1) - 3*sqrt(x - 1)) == [5] + assert solve(sqrt(x)*sqrt(x - 7) - 12) == [16] + assert solve(sqrt(x - 3) + sqrt(x) - 3) == [4] + assert solve(sqrt(9*x**2 + 4) - (3*x + 2)) == [0] + assert solve(sqrt(x) - 2 - 5) == [49] + assert solve(sqrt(x - 3) - sqrt(x) - 3) == [] + assert solve(sqrt(x - 1) - x + 7) == [10] + assert solve(sqrt(x - 2) - 5) == [27] + assert solve(sqrt(17*x - sqrt(x**2 - 5)) - 7) == [3] + assert solve(sqrt(x) - sqrt(x - 1) + sqrt(sqrt(x))) == [] + + # don't posify the expression in unrad and do use _mexpand + z = sqrt(2*x + 1)/sqrt(x) - sqrt(2 + 1/x) + p = posify(z)[0] + assert solve(p) == [] + assert solve(z) == [] + assert solve(z + 6*I) == [Rational(-1, 11)] + assert solve(p + 6*I) == [] + # issue 8622 + assert unrad(root(x + 1, 5) - root(x, 3)) == ( + -(x**5 - x**3 - 3*x**2 - 3*x - 1), []) + # issue #8679 + assert check(unrad(x + root(x, 3) + root(x, 3)**2 + sqrt(y), x), + (s**3 + s**2 + s + sqrt(y), [s, s**3 - x])) + + # for coverage + assert check(unrad(sqrt(x) + root(x, 3) + y), + (s**3 + s**2 + y, [s, s**6 - x])) + assert solve(sqrt(x) + root(x, 3) - 2) == [1] + raises(NotImplementedError, lambda: + solve(sqrt(x) + root(x, 3) + root(x + 1, 5) - 2)) + # fails through a different code path + raises(NotImplementedError, lambda: solve(-sqrt(2) + cosh(x)/x)) + # unrad some + assert solve(sqrt(x + root(x, 3))+root(x - y, 5), y) == [ + x + (x**Rational(1, 3) + x)**Rational(5, 2)] + assert check(unrad(sqrt(x) - root(x + 1, 3)*sqrt(x + 2) + 2), + (s**10 + 8*s**8 + 24*s**6 - 12*s**5 - 22*s**4 - 160*s**3 - 212*s**2 - + 192*s - 56, [s, s**2 - x])) + e = root(x + 1, 3) + root(x, 3) + assert unrad(e) == (2*x + 1, []) + eq = (sqrt(x) + sqrt(x + 1) + sqrt(1 - x) - 6*sqrt(5)/5) + assert check(unrad(eq), + (15625*x**4 + 173000*x**3 + 355600*x**2 - 817920*x + 331776, [])) + assert check(unrad(root(x, 4) + root(x, 4)**3 - 1), + (s**3 + s - 1, [s, s**4 - x])) + assert check(unrad(root(x, 2) + root(x, 2)**3 - 1), + (x**3 + 2*x**2 + x - 1, [])) + assert unrad(x**0.5) is None + assert check(unrad(t + root(x + y, 5) + root(x + y, 5)**3), + (s**3 + s + t, [s, s**5 - x - y])) + assert check(unrad(x + root(x + y, 5) + root(x + y, 5)**3, y), + (s**3 + s + x, [s, s**5 - x - y])) + assert check(unrad(x + root(x + y, 5) + root(x + y, 5)**3, x), + (s**5 + s**3 + s - y, [s, s**5 - x - y])) + assert check(unrad(root(x - 1, 3) + root(x + 1, 5) + root(2, 5)), + (s**5 + 5*2**Rational(1, 5)*s**4 + s**3 + 10*2**Rational(2, 5)*s**3 + + 10*2**Rational(3, 5)*s**2 + 5*2**Rational(4, 5)*s + 4, [s, s**3 - x + 1])) + raises(NotImplementedError, lambda: + unrad((root(x, 2) + root(x, 3) + root(x, 4)).subs(x, x**5 - x + 1))) + + # the simplify flag should be reset to False for unrad results; + # if it's not then this next test will take a long time + assert solve(root(x, 3) + root(x, 5) - 2) == [1] + eq = (sqrt(x) + sqrt(x + 1) + sqrt(1 - x) - 6*sqrt(5)/5) + assert check(unrad(eq), + ((5*x - 4)*(3125*x**3 + 37100*x**2 + 100800*x - 82944), [])) + ans = S(''' + [4/5, -1484/375 + 172564/(140625*(114*sqrt(12657)/78125 + + 12459439/52734375)**(1/3)) + + 4*(114*sqrt(12657)/78125 + 12459439/52734375)**(1/3)]''') + assert solve(eq) == ans + # duplicate radical handling + assert check(unrad(sqrt(x + root(x + 1, 3)) - root(x + 1, 3) - 2), + (s**3 - s**2 - 3*s - 5, [s, s**3 - x - 1])) + # cov post-processing + e = root(x**2 + 1, 3) - root(x**2 - 1, 5) - 2 + assert check(unrad(e), + (s**5 - 10*s**4 + 39*s**3 - 80*s**2 + 80*s - 30, + [s, s**3 - x**2 - 1])) + + e = sqrt(x + root(x + 1, 2)) - root(x + 1, 3) - 2 + assert check(unrad(e), + (s**6 - 2*s**5 - 7*s**4 - 3*s**3 + 26*s**2 + 40*s + 25, + [s, s**3 - x - 1])) + assert check(unrad(e, _reverse=True), + (s**6 - 14*s**5 + 73*s**4 - 187*s**3 + 276*s**2 - 228*s + 89, + [s, s**2 - x - sqrt(x + 1)])) + # this one needs r0, r1 reversal to work + assert check(unrad(sqrt(x + sqrt(root(x, 3) - 1)) - root(x, 6) - 2), + (s**12 - 2*s**8 - 8*s**7 - 8*s**6 + s**4 + 8*s**3 + 23*s**2 + + 32*s + 17, [s, s**6 - x])) + + # why does this pass + assert unrad(root(cosh(x), 3)/x*root(x + 1, 5) - 1) == ( + -(x**15 - x**3*cosh(x)**5 - 3*x**2*cosh(x)**5 - 3*x*cosh(x)**5 + - cosh(x)**5), []) + # and this fail? + #assert unrad(sqrt(cosh(x)/x) + root(x + 1, 3)*sqrt(x) - 1) == ( + # -s**6 + 6*s**5 - 15*s**4 + 20*s**3 - 15*s**2 + 6*s + x**5 + + # 2*x**4 + x**3 - 1, [s, s**2 - cosh(x)/x]) + + # watch for symbols in exponents + assert unrad(S('(x+y)**(2*y/3) + (x+y)**(1/3) + 1')) is None + assert check(unrad(S('(x+y)**(2*y/3) + (x+y)**(1/3) + 1'), x), + (s**(2*y) + s + 1, [s, s**3 - x - y])) + # should _Q be so lenient? + assert unrad(x**(S.Half/y) + y, x) == (x**(1/y) - y**2, []) + + # This tests two things: that if full unrad is attempted and fails + # the solution should still be found; also it tests that the use of + # composite + assert len(solve(sqrt(y)*x + x**3 - 1, x)) == 3 + assert len(solve(-512*y**3 + 1344*(x + 2)**Rational(1, 3)*y**2 - + 1176*(x + 2)**Rational(2, 3)*y - 169*x + 686, y, _unrad=False)) == 3 + + # watch out for when the cov doesn't involve the symbol of interest + eq = S('-x + (7*y/8 - (27*x/2 + 27*sqrt(x**2)/2)**(1/3)/3)**3 - 1') + assert solve(eq, y) == [ + 2**(S(2)/3)*(27*x + 27*sqrt(x**2))**(S(1)/3)*S(4)/21 + (512*x/343 + + S(512)/343)**(S(1)/3)*(-S(1)/2 - sqrt(3)*I/2), 2**(S(2)/3)*(27*x + + 27*sqrt(x**2))**(S(1)/3)*S(4)/21 + (512*x/343 + + S(512)/343)**(S(1)/3)*(-S(1)/2 + sqrt(3)*I/2), 2**(S(2)/3)*(27*x + + 27*sqrt(x**2))**(S(1)/3)*S(4)/21 + (512*x/343 + S(512)/343)**(S(1)/3)] + + eq = root(x + 1, 3) - (root(x, 3) + root(x, 5)) + assert check(unrad(eq), + (3*s**13 + 3*s**11 + s**9 - 1, [s, s**15 - x])) + assert check(unrad(eq - 2), + (3*s**13 + 3*s**11 + 6*s**10 + s**9 + 12*s**8 + 6*s**6 + 12*s**5 + + 12*s**3 + 7, [s, s**15 - x])) + assert check(unrad(root(x, 3) - root(x + 1, 4)/2 + root(x + 2, 3)), + (s*(4096*s**9 + 960*s**8 + 48*s**7 - s**6 - 1728), + [s, s**4 - x - 1])) # orig expr has two real roots: -1, -.389 + assert check(unrad(root(x, 3) + root(x + 1, 4) - root(x + 2, 3)/2), + (343*s**13 + 2904*s**12 + 1344*s**11 + 512*s**10 - 1323*s**9 - + 3024*s**8 - 1728*s**7 + 1701*s**5 + 216*s**4 - 729*s, [s, s**4 - x - + 1])) # orig expr has one real root: -0.048 + assert check(unrad(root(x, 3)/2 - root(x + 1, 4) + root(x + 2, 3)), + (729*s**13 - 216*s**12 + 1728*s**11 - 512*s**10 + 1701*s**9 - + 3024*s**8 + 1344*s**7 + 1323*s**5 - 2904*s**4 + 343*s, [s, s**4 - x - + 1])) # orig expr has 2 real roots: -0.91, -0.15 + assert check(unrad(root(x, 3)/2 - root(x + 1, 4) + root(x + 2, 3) - 2), + (729*s**13 + 1242*s**12 + 18496*s**10 + 129701*s**9 + 388602*s**8 + + 453312*s**7 - 612864*s**6 - 3337173*s**5 - 6332418*s**4 - 7134912*s**3 + - 5064768*s**2 - 2111913*s - 398034, [s, s**4 - x - 1])) + # orig expr has 1 real root: 19.53 + + ans = solve(sqrt(x) + sqrt(x + 1) - + sqrt(1 - x) - sqrt(2 + x)) + assert len(ans) == 1 and NS(ans[0])[:4] == '0.73' + # the fence optimization problem + # https://github.com/sympy/sympy/issues/4793#issuecomment-36994519 + F = Symbol('F') + eq = F - (2*x + 2*y + sqrt(x**2 + y**2)) + ans = F*Rational(2, 7) - sqrt(2)*F/14 + X = solve(eq, x, check=False) + for xi in reversed(X): # reverse since currently, ans is the 2nd one + Y = solve((x*y).subs(x, xi).diff(y), y, simplify=False, check=False) + if any((a - ans).expand().is_zero for a in Y): + break + else: + assert None # no answer was found + assert solve(sqrt(x + 1) + root(x, 3) - 2) == S(''' + [(-11/(9*(47/54 + sqrt(93)/6)**(1/3)) + 1/3 + (47/54 + + sqrt(93)/6)**(1/3))**3]''') + assert solve(sqrt(sqrt(x + 1)) + x**Rational(1, 3) - 2) == S(''' + [(-sqrt(-2*(-1/16 + sqrt(6913)/16)**(1/3) + 6/(-1/16 + + sqrt(6913)/16)**(1/3) + 17/2 + 121/(4*sqrt(-6/(-1/16 + + sqrt(6913)/16)**(1/3) + 2*(-1/16 + sqrt(6913)/16)**(1/3) + 17/4)))/2 + + sqrt(-6/(-1/16 + sqrt(6913)/16)**(1/3) + 2*(-1/16 + + sqrt(6913)/16)**(1/3) + 17/4)/2 + 9/4)**3]''') + assert solve(sqrt(x) + root(sqrt(x) + 1, 3) - 2) == S(''' + [(-(81/2 + 3*sqrt(741)/2)**(1/3)/3 + (81/2 + 3*sqrt(741)/2)**(-1/3) + + 2)**2]''') + eq = S(''' + -x + (1/2 - sqrt(3)*I/2)*(3*x**3/2 - x*(3*x**2 - 34)/2 + sqrt((-3*x**3 + + x*(3*x**2 - 34) + 90)**2/4 - 39304/27) - 45)**(1/3) + 34/(3*(1/2 - + sqrt(3)*I/2)*(3*x**3/2 - x*(3*x**2 - 34)/2 + sqrt((-3*x**3 + x*(3*x**2 + - 34) + 90)**2/4 - 39304/27) - 45)**(1/3))''') + assert check(unrad(eq), + (s*-(-s**6 + sqrt(3)*s**6*I - 153*2**Rational(2, 3)*3**Rational(1, 3)*s**4 + + 51*12**Rational(1, 3)*s**4 - 102*2**Rational(2, 3)*3**Rational(5, 6)*s**4*I - 1620*s**3 + + 1620*sqrt(3)*s**3*I + 13872*18**Rational(1, 3)*s**2 - 471648 + + 471648*sqrt(3)*I), [s, s**3 - 306*x - sqrt(3)*sqrt(31212*x**2 - + 165240*x + 61484) + 810])) + + assert solve(eq) == [] # not other code errors + eq = root(x, 3) - root(y, 3) + root(x, 5) + assert check(unrad(eq), + (s**15 + 3*s**13 + 3*s**11 + s**9 - y, [s, s**15 - x])) + eq = root(x, 3) + root(y, 3) + root(x*y, 4) + assert check(unrad(eq), + (s*y*(-s**12 - 3*s**11*y - 3*s**10*y**2 - s**9*y**3 - + 3*s**8*y**2 + 21*s**7*y**3 - 3*s**6*y**4 - 3*s**4*y**4 - + 3*s**3*y**5 - y**6), [s, s**4 - x*y])) + raises(NotImplementedError, + lambda: unrad(root(x, 3) + root(y, 3) + root(x*y, 5))) + + # Test unrad with an Equality + eq = Eq(-x**(S(1)/5) + x**(S(1)/3), -3**(S(1)/3) - (-1)**(S(3)/5)*3**(S(1)/5)) + assert check(unrad(eq), + (-s**5 + s**3 - 3**(S(1)/3) - (-1)**(S(3)/5)*3**(S(1)/5), [s, s**15 - x])) + + # make sure buried radicals are exposed + s = sqrt(x) - 1 + assert unrad(s**2 - s**3) == (x**3 - 6*x**2 + 9*x - 4, []) + # make sure numerators which are already polynomial are rejected + assert unrad((x/(x + 1) + 3)**(-2), x) is None + + # https://github.com/sympy/sympy/issues/23707 + eq = sqrt(x - y)*exp(t*sqrt(x - y)) - exp(t*sqrt(x - y)) + assert solve(eq, y) == [x - 1] + assert unrad(eq) is None + + +@slow +def test_unrad_slow(): + # this has roots with multiplicity > 1; there should be no + # repeats in roots obtained, however + eq = (sqrt(1 + sqrt(1 - 4*x**2)) - x*(1 + sqrt(1 + 2*sqrt(1 - 4*x**2)))) + assert solve(eq) == [S.Half] + + +@XFAIL +def test_unrad_fail(): + # this only works if we check real_root(eq.subs(x, Rational(1, 3))) + # but checksol doesn't work like that + assert solve(root(x**3 - 3*x**2, 3) + 1 - x) == [Rational(1, 3)] + assert solve(root(x + 1, 3) + root(x**2 - 2, 5) + 1) == [ + -1, -1 + CRootOf(x**5 + x**4 + 5*x**3 + 8*x**2 + 10*x + 5, 0)**3] + + +def test_checksol(): + x, y, r, t = symbols('x, y, r, t') + eq = r - x**2 - y**2 + dict_var_soln = {y: - sqrt(r) / sqrt(tan(t)**2 + 1), + x: -sqrt(r)*tan(t)/sqrt(tan(t)**2 + 1)} + assert checksol(eq, dict_var_soln) == True + assert checksol(Eq(x, False), {x: False}) is True + assert checksol(Ne(x, False), {x: False}) is False + assert checksol(Eq(x < 1, True), {x: 0}) is True + assert checksol(Eq(x < 1, True), {x: 1}) is False + assert checksol(Eq(x < 1, False), {x: 1}) is True + assert checksol(Eq(x < 1, False), {x: 0}) is False + assert checksol(Eq(x + 1, x**2 + 1), {x: 1}) is True + assert checksol([x - 1, x**2 - 1], x, 1) is True + assert checksol([x - 1, x**2 - 2], x, 1) is False + assert checksol(Poly(x**2 - 1), x, 1) is True + assert checksol(0, {}) is True + assert checksol([1e-10, x - 2], x, 2) is False + assert checksol([0.5, 0, x], x, 0) is False + assert checksol(y, x, 2) is False + assert checksol(x+1e-10, x, 0, numerical=True) is True + assert checksol(x+1e-10, x, 0, numerical=False) is False + assert checksol(exp(92*x), {x: log(sqrt(2)/2)}) is False + assert checksol(exp(92*x), {x: log(sqrt(2)/2) + I*pi}) is False + assert checksol(1/x**5, x, 1000) is False + raises(ValueError, lambda: checksol(x, 1)) + raises(ValueError, lambda: checksol([], x, 1)) + + +def test__invert(): + assert _invert(x - 2) == (2, x) + assert _invert(2) == (2, 0) + assert _invert(exp(1/x) - 3, x) == (1/log(3), x) + assert _invert(exp(1/x + a/x) - 3, x) == ((a + 1)/log(3), x) + assert _invert(a, x) == (a, 0) + + +def test_issue_4463(): + assert solve(-a*x + 2*x*log(x), x) == [exp(a/2)] + assert solve(x**x) == [] + assert solve(x**x - 2) == [exp(LambertW(log(2)))] + assert solve(((x - 3)*(x - 2))**((x - 3)*(x - 4))) == [2] + +@slow +def test_issue_5114_solvers(): + a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r = symbols('a:r') + + # there is no 'a' in the equation set but this is how the + # problem was originally posed + syms = a, b, c, f, h, k, n + eqs = [b + r/d - c/d, + c*(1/d + 1/e + 1/g) - f/g - r/d, + f*(1/g + 1/i + 1/j) - c/g - h/i, + h*(1/i + 1/l + 1/m) - f/i - k/m, + k*(1/m + 1/o + 1/p) - h/m - n/p, + n*(1/p + 1/q) - k/p] + assert len(solve(eqs, syms, manual=True, check=False, simplify=False)) == 1 + + +def test_issue_5849(): + # + # XXX: This system does not have a solution for most values of the + # parameters. Generally solve returns the empty set for systems that are + # generically inconsistent. + # + I1, I2, I3, I4, I5, I6 = symbols('I1:7') + dI1, dI4, dQ2, dQ4, Q2, Q4 = symbols('dI1,dI4,dQ2,dQ4,Q2,Q4') + + e = ( + I1 - I2 - I3, + I3 - I4 - I5, + I4 + I5 - I6, + -I1 + I2 + I6, + -2*I1 - 2*I3 - 2*I5 - 3*I6 - dI1/2 + 12, + -I4 + dQ4, + -I2 + dQ2, + 2*I3 + 2*I5 + 3*I6 - Q2, + I4 - 2*I5 + 2*Q4 + dI4 + ) + + ans = [{ + I1: I2 + I3, + dI1: -4*I2 - 8*I3 - 4*I5 - 6*I6 + 24, + I4: I3 - I5, + dQ4: I3 - I5, + Q4: -I3/2 + 3*I5/2 - dI4/2, + dQ2: I2, + Q2: 2*I3 + 2*I5 + 3*I6}] + + v = I1, I4, Q2, Q4, dI1, dI4, dQ2, dQ4 + assert solve(e, *v, manual=True, check=False, dict=True) == ans + assert solve(e, *v, manual=True, check=False) == [ + tuple([a.get(i, i) for i in v]) for a in ans] + assert solve(e, *v, manual=True) == [] + assert solve(e, *v) == [] + + # the matrix solver (tested below) doesn't like this because it produces + # a zero row in the matrix. Is this related to issue 4551? + assert [ei.subs( + ans[0]) for ei in e] == [0, 0, I3 - I6, -I3 + I6, 0, 0, 0, 0, 0] + + +def test_issue_5849_matrix(): + '''Same as test_issue_5849 but solved with the matrix solver. + + A solution only exists if I3 == I6 which is not generically true, + but `solve` does not return conditions under which the solution is + valid, only a solution that is canonical and consistent with the input. + ''' + # a simple example with the same issue + # assert solve([x+y+z, x+y], [x, y]) == {x: y} + # the longer example + I1, I2, I3, I4, I5, I6 = symbols('I1:7') + dI1, dI4, dQ2, dQ4, Q2, Q4 = symbols('dI1,dI4,dQ2,dQ4,Q2,Q4') + + e = ( + I1 - I2 - I3, + I3 - I4 - I5, + I4 + I5 - I6, + -I1 + I2 + I6, + -2*I1 - 2*I3 - 2*I5 - 3*I6 - dI1/2 + 12, + -I4 + dQ4, + -I2 + dQ2, + 2*I3 + 2*I5 + 3*I6 - Q2, + I4 - 2*I5 + 2*Q4 + dI4 + ) + assert solve(e, I1, I4, Q2, Q4, dI1, dI4, dQ2, dQ4) == [] + + +def test_issue_21882(): + + a, b, c, d, f, g, k = unknowns = symbols('a, b, c, d, f, g, k') + + equations = [ + -k*a + b + 5*f/6 + 2*c/9 + 5*d/6 + 4*a/3, + -k*f + 4*f/3 + d/2, + -k*d + f/6 + d, + 13*b/18 + 13*c/18 + 13*a/18, + -k*c + b/2 + 20*c/9 + a, + -k*b + b + c/18 + a/6, + 5*b/3 + c/3 + a, + 2*b/3 + 2*c + 4*a/3, + -g, + ] + + answer = [ + {a: 0, f: 0, b: 0, d: 0, c: 0, g: 0}, + {a: 0, f: -d, b: 0, k: S(5)/6, c: 0, g: 0}, + {a: -2*c, f: 0, b: c, d: 0, k: S(13)/18, g: 0}] + # but not {a: 0, f: 0, b: 0, k: S(3)/2, c: 0, d: 0, g: 0} + # since this is already covered by the first solution + got = solve(equations, unknowns, dict=True) + assert got == answer, (got,answer) + + +def test_issue_5901(): + f, g, h = map(Function, 'fgh') + a = Symbol('a') + D = Derivative(f(x), x) + G = Derivative(g(a), a) + assert solve(f(x) + f(x).diff(x), f(x)) == \ + [-D] + assert solve(f(x) - 3, f(x)) == \ + [3] + assert solve(f(x) - 3*f(x).diff(x), f(x)) == \ + [3*D] + assert solve([f(x) - 3*f(x).diff(x)], f(x)) == \ + {f(x): 3*D} + assert solve([f(x) - 3*f(x).diff(x), f(x)**2 - y + 4], f(x), y) == \ + [(3*D, 9*D**2 + 4)] + assert solve(-f(a)**2*g(a)**2 + f(a)**2*h(a)**2 + g(a).diff(a), + h(a), g(a), set=True) == \ + ([h(a), g(a)], { + (-sqrt(f(a)**2*g(a)**2 - G)/f(a), g(a)), + (sqrt(f(a)**2*g(a)**2 - G)/f(a), g(a))}), solve(-f(a)**2*g(a)**2 + f(a)**2*h(a)**2 + g(a).diff(a), + h(a), g(a), set=True) + args = [[f(x).diff(x, 2)*(f(x) + g(x)), 2 - g(x)**2], f(x), g(x)] + assert solve(*args, set=True)[1] == \ + {(-sqrt(2), sqrt(2)), (sqrt(2), -sqrt(2))} + eqs = [f(x)**2 + g(x) - 2*f(x).diff(x), g(x)**2 - 4] + assert solve(eqs, f(x), g(x), set=True) == \ + ([f(x), g(x)], { + (-sqrt(2*D - 2), S(2)), + (sqrt(2*D - 2), S(2)), + (-sqrt(2*D + 2), -S(2)), + (sqrt(2*D + 2), -S(2))}) + + # the underlying problem was in solve_linear that was not masking off + # anything but a Mul or Add; it now raises an error if it gets anything + # but a symbol and solve handles the substitutions necessary so solve_linear + # won't make this error + raises( + ValueError, lambda: solve_linear(f(x) + f(x).diff(x), symbols=[f(x)])) + assert solve_linear(f(x) + f(x).diff(x), symbols=[x]) == \ + (f(x) + Derivative(f(x), x), 1) + assert solve_linear(f(x) + Integral(x, (x, y)), symbols=[x]) == \ + (f(x) + Integral(x, (x, y)), 1) + assert solve_linear(f(x) + Integral(x, (x, y)) + x, symbols=[x]) == \ + (x + f(x) + Integral(x, (x, y)), 1) + assert solve_linear(f(y) + Integral(x, (x, y)) + x, symbols=[x]) == \ + (x, -f(y) - Integral(x, (x, y))) + assert solve_linear(x - f(x)/a + (f(x) - 1)/a, symbols=[x]) == \ + (x, 1/a) + assert solve_linear(x + Derivative(2*x, x)) == \ + (x, -2) + assert solve_linear(x + Integral(x, y), symbols=[x]) == \ + (x, 0) + assert solve_linear(x + Integral(x, y) - 2, symbols=[x]) == \ + (x, 2/(y + 1)) + + assert set(solve(x + exp(x)**2, exp(x))) == \ + {-sqrt(-x), sqrt(-x)} + assert solve(x + exp(x), x, implicit=True) == \ + [-exp(x)] + assert solve(cos(x) - sin(x), x, implicit=True) == [] + assert solve(x - sin(x), x, implicit=True) == \ + [sin(x)] + assert solve(x**2 + x - 3, x, implicit=True) == \ + [-x**2 + 3] + assert solve(x**2 + x - 3, x**2, implicit=True) == \ + [-x + 3] + + +def test_issue_5912(): + assert set(solve(x**2 - x - 0.1, rational=True)) == \ + {S.Half + sqrt(35)/10, -sqrt(35)/10 + S.Half} + ans = solve(x**2 - x - 0.1, rational=False) + assert len(ans) == 2 and all(a.is_Number for a in ans) + ans = solve(x**2 - x - 0.1) + assert len(ans) == 2 and all(a.is_Number for a in ans) + + +def test_float_handling(): + def test(e1, e2): + return len(e1.atoms(Float)) == len(e2.atoms(Float)) + assert solve(x - 0.5, rational=True)[0].is_Rational + assert solve(x - 0.5, rational=False)[0].is_Float + assert solve(x - S.Half, rational=False)[0].is_Rational + assert solve(x - 0.5, rational=None)[0].is_Float + assert solve(x - S.Half, rational=None)[0].is_Rational + assert test(nfloat(1 + 2*x), 1.0 + 2.0*x) + for contain in [list, tuple, set]: + ans = nfloat(contain([1 + 2*x])) + assert type(ans) is contain and test(list(ans)[0], 1.0 + 2.0*x) + k, v = list(nfloat({2*x: [1 + 2*x]}).items())[0] + assert test(k, 2*x) and test(v[0], 1.0 + 2.0*x) + assert test(nfloat(cos(2*x)), cos(2.0*x)) + assert test(nfloat(3*x**2), 3.0*x**2) + assert test(nfloat(3*x**2, exponent=True), 3.0*x**2.0) + assert test(nfloat(exp(2*x)), exp(2.0*x)) + assert test(nfloat(x/3), x/3.0) + assert test(nfloat(x**4 + 2*x + cos(Rational(1, 3)) + 1), + x**4 + 2.0*x + 1.94495694631474) + # don't call nfloat if there is no solution + tot = 100 + c + z + t + assert solve(((.7 + c)/tot - .6, (.2 + z)/tot - .3, t/tot - .1)) == [] + + +def test_check_assumptions(): + x = symbols('x', positive=True) + assert solve(x**2 - 1) == [1] + + +def test_issue_6056(): + assert solve(tanh(x + 3)*tanh(x - 3) - 1) == [] + assert solve(tanh(x - 1)*tanh(x + 1) + 1) == \ + [I*pi*Rational(-3, 4), -I*pi/4, I*pi/4, I*pi*Rational(3, 4)] + assert solve((tanh(x + 3)*tanh(x - 3) + 1)**2) == \ + [I*pi*Rational(-3, 4), -I*pi/4, I*pi/4, I*pi*Rational(3, 4)] + + +def test_issue_5673(): + eq = -x + exp(exp(LambertW(log(x)))*LambertW(log(x))) + assert checksol(eq, x, 2) is True + assert checksol(eq, x, 2, numerical=False) is None + + +def test_exclude(): + R, C, Ri, Vout, V1, Vminus, Vplus, s = \ + symbols('R, C, Ri, Vout, V1, Vminus, Vplus, s') + Rf = symbols('Rf', positive=True) # to eliminate Rf = 0 soln + eqs = [C*V1*s + Vplus*(-2*C*s - 1/R), + Vminus*(-1/Ri - 1/Rf) + Vout/Rf, + C*Vplus*s + V1*(-C*s - 1/R) + Vout/R, + -Vminus + Vplus] + assert solve(eqs, exclude=s*C*R) == [ + { + Rf: Ri*(C*R*s + 1)**2/(C*R*s), + Vminus: Vplus, + V1: 2*Vplus + Vplus/(C*R*s), + Vout: C*R*Vplus*s + 3*Vplus + Vplus/(C*R*s)}, + { + Vplus: 0, + Vminus: 0, + V1: 0, + Vout: 0}, + ] + + # TODO: Investigate why currently solution [0] is preferred over [1]. + assert solve(eqs, exclude=[Vplus, s, C]) in [[{ + Vminus: Vplus, + V1: Vout/2 + Vplus/2 + sqrt((Vout - 5*Vplus)*(Vout - Vplus))/2, + R: (Vout - 3*Vplus - sqrt(Vout**2 - 6*Vout*Vplus + 5*Vplus**2))/(2*C*Vplus*s), + Rf: Ri*(Vout - Vplus)/Vplus, + }, { + Vminus: Vplus, + V1: Vout/2 + Vplus/2 - sqrt((Vout - 5*Vplus)*(Vout - Vplus))/2, + R: (Vout - 3*Vplus + sqrt(Vout**2 - 6*Vout*Vplus + 5*Vplus**2))/(2*C*Vplus*s), + Rf: Ri*(Vout - Vplus)/Vplus, + }], [{ + Vminus: Vplus, + Vout: (V1**2 - V1*Vplus - Vplus**2)/(V1 - 2*Vplus), + Rf: Ri*(V1 - Vplus)**2/(Vplus*(V1 - 2*Vplus)), + R: Vplus/(C*s*(V1 - 2*Vplus)), + }]] + + +def test_high_order_roots(): + s = x**5 + 4*x**3 + 3*x**2 + Rational(7, 4) + assert set(solve(s)) == set(Poly(s*4, domain='ZZ').all_roots()) + + +def test_minsolve_linear_system(): + pqt = {"quick": True, "particular": True} + pqf = {"quick": False, "particular": True} + assert solve([x + y - 5, 2*x - y - 1], **pqt) == {x: 2, y: 3} + assert solve([x + y - 5, 2*x - y - 1], **pqf) == {x: 2, y: 3} + def count(dic): + return len([x for x in dic.values() if x == 0]) + assert count(solve([x + y + z, y + z + a + t], **pqt)) == 3 + assert count(solve([x + y + z, y + z + a + t], **pqf)) == 3 + assert count(solve([x + y + z, y + z + a], **pqt)) == 1 + assert count(solve([x + y + z, y + z + a], **pqf)) == 2 + # issue 22718 + A = Matrix([ + [ 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0], + [ 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, -1, -1, 0, 0], + [-1, -1, 0, 0, -1, 0, 0, 0, 0, 0, 1, 1, 0, 1], + [ 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, -1, 0, -1, 0], + [-1, 0, -1, 0, 0, -1, 0, 0, 0, 0, 1, 0, 1, 1], + [-1, 0, 0, -1, 0, 0, -1, 0, 0, 0, -1, 0, 0, -1], + [ 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, -1, -1, 0], + [ 0, -1, -1, 0, 0, 0, 0, -1, 0, 0, 0, 1, 1, 1], + [ 0, -1, 0, -1, 0, 0, 0, 0, -1, 0, 0, -1, 0, -1], + [ 0, 0, -1, -1, 0, 0, 0, 0, 0, -1, 0, 0, -1, -1], + [ 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0], + [ 0, 0, 0, 0, -1, -1, 0, -1, 0, 0, 0, 0, 0, 0]]) + v = Matrix(symbols("v:14", integer=True)) + B = Matrix([[2], [-2], [0], [0], [0], [0], [0], [0], [0], + [0], [0], [0]]) + eqs = A@v-B + assert solve(eqs) == [] + assert solve(eqs, particular=True) == [] # assumption violated + assert all(v for v in solve([x + y + z, y + z + a]).values()) + for _q in (True, False): + assert not all(v for v in solve( + [x + y + z, y + z + a], quick=_q, + particular=True).values()) + # raise error if quick used w/o particular=True + raises(ValueError, lambda: solve([x + 1], quick=_q)) + raises(ValueError, lambda: solve([x + 1], quick=_q, particular=False)) + # and give a good error message if someone tries to use + # particular with a single equation + raises(ValueError, lambda: solve(x + 1, particular=True)) + + +def test_real_roots(): + # cf. issue 6650 + x = Symbol('x', real=True) + assert len(solve(x**5 + x**3 + 1)) == 1 + + +def test_issue_6528(): + eqs = [ + 327600995*x**2 - 37869137*x + 1809975124*y**2 - 9998905626, + 895613949*x**2 - 273830224*x*y + 530506983*y**2 - 10000000000] + # two expressions encountered are > 1400 ops long so if this hangs + # it is likely because simplification is being done + assert len(solve(eqs, y, x, check=False)) == 4 + + +def test_overdetermined(): + x = symbols('x', real=True) + eqs = [Abs(4*x - 7) - 5, Abs(3 - 8*x) - 1] + assert solve(eqs, x) == [(S.Half,)] + assert solve(eqs, x, manual=True) == [(S.Half,)] + assert solve(eqs, x, manual=True, check=False) == [(S.Half,), (S(3),)] + + +def test_issue_6605(): + x = symbols('x') + assert solve(4**(x/2) - 2**(x/3)) == [0, 3*I*pi/log(2)] + # while the first one passed, this one failed + x = symbols('x', real=True) + assert solve(5**(x/2) - 2**(x/3)) == [0] + b = sqrt(6)*sqrt(log(2))/sqrt(log(5)) + assert solve(5**(x/2) - 2**(3/x)) == [-b, b] + + +def test__ispow(): + assert _ispow(x**2) + assert not _ispow(x) + assert not _ispow(True) + + +def test_issue_6644(): + eq = -sqrt((m - q)**2 + (-m/(2*q) + S.Half)**2) + sqrt((-m**2/2 - sqrt( + 4*m**4 - 4*m**2 + 8*m + 1)/4 - Rational(1, 4))**2 + (m**2/2 - m - sqrt( + 4*m**4 - 4*m**2 + 8*m + 1)/4 - Rational(1, 4))**2) + sol = solve(eq, q, simplify=False, check=False) + assert len(sol) == 5 + + +def test_issue_6752(): + assert solve([a**2 + a, a - b], [a, b]) == [(-1, -1), (0, 0)] + assert solve([a**2 + a*c, a - b], [a, b]) == [(0, 0), (-c, -c)] + + +def test_issue_6792(): + assert solve(x*(x - 1)**2*(x + 1)*(x**6 - x + 1)) == [ + -1, 0, 1, CRootOf(x**6 - x + 1, 0), CRootOf(x**6 - x + 1, 1), + CRootOf(x**6 - x + 1, 2), CRootOf(x**6 - x + 1, 3), + CRootOf(x**6 - x + 1, 4), CRootOf(x**6 - x + 1, 5)] + + +def test_issues_6819_6820_6821_6248_8692_25777_25779(): + # issue 6821 + x, y = symbols('x y', real=True) + assert solve(abs(x + 3) - 2*abs(x - 3)) == [1, 9] + assert solve([abs(x) - 2, arg(x) - pi], x) == [(-2,)] + assert set(solve(abs(x - 7) - 8)) == {-S.One, S(15)} + + # issue 8692 + assert solve(Eq(Abs(x + 1) + Abs(x**2 - 7), 9), x) == [ + Rational(-1, 2) + sqrt(61)/2, -sqrt(69)/2 + S.Half] + + # issue 7145 + assert solve(2*abs(x) - abs(x - 1)) == [-1, Rational(1, 3)] + + # 25777 + assert solve(abs(x**3 + x + 2)/(x + 1)) == [] + + # 25779 + assert solve(abs(x)) == [0] + assert solve(Eq(abs(x**2 - 2*x), 4), x) == [ + 1 - sqrt(5), 1 + sqrt(5)] + nn = symbols('nn', nonnegative=True) + assert solve(abs(sqrt(nn))) == [0] + nz = symbols('nz', nonzero=True) + assert solve(Eq(Abs(4 + 1 / (4*nz)), 0)) == [-Rational(1, 16)] + + x = symbols('x') + assert solve([re(x) - 1, im(x) - 2], x) == [ + {x: 1 + 2*I, re(x): 1, im(x): 2}] + + # check for 'dict' handling of solution + eq = sqrt(re(x)**2 + im(x)**2) - 3 + assert solve(eq) == solve(eq, x) + + i = symbols('i', imaginary=True) + assert solve(abs(i) - 3) == [-3*I, 3*I] + raises(NotImplementedError, lambda: solve(abs(x) - 3)) + + w = symbols('w', integer=True) + assert solve(2*x**w - 4*y**w, w) == solve((x/y)**w - 2, w) + + x, y = symbols('x y', real=True) + assert solve(x + y*I + 3) == {y: 0, x: -3} + # issue 2642 + assert solve(x*(1 + I)) == [0] + + x, y = symbols('x y', imaginary=True) + assert solve(x + y*I + 3 + 2*I) == {x: -2*I, y: 3*I} + + x = symbols('x', real=True) + assert solve(x + y + 3 + 2*I) == {x: -3, y: -2*I} + + # issue 6248 + f = Function('f') + assert solve(f(x + 1) - f(2*x - 1)) == [2] + assert solve(log(x + 1) - log(2*x - 1)) == [2] + + x = symbols('x') + assert solve(2**x + 4**x) == [I*pi/log(2)] + +def test_issue_17638(): + + assert solve(((2-exp(2*x))*exp(x))/(exp(2*x)+2)**2 > 0, x) == (-oo < x) & (x < log(2)/2) + assert solve(((2-exp(2*x)+2)*exp(x+2))/(exp(x)+2)**2 > 0, x) == (-oo < x) & (x < log(4)/2) + assert solve((exp(x)+2+x**2)*exp(2*x+2)/(exp(x)+2)**2 > 0, x) == (-oo < x) & (x < oo) + + + +def test_issue_14607(): + # issue 14607 + s, tau_c, tau_1, tau_2, phi, K = symbols( + 's, tau_c, tau_1, tau_2, phi, K') + + target = (s**2*tau_1*tau_2 + s*tau_1 + s*tau_2 + 1)/(K*s*(-phi + tau_c)) + + K_C, tau_I, tau_D = symbols('K_C, tau_I, tau_D', + positive=True, nonzero=True) + PID = K_C*(1 + 1/(tau_I*s) + tau_D*s) + + eq = (target - PID).together() + eq *= denom(eq).simplify() + eq = Poly(eq, s) + c = eq.coeffs() + + vars = [K_C, tau_I, tau_D] + s = solve(c, vars, dict=True) + + assert len(s) == 1 + + knownsolution = {K_C: -(tau_1 + tau_2)/(K*(phi - tau_c)), + tau_I: tau_1 + tau_2, + tau_D: tau_1*tau_2/(tau_1 + tau_2)} + + for var in vars: + assert s[0][var].simplify() == knownsolution[var].simplify() + + +def test_lambert_multivariate(): + from sympy.abc import x, y + assert _filtered_gens(Poly(x + 1/x + exp(x) + y), x) == {x, exp(x)} + assert _lambert(x, x) == [] + assert solve((x**2 - 2*x + 1).subs(x, log(x) + 3*x)) == [LambertW(3*S.Exp1)/3] + assert solve((x**2 - 2*x + 1).subs(x, (log(x) + 3*x)**2 - 1)) == \ + [LambertW(3*exp(-sqrt(2)))/3, LambertW(3*exp(sqrt(2)))/3] + assert solve((x**2 - 2*x - 2).subs(x, log(x) + 3*x)) == \ + [LambertW(3*exp(1 - sqrt(3)))/3, LambertW(3*exp(1 + sqrt(3)))/3] + eq = (x*exp(x) - 3).subs(x, x*exp(x)) + assert solve(eq) == [LambertW(3*exp(-LambertW(3)))] + # coverage test + raises(NotImplementedError, lambda: solve(x - sin(x)*log(y - x), x)) + ans = [3, -3*LambertW(-log(3)/3)/log(3)] # 3 and 2.478... + assert solve(x**3 - 3**x, x) == ans + assert set(solve(3*log(x) - x*log(3))) == set(ans) + assert solve(LambertW(2*x) - y, x) == [y*exp(y)/2] + + +@XFAIL +def test_other_lambert(): + assert solve(3*sin(x) - x*sin(3), x) == [3] + assert set(solve(x**a - a**x), x) == { + a, -a*LambertW(-log(a)/a)/log(a)} + + +@slow +def test_lambert_bivariate(): + # tests passing current implementation + assert solve((x**2 + x)*exp(x**2 + x) - 1) == [ + Rational(-1, 2) + sqrt(1 + 4*LambertW(1))/2, + Rational(-1, 2) - sqrt(1 + 4*LambertW(1))/2] + assert solve((x**2 + x)*exp((x**2 + x)*2) - 1) == [ + Rational(-1, 2) + sqrt(1 + 2*LambertW(2))/2, + Rational(-1, 2) - sqrt(1 + 2*LambertW(2))/2] + assert solve(a/x + exp(x/2), x) == [2*LambertW(-a/2)] + assert solve((a/x + exp(x/2)).diff(x), x) == \ + [4*LambertW(-sqrt(2)*sqrt(a)/4), 4*LambertW(sqrt(2)*sqrt(a)/4)] + assert solve((1/x + exp(x/2)).diff(x), x) == \ + [4*LambertW(-sqrt(2)/4), + 4*LambertW(sqrt(2)/4), # nsimplifies as 2*2**(141/299)*3**(206/299)*5**(205/299)*7**(37/299)/21 + 4*LambertW(-sqrt(2)/4, -1)] + assert solve(x*log(x) + 3*x + 1, x) == \ + [exp(-3 + LambertW(-exp(3)))] + assert solve(-x**2 + 2**x, x) == [2, 4, -2*LambertW(log(2)/2)/log(2)] + assert solve(x**2 - 2**x, x) == [2, 4, -2*LambertW(log(2)/2)/log(2)] + ans = solve(3*x + 5 + 2**(-5*x + 3), x) + assert len(ans) == 1 and ans[0].expand() == \ + Rational(-5, 3) + LambertW(-10240*root(2, 3)*log(2)/3)/(5*log(2)) + assert solve(5*x - 1 + 3*exp(2 - 7*x), x) == \ + [Rational(1, 5) + LambertW(-21*exp(Rational(3, 5))/5)/7] + assert solve((log(x) + x).subs(x, x**2 + 1)) == [ + -I*sqrt(-LambertW(1) + 1), sqrt(-1 + LambertW(1))] + # check collection + ax = a**(3*x + 5) + ans = solve(3*log(ax) + b*log(ax) + ax, x) + x0 = 1/log(a) + x1 = sqrt(3)*I + x2 = b + 3 + x3 = x2*LambertW(1/x2)/a**5 + x4 = x3**Rational(1, 3)/2 + assert ans == [ + x0*log(x4*(-x1 - 1)), + x0*log(x4*(x1 - 1)), + x0*log(x3)/3] + x1 = LambertW(Rational(1, 3)) + x2 = a**(-5) + x3 = -3**Rational(1, 3) + x4 = 3**Rational(5, 6)*I + x5 = x1**Rational(1, 3)*x2**Rational(1, 3)/2 + ans = solve(3*log(ax) + ax, x) + assert ans == [ + x0*log(3*x1*x2)/3, + x0*log(x5*(x3 - x4)), + x0*log(x5*(x3 + x4))] + # coverage + p = symbols('p', positive=True) + eq = 4*2**(2*p + 3) - 2*p - 3 + assert _solve_lambert(eq, p, _filtered_gens(Poly(eq), p)) == [ + Rational(-3, 2) - LambertW(-4*log(2))/(2*log(2))] + assert set(solve(3**cos(x) - cos(x)**3)) == { + acos(3), acos(-3*LambertW(-log(3)/3)/log(3))} + # should give only one solution after using `uniq` + assert solve(2*log(x) - 2*log(z) + log(z + log(x) + log(z)), x) == [ + exp(-z + LambertW(2*z**4*exp(2*z))/2)/z] + # cases when p != S.One + # issue 4271 + ans = solve((a/x + exp(x/2)).diff(x, 2), x) + x0 = (-a)**Rational(1, 3) + x1 = sqrt(3)*I + x2 = x0/6 + assert ans == [ + 6*LambertW(x0/3), + 6*LambertW(x2*(-x1 - 1)), + 6*LambertW(x2*(x1 - 1))] + assert solve((1/x + exp(x/2)).diff(x, 2), x) == \ + [6*LambertW(Rational(-1, 3)), 6*LambertW(Rational(1, 6) - sqrt(3)*I/6), \ + 6*LambertW(Rational(1, 6) + sqrt(3)*I/6), 6*LambertW(Rational(-1, 3), -1)] + assert solve(x**2 - y**2/exp(x), x, y, dict=True) == \ + [{x: 2*LambertW(-y/2)}, {x: 2*LambertW(y/2)}] + # this is slow but not exceedingly slow + assert solve((x**3)**(x/2) + pi/2, x) == [ + exp(LambertW(-2*log(2)/3 + 2*log(pi)/3 + I*pi*Rational(2, 3)))] + + # issue 23253 + assert solve((1/log(sqrt(x) + 2)**2 - 1/x)) == [ + (LambertW(-exp(-2), -1) + 2)**2] + assert solve((1/log(1/sqrt(x) + 2)**2 - x)) == [ + (LambertW(-exp(-2), -1) + 2)**-2] + assert solve((1/log(x**2 + 2)**2 - x**-4)) == [ + -I*sqrt(2 - LambertW(exp(2))), + -I*sqrt(LambertW(-exp(-2)) + 2), + sqrt(-2 - LambertW(-exp(-2))), + sqrt(-2 + LambertW(exp(2))), + -sqrt(-2 - LambertW(-exp(-2), -1)), + sqrt(-2 - LambertW(-exp(-2), -1))] + + +def test_rewrite_trig(): + assert solve(sin(x) + tan(x)) == [0, -pi, pi, 2*pi] + assert solve(sin(x) + sec(x)) == [ + -2*atan(Rational(-1, 2) + sqrt(2)*sqrt(1 - sqrt(3)*I)/2 + sqrt(3)*I/2), + 2*atan(S.Half - sqrt(2)*sqrt(1 + sqrt(3)*I)/2 + sqrt(3)*I/2), 2*atan(S.Half + + sqrt(2)*sqrt(1 + sqrt(3)*I)/2 + sqrt(3)*I/2), 2*atan(S.Half - + sqrt(3)*I/2 + sqrt(2)*sqrt(1 - sqrt(3)*I)/2)] + assert solve(sinh(x) + tanh(x)) == [0, I*pi] + + # issue 6157 + assert solve(2*sin(x) - cos(x), x) == [atan(S.Half)] + + +@XFAIL +def test_rewrite_trigh(): + # if this import passes then the test below should also pass + from sympy.functions.elementary.hyperbolic import sech + assert solve(sinh(x) + sech(x)) == [ + 2*atanh(Rational(-1, 2) + sqrt(5)/2 - sqrt(-2*sqrt(5) + 2)/2), + 2*atanh(Rational(-1, 2) + sqrt(5)/2 + sqrt(-2*sqrt(5) + 2)/2), + 2*atanh(-sqrt(5)/2 - S.Half + sqrt(2 + 2*sqrt(5))/2), + 2*atanh(-sqrt(2 + 2*sqrt(5))/2 - sqrt(5)/2 - S.Half)] + + +def test_uselogcombine(): + eq = z - log(x) + log(y/(x*(-1 + y**2/x**2))) + assert solve(eq, x, force=True) == [-sqrt(y*(y - exp(z))), sqrt(y*(y - exp(z)))] + assert solve(log(x + 3) + log(1 + 3/x) - 3) in [ + [-3 + sqrt(-12 + exp(3))*exp(Rational(3, 2))/2 + exp(3)/2, + -sqrt(-12 + exp(3))*exp(Rational(3, 2))/2 - 3 + exp(3)/2], + [-3 + sqrt(-36 + (-exp(3) + 6)**2)/2 + exp(3)/2, + -3 - sqrt(-36 + (-exp(3) + 6)**2)/2 + exp(3)/2], + ] + assert solve(log(exp(2*x) + 1) + log(-tanh(x) + 1) - log(2)) == [] + + +def test_atan2(): + assert solve(atan2(x, 2) - pi/3, x) == [2*sqrt(3)] + + +def test_errorinverses(): + assert solve(erf(x) - y, x) == [erfinv(y)] + assert solve(erfinv(x) - y, x) == [erf(y)] + assert solve(erfc(x) - y, x) == [erfcinv(y)] + assert solve(erfcinv(x) - y, x) == [erfc(y)] + + +def test_issue_2725(): + R = Symbol('R') + eq = sqrt(2)*R*sqrt(1/(R + 1)) + (R + 1)*(sqrt(2)*sqrt(1/(R + 1)) - 1) + sol = solve(eq, R, set=True)[1] + assert sol == {(Rational(5, 3) + (Rational(-1, 2) - sqrt(3)*I/2)*(Rational(251, 27) + + sqrt(111)*I/9)**Rational(1, 3) + 40/(9*((Rational(-1, 2) - sqrt(3)*I/2)*(Rational(251, 27) + + sqrt(111)*I/9)**Rational(1, 3))),), (Rational(5, 3) + 40/(9*(Rational(251, 27) + + sqrt(111)*I/9)**Rational(1, 3)) + (Rational(251, 27) + sqrt(111)*I/9)**Rational(1, 3),)} + + +def test_issue_5114_6611(): + # See that it doesn't hang; this solves in about 2 seconds. + # Also check that the solution is relatively small. + # Note: the system in issue 6611 solves in about 5 seconds and has + # an op-count of 138336 (with simplify=False). + b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r = symbols('b:r') + eqs = Matrix([ + [b - c/d + r/d], [c*(1/g + 1/e + 1/d) - f/g - r/d], + [-c/g + f*(1/j + 1/i + 1/g) - h/i], [-f/i + h*(1/m + 1/l + 1/i) - k/m], + [-h/m + k*(1/p + 1/o + 1/m) - n/p], [-k/p + n*(1/q + 1/p)]]) + v = Matrix([f, h, k, n, b, c]) + ans = solve(list(eqs), list(v), simplify=False) + # If time is taken to simplify then then 2617 below becomes + # 1168 and the time is about 50 seconds instead of 2. + assert sum(s.count_ops() for s in ans.values()) <= 3270 + + +def test_det_quick(): + m = Matrix(3, 3, symbols('a:9')) + assert m.det() == det_quick(m) # calls det_perm + m[0, 0] = 1 + assert m.det() == det_quick(m) # calls det_minor + m = Matrix(3, 3, list(range(9))) + assert m.det() == det_quick(m) # defaults to .det() + # make sure they work with Sparse + s = SparseMatrix(2, 2, (1, 2, 1, 4)) + assert det_perm(s) == det_minor(s) == s.det() + + +def test_real_imag_splitting(): + a, b = symbols('a b', real=True) + assert solve(sqrt(a**2 + b**2) - 3, a) == \ + [-sqrt(-b**2 + 9), sqrt(-b**2 + 9)] + a, b = symbols('a b', imaginary=True) + assert solve(sqrt(a**2 + b**2) - 3, a) == [] + + +def test_issue_7110(): + y = -2*x**3 + 4*x**2 - 2*x + 5 + assert any(ask(Q.real(i)) for i in solve(y)) + + +def test_units(): + assert solve(1/x - 1/(2*cm)) == [2*cm] + + +def test_issue_7547(): + A, B, V = symbols('A,B,V') + eq1 = Eq(630.26*(V - 39.0)*V*(V + 39) - A + B, 0) + eq2 = Eq(B, 1.36*10**8*(V - 39)) + eq3 = Eq(A, 5.75*10**5*V*(V + 39.0)) + sol = Matrix(nsolve(Tuple(eq1, eq2, eq3), [A, B, V], (0, 0, 0))) + assert str(sol) == str(Matrix( + [['4442890172.68209'], + ['4289299466.1432'], + ['70.5389666628177']])) + + +def test_issue_7895(): + r = symbols('r', real=True) + assert solve(sqrt(r) - 2) == [4] + + +def test_issue_2777(): + # the equations represent two circles + x, y = symbols('x y', real=True) + e1, e2 = sqrt(x**2 + y**2) - 10, sqrt(y**2 + (-x + 10)**2) - 3 + a, b = Rational(191, 20), 3*sqrt(391)/20 + ans = [(a, -b), (a, b)] + assert solve((e1, e2), (x, y)) == ans + assert solve((e1, e2/(x - a)), (x, y)) == [] + # make the 2nd circle's radius be -3 + e2 += 6 + assert solve((e1, e2), (x, y)) == [] + assert solve((e1, e2), (x, y), check=False) == ans + + +def test_issue_7322(): + number = 5.62527e-35 + assert solve(x - number, x)[0] == number + + +def test_nsolve(): + raises(ValueError, lambda: nsolve(x, (-1, 1), method='bisect')) + raises(TypeError, lambda: nsolve((x - y + 3,x + y,z - y),(x,y,z),(-50,50))) + raises(TypeError, lambda: nsolve((x + y, x - y), (0, 1))) + raises(TypeError, lambda: nsolve(x < 0.5, x, 1)) + + +@slow +def test_high_order_multivariate(): + assert len(solve(a*x**3 - x + 1, x)) == 3 + assert len(solve(a*x**4 - x + 1, x)) == 4 + assert solve(a*x**5 - x + 1, x) == [] # incomplete solution allowed + raises(NotImplementedError, lambda: + solve(a*x**5 - x + 1, x, incomplete=False)) + + # result checking must always consider the denominator and CRootOf + # must be checked, too + d = x**5 - x + 1 + assert solve(d*(1 + 1/d)) == [CRootOf(d + 1, i) for i in range(5)] + d = x - 1 + assert solve(d*(2 + 1/d)) == [S.Half] + + +def test_base_0_exp_0(): + assert solve(0**x - 1) == [0] + assert solve(0**(x - 2) - 1) == [2] + assert solve(S('x*(1/x**0 - x)', evaluate=False)) == \ + [0, 1] + + +def test__simple_dens(): + assert _simple_dens(1/x**0, [x]) == set() + assert _simple_dens(1/x**y, [x]) == {x**y} + assert _simple_dens(1/root(x, 3), [x]) == {x} + + +def test_issue_8755(): + # This tests two things: that if full unrad is attempted and fails + # the solution should still be found; also it tests the use of + # keyword `composite`. + assert len(solve(sqrt(y)*x + x**3 - 1, x)) == 3 + assert len(solve(-512*y**3 + 1344*(x + 2)**Rational(1, 3)*y**2 - + 1176*(x + 2)**Rational(2, 3)*y - 169*x + 686, y, _unrad=False)) == 3 + + +@slow +def test_issue_8828(): + x1 = 0 + y1 = -620 + r1 = 920 + x2 = 126 + y2 = 276 + x3 = 51 + y3 = 205 + r3 = 104 + v = x, y, z + + f1 = (x - x1)**2 + (y - y1)**2 - (r1 - z)**2 + f2 = (x - x2)**2 + (y - y2)**2 - z**2 + f3 = (x - x3)**2 + (y - y3)**2 - (r3 - z)**2 + F = f1,f2,f3 + + g1 = sqrt((x - x1)**2 + (y - y1)**2) + z - r1 + g2 = f2 + g3 = sqrt((x - x3)**2 + (y - y3)**2) + z - r3 + G = g1,g2,g3 + + A = solve(F, v) + B = solve(G, v) + C = solve(G, v, manual=True) + + p, q, r = [{tuple(i.evalf(2) for i in j) for j in R} for R in [A, B, C]] + assert p == q == r + + +def test_issue_2840_8155(): + # with parameter-free solutions (i.e. no `n`), we want to avoid + # excessive periodic solutions + assert solve(sin(3*x) + sin(6*x)) == [0, -2*pi/9, 2*pi/9] + assert solve(sin(300*x) + sin(600*x)) == [0, -pi/450, pi/450] + assert solve(2*sin(x) - 2*sin(2*x)) == [0, -pi/3, pi/3] + + +def test_issue_9567(): + assert solve(1 + 1/(x - 1)) == [0] + + +def test_issue_11538(): + assert solve(x + E) == [-E] + assert solve(x**2 + E) == [-I*sqrt(E), I*sqrt(E)] + assert solve(x**3 + 2*E) == [ + -cbrt(2 * E), + cbrt(2)*cbrt(E)/2 - cbrt(2)*sqrt(3)*I*cbrt(E)/2, + cbrt(2)*cbrt(E)/2 + cbrt(2)*sqrt(3)*I*cbrt(E)/2] + assert solve([x + 4, y + E], x, y) == {x: -4, y: -E} + assert solve([x**2 + 4, y + E], x, y) == [ + (-2*I, -E), (2*I, -E)] + + e1 = x - y**3 + 4 + e2 = x + y + 4 + 4 * E + assert len(solve([e1, e2], x, y)) == 3 + + +@slow +def test_issue_12114(): + a, b, c, d, e, f, g = symbols('a,b,c,d,e,f,g') + terms = [1 + a*b + d*e, 1 + a*c + d*f, 1 + b*c + e*f, + g - a**2 - d**2, g - b**2 - e**2, g - c**2 - f**2] + sol = solve(terms, [a, b, c, d, e, f, g], dict=True) + s = sqrt(-f**2 - 1) + s2 = sqrt(2 - f**2) + s3 = sqrt(6 - 3*f**2) + s4 = sqrt(3)*f + s5 = sqrt(3)*s2 + assert sol == [ + {a: -s, b: -s, c: -s, d: f, e: f, g: -1}, + {a: s, b: s, c: s, d: f, e: f, g: -1}, + {a: -s4/2 - s2/2, b: s4/2 - s2/2, c: s2, + d: -f/2 + s3/2, e: -f/2 - s5/2, g: 2}, + {a: -s4/2 + s2/2, b: s4/2 + s2/2, c: -s2, + d: -f/2 - s3/2, e: -f/2 + s5/2, g: 2}, + {a: s4/2 - s2/2, b: -s4/2 - s2/2, c: s2, + d: -f/2 - s3/2, e: -f/2 + s5/2, g: 2}, + {a: s4/2 + s2/2, b: -s4/2 + s2/2, c: -s2, + d: -f/2 + s3/2, e: -f/2 - s5/2, g: 2}] + + +def test_inf(): + assert solve(1 - oo*x) == [] + assert solve(oo*x, x) == [] + assert solve(oo*x - oo, x) == [] + + +def test_issue_12448(): + f = Function('f') + fun = [f(i) for i in range(15)] + sym = symbols('x:15') + reps = dict(zip(fun, sym)) + + (x, y, z), c = sym[:3], sym[3:] + ssym = solve([c[4*i]*x + c[4*i + 1]*y + c[4*i + 2]*z + c[4*i + 3] + for i in range(3)], (x, y, z)) + + (x, y, z), c = fun[:3], fun[3:] + sfun = solve([c[4*i]*x + c[4*i + 1]*y + c[4*i + 2]*z + c[4*i + 3] + for i in range(3)], (x, y, z)) + + assert sfun[fun[0]].xreplace(reps).count_ops() == \ + ssym[sym[0]].count_ops() + + +def test_denoms(): + assert denoms(x/2 + 1/y) == {2, y} + assert denoms(x/2 + 1/y, y) == {y} + assert denoms(x/2 + 1/y, [y]) == {y} + assert denoms(1/x + 1/y + 1/z, [x, y]) == {x, y} + assert denoms(1/x + 1/y + 1/z, x, y) == {x, y} + assert denoms(1/x + 1/y + 1/z, {x, y}) == {x, y} + + +def test_issue_12476(): + x0, x1, x2, x3, x4, x5 = symbols('x0 x1 x2 x3 x4 x5') + eqns = [x0**2 - x0, x0*x1 - x1, x0*x2 - x2, x0*x3 - x3, x0*x4 - x4, x0*x5 - x5, + x0*x1 - x1, -x0/3 + x1**2 - 2*x2/3, x1*x2 - x1/3 - x2/3 - x3/3, + x1*x3 - x2/3 - x3/3 - x4/3, x1*x4 - 2*x3/3 - x5/3, x1*x5 - x4, x0*x2 - x2, + x1*x2 - x1/3 - x2/3 - x3/3, -x0/6 - x1/6 + x2**2 - x2/6 - x3/3 - x4/6, + -x1/6 + x2*x3 - x2/3 - x3/6 - x4/6 - x5/6, x2*x4 - x2/3 - x3/3 - x4/3, + x2*x5 - x3, x0*x3 - x3, x1*x3 - x2/3 - x3/3 - x4/3, + -x1/6 + x2*x3 - x2/3 - x3/6 - x4/6 - x5/6, + -x0/6 - x1/6 - x2/6 + x3**2 - x3/3 - x4/6, -x1/3 - x2/3 + x3*x4 - x3/3, + -x2 + x3*x5, x0*x4 - x4, x1*x4 - 2*x3/3 - x5/3, x2*x4 - x2/3 - x3/3 - x4/3, + -x1/3 - x2/3 + x3*x4 - x3/3, -x0/3 - 2*x2/3 + x4**2, -x1 + x4*x5, x0*x5 - x5, + x1*x5 - x4, x2*x5 - x3, -x2 + x3*x5, -x1 + x4*x5, -x0 + x5**2, x0 - 1] + sols = [{x0: 1, x3: Rational(1, 6), x2: Rational(1, 6), x4: Rational(-2, 3), x1: Rational(-2, 3), x5: 1}, + {x0: 1, x3: S.Half, x2: Rational(-1, 2), x4: 0, x1: 0, x5: -1}, + {x0: 1, x3: Rational(-1, 3), x2: Rational(-1, 3), x4: Rational(1, 3), x1: Rational(1, 3), x5: 1}, + {x0: 1, x3: 1, x2: 1, x4: 1, x1: 1, x5: 1}, + {x0: 1, x3: Rational(-1, 3), x2: Rational(1, 3), x4: sqrt(5)/3, x1: -sqrt(5)/3, x5: -1}, + {x0: 1, x3: Rational(-1, 3), x2: Rational(1, 3), x4: -sqrt(5)/3, x1: sqrt(5)/3, x5: -1}] + + assert solve(eqns) == sols + + +def test_issue_13849(): + t = symbols('t') + assert solve((t*(sqrt(5) + sqrt(2)) - sqrt(2), t), t) == [] + + +def test_issue_14860(): + from sympy.physics.units import newton, kilo + assert solve(8*kilo*newton + x + y, x) == [-8000*newton - y] + + +def test_issue_14721(): + k, h, a, b = symbols(':4') + assert solve([ + -1 + (-k + 1)**2/b**2 + (-h - 1)**2/a**2, + -1 + (-k + 1)**2/b**2 + (-h + 1)**2/a**2, + h, k + 2], h, k, a, b) == [ + (0, -2, -b*sqrt(1/(b**2 - 9)), b), + (0, -2, b*sqrt(1/(b**2 - 9)), b)] + assert solve([ + h, h/a + 1/b**2 - 2, -h/2 + 1/b**2 - 2], a, h, b) == [ + (a, 0, -sqrt(2)/2), (a, 0, sqrt(2)/2)] + assert solve((a + b**2 - 1, a + b**2 - 2)) == [] + + +def test_issue_14779(): + x = symbols('x', real=True) + assert solve(sqrt(x**4 - 130*x**2 + 1089) + sqrt(x**4 - 130*x**2 + + 3969) - 96*Abs(x)/x,x) == [sqrt(130)] + + +def test_issue_15307(): + assert solve((y - 2, Mul(x + 3,x - 2, evaluate=False))) == \ + [{x: -3, y: 2}, {x: 2, y: 2}] + assert solve((y - 2, Mul(3, x - 2, evaluate=False))) == \ + {x: 2, y: 2} + assert solve((y - 2, Add(x + 4, x - 2, evaluate=False))) == \ + {x: -1, y: 2} + eq1 = Eq(12513*x + 2*y - 219093, -5726*x - y) + eq2 = Eq(-2*x + 8, 2*x - 40) + assert solve([eq1, eq2]) == {x:12, y:75} + + +def test_issue_15415(): + assert solve(x - 3, x) == [3] + assert solve([x - 3], x) == {x:3} + assert solve(Eq(y + 3*x**2/2, y + 3*x), y) == [] + assert solve([Eq(y + 3*x**2/2, y + 3*x)], y) == [] + assert solve([Eq(y + 3*x**2/2, y + 3*x), Eq(x, 1)], y) == [] + + +@slow +def test_issue_15731(): + # f(x)**g(x)=c + assert solve(Eq((x**2 - 7*x + 11)**(x**2 - 13*x + 42), 1)) == [2, 3, 4, 5, 6, 7] + assert solve((x)**(x + 4) - 4) == [-2] + assert solve((-x)**(-x + 4) - 4) == [2] + assert solve((x**2 - 6)**(x**2 - 2) - 4) == [-2, 2] + assert solve((x**2 - 2*x - 1)**(x**2 - 3) - 1/(1 - 2*sqrt(2))) == [sqrt(2)] + assert solve(x**(x + S.Half) - 4*sqrt(2)) == [S(2)] + assert solve((x**2 + 1)**x - 25) == [2] + assert solve(x**(2/x) - 2) == [2, 4] + assert solve((x/2)**(2/x) - sqrt(2)) == [4, 8] + assert solve(x**(x + S.Half) - Rational(9, 4)) == [Rational(3, 2)] + # a**g(x)=c + assert solve((-sqrt(sqrt(2)))**x - 2) == [4, log(2)/(log(2**Rational(1, 4)) + I*pi)] + assert solve((sqrt(2))**x - sqrt(sqrt(2))) == [S.Half] + assert solve((-sqrt(2))**x + 2*(sqrt(2))) == [3, + (3*log(2)**2 + 4*pi**2 - 4*I*pi*log(2))/(log(2)**2 + 4*pi**2)] + assert solve((sqrt(2))**x - 2*(sqrt(2))) == [3] + assert solve(I**x + 1) == [2] + assert solve((1 + I)**x - 2*I) == [2] + assert solve((sqrt(2) + sqrt(3))**x - (2*sqrt(6) + 5)**Rational(1, 3)) == [Rational(2, 3)] + # bases of both sides are equal + b = Symbol('b') + assert solve(b**x - b**2, x) == [2] + assert solve(b**x - 1/b, x) == [-1] + assert solve(b**x - b, x) == [1] + b = Symbol('b', positive=True) + assert solve(b**x - b**2, x) == [2] + assert solve(b**x - 1/b, x) == [-1] + + +def test_issue_10933(): + assert solve(x**4 + y*(x + 0.1), x) # doesn't fail + assert solve(I*x**4 + x**3 + x**2 + 1.) # doesn't fail + + +def test_Abs_handling(): + x = symbols('x', real=True) + assert solve(abs(x/y), x) == [0] + + +def test_issue_7982(): + x = Symbol('x') + # Test that no exception happens + assert solve([2*x**2 + 5*x + 20 <= 0, x >= 1.5], x) is S.false + # From #8040 + assert solve([x**3 - 8.08*x**2 - 56.48*x/5 - 106 >= 0, x - 1 <= 0], [x]) is S.false + + +def test_issue_14645(): + x, y = symbols('x y') + assert solve([x*y - x - y, x*y - x - y], [x, y]) == [(y/(y - 1), y)] + + +def test_issue_12024(): + x, y = symbols('x y') + assert solve(Piecewise((0.0, x < 0.1), (x, x >= 0.1)) - y) == \ + [{y: Piecewise((0.0, x < 0.1), (x, True))}] + + +def test_issue_17452(): + assert solve((7**x)**x + pi, x) == [-sqrt(log(pi) + I*pi)/sqrt(log(7)), + sqrt(log(pi) + I*pi)/sqrt(log(7))] + assert solve(x**(x/11) + pi/11, x) == [exp(LambertW(-11*log(11) + 11*log(pi) + 11*I*pi))] + + +def test_issue_17799(): + assert solve(-erf(x**(S(1)/3))**pi + I, x) == [] + + +def test_issue_17650(): + x = Symbol('x', real=True) + assert solve(abs(abs(x**2 - 1) - x) - x) == [1, -1 + sqrt(2), 1 + sqrt(2)] + + +def test_issue_17882(): + eq = -8*x**2/(9*(x**2 - 1)**(S(4)/3)) + 4/(3*(x**2 - 1)**(S(1)/3)) + assert unrad(eq) is None + + +def test_issue_17949(): + assert solve(exp(+x+x**2), x) == [] + assert solve(exp(-x+x**2), x) == [] + assert solve(exp(+x-x**2), x) == [] + assert solve(exp(-x-x**2), x) == [] + + +def test_issue_10993(): + assert solve(Eq(binomial(x, 2), 3)) == [-2, 3] + assert solve(Eq(pow(x, 2) + binomial(x, 3), x)) == [-4, 0, 1] + assert solve(Eq(binomial(x, 2), 0)) == [0, 1] + assert solve(a+binomial(x, 3), a) == [-binomial(x, 3)] + assert solve(x-binomial(a, 3) + binomial(y, 2) + sin(a), x) == [-sin(a) + binomial(a, 3) - binomial(y, 2)] + assert solve((x+1)-binomial(x+1, 3), x) == [-2, -1, 3] + + +def test_issue_11553(): + eq1 = x + y + 1 + eq2 = x + GoldenRatio + assert solve([eq1, eq2], x, y) == {x: -GoldenRatio, y: -1 + GoldenRatio} + eq3 = x + 2 + TribonacciConstant + assert solve([eq1, eq3], x, y) == {x: -2 - TribonacciConstant, y: 1 + TribonacciConstant} + + +def test_issue_19113_19102(): + t = S(1)/3 + solve(cos(x)**5-sin(x)**5) + assert solve(4*cos(x)**3 - 2*sin(x)**3) == [ + atan(2**(t)), -atan(2**(t)*(1 - sqrt(3)*I)/2), + -atan(2**(t)*(1 + sqrt(3)*I)/2)] + h = S.Half + assert solve(cos(x)**2 + sin(x)) == [ + 2*atan(-h + sqrt(5)/2 + sqrt(2)*sqrt(1 - sqrt(5))/2), + -2*atan(h + sqrt(5)/2 + sqrt(2)*sqrt(1 + sqrt(5))/2), + -2*atan(-sqrt(5)/2 + h + sqrt(2)*sqrt(1 - sqrt(5))/2), + -2*atan(-sqrt(2)*sqrt(1 + sqrt(5))/2 + h + sqrt(5)/2)] + assert solve(3*cos(x) - sin(x)) == [atan(3)] + + +def test_issue_19509(): + a = S(3)/4 + b = S(5)/8 + c = sqrt(5)/8 + d = sqrt(5)/4 + assert solve(1/(x -1)**5 - 1) == [2, + -d + a - sqrt(-b + c), + -d + a + sqrt(-b + c), + d + a - sqrt(-b - c), + d + a + sqrt(-b - c)] + +def test_issue_20747(): + THT, HT, DBH, dib, c0, c1, c2, c3, c4 = symbols('THT HT DBH dib c0 c1 c2 c3 c4') + f = DBH*c3 + THT*c4 + c2 + rhs = 1 - ((HT - 1)/(THT - 1))**c1*(1 - exp(c0/f)) + eq = dib - DBH*(c0 - f*log(rhs)) + term = ((1 - exp((DBH*c0 - dib)/(DBH*(DBH*c3 + THT*c4 + c2)))) + / (1 - exp(c0/(DBH*c3 + THT*c4 + c2)))) + sol = [THT*term**(1/c1) - term**(1/c1) + 1] + assert solve(eq, HT) == sol + + +def test_issue_27001(): + assert solve((x, x**2), (x, y, z), dict=True) == [{x: 0}] + s = a1, a2, a3, a4, a5 = symbols('a1:6') + eqs = [8*a1**4*a2 + 4*a1**2*a2**3 - 8*a1**2*a2*a4 + a2**5/2 - 2*a2**3*a4 + + 8*a2*a3**2 + 2*a2*a4**2 + 8*a2*a5, 12*a1**4 + 6*a1**2*a2**2 - + 8*a1**2*a4 + 3*a2**4/4 - 2*a2**2*a4 + 4*a3**2 + a4**2 + 4*a5, 16*a1**3 + + 4*a1*a2**2 - 8*a1*a4, -8*a1**2*a2 - 2*a2**3 + 4*a2*a4] + sol = [{a4: 2*a1**2 + a2**2/2, a5: -a3**2}, {a1: 0, a2: 0, a5: -a3**2 - a4**2/4}] + assert solve(eqs, s, dict=True) == sol + assert (g:=solve(groebner(eqs, s), dict=True)) == sol, g + + +def test_issue_20902(): + f = (t / ((1 + t) ** 2)) + assert solve(f.subs({t: 3 * x + 2}).diff(x) > 0, x) == (S(-1) < x) & (x < S(-1)/3) + assert solve(f.subs({t: 3 * x + 3}).diff(x) > 0, x) == (S(-4)/3 < x) & (x < S(-2)/3) + assert solve(f.subs({t: 3 * x + 4}).diff(x) > 0, x) == (S(-5)/3 < x) & (x < S(-1)) + assert solve(f.subs({t: 3 * x + 2}).diff(x) > 0, x) == (S(-1) < x) & (x < S(-1)/3) + + +def test_issue_21034(): + a = symbols('a', real=True) + system = [x - cosh(cos(4)), y - sinh(cos(a)), z - tanh(x)] + # constants inside hyperbolic functions should not be rewritten in terms of exp + assert solve(system, x, y, z) == [(cosh(cos(4)), sinh(cos(a)), tanh(cosh(cos(4))))] + # but if the variable of interest is present in a hyperbolic function, + # then it should be rewritten in terms of exp and solved further + newsystem = [(exp(x) - exp(-x)) - tanh(x)*(exp(x) + exp(-x)) + x - 5] + assert solve(newsystem, x) == {x: 5} + + +def test_issue_4886(): + z = a*sqrt(R**2*a**2 + R**2*b**2 - c**2)/(a**2 + b**2) + t = b*c/(a**2 + b**2) + sol = [((b*(t - z) - c)/(-a), t - z), ((b*(t + z) - c)/(-a), t + z)] + assert solve([x**2 + y**2 - R**2, a*x + b*y - c], x, y) == sol + + +def test_issue_6819(): + a, b, c, d = symbols('a b c d', positive=True) + assert solve(a*b**x - c*d**x, x) == [log(c/a)/log(b/d)] + + +def test_issue_17454(): + x = Symbol('x') + assert solve((1 - x - I)**4, x) == [1 - I] + + +def test_issue_21852(): + solution = [21 - 21*sqrt(2)/2] + assert solve(2*x + sqrt(2*x**2) - 21) == solution + + +def test_issue_21942(): + eq = -d + (a*c**(1 - e) + b**(1 - e)*(1 - a))**(1/(1 - e)) + sol = solve(eq, c, simplify=False, check=False) + assert sol == [((a*b**(1 - e) - b**(1 - e) + + d**(1 - e))/a)**(1/(1 - e))] + + +def test_solver_flags(): + root = solve(x**5 + x**2 - x - 1, cubics=False) + rad = solve(x**5 + x**2 - x - 1, cubics=True) + assert root != rad + + +def test_issue_22768(): + eq = 2*x**3 - 16*(y - 1)**6*z**3 + assert solve(eq.expand(), x, simplify=False + ) == [2*z*(y - 1)**2, z*(-1 + sqrt(3)*I)*(y - 1)**2, + -z*(1 + sqrt(3)*I)*(y - 1)**2] + + +def test_issue_22717(): + assert solve((-y**2 + log(y**2/x) + 2, -2*x*y + 2*x/y)) == [ + {y: -1, x: E}, {y: 1, x: E}] + + +def test_issue_25176(): + eq = (x - 5)**-8 - 3 + sol = solve(eq) + assert not any(eq.subs(x, i) for i in sol) + + +def test_issue_10169(): + eq = S(-8*a - x**5*(a + b + c + e) - x**4*(4*a - 2**Rational(3,4)*c + 4*c + + d + 2**Rational(3,4)*e + 4*e + k) - x**3*(-4*2**Rational(3,4)*c + sqrt(2)*c - + 2**Rational(3,4)*d + 4*d + sqrt(2)*e + 4*2**Rational(3,4)*e + 2**Rational(3,4)*k + 4*k) - + x**2*(4*sqrt(2)*c - 4*2**Rational(3,4)*d + sqrt(2)*d + 4*sqrt(2)*e + + sqrt(2)*k + 4*2**Rational(3,4)*k) - x*(2*a + 2*b + 4*sqrt(2)*d + + 4*sqrt(2)*k) + 5) + assert solve_undetermined_coeffs(eq, [a, b, c, d, e, k], x) == { + a: Rational(5,8), + b: Rational(-5,1032), + c: Rational(-40,129) - 5*2**Rational(3,4)/129 + 5*2**Rational(1,4)/1032, + d: -20*2**Rational(3,4)/129 - 10*sqrt(2)/129 - 5*2**Rational(1,4)/258, + e: Rational(-40,129) - 5*2**Rational(1,4)/1032 + 5*2**Rational(3,4)/129, + k: -10*sqrt(2)/129 + 5*2**Rational(1,4)/258 + 20*2**Rational(3,4)/129 + } + + +def test_solve_undetermined_coeffs_issue_23927(): + A, B, r, phi = symbols('A, B, r, phi') + e = Eq(A*sin(t) + B*cos(t), r*sin(t - phi)) + eq = (e.lhs - e.rhs).expand(trig=True) + soln = solve_undetermined_coeffs(eq, (r, phi), t) + assert soln == [{ + phi: 2*atan((A - sqrt(A**2 + B**2))/B), + r: (-A**2 + A*sqrt(A**2 + B**2) - B**2)/(A - sqrt(A**2 + B**2)) + }, { + phi: 2*atan((A + sqrt(A**2 + B**2))/B), + r: (A**2 + A*sqrt(A**2 + B**2) + B**2)/(A + sqrt(A**2 + B**2))/-1 + }] + +def test_issue_24368(): + # Ideally these would produce a solution, but for now just check that they + # don't fail with a RuntimeError + raises(NotImplementedError, lambda: solve(Mod(x**2, 49), x)) + s2 = Symbol('s2', integer=True, positive=True) + f = floor(s2/2 - S(1)/2) + raises(NotImplementedError, lambda: solve((Mod(f**2/(f + 1) + 2*f/(f + 1) + 1/(f + 1), 1))*f + Mod(f**2/(f + 1) + 2*f/(f + 1) + 1/(f + 1), 1), s2)) + + +def test_solve_Piecewise(): + assert [S(10)/3] == solve(3*Piecewise( + (S.NaN, x <= 0), + (20*x - 3*(x - 6)**2/2 - 176, (x >= 0) & (x >= 2) & (x>= 4) & (x >= 6) & (x < 10)), + (100 - 26*x, (x >= 0) & (x >= 2) & (x >= 4) & (x < 10)), + (16*x - 3*(x - 6)**2/2 - 176, (x >= 2) & (x >= 4) & (x >= 6) & (x < 10)), + (100 - 30*x, (x >= 2) & (x >= 4) & (x < 10)), + (30*x - 3*(x - 6)**2/2 - 196, (x>= 0) & (x >= 4) & (x >= 6) & (x < 10)), + (80 - 16*x, (x >= 0) & (x >= 4) & (x < 10)), + (26*x - 3*(x - 6)**2/2 - 196, (x >= 4) & (x >= 6) & (x < 10)), + (80 - 20*x, (x >= 4) & (x < 10)), + (40*x - 3*(x - 6)**2/2 - 256, (x >= 0) & (x >= 2) & (x >= 6) & (x < 10)), + (20 - 6*x, (x >= 0) & (x >= 2) & (x < 10)), + (36*x - 3*(x - 6)**2/2 - 256, (x >= 2) & (x >= 6) & (x < 10)), + (20 - 10*x, (x >= 2) & (x < 10)), + (50*x - 3*(x - 6)**2/2 - 276, (x >= 0) & (x >= 6) & (x < 10)), + (4*x, (x >= 0) & (x < 10)), + (46*x - 3*(x - 6)**2/2 - 276, (x >= 6) & (x < 10)), + (0, x < 10), # this will simplify away + (S.NaN,True))) diff --git a/phivenv/Lib/site-packages/sympy/solvers/tests/test_solveset.py b/phivenv/Lib/site-packages/sympy/solvers/tests/test_solveset.py new file mode 100644 index 0000000000000000000000000000000000000000..a1ba7a11e68ed518c4d83c050947b78756ade181 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/solvers/tests/test_solveset.py @@ -0,0 +1,3548 @@ +from math import isclose + +from sympy.calculus.util import stationary_points +from sympy.core.containers import Tuple +from sympy.core.function import (Function, Lambda, nfloat, diff) +from sympy.core.mod import Mod +from sympy.core.numbers import (E, I, Rational, oo, pi, Integer, all_close) +from sympy.core.relational import (Eq, Gt, Ne, Ge) +from sympy.core.singleton import S +from sympy.core.sorting import ordered +from sympy.core.symbol import (Dummy, Symbol, symbols) +from sympy.core.sympify import sympify +from sympy.functions.elementary.complexes import (Abs, arg, im, re, sign, conjugate) +from sympy.functions.elementary.exponential import (LambertW, exp, log) +from sympy.functions.elementary.hyperbolic import (HyperbolicFunction, + sinh, cosh, tanh, coth, sech, csch, asinh, acosh, atanh, acoth, asech, acsch) +from sympy.functions.elementary.miscellaneous import sqrt, Min, Max +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import ( + TrigonometricFunction, acos, acot, acsc, asec, asin, atan, atan2, + cos, cot, csc, sec, sin, tan) +from sympy.functions.special.error_functions import (erf, erfc, + erfcinv, erfinv) +from sympy.logic.boolalg import And +from sympy.matrices.dense import MutableDenseMatrix as Matrix +from sympy.matrices.immutable import ImmutableDenseMatrix +from sympy.polys.polytools import Poly +from sympy.polys.rootoftools import CRootOf +from sympy.sets.contains import Contains +from sympy.sets.conditionset import ConditionSet +from sympy.sets.fancysets import ImageSet, Range +from sympy.sets.sets import (Complement, FiniteSet, + Intersection, Interval, Union, imageset, ProductSet) +from sympy.simplify import simplify +from sympy.tensor.indexed import Indexed +from sympy.utilities.iterables import numbered_symbols + +from sympy.testing.pytest import (XFAIL, raises, skip, slow, SKIP, _both_exp_pow) +from sympy.core.random import verify_numerically as tn +from sympy.physics.units import cm + +from sympy.solvers import solve +from sympy.solvers.solveset import ( + solveset_real, domain_check, solveset_complex, linear_eq_to_matrix, + linsolve, _is_function_class_equation, invert_real, invert_complex, + _invert_trig_hyp_real, solveset, solve_decomposition, substitution, + nonlinsolve, solvify, + _is_finite_with_finite_vars, _transolve, _is_exponential, + _solve_exponential, _is_logarithmic, _is_lambert, + _solve_logarithm, _term_factors, _is_modular, NonlinearError) + +from sympy.abc import (a, b, c, d, e, f, g, h, i, j, k, l, m, n, q, r, + t, w, x, y, z) + + +def dumeq(i, j): + if type(i) in (list, tuple): + return all(dumeq(i, j) for i, j in zip(i, j)) + return i == j or i.dummy_eq(j) + + +def assert_close_ss(sol1, sol2): + """Test solutions with floats from solveset are close""" + sol1 = sympify(sol1) + sol2 = sympify(sol2) + assert isinstance(sol1, FiniteSet) + assert isinstance(sol2, FiniteSet) + assert len(sol1) == len(sol2) + assert all(isclose(v1, v2) for v1, v2 in zip(sol1, sol2)) + + +def assert_close_nl(sol1, sol2): + """Test solutions with floats from nonlinsolve are close""" + sol1 = sympify(sol1) + sol2 = sympify(sol2) + assert isinstance(sol1, FiniteSet) + assert isinstance(sol2, FiniteSet) + assert len(sol1) == len(sol2) + for s1, s2 in zip(sol1, sol2): + assert len(s1) == len(s2) + assert all(isclose(v1, v2) for v1, v2 in zip(s1, s2)) + + +@_both_exp_pow +def test_invert_real(): + x = Symbol('x', real=True) + + def ireal(x, s=S.Reals): + return Intersection(s, x) + + assert invert_real(exp(x), z, x) == (x, ireal(FiniteSet(log(z)))) + + y = Symbol('y', positive=True) + n = Symbol('n', real=True) + assert invert_real(x + 3, y, x) == (x, FiniteSet(y - 3)) + assert invert_real(x*3, y, x) == (x, FiniteSet(y / 3)) + + assert invert_real(exp(x), y, x) == (x, FiniteSet(log(y))) + assert invert_real(exp(3*x), y, x) == (x, FiniteSet(log(y) / 3)) + assert invert_real(exp(x + 3), y, x) == (x, FiniteSet(log(y) - 3)) + + assert invert_real(exp(x) + 3, y, x) == (x, ireal(FiniteSet(log(y - 3)))) + assert invert_real(exp(x)*3, y, x) == (x, FiniteSet(log(y / 3))) + + assert invert_real(log(x), y, x) == (x, FiniteSet(exp(y))) + assert invert_real(log(3*x), y, x) == (x, FiniteSet(exp(y) / 3)) + assert invert_real(log(x + 3), y, x) == (x, FiniteSet(exp(y) - 3)) + + assert invert_real(Abs(x), y, x) == (x, FiniteSet(y, -y)) + + assert invert_real(2**x, y, x) == (x, FiniteSet(log(y)/log(2))) + assert invert_real(2**exp(x), y, x) == (x, ireal(FiniteSet(log(log(y)/log(2))))) + + assert invert_real(x**2, y, x) == (x, FiniteSet(sqrt(y), -sqrt(y))) + assert invert_real(x**S.Half, y, x) == (x, FiniteSet(y**2)) + + raises(ValueError, lambda: invert_real(x, x, x)) + + # issue 21236 + assert invert_real(x**pi, y, x) == (x, FiniteSet(y**(1/pi))) + assert invert_real(x**pi, -E, x) == (x, S.EmptySet) + assert invert_real(x**Rational(3/2), 1000, x) == (x, FiniteSet(100)) + assert invert_real(x**1.0, 1, x) == (x**1.0, FiniteSet(1)) + + raises(ValueError, lambda: invert_real(S.One, y, x)) + + assert invert_real(x**31 + x, y, x) == (x**31 + x, FiniteSet(y)) + + lhs = x**31 + x + base_values = FiniteSet(y - 1, -y - 1) + assert invert_real(Abs(x**31 + x + 1), y, x) == (lhs, base_values) + + assert dumeq(invert_real(sin(x), y, x), (x, + ConditionSet(x, (S(-1) <= y) & (y <= S(1)), Union( + ImageSet(Lambda(n, 2*n*pi + asin(y)), S.Integers), + ImageSet(Lambda(n, pi*2*n + pi - asin(y)), S.Integers))))) + + assert dumeq(invert_real(sin(exp(x)), y, x), (x, + ConditionSet(x, (S(-1) <= y) & (y <= S(1)), Union( + ImageSet(Lambda(n, log(2*n*pi + asin(y))), S.Integers), + ImageSet(Lambda(n, log(pi*2*n + pi - asin(y))), S.Integers))))) + + assert dumeq(invert_real(csc(x), y, x), (x, + ConditionSet(x, ((S(1) <= y) & (y < oo)) | ((-oo < y) & (y <= S(-1))), + Union(ImageSet(Lambda(n, 2*n*pi + acsc(y)), S.Integers), + ImageSet(Lambda(n, 2*n*pi - acsc(y) + pi), S.Integers))))) + + assert dumeq(invert_real(csc(exp(x)), y, x), (x, + ConditionSet(x, ((S(1) <= y) & (y < oo)) | ((-oo < y) & (y <= S(-1))), + Union(ImageSet(Lambda(n, log(2*n*pi + acsc(y))), S.Integers), + ImageSet(Lambda(n, log(2*n*pi - acsc(y) + pi)), S.Integers))))) + + assert dumeq(invert_real(cos(x), y, x), (x, + ConditionSet(x, (S(-1) <= y) & (y <= S(1)), Union( + ImageSet(Lambda(n, 2*n*pi + acos(y)), S.Integers), + ImageSet(Lambda(n, 2*n*pi - acos(y)), S.Integers))))) + + assert dumeq(invert_real(cos(exp(x)), y, x), (x, + ConditionSet(x, (S(-1) <= y) & (y <= S(1)), Union( + ImageSet(Lambda(n, log(2*n*pi + acos(y))), S.Integers), + ImageSet(Lambda(n, log(2*n*pi - acos(y))), S.Integers))))) + + assert dumeq(invert_real(sec(x), y, x), (x, + ConditionSet(x, ((S(1) <= y) & (y < oo)) | ((-oo < y) & (y <= S(-1))), + Union(ImageSet(Lambda(n, 2*n*pi + asec(y)), S.Integers), \ + ImageSet(Lambda(n, 2*n*pi - asec(y)), S.Integers))))) + + assert dumeq(invert_real(sec(exp(x)), y, x), (x, + ConditionSet(x, ((S(1) <= y) & (y < oo)) | ((-oo < y) & (y <= S(-1))), + Union(ImageSet(Lambda(n, log(2*n*pi - asec(y))), S.Integers), + ImageSet(Lambda(n, log(2*n*pi + asec(y))), S.Integers))))) + + assert dumeq(invert_real(tan(x), y, x), (x, + ConditionSet(x, (-oo < y) & (y < oo), + ImageSet(Lambda(n, n*pi + atan(y)), S.Integers)))) + + assert dumeq(invert_real(tan(exp(x)), y, x), (x, + ConditionSet(x, (-oo < y) & (y < oo), + ImageSet(Lambda(n, log(n*pi + atan(y))), S.Integers)))) + + assert dumeq(invert_real(cot(x), y, x), (x, + ConditionSet(x, (-oo < y) & (y < oo), + ImageSet(Lambda(n, n*pi + acot(y)), S.Integers)))) + + assert dumeq(invert_real(cot(exp(x)), y, x), (x, + ConditionSet(x, (-oo < y) & (y < oo), + ImageSet(Lambda(n, log(n*pi + acot(y))), S.Integers)))) + + assert dumeq(invert_real(tan(tan(x)), y, x), + (x, ConditionSet(x, Eq(tan(tan(x)), y), S.Reals))) + # slight regression compared to previous result: + # (tan(x), imageset(Lambda(n, n*pi + atan(y)), S.Integers))) + + x = Symbol('x', positive=True) + assert invert_real(x**pi, y, x) == (x, FiniteSet(y**(1/pi))) + + r = Symbol('r', real=True) + p = Symbol('p', positive=True) + assert invert_real(sinh(x), r, x) == (x, FiniteSet(asinh(r))) + assert invert_real(sinh(log(x)), p, x) == (x, FiniteSet(exp(asinh(p)))) + + assert invert_real(cosh(x), r, x) == (x, Intersection( + FiniteSet(-acosh(r), acosh(r)), S.Reals)) + assert invert_real(cosh(x), p + 1, x) == (x, + FiniteSet(-acosh(p + 1), acosh(p + 1))) + + assert invert_real(tanh(x), r, x) == (x, Intersection(FiniteSet(atanh(r)), S.Reals)) + assert invert_real(coth(x), p+1, x) == (x, FiniteSet(acoth(p+1))) + assert invert_real(sech(x), r, x) == (x, Intersection( + FiniteSet(-asech(r), asech(r)), S.Reals)) + assert invert_real(csch(x), p, x) == (x, FiniteSet(acsch(p))) + + assert dumeq(invert_real(tanh(sin(x)), r, x), (x, + ConditionSet(x, (S(-1) <= atanh(r)) & (atanh(r) <= S(1)), Union( + ImageSet(Lambda(n, 2*n*pi + asin(atanh(r))), S.Integers), + ImageSet(Lambda(n, 2*n*pi - asin(atanh(r)) + pi), S.Integers))))) + + +def test_invert_trig_hyp_real(): + # check some codepaths that are not as easily reached otherwise + n = Dummy('n') + assert _invert_trig_hyp_real(cosh(x), Range(-5, 10, 1), x)[1].dummy_eq(Union( + ImageSet(Lambda(n, -acosh(n)), Range(1, 10, 1)), + ImageSet(Lambda(n, acosh(n)), Range(1, 10, 1)))) + assert _invert_trig_hyp_real(coth(x), Interval(-3, 2), x) == (x, Union( + Interval(-oo, -acoth(3)), Interval(acoth(2), oo))) + assert _invert_trig_hyp_real(tanh(x), Interval(-S.Half, 1), x) == (x, + Interval(-atanh(S.Half), oo)) + assert _invert_trig_hyp_real(sech(x), imageset(n, S.Half + n/3, S.Naturals0), x) == \ + (x, FiniteSet(-asech(S(1)/2), asech(S(1)/2), -asech(S(5)/6), asech(S(5)/6))) + assert _invert_trig_hyp_real(csch(x), S.Reals, x) == (x, + Union(Interval.open(-oo, 0), Interval.open(0, oo))) + + +def test_invert_complex(): + assert invert_complex(x + 3, y, x) == (x, FiniteSet(y - 3)) + assert invert_complex(x*3, y, x) == (x, FiniteSet(y / 3)) + assert invert_complex((x - 1)**3, 0, x) == (x, FiniteSet(1)) + + assert dumeq(invert_complex(exp(x), y, x), + (x, imageset(Lambda(n, I*(2*pi*n + arg(y)) + log(Abs(y))), S.Integers))) + + assert invert_complex(log(x), y, x) == (x, FiniteSet(exp(y))) + + raises(ValueError, lambda: invert_real(1, y, x)) + raises(ValueError, lambda: invert_complex(x, x, x)) + raises(ValueError, lambda: invert_complex(x, x, 1)) + + assert dumeq(invert_complex(sin(x), I, x), (x, Union( + ImageSet(Lambda(n, 2*n*pi + I*log(1 + sqrt(2))), S.Integers), + ImageSet(Lambda(n, 2*n*pi + pi - I*log(1 + sqrt(2))), S.Integers)))) + assert dumeq(invert_complex(cos(x), 1+I, x), (x, Union( + ImageSet(Lambda(n, 2*n*pi - acos(1 + I)), S.Integers), + ImageSet(Lambda(n, 2*n*pi + acos(1 + I)), S.Integers)))) + assert dumeq(invert_complex(tan(2*x), 1, x), (x, + ImageSet(Lambda(n, n*pi/2 + pi/8), S.Integers))) + assert dumeq(invert_complex(cot(x), 2*I, x), (x, + ImageSet(Lambda(n, n*pi - I*acoth(2)), S.Integers))) + + assert dumeq(invert_complex(sinh(x), 0, x), (x, Union( + ImageSet(Lambda(n, 2*n*I*pi), S.Integers), + ImageSet(Lambda(n, 2*n*I*pi + I*pi), S.Integers)))) + assert dumeq(invert_complex(cosh(x), 0, x), (x, Union( + ImageSet(Lambda(n, 2*n*I*pi + I*pi/2), S.Integers), + ImageSet(Lambda(n, 2*n*I*pi + 3*I*pi/2), S.Integers)))) + assert invert_complex(tanh(x), 1, x) == (x, S.EmptySet) + assert dumeq(invert_complex(tanh(x), a, x), (x, + ConditionSet(x, Ne(a, -1) & Ne(a, 1), + ImageSet(Lambda(n, n*I*pi + atanh(a)), S.Integers)))) + assert invert_complex(coth(x), 1, x) == (x, S.EmptySet) + assert dumeq(invert_complex(coth(x), a, x), (x, + ConditionSet(x, Ne(a, -1) & Ne(a, 1), + ImageSet(Lambda(n, n*I*pi + acoth(a)), S.Integers)))) + assert dumeq(invert_complex(sech(x), 2, x), (x, Union( + ImageSet(Lambda(n, 2*n*I*pi + I*pi/3), S.Integers), + ImageSet(Lambda(n, 2*n*I*pi + 5*I*pi/3), S.Integers)))) + + +def test_domain_check(): + assert domain_check(1/(1 + (1/(x+1))**2), x, -1) is False + assert domain_check(x**2, x, 0) is True + assert domain_check(x, x, oo) is False + assert domain_check(0, x, oo) is False + + +def test_issue_11536(): + assert solveset(0**x - 100, x, S.Reals) == S.EmptySet + assert solveset(0**x - 1, x, S.Reals) == FiniteSet(0) + + +def test_issue_17479(): + f = (x**2 + y**2)**2 + (x**2 + z**2)**2 - 2*(2*x**2 + y**2 + z**2) + fx = f.diff(x) + fy = f.diff(y) + fz = f.diff(z) + sol = nonlinsolve([fx, fy, fz], [x, y, z]) + assert len(sol) >= 4 and len(sol) <= 20 + # nonlinsolve has been giving a varying number of solutions + # (originally 18, then 20, now 19) due to various internal changes. + # Unfortunately not all the solutions are actually valid and some are + # redundant. Since the original issue was that an exception was raised, + # this first test only checks that nonlinsolve returns a "plausible" + # solution set. The next test checks the result for correctness. + + +@XFAIL +def test_issue_18449(): + x, y, z = symbols("x, y, z") + f = (x**2 + y**2)**2 + (x**2 + z**2)**2 - 2*(2*x**2 + y**2 + z**2) + fx = diff(f, x) + fy = diff(f, y) + fz = diff(f, z) + sol = nonlinsolve([fx, fy, fz], [x, y, z]) + for (xs, ys, zs) in sol: + d = {x: xs, y: ys, z: zs} + assert tuple(_.subs(d).simplify() for _ in (fx, fy, fz)) == (0, 0, 0) + # After simplification and removal of duplicate elements, there should + # only be 4 parametric solutions left: + # simplifiedsolutions = FiniteSet((sqrt(1 - z**2), z, z), + # (-sqrt(1 - z**2), z, z), + # (sqrt(1 - z**2), -z, z), + # (-sqrt(1 - z**2), -z, z)) + # TODO: Is the above solution set definitely complete? + + +def test_issue_21047(): + f = (2 - x)**2 + (sqrt(x - 1) - 1)**6 + assert solveset(f, x, S.Reals) == FiniteSet(2) + + f = (sqrt(x)-1)**2 + (sqrt(x)+1)**2 -2*x**2 + sqrt(2) + assert solveset(f, x, S.Reals) == FiniteSet( + S.Half - sqrt(2*sqrt(2) + 5)/2, S.Half + sqrt(2*sqrt(2) + 5)/2) + + +def test_is_function_class_equation(): + assert _is_function_class_equation(TrigonometricFunction, + tan(x), x) is True + assert _is_function_class_equation(TrigonometricFunction, + tan(x) - 1, x) is True + assert _is_function_class_equation(TrigonometricFunction, + tan(x) + sin(x), x) is True + assert _is_function_class_equation(TrigonometricFunction, + tan(x) + sin(x) - a, x) is True + assert _is_function_class_equation(TrigonometricFunction, + sin(x)*tan(x) + sin(x), x) is True + assert _is_function_class_equation(TrigonometricFunction, + sin(x)*tan(x + a) + sin(x), x) is True + assert _is_function_class_equation(TrigonometricFunction, + sin(x)*tan(x*a) + sin(x), x) is True + assert _is_function_class_equation(TrigonometricFunction, + a*tan(x) - 1, x) is True + assert _is_function_class_equation(TrigonometricFunction, + tan(x)**2 + sin(x) - 1, x) is True + assert _is_function_class_equation(TrigonometricFunction, + tan(x) + x, x) is False + assert _is_function_class_equation(TrigonometricFunction, + tan(x**2), x) is False + assert _is_function_class_equation(TrigonometricFunction, + tan(x**2) + sin(x), x) is False + assert _is_function_class_equation(TrigonometricFunction, + tan(x)**sin(x), x) is False + assert _is_function_class_equation(TrigonometricFunction, + tan(sin(x)) + sin(x), x) is False + assert _is_function_class_equation(HyperbolicFunction, + tanh(x), x) is True + assert _is_function_class_equation(HyperbolicFunction, + tanh(x) - 1, x) is True + assert _is_function_class_equation(HyperbolicFunction, + tanh(x) + sinh(x), x) is True + assert _is_function_class_equation(HyperbolicFunction, + tanh(x) + sinh(x) - a, x) is True + assert _is_function_class_equation(HyperbolicFunction, + sinh(x)*tanh(x) + sinh(x), x) is True + assert _is_function_class_equation(HyperbolicFunction, + sinh(x)*tanh(x + a) + sinh(x), x) is True + assert _is_function_class_equation(HyperbolicFunction, + sinh(x)*tanh(x*a) + sinh(x), x) is True + assert _is_function_class_equation(HyperbolicFunction, + a*tanh(x) - 1, x) is True + assert _is_function_class_equation(HyperbolicFunction, + tanh(x)**2 + sinh(x) - 1, x) is True + assert _is_function_class_equation(HyperbolicFunction, + tanh(x) + x, x) is False + assert _is_function_class_equation(HyperbolicFunction, + tanh(x**2), x) is False + assert _is_function_class_equation(HyperbolicFunction, + tanh(x**2) + sinh(x), x) is False + assert _is_function_class_equation(HyperbolicFunction, + tanh(x)**sinh(x), x) is False + assert _is_function_class_equation(HyperbolicFunction, + tanh(sinh(x)) + sinh(x), x) is False + + +def test_garbage_input(): + raises(ValueError, lambda: solveset_real([y], y)) + x = Symbol('x', real=True) + assert solveset_real(x, 1) == S.EmptySet + assert solveset_real(x - 1, 1) == FiniteSet(x) + assert solveset_real(x, pi) == S.EmptySet + assert solveset_real(x, x**2) == S.EmptySet + + raises(ValueError, lambda: solveset_complex([x], x)) + assert solveset_complex(x, pi) == S.EmptySet + + raises(ValueError, lambda: solveset((x, y), x)) + raises(ValueError, lambda: solveset(x + 1, S.Reals)) + raises(ValueError, lambda: solveset(x + 1, x, 2)) + + +def test_solve_mul(): + assert solveset_real((a*x + b)*(exp(x) - 3), x) == \ + Union({log(3)}, Intersection({-b/a}, S.Reals)) + anz = Symbol('anz', nonzero=True) + bb = Symbol('bb', real=True) + assert solveset_real((anz*x + bb)*(exp(x) - 3), x) == \ + FiniteSet(-bb/anz, log(3)) + assert solveset_real((2*x + 8)*(8 + exp(x)), x) == FiniteSet(S(-4)) + assert solveset_real(x/log(x), x) is S.EmptySet + + +def test_solve_invert(): + assert solveset_real(exp(x) - 3, x) == FiniteSet(log(3)) + assert solveset_real(log(x) - 3, x) == FiniteSet(exp(3)) + + assert solveset_real(3**(x + 2), x) == FiniteSet() + assert solveset_real(3**(2 - x), x) == FiniteSet() + + assert solveset_real(y - b*exp(a/x), x) == Intersection( + S.Reals, FiniteSet(a/log(y/b))) + + # issue 4504 + assert solveset_real(2**x - 10, x) == FiniteSet(1 + log(5)/log(2)) + + +def test_issue_25768(): + assert dumeq(solveset_real(sin(x) - S.Half, x), Union( + ImageSet(Lambda(n, pi*2*n + pi/6), S.Integers), + ImageSet(Lambda(n, pi*2*n + pi*5/6), S.Integers))) + n1 = solveset_real(sin(x) - 0.5, x).n(5) + n2 = solveset_real(sin(x) - S.Half, x).n(5) + # help pass despite fp differences + eq = [i.replace( + lambda x:x.is_Float, + lambda x:Rational(x).limit_denominator(1000)) for i in (n1, n2)] + assert dumeq(*eq),(n1,n2) + + +def test_errorinverses(): + assert solveset_real(erf(x) - S.Half, x) == \ + FiniteSet(erfinv(S.Half)) + assert solveset_real(erfinv(x) - 2, x) == \ + FiniteSet(erf(2)) + assert solveset_real(erfc(x) - S.One, x) == \ + FiniteSet(erfcinv(S.One)) + assert solveset_real(erfcinv(x) - 2, x) == FiniteSet(erfc(2)) + + +def test_solve_polynomial(): + x = Symbol('x', real=True) + y = Symbol('y', real=True) + assert solveset_real(3*x - 2, x) == FiniteSet(Rational(2, 3)) + + assert solveset_real(x**2 - 1, x) == FiniteSet(-S.One, S.One) + assert solveset_real(x - y**3, x) == FiniteSet(y ** 3) + + assert solveset_real(x**3 - 15*x - 4, x) == FiniteSet( + -2 + 3 ** S.Half, + S(4), + -2 - 3 ** S.Half) + + assert solveset_real(sqrt(x) - 1, x) == FiniteSet(1) + assert solveset_real(sqrt(x) - 2, x) == FiniteSet(4) + assert solveset_real(x**Rational(1, 4) - 2, x) == FiniteSet(16) + assert solveset_real(x**Rational(1, 3) - 3, x) == FiniteSet(27) + assert len(solveset_real(x**5 + x**3 + 1, x)) == 1 + assert len(solveset_real(-2*x**3 + 4*x**2 - 2*x + 6, x)) > 0 + assert solveset_real(x**6 + x**4 + I, x) is S.EmptySet + + +def test_return_root_of(): + f = x**5 - 15*x**3 - 5*x**2 + 10*x + 20 + s = list(solveset_complex(f, x)) + for root in s: + assert root.func == CRootOf + + # if one uses solve to get the roots of a polynomial that has a CRootOf + # solution, make sure that the use of nfloat during the solve process + # doesn't fail. Note: if you want numerical solutions to a polynomial + # it is *much* faster to use nroots to get them than to solve the + # equation only to get CRootOf solutions which are then numerically + # evaluated. So for eq = x**5 + 3*x + 7 do Poly(eq).nroots() rather + # than [i.n() for i in solve(eq)] to get the numerical roots of eq. + assert nfloat(list(solveset_complex(x**5 + 3*x**3 + 7, x))[0], + exponent=False) == CRootOf(x**5 + 3*x**3 + 7, 0).n() + + sol = list(solveset_complex(x**6 - 2*x + 2, x)) + assert all(isinstance(i, CRootOf) for i in sol) and len(sol) == 6 + + f = x**5 - 15*x**3 - 5*x**2 + 10*x + 20 + s = list(solveset_complex(f, x)) + for root in s: + assert root.func == CRootOf + + s = x**5 + 4*x**3 + 3*x**2 + Rational(7, 4) + assert solveset_complex(s, x) == \ + FiniteSet(*Poly(s*4, domain='ZZ').all_roots()) + + # Refer issue #7876 + eq = x*(x - 1)**2*(x + 1)*(x**6 - x + 1) + assert solveset_complex(eq, x) == \ + FiniteSet(-1, 0, 1, CRootOf(x**6 - x + 1, 0), + CRootOf(x**6 - x + 1, 1), + CRootOf(x**6 - x + 1, 2), + CRootOf(x**6 - x + 1, 3), + CRootOf(x**6 - x + 1, 4), + CRootOf(x**6 - x + 1, 5)) + + +def test_solveset_sqrt_1(): + assert solveset_real(sqrt(5*x + 6) - 2 - x, x) == \ + FiniteSet(-S.One, S(2)) + assert solveset_real(sqrt(x - 1) - x + 7, x) == FiniteSet(10) + assert solveset_real(sqrt(x - 2) - 5, x) == FiniteSet(27) + assert solveset_real(sqrt(x) - 2 - 5, x) == FiniteSet(49) + assert solveset_real(sqrt(x**3), x) == FiniteSet(0) + assert solveset_real(sqrt(x - 1), x) == FiniteSet(1) + assert solveset_real(sqrt((x-3)/x), x) == FiniteSet(3) + assert solveset_real(sqrt((x-3)/x)-Rational(1, 2), x) == \ + FiniteSet(4) + +def test_solveset_sqrt_2(): + x = Symbol('x', real=True) + y = Symbol('y', real=True) + # http://tutorial.math.lamar.edu/Classes/Alg/SolveRadicalEqns.aspx#Solve_Rad_Ex2_a + assert solveset_real(sqrt(2*x - 1) - sqrt(x - 4) - 2, x) == \ + FiniteSet(S(5), S(13)) + assert solveset_real(sqrt(x + 7) + 2 - sqrt(3 - x), x) == \ + FiniteSet(-6) + + # http://www.purplemath.com/modules/solverad.htm + assert solveset_real(sqrt(17*x - sqrt(x**2 - 5)) - 7, x) == \ + FiniteSet(3) + + eq = x + 1 - (x**4 + 4*x**3 - x)**Rational(1, 4) + assert solveset_real(eq, x) == FiniteSet(Rational(-1, 2), Rational(-1, 3)) + + eq = sqrt(2*x + 9) - sqrt(x + 1) - sqrt(x + 4) + assert solveset_real(eq, x) == FiniteSet(0) + + eq = sqrt(x + 4) + sqrt(2*x - 1) - 3*sqrt(x - 1) + assert solveset_real(eq, x) == FiniteSet(5) + + eq = sqrt(x)*sqrt(x - 7) - 12 + assert solveset_real(eq, x) == FiniteSet(16) + + eq = sqrt(x - 3) + sqrt(x) - 3 + assert solveset_real(eq, x) == FiniteSet(4) + + eq = sqrt(2*x**2 - 7) - (3 - x) + assert solveset_real(eq, x) == FiniteSet(-S(8), S(2)) + + # others + eq = sqrt(9*x**2 + 4) - (3*x + 2) + assert solveset_real(eq, x) == FiniteSet(0) + + assert solveset_real(sqrt(x - 3) - sqrt(x) - 3, x) == FiniteSet() + + eq = (2*x - 5)**Rational(1, 3) - 3 + assert solveset_real(eq, x) == FiniteSet(16) + + assert solveset_real(sqrt(x) + sqrt(sqrt(x)) - 4, x) == \ + FiniteSet((Rational(-1, 2) + sqrt(17)/2)**4) + + eq = sqrt(x) - sqrt(x - 1) + sqrt(sqrt(x)) + assert solveset_real(eq, x) == FiniteSet() + + eq = (x - 4)**2 + (sqrt(x) - 2)**4 + assert solveset_real(eq, x) == FiniteSet(-4, 4) + + eq = (sqrt(x) + sqrt(x + 1) + sqrt(1 - x) - 6*sqrt(5)/5) + ans = solveset_real(eq, x) + ra = S('''-1484/375 - 4*(-S(1)/2 + sqrt(3)*I/2)*(-12459439/52734375 + + 114*sqrt(12657)/78125)**(S(1)/3) - 172564/(140625*(-S(1)/2 + + sqrt(3)*I/2)*(-12459439/52734375 + 114*sqrt(12657)/78125)**(S(1)/3))''') + rb = Rational(4, 5) + assert all(abs(eq.subs(x, i).n()) < 1e-10 for i in (ra, rb)) and \ + len(ans) == 2 and \ + {i.n(chop=True) for i in ans} == \ + {i.n(chop=True) for i in (ra, rb)} + + assert solveset_real(sqrt(x) + x**Rational(1, 3) + + x**Rational(1, 4), x) == FiniteSet(0) + + assert solveset_real(x/sqrt(x**2 + 1), x) == FiniteSet(0) + + eq = (x - y**3)/((y**2)*sqrt(1 - y**2)) + assert solveset_real(eq, x) == FiniteSet(y**3) + + # issue 4497 + assert solveset_real(1/(5 + x)**Rational(1, 5) - 9, x) == \ + FiniteSet(Rational(-295244, 59049)) + + +@XFAIL +def test_solve_sqrt_fail(): + # this only works if we check real_root(eq.subs(x, Rational(1, 3))) + # but checksol doesn't work like that + eq = (x**3 - 3*x**2)**Rational(1, 3) + 1 - x + assert solveset_real(eq, x) == FiniteSet(Rational(1, 3)) + + +@slow +def test_solve_sqrt_3(): + R = Symbol('R') + eq = sqrt(2)*R*sqrt(1/(R + 1)) + (R + 1)*(sqrt(2)*sqrt(1/(R + 1)) - 1) + sol = solveset_complex(eq, R) + fset = [Rational(5, 3) + 4*sqrt(10)*cos(atan(3*sqrt(111)/251)/3)/3, + -sqrt(10)*cos(atan(3*sqrt(111)/251)/3)/3 + + 40*re(1/((Rational(-1, 2) - sqrt(3)*I/2)*(Rational(251, 27) + sqrt(111)*I/9)**Rational(1, 3)))/9 + + sqrt(30)*sin(atan(3*sqrt(111)/251)/3)/3 + Rational(5, 3) + + I*(-sqrt(30)*cos(atan(3*sqrt(111)/251)/3)/3 - + sqrt(10)*sin(atan(3*sqrt(111)/251)/3)/3 + + 40*im(1/((Rational(-1, 2) - sqrt(3)*I/2)*(Rational(251, 27) + sqrt(111)*I/9)**Rational(1, 3)))/9)] + cset = [40*re(1/((Rational(-1, 2) + sqrt(3)*I/2)*(Rational(251, 27) + sqrt(111)*I/9)**Rational(1, 3)))/9 - + sqrt(10)*cos(atan(3*sqrt(111)/251)/3)/3 - sqrt(30)*sin(atan(3*sqrt(111)/251)/3)/3 + + Rational(5, 3) + + I*(40*im(1/((Rational(-1, 2) + sqrt(3)*I/2)*(Rational(251, 27) + sqrt(111)*I/9)**Rational(1, 3)))/9 - + sqrt(10)*sin(atan(3*sqrt(111)/251)/3)/3 + + sqrt(30)*cos(atan(3*sqrt(111)/251)/3)/3)] + + fs = FiniteSet(*fset) + cs = ConditionSet(R, Eq(eq, 0), FiniteSet(*cset)) + assert sol == (fs - {-1}) | (cs - {-1}) + + # the number of real roots will depend on the value of m: for m=1 there are 4 + # and for m=-1 there are none. + eq = -sqrt((m - q)**2 + (-m/(2*q) + S.Half)**2) + sqrt((-m**2/2 - sqrt( + 4*m**4 - 4*m**2 + 8*m + 1)/4 - Rational(1, 4))**2 + (m**2/2 - m - sqrt( + 4*m**4 - 4*m**2 + 8*m + 1)/4 - Rational(1, 4))**2) + unsolved_object = ConditionSet(q, Eq(sqrt((m - q)**2 + (-m/(2*q) + S.Half)**2) - + sqrt((-m**2/2 - sqrt(4*m**4 - 4*m**2 + 8*m + 1)/4 - Rational(1, 4))**2 + (m**2/2 - m - + sqrt(4*m**4 - 4*m**2 + 8*m + 1)/4 - Rational(1, 4))**2), 0), S.Reals) + assert solveset_real(eq, q) == unsolved_object + + +def test_solve_polynomial_symbolic_param(): + assert solveset_complex((x**2 - 1)**2 - a, x) == \ + FiniteSet(sqrt(1 + sqrt(a)), -sqrt(1 + sqrt(a)), + sqrt(1 - sqrt(a)), -sqrt(1 - sqrt(a))) + + # issue 4507 + assert solveset_complex(y - b/(1 + a*x), x) == \ + FiniteSet((b/y - 1)/a) - FiniteSet(-1/a) + + # issue 4508 + assert solveset_complex(y - b*x/(a + x), x) == \ + FiniteSet(-a*y/(y - b)) - FiniteSet(-a) + + +def test_solve_rational(): + assert solveset_real(1/x + 1, x) == FiniteSet(-S.One) + assert solveset_real(1/exp(x) - 1, x) == FiniteSet(0) + assert solveset_real(x*(1 - 5/x), x) == FiniteSet(5) + assert solveset_real(2*x/(x + 2) - 1, x) == FiniteSet(2) + assert solveset_real((x**2/(7 - x)).diff(x), x) == \ + FiniteSet(S.Zero, S(14)) + + +def test_solveset_real_gen_is_pow(): + assert solveset_real(sqrt(1) + 1, x) is S.EmptySet + + +def test_no_sol(): + assert solveset(1 - oo*x) is S.EmptySet + assert solveset(oo*x, x) is S.EmptySet + assert solveset(oo*x - oo, x) is S.EmptySet + assert solveset_real(4, x) is S.EmptySet + assert solveset_real(exp(x), x) is S.EmptySet + assert solveset_real(x**2 + 1, x) is S.EmptySet + assert solveset_real(-3*a/sqrt(x), x) is S.EmptySet + assert solveset_real(1/x, x) is S.EmptySet + assert solveset_real(-(1 + x)/(2 + x)**2 + 1/(2 + x), x + ) is S.EmptySet + + +def test_sol_zero_real(): + assert solveset_real(0, x) == S.Reals + assert solveset(0, x, Interval(1, 2)) == Interval(1, 2) + assert solveset_real(-x**2 - 2*x + (x + 1)**2 - 1, x) == S.Reals + + +def test_no_sol_rational_extragenous(): + assert solveset_real((x/(x + 1) + 3)**(-2), x) is S.EmptySet + assert solveset_real((x - 1)/(1 + 1/(x - 1)), x) is S.EmptySet + + +def test_solve_polynomial_cv_1a(): + """ + Test for solving on equations that can be converted to + a polynomial equation using the change of variable y -> x**Rational(p, q) + """ + assert solveset_real(sqrt(x) - 1, x) == FiniteSet(1) + assert solveset_real(sqrt(x) - 2, x) == FiniteSet(4) + assert solveset_real(x**Rational(1, 4) - 2, x) == FiniteSet(16) + assert solveset_real(x**Rational(1, 3) - 3, x) == FiniteSet(27) + assert solveset_real(x*(x**(S.One / 3) - 3), x) == \ + FiniteSet(S.Zero, S(27)) + + +def test_solveset_real_rational(): + """Test solveset_real for rational functions""" + x = Symbol('x', real=True) + y = Symbol('y', real=True) + assert solveset_real((x - y**3) / ((y**2)*sqrt(1 - y**2)), x) \ + == FiniteSet(y**3) + # issue 4486 + assert solveset_real(2*x/(x + 2) - 1, x) == FiniteSet(2) + + +def test_solveset_real_log(): + assert solveset_real(log((x-1)*(x+1)), x) == \ + FiniteSet(sqrt(2), -sqrt(2)) + + +def test_poly_gens(): + assert solveset_real(4**(2*(x**2) + 2*x) - 8, x) == \ + FiniteSet(Rational(-3, 2), S.Half) + + +def test_solve_abs(): + n = Dummy('n') + raises(ValueError, lambda: solveset(Abs(x) - 1, x)) + assert solveset(Abs(x) - n, x, S.Reals).dummy_eq( + ConditionSet(x, Contains(n, Interval(0, oo)), {-n, n})) + assert solveset_real(Abs(x) - 2, x) == FiniteSet(-2, 2) + assert solveset_real(Abs(x) + 2, x) is S.EmptySet + assert solveset_real(Abs(x + 3) - 2*Abs(x - 3), x) == \ + FiniteSet(1, 9) + assert solveset_real(2*Abs(x) - Abs(x - 1), x) == \ + FiniteSet(-1, Rational(1, 3)) + + sol = ConditionSet( + x, + And( + Contains(b, Interval(0, oo)), + Contains(a + b, Interval(0, oo)), + Contains(a - b, Interval(0, oo))), + FiniteSet(-a - b - 3, -a + b - 3, a - b - 3, a + b - 3)) + eq = Abs(Abs(x + 3) - a) - b + assert invert_real(eq, 0, x)[1] == sol + reps = {a: 3, b: 1} + eqab = eq.subs(reps) + for si in sol.subs(reps): + assert not eqab.subs(x, si) + assert dumeq(solveset(Eq(sin(Abs(x)), 1), x, domain=S.Reals), Union( + Intersection(Interval(0, oo), Union( + Intersection(ImageSet(Lambda(n, 2*n*pi + 3*pi/2), S.Integers), + Interval(-oo, 0)), + Intersection(ImageSet(Lambda(n, 2*n*pi + pi/2), S.Integers), + Interval(0, oo)))))) + + +def test_issue_9824(): + assert dumeq(solveset(sin(x)**2 - 2*sin(x) + 1, x), ImageSet(Lambda(n, 2*n*pi + pi/2), S.Integers)) + assert dumeq(solveset(cos(x)**2 - 2*cos(x) + 1, x), ImageSet(Lambda(n, 2*n*pi), S.Integers)) + + +def test_issue_9565(): + assert solveset_real(Abs((x - 1)/(x - 5)) <= Rational(1, 3), x) == Interval(-1, 2) + + +def test_issue_10069(): + eq = abs(1/(x - 1)) - 1 > 0 + assert solveset_real(eq, x) == Union( + Interval.open(0, 1), Interval.open(1, 2)) + + +def test_real_imag_splitting(): + a, b = symbols('a b', real=True) + assert solveset_real(sqrt(a**2 - b**2) - 3, a) == \ + FiniteSet(-sqrt(b**2 + 9), sqrt(b**2 + 9)) + assert solveset_real(sqrt(a**2 + b**2) - 3, a) != \ + S.EmptySet + + +def test_units(): + assert solveset_real(1/x - 1/(2*cm), x) == FiniteSet(2*cm) + + +def test_solve_only_exp_1(): + y = Symbol('y', positive=True) + assert solveset_real(exp(x) - y, x) == FiniteSet(log(y)) + assert solveset_real(exp(x) + exp(-x) - 4, x) == \ + FiniteSet(log(-sqrt(3) + 2), log(sqrt(3) + 2)) + assert solveset_real(exp(x) + exp(-x) - y, x) != S.EmptySet + + +def test_atan2(): + # The .inverse() method on atan2 works only if x.is_real is True and the + # second argument is a real constant + assert solveset_real(atan2(x, 2) - pi/3, x) == FiniteSet(2*sqrt(3)) + + +def test_piecewise_solveset(): + eq = Piecewise((x - 2, Gt(x, 2)), (2 - x, True)) - 3 + assert set(solveset_real(eq, x)) == set(FiniteSet(-1, 5)) + + absxm3 = Piecewise( + (x - 3, 0 <= x - 3), + (3 - x, 0 > x - 3)) + y = Symbol('y', positive=True) + assert solveset_real(absxm3 - y, x) == FiniteSet(-y + 3, y + 3) + + f = Piecewise(((x - 2)**2, x >= 0), (0, True)) + assert solveset(f, x, domain=S.Reals) == Union(FiniteSet(2), Interval(-oo, 0, True, True)) + + assert solveset( + Piecewise((x + 1, x > 0), (I, True)) - I, x, S.Reals + ) == Interval(-oo, 0) + + assert solveset(Piecewise((x - 1, Ne(x, I)), (x, True)), x) == FiniteSet(1) + + # issue 19718 + g = Piecewise((1, x > 10), (0, True)) + assert solveset(g > 0, x, S.Reals) == Interval.open(10, oo) + + from sympy.logic.boolalg import BooleanTrue + f = BooleanTrue() + assert solveset(f, x, domain=Interval(-3, 10)) == Interval(-3, 10) + + # issue 20552 + f = Piecewise((0, Eq(x, 0)), (x**2/Abs(x), True)) + g = Piecewise((0, Eq(x, pi)), ((x - pi)/sin(x), True)) + assert solveset(f, x, domain=S.Reals) == FiniteSet(0) + assert solveset(g) == FiniteSet(pi) + + +def test_solveset_complex_polynomial(): + assert solveset_complex(a*x**2 + b*x + c, x) == \ + FiniteSet(-b/(2*a) - sqrt(-4*a*c + b**2)/(2*a), + -b/(2*a) + sqrt(-4*a*c + b**2)/(2*a)) + + assert solveset_complex(x - y**3, y) == FiniteSet( + (-x**Rational(1, 3))/2 + I*sqrt(3)*x**Rational(1, 3)/2, + x**Rational(1, 3), + (-x**Rational(1, 3))/2 - I*sqrt(3)*x**Rational(1, 3)/2) + + assert solveset_complex(x + 1/x - 1, x) == \ + FiniteSet(S.Half + I*sqrt(3)/2, S.Half - I*sqrt(3)/2) + + +def test_sol_zero_complex(): + assert solveset_complex(0, x) is S.Complexes + + +def test_solveset_complex_rational(): + assert solveset_complex((x - 1)*(x - I)/(x - 3), x) == \ + FiniteSet(1, I) + + assert solveset_complex((x - y**3)/((y**2)*sqrt(1 - y**2)), x) == \ + FiniteSet(y**3) + assert solveset_complex(-x**2 - I, x) == \ + FiniteSet(-sqrt(2)/2 + sqrt(2)*I/2, sqrt(2)/2 - sqrt(2)*I/2) + + +def test_solve_quintics(): + skip("This test is too slow") + f = x**5 - 110*x**3 - 55*x**2 + 2310*x + 979 + s = solveset_complex(f, x) + for root in s: + res = f.subs(x, root.n()).n() + assert tn(res, 0) + + f = x**5 + 15*x + 12 + s = solveset_complex(f, x) + for root in s: + res = f.subs(x, root.n()).n() + assert tn(res, 0) + + +def test_solveset_complex_exp(): + assert dumeq(solveset_complex(exp(x) - 1, x), + imageset(Lambda(n, I*2*n*pi), S.Integers)) + assert dumeq(solveset_complex(exp(x) - I, x), + imageset(Lambda(n, I*(2*n*pi + pi/2)), S.Integers)) + assert solveset_complex(1/exp(x), x) == S.EmptySet + assert dumeq(solveset_complex(sinh(x).rewrite(exp), x), + imageset(Lambda(n, n*pi*I), S.Integers)) + + +def test_solveset_real_exp(): + assert solveset(Eq((-2)**x, 4), x, S.Reals) == FiniteSet(2) + assert solveset(Eq(-2**x, 4), x, S.Reals) == S.EmptySet + assert solveset(Eq((-3)**x, 27), x, S.Reals) == S.EmptySet + assert solveset(Eq((-5)**(x+1), 625), x, S.Reals) == FiniteSet(3) + assert solveset(Eq(2**(x-3), -16), x, S.Reals) == S.EmptySet + assert solveset(Eq((-3)**(x - 3), -3**39), x, S.Reals) == FiniteSet(42) + assert solveset(Eq(2**x, y), x, S.Reals) == Intersection(S.Reals, FiniteSet(log(y)/log(2))) + + assert invert_real((-2)**(2*x) - 16, 0, x) == (x, FiniteSet(2)) + + +def test_solve_complex_log(): + assert solveset_complex(log(x), x) == FiniteSet(1) + assert solveset_complex(1 - log(a + 4*x**2), x) == \ + FiniteSet(-sqrt(-a + E)/2, sqrt(-a + E)/2) + + +def test_solve_complex_sqrt(): + assert solveset_complex(sqrt(5*x + 6) - 2 - x, x) == \ + FiniteSet(-S.One, S(2)) + assert solveset_complex(sqrt(5*x + 6) - (2 + 2*I) - x, x) == \ + FiniteSet(-S(2), 3 - 4*I) + assert solveset_complex(4*x*(1 - a * sqrt(x)), x) == \ + FiniteSet(S.Zero, 1 / a ** 2) + + +def test_solveset_complex_tan(): + s = solveset_complex(tan(x).rewrite(exp), x) + assert dumeq(s, imageset(Lambda(n, pi*n), S.Integers) - \ + imageset(Lambda(n, pi*n + pi/2), S.Integers)) + + +@_both_exp_pow +def test_solve_trig(): + assert dumeq(solveset_real(sin(x), x), + Union(imageset(Lambda(n, 2*pi*n), S.Integers), + imageset(Lambda(n, 2*pi*n + pi), S.Integers))) + + assert dumeq(solveset_real(sin(x) - 1, x), + imageset(Lambda(n, 2*pi*n + pi/2), S.Integers)) + + assert dumeq(solveset_real(cos(x), x), + Union(imageset(Lambda(n, 2*pi*n + pi/2), S.Integers), + imageset(Lambda(n, 2*pi*n + pi*Rational(3, 2)), S.Integers))) + + assert dumeq(solveset_real(sin(x) + cos(x), x), + Union(imageset(Lambda(n, 2*n*pi + pi*Rational(3, 4)), S.Integers), + imageset(Lambda(n, 2*n*pi + pi*Rational(7, 4)), S.Integers))) + + assert solveset_real(sin(x)**2 + cos(x)**2, x) == S.EmptySet + + assert dumeq(solveset_complex(cos(x) - S.Half, x), + Union(imageset(Lambda(n, 2*n*pi + pi*Rational(5, 3)), S.Integers), + imageset(Lambda(n, 2*n*pi + pi/3), S.Integers))) + + assert dumeq(solveset(sin(y + a) - sin(y), a, domain=S.Reals), + ConditionSet(a, (S(-1) <= sin(y)) & (sin(y) <= S(1)), Union( + ImageSet(Lambda(n, 2*n*pi - y + asin(sin(y))), S.Integers), + ImageSet(Lambda(n, 2*n*pi - y - asin(sin(y)) + pi), S.Integers)))) + + assert dumeq(solveset_real(sin(2*x)*cos(x) + cos(2*x)*sin(x)-1, x), + ImageSet(Lambda(n, n*pi*Rational(2, 3) + pi/6), S.Integers)) + + assert dumeq(solveset_real(2*tan(x)*sin(x) + 1, x), Union( + ImageSet(Lambda(n, 2*n*pi + atan(sqrt(2)*sqrt(-1 + sqrt(17))/ + (1 - sqrt(17))) + pi), S.Integers), + ImageSet(Lambda(n, 2*n*pi - atan(sqrt(2)*sqrt(-1 + sqrt(17))/ + (1 - sqrt(17))) + pi), S.Integers))) + + assert dumeq(solveset_real(cos(2*x)*cos(4*x) - 1, x), + ImageSet(Lambda(n, n*pi), S.Integers)) + + assert dumeq(solveset(sin(x/10) + Rational(3, 4)), Union( + ImageSet(Lambda(n, 20*n*pi - 10*asin(S(3)/4) + 20*pi), S.Integers), + ImageSet(Lambda(n, 20*n*pi + 10*asin(S(3)/4) + 10*pi), S.Integers))) + + assert dumeq(solveset(cos(x/15) + cos(x/5)), Union( + ImageSet(Lambda(n, 30*n*pi + 15*pi/2), S.Integers), + ImageSet(Lambda(n, 30*n*pi + 45*pi/2), S.Integers), + ImageSet(Lambda(n, 30*n*pi + 75*pi/4), S.Integers), + ImageSet(Lambda(n, 30*n*pi + 45*pi/4), S.Integers), + ImageSet(Lambda(n, 30*n*pi + 105*pi/4), S.Integers), + ImageSet(Lambda(n, 30*n*pi + 15*pi/4), S.Integers))) + + assert dumeq(solveset(sec(sqrt(2)*x/3) + 5), Union( + ImageSet(Lambda(n, 3*sqrt(2)*(2*n*pi - asec(-5))/2), S.Integers), + ImageSet(Lambda(n, 3*sqrt(2)*(2*n*pi + asec(-5))/2), S.Integers))) + + assert dumeq(simplify(solveset(tan(pi*x) - cot(pi/2*x))), Union( + ImageSet(Lambda(n, 4*n + 1), S.Integers), + ImageSet(Lambda(n, 4*n + 3), S.Integers), + ImageSet(Lambda(n, 4*n + Rational(7, 3)), S.Integers), + ImageSet(Lambda(n, 4*n + Rational(5, 3)), S.Integers), + ImageSet(Lambda(n, 4*n + Rational(11, 3)), S.Integers), + ImageSet(Lambda(n, 4*n + Rational(1, 3)), S.Integers))) + + assert dumeq(solveset(cos(9*x)), Union( + ImageSet(Lambda(n, 2*n*pi/9 + pi/18), S.Integers), + ImageSet(Lambda(n, 2*n*pi/9 + pi/6), S.Integers))) + + assert dumeq(solveset(sin(8*x) + cot(12*x), x, S.Reals), Union( + ImageSet(Lambda(n, n*pi/2 + pi/8), S.Integers), + ImageSet(Lambda(n, n*pi/2 + 3*pi/8), S.Integers), + ImageSet(Lambda(n, n*pi/2 + 5*pi/16), S.Integers), + ImageSet(Lambda(n, n*pi/2 + 3*pi/16), S.Integers), + ImageSet(Lambda(n, n*pi/2 + 7*pi/16), S.Integers), + ImageSet(Lambda(n, n*pi/2 + pi/16), S.Integers))) + + # This is the only remaining solveset test that actually ends up being solved + # by _solve_trig2(). All others are handled by the improved _solve_trig1. + assert dumeq(solveset_real(2*cos(x)*cos(2*x) - 1, x), + Union(ImageSet(Lambda(n, 2*n*pi + 2*atan(sqrt(-2*2**Rational(1, 3)*(67 + + 9*sqrt(57))**Rational(2, 3) + 8*2**Rational(2, 3) + 11*(67 + + 9*sqrt(57))**Rational(1, 3))/(3*(67 + 9*sqrt(57))**Rational(1, 6)))), S.Integers), + ImageSet(Lambda(n, 2*n*pi - 2*atan(sqrt(-2*2**Rational(1, 3)*(67 + + 9*sqrt(57))**Rational(2, 3) + 8*2**Rational(2, 3) + 11*(67 + + 9*sqrt(57))**Rational(1, 3))/(3*(67 + 9*sqrt(57))**Rational(1, 6))) + + 2*pi), S.Integers))) + + # issue #16870 + assert dumeq(simplify(solveset(sin(x/180*pi) - S.Half, x, S.Reals)), Union( + ImageSet(Lambda(n, 360*n + 150), S.Integers), + ImageSet(Lambda(n, 360*n + 30), S.Integers))) + + +def test_solve_trig_hyp_by_inversion(): + n = Dummy('n') + assert solveset_real(sin(2*x + 3) - S(1)/2, x).dummy_eq(Union( + ImageSet(Lambda(n, n*pi - S(3)/2 + 13*pi/12), S.Integers), + ImageSet(Lambda(n, n*pi - S(3)/2 + 17*pi/12), S.Integers))) + assert solveset_complex(sin(2*x + 3) - S(1)/2, x).dummy_eq(Union( + ImageSet(Lambda(n, n*pi - S(3)/2 + 13*pi/12), S.Integers), + ImageSet(Lambda(n, n*pi - S(3)/2 + 17*pi/12), S.Integers))) + assert solveset_real(tan(x) - tan(pi/10), x).dummy_eq( + ImageSet(Lambda(n, n*pi + pi/10), S.Integers)) + assert solveset_complex(tan(x) - tan(pi/10), x).dummy_eq( + ImageSet(Lambda(n, n*pi + pi/10), S.Integers)) + + assert solveset_real(3*cosh(2*x) - 5, x) == FiniteSet( + -acosh(S(5)/3)/2, acosh(S(5)/3)/2) + assert solveset_complex(3*cosh(2*x) - 5, x).dummy_eq(Union( + ImageSet(Lambda(n, n*I*pi - acosh(S(5)/3)/2), S.Integers), + ImageSet(Lambda(n, n*I*pi + acosh(S(5)/3)/2), S.Integers))) + assert solveset_real(sinh(x - 3) - 2, x) == FiniteSet( + asinh(2) + 3) + assert solveset_complex(sinh(x - 3) - 2, x).dummy_eq(Union( + ImageSet(Lambda(n, 2*n*I*pi + asinh(2) + 3), S.Integers), + ImageSet(Lambda(n, 2*n*I*pi - asinh(2) + 3 + I*pi), S.Integers))) + + assert solveset_real(cos(sinh(x))-cos(pi/12), x).dummy_eq(Union( + ImageSet(Lambda(n, asinh(2*n*pi + pi/12)), S.Integers), + ImageSet(Lambda(n, asinh(2*n*pi + 23*pi/12)), S.Integers))) + assert solveset(cos(sinh(x))-cos(pi/12), x, Interval(2,3)) == \ + FiniteSet(asinh(23*pi/12), asinh(25*pi/12)) + assert solveset_real(cosh(x**2-1)-2, x) == FiniteSet( + -sqrt(1 + acosh(2)), sqrt(1 + acosh(2))) + + assert solveset_real(sin(x) - 2, x) == S.EmptySet # issue #17334 + assert solveset_real(cos(x) + 2, x) == S.EmptySet + assert solveset_real(sec(x), x) == S.EmptySet + assert solveset_real(csc(x), x) == S.EmptySet + assert solveset_real(cosh(x) + 1, x) == S.EmptySet + assert solveset_real(coth(x), x) == S.EmptySet + assert solveset_real(sech(x) - 2, x) == S.EmptySet + assert solveset_real(sech(x), x) == S.EmptySet + assert solveset_real(tanh(x) + 1, x) == S.EmptySet + assert solveset_complex(tanh(x), 1) == S.EmptySet + assert solveset_complex(coth(x), -1) == S.EmptySet + assert solveset_complex(sech(x), 0) == S.EmptySet + assert solveset_complex(csch(x), 0) == S.EmptySet + + assert solveset_real(abs(csch(x)) - 3, x) == FiniteSet(-acsch(3), acsch(3)) + + assert solveset_real(tanh(x**2 - 1) - exp(-9), x) == FiniteSet( + -sqrt(atanh(exp(-9)) + 1), sqrt(atanh(exp(-9)) + 1)) + + assert solveset_real(coth(log(x)) + 2, x) == FiniteSet(exp(-acoth(2))) + assert solveset_real(coth(exp(x)) + 2, x) == S.EmptySet + + assert solveset_complex(sinh(x) - I/2, x).dummy_eq(Union( + ImageSet(Lambda(n, 2*I*pi*n + 5*I*pi/6), S.Integers), + ImageSet(Lambda(n, 2*I*pi*n + I*pi/6), S.Integers))) + assert solveset_complex(sinh(x/10) + Rational(3, 4), x).dummy_eq(Union( + ImageSet(Lambda(n, 20*n*I*pi - 10*asinh(S(3)/4)), S.Integers), + ImageSet(Lambda(n, 20*n*I*pi + 10*asinh(S(3)/4) + 10*I*pi), S.Integers))) + assert solveset_complex(sech(sqrt(2)*x/3) + 5, x).dummy_eq(Union( + ImageSet(Lambda(n, 3*sqrt(2)*(2*n*I*pi - asech(-5))/2), S.Integers), + ImageSet(Lambda(n, 3*sqrt(2)*(2*n*I*pi + asech(-5))/2), S.Integers))) + assert solveset_complex(cosh(9*x), x).dummy_eq(Union( + ImageSet(Lambda(n, 2*n*I*pi/9 + I*pi/18), S.Integers), + ImageSet(Lambda(n, 2*n*I*pi/9 + I*pi/6), S.Integers))) + + eq = (x**5 -4*x + 1).subs(x, coth(z)) + assert solveset(eq, z, S.Complexes).dummy_eq(Union( + ImageSet(Lambda(n, n*I*pi + acoth(CRootOf(x**5 -4*x + 1, 0))), S.Integers), + ImageSet(Lambda(n, n*I*pi + acoth(CRootOf(x**5 -4*x + 1, 1))), S.Integers), + ImageSet(Lambda(n, n*I*pi + acoth(CRootOf(x**5 -4*x + 1, 2))), S.Integers), + ImageSet(Lambda(n, n*I*pi + acoth(CRootOf(x**5 -4*x + 1, 3))), S.Integers), + ImageSet(Lambda(n, n*I*pi + acoth(CRootOf(x**5 -4*x + 1, 4))), S.Integers))) + assert solveset(eq, z, S.Reals) == FiniteSet( + acoth(CRootOf(x**5 - 4*x + 1, 0)), acoth(CRootOf(x**5 - 4*x + 1, 2))) + + eq = ((x-sqrt(3)/2)*(x+2)).expand().subs(x, cos(x)) + assert solveset(eq, x, S.Complexes).dummy_eq(Union( + ImageSet(Lambda(n, 2*n*pi - acos(-2)), S.Integers), + ImageSet(Lambda(n, 2*n*pi + acos(-2)), S.Integers), + ImageSet(Lambda(n, 2*n*pi + pi/6), S.Integers), + ImageSet(Lambda(n, 2*n*pi + 11*pi/6), S.Integers))) + assert solveset(eq, x, S.Reals).dummy_eq(Union( + ImageSet(Lambda(n, 2*n*pi + pi/6), S.Integers), + ImageSet(Lambda(n, 2*n*pi + 11*pi/6), S.Integers))) + + assert solveset((1+sec(sqrt(3)*x+4)**2)/(1-sec(sqrt(3)*x+4))).dummy_eq(Union( + ImageSet(Lambda(n, sqrt(3)*(2*n*pi - 4 - asec(I))/3), S.Integers), + ImageSet(Lambda(n, sqrt(3)*(2*n*pi - 4 + asec(I))/3), S.Integers), + ImageSet(Lambda(n, sqrt(3)*(2*n*pi - 4 - asec(-I))/3), S.Integers), + ImageSet(Lambda(n, sqrt(3)*(2*n*pi - 4 + asec(-I))/3), S.Integers))) + + assert all_close(solveset(tan(3.14*x)**(S(3)/2)-5.678, x, Interval(0, 3)), + FiniteSet(0.403301114561067, 0.403301114561067 + 0.318471337579618*pi, + 0.403301114561067 + 0.636942675159236*pi)) + + +def test_old_trig_issues(): + # issues #9606 / #9531: + assert solveset(sinh(x), x, S.Reals) == FiniteSet(0) + assert solveset(sinh(x), x, S.Complexes).dummy_eq(Union( + ImageSet(Lambda(n, 2*n*I*pi), S.Integers), + ImageSet(Lambda(n, 2*n*I*pi + I*pi), S.Integers))) + + # issues #11218 / #18427 + assert solveset(sin(pi*x), x, S.Reals).dummy_eq(Union( + ImageSet(Lambda(n, (2*n*pi + pi)/pi), S.Integers), + ImageSet(Lambda(n, 2*n), S.Integers))) + assert solveset(sin(pi*x), x).dummy_eq(Union( + ImageSet(Lambda(n, (2*n*pi + pi)/pi), S.Integers), + ImageSet(Lambda(n, 2*n), S.Integers))) + + # issue #17543 + assert solveset(I*cot(8*x - 8*E), x).dummy_eq( + ImageSet(Lambda(n, pi*n/8 - 13*pi/16 + E), S.Integers)) + + # issue #20798 + assert all_close(solveset(cos(2*x) - 0.5, x, Interval(0, 2*pi)), FiniteSet( + 0.523598775598299, -0.523598775598299 + pi, + -0.523598775598299 + 2*pi, 0.523598775598299 + pi)) + sol = Union(ImageSet(Lambda(n, n*pi - 0.523598775598299), S.Integers), + ImageSet(Lambda(n, n*pi + 0.523598775598299), S.Integers)) + ret = solveset(cos(2*x) - 0.5, x, S.Reals) + # replace Dummy n by the regular Symbol n to allow all_close comparison. + ret = ret.subs(ret.atoms(Dummy).pop(), n) + assert all_close(ret, sol) + ret = solveset(cos(2*x) - 0.5, x, S.Complexes) + ret = ret.subs(ret.atoms(Dummy).pop(), n) + assert all_close(ret, sol) + + # issue #21296 / #17667 + assert solveset(tan(x)-sqrt(2), x, Interval(0, pi/2)) == FiniteSet(atan(sqrt(2))) + assert solveset(tan(x)-pi, x, Interval(0, pi/2)) == FiniteSet(atan(pi)) + + # issue #17667 + # not yet working properly: + # solveset(cos(x)-y, x, Interval(0, pi)) + assert solveset(cos(x)-y, x, S.Reals).dummy_eq( + ConditionSet(x,(S(-1) <= y) & (y <= S(1)), Union( + ImageSet(Lambda(n, 2*n*pi - acos(y)), S.Integers), + ImageSet(Lambda(n, 2*n*pi + acos(y)), S.Integers)))) + + # issue #17579 + # Valid result, but the intersection could potentially be simplified. + assert solveset(sin(log(x)), x, Interval(0,1, True, False)).dummy_eq( + Union(Intersection(ImageSet(Lambda(n, exp(2*n*pi)), S.Integers), Interval.Lopen(0, 1)), + Intersection(ImageSet(Lambda(n, exp(2*n*pi + pi)), S.Integers), Interval.Lopen(0, 1)))) + + # issue #17334 + assert solveset(sin(x) - sin(1), x, S.Reals).dummy_eq(Union( + ImageSet(Lambda(n, 2*n*pi + 1), S.Integers), + ImageSet(Lambda(n, 2*n*pi - 1 + pi), S.Integers))) + assert solveset(sin(x) - sqrt(5)/3, x, S.Reals).dummy_eq(Union( + ImageSet(Lambda(n, 2*n*pi + asin(sqrt(5)/3)), S.Integers), + ImageSet(Lambda(n, 2*n*pi - asin(sqrt(5)/3) + pi), S.Integers))) + assert solveset(sinh(x)-cosh(2), x, S.Reals) == FiniteSet(asinh(cosh(2))) + + # issue 9825 + assert solveset(Eq(tan(x), y), x, domain=S.Reals).dummy_eq( + ConditionSet(x, (-oo < y) & (y < oo), + ImageSet(Lambda(n, n*pi + atan(y)), S.Integers))) + r = Symbol('r', real=True) + assert solveset(Eq(tan(x), r), x, domain=S.Reals).dummy_eq( + ImageSet(Lambda(n, n*pi + atan(r)), S.Integers)) + + +def test_solve_hyperbolic(): + # actual solver: _solve_trig1 + n = Dummy('n') + assert solveset(sinh(x) + cosh(x), x) == S.EmptySet + assert solveset(sinh(x) + cos(x), x) == ConditionSet(x, + Eq(cos(x) + sinh(x), 0), S.Complexes) + assert solveset_real(sinh(x) + sech(x), x) == FiniteSet( + log(sqrt(sqrt(5) - 2))) + assert solveset_real(cosh(2*x) + 2*sinh(x) - 5, x) == FiniteSet( + log(-2 + sqrt(5)), log(1 + sqrt(2))) + assert solveset_real((coth(x) + sinh(2*x))/cosh(x) - 3, x) == FiniteSet( + log(S.Half + sqrt(5)/2), log(1 + sqrt(2))) + assert solveset_real(cosh(x)*sinh(x) - 2, x) == FiniteSet( + log(4 + sqrt(17))/2) + assert solveset_real(sinh(x) + tanh(x) - 1, x) == FiniteSet( + log(sqrt(2)/2 + sqrt(-S(1)/2 + sqrt(2)))) + + assert dumeq(solveset_complex(sinh(x) + sech(x), x), Union( + ImageSet(Lambda(n, 2*n*I*pi + log(sqrt(-2 + sqrt(5)))), S.Integers), + ImageSet(Lambda(n, I*(2*n*pi + pi/2) + log(sqrt(2 + sqrt(5)))), S.Integers), + ImageSet(Lambda(n, I*(2*n*pi + pi) + log(sqrt(-2 + sqrt(5)))), S.Integers), + ImageSet(Lambda(n, I*(2*n*pi - pi/2) + log(sqrt(2 + sqrt(5)))), S.Integers))) + + assert dumeq(solveset(cosh(x/15) + cosh(x/5)), Union( + ImageSet(Lambda(n, 15*I*(2*n*pi + pi/2)), S.Integers), + ImageSet(Lambda(n, 15*I*(2*n*pi - pi/2)), S.Integers), + ImageSet(Lambda(n, 15*I*(2*n*pi - 3*pi/4)), S.Integers), + ImageSet(Lambda(n, 15*I*(2*n*pi + 3*pi/4)), S.Integers), + ImageSet(Lambda(n, 15*I*(2*n*pi - pi/4)), S.Integers), + ImageSet(Lambda(n, 15*I*(2*n*pi + pi/4)), S.Integers))) + + assert dumeq(solveset(tanh(pi*x) - coth(pi/2*x)), Union( + ImageSet(Lambda(n, 2*I*(2*n*pi + pi/2)/pi), S.Integers), + ImageSet(Lambda(n, 2*I*(2*n*pi - pi/2)/pi), S.Integers))) + + # issues #18490 / #19489 + assert solveset(cosh(x) + cosh(3*x) - cosh(5*x), x, S.Reals + ).dummy_eq(ConditionSet(x, + Eq(cosh(x) + cosh(3*x) - cosh(5*x), 0), S.Reals)) + assert solveset(sinh(8*x) + coth(12*x)).dummy_eq( + ConditionSet(x, Eq(sinh(8*x) + coth(12*x), 0), S.Complexes)) + + +def test_solve_trig_hyp_symbolic(): + # actual solver: invert_trig_hyp + assert dumeq(solveset(sin(a*x), x), ConditionSet(x, Ne(a, 0), Union( + ImageSet(Lambda(n, (2*n*pi + pi)/a), S.Integers), + ImageSet(Lambda(n, 2*n*pi/a), S.Integers)))) + + assert dumeq(solveset(cosh(x/a), x), ConditionSet(x, Ne(a, 0), Union( + ImageSet(Lambda(n, a*(2*n*I*pi + I*pi/2)), S.Integers), + ImageSet(Lambda(n, a*(2*n*I*pi + 3*I*pi/2)), S.Integers)))) + + assert dumeq(solveset(sin(2*sqrt(3)/3*a**2/(b*pi)*x) + + cos(4*sqrt(3)/3*a**2/(b*pi)*x), x), + ConditionSet(x, Ne(b, 0) & Ne(a**2, 0), Union( + ImageSet(Lambda(n, sqrt(3)*pi*b*(2*n*pi + pi/2)/(2*a**2)), S.Integers), + ImageSet(Lambda(n, sqrt(3)*pi*b*(2*n*pi - 5*pi/6)/(2*a**2)), S.Integers), + ImageSet(Lambda(n, sqrt(3)*pi*b*(2*n*pi - pi/6)/(2*a**2)), S.Integers)))) + + assert dumeq(solveset(cosh((a**2 + 1)*x) - 3, x), ConditionSet( + x, Ne(a**2 + 1, 0), Union( + ImageSet(Lambda(n, (2*n*I*pi - acosh(3))/(a**2 + 1)), S.Integers), + ImageSet(Lambda(n, (2*n*I*pi + acosh(3))/(a**2 + 1)), S.Integers)))) + + ar = Symbol('ar', real=True) + assert solveset(cosh((ar**2 + 1)*x) - 2, x, S.Reals) == FiniteSet( + -acosh(2)/(ar**2 + 1), acosh(2)/(ar**2 + 1)) + + # actual solver: _solve_trig1 + assert dumeq(simplify(solveset(cot((1 + I)*x) - cot((3 + 3*I)*x), x)), Union( + ImageSet(Lambda(n, pi*(1 - I)*(4*n + 1)/4), S.Integers), + ImageSet(Lambda(n, pi*(1 - I)*(4*n - 1)/4), S.Integers))) + + +def test_issue_9616(): + assert dumeq(solveset(sinh(x) + tanh(x) - 1, x), Union( + ImageSet(Lambda(n, 2*n*I*pi + log(sqrt(2)/2 + sqrt(-S.Half + sqrt(2)))), S.Integers), + ImageSet(Lambda(n, I*(2*n*pi - atan(sqrt(2)*sqrt(S.Half + sqrt(2))) + pi) + + log(sqrt(1 + sqrt(2)))), S.Integers), + ImageSet(Lambda(n, I*(2*n*pi + pi) + log(-sqrt(2)/2 + sqrt(-S.Half + sqrt(2)))), S.Integers), + ImageSet(Lambda(n, I*(2*n*pi - pi + atan(sqrt(2)*sqrt(S.Half + sqrt(2)))) + + log(sqrt(1 + sqrt(2)))), S.Integers))) + f1 = (sinh(x)).rewrite(exp) + f2 = (tanh(x)).rewrite(exp) + assert dumeq(solveset(f1 + f2 - 1, x), Union( + Complement(ImageSet( + Lambda(n, I*(2*n*pi + pi) + log(-sqrt(2)/2 + sqrt(-S.Half + sqrt(2)))), S.Integers), + ImageSet(Lambda(n, I*(2*n*pi + pi)/2), S.Integers)), + Complement(ImageSet(Lambda(n, I*(2*n*pi - pi + atan(sqrt(2)*sqrt(S.Half + sqrt(2)))) + + log(sqrt(1 + sqrt(2)))), S.Integers), + ImageSet(Lambda(n, I*(2*n*pi + pi)/2), S.Integers)), + Complement(ImageSet(Lambda(n, I*(2*n*pi - atan(sqrt(2)*sqrt(S.Half + sqrt(2))) + pi) + + log(sqrt(1 + sqrt(2)))), S.Integers), + ImageSet(Lambda(n, I*(2*n*pi + pi)/2), S.Integers)), + Complement( + ImageSet(Lambda(n, 2*n*I*pi + log(sqrt(2)/2 + sqrt(-S.Half + sqrt(2)))), S.Integers), + ImageSet(Lambda(n, I*(2*n*pi + pi)/2), S.Integers)))) + + +def test_solve_invalid_sol(): + assert 0 not in solveset_real(sin(x)/x, x) + assert 0 not in solveset_complex((exp(x) - 1)/x, x) + + +@XFAIL +def test_solve_trig_simplified(): + n = Dummy('n') + assert dumeq(solveset_real(sin(x), x), + imageset(Lambda(n, n*pi), S.Integers)) + + assert dumeq(solveset_real(cos(x), x), + imageset(Lambda(n, n*pi + pi/2), S.Integers)) + + assert dumeq(solveset_real(cos(x) + sin(x), x), + imageset(Lambda(n, n*pi - pi/4), S.Integers)) + + +@XFAIL +def test_solve_lambert(): + assert solveset_real(x*exp(x) - 1, x) == FiniteSet(LambertW(1)) + assert solveset_real(exp(x) + x, x) == FiniteSet(-LambertW(1)) + assert solveset_real(x + 2**x, x) == \ + FiniteSet(-LambertW(log(2))/log(2)) + + # issue 4739 + ans = solveset_real(3*x + 5 + 2**(-5*x + 3), x) + assert ans == FiniteSet(Rational(-5, 3) + + LambertW(-10240*2**Rational(1, 3)*log(2)/3)/(5*log(2))) + + eq = 2*(3*x + 4)**5 - 6*7**(3*x + 9) + result = solveset_real(eq, x) + ans = FiniteSet((log(2401) + + 5*LambertW(-log(7**(7*3**Rational(1, 5)/5))))/(3*log(7))/-1) + assert result == ans + assert solveset_real(eq.expand(), x) == result + + assert solveset_real(5*x - 1 + 3*exp(2 - 7*x), x) == \ + FiniteSet(Rational(1, 5) + LambertW(-21*exp(Rational(3, 5))/5)/7) + + assert solveset_real(2*x + 5 + log(3*x - 2), x) == \ + FiniteSet(Rational(2, 3) + LambertW(2*exp(Rational(-19, 3))/3)/2) + + assert solveset_real(3*x + log(4*x), x) == \ + FiniteSet(LambertW(Rational(3, 4))/3) + + assert solveset_real(x**x - 2) == FiniteSet(exp(LambertW(log(2)))) + + a = Symbol('a') + assert solveset_real(-a*x + 2*x*log(x), x) == FiniteSet(exp(a/2)) + a = Symbol('a', real=True) + assert solveset_real(a/x + exp(x/2), x) == \ + FiniteSet(2*LambertW(-a/2)) + assert solveset_real((a/x + exp(x/2)).diff(x), x) == \ + FiniteSet(4*LambertW(sqrt(2)*sqrt(a)/4)) + + # coverage test + assert solveset_real(tanh(x + 3)*tanh(x - 3) - 1, x) is S.EmptySet + + assert solveset_real((x**2 - 2*x + 1).subs(x, log(x) + 3*x), x) == \ + FiniteSet(LambertW(3*S.Exp1)/3) + assert solveset_real((x**2 - 2*x + 1).subs(x, (log(x) + 3*x)**2 - 1), x) == \ + FiniteSet(LambertW(3*exp(-sqrt(2)))/3, LambertW(3*exp(sqrt(2)))/3) + assert solveset_real((x**2 - 2*x - 2).subs(x, log(x) + 3*x), x) == \ + FiniteSet(LambertW(3*exp(1 + sqrt(3)))/3, LambertW(3*exp(-sqrt(3) + 1))/3) + assert solveset_real(x*log(x) + 3*x + 1, x) == \ + FiniteSet(exp(-3 + LambertW(-exp(3)))) + eq = (x*exp(x) - 3).subs(x, x*exp(x)) + assert solveset_real(eq, x) == \ + FiniteSet(LambertW(3*exp(-LambertW(3)))) + + assert solveset_real(3*log(a**(3*x + 5)) + a**(3*x + 5), x) == \ + FiniteSet(-((log(a**5) + LambertW(Rational(1, 3)))/(3*log(a)))) + p = symbols('p', positive=True) + assert solveset_real(3*log(p**(3*x + 5)) + p**(3*x + 5), x) == \ + FiniteSet( + log((-3**Rational(1, 3) - 3**Rational(5, 6)*I)*LambertW(Rational(1, 3))**Rational(1, 3)/(2*p**Rational(5, 3)))/log(p), + log((-3**Rational(1, 3) + 3**Rational(5, 6)*I)*LambertW(Rational(1, 3))**Rational(1, 3)/(2*p**Rational(5, 3)))/log(p), + log((3*LambertW(Rational(1, 3))/p**5)**(1/(3*log(p)))),) # checked numerically + # check collection + b = Symbol('b') + eq = 3*log(a**(3*x + 5)) + b*log(a**(3*x + 5)) + a**(3*x + 5) + assert solveset_real(eq, x) == FiniteSet( + -((log(a**5) + LambertW(1/(b + 3)))/(3*log(a)))) + + # issue 4271 + assert solveset_real((a/x + exp(x/2)).diff(x, 2), x) == FiniteSet( + 6*LambertW((-1)**Rational(1, 3)*a**Rational(1, 3)/3)) + + assert solveset_real(x**3 - 3**x, x) == \ + FiniteSet(-3/log(3)*LambertW(-log(3)/3)) + assert solveset_real(3**cos(x) - cos(x)**3) == FiniteSet( + acos(-3*LambertW(-log(3)/3)/log(3))) + + assert solveset_real(x**2 - 2**x, x) == \ + solveset_real(-x**2 + 2**x, x) + + assert solveset_real(3*log(x) - x*log(3)) == FiniteSet( + -3*LambertW(-log(3)/3)/log(3), + -3*LambertW(-log(3)/3, -1)/log(3)) + + assert solveset_real(LambertW(2*x) - y) == FiniteSet( + y*exp(y)/2) + + +@XFAIL +def test_other_lambert(): + a = Rational(6, 5) + assert solveset_real(x**a - a**x, x) == FiniteSet( + a, -a*LambertW(-log(a)/a)/log(a)) + + +@_both_exp_pow +def test_solveset(): + f = Function('f') + raises(ValueError, lambda: solveset(x + y)) + assert solveset(x, 1) == S.EmptySet + assert solveset(f(1)**2 + y + 1, f(1) + ) == FiniteSet(-sqrt(-y - 1), sqrt(-y - 1)) + assert solveset(f(1)**2 - 1, f(1), S.Reals) == FiniteSet(-1, 1) + assert solveset(f(1)**2 + 1, f(1)) == FiniteSet(-I, I) + assert solveset(x - 1, 1) == FiniteSet(x) + assert solveset(sin(x) - cos(x), sin(x)) == FiniteSet(cos(x)) + + assert solveset(0, domain=S.Reals) == S.Reals + assert solveset(1) == S.EmptySet + assert solveset(True, domain=S.Reals) == S.Reals # issue 10197 + assert solveset(False, domain=S.Reals) == S.EmptySet + + assert solveset(exp(x) - 1, domain=S.Reals) == FiniteSet(0) + assert solveset(exp(x) - 1, x, S.Reals) == FiniteSet(0) + assert solveset(Eq(exp(x), 1), x, S.Reals) == FiniteSet(0) + assert solveset(exp(x) - 1, exp(x), S.Reals) == FiniteSet(1) + A = Indexed('A', x) + assert solveset(A - 1, A, S.Reals) == FiniteSet(1) + + assert solveset(x - 1 >= 0, x, S.Reals) == Interval(1, oo) + assert solveset(exp(x) - 1 >= 0, x, S.Reals) == Interval(0, oo) + + assert dumeq(solveset(exp(x) - 1, x), imageset(Lambda(n, 2*I*pi*n), S.Integers)) + assert dumeq(solveset(Eq(exp(x), 1), x), imageset(Lambda(n, 2*I*pi*n), + S.Integers)) + # issue 13825 + assert solveset(x**2 + f(0) + 1, x) == {-sqrt(-f(0) - 1), sqrt(-f(0) - 1)} + + # issue 19977 + assert solveset(atan(log(x)) > 0, x, domain=Interval.open(0, oo)) == Interval.open(1, oo) + + +@_both_exp_pow +def test_multi_exp(): + k1, k2, k3 = symbols('k1, k2, k3') + assert dumeq(solveset(exp(exp(x)) - 5, x),\ + imageset(Lambda(((k1, n),), I*(2*k1*pi + arg(2*n*I*pi + log(5))) + log(Abs(2*n*I*pi + log(5)))),\ + ProductSet(S.Integers, S.Integers))) + assert dumeq(solveset((d*exp(exp(a*x + b)) + c), x),\ + imageset(Lambda(x, (-b + x)/a), ImageSet(Lambda(((k1, n),), \ + I*(2*k1*pi + arg(I*(2*n*pi + arg(-c/d)) + log(Abs(c/d)))) + log(Abs(I*(2*n*pi + arg(-c/d)) + log(Abs(c/d))))), \ + ProductSet(S.Integers, S.Integers)))) + + assert dumeq(solveset((d*exp(exp(exp(a*x + b))) + c), x),\ + imageset(Lambda(x, (-b + x)/a), ImageSet(Lambda(((k2, k1, n),), \ + I*(2*k2*pi + arg(I*(2*k1*pi + arg(I*(2*n*pi + arg(-c/d)) + log(Abs(c/d)))) + \ + log(Abs(I*(2*n*pi + arg(-c/d)) + log(Abs(c/d)))))) + log(Abs(I*(2*k1*pi + arg(I*(2*n*pi + arg(-c/d)) + \ + log(Abs(c/d)))) + log(Abs(I*(2*n*pi + arg(-c/d)) + log(Abs(c/d))))))), \ + ProductSet(S.Integers, S.Integers, S.Integers)))) + + assert dumeq(solveset((d*exp(exp(exp(exp(a*x + b)))) + c), x),\ + ImageSet(Lambda(x, (-b + x)/a), ImageSet(Lambda(((k3, k2, k1, n),), \ + I*(2*k3*pi + arg(I*(2*k2*pi + arg(I*(2*k1*pi + arg(I*(2*n*pi + arg(-c/d)) + log(Abs(c/d)))) + \ + log(Abs(I*(2*n*pi + arg(-c/d)) + log(Abs(c/d)))))) + log(Abs(I*(2*k1*pi + arg(I*(2*n*pi + arg(-c/d)) + \ + log(Abs(c/d)))) + log(Abs(I*(2*n*pi + arg(-c/d)) + log(Abs(c/d)))))))) + log(Abs(I*(2*k2*pi + \ + arg(I*(2*k1*pi + arg(I*(2*n*pi + arg(-c/d)) + log(Abs(c/d)))) + log(Abs(I*(2*n*pi + arg(-c/d)) + log(Abs(c/d)))))) + \ + log(Abs(I*(2*k1*pi + arg(I*(2*n*pi + arg(-c/d)) + log(Abs(c/d)))) + log(Abs(I*(2*n*pi + arg(-c/d)) + log(Abs(c/d))))))))), \ + ProductSet(S.Integers, S.Integers, S.Integers, S.Integers)))) + + +def test__solveset_multi(): + from sympy.solvers.solveset import _solveset_multi + from sympy.sets import Reals + + # Basic univariate case: + assert _solveset_multi([x**2-1], [x], [S.Reals]) == FiniteSet((1,), (-1,)) + + # Linear systems of two equations + assert _solveset_multi([x+y, x+1], [x, y], [Reals, Reals]) == FiniteSet((-1, 1)) + assert _solveset_multi([x+y, x+1], [y, x], [Reals, Reals]) == FiniteSet((1, -1)) + assert _solveset_multi([x+y, x-y-1], [x, y], [Reals, Reals]) == FiniteSet((S(1)/2, -S(1)/2)) + assert _solveset_multi([x-1, y-2], [x, y], [Reals, Reals]) == FiniteSet((1, 2)) + # assert dumeq(_solveset_multi([x+y], [x, y], [Reals, Reals]), ImageSet(Lambda(x, (x, -x)), Reals)) + assert dumeq(_solveset_multi([x+y], [x, y], [Reals, Reals]), Union( + ImageSet(Lambda(((x,),), (x, -x)), ProductSet(Reals)), + ImageSet(Lambda(((y,),), (-y, y)), ProductSet(Reals)))) + assert _solveset_multi([x+y, x+y+1], [x, y], [Reals, Reals]) == S.EmptySet + assert _solveset_multi([x+y, x-y, x-1], [x, y], [Reals, Reals]) == S.EmptySet + assert _solveset_multi([x+y, x-y, x-1], [y, x], [Reals, Reals]) == S.EmptySet + + # Systems of three equations: + assert _solveset_multi([x+y+z-1, x+y-z-2, x-y-z-3], [x, y, z], [Reals, + Reals, Reals]) == FiniteSet((2, -S.Half, -S.Half)) + + # Nonlinear systems: + from sympy.abc import theta + assert _solveset_multi([x**2+y**2-2, x+y], [x, y], [Reals, Reals]) == FiniteSet((-1, 1), (1, -1)) + assert _solveset_multi([x**2-1, y], [x, y], [Reals, Reals]) == FiniteSet((1, 0), (-1, 0)) + #assert _solveset_multi([x**2-y**2], [x, y], [Reals, Reals]) == Union( + # ImageSet(Lambda(x, (x, -x)), Reals), ImageSet(Lambda(x, (x, x)), Reals)) + assert dumeq(_solveset_multi([x**2-y**2], [x, y], [Reals, Reals]), Union( + ImageSet(Lambda(((x,),), (x, -Abs(x))), ProductSet(Reals)), + ImageSet(Lambda(((x,),), (x, Abs(x))), ProductSet(Reals)), + ImageSet(Lambda(((y,),), (-Abs(y), y)), ProductSet(Reals)), + ImageSet(Lambda(((y,),), (Abs(y), y)), ProductSet(Reals)))) + assert _solveset_multi([r*cos(theta)-1, r*sin(theta)], [theta, r], + [Interval(0, pi), Interval(-1, 1)]) == FiniteSet((0, 1), (pi, -1)) + assert _solveset_multi([r*cos(theta)-1, r*sin(theta)], [r, theta], + [Interval(0, 1), Interval(0, pi)]) == FiniteSet((1, 0)) + assert _solveset_multi([r*cos(theta)-r, r*sin(theta)], [r, theta], + [Interval(0, 1), Interval(0, pi)]) == Union( + ImageSet(Lambda(((r,),), (r, 0)), + ImageSet(Lambda(r, (r,)), Interval(0, 1))), + ImageSet(Lambda(((theta,),), (0, theta)), + ImageSet(Lambda(theta, (theta,)), Interval(0, pi)))) + + +def test_conditionset(): + assert solveset(Eq(sin(x)**2 + cos(x)**2, 1), x, domain=S.Reals + ) is S.Reals + + assert solveset(Eq(x**2 + x*sin(x), 1), x, domain=S.Reals + ).dummy_eq(ConditionSet(x, Eq(x**2 + x*sin(x) - 1, 0), S.Reals)) + + assert dumeq(solveset(Eq(-I*(exp(I*x) - exp(-I*x))/2, 1), x + ), imageset(Lambda(n, 2*n*pi + pi/2), S.Integers)) + + assert solveset(x + sin(x) > 1, x, domain=S.Reals + ).dummy_eq(ConditionSet(x, x + sin(x) > 1, S.Reals)) + + assert solveset(Eq(sin(Abs(x)), x), x, domain=S.Reals + ).dummy_eq(ConditionSet(x, Eq(-x + sin(Abs(x)), 0), S.Reals)) + + assert solveset(y**x-z, x, S.Reals + ).dummy_eq(ConditionSet(x, Eq(y**x - z, 0), S.Reals)) + + +@XFAIL +def test_conditionset_equality(): + ''' Checking equality of different representations of ConditionSet''' + assert solveset(Eq(tan(x), y), x) == ConditionSet(x, Eq(tan(x), y), S.Complexes) + + +def test_solveset_domain(): + assert solveset(x**2 - x - 6, x, Interval(0, oo)) == FiniteSet(3) + assert solveset(x**2 - 1, x, Interval(0, oo)) == FiniteSet(1) + assert solveset(x**4 - 16, x, Interval(0, 10)) == FiniteSet(2) + + +def test_improve_coverage(): + solution = solveset(exp(x) + sin(x), x, S.Reals) + unsolved_object = ConditionSet(x, Eq(exp(x) + sin(x), 0), S.Reals) + assert solution.dummy_eq(unsolved_object) + + +def test_issue_9522(): + expr1 = Eq(1/(x**2 - 4) + x, 1/(x**2 - 4) + 2) + expr2 = Eq(1/x + x, 1/x) + + assert solveset(expr1, x, S.Reals) is S.EmptySet + assert solveset(expr2, x, S.Reals) is S.EmptySet + + +def test_solvify(): + assert solvify(x**2 + 10, x, S.Reals) == [] + assert solvify(x**3 + 1, x, S.Complexes) == [-1, S.Half - sqrt(3)*I/2, + S.Half + sqrt(3)*I/2] + assert solvify(log(x), x, S.Reals) == [1] + assert solvify(cos(x), x, S.Reals) == [pi/2, pi*Rational(3, 2)] + assert solvify(sin(x) + 1, x, S.Reals) == [pi*Rational(3, 2)] + raises(NotImplementedError, lambda: solvify(sin(exp(x)), x, S.Complexes)) + + +def test_solvify_piecewise(): + p1 = Piecewise((0, x < -1), (x**2, x <= 1), (log(x), True)) + p2 = Piecewise((0, x < -10), (x**2 + 5*x - 6, x >= -9)) + p3 = Piecewise((0, Eq(x, 0)), (x**2/Abs(x), True)) + p4 = Piecewise((0, Eq(x, pi)), ((x - pi)/sin(x), True)) + + # issue 21079 + assert solvify(p1, x, S.Reals) == [0] + assert solvify(p2, x, S.Reals) == [-6, 1] + assert solvify(p3, x, S.Reals) == [0] + assert solvify(p4, x, S.Reals) == [pi] + + +def test_abs_invert_solvify(): + + x = Symbol('x',positive=True) + assert solvify(sin(Abs(x)), x, S.Reals) == [0, pi] + x = Symbol('x') + assert solvify(sin(Abs(x)), x, S.Reals) is None + + +def test_linear_eq_to_matrix(): + assert linear_eq_to_matrix(0, x) == (Matrix([[0]]), Matrix([[0]])) + assert linear_eq_to_matrix(1, x) == (Matrix([[0]]), Matrix([[-1]])) + + # integer coefficients + eqns1 = [2*x + y - 2*z - 3, x - y - z, x + y + 3*z - 12] + eqns2 = [Eq(3*x + 2*y - z, 1), Eq(2*x - 2*y + 4*z, -2), -2*x + y - 2*z] + + A, B = linear_eq_to_matrix(eqns1, x, y, z) + assert A == Matrix([[2, 1, -2], [1, -1, -1], [1, 1, 3]]) + assert B == Matrix([[3], [0], [12]]) + + A, B = linear_eq_to_matrix(eqns2, x, y, z) + assert A == Matrix([[3, 2, -1], [2, -2, 4], [-2, 1, -2]]) + assert B == Matrix([[1], [-2], [0]]) + + # Pure symbolic coefficients + eqns3 = [a*b*x + b*y + c*z - d, e*x + d*x + f*y + g*z - h, i*x + j*y + k*z - l] + A, B = linear_eq_to_matrix(eqns3, x, y, z) + assert A == Matrix([[a*b, b, c], [d + e, f, g], [i, j, k]]) + assert B == Matrix([[d], [h], [l]]) + + # raise Errors if + # 1) no symbols are given + raises(ValueError, lambda: linear_eq_to_matrix(eqns3)) + # 2) there are duplicates + raises(ValueError, lambda: linear_eq_to_matrix(eqns3, [x, x, y])) + # 3) a nonlinear term is detected in the original expression + raises(NonlinearError, lambda: linear_eq_to_matrix(Eq(1/x + x, 1/x), [x])) + raises(NonlinearError, lambda: linear_eq_to_matrix([x**2], [x])) + raises(NonlinearError, lambda: linear_eq_to_matrix([x*y], [x, y])) + # 4) Eq being used to represent equations autoevaluates + # (use unevaluated Eq instead) + raises(ValueError, lambda: linear_eq_to_matrix(Eq(x, x), x)) + raises(ValueError, lambda: linear_eq_to_matrix(Eq(x, x + 1), x)) + + + # if non-symbols are passed, the user is responsible for interpreting + assert linear_eq_to_matrix([x], [1/x]) == (Matrix([[0]]), Matrix([[-x]])) + + # issue 15195 + assert linear_eq_to_matrix(x + y*(z*(3*x + 2) + 3), x) == ( + Matrix([[3*y*z + 1]]), Matrix([[-y*(2*z + 3)]])) + assert linear_eq_to_matrix(Matrix( + [[a*x + b*y - 7], [5*x + 6*y - c]]), x, y) == ( + Matrix([[a, b], [5, 6]]), Matrix([[7], [c]])) + + # issue 15312 + assert linear_eq_to_matrix(Eq(x + 2, 1), x) == ( + Matrix([[1]]), Matrix([[-1]])) + + # issue 25423 + raises(TypeError, lambda: linear_eq_to_matrix([], {x, y})) + raises(TypeError, lambda: linear_eq_to_matrix([x + y], {x, y})) + raises(ValueError, lambda: linear_eq_to_matrix({x + y}, (x, y))) + + +def test_issue_16577(): + assert linear_eq_to_matrix(Eq(a*(2*x + 3*y) + 4*y, 5), x, y) == ( + Matrix([[2*a, 3*a + 4]]), Matrix([[5]])) + + +def test_issue_10085(): + assert invert_real(exp(x),0,x) == (x, S.EmptySet) + + +def test_linsolve(): + x1, x2, x3, x4 = symbols('x1, x2, x3, x4') + + # Test for different input forms + + M = Matrix([[1, 2, 1, 1, 7], [1, 2, 2, -1, 12], [2, 4, 0, 6, 4]]) + system1 = A, B = M[:, :-1], M[:, -1] + Eqns = [x1 + 2*x2 + x3 + x4 - 7, x1 + 2*x2 + 2*x3 - x4 - 12, + 2*x1 + 4*x2 + 6*x4 - 4] + + sol = FiniteSet((-2*x2 - 3*x4 + 2, x2, 2*x4 + 5, x4)) + assert linsolve(Eqns, (x1, x2, x3, x4)) == sol + assert linsolve(Eqns, *(x1, x2, x3, x4)) == sol + assert linsolve(system1, (x1, x2, x3, x4)) == sol + assert linsolve(system1, *(x1, x2, x3, x4)) == sol + # issue 9667 - symbols can be Dummy symbols + x1, x2, x3, x4 = symbols('x:4', cls=Dummy) + assert linsolve(system1, x1, x2, x3, x4) == FiniteSet( + (-2*x2 - 3*x4 + 2, x2, 2*x4 + 5, x4)) + + # raise ValueError for garbage value + raises(ValueError, lambda: linsolve(Eqns)) + raises(ValueError, lambda: linsolve(x1)) + raises(ValueError, lambda: linsolve(x1, x2)) + raises(ValueError, lambda: linsolve((A,), x1, x2)) + raises(ValueError, lambda: linsolve(A, B, x1, x2)) + raises(ValueError, lambda: linsolve([x1], x1, x1)) + raises(ValueError, lambda: linsolve([x1], (i for i in (x1, x1)))) + + #raise ValueError if equations are non-linear in given variables + raises(NonlinearError, lambda: linsolve([x + y - 1, x ** 2 + y - 3], [x, y])) + raises(NonlinearError, lambda: linsolve([cos(x) + y, x + y], [x, y])) + assert linsolve([x + z - 1, x ** 2 + y - 3], [z, y]) == {(-x + 1, -x**2 + 3)} + + # Fully symbolic test + A = Matrix([[a, b], [c, d]]) + B = Matrix([[e], [g]]) + system2 = (A, B) + sol = FiniteSet(((-b*g + d*e)/(a*d - b*c), (a*g - c*e)/(a*d - b*c))) + assert linsolve(system2, [x, y]) == sol + + # No solution + A = Matrix([[1, 2, 3], [2, 4, 6], [3, 6, 9]]) + B = Matrix([0, 0, 1]) + assert linsolve((A, B), (x, y, z)) is S.EmptySet + + # Issue #10056 + A, B, J1, J2 = symbols('A B J1 J2') + Augmatrix = Matrix([ + [2*I*J1, 2*I*J2, -2/J1], + [-2*I*J2, -2*I*J1, 2/J2], + [0, 2, 2*I/(J1*J2)], + [2, 0, 0], + ]) + + assert linsolve(Augmatrix, A, B) == FiniteSet((0, I/(J1*J2))) + + # Issue #10121 - Assignment of free variables + Augmatrix = Matrix([[0, 1, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0]]) + assert linsolve(Augmatrix, a, b, c, d, e) == FiniteSet((a, 0, c, 0, e)) + #raises(IndexError, lambda: linsolve(Augmatrix, a, b, c)) + + x0, x1, x2, _x0 = symbols('tau0 tau1 tau2 _tau0') + assert linsolve(Matrix([[0, 1, 0, 0, 0, 0], [0, 0, 0, 1, 0, _x0]]) + ) == FiniteSet((x0, 0, x1, _x0, x2)) + x0, x1, x2, _x0 = symbols('tau00 tau01 tau02 tau0') + assert linsolve(Matrix([[0, 1, 0, 0, 0, 0], [0, 0, 0, 1, 0, _x0]]) + ) == FiniteSet((x0, 0, x1, _x0, x2)) + x0, x1, x2, _x0 = symbols('tau00 tau01 tau02 tau1') + assert linsolve(Matrix([[0, 1, 0, 0, 0, 0], [0, 0, 0, 1, 0, _x0]]) + ) == FiniteSet((x0, 0, x1, _x0, x2)) + # symbols can be given as generators + x0, x2, x4 = symbols('x0, x2, x4') + assert linsolve(Augmatrix, numbered_symbols('x') + ) == FiniteSet((x0, 0, x2, 0, x4)) + Augmatrix[-1, -1] = x0 + # use Dummy to avoid clash; the names may clash but the symbols + # will not + Augmatrix[-1, -1] = symbols('_x0') + assert len(linsolve( + Augmatrix, numbered_symbols('x', cls=Dummy)).free_symbols) == 4 + + # Issue #12604 + f = Function('f') + assert linsolve([f(x) - 5], f(x)) == FiniteSet((5,)) + + # Issue #14860 + from sympy.physics.units import meter, newton, kilo + kN = kilo*newton + Eqns = [8*kN + x + y, 28*kN*meter + 3*x*meter] + assert linsolve(Eqns, x, y) == { + (kilo*newton*Rational(-28, 3), kN*Rational(4, 3))} + + # linsolve does not allow expansion (real or implemented) + # to remove singularities, but it will cancel linear terms + assert linsolve([Eq(x, x + y)], [x, y]) == {(x, 0)} + assert linsolve([Eq(x + x*y, 1 + y)], [x]) == {(1,)} + assert linsolve([Eq(1 + y, x + x*y)], [x]) == {(1,)} + raises(NonlinearError, lambda: + linsolve([Eq(x**2, x**2 + y)], [x, y])) + + # corner cases + # + # XXX: The case below should give the same as for [0] + # assert linsolve([], [x]) == {(x,)} + assert linsolve([], [x]) is S.EmptySet + assert linsolve([0], [x]) == {(x,)} + assert linsolve([x], [x, y]) == {(0, y)} + assert linsolve([x, 0], [x, y]) == {(0, y)} + + +def test_linsolve_large_sparse(): + # + # This is mainly a performance test + # + + def _mk_eqs_sol(n): + xs = symbols('x:{}'.format(n)) + ys = symbols('y:{}'.format(n)) + syms = xs + ys + eqs = [] + sol = (-S.Half,) * n + (S.Half,) * n + for xi, yi in zip(xs, ys): + eqs.extend([xi + yi, xi - yi + 1]) + return eqs, syms, FiniteSet(sol) + + n = 500 + eqs, syms, sol = _mk_eqs_sol(n) + assert linsolve(eqs, syms) == sol + + +def test_linsolve_immutable(): + A = ImmutableDenseMatrix([[1, 1, 2], [0, 1, 2], [0, 0, 1]]) + B = ImmutableDenseMatrix([2, 1, -1]) + assert linsolve([A, B], (x, y, z)) == FiniteSet((1, 3, -1)) + + A = ImmutableDenseMatrix([[1, 1, 7], [1, -1, 3]]) + assert linsolve(A) == FiniteSet((5, 2)) + + +def test_solve_decomposition(): + n = Dummy('n') + + f1 = exp(3*x) - 6*exp(2*x) + 11*exp(x) - 6 + f2 = sin(x)**2 - 2*sin(x) + 1 + f3 = sin(x)**2 - sin(x) + f4 = sin(x + 1) + f5 = exp(x + 2) - 1 + f6 = 1/log(x) + f7 = 1/x + + s1 = ImageSet(Lambda(n, 2*n*pi), S.Integers) + s2 = ImageSet(Lambda(n, 2*n*pi + pi), S.Integers) + s3 = ImageSet(Lambda(n, 2*n*pi + pi/2), S.Integers) + s4 = ImageSet(Lambda(n, 2*n*pi - 1), S.Integers) + s5 = ImageSet(Lambda(n, 2*n*pi - 1 + pi), S.Integers) + + assert solve_decomposition(f1, x, S.Reals) == FiniteSet(0, log(2), log(3)) + assert dumeq(solve_decomposition(f2, x, S.Reals), s3) + assert dumeq(solve_decomposition(f3, x, S.Reals), Union(s1, s2, s3)) + assert dumeq(solve_decomposition(f4, x, S.Reals), Union(s4, s5)) + assert solve_decomposition(f5, x, S.Reals) == FiniteSet(-2) + assert solve_decomposition(f6, x, S.Reals) == S.EmptySet + assert solve_decomposition(f7, x, S.Reals) == S.EmptySet + assert solve_decomposition(x, x, Interval(1, 2)) == S.EmptySet + + +# nonlinsolve testcases +def test_nonlinsolve_basic(): + assert nonlinsolve([],[]) == S.EmptySet + assert nonlinsolve([],[x, y]) == S.EmptySet + + system = [x, y - x - 5] + assert nonlinsolve([x],[x, y]) == FiniteSet((0, y)) + assert nonlinsolve(system, [y]) == S.EmptySet + soln = (ImageSet(Lambda(n, 2*n*pi + pi/2), S.Integers),) + assert dumeq(nonlinsolve([sin(x) - 1], [x]), FiniteSet(tuple(soln))) + soln = ((ImageSet(Lambda(n, 2*n*pi + pi), S.Integers), 1), + (ImageSet(Lambda(n, 2*n*pi), S.Integers), 1)) + assert dumeq(nonlinsolve([sin(x), y - 1], [x, y]), FiniteSet(*soln)) + assert nonlinsolve([x**2 - 1], [x]) == FiniteSet((-1,), (1,)) + + soln = FiniteSet((y, y)) + assert nonlinsolve([x - y, 0], x, y) == soln + assert nonlinsolve([0, x - y], x, y) == soln + assert nonlinsolve([x - y, x - y], x, y) == soln + assert nonlinsolve([x, 0], x, y) == FiniteSet((0, y)) + f = Function('f') + assert nonlinsolve([f(x), 0], f(x), y) == FiniteSet((0, y)) + assert nonlinsolve([f(x), 0], f(x), f(y)) == FiniteSet((0, f(y))) + A = Indexed('A', x) + assert nonlinsolve([A, 0], A, y) == FiniteSet((0, y)) + assert nonlinsolve([x**2 -1], [sin(x)]) == FiniteSet((S.EmptySet,)) + assert nonlinsolve([x**2 -1], sin(x)) == FiniteSet((S.EmptySet,)) + assert nonlinsolve([x**2 -1], 1) == FiniteSet((x**2,)) + assert nonlinsolve([x**2 -1], x + y) == FiniteSet((S.EmptySet,)) + assert nonlinsolve([Eq(1, x + y), Eq(1, -x + y - 1), Eq(1, -x + y - 1)], x, y) == FiniteSet( + (-S.Half, 3*S.Half)) + + +def test_nonlinsolve_abs(): + soln = FiniteSet((y, y), (-y, y)) + assert nonlinsolve([Abs(x) - y], x, y) == soln + + +def test_raise_exception_nonlinsolve(): + raises(IndexError, lambda: nonlinsolve([x**2 -1], [])) + raises(ValueError, lambda: nonlinsolve([x**2 -1])) + + +def test_trig_system(): + # TODO: add more simple testcases when solveset returns + # simplified soln for Trig eq + assert nonlinsolve([sin(x) - 1, cos(x) -1 ], x) == S.EmptySet + soln1 = (ImageSet(Lambda(n, 2*n*pi + pi/2), S.Integers),) + soln = FiniteSet(soln1) + assert dumeq(nonlinsolve([sin(x) - 1, cos(x)], x), soln) + + +@XFAIL +def test_trig_system_fail(): + # fails because solveset trig solver is not much smart. + sys = [x + y - pi/2, sin(x) + sin(y) - 1] + # solveset returns conditionset for sin(x) + sin(y) - 1 + soln_1 = (ImageSet(Lambda(n, n*pi + pi/2), S.Integers), + ImageSet(Lambda(n, n*pi), S.Integers)) + soln_1 = FiniteSet(soln_1) + soln_2 = (ImageSet(Lambda(n, n*pi), S.Integers), + ImageSet(Lambda(n, n*pi+ pi/2), S.Integers)) + soln_2 = FiniteSet(soln_2) + soln = soln_1 + soln_2 + assert dumeq(nonlinsolve(sys, [x, y]), soln) + + # Add more cases from here + # http://www.vitutor.com/geometry/trigonometry/equations_systems.html#uno + sys = [sin(x) + sin(y) - (sqrt(3)+1)/2, sin(x) - sin(y) - (sqrt(3) - 1)/2] + soln_x = Union(ImageSet(Lambda(n, 2*n*pi + pi/3), S.Integers), + ImageSet(Lambda(n, 2*n*pi + pi*Rational(2, 3)), S.Integers)) + soln_y = Union(ImageSet(Lambda(n, 2*n*pi + pi/6), S.Integers), + ImageSet(Lambda(n, 2*n*pi + pi*Rational(5, 6)), S.Integers)) + assert dumeq(nonlinsolve(sys, [x, y]), FiniteSet((soln_x, soln_y))) + + +def test_nonlinsolve_positive_dimensional(): + x, y, a, b, c, d = symbols('x, y, a, b, c, d', extended_real=True) + assert nonlinsolve([x*y, x*y - x], [x, y]) == FiniteSet((0, y)) + + system = [a**2 + a*c, a - b] + assert nonlinsolve(system, [a, b]) == FiniteSet((0, 0), (-c, -c)) + # here (a= 0, b = 0) is independent soln so both is printed. + # if symbols = [a, b, c] then only {a : -c ,b : -c} + + eq1 = a + b + c + d + eq2 = a*b + b*c + c*d + d*a + eq3 = a*b*c + b*c*d + c*d*a + d*a*b + eq4 = a*b*c*d - 1 + system = [eq1, eq2, eq3, eq4] + sol1 = (-1/d, -d, 1/d, FiniteSet(d) - FiniteSet(0)) + sol2 = (1/d, -d, -1/d, FiniteSet(d) - FiniteSet(0)) + soln = FiniteSet(sol1, sol2) + assert nonlinsolve(system, [a, b, c, d]) == soln + + assert nonlinsolve([x**4 - 3*x**2 + y*x, x*z**2, y*z - 1], [x, y, z]) == \ + {(0, 1/z, z)} + + +def test_nonlinsolve_polysys(): + x, y, z = symbols('x, y, z', real=True) + assert nonlinsolve([x**2 + y - 2, x**2 + y], [x, y]) == S.EmptySet + + s = (-y + 2, y) + assert nonlinsolve([(x + y)**2 - 4, x + y - 2], [x, y]) == FiniteSet(s) + + system = [x**2 - y**2] + soln_real = FiniteSet((-y, y), (y, y)) + soln_complex = FiniteSet((-Abs(y), y), (Abs(y), y)) + soln =soln_real + soln_complex + assert nonlinsolve(system, [x, y]) == soln + + system = [x**2 - y**2] + soln_real= FiniteSet((y, -y), (y, y)) + soln_complex = FiniteSet((y, -Abs(y)), (y, Abs(y))) + soln = soln_real + soln_complex + assert nonlinsolve(system, [y, x]) == soln + + system = [x**2 + y - 3, x - y - 4] + assert nonlinsolve(system, (x, y)) != nonlinsolve(system, (y, x)) + + assert nonlinsolve([-x**2 - y**2 + z, -2*x, -2*y, S.One], [x, y, z]) == S.EmptySet + assert nonlinsolve([x + y + z, S.One, S.One, S.One], [x, y, z]) == S.EmptySet + + system = [-x**2*z**2 + x*y*z + y**4, -2*x*z**2 + y*z, x*z + 4*y**3, -2*x**2*z + x*y] + assert nonlinsolve(system, [x, y, z]) == FiniteSet((0, 0, z), (x, 0, 0)) + + +def test_nonlinsolve_using_substitution(): + x, y, z, n = symbols('x, y, z, n', real = True) + system = [(x + y)*n - y**2 + 2] + s_x = (n*y - y**2 + 2)/n + soln = (-s_x, y) + assert nonlinsolve(system, [x, y]) == FiniteSet(soln) + + system = [z**2*x**2 - z**2*y**2/exp(x)] + soln_real_1 = (y, x, 0) + soln_real_2 = (-exp(x/2)*Abs(x), x, z) + soln_real_3 = (exp(x/2)*Abs(x), x, z) + soln_complex_1 = (-x*exp(x/2), x, z) + soln_complex_2 = (x*exp(x/2), x, z) + syms = [y, x, z] + soln = FiniteSet(soln_real_1, soln_complex_1, soln_complex_2,\ + soln_real_2, soln_real_3) + assert nonlinsolve(system,syms) == soln + + +def test_nonlinsolve_complex(): + n = Dummy('n') + assert dumeq(nonlinsolve([exp(x) - sin(y), 1/y - 3], [x, y]), { + (ImageSet(Lambda(n, 2*n*I*pi + log(sin(Rational(1, 3)))), S.Integers), Rational(1, 3))}) + + system = [exp(x) - sin(y), 1/exp(y) - 3] + assert dumeq(nonlinsolve(system, [x, y]), { + (ImageSet(Lambda(n, I*(2*n*pi + pi) + + log(sin(log(3)))), S.Integers), -log(3)), + (ImageSet(Lambda(n, I*(2*n*pi + arg(sin(2*n*I*pi - log(3)))) + + log(Abs(sin(2*n*I*pi - log(3))))), S.Integers), + ImageSet(Lambda(n, 2*n*I*pi - log(3)), S.Integers))}) + + system = [exp(x) - sin(y), y**2 - 4] + assert dumeq(nonlinsolve(system, [x, y]), { + (ImageSet(Lambda(n, I*(2*n*pi + pi) + log(sin(2))), S.Integers), -2), + (ImageSet(Lambda(n, 2*n*I*pi + log(sin(2))), S.Integers), 2)}) + + system = [exp(x) - 2, y ** 2 - 2] + assert dumeq(nonlinsolve(system, [x, y]), { + (log(2), -sqrt(2)), (log(2), sqrt(2)), + (ImageSet(Lambda(n, 2*n*I*pi + log(2)), S.Integers), -sqrt(2)), + (ImageSet(Lambda(n, 2 * n * I * pi + log(2)), S.Integers), sqrt(2))}) + + +def test_nonlinsolve_radical(): + assert nonlinsolve([sqrt(y) - x - z, y - 1], [x, y, z]) == {(1 - z, 1, z)} + + +def test_nonlinsolve_inexact(): + sol = [(-1.625, -1.375), (1.625, 1.375)] + res = nonlinsolve([(x + y)**2 - 9, x**2 - y**2 - 0.75], [x, y]) + assert all(abs(res.args[i][j]-sol[i][j]) < 1e-9 + for i in range(2) for j in range(2)) + + assert nonlinsolve([(x + y)**2 - 9, (x + y)**2 - 0.75], [x, y]) == S.EmptySet + + assert nonlinsolve([y**2 + (x - 0.5)**2 - 0.0625, 2*x - 1.0, 2*y], [x, y]) == \ + S.EmptySet + + res = nonlinsolve([x**2 + y - 0.5, (x + y)**2, log(z)], [x, y, z]) + sol = [(-0.366025403784439, 0.366025403784439, 1), + (-0.366025403784439, 0.366025403784439, 1), + (1.36602540378444, -1.36602540378444, 1)] + assert all(abs(res.args[i][j]-sol[i][j]) < 1e-9 + for i in range(3) for j in range(3)) + + res = nonlinsolve([y - x**2, x**5 - x + 1.0], [x, y]) + sol = [(-1.16730397826142, 1.36259857766493), + (-0.181232444469876 - 1.08395410131771*I, + -1.14211129483496 + 0.392895302949911*I), + (-0.181232444469876 + 1.08395410131771*I, + -1.14211129483496 - 0.392895302949911*I), + (0.764884433600585 - 0.352471546031726*I, + 0.460812006002492 - 0.539199997693599*I), + (0.764884433600585 + 0.352471546031726*I, + 0.460812006002492 + 0.539199997693599*I)] + assert all(abs(res.args[i][j] - sol[i][j]) < 1e-9 + for i in range(5) for j in range(2)) + +@XFAIL +def test_solve_nonlinear_trans(): + # After the transcendental equation solver these will work + x, y = symbols('x, y', real=True) + soln1 = FiniteSet((2*LambertW(y/2), y)) + soln2 = FiniteSet((-x*sqrt(exp(x)), y), (x*sqrt(exp(x)), y)) + soln3 = FiniteSet((x*exp(x/2), x)) + soln4 = FiniteSet(2*LambertW(y/2), y) + assert nonlinsolve([x**2 - y**2/exp(x)], [x, y]) == soln1 + assert nonlinsolve([x**2 - y**2/exp(x)], [y, x]) == soln2 + assert nonlinsolve([x**2 - y**2/exp(x)], [y, x]) == soln3 + assert nonlinsolve([x**2 - y**2/exp(x)], [x, y]) == soln4 + + +def test_nonlinsolve_issue_25182(): + a1, b1, c1, ca, cb, cg = symbols('a1, b1, c1, ca, cb, cg') + eq1 = a1*a1 + b1*b1 - 2.*a1*b1*cg - c1*c1 + eq2 = a1*a1 + c1*c1 - 2.*a1*c1*cb - b1*b1 + eq3 = b1*b1 + c1*c1 - 2.*b1*c1*ca - a1*a1 + assert nonlinsolve([eq1, eq2, eq3], [c1, cb, cg]) == FiniteSet( + (1.0*b1*ca - 1.0*sqrt(a1**2 + b1**2*ca**2 - b1**2), + -1.0*sqrt(a1**2 + b1**2*ca**2 - b1**2)/a1, + -1.0*b1*(ca - 1)*(ca + 1)/a1 + 1.0*ca*sqrt(a1**2 + b1**2*ca**2 - b1**2)/a1), + (1.0*b1*ca + 1.0*sqrt(a1**2 + b1**2*ca**2 - b1**2), + 1.0*sqrt(a1**2 + b1**2*ca**2 - b1**2)/a1, + -1.0*b1*(ca - 1)*(ca + 1)/a1 - 1.0*ca*sqrt(a1**2 + b1**2*ca**2 - b1**2)/a1)) + + +def test_issue_14642(): + x = Symbol('x') + n1 = 0.5*x**3+x**2+0.5+I #add I in the Polynomials + solution = solveset(n1, x) + assert abs(solution.args[0] - (-2.28267560928153 - 0.312325580497716*I)) <= 1e-9 + assert abs(solution.args[1] - (-0.297354141679308 + 1.01904778618762*I)) <= 1e-9 + assert abs(solution.args[2] - (0.580029750960839 - 0.706722205689907*I)) <= 1e-9 + + # Symbolic + n1 = S.Half*x**3+x**2+S.Half+I + res = FiniteSet(-((3*sqrt(3)*31985**(S(1)/4)*sin(atan(S(172)/49)/2)/2 + + S(43)/2)**2 + (27 + 3*sqrt(3)*31985**(S(1)/4)*cos(atan(S(172)/49) + /2)/2)**2)**(S(1)/6)*cos(atan((27 + 3*sqrt(3)*31985**(S(1)/4)* + cos(atan(S(172)/49)/2)/2)/(3*sqrt(3)*31985**(S(1)/4)*sin(atan( + S(172)/49)/2)/2 + S(43)/2))/3)/3 - S(2)/3 - 4*cos(atan((27 + + 3*sqrt(3)*31985**(S(1)/4)*cos(atan(S(172)/49)/2)/2)/(3*sqrt(3)* + 31985**(S(1)/4)*sin(atan(S(172)/49)/2)/2 + S(43)/2))/3)/(3*((3* + sqrt(3)*31985**(S(1)/4)*sin(atan(S(172)/49)/2)/2 + S(43)/2)**2 + + (27 + 3*sqrt(3)*31985**(S(1)/4)*cos(atan(S(172)/49)/2)/2)**2)**(S(1)/ + 6)) + I*(-((3*sqrt(3)*31985**(S(1)/4)*sin(atan(S(172)/49)/2)/2 + + S(43)/2)**2 + (27 + 3*sqrt(3)*31985**(S(1)/4)*cos(atan(S(172)/49)/ + 2)/2)**2)**(S(1)/6)*sin(atan((27 + 3*sqrt(3)*31985**(S(1)/4)*cos( + atan(S(172)/49)/2)/2)/(3*sqrt(3)*31985**(S(1)/4)*sin(atan(S(172)/49) + /2)/2 + S(43)/2))/3)/3 + 4*sin(atan((27 + 3*sqrt(3)*31985**(S(1)/4)* + cos(atan(S(172)/49)/2)/2)/(3*sqrt(3)*31985**(S(1)/4)*sin(atan(S(172) + /49)/2)/2 + S(43)/2))/3)/(3*((3*sqrt(3)*31985**(S(1)/4)*sin(atan( + S(172)/49)/2)/2 + S(43)/2)**2 + (27 + 3*sqrt(3)*31985**(S(1)/4)* + cos(atan(S(172)/49)/2)/2)**2)**(S(1)/6))), -S(2)/3 - sqrt(3)*((3* + sqrt(3)*31985**(S(1)/4)*sin(atan(S(172)/49)/2)/2 + S(43)/2)**2 + + (27 + 3*sqrt(3)*31985**(S(1)/4)*cos(atan(S(172)/49)/2)/2)**2)**(S(1) + /6)*sin(atan((27 + 3*sqrt(3)*31985**(S(1)/4)*cos(atan(S(172)/49)/2) + /2)/(3*sqrt(3)*31985**(S(1)/4)*sin(atan(S(172)/49)/2)/2 + S(43)/2)) + /3)/6 - 4*re(1/((-S(1)/2 - sqrt(3)*I/2)*(S(43)/2 + 27*I + sqrt(-256 + + (43 + 54*I)**2)/2)**(S(1)/3)))/3 + ((3*sqrt(3)*31985**(S(1)/4)*sin( + atan(S(172)/49)/2)/2 + S(43)/2)**2 + (27 + 3*sqrt(3)*31985**(S(1)/4)* + cos(atan(S(172)/49)/2)/2)**2)**(S(1)/6)*cos(atan((27 + 3*sqrt(3)* + 31985**(S(1)/4)*cos(atan(S(172)/49)/2)/2)/(3*sqrt(3)*31985**(S(1)/4)* + sin(atan(S(172)/49)/2)/2 + S(43)/2))/3)/6 + I*(-4*im(1/((-S(1)/2 - + sqrt(3)*I/2)*(S(43)/2 + 27*I + sqrt(-256 + (43 + 54*I)**2)/2)**(S(1)/ + 3)))/3 + ((3*sqrt(3)*31985**(S(1)/4)*sin(atan(S(172)/49)/2)/2 + + S(43)/2)**2 + (27 + 3*sqrt(3)*31985**(S(1)/4)*cos(atan(S(172)/49)/2) + /2)**2)**(S(1)/6)*sin(atan((27 + 3*sqrt(3)*31985**(S(1)/4)*cos(atan( + S(172)/49)/2)/2)/(3*sqrt(3)*31985**(S(1)/4)*sin(atan(S(172)/49)/2)/2 + + S(43)/2))/3)/6 + sqrt(3)*((3*sqrt(3)*31985**(S(1)/4)*sin(atan(S(172)/ + 49)/2)/2 + S(43)/2)**2 + (27 + 3*sqrt(3)*31985**(S(1)/4)*cos(atan( + S(172)/49)/2)/2)**2)**(S(1)/6)*cos(atan((27 + 3*sqrt(3)*31985**(S(1)/ + 4)*cos(atan(S(172)/49)/2)/2)/(3*sqrt(3)*31985**(S(1)/4)*sin(atan( + S(172)/49)/2)/2 + S(43)/2))/3)/6), -S(2)/3 - 4*re(1/((-S(1)/2 + + sqrt(3)*I/2)*(S(43)/2 + 27*I + sqrt(-256 + (43 + 54*I)**2)/2)**(S(1) + /3)))/3 + sqrt(3)*((3*sqrt(3)*31985**(S(1)/4)*sin(atan(S(172)/49)/2)/2 + + S(43)/2)**2 + (27 + 3*sqrt(3)*31985**(S(1)/4)*cos(atan(S(172)/49)/2) + /2)**2)**(S(1)/6)*sin(atan((27 + 3*sqrt(3)*31985**(S(1)/4)*cos(atan( + S(172)/49)/2)/2)/(3*sqrt(3)*31985**(S(1)/4)*sin(atan(S(172)/49)/2)/2 + + S(43)/2))/3)/6 + ((3*sqrt(3)*31985**(S(1)/4)*sin(atan(S(172)/49)/2)/2 + + S(43)/2)**2 + (27 + 3*sqrt(3)*31985**(S(1)/4)*cos(atan(S(172)/49)/2) + /2)**2)**(S(1)/6)*cos(atan((27 + 3*sqrt(3)*31985**(S(1)/4)*cos(atan( + S(172)/49)/2)/2)/(3*sqrt(3)*31985**(S(1)/4)*sin(atan(S(172)/49)/2)/2 + + S(43)/2))/3)/6 + I*(-sqrt(3)*((3*sqrt(3)*31985**(S(1)/4)*sin(atan( + S(172)/49)/2)/2 + S(43)/2)**2 + (27 + 3*sqrt(3)*31985**(S(1)/4)*cos( + atan(S(172)/49)/2)/2)**2)**(S(1)/6)*cos(atan((27 + 3*sqrt(3)*31985**( + S(1)/4)*cos(atan(S(172)/49)/2)/2)/(3*sqrt(3)*31985**(S(1)/4)*sin( + atan(S(172)/49)/2)/2 + S(43)/2))/3)/6 + ((3*sqrt(3)*31985**(S(1)/4)* + sin(atan(S(172)/49)/2)/2 + S(43)/2)**2 + (27 + 3*sqrt(3)*31985**(S(1)/4)* + cos(atan(S(172)/49)/2)/2)**2)**(S(1)/6)*sin(atan((27 + 3*sqrt(3)*31985**( + S(1)/4)*cos(atan(S(172)/49)/2)/2)/(3*sqrt(3)*31985**(S(1)/4)*sin( + atan(S(172)/49)/2)/2 + S(43)/2))/3)/6 - 4*im(1/((-S(1)/2 + sqrt(3)*I/2)* + (S(43)/2 + 27*I + sqrt(-256 + (43 + 54*I)**2)/2)**(S(1)/3)))/3)) + + assert solveset(n1, x) == res + + +def test_issue_13961(): + V = (ax, bx, cx, gx, jx, lx, mx, nx, q) = symbols('ax bx cx gx jx lx mx nx q') + S = (ax*q - lx*q - mx, ax - gx*q - lx, bx*q**2 + cx*q - jx*q - nx, q*(-ax*q + lx*q + mx), q*(-ax + gx*q + lx)) + + sol = FiniteSet((lx + mx/q, (-cx*q + jx*q + nx)/q**2, cx, mx/q**2, jx, lx, mx, nx, Complement({q}, {0})), + (lx + mx/q, (cx*q - jx*q - nx)/q**2*-1, cx, mx/q**2, jx, lx, mx, nx, Complement({q}, {0}))) + assert nonlinsolve(S, *V) == sol + # The two solutions are in fact identical, so even better if only one is returned + + +def test_issue_14541(): + solutions = solveset(sqrt(-x**2 - 2.0), x) + assert abs(solutions.args[0]+1.4142135623731*I) <= 1e-9 + assert abs(solutions.args[1]-1.4142135623731*I) <= 1e-9 + + +def test_issue_13396(): + expr = -2*y*exp(-x**2 - y**2)*Abs(x) + sol = FiniteSet(0) + + assert solveset(expr, y, domain=S.Reals) == sol + + # Related type of equation also solved here + assert solveset(atan(x**2 - y**2)-pi/2, y, S.Reals) is S.EmptySet + + +def test_issue_12032(): + sol = FiniteSet(-sqrt(-2/(3*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3))) + + 2*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3)))/2 + + sqrt(Abs(-2*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3)) + + 2/(3*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3))) + + 2/sqrt(-2/(3*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3))) + + 2*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3)))))/2, + -sqrt(Abs(-2*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3)) + + 2/(3*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3))) + + 2/sqrt(-2/(3*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3))) + + 2*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3)))))/2 - + sqrt(-2/(3*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3))) + + 2*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3)))/2, + sqrt(-2/(3*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3))) + + 2*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3)))/2 - + I*sqrt(Abs(-2/sqrt(-2/(3*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3))) + + 2*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3))) - + 2*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3)) + + 2/(3*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3)))))/2, + sqrt(-2/(3*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3))) + + 2*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3)))/2 + + I*sqrt(Abs(-2/sqrt(-2/(3*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3))) + + 2*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3))) - + 2*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3)) + + 2/(3*(Rational(1, 16) + sqrt(849)/144)**(Rational(1,3)))))/2) + assert solveset(x**4 + x - 1, x) == sol + + +def test_issue_10876(): + assert solveset(1/sqrt(x), x) == S.EmptySet + + +def test_issue_19050(): + # test_issue_19050 --> TypeError removed + assert dumeq(nonlinsolve([x + y, sin(y)], [x, y]), + FiniteSet((ImageSet(Lambda(n, -2*n*pi), S.Integers), ImageSet(Lambda(n, 2*n*pi), S.Integers)),\ + (ImageSet(Lambda(n, -2*n*pi - pi), S.Integers), ImageSet(Lambda(n, 2*n*pi + pi), S.Integers)))) + assert dumeq(nonlinsolve([x + y, sin(y) + cos(y)], [x, y]), + FiniteSet((ImageSet(Lambda(n, -2*n*pi - 3*pi/4), S.Integers), ImageSet(Lambda(n, 2*n*pi + 3*pi/4), S.Integers)), \ + (ImageSet(Lambda(n, -2*n*pi - 7*pi/4), S.Integers), ImageSet(Lambda(n, 2*n*pi + 7*pi/4), S.Integers)))) + + +def test_issue_16618(): + eqn = [sin(x)*sin(y), cos(x)*cos(y) - 1] + # nonlinsolve's answer is still suspicious since it contains only three + # distinct Dummys instead of 4. (Both 'x' ImageSets share the same Dummy.) + ans = FiniteSet((ImageSet(Lambda(n, 2*n*pi), S.Integers), ImageSet(Lambda(n, 2*n*pi), S.Integers)), + (ImageSet(Lambda(n, 2*n*pi + pi), S.Integers), ImageSet(Lambda(n, 2*n*pi + pi), S.Integers))) + sol = nonlinsolve(eqn, [x, y]) + + for i0, j0 in zip(ordered(sol), ordered(ans)): + assert len(i0) == len(j0) == 2 + assert all(a.dummy_eq(b) for a, b in zip(i0, j0)) + assert len(sol) == len(ans) + + +def test_issue_17566(): + assert nonlinsolve([32*(2**x)/2**(-y) - 4**y, 27*(3**x) - S(1)/3**y], x, y) ==\ + FiniteSet((-log(81)/log(3), 1)) + + +def test_issue_16643(): + n = Dummy('n') + assert solveset(x**2*sin(x), x).dummy_eq(Union(ImageSet(Lambda(n, 2*n*pi + pi), S.Integers), + ImageSet(Lambda(n, 2*n*pi), S.Integers))) + + +def test_issue_19587(): + n,m = symbols('n m') + assert nonlinsolve([32*2**m*2**n - 4**n, 27*3**m - 3**(-n)], m, n) ==\ + FiniteSet((-log(81)/log(3), 1)) + + +def test_issue_5132_1(): + system = [sqrt(x**2 + y**2) - sqrt(10), x + y - 4] + assert nonlinsolve(system, [x, y]) == FiniteSet((1, 3), (3, 1)) + + n = Dummy('n') + eqs = [exp(x)**2 - sin(y) + z**2, 1/exp(y) - 3] + s_real_y = -log(3) + s_real_z = sqrt(-exp(2*x) - sin(log(3))) + soln_real = FiniteSet((s_real_y, s_real_z), (s_real_y, -s_real_z)) + lam = Lambda(n, 2*n*I*pi + -log(3)) + s_complex_y = ImageSet(lam, S.Integers) + lam = Lambda(n, sqrt(-exp(2*x) + sin(2*n*I*pi + -log(3)))) + s_complex_z_1 = ImageSet(lam, S.Integers) + lam = Lambda(n, -sqrt(-exp(2*x) + sin(2*n*I*pi + -log(3)))) + s_complex_z_2 = ImageSet(lam, S.Integers) + soln_complex = FiniteSet( + (s_complex_y, s_complex_z_1), + (s_complex_y, s_complex_z_2) + ) + soln = soln_real + soln_complex + assert dumeq(nonlinsolve(eqs, [y, z]), soln) + + +def test_issue_5132_2(): + x, y = symbols('x, y', real=True) + eqs = [exp(x)**2 - sin(y) + z**2] + n = Dummy('n') + soln_real = (log(-z**2 + sin(y))/2, z) + lam = Lambda( n, I*(2*n*pi + arg(-z**2 + sin(y)))/2 + log(Abs(z**2 - sin(y)))/2) + img = ImageSet(lam, S.Integers) + # not sure about the complex soln. But it looks correct. + soln_complex = (img, z) + soln = FiniteSet(soln_real, soln_complex) + assert dumeq(nonlinsolve(eqs, [x, z]), soln) + + system = [r - x**2 - y**2, tan(t) - y/x] + s_x = sqrt(r/(tan(t)**2 + 1)) + s_y = sqrt(r/(tan(t)**2 + 1))*tan(t) + soln = FiniteSet((s_x, s_y), (-s_x, -s_y)) + assert nonlinsolve(system, [x, y]) == soln + + +def test_issue_6752(): + a, b = symbols('a, b', real=True) + assert nonlinsolve([a**2 + a, a - b], [a, b]) == {(-1, -1), (0, 0)} + + +@SKIP("slow") +def test_issue_5114_solveset(): + # slow testcase + from sympy.abc import o, p + + # there is no 'a' in the equation set but this is how the + # problem was originally posed + syms = [a, b, c, f, h, k, n] + eqs = [b + r/d - c/d, + c*(1/d + 1/e + 1/g) - f/g - r/d, + f*(1/g + 1/i + 1/j) - c/g - h/i, + h*(1/i + 1/l + 1/m) - f/i - k/m, + k*(1/m + 1/o + 1/p) - h/m - n/p, + n*(1/p + 1/q) - k/p] + assert len(nonlinsolve(eqs, syms)) == 1 + + +@SKIP("Hangs") +def _test_issue_5335(): + # Not able to check zero dimensional system. + # is_zero_dimensional Hangs + lam, a0, conc = symbols('lam a0 conc') + eqs = [lam + 2*y - a0*(1 - x/2)*x - 0.005*x/2*x, + a0*(1 - x/2)*x - 1*y - 0.743436700916726*y, + x + y - conc] + sym = [x, y, a0] + # there are 4 solutions but only two are valid + assert len(nonlinsolve(eqs, sym)) == 2 + # float + eqs = [lam + 2*y - a0*(1 - x/2)*x - 0.005*x/2*x, + a0*(1 - x/2)*x - 1*y - 0.743436700916726*y, + x + y - conc] + sym = [x, y, a0] + assert len(nonlinsolve(eqs, sym)) == 2 + + +def test_issue_2777(): + # the equations represent two circles + x, y = symbols('x y', real=True) + e1, e2 = sqrt(x**2 + y**2) - 10, sqrt(y**2 + (-x + 10)**2) - 3 + a, b = Rational(191, 20), 3*sqrt(391)/20 + ans = {(a, -b), (a, b)} + assert nonlinsolve((e1, e2), (x, y)) == ans + assert nonlinsolve((e1, e2/(x - a)), (x, y)) == S.EmptySet + # make the 2nd circle's radius be -3 + e2 += 6 + assert nonlinsolve((e1, e2), (x, y)) == S.EmptySet + + +def test_issue_8828(): + x1 = 0 + y1 = -620 + r1 = 920 + x2 = 126 + y2 = 276 + x3 = 51 + y3 = 205 + r3 = 104 + v = [x, y, z] + + f1 = (x - x1)**2 + (y - y1)**2 - (r1 - z)**2 + f2 = (x2 - x)**2 + (y2 - y)**2 - z**2 + f3 = (x - x3)**2 + (y - y3)**2 - (r3 - z)**2 + F = [f1, f2, f3] + + g1 = sqrt((x - x1)**2 + (y - y1)**2) + z - r1 + g2 = f2 + g3 = sqrt((x - x3)**2 + (y - y3)**2) + z - r3 + G = [g1, g2, g3] + + # both soln same + A = nonlinsolve(F, v) + B = nonlinsolve(G, v) + assert A == B + + +def test_nonlinsolve_conditionset(): + # when solveset failed to solve all the eq + # return conditionset + f = Function('f') + f1 = f(x) - pi/2 + f2 = f(y) - pi*Rational(3, 2) + intermediate_system = Eq(2*f(x) - pi, 0) & Eq(2*f(y) - 3*pi, 0) + syms = Tuple(x, y) + soln = ConditionSet( + syms, + intermediate_system, + S.Complexes**2) + assert nonlinsolve([f1, f2], [x, y]) == soln + + +def test_substitution_basic(): + assert substitution([], [x, y]) == S.EmptySet + assert substitution([], []) == S.EmptySet + system = [2*x**2 + 3*y**2 - 30, 3*x**2 - 2*y**2 - 19] + soln = FiniteSet((-3, -2), (-3, 2), (3, -2), (3, 2)) + assert substitution(system, [x, y]) == soln + + soln = FiniteSet((-1, 1)) + assert substitution([x + y], [x], [{y: 1}], [y], set(), [x, y]) == soln + assert substitution( + [x + y], [x], [{y: 1}], [y], + {x + 1}, [y, x]) == S.EmptySet + + +def test_substitution_incorrect(): + # the solutions in the following two tests are incorrect. The + # correct result is EmptySet in both cases. + assert substitution([h - 1, k - 1, f - 2, f - 4, -2 * k], + [h, k, f]) == {(1, 1, f)} + assert substitution([x + y + z, S.One, S.One, S.One], [x, y, z]) == \ + {(-y - z, y, z)} + + # the correct result in the test below is {(-I, I, I, -I), + # (I, -I, -I, I)} + assert substitution([a - d, b + d, c + d, d**2 + 1], [a, b, c, d]) == \ + {(d, -d, -d, d)} + + # the result in the test below is incomplete. The complete result + # is {(0, b), (log(2), 2)} + assert substitution([a*(a - log(b)), a*(b - 2)], [a, b]) == \ + {(0, b)} + + # The system in the test below is zero-dimensional, so the result + # should have no free symbols + assert substitution([-k*y + 6*x - 4*y, -81*k + 49*y**2 - 270, + -3*k*z + k + z**3, k**2 - 2*k + 4], + [x, y, z, k]).free_symbols == {z} + + +def test_substitution_redundant(): + # the third and fourth solutions are redundant in the test below + assert substitution([x**2 - y**2, z - 1], [x, z]) == \ + {(-y, 1), (y, 1), (-sqrt(y**2), 1), (sqrt(y**2), 1)} + + # the system below has three solutions. Two of the solutions + # returned by substitution are redundant. + res = substitution([x - y, y**3 - 3*y**2 + 1], [x, y]) + assert len(res) == 5 + + +def test_issue_5132_substitution(): + x, y, z, r, t = symbols('x, y, z, r, t', real=True) + system = [r - x**2 - y**2, tan(t) - y/x] + s_x_1 = Complement(FiniteSet(-sqrt(r/(tan(t)**2 + 1))), FiniteSet(0)) + s_x_2 = Complement(FiniteSet(sqrt(r/(tan(t)**2 + 1))), FiniteSet(0)) + s_y = sqrt(r/(tan(t)**2 + 1))*tan(t) + soln = FiniteSet((s_x_2, s_y)) + FiniteSet((s_x_1, -s_y)) + assert substitution(system, [x, y]) == soln + + n = Dummy('n') + eqs = [exp(x)**2 - sin(y) + z**2, 1/exp(y) - 3] + s_real_y = -log(3) + s_real_z = sqrt(-exp(2*x) - sin(log(3))) + soln_real = FiniteSet((s_real_y, s_real_z), (s_real_y, -s_real_z)) + lam = Lambda(n, 2*n*I*pi + -log(3)) + s_complex_y = ImageSet(lam, S.Integers) + lam = Lambda(n, sqrt(-exp(2*x) + sin(2*n*I*pi + -log(3)))) + s_complex_z_1 = ImageSet(lam, S.Integers) + lam = Lambda(n, -sqrt(-exp(2*x) + sin(2*n*I*pi + -log(3)))) + s_complex_z_2 = ImageSet(lam, S.Integers) + soln_complex = FiniteSet( + (s_complex_y, s_complex_z_1), + (s_complex_y, s_complex_z_2)) + soln = soln_real + soln_complex + assert dumeq(substitution(eqs, [y, z]), soln) + + +def test_raises_substitution(): + raises(ValueError, lambda: substitution([x**2 -1], [])) + raises(TypeError, lambda: substitution([x**2 -1])) + raises(ValueError, lambda: substitution([x**2 -1], [sin(x)])) + raises(TypeError, lambda: substitution([x**2 -1], x)) + raises(TypeError, lambda: substitution([x**2 -1], 1)) + + +def test_issue_21022(): + from sympy.core.sympify import sympify + + eqs = [ + 'k-16', + 'p-8', + 'y*y+z*z-x*x', + 'd - x + p', + 'd*d+k*k-y*y', + 'z*z-p*p-k*k', + 'abc-efg', + ] + efg = Symbol('efg') + eqs = [sympify(x) for x in eqs] + + syb = list(ordered(set.union(*[x.free_symbols for x in eqs]))) + res = nonlinsolve(eqs, syb) + + ans = FiniteSet( + (efg, 32, efg, 16, 8, 40, -16*sqrt(5), -8*sqrt(5)), + (efg, 32, efg, 16, 8, 40, -16*sqrt(5), 8*sqrt(5)), + (efg, 32, efg, 16, 8, 40, 16*sqrt(5), -8*sqrt(5)), + (efg, 32, efg, 16, 8, 40, 16*sqrt(5), 8*sqrt(5)), + ) + assert len(res) == len(ans) == 4 + assert res == ans + for result in res.args: + assert len(result) == 8 + + +def test_issue_17940(): + n = Dummy('n') + k1 = Dummy('k1') + sol = ImageSet(Lambda(((k1, n),), I*(2*k1*pi + arg(2*n*I*pi + log(5))) + + log(Abs(2*n*I*pi + log(5)))), + ProductSet(S.Integers, S.Integers)) + assert solveset(exp(exp(x)) - 5, x).dummy_eq(sol) + + +def test_issue_17906(): + assert solveset(7**(x**2 - 80) - 49**x, x) == FiniteSet(-8, 10) + + +@XFAIL +def test_issue_17933(): + eq1 = x*sin(45) - y*cos(q) + eq2 = x*cos(45) - y*sin(q) + eq3 = 9*x*sin(45)/10 + y*cos(q) + eq4 = 9*x*cos(45)/10 + y*sin(z) - z + assert nonlinsolve([eq1, eq2, eq3, eq4], x, y, z, q) ==\ + FiniteSet((0, 0, 0, q)) + +def test_issue_17933_bis(): + # nonlinsolve's result depends on the 'default_sort_key' ordering of + # the unknowns. + eq1 = x*sin(45) - y*cos(q) + eq2 = x*cos(45) - y*sin(q) + eq3 = 9*x*sin(45)/10 + y*cos(q) + eq4 = 9*x*cos(45)/10 + y*sin(z) - z + zz = Symbol('zz') + eqs = [e.subs(q, zz) for e in (eq1, eq2, eq3, eq4)] + assert nonlinsolve(eqs, x, y, z, zz) == FiniteSet((0, 0, 0, zz)) + + +def test_issue_14565(): + # removed redundancy + assert dumeq(nonlinsolve([k + m, k + m*exp(-2*pi*k)], [k, m]) , + FiniteSet((-n*I, ImageSet(Lambda(n, n*I), S.Integers)))) + + +# end of tests for nonlinsolve + + +def test_issue_9556(): + b = Symbol('b', positive=True) + + assert solveset(Abs(x) + 1, x, S.Reals) is S.EmptySet + assert solveset(Abs(x) + b, x, S.Reals) is S.EmptySet + assert solveset(Eq(b, -1), b, S.Reals) is S.EmptySet + + +def test_issue_9611(): + assert solveset(Eq(x - x + a, a), x, S.Reals) == S.Reals + assert solveset(Eq(y - y + a, a), y) == S.Complexes + + +def test_issue_9557(): + assert solveset(x**2 + a, x, S.Reals) == Intersection(S.Reals, + FiniteSet(-sqrt(-a), sqrt(-a))) + + +def test_issue_9778(): + x = Symbol('x', real=True) + y = Symbol('y', real=True) + assert solveset(x**3 + 1, x, S.Reals) == FiniteSet(-1) + assert solveset(x**Rational(3, 5) + 1, x, S.Reals) == S.EmptySet + assert solveset(x**3 + y, x, S.Reals) == \ + FiniteSet(-Abs(y)**Rational(1, 3)*sign(y)) + + +def test_issue_10214(): + assert solveset(x**Rational(3, 2) + 4, x, S.Reals) == S.EmptySet + assert solveset(x**(Rational(-3, 2)) + 4, x, S.Reals) == S.EmptySet + + ans = FiniteSet(-2**Rational(2, 3)) + assert solveset(x**(S(3)) + 4, x, S.Reals) == ans + assert (x**(S(3)) + 4).subs(x,list(ans)[0]) == 0 # substituting ans and verifying the result. + assert (x**(S(3)) + 4).subs(x,-(-2)**Rational(2, 3)) == 0 + + +def test_issue_9849(): + assert solveset(Abs(sin(x)) + 1, x, S.Reals) == S.EmptySet + + +def test_issue_9953(): + assert linsolve([ ], x) == S.EmptySet + + +def test_issue_9913(): + assert solveset(2*x + 1/(x - 10)**2, x, S.Reals) == \ + FiniteSet(-(3*sqrt(24081)/4 + Rational(4027, 4))**Rational(1, 3)/3 - 100/ + (3*(3*sqrt(24081)/4 + Rational(4027, 4))**Rational(1, 3)) + Rational(20, 3)) + + +def test_issue_10397(): + assert solveset(sqrt(x), x, S.Complexes) == FiniteSet(0) + + +def test_issue_14987(): + raises(ValueError, lambda: linear_eq_to_matrix( + [x**2], x)) + raises(ValueError, lambda: linear_eq_to_matrix( + [x*(-3/x + 1) + 2*y - a], [x, y])) + raises(ValueError, lambda: linear_eq_to_matrix( + [(x**2 - 3*x)/(x - 3) - 3], x)) + raises(ValueError, lambda: linear_eq_to_matrix( + [(x + 1)**3 - x**3 - 3*x**2 + 7], x)) + raises(ValueError, lambda: linear_eq_to_matrix( + [x*(1/x + 1) + y], [x, y])) + raises(ValueError, lambda: linear_eq_to_matrix( + [(x + 1)*y], [x, y])) + raises(ValueError, lambda: linear_eq_to_matrix( + [Eq(1/x, 1/x + y)], [x, y])) + raises(ValueError, lambda: linear_eq_to_matrix( + [Eq(y/x, y/x + y)], [x, y])) + raises(ValueError, lambda: linear_eq_to_matrix( + [Eq(x*(x + 1), x**2 + y)], [x, y])) + + +def test_simplification(): + eq = x + (a - b)/(-2*a + 2*b) + assert solveset(eq, x) == FiniteSet(S.Half) + assert solveset(eq, x, S.Reals) == Intersection({-((a - b)/(-2*a + 2*b))}, S.Reals) + # So that ap - bn is not zero: + ap = Symbol('ap', positive=True) + bn = Symbol('bn', negative=True) + eq = x + (ap - bn)/(-2*ap + 2*bn) + assert solveset(eq, x) == FiniteSet(S.Half) + assert solveset(eq, x, S.Reals) == FiniteSet(S.Half) + + +def test_integer_domain_relational(): + eq1 = 2*x + 3 > 0 + eq2 = x**2 + 3*x - 2 >= 0 + eq3 = x + 1/x > -2 + 1/x + eq4 = x + sqrt(x**2 - 5) > 0 + eq = x + 1/x > -2 + 1/x + eq5 = eq.subs(x,log(x)) + eq6 = log(x)/x <= 0 + eq7 = log(x)/x < 0 + eq8 = x/(x-3) < 3 + eq9 = x/(x**2-3) < 3 + + assert solveset(eq1, x, S.Integers) == Range(-1, oo, 1) + assert solveset(eq2, x, S.Integers) == Union(Range(-oo, -3, 1), Range(1, oo, 1)) + assert solveset(eq3, x, S.Integers) == Union(Range(-1, 0, 1), Range(1, oo, 1)) + assert solveset(eq4, x, S.Integers) == Range(3, oo, 1) + assert solveset(eq5, x, S.Integers) == Range(2, oo, 1) + assert solveset(eq6, x, S.Integers) == Range(1, 2, 1) + assert solveset(eq7, x, S.Integers) == S.EmptySet + assert solveset(eq8, x, domain=Range(0,5)) == Range(0, 3, 1) + assert solveset(eq9, x, domain=Range(0,5)) == Union(Range(0, 2, 1), Range(2, 5, 1)) + + # test_issue_19794 + assert solveset(x + 2 < 0, x, S.Integers) == Range(-oo, -2, 1) + + +def test_issue_10555(): + f = Function('f') + g = Function('g') + assert solveset(f(x) - pi/2, x, S.Reals).dummy_eq( + ConditionSet(x, Eq(f(x) - pi/2, 0), S.Reals)) + assert solveset(f(g(x)) - pi/2, g(x), S.Reals).dummy_eq( + ConditionSet(g(x), Eq(f(g(x)) - pi/2, 0), S.Reals)) + + +def test_issue_8715(): + eq = x + 1/x > -2 + 1/x + assert solveset(eq, x, S.Reals) == \ + (Interval.open(-2, oo) - FiniteSet(0)) + assert solveset(eq.subs(x,log(x)), x, S.Reals) == \ + Interval.open(exp(-2), oo) - FiniteSet(1) + + +def test_issue_11174(): + eq = z**2 + exp(2*x) - sin(y) + soln = Intersection(S.Reals, FiniteSet(log(-z**2 + sin(y))/2)) + assert solveset(eq, x, S.Reals) == soln + + eq = sqrt(r)*Abs(tan(t))/sqrt(tan(t)**2 + 1) + x*tan(t) + s = -sqrt(r)*Abs(tan(t))/(sqrt(tan(t)**2 + 1)*tan(t)) + soln = Intersection(S.Reals, FiniteSet(s)) + assert solveset(eq, x, S.Reals) == soln + + +def test_issue_11534(): + # eq1 and eq2 should not have the same solutions because squaring both + # sides of the radical equation introduces a spurious solution branch. + # The equations have a symbolic parameter y and it is easy to see that for + # y != 0 the solution s1 will not be valid for eq1. + x = Symbol('x', real=True) + y = Symbol('y', real=True) + eq1 = -y + x/sqrt(-x**2 + 1) + eq2 = -y**2 + x**2/(-x**2 + 1) + + # We get a ConditionSet here because s1 works in eq1 if y is equal to zero + # although not for any other value of y. That case is redundant though + # because if y=0 then s1=s2 so the solution for eq1 could just be returned + # as s2 - {-1, 1}. In fact we have + # |y/sqrt(y**2 + 1)| < 1 + # So the complements are not needed either. The ideal output here would be + # sol1 = s2 + # sol2 = s1 | s2. + s1, s2 = FiniteSet(-y/sqrt(y**2 + 1)), FiniteSet(y/sqrt(y**2 + 1)) + cset = ConditionSet(x, Eq(eq1, 0), s1) + sol1 = (s2 - {-1, 1}) | (cset - {-1, 1}) + sol2 = (s1 | s2) - {-1, 1} + + assert solveset(eq1, x, S.Reals) == sol1 + assert solveset(eq2, x, S.Reals) == sol2 + + +def test_issue_10477(): + assert solveset((x**2 + 4*x - 3)/x < 2, x, S.Reals) == \ + Union(Interval.open(-oo, -3), Interval.open(0, 1)) + + +def test_issue_10671(): + assert solveset(sin(y), y, Interval(0, pi)) == FiniteSet(0, pi) + i = Interval(1, 10) + assert solveset((1/x).diff(x) < 0, x, i) == i + + +def test_issue_11064(): + eq = x + sqrt(x**2 - 5) + assert solveset(eq > 0, x, S.Reals) == \ + Interval(sqrt(5), oo) + assert solveset(eq < 0, x, S.Reals) == \ + Interval(-oo, -sqrt(5)) + assert solveset(eq > sqrt(5), x, S.Reals) == \ + Interval.Lopen(sqrt(5), oo) + + +def test_issue_12478(): + eq = sqrt(x - 2) + 2 + soln = solveset_real(eq, x) + assert soln is S.EmptySet + assert solveset(eq < 0, x, S.Reals) is S.EmptySet + assert solveset(eq > 0, x, S.Reals) == Interval(2, oo) + + +def test_issue_12429(): + eq = solveset(log(x)/x <= 0, x, S.Reals) + sol = Interval.Lopen(0, 1) + assert eq == sol + + +def test_issue_19506(): + eq = arg(x + I) + C = Dummy('C') + assert solveset(eq).dummy_eq(Intersection(ConditionSet(C, Eq(im(C) + 1, 0), S.Complexes), + ConditionSet(C, re(C) > 0, S.Complexes))) + + +def test_solveset_arg(): + assert solveset(arg(x), x, S.Reals) == Interval.open(0, oo) + assert solveset(arg(4*x -3), x, S.Reals) == Interval.open(Rational(3, 4), oo) + + +def test__is_finite_with_finite_vars(): + f = _is_finite_with_finite_vars + # issue 12482 + assert all(f(1/x) is None for x in ( + Dummy(), Dummy(real=True), Dummy(complex=True))) + assert f(1/Dummy(real=False)) is True # b/c it's finite but not 0 + + +def test_issue_13550(): + assert solveset(x**2 - 2*x - 15, symbol = x, domain = Interval(-oo, 0)) == FiniteSet(-3) + + +def test_issue_13849(): + assert nonlinsolve((t*(sqrt(5) + sqrt(2)) - sqrt(2), t), t) is S.EmptySet + + +def test_issue_14223(): + assert solveset((Abs(x + Min(x, 2)) - 2).rewrite(Piecewise), x, + S.Reals) == FiniteSet(-1, 1) + assert solveset((Abs(x + Min(x, 2)) - 2).rewrite(Piecewise), x, + Interval(0, 2)) == FiniteSet(1) + assert solveset(x, x, FiniteSet(1, 2)) is S.EmptySet + + +def test_issue_10158(): + dom = S.Reals + assert solveset(x*Max(x, 15) - 10, x, dom) == FiniteSet(Rational(2, 3)) + assert solveset(x*Min(x, 15) - 10, x, dom) == FiniteSet(-sqrt(10), sqrt(10)) + assert solveset(Max(Abs(x - 3) - 1, x + 2) - 3, x, dom) == FiniteSet(-1, 1) + assert solveset(Abs(x - 1) - Abs(y), x, dom) == FiniteSet(-Abs(y) + 1, Abs(y) + 1) + assert solveset(Abs(x + 4*Abs(x + 1)), x, dom) == FiniteSet(Rational(-4, 3), Rational(-4, 5)) + assert solveset(2*Abs(x + Abs(x + Max(3, x))) - 2, x, S.Reals) == FiniteSet(-1, -2) + dom = S.Complexes + raises(ValueError, lambda: solveset(x*Max(x, 15) - 10, x, dom)) + raises(ValueError, lambda: solveset(x*Min(x, 15) - 10, x, dom)) + raises(ValueError, lambda: solveset(Max(Abs(x - 3) - 1, x + 2) - 3, x, dom)) + raises(ValueError, lambda: solveset(Abs(x - 1) - Abs(y), x, dom)) + raises(ValueError, lambda: solveset(Abs(x + 4*Abs(x + 1)), x, dom)) + + +def test_issue_14300(): + f = 1 - exp(-18000000*x) - y + a1 = FiniteSet(-log(-y + 1)/18000000) + + assert solveset(f, x, S.Reals) == \ + Intersection(S.Reals, a1) + assert dumeq(solveset(f, x), + ImageSet(Lambda(n, -I*(2*n*pi + arg(-y + 1))/18000000 - + log(Abs(y - 1))/18000000), S.Integers)) + + +def test_issue_14454(): + number = CRootOf(x**4 + x - 1, 2) + raises(ValueError, lambda: invert_real(number, 0, x)) + assert invert_real(x**2, number, x) # no error + + +def test_issue_17882(): + assert solveset(-8*x**2/(9*(x**2 - 1)**(S(4)/3)) + 4/(3*(x**2 - 1)**(S(1)/3)), x, S.Complexes) == \ + FiniteSet(sqrt(3), -sqrt(3)) + + +def test_term_factors(): + assert list(_term_factors(3**x - 2)) == [-2, 3**x] + expr = 4**(x + 1) + 4**(x + 2) + 4**(x - 1) - 3**(x + 2) - 3**(x + 3) + assert set(_term_factors(expr)) == { + 3**(x + 2), 4**(x + 2), 3**(x + 3), 4**(x - 1), -1, 4**(x + 1)} + + +#################### tests for transolve and its helpers ############### + +def test_transolve(): + + assert _transolve(3**x, x, S.Reals) == S.EmptySet + assert _transolve(3**x - 9**(x + 5), x, S.Reals) == FiniteSet(-10) + + +def test_issue_21276(): + eq = (2*x*(y - z) - y*erf(y - z) - y + z*erf(y - z) + z)**2 + assert solveset(eq.expand(), y) == FiniteSet(z, z + erfinv(2*x - 1)) + + +# exponential tests +def test_exponential_real(): + from sympy.abc import y + + e1 = 3**(2*x) - 2**(x + 3) + e2 = 4**(5 - 9*x) - 8**(2 - x) + e3 = 2**x + 4**x + e4 = exp(log(5)*x) - 2**x + e5 = exp(x/y)*exp(-z/y) - 2 + e6 = 5**(x/2) - 2**(x/3) + e7 = 4**(x + 1) + 4**(x + 2) + 4**(x - 1) - 3**(x + 2) - 3**(x + 3) + e8 = -9*exp(-2*x + 5) + 4*exp(3*x + 1) + e9 = 2**x + 4**x + 8**x - 84 + e10 = 29*2**(x + 1)*615**(x) - 123*2726**(x) + + assert solveset(e1, x, S.Reals) == FiniteSet( + -3*log(2)/(-2*log(3) + log(2))) + assert solveset(e2, x, S.Reals) == FiniteSet(Rational(4, 15)) + assert solveset(e3, x, S.Reals) == S.EmptySet + assert solveset(e4, x, S.Reals) == FiniteSet(0) + assert solveset(e5, x, S.Reals) == Intersection( + S.Reals, FiniteSet(y*log(2*exp(z/y)))) + assert solveset(e6, x, S.Reals) == FiniteSet(0) + assert solveset(e7, x, S.Reals) == FiniteSet(2) + assert solveset(e8, x, S.Reals) == FiniteSet(-2*log(2)/5 + 2*log(3)/5 + Rational(4, 5)) + assert solveset(e9, x, S.Reals) == FiniteSet(2) + assert solveset(e10,x, S.Reals) == FiniteSet((-log(29) - log(2) + log(123))/(-log(2726) + log(2) + log(615))) + + assert solveset_real(-9*exp(-2*x + 5) + 2**(x + 1), x) == FiniteSet( + -((-5 - 2*log(3) + log(2))/(log(2) + 2))) + assert solveset_real(4**(x/2) - 2**(x/3), x) == FiniteSet(0) + b = sqrt(6)*sqrt(log(2))/sqrt(log(5)) + assert solveset_real(5**(x/2) - 2**(3/x), x) == FiniteSet(-b, b) + + # coverage test + C1, C2 = symbols('C1 C2') + f = Function('f') + assert solveset_real(C1 + C2/x**2 - exp(-f(x)), f(x)) == Intersection( + S.Reals, FiniteSet(-log(C1 + C2/x**2))) + y = symbols('y', positive=True) + assert solveset_real(x**2 - y**2/exp(x), y) == Intersection( + S.Reals, FiniteSet(-sqrt(x**2*exp(x)), sqrt(x**2*exp(x)))) + p = Symbol('p', positive=True) + assert solveset_real((1/p + 1)**(p + 1), p).dummy_eq( + ConditionSet(x, Eq((1 + 1/x)**(x + 1), 0), S.Reals)) + assert solveset(2**x - 4**x + 12, x, S.Reals) == {2} + assert solveset(2**x - 2**(2*x) + 12, x, S.Reals) == {2} + + +@XFAIL +def test_exponential_complex(): + n = Dummy('n') + + assert dumeq(solveset_complex(2**x + 4**x, x),imageset( + Lambda(n, I*(2*n*pi + pi)/log(2)), S.Integers)) + assert solveset_complex(x**z*y**z - 2, z) == FiniteSet( + log(2)/(log(x) + log(y))) + assert dumeq(solveset_complex(4**(x/2) - 2**(x/3), x), imageset( + Lambda(n, 3*n*I*pi/log(2)), S.Integers)) + assert dumeq(solveset(2**x + 32, x), imageset( + Lambda(n, (I*(2*n*pi + pi) + 5*log(2))/log(2)), S.Integers)) + + eq = (2**exp(y**2/x) + 2)/(x**2 + 15) + a = sqrt(x)*sqrt(-log(log(2)) + log(log(2) + 2*n*I*pi)) + assert solveset_complex(eq, y) == FiniteSet(-a, a) + + union1 = imageset(Lambda(n, I*(2*n*pi - pi*Rational(2, 3))/log(2)), S.Integers) + union2 = imageset(Lambda(n, I*(2*n*pi + pi*Rational(2, 3))/log(2)), S.Integers) + assert dumeq(solveset(2**x + 4**x + 8**x, x), Union(union1, union2)) + + eq = 4**(x + 1) + 4**(x + 2) + 4**(x - 1) - 3**(x + 2) - 3**(x + 3) + res = solveset(eq, x) + num = 2*n*I*pi - 4*log(2) + 2*log(3) + den = -2*log(2) + log(3) + ans = imageset(Lambda(n, num/den), S.Integers) + assert dumeq(res, ans) + + +def test_expo_conditionset(): + + f1 = (exp(x) + 1)**x - 2 + f2 = (x + 2)**y*x - 3 + f3 = 2**x - exp(x) - 3 + f4 = log(x) - exp(x) + f5 = 2**x + 3**x - 5**x + + assert solveset(f1, x, S.Reals).dummy_eq(ConditionSet( + x, Eq((exp(x) + 1)**x - 2, 0), S.Reals)) + assert solveset(f2, x, S.Reals).dummy_eq(ConditionSet( + x, Eq(x*(x + 2)**y - 3, 0), S.Reals)) + assert solveset(f3, x, S.Reals).dummy_eq(ConditionSet( + x, Eq(2**x - exp(x) - 3, 0), S.Reals)) + assert solveset(f4, x, S.Reals).dummy_eq(ConditionSet( + x, Eq(-exp(x) + log(x), 0), S.Reals)) + assert solveset(f5, x, S.Reals).dummy_eq(ConditionSet( + x, Eq(2**x + 3**x - 5**x, 0), S.Reals)) + + +def test_exponential_symbols(): + x, y, z = symbols('x y z', positive=True) + xr, zr = symbols('xr, zr', real=True) + + assert solveset(z**x - y, x, S.Reals) == Intersection( + S.Reals, FiniteSet(log(y)/log(z))) + + f1 = 2*x**w - 4*y**w + f2 = (x/y)**w - 2 + sol1 = Intersection({log(2)/(log(x) - log(y))}, S.Reals) + sol2 = Intersection({log(2)/log(x/y)}, S.Reals) + assert solveset(f1, w, S.Reals) == sol1, solveset(f1, w, S.Reals) + assert solveset(f2, w, S.Reals) == sol2, solveset(f2, w, S.Reals) + + assert solveset(x**x, x, Interval.Lopen(0,oo)).dummy_eq( + ConditionSet(w, Eq(w**w, 0), Interval.open(0, oo))) + assert solveset(x**y - 1, y, S.Reals) == FiniteSet(0) + assert solveset(exp(x/y)*exp(-z/y) - 2, y, S.Reals) == \ + Complement(ConditionSet(y, Eq(im(x)/y, 0) & Eq(im(z)/y, 0), \ + Complement(Intersection(FiniteSet((x - z)/log(2)), S.Reals), FiniteSet(0))), FiniteSet(0)) + assert solveset(exp(xr/y)*exp(-zr/y) - 2, y, S.Reals) == \ + Complement(FiniteSet((xr - zr)/log(2)), FiniteSet(0)) + + assert solveset(a**x - b**x, x).dummy_eq(ConditionSet( + w, Ne(a, 0) & Ne(b, 0), FiniteSet(0))) + + +def test_ignore_assumptions(): + # make sure assumptions are ignored + xpos = symbols('x', positive=True) + x = symbols('x') + assert solveset_complex(xpos**2 - 4, xpos + ) == solveset_complex(x**2 - 4, x) + + +@XFAIL +def test_issue_10864(): + assert solveset(x**(y*z) - x, x, S.Reals) == FiniteSet(1) + + +@XFAIL +def test_solve_only_exp_2(): + assert solveset_real(sqrt(exp(x)) + sqrt(exp(-x)) - 4, x) == \ + FiniteSet(2*log(-sqrt(3) + 2), 2*log(sqrt(3) + 2)) + + +def test_is_exponential(): + assert _is_exponential(y, x) is False + assert _is_exponential(3**x - 2, x) is True + assert _is_exponential(5**x - 7**(2 - x), x) is True + assert _is_exponential(sin(2**x) - 4*x, x) is False + assert _is_exponential(x**y - z, y) is True + assert _is_exponential(x**y - z, x) is False + assert _is_exponential(2**x + 4**x - 1, x) is True + assert _is_exponential(x**(y*z) - x, x) is False + assert _is_exponential(x**(2*x) - 3**x, x) is False + assert _is_exponential(x**y - y*z, y) is False + assert _is_exponential(x**y - x*z, y) is True + + +def test_solve_exponential(): + assert _solve_exponential(3**(2*x) - 2**(x + 3), 0, x, S.Reals) == \ + FiniteSet(-3*log(2)/(-2*log(3) + log(2))) + assert _solve_exponential(2**y + 4**y, 1, y, S.Reals) == \ + FiniteSet(log(Rational(-1, 2) + sqrt(5)/2)/log(2)) + assert _solve_exponential(2**y + 4**y, 0, y, S.Reals) == \ + S.EmptySet + assert _solve_exponential(2**x + 3**x - 5**x, 0, x, S.Reals) == \ + ConditionSet(x, Eq(2**x + 3**x - 5**x, 0), S.Reals) + +# end of exponential tests + + +# logarithmic tests +def test_logarithmic(): + assert solveset_real(log(x - 3) + log(x + 3), x) == FiniteSet( + -sqrt(10), sqrt(10)) + assert solveset_real(log(x + 1) - log(2*x - 1), x) == FiniteSet(2) + assert solveset_real(log(x + 3) + log(1 + 3/x) - 3, x) == FiniteSet( + -3 + sqrt(-12 + exp(3))*exp(Rational(3, 2))/2 + exp(3)/2, + -sqrt(-12 + exp(3))*exp(Rational(3, 2))/2 - 3 + exp(3)/2) + + eq = z - log(x) + log(y/(x*(-1 + y**2/x**2))) + assert solveset_real(eq, x) == \ + Intersection(S.Reals, FiniteSet(-sqrt(y**2 - y*exp(z)), + sqrt(y**2 - y*exp(z)))) - \ + Intersection(S.Reals, FiniteSet(-sqrt(y**2), sqrt(y**2))) + assert solveset_real( + log(3*x) - log(-x + 1) - log(4*x + 1), x) == FiniteSet(Rational(-1, 2), S.Half) + assert solveset(log(x**y) - y*log(x), x, S.Reals) == S.Reals + +@XFAIL +def test_uselogcombine_2(): + eq = log(exp(2*x) + 1) + log(-tanh(x) + 1) - log(2) + assert solveset_real(eq, x) is S.EmptySet + eq = log(8*x) - log(sqrt(x) + 1) - 2 + assert solveset_real(eq, x) is S.EmptySet + + +def test_is_logarithmic(): + assert _is_logarithmic(y, x) is False + assert _is_logarithmic(log(x), x) is True + assert _is_logarithmic(log(x) - 3, x) is True + assert _is_logarithmic(log(x)*log(y), x) is True + assert _is_logarithmic(log(x)**2, x) is False + assert _is_logarithmic(log(x - 3) + log(x + 3), x) is True + assert _is_logarithmic(log(x**y) - y*log(x), x) is True + assert _is_logarithmic(sin(log(x)), x) is False + assert _is_logarithmic(x + y, x) is False + assert _is_logarithmic(log(3*x) - log(1 - x) + 4, x) is True + assert _is_logarithmic(log(x) + log(y) + x, x) is False + assert _is_logarithmic(log(log(x - 3)) + log(x - 3), x) is True + assert _is_logarithmic(log(log(3) + x) + log(x), x) is True + assert _is_logarithmic(log(x)*(y + 3) + log(x), y) is False + + +def test_solve_logarithm(): + y = Symbol('y') + assert _solve_logarithm(log(x**y) - y*log(x), 0, x, S.Reals) == S.Reals + y = Symbol('y', positive=True) + assert _solve_logarithm(log(x)*log(y), 0, x, S.Reals) == FiniteSet(1) + +# end of logarithmic tests + + +# lambert tests +def test_is_lambert(): + a, b, c = symbols('a,b,c') + assert _is_lambert(x**2, x) is False + assert _is_lambert(a**x**2+b*x+c, x) is True + assert _is_lambert(E**2, x) is False + assert _is_lambert(x*E**2, x) is False + assert _is_lambert(3*log(x) - x*log(3), x) is True + assert _is_lambert(log(log(x - 3)) + log(x-3), x) is True + assert _is_lambert(5*x - 1 + 3*exp(2 - 7*x), x) is True + assert _is_lambert((a/x + exp(x/2)).diff(x, 2), x) is True + assert _is_lambert((x**2 - 2*x + 1).subs(x, (log(x) + 3*x)**2 - 1), x) is True + assert _is_lambert(x*sinh(x) - 1, x) is True + assert _is_lambert(x*cos(x) - 5, x) is True + assert _is_lambert(tanh(x) - 5*x, x) is True + assert _is_lambert(cosh(x) - sinh(x), x) is False + +# end of lambert tests + + +def test_linear_coeffs(): + from sympy.solvers.solveset import linear_coeffs + assert linear_coeffs(0, x) == [0, 0] + assert all(i is S.Zero for i in linear_coeffs(0, x)) + assert linear_coeffs(x + 2*y + 3, x, y) == [1, 2, 3] + assert linear_coeffs(x + 2*y + 3, y, x) == [2, 1, 3] + assert linear_coeffs(x + 2*x**2 + 3, x, x**2) == [1, 2, 3] + raises(ValueError, lambda: + linear_coeffs(x + 2*x**2 + x**3, x, x**2)) + raises(ValueError, lambda: + linear_coeffs(1/x*(x - 1) + 1/x, x)) + raises(ValueError, lambda: + linear_coeffs(x, x, x)) + assert linear_coeffs(a*(x + y), x, y) == [a, a, 0] + assert linear_coeffs(1.0, x, y) == [0, 0, 1.0] + # don't include coefficients of 0 + assert linear_coeffs(Eq(x, x + y), x, y, dict=True) == {y: -1} + assert linear_coeffs(0, x, y, dict=True) == {} + + +def test_is_modular(): + assert _is_modular(y, x) is False + assert _is_modular(Mod(x, 3) - 1, x) is True + assert _is_modular(Mod(x**3 - 3*x**2 - x + 1, 3) - 1, x) is True + assert _is_modular(Mod(exp(x + y), 3) - 2, x) is True + assert _is_modular(Mod(exp(x + y), 3) - log(x), x) is True + assert _is_modular(Mod(x, 3) - 1, y) is False + assert _is_modular(Mod(x, 3)**2 - 5, x) is False + assert _is_modular(Mod(x, 3)**2 - y, x) is False + assert _is_modular(exp(Mod(x, 3)) - 1, x) is False + assert _is_modular(Mod(3, y) - 1, y) is False + + +def test_invert_modular(): + n = Dummy('n', integer=True) + from sympy.solvers.solveset import _invert_modular as invert_modular + + # no solutions + assert invert_modular(Mod(x, 12), S(1)/2, n, x) == (x, S.EmptySet) + # non invertible cases + assert invert_modular(Mod(sin(x), 7), S(5), n, x) == (Mod(sin(x), 7), 5) + assert invert_modular(Mod(exp(x), 7), S(5), n, x) == (Mod(exp(x), 7), 5) + assert invert_modular(Mod(log(x), 7), S(5), n, x) == (Mod(log(x), 7), 5) + # a is symbol + assert dumeq(invert_modular(Mod(x, 7), S(5), n, x), + (x, ImageSet(Lambda(n, 7*n + 5), S.Integers))) + # a.is_Add + assert dumeq(invert_modular(Mod(x + 8, 7), S(5), n, x), + (x, ImageSet(Lambda(n, 7*n + 4), S.Integers))) + assert invert_modular(Mod(x**2 + x, 7), S(5), n, x) == \ + (Mod(x**2 + x, 7), 5) + # a.is_Mul + assert dumeq(invert_modular(Mod(3*x, 7), S(5), n, x), + (x, ImageSet(Lambda(n, 7*n + 4), S.Integers))) + assert invert_modular(Mod((x + 1)*(x + 2), 7), S(5), n, x) == \ + (Mod((x + 1)*(x + 2), 7), 5) + # a.is_Pow + assert invert_modular(Mod(x**4, 7), S(5), n, x) == \ + (x, S.EmptySet) + assert dumeq(invert_modular(Mod(3**x, 4), S(3), n, x), + (x, ImageSet(Lambda(n, 2*n + 1), S.Naturals0))) + assert dumeq(invert_modular(Mod(2**(x**2 + x + 1), 7), S(2), n, x), + (x**2 + x + 1, ImageSet(Lambda(n, 3*n + 1), S.Naturals0))) + assert invert_modular(Mod(sin(x)**4, 7), S(5), n, x) == (x, S.EmptySet) + + +def test_solve_modular(): + n = Dummy('n', integer=True) + # if rhs has symbol (need to be implemented in future). + assert solveset(Mod(x, 4) - x, x, S.Integers + ).dummy_eq( + ConditionSet(x, Eq(-x + Mod(x, 4), 0), + S.Integers)) + # when _invert_modular fails to invert + assert solveset(3 - Mod(sin(x), 7), x, S.Integers + ).dummy_eq( + ConditionSet(x, Eq(Mod(sin(x), 7) - 3, 0), S.Integers)) + assert solveset(3 - Mod(log(x), 7), x, S.Integers + ).dummy_eq( + ConditionSet(x, Eq(Mod(log(x), 7) - 3, 0), S.Integers)) + assert solveset(3 - Mod(exp(x), 7), x, S.Integers + ).dummy_eq(ConditionSet(x, Eq(Mod(exp(x), 7) - 3, 0), + S.Integers)) + # EmptySet solution definitely + assert solveset(7 - Mod(x, 5), x, S.Integers) is S.EmptySet + assert solveset(5 - Mod(x, 5), x, S.Integers) is S.EmptySet + # Negative m + assert dumeq(solveset(2 + Mod(x, -3), x, S.Integers), + ImageSet(Lambda(n, -3*n - 2), S.Integers)) + assert solveset(4 + Mod(x, -3), x, S.Integers) is S.EmptySet + # linear expression in Mod + assert dumeq(solveset(3 - Mod(x, 5), x, S.Integers), + ImageSet(Lambda(n, 5*n + 3), S.Integers)) + assert dumeq(solveset(3 - Mod(5*x - 8, 7), x, S.Integers), + ImageSet(Lambda(n, 7*n + 5), S.Integers)) + assert dumeq(solveset(3 - Mod(5*x, 7), x, S.Integers), + ImageSet(Lambda(n, 7*n + 2), S.Integers)) + # higher degree expression in Mod + assert dumeq(solveset(Mod(x**2, 160) - 9, x, S.Integers), + Union(ImageSet(Lambda(n, 160*n + 3), S.Integers), + ImageSet(Lambda(n, 160*n + 13), S.Integers), + ImageSet(Lambda(n, 160*n + 67), S.Integers), + ImageSet(Lambda(n, 160*n + 77), S.Integers), + ImageSet(Lambda(n, 160*n + 83), S.Integers), + ImageSet(Lambda(n, 160*n + 93), S.Integers), + ImageSet(Lambda(n, 160*n + 147), S.Integers), + ImageSet(Lambda(n, 160*n + 157), S.Integers))) + assert solveset(3 - Mod(x**4, 7), x, S.Integers) is S.EmptySet + assert dumeq(solveset(Mod(x**4, 17) - 13, x, S.Integers), + Union(ImageSet(Lambda(n, 17*n + 3), S.Integers), + ImageSet(Lambda(n, 17*n + 5), S.Integers), + ImageSet(Lambda(n, 17*n + 12), S.Integers), + ImageSet(Lambda(n, 17*n + 14), S.Integers))) + # a.is_Pow tests + assert dumeq(solveset(Mod(7**x, 41) - 15, x, S.Integers), + ImageSet(Lambda(n, 40*n + 3), S.Naturals0)) + assert dumeq(solveset(Mod(12**x, 21) - 18, x, S.Integers), + ImageSet(Lambda(n, 6*n + 2), S.Naturals0)) + assert dumeq(solveset(Mod(3**x, 4) - 3, x, S.Integers), + ImageSet(Lambda(n, 2*n + 1), S.Naturals0)) + assert dumeq(solveset(Mod(2**x, 7) - 2 , x, S.Integers), + ImageSet(Lambda(n, 3*n + 1), S.Naturals0)) + assert dumeq(solveset(Mod(3**(3**x), 4) - 3, x, S.Integers), + Intersection(ImageSet(Lambda(n, Intersection({log(2*n + 1)/log(3)}, + S.Integers)), S.Naturals0), S.Integers)) + # Implemented for m without primitive root + assert solveset(Mod(x**3, 7) - 2, x, S.Integers) is S.EmptySet + assert dumeq(solveset(Mod(x**3, 8) - 1, x, S.Integers), + ImageSet(Lambda(n, 8*n + 1), S.Integers)) + assert dumeq(solveset(Mod(x**4, 9) - 4, x, S.Integers), + Union(ImageSet(Lambda(n, 9*n + 4), S.Integers), + ImageSet(Lambda(n, 9*n + 5), S.Integers))) + # domain intersection + assert dumeq(solveset(3 - Mod(5*x - 8, 7), x, S.Naturals0), + Intersection(ImageSet(Lambda(n, 7*n + 5), S.Integers), S.Naturals0)) + # Complex args + assert solveset(Mod(x, 3) - I, x, S.Integers) == \ + S.EmptySet + assert solveset(Mod(I*x, 3) - 2, x, S.Integers + ).dummy_eq( + ConditionSet(x, Eq(Mod(I*x, 3) - 2, 0), S.Integers)) + assert solveset(Mod(I + x, 3) - 2, x, S.Integers + ).dummy_eq( + ConditionSet(x, Eq(Mod(x + I, 3) - 2, 0), S.Integers)) + + # issue 17373 (https://github.com/sympy/sympy/issues/17373) + assert dumeq(solveset(Mod(x**4, 14) - 11, x, S.Integers), + Union(ImageSet(Lambda(n, 14*n + 3), S.Integers), + ImageSet(Lambda(n, 14*n + 11), S.Integers))) + assert dumeq(solveset(Mod(x**31, 74) - 43, x, S.Integers), + ImageSet(Lambda(n, 74*n + 31), S.Integers)) + + # issue 13178 + n = symbols('n', integer=True) + a = 742938285 + b = 1898888478 + m = 2**31 - 1 + c = 20170816 + assert dumeq(solveset(c - Mod(a**n*b, m), n, S.Integers), + ImageSet(Lambda(n, 2147483646*n + 100), S.Naturals0)) + assert dumeq(solveset(c - Mod(a**n*b, m), n, S.Naturals0), + Intersection(ImageSet(Lambda(n, 2147483646*n + 100), S.Naturals0), + S.Naturals0)) + assert dumeq(solveset(c - Mod(a**(2*n)*b, m), n, S.Integers), + Intersection(ImageSet(Lambda(n, 1073741823*n + 50), S.Naturals0), + S.Integers)) + assert solveset(c - Mod(a**(2*n + 7)*b, m), n, S.Integers) is S.EmptySet + assert dumeq(solveset(c - Mod(a**(n - 4)*b, m), n, S.Integers), + Intersection(ImageSet(Lambda(n, 2147483646*n + 104), S.Naturals0), + S.Integers)) + +# end of modular tests + +def test_issue_17276(): + assert nonlinsolve([Eq(x, 5**(S(1)/5)), Eq(x*y, 25*sqrt(5))], x, y) == \ + FiniteSet((5**(S(1)/5), 25*5**(S(3)/10))) + + +def test_issue_10426(): + x = Dummy('x') + a = Symbol('a') + n = Dummy('n') + assert (solveset(sin(x + a) - sin(x), a)).dummy_eq(Dummy('x')) == (Union( + ImageSet(Lambda(n, 2*n*pi), S.Integers), + Intersection(S.Complexes, ImageSet(Lambda(n, -I*(I*(2*n*pi + arg(-exp(-2*I*x))) + 2*im(x))), + S.Integers)))).dummy_eq(Dummy('x,n')) + + +def test_solveset_conjugate(): + """Test solveset for simple conjugate functions""" + assert solveset(conjugate(x) -3 + I) == FiniteSet(3 + I) + + +def test_issue_18208(): + variables = symbols('x0:16') + symbols('y0:12') + x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12, x13, x14, x15,\ + y0, y1, y2, y3, y4, y5, y6, y7, y8, y9, y10, y11 = variables + + eqs = [x0 + x1 + x2 + x3 - 51, + x0 + x1 + x4 + x5 - 46, + x2 + x3 + x6 + x7 - 39, + x0 + x3 + x4 + x7 - 50, + x1 + x2 + x5 + x6 - 35, + x4 + x5 + x6 + x7 - 34, + x4 + x5 + x8 + x9 - 46, + x10 + x11 + x6 + x7 - 23, + x11 + x4 + x7 + x8 - 25, + x10 + x5 + x6 + x9 - 44, + x10 + x11 + x8 + x9 - 35, + x12 + x13 + x8 + x9 - 35, + x10 + x11 + x14 + x15 - 29, + x11 + x12 + x15 + x8 - 35, + x10 + x13 + x14 + x9 - 29, + x12 + x13 + x14 + x15 - 29, + y0 + y1 + y2 + y3 - 55, + y0 + y1 + y4 + y5 - 53, + y2 + y3 + y6 + y7 - 56, + y0 + y3 + y4 + y7 - 57, + y1 + y2 + y5 + y6 - 52, + y4 + y5 + y6 + y7 - 54, + y4 + y5 + y8 + y9 - 48, + y10 + y11 + y6 + y7 - 60, + y11 + y4 + y7 + y8 - 51, + y10 + y5 + y6 + y9 - 57, + y10 + y11 + y8 + y9 - 54, + x10 - 2, + x11 - 5, + x12 - 1, + x13 - 6, + x14 - 1, + x15 - 21, + y0 - 12, + y1 - 20] + + expected = [38 - x3, x3 - 10, 23 - x3, x3, 12 - x7, x7 + 6, 16 - x7, x7, + 8, 20, 2, 5, 1, 6, 1, 21, 12, 20, -y11 + y9 + 2, y11 - y9 + 21, + -y11 - y7 + y9 + 24, y11 + y7 - y9 - 3, 33 - y7, y7, 27 - y9, y9, + 27 - y11, y11] + + A, b = linear_eq_to_matrix(eqs, variables) + + # solve + solve_expected = {v:eq for v, eq in zip(variables, expected) if v != eq} + + assert solve(eqs, variables) == solve_expected + + # linsolve + linsolve_expected = FiniteSet(Tuple(*expected)) + + assert linsolve(eqs, variables) == linsolve_expected + assert linsolve((A, b), variables) == linsolve_expected + + # gauss_jordan_solve + gj_solve, new_vars = A.gauss_jordan_solve(b) + gj_solve = list(gj_solve) + + gj_expected = linsolve_expected.subs(zip([x3, x7, y7, y9, y11], new_vars)) + + assert FiniteSet(Tuple(*gj_solve)) == gj_expected + + # nonlinsolve + # The solution set of nonlinsolve is currently equivalent to linsolve and is + # also correct. However, we would prefer to use the same symbols as parameters + # for the solution to the underdetermined system in all cases if possible. + # We want a solution that is not just equivalent but also given in the same form. + # This test may be changed should nonlinsolve be modified in this way. + + nonlinsolve_expected = FiniteSet((38 - x3, x3 - 10, 23 - x3, x3, 12 - x7, x7 + 6, + 16 - x7, x7, 8, 20, 2, 5, 1, 6, 1, 21, 12, 20, + -y5 + y7 - 1, y5 - y7 + 24, 21 - y5, y5, 33 - y7, + y7, 27 - y9, y9, -y5 + y7 - y9 + 24, y5 - y7 + y9 + 3)) + + assert nonlinsolve(eqs, variables) == nonlinsolve_expected + + +def test_substitution_with_infeasible_solution(): + a00, a01, a10, a11, l0, l1, l2, l3, m0, m1, m2, m3, m4, m5, m6, m7, c00, c01, c10, c11, p00, p01, p10, p11 = symbols( + 'a00, a01, a10, a11, l0, l1, l2, l3, m0, m1, m2, m3, m4, m5, m6, m7, c00, c01, c10, c11, p00, p01, p10, p11' + ) + solvefor = [p00, p01, p10, p11, c00, c01, c10, c11, m0, m1, m3, l0, l1, l2, l3] + system = [ + -l0 * c00 - l1 * c01 + m0 + c00 + c01, + -l0 * c10 - l1 * c11 + m1, + -l2 * c00 - l3 * c01 + c00 + c01, + -l2 * c10 - l3 * c11 + m3, + -l0 * p00 - l2 * p10 + p00 + p10, + -l1 * p00 - l3 * p10 + p00 + p10, + -l0 * p01 - l2 * p11, + -l1 * p01 - l3 * p11, + -a00 + c00 * p00 + c10 * p01, + -a01 + c01 * p00 + c11 * p01, + -a10 + c00 * p10 + c10 * p11, + -a11 + c01 * p10 + c11 * p11, + -m0 * p00, + -m1 * p01, + -m2 * p10, + -m3 * p11, + -m4 * c00, + -m5 * c01, + -m6 * c10, + -m7 * c11, + m2, + m4, + m5, + m6, + m7 + ] + sol = FiniteSet( + (0, Complement(FiniteSet(p01), FiniteSet(0)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, l2, l3), + (p00, Complement(FiniteSet(p01), FiniteSet(0)), 0, p11, 0, 0, 0, 0, 0, 0, 0, 1, 1, -p01/p11, -p01/p11), + (0, Complement(FiniteSet(p01), FiniteSet(0)), 0, p11, 0, 0, 0, 0, 0, 0, 0, 1, -l3*p11/p01, -p01/p11, l3), + (0, Complement(FiniteSet(p01), FiniteSet(0)), 0, p11, 0, 0, 0, 0, 0, 0, 0, -l2*p11/p01, -l3*p11/p01, l2, l3), + ) + assert sol != nonlinsolve(system, solvefor) + + +def test_issue_20097(): + assert solveset(1/sqrt(x)) is S.EmptySet + + +def test_issue_15350(): + assert solveset(diff(sqrt(1/x+x))) == FiniteSet(-1, 1) + + +def test_issue_18359(): + c1 = Piecewise((0, x < 0), (Min(1, x)/2 - Min(2, x)/2 + Min(3, x)/2, True)) + c2 = Piecewise((Piecewise((0, x < 0), (Min(1, x)/2 - Min(2, x)/2 + Min(3, x)/2, True)), x >= 0), (0, True)) + correct_result = Interval(1, 2) + result1 = solveset(c1 - Rational(1, 2), x, Interval(0, 3)) + result2 = solveset(c2 - Rational(1, 2), x, Interval(0, 3)) + assert result1 == correct_result + assert result2 == correct_result + + +def test_issue_17604(): + lhs = -2**(3*x/11)*exp(x/11) + pi**(x/11) + assert _is_exponential(lhs, x) + assert _solve_exponential(lhs, 0, x, S.Complexes) == FiniteSet(0) + + +def test_issue_17580(): + assert solveset(1/(1 - x**3)**2, x, S.Reals) is S.EmptySet + + +def test_issue_17566_actual(): + sys = [2**x + 2**y - 3, 4**x + 9**y - 5] + # Not clear this is the correct result, but at least no recursion error + assert nonlinsolve(sys, x, y) == FiniteSet((log(3 - 2**y)/log(2), y)) + + +def test_issue_17565(): + eq = Ge(2*(x - 2)**2/(3*(x + 1)**(Integer(1)/3)) + 2*(x - 2)*(x + 1)**(Integer(2)/3), 0) + res = Union(Interval.Lopen(-1, -Rational(1, 4)), Interval(2, oo)) + assert solveset(eq, x, S.Reals) == res + + +def test_issue_15024(): + function = (x + 5)/sqrt(-x**2 - 10*x) + assert solveset(function, x, S.Reals) == FiniteSet(Integer(-5)) + + +def test_issue_16877(): + assert dumeq(nonlinsolve([x - 1, sin(y)], x, y), + FiniteSet((1, ImageSet(Lambda(n, 2*n*pi), S.Integers)), + (1, ImageSet(Lambda(n, 2*n*pi + pi), S.Integers)))) + # Even better if (1, ImageSet(Lambda(n, n*pi), S.Integers)) is obtained + + +def test_issue_16876(): + assert dumeq(nonlinsolve([sin(x), 2*x - 4*y], x, y), + FiniteSet((ImageSet(Lambda(n, 2*n*pi), S.Integers), + ImageSet(Lambda(n, n*pi), S.Integers)), + (ImageSet(Lambda(n, 2*n*pi + pi), S.Integers), + ImageSet(Lambda(n, n*pi + pi/2), S.Integers)))) + # Even better if (ImageSet(Lambda(n, n*pi), S.Integers), + # ImageSet(Lambda(n, n*pi/2), S.Integers)) is obtained + +def test_issue_21236(): + x, z = symbols("x z") + y = symbols('y', rational=True) + assert solveset(x**y - z, x, S.Reals) == ConditionSet(x, Eq(x**y - z, 0), S.Reals) + e1, e2 = symbols('e1 e2', even=True) + y = e1/e2 # don't know if num or den will be odd and the other even + assert solveset(x**y - z, x, S.Reals) == ConditionSet(x, Eq(x**y - z, 0), S.Reals) + + +def test_issue_21908(): + assert nonlinsolve([(x**2 + 2*x - y**2)*exp(x), -2*y*exp(x)], x, y + ) == {(-2, 0), (0, 0)} + + +def test_issue_19144(): + # test case 1 + expr1 = [x + y - 1, y**2 + 1] + eq1 = [Eq(i, 0) for i in expr1] + soln1 = {(1 - I, I), (1 + I, -I)} + soln_expr1 = nonlinsolve(expr1, [x, y]) + soln_eq1 = nonlinsolve(eq1, [x, y]) + assert soln_eq1 == soln_expr1 == soln1 + # test case 2 - with denoms + expr2 = [x/y - 1, y**2 + 1] + eq2 = [Eq(i, 0) for i in expr2] + soln2 = {(-I, -I), (I, I)} + soln_expr2 = nonlinsolve(expr2, [x, y]) + soln_eq2 = nonlinsolve(eq2, [x, y]) + assert soln_eq2 == soln_expr2 == soln2 + # denominators that cancel in expression + assert nonlinsolve([Eq(x + 1/x, 1/x)], [x]) == FiniteSet((S.EmptySet,)) + + +def test_issue_22413(): + res = nonlinsolve((4*y*(2*x + 2*exp(y) + 1)*exp(2*x), + 4*x*exp(2*x) + 4*y*exp(2*x + y) + 4*exp(2*x + y) + 1), + x, y) + # First solution is not correct, but the issue was an exception + sols = FiniteSet((x, S.Zero), (-exp(y) - S.Half, y)) + assert res == sols + + +def test_issue_23318(): + eqs_eq = [ + Eq(53.5780461486929, x * log(y / (5.0 - y) + 1) / y), + Eq(x, 0.0015 * z), + Eq(0.0015, 7845.32 * y / z), + ] + eqs_expr = [eq.lhs - eq.rhs for eq in eqs_eq] + + sol = {(266.97755814852, 0.0340301680681629, 177985.03876568)} + + assert_close_nl(nonlinsolve(eqs_eq, [x, y, z]), sol) + assert_close_nl(nonlinsolve(eqs_expr, [x, y, z]), sol) + + logterm = log(1.91196789933362e-7*z/(5.0 - 1.91196789933362e-7*z) + 1) + eq = -0.0015*z*logterm + 1.02439504345316e-5*z + assert_close_ss(solveset(eq, z), {0, 177985.038765679}) + + +def test_issue_19814(): + assert nonlinsolve([ 2**m - 2**(2*n), 4*2**m - 2**(4*n)], m, n + ) == FiniteSet((log(2**(2*n))/log(2), S.Complexes)) + + +def test_issue_22058(): + sol = solveset(-sqrt(t)*x**2 + 2*x + sqrt(t), x, S.Reals) + # doesn't fail (and following numerical check) + assert sol.xreplace({t: 1}) == {1 - sqrt(2), 1 + sqrt(2)}, sol.xreplace({t: 1}) + + +def test_issue_11184(): + assert solveset(20*sqrt(y**2 + (sqrt(-(y - 10)*(y + 10)) + 10)**2) - 60, y, S.Reals) is S.EmptySet + + +def test_issue_21890(): + e = S(2)/3 + assert nonlinsolve([4*x**3*y**4 - 2*y, 4*x**4*y**3 - 2*x], x, y) == { + (2**e/(2*y), y), ((-2**e/4 - 2**e*sqrt(3)*I/4)/y, y), + ((-2**e/4 + 2**e*sqrt(3)*I/4)/y, y)} + assert nonlinsolve([(1 - 4*x**2)*exp(-2*x**2 - 2*y**2), + -4*x*y*exp(-2*x**2)*exp(-2*y**2)], x, y) == {(-S(1)/2, 0), (S(1)/2, 0)} + rx, ry = symbols('x y', real=True) + sol = nonlinsolve([4*rx**3*ry**4 - 2*ry, 4*rx**4*ry**3 - 2*rx], rx, ry) + ans = {(2**(S(2)/3)/(2*ry), ry), + ((-2**(S(2)/3)/4 - 2**(S(2)/3)*sqrt(3)*I/4)/ry, ry), + ((-2**(S(2)/3)/4 + 2**(S(2)/3)*sqrt(3)*I/4)/ry, ry)} + assert sol == ans + + +def test_issue_22628(): + assert nonlinsolve([h - 1, k - 1, f - 2, f - 4, -2*k], h, k, f) == S.EmptySet + assert nonlinsolve([x**3 - 1, x + y, x**2 - 4], [x, y]) == S.EmptySet + + +def test_issue_25781(): + assert solve(sqrt(x/2) - x) == [0, S.Half] + + +def test_issue_26077(): + _n = Symbol('_n') + function = x*cot(5*x) + critical_points = stationary_points(function, x, S.Reals) + excluded_points = Union( + ImageSet(Lambda(_n, 2*_n*pi/5), S.Integers), + ImageSet(Lambda(_n, 2*_n*pi/5 + pi/5), S.Integers) + ) + solution = ConditionSet(x, + Eq(x*(-5*cot(5*x)**2 - 5) + cot(5*x), 0), + Complement(S.Reals, excluded_points) + ) + assert solution.as_dummy() == critical_points.as_dummy() diff --git a/phivenv/Lib/site-packages/sympy/stats/__init__.py b/phivenv/Lib/site-packages/sympy/stats/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..adb79261954924305c1837555d7d47cd53b8430b --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/stats/__init__.py @@ -0,0 +1,202 @@ +""" +SymPy statistics module + +Introduces a random variable type into the SymPy language. + +Random variables may be declared using prebuilt functions such as +Normal, Exponential, Coin, Die, etc... or built with functions like FiniteRV. + +Queries on random expressions can be made using the functions + +========================= ============================= + Expression Meaning +------------------------- ----------------------------- + ``P(condition)`` Probability + ``E(expression)`` Expected value + ``H(expression)`` Entropy + ``variance(expression)`` Variance + ``density(expression)`` Probability Density Function + ``sample(expression)`` Produce a realization + ``where(condition)`` Where the condition is true +========================= ============================= + +Examples +======== + +>>> from sympy.stats import P, E, variance, Die, Normal +>>> from sympy import simplify +>>> X, Y = Die('X', 6), Die('Y', 6) # Define two six sided dice +>>> Z = Normal('Z', 0, 1) # Declare a Normal random variable with mean 0, std 1 +>>> P(X>3) # Probability X is greater than 3 +1/2 +>>> E(X+Y) # Expectation of the sum of two dice +7 +>>> variance(X+Y) # Variance of the sum of two dice +35/6 +>>> simplify(P(Z>1)) # Probability of Z being greater than 1 +1/2 - erf(sqrt(2)/2)/2 + + +One could also create custom distribution and define custom random variables +as follows: + +1. If you want to create a Continuous Random Variable: + +>>> from sympy.stats import ContinuousRV, P, E +>>> from sympy import exp, Symbol, Interval, oo +>>> x = Symbol('x') +>>> pdf = exp(-x) # pdf of the Continuous Distribution +>>> Z = ContinuousRV(x, pdf, set=Interval(0, oo)) +>>> E(Z) +1 +>>> P(Z > 5) +exp(-5) + +1.1 To create an instance of Continuous Distribution: + +>>> from sympy.stats import ContinuousDistributionHandmade +>>> from sympy import Lambda +>>> dist = ContinuousDistributionHandmade(Lambda(x, pdf), set=Interval(0, oo)) +>>> dist.pdf(x) +exp(-x) + +2. If you want to create a Discrete Random Variable: + +>>> from sympy.stats import DiscreteRV, P, E +>>> from sympy import Symbol, S +>>> p = S(1)/2 +>>> x = Symbol('x', integer=True, positive=True) +>>> pdf = p*(1 - p)**(x - 1) +>>> D = DiscreteRV(x, pdf, set=S.Naturals) +>>> E(D) +2 +>>> P(D > 3) +1/8 + +2.1 To create an instance of Discrete Distribution: + +>>> from sympy.stats import DiscreteDistributionHandmade +>>> from sympy import Lambda +>>> dist = DiscreteDistributionHandmade(Lambda(x, pdf), set=S.Naturals) +>>> dist.pdf(x) +2**(1 - x)/2 + +3. If you want to create a Finite Random Variable: + +>>> from sympy.stats import FiniteRV, P, E +>>> from sympy import Rational, Eq +>>> pmf = {1: Rational(1, 3), 2: Rational(1, 6), 3: Rational(1, 4), 4: Rational(1, 4)} +>>> X = FiniteRV('X', pmf) +>>> E(X) +29/12 +>>> P(X > 3) +1/4 + +3.1 To create an instance of Finite Distribution: + +>>> from sympy.stats import FiniteDistributionHandmade +>>> dist = FiniteDistributionHandmade(pmf) +>>> dist.pmf(x) +Lambda(x, Piecewise((1/3, Eq(x, 1)), (1/6, Eq(x, 2)), (1/4, Eq(x, 3) | Eq(x, 4)), (0, True))) +""" + +__all__ = [ + 'P', 'E', 'H', 'density', 'where', 'given', 'sample', 'cdf','median', + 'characteristic_function', 'pspace', 'sample_iter', 'variance', 'std', + 'skewness', 'kurtosis', 'covariance', 'dependent', 'entropy', 'independent', + 'random_symbols', 'correlation', 'factorial_moment', 'moment', 'cmoment', + 'sampling_density', 'moment_generating_function', 'smoment', 'quantile', + 'coskewness', 'sample_stochastic_process', + + 'FiniteRV', 'DiscreteUniform', 'Die', 'Bernoulli', 'Coin', 'Binomial', + 'BetaBinomial', 'Hypergeometric', 'Rademacher', 'IdealSoliton', 'RobustSoliton', + 'FiniteDistributionHandmade', + + 'ContinuousRV', 'Arcsin', 'Benini', 'Beta', 'BetaNoncentral', 'BetaPrime', + 'BoundedPareto', 'Cauchy', 'Chi', 'ChiNoncentral', 'ChiSquared', 'Dagum', 'Davis', 'Erlang', + 'ExGaussian', 'Exponential', 'ExponentialPower', 'FDistribution', + 'FisherZ', 'Frechet', 'Gamma', 'GammaInverse', 'Gompertz', 'Gumbel', + 'Kumaraswamy', 'Laplace', 'Levy', 'Logistic','LogCauchy', 'LogLogistic', 'LogitNormal', 'LogNormal', 'Lomax', + 'Moyal', 'Maxwell', 'Nakagami', 'Normal', 'GaussianInverse', 'Pareto', 'PowerFunction', + 'QuadraticU', 'RaisedCosine', 'Rayleigh','Reciprocal', 'StudentT', 'ShiftedGompertz', + 'Trapezoidal', 'Triangular', 'Uniform', 'UniformSum', 'VonMises', 'Wald', + 'Weibull', 'WignerSemicircle', 'ContinuousDistributionHandmade', + + 'FlorySchulz', 'Geometric','Hermite', 'Logarithmic', 'NegativeBinomial', 'Poisson', 'Skellam', + 'YuleSimon', 'Zeta', 'DiscreteRV', 'DiscreteDistributionHandmade', + + 'JointRV', 'Dirichlet', 'GeneralizedMultivariateLogGamma', + 'GeneralizedMultivariateLogGammaOmega', 'Multinomial', 'MultivariateBeta', + 'MultivariateEwens', 'MultivariateT', 'NegativeMultinomial', + 'NormalGamma', 'MultivariateNormal', 'MultivariateLaplace', 'marginal_distribution', + + 'StochasticProcess', 'DiscreteTimeStochasticProcess', + 'DiscreteMarkovChain', 'TransitionMatrixOf', 'StochasticStateSpaceOf', + 'GeneratorMatrixOf', 'ContinuousMarkovChain', 'BernoulliProcess', + 'PoissonProcess', 'WienerProcess', 'GammaProcess', + + 'CircularEnsemble', 'CircularUnitaryEnsemble', + 'CircularOrthogonalEnsemble', 'CircularSymplecticEnsemble', + 'GaussianEnsemble', 'GaussianUnitaryEnsemble', + 'GaussianOrthogonalEnsemble', 'GaussianSymplecticEnsemble', + 'joint_eigen_distribution', 'JointEigenDistribution', + 'level_spacing_distribution', + + 'MatrixGamma', 'Wishart', 'MatrixNormal', 'MatrixStudentT', + + 'Probability', 'Expectation', 'Variance', 'Covariance', 'Moment', + 'CentralMoment', + + 'ExpectationMatrix', 'VarianceMatrix', 'CrossCovarianceMatrix' + +] +from .rv_interface import (P, E, H, density, where, given, sample, cdf, median, + characteristic_function, pspace, sample_iter, variance, std, skewness, + kurtosis, covariance, dependent, entropy, independent, random_symbols, + correlation, factorial_moment, moment, cmoment, sampling_density, + moment_generating_function, smoment, quantile, coskewness, + sample_stochastic_process) + +from .frv_types import (FiniteRV, DiscreteUniform, Die, Bernoulli, Coin, + Binomial, BetaBinomial, Hypergeometric, Rademacher, + FiniteDistributionHandmade, IdealSoliton, RobustSoliton) + +from .crv_types import (ContinuousRV, Arcsin, Benini, Beta, BetaNoncentral, + BetaPrime, BoundedPareto, Cauchy, Chi, ChiNoncentral, ChiSquared, + Dagum, Davis, Erlang, ExGaussian, Exponential, ExponentialPower, + FDistribution, FisherZ, Frechet, Gamma, GammaInverse, GaussianInverse, + Gompertz, Gumbel, Kumaraswamy, Laplace, Levy, Logistic, LogCauchy, + LogLogistic, LogitNormal, LogNormal, Lomax, Maxwell, Moyal, Nakagami, + Normal, Pareto, QuadraticU, RaisedCosine, Rayleigh, Reciprocal, + StudentT, PowerFunction, ShiftedGompertz, Trapezoidal, Triangular, + Uniform, UniformSum, VonMises, Wald, Weibull, WignerSemicircle, + ContinuousDistributionHandmade) + +from .drv_types import (FlorySchulz, Geometric, Hermite, Logarithmic, NegativeBinomial, Poisson, + Skellam, YuleSimon, Zeta, DiscreteRV, DiscreteDistributionHandmade) + +from .joint_rv_types import (JointRV, Dirichlet, + GeneralizedMultivariateLogGamma, GeneralizedMultivariateLogGammaOmega, + Multinomial, MultivariateBeta, MultivariateEwens, MultivariateT, + NegativeMultinomial, NormalGamma, MultivariateNormal, MultivariateLaplace, + marginal_distribution) + +from .stochastic_process_types import (StochasticProcess, + DiscreteTimeStochasticProcess, DiscreteMarkovChain, + TransitionMatrixOf, StochasticStateSpaceOf, GeneratorMatrixOf, + ContinuousMarkovChain, BernoulliProcess, PoissonProcess, WienerProcess, + GammaProcess) + +from .random_matrix_models import (CircularEnsemble, CircularUnitaryEnsemble, + CircularOrthogonalEnsemble, CircularSymplecticEnsemble, + GaussianEnsemble, GaussianUnitaryEnsemble, GaussianOrthogonalEnsemble, + GaussianSymplecticEnsemble, joint_eigen_distribution, + JointEigenDistribution, level_spacing_distribution) + +from .matrix_distributions import MatrixGamma, Wishart, MatrixNormal, MatrixStudentT + +from .symbolic_probability import (Probability, Expectation, Variance, + Covariance, Moment, CentralMoment) + +from .symbolic_multivariate_probability import (ExpectationMatrix, VarianceMatrix, + CrossCovarianceMatrix) diff --git a/phivenv/Lib/site-packages/sympy/stats/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/stats/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dbc8ffd63445d510bca0277b7cfc3e4c4f1aa6ba Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/stats/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/stats/__pycache__/compound_rv.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/stats/__pycache__/compound_rv.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64a76fdfc66fcfe8eaa3b1587fc6380b931766c0 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/stats/__pycache__/compound_rv.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/stats/__pycache__/crv.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/stats/__pycache__/crv.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97d4cb4e0109125bfe759b69223f9b0d8f08f07a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/stats/__pycache__/crv.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/stats/__pycache__/drv.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/stats/__pycache__/drv.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c727067afc71aa6cb873901f3c1fb283dce2bf6 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/stats/__pycache__/drv.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/stats/__pycache__/drv_types.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/stats/__pycache__/drv_types.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..965552a4ee4e689904c82360b1d1fd6a90872b8e Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/stats/__pycache__/drv_types.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/stats/__pycache__/error_prop.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/stats/__pycache__/error_prop.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8730478fb33e02695ac6921bb30a9aeff6b8fe82 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/stats/__pycache__/error_prop.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/stats/__pycache__/frv.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/stats/__pycache__/frv.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f176115028883cf4af9a4b3895673c1444ab1b00 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/stats/__pycache__/frv.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/stats/__pycache__/frv_types.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/stats/__pycache__/frv_types.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3fe7ce8ee102ccbd5daaa89485a395a30ed50d34 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/stats/__pycache__/frv_types.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/stats/__pycache__/joint_rv.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/stats/__pycache__/joint_rv.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..284446fbce0c62331cc905fa56f67af8f18dafa5 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/stats/__pycache__/joint_rv.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/stats/__pycache__/joint_rv_types.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/stats/__pycache__/joint_rv_types.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b76fcb315a0be8f279f6713802071f80abf1b14 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/stats/__pycache__/joint_rv_types.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/stats/__pycache__/matrix_distributions.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/stats/__pycache__/matrix_distributions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45dd07fa01ba8e9cd4edc20e8bc17b2c900b9432 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/stats/__pycache__/matrix_distributions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/stats/__pycache__/random_matrix.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/stats/__pycache__/random_matrix.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a7c78963c91257e7de12ffe3df1a673ad153b8d Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/stats/__pycache__/random_matrix.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/stats/__pycache__/random_matrix_models.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/stats/__pycache__/random_matrix_models.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24d644900ba56d130e51aa9d0901f5913fd8a991 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/stats/__pycache__/random_matrix_models.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/stats/__pycache__/rv.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/stats/__pycache__/rv.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf04676c463a449d9cc399dcaa1bacc2b53521b7 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/stats/__pycache__/rv.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/stats/__pycache__/rv_interface.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/stats/__pycache__/rv_interface.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a0d949791e3ffc71d248955b3f196c247ec1e0f Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/stats/__pycache__/rv_interface.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/stats/__pycache__/stochastic_process.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/stats/__pycache__/stochastic_process.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27030fe46fed09e8ebb53fc57e39188fb399750e Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/stats/__pycache__/stochastic_process.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/stats/__pycache__/stochastic_process_types.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/stats/__pycache__/stochastic_process_types.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e04a1167ca0a28eaac5dd50039c915401b41a836 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/stats/__pycache__/stochastic_process_types.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/stats/__pycache__/symbolic_multivariate_probability.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/stats/__pycache__/symbolic_multivariate_probability.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2f764580e3da6e3ee7fd8a2a70ee469acd4d9be Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/stats/__pycache__/symbolic_multivariate_probability.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/stats/__pycache__/symbolic_probability.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/stats/__pycache__/symbolic_probability.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..76ce90ac50cc7887988e5659ce237a0e59eeedc2 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/stats/__pycache__/symbolic_probability.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/stats/compound_rv.py b/phivenv/Lib/site-packages/sympy/stats/compound_rv.py new file mode 100644 index 0000000000000000000000000000000000000000..27555f4233fe691bac303800a87736205acbdee6 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/stats/compound_rv.py @@ -0,0 +1,223 @@ +from sympy.concrete.summations import Sum +from sympy.core.basic import Basic +from sympy.core.function import Lambda +from sympy.core.symbol import Dummy +from sympy.integrals.integrals import Integral +from sympy.stats.rv import (NamedArgsMixin, random_symbols, _symbol_converter, + PSpace, RandomSymbol, is_random, Distribution) +from sympy.stats.crv import ContinuousDistribution, SingleContinuousPSpace +from sympy.stats.drv import DiscreteDistribution, SingleDiscretePSpace +from sympy.stats.frv import SingleFiniteDistribution, SingleFinitePSpace +from sympy.stats.crv_types import ContinuousDistributionHandmade +from sympy.stats.drv_types import DiscreteDistributionHandmade +from sympy.stats.frv_types import FiniteDistributionHandmade + + +class CompoundPSpace(PSpace): + """ + A temporary Probability Space for the Compound Distribution. After + Marginalization, this returns the corresponding Probability Space of the + parent distribution. + """ + + def __new__(cls, s, distribution): + s = _symbol_converter(s) + if isinstance(distribution, ContinuousDistribution): + return SingleContinuousPSpace(s, distribution) + if isinstance(distribution, DiscreteDistribution): + return SingleDiscretePSpace(s, distribution) + if isinstance(distribution, SingleFiniteDistribution): + return SingleFinitePSpace(s, distribution) + if not isinstance(distribution, CompoundDistribution): + raise ValueError("%s should be an isinstance of " + "CompoundDistribution"%(distribution)) + return Basic.__new__(cls, s, distribution) + + @property + def value(self): + return RandomSymbol(self.symbol, self) + + @property + def symbol(self): + return self.args[0] + + @property + def is_Continuous(self): + return self.distribution.is_Continuous + + @property + def is_Finite(self): + return self.distribution.is_Finite + + @property + def is_Discrete(self): + return self.distribution.is_Discrete + + @property + def distribution(self): + return self.args[1] + + @property + def pdf(self): + return self.distribution.pdf(self.symbol) + + @property + def set(self): + return self.distribution.set + + @property + def domain(self): + return self._get_newpspace().domain + + def _get_newpspace(self, evaluate=False): + x = Dummy('x') + parent_dist = self.distribution.args[0] + func = Lambda(x, self.distribution.pdf(x, evaluate)) + new_pspace = self._transform_pspace(self.symbol, parent_dist, func) + if new_pspace is not None: + return new_pspace + message = ("Compound Distribution for %s is not implemented yet" % str(parent_dist)) + raise NotImplementedError(message) + + def _transform_pspace(self, sym, dist, pdf): + """ + This function returns the new pspace of the distribution using handmade + Distributions and their corresponding pspace. + """ + pdf = Lambda(sym, pdf(sym)) + _set = dist.set + if isinstance(dist, ContinuousDistribution): + return SingleContinuousPSpace(sym, ContinuousDistributionHandmade(pdf, _set)) + elif isinstance(dist, DiscreteDistribution): + return SingleDiscretePSpace(sym, DiscreteDistributionHandmade(pdf, _set)) + elif isinstance(dist, SingleFiniteDistribution): + dens = {k: pdf(k) for k in _set} + return SingleFinitePSpace(sym, FiniteDistributionHandmade(dens)) + + def compute_density(self, expr, *, compound_evaluate=True, **kwargs): + new_pspace = self._get_newpspace(compound_evaluate) + expr = expr.subs({self.value: new_pspace.value}) + return new_pspace.compute_density(expr, **kwargs) + + def compute_cdf(self, expr, *, compound_evaluate=True, **kwargs): + new_pspace = self._get_newpspace(compound_evaluate) + expr = expr.subs({self.value: new_pspace.value}) + return new_pspace.compute_cdf(expr, **kwargs) + + def compute_expectation(self, expr, rvs=None, evaluate=False, **kwargs): + new_pspace = self._get_newpspace(evaluate) + expr = expr.subs({self.value: new_pspace.value}) + if rvs: + rvs = rvs.subs({self.value: new_pspace.value}) + if isinstance(new_pspace, SingleFinitePSpace): + return new_pspace.compute_expectation(expr, rvs, **kwargs) + return new_pspace.compute_expectation(expr, rvs, evaluate, **kwargs) + + def probability(self, condition, *, compound_evaluate=True, **kwargs): + new_pspace = self._get_newpspace(compound_evaluate) + condition = condition.subs({self.value: new_pspace.value}) + return new_pspace.probability(condition) + + def conditional_space(self, condition, *, compound_evaluate=True, **kwargs): + new_pspace = self._get_newpspace(compound_evaluate) + condition = condition.subs({self.value: new_pspace.value}) + return new_pspace.conditional_space(condition) + + +class CompoundDistribution(Distribution, NamedArgsMixin): + """ + Class for Compound Distributions. + + Parameters + ========== + + dist : Distribution + Distribution must contain a random parameter + + Examples + ======== + + >>> from sympy.stats.compound_rv import CompoundDistribution + >>> from sympy.stats.crv_types import NormalDistribution + >>> from sympy.stats import Normal + >>> from sympy.abc import x + >>> X = Normal('X', 2, 4) + >>> N = NormalDistribution(X, 4) + >>> C = CompoundDistribution(N) + >>> C.set + Interval(-oo, oo) + >>> C.pdf(x, evaluate=True).simplify() + exp(-x**2/64 + x/16 - 1/16)/(8*sqrt(pi)) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Compound_probability_distribution + + """ + + def __new__(cls, dist): + if not isinstance(dist, (ContinuousDistribution, + SingleFiniteDistribution, DiscreteDistribution)): + message = "Compound Distribution for %s is not implemented yet" % str(dist) + raise NotImplementedError(message) + if not cls._compound_check(dist): + return dist + return Basic.__new__(cls, dist) + + @property + def set(self): + return self.args[0].set + + @property + def is_Continuous(self): + return isinstance(self.args[0], ContinuousDistribution) + + @property + def is_Finite(self): + return isinstance(self.args[0], SingleFiniteDistribution) + + @property + def is_Discrete(self): + return isinstance(self.args[0], DiscreteDistribution) + + def pdf(self, x, evaluate=False): + dist = self.args[0] + randoms = [rv for rv in dist.args if is_random(rv)] + if isinstance(dist, SingleFiniteDistribution): + y = Dummy('y', integer=True, negative=False) + expr = dist.pmf(y) + else: + y = Dummy('y') + expr = dist.pdf(y) + for rv in randoms: + expr = self._marginalise(expr, rv, evaluate) + return Lambda(y, expr)(x) + + def _marginalise(self, expr, rv, evaluate): + if isinstance(rv.pspace.distribution, SingleFiniteDistribution): + rv_dens = rv.pspace.distribution.pmf(rv) + else: + rv_dens = rv.pspace.distribution.pdf(rv) + rv_dom = rv.pspace.domain.set + if rv.pspace.is_Discrete or rv.pspace.is_Finite: + expr = Sum(expr*rv_dens, (rv, rv_dom._inf, + rv_dom._sup)) + else: + expr = Integral(expr*rv_dens, (rv, rv_dom._inf, + rv_dom._sup)) + if evaluate: + return expr.doit() + return expr + + @classmethod + def _compound_check(self, dist): + """ + Checks if the given distribution contains random parameters. + """ + randoms = [] + for arg in dist.args: + randoms.extend(random_symbols(arg)) + if len(randoms) == 0: + return False + return True diff --git a/phivenv/Lib/site-packages/sympy/stats/crv.py b/phivenv/Lib/site-packages/sympy/stats/crv.py new file mode 100644 index 0000000000000000000000000000000000000000..0a5184029679f663c83d81aa6c1b6ca4d948c70f --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/stats/crv.py @@ -0,0 +1,570 @@ +""" +Continuous Random Variables Module + +See Also +======== +sympy.stats.crv_types +sympy.stats.rv +sympy.stats.frv +""" + + +from sympy.core.basic import Basic +from sympy.core.cache import cacheit +from sympy.core.function import Lambda, PoleError +from sympy.core.numbers import (I, nan, oo) +from sympy.core.relational import (Eq, Ne) +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, symbols) +from sympy.core.sympify import _sympify, sympify +from sympy.functions.combinatorial.factorials import factorial +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.special.delta_functions import DiracDelta +from sympy.integrals.integrals import (Integral, integrate) +from sympy.logic.boolalg import (And, Or) +from sympy.polys.polyerrors import PolynomialError +from sympy.polys.polytools import poly +from sympy.series.series import series +from sympy.sets.sets import (FiniteSet, Intersection, Interval, Union) +from sympy.solvers.solveset import solveset +from sympy.solvers.inequalities import reduce_rational_inequalities +from sympy.stats.rv import (RandomDomain, SingleDomain, ConditionalDomain, is_random, + ProductDomain, PSpace, SinglePSpace, random_symbols, NamedArgsMixin, Distribution) + + +class ContinuousDomain(RandomDomain): + """ + A domain with continuous support + + Represented using symbols and Intervals. + """ + is_Continuous = True + + def as_boolean(self): + raise NotImplementedError("Not Implemented for generic Domains") + + +class SingleContinuousDomain(ContinuousDomain, SingleDomain): + """ + A univariate domain with continuous support + + Represented using a single symbol and interval. + """ + def compute_expectation(self, expr, variables=None, **kwargs): + if variables is None: + variables = self.symbols + if not variables: + return expr + if frozenset(variables) != frozenset(self.symbols): + raise ValueError("Values should be equal") + # assumes only intervals + return Integral(expr, (self.symbol, self.set), **kwargs) + + def as_boolean(self): + return self.set.as_relational(self.symbol) + + +class ProductContinuousDomain(ProductDomain, ContinuousDomain): + """ + A collection of independent domains with continuous support + """ + + def compute_expectation(self, expr, variables=None, **kwargs): + if variables is None: + variables = self.symbols + for domain in self.domains: + domain_vars = frozenset(variables) & frozenset(domain.symbols) + if domain_vars: + expr = domain.compute_expectation(expr, domain_vars, **kwargs) + return expr + + def as_boolean(self): + return And(*[domain.as_boolean() for domain in self.domains]) + + +class ConditionalContinuousDomain(ContinuousDomain, ConditionalDomain): + """ + A domain with continuous support that has been further restricted by a + condition such as $x > 3$. + """ + + def compute_expectation(self, expr, variables=None, **kwargs): + if variables is None: + variables = self.symbols + if not variables: + return expr + # Extract the full integral + fullintgrl = self.fulldomain.compute_expectation(expr, variables) + # separate into integrand and limits + integrand, limits = fullintgrl.function, list(fullintgrl.limits) + + conditions = [self.condition] + while conditions: + cond = conditions.pop() + if cond.is_Boolean: + if isinstance(cond, And): + conditions.extend(cond.args) + elif isinstance(cond, Or): + raise NotImplementedError("Or not implemented here") + elif cond.is_Relational: + if cond.is_Equality: + # Add the appropriate Delta to the integrand + integrand *= DiracDelta(cond.lhs - cond.rhs) + else: + symbols = cond.free_symbols & set(self.symbols) + if len(symbols) != 1: # Can't handle x > y + raise NotImplementedError( + "Multivariate Inequalities not yet implemented") + # Can handle x > 0 + symbol = symbols.pop() + # Find the limit with x, such as (x, -oo, oo) + for i, limit in enumerate(limits): + if limit[0] == symbol: + # Make condition into an Interval like [0, oo] + cintvl = reduce_rational_inequalities_wrap( + cond, symbol) + # Make limit into an Interval like [-oo, oo] + lintvl = Interval(limit[1], limit[2]) + # Intersect them to get [0, oo] + intvl = cintvl.intersect(lintvl) + # Put back into limits list + limits[i] = (symbol, intvl.left, intvl.right) + else: + raise TypeError( + "Condition %s is not a relational or Boolean" % cond) + + return Integral(integrand, *limits, **kwargs) + + def as_boolean(self): + return And(self.fulldomain.as_boolean(), self.condition) + + @property + def set(self): + if len(self.symbols) == 1: + return (self.fulldomain.set & reduce_rational_inequalities_wrap( + self.condition, tuple(self.symbols)[0])) + else: + raise NotImplementedError( + "Set of Conditional Domain not Implemented") + + +class ContinuousDistribution(Distribution): + def __call__(self, *args): + return self.pdf(*args) + + +class SingleContinuousDistribution(ContinuousDistribution, NamedArgsMixin): + """ Continuous distribution of a single variable. + + Explanation + =========== + + Serves as superclass for Normal/Exponential/UniformDistribution etc.... + + Represented by parameters for each of the specific classes. E.g + NormalDistribution is represented by a mean and standard deviation. + + Provides methods for pdf, cdf, and sampling. + + See Also + ======== + + sympy.stats.crv_types.* + """ + + set = Interval(-oo, oo) + + def __new__(cls, *args): + args = list(map(sympify, args)) + return Basic.__new__(cls, *args) + + @staticmethod + def check(*args): + pass + + @cacheit + def compute_cdf(self, **kwargs): + """ Compute the CDF from the PDF. + + Returns a Lambda. + """ + x, z = symbols('x, z', real=True, cls=Dummy) + left_bound = self.set.start + + # CDF is integral of PDF from left bound to z + pdf = self.pdf(x) + cdf = integrate(pdf.doit(), (x, left_bound, z), **kwargs) + # CDF Ensure that CDF left of left_bound is zero + cdf = Piecewise((cdf, z >= left_bound), (0, True)) + return Lambda(z, cdf) + + def _cdf(self, x): + return None + + def cdf(self, x, **kwargs): + """ Cumulative density function """ + if len(kwargs) == 0: + cdf = self._cdf(x) + if cdf is not None: + return cdf + return self.compute_cdf(**kwargs)(x) + + @cacheit + def compute_characteristic_function(self, **kwargs): + """ Compute the characteristic function from the PDF. + + Returns a Lambda. + """ + x, t = symbols('x, t', real=True, cls=Dummy) + pdf = self.pdf(x) + cf = integrate(exp(I*t*x)*pdf, (x, self.set)) + return Lambda(t, cf) + + def _characteristic_function(self, t): + return None + + def characteristic_function(self, t, **kwargs): + """ Characteristic function """ + if len(kwargs) == 0: + cf = self._characteristic_function(t) + if cf is not None: + return cf + return self.compute_characteristic_function(**kwargs)(t) + + @cacheit + def compute_moment_generating_function(self, **kwargs): + """ Compute the moment generating function from the PDF. + + Returns a Lambda. + """ + x, t = symbols('x, t', real=True, cls=Dummy) + pdf = self.pdf(x) + mgf = integrate(exp(t * x) * pdf, (x, self.set)) + return Lambda(t, mgf) + + def _moment_generating_function(self, t): + return None + + def moment_generating_function(self, t, **kwargs): + """ Moment generating function """ + if not kwargs: + mgf = self._moment_generating_function(t) + if mgf is not None: + return mgf + return self.compute_moment_generating_function(**kwargs)(t) + + def expectation(self, expr, var, evaluate=True, **kwargs): + """ Expectation of expression over distribution """ + if evaluate: + try: + p = poly(expr, var) + if p.is_zero: + return S.Zero + t = Dummy('t', real=True) + mgf = self._moment_generating_function(t) + if mgf is None: + return integrate(expr * self.pdf(var), (var, self.set), **kwargs) + deg = p.degree() + taylor = poly(series(mgf, t, 0, deg + 1).removeO(), t) + result = 0 + for k in range(deg+1): + result += p.coeff_monomial(var ** k) * taylor.coeff_monomial(t ** k) * factorial(k) + return result + except PolynomialError: + return integrate(expr * self.pdf(var), (var, self.set), **kwargs) + else: + return Integral(expr * self.pdf(var), (var, self.set), **kwargs) + + @cacheit + def compute_quantile(self, **kwargs): + """ Compute the Quantile from the PDF. + + Returns a Lambda. + """ + x, p = symbols('x, p', real=True, cls=Dummy) + left_bound = self.set.start + + pdf = self.pdf(x) + cdf = integrate(pdf, (x, left_bound, x), **kwargs) + quantile = solveset(cdf - p, x, self.set) + return Lambda(p, Piecewise((quantile, (p >= 0) & (p <= 1) ), (nan, True))) + + def _quantile(self, x): + return None + + def quantile(self, x, **kwargs): + """ Cumulative density function """ + if len(kwargs) == 0: + quantile = self._quantile(x) + if quantile is not None: + return quantile + return self.compute_quantile(**kwargs)(x) + + +class ContinuousPSpace(PSpace): + """ Continuous Probability Space + + Represents the likelihood of an event space defined over a continuum. + + Represented with a ContinuousDomain and a PDF (Lambda-Like) + """ + + is_Continuous = True + is_real = True + + @property + def pdf(self): + return self.density(*self.domain.symbols) + + def compute_expectation(self, expr, rvs=None, evaluate=False, **kwargs): + if rvs is None: + rvs = self.values + else: + rvs = frozenset(rvs) + + expr = expr.xreplace({rv: rv.symbol for rv in rvs}) + + domain_symbols = frozenset(rv.symbol for rv in rvs) + + return self.domain.compute_expectation(self.pdf * expr, + domain_symbols, **kwargs) + + def compute_density(self, expr, **kwargs): + # Common case Density(X) where X in self.values + if expr in self.values: + # Marginalize all other random symbols out of the density + randomsymbols = tuple(set(self.values) - frozenset([expr])) + symbols = tuple(rs.symbol for rs in randomsymbols) + pdf = self.domain.compute_expectation(self.pdf, symbols, **kwargs) + return Lambda(expr.symbol, pdf) + + z = Dummy('z', real=True) + return Lambda(z, self.compute_expectation(DiracDelta(expr - z), **kwargs)) + + @cacheit + def compute_cdf(self, expr, **kwargs): + if not self.domain.set.is_Interval: + raise ValueError( + "CDF not well defined on multivariate expressions") + + d = self.compute_density(expr, **kwargs) + x, z = symbols('x, z', real=True, cls=Dummy) + left_bound = self.domain.set.start + + # CDF is integral of PDF from left bound to z + cdf = integrate(d(x), (x, left_bound, z), **kwargs) + # CDF Ensure that CDF left of left_bound is zero + cdf = Piecewise((cdf, z >= left_bound), (0, True)) + return Lambda(z, cdf) + + @cacheit + def compute_characteristic_function(self, expr, **kwargs): + if not self.domain.set.is_Interval: + raise NotImplementedError("Characteristic function of multivariate expressions not implemented") + + d = self.compute_density(expr, **kwargs) + x, t = symbols('x, t', real=True, cls=Dummy) + cf = integrate(exp(I*t*x)*d(x), (x, -oo, oo), **kwargs) + return Lambda(t, cf) + + @cacheit + def compute_moment_generating_function(self, expr, **kwargs): + if not self.domain.set.is_Interval: + raise NotImplementedError("Moment generating function of multivariate expressions not implemented") + + d = self.compute_density(expr, **kwargs) + x, t = symbols('x, t', real=True, cls=Dummy) + mgf = integrate(exp(t * x) * d(x), (x, -oo, oo), **kwargs) + return Lambda(t, mgf) + + @cacheit + def compute_quantile(self, expr, **kwargs): + if not self.domain.set.is_Interval: + raise ValueError( + "Quantile not well defined on multivariate expressions") + + d = self.compute_cdf(expr, **kwargs) + x = Dummy('x', real=True) + p = Dummy('p', positive=True) + + quantile = solveset(d(x) - p, x, self.set) + + return Lambda(p, quantile) + + def probability(self, condition, **kwargs): + z = Dummy('z', real=True) + cond_inv = False + if isinstance(condition, Ne): + condition = Eq(condition.args[0], condition.args[1]) + cond_inv = True + # Univariate case can be handled by where + try: + domain = self.where(condition) + rv = [rv for rv in self.values if rv.symbol == domain.symbol][0] + # Integrate out all other random variables + pdf = self.compute_density(rv, **kwargs) + # return S.Zero if `domain` is empty set + if domain.set is S.EmptySet or isinstance(domain.set, FiniteSet): + return S.Zero if not cond_inv else S.One + if isinstance(domain.set, Union): + return sum( + Integral(pdf(z), (z, subset), **kwargs) for subset in + domain.set.args if isinstance(subset, Interval)) + # Integrate out the last variable over the special domain + return Integral(pdf(z), (z, domain.set), **kwargs) + + # Other cases can be turned into univariate case + # by computing a density handled by density computation + except NotImplementedError: + from sympy.stats.rv import density + expr = condition.lhs - condition.rhs + if not is_random(expr): + dens = self.density + comp = condition.rhs + else: + dens = density(expr, **kwargs) + comp = 0 + if not isinstance(dens, ContinuousDistribution): + from sympy.stats.crv_types import ContinuousDistributionHandmade + dens = ContinuousDistributionHandmade(dens, set=self.domain.set) + # Turn problem into univariate case + space = SingleContinuousPSpace(z, dens) + result = space.probability(condition.__class__(space.value, comp)) + return result if not cond_inv else S.One - result + + def where(self, condition): + rvs = frozenset(random_symbols(condition)) + if not (len(rvs) == 1 and rvs.issubset(self.values)): + raise NotImplementedError( + "Multiple continuous random variables not supported") + rv = tuple(rvs)[0] + interval = reduce_rational_inequalities_wrap(condition, rv) + interval = interval.intersect(self.domain.set) + return SingleContinuousDomain(rv.symbol, interval) + + def conditional_space(self, condition, normalize=True, **kwargs): + condition = condition.xreplace({rv: rv.symbol for rv in self.values}) + domain = ConditionalContinuousDomain(self.domain, condition) + if normalize: + # create a clone of the variable to + # make sure that variables in nested integrals are different + # from the variables outside the integral + # this makes sure that they are evaluated separately + # and in the correct order + replacement = {rv: Dummy(str(rv)) for rv in self.symbols} + norm = domain.compute_expectation(self.pdf, **kwargs) + pdf = self.pdf / norm.xreplace(replacement) + # XXX: Converting set to tuple. The order matters to Lambda though + # so we shouldn't be starting with a set here... + density = Lambda(tuple(domain.symbols), pdf) + + return ContinuousPSpace(domain, density) + + +class SingleContinuousPSpace(ContinuousPSpace, SinglePSpace): + """ + A continuous probability space over a single univariate variable. + + These consist of a Symbol and a SingleContinuousDistribution + + This class is normally accessed through the various random variable + functions, Normal, Exponential, Uniform, etc.... + """ + + @property + def set(self): + return self.distribution.set + + @property + def domain(self): + return SingleContinuousDomain(sympify(self.symbol), self.set) + + def sample(self, size=(), library='scipy', seed=None): + """ + Internal sample method. + + Returns dictionary mapping RandomSymbol to realization value. + """ + return {self.value: self.distribution.sample(size, library=library, seed=seed)} + + def compute_expectation(self, expr, rvs=None, evaluate=False, **kwargs): + rvs = rvs or (self.value,) + if self.value not in rvs: + return expr + + expr = _sympify(expr) + expr = expr.xreplace({rv: rv.symbol for rv in rvs}) + + x = self.value.symbol + try: + return self.distribution.expectation(expr, x, evaluate=evaluate, **kwargs) + except PoleError: + return Integral(expr * self.pdf, (x, self.set), **kwargs) + + def compute_cdf(self, expr, **kwargs): + if expr == self.value: + z = Dummy("z", real=True) + return Lambda(z, self.distribution.cdf(z, **kwargs)) + else: + return ContinuousPSpace.compute_cdf(self, expr, **kwargs) + + def compute_characteristic_function(self, expr, **kwargs): + if expr == self.value: + t = Dummy("t", real=True) + return Lambda(t, self.distribution.characteristic_function(t, **kwargs)) + else: + return ContinuousPSpace.compute_characteristic_function(self, expr, **kwargs) + + def compute_moment_generating_function(self, expr, **kwargs): + if expr == self.value: + t = Dummy("t", real=True) + return Lambda(t, self.distribution.moment_generating_function(t, **kwargs)) + else: + return ContinuousPSpace.compute_moment_generating_function(self, expr, **kwargs) + + def compute_density(self, expr, **kwargs): + # https://en.wikipedia.org/wiki/Random_variable#Functions_of_random_variables + if expr == self.value: + return self.density + y = Dummy('y', real=True) + + gs = solveset(expr - y, self.value, S.Reals) + + if isinstance(gs, Intersection): + if len(gs.args) == 2 and gs.args[0] is S.Reals: + gs = gs.args[1] + if not gs.is_FiniteSet: + raise ValueError("Can not solve %s for %s" % (expr, self.value)) + fx = self.compute_density(self.value) + fy = sum(fx(g) * abs(g.diff(y)) for g in gs) + return Lambda(y, fy) + + def compute_quantile(self, expr, **kwargs): + + if expr == self.value: + p = Dummy("p", real=True) + return Lambda(p, self.distribution.quantile(p, **kwargs)) + else: + return ContinuousPSpace.compute_quantile(self, expr, **kwargs) + +def _reduce_inequalities(conditions, var, **kwargs): + try: + return reduce_rational_inequalities(conditions, var, **kwargs) + except PolynomialError: + raise ValueError("Reduction of condition failed %s\n" % conditions[0]) + + +def reduce_rational_inequalities_wrap(condition, var): + if condition.is_Relational: + return _reduce_inequalities([[condition]], var, relational=False) + if isinstance(condition, Or): + return Union(*[_reduce_inequalities([[arg]], var, relational=False) + for arg in condition.args]) + if isinstance(condition, And): + intervals = [_reduce_inequalities([[arg]], var, relational=False) + for arg in condition.args] + I = intervals[0] + for i in intervals: + I = I.intersect(i) + return I diff --git a/phivenv/Lib/site-packages/sympy/stats/crv_types.py b/phivenv/Lib/site-packages/sympy/stats/crv_types.py new file mode 100644 index 0000000000000000000000000000000000000000..073e7350fdf80aac39ecd1dd607488a8b76187e3 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/stats/crv_types.py @@ -0,0 +1,4732 @@ +""" +Continuous Random Variables - Prebuilt variables + +Contains +======== +Arcsin +Benini +Beta +BetaNoncentral +BetaPrime +BoundedPareto +Cauchy +Chi +ChiNoncentral +ChiSquared +Dagum +Davis +Erlang +ExGaussian +Exponential +ExponentialPower +FDistribution +FisherZ +Frechet +Gamma +GammaInverse +Gumbel +Gompertz +Kumaraswamy +Laplace +Levy +LogCauchy +Logistic +LogLogistic +LogitNormal +LogNormal +Lomax +Maxwell +Moyal +Nakagami +Normal +Pareto +PowerFunction +QuadraticU +RaisedCosine +Rayleigh +Reciprocal +ShiftedGompertz +StudentT +Trapezoidal +Triangular +Uniform +UniformSum +VonMises +Wald +Weibull +WignerSemicircle +""" + + + +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.trigonometric import (atan, cos, sin, tan) +from sympy.functions.special.bessel import (besseli, besselj, besselk) +from sympy.functions.special.beta_functions import beta as beta_fn +from sympy.concrete.summations import Sum +from sympy.core.basic import Basic +from sympy.core.function import Lambda +from sympy.core.numbers import (I, Rational, pi) +from sympy.core.relational import (Eq, Ne) +from sympy.core.singleton import S +from sympy.core.symbol import Dummy +from sympy.core.sympify import sympify +from sympy.functions.combinatorial.factorials import (binomial, factorial) +from sympy.functions.elementary.complexes import (Abs, sign) +from sympy.functions.elementary.exponential import log +from sympy.functions.elementary.hyperbolic import sinh +from sympy.functions.elementary.integers import floor +from sympy.functions.elementary.miscellaneous import sqrt, Max, Min +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import asin +from sympy.functions.special.error_functions import (erf, erfc, erfi, erfinv, expint) +from sympy.functions.special.gamma_functions import (gamma, lowergamma, uppergamma) +from sympy.functions.special.zeta_functions import zeta +from sympy.functions.special.hyper import hyper +from sympy.integrals.integrals import integrate +from sympy.logic.boolalg import And +from sympy.sets.sets import Interval +from sympy.matrices import MatrixBase +from sympy.stats.crv import SingleContinuousPSpace, SingleContinuousDistribution +from sympy.stats.rv import _value_check, is_random + +oo = S.Infinity + +__all__ = ['ContinuousRV', +'Arcsin', +'Benini', +'Beta', +'BetaNoncentral', +'BetaPrime', +'BoundedPareto', +'Cauchy', +'Chi', +'ChiNoncentral', +'ChiSquared', +'Dagum', +'Davis', +'Erlang', +'ExGaussian', +'Exponential', +'ExponentialPower', +'FDistribution', +'FisherZ', +'Frechet', +'Gamma', +'GammaInverse', +'Gompertz', +'Gumbel', +'Kumaraswamy', +'Laplace', +'Levy', +'LogCauchy', +'Logistic', +'LogLogistic', +'LogitNormal', +'LogNormal', +'Lomax', +'Maxwell', +'Moyal', +'Nakagami', +'Normal', +'GaussianInverse', +'Pareto', +'PowerFunction', +'QuadraticU', +'RaisedCosine', +'Rayleigh', +'Reciprocal', +'StudentT', +'ShiftedGompertz', +'Trapezoidal', +'Triangular', +'Uniform', +'UniformSum', +'VonMises', +'Wald', +'Weibull', +'WignerSemicircle', +] + + +@is_random.register(MatrixBase) +def _(x): + return any(is_random(i) for i in x) + +def rv(symbol, cls, args, **kwargs): + args = list(map(sympify, args)) + dist = cls(*args) + if kwargs.pop('check', True): + dist.check(*args) + pspace = SingleContinuousPSpace(symbol, dist) + if any(is_random(arg) for arg in args): + from sympy.stats.compound_rv import CompoundPSpace, CompoundDistribution + pspace = CompoundPSpace(symbol, CompoundDistribution(dist)) + return pspace.value + + +class ContinuousDistributionHandmade(SingleContinuousDistribution): + _argnames = ('pdf',) + + def __new__(cls, pdf, set=Interval(-oo, oo)): + return Basic.__new__(cls, pdf, set) + + @property + def set(self): + return self.args[1] + + @staticmethod + def check(pdf, set): + x = Dummy('x') + val = integrate(pdf(x), (x, set)) + _value_check(Eq(val, 1) != S.false, "The pdf on the given set is incorrect.") + + +def ContinuousRV(symbol, density, set=Interval(-oo, oo), **kwargs): + """ + Create a Continuous Random Variable given the following: + + Parameters + ========== + + symbol : Symbol + Represents name of the random variable. + density : Expression containing symbol + Represents probability density function. + set : set/Interval + Represents the region where the pdf is valid, by default is real line. + check : bool + If True, it will check whether the given density + integrates to 1 over the given set. If False, it + will not perform this check. Default is False. + + + Returns + ======= + + RandomSymbol + + Many common continuous random variable types are already implemented. + This function should be necessary only very rarely. + + + Examples + ======== + + >>> from sympy import Symbol, sqrt, exp, pi + >>> from sympy.stats import ContinuousRV, P, E + + >>> x = Symbol("x") + + >>> pdf = sqrt(2)*exp(-x**2/2)/(2*sqrt(pi)) # Normal distribution + >>> X = ContinuousRV(x, pdf) + + >>> E(X) + 0 + >>> P(X>0) + 1/2 + """ + pdf = Piecewise((density, set.as_relational(symbol)), (0, True)) + pdf = Lambda(symbol, pdf) + # have a default of False while `rv` should have a default of True + kwargs['check'] = kwargs.pop('check', False) + return rv(symbol.name, ContinuousDistributionHandmade, (pdf, set), **kwargs) + +######################################## +# Continuous Probability Distributions # +######################################## + +#------------------------------------------------------------------------------- +# Arcsin distribution ---------------------------------------------------------- + + +class ArcsinDistribution(SingleContinuousDistribution): + _argnames = ('a', 'b') + + @property + def set(self): + return Interval(self.a, self.b) + + def pdf(self, x): + a, b = self.a, self.b + return 1/(pi*sqrt((x - a)*(b - x))) + + def _cdf(self, x): + a, b = self.a, self.b + return Piecewise( + (S.Zero, x < a), + (2*asin(sqrt((x - a)/(b - a)))/pi, x <= b), + (S.One, True)) + + +def Arcsin(name, a=0, b=1): + r""" + Create a Continuous Random Variable with an arcsin distribution. + + The density of the arcsin distribution is given by + + .. math:: + f(x) := \frac{1}{\pi\sqrt{(x-a)(b-x)}} + + with :math:`x \in (a,b)`. It must hold that :math:`-\infty < a < b < \infty`. + + Parameters + ========== + + a : Real number, the left interval boundary + b : Real number, the right interval boundary + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Arcsin, density, cdf + >>> from sympy import Symbol + + >>> a = Symbol("a", real=True) + >>> b = Symbol("b", real=True) + >>> z = Symbol("z") + + >>> X = Arcsin("x", a, b) + + >>> density(X)(z) + 1/(pi*sqrt((-a + z)*(b - z))) + + >>> cdf(X)(z) + Piecewise((0, a > z), + (2*asin(sqrt((-a + z)/(-a + b)))/pi, b >= z), + (1, True)) + + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Arcsine_distribution + + """ + + return rv(name, ArcsinDistribution, (a, b)) + +#------------------------------------------------------------------------------- +# Benini distribution ---------------------------------------------------------- + + +class BeniniDistribution(SingleContinuousDistribution): + _argnames = ('alpha', 'beta', 'sigma') + + @staticmethod + def check(alpha, beta, sigma): + _value_check(alpha > 0, "Shape parameter Alpha must be positive.") + _value_check(beta > 0, "Shape parameter Beta must be positive.") + _value_check(sigma > 0, "Scale parameter Sigma must be positive.") + + @property + def set(self): + return Interval(self.sigma, oo) + + def pdf(self, x): + alpha, beta, sigma = self.alpha, self.beta, self.sigma + return (exp(-alpha*log(x/sigma) - beta*log(x/sigma)**2) + *(alpha/x + 2*beta*log(x/sigma)/x)) + + def _moment_generating_function(self, t): + raise NotImplementedError('The moment generating function of the ' + 'Benini distribution does not exist.') + +def Benini(name, alpha, beta, sigma): + r""" + Create a Continuous Random Variable with a Benini distribution. + + The density of the Benini distribution is given by + + .. math:: + f(x) := e^{-\alpha\log{\frac{x}{\sigma}} + -\beta\log^2\left[{\frac{x}{\sigma}}\right]} + \left(\frac{\alpha}{x}+\frac{2\beta\log{\frac{x}{\sigma}}}{x}\right) + + This is a heavy-tailed distribution and is also known as the log-Rayleigh + distribution. + + Parameters + ========== + + alpha : Real number, `\alpha > 0`, a shape + beta : Real number, `\beta > 0`, a shape + sigma : Real number, `\sigma > 0`, a scale + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Benini, density, cdf + >>> from sympy import Symbol, pprint + + >>> alpha = Symbol("alpha", positive=True) + >>> beta = Symbol("beta", positive=True) + >>> sigma = Symbol("sigma", positive=True) + >>> z = Symbol("z") + + >>> X = Benini("x", alpha, beta, sigma) + + >>> D = density(X)(z) + >>> pprint(D, use_unicode=False) + / / z \\ / z \ 2/ z \ + | 2*beta*log|-----|| - alpha*log|-----| - beta*log |-----| + |alpha \sigma/| \sigma/ \sigma/ + |----- + -----------------|*e + \ z z / + + >>> cdf(X)(z) + Piecewise((1 - exp(-alpha*log(z/sigma) - beta*log(z/sigma)**2), sigma <= z), + (0, True)) + + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Benini_distribution + .. [2] https://reference.wolfram.com/legacy/v8/ref/BeniniDistribution.html + + """ + + return rv(name, BeniniDistribution, (alpha, beta, sigma)) + +#------------------------------------------------------------------------------- +# Beta distribution ------------------------------------------------------------ + + +class BetaDistribution(SingleContinuousDistribution): + _argnames = ('alpha', 'beta') + + set = Interval(0, 1) + + @staticmethod + def check(alpha, beta): + _value_check(alpha > 0, "Shape parameter Alpha must be positive.") + _value_check(beta > 0, "Shape parameter Beta must be positive.") + + def pdf(self, x): + alpha, beta = self.alpha, self.beta + return x**(alpha - 1) * (1 - x)**(beta - 1) / beta_fn(alpha, beta) + + def _characteristic_function(self, t): + return hyper((self.alpha,), (self.alpha + self.beta,), I*t) + + def _moment_generating_function(self, t): + return hyper((self.alpha,), (self.alpha + self.beta,), t) + + +def Beta(name, alpha, beta): + r""" + Create a Continuous Random Variable with a Beta distribution. + + The density of the Beta distribution is given by + + .. math:: + f(x) := \frac{x^{\alpha-1}(1-x)^{\beta-1}} {\mathrm{B}(\alpha,\beta)} + + with :math:`x \in [0,1]`. + + Parameters + ========== + + alpha : Real number, `\alpha > 0`, a shape + beta : Real number, `\beta > 0`, a shape + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Beta, density, E, variance + >>> from sympy import Symbol, simplify, pprint, factor + + >>> alpha = Symbol("alpha", positive=True) + >>> beta = Symbol("beta", positive=True) + >>> z = Symbol("z") + + >>> X = Beta("x", alpha, beta) + + >>> D = density(X)(z) + >>> pprint(D, use_unicode=False) + alpha - 1 beta - 1 + z *(1 - z) + -------------------------- + B(alpha, beta) + + >>> simplify(E(X)) + alpha/(alpha + beta) + + >>> factor(simplify(variance(X))) + alpha*beta/((alpha + beta)**2*(alpha + beta + 1)) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Beta_distribution + .. [2] https://mathworld.wolfram.com/BetaDistribution.html + + """ + + return rv(name, BetaDistribution, (alpha, beta)) + +#------------------------------------------------------------------------------- +# Noncentral Beta distribution ------------------------------------------------------------ + + +class BetaNoncentralDistribution(SingleContinuousDistribution): + _argnames = ('alpha', 'beta', 'lamda') + + set = Interval(0, 1) + + @staticmethod + def check(alpha, beta, lamda): + _value_check(alpha > 0, "Shape parameter Alpha must be positive.") + _value_check(beta > 0, "Shape parameter Beta must be positive.") + _value_check(lamda >= 0, "Noncentrality parameter Lambda must be positive") + + def pdf(self, x): + alpha, beta, lamda = self.alpha, self.beta, self.lamda + k = Dummy("k") + return Sum(exp(-lamda / 2) * (lamda / 2)**k * x**(alpha + k - 1) *( + 1 - x)**(beta - 1) / (factorial(k) * beta_fn(alpha + k, beta)), (k, 0, oo)) + +def BetaNoncentral(name, alpha, beta, lamda): + r""" + Create a Continuous Random Variable with a Type I Noncentral Beta distribution. + + The density of the Noncentral Beta distribution is given by + + .. math:: + f(x) := \sum_{k=0}^\infty e^{-\lambda/2}\frac{(\lambda/2)^k}{k!} + \frac{x^{\alpha+k-1}(1-x)^{\beta-1}}{\mathrm{B}(\alpha+k,\beta)} + + with :math:`x \in [0,1]`. + + Parameters + ========== + + alpha : Real number, `\alpha > 0`, a shape + beta : Real number, `\beta > 0`, a shape + lamda : Real number, `\lambda \geq 0`, noncentrality parameter + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import BetaNoncentral, density, cdf + >>> from sympy import Symbol, pprint + + >>> alpha = Symbol("alpha", positive=True) + >>> beta = Symbol("beta", positive=True) + >>> lamda = Symbol("lamda", nonnegative=True) + >>> z = Symbol("z") + + >>> X = BetaNoncentral("x", alpha, beta, lamda) + + >>> D = density(X)(z) + >>> pprint(D, use_unicode=False) + oo + _____ + \ ` + \ -lamda + \ k ------- + \ k + alpha - 1 /lamda\ beta - 1 2 + ) z *|-----| *(1 - z) *e + / \ 2 / + / ------------------------------------------------ + / B(k + alpha, beta)*k! + /____, + k = 0 + + Compute cdf with specific 'x', 'alpha', 'beta' and 'lamda' values as follows: + + >>> cdf(BetaNoncentral("x", 1, 1, 1), evaluate=False)(2).doit() + 2*exp(1/2) + + The argument evaluate=False prevents an attempt at evaluation + of the sum for general x, before the argument 2 is passed. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Noncentral_beta_distribution + .. [2] https://reference.wolfram.com/language/ref/NoncentralBetaDistribution.html + + """ + + return rv(name, BetaNoncentralDistribution, (alpha, beta, lamda)) + + +#------------------------------------------------------------------------------- +# Beta prime distribution ------------------------------------------------------ + + +class BetaPrimeDistribution(SingleContinuousDistribution): + _argnames = ('alpha', 'beta') + + @staticmethod + def check(alpha, beta): + _value_check(alpha > 0, "Shape parameter Alpha must be positive.") + _value_check(beta > 0, "Shape parameter Beta must be positive.") + + set = Interval(0, oo) + + def pdf(self, x): + alpha, beta = self.alpha, self.beta + return x**(alpha - 1)*(1 + x)**(-alpha - beta)/beta_fn(alpha, beta) + +def BetaPrime(name, alpha, beta): + r""" + Create a continuous random variable with a Beta prime distribution. + + The density of the Beta prime distribution is given by + + .. math:: + f(x) := \frac{x^{\alpha-1} (1+x)^{-\alpha -\beta}}{B(\alpha,\beta)} + + with :math:`x > 0`. + + Parameters + ========== + + alpha : Real number, `\alpha > 0`, a shape + beta : Real number, `\beta > 0`, a shape + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import BetaPrime, density + >>> from sympy import Symbol, pprint + + >>> alpha = Symbol("alpha", positive=True) + >>> beta = Symbol("beta", positive=True) + >>> z = Symbol("z") + + >>> X = BetaPrime("x", alpha, beta) + + >>> D = density(X)(z) + >>> pprint(D, use_unicode=False) + alpha - 1 -alpha - beta + z *(z + 1) + ------------------------------- + B(alpha, beta) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Beta_prime_distribution + .. [2] https://mathworld.wolfram.com/BetaPrimeDistribution.html + + """ + + return rv(name, BetaPrimeDistribution, (alpha, beta)) + +#------------------------------------------------------------------------------- +# Bounded Pareto Distribution -------------------------------------------------- +class BoundedParetoDistribution(SingleContinuousDistribution): + _argnames = ('alpha', 'left', 'right') + + @property + def set(self): + return Interval(self.left, self.right) + + @staticmethod + def check(alpha, left, right): + _value_check (alpha.is_positive, "Shape must be positive.") + _value_check (left.is_positive, "Left value should be positive.") + _value_check (right > left, "Right should be greater than left.") + + def pdf(self, x): + alpha, left, right = self.alpha, self.left, self.right + num = alpha * (left**alpha) * x**(- alpha -1) + den = 1 - (left/right)**alpha + return num/den + +def BoundedPareto(name, alpha, left, right): + r""" + Create a continuous random variable with a Bounded Pareto distribution. + + The density of the Bounded Pareto distribution is given by + + .. math:: + f(x) := \frac{\alpha L^{\alpha}x^{-\alpha-1}}{1-(\frac{L}{H})^{\alpha}} + + Parameters + ========== + + alpha : Real Number, `\alpha > 0` + Shape parameter + left : Real Number, `left > 0` + Location parameter + right : Real Number, `right > left` + Location parameter + + Examples + ======== + + >>> from sympy.stats import BoundedPareto, density, cdf, E + >>> from sympy import symbols + >>> L, H = symbols('L, H', positive=True) + >>> X = BoundedPareto('X', 2, L, H) + >>> x = symbols('x') + >>> density(X)(x) + 2*L**2/(x**3*(1 - L**2/H**2)) + >>> cdf(X)(x) + Piecewise((-H**2*L**2/(x**2*(H**2 - L**2)) + H**2/(H**2 - L**2), L <= x), (0, True)) + >>> E(X).simplify() + 2*H*L/(H + L) + + Returns + ======= + + RandomSymbol + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Pareto_distribution#Bounded_Pareto_distribution + + """ + return rv (name, BoundedParetoDistribution, (alpha, left, right)) + +# ------------------------------------------------------------------------------ +# Cauchy distribution ---------------------------------------------------------- + + +class CauchyDistribution(SingleContinuousDistribution): + _argnames = ('x0', 'gamma') + + @staticmethod + def check(x0, gamma): + _value_check(gamma > 0, "Scale parameter Gamma must be positive.") + _value_check(x0.is_real, "Location parameter must be real.") + + def pdf(self, x): + return 1/(pi*self.gamma*(1 + ((x - self.x0)/self.gamma)**2)) + + def _cdf(self, x): + x0, gamma = self.x0, self.gamma + return (1/pi)*atan((x - x0)/gamma) + S.Half + + def _characteristic_function(self, t): + return exp(self.x0 * I * t - self.gamma * Abs(t)) + + def _moment_generating_function(self, t): + raise NotImplementedError("The moment generating function for the " + "Cauchy distribution does not exist.") + + def _quantile(self, p): + return self.x0 + self.gamma*tan(pi*(p - S.Half)) + + +def Cauchy(name, x0, gamma): + r""" + Create a continuous random variable with a Cauchy distribution. + + The density of the Cauchy distribution is given by + + .. math:: + f(x) := \frac{1}{\pi \gamma [1 + {(\frac{x-x_0}{\gamma})}^2]} + + Parameters + ========== + + x0 : Real number, the location + gamma : Real number, `\gamma > 0`, a scale + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Cauchy, density + >>> from sympy import Symbol + + >>> x0 = Symbol("x0") + >>> gamma = Symbol("gamma", positive=True) + >>> z = Symbol("z") + + >>> X = Cauchy("x", x0, gamma) + + >>> density(X)(z) + 1/(pi*gamma*(1 + (-x0 + z)**2/gamma**2)) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Cauchy_distribution + .. [2] https://mathworld.wolfram.com/CauchyDistribution.html + + """ + + return rv(name, CauchyDistribution, (x0, gamma)) + +#------------------------------------------------------------------------------- +# Chi distribution ------------------------------------------------------------- + + +class ChiDistribution(SingleContinuousDistribution): + _argnames = ('k',) + + @staticmethod + def check(k): + _value_check(k > 0, "Number of degrees of freedom (k) must be positive.") + _value_check(k.is_integer, "Number of degrees of freedom (k) must be an integer.") + + set = Interval(0, oo) + + def pdf(self, x): + return 2**(1 - self.k/2)*x**(self.k - 1)*exp(-x**2/2)/gamma(self.k/2) + + def _characteristic_function(self, t): + k = self.k + + part_1 = hyper((k/2,), (S.Half,), -t**2/2) + part_2 = I*t*sqrt(2)*gamma((k+1)/2)/gamma(k/2) + part_3 = hyper(((k+1)/2,), (Rational(3, 2),), -t**2/2) + return part_1 + part_2*part_3 + + def _moment_generating_function(self, t): + k = self.k + + part_1 = hyper((k / 2,), (S.Half,), t ** 2 / 2) + part_2 = t * sqrt(2) * gamma((k + 1) / 2) / gamma(k / 2) + part_3 = hyper(((k + 1) / 2,), (S(3) / 2,), t ** 2 / 2) + return part_1 + part_2 * part_3 + +def Chi(name, k): + r""" + Create a continuous random variable with a Chi distribution. + + The density of the Chi distribution is given by + + .. math:: + f(x) := \frac{2^{1-k/2}x^{k-1}e^{-x^2/2}}{\Gamma(k/2)} + + with :math:`x \geq 0`. + + Parameters + ========== + + k : Positive integer, The number of degrees of freedom + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Chi, density, E + >>> from sympy import Symbol, simplify + + >>> k = Symbol("k", integer=True) + >>> z = Symbol("z") + + >>> X = Chi("x", k) + + >>> density(X)(z) + 2**(1 - k/2)*z**(k - 1)*exp(-z**2/2)/gamma(k/2) + + >>> simplify(E(X)) + sqrt(2)*gamma(k/2 + 1/2)/gamma(k/2) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Chi_distribution + .. [2] https://mathworld.wolfram.com/ChiDistribution.html + + """ + + return rv(name, ChiDistribution, (k,)) + +#------------------------------------------------------------------------------- +# Non-central Chi distribution ------------------------------------------------- + + +class ChiNoncentralDistribution(SingleContinuousDistribution): + _argnames = ('k', 'l') + + @staticmethod + def check(k, l): + _value_check(k > 0, "Number of degrees of freedom (k) must be positive.") + _value_check(k.is_integer, "Number of degrees of freedom (k) must be an integer.") + _value_check(l > 0, "Shift parameter Lambda must be positive.") + + set = Interval(0, oo) + + def pdf(self, x): + k, l = self.k, self.l + return exp(-(x**2+l**2)/2)*x**k*l / (l*x)**(k/2) * besseli(k/2-1, l*x) + +def ChiNoncentral(name, k, l): + r""" + Create a continuous random variable with a non-central Chi distribution. + + Explanation + =========== + + The density of the non-central Chi distribution is given by + + .. math:: + f(x) := \frac{e^{-(x^2+\lambda^2)/2} x^k\lambda} + {(\lambda x)^{k/2}} I_{k/2-1}(\lambda x) + + with `x \geq 0`. Here, `I_\nu (x)` is the + :ref:`modified Bessel function of the first kind `. + + Parameters + ========== + + k : A positive Integer, $k > 0$ + The number of degrees of freedom. + lambda : Real number, `\lambda > 0` + Shift parameter. + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import ChiNoncentral, density + >>> from sympy import Symbol + + >>> k = Symbol("k", integer=True) + >>> l = Symbol("l") + >>> z = Symbol("z") + + >>> X = ChiNoncentral("x", k, l) + + >>> density(X)(z) + l*z**k*exp(-l**2/2 - z**2/2)*besseli(k/2 - 1, l*z)/(l*z)**(k/2) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Noncentral_chi_distribution + """ + + return rv(name, ChiNoncentralDistribution, (k, l)) + +#------------------------------------------------------------------------------- +# Chi squared distribution ----------------------------------------------------- + + +class ChiSquaredDistribution(SingleContinuousDistribution): + _argnames = ('k',) + + @staticmethod + def check(k): + _value_check(k > 0, "Number of degrees of freedom (k) must be positive.") + _value_check(k.is_integer, "Number of degrees of freedom (k) must be an integer.") + + set = Interval(0, oo) + + def pdf(self, x): + k = self.k + return 1/(2**(k/2)*gamma(k/2))*x**(k/2 - 1)*exp(-x/2) + + def _cdf(self, x): + k = self.k + return Piecewise( + (S.One/gamma(k/2)*lowergamma(k/2, x/2), x >= 0), + (0, True) + ) + + def _characteristic_function(self, t): + return (1 - 2*I*t)**(-self.k/2) + + def _moment_generating_function(self, t): + return (1 - 2*t)**(-self.k/2) + + +def ChiSquared(name, k): + r""" + Create a continuous random variable with a Chi-squared distribution. + + Explanation + =========== + + The density of the Chi-squared distribution is given by + + .. math:: + f(x) := \frac{1}{2^{\frac{k}{2}}\Gamma\left(\frac{k}{2}\right)} + x^{\frac{k}{2}-1} e^{-\frac{x}{2}} + + with :math:`x \geq 0`. + + Parameters + ========== + + k : Positive integer + The number of degrees of freedom. + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import ChiSquared, density, E, variance, moment + >>> from sympy import Symbol + + >>> k = Symbol("k", integer=True, positive=True) + >>> z = Symbol("z") + + >>> X = ChiSquared("x", k) + + >>> density(X)(z) + z**(k/2 - 1)*exp(-z/2)/(2**(k/2)*gamma(k/2)) + + >>> E(X) + k + + >>> variance(X) + 2*k + + >>> moment(X, 3) + k**3 + 6*k**2 + 8*k + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Chi_squared_distribution + .. [2] https://mathworld.wolfram.com/Chi-SquaredDistribution.html + """ + + return rv(name, ChiSquaredDistribution, (k, )) + +#------------------------------------------------------------------------------- +# Dagum distribution ----------------------------------------------------------- + + +class DagumDistribution(SingleContinuousDistribution): + _argnames = ('p', 'a', 'b') + + set = Interval(0, oo) + + @staticmethod + def check(p, a, b): + _value_check(p > 0, "Shape parameter p must be positive.") + _value_check(a > 0, "Shape parameter a must be positive.") + _value_check(b > 0, "Scale parameter b must be positive.") + + def pdf(self, x): + p, a, b = self.p, self.a, self.b + return a*p/x*((x/b)**(a*p)/(((x/b)**a + 1)**(p + 1))) + + def _cdf(self, x): + p, a, b = self.p, self.a, self.b + return Piecewise(((S.One + (S(x)/b)**-a)**-p, x>=0), + (S.Zero, True)) + +def Dagum(name, p, a, b): + r""" + Create a continuous random variable with a Dagum distribution. + + Explanation + =========== + + The density of the Dagum distribution is given by + + .. math:: + f(x) := \frac{a p}{x} \left( \frac{\left(\tfrac{x}{b}\right)^{a p}} + {\left(\left(\tfrac{x}{b}\right)^a + 1 \right)^{p+1}} \right) + + with :math:`x > 0`. + + Parameters + ========== + + p : Real number + `p > 0`, a shape. + a : Real number + `a > 0`, a shape. + b : Real number + `b > 0`, a scale. + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Dagum, density, cdf + >>> from sympy import Symbol + + >>> p = Symbol("p", positive=True) + >>> a = Symbol("a", positive=True) + >>> b = Symbol("b", positive=True) + >>> z = Symbol("z") + + >>> X = Dagum("x", p, a, b) + + >>> density(X)(z) + a*p*(z/b)**(a*p)*((z/b)**a + 1)**(-p - 1)/z + + >>> cdf(X)(z) + Piecewise(((1 + (z/b)**(-a))**(-p), z >= 0), (0, True)) + + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Dagum_distribution + + """ + + return rv(name, DagumDistribution, (p, a, b)) + +#------------------------------------------------------------------------------- +# Davis distribution ----------------------------------------------------------- + +class DavisDistribution(SingleContinuousDistribution): + _argnames = ('b', 'n', 'mu') + + set = Interval(0, oo) + + @staticmethod + def check(b, n, mu): + _value_check(b > 0, "Scale parameter b must be positive.") + _value_check(n > 1, "Shape parameter n must be above 1.") + _value_check(mu > 0, "Location parameter mu must be positive.") + + def pdf(self, x): + b, n, mu = self.b, self.n, self.mu + dividend = b**n*(x - mu)**(-1-n) + divisor = (exp(b/(x-mu))-1)*(gamma(n)*zeta(n)) + return dividend/divisor + + +def Davis(name, b, n, mu): + r""" Create a continuous random variable with Davis distribution. + + Explanation + =========== + + The density of Davis distribution is given by + + .. math:: + f(x; \mu; b, n) := \frac{b^{n}(x - \mu)^{1-n}}{ \left( e^{\frac{b}{x-\mu}} - 1 \right) \Gamma(n)\zeta(n)} + + with :math:`x \in [0,\infty]`. + + Davis distribution is a generalization of the Planck's law of radiation from statistical physics. It is used for modeling income distribution. + + Parameters + ========== + b : Real number + `p > 0`, a scale. + n : Real number + `n > 1`, a shape. + mu : Real number + `mu > 0`, a location. + + Returns + ======= + + RandomSymbol + + Examples + ======== + >>> from sympy.stats import Davis, density + >>> from sympy import Symbol + >>> b = Symbol("b", positive=True) + >>> n = Symbol("n", positive=True) + >>> mu = Symbol("mu", positive=True) + >>> z = Symbol("z") + >>> X = Davis("x", b, n, mu) + >>> density(X)(z) + b**n*(-mu + z)**(-n - 1)/((exp(b/(-mu + z)) - 1)*gamma(n)*zeta(n)) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Davis_distribution + .. [2] https://reference.wolfram.com/language/ref/DavisDistribution.html + + """ + return rv(name, DavisDistribution, (b, n, mu)) + + +#------------------------------------------------------------------------------- +# Erlang distribution ---------------------------------------------------------- + + +def Erlang(name, k, l): + r""" + Create a continuous random variable with an Erlang distribution. + + Explanation + =========== + + The density of the Erlang distribution is given by + + .. math:: + f(x) := \frac{\lambda^k x^{k-1} e^{-\lambda x}}{(k-1)!} + + with :math:`x \in [0,\infty]`. + + Parameters + ========== + + k : Positive integer + l : Real number, `\lambda > 0`, the rate + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Erlang, density, cdf, E, variance + >>> from sympy import Symbol, simplify, pprint + + >>> k = Symbol("k", integer=True, positive=True) + >>> l = Symbol("l", positive=True) + >>> z = Symbol("z") + + >>> X = Erlang("x", k, l) + + >>> D = density(X)(z) + >>> pprint(D, use_unicode=False) + k k - 1 -l*z + l *z *e + --------------- + Gamma(k) + + >>> C = cdf(X)(z) + >>> pprint(C, use_unicode=False) + /lowergamma(k, l*z) + |------------------ for z > 0 + < Gamma(k) + | + \ 0 otherwise + + + >>> E(X) + k/l + + >>> simplify(variance(X)) + k/l**2 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Erlang_distribution + .. [2] https://mathworld.wolfram.com/ErlangDistribution.html + + """ + + return rv(name, GammaDistribution, (k, S.One/l)) + +# ------------------------------------------------------------------------------- +# ExGaussian distribution ----------------------------------------------------- + + +class ExGaussianDistribution(SingleContinuousDistribution): + _argnames = ('mean', 'std', 'rate') + + set = Interval(-oo, oo) + + @staticmethod + def check(mean, std, rate): + _value_check( + std > 0, "Standard deviation of ExGaussian must be positive.") + _value_check(rate > 0, "Rate of ExGaussian must be positive.") + + def pdf(self, x): + mean, std, rate = self.mean, self.std, self.rate + term1 = rate/2 + term2 = exp(rate * (2 * mean + rate * std**2 - 2*x)/2) + term3 = erfc((mean + rate*std**2 - x)/(sqrt(2)*std)) + return term1*term2*term3 + + def _cdf(self, x): + from sympy.stats import cdf + mean, std, rate = self.mean, self.std, self.rate + u = rate*(x - mean) + v = rate*std + GaussianCDF1 = cdf(Normal('x', 0, v))(u) + GaussianCDF2 = cdf(Normal('x', v**2, v))(u) + + return GaussianCDF1 - exp(-u + (v**2/2) + log(GaussianCDF2)) + + def _characteristic_function(self, t): + mean, std, rate = self.mean, self.std, self.rate + term1 = (1 - I*t/rate)**(-1) + term2 = exp(I*mean*t - std**2*t**2/2) + return term1 * term2 + + def _moment_generating_function(self, t): + mean, std, rate = self.mean, self.std, self.rate + term1 = (1 - t/rate)**(-1) + term2 = exp(mean*t + std**2*t**2/2) + return term1*term2 + + +def ExGaussian(name, mean, std, rate): + r""" + Create a continuous random variable with an Exponentially modified + Gaussian (EMG) distribution. + + Explanation + =========== + + The density of the exponentially modified Gaussian distribution is given by + + .. math:: + f(x) := \frac{\lambda}{2}e^{\frac{\lambda}{2}(2\mu+\lambda\sigma^2-2x)} + \text{erfc}(\frac{\mu + \lambda\sigma^2 - x}{\sqrt{2}\sigma}) + + with $x > 0$. Note that the expected value is `1/\lambda`. + + Parameters + ========== + + name : A string giving a name for this distribution + mean : A Real number, the mean of Gaussian component + std : A positive Real number, + :math: `\sigma^2 > 0` the variance of Gaussian component + rate : A positive Real number, + :math: `\lambda > 0` the rate of Exponential component + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import ExGaussian, density, cdf, E + >>> from sympy.stats import variance, skewness + >>> from sympy import Symbol, pprint, simplify + + >>> mean = Symbol("mu") + >>> std = Symbol("sigma", positive=True) + >>> rate = Symbol("lamda", positive=True) + >>> z = Symbol("z") + >>> X = ExGaussian("x", mean, std, rate) + + >>> pprint(density(X)(z), use_unicode=False) + / 2 \ + lamda*\lamda*sigma + 2*mu - 2*z/ + --------------------------------- / ___ / 2 \\ + 2 |\/ 2 *\lamda*sigma + mu - z/| + lamda*e *erfc|-----------------------------| + \ 2*sigma / + ---------------------------------------------------------------------------- + 2 + + >>> cdf(X)(z) + -(erf(sqrt(2)*(-lamda**2*sigma**2 + lamda*(-mu + z))/(2*lamda*sigma))/2 + 1/2)*exp(lamda**2*sigma**2/2 - lamda*(-mu + z)) + erf(sqrt(2)*(-mu + z)/(2*sigma))/2 + 1/2 + + >>> E(X) + (lamda*mu + 1)/lamda + + >>> simplify(variance(X)) + sigma**2 + lamda**(-2) + + >>> simplify(skewness(X)) + 2/(lamda**2*sigma**2 + 1)**(3/2) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Exponentially_modified_Gaussian_distribution + """ + return rv(name, ExGaussianDistribution, (mean, std, rate)) + +#------------------------------------------------------------------------------- +# Exponential distribution ----------------------------------------------------- + + +class ExponentialDistribution(SingleContinuousDistribution): + _argnames = ('rate',) + + set = Interval(0, oo) + + @staticmethod + def check(rate): + _value_check(rate > 0, "Rate must be positive.") + + def pdf(self, x): + return self.rate * exp(-self.rate*x) + + def _cdf(self, x): + return Piecewise( + (S.One - exp(-self.rate*x), x >= 0), + (0, True), + ) + + def _characteristic_function(self, t): + rate = self.rate + return rate / (rate - I*t) + + def _moment_generating_function(self, t): + rate = self.rate + return rate / (rate - t) + + def _quantile(self, p): + return -log(1-p)/self.rate + + +def Exponential(name, rate): + r""" + Create a continuous random variable with an Exponential distribution. + + Explanation + =========== + + The density of the exponential distribution is given by + + .. math:: + f(x) := \lambda \exp(-\lambda x) + + with $x > 0$. Note that the expected value is `1/\lambda`. + + Parameters + ========== + + rate : A positive Real number, `\lambda > 0`, the rate (or inverse scale/inverse mean) + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Exponential, density, cdf, E + >>> from sympy.stats import variance, std, skewness, quantile + >>> from sympy import Symbol + + >>> l = Symbol("lambda", positive=True) + >>> z = Symbol("z") + >>> p = Symbol("p") + >>> X = Exponential("x", l) + + >>> density(X)(z) + lambda*exp(-lambda*z) + + >>> cdf(X)(z) + Piecewise((1 - exp(-lambda*z), z >= 0), (0, True)) + + >>> quantile(X)(p) + -log(1 - p)/lambda + + >>> E(X) + 1/lambda + + >>> variance(X) + lambda**(-2) + + >>> skewness(X) + 2 + + >>> X = Exponential('x', 10) + + >>> density(X)(z) + 10*exp(-10*z) + + >>> E(X) + 1/10 + + >>> std(X) + 1/10 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Exponential_distribution + .. [2] https://mathworld.wolfram.com/ExponentialDistribution.html + + """ + + return rv(name, ExponentialDistribution, (rate, )) + + +# ------------------------------------------------------------------------------- +# Exponential Power distribution ----------------------------------------------------- + +class ExponentialPowerDistribution(SingleContinuousDistribution): + _argnames = ('mu', 'alpha', 'beta') + + set = Interval(-oo, oo) + + @staticmethod + def check(mu, alpha, beta): + _value_check(alpha > 0, "Scale parameter alpha must be positive.") + _value_check(beta > 0, "Shape parameter beta must be positive.") + + def pdf(self, x): + mu, alpha, beta = self.mu, self.alpha, self.beta + num = beta*exp(-(Abs(x - mu)/alpha)**beta) + den = 2*alpha*gamma(1/beta) + return num/den + + def _cdf(self, x): + mu, alpha, beta = self.mu, self.alpha, self.beta + num = lowergamma(1/beta, (Abs(x - mu) / alpha)**beta) + den = 2*gamma(1/beta) + return sign(x - mu)*num/den + S.Half + + +def ExponentialPower(name, mu, alpha, beta): + r""" + Create a Continuous Random Variable with Exponential Power distribution. + This distribution is known also as Generalized Normal + distribution version 1. + + Explanation + =========== + + The density of the Exponential Power distribution is given by + + .. math:: + f(x) := \frac{\beta}{2\alpha\Gamma(\frac{1}{\beta})} + e^{{-(\frac{|x - \mu|}{\alpha})^{\beta}}} + + with :math:`x \in [ - \infty, \infty ]`. + + Parameters + ========== + + mu : Real number + A location. + alpha : Real number,`\alpha > 0` + A scale. + beta : Real number, `\beta > 0` + A shape. + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import ExponentialPower, density, cdf + >>> from sympy import Symbol, pprint + >>> z = Symbol("z") + >>> mu = Symbol("mu") + >>> alpha = Symbol("alpha", positive=True) + >>> beta = Symbol("beta", positive=True) + >>> X = ExponentialPower("x", mu, alpha, beta) + >>> pprint(density(X)(z), use_unicode=False) + beta + /|mu - z|\ + -|--------| + \ alpha / + beta*e + --------------------- + / 1 \ + 2*alpha*Gamma|----| + \beta/ + >>> cdf(X)(z) + 1/2 + lowergamma(1/beta, (Abs(mu - z)/alpha)**beta)*sign(-mu + z)/(2*gamma(1/beta)) + + References + ========== + + .. [1] https://reference.wolfram.com/language/ref/ExponentialPowerDistribution.html + .. [2] https://en.wikipedia.org/wiki/Generalized_normal_distribution#Version_1 + + """ + return rv(name, ExponentialPowerDistribution, (mu, alpha, beta)) + + +#------------------------------------------------------------------------------- +# F distribution --------------------------------------------------------------- + + +class FDistributionDistribution(SingleContinuousDistribution): + _argnames = ('d1', 'd2') + + set = Interval(0, oo) + + @staticmethod + def check(d1, d2): + _value_check((d1 > 0, d1.is_integer), + "Degrees of freedom d1 must be positive integer.") + _value_check((d2 > 0, d2.is_integer), + "Degrees of freedom d2 must be positive integer.") + + def pdf(self, x): + d1, d2 = self.d1, self.d2 + return (sqrt((d1*x)**d1*d2**d2 / (d1*x+d2)**(d1+d2)) + / (x * beta_fn(d1/2, d2/2))) + + def _moment_generating_function(self, t): + raise NotImplementedError('The moment generating function for the ' + 'F-distribution does not exist.') + +def FDistribution(name, d1, d2): + r""" + Create a continuous random variable with a F distribution. + + Explanation + =========== + + The density of the F distribution is given by + + .. math:: + f(x) := \frac{\sqrt{\frac{(d_1 x)^{d_1} d_2^{d_2}} + {(d_1 x + d_2)^{d_1 + d_2}}}} + {x \mathrm{B} \left(\frac{d_1}{2}, \frac{d_2}{2}\right)} + + with :math:`x > 0`. + + Parameters + ========== + + d1 : `d_1 > 0`, where `d_1` is the degrees of freedom (`n_1 - 1`) + d2 : `d_2 > 0`, where `d_2` is the degrees of freedom (`n_2 - 1`) + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import FDistribution, density + >>> from sympy import Symbol, pprint + + >>> d1 = Symbol("d1", positive=True) + >>> d2 = Symbol("d2", positive=True) + >>> z = Symbol("z") + + >>> X = FDistribution("x", d1, d2) + + >>> D = density(X)(z) + >>> pprint(D, use_unicode=False) + d2 + -- ______________________________ + 2 / d1 -d1 - d2 + d2 *\/ (d1*z) *(d1*z + d2) + -------------------------------------- + /d1 d2\ + z*B|--, --| + \2 2 / + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/F-distribution + .. [2] https://mathworld.wolfram.com/F-Distribution.html + + """ + + return rv(name, FDistributionDistribution, (d1, d2)) + +#------------------------------------------------------------------------------- +# Fisher Z distribution -------------------------------------------------------- + +class FisherZDistribution(SingleContinuousDistribution): + _argnames = ('d1', 'd2') + + set = Interval(-oo, oo) + + @staticmethod + def check(d1, d2): + _value_check(d1 > 0, "Degree of freedom d1 must be positive.") + _value_check(d2 > 0, "Degree of freedom d2 must be positive.") + + def pdf(self, x): + d1, d2 = self.d1, self.d2 + return (2*d1**(d1/2)*d2**(d2/2) / beta_fn(d1/2, d2/2) * + exp(d1*x) / (d1*exp(2*x)+d2)**((d1+d2)/2)) + +def FisherZ(name, d1, d2): + r""" + Create a Continuous Random Variable with an Fisher's Z distribution. + + Explanation + =========== + + The density of the Fisher's Z distribution is given by + + .. math:: + f(x) := \frac{2d_1^{d_1/2} d_2^{d_2/2}} {\mathrm{B}(d_1/2, d_2/2)} + \frac{e^{d_1z}}{\left(d_1e^{2z}+d_2\right)^{\left(d_1+d_2\right)/2}} + + + .. TODO - What is the difference between these degrees of freedom? + + Parameters + ========== + + d1 : `d_1 > 0` + Degree of freedom. + d2 : `d_2 > 0` + Degree of freedom. + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import FisherZ, density + >>> from sympy import Symbol, pprint + + >>> d1 = Symbol("d1", positive=True) + >>> d2 = Symbol("d2", positive=True) + >>> z = Symbol("z") + + >>> X = FisherZ("x", d1, d2) + + >>> D = density(X)(z) + >>> pprint(D, use_unicode=False) + d1 d2 + d1 d2 - -- - -- + -- -- 2 2 + 2 2 / 2*z \ d1*z + 2*d1 *d2 *\d1*e + d2/ *e + ----------------------------------------- + /d1 d2\ + B|--, --| + \2 2 / + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Fisher%27s_z-distribution + .. [2] https://mathworld.wolfram.com/Fishersz-Distribution.html + + """ + + return rv(name, FisherZDistribution, (d1, d2)) + +#------------------------------------------------------------------------------- +# Frechet distribution --------------------------------------------------------- + +class FrechetDistribution(SingleContinuousDistribution): + _argnames = ('a', 's', 'm') + + set = Interval(0, oo) + + @staticmethod + def check(a, s, m): + _value_check(a > 0, "Shape parameter alpha must be positive.") + _value_check(s > 0, "Scale parameter s must be positive.") + + def __new__(cls, a, s=1, m=0): + a, s, m = list(map(sympify, (a, s, m))) + return Basic.__new__(cls, a, s, m) + + def pdf(self, x): + a, s, m = self.a, self.s, self.m + return a/s * ((x-m)/s)**(-1-a) * exp(-((x-m)/s)**(-a)) + + def _cdf(self, x): + a, s, m = self.a, self.s, self.m + return Piecewise((exp(-((x-m)/s)**(-a)), x >= m), + (S.Zero, True)) + +def Frechet(name, a, s=1, m=0): + r""" + Create a continuous random variable with a Frechet distribution. + + Explanation + =========== + + The density of the Frechet distribution is given by + + .. math:: + f(x) := \frac{\alpha}{s} \left(\frac{x-m}{s}\right)^{-1-\alpha} + e^{-(\frac{x-m}{s})^{-\alpha}} + + with :math:`x \geq m`. + + Parameters + ========== + + a : Real number, :math:`a \in \left(0, \infty\right)` the shape + s : Real number, :math:`s \in \left(0, \infty\right)` the scale + m : Real number, :math:`m \in \left(-\infty, \infty\right)` the minimum + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Frechet, density, cdf + >>> from sympy import Symbol + + >>> a = Symbol("a", positive=True) + >>> s = Symbol("s", positive=True) + >>> m = Symbol("m", real=True) + >>> z = Symbol("z") + + >>> X = Frechet("x", a, s, m) + + >>> density(X)(z) + a*((-m + z)/s)**(-a - 1)*exp(-1/((-m + z)/s)**a)/s + + >>> cdf(X)(z) + Piecewise((exp(-1/((-m + z)/s)**a), m <= z), (0, True)) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Fr%C3%A9chet_distribution + + """ + + return rv(name, FrechetDistribution, (a, s, m)) + +#------------------------------------------------------------------------------- +# Gamma distribution ----------------------------------------------------------- + + +class GammaDistribution(SingleContinuousDistribution): + _argnames = ('k', 'theta') + + set = Interval(0, oo) + + @staticmethod + def check(k, theta): + _value_check(k > 0, "k must be positive") + _value_check(theta > 0, "Theta must be positive") + + def pdf(self, x): + k, theta = self.k, self.theta + return x**(k - 1) * exp(-x/theta) / (gamma(k)*theta**k) + + def _cdf(self, x): + k, theta = self.k, self.theta + return Piecewise( + (lowergamma(k, S(x)/theta)/gamma(k), x > 0), + (S.Zero, True)) + + def _characteristic_function(self, t): + return (1 - self.theta*I*t)**(-self.k) + + def _moment_generating_function(self, t): + return (1- self.theta*t)**(-self.k) + + +def Gamma(name, k, theta): + r""" + Create a continuous random variable with a Gamma distribution. + + Explanation + =========== + + The density of the Gamma distribution is given by + + .. math:: + f(x) := \frac{1}{\Gamma(k) \theta^k} x^{k - 1} e^{-\frac{x}{\theta}} + + with :math:`x \in [0,1]`. + + Parameters + ========== + + k : Real number, `k > 0`, a shape + theta : Real number, `\theta > 0`, a scale + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Gamma, density, cdf, E, variance + >>> from sympy import Symbol, pprint, simplify + + >>> k = Symbol("k", positive=True) + >>> theta = Symbol("theta", positive=True) + >>> z = Symbol("z") + + >>> X = Gamma("x", k, theta) + + >>> D = density(X)(z) + >>> pprint(D, use_unicode=False) + -z + ----- + -k k - 1 theta + theta *z *e + --------------------- + Gamma(k) + + >>> C = cdf(X, meijerg=True)(z) + >>> pprint(C, use_unicode=False) + / / z \ + |k*lowergamma|k, -----| + | \ theta/ + <---------------------- for z >= 0 + | Gamma(k + 1) + | + \ 0 otherwise + + >>> E(X) + k*theta + + >>> V = simplify(variance(X)) + >>> pprint(V, use_unicode=False) + 2 + k*theta + + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Gamma_distribution + .. [2] https://mathworld.wolfram.com/GammaDistribution.html + + """ + + return rv(name, GammaDistribution, (k, theta)) + +#------------------------------------------------------------------------------- +# Inverse Gamma distribution --------------------------------------------------- + + +class GammaInverseDistribution(SingleContinuousDistribution): + _argnames = ('a', 'b') + + set = Interval(0, oo) + + @staticmethod + def check(a, b): + _value_check(a > 0, "alpha must be positive") + _value_check(b > 0, "beta must be positive") + + def pdf(self, x): + a, b = self.a, self.b + return b**a/gamma(a) * x**(-a-1) * exp(-b/x) + + def _cdf(self, x): + a, b = self.a, self.b + return Piecewise((uppergamma(a,b/x)/gamma(a), x > 0), + (S.Zero, True)) + + def _characteristic_function(self, t): + a, b = self.a, self.b + return 2 * (-I*b*t)**(a/2) * besselk(a, sqrt(-4*I*b*t)) / gamma(a) + + def _moment_generating_function(self, t): + raise NotImplementedError('The moment generating function for the ' + 'gamma inverse distribution does not exist.') + +def GammaInverse(name, a, b): + r""" + Create a continuous random variable with an inverse Gamma distribution. + + Explanation + =========== + + The density of the inverse Gamma distribution is given by + + .. math:: + f(x) := \frac{\beta^\alpha}{\Gamma(\alpha)} x^{-\alpha - 1} + \exp\left(\frac{-\beta}{x}\right) + + with :math:`x > 0`. + + Parameters + ========== + + a : Real number, `a > 0`, a shape + b : Real number, `b > 0`, a scale + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import GammaInverse, density, cdf + >>> from sympy import Symbol, pprint + + >>> a = Symbol("a", positive=True) + >>> b = Symbol("b", positive=True) + >>> z = Symbol("z") + + >>> X = GammaInverse("x", a, b) + + >>> D = density(X)(z) + >>> pprint(D, use_unicode=False) + -b + --- + a -a - 1 z + b *z *e + --------------- + Gamma(a) + + >>> cdf(X)(z) + Piecewise((uppergamma(a, b/z)/gamma(a), z > 0), (0, True)) + + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Inverse-gamma_distribution + + """ + + return rv(name, GammaInverseDistribution, (a, b)) + + +#------------------------------------------------------------------------------- +# Gumbel distribution (Maximum and Minimum) -------------------------------------------------------- + + +class GumbelDistribution(SingleContinuousDistribution): + _argnames = ('beta', 'mu', 'minimum') + + set = Interval(-oo, oo) + + @staticmethod + def check(beta, mu, minimum): + _value_check(beta > 0, "Scale parameter beta must be positive.") + + def pdf(self, x): + beta, mu = self.beta, self.mu + z = (x - mu)/beta + f_max = (1/beta)*exp(-z - exp(-z)) + f_min = (1/beta)*exp(z - exp(z)) + return Piecewise((f_min, self.minimum), (f_max, not self.minimum)) + + def _cdf(self, x): + beta, mu = self.beta, self.mu + z = (x - mu)/beta + F_max = exp(-exp(-z)) + F_min = 1 - exp(-exp(z)) + return Piecewise((F_min, self.minimum), (F_max, not self.minimum)) + + def _characteristic_function(self, t): + cf_max = gamma(1 - I*self.beta*t) * exp(I*self.mu*t) + cf_min = gamma(1 + I*self.beta*t) * exp(I*self.mu*t) + return Piecewise((cf_min, self.minimum), (cf_max, not self.minimum)) + + def _moment_generating_function(self, t): + mgf_max = gamma(1 - self.beta*t) * exp(self.mu*t) + mgf_min = gamma(1 + self.beta*t) * exp(self.mu*t) + return Piecewise((mgf_min, self.minimum), (mgf_max, not self.minimum)) + +def Gumbel(name, beta, mu, minimum=False): + r""" + Create a Continuous Random Variable with Gumbel distribution. + + Explanation + =========== + + The density of the Gumbel distribution is given by + + For Maximum + + .. math:: + f(x) := \dfrac{1}{\beta} \exp \left( -\dfrac{x-\mu}{\beta} + - \exp \left( -\dfrac{x - \mu}{\beta} \right) \right) + + with :math:`x \in [ - \infty, \infty ]`. + + For Minimum + + .. math:: + f(x) := \frac{e^{- e^{\frac{- \mu + x}{\beta}} + \frac{- \mu + x}{\beta}}}{\beta} + + with :math:`x \in [ - \infty, \infty ]`. + + Parameters + ========== + + mu : Real number, `\mu`, a location + beta : Real number, `\beta > 0`, a scale + minimum : Boolean, by default ``False``, set to ``True`` for enabling minimum distribution + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Gumbel, density, cdf + >>> from sympy import Symbol + >>> x = Symbol("x") + >>> mu = Symbol("mu") + >>> beta = Symbol("beta", positive=True) + >>> X = Gumbel("x", beta, mu) + >>> density(X)(x) + exp(-exp(-(-mu + x)/beta) - (-mu + x)/beta)/beta + >>> cdf(X)(x) + exp(-exp(-(-mu + x)/beta)) + + References + ========== + + .. [1] https://mathworld.wolfram.com/GumbelDistribution.html + .. [2] https://en.wikipedia.org/wiki/Gumbel_distribution + .. [3] https://web.archive.org/web/20200628222206/http://www.mathwave.com/help/easyfit/html/analyses/distributions/gumbel_max.html + .. [4] https://web.archive.org/web/20200628222212/http://www.mathwave.com/help/easyfit/html/analyses/distributions/gumbel_min.html + + """ + return rv(name, GumbelDistribution, (beta, mu, minimum)) + +#------------------------------------------------------------------------------- +# Gompertz distribution -------------------------------------------------------- + +class GompertzDistribution(SingleContinuousDistribution): + _argnames = ('b', 'eta') + + set = Interval(0, oo) + + @staticmethod + def check(b, eta): + _value_check(b > 0, "b must be positive") + _value_check(eta > 0, "eta must be positive") + + def pdf(self, x): + eta, b = self.eta, self.b + return b*eta*exp(b*x)*exp(eta)*exp(-eta*exp(b*x)) + + def _cdf(self, x): + eta, b = self.eta, self.b + return 1 - exp(eta)*exp(-eta*exp(b*x)) + + def _moment_generating_function(self, t): + eta, b = self.eta, self.b + return eta * exp(eta) * expint(t/b, eta) + +def Gompertz(name, b, eta): + r""" + Create a Continuous Random Variable with Gompertz distribution. + + Explanation + =========== + + The density of the Gompertz distribution is given by + + .. math:: + f(x) := b \eta e^{b x} e^{\eta} \exp \left(-\eta e^{bx} \right) + + with :math:`x \in [0, \infty)`. + + Parameters + ========== + + b : Real number, `b > 0`, a scale + eta : Real number, `\eta > 0`, a shape + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Gompertz, density + >>> from sympy import Symbol + + >>> b = Symbol("b", positive=True) + >>> eta = Symbol("eta", positive=True) + >>> z = Symbol("z") + + >>> X = Gompertz("x", b, eta) + + >>> density(X)(z) + b*eta*exp(eta)*exp(b*z)*exp(-eta*exp(b*z)) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Gompertz_distribution + + """ + return rv(name, GompertzDistribution, (b, eta)) + +#------------------------------------------------------------------------------- +# Kumaraswamy distribution ----------------------------------------------------- + + +class KumaraswamyDistribution(SingleContinuousDistribution): + _argnames = ('a', 'b') + + set = Interval(0, oo) + + @staticmethod + def check(a, b): + _value_check(a > 0, "a must be positive") + _value_check(b > 0, "b must be positive") + + def pdf(self, x): + a, b = self.a, self.b + return a * b * x**(a-1) * (1-x**a)**(b-1) + + def _cdf(self, x): + a, b = self.a, self.b + return Piecewise( + (S.Zero, x < S.Zero), + (1 - (1 - x**a)**b, x <= S.One), + (S.One, True)) + +def Kumaraswamy(name, a, b): + r""" + Create a Continuous Random Variable with a Kumaraswamy distribution. + + Explanation + =========== + + The density of the Kumaraswamy distribution is given by + + .. math:: + f(x) := a b x^{a-1} (1-x^a)^{b-1} + + with :math:`x \in [0,1]`. + + Parameters + ========== + + a : Real number, `a > 0`, a shape + b : Real number, `b > 0`, a shape + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Kumaraswamy, density, cdf + >>> from sympy import Symbol, pprint + + >>> a = Symbol("a", positive=True) + >>> b = Symbol("b", positive=True) + >>> z = Symbol("z") + + >>> X = Kumaraswamy("x", a, b) + + >>> D = density(X)(z) + >>> pprint(D, use_unicode=False) + b - 1 + a - 1 / a\ + a*b*z *\1 - z / + + >>> cdf(X)(z) + Piecewise((0, z < 0), (1 - (1 - z**a)**b, z <= 1), (1, True)) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Kumaraswamy_distribution + + """ + + return rv(name, KumaraswamyDistribution, (a, b)) + +#------------------------------------------------------------------------------- +# Laplace distribution --------------------------------------------------------- + + +class LaplaceDistribution(SingleContinuousDistribution): + _argnames = ('mu', 'b') + + set = Interval(-oo, oo) + + @staticmethod + def check(mu, b): + _value_check(b > 0, "Scale parameter b must be positive.") + _value_check(mu.is_real, "Location parameter mu should be real") + + def pdf(self, x): + mu, b = self.mu, self.b + return 1/(2*b)*exp(-Abs(x - mu)/b) + + def _cdf(self, x): + mu, b = self.mu, self.b + return Piecewise( + (S.Half*exp((x - mu)/b), x < mu), + (S.One - S.Half*exp(-(x - mu)/b), x >= mu) + ) + + def _characteristic_function(self, t): + return exp(self.mu*I*t) / (1 + self.b**2*t**2) + + def _moment_generating_function(self, t): + return exp(self.mu*t) / (1 - self.b**2*t**2) + +def Laplace(name, mu, b): + r""" + Create a continuous random variable with a Laplace distribution. + + Explanation + =========== + + The density of the Laplace distribution is given by + + .. math:: + f(x) := \frac{1}{2 b} \exp \left(-\frac{|x-\mu|}b \right) + + Parameters + ========== + + mu : Real number or a list/matrix, the location (mean) or the + location vector + b : Real number or a positive definite matrix, representing a scale + or the covariance matrix. + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Laplace, density, cdf + >>> from sympy import Symbol, pprint + + >>> mu = Symbol("mu") + >>> b = Symbol("b", positive=True) + >>> z = Symbol("z") + + >>> X = Laplace("x", mu, b) + + >>> density(X)(z) + exp(-Abs(mu - z)/b)/(2*b) + + >>> cdf(X)(z) + Piecewise((exp((-mu + z)/b)/2, mu > z), (1 - exp((mu - z)/b)/2, True)) + + >>> L = Laplace('L', [1, 2], [[1, 0], [0, 1]]) + >>> pprint(density(L)(1, 2), use_unicode=False) + 5 / ____\ + e *besselk\0, \/ 35 / + --------------------- + pi + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Laplace_distribution + .. [2] https://mathworld.wolfram.com/LaplaceDistribution.html + + """ + + if isinstance(mu, (list, MatrixBase)) and\ + isinstance(b, (list, MatrixBase)): + from sympy.stats.joint_rv_types import MultivariateLaplace + return MultivariateLaplace(name, mu, b) + + return rv(name, LaplaceDistribution, (mu, b)) + +#------------------------------------------------------------------------------- +# Levy distribution --------------------------------------------------------- + + +class LevyDistribution(SingleContinuousDistribution): + _argnames = ('mu', 'c') + + @property + def set(self): + return Interval(self.mu, oo) + + @staticmethod + def check(mu, c): + _value_check(c > 0, "c (scale parameter) must be positive") + _value_check(mu.is_real, "mu (location parameter) must be real") + + def pdf(self, x): + mu, c = self.mu, self.c + return sqrt(c/(2*pi))*exp(-c/(2*(x - mu)))/((x - mu)**(S.One + S.Half)) + + def _cdf(self, x): + mu, c = self.mu, self.c + return erfc(sqrt(c/(2*(x - mu)))) + + def _characteristic_function(self, t): + mu, c = self.mu, self.c + return exp(I * mu * t - sqrt(-2 * I * c * t)) + + def _moment_generating_function(self, t): + raise NotImplementedError('The moment generating function of Levy distribution does not exist.') + +def Levy(name, mu, c): + r""" + Create a continuous random variable with a Levy distribution. + + The density of the Levy distribution is given by + + .. math:: + f(x) := \sqrt(\frac{c}{2 \pi}) \frac{\exp -\frac{c}{2 (x - \mu)}}{(x - \mu)^{3/2}} + + Parameters + ========== + + mu : Real number + The location parameter. + c : Real number, `c > 0` + A scale parameter. + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Levy, density, cdf + >>> from sympy import Symbol + + >>> mu = Symbol("mu", real=True) + >>> c = Symbol("c", positive=True) + >>> z = Symbol("z") + + >>> X = Levy("x", mu, c) + + >>> density(X)(z) + sqrt(2)*sqrt(c)*exp(-c/(-2*mu + 2*z))/(2*sqrt(pi)*(-mu + z)**(3/2)) + + >>> cdf(X)(z) + erfc(sqrt(c)*sqrt(1/(-2*mu + 2*z))) + + References + ========== + .. [1] https://en.wikipedia.org/wiki/L%C3%A9vy_distribution + .. [2] https://mathworld.wolfram.com/LevyDistribution.html + """ + + return rv(name, LevyDistribution, (mu, c)) + +#------------------------------------------------------------------------------- +# Log-Cauchy distribution -------------------------------------------------------- + + +class LogCauchyDistribution(SingleContinuousDistribution): + _argnames = ('mu', 'sigma') + + set = Interval.open(0, oo) + + @staticmethod + def check(mu, sigma): + _value_check((sigma > 0) != False, "Scale parameter Gamma must be positive.") + _value_check(mu.is_real != False, "Location parameter must be real.") + + def pdf(self, x): + mu, sigma = self.mu, self.sigma + return 1/(x*pi)*(sigma/((log(x) - mu)**2 + sigma**2)) + + def _cdf(self, x): + mu, sigma = self.mu, self.sigma + return (1/pi)*atan((log(x) - mu)/sigma) + S.Half + + def _characteristic_function(self, t): + raise NotImplementedError("The characteristic function for the " + "Log-Cauchy distribution does not exist.") + + def _moment_generating_function(self, t): + raise NotImplementedError("The moment generating function for the " + "Log-Cauchy distribution does not exist.") + +def LogCauchy(name, mu, sigma): + r""" + Create a continuous random variable with a Log-Cauchy distribution. + The density of the Log-Cauchy distribution is given by + + .. math:: + f(x) := \frac{1}{\pi x} \frac{\sigma}{(log(x)-\mu^2) + \sigma^2} + + Parameters + ========== + + mu : Real number, the location + + sigma : Real number, `\sigma > 0`, a scale + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import LogCauchy, density, cdf + >>> from sympy import Symbol, S + + >>> mu = 2 + >>> sigma = S.One / 5 + >>> z = Symbol("z") + + >>> X = LogCauchy("x", mu, sigma) + + >>> density(X)(z) + 1/(5*pi*z*((log(z) - 2)**2 + 1/25)) + + >>> cdf(X)(z) + atan(5*log(z) - 10)/pi + 1/2 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Log-Cauchy_distribution + """ + + return rv(name, LogCauchyDistribution, (mu, sigma)) + + +#------------------------------------------------------------------------------- +# Logistic distribution -------------------------------------------------------- + + +class LogisticDistribution(SingleContinuousDistribution): + _argnames = ('mu', 's') + + set = Interval(-oo, oo) + + @staticmethod + def check(mu, s): + _value_check(s > 0, "Scale parameter s must be positive.") + + def pdf(self, x): + mu, s = self.mu, self.s + return exp(-(x - mu)/s)/(s*(1 + exp(-(x - mu)/s))**2) + + def _cdf(self, x): + mu, s = self.mu, self.s + return S.One/(1 + exp(-(x - mu)/s)) + + def _characteristic_function(self, t): + return Piecewise((exp(I*t*self.mu) * pi*self.s*t / sinh(pi*self.s*t), Ne(t, 0)), (S.One, True)) + + def _moment_generating_function(self, t): + return exp(self.mu*t) * beta_fn(1 - self.s*t, 1 + self.s*t) + + def _quantile(self, p): + return self.mu - self.s*log(-S.One + S.One/p) + +def Logistic(name, mu, s): + r""" + Create a continuous random variable with a logistic distribution. + + Explanation + =========== + + The density of the logistic distribution is given by + + .. math:: + f(x) := \frac{e^{-(x-\mu)/s}} {s\left(1+e^{-(x-\mu)/s}\right)^2} + + Parameters + ========== + + mu : Real number, the location (mean) + s : Real number, `s > 0`, a scale + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Logistic, density, cdf + >>> from sympy import Symbol + + >>> mu = Symbol("mu", real=True) + >>> s = Symbol("s", positive=True) + >>> z = Symbol("z") + + >>> X = Logistic("x", mu, s) + + >>> density(X)(z) + exp((mu - z)/s)/(s*(exp((mu - z)/s) + 1)**2) + + >>> cdf(X)(z) + 1/(exp((mu - z)/s) + 1) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Logistic_distribution + .. [2] https://mathworld.wolfram.com/LogisticDistribution.html + + """ + + return rv(name, LogisticDistribution, (mu, s)) + +#------------------------------------------------------------------------------- +# Log-logistic distribution -------------------------------------------------------- + + +class LogLogisticDistribution(SingleContinuousDistribution): + _argnames = ('alpha', 'beta') + + set = Interval(0, oo) + + @staticmethod + def check(alpha, beta): + _value_check(alpha > 0, "Scale parameter Alpha must be positive.") + _value_check(beta > 0, "Shape parameter Beta must be positive.") + + def pdf(self, x): + a, b = self.alpha, self.beta + return ((b/a)*(x/a)**(b - 1))/(1 + (x/a)**b)**2 + + def _cdf(self, x): + a, b = self.alpha, self.beta + return 1/(1 + (x/a)**(-b)) + + def _quantile(self, p): + a, b = self.alpha, self.beta + return a*((p/(1 - p))**(1/b)) + + def expectation(self, expr, var, **kwargs): + a, b = self.args + return Piecewise((S.NaN, b <= 1), (pi*a/(b*sin(pi/b)), True)) + +def LogLogistic(name, alpha, beta): + r""" + Create a continuous random variable with a log-logistic distribution. + The distribution is unimodal when ``beta > 1``. + + Explanation + =========== + + The density of the log-logistic distribution is given by + + .. math:: + f(x) := \frac{(\frac{\beta}{\alpha})(\frac{x}{\alpha})^{\beta - 1}} + {(1 + (\frac{x}{\alpha})^{\beta})^2} + + Parameters + ========== + + alpha : Real number, `\alpha > 0`, scale parameter and median of distribution + beta : Real number, `\beta > 0`, a shape parameter + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import LogLogistic, density, cdf, quantile + >>> from sympy import Symbol, pprint + + >>> alpha = Symbol("alpha", positive=True) + >>> beta = Symbol("beta", positive=True) + >>> p = Symbol("p") + >>> z = Symbol("z", positive=True) + + >>> X = LogLogistic("x", alpha, beta) + + >>> D = density(X)(z) + >>> pprint(D, use_unicode=False) + beta - 1 + / z \ + beta*|-----| + \alpha/ + ------------------------ + 2 + / beta \ + |/ z \ | + alpha*||-----| + 1| + \\alpha/ / + + >>> cdf(X)(z) + 1/(1 + (z/alpha)**(-beta)) + + >>> quantile(X)(p) + alpha*(p/(1 - p))**(1/beta) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Log-logistic_distribution + + """ + + return rv(name, LogLogisticDistribution, (alpha, beta)) + +#------------------------------------------------------------------------------- +#Logit-Normal distribution------------------------------------------------------ + +class LogitNormalDistribution(SingleContinuousDistribution): + _argnames = ('mu', 's') + set = Interval.open(0, 1) + + @staticmethod + def check(mu, s): + _value_check((s ** 2).is_real is not False and s ** 2 > 0, "Squared scale parameter s must be positive.") + _value_check(mu.is_real is not False, "Location parameter must be real") + + def _logit(self, x): + return log(x / (1 - x)) + + def pdf(self, x): + mu, s = self.mu, self.s + return exp(-(self._logit(x) - mu)**2/(2*s**2))*(S.One/sqrt(2*pi*(s**2)))*(1/(x*(1 - x))) + + def _cdf(self, x): + mu, s = self.mu, self.s + return (S.One/2)*(1 + erf((self._logit(x) - mu)/(sqrt(2*s**2)))) + + +def LogitNormal(name, mu, s): + r""" + Create a continuous random variable with a Logit-Normal distribution. + + The density of the logistic distribution is given by + + .. math:: + f(x) := \frac{1}{s \sqrt{2 \pi}} \frac{1}{x(1 - x)} e^{- \frac{(logit(x) - \mu)^2}{s^2}} + where logit(x) = \log(\frac{x}{1 - x}) + Parameters + ========== + + mu : Real number, the location (mean) + s : Real number, `s > 0`, a scale + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import LogitNormal, density, cdf + >>> from sympy import Symbol,pprint + + >>> mu = Symbol("mu", real=True) + >>> s = Symbol("s", positive=True) + >>> z = Symbol("z") + >>> X = LogitNormal("x",mu,s) + + >>> D = density(X)(z) + >>> pprint(D, use_unicode=False) + 2 + / / z \\ + -|-mu + log|-----|| + \ \1 - z// + --------------------- + 2 + ___ 2*s + \/ 2 *e + ---------------------------- + ____ + 2*\/ pi *s*z*(1 - z) + + >>> density(X)(z) + sqrt(2)*exp(-(-mu + log(z/(1 - z)))**2/(2*s**2))/(2*sqrt(pi)*s*z*(1 - z)) + + >>> cdf(X)(z) + erf(sqrt(2)*(-mu + log(z/(1 - z)))/(2*s))/2 + 1/2 + + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Logit-normal_distribution + + """ + + return rv(name, LogitNormalDistribution, (mu, s)) + +#------------------------------------------------------------------------------- +# Log Normal distribution ------------------------------------------------------ + + +class LogNormalDistribution(SingleContinuousDistribution): + _argnames = ('mean', 'std') + + set = Interval(0, oo) + + @staticmethod + def check(mean, std): + _value_check(std > 0, "Parameter std must be positive.") + + def pdf(self, x): + mean, std = self.mean, self.std + return exp(-(log(x) - mean)**2 / (2*std**2)) / (x*sqrt(2*pi)*std) + + def _cdf(self, x): + mean, std = self.mean, self.std + return Piecewise( + (S.Half + S.Half*erf((log(x) - mean)/sqrt(2)/std), x > 0), + (S.Zero, True) + ) + + def _moment_generating_function(self, t): + raise NotImplementedError('Moment generating function of the log-normal distribution is not defined.') + + +def LogNormal(name, mean, std): + r""" + Create a continuous random variable with a log-normal distribution. + + Explanation + =========== + + The density of the log-normal distribution is given by + + .. math:: + f(x) := \frac{1}{x\sqrt{2\pi\sigma^2}} + e^{-\frac{\left(\ln x-\mu\right)^2}{2\sigma^2}} + + with :math:`x \geq 0`. + + Parameters + ========== + + mu : Real number + The log-scale. + sigma : Real number + A shape. ($\sigma^2 > 0$) + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import LogNormal, density + >>> from sympy import Symbol, pprint + + >>> mu = Symbol("mu", real=True) + >>> sigma = Symbol("sigma", positive=True) + >>> z = Symbol("z") + + >>> X = LogNormal("x", mu, sigma) + + >>> D = density(X)(z) + >>> pprint(D, use_unicode=False) + 2 + -(-mu + log(z)) + ----------------- + 2 + ___ 2*sigma + \/ 2 *e + ------------------------ + ____ + 2*\/ pi *sigma*z + + + >>> X = LogNormal('x', 0, 1) # Mean 0, standard deviation 1 + + >>> density(X)(z) + sqrt(2)*exp(-log(z)**2/2)/(2*sqrt(pi)*z) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Lognormal + .. [2] https://mathworld.wolfram.com/LogNormalDistribution.html + + """ + + return rv(name, LogNormalDistribution, (mean, std)) + +#------------------------------------------------------------------------------- +# Lomax Distribution ----------------------------------------------------------- + +class LomaxDistribution(SingleContinuousDistribution): + _argnames = ('alpha', 'lamda',) + set = Interval(0, oo) + + @staticmethod + def check(alpha, lamda): + _value_check(alpha.is_real, "Shape parameter should be real.") + _value_check(lamda.is_real, "Scale parameter should be real.") + _value_check(alpha.is_positive, "Shape parameter should be positive.") + _value_check(lamda.is_positive, "Scale parameter should be positive.") + + def pdf(self, x): + lamba, alpha = self.lamda, self.alpha + return (alpha/lamba) * (S.One + x/lamba)**(-alpha-1) + +def Lomax(name, alpha, lamda): + r""" + Create a continuous random variable with a Lomax distribution. + + Explanation + =========== + + The density of the Lomax distribution is given by + + .. math:: + f(x) := \frac{\alpha}{\lambda}\left[1+\frac{x}{\lambda}\right]^{-(\alpha+1)} + + Parameters + ========== + + alpha : Real Number, `\alpha > 0` + Shape parameter + lamda : Real Number, `\lambda > 0` + Scale parameter + + Examples + ======== + + >>> from sympy.stats import Lomax, density, cdf, E + >>> from sympy import symbols + >>> a, l = symbols('a, l', positive=True) + >>> X = Lomax('X', a, l) + >>> x = symbols('x') + >>> density(X)(x) + a*(1 + x/l)**(-a - 1)/l + >>> cdf(X)(x) + Piecewise((1 - 1/(1 + x/l)**a, x >= 0), (0, True)) + >>> a = 2 + >>> X = Lomax('X', a, l) + >>> E(X) + l + + Returns + ======= + + RandomSymbol + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Lomax_distribution + + """ + return rv(name, LomaxDistribution, (alpha, lamda)) + +#------------------------------------------------------------------------------- +# Maxwell distribution --------------------------------------------------------- + + +class MaxwellDistribution(SingleContinuousDistribution): + _argnames = ('a',) + + set = Interval(0, oo) + + @staticmethod + def check(a): + _value_check(a > 0, "Parameter a must be positive.") + + def pdf(self, x): + a = self.a + return sqrt(2/pi)*x**2*exp(-x**2/(2*a**2))/a**3 + + def _cdf(self, x): + a = self.a + return erf(sqrt(2)*x/(2*a)) - sqrt(2)*x*exp(-x**2/(2*a**2))/(sqrt(pi)*a) + +def Maxwell(name, a): + r""" + Create a continuous random variable with a Maxwell distribution. + + Explanation + =========== + + The density of the Maxwell distribution is given by + + .. math:: + f(x) := \sqrt{\frac{2}{\pi}} \frac{x^2 e^{-x^2/(2a^2)}}{a^3} + + with :math:`x \geq 0`. + + .. TODO - what does the parameter mean? + + Parameters + ========== + + a : Real number, `a > 0` + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Maxwell, density, E, variance + >>> from sympy import Symbol, simplify + + >>> a = Symbol("a", positive=True) + >>> z = Symbol("z") + + >>> X = Maxwell("x", a) + + >>> density(X)(z) + sqrt(2)*z**2*exp(-z**2/(2*a**2))/(sqrt(pi)*a**3) + + >>> E(X) + 2*sqrt(2)*a/sqrt(pi) + + >>> simplify(variance(X)) + a**2*(-8 + 3*pi)/pi + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Maxwell_distribution + .. [2] https://mathworld.wolfram.com/MaxwellDistribution.html + + """ + + return rv(name, MaxwellDistribution, (a, )) + +#------------------------------------------------------------------------------- +# Moyal Distribution ----------------------------------------------------------- +class MoyalDistribution(SingleContinuousDistribution): + _argnames = ('mu', 'sigma') + + @staticmethod + def check(mu, sigma): + _value_check(mu.is_real, "Location parameter must be real.") + _value_check(sigma.is_real and sigma > 0, "Scale parameter must be real\ + and positive.") + + def pdf(self, x): + mu, sigma = self.mu, self.sigma + num = exp(-(exp(-(x - mu)/sigma) + (x - mu)/(sigma))/2) + den = (sqrt(2*pi) * sigma) + return num/den + + def _characteristic_function(self, t): + mu, sigma = self.mu, self.sigma + term1 = exp(I*t*mu) + term2 = (2**(-I*sigma*t) * gamma(Rational(1, 2) - I*t*sigma)) + return (term1 * term2)/sqrt(pi) + + def _moment_generating_function(self, t): + mu, sigma = self.mu, self.sigma + term1 = exp(t*mu) + term2 = (2**(-1*sigma*t) * gamma(Rational(1, 2) - t*sigma)) + return (term1 * term2)/sqrt(pi) + +def Moyal(name, mu, sigma): + r""" + Create a continuous random variable with a Moyal distribution. + + Explanation + =========== + + The density of the Moyal distribution is given by + + .. math:: + f(x) := \frac{\exp-\frac{1}{2}\exp-\frac{x-\mu}{\sigma}-\frac{x-\mu}{2\sigma}}{\sqrt{2\pi}\sigma} + + with :math:`x \in \mathbb{R}`. + + Parameters + ========== + + mu : Real number + Location parameter + sigma : Real positive number + Scale parameter + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Moyal, density, cdf + >>> from sympy import Symbol, simplify + >>> mu = Symbol("mu", real=True) + >>> sigma = Symbol("sigma", positive=True, real=True) + >>> z = Symbol("z") + >>> X = Moyal("x", mu, sigma) + >>> density(X)(z) + sqrt(2)*exp(-exp((mu - z)/sigma)/2 - (-mu + z)/(2*sigma))/(2*sqrt(pi)*sigma) + >>> simplify(cdf(X)(z)) + 1 - erf(sqrt(2)*exp((mu - z)/(2*sigma))/2) + + References + ========== + + .. [1] https://reference.wolfram.com/language/ref/MoyalDistribution.html + .. [2] https://www.stat.rice.edu/~dobelman/textfiles/DistributionsHandbook.pdf + + """ + + return rv(name, MoyalDistribution, (mu, sigma)) + +#------------------------------------------------------------------------------- +# Nakagami distribution -------------------------------------------------------- + + +class NakagamiDistribution(SingleContinuousDistribution): + _argnames = ('mu', 'omega') + + set = Interval(0, oo) + + @staticmethod + def check(mu, omega): + _value_check(mu >= S.Half, "Shape parameter mu must be greater than equal to 1/2.") + _value_check(omega > 0, "Spread parameter omega must be positive.") + + def pdf(self, x): + mu, omega = self.mu, self.omega + return 2*mu**mu/(gamma(mu)*omega**mu)*x**(2*mu - 1)*exp(-mu/omega*x**2) + + def _cdf(self, x): + mu, omega = self.mu, self.omega + return Piecewise( + (lowergamma(mu, (mu/omega)*x**2)/gamma(mu), x > 0), + (S.Zero, True)) + +def Nakagami(name, mu, omega): + r""" + Create a continuous random variable with a Nakagami distribution. + + Explanation + =========== + + The density of the Nakagami distribution is given by + + .. math:: + f(x) := \frac{2\mu^\mu}{\Gamma(\mu)\omega^\mu} x^{2\mu-1} + \exp\left(-\frac{\mu}{\omega}x^2 \right) + + with :math:`x > 0`. + + Parameters + ========== + + mu : Real number, `\mu \geq \frac{1}{2}`, a shape + omega : Real number, `\omega > 0`, the spread + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Nakagami, density, E, variance, cdf + >>> from sympy import Symbol, simplify, pprint + + >>> mu = Symbol("mu", positive=True) + >>> omega = Symbol("omega", positive=True) + >>> z = Symbol("z") + + >>> X = Nakagami("x", mu, omega) + + >>> D = density(X)(z) + >>> pprint(D, use_unicode=False) + 2 + -mu*z + ------- + mu -mu 2*mu - 1 omega + 2*mu *omega *z *e + ---------------------------------- + Gamma(mu) + + >>> simplify(E(X)) + sqrt(mu)*sqrt(omega)*gamma(mu + 1/2)/gamma(mu + 1) + + >>> V = simplify(variance(X)) + >>> pprint(V, use_unicode=False) + 2 + omega*Gamma (mu + 1/2) + omega - ----------------------- + Gamma(mu)*Gamma(mu + 1) + + >>> cdf(X)(z) + Piecewise((lowergamma(mu, mu*z**2/omega)/gamma(mu), z > 0), + (0, True)) + + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Nakagami_distribution + + """ + + return rv(name, NakagamiDistribution, (mu, omega)) + +#------------------------------------------------------------------------------- +# Normal distribution ---------------------------------------------------------- + + +class NormalDistribution(SingleContinuousDistribution): + _argnames = ('mean', 'std') + + @staticmethod + def check(mean, std): + _value_check(std > 0, "Standard deviation must be positive") + + def pdf(self, x): + return exp(-(x - self.mean)**2 / (2*self.std**2)) / (sqrt(2*pi)*self.std) + + def _cdf(self, x): + mean, std = self.mean, self.std + return erf(sqrt(2)*(-mean + x)/(2*std))/2 + S.Half + + def _characteristic_function(self, t): + mean, std = self.mean, self.std + return exp(I*mean*t - std**2*t**2/2) + + def _moment_generating_function(self, t): + mean, std = self.mean, self.std + return exp(mean*t + std**2*t**2/2) + + def _quantile(self, p): + mean, std = self.mean, self.std + return mean + std*sqrt(2)*erfinv(2*p - 1) + + +def Normal(name, mean, std): + r""" + Create a continuous random variable with a Normal distribution. + + Explanation + =========== + + The density of the Normal distribution is given by + + .. math:: + f(x) := \frac{1}{\sigma\sqrt{2\pi}} e^{ -\frac{(x-\mu)^2}{2\sigma^2} } + + Parameters + ========== + + mu : Real number or a list representing the mean or the mean vector + sigma : Real number or a positive definite square matrix, + :math:`\sigma^2 > 0`, the variance + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Normal, density, E, std, cdf, skewness, quantile, marginal_distribution + >>> from sympy import Symbol, simplify, pprint + + >>> mu = Symbol("mu") + >>> sigma = Symbol("sigma", positive=True) + >>> z = Symbol("z") + >>> y = Symbol("y") + >>> p = Symbol("p") + >>> X = Normal("x", mu, sigma) + + >>> density(X)(z) + sqrt(2)*exp(-(-mu + z)**2/(2*sigma**2))/(2*sqrt(pi)*sigma) + + >>> C = simplify(cdf(X))(z) # it needs a little more help... + >>> pprint(C, use_unicode=False) + / ___ \ + |\/ 2 *(-mu + z)| + erf|---------------| + \ 2*sigma / 1 + -------------------- + - + 2 2 + + >>> quantile(X)(p) + mu + sqrt(2)*sigma*erfinv(2*p - 1) + + >>> simplify(skewness(X)) + 0 + + >>> X = Normal("x", 0, 1) # Mean 0, standard deviation 1 + >>> density(X)(z) + sqrt(2)*exp(-z**2/2)/(2*sqrt(pi)) + + >>> E(2*X + 1) + 1 + + >>> simplify(std(2*X + 1)) + 2 + + >>> m = Normal('X', [1, 2], [[2, 1], [1, 2]]) + >>> pprint(density(m)(y, z), use_unicode=False) + 2 2 + y y*z z + - -- + --- - -- + z - 1 + ___ 3 3 3 + \/ 3 *e + ------------------------------ + 6*pi + + >>> marginal_distribution(m, m[0])(1) + 1/(2*sqrt(pi)) + + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Normal_distribution + .. [2] https://mathworld.wolfram.com/NormalDistributionFunction.html + + """ + + if isinstance(mean, list) or getattr(mean, 'is_Matrix', False) and\ + isinstance(std, list) or getattr(std, 'is_Matrix', False): + from sympy.stats.joint_rv_types import MultivariateNormal + return MultivariateNormal(name, mean, std) + return rv(name, NormalDistribution, (mean, std)) + + +#------------------------------------------------------------------------------- +# Inverse Gaussian distribution ---------------------------------------------------------- + + +class GaussianInverseDistribution(SingleContinuousDistribution): + _argnames = ('mean', 'shape') + + @property + def set(self): + return Interval(0, oo) + + @staticmethod + def check(mean, shape): + _value_check(shape > 0, "Shape parameter must be positive") + _value_check(mean > 0, "Mean must be positive") + + def pdf(self, x): + mu, s = self.mean, self.shape + return exp(-s*(x - mu)**2 / (2*x*mu**2)) * sqrt(s/(2*pi*x**3)) + + def _cdf(self, x): + from sympy.stats import cdf + mu, s = self.mean, self.shape + stdNormalcdf = cdf(Normal('x', 0, 1)) + + first_term = stdNormalcdf(sqrt(s/x) * ((x/mu) - S.One)) + second_term = exp(2*s/mu) * stdNormalcdf(-sqrt(s/x)*(x/mu + S.One)) + + return first_term + second_term + + def _characteristic_function(self, t): + mu, s = self.mean, self.shape + return exp((s/mu)*(1 - sqrt(1 - (2*mu**2*I*t)/s))) + + def _moment_generating_function(self, t): + mu, s = self.mean, self.shape + return exp((s/mu)*(1 - sqrt(1 - (2*mu**2*t)/s))) + + +def GaussianInverse(name, mean, shape): + r""" + Create a continuous random variable with an Inverse Gaussian distribution. + Inverse Gaussian distribution is also known as Wald distribution. + + Explanation + =========== + + The density of the Inverse Gaussian distribution is given by + + .. math:: + f(x) := \sqrt{\frac{\lambda}{2\pi x^3}} e^{-\frac{\lambda(x-\mu)^2}{2x\mu^2}} + + Parameters + ========== + + mu : + Positive number representing the mean. + lambda : + Positive number representing the shape parameter. + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import GaussianInverse, density, E, std, skewness + >>> from sympy import Symbol, pprint + + >>> mu = Symbol("mu", positive=True) + >>> lamda = Symbol("lambda", positive=True) + >>> z = Symbol("z", positive=True) + >>> X = GaussianInverse("x", mu, lamda) + + >>> D = density(X)(z) + >>> pprint(D, use_unicode=False) + 2 + -lambda*(-mu + z) + ------------------- + 2 + ___ ________ 2*mu *z + \/ 2 *\/ lambda *e + ------------------------------------- + ____ 3/2 + 2*\/ pi *z + + >>> E(X) + mu + + >>> std(X).expand() + mu**(3/2)/sqrt(lambda) + + >>> skewness(X).expand() + 3*sqrt(mu)/sqrt(lambda) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Inverse_Gaussian_distribution + .. [2] https://mathworld.wolfram.com/InverseGaussianDistribution.html + + """ + + return rv(name, GaussianInverseDistribution, (mean, shape)) + +Wald = GaussianInverse + +#------------------------------------------------------------------------------- +# Pareto distribution ---------------------------------------------------------- + + +class ParetoDistribution(SingleContinuousDistribution): + _argnames = ('xm', 'alpha') + + @property + def set(self): + return Interval(self.xm, oo) + + @staticmethod + def check(xm, alpha): + _value_check(xm > 0, "Xm must be positive") + _value_check(alpha > 0, "Alpha must be positive") + + def pdf(self, x): + xm, alpha = self.xm, self.alpha + return alpha * xm**alpha / x**(alpha + 1) + + def _cdf(self, x): + xm, alpha = self.xm, self.alpha + return Piecewise( + (S.One - xm**alpha/x**alpha, x>=xm), + (0, True), + ) + + def _moment_generating_function(self, t): + xm, alpha = self.xm, self.alpha + return alpha * (-xm*t)**alpha * uppergamma(-alpha, -xm*t) + + def _characteristic_function(self, t): + xm, alpha = self.xm, self.alpha + return alpha * (-I * xm * t) ** alpha * uppergamma(-alpha, -I * xm * t) + + +def Pareto(name, xm, alpha): + r""" + Create a continuous random variable with the Pareto distribution. + + Explanation + =========== + + The density of the Pareto distribution is given by + + .. math:: + f(x) := \frac{\alpha\,x_m^\alpha}{x^{\alpha+1}} + + with :math:`x \in [x_m,\infty]`. + + Parameters + ========== + + xm : Real number, `x_m > 0`, a scale + alpha : Real number, `\alpha > 0`, a shape + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Pareto, density + >>> from sympy import Symbol + + >>> xm = Symbol("xm", positive=True) + >>> beta = Symbol("beta", positive=True) + >>> z = Symbol("z") + + >>> X = Pareto("x", xm, beta) + + >>> density(X)(z) + beta*xm**beta*z**(-beta - 1) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Pareto_distribution + .. [2] https://mathworld.wolfram.com/ParetoDistribution.html + + """ + + return rv(name, ParetoDistribution, (xm, alpha)) + +#------------------------------------------------------------------------------- +# PowerFunction distribution --------------------------------------------------- + + +class PowerFunctionDistribution(SingleContinuousDistribution): + _argnames=('alpha','a','b') + + @property + def set(self): + return Interval(self.a, self.b) + + @staticmethod + def check(alpha, a, b): + _value_check(a.is_real, "Continuous Boundary parameter should be real.") + _value_check(b.is_real, "Continuous Boundary parameter should be real.") + _value_check(a < b, " 'a' the left Boundary must be smaller than 'b' the right Boundary." ) + _value_check(alpha.is_positive, "Continuous Shape parameter should be positive.") + + def pdf(self, x): + alpha, a, b = self.alpha, self.a, self.b + num = alpha*(x - a)**(alpha - 1) + den = (b - a)**alpha + return num/den + +def PowerFunction(name, alpha, a, b): + r""" + Creates a continuous random variable with a Power Function Distribution. + + Explanation + =========== + + The density of PowerFunction distribution is given by + + .. math:: + f(x) := \frac{{\alpha}(x - a)^{\alpha - 1}}{(b - a)^{\alpha}} + + with :math:`x \in [a,b]`. + + Parameters + ========== + + alpha : Positive number, `0 < \alpha`, the shape parameter + a : Real number, :math:`-\infty < a`, the left boundary + b : Real number, :math:`a < b < \infty`, the right boundary + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import PowerFunction, density, cdf, E, variance + >>> from sympy import Symbol + >>> alpha = Symbol("alpha", positive=True) + >>> a = Symbol("a", real=True) + >>> b = Symbol("b", real=True) + >>> z = Symbol("z") + + >>> X = PowerFunction("X", 2, a, b) + + >>> density(X)(z) + (-2*a + 2*z)/(-a + b)**2 + + >>> cdf(X)(z) + Piecewise((a**2/(a**2 - 2*a*b + b**2) - 2*a*z/(a**2 - 2*a*b + b**2) + + z**2/(a**2 - 2*a*b + b**2), a <= z), (0, True)) + + >>> alpha = 2 + >>> a = 0 + >>> b = 1 + >>> Y = PowerFunction("Y", alpha, a, b) + + >>> E(Y) + 2/3 + + >>> variance(Y) + 1/18 + + References + ========== + + .. [1] https://web.archive.org/web/20200204081320/http://www.mathwave.com/help/easyfit/html/analyses/distributions/power_func.html + + """ + return rv(name, PowerFunctionDistribution, (alpha, a, b)) + +#------------------------------------------------------------------------------- +# QuadraticU distribution ------------------------------------------------------ + + +class QuadraticUDistribution(SingleContinuousDistribution): + _argnames = ('a', 'b') + + @property + def set(self): + return Interval(self.a, self.b) + + @staticmethod + def check(a, b): + _value_check(b > a, "Parameter b must be in range (%s, oo)."%(a)) + + def pdf(self, x): + a, b = self.a, self.b + alpha = 12 / (b-a)**3 + beta = (a+b) / 2 + return Piecewise( + (alpha * (x-beta)**2, And(a<=x, x<=b)), + (S.Zero, True)) + + def _moment_generating_function(self, t): + a, b = self.a, self.b + return -3 * (exp(a*t) * (4 + (a**2 + 2*a*(-2 + b) + b**2) * t) \ + - exp(b*t) * (4 + (-4*b + (a + b)**2) * t)) / ((a-b)**3 * t**2) + + def _characteristic_function(self, t): + a, b = self.a, self.b + return -3*I*(exp(I*a*t*exp(I*b*t)) * (4*I - (-4*b + (a+b)**2)*t)) \ + / ((a-b)**3 * t**2) + + +def QuadraticU(name, a, b): + r""" + Create a Continuous Random Variable with a U-quadratic distribution. + + Explanation + =========== + + The density of the U-quadratic distribution is given by + + .. math:: + f(x) := \alpha (x-\beta)^2 + + with :math:`x \in [a,b]`. + + Parameters + ========== + + a : Real number + b : Real number, :math:`a < b` + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import QuadraticU, density + >>> from sympy import Symbol, pprint + + >>> a = Symbol("a", real=True) + >>> b = Symbol("b", real=True) + >>> z = Symbol("z") + + >>> X = QuadraticU("x", a, b) + + >>> D = density(X)(z) + >>> pprint(D, use_unicode=False) + / 2 + | / a b \ + |12*|- - - - + z| + | \ 2 2 / + <----------------- for And(b >= z, a <= z) + | 3 + | (-a + b) + | + \ 0 otherwise + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/U-quadratic_distribution + + """ + + return rv(name, QuadraticUDistribution, (a, b)) + +#------------------------------------------------------------------------------- +# RaisedCosine distribution ---------------------------------------------------- + + +class RaisedCosineDistribution(SingleContinuousDistribution): + _argnames = ('mu', 's') + + @property + def set(self): + return Interval(self.mu - self.s, self.mu + self.s) + + @staticmethod + def check(mu, s): + _value_check(s > 0, "s must be positive") + + def pdf(self, x): + mu, s = self.mu, self.s + return Piecewise( + ((1+cos(pi*(x-mu)/s)) / (2*s), And(mu-s<=x, x<=mu+s)), + (S.Zero, True)) + + def _characteristic_function(self, t): + mu, s = self.mu, self.s + return Piecewise((exp(-I*pi*mu/s)/2, Eq(t, -pi/s)), + (exp(I*pi*mu/s)/2, Eq(t, pi/s)), + (pi**2*sin(s*t)*exp(I*mu*t) / (s*t*(pi**2 - s**2*t**2)), True)) + + def _moment_generating_function(self, t): + mu, s = self.mu, self.s + return pi**2 * sinh(s*t) * exp(mu*t) / (s*t*(pi**2 + s**2*t**2)) + +def RaisedCosine(name, mu, s): + r""" + Create a Continuous Random Variable with a raised cosine distribution. + + Explanation + =========== + + The density of the raised cosine distribution is given by + + .. math:: + f(x) := \frac{1}{2s}\left(1+\cos\left(\frac{x-\mu}{s}\pi\right)\right) + + with :math:`x \in [\mu-s,\mu+s]`. + + Parameters + ========== + + mu : Real number + s : Real number, `s > 0` + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import RaisedCosine, density + >>> from sympy import Symbol, pprint + + >>> mu = Symbol("mu", real=True) + >>> s = Symbol("s", positive=True) + >>> z = Symbol("z") + + >>> X = RaisedCosine("x", mu, s) + + >>> D = density(X)(z) + >>> pprint(D, use_unicode=False) + / /pi*(-mu + z)\ + |cos|------------| + 1 + | \ s / + <--------------------- for And(z >= mu - s, z <= mu + s) + | 2*s + | + \ 0 otherwise + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Raised_cosine_distribution + + """ + + return rv(name, RaisedCosineDistribution, (mu, s)) + +#------------------------------------------------------------------------------- +# Rayleigh distribution -------------------------------------------------------- + + +class RayleighDistribution(SingleContinuousDistribution): + _argnames = ('sigma',) + + set = Interval(0, oo) + + @staticmethod + def check(sigma): + _value_check(sigma > 0, "Scale parameter sigma must be positive.") + + def pdf(self, x): + sigma = self.sigma + return x/sigma**2*exp(-x**2/(2*sigma**2)) + + def _cdf(self, x): + sigma = self.sigma + return 1 - exp(-(x**2/(2*sigma**2))) + + def _characteristic_function(self, t): + sigma = self.sigma + return 1 - sigma*t*exp(-sigma**2*t**2/2) * sqrt(pi/2) * (erfi(sigma*t/sqrt(2)) - I) + + def _moment_generating_function(self, t): + sigma = self.sigma + return 1 + sigma*t*exp(sigma**2*t**2/2) * sqrt(pi/2) * (erf(sigma*t/sqrt(2)) + 1) + + +def Rayleigh(name, sigma): + r""" + Create a continuous random variable with a Rayleigh distribution. + + Explanation + =========== + + The density of the Rayleigh distribution is given by + + .. math :: + f(x) := \frac{x}{\sigma^2} e^{-x^2/2\sigma^2} + + with :math:`x > 0`. + + Parameters + ========== + + sigma : Real number, `\sigma > 0` + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Rayleigh, density, E, variance + >>> from sympy import Symbol + + >>> sigma = Symbol("sigma", positive=True) + >>> z = Symbol("z") + + >>> X = Rayleigh("x", sigma) + + >>> density(X)(z) + z*exp(-z**2/(2*sigma**2))/sigma**2 + + >>> E(X) + sqrt(2)*sqrt(pi)*sigma/2 + + >>> variance(X) + -pi*sigma**2/2 + 2*sigma**2 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Rayleigh_distribution + .. [2] https://mathworld.wolfram.com/RayleighDistribution.html + + """ + + return rv(name, RayleighDistribution, (sigma, )) + +#------------------------------------------------------------------------------- +# Reciprocal distribution -------------------------------------------------------- + +class ReciprocalDistribution(SingleContinuousDistribution): + _argnames = ('a', 'b') + + @property + def set(self): + return Interval(self.a, self.b) + + @staticmethod + def check(a, b): + _value_check(a > 0, "Parameter > 0. a = %s"%a) + _value_check((a < b), + "Parameter b must be in range (%s, +oo]. b = %s"%(a, b)) + + def pdf(self, x): + a, b = self.a, self.b + return 1/(x*(log(b) - log(a))) + + +def Reciprocal(name, a, b): + r"""Creates a continuous random variable with a reciprocal distribution. + + + Parameters + ========== + + a : Real number, :math:`0 < a` + b : Real number, :math:`a < b` + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Reciprocal, density, cdf + >>> from sympy import symbols + >>> a, b, x = symbols('a, b, x', positive=True) + >>> R = Reciprocal('R', a, b) + + >>> density(R)(x) + 1/(x*(-log(a) + log(b))) + >>> cdf(R)(x) + Piecewise((log(a)/(log(a) - log(b)) - log(x)/(log(a) - log(b)), a <= x), (0, True)) + + Reference + ========= + + .. [1] https://en.wikipedia.org/wiki/Reciprocal_distribution + + """ + return rv(name, ReciprocalDistribution, (a, b)) + + +#------------------------------------------------------------------------------- +# Shifted Gompertz distribution ------------------------------------------------ + + +class ShiftedGompertzDistribution(SingleContinuousDistribution): + _argnames = ('b', 'eta') + + set = Interval(0, oo) + + @staticmethod + def check(b, eta): + _value_check(b > 0, "b must be positive") + _value_check(eta > 0, "eta must be positive") + + def pdf(self, x): + b, eta = self.b, self.eta + return b*exp(-b*x)*exp(-eta*exp(-b*x))*(1+eta*(1-exp(-b*x))) + +def ShiftedGompertz(name, b, eta): + r""" + Create a continuous random variable with a Shifted Gompertz distribution. + + Explanation + =========== + + The density of the Shifted Gompertz distribution is given by + + .. math:: + f(x) := b e^{-b x} e^{-\eta \exp(-b x)} \left[1 + \eta(1 - e^(-bx)) \right] + + with :math:`x \in [0, \infty)`. + + Parameters + ========== + + b : Real number, `b > 0`, a scale + eta : Real number, `\eta > 0`, a shape + + Returns + ======= + + RandomSymbol + + Examples + ======== + >>> from sympy.stats import ShiftedGompertz, density + >>> from sympy import Symbol + + >>> b = Symbol("b", positive=True) + >>> eta = Symbol("eta", positive=True) + >>> x = Symbol("x") + + >>> X = ShiftedGompertz("x", b, eta) + + >>> density(X)(x) + b*(eta*(1 - exp(-b*x)) + 1)*exp(-b*x)*exp(-eta*exp(-b*x)) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Shifted_Gompertz_distribution + + """ + return rv(name, ShiftedGompertzDistribution, (b, eta)) + +#------------------------------------------------------------------------------- +# StudentT distribution -------------------------------------------------------- + + +class StudentTDistribution(SingleContinuousDistribution): + _argnames = ('nu',) + + set = Interval(-oo, oo) + + @staticmethod + def check(nu): + _value_check(nu > 0, "Degrees of freedom nu must be positive.") + + def pdf(self, x): + nu = self.nu + return 1/(sqrt(nu)*beta_fn(S.Half, nu/2))*(1 + x**2/nu)**(-(nu + 1)/2) + + def _cdf(self, x): + nu = self.nu + return S.Half + x*gamma((nu+1)/2)*hyper((S.Half, (nu+1)/2), + (Rational(3, 2),), -x**2/nu)/(sqrt(pi*nu)*gamma(nu/2)) + + def _moment_generating_function(self, t): + raise NotImplementedError('The moment generating function for the Student-T distribution is undefined.') + + +def StudentT(name, nu): + r""" + Create a continuous random variable with a student's t distribution. + + Explanation + =========== + + The density of the student's t distribution is given by + + .. math:: + f(x) := \frac{\Gamma \left(\frac{\nu+1}{2} \right)} + {\sqrt{\nu\pi}\Gamma \left(\frac{\nu}{2} \right)} + \left(1+\frac{x^2}{\nu} \right)^{-\frac{\nu+1}{2}} + + Parameters + ========== + + nu : Real number, `\nu > 0`, the degrees of freedom + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import StudentT, density, cdf + >>> from sympy import Symbol, pprint + + >>> nu = Symbol("nu", positive=True) + >>> z = Symbol("z") + + >>> X = StudentT("x", nu) + + >>> D = density(X)(z) + >>> pprint(D, use_unicode=False) + nu 1 + - -- - - + 2 2 + / 2\ + | z | + |1 + --| + \ nu/ + ----------------- + ____ / nu\ + \/ nu *B|1/2, --| + \ 2 / + + >>> cdf(X)(z) + 1/2 + z*gamma(nu/2 + 1/2)*hyper((1/2, nu/2 + 1/2), (3/2,), + -z**2/nu)/(sqrt(pi)*sqrt(nu)*gamma(nu/2)) + + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Student_t-distribution + .. [2] https://mathworld.wolfram.com/Studentst-Distribution.html + + """ + + return rv(name, StudentTDistribution, (nu, )) + +#------------------------------------------------------------------------------- +# Trapezoidal distribution ------------------------------------------------------ + + +class TrapezoidalDistribution(SingleContinuousDistribution): + _argnames = ('a', 'b', 'c', 'd') + + @property + def set(self): + return Interval(self.a, self.d) + + @staticmethod + def check(a, b, c, d): + _value_check(a < d, "Lower bound parameter a < %s. a = %s"%(d, a)) + _value_check((a <= b, b < c), + "Level start parameter b must be in range [%s, %s). b = %s"%(a, c, b)) + _value_check((b < c, c <= d), + "Level end parameter c must be in range (%s, %s]. c = %s"%(b, d, c)) + _value_check(d >= c, "Upper bound parameter d > %s. d = %s"%(c, d)) + + def pdf(self, x): + a, b, c, d = self.a, self.b, self.c, self.d + return Piecewise( + (2*(x-a) / ((b-a)*(d+c-a-b)), And(a <= x, x < b)), + (2 / (d+c-a-b), And(b <= x, x < c)), + (2*(d-x) / ((d-c)*(d+c-a-b)), And(c <= x, x <= d)), + (S.Zero, True)) + +def Trapezoidal(name, a, b, c, d): + r""" + Create a continuous random variable with a trapezoidal distribution. + + Explanation + =========== + + The density of the trapezoidal distribution is given by + + .. math:: + f(x) := \begin{cases} + 0 & \mathrm{for\ } x < a, \\ + \frac{2(x-a)}{(b-a)(d+c-a-b)} & \mathrm{for\ } a \le x < b, \\ + \frac{2}{d+c-a-b} & \mathrm{for\ } b \le x < c, \\ + \frac{2(d-x)}{(d-c)(d+c-a-b)} & \mathrm{for\ } c \le x < d, \\ + 0 & \mathrm{for\ } d < x. + \end{cases} + + Parameters + ========== + + a : Real number, :math:`a < d` + b : Real number, :math:`a \le b < c` + c : Real number, :math:`b < c \le d` + d : Real number + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Trapezoidal, density + >>> from sympy import Symbol, pprint + + >>> a = Symbol("a") + >>> b = Symbol("b") + >>> c = Symbol("c") + >>> d = Symbol("d") + >>> z = Symbol("z") + + >>> X = Trapezoidal("x", a,b,c,d) + + >>> pprint(density(X)(z), use_unicode=False) + / -2*a + 2*z + |------------------------- for And(a <= z, b > z) + |(-a + b)*(-a - b + c + d) + | + | 2 + | -------------- for And(b <= z, c > z) + < -a - b + c + d + | + | 2*d - 2*z + |------------------------- for And(d >= z, c <= z) + |(-c + d)*(-a - b + c + d) + | + \ 0 otherwise + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Trapezoidal_distribution + + """ + return rv(name, TrapezoidalDistribution, (a, b, c, d)) + +#------------------------------------------------------------------------------- +# Triangular distribution ------------------------------------------------------ + + +class TriangularDistribution(SingleContinuousDistribution): + _argnames = ('a', 'b', 'c') + + @property + def set(self): + return Interval(self.a, self.b) + + @staticmethod + def check(a, b, c): + _value_check(b > a, "Parameter b > %s. b = %s"%(a, b)) + _value_check((a <= c, c <= b), + "Parameter c must be in range [%s, %s]. c = %s"%(a, b, c)) + + def pdf(self, x): + a, b, c = self.a, self.b, self.c + return Piecewise( + (2*(x - a)/((b - a)*(c - a)), And(a <= x, x < c)), + (2/(b - a), Eq(x, c)), + (2*(b - x)/((b - a)*(b - c)), And(c < x, x <= b)), + (S.Zero, True)) + + def _characteristic_function(self, t): + a, b, c = self.a, self.b, self.c + return -2 *((b-c) * exp(I*a*t) - (b-a) * exp(I*c*t) + (c-a) * exp(I*b*t)) / ((b-a)*(c-a)*(b-c)*t**2) + + def _moment_generating_function(self, t): + a, b, c = self.a, self.b, self.c + return 2 * ((b - c) * exp(a * t) - (b - a) * exp(c * t) + (c - a) * exp(b * t)) / ( + (b - a) * (c - a) * (b - c) * t ** 2) + + +def Triangular(name, a, b, c): + r""" + Create a continuous random variable with a triangular distribution. + + Explanation + =========== + + The density of the triangular distribution is given by + + .. math:: + f(x) := \begin{cases} + 0 & \mathrm{for\ } x < a, \\ + \frac{2(x-a)}{(b-a)(c-a)} & \mathrm{for\ } a \le x < c, \\ + \frac{2}{b-a} & \mathrm{for\ } x = c, \\ + \frac{2(b-x)}{(b-a)(b-c)} & \mathrm{for\ } c < x \le b, \\ + 0 & \mathrm{for\ } b < x. + \end{cases} + + Parameters + ========== + + a : Real number, :math:`a \in \left(-\infty, \infty\right)` + b : Real number, :math:`a < b` + c : Real number, :math:`a \leq c \leq b` + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Triangular, density + >>> from sympy import Symbol, pprint + + >>> a = Symbol("a") + >>> b = Symbol("b") + >>> c = Symbol("c") + >>> z = Symbol("z") + + >>> X = Triangular("x", a,b,c) + + >>> pprint(density(X)(z), use_unicode=False) + / -2*a + 2*z + |----------------- for And(a <= z, c > z) + |(-a + b)*(-a + c) + | + | 2 + | ------ for c = z + < -a + b + | + | 2*b - 2*z + |---------------- for And(b >= z, c < z) + |(-a + b)*(b - c) + | + \ 0 otherwise + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Triangular_distribution + .. [2] https://mathworld.wolfram.com/TriangularDistribution.html + + """ + + return rv(name, TriangularDistribution, (a, b, c)) + +#------------------------------------------------------------------------------- +# Uniform distribution --------------------------------------------------------- + + +class UniformDistribution(SingleContinuousDistribution): + _argnames = ('left', 'right') + + @property + def set(self): + return Interval(self.left, self.right) + + @staticmethod + def check(left, right): + _value_check(left < right, "Lower limit should be less than Upper limit.") + + def pdf(self, x): + left, right = self.left, self.right + return Piecewise( + (S.One/(right - left), And(left <= x, x <= right)), + (S.Zero, True) + ) + + def _cdf(self, x): + left, right = self.left, self.right + return Piecewise( + (S.Zero, x < left), + ((x - left)/(right - left), x <= right), + (S.One, True) + ) + + def _characteristic_function(self, t): + left, right = self.left, self.right + return Piecewise(((exp(I*t*right) - exp(I*t*left)) / (I*t*(right - left)), Ne(t, 0)), + (S.One, True)) + + def _moment_generating_function(self, t): + left, right = self.left, self.right + return Piecewise(((exp(t*right) - exp(t*left)) / (t * (right - left)), Ne(t, 0)), + (S.One, True)) + + def expectation(self, expr, var, **kwargs): + kwargs['evaluate'] = True + result = SingleContinuousDistribution.expectation(self, expr, var, **kwargs) + result = result.subs({Max(self.left, self.right): self.right, + Min(self.left, self.right): self.left}) + return result + + +def Uniform(name, left, right): + r""" + Create a continuous random variable with a uniform distribution. + + Explanation + =========== + + The density of the uniform distribution is given by + + .. math:: + f(x) := \begin{cases} + \frac{1}{b - a} & \text{for } x \in [a,b] \\ + 0 & \text{otherwise} + \end{cases} + + with :math:`x \in [a,b]`. + + Parameters + ========== + + a : Real number, :math:`-\infty < a`, the left boundary + b : Real number, :math:`a < b < \infty`, the right boundary + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Uniform, density, cdf, E, variance + >>> from sympy import Symbol, simplify + + >>> a = Symbol("a", negative=True) + >>> b = Symbol("b", positive=True) + >>> z = Symbol("z") + + >>> X = Uniform("x", a, b) + + >>> density(X)(z) + Piecewise((1/(-a + b), (b >= z) & (a <= z)), (0, True)) + + >>> cdf(X)(z) + Piecewise((0, a > z), ((-a + z)/(-a + b), b >= z), (1, True)) + + >>> E(X) + a/2 + b/2 + + >>> simplify(variance(X)) + a**2/12 - a*b/6 + b**2/12 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Uniform_distribution_%28continuous%29 + .. [2] https://mathworld.wolfram.com/UniformDistribution.html + + """ + + return rv(name, UniformDistribution, (left, right)) + +#------------------------------------------------------------------------------- +# UniformSum distribution ------------------------------------------------------ + + +class UniformSumDistribution(SingleContinuousDistribution): + _argnames = ('n',) + + @property + def set(self): + return Interval(0, self.n) + + @staticmethod + def check(n): + _value_check((n > 0, n.is_integer), + "Parameter n must be positive integer.") + + def pdf(self, x): + n = self.n + k = Dummy("k") + return 1/factorial( + n - 1)*Sum((-1)**k*binomial(n, k)*(x - k)**(n - 1), (k, 0, floor(x))) + + def _cdf(self, x): + n = self.n + k = Dummy("k") + return Piecewise((S.Zero, x < 0), + (1/factorial(n)*Sum((-1)**k*binomial(n, k)*(x - k)**(n), + (k, 0, floor(x))), x <= n), + (S.One, True)) + + def _characteristic_function(self, t): + return ((exp(I*t) - 1) / (I*t))**self.n + + def _moment_generating_function(self, t): + return ((exp(t) - 1) / t)**self.n + +def UniformSum(name, n): + r""" + Create a continuous random variable with an Irwin-Hall distribution. + + Explanation + =========== + + The probability distribution function depends on a single parameter + $n$ which is an integer. + + The density of the Irwin-Hall distribution is given by + + .. math :: + f(x) := \frac{1}{(n-1)!}\sum_{k=0}^{\left\lfloor x\right\rfloor}(-1)^k + \binom{n}{k}(x-k)^{n-1} + + Parameters + ========== + + n : A positive integer, `n > 0` + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import UniformSum, density, cdf + >>> from sympy import Symbol, pprint + + >>> n = Symbol("n", integer=True) + >>> z = Symbol("z") + + >>> X = UniformSum("x", n) + + >>> D = density(X)(z) + >>> pprint(D, use_unicode=False) + floor(z) + ___ + \ ` + \ k n - 1 /n\ + ) (-1) *(-k + z) *| | + / \k/ + /__, + k = 0 + -------------------------------- + (n - 1)! + + >>> cdf(X)(z) + Piecewise((0, z < 0), (Sum((-1)**_k*(-_k + z)**n*binomial(n, _k), + (_k, 0, floor(z)))/factorial(n), n >= z), (1, True)) + + + Compute cdf with specific 'x' and 'n' values as follows : + >>> cdf(UniformSum("x", 5), evaluate=False)(2).doit() + 9/40 + + The argument evaluate=False prevents an attempt at evaluation + of the sum for general n, before the argument 2 is passed. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Uniform_sum_distribution + .. [2] https://mathworld.wolfram.com/UniformSumDistribution.html + + """ + + return rv(name, UniformSumDistribution, (n, )) + +#------------------------------------------------------------------------------- +# VonMises distribution -------------------------------------------------------- + + +class VonMisesDistribution(SingleContinuousDistribution): + _argnames = ('mu', 'k') + + set = Interval(0, 2*pi) + + @staticmethod + def check(mu, k): + _value_check(k > 0, "k must be positive") + + def pdf(self, x): + mu, k = self.mu, self.k + return exp(k*cos(x-mu)) / (2*pi*besseli(0, k)) + +def VonMises(name, mu, k): + r""" + Create a Continuous Random Variable with a von Mises distribution. + + Explanation + =========== + + The density of the von Mises distribution is given by + + .. math:: + f(x) := \frac{e^{\kappa\cos(x-\mu)}}{2\pi I_0(\kappa)} + + with :math:`x \in [0,2\pi]`. + + Parameters + ========== + + mu : Real number + Measure of location. + k : Real number + Measure of concentration. + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import VonMises, density + >>> from sympy import Symbol, pprint + + >>> mu = Symbol("mu") + >>> k = Symbol("k", positive=True) + >>> z = Symbol("z") + + >>> X = VonMises("x", mu, k) + + >>> D = density(X)(z) + >>> pprint(D, use_unicode=False) + k*cos(mu - z) + e + ------------------ + 2*pi*besseli(0, k) + + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Von_Mises_distribution + .. [2] https://mathworld.wolfram.com/vonMisesDistribution.html + + """ + + return rv(name, VonMisesDistribution, (mu, k)) + +#------------------------------------------------------------------------------- +# Weibull distribution --------------------------------------------------------- + + +class WeibullDistribution(SingleContinuousDistribution): + _argnames = ('alpha', 'beta') + + set = Interval(0, oo) + + @staticmethod + def check(alpha, beta): + _value_check(alpha > 0, "Alpha must be positive") + _value_check(beta > 0, "Beta must be positive") + + def pdf(self, x): + alpha, beta = self.alpha, self.beta + return beta * (x/alpha)**(beta - 1) * exp(-(x/alpha)**beta) / alpha + + +def Weibull(name, alpha, beta): + r""" + Create a continuous random variable with a Weibull distribution. + + Explanation + =========== + + The density of the Weibull distribution is given by + + .. math:: + f(x) := \begin{cases} + \frac{k}{\lambda}\left(\frac{x}{\lambda}\right)^{k-1} + e^{-(x/\lambda)^{k}} & x\geq0\\ + 0 & x<0 + \end{cases} + + Parameters + ========== + + lambda : Real number, $\lambda > 0$, a scale + k : Real number, $k > 0$, a shape + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Weibull, density, E, variance + >>> from sympy import Symbol, simplify + + >>> l = Symbol("lambda", positive=True) + >>> k = Symbol("k", positive=True) + >>> z = Symbol("z") + + >>> X = Weibull("x", l, k) + + >>> density(X)(z) + k*(z/lambda)**(k - 1)*exp(-(z/lambda)**k)/lambda + + >>> simplify(E(X)) + lambda*gamma(1 + 1/k) + + >>> simplify(variance(X)) + lambda**2*(-gamma(1 + 1/k)**2 + gamma(1 + 2/k)) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Weibull_distribution + .. [2] https://mathworld.wolfram.com/WeibullDistribution.html + + """ + + return rv(name, WeibullDistribution, (alpha, beta)) + +#------------------------------------------------------------------------------- +# Wigner semicircle distribution ----------------------------------------------- + + +class WignerSemicircleDistribution(SingleContinuousDistribution): + _argnames = ('R',) + + @property + def set(self): + return Interval(-self.R, self.R) + + @staticmethod + def check(R): + _value_check(R > 0, "Radius R must be positive.") + + def pdf(self, x): + R = self.R + return 2/(pi*R**2)*sqrt(R**2 - x**2) + + def _characteristic_function(self, t): + return Piecewise((2 * besselj(1, self.R*t) / (self.R*t), Ne(t, 0)), + (S.One, True)) + + def _moment_generating_function(self, t): + return Piecewise((2 * besseli(1, self.R*t) / (self.R*t), Ne(t, 0)), + (S.One, True)) + +def WignerSemicircle(name, R): + r""" + Create a continuous random variable with a Wigner semicircle distribution. + + Explanation + =========== + + The density of the Wigner semicircle distribution is given by + + .. math:: + f(x) := \frac2{\pi R^2}\,\sqrt{R^2-x^2} + + with :math:`x \in [-R,R]`. + + Parameters + ========== + + R : Real number, `R > 0`, the radius + + Returns + ======= + + A RandomSymbol. + + Examples + ======== + + >>> from sympy.stats import WignerSemicircle, density, E + >>> from sympy import Symbol + + >>> R = Symbol("R", positive=True) + >>> z = Symbol("z") + + >>> X = WignerSemicircle("x", R) + + >>> density(X)(z) + 2*sqrt(R**2 - z**2)/(pi*R**2) + + >>> E(X) + 0 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Wigner_semicircle_distribution + .. [2] https://mathworld.wolfram.com/WignersSemicircleLaw.html + + """ + + return rv(name, WignerSemicircleDistribution, (R,)) diff --git a/phivenv/Lib/site-packages/sympy/stats/drv.py b/phivenv/Lib/site-packages/sympy/stats/drv.py new file mode 100644 index 0000000000000000000000000000000000000000..dea14f2dfd1078c223c61bb5cd1373105e72ea28 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/stats/drv.py @@ -0,0 +1,350 @@ +from sympy.concrete.summations import (Sum, summation) +from sympy.core.basic import Basic +from sympy.core.cache import cacheit +from sympy.core.function import Lambda +from sympy.core.numbers import I +from sympy.core.relational import (Eq, Ne) +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, symbols) +from sympy.core.sympify import sympify +from sympy.functions.combinatorial.factorials import factorial +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.integers import floor +from sympy.functions.elementary.piecewise import Piecewise +from sympy.logic.boolalg import And +from sympy.polys.polytools import poly +from sympy.series.series import series + +from sympy.polys.polyerrors import PolynomialError +from sympy.stats.crv import reduce_rational_inequalities_wrap +from sympy.stats.rv import (NamedArgsMixin, SinglePSpace, SingleDomain, + random_symbols, PSpace, ConditionalDomain, RandomDomain, + ProductDomain, Distribution) +from sympy.stats.symbolic_probability import Probability +from sympy.sets.fancysets import Range, FiniteSet +from sympy.sets.sets import Union +from sympy.sets.contains import Contains +from sympy.utilities import filldedent +from sympy.core.sympify import _sympify + + +class DiscreteDistribution(Distribution): + def __call__(self, *args): + return self.pdf(*args) + + +class SingleDiscreteDistribution(DiscreteDistribution, NamedArgsMixin): + """ Discrete distribution of a single variable. + + Serves as superclass for PoissonDistribution etc.... + + Provides methods for pdf, cdf, and sampling + + See Also: + sympy.stats.crv_types.* + """ + + set = S.Integers + + def __new__(cls, *args): + args = list(map(sympify, args)) + return Basic.__new__(cls, *args) + + @staticmethod + def check(*args): + pass + + @cacheit + def compute_cdf(self, **kwargs): + """ Compute the CDF from the PDF. + + Returns a Lambda. + """ + x = symbols('x', integer=True, cls=Dummy) + z = symbols('z', real=True, cls=Dummy) + left_bound = self.set.inf + + # CDF is integral of PDF from left bound to z + pdf = self.pdf(x) + cdf = summation(pdf, (x, left_bound, floor(z)), **kwargs) + # CDF Ensure that CDF left of left_bound is zero + cdf = Piecewise((cdf, z >= left_bound), (0, True)) + return Lambda(z, cdf) + + def _cdf(self, x): + return None + + def cdf(self, x, **kwargs): + """ Cumulative density function """ + if not kwargs: + cdf = self._cdf(x) + if cdf is not None: + return cdf + return self.compute_cdf(**kwargs)(x) + + @cacheit + def compute_characteristic_function(self, **kwargs): + """ Compute the characteristic function from the PDF. + + Returns a Lambda. + """ + x, t = symbols('x, t', real=True, cls=Dummy) + pdf = self.pdf(x) + cf = summation(exp(I*t*x)*pdf, (x, self.set.inf, self.set.sup)) + return Lambda(t, cf) + + def _characteristic_function(self, t): + return None + + def characteristic_function(self, t, **kwargs): + """ Characteristic function """ + if not kwargs: + cf = self._characteristic_function(t) + if cf is not None: + return cf + return self.compute_characteristic_function(**kwargs)(t) + + @cacheit + def compute_moment_generating_function(self, **kwargs): + t = Dummy('t', real=True) + x = Dummy('x', integer=True) + pdf = self.pdf(x) + mgf = summation(exp(t*x)*pdf, (x, self.set.inf, self.set.sup)) + return Lambda(t, mgf) + + def _moment_generating_function(self, t): + return None + + def moment_generating_function(self, t, **kwargs): + if not kwargs: + mgf = self._moment_generating_function(t) + if mgf is not None: + return mgf + return self.compute_moment_generating_function(**kwargs)(t) + + @cacheit + def compute_quantile(self, **kwargs): + """ Compute the Quantile from the PDF. + + Returns a Lambda. + """ + x = Dummy('x', integer=True) + p = Dummy('p', real=True) + left_bound = self.set.inf + pdf = self.pdf(x) + cdf = summation(pdf, (x, left_bound, x), **kwargs) + set = ((x, p <= cdf), ) + return Lambda(p, Piecewise(*set)) + + def _quantile(self, x): + return None + + def quantile(self, x, **kwargs): + """ Cumulative density function """ + if not kwargs: + quantile = self._quantile(x) + if quantile is not None: + return quantile + return self.compute_quantile(**kwargs)(x) + + def expectation(self, expr, var, evaluate=True, **kwargs): + """ Expectation of expression over distribution """ + # TODO: support discrete sets with non integer stepsizes + + if evaluate: + try: + p = poly(expr, var) + + t = Dummy('t', real=True) + + mgf = self.moment_generating_function(t) + deg = p.degree() + taylor = poly(series(mgf, t, 0, deg + 1).removeO(), t) + result = 0 + for k in range(deg+1): + result += p.coeff_monomial(var ** k) * taylor.coeff_monomial(t ** k) * factorial(k) + + return result + + except PolynomialError: + return summation(expr * self.pdf(var), + (var, self.set.inf, self.set.sup), **kwargs) + + else: + return Sum(expr * self.pdf(var), + (var, self.set.inf, self.set.sup), **kwargs) + + def __call__(self, *args): + return self.pdf(*args) + + +class DiscreteDomain(RandomDomain): + """ + A domain with discrete support with step size one. + Represented using symbols and Range. + """ + is_Discrete = True + +class SingleDiscreteDomain(DiscreteDomain, SingleDomain): + def as_boolean(self): + return Contains(self.symbol, self.set) + + +class ConditionalDiscreteDomain(DiscreteDomain, ConditionalDomain): + """ + Domain with discrete support of step size one, that is restricted by + some condition. + """ + @property + def set(self): + rv = self.symbols + if len(self.symbols) > 1: + raise NotImplementedError(filldedent(''' + Multivariate conditional domains are not yet implemented.''')) + rv = list(rv)[0] + return reduce_rational_inequalities_wrap(self.condition, + rv).intersect(self.fulldomain.set) + + +class DiscretePSpace(PSpace): + is_real = True + is_Discrete = True + + @property + def pdf(self): + return self.density(*self.symbols) + + def where(self, condition): + rvs = random_symbols(condition) + assert all(r.symbol in self.symbols for r in rvs) + if len(rvs) > 1: + raise NotImplementedError(filldedent('''Multivariate discrete + random variables are not yet supported.''')) + conditional_domain = reduce_rational_inequalities_wrap(condition, + rvs[0]) + conditional_domain = conditional_domain.intersect(self.domain.set) + return SingleDiscreteDomain(rvs[0].symbol, conditional_domain) + + def probability(self, condition): + complement = isinstance(condition, Ne) + if complement: + condition = Eq(condition.args[0], condition.args[1]) + try: + _domain = self.where(condition).set + if condition == False or _domain is S.EmptySet: + return S.Zero + if condition == True or _domain == self.domain.set: + return S.One + prob = self.eval_prob(_domain) + except NotImplementedError: + from sympy.stats.rv import density + expr = condition.lhs - condition.rhs + dens = density(expr) + if not isinstance(dens, DiscreteDistribution): + from sympy.stats.drv_types import DiscreteDistributionHandmade + dens = DiscreteDistributionHandmade(dens) + z = Dummy('z', real=True) + space = SingleDiscretePSpace(z, dens) + prob = space.probability(condition.__class__(space.value, 0)) + if prob is None: + prob = Probability(condition) + return prob if not complement else S.One - prob + + def eval_prob(self, _domain): + sym = list(self.symbols)[0] + if isinstance(_domain, Range): + n = symbols('n', integer=True) + inf, sup, step = (r for r in _domain.args) + summand = ((self.pdf).replace( + sym, n*step)) + rv = summation(summand, + (n, inf/step, (sup)/step - 1)).doit() + return rv + elif isinstance(_domain, FiniteSet): + pdf = Lambda(sym, self.pdf) + rv = sum(pdf(x) for x in _domain) + return rv + elif isinstance(_domain, Union): + rv = sum(self.eval_prob(x) for x in _domain.args) + return rv + + def conditional_space(self, condition): + # XXX: Converting from set to tuple. The order matters to Lambda + # though so we should be starting with a set... + density = Lambda(tuple(self.symbols), self.pdf/self.probability(condition)) + condition = condition.xreplace({rv: rv.symbol for rv in self.values}) + domain = ConditionalDiscreteDomain(self.domain, condition) + return DiscretePSpace(domain, density) + +class ProductDiscreteDomain(ProductDomain, DiscreteDomain): + def as_boolean(self): + return And(*[domain.as_boolean for domain in self.domains]) + +class SingleDiscretePSpace(DiscretePSpace, SinglePSpace): + """ Discrete probability space over a single univariate variable """ + is_real = True + + @property + def set(self): + return self.distribution.set + + @property + def domain(self): + return SingleDiscreteDomain(self.symbol, self.set) + + def sample(self, size=(), library='scipy', seed=None): + """ + Internal sample method. + + Returns dictionary mapping RandomSymbol to realization value. + """ + return {self.value: self.distribution.sample(size, library=library, seed=seed)} + + def compute_expectation(self, expr, rvs=None, evaluate=True, **kwargs): + rvs = rvs or (self.value,) + if self.value not in rvs: + return expr + + expr = _sympify(expr) + expr = expr.xreplace({rv: rv.symbol for rv in rvs}) + + x = self.value.symbol + try: + return self.distribution.expectation(expr, x, evaluate=evaluate, + **kwargs) + except NotImplementedError: + return Sum(expr * self.pdf, (x, self.set.inf, self.set.sup), + **kwargs) + + def compute_cdf(self, expr, **kwargs): + if expr == self.value: + x = Dummy("x", real=True) + return Lambda(x, self.distribution.cdf(x, **kwargs)) + else: + raise NotImplementedError() + + def compute_density(self, expr, **kwargs): + if expr == self.value: + return self.distribution + raise NotImplementedError() + + def compute_characteristic_function(self, expr, **kwargs): + if expr == self.value: + t = Dummy("t", real=True) + return Lambda(t, self.distribution.characteristic_function(t, **kwargs)) + else: + raise NotImplementedError() + + def compute_moment_generating_function(self, expr, **kwargs): + if expr == self.value: + t = Dummy("t", real=True) + return Lambda(t, self.distribution.moment_generating_function(t, **kwargs)) + else: + raise NotImplementedError() + + def compute_quantile(self, expr, **kwargs): + if expr == self.value: + p = Dummy("p", real=True) + return Lambda(p, self.distribution.quantile(p, **kwargs)) + else: + raise NotImplementedError() diff --git a/phivenv/Lib/site-packages/sympy/stats/drv_types.py b/phivenv/Lib/site-packages/sympy/stats/drv_types.py new file mode 100644 index 0000000000000000000000000000000000000000..84920d31c0083828efc2cd3f752d2c48f5430102 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/stats/drv_types.py @@ -0,0 +1,849 @@ +""" + +Contains +======== +FlorySchulz +Geometric +Hermite +Logarithmic +NegativeBinomial +Poisson +Skellam +YuleSimon +Zeta +""" + + + +from sympy.concrete.summations import Sum +from sympy.core.basic import Basic +from sympy.core.function import Lambda +from sympy.core.numbers import I +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import Dummy +from sympy.core.sympify import sympify +from sympy.functions.combinatorial.factorials import (binomial, factorial, FallingFactorial) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.integers import floor +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.special.bessel import besseli +from sympy.functions.special.beta_functions import beta +from sympy.functions.special.hyper import hyper +from sympy.functions.special.zeta_functions import (polylog, zeta) +from sympy.stats.drv import SingleDiscreteDistribution, SingleDiscretePSpace +from sympy.stats.rv import _value_check, is_random + + +__all__ = ['FlorySchulz', +'Geometric', +'Hermite', +'Logarithmic', +'NegativeBinomial', +'Poisson', +'Skellam', +'YuleSimon', +'Zeta' +] + + +def rv(symbol, cls, *args, **kwargs): + args = list(map(sympify, args)) + dist = cls(*args) + if kwargs.pop('check', True): + dist.check(*args) + pspace = SingleDiscretePSpace(symbol, dist) + if any(is_random(arg) for arg in args): + from sympy.stats.compound_rv import CompoundPSpace, CompoundDistribution + pspace = CompoundPSpace(symbol, CompoundDistribution(dist)) + return pspace.value + + +class DiscreteDistributionHandmade(SingleDiscreteDistribution): + _argnames = ('pdf',) + + def __new__(cls, pdf, set=S.Integers): + return Basic.__new__(cls, pdf, set) + + @property + def set(self): + return self.args[1] + + @staticmethod + def check(pdf, set): + x = Dummy('x') + val = Sum(pdf(x), (x, set._inf, set._sup)).doit() + _value_check(Eq(val, 1) != S.false, "The pdf is incorrect on the given set.") + + + +def DiscreteRV(symbol, density, set=S.Integers, **kwargs): + """ + Create a Discrete Random Variable given the following: + + Parameters + ========== + + symbol : Symbol + Represents name of the random variable. + density : Expression containing symbol + Represents probability density function. + set : set + Represents the region where the pdf is valid, by default is real line. + check : bool + If True, it will check whether the given density + integrates to 1 over the given set. If False, it + will not perform this check. Default is False. + + Examples + ======== + + >>> from sympy.stats import DiscreteRV, P, E + >>> from sympy import Rational, Symbol + >>> x = Symbol('x') + >>> n = 10 + >>> density = Rational(1, 10) + >>> X = DiscreteRV(x, density, set=set(range(n))) + >>> E(X) + 9/2 + >>> P(X>3) + 3/5 + + Returns + ======= + + RandomSymbol + + """ + set = sympify(set) + pdf = Piecewise((density, set.as_relational(symbol)), (0, True)) + pdf = Lambda(symbol, pdf) + # have a default of False while `rv` should have a default of True + kwargs['check'] = kwargs.pop('check', False) + return rv(symbol.name, DiscreteDistributionHandmade, pdf, set, **kwargs) + + +#------------------------------------------------------------------------------- +# Flory-Schulz distribution ------------------------------------------------------------ + +class FlorySchulzDistribution(SingleDiscreteDistribution): + _argnames = ('a',) + set = S.Naturals + + @staticmethod + def check(a): + _value_check((0 < a, a < 1), "a must be between 0 and 1") + + def pdf(self, k): + a = self.a + return (a**2 * k * (1 - a)**(k - 1)) + + def _characteristic_function(self, t): + a = self.a + return a**2*exp(I*t)/((1 + (a - 1)*exp(I*t))**2) + + def _moment_generating_function(self, t): + a = self.a + return a**2*exp(t)/((1 + (a - 1)*exp(t))**2) + + +def FlorySchulz(name, a): + r""" + Create a discrete random variable with a FlorySchulz distribution. + + The density of the FlorySchulz distribution is given by + + .. math:: + f(k) := (a^2) k (1 - a)^{k-1} + + Parameters + ========== + + a : A real number between 0 and 1 + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import density, E, variance, FlorySchulz + >>> from sympy import Symbol, S + + >>> a = S.One / 5 + >>> z = Symbol("z") + + >>> X = FlorySchulz("x", a) + + >>> density(X)(z) + (4/5)**(z - 1)*z/25 + + >>> E(X) + 9 + + >>> variance(X) + 40 + + References + ========== + + https://en.wikipedia.org/wiki/Flory%E2%80%93Schulz_distribution + """ + return rv(name, FlorySchulzDistribution, a) + + +#------------------------------------------------------------------------------- +# Geometric distribution ------------------------------------------------------------ + +class GeometricDistribution(SingleDiscreteDistribution): + _argnames = ('p',) + set = S.Naturals + + @staticmethod + def check(p): + _value_check((0 < p, p <= 1), "p must be between 0 and 1") + + def pdf(self, k): + return (1 - self.p)**(k - 1) * self.p + + def _characteristic_function(self, t): + p = self.p + return p * exp(I*t) / (1 - (1 - p)*exp(I*t)) + + def _moment_generating_function(self, t): + p = self.p + return p * exp(t) / (1 - (1 - p) * exp(t)) + + +def Geometric(name, p): + r""" + Create a discrete random variable with a Geometric distribution. + + Explanation + =========== + + The density of the Geometric distribution is given by + + .. math:: + f(k) := p (1 - p)^{k - 1} + + Parameters + ========== + + p : A probability between 0 and 1 + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Geometric, density, E, variance + >>> from sympy import Symbol, S + + >>> p = S.One / 5 + >>> z = Symbol("z") + + >>> X = Geometric("x", p) + + >>> density(X)(z) + (4/5)**(z - 1)/5 + + >>> E(X) + 5 + + >>> variance(X) + 20 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Geometric_distribution + .. [2] https://mathworld.wolfram.com/GeometricDistribution.html + + """ + return rv(name, GeometricDistribution, p) + + +#------------------------------------------------------------------------------- +# Hermite distribution --------------------------------------------------------- + + +class HermiteDistribution(SingleDiscreteDistribution): + _argnames = ('a1', 'a2') + set = S.Naturals0 + + @staticmethod + def check(a1, a2): + _value_check(a1.is_nonnegative, 'Parameter a1 must be >= 0.') + _value_check(a2.is_nonnegative, 'Parameter a2 must be >= 0.') + + def pdf(self, k): + a1, a2 = self.a1, self.a2 + term1 = exp(-(a1 + a2)) + j = Dummy("j", integer=True) + num = a1**(k - 2*j) * a2**j + den = factorial(k - 2*j) * factorial(j) + return term1 * Sum(num/den, (j, 0, k//2)).doit() + + def _moment_generating_function(self, t): + a1, a2 = self.a1, self.a2 + term1 = a1 * (exp(t) - 1) + term2 = a2 * (exp(2*t) - 1) + return exp(term1 + term2) + + def _characteristic_function(self, t): + a1, a2 = self.a1, self.a2 + term1 = a1 * (exp(I*t) - 1) + term2 = a2 * (exp(2*I*t) - 1) + return exp(term1 + term2) + +def Hermite(name, a1, a2): + r""" + Create a discrete random variable with a Hermite distribution. + + Explanation + =========== + + The density of the Hermite distribution is given by + + .. math:: + f(x):= e^{-a_1 -a_2}\sum_{j=0}^{\left \lfloor x/2 \right \rfloor} + \frac{a_{1}^{x-2j}a_{2}^{j}}{(x-2j)!j!} + + Parameters + ========== + + a1 : A Positive number greater than equal to 0. + a2 : A Positive number greater than equal to 0. + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Hermite, density, E, variance + >>> from sympy import Symbol + + >>> a1 = Symbol("a1", positive=True) + >>> a2 = Symbol("a2", positive=True) + >>> x = Symbol("x") + + >>> H = Hermite("H", a1=5, a2=4) + + >>> density(H)(2) + 33*exp(-9)/2 + + >>> E(H) + 13 + + >>> variance(H) + 21 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Hermite_distribution + + """ + + return rv(name, HermiteDistribution, a1, a2) + + +#------------------------------------------------------------------------------- +# Logarithmic distribution ------------------------------------------------------------ + +class LogarithmicDistribution(SingleDiscreteDistribution): + _argnames = ('p',) + + set = S.Naturals + + @staticmethod + def check(p): + _value_check((p > 0, p < 1), "p should be between 0 and 1") + + def pdf(self, k): + p = self.p + return (-1) * p**k / (k * log(1 - p)) + + def _characteristic_function(self, t): + p = self.p + return log(1 - p * exp(I*t)) / log(1 - p) + + def _moment_generating_function(self, t): + p = self.p + return log(1 - p * exp(t)) / log(1 - p) + + +def Logarithmic(name, p): + r""" + Create a discrete random variable with a Logarithmic distribution. + + Explanation + =========== + + The density of the Logarithmic distribution is given by + + .. math:: + f(k) := \frac{-p^k}{k \ln{(1 - p)}} + + Parameters + ========== + + p : A value between 0 and 1 + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Logarithmic, density, E, variance + >>> from sympy import Symbol, S + + >>> p = S.One / 5 + >>> z = Symbol("z") + + >>> X = Logarithmic("x", p) + + >>> density(X)(z) + -1/(5**z*z*log(4/5)) + + >>> E(X) + -1/(-4*log(5) + 8*log(2)) + + >>> variance(X) + -1/((-4*log(5) + 8*log(2))*(-2*log(5) + 4*log(2))) + 1/(-64*log(2)*log(5) + 64*log(2)**2 + 16*log(5)**2) - 10/(-32*log(5) + 64*log(2)) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Logarithmic_distribution + .. [2] https://mathworld.wolfram.com/LogarithmicDistribution.html + + """ + return rv(name, LogarithmicDistribution, p) + + +#------------------------------------------------------------------------------- +# Negative binomial distribution ------------------------------------------------------------ + +class NegativeBinomialDistribution(SingleDiscreteDistribution): + _argnames = ('r', 'p') + set = S.Naturals0 + + @staticmethod + def check(r, p): + _value_check(r > 0, 'r should be positive') + _value_check((p > 0, p < 1), 'p should be between 0 and 1') + + def pdf(self, k): + r = self.r + p = self.p + + return binomial(k + r - 1, k) * (1 - p)**k * p**r + + def _characteristic_function(self, t): + r = self.r + p = self.p + + return (p / (1 - (1 - p) * exp(I*t)))**r + + def _moment_generating_function(self, t): + r = self.r + p = self.p + + return (p / (1 - (1 - p) * exp(t)))**r + +def NegativeBinomial(name, r, p): + r""" + Create a discrete random variable with a Negative Binomial distribution. + + Explanation + =========== + + The density of the Negative Binomial distribution is given by + + .. math:: + f(k) := \binom{k + r - 1}{k} (1 - p)^k p^r + + Parameters + ========== + + r : A positive value + Number of successes until the experiment is stopped. + p : A value between 0 and 1 + Probability of success. + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import NegativeBinomial, density, E, variance + >>> from sympy import Symbol, S + + >>> r = 5 + >>> p = S.One / 3 + >>> z = Symbol("z") + + >>> X = NegativeBinomial("x", r, p) + + >>> density(X)(z) + (2/3)**z*binomial(z + 4, z)/243 + + >>> E(X) + 10 + + >>> variance(X) + 30 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Negative_binomial_distribution + .. [2] https://mathworld.wolfram.com/NegativeBinomialDistribution.html + + """ + return rv(name, NegativeBinomialDistribution, r, p) + + +#------------------------------------------------------------------------------- +# Poisson distribution ------------------------------------------------------------ + +class PoissonDistribution(SingleDiscreteDistribution): + _argnames = ('lamda',) + + set = S.Naturals0 + + @staticmethod + def check(lamda): + _value_check(lamda > 0, "Lambda must be positive") + + def pdf(self, k): + return self.lamda**k / factorial(k) * exp(-self.lamda) + + def _characteristic_function(self, t): + return exp(self.lamda * (exp(I*t) - 1)) + + def _moment_generating_function(self, t): + return exp(self.lamda * (exp(t) - 1)) + + def expectation(self, expr, var, evaluate=True, **kwargs): + if evaluate: + if expr == var: + return self.lamda + if ( + isinstance(expr, FallingFactorial) + and expr.args[1].is_integer + and expr.args[1].is_positive + and expr.args[0] == var + ): + return self.lamda ** expr.args[1] + return super().expectation(expr, var, evaluate, **kwargs) + +def Poisson(name, lamda): + r""" + Create a discrete random variable with a Poisson distribution. + + Explanation + =========== + + The density of the Poisson distribution is given by + + .. math:: + f(k) := \frac{\lambda^{k} e^{- \lambda}}{k!} + + Parameters + ========== + + lamda : Positive number, a rate + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Poisson, density, E, variance + >>> from sympy import Symbol, simplify + + >>> rate = Symbol("lambda", positive=True) + >>> z = Symbol("z") + + >>> X = Poisson("x", rate) + + >>> density(X)(z) + lambda**z*exp(-lambda)/factorial(z) + + >>> E(X) + lambda + + >>> simplify(variance(X)) + lambda + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Poisson_distribution + .. [2] https://mathworld.wolfram.com/PoissonDistribution.html + + """ + return rv(name, PoissonDistribution, lamda) + + +# ----------------------------------------------------------------------------- +# Skellam distribution -------------------------------------------------------- + + +class SkellamDistribution(SingleDiscreteDistribution): + _argnames = ('mu1', 'mu2') + set = S.Integers + + @staticmethod + def check(mu1, mu2): + _value_check(mu1 >= 0, 'Parameter mu1 must be >= 0') + _value_check(mu2 >= 0, 'Parameter mu2 must be >= 0') + + def pdf(self, k): + (mu1, mu2) = (self.mu1, self.mu2) + term1 = exp(-(mu1 + mu2)) * (mu1 / mu2) ** (k / 2) + term2 = besseli(k, 2 * sqrt(mu1 * mu2)) + return term1 * term2 + + def _cdf(self, x): + raise NotImplementedError( + "Skellam doesn't have closed form for the CDF.") + + def _characteristic_function(self, t): + (mu1, mu2) = (self.mu1, self.mu2) + return exp(-(mu1 + mu2) + mu1 * exp(I * t) + mu2 * exp(-I * t)) + + def _moment_generating_function(self, t): + (mu1, mu2) = (self.mu1, self.mu2) + return exp(-(mu1 + mu2) + mu1 * exp(t) + mu2 * exp(-t)) + + +def Skellam(name, mu1, mu2): + r""" + Create a discrete random variable with a Skellam distribution. + + Explanation + =========== + + The Skellam is the distribution of the difference N1 - N2 + of two statistically independent random variables N1 and N2 + each Poisson-distributed with respective expected values mu1 and mu2. + + The density of the Skellam distribution is given by + + .. math:: + f(k) := e^{-(\mu_1+\mu_2)}(\frac{\mu_1}{\mu_2})^{k/2}I_k(2\sqrt{\mu_1\mu_2}) + + Parameters + ========== + + mu1 : A non-negative value + mu2 : A non-negative value + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Skellam, density, E, variance + >>> from sympy import Symbol, pprint + + >>> z = Symbol("z", integer=True) + >>> mu1 = Symbol("mu1", positive=True) + >>> mu2 = Symbol("mu2", positive=True) + >>> X = Skellam("x", mu1, mu2) + + >>> pprint(density(X)(z), use_unicode=False) + z + - + 2 + /mu1\ -mu1 - mu2 / _____ _____\ + |---| *e *besseli\z, 2*\/ mu1 *\/ mu2 / + \mu2/ + >>> E(X) + mu1 - mu2 + >>> variance(X).expand() + mu1 + mu2 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Skellam_distribution + + """ + return rv(name, SkellamDistribution, mu1, mu2) + + +#------------------------------------------------------------------------------- +# Yule-Simon distribution ------------------------------------------------------------ + +class YuleSimonDistribution(SingleDiscreteDistribution): + _argnames = ('rho',) + set = S.Naturals + + @staticmethod + def check(rho): + _value_check(rho > 0, 'rho should be positive') + + def pdf(self, k): + rho = self.rho + return rho * beta(k, rho + 1) + + def _cdf(self, x): + return Piecewise((1 - floor(x) * beta(floor(x), self.rho + 1), x >= 1), (0, True)) + + def _characteristic_function(self, t): + rho = self.rho + return rho * hyper((1, 1), (rho + 2,), exp(I*t)) * exp(I*t) / (rho + 1) + + def _moment_generating_function(self, t): + rho = self.rho + return rho * hyper((1, 1), (rho + 2,), exp(t)) * exp(t) / (rho + 1) + + +def YuleSimon(name, rho): + r""" + Create a discrete random variable with a Yule-Simon distribution. + + Explanation + =========== + + The density of the Yule-Simon distribution is given by + + .. math:: + f(k) := \rho B(k, \rho + 1) + + Parameters + ========== + + rho : A positive value + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import YuleSimon, density, E, variance + >>> from sympy import Symbol, simplify + + >>> p = 5 + >>> z = Symbol("z") + + >>> X = YuleSimon("x", p) + + >>> density(X)(z) + 5*beta(z, 6) + + >>> simplify(E(X)) + 5/4 + + >>> simplify(variance(X)) + 25/48 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Yule%E2%80%93Simon_distribution + + """ + return rv(name, YuleSimonDistribution, rho) + + +#------------------------------------------------------------------------------- +# Zeta distribution ------------------------------------------------------------ + +class ZetaDistribution(SingleDiscreteDistribution): + _argnames = ('s',) + set = S.Naturals + + @staticmethod + def check(s): + _value_check(s > 1, 's should be greater than 1') + + def pdf(self, k): + s = self.s + return 1 / (k**s * zeta(s)) + + def _characteristic_function(self, t): + return polylog(self.s, exp(I*t)) / zeta(self.s) + + def _moment_generating_function(self, t): + return polylog(self.s, exp(t)) / zeta(self.s) + + +def Zeta(name, s): + r""" + Create a discrete random variable with a Zeta distribution. + + Explanation + =========== + + The density of the Zeta distribution is given by + + .. math:: + f(k) := \frac{1}{k^s \zeta{(s)}} + + Parameters + ========== + + s : A value greater than 1 + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import Zeta, density, E, variance + >>> from sympy import Symbol + + >>> s = 5 + >>> z = Symbol("z") + + >>> X = Zeta("x", s) + + >>> density(X)(z) + 1/(z**5*zeta(5)) + + >>> E(X) + pi**4/(90*zeta(5)) + + >>> variance(X) + -pi**8/(8100*zeta(5)**2) + zeta(3)/zeta(5) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Zeta_distribution + + """ + return rv(name, ZetaDistribution, s) diff --git a/phivenv/Lib/site-packages/sympy/stats/error_prop.py b/phivenv/Lib/site-packages/sympy/stats/error_prop.py new file mode 100644 index 0000000000000000000000000000000000000000..e6cacb894307fe60cbf096c7760e6ed57f385a91 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/stats/error_prop.py @@ -0,0 +1,100 @@ +"""Tools for arithmetic error propagation.""" + +from itertools import repeat, combinations + +from sympy.core.add import Add +from sympy.core.mul import Mul +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.functions.elementary.exponential import exp +from sympy.simplify.simplify import simplify +from sympy.stats.symbolic_probability import RandomSymbol, Variance, Covariance +from sympy.stats.rv import is_random + +_arg0_or_var = lambda var: var.args[0] if len(var.args) > 0 else var + + +def variance_prop(expr, consts=(), include_covar=False): + r"""Symbolically propagates variance (`\sigma^2`) for expressions. + This is computed as as seen in [1]_. + + Parameters + ========== + + expr : Expr + A SymPy expression to compute the variance for. + consts : sequence of Symbols, optional + Represents symbols that are known constants in the expr, + and thus have zero variance. All symbols not in consts are + assumed to be variant. + include_covar : bool, optional + Flag for whether or not to include covariances, default=False. + + Returns + ======= + + var_expr : Expr + An expression for the total variance of the expr. + The variance for the original symbols (e.g. x) are represented + via instance of the Variance symbol (e.g. Variance(x)). + + Examples + ======== + + >>> from sympy import symbols, exp + >>> from sympy.stats.error_prop import variance_prop + >>> x, y = symbols('x y') + + >>> variance_prop(x + y) + Variance(x) + Variance(y) + + >>> variance_prop(x * y) + x**2*Variance(y) + y**2*Variance(x) + + >>> variance_prop(exp(2*x)) + 4*exp(4*x)*Variance(x) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Propagation_of_uncertainty + + """ + args = expr.args + if len(args) == 0: + if expr in consts: + return S.Zero + elif is_random(expr): + return Variance(expr).doit() + elif isinstance(expr, Symbol): + return Variance(RandomSymbol(expr)).doit() + else: + return S.Zero + nargs = len(args) + var_args = list(map(variance_prop, args, repeat(consts, nargs), + repeat(include_covar, nargs))) + if isinstance(expr, Add): + var_expr = Add(*var_args) + if include_covar: + terms = [2 * Covariance(_arg0_or_var(x), _arg0_or_var(y)).expand() \ + for x, y in combinations(var_args, 2)] + var_expr += Add(*terms) + elif isinstance(expr, Mul): + terms = [v/a**2 for a, v in zip(args, var_args)] + var_expr = simplify(expr**2 * Add(*terms)) + if include_covar: + terms = [2*Covariance(_arg0_or_var(x), _arg0_or_var(y)).expand()/(a*b) \ + for (a, b), (x, y) in zip(combinations(args, 2), + combinations(var_args, 2))] + var_expr += Add(*terms) + elif isinstance(expr, Pow): + b = args[1] + v = var_args[0] * (expr * b / args[0])**2 + var_expr = simplify(v) + elif isinstance(expr, exp): + var_expr = simplify(var_args[0] * expr**2) + else: + # unknown how to proceed, return variance of whole expr. + var_expr = Variance(expr) + return var_expr diff --git a/phivenv/Lib/site-packages/sympy/stats/frv.py b/phivenv/Lib/site-packages/sympy/stats/frv.py new file mode 100644 index 0000000000000000000000000000000000000000..498d7e4006b2b8db306a0905ed67578021e220a8 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/stats/frv.py @@ -0,0 +1,512 @@ +""" +Finite Discrete Random Variables Module + +See Also +======== +sympy.stats.frv_types +sympy.stats.rv +sympy.stats.crv +""" +from itertools import product + +from sympy.concrete.summations import Sum +from sympy.core.basic import Basic +from sympy.core.cache import cacheit +from sympy.core.function import Lambda +from sympy.core.mul import Mul +from sympy.core.numbers import (I, nan) +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol) +from sympy.core.sympify import sympify +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.piecewise import Piecewise +from sympy.logic.boolalg import (And, Or) +from sympy.sets.sets import Intersection +from sympy.core.containers import Dict +from sympy.core.logic import Logic +from sympy.core.relational import Relational +from sympy.core.sympify import _sympify +from sympy.sets.sets import FiniteSet +from sympy.stats.rv import (RandomDomain, ProductDomain, ConditionalDomain, + PSpace, IndependentProductPSpace, SinglePSpace, random_symbols, + sumsets, rv_subs, NamedArgsMixin, Density, Distribution) + + +class FiniteDensity(dict): + """ + A domain with Finite Density. + """ + def __call__(self, item): + """ + Make instance of a class callable. + + If item belongs to current instance of a class, return it. + + Otherwise, return 0. + """ + item = sympify(item) + if item in self: + return self[item] + else: + return 0 + + @property + def dict(self): + """ + Return item as dictionary. + """ + return dict(self) + +class FiniteDomain(RandomDomain): + """ + A domain with discrete finite support + + Represented using a FiniteSet. + """ + is_Finite = True + + @property + def symbols(self): + return FiniteSet(sym for sym, val in self.elements) + + @property + def elements(self): + return self.args[0] + + @property + def dict(self): + return FiniteSet(*[Dict(dict(el)) for el in self.elements]) + + def __contains__(self, other): + return other in self.elements + + def __iter__(self): + return self.elements.__iter__() + + def as_boolean(self): + return Or(*[And(*[Eq(sym, val) for sym, val in item]) for item in self]) + + +class SingleFiniteDomain(FiniteDomain): + """ + A FiniteDomain over a single symbol/set + + Example: The possibilities of a *single* die roll. + """ + + def __new__(cls, symbol, set): + if not isinstance(set, FiniteSet) and \ + not isinstance(set, Intersection): + set = FiniteSet(*set) + return Basic.__new__(cls, symbol, set) + + @property + def symbol(self): + return self.args[0] + + @property + def symbols(self): + return FiniteSet(self.symbol) + + @property + def set(self): + return self.args[1] + + @property + def elements(self): + return FiniteSet(*[frozenset(((self.symbol, elem), )) for elem in self.set]) + + def __iter__(self): + return (frozenset(((self.symbol, elem),)) for elem in self.set) + + def __contains__(self, other): + sym, val = tuple(other)[0] + return sym == self.symbol and val in self.set + + +class ProductFiniteDomain(ProductDomain, FiniteDomain): + """ + A Finite domain consisting of several other FiniteDomains + + Example: The possibilities of the rolls of three independent dice + """ + + def __iter__(self): + proditer = product(*self.domains) + return (sumsets(items) for items in proditer) + + @property + def elements(self): + return FiniteSet(*self) + + +class ConditionalFiniteDomain(ConditionalDomain, ProductFiniteDomain): + """ + A FiniteDomain that has been restricted by a condition + + Example: The possibilities of a die roll under the condition that the + roll is even. + """ + + def __new__(cls, domain, condition): + """ + Create a new instance of ConditionalFiniteDomain class + """ + if condition is True: + return domain + cond = rv_subs(condition) + return Basic.__new__(cls, domain, cond) + + def _test(self, elem): + """ + Test the value. If value is boolean, return it. If value is equality + relational (two objects are equal), return it with left-hand side + being equal to right-hand side. Otherwise, raise ValueError exception. + """ + val = self.condition.xreplace(dict(elem)) + if val in [True, False]: + return val + elif val.is_Equality: + return val.lhs == val.rhs + raise ValueError("Undecidable if %s" % str(val)) + + def __contains__(self, other): + return other in self.fulldomain and self._test(other) + + def __iter__(self): + return (elem for elem in self.fulldomain if self._test(elem)) + + @property + def set(self): + if isinstance(self.fulldomain, SingleFiniteDomain): + return FiniteSet(*[elem for elem in self.fulldomain.set + if frozenset(((self.fulldomain.symbol, elem),)) in self]) + else: + raise NotImplementedError( + "Not implemented on multi-dimensional conditional domain") + + def as_boolean(self): + return FiniteDomain.as_boolean(self) + + +class SingleFiniteDistribution(Distribution, NamedArgsMixin): + def __new__(cls, *args): + args = list(map(sympify, args)) + return Basic.__new__(cls, *args) + + @staticmethod + def check(*args): + pass + + @property # type: ignore + @cacheit + def dict(self): + if self.is_symbolic: + return Density(self) + return {k: self.pmf(k) for k in self.set} + + def pmf(self, *args): # to be overridden by specific distribution + raise NotImplementedError() + + @property + def set(self): # to be overridden by specific distribution + raise NotImplementedError() + + values = property(lambda self: self.dict.values) + items = property(lambda self: self.dict.items) + is_symbolic = property(lambda self: False) + __iter__ = property(lambda self: self.dict.__iter__) + __getitem__ = property(lambda self: self.dict.__getitem__) + + def __call__(self, *args): + return self.pmf(*args) + + def __contains__(self, other): + return other in self.set + + +#============================================= +#========= Probability Space =============== +#============================================= + + +class FinitePSpace(PSpace): + """ + A Finite Probability Space + + Represents the probabilities of a finite number of events. + """ + is_Finite = True + + def __new__(cls, domain, density): + density = {sympify(key): sympify(val) + for key, val in density.items()} + public_density = Dict(density) + + obj = PSpace.__new__(cls, domain, public_density) + obj._density = density + return obj + + def prob_of(self, elem): + elem = sympify(elem) + density = self._density + if isinstance(list(density.keys())[0], FiniteSet): + return density.get(elem, S.Zero) + return density.get(tuple(elem)[0][1], S.Zero) + + def where(self, condition): + assert all(r.symbol in self.symbols for r in random_symbols(condition)) + return ConditionalFiniteDomain(self.domain, condition) + + def compute_density(self, expr): + expr = rv_subs(expr, self.values) + d = FiniteDensity() + for elem in self.domain: + val = expr.xreplace(dict(elem)) + prob = self.prob_of(elem) + d[val] = d.get(val, S.Zero) + prob + return d + + @cacheit + def compute_cdf(self, expr): + d = self.compute_density(expr) + cum_prob = S.Zero + cdf = [] + for key in sorted(d): + prob = d[key] + cum_prob += prob + cdf.append((key, cum_prob)) + + return dict(cdf) + + @cacheit + def sorted_cdf(self, expr, python_float=False): + cdf = self.compute_cdf(expr) + items = list(cdf.items()) + sorted_items = sorted(items, key=lambda val_cumprob: val_cumprob[1]) + if python_float: + sorted_items = [(v, float(cum_prob)) + for v, cum_prob in sorted_items] + return sorted_items + + @cacheit + def compute_characteristic_function(self, expr): + d = self.compute_density(expr) + t = Dummy('t', real=True) + + return Lambda(t, sum(exp(I*k*t)*v for k,v in d.items())) + + @cacheit + def compute_moment_generating_function(self, expr): + d = self.compute_density(expr) + t = Dummy('t', real=True) + + return Lambda(t, sum(exp(k*t)*v for k,v in d.items())) + + def compute_expectation(self, expr, rvs=None, **kwargs): + rvs = rvs or self.values + expr = rv_subs(expr, rvs) + probs = [self.prob_of(elem) for elem in self.domain] + if isinstance(expr, (Logic, Relational)): + parse_domain = [tuple(elem)[0][1] for elem in self.domain] + bools = [expr.xreplace(dict(elem)) for elem in self.domain] + else: + parse_domain = [expr.xreplace(dict(elem)) for elem in self.domain] + bools = [True for elem in self.domain] + return sum(Piecewise((prob * elem, blv), (S.Zero, True)) + for prob, elem, blv in zip(probs, parse_domain, bools)) + + def compute_quantile(self, expr): + cdf = self.compute_cdf(expr) + p = Dummy('p', real=True) + set = ((nan, (p < 0) | (p > 1)),) + for key, value in cdf.items(): + set = set + ((key, p <= value), ) + return Lambda(p, Piecewise(*set)) + + def probability(self, condition): + cond_symbols = frozenset(rs.symbol for rs in random_symbols(condition)) + cond = rv_subs(condition) + if not cond_symbols.issubset(self.symbols): + raise ValueError("Cannot compare foreign random symbols, %s" + %(str(cond_symbols - self.symbols))) + if isinstance(condition, Relational) and \ + (not cond.free_symbols.issubset(self.domain.free_symbols)): + rv = condition.lhs if isinstance(condition.rhs, Symbol) else condition.rhs + return sum(Piecewise( + (self.prob_of(elem), condition.subs(rv, list(elem)[0][1])), + (S.Zero, True)) for elem in self.domain) + return sympify(sum(self.prob_of(elem) for elem in self.where(condition))) + + def conditional_space(self, condition): + domain = self.where(condition) + prob = self.probability(condition) + density = {key: val / prob + for key, val in self._density.items() if domain._test(key)} + return FinitePSpace(domain, density) + + def sample(self, size=(), library='scipy', seed=None): + """ + Internal sample method + + Returns dictionary mapping RandomSymbol to realization value. + """ + return {self.value: self.distribution.sample(size, library, seed)} + + +class SingleFinitePSpace(SinglePSpace, FinitePSpace): + """ + A single finite probability space + + Represents the probabilities of a set of random events that can be + attributed to a single variable/symbol. + + This class is implemented by many of the standard FiniteRV types such as + Die, Bernoulli, Coin, etc.... + """ + @property + def domain(self): + return SingleFiniteDomain(self.symbol, self.distribution.set) + + @property + def _is_symbolic(self): + """ + Helper property to check if the distribution + of the random variable is having symbolic + dimension. + """ + return self.distribution.is_symbolic + + @property + def distribution(self): + return self.args[1] + + def pmf(self, expr): + return self.distribution.pmf(expr) + + @property # type: ignore + @cacheit + def _density(self): + return {FiniteSet((self.symbol, val)): prob + for val, prob in self.distribution.dict.items()} + + @cacheit + def compute_characteristic_function(self, expr): + if self._is_symbolic: + d = self.compute_density(expr) + t = Dummy('t', real=True) + ki = Dummy('ki') + return Lambda(t, Sum(d(ki)*exp(I*ki*t), (ki, self.args[1].low, self.args[1].high))) + expr = rv_subs(expr, self.values) + return FinitePSpace(self.domain, self.distribution).compute_characteristic_function(expr) + + @cacheit + def compute_moment_generating_function(self, expr): + if self._is_symbolic: + d = self.compute_density(expr) + t = Dummy('t', real=True) + ki = Dummy('ki') + return Lambda(t, Sum(d(ki)*exp(ki*t), (ki, self.args[1].low, self.args[1].high))) + expr = rv_subs(expr, self.values) + return FinitePSpace(self.domain, self.distribution).compute_moment_generating_function(expr) + + def compute_quantile(self, expr): + if self._is_symbolic: + raise NotImplementedError("Computing quantile for random variables " + "with symbolic dimension because the bounds of searching the required " + "value is undetermined.") + expr = rv_subs(expr, self.values) + return FinitePSpace(self.domain, self.distribution).compute_quantile(expr) + + def compute_density(self, expr): + if self._is_symbolic: + rv = list(random_symbols(expr))[0] + k = Dummy('k', integer=True) + cond = True if not isinstance(expr, (Relational, Logic)) \ + else expr.subs(rv, k) + return Lambda(k, + Piecewise((self.pmf(k), And(k >= self.args[1].low, + k <= self.args[1].high, cond)), (S.Zero, True))) + expr = rv_subs(expr, self.values) + return FinitePSpace(self.domain, self.distribution).compute_density(expr) + + def compute_cdf(self, expr): + if self._is_symbolic: + d = self.compute_density(expr) + k = Dummy('k') + ki = Dummy('ki') + return Lambda(k, Sum(d(ki), (ki, self.args[1].low, k))) + expr = rv_subs(expr, self.values) + return FinitePSpace(self.domain, self.distribution).compute_cdf(expr) + + def compute_expectation(self, expr, rvs=None, **kwargs): + if self._is_symbolic: + rv = random_symbols(expr)[0] + k = Dummy('k', integer=True) + expr = expr.subs(rv, k) + cond = True if not isinstance(expr, (Relational, Logic)) \ + else expr + func = self.pmf(k) * k if cond != True else self.pmf(k) * expr + return Sum(Piecewise((func, cond), (S.Zero, True)), + (k, self.distribution.low, self.distribution.high)).doit() + + expr = _sympify(expr) + expr = rv_subs(expr, rvs) + return FinitePSpace(self.domain, self.distribution).compute_expectation(expr, rvs, **kwargs) + + def probability(self, condition): + if self._is_symbolic: + #TODO: Implement the mechanism for handling queries for symbolic sized distributions. + raise NotImplementedError("Currently, probability queries are not " + "supported for random variables with symbolic sized distributions.") + condition = rv_subs(condition) + return FinitePSpace(self.domain, self.distribution).probability(condition) + + def conditional_space(self, condition): + """ + This method is used for transferring the + computation to probability method because + conditional space of random variables with + symbolic dimensions is currently not possible. + """ + if self._is_symbolic: + self + domain = self.where(condition) + prob = self.probability(condition) + density = {key: val / prob + for key, val in self._density.items() if domain._test(key)} + return FinitePSpace(domain, density) + + +class ProductFinitePSpace(IndependentProductPSpace, FinitePSpace): + """ + A collection of several independent finite probability spaces + """ + @property + def domain(self): + return ProductFiniteDomain(*[space.domain for space in self.spaces]) + + @property # type: ignore + @cacheit + def _density(self): + proditer = product(*[iter(space._density.items()) + for space in self.spaces]) + d = {} + for items in proditer: + elems, probs = list(zip(*items)) + elem = sumsets(elems) + prob = Mul(*probs) + d[elem] = d.get(elem, S.Zero) + prob + return Dict(d) + + @property # type: ignore + @cacheit + def density(self): + return Dict(self._density) + + def probability(self, condition): + return FinitePSpace.probability(self, condition) + + def compute_density(self, expr): + return FinitePSpace.compute_density(self, expr) diff --git a/phivenv/Lib/site-packages/sympy/stats/frv_types.py b/phivenv/Lib/site-packages/sympy/stats/frv_types.py new file mode 100644 index 0000000000000000000000000000000000000000..bde656c219791c287ff445d5d215e3759271e923 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/stats/frv_types.py @@ -0,0 +1,873 @@ +""" +Finite Discrete Random Variables - Prebuilt variable types + +Contains +======== +FiniteRV +DiscreteUniform +Die +Bernoulli +Coin +Binomial +BetaBinomial +Hypergeometric +Rademacher +IdealSoliton +RobustSoliton +""" + + +from sympy.core.cache import cacheit +from sympy.core.function import Lambda +from sympy.core.numbers import (Integer, Rational) +from sympy.core.relational import (Eq, Ge, Gt, Le, Lt) +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol) +from sympy.core.sympify import sympify +from sympy.functions.combinatorial.factorials import binomial +from sympy.functions.elementary.exponential import log +from sympy.functions.elementary.piecewise import Piecewise +from sympy.logic.boolalg import Or +from sympy.sets.contains import Contains +from sympy.sets.fancysets import Range +from sympy.sets.sets import (Intersection, Interval) +from sympy.functions.special.beta_functions import beta as beta_fn +from sympy.stats.frv import (SingleFiniteDistribution, + SingleFinitePSpace) +from sympy.stats.rv import _value_check, Density, is_random +from sympy.utilities.iterables import multiset +from sympy.utilities.misc import filldedent + + +__all__ = ['FiniteRV', +'DiscreteUniform', +'Die', +'Bernoulli', +'Coin', +'Binomial', +'BetaBinomial', +'Hypergeometric', +'Rademacher', +'IdealSoliton', +'RobustSoliton', +] + +def rv(name, cls, *args, **kwargs): + args = list(map(sympify, args)) + dist = cls(*args) + if kwargs.pop('check', True): + dist.check(*args) + pspace = SingleFinitePSpace(name, dist) + if any(is_random(arg) for arg in args): + from sympy.stats.compound_rv import CompoundPSpace, CompoundDistribution + pspace = CompoundPSpace(name, CompoundDistribution(dist)) + return pspace.value + +class FiniteDistributionHandmade(SingleFiniteDistribution): + + @property + def dict(self): + return self.args[0] + + def pmf(self, x): + x = Symbol('x') + return Lambda(x, Piecewise(*( + [(v, Eq(k, x)) for k, v in self.dict.items()] + [(S.Zero, True)]))) + + @property + def set(self): + return set(self.dict.keys()) + + @staticmethod + def check(density): + for p in density.values(): + _value_check((p >= 0, p <= 1), + "Probability at a point must be between 0 and 1.") + val = sum(density.values()) + _value_check(Eq(val, 1) != S.false, "Total Probability must be 1.") + +def FiniteRV(name, density, **kwargs): + r""" + Create a Finite Random Variable given a dict representing the density. + + Parameters + ========== + + name : Symbol + Represents name of the random variable. + density : dict + Dictionary containing the pdf of finite distribution + check : bool + If True, it will check whether the given density + integrates to 1 over the given set. If False, it + will not perform this check. Default is False. + + Examples + ======== + + >>> from sympy.stats import FiniteRV, P, E + + >>> density = {0: .1, 1: .2, 2: .3, 3: .4} + >>> X = FiniteRV('X', density) + + >>> E(X) + 2.00000000000000 + >>> P(X >= 2) + 0.700000000000000 + + Returns + ======= + + RandomSymbol + + """ + # have a default of False while `rv` should have a default of True + kwargs['check'] = kwargs.pop('check', False) + return rv(name, FiniteDistributionHandmade, density, **kwargs) + +class DiscreteUniformDistribution(SingleFiniteDistribution): + + @staticmethod + def check(*args): + # not using _value_check since there is a + # suggestion for the user + if len(set(args)) != len(args): + weights = multiset(args) + n = Integer(len(args)) + for k in weights: + weights[k] /= n + raise ValueError(filldedent(""" + Repeated args detected but set expected. For a + distribution having different weights for each + item use the following:""") + ( + '\nS("FiniteRV(%s, %s)")' % ("'X'", weights))) + + @property + def p(self): + return Rational(1, len(self.args)) + + @property # type: ignore + @cacheit + def dict(self): + return dict.fromkeys(self.set, self.p) + + @property + def set(self): + return set(self.args) + + def pmf(self, x): + if x in self.args: + return self.p + else: + return S.Zero + + +def DiscreteUniform(name, items): + r""" + Create a Finite Random Variable representing a uniform distribution over + the input set. + + Parameters + ========== + + items : list/tuple + Items over which Uniform distribution is to be made + + Examples + ======== + + >>> from sympy.stats import DiscreteUniform, density + >>> from sympy import symbols + + >>> X = DiscreteUniform('X', symbols('a b c')) # equally likely over a, b, c + >>> density(X).dict + {a: 1/3, b: 1/3, c: 1/3} + + >>> Y = DiscreteUniform('Y', list(range(5))) # distribution over a range + >>> density(Y).dict + {0: 1/5, 1: 1/5, 2: 1/5, 3: 1/5, 4: 1/5} + + Returns + ======= + + RandomSymbol + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Discrete_uniform_distribution + .. [2] https://mathworld.wolfram.com/DiscreteUniformDistribution.html + + """ + return rv(name, DiscreteUniformDistribution, *items) + + +class DieDistribution(SingleFiniteDistribution): + _argnames = ('sides',) + + @staticmethod + def check(sides): + _value_check((sides.is_positive, sides.is_integer), + "number of sides must be a positive integer.") + + @property + def is_symbolic(self): + return not self.sides.is_number + + @property + def high(self): + return self.sides + + @property + def low(self): + return S.One + + @property + def set(self): + if self.is_symbolic: + return Intersection(S.Naturals0, Interval(0, self.sides)) + return set(map(Integer, range(1, self.sides + 1))) + + def pmf(self, x): + x = sympify(x) + if not (x.is_number or x.is_Symbol or is_random(x)): + raise ValueError("'x' expected as an argument of type 'number', 'Symbol', or " + "'RandomSymbol' not %s" % (type(x))) + cond = Ge(x, 1) & Le(x, self.sides) & Contains(x, S.Integers) + return Piecewise((S.One/self.sides, cond), (S.Zero, True)) + +def Die(name, sides=6): + r""" + Create a Finite Random Variable representing a fair die. + + Parameters + ========== + + sides : Integer + Represents the number of sides of the Die, by default is 6 + + Examples + ======== + + >>> from sympy.stats import Die, density + >>> from sympy import Symbol + + >>> D6 = Die('D6', 6) # Six sided Die + >>> density(D6).dict + {1: 1/6, 2: 1/6, 3: 1/6, 4: 1/6, 5: 1/6, 6: 1/6} + + >>> D4 = Die('D4', 4) # Four sided Die + >>> density(D4).dict + {1: 1/4, 2: 1/4, 3: 1/4, 4: 1/4} + + >>> n = Symbol('n', positive=True, integer=True) + >>> Dn = Die('Dn', n) # n sided Die + >>> density(Dn).dict + Density(DieDistribution(n)) + >>> density(Dn).dict.subs(n, 4).doit() + {1: 1/4, 2: 1/4, 3: 1/4, 4: 1/4} + + Returns + ======= + + RandomSymbol + """ + + return rv(name, DieDistribution, sides) + + +class BernoulliDistribution(SingleFiniteDistribution): + _argnames = ('p', 'succ', 'fail') + + @staticmethod + def check(p, succ, fail): + _value_check((p >= 0, p <= 1), + "p should be in range [0, 1].") + + @property + def set(self): + return {self.succ, self.fail} + + def pmf(self, x): + if isinstance(self.succ, Symbol) and isinstance(self.fail, Symbol): + return Piecewise((self.p, x == self.succ), + (1 - self.p, x == self.fail), + (S.Zero, True)) + return Piecewise((self.p, Eq(x, self.succ)), + (1 - self.p, Eq(x, self.fail)), + (S.Zero, True)) + + +def Bernoulli(name, p, succ=1, fail=0): + r""" + Create a Finite Random Variable representing a Bernoulli process. + + Parameters + ========== + + p : Rational number between 0 and 1 + Represents probability of success + succ : Integer/symbol/string + Represents event of success + fail : Integer/symbol/string + Represents event of failure + + Examples + ======== + + >>> from sympy.stats import Bernoulli, density + >>> from sympy import S + + >>> X = Bernoulli('X', S(3)/4) # 1-0 Bernoulli variable, probability = 3/4 + >>> density(X).dict + {0: 1/4, 1: 3/4} + + >>> X = Bernoulli('X', S.Half, 'Heads', 'Tails') # A fair coin toss + >>> density(X).dict + {Heads: 1/2, Tails: 1/2} + + Returns + ======= + + RandomSymbol + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Bernoulli_distribution + .. [2] https://mathworld.wolfram.com/BernoulliDistribution.html + + """ + + return rv(name, BernoulliDistribution, p, succ, fail) + + +def Coin(name, p=S.Half): + r""" + Create a Finite Random Variable representing a Coin toss. + + This is an equivalent of a Bernoulli random variable with + "H" and "T" as success and failure events respectively. + + Parameters + ========== + + p : Rational Number between 0 and 1 + Represents probability of getting "Heads", by default is Half + + Examples + ======== + + >>> from sympy.stats import Coin, density + >>> from sympy import Rational + + >>> C = Coin('C') # A fair coin toss + >>> density(C).dict + {H: 1/2, T: 1/2} + + >>> C2 = Coin('C2', Rational(3, 5)) # An unfair coin + >>> density(C2).dict + {H: 3/5, T: 2/5} + + Returns + ======= + + RandomSymbol + + See Also + ======== + + sympy.stats.Binomial + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Coin_flipping + + """ + return rv(name, BernoulliDistribution, p, 'H', 'T') + + +class BinomialDistribution(SingleFiniteDistribution): + _argnames = ('n', 'p', 'succ', 'fail') + + @staticmethod + def check(n, p, succ, fail): + _value_check((n.is_integer, n.is_nonnegative), + "'n' must be nonnegative integer.") + _value_check((p <= 1, p >= 0), + "p should be in range [0, 1].") + + @property + def high(self): + return self.n + + @property + def low(self): + return S.Zero + + @property + def is_symbolic(self): + return not self.n.is_number + + @property + def set(self): + if self.is_symbolic: + return Intersection(S.Naturals0, Interval(0, self.n)) + return set(self.dict.keys()) + + def pmf(self, x): + n, p = self.n, self.p + x = sympify(x) + if not (x.is_number or x.is_Symbol or is_random(x)): + raise ValueError("'x' expected as an argument of type 'number', 'Symbol', or " + "'RandomSymbol' not %s" % (type(x))) + cond = Ge(x, 0) & Le(x, n) & Contains(x, S.Integers) + return Piecewise((binomial(n, x) * p**x * (1 - p)**(n - x), cond), (S.Zero, True)) + + @property # type: ignore + @cacheit + def dict(self): + if self.is_symbolic: + return Density(self) + return {k*self.succ + (self.n-k)*self.fail: self.pmf(k) + for k in range(0, self.n + 1)} + + +def Binomial(name, n, p, succ=1, fail=0): + r""" + Create a Finite Random Variable representing a binomial distribution. + + Parameters + ========== + + n : Positive Integer + Represents number of trials + p : Rational Number between 0 and 1 + Represents probability of success + succ : Integer/symbol/string + Represents event of success, by default is 1 + fail : Integer/symbol/string + Represents event of failure, by default is 0 + + Examples + ======== + + >>> from sympy.stats import Binomial, density + >>> from sympy import S, Symbol + + >>> X = Binomial('X', 4, S.Half) # Four "coin flips" + >>> density(X).dict + {0: 1/16, 1: 1/4, 2: 3/8, 3: 1/4, 4: 1/16} + + >>> n = Symbol('n', positive=True, integer=True) + >>> p = Symbol('p', positive=True) + >>> X = Binomial('X', n, S.Half) # n "coin flips" + >>> density(X).dict + Density(BinomialDistribution(n, 1/2, 1, 0)) + >>> density(X).dict.subs(n, 4).doit() + {0: 1/16, 1: 1/4, 2: 3/8, 3: 1/4, 4: 1/16} + + Returns + ======= + + RandomSymbol + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Binomial_distribution + .. [2] https://mathworld.wolfram.com/BinomialDistribution.html + + """ + + return rv(name, BinomialDistribution, n, p, succ, fail) + +#------------------------------------------------------------------------------- +# Beta-binomial distribution ---------------------------------------------------------- + +class BetaBinomialDistribution(SingleFiniteDistribution): + _argnames = ('n', 'alpha', 'beta') + + @staticmethod + def check(n, alpha, beta): + _value_check((n.is_integer, n.is_nonnegative), + "'n' must be nonnegative integer. n = %s." % str(n)) + _value_check((alpha > 0), + "'alpha' must be: alpha > 0 . alpha = %s" % str(alpha)) + _value_check((beta > 0), + "'beta' must be: beta > 0 . beta = %s" % str(beta)) + + @property + def high(self): + return self.n + + @property + def low(self): + return S.Zero + + @property + def is_symbolic(self): + return not self.n.is_number + + @property + def set(self): + if self.is_symbolic: + return Intersection(S.Naturals0, Interval(0, self.n)) + return set(map(Integer, range(self.n + 1))) + + def pmf(self, k): + n, a, b = self.n, self.alpha, self.beta + return binomial(n, k) * beta_fn(k + a, n - k + b) / beta_fn(a, b) + + +def BetaBinomial(name, n, alpha, beta): + r""" + Create a Finite Random Variable representing a Beta-binomial distribution. + + Parameters + ========== + + n : Positive Integer + Represents number of trials + alpha : Real positive number + beta : Real positive number + + Examples + ======== + + >>> from sympy.stats import BetaBinomial, density + + >>> X = BetaBinomial('X', 2, 1, 1) + >>> density(X).dict + {0: 1/3, 1: 2*beta(2, 2), 2: 1/3} + + Returns + ======= + + RandomSymbol + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Beta-binomial_distribution + .. [2] https://mathworld.wolfram.com/BetaBinomialDistribution.html + + """ + + return rv(name, BetaBinomialDistribution, n, alpha, beta) + + +class HypergeometricDistribution(SingleFiniteDistribution): + _argnames = ('N', 'm', 'n') + + @staticmethod + def check(n, N, m): + _value_check((N.is_integer, N.is_nonnegative), + "'N' must be nonnegative integer. N = %s." % str(N)) + _value_check((n.is_integer, n.is_nonnegative), + "'n' must be nonnegative integer. n = %s." % str(n)) + _value_check((m.is_integer, m.is_nonnegative), + "'m' must be nonnegative integer. m = %s." % str(m)) + + @property + def is_symbolic(self): + return not all(x.is_number for x in (self.N, self.m, self.n)) + + @property + def high(self): + return Piecewise((self.n, Lt(self.n, self.m) != False), (self.m, True)) + + @property + def low(self): + return Piecewise((0, Gt(0, self.n + self.m - self.N) != False), (self.n + self.m - self.N, True)) + + @property + def set(self): + N, m, n = self.N, self.m, self.n + if self.is_symbolic: + return Intersection(S.Naturals0, Interval(self.low, self.high)) + return set(range(max(0, n + m - N), min(n, m) + 1)) + + def pmf(self, k): + N, m, n = self.N, self.m, self.n + return S(binomial(m, k) * binomial(N - m, n - k))/binomial(N, n) + + +def Hypergeometric(name, N, m, n): + r""" + Create a Finite Random Variable representing a hypergeometric distribution. + + Parameters + ========== + + N : Positive Integer + Represents finite population of size N. + m : Positive Integer + Represents number of trials with required feature. + n : Positive Integer + Represents numbers of draws. + + + Examples + ======== + + >>> from sympy.stats import Hypergeometric, density + + >>> X = Hypergeometric('X', 10, 5, 3) # 10 marbles, 5 white (success), 3 draws + >>> density(X).dict + {0: 1/12, 1: 5/12, 2: 5/12, 3: 1/12} + + Returns + ======= + + RandomSymbol + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Hypergeometric_distribution + .. [2] https://mathworld.wolfram.com/HypergeometricDistribution.html + + """ + return rv(name, HypergeometricDistribution, N, m, n) + + +class RademacherDistribution(SingleFiniteDistribution): + + @property + def set(self): + return {-1, 1} + + @property + def pmf(self): + k = Dummy('k') + return Lambda(k, Piecewise((S.Half, Or(Eq(k, -1), Eq(k, 1))), (S.Zero, True))) + +def Rademacher(name): + r""" + Create a Finite Random Variable representing a Rademacher distribution. + + Examples + ======== + + >>> from sympy.stats import Rademacher, density + + >>> X = Rademacher('X') + >>> density(X).dict + {-1: 1/2, 1: 1/2} + + Returns + ======= + + RandomSymbol + + See Also + ======== + + sympy.stats.Bernoulli + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Rademacher_distribution + + """ + return rv(name, RademacherDistribution) + +class IdealSolitonDistribution(SingleFiniteDistribution): + _argnames = ('k',) + + @staticmethod + def check(k): + _value_check(k.is_integer and k.is_positive, + "'k' must be a positive integer.") + + @property + def low(self): + return S.One + + @property + def high(self): + return self.k + + @property + def set(self): + return set(map(Integer, range(1, self.k + 1))) + + @property # type: ignore + @cacheit + def dict(self): + if self.k.is_Symbol: + return Density(self) + d = {1: Rational(1, self.k)} + d.update({i: Rational(1, i*(i - 1)) for i in range(2, self.k + 1)}) + return d + + def pmf(self, x): + x = sympify(x) + if not (x.is_number or x.is_Symbol or is_random(x)): + raise ValueError("'x' expected as an argument of type 'number', 'Symbol', or " + "'RandomSymbol' not %s" % (type(x))) + cond1 = Eq(x, 1) & x.is_integer + cond2 = Ge(x, 1) & Le(x, self.k) & x.is_integer + return Piecewise((1/self.k, cond1), (1/(x*(x - 1)), cond2), (S.Zero, True)) + +def IdealSoliton(name, k): + r""" + Create a Finite Random Variable of Ideal Soliton Distribution + + Parameters + ========== + + k : Positive Integer + Represents the number of input symbols in an LT (Luby Transform) code. + + Examples + ======== + + >>> from sympy.stats import IdealSoliton, density, P, E + >>> sol = IdealSoliton('sol', 5) + >>> density(sol).dict + {1: 1/5, 2: 1/2, 3: 1/6, 4: 1/12, 5: 1/20} + >>> density(sol).set + {1, 2, 3, 4, 5} + + >>> from sympy import Symbol + >>> k = Symbol('k', positive=True, integer=True) + >>> sol = IdealSoliton('sol', k) + >>> density(sol).dict + Density(IdealSolitonDistribution(k)) + >>> density(sol).dict.subs(k, 10).doit() + {1: 1/10, 2: 1/2, 3: 1/6, 4: 1/12, 5: 1/20, 6: 1/30, 7: 1/42, 8: 1/56, 9: 1/72, 10: 1/90} + + >>> E(sol.subs(k, 10)) + 7381/2520 + + >>> P(sol.subs(k, 4) > 2) + 1/4 + + Returns + ======= + + RandomSymbol + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Soliton_distribution#Ideal_distribution + .. [2] https://pages.cs.wisc.edu/~suman/courses/740/papers/luby02lt.pdf + + """ + return rv(name, IdealSolitonDistribution, k) + +class RobustSolitonDistribution(SingleFiniteDistribution): + _argnames= ('k', 'delta', 'c') + + @staticmethod + def check(k, delta, c): + _value_check(k.is_integer and k.is_positive, + "'k' must be a positive integer") + _value_check(Gt(delta, 0) and Le(delta, 1), + "'delta' must be a real number in the interval (0,1)") + _value_check(c.is_positive, + "'c' must be a positive real number.") + + @property + def R(self): + return self.c * log(self.k/self.delta) * self.k**0.5 + + @property + def Z(self): + z = 0 + for i in Range(1, round(self.k/self.R)): + z += (1/i) + z += log(self.R/self.delta) + return 1 + z * self.R/self.k + + @property + def low(self): + return S.One + + @property + def high(self): + return self.k + + @property + def set(self): + return set(map(Integer, range(1, self.k + 1))) + + @property + def is_symbolic(self): + return not (self.k.is_number and self.c.is_number and self.delta.is_number) + + def pmf(self, x): + x = sympify(x) + if not (x.is_number or x.is_Symbol or is_random(x)): + raise ValueError("'x' expected as an argument of type 'number', 'Symbol', or " + "'RandomSymbol' not %s" % (type(x))) + + cond1 = Eq(x, 1) & x.is_integer + cond2 = Ge(x, 1) & Le(x, self.k) & x.is_integer + rho = Piecewise((Rational(1, self.k), cond1), (Rational(1, x*(x-1)), cond2), (S.Zero, True)) + + cond1 = Ge(x, 1) & Le(x, round(self.k/self.R)-1) + cond2 = Eq(x, round(self.k/self.R)) + tau = Piecewise((self.R/(self.k * x), cond1), (self.R * log(self.R/self.delta)/self.k, cond2), (S.Zero, True)) + + return (rho + tau)/self.Z + +def RobustSoliton(name, k, delta, c): + r''' + Create a Finite Random Variable of Robust Soliton Distribution + + Parameters + ========== + + k : Positive Integer + Represents the number of input symbols in an LT (Luby Transform) code. + delta : Positive Rational Number + Represents the failure probability. Must be in the interval (0,1). + c : Positive Rational Number + Constant of proportionality. Values close to 1 are recommended + + Examples + ======== + + >>> from sympy.stats import RobustSoliton, density, P, E + >>> robSol = RobustSoliton('robSol', 5, 0.5, 0.01) + >>> density(robSol).dict + {1: 0.204253668152708, 2: 0.490631107897393, 3: 0.165210624506162, 4: 0.0834387731899302, 5: 0.0505633404760675} + >>> density(robSol).set + {1, 2, 3, 4, 5} + + >>> from sympy import Symbol + >>> k = Symbol('k', positive=True, integer=True) + >>> c = Symbol('c', positive=True) + >>> robSol = RobustSoliton('robSol', k, 0.5, c) + >>> density(robSol).dict + Density(RobustSolitonDistribution(k, 0.5, c)) + >>> density(robSol).dict.subs(k, 10).subs(c, 0.03).doit() + {1: 0.116641095387194, 2: 0.467045731687165, 3: 0.159984123349381, 4: 0.0821431680681869, 5: 0.0505765646770100, + 6: 0.0345781523420719, 7: 0.0253132820710503, 8: 0.0194459129233227, 9: 0.0154831166726115, 10: 0.0126733075238887} + + >>> E(robSol.subs(k, 10).subs(c, 0.05)) + 2.91358846104106 + + >>> P(robSol.subs(k, 4).subs(c, 0.1) > 2) + 0.243650614389834 + + Returns + ======= + + RandomSymbol + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Soliton_distribution#Robust_distribution + .. [2] https://www.inference.org.uk/mackay/itprnn/ps/588.596.pdf + .. [3] https://pages.cs.wisc.edu/~suman/courses/740/papers/luby02lt.pdf + + ''' + return rv(name, RobustSolitonDistribution, k, delta, c) diff --git a/phivenv/Lib/site-packages/sympy/stats/joint_rv.py b/phivenv/Lib/site-packages/sympy/stats/joint_rv.py new file mode 100644 index 0000000000000000000000000000000000000000..d147942f08b998e167b246628360fa27fc8ef348 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/stats/joint_rv.py @@ -0,0 +1,426 @@ +""" +Joint Random Variables Module + +See Also +======== +sympy.stats.rv +sympy.stats.frv +sympy.stats.crv +sympy.stats.drv +""" +from math import prod + +from sympy.core.basic import Basic +from sympy.core.function import Lambda +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol) +from sympy.core.sympify import sympify +from sympy.sets.sets import ProductSet +from sympy.tensor.indexed import Indexed +from sympy.concrete.products import Product +from sympy.concrete.summations import Sum, summation +from sympy.core.containers import Tuple +from sympy.integrals.integrals import Integral, integrate +from sympy.matrices import ImmutableMatrix, matrix2numpy, list2numpy +from sympy.stats.crv import SingleContinuousDistribution, SingleContinuousPSpace +from sympy.stats.drv import SingleDiscreteDistribution, SingleDiscretePSpace +from sympy.stats.rv import (ProductPSpace, NamedArgsMixin, Distribution, + ProductDomain, RandomSymbol, random_symbols, + SingleDomain, _symbol_converter) +from sympy.utilities.iterables import iterable +from sympy.utilities.misc import filldedent +from sympy.external import import_module + +# __all__ = ['marginal_distribution'] + +class JointPSpace(ProductPSpace): + """ + Represents a joint probability space. Represented using symbols for + each component and a distribution. + """ + def __new__(cls, sym, dist): + if isinstance(dist, SingleContinuousDistribution): + return SingleContinuousPSpace(sym, dist) + if isinstance(dist, SingleDiscreteDistribution): + return SingleDiscretePSpace(sym, dist) + sym = _symbol_converter(sym) + return Basic.__new__(cls, sym, dist) + + @property + def set(self): + return self.domain.set + + @property + def symbol(self): + return self.args[0] + + @property + def distribution(self): + return self.args[1] + + @property + def value(self): + return JointRandomSymbol(self.symbol, self) + + @property + def component_count(self): + _set = self.distribution.set + if isinstance(_set, ProductSet): + return S(len(_set.args)) + elif isinstance(_set, Product): + return _set.limits[0][-1] + return S.One + + @property + def pdf(self): + sym = [Indexed(self.symbol, i) for i in range(self.component_count)] + return self.distribution(*sym) + + @property + def domain(self): + rvs = random_symbols(self.distribution) + if not rvs: + return SingleDomain(self.symbol, self.distribution.set) + return ProductDomain(*[rv.pspace.domain for rv in rvs]) + + def component_domain(self, index): + return self.set.args[index] + + def marginal_distribution(self, *indices): + count = self.component_count + if count.atoms(Symbol): + raise ValueError("Marginal distributions cannot be computed " + "for symbolic dimensions. It is a work under progress.") + orig = [Indexed(self.symbol, i) for i in range(count)] + all_syms = [Symbol(str(i)) for i in orig] + replace_dict = dict(zip(all_syms, orig)) + sym = tuple(Symbol(str(Indexed(self.symbol, i))) for i in indices) + limits = [[i,] for i in all_syms if i not in sym] + index = 0 + for i in range(count): + if i not in indices: + limits[index].append(self.distribution.set.args[i]) + limits[index] = tuple(limits[index]) + index += 1 + if self.distribution.is_Continuous: + f = Lambda(sym, integrate(self.distribution(*all_syms), *limits)) + elif self.distribution.is_Discrete: + f = Lambda(sym, summation(self.distribution(*all_syms), *limits)) + return f.xreplace(replace_dict) + + def compute_expectation(self, expr, rvs=None, evaluate=False, **kwargs): + syms = tuple(self.value[i] for i in range(self.component_count)) + rvs = rvs or syms + if not any(i in rvs for i in syms): + return expr + expr = expr*self.pdf + for rv in rvs: + if isinstance(rv, Indexed): + expr = expr.xreplace({rv: Indexed(str(rv.base), rv.args[1])}) + elif isinstance(rv, RandomSymbol): + expr = expr.xreplace({rv: rv.symbol}) + if self.value in random_symbols(expr): + raise NotImplementedError(filldedent(''' + Expectations of expression with unindexed joint random symbols + cannot be calculated yet.''')) + limits = tuple((Indexed(str(rv.base),rv.args[1]), + self.distribution.set.args[rv.args[1]]) for rv in syms) + return Integral(expr, *limits) + + def where(self, condition): + raise NotImplementedError() + + def compute_density(self, expr): + raise NotImplementedError() + + def sample(self, size=(), library='scipy', seed=None): + """ + Internal sample method + + Returns dictionary mapping RandomSymbol to realization value. + """ + return {RandomSymbol(self.symbol, self): self.distribution.sample(size, + library=library, seed=seed)} + + def probability(self, condition): + raise NotImplementedError() + + +class SampleJointScipy: + """Returns the sample from scipy of the given distribution""" + def __new__(cls, dist, size, seed=None): + return cls._sample_scipy(dist, size, seed) + + @classmethod + def _sample_scipy(cls, dist, size, seed): + """Sample from SciPy.""" + + import numpy + if seed is None or isinstance(seed, int): + rand_state = numpy.random.default_rng(seed=seed) + else: + rand_state = seed + from scipy import stats as scipy_stats + scipy_rv_map = { + 'MultivariateNormalDistribution': lambda dist, size: scipy_stats.multivariate_normal.rvs( + mean=matrix2numpy(dist.mu).flatten(), + cov=matrix2numpy(dist.sigma), size=size, random_state=rand_state), + 'MultivariateBetaDistribution': lambda dist, size: scipy_stats.dirichlet.rvs( + alpha=list2numpy(dist.alpha, float).flatten(), size=size, random_state=rand_state), + 'MultinomialDistribution': lambda dist, size: scipy_stats.multinomial.rvs( + n=int(dist.n), p=list2numpy(dist.p, float).flatten(), size=size, random_state=rand_state) + } + + sample_shape = { + 'MultivariateNormalDistribution': lambda dist: matrix2numpy(dist.mu).flatten().shape, + 'MultivariateBetaDistribution': lambda dist: list2numpy(dist.alpha).flatten().shape, + 'MultinomialDistribution': lambda dist: list2numpy(dist.p).flatten().shape + } + + dist_list = scipy_rv_map.keys() + + if dist.__class__.__name__ not in dist_list: + return None + + samples = scipy_rv_map[dist.__class__.__name__](dist, size) + return samples.reshape(size + sample_shape[dist.__class__.__name__](dist)) + +class SampleJointNumpy: + """Returns the sample from numpy of the given distribution""" + + def __new__(cls, dist, size, seed=None): + return cls._sample_numpy(dist, size, seed) + + @classmethod + def _sample_numpy(cls, dist, size, seed): + """Sample from NumPy.""" + + import numpy + if seed is None or isinstance(seed, int): + rand_state = numpy.random.default_rng(seed=seed) + else: + rand_state = seed + numpy_rv_map = { + 'MultivariateNormalDistribution': lambda dist, size: rand_state.multivariate_normal( + mean=matrix2numpy(dist.mu, float).flatten(), + cov=matrix2numpy(dist.sigma, float), size=size), + 'MultivariateBetaDistribution': lambda dist, size: rand_state.dirichlet( + alpha=list2numpy(dist.alpha, float).flatten(), size=size), + 'MultinomialDistribution': lambda dist, size: rand_state.multinomial( + n=int(dist.n), pvals=list2numpy(dist.p, float).flatten(), size=size) + } + + sample_shape = { + 'MultivariateNormalDistribution': lambda dist: matrix2numpy(dist.mu).flatten().shape, + 'MultivariateBetaDistribution': lambda dist: list2numpy(dist.alpha).flatten().shape, + 'MultinomialDistribution': lambda dist: list2numpy(dist.p).flatten().shape + } + + dist_list = numpy_rv_map.keys() + + if dist.__class__.__name__ not in dist_list: + return None + + samples = numpy_rv_map[dist.__class__.__name__](dist, prod(size)) + return samples.reshape(size + sample_shape[dist.__class__.__name__](dist)) + +class SampleJointPymc: + """Returns the sample from pymc of the given distribution""" + + def __new__(cls, dist, size, seed=None): + return cls._sample_pymc(dist, size, seed) + + @classmethod + def _sample_pymc(cls, dist, size, seed): + """Sample from PyMC.""" + + try: + import pymc + except ImportError: + import pymc3 as pymc + pymc_rv_map = { + 'MultivariateNormalDistribution': lambda dist: + pymc.MvNormal('X', mu=matrix2numpy(dist.mu, float).flatten(), + cov=matrix2numpy(dist.sigma, float), shape=(1, dist.mu.shape[0])), + 'MultivariateBetaDistribution': lambda dist: + pymc.Dirichlet('X', a=list2numpy(dist.alpha, float).flatten()), + 'MultinomialDistribution': lambda dist: + pymc.Multinomial('X', n=int(dist.n), + p=list2numpy(dist.p, float).flatten(), shape=(1, len(dist.p))) + } + + sample_shape = { + 'MultivariateNormalDistribution': lambda dist: matrix2numpy(dist.mu).flatten().shape, + 'MultivariateBetaDistribution': lambda dist: list2numpy(dist.alpha).flatten().shape, + 'MultinomialDistribution': lambda dist: list2numpy(dist.p).flatten().shape + } + + dist_list = pymc_rv_map.keys() + + if dist.__class__.__name__ not in dist_list: + return None + + import logging + logging.getLogger("pymc3").setLevel(logging.ERROR) + with pymc.Model(): + pymc_rv_map[dist.__class__.__name__](dist) + samples = pymc.sample(draws=prod(size), chains=1, progressbar=False, random_seed=seed, return_inferencedata=False, compute_convergence_checks=False)[:]['X'] + return samples.reshape(size + sample_shape[dist.__class__.__name__](dist)) + + +_get_sample_class_jrv = { + 'scipy': SampleJointScipy, + 'pymc3': SampleJointPymc, + 'pymc': SampleJointPymc, + 'numpy': SampleJointNumpy +} + +class JointDistribution(Distribution, NamedArgsMixin): + """ + Represented by the random variables part of the joint distribution. + Contains methods for PDF, CDF, sampling, marginal densities, etc. + """ + + _argnames = ('pdf', ) + + def __new__(cls, *args): + args = list(map(sympify, args)) + for i in range(len(args)): + if isinstance(args[i], list): + args[i] = ImmutableMatrix(args[i]) + return Basic.__new__(cls, *args) + + @property + def domain(self): + return ProductDomain(self.symbols) + + @property + def pdf(self): + return self.density.args[1] + + def cdf(self, other): + if not isinstance(other, dict): + raise ValueError("%s should be of type dict, got %s"%(other, type(other))) + rvs = other.keys() + _set = self.domain.set.sets + expr = self.pdf(tuple(i.args[0] for i in self.symbols)) + for i in range(len(other)): + if rvs[i].is_Continuous: + density = Integral(expr, (rvs[i], _set[i].inf, + other[rvs[i]])) + elif rvs[i].is_Discrete: + density = Sum(expr, (rvs[i], _set[i].inf, + other[rvs[i]])) + return density + + def sample(self, size=(), library='scipy', seed=None): + """ A random realization from the distribution """ + + libraries = ('scipy', 'numpy', 'pymc3', 'pymc') + if library not in libraries: + raise NotImplementedError("Sampling from %s is not supported yet." + % str(library)) + if not import_module(library): + raise ValueError("Failed to import %s" % library) + + samps = _get_sample_class_jrv[library](self, size, seed=seed) + + if samps is not None: + return samps + raise NotImplementedError( + "Sampling for %s is not currently implemented from %s" + % (self.__class__.__name__, library) + ) + + def __call__(self, *args): + return self.pdf(*args) + +class JointRandomSymbol(RandomSymbol): + """ + Representation of random symbols with joint probability distributions + to allow indexing." + """ + def __getitem__(self, key): + if isinstance(self.pspace, JointPSpace): + if (self.pspace.component_count <= key) == True: + raise ValueError("Index keys for %s can only up to %s." % + (self.name, self.pspace.component_count - 1)) + return Indexed(self, key) + + + +class MarginalDistribution(Distribution): + """ + Represents the marginal distribution of a joint probability space. + + Initialised using a probability distribution and random variables(or + their indexed components) which should be a part of the resultant + distribution. + """ + + def __new__(cls, dist, *rvs): + if len(rvs) == 1 and iterable(rvs[0]): + rvs = tuple(rvs[0]) + if not all(isinstance(rv, (Indexed, RandomSymbol)) for rv in rvs): + raise ValueError(filldedent('''Marginal distribution can be + intitialised only in terms of random variables or indexed random + variables''')) + rvs = Tuple.fromiter(rv for rv in rvs) + if not isinstance(dist, JointDistribution) and len(random_symbols(dist)) == 0: + return dist + return Basic.__new__(cls, dist, rvs) + + def check(self): + pass + + @property + def set(self): + rvs = [i for i in self.args[1] if isinstance(i, RandomSymbol)] + return ProductSet(*[rv.pspace.set for rv in rvs]) + + @property + def symbols(self): + rvs = self.args[1] + return {rv.pspace.symbol for rv in rvs} + + def pdf(self, *x): + expr, rvs = self.args[0], self.args[1] + marginalise_out = [i for i in random_symbols(expr) if i not in rvs] + if isinstance(expr, JointDistribution): + count = len(expr.domain.args) + x = Dummy('x', real=True) + syms = tuple(Indexed(x, i) for i in count) + expr = expr.pdf(syms) + else: + syms = tuple(rv.pspace.symbol if isinstance(rv, RandomSymbol) else rv.args[0] for rv in rvs) + return Lambda(syms, self.compute_pdf(expr, marginalise_out))(*x) + + def compute_pdf(self, expr, rvs): + for rv in rvs: + lpdf = 1 + if isinstance(rv, RandomSymbol): + lpdf = rv.pspace.pdf + expr = self.marginalise_out(expr*lpdf, rv) + return expr + + def marginalise_out(self, expr, rv): + from sympy.concrete.summations import Sum + if isinstance(rv, RandomSymbol): + dom = rv.pspace.set + elif isinstance(rv, Indexed): + dom = rv.base.component_domain( + rv.pspace.component_domain(rv.args[1])) + expr = expr.xreplace({rv: rv.pspace.symbol}) + if rv.pspace.is_Continuous: + #TODO: Modify to support integration + #for all kinds of sets. + expr = Integral(expr, (rv.pspace.symbol, dom)) + elif rv.pspace.is_Discrete: + #incorporate this into `Sum`/`summation` + if dom in (S.Integers, S.Naturals, S.Naturals0): + dom = (dom.inf, dom.sup) + expr = Sum(expr, (rv.pspace.symbol, dom)) + return expr + + def __call__(self, *args): + return self.pdf(*args) diff --git a/phivenv/Lib/site-packages/sympy/stats/joint_rv_types.py b/phivenv/Lib/site-packages/sympy/stats/joint_rv_types.py new file mode 100644 index 0000000000000000000000000000000000000000..6cee9f9aa30897593ffb7c7b930a55a38f0c518a --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/stats/joint_rv_types.py @@ -0,0 +1,945 @@ +from sympy.concrete.products import Product +from sympy.concrete.summations import Sum +from sympy.core.add import Add +from sympy.core.function import Lambda +from sympy.core.mul import Mul +from sympy.core.numbers import (Integer, Rational, pi) +from sympy.core.power import Pow +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.core.sympify import sympify +from sympy.functions.combinatorial.factorials import (rf, factorial) +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.special.bessel import besselk +from sympy.functions.special.gamma_functions import gamma +from sympy.matrices.dense import (Matrix, ones) +from sympy.sets.fancysets import Range +from sympy.sets.sets import (Intersection, Interval) +from sympy.tensor.indexed import (Indexed, IndexedBase) +from sympy.matrices import ImmutableMatrix, MatrixSymbol +from sympy.matrices.expressions.determinant import det +from sympy.matrices.expressions.matexpr import MatrixElement +from sympy.stats.joint_rv import JointDistribution, JointPSpace, MarginalDistribution +from sympy.stats.rv import _value_check, random_symbols + +__all__ = ['JointRV', +'MultivariateNormal', +'MultivariateLaplace', +'Dirichlet', +'GeneralizedMultivariateLogGamma', +'GeneralizedMultivariateLogGammaOmega', +'Multinomial', +'MultivariateBeta', +'MultivariateEwens', +'MultivariateT', +'NegativeMultinomial', +'NormalGamma' +] + +def multivariate_rv(cls, sym, *args): + args = list(map(sympify, args)) + dist = cls(*args) + args = dist.args + dist.check(*args) + return JointPSpace(sym, dist).value + + +def marginal_distribution(rv, *indices): + """ + Marginal distribution function of a joint random variable. + + Parameters + ========== + + rv : A random variable with a joint probability distribution. + indices : Component indices or the indexed random symbol + for which the joint distribution is to be calculated + + Returns + ======= + + A Lambda expression in `sym`. + + Examples + ======== + + >>> from sympy.stats import MultivariateNormal, marginal_distribution + >>> m = MultivariateNormal('X', [1, 2], [[2, 1], [1, 2]]) + >>> marginal_distribution(m, m[0])(1) + 1/(2*sqrt(pi)) + + """ + indices = list(indices) + for i in range(len(indices)): + if isinstance(indices[i], Indexed): + indices[i] = indices[i].args[1] + prob_space = rv.pspace + if not indices: + raise ValueError( + "At least one component for marginal density is needed.") + if hasattr(prob_space.distribution, '_marginal_distribution'): + return prob_space.distribution._marginal_distribution(indices, rv.symbol) + return prob_space.marginal_distribution(*indices) + + +class JointDistributionHandmade(JointDistribution): + + _argnames = ('pdf',) + is_Continuous = True + + @property + def set(self): + return self.args[1] + + +def JointRV(symbol, pdf, _set=None): + """ + Create a Joint Random Variable where each of its component is continuous, + given the following: + + Parameters + ========== + + symbol : Symbol + Represents name of the random variable. + pdf : A PDF in terms of indexed symbols of the symbol given + as the first argument + + NOTE + ==== + + As of now, the set for each component for a ``JointRV`` is + equal to the set of all integers, which cannot be changed. + + Examples + ======== + + >>> from sympy import exp, pi, Indexed, S + >>> from sympy.stats import density, JointRV + >>> x1, x2 = (Indexed('x', i) for i in (1, 2)) + >>> pdf = exp(-x1**2/2 + x1 - x2**2/2 - S(1)/2)/(2*pi) + >>> N1 = JointRV('x', pdf) #Multivariate Normal distribution + >>> density(N1)(1, 2) + exp(-2)/(2*pi) + + Returns + ======= + + RandomSymbol + + """ + #TODO: Add support for sets provided by the user + symbol = sympify(symbol) + syms = [i for i in pdf.free_symbols if isinstance(i, Indexed) + and i.base == IndexedBase(symbol)] + syms = tuple(sorted(syms, key = lambda index: index.args[1])) + _set = S.Reals**len(syms) + pdf = Lambda(syms, pdf) + dist = JointDistributionHandmade(pdf, _set) + jrv = JointPSpace(symbol, dist).value + rvs = random_symbols(pdf) + if len(rvs) != 0: + dist = MarginalDistribution(dist, (jrv,)) + return JointPSpace(symbol, dist).value + return jrv + +#------------------------------------------------------------------------------- +# Multivariate Normal distribution --------------------------------------------- + +class MultivariateNormalDistribution(JointDistribution): + _argnames = ('mu', 'sigma') + + is_Continuous=True + + @property + def set(self): + k = self.mu.shape[0] + return S.Reals**k + + @staticmethod + def check(mu, sigma): + _value_check(mu.shape[0] == sigma.shape[0], + "Size of the mean vector and covariance matrix are incorrect.") + #check if covariance matrix is positive semi definite or not. + if not isinstance(sigma, MatrixSymbol): + _value_check(sigma.is_positive_semidefinite, + "The covariance matrix must be positive semi definite. ") + + def pdf(self, *args): + mu, sigma = self.mu, self.sigma + k = mu.shape[0] + if len(args) == 1 and args[0].is_Matrix: + args = args[0] + else: + args = ImmutableMatrix(args) + x = args - mu + density = S.One/sqrt((2*pi)**(k)*det(sigma))*exp( + Rational(-1, 2)*x.transpose()*(sigma.inv()*x)) + return MatrixElement(density, 0, 0) + + def _marginal_distribution(self, indices, sym): + sym = ImmutableMatrix([Indexed(sym, i) for i in indices]) + _mu, _sigma = self.mu, self.sigma + k = self.mu.shape[0] + for i in range(k): + if i not in indices: + _mu = _mu.row_del(i) + _sigma = _sigma.col_del(i) + _sigma = _sigma.row_del(i) + return Lambda(tuple(sym), S.One/sqrt((2*pi)**(len(_mu))*det(_sigma))*exp( + Rational(-1, 2)*(_mu - sym).transpose()*(_sigma.inv()*\ + (_mu - sym)))[0]) + +def MultivariateNormal(name, mu, sigma): + r""" + Creates a continuous random variable with Multivariate Normal + Distribution. + + The density of the multivariate normal distribution can be found at [1]. + + Parameters + ========== + + mu : List representing the mean or the mean vector + sigma : Positive semidefinite square matrix + Represents covariance Matrix. + If `\sigma` is noninvertible then only sampling is supported currently + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import MultivariateNormal, density, marginal_distribution + >>> from sympy import symbols, MatrixSymbol + >>> X = MultivariateNormal('X', [3, 4], [[2, 1], [1, 2]]) + >>> y, z = symbols('y z') + >>> density(X)(y, z) + sqrt(3)*exp(-y**2/3 + y*z/3 + 2*y/3 - z**2/3 + 5*z/3 - 13/3)/(6*pi) + >>> density(X)(1, 2) + sqrt(3)*exp(-4/3)/(6*pi) + >>> marginal_distribution(X, X[1])(y) + exp(-(y - 4)**2/4)/(2*sqrt(pi)) + >>> marginal_distribution(X, X[0])(y) + exp(-(y - 3)**2/4)/(2*sqrt(pi)) + + The example below shows that it is also possible to use + symbolic parameters to define the MultivariateNormal class. + + >>> n = symbols('n', integer=True, positive=True) + >>> Sg = MatrixSymbol('Sg', n, n) + >>> mu = MatrixSymbol('mu', n, 1) + >>> obs = MatrixSymbol('obs', n, 1) + >>> X = MultivariateNormal('X', mu, Sg) + + The density of a multivariate normal can be + calculated using a matrix argument, as shown below. + + >>> density(X)(obs) + (exp(((1/2)*mu.T - (1/2)*obs.T)*Sg**(-1)*(-mu + obs))/sqrt((2*pi)**n*Determinant(Sg)))[0, 0] + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Multivariate_normal_distribution + + """ + return multivariate_rv(MultivariateNormalDistribution, name, mu, sigma) + +#------------------------------------------------------------------------------- +# Multivariate Laplace distribution -------------------------------------------- + +class MultivariateLaplaceDistribution(JointDistribution): + _argnames = ('mu', 'sigma') + is_Continuous=True + + @property + def set(self): + k = self.mu.shape[0] + return S.Reals**k + + @staticmethod + def check(mu, sigma): + _value_check(mu.shape[0] == sigma.shape[0], + "Size of the mean vector and covariance matrix are incorrect.") + # check if covariance matrix is positive definite or not. + if not isinstance(sigma, MatrixSymbol): + _value_check(sigma.is_positive_definite, + "The covariance matrix must be positive definite. ") + + def pdf(self, *args): + mu, sigma = self.mu, self.sigma + mu_T = mu.transpose() + k = S(mu.shape[0]) + sigma_inv = sigma.inv() + args = ImmutableMatrix(args) + args_T = args.transpose() + x = (mu_T*sigma_inv*mu)[0] + y = (args_T*sigma_inv*args)[0] + v = 1 - k/2 + return (2 * (y/(2 + x))**(v/2) * besselk(v, sqrt((2 + x)*y)) * + exp((args_T * sigma_inv * mu)[0]) / + ((2 * pi)**(k/2) * sqrt(det(sigma)))) + + +def MultivariateLaplace(name, mu, sigma): + """ + Creates a continuous random variable with Multivariate Laplace + Distribution. + + The density of the multivariate Laplace distribution can be found at [1]. + + Parameters + ========== + + mu : List representing the mean or the mean vector + sigma : Positive definite square matrix + Represents covariance Matrix + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import MultivariateLaplace, density + >>> from sympy import symbols + >>> y, z = symbols('y z') + >>> X = MultivariateLaplace('X', [2, 4], [[3, 1], [1, 3]]) + >>> density(X)(y, z) + sqrt(2)*exp(y/4 + 5*z/4)*besselk(0, sqrt(15*y*(3*y/8 - z/8)/2 + 15*z*(-y/8 + 3*z/8)/2))/(4*pi) + >>> density(X)(1, 2) + sqrt(2)*exp(11/4)*besselk(0, sqrt(165)/4)/(4*pi) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Multivariate_Laplace_distribution + + """ + return multivariate_rv(MultivariateLaplaceDistribution, name, mu, sigma) + +#------------------------------------------------------------------------------- +# Multivariate StudentT distribution ------------------------------------------- + +class MultivariateTDistribution(JointDistribution): + _argnames = ('mu', 'shape_mat', 'dof') + is_Continuous=True + + @property + def set(self): + k = self.mu.shape[0] + return S.Reals**k + + @staticmethod + def check(mu, sigma, v): + _value_check(mu.shape[0] == sigma.shape[0], + "Size of the location vector and shape matrix are incorrect.") + # check if covariance matrix is positive definite or not. + if not isinstance(sigma, MatrixSymbol): + _value_check(sigma.is_positive_definite, + "The shape matrix must be positive definite. ") + + def pdf(self, *args): + mu, sigma = self.mu, self.shape_mat + v = S(self.dof) + k = S(mu.shape[0]) + sigma_inv = sigma.inv() + args = ImmutableMatrix(args) + x = args - mu + return gamma((k + v)/2)/(gamma(v/2)*(v*pi)**(k/2)*sqrt(det(sigma)))\ + *(1 + 1/v*(x.transpose()*sigma_inv*x)[0])**((-v - k)/2) + +def MultivariateT(syms, mu, sigma, v): + """ + Creates a joint random variable with multivariate T-distribution. + + Parameters + ========== + + syms : A symbol/str + For identifying the random variable. + mu : A list/matrix + Representing the location vector + sigma : The shape matrix for the distribution + + Examples + ======== + + >>> from sympy.stats import density, MultivariateT + >>> from sympy import Symbol + + >>> x = Symbol("x") + >>> X = MultivariateT("x", [1, 1], [[1, 0], [0, 1]], 2) + + >>> density(X)(1, 2) + 2/(9*pi) + + Returns + ======= + + RandomSymbol + + """ + return multivariate_rv(MultivariateTDistribution, syms, mu, sigma, v) + + +#------------------------------------------------------------------------------- +# Multivariate Normal Gamma distribution --------------------------------------- + +class NormalGammaDistribution(JointDistribution): + + _argnames = ('mu', 'lamda', 'alpha', 'beta') + is_Continuous=True + + @staticmethod + def check(mu, lamda, alpha, beta): + _value_check(mu.is_real, "Location must be real.") + _value_check(lamda > 0, "Lambda must be positive") + _value_check(alpha > 0, "alpha must be positive") + _value_check(beta > 0, "beta must be positive") + + @property + def set(self): + return S.Reals*Interval(0, S.Infinity) + + def pdf(self, x, tau): + beta, alpha, lamda = self.beta, self.alpha, self.lamda + mu = self.mu + + return beta**alpha*sqrt(lamda)/(gamma(alpha)*sqrt(2*pi))*\ + tau**(alpha - S.Half)*exp(-1*beta*tau)*\ + exp(-1*(lamda*tau*(x - mu)**2)/S(2)) + + def _marginal_distribution(self, indices, *sym): + if len(indices) == 2: + return self.pdf(*sym) + if indices[0] == 0: + #For marginal over `x`, return non-standardized Student-T's + #distribution + x = sym[0] + v, mu, sigma = self.alpha - S.Half, self.mu, \ + S(self.beta)/(self.lamda * self.alpha) + return Lambda(sym, gamma((v + 1)/2)/(gamma(v/2)*sqrt(pi*v)*sigma)*\ + (1 + 1/v*((x - mu)/sigma)**2)**((-v -1)/2)) + #For marginal over `tau`, return Gamma distribution as per construction + from sympy.stats.crv_types import GammaDistribution + return Lambda(sym, GammaDistribution(self.alpha, self.beta)(sym[0])) + +def NormalGamma(sym, mu, lamda, alpha, beta): + """ + Creates a bivariate joint random variable with multivariate Normal gamma + distribution. + + Parameters + ========== + + sym : A symbol/str + For identifying the random variable. + mu : A real number + The mean of the normal distribution + lamda : A positive integer + Parameter of joint distribution + alpha : A positive integer + Parameter of joint distribution + beta : A positive integer + Parameter of joint distribution + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import density, NormalGamma + >>> from sympy import symbols + + >>> X = NormalGamma('x', 0, 1, 2, 3) + >>> y, z = symbols('y z') + + >>> density(X)(y, z) + 9*sqrt(2)*z**(3/2)*exp(-3*z)*exp(-y**2*z/2)/(2*sqrt(pi)) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Normal-gamma_distribution + + """ + return multivariate_rv(NormalGammaDistribution, sym, mu, lamda, alpha, beta) + +#------------------------------------------------------------------------------- +# Multivariate Beta/Dirichlet distribution ------------------------------------- + +class MultivariateBetaDistribution(JointDistribution): + + _argnames = ('alpha',) + is_Continuous = True + + @staticmethod + def check(alpha): + _value_check(len(alpha) >= 2, "At least two categories should be passed.") + for a_k in alpha: + _value_check((a_k > 0) != False, "Each concentration parameter" + " should be positive.") + + @property + def set(self): + k = len(self.alpha) + return Interval(0, 1)**k + + def pdf(self, *syms): + alpha = self.alpha + B = Mul.fromiter(map(gamma, alpha))/gamma(Add(*alpha)) + return Mul.fromiter(sym**(a_k - 1) for a_k, sym in zip(alpha, syms))/B + +def MultivariateBeta(syms, *alpha): + """ + Creates a continuous random variable with Dirichlet/Multivariate Beta + Distribution. + + The density of the Dirichlet distribution can be found at [1]. + + Parameters + ========== + + alpha : Positive real numbers + Signifies concentration numbers. + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import density, MultivariateBeta, marginal_distribution + >>> from sympy import Symbol + >>> a1 = Symbol('a1', positive=True) + >>> a2 = Symbol('a2', positive=True) + >>> B = MultivariateBeta('B', [a1, a2]) + >>> C = MultivariateBeta('C', a1, a2) + >>> x = Symbol('x') + >>> y = Symbol('y') + >>> density(B)(x, y) + x**(a1 - 1)*y**(a2 - 1)*gamma(a1 + a2)/(gamma(a1)*gamma(a2)) + >>> marginal_distribution(C, C[0])(x) + x**(a1 - 1)*gamma(a1 + a2)/(a2*gamma(a1)*gamma(a2)) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Dirichlet_distribution + .. [2] https://mathworld.wolfram.com/DirichletDistribution.html + + """ + if not isinstance(alpha[0], list): + alpha = (list(alpha),) + return multivariate_rv(MultivariateBetaDistribution, syms, alpha[0]) + +Dirichlet = MultivariateBeta + +#------------------------------------------------------------------------------- +# Multivariate Ewens distribution ---------------------------------------------- + +class MultivariateEwensDistribution(JointDistribution): + + _argnames = ('n', 'theta') + is_Discrete = True + is_Continuous = False + + @staticmethod + def check(n, theta): + _value_check((n > 0), + "sample size should be positive integer.") + _value_check(theta.is_positive, "mutation rate should be positive.") + + @property + def set(self): + if not isinstance(self.n, Integer): + i = Symbol('i', integer=True, positive=True) + return Product(Intersection(S.Naturals0, Interval(0, self.n//i)), + (i, 1, self.n)) + prod_set = Range(0, self.n + 1) + for i in range(2, self.n + 1): + prod_set *= Range(0, self.n//i + 1) + return prod_set.flatten() + + def pdf(self, *syms): + n, theta = self.n, self.theta + condi = isinstance(self.n, Integer) + if not (isinstance(syms[0], IndexedBase) or condi): + raise ValueError("Please use IndexedBase object for syms as " + "the dimension is symbolic") + term_1 = factorial(n)/rf(theta, n) + if condi: + term_2 = Mul.fromiter(theta**syms[j]/((j+1)**syms[j]*factorial(syms[j])) + for j in range(n)) + cond = Eq(sum((k + 1)*syms[k] for k in range(n)), n) + return Piecewise((term_1 * term_2, cond), (0, True)) + syms = syms[0] + j, k = symbols('j, k', positive=True, integer=True) + term_2 = Product(theta**syms[j]/((j+1)**syms[j]*factorial(syms[j])), + (j, 0, n - 1)) + cond = Eq(Sum((k + 1)*syms[k], (k, 0, n - 1)), n) + return Piecewise((term_1 * term_2, cond), (0, True)) + + +def MultivariateEwens(syms, n, theta): + """ + Creates a discrete random variable with Multivariate Ewens + Distribution. + + The density of the said distribution can be found at [1]. + + Parameters + ========== + + n : Positive integer + Size of the sample or the integer whose partitions are considered + theta : Positive real number + Denotes Mutation rate + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import density, marginal_distribution, MultivariateEwens + >>> from sympy import Symbol + >>> a1 = Symbol('a1', positive=True) + >>> a2 = Symbol('a2', positive=True) + >>> ed = MultivariateEwens('E', 2, 1) + >>> density(ed)(a1, a2) + Piecewise((1/(2**a2*factorial(a1)*factorial(a2)), Eq(a1 + 2*a2, 2)), (0, True)) + >>> marginal_distribution(ed, ed[0])(a1) + Piecewise((1/factorial(a1), Eq(a1, 2)), (0, True)) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Ewens%27s_sampling_formula + .. [2] https://www.jstor.org/stable/24780825 + """ + return multivariate_rv(MultivariateEwensDistribution, syms, n, theta) + +#------------------------------------------------------------------------------- +# Generalized Multivariate Log Gamma distribution ------------------------------ + +class GeneralizedMultivariateLogGammaDistribution(JointDistribution): + + _argnames = ('delta', 'v', 'lamda', 'mu') + is_Continuous=True + + def check(self, delta, v, l, mu): + _value_check((delta >= 0, delta <= 1), "delta must be in range [0, 1].") + _value_check((v > 0), "v must be positive") + for lk in l: + _value_check((lk > 0), "lamda must be a positive vector.") + for muk in mu: + _value_check((muk > 0), "mu must be a positive vector.") + _value_check(len(l) > 1,"the distribution should have at least" + " two random variables.") + + @property + def set(self): + return S.Reals**len(self.lamda) + + def pdf(self, *y): + d, v, l, mu = self.delta, self.v, self.lamda, self.mu + n = Symbol('n', negative=False, integer=True) + k = len(l) + sterm1 = Pow((1 - d), n)/\ + ((gamma(v + n)**(k - 1))*gamma(v)*gamma(n + 1)) + sterm2 = Mul.fromiter(mui*li**(-v - n) for mui, li in zip(mu, l)) + term1 = sterm1 * sterm2 + sterm3 = (v + n) * sum(mui * yi for mui, yi in zip(mu, y)) + sterm4 = sum(exp(mui * yi)/li for (mui, yi, li) in zip(mu, y, l)) + term2 = exp(sterm3 - sterm4) + return Pow(d, v) * Sum(term1 * term2, (n, 0, S.Infinity)) + +def GeneralizedMultivariateLogGamma(syms, delta, v, lamda, mu): + """ + Creates a joint random variable with generalized multivariate log gamma + distribution. + + The joint pdf can be found at [1]. + + Parameters + ========== + + syms : list/tuple/set of symbols for identifying each component + delta : A constant in range $[0, 1]$ + v : Positive real number + lamda : List of positive real numbers + mu : List of positive real numbers + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import density + >>> from sympy.stats.joint_rv_types import GeneralizedMultivariateLogGamma + >>> from sympy import symbols, S + >>> v = 1 + >>> l, mu = [1, 1, 1], [1, 1, 1] + >>> d = S.Half + >>> y = symbols('y_1:4', positive=True) + >>> Gd = GeneralizedMultivariateLogGamma('G', d, v, l, mu) + >>> density(Gd)(y[0], y[1], y[2]) + Sum(exp((n + 1)*(y_1 + y_2 + y_3) - exp(y_1) - exp(y_2) - + exp(y_3))/(2**n*gamma(n + 1)**3), (n, 0, oo))/2 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Generalized_multivariate_log-gamma_distribution + .. [2] https://www.researchgate.net/publication/234137346_On_a_multivariate_log-gamma_distribution_and_the_use_of_the_distribution_in_the_Bayesian_analysis + + Note + ==== + + If the GeneralizedMultivariateLogGamma is too long to type use, + + >>> from sympy.stats.joint_rv_types import GeneralizedMultivariateLogGamma as GMVLG + >>> Gd = GMVLG('G', d, v, l, mu) + + If you want to pass the matrix omega instead of the constant delta, then use + ``GeneralizedMultivariateLogGammaOmega``. + + """ + return multivariate_rv(GeneralizedMultivariateLogGammaDistribution, + syms, delta, v, lamda, mu) + +def GeneralizedMultivariateLogGammaOmega(syms, omega, v, lamda, mu): + """ + Extends GeneralizedMultivariateLogGamma. + + Parameters + ========== + + syms : list/tuple/set of symbols + For identifying each component + omega : A square matrix + Every element of square matrix must be absolute value of + square root of correlation coefficient + v : Positive real number + lamda : List of positive real numbers + mu : List of positive real numbers + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import density + >>> from sympy.stats.joint_rv_types import GeneralizedMultivariateLogGammaOmega + >>> from sympy import Matrix, symbols, S + >>> omega = Matrix([[1, S.Half, S.Half], [S.Half, 1, S.Half], [S.Half, S.Half, 1]]) + >>> v = 1 + >>> l, mu = [1, 1, 1], [1, 1, 1] + >>> G = GeneralizedMultivariateLogGammaOmega('G', omega, v, l, mu) + >>> y = symbols('y_1:4', positive=True) + >>> density(G)(y[0], y[1], y[2]) + sqrt(2)*Sum((1 - sqrt(2)/2)**n*exp((n + 1)*(y_1 + y_2 + y_3) - exp(y_1) - + exp(y_2) - exp(y_3))/gamma(n + 1)**3, (n, 0, oo))/2 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Generalized_multivariate_log-gamma_distribution + .. [2] https://www.researchgate.net/publication/234137346_On_a_multivariate_log-gamma_distribution_and_the_use_of_the_distribution_in_the_Bayesian_analysis + + Notes + ===== + + If the GeneralizedMultivariateLogGammaOmega is too long to type use, + + >>> from sympy.stats.joint_rv_types import GeneralizedMultivariateLogGammaOmega as GMVLGO + >>> G = GMVLGO('G', omega, v, l, mu) + + """ + _value_check((omega.is_square, isinstance(omega, Matrix)), "omega must be a" + " square matrix") + for val in omega.values(): + _value_check((val >= 0, val <= 1), + "all values in matrix must be between 0 and 1(both inclusive).") + _value_check(omega.diagonal().equals(ones(1, omega.shape[0])), + "all the elements of diagonal should be 1.") + _value_check((omega.shape[0] == len(lamda), len(lamda) == len(mu)), + "lamda, mu should be of same length and omega should " + " be of shape (length of lamda, length of mu)") + _value_check(len(lamda) > 1,"the distribution should have at least" + " two random variables.") + delta = Pow(Rational(omega.det()), Rational(1, len(lamda) - 1)) + return GeneralizedMultivariateLogGamma(syms, delta, v, lamda, mu) + + +#------------------------------------------------------------------------------- +# Multinomial distribution ----------------------------------------------------- + +class MultinomialDistribution(JointDistribution): + + _argnames = ('n', 'p') + is_Continuous=False + is_Discrete = True + + @staticmethod + def check(n, p): + _value_check(n > 0, + "number of trials must be a positive integer") + for p_k in p: + _value_check((p_k >= 0, p_k <= 1), + "probability must be in range [0, 1]") + _value_check(Eq(sum(p), 1), + "probabilities must sum to 1") + + @property + def set(self): + return Intersection(S.Naturals0, Interval(0, self.n))**len(self.p) + + def pdf(self, *x): + n, p = self.n, self.p + term_1 = factorial(n)/Mul.fromiter(factorial(x_k) for x_k in x) + term_2 = Mul.fromiter(p_k**x_k for p_k, x_k in zip(p, x)) + return Piecewise((term_1 * term_2, Eq(sum(x), n)), (0, True)) + +def Multinomial(syms, n, *p): + """ + Creates a discrete random variable with Multinomial Distribution. + + The density of the said distribution can be found at [1]. + + Parameters + ========== + + n : Positive integer + Represents number of trials + p : List of event probabilities + Must be in the range of $[0, 1]$. + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import density, Multinomial, marginal_distribution + >>> from sympy import symbols + >>> x1, x2, x3 = symbols('x1, x2, x3', nonnegative=True, integer=True) + >>> p1, p2, p3 = symbols('p1, p2, p3', positive=True) + >>> M = Multinomial('M', 3, p1, p2, p3) + >>> density(M)(x1, x2, x3) + Piecewise((6*p1**x1*p2**x2*p3**x3/(factorial(x1)*factorial(x2)*factorial(x3)), + Eq(x1 + x2 + x3, 3)), (0, True)) + >>> marginal_distribution(M, M[0])(x1).subs(x1, 1) + 3*p1*p2**2 + 6*p1*p2*p3 + 3*p1*p3**2 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Multinomial_distribution + .. [2] https://mathworld.wolfram.com/MultinomialDistribution.html + + """ + if not isinstance(p[0], list): + p = (list(p), ) + return multivariate_rv(MultinomialDistribution, syms, n, p[0]) + +#------------------------------------------------------------------------------- +# Negative Multinomial Distribution -------------------------------------------- + +class NegativeMultinomialDistribution(JointDistribution): + + _argnames = ('k0', 'p') + is_Continuous=False + is_Discrete = True + + @staticmethod + def check(k0, p): + _value_check(k0 > 0, + "number of failures must be a positive integer") + for p_k in p: + _value_check((p_k >= 0, p_k <= 1), + "probability must be in range [0, 1].") + _value_check(sum(p) <= 1, + "success probabilities must not be greater than 1.") + + @property + def set(self): + return Range(0, S.Infinity)**len(self.p) + + def pdf(self, *k): + k0, p = self.k0, self.p + term_1 = (gamma(k0 + sum(k))*(1 - sum(p))**k0)/gamma(k0) + term_2 = Mul.fromiter(pi**ki/factorial(ki) for pi, ki in zip(p, k)) + return term_1 * term_2 + +def NegativeMultinomial(syms, k0, *p): + """ + Creates a discrete random variable with Negative Multinomial Distribution. + + The density of the said distribution can be found at [1]. + + Parameters + ========== + + k0 : positive integer + Represents number of failures before the experiment is stopped + p : List of event probabilities + Must be in the range of $[0, 1]$ + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import density, NegativeMultinomial, marginal_distribution + >>> from sympy import symbols + >>> x1, x2, x3 = symbols('x1, x2, x3', nonnegative=True, integer=True) + >>> p1, p2, p3 = symbols('p1, p2, p3', positive=True) + >>> N = NegativeMultinomial('M', 3, p1, p2, p3) + >>> N_c = NegativeMultinomial('M', 3, 0.1, 0.1, 0.1) + >>> density(N)(x1, x2, x3) + p1**x1*p2**x2*p3**x3*(-p1 - p2 - p3 + 1)**3*gamma(x1 + x2 + + x3 + 3)/(2*factorial(x1)*factorial(x2)*factorial(x3)) + >>> marginal_distribution(N_c, N_c[0])(1).evalf().round(2) + 0.25 + + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Negative_multinomial_distribution + .. [2] https://mathworld.wolfram.com/NegativeBinomialDistribution.html + + """ + if not isinstance(p[0], list): + p = (list(p), ) + return multivariate_rv(NegativeMultinomialDistribution, syms, k0, p[0]) diff --git a/phivenv/Lib/site-packages/sympy/stats/matrix_distributions.py b/phivenv/Lib/site-packages/sympy/stats/matrix_distributions.py new file mode 100644 index 0000000000000000000000000000000000000000..9a43c0226bc25702211a910ebbe30e280ad0cf50 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/stats/matrix_distributions.py @@ -0,0 +1,610 @@ +from math import prod + +from sympy.core.basic import Basic +from sympy.core.numbers import pi +from sympy.core.singleton import S +from sympy.functions.elementary.exponential import exp +from sympy.functions.special.gamma_functions import multigamma +from sympy.core.sympify import sympify, _sympify +from sympy.matrices import (ImmutableMatrix, Inverse, Trace, Determinant, + MatrixSymbol, MatrixBase, Transpose, MatrixSet, + matrix2numpy) +from sympy.stats.rv import (_value_check, RandomMatrixSymbol, NamedArgsMixin, PSpace, + _symbol_converter, MatrixDomain, Distribution) +from sympy.external import import_module + + +################################################################################ +#------------------------Matrix Probability Space------------------------------# +################################################################################ +class MatrixPSpace(PSpace): + """ + Represents probability space for + Matrix Distributions. + """ + def __new__(cls, sym, distribution, dim_n, dim_m): + sym = _symbol_converter(sym) + dim_n, dim_m = _sympify(dim_n), _sympify(dim_m) + if not (dim_n.is_integer and dim_m.is_integer): + raise ValueError("Dimensions should be integers") + return Basic.__new__(cls, sym, distribution, dim_n, dim_m) + + distribution = property(lambda self: self.args[1]) + symbol = property(lambda self: self.args[0]) + + @property + def domain(self): + return MatrixDomain(self.symbol, self.distribution.set) + + @property + def value(self): + return RandomMatrixSymbol(self.symbol, self.args[2], self.args[3], self) + + @property + def values(self): + return {self.value} + + def compute_density(self, expr, *args): + rms = expr.atoms(RandomMatrixSymbol) + if len(rms) > 1 or (not isinstance(expr, RandomMatrixSymbol)): + raise NotImplementedError("Currently, no algorithm has been " + "implemented to handle general expressions containing " + "multiple matrix distributions.") + return self.distribution.pdf(expr) + + def sample(self, size=(), library='scipy', seed=None): + """ + Internal sample method + + Returns dictionary mapping RandomMatrixSymbol to realization value. + """ + return {self.value: self.distribution.sample(size, library=library, seed=seed)} + + +def rv(symbol, cls, args): + args = list(map(sympify, args)) + dist = cls(*args) + dist.check(*args) + dim = dist.dimension + pspace = MatrixPSpace(symbol, dist, dim[0], dim[1]) + return pspace.value + + +class SampleMatrixScipy: + """Returns the sample from scipy of the given distribution""" + def __new__(cls, dist, size, seed=None): + return cls._sample_scipy(dist, size, seed) + + @classmethod + def _sample_scipy(cls, dist, size, seed): + """Sample from SciPy.""" + + from scipy import stats as scipy_stats + import numpy + scipy_rv_map = { + 'WishartDistribution': lambda dist, size, rand_state: scipy_stats.wishart.rvs( + df=int(dist.n), scale=matrix2numpy(dist.scale_matrix, float), size=size), + 'MatrixNormalDistribution': lambda dist, size, rand_state: scipy_stats.matrix_normal.rvs( + mean=matrix2numpy(dist.location_matrix, float), + rowcov=matrix2numpy(dist.scale_matrix_1, float), + colcov=matrix2numpy(dist.scale_matrix_2, float), size=size, random_state=rand_state) + } + + sample_shape = { + 'WishartDistribution': lambda dist: dist.scale_matrix.shape, + 'MatrixNormalDistribution' : lambda dist: dist.location_matrix.shape + } + + dist_list = scipy_rv_map.keys() + + if dist.__class__.__name__ not in dist_list: + return None + + if seed is None or isinstance(seed, int): + rand_state = numpy.random.default_rng(seed=seed) + else: + rand_state = seed + samp = scipy_rv_map[dist.__class__.__name__](dist, prod(size), rand_state) + return samp.reshape(size + sample_shape[dist.__class__.__name__](dist)) + + +class SampleMatrixNumpy: + """Returns the sample from numpy of the given distribution""" + + ### TODO: Add tests after adding matrix distributions in numpy_rv_map + def __new__(cls, dist, size, seed=None): + return cls._sample_numpy(dist, size, seed) + + @classmethod + def _sample_numpy(cls, dist, size, seed): + """Sample from NumPy.""" + + numpy_rv_map = { + } + + sample_shape = { + } + + dist_list = numpy_rv_map.keys() + + if dist.__class__.__name__ not in dist_list: + return None + + import numpy + if seed is None or isinstance(seed, int): + rand_state = numpy.random.default_rng(seed=seed) + else: + rand_state = seed + samp = numpy_rv_map[dist.__class__.__name__](dist, prod(size), rand_state) + return samp.reshape(size + sample_shape[dist.__class__.__name__](dist)) + + +class SampleMatrixPymc: + """Returns the sample from pymc of the given distribution""" + + def __new__(cls, dist, size, seed=None): + return cls._sample_pymc(dist, size, seed) + + @classmethod + def _sample_pymc(cls, dist, size, seed): + """Sample from PyMC.""" + + try: + import pymc + except ImportError: + import pymc3 as pymc + pymc_rv_map = { + 'MatrixNormalDistribution': lambda dist: pymc.MatrixNormal('X', + mu=matrix2numpy(dist.location_matrix, float), + rowcov=matrix2numpy(dist.scale_matrix_1, float), + colcov=matrix2numpy(dist.scale_matrix_2, float), + shape=dist.location_matrix.shape), + 'WishartDistribution': lambda dist: pymc.WishartBartlett('X', + nu=int(dist.n), S=matrix2numpy(dist.scale_matrix, float)) + } + + sample_shape = { + 'WishartDistribution': lambda dist: dist.scale_matrix.shape, + 'MatrixNormalDistribution' : lambda dist: dist.location_matrix.shape + } + + dist_list = pymc_rv_map.keys() + + if dist.__class__.__name__ not in dist_list: + return None + import logging + logging.getLogger("pymc").setLevel(logging.ERROR) + with pymc.Model(): + pymc_rv_map[dist.__class__.__name__](dist) + samps = pymc.sample(draws=prod(size), chains=1, progressbar=False, random_seed=seed, return_inferencedata=False, compute_convergence_checks=False)['X'] + return samps.reshape(size + sample_shape[dist.__class__.__name__](dist)) + +_get_sample_class_matrixrv = { + 'scipy': SampleMatrixScipy, + 'pymc3': SampleMatrixPymc, + 'pymc': SampleMatrixPymc, + 'numpy': SampleMatrixNumpy +} + +################################################################################ +#-------------------------Matrix Distribution----------------------------------# +################################################################################ + +class MatrixDistribution(Distribution, NamedArgsMixin): + """ + Abstract class for Matrix Distribution. + """ + def __new__(cls, *args): + args = [ImmutableMatrix(arg) if isinstance(arg, list) + else _sympify(arg) for arg in args] + return Basic.__new__(cls, *args) + + @staticmethod + def check(*args): + pass + + def __call__(self, expr): + if isinstance(expr, list): + expr = ImmutableMatrix(expr) + return self.pdf(expr) + + def sample(self, size=(), library='scipy', seed=None): + """ + Internal sample method + + Returns dictionary mapping RandomSymbol to realization value. + """ + + libraries = ['scipy', 'numpy', 'pymc3', 'pymc'] + if library not in libraries: + raise NotImplementedError("Sampling from %s is not supported yet." + % str(library)) + if not import_module(library): + raise ValueError("Failed to import %s" % library) + + samps = _get_sample_class_matrixrv[library](self, size, seed) + + if samps is not None: + return samps + raise NotImplementedError( + "Sampling for %s is not currently implemented from %s" + % (self.__class__.__name__, library) + ) + +################################################################################ +#------------------------Matrix Distribution Types-----------------------------# +################################################################################ + +#------------------------------------------------------------------------------- +# Matrix Gamma distribution ---------------------------------------------------- + +class MatrixGammaDistribution(MatrixDistribution): + + _argnames = ('alpha', 'beta', 'scale_matrix') + + @staticmethod + def check(alpha, beta, scale_matrix): + if not isinstance(scale_matrix, MatrixSymbol): + _value_check(scale_matrix.is_positive_definite, "The shape " + "matrix must be positive definite.") + _value_check(scale_matrix.is_square, "Should " + "be square matrix") + _value_check(alpha.is_positive, "Shape parameter should be positive.") + _value_check(beta.is_positive, "Scale parameter should be positive.") + + @property + def set(self): + k = self.scale_matrix.shape[0] + return MatrixSet(k, k, S.Reals) + + @property + def dimension(self): + return self.scale_matrix.shape + + def pdf(self, x): + alpha, beta, scale_matrix = self.alpha, self.beta, self.scale_matrix + p = scale_matrix.shape[0] + if isinstance(x, list): + x = ImmutableMatrix(x) + if not isinstance(x, (MatrixBase, MatrixSymbol)): + raise ValueError("%s should be an isinstance of Matrix " + "or MatrixSymbol" % str(x)) + sigma_inv_x = - Inverse(scale_matrix)*x / beta + term1 = exp(Trace(sigma_inv_x))/((beta**(p*alpha)) * multigamma(alpha, p)) + term2 = (Determinant(scale_matrix))**(-alpha) + term3 = (Determinant(x))**(alpha - S(p + 1)/2) + return term1 * term2 * term3 + +def MatrixGamma(symbol, alpha, beta, scale_matrix): + """ + Creates a random variable with Matrix Gamma Distribution. + + The density of the said distribution can be found at [1]. + + Parameters + ========== + + alpha: Positive Real number + Shape Parameter + beta: Positive Real number + Scale Parameter + scale_matrix: Positive definite real square matrix + Scale Matrix + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import density, MatrixGamma + >>> from sympy import MatrixSymbol, symbols + >>> a, b = symbols('a b', positive=True) + >>> M = MatrixGamma('M', a, b, [[2, 1], [1, 2]]) + >>> X = MatrixSymbol('X', 2, 2) + >>> density(M)(X).doit() + exp(Trace(Matrix([ + [-2/3, 1/3], + [ 1/3, -2/3]])*X)/b)*Determinant(X)**(a - 3/2)/(3**a*sqrt(pi)*b**(2*a)*gamma(a)*gamma(a - 1/2)) + >>> density(M)([[1, 0], [0, 1]]).doit() + exp(-4/(3*b))/(3**a*sqrt(pi)*b**(2*a)*gamma(a)*gamma(a - 1/2)) + + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Matrix_gamma_distribution + + """ + if isinstance(scale_matrix, list): + scale_matrix = ImmutableMatrix(scale_matrix) + return rv(symbol, MatrixGammaDistribution, (alpha, beta, scale_matrix)) + +#------------------------------------------------------------------------------- +# Wishart Distribution --------------------------------------------------------- + +class WishartDistribution(MatrixDistribution): + + _argnames = ('n', 'scale_matrix') + + @staticmethod + def check(n, scale_matrix): + if not isinstance(scale_matrix, MatrixSymbol): + _value_check(scale_matrix.is_positive_definite, "The shape " + "matrix must be positive definite.") + _value_check(scale_matrix.is_square, "Should " + "be square matrix") + _value_check(n.is_positive, "Shape parameter should be positive.") + + @property + def set(self): + k = self.scale_matrix.shape[0] + return MatrixSet(k, k, S.Reals) + + @property + def dimension(self): + return self.scale_matrix.shape + + def pdf(self, x): + n, scale_matrix = self.n, self.scale_matrix + p = scale_matrix.shape[0] + if isinstance(x, list): + x = ImmutableMatrix(x) + if not isinstance(x, (MatrixBase, MatrixSymbol)): + raise ValueError("%s should be an isinstance of Matrix " + "or MatrixSymbol" % str(x)) + sigma_inv_x = - Inverse(scale_matrix)*x / S(2) + term1 = exp(Trace(sigma_inv_x))/((2**(p*n/S(2))) * multigamma(n/S(2), p)) + term2 = (Determinant(scale_matrix))**(-n/S(2)) + term3 = (Determinant(x))**(S(n - p - 1)/2) + return term1 * term2 * term3 + +def Wishart(symbol, n, scale_matrix): + """ + Creates a random variable with Wishart Distribution. + + The density of the said distribution can be found at [1]. + + Parameters + ========== + + n: Positive Real number + Represents degrees of freedom + scale_matrix: Positive definite real square matrix + Scale Matrix + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy.stats import density, Wishart + >>> from sympy import MatrixSymbol, symbols + >>> n = symbols('n', positive=True) + >>> W = Wishart('W', n, [[2, 1], [1, 2]]) + >>> X = MatrixSymbol('X', 2, 2) + >>> density(W)(X).doit() + exp(Trace(Matrix([ + [-1/3, 1/6], + [ 1/6, -1/3]])*X))*Determinant(X)**(n/2 - 3/2)/(2**n*3**(n/2)*sqrt(pi)*gamma(n/2)*gamma(n/2 - 1/2)) + >>> density(W)([[1, 0], [0, 1]]).doit() + exp(-2/3)/(2**n*3**(n/2)*sqrt(pi)*gamma(n/2)*gamma(n/2 - 1/2)) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Wishart_distribution + + """ + if isinstance(scale_matrix, list): + scale_matrix = ImmutableMatrix(scale_matrix) + return rv(symbol, WishartDistribution, (n, scale_matrix)) + +#------------------------------------------------------------------------------- +# Matrix Normal distribution --------------------------------------------------- + +class MatrixNormalDistribution(MatrixDistribution): + + _argnames = ('location_matrix', 'scale_matrix_1', 'scale_matrix_2') + + @staticmethod + def check(location_matrix, scale_matrix_1, scale_matrix_2): + if not isinstance(scale_matrix_1, MatrixSymbol): + _value_check(scale_matrix_1.is_positive_definite, "The shape " + "matrix must be positive definite.") + if not isinstance(scale_matrix_2, MatrixSymbol): + _value_check(scale_matrix_2.is_positive_definite, "The shape " + "matrix must be positive definite.") + _value_check(scale_matrix_1.is_square, "Scale matrix 1 should be " + "be square matrix") + _value_check(scale_matrix_2.is_square, "Scale matrix 2 should be " + "be square matrix") + n = location_matrix.shape[0] + p = location_matrix.shape[1] + _value_check(scale_matrix_1.shape[0] == n, "Scale matrix 1 should be" + " of shape %s x %s"% (str(n), str(n))) + _value_check(scale_matrix_2.shape[0] == p, "Scale matrix 2 should be" + " of shape %s x %s"% (str(p), str(p))) + + @property + def set(self): + n, p = self.location_matrix.shape + return MatrixSet(n, p, S.Reals) + + @property + def dimension(self): + return self.location_matrix.shape + + def pdf(self, x): + M, U, V = self.location_matrix, self.scale_matrix_1, self.scale_matrix_2 + n, p = M.shape + if isinstance(x, list): + x = ImmutableMatrix(x) + if not isinstance(x, (MatrixBase, MatrixSymbol)): + raise ValueError("%s should be an isinstance of Matrix " + "or MatrixSymbol" % str(x)) + term1 = Inverse(V)*Transpose(x - M)*Inverse(U)*(x - M) + num = exp(-Trace(term1)/S(2)) + den = (2*pi)**(S(n*p)/2) * Determinant(U)**(S(p)/2) * Determinant(V)**(S(n)/2) + return num/den + +def MatrixNormal(symbol, location_matrix, scale_matrix_1, scale_matrix_2): + """ + Creates a random variable with Matrix Normal Distribution. + + The density of the said distribution can be found at [1]. + + Parameters + ========== + + location_matrix: Real ``n x p`` matrix + Represents degrees of freedom + scale_matrix_1: Positive definite matrix + Scale Matrix of shape ``n x n`` + scale_matrix_2: Positive definite matrix + Scale Matrix of shape ``p x p`` + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy import MatrixSymbol + >>> from sympy.stats import density, MatrixNormal + >>> M = MatrixNormal('M', [[1, 2]], [1], [[1, 0], [0, 1]]) + >>> X = MatrixSymbol('X', 1, 2) + >>> density(M)(X).doit() + exp(-Trace((Matrix([ + [-1], + [-2]]) + X.T)*(Matrix([[-1, -2]]) + X))/2)/(2*pi) + >>> density(M)([[3, 4]]).doit() + exp(-4)/(2*pi) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Matrix_normal_distribution + + """ + if isinstance(location_matrix, list): + location_matrix = ImmutableMatrix(location_matrix) + if isinstance(scale_matrix_1, list): + scale_matrix_1 = ImmutableMatrix(scale_matrix_1) + if isinstance(scale_matrix_2, list): + scale_matrix_2 = ImmutableMatrix(scale_matrix_2) + args = (location_matrix, scale_matrix_1, scale_matrix_2) + return rv(symbol, MatrixNormalDistribution, args) + +#------------------------------------------------------------------------------- +# Matrix Student's T distribution --------------------------------------------------- + +class MatrixStudentTDistribution(MatrixDistribution): + + _argnames = ('nu', 'location_matrix', 'scale_matrix_1', 'scale_matrix_2') + + @staticmethod + def check(nu, location_matrix, scale_matrix_1, scale_matrix_2): + if not isinstance(scale_matrix_1, MatrixSymbol): + _value_check(scale_matrix_1.is_positive_definite != False, "The shape " + "matrix must be positive definite.") + if not isinstance(scale_matrix_2, MatrixSymbol): + _value_check(scale_matrix_2.is_positive_definite != False, "The shape " + "matrix must be positive definite.") + _value_check(scale_matrix_1.is_square != False, "Scale matrix 1 should be " + "be square matrix") + _value_check(scale_matrix_2.is_square != False, "Scale matrix 2 should be " + "be square matrix") + n = location_matrix.shape[0] + p = location_matrix.shape[1] + _value_check(scale_matrix_1.shape[0] == p, "Scale matrix 1 should be" + " of shape %s x %s" % (str(p), str(p))) + _value_check(scale_matrix_2.shape[0] == n, "Scale matrix 2 should be" + " of shape %s x %s" % (str(n), str(n))) + _value_check(nu.is_positive != False, "Degrees of freedom must be positive") + + @property + def set(self): + n, p = self.location_matrix.shape + return MatrixSet(n, p, S.Reals) + + @property + def dimension(self): + return self.location_matrix.shape + + def pdf(self, x): + from sympy.matrices.dense import eye + if isinstance(x, list): + x = ImmutableMatrix(x) + if not isinstance(x, (MatrixBase, MatrixSymbol)): + raise ValueError("%s should be an isinstance of Matrix " + "or MatrixSymbol" % str(x)) + nu, M, Omega, Sigma = self.nu, self.location_matrix, self.scale_matrix_1, self.scale_matrix_2 + n, p = M.shape + + K = multigamma((nu + n + p - 1)/2, p) * Determinant(Omega)**(-n/2) * Determinant(Sigma)**(-p/2) \ + / ((pi)**(n*p/2) * multigamma((nu + p - 1)/2, p)) + return K * (Determinant(eye(n) + Inverse(Sigma)*(x - M)*Inverse(Omega)*Transpose(x - M))) \ + **(-(nu + n + p -1)/2) + + + +def MatrixStudentT(symbol, nu, location_matrix, scale_matrix_1, scale_matrix_2): + """ + Creates a random variable with Matrix Gamma Distribution. + + The density of the said distribution can be found at [1]. + + Parameters + ========== + + nu: Positive Real number + degrees of freedom + location_matrix: Positive definite real square matrix + Location Matrix of shape ``n x p`` + scale_matrix_1: Positive definite real square matrix + Scale Matrix of shape ``p x p`` + scale_matrix_2: Positive definite real square matrix + Scale Matrix of shape ``n x n`` + + Returns + ======= + + RandomSymbol + + Examples + ======== + + >>> from sympy import MatrixSymbol,symbols + >>> from sympy.stats import density, MatrixStudentT + >>> v = symbols('v',positive=True) + >>> M = MatrixStudentT('M', v, [[1, 2]], [[1, 0], [0, 1]], [1]) + >>> X = MatrixSymbol('X', 1, 2) + >>> density(M)(X) + gamma(v/2 + 1)*Determinant((Matrix([[-1, -2]]) + X)*(Matrix([ + [-1], + [-2]]) + X.T) + Matrix([[1]]))**(-v/2 - 1)/(pi**1.0*gamma(v/2)*Determinant(Matrix([[1]]))**1.0*Determinant(Matrix([ + [1, 0], + [0, 1]]))**0.5) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Matrix_t-distribution + + """ + if isinstance(location_matrix, list): + location_matrix = ImmutableMatrix(location_matrix) + if isinstance(scale_matrix_1, list): + scale_matrix_1 = ImmutableMatrix(scale_matrix_1) + if isinstance(scale_matrix_2, list): + scale_matrix_2 = ImmutableMatrix(scale_matrix_2) + args = (nu, location_matrix, scale_matrix_1, scale_matrix_2) + return rv(symbol, MatrixStudentTDistribution, args) diff --git a/phivenv/Lib/site-packages/sympy/stats/random_matrix.py b/phivenv/Lib/site-packages/sympy/stats/random_matrix.py new file mode 100644 index 0000000000000000000000000000000000000000..fdd25cb9ad23fed9d3a85982b24bef33d04928f0 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/stats/random_matrix.py @@ -0,0 +1,30 @@ +from sympy.core.basic import Basic +from sympy.stats.rv import PSpace, _symbol_converter, RandomMatrixSymbol + +class RandomMatrixPSpace(PSpace): + """ + Represents probability space for + random matrices. It contains the mechanics + for handling the API calls for random matrices. + """ + def __new__(cls, sym, model=None): + sym = _symbol_converter(sym) + if model: + return Basic.__new__(cls, sym, model) + else: + return Basic.__new__(cls, sym) + + @property + def model(self): + try: + return self.args[1] + except IndexError: + return None + + def compute_density(self, expr, *args): + rms = expr.atoms(RandomMatrixSymbol) + if len(rms) > 2 or (not isinstance(expr, RandomMatrixSymbol)): + raise NotImplementedError("Currently, no algorithm has been " + "implemented to handle general expressions containing " + "multiple random matrices.") + return self.model.density(expr) diff --git a/phivenv/Lib/site-packages/sympy/stats/random_matrix_models.py b/phivenv/Lib/site-packages/sympy/stats/random_matrix_models.py new file mode 100644 index 0000000000000000000000000000000000000000..6327a248ea5919c0bbb0ffc2c984105e04fe20e9 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/stats/random_matrix_models.py @@ -0,0 +1,457 @@ +from sympy.concrete.products import Product +from sympy.concrete.summations import Sum +from sympy.core.basic import Basic +from sympy.core.function import Lambda +from sympy.core.numbers import (I, pi) +from sympy.core.singleton import S +from sympy.core.symbol import Dummy +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.exponential import exp +from sympy.functions.special.gamma_functions import gamma +from sympy.integrals.integrals import Integral +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.expressions.trace import Trace +from sympy.tensor.indexed import IndexedBase +from sympy.core.sympify import _sympify +from sympy.stats.rv import _symbol_converter, Density, RandomMatrixSymbol, is_random +from sympy.stats.joint_rv_types import JointDistributionHandmade +from sympy.stats.random_matrix import RandomMatrixPSpace +from sympy.tensor.array import ArrayComprehension + +__all__ = [ + 'CircularEnsemble', + 'CircularUnitaryEnsemble', + 'CircularOrthogonalEnsemble', + 'CircularSymplecticEnsemble', + 'GaussianEnsemble', + 'GaussianUnitaryEnsemble', + 'GaussianOrthogonalEnsemble', + 'GaussianSymplecticEnsemble', + 'joint_eigen_distribution', + 'JointEigenDistribution', + 'level_spacing_distribution' +] + +@is_random.register(RandomMatrixSymbol) +def _(x): + return True + + +class RandomMatrixEnsembleModel(Basic): + """ + Base class for random matrix ensembles. + It acts as an umbrella and contains + the methods common to all the ensembles + defined in sympy.stats.random_matrix_models. + """ + def __new__(cls, sym, dim=None): + sym, dim = _symbol_converter(sym), _sympify(dim) + if dim.is_integer == False: + raise ValueError("Dimension of the random matrices must be " + "integers, received %s instead."%(dim)) + return Basic.__new__(cls, sym, dim) + + symbol = property(lambda self: self.args[0]) + dimension = property(lambda self: self.args[1]) + + def density(self, expr): + return Density(expr) + + def __call__(self, expr): + return self.density(expr) + +class GaussianEnsembleModel(RandomMatrixEnsembleModel): + """ + Abstract class for Gaussian ensembles. + Contains the properties common to all the + gaussian ensembles. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Random_matrix#Gaussian_ensembles + .. [2] https://arxiv.org/pdf/1712.07903.pdf + """ + def _compute_normalization_constant(self, beta, n): + """ + Helper function for computing normalization + constant for joint probability density of eigen + values of Gaussian ensembles. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Selberg_integral#Mehta's_integral + """ + n = S(n) + prod_term = lambda j: gamma(1 + beta*S(j)/2)/gamma(S.One + beta/S(2)) + j = Dummy('j', integer=True, positive=True) + term1 = Product(prod_term(j), (j, 1, n)).doit() + term2 = (2/(beta*n))**(beta*n*(n - 1)/4 + n/2) + term3 = (2*pi)**(n/2) + return term1 * term2 * term3 + + def _compute_joint_eigen_distribution(self, beta): + """ + Helper function for computing the joint + probability distribution of eigen values + of the random matrix. + """ + n = self.dimension + Zbn = self._compute_normalization_constant(beta, n) + l = IndexedBase('l') + i = Dummy('i', integer=True, positive=True) + j = Dummy('j', integer=True, positive=True) + k = Dummy('k', integer=True, positive=True) + term1 = exp((-S(n)/2) * Sum(l[k]**2, (k, 1, n)).doit()) + sub_term = Lambda(i, Product(Abs(l[j] - l[i])**beta, (j, i + 1, n))) + term2 = Product(sub_term(i).doit(), (i, 1, n - 1)).doit() + syms = ArrayComprehension(l[k], (k, 1, n)).doit() + return Lambda(tuple(syms), (term1 * term2)/Zbn) + +class GaussianUnitaryEnsembleModel(GaussianEnsembleModel): + @property + def normalization_constant(self): + n = self.dimension + return 2**(S(n)/2) * pi**(S(n**2)/2) + + def density(self, expr): + n, ZGUE = self.dimension, self.normalization_constant + h_pspace = RandomMatrixPSpace('P', model=self) + H = RandomMatrixSymbol('H', n, n, pspace=h_pspace) + return Lambda(H, exp(-S(n)/2 * Trace(H**2))/ZGUE)(expr) + + def joint_eigen_distribution(self): + return self._compute_joint_eigen_distribution(S(2)) + + def level_spacing_distribution(self): + s = Dummy('s') + f = (32/pi**2)*(s**2)*exp((-4/pi)*s**2) + return Lambda(s, f) + +class GaussianOrthogonalEnsembleModel(GaussianEnsembleModel): + @property + def normalization_constant(self): + n = self.dimension + _H = MatrixSymbol('_H', n, n) + return Integral(exp(-S(n)/4 * Trace(_H**2))) + + def density(self, expr): + n, ZGOE = self.dimension, self.normalization_constant + h_pspace = RandomMatrixPSpace('P', model=self) + H = RandomMatrixSymbol('H', n, n, pspace=h_pspace) + return Lambda(H, exp(-S(n)/4 * Trace(H**2))/ZGOE)(expr) + + def joint_eigen_distribution(self): + return self._compute_joint_eigen_distribution(S.One) + + def level_spacing_distribution(self): + s = Dummy('s') + f = (pi/2)*s*exp((-pi/4)*s**2) + return Lambda(s, f) + +class GaussianSymplecticEnsembleModel(GaussianEnsembleModel): + @property + def normalization_constant(self): + n = self.dimension + _H = MatrixSymbol('_H', n, n) + return Integral(exp(-S(n) * Trace(_H**2))) + + def density(self, expr): + n, ZGSE = self.dimension, self.normalization_constant + h_pspace = RandomMatrixPSpace('P', model=self) + H = RandomMatrixSymbol('H', n, n, pspace=h_pspace) + return Lambda(H, exp(-S(n) * Trace(H**2))/ZGSE)(expr) + + def joint_eigen_distribution(self): + return self._compute_joint_eigen_distribution(S(4)) + + def level_spacing_distribution(self): + s = Dummy('s') + f = ((S(2)**18)/((S(3)**6)*(pi**3)))*(s**4)*exp((-64/(9*pi))*s**2) + return Lambda(s, f) + +def GaussianEnsemble(sym, dim): + sym, dim = _symbol_converter(sym), _sympify(dim) + model = GaussianEnsembleModel(sym, dim) + rmp = RandomMatrixPSpace(sym, model=model) + return RandomMatrixSymbol(sym, dim, dim, pspace=rmp) + +def GaussianUnitaryEnsemble(sym, dim): + """ + Represents Gaussian Unitary Ensembles. + + Examples + ======== + + >>> from sympy.stats import GaussianUnitaryEnsemble as GUE, density + >>> from sympy import MatrixSymbol + >>> G = GUE('U', 2) + >>> X = MatrixSymbol('X', 2, 2) + >>> density(G)(X) + exp(-Trace(X**2))/(2*pi**2) + """ + sym, dim = _symbol_converter(sym), _sympify(dim) + model = GaussianUnitaryEnsembleModel(sym, dim) + rmp = RandomMatrixPSpace(sym, model=model) + return RandomMatrixSymbol(sym, dim, dim, pspace=rmp) + +def GaussianOrthogonalEnsemble(sym, dim): + """ + Represents Gaussian Orthogonal Ensembles. + + Examples + ======== + + >>> from sympy.stats import GaussianOrthogonalEnsemble as GOE, density + >>> from sympy import MatrixSymbol + >>> G = GOE('U', 2) + >>> X = MatrixSymbol('X', 2, 2) + >>> density(G)(X) + exp(-Trace(X**2)/2)/Integral(exp(-Trace(_H**2)/2), _H) + """ + sym, dim = _symbol_converter(sym), _sympify(dim) + model = GaussianOrthogonalEnsembleModel(sym, dim) + rmp = RandomMatrixPSpace(sym, model=model) + return RandomMatrixSymbol(sym, dim, dim, pspace=rmp) + +def GaussianSymplecticEnsemble(sym, dim): + """ + Represents Gaussian Symplectic Ensembles. + + Examples + ======== + + >>> from sympy.stats import GaussianSymplecticEnsemble as GSE, density + >>> from sympy import MatrixSymbol + >>> G = GSE('U', 2) + >>> X = MatrixSymbol('X', 2, 2) + >>> density(G)(X) + exp(-2*Trace(X**2))/Integral(exp(-2*Trace(_H**2)), _H) + """ + sym, dim = _symbol_converter(sym), _sympify(dim) + model = GaussianSymplecticEnsembleModel(sym, dim) + rmp = RandomMatrixPSpace(sym, model=model) + return RandomMatrixSymbol(sym, dim, dim, pspace=rmp) + +class CircularEnsembleModel(RandomMatrixEnsembleModel): + """ + Abstract class for Circular ensembles. + Contains the properties and methods + common to all the circular ensembles. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Circular_ensemble + """ + def density(self, expr): + # TODO : Add support for Lie groups(as extensions of sympy.diffgeom) + # and define measures on them + raise NotImplementedError("Support for Haar measure hasn't been " + "implemented yet, therefore the density of " + "%s cannot be computed."%(self)) + + def _compute_joint_eigen_distribution(self, beta): + """ + Helper function to compute the joint distribution of phases + of the complex eigen values of matrices belonging to any + circular ensembles. + """ + n = self.dimension + Zbn = ((2*pi)**n)*(gamma(beta*n/2 + 1)/S(gamma(beta/2 + 1))**n) + t = IndexedBase('t') + i, j, k = (Dummy('i', integer=True), Dummy('j', integer=True), + Dummy('k', integer=True)) + syms = ArrayComprehension(t[i], (i, 1, n)).doit() + f = Product(Product(Abs(exp(I*t[k]) - exp(I*t[j]))**beta, (j, k + 1, n)).doit(), + (k, 1, n - 1)).doit() + return Lambda(tuple(syms), f/Zbn) + +class CircularUnitaryEnsembleModel(CircularEnsembleModel): + def joint_eigen_distribution(self): + return self._compute_joint_eigen_distribution(S(2)) + +class CircularOrthogonalEnsembleModel(CircularEnsembleModel): + def joint_eigen_distribution(self): + return self._compute_joint_eigen_distribution(S.One) + +class CircularSymplecticEnsembleModel(CircularEnsembleModel): + def joint_eigen_distribution(self): + return self._compute_joint_eigen_distribution(S(4)) + +def CircularEnsemble(sym, dim): + sym, dim = _symbol_converter(sym), _sympify(dim) + model = CircularEnsembleModel(sym, dim) + rmp = RandomMatrixPSpace(sym, model=model) + return RandomMatrixSymbol(sym, dim, dim, pspace=rmp) + +def CircularUnitaryEnsemble(sym, dim): + """ + Represents Circular Unitary Ensembles. + + Examples + ======== + + >>> from sympy.stats import CircularUnitaryEnsemble as CUE + >>> from sympy.stats import joint_eigen_distribution + >>> C = CUE('U', 1) + >>> joint_eigen_distribution(C) + Lambda(t[1], Product(Abs(exp(I*t[_j]) - exp(I*t[_k]))**2, (_j, _k + 1, 1), (_k, 1, 0))/(2*pi)) + + Note + ==== + + As can be seen above in the example, density of CiruclarUnitaryEnsemble + is not evaluated because the exact definition is based on haar measure of + unitary group which is not unique. + """ + sym, dim = _symbol_converter(sym), _sympify(dim) + model = CircularUnitaryEnsembleModel(sym, dim) + rmp = RandomMatrixPSpace(sym, model=model) + return RandomMatrixSymbol(sym, dim, dim, pspace=rmp) + +def CircularOrthogonalEnsemble(sym, dim): + """ + Represents Circular Orthogonal Ensembles. + + Examples + ======== + + >>> from sympy.stats import CircularOrthogonalEnsemble as COE + >>> from sympy.stats import joint_eigen_distribution + >>> C = COE('O', 1) + >>> joint_eigen_distribution(C) + Lambda(t[1], Product(Abs(exp(I*t[_j]) - exp(I*t[_k])), (_j, _k + 1, 1), (_k, 1, 0))/(2*pi)) + + Note + ==== + + As can be seen above in the example, density of CiruclarOrthogonalEnsemble + is not evaluated because the exact definition is based on haar measure of + unitary group which is not unique. + """ + sym, dim = _symbol_converter(sym), _sympify(dim) + model = CircularOrthogonalEnsembleModel(sym, dim) + rmp = RandomMatrixPSpace(sym, model=model) + return RandomMatrixSymbol(sym, dim, dim, pspace=rmp) + +def CircularSymplecticEnsemble(sym, dim): + """ + Represents Circular Symplectic Ensembles. + + Examples + ======== + + >>> from sympy.stats import CircularSymplecticEnsemble as CSE + >>> from sympy.stats import joint_eigen_distribution + >>> C = CSE('S', 1) + >>> joint_eigen_distribution(C) + Lambda(t[1], Product(Abs(exp(I*t[_j]) - exp(I*t[_k]))**4, (_j, _k + 1, 1), (_k, 1, 0))/(2*pi)) + + Note + ==== + + As can be seen above in the example, density of CiruclarSymplecticEnsemble + is not evaluated because the exact definition is based on haar measure of + unitary group which is not unique. + """ + sym, dim = _symbol_converter(sym), _sympify(dim) + model = CircularSymplecticEnsembleModel(sym, dim) + rmp = RandomMatrixPSpace(sym, model=model) + return RandomMatrixSymbol(sym, dim, dim, pspace=rmp) + +def joint_eigen_distribution(mat): + """ + For obtaining joint probability distribution + of eigen values of random matrix. + + Parameters + ========== + + mat: RandomMatrixSymbol + The matrix symbol whose eigen values are to be considered. + + Returns + ======= + + Lambda + + Examples + ======== + + >>> from sympy.stats import GaussianUnitaryEnsemble as GUE + >>> from sympy.stats import joint_eigen_distribution + >>> U = GUE('U', 2) + >>> joint_eigen_distribution(U) + Lambda((l[1], l[2]), exp(-l[1]**2 - l[2]**2)*Product(Abs(l[_i] - l[_j])**2, (_j, _i + 1, 2), (_i, 1, 1))/pi) + """ + if not isinstance(mat, RandomMatrixSymbol): + raise ValueError("%s is not of type, RandomMatrixSymbol."%(mat)) + return mat.pspace.model.joint_eigen_distribution() + +def JointEigenDistribution(mat): + """ + Creates joint distribution of eigen values of matrices with random + expressions. + + Parameters + ========== + + mat: Matrix + The matrix under consideration. + + Returns + ======= + + JointDistributionHandmade + + Examples + ======== + + >>> from sympy.stats import Normal, JointEigenDistribution + >>> from sympy import Matrix + >>> A = [[Normal('A00', 0, 1), Normal('A01', 0, 1)], + ... [Normal('A10', 0, 1), Normal('A11', 0, 1)]] + >>> JointEigenDistribution(Matrix(A)) + JointDistributionHandmade(-sqrt(A00**2 - 2*A00*A11 + 4*A01*A10 + A11**2)/2 + + A00/2 + A11/2, sqrt(A00**2 - 2*A00*A11 + 4*A01*A10 + A11**2)/2 + A00/2 + A11/2) + + """ + eigenvals = mat.eigenvals(multiple=True) + if not all(is_random(eigenval) for eigenval in set(eigenvals)): + raise ValueError("Eigen values do not have any random expression, " + "joint distribution cannot be generated.") + return JointDistributionHandmade(*eigenvals) + +def level_spacing_distribution(mat): + """ + For obtaining distribution of level spacings. + + Parameters + ========== + + mat: RandomMatrixSymbol + The random matrix symbol whose eigen values are + to be considered for finding the level spacings. + + Returns + ======= + + Lambda + + Examples + ======== + + >>> from sympy.stats import GaussianUnitaryEnsemble as GUE + >>> from sympy.stats import level_spacing_distribution + >>> U = GUE('U', 2) + >>> level_spacing_distribution(U) + Lambda(_s, 32*_s**2*exp(-4*_s**2/pi)/pi**2) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Random_matrix#Distribution_of_level_spacings + """ + return mat.pspace.model.level_spacing_distribution() diff --git a/phivenv/Lib/site-packages/sympy/stats/rv.py b/phivenv/Lib/site-packages/sympy/stats/rv.py new file mode 100644 index 0000000000000000000000000000000000000000..75ab54deb551b7ff3d4d06f37482a1f16a789ba6 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/stats/rv.py @@ -0,0 +1,1798 @@ +""" +Main Random Variables Module + +Defines abstract random variable type. +Contains interfaces for probability space object (PSpace) as well as standard +operators, P, E, sample, density, where, quantile + +See Also +======== + +sympy.stats.crv +sympy.stats.frv +sympy.stats.rv_interface +""" + +from __future__ import annotations +from functools import singledispatch +from math import prod + +from sympy.core.add import Add +from sympy.core.basic import Basic +from sympy.core.containers import Tuple +from sympy.core.expr import Expr +from sympy.core.function import (Function, Lambda) +from sympy.core.logic import fuzzy_and +from sympy.core.mul import Mul +from sympy.core.relational import (Eq, Ne) +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol) +from sympy.core.sympify import sympify +from sympy.functions.special.delta_functions import DiracDelta +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.logic.boolalg import (And, Or) +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.tensor.indexed import Indexed +from sympy.utilities.lambdify import lambdify +from sympy.core.relational import Relational +from sympy.core.sympify import _sympify +from sympy.sets.sets import FiniteSet, ProductSet, Intersection +from sympy.solvers.solveset import solveset +from sympy.external import import_module +from sympy.utilities.decorator import doctest_depends_on +from sympy.utilities.exceptions import sympy_deprecation_warning +from sympy.utilities.iterables import iterable + + +__doctest_requires__ = {('sample',): ['scipy']} + + +x = Symbol('x') + + +@singledispatch +def is_random(x): + return False + + +@is_random.register(Basic) +def _(x): + atoms = x.free_symbols + return any(is_random(i) for i in atoms) + + +class RandomDomain(Basic): + """ + Represents a set of variables and the values which they can take. + + See Also + ======== + + sympy.stats.crv.ContinuousDomain + sympy.stats.frv.FiniteDomain + """ + + is_ProductDomain = False + is_Finite = False + is_Continuous = False + is_Discrete = False + + def __new__(cls, symbols, *args): + symbols = FiniteSet(*symbols) + return Basic.__new__(cls, symbols, *args) + + @property + def symbols(self): + return self.args[0] + + @property + def set(self): + return self.args[1] + + def __contains__(self, other): + raise NotImplementedError() + + def compute_expectation(self, expr): + raise NotImplementedError() + + +class SingleDomain(RandomDomain): + """ + A single variable and its domain. + + See Also + ======== + + sympy.stats.crv.SingleContinuousDomain + sympy.stats.frv.SingleFiniteDomain + """ + def __new__(cls, symbol, set): + assert symbol.is_Symbol + return Basic.__new__(cls, symbol, set) + + @property + def symbol(self): + return self.args[0] + + @property + def symbols(self): + return FiniteSet(self.symbol) + + def __contains__(self, other): + if len(other) != 1: + return False + sym, val = tuple(other)[0] + return self.symbol == sym and val in self.set + + +class MatrixDomain(RandomDomain): + """ + A Random Matrix variable and its domain. + + """ + def __new__(cls, symbol, set): + symbol, set = _symbol_converter(symbol), _sympify(set) + return Basic.__new__(cls, symbol, set) + + @property + def symbol(self): + return self.args[0] + + @property + def symbols(self): + return FiniteSet(self.symbol) + + +class ConditionalDomain(RandomDomain): + """ + A RandomDomain with an attached condition. + + See Also + ======== + + sympy.stats.crv.ConditionalContinuousDomain + sympy.stats.frv.ConditionalFiniteDomain + """ + def __new__(cls, fulldomain, condition): + condition = condition.xreplace({rs: rs.symbol + for rs in random_symbols(condition)}) + return Basic.__new__(cls, fulldomain, condition) + + @property + def symbols(self): + return self.fulldomain.symbols + + @property + def fulldomain(self): + return self.args[0] + + @property + def condition(self): + return self.args[1] + + @property + def set(self): + raise NotImplementedError("Set of Conditional Domain not Implemented") + + def as_boolean(self): + return And(self.fulldomain.as_boolean(), self.condition) + + +class PSpace(Basic): + """ + A Probability Space. + + Explanation + =========== + + Probability Spaces encode processes that equal different values + probabilistically. These underly Random Symbols which occur in SymPy + expressions and contain the mechanics to evaluate statistical statements. + + See Also + ======== + + sympy.stats.crv.ContinuousPSpace + sympy.stats.frv.FinitePSpace + """ + + is_Finite: bool | None = None # Fails test if not set to None + is_Continuous: bool | None = None # Fails test if not set to None + is_Discrete: bool | None = None # Fails test if not set to None + is_real: bool | None + + @property + def domain(self): + return self.args[0] + + @property + def density(self): + return self.args[1] + + @property + def values(self): + return frozenset(RandomSymbol(sym, self) for sym in self.symbols) + + @property + def symbols(self): + return self.domain.symbols + + def where(self, condition): + raise NotImplementedError() + + def compute_density(self, expr): + raise NotImplementedError() + + def sample(self, size=(), library='scipy', seed=None): + raise NotImplementedError() + + def probability(self, condition): + raise NotImplementedError() + + def compute_expectation(self, expr): + raise NotImplementedError() + + +class SinglePSpace(PSpace): + """ + Represents the probabilities of a set of random events that can be + attributed to a single variable/symbol. + """ + def __new__(cls, s, distribution): + s = _symbol_converter(s) + return Basic.__new__(cls, s, distribution) + + @property + def value(self): + return RandomSymbol(self.symbol, self) + + @property + def symbol(self): + return self.args[0] + + @property + def distribution(self): + return self.args[1] + + @property + def pdf(self): + return self.distribution.pdf(self.symbol) + + +class RandomSymbol(Expr): + """ + Random Symbols represent ProbabilitySpaces in SymPy Expressions. + In principle they can take on any value that their symbol can take on + within the associated PSpace with probability determined by the PSpace + Density. + + Explanation + =========== + + Random Symbols contain pspace and symbol properties. + The pspace property points to the represented Probability Space + The symbol is a standard SymPy Symbol that is used in that probability space + for example in defining a density. + + You can form normal SymPy expressions using RandomSymbols and operate on + those expressions with the Functions + + E - Expectation of a random expression + P - Probability of a condition + density - Probability Density of an expression + given - A new random expression (with new random symbols) given a condition + + An object of the RandomSymbol type should almost never be created by the + user. They tend to be created instead by the PSpace class's value method. + Traditionally a user does not even do this but instead calls one of the + convenience functions Normal, Exponential, Coin, Die, FiniteRV, etc.... + """ + + def __new__(cls, symbol, pspace=None): + from sympy.stats.joint_rv import JointRandomSymbol + if pspace is None: + # Allow single arg, representing pspace == PSpace() + pspace = PSpace() + symbol = _symbol_converter(symbol) + if not isinstance(pspace, PSpace): + raise TypeError("pspace variable should be of type PSpace") + if cls == JointRandomSymbol and isinstance(pspace, SinglePSpace): + cls = RandomSymbol + return Basic.__new__(cls, symbol, pspace) + + is_finite = True + is_symbol = True + is_Atom = True + + _diff_wrt = True + + pspace = property(lambda self: self.args[1]) + symbol = property(lambda self: self.args[0]) + name = property(lambda self: self.symbol.name) + + def _eval_is_positive(self): + return self.symbol.is_positive + + def _eval_is_integer(self): + return self.symbol.is_integer + + def _eval_is_real(self): + return self.symbol.is_real or self.pspace.is_real + + @property + def is_commutative(self): + return self.symbol.is_commutative + + @property + def free_symbols(self): + return {self} + +class RandomIndexedSymbol(RandomSymbol): + + def __new__(cls, idx_obj, pspace=None): + if pspace is None: + # Allow single arg, representing pspace == PSpace() + pspace = PSpace() + if not isinstance(idx_obj, (Indexed, Function)): + raise TypeError("An Function or Indexed object is expected not %s"%(idx_obj)) + return Basic.__new__(cls, idx_obj, pspace) + + symbol = property(lambda self: self.args[0]) + name = property(lambda self: str(self.args[0])) + + @property + def key(self): + if isinstance(self.symbol, Indexed): + return self.symbol.args[1] + elif isinstance(self.symbol, Function): + return self.symbol.args[0] + + @property + def free_symbols(self): + if self.key.free_symbols: + free_syms = self.key.free_symbols + free_syms.add(self) + return free_syms + return {self} + + @property + def pspace(self): + return self.args[1] + +class RandomMatrixSymbol(RandomSymbol, MatrixSymbol): # type: ignore + def __new__(cls, symbol, n, m, pspace=None): + n, m = _sympify(n), _sympify(m) + symbol = _symbol_converter(symbol) + if pspace is None: + # Allow single arg, representing pspace == PSpace() + pspace = PSpace() + return Basic.__new__(cls, symbol, n, m, pspace) + + symbol = property(lambda self: self.args[0]) + pspace = property(lambda self: self.args[3]) + + +class ProductPSpace(PSpace): + """ + Abstract class for representing probability spaces with multiple random + variables. + + See Also + ======== + + sympy.stats.rv.IndependentProductPSpace + sympy.stats.joint_rv.JointPSpace + """ + pass + +class IndependentProductPSpace(ProductPSpace): + """ + A probability space resulting from the merger of two independent probability + spaces. + + Often created using the function, pspace. + """ + + def __new__(cls, *spaces): + rs_space_dict = {} + for space in spaces: + for value in space.values: + rs_space_dict[value] = space + + symbols = FiniteSet(*[val.symbol for val in rs_space_dict.keys()]) + + # Overlapping symbols + from sympy.stats.joint_rv import MarginalDistribution + from sympy.stats.compound_rv import CompoundDistribution + if len(symbols) < sum(len(space.symbols) for space in spaces if not + isinstance(space.distribution, ( + CompoundDistribution, MarginalDistribution))): + raise ValueError("Overlapping Random Variables") + + if all(space.is_Finite for space in spaces): + from sympy.stats.frv import ProductFinitePSpace + cls = ProductFinitePSpace + + obj = Basic.__new__(cls, *FiniteSet(*spaces)) + + return obj + + @property + def pdf(self): + p = Mul(*[space.pdf for space in self.spaces]) + return p.subs({rv: rv.symbol for rv in self.values}) + + @property + def rs_space_dict(self): + d = {} + for space in self.spaces: + for value in space.values: + d[value] = space + return d + + @property + def symbols(self): + return FiniteSet(*[val.symbol for val in self.rs_space_dict.keys()]) + + @property + def spaces(self): + return FiniteSet(*self.args) + + @property + def values(self): + return sumsets(space.values for space in self.spaces) + + def compute_expectation(self, expr, rvs=None, evaluate=False, **kwargs): + rvs = rvs or self.values + rvs = frozenset(rvs) + for space in self.spaces: + expr = space.compute_expectation(expr, rvs & space.values, evaluate=False, **kwargs) + if evaluate and hasattr(expr, 'doit'): + return expr.doit(**kwargs) + return expr + + @property + def domain(self): + return ProductDomain(*[space.domain for space in self.spaces]) + + @property + def density(self): + raise NotImplementedError("Density not available for ProductSpaces") + + def sample(self, size=(), library='scipy', seed=None): + return {k: v for space in self.spaces + for k, v in space.sample(size=size, library=library, seed=seed).items()} + + + def probability(self, condition, **kwargs): + cond_inv = False + if isinstance(condition, Ne): + condition = Eq(condition.args[0], condition.args[1]) + cond_inv = True + elif isinstance(condition, And): # they are independent + return Mul(*[self.probability(arg) for arg in condition.args]) + elif isinstance(condition, Or): # they are independent + return Add(*[self.probability(arg) for arg in condition.args]) + expr = condition.lhs - condition.rhs + rvs = random_symbols(expr) + dens = self.compute_density(expr) + if any(pspace(rv).is_Continuous for rv in rvs): + from sympy.stats.crv import SingleContinuousPSpace + from sympy.stats.crv_types import ContinuousDistributionHandmade + if expr in self.values: + # Marginalize all other random symbols out of the density + randomsymbols = tuple(set(self.values) - frozenset([expr])) + symbols = tuple(rs.symbol for rs in randomsymbols) + pdf = self.domain.integrate(self.pdf, symbols, **kwargs) + return Lambda(expr.symbol, pdf) + dens = ContinuousDistributionHandmade(dens) + z = Dummy('z', real=True) + space = SingleContinuousPSpace(z, dens) + result = space.probability(condition.__class__(space.value, 0)) + else: + from sympy.stats.drv import SingleDiscretePSpace + from sympy.stats.drv_types import DiscreteDistributionHandmade + dens = DiscreteDistributionHandmade(dens) + z = Dummy('z', integer=True) + space = SingleDiscretePSpace(z, dens) + result = space.probability(condition.__class__(space.value, 0)) + return result if not cond_inv else S.One - result + + def compute_density(self, expr, **kwargs): + rvs = random_symbols(expr) + if any(pspace(rv).is_Continuous for rv in rvs): + z = Dummy('z', real=True) + expr = self.compute_expectation(DiracDelta(expr - z), + **kwargs) + else: + z = Dummy('z', integer=True) + expr = self.compute_expectation(KroneckerDelta(expr, z), + **kwargs) + return Lambda(z, expr) + + def compute_cdf(self, expr, **kwargs): + raise ValueError("CDF not well defined on multivariate expressions") + + def conditional_space(self, condition, normalize=True, **kwargs): + rvs = random_symbols(condition) + condition = condition.xreplace({rv: rv.symbol for rv in self.values}) + pspaces = [pspace(rv) for rv in rvs] + if any(ps.is_Continuous for ps in pspaces): + from sympy.stats.crv import (ConditionalContinuousDomain, + ContinuousPSpace) + space = ContinuousPSpace + domain = ConditionalContinuousDomain(self.domain, condition) + elif any(ps.is_Discrete for ps in pspaces): + from sympy.stats.drv import (ConditionalDiscreteDomain, + DiscretePSpace) + space = DiscretePSpace + domain = ConditionalDiscreteDomain(self.domain, condition) + elif all(ps.is_Finite for ps in pspaces): + from sympy.stats.frv import FinitePSpace + return FinitePSpace.conditional_space(self, condition) + if normalize: + replacement = {rv: Dummy(str(rv)) for rv in self.symbols} + norm = domain.compute_expectation(self.pdf, **kwargs) + pdf = self.pdf / norm.xreplace(replacement) + # XXX: Converting symbols from set to tuple. The order matters to + # Lambda though so we shouldn't be starting with a set here... + density = Lambda(tuple(domain.symbols), pdf) + + return space(domain, density) + +class ProductDomain(RandomDomain): + """ + A domain resulting from the merger of two independent domains. + + See Also + ======== + sympy.stats.crv.ProductContinuousDomain + sympy.stats.frv.ProductFiniteDomain + """ + is_ProductDomain = True + + def __new__(cls, *domains): + # Flatten any product of products + domains2 = [] + for domain in domains: + if not domain.is_ProductDomain: + domains2.append(domain) + else: + domains2.extend(domain.domains) + domains2 = FiniteSet(*domains2) + + if all(domain.is_Finite for domain in domains2): + from sympy.stats.frv import ProductFiniteDomain + cls = ProductFiniteDomain + if all(domain.is_Continuous for domain in domains2): + from sympy.stats.crv import ProductContinuousDomain + cls = ProductContinuousDomain + if all(domain.is_Discrete for domain in domains2): + from sympy.stats.drv import ProductDiscreteDomain + cls = ProductDiscreteDomain + + return Basic.__new__(cls, *domains2) + + @property + def sym_domain_dict(self): + return {symbol: domain for domain in self.domains + for symbol in domain.symbols} + + @property + def symbols(self): + return FiniteSet(*[sym for domain in self.domains + for sym in domain.symbols]) + + @property + def domains(self): + return self.args + + @property + def set(self): + return ProductSet(*(domain.set for domain in self.domains)) + + def __contains__(self, other): + # Split event into each subdomain + for domain in self.domains: + # Collect the parts of this event which associate to this domain + elem = frozenset([item for item in other + if sympify(domain.symbols.contains(item[0])) + is S.true]) + # Test this sub-event + if elem not in domain: + return False + # All subevents passed + return True + + def as_boolean(self): + return And(*[domain.as_boolean() for domain in self.domains]) + + +def random_symbols(expr): + """ + Returns all RandomSymbols within a SymPy Expression. + """ + atoms = getattr(expr, 'atoms', None) + if atoms is not None: + comp = lambda rv: rv.symbol.name + l = list(atoms(RandomSymbol)) + return sorted(l, key=comp) + else: + return [] + + +def pspace(expr): + """ + Returns the underlying Probability Space of a random expression. + + For internal use. + + Examples + ======== + + >>> from sympy.stats import pspace, Normal + >>> X = Normal('X', 0, 1) + >>> pspace(2*X + 1) == X.pspace + True + """ + expr = sympify(expr) + if isinstance(expr, RandomSymbol) and expr.pspace is not None: + return expr.pspace + if expr.has(RandomMatrixSymbol): + rm = list(expr.atoms(RandomMatrixSymbol))[0] + return rm.pspace + + rvs = random_symbols(expr) + if not rvs: + raise ValueError("Expression containing Random Variable expected, not %s" % (expr)) + # If only one space present + if all(rv.pspace == rvs[0].pspace for rv in rvs): + return rvs[0].pspace + from sympy.stats.compound_rv import CompoundPSpace + from sympy.stats.stochastic_process import StochasticPSpace + for rv in rvs: + if isinstance(rv.pspace, (CompoundPSpace, StochasticPSpace)): + return rv.pspace + # Otherwise make a product space + return IndependentProductPSpace(*[rv.pspace for rv in rvs]) + + +def sumsets(sets): + """ + Union of sets + """ + return frozenset().union(*sets) + + +def rs_swap(a, b): + """ + Build a dictionary to swap RandomSymbols based on their underlying symbol. + + i.e. + if ``X = ('x', pspace1)`` + and ``Y = ('x', pspace2)`` + then ``X`` and ``Y`` match and the key, value pair + ``{X:Y}`` will appear in the result + + Inputs: collections a and b of random variables which share common symbols + Output: dict mapping RVs in a to RVs in b + """ + d = {} + for rsa in a: + d[rsa] = [rsb for rsb in b if rsa.symbol == rsb.symbol][0] + return d + + +def given(expr, condition=None, **kwargs): + r""" Conditional Random Expression. + + Explanation + =========== + + From a random expression and a condition on that expression creates a new + probability space from the condition and returns the same expression on that + conditional probability space. + + Examples + ======== + + >>> from sympy.stats import given, density, Die + >>> X = Die('X', 6) + >>> Y = given(X, X > 3) + >>> density(Y).dict + {4: 1/3, 5: 1/3, 6: 1/3} + + Following convention, if the condition is a random symbol then that symbol + is considered fixed. + + >>> from sympy.stats import Normal + >>> from sympy import pprint + >>> from sympy.abc import z + + >>> X = Normal('X', 0, 1) + >>> Y = Normal('Y', 0, 1) + >>> pprint(density(X + Y, Y)(z), use_unicode=False) + 2 + -(-Y + z) + ----------- + ___ 2 + \/ 2 *e + ------------------ + ____ + 2*\/ pi + """ + + if not is_random(condition) or pspace_independent(expr, condition): + return expr + + if isinstance(condition, RandomSymbol): + condition = Eq(condition, condition.symbol) + + condsymbols = random_symbols(condition) + if (isinstance(condition, Eq) and len(condsymbols) == 1 and + not isinstance(pspace(expr).domain, ConditionalDomain)): + rv = tuple(condsymbols)[0] + + results = solveset(condition, rv) + if isinstance(results, Intersection) and S.Reals in results.args: + results = list(results.args[1]) + + sums = 0 + for res in results: + temp = expr.subs(rv, res) + if temp == True: + return True + if temp != False: + # XXX: This seems nonsensical but preserves existing behaviour + # after the change that Relational is no longer a subclass of + # Expr. Here expr is sometimes Relational and sometimes Expr + # but we are trying to add them with +=. This needs to be + # fixed somehow. + if sums == 0 and isinstance(expr, Relational): + sums = expr.subs(rv, res) + else: + sums += expr.subs(rv, res) + if sums == 0: + return False + return sums + + # Get full probability space of both the expression and the condition + fullspace = pspace(Tuple(expr, condition)) + # Build new space given the condition + space = fullspace.conditional_space(condition, **kwargs) + # Dictionary to swap out RandomSymbols in expr with new RandomSymbols + # That point to the new conditional space + swapdict = rs_swap(fullspace.values, space.values) + # Swap random variables in the expression + expr = expr.xreplace(swapdict) + return expr + + +def expectation(expr, condition=None, numsamples=None, evaluate=True, **kwargs): + """ + Returns the expected value of a random expression. + + Parameters + ========== + + expr : Expr containing RandomSymbols + The expression of which you want to compute the expectation value + given : Expr containing RandomSymbols + A conditional expression. E(X, X>0) is expectation of X given X > 0 + numsamples : int + Enables sampling and approximates the expectation with this many samples + evalf : Bool (defaults to True) + If sampling return a number rather than a complex expression + evaluate : Bool (defaults to True) + In case of continuous systems return unevaluated integral + + Examples + ======== + + >>> from sympy.stats import E, Die + >>> X = Die('X', 6) + >>> E(X) + 7/2 + >>> E(2*X + 1) + 8 + + >>> E(X, X > 3) # Expectation of X given that it is above 3 + 5 + """ + + if not is_random(expr): # expr isn't random? + return expr + kwargs['numsamples'] = numsamples + from sympy.stats.symbolic_probability import Expectation + if evaluate: + return Expectation(expr, condition).doit(**kwargs) + return Expectation(expr, condition) + + +def probability(condition, given_condition=None, numsamples=None, + evaluate=True, **kwargs): + """ + Probability that a condition is true, optionally given a second condition. + + Parameters + ========== + + condition : Combination of Relationals containing RandomSymbols + The condition of which you want to compute the probability + given_condition : Combination of Relationals containing RandomSymbols + A conditional expression. P(X > 1, X > 0) is expectation of X > 1 + given X > 0 + numsamples : int + Enables sampling and approximates the probability with this many samples + evaluate : Bool (defaults to True) + In case of continuous systems return unevaluated integral + + Examples + ======== + + >>> from sympy.stats import P, Die + >>> from sympy import Eq + >>> X, Y = Die('X', 6), Die('Y', 6) + >>> P(X > 3) + 1/2 + >>> P(Eq(X, 5), X > 2) # Probability that X == 5 given that X > 2 + 1/4 + >>> P(X > Y) + 5/12 + """ + + kwargs['numsamples'] = numsamples + from sympy.stats.symbolic_probability import Probability + if evaluate: + return Probability(condition, given_condition).doit(**kwargs) + return Probability(condition, given_condition) + + +class Density(Basic): + expr = property(lambda self: self.args[0]) + + def __new__(cls, expr, condition = None): + expr = _sympify(expr) + if condition is None: + obj = Basic.__new__(cls, expr) + else: + condition = _sympify(condition) + obj = Basic.__new__(cls, expr, condition) + return obj + + @property + def condition(self): + if len(self.args) > 1: + return self.args[1] + else: + return None + + def doit(self, evaluate=True, **kwargs): + from sympy.stats.random_matrix import RandomMatrixPSpace + from sympy.stats.joint_rv import JointPSpace + from sympy.stats.matrix_distributions import MatrixPSpace + from sympy.stats.compound_rv import CompoundPSpace + from sympy.stats.frv import SingleFiniteDistribution + expr, condition = self.expr, self.condition + + if isinstance(expr, SingleFiniteDistribution): + return expr.dict + if condition is not None: + # Recompute on new conditional expr + expr = given(expr, condition, **kwargs) + if not random_symbols(expr): + return Lambda(x, DiracDelta(x - expr)) + if isinstance(expr, RandomSymbol): + if isinstance(expr.pspace, (SinglePSpace, JointPSpace, MatrixPSpace)) and \ + hasattr(expr.pspace, 'distribution'): + return expr.pspace.distribution + elif isinstance(expr.pspace, RandomMatrixPSpace): + return expr.pspace.model + if isinstance(pspace(expr), CompoundPSpace): + kwargs['compound_evaluate'] = evaluate + result = pspace(expr).compute_density(expr, **kwargs) + + if evaluate and hasattr(result, 'doit'): + return result.doit() + else: + return result + + +def density(expr, condition=None, evaluate=True, numsamples=None, **kwargs): + """ + Probability density of a random expression, optionally given a second + condition. + + Explanation + =========== + + This density will take on different forms for different types of + probability spaces. Discrete variables produce Dicts. Continuous + variables produce Lambdas. + + Parameters + ========== + + expr : Expr containing RandomSymbols + The expression of which you want to compute the density value + condition : Relational containing RandomSymbols + A conditional expression. density(X > 1, X > 0) is density of X > 1 + given X > 0 + numsamples : int + Enables sampling and approximates the density with this many samples + + Examples + ======== + + >>> from sympy.stats import density, Die, Normal + >>> from sympy import Symbol + + >>> x = Symbol('x') + >>> D = Die('D', 6) + >>> X = Normal(x, 0, 1) + + >>> density(D).dict + {1: 1/6, 2: 1/6, 3: 1/6, 4: 1/6, 5: 1/6, 6: 1/6} + >>> density(2*D).dict + {2: 1/6, 4: 1/6, 6: 1/6, 8: 1/6, 10: 1/6, 12: 1/6} + >>> density(X)(x) + sqrt(2)*exp(-x**2/2)/(2*sqrt(pi)) + """ + + if numsamples: + return sampling_density(expr, condition, numsamples=numsamples, + **kwargs) + + return Density(expr, condition).doit(evaluate=evaluate, **kwargs) + + +def cdf(expr, condition=None, evaluate=True, **kwargs): + """ + Cumulative Distribution Function of a random expression. + + optionally given a second condition. + + Explanation + =========== + + This density will take on different forms for different types of + probability spaces. + Discrete variables produce Dicts. + Continuous variables produce Lambdas. + + Examples + ======== + + >>> from sympy.stats import density, Die, Normal, cdf + + >>> D = Die('D', 6) + >>> X = Normal('X', 0, 1) + + >>> density(D).dict + {1: 1/6, 2: 1/6, 3: 1/6, 4: 1/6, 5: 1/6, 6: 1/6} + >>> cdf(D) + {1: 1/6, 2: 1/3, 3: 1/2, 4: 2/3, 5: 5/6, 6: 1} + >>> cdf(3*D, D > 2) + {9: 1/4, 12: 1/2, 15: 3/4, 18: 1} + + >>> cdf(X) + Lambda(_z, erf(sqrt(2)*_z/2)/2 + 1/2) + """ + if condition is not None: # If there is a condition + # Recompute on new conditional expr + return cdf(given(expr, condition, **kwargs), **kwargs) + + # Otherwise pass work off to the ProbabilitySpace + result = pspace(expr).compute_cdf(expr, **kwargs) + + if evaluate and hasattr(result, 'doit'): + return result.doit() + else: + return result + + +def characteristic_function(expr, condition=None, evaluate=True, **kwargs): + """ + Characteristic function of a random expression, optionally given a second condition. + + Returns a Lambda. + + Examples + ======== + + >>> from sympy.stats import Normal, DiscreteUniform, Poisson, characteristic_function + + >>> X = Normal('X', 0, 1) + >>> characteristic_function(X) + Lambda(_t, exp(-_t**2/2)) + + >>> Y = DiscreteUniform('Y', [1, 2, 7]) + >>> characteristic_function(Y) + Lambda(_t, exp(7*_t*I)/3 + exp(2*_t*I)/3 + exp(_t*I)/3) + + >>> Z = Poisson('Z', 2) + >>> characteristic_function(Z) + Lambda(_t, exp(2*exp(_t*I) - 2)) + """ + if condition is not None: + return characteristic_function(given(expr, condition, **kwargs), **kwargs) + + result = pspace(expr).compute_characteristic_function(expr, **kwargs) + + if evaluate and hasattr(result, 'doit'): + return result.doit() + else: + return result + +def moment_generating_function(expr, condition=None, evaluate=True, **kwargs): + if condition is not None: + return moment_generating_function(given(expr, condition, **kwargs), **kwargs) + + result = pspace(expr).compute_moment_generating_function(expr, **kwargs) + + if evaluate and hasattr(result, 'doit'): + return result.doit() + else: + return result + +def where(condition, given_condition=None, **kwargs): + """ + Returns the domain where a condition is True. + + Examples + ======== + + >>> from sympy.stats import where, Die, Normal + >>> from sympy import And + + >>> D1, D2 = Die('a', 6), Die('b', 6) + >>> a, b = D1.symbol, D2.symbol + >>> X = Normal('x', 0, 1) + + >>> where(X**2<1) + Domain: (-1 < x) & (x < 1) + + >>> where(X**2<1).set + Interval.open(-1, 1) + + >>> where(And(D1<=D2, D2<3)) + Domain: (Eq(a, 1) & Eq(b, 1)) | (Eq(a, 1) & Eq(b, 2)) | (Eq(a, 2) & Eq(b, 2)) + """ + if given_condition is not None: # If there is a condition + # Recompute on new conditional expr + return where(given(condition, given_condition, **kwargs), **kwargs) + + # Otherwise pass work off to the ProbabilitySpace + return pspace(condition).where(condition, **kwargs) + + +@doctest_depends_on(modules=('scipy',)) +def sample(expr, condition=None, size=(), library='scipy', + numsamples=1, seed=None, **kwargs): + """ + A realization of the random expression. + + Parameters + ========== + + expr : Expression of random variables + Expression from which sample is extracted + condition : Expr containing RandomSymbols + A conditional expression + size : int, tuple + Represents size of each sample in numsamples + library : str + - 'scipy' : Sample using scipy + - 'numpy' : Sample using numpy + - 'pymc' : Sample using PyMC + + Choose any of the available options to sample from as string, + by default is 'scipy' + numsamples : int + Number of samples, each with size as ``size``. + + .. deprecated:: 1.9 + + The ``numsamples`` parameter is deprecated and is only provided for + compatibility with v1.8. Use a list comprehension or an additional + dimension in ``size`` instead. See + :ref:`deprecated-sympy-stats-numsamples` for details. + + seed : + An object to be used as seed by the given external library for sampling `expr`. + Following is the list of possible types of object for the supported libraries, + + - 'scipy': int, numpy.random.RandomState, numpy.random.Generator + - 'numpy': int, numpy.random.RandomState, numpy.random.Generator + - 'pymc': int + + Optional, by default None, in which case seed settings + related to the given library will be used. + No modifications to environment's global seed settings + are done by this argument. + + Returns + ======= + + sample: float/list/numpy.ndarray + one sample or a collection of samples of the random expression. + + - sample(X) returns float/numpy.float64/numpy.int64 object. + - sample(X, size=int/tuple) returns numpy.ndarray object. + + Examples + ======== + + >>> from sympy.stats import Die, sample, Normal, Geometric + >>> X, Y, Z = Die('X', 6), Die('Y', 6), Die('Z', 6) # Finite Random Variable + >>> die_roll = sample(X + Y + Z) + >>> die_roll # doctest: +SKIP + 3 + >>> N = Normal('N', 3, 4) # Continuous Random Variable + >>> samp = sample(N) + >>> samp in N.pspace.domain.set + True + >>> samp = sample(N, N>0) + >>> samp > 0 + True + >>> samp_list = sample(N, size=4) + >>> [sam in N.pspace.domain.set for sam in samp_list] + [True, True, True, True] + >>> sample(N, size = (2,3)) # doctest: +SKIP + array([[5.42519758, 6.40207856, 4.94991743], + [1.85819627, 6.83403519, 1.9412172 ]]) + >>> G = Geometric('G', 0.5) # Discrete Random Variable + >>> samp_list = sample(G, size=3) + >>> samp_list # doctest: +SKIP + [1, 3, 2] + >>> [sam in G.pspace.domain.set for sam in samp_list] + [True, True, True] + >>> MN = Normal("MN", [3, 4], [[2, 1], [1, 2]]) # Joint Random Variable + >>> samp_list = sample(MN, size=4) + >>> samp_list # doctest: +SKIP + [array([2.85768055, 3.38954165]), + array([4.11163337, 4.3176591 ]), + array([0.79115232, 1.63232916]), + array([4.01747268, 3.96716083])] + >>> [tuple(sam) in MN.pspace.domain.set for sam in samp_list] + [True, True, True, True] + + .. versionchanged:: 1.7.0 + sample used to return an iterator containing the samples instead of value. + + .. versionchanged:: 1.9.0 + sample returns values or array of values instead of an iterator and numsamples is deprecated. + + """ + + iterator = sample_iter(expr, condition, size=size, library=library, + numsamples=numsamples, seed=seed) + + if numsamples != 1: + sympy_deprecation_warning( + f""" + The numsamples parameter to sympy.stats.sample() is deprecated. + Either use a list comprehension, like + + [sample(...) for i in range({numsamples})] + + or add a dimension to size, like + + sample(..., size={(numsamples,) + size}) + """, + deprecated_since_version="1.9", + active_deprecations_target="deprecated-sympy-stats-numsamples", + ) + return [next(iterator) for i in range(numsamples)] + + return next(iterator) + + +def quantile(expr, evaluate=True, **kwargs): + r""" + Return the :math:`p^{th}` order quantile of a probability distribution. + + Explanation + =========== + + Quantile is defined as the value at which the probability of the random + variable is less than or equal to the given probability. + + .. math:: + Q(p) = \inf\{x \in (-\infty, \infty) : p \le F(x)\} + + Examples + ======== + + >>> from sympy.stats import quantile, Die, Exponential + >>> from sympy import Symbol, pprint + >>> p = Symbol("p") + + >>> l = Symbol("lambda", positive=True) + >>> X = Exponential("x", l) + >>> quantile(X)(p) + -log(1 - p)/lambda + + >>> D = Die("d", 6) + >>> pprint(quantile(D)(p), use_unicode=False) + /nan for Or(p > 1, p < 0) + | + | 1 for p <= 1/6 + | + | 2 for p <= 1/3 + | + < 3 for p <= 1/2 + | + | 4 for p <= 2/3 + | + | 5 for p <= 5/6 + | + \ 6 for p <= 1 + + """ + result = pspace(expr).compute_quantile(expr, **kwargs) + + if evaluate and hasattr(result, 'doit'): + return result.doit() + else: + return result + +def sample_iter(expr, condition=None, size=(), library='scipy', + numsamples=S.Infinity, seed=None, **kwargs): + + """ + Returns an iterator of realizations from the expression given a condition. + + Parameters + ========== + + expr: Expr + Random expression to be realized + condition: Expr, optional + A conditional expression + size : int, tuple + Represents size of each sample in numsamples + numsamples: integer, optional + Length of the iterator (defaults to infinity) + seed : + An object to be used as seed by the given external library for sampling `expr`. + Following is the list of possible types of object for the supported libraries, + + - 'scipy': int, numpy.random.RandomState, numpy.random.Generator + - 'numpy': int, numpy.random.RandomState, numpy.random.Generator + - 'pymc': int + + Optional, by default None, in which case seed settings + related to the given library will be used. + No modifications to environment's global seed settings + are done by this argument. + + Examples + ======== + + >>> from sympy.stats import Normal, sample_iter + >>> X = Normal('X', 0, 1) + >>> expr = X*X + 3 + >>> iterator = sample_iter(expr, numsamples=3) # doctest: +SKIP + >>> list(iterator) # doctest: +SKIP + [12, 4, 7] + + Returns + ======= + + sample_iter: iterator object + iterator object containing the sample/samples of given expr + + See Also + ======== + + sample + sampling_P + sampling_E + + """ + from sympy.stats.joint_rv import JointRandomSymbol + if not import_module(library): + raise ValueError("Failed to import %s" % library) + + if condition is not None: + ps = pspace(Tuple(expr, condition)) + else: + ps = pspace(expr) + + rvs = list(ps.values) + if isinstance(expr, JointRandomSymbol): + expr = expr.subs({expr: RandomSymbol(expr.symbol, expr.pspace)}) + else: + sub = {} + for arg in expr.args: + if isinstance(arg, JointRandomSymbol): + sub[arg] = RandomSymbol(arg.symbol, arg.pspace) + expr = expr.subs(sub) + + def fn_subs(*args): + return expr.subs(dict(zip(rvs, args))) + + def given_fn_subs(*args): + if condition is not None: + return condition.subs(dict(zip(rvs, args))) + return False + + if library in ('pymc', 'pymc3'): + # Currently unable to lambdify in pymc + # TODO : Remove when lambdify accepts 'pymc' as module + fn = lambdify(rvs, expr, **kwargs) + else: + fn = lambdify(rvs, expr, modules=library, **kwargs) + + + if condition is not None: + given_fn = lambdify(rvs, condition, **kwargs) + + def return_generator_infinite(): + count = 0 + _size = (1,)+((size,) if isinstance(size, int) else size) + while count < numsamples: + d = ps.sample(size=_size, library=library, seed=seed) # a dictionary that maps RVs to values + args = [d[rv][0] for rv in rvs] + + if condition is not None: # Check that these values satisfy the condition + # TODO: Replace the try-except block with only given_fn(*args) + # once lambdify works with unevaluated SymPy objects. + try: + gd = given_fn(*args) + except (NameError, TypeError): + gd = given_fn_subs(*args) + if gd != True and gd != False: + raise ValueError( + "Conditions must not contain free symbols") + if not gd: # If the values don't satisfy then try again + continue + + yield fn(*args) + count += 1 + + def return_generator_finite(): + faulty = True + while faulty: + d = ps.sample(size=(numsamples,) + ((size,) if isinstance(size, int) else size), + library=library, seed=seed) # a dictionary that maps RVs to values + + faulty = False + count = 0 + while count < numsamples and not faulty: + args = [d[rv][count] for rv in rvs] + if condition is not None: # Check that these values satisfy the condition + # TODO: Replace the try-except block with only given_fn(*args) + # once lambdify works with unevaluated SymPy objects. + try: + gd = given_fn(*args) + except (NameError, TypeError): + gd = given_fn_subs(*args) + if gd != True and gd != False: + raise ValueError( + "Conditions must not contain free symbols") + if not gd: # If the values don't satisfy then try again + faulty = True + + count += 1 + + count = 0 + while count < numsamples: + args = [d[rv][count] for rv in rvs] + # TODO: Replace the try-except block with only fn(*args) + # once lambdify works with unevaluated SymPy objects. + try: + yield fn(*args) + except (NameError, TypeError): + yield fn_subs(*args) + count += 1 + + if numsamples is S.Infinity: + return return_generator_infinite() + + return return_generator_finite() + +def sample_iter_lambdify(expr, condition=None, size=(), + numsamples=S.Infinity, seed=None, **kwargs): + + return sample_iter(expr, condition=condition, size=size, + numsamples=numsamples, seed=seed, **kwargs) + +def sample_iter_subs(expr, condition=None, size=(), + numsamples=S.Infinity, seed=None, **kwargs): + + return sample_iter(expr, condition=condition, size=size, + numsamples=numsamples, seed=seed, **kwargs) + + +def sampling_P(condition, given_condition=None, library='scipy', numsamples=1, + evalf=True, seed=None, **kwargs): + """ + Sampling version of P. + + See Also + ======== + + P + sampling_E + sampling_density + + """ + + count_true = 0 + count_false = 0 + samples = sample_iter(condition, given_condition, library=library, + numsamples=numsamples, seed=seed, **kwargs) + + for sample in samples: + if sample: + count_true += 1 + else: + count_false += 1 + + result = S(count_true) / numsamples + if evalf: + return result.evalf() + else: + return result + + +def sampling_E(expr, given_condition=None, library='scipy', numsamples=1, + evalf=True, seed=None, **kwargs): + """ + Sampling version of E. + + See Also + ======== + + P + sampling_P + sampling_density + """ + samples = list(sample_iter(expr, given_condition, library=library, + numsamples=numsamples, seed=seed, **kwargs)) + result = Add(*samples) / numsamples + + if evalf: + return result.evalf() + else: + return result + +def sampling_density(expr, given_condition=None, library='scipy', + numsamples=1, seed=None, **kwargs): + """ + Sampling version of density. + + See Also + ======== + density + sampling_P + sampling_E + """ + + results = {} + for result in sample_iter(expr, given_condition, library=library, + numsamples=numsamples, seed=seed, **kwargs): + results[result] = results.get(result, 0) + 1 + + return results + + +def dependent(a, b): + """ + Dependence of two random expressions. + + Two expressions are independent if knowledge of one does not change + computations on the other. + + Examples + ======== + + >>> from sympy.stats import Normal, dependent, given + >>> from sympy import Tuple, Eq + + >>> X, Y = Normal('X', 0, 1), Normal('Y', 0, 1) + >>> dependent(X, Y) + False + >>> dependent(2*X + Y, -Y) + True + >>> X, Y = given(Tuple(X, Y), Eq(X + Y, 3)) + >>> dependent(X, Y) + True + + See Also + ======== + + independent + """ + if pspace_independent(a, b): + return False + + z = Symbol('z', real=True) + # Dependent if density is unchanged when one is given information about + # the other + return (density(a, Eq(b, z)) != density(a) or + density(b, Eq(a, z)) != density(b)) + + +def independent(a, b): + """ + Independence of two random expressions. + + Two expressions are independent if knowledge of one does not change + computations on the other. + + Examples + ======== + + >>> from sympy.stats import Normal, independent, given + >>> from sympy import Tuple, Eq + + >>> X, Y = Normal('X', 0, 1), Normal('Y', 0, 1) + >>> independent(X, Y) + True + >>> independent(2*X + Y, -Y) + False + >>> X, Y = given(Tuple(X, Y), Eq(X + Y, 3)) + >>> independent(X, Y) + False + + See Also + ======== + + dependent + """ + return not dependent(a, b) + + +def pspace_independent(a, b): + """ + Tests for independence between a and b by checking if their PSpaces have + overlapping symbols. This is a sufficient but not necessary condition for + independence and is intended to be used internally. + + Notes + ===== + + pspace_independent(a, b) implies independent(a, b) + independent(a, b) does not imply pspace_independent(a, b) + """ + a_symbols = set(pspace(b).symbols) + b_symbols = set(pspace(a).symbols) + + if len(set(random_symbols(a)).intersection(random_symbols(b))) != 0: + return False + + if len(a_symbols.intersection(b_symbols)) == 0: + return True + return None + + +def rv_subs(expr, symbols=None): + """ + Given a random expression replace all random variables with their symbols. + + If symbols keyword is given restrict the swap to only the symbols listed. + """ + if symbols is None: + symbols = random_symbols(expr) + if not symbols: + return expr + swapdict = {rv: rv.symbol for rv in symbols} + return expr.subs(swapdict) + + +class NamedArgsMixin: + _argnames: tuple[str, ...] = () + + def __getattr__(self, attr): + try: + return self.args[self._argnames.index(attr)] + except ValueError: + raise AttributeError("'%s' object has no attribute '%s'" % ( + type(self).__name__, attr)) + + +class Distribution(Basic): + + def sample(self, size=(), library='scipy', seed=None): + """ A random realization from the distribution """ + + module = import_module(library) + if library in {'scipy', 'numpy', 'pymc3', 'pymc'} and module is None: + raise ValueError("Failed to import %s" % library) + + if library == 'scipy': + # scipy does not require map as it can handle using custom distributions. + # However, we will still use a map where we can. + + # TODO: do this for drv.py and frv.py if necessary. + # TODO: add more distributions here if there are more + # See links below referring to sections beginning with "A common parametrization..." + # I will remove all these comments if everything is ok. + + from sympy.stats.sampling.sample_scipy import do_sample_scipy + import numpy + if seed is None or isinstance(seed, int): + rand_state = numpy.random.default_rng(seed=seed) + else: + rand_state = seed + samps = do_sample_scipy(self, size, rand_state) + + elif library == 'numpy': + from sympy.stats.sampling.sample_numpy import do_sample_numpy + import numpy + if seed is None or isinstance(seed, int): + rand_state = numpy.random.default_rng(seed=seed) + else: + rand_state = seed + _size = None if size == () else size + samps = do_sample_numpy(self, _size, rand_state) + elif library in ('pymc', 'pymc3'): + from sympy.stats.sampling.sample_pymc import do_sample_pymc + import logging + logging.getLogger("pymc").setLevel(logging.ERROR) + try: + import pymc + except ImportError: + import pymc3 as pymc + + with pymc.Model(): + if do_sample_pymc(self) is not None: + samps = pymc.sample(draws=prod(size), chains=1, compute_convergence_checks=False, + progressbar=False, random_seed=seed, return_inferencedata=False)[:]['X'] + samps = samps.reshape(size) + else: + samps = None + + else: + raise NotImplementedError("Sampling from %s is not supported yet." + % str(library)) + + if samps is not None: + return samps + raise NotImplementedError( + "Sampling for %s is not currently implemented from %s" + % (self, library)) + + +def _value_check(condition, message): + """ + Raise a ValueError with message if condition is False, else + return True if all conditions were True, else False. + + Examples + ======== + + >>> from sympy.stats.rv import _value_check + >>> from sympy.abc import a, b, c + >>> from sympy import And, Dummy + + >>> _value_check(2 < 3, '') + True + + Here, the condition is not False, but it does not evaluate to True + so False is returned (but no error is raised). So checking if the + return value is True or False will tell you if all conditions were + evaluated. + + >>> _value_check(a < b, '') + False + + In this case the condition is False so an error is raised: + + >>> r = Dummy(real=True) + >>> _value_check(r < r - 1, 'condition is not true') + Traceback (most recent call last): + ... + ValueError: condition is not true + + If no condition of many conditions must be False, they can be + checked by passing them as an iterable: + + >>> _value_check((a < 0, b < 0, c < 0), '') + False + + The iterable can be a generator, too: + + >>> _value_check((i < 0 for i in (a, b, c)), '') + False + + The following are equivalent to the above but do not pass + an iterable: + + >>> all(_value_check(i < 0, '') for i in (a, b, c)) + False + >>> _value_check(And(a < 0, b < 0, c < 0), '') + False + """ + if not iterable(condition): + condition = [condition] + truth = fuzzy_and(condition) + if truth == False: + raise ValueError(message) + return truth == True + +def _symbol_converter(sym): + """ + Casts the parameter to Symbol if it is 'str' + otherwise no operation is performed on it. + + Parameters + ========== + + sym + The parameter to be converted. + + Returns + ======= + + Symbol + the parameter converted to Symbol. + + Raises + ====== + + TypeError + If the parameter is not an instance of both str and + Symbol. + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy.stats.rv import _symbol_converter + >>> s = _symbol_converter('s') + >>> isinstance(s, Symbol) + True + >>> _symbol_converter(1) + Traceback (most recent call last): + ... + TypeError: 1 is neither a Symbol nor a string + >>> r = Symbol('r') + >>> isinstance(r, Symbol) + True + """ + if isinstance(sym, str): + sym = Symbol(sym) + if not isinstance(sym, Symbol): + raise TypeError("%s is neither a Symbol nor a string"%(sym)) + return sym + +def sample_stochastic_process(process): + """ + This function is used to sample from stochastic process. + + Parameters + ========== + + process: StochasticProcess + Process used to extract the samples. It must be an instance of + StochasticProcess + + Examples + ======== + + >>> from sympy.stats import sample_stochastic_process, DiscreteMarkovChain + >>> from sympy import Matrix + >>> T = Matrix([[0.5, 0.2, 0.3],[0.2, 0.5, 0.3],[0.2, 0.3, 0.5]]) + >>> Y = DiscreteMarkovChain("Y", [0, 1, 2], T) + >>> next(sample_stochastic_process(Y)) in Y.state_space + True + >>> next(sample_stochastic_process(Y)) # doctest: +SKIP + 0 + >>> next(sample_stochastic_process(Y)) # doctest: +SKIP + 2 + + Returns + ======= + + sample: iterator object + iterator object containing the sample of given process + + """ + from sympy.stats.stochastic_process_types import StochasticProcess + if not isinstance(process, StochasticProcess): + raise ValueError("Process must be an instance of Stochastic Process") + return process.sample() diff --git a/phivenv/Lib/site-packages/sympy/stats/rv_interface.py b/phivenv/Lib/site-packages/sympy/stats/rv_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..16d65b83634cdb04ef7e5046175848cdf380434b --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/stats/rv_interface.py @@ -0,0 +1,519 @@ +from sympy.sets import FiniteSet +from sympy.core.numbers import Rational +from sympy.core.relational import Eq +from sympy.core.symbol import Dummy +from sympy.functions.combinatorial.factorials import FallingFactorial +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import piecewise_fold +from sympy.integrals.integrals import Integral +from sympy.solvers.solveset import solveset +from .rv import (probability, expectation, density, where, given, pspace, cdf, PSpace, + characteristic_function, sample, sample_iter, random_symbols, independent, dependent, + sampling_density, moment_generating_function, quantile, is_random, + sample_stochastic_process) + + +__all__ = ['P', 'E', 'H', 'density', 'where', 'given', 'sample', 'cdf', + 'characteristic_function', 'pspace', 'sample_iter', 'variance', 'std', + 'skewness', 'kurtosis', 'covariance', 'dependent', 'entropy', 'median', + 'independent', 'random_symbols', 'correlation', 'factorial_moment', + 'moment', 'cmoment', 'sampling_density', 'moment_generating_function', + 'smoment', 'quantile', 'sample_stochastic_process'] + + + +def moment(X, n, c=0, condition=None, *, evaluate=True, **kwargs): + """ + Return the nth moment of a random expression about c. + + .. math:: + moment(X, c, n) = E((X-c)^{n}) + + Default value of c is 0. + + Examples + ======== + + >>> from sympy.stats import Die, moment, E + >>> X = Die('X', 6) + >>> moment(X, 1, 6) + -5/2 + >>> moment(X, 2) + 91/6 + >>> moment(X, 1) == E(X) + True + """ + from sympy.stats.symbolic_probability import Moment + if evaluate: + return Moment(X, n, c, condition).doit() + return Moment(X, n, c, condition).rewrite(Integral) + + +def variance(X, condition=None, **kwargs): + """ + Variance of a random expression. + + .. math:: + variance(X) = E((X-E(X))^{2}) + + Examples + ======== + + >>> from sympy.stats import Die, Bernoulli, variance + >>> from sympy import simplify, Symbol + + >>> X = Die('X', 6) + >>> p = Symbol('p') + >>> B = Bernoulli('B', p, 1, 0) + + >>> variance(2*X) + 35/3 + + >>> simplify(variance(B)) + p*(1 - p) + """ + if is_random(X) and pspace(X) == PSpace(): + from sympy.stats.symbolic_probability import Variance + return Variance(X, condition) + + return cmoment(X, 2, condition, **kwargs) + + +def standard_deviation(X, condition=None, **kwargs): + r""" + Standard Deviation of a random expression + + .. math:: + std(X) = \sqrt(E((X-E(X))^{2})) + + Examples + ======== + + >>> from sympy.stats import Bernoulli, std + >>> from sympy import Symbol, simplify + + >>> p = Symbol('p') + >>> B = Bernoulli('B', p, 1, 0) + + >>> simplify(std(B)) + sqrt(p*(1 - p)) + """ + return sqrt(variance(X, condition, **kwargs)) +std = standard_deviation + +def entropy(expr, condition=None, **kwargs): + """ + Calculates entropy of a probability distribution. + + Parameters + ========== + + expression : the random expression whose entropy is to be calculated + condition : optional, to specify conditions on random expression + b: base of the logarithm, optional + By default, it is taken as Euler's number + + Returns + ======= + + result : Entropy of the expression, a constant + + Examples + ======== + + >>> from sympy.stats import Normal, Die, entropy + >>> X = Normal('X', 0, 1) + >>> entropy(X) + log(2)/2 + 1/2 + log(pi)/2 + + >>> D = Die('D', 4) + >>> entropy(D) + log(4) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Entropy_%28information_theory%29 + .. [2] https://www.crmarsh.com/static/pdf/Charles_Marsh_Continuous_Entropy.pdf + .. [3] https://kconrad.math.uconn.edu/blurbs/analysis/entropypost.pdf + """ + pdf = density(expr, condition, **kwargs) + base = kwargs.get('b', exp(1)) + if isinstance(pdf, dict): + return sum(-prob*log(prob, base) for prob in pdf.values()) + return expectation(-log(pdf(expr), base)) + +def covariance(X, Y, condition=None, **kwargs): + """ + Covariance of two random expressions. + + Explanation + =========== + + The expectation that the two variables will rise and fall together + + .. math:: + covariance(X,Y) = E((X-E(X)) (Y-E(Y))) + + Examples + ======== + + >>> from sympy.stats import Exponential, covariance + >>> from sympy import Symbol + + >>> rate = Symbol('lambda', positive=True, real=True) + >>> X = Exponential('X', rate) + >>> Y = Exponential('Y', rate) + + >>> covariance(X, X) + lambda**(-2) + >>> covariance(X, Y) + 0 + >>> covariance(X, Y + rate*X) + 1/lambda + """ + if (is_random(X) and pspace(X) == PSpace()) or (is_random(Y) and pspace(Y) == PSpace()): + from sympy.stats.symbolic_probability import Covariance + return Covariance(X, Y, condition) + + return expectation( + (X - expectation(X, condition, **kwargs)) * + (Y - expectation(Y, condition, **kwargs)), + condition, **kwargs) + + +def correlation(X, Y, condition=None, **kwargs): + r""" + Correlation of two random expressions, also known as correlation + coefficient or Pearson's correlation. + + Explanation + =========== + + The normalized expectation that the two variables will rise + and fall together + + .. math:: + correlation(X,Y) = E((X-E(X))(Y-E(Y)) / (\sigma_x \sigma_y)) + + Examples + ======== + + >>> from sympy.stats import Exponential, correlation + >>> from sympy import Symbol + + >>> rate = Symbol('lambda', positive=True, real=True) + >>> X = Exponential('X', rate) + >>> Y = Exponential('Y', rate) + + >>> correlation(X, X) + 1 + >>> correlation(X, Y) + 0 + >>> correlation(X, Y + rate*X) + 1/sqrt(1 + lambda**(-2)) + """ + return covariance(X, Y, condition, **kwargs)/(std(X, condition, **kwargs) + * std(Y, condition, **kwargs)) + + +def cmoment(X, n, condition=None, *, evaluate=True, **kwargs): + """ + Return the nth central moment of a random expression about its mean. + + .. math:: + cmoment(X, n) = E((X - E(X))^{n}) + + Examples + ======== + + >>> from sympy.stats import Die, cmoment, variance + >>> X = Die('X', 6) + >>> cmoment(X, 3) + 0 + >>> cmoment(X, 2) + 35/12 + >>> cmoment(X, 2) == variance(X) + True + """ + from sympy.stats.symbolic_probability import CentralMoment + if evaluate: + return CentralMoment(X, n, condition).doit() + return CentralMoment(X, n, condition).rewrite(Integral) + + +def smoment(X, n, condition=None, **kwargs): + r""" + Return the nth Standardized moment of a random expression. + + .. math:: + smoment(X, n) = E(((X - \mu)/\sigma_X)^{n}) + + Examples + ======== + + >>> from sympy.stats import skewness, Exponential, smoment + >>> from sympy import Symbol + >>> rate = Symbol('lambda', positive=True, real=True) + >>> Y = Exponential('Y', rate) + >>> smoment(Y, 4) + 9 + >>> smoment(Y, 4) == smoment(3*Y, 4) + True + >>> smoment(Y, 3) == skewness(Y) + True + """ + sigma = std(X, condition, **kwargs) + return (1/sigma)**n*cmoment(X, n, condition, **kwargs) + +def skewness(X, condition=None, **kwargs): + r""" + Measure of the asymmetry of the probability distribution. + + Explanation + =========== + + Positive skew indicates that most of the values lie to the right of + the mean. + + .. math:: + skewness(X) = E(((X - E(X))/\sigma_X)^{3}) + + Parameters + ========== + + condition : Expr containing RandomSymbols + A conditional expression. skewness(X, X>0) is skewness of X given X > 0 + + Examples + ======== + + >>> from sympy.stats import skewness, Exponential, Normal + >>> from sympy import Symbol + >>> X = Normal('X', 0, 1) + >>> skewness(X) + 0 + >>> skewness(X, X > 0) # find skewness given X > 0 + (-sqrt(2)/sqrt(pi) + 4*sqrt(2)/pi**(3/2))/(1 - 2/pi)**(3/2) + + >>> rate = Symbol('lambda', positive=True, real=True) + >>> Y = Exponential('Y', rate) + >>> skewness(Y) + 2 + """ + return smoment(X, 3, condition=condition, **kwargs) + +def kurtosis(X, condition=None, **kwargs): + r""" + Characterizes the tails/outliers of a probability distribution. + + Explanation + =========== + + Kurtosis of any univariate normal distribution is 3. Kurtosis less than + 3 means that the distribution produces fewer and less extreme outliers + than the normal distribution. + + .. math:: + kurtosis(X) = E(((X - E(X))/\sigma_X)^{4}) + + Parameters + ========== + + condition : Expr containing RandomSymbols + A conditional expression. kurtosis(X, X>0) is kurtosis of X given X > 0 + + Examples + ======== + + >>> from sympy.stats import kurtosis, Exponential, Normal + >>> from sympy import Symbol + >>> X = Normal('X', 0, 1) + >>> kurtosis(X) + 3 + >>> kurtosis(X, X > 0) # find kurtosis given X > 0 + (-4/pi - 12/pi**2 + 3)/(1 - 2/pi)**2 + + >>> rate = Symbol('lamda', positive=True, real=True) + >>> Y = Exponential('Y', rate) + >>> kurtosis(Y) + 9 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Kurtosis + .. [2] https://mathworld.wolfram.com/Kurtosis.html + """ + return smoment(X, 4, condition=condition, **kwargs) + + +def factorial_moment(X, n, condition=None, **kwargs): + """ + The factorial moment is a mathematical quantity defined as the expectation + or average of the falling factorial of a random variable. + + .. math:: + factorial-moment(X, n) = E(X(X - 1)(X - 2)...(X - n + 1)) + + Parameters + ========== + + n: A natural number, n-th factorial moment. + + condition : Expr containing RandomSymbols + A conditional expression. + + Examples + ======== + + >>> from sympy.stats import factorial_moment, Poisson, Binomial + >>> from sympy import Symbol, S + >>> lamda = Symbol('lamda') + >>> X = Poisson('X', lamda) + >>> factorial_moment(X, 2) + lamda**2 + >>> Y = Binomial('Y', 2, S.Half) + >>> factorial_moment(Y, 2) + 1/2 + >>> factorial_moment(Y, 2, Y > 1) # find factorial moment for Y > 1 + 2 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Factorial_moment + .. [2] https://mathworld.wolfram.com/FactorialMoment.html + """ + return expectation(FallingFactorial(X, n), condition=condition, **kwargs) + +def median(X, evaluate=True, **kwargs): + r""" + Calculates the median of the probability distribution. + + Explanation + =========== + + Mathematically, median of Probability distribution is defined as all those + values of `m` for which the following condition is satisfied + + .. math:: + P(X\leq m) \geq \frac{1}{2} \text{ and} \text{ } P(X\geq m)\geq \frac{1}{2} + + Parameters + ========== + + X: The random expression whose median is to be calculated. + + Returns + ======= + + The FiniteSet or an Interval which contains the median of the + random expression. + + Examples + ======== + + >>> from sympy.stats import Normal, Die, median + >>> N = Normal('N', 3, 1) + >>> median(N) + {3} + >>> D = Die('D') + >>> median(D) + {3, 4} + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Median#Probability_distributions + + """ + if not is_random(X): + return X + + from sympy.stats.crv import ContinuousPSpace + from sympy.stats.drv import DiscretePSpace + from sympy.stats.frv import FinitePSpace + + if isinstance(pspace(X), FinitePSpace): + cdf = pspace(X).compute_cdf(X) + result = [] + for key, value in cdf.items(): + if value>= Rational(1, 2) and (1 - value) + \ + pspace(X).probability(Eq(X, key)) >= Rational(1, 2): + result.append(key) + return FiniteSet(*result) + if isinstance(pspace(X), (ContinuousPSpace, DiscretePSpace)): + cdf = pspace(X).compute_cdf(X) + x = Dummy('x') + result = solveset(piecewise_fold(cdf(x) - Rational(1, 2)), x, pspace(X).set) + return result + raise NotImplementedError("The median of %s is not implemented."%str(pspace(X))) + + +def coskewness(X, Y, Z, condition=None, **kwargs): + r""" + Calculates the co-skewness of three random variables. + + Explanation + =========== + + Mathematically Coskewness is defined as + + .. math:: + coskewness(X,Y,Z)=\frac{E[(X-E[X]) * (Y-E[Y]) * (Z-E[Z])]} {\sigma_{X}\sigma_{Y}\sigma_{Z}} + + Parameters + ========== + + X : RandomSymbol + Random Variable used to calculate coskewness + Y : RandomSymbol + Random Variable used to calculate coskewness + Z : RandomSymbol + Random Variable used to calculate coskewness + condition : Expr containing RandomSymbols + A conditional expression + + Examples + ======== + + >>> from sympy.stats import coskewness, Exponential, skewness + >>> from sympy import symbols + >>> p = symbols('p', positive=True) + >>> X = Exponential('X', p) + >>> Y = Exponential('Y', 2*p) + >>> coskewness(X, Y, Y) + 0 + >>> coskewness(X, Y + X, Y + 2*X) + 16*sqrt(85)/85 + >>> coskewness(X + 2*Y, Y + X, Y + 2*X, X > 3) + 9*sqrt(170)/85 + >>> coskewness(Y, Y, Y) == skewness(Y) + True + >>> coskewness(X, Y + p*X, Y + 2*p*X) + 4/(sqrt(1 + 1/(4*p**2))*sqrt(4 + 1/(4*p**2))) + + Returns + ======= + + coskewness : The coskewness of the three random variables + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Coskewness + + """ + num = expectation((X - expectation(X, condition, **kwargs)) \ + * (Y - expectation(Y, condition, **kwargs)) \ + * (Z - expectation(Z, condition, **kwargs)), condition, **kwargs) + den = std(X, condition, **kwargs) * std(Y, condition, **kwargs) \ + * std(Z, condition, **kwargs) + return num/den + + +P = probability +E = expectation +H = entropy diff --git a/phivenv/Lib/site-packages/sympy/stats/sampling/__init__.py b/phivenv/Lib/site-packages/sympy/stats/sampling/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/stats/sampling/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/stats/sampling/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb433ebc9697d35a48096bd2ee969ec998580847 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/stats/sampling/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/stats/sampling/__pycache__/sample_numpy.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/stats/sampling/__pycache__/sample_numpy.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b19fd7553cfae2281392d2435062c4e79a41a054 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/stats/sampling/__pycache__/sample_numpy.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/stats/sampling/__pycache__/sample_pymc.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/stats/sampling/__pycache__/sample_pymc.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..926d648c4dcdd9b323fa32496fca85a5bd35cb4c Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/stats/sampling/__pycache__/sample_pymc.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/stats/sampling/__pycache__/sample_scipy.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/stats/sampling/__pycache__/sample_scipy.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65f14915ffdea644d7aed00c388e42b644aa9cf9 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/stats/sampling/__pycache__/sample_scipy.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/stats/sampling/sample_numpy.py b/phivenv/Lib/site-packages/sympy/stats/sampling/sample_numpy.py new file mode 100644 index 0000000000000000000000000000000000000000..d65417945449ed8b62d2547215d8908f84820a9b --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/stats/sampling/sample_numpy.py @@ -0,0 +1,105 @@ +from functools import singledispatch + +from sympy.external import import_module +from sympy.stats.crv_types import BetaDistribution, ChiSquaredDistribution, ExponentialDistribution, GammaDistribution, \ + LogNormalDistribution, NormalDistribution, ParetoDistribution, UniformDistribution, FDistributionDistribution, GumbelDistribution, LaplaceDistribution, \ + LogisticDistribution, RayleighDistribution, TriangularDistribution +from sympy.stats.drv_types import GeometricDistribution, PoissonDistribution, ZetaDistribution +from sympy.stats.frv_types import BinomialDistribution, HypergeometricDistribution + + +numpy = import_module('numpy') + + +@singledispatch +def do_sample_numpy(dist, size, rand_state): + return None + + +# CRV: + +@do_sample_numpy.register(BetaDistribution) +def _(dist: BetaDistribution, size, rand_state): + return rand_state.beta(a=float(dist.alpha), b=float(dist.beta), size=size) + + +@do_sample_numpy.register(ChiSquaredDistribution) +def _(dist: ChiSquaredDistribution, size, rand_state): + return rand_state.chisquare(df=float(dist.k), size=size) + + +@do_sample_numpy.register(ExponentialDistribution) +def _(dist: ExponentialDistribution, size, rand_state): + return rand_state.exponential(1 / float(dist.rate), size=size) + +@do_sample_numpy.register(FDistributionDistribution) +def _(dist: FDistributionDistribution, size, rand_state): + return rand_state.f(dfnum = float(dist.d1), dfden = float(dist.d2), size=size) + +@do_sample_numpy.register(GammaDistribution) +def _(dist: GammaDistribution, size, rand_state): + return rand_state.gamma(shape = float(dist.k), scale = float(dist.theta), size=size) + +@do_sample_numpy.register(GumbelDistribution) +def _(dist: GumbelDistribution, size, rand_state): + return rand_state.gumbel(loc = float(dist.mu), scale = float(dist.beta), size=size) + +@do_sample_numpy.register(LaplaceDistribution) +def _(dist: LaplaceDistribution, size, rand_state): + return rand_state.laplace(loc = float(dist.mu), scale = float(dist.b), size=size) + +@do_sample_numpy.register(LogisticDistribution) +def _(dist: LogisticDistribution, size, rand_state): + return rand_state.logistic(loc = float(dist.mu), scale = float(dist.s), size=size) + +@do_sample_numpy.register(LogNormalDistribution) +def _(dist: LogNormalDistribution, size, rand_state): + return rand_state.lognormal(mean = float(dist.mean), sigma = float(dist.std), size=size) + +@do_sample_numpy.register(NormalDistribution) +def _(dist: NormalDistribution, size, rand_state): + return rand_state.normal(loc = float(dist.mean), scale = float(dist.std), size=size) + +@do_sample_numpy.register(RayleighDistribution) +def _(dist: RayleighDistribution, size, rand_state): + return rand_state.rayleigh(scale = float(dist.sigma), size=size) + +@do_sample_numpy.register(ParetoDistribution) +def _(dist: ParetoDistribution, size, rand_state): + return (numpy.random.pareto(a=float(dist.alpha), size=size) + 1) * float(dist.xm) + +@do_sample_numpy.register(TriangularDistribution) +def _(dist: TriangularDistribution, size, rand_state): + return rand_state.triangular(left = float(dist.a), mode = float(dist.b), right = float(dist.c), size=size) + +@do_sample_numpy.register(UniformDistribution) +def _(dist: UniformDistribution, size, rand_state): + return rand_state.uniform(low=float(dist.left), high=float(dist.right), size=size) + + +# DRV: + +@do_sample_numpy.register(GeometricDistribution) +def _(dist: GeometricDistribution, size, rand_state): + return rand_state.geometric(p=float(dist.p), size=size) + + +@do_sample_numpy.register(PoissonDistribution) +def _(dist: PoissonDistribution, size, rand_state): + return rand_state.poisson(lam=float(dist.lamda), size=size) + + +@do_sample_numpy.register(ZetaDistribution) +def _(dist: ZetaDistribution, size, rand_state): + return rand_state.zipf(a=float(dist.s), size=size) + + +# FRV: + +@do_sample_numpy.register(BinomialDistribution) +def _(dist: BinomialDistribution, size, rand_state): + return rand_state.binomial(n=int(dist.n), p=float(dist.p), size=size) + +@do_sample_numpy.register(HypergeometricDistribution) +def _(dist: HypergeometricDistribution, size, rand_state): + return rand_state.hypergeometric(ngood = int(dist.N), nbad = int(dist.m), nsample = int(dist.n), size=size) diff --git a/phivenv/Lib/site-packages/sympy/stats/sampling/sample_pymc.py b/phivenv/Lib/site-packages/sympy/stats/sampling/sample_pymc.py new file mode 100644 index 0000000000000000000000000000000000000000..546f02a3092815af2e54b4a164463f62ece7a024 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/stats/sampling/sample_pymc.py @@ -0,0 +1,99 @@ +from functools import singledispatch +from sympy.external import import_module +from sympy.stats.crv_types import BetaDistribution, CauchyDistribution, ChiSquaredDistribution, ExponentialDistribution, \ + GammaDistribution, LogNormalDistribution, NormalDistribution, ParetoDistribution, UniformDistribution, \ + GaussianInverseDistribution +from sympy.stats.drv_types import PoissonDistribution, GeometricDistribution, NegativeBinomialDistribution +from sympy.stats.frv_types import BinomialDistribution, BernoulliDistribution + + +try: + import pymc +except ImportError: + pymc = import_module('pymc3') + +@singledispatch +def do_sample_pymc(dist): + return None + + +# CRV: + +@do_sample_pymc.register(BetaDistribution) +def _(dist: BetaDistribution): + return pymc.Beta('X', alpha=float(dist.alpha), beta=float(dist.beta)) + + +@do_sample_pymc.register(CauchyDistribution) +def _(dist: CauchyDistribution): + return pymc.Cauchy('X', alpha=float(dist.x0), beta=float(dist.gamma)) + + +@do_sample_pymc.register(ChiSquaredDistribution) +def _(dist: ChiSquaredDistribution): + return pymc.ChiSquared('X', nu=float(dist.k)) + + +@do_sample_pymc.register(ExponentialDistribution) +def _(dist: ExponentialDistribution): + return pymc.Exponential('X', lam=float(dist.rate)) + + +@do_sample_pymc.register(GammaDistribution) +def _(dist: GammaDistribution): + return pymc.Gamma('X', alpha=float(dist.k), beta=1 / float(dist.theta)) + + +@do_sample_pymc.register(LogNormalDistribution) +def _(dist: LogNormalDistribution): + return pymc.Lognormal('X', mu=float(dist.mean), sigma=float(dist.std)) + + +@do_sample_pymc.register(NormalDistribution) +def _(dist: NormalDistribution): + return pymc.Normal('X', float(dist.mean), float(dist.std)) + + +@do_sample_pymc.register(GaussianInverseDistribution) +def _(dist: GaussianInverseDistribution): + return pymc.Wald('X', mu=float(dist.mean), lam=float(dist.shape)) + + +@do_sample_pymc.register(ParetoDistribution) +def _(dist: ParetoDistribution): + return pymc.Pareto('X', alpha=float(dist.alpha), m=float(dist.xm)) + + +@do_sample_pymc.register(UniformDistribution) +def _(dist: UniformDistribution): + return pymc.Uniform('X', lower=float(dist.left), upper=float(dist.right)) + + +# DRV: + +@do_sample_pymc.register(GeometricDistribution) +def _(dist: GeometricDistribution): + return pymc.Geometric('X', p=float(dist.p)) + + +@do_sample_pymc.register(NegativeBinomialDistribution) +def _(dist: NegativeBinomialDistribution): + return pymc.NegativeBinomial('X', mu=float((dist.p * dist.r) / (1 - dist.p)), + alpha=float(dist.r)) + + +@do_sample_pymc.register(PoissonDistribution) +def _(dist: PoissonDistribution): + return pymc.Poisson('X', mu=float(dist.lamda)) + + +# FRV: + +@do_sample_pymc.register(BernoulliDistribution) +def _(dist: BernoulliDistribution): + return pymc.Bernoulli('X', p=float(dist.p)) + + +@do_sample_pymc.register(BinomialDistribution) +def _(dist: BinomialDistribution): + return pymc.Binomial('X', n=int(dist.n), p=float(dist.p)) diff --git a/phivenv/Lib/site-packages/sympy/stats/sampling/sample_scipy.py b/phivenv/Lib/site-packages/sympy/stats/sampling/sample_scipy.py new file mode 100644 index 0000000000000000000000000000000000000000..f12508f68844488e9a14b1476005164eb422796e --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/stats/sampling/sample_scipy.py @@ -0,0 +1,167 @@ +from functools import singledispatch + +from sympy.core.symbol import Dummy +from sympy.functions.elementary.exponential import exp +from sympy.utilities.lambdify import lambdify +from sympy.external import import_module +from sympy.stats import DiscreteDistributionHandmade +from sympy.stats.crv import SingleContinuousDistribution +from sympy.stats.crv_types import ChiSquaredDistribution, ExponentialDistribution, GammaDistribution, \ + LogNormalDistribution, NormalDistribution, ParetoDistribution, UniformDistribution, BetaDistribution, \ + StudentTDistribution, CauchyDistribution +from sympy.stats.drv_types import GeometricDistribution, LogarithmicDistribution, NegativeBinomialDistribution, \ + PoissonDistribution, SkellamDistribution, YuleSimonDistribution, ZetaDistribution +from sympy.stats.frv import SingleFiniteDistribution + + +scipy = import_module("scipy", import_kwargs={'fromlist':['stats']}) + + +@singledispatch +def do_sample_scipy(dist, size, seed): + return None + + +# CRV + +@do_sample_scipy.register(SingleContinuousDistribution) +def _(dist: SingleContinuousDistribution, size, seed): + # if we don't need to make a handmade pdf, we won't + import scipy.stats + + z = Dummy('z') + handmade_pdf = lambdify(z, dist.pdf(z), ['numpy', 'scipy']) + + class scipy_pdf(scipy.stats.rv_continuous): + def _pdf(dist, x): + return handmade_pdf(x) + + scipy_rv = scipy_pdf(a=float(dist.set._inf), + b=float(dist.set._sup), name='scipy_pdf') + return scipy_rv.rvs(size=size, random_state=seed) + + +@do_sample_scipy.register(ChiSquaredDistribution) +def _(dist: ChiSquaredDistribution, size, seed): + # same parametrisation + return scipy.stats.chi2.rvs(df=float(dist.k), size=size, random_state=seed) + + +@do_sample_scipy.register(ExponentialDistribution) +def _(dist: ExponentialDistribution, size, seed): + # https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.expon.html#scipy.stats.expon + return scipy.stats.expon.rvs(scale=1 / float(dist.rate), size=size, random_state=seed) + + +@do_sample_scipy.register(GammaDistribution) +def _(dist: GammaDistribution, size, seed): + # https://stackoverflow.com/questions/42150965/how-to-plot-gamma-distribution-with-alpha-and-beta-parameters-in-python + return scipy.stats.gamma.rvs(a=float(dist.k), scale=float(dist.theta), size=size, random_state=seed) + + +@do_sample_scipy.register(LogNormalDistribution) +def _(dist: LogNormalDistribution, size, seed): + # https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.lognorm.html + return scipy.stats.lognorm.rvs(scale=float(exp(dist.mean)), s=float(dist.std), size=size, random_state=seed) + + +@do_sample_scipy.register(NormalDistribution) +def _(dist: NormalDistribution, size, seed): + return scipy.stats.norm.rvs(loc=float(dist.mean), scale=float(dist.std), size=size, random_state=seed) + + +@do_sample_scipy.register(ParetoDistribution) +def _(dist: ParetoDistribution, size, seed): + # https://stackoverflow.com/questions/42260519/defining-pareto-distribution-in-python-scipy + return scipy.stats.pareto.rvs(b=float(dist.alpha), scale=float(dist.xm), size=size, random_state=seed) + + +@do_sample_scipy.register(StudentTDistribution) +def _(dist: StudentTDistribution, size, seed): + return scipy.stats.t.rvs(df=float(dist.nu), size=size, random_state=seed) + + +@do_sample_scipy.register(UniformDistribution) +def _(dist: UniformDistribution, size, seed): + # https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.uniform.html + return scipy.stats.uniform.rvs(loc=float(dist.left), scale=float(dist.right - dist.left), size=size, random_state=seed) + + +@do_sample_scipy.register(BetaDistribution) +def _(dist: BetaDistribution, size, seed): + # same parametrisation + return scipy.stats.beta.rvs(a=float(dist.alpha), b=float(dist.beta), size=size, random_state=seed) + + +@do_sample_scipy.register(CauchyDistribution) +def _(dist: CauchyDistribution, size, seed): + return scipy.stats.cauchy.rvs(loc=float(dist.x0), scale=float(dist.gamma), size=size, random_state=seed) + + +# DRV: + +@do_sample_scipy.register(DiscreteDistributionHandmade) +def _(dist: DiscreteDistributionHandmade, size, seed): + from scipy.stats import rv_discrete + + z = Dummy('z') + handmade_pmf = lambdify(z, dist.pdf(z), ['numpy', 'scipy']) + + class scipy_pmf(rv_discrete): + def _pmf(dist, x): + return handmade_pmf(x) + + scipy_rv = scipy_pmf(a=float(dist.set._inf), b=float(dist.set._sup), + name='scipy_pmf') + return scipy_rv.rvs(size=size, random_state=seed) + + +@do_sample_scipy.register(GeometricDistribution) +def _(dist: GeometricDistribution, size, seed): + return scipy.stats.geom.rvs(p=float(dist.p), size=size, random_state=seed) + + +@do_sample_scipy.register(LogarithmicDistribution) +def _(dist: LogarithmicDistribution, size, seed): + return scipy.stats.logser.rvs(p=float(dist.p), size=size, random_state=seed) + + +@do_sample_scipy.register(NegativeBinomialDistribution) +def _(dist: NegativeBinomialDistribution, size, seed): + return scipy.stats.nbinom.rvs(n=float(dist.r), p=float(dist.p), size=size, random_state=seed) + + +@do_sample_scipy.register(PoissonDistribution) +def _(dist: PoissonDistribution, size, seed): + return scipy.stats.poisson.rvs(mu=float(dist.lamda), size=size, random_state=seed) + + +@do_sample_scipy.register(SkellamDistribution) +def _(dist: SkellamDistribution, size, seed): + return scipy.stats.skellam.rvs(mu1=float(dist.mu1), mu2=float(dist.mu2), size=size, random_state=seed) + + +@do_sample_scipy.register(YuleSimonDistribution) +def _(dist: YuleSimonDistribution, size, seed): + return scipy.stats.yulesimon.rvs(alpha=float(dist.rho), size=size, random_state=seed) + + +@do_sample_scipy.register(ZetaDistribution) +def _(dist: ZetaDistribution, size, seed): + return scipy.stats.zipf.rvs(a=float(dist.s), size=size, random_state=seed) + + +# FRV: + +@do_sample_scipy.register(SingleFiniteDistribution) +def _(dist: SingleFiniteDistribution, size, seed): + # scipy can handle with custom distributions + + from scipy.stats import rv_discrete + density_ = dist.dict + x, y = [], [] + for k, v in density_.items(): + x.append(int(k)) + y.append(float(v)) + scipy_rv = rv_discrete(name='scipy_rv', values=(x, y)) + return scipy_rv.rvs(size=size, random_state=seed) diff --git a/phivenv/Lib/site-packages/sympy/stats/sampling/tests/__init__.py b/phivenv/Lib/site-packages/sympy/stats/sampling/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/stats/sampling/tests/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/stats/sampling/tests/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e223fc82cf6ff0160bbc544138107b16c66646ee Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/stats/sampling/tests/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/stats/sampling/tests/__pycache__/test_sample_continuous_rv.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/stats/sampling/tests/__pycache__/test_sample_continuous_rv.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed019f181e0774e887ad9a40574f468fcbb57726 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/stats/sampling/tests/__pycache__/test_sample_continuous_rv.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/stats/sampling/tests/__pycache__/test_sample_discrete_rv.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/stats/sampling/tests/__pycache__/test_sample_discrete_rv.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd506f80bdbef1aa16a0b16ebb7d440487fff1bc Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/stats/sampling/tests/__pycache__/test_sample_discrete_rv.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/stats/sampling/tests/__pycache__/test_sample_finite_rv.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/stats/sampling/tests/__pycache__/test_sample_finite_rv.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f1532c85a9261fed523c1b5b1eb1fdbf7b1a25a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/stats/sampling/tests/__pycache__/test_sample_finite_rv.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/stats/sampling/tests/test_sample_continuous_rv.py b/phivenv/Lib/site-packages/sympy/stats/sampling/tests/test_sample_continuous_rv.py new file mode 100644 index 0000000000000000000000000000000000000000..953bb602df5e63da2882ee118de9dbf24b6f7804 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/stats/sampling/tests/test_sample_continuous_rv.py @@ -0,0 +1,181 @@ +from sympy.core.numbers import oo +from sympy.core.symbol import Symbol +from sympy.functions.elementary.exponential import exp +from sympy.sets.sets import Interval +from sympy.external import import_module +from sympy.stats import Beta, Chi, Normal, Gamma, Exponential, LogNormal, Pareto, ChiSquared, Uniform, sample, \ + BetaPrime, Cauchy, GammaInverse, GaussianInverse, StudentT, Weibull, density, ContinuousRV, FDistribution, \ + Gumbel, Laplace, Logistic, Rayleigh, Triangular +from sympy.testing.pytest import skip, raises + + +def test_sample_numpy(): + distribs_numpy = [ + Beta("B", 1, 1), + Normal("N", 0, 1), + Gamma("G", 2, 7), + Exponential("E", 2), + LogNormal("LN", 0, 1), + Pareto("P", 1, 1), + ChiSquared("CS", 2), + Uniform("U", 0, 1), + FDistribution("FD", 1, 2), + Gumbel("GB", 1, 2), + Laplace("L", 1, 2), + Logistic("LO", 1, 2), + Rayleigh("R", 1), + Triangular("T", 1, 2, 2), + ] + size = 3 + numpy = import_module('numpy') + if not numpy: + skip('Numpy is not installed. Abort tests for _sample_numpy.') + else: + for X in distribs_numpy: + samps = sample(X, size=size, library='numpy') + for sam in samps: + assert sam in X.pspace.domain.set + raises(NotImplementedError, + lambda: sample(Chi("C", 1), library='numpy')) + raises(NotImplementedError, + lambda: Chi("C", 1).pspace.distribution.sample(library='tensorflow')) + + +def test_sample_scipy(): + distribs_scipy = [ + Beta("B", 1, 1), + BetaPrime("BP", 1, 1), + Cauchy("C", 1, 1), + Chi("C", 1), + Normal("N", 0, 1), + Gamma("G", 2, 7), + GammaInverse("GI", 1, 1), + GaussianInverse("GUI", 1, 1), + Exponential("E", 2), + LogNormal("LN", 0, 1), + Pareto("P", 1, 1), + StudentT("S", 2), + ChiSquared("CS", 2), + Uniform("U", 0, 1) + ] + size = 3 + scipy = import_module('scipy') + if not scipy: + skip('Scipy is not installed. Abort tests for _sample_scipy.') + else: + for X in distribs_scipy: + samps = sample(X, size=size, library='scipy') + samps2 = sample(X, size=(2, 2), library='scipy') + for sam in samps: + assert sam in X.pspace.domain.set + for i in range(2): + for j in range(2): + assert samps2[i][j] in X.pspace.domain.set + + +def test_sample_pymc(): + distribs_pymc = [ + Beta("B", 1, 1), + Cauchy("C", 1, 1), + Normal("N", 0, 1), + Gamma("G", 2, 7), + GaussianInverse("GI", 1, 1), + Exponential("E", 2), + LogNormal("LN", 0, 1), + Pareto("P", 1, 1), + ChiSquared("CS", 2), + Uniform("U", 0, 1) + ] + size = 3 + pymc = import_module('pymc') + if not pymc: + skip('PyMC is not installed. Abort tests for _sample_pymc.') + else: + for X in distribs_pymc: + samps = sample(X, size=size, library='pymc') + for sam in samps: + assert sam in X.pspace.domain.set + raises(NotImplementedError, + lambda: sample(Chi("C", 1), library='pymc')) + + +def test_sampling_gamma_inverse(): + scipy = import_module('scipy') + if not scipy: + skip('Scipy not installed. Abort tests for sampling of gamma inverse.') + X = GammaInverse("x", 1, 1) + assert sample(X) in X.pspace.domain.set + + +def test_lognormal_sampling(): + # Right now, only density function and sampling works + scipy = import_module('scipy') + if not scipy: + skip('Scipy is not installed. Abort tests') + for i in range(3): + X = LogNormal('x', i, 1) + assert sample(X) in X.pspace.domain.set + + size = 5 + samps = sample(X, size=size) + for samp in samps: + assert samp in X.pspace.domain.set + + +def test_sampling_gaussian_inverse(): + scipy = import_module('scipy') + if not scipy: + skip('Scipy not installed. Abort tests for sampling of Gaussian inverse.') + X = GaussianInverse("x", 1, 1) + assert sample(X, library='scipy') in X.pspace.domain.set + + +def test_prefab_sampling(): + scipy = import_module('scipy') + if not scipy: + skip('Scipy is not installed. Abort tests') + N = Normal('X', 0, 1) + L = LogNormal('L', 0, 1) + E = Exponential('Ex', 1) + P = Pareto('P', 1, 3) + W = Weibull('W', 1, 1) + U = Uniform('U', 0, 1) + B = Beta('B', 2, 5) + G = Gamma('G', 1, 3) + + variables = [N, L, E, P, W, U, B, G] + niter = 10 + size = 5 + for var in variables: + for _ in range(niter): + assert sample(var) in var.pspace.domain.set + samps = sample(var, size=size) + for samp in samps: + assert samp in var.pspace.domain.set + + +def test_sample_continuous(): + z = Symbol('z') + Z = ContinuousRV(z, exp(-z), set=Interval(0, oo)) + assert density(Z)(-1) == 0 + + scipy = import_module('scipy') + if not scipy: + skip('Scipy is not installed. Abort tests') + assert sample(Z) in Z.pspace.domain.set + sym, val = list(Z.pspace.sample().items())[0] + assert sym == Z and val in Interval(0, oo) + + libraries = ['scipy', 'numpy', 'pymc'] + for lib in libraries: + try: + imported_lib = import_module(lib) + if imported_lib: + s0, s1, s2 = [], [], [] + s0 = sample(Z, size=10, library=lib, seed=0) + s1 = sample(Z, size=10, library=lib, seed=0) + s2 = sample(Z, size=10, library=lib, seed=1) + assert all(s0 == s1) + assert all(s1 != s2) + except NotImplementedError: + continue diff --git a/phivenv/Lib/site-packages/sympy/stats/sampling/tests/test_sample_discrete_rv.py b/phivenv/Lib/site-packages/sympy/stats/sampling/tests/test_sample_discrete_rv.py new file mode 100644 index 0000000000000000000000000000000000000000..90d385cd599222fd7da7c1559b619bafbeb01831 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/stats/sampling/tests/test_sample_discrete_rv.py @@ -0,0 +1,109 @@ +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.external import import_module +from sympy.stats import ( + Geometric, + Poisson, + Zeta, + sample, + Skellam, + Logarithmic, + NegativeBinomial, + YuleSimon, + DiscreteRV, +) +from sympy.testing.pytest import skip, raises, slow + + +def test_sample_numpy(): + distribs_numpy = [ + Geometric('G', 0.5), + Poisson('P', 1), + Zeta('Z', 2) + ] + size = 3 + numpy = import_module('numpy') + if not numpy: + skip('Numpy is not installed. Abort tests for _sample_numpy.') + else: + for X in distribs_numpy: + samps = sample(X, size=size, library='numpy') + for sam in samps: + assert sam in X.pspace.domain.set + raises(NotImplementedError, + lambda: sample(Skellam('S', 1, 1), library='numpy')) + raises(NotImplementedError, + lambda: Skellam('S', 1, 1).pspace.distribution.sample(library='tensorflow')) + + +def test_sample_scipy(): + p = S(2)/3 + x = Symbol('x', integer=True, positive=True) + pdf = p*(1 - p)**(x - 1) # pdf of Geometric Distribution + distribs_scipy = [ + DiscreteRV(x, pdf, set=S.Naturals), + Geometric('G', 0.5), + Logarithmic('L', 0.5), + NegativeBinomial('N', 5, 0.4), + Poisson('P', 1), + Skellam('S', 1, 1), + YuleSimon('Y', 1), + Zeta('Z', 2) + ] + size = 3 + scipy = import_module('scipy') + if not scipy: + skip('Scipy is not installed. Abort tests for _sample_scipy.') + else: + for X in distribs_scipy: + samps = sample(X, size=size, library='scipy') + samps2 = sample(X, size=(2, 2), library='scipy') + for sam in samps: + assert sam in X.pspace.domain.set + for i in range(2): + for j in range(2): + assert samps2[i][j] in X.pspace.domain.set + + +def test_sample_pymc(): + distribs_pymc = [ + Geometric('G', 0.5), + Poisson('P', 1), + NegativeBinomial('N', 5, 0.4) + ] + size = 3 + pymc = import_module('pymc') + if not pymc: + skip('PyMC is not installed. Abort tests for _sample_pymc.') + else: + for X in distribs_pymc: + samps = sample(X, size=size, library='pymc') + for sam in samps: + assert sam in X.pspace.domain.set + raises(NotImplementedError, + lambda: sample(Skellam('S', 1, 1), library='pymc')) + +@slow +def test_sample_discrete(): + X = Geometric('X', S.Half) + scipy = import_module('scipy') + if not scipy: + skip('Scipy not installed. Abort tests') + assert sample(X) in X.pspace.domain.set + samps = sample(X, size=2) # This takes long time if ran without scipy + for samp in samps: + assert samp in X.pspace.domain.set + + libraries = ['scipy', 'numpy', 'pymc'] + for lib in libraries: + try: + imported_lib = import_module(lib) + if imported_lib: + s0, s1, s2 = [], [], [] + s0 = sample(X, size=10, library=lib, seed=0) + s1 = sample(X, size=10, library=lib, seed=0) + s2 = sample(X, size=10, library=lib, seed=1) + assert all(s0 == s1) + assert not all(s1 == s2) + except NotImplementedError: + continue diff --git a/phivenv/Lib/site-packages/sympy/stats/sampling/tests/test_sample_finite_rv.py b/phivenv/Lib/site-packages/sympy/stats/sampling/tests/test_sample_finite_rv.py new file mode 100644 index 0000000000000000000000000000000000000000..96cabe0ff4aaa5977e16600217fbbdeb08b962ae --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/stats/sampling/tests/test_sample_finite_rv.py @@ -0,0 +1,94 @@ +from sympy.core.numbers import Rational +from sympy.core.singleton import S +from sympy.external import import_module +from sympy.stats import Binomial, sample, Die, FiniteRV, DiscreteUniform, Bernoulli, BetaBinomial, Hypergeometric, \ + Rademacher +from sympy.testing.pytest import skip, raises + +def test_given_sample(): + X = Die('X', 6) + scipy = import_module('scipy') + if not scipy: + skip('Scipy is not installed. Abort tests') + assert sample(X, X > 5) == 6 + +def test_sample_numpy(): + distribs_numpy = [ + Binomial("B", 5, 0.4), + Hypergeometric("H", 2, 1, 1) + ] + size = 3 + numpy = import_module('numpy') + if not numpy: + skip('Numpy is not installed. Abort tests for _sample_numpy.') + else: + for X in distribs_numpy: + samps = sample(X, size=size, library='numpy') + for sam in samps: + assert sam in X.pspace.domain.set + raises(NotImplementedError, + lambda: sample(Die("D"), library='numpy')) + raises(NotImplementedError, + lambda: Die("D").pspace.sample(library='tensorflow')) + + +def test_sample_scipy(): + distribs_scipy = [ + FiniteRV('F', {1: S.Half, 2: Rational(1, 4), 3: Rational(1, 4)}), + DiscreteUniform("Y", list(range(5))), + Die("D"), + Bernoulli("Be", 0.3), + Binomial("Bi", 5, 0.4), + BetaBinomial("Bb", 2, 1, 1), + Hypergeometric("H", 1, 1, 1), + Rademacher("R") + ] + + size = 3 + scipy = import_module('scipy') + if not scipy: + skip('Scipy not installed. Abort tests for _sample_scipy.') + else: + for X in distribs_scipy: + samps = sample(X, size=size) + samps2 = sample(X, size=(2, 2)) + for sam in samps: + assert sam in X.pspace.domain.set + for i in range(2): + for j in range(2): + assert samps2[i][j] in X.pspace.domain.set + + +def test_sample_pymc(): + distribs_pymc = [ + Bernoulli('B', 0.2), + Binomial('N', 5, 0.4) + ] + size = 3 + pymc = import_module('pymc') + if not pymc: + skip('PyMC is not installed. Abort tests for _sample_pymc.') + else: + for X in distribs_pymc: + samps = sample(X, size=size, library='pymc') + for sam in samps: + assert sam in X.pspace.domain.set + raises(NotImplementedError, + lambda: (sample(Die("D"), library='pymc'))) + + +def test_sample_seed(): + F = FiniteRV('F', {1: S.Half, 2: Rational(1, 4), 3: Rational(1, 4)}) + size = 10 + libraries = ['scipy', 'numpy', 'pymc'] + for lib in libraries: + try: + imported_lib = import_module(lib) + if imported_lib: + s0 = sample(F, size=size, library=lib, seed=0) + s1 = sample(F, size=size, library=lib, seed=0) + s2 = sample(F, size=size, library=lib, seed=1) + assert all(s0 == s1) + assert not all(s1 == s2) + except NotImplementedError: + continue diff --git a/phivenv/Lib/site-packages/sympy/stats/stochastic_process.py b/phivenv/Lib/site-packages/sympy/stats/stochastic_process.py new file mode 100644 index 0000000000000000000000000000000000000000..bfb0e759c66be892ae38ddda004dfe928f683fee --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/stats/stochastic_process.py @@ -0,0 +1,66 @@ +from sympy.core.basic import Basic +from sympy.stats.joint_rv import ProductPSpace +from sympy.stats.rv import ProductDomain, _symbol_converter, Distribution + + +class StochasticPSpace(ProductPSpace): + """ + Represents probability space of stochastic processes + and their random variables. Contains mechanics to do + computations for queries of stochastic processes. + + Explanation + =========== + + Initialized by symbol, the specific process and + distribution(optional) if the random indexed symbols + of the process follows any specific distribution, like, + in Bernoulli Process, each random indexed symbol follows + Bernoulli distribution. For processes with memory, this + parameter should not be passed. + """ + + def __new__(cls, sym, process, distribution=None): + sym = _symbol_converter(sym) + from sympy.stats.stochastic_process_types import StochasticProcess + if not isinstance(process, StochasticProcess): + raise TypeError("`process` must be an instance of StochasticProcess.") + if distribution is None: + distribution = Distribution() + return Basic.__new__(cls, sym, process, distribution) + + @property + def process(self): + """ + The associated stochastic process. + """ + return self.args[1] + + @property + def domain(self): + return ProductDomain(self.process.index_set, + self.process.state_space) + + @property + def symbol(self): + return self.args[0] + + @property + def distribution(self): + return self.args[2] + + def probability(self, condition, given_condition=None, evaluate=True, **kwargs): + """ + Transfers the task of handling queries to the specific stochastic + process because every process has their own logic of handling such + queries. + """ + return self.process.probability(condition, given_condition, evaluate, **kwargs) + + def compute_expectation(self, expr, condition=None, evaluate=True, **kwargs): + """ + Transfers the task of handling queries to the specific stochastic + process because every process has their own logic of handling such + queries. + """ + return self.process.expectation(expr, condition, evaluate, **kwargs) diff --git a/phivenv/Lib/site-packages/sympy/stats/stochastic_process_types.py b/phivenv/Lib/site-packages/sympy/stats/stochastic_process_types.py new file mode 100644 index 0000000000000000000000000000000000000000..7387cd3dbcf6defb3b7e475f542d04ecef6fecf6 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/stats/stochastic_process_types.py @@ -0,0 +1,2383 @@ +from __future__ import annotations +import random +import itertools +from typing import Sequence as tSequence +from sympy.concrete.summations import Sum +from sympy.core.add import Add +from sympy.core.basic import Basic +from sympy.core.cache import cacheit +from sympy.core.containers import Tuple +from sympy.core.expr import Expr +from sympy.core.function import (Function, Lambda) +from sympy.core.mul import Mul +from sympy.core.intfunc import igcd +from sympy.core.numbers import (Integer, Rational, oo, pi) +from sympy.core.relational import (Eq, Ge, Gt, Le, Lt, Ne) +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol) +from sympy.functions.combinatorial.factorials import factorial +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.integers import ceiling +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.special.gamma_functions import gamma +from sympy.logic.boolalg import (And, Not, Or) +from sympy.matrices.exceptions import NonSquareMatrixError +from sympy.matrices.dense import (Matrix, eye, ones, zeros) +from sympy.matrices.expressions.blockmatrix import BlockMatrix +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.expressions.special import Identity +from sympy.matrices.immutable import ImmutableMatrix +from sympy.sets.conditionset import ConditionSet +from sympy.sets.contains import Contains +from sympy.sets.fancysets import Range +from sympy.sets.sets import (FiniteSet, Intersection, Interval, Set, Union) +from sympy.solvers.solveset import linsolve +from sympy.tensor.indexed import (Indexed, IndexedBase) +from sympy.core.relational import Relational +from sympy.logic.boolalg import Boolean +from sympy.utilities.exceptions import sympy_deprecation_warning +from sympy.utilities.iterables import strongly_connected_components +from sympy.stats.joint_rv import JointDistribution +from sympy.stats.joint_rv_types import JointDistributionHandmade +from sympy.stats.rv import (RandomIndexedSymbol, random_symbols, RandomSymbol, + _symbol_converter, _value_check, pspace, given, + dependent, is_random, sample_iter, Distribution, + Density) +from sympy.stats.stochastic_process import StochasticPSpace +from sympy.stats.symbolic_probability import Probability, Expectation +from sympy.stats.frv_types import Bernoulli, BernoulliDistribution, FiniteRV +from sympy.stats.drv_types import Poisson, PoissonDistribution +from sympy.stats.crv_types import Normal, NormalDistribution, Gamma, GammaDistribution +from sympy.core.sympify import _sympify, sympify + +EmptySet = S.EmptySet + +__all__ = [ + 'StochasticProcess', + 'DiscreteTimeStochasticProcess', + 'DiscreteMarkovChain', + 'TransitionMatrixOf', + 'StochasticStateSpaceOf', + 'GeneratorMatrixOf', + 'ContinuousMarkovChain', + 'BernoulliProcess', + 'PoissonProcess', + 'WienerProcess', + 'GammaProcess' +] + + +@is_random.register(Indexed) +def _(x): + return is_random(x.base) + +@is_random.register(RandomIndexedSymbol) # type: ignore +def _(x): + return True + +def _set_converter(itr): + """ + Helper function for converting list/tuple/set to Set. + If parameter is not an instance of list/tuple/set then + no operation is performed. + + Returns + ======= + + Set + The argument converted to Set. + + + Raises + ====== + + TypeError + If the argument is not an instance of list/tuple/set. + """ + if isinstance(itr, (list, tuple, set)): + itr = FiniteSet(*itr) + if not isinstance(itr, Set): + raise TypeError("%s is not an instance of list/tuple/set."%(itr)) + return itr + +def _state_converter(itr: tSequence) -> Tuple | Range: + """ + Helper function for converting list/tuple/set/Range/Tuple/FiniteSet + to tuple/Range. + """ + itr_ret: Tuple | Range + + if isinstance(itr, (Tuple, set, FiniteSet)): + itr_ret = Tuple(*(sympify(i) if isinstance(i, str) else i for i in itr)) + + elif isinstance(itr, (list, tuple)): + # check if states are unique + if len(set(itr)) != len(itr): + raise ValueError('The state space must have unique elements.') + itr_ret = Tuple(*(sympify(i) if isinstance(i, str) else i for i in itr)) + + elif isinstance(itr, Range): + # the only ordered set in SymPy I know of + # try to convert to tuple + try: + itr_ret = Tuple(*(sympify(i) if isinstance(i, str) else i for i in itr)) + except (TypeError, ValueError): + itr_ret = itr + + else: + raise TypeError("%s is not an instance of list/tuple/set/Range/Tuple/FiniteSet." % (itr)) + return itr_ret + +def _sym_sympify(arg): + """ + Converts an arbitrary expression to a type that can be used inside SymPy. + As generally strings are unwise to use in the expressions, + it returns the Symbol of argument if the string type argument is passed. + + Parameters + ========= + + arg: The parameter to be converted to be used in SymPy. + + Returns + ======= + + The converted parameter. + + """ + if isinstance(arg, str): + return Symbol(arg) + else: + return _sympify(arg) + +def _matrix_checks(matrix): + if not isinstance(matrix, (Matrix, MatrixSymbol, ImmutableMatrix)): + raise TypeError("Transition probabilities either should " + "be a Matrix or a MatrixSymbol.") + if matrix.shape[0] != matrix.shape[1]: + raise NonSquareMatrixError("%s is not a square matrix"%(matrix)) + if isinstance(matrix, Matrix): + matrix = ImmutableMatrix(matrix.tolist()) + return matrix + +class StochasticProcess(Basic): + """ + Base class for all the stochastic processes whether + discrete or continuous. + + Parameters + ========== + + sym: Symbol or str + state_space: Set + The state space of the stochastic process, by default S.Reals. + For discrete sets it is zero indexed. + + See Also + ======== + + DiscreteTimeStochasticProcess + """ + + index_set = S.Reals + + def __new__(cls, sym, state_space=S.Reals, **kwargs): + sym = _symbol_converter(sym) + state_space = _set_converter(state_space) + return Basic.__new__(cls, sym, state_space) + + @property + def symbol(self): + return self.args[0] + + @property + def state_space(self) -> FiniteSet | Range: + if not isinstance(self.args[1], (FiniteSet, Range)): + assert isinstance(self.args[1], Tuple) + return FiniteSet(*self.args[1]) + return self.args[1] + + def _deprecation_warn_distribution(self): + sympy_deprecation_warning( + """ + Calling the distribution method with a RandomIndexedSymbol + argument, like X.distribution(X(t)) is deprecated. Instead, call + distribution() with the given timestamp, like + + X.distribution(t) + """, + deprecated_since_version="1.7.1", + active_deprecations_target="deprecated-distribution-randomindexedsymbol", + stacklevel=4, + ) + + def distribution(self, key=None): + if key is None: + self._deprecation_warn_distribution() + return Distribution() + + def density(self, x): + return Density() + + def __call__(self, time): + """ + Overridden in ContinuousTimeStochasticProcess. + """ + raise NotImplementedError("Use [] for indexing discrete time stochastic process.") + + def __getitem__(self, time): + """ + Overridden in DiscreteTimeStochasticProcess. + """ + raise NotImplementedError("Use () for indexing continuous time stochastic process.") + + def probability(self, condition): + raise NotImplementedError() + + def joint_distribution(self, *args): + """ + Computes the joint distribution of the random indexed variables. + + Parameters + ========== + + args: iterable + The finite list of random indexed variables/the key of a stochastic + process whose joint distribution has to be computed. + + Returns + ======= + + JointDistribution + The joint distribution of the list of random indexed variables. + An unevaluated object is returned if it is not possible to + compute the joint distribution. + + Raises + ====== + + ValueError: When the arguments passed are not of type RandomIndexSymbol + or Number. + """ + args = list(args) + for i, arg in enumerate(args): + if S(arg).is_Number: + if self.index_set.is_subset(S.Integers): + args[i] = self.__getitem__(arg) + else: + args[i] = self.__call__(arg) + elif not isinstance(arg, RandomIndexedSymbol): + raise ValueError("Expected a RandomIndexedSymbol or " + "key not %s"%(type(arg))) + + if args[0].pspace.distribution == Distribution(): + return JointDistribution(*args) + density = Lambda(tuple(args), + expr=Mul.fromiter(arg.pspace.process.density(arg) for arg in args)) + return JointDistributionHandmade(density) + + def expectation(self, condition, given_condition): + raise NotImplementedError("Abstract method for expectation queries.") + + def sample(self): + raise NotImplementedError("Abstract method for sampling queries.") + +class DiscreteTimeStochasticProcess(StochasticProcess): + """ + Base class for all discrete stochastic processes. + """ + def __getitem__(self, time): + """ + For indexing discrete time stochastic processes. + + Returns + ======= + + RandomIndexedSymbol + """ + time = sympify(time) + if not time.is_symbol and time not in self.index_set: + raise IndexError("%s is not in the index set of %s"%(time, self.symbol)) + idx_obj = Indexed(self.symbol, time) + pspace_obj = StochasticPSpace(self.symbol, self, self.distribution(time)) + return RandomIndexedSymbol(idx_obj, pspace_obj) + +class ContinuousTimeStochasticProcess(StochasticProcess): + """ + Base class for all continuous time stochastic process. + """ + def __call__(self, time): + """ + For indexing continuous time stochastic processes. + + Returns + ======= + + RandomIndexedSymbol + """ + time = sympify(time) + if not time.is_symbol and time not in self.index_set: + raise IndexError("%s is not in the index set of %s"%(time, self.symbol)) + func_obj = Function(self.symbol)(time) + pspace_obj = StochasticPSpace(self.symbol, self, self.distribution(time)) + return RandomIndexedSymbol(func_obj, pspace_obj) + +class TransitionMatrixOf(Boolean): + """ + Assumes that the matrix is the transition matrix + of the process. + """ + + def __new__(cls, process, matrix): + if not isinstance(process, DiscreteMarkovChain): + raise ValueError("Currently only DiscreteMarkovChain " + "support TransitionMatrixOf.") + matrix = _matrix_checks(matrix) + return Basic.__new__(cls, process, matrix) + + process = property(lambda self: self.args[0]) + matrix = property(lambda self: self.args[1]) + +class GeneratorMatrixOf(TransitionMatrixOf): + """ + Assumes that the matrix is the generator matrix + of the process. + """ + + def __new__(cls, process, matrix): + if not isinstance(process, ContinuousMarkovChain): + raise ValueError("Currently only ContinuousMarkovChain " + "support GeneratorMatrixOf.") + matrix = _matrix_checks(matrix) + return Basic.__new__(cls, process, matrix) + +class StochasticStateSpaceOf(Boolean): + + def __new__(cls, process, state_space): + if not isinstance(process, (DiscreteMarkovChain, ContinuousMarkovChain)): + raise ValueError("Currently only DiscreteMarkovChain and ContinuousMarkovChain " + "support StochasticStateSpaceOf.") + state_space = _state_converter(state_space) + if isinstance(state_space, Range): + ss_size = ceiling((state_space.stop - state_space.start) / state_space.step) + else: + ss_size = len(state_space) + state_index = Range(ss_size) + return Basic.__new__(cls, process, state_index) + + process = property(lambda self: self.args[0]) + state_index = property(lambda self: self.args[1]) + +class MarkovProcess(StochasticProcess): + """ + Contains methods that handle queries + common to Markov processes. + """ + + @property + def number_of_states(self) -> Integer | Symbol: + """ + The number of states in the Markov Chain. + """ + return _sympify(self.args[2].shape[0]) # type: ignore + + @property + def _state_index(self): + """ + Returns state index as Range. + """ + return self.args[1] + + @classmethod + def _sanity_checks(cls, state_space, trans_probs): + # Try to never have None as state_space or trans_probs. + # This helps a lot if we get it done at the start. + if (state_space is None) and (trans_probs is None): + _n = Dummy('n', integer=True, nonnegative=True) + state_space = _state_converter(Range(_n)) + trans_probs = _matrix_checks(MatrixSymbol('_T', _n, _n)) + + elif state_space is None: + trans_probs = _matrix_checks(trans_probs) + state_space = _state_converter(Range(trans_probs.shape[0])) + + elif trans_probs is None: + state_space = _state_converter(state_space) + if isinstance(state_space, Range): + _n = ceiling((state_space.stop - state_space.start) / state_space.step) + else: + _n = len(state_space) + trans_probs = MatrixSymbol('_T', _n, _n) + + else: + state_space = _state_converter(state_space) + trans_probs = _matrix_checks(trans_probs) + # Range object doesn't want to give a symbolic size + # so we do it ourselves. + if isinstance(state_space, Range): + ss_size = ceiling((state_space.stop - state_space.start) / state_space.step) + else: + ss_size = len(state_space) + if ss_size != trans_probs.shape[0]: + raise ValueError('The size of the state space and the number of ' + 'rows of the transition matrix must be the same.') + + return state_space, trans_probs + + def _extract_information(self, given_condition): + """ + Helper function to extract information, like, + transition matrix/generator matrix, state space, etc. + """ + if isinstance(self, DiscreteMarkovChain): + trans_probs = self.transition_probabilities + state_index = self._state_index + elif isinstance(self, ContinuousMarkovChain): + trans_probs = self.generator_matrix + state_index = self._state_index + if isinstance(given_condition, And): + gcs = given_condition.args + given_condition = S.true + for gc in gcs: + if isinstance(gc, TransitionMatrixOf): + trans_probs = gc.matrix + if isinstance(gc, StochasticStateSpaceOf): + state_index = gc.state_index + if isinstance(gc, Relational): + given_condition = given_condition & gc + if isinstance(given_condition, TransitionMatrixOf): + trans_probs = given_condition.matrix + given_condition = S.true + if isinstance(given_condition, StochasticStateSpaceOf): + state_index = given_condition.state_index + given_condition = S.true + return trans_probs, state_index, given_condition + + def _check_trans_probs(self, trans_probs, row_sum=1): + """ + Helper function for checking the validity of transition + probabilities. + """ + if not isinstance(trans_probs, MatrixSymbol): + rows = trans_probs.tolist() + for row in rows: + if (sum(row) - row_sum) != 0: + raise ValueError("Values in a row must sum to %s. " + "If you are using Float or floats then please use Rational."%(row_sum)) + + def _work_out_state_index(self, state_index, given_condition, trans_probs): + """ + Helper function to extract state space if there + is a random symbol in the given condition. + """ + # if given condition is None, then there is no need to work out + # state_space from random variables + if given_condition != None: + rand_var = list(given_condition.atoms(RandomSymbol) - + given_condition.atoms(RandomIndexedSymbol)) + if len(rand_var) == 1: + state_index = rand_var[0].pspace.set + + # `not None` is `True`. So the old test fails for symbolic sizes. + # Need to build the statement differently. + sym_cond = not self.number_of_states.is_Integer + cond1 = not sym_cond and len(state_index) != trans_probs.shape[0] + if cond1: + raise ValueError("state space is not compatible with the transition probabilities.") + if not isinstance(trans_probs.shape[0], Symbol): + state_index = FiniteSet(*range(trans_probs.shape[0])) + return state_index + + @cacheit + def _preprocess(self, given_condition, evaluate): + """ + Helper function for pre-processing the information. + """ + is_insufficient = False + + if not evaluate: # avoid pre-processing if the result is not to be evaluated + return (True, None, None, None) + + # extracting transition matrix and state space + trans_probs, state_index, given_condition = self._extract_information(given_condition) + + # given_condition does not have sufficient information + # for computations + if trans_probs is None or \ + given_condition is None: + is_insufficient = True + else: + # checking transition probabilities + if isinstance(self, DiscreteMarkovChain): + self._check_trans_probs(trans_probs, row_sum=1) + elif isinstance(self, ContinuousMarkovChain): + self._check_trans_probs(trans_probs, row_sum=0) + + # working out state space + state_index = self._work_out_state_index(state_index, given_condition, trans_probs) + + return is_insufficient, trans_probs, state_index, given_condition + + def replace_with_index(self, condition): + if isinstance(condition, Relational): + lhs, rhs = condition.lhs, condition.rhs + if not isinstance(lhs, RandomIndexedSymbol): + lhs, rhs = rhs, lhs + condition = type(condition)(self.index_of.get(lhs, lhs), + self.index_of.get(rhs, rhs)) + return condition + + def probability(self, condition, given_condition=None, evaluate=True, **kwargs): + """ + Handles probability queries for Markov process. + + Parameters + ========== + + condition: Relational + given_condition: Relational/And + + Returns + ======= + Probability + If the information is not sufficient. + Expr + In all other cases. + + Note + ==== + Any information passed at the time of query overrides + any information passed at the time of object creation like + transition probabilities, state space. + Pass the transition matrix using TransitionMatrixOf, + generator matrix using GeneratorMatrixOf and state space + using StochasticStateSpaceOf in given_condition using & or And. + """ + check, mat, state_index, new_given_condition = \ + self._preprocess(given_condition, evaluate) + + rv = list(condition.atoms(RandomIndexedSymbol)) + symbolic = False + for sym in rv: + if sym.key.is_symbol: + symbolic = True + break + + if check: + return Probability(condition, new_given_condition) + + if isinstance(self, ContinuousMarkovChain): + trans_probs = self.transition_probabilities(mat) + elif isinstance(self, DiscreteMarkovChain): + trans_probs = mat + condition = self.replace_with_index(condition) + given_condition = self.replace_with_index(given_condition) + new_given_condition = self.replace_with_index(new_given_condition) + + if isinstance(condition, Relational): + if isinstance(new_given_condition, And): + gcs = new_given_condition.args + else: + gcs = (new_given_condition, ) + min_key_rv = list(new_given_condition.atoms(RandomIndexedSymbol)) + + if len(min_key_rv): + min_key_rv = min_key_rv[0] + for r in rv: + if min_key_rv.key.is_symbol or r.key.is_symbol: + continue + if min_key_rv.key > r.key: + return Probability(condition) + else: + min_key_rv = None + return Probability(condition) + + if symbolic: + return self._symbolic_probability(condition, new_given_condition, rv, min_key_rv) + + if len(rv) > 1: + rv[0] = condition.lhs + rv[1] = condition.rhs + if rv[0].key < rv[1].key: + rv[0], rv[1] = rv[1], rv[0] + if isinstance(condition, Gt): + condition = Lt(condition.lhs, condition.rhs) + elif isinstance(condition, Lt): + condition = Gt(condition.lhs, condition.rhs) + elif isinstance(condition, Ge): + condition = Le(condition.lhs, condition.rhs) + elif isinstance(condition, Le): + condition = Ge(condition.lhs, condition.rhs) + s = Rational(0, 1) + n = len(self.state_space) + + if isinstance(condition, (Eq, Ne)): + for i in range(0, n): + s += self.probability(Eq(rv[0], i), Eq(rv[1], i)) * self.probability(Eq(rv[1], i), new_given_condition) + return s if isinstance(condition, Eq) else 1 - s + else: + upper = 0 + greater = False + if isinstance(condition, (Ge, Lt)): + upper = 1 + if isinstance(condition, (Ge, Gt)): + greater = True + + for i in range(0, n): + if i <= n//2: + for j in range(0, i + upper): + s += self.probability(Eq(rv[0], i), Eq(rv[1], j)) * self.probability(Eq(rv[1], j), new_given_condition) + else: + s += self.probability(Eq(rv[0], i), new_given_condition) + for j in range(i + upper, n): + s -= self.probability(Eq(rv[0], i), Eq(rv[1], j)) * self.probability(Eq(rv[1], j), new_given_condition) + return s if greater else 1 - s + + rv = rv[0] + states = condition.as_set() + prob, gstate = {}, None + for gc in gcs: + if gc.has(min_key_rv): + if gc.has(Probability): + p, gp = (gc.rhs, gc.lhs) if isinstance(gc.lhs, Probability) \ + else (gc.lhs, gc.rhs) + gr = gp.args[0] + gset = Intersection(gr.as_set(), state_index) + gstate = list(gset)[0] + prob[gset] = p + else: + _, gstate = (gc.lhs.key, gc.rhs) if isinstance(gc.lhs, RandomIndexedSymbol) \ + else (gc.rhs.key, gc.lhs) + + if not all(k in self.index_set for k in (rv.key, min_key_rv.key)): + raise IndexError("The timestamps of the process are not in it's index set.") + states = Intersection(states, state_index) if not isinstance(self.number_of_states, Symbol) else states + for state in Union(states, FiniteSet(gstate)): + if not state.is_Integer or Ge(state, mat.shape[0]) is True: + raise IndexError("No information is available for (%s, %s) in " + "transition probabilities of shape, (%s, %s). " + "State space is zero indexed." + %(gstate, state, mat.shape[0], mat.shape[1])) + if prob: + gstates = Union(*prob.keys()) + if len(gstates) == 1: + gstate = list(gstates)[0] + gprob = list(prob.values())[0] + prob[gstates] = gprob + elif len(gstates) == len(state_index) - 1: + gstate = list(state_index - gstates)[0] + gprob = S.One - sum(prob.values()) + prob[state_index - gstates] = gprob + else: + raise ValueError("Conflicting information.") + else: + gprob = S.One + + if min_key_rv == rv: + return sum(prob[FiniteSet(state)] for state in states) + if isinstance(self, ContinuousMarkovChain): + return gprob * sum(trans_probs(rv.key - min_key_rv.key).__getitem__((gstate, state)) + for state in states) + if isinstance(self, DiscreteMarkovChain): + return gprob * sum((trans_probs**(rv.key - min_key_rv.key)).__getitem__((gstate, state)) + for state in states) + + if isinstance(condition, Not): + expr = condition.args[0] + return S.One - self.probability(expr, given_condition, evaluate, **kwargs) + + if isinstance(condition, And): + compute_later, state2cond, conds = [], {}, condition.args + for expr in conds: + if isinstance(expr, Relational): + ris = list(expr.atoms(RandomIndexedSymbol))[0] + if state2cond.get(ris, None) is None: + state2cond[ris] = S.true + state2cond[ris] &= expr + else: + compute_later.append(expr) + ris = [] + for ri in state2cond: + ris.append(ri) + cset = Intersection(state2cond[ri].as_set(), state_index) + if len(cset) == 0: + return S.Zero + state2cond[ri] = cset.as_relational(ri) + sorted_ris = sorted(ris, key=lambda ri: ri.key) + prod = self.probability(state2cond[sorted_ris[0]], given_condition, evaluate, **kwargs) + for i in range(1, len(sorted_ris)): + ri, prev_ri = sorted_ris[i], sorted_ris[i-1] + if not isinstance(state2cond[ri], Eq): + raise ValueError("The process is in multiple states at %s, unable to determine the probability."%(ri)) + mat_of = TransitionMatrixOf(self, mat) if isinstance(self, DiscreteMarkovChain) else GeneratorMatrixOf(self, mat) + prod *= self.probability(state2cond[ri], state2cond[prev_ri] + & mat_of + & StochasticStateSpaceOf(self, state_index), + evaluate, **kwargs) + for expr in compute_later: + prod *= self.probability(expr, given_condition, evaluate, **kwargs) + return prod + + if isinstance(condition, Or): + return sum(self.probability(expr, given_condition, evaluate, **kwargs) + for expr in condition.args) + + raise NotImplementedError("Mechanism for handling (%s, %s) queries hasn't been " + "implemented yet."%(condition, given_condition)) + + def _symbolic_probability(self, condition, new_given_condition, rv, min_key_rv): + #Function to calculate probability for queries with symbols + if isinstance(condition, Relational): + curr_state = new_given_condition.rhs if isinstance(new_given_condition.lhs, RandomIndexedSymbol) \ + else new_given_condition.lhs + next_state = condition.rhs if isinstance(condition.lhs, RandomIndexedSymbol) \ + else condition.lhs + + if isinstance(condition, (Eq, Ne)): + if isinstance(self, DiscreteMarkovChain): + P = self.transition_probabilities**(rv[0].key - min_key_rv.key) + else: + P = exp(self.generator_matrix*(rv[0].key - min_key_rv.key)) + prob = P[curr_state, next_state] if isinstance(condition, Eq) else 1 - P[curr_state, next_state] + return Piecewise((prob, rv[0].key > min_key_rv.key), (Probability(condition), True)) + else: + upper = 1 + greater = False + if isinstance(condition, (Ge, Lt)): + upper = 0 + if isinstance(condition, (Ge, Gt)): + greater = True + k = Dummy('k') + condition = Eq(condition.lhs, k) if isinstance(condition.lhs, RandomIndexedSymbol)\ + else Eq(condition.rhs, k) + total = Sum(self.probability(condition, new_given_condition), (k, next_state + upper, self.state_space._sup)) + return Piecewise((total, rv[0].key > min_key_rv.key), (Probability(condition), True)) if greater\ + else Piecewise((1 - total, rv[0].key > min_key_rv.key), (Probability(condition), True)) + else: + return Probability(condition, new_given_condition) + + def expectation(self, expr, condition=None, evaluate=True, **kwargs): + """ + Handles expectation queries for markov process. + + Parameters + ========== + + expr: RandomIndexedSymbol, Relational, Logic + Condition for which expectation has to be computed. Must + contain a RandomIndexedSymbol of the process. + condition: Relational, Logic + The given conditions under which computations should be done. + + Returns + ======= + + Expectation + Unevaluated object if computations cannot be done due to + insufficient information. + Expr + In all other cases when the computations are successful. + + Note + ==== + + Any information passed at the time of query overrides + any information passed at the time of object creation like + transition probabilities, state space. + + Pass the transition matrix using TransitionMatrixOf, + generator matrix using GeneratorMatrixOf and state space + using StochasticStateSpaceOf in given_condition using & or And. + """ + + check, mat, state_index, condition = \ + self._preprocess(condition, evaluate) + + if check: + return Expectation(expr, condition) + + rvs = random_symbols(expr) + if isinstance(expr, Expr) and isinstance(condition, Eq) \ + and len(rvs) == 1: + # handle queries similar to E(f(X[i]), Eq(X[i-m], )) + condition=self.replace_with_index(condition) + state_index=self.replace_with_index(state_index) + rv = list(rvs)[0] + lhsg, rhsg = condition.lhs, condition.rhs + if not isinstance(lhsg, RandomIndexedSymbol): + lhsg, rhsg = (rhsg, lhsg) + if rhsg not in state_index: + raise ValueError("%s state is not in the state space."%(rhsg)) + if rv.key < lhsg.key: + raise ValueError("Incorrect given condition is given, expectation " + "time %s < time %s"%(rv.key, rv.key)) + mat_of = TransitionMatrixOf(self, mat) if isinstance(self, DiscreteMarkovChain) else GeneratorMatrixOf(self, mat) + cond = condition & mat_of & \ + StochasticStateSpaceOf(self, state_index) + func = lambda s: self.probability(Eq(rv, s), cond) * expr.subs(rv, self._state_index[s]) + return sum(func(s) for s in state_index) + + raise NotImplementedError("Mechanism for handling (%s, %s) queries hasn't been " + "implemented yet."%(expr, condition)) + +class DiscreteMarkovChain(DiscreteTimeStochasticProcess, MarkovProcess): + """ + Represents a finite discrete time-homogeneous Markov chain. + + This type of Markov Chain can be uniquely characterised by + its (ordered) state space and its one-step transition probability + matrix. + + Parameters + ========== + + sym: + The name given to the Markov Chain + state_space: + Optional, by default, Range(n) + trans_probs: + Optional, by default, MatrixSymbol('_T', n, n) + + Examples + ======== + + >>> from sympy.stats import DiscreteMarkovChain, TransitionMatrixOf, P, E + >>> from sympy import Matrix, MatrixSymbol, Eq, symbols + >>> T = Matrix([[0.5, 0.2, 0.3],[0.2, 0.5, 0.3],[0.2, 0.3, 0.5]]) + >>> Y = DiscreteMarkovChain("Y", [0, 1, 2], T) + >>> YS = DiscreteMarkovChain("Y") + + >>> Y.state_space + {0, 1, 2} + >>> Y.transition_probabilities + Matrix([ + [0.5, 0.2, 0.3], + [0.2, 0.5, 0.3], + [0.2, 0.3, 0.5]]) + >>> TS = MatrixSymbol('T', 3, 3) + >>> P(Eq(YS[3], 2), Eq(YS[1], 1) & TransitionMatrixOf(YS, TS)) + T[0, 2]*T[1, 0] + T[1, 1]*T[1, 2] + T[1, 2]*T[2, 2] + >>> P(Eq(Y[3], 2), Eq(Y[1], 1)).round(2) + 0.36 + + Probabilities will be calculated based on indexes rather + than state names. For example, with the Sunny-Cloudy-Rainy + model with string state names: + + >>> from sympy.core.symbol import Str + >>> Y = DiscreteMarkovChain("Y", [Str('Sunny'), Str('Cloudy'), Str('Rainy')], T) + >>> P(Eq(Y[3], 2), Eq(Y[1], 1)).round(2) + 0.36 + + This gives the same answer as the ``[0, 1, 2]`` state space. + Currently, there is no support for state names within probability + and expectation statements. Here is a work-around using ``Str``: + + >>> P(Eq(Str('Rainy'), Y[3]), Eq(Y[1], Str('Cloudy'))).round(2) + 0.36 + + Symbol state names can also be used: + + >>> sunny, cloudy, rainy = symbols('Sunny, Cloudy, Rainy') + >>> Y = DiscreteMarkovChain("Y", [sunny, cloudy, rainy], T) + >>> P(Eq(Y[3], rainy), Eq(Y[1], cloudy)).round(2) + 0.36 + + Expectations will be calculated as follows: + + >>> E(Y[3], Eq(Y[1], cloudy)) + 0.38*Cloudy + 0.36*Rainy + 0.26*Sunny + + Probability of expressions with multiple RandomIndexedSymbols + can also be calculated provided there is only 1 RandomIndexedSymbol + in the given condition. It is always better to use Rational instead + of floating point numbers for the probabilities in the + transition matrix to avoid errors. + + >>> from sympy import Gt, Le, Rational + >>> T = Matrix([[Rational(5, 10), Rational(3, 10), Rational(2, 10)], [Rational(2, 10), Rational(7, 10), Rational(1, 10)], [Rational(3, 10), Rational(3, 10), Rational(4, 10)]]) + >>> Y = DiscreteMarkovChain("Y", [0, 1, 2], T) + >>> P(Eq(Y[3], Y[1]), Eq(Y[0], 0)).round(3) + 0.409 + >>> P(Gt(Y[3], Y[1]), Eq(Y[0], 0)).round(2) + 0.36 + >>> P(Le(Y[15], Y[10]), Eq(Y[8], 2)).round(7) + 0.6963328 + + Symbolic probability queries are also supported + + >>> a, b, c, d = symbols('a b c d') + >>> T = Matrix([[Rational(1, 10), Rational(4, 10), Rational(5, 10)], [Rational(3, 10), Rational(4, 10), Rational(3, 10)], [Rational(7, 10), Rational(2, 10), Rational(1, 10)]]) + >>> Y = DiscreteMarkovChain("Y", [0, 1, 2], T) + >>> query = P(Eq(Y[a], b), Eq(Y[c], d)) + >>> query.subs({a:10, b:2, c:5, d:1}).round(4) + 0.3096 + >>> P(Eq(Y[10], 2), Eq(Y[5], 1)).evalf().round(4) + 0.3096 + >>> query_gt = P(Gt(Y[a], b), Eq(Y[c], d)) + >>> query_gt.subs({a:21, b:0, c:5, d:0}).evalf().round(5) + 0.64705 + >>> P(Gt(Y[21], 0), Eq(Y[5], 0)).round(5) + 0.64705 + + There is limited support for arbitrarily sized states: + + >>> n = symbols('n', nonnegative=True, integer=True) + >>> T = MatrixSymbol('T', n, n) + >>> Y = DiscreteMarkovChain("Y", trans_probs=T) + >>> Y.state_space + Range(0, n, 1) + >>> query = P(Eq(Y[a], b), Eq(Y[c], d)) + >>> query.subs({a:10, b:2, c:5, d:1}) + (T**5)[1, 2] + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Markov_chain#Discrete-time_Markov_chain + .. [2] https://web.archive.org/web/20201230182007/https://www.dartmouth.edu/~chance/teaching_aids/books_articles/probability_book/Chapter11.pdf + """ + index_set = S.Naturals0 + + def __new__(cls, sym, state_space=None, trans_probs=None): + sym = _symbol_converter(sym) + + state_space, trans_probs = MarkovProcess._sanity_checks(state_space, trans_probs) + + obj = Basic.__new__(cls, sym, state_space, trans_probs) # type: ignore + indices = {} + if isinstance(obj.number_of_states, Integer): + for index, state in enumerate(obj._state_index): + indices[state] = index + obj.index_of = indices + return obj + + @property + def transition_probabilities(self): + """ + Transition probabilities of discrete Markov chain, + either an instance of Matrix or MatrixSymbol. + """ + return self.args[2] + + def communication_classes(self) -> list[tuple[list[Basic], Boolean, Integer]]: + """ + Returns the list of communication classes that partition + the states of the markov chain. + + A communication class is defined to be a set of states + such that every state in that set is reachable from + every other state in that set. Due to its properties + this forms a class in the mathematical sense. + Communication classes are also known as recurrence + classes. + + Returns + ======= + + classes + The ``classes`` are a list of tuples. Each + tuple represents a single communication class + with its properties. The first element in the + tuple is the list of states in the class, the + second element is whether the class is recurrent + and the third element is the period of the + communication class. + + Examples + ======== + + >>> from sympy.stats import DiscreteMarkovChain + >>> from sympy import Matrix + >>> T = Matrix([[0, 1, 0], + ... [1, 0, 0], + ... [1, 0, 0]]) + >>> X = DiscreteMarkovChain('X', [1, 2, 3], T) + >>> classes = X.communication_classes() + >>> for states, is_recurrent, period in classes: + ... states, is_recurrent, period + ([1, 2], True, 2) + ([3], False, 1) + + From this we can see that states ``1`` and ``2`` + communicate, are recurrent and have a period + of 2. We can also see state ``3`` is transient + with a period of 1. + + Notes + ===== + + The algorithm used is of order ``O(n**2)`` where + ``n`` is the number of states in the markov chain. + It uses Tarjan's algorithm to find the classes + themselves and then it uses a breadth-first search + algorithm to find each class's periodicity. + Most of the algorithm's components approach ``O(n)`` + as the matrix becomes more and more sparse. + + References + ========== + + .. [1] https://web.archive.org/web/20220207032113/https://www.columbia.edu/~ww2040/4701Sum07/4701-06-Notes-MCII.pdf + .. [2] https://cecas.clemson.edu/~shierd/Shier/markov.pdf + .. [3] https://www.proquest.com/openview/4adc6a51d8371be5b0e4c7dff287fc70/1?pq-origsite=gscholar&cbl=2026366&diss=y + .. [4] https://www.mathworks.com/help/econ/dtmc.classify.html + """ + n = self.number_of_states + T = self.transition_probabilities + + if isinstance(T, MatrixSymbol): + raise NotImplementedError("Cannot perform the operation with a symbolic matrix.") + + # begin Tarjan's algorithm + V = Range(n) + # don't use state names. Rather use state + # indexes since we use them for matrix + # indexing here and later onward + E = [(i, j) for i in V for j in V if T[i, j] != 0] + classes = strongly_connected_components((V, E)) + # end Tarjan's algorithm + + recurrence = [] + periods = [] + for class_ in classes: + # begin recurrent check (similar to self._check_trans_probs()) + submatrix = T[class_, class_] # get the submatrix with those states + is_recurrent = S.true + rows = submatrix.tolist() + for row in rows: + if (sum(row) - 1) != 0: + is_recurrent = S.false + break + recurrence.append(is_recurrent) + # end recurrent check + + # begin breadth-first search + non_tree_edge_values: set[int] = set() + visited = {class_[0]} + newly_visited = {class_[0]} + level = {class_[0]: 0} + current_level = 0 + done = False # imitate a do-while loop + while not done: # runs at most len(class_) times + done = len(visited) == len(class_) + current_level += 1 + + # this loop and the while loop above run a combined len(class_) number of times. + # so this triple nested loop runs through each of the n states once. + for i in newly_visited: + + # the loop below runs len(class_) number of times + # complexity is around about O(n * avg(len(class_))) + newly_visited = {j for j in class_ if T[i, j] != 0} + + new_tree_edges = newly_visited.difference(visited) + for j in new_tree_edges: + level[j] = current_level + + new_non_tree_edges = newly_visited.intersection(visited) + new_non_tree_edge_values = {level[i]-level[j]+1 for j in new_non_tree_edges} + + non_tree_edge_values = non_tree_edge_values.union(new_non_tree_edge_values) + visited = visited.union(new_tree_edges) + + # igcd needs at least 2 arguments + positive_ntev = {val_e for val_e in non_tree_edge_values if val_e > 0} + if len(positive_ntev) == 0: + periods.append(len(class_)) + elif len(positive_ntev) == 1: + periods.append(positive_ntev.pop()) + else: + periods.append(igcd(*positive_ntev)) + # end breadth-first search + + # convert back to the user's state names + classes = [[_sympify(self._state_index[i]) for i in class_] for class_ in classes] + return list(zip(classes, recurrence, map(Integer,periods))) + + def fundamental_matrix(self): + """ + Each entry fundamental matrix can be interpreted as + the expected number of times the chains is in state j + if it started in state i. + + References + ========== + + .. [1] https://lips.cs.princeton.edu/the-fundamental-matrix-of-a-finite-markov-chain/ + + """ + _, _, _, Q = self.decompose() + + if Q.shape[0] > 0: # if non-ergodic + I = eye(Q.shape[0]) + if (I - Q).det() == 0: + raise ValueError("The fundamental matrix doesn't exist.") + return (I - Q).inv().as_immutable() + else: # if ergodic + P = self.transition_probabilities + I = eye(P.shape[0]) + w = self.fixed_row_vector() + W = Matrix([list(w) for i in range(0, P.shape[0])]) + if (I - P + W).det() == 0: + raise ValueError("The fundamental matrix doesn't exist.") + return (I - P + W).inv().as_immutable() + + def absorbing_probabilities(self): + """ + Computes the absorbing probabilities, i.e. + the ij-th entry of the matrix denotes the + probability of Markov chain being absorbed + in state j starting from state i. + """ + _, _, R, _ = self.decompose() + N = self.fundamental_matrix() + if R is None or N is None: + return None + return N*R + + def absorbing_probabilites(self): + sympy_deprecation_warning( + """ + DiscreteMarkovChain.absorbing_probabilites() is deprecated. Use + absorbing_probabilities() instead (note the spelling difference). + """, + deprecated_since_version="1.7", + active_deprecations_target="deprecated-absorbing_probabilites", + ) + return self.absorbing_probabilities() + + def is_regular(self): + tuples = self.communication_classes() + if len(tuples) == 0: + return S.false # not defined for a 0x0 matrix + classes, _, periods = list(zip(*tuples)) + return And(len(classes) == 1, periods[0] == 1) + + def is_ergodic(self): + tuples = self.communication_classes() + if len(tuples) == 0: + return S.false # not defined for a 0x0 matrix + classes, _, _ = list(zip(*tuples)) + return S(len(classes) == 1) + + def is_absorbing_state(self, state): + trans_probs = self.transition_probabilities + if isinstance(trans_probs, ImmutableMatrix) and \ + state < trans_probs.shape[0]: + return S(trans_probs[state, state]) is S.One + + def is_absorbing_chain(self): + states, A, B, C = self.decompose() + r = A.shape[0] + return And(r > 0, A == Identity(r).as_explicit()) + + def stationary_distribution(self, condition_set=False) -> ImmutableMatrix | ConditionSet | Lambda: + r""" + The stationary distribution is any row vector, p, that solves p = pP, + is row stochastic and each element in p must be nonnegative. + That means in matrix form: :math:`(P-I)^T p^T = 0` and + :math:`(1, \dots, 1) p = 1` + where ``P`` is the one-step transition matrix. + + All time-homogeneous Markov Chains with a finite state space + have at least one stationary distribution. In addition, if + a finite time-homogeneous Markov Chain is irreducible, the + stationary distribution is unique. + + Parameters + ========== + + condition_set : bool + If the chain has a symbolic size or transition matrix, + it will return a ``Lambda`` if ``False`` and return a + ``ConditionSet`` if ``True``. + + Examples + ======== + + >>> from sympy.stats import DiscreteMarkovChain + >>> from sympy import Matrix, S + + An irreducible Markov Chain + + >>> T = Matrix([[S(1)/2, S(1)/2, 0], + ... [S(4)/5, S(1)/5, 0], + ... [1, 0, 0]]) + >>> X = DiscreteMarkovChain('X', trans_probs=T) + >>> X.stationary_distribution() + Matrix([[8/13, 5/13, 0]]) + + A reducible Markov Chain + + >>> T = Matrix([[S(1)/2, S(1)/2, 0], + ... [S(4)/5, S(1)/5, 0], + ... [0, 0, 1]]) + >>> X = DiscreteMarkovChain('X', trans_probs=T) + >>> X.stationary_distribution() + Matrix([[8/13 - 8*tau0/13, 5/13 - 5*tau0/13, tau0]]) + + >>> Y = DiscreteMarkovChain('Y') + >>> Y.stationary_distribution() + Lambda((wm, _T), Eq(wm*_T, wm)) + + >>> Y.stationary_distribution(condition_set=True) + ConditionSet(wm, Eq(wm*_T, wm)) + + References + ========== + + .. [1] https://www.probabilitycourse.com/chapter11/11_2_6_stationary_and_limiting_distributions.php + .. [2] https://web.archive.org/web/20210508104430/https://galton.uchicago.edu/~yibi/teaching/stat317/2014/Lectures/Lecture4_6up.pdf + + See Also + ======== + + sympy.stats.DiscreteMarkovChain.limiting_distribution + """ + trans_probs = self.transition_probabilities + n = self.number_of_states + + if n == 0: + return ImmutableMatrix(Matrix([[]])) + + # symbolic matrix version + if isinstance(trans_probs, MatrixSymbol): + wm = MatrixSymbol('wm', 1, n) + if condition_set: + return ConditionSet(wm, Eq(wm * trans_probs, wm)) + else: + return Lambda((wm, trans_probs), Eq(wm * trans_probs, wm)) + + # numeric matrix version + a = Matrix(trans_probs - Identity(n)).T + a[0, 0:n] = ones(1, n) # type: ignore + b = zeros(n, 1) + b[0, 0] = 1 + + soln = list(linsolve((a, b)))[0] + return ImmutableMatrix([soln]) + + def fixed_row_vector(self): + """ + A wrapper for ``stationary_distribution()``. + """ + return self.stationary_distribution() + + @property + def limiting_distribution(self): + """ + The fixed row vector is the limiting + distribution of a discrete Markov chain. + """ + return self.fixed_row_vector() + + def decompose(self) -> tuple[list[Basic], ImmutableMatrix, ImmutableMatrix, ImmutableMatrix]: + """ + Decomposes the transition matrix into submatrices with + special properties. + + The transition matrix can be decomposed into 4 submatrices: + - A - the submatrix from recurrent states to recurrent states. + - B - the submatrix from transient to recurrent states. + - C - the submatrix from transient to transient states. + - O - the submatrix of zeros for recurrent to transient states. + + Returns + ======= + + states, A, B, C + ``states`` - a list of state names with the first being + the recurrent states and the last being + the transient states in the order + of the row names of A and then the row names of C. + ``A`` - the submatrix from recurrent states to recurrent states. + ``B`` - the submatrix from transient to recurrent states. + ``C`` - the submatrix from transient to transient states. + + Examples + ======== + + >>> from sympy.stats import DiscreteMarkovChain + >>> from sympy import Matrix, S + + One can decompose this chain for example: + + >>> T = Matrix([[S(1)/2, S(1)/2, 0, 0, 0], + ... [S(2)/5, S(1)/5, S(2)/5, 0, 0], + ... [0, 0, 1, 0, 0], + ... [0, 0, S(1)/2, S(1)/2, 0], + ... [S(1)/2, 0, 0, 0, S(1)/2]]) + >>> X = DiscreteMarkovChain('X', trans_probs=T) + >>> states, A, B, C = X.decompose() + >>> states + [2, 0, 1, 3, 4] + + >>> A # recurrent to recurrent + Matrix([[1]]) + + >>> B # transient to recurrent + Matrix([ + [ 0], + [2/5], + [1/2], + [ 0]]) + + >>> C # transient to transient + Matrix([ + [1/2, 1/2, 0, 0], + [2/5, 1/5, 0, 0], + [ 0, 0, 1/2, 0], + [1/2, 0, 0, 1/2]]) + + This means that state 2 is the only absorbing state + (since A is a 1x1 matrix). B is a 4x1 matrix since + the 4 remaining transient states all merge into recurrent + state 2. And C is the 4x4 matrix that shows how the + transient states 0, 1, 3, 4 all interact. + + See Also + ======== + + sympy.stats.DiscreteMarkovChain.communication_classes + sympy.stats.DiscreteMarkovChain.canonical_form + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Absorbing_Markov_chain + .. [2] https://people.brandeis.edu/~igusa/Math56aS08/Math56a_S08_notes015.pdf + """ + trans_probs = self.transition_probabilities + + classes = self.communication_classes() + r_states = [] + t_states = [] + + for states, recurrent, period in classes: + if recurrent: + r_states += states + else: + t_states += states + + states = r_states + t_states + indexes = [self.index_of[state] for state in states] # type: ignore + + A = Matrix(len(r_states), len(r_states), + lambda i, j: trans_probs[indexes[i], indexes[j]]) + + B = Matrix(len(t_states), len(r_states), + lambda i, j: trans_probs[indexes[len(r_states) + i], indexes[j]]) + + C = Matrix(len(t_states), len(t_states), + lambda i, j: trans_probs[indexes[len(r_states) + i], indexes[len(r_states) + j]]) + + return states, A.as_immutable(), B.as_immutable(), C.as_immutable() + + def canonical_form(self) -> tuple[list[Basic], ImmutableMatrix]: + """ + Reorders the one-step transition matrix + so that recurrent states appear first and transient + states appear last. Other representations include inserting + transient states first and recurrent states last. + + Returns + ======= + + states, P_new + ``states`` is the list that describes the order of the + new states in the matrix + so that the ith element in ``states`` is the state of the + ith row of A. + ``P_new`` is the new transition matrix in canonical form. + + Examples + ======== + + >>> from sympy.stats import DiscreteMarkovChain + >>> from sympy import Matrix, S + + You can convert your chain into canonical form: + + >>> T = Matrix([[S(1)/2, S(1)/2, 0, 0, 0], + ... [S(2)/5, S(1)/5, S(2)/5, 0, 0], + ... [0, 0, 1, 0, 0], + ... [0, 0, S(1)/2, S(1)/2, 0], + ... [S(1)/2, 0, 0, 0, S(1)/2]]) + >>> X = DiscreteMarkovChain('X', list(range(1, 6)), trans_probs=T) + >>> states, new_matrix = X.canonical_form() + >>> states + [3, 1, 2, 4, 5] + + >>> new_matrix + Matrix([ + [ 1, 0, 0, 0, 0], + [ 0, 1/2, 1/2, 0, 0], + [2/5, 2/5, 1/5, 0, 0], + [1/2, 0, 0, 1/2, 0], + [ 0, 1/2, 0, 0, 1/2]]) + + The new states are [3, 1, 2, 4, 5] and you can + create a new chain with this and its canonical + form will remain the same (since it is already + in canonical form). + + >>> X = DiscreteMarkovChain('X', states, new_matrix) + >>> states, new_matrix = X.canonical_form() + >>> states + [3, 1, 2, 4, 5] + + >>> new_matrix + Matrix([ + [ 1, 0, 0, 0, 0], + [ 0, 1/2, 1/2, 0, 0], + [2/5, 2/5, 1/5, 0, 0], + [1/2, 0, 0, 1/2, 0], + [ 0, 1/2, 0, 0, 1/2]]) + + This is not limited to absorbing chains: + + >>> T = Matrix([[0, 5, 5, 0, 0], + ... [0, 0, 0, 10, 0], + ... [5, 0, 5, 0, 0], + ... [0, 10, 0, 0, 0], + ... [0, 3, 0, 3, 4]])/10 + >>> X = DiscreteMarkovChain('X', trans_probs=T) + >>> states, new_matrix = X.canonical_form() + >>> states + [1, 3, 0, 2, 4] + + >>> new_matrix + Matrix([ + [ 0, 1, 0, 0, 0], + [ 1, 0, 0, 0, 0], + [ 1/2, 0, 0, 1/2, 0], + [ 0, 0, 1/2, 1/2, 0], + [3/10, 3/10, 0, 0, 2/5]]) + + See Also + ======== + + sympy.stats.DiscreteMarkovChain.communication_classes + sympy.stats.DiscreteMarkovChain.decompose + + References + ========== + + .. [1] https://onlinelibrary.wiley.com/doi/pdf/10.1002/9780470316887.app1 + .. [2] http://www.columbia.edu/~ww2040/6711F12/lect1023big.pdf + """ + states, A, B, C = self.decompose() + O = zeros(A.shape[0], C.shape[1]) + return states, BlockMatrix([[A, O], [B, C]]).as_explicit() + + def sample(self): + """ + Returns + ======= + + sample: iterator object + iterator object containing the sample + + """ + if not isinstance(self.transition_probabilities, (Matrix, ImmutableMatrix)): + raise ValueError("Transition Matrix must be provided for sampling") + Tlist = self.transition_probabilities.tolist() + samps = [random.choice(list(self.state_space))] + yield samps[0] + time = 1 + densities = {} + for state in self.state_space: + states = list(self.state_space) + densities[state] = {states[i]: Tlist[state][i] + for i in range(len(states))} + while time < S.Infinity: + samps.append((next(sample_iter(FiniteRV("_", densities[samps[time - 1]]))))) + yield samps[time] + time += 1 + +class ContinuousMarkovChain(ContinuousTimeStochasticProcess, MarkovProcess): + """ + Represents continuous time Markov chain. + + Parameters + ========== + + sym : Symbol/str + state_space : Set + Optional, by default, S.Reals + gen_mat : Matrix/ImmutableMatrix/MatrixSymbol + Optional, by default, None + + Examples + ======== + + >>> from sympy.stats import ContinuousMarkovChain, P + >>> from sympy import Matrix, S, Eq, Gt + >>> G = Matrix([[-S(1), S(1)], [S(1), -S(1)]]) + >>> C = ContinuousMarkovChain('C', state_space=[0, 1], gen_mat=G) + >>> C.limiting_distribution() + Matrix([[1/2, 1/2]]) + >>> C.state_space + {0, 1} + >>> C.generator_matrix + Matrix([ + [-1, 1], + [ 1, -1]]) + + Probability queries are supported + + >>> P(Eq(C(1.96), 0), Eq(C(0.78), 1)).round(5) + 0.45279 + >>> P(Gt(C(1.7), 0), Eq(C(0.82), 1)).round(5) + 0.58602 + + Probability of expressions with multiple RandomIndexedSymbols + can also be calculated provided there is only 1 RandomIndexedSymbol + in the given condition. It is always better to use Rational instead + of floating point numbers for the probabilities in the + generator matrix to avoid errors. + + >>> from sympy import Gt, Le, Rational + >>> G = Matrix([[-S(1), Rational(1, 10), Rational(9, 10)], [Rational(2, 5), -S(1), Rational(3, 5)], [Rational(1, 2), Rational(1, 2), -S(1)]]) + >>> C = ContinuousMarkovChain('C', state_space=[0, 1, 2], gen_mat=G) + >>> P(Eq(C(3.92), C(1.75)), Eq(C(0.46), 0)).round(5) + 0.37933 + >>> P(Gt(C(3.92), C(1.75)), Eq(C(0.46), 0)).round(5) + 0.34211 + >>> P(Le(C(1.57), C(3.14)), Eq(C(1.22), 1)).round(4) + 0.7143 + + Symbolic probability queries are also supported + + >>> from sympy import symbols + >>> a,b,c,d = symbols('a b c d') + >>> G = Matrix([[-S(1), Rational(1, 10), Rational(9, 10)], [Rational(2, 5), -S(1), Rational(3, 5)], [Rational(1, 2), Rational(1, 2), -S(1)]]) + >>> C = ContinuousMarkovChain('C', state_space=[0, 1, 2], gen_mat=G) + >>> query = P(Eq(C(a), b), Eq(C(c), d)) + >>> query.subs({a:3.65, b:2, c:1.78, d:1}).evalf().round(10) + 0.4002723175 + >>> P(Eq(C(3.65), 2), Eq(C(1.78), 1)).round(10) + 0.4002723175 + >>> query_gt = P(Gt(C(a), b), Eq(C(c), d)) + >>> query_gt.subs({a:43.2, b:0, c:3.29, d:2}).evalf().round(10) + 0.6832579186 + >>> P(Gt(C(43.2), 0), Eq(C(3.29), 2)).round(10) + 0.6832579186 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Markov_chain#Continuous-time_Markov_chain + .. [2] https://u.math.biu.ac.il/~amirgi/CTMCnotes.pdf + """ + index_set = S.Reals + + def __new__(cls, sym, state_space=None, gen_mat=None): + sym = _symbol_converter(sym) + state_space, gen_mat = MarkovProcess._sanity_checks(state_space, gen_mat) + obj = Basic.__new__(cls, sym, state_space, gen_mat) + indices = {} + if isinstance(obj.number_of_states, Integer): + for index, state in enumerate(obj.state_space): + indices[state] = index + obj.index_of = indices + return obj + + @property + def generator_matrix(self): + return self.args[2] + + @cacheit + def transition_probabilities(self, gen_mat=None): + t = Dummy('t') + if isinstance(gen_mat, (Matrix, ImmutableMatrix)) and \ + gen_mat.is_diagonalizable(): + # for faster computation use diagonalized generator matrix + Q, D = gen_mat.diagonalize() + return Lambda(t, Q*exp(t*D)*Q.inv()) + if gen_mat != None: + return Lambda(t, exp(t*gen_mat)) + + def limiting_distribution(self): + gen_mat = self.generator_matrix + if gen_mat is None: + return None + if isinstance(gen_mat, MatrixSymbol): + wm = MatrixSymbol('wm', 1, gen_mat.shape[0]) + return Lambda((wm, gen_mat), Eq(wm*gen_mat, wm)) + w = IndexedBase('w') + wi = [w[i] for i in range(gen_mat.shape[0])] + wm = Matrix([wi]) + eqs = (wm*gen_mat).tolist()[0] + eqs.append(sum(wi) - 1) + soln = list(linsolve(eqs, wi))[0] + return ImmutableMatrix([soln]) + + +class BernoulliProcess(DiscreteTimeStochasticProcess): + """ + The Bernoulli process consists of repeated + independent Bernoulli process trials with the same parameter `p`. + It's assumed that the probability `p` applies to every + trial and that the outcomes of each trial + are independent of all the rest. Therefore Bernoulli Process + is Discrete State and Discrete Time Stochastic Process. + + Parameters + ========== + + sym : Symbol/str + success : Integer/str + The event which is considered to be success. Default: 1. + failure: Integer/str + The event which is considered to be failure. Default: 0. + p : Real Number between 0 and 1 + Represents the probability of getting success. + + Examples + ======== + + >>> from sympy.stats import BernoulliProcess, P, E + >>> from sympy import Eq, Gt + >>> B = BernoulliProcess("B", p=0.7, success=1, failure=0) + >>> B.state_space + {0, 1} + >>> B.p.round(2) + 0.70 + >>> B.success + 1 + >>> B.failure + 0 + >>> X = B[1] + B[2] + B[3] + >>> P(Eq(X, 0)).round(2) + 0.03 + >>> P(Eq(X, 2)).round(2) + 0.44 + >>> P(Eq(X, 4)).round(2) + 0 + >>> P(Gt(X, 1)).round(2) + 0.78 + >>> P(Eq(B[1], 0) & Eq(B[2], 1) & Eq(B[3], 0) & Eq(B[4], 1)).round(2) + 0.04 + >>> B.joint_distribution(B[1], B[2]) + JointDistributionHandmade(Lambda((B[1], B[2]), Piecewise((0.7, Eq(B[1], 1)), + (0.3, Eq(B[1], 0)), (0, True))*Piecewise((0.7, Eq(B[2], 1)), (0.3, Eq(B[2], 0)), + (0, True)))) + >>> E(2*B[1] + B[2]).round(2) + 2.10 + >>> P(B[1] < 1).round(2) + 0.30 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Bernoulli_process + .. [2] https://mathcs.clarku.edu/~djoyce/ma217/bernoulli.pdf + + """ + + index_set = S.Naturals0 + + def __new__(cls, sym, p, success=1, failure=0): + _value_check(p >= 0 and p <= 1, 'Value of p must be between 0 and 1.') + sym = _symbol_converter(sym) + p = _sympify(p) + success = _sym_sympify(success) + failure = _sym_sympify(failure) + return Basic.__new__(cls, sym, p, success, failure) + + @property + def symbol(self): + return self.args[0] + + @property + def p(self): + return self.args[1] + + @property + def success(self): + return self.args[2] + + @property + def failure(self): + return self.args[3] + + @property + def state_space(self): + return _set_converter([self.success, self.failure]) + + def distribution(self, key=None): + if key is None: + self._deprecation_warn_distribution() + return BernoulliDistribution(self.p) + return BernoulliDistribution(self.p, self.success, self.failure) + + def simple_rv(self, rv): + return Bernoulli(rv.name, p=self.p, + succ=self.success, fail=self.failure) + + def expectation(self, expr, condition=None, evaluate=True, **kwargs): + """ + Computes expectation. + + Parameters + ========== + + expr : RandomIndexedSymbol, Relational, Logic + Condition for which expectation has to be computed. Must + contain a RandomIndexedSymbol of the process. + condition : Relational, Logic + The given conditions under which computations should be done. + + Returns + ======= + + Expectation of the RandomIndexedSymbol. + + """ + + return _SubstituteRV._expectation(expr, condition, evaluate, **kwargs) + + def probability(self, condition, given_condition=None, evaluate=True, **kwargs): + """ + Computes probability. + + Parameters + ========== + + condition : Relational + Condition for which probability has to be computed. Must + contain a RandomIndexedSymbol of the process. + given_condition : Relational, Logic + The given conditions under which computations should be done. + + Returns + ======= + + Probability of the condition. + + """ + + return _SubstituteRV._probability(condition, given_condition, evaluate, **kwargs) + + def density(self, x): + return Piecewise((self.p, Eq(x, self.success)), + (1 - self.p, Eq(x, self.failure)), + (S.Zero, True)) + +class _SubstituteRV: + """ + Internal class to handle the queries of expectation and probability + by substitution. + """ + + @staticmethod + def _rvindexed_subs(expr, condition=None): + """ + Substitutes the RandomIndexedSymbol with the RandomSymbol with + same name, distribution and probability as RandomIndexedSymbol. + + Parameters + ========== + + expr: RandomIndexedSymbol, Relational, Logic + Condition for which expectation has to be computed. Must + contain a RandomIndexedSymbol of the process. + condition: Relational, Logic + The given conditions under which computations should be done. + + """ + + rvs_expr = random_symbols(expr) + if len(rvs_expr) != 0: + swapdict_expr = {} + for rv in rvs_expr: + if isinstance(rv, RandomIndexedSymbol): + newrv = rv.pspace.process.simple_rv(rv) # substitute with equivalent simple rv + swapdict_expr[rv] = newrv + expr = expr.subs(swapdict_expr) + rvs_cond = random_symbols(condition) + if len(rvs_cond)!=0: + swapdict_cond = {} + for rv in rvs_cond: + if isinstance(rv, RandomIndexedSymbol): + newrv = rv.pspace.process.simple_rv(rv) + swapdict_cond[rv] = newrv + condition = condition.subs(swapdict_cond) + return expr, condition + + @classmethod + def _expectation(self, expr, condition=None, evaluate=True, **kwargs): + """ + Internal method for computing expectation of indexed RV. + + Parameters + ========== + + expr: RandomIndexedSymbol, Relational, Logic + Condition for which expectation has to be computed. Must + contain a RandomIndexedSymbol of the process. + condition: Relational, Logic + The given conditions under which computations should be done. + + Returns + ======= + + Expectation of the RandomIndexedSymbol. + + """ + new_expr, new_condition = self._rvindexed_subs(expr, condition) + + if not is_random(new_expr): + return new_expr + new_pspace = pspace(new_expr) + if new_condition is not None: + new_expr = given(new_expr, new_condition) + if new_expr.is_Add: # As E is Linear + return Add(*[new_pspace.compute_expectation( + expr=arg, evaluate=evaluate, **kwargs) + for arg in new_expr.args]) + return new_pspace.compute_expectation( + new_expr, evaluate=evaluate, **kwargs) + + @classmethod + def _probability(self, condition, given_condition=None, evaluate=True, **kwargs): + """ + Internal method for computing probability of indexed RV + + Parameters + ========== + + condition: Relational + Condition for which probability has to be computed. Must + contain a RandomIndexedSymbol of the process. + given_condition: Relational/And + The given conditions under which computations should be done. + + Returns + ======= + + Probability of the condition. + + """ + new_condition, new_givencondition = self._rvindexed_subs(condition, given_condition) + + if isinstance(new_givencondition, RandomSymbol): + condrv = random_symbols(new_condition) + if len(condrv) == 1 and condrv[0] == new_givencondition: + return BernoulliDistribution(self._probability(new_condition), 0, 1) + + if any(dependent(rv, new_givencondition) for rv in condrv): + return Probability(new_condition, new_givencondition) + else: + return self._probability(new_condition) + + if new_givencondition is not None and \ + not isinstance(new_givencondition, (Relational, Boolean)): + raise ValueError("%s is not a relational or combination of relationals" + % (new_givencondition)) + if new_givencondition == False or new_condition == False: + return S.Zero + if new_condition == True: + return S.One + if not isinstance(new_condition, (Relational, Boolean)): + raise ValueError("%s is not a relational or combination of relationals" + % (new_condition)) + + if new_givencondition is not None: # If there is a condition + # Recompute on new conditional expr + return self._probability(given(new_condition, new_givencondition, **kwargs), **kwargs) + result = pspace(new_condition).probability(new_condition, **kwargs) + if evaluate and hasattr(result, 'doit'): + return result.doit() + else: + return result + +def get_timerv_swaps(expr, condition): + """ + Finds the appropriate interval for each time stamp in expr by parsing + the given condition and returns intervals for each timestamp and + dictionary that maps variable time-stamped Random Indexed Symbol to its + corresponding Random Indexed variable with fixed time stamp. + + Parameters + ========== + + expr: SymPy Expression + Expression containing Random Indexed Symbols with variable time stamps + condition: Relational/Boolean Expression + Expression containing time bounds of variable time stamps in expr + + Examples + ======== + + >>> from sympy.stats.stochastic_process_types import get_timerv_swaps, PoissonProcess + >>> from sympy import symbols, Contains, Interval + >>> x, t, d = symbols('x t d', positive=True) + >>> X = PoissonProcess("X", 3) + >>> get_timerv_swaps(x*X(t), Contains(t, Interval.Lopen(0, 1))) + ([Interval.Lopen(0, 1)], {X(t): X(1)}) + >>> get_timerv_swaps((X(t)**2 + X(d)**2), Contains(t, Interval.Lopen(0, 1)) + ... & Contains(d, Interval.Ropen(1, 4))) # doctest: +SKIP + ([Interval.Ropen(1, 4), Interval.Lopen(0, 1)], {X(d): X(3), X(t): X(1)}) + + Returns + ======= + + intervals: list + List of Intervals/FiniteSet on which each time stamp is defined + rv_swap: dict + Dictionary mapping variable time Random Indexed Symbol to constant time + Random Indexed Variable + + """ + + if not isinstance(condition, (Relational, Boolean)): + raise ValueError("%s is not a relational or combination of relationals" + % (condition)) + expr_syms = list(expr.atoms(RandomIndexedSymbol)) + if isinstance(condition, (And, Or)): + given_cond_args = condition.args + else: # single condition + given_cond_args = (condition, ) + rv_swap = {} + intervals = [] + for expr_sym in expr_syms: + for arg in given_cond_args: + if arg.has(expr_sym.key) and isinstance(expr_sym.key, Symbol): + intv = _set_converter(arg.args[1]) + diff_key = intv._sup - intv._inf + if diff_key == oo: + raise ValueError("%s should have finite bounds" % str(expr_sym.name)) + elif diff_key == S.Zero: # has singleton set + diff_key = intv._sup + rv_swap[expr_sym] = expr_sym.subs({expr_sym.key: diff_key}) + intervals.append(intv) + return intervals, rv_swap + + +class CountingProcess(ContinuousTimeStochasticProcess): + """ + This class handles the common methods of the Counting Processes + such as Poisson, Wiener and Gamma Processes + """ + index_set = _set_converter(Interval(0, oo)) + + @property + def symbol(self): + return self.args[0] + + def expectation(self, expr, condition=None, evaluate=True, **kwargs): + """ + Computes expectation + + Parameters + ========== + + expr: RandomIndexedSymbol, Relational, Logic + Condition for which expectation has to be computed. Must + contain a RandomIndexedSymbol of the process. + condition: Relational, Boolean + The given conditions under which computations should be done, i.e, + the intervals on which each variable time stamp in expr is defined + + Returns + ======= + + Expectation of the given expr + + """ + if condition is not None: + intervals, rv_swap = get_timerv_swaps(expr, condition) + # they are independent when they have non-overlapping intervals + if len(intervals) == 1 or all(Intersection(*intv_comb) == EmptySet + for intv_comb in itertools.combinations(intervals, 2)): + if expr.is_Add: + return Add.fromiter(self.expectation(arg, condition) + for arg in expr.args) + expr = expr.subs(rv_swap) + else: + return Expectation(expr, condition) + + return _SubstituteRV._expectation(expr, evaluate=evaluate, **kwargs) + + def _solve_argwith_tworvs(self, arg): + if arg.args[0].key >= arg.args[1].key or isinstance(arg, Eq): + diff_key = abs(arg.args[0].key - arg.args[1].key) + rv = arg.args[0] + arg = arg.__class__(rv.pspace.process(diff_key), 0) + else: + diff_key = arg.args[1].key - arg.args[0].key + rv = arg.args[1] + arg = arg.__class__(rv.pspace.process(diff_key), 0) + return arg + + def _solve_numerical(self, condition, given_condition=None): + if isinstance(condition, And): + args_list = list(condition.args) + else: + args_list = [condition] + if given_condition is not None: + if isinstance(given_condition, And): + args_list.extend(list(given_condition.args)) + else: + args_list.extend([given_condition]) + # sort the args based on timestamp to get the independent increments in + # each segment using all the condition args as well as given_condition args + args_list = sorted(args_list, key=lambda x: x.args[0].key) + result = [] + cond_args = list(condition.args) if isinstance(condition, And) else [condition] + if args_list[0] in cond_args and not (is_random(args_list[0].args[0]) + and is_random(args_list[0].args[1])): + result.append(_SubstituteRV._probability(args_list[0])) + + if is_random(args_list[0].args[0]) and is_random(args_list[0].args[1]): + arg = self._solve_argwith_tworvs(args_list[0]) + result.append(_SubstituteRV._probability(arg)) + + for i in range(len(args_list) - 1): + curr, nex = args_list[i], args_list[i + 1] + diff_key = nex.args[0].key - curr.args[0].key + working_set = curr.args[0].pspace.process.state_space + if curr.args[1] > nex.args[1]: #impossible condition so return 0 + result.append(0) + break + if isinstance(curr, Eq): + working_set = Intersection(working_set, Interval.Lopen(curr.args[1], oo)) + else: + working_set = Intersection(working_set, curr.as_set()) + if isinstance(nex, Eq): + working_set = Intersection(working_set, Interval(-oo, nex.args[1])) + else: + working_set = Intersection(working_set, nex.as_set()) + if working_set == EmptySet: + rv = Eq(curr.args[0].pspace.process(diff_key), 0) + result.append(_SubstituteRV._probability(rv)) + else: + if working_set.is_finite_set: + if isinstance(curr, Eq) and isinstance(nex, Eq): + rv = Eq(curr.args[0].pspace.process(diff_key), len(working_set)) + result.append(_SubstituteRV._probability(rv)) + elif isinstance(curr, Eq) ^ isinstance(nex, Eq): + result.append(Add.fromiter(_SubstituteRV._probability(Eq( + curr.args[0].pspace.process(diff_key), x)) + for x in range(len(working_set)))) + else: + n = len(working_set) + result.append(Add.fromiter((n - x)*_SubstituteRV._probability(Eq( + curr.args[0].pspace.process(diff_key), x)) for x in range(n))) + else: + result.append(_SubstituteRV._probability( + curr.args[0].pspace.process(diff_key) <= working_set._sup - working_set._inf)) + return Mul.fromiter(result) + + + def probability(self, condition, given_condition=None, evaluate=True, **kwargs): + """ + Computes probability. + + Parameters + ========== + + condition: Relational + Condition for which probability has to be computed. Must + contain a RandomIndexedSymbol of the process. + given_condition: Relational, Boolean + The given conditions under which computations should be done, i.e, + the intervals on which each variable time stamp in expr is defined + + Returns + ======= + + Probability of the condition + + """ + check_numeric = True + if isinstance(condition, (And, Or)): + cond_args = condition.args + else: + cond_args = (condition, ) + # check that condition args are numeric or not + if not all(arg.args[0].key.is_number for arg in cond_args): + check_numeric = False + if given_condition is not None: + check_given_numeric = True + if isinstance(given_condition, (And, Or)): + given_cond_args = given_condition.args + else: + given_cond_args = (given_condition, ) + # check that given condition args are numeric or not + if given_condition.has(Contains): + check_given_numeric = False + # Handle numerical queries + if check_numeric and check_given_numeric: + res = [] + if isinstance(condition, Or): + res.append(Add.fromiter(self._solve_numerical(arg, given_condition) + for arg in condition.args)) + if isinstance(given_condition, Or): + res.append(Add.fromiter(self._solve_numerical(condition, arg) + for arg in given_condition.args)) + if res: + return Add.fromiter(res) + return self._solve_numerical(condition, given_condition) + + # No numeric queries, go by Contains?... then check that all the + # given condition are in form of `Contains` + if not all(arg.has(Contains) for arg in given_cond_args): + raise ValueError("If given condition is passed with `Contains`, then " + "please pass the evaluated condition with its corresponding information " + "in terms of intervals of each time stamp to be passed in given condition.") + + intervals, rv_swap = get_timerv_swaps(condition, given_condition) + # they are independent when they have non-overlapping intervals + if len(intervals) == 1 or all(Intersection(*intv_comb) == EmptySet + for intv_comb in itertools.combinations(intervals, 2)): + if isinstance(condition, And): + return Mul.fromiter(self.probability(arg, given_condition) + for arg in condition.args) + elif isinstance(condition, Or): + return Add.fromiter(self.probability(arg, given_condition) + for arg in condition.args) + condition = condition.subs(rv_swap) + else: + return Probability(condition, given_condition) + if check_numeric: + return self._solve_numerical(condition) + return _SubstituteRV._probability(condition, evaluate=evaluate, **kwargs) + +class PoissonProcess(CountingProcess): + """ + The Poisson process is a counting process. It is usually used in scenarios + where we are counting the occurrences of certain events that appear + to happen at a certain rate, but completely at random. + + Parameters + ========== + + sym : Symbol/str + lamda : Positive number + Rate of the process, ``lambda > 0`` + + Examples + ======== + + >>> from sympy.stats import PoissonProcess, P, E + >>> from sympy import symbols, Eq, Ne, Contains, Interval + >>> X = PoissonProcess("X", lamda=3) + >>> X.state_space + Naturals0 + >>> X.lamda + 3 + >>> t1, t2 = symbols('t1 t2', positive=True) + >>> P(X(t1) < 4) + (9*t1**3/2 + 9*t1**2/2 + 3*t1 + 1)*exp(-3*t1) + >>> P(Eq(X(t1), 2) | Ne(X(t1), 4), Contains(t1, Interval.Ropen(2, 4))) + 1 - 36*exp(-6) + >>> P(Eq(X(t1), 2) & Eq(X(t2), 3), Contains(t1, Interval.Lopen(0, 2)) + ... & Contains(t2, Interval.Lopen(2, 4))) + 648*exp(-12) + >>> E(X(t1)) + 3*t1 + >>> E(X(t1)**2 + 2*X(t2), Contains(t1, Interval.Lopen(0, 1)) + ... & Contains(t2, Interval.Lopen(1, 2))) + 18 + >>> P(X(3) < 1, Eq(X(1), 0)) + exp(-6) + >>> P(Eq(X(4), 3), Eq(X(2), 3)) + exp(-6) + >>> P(X(2) <= 3, X(1) > 1) + 5*exp(-3) + + Merging two Poisson Processes + + >>> Y = PoissonProcess("Y", lamda=4) + >>> Z = X + Y + >>> Z.lamda + 7 + + Splitting a Poisson Process into two independent Poisson Processes + + >>> N, M = Z.split(l1=2, l2=5) + >>> N.lamda, M.lamda + (2, 5) + + References + ========== + + .. [1] https://www.probabilitycourse.com/chapter11/11_0_0_intro.php + .. [2] https://en.wikipedia.org/wiki/Poisson_point_process + + """ + + def __new__(cls, sym, lamda): + _value_check(lamda > 0, 'lamda should be a positive number.') + sym = _symbol_converter(sym) + lamda = _sympify(lamda) + return Basic.__new__(cls, sym, lamda) + + @property + def lamda(self): + return self.args[1] + + @property + def state_space(self): + return S.Naturals0 + + def distribution(self, key): + if isinstance(key, RandomIndexedSymbol): + self._deprecation_warn_distribution() + return PoissonDistribution(self.lamda*key.key) + return PoissonDistribution(self.lamda*key) + + def density(self, x): + return (self.lamda*x.key)**x / factorial(x) * exp(-(self.lamda*x.key)) + + def simple_rv(self, rv): + return Poisson(rv.name, lamda=self.lamda*rv.key) + + def __add__(self, other): + if not isinstance(other, PoissonProcess): + raise ValueError("Only instances of Poisson Process can be merged") + return PoissonProcess(Dummy(self.symbol.name + other.symbol.name), + self.lamda + other.lamda) + + def split(self, l1, l2): + if _sympify(l1 + l2) != self.lamda: + raise ValueError("Sum of l1 and l2 should be %s" % str(self.lamda)) + return PoissonProcess(Dummy("l1"), l1), PoissonProcess(Dummy("l2"), l2) + +class WienerProcess(CountingProcess): + """ + The Wiener process is a real valued continuous-time stochastic process. + In physics it is used to study Brownian motion and it is often also called + Brownian motion due to its historical connection with physical process of the + same name originally observed by Scottish botanist Robert Brown. + + Parameters + ========== + + sym : Symbol/str + + Examples + ======== + + >>> from sympy.stats import WienerProcess, P, E + >>> from sympy import symbols, Contains, Interval + >>> X = WienerProcess("X") + >>> X.state_space + Reals + >>> t1, t2 = symbols('t1 t2', positive=True) + >>> P(X(t1) < 7).simplify() + erf(7*sqrt(2)/(2*sqrt(t1)))/2 + 1/2 + >>> P((X(t1) > 2) | (X(t1) < 4), Contains(t1, Interval.Ropen(2, 4))).simplify() + -erf(1)/2 + erf(2)/2 + 1 + >>> E(X(t1)) + 0 + >>> E(X(t1) + 2*X(t2), Contains(t1, Interval.Lopen(0, 1)) + ... & Contains(t2, Interval.Lopen(1, 2))) + 0 + + References + ========== + + .. [1] https://www.probabilitycourse.com/chapter11/11_4_0_brownian_motion_wiener_process.php + .. [2] https://en.wikipedia.org/wiki/Wiener_process + + """ + def __new__(cls, sym): + sym = _symbol_converter(sym) + return Basic.__new__(cls, sym) + + @property + def state_space(self): + return S.Reals + + def distribution(self, key): + if isinstance(key, RandomIndexedSymbol): + self._deprecation_warn_distribution() + return NormalDistribution(0, sqrt(key.key)) + return NormalDistribution(0, sqrt(key)) + + def density(self, x): + return exp(-x**2/(2*x.key)) / (sqrt(2*pi)*sqrt(x.key)) + + def simple_rv(self, rv): + return Normal(rv.name, 0, sqrt(rv.key)) + + +class GammaProcess(CountingProcess): + r""" + A Gamma process is a random process with independent gamma distributed + increments. It is a pure-jump increasing Levy process. + + Parameters + ========== + + sym : Symbol/str + lamda : Positive number + Jump size of the process, ``lamda > 0`` + gamma : Positive number + Rate of jump arrivals, `\gamma > 0` + + Examples + ======== + + >>> from sympy.stats import GammaProcess, E, P, variance + >>> from sympy import symbols, Contains, Interval, Not + >>> t, d, x, l, g = symbols('t d x l g', positive=True) + >>> X = GammaProcess("X", l, g) + >>> E(X(t)) + g*t/l + >>> variance(X(t)).simplify() + g*t/l**2 + >>> X = GammaProcess('X', 1, 2) + >>> P(X(t) < 1).simplify() + lowergamma(2*t, 1)/gamma(2*t) + >>> P(Not((X(t) < 5) & (X(d) > 3)), Contains(t, Interval.Ropen(2, 4)) & + ... Contains(d, Interval.Lopen(7, 8))).simplify() + -4*exp(-3) + 472*exp(-8)/3 + 1 + >>> E(X(2) + x*E(X(5))) + 10*x + 4 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Gamma_process + + """ + def __new__(cls, sym, lamda, gamma): + _value_check(lamda > 0, 'lamda should be a positive number') + _value_check(gamma > 0, 'gamma should be a positive number') + sym = _symbol_converter(sym) + gamma = _sympify(gamma) + lamda = _sympify(lamda) + return Basic.__new__(cls, sym, lamda, gamma) + + @property + def lamda(self): + return self.args[1] + + @property + def gamma(self): + return self.args[2] + + @property + def state_space(self): + return _set_converter(Interval(0, oo)) + + def distribution(self, key): + if isinstance(key, RandomIndexedSymbol): + self._deprecation_warn_distribution() + return GammaDistribution(self.gamma*key.key, 1/self.lamda) + return GammaDistribution(self.gamma*key, 1/self.lamda) + + def density(self, x): + k = self.gamma*x.key + theta = 1/self.lamda + return x**(k - 1) * exp(-x/theta) / (gamma(k)*theta**k) + + def simple_rv(self, rv): + return Gamma(rv.name, self.gamma*rv.key, 1/self.lamda) diff --git a/phivenv/Lib/site-packages/sympy/stats/symbolic_multivariate_probability.py b/phivenv/Lib/site-packages/sympy/stats/symbolic_multivariate_probability.py new file mode 100644 index 0000000000000000000000000000000000000000..bbe8776e58e82489e29734cea48c9138bc512f34 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/stats/symbolic_multivariate_probability.py @@ -0,0 +1,308 @@ +import itertools + +from sympy.core.add import Add +from sympy.core.expr import Expr +from sympy.core.function import expand as _expand +from sympy.core.mul import Mul +from sympy.core.singleton import S +from sympy.matrices.exceptions import ShapeError +from sympy.matrices.expressions.matexpr import MatrixExpr +from sympy.matrices.expressions.matmul import MatMul +from sympy.matrices.expressions.special import ZeroMatrix +from sympy.stats.rv import RandomSymbol, is_random +from sympy.core.sympify import _sympify +from sympy.stats.symbolic_probability import Variance, Covariance, Expectation + + +class ExpectationMatrix(Expectation, MatrixExpr): + """ + Expectation of a random matrix expression. + + Examples + ======== + + >>> from sympy.stats import ExpectationMatrix, Normal + >>> from sympy.stats.rv import RandomMatrixSymbol + >>> from sympy import symbols, MatrixSymbol, Matrix + >>> k = symbols("k") + >>> A, B = MatrixSymbol("A", k, k), MatrixSymbol("B", k, k) + >>> X, Y = RandomMatrixSymbol("X", k, 1), RandomMatrixSymbol("Y", k, 1) + >>> ExpectationMatrix(X) + ExpectationMatrix(X) + >>> ExpectationMatrix(A*X).shape + (k, 1) + + To expand the expectation in its expression, use ``expand()``: + + >>> ExpectationMatrix(A*X + B*Y).expand() + A*ExpectationMatrix(X) + B*ExpectationMatrix(Y) + >>> ExpectationMatrix((X + Y)*(X - Y).T).expand() + ExpectationMatrix(X*X.T) - ExpectationMatrix(X*Y.T) + ExpectationMatrix(Y*X.T) - ExpectationMatrix(Y*Y.T) + + To evaluate the ``ExpectationMatrix``, use ``doit()``: + + >>> N11, N12 = Normal('N11', 11, 1), Normal('N12', 12, 1) + >>> N21, N22 = Normal('N21', 21, 1), Normal('N22', 22, 1) + >>> M11, M12 = Normal('M11', 1, 1), Normal('M12', 2, 1) + >>> M21, M22 = Normal('M21', 3, 1), Normal('M22', 4, 1) + >>> x1 = Matrix([[N11, N12], [N21, N22]]) + >>> x2 = Matrix([[M11, M12], [M21, M22]]) + >>> ExpectationMatrix(x1 + x2).doit() + Matrix([ + [12, 14], + [24, 26]]) + + """ + def __new__(cls, expr, condition=None): + expr = _sympify(expr) + if condition is None: + if not is_random(expr): + return expr + obj = Expr.__new__(cls, expr) + else: + condition = _sympify(condition) + obj = Expr.__new__(cls, expr, condition) + + obj._shape = expr.shape + obj._condition = condition + return obj + + @property + def shape(self): + return self._shape + + def expand(self, **hints): + expr = self.args[0] + condition = self._condition + if not is_random(expr): + return expr + + if isinstance(expr, Add): + return Add.fromiter(Expectation(a, condition=condition).expand() + for a in expr.args) + + expand_expr = _expand(expr) + if isinstance(expand_expr, Add): + return Add.fromiter(Expectation(a, condition=condition).expand() + for a in expand_expr.args) + + elif isinstance(expr, (Mul, MatMul)): + rv = [] + nonrv = [] + postnon = [] + + for a in expr.args: + if is_random(a): + if rv: + rv.extend(postnon) + else: + nonrv.extend(postnon) + postnon = [] + rv.append(a) + elif a.is_Matrix: + postnon.append(a) + else: + nonrv.append(a) + + # In order to avoid infinite-looping (MatMul may call .doit() again), + # do not rebuild + if len(nonrv) == 0: + return self + return Mul.fromiter(nonrv)*Expectation(Mul.fromiter(rv), + condition=condition)*Mul.fromiter(postnon) + + return self + +class VarianceMatrix(Variance, MatrixExpr): + """ + Variance of a random matrix probability expression. Also known as + Covariance matrix, auto-covariance matrix, dispersion matrix, + or variance-covariance matrix. + + Examples + ======== + + >>> from sympy.stats import VarianceMatrix + >>> from sympy.stats.rv import RandomMatrixSymbol + >>> from sympy import symbols, MatrixSymbol + >>> k = symbols("k") + >>> A, B = MatrixSymbol("A", k, k), MatrixSymbol("B", k, k) + >>> X, Y = RandomMatrixSymbol("X", k, 1), RandomMatrixSymbol("Y", k, 1) + >>> VarianceMatrix(X) + VarianceMatrix(X) + >>> VarianceMatrix(X).shape + (k, k) + + To expand the variance in its expression, use ``expand()``: + + >>> VarianceMatrix(A*X).expand() + A*VarianceMatrix(X)*A.T + >>> VarianceMatrix(A*X + B*Y).expand() + 2*A*CrossCovarianceMatrix(X, Y)*B.T + A*VarianceMatrix(X)*A.T + B*VarianceMatrix(Y)*B.T + """ + def __new__(cls, arg, condition=None): + arg = _sympify(arg) + + if 1 not in arg.shape: + raise ShapeError("Expression is not a vector") + + shape = (arg.shape[0], arg.shape[0]) if arg.shape[1] == 1 else (arg.shape[1], arg.shape[1]) + + if condition: + obj = Expr.__new__(cls, arg, condition) + else: + obj = Expr.__new__(cls, arg) + + obj._shape = shape + obj._condition = condition + return obj + + @property + def shape(self): + return self._shape + + def expand(self, **hints): + arg = self.args[0] + condition = self._condition + + if not is_random(arg): + return ZeroMatrix(*self.shape) + + if isinstance(arg, RandomSymbol): + return self + elif isinstance(arg, Add): + rv = [] + for a in arg.args: + if is_random(a): + rv.append(a) + variances = Add(*(Variance(xv, condition).expand() for xv in rv)) + map_to_covar = lambda x: 2*Covariance(*x, condition=condition).expand() + covariances = Add(*map(map_to_covar, itertools.combinations(rv, 2))) + return variances + covariances + elif isinstance(arg, (Mul, MatMul)): + nonrv = [] + rv = [] + for a in arg.args: + if is_random(a): + rv.append(a) + else: + nonrv.append(a) + if len(rv) == 0: + return ZeroMatrix(*self.shape) + # Avoid possible infinite loops with MatMul: + if len(nonrv) == 0: + return self + # Variance of many multiple matrix products is not implemented: + if len(rv) > 1: + return self + return Mul.fromiter(nonrv)*Variance(Mul.fromiter(rv), + condition)*(Mul.fromiter(nonrv)).transpose() + + # this expression contains a RandomSymbol somehow: + return self + +class CrossCovarianceMatrix(Covariance, MatrixExpr): + """ + Covariance of a random matrix probability expression. + + Examples + ======== + + >>> from sympy.stats import CrossCovarianceMatrix + >>> from sympy.stats.rv import RandomMatrixSymbol + >>> from sympy import symbols, MatrixSymbol + >>> k = symbols("k") + >>> A, B = MatrixSymbol("A", k, k), MatrixSymbol("B", k, k) + >>> C, D = MatrixSymbol("C", k, k), MatrixSymbol("D", k, k) + >>> X, Y = RandomMatrixSymbol("X", k, 1), RandomMatrixSymbol("Y", k, 1) + >>> Z, W = RandomMatrixSymbol("Z", k, 1), RandomMatrixSymbol("W", k, 1) + >>> CrossCovarianceMatrix(X, Y) + CrossCovarianceMatrix(X, Y) + >>> CrossCovarianceMatrix(X, Y).shape + (k, k) + + To expand the covariance in its expression, use ``expand()``: + + >>> CrossCovarianceMatrix(X + Y, Z).expand() + CrossCovarianceMatrix(X, Z) + CrossCovarianceMatrix(Y, Z) + >>> CrossCovarianceMatrix(A*X, Y).expand() + A*CrossCovarianceMatrix(X, Y) + >>> CrossCovarianceMatrix(A*X, B.T*Y).expand() + A*CrossCovarianceMatrix(X, Y)*B + >>> CrossCovarianceMatrix(A*X + B*Y, C.T*Z + D.T*W).expand() + A*CrossCovarianceMatrix(X, W)*D + A*CrossCovarianceMatrix(X, Z)*C + B*CrossCovarianceMatrix(Y, W)*D + B*CrossCovarianceMatrix(Y, Z)*C + + """ + def __new__(cls, arg1, arg2, condition=None): + arg1 = _sympify(arg1) + arg2 = _sympify(arg2) + + if (1 not in arg1.shape) or (1 not in arg2.shape) or (arg1.shape[1] != arg2.shape[1]): + raise ShapeError("Expression is not a vector") + + shape = (arg1.shape[0], arg2.shape[0]) if arg1.shape[1] == 1 and arg2.shape[1] == 1 \ + else (1, 1) + + if condition: + obj = Expr.__new__(cls, arg1, arg2, condition) + else: + obj = Expr.__new__(cls, arg1, arg2) + + obj._shape = shape + obj._condition = condition + return obj + + @property + def shape(self): + return self._shape + + def expand(self, **hints): + arg1 = self.args[0] + arg2 = self.args[1] + condition = self._condition + + if arg1 == arg2: + return VarianceMatrix(arg1, condition).expand() + + if not is_random(arg1) or not is_random(arg2): + return ZeroMatrix(*self.shape) + + if isinstance(arg1, RandomSymbol) and isinstance(arg2, RandomSymbol): + return CrossCovarianceMatrix(arg1, arg2, condition) + + coeff_rv_list1 = self._expand_single_argument(arg1.expand()) + coeff_rv_list2 = self._expand_single_argument(arg2.expand()) + + addends = [a*CrossCovarianceMatrix(r1, r2, condition=condition)*b.transpose() + for (a, r1) in coeff_rv_list1 for (b, r2) in coeff_rv_list2] + return Add.fromiter(addends) + + @classmethod + def _expand_single_argument(cls, expr): + # return (coefficient, random_symbol) pairs: + if isinstance(expr, RandomSymbol): + return [(S.One, expr)] + elif isinstance(expr, Add): + outval = [] + for a in expr.args: + if isinstance(a, (Mul, MatMul)): + outval.append(cls._get_mul_nonrv_rv_tuple(a)) + elif is_random(a): + outval.append((S.One, a)) + + return outval + elif isinstance(expr, (Mul, MatMul)): + return [cls._get_mul_nonrv_rv_tuple(expr)] + elif is_random(expr): + return [(S.One, expr)] + + @classmethod + def _get_mul_nonrv_rv_tuple(cls, m): + rv = [] + nonrv = [] + for a in m.args: + if is_random(a): + rv.append(a) + else: + nonrv.append(a) + return (Mul.fromiter(nonrv), Mul.fromiter(rv)) diff --git a/phivenv/Lib/site-packages/sympy/stats/symbolic_probability.py b/phivenv/Lib/site-packages/sympy/stats/symbolic_probability.py new file mode 100644 index 0000000000000000000000000000000000000000..5d0b971a8691f82de15258d4c460129059eaf436 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/stats/symbolic_probability.py @@ -0,0 +1,698 @@ +import itertools +from sympy.concrete.summations import Sum +from sympy.core.add import Add +from sympy.core.expr import Expr +from sympy.core.function import expand as _expand +from sympy.core.mul import Mul +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.integrals.integrals import Integral +from sympy.logic.boolalg import Not +from sympy.core.parameters import global_parameters +from sympy.core.sorting import default_sort_key +from sympy.core.sympify import _sympify +from sympy.core.relational import Relational +from sympy.logic.boolalg import Boolean +from sympy.stats import variance, covariance +from sympy.stats.rv import (RandomSymbol, pspace, dependent, + given, sampling_E, RandomIndexedSymbol, is_random, + PSpace, sampling_P, random_symbols) + +__all__ = ['Probability', 'Expectation', 'Variance', 'Covariance'] + + +@is_random.register(Expr) +def _(x): + atoms = x.free_symbols + if len(atoms) == 1 and next(iter(atoms)) == x: + return False + return any(is_random(i) for i in atoms) + +@is_random.register(RandomSymbol) # type: ignore +def _(x): + return True + + +class Probability(Expr): + """ + Symbolic expression for the probability. + + Examples + ======== + + >>> from sympy.stats import Probability, Normal + >>> from sympy import Integral + >>> X = Normal("X", 0, 1) + >>> prob = Probability(X > 1) + >>> prob + Probability(X > 1) + + Integral representation: + + >>> prob.rewrite(Integral) + Integral(sqrt(2)*exp(-_z**2/2)/(2*sqrt(pi)), (_z, 1, oo)) + + Evaluation of the integral: + + >>> prob.evaluate_integral() + sqrt(2)*(-sqrt(2)*sqrt(pi)*erf(sqrt(2)/2) + sqrt(2)*sqrt(pi))/(4*sqrt(pi)) + """ + + is_commutative = True + + def __new__(cls, prob, condition=None, **kwargs): + prob = _sympify(prob) + if condition is None: + obj = Expr.__new__(cls, prob) + else: + condition = _sympify(condition) + obj = Expr.__new__(cls, prob, condition) + obj._condition = condition + return obj + + def doit(self, **hints): + condition = self.args[0] + given_condition = self._condition + numsamples = hints.get('numsamples', False) + evaluate = hints.get('evaluate', True) + + if isinstance(condition, Not): + return S.One - self.func(condition.args[0], given_condition, + evaluate=evaluate).doit(**hints) + + if condition.has(RandomIndexedSymbol): + return pspace(condition).probability(condition, given_condition, + evaluate=evaluate) + + if isinstance(given_condition, RandomSymbol): + condrv = random_symbols(condition) + if len(condrv) == 1 and condrv[0] == given_condition: + from sympy.stats.frv_types import BernoulliDistribution + return BernoulliDistribution(self.func(condition).doit(**hints), 0, 1) + if any(dependent(rv, given_condition) for rv in condrv): + return Probability(condition, given_condition) + else: + return Probability(condition).doit() + + if given_condition is not None and \ + not isinstance(given_condition, (Relational, Boolean)): + raise ValueError("%s is not a relational or combination of relationals" + % (given_condition)) + + if given_condition == False or condition is S.false: + return S.Zero + if not isinstance(condition, (Relational, Boolean)): + raise ValueError("%s is not a relational or combination of relationals" + % (condition)) + if condition is S.true: + return S.One + + if numsamples: + return sampling_P(condition, given_condition, numsamples=numsamples) + if given_condition is not None: # If there is a condition + # Recompute on new conditional expr + return Probability(given(condition, given_condition)).doit() + + # Otherwise pass work off to the ProbabilitySpace + if pspace(condition) == PSpace(): + return Probability(condition, given_condition) + + result = pspace(condition).probability(condition) + if hasattr(result, 'doit') and evaluate: + return result.doit() + else: + return result + + def _eval_rewrite_as_Integral(self, arg, condition=None, **kwargs): + return self.func(arg, condition=condition).doit(evaluate=False) + + _eval_rewrite_as_Sum = _eval_rewrite_as_Integral + + def evaluate_integral(self): + return self.rewrite(Integral).doit() + + +class Expectation(Expr): + """ + Symbolic expression for the expectation. + + Examples + ======== + + >>> from sympy.stats import Expectation, Normal, Probability, Poisson + >>> from sympy import symbols, Integral, Sum + >>> mu = symbols("mu") + >>> sigma = symbols("sigma", positive=True) + >>> X = Normal("X", mu, sigma) + >>> Expectation(X) + Expectation(X) + >>> Expectation(X).evaluate_integral().simplify() + mu + + To get the integral expression of the expectation: + + >>> Expectation(X).rewrite(Integral) + Integral(sqrt(2)*X*exp(-(X - mu)**2/(2*sigma**2))/(2*sqrt(pi)*sigma), (X, -oo, oo)) + + The same integral expression, in more abstract terms: + + >>> Expectation(X).rewrite(Probability) + Integral(x*Probability(Eq(X, x)), (x, -oo, oo)) + + To get the Summation expression of the expectation for discrete random variables: + + >>> lamda = symbols('lamda', positive=True) + >>> Z = Poisson('Z', lamda) + >>> Expectation(Z).rewrite(Sum) + Sum(Z*lamda**Z*exp(-lamda)/factorial(Z), (Z, 0, oo)) + + This class is aware of some properties of the expectation: + + >>> from sympy.abc import a + >>> Expectation(a*X) + Expectation(a*X) + >>> Y = Normal("Y", 1, 2) + >>> Expectation(X + Y) + Expectation(X + Y) + + To expand the ``Expectation`` into its expression, use ``expand()``: + + >>> Expectation(X + Y).expand() + Expectation(X) + Expectation(Y) + >>> Expectation(a*X + Y).expand() + a*Expectation(X) + Expectation(Y) + >>> Expectation(a*X + Y) + Expectation(a*X + Y) + >>> Expectation((X + Y)*(X - Y)).expand() + Expectation(X**2) - Expectation(Y**2) + + To evaluate the ``Expectation``, use ``doit()``: + + >>> Expectation(X + Y).doit() + mu + 1 + >>> Expectation(X + Expectation(Y + Expectation(2*X))).doit() + 3*mu + 1 + + To prevent evaluating nested ``Expectation``, use ``doit(deep=False)`` + + >>> Expectation(X + Expectation(Y)).doit(deep=False) + mu + Expectation(Expectation(Y)) + >>> Expectation(X + Expectation(Y + Expectation(2*X))).doit(deep=False) + mu + Expectation(Expectation(Expectation(2*X) + Y)) + + """ + + def __new__(cls, expr, condition=None, **kwargs): + expr = _sympify(expr) + if expr.is_Matrix: + from sympy.stats.symbolic_multivariate_probability import ExpectationMatrix + return ExpectationMatrix(expr, condition) + if condition is None: + if not is_random(expr): + return expr + obj = Expr.__new__(cls, expr) + else: + condition = _sympify(condition) + obj = Expr.__new__(cls, expr, condition) + obj._condition = condition + return obj + + def _eval_is_commutative(self): + return(self.args[0].is_commutative) + + def expand(self, **hints): + expr = self.args[0] + condition = self._condition + + if not is_random(expr): + return expr + + if isinstance(expr, Add): + return Add.fromiter(Expectation(a, condition=condition).expand() + for a in expr.args) + + expand_expr = _expand(expr) + if isinstance(expand_expr, Add): + return Add.fromiter(Expectation(a, condition=condition).expand() + for a in expand_expr.args) + + elif isinstance(expr, Mul): + rv = [] + nonrv = [] + for a in expr.args: + if is_random(a): + rv.append(a) + else: + nonrv.append(a) + return Mul.fromiter(nonrv)*Expectation(Mul.fromiter(rv), condition=condition) + + return self + + def doit(self, **hints): + deep = hints.get('deep', True) + condition = self._condition + expr = self.args[0] + numsamples = hints.get('numsamples', False) + evaluate = hints.get('evaluate', True) + + if deep: + expr = expr.doit(**hints) + + if not is_random(expr) or isinstance(expr, Expectation): # expr isn't random? + return expr + if numsamples: # Computing by monte carlo sampling? + evalf = hints.get('evalf', True) + return sampling_E(expr, condition, numsamples=numsamples, evalf=evalf) + + if expr.has(RandomIndexedSymbol): + return pspace(expr).compute_expectation(expr, condition) + + # Create new expr and recompute E + if condition is not None: # If there is a condition + return self.func(given(expr, condition)).doit(**hints) + + # A few known statements for efficiency + + if expr.is_Add: # We know that E is Linear + return Add(*[self.func(arg, condition).doit(**hints) + if not isinstance(arg, Expectation) else self.func(arg, condition) + for arg in expr.args]) + if expr.is_Mul: + if expr.atoms(Expectation): + return expr + + if pspace(expr) == PSpace(): + return self.func(expr) + # Otherwise case is simple, pass work off to the ProbabilitySpace + result = pspace(expr).compute_expectation(expr, evaluate=evaluate) + if hasattr(result, 'doit') and evaluate: + return result.doit(**hints) + else: + return result + + + def _eval_rewrite_as_Probability(self, arg, condition=None, **kwargs): + rvs = arg.atoms(RandomSymbol) + if len(rvs) > 1: + raise NotImplementedError() + if len(rvs) == 0: + return arg + + rv = rvs.pop() + if rv.pspace is None: + raise ValueError("Probability space not known") + + symbol = rv.symbol + if symbol.name[0].isupper(): + symbol = Symbol(symbol.name.lower()) + else : + symbol = Symbol(symbol.name + "_1") + + if rv.pspace.is_Continuous: + return Integral(arg.replace(rv, symbol)*Probability(Eq(rv, symbol), condition), (symbol, rv.pspace.domain.set.inf, rv.pspace.domain.set.sup)) + else: + if rv.pspace.is_Finite: + raise NotImplementedError + else: + return Sum(arg.replace(rv, symbol)*Probability(Eq(rv, symbol), condition), (symbol, rv.pspace.domain.set.inf, rv.pspace.set.sup)) + + def _eval_rewrite_as_Integral(self, arg, condition=None, evaluate=False, **kwargs): + return self.func(arg, condition=condition).doit(deep=False, evaluate=evaluate) + + _eval_rewrite_as_Sum = _eval_rewrite_as_Integral # For discrete this will be Sum + + def evaluate_integral(self): + return self.rewrite(Integral).doit() + + evaluate_sum = evaluate_integral + +class Variance(Expr): + """ + Symbolic expression for the variance. + + Examples + ======== + + >>> from sympy import symbols, Integral + >>> from sympy.stats import Normal, Expectation, Variance, Probability + >>> mu = symbols("mu", positive=True) + >>> sigma = symbols("sigma", positive=True) + >>> X = Normal("X", mu, sigma) + >>> Variance(X) + Variance(X) + >>> Variance(X).evaluate_integral() + sigma**2 + + Integral representation of the underlying calculations: + + >>> Variance(X).rewrite(Integral) + Integral(sqrt(2)*(X - Integral(sqrt(2)*X*exp(-(X - mu)**2/(2*sigma**2))/(2*sqrt(pi)*sigma), (X, -oo, oo)))**2*exp(-(X - mu)**2/(2*sigma**2))/(2*sqrt(pi)*sigma), (X, -oo, oo)) + + Integral representation, without expanding the PDF: + + >>> Variance(X).rewrite(Probability) + -Integral(x*Probability(Eq(X, x)), (x, -oo, oo))**2 + Integral(x**2*Probability(Eq(X, x)), (x, -oo, oo)) + + Rewrite the variance in terms of the expectation + + >>> Variance(X).rewrite(Expectation) + -Expectation(X)**2 + Expectation(X**2) + + Some transformations based on the properties of the variance may happen: + + >>> from sympy.abc import a + >>> Y = Normal("Y", 0, 1) + >>> Variance(a*X) + Variance(a*X) + + To expand the variance in its expression, use ``expand()``: + + >>> Variance(a*X).expand() + a**2*Variance(X) + >>> Variance(X + Y) + Variance(X + Y) + >>> Variance(X + Y).expand() + 2*Covariance(X, Y) + Variance(X) + Variance(Y) + + """ + def __new__(cls, arg, condition=None, **kwargs): + arg = _sympify(arg) + + if arg.is_Matrix: + from sympy.stats.symbolic_multivariate_probability import VarianceMatrix + return VarianceMatrix(arg, condition) + if condition is None: + obj = Expr.__new__(cls, arg) + else: + condition = _sympify(condition) + obj = Expr.__new__(cls, arg, condition) + obj._condition = condition + return obj + + def _eval_is_commutative(self): + return self.args[0].is_commutative + + def expand(self, **hints): + arg = self.args[0] + condition = self._condition + + if not is_random(arg): + return S.Zero + + if isinstance(arg, RandomSymbol): + return self + elif isinstance(arg, Add): + rv = [] + for a in arg.args: + if is_random(a): + rv.append(a) + variances = Add(*(Variance(xv, condition).expand() for xv in rv)) + map_to_covar = lambda x: 2*Covariance(*x, condition=condition).expand() + covariances = Add(*map(map_to_covar, itertools.combinations(rv, 2))) + return variances + covariances + elif isinstance(arg, Mul): + nonrv = [] + rv = [] + for a in arg.args: + if is_random(a): + rv.append(a) + else: + nonrv.append(a**2) + if len(rv) == 0: + return S.Zero + return Mul.fromiter(nonrv)*Variance(Mul.fromiter(rv), condition) + + # this expression contains a RandomSymbol somehow: + return self + + def _eval_rewrite_as_Expectation(self, arg, condition=None, **kwargs): + e1 = Expectation(arg**2, condition) + e2 = Expectation(arg, condition)**2 + return e1 - e2 + + def _eval_rewrite_as_Probability(self, arg, condition=None, **kwargs): + return self.rewrite(Expectation).rewrite(Probability) + + def _eval_rewrite_as_Integral(self, arg, condition=None, **kwargs): + return variance(self.args[0], self._condition, evaluate=False) + + _eval_rewrite_as_Sum = _eval_rewrite_as_Integral + + def evaluate_integral(self): + return self.rewrite(Integral).doit() + + +class Covariance(Expr): + """ + Symbolic expression for the covariance. + + Examples + ======== + + >>> from sympy.stats import Covariance + >>> from sympy.stats import Normal + >>> X = Normal("X", 3, 2) + >>> Y = Normal("Y", 0, 1) + >>> Z = Normal("Z", 0, 1) + >>> W = Normal("W", 0, 1) + >>> cexpr = Covariance(X, Y) + >>> cexpr + Covariance(X, Y) + + Evaluate the covariance, `X` and `Y` are independent, + therefore zero is the result: + + >>> cexpr.evaluate_integral() + 0 + + Rewrite the covariance expression in terms of expectations: + + >>> from sympy.stats import Expectation + >>> cexpr.rewrite(Expectation) + Expectation(X*Y) - Expectation(X)*Expectation(Y) + + In order to expand the argument, use ``expand()``: + + >>> from sympy.abc import a, b, c, d + >>> Covariance(a*X + b*Y, c*Z + d*W) + Covariance(a*X + b*Y, c*Z + d*W) + >>> Covariance(a*X + b*Y, c*Z + d*W).expand() + a*c*Covariance(X, Z) + a*d*Covariance(W, X) + b*c*Covariance(Y, Z) + b*d*Covariance(W, Y) + + This class is aware of some properties of the covariance: + + >>> Covariance(X, X).expand() + Variance(X) + >>> Covariance(a*X, b*Y).expand() + a*b*Covariance(X, Y) + """ + + def __new__(cls, arg1, arg2, condition=None, **kwargs): + arg1 = _sympify(arg1) + arg2 = _sympify(arg2) + + if arg1.is_Matrix or arg2.is_Matrix: + from sympy.stats.symbolic_multivariate_probability import CrossCovarianceMatrix + return CrossCovarianceMatrix(arg1, arg2, condition) + + if kwargs.pop('evaluate', global_parameters.evaluate): + arg1, arg2 = sorted([arg1, arg2], key=default_sort_key) + + if condition is None: + obj = Expr.__new__(cls, arg1, arg2) + else: + condition = _sympify(condition) + obj = Expr.__new__(cls, arg1, arg2, condition) + obj._condition = condition + return obj + + def _eval_is_commutative(self): + return self.args[0].is_commutative + + def expand(self, **hints): + arg1 = self.args[0] + arg2 = self.args[1] + condition = self._condition + + if arg1 == arg2: + return Variance(arg1, condition).expand() + + if not is_random(arg1): + return S.Zero + if not is_random(arg2): + return S.Zero + + arg1, arg2 = sorted([arg1, arg2], key=default_sort_key) + + if isinstance(arg1, RandomSymbol) and isinstance(arg2, RandomSymbol): + return Covariance(arg1, arg2, condition) + + coeff_rv_list1 = self._expand_single_argument(arg1.expand()) + coeff_rv_list2 = self._expand_single_argument(arg2.expand()) + + addends = [a*b*Covariance(*sorted([r1, r2], key=default_sort_key), condition=condition) + for (a, r1) in coeff_rv_list1 for (b, r2) in coeff_rv_list2] + return Add.fromiter(addends) + + @classmethod + def _expand_single_argument(cls, expr): + # return (coefficient, random_symbol) pairs: + if isinstance(expr, RandomSymbol): + return [(S.One, expr)] + elif isinstance(expr, Add): + outval = [] + for a in expr.args: + if isinstance(a, Mul): + outval.append(cls._get_mul_nonrv_rv_tuple(a)) + elif is_random(a): + outval.append((S.One, a)) + + return outval + elif isinstance(expr, Mul): + return [cls._get_mul_nonrv_rv_tuple(expr)] + elif is_random(expr): + return [(S.One, expr)] + + @classmethod + def _get_mul_nonrv_rv_tuple(cls, m): + rv = [] + nonrv = [] + for a in m.args: + if is_random(a): + rv.append(a) + else: + nonrv.append(a) + return (Mul.fromiter(nonrv), Mul.fromiter(rv)) + + def _eval_rewrite_as_Expectation(self, arg1, arg2, condition=None, **kwargs): + e1 = Expectation(arg1*arg2, condition) + e2 = Expectation(arg1, condition)*Expectation(arg2, condition) + return e1 - e2 + + def _eval_rewrite_as_Probability(self, arg1, arg2, condition=None, **kwargs): + return self.rewrite(Expectation).rewrite(Probability) + + def _eval_rewrite_as_Integral(self, arg1, arg2, condition=None, **kwargs): + return covariance(self.args[0], self.args[1], self._condition, evaluate=False) + + _eval_rewrite_as_Sum = _eval_rewrite_as_Integral + + def evaluate_integral(self): + return self.rewrite(Integral).doit() + + +class Moment(Expr): + """ + Symbolic class for Moment + + Examples + ======== + + >>> from sympy import Symbol, Integral + >>> from sympy.stats import Normal, Expectation, Probability, Moment + >>> mu = Symbol('mu', real=True) + >>> sigma = Symbol('sigma', positive=True) + >>> X = Normal('X', mu, sigma) + >>> M = Moment(X, 3, 1) + + To evaluate the result of Moment use `doit`: + + >>> M.doit() + mu**3 - 3*mu**2 + 3*mu*sigma**2 + 3*mu - 3*sigma**2 - 1 + + Rewrite the Moment expression in terms of Expectation: + + >>> M.rewrite(Expectation) + Expectation((X - 1)**3) + + Rewrite the Moment expression in terms of Probability: + + >>> M.rewrite(Probability) + Integral((x - 1)**3*Probability(Eq(X, x)), (x, -oo, oo)) + + Rewrite the Moment expression in terms of Integral: + + >>> M.rewrite(Integral) + Integral(sqrt(2)*(X - 1)**3*exp(-(X - mu)**2/(2*sigma**2))/(2*sqrt(pi)*sigma), (X, -oo, oo)) + + """ + def __new__(cls, X, n, c=0, condition=None, **kwargs): + X = _sympify(X) + n = _sympify(n) + c = _sympify(c) + if condition is not None: + condition = _sympify(condition) + return super().__new__(cls, X, n, c, condition) + else: + return super().__new__(cls, X, n, c) + + def doit(self, **hints): + return self.rewrite(Expectation).doit(**hints) + + def _eval_rewrite_as_Expectation(self, X, n, c=0, condition=None, **kwargs): + return Expectation((X - c)**n, condition) + + def _eval_rewrite_as_Probability(self, X, n, c=0, condition=None, **kwargs): + return self.rewrite(Expectation).rewrite(Probability) + + def _eval_rewrite_as_Integral(self, X, n, c=0, condition=None, **kwargs): + return self.rewrite(Expectation).rewrite(Integral) + + +class CentralMoment(Expr): + """ + Symbolic class Central Moment + + Examples + ======== + + >>> from sympy import Symbol, Integral + >>> from sympy.stats import Normal, Expectation, Probability, CentralMoment + >>> mu = Symbol('mu', real=True) + >>> sigma = Symbol('sigma', positive=True) + >>> X = Normal('X', mu, sigma) + >>> CM = CentralMoment(X, 4) + + To evaluate the result of CentralMoment use `doit`: + + >>> CM.doit().simplify() + 3*sigma**4 + + Rewrite the CentralMoment expression in terms of Expectation: + + >>> CM.rewrite(Expectation) + Expectation((-Expectation(X) + X)**4) + + Rewrite the CentralMoment expression in terms of Probability: + + >>> CM.rewrite(Probability) + Integral((x - Integral(x*Probability(True), (x, -oo, oo)))**4*Probability(Eq(X, x)), (x, -oo, oo)) + + Rewrite the CentralMoment expression in terms of Integral: + + >>> CM.rewrite(Integral) + Integral(sqrt(2)*(X - Integral(sqrt(2)*X*exp(-(X - mu)**2/(2*sigma**2))/(2*sqrt(pi)*sigma), (X, -oo, oo)))**4*exp(-(X - mu)**2/(2*sigma**2))/(2*sqrt(pi)*sigma), (X, -oo, oo)) + + """ + def __new__(cls, X, n, condition=None, **kwargs): + X = _sympify(X) + n = _sympify(n) + if condition is not None: + condition = _sympify(condition) + return super().__new__(cls, X, n, condition) + else: + return super().__new__(cls, X, n) + + def doit(self, **hints): + return self.rewrite(Expectation).doit(**hints) + + def _eval_rewrite_as_Expectation(self, X, n, condition=None, **kwargs): + mu = Expectation(X, condition, **kwargs) + return Moment(X, n, mu, condition, **kwargs).rewrite(Expectation) + + def _eval_rewrite_as_Probability(self, X, n, condition=None, **kwargs): + return self.rewrite(Expectation).rewrite(Probability) + + def _eval_rewrite_as_Integral(self, X, n, condition=None, **kwargs): + return self.rewrite(Expectation).rewrite(Integral) diff --git a/phivenv/Lib/site-packages/sympy/stats/tests/__init__.py b/phivenv/Lib/site-packages/sympy/stats/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/stats/tests/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/stats/tests/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4312743b9ceebe46c5a1559bef6b80f19619708d Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/stats/tests/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/stats/tests/__pycache__/test_compound_rv.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/stats/tests/__pycache__/test_compound_rv.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f58b4a2c58a28ca035f9e3afafaa40c11e6cb5c3 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/stats/tests/__pycache__/test_compound_rv.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/stats/tests/__pycache__/test_continuous_rv.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/stats/tests/__pycache__/test_continuous_rv.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cbedb58a33364af3375c0f0d390336b0f976627d Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/stats/tests/__pycache__/test_continuous_rv.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/stats/tests/__pycache__/test_discrete_rv.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/stats/tests/__pycache__/test_discrete_rv.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1db598468f6e0c30ec56e8dc7ee78bc450626000 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/stats/tests/__pycache__/test_discrete_rv.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/stats/tests/__pycache__/test_error_prop.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/stats/tests/__pycache__/test_error_prop.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90329677f9eb84eaf6f7b95b90cf4c2e5d0aee91 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/stats/tests/__pycache__/test_error_prop.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/stats/tests/__pycache__/test_finite_rv.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/stats/tests/__pycache__/test_finite_rv.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e40b4b18c6765664097ca1b4482fa6bf3055761 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/stats/tests/__pycache__/test_finite_rv.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/stats/tests/__pycache__/test_joint_rv.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/stats/tests/__pycache__/test_joint_rv.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df62ac8330dc23b6b6e2d1897eecab10ddfada10 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/stats/tests/__pycache__/test_joint_rv.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/stats/tests/__pycache__/test_matrix_distributions.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/stats/tests/__pycache__/test_matrix_distributions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cbda6718eecd746cc4396c9212a275c0b9724e5e Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/stats/tests/__pycache__/test_matrix_distributions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/stats/tests/__pycache__/test_mix.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/stats/tests/__pycache__/test_mix.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1c5652039545573802414c75ee0b341643b9630 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/stats/tests/__pycache__/test_mix.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/stats/tests/__pycache__/test_random_matrix.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/stats/tests/__pycache__/test_random_matrix.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d140c678e353736617bd21a8f43be0e16f75b096 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/stats/tests/__pycache__/test_random_matrix.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/stats/tests/__pycache__/test_rv.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/stats/tests/__pycache__/test_rv.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..efb8961e0df64cf61e90b15d84be15d84342b0ef Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/stats/tests/__pycache__/test_rv.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/stats/tests/__pycache__/test_stochastic_process.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/stats/tests/__pycache__/test_stochastic_process.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63f3878f42a2a618ecbe0a40a67b3060f2a73048 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/stats/tests/__pycache__/test_stochastic_process.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/stats/tests/__pycache__/test_symbolic_multivariate.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/stats/tests/__pycache__/test_symbolic_multivariate.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e8bd9331a49df9d70e015b41fb99bb1bfdacf47 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/stats/tests/__pycache__/test_symbolic_multivariate.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/stats/tests/__pycache__/test_symbolic_probability.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/stats/tests/__pycache__/test_symbolic_probability.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63477552168517e81be7760d46ffc43902d3af39 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/stats/tests/__pycache__/test_symbolic_probability.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/stats/tests/test_compound_rv.py b/phivenv/Lib/site-packages/sympy/stats/tests/test_compound_rv.py new file mode 100644 index 0000000000000000000000000000000000000000..573ba364b686738e56bb1c4615acd2a9bc8bf3ae --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/stats/tests/test_compound_rv.py @@ -0,0 +1,159 @@ +from sympy.concrete.summations import Sum +from sympy.core.numbers import (oo, pi) +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.combinatorial.factorials import factorial +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.special.beta_functions import beta +from sympy.functions.special.error_functions import erf +from sympy.functions.special.gamma_functions import gamma +from sympy.integrals.integrals import Integral +from sympy.sets.sets import Interval +from sympy.stats import (Normal, P, E, density, Gamma, Poisson, Rayleigh, + variance, Bernoulli, Beta, Uniform, cdf) +from sympy.stats.compound_rv import CompoundDistribution, CompoundPSpace +from sympy.stats.crv_types import NormalDistribution +from sympy.stats.drv_types import PoissonDistribution +from sympy.stats.frv_types import BernoulliDistribution +from sympy.testing.pytest import raises, ignore_warnings +from sympy.stats.joint_rv_types import MultivariateNormalDistribution + +from sympy.abc import x + + +# helpers for testing troublesome unevaluated expressions +flat = lambda s: ''.join(str(s).split()) +streq = lambda *a: len(set(map(flat, a))) == 1 +assert streq(x, x) +assert streq(x, 'x') +assert not streq(x, x + 1) + + +def test_normal_CompoundDist(): + X = Normal('X', 1, 2) + Y = Normal('X', X, 4) + assert density(Y)(x).simplify() == sqrt(10)*exp(-x**2/40 + x/20 - S(1)/40)/(20*sqrt(pi)) + assert E(Y) == 1 # it is always equal to mean of X + assert P(Y > 1) == S(1)/2 # as 1 is the mean + assert P(Y > 5).simplify() == S(1)/2 - erf(sqrt(10)/5)/2 + assert variance(Y) == variance(X) + 4**2 # 2**2 + 4**2 + # https://math.stackexchange.com/questions/1484451/ + # (Contains proof of E and variance computation) + + +def test_poisson_CompoundDist(): + k, t, y = symbols('k t y', positive=True, real=True) + G = Gamma('G', k, t) + D = Poisson('P', G) + assert density(D)(y).simplify() == t**y*(t + 1)**(-k - y)*gamma(k + y)/(gamma(k)*gamma(y + 1)) + # https://en.wikipedia.org/wiki/Negative_binomial_distribution#Gamma%E2%80%93Poisson_mixture + assert E(D).simplify() == k*t # mean of NegativeBinomialDistribution + + +def test_bernoulli_CompoundDist(): + X = Beta('X', 1, 2) + Y = Bernoulli('Y', X) + assert density(Y).dict == {0: S(2)/3, 1: S(1)/3} + assert E(Y) == P(Eq(Y, 1)) == S(1)/3 + assert variance(Y) == S(2)/9 + assert cdf(Y) == {0: S(2)/3, 1: 1} + + # test issue 8128 + a = Bernoulli('a', S(1)/2) + b = Bernoulli('b', a) + assert density(b).dict == {0: S(1)/2, 1: S(1)/2} + assert P(b > 0.5) == S(1)/2 + + X = Uniform('X', 0, 1) + Y = Bernoulli('Y', X) + assert E(Y) == S(1)/2 + assert P(Eq(Y, 1)) == E(Y) + + +def test_unevaluated_CompoundDist(): + # these tests need to be removed once they work with evaluation as they are currently not + # evaluated completely in sympy. + R = Rayleigh('R', 4) + X = Normal('X', 3, R) + ans = ''' + Piecewise(((-sqrt(pi)*sinh(x/4 - 3/4) + sqrt(pi)*cosh(x/4 - 3/4))/( + 8*sqrt(pi)), Abs(arg(x - 3)) <= pi/4), (Integral(sqrt(2)*exp(-(x - 3) + **2/(2*R**2))*exp(-R**2/32)/(32*sqrt(pi)), (R, 0, oo)), True))''' + assert streq(density(X)(x), ans) + + expre = ''' + Integral(X*Integral(sqrt(2)*exp(-(X-3)**2/(2*R**2))*exp(-R**2/32)/(32* + sqrt(pi)),(R,0,oo)),(X,-oo,oo))''' + with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed + assert streq(E(X, evaluate=False).rewrite(Integral), expre) + + X = Poisson('X', 1) + Y = Poisson('Y', X) + Z = Poisson('Z', Y) + exprd = Sum(exp(-Y)*Y**x*Sum(exp(-1)*exp(-X)*X**Y/(factorial(X)*factorial(Y) + ), (X, 0, oo))/factorial(x), (Y, 0, oo)) + assert density(Z)(x) == exprd + + N = Normal('N', 1, 2) + M = Normal('M', 3, 4) + D = Normal('D', M, N) + exprd = ''' + Integral(sqrt(2)*exp(-(N-1)**2/8)*Integral(exp(-(x-M)**2/(2*N**2))*exp + (-(M-3)**2/32)/(8*pi*N),(M,-oo,oo))/(4*sqrt(pi)),(N,-oo,oo))''' + assert streq(density(D, evaluate=False)(x), exprd) + + +def test_Compound_Distribution(): + X = Normal('X', 2, 4) + N = NormalDistribution(X, 4) + C = CompoundDistribution(N) + assert C.is_Continuous + assert C.set == Interval(-oo, oo) + assert C.pdf(x, evaluate=True).simplify() == exp(-x**2/64 + x/16 - S(1)/16)/(8*sqrt(pi)) + + assert not isinstance(CompoundDistribution(NormalDistribution(2, 3)), + CompoundDistribution) + M = MultivariateNormalDistribution([1, 2], [[2, 1], [1, 2]]) + raises(NotImplementedError, lambda: CompoundDistribution(M)) + + X = Beta('X', 2, 4) + B = BernoulliDistribution(X, 1, 0) + C = CompoundDistribution(B) + assert C.is_Finite + assert C.set == {0, 1} + y = symbols('y', negative=False, integer=True) + assert C.pdf(y, evaluate=True) == Piecewise((S(1)/(30*beta(2, 4)), Eq(y, 0)), + (S(1)/(60*beta(2, 4)), Eq(y, 1)), (0, True)) + + k, t, z = symbols('k t z', positive=True, real=True) + G = Gamma('G', k, t) + X = PoissonDistribution(G) + C = CompoundDistribution(X) + assert C.is_Discrete + assert C.set == S.Naturals0 + assert C.pdf(z, evaluate=True).simplify() == t**z*(t + 1)**(-k - z)*gamma(k \ + + z)/(gamma(k)*gamma(z + 1)) + + +def test_compound_pspace(): + X = Normal('X', 2, 4) + Y = Normal('Y', 3, 6) + assert not isinstance(Y.pspace, CompoundPSpace) + N = NormalDistribution(1, 2) + D = PoissonDistribution(3) + B = BernoulliDistribution(0.2, 1, 0) + pspace1 = CompoundPSpace('N', N) + pspace2 = CompoundPSpace('D', D) + pspace3 = CompoundPSpace('B', B) + assert not isinstance(pspace1, CompoundPSpace) + assert not isinstance(pspace2, CompoundPSpace) + assert not isinstance(pspace3, CompoundPSpace) + M = MultivariateNormalDistribution([1, 2], [[2, 1], [1, 2]]) + raises(ValueError, lambda: CompoundPSpace('M', M)) + Y = Normal('Y', X, 6) + assert isinstance(Y.pspace, CompoundPSpace) + assert Y.pspace.distribution == CompoundDistribution(NormalDistribution(X, 6)) + assert Y.pspace.domain.set == Interval(-oo, oo) diff --git a/phivenv/Lib/site-packages/sympy/stats/tests/test_continuous_rv.py b/phivenv/Lib/site-packages/sympy/stats/tests/test_continuous_rv.py new file mode 100644 index 0000000000000000000000000000000000000000..b2c4206b5c29ffd3194d1ae05e57c51c9c1b6d78 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/stats/tests/test_continuous_rv.py @@ -0,0 +1,1583 @@ +from sympy.concrete.summations import Sum +from sympy.core.function import (Lambda, diff, expand_func) +from sympy.core.mul import Mul +from sympy.core import EulerGamma +from sympy.core.numbers import (E as e, I, Rational, pi) +from sympy.core.relational import (Eq, Ne) +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol, symbols) +from sympy.functions.combinatorial.factorials import (binomial, factorial) +from sympy.functions.elementary.complexes import (Abs, im, re, sign) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.hyperbolic import (cosh, sinh) +from sympy.functions.elementary.integers import floor +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import (asin, atan, cos, sin, tan) +from sympy.functions.special.bessel import (besseli, besselj, besselk) +from sympy.functions.special.beta_functions import beta +from sympy.functions.special.error_functions import (erf, erfc, erfi, expint) +from sympy.functions.special.gamma_functions import (gamma, lowergamma, uppergamma) +from sympy.functions.special.zeta_functions import zeta +from sympy.functions.special.hyper import hyper +from sympy.integrals.integrals import Integral +from sympy.logic.boolalg import (And, Or) +from sympy.sets.sets import Interval +from sympy.simplify.simplify import simplify +from sympy.utilities.lambdify import lambdify +from sympy.functions.special.error_functions import erfinv +from sympy.functions.special.hyper import meijerg +from sympy.sets.sets import FiniteSet, Complement, Intersection +from sympy.stats import (P, E, where, density, variance, covariance, skewness, kurtosis, median, + given, pspace, cdf, characteristic_function, moment_generating_function, + ContinuousRV, Arcsin, Benini, Beta, BetaNoncentral, BetaPrime, + Cauchy, Chi, ChiSquared, ChiNoncentral, Dagum, Davis, Erlang, ExGaussian, + Exponential, ExponentialPower, FDistribution, FisherZ, Frechet, Gamma, + GammaInverse, Gompertz, Gumbel, Kumaraswamy, Laplace, Levy, Logistic, LogCauchy, + LogLogistic, LogitNormal, LogNormal, Maxwell, Moyal, Nakagami, Normal, GaussianInverse, + Pareto, PowerFunction, QuadraticU, RaisedCosine, Rayleigh, Reciprocal, ShiftedGompertz, StudentT, + Trapezoidal, Triangular, Uniform, UniformSum, VonMises, Weibull, coskewness, + WignerSemicircle, Wald, correlation, moment, cmoment, smoment, quantile, + Lomax, BoundedPareto) + +from sympy.stats.crv_types import NormalDistribution, ExponentialDistribution, ContinuousDistributionHandmade +from sympy.stats.joint_rv_types import MultivariateLaplaceDistribution, MultivariateNormalDistribution +from sympy.stats.crv import SingleContinuousPSpace, SingleContinuousDomain +from sympy.stats.compound_rv import CompoundPSpace +from sympy.stats.symbolic_probability import Probability +from sympy.testing.pytest import raises, XFAIL, slow, ignore_warnings +from sympy.core.random import verify_numerically as tn + +oo = S.Infinity + +x, y, z = map(Symbol, 'xyz') + +def test_single_normal(): + mu = Symbol('mu', real=True) + sigma = Symbol('sigma', positive=True) + X = Normal('x', 0, 1) + Y = X*sigma + mu + + assert E(Y) == mu + assert variance(Y) == sigma**2 + pdf = density(Y) + x = Symbol('x', real=True) + assert (pdf(x) == + 2**S.Half*exp(-(x - mu)**2/(2*sigma**2))/(2*pi**S.Half*sigma)) + + assert P(X**2 < 1) == erf(2**S.Half/2) + ans = quantile(Y)(x) + assert ans == Complement(Intersection(FiniteSet( + sqrt(2)*sigma*(sqrt(2)*mu/(2*sigma)+ erfinv(2*x - 1))), + Interval(-oo, oo)), FiniteSet(mu)) + assert E(X, Eq(X, mu)) == mu + + assert median(X) == FiniteSet(0) + # issue 8248 + assert X.pspace.compute_expectation(1).doit() == 1 + + +def test_conditional_1d(): + X = Normal('x', 0, 1) + Y = given(X, X >= 0) + z = Symbol('z') + + assert density(Y)(z) == 2 * density(X)(z) + + assert Y.pspace.domain.set == Interval(0, oo) + assert E(Y) == sqrt(2) / sqrt(pi) + + assert E(X**2) == E(Y**2) + + +def test_ContinuousDomain(): + X = Normal('x', 0, 1) + assert where(X**2 <= 1).set == Interval(-1, 1) + assert where(X**2 <= 1).symbol == X.symbol + assert where(And(X**2 <= 1, X >= 0)).set == Interval(0, 1) + raises(ValueError, lambda: where(sin(X) > 1)) + + Y = given(X, X >= 0) + + assert Y.pspace.domain.set == Interval(0, oo) + + +def test_multiple_normal(): + X, Y = Normal('x', 0, 1), Normal('y', 0, 1) + p = Symbol("p", positive=True) + + assert E(X + Y) == 0 + assert variance(X + Y) == 2 + assert variance(X + X) == 4 + assert covariance(X, Y) == 0 + assert covariance(2*X + Y, -X) == -2*variance(X) + assert skewness(X) == 0 + assert skewness(X + Y) == 0 + assert kurtosis(X) == 3 + assert kurtosis(X+Y) == 3 + assert correlation(X, Y) == 0 + assert correlation(X, X + Y) == correlation(X, X - Y) + assert moment(X, 2) == 1 + assert cmoment(X, 3) == 0 + assert moment(X + Y, 4) == 12 + assert cmoment(X, 2) == variance(X) + assert smoment(X*X, 2) == 1 + assert smoment(X + Y, 3) == skewness(X + Y) + assert smoment(X + Y, 4) == kurtosis(X + Y) + assert E(X, Eq(X + Y, 0)) == 0 + assert variance(X, Eq(X + Y, 0)) == S.Half + assert quantile(X)(p) == sqrt(2)*erfinv(2*p - S.One) + + +def test_symbolic(): + mu1, mu2 = symbols('mu1 mu2', real=True) + s1, s2 = symbols('sigma1 sigma2', positive=True) + rate = Symbol('lambda', positive=True) + X = Normal('x', mu1, s1) + Y = Normal('y', mu2, s2) + Z = Exponential('z', rate) + a, b, c = symbols('a b c', real=True) + + assert E(X) == mu1 + assert E(X + Y) == mu1 + mu2 + assert E(a*X + b) == a*E(X) + b + assert variance(X) == s1**2 + assert variance(X + a*Y + b) == variance(X) + a**2*variance(Y) + + assert E(Z) == 1/rate + assert E(a*Z + b) == a*E(Z) + b + assert E(X + a*Z + b) == mu1 + a/rate + b + assert median(X) == FiniteSet(mu1) + + +def test_cdf(): + X = Normal('x', 0, 1) + + d = cdf(X) + assert P(X < 1) == d(1).rewrite(erfc) + assert d(0) == S.Half + + d = cdf(X, X > 0) # given X>0 + assert d(0) == 0 + + Y = Exponential('y', 10) + d = cdf(Y) + assert d(-5) == 0 + assert P(Y > 3) == 1 - d(3) + + raises(ValueError, lambda: cdf(X + Y)) + + Z = Exponential('z', 1) + f = cdf(Z) + assert f(z) == Piecewise((1 - exp(-z), z >= 0), (0, True)) + + +def test_characteristic_function(): + X = Uniform('x', 0, 1) + + cf = characteristic_function(X) + assert cf(1) == -I*(-1 + exp(I)) + + Y = Normal('y', 1, 1) + cf = characteristic_function(Y) + assert cf(0) == 1 + assert cf(1) == exp(I - S.Half) + + Z = Exponential('z', 5) + cf = characteristic_function(Z) + assert cf(0) == 1 + assert cf(1).expand() == Rational(25, 26) + I*5/26 + + X = GaussianInverse('x', 1, 1) + cf = characteristic_function(X) + assert cf(0) == 1 + assert cf(1) == exp(1 - sqrt(1 - 2*I)) + + X = ExGaussian('x', 0, 1, 1) + cf = characteristic_function(X) + assert cf(0) == 1 + assert cf(1) == (1 + I)*exp(Rational(-1, 2))/2 + + L = Levy('x', 0, 1) + cf = characteristic_function(L) + assert cf(0) == 1 + assert cf(1) == exp(-sqrt(2)*sqrt(-I)) + + +def test_moment_generating_function(): + t = symbols('t', positive=True) + + # Symbolic tests + a, b, c = symbols('a b c') + + mgf = moment_generating_function(Beta('x', a, b))(t) + assert mgf == hyper((a,), (a + b,), t) + + mgf = moment_generating_function(Chi('x', a))(t) + assert mgf == sqrt(2)*t*gamma(a/2 + S.Half)*\ + hyper((a/2 + S.Half,), (Rational(3, 2),), t**2/2)/gamma(a/2) +\ + hyper((a/2,), (S.Half,), t**2/2) + + mgf = moment_generating_function(ChiSquared('x', a))(t) + assert mgf == (1 - 2*t)**(-a/2) + + mgf = moment_generating_function(Erlang('x', a, b))(t) + assert mgf == (1 - t/b)**(-a) + + mgf = moment_generating_function(ExGaussian("x", a, b, c))(t) + assert mgf == exp(a*t + b**2*t**2/2)/(1 - t/c) + + mgf = moment_generating_function(Exponential('x', a))(t) + assert mgf == a/(a - t) + + mgf = moment_generating_function(Gamma('x', a, b))(t) + assert mgf == (-b*t + 1)**(-a) + + mgf = moment_generating_function(Gumbel('x', a, b))(t) + assert mgf == exp(b*t)*gamma(-a*t + 1) + + mgf = moment_generating_function(Gompertz('x', a, b))(t) + assert mgf == b*exp(b)*expint(t/a, b) + + mgf = moment_generating_function(Laplace('x', a, b))(t) + assert mgf == exp(a*t)/(-b**2*t**2 + 1) + + mgf = moment_generating_function(Logistic('x', a, b))(t) + assert mgf == exp(a*t)*beta(-b*t + 1, b*t + 1) + + mgf = moment_generating_function(Normal('x', a, b))(t) + assert mgf == exp(a*t + b**2*t**2/2) + + mgf = moment_generating_function(Pareto('x', a, b))(t) + assert mgf == b*(-a*t)**b*uppergamma(-b, -a*t) + + mgf = moment_generating_function(QuadraticU('x', a, b))(t) + assert str(mgf) == ("(3*(t*(-4*b + (a + b)**2) + 4)*exp(b*t) - " + "3*(t*(a**2 + 2*a*(b - 2) + b**2) + 4)*exp(a*t))/(t**2*(a - b)**3)") + + mgf = moment_generating_function(RaisedCosine('x', a, b))(t) + assert mgf == pi**2*exp(a*t)*sinh(b*t)/(b*t*(b**2*t**2 + pi**2)) + + mgf = moment_generating_function(Rayleigh('x', a))(t) + assert mgf == sqrt(2)*sqrt(pi)*a*t*(erf(sqrt(2)*a*t/2) + 1)\ + *exp(a**2*t**2/2)/2 + 1 + + mgf = moment_generating_function(Triangular('x', a, b, c))(t) + assert str(mgf) == ("(-2*(-a + b)*exp(c*t) + 2*(-a + c)*exp(b*t) + " + "2*(b - c)*exp(a*t))/(t**2*(-a + b)*(-a + c)*(b - c))") + + mgf = moment_generating_function(Uniform('x', a, b))(t) + assert mgf == (-exp(a*t) + exp(b*t))/(t*(-a + b)) + + mgf = moment_generating_function(UniformSum('x', a))(t) + assert mgf == ((exp(t) - 1)/t)**a + + mgf = moment_generating_function(WignerSemicircle('x', a))(t) + assert mgf == 2*besseli(1, a*t)/(a*t) + + # Numeric tests + + mgf = moment_generating_function(Beta('x', 1, 1))(t) + assert mgf.diff(t).subs(t, 1) == hyper((2,), (3,), 1)/2 + + mgf = moment_generating_function(Chi('x', 1))(t) + assert mgf.diff(t).subs(t, 1) == sqrt(2)*hyper((1,), (Rational(3, 2),), S.Half + )/sqrt(pi) + hyper((Rational(3, 2),), (Rational(3, 2),), S.Half) + 2*sqrt(2)*hyper((2,), + (Rational(5, 2),), S.Half)/(3*sqrt(pi)) + + mgf = moment_generating_function(ChiSquared('x', 1))(t) + assert mgf.diff(t).subs(t, 1) == I + + mgf = moment_generating_function(Erlang('x', 1, 1))(t) + assert mgf.diff(t).subs(t, 0) == 1 + + mgf = moment_generating_function(ExGaussian("x", 0, 1, 1))(t) + assert mgf.diff(t).subs(t, 2) == -exp(2) + + mgf = moment_generating_function(Exponential('x', 1))(t) + assert mgf.diff(t).subs(t, 0) == 1 + + mgf = moment_generating_function(Gamma('x', 1, 1))(t) + assert mgf.diff(t).subs(t, 0) == 1 + + mgf = moment_generating_function(Gumbel('x', 1, 1))(t) + assert mgf.diff(t).subs(t, 0) == EulerGamma + 1 + + mgf = moment_generating_function(Gompertz('x', 1, 1))(t) + assert mgf.diff(t).subs(t, 1) == -e*meijerg(((), (1, 1)), + ((0, 0, 0), ()), 1) + + mgf = moment_generating_function(Laplace('x', 1, 1))(t) + assert mgf.diff(t).subs(t, 0) == 1 + + mgf = moment_generating_function(Logistic('x', 1, 1))(t) + assert mgf.diff(t).subs(t, 0) == beta(1, 1) + + mgf = moment_generating_function(Normal('x', 0, 1))(t) + assert mgf.diff(t).subs(t, 1) == exp(S.Half) + + mgf = moment_generating_function(Pareto('x', 1, 1))(t) + assert mgf.diff(t).subs(t, 0) == expint(1, 0) + + mgf = moment_generating_function(QuadraticU('x', 1, 2))(t) + assert mgf.diff(t).subs(t, 1) == -12*e - 3*exp(2) + + mgf = moment_generating_function(RaisedCosine('x', 1, 1))(t) + assert mgf.diff(t).subs(t, 1) == -2*e*pi**2*sinh(1)/\ + (1 + pi**2)**2 + e*pi**2*cosh(1)/(1 + pi**2) + + mgf = moment_generating_function(Rayleigh('x', 1))(t) + assert mgf.diff(t).subs(t, 0) == sqrt(2)*sqrt(pi)/2 + + mgf = moment_generating_function(Triangular('x', 1, 3, 2))(t) + assert mgf.diff(t).subs(t, 1) == -e + exp(3) + + mgf = moment_generating_function(Uniform('x', 0, 1))(t) + assert mgf.diff(t).subs(t, 1) == 1 + + mgf = moment_generating_function(UniformSum('x', 1))(t) + assert mgf.diff(t).subs(t, 1) == 1 + + mgf = moment_generating_function(WignerSemicircle('x', 1))(t) + assert mgf.diff(t).subs(t, 1) == -2*besseli(1, 1) + besseli(2, 1) +\ + besseli(0, 1) + + +def test_ContinuousRV(): + pdf = sqrt(2)*exp(-x**2/2)/(2*sqrt(pi)) # Normal distribution + # X and Y should be equivalent + X = ContinuousRV(x, pdf, check=True) + Y = Normal('y', 0, 1) + + assert variance(X) == variance(Y) + assert P(X > 0) == P(Y > 0) + Z = ContinuousRV(z, exp(-z), set=Interval(0, oo)) + assert Z.pspace.domain.set == Interval(0, oo) + assert E(Z) == 1 + assert P(Z > 5) == exp(-5) + raises(ValueError, lambda: ContinuousRV(z, exp(-z), set=Interval(0, 10), check=True)) + + # the correct pdf for Gamma(k, theta) but the integral in `check` + # integrates to something equivalent to 1 and not to 1 exactly + _x, k, theta = symbols("x k theta", positive=True) + pdf = 1/(gamma(k)*theta**k)*_x**(k-1)*exp(-_x/theta) + X = ContinuousRV(_x, pdf, set=Interval(0, oo)) + Y = Gamma('y', k, theta) + assert (E(X) - E(Y)).simplify() == 0 + assert (variance(X) - variance(Y)).simplify() == 0 + + +def test_arcsin(): + + a = Symbol("a", real=True) + b = Symbol("b", real=True) + + X = Arcsin('x', a, b) + assert density(X)(x) == 1/(pi*sqrt((-x + b)*(x - a))) + assert cdf(X)(x) == Piecewise((0, a > x), + (2*asin(sqrt((-a + x)/(-a + b)))/pi, b >= x), + (1, True)) + assert pspace(X).domain.set == Interval(a, b) + +def test_benini(): + alpha = Symbol("alpha", positive=True) + beta = Symbol("beta", positive=True) + sigma = Symbol("sigma", positive=True) + X = Benini('x', alpha, beta, sigma) + + assert density(X)(x) == ((alpha/x + 2*beta*log(x/sigma)/x) + *exp(-alpha*log(x/sigma) - beta*log(x/sigma)**2)) + + assert pspace(X).domain.set == Interval(sigma, oo) + raises(NotImplementedError, lambda: moment_generating_function(X)) + alpha = Symbol("alpha", nonpositive=True) + raises(ValueError, lambda: Benini('x', alpha, beta, sigma)) + + beta = Symbol("beta", nonpositive=True) + raises(ValueError, lambda: Benini('x', alpha, beta, sigma)) + + alpha = Symbol("alpha", positive=True) + raises(ValueError, lambda: Benini('x', alpha, beta, sigma)) + + beta = Symbol("beta", positive=True) + sigma = Symbol("sigma", nonpositive=True) + raises(ValueError, lambda: Benini('x', alpha, beta, sigma)) + +def test_beta(): + a, b = symbols('alpha beta', positive=True) + B = Beta('x', a, b) + + assert pspace(B).domain.set == Interval(0, 1) + assert characteristic_function(B)(x) == hyper((a,), (a + b,), I*x) + assert density(B)(x) == x**(a - 1)*(1 - x)**(b - 1)/beta(a, b) + + assert simplify(E(B)) == a / (a + b) + assert simplify(variance(B)) == a*b / (a**3 + 3*a**2*b + a**2 + 3*a*b**2 + 2*a*b + b**3 + b**2) + + # Full symbolic solution is too much, test with numeric version + a, b = 1, 2 + B = Beta('x', a, b) + assert expand_func(E(B)) == a / S(a + b) + assert expand_func(variance(B)) == (a*b) / S((a + b)**2 * (a + b + 1)) + assert median(B) == FiniteSet(1 - 1/sqrt(2)) + +def test_beta_noncentral(): + a, b = symbols('a b', positive=True) + c = Symbol('c', nonnegative=True) + _k = Dummy('k') + + X = BetaNoncentral('x', a, b, c) + + assert pspace(X).domain.set == Interval(0, 1) + + dens = density(X) + z = Symbol('z') + + res = Sum( z**(_k + a - 1)*(c/2)**_k*(1 - z)**(b - 1)*exp(-c/2)/ + (beta(_k + a, b)*factorial(_k)), (_k, 0, oo)) + assert dens(z).dummy_eq(res) + + # BetaCentral should not raise if the assumptions + # on the symbols can not be determined + a, b, c = symbols('a b c') + assert BetaNoncentral('x', a, b, c) + + a = Symbol('a', positive=False, real=True) + raises(ValueError, lambda: BetaNoncentral('x', a, b, c)) + + a = Symbol('a', positive=True) + b = Symbol('b', positive=False, real=True) + raises(ValueError, lambda: BetaNoncentral('x', a, b, c)) + + a = Symbol('a', positive=True) + b = Symbol('b', positive=True) + c = Symbol('c', nonnegative=False, real=True) + raises(ValueError, lambda: BetaNoncentral('x', a, b, c)) + +def test_betaprime(): + alpha = Symbol("alpha", positive=True) + + betap = Symbol("beta", positive=True) + + X = BetaPrime('x', alpha, betap) + assert density(X)(x) == x**(alpha - 1)*(x + 1)**(-alpha - betap)/beta(alpha, betap) + + alpha = Symbol("alpha", nonpositive=True) + raises(ValueError, lambda: BetaPrime('x', alpha, betap)) + + alpha = Symbol("alpha", positive=True) + betap = Symbol("beta", nonpositive=True) + raises(ValueError, lambda: BetaPrime('x', alpha, betap)) + X = BetaPrime('x', 1, 1) + assert median(X) == FiniteSet(1) + + +def test_BoundedPareto(): + L, H = symbols('L, H', negative=True) + raises(ValueError, lambda: BoundedPareto('X', 1, L, H)) + L, H = symbols('L, H', real=False) + raises(ValueError, lambda: BoundedPareto('X', 1, L, H)) + L, H = symbols('L, H', positive=True) + raises(ValueError, lambda: BoundedPareto('X', -1, L, H)) + + X = BoundedPareto('X', 2, L, H) + assert X.pspace.domain.set == Interval(L, H) + assert density(X)(x) == 2*L**2/(x**3*(1 - L**2/H**2)) + assert cdf(X)(x) == Piecewise((-H**2*L**2/(x**2*(H**2 - L**2)) \ + + H**2/(H**2 - L**2), L <= x), (0, True)) + assert E(X).simplify() == 2*H*L/(H + L) + X = BoundedPareto('X', 1, 2, 4) + assert E(X).simplify() == log(16) + assert median(X) == FiniteSet(Rational(8, 3)) + assert variance(X).simplify() == 8 - 16*log(2)**2 + + +def test_cauchy(): + x0 = Symbol("x0", real=True) + gamma = Symbol("gamma", positive=True) + p = Symbol("p", positive=True) + + X = Cauchy('x', x0, gamma) + # Tests the characteristic function + assert characteristic_function(X)(x) == exp(-gamma*Abs(x) + I*x*x0) + raises(NotImplementedError, lambda: moment_generating_function(X)) + assert density(X)(x) == 1/(pi*gamma*(1 + (x - x0)**2/gamma**2)) + assert diff(cdf(X)(x), x) == density(X)(x) + assert quantile(X)(p) == gamma*tan(pi*(p - S.Half)) + x0 + + x1 = Symbol("x1", real=False) + raises(ValueError, lambda: Cauchy('x', x1, gamma)) + gamma = Symbol("gamma", nonpositive=True) + raises(ValueError, lambda: Cauchy('x', x0, gamma)) + assert median(X) == FiniteSet(x0) + +def test_chi(): + from sympy.core.numbers import I + k = Symbol("k", integer=True) + + X = Chi('x', k) + assert density(X)(x) == 2**(-k/2 + 1)*x**(k - 1)*exp(-x**2/2)/gamma(k/2) + + # Tests the characteristic function + assert characteristic_function(X)(x) == sqrt(2)*I*x*gamma(k/2 + S(1)/2)*hyper((k/2 + S(1)/2,), + (S(3)/2,), -x**2/2)/gamma(k/2) + hyper((k/2,), (S(1)/2,), -x**2/2) + + # Tests the moment generating function + assert moment_generating_function(X)(x) == sqrt(2)*x*gamma(k/2 + S(1)/2)*hyper((k/2 + S(1)/2,), + (S(3)/2,), x**2/2)/gamma(k/2) + hyper((k/2,), (S(1)/2,), x**2/2) + + k = Symbol("k", integer=True, positive=False) + raises(ValueError, lambda: Chi('x', k)) + + k = Symbol("k", integer=False, positive=True) + raises(ValueError, lambda: Chi('x', k)) + +def test_chi_noncentral(): + k = Symbol("k", integer=True) + l = Symbol("l") + + X = ChiNoncentral("x", k, l) + assert density(X)(x) == (x**k*l*(x*l)**(-k/2)* + exp(-x**2/2 - l**2/2)*besseli(k/2 - 1, x*l)) + + k = Symbol("k", integer=True, positive=False) + raises(ValueError, lambda: ChiNoncentral('x', k, l)) + + k = Symbol("k", integer=True, positive=True) + l = Symbol("l", nonpositive=True) + raises(ValueError, lambda: ChiNoncentral('x', k, l)) + + k = Symbol("k", integer=False) + l = Symbol("l", positive=True) + raises(ValueError, lambda: ChiNoncentral('x', k, l)) + + +def test_chi_squared(): + k = Symbol("k", integer=True) + X = ChiSquared('x', k) + + # Tests the characteristic function + assert characteristic_function(X)(x) == ((-2*I*x + 1)**(-k/2)) + + assert density(X)(x) == 2**(-k/2)*x**(k/2 - 1)*exp(-x/2)/gamma(k/2) + assert cdf(X)(x) == Piecewise((lowergamma(k/2, x/2)/gamma(k/2), x >= 0), (0, True)) + assert E(X) == k + assert variance(X) == 2*k + + X = ChiSquared('x', 15) + assert cdf(X)(3) == -14873*sqrt(6)*exp(Rational(-3, 2))/(5005*sqrt(pi)) + erf(sqrt(6)/2) + + k = Symbol("k", integer=True, positive=False) + raises(ValueError, lambda: ChiSquared('x', k)) + + k = Symbol("k", integer=False, positive=True) + raises(ValueError, lambda: ChiSquared('x', k)) + + +def test_dagum(): + p = Symbol("p", positive=True) + b = Symbol("b", positive=True) + a = Symbol("a", positive=True) + + X = Dagum('x', p, a, b) + assert density(X)(x) == a*p*(x/b)**(a*p)*((x/b)**a + 1)**(-p - 1)/x + assert cdf(X)(x) == Piecewise(((1 + (x/b)**(-a))**(-p), x >= 0), + (0, True)) + + p = Symbol("p", nonpositive=True) + raises(ValueError, lambda: Dagum('x', p, a, b)) + + p = Symbol("p", positive=True) + b = Symbol("b", nonpositive=True) + raises(ValueError, lambda: Dagum('x', p, a, b)) + + b = Symbol("b", positive=True) + a = Symbol("a", nonpositive=True) + raises(ValueError, lambda: Dagum('x', p, a, b)) + X = Dagum('x', 1, 1, 1) + assert median(X) == FiniteSet(1) + +def test_davis(): + b = Symbol("b", positive=True) + n = Symbol("n", positive=True) + mu = Symbol("mu", positive=True) + + X = Davis('x', b, n, mu) + dividend = b**n*(x - mu)**(-1-n) + divisor = (exp(b/(x-mu))-1)*(gamma(n)*zeta(n)) + assert density(X)(x) == dividend/divisor + + +def test_erlang(): + k = Symbol("k", integer=True, positive=True) + l = Symbol("l", positive=True) + + X = Erlang("x", k, l) + assert density(X)(x) == x**(k - 1)*l**k*exp(-x*l)/gamma(k) + assert cdf(X)(x) == Piecewise((lowergamma(k, l*x)/gamma(k), x > 0), + (0, True)) + + +def test_exgaussian(): + m, z = symbols("m, z") + s, l = symbols("s, l", positive=True) + X = ExGaussian("x", m, s, l) + + assert density(X)(z) == l*exp(l*(l*s**2 + 2*m - 2*z)/2) *\ + erfc(sqrt(2)*(l*s**2 + m - z)/(2*s))/2 + + # Note: actual_output simplifies to expected_output. + # Ideally cdf(X)(z) would return expected_output + # expected_output = (erf(sqrt(2)*(l*s**2 + m - z)/(2*s)) - 1)*exp(l*(l*s**2 + 2*m - 2*z)/2)/2 - erf(sqrt(2)*(m - z)/(2*s))/2 + S.Half + u = l*(z - m) + v = l*s + GaussianCDF1 = cdf(Normal('x', 0, v))(u) + GaussianCDF2 = cdf(Normal('x', v**2, v))(u) + actual_output = GaussianCDF1 - exp(-u + (v**2/2) + log(GaussianCDF2)) + assert cdf(X)(z) == actual_output + # assert simplify(actual_output) == expected_output + + assert variance(X).expand() == s**2 + l**(-2) + + assert skewness(X).expand() == 2/(l**3*s**2*sqrt(s**2 + l**(-2)) + l * + sqrt(s**2 + l**(-2))) + + +@slow +def test_exponential(): + rate = Symbol('lambda', positive=True) + X = Exponential('x', rate) + p = Symbol("p", positive=True, real=True) + + assert E(X) == 1/rate + assert variance(X) == 1/rate**2 + assert skewness(X) == 2 + assert skewness(X) == smoment(X, 3) + assert kurtosis(X) == 9 + assert kurtosis(X) == smoment(X, 4) + assert smoment(2*X, 4) == smoment(X, 4) + assert moment(X, 3) == 3*2*1/rate**3 + assert P(X > 0) is S.One + assert P(X > 1) == exp(-rate) + assert P(X > 10) == exp(-10*rate) + assert quantile(X)(p) == -log(1-p)/rate + + assert where(X <= 1).set == Interval(0, 1) + Y = Exponential('y', 1) + assert median(Y) == FiniteSet(log(2)) + #Test issue 9970 + z = Dummy('z') + assert P(X > z) == exp(-z*rate) + assert P(X < z) == 0 + #Test issue 10076 (Distribution with interval(0,oo)) + x = Symbol('x') + _z = Dummy('_z') + b = SingleContinuousPSpace(x, ExponentialDistribution(2)) + + with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed + expected1 = Integral(2*exp(-2*_z), (_z, 3, oo)) + assert b.probability(x > 3, evaluate=False).rewrite(Integral).dummy_eq(expected1) + + expected2 = Integral(2*exp(-2*_z), (_z, 0, 4)) + assert b.probability(x < 4, evaluate=False).rewrite(Integral).dummy_eq(expected2) + Y = Exponential('y', 2*rate) + assert coskewness(X, X, X) == skewness(X) + assert coskewness(X, Y + rate*X, Y + 2*rate*X) == \ + 4/(sqrt(1 + 1/(4*rate**2))*sqrt(4 + 1/(4*rate**2))) + assert coskewness(X + 2*Y, Y + X, Y + 2*X, X > 3) == \ + sqrt(170)*Rational(9, 85) + +def test_exponential_power(): + mu = Symbol('mu') + z = Symbol('z') + alpha = Symbol('alpha', positive=True) + beta = Symbol('beta', positive=True) + + X = ExponentialPower('x', mu, alpha, beta) + + assert density(X)(z) == beta*exp(-(Abs(mu - z)/alpha) + ** beta)/(2*alpha*gamma(1/beta)) + assert cdf(X)(z) == S.Half + lowergamma(1/beta, + (Abs(mu - z)/alpha)**beta)*sign(-mu + z)/\ + (2*gamma(1/beta)) + + +def test_f_distribution(): + d1 = Symbol("d1", positive=True) + d2 = Symbol("d2", positive=True) + + X = FDistribution("x", d1, d2) + + assert density(X)(x) == (d2**(d2/2)*sqrt((d1*x)**d1*(d1*x + d2)**(-d1 - d2)) + /(x*beta(d1/2, d2/2))) + + raises(NotImplementedError, lambda: moment_generating_function(X)) + d1 = Symbol("d1", nonpositive=True) + raises(ValueError, lambda: FDistribution('x', d1, d1)) + + d1 = Symbol("d1", positive=True, integer=False) + raises(ValueError, lambda: FDistribution('x', d1, d1)) + + d1 = Symbol("d1", positive=True) + d2 = Symbol("d2", nonpositive=True) + raises(ValueError, lambda: FDistribution('x', d1, d2)) + + d2 = Symbol("d2", positive=True, integer=False) + raises(ValueError, lambda: FDistribution('x', d1, d2)) + + +def test_fisher_z(): + d1 = Symbol("d1", positive=True) + d2 = Symbol("d2", positive=True) + + X = FisherZ("x", d1, d2) + assert density(X)(x) == (2*d1**(d1/2)*d2**(d2/2)*(d1*exp(2*x) + d2) + **(-d1/2 - d2/2)*exp(d1*x)/beta(d1/2, d2/2)) + +def test_frechet(): + a = Symbol("a", positive=True) + s = Symbol("s", positive=True) + m = Symbol("m", real=True) + + X = Frechet("x", a, s=s, m=m) + assert density(X)(x) == a*((x - m)/s)**(-a - 1)*exp(-((x - m)/s)**(-a))/s + assert cdf(X)(x) == Piecewise((exp(-((-m + x)/s)**(-a)), m <= x), (0, True)) + +@slow +def test_gamma(): + k = Symbol("k", positive=True) + theta = Symbol("theta", positive=True) + + X = Gamma('x', k, theta) + + # Tests characteristic function + assert characteristic_function(X)(x) == ((-I*theta*x + 1)**(-k)) + + assert density(X)(x) == x**(k - 1)*theta**(-k)*exp(-x/theta)/gamma(k) + assert cdf(X, meijerg=True)(z) == Piecewise( + (-k*lowergamma(k, 0)/gamma(k + 1) + + k*lowergamma(k, z/theta)/gamma(k + 1), z >= 0), + (0, True)) + + # assert simplify(variance(X)) == k*theta**2 # handled numerically below + assert E(X) == moment(X, 1) + + k, theta = symbols('k theta', positive=True) + X = Gamma('x', k, theta) + assert E(X) == k*theta + assert variance(X) == k*theta**2 + assert skewness(X).expand() == 2/sqrt(k) + assert kurtosis(X).expand() == 3 + 6/k + + Y = Gamma('y', 2*k, 3*theta) + assert coskewness(X, theta*X + Y, k*X + Y).simplify() == \ + 2*531441**(-k)*sqrt(k)*theta*(3*3**(12*k) - 2*531441**k) \ + /(sqrt(k**2 + 18)*sqrt(theta**2 + 18)) + +def test_gamma_inverse(): + a = Symbol("a", positive=True) + b = Symbol("b", positive=True) + X = GammaInverse("x", a, b) + assert density(X)(x) == x**(-a - 1)*b**a*exp(-b/x)/gamma(a) + assert cdf(X)(x) == Piecewise((uppergamma(a, b/x)/gamma(a), x > 0), (0, True)) + assert characteristic_function(X)(x) == 2 * (-I*b*x)**(a/2) \ + * besselk(a, 2*sqrt(b)*sqrt(-I*x))/gamma(a) + raises(NotImplementedError, lambda: moment_generating_function(X)) + +def test_gompertz(): + b = Symbol("b", positive=True) + eta = Symbol("eta", positive=True) + + X = Gompertz("x", b, eta) + + assert density(X)(x) == b*eta*exp(eta)*exp(b*x)*exp(-eta*exp(b*x)) + assert cdf(X)(x) == 1 - exp(eta)*exp(-eta*exp(b*x)) + assert diff(cdf(X)(x), x) == density(X)(x) + + +def test_gumbel(): + beta = Symbol("beta", positive=True) + mu = Symbol("mu") + x = Symbol("x") + y = Symbol("y") + X = Gumbel("x", beta, mu) + Y = Gumbel("y", beta, mu, minimum=True) + assert density(X)(x).expand() == \ + exp(mu/beta)*exp(-x/beta)*exp(-exp(mu/beta)*exp(-x/beta))/beta + assert density(Y)(y).expand() == \ + exp(-mu/beta)*exp(y/beta)*exp(-exp(-mu/beta)*exp(y/beta))/beta + assert cdf(X)(x).expand() == \ + exp(-exp(mu/beta)*exp(-x/beta)) + assert characteristic_function(X)(x) == exp(I*mu*x)*gamma(-I*beta*x + 1) + +def test_kumaraswamy(): + a = Symbol("a", positive=True) + b = Symbol("b", positive=True) + + X = Kumaraswamy("x", a, b) + assert density(X)(x) == x**(a - 1)*a*b*(-x**a + 1)**(b - 1) + assert cdf(X)(x) == Piecewise((0, x < 0), + (-(-x**a + 1)**b + 1, x <= 1), + (1, True)) + + +def test_laplace(): + mu = Symbol("mu") + b = Symbol("b", positive=True) + + X = Laplace('x', mu, b) + + #Tests characteristic_function + assert characteristic_function(X)(x) == (exp(I*mu*x)/(b**2*x**2 + 1)) + + assert density(X)(x) == exp(-Abs(x - mu)/b)/(2*b) + assert cdf(X)(x) == Piecewise((exp((-mu + x)/b)/2, mu > x), + (-exp((mu - x)/b)/2 + 1, True)) + X = Laplace('x', [1, 2], [[1, 0], [0, 1]]) + assert isinstance(pspace(X).distribution, MultivariateLaplaceDistribution) + +def test_levy(): + mu = Symbol("mu", real=True) + c = Symbol("c", positive=True) + + X = Levy('x', mu, c) + assert X.pspace.domain.set == Interval(mu, oo) + assert density(X)(x) == sqrt(c/(2*pi))*exp(-c/(2*(x - mu)))/((x - mu)**(S.One + S.Half)) + assert cdf(X)(x) == erfc(sqrt(c/(2*(x - mu)))) + + raises(NotImplementedError, lambda: moment_generating_function(X)) + mu = Symbol("mu", real=False) + raises(ValueError, lambda: Levy('x',mu,c)) + + c = Symbol("c", nonpositive=True) + raises(ValueError, lambda: Levy('x',mu,c)) + + mu = Symbol("mu", real=True) + raises(ValueError, lambda: Levy('x',mu,c)) + +def test_logcauchy(): + mu = Symbol("mu", positive=True) + sigma = Symbol("sigma", positive=True) + + X = LogCauchy("x", mu, sigma) + + assert density(X)(x) == sigma/(x*pi*(sigma**2 + (-mu + log(x))**2)) + assert cdf(X)(x) == atan((log(x) - mu)/sigma)/pi + S.Half + + +def test_logistic(): + mu = Symbol("mu", real=True) + s = Symbol("s", positive=True) + p = Symbol("p", positive=True) + + X = Logistic('x', mu, s) + + #Tests characteristics_function + assert characteristic_function(X)(x) == \ + (Piecewise((pi*s*x*exp(I*mu*x)/sinh(pi*s*x), Ne(x, 0)), (1, True))) + + assert density(X)(x) == exp((-x + mu)/s)/(s*(exp((-x + mu)/s) + 1)**2) + assert cdf(X)(x) == 1/(exp((mu - x)/s) + 1) + assert quantile(X)(p) == mu - s*log(-S.One + 1/p) + +def test_loglogistic(): + a, b = symbols('a b') + assert LogLogistic('x', a, b) + + a = Symbol('a', negative=True) + b = Symbol('b', positive=True) + raises(ValueError, lambda: LogLogistic('x', a, b)) + + a = Symbol('a', positive=True) + b = Symbol('b', negative=True) + raises(ValueError, lambda: LogLogistic('x', a, b)) + + a, b, z, p = symbols('a b z p', positive=True) + X = LogLogistic('x', a, b) + assert density(X)(z) == b*(z/a)**(b - 1)/(a*((z/a)**b + 1)**2) + assert cdf(X)(z) == 1/(1 + (z/a)**(-b)) + assert quantile(X)(p) == a*(p/(1 - p))**(1/b) + + # Expectation + assert E(X) == Piecewise((S.NaN, b <= 1), (pi*a/(b*sin(pi/b)), True)) + b = symbols('b', prime=True) # b > 1 + X = LogLogistic('x', a, b) + assert E(X) == pi*a/(b*sin(pi/b)) + X = LogLogistic('x', 1, 2) + assert median(X) == FiniteSet(1) + +def test_logitnormal(): + mu = Symbol('mu', real=True) + s = Symbol('s', positive=True) + X = LogitNormal('x', mu, s) + x = Symbol('x') + + assert density(X)(x) == sqrt(2)*exp(-(-mu + log(x/(1 - x)))**2/(2*s**2))/(2*sqrt(pi)*s*x*(1 - x)) + assert cdf(X)(x) == erf(sqrt(2)*(-mu + log(x/(1 - x)))/(2*s))/2 + S(1)/2 + +def test_lognormal(): + mean = Symbol('mu', real=True) + std = Symbol('sigma', positive=True) + X = LogNormal('x', mean, std) + # The sympy integrator can't do this too well + #assert E(X) == exp(mean+std**2/2) + #assert variance(X) == (exp(std**2)-1) * exp(2*mean + std**2) + + # The sympy integrator can't do this too well + #assert E(X) == + raises(NotImplementedError, lambda: moment_generating_function(X)) + mu = Symbol("mu", real=True) + sigma = Symbol("sigma", positive=True) + + X = LogNormal('x', mu, sigma) + assert density(X)(x) == (sqrt(2)*exp(-(-mu + log(x))**2 + /(2*sigma**2))/(2*x*sqrt(pi)*sigma)) + # Tests cdf + assert cdf(X)(x) == Piecewise( + (erf(sqrt(2)*(-mu + log(x))/(2*sigma))/2 + + S(1)/2, x > 0), (0, True)) + + X = LogNormal('x', 0, 1) # Mean 0, standard deviation 1 + assert density(X)(x) == sqrt(2)*exp(-log(x)**2/2)/(2*x*sqrt(pi)) + + +def test_Lomax(): + a, l = symbols('a, l', negative=True) + raises(ValueError, lambda: Lomax('X', a, l)) + a, l = symbols('a, l', real=False) + raises(ValueError, lambda: Lomax('X', a, l)) + + a, l = symbols('a, l', positive=True) + X = Lomax('X', a, l) + assert X.pspace.domain.set == Interval(0, oo) + assert density(X)(x) == a*(1 + x/l)**(-a - 1)/l + assert cdf(X)(x) == Piecewise((1 - (1 + x/l)**(-a), x >= 0), (0, True)) + a = 3 + X = Lomax('X', a, l) + assert E(X) == l/2 + assert median(X) == FiniteSet(l*(-1 + 2**Rational(1, 3))) + assert variance(X) == 3*l**2/4 + + +def test_maxwell(): + a = Symbol("a", positive=True) + + X = Maxwell('x', a) + + assert density(X)(x) == (sqrt(2)*x**2*exp(-x**2/(2*a**2))/ + (sqrt(pi)*a**3)) + assert E(X) == 2*sqrt(2)*a/sqrt(pi) + assert variance(X) == -8*a**2/pi + 3*a**2 + assert cdf(X)(x) == erf(sqrt(2)*x/(2*a)) - sqrt(2)*x*exp(-x**2/(2*a**2))/(sqrt(pi)*a) + assert diff(cdf(X)(x), x) == density(X)(x) + + +@slow +def test_Moyal(): + mu = Symbol('mu',real=False) + sigma = Symbol('sigma', positive=True) + raises(ValueError, lambda: Moyal('M',mu, sigma)) + + mu = Symbol('mu', real=True) + sigma = Symbol('sigma', negative=True) + raises(ValueError, lambda: Moyal('M',mu, sigma)) + + sigma = Symbol('sigma', positive=True) + M = Moyal('M', mu, sigma) + assert density(M)(z) == sqrt(2)*exp(-exp((mu - z)/sigma)/2 + - (-mu + z)/(2*sigma))/(2*sqrt(pi)*sigma) + assert cdf(M)(z).simplify() == 1 - erf(sqrt(2)*exp((mu - z)/(2*sigma))/2) + assert characteristic_function(M)(z) == 2**(-I*sigma*z)*exp(I*mu*z) \ + *gamma(-I*sigma*z + Rational(1, 2))/sqrt(pi) + assert E(M) == mu + EulerGamma*sigma + sigma*log(2) + assert moment_generating_function(M)(z) == 2**(-sigma*z)*exp(mu*z) \ + *gamma(-sigma*z + Rational(1, 2))/sqrt(pi) + + +def test_nakagami(): + mu = Symbol("mu", positive=True) + omega = Symbol("omega", positive=True) + + X = Nakagami('x', mu, omega) + assert density(X)(x) == (2*x**(2*mu - 1)*mu**mu*omega**(-mu) + *exp(-x**2*mu/omega)/gamma(mu)) + assert simplify(E(X)) == (sqrt(mu)*sqrt(omega) + *gamma(mu + S.Half)/gamma(mu + 1)) + assert simplify(variance(X)) == ( + omega - omega*gamma(mu + S.Half)**2/(gamma(mu)*gamma(mu + 1))) + assert cdf(X)(x) == Piecewise( + (lowergamma(mu, mu*x**2/omega)/gamma(mu), x > 0), + (0, True)) + X = Nakagami('x', 1, 1) + assert median(X) == FiniteSet(sqrt(log(2))) + +def test_gaussian_inverse(): + # test for symbolic parameters + a, b = symbols('a b') + assert GaussianInverse('x', a, b) + + # Inverse Gaussian distribution is also known as Wald distribution + # `GaussianInverse` can also be referred by the name `Wald` + a, b, z = symbols('a b z') + X = Wald('x', a, b) + assert density(X)(z) == sqrt(2)*sqrt(b/z**3)*exp(-b*(-a + z)**2/(2*a**2*z))/(2*sqrt(pi)) + + a, b = symbols('a b', positive=True) + z = Symbol('z', positive=True) + + X = GaussianInverse('x', a, b) + assert density(X)(z) == sqrt(2)*sqrt(b)*sqrt(z**(-3))*exp(-b*(-a + z)**2/(2*a**2*z))/(2*sqrt(pi)) + assert E(X) == a + assert variance(X).expand() == a**3/b + assert cdf(X)(z) == (S.Half - erf(sqrt(2)*sqrt(b)*(1 + z/a)/(2*sqrt(z)))/2)*exp(2*b/a) +\ + erf(sqrt(2)*sqrt(b)*(-1 + z/a)/(2*sqrt(z)))/2 + S.Half + + a = symbols('a', nonpositive=True) + raises(ValueError, lambda: GaussianInverse('x', a, b)) + + a = symbols('a', positive=True) + b = symbols('b', nonpositive=True) + raises(ValueError, lambda: GaussianInverse('x', a, b)) + +def test_pareto(): + xm, beta = symbols('xm beta', positive=True) + alpha = beta + 5 + X = Pareto('x', xm, alpha) + + dens = density(X) + + #Tests cdf function + assert cdf(X)(x) == \ + Piecewise((-x**(-beta - 5)*xm**(beta + 5) + 1, x >= xm), (0, True)) + + #Tests characteristic_function + assert characteristic_function(X)(x) == \ + ((-I*x*xm)**(beta + 5)*(beta + 5)*uppergamma(-beta - 5, -I*x*xm)) + + assert dens(x) == x**(-(alpha + 1))*xm**(alpha)*(alpha) + + assert simplify(E(X)) == alpha*xm/(alpha-1) + + # computation of taylor series for MGF still too slow + #assert simplify(variance(X)) == xm**2*alpha / ((alpha-1)**2*(alpha-2)) + + +def test_pareto_numeric(): + xm, beta = 3, 2 + alpha = beta + 5 + X = Pareto('x', xm, alpha) + + assert E(X) == alpha*xm/S(alpha - 1) + assert variance(X) == xm**2*alpha / S((alpha - 1)**2*(alpha - 2)) + assert median(X) == FiniteSet(3*2**Rational(1, 7)) + # Skewness tests too slow. Try shortcutting function? + + +def test_PowerFunction(): + alpha = Symbol("alpha", nonpositive=True) + a, b = symbols('a, b', real=True) + raises (ValueError, lambda: PowerFunction('x', alpha, a, b)) + + a, b = symbols('a, b', real=False) + raises (ValueError, lambda: PowerFunction('x', alpha, a, b)) + + alpha = Symbol("alpha", positive=True) + a, b = symbols('a, b', real=True) + raises (ValueError, lambda: PowerFunction('x', alpha, 5, 2)) + + X = PowerFunction('X', 2, a, b) + assert density(X)(z) == (-2*a + 2*z)/(-a + b)**2 + assert cdf(X)(z) == Piecewise((a**2/(a**2 - 2*a*b + b**2) - + 2*a*z/(a**2 - 2*a*b + b**2) + z**2/(a**2 - 2*a*b + b**2), a <= z), (0, True)) + + X = PowerFunction('X', 2, 0, 1) + assert density(X)(z) == 2*z + assert cdf(X)(z) == Piecewise((z**2, z >= 0), (0,True)) + assert E(X) == Rational(2,3) + assert P(X < 0) == 0 + assert P(X < 1) == 1 + assert median(X) == FiniteSet(1/sqrt(2)) + +def test_raised_cosine(): + mu = Symbol("mu", real=True) + s = Symbol("s", positive=True) + + X = RaisedCosine("x", mu, s) + + assert pspace(X).domain.set == Interval(mu - s, mu + s) + #Tests characteristics_function + assert characteristic_function(X)(x) == \ + Piecewise((exp(-I*pi*mu/s)/2, Eq(x, -pi/s)), (exp(I*pi*mu/s)/2, Eq(x, pi/s)), (pi**2*exp(I*mu*x)*sin(s*x)/(s*x*(-s**2*x**2 + pi**2)), True)) + + assert density(X)(x) == (Piecewise(((cos(pi*(x - mu)/s) + 1)/(2*s), + And(x <= mu + s, mu - s <= x)), (0, True))) + + +def test_rayleigh(): + sigma = Symbol("sigma", positive=True) + + X = Rayleigh('x', sigma) + + #Tests characteristic_function + assert characteristic_function(X)(x) == (-sqrt(2)*sqrt(pi)*sigma*x*(erfi(sqrt(2)*sigma*x/2) - I)*exp(-sigma**2*x**2/2)/2 + 1) + + assert density(X)(x) == x*exp(-x**2/(2*sigma**2))/sigma**2 + assert E(X) == sqrt(2)*sqrt(pi)*sigma/2 + assert variance(X) == -pi*sigma**2/2 + 2*sigma**2 + assert cdf(X)(x) == 1 - exp(-x**2/(2*sigma**2)) + assert diff(cdf(X)(x), x) == density(X)(x) + +def test_reciprocal(): + a = Symbol("a", real=True) + b = Symbol("b", real=True) + + X = Reciprocal('x', a, b) + assert density(X)(x) == 1/(x*(-log(a) + log(b))) + assert cdf(X)(x) == Piecewise((log(a)/(log(a) - log(b)) - log(x)/(log(a) - log(b)), a <= x), (0, True)) + X = Reciprocal('x', 5, 30) + + assert E(X) == 25/(log(30) - log(5)) + assert P(X < 4) == S.Zero + assert P(X < 20) == log(20) / (log(30) - log(5)) - log(5) / (log(30) - log(5)) + assert cdf(X)(10) == log(10) / (log(30) - log(5)) - log(5) / (log(30) - log(5)) + + a = symbols('a', nonpositive=True) + raises(ValueError, lambda: Reciprocal('x', a, b)) + + a = symbols('a', positive=True) + b = symbols('b', positive=True) + raises(ValueError, lambda: Reciprocal('x', a + b, a)) + +def test_shiftedgompertz(): + b = Symbol("b", positive=True) + eta = Symbol("eta", positive=True) + X = ShiftedGompertz("x", b, eta) + assert density(X)(x) == b*(eta*(1 - exp(-b*x)) + 1)*exp(-b*x)*exp(-eta*exp(-b*x)) + + +def test_studentt(): + nu = Symbol("nu", positive=True) + + X = StudentT('x', nu) + assert density(X)(x) == (1 + x**2/nu)**(-nu/2 - S.Half)/(sqrt(nu)*beta(S.Half, nu/2)) + assert cdf(X)(x) == S.Half + x*gamma(nu/2 + S.Half)*hyper((S.Half, nu/2 + S.Half), + (Rational(3, 2),), -x**2/nu)/(sqrt(pi)*sqrt(nu)*gamma(nu/2)) + raises(NotImplementedError, lambda: moment_generating_function(X)) + +def test_trapezoidal(): + a = Symbol("a", real=True) + b = Symbol("b", real=True) + c = Symbol("c", real=True) + d = Symbol("d", real=True) + + X = Trapezoidal('x', a, b, c, d) + assert density(X)(x) == Piecewise(((-2*a + 2*x)/((-a + b)*(-a - b + c + d)), (a <= x) & (x < b)), + (2/(-a - b + c + d), (b <= x) & (x < c)), + ((2*d - 2*x)/((-c + d)*(-a - b + c + d)), (c <= x) & (x <= d)), + (0, True)) + + X = Trapezoidal('x', 0, 1, 2, 3) + assert E(X) == Rational(3, 2) + assert variance(X) == Rational(5, 12) + assert P(X < 2) == Rational(3, 4) + assert median(X) == FiniteSet(Rational(3, 2)) + +def test_triangular(): + a = Symbol("a") + b = Symbol("b") + c = Symbol("c") + + X = Triangular('x', a, b, c) + assert pspace(X).domain.set == Interval(a, b) + assert str(density(X)(x)) == ("Piecewise(((-2*a + 2*x)/((-a + b)*(-a + c)), (a <= x) & (c > x)), " + "(2/(-a + b), Eq(c, x)), ((2*b - 2*x)/((-a + b)*(b - c)), (b >= x) & (c < x)), (0, True))") + + #Tests moment_generating_function + assert moment_generating_function(X)(x).expand() == \ + ((-2*(-a + b)*exp(c*x) + 2*(-a + c)*exp(b*x) + 2*(b - c)*exp(a*x))/(x**2*(-a + b)*(-a + c)*(b - c))).expand() + assert str(characteristic_function(X)(x)) == \ + '(2*(-a + b)*exp(I*c*x) - 2*(-a + c)*exp(I*b*x) - 2*(b - c)*exp(I*a*x))/(x**2*(-a + b)*(-a + c)*(b - c))' + +def test_quadratic_u(): + a = Symbol("a", real=True) + b = Symbol("b", real=True) + + X = QuadraticU("x", a, b) + Y = QuadraticU("x", 1, 2) + + assert pspace(X).domain.set == Interval(a, b) + # Tests _moment_generating_function + assert moment_generating_function(Y)(1) == -15*exp(2) + 27*exp(1) + assert moment_generating_function(Y)(2) == -9*exp(4)/2 + 21*exp(2)/2 + + assert characteristic_function(Y)(1) == 3*I*(-1 + 4*I)*exp(I*exp(2*I)) + assert density(X)(x) == (Piecewise((12*(x - a/2 - b/2)**2/(-a + b)**3, + And(x <= b, a <= x)), (0, True))) + + +def test_uniform(): + l = Symbol('l', real=True) + w = Symbol('w', positive=True) + X = Uniform('x', l, l + w) + + assert E(X) == l + w/2 + assert variance(X).expand() == w**2/12 + + # With numbers all is well + X = Uniform('x', 3, 5) + assert P(X < 3) == 0 and P(X > 5) == 0 + assert P(X < 4) == P(X > 4) == S.Half + assert median(X) == FiniteSet(4) + + z = Symbol('z') + p = density(X)(z) + assert p.subs(z, 3.7) == S.Half + assert p.subs(z, -1) == 0 + assert p.subs(z, 6) == 0 + + c = cdf(X) + assert c(2) == 0 and c(3) == 0 + assert c(Rational(7, 2)) == Rational(1, 4) + assert c(5) == 1 and c(6) == 1 + + +@XFAIL +@slow +def test_uniform_P(): + """ This stopped working because SingleContinuousPSpace.compute_density no + longer calls integrate on a DiracDelta but rather just solves directly. + integrate used to call UniformDistribution.expectation which special-cased + subsed out the Min and Max terms that Uniform produces + + I decided to regress on this class for general cleanliness (and I suspect + speed) of the algorithm. + """ + l = Symbol('l', real=True) + w = Symbol('w', positive=True) + X = Uniform('x', l, l + w) + assert P(X < l) == 0 and P(X > l + w) == 0 + + +def test_uniformsum(): + n = Symbol("n", integer=True) + _k = Dummy("k") + x = Symbol("x") + + X = UniformSum('x', n) + res = Sum((-1)**_k*(-_k + x)**(n - 1)*binomial(n, _k), (_k, 0, floor(x)))/factorial(n - 1) + assert density(X)(x).dummy_eq(res) + + #Tests set functions + assert X.pspace.domain.set == Interval(0, n) + + #Tests the characteristic_function + assert characteristic_function(X)(x) == (-I*(exp(I*x) - 1)/x)**n + + #Tests the moment_generating_function + assert moment_generating_function(X)(x) == ((exp(x) - 1)/x)**n + + +def test_von_mises(): + mu = Symbol("mu") + k = Symbol("k", positive=True) + + X = VonMises("x", mu, k) + assert density(X)(x) == exp(k*cos(x - mu))/(2*pi*besseli(0, k)) + + +def test_weibull(): + a, b = symbols('a b', positive=True) + # FIXME: simplify(E(X)) seems to hang without extended_positive=True + # On a Linux machine this had a rapid memory leak... + # a, b = symbols('a b', positive=True) + X = Weibull('x', a, b) + + assert E(X).expand() == a * gamma(1 + 1/b) + assert variance(X).expand() == (a**2 * gamma(1 + 2/b) - E(X)**2).expand() + assert simplify(skewness(X)) == (2*gamma(1 + 1/b)**3 - 3*gamma(1 + 1/b)*gamma(1 + 2/b) + gamma(1 + 3/b))/(-gamma(1 + 1/b)**2 + gamma(1 + 2/b))**Rational(3, 2) + assert simplify(kurtosis(X)) == (-3*gamma(1 + 1/b)**4 +\ + 6*gamma(1 + 1/b)**2*gamma(1 + 2/b) - 4*gamma(1 + 1/b)*gamma(1 + 3/b) + gamma(1 + 4/b))/(gamma(1 + 1/b)**2 - gamma(1 + 2/b))**2 + +def test_weibull_numeric(): + # Test for integers and rationals + a = 1 + bvals = [S.Half, 1, Rational(3, 2), 5] + for b in bvals: + X = Weibull('x', a, b) + assert simplify(E(X)) == expand_func(a * gamma(1 + 1/S(b))) + assert simplify(variance(X)) == simplify( + a**2 * gamma(1 + 2/S(b)) - E(X)**2) + # Not testing Skew... it's slow with int/frac values > 3/2 + + +def test_wignersemicircle(): + R = Symbol("R", positive=True) + + X = WignerSemicircle('x', R) + assert pspace(X).domain.set == Interval(-R, R) + assert density(X)(x) == 2*sqrt(-x**2 + R**2)/(pi*R**2) + assert E(X) == 0 + + + #Tests ChiNoncentralDistribution + assert characteristic_function(X)(x) == \ + Piecewise((2*besselj(1, R*x)/(R*x), Ne(x, 0)), (1, True)) + + +def test_input_value_assertions(): + a, b = symbols('a b') + p, q = symbols('p q', positive=True) + m, n = symbols('m n', positive=False, real=True) + + raises(ValueError, lambda: Normal('x', 3, 0)) + raises(ValueError, lambda: Normal('x', m, n)) + Normal('X', a, p) # No error raised + raises(ValueError, lambda: Exponential('x', m)) + Exponential('Ex', p) # No error raised + for fn in [Pareto, Weibull, Beta, Gamma]: + raises(ValueError, lambda: fn('x', m, p)) + raises(ValueError, lambda: fn('x', p, n)) + fn('x', p, q) # No error raised + + +def test_unevaluated(): + X = Normal('x', 0, 1) + k = Dummy('k') + expr1 = Integral(sqrt(2)*k*exp(-k**2/2)/(2*sqrt(pi)), (k, -oo, oo)) + expr2 = Integral(sqrt(2)*exp(-k**2/2)/(2*sqrt(pi)), (k, 0, oo)) + with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed + assert E(X, evaluate=False).rewrite(Integral).dummy_eq(expr1) + assert E(X + 1, evaluate=False).rewrite(Integral).dummy_eq(expr1 + 1) + assert P(X > 0, evaluate=False).rewrite(Integral).dummy_eq(expr2) + + assert P(X > 0, X**2 < 1) == S.Half + + +def test_probability_unevaluated(): + T = Normal('T', 30, 3) + with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed + assert type(P(T > 33, evaluate=False)) == Probability + + +def test_density_unevaluated(): + X = Normal('X', 0, 1) + Y = Normal('Y', 0, 2) + assert isinstance(density(X+Y, evaluate=False)(z), Integral) + + +def test_NormalDistribution(): + nd = NormalDistribution(0, 1) + x = Symbol('x') + assert nd.cdf(x) == erf(sqrt(2)*x/2)/2 + S.Half + assert nd.expectation(1, x) == 1 + assert nd.expectation(x, x) == 0 + assert nd.expectation(x**2, x) == 1 + #Test issue 10076 + a = SingleContinuousPSpace(x, NormalDistribution(2, 4)) + _z = Dummy('_z') + + expected1 = Integral(sqrt(2)*exp(-(_z - 2)**2/32)/(8*sqrt(pi)),(_z, -oo, 1)) + assert a.probability(x < 1, evaluate=False).dummy_eq(expected1) is True + + expected2 = Integral(sqrt(2)*exp(-(_z - 2)**2/32)/(8*sqrt(pi)),(_z, 1, oo)) + assert a.probability(x > 1, evaluate=False).dummy_eq(expected2) is True + + b = SingleContinuousPSpace(x, NormalDistribution(1, 9)) + + expected3 = Integral(sqrt(2)*exp(-(_z - 1)**2/162)/(18*sqrt(pi)),(_z, 6, oo)) + assert b.probability(x > 6, evaluate=False).dummy_eq(expected3) is True + + expected4 = Integral(sqrt(2)*exp(-(_z - 1)**2/162)/(18*sqrt(pi)),(_z, -oo, 6)) + assert b.probability(x < 6, evaluate=False).dummy_eq(expected4) is True + + +def test_random_parameters(): + mu = Normal('mu', 2, 3) + meas = Normal('T', mu, 1) + assert density(meas, evaluate=False)(z) + assert isinstance(pspace(meas), CompoundPSpace) + X = Normal('x', [1, 2], [[1, 0], [0, 1]]) + assert isinstance(pspace(X).distribution, MultivariateNormalDistribution) + assert density(meas)(z).simplify() == sqrt(5)*exp(-z**2/20 + z/5 - S(1)/5)/(10*sqrt(pi)) + + +def test_random_parameters_given(): + mu = Normal('mu', 2, 3) + meas = Normal('T', mu, 1) + assert given(meas, Eq(mu, 5)) == Normal('T', 5, 1) + + +def test_conjugate_priors(): + mu = Normal('mu', 2, 3) + x = Normal('x', mu, 1) + assert isinstance(simplify(density(mu, Eq(x, y), evaluate=False)(z)), + Mul) + + +def test_difficult_univariate(): + """ Since using solve in place of deltaintegrate we're able to perform + substantially more complex density computations on single continuous random + variables """ + x = Normal('x', 0, 1) + assert density(x**3) + assert density(exp(x**2)) + assert density(log(x)) + + +def test_issue_10003(): + X = Exponential('x', 3) + G = Gamma('g', 1, 2) + assert P(X < -1) is S.Zero + assert P(G < -1) is S.Zero + + +def test_precomputed_cdf(): + x = symbols("x", real=True) + mu = symbols("mu", real=True) + sigma, xm, alpha = symbols("sigma xm alpha", positive=True) + n = symbols("n", integer=True, positive=True) + distribs = [ + Normal("X", mu, sigma), + Pareto("P", xm, alpha), + ChiSquared("C", n), + Exponential("E", sigma), + # LogNormal("L", mu, sigma), + ] + for X in distribs: + compdiff = cdf(X)(x) - simplify(X.pspace.density.compute_cdf()(x)) + compdiff = simplify(compdiff.rewrite(erfc)) + assert compdiff == 0 + + +@slow +def test_precomputed_characteristic_functions(): + import mpmath + + def test_cf(dist, support_lower_limit, support_upper_limit): + pdf = density(dist) + t = Symbol('t') + + # first function is the hardcoded CF of the distribution + cf1 = lambdify([t], characteristic_function(dist)(t), 'mpmath') + + # second function is the Fourier transform of the density function + f = lambdify([x, t], pdf(x)*exp(I*x*t), 'mpmath') + cf2 = lambda t: mpmath.quad(lambda x: f(x, t), [support_lower_limit, support_upper_limit], maxdegree=10) + + # compare the two functions at various points + for test_point in [2, 5, 8, 11]: + n1 = cf1(test_point) + n2 = cf2(test_point) + + assert abs(re(n1) - re(n2)) < 1e-12 + assert abs(im(n1) - im(n2)) < 1e-12 + + test_cf(Beta('b', 1, 2), 0, 1) + test_cf(Chi('c', 3), 0, mpmath.inf) + test_cf(ChiSquared('c', 2), 0, mpmath.inf) + test_cf(Exponential('e', 6), 0, mpmath.inf) + test_cf(Logistic('l', 1, 2), -mpmath.inf, mpmath.inf) + test_cf(Normal('n', -1, 5), -mpmath.inf, mpmath.inf) + test_cf(RaisedCosine('r', 3, 1), 2, 4) + test_cf(Rayleigh('r', 0.5), 0, mpmath.inf) + test_cf(Uniform('u', -1, 1), -1, 1) + test_cf(WignerSemicircle('w', 3), -3, 3) + + +def test_long_precomputed_cdf(): + x = symbols("x", real=True) + distribs = [ + Arcsin("A", -5, 9), + Dagum("D", 4, 10, 3), + Erlang("E", 14, 5), + Frechet("F", 2, 6, -3), + Gamma("G", 2, 7), + GammaInverse("GI", 3, 5), + Kumaraswamy("K", 6, 8), + Laplace("LA", -5, 4), + Logistic("L", -6, 7), + Nakagami("N", 2, 7), + StudentT("S", 4) + ] + for distr in distribs: + for _ in range(5): + assert tn(diff(cdf(distr)(x), x), density(distr)(x), x, a=0, b=0, c=1, d=0) + + US = UniformSum("US", 5) + pdf01 = density(US)(x).subs(floor(x), 0).doit() # pdf on (0, 1) + cdf01 = cdf(US, evaluate=False)(x).subs(floor(x), 0).doit() # cdf on (0, 1) + assert tn(diff(cdf01, x), pdf01, x, a=0, b=0, c=1, d=0) + + +def test_issue_13324(): + X = Uniform('X', 0, 1) + assert E(X, X > S.Half) == Rational(3, 4) + assert E(X, X > 0) == S.Half + +def test_issue_20756(): + X = Uniform('X', -1, +1) + Y = Uniform('Y', -1, +1) + assert E(X * Y) == S.Zero + assert E(X * ((Y + 1) - 1)) == S.Zero + assert E(Y * (X*(X + 1) - X*X)) == S.Zero + +def test_FiniteSet_prob(): + E = Exponential('E', 3) + N = Normal('N', 5, 7) + assert P(Eq(E, 1)) is S.Zero + assert P(Eq(N, 2)) is S.Zero + assert P(Eq(N, x)) is S.Zero + +def test_prob_neq(): + E = Exponential('E', 4) + X = ChiSquared('X', 4) + assert P(Ne(E, 2)) == 1 + assert P(Ne(X, 4)) == 1 + assert P(Ne(X, 4)) == 1 + assert P(Ne(X, 5)) == 1 + assert P(Ne(E, x)) == 1 + +def test_union(): + N = Normal('N', 3, 2) + assert simplify(P(N**2 - N > 2)) == \ + -erf(sqrt(2))/2 - erfc(sqrt(2)/4)/2 + Rational(3, 2) + assert simplify(P(N**2 - 4 > 0)) == \ + -erf(5*sqrt(2)/4)/2 - erfc(sqrt(2)/4)/2 + Rational(3, 2) + +def test_Or(): + N = Normal('N', 0, 1) + assert simplify(P(Or(N > 2, N < 1))) == \ + -erf(sqrt(2))/2 - erfc(sqrt(2)/2)/2 + Rational(3, 2) + assert P(Or(N < 0, N < 1)) == P(N < 1) + assert P(Or(N > 0, N < 0)) == 1 + + +def test_conditional_eq(): + E = Exponential('E', 1) + assert P(Eq(E, 1), Eq(E, 1)) == 1 + assert P(Eq(E, 1), Eq(E, 2)) == 0 + assert P(E > 1, Eq(E, 2)) == 1 + assert P(E < 1, Eq(E, 2)) == 0 + +def test_ContinuousDistributionHandmade(): + x = Symbol('x') + z = Dummy('z') + dens = Lambda(x, Piecewise((S.Half, (0<=x)&(x<1)), (0, (x>=1)&(x<2)), + (S.Half, (x>=2)&(x<3)), (0, True))) + dens = ContinuousDistributionHandmade(dens, set=Interval(0, 3)) + space = SingleContinuousPSpace(z, dens) + assert dens.pdf == Lambda(x, Piecewise((S(1)/2, (x >= 0) & (x < 1)), + (0, (x >= 1) & (x < 2)), (S(1)/2, (x >= 2) & (x < 3)), (0, True))) + assert median(space.value) == Interval(1, 2) + assert E(space.value) == Rational(3, 2) + assert variance(space.value) == Rational(13, 12) + + +def test_issue_16318(): + # test compute_expectation function of the SingleContinuousDomain + N = SingleContinuousDomain(x, Interval(0, 1)) + raises(ValueError, lambda: SingleContinuousDomain.compute_expectation(N, x+1, {x, y})) + +def test_compute_density(): + X = Normal('X', 0, Symbol("sigma")**2) + raises(ValueError, lambda: density(X**5 + X)) diff --git a/phivenv/Lib/site-packages/sympy/stats/tests/test_discrete_rv.py b/phivenv/Lib/site-packages/sympy/stats/tests/test_discrete_rv.py new file mode 100644 index 0000000000000000000000000000000000000000..39650a1a08e42840aaf8b06c1eb6e60c92a57f23 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/stats/tests/test_discrete_rv.py @@ -0,0 +1,312 @@ +from sympy.concrete.summations import Sum +from sympy.core.numbers import (I, Rational, oo, pi) +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.functions.elementary.complexes import (im, re) +from sympy.functions.elementary.exponential import log +from sympy.functions.elementary.integers import floor +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.special.bessel import besseli +from sympy.functions.special.beta_functions import beta +from sympy.functions.special.zeta_functions import zeta +from sympy.sets.sets import FiniteSet +from sympy.simplify.simplify import simplify +from sympy.utilities.lambdify import lambdify +from sympy.core.relational import Eq, Ne +from sympy.functions.elementary.exponential import exp +from sympy.logic.boolalg import Or +from sympy.sets.fancysets import Range +from sympy.stats import (P, E, variance, density, characteristic_function, + where, moment_generating_function, skewness, cdf, + kurtosis, coskewness) +from sympy.stats.drv_types import (PoissonDistribution, GeometricDistribution, + FlorySchulz, Poisson, Geometric, Hermite, Logarithmic, + NegativeBinomial, Skellam, YuleSimon, Zeta, + DiscreteRV) +from sympy.testing.pytest import slow, nocache_fail, raises, skip +from sympy.stats.symbolic_probability import Expectation +from sympy.functions.combinatorial.factorials import FallingFactorial + +x = Symbol('x') + + +def test_PoissonDistribution(): + l = 3 + p = PoissonDistribution(l) + assert abs(p.cdf(10).evalf() - 1) < .001 + assert abs(p.cdf(10.4).evalf() - 1) < .001 + assert p.expectation(x, x) == l + assert p.expectation(x**2, x) - p.expectation(x, x)**2 == l + + +def test_Poisson(): + l = 3 + x = Poisson('x', l) + assert E(x) == l + assert E(2*x) == 2*l + assert variance(x) == l + assert density(x) == PoissonDistribution(l) + assert isinstance(E(x, evaluate=False), Expectation) + assert isinstance(E(2*x, evaluate=False), Expectation) + # issue 8248 + assert x.pspace.compute_expectation(1) == 1 + # issue 27344 + try: + import numpy as np + except ImportError: + skip("numpy not installed") + y = Poisson('y', np.float64(4.72544290380919e-11)) + assert E(y) == 4.72544290380919e-11 + y = Poisson('y', np.float64(4.725442903809197e-11)) + assert E(y) == 4.725442903809197e-11 + l2 = 5 + z = Poisson('z', l2) + assert E(z) == l2 + assert E(FallingFactorial(z, 3)) == l2**3 + assert E(z**2) == l2 + l2**2 + + +def test_FlorySchulz(): + a = Symbol("a") + z = Symbol("z") + x = FlorySchulz('x', a) + assert E(x) == (2 - a)/a + assert (variance(x) - 2*(1 - a)/a**2).simplify() == S(0) + assert density(x)(z) == a**2*z*(1 - a)**(z - 1) + + +@slow +def test_GeometricDistribution(): + p = S.One / 5 + d = GeometricDistribution(p) + assert d.expectation(x, x) == 1/p + assert d.expectation(x**2, x) - d.expectation(x, x)**2 == (1-p)/p**2 + assert abs(d.cdf(20000).evalf() - 1) < .001 + assert abs(d.cdf(20000.8).evalf() - 1) < .001 + G = Geometric('G', p=S(1)/4) + assert cdf(G)(S(7)/2) == P(G <= S(7)/2) + + X = Geometric('X', Rational(1, 5)) + Y = Geometric('Y', Rational(3, 10)) + assert coskewness(X, X + Y, X + 2*Y).simplify() == sqrt(230)*Rational(81, 1150) + + +def test_Hermite(): + a1 = Symbol("a1", positive=True) + a2 = Symbol("a2", negative=True) + raises(ValueError, lambda: Hermite("H", a1, a2)) + + a1 = Symbol("a1", negative=True) + a2 = Symbol("a2", positive=True) + raises(ValueError, lambda: Hermite("H", a1, a2)) + + a1 = Symbol("a1", positive=True) + x = Symbol("x") + H = Hermite("H", a1, a2) + assert moment_generating_function(H)(x) == exp(a1*(exp(x) - 1) + + a2*(exp(2*x) - 1)) + assert characteristic_function(H)(x) == exp(a1*(exp(I*x) - 1) + + a2*(exp(2*I*x) - 1)) + assert E(H) == a1 + 2*a2 + + H = Hermite("H", a1=5, a2=4) + assert density(H)(2) == 33*exp(-9)/2 + assert E(H) == 13 + assert variance(H) == 21 + assert kurtosis(H) == Rational(464,147) + assert skewness(H) == 37*sqrt(21)/441 + +def test_Logarithmic(): + p = S.Half + x = Logarithmic('x', p) + assert E(x) == -p / ((1 - p) * log(1 - p)) + assert variance(x) == -1/log(2)**2 + 2/log(2) + assert E(2*x**2 + 3*x + 4) == 4 + 7 / log(2) + assert isinstance(E(x, evaluate=False), Expectation) + + +@nocache_fail +def test_negative_binomial(): + r = 5 + p = S.One / 3 + x = NegativeBinomial('x', r, p) + assert E(x) == r * (1 - p) / p + # This hangs when run with the cache disabled: + assert variance(x) == r * (1 - p) / p**2 + assert E(x**5 + 2*x + 3) == E(x**5) + 2*E(x) + 3 == Rational(796473, 1) + assert isinstance(E(x, evaluate=False), Expectation) + + +def test_skellam(): + mu1 = Symbol('mu1') + mu2 = Symbol('mu2') + z = Symbol('z') + X = Skellam('x', mu1, mu2) + + assert density(X)(z) == (mu1/mu2)**(z/2) * \ + exp(-mu1 - mu2)*besseli(z, 2*sqrt(mu1*mu2)) + assert skewness(X).expand() == mu1/(mu1*sqrt(mu1 + mu2) + mu2 * + sqrt(mu1 + mu2)) - mu2/(mu1*sqrt(mu1 + mu2) + mu2*sqrt(mu1 + mu2)) + assert variance(X).expand() == mu1 + mu2 + assert E(X) == mu1 - mu2 + assert characteristic_function(X)(z) == exp( + mu1*exp(I*z) - mu1 - mu2 + mu2*exp(-I*z)) + assert moment_generating_function(X)(z) == exp( + mu1*exp(z) - mu1 - mu2 + mu2*exp(-z)) + + +def test_yule_simon(): + from sympy.core.singleton import S + rho = S(3) + x = YuleSimon('x', rho) + assert simplify(E(x)) == rho / (rho - 1) + assert simplify(variance(x)) == rho**2 / ((rho - 1)**2 * (rho - 2)) + assert isinstance(E(x, evaluate=False), Expectation) + # To test the cdf function + assert cdf(x)(x) == Piecewise((-beta(floor(x), 4)*floor(x) + 1, x >= 1), (0, True)) + + +def test_zeta(): + s = S(5) + x = Zeta('x', s) + assert E(x) == zeta(s-1) / zeta(s) + assert simplify(variance(x)) == ( + zeta(s) * zeta(s-2) - zeta(s-1)**2) / zeta(s)**2 + + +def test_discrete_probability(): + X = Geometric('X', Rational(1, 5)) + Y = Poisson('Y', 4) + G = Geometric('e', x) + assert P(Eq(X, 3)) == Rational(16, 125) + assert P(X < 3) == Rational(9, 25) + assert P(X > 3) == Rational(64, 125) + assert P(X >= 3) == Rational(16, 25) + assert P(X <= 3) == Rational(61, 125) + assert P(Ne(X, 3)) == Rational(109, 125) + assert P(Eq(Y, 3)) == 32*exp(-4)/3 + assert P(Y < 3) == 13*exp(-4) + assert P(Y > 3).equals(32*(Rational(-71, 32) + 3*exp(4)/32)*exp(-4)/3) + assert P(Y >= 3).equals(32*(Rational(-39, 32) + 3*exp(4)/32)*exp(-4)/3) + assert P(Y <= 3) == 71*exp(-4)/3 + assert P(Ne(Y, 3)).equals( + 13*exp(-4) + 32*(Rational(-71, 32) + 3*exp(4)/32)*exp(-4)/3) + assert P(X < S.Infinity) is S.One + assert P(X > S.Infinity) is S.Zero + assert P(G < 3) == x*(2-x) + assert P(Eq(G, 3)) == x*(-x + 1)**2 + + +def test_DiscreteRV(): + p = S(1)/2 + x = Symbol('x', integer=True, positive=True) + pdf = p*(1 - p)**(x - 1) # pdf of Geometric Distribution + D = DiscreteRV(x, pdf, set=S.Naturals, check=True) + assert E(D) == E(Geometric('G', S(1)/2)) == 2 + assert P(D > 3) == S(1)/8 + assert D.pspace.domain.set == S.Naturals + raises(ValueError, lambda: DiscreteRV(x, x, FiniteSet(*range(4)), check=True)) + + # purposeful invalid pmf but it should not raise since check=False + # see test_drv_types.test_ContinuousRV for explanation + X = DiscreteRV(x, 1/x, S.Naturals) + assert P(X < 2) == 1 + assert E(X) == oo + +def test_precomputed_characteristic_functions(): + import mpmath + + def test_cf(dist, support_lower_limit, support_upper_limit): + pdf = density(dist) + t = S('t') + x = S('x') + + # first function is the hardcoded CF of the distribution + cf1 = lambdify([t], characteristic_function(dist)(t), 'mpmath') + + # second function is the Fourier transform of the density function + f = lambdify([x, t], pdf(x)*exp(I*x*t), 'mpmath') + cf2 = lambda t: mpmath.nsum(lambda x: f(x, t), [ + support_lower_limit, support_upper_limit], maxdegree=10) + + # compare the two functions at various points + for test_point in [2, 5, 8, 11]: + n1 = cf1(test_point) + n2 = cf2(test_point) + + assert abs(re(n1) - re(n2)) < 1e-12 + assert abs(im(n1) - im(n2)) < 1e-12 + + test_cf(Geometric('g', Rational(1, 3)), 1, mpmath.inf) + test_cf(Logarithmic('l', Rational(1, 5)), 1, mpmath.inf) + test_cf(NegativeBinomial('n', 5, Rational(1, 7)), 0, mpmath.inf) + test_cf(Poisson('p', 5), 0, mpmath.inf) + test_cf(YuleSimon('y', 5), 1, mpmath.inf) + test_cf(Zeta('z', 5), 1, mpmath.inf) + + +def test_moment_generating_functions(): + t = S('t') + + geometric_mgf = moment_generating_function(Geometric('g', S.Half))(t) + assert geometric_mgf.diff(t).subs(t, 0) == 2 + + logarithmic_mgf = moment_generating_function(Logarithmic('l', S.Half))(t) + assert logarithmic_mgf.diff(t).subs(t, 0) == 1/log(2) + + negative_binomial_mgf = moment_generating_function( + NegativeBinomial('n', 5, Rational(1, 3)))(t) + assert negative_binomial_mgf.diff(t).subs(t, 0) == Rational(10, 1) + + poisson_mgf = moment_generating_function(Poisson('p', 5))(t) + assert poisson_mgf.diff(t).subs(t, 0) == 5 + + skellam_mgf = moment_generating_function(Skellam('s', 1, 1))(t) + assert skellam_mgf.diff(t).subs( + t, 2) == (-exp(-2) + exp(2))*exp(-2 + exp(-2) + exp(2)) + + yule_simon_mgf = moment_generating_function(YuleSimon('y', 3))(t) + assert simplify(yule_simon_mgf.diff(t).subs(t, 0)) == Rational(3, 2) + + zeta_mgf = moment_generating_function(Zeta('z', 5))(t) + assert zeta_mgf.diff(t).subs(t, 0) == pi**4/(90*zeta(5)) + + +def test_Or(): + X = Geometric('X', S.Half) + assert P(Or(X < 3, X > 4)) == Rational(13, 16) + assert P(Or(X > 2, X > 1)) == P(X > 1) + assert P(Or(X >= 3, X < 3)) == 1 + + +def test_where(): + X = Geometric('X', Rational(1, 5)) + Y = Poisson('Y', 4) + assert where(X**2 > 4).set == Range(3, S.Infinity, 1) + assert where(X**2 >= 4).set == Range(2, S.Infinity, 1) + assert where(Y**2 < 9).set == Range(0, 3, 1) + assert where(Y**2 <= 9).set == Range(0, 4, 1) + + +def test_conditional(): + X = Geometric('X', Rational(2, 3)) + Y = Poisson('Y', 3) + assert P(X > 2, X > 3) == 1 + assert P(X > 3, X > 2) == Rational(1, 3) + assert P(Y > 2, Y < 2) == 0 + assert P(Eq(Y, 3), Y >= 0) == 9*exp(-3)/2 + assert P(Eq(Y, 3), Eq(Y, 2)) == 0 + assert P(X < 2, Eq(X, 2)) == 0 + assert P(X > 2, Eq(X, 3)) == 1 + + +def test_product_spaces(): + X1 = Geometric('X1', S.Half) + X2 = Geometric('X2', Rational(1, 3)) + assert str(P(X1 + X2 < 3).rewrite(Sum)) == ( + "Sum(Piecewise((1/(4*2**n), n >= -1), (0, True)), (n, -oo, -1))/3") + assert str(P(X1 + X2 > 3).rewrite(Sum)) == ( + 'Sum(Piecewise((2**(X2 - n - 2)*(2/3)**(X2 - 1)/6, ' + 'X2 - n <= 2), (0, True)), (X2, 1, oo), (n, 1, oo))') + assert P(Eq(X1 + X2, 3)) == Rational(1, 12) diff --git a/phivenv/Lib/site-packages/sympy/stats/tests/test_error_prop.py b/phivenv/Lib/site-packages/sympy/stats/tests/test_error_prop.py new file mode 100644 index 0000000000000000000000000000000000000000..483fb4c36e202d744faeb355606ff9803a516873 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/stats/tests/test_error_prop.py @@ -0,0 +1,60 @@ +from sympy.core.function import Function +from sympy.core.symbol import symbols +from sympy.functions.elementary.exponential import exp +from sympy.stats.error_prop import variance_prop +from sympy.stats.symbolic_probability import (RandomSymbol, Variance, + Covariance) + + +def test_variance_prop(): + x, y, z = symbols('x y z') + phi, t = consts = symbols('phi t') + a = RandomSymbol(x) + var_x = Variance(a) + var_y = Variance(RandomSymbol(y)) + var_z = Variance(RandomSymbol(z)) + f = Function('f')(x) + cases = { + x + y: var_x + var_y, + a + y: var_x + var_y, + x + y + z: var_x + var_y + var_z, + 2*x: 4*var_x, + x*y: var_x*y**2 + var_y*x**2, + 1/x: var_x/x**4, + x/y: (var_x*y**2 + var_y*x**2)/y**4, + exp(x): var_x*exp(2*x), + exp(2*x): 4*var_x*exp(4*x), + exp(-x*t): t**2*var_x*exp(-2*t*x), + f: Variance(f), + } + for inp, out in cases.items(): + obs = variance_prop(inp, consts=consts) + assert out == obs + +def test_variance_prop_with_covar(): + x, y, z = symbols('x y z') + phi, t = consts = symbols('phi t') + a = RandomSymbol(x) + var_x = Variance(a) + b = RandomSymbol(y) + var_y = Variance(b) + c = RandomSymbol(z) + var_z = Variance(c) + covar_x_y = Covariance(a, b) + covar_x_z = Covariance(a, c) + covar_y_z = Covariance(b, c) + cases = { + x + y: var_x + var_y + 2*covar_x_y, + a + y: var_x + var_y + 2*covar_x_y, + x + y + z: var_x + var_y + var_z + \ + 2*covar_x_y + 2*covar_x_z + 2*covar_y_z, + 2*x: 4*var_x, + x*y: var_x*y**2 + var_y*x**2 + 2*covar_x_y/(x*y), + 1/x: var_x/x**4, + exp(x): var_x*exp(2*x), + exp(2*x): 4*var_x*exp(4*x), + exp(-x*t): t**2*var_x*exp(-2*t*x), + } + for inp, out in cases.items(): + obs = variance_prop(inp, consts=consts, include_covar=True) + assert out == obs diff --git a/phivenv/Lib/site-packages/sympy/stats/tests/test_finite_rv.py b/phivenv/Lib/site-packages/sympy/stats/tests/test_finite_rv.py new file mode 100644 index 0000000000000000000000000000000000000000..93bf0211a26ecc32d7f18c7e2d8236859857e445 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/stats/tests/test_finite_rv.py @@ -0,0 +1,509 @@ +from sympy.concrete.summations import Sum +from sympy.core.containers import (Dict, Tuple) +from sympy.core.function import Function +from sympy.core.numbers import (I, Rational, nan) +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol, symbols) +from sympy.core.sympify import sympify +from sympy.functions.combinatorial.factorials import binomial +from sympy.functions.combinatorial.numbers import harmonic +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import cos +from sympy.functions.special.beta_functions import beta +from sympy.logic.boolalg import (And, Or) +from sympy.polys.polytools import cancel +from sympy.sets.sets import FiniteSet +from sympy.simplify.simplify import simplify +from sympy.matrices import Matrix +from sympy.stats import (DiscreteUniform, Die, Bernoulli, Coin, Binomial, BetaBinomial, + Hypergeometric, Rademacher, IdealSoliton, RobustSoliton, P, E, variance, + covariance, skewness, density, where, FiniteRV, pspace, cdf, + correlation, moment, cmoment, smoment, characteristic_function, + moment_generating_function, quantile, kurtosis, median, coskewness) +from sympy.stats.frv_types import DieDistribution, BinomialDistribution, \ + HypergeometricDistribution +from sympy.stats.rv import Density +from sympy.testing.pytest import raises + + +def BayesTest(A, B): + assert P(A, B) == P(And(A, B)) / P(B) + assert P(A, B) == P(B, A) * P(A) / P(B) + + +def test_discreteuniform(): + # Symbolic + a, b, c, t = symbols('a b c t') + X = DiscreteUniform('X', [a, b, c]) + + assert E(X) == (a + b + c)/3 + assert simplify(variance(X) + - ((a**2 + b**2 + c**2)/3 - (a/3 + b/3 + c/3)**2)) == 0 + assert P(Eq(X, a)) == P(Eq(X, b)) == P(Eq(X, c)) == S('1/3') + + Y = DiscreteUniform('Y', range(-5, 5)) + + # Numeric + assert E(Y) == S('-1/2') + assert variance(Y) == S('33/4') + assert median(Y) == FiniteSet(-1, 0) + + for x in range(-5, 5): + assert P(Eq(Y, x)) == S('1/10') + assert P(Y <= x) == S(x + 6)/10 + assert P(Y >= x) == S(5 - x)/10 + + assert dict(density(Die('D', 6)).items()) == \ + dict(density(DiscreteUniform('U', range(1, 7))).items()) + + assert characteristic_function(X)(t) == exp(I*a*t)/3 + exp(I*b*t)/3 + exp(I*c*t)/3 + assert moment_generating_function(X)(t) == exp(a*t)/3 + exp(b*t)/3 + exp(c*t)/3 + # issue 18611 + raises(ValueError, lambda: DiscreteUniform('Z', [a, a, a, b, b, c])) + +def test_dice(): + # TODO: Make iid method! + X, Y, Z = Die('X', 6), Die('Y', 6), Die('Z', 6) + a, b, t, p = symbols('a b t p') + + assert E(X) == 3 + S.Half + assert variance(X) == Rational(35, 12) + assert E(X + Y) == 7 + assert E(X + X) == 7 + assert E(a*X + b) == a*E(X) + b + assert variance(X + Y) == variance(X) + variance(Y) == cmoment(X + Y, 2) + assert variance(X + X) == 4 * variance(X) == cmoment(X + X, 2) + assert cmoment(X, 0) == 1 + assert cmoment(4*X, 3) == 64*cmoment(X, 3) + assert covariance(X, Y) is S.Zero + assert covariance(X, X + Y) == variance(X) + assert density(Eq(cos(X*S.Pi), 1))[True] == S.Half + assert correlation(X, Y) == 0 + assert correlation(X, Y) == correlation(Y, X) + assert smoment(X + Y, 3) == skewness(X + Y) + assert smoment(X + Y, 4) == kurtosis(X + Y) + assert smoment(X, 0) == 1 + assert P(X > 3) == S.Half + assert P(2*X > 6) == S.Half + assert P(X > Y) == Rational(5, 12) + assert P(Eq(X, Y)) == P(Eq(X, 1)) + + assert E(X, X > 3) == 5 == moment(X, 1, 0, X > 3) + assert E(X, Y > 3) == E(X) == moment(X, 1, 0, Y > 3) + assert E(X + Y, Eq(X, Y)) == E(2*X) + assert moment(X, 0) == 1 + assert moment(5*X, 2) == 25*moment(X, 2) + assert quantile(X)(p) == Piecewise((nan, (p > 1) | (p < 0)),\ + (S.One, p <= Rational(1, 6)), (S(2), p <= Rational(1, 3)), (S(3), p <= S.Half),\ + (S(4), p <= Rational(2, 3)), (S(5), p <= Rational(5, 6)), (S(6), p <= 1)) + + assert P(X > 3, X > 3) is S.One + assert P(X > Y, Eq(Y, 6)) is S.Zero + assert P(Eq(X + Y, 12)) == Rational(1, 36) + assert P(Eq(X + Y, 12), Eq(X, 6)) == Rational(1, 6) + + assert density(X + Y) == density(Y + Z) != density(X + X) + d = density(2*X + Y**Z) + assert d[S(22)] == Rational(1, 108) and d[S(4100)] == Rational(1, 216) and S(3130) not in d + + assert pspace(X).domain.as_boolean() == Or( + *[Eq(X.symbol, i) for i in [1, 2, 3, 4, 5, 6]]) + + assert where(X > 3).set == FiniteSet(4, 5, 6) + + assert characteristic_function(X)(t) == exp(6*I*t)/6 + exp(5*I*t)/6 + exp(4*I*t)/6 + exp(3*I*t)/6 + exp(2*I*t)/6 + exp(I*t)/6 + assert moment_generating_function(X)(t) == exp(6*t)/6 + exp(5*t)/6 + exp(4*t)/6 + exp(3*t)/6 + exp(2*t)/6 + exp(t)/6 + assert median(X) == FiniteSet(3, 4) + D = Die('D', 7) + assert median(D) == FiniteSet(4) + # Bayes test for die + BayesTest(X > 3, X + Y < 5) + BayesTest(Eq(X - Y, Z), Z > Y) + BayesTest(X > 3, X > 2) + + # arg test for die + raises(ValueError, lambda: Die('X', -1)) # issue 8105: negative sides. + raises(ValueError, lambda: Die('X', 0)) + raises(ValueError, lambda: Die('X', 1.5)) # issue 8103: non integer sides. + + # symbolic test for die + n, k = symbols('n, k', positive=True) + D = Die('D', n) + dens = density(D).dict + assert dens == Density(DieDistribution(n)) + assert set(dens.subs(n, 4).doit().keys()) == {1, 2, 3, 4} + assert set(dens.subs(n, 4).doit().values()) == {Rational(1, 4)} + k = Dummy('k', integer=True) + assert E(D).dummy_eq( + Sum(Piecewise((k/n, k <= n), (0, True)), (k, 1, n))) + assert variance(D).subs(n, 6).doit() == Rational(35, 12) + + ki = Dummy('ki') + cumuf = cdf(D)(k) + assert cumuf.dummy_eq( + Sum(Piecewise((1/n, (ki >= 1) & (ki <= n)), (0, True)), (ki, 1, k))) + assert cumuf.subs({n: 6, k: 2}).doit() == Rational(1, 3) + + t = Dummy('t') + cf = characteristic_function(D)(t) + assert cf.dummy_eq( + Sum(Piecewise((exp(ki*I*t)/n, (ki >= 1) & (ki <= n)), (0, True)), (ki, 1, n))) + assert cf.subs(n, 3).doit() == exp(3*I*t)/3 + exp(2*I*t)/3 + exp(I*t)/3 + mgf = moment_generating_function(D)(t) + assert mgf.dummy_eq( + Sum(Piecewise((exp(ki*t)/n, (ki >= 1) & (ki <= n)), (0, True)), (ki, 1, n))) + assert mgf.subs(n, 3).doit() == exp(3*t)/3 + exp(2*t)/3 + exp(t)/3 + +def test_given(): + X = Die('X', 6) + assert density(X, X > 5) == {S(6): S.One} + assert where(X > 2, X > 5).as_boolean() == Eq(X.symbol, 6) + + +def test_domains(): + X, Y = Die('x', 6), Die('y', 6) + x, y = X.symbol, Y.symbol + # Domains + d = where(X > Y) + assert d.condition == (x > y) + d = where(And(X > Y, Y > 3)) + assert d.as_boolean() == Or(And(Eq(x, 5), Eq(y, 4)), And(Eq(x, 6), + Eq(y, 5)), And(Eq(x, 6), Eq(y, 4))) + assert len(d.elements) == 3 + + assert len(pspace(X + Y).domain.elements) == 36 + + Z = Die('x', 4) + + raises(ValueError, lambda: P(X > Z)) # Two domains with same internal symbol + + assert pspace(X + Y).domain.set == FiniteSet(1, 2, 3, 4, 5, 6)**2 + + assert where(X > 3).set == FiniteSet(4, 5, 6) + assert X.pspace.domain.dict == FiniteSet( + *[Dict({X.symbol: i}) for i in range(1, 7)]) + + assert where(X > Y).dict == FiniteSet(*[Dict({X.symbol: i, Y.symbol: j}) + for i in range(1, 7) for j in range(1, 7) if i > j]) + +def test_bernoulli(): + p, a, b, t = symbols('p a b t') + X = Bernoulli('B', p, a, b) + + assert E(X) == a*p + b*(-p + 1) + assert density(X)[a] == p + assert density(X)[b] == 1 - p + assert characteristic_function(X)(t) == p * exp(I * a * t) + (-p + 1) * exp(I * b * t) + assert moment_generating_function(X)(t) == p * exp(a * t) + (-p + 1) * exp(b * t) + + X = Bernoulli('B', p, 1, 0) + z = Symbol("z") + + assert E(X) == p + assert simplify(variance(X)) == p*(1 - p) + assert E(a*X + b) == a*E(X) + b + assert simplify(variance(a*X + b)) == simplify(a**2 * variance(X)) + assert quantile(X)(z) == Piecewise((nan, (z > 1) | (z < 0)), (0, z <= 1 - p), (1, z <= 1)) + Y = Bernoulli('Y', Rational(1, 2)) + assert median(Y) == FiniteSet(0, 1) + Z = Bernoulli('Z', Rational(2, 3)) + assert median(Z) == FiniteSet(1) + raises(ValueError, lambda: Bernoulli('B', 1.5)) + raises(ValueError, lambda: Bernoulli('B', -0.5)) + + #issue 8248 + assert X.pspace.compute_expectation(1) == 1 + + p = Rational(1, 5) + X = Binomial('X', 5, p) + Y = Binomial('Y', 7, 2*p) + Z = Binomial('Z', 9, 3*p) + assert coskewness(Y + Z, X + Y, X + Z).simplify() == 0 + assert coskewness(Y + 2*X + Z, X + 2*Y + Z, X + 2*Z + Y).simplify() == \ + sqrt(1529)*Rational(12, 16819) + assert coskewness(Y + 2*X + Z, X + 2*Y + Z, X + 2*Z + Y, X < 2).simplify() \ + == -sqrt(357451121)*Rational(2812, 4646864573) + +def test_cdf(): + D = Die('D', 6) + o = S.One + + assert cdf( + D) == sympify({1: o/6, 2: o/3, 3: o/2, 4: 2*o/3, 5: 5*o/6, 6: o}) + + +def test_coins(): + C, D = Coin('C'), Coin('D') + H, T = symbols('H, T') + assert P(Eq(C, D)) == S.Half + assert density(Tuple(C, D)) == {(H, H): Rational(1, 4), (H, T): Rational(1, 4), + (T, H): Rational(1, 4), (T, T): Rational(1, 4)} + assert dict(density(C).items()) == {H: S.Half, T: S.Half} + + F = Coin('F', Rational(1, 10)) + assert P(Eq(F, H)) == Rational(1, 10) + + d = pspace(C).domain + + assert d.as_boolean() == Or(Eq(C.symbol, H), Eq(C.symbol, T)) + + raises(ValueError, lambda: P(C > D)) # Can't intelligently compare H to T + +def test_binomial_verify_parameters(): + raises(ValueError, lambda: Binomial('b', .2, .5)) + raises(ValueError, lambda: Binomial('b', 3, 1.5)) + +def test_binomial_numeric(): + nvals = range(5) + pvals = [0, Rational(1, 4), S.Half, Rational(3, 4), 1] + + for n in nvals: + for p in pvals: + X = Binomial('X', n, p) + assert E(X) == n*p + assert variance(X) == n*p*(1 - p) + if n > 0 and 0 < p < 1: + assert skewness(X) == (1 - 2*p)/sqrt(n*p*(1 - p)) + assert kurtosis(X) == 3 + (1 - 6*p*(1 - p))/(n*p*(1 - p)) + for k in range(n + 1): + assert P(Eq(X, k)) == binomial(n, k)*p**k*(1 - p)**(n - k) + +def test_binomial_quantile(): + X = Binomial('X', 50, S.Half) + assert quantile(X)(0.95) == S(31) + assert median(X) == FiniteSet(25) + + X = Binomial('X', 5, S.Half) + p = Symbol("p", positive=True) + assert quantile(X)(p) == Piecewise((nan, p > S.One), (S.Zero, p <= Rational(1, 32)),\ + (S.One, p <= Rational(3, 16)), (S(2), p <= S.Half), (S(3), p <= Rational(13, 16)),\ + (S(4), p <= Rational(31, 32)), (S(5), p <= S.One)) + assert median(X) == FiniteSet(2, 3) + + +def test_binomial_symbolic(): + n = 2 + p = symbols('p', positive=True) + X = Binomial('X', n, p) + t = Symbol('t') + + assert simplify(E(X)) == n*p == simplify(moment(X, 1)) + assert simplify(variance(X)) == n*p*(1 - p) == simplify(cmoment(X, 2)) + assert cancel(skewness(X) - (1 - 2*p)/sqrt(n*p*(1 - p))) == 0 + assert cancel((kurtosis(X)) - (3 + (1 - 6*p*(1 - p))/(n*p*(1 - p)))) == 0 + assert characteristic_function(X)(t) == p ** 2 * exp(2 * I * t) + 2 * p * (-p + 1) * exp(I * t) + (-p + 1) ** 2 + assert moment_generating_function(X)(t) == p ** 2 * exp(2 * t) + 2 * p * (-p + 1) * exp(t) + (-p + 1) ** 2 + + # Test ability to change success/failure winnings + H, T = symbols('H T') + Y = Binomial('Y', n, p, succ=H, fail=T) + assert simplify(E(Y) - (n*(H*p + T*(1 - p)))) == 0 + + # test symbolic dimensions + n = symbols('n') + B = Binomial('B', n, p) + raises(NotImplementedError, lambda: P(B > 2)) + assert density(B).dict == Density(BinomialDistribution(n, p, 1, 0)) + assert set(density(B).dict.subs(n, 4).doit().keys()) == \ + {S.Zero, S.One, S(2), S(3), S(4)} + assert set(density(B).dict.subs(n, 4).doit().values()) == \ + {(1 - p)**4, 4*p*(1 - p)**3, 6*p**2*(1 - p)**2, 4*p**3*(1 - p), p**4} + k = Dummy('k', integer=True) + assert E(B > 2).dummy_eq( + Sum(Piecewise((k*p**k*(1 - p)**(-k + n)*binomial(n, k), (k >= 0) + & (k <= n) & (k > 2)), (0, True)), (k, 0, n))) + +def test_beta_binomial(): + # verify parameters + raises(ValueError, lambda: BetaBinomial('b', .2, 1, 2)) + raises(ValueError, lambda: BetaBinomial('b', 2, -1, 2)) + raises(ValueError, lambda: BetaBinomial('b', 2, 1, -2)) + assert BetaBinomial('b', 2, 1, 1) + + # test numeric values + nvals = range(1,5) + alphavals = [Rational(1, 4), S.Half, Rational(3, 4), 1, 10] + betavals = [Rational(1, 4), S.Half, Rational(3, 4), 1, 10] + + for n in nvals: + for a in alphavals: + for b in betavals: + X = BetaBinomial('X', n, a, b) + assert E(X) == moment(X, 1) + assert variance(X) == cmoment(X, 2) + + # test symbolic + n, a, b = symbols('a b n') + assert BetaBinomial('x', n, a, b) + n = 2 # Because we're using for loops, can't do symbolic n + a, b = symbols('a b', positive=True) + X = BetaBinomial('X', n, a, b) + t = Symbol('t') + + assert E(X).expand() == moment(X, 1).expand() + assert variance(X).expand() == cmoment(X, 2).expand() + assert skewness(X) == smoment(X, 3) + assert characteristic_function(X)(t) == exp(2*I*t)*beta(a + 2, b)/beta(a, b) +\ + 2*exp(I*t)*beta(a + 1, b + 1)/beta(a, b) + beta(a, b + 2)/beta(a, b) + assert moment_generating_function(X)(t) == exp(2*t)*beta(a + 2, b)/beta(a, b) +\ + 2*exp(t)*beta(a + 1, b + 1)/beta(a, b) + beta(a, b + 2)/beta(a, b) + +def test_hypergeometric_numeric(): + for N in range(1, 5): + for m in range(0, N + 1): + for n in range(1, N + 1): + X = Hypergeometric('X', N, m, n) + N, m, n = map(sympify, (N, m, n)) + assert sum(density(X).values()) == 1 + assert E(X) == n * m / N + if N > 1: + assert variance(X) == n*(m/N)*(N - m)/N*(N - n)/(N - 1) + # Only test for skewness when defined + if N > 2 and 0 < m < N and n < N: + assert skewness(X) == simplify((N - 2*m)*sqrt(N - 1)*(N - 2*n) + / (sqrt(n*m*(N - m)*(N - n))*(N - 2))) + +def test_hypergeometric_symbolic(): + N, m, n = symbols('N, m, n') + H = Hypergeometric('H', N, m, n) + dens = density(H).dict + expec = E(H > 2) + assert dens == Density(HypergeometricDistribution(N, m, n)) + assert dens.subs(N, 5).doit() == Density(HypergeometricDistribution(5, m, n)) + assert set(dens.subs({N: 3, m: 2, n: 1}).doit().keys()) == {S.Zero, S.One} + assert set(dens.subs({N: 3, m: 2, n: 1}).doit().values()) == {Rational(1, 3), Rational(2, 3)} + k = Dummy('k', integer=True) + assert expec.dummy_eq( + Sum(Piecewise((k*binomial(m, k)*binomial(N - m, -k + n) + /binomial(N, n), k > 2), (0, True)), (k, 0, n))) + +def test_rademacher(): + X = Rademacher('X') + t = Symbol('t') + + assert E(X) == 0 + assert variance(X) == 1 + assert density(X)[-1] == S.Half + assert density(X)[1] == S.Half + assert characteristic_function(X)(t) == exp(I*t)/2 + exp(-I*t)/2 + assert moment_generating_function(X)(t) == exp(t) / 2 + exp(-t) / 2 + +def test_ideal_soliton(): + raises(ValueError, lambda : IdealSoliton('sol', -12)) + raises(ValueError, lambda : IdealSoliton('sol', 13.2)) + raises(ValueError, lambda : IdealSoliton('sol', 0)) + f = Function('f') + raises(ValueError, lambda : density(IdealSoliton('sol', 10)).pmf(f)) + + k = Symbol('k', integer=True, positive=True) + x = Symbol('x', integer=True, positive=True) + t = Symbol('t') + sol = IdealSoliton('sol', k) + assert density(sol).low == S.One + assert density(sol).high == k + assert density(sol).dict == Density(density(sol)) + assert density(sol).pmf(x) == Piecewise((1/k, Eq(x, 1)), (1/(x*(x - 1)), k >= x), (0, True)) + + k_vals = [5, 20, 50, 100, 1000] + for i in k_vals: + assert E(sol.subs(k, i)) == harmonic(i) == moment(sol.subs(k, i), 1) + assert variance(sol.subs(k, i)) == (i - 1) + harmonic(i) - harmonic(i)**2 == cmoment(sol.subs(k, i),2) + assert skewness(sol.subs(k, i)) == smoment(sol.subs(k, i), 3) + assert kurtosis(sol.subs(k, i)) == smoment(sol.subs(k, i), 4) + + assert exp(I*t)/10 + Sum(exp(I*t*x)/(x*x - x), (x, 2, k)).subs(k, 10).doit() == characteristic_function(sol.subs(k, 10))(t) + assert exp(t)/10 + Sum(exp(t*x)/(x*x - x), (x, 2, k)).subs(k, 10).doit() == moment_generating_function(sol.subs(k, 10))(t) + +def test_robust_soliton(): + raises(ValueError, lambda : RobustSoliton('robSol', -12, 0.1, 0.02)) + raises(ValueError, lambda : RobustSoliton('robSol', 13, 1.89, 0.1)) + raises(ValueError, lambda : RobustSoliton('robSol', 15, 0.6, -2.31)) + f = Function('f') + raises(ValueError, lambda : density(RobustSoliton('robSol', 15, 0.6, 0.1)).pmf(f)) + + k = Symbol('k', integer=True, positive=True) + delta = Symbol('delta', positive=True) + c = Symbol('c', positive=True) + robSol = RobustSoliton('robSol', k, delta, c) + assert density(robSol).low == 1 + assert density(robSol).high == k + + k_vals = [10, 20, 50] + delta_vals = [0.2, 0.4, 0.6] + c_vals = [0.01, 0.03, 0.05] + for x in k_vals: + for y in delta_vals: + for z in c_vals: + assert E(robSol.subs({k: x, delta: y, c: z})) == moment(robSol.subs({k: x, delta: y, c: z}), 1) + assert variance(robSol.subs({k: x, delta: y, c: z})) == cmoment(robSol.subs({k: x, delta: y, c: z}), 2) + assert skewness(robSol.subs({k: x, delta: y, c: z})) == smoment(robSol.subs({k: x, delta: y, c: z}), 3) + assert kurtosis(robSol.subs({k: x, delta: y, c: z})) == smoment(robSol.subs({k: x, delta: y, c: z}), 4) + +def test_FiniteRV(): + F = FiniteRV('F', {1: S.Half, 2: Rational(1, 4), 3: Rational(1, 4)}, check=True) + p = Symbol("p", positive=True) + + assert dict(density(F).items()) == {S.One: S.Half, S(2): Rational(1, 4), S(3): Rational(1, 4)} + assert P(F >= 2) == S.Half + assert quantile(F)(p) == Piecewise((nan, p > S.One), (S.One, p <= S.Half),\ + (S(2), p <= Rational(3, 4)),(S(3), True)) + + assert pspace(F).domain.as_boolean() == Or( + *[Eq(F.symbol, i) for i in [1, 2, 3]]) + + assert F.pspace.domain.set == FiniteSet(1, 2, 3) + raises(ValueError, lambda: FiniteRV('F', {1: S.Half, 2: S.Half, 3: S.Half}, check=True)) + raises(ValueError, lambda: FiniteRV('F', {1: S.Half, 2: Rational(-1, 2), 3: S.One}, check=True)) + raises(ValueError, lambda: FiniteRV('F', {1: S.One, 2: Rational(3, 2), 3: S.Zero,\ + 4: Rational(-1, 2), 5: Rational(-3, 4), 6: Rational(-1, 4)}, check=True)) + + # purposeful invalid pmf but it should not raise since check=False + # see test_drv_types.test_ContinuousRV for explanation + X = FiniteRV('X', {1: 1, 2: 2}) + assert E(X) == 5 + assert P(X <= 2) + P(X > 2) != 1 + +def test_density_call(): + from sympy.abc import p + x = Bernoulli('x', p) + d = density(x) + assert d(0) == 1 - p + assert d(S.Zero) == 1 - p + assert d(5) == 0 + + assert 0 in d + assert 5 not in d + assert d(S.Zero) == d[S.Zero] + + +def test_DieDistribution(): + from sympy.abc import x + X = DieDistribution(6) + assert X.pmf(S.Half) is S.Zero + assert X.pmf(x).subs({x: 1}).doit() == Rational(1, 6) + assert X.pmf(x).subs({x: 7}).doit() == 0 + assert X.pmf(x).subs({x: -1}).doit() == 0 + assert X.pmf(x).subs({x: Rational(1, 3)}).doit() == 0 + raises(ValueError, lambda: X.pmf(Matrix([0, 0]))) + raises(ValueError, lambda: X.pmf(x**2 - 1)) + +def test_FinitePSpace(): + X = Die('X', 6) + space = pspace(X) + assert space.density == DieDistribution(6) + +def test_symbolic_conditions(): + B = Bernoulli('B', Rational(1, 4)) + D = Die('D', 4) + b, n = symbols('b, n') + Y = P(Eq(B, b)) + Z = E(D > n) + assert Y == \ + Piecewise((Rational(1, 4), Eq(b, 1)), (0, True)) + \ + Piecewise((Rational(3, 4), Eq(b, 0)), (0, True)) + assert Z == \ + Piecewise((Rational(1, 4), n < 1), (0, True)) + Piecewise((S.Half, n < 2), (0, True)) + \ + Piecewise((Rational(3, 4), n < 3), (0, True)) + Piecewise((S.One, n < 4), (0, True)) diff --git a/phivenv/Lib/site-packages/sympy/stats/tests/test_joint_rv.py b/phivenv/Lib/site-packages/sympy/stats/tests/test_joint_rv.py new file mode 100644 index 0000000000000000000000000000000000000000..057fc313dfbb31826b07fd1315205d22b86a7f96 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/stats/tests/test_joint_rv.py @@ -0,0 +1,436 @@ +from sympy.concrete.products import Product +from sympy.concrete.summations import Sum +from sympy.core.numbers import (Rational, oo, pi) +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.combinatorial.factorials import (RisingFactorial, factorial) +from sympy.functions.elementary.complexes import polar_lift +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.special.bessel import besselk +from sympy.functions.special.gamma_functions import gamma +from sympy.matrices.dense import eye +from sympy.matrices.expressions.determinant import Determinant +from sympy.sets.fancysets import Range +from sympy.sets.sets import (Interval, ProductSet) +from sympy.simplify.simplify import simplify +from sympy.tensor.indexed import (Indexed, IndexedBase) +from sympy.core.numbers import comp +from sympy.integrals.integrals import integrate +from sympy.matrices import Matrix, MatrixSymbol +from sympy.matrices.expressions.matexpr import MatrixElement +from sympy.stats import density, median, marginal_distribution, Normal, Laplace, E, sample +from sympy.stats.joint_rv_types import (JointRV, MultivariateNormalDistribution, + JointDistributionHandmade, MultivariateT, NormalGamma, + GeneralizedMultivariateLogGammaOmega as GMVLGO, MultivariateBeta, + GeneralizedMultivariateLogGamma as GMVLG, MultivariateEwens, + Multinomial, NegativeMultinomial, MultivariateNormal, + MultivariateLaplace) +from sympy.testing.pytest import raises, XFAIL, skip, slow +from sympy.external import import_module + +from sympy.abc import x, y + + + +def test_Normal(): + m = Normal('A', [1, 2], [[1, 0], [0, 1]]) + A = MultivariateNormal('A', [1, 2], [[1, 0], [0, 1]]) + assert m == A + assert density(m)(1, 2) == 1/(2*pi) + assert m.pspace.distribution.set == ProductSet(S.Reals, S.Reals) + raises (ValueError, lambda:m[2]) + n = Normal('B', [1, 2, 3], [[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + p = Normal('C', Matrix([1, 2]), Matrix([[1, 0], [0, 1]])) + assert density(m)(x, y) == density(p)(x, y) + assert marginal_distribution(n, 0, 1)(1, 2) == 1/(2*pi) + raises(ValueError, lambda: marginal_distribution(m)) + assert integrate(density(m)(x, y), (x, -oo, oo), (y, -oo, oo)).evalf() == 1.0 + N = Normal('N', [1, 2], [[x, 0], [0, y]]) + assert density(N)(0, 0) == exp(-((4*x + y)/(2*x*y)))/(2*pi*sqrt(x*y)) + + raises (ValueError, lambda: Normal('M', [1, 2], [[1, 1], [1, -1]])) + # symbolic + n = symbols('n', integer=True, positive=True) + mu = MatrixSymbol('mu', n, 1) + sigma = MatrixSymbol('sigma', n, n) + X = Normal('X', mu, sigma) + assert density(X) == MultivariateNormalDistribution(mu, sigma) + raises (NotImplementedError, lambda: median(m)) + # Below tests should work after issue #17267 is resolved + # assert E(X) == mu + # assert variance(X) == sigma + + # test symbolic multivariate normal densities + n = 3 + + Sg = MatrixSymbol('Sg', n, n) + mu = MatrixSymbol('mu', n, 1) + obs = MatrixSymbol('obs', n, 1) + + X = MultivariateNormal('X', mu, Sg) + density_X = density(X) + + eval_a = density_X(obs).subs({Sg: eye(3), + mu: Matrix([0, 0, 0]), obs: Matrix([0, 0, 0])}).doit() + eval_b = density_X(0, 0, 0).subs({Sg: eye(3), mu: Matrix([0, 0, 0])}).doit() + + assert eval_a == sqrt(2)/(4*pi**Rational(3/2)) + assert eval_b == sqrt(2)/(4*pi**Rational(3/2)) + + n = symbols('n', integer=True, positive=True) + + Sg = MatrixSymbol('Sg', n, n) + mu = MatrixSymbol('mu', n, 1) + obs = MatrixSymbol('obs', n, 1) + + X = MultivariateNormal('X', mu, Sg) + density_X_at_obs = density(X)(obs) + + expected_density = MatrixElement( + exp((S(1)/2) * (mu.T - obs.T) * Sg**(-1) * (-mu + obs)) / \ + sqrt((2*pi)**n * Determinant(Sg)), 0, 0) + + assert density_X_at_obs == expected_density + + +def test_MultivariateTDist(): + t1 = MultivariateT('T', [0, 0], [[1, 0], [0, 1]], 2) + assert(density(t1))(1, 1) == 1/(8*pi) + assert t1.pspace.distribution.set == ProductSet(S.Reals, S.Reals) + assert integrate(density(t1)(x, y), (x, -oo, oo), \ + (y, -oo, oo)).evalf() == 1.0 + raises(ValueError, lambda: MultivariateT('T', [1, 2], [[1, 1], [1, -1]], 1)) + t2 = MultivariateT('t2', [1, 2], [[x, 0], [0, y]], 1) + assert density(t2)(1, 2) == 1/(2*pi*sqrt(x*y)) + + +def test_multivariate_laplace(): + raises(ValueError, lambda: Laplace('T', [1, 2], [[1, 2], [2, 1]])) + L = Laplace('L', [1, 0], [[1, 0], [0, 1]]) + L2 = MultivariateLaplace('L2', [1, 0], [[1, 0], [0, 1]]) + assert density(L)(2, 3) == exp(2)*besselk(0, sqrt(39))/pi + L1 = Laplace('L1', [1, 2], [[x, 0], [0, y]]) + assert density(L1)(0, 1) == \ + exp(2/y)*besselk(0, sqrt((2 + 4/y + 1/x)/y))/(pi*sqrt(x*y)) + assert L.pspace.distribution.set == ProductSet(S.Reals, S.Reals) + assert L.pspace.distribution == L2.pspace.distribution + + +def test_NormalGamma(): + ng = NormalGamma('G', 1, 2, 3, 4) + assert density(ng)(1, 1) == 32*exp(-4)/sqrt(pi) + assert ng.pspace.distribution.set == ProductSet(S.Reals, Interval(0, oo)) + raises(ValueError, lambda:NormalGamma('G', 1, 2, 3, -1)) + assert marginal_distribution(ng, 0)(1) == \ + 3*sqrt(10)*gamma(Rational(7, 4))/(10*sqrt(pi)*gamma(Rational(5, 4))) + assert marginal_distribution(ng, y)(1) == exp(Rational(-1, 4))/128 + assert marginal_distribution(ng,[0,1])(x) == x**2*exp(-x/4)/128 + + +def test_GeneralizedMultivariateLogGammaDistribution(): + h = S.Half + omega = Matrix([[1, h, h, h], + [h, 1, h, h], + [h, h, 1, h], + [h, h, h, 1]]) + v, l, mu = (4, [1, 2, 3, 4], [1, 2, 3, 4]) + y_1, y_2, y_3, y_4 = symbols('y_1:5', real=True) + delta = symbols('d', positive=True) + G = GMVLGO('G', omega, v, l, mu) + Gd = GMVLG('Gd', delta, v, l, mu) + dend = ("d**4*Sum(4*24**(-n - 4)*(1 - d)**n*exp((n + 4)*(y_1 + 2*y_2 + 3*y_3 " + "+ 4*y_4) - exp(y_1) - exp(2*y_2)/2 - exp(3*y_3)/3 - exp(4*y_4)/4)/" + "(gamma(n + 1)*gamma(n + 4)**3), (n, 0, oo))") + assert str(density(Gd)(y_1, y_2, y_3, y_4)) == dend + den = ("5*2**(2/3)*5**(1/3)*Sum(4*24**(-n - 4)*(-2**(2/3)*5**(1/3)/4 + 1)**n*" + "exp((n + 4)*(y_1 + 2*y_2 + 3*y_3 + 4*y_4) - exp(y_1) - exp(2*y_2)/2 - " + "exp(3*y_3)/3 - exp(4*y_4)/4)/(gamma(n + 1)*gamma(n + 4)**3), (n, 0, oo))/64") + assert str(density(G)(y_1, y_2, y_3, y_4)) == den + marg = ("5*2**(2/3)*5**(1/3)*exp(4*y_1)*exp(-exp(y_1))*Integral(exp(-exp(4*G[3])" + "/4)*exp(16*G[3])*Integral(exp(-exp(3*G[2])/3)*exp(12*G[2])*Integral(exp(" + "-exp(2*G[1])/2)*exp(8*G[1])*Sum((-1/4)**n*(-4 + 2**(2/3)*5**(1/3" + "))**n*exp(n*y_1)*exp(2*n*G[1])*exp(3*n*G[2])*exp(4*n*G[3])/(24**n*gamma(n + 1)" + "*gamma(n + 4)**3), (n, 0, oo)), (G[1], -oo, oo)), (G[2], -oo, oo)), (G[3]" + ", -oo, oo))/5308416") + assert str(marginal_distribution(G, G[0])(y_1)) == marg + omega_f1 = Matrix([[1, h, h]]) + omega_f2 = Matrix([[1, h, h, h], + [h, 1, 2, h], + [h, h, 1, h], + [h, h, h, 1]]) + omega_f3 = Matrix([[6, h, h, h], + [h, 1, 2, h], + [h, h, 1, h], + [h, h, h, 1]]) + v_f = symbols("v_f", positive=False, real=True) + l_f = [1, 2, v_f, 4] + m_f = [v_f, 2, 3, 4] + omega_f4 = Matrix([[1, h, h, h, h], + [h, 1, h, h, h], + [h, h, 1, h, h], + [h, h, h, 1, h], + [h, h, h, h, 1]]) + l_f1 = [1, 2, 3, 4, 5] + omega_f5 = Matrix([[1]]) + mu_f5 = l_f5 = [1] + + raises(ValueError, lambda: GMVLGO('G', omega_f1, v, l, mu)) + raises(ValueError, lambda: GMVLGO('G', omega_f2, v, l, mu)) + raises(ValueError, lambda: GMVLGO('G', omega_f3, v, l, mu)) + raises(ValueError, lambda: GMVLGO('G', omega, v_f, l, mu)) + raises(ValueError, lambda: GMVLGO('G', omega, v, l_f, mu)) + raises(ValueError, lambda: GMVLGO('G', omega, v, l, m_f)) + raises(ValueError, lambda: GMVLGO('G', omega_f4, v, l, mu)) + raises(ValueError, lambda: GMVLGO('G', omega, v, l_f1, mu)) + raises(ValueError, lambda: GMVLGO('G', omega_f5, v, l_f5, mu_f5)) + raises(ValueError, lambda: GMVLG('G', Rational(3, 2), v, l, mu)) + + +def test_MultivariateBeta(): + a1, a2 = symbols('a1, a2', positive=True) + a1_f, a2_f = symbols('a1, a2', positive=False, real=True) + mb = MultivariateBeta('B', [a1, a2]) + mb_c = MultivariateBeta('C', a1, a2) + assert density(mb)(1, 2) == S(2)**(a2 - 1)*gamma(a1 + a2)/\ + (gamma(a1)*gamma(a2)) + assert marginal_distribution(mb_c, 0)(3) == S(3)**(a1 - 1)*gamma(a1 + a2)/\ + (a2*gamma(a1)*gamma(a2)) + raises(ValueError, lambda: MultivariateBeta('b1', [a1_f, a2])) + raises(ValueError, lambda: MultivariateBeta('b2', [a1, a2_f])) + raises(ValueError, lambda: MultivariateBeta('b3', [0, 0])) + raises(ValueError, lambda: MultivariateBeta('b4', [a1_f, a2_f])) + assert mb.pspace.distribution.set == ProductSet(Interval(0, 1), Interval(0, 1)) + + +def test_MultivariateEwens(): + n, theta, i = symbols('n theta i', positive=True) + + # tests for integer dimensions + theta_f = symbols('t_f', negative=True) + a = symbols('a_1:4', positive = True, integer = True) + ed = MultivariateEwens('E', 3, theta) + assert density(ed)(a[0], a[1], a[2]) == Piecewise((6*2**(-a[1])*3**(-a[2])* + theta**a[0]*theta**a[1]*theta**a[2]/ + (theta*(theta + 1)*(theta + 2)* + factorial(a[0])*factorial(a[1])* + factorial(a[2])), Eq(a[0] + 2*a[1] + + 3*a[2], 3)), (0, True)) + assert marginal_distribution(ed, ed[1])(a[1]) == Piecewise((6*2**(-a[1])* + theta**a[1]/((theta + 1)* + (theta + 2)*factorial(a[1])), + Eq(2*a[1] + 1, 3)), (0, True)) + raises(ValueError, lambda: MultivariateEwens('e1', 5, theta_f)) + assert ed.pspace.distribution.set == ProductSet(Range(0, 4, 1), + Range(0, 2, 1), Range(0, 2, 1)) + + # tests for symbolic dimensions + eds = MultivariateEwens('E', n, theta) + a = IndexedBase('a') + j, k = symbols('j, k') + den = Piecewise((factorial(n)*Product(theta**a[j]*(j + 1)**(-a[j])/ + factorial(a[j]), (j, 0, n - 1))/RisingFactorial(theta, n), + Eq(n, Sum((k + 1)*a[k], (k, 0, n - 1)))), (0, True)) + assert density(eds)(a).dummy_eq(den) + + +def test_Multinomial(): + n, x1, x2, x3, x4 = symbols('n, x1, x2, x3, x4', nonnegative=True, integer=True) + p1, p2, p3, p4 = symbols('p1, p2, p3, p4', positive=True) + p1_f, n_f = symbols('p1_f, n_f', negative=True) + M = Multinomial('M', n, [p1, p2, p3, p4]) + C = Multinomial('C', 3, p1, p2, p3) + f = factorial + assert density(M)(x1, x2, x3, x4) == Piecewise((p1**x1*p2**x2*p3**x3*p4**x4* + f(n)/(f(x1)*f(x2)*f(x3)*f(x4)), + Eq(n, x1 + x2 + x3 + x4)), (0, True)) + assert marginal_distribution(C, C[0])(x1).subs(x1, 1) ==\ + 3*p1*p2**2 +\ + 6*p1*p2*p3 +\ + 3*p1*p3**2 + raises(ValueError, lambda: Multinomial('b1', 5, [p1, p2, p3, p1_f])) + raises(ValueError, lambda: Multinomial('b2', n_f, [p1, p2, p3, p4])) + raises(ValueError, lambda: Multinomial('b3', n, 0.5, 0.4, 0.3, 0.1)) + + +def test_NegativeMultinomial(): + k0, x1, x2, x3, x4 = symbols('k0, x1, x2, x3, x4', nonnegative=True, integer=True) + p1, p2, p3, p4 = symbols('p1, p2, p3, p4', positive=True) + p1_f = symbols('p1_f', negative=True) + N = NegativeMultinomial('N', 4, [p1, p2, p3, p4]) + C = NegativeMultinomial('C', 4, 0.1, 0.2, 0.3) + g = gamma + f = factorial + assert simplify(density(N)(x1, x2, x3, x4) - + p1**x1*p2**x2*p3**x3*p4**x4*(-p1 - p2 - p3 - p4 + 1)**4*g(x1 + x2 + + x3 + x4 + 4)/(6*f(x1)*f(x2)*f(x3)*f(x4))) is S.Zero + assert comp(marginal_distribution(C, C[0])(1).evalf(), 0.33, .01) + raises(ValueError, lambda: NegativeMultinomial('b1', 5, [p1, p2, p3, p1_f])) + raises(ValueError, lambda: NegativeMultinomial('b2', k0, 0.5, 0.4, 0.3, 0.4)) + assert N.pspace.distribution.set == ProductSet(Range(0, oo, 1), + Range(0, oo, 1), Range(0, oo, 1), Range(0, oo, 1)) + + +@slow +def test_JointPSpace_marginal_distribution(): + T = MultivariateT('T', [0, 0], [[1, 0], [0, 1]], 2) + got = marginal_distribution(T, T[1])(x) + ans = sqrt(2)*(x**2/2 + 1)/(4*polar_lift(x**2/2 + 1)**(S(5)/2)) + assert got == ans, got + assert integrate(marginal_distribution(T, 1)(x), (x, -oo, oo)) == 1 + + t = MultivariateT('T', [0, 0, 0], [[1, 0, 0], [0, 1, 0], [0, 0, 1]], 3) + assert comp(marginal_distribution(t, 0)(1).evalf(), 0.2, .01) + + +def test_JointRV(): + x1, x2 = (Indexed('x', i) for i in (1, 2)) + pdf = exp(-x1**2/2 + x1 - x2**2/2 - S.Half)/(2*pi) + X = JointRV('x', pdf) + assert density(X)(1, 2) == exp(-2)/(2*pi) + assert isinstance(X.pspace.distribution, JointDistributionHandmade) + assert marginal_distribution(X, 0)(2) == sqrt(2)*exp(Rational(-1, 2))/(2*sqrt(pi)) + + +def test_expectation(): + m = Normal('A', [x, y], [[1, 0], [0, 1]]) + assert simplify(E(m[1])) == y + + +@XFAIL +def test_joint_vector_expectation(): + m = Normal('A', [x, y], [[1, 0], [0, 1]]) + assert E(m) == (x, y) + + +def test_sample_numpy(): + distribs_numpy = [ + MultivariateNormal("M", [3, 4], [[2, 1], [1, 2]]), + MultivariateBeta("B", [0.4, 5, 15, 50, 203]), + Multinomial("N", 50, [0.3, 0.2, 0.1, 0.25, 0.15]) + ] + size = 3 + numpy = import_module('numpy') + if not numpy: + skip('Numpy is not installed. Abort tests for _sample_numpy.') + else: + for X in distribs_numpy: + samps = sample(X, size=size, library='numpy') + for sam in samps: + assert tuple(sam) in X.pspace.distribution.set + N_c = NegativeMultinomial('N', 3, 0.1, 0.1, 0.1) + raises(NotImplementedError, lambda: sample(N_c, library='numpy')) + + +def test_sample_scipy(): + distribs_scipy = [ + MultivariateNormal("M", [0, 0], [[0.1, 0.025], [0.025, 0.1]]), + MultivariateBeta("B", [0.4, 5, 15]), + Multinomial("N", 8, [0.3, 0.2, 0.1, 0.4]) + ] + + size = 3 + scipy = import_module('scipy') + if not scipy: + skip('Scipy not installed. Abort tests for _sample_scipy.') + else: + for X in distribs_scipy: + samps = sample(X, size=size) + samps2 = sample(X, size=(2, 2)) + for sam in samps: + assert tuple(sam) in X.pspace.distribution.set + for i in range(2): + for j in range(2): + assert tuple(samps2[i][j]) in X.pspace.distribution.set + N_c = NegativeMultinomial('N', 3, 0.1, 0.1, 0.1) + raises(NotImplementedError, lambda: sample(N_c)) + + +def test_sample_pymc(): + distribs_pymc = [ + MultivariateNormal("M", [5, 2], [[1, 0], [0, 1]]), + MultivariateBeta("B", [0.4, 5, 15]), + Multinomial("N", 4, [0.3, 0.2, 0.1, 0.4]) + ] + size = 3 + pymc = import_module('pymc') + if not pymc: + skip('PyMC is not installed. Abort tests for _sample_pymc.') + else: + for X in distribs_pymc: + samps = sample(X, size=size, library='pymc') + for sam in samps: + assert tuple(sam.flatten()) in X.pspace.distribution.set + N_c = NegativeMultinomial('N', 3, 0.1, 0.1, 0.1) + raises(NotImplementedError, lambda: sample(N_c, library='pymc')) + + +def test_sample_seed(): + x1, x2 = (Indexed('x', i) for i in (1, 2)) + pdf = exp(-x1**2/2 + x1 - x2**2/2 - S.Half)/(2*pi) + X = JointRV('x', pdf) + + libraries = ['scipy', 'numpy', 'pymc'] + for lib in libraries: + try: + imported_lib = import_module(lib) + if imported_lib: + s0, s1, s2 = [], [], [] + s0 = sample(X, size=10, library=lib, seed=0) + s1 = sample(X, size=10, library=lib, seed=0) + s2 = sample(X, size=10, library=lib, seed=1) + assert all(s0 == s1) + assert all(s1 != s2) + except NotImplementedError: + continue + +# +# XXX: This fails for pymc. Previously the test appeared to pass but that is +# just because the library argument was not passed so the test always used +# scipy. +# +def test_issue_21057(): + m = Normal("x", [0, 0], [[0, 0], [0, 0]]) + n = MultivariateNormal("x", [0, 0], [[0, 0], [0, 0]]) + p = Normal("x", [0, 0], [[0, 0], [0, 1]]) + assert m == n + libraries = ('scipy', 'numpy') # , 'pymc') # <-- pymc fails + for library in libraries: + try: + imported_lib = import_module(library) + if imported_lib: + s1 = sample(m, size=8, library=library) + s2 = sample(n, size=8, library=library) + s3 = sample(p, size=8, library=library) + assert tuple(s1.flatten()) == tuple(s2.flatten()) + for s in s3: + assert tuple(s.flatten()) in p.pspace.distribution.set + except NotImplementedError: + continue + + +# +# When this passes the pymc part can be uncommented in test_issue_21057 above +# and this can be deleted. +# +@XFAIL +def test_issue_21057_pymc(): + m = Normal("x", [0, 0], [[0, 0], [0, 0]]) + n = MultivariateNormal("x", [0, 0], [[0, 0], [0, 0]]) + p = Normal("x", [0, 0], [[0, 0], [0, 1]]) + assert m == n + libraries = ('pymc',) + for library in libraries: + try: + imported_lib = import_module(library) + if imported_lib: + s1 = sample(m, size=8, library=library) + s2 = sample(n, size=8, library=library) + s3 = sample(p, size=8, library=library) + assert tuple(s1.flatten()) == tuple(s2.flatten()) + for s in s3: + assert tuple(s.flatten()) in p.pspace.distribution.set + except NotImplementedError: + continue diff --git a/phivenv/Lib/site-packages/sympy/stats/tests/test_matrix_distributions.py b/phivenv/Lib/site-packages/sympy/stats/tests/test_matrix_distributions.py new file mode 100644 index 0000000000000000000000000000000000000000..a2a2dcdd853793d9f77e1a88adf63158ed68e3ba --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/stats/tests/test_matrix_distributions.py @@ -0,0 +1,186 @@ +from sympy.concrete.products import Product +from sympy.core.numbers import pi +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, symbols) +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.special.gamma_functions import gamma +from sympy.matrices import Determinant, Matrix, Trace, MatrixSymbol, MatrixSet +from sympy.stats import density, sample +from sympy.stats.matrix_distributions import (MatrixGammaDistribution, + MatrixGamma, MatrixPSpace, Wishart, MatrixNormal, MatrixStudentT) +from sympy.testing.pytest import raises, skip +from sympy.external import import_module + + +def test_MatrixPSpace(): + M = MatrixGammaDistribution(1, 2, [[2, 1], [1, 2]]) + MP = MatrixPSpace('M', M, 2, 2) + assert MP.distribution == M + raises(ValueError, lambda: MatrixPSpace('M', M, 1.2, 2)) + +def test_MatrixGamma(): + M = MatrixGamma('M', 1, 2, [[1, 0], [0, 1]]) + assert M.pspace.distribution.set == MatrixSet(2, 2, S.Reals) + assert isinstance(density(M), MatrixGammaDistribution) + X = MatrixSymbol('X', 2, 2) + num = exp(Trace(Matrix([[-S(1)/2, 0], [0, -S(1)/2]])*X)) + assert density(M)(X).doit() == num/(4*pi*sqrt(Determinant(X))) + assert density(M)([[2, 1], [1, 2]]).doit() == sqrt(3)*exp(-2)/(12*pi) + X = MatrixSymbol('X', 1, 2) + Y = MatrixSymbol('Y', 1, 2) + assert density(M)([X, Y]).doit() == exp(-X[0, 0]/2 - Y[0, 1]/2)/(4*pi*sqrt( + X[0, 0]*Y[0, 1] - X[0, 1]*Y[0, 0])) + # symbolic + a, b = symbols('a b', positive=True) + d = symbols('d', positive=True, integer=True) + Y = MatrixSymbol('Y', d, d) + Z = MatrixSymbol('Z', 2, 2) + SM = MatrixSymbol('SM', d, d) + M2 = MatrixGamma('M2', a, b, SM) + M3 = MatrixGamma('M3', 2, 3, [[2, 1], [1, 2]]) + k = Dummy('k') + exprd = pi**(-d*(d - 1)/4)*b**(-a*d)*exp(Trace((-1/b)*SM**(-1)*Y) + )*Determinant(SM)**(-a)*Determinant(Y)**(a - d/2 - S(1)/2)/Product( + gamma(-k/2 + a + S(1)/2), (k, 1, d)) + assert density(M2)(Y).dummy_eq(exprd) + raises(NotImplementedError, lambda: density(M3 + M)(Z)) + raises(ValueError, lambda: density(M)(1)) + raises(ValueError, lambda: MatrixGamma('M', -1, 2, [[1, 0], [0, 1]])) + raises(ValueError, lambda: MatrixGamma('M', -1, -2, [[1, 0], [0, 1]])) + raises(ValueError, lambda: MatrixGamma('M', -1, 2, [[1, 0], [2, 1]])) + raises(ValueError, lambda: MatrixGamma('M', -1, 2, [[1, 0], [0]])) + +def test_Wishart(): + W = Wishart('W', 5, [[1, 0], [0, 1]]) + assert W.pspace.distribution.set == MatrixSet(2, 2, S.Reals) + X = MatrixSymbol('X', 2, 2) + term1 = exp(Trace(Matrix([[-S(1)/2, 0], [0, -S(1)/2]])*X)) + assert density(W)(X).doit() == term1 * Determinant(X)/(24*pi) + assert density(W)([[2, 1], [1, 2]]).doit() == exp(-2)/(8*pi) + n = symbols('n', positive=True) + d = symbols('d', positive=True, integer=True) + Y = MatrixSymbol('Y', d, d) + SM = MatrixSymbol('SM', d, d) + W = Wishart('W', n, SM) + k = Dummy('k') + exprd = 2**(-d*n/2)*pi**(-d*(d - 1)/4)*exp(Trace(-(S(1)/2)*SM**(-1)*Y) + )*Determinant(SM)**(-n/2)*Determinant(Y)**( + -d/2 + n/2 - S(1)/2)/Product(gamma(-k/2 + n/2 + S(1)/2), (k, 1, d)) + assert density(W)(Y).dummy_eq(exprd) + raises(ValueError, lambda: density(W)(1)) + raises(ValueError, lambda: Wishart('W', -1, [[1, 0], [0, 1]])) + raises(ValueError, lambda: Wishart('W', -1, [[1, 0], [2, 1]])) + raises(ValueError, lambda: Wishart('W', 2, [[1, 0], [0]])) + +def test_MatrixNormal(): + M = MatrixNormal('M', [[5, 6]], [4], [[2, 1], [1, 2]]) + assert M.pspace.distribution.set == MatrixSet(1, 2, S.Reals) + X = MatrixSymbol('X', 1, 2) + term1 = exp(-Trace(Matrix([[ S(2)/3, -S(1)/3], [-S(1)/3, S(2)/3]])*( + Matrix([[-5], [-6]]) + X.T)*Matrix([[S(1)/4]])*(Matrix([[-5, -6]]) + X))/2) + assert density(M)(X).doit() == (sqrt(3)) * term1/(24*pi) + assert density(M)([[7, 8]]).doit() == sqrt(3)*exp(-S(1)/3)/(24*pi) + d, n = symbols('d n', positive=True, integer=True) + SM2 = MatrixSymbol('SM2', d, d) + SM1 = MatrixSymbol('SM1', n, n) + LM = MatrixSymbol('LM', n, d) + Y = MatrixSymbol('Y', n, d) + M = MatrixNormal('M', LM, SM1, SM2) + exprd = (2*pi)**(-d*n/2)*exp(-Trace(SM2**(-1)*(-LM.T + Y.T)*SM1**(-1)*(-LM + Y) + )/2)*Determinant(SM1)**(-d/2)*Determinant(SM2)**(-n/2) + assert density(M)(Y).doit() == exprd + raises(ValueError, lambda: density(M)(1)) + raises(ValueError, lambda: MatrixNormal('M', [1, 2], [[1, 0], [0, 1]], [[1, 0], [2, 1]])) + raises(ValueError, lambda: MatrixNormal('M', [1, 2], [[1, 0], [2, 1]], [[1, 0], [0, 1]])) + raises(ValueError, lambda: MatrixNormal('M', [1, 2], [[1, 0], [0, 1]], [[1, 0], [0, 1]])) + raises(ValueError, lambda: MatrixNormal('M', [1, 2], [[1, 0], [2]], [[1, 0], [0, 1]])) + raises(ValueError, lambda: MatrixNormal('M', [1, 2], [[1, 0], [2, 1]], [[1, 0], [0]])) + raises(ValueError, lambda: MatrixNormal('M', [[1, 2]], [[1, 0], [0, 1]], [[1, 0]])) + raises(ValueError, lambda: MatrixNormal('M', [[1, 2]], [1], [[1, 0]])) + +def test_MatrixStudentT(): + M = MatrixStudentT('M', 2, [[5, 6]], [[2, 1], [1, 2]], [4]) + assert M.pspace.distribution.set == MatrixSet(1, 2, S.Reals) + X = MatrixSymbol('X', 1, 2) + D = pi ** (-1.0) * Determinant(Matrix([[4]])) ** (-1.0) * Determinant(Matrix([[2, 1], [1, 2]])) \ + ** (-0.5) / Determinant(Matrix([[S(1) / 4]]) * (Matrix([[-5, -6]]) + X) + * Matrix([[S(2) / 3, -S(1) / 3], [-S(1) / 3, S(2) / 3]]) * ( + Matrix([[-5], [-6]]) + X.T) + Matrix([[1]])) ** 2 + assert density(M)(X) == D + + v = symbols('v', positive=True) + n, p = 1, 2 + Omega = MatrixSymbol('Omega', p, p) + Sigma = MatrixSymbol('Sigma', n, n) + Location = MatrixSymbol('Location', n, p) + Y = MatrixSymbol('Y', n, p) + M = MatrixStudentT('M', v, Location, Omega, Sigma) + + exprd = gamma(v/2 + 1)*Determinant(Matrix([[1]]) + Sigma**(-1)*(-Location + Y)*Omega**(-1)*(-Location.T + Y.T))**(-v/2 - 1) / \ + (pi*gamma(v/2)*sqrt(Determinant(Omega))*Determinant(Sigma)) + + assert density(M)(Y) == exprd + raises(ValueError, lambda: density(M)(1)) + raises(ValueError, lambda: MatrixStudentT('M', 1, [1, 2], [[1, 0], [0, 1]], [[1, 0], [2, 1]])) + raises(ValueError, lambda: MatrixStudentT('M', 1, [1, 2], [[1, 0], [2, 1]], [[1, 0], [0, 1]])) + raises(ValueError, lambda: MatrixStudentT('M', 1, [1, 2], [[1, 0], [0, 1]], [[1, 0], [0, 1]])) + raises(ValueError, lambda: MatrixStudentT('M', 1, [1, 2], [[1, 0], [2]], [[1, 0], [0, 1]])) + raises(ValueError, lambda: MatrixStudentT('M', 1, [1, 2], [[1, 0], [2, 1]], [[1], [2]])) + raises(ValueError, lambda: MatrixStudentT('M', 1, [[1, 2]], [[1, 0], [0, 1]], [[1, 0]])) + raises(ValueError, lambda: MatrixStudentT('M', 1, [[1, 2]], [1], [[1, 0]])) + raises(ValueError, lambda: MatrixStudentT('M', -1, [1, 2], [[1, 0], [0, 1]], [4])) + +def test_sample_scipy(): + distribs_scipy = [ + MatrixNormal('M', [[5, 6]], [4], [[2, 1], [1, 2]]), + Wishart('W', 5, [[1, 0], [0, 1]]) + ] + + size = 5 + scipy = import_module('scipy') + if not scipy: + skip('Scipy not installed. Abort tests for _sample_scipy.') + else: + for X in distribs_scipy: + samps = sample(X, size=size) + for sam in samps: + assert Matrix(sam) in X.pspace.distribution.set + M = MatrixGamma('M', 1, 2, [[1, 0], [0, 1]]) + raises(NotImplementedError, lambda: sample(M, size=3)) + +def test_sample_pymc(): + distribs_pymc = [ + MatrixNormal('M', [[5, 6], [3, 4]], [[1, 0], [0, 1]], [[2, 1], [1, 2]]), + Wishart('W', 7, [[2, 1], [1, 2]]) + ] + size = 3 + pymc = import_module('pymc') + if not pymc: + skip('PyMC is not installed. Abort tests for _sample_pymc.') + else: + for X in distribs_pymc: + samps = sample(X, size=size, library='pymc') + for sam in samps: + assert Matrix(sam) in X.pspace.distribution.set + M = MatrixGamma('M', 1, 2, [[1, 0], [0, 1]]) + raises(NotImplementedError, lambda: sample(M, size=3)) + +def test_sample_seed(): + X = MatrixNormal('M', [[5, 6], [3, 4]], [[1, 0], [0, 1]], [[2, 1], [1, 2]]) + + libraries = ['scipy', 'numpy', 'pymc'] + for lib in libraries: + try: + imported_lib = import_module(lib) + if imported_lib: + s0, s1, s2 = [], [], [] + s0 = sample(X, size=10, library=lib, seed=0) + s1 = sample(X, size=10, library=lib, seed=0) + s2 = sample(X, size=10, library=lib, seed=1) + for i in range(10): + assert (s0[i] == s1[i]).all() + assert (s1[i] != s2[i]).all() + + except NotImplementedError: + continue diff --git a/phivenv/Lib/site-packages/sympy/stats/tests/test_mix.py b/phivenv/Lib/site-packages/sympy/stats/tests/test_mix.py new file mode 100644 index 0000000000000000000000000000000000000000..4334d9b144a5ddaad938f195f0276e0e8993aa35 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/stats/tests/test_mix.py @@ -0,0 +1,82 @@ +from sympy.concrete.summations import Sum +from sympy.core.add import Add +from sympy.core.mul import Mul +from sympy.core.numbers import (Integer, oo, pi) +from sympy.core.power import Pow +from sympy.core.relational import (Eq, Ne) +from sympy.core.symbol import (Dummy, Symbol, symbols) +from sympy.functions.combinatorial.factorials import factorial +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.special.delta_functions import DiracDelta +from sympy.functions.special.gamma_functions import gamma +from sympy.integrals.integrals import Integral +from sympy.simplify.simplify import simplify +from sympy.tensor.indexed import (Indexed, IndexedBase) +from sympy.functions.elementary.piecewise import ExprCondPair +from sympy.stats import (Poisson, Beta, Exponential, P, + Multinomial, MultivariateBeta) +from sympy.stats.crv_types import Normal +from sympy.stats.drv_types import PoissonDistribution +from sympy.stats.compound_rv import CompoundPSpace, CompoundDistribution +from sympy.stats.joint_rv import MarginalDistribution +from sympy.stats.rv import pspace, density +from sympy.testing.pytest import ignore_warnings + +def test_density(): + x = Symbol('x') + l = Symbol('l', positive=True) + rate = Beta(l, 2, 3) + X = Poisson(x, rate) + assert isinstance(pspace(X), CompoundPSpace) + assert density(X, Eq(rate, rate.symbol)) == PoissonDistribution(l) + N1 = Normal('N1', 0, 1) + N2 = Normal('N2', N1, 2) + assert density(N2)(0).doit() == sqrt(10)/(10*sqrt(pi)) + assert simplify(density(N2, Eq(N1, 1))(x)) == \ + sqrt(2)*exp(-(x - 1)**2/8)/(4*sqrt(pi)) + assert simplify(density(N2)(x)) == sqrt(10)*exp(-x**2/10)/(10*sqrt(pi)) + +def test_MarginalDistribution(): + a1, p1, p2 = symbols('a1 p1 p2', positive=True) + C = Multinomial('C', 2, p1, p2) + B = MultivariateBeta('B', a1, C[0]) + MGR = MarginalDistribution(B, (C[0],)) + mgrc = Mul(Symbol('B'), Piecewise(ExprCondPair(Mul(Integer(2), + Pow(Symbol('p1', positive=True), Indexed(IndexedBase(Symbol('C')), + Integer(0))), Pow(Symbol('p2', positive=True), + Indexed(IndexedBase(Symbol('C')), Integer(1))), + Pow(factorial(Indexed(IndexedBase(Symbol('C')), Integer(0))), Integer(-1)), + Pow(factorial(Indexed(IndexedBase(Symbol('C')), Integer(1))), Integer(-1))), + Eq(Add(Indexed(IndexedBase(Symbol('C')), Integer(0)), + Indexed(IndexedBase(Symbol('C')), Integer(1))), Integer(2))), + ExprCondPair(Integer(0), True)), Pow(gamma(Symbol('a1', positive=True)), + Integer(-1)), gamma(Add(Symbol('a1', positive=True), + Indexed(IndexedBase(Symbol('C')), Integer(0)))), + Pow(gamma(Indexed(IndexedBase(Symbol('C')), Integer(0))), Integer(-1)), + Pow(Indexed(IndexedBase(Symbol('B')), Integer(0)), + Add(Symbol('a1', positive=True), Integer(-1))), + Pow(Indexed(IndexedBase(Symbol('B')), Integer(1)), + Add(Indexed(IndexedBase(Symbol('C')), Integer(0)), Integer(-1)))) + assert MGR(C) == mgrc + +def test_compound_distribution(): + Y = Poisson('Y', 1) + Z = Poisson('Z', Y) + assert isinstance(pspace(Z), CompoundPSpace) + assert isinstance(pspace(Z).distribution, CompoundDistribution) + assert Z.pspace.distribution.pdf(1).doit() == exp(-2)*exp(exp(-1)) + +def test_mix_expression(): + Y, E = Poisson('Y', 1), Exponential('E', 1) + k = Dummy('k') + expr1 = Integral(Sum(exp(-1)*Integral(exp(-k)*DiracDelta(k - 2), (k, 0, oo) + )/factorial(k), (k, 0, oo)), (k, -oo, 0)) + expr2 = Integral(Sum(exp(-1)*Integral(exp(-k)*DiracDelta(k - 2), (k, 0, oo) + )/factorial(k), (k, 0, oo)), (k, 0, oo)) + assert P(Eq(Y + E, 1)) == 0 + assert P(Ne(Y + E, 2)) == 1 + with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed + assert P(E + Y < 2, evaluate=False).rewrite(Integral).dummy_eq(expr1) + assert P(E + Y > 2, evaluate=False).rewrite(Integral).dummy_eq(expr2) diff --git a/phivenv/Lib/site-packages/sympy/stats/tests/test_random_matrix.py b/phivenv/Lib/site-packages/sympy/stats/tests/test_random_matrix.py new file mode 100644 index 0000000000000000000000000000000000000000..ba570a16bc42620d53bce19be71e7d125965ede1 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/stats/tests/test_random_matrix.py @@ -0,0 +1,135 @@ +from sympy.concrete.products import Product +from sympy.core.function import Lambda +from sympy.core.numbers import (I, Rational, pi) +from sympy.core.singleton import S +from sympy.core.symbol import Dummy +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.integrals.integrals import Integral +from sympy.matrices.dense import Matrix +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.expressions.trace import Trace +from sympy.tensor.indexed import IndexedBase +from sympy.stats import (GaussianUnitaryEnsemble as GUE, density, + GaussianOrthogonalEnsemble as GOE, + GaussianSymplecticEnsemble as GSE, + joint_eigen_distribution, + CircularUnitaryEnsemble as CUE, + CircularOrthogonalEnsemble as COE, + CircularSymplecticEnsemble as CSE, + JointEigenDistribution, + level_spacing_distribution, + Normal, Beta) +from sympy.stats.joint_rv_types import JointDistributionHandmade +from sympy.stats.rv import RandomMatrixSymbol +from sympy.stats.random_matrix_models import GaussianEnsemble, RandomMatrixPSpace +from sympy.testing.pytest import raises + +def test_GaussianEnsemble(): + G = GaussianEnsemble('G', 3) + assert density(G) == G.pspace.model + raises(ValueError, lambda: GaussianEnsemble('G', 3.5)) + +def test_GaussianUnitaryEnsemble(): + H = RandomMatrixSymbol('H', 3, 3) + G = GUE('U', 3) + assert density(G)(H) == sqrt(2)*exp(-3*Trace(H**2)/2)/(4*pi**Rational(9, 2)) + i, j = (Dummy('i', integer=True, positive=True), + Dummy('j', integer=True, positive=True)) + l = IndexedBase('l') + assert joint_eigen_distribution(G).dummy_eq( + Lambda((l[1], l[2], l[3]), + 27*sqrt(6)*exp(-3*(l[1]**2)/2 - 3*(l[2]**2)/2 - 3*(l[3]**2)/2)* + Product(Abs(l[i] - l[j])**2, (j, i + 1, 3), (i, 1, 2))/(16*pi**Rational(3, 2)))) + s = Dummy('s') + assert level_spacing_distribution(G).dummy_eq(Lambda(s, 32*s**2*exp(-4*s**2/pi)/pi**2)) + + +def test_GaussianOrthogonalEnsemble(): + H = RandomMatrixSymbol('H', 3, 3) + _H = MatrixSymbol('_H', 3, 3) + G = GOE('O', 3) + assert density(G)(H) == exp(-3*Trace(H**2)/4)/Integral(exp(-3*Trace(_H**2)/4), _H) + i, j = (Dummy('i', integer=True, positive=True), + Dummy('j', integer=True, positive=True)) + l = IndexedBase('l') + assert joint_eigen_distribution(G).dummy_eq( + Lambda((l[1], l[2], l[3]), + 9*sqrt(2)*exp(-3*l[1]**2/2 - 3*l[2]**2/2 - 3*l[3]**2/2)* + Product(Abs(l[i] - l[j]), (j, i + 1, 3), (i, 1, 2))/(32*pi))) + s = Dummy('s') + assert level_spacing_distribution(G).dummy_eq(Lambda(s, s*pi*exp(-s**2*pi/4)/2)) + +def test_GaussianSymplecticEnsemble(): + H = RandomMatrixSymbol('H', 3, 3) + _H = MatrixSymbol('_H', 3, 3) + G = GSE('O', 3) + assert density(G)(H) == exp(-3*Trace(H**2))/Integral(exp(-3*Trace(_H**2)), _H) + i, j = (Dummy('i', integer=True, positive=True), + Dummy('j', integer=True, positive=True)) + l = IndexedBase('l') + assert joint_eigen_distribution(G).dummy_eq( + Lambda((l[1], l[2], l[3]), + 162*sqrt(3)*exp(-3*l[1]**2/2 - 3*l[2]**2/2 - 3*l[3]**2/2)* + Product(Abs(l[i] - l[j])**4, (j, i + 1, 3), (i, 1, 2))/(5*pi**Rational(3, 2)))) + s = Dummy('s') + assert level_spacing_distribution(G).dummy_eq(Lambda(s, S(262144)*s**4*exp(-64*s**2/(9*pi))/(729*pi**3))) + +def test_CircularUnitaryEnsemble(): + CU = CUE('U', 3) + j, k = (Dummy('j', integer=True, positive=True), + Dummy('k', integer=True, positive=True)) + t = IndexedBase('t') + assert joint_eigen_distribution(CU).dummy_eq( + Lambda((t[1], t[2], t[3]), + Product(Abs(exp(I*t[j]) - exp(I*t[k]))**2, + (j, k + 1, 3), (k, 1, 2))/(48*pi**3)) + ) + +def test_CircularOrthogonalEnsemble(): + CO = COE('U', 3) + j, k = (Dummy('j', integer=True, positive=True), + Dummy('k', integer=True, positive=True)) + t = IndexedBase('t') + assert joint_eigen_distribution(CO).dummy_eq( + Lambda((t[1], t[2], t[3]), + Product(Abs(exp(I*t[j]) - exp(I*t[k])), + (j, k + 1, 3), (k, 1, 2))/(48*pi**2)) + ) + +def test_CircularSymplecticEnsemble(): + CS = CSE('U', 3) + j, k = (Dummy('j', integer=True, positive=True), + Dummy('k', integer=True, positive=True)) + t = IndexedBase('t') + assert joint_eigen_distribution(CS).dummy_eq( + Lambda((t[1], t[2], t[3]), + Product(Abs(exp(I*t[j]) - exp(I*t[k]))**4, + (j, k + 1, 3), (k, 1, 2))/(720*pi**3)) + ) + +def test_JointEigenDistribution(): + A = Matrix([[Normal('A00', 0, 1), Normal('A01', 1, 1)], + [Beta('A10', 1, 1), Beta('A11', 1, 1)]]) + assert JointEigenDistribution(A) == \ + JointDistributionHandmade(-sqrt(A[0, 0]**2 - 2*A[0, 0]*A[1, 1] + 4*A[0, 1]*A[1, 0] + A[1, 1]**2)/2 + + A[0, 0]/2 + A[1, 1]/2, sqrt(A[0, 0]**2 - 2*A[0, 0]*A[1, 1] + 4*A[0, 1]*A[1, 0] + A[1, 1]**2)/2 + A[0, 0]/2 + A[1, 1]/2) + raises(ValueError, lambda: JointEigenDistribution(Matrix([[1, 0], [2, 1]]))) + +def test_issue_19841(): + G1 = GUE('U', 2) + G2 = G1.xreplace({2: 2}) + assert G1.args == G2.args + + X = MatrixSymbol('X', 2, 2) + G = GSE('U', 2) + h_pspace = RandomMatrixPSpace('P', model=density(G)) + H = RandomMatrixSymbol('H', 2, 2, pspace=h_pspace) + H2 = RandomMatrixSymbol('H', 2, 2, pspace=None) + assert H.doit() == H + + assert (2*H).xreplace({H: X}) == 2*X + assert (2*H).xreplace({H2: X}) == 2*H + assert (2*H2).xreplace({H: X}) == 2*H2 + assert (2*H2).xreplace({H2: X}) == 2*X diff --git a/phivenv/Lib/site-packages/sympy/stats/tests/test_rv.py b/phivenv/Lib/site-packages/sympy/stats/tests/test_rv.py new file mode 100644 index 0000000000000000000000000000000000000000..185756300556f2fe70b76c402113ec2bb2501ef4 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/stats/tests/test_rv.py @@ -0,0 +1,441 @@ +from sympy.concrete.summations import Sum +from sympy.core.basic import Basic +from sympy.core.containers import Tuple +from sympy.core.function import Lambda +from sympy.core.numbers import (Rational, nan, oo, pi) +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.combinatorial.factorials import (FallingFactorial, binomial) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.functions.special.delta_functions import DiracDelta +from sympy.integrals.integrals import integrate +from sympy.logic.boolalg import (And, Or) +from sympy.matrices.dense import Matrix +from sympy.sets.sets import Interval +from sympy.tensor.indexed import Indexed +from sympy.stats import (Die, Normal, Exponential, FiniteRV, P, E, H, variance, + density, given, independent, dependent, where, pspace, GaussianUnitaryEnsemble, + random_symbols, sample, Geometric, factorial_moment, Binomial, Hypergeometric, + DiscreteUniform, Poisson, characteristic_function, moment_generating_function, + BernoulliProcess, Variance, Expectation, Probability, Covariance, covariance, cmoment, + moment, median) +from sympy.stats.rv import (IndependentProductPSpace, rs_swap, Density, NamedArgsMixin, + RandomSymbol, sample_iter, PSpace, is_random, RandomIndexedSymbol, RandomMatrixSymbol) +from sympy.testing.pytest import raises, skip, XFAIL, warns_deprecated_sympy +from sympy.external import import_module +from sympy.core.numbers import comp +from sympy.stats.frv_types import BernoulliDistribution +from sympy.core.symbol import Dummy +from sympy.functions.elementary.piecewise import Piecewise + +def test_where(): + X, Y = Die('X'), Die('Y') + Z = Normal('Z', 0, 1) + + assert where(Z**2 <= 1).set == Interval(-1, 1) + assert where(Z**2 <= 1).as_boolean() == Interval(-1, 1).as_relational(Z.symbol) + assert where(And(X > Y, Y > 4)).as_boolean() == And( + Eq(X.symbol, 6), Eq(Y.symbol, 5)) + + assert len(where(X < 3).set) == 2 + assert 1 in where(X < 3).set + + X, Y = Normal('X', 0, 1), Normal('Y', 0, 1) + assert where(And(X**2 <= 1, X >= 0)).set == Interval(0, 1) + XX = given(X, And(X**2 <= 1, X >= 0)) + assert XX.pspace.domain.set == Interval(0, 1) + assert XX.pspace.domain.as_boolean() == \ + And(0 <= X.symbol, X.symbol**2 <= 1, -oo < X.symbol, X.symbol < oo) + + with raises(TypeError): + XX = given(X, X + 3) + + +def test_random_symbols(): + X, Y = Normal('X', 0, 1), Normal('Y', 0, 1) + + assert set(random_symbols(2*X + 1)) == {X} + assert set(random_symbols(2*X + Y)) == {X, Y} + assert set(random_symbols(2*X + Y.symbol)) == {X} + assert set(random_symbols(2)) == set() + + +def test_characteristic_function(): + # Imports I from sympy + from sympy.core.numbers import I + X = Normal('X',0,1) + Y = DiscreteUniform('Y', [1,2,7]) + Z = Poisson('Z', 2) + t = symbols('_t') + P = Lambda(t, exp(-t**2/2)) + Q = Lambda(t, exp(7*t*I)/3 + exp(2*t*I)/3 + exp(t*I)/3) + R = Lambda(t, exp(2 * exp(t*I) - 2)) + + + assert characteristic_function(X).dummy_eq(P) + assert characteristic_function(Y).dummy_eq(Q) + assert characteristic_function(Z).dummy_eq(R) + + +def test_moment_generating_function(): + + X = Normal('X',0,1) + Y = DiscreteUniform('Y', [1,2,7]) + Z = Poisson('Z', 2) + t = symbols('_t') + P = Lambda(t, exp(t**2/2)) + Q = Lambda(t, (exp(7*t)/3 + exp(2*t)/3 + exp(t)/3)) + R = Lambda(t, exp(2 * exp(t) - 2)) + + + assert moment_generating_function(X).dummy_eq(P) + assert moment_generating_function(Y).dummy_eq(Q) + assert moment_generating_function(Z).dummy_eq(R) + +def test_sample_iter(): + + X = Normal('X',0,1) + Y = DiscreteUniform('Y', [1, 2, 7]) + Z = Poisson('Z', 2) + + scipy = import_module('scipy') + if not scipy: + skip('Scipy is not installed. Abort tests') + expr = X**2 + 3 + iterator = sample_iter(expr) + + expr2 = Y**2 + 5*Y + 4 + iterator2 = sample_iter(expr2) + + expr3 = Z**3 + 4 + iterator3 = sample_iter(expr3) + + def is_iterator(obj): + if ( + hasattr(obj, '__iter__') and + (hasattr(obj, 'next') or + hasattr(obj, '__next__')) and + callable(obj.__iter__) and + obj.__iter__() is obj + ): + return True + else: + return False + assert is_iterator(iterator) + assert is_iterator(iterator2) + assert is_iterator(iterator3) + +def test_pspace(): + X, Y = Normal('X', 0, 1), Normal('Y', 0, 1) + x = Symbol('x') + + raises(ValueError, lambda: pspace(5 + 3)) + raises(ValueError, lambda: pspace(x < 1)) + assert pspace(X) == X.pspace + assert pspace(2*X + 1) == X.pspace + assert pspace(2*X + Y) == IndependentProductPSpace(Y.pspace, X.pspace) + +def test_rs_swap(): + X = Normal('x', 0, 1) + Y = Exponential('y', 1) + + XX = Normal('x', 0, 2) + YY = Normal('y', 0, 3) + + expr = 2*X + Y + assert expr.subs(rs_swap((X, Y), (YY, XX))) == 2*XX + YY + + +def test_RandomSymbol(): + + X = Normal('x', 0, 1) + Y = Normal('x', 0, 2) + assert X.symbol == Y.symbol + assert X != Y + + assert X.name == X.symbol.name + + X = Normal('lambda', 0, 1) # make sure we can use protected terms + X = Normal('Lambda', 0, 1) # make sure we can use SymPy terms + + +def test_RandomSymbol_diff(): + X = Normal('x', 0, 1) + assert (2*X).diff(X) + + +def test_random_symbol_no_pspace(): + x = RandomSymbol(Symbol('x')) + assert x.pspace == PSpace() + +def test_overlap(): + X = Normal('x', 0, 1) + Y = Normal('x', 0, 2) + + raises(ValueError, lambda: P(X > Y)) + + +def test_IndependentProductPSpace(): + X = Normal('X', 0, 1) + Y = Normal('Y', 0, 1) + px = X.pspace + py = Y.pspace + assert pspace(X + Y) == IndependentProductPSpace(px, py) + assert pspace(X + Y) == IndependentProductPSpace(py, px) + + +def test_E(): + assert E(5) == 5 + + +def test_H(): + X = Normal('X', 0, 1) + D = Die('D', sides = 4) + G = Geometric('G', 0.5) + assert H(X, X > 0) == -log(2)/2 + S.Half + log(pi)/2 + assert H(D, D > 2) == log(2) + assert comp(H(G).evalf().round(2), 1.39) + + +def test_Sample(): + X = Die('X', 6) + Y = Normal('Y', 0, 1) + z = Symbol('z', integer=True) + + scipy = import_module('scipy') + if not scipy: + skip('Scipy is not installed. Abort tests') + assert sample(X) in [1, 2, 3, 4, 5, 6] + assert isinstance(sample(X + Y), float) + + assert P(X + Y > 0, Y < 0, numsamples=10).is_number + assert E(X + Y, numsamples=10).is_number + assert E(X**2 + Y, numsamples=10).is_number + assert E((X + Y)**2, numsamples=10).is_number + assert variance(X + Y, numsamples=10).is_number + + raises(TypeError, lambda: P(Y > z, numsamples=5)) + + assert P(sin(Y) <= 1, numsamples=10) == 1.0 + assert P(sin(Y) <= 1, cos(Y) < 1, numsamples=10) == 1.0 + + assert all(i in range(1, 7) for i in density(X, numsamples=10)) + assert all(i in range(4, 7) for i in density(X, X>3, numsamples=10)) + + numpy = import_module('numpy') + if not numpy: + skip('Numpy is not installed. Abort tests') + #Test Issue #21563: Output of sample must be a float or array + assert isinstance(sample(X), (numpy.int32, numpy.int64)) + assert isinstance(sample(Y), numpy.float64) + assert isinstance(sample(X, size=2), numpy.ndarray) + + with warns_deprecated_sympy(): + sample(X, numsamples=2) + +@XFAIL +def test_samplingE(): + scipy = import_module('scipy') + if not scipy: + skip('Scipy is not installed. Abort tests') + Y = Normal('Y', 0, 1) + z = Symbol('z', integer=True) + assert E(Sum(1/z**Y, (z, 1, oo)), Y > 2, numsamples=3).is_number + + +def test_given(): + X = Normal('X', 0, 1) + Y = Normal('Y', 0, 1) + A = given(X, True) + B = given(X, Y > 2) + + assert X == A == B + + +def test_factorial_moment(): + X = Poisson('X', 2) + Y = Binomial('Y', 2, S.Half) + Z = Hypergeometric('Z', 4, 2, 2) + assert factorial_moment(X, 2) == 4 + assert factorial_moment(Y, 2) == S.Half + assert factorial_moment(Z, 2) == Rational(1, 3) + + x, y, z, l = symbols('x y z l') + Y = Binomial('Y', 2, y) + Z = Hypergeometric('Z', 10, 2, 3) + assert factorial_moment(Y, l) == y**2*FallingFactorial( + 2, l) + 2*y*(1 - y)*FallingFactorial(1, l) + (1 - y)**2*\ + FallingFactorial(0, l) + assert factorial_moment(Z, l) == 7*FallingFactorial(0, l)/\ + 15 + 7*FallingFactorial(1, l)/15 + FallingFactorial(2, l)/15 + + +def test_dependence(): + X, Y = Die('X'), Die('Y') + assert independent(X, 2*Y) + assert not dependent(X, 2*Y) + + X, Y = Normal('X', 0, 1), Normal('Y', 0, 1) + assert independent(X, Y) + assert dependent(X, 2*X) + + # Create a dependency + XX, YY = given(Tuple(X, Y), Eq(X + Y, 3)) + assert dependent(XX, YY) + +def test_dependent_finite(): + X, Y = Die('X'), Die('Y') + # Dependence testing requires symbolic conditions which currently break + # finite random variables + assert dependent(X, Y + X) + + XX, YY = given(Tuple(X, Y), X + Y > 5) # Create a dependency + assert dependent(XX, YY) + + +def test_normality(): + X, Y = Normal('X', 0, 1), Normal('Y', 0, 1) + x = Symbol('x', real=True) + z = Symbol('z', real=True) + dens = density(X - Y, Eq(X + Y, z)) + + assert integrate(dens(x), (x, -oo, oo)) == 1 + + +def test_Density(): + X = Die('X', 6) + d = Density(X) + assert d.doit() == density(X) + +def test_NamedArgsMixin(): + class Foo(Basic, NamedArgsMixin): + _argnames = 'foo', 'bar' + + a = Foo(S(1), S(2)) + + assert a.foo == 1 + assert a.bar == 2 + + raises(AttributeError, lambda: a.baz) + + class Bar(Basic, NamedArgsMixin): + pass + + raises(AttributeError, lambda: Bar(S(1), S(2)).foo) + +def test_density_constant(): + assert density(3)(2) == 0 + assert density(3)(3) == DiracDelta(0) + +def test_cmoment_constant(): + assert variance(3) == 0 + assert cmoment(3, 3) == 0 + assert cmoment(3, 4) == 0 + x = Symbol('x') + assert variance(x) == 0 + assert cmoment(x, 15) == 0 + assert cmoment(x, 0) == 1 + +def test_moment_constant(): + assert moment(3, 0) == 1 + assert moment(3, 1) == 3 + assert moment(3, 2) == 9 + x = Symbol('x') + assert moment(x, 2) == x**2 + +def test_median_constant(): + assert median(3) == 3 + x = Symbol('x') + assert median(x) == x + +def test_real(): + x = Normal('x', 0, 1) + assert x.is_real + + +def test_issue_10052(): + X = Exponential('X', 3) + assert P(X < oo) == 1 + assert P(X > oo) == 0 + assert P(X < 2, X > oo) == 0 + assert P(X < oo, X > oo) == 0 + assert P(X < oo, X > 2) == 1 + assert P(X < 3, X == 2) == 0 + raises(ValueError, lambda: P(1)) + raises(ValueError, lambda: P(X < 1, 2)) + +def test_issue_11934(): + density = {0: .5, 1: .5} + X = FiniteRV('X', density) + assert E(X) == 0.5 + assert P( X>= 2) == 0 + +def test_issue_8129(): + X = Exponential('X', 4) + assert P(X >= X) == 1 + assert P(X > X) == 0 + assert P(X > X+1) == 0 + +def test_issue_12237(): + X = Normal('X', 0, 1) + Y = Normal('Y', 0, 1) + U = P(X > 0, X) + V = P(Y < 0, X) + W = P(X + Y > 0, X) + assert W == P(X + Y > 0, X) + assert U == BernoulliDistribution(S.Half, S.Zero, S.One) + assert V == S.Half + +def test_is_random(): + X = Normal('X', 0, 1) + Y = Normal('Y', 0, 1) + a, b = symbols('a, b') + G = GaussianUnitaryEnsemble('U', 2) + B = BernoulliProcess('B', 0.9) + assert not is_random(a) + assert not is_random(a + b) + assert not is_random(a * b) + assert not is_random(Matrix([a**2, b**2])) + assert is_random(X) + assert is_random(X**2 + Y) + assert is_random(Y + b**2) + assert is_random(Y > 5) + assert is_random(B[3] < 1) + assert is_random(G) + assert is_random(X * Y * B[1]) + assert is_random(Matrix([[X, B[2]], [G, Y]])) + assert is_random(Eq(X, 4)) + +def test_issue_12283(): + x = symbols('x') + X = RandomSymbol(x) + Y = RandomSymbol('Y') + Z = RandomMatrixSymbol('Z', 2, 1) + W = RandomMatrixSymbol('W', 2, 1) + RI = RandomIndexedSymbol(Indexed('RI', 3)) + assert pspace(Z) == PSpace() + assert pspace(RI) == PSpace() + assert pspace(X) == PSpace() + assert E(X) == Expectation(X) + assert P(Y > 3) == Probability(Y > 3) + assert variance(X) == Variance(X) + assert variance(RI) == Variance(RI) + assert covariance(X, Y) == Covariance(X, Y) + assert covariance(W, Z) == Covariance(W, Z) + +def test_issue_6810(): + X = Die('X', 6) + Y = Normal('Y', 0, 1) + assert P(Eq(X, 2)) == S(1)/6 + assert P(Eq(Y, 0)) == 0 + assert P(Or(X > 2, X < 3)) == 1 + assert P(And(X > 3, X > 2)) == S(1)/2 + +def test_issue_20286(): + n, p = symbols('n p') + B = Binomial('B', n, p) + k = Dummy('k', integer = True) + eq = Sum(Piecewise((-p**k*(1 - p)**(-k + n)*log(p**k*(1 - p)**(-k + n)*binomial(n, k))*binomial(n, k), (k >= 0) & (k <= n)), (nan, True)), (k, 0, n)) + assert eq.dummy_eq(H(B)) diff --git a/phivenv/Lib/site-packages/sympy/stats/tests/test_stochastic_process.py b/phivenv/Lib/site-packages/sympy/stats/tests/test_stochastic_process.py new file mode 100644 index 0000000000000000000000000000000000000000..3e42ffc8632240b1d85a774467a057e9857c567c --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/stats/tests/test_stochastic_process.py @@ -0,0 +1,763 @@ +from sympy.concrete.summations import Sum +from sympy.core.containers import Tuple +from sympy.core.function import Lambda +from sympy.core.numbers import (Float, Rational, oo, pi) +from sympy.core.relational import (Eq, Ge, Gt, Le, Lt, Ne) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.combinatorial.factorials import factorial +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.integers import ceiling +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.special.error_functions import erf +from sympy.functions.special.gamma_functions import (gamma, lowergamma) +from sympy.logic.boolalg import (And, Not) +from sympy.matrices.dense import Matrix +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.immutable import ImmutableMatrix +from sympy.sets.contains import Contains +from sympy.sets.fancysets import Range +from sympy.sets.sets import (FiniteSet, Interval) +from sympy.stats import (DiscreteMarkovChain, P, TransitionMatrixOf, E, + StochasticStateSpaceOf, variance, ContinuousMarkovChain, + BernoulliProcess, PoissonProcess, WienerProcess, + GammaProcess, sample_stochastic_process) +from sympy.stats.joint_rv import JointDistribution +from sympy.stats.joint_rv_types import JointDistributionHandmade +from sympy.stats.rv import RandomIndexedSymbol +from sympy.stats.symbolic_probability import Probability, Expectation +from sympy.testing.pytest import (raises, skip, ignore_warnings, + warns_deprecated_sympy) +from sympy.external import import_module +from sympy.stats.frv_types import BernoulliDistribution +from sympy.stats.drv_types import PoissonDistribution +from sympy.stats.crv_types import NormalDistribution, GammaDistribution +from sympy.core.symbol import Str + + +def test_DiscreteMarkovChain(): + + # pass only the name + X = DiscreteMarkovChain("X") + assert isinstance(X.state_space, Range) + assert X.index_set == S.Naturals0 + assert isinstance(X.transition_probabilities, MatrixSymbol) + t = symbols('t', positive=True, integer=True) + assert isinstance(X[t], RandomIndexedSymbol) + assert E(X[0]) == Expectation(X[0]) + raises(TypeError, lambda: DiscreteMarkovChain(1)) + raises(NotImplementedError, lambda: X(t)) + raises(NotImplementedError, lambda: X.communication_classes()) + raises(NotImplementedError, lambda: X.canonical_form()) + raises(NotImplementedError, lambda: X.decompose()) + + nz = Symbol('n', integer=True) + TZ = MatrixSymbol('M', nz, nz) + SZ = Range(nz) + YZ = DiscreteMarkovChain('Y', SZ, TZ) + assert P(Eq(YZ[2], 1), Eq(YZ[1], 0)) == TZ[0, 1] + + raises(ValueError, lambda: sample_stochastic_process(t)) + raises(ValueError, lambda: next(sample_stochastic_process(X))) + # pass name and state_space + # any hashable object should be a valid state + # states should be valid as a tuple/set/list/Tuple/Range + sym, rainy, cloudy, sunny = symbols('a Rainy Cloudy Sunny', real=True) + state_spaces = [(1, 2, 3), [Str('Hello'), sym, DiscreteMarkovChain("Y", (1,2,3))], + Tuple(S(1), exp(sym), Str('World'), sympify=False), Range(-1, 5, 2), + [rainy, cloudy, sunny]] + chains = [DiscreteMarkovChain("Y", state_space) for state_space in state_spaces] + + for i, Y in enumerate(chains): + assert isinstance(Y.transition_probabilities, MatrixSymbol) + assert Y.state_space == state_spaces[i] or Y.state_space == FiniteSet(*state_spaces[i]) + assert Y.number_of_states == 3 + + with ignore_warnings(UserWarning): # TODO: Restore tests once warnings are removed + assert P(Eq(Y[2], 1), Eq(Y[0], 2), evaluate=False) == Probability(Eq(Y[2], 1), Eq(Y[0], 2)) + assert E(Y[0]) == Expectation(Y[0]) + + raises(ValueError, lambda: next(sample_stochastic_process(Y))) + + raises(TypeError, lambda: DiscreteMarkovChain("Y", {1: 1})) + Y = DiscreteMarkovChain("Y", Range(1, t, 2)) + assert Y.number_of_states == ceiling((t-1)/2) + + # pass name and transition_probabilities + chains = [DiscreteMarkovChain("Y", trans_probs=Matrix([])), + DiscreteMarkovChain("Y", trans_probs=Matrix([[0, 1], [1, 0]])), + DiscreteMarkovChain("Y", trans_probs=Matrix([[pi, 1-pi], [sym, 1-sym]]))] + for Z in chains: + assert Z.number_of_states == Z.transition_probabilities.shape[0] + assert isinstance(Z.transition_probabilities, ImmutableMatrix) + + # pass name, state_space and transition_probabilities + T = Matrix([[0.5, 0.2, 0.3],[0.2, 0.5, 0.3],[0.2, 0.3, 0.5]]) + TS = MatrixSymbol('T', 3, 3) + Y = DiscreteMarkovChain("Y", [0, 1, 2], T) + YS = DiscreteMarkovChain("Y", ['One', 'Two', 3], TS) + assert Y.joint_distribution(1, Y[2], 3) == JointDistribution(Y[1], Y[2], Y[3]) + raises(ValueError, lambda: Y.joint_distribution(Y[1].symbol, Y[2].symbol)) + assert P(Eq(Y[3], 2), Eq(Y[1], 1)).round(2) == Float(0.36, 2) + assert (P(Eq(YS[3], 2), Eq(YS[1], 1)) - + (TS[0, 2]*TS[1, 0] + TS[1, 1]*TS[1, 2] + TS[1, 2]*TS[2, 2])).simplify() == 0 + assert P(Eq(YS[1], 1), Eq(YS[2], 2)) == Probability(Eq(YS[1], 1)) + assert P(Eq(YS[3], 3), Eq(YS[1], 1)) == TS[0, 2]*TS[1, 0] + TS[1, 1]*TS[1, 2] + TS[1, 2]*TS[2, 2] + TO = Matrix([[0.25, 0.75, 0],[0, 0.25, 0.75],[0.75, 0, 0.25]]) + assert P(Eq(Y[3], 2), Eq(Y[1], 1) & TransitionMatrixOf(Y, TO)).round(3) == Float(0.375, 3) + with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed + assert E(Y[3], evaluate=False) == Expectation(Y[3]) + assert E(Y[3], Eq(Y[2], 1)).round(2) == Float(1.1, 3) + TSO = MatrixSymbol('T', 4, 4) + raises(ValueError, lambda: str(P(Eq(YS[3], 2), Eq(YS[1], 1) & TransitionMatrixOf(YS, TSO)))) + raises(TypeError, lambda: DiscreteMarkovChain("Z", [0, 1, 2], symbols('M'))) + raises(ValueError, lambda: DiscreteMarkovChain("Z", [0, 1, 2], MatrixSymbol('T', 3, 4))) + raises(ValueError, lambda: E(Y[3], Eq(Y[2], 6))) + raises(ValueError, lambda: E(Y[2], Eq(Y[3], 1))) + + + # extended tests for probability queries + TO1 = Matrix([[Rational(1, 4), Rational(3, 4), 0],[Rational(1, 3), Rational(1, 3), Rational(1, 3)],[0, Rational(1, 4), Rational(3, 4)]]) + assert P(And(Eq(Y[2], 1), Eq(Y[1], 1), Eq(Y[0], 0)), + Eq(Probability(Eq(Y[0], 0)), Rational(1, 4)) & TransitionMatrixOf(Y, TO1)) == Rational(1, 16) + assert P(And(Eq(Y[2], 1), Eq(Y[1], 1), Eq(Y[0], 0)), TransitionMatrixOf(Y, TO1)) == \ + Probability(Eq(Y[0], 0))/4 + assert P(Lt(X[1], 2) & Gt(X[1], 0), Eq(X[0], 2) & + StochasticStateSpaceOf(X, [0, 1, 2]) & TransitionMatrixOf(X, TO1)) == Rational(1, 4) + assert P(Lt(X[1], 2) & Gt(X[1], 0), Eq(X[0], 2) & + StochasticStateSpaceOf(X, [S(0), '0', 1]) & TransitionMatrixOf(X, TO1)) == Rational(1, 4) + assert P(Ne(X[1], 2) & Ne(X[1], 1), Eq(X[0], 2) & + StochasticStateSpaceOf(X, [0, 1, 2]) & TransitionMatrixOf(X, TO1)) is S.Zero + assert P(Ne(X[1], 2) & Ne(X[1], 1), Eq(X[0], 2) & + StochasticStateSpaceOf(X, [S(0), '0', 1]) & TransitionMatrixOf(X, TO1)) is S.Zero + assert P(And(Eq(Y[2], 1), Eq(Y[1], 1), Eq(Y[0], 0)), Eq(Y[1], 1)) == 0.1*Probability(Eq(Y[0], 0)) + + # testing properties of Markov chain + TO2 = Matrix([[S.One, 0, 0],[Rational(1, 3), Rational(1, 3), Rational(1, 3)],[0, Rational(1, 4), Rational(3, 4)]]) + TO3 = Matrix([[Rational(1, 4), Rational(3, 4), 0],[Rational(1, 3), Rational(1, 3), Rational(1, 3)], [0, Rational(1, 4), Rational(3, 4)]]) + Y2 = DiscreteMarkovChain('Y', trans_probs=TO2) + Y3 = DiscreteMarkovChain('Y', trans_probs=TO3) + assert Y3.fundamental_matrix() == ImmutableMatrix([[176, 81, -132], [36, 141, -52], [-44, -39, 208]])/125 + assert Y2.is_absorbing_chain() == True + assert Y3.is_absorbing_chain() == False + assert Y2.canonical_form() == ([0, 1, 2], TO2) + assert Y3.canonical_form() == ([0, 1, 2], TO3) + assert Y2.decompose() == ([0, 1, 2], TO2[0:1, 0:1], TO2[1:3, 0:1], TO2[1:3, 1:3]) + assert Y3.decompose() == ([0, 1, 2], TO3, Matrix(0, 3, []), Matrix(0, 0, [])) + TO4 = Matrix([[Rational(1, 5), Rational(2, 5), Rational(2, 5)], [Rational(1, 10), S.Half, Rational(2, 5)], [Rational(3, 5), Rational(3, 10), Rational(1, 10)]]) + Y4 = DiscreteMarkovChain('Y', trans_probs=TO4) + w = ImmutableMatrix([[Rational(11, 39), Rational(16, 39), Rational(4, 13)]]) + assert Y4.limiting_distribution == w + assert Y4.is_regular() == True + assert Y4.is_ergodic() == True + TS1 = MatrixSymbol('T', 3, 3) + Y5 = DiscreteMarkovChain('Y', trans_probs=TS1) + assert Y5.limiting_distribution(w, TO4).doit() == True + assert Y5.stationary_distribution(condition_set=True).subs(TS1, TO4).contains(w).doit() == S.true + TO6 = Matrix([[S.One, 0, 0, 0, 0],[S.Half, 0, S.Half, 0, 0],[0, S.Half, 0, S.Half, 0], [0, 0, S.Half, 0, S.Half], [0, 0, 0, 0, 1]]) + Y6 = DiscreteMarkovChain('Y', trans_probs=TO6) + assert Y6.fundamental_matrix() == ImmutableMatrix([[Rational(3, 2), S.One, S.Half], [S.One, S(2), S.One], [S.Half, S.One, Rational(3, 2)]]) + assert Y6.absorbing_probabilities() == ImmutableMatrix([[Rational(3, 4), Rational(1, 4)], [S.Half, S.Half], [Rational(1, 4), Rational(3, 4)]]) + with warns_deprecated_sympy(): + Y6.absorbing_probabilites() + TO7 = Matrix([[Rational(1, 2), Rational(1, 4), Rational(1, 4)], [Rational(1, 2), 0, Rational(1, 2)], [Rational(1, 4), Rational(1, 4), Rational(1, 2)]]) + Y7 = DiscreteMarkovChain('Y', trans_probs=TO7) + assert Y7.is_absorbing_chain() == False + assert Y7.fundamental_matrix() == ImmutableMatrix([[Rational(86, 75), Rational(1, 25), Rational(-14, 75)], + [Rational(2, 25), Rational(21, 25), Rational(2, 25)], + [Rational(-14, 75), Rational(1, 25), Rational(86, 75)]]) + + # test for zero-sized matrix functionality + X = DiscreteMarkovChain('X', trans_probs=Matrix([])) + assert X.number_of_states == 0 + assert X.stationary_distribution() == Matrix([[]]) + assert X.communication_classes() == [] + assert X.canonical_form() == ([], Matrix([])) + assert X.decompose() == ([], Matrix([]), Matrix([]), Matrix([])) + assert X.is_regular() == False + assert X.is_ergodic() == False + + # test communication_class + # see https://drive.google.com/drive/folders/1HbxLlwwn2b3U8Lj7eb_ASIUb5vYaNIjg?usp=sharing + # tutorial 2.pdf + TO7 = Matrix([[0, 5, 5, 0, 0], + [0, 0, 0, 10, 0], + [5, 0, 5, 0, 0], + [0, 10, 0, 0, 0], + [0, 3, 0, 3, 4]])/10 + Y7 = DiscreteMarkovChain('Y', trans_probs=TO7) + tuples = Y7.communication_classes() + classes, recurrence, periods = list(zip(*tuples)) + assert classes == ([1, 3], [0, 2], [4]) + assert recurrence == (True, False, False) + assert periods == (2, 1, 1) + + TO8 = Matrix([[0, 0, 0, 10, 0, 0], + [5, 0, 5, 0, 0, 0], + [0, 4, 0, 0, 0, 6], + [10, 0, 0, 0, 0, 0], + [0, 10, 0, 0, 0, 0], + [0, 0, 0, 5, 5, 0]])/10 + Y8 = DiscreteMarkovChain('Y', trans_probs=TO8) + tuples = Y8.communication_classes() + classes, recurrence, periods = list(zip(*tuples)) + assert classes == ([0, 3], [1, 2, 5, 4]) + assert recurrence == (True, False) + assert periods == (2, 2) + + TO9 = Matrix([[2, 0, 0, 3, 0, 0, 3, 2, 0, 0], + [0, 10, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 2, 2, 0, 0, 0, 0, 0, 3, 3], + [0, 0, 0, 3, 0, 0, 6, 1, 0, 0], + [0, 0, 0, 0, 5, 5, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 10, 0, 0, 0, 0], + [4, 0, 0, 5, 0, 0, 1, 0, 0, 0], + [2, 0, 0, 4, 0, 0, 2, 2, 0, 0], + [3, 0, 1, 0, 0, 0, 0, 0, 4, 2], + [0, 0, 4, 0, 0, 0, 0, 0, 3, 3]])/10 + Y9 = DiscreteMarkovChain('Y', trans_probs=TO9) + tuples = Y9.communication_classes() + classes, recurrence, periods = list(zip(*tuples)) + assert classes == ([0, 3, 6, 7], [1], [2, 8, 9], [5], [4]) + assert recurrence == (True, True, False, True, False) + assert periods == (1, 1, 1, 1, 1) + + # test canonical form + # see https://web.archive.org/web/20201230182007/https://www.dartmouth.edu/~chance/teaching_aids/books_articles/probability_book/Chapter11.pdf + # example 11.13 + T = Matrix([[1, 0, 0, 0, 0], + [S(1) / 2, 0, S(1) / 2, 0, 0], + [0, S(1) / 2, 0, S(1) / 2, 0], + [0, 0, S(1) / 2, 0, S(1) / 2], + [0, 0, 0, 0, S(1)]]) + DW = DiscreteMarkovChain('DW', [0, 1, 2, 3, 4], T) + states, A, B, C = DW.decompose() + assert states == [0, 4, 1, 2, 3] + assert A == Matrix([[1, 0], [0, 1]]) + assert B == Matrix([[S(1)/2, 0], [0, 0], [0, S(1)/2]]) + assert C == Matrix([[0, S(1)/2, 0], [S(1)/2, 0, S(1)/2], [0, S(1)/2, 0]]) + states, new_matrix = DW.canonical_form() + assert states == [0, 4, 1, 2, 3] + assert new_matrix == Matrix([[1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [S(1)/2, 0, 0, S(1)/2, 0], + [0, 0, S(1)/2, 0, S(1)/2], + [0, S(1)/2, 0, S(1)/2, 0]]) + + # test regular and ergodic + # https://web.archive.org/web/20201230182007/https://www.dartmouth.edu/~chance/teaching_aids/books_articles/probability_book/Chapter11.pdf + T = Matrix([[0, 4, 0, 0, 0], + [1, 0, 3, 0, 0], + [0, 2, 0, 2, 0], + [0, 0, 3, 0, 1], + [0, 0, 0, 4, 0]])/4 + X = DiscreteMarkovChain('X', trans_probs=T) + assert not X.is_regular() + assert X.is_ergodic() + T = Matrix([[0, 1], [1, 0]]) + X = DiscreteMarkovChain('X', trans_probs=T) + assert not X.is_regular() + assert X.is_ergodic() + # http://www.math.wisc.edu/~valko/courses/331/MC2.pdf + T = Matrix([[2, 1, 1], + [2, 0, 2], + [1, 1, 2]])/4 + X = DiscreteMarkovChain('X', trans_probs=T) + assert X.is_regular() + assert X.is_ergodic() + # https://docs.ufpr.br/~lucambio/CE222/1S2014/Kemeny-Snell1976.pdf + T = Matrix([[1, 1], [1, 1]])/2 + X = DiscreteMarkovChain('X', trans_probs=T) + assert X.is_regular() + assert X.is_ergodic() + + # test is_absorbing_chain + T = Matrix([[0, 1, 0], + [1, 0, 0], + [0, 0, 1]]) + X = DiscreteMarkovChain('X', trans_probs=T) + assert not X.is_absorbing_chain() + # https://en.wikipedia.org/wiki/Absorbing_Markov_chain + T = Matrix([[1, 1, 0, 0], + [0, 1, 1, 0], + [1, 0, 0, 1], + [0, 0, 0, 2]])/2 + X = DiscreteMarkovChain('X', trans_probs=T) + assert X.is_absorbing_chain() + T = Matrix([[2, 0, 0, 0, 0], + [1, 0, 1, 0, 0], + [0, 1, 0, 1, 0], + [0, 0, 1, 0, 1], + [0, 0, 0, 0, 2]])/2 + X = DiscreteMarkovChain('X', trans_probs=T) + assert X.is_absorbing_chain() + + # test custom state space + Y10 = DiscreteMarkovChain('Y', [1, 2, 3], TO2) + tuples = Y10.communication_classes() + classes, recurrence, periods = list(zip(*tuples)) + assert classes == ([1], [2, 3]) + assert recurrence == (True, False) + assert periods == (1, 1) + assert Y10.canonical_form() == ([1, 2, 3], TO2) + assert Y10.decompose() == ([1, 2, 3], TO2[0:1, 0:1], TO2[1:3, 0:1], TO2[1:3, 1:3]) + + # testing miscellaneous queries + T = Matrix([[S.Half, Rational(1, 4), Rational(1, 4)], + [Rational(1, 3), 0, Rational(2, 3)], + [S.Half, S.Half, 0]]) + X = DiscreteMarkovChain('X', [0, 1, 2], T) + assert P(Eq(X[1], 2) & Eq(X[2], 1) & Eq(X[3], 0), + Eq(P(Eq(X[1], 0)), Rational(1, 4)) & Eq(P(Eq(X[1], 1)), Rational(1, 4))) == Rational(1, 12) + assert P(Eq(X[2], 1) | Eq(X[2], 2), Eq(X[1], 1)) == Rational(2, 3) + assert P(Eq(X[2], 1) & Eq(X[2], 2), Eq(X[1], 1)) is S.Zero + assert P(Ne(X[2], 2), Eq(X[1], 1)) == Rational(1, 3) + assert E(X[1]**2, Eq(X[0], 1)) == Rational(8, 3) + assert variance(X[1], Eq(X[0], 1)) == Rational(8, 9) + raises(ValueError, lambda: E(X[1], Eq(X[2], 1))) + raises(ValueError, lambda: DiscreteMarkovChain('X', [0, 1], T)) + + # testing miscellaneous queries with different state space + X = DiscreteMarkovChain('X', ['A', 'B', 'C'], T) + assert P(Eq(X[1], 2) & Eq(X[2], 1) & Eq(X[3], 0), + Eq(P(Eq(X[1], 0)), Rational(1, 4)) & Eq(P(Eq(X[1], 1)), Rational(1, 4))) == Rational(1, 12) + assert P(Eq(X[2], 1) | Eq(X[2], 2), Eq(X[1], 1)) == Rational(2, 3) + assert P(Eq(X[2], 1) & Eq(X[2], 2), Eq(X[1], 1)) is S.Zero + assert P(Ne(X[2], 2), Eq(X[1], 1)) == Rational(1, 3) + a = X.state_space.args[0] + c = X.state_space.args[2] + assert (E(X[1] ** 2, Eq(X[0], 1)) - (a**2/3 + 2*c**2/3)).simplify() == 0 + assert (variance(X[1], Eq(X[0], 1)) - (2*(-a/3 + c/3)**2/3 + (2*a/3 - 2*c/3)**2/3)).simplify() == 0 + raises(ValueError, lambda: E(X[1], Eq(X[2], 1))) + + #testing queries with multiple RandomIndexedSymbols + T = Matrix([[Rational(5, 10), Rational(3, 10), Rational(2, 10)], [Rational(2, 10), Rational(7, 10), Rational(1, 10)], [Rational(3, 10), Rational(3, 10), Rational(4, 10)]]) + Y = DiscreteMarkovChain("Y", [0, 1, 2], T) + assert P(Eq(Y[7], Y[5]), Eq(Y[2], 0)).round(5) == Float(0.44428, 5) + assert P(Gt(Y[3], Y[1]), Eq(Y[0], 0)).round(2) == Float(0.36, 2) + assert P(Le(Y[5], Y[10]), Eq(Y[4], 2)).round(6) == Float(0.583120, 6) + assert Float(P(Eq(Y[10], Y[5]), Eq(Y[4], 1)), 14) == Float(1 - P(Ne(Y[10], Y[5]), Eq(Y[4], 1)), 14) + assert Float(P(Gt(Y[8], Y[9]), Eq(Y[3], 2)), 14) == Float(1 - P(Le(Y[8], Y[9]), Eq(Y[3], 2)), 14) + assert Float(P(Lt(Y[1], Y[4]), Eq(Y[0], 0)), 14) == Float(1 - P(Ge(Y[1], Y[4]), Eq(Y[0], 0)), 14) + assert P(Eq(Y[5], Y[10]), Eq(Y[2], 1)) == P(Eq(Y[10], Y[5]), Eq(Y[2], 1)) + assert P(Gt(Y[1], Y[2]), Eq(Y[0], 1)) == P(Lt(Y[2], Y[1]), Eq(Y[0], 1)) + assert P(Ge(Y[7], Y[6]), Eq(Y[4], 1)) == P(Le(Y[6], Y[7]), Eq(Y[4], 1)) + + #test symbolic queries + a, b, c, d = symbols('a b c d') + T = Matrix([[Rational(1, 10), Rational(4, 10), Rational(5, 10)], [Rational(3, 10), Rational(4, 10), Rational(3, 10)], [Rational(7, 10), Rational(2, 10), Rational(1, 10)]]) + Y = DiscreteMarkovChain("Y", [0, 1, 2], T) + query = P(Eq(Y[a], b), Eq(Y[c], d)) + assert query.subs({a:10, b:2, c:5, d:1}).evalf().round(4) == P(Eq(Y[10], 2), Eq(Y[5], 1)).round(4) + assert query.subs({a:15, b:0, c:10, d:1}).evalf().round(4) == P(Eq(Y[15], 0), Eq(Y[10], 1)).round(4) + query_gt = P(Gt(Y[a], b), Eq(Y[c], d)) + query_le = P(Le(Y[a], b), Eq(Y[c], d)) + assert query_gt.subs({a:5, b:2, c:1, d:0}).evalf() + query_le.subs({a:5, b:2, c:1, d:0}).evalf() == 1.0 + query_ge = P(Ge(Y[a], b), Eq(Y[c], d)) + query_lt = P(Lt(Y[a], b), Eq(Y[c], d)) + assert query_ge.subs({a:4, b:1, c:0, d:2}).evalf() + query_lt.subs({a:4, b:1, c:0, d:2}).evalf() == 1.0 + + #test issue 20078 + assert (2*Y[1] + 3*Y[1]).simplify() == 5*Y[1] + assert (2*Y[1] - 3*Y[1]).simplify() == -Y[1] + assert (2*(0.25*Y[1])).simplify() == 0.5*Y[1] + assert ((2*Y[1]) * (0.25*Y[1])).simplify() == 0.5*Y[1]**2 + assert (Y[1]**2 + Y[1]**3).simplify() == (Y[1] + 1)*Y[1]**2 + +def test_sample_stochastic_process(): + if not import_module('scipy'): + skip('SciPy Not installed. Skip sampling tests') + import random + random.seed(0) + numpy = import_module('numpy') + if numpy: + numpy.random.seed(0) # scipy uses numpy to sample so to set its seed + T = Matrix([[0.5, 0.2, 0.3],[0.2, 0.5, 0.3],[0.2, 0.3, 0.5]]) + Y = DiscreteMarkovChain("Y", [0, 1, 2], T) + for samps in range(10): + assert next(sample_stochastic_process(Y)) in Y.state_space + Z = DiscreteMarkovChain("Z", ['1', 1, 0], T) + for samps in range(10): + assert next(sample_stochastic_process(Z)) in Z.state_space + + T = Matrix([[S.Half, Rational(1, 4), Rational(1, 4)], + [Rational(1, 3), 0, Rational(2, 3)], + [S.Half, S.Half, 0]]) + X = DiscreteMarkovChain('X', [0, 1, 2], T) + for samps in range(10): + assert next(sample_stochastic_process(X)) in X.state_space + W = DiscreteMarkovChain('W', [1, pi, oo], T) + for samps in range(10): + assert next(sample_stochastic_process(W)) in W.state_space + + +def test_ContinuousMarkovChain(): + T1 = Matrix([[S(-2), S(2), S.Zero], + [S.Zero, S.NegativeOne, S.One], + [Rational(3, 2), Rational(3, 2), S(-3)]]) + C1 = ContinuousMarkovChain('C', [0, 1, 2], T1) + assert C1.limiting_distribution() == ImmutableMatrix([[Rational(3, 19), Rational(12, 19), Rational(4, 19)]]) + + T2 = Matrix([[-S.One, S.One, S.Zero], [S.One, -S.One, S.Zero], [S.Zero, S.One, -S.One]]) + C2 = ContinuousMarkovChain('C', [0, 1, 2], T2) + A, t = C2.generator_matrix, symbols('t', positive=True) + assert C2.transition_probabilities(A)(t) == Matrix([[S.Half + exp(-2*t)/2, S.Half - exp(-2*t)/2, 0], + [S.Half - exp(-2*t)/2, S.Half + exp(-2*t)/2, 0], + [S.Half - exp(-t) + exp(-2*t)/2, S.Half - exp(-2*t)/2, exp(-t)]]) + with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed + assert P(Eq(C2(1), 1), Eq(C2(0), 1), evaluate=False) == Probability(Eq(C2(1), 1), Eq(C2(0), 1)) + assert P(Eq(C2(1), 1), Eq(C2(0), 1)) == exp(-2)/2 + S.Half + assert P(Eq(C2(1), 0) & Eq(C2(2), 1) & Eq(C2(3), 1), + Eq(P(Eq(C2(1), 0)), S.Half)) == (Rational(1, 4) - exp(-2)/4)*(exp(-2)/2 + S.Half) + assert P(Not(Eq(C2(1), 0) & Eq(C2(2), 1) & Eq(C2(3), 2)) | + (Eq(C2(1), 0) & Eq(C2(2), 1) & Eq(C2(3), 2)), + Eq(P(Eq(C2(1), 0)), Rational(1, 4)) & Eq(P(Eq(C2(1), 1)), Rational(1, 4))) is S.One + assert E(C2(Rational(3, 2)), Eq(C2(0), 2)) == -exp(-3)/2 + 2*exp(Rational(-3, 2)) + S.Half + assert variance(C2(Rational(3, 2)), Eq(C2(0), 1)) == ((S.Half - exp(-3)/2)**2*(exp(-3)/2 + S.Half) + + (Rational(-1, 2) - exp(-3)/2)**2*(S.Half - exp(-3)/2)) + raises(KeyError, lambda: P(Eq(C2(1), 0), Eq(P(Eq(C2(1), 1)), S.Half))) + assert P(Eq(C2(1), 0), Eq(P(Eq(C2(5), 1)), S.Half)) == Probability(Eq(C2(1), 0)) + TS1 = MatrixSymbol('G', 3, 3) + CS1 = ContinuousMarkovChain('C', [0, 1, 2], TS1) + A = CS1.generator_matrix + assert CS1.transition_probabilities(A)(t) == exp(t*A) + + C3 = ContinuousMarkovChain('C', [Symbol('0'), Symbol('1'), Symbol('2')], T2) + assert P(Eq(C3(1), 1), Eq(C3(0), 1)) == exp(-2)/2 + S.Half + assert P(Eq(C3(1), Symbol('1')), Eq(C3(0), Symbol('1'))) == exp(-2)/2 + S.Half + + #test probability queries + G = Matrix([[-S(1), Rational(1, 10), Rational(9, 10)], [Rational(2, 5), -S(1), Rational(3, 5)], [Rational(1, 2), Rational(1, 2), -S(1)]]) + C = ContinuousMarkovChain('C', state_space=[0, 1, 2], gen_mat=G) + assert P(Eq(C(7.385), C(3.19)), Eq(C(0.862), 0)).round(5) == Float(0.35469, 5) + assert P(Gt(C(98.715), C(19.807)), Eq(C(11.314), 2)).round(5) == Float(0.32452, 5) + assert P(Le(C(5.9), C(10.112)), Eq(C(4), 1)).round(6) == Float(0.675214, 6) + assert Float(P(Eq(C(7.32), C(2.91)), Eq(C(2.63), 1)), 14) == Float(1 - P(Ne(C(7.32), C(2.91)), Eq(C(2.63), 1)), 14) + assert Float(P(Gt(C(3.36), C(1.101)), Eq(C(0.8), 2)), 14) == Float(1 - P(Le(C(3.36), C(1.101)), Eq(C(0.8), 2)), 14) + assert Float(P(Lt(C(4.9), C(2.79)), Eq(C(1.61), 0)), 14) == Float(1 - P(Ge(C(4.9), C(2.79)), Eq(C(1.61), 0)), 14) + assert P(Eq(C(5.243), C(10.912)), Eq(C(2.174), 1)) == P(Eq(C(10.912), C(5.243)), Eq(C(2.174), 1)) + assert P(Gt(C(2.344), C(9.9)), Eq(C(1.102), 1)) == P(Lt(C(9.9), C(2.344)), Eq(C(1.102), 1)) + assert P(Ge(C(7.87), C(1.008)), Eq(C(0.153), 1)) == P(Le(C(1.008), C(7.87)), Eq(C(0.153), 1)) + + #test symbolic queries + a, b, c, d = symbols('a b c d') + query = P(Eq(C(a), b), Eq(C(c), d)) + assert query.subs({a:3.65, b:2, c:1.78, d:1}).evalf().round(10) == P(Eq(C(3.65), 2), Eq(C(1.78), 1)).round(10) + query_gt = P(Gt(C(a), b), Eq(C(c), d)) + query_le = P(Le(C(a), b), Eq(C(c), d)) + assert query_gt.subs({a:13.2, b:0, c:3.29, d:2}).evalf() + query_le.subs({a:13.2, b:0, c:3.29, d:2}).evalf() == 1.0 + query_ge = P(Ge(C(a), b), Eq(C(c), d)) + query_lt = P(Lt(C(a), b), Eq(C(c), d)) + assert query_ge.subs({a:7.43, b:1, c:1.45, d:0}).evalf() + query_lt.subs({a:7.43, b:1, c:1.45, d:0}).evalf() == 1.0 + + #test issue 20078 + assert (2*C(1) + 3*C(1)).simplify() == 5*C(1) + assert (2*C(1) - 3*C(1)).simplify() == -C(1) + assert (2*(0.25*C(1))).simplify() == 0.5*C(1) + assert (2*C(1) * 0.25*C(1)).simplify() == 0.5*C(1)**2 + assert (C(1)**2 + C(1)**3).simplify() == (C(1) + 1)*C(1)**2 + +def test_BernoulliProcess(): + + B = BernoulliProcess("B", p=0.6, success=1, failure=0) + assert B.state_space == FiniteSet(0, 1) + assert B.index_set == S.Naturals0 + assert B.success == 1 + assert B.failure == 0 + + X = BernoulliProcess("X", p=Rational(1,3), success='H', failure='T') + assert X.state_space == FiniteSet('H', 'T') + H, T = symbols("H,T") + assert E(X[1]+X[2]*X[3]) == H**2/9 + 4*H*T/9 + H/3 + 4*T**2/9 + 2*T/3 + + t, x = symbols('t, x', positive=True, integer=True) + assert isinstance(B[t], RandomIndexedSymbol) + + raises(ValueError, lambda: BernoulliProcess("X", p=1.1, success=1, failure=0)) + raises(NotImplementedError, lambda: B(t)) + + raises(IndexError, lambda: B[-3]) + assert B.joint_distribution(B[3], B[9]) == JointDistributionHandmade(Lambda((B[3], B[9]), + Piecewise((0.6, Eq(B[3], 1)), (0.4, Eq(B[3], 0)), (0, True)) + *Piecewise((0.6, Eq(B[9], 1)), (0.4, Eq(B[9], 0)), (0, True)))) + + assert B.joint_distribution(2, B[4]) == JointDistributionHandmade(Lambda((B[2], B[4]), + Piecewise((0.6, Eq(B[2], 1)), (0.4, Eq(B[2], 0)), (0, True)) + *Piecewise((0.6, Eq(B[4], 1)), (0.4, Eq(B[4], 0)), (0, True)))) + + # Test for the sum distribution of Bernoulli Process RVs + Y = B[1] + B[2] + B[3] + assert P(Eq(Y, 0)).round(2) == Float(0.06, 1) + assert P(Eq(Y, 2)).round(2) == Float(0.43, 2) + assert P(Eq(Y, 4)).round(2) == 0 + assert P(Gt(Y, 1)).round(2) == Float(0.65, 2) + # Test for independency of each Random Indexed variable + assert P(Eq(B[1], 0) & Eq(B[2], 1) & Eq(B[3], 0) & Eq(B[4], 1)).round(2) == Float(0.06, 1) + + assert E(2 * B[1] + B[2]).round(2) == Float(1.80, 3) + assert E(2 * B[1] + B[2] + 5).round(2) == Float(6.80, 3) + assert E(B[2] * B[4] + B[10]).round(2) == Float(0.96, 2) + assert E(B[2] > 0, Eq(B[1],1) & Eq(B[2],1)).round(2) == Float(0.60,2) + assert E(B[1]) == 0.6 + assert P(B[1] > 0).round(2) == Float(0.60, 2) + assert P(B[1] < 1).round(2) == Float(0.40, 2) + assert P(B[1] > 0, B[2] <= 1).round(2) == Float(0.60, 2) + assert P(B[12] * B[5] > 0).round(2) == Float(0.36, 2) + assert P(B[12] * B[5] > 0, B[4] < 1).round(2) == Float(0.36, 2) + assert P(Eq(B[2], 1), B[2] > 0) == 1.0 + assert P(Eq(B[5], 3)) == 0 + assert P(Eq(B[1], 1), B[1] < 0) == 0 + assert P(B[2] > 0, Eq(B[2], 1)) == 1 + assert P(B[2] < 0, Eq(B[2], 1)) == 0 + assert P(B[2] > 0, B[2]==7) == 0 + assert P(B[5] > 0, B[5]) == BernoulliDistribution(0.6, 0, 1) + raises(ValueError, lambda: P(3)) + raises(ValueError, lambda: P(B[3] > 0, 3)) + + # test issue 19456 + expr = Sum(B[t], (t, 0, 4)) + expr2 = Sum(B[t], (t, 1, 3)) + expr3 = Sum(B[t]**2, (t, 1, 3)) + assert expr.doit() == B[0] + B[1] + B[2] + B[3] + B[4] + assert expr2.doit() == Y + assert expr3.doit() == B[1]**2 + B[2]**2 + B[3]**2 + assert B[2*t].free_symbols == {B[2*t], t} + assert B[4].free_symbols == {B[4]} + assert B[x*t].free_symbols == {B[x*t], x, t} + + #test issue 20078 + assert (2*B[t] + 3*B[t]).simplify() == 5*B[t] + assert (2*B[t] - 3*B[t]).simplify() == -B[t] + assert (2*(0.25*B[t])).simplify() == 0.5*B[t] + assert (2*B[t] * 0.25*B[t]).simplify() == 0.5*B[t]**2 + assert (B[t]**2 + B[t]**3).simplify() == (B[t] + 1)*B[t]**2 + +def test_PoissonProcess(): + X = PoissonProcess("X", 3) + assert X.state_space == S.Naturals0 + assert X.index_set == Interval(0, oo) + assert X.lamda == 3 + + t, d, x, y = symbols('t d x y', positive=True) + assert isinstance(X(t), RandomIndexedSymbol) + assert X.distribution(t) == PoissonDistribution(3*t) + with warns_deprecated_sympy(): + X.distribution(X(t)) + raises(ValueError, lambda: PoissonProcess("X", -1)) + raises(NotImplementedError, lambda: X[t]) + raises(IndexError, lambda: X(-5)) + + assert X.joint_distribution(X(2), X(3)) == JointDistributionHandmade(Lambda((X(2), X(3)), + 6**X(2)*9**X(3)*exp(-15)/(factorial(X(2))*factorial(X(3))))) + + assert X.joint_distribution(4, 6) == JointDistributionHandmade(Lambda((X(4), X(6)), + 12**X(4)*18**X(6)*exp(-30)/(factorial(X(4))*factorial(X(6))))) + + assert P(X(t) < 1) == exp(-3*t) + assert P(Eq(X(t), 0), Contains(t, Interval.Lopen(3, 5))) == exp(-6) # exp(-2*lamda) + res = P(Eq(X(t), 1), Contains(t, Interval.Lopen(3, 4))) + assert res == 3*exp(-3) + + # Equivalent to P(Eq(X(t), 1))**4 because of non-overlapping intervals + assert P(Eq(X(t), 1) & Eq(X(d), 1) & Eq(X(x), 1) & Eq(X(y), 1), Contains(t, Interval.Lopen(0, 1)) + & Contains(d, Interval.Lopen(1, 2)) & Contains(x, Interval.Lopen(2, 3)) + & Contains(y, Interval.Lopen(3, 4))) == res**4 + + # Return Probability because of overlapping intervals + assert P(Eq(X(t), 2) & Eq(X(d), 3), Contains(t, Interval.Lopen(0, 2)) + & Contains(d, Interval.Ropen(2, 4))) == \ + Probability(Eq(X(d), 3) & Eq(X(t), 2), Contains(t, Interval.Lopen(0, 2)) + & Contains(d, Interval.Ropen(2, 4))) + + raises(ValueError, lambda: P(Eq(X(t), 2) & Eq(X(d), 3), + Contains(t, Interval.Lopen(0, 4)) & Contains(d, Interval.Lopen(3, oo)))) # no bound on d + assert P(Eq(X(3), 2)) == 81*exp(-9)/2 + assert P(Eq(X(t), 2), Contains(t, Interval.Lopen(0, 5))) == 225*exp(-15)/2 + + # Check that probability works correctly by adding it to 1 + res1 = P(X(t) <= 3, Contains(t, Interval.Lopen(0, 5))) + res2 = P(X(t) > 3, Contains(t, Interval.Lopen(0, 5))) + assert res1 == 691*exp(-15) + assert (res1 + res2).simplify() == 1 + + # Check Not and Or + assert P(Not(Eq(X(t), 2) & (X(d) > 3)), Contains(t, Interval.Ropen(2, 4)) & \ + Contains(d, Interval.Lopen(7, 8))).simplify() == -18*exp(-6) + 234*exp(-9) + 1 + assert P(Eq(X(t), 2) | Ne(X(t), 4), Contains(t, Interval.Ropen(2, 4))) == 1 - 36*exp(-6) + raises(ValueError, lambda: P(X(t) > 2, X(t) + X(d))) + assert E(X(t)) == 3*t # property of the distribution at a given timestamp + assert E(X(t)**2 + X(d)*2 + X(y)**3, Contains(t, Interval.Lopen(0, 1)) + & Contains(d, Interval.Lopen(1, 2)) & Contains(y, Interval.Ropen(3, 4))) == 75 + assert E(X(t)**2, Contains(t, Interval.Lopen(0, 1))) == 12 + assert E(x*(X(t) + X(d))*(X(t)**2+X(d)**2), Contains(t, Interval.Lopen(0, 1)) + & Contains(d, Interval.Ropen(1, 2))) == \ + Expectation(x*(X(d) + X(t))*(X(d)**2 + X(t)**2), Contains(t, Interval.Lopen(0, 1)) + & Contains(d, Interval.Ropen(1, 2))) + + # Value Error because of infinite time bound + raises(ValueError, lambda: E(X(t)**3, Contains(t, Interval.Lopen(1, oo)))) + + # Equivalent to E(X(t)**2) - E(X(d)**2) == E(X(1)**2) - E(X(1)**2) == 0 + assert E((X(t) + X(d))*(X(t) - X(d)), Contains(t, Interval.Lopen(0, 1)) + & Contains(d, Interval.Lopen(1, 2))) == 0 + assert E(X(2) + x*E(X(5))) == 15*x + 6 + assert E(x*X(1) + y) == 3*x + y + assert P(Eq(X(1), 2) & Eq(X(t), 3), Contains(t, Interval.Lopen(1, 2))) == 81*exp(-6)/4 + Y = PoissonProcess("Y", 6) + Z = X + Y + assert Z.lamda == X.lamda + Y.lamda == 9 + raises(ValueError, lambda: X + 5) # should be added be only PoissonProcess instance + N, M = Z.split(4, 5) + assert N.lamda == 4 + assert M.lamda == 5 + raises(ValueError, lambda: Z.split(3, 2)) # 2+3 != 9 + + raises(ValueError, lambda :P(Eq(X(t), 0), Contains(t, Interval.Lopen(1, 3)) & Eq(X(1), 0))) + # check if it handles queries with two random variables in one args + res1 = P(Eq(N(3), N(5))) + assert res1 == P(Eq(N(t), 0), Contains(t, Interval(3, 5))) + res2 = P(N(3) > N(1)) + assert res2 == P((N(t) > 0), Contains(t, Interval(1, 3))) + assert P(N(3) < N(1)) == 0 # condition is not possible + res3 = P(N(3) <= N(1)) # holds only for Eq(N(3), N(1)) + assert res3 == P(Eq(N(t), 0), Contains(t, Interval(1, 3))) + + # tests from https://www.probabilitycourse.com/chapter11/11_1_2_basic_concepts_of_the_poisson_process.php + X = PoissonProcess('X', 10) # 11.1 + assert P(Eq(X(S(1)/3), 3) & Eq(X(1), 10)) == exp(-10)*Rational(8000000000, 11160261) + assert P(Eq(X(1), 1), Eq(X(S(1)/3), 3)) == 0 + assert P(Eq(X(1), 10), Eq(X(S(1)/3), 3)) == P(Eq(X(S(2)/3), 7)) + + X = PoissonProcess('X', 2) # 11.2 + assert P(X(S(1)/2) < 1) == exp(-1) + assert P(X(3) < 1, Eq(X(1), 0)) == exp(-4) + assert P(Eq(X(4), 3), Eq(X(2), 3)) == exp(-4) + + X = PoissonProcess('X', 3) + assert P(Eq(X(2), 5) & Eq(X(1), 2)) == Rational(81, 4)*exp(-6) + + # check few properties + assert P(X(2) <= 3, X(1)>=1) == 3*P(Eq(X(1), 0)) + 2*P(Eq(X(1), 1)) + P(Eq(X(1), 2)) + assert P(X(2) <= 3, X(1) > 1) == 2*P(Eq(X(1), 0)) + 1*P(Eq(X(1), 1)) + assert P(Eq(X(2), 5) & Eq(X(1), 2)) == P(Eq(X(1), 3))*P(Eq(X(1), 2)) + assert P(Eq(X(3), 4), Eq(X(1), 3)) == P(Eq(X(2), 1)) + + #test issue 20078 + assert (2*X(t) + 3*X(t)).simplify() == 5*X(t) + assert (2*X(t) - 3*X(t)).simplify() == -X(t) + assert (2*(0.25*X(t))).simplify() == 0.5*X(t) + assert (2*X(t) * 0.25*X(t)).simplify() == 0.5*X(t)**2 + assert (X(t)**2 + X(t)**3).simplify() == (X(t) + 1)*X(t)**2 + +def test_WienerProcess(): + X = WienerProcess("X") + assert X.state_space == S.Reals + assert X.index_set == Interval(0, oo) + + t, d, x, y = symbols('t d x y', positive=True) + assert isinstance(X(t), RandomIndexedSymbol) + assert X.distribution(t) == NormalDistribution(0, sqrt(t)) + with warns_deprecated_sympy(): + X.distribution(X(t)) + raises(ValueError, lambda: PoissonProcess("X", -1)) + raises(NotImplementedError, lambda: X[t]) + raises(IndexError, lambda: X(-2)) + + assert X.joint_distribution(X(2), X(3)) == JointDistributionHandmade( + Lambda((X(2), X(3)), sqrt(6)*exp(-X(2)**2/4)*exp(-X(3)**2/6)/(12*pi))) + assert X.joint_distribution(4, 6) == JointDistributionHandmade( + Lambda((X(4), X(6)), sqrt(6)*exp(-X(4)**2/8)*exp(-X(6)**2/12)/(24*pi))) + + assert P(X(t) < 3).simplify() == erf(3*sqrt(2)/(2*sqrt(t)))/2 + S(1)/2 + assert P(X(t) > 2, Contains(t, Interval.Lopen(3, 7))).simplify() == S(1)/2 -\ + erf(sqrt(2)/2)/2 + + # Equivalent to P(X(1)>1)**4 + assert P((X(t) > 4) & (X(d) > 3) & (X(x) > 2) & (X(y) > 1), + Contains(t, Interval.Lopen(0, 1)) & Contains(d, Interval.Lopen(1, 2)) + & Contains(x, Interval.Lopen(2, 3)) & Contains(y, Interval.Lopen(3, 4))).simplify() ==\ + (1 - erf(sqrt(2)/2))*(1 - erf(sqrt(2)))*(1 - erf(3*sqrt(2)/2))*(1 - erf(2*sqrt(2)))/16 + + # Contains an overlapping interval so, return Probability + assert P((X(t)< 2) & (X(d)> 3), Contains(t, Interval.Lopen(0, 2)) + & Contains(d, Interval.Ropen(2, 4))) == Probability((X(d) > 3) & (X(t) < 2), + Contains(d, Interval.Ropen(2, 4)) & Contains(t, Interval.Lopen(0, 2))) + + assert str(P(Not((X(t) < 5) & (X(d) > 3)), Contains(t, Interval.Ropen(2, 4)) & + Contains(d, Interval.Lopen(7, 8))).simplify()) == \ + '-(1 - erf(3*sqrt(2)/2))*(2 - erfc(5/2))/4 + 1' + # Distribution has mean 0 at each timestamp + assert E(X(t)) == 0 + assert E(x*(X(t) + X(d))*(X(t)**2+X(d)**2), Contains(t, Interval.Lopen(0, 1)) + & Contains(d, Interval.Ropen(1, 2))) == Expectation(x*(X(d) + X(t))*(X(d)**2 + X(t)**2), + Contains(d, Interval.Ropen(1, 2)) & Contains(t, Interval.Lopen(0, 1))) + assert E(X(t) + x*E(X(3))) == 0 + + #test issue 20078 + assert (2*X(t) + 3*X(t)).simplify() == 5*X(t) + assert (2*X(t) - 3*X(t)).simplify() == -X(t) + assert (2*(0.25*X(t))).simplify() == 0.5*X(t) + assert (2*X(t) * 0.25*X(t)).simplify() == 0.5*X(t)**2 + assert (X(t)**2 + X(t)**3).simplify() == (X(t) + 1)*X(t)**2 + + +def test_GammaProcess_symbolic(): + t, d, x, y, g, l = symbols('t d x y g l', positive=True) + X = GammaProcess("X", l, g) + + raises(NotImplementedError, lambda: X[t]) + raises(IndexError, lambda: X(-1)) + assert isinstance(X(t), RandomIndexedSymbol) + assert X.state_space == Interval(0, oo) + assert X.distribution(t) == GammaDistribution(g*t, 1/l) + with warns_deprecated_sympy(): + X.distribution(X(t)) + assert X.joint_distribution(5, X(3)) == JointDistributionHandmade(Lambda( + (X(5), X(3)), l**(8*g)*exp(-l*X(3))*exp(-l*X(5))*X(3)**(3*g - 1)*X(5)**(5*g + - 1)/(gamma(3*g)*gamma(5*g)))) + # property of the gamma process at any given timestamp + assert E(X(t)) == g*t/l + assert variance(X(t)).simplify() == g*t/l**2 + + # Equivalent to E(2*X(1)) + E(X(1)**2) + E(X(1)**3), where E(X(1)) == g/l + assert E(X(t)**2 + X(d)*2 + X(y)**3, Contains(t, Interval.Lopen(0, 1)) + & Contains(d, Interval.Lopen(1, 2)) & Contains(y, Interval.Ropen(3, 4))) == \ + 2*g/l + (g**2 + g)/l**2 + (g**3 + 3*g**2 + 2*g)/l**3 + + assert P(X(t) > 3, Contains(t, Interval.Lopen(3, 4))).simplify() == \ + 1 - lowergamma(g, 3*l)/gamma(g) # equivalent to P(X(1)>3) + + + #test issue 20078 + assert (2*X(t) + 3*X(t)).simplify() == 5*X(t) + assert (2*X(t) - 3*X(t)).simplify() == -X(t) + assert (2*(0.25*X(t))).simplify() == 0.5*X(t) + assert (2*X(t) * 0.25*X(t)).simplify() == 0.5*X(t)**2 + assert (X(t)**2 + X(t)**3).simplify() == (X(t) + 1)*X(t)**2 +def test_GammaProcess_numeric(): + t, d, x, y = symbols('t d x y', positive=True) + X = GammaProcess("X", 1, 2) + assert X.state_space == Interval(0, oo) + assert X.index_set == Interval(0, oo) + assert X.lamda == 1 + assert X.gamma == 2 + + raises(ValueError, lambda: GammaProcess("X", -1, 2)) + raises(ValueError, lambda: GammaProcess("X", 0, -2)) + raises(ValueError, lambda: GammaProcess("X", -1, -2)) + + # all are independent because of non-overlapping intervals + assert P((X(t) > 4) & (X(d) > 3) & (X(x) > 2) & (X(y) > 1), Contains(t, + Interval.Lopen(0, 1)) & Contains(d, Interval.Lopen(1, 2)) & Contains(x, + Interval.Lopen(2, 3)) & Contains(y, Interval.Lopen(3, 4))).simplify() == \ + 120*exp(-10) + + # Check working with Not and Or + assert P(Not((X(t) < 5) & (X(d) > 3)), Contains(t, Interval.Ropen(2, 4)) & + Contains(d, Interval.Lopen(7, 8))).simplify() == -4*exp(-3) + 472*exp(-8)/3 + 1 + assert P((X(t) > 2) | (X(t) < 4), Contains(t, Interval.Ropen(1, 4))).simplify() == \ + -643*exp(-4)/15 + 109*exp(-2)/15 + 1 + + assert E(X(t)) == 2*t # E(X(t)) == gamma*t/l + assert E(X(2) + x*E(X(5))) == 10*x + 4 diff --git a/phivenv/Lib/site-packages/sympy/stats/tests/test_symbolic_multivariate.py b/phivenv/Lib/site-packages/sympy/stats/tests/test_symbolic_multivariate.py new file mode 100644 index 0000000000000000000000000000000000000000..79979e20a6f10d2a2ddfe85ce4c8df145e98c3fd --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/stats/tests/test_symbolic_multivariate.py @@ -0,0 +1,172 @@ +from sympy.stats import Expectation, Normal, Variance, Covariance +from sympy.testing.pytest import raises +from sympy.core.symbol import symbols +from sympy.matrices.exceptions import ShapeError +from sympy.matrices.dense import Matrix +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.expressions.special import ZeroMatrix +from sympy.stats.rv import RandomMatrixSymbol +from sympy.stats.symbolic_multivariate_probability import (ExpectationMatrix, + VarianceMatrix, CrossCovarianceMatrix) + +j, k = symbols("j,k") + +A = MatrixSymbol("A", k, k) +B = MatrixSymbol("B", k, k) +C = MatrixSymbol("C", k, k) +D = MatrixSymbol("D", k, k) + +a = MatrixSymbol("a", k, 1) +b = MatrixSymbol("b", k, 1) + +A2 = MatrixSymbol("A2", 2, 2) +B2 = MatrixSymbol("B2", 2, 2) + +X = RandomMatrixSymbol("X", k, 1) +Y = RandomMatrixSymbol("Y", k, 1) +Z = RandomMatrixSymbol("Z", k, 1) +W = RandomMatrixSymbol("W", k, 1) + +R = RandomMatrixSymbol("R", k, k) + +X2 = RandomMatrixSymbol("X2", 2, 1) + +normal = Normal("normal", 0, 1) + +m1 = Matrix([ + [1, j*Normal("normal2", 2, 1)], + [normal, 0] +]) + +def test_multivariate_expectation(): + expr = Expectation(a) + assert expr == Expectation(a) == ExpectationMatrix(a) + assert expr.expand() == a + + expr = Expectation(X) + assert expr == Expectation(X) == ExpectationMatrix(X) + assert expr.shape == (k, 1) + assert expr.rows == k + assert expr.cols == 1 + assert isinstance(expr, ExpectationMatrix) + + expr = Expectation(A*X + b) + assert expr == ExpectationMatrix(A*X + b) + assert expr.expand() == A*ExpectationMatrix(X) + b + assert isinstance(expr, ExpectationMatrix) + assert expr.shape == (k, 1) + + expr = Expectation(m1*X2) + assert expr.expand() == expr + + expr = Expectation(A2*m1*B2*X2) + assert expr.args[0].args == (A2, m1, B2, X2) + assert expr.expand() == A2*ExpectationMatrix(m1*B2*X2) + + expr = Expectation((X + Y)*(X - Y).T) + assert expr.expand() == ExpectationMatrix(X*X.T) - ExpectationMatrix(X*Y.T) +\ + ExpectationMatrix(Y*X.T) - ExpectationMatrix(Y*Y.T) + + expr = Expectation(A*X + B*Y) + assert expr.expand() == A*ExpectationMatrix(X) + B*ExpectationMatrix(Y) + + assert Expectation(m1).doit() == Matrix([[1, 2*j], [0, 0]]) + + x1 = Matrix([ + [Normal('N11', 11, 1), Normal('N12', 12, 1)], + [Normal('N21', 21, 1), Normal('N22', 22, 1)] + ]) + x2 = Matrix([ + [Normal('M11', 1, 1), Normal('M12', 2, 1)], + [Normal('M21', 3, 1), Normal('M22', 4, 1)] + ]) + + assert Expectation(Expectation(x1 + x2)).doit(deep=False) == ExpectationMatrix(x1 + x2) + assert Expectation(Expectation(x1 + x2)).doit() == Matrix([[12, 14], [24, 26]]) + + +def test_multivariate_variance(): + raises(ShapeError, lambda: Variance(A)) + + expr = Variance(a) + assert expr == Variance(a) == VarianceMatrix(a) + assert expr.expand() == ZeroMatrix(k, k) + expr = Variance(a.T) + assert expr == Variance(a.T) == VarianceMatrix(a.T) + assert expr.expand() == ZeroMatrix(k, k) + + expr = Variance(X) + assert expr == Variance(X) == VarianceMatrix(X) + assert expr.shape == (k, k) + assert expr.rows == k + assert expr.cols == k + assert isinstance(expr, VarianceMatrix) + + expr = Variance(A*X) + assert expr == VarianceMatrix(A*X) + assert expr.expand() == A*VarianceMatrix(X)*A.T + assert isinstance(expr, VarianceMatrix) + assert expr.shape == (k, k) + + expr = Variance(A*B*X) + assert expr.expand() == A*B*VarianceMatrix(X)*B.T*A.T + + expr = Variance(m1*X2) + assert expr.expand() == expr + + expr = Variance(A2*m1*B2*X2) + assert expr.args[0].args == (A2, m1, B2, X2) + assert expr.expand() == expr + + expr = Variance(A*X + B*Y) + assert expr.expand() == 2*A*CrossCovarianceMatrix(X, Y)*B.T +\ + A*VarianceMatrix(X)*A.T + B*VarianceMatrix(Y)*B.T + +def test_multivariate_crosscovariance(): + raises(ShapeError, lambda: Covariance(X, Y.T)) + raises(ShapeError, lambda: Covariance(X, A)) + + + expr = Covariance(a.T, b.T) + assert expr.shape == (1, 1) + assert expr.expand() == ZeroMatrix(1, 1) + + expr = Covariance(a, b) + assert expr == Covariance(a, b) == CrossCovarianceMatrix(a, b) + assert expr.expand() == ZeroMatrix(k, k) + assert expr.shape == (k, k) + assert expr.rows == k + assert expr.cols == k + assert isinstance(expr, CrossCovarianceMatrix) + + expr = Covariance(A*X + a, b) + assert expr.expand() == ZeroMatrix(k, k) + + expr = Covariance(X, Y) + assert isinstance(expr, CrossCovarianceMatrix) + assert expr.expand() == expr + + expr = Covariance(X, X) + assert isinstance(expr, CrossCovarianceMatrix) + assert expr.expand() == VarianceMatrix(X) + + expr = Covariance(X + Y, Z) + assert isinstance(expr, CrossCovarianceMatrix) + assert expr.expand() == CrossCovarianceMatrix(X, Z) + CrossCovarianceMatrix(Y, Z) + + expr = Covariance(A*X, Y) + assert isinstance(expr, CrossCovarianceMatrix) + assert expr.expand() == A*CrossCovarianceMatrix(X, Y) + + expr = Covariance(X, B*Y) + assert isinstance(expr, CrossCovarianceMatrix) + assert expr.expand() == CrossCovarianceMatrix(X, Y)*B.T + + expr = Covariance(A*X + a, B.T*Y + b) + assert isinstance(expr, CrossCovarianceMatrix) + assert expr.expand() == A*CrossCovarianceMatrix(X, Y)*B + + expr = Covariance(A*X + B*Y + a, C.T*Z + D.T*W + b) + assert isinstance(expr, CrossCovarianceMatrix) + assert expr.expand() == A*CrossCovarianceMatrix(X, W)*D + A*CrossCovarianceMatrix(X, Z)*C \ + + B*CrossCovarianceMatrix(Y, W)*D + B*CrossCovarianceMatrix(Y, Z)*C diff --git a/phivenv/Lib/site-packages/sympy/stats/tests/test_symbolic_probability.py b/phivenv/Lib/site-packages/sympy/stats/tests/test_symbolic_probability.py new file mode 100644 index 0000000000000000000000000000000000000000..edac942ac081c0d44cafd31761b77bc577b6a3fd --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/stats/tests/test_symbolic_probability.py @@ -0,0 +1,175 @@ +from sympy.concrete.summations import Sum +from sympy.core.mul import Mul +from sympy.core.numbers import (oo, pi) +from sympy.core.relational import Eq +from sympy.core.symbol import (Dummy, symbols) +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import sin +from sympy.integrals.integrals import Integral +from sympy.core.expr import unchanged +from sympy.stats import (Normal, Poisson, variance, Covariance, Variance, + Probability, Expectation, Moment, CentralMoment) +from sympy.stats.rv import probability, expectation + + +def test_literal_probability(): + X = Normal('X', 2, 3) + Y = Normal('Y', 3, 4) + Z = Poisson('Z', 4) + W = Poisson('W', 3) + x = symbols('x', real=True) + y, w, z = symbols('y, w, z') + + assert Probability(X > 0).evaluate_integral() == probability(X > 0) + assert Probability(X > x).evaluate_integral() == probability(X > x) + assert Probability(X > 0).rewrite(Integral).doit() == probability(X > 0) + assert Probability(X > x).rewrite(Integral).doit() == probability(X > x) + + assert Expectation(X).evaluate_integral() == expectation(X) + assert Expectation(X).rewrite(Integral).doit() == expectation(X) + assert Expectation(X**2).evaluate_integral() == expectation(X**2) + assert Expectation(x*X).args == (x*X,) + assert Expectation(x*X).expand() == x*Expectation(X) + assert Expectation(2*X + 3*Y + z*X*Y).expand() == 2*Expectation(X) + 3*Expectation(Y) + z*Expectation(X*Y) + assert Expectation(2*X + 3*Y + z*X*Y).args == (2*X + 3*Y + z*X*Y,) + assert Expectation(sin(X)) == Expectation(sin(X)).expand() + assert Expectation(2*x*sin(X)*Y + y*X**2 + z*X*Y).expand() == 2*x*Expectation(sin(X)*Y) \ + + y*Expectation(X**2) + z*Expectation(X*Y) + assert Expectation(X + Y).expand() == Expectation(X) + Expectation(Y) + assert Expectation((X + Y)*(X - Y)).expand() == Expectation(X**2) - Expectation(Y**2) + assert Expectation((X + Y)*(X - Y)).expand().doit() == -12 + assert Expectation(X + Y, evaluate=True).doit() == 5 + assert Expectation(X + Expectation(Y)).doit() == 5 + assert Expectation(X + Expectation(Y)).doit(deep=False) == 2 + Expectation(Expectation(Y)) + assert Expectation(X + Expectation(Y + Expectation(2*X))).doit(deep=False) == 2 \ + + Expectation(Expectation(Y + Expectation(2*X))) + assert Expectation(X + Expectation(Y + Expectation(2*X))).doit() == 9 + assert Expectation(Expectation(2*X)).doit() == 4 + assert Expectation(Expectation(2*X)).doit(deep=False) == Expectation(2*X) + assert Expectation(4*Expectation(2*X)).doit(deep=False) == 4*Expectation(2*X) + assert Expectation((X + Y)**3).expand() == 3*Expectation(X*Y**2) +\ + 3*Expectation(X**2*Y) + Expectation(X**3) + Expectation(Y**3) + assert Expectation((X - Y)**3).expand() == 3*Expectation(X*Y**2) -\ + 3*Expectation(X**2*Y) + Expectation(X**3) - Expectation(Y**3) + assert Expectation((X - Y)**2).expand() == -2*Expectation(X*Y) +\ + Expectation(X**2) + Expectation(Y**2) + + assert Variance(w).args == (w,) + assert Variance(w).expand() == 0 + assert Variance(X).evaluate_integral() == Variance(X).rewrite(Integral).doit() == variance(X) + assert Variance(X + z).args == (X + z,) + assert Variance(X + z).expand() == Variance(X) + assert Variance(X*Y).args == (Mul(X, Y),) + assert type(Variance(X*Y)) == Variance + assert Variance(z*X).expand() == z**2*Variance(X) + assert Variance(X + Y).expand() == Variance(X) + Variance(Y) + 2*Covariance(X, Y) + assert Variance(X + Y + Z + W).expand() == (Variance(X) + Variance(Y) + Variance(Z) + Variance(W) + + 2 * Covariance(X, Y) + 2 * Covariance(X, Z) + 2 * Covariance(X, W) + + 2 * Covariance(Y, Z) + 2 * Covariance(Y, W) + 2 * Covariance(W, Z)) + assert Variance(X**2).evaluate_integral() == variance(X**2) + assert unchanged(Variance, X**2) + assert Variance(x*X**2).expand() == x**2*Variance(X**2) + assert Variance(sin(X)).args == (sin(X),) + assert Variance(sin(X)).expand() == Variance(sin(X)) + assert Variance(x*sin(X)).expand() == x**2*Variance(sin(X)) + + assert Covariance(w, z).args == (w, z) + assert Covariance(w, z).expand() == 0 + assert Covariance(X, w).expand() == 0 + assert Covariance(w, X).expand() == 0 + assert Covariance(X, Y).args == (X, Y) + assert type(Covariance(X, Y)) == Covariance + assert Covariance(z*X + 3, Y).expand() == z*Covariance(X, Y) + assert Covariance(X, X).args == (X, X) + assert Covariance(X, X).expand() == Variance(X) + assert Covariance(z*X + 3, w*Y + 4).expand() == w*z*Covariance(X,Y) + assert Covariance(X, Y) == Covariance(Y, X) + assert Covariance(X + Y, Z + W).expand() == Covariance(W, X) + Covariance(W, Y) + Covariance(X, Z) + Covariance(Y, Z) + assert Covariance(x*X + y*Y, z*Z + w*W).expand() == (x*w*Covariance(W, X) + w*y*Covariance(W, Y) + + x*z*Covariance(X, Z) + y*z*Covariance(Y, Z)) + assert Covariance(x*X**2 + y*sin(Y), z*Y*Z**2 + w*W).expand() == (w*x*Covariance(W, X**2) + w*y*Covariance(sin(Y), W) + + x*z*Covariance(Y*Z**2, X**2) + y*z*Covariance(Y*Z**2, sin(Y))) + assert Covariance(X, X**2).expand() == Covariance(X, X**2) + assert Covariance(X, sin(X)).expand() == Covariance(sin(X), X) + assert Covariance(X**2, sin(X)*Y).expand() == Covariance(sin(X)*Y, X**2) + assert Covariance(w, X).evaluate_integral() == 0 + + +def test_probability_rewrite(): + X = Normal('X', 2, 3) + Y = Normal('Y', 3, 4) + Z = Poisson('Z', 4) + W = Poisson('W', 3) + x, y, w, z = symbols('x, y, w, z') + + assert Variance(w).rewrite(Expectation) == 0 + assert Variance(X).rewrite(Expectation) == Expectation(X ** 2) - Expectation(X) ** 2 + assert Variance(X, condition=Y).rewrite(Expectation) == Expectation(X ** 2, Y) - Expectation(X, Y) ** 2 + assert Variance(X, Y) != Expectation(X**2) - Expectation(X)**2 + assert Variance(X + z).rewrite(Expectation) == Expectation((X + z) ** 2) - Expectation(X + z) ** 2 + assert Variance(X * Y).rewrite(Expectation) == Expectation(X ** 2 * Y ** 2) - Expectation(X * Y) ** 2 + + assert Covariance(w, X).rewrite(Expectation) == -w*Expectation(X) + Expectation(w*X) + assert Covariance(X, Y).rewrite(Expectation) == Expectation(X*Y) - Expectation(X)*Expectation(Y) + assert Covariance(X, Y, condition=W).rewrite(Expectation) == Expectation(X * Y, W) - Expectation(X, W) * Expectation(Y, W) + + w, x, z = symbols("W, x, z") + px = Probability(Eq(X, x)) + pz = Probability(Eq(Z, z)) + + assert Expectation(X).rewrite(Probability) == Integral(x*px, (x, -oo, oo)) + assert Expectation(Z).rewrite(Probability) == Sum(z*pz, (z, 0, oo)) + assert Variance(X).rewrite(Probability) == Integral(x**2*px, (x, -oo, oo)) - Integral(x*px, (x, -oo, oo))**2 + assert Variance(Z).rewrite(Probability) == Sum(z**2*pz, (z, 0, oo)) - Sum(z*pz, (z, 0, oo))**2 + assert Covariance(w, X).rewrite(Probability) == \ + -w*Integral(x*Probability(Eq(X, x)), (x, -oo, oo)) + Integral(w*x*Probability(Eq(X, x)), (x, -oo, oo)) + + # To test rewrite as sum function + assert Variance(X).rewrite(Sum) == Variance(X).rewrite(Integral) + assert Expectation(X).rewrite(Sum) == Expectation(X).rewrite(Integral) + + assert Covariance(w, X).rewrite(Sum) == 0 + + assert Covariance(w, X).rewrite(Integral) == 0 + + assert Variance(X, condition=Y).rewrite(Probability) == Integral(x**2*Probability(Eq(X, x), Y), (x, -oo, oo)) - \ + Integral(x*Probability(Eq(X, x), Y), (x, -oo, oo))**2 + + +def test_symbolic_Moment(): + mu = symbols('mu', real=True) + sigma = symbols('sigma', positive=True) + x = symbols('x') + X = Normal('X', mu, sigma) + M = Moment(X, 4, 2) + assert M.rewrite(Expectation) == Expectation((X - 2)**4) + assert M.rewrite(Probability) == Integral((x - 2)**4*Probability(Eq(X, x)), + (x, -oo, oo)) + k = Dummy('k') + expri = Integral(sqrt(2)*(k - 2)**4*exp(-(k - \ + mu)**2/(2*sigma**2))/(2*sqrt(pi)*sigma), (k, -oo, oo)) + assert M.rewrite(Integral).dummy_eq(expri) + assert M.doit() == (mu**4 - 8*mu**3 + 6*mu**2*sigma**2 + \ + 24*mu**2 - 24*mu*sigma**2 - 32*mu + 3*sigma**4 + 24*sigma**2 + 16) + M = Moment(2, 5) + assert M.doit() == 2**5 + + +def test_symbolic_CentralMoment(): + mu = symbols('mu', real=True) + sigma = symbols('sigma', positive=True) + x = symbols('x') + X = Normal('X', mu, sigma) + CM = CentralMoment(X, 6) + assert CM.rewrite(Expectation) == Expectation((X - Expectation(X))**6) + assert CM.rewrite(Probability) == Integral((x - Integral(x*Probability(True), + (x, -oo, oo)))**6*Probability(Eq(X, x)), (x, -oo, oo)) + k = Dummy('k') + expri = Integral(sqrt(2)*(k - Integral(sqrt(2)*k*exp(-(k - \ + mu)**2/(2*sigma**2))/(2*sqrt(pi)*sigma), (k, -oo, oo)))**6*exp(-(k - \ + mu)**2/(2*sigma**2))/(2*sqrt(pi)*sigma), (k, -oo, oo)) + assert CM.rewrite(Integral).dummy_eq(expri) + assert CM.doit().simplify() == 15*sigma**6 + CM = Moment(5, 5) + assert CM.doit() == 5**5 diff --git a/phivenv/Lib/site-packages/sympy/strategies/__init__.py b/phivenv/Lib/site-packages/sympy/strategies/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bb4c5aa8afe6fd818e136ec0797b7429e2da76cf --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/strategies/__init__.py @@ -0,0 +1,50 @@ +""" Rewrite Rules + +DISCLAIMER: This module is experimental. The interface is subject to change. + +A rule is a function that transforms one expression into another + + Rule :: Expr -> Expr + +A strategy is a function that says how a rule should be applied to a syntax +tree. In general strategies take rules and produce a new rule + + Strategy :: [Rules], Other-stuff -> Rule + +This allows developers to separate a mathematical transformation from the +algorithmic details of applying that transformation. The goal is to separate +the work of mathematical programming from algorithmic programming. + +Submodules + +strategies.rl - some fundamental rules +strategies.core - generic non-SymPy specific strategies +strategies.traverse - strategies that traverse a SymPy tree +strategies.tools - some conglomerate strategies that do depend on SymPy +""" + +from . import rl +from . import traverse +from .rl import rm_id, unpack, flatten, sort, glom, distribute, rebuild +from .util import new +from .core import ( + condition, debug, chain, null_safe, do_one, exhaust, minimize, tryit) +from .tools import canon, typed +from . import branch + +__all__ = [ + 'rl', + + 'traverse', + + 'rm_id', 'unpack', 'flatten', 'sort', 'glom', 'distribute', 'rebuild', + + 'new', + + 'condition', 'debug', 'chain', 'null_safe', 'do_one', 'exhaust', + 'minimize', 'tryit', + + 'canon', 'typed', + + 'branch', +] diff --git a/phivenv/Lib/site-packages/sympy/strategies/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/strategies/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..162a91a30a6a5ab1ab4a8635b4d0758a1c43c58a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/strategies/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/strategies/__pycache__/core.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/strategies/__pycache__/core.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be5367332dd3da6e4d82bc225da036faae694afb Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/strategies/__pycache__/core.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/strategies/__pycache__/rl.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/strategies/__pycache__/rl.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a0609c3fc9ececc6505e2ea1c2ff8b33a2932e4 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/strategies/__pycache__/rl.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/strategies/__pycache__/tools.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/strategies/__pycache__/tools.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4123cf9d83abe8ab23547629dcfbc1d7e3911e82 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/strategies/__pycache__/tools.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/strategies/__pycache__/traverse.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/strategies/__pycache__/traverse.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49004c5ce39599dd6706af76cf1e031e6d05864c Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/strategies/__pycache__/traverse.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/strategies/__pycache__/tree.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/strategies/__pycache__/tree.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15416a3c54bdd5f6ad4a22f9fe6d7df45c8e23b1 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/strategies/__pycache__/tree.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/strategies/__pycache__/util.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/strategies/__pycache__/util.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f760a465b05a5686000128c7e49d8ee31ee2a6a2 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/strategies/__pycache__/util.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/strategies/branch/__init__.py b/phivenv/Lib/site-packages/sympy/strategies/branch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fec5afe84a58f3d887a8c762692a3673a2b6d4c8 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/strategies/branch/__init__.py @@ -0,0 +1,14 @@ +from . import traverse +from .core import ( + condition, debug, multiplex, exhaust, notempty, + chain, onaction, sfilter, yieldify, do_one, identity) +from .tools import canon + +__all__ = [ + 'traverse', + + 'condition', 'debug', 'multiplex', 'exhaust', 'notempty', 'chain', + 'onaction', 'sfilter', 'yieldify', 'do_one', 'identity', + + 'canon', +] diff --git a/phivenv/Lib/site-packages/sympy/strategies/branch/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/strategies/branch/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f5b0762f9a89c64da58589ef4d5d1deabf19e70 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/strategies/branch/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/strategies/branch/__pycache__/core.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/strategies/branch/__pycache__/core.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e1acb85b9ce68f40ff38c9868b9ef17d202aa1a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/strategies/branch/__pycache__/core.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/strategies/branch/__pycache__/tools.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/strategies/branch/__pycache__/tools.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9587718336860ef367b003bae8e05c96a3cca5e Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/strategies/branch/__pycache__/tools.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/strategies/branch/__pycache__/traverse.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/strategies/branch/__pycache__/traverse.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1dab0549be356d0b6fd0277c3fd608c05a45e29 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/strategies/branch/__pycache__/traverse.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/strategies/branch/core.py b/phivenv/Lib/site-packages/sympy/strategies/branch/core.py new file mode 100644 index 0000000000000000000000000000000000000000..2dabaef69b60d994799f71414699223f84e1809b --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/strategies/branch/core.py @@ -0,0 +1,116 @@ +""" Generic SymPy-Independent Strategies """ + + +def identity(x): + yield x + + +def exhaust(brule): + """ Apply a branching rule repeatedly until it has no effect """ + def exhaust_brl(expr): + seen = {expr} + for nexpr in brule(expr): + if nexpr not in seen: + seen.add(nexpr) + yield from exhaust_brl(nexpr) + if seen == {expr}: + yield expr + return exhaust_brl + + +def onaction(brule, fn): + def onaction_brl(expr): + for result in brule(expr): + if result != expr: + fn(brule, expr, result) + yield result + return onaction_brl + + +def debug(brule, file=None): + """ Print the input and output expressions at each rule application """ + if not file: + from sys import stdout + file = stdout + + def write(brl, expr, result): + file.write("Rule: %s\n" % brl.__name__) + file.write("In: %s\nOut: %s\n\n" % (expr, result)) + + return onaction(brule, write) + + +def multiplex(*brules): + """ Multiplex many branching rules into one """ + def multiplex_brl(expr): + seen = set() + for brl in brules: + for nexpr in brl(expr): + if nexpr not in seen: + seen.add(nexpr) + yield nexpr + return multiplex_brl + + +def condition(cond, brule): + """ Only apply branching rule if condition is true """ + def conditioned_brl(expr): + if cond(expr): + yield from brule(expr) + else: + pass + return conditioned_brl + + +def sfilter(pred, brule): + """ Yield only those results which satisfy the predicate """ + def filtered_brl(expr): + yield from filter(pred, brule(expr)) + return filtered_brl + + +def notempty(brule): + def notempty_brl(expr): + yielded = False + for nexpr in brule(expr): + yielded = True + yield nexpr + if not yielded: + yield expr + return notempty_brl + + +def do_one(*brules): + """ Execute one of the branching rules """ + def do_one_brl(expr): + yielded = False + for brl in brules: + for nexpr in brl(expr): + yielded = True + yield nexpr + if yielded: + return + return do_one_brl + + +def chain(*brules): + """ + Compose a sequence of brules so that they apply to the expr sequentially + """ + def chain_brl(expr): + if not brules: + yield expr + return + + head, tail = brules[0], brules[1:] + for nexpr in head(expr): + yield from chain(*tail)(nexpr) + + return chain_brl + + +def yieldify(rl): + """ Turn a rule into a branching rule """ + def brl(expr): + yield rl(expr) + return brl diff --git a/phivenv/Lib/site-packages/sympy/strategies/branch/tests/__init__.py b/phivenv/Lib/site-packages/sympy/strategies/branch/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/strategies/branch/tests/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/strategies/branch/tests/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de515204fab3e5477b1da7bfb3554773aa5bd56e Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/strategies/branch/tests/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/strategies/branch/tests/__pycache__/test_core.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/strategies/branch/tests/__pycache__/test_core.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..019ea9b3a7b70bc52951281e16fd4f14eaf0bfe4 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/strategies/branch/tests/__pycache__/test_core.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/strategies/branch/tests/__pycache__/test_tools.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/strategies/branch/tests/__pycache__/test_tools.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..420c99f10e012f78dedfb675c5d85a9b24c13e3b Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/strategies/branch/tests/__pycache__/test_tools.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/strategies/branch/tests/__pycache__/test_traverse.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/strategies/branch/tests/__pycache__/test_traverse.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fbda38de775cb1e046f813ce220ef151d64993c2 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/strategies/branch/tests/__pycache__/test_traverse.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/strategies/branch/tests/test_core.py b/phivenv/Lib/site-packages/sympy/strategies/branch/tests/test_core.py new file mode 100644 index 0000000000000000000000000000000000000000..ac620b0afb6dbadc4d97b29ddbb341cd920b6588 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/strategies/branch/tests/test_core.py @@ -0,0 +1,117 @@ +from sympy.strategies.branch.core import ( + exhaust, debug, multiplex, condition, notempty, chain, onaction, sfilter, + yieldify, do_one, identity) + + +def posdec(x): + if x > 0: + yield x - 1 + else: + yield x + + +def branch5(x): + if 0 < x < 5: + yield x - 1 + elif 5 < x < 10: + yield x + 1 + elif x == 5: + yield x + 1 + yield x - 1 + else: + yield x + + +def even(x): + return x % 2 == 0 + + +def inc(x): + yield x + 1 + + +def one_to_n(n): + yield from range(n) + + +def test_exhaust(): + brl = exhaust(branch5) + assert set(brl(3)) == {0} + assert set(brl(7)) == {10} + assert set(brl(5)) == {0, 10} + + +def test_debug(): + from io import StringIO + file = StringIO() + rl = debug(posdec, file) + list(rl(5)) + log = file.getvalue() + file.close() + + assert posdec.__name__ in log + assert '5' in log + assert '4' in log + + +def test_multiplex(): + brl = multiplex(posdec, branch5) + assert set(brl(3)) == {2} + assert set(brl(7)) == {6, 8} + assert set(brl(5)) == {4, 6} + + +def test_condition(): + brl = condition(even, branch5) + assert set(brl(4)) == set(branch5(4)) + assert set(brl(5)) == set() + + +def test_sfilter(): + brl = sfilter(even, one_to_n) + assert set(brl(10)) == {0, 2, 4, 6, 8} + + +def test_notempty(): + def ident_if_even(x): + if even(x): + yield x + + brl = notempty(ident_if_even) + assert set(brl(4)) == {4} + assert set(brl(5)) == {5} + + +def test_chain(): + assert list(chain()(2)) == [2] # identity + assert list(chain(inc, inc)(2)) == [4] + assert list(chain(branch5, inc)(4)) == [4] + assert set(chain(branch5, inc)(5)) == {5, 7} + assert list(chain(inc, branch5)(5)) == [7] + + +def test_onaction(): + L = [] + + def record(fn, input, output): + L.append((input, output)) + + list(onaction(inc, record)(2)) + assert L == [(2, 3)] + + list(onaction(identity, record)(2)) + assert L == [(2, 3)] + + +def test_yieldify(): + yinc = yieldify(lambda x: x + 1) + assert list(yinc(3)) == [4] + + +def test_do_one(): + def bad(expr): + raise ValueError + + assert list(do_one(inc)(3)) == [4] + assert list(do_one(inc, bad)(3)) == [4] + assert list(do_one(inc, posdec)(3)) == [4] diff --git a/phivenv/Lib/site-packages/sympy/strategies/branch/tests/test_tools.py b/phivenv/Lib/site-packages/sympy/strategies/branch/tests/test_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..c2bd224030c337f0a000d94f6e7e65f3b8bd118f --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/strategies/branch/tests/test_tools.py @@ -0,0 +1,42 @@ +from sympy.strategies.branch.tools import canon +from sympy.core.basic import Basic +from sympy.core.numbers import Integer +from sympy.core.singleton import S + + +def posdec(x): + if isinstance(x, Integer) and x > 0: + yield x - 1 + else: + yield x + + +def branch5(x): + if isinstance(x, Integer): + if 0 < x < 5: + yield x - 1 + elif 5 < x < 10: + yield x + 1 + elif x == 5: + yield x + 1 + yield x - 1 + else: + yield x + + +def test_zero_ints(): + expr = Basic(S(2), Basic(S(5), S(3)), S(8)) + expected = {Basic(S(0), Basic(S(0), S(0)), S(0))} + + brl = canon(posdec) + assert set(brl(expr)) == expected + + +def test_split5(): + expr = Basic(S(2), Basic(S(5), S(3)), S(8)) + expected = { + Basic(S(0), Basic(S(0), S(0)), S(10)), + Basic(S(0), Basic(S(10), S(0)), S(10))} + + brl = canon(branch5) + assert set(brl(expr)) == expected diff --git a/phivenv/Lib/site-packages/sympy/strategies/branch/tests/test_traverse.py b/phivenv/Lib/site-packages/sympy/strategies/branch/tests/test_traverse.py new file mode 100644 index 0000000000000000000000000000000000000000..e051928210981223004de28b8c617d0438e11ac6 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/strategies/branch/tests/test_traverse.py @@ -0,0 +1,53 @@ +from sympy.core.basic import Basic +from sympy.core.numbers import Integer +from sympy.core.singleton import S +from sympy.strategies.branch.traverse import top_down, sall +from sympy.strategies.branch.core import do_one, identity + + +def inc(x): + if isinstance(x, Integer): + yield x + 1 + + +def test_top_down_easy(): + expr = Basic(S(1), S(2)) + expected = Basic(S(2), S(3)) + brl = top_down(inc) + + assert set(brl(expr)) == {expected} + + +def test_top_down_big_tree(): + expr = Basic(S(1), Basic(S(2)), Basic(S(3), Basic(S(4)), S(5))) + expected = Basic(S(2), Basic(S(3)), Basic(S(4), Basic(S(5)), S(6))) + brl = top_down(inc) + + assert set(brl(expr)) == {expected} + + +def test_top_down_harder_function(): + def split5(x): + if x == 5: + yield x - 1 + yield x + 1 + + expr = Basic(Basic(S(5), S(6)), S(1)) + expected = {Basic(Basic(S(4), S(6)), S(1)), Basic(Basic(S(6), S(6)), S(1))} + brl = top_down(split5) + + assert set(brl(expr)) == expected + + +def test_sall(): + expr = Basic(S(1), S(2)) + expected = Basic(S(2), S(3)) + brl = sall(inc) + + assert list(brl(expr)) == [expected] + + expr = Basic(S(1), S(2), Basic(S(3), S(4))) + expected = Basic(S(2), S(3), Basic(S(3), S(4))) + brl = sall(do_one(inc, identity)) + + assert list(brl(expr)) == [expected] diff --git a/phivenv/Lib/site-packages/sympy/strategies/branch/tools.py b/phivenv/Lib/site-packages/sympy/strategies/branch/tools.py new file mode 100644 index 0000000000000000000000000000000000000000..a6c9097323a7962080ae4497ead976818e386518 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/strategies/branch/tools.py @@ -0,0 +1,12 @@ +from .core import exhaust, multiplex +from .traverse import top_down + + +def canon(*rules): + """ Strategy for canonicalization + + Apply each branching rule in a top-down fashion through the tree. + Multiplex through all branching rule traversals + Keep doing this until there is no change. + """ + return exhaust(multiplex(*map(top_down, rules))) diff --git a/phivenv/Lib/site-packages/sympy/strategies/branch/traverse.py b/phivenv/Lib/site-packages/sympy/strategies/branch/traverse.py new file mode 100644 index 0000000000000000000000000000000000000000..28b7098dbda401fc0f0b6d27988d8c37e2f231ae --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/strategies/branch/traverse.py @@ -0,0 +1,25 @@ +""" Branching Strategies to Traverse a Tree """ +from itertools import product +from sympy.strategies.util import basic_fns +from .core import chain, identity, do_one + + +def top_down(brule, fns=basic_fns): + """ Apply a rule down a tree running it on the top nodes first """ + return chain(do_one(brule, identity), + lambda expr: sall(top_down(brule, fns), fns)(expr)) + + +def sall(brule, fns=basic_fns): + """ Strategic all - apply rule to args """ + op, new, children, leaf = map(fns.get, ('op', 'new', 'children', 'leaf')) + + def all_rl(expr): + if leaf(expr): + yield expr + else: + myop = op(expr) + argss = product(*map(brule, children(expr))) + for args in argss: + yield new(myop, *args) + return all_rl diff --git a/phivenv/Lib/site-packages/sympy/strategies/core.py b/phivenv/Lib/site-packages/sympy/strategies/core.py new file mode 100644 index 0000000000000000000000000000000000000000..75b75cb5f2e0693eea98a7b1c9b3e7f036ec26f6 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/strategies/core.py @@ -0,0 +1,151 @@ +""" Generic SymPy-Independent Strategies """ +from __future__ import annotations +from collections.abc import Callable, Mapping +from typing import TypeVar +from sys import stdout + + +_S = TypeVar('_S') +_T = TypeVar('_T') + + +def identity(x: _T) -> _T: + return x + + +def exhaust(rule: Callable[[_T], _T]) -> Callable[[_T], _T]: + """ Apply a rule repeatedly until it has no effect """ + def exhaustive_rl(expr: _T) -> _T: + new, old = rule(expr), expr + while new != old: + new, old = rule(new), new + return new + return exhaustive_rl + + +def memoize(rule: Callable[[_S], _T]) -> Callable[[_S], _T]: + """Memoized version of a rule + + Notes + ===== + + This cache can grow infinitely, so it is not recommended to use this + than ``functools.lru_cache`` unless you need very heavy computation. + """ + cache: dict[_S, _T] = {} + + def memoized_rl(expr: _S) -> _T: + if expr in cache: + return cache[expr] + else: + result = rule(expr) + cache[expr] = result + return result + return memoized_rl + + +def condition( + cond: Callable[[_T], bool], rule: Callable[[_T], _T] +) -> Callable[[_T], _T]: + """ Only apply rule if condition is true """ + def conditioned_rl(expr: _T) -> _T: + if cond(expr): + return rule(expr) + return expr + return conditioned_rl + + +def chain(*rules: Callable[[_T], _T]) -> Callable[[_T], _T]: + """ + Compose a sequence of rules so that they apply to the expr sequentially + """ + def chain_rl(expr: _T) -> _T: + for rule in rules: + expr = rule(expr) + return expr + return chain_rl + + +def debug(rule, file=None): + """ Print out before and after expressions each time rule is used """ + if file is None: + file = stdout + + def debug_rl(*args, **kwargs): + expr = args[0] + result = rule(*args, **kwargs) + if result != expr: + file.write("Rule: %s\n" % rule.__name__) + file.write("In: %s\nOut: %s\n\n" % (expr, result)) + return result + return debug_rl + + +def null_safe(rule: Callable[[_T], _T | None]) -> Callable[[_T], _T]: + """ Return original expr if rule returns None """ + def null_safe_rl(expr: _T) -> _T: + result = rule(expr) + if result is None: + return expr + return result + return null_safe_rl + + +def tryit(rule: Callable[[_T], _T], exception) -> Callable[[_T], _T]: + """ Return original expr if rule raises exception """ + def try_rl(expr: _T) -> _T: + try: + return rule(expr) + except exception: + return expr + return try_rl + + +def do_one(*rules: Callable[[_T], _T]) -> Callable[[_T], _T]: + """ Try each of the rules until one works. Then stop. """ + def do_one_rl(expr: _T) -> _T: + for rl in rules: + result = rl(expr) + if result != expr: + return result + return expr + return do_one_rl + + +def switch( + key: Callable[[_S], _T], + ruledict: Mapping[_T, Callable[[_S], _S]] +) -> Callable[[_S], _S]: + """ Select a rule based on the result of key called on the function """ + def switch_rl(expr: _S) -> _S: + rl = ruledict.get(key(expr), identity) + return rl(expr) + return switch_rl + + +# XXX Untyped default argument for minimize function +# where python requires SupportsRichComparison type +def _identity(x): + return x + + +def minimize( + *rules: Callable[[_S], _T], + objective=_identity +) -> Callable[[_S], _T]: + """ Select result of rules that minimizes objective + + >>> from sympy.strategies import minimize + >>> inc = lambda x: x + 1 + >>> dec = lambda x: x - 1 + >>> rl = minimize(inc, dec) + >>> rl(4) + 3 + + >>> rl = minimize(inc, dec, objective=lambda x: -x) # maximize + >>> rl(4) + 5 + """ + def minrule(expr: _S) -> _T: + return min([rule(expr) for rule in rules], key=objective) + return minrule diff --git a/phivenv/Lib/site-packages/sympy/strategies/rl.py b/phivenv/Lib/site-packages/sympy/strategies/rl.py new file mode 100644 index 0000000000000000000000000000000000000000..e84ee90582fafcf36fd4205a58b05b650875a9a5 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/strategies/rl.py @@ -0,0 +1,176 @@ +""" Generic Rules for SymPy + +This file assumes knowledge of Basic and little else. +""" +from sympy.utilities.iterables import sift +from .util import new + + +# Functions that create rules +def rm_id(isid, new=new): + """ Create a rule to remove identities. + + isid - fn :: x -> Bool --- whether or not this element is an identity. + + Examples + ======== + + >>> from sympy.strategies import rm_id + >>> from sympy import Basic, S + >>> remove_zeros = rm_id(lambda x: x==0) + >>> remove_zeros(Basic(S(1), S(0), S(2))) + Basic(1, 2) + >>> remove_zeros(Basic(S(0), S(0))) # If only identities then we keep one + Basic(0) + + See Also: + unpack + """ + def ident_remove(expr): + """ Remove identities """ + ids = list(map(isid, expr.args)) + if sum(ids) == 0: # No identities. Common case + return expr + elif sum(ids) != len(ids): # there is at least one non-identity + return new(expr.__class__, + *[arg for arg, x in zip(expr.args, ids) if not x]) + else: + return new(expr.__class__, expr.args[0]) + + return ident_remove + + +def glom(key, count, combine): + """ Create a rule to conglomerate identical args. + + Examples + ======== + + >>> from sympy.strategies import glom + >>> from sympy import Add + >>> from sympy.abc import x + + >>> key = lambda x: x.as_coeff_Mul()[1] + >>> count = lambda x: x.as_coeff_Mul()[0] + >>> combine = lambda cnt, arg: cnt * arg + >>> rl = glom(key, count, combine) + + >>> rl(Add(x, -x, 3*x, 2, 3, evaluate=False)) + 3*x + 5 + + Wait, how are key, count and combine supposed to work? + + >>> key(2*x) + x + >>> count(2*x) + 2 + >>> combine(2, x) + 2*x + """ + def conglomerate(expr): + """ Conglomerate together identical args x + x -> 2x """ + groups = sift(expr.args, key) + counts = {k: sum(map(count, args)) for k, args in groups.items()} + newargs = [combine(cnt, mat) for mat, cnt in counts.items()] + if set(newargs) != set(expr.args): + return new(type(expr), *newargs) + else: + return expr + + return conglomerate + + +def sort(key, new=new): + """ Create a rule to sort by a key function. + + Examples + ======== + + >>> from sympy.strategies import sort + >>> from sympy import Basic, S + >>> sort_rl = sort(str) + >>> sort_rl(Basic(S(3), S(1), S(2))) + Basic(1, 2, 3) + """ + + def sort_rl(expr): + return new(expr.__class__, *sorted(expr.args, key=key)) + return sort_rl + + +def distribute(A, B): + """ Turns an A containing Bs into a B of As + + where A, B are container types + + >>> from sympy.strategies import distribute + >>> from sympy import Add, Mul, symbols + >>> x, y = symbols('x,y') + >>> dist = distribute(Mul, Add) + >>> expr = Mul(2, x+y, evaluate=False) + >>> expr + 2*(x + y) + >>> dist(expr) + 2*x + 2*y + """ + + def distribute_rl(expr): + for i, arg in enumerate(expr.args): + if isinstance(arg, B): + first, b, tail = expr.args[:i], expr.args[i], expr.args[i + 1:] + return B(*[A(*(first + (arg,) + tail)) for arg in b.args]) + return expr + return distribute_rl + + +def subs(a, b): + """ Replace expressions exactly """ + def subs_rl(expr): + if expr == a: + return b + else: + return expr + return subs_rl + + +# Functions that are rules +def unpack(expr): + """ Rule to unpack singleton args + + >>> from sympy.strategies import unpack + >>> from sympy import Basic, S + >>> unpack(Basic(S(2))) + 2 + """ + if len(expr.args) == 1: + return expr.args[0] + else: + return expr + + +def flatten(expr, new=new): + """ Flatten T(a, b, T(c, d), T2(e)) to T(a, b, c, d, T2(e)) """ + cls = expr.__class__ + args = [] + for arg in expr.args: + if arg.__class__ == cls: + args.extend(arg.args) + else: + args.append(arg) + return new(expr.__class__, *args) + + +def rebuild(expr): + """ Rebuild a SymPy tree. + + Explanation + =========== + + This function recursively calls constructors in the expression tree. + This forces canonicalization and removes ugliness introduced by the use of + Basic.__new__ + """ + if expr.is_Atom: + return expr + else: + return expr.func(*list(map(rebuild, expr.args))) diff --git a/phivenv/Lib/site-packages/sympy/strategies/tests/__init__.py b/phivenv/Lib/site-packages/sympy/strategies/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/strategies/tests/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/strategies/tests/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..80172b05bac4b22601386f6d7e619bee2a1a79af Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/strategies/tests/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/strategies/tests/__pycache__/test_core.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/strategies/tests/__pycache__/test_core.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a9c919c47cbd79970cbbf9e027ea8e1a959f94f Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/strategies/tests/__pycache__/test_core.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/strategies/tests/__pycache__/test_rl.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/strategies/tests/__pycache__/test_rl.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b696e8ee0f1f104402d9512a334dce2a40b4dd72 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/strategies/tests/__pycache__/test_rl.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/strategies/tests/__pycache__/test_tools.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/strategies/tests/__pycache__/test_tools.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ac09cf6c5b07f8a435e79629b08c736d09481c0 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/strategies/tests/__pycache__/test_tools.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/strategies/tests/__pycache__/test_traverse.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/strategies/tests/__pycache__/test_traverse.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1fa5d3194f05d9c7af90c0513c36dc10b97b78d6 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/strategies/tests/__pycache__/test_traverse.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/strategies/tests/__pycache__/test_tree.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/strategies/tests/__pycache__/test_tree.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f693a5d83326befadf9132d3fb218d88b19daa4e Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/strategies/tests/__pycache__/test_tree.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/strategies/tests/test_core.py b/phivenv/Lib/site-packages/sympy/strategies/tests/test_core.py new file mode 100644 index 0000000000000000000000000000000000000000..1e19bcab1940c476a8996b7ba92e7645a6230034 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/strategies/tests/test_core.py @@ -0,0 +1,118 @@ +from __future__ import annotations +from sympy.core.singleton import S +from sympy.core.basic import Basic +from sympy.strategies.core import ( + null_safe, exhaust, memoize, condition, + chain, tryit, do_one, debug, switch, minimize) +from io import StringIO + + +def posdec(x: int) -> int: + if x > 0: + return x - 1 + return x + + +def inc(x: int) -> int: + return x + 1 + + +def dec(x: int) -> int: + return x - 1 + + +def test_null_safe(): + def rl(expr: int) -> int | None: + if expr == 1: + return 2 + return None + + safe_rl = null_safe(rl) + assert rl(1) == safe_rl(1) + assert rl(3) is None + assert safe_rl(3) == 3 + + +def test_exhaust(): + sink = exhaust(posdec) + assert sink(5) == 0 + assert sink(10) == 0 + + +def test_memoize(): + rl = memoize(posdec) + assert rl(5) == posdec(5) + assert rl(5) == posdec(5) + assert rl(-2) == posdec(-2) + + +def test_condition(): + rl = condition(lambda x: x % 2 == 0, posdec) + assert rl(5) == 5 + assert rl(4) == 3 + + +def test_chain(): + rl = chain(posdec, posdec) + assert rl(5) == 3 + assert rl(1) == 0 + + +def test_tryit(): + def rl(expr: Basic) -> Basic: + assert False + + safe_rl = tryit(rl, AssertionError) + assert safe_rl(S(1)) == S(1) + + +def test_do_one(): + rl = do_one(posdec, posdec) + assert rl(5) == 4 + + def rl1(x: int) -> int: + if x == 1: + return 2 + return x + + def rl2(x: int) -> int: + if x == 2: + return 3 + return x + + rule = do_one(rl1, rl2) + assert rule(1) == 2 + assert rule(rule(1)) == 3 + + +def test_debug(): + file = StringIO() + rl = debug(posdec, file) + rl(5) + log = file.getvalue() + file.close() + + assert posdec.__name__ in log + assert '5' in log + assert '4' in log + + +def test_switch(): + def key(x: int) -> int: + return x % 3 + + rl = switch(key, {0: inc, 1: dec}) + assert rl(3) == 4 + assert rl(4) == 3 + assert rl(5) == 5 + + +def test_minimize(): + def key(x: int) -> int: + return -x + + rl = minimize(inc, dec) + assert rl(4) == 3 + + rl = minimize(inc, dec, objective=key) + assert rl(4) == 5 diff --git a/phivenv/Lib/site-packages/sympy/strategies/tests/test_rl.py b/phivenv/Lib/site-packages/sympy/strategies/tests/test_rl.py new file mode 100644 index 0000000000000000000000000000000000000000..8bfa90ad4d970b21396e0e1b6427b5a5c68fe381 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/strategies/tests/test_rl.py @@ -0,0 +1,78 @@ +from sympy.core.singleton import S +from sympy.strategies.rl import ( + rm_id, glom, flatten, unpack, sort, distribute, subs, rebuild) +from sympy.core.basic import Basic +from sympy.core.add import Add +from sympy.core.mul import Mul +from sympy.core.symbol import symbols +from sympy.abc import x + + +def test_rm_id(): + rmzeros = rm_id(lambda x: x == 0) + assert rmzeros(Basic(S(0), S(1))) == Basic(S(1)) + assert rmzeros(Basic(S(0), S(0))) == Basic(S(0)) + assert rmzeros(Basic(S(2), S(1))) == Basic(S(2), S(1)) + + +def test_glom(): + def key(x): + return x.as_coeff_Mul()[1] + + def count(x): + return x.as_coeff_Mul()[0] + + def newargs(cnt, arg): + return cnt * arg + + rl = glom(key, count, newargs) + + result = rl(Add(x, -x, 3 * x, 2, 3, evaluate=False)) + expected = Add(3 * x, 5) + assert set(result.args) == set(expected.args) + + +def test_flatten(): + assert flatten(Basic(S(1), S(2), Basic(S(3), S(4)))) == \ + Basic(S(1), S(2), S(3), S(4)) + + +def test_unpack(): + assert unpack(Basic(S(2))) == 2 + assert unpack(Basic(S(2), S(3))) == Basic(S(2), S(3)) + + +def test_sort(): + assert sort(str)(Basic(S(3), S(1), S(2))) == Basic(S(1), S(2), S(3)) + + +def test_distribute(): + class T1(Basic): + pass + + class T2(Basic): + pass + + distribute_t12 = distribute(T1, T2) + assert distribute_t12(T1(S(1), S(2), T2(S(3), S(4)), S(5))) == \ + T2(T1(S(1), S(2), S(3), S(5)), T1(S(1), S(2), S(4), S(5))) + assert distribute_t12(T1(S(1), S(2), S(3))) == T1(S(1), S(2), S(3)) + + +def test_distribute_add_mul(): + x, y = symbols('x, y') + expr = Mul(2, Add(x, y), evaluate=False) + expected = Add(Mul(2, x), Mul(2, y)) + distribute_mul = distribute(Mul, Add) + assert distribute_mul(expr) == expected + + +def test_subs(): + rl = subs(1, 2) + assert rl(1) == 2 + assert rl(3) == 3 + + +def test_rebuild(): + expr = Basic.__new__(Add, S(1), S(2)) + assert rebuild(expr) == 3 diff --git a/phivenv/Lib/site-packages/sympy/strategies/tests/test_tools.py b/phivenv/Lib/site-packages/sympy/strategies/tests/test_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..89774aeb92ead5781966e4f48ad32dc63e1bf7e2 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/strategies/tests/test_tools.py @@ -0,0 +1,32 @@ +from sympy.strategies.tools import subs, typed +from sympy.strategies.rl import rm_id +from sympy.core.basic import Basic +from sympy.core.singleton import S + + +def test_subs(): + from sympy.core.symbol import symbols + a, b, c, d, e, f = symbols('a,b,c,d,e,f') + mapping = {a: d, d: a, Basic(e): Basic(f)} + expr = Basic(a, Basic(b, c), Basic(d, Basic(e))) + result = Basic(d, Basic(b, c), Basic(a, Basic(f))) + assert subs(mapping)(expr) == result + + +def test_subs_empty(): + assert subs({})(Basic(S(1), S(2))) == Basic(S(1), S(2)) + + +def test_typed(): + class A(Basic): + pass + + class B(Basic): + pass + + rmzeros = rm_id(lambda x: x == S(0)) + rmones = rm_id(lambda x: x == S(1)) + remove_something = typed({A: rmzeros, B: rmones}) + + assert remove_something(A(S(0), S(1))) == A(S(1)) + assert remove_something(B(S(0), S(1))) == B(S(0)) diff --git a/phivenv/Lib/site-packages/sympy/strategies/tests/test_traverse.py b/phivenv/Lib/site-packages/sympy/strategies/tests/test_traverse.py new file mode 100644 index 0000000000000000000000000000000000000000..ee2409616a8b4f750af8baea149b2ea52c56be9d --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/strategies/tests/test_traverse.py @@ -0,0 +1,84 @@ +from sympy.strategies.traverse import ( + top_down, bottom_up, sall, top_down_once, bottom_up_once, basic_fns) +from sympy.strategies.rl import rebuild +from sympy.strategies.util import expr_fns +from sympy.core.add import Add +from sympy.core.basic import Basic +from sympy.core.numbers import Integer +from sympy.core.singleton import S +from sympy.core.symbol import Str, Symbol +from sympy.abc import x, y, z + + +def zero_symbols(expression): + return S.Zero if isinstance(expression, Symbol) else expression + + +def test_sall(): + zero_onelevel = sall(zero_symbols) + + assert zero_onelevel(Basic(x, y, Basic(x, z))) == \ + Basic(S(0), S(0), Basic(x, z)) + + +def test_bottom_up(): + _test_global_traversal(bottom_up) + _test_stop_on_non_basics(bottom_up) + + +def test_top_down(): + _test_global_traversal(top_down) + _test_stop_on_non_basics(top_down) + + +def _test_global_traversal(trav): + zero_all_symbols = trav(zero_symbols) + + assert zero_all_symbols(Basic(x, y, Basic(x, z))) == \ + Basic(S(0), S(0), Basic(S(0), S(0))) + + +def _test_stop_on_non_basics(trav): + def add_one_if_can(expr): + try: + return expr + 1 + except TypeError: + return expr + + expr = Basic(S(1), Str('a'), Basic(S(2), Str('b'))) + expected = Basic(S(2), Str('a'), Basic(S(3), Str('b'))) + rl = trav(add_one_if_can) + + assert rl(expr) == expected + + +class Basic2(Basic): + pass + + +def rl(x): + if x.args and not isinstance(x.args[0], Integer): + return Basic2(*x.args) + return x + + +def test_top_down_once(): + top_rl = top_down_once(rl) + + assert top_rl(Basic(S(1.0), S(2.0), Basic(S(3), S(4)))) == \ + Basic2(S(1.0), S(2.0), Basic(S(3), S(4))) + + +def test_bottom_up_once(): + bottom_rl = bottom_up_once(rl) + + assert bottom_rl(Basic(S(1), S(2), Basic(S(3.0), S(4.0)))) == \ + Basic(S(1), S(2), Basic2(S(3.0), S(4.0))) + + +def test_expr_fns(): + expr = x + y**3 + e = bottom_up(lambda v: v + 1, expr_fns)(expr) + b = bottom_up(lambda v: Basic.__new__(Add, v, S(1)), basic_fns)(expr) + + assert rebuild(b) == e diff --git a/phivenv/Lib/site-packages/sympy/strategies/tests/test_tree.py b/phivenv/Lib/site-packages/sympy/strategies/tests/test_tree.py new file mode 100644 index 0000000000000000000000000000000000000000..d5cdde747fe3ab90c8fd181701194403bc526067 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/strategies/tests/test_tree.py @@ -0,0 +1,92 @@ +from sympy.strategies.tree import treeapply, greedy, allresults, brute +from functools import partial, reduce + + +def inc(x): + return x + 1 + + +def dec(x): + return x - 1 + + +def double(x): + return 2 * x + + +def square(x): + return x**2 + + +def add(*args): + return sum(args) + + +def mul(*args): + return reduce(lambda a, b: a * b, args, 1) + + +def test_treeapply(): + tree = ([3, 3], [4, 1], 2) + assert treeapply(tree, {list: min, tuple: max}) == 3 + assert treeapply(tree, {list: add, tuple: mul}) == 60 + + +def test_treeapply_leaf(): + assert treeapply(3, {}, leaf=lambda x: x**2) == 9 + tree = ([3, 3], [4, 1], 2) + treep1 = ([4, 4], [5, 2], 3) + assert treeapply(tree, {list: min, tuple: max}, leaf=lambda x: x + 1) == \ + treeapply(treep1, {list: min, tuple: max}) + + +def test_treeapply_strategies(): + from sympy.strategies import chain, minimize + join = {list: chain, tuple: minimize} + + assert treeapply(inc, join) == inc + assert treeapply((inc, dec), join)(5) == minimize(inc, dec)(5) + assert treeapply([inc, dec], join)(5) == chain(inc, dec)(5) + tree = (inc, [dec, double]) # either inc or dec-then-double + assert treeapply(tree, join)(5) == 6 + assert treeapply(tree, join)(1) == 0 + + maximize = partial(minimize, objective=lambda x: -x) + join = {list: chain, tuple: maximize} + fn = treeapply(tree, join) + assert fn(4) == 6 # highest value comes from the dec then double + assert fn(1) == 2 # highest value comes from the inc + + +def test_greedy(): + tree = [inc, (dec, double)] # either inc or dec-then-double + + fn = greedy(tree, objective=lambda x: -x) + assert fn(4) == 6 # highest value comes from the dec then double + assert fn(1) == 2 # highest value comes from the inc + + tree = [inc, dec, [inc, dec, [(inc, inc), (dec, dec)]]] + lowest = greedy(tree) + assert lowest(10) == 8 + + highest = greedy(tree, objective=lambda x: -x) + assert highest(10) == 12 + + +def test_allresults(): + # square = lambda x: x**2 + + assert set(allresults(inc)(3)) == {inc(3)} + assert set(allresults([inc, dec])(3)) == {2, 4} + assert set(allresults((inc, dec))(3)) == {3} + assert set(allresults([inc, (dec, double)])(4)) == {5, 6} + + +def test_brute(): + tree = ([inc, dec], square) + fn = brute(tree, lambda x: -x) + + assert fn(2) == (2 + 1)**2 + assert fn(-2) == (-2 - 1)**2 + + assert brute(inc)(1) == 2 diff --git a/phivenv/Lib/site-packages/sympy/strategies/tools.py b/phivenv/Lib/site-packages/sympy/strategies/tools.py new file mode 100644 index 0000000000000000000000000000000000000000..e6a94c16db57206d7c83c8a5e13930c4cffdde47 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/strategies/tools.py @@ -0,0 +1,53 @@ +from . import rl +from .core import do_one, exhaust, switch +from .traverse import top_down + + +def subs(d, **kwargs): + """ Full simultaneous exact substitution. + + Examples + ======== + + >>> from sympy.strategies.tools import subs + >>> from sympy import Basic, S + >>> mapping = {S(1): S(4), S(4): S(1), Basic(S(5)): Basic(S(6), S(7))} + >>> expr = Basic(S(1), Basic(S(2), S(3)), Basic(S(4), Basic(S(5)))) + >>> subs(mapping)(expr) + Basic(4, Basic(2, 3), Basic(1, Basic(6, 7))) + """ + if d: + return top_down(do_one(*map(rl.subs, *zip(*d.items()))), **kwargs) + else: + return lambda x: x + + +def canon(*rules, **kwargs): + """ Strategy for canonicalization. + + Explanation + =========== + + Apply each rule in a bottom_up fashion through the tree. + Do each one in turn. + Keep doing this until there is no change. + """ + return exhaust(top_down(exhaust(do_one(*rules)), **kwargs)) + + +def typed(ruletypes): + """ Apply rules based on the expression type + + inputs: + ruletypes -- a dict mapping {Type: rule} + + Examples + ======== + + >>> from sympy.strategies import rm_id, typed + >>> from sympy import Add, Mul + >>> rm_zeros = rm_id(lambda x: x==0) + >>> rm_ones = rm_id(lambda x: x==1) + >>> remove_idents = typed({Add: rm_zeros, Mul: rm_ones}) + """ + return switch(type, ruletypes) diff --git a/phivenv/Lib/site-packages/sympy/strategies/traverse.py b/phivenv/Lib/site-packages/sympy/strategies/traverse.py new file mode 100644 index 0000000000000000000000000000000000000000..869361f443742b5b7346c9c970f103b955e8473e --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/strategies/traverse.py @@ -0,0 +1,37 @@ +"""Strategies to Traverse a Tree.""" +from sympy.strategies.util import basic_fns +from sympy.strategies.core import chain, do_one + + +def top_down(rule, fns=basic_fns): + """Apply a rule down a tree running it on the top nodes first.""" + return chain(rule, lambda expr: sall(top_down(rule, fns), fns)(expr)) + + +def bottom_up(rule, fns=basic_fns): + """Apply a rule down a tree running it on the bottom nodes first.""" + return chain(lambda expr: sall(bottom_up(rule, fns), fns)(expr), rule) + + +def top_down_once(rule, fns=basic_fns): + """Apply a rule down a tree - stop on success.""" + return do_one(rule, lambda expr: sall(top_down(rule, fns), fns)(expr)) + + +def bottom_up_once(rule, fns=basic_fns): + """Apply a rule up a tree - stop on success.""" + return do_one(lambda expr: sall(bottom_up(rule, fns), fns)(expr), rule) + + +def sall(rule, fns=basic_fns): + """Strategic all - apply rule to args.""" + op, new, children, leaf = map(fns.get, ('op', 'new', 'children', 'leaf')) + + def all_rl(expr): + if leaf(expr): + return expr + else: + args = map(rule, children(expr)) + return new(op(expr), *args) + + return all_rl diff --git a/phivenv/Lib/site-packages/sympy/strategies/tree.py b/phivenv/Lib/site-packages/sympy/strategies/tree.py new file mode 100644 index 0000000000000000000000000000000000000000..c2006fde4fc5d09f3d38baae4d7335b4cbd971b7 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/strategies/tree.py @@ -0,0 +1,139 @@ +from functools import partial +from sympy.strategies import chain, minimize +from sympy.strategies.core import identity +import sympy.strategies.branch as branch +from sympy.strategies.branch import yieldify + + +def treeapply(tree, join, leaf=identity): + """ Apply functions onto recursive containers (tree). + + Explanation + =========== + + join - a dictionary mapping container types to functions + e.g. ``{list: minimize, tuple: chain}`` + + Keys are containers/iterables. Values are functions [a] -> a. + + Examples + ======== + + >>> from sympy.strategies.tree import treeapply + >>> tree = [(3, 2), (4, 1)] + >>> treeapply(tree, {list: max, tuple: min}) + 2 + + >>> add = lambda *args: sum(args) + >>> def mul(*args): + ... total = 1 + ... for arg in args: + ... total *= arg + ... return total + >>> treeapply(tree, {list: mul, tuple: add}) + 25 + """ + for typ in join: + if isinstance(tree, typ): + return join[typ](*map(partial(treeapply, join=join, leaf=leaf), + tree)) + return leaf(tree) + + +def greedy(tree, objective=identity, **kwargs): + """ Execute a strategic tree. Select alternatives greedily + + Trees + ----- + + Nodes in a tree can be either + + function - a leaf + list - a selection among operations + tuple - a sequence of chained operations + + Textual examples + ---------------- + + Text: Run f, then run g, e.g. ``lambda x: g(f(x))`` + Code: ``(f, g)`` + + Text: Run either f or g, whichever minimizes the objective + Code: ``[f, g]`` + + Textx: Run either f or g, whichever is better, then run h + Code: ``([f, g], h)`` + + Text: Either expand then simplify or try factor then foosimp. Finally print + Code: ``([(expand, simplify), (factor, foosimp)], print)`` + + Objective + --------- + + "Better" is determined by the objective keyword. This function makes + choices to minimize the objective. It defaults to the identity. + + Examples + ======== + + >>> from sympy.strategies.tree import greedy + >>> inc = lambda x: x + 1 + >>> dec = lambda x: x - 1 + >>> double = lambda x: 2*x + + >>> tree = [inc, (dec, double)] # either inc or dec-then-double + >>> fn = greedy(tree) + >>> fn(4) # lowest value comes from the inc + 5 + >>> fn(1) # lowest value comes from dec then double + 0 + + This function selects between options in a tuple. The result is chosen + that minimizes the objective function. + + >>> fn = greedy(tree, objective=lambda x: -x) # maximize + >>> fn(4) # highest value comes from the dec then double + 6 + >>> fn(1) # highest value comes from the inc + 2 + + Greediness + ---------- + + This is a greedy algorithm. In the example: + + ([a, b], c) # do either a or b, then do c + + the choice between running ``a`` or ``b`` is made without foresight to c + """ + optimize = partial(minimize, objective=objective) + return treeapply(tree, {list: optimize, tuple: chain}, **kwargs) + + +def allresults(tree, leaf=yieldify): + """ Execute a strategic tree. Return all possibilities. + + Returns a lazy iterator of all possible results + + Exhaustiveness + -------------- + + This is an exhaustive algorithm. In the example + + ([a, b], [c, d]) + + All of the results from + + (a, c), (b, c), (a, d), (b, d) + + are returned. This can lead to combinatorial blowup. + + See sympy.strategies.greedy for details on input + """ + return treeapply(tree, {list: branch.multiplex, tuple: branch.chain}, + leaf=leaf) + + +def brute(tree, objective=identity, **kwargs): + return lambda expr: min(tuple(allresults(tree, **kwargs)(expr)), + key=objective) diff --git a/phivenv/Lib/site-packages/sympy/strategies/util.py b/phivenv/Lib/site-packages/sympy/strategies/util.py new file mode 100644 index 0000000000000000000000000000000000000000..13aab5f6a49650c5ded9cd913c23c6682f18d40a --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/strategies/util.py @@ -0,0 +1,17 @@ +from sympy.core.basic import Basic + +new = Basic.__new__ + + +def assoc(d, k, v): + d = d.copy() + d[k] = v + return d + + +basic_fns = {'op': type, + 'new': Basic.__new__, + 'leaf': lambda x: not isinstance(x, Basic) or x.is_Atom, + 'children': lambda x: x.args} + +expr_fns = assoc(basic_fns, 'new', lambda op, *args: op(*args)) diff --git a/phivenv/Lib/site-packages/sympy/tensor/__init__.py b/phivenv/Lib/site-packages/sympy/tensor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a832614b1d48e26bf01e16f040f34dd412e8e32b --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/tensor/__init__.py @@ -0,0 +1,23 @@ +"""A module to manipulate symbolic objects with indices including tensors + +""" +from .indexed import IndexedBase, Idx, Indexed +from .index_methods import get_contraction_structure, get_indices +from .functions import shape +from .array import (MutableDenseNDimArray, ImmutableDenseNDimArray, + MutableSparseNDimArray, ImmutableSparseNDimArray, NDimArray, tensorproduct, + tensorcontraction, tensordiagonal, derive_by_array, permutedims, Array, + DenseNDimArray, SparseNDimArray,) + +__all__ = [ + 'IndexedBase', 'Idx', 'Indexed', + + 'get_contraction_structure', 'get_indices', + + 'shape', + + 'MutableDenseNDimArray', 'ImmutableDenseNDimArray', + 'MutableSparseNDimArray', 'ImmutableSparseNDimArray', 'NDimArray', + 'tensorproduct', 'tensorcontraction', 'tensordiagonal', 'derive_by_array', 'permutedims', + 'Array', 'DenseNDimArray', 'SparseNDimArray', +] diff --git a/phivenv/Lib/site-packages/sympy/tensor/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/tensor/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1cbc3d7002c2dc913f44857740dd75dace60ac31 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/tensor/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/tensor/__pycache__/functions.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/tensor/__pycache__/functions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..096d1facbdead8267afd65e270de44e235f48efd Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/tensor/__pycache__/functions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/tensor/__pycache__/index_methods.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/tensor/__pycache__/index_methods.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6985afcab921c0b3a8d9d1aa031f48dc5336e09b Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/tensor/__pycache__/index_methods.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/tensor/__pycache__/indexed.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/tensor/__pycache__/indexed.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..815955a774d94b793dc59baa5d0869c9ae3e5167 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/tensor/__pycache__/indexed.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/tensor/__pycache__/toperators.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/tensor/__pycache__/toperators.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..80b51d7a251e675ffba58c88c9147c4b58706a6e Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/tensor/__pycache__/toperators.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/__init__.py b/phivenv/Lib/site-packages/sympy/tensor/array/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eca2eb4c6c58cb113517b6e41737e9d97abbb84e --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/tensor/array/__init__.py @@ -0,0 +1,271 @@ +r""" +N-dim array module for SymPy. + +Four classes are provided to handle N-dim arrays, given by the combinations +dense/sparse (i.e. whether to store all elements or only the non-zero ones in +memory) and mutable/immutable (immutable classes are SymPy objects, but cannot +change after they have been created). + +Examples +======== + +The following examples show the usage of ``Array``. This is an abbreviation for +``ImmutableDenseNDimArray``, that is an immutable and dense N-dim array, the +other classes are analogous. For mutable classes it is also possible to change +element values after the object has been constructed. + +Array construction can detect the shape of nested lists and tuples: + +>>> from sympy import Array +>>> a1 = Array([[1, 2], [3, 4], [5, 6]]) +>>> a1 +[[1, 2], [3, 4], [5, 6]] +>>> a1.shape +(3, 2) +>>> a1.rank() +2 +>>> from sympy.abc import x, y, z +>>> a2 = Array([[[x, y], [z, x*z]], [[1, x*y], [1/x, x/y]]]) +>>> a2 +[[[x, y], [z, x*z]], [[1, x*y], [1/x, x/y]]] +>>> a2.shape +(2, 2, 2) +>>> a2.rank() +3 + +Otherwise one could pass a 1-dim array followed by a shape tuple: + +>>> m1 = Array(range(12), (3, 4)) +>>> m1 +[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]] +>>> m2 = Array(range(12), (3, 2, 2)) +>>> m2 +[[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[8, 9], [10, 11]]] +>>> m2[1,1,1] +7 +>>> m2.reshape(4, 3) +[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]] + +Slice support: + +>>> m2[:, 1, 1] +[3, 7, 11] + +Elementwise derivative: + +>>> from sympy.abc import x, y, z +>>> m3 = Array([x**3, x*y, z]) +>>> m3.diff(x) +[3*x**2, y, 0] +>>> m3.diff(z) +[0, 0, 1] + +Multiplication with other SymPy expressions is applied elementwisely: + +>>> (1+x)*m3 +[x**3*(x + 1), x*y*(x + 1), z*(x + 1)] + +To apply a function to each element of the N-dim array, use ``applyfunc``: + +>>> m3.applyfunc(lambda x: x/2) +[x**3/2, x*y/2, z/2] + +N-dim arrays can be converted to nested lists by the ``tolist()`` method: + +>>> m2.tolist() +[[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[8, 9], [10, 11]]] +>>> isinstance(m2.tolist(), list) +True + +If the rank is 2, it is possible to convert them to matrices with ``tomatrix()``: + +>>> m1.tomatrix() +Matrix([ +[0, 1, 2, 3], +[4, 5, 6, 7], +[8, 9, 10, 11]]) + +Products and contractions +------------------------- + +Tensor product between arrays `A_{i_1,\ldots,i_n}` and `B_{j_1,\ldots,j_m}` +creates the combined array `P = A \otimes B` defined as + +`P_{i_1,\ldots,i_n,j_1,\ldots,j_m} := A_{i_1,\ldots,i_n}\cdot B_{j_1,\ldots,j_m}.` + +It is available through ``tensorproduct(...)``: + +>>> from sympy import Array, tensorproduct +>>> from sympy.abc import x,y,z,t +>>> A = Array([x, y, z, t]) +>>> B = Array([1, 2, 3, 4]) +>>> tensorproduct(A, B) +[[x, 2*x, 3*x, 4*x], [y, 2*y, 3*y, 4*y], [z, 2*z, 3*z, 4*z], [t, 2*t, 3*t, 4*t]] + +In case you don't want to evaluate the tensor product immediately, you can use +``ArrayTensorProduct``, which creates an unevaluated tensor product expression: + +>>> from sympy.tensor.array.expressions import ArrayTensorProduct +>>> ArrayTensorProduct(A, B) +ArrayTensorProduct([x, y, z, t], [1, 2, 3, 4]) + +Calling ``.as_explicit()`` on ``ArrayTensorProduct`` is equivalent to just calling +``tensorproduct(...)``: + +>>> ArrayTensorProduct(A, B).as_explicit() +[[x, 2*x, 3*x, 4*x], [y, 2*y, 3*y, 4*y], [z, 2*z, 3*z, 4*z], [t, 2*t, 3*t, 4*t]] + +Tensor product between a rank-1 array and a matrix creates a rank-3 array: + +>>> from sympy import eye +>>> p1 = tensorproduct(A, eye(4)) +>>> p1 +[[[x, 0, 0, 0], [0, x, 0, 0], [0, 0, x, 0], [0, 0, 0, x]], [[y, 0, 0, 0], [0, y, 0, 0], [0, 0, y, 0], [0, 0, 0, y]], [[z, 0, 0, 0], [0, z, 0, 0], [0, 0, z, 0], [0, 0, 0, z]], [[t, 0, 0, 0], [0, t, 0, 0], [0, 0, t, 0], [0, 0, 0, t]]] + +Now, to get back `A_0 \otimes \mathbf{1}` one can access `p_{0,m,n}` by slicing: + +>>> p1[0,:,:] +[[x, 0, 0, 0], [0, x, 0, 0], [0, 0, x, 0], [0, 0, 0, x]] + +Tensor contraction sums over the specified axes, for example contracting +positions `a` and `b` means + +`A_{i_1,\ldots,i_a,\ldots,i_b,\ldots,i_n} \implies \sum_k A_{i_1,\ldots,k,\ldots,k,\ldots,i_n}` + +Remember that Python indexing is zero starting, to contract the a-th and b-th +axes it is therefore necessary to specify `a-1` and `b-1` + +>>> from sympy import tensorcontraction +>>> C = Array([[x, y], [z, t]]) + +The matrix trace is equivalent to the contraction of a rank-2 array: + +`A_{m,n} \implies \sum_k A_{k,k}` + +>>> tensorcontraction(C, (0, 1)) +t + x + +To create an expression representing a tensor contraction that does not get +evaluated immediately, use ``ArrayContraction``, which is equivalent to +``tensorcontraction(...)`` if it is followed by ``.as_explicit()``: + +>>> from sympy.tensor.array.expressions import ArrayContraction +>>> ArrayContraction(C, (0, 1)) +ArrayContraction([[x, y], [z, t]], (0, 1)) +>>> ArrayContraction(C, (0, 1)).as_explicit() +t + x + +Matrix product is equivalent to a tensor product of two rank-2 arrays, followed +by a contraction of the 2nd and 3rd axes (in Python indexing axes number 1, 2). + +`A_{m,n}\cdot B_{i,j} \implies \sum_k A_{m, k}\cdot B_{k, j}` + +>>> D = Array([[2, 1], [0, -1]]) +>>> tensorcontraction(tensorproduct(C, D), (1, 2)) +[[2*x, x - y], [2*z, -t + z]] + +One may verify that the matrix product is equivalent: + +>>> from sympy import Matrix +>>> Matrix([[x, y], [z, t]])*Matrix([[2, 1], [0, -1]]) +Matrix([ +[2*x, x - y], +[2*z, -t + z]]) + +or equivalently + +>>> C.tomatrix()*D.tomatrix() +Matrix([ +[2*x, x - y], +[2*z, -t + z]]) + +Diagonal operator +----------------- + +The ``tensordiagonal`` function acts in a similar manner as ``tensorcontraction``, +but the joined indices are not summed over, for example diagonalizing +positions `a` and `b` means + +`A_{i_1,\ldots,i_a,\ldots,i_b,\ldots,i_n} \implies A_{i_1,\ldots,k,\ldots,k,\ldots,i_n} +\implies \tilde{A}_{i_1,\ldots,i_{a-1},i_{a+1},\ldots,i_{b-1},i_{b+1},\ldots,i_n,k}` + +where `\tilde{A}` is the array equivalent to the diagonal of `A` at positions +`a` and `b` moved to the last index slot. + +Compare the difference between contraction and diagonal operators: + +>>> from sympy import tensordiagonal +>>> from sympy.abc import a, b, c, d +>>> m = Matrix([[a, b], [c, d]]) +>>> tensorcontraction(m, [0, 1]) +a + d +>>> tensordiagonal(m, [0, 1]) +[a, d] + +In short, no summation occurs with ``tensordiagonal``. + + +Derivatives by array +-------------------- + +The usual derivative operation may be extended to support derivation with +respect to arrays, provided that all elements in the that array are symbols or +expressions suitable for derivations. + +The definition of a derivative by an array is as follows: given the array +`A_{i_1, \ldots, i_N}` and the array `X_{j_1, \ldots, j_M}` +the derivative of arrays will return a new array `B` defined by + +`B_{j_1,\ldots,j_M,i_1,\ldots,i_N} := \frac{\partial A_{i_1,\ldots,i_N}}{\partial X_{j_1,\ldots,j_M}}` + +The function ``derive_by_array`` performs such an operation: + +>>> from sympy import derive_by_array +>>> from sympy.abc import x, y, z, t +>>> from sympy import sin, exp + +With scalars, it behaves exactly as the ordinary derivative: + +>>> derive_by_array(sin(x*y), x) +y*cos(x*y) + +Scalar derived by an array basis: + +>>> derive_by_array(sin(x*y), [x, y, z]) +[y*cos(x*y), x*cos(x*y), 0] + +Deriving array by an array basis: `B^{nm} := \frac{\partial A^m}{\partial x^n}` + +>>> basis = [x, y, z] +>>> ax = derive_by_array([exp(x), sin(y*z), t], basis) +>>> ax +[[exp(x), 0, 0], [0, z*cos(y*z), 0], [0, y*cos(y*z), 0]] + +Contraction of the resulting array: `\sum_m \frac{\partial A^m}{\partial x^m}` + +>>> tensorcontraction(ax, (0, 1)) +z*cos(y*z) + exp(x) + +""" + +from .dense_ndim_array import MutableDenseNDimArray, ImmutableDenseNDimArray, DenseNDimArray +from .sparse_ndim_array import MutableSparseNDimArray, ImmutableSparseNDimArray, SparseNDimArray +from .ndim_array import NDimArray, ArrayKind +from .arrayop import tensorproduct, tensorcontraction, tensordiagonal, derive_by_array, permutedims +from .array_comprehension import ArrayComprehension, ArrayComprehensionMap + +Array = ImmutableDenseNDimArray + +__all__ = [ + 'MutableDenseNDimArray', 'ImmutableDenseNDimArray', 'DenseNDimArray', + + 'MutableSparseNDimArray', 'ImmutableSparseNDimArray', 'SparseNDimArray', + + 'NDimArray', 'ArrayKind', + + 'tensorproduct', 'tensorcontraction', 'tensordiagonal', 'derive_by_array', + + 'permutedims', 'ArrayComprehension', 'ArrayComprehensionMap', + + 'Array', +] diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/tensor/array/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f91b9c17ed2266464be30eff950ff2c499b629f Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/tensor/array/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/__pycache__/array_comprehension.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/tensor/array/__pycache__/array_comprehension.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2913a16c099073ad3676565ae559fa35e673033 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/tensor/array/__pycache__/array_comprehension.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/__pycache__/array_derivatives.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/tensor/array/__pycache__/array_derivatives.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e4e2997cdaace015d93d8cb56742bdc24c27177 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/tensor/array/__pycache__/array_derivatives.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/__pycache__/arrayop.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/tensor/array/__pycache__/arrayop.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e5d3d9e831916415197d4e86b2d76674c215c5d Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/tensor/array/__pycache__/arrayop.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/__pycache__/dense_ndim_array.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/tensor/array/__pycache__/dense_ndim_array.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7499a08281a7f2706e24de388727e2618f5b58de Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/tensor/array/__pycache__/dense_ndim_array.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/__pycache__/mutable_ndim_array.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/tensor/array/__pycache__/mutable_ndim_array.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d72555c76c1f26e03984c36dcd93dfd5f667e2e9 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/tensor/array/__pycache__/mutable_ndim_array.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/__pycache__/ndim_array.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/tensor/array/__pycache__/ndim_array.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a94379d20f8e3cc8f9fe11020c9c14d3f00d9deb Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/tensor/array/__pycache__/ndim_array.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/__pycache__/sparse_ndim_array.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/tensor/array/__pycache__/sparse_ndim_array.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d1be4db6b1b51e900954f54c831cbfe86ff9da7 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/tensor/array/__pycache__/sparse_ndim_array.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/array_comprehension.py b/phivenv/Lib/site-packages/sympy/tensor/array/array_comprehension.py new file mode 100644 index 0000000000000000000000000000000000000000..95702f499f3e40597fd0144929138ac1329962ee --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/tensor/array/array_comprehension.py @@ -0,0 +1,399 @@ +import functools, itertools +from sympy.core.sympify import _sympify, sympify +from sympy.core.expr import Expr +from sympy.core import Basic, Tuple +from sympy.tensor.array import ImmutableDenseNDimArray +from sympy.core.symbol import Symbol +from sympy.core.numbers import Integer + + +class ArrayComprehension(Basic): + """ + Generate a list comprehension. + + Explanation + =========== + + If there is a symbolic dimension, for example, say [i for i in range(1, N)] where + N is a Symbol, then the expression will not be expanded to an array. Otherwise, + calling the doit() function will launch the expansion. + + Examples + ======== + + >>> from sympy.tensor.array import ArrayComprehension + >>> from sympy import symbols + >>> i, j, k = symbols('i j k') + >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3)) + >>> a + ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3)) + >>> a.doit() + [[11, 12, 13], [21, 22, 23], [31, 32, 33], [41, 42, 43]] + >>> b = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, k)) + >>> b.doit() + ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, k)) + """ + def __new__(cls, function, *symbols, **assumptions): + if any(len(l) != 3 or None for l in symbols): + raise ValueError('ArrayComprehension requires values lower and upper bound' + ' for the expression') + arglist = [sympify(function)] + arglist.extend(cls._check_limits_validity(function, symbols)) + obj = Basic.__new__(cls, *arglist, **assumptions) + obj._limits = obj._args[1:] + obj._shape = cls._calculate_shape_from_limits(obj._limits) + obj._rank = len(obj._shape) + obj._loop_size = cls._calculate_loop_size(obj._shape) + return obj + + @property + def function(self): + """The function applied across limits. + + Examples + ======== + + >>> from sympy.tensor.array import ArrayComprehension + >>> from sympy import symbols + >>> i, j = symbols('i j') + >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3)) + >>> a.function + 10*i + j + """ + return self._args[0] + + @property + def limits(self): + """ + The list of limits that will be applied while expanding the array. + + Examples + ======== + + >>> from sympy.tensor.array import ArrayComprehension + >>> from sympy import symbols + >>> i, j = symbols('i j') + >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3)) + >>> a.limits + ((i, 1, 4), (j, 1, 3)) + """ + return self._limits + + @property + def free_symbols(self): + """ + The set of the free_symbols in the array. + Variables appeared in the bounds are supposed to be excluded + from the free symbol set. + + Examples + ======== + + >>> from sympy.tensor.array import ArrayComprehension + >>> from sympy import symbols + >>> i, j, k = symbols('i j k') + >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3)) + >>> a.free_symbols + set() + >>> b = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, k+3)) + >>> b.free_symbols + {k} + """ + expr_free_sym = self.function.free_symbols + for var, inf, sup in self._limits: + expr_free_sym.discard(var) + curr_free_syms = inf.free_symbols.union(sup.free_symbols) + expr_free_sym = expr_free_sym.union(curr_free_syms) + return expr_free_sym + + @property + def variables(self): + """The tuples of the variables in the limits. + + Examples + ======== + + >>> from sympy.tensor.array import ArrayComprehension + >>> from sympy import symbols + >>> i, j, k = symbols('i j k') + >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3)) + >>> a.variables + [i, j] + """ + return [l[0] for l in self._limits] + + @property + def bound_symbols(self): + """The list of dummy variables. + + Note + ==== + + Note that all variables are dummy variables since a limit without + lower bound or upper bound is not accepted. + """ + return [l[0] for l in self._limits if len(l) != 1] + + @property + def shape(self): + """ + The shape of the expanded array, which may have symbols. + + Note + ==== + + Both the lower and the upper bounds are included while + calculating the shape. + + Examples + ======== + + >>> from sympy.tensor.array import ArrayComprehension + >>> from sympy import symbols + >>> i, j, k = symbols('i j k') + >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3)) + >>> a.shape + (4, 3) + >>> b = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, k+3)) + >>> b.shape + (4, k + 3) + """ + return self._shape + + @property + def is_shape_numeric(self): + """ + Test if the array is shape-numeric which means there is no symbolic + dimension. + + Examples + ======== + + >>> from sympy.tensor.array import ArrayComprehension + >>> from sympy import symbols + >>> i, j, k = symbols('i j k') + >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3)) + >>> a.is_shape_numeric + True + >>> b = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, k+3)) + >>> b.is_shape_numeric + False + """ + for _, inf, sup in self._limits: + if Basic(inf, sup).atoms(Symbol): + return False + return True + + def rank(self): + """The rank of the expanded array. + + Examples + ======== + + >>> from sympy.tensor.array import ArrayComprehension + >>> from sympy import symbols + >>> i, j, k = symbols('i j k') + >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3)) + >>> a.rank() + 2 + """ + return self._rank + + def __len__(self): + """ + The length of the expanded array which means the number + of elements in the array. + + Raises + ====== + + ValueError : When the length of the array is symbolic + + Examples + ======== + + >>> from sympy.tensor.array import ArrayComprehension + >>> from sympy import symbols + >>> i, j = symbols('i j') + >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3)) + >>> len(a) + 12 + """ + if self._loop_size.free_symbols: + raise ValueError('Symbolic length is not supported') + return self._loop_size + + @classmethod + def _check_limits_validity(cls, function, limits): + #limits = sympify(limits) + new_limits = [] + for var, inf, sup in limits: + var = _sympify(var) + inf = _sympify(inf) + #since this is stored as an argument, it should be + #a Tuple + if isinstance(sup, list): + sup = Tuple(*sup) + else: + sup = _sympify(sup) + new_limits.append(Tuple(var, inf, sup)) + if any((not isinstance(i, Expr)) or i.atoms(Symbol, Integer) != i.atoms() + for i in [inf, sup]): + raise TypeError('Bounds should be an Expression(combination of Integer and Symbol)') + if (inf > sup) == True: + raise ValueError('Lower bound should be inferior to upper bound') + if var in inf.free_symbols or var in sup.free_symbols: + raise ValueError('Variable should not be part of its bounds') + return new_limits + + @classmethod + def _calculate_shape_from_limits(cls, limits): + return tuple([sup - inf + 1 for _, inf, sup in limits]) + + @classmethod + def _calculate_loop_size(cls, shape): + if not shape: + return 0 + loop_size = 1 + for l in shape: + loop_size = loop_size * l + + return loop_size + + def doit(self, **hints): + if not self.is_shape_numeric: + return self + + return self._expand_array() + + def _expand_array(self): + res = [] + for values in itertools.product(*[range(inf, sup+1) + for var, inf, sup + in self._limits]): + res.append(self._get_element(values)) + + return ImmutableDenseNDimArray(res, self.shape) + + def _get_element(self, values): + temp = self.function + for var, val in zip(self.variables, values): + temp = temp.subs(var, val) + return temp + + def tolist(self): + """Transform the expanded array to a list. + + Raises + ====== + + ValueError : When there is a symbolic dimension + + Examples + ======== + + >>> from sympy.tensor.array import ArrayComprehension + >>> from sympy import symbols + >>> i, j = symbols('i j') + >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3)) + >>> a.tolist() + [[11, 12, 13], [21, 22, 23], [31, 32, 33], [41, 42, 43]] + """ + if self.is_shape_numeric: + return self._expand_array().tolist() + + raise ValueError("A symbolic array cannot be expanded to a list") + + def tomatrix(self): + """Transform the expanded array to a matrix. + + Raises + ====== + + ValueError : When there is a symbolic dimension + ValueError : When the rank of the expanded array is not equal to 2 + + Examples + ======== + + >>> from sympy.tensor.array import ArrayComprehension + >>> from sympy import symbols + >>> i, j = symbols('i j') + >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3)) + >>> a.tomatrix() + Matrix([ + [11, 12, 13], + [21, 22, 23], + [31, 32, 33], + [41, 42, 43]]) + """ + from sympy.matrices import Matrix + + if not self.is_shape_numeric: + raise ValueError("A symbolic array cannot be expanded to a matrix") + if self._rank != 2: + raise ValueError('Dimensions must be of size of 2') + + return Matrix(self._expand_array().tomatrix()) + + +def isLambda(v): + LAMBDA = lambda: 0 + return isinstance(v, type(LAMBDA)) and v.__name__ == LAMBDA.__name__ + +class ArrayComprehensionMap(ArrayComprehension): + ''' + A subclass of ArrayComprehension dedicated to map external function lambda. + + Notes + ===== + + Only the lambda function is considered. + At most one argument in lambda function is accepted in order to avoid ambiguity + in value assignment. + + Examples + ======== + + >>> from sympy.tensor.array import ArrayComprehensionMap + >>> from sympy import symbols + >>> i, j, k = symbols('i j k') + >>> a = ArrayComprehensionMap(lambda: 1, (i, 1, 4)) + >>> a.doit() + [1, 1, 1, 1] + >>> b = ArrayComprehensionMap(lambda a: a+1, (j, 1, 4)) + >>> b.doit() + [2, 3, 4, 5] + + ''' + def __new__(cls, function, *symbols, **assumptions): + if any(len(l) != 3 or None for l in symbols): + raise ValueError('ArrayComprehension requires values lower and upper bound' + ' for the expression') + + if not isLambda(function): + raise ValueError('Data type not supported') + + arglist = cls._check_limits_validity(function, symbols) + obj = Basic.__new__(cls, *arglist, **assumptions) + obj._limits = obj._args + obj._shape = cls._calculate_shape_from_limits(obj._limits) + obj._rank = len(obj._shape) + obj._loop_size = cls._calculate_loop_size(obj._shape) + obj._lambda = function + return obj + + @property + def func(self): + class _(ArrayComprehensionMap): + def __new__(cls, *args, **kwargs): + return ArrayComprehensionMap(self._lambda, *args, **kwargs) + return _ + + def _get_element(self, values): + temp = self._lambda + if self._lambda.__code__.co_argcount == 0: + temp = temp() + elif self._lambda.__code__.co_argcount == 1: + temp = temp(functools.reduce(lambda a, b: a*b, values)) + return temp diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/array_derivatives.py b/phivenv/Lib/site-packages/sympy/tensor/array/array_derivatives.py new file mode 100644 index 0000000000000000000000000000000000000000..a38db6caefe256a8c7e1f3415b78351b3787fee9 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/tensor/array/array_derivatives.py @@ -0,0 +1,129 @@ +from __future__ import annotations + +from sympy.core.expr import Expr +from sympy.core.function import Derivative +from sympy.core.numbers import Integer +from sympy.matrices.matrixbase import MatrixBase +from .ndim_array import NDimArray +from .arrayop import derive_by_array +from sympy.matrices.expressions.matexpr import MatrixExpr +from sympy.matrices.expressions.special import ZeroMatrix +from sympy.matrices.expressions.matexpr import _matrix_derivative + + +class ArrayDerivative(Derivative): + + is_scalar = False + + def __new__(cls, expr, *variables, **kwargs): + obj = super().__new__(cls, expr, *variables, **kwargs) + if isinstance(obj, ArrayDerivative): + obj._shape = obj._get_shape() + return obj + + def _get_shape(self): + shape = () + for v, count in self.variable_count: + if hasattr(v, "shape"): + for i in range(count): + shape += v.shape + if hasattr(self.expr, "shape"): + shape += self.expr.shape + return shape + + @property + def shape(self): + return self._shape + + @classmethod + def _get_zero_with_shape_like(cls, expr): + if isinstance(expr, (MatrixBase, NDimArray)): + return expr.zeros(*expr.shape) + elif isinstance(expr, MatrixExpr): + return ZeroMatrix(*expr.shape) + else: + raise RuntimeError("Unable to determine shape of array-derivative.") + + @staticmethod + def _call_derive_scalar_by_matrix(expr: Expr, v: MatrixBase) -> Expr: + return v.applyfunc(lambda x: expr.diff(x)) + + @staticmethod + def _call_derive_scalar_by_matexpr(expr: Expr, v: MatrixExpr) -> Expr: + if expr.has(v): + return _matrix_derivative(expr, v) + else: + return ZeroMatrix(*v.shape) + + @staticmethod + def _call_derive_scalar_by_array(expr: Expr, v: NDimArray) -> Expr: + return v.applyfunc(lambda x: expr.diff(x)) + + @staticmethod + def _call_derive_matrix_by_scalar(expr: MatrixBase, v: Expr) -> Expr: + return _matrix_derivative(expr, v) + + @staticmethod + def _call_derive_matexpr_by_scalar(expr: MatrixExpr, v: Expr) -> Expr: + return expr._eval_derivative(v) + + @staticmethod + def _call_derive_array_by_scalar(expr: NDimArray, v: Expr) -> Expr: + return expr.applyfunc(lambda x: x.diff(v)) + + @staticmethod + def _call_derive_default(expr: Expr, v: Expr) -> Expr | None: + if expr.has(v): + return _matrix_derivative(expr, v) + else: + return None + + @classmethod + def _dispatch_eval_derivative_n_times(cls, expr, v, count): + # Evaluate the derivative `n` times. If + # `_eval_derivative_n_times` is not overridden by the current + # object, the default in `Basic` will call a loop over + # `_eval_derivative`: + + if not isinstance(count, (int, Integer)) or ((count <= 0) == True): + return None + + # TODO: this could be done with multiple-dispatching: + if expr.is_scalar: + if isinstance(v, MatrixBase): + result = cls._call_derive_scalar_by_matrix(expr, v) + elif isinstance(v, MatrixExpr): + result = cls._call_derive_scalar_by_matexpr(expr, v) + elif isinstance(v, NDimArray): + result = cls._call_derive_scalar_by_array(expr, v) + elif v.is_scalar: + # scalar by scalar has a special + return super()._dispatch_eval_derivative_n_times(expr, v, count) + else: + return None + elif v.is_scalar: + if isinstance(expr, MatrixBase): + result = cls._call_derive_matrix_by_scalar(expr, v) + elif isinstance(expr, MatrixExpr): + result = cls._call_derive_matexpr_by_scalar(expr, v) + elif isinstance(expr, NDimArray): + result = cls._call_derive_array_by_scalar(expr, v) + else: + return None + else: + # Both `expr` and `v` are some array/matrix type: + if isinstance(expr, MatrixBase) or isinstance(v, MatrixBase): + result = derive_by_array(expr, v) + elif isinstance(expr, MatrixExpr) and isinstance(v, MatrixExpr): + result = cls._call_derive_default(expr, v) + elif isinstance(expr, MatrixExpr) or isinstance(v, MatrixExpr): + # if one expression is a symbolic matrix expression while the other isn't, don't evaluate: + return None + else: + result = derive_by_array(expr, v) + if result is None: + return None + if count == 1: + return result + else: + return cls._dispatch_eval_derivative_n_times(result, v, count - 1) diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/arrayop.py b/phivenv/Lib/site-packages/sympy/tensor/array/arrayop.py new file mode 100644 index 0000000000000000000000000000000000000000..a81e6b381a8a93f0cd585278a4be0259b06406dd --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/tensor/array/arrayop.py @@ -0,0 +1,528 @@ +import itertools +from collections.abc import Iterable + +from sympy.core._print_helpers import Printable +from sympy.core.containers import Tuple +from sympy.core.function import diff +from sympy.core.singleton import S +from sympy.core.sympify import _sympify + +from sympy.tensor.array.ndim_array import NDimArray +from sympy.tensor.array.dense_ndim_array import DenseNDimArray, ImmutableDenseNDimArray +from sympy.tensor.array.sparse_ndim_array import SparseNDimArray + + +def _arrayfy(a): + from sympy.matrices import MatrixBase + + if isinstance(a, NDimArray): + return a + if isinstance(a, (MatrixBase, list, tuple, Tuple)): + return ImmutableDenseNDimArray(a) + return a + + +def tensorproduct(*args): + """ + Tensor product among scalars or array-like objects. + + The equivalent operator for array expressions is ``ArrayTensorProduct``, + which can be used to keep the expression unevaluated. + + Examples + ======== + + >>> from sympy.tensor.array import tensorproduct, Array + >>> from sympy.abc import x, y, z, t + >>> A = Array([[1, 2], [3, 4]]) + >>> B = Array([x, y]) + >>> tensorproduct(A, B) + [[[x, y], [2*x, 2*y]], [[3*x, 3*y], [4*x, 4*y]]] + >>> tensorproduct(A, x) + [[x, 2*x], [3*x, 4*x]] + >>> tensorproduct(A, B, B) + [[[[x**2, x*y], [x*y, y**2]], [[2*x**2, 2*x*y], [2*x*y, 2*y**2]]], [[[3*x**2, 3*x*y], [3*x*y, 3*y**2]], [[4*x**2, 4*x*y], [4*x*y, 4*y**2]]]] + + Applying this function on two matrices will result in a rank 4 array. + + >>> from sympy import Matrix, eye + >>> m = Matrix([[x, y], [z, t]]) + >>> p = tensorproduct(eye(3), m) + >>> p + [[[[x, y], [z, t]], [[0, 0], [0, 0]], [[0, 0], [0, 0]]], [[[0, 0], [0, 0]], [[x, y], [z, t]], [[0, 0], [0, 0]]], [[[0, 0], [0, 0]], [[0, 0], [0, 0]], [[x, y], [z, t]]]] + + See Also + ======== + + sympy.tensor.array.expressions.array_expressions.ArrayTensorProduct + + """ + from sympy.tensor.array import SparseNDimArray, ImmutableSparseNDimArray + + if len(args) == 0: + return S.One + if len(args) == 1: + return _arrayfy(args[0]) + from sympy.tensor.array.expressions.array_expressions import _CodegenArrayAbstract + from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct + from sympy.tensor.array.expressions.array_expressions import _ArrayExpr + from sympy.matrices.expressions.matexpr import MatrixSymbol + if any(isinstance(arg, (_ArrayExpr, _CodegenArrayAbstract, MatrixSymbol)) for arg in args): + return ArrayTensorProduct(*args) + if len(args) > 2: + return tensorproduct(tensorproduct(args[0], args[1]), *args[2:]) + + # length of args is 2: + a, b = map(_arrayfy, args) + + if not isinstance(a, NDimArray) or not isinstance(b, NDimArray): + return a*b + + if isinstance(a, SparseNDimArray) and isinstance(b, SparseNDimArray): + lp = len(b) + new_array = {k1*lp + k2: v1*v2 for k1, v1 in a._sparse_array.items() for k2, v2 in b._sparse_array.items()} + return ImmutableSparseNDimArray(new_array, a.shape + b.shape) + + product_list = [i*j for i in Flatten(a) for j in Flatten(b)] + return ImmutableDenseNDimArray(product_list, a.shape + b.shape) + + +def _util_contraction_diagonal(array, *contraction_or_diagonal_axes): + array = _arrayfy(array) + + # Verify contraction_axes: + taken_dims = set() + for axes_group in contraction_or_diagonal_axes: + if not isinstance(axes_group, Iterable): + raise ValueError("collections of contraction/diagonal axes expected") + + dim = array.shape[axes_group[0]] + + for d in axes_group: + if d in taken_dims: + raise ValueError("dimension specified more than once") + if dim != array.shape[d]: + raise ValueError("cannot contract or diagonalize between axes of different dimension") + taken_dims.add(d) + + rank = array.rank() + + remaining_shape = [dim for i, dim in enumerate(array.shape) if i not in taken_dims] + cum_shape = [0]*rank + _cumul = 1 + for i in range(rank): + cum_shape[rank - i - 1] = _cumul + _cumul *= int(array.shape[rank - i - 1]) + + # DEFINITION: by absolute position it is meant the position along the one + # dimensional array containing all the tensor components. + + # Possible future work on this module: move computation of absolute + # positions to a class method. + + # Determine absolute positions of the uncontracted indices: + remaining_indices = [[cum_shape[i]*j for j in range(array.shape[i])] + for i in range(rank) if i not in taken_dims] + + # Determine absolute positions of the contracted indices: + summed_deltas = [] + for axes_group in contraction_or_diagonal_axes: + lidx = [] + for js in range(array.shape[axes_group[0]]): + lidx.append(sum(cum_shape[ig] * js for ig in axes_group)) + summed_deltas.append(lidx) + + return array, remaining_indices, remaining_shape, summed_deltas + + +def tensorcontraction(array, *contraction_axes): + """ + Contraction of an array-like object on the specified axes. + + The equivalent operator for array expressions is ``ArrayContraction``, + which can be used to keep the expression unevaluated. + + Examples + ======== + + >>> from sympy import Array, tensorcontraction + >>> from sympy import Matrix, eye + >>> tensorcontraction(eye(3), (0, 1)) + 3 + >>> A = Array(range(18), (3, 2, 3)) + >>> A + [[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]], [[12, 13, 14], [15, 16, 17]]] + >>> tensorcontraction(A, (0, 2)) + [21, 30] + + Matrix multiplication may be emulated with a proper combination of + ``tensorcontraction`` and ``tensorproduct`` + + >>> from sympy import tensorproduct + >>> from sympy.abc import a,b,c,d,e,f,g,h + >>> m1 = Matrix([[a, b], [c, d]]) + >>> m2 = Matrix([[e, f], [g, h]]) + >>> p = tensorproduct(m1, m2) + >>> p + [[[[a*e, a*f], [a*g, a*h]], [[b*e, b*f], [b*g, b*h]]], [[[c*e, c*f], [c*g, c*h]], [[d*e, d*f], [d*g, d*h]]]] + >>> tensorcontraction(p, (1, 2)) + [[a*e + b*g, a*f + b*h], [c*e + d*g, c*f + d*h]] + >>> m1*m2 + Matrix([ + [a*e + b*g, a*f + b*h], + [c*e + d*g, c*f + d*h]]) + + See Also + ======== + + sympy.tensor.array.expressions.array_expressions.ArrayContraction + + """ + from sympy.tensor.array.expressions.array_expressions import _array_contraction + from sympy.tensor.array.expressions.array_expressions import _CodegenArrayAbstract + from sympy.tensor.array.expressions.array_expressions import _ArrayExpr + from sympy.matrices.expressions.matexpr import MatrixSymbol + if isinstance(array, (_ArrayExpr, _CodegenArrayAbstract, MatrixSymbol)): + return _array_contraction(array, *contraction_axes) + + array, remaining_indices, remaining_shape, summed_deltas = _util_contraction_diagonal(array, *contraction_axes) + + # Compute the contracted array: + # + # 1. external for loops on all uncontracted indices. + # Uncontracted indices are determined by the combinatorial product of + # the absolute positions of the remaining indices. + # 2. internal loop on all contracted indices. + # It sums the values of the absolute contracted index and the absolute + # uncontracted index for the external loop. + contracted_array = [] + for icontrib in itertools.product(*remaining_indices): + index_base_position = sum(icontrib) + isum = S.Zero + for sum_to_index in itertools.product(*summed_deltas): + idx = array._get_tuple_index(index_base_position + sum(sum_to_index)) + isum += array[idx] + + contracted_array.append(isum) + + if len(remaining_indices) == 0: + assert len(contracted_array) == 1 + return contracted_array[0] + + return type(array)(contracted_array, remaining_shape) + + +def tensordiagonal(array, *diagonal_axes): + """ + Diagonalization of an array-like object on the specified axes. + + This is equivalent to multiplying the expression by Kronecker deltas + uniting the axes. + + The diagonal indices are put at the end of the axes. + + The equivalent operator for array expressions is ``ArrayDiagonal``, which + can be used to keep the expression unevaluated. + + Examples + ======== + + ``tensordiagonal`` acting on a 2-dimensional array by axes 0 and 1 is + equivalent to the diagonal of the matrix: + + >>> from sympy import Array, tensordiagonal + >>> from sympy import Matrix, eye + >>> tensordiagonal(eye(3), (0, 1)) + [1, 1, 1] + + >>> from sympy.abc import a,b,c,d + >>> m1 = Matrix([[a, b], [c, d]]) + >>> tensordiagonal(m1, [0, 1]) + [a, d] + + In case of higher dimensional arrays, the diagonalized out dimensions + are appended removed and appended as a single dimension at the end: + + >>> A = Array(range(18), (3, 2, 3)) + >>> A + [[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]], [[12, 13, 14], [15, 16, 17]]] + >>> tensordiagonal(A, (0, 2)) + [[0, 7, 14], [3, 10, 17]] + >>> from sympy import permutedims + >>> tensordiagonal(A, (0, 2)) == permutedims(Array([A[0, :, 0], A[1, :, 1], A[2, :, 2]]), [1, 0]) + True + + See Also + ======== + + sympy.tensor.array.expressions.array_expressions.ArrayDiagonal + + """ + if any(len(i) <= 1 for i in diagonal_axes): + raise ValueError("need at least two axes to diagonalize") + + from sympy.tensor.array.expressions.array_expressions import _ArrayExpr + from sympy.tensor.array.expressions.array_expressions import _CodegenArrayAbstract + from sympy.tensor.array.expressions.array_expressions import ArrayDiagonal, _array_diagonal + from sympy.matrices.expressions.matexpr import MatrixSymbol + if isinstance(array, (_ArrayExpr, _CodegenArrayAbstract, MatrixSymbol)): + return _array_diagonal(array, *diagonal_axes) + + ArrayDiagonal._validate(array, *diagonal_axes) + + array, remaining_indices, remaining_shape, diagonal_deltas = _util_contraction_diagonal(array, *diagonal_axes) + + # Compute the diagonalized array: + # + # 1. external for loops on all undiagonalized indices. + # Undiagonalized indices are determined by the combinatorial product of + # the absolute positions of the remaining indices. + # 2. internal loop on all diagonal indices. + # It appends the values of the absolute diagonalized index and the absolute + # undiagonalized index for the external loop. + diagonalized_array = [] + diagonal_shape = [len(i) for i in diagonal_deltas] + for icontrib in itertools.product(*remaining_indices): + index_base_position = sum(icontrib) + isum = [] + for sum_to_index in itertools.product(*diagonal_deltas): + idx = array._get_tuple_index(index_base_position + sum(sum_to_index)) + isum.append(array[idx]) + + isum = type(array)(isum).reshape(*diagonal_shape) + diagonalized_array.append(isum) + + return type(array)(diagonalized_array, remaining_shape + diagonal_shape) + + +def derive_by_array(expr, dx): + r""" + Derivative by arrays. Supports both arrays and scalars. + + The equivalent operator for array expressions is ``array_derive``. + + Explanation + =========== + + Given the array `A_{i_1, \ldots, i_N}` and the array `X_{j_1, \ldots, j_M}` + this function will return a new array `B` defined by + + `B_{j_1,\ldots,j_M,i_1,\ldots,i_N} := \frac{\partial A_{i_1,\ldots,i_N}}{\partial X_{j_1,\ldots,j_M}}` + + Examples + ======== + + >>> from sympy import derive_by_array + >>> from sympy.abc import x, y, z, t + >>> from sympy import cos + >>> derive_by_array(cos(x*t), x) + -t*sin(t*x) + >>> derive_by_array(cos(x*t), [x, y, z, t]) + [-t*sin(t*x), 0, 0, -x*sin(t*x)] + >>> derive_by_array([x, y**2*z], [[x, y], [z, t]]) + [[[1, 0], [0, 2*y*z]], [[0, y**2], [0, 0]]] + + """ + from sympy.matrices import MatrixBase + from sympy.tensor.array import SparseNDimArray + array_types = (Iterable, MatrixBase, NDimArray) + + if isinstance(dx, array_types): + dx = ImmutableDenseNDimArray(dx) + for i in dx: + if not i._diff_wrt: + raise ValueError("cannot derive by this array") + + if isinstance(expr, array_types): + if isinstance(expr, NDimArray): + expr = expr.as_immutable() + else: + expr = ImmutableDenseNDimArray(expr) + + if isinstance(dx, array_types): + if isinstance(expr, SparseNDimArray): + lp = len(expr) + new_array = {k + i*lp: v + for i, x in enumerate(Flatten(dx)) + for k, v in expr.diff(x)._sparse_array.items()} + else: + new_array = [[y.diff(x) for y in Flatten(expr)] for x in Flatten(dx)] + return type(expr)(new_array, dx.shape + expr.shape) + else: + return expr.diff(dx) + else: + expr = _sympify(expr) + if isinstance(dx, array_types): + return ImmutableDenseNDimArray([expr.diff(i) for i in Flatten(dx)], dx.shape) + else: + dx = _sympify(dx) + return diff(expr, dx) + + +def permutedims(expr, perm=None, index_order_old=None, index_order_new=None): + """ + Permutes the indices of an array. + + Parameter specifies the permutation of the indices. + + The equivalent operator for array expressions is ``PermuteDims``, which can + be used to keep the expression unevaluated. + + Examples + ======== + + >>> from sympy.abc import x, y, z, t + >>> from sympy import sin + >>> from sympy import Array, permutedims + >>> a = Array([[x, y, z], [t, sin(x), 0]]) + >>> a + [[x, y, z], [t, sin(x), 0]] + >>> permutedims(a, (1, 0)) + [[x, t], [y, sin(x)], [z, 0]] + + If the array is of second order, ``transpose`` can be used: + + >>> from sympy import transpose + >>> transpose(a) + [[x, t], [y, sin(x)], [z, 0]] + + Examples on higher dimensions: + + >>> b = Array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + >>> permutedims(b, (2, 1, 0)) + [[[1, 5], [3, 7]], [[2, 6], [4, 8]]] + >>> permutedims(b, (1, 2, 0)) + [[[1, 5], [2, 6]], [[3, 7], [4, 8]]] + + An alternative way to specify the same permutations as in the previous + lines involves passing the *old* and *new* indices, either as a list or as + a string: + + >>> permutedims(b, index_order_old="cba", index_order_new="abc") + [[[1, 5], [3, 7]], [[2, 6], [4, 8]]] + >>> permutedims(b, index_order_old="cab", index_order_new="abc") + [[[1, 5], [2, 6]], [[3, 7], [4, 8]]] + + ``Permutation`` objects are also allowed: + + >>> from sympy.combinatorics import Permutation + >>> permutedims(b, Permutation([1, 2, 0])) + [[[1, 5], [2, 6]], [[3, 7], [4, 8]]] + + See Also + ======== + + sympy.tensor.array.expressions.array_expressions.PermuteDims + + """ + from sympy.tensor.array import SparseNDimArray + + from sympy.tensor.array.expressions.array_expressions import _ArrayExpr + from sympy.tensor.array.expressions.array_expressions import _CodegenArrayAbstract + from sympy.tensor.array.expressions.array_expressions import _permute_dims + from sympy.matrices.expressions.matexpr import MatrixSymbol + from sympy.tensor.array.expressions import PermuteDims + from sympy.tensor.array.expressions.array_expressions import get_rank + perm = PermuteDims._get_permutation_from_arguments(perm, index_order_old, index_order_new, get_rank(expr)) + if isinstance(expr, (_ArrayExpr, _CodegenArrayAbstract, MatrixSymbol)): + return _permute_dims(expr, perm) + + if not isinstance(expr, NDimArray): + expr = ImmutableDenseNDimArray(expr) + + from sympy.combinatorics import Permutation + if not isinstance(perm, Permutation): + perm = Permutation(list(perm)) + + if perm.size != expr.rank(): + raise ValueError("wrong permutation size") + + # Get the inverse permutation: + iperm = ~perm + new_shape = perm(expr.shape) + + if isinstance(expr, SparseNDimArray): + return type(expr)({tuple(perm(expr._get_tuple_index(k))): v + for k, v in expr._sparse_array.items()}, new_shape) + + indices_span = perm([range(i) for i in expr.shape]) + + new_array = [None]*len(expr) + for i, idx in enumerate(itertools.product(*indices_span)): + t = iperm(idx) + new_array[i] = expr[t] + + return type(expr)(new_array, new_shape) + + +class Flatten(Printable): + """ + Flatten an iterable object to a list in a lazy-evaluation way. + + Notes + ===== + + This class is an iterator with which the memory cost can be economised. + Optimisation has been considered to ameliorate the performance for some + specific data types like DenseNDimArray and SparseNDimArray. + + Examples + ======== + + >>> from sympy.tensor.array.arrayop import Flatten + >>> from sympy.tensor.array import Array + >>> A = Array(range(6)).reshape(2, 3) + >>> Flatten(A) + Flatten([[0, 1, 2], [3, 4, 5]]) + >>> [i for i in Flatten(A)] + [0, 1, 2, 3, 4, 5] + """ + def __init__(self, iterable): + from sympy.matrices.matrixbase import MatrixBase + from sympy.tensor.array import NDimArray + + if not isinstance(iterable, (Iterable, MatrixBase)): + raise NotImplementedError("Data type not yet supported") + + if isinstance(iterable, list): + iterable = NDimArray(iterable) + + self._iter = iterable + self._idx = 0 + + def __iter__(self): + return self + + def __next__(self): + from sympy.matrices.matrixbase import MatrixBase + + if len(self._iter) > self._idx: + if isinstance(self._iter, DenseNDimArray): + result = self._iter._array[self._idx] + + elif isinstance(self._iter, SparseNDimArray): + if self._idx in self._iter._sparse_array: + result = self._iter._sparse_array[self._idx] + else: + result = 0 + + elif isinstance(self._iter, MatrixBase): + result = self._iter[self._idx] + + elif hasattr(self._iter, '__next__'): + result = next(self._iter) + + else: + result = self._iter[self._idx] + + else: + raise StopIteration + + self._idx += 1 + return result + + def next(self): + return self.__next__() + + def _sympystr(self, printer): + return type(self).__name__ + '(' + printer._print(self._iter) + ')' diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/dense_ndim_array.py b/phivenv/Lib/site-packages/sympy/tensor/array/dense_ndim_array.py new file mode 100644 index 0000000000000000000000000000000000000000..576e452c55d8d374ca1f72c553f3a64de7227d43 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/tensor/array/dense_ndim_array.py @@ -0,0 +1,206 @@ +import functools +from typing import List + +from sympy.core.basic import Basic +from sympy.core.containers import Tuple +from sympy.core.singleton import S +from sympy.core.sympify import _sympify +from sympy.tensor.array.mutable_ndim_array import MutableNDimArray +from sympy.tensor.array.ndim_array import NDimArray, ImmutableNDimArray, ArrayKind +from sympy.utilities.iterables import flatten + + +class DenseNDimArray(NDimArray): + + _array: List[Basic] + + def __new__(self, *args, **kwargs): + return ImmutableDenseNDimArray(*args, **kwargs) + + @property + def kind(self) -> ArrayKind: + return ArrayKind._union(self._array) + + def __getitem__(self, index): + """ + Allows to get items from N-dim array. + + Examples + ======== + + >>> from sympy import MutableDenseNDimArray + >>> a = MutableDenseNDimArray([0, 1, 2, 3], (2, 2)) + >>> a + [[0, 1], [2, 3]] + >>> a[0, 0] + 0 + >>> a[1, 1] + 3 + >>> a[0] + [0, 1] + >>> a[1] + [2, 3] + + + Symbolic index: + + >>> from sympy.abc import i, j + >>> a[i, j] + [[0, 1], [2, 3]][i, j] + + Replace `i` and `j` to get element `(1, 1)`: + + >>> a[i, j].subs({i: 1, j: 1}) + 3 + + """ + syindex = self._check_symbolic_index(index) + if syindex is not None: + return syindex + + index = self._check_index_for_getitem(index) + + if isinstance(index, tuple) and any(isinstance(i, slice) for i in index): + sl_factors, eindices = self._get_slice_data_for_array_access(index) + array = [self._array[self._parse_index(i)] for i in eindices] + nshape = [len(el) for i, el in enumerate(sl_factors) if isinstance(index[i], slice)] + return type(self)(array, nshape) + else: + index = self._parse_index(index) + return self._array[index] + + @classmethod + def zeros(cls, *shape): + list_length = functools.reduce(lambda x, y: x*y, shape, S.One) + return cls._new(([0]*list_length,), shape) + + def tomatrix(self): + """ + Converts MutableDenseNDimArray to Matrix. Can convert only 2-dim array, else will raise error. + + Examples + ======== + + >>> from sympy import MutableDenseNDimArray + >>> a = MutableDenseNDimArray([1 for i in range(9)], (3, 3)) + >>> b = a.tomatrix() + >>> b + Matrix([ + [1, 1, 1], + [1, 1, 1], + [1, 1, 1]]) + + """ + from sympy.matrices import Matrix + + if self.rank() != 2: + raise ValueError('Dimensions must be of size of 2') + + return Matrix(self.shape[0], self.shape[1], self._array) + + def reshape(self, *newshape): + """ + Returns MutableDenseNDimArray instance with new shape. Elements number + must be suitable to new shape. The only argument of method sets + new shape. + + Examples + ======== + + >>> from sympy import MutableDenseNDimArray + >>> a = MutableDenseNDimArray([1, 2, 3, 4, 5, 6], (2, 3)) + >>> a.shape + (2, 3) + >>> a + [[1, 2, 3], [4, 5, 6]] + >>> b = a.reshape(3, 2) + >>> b.shape + (3, 2) + >>> b + [[1, 2], [3, 4], [5, 6]] + + """ + new_total_size = functools.reduce(lambda x,y: x*y, newshape) + if new_total_size != self._loop_size: + raise ValueError('Expecting reshape size to %d but got prod(%s) = %d' % ( + self._loop_size, str(newshape), new_total_size)) + + # there is no `.func` as this class does not subtype `Basic`: + return type(self)(self._array, newshape) + + +class ImmutableDenseNDimArray(DenseNDimArray, ImmutableNDimArray): # type: ignore + def __new__(cls, iterable, shape=None, **kwargs): + return cls._new(iterable, shape, **kwargs) + + @classmethod + def _new(cls, iterable, shape, **kwargs): + shape, flat_list = cls._handle_ndarray_creation_inputs(iterable, shape, **kwargs) + shape = Tuple(*map(_sympify, shape)) + cls._check_special_bounds(flat_list, shape) + flat_list = flatten(flat_list) + flat_list = Tuple(*flat_list) + self = Basic.__new__(cls, flat_list, shape, **kwargs) + self._shape = shape + self._array = list(flat_list) + self._rank = len(shape) + self._loop_size = functools.reduce(lambda x,y: x*y, shape, 1) + return self + + def __setitem__(self, index, value): + raise TypeError('immutable N-dim array') + + def as_mutable(self): + return MutableDenseNDimArray(self) + + def _eval_simplify(self, **kwargs): + from sympy.simplify.simplify import simplify + return self.applyfunc(simplify) + +class MutableDenseNDimArray(DenseNDimArray, MutableNDimArray): + + def __new__(cls, iterable=None, shape=None, **kwargs): + return cls._new(iterable, shape, **kwargs) + + @classmethod + def _new(cls, iterable, shape, **kwargs): + shape, flat_list = cls._handle_ndarray_creation_inputs(iterable, shape, **kwargs) + flat_list = flatten(flat_list) + self = object.__new__(cls) + self._shape = shape + self._array = list(flat_list) + self._rank = len(shape) + self._loop_size = functools.reduce(lambda x,y: x*y, shape) if shape else len(flat_list) + return self + + def __setitem__(self, index, value): + """Allows to set items to MutableDenseNDimArray. + + Examples + ======== + + >>> from sympy import MutableDenseNDimArray + >>> a = MutableDenseNDimArray.zeros(2, 2) + >>> a[0,0] = 1 + >>> a[1,1] = 1 + >>> a + [[1, 0], [0, 1]] + + """ + if isinstance(index, tuple) and any(isinstance(i, slice) for i in index): + value, eindices, slice_offsets = self._get_slice_data_for_array_assignment(index, value) + for i in eindices: + other_i = [ind - j for ind, j in zip(i, slice_offsets) if j is not None] + self._array[self._parse_index(i)] = value[other_i] + else: + index = self._parse_index(index) + self._setter_iterable_check(value) + value = _sympify(value) + self._array[index] = value + + def as_immutable(self): + return ImmutableDenseNDimArray(self) + + @property + def free_symbols(self): + return {i for j in self._array for i in j.free_symbols} diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/expressions/__init__.py b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f1658241782cdf0e38a30c43a6d67f9811297f4c --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/__init__.py @@ -0,0 +1,178 @@ +r""" +Array expressions are expressions representing N-dimensional arrays, without +evaluating them. These expressions represent in a certain way abstract syntax +trees of operations on N-dimensional arrays. + +Every N-dimensional array operator has a corresponding array expression object. + +Table of correspondences: + +=============================== ============================= + Array operator Array expression operator +=============================== ============================= + tensorproduct ArrayTensorProduct + tensorcontraction ArrayContraction + tensordiagonal ArrayDiagonal + permutedims PermuteDims +=============================== ============================= + +Examples +======== + +``ArraySymbol`` objects are the N-dimensional equivalent of ``MatrixSymbol`` +objects in the matrix module: + +>>> from sympy.tensor.array.expressions import ArraySymbol +>>> from sympy.abc import i, j, k +>>> A = ArraySymbol("A", (3, 2, 4)) +>>> A.shape +(3, 2, 4) +>>> A[i, j, k] +A[i, j, k] +>>> A.as_explicit() +[[[A[0, 0, 0], A[0, 0, 1], A[0, 0, 2], A[0, 0, 3]], + [A[0, 1, 0], A[0, 1, 1], A[0, 1, 2], A[0, 1, 3]]], + [[A[1, 0, 0], A[1, 0, 1], A[1, 0, 2], A[1, 0, 3]], + [A[1, 1, 0], A[1, 1, 1], A[1, 1, 2], A[1, 1, 3]]], + [[A[2, 0, 0], A[2, 0, 1], A[2, 0, 2], A[2, 0, 3]], + [A[2, 1, 0], A[2, 1, 1], A[2, 1, 2], A[2, 1, 3]]]] + +Component-explicit arrays can be added inside array expressions: + +>>> from sympy import Array +>>> from sympy import tensorproduct +>>> from sympy.tensor.array.expressions import ArrayTensorProduct +>>> a = Array([1, 2, 3]) +>>> b = Array([i, j, k]) +>>> expr = ArrayTensorProduct(a, b, b) +>>> expr +ArrayTensorProduct([1, 2, 3], [i, j, k], [i, j, k]) +>>> expr.as_explicit() == tensorproduct(a, b, b) +True + +Constructing array expressions from index-explicit forms +-------------------------------------------------------- + +Array expressions are index-implicit. This means they do not use any indices to +represent array operations. The function ``convert_indexed_to_array( ... )`` +may be used to convert index-explicit expressions to array expressions. +It takes as input two parameters: the index-explicit expression and the order +of the indices: + +>>> from sympy.tensor.array.expressions import convert_indexed_to_array +>>> from sympy import Sum +>>> A = ArraySymbol("A", (3, 3)) +>>> B = ArraySymbol("B", (3, 3)) +>>> convert_indexed_to_array(A[i, j], [i, j]) +A +>>> convert_indexed_to_array(A[i, j], [j, i]) +PermuteDims(A, (0 1)) +>>> convert_indexed_to_array(A[i, j] + B[j, i], [i, j]) +ArrayAdd(A, PermuteDims(B, (0 1))) +>>> convert_indexed_to_array(Sum(A[i, j]*B[j, k], (j, 0, 2)), [i, k]) +ArrayContraction(ArrayTensorProduct(A, B), (1, 2)) + +The diagonal of a matrix in the array expression form: + +>>> convert_indexed_to_array(A[i, i], [i]) +ArrayDiagonal(A, (0, 1)) + +The trace of a matrix in the array expression form: + +>>> convert_indexed_to_array(Sum(A[i, i], (i, 0, 2)), [i]) +ArrayContraction(A, (0, 1)) + +Compatibility with matrices +--------------------------- + +Array expressions can be mixed with objects from the matrix module: + +>>> from sympy import MatrixSymbol +>>> from sympy.tensor.array.expressions import ArrayContraction +>>> M = MatrixSymbol("M", 3, 3) +>>> N = MatrixSymbol("N", 3, 3) + +Express the matrix product in the array expression form: + +>>> from sympy.tensor.array.expressions import convert_matrix_to_array +>>> expr = convert_matrix_to_array(M*N) +>>> expr +ArrayContraction(ArrayTensorProduct(M, N), (1, 2)) + +The expression can be converted back to matrix form: + +>>> from sympy.tensor.array.expressions import convert_array_to_matrix +>>> convert_array_to_matrix(expr) +M*N + +Add a second contraction on the remaining axes in order to get the trace of `M \cdot N`: + +>>> expr_tr = ArrayContraction(expr, (0, 1)) +>>> expr_tr +ArrayContraction(ArrayContraction(ArrayTensorProduct(M, N), (1, 2)), (0, 1)) + +Flatten the expression by calling ``.doit()`` and remove the nested array contraction operations: + +>>> expr_tr.doit() +ArrayContraction(ArrayTensorProduct(M, N), (0, 3), (1, 2)) + +Get the explicit form of the array expression: + +>>> expr.as_explicit() +[[M[0, 0]*N[0, 0] + M[0, 1]*N[1, 0] + M[0, 2]*N[2, 0], M[0, 0]*N[0, 1] + M[0, 1]*N[1, 1] + M[0, 2]*N[2, 1], M[0, 0]*N[0, 2] + M[0, 1]*N[1, 2] + M[0, 2]*N[2, 2]], + [M[1, 0]*N[0, 0] + M[1, 1]*N[1, 0] + M[1, 2]*N[2, 0], M[1, 0]*N[0, 1] + M[1, 1]*N[1, 1] + M[1, 2]*N[2, 1], M[1, 0]*N[0, 2] + M[1, 1]*N[1, 2] + M[1, 2]*N[2, 2]], + [M[2, 0]*N[0, 0] + M[2, 1]*N[1, 0] + M[2, 2]*N[2, 0], M[2, 0]*N[0, 1] + M[2, 1]*N[1, 1] + M[2, 2]*N[2, 1], M[2, 0]*N[0, 2] + M[2, 1]*N[1, 2] + M[2, 2]*N[2, 2]]] + +Express the trace of a matrix: + +>>> from sympy import Trace +>>> convert_matrix_to_array(Trace(M)) +ArrayContraction(M, (0, 1)) +>>> convert_matrix_to_array(Trace(M*N)) +ArrayContraction(ArrayTensorProduct(M, N), (0, 3), (1, 2)) + +Express the transposition of a matrix (will be expressed as a permutation of the axes: + +>>> convert_matrix_to_array(M.T) +PermuteDims(M, (0 1)) + +Compute the derivative array expressions: + +>>> from sympy.tensor.array.expressions import array_derive +>>> d = array_derive(M, M) +>>> d +PermuteDims(ArrayTensorProduct(I, I), (3)(1 2)) + +Verify that the derivative corresponds to the form computed with explicit matrices: + +>>> d.as_explicit() +[[[[1, 0, 0], [0, 0, 0], [0, 0, 0]], [[0, 1, 0], [0, 0, 0], [0, 0, 0]], [[0, 0, 1], [0, 0, 0], [0, 0, 0]]], [[[0, 0, 0], [1, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 1, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 1], [0, 0, 0]]], [[[0, 0, 0], [0, 0, 0], [1, 0, 0]], [[0, 0, 0], [0, 0, 0], [0, 1, 0]], [[0, 0, 0], [0, 0, 0], [0, 0, 1]]]] +>>> Me = M.as_explicit() +>>> Me.diff(Me) +[[[[1, 0, 0], [0, 0, 0], [0, 0, 0]], [[0, 1, 0], [0, 0, 0], [0, 0, 0]], [[0, 0, 1], [0, 0, 0], [0, 0, 0]]], [[[0, 0, 0], [1, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 1, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 1], [0, 0, 0]]], [[[0, 0, 0], [0, 0, 0], [1, 0, 0]], [[0, 0, 0], [0, 0, 0], [0, 1, 0]], [[0, 0, 0], [0, 0, 0], [0, 0, 1]]]] + +""" + +__all__ = [ + "ArraySymbol", "ArrayElement", "ZeroArray", "OneArray", + "ArrayTensorProduct", + "ArrayContraction", + "ArrayDiagonal", + "PermuteDims", + "ArrayAdd", + "ArrayElementwiseApplyFunc", + "Reshape", + "convert_array_to_matrix", + "convert_matrix_to_array", + "convert_array_to_indexed", + "convert_indexed_to_array", + "array_derive", +] + +from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct, ArrayAdd, PermuteDims, ArrayDiagonal, \ + ArrayContraction, Reshape, ArraySymbol, ArrayElement, ZeroArray, OneArray, ArrayElementwiseApplyFunc +from sympy.tensor.array.expressions.arrayexpr_derivatives import array_derive +from sympy.tensor.array.expressions.from_array_to_indexed import convert_array_to_indexed +from sympy.tensor.array.expressions.from_array_to_matrix import convert_array_to_matrix +from sympy.tensor.array.expressions.from_indexed_to_array import convert_indexed_to_array +from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/expressions/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27ec0b75219b1ea3acd7b32116037e04a0144dac Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/expressions/__pycache__/array_expressions.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/__pycache__/array_expressions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..579f4ee04736e01ab5ab6501784f15c67c13536e Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/__pycache__/array_expressions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/expressions/__pycache__/arrayexpr_derivatives.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/__pycache__/arrayexpr_derivatives.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac253d08707745cde049e717a7366fe8d673f1b8 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/__pycache__/arrayexpr_derivatives.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/expressions/__pycache__/conv_array_to_indexed.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/__pycache__/conv_array_to_indexed.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad8f4e958faee409dd226655c3837834f7cfe878 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/__pycache__/conv_array_to_indexed.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/expressions/__pycache__/conv_array_to_matrix.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/__pycache__/conv_array_to_matrix.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83f5f54f6fd68ebb470608ceae2642681ae67bf0 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/__pycache__/conv_array_to_matrix.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/expressions/__pycache__/conv_indexed_to_array.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/__pycache__/conv_indexed_to_array.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1c9a6630ba67008ad3b32bc5adf4d02067fdbac Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/__pycache__/conv_indexed_to_array.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/expressions/__pycache__/conv_matrix_to_array.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/__pycache__/conv_matrix_to_array.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f475bbf5b22df89e2d47be78340c0a9f7fe66767 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/__pycache__/conv_matrix_to_array.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/expressions/__pycache__/from_array_to_indexed.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/__pycache__/from_array_to_indexed.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50023ea73cb4ee17a313320d72201db3ac3c134d Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/__pycache__/from_array_to_indexed.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/expressions/__pycache__/from_array_to_matrix.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/__pycache__/from_array_to_matrix.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9bfd27e253ec35e7cd49251b9c68a589e74e3871 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/__pycache__/from_array_to_matrix.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/expressions/__pycache__/from_indexed_to_array.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/__pycache__/from_indexed_to_array.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2694a697b805f216591f03821c3c00afc1e3ea9 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/__pycache__/from_indexed_to_array.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/expressions/__pycache__/from_matrix_to_array.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/__pycache__/from_matrix_to_array.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c142f4e590ce97cc5fe95c7757d0195f148cb643 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/__pycache__/from_matrix_to_array.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/expressions/__pycache__/utils.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6128e74e803e1844e252a238fe9ca3d34d4b255f Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/__pycache__/utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/expressions/array_expressions.py b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/array_expressions.py new file mode 100644 index 0000000000000000000000000000000000000000..f062e3de4c24987d62ba0b3a19fe474fb4687940 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/array_expressions.py @@ -0,0 +1,1969 @@ +from __future__ import annotations +import collections.abc +import operator +from collections import defaultdict, Counter +from functools import reduce +import itertools +from itertools import accumulate + +import typing + +from sympy.core.numbers import Integer +from sympy.core.relational import Equality +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.core.basic import Basic +from sympy.core.containers import Tuple +from sympy.core.expr import Expr +from sympy.core.function import (Function, Lambda) +from sympy.core.mul import Mul +from sympy.core.singleton import S +from sympy.core.sorting import default_sort_key +from sympy.core.symbol import (Dummy, Symbol) +from sympy.matrices.matrixbase import MatrixBase +from sympy.matrices.expressions.diagonal import diagonalize_vector +from sympy.matrices.expressions.matexpr import MatrixExpr +from sympy.matrices.expressions.special import ZeroMatrix +from sympy.tensor.array.arrayop import (permutedims, tensorcontraction, tensordiagonal, tensorproduct) +from sympy.tensor.array.dense_ndim_array import ImmutableDenseNDimArray +from sympy.tensor.array.ndim_array import NDimArray +from sympy.tensor.indexed import (Indexed, IndexedBase) +from sympy.matrices.expressions.matexpr import MatrixElement +from sympy.tensor.array.expressions.utils import _apply_recursively_over_nested_lists, _sort_contraction_indices, \ + _get_mapping_from_subranks, _build_push_indices_up_func_transformation, _get_contraction_links, \ + _build_push_indices_down_func_transformation +from sympy.combinatorics import Permutation +from sympy.combinatorics.permutations import _af_invert +from sympy.core.sympify import _sympify + + +class _ArrayExpr(Expr): + shape: tuple[Expr, ...] + + def __getitem__(self, item): + if not isinstance(item, collections.abc.Iterable): + item = (item,) + ArrayElement._check_shape(self, item) + return self._get(item) + + def _get(self, item): + return _get_array_element_or_slice(self, item) + + +class ArraySymbol(_ArrayExpr): + """ + Symbol representing an array expression + """ + + _iterable = False + + def __new__(cls, symbol, shape: typing.Iterable) -> "ArraySymbol": + if isinstance(symbol, str): + symbol = Symbol(symbol) + # symbol = _sympify(symbol) + shape = Tuple(*map(_sympify, shape)) + obj = Expr.__new__(cls, symbol, shape) + return obj + + @property + def name(self): + return self._args[0] + + @property + def shape(self): + return self._args[1] + + def as_explicit(self): + if not all(i.is_Integer for i in self.shape): + raise ValueError("cannot express explicit array with symbolic shape") + data = [self[i] for i in itertools.product(*[range(j) for j in self.shape])] + return ImmutableDenseNDimArray(data).reshape(*self.shape) + + +class ArrayElement(Expr): + """ + An element of an array. + """ + + _diff_wrt = True + is_symbol = True + is_commutative = True + + def __new__(cls, name, indices): + if isinstance(name, str): + name = Symbol(name) + name = _sympify(name) + if not isinstance(indices, collections.abc.Iterable): + indices = (indices,) + indices = _sympify(tuple(indices)) + cls._check_shape(name, indices) + obj = Expr.__new__(cls, name, indices) + return obj + + @classmethod + def _check_shape(cls, name, indices): + indices = tuple(indices) + if hasattr(name, "shape"): + index_error = IndexError("number of indices does not match shape of the array") + if len(indices) != len(name.shape): + raise index_error + if any((i >= s) == True for i, s in zip(indices, name.shape)): + raise ValueError("shape is out of bounds") + if any((i < 0) == True for i in indices): + raise ValueError("shape contains negative values") + + @property + def name(self): + return self._args[0] + + @property + def indices(self): + return self._args[1] + + def _eval_derivative(self, s): + if not isinstance(s, ArrayElement): + return S.Zero + + if s == self: + return S.One + + if s.name != self.name: + return S.Zero + + return Mul.fromiter(KroneckerDelta(i, j) for i, j in zip(self.indices, s.indices)) + + +class ZeroArray(_ArrayExpr): + """ + Symbolic array of zeros. Equivalent to ``ZeroMatrix`` for matrices. + """ + + def __new__(cls, *shape): + if len(shape) == 0: + return S.Zero + shape = map(_sympify, shape) + obj = Expr.__new__(cls, *shape) + return obj + + @property + def shape(self): + return self._args + + def as_explicit(self): + if not all(i.is_Integer for i in self.shape): + raise ValueError("Cannot return explicit form for symbolic shape.") + return ImmutableDenseNDimArray.zeros(*self.shape) + + def _get(self, item): + return S.Zero + + +class OneArray(_ArrayExpr): + """ + Symbolic array of ones. + """ + + def __new__(cls, *shape): + if len(shape) == 0: + return S.One + shape = map(_sympify, shape) + obj = Expr.__new__(cls, *shape) + return obj + + @property + def shape(self): + return self._args + + def as_explicit(self): + if not all(i.is_Integer for i in self.shape): + raise ValueError("Cannot return explicit form for symbolic shape.") + return ImmutableDenseNDimArray([S.One for i in range(reduce(operator.mul, self.shape))]).reshape(*self.shape) + + def _get(self, item): + return S.One + + +class _CodegenArrayAbstract(Basic): + + @property + def subranks(self): + """ + Returns the ranks of the objects in the uppermost tensor product inside + the current object. In case no tensor products are contained, return + the atomic ranks. + + Examples + ======== + + >>> from sympy.tensor.array import tensorproduct, tensorcontraction + >>> from sympy import MatrixSymbol + >>> M = MatrixSymbol("M", 3, 3) + >>> N = MatrixSymbol("N", 3, 3) + >>> P = MatrixSymbol("P", 3, 3) + + Important: do not confuse the rank of the matrix with the rank of an array. + + >>> tp = tensorproduct(M, N, P) + >>> tp.subranks + [2, 2, 2] + + >>> co = tensorcontraction(tp, (1, 2), (3, 4)) + >>> co.subranks + [2, 2, 2] + """ + return self._subranks[:] + + def subrank(self): + """ + The sum of ``subranks``. + """ + return sum(self.subranks) + + @property + def shape(self): + return self._shape + + def doit(self, **hints): + deep = hints.get("deep", True) + if deep: + return self.func(*[arg.doit(**hints) for arg in self.args])._canonicalize() + else: + return self._canonicalize() + +class ArrayTensorProduct(_CodegenArrayAbstract): + r""" + Class to represent the tensor product of array-like objects. + """ + + def __new__(cls, *args, **kwargs): + args = [_sympify(arg) for arg in args] + + canonicalize = kwargs.pop("canonicalize", False) + + ranks = [get_rank(arg) for arg in args] + + obj = Basic.__new__(cls, *args) + obj._subranks = ranks + shapes = [get_shape(i) for i in args] + + if any(i is None for i in shapes): + obj._shape = None + else: + obj._shape = tuple(j for i in shapes for j in i) + if canonicalize: + return obj._canonicalize() + return obj + + def _canonicalize(self): + args = self.args + args = self._flatten(args) + + ranks = [get_rank(arg) for arg in args] + + # Check if there are nested permutation and lift them up: + permutation_cycles = [] + for i, arg in enumerate(args): + if not isinstance(arg, PermuteDims): + continue + permutation_cycles.extend([[k + sum(ranks[:i]) for k in j] for j in arg.permutation.cyclic_form]) + args[i] = arg.expr + if permutation_cycles: + return _permute_dims(_array_tensor_product(*args), Permutation(sum(ranks)-1)*Permutation(permutation_cycles)) + + if len(args) == 1: + return args[0] + + # If any object is a ZeroArray, return a ZeroArray: + if any(isinstance(arg, (ZeroArray, ZeroMatrix)) for arg in args): + shapes = reduce(operator.add, [get_shape(i) for i in args], ()) + return ZeroArray(*shapes) + + # If there are contraction objects inside, transform the whole + # expression into `ArrayContraction`: + contractions = {i: arg for i, arg in enumerate(args) if isinstance(arg, ArrayContraction)} + if contractions: + ranks = [_get_subrank(arg) if isinstance(arg, ArrayContraction) else get_rank(arg) for arg in args] + cumulative_ranks = list(accumulate([0] + ranks))[:-1] + tp = _array_tensor_product(*[arg.expr if isinstance(arg, ArrayContraction) else arg for arg in args]) + contraction_indices = [tuple(cumulative_ranks[i] + k for k in j) for i, arg in contractions.items() for j in arg.contraction_indices] + return _array_contraction(tp, *contraction_indices) + + diagonals = {i: arg for i, arg in enumerate(args) if isinstance(arg, ArrayDiagonal)} + if diagonals: + inverse_permutation = [] + last_perm = [] + ranks = [get_rank(arg) for arg in args] + cumulative_ranks = list(accumulate([0] + ranks))[:-1] + for i, arg in enumerate(args): + if isinstance(arg, ArrayDiagonal): + i1 = get_rank(arg) - len(arg.diagonal_indices) + i2 = len(arg.diagonal_indices) + inverse_permutation.extend([cumulative_ranks[i] + j for j in range(i1)]) + last_perm.extend([cumulative_ranks[i] + j for j in range(i1, i1 + i2)]) + else: + inverse_permutation.extend([cumulative_ranks[i] + j for j in range(get_rank(arg))]) + inverse_permutation.extend(last_perm) + tp = _array_tensor_product(*[arg.expr if isinstance(arg, ArrayDiagonal) else arg for arg in args]) + ranks2 = [_get_subrank(arg) if isinstance(arg, ArrayDiagonal) else get_rank(arg) for arg in args] + cumulative_ranks2 = list(accumulate([0] + ranks2))[:-1] + diagonal_indices = [tuple(cumulative_ranks2[i] + k for k in j) for i, arg in diagonals.items() for j in arg.diagonal_indices] + return _permute_dims(_array_diagonal(tp, *diagonal_indices), _af_invert(inverse_permutation)) + + return self.func(*args, canonicalize=False) + + @classmethod + def _flatten(cls, args): + args = [i for arg in args for i in (arg.args if isinstance(arg, cls) else [arg])] + return args + + def as_explicit(self): + return tensorproduct(*[arg.as_explicit() if hasattr(arg, "as_explicit") else arg for arg in self.args]) + + +class ArrayAdd(_CodegenArrayAbstract): + r""" + Class for elementwise array additions. + """ + + def __new__(cls, *args, **kwargs): + args = [_sympify(arg) for arg in args] + ranks = [get_rank(arg) for arg in args] + ranks = list(set(ranks)) + if len(ranks) != 1: + raise ValueError("summing arrays of different ranks") + shapes = [arg.shape for arg in args] + if len({i for i in shapes if i is not None}) > 1: + raise ValueError("mismatching shapes in addition") + + canonicalize = kwargs.pop("canonicalize", False) + + obj = Basic.__new__(cls, *args) + obj._subranks = ranks + if any(i is None for i in shapes): + obj._shape = None + else: + obj._shape = shapes[0] + if canonicalize: + return obj._canonicalize() + return obj + + def _canonicalize(self): + args = self.args + + # Flatten: + args = self._flatten_args(args) + + shapes = [get_shape(arg) for arg in args] + args = [arg for arg in args if not isinstance(arg, (ZeroArray, ZeroMatrix))] + if len(args) == 0: + if any(i for i in shapes if i is None): + raise NotImplementedError("cannot handle addition of ZeroMatrix/ZeroArray and undefined shape object") + return ZeroArray(*shapes[0]) + elif len(args) == 1: + return args[0] + return self.func(*args, canonicalize=False) + + @classmethod + def _flatten_args(cls, args): + new_args = [] + for arg in args: + if isinstance(arg, ArrayAdd): + new_args.extend(arg.args) + else: + new_args.append(arg) + return new_args + + def as_explicit(self): + return reduce( + operator.add, + [arg.as_explicit() if hasattr(arg, "as_explicit") else arg for arg in self.args]) + + +class PermuteDims(_CodegenArrayAbstract): + r""" + Class to represent permutation of axes of arrays. + + Examples + ======== + + >>> from sympy.tensor.array import permutedims + >>> from sympy import MatrixSymbol + >>> M = MatrixSymbol("M", 3, 3) + >>> cg = permutedims(M, [1, 0]) + + The object ``cg`` represents the transposition of ``M``, as the permutation + ``[1, 0]`` will act on its indices by switching them: + + `M_{ij} \Rightarrow M_{ji}` + + This is evident when transforming back to matrix form: + + >>> from sympy.tensor.array.expressions.from_array_to_matrix import convert_array_to_matrix + >>> convert_array_to_matrix(cg) + M.T + + >>> N = MatrixSymbol("N", 3, 2) + >>> cg = permutedims(N, [1, 0]) + >>> cg.shape + (2, 3) + + There are optional parameters that can be used as alternative to the permutation: + + >>> from sympy.tensor.array.expressions import ArraySymbol, PermuteDims + >>> M = ArraySymbol("M", (1, 2, 3, 4, 5)) + >>> expr = PermuteDims(M, index_order_old="ijklm", index_order_new="kijml") + >>> expr + PermuteDims(M, (0 2 1)(3 4)) + >>> expr.shape + (3, 1, 2, 5, 4) + + Permutations of tensor products are simplified in order to achieve a + standard form: + + >>> from sympy.tensor.array import tensorproduct + >>> M = MatrixSymbol("M", 4, 5) + >>> tp = tensorproduct(M, N) + >>> tp.shape + (4, 5, 3, 2) + >>> perm1 = permutedims(tp, [2, 3, 1, 0]) + + The args ``(M, N)`` have been sorted and the permutation has been + simplified, the expression is equivalent: + + >>> perm1.expr.args + (N, M) + >>> perm1.shape + (3, 2, 5, 4) + >>> perm1.permutation + (2 3) + + The permutation in its array form has been simplified from + ``[2, 3, 1, 0]`` to ``[0, 1, 3, 2]``, as the arguments of the tensor + product `M` and `N` have been switched: + + >>> perm1.permutation.array_form + [0, 1, 3, 2] + + We can nest a second permutation: + + >>> perm2 = permutedims(perm1, [1, 0, 2, 3]) + >>> perm2.shape + (2, 3, 5, 4) + >>> perm2.permutation.array_form + [1, 0, 3, 2] + """ + + def __new__(cls, expr, permutation=None, index_order_old=None, index_order_new=None, **kwargs): + from sympy.combinatorics import Permutation + expr = _sympify(expr) + expr_rank = get_rank(expr) + permutation = cls._get_permutation_from_arguments(permutation, index_order_old, index_order_new, expr_rank) + permutation = Permutation(permutation) + permutation_size = permutation.size + if permutation_size != expr_rank: + raise ValueError("Permutation size must be the length of the shape of expr") + + canonicalize = kwargs.pop("canonicalize", False) + + obj = Basic.__new__(cls, expr, permutation) + obj._subranks = [get_rank(expr)] + shape = get_shape(expr) + if shape is None: + obj._shape = None + else: + obj._shape = tuple(shape[permutation(i)] for i in range(len(shape))) + if canonicalize: + return obj._canonicalize() + return obj + + def _canonicalize(self): + expr = self.expr + permutation = self.permutation + if isinstance(expr, PermuteDims): + subexpr = expr.expr + subperm = expr.permutation + permutation = permutation * subperm + expr = subexpr + if isinstance(expr, ArrayContraction): + expr, permutation = self._PermuteDims_denestarg_ArrayContraction(expr, permutation) + if isinstance(expr, ArrayTensorProduct): + expr, permutation = self._PermuteDims_denestarg_ArrayTensorProduct(expr, permutation) + if isinstance(expr, (ZeroArray, ZeroMatrix)): + return ZeroArray(*[expr.shape[i] for i in permutation.array_form]) + plist = permutation.array_form + if plist == sorted(plist): + return expr + return self.func(expr, permutation, canonicalize=False) + + @property + def expr(self): + return self.args[0] + + @property + def permutation(self): + return self.args[1] + + @classmethod + def _PermuteDims_denestarg_ArrayTensorProduct(cls, expr, permutation): + # Get the permutation in its image-form: + perm_image_form = _af_invert(permutation.array_form) + args = list(expr.args) + # Starting index global position for every arg: + cumul = list(accumulate([0] + expr.subranks)) + # Split `perm_image_form` into a list of list corresponding to the indices + # of every argument: + perm_image_form_in_components = [perm_image_form[cumul[i]:cumul[i+1]] for i in range(len(args))] + # Create an index, target-position-key array: + ps = [(i, sorted(comp)) for i, comp in enumerate(perm_image_form_in_components)] + # Sort the array according to the target-position-key: + # In this way, we define a canonical way to sort the arguments according + # to the permutation. + ps.sort(key=lambda x: x[1]) + # Read the inverse-permutation (i.e. image-form) of the args: + perm_args_image_form = [i[0] for i in ps] + # Apply the args-permutation to the `args`: + args_sorted = [args[i] for i in perm_args_image_form] + # Apply the args-permutation to the array-form of the permutation of the axes (of `expr`): + perm_image_form_sorted_args = [perm_image_form_in_components[i] for i in perm_args_image_form] + new_permutation = Permutation(_af_invert([j for i in perm_image_form_sorted_args for j in i])) + return _array_tensor_product(*args_sorted), new_permutation + + @classmethod + def _PermuteDims_denestarg_ArrayContraction(cls, expr, permutation): + if not isinstance(expr, ArrayContraction): + return expr, permutation + if not isinstance(expr.expr, ArrayTensorProduct): + return expr, permutation + args = expr.expr.args + subranks = [get_rank(arg) for arg in expr.expr.args] + + contraction_indices = expr.contraction_indices + contraction_indices_flat = [j for i in contraction_indices for j in i] + cumul = list(accumulate([0] + subranks)) + + # Spread the permutation in its array form across the args in the corresponding + # tensor-product arguments with free indices: + permutation_array_blocks_up = [] + image_form = _af_invert(permutation.array_form) + counter = 0 + for i in range(len(subranks)): + current = [] + for j in range(cumul[i], cumul[i+1]): + if j in contraction_indices_flat: + continue + current.append(image_form[counter]) + counter += 1 + permutation_array_blocks_up.append(current) + + # Get the map of axis repositioning for every argument of tensor-product: + index_blocks = [list(range(cumul[i], cumul[i+1])) for i, e in enumerate(expr.subranks)] + index_blocks_up = expr._push_indices_up(expr.contraction_indices, index_blocks) + inverse_permutation = permutation**(-1) + index_blocks_up_permuted = [[inverse_permutation(j) for j in i if j is not None] for i in index_blocks_up] + + # Sorting key is a list of tuple, first element is the index of `args`, second element of + # the tuple is the sorting key to sort `args` of the tensor product: + sorting_keys = list(enumerate(index_blocks_up_permuted)) + sorting_keys.sort(key=lambda x: x[1]) + + # Now we can get the permutation acting on the args in its image-form: + new_perm_image_form = [i[0] for i in sorting_keys] + # Apply the args-level permutation to various elements: + new_index_blocks = [index_blocks[i] for i in new_perm_image_form] + new_index_perm_array_form = _af_invert([j for i in new_index_blocks for j in i]) + new_args = [args[i] for i in new_perm_image_form] + new_contraction_indices = [tuple(new_index_perm_array_form[j] for j in i) for i in contraction_indices] + new_expr = _array_contraction(_array_tensor_product(*new_args), *new_contraction_indices) + new_permutation = Permutation(_af_invert([j for i in [permutation_array_blocks_up[k] for k in new_perm_image_form] for j in i])) + return new_expr, new_permutation + + @classmethod + def _check_permutation_mapping(cls, expr, permutation): + subranks = expr.subranks + index2arg = [i for i, arg in enumerate(expr.args) for j in range(expr.subranks[i])] + permuted_indices = [permutation(i) for i in range(expr.subrank())] + new_args = list(expr.args) + arg_candidate_index = index2arg[permuted_indices[0]] + current_indices = [] + new_permutation = [] + inserted_arg_cand_indices = set() + for i, idx in enumerate(permuted_indices): + if index2arg[idx] != arg_candidate_index: + new_permutation.extend(current_indices) + current_indices = [] + arg_candidate_index = index2arg[idx] + current_indices.append(idx) + arg_candidate_rank = subranks[arg_candidate_index] + if len(current_indices) == arg_candidate_rank: + new_permutation.extend(sorted(current_indices)) + local_current_indices = [j - min(current_indices) for j in current_indices] + i1 = index2arg[i] + new_args[i1] = _permute_dims(new_args[i1], Permutation(local_current_indices)) + inserted_arg_cand_indices.add(arg_candidate_index) + current_indices = [] + new_permutation.extend(current_indices) + + # TODO: swap args positions in order to simplify the expression: + # TODO: this should be in a function + args_positions = list(range(len(new_args))) + # Get possible shifts: + maps = {} + cumulative_subranks = [0] + list(accumulate(subranks)) + for i in range(len(subranks)): + s = {index2arg[new_permutation[j]] for j in range(cumulative_subranks[i], cumulative_subranks[i+1])} + if len(s) != 1: + continue + elem = next(iter(s)) + if i != elem: + maps[i] = elem + + # Find cycles in the map: + lines = [] + current_line = [] + while maps: + if len(current_line) == 0: + k, v = maps.popitem() + current_line.append(k) + else: + k = current_line[-1] + if k not in maps: + current_line = [] + continue + v = maps.pop(k) + if v in current_line: + lines.append(current_line) + current_line = [] + continue + current_line.append(v) + for line in lines: + for i, e in enumerate(line): + args_positions[line[(i + 1) % len(line)]] = e + + # TODO: function in order to permute the args: + permutation_blocks = [[new_permutation[cumulative_subranks[i] + j] for j in range(e)] for i, e in enumerate(subranks)] + new_args = [new_args[i] for i in args_positions] + new_permutation_blocks = [permutation_blocks[i] for i in args_positions] + new_permutation2 = [j for i in new_permutation_blocks for j in i] + return _array_tensor_product(*new_args), Permutation(new_permutation2) # **(-1) + + @classmethod + def _check_if_there_are_closed_cycles(cls, expr, permutation): + args = list(expr.args) + subranks = expr.subranks + cyclic_form = permutation.cyclic_form + cumulative_subranks = [0] + list(accumulate(subranks)) + cyclic_min = [min(i) for i in cyclic_form] + cyclic_max = [max(i) for i in cyclic_form] + cyclic_keep = [] + for i, cycle in enumerate(cyclic_form): + flag = True + for j in range(len(cumulative_subranks) - 1): + if cyclic_min[i] >= cumulative_subranks[j] and cyclic_max[i] < cumulative_subranks[j+1]: + # Found a sinkable cycle. + args[j] = _permute_dims(args[j], Permutation([[k - cumulative_subranks[j] for k in cycle]])) + flag = False + break + if flag: + cyclic_keep.append(cycle) + return _array_tensor_product(*args), Permutation(cyclic_keep, size=permutation.size) + + def nest_permutation(self): + r""" + DEPRECATED. + """ + ret = self._nest_permutation(self.expr, self.permutation) + if ret is None: + return self + return ret + + @classmethod + def _nest_permutation(cls, expr, permutation): + if isinstance(expr, ArrayTensorProduct): + return _permute_dims(*cls._check_if_there_are_closed_cycles(expr, permutation)) + elif isinstance(expr, ArrayContraction): + # Invert tree hierarchy: put the contraction above. + cycles = permutation.cyclic_form + newcycles = ArrayContraction._convert_outer_indices_to_inner_indices(expr, *cycles) + newpermutation = Permutation(newcycles) + new_contr_indices = [tuple(newpermutation(j) for j in i) for i in expr.contraction_indices] + return _array_contraction(PermuteDims(expr.expr, newpermutation), *new_contr_indices) + elif isinstance(expr, ArrayAdd): + return _array_add(*[PermuteDims(arg, permutation) for arg in expr.args]) + return None + + def as_explicit(self): + expr = self.expr + if hasattr(expr, "as_explicit"): + expr = expr.as_explicit() + return permutedims(expr, self.permutation) + + @classmethod + def _get_permutation_from_arguments(cls, permutation, index_order_old, index_order_new, dim): + if permutation is None: + if index_order_new is None or index_order_old is None: + raise ValueError("Permutation not defined") + return PermuteDims._get_permutation_from_index_orders(index_order_old, index_order_new, dim) + else: + if index_order_new is not None: + raise ValueError("index_order_new cannot be defined with permutation") + if index_order_old is not None: + raise ValueError("index_order_old cannot be defined with permutation") + return permutation + + @classmethod + def _get_permutation_from_index_orders(cls, index_order_old, index_order_new, dim): + if len(set(index_order_new)) != dim: + raise ValueError("wrong number of indices in index_order_new") + if len(set(index_order_old)) != dim: + raise ValueError("wrong number of indices in index_order_old") + if len(set.symmetric_difference(set(index_order_new), set(index_order_old))) > 0: + raise ValueError("index_order_new and index_order_old must have the same indices") + permutation = [index_order_old.index(i) for i in index_order_new] + return permutation + + +class ArrayDiagonal(_CodegenArrayAbstract): + r""" + Class to represent the diagonal operator. + + Explanation + =========== + + In a 2-dimensional array it returns the diagonal, this looks like the + operation: + + `A_{ij} \rightarrow A_{ii}` + + The diagonal over axes 1 and 2 (the second and third) of the tensor product + of two 2-dimensional arrays `A \otimes B` is + + `\Big[ A_{ab} B_{cd} \Big]_{abcd} \rightarrow \Big[ A_{ai} B_{id} \Big]_{adi}` + + In this last example the array expression has been reduced from + 4-dimensional to 3-dimensional. Notice that no contraction has occurred, + rather there is a new index `i` for the diagonal, contraction would have + reduced the array to 2 dimensions. + + Notice that the diagonalized out dimensions are added as new dimensions at + the end of the indices. + """ + + def __new__(cls, expr, *diagonal_indices, **kwargs): + expr = _sympify(expr) + diagonal_indices = [Tuple(*sorted(i)) for i in diagonal_indices] + canonicalize = kwargs.get("canonicalize", False) + + shape = get_shape(expr) + if shape is not None: + cls._validate(expr, *diagonal_indices, **kwargs) + # Get new shape: + positions, shape = cls._get_positions_shape(shape, diagonal_indices) + else: + positions = None + if len(diagonal_indices) == 0: + return expr + obj = Basic.__new__(cls, expr, *diagonal_indices) + obj._positions = positions + obj._subranks = _get_subranks(expr) + obj._shape = shape + if canonicalize: + return obj._canonicalize() + return obj + + def _canonicalize(self): + expr = self.expr + diagonal_indices = self.diagonal_indices + trivial_diags = [i for i in diagonal_indices if len(i) == 1] + if len(trivial_diags) > 0: + trivial_pos = {e[0]: i for i, e in enumerate(diagonal_indices) if len(e) == 1} + diag_pos = {e: i for i, e in enumerate(diagonal_indices) if len(e) > 1} + diagonal_indices_short = [i for i in diagonal_indices if len(i) > 1] + rank1 = get_rank(self) + rank2 = len(diagonal_indices) + rank3 = rank1 - rank2 + inv_permutation = [] + counter1 = 0 + indices_down = ArrayDiagonal._push_indices_down(diagonal_indices_short, list(range(rank1)), get_rank(expr)) + for i in indices_down: + if i in trivial_pos: + inv_permutation.append(rank3 + trivial_pos[i]) + elif isinstance(i, (Integer, int)): + inv_permutation.append(counter1) + counter1 += 1 + else: + inv_permutation.append(rank3 + diag_pos[i]) + permutation = _af_invert(inv_permutation) + if len(diagonal_indices_short) > 0: + return _permute_dims(_array_diagonal(expr, *diagonal_indices_short), permutation) + else: + return _permute_dims(expr, permutation) + if isinstance(expr, ArrayAdd): + return self._ArrayDiagonal_denest_ArrayAdd(expr, *diagonal_indices) + if isinstance(expr, ArrayDiagonal): + return self._ArrayDiagonal_denest_ArrayDiagonal(expr, *diagonal_indices) + if isinstance(expr, PermuteDims): + return self._ArrayDiagonal_denest_PermuteDims(expr, *diagonal_indices) + if isinstance(expr, (ZeroArray, ZeroMatrix)): + positions, shape = self._get_positions_shape(expr.shape, diagonal_indices) + return ZeroArray(*shape) + return self.func(expr, *diagonal_indices, canonicalize=False) + + @staticmethod + def _validate(expr, *diagonal_indices, **kwargs): + # Check that no diagonalization happens on indices with mismatched + # dimensions: + shape = get_shape(expr) + for i in diagonal_indices: + if any(j >= len(shape) for j in i): + raise ValueError("index is larger than expression shape") + if len({shape[j] for j in i}) != 1: + raise ValueError("diagonalizing indices of different dimensions") + if not kwargs.get("allow_trivial_diags", False) and len(i) <= 1: + raise ValueError("need at least two axes to diagonalize") + if len(set(i)) != len(i): + raise ValueError("axis index cannot be repeated") + + @staticmethod + def _remove_trivial_dimensions(shape, *diagonal_indices): + return [tuple(j for j in i) for i in diagonal_indices if shape[i[0]] != 1] + + @property + def expr(self): + return self.args[0] + + @property + def diagonal_indices(self): + return self.args[1:] + + @staticmethod + def _flatten(expr, *outer_diagonal_indices): + inner_diagonal_indices = expr.diagonal_indices + all_inner = [j for i in inner_diagonal_indices for j in i] + all_inner.sort() + # TODO: add API for total rank and cumulative rank: + total_rank = _get_subrank(expr) + inner_rank = len(all_inner) + outer_rank = total_rank - inner_rank + shifts = [0 for i in range(outer_rank)] + counter = 0 + pointer = 0 + for i in range(outer_rank): + while pointer < inner_rank and counter >= all_inner[pointer]: + counter += 1 + pointer += 1 + shifts[i] += pointer + counter += 1 + outer_diagonal_indices = tuple(tuple(shifts[j] + j for j in i) for i in outer_diagonal_indices) + diagonal_indices = inner_diagonal_indices + outer_diagonal_indices + return _array_diagonal(expr.expr, *diagonal_indices) + + @classmethod + def _ArrayDiagonal_denest_ArrayAdd(cls, expr, *diagonal_indices): + return _array_add(*[_array_diagonal(arg, *diagonal_indices) for arg in expr.args]) + + @classmethod + def _ArrayDiagonal_denest_ArrayDiagonal(cls, expr, *diagonal_indices): + return cls._flatten(expr, *diagonal_indices) + + @classmethod + def _ArrayDiagonal_denest_PermuteDims(cls, expr: PermuteDims, *diagonal_indices): + back_diagonal_indices = [[expr.permutation(j) for j in i] for i in diagonal_indices] + nondiag = [i for i in range(get_rank(expr)) if not any(i in j for j in diagonal_indices)] + back_nondiag = [expr.permutation(i) for i in nondiag] + remap = {e: i for i, e in enumerate(sorted(back_nondiag))} + new_permutation1 = [remap[i] for i in back_nondiag] + shift = len(new_permutation1) + diag_block_perm = [i + shift for i in range(len(back_diagonal_indices))] + new_permutation = new_permutation1 + diag_block_perm + return _permute_dims( + _array_diagonal( + expr.expr, + *back_diagonal_indices + ), + new_permutation + ) + + def _push_indices_down_nonstatic(self, indices): + transform = lambda x: self._positions[x] if x < len(self._positions) else None + return _apply_recursively_over_nested_lists(transform, indices) + + def _push_indices_up_nonstatic(self, indices): + + def transform(x): + for i, e in enumerate(self._positions): + if (isinstance(e, int) and x == e) or (isinstance(e, tuple) and x in e): + return i + + return _apply_recursively_over_nested_lists(transform, indices) + + @classmethod + def _push_indices_down(cls, diagonal_indices, indices, rank): + positions, shape = cls._get_positions_shape(range(rank), diagonal_indices) + transform = lambda x: positions[x] if x < len(positions) else None + return _apply_recursively_over_nested_lists(transform, indices) + + @classmethod + def _push_indices_up(cls, diagonal_indices, indices, rank): + positions, shape = cls._get_positions_shape(range(rank), diagonal_indices) + + def transform(x): + for i, e in enumerate(positions): + if (isinstance(e, int) and x == e) or (isinstance(e, (tuple, Tuple)) and (x in e)): + return i + + return _apply_recursively_over_nested_lists(transform, indices) + + @classmethod + def _get_positions_shape(cls, shape, diagonal_indices): + data1 = tuple((i, shp) for i, shp in enumerate(shape) if not any(i in j for j in diagonal_indices)) + pos1, shp1 = zip(*data1) if data1 else ((), ()) + data2 = tuple((i, shape[i[0]]) for i in diagonal_indices) + pos2, shp2 = zip(*data2) if data2 else ((), ()) + positions = pos1 + pos2 + shape = shp1 + shp2 + return positions, shape + + def as_explicit(self): + expr = self.expr + if hasattr(expr, "as_explicit"): + expr = expr.as_explicit() + return tensordiagonal(expr, *self.diagonal_indices) + + +class ArrayElementwiseApplyFunc(_CodegenArrayAbstract): + + def __new__(cls, function, element): + + if not isinstance(function, Lambda): + d = Dummy('d') + function = Lambda(d, function(d)) + + obj = _CodegenArrayAbstract.__new__(cls, function, element) + obj._subranks = _get_subranks(element) + return obj + + @property + def function(self): + return self.args[0] + + @property + def expr(self): + return self.args[1] + + @property + def shape(self): + return self.expr.shape + + def _get_function_fdiff(self): + d = Dummy("d") + function = self.function(d) + fdiff = function.diff(d) + if isinstance(fdiff, Function): + fdiff = type(fdiff) + else: + fdiff = Lambda(d, fdiff) + return fdiff + + def as_explicit(self): + expr = self.expr + if hasattr(expr, "as_explicit"): + expr = expr.as_explicit() + return expr.applyfunc(self.function) + + +class ArrayContraction(_CodegenArrayAbstract): + r""" + This class is meant to represent contractions of arrays in a form easily + processable by the code printers. + """ + + def __new__(cls, expr, *contraction_indices, **kwargs): + contraction_indices = _sort_contraction_indices(contraction_indices) + expr = _sympify(expr) + + canonicalize = kwargs.get("canonicalize", False) + + obj = Basic.__new__(cls, expr, *contraction_indices) + obj._subranks = _get_subranks(expr) + obj._mapping = _get_mapping_from_subranks(obj._subranks) + + free_indices_to_position = {i: i for i in range(sum(obj._subranks)) if all(i not in cind for cind in contraction_indices)} + obj._free_indices_to_position = free_indices_to_position + + shape = get_shape(expr) + cls._validate(expr, *contraction_indices) + if shape: + shape = tuple(shp for i, shp in enumerate(shape) if not any(i in j for j in contraction_indices)) + obj._shape = shape + if canonicalize: + return obj._canonicalize() + return obj + + def _canonicalize(self): + expr = self.expr + contraction_indices = self.contraction_indices + + if len(contraction_indices) == 0: + return expr + + if isinstance(expr, ArrayContraction): + return self._ArrayContraction_denest_ArrayContraction(expr, *contraction_indices) + + if isinstance(expr, (ZeroArray, ZeroMatrix)): + return self._ArrayContraction_denest_ZeroArray(expr, *contraction_indices) + + if isinstance(expr, PermuteDims): + return self._ArrayContraction_denest_PermuteDims(expr, *contraction_indices) + + if isinstance(expr, ArrayTensorProduct): + expr, contraction_indices = self._sort_fully_contracted_args(expr, contraction_indices) + expr, contraction_indices = self._lower_contraction_to_addends(expr, contraction_indices) + if len(contraction_indices) == 0: + return expr + + if isinstance(expr, ArrayDiagonal): + return self._ArrayContraction_denest_ArrayDiagonal(expr, *contraction_indices) + + if isinstance(expr, ArrayAdd): + return self._ArrayContraction_denest_ArrayAdd(expr, *contraction_indices) + + # Check single index contractions on 1-dimensional axes: + contraction_indices = [i for i in contraction_indices if len(i) > 1 or get_shape(expr)[i[0]] != 1] + if len(contraction_indices) == 0: + return expr + + return self.func(expr, *contraction_indices, canonicalize=False) + + def __mul__(self, other): + if other == 1: + return self + else: + raise NotImplementedError("Product of N-dim arrays is not uniquely defined. Use another method.") + + def __rmul__(self, other): + if other == 1: + return self + else: + raise NotImplementedError("Product of N-dim arrays is not uniquely defined. Use another method.") + + @staticmethod + def _validate(expr, *contraction_indices): + shape = get_shape(expr) + if shape is None: + return + + # Check that no contraction happens when the shape is mismatched: + for i in contraction_indices: + if len({shape[j] for j in i if shape[j] != -1}) != 1: + raise ValueError("contracting indices of different dimensions") + + @classmethod + def _push_indices_down(cls, contraction_indices, indices): + flattened_contraction_indices = [j for i in contraction_indices for j in i] + flattened_contraction_indices.sort() + transform = _build_push_indices_down_func_transformation(flattened_contraction_indices) + return _apply_recursively_over_nested_lists(transform, indices) + + @classmethod + def _push_indices_up(cls, contraction_indices, indices): + flattened_contraction_indices = [j for i in contraction_indices for j in i] + flattened_contraction_indices.sort() + transform = _build_push_indices_up_func_transformation(flattened_contraction_indices) + return _apply_recursively_over_nested_lists(transform, indices) + + @classmethod + def _lower_contraction_to_addends(cls, expr, contraction_indices): + if isinstance(expr, ArrayAdd): + raise NotImplementedError() + if not isinstance(expr, ArrayTensorProduct): + return expr, contraction_indices + subranks = expr.subranks + cumranks = list(accumulate([0] + subranks)) + contraction_indices_remaining = [] + contraction_indices_args = [[] for i in expr.args] + backshift = set() + for contraction_group in contraction_indices: + for j in range(len(expr.args)): + if not isinstance(expr.args[j], ArrayAdd): + continue + if all(cumranks[j] <= k < cumranks[j+1] for k in contraction_group): + contraction_indices_args[j].append([k - cumranks[j] for k in contraction_group]) + backshift.update(contraction_group) + break + else: + contraction_indices_remaining.append(contraction_group) + if len(contraction_indices_remaining) == len(contraction_indices): + return expr, contraction_indices + total_rank = get_rank(expr) + shifts = list(accumulate([1 if i in backshift else 0 for i in range(total_rank)])) + contraction_indices_remaining = [Tuple.fromiter(j - shifts[j] for j in i) for i in contraction_indices_remaining] + ret = _array_tensor_product(*[ + _array_contraction(arg, *contr) for arg, contr in zip(expr.args, contraction_indices_args) + ]) + return ret, contraction_indices_remaining + + def split_multiple_contractions(self): + """ + Recognize multiple contractions and attempt at rewriting them as paired-contractions. + + This allows some contractions involving more than two indices to be + rewritten as multiple contractions involving two indices, thus allowing + the expression to be rewritten as a matrix multiplication line. + + Examples: + + * `A_ij b_j0 C_jk` ===> `A*DiagMatrix(b)*C` + + Care for: + - matrix being diagonalized (i.e. `A_ii`) + - vectors being diagonalized (i.e. `a_i0`) + + Multiple contractions can be split into matrix multiplications if + not more than two arguments are non-diagonals or non-vectors. + Vectors get diagonalized while diagonal matrices remain diagonal. + The non-diagonal matrices can be at the beginning or at the end + of the final matrix multiplication line. + """ + + editor = _EditArrayContraction(self) + + contraction_indices = self.contraction_indices + + onearray_insert = [] + + for indl, links in enumerate(contraction_indices): + if len(links) <= 2: + continue + + # Check multiple contractions: + # + # Examples: + # + # * `A_ij b_j0 C_jk` ===> `A*DiagMatrix(b)*C \otimes OneArray(1)` with permutation (1 2) + # + # Care for: + # - matrix being diagonalized (i.e. `A_ii`) + # - vectors being diagonalized (i.e. `a_i0`) + + # Multiple contractions can be split into matrix multiplications if + # not more than three arguments are non-diagonals or non-vectors. + # + # Vectors get diagonalized while diagonal matrices remain diagonal. + # The non-diagonal matrices can be at the beginning or at the end + # of the final matrix multiplication line. + + positions = editor.get_mapping_for_index(indl) + + # Also consider the case of diagonal matrices being contracted: + current_dimension = self.expr.shape[links[0]] + + not_vectors = [] + vectors = [] + for arg_ind, rel_ind in positions: + arg = editor.args_with_ind[arg_ind] + mat = arg.element + abs_arg_start, abs_arg_end = editor.get_absolute_range(arg) + other_arg_pos = 1-rel_ind + other_arg_abs = abs_arg_start + other_arg_pos + if ((1 not in mat.shape) or + ((current_dimension == 1) is True and mat.shape != (1, 1)) or + any(other_arg_abs in l for li, l in enumerate(contraction_indices) if li != indl) + ): + not_vectors.append((arg, rel_ind)) + else: + vectors.append((arg, rel_ind)) + if len(not_vectors) > 2: + # If more than two arguments in the multiple contraction are + # non-vectors and non-diagonal matrices, we cannot find a way + # to split this contraction into a matrix multiplication line: + continue + # Three cases to handle: + # - zero non-vectors + # - one non-vector + # - two non-vectors + for v, rel_ind in vectors: + v.element = diagonalize_vector(v.element) + vectors_to_loop = not_vectors[:1] + vectors + not_vectors[1:] + first_not_vector, rel_ind = vectors_to_loop[0] + new_index = first_not_vector.indices[rel_ind] + + for v, rel_ind in vectors_to_loop[1:-1]: + v.indices[rel_ind] = new_index + new_index = editor.get_new_contraction_index() + assert v.indices.index(None) == 1 - rel_ind + v.indices[v.indices.index(None)] = new_index + onearray_insert.append(v) + + last_vec, rel_ind = vectors_to_loop[-1] + last_vec.indices[rel_ind] = new_index + + for v in onearray_insert: + editor.insert_after(v, _ArgE(OneArray(1), [None])) + + return editor.to_array_contraction() + + def flatten_contraction_of_diagonal(self): + if not isinstance(self.expr, ArrayDiagonal): + return self + contraction_down = self.expr._push_indices_down(self.expr.diagonal_indices, self.contraction_indices) + new_contraction_indices = [] + diagonal_indices = self.expr.diagonal_indices[:] + for i in contraction_down: + contraction_group = list(i) + for j in i: + diagonal_with = [k for k in diagonal_indices if j in k] + contraction_group.extend([l for k in diagonal_with for l in k]) + diagonal_indices = [k for k in diagonal_indices if k not in diagonal_with] + new_contraction_indices.append(sorted(set(contraction_group))) + + new_contraction_indices = ArrayDiagonal._push_indices_up(diagonal_indices, new_contraction_indices) + return _array_contraction( + _array_diagonal( + self.expr.expr, + *diagonal_indices + ), + *new_contraction_indices + ) + + @staticmethod + def _get_free_indices_to_position_map(free_indices, contraction_indices): + free_indices_to_position = {} + flattened_contraction_indices = [j for i in contraction_indices for j in i] + counter = 0 + for ind in free_indices: + while counter in flattened_contraction_indices: + counter += 1 + free_indices_to_position[ind] = counter + counter += 1 + return free_indices_to_position + + @staticmethod + def _get_index_shifts(expr): + """ + Get the mapping of indices at the positions before the contraction + occurs. + + Examples + ======== + + >>> from sympy.tensor.array import tensorproduct, tensorcontraction + >>> from sympy import MatrixSymbol + >>> M = MatrixSymbol("M", 3, 3) + >>> N = MatrixSymbol("N", 3, 3) + >>> cg = tensorcontraction(tensorproduct(M, N), [1, 2]) + >>> cg._get_index_shifts(cg) + [0, 2] + + Indeed, ``cg`` after the contraction has two dimensions, 0 and 1. They + need to be shifted by 0 and 2 to get the corresponding positions before + the contraction (that is, 0 and 3). + """ + inner_contraction_indices = expr.contraction_indices + all_inner = [j for i in inner_contraction_indices for j in i] + all_inner.sort() + # TODO: add API for total rank and cumulative rank: + total_rank = _get_subrank(expr) + inner_rank = len(all_inner) + outer_rank = total_rank - inner_rank + shifts = [0 for i in range(outer_rank)] + counter = 0 + pointer = 0 + for i in range(outer_rank): + while pointer < inner_rank and counter >= all_inner[pointer]: + counter += 1 + pointer += 1 + shifts[i] += pointer + counter += 1 + return shifts + + @staticmethod + def _convert_outer_indices_to_inner_indices(expr, *outer_contraction_indices): + shifts = ArrayContraction._get_index_shifts(expr) + outer_contraction_indices = tuple(tuple(shifts[j] + j for j in i) for i in outer_contraction_indices) + return outer_contraction_indices + + @staticmethod + def _flatten(expr, *outer_contraction_indices): + inner_contraction_indices = expr.contraction_indices + outer_contraction_indices = ArrayContraction._convert_outer_indices_to_inner_indices(expr, *outer_contraction_indices) + contraction_indices = inner_contraction_indices + outer_contraction_indices + return _array_contraction(expr.expr, *contraction_indices) + + @classmethod + def _ArrayContraction_denest_ArrayContraction(cls, expr, *contraction_indices): + return cls._flatten(expr, *contraction_indices) + + @classmethod + def _ArrayContraction_denest_ZeroArray(cls, expr, *contraction_indices): + contraction_indices_flat = [j for i in contraction_indices for j in i] + shape = [e for i, e in enumerate(expr.shape) if i not in contraction_indices_flat] + return ZeroArray(*shape) + + @classmethod + def _ArrayContraction_denest_ArrayAdd(cls, expr, *contraction_indices): + return _array_add(*[_array_contraction(i, *contraction_indices) for i in expr.args]) + + @classmethod + def _ArrayContraction_denest_PermuteDims(cls, expr, *contraction_indices): + permutation = expr.permutation + plist = permutation.array_form + new_contraction_indices = [tuple(permutation(j) for j in i) for i in contraction_indices] + new_plist = [i for i in plist if not any(i in j for j in new_contraction_indices)] + new_plist = cls._push_indices_up(new_contraction_indices, new_plist) + return _permute_dims( + _array_contraction(expr.expr, *new_contraction_indices), + Permutation(new_plist) + ) + + @classmethod + def _ArrayContraction_denest_ArrayDiagonal(cls, expr: 'ArrayDiagonal', *contraction_indices): + diagonal_indices = list(expr.diagonal_indices) + down_contraction_indices = expr._push_indices_down(expr.diagonal_indices, contraction_indices, get_rank(expr.expr)) + # Flatten diagonally contracted indices: + down_contraction_indices = [[k for j in i for k in (j if isinstance(j, (tuple, Tuple)) else [j])] for i in down_contraction_indices] + new_contraction_indices = [] + for contr_indgrp in down_contraction_indices: + ind = contr_indgrp[:] + for j, diag_indgrp in enumerate(diagonal_indices): + if diag_indgrp is None: + continue + if any(i in diag_indgrp for i in contr_indgrp): + ind.extend(diag_indgrp) + diagonal_indices[j] = None + new_contraction_indices.append(sorted(set(ind))) + + new_diagonal_indices_down = [i for i in diagonal_indices if i is not None] + new_diagonal_indices = ArrayContraction._push_indices_up(new_contraction_indices, new_diagonal_indices_down) + return _array_diagonal( + _array_contraction(expr.expr, *new_contraction_indices), + *new_diagonal_indices + ) + + @classmethod + def _sort_fully_contracted_args(cls, expr, contraction_indices): + if expr.shape is None: + return expr, contraction_indices + cumul = list(accumulate([0] + expr.subranks)) + index_blocks = [list(range(cumul[i], cumul[i+1])) for i in range(len(expr.args))] + contraction_indices_flat = {j for i in contraction_indices for j in i} + fully_contracted = [all(j in contraction_indices_flat for j in range(cumul[i], cumul[i+1])) for i, arg in enumerate(expr.args)] + new_pos = sorted(range(len(expr.args)), key=lambda x: (0, default_sort_key(expr.args[x])) if fully_contracted[x] else (1,)) + new_args = [expr.args[i] for i in new_pos] + new_index_blocks_flat = [j for i in new_pos for j in index_blocks[i]] + index_permutation_array_form = _af_invert(new_index_blocks_flat) + new_contraction_indices = [tuple(index_permutation_array_form[j] for j in i) for i in contraction_indices] + new_contraction_indices = _sort_contraction_indices(new_contraction_indices) + return _array_tensor_product(*new_args), new_contraction_indices + + def _get_contraction_tuples(self): + r""" + Return tuples containing the argument index and position within the + argument of the index position. + + Examples + ======== + + >>> from sympy import MatrixSymbol + >>> from sympy.abc import N + >>> from sympy.tensor.array import tensorproduct, tensorcontraction + >>> A = MatrixSymbol("A", N, N) + >>> B = MatrixSymbol("B", N, N) + + >>> cg = tensorcontraction(tensorproduct(A, B), (1, 2)) + >>> cg._get_contraction_tuples() + [[(0, 1), (1, 0)]] + + Notes + ===== + + Here the contraction pair `(1, 2)` meaning that the 2nd and 3rd indices + of the tensor product `A\otimes B` are contracted, has been transformed + into `(0, 1)` and `(1, 0)`, identifying the same indices in a different + notation. `(0, 1)` is the second index (1) of the first argument (i.e. + 0 or `A`). `(1, 0)` is the first index (i.e. 0) of the second + argument (i.e. 1 or `B`). + """ + mapping = self._mapping + return [[mapping[j] for j in i] for i in self.contraction_indices] + + @staticmethod + def _contraction_tuples_to_contraction_indices(expr, contraction_tuples): + # TODO: check that `expr` has `.subranks`: + ranks = expr.subranks + cumulative_ranks = [0] + list(accumulate(ranks)) + return [tuple(cumulative_ranks[j]+k for j, k in i) for i in contraction_tuples] + + @property + def free_indices(self): + return self._free_indices[:] + + @property + def free_indices_to_position(self): + return dict(self._free_indices_to_position) + + @property + def expr(self): + return self.args[0] + + @property + def contraction_indices(self): + return self.args[1:] + + def _contraction_indices_to_components(self): + expr = self.expr + if not isinstance(expr, ArrayTensorProduct): + raise NotImplementedError("only for contractions of tensor products") + ranks = expr.subranks + mapping = {} + counter = 0 + for i, rank in enumerate(ranks): + for j in range(rank): + mapping[counter] = (i, j) + counter += 1 + return mapping + + def sort_args_by_name(self): + """ + Sort arguments in the tensor product so that their order is lexicographical. + + Examples + ======== + + >>> from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array + >>> from sympy import MatrixSymbol + >>> from sympy.abc import N + >>> A = MatrixSymbol("A", N, N) + >>> B = MatrixSymbol("B", N, N) + >>> C = MatrixSymbol("C", N, N) + >>> D = MatrixSymbol("D", N, N) + + >>> cg = convert_matrix_to_array(C*D*A*B) + >>> cg + ArrayContraction(ArrayTensorProduct(A, D, C, B), (0, 3), (1, 6), (2, 5)) + >>> cg.sort_args_by_name() + ArrayContraction(ArrayTensorProduct(A, D, B, C), (0, 3), (1, 4), (2, 7)) + """ + expr = self.expr + if not isinstance(expr, ArrayTensorProduct): + return self + args = expr.args + sorted_data = sorted(enumerate(args), key=lambda x: default_sort_key(x[1])) + pos_sorted, args_sorted = zip(*sorted_data) + reordering_map = {i: pos_sorted.index(i) for i, arg in enumerate(args)} + contraction_tuples = self._get_contraction_tuples() + contraction_tuples = [[(reordering_map[j], k) for j, k in i] for i in contraction_tuples] + c_tp = _array_tensor_product(*args_sorted) + new_contr_indices = self._contraction_tuples_to_contraction_indices( + c_tp, + contraction_tuples + ) + return _array_contraction(c_tp, *new_contr_indices) + + def _get_contraction_links(self): + r""" + Returns a dictionary of links between arguments in the tensor product + being contracted. + + See the example for an explanation of the values. + + Examples + ======== + + >>> from sympy import MatrixSymbol + >>> from sympy.abc import N + >>> from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array + >>> A = MatrixSymbol("A", N, N) + >>> B = MatrixSymbol("B", N, N) + >>> C = MatrixSymbol("C", N, N) + >>> D = MatrixSymbol("D", N, N) + + Matrix multiplications are pairwise contractions between neighboring + matrices: + + `A_{ij} B_{jk} C_{kl} D_{lm}` + + >>> cg = convert_matrix_to_array(A*B*C*D) + >>> cg + ArrayContraction(ArrayTensorProduct(B, C, A, D), (0, 5), (1, 2), (3, 6)) + + >>> cg._get_contraction_links() + {0: {0: (2, 1), 1: (1, 0)}, 1: {0: (0, 1), 1: (3, 0)}, 2: {1: (0, 0)}, 3: {0: (1, 1)}} + + This dictionary is interpreted as follows: argument in position 0 (i.e. + matrix `A`) has its second index (i.e. 1) contracted to `(1, 0)`, that + is argument in position 1 (matrix `B`) on the first index slot of `B`, + this is the contraction provided by the index `j` from `A`. + + The argument in position 1 (that is, matrix `B`) has two contractions, + the ones provided by the indices `j` and `k`, respectively the first + and second indices (0 and 1 in the sub-dict). The link `(0, 1)` and + `(2, 0)` respectively. `(0, 1)` is the index slot 1 (the 2nd) of + argument in position 0 (that is, `A_{\ldot j}`), and so on. + """ + args, dlinks = _get_contraction_links([self], self.subranks, *self.contraction_indices) + return dlinks + + def as_explicit(self): + expr = self.expr + if hasattr(expr, "as_explicit"): + expr = expr.as_explicit() + return tensorcontraction(expr, *self.contraction_indices) + + +class Reshape(_CodegenArrayAbstract): + """ + Reshape the dimensions of an array expression. + + Examples + ======== + + >>> from sympy.tensor.array.expressions import ArraySymbol, Reshape + >>> A = ArraySymbol("A", (6,)) + >>> A.shape + (6,) + >>> Reshape(A, (3, 2)).shape + (3, 2) + + Check the component-explicit forms: + + >>> A.as_explicit() + [A[0], A[1], A[2], A[3], A[4], A[5]] + >>> Reshape(A, (3, 2)).as_explicit() + [[A[0], A[1]], [A[2], A[3]], [A[4], A[5]]] + + """ + + def __new__(cls, expr, shape): + expr = _sympify(expr) + if not isinstance(shape, Tuple): + shape = Tuple(*shape) + if Equality(Mul.fromiter(expr.shape), Mul.fromiter(shape)) == False: + raise ValueError("shape mismatch") + obj = Expr.__new__(cls, expr, shape) + obj._shape = tuple(shape) + obj._expr = expr + return obj + + @property + def shape(self): + return self._shape + + @property + def expr(self): + return self._expr + + def doit(self, *args, **kwargs): + if kwargs.get("deep", True): + expr = self.expr.doit(*args, **kwargs) + else: + expr = self.expr + if isinstance(expr, (MatrixBase, NDimArray)): + return expr.reshape(*self.shape) + return Reshape(expr, self.shape) + + def as_explicit(self): + ee = self.expr + if hasattr(ee, "as_explicit"): + ee = ee.as_explicit() + if isinstance(ee, MatrixBase): + from sympy import Array + ee = Array(ee) + elif isinstance(ee, MatrixExpr): + return self + return ee.reshape(*self.shape) + + +class _ArgE: + """ + The ``_ArgE`` object contains references to the array expression + (``.element``) and a list containing the information about index + contractions (``.indices``). + + Index contractions are numbered and contracted indices show the number of + the contraction. Uncontracted indices have ``None`` value. + + For example: + ``_ArgE(M, [None, 3])`` + This object means that expression ``M`` is part of an array contraction + and has two indices, the first is not contracted (value ``None``), + the second index is contracted to the 4th (i.e. number ``3``) group of the + array contraction object. + """ + indices: list[int | None] + + def __init__(self, element, indices: list[int | None] | None = None): + self.element = element + if indices is None: + self.indices = [None for i in range(get_rank(element))] + else: + self.indices = indices + + def __str__(self): + return "_ArgE(%s, %s)" % (self.element, self.indices) + + __repr__ = __str__ + + +class _IndPos: + """ + Index position, requiring two integers in the constructor: + + - arg: the position of the argument in the tensor product, + - rel: the relative position of the index inside the argument. + """ + def __init__(self, arg: int, rel: int): + self.arg = arg + self.rel = rel + + def __str__(self): + return "_IndPos(%i, %i)" % (self.arg, self.rel) + + __repr__ = __str__ + + def __iter__(self): + yield from [self.arg, self.rel] + + +class _EditArrayContraction: + """ + Utility class to help manipulate array contraction objects. + + This class takes as input an ``ArrayContraction`` object and turns it into + an editable object. + + The field ``args_with_ind`` of this class is a list of ``_ArgE`` objects + which can be used to easily edit the contraction structure of the + expression. + + Once editing is finished, the ``ArrayContraction`` object may be recreated + by calling the ``.to_array_contraction()`` method. + """ + + def __init__(self, base_array: typing.Union[ArrayContraction, ArrayDiagonal, ArrayTensorProduct]): + + expr: Basic + diagonalized: tuple[tuple[int, ...], ...] + contraction_indices: list[tuple[int]] + if isinstance(base_array, ArrayContraction): + mapping = _get_mapping_from_subranks(base_array.subranks) + expr = base_array.expr + contraction_indices = base_array.contraction_indices + diagonalized = () + elif isinstance(base_array, ArrayDiagonal): + + if isinstance(base_array.expr, ArrayContraction): + mapping = _get_mapping_from_subranks(base_array.expr.subranks) + expr = base_array.expr.expr + diagonalized = ArrayContraction._push_indices_down(base_array.expr.contraction_indices, base_array.diagonal_indices) + contraction_indices = base_array.expr.contraction_indices + elif isinstance(base_array.expr, ArrayTensorProduct): + mapping = {} + expr = base_array.expr + diagonalized = base_array.diagonal_indices + contraction_indices = [] + else: + mapping = {} + expr = base_array.expr + diagonalized = base_array.diagonal_indices + contraction_indices = [] + + elif isinstance(base_array, ArrayTensorProduct): + expr = base_array + contraction_indices = [] + diagonalized = () + else: + raise NotImplementedError() + + if isinstance(expr, ArrayTensorProduct): + args = list(expr.args) + else: + args = [expr] + + args_with_ind: list[_ArgE] = [_ArgE(arg) for arg in args] + for i, contraction_tuple in enumerate(contraction_indices): + for j in contraction_tuple: + arg_pos, rel_pos = mapping[j] + args_with_ind[arg_pos].indices[rel_pos] = i + self.args_with_ind: list[_ArgE] = args_with_ind + self.number_of_contraction_indices: int = len(contraction_indices) + self._track_permutation: list[list[int]] | None = None + + mapping = _get_mapping_from_subranks(base_array.subranks) + + # Trick: add diagonalized indices as negative indices into the editor object: + for i, e in enumerate(diagonalized): + for j in e: + arg_pos, rel_pos = mapping[j] + self.args_with_ind[arg_pos].indices[rel_pos] = -1 - i + + def insert_after(self, arg: _ArgE, new_arg: _ArgE): + pos = self.args_with_ind.index(arg) + self.args_with_ind.insert(pos + 1, new_arg) + + def get_new_contraction_index(self): + self.number_of_contraction_indices += 1 + return self.number_of_contraction_indices - 1 + + def refresh_indices(self): + updates = {} + for arg_with_ind in self.args_with_ind: + updates.update({i: -1 for i in arg_with_ind.indices if i is not None}) + for i, e in enumerate(sorted(updates)): + updates[e] = i + self.number_of_contraction_indices = len(updates) + for arg_with_ind in self.args_with_ind: + arg_with_ind.indices = [updates.get(i, None) for i in arg_with_ind.indices] + + def merge_scalars(self): + scalars = [] + for arg_with_ind in self.args_with_ind: + if len(arg_with_ind.indices) == 0: + scalars.append(arg_with_ind) + for i in scalars: + self.args_with_ind.remove(i) + scalar = Mul.fromiter([i.element for i in scalars]) + if len(self.args_with_ind) == 0: + self.args_with_ind.append(_ArgE(scalar)) + else: + from sympy.tensor.array.expressions.from_array_to_matrix import _a2m_tensor_product + self.args_with_ind[0].element = _a2m_tensor_product(scalar, self.args_with_ind[0].element) + + def to_array_contraction(self): + + # Count the ranks of the arguments: + counter = 0 + # Create a collector for the new diagonal indices: + diag_indices = defaultdict(list) + + count_index_freq = Counter() + for arg_with_ind in self.args_with_ind: + count_index_freq.update(Counter(arg_with_ind.indices)) + + free_index_count = count_index_freq[None] + + # Construct the inverse permutation: + inv_perm1 = [] + inv_perm2 = [] + # Keep track of which diagonal indices have already been processed: + done = set() + + # Counter for the diagonal indices: + counter4 = 0 + + for arg_with_ind in self.args_with_ind: + # If some diagonalization axes have been removed, they should be + # permuted in order to keep the permutation. + # Add permutation here + counter2 = 0 # counter for the indices + for i in arg_with_ind.indices: + if i is None: + inv_perm1.append(counter4) + counter2 += 1 + counter4 += 1 + continue + if i >= 0: + continue + # Reconstruct the diagonal indices: + diag_indices[-1 - i].append(counter + counter2) + if count_index_freq[i] == 1 and i not in done: + inv_perm1.append(free_index_count - 1 - i) + done.add(i) + elif i not in done: + inv_perm2.append(free_index_count - 1 - i) + done.add(i) + counter2 += 1 + # Remove negative indices to restore a proper editor object: + arg_with_ind.indices = [i if i is not None and i >= 0 else None for i in arg_with_ind.indices] + counter += len([i for i in arg_with_ind.indices if i is None or i < 0]) + + inverse_permutation = inv_perm1 + inv_perm2 + permutation = _af_invert(inverse_permutation) + + # Get the diagonal indices after the detection of HadamardProduct in the expression: + diag_indices_filtered = [tuple(v) for v in diag_indices.values() if len(v) > 1] + + self.merge_scalars() + self.refresh_indices() + args = [arg.element for arg in self.args_with_ind] + contraction_indices = self.get_contraction_indices() + expr = _array_contraction(_array_tensor_product(*args), *contraction_indices) + expr2 = _array_diagonal(expr, *diag_indices_filtered) + if self._track_permutation is not None: + permutation2 = _af_invert([j for i in self._track_permutation for j in i]) + expr2 = _permute_dims(expr2, permutation2) + + expr3 = _permute_dims(expr2, permutation) + return expr3 + + def get_contraction_indices(self) -> list[list[int]]: + contraction_indices: list[list[int]] = [[] for i in range(self.number_of_contraction_indices)] + current_position: int = 0 + for arg_with_ind in self.args_with_ind: + for j in arg_with_ind.indices: + if j is not None: + contraction_indices[j].append(current_position) + current_position += 1 + return contraction_indices + + def get_mapping_for_index(self, ind) -> list[_IndPos]: + if ind >= self.number_of_contraction_indices: + raise ValueError("index value exceeding the index range") + positions: list[_IndPos] = [] + for i, arg_with_ind in enumerate(self.args_with_ind): + for j, arg_ind in enumerate(arg_with_ind.indices): + if ind == arg_ind: + positions.append(_IndPos(i, j)) + return positions + + def get_contraction_indices_to_ind_rel_pos(self) -> list[list[_IndPos]]: + contraction_indices: list[list[_IndPos]] = [[] for i in range(self.number_of_contraction_indices)] + for i, arg_with_ind in enumerate(self.args_with_ind): + for j, ind in enumerate(arg_with_ind.indices): + if ind is not None: + contraction_indices[ind].append(_IndPos(i, j)) + return contraction_indices + + def count_args_with_index(self, index: int) -> int: + """ + Count the number of arguments that have the given index. + """ + counter: int = 0 + for arg_with_ind in self.args_with_ind: + if index in arg_with_ind.indices: + counter += 1 + return counter + + def get_args_with_index(self, index: int) -> list[_ArgE]: + """ + Get a list of arguments having the given index. + """ + ret: list[_ArgE] = [i for i in self.args_with_ind if index in i.indices] + return ret + + @property + def number_of_diagonal_indices(self): + data = set() + for arg in self.args_with_ind: + data.update({i for i in arg.indices if i is not None and i < 0}) + return len(data) + + def track_permutation_start(self): + permutation = [] + perm_diag = [] + counter = 0 + counter2 = -1 + for arg_with_ind in self.args_with_ind: + perm = [] + for i in arg_with_ind.indices: + if i is not None: + if i < 0: + perm_diag.append(counter2) + counter2 -= 1 + continue + perm.append(counter) + counter += 1 + permutation.append(perm) + max_ind = max(max(i) if i else -1 for i in permutation) if permutation else -1 + perm_diag = [max_ind - i for i in perm_diag] + self._track_permutation = permutation + [perm_diag] + + def track_permutation_merge(self, destination: _ArgE, from_element: _ArgE): + index_destination = self.args_with_ind.index(destination) + index_element = self.args_with_ind.index(from_element) + self._track_permutation[index_destination].extend(self._track_permutation[index_element]) # type: ignore + self._track_permutation.pop(index_element) # type: ignore + + def get_absolute_free_range(self, arg: _ArgE) -> typing.Tuple[int, int]: + """ + Return the range of the free indices of the arg as absolute positions + among all free indices. + """ + counter = 0 + for arg_with_ind in self.args_with_ind: + number_free_indices = len([i for i in arg_with_ind.indices if i is None]) + if arg_with_ind == arg: + return counter, counter + number_free_indices + counter += number_free_indices + raise IndexError("argument not found") + + def get_absolute_range(self, arg: _ArgE) -> typing.Tuple[int, int]: + """ + Return the absolute range of indices for arg, disregarding dummy + indices. + """ + counter = 0 + for arg_with_ind in self.args_with_ind: + number_indices = len(arg_with_ind.indices) + if arg_with_ind == arg: + return counter, counter + number_indices + counter += number_indices + raise IndexError("argument not found") + + +def get_rank(expr): + if isinstance(expr, (MatrixExpr, MatrixElement)): + return 2 + if isinstance(expr, _CodegenArrayAbstract): + return len(expr.shape) + if isinstance(expr, NDimArray): + return expr.rank() + if isinstance(expr, Indexed): + return expr.rank + if isinstance(expr, IndexedBase): + shape = expr.shape + if shape is None: + return -1 + else: + return len(shape) + if hasattr(expr, "shape"): + return len(expr.shape) + return 0 + + +def _get_subrank(expr): + if isinstance(expr, _CodegenArrayAbstract): + return expr.subrank() + return get_rank(expr) + + +def _get_subranks(expr): + if isinstance(expr, _CodegenArrayAbstract): + return expr.subranks + else: + return [get_rank(expr)] + + +def get_shape(expr): + if hasattr(expr, "shape"): + return expr.shape + return () + + +def nest_permutation(expr): + if isinstance(expr, PermuteDims): + return expr.nest_permutation() + else: + return expr + + +def _array_tensor_product(*args, **kwargs): + return ArrayTensorProduct(*args, canonicalize=True, **kwargs) + + +def _array_contraction(expr, *contraction_indices, **kwargs): + return ArrayContraction(expr, *contraction_indices, canonicalize=True, **kwargs) + + +def _array_diagonal(expr, *diagonal_indices, **kwargs): + return ArrayDiagonal(expr, *diagonal_indices, canonicalize=True, **kwargs) + + +def _permute_dims(expr, permutation, **kwargs): + return PermuteDims(expr, permutation, canonicalize=True, **kwargs) + + +def _array_add(*args, **kwargs): + return ArrayAdd(*args, canonicalize=True, **kwargs) + + +def _get_array_element_or_slice(expr, indices): + return ArrayElement(expr, indices) diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/expressions/arrayexpr_derivatives.py b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/arrayexpr_derivatives.py new file mode 100644 index 0000000000000000000000000000000000000000..ab44a6fbf715ac7f2b8c287dcc84a49289f2dd76 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/arrayexpr_derivatives.py @@ -0,0 +1,194 @@ +import operator +from functools import reduce, singledispatch + +from sympy.core.expr import Expr +from sympy.core.singleton import S +from sympy.matrices.expressions.hadamard import HadamardProduct +from sympy.matrices.expressions.inverse import Inverse +from sympy.matrices.expressions.matexpr import (MatrixExpr, MatrixSymbol) +from sympy.matrices.expressions.special import Identity, OneMatrix +from sympy.matrices.expressions.transpose import Transpose +from sympy.combinatorics.permutations import _af_invert +from sympy.matrices.expressions.applyfunc import ElementwiseApplyFunction +from sympy.tensor.array.expressions.array_expressions import ( + _ArrayExpr, ZeroArray, ArraySymbol, ArrayTensorProduct, ArrayAdd, + PermuteDims, ArrayDiagonal, ArrayElementwiseApplyFunc, get_rank, + get_shape, ArrayContraction, _array_tensor_product, _array_contraction, + _array_diagonal, _array_add, _permute_dims, Reshape) +from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array + + +@singledispatch +def array_derive(expr, x): + """ + Derivatives (gradients) for array expressions. + """ + raise NotImplementedError(f"not implemented for type {type(expr)}") + + +@array_derive.register(Expr) +def _(expr: Expr, x: _ArrayExpr): + return ZeroArray(*x.shape) + + +@array_derive.register(ArrayTensorProduct) +def _(expr: ArrayTensorProduct, x: Expr): + args = expr.args + addend_list = [] + for i, arg in enumerate(expr.args): + darg = array_derive(arg, x) + if darg == 0: + continue + args_prev = args[:i] + args_succ = args[i+1:] + shape_prev = reduce(operator.add, map(get_shape, args_prev), ()) + shape_succ = reduce(operator.add, map(get_shape, args_succ), ()) + addend = _array_tensor_product(*args_prev, darg, *args_succ) + tot1 = len(get_shape(x)) + tot2 = tot1 + len(shape_prev) + tot3 = tot2 + len(get_shape(arg)) + tot4 = tot3 + len(shape_succ) + perm = list(range(tot1, tot2)) + \ + list(range(tot1)) + list(range(tot2, tot3)) + \ + list(range(tot3, tot4)) + addend = _permute_dims(addend, _af_invert(perm)) + addend_list.append(addend) + if len(addend_list) == 1: + return addend_list[0] + elif len(addend_list) == 0: + return S.Zero + else: + return _array_add(*addend_list) + + +@array_derive.register(ArraySymbol) +def _(expr: ArraySymbol, x: _ArrayExpr): + if expr == x: + return _permute_dims( + ArrayTensorProduct.fromiter(Identity(i) for i in expr.shape), + [2*i for i in range(len(expr.shape))] + [2*i+1 for i in range(len(expr.shape))] + ) + return ZeroArray(*(x.shape + expr.shape)) + + +@array_derive.register(MatrixSymbol) +def _(expr: MatrixSymbol, x: _ArrayExpr): + m, n = expr.shape + if expr == x: + return _permute_dims( + _array_tensor_product(Identity(m), Identity(n)), + [0, 2, 1, 3] + ) + return ZeroArray(*(x.shape + expr.shape)) + + +@array_derive.register(Identity) +def _(expr: Identity, x: _ArrayExpr): + return ZeroArray(*(x.shape + expr.shape)) + + +@array_derive.register(OneMatrix) +def _(expr: OneMatrix, x: _ArrayExpr): + return ZeroArray(*(x.shape + expr.shape)) + + +@array_derive.register(Transpose) +def _(expr: Transpose, x: Expr): + # D(A.T, A) ==> (m,n,i,j) ==> D(A_ji, A_mn) = d_mj d_ni + # D(B.T, A) ==> (m,n,i,j) ==> D(B_ji, A_mn) + fd = array_derive(expr.arg, x) + return _permute_dims(fd, [0, 1, 3, 2]) + + +@array_derive.register(Inverse) +def _(expr: Inverse, x: Expr): + mat = expr.I + dexpr = array_derive(mat, x) + tp = _array_tensor_product(-expr, dexpr, expr) + mp = _array_contraction(tp, (1, 4), (5, 6)) + pp = _permute_dims(mp, [1, 2, 0, 3]) + return pp + + +@array_derive.register(ElementwiseApplyFunction) +def _(expr: ElementwiseApplyFunction, x: Expr): + assert get_rank(expr) == 2 + assert get_rank(x) == 2 + fdiff = expr._get_function_fdiff() + dexpr = array_derive(expr.expr, x) + tp = _array_tensor_product( + ElementwiseApplyFunction(fdiff, expr.expr), + dexpr + ) + td = _array_diagonal( + tp, (0, 4), (1, 5) + ) + return td + + +@array_derive.register(ArrayElementwiseApplyFunc) +def _(expr: ArrayElementwiseApplyFunc, x: Expr): + fdiff = expr._get_function_fdiff() + subexpr = expr.expr + dsubexpr = array_derive(subexpr, x) + tp = _array_tensor_product( + dsubexpr, + ArrayElementwiseApplyFunc(fdiff, subexpr) + ) + b = get_rank(x) + c = get_rank(expr) + diag_indices = [(b + i, b + c + i) for i in range(c)] + return _array_diagonal(tp, *diag_indices) + + +@array_derive.register(MatrixExpr) +def _(expr: MatrixExpr, x: Expr): + cg = convert_matrix_to_array(expr) + return array_derive(cg, x) + + +@array_derive.register(HadamardProduct) +def _(expr: HadamardProduct, x: Expr): + raise NotImplementedError() + + +@array_derive.register(ArrayContraction) +def _(expr: ArrayContraction, x: Expr): + fd = array_derive(expr.expr, x) + rank_x = len(get_shape(x)) + contraction_indices = expr.contraction_indices + new_contraction_indices = [tuple(j + rank_x for j in i) for i in contraction_indices] + return _array_contraction(fd, *new_contraction_indices) + + +@array_derive.register(ArrayDiagonal) +def _(expr: ArrayDiagonal, x: Expr): + dsubexpr = array_derive(expr.expr, x) + rank_x = len(get_shape(x)) + diag_indices = [[j + rank_x for j in i] for i in expr.diagonal_indices] + return _array_diagonal(dsubexpr, *diag_indices) + + +@array_derive.register(ArrayAdd) +def _(expr: ArrayAdd, x: Expr): + return _array_add(*[array_derive(arg, x) for arg in expr.args]) + + +@array_derive.register(PermuteDims) +def _(expr: PermuteDims, x: Expr): + de = array_derive(expr.expr, x) + perm = [0, 1] + [i + 2 for i in expr.permutation.array_form] + return _permute_dims(de, perm) + + +@array_derive.register(Reshape) +def _(expr: Reshape, x: Expr): + de = array_derive(expr.expr, x) + return Reshape(de, get_shape(x) + expr.shape) + + +def matrix_derive(expr, x): + from sympy.tensor.array.expressions.from_array_to_matrix import convert_array_to_matrix + ce = convert_matrix_to_array(expr) + dce = array_derive(ce, x) + return convert_array_to_matrix(dce).doit() diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/expressions/conv_array_to_indexed.py b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/conv_array_to_indexed.py new file mode 100644 index 0000000000000000000000000000000000000000..1929c3401e131cca0a83080131ead9198b37bcbb --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/conv_array_to_indexed.py @@ -0,0 +1,12 @@ +from sympy.tensor.array.expressions import from_array_to_indexed +from sympy.utilities.decorator import deprecated + + +_conv_to_from_decorator = deprecated( + "module has been renamed by replacing 'conv_' with 'from_' in its name", + deprecated_since_version="1.11", + active_deprecations_target="deprecated-conv-array-expr-module-names", +) + + +convert_array_to_indexed = _conv_to_from_decorator(from_array_to_indexed.convert_array_to_indexed) diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/expressions/conv_array_to_matrix.py b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/conv_array_to_matrix.py new file mode 100644 index 0000000000000000000000000000000000000000..2708e74aaa98d6ee38eae46d97d4483a546e0776 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/conv_array_to_matrix.py @@ -0,0 +1,6 @@ +from sympy.tensor.array.expressions import from_array_to_matrix +from sympy.tensor.array.expressions.conv_array_to_indexed import _conv_to_from_decorator + +convert_array_to_matrix = _conv_to_from_decorator(from_array_to_matrix.convert_array_to_matrix) +_array2matrix = _conv_to_from_decorator(from_array_to_matrix._array2matrix) +_remove_trivial_dims = _conv_to_from_decorator(from_array_to_matrix._remove_trivial_dims) diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/expressions/conv_indexed_to_array.py b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/conv_indexed_to_array.py new file mode 100644 index 0000000000000000000000000000000000000000..6058b31f20778834ea23a01553d594b7965eb6bb --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/conv_indexed_to_array.py @@ -0,0 +1,4 @@ +from sympy.tensor.array.expressions import from_indexed_to_array +from sympy.tensor.array.expressions.conv_array_to_indexed import _conv_to_from_decorator + +convert_indexed_to_array = _conv_to_from_decorator(from_indexed_to_array.convert_indexed_to_array) diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/expressions/conv_matrix_to_array.py b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/conv_matrix_to_array.py new file mode 100644 index 0000000000000000000000000000000000000000..46469df60703c237527c0b2834235309640afe7c --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/conv_matrix_to_array.py @@ -0,0 +1,4 @@ +from sympy.tensor.array.expressions import from_matrix_to_array +from sympy.tensor.array.expressions.conv_array_to_indexed import _conv_to_from_decorator + +convert_matrix_to_array = _conv_to_from_decorator(from_matrix_to_array.convert_matrix_to_array) diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/expressions/from_array_to_indexed.py b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/from_array_to_indexed.py new file mode 100644 index 0000000000000000000000000000000000000000..9eb86e7cfbe31ebfe7c9649803d9cb5e34b98276 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/from_array_to_indexed.py @@ -0,0 +1,84 @@ +import collections.abc +import operator +from itertools import accumulate + +from sympy import Mul, Sum, Dummy, Add +from sympy.tensor.array.expressions import PermuteDims, ArrayAdd, ArrayElementwiseApplyFunc, Reshape +from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct, get_rank, ArrayContraction, \ + ArrayDiagonal, get_shape, _get_array_element_or_slice, _ArrayExpr +from sympy.tensor.array.expressions.utils import _apply_permutation_to_list + + +def convert_array_to_indexed(expr, indices): + return _ConvertArrayToIndexed().do_convert(expr, indices) + + +class _ConvertArrayToIndexed: + + def __init__(self): + self.count_dummies = 0 + + def do_convert(self, expr, indices): + if isinstance(expr, ArrayTensorProduct): + cumul = list(accumulate([0] + [get_rank(arg) for arg in expr.args])) + indices_grp = [indices[cumul[i]:cumul[i+1]] for i in range(len(expr.args))] + return Mul.fromiter(self.do_convert(arg, ind) for arg, ind in zip(expr.args, indices_grp)) + if isinstance(expr, ArrayContraction): + new_indices = [None for i in range(get_rank(expr.expr))] + limits = [] + bottom_shape = get_shape(expr.expr) + for contraction_index_grp in expr.contraction_indices: + d = Dummy(f"d{self.count_dummies}") + self.count_dummies += 1 + dim = bottom_shape[contraction_index_grp[0]] + limits.append((d, 0, dim-1)) + for i in contraction_index_grp: + new_indices[i] = d + j = 0 + for i in range(len(new_indices)): + if new_indices[i] is None: + new_indices[i] = indices[j] + j += 1 + newexpr = self.do_convert(expr.expr, new_indices) + return Sum(newexpr, *limits) + if isinstance(expr, ArrayDiagonal): + new_indices = [None for i in range(get_rank(expr.expr))] + ind_pos = expr._push_indices_down(expr.diagonal_indices, list(range(len(indices))), get_rank(expr)) + for i, index in zip(ind_pos, indices): + if isinstance(i, collections.abc.Iterable): + for j in i: + new_indices[j] = index + else: + new_indices[i] = index + newexpr = self.do_convert(expr.expr, new_indices) + return newexpr + if isinstance(expr, PermuteDims): + permuted_indices = _apply_permutation_to_list(expr.permutation, indices) + return self.do_convert(expr.expr, permuted_indices) + if isinstance(expr, ArrayAdd): + return Add.fromiter(self.do_convert(arg, indices) for arg in expr.args) + if isinstance(expr, _ArrayExpr): + return expr.__getitem__(tuple(indices)) + if isinstance(expr, ArrayElementwiseApplyFunc): + return expr.function(self.do_convert(expr.expr, indices)) + if isinstance(expr, Reshape): + shape_up = expr.shape + shape_down = get_shape(expr.expr) + cumul = list(accumulate([1] + list(reversed(shape_up)), operator.mul)) + one_index = Add.fromiter(i*s for i, s in zip(reversed(indices), cumul)) + dest_indices = [None for _ in shape_down] + c = 1 + for i, e in enumerate(reversed(shape_down)): + if c == 1: + if i == len(shape_down) - 1: + dest_indices[i] = one_index + else: + dest_indices[i] = one_index % e + elif i == len(shape_down) - 1: + dest_indices[i] = one_index // c + else: + dest_indices[i] = one_index // c % e + c *= e + dest_indices.reverse() + return self.do_convert(expr.expr, dest_indices) + return _get_array_element_or_slice(expr, indices) diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/expressions/from_array_to_matrix.py b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/from_array_to_matrix.py new file mode 100644 index 0000000000000000000000000000000000000000..debfdd7eb5c4533996b3d72b55d679be3daf3afe --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/from_array_to_matrix.py @@ -0,0 +1,1004 @@ +from __future__ import annotations +import itertools +from collections import defaultdict +from typing import FrozenSet +from functools import singledispatch +from itertools import accumulate + +from sympy import MatMul, Basic, Wild, KroneckerProduct +from sympy.assumptions.ask import (Q, ask) +from sympy.core.mul import Mul +from sympy.core.singleton import S +from sympy.matrices.expressions.diagonal import DiagMatrix +from sympy.matrices.expressions.hadamard import hadamard_product, HadamardPower +from sympy.matrices.expressions.matexpr import MatrixExpr +from sympy.matrices.expressions.special import (Identity, ZeroMatrix, OneMatrix) +from sympy.matrices.expressions.trace import Trace +from sympy.matrices.expressions.transpose import Transpose +from sympy.combinatorics.permutations import _af_invert, Permutation +from sympy.matrices.matrixbase import MatrixBase +from sympy.matrices.expressions.applyfunc import ElementwiseApplyFunction +from sympy.matrices.expressions.matexpr import MatrixElement +from sympy.tensor.array.expressions.array_expressions import PermuteDims, ArrayDiagonal, \ + ArrayTensorProduct, OneArray, get_rank, _get_subrank, ZeroArray, ArrayContraction, \ + ArrayAdd, _CodegenArrayAbstract, get_shape, ArrayElementwiseApplyFunc, _ArrayExpr, _EditArrayContraction, _ArgE, \ + ArrayElement, _array_tensor_product, _array_contraction, _array_diagonal, _array_add, _permute_dims +from sympy.tensor.array.expressions.utils import _get_mapping_from_subranks + + +def _get_candidate_for_matmul_from_contraction(scan_indices: list[int | None], remaining_args: list[_ArgE]) -> tuple[_ArgE | None, bool, int]: + + scan_indices_int: list[int] = [i for i in scan_indices if i is not None] + if len(scan_indices_int) == 0: + return None, False, -1 + + transpose: bool = False + candidate: _ArgE | None = None + candidate_index: int = -1 + for arg_with_ind2 in remaining_args: + if not isinstance(arg_with_ind2.element, MatrixExpr): + continue + for index in scan_indices_int: + if candidate_index != -1 and candidate_index != index: + # A candidate index has already been selected, check + # repetitions only for that index: + continue + if index in arg_with_ind2.indices: + if set(arg_with_ind2.indices) == {index}: + # Index repeated twice in arg_with_ind2 + candidate = None + break + if candidate is None: + candidate = arg_with_ind2 + candidate_index = index + transpose = (index == arg_with_ind2.indices[1]) + else: + # Index repeated more than twice, break + candidate = None + break + return candidate, transpose, candidate_index + + +def _insert_candidate_into_editor(editor: _EditArrayContraction, arg_with_ind: _ArgE, candidate: _ArgE, transpose1: bool, transpose2: bool): + other = candidate.element + other_index: int | None + if transpose2: + other = Transpose(other) + other_index = candidate.indices[0] + else: + other_index = candidate.indices[1] + new_element = (Transpose(arg_with_ind.element) if transpose1 else arg_with_ind.element) * other + editor.args_with_ind.remove(candidate) + new_arge = _ArgE(new_element) + return new_arge, other_index + + +def _support_function_tp1_recognize(contraction_indices, args): + if len(contraction_indices) == 0: + return _a2m_tensor_product(*args) + + ac = _array_contraction(_array_tensor_product(*args), *contraction_indices) + editor = _EditArrayContraction(ac) + editor.track_permutation_start() + + while True: + flag_stop = True + for i, arg_with_ind in enumerate(editor.args_with_ind): + if not isinstance(arg_with_ind.element, MatrixExpr): + continue + + first_index = arg_with_ind.indices[0] + second_index = arg_with_ind.indices[1] + + first_frequency = editor.count_args_with_index(first_index) + second_frequency = editor.count_args_with_index(second_index) + + if first_index is not None and first_frequency == 1 and first_index == second_index: + flag_stop = False + arg_with_ind.element = Trace(arg_with_ind.element)._normalize() + arg_with_ind.indices = [] + break + + scan_indices = [] + if first_frequency == 2: + scan_indices.append(first_index) + if second_frequency == 2: + scan_indices.append(second_index) + + candidate, transpose, found_index = _get_candidate_for_matmul_from_contraction(scan_indices, editor.args_with_ind[i+1:]) + if candidate is not None: + flag_stop = False + editor.track_permutation_merge(arg_with_ind, candidate) + transpose1 = found_index == first_index + new_arge, other_index = _insert_candidate_into_editor(editor, arg_with_ind, candidate, transpose1, transpose) + if found_index == first_index: + new_arge.indices = [second_index, other_index] + else: + new_arge.indices = [first_index, other_index] + set_indices = set(new_arge.indices) + if len(set_indices) == 1 and set_indices != {None}: + # This is a trace: + new_arge.element = Trace(new_arge.element)._normalize() + new_arge.indices = [] + editor.args_with_ind[i] = new_arge + # TODO: is this break necessary? + break + + if flag_stop: + break + + editor.refresh_indices() + return editor.to_array_contraction() + + +def _find_trivial_matrices_rewrite(expr: ArrayTensorProduct): + # If there are matrices of trivial shape in the tensor product (i.e. shape + # (1, 1)), try to check if there is a suitable non-trivial MatMul where the + # expression can be inserted. + + # For example, if "a" has shape (1, 1) and "b" has shape (k, 1), the + # expressions "_array_tensor_product(a, b*b.T)" can be rewritten as + # "b*a*b.T" + + trivial_matrices = [] + pos: int | None = None # must be initialized else causes UnboundLocalError + first: MatrixExpr | None = None # may cause UnboundLocalError if not initialized + second: MatrixExpr | None = None # may cause UnboundLocalError if not initialized + removed: list[int] = [] + counter: int = 0 + args: list[Basic | None] = list(expr.args) + for i, arg in enumerate(expr.args): + if isinstance(arg, MatrixExpr): + if arg.shape == (1, 1): + trivial_matrices.append(arg) + args[i] = None + removed.extend([counter, counter+1]) + elif pos is None and isinstance(arg, MatMul): + margs = arg.args + for j, e in enumerate(margs): + if isinstance(e, MatrixExpr) and e.shape[1] == 1: + pos = i + first = MatMul.fromiter(margs[:j+1]) + second = MatMul.fromiter(margs[j+1:]) + break + counter += get_rank(arg) + if pos is None: + return expr, [] + args[pos] = (first*MatMul.fromiter(i for i in trivial_matrices)*second).doit() + return _array_tensor_product(*[i for i in args if i is not None]), removed + + +def _find_trivial_kronecker_products_broadcast(expr: ArrayTensorProduct): + newargs: list[Basic] = [] + removed = [] + count_dims = 0 + for arg in expr.args: + count_dims += get_rank(arg) + shape = get_shape(arg) + current_range = [count_dims-i for i in range(len(shape), 0, -1)] + if (shape == (1, 1) and len(newargs) > 0 and 1 not in get_shape(newargs[-1]) and + isinstance(newargs[-1], MatrixExpr) and isinstance(arg, MatrixExpr)): + # KroneckerProduct object allows the trick of broadcasting: + newargs[-1] = KroneckerProduct(newargs[-1], arg) + removed.extend(current_range) + elif 1 not in shape and len(newargs) > 0 and get_shape(newargs[-1]) == (1, 1): + # Broadcast: + newargs[-1] = KroneckerProduct(newargs[-1], arg) + prev_range = [i for i in range(min(current_range)) if i not in removed] + removed.extend(prev_range[-2:]) + else: + newargs.append(arg) + return _array_tensor_product(*newargs), removed + + +@singledispatch +def _array2matrix(expr): + return expr + + +@_array2matrix.register(ZeroArray) +def _(expr: ZeroArray): + if get_rank(expr) == 2: + return ZeroMatrix(*expr.shape) + else: + return expr + + +@_array2matrix.register(ArrayTensorProduct) +def _(expr: ArrayTensorProduct): + return _a2m_tensor_product(*[_array2matrix(arg) for arg in expr.args]) + + +@_array2matrix.register(ArrayContraction) +def _(expr: ArrayContraction): + expr = expr.flatten_contraction_of_diagonal() + expr = identify_removable_identity_matrices(expr) + expr = expr.split_multiple_contractions() + expr = identify_hadamard_products(expr) + if not isinstance(expr, ArrayContraction): + return _array2matrix(expr) + subexpr = expr.expr + contraction_indices: tuple[tuple[int]] = expr.contraction_indices + if contraction_indices == ((0,), (1,)) or ( + contraction_indices == ((0,),) and subexpr.shape[1] == 1 + ) or ( + contraction_indices == ((1,),) and subexpr.shape[0] == 1 + ): + shape = subexpr.shape + subexpr = _array2matrix(subexpr) + if isinstance(subexpr, MatrixExpr): + return OneMatrix(1, shape[0])*subexpr*OneMatrix(shape[1], 1) + if isinstance(subexpr, ArrayTensorProduct): + newexpr = _array_contraction(_array2matrix(subexpr), *contraction_indices) + contraction_indices = newexpr.contraction_indices + if any(i > 2 for i in newexpr.subranks): + addends = _array_add(*[_a2m_tensor_product(*j) for j in itertools.product(*[i.args if isinstance(i, + ArrayAdd) else [i] for i in expr.expr.args])]) + newexpr = _array_contraction(addends, *contraction_indices) + if isinstance(newexpr, ArrayAdd): + ret = _array2matrix(newexpr) + return ret + assert isinstance(newexpr, ArrayContraction) + ret = _support_function_tp1_recognize(contraction_indices, list(newexpr.expr.args)) + return ret + elif not isinstance(subexpr, _CodegenArrayAbstract): + ret = _array2matrix(subexpr) + if isinstance(ret, MatrixExpr): + assert expr.contraction_indices == ((0, 1),) + return _a2m_trace(ret) + else: + return _array_contraction(ret, *expr.contraction_indices) + + +@_array2matrix.register(ArrayDiagonal) +def _(expr: ArrayDiagonal): + pexpr = _array_diagonal(_array2matrix(expr.expr), *expr.diagonal_indices) + pexpr = identify_hadamard_products(pexpr) + if isinstance(pexpr, ArrayDiagonal): + pexpr = _array_diag2contr_diagmatrix(pexpr) + if expr == pexpr: + return expr + return _array2matrix(pexpr) + + +@_array2matrix.register(PermuteDims) +def _(expr: PermuteDims): + if expr.permutation.array_form == [1, 0]: + return _a2m_transpose(_array2matrix(expr.expr)) + elif isinstance(expr.expr, ArrayTensorProduct): + ranks = expr.expr.subranks + inv_permutation = expr.permutation**(-1) + newrange = [inv_permutation(i) for i in range(sum(ranks))] + newpos = [] + counter = 0 + for rank in ranks: + newpos.append(newrange[counter:counter+rank]) + counter += rank + newargs = [] + newperm = [] + scalars = [] + for pos, arg in zip(newpos, expr.expr.args): + if len(pos) == 0: + scalars.append(_array2matrix(arg)) + elif pos == sorted(pos): + newargs.append((_array2matrix(arg), pos[0])) + newperm.extend(pos) + elif len(pos) == 2: + newargs.append((_a2m_transpose(_array2matrix(arg)), pos[0])) + newperm.extend(reversed(pos)) + else: + raise NotImplementedError() + newargs = [i[0] for i in newargs] + return _permute_dims(_a2m_tensor_product(*scalars, *newargs), _af_invert(newperm)) + elif isinstance(expr.expr, ArrayContraction): + mat_mul_lines = _array2matrix(expr.expr) + if not isinstance(mat_mul_lines, ArrayTensorProduct): + return _permute_dims(mat_mul_lines, expr.permutation) + # TODO: this assumes that all arguments are matrices, it may not be the case: + permutation = Permutation(2*len(mat_mul_lines.args)-1)*expr.permutation + permuted = [permutation(i) for i in range(2*len(mat_mul_lines.args))] + args_array = [None for i in mat_mul_lines.args] + for i in range(len(mat_mul_lines.args)): + p1 = permuted[2*i] + p2 = permuted[2*i+1] + if p1 // 2 != p2 // 2: + return _permute_dims(mat_mul_lines, permutation) + if p1 > p2: + args_array[i] = _a2m_transpose(mat_mul_lines.args[p1 // 2]) + else: + args_array[i] = mat_mul_lines.args[p1 // 2] + return _a2m_tensor_product(*args_array) + else: + return expr + + +@_array2matrix.register(ArrayAdd) +def _(expr: ArrayAdd): + addends = [_array2matrix(arg) for arg in expr.args] + return _a2m_add(*addends) + + +@_array2matrix.register(ArrayElementwiseApplyFunc) +def _(expr: ArrayElementwiseApplyFunc): + subexpr = _array2matrix(expr.expr) + if isinstance(subexpr, MatrixExpr): + if subexpr.shape != (1, 1): + d = expr.function.bound_symbols[0] + w = Wild("w", exclude=[d]) + p = Wild("p", exclude=[d]) + m = expr.function.expr.match(w*d**p) + if m is not None: + return m[w]*HadamardPower(subexpr, m[p]) + return ElementwiseApplyFunction(expr.function, subexpr) + else: + return ArrayElementwiseApplyFunc(expr.function, subexpr) + + +@_array2matrix.register(ArrayElement) +def _(expr: ArrayElement): + ret = _array2matrix(expr.name) + if isinstance(ret, MatrixExpr): + return MatrixElement(ret, *expr.indices) + return ArrayElement(ret, expr.indices) + + +@singledispatch +def _remove_trivial_dims(expr): + return expr, [] + + +@_remove_trivial_dims.register(ArrayTensorProduct) +def _(expr: ArrayTensorProduct): + # Recognize expressions like [x, y] with shape (k, 1, k, 1) as `x*y.T`. + # The matrix expression has to be equivalent to the tensor product of the + # matrices, with trivial dimensions (i.e. dim=1) dropped. + # That is, add contractions over trivial dimensions: + + removed = [] + newargs = [] + cumul = list(accumulate([0] + [get_rank(arg) for arg in expr.args])) + pending = None + prev_i = None + for i, arg in enumerate(expr.args): + current_range = list(range(cumul[i], cumul[i+1])) + if isinstance(arg, OneArray): + removed.extend(current_range) + continue + if not isinstance(arg, (MatrixExpr, MatrixBase)): + rarg, rem = _remove_trivial_dims(arg) + removed.extend(rem) + newargs.append(rarg) + continue + elif getattr(arg, "is_Identity", False) and arg.shape == (1, 1): + if arg.shape == (1, 1): + # Ignore identity matrices of shape (1, 1) - they are equivalent to scalar 1. + removed.extend(current_range) + continue + elif arg.shape == (1, 1): + arg, _ = _remove_trivial_dims(arg) + # Matrix is equivalent to scalar: + if len(newargs) == 0: + newargs.append(arg) + elif 1 in get_shape(newargs[-1]): + if newargs[-1].shape[1] == 1: + newargs[-1] = newargs[-1]*arg + else: + newargs[-1] = arg*newargs[-1] + removed.extend(current_range) + else: + newargs.append(arg) + elif 1 in arg.shape: + k = [i for i in arg.shape if i != 1][0] + if pending is None: + pending = k + prev_i = i + newargs.append(arg) + elif pending == k: + prev = newargs[-1] + if prev.shape[0] == 1: + d1 = cumul[prev_i] # type: ignore + prev = _a2m_transpose(prev) + else: + d1 = cumul[prev_i] + 1 # type: ignore + if arg.shape[1] == 1: + d2 = cumul[i] + 1 + arg = _a2m_transpose(arg) + else: + d2 = cumul[i] + newargs[-1] = prev*arg + pending = None + removed.extend([d1, d2]) + else: + newargs.append(arg) + pending = k + prev_i = i + else: + newargs.append(arg) + pending = None + newexpr, newremoved = _a2m_tensor_product(*newargs), sorted(removed) + if isinstance(newexpr, ArrayTensorProduct): + newexpr, newremoved2 = _find_trivial_matrices_rewrite(newexpr) + newremoved = _combine_removed(-1, newremoved, newremoved2) + if isinstance(newexpr, ArrayTensorProduct): + newexpr, newremoved2 = _find_trivial_kronecker_products_broadcast(newexpr) + newremoved = _combine_removed(-1, newremoved, newremoved2) + return newexpr, newremoved + + +@_remove_trivial_dims.register(ArrayAdd) +def _(expr: ArrayAdd): + rec = [_remove_trivial_dims(arg) for arg in expr.args] + newargs, removed = zip(*rec) + if len({get_shape(i) for i in newargs}) > 1: + return expr, [] + if len(removed) == 0: + return expr, removed + removed1 = removed[0] + return _a2m_add(*newargs), removed1 + + +@_remove_trivial_dims.register(PermuteDims) +def _(expr: PermuteDims): + subexpr, subremoved = _remove_trivial_dims(expr.expr) + p = expr.permutation.array_form + pinv = _af_invert(expr.permutation.array_form) + shift = list(accumulate([1 if i in subremoved else 0 for i in range(len(p))])) + premoved = [pinv[i] for i in subremoved] + p2 = [e - shift[e] for e in p if e not in subremoved] + # TODO: check if subremoved should be permuted as well... + newexpr = _permute_dims(subexpr, p2) + premoved = sorted(premoved) + if newexpr != expr: + newexpr, removed2 = _remove_trivial_dims(_array2matrix(newexpr)) + premoved = _combine_removed(-1, premoved, removed2) + return newexpr, premoved + + +@_remove_trivial_dims.register(ArrayContraction) +def _(expr: ArrayContraction): + new_expr, removed0 = _array_contraction_to_diagonal_multiple_identity(expr) + if new_expr != expr: + new_expr2, removed1 = _remove_trivial_dims(_array2matrix(new_expr)) + removed = _combine_removed(-1, removed0, removed1) + return new_expr2, removed + rank1 = get_rank(expr) + expr, removed1 = remove_identity_matrices(expr) + if not isinstance(expr, ArrayContraction): + expr2, removed2 = _remove_trivial_dims(expr) + return expr2, _combine_removed(rank1, removed1, removed2) + newexpr, removed2 = _remove_trivial_dims(expr.expr) + shifts = list(accumulate([1 if i in removed2 else 0 for i in range(get_rank(expr.expr))])) + new_contraction_indices = [tuple(j for j in i if j not in removed2) for i in expr.contraction_indices] + # Remove possible empty tuples "()": + new_contraction_indices = [i for i in new_contraction_indices if len(i) > 0] + contraction_indices_flat = [j for i in expr.contraction_indices for j in i] + removed2 = [i for i in removed2 if i not in contraction_indices_flat] + new_contraction_indices = [tuple(j - shifts[j] for j in i) for i in new_contraction_indices] + # Shift removed2: + removed2 = ArrayContraction._push_indices_up(expr.contraction_indices, removed2) + removed = _combine_removed(rank1, removed1, removed2) + return _array_contraction(newexpr, *new_contraction_indices), list(removed) + + +def _remove_diagonalized_identity_matrices(expr: ArrayDiagonal): + assert isinstance(expr, ArrayDiagonal) + editor = _EditArrayContraction(expr) + mapping = {i: {j for j in editor.args_with_ind if i in j.indices} for i in range(-1, -1-editor.number_of_diagonal_indices, -1)} + removed = [] + counter: int = 0 + for i, arg_with_ind in enumerate(editor.args_with_ind): + counter += len(arg_with_ind.indices) + if isinstance(arg_with_ind.element, Identity): + if None in arg_with_ind.indices and any(i is not None and (i < 0) == True for i in arg_with_ind.indices): + diag_ind = [j for j in arg_with_ind.indices if j is not None][0] + other = [j for j in mapping[diag_ind] if j != arg_with_ind][0] + if not isinstance(other.element, MatrixExpr): + continue + if 1 not in other.element.shape: + continue + if None not in other.indices: + continue + editor.args_with_ind[i].element = None + none_index = other.indices.index(None) + other.element = DiagMatrix(other.element) + other_range = editor.get_absolute_range(other) + removed.extend([other_range[0] + none_index]) + editor.args_with_ind = [i for i in editor.args_with_ind if i.element is not None] + removed = ArrayDiagonal._push_indices_up(expr.diagonal_indices, removed, get_rank(expr.expr)) + return editor.to_array_contraction(), removed + + +@_remove_trivial_dims.register(ArrayDiagonal) +def _(expr: ArrayDiagonal): + newexpr, removed = _remove_trivial_dims(expr.expr) + shifts = list(accumulate([0] + [1 if i in removed else 0 for i in range(get_rank(expr.expr))])) + new_diag_indices_map = {i: tuple(j for j in i if j not in removed) for i in expr.diagonal_indices} + for old_diag_tuple, new_diag_tuple in new_diag_indices_map.items(): + if len(new_diag_tuple) == 1: + removed = [i for i in removed if i not in old_diag_tuple] + new_diag_indices = [tuple(j - shifts[j] for j in i) for i in new_diag_indices_map.values()] + rank = get_rank(expr.expr) + removed = ArrayDiagonal._push_indices_up(expr.diagonal_indices, removed, rank) + removed = sorted(set(removed)) + # If there are single axes to diagonalize remaining, it means that their + # corresponding dimension has been removed, they no longer need diagonalization: + new_diag_indices = [i for i in new_diag_indices if len(i) > 0] + if len(new_diag_indices) > 0: + newexpr2 = _array_diagonal(newexpr, *new_diag_indices, allow_trivial_diags=True) + else: + newexpr2 = newexpr + if isinstance(newexpr2, ArrayDiagonal): + newexpr3, removed2 = _remove_diagonalized_identity_matrices(newexpr2) + removed = _combine_removed(-1, removed, removed2) + return newexpr3, removed + else: + return newexpr2, removed + + +@_remove_trivial_dims.register(ElementwiseApplyFunction) +def _(expr: ElementwiseApplyFunction): + subexpr, removed = _remove_trivial_dims(expr.expr) + if subexpr.shape == (1, 1): + # TODO: move this to ElementwiseApplyFunction + return expr.function(subexpr), removed + [0, 1] + return ElementwiseApplyFunction(expr.function, subexpr), [] + + +@_remove_trivial_dims.register(ArrayElementwiseApplyFunc) +def _(expr: ArrayElementwiseApplyFunc): + subexpr, removed = _remove_trivial_dims(expr.expr) + return ArrayElementwiseApplyFunc(expr.function, subexpr), removed + + +def convert_array_to_matrix(expr): + r""" + Recognize matrix expressions in codegen objects. + + If more than one matrix multiplication line have been detected, return a + list with the matrix expressions. + + Examples + ======== + + >>> from sympy.tensor.array.expressions.from_indexed_to_array import convert_indexed_to_array + >>> from sympy.tensor.array import tensorcontraction, tensorproduct + >>> from sympy import MatrixSymbol, Sum + >>> from sympy.abc import i, j, k, l, N + >>> from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array + >>> from sympy.tensor.array.expressions.from_array_to_matrix import convert_array_to_matrix + >>> A = MatrixSymbol("A", N, N) + >>> B = MatrixSymbol("B", N, N) + >>> C = MatrixSymbol("C", N, N) + >>> D = MatrixSymbol("D", N, N) + + >>> expr = Sum(A[i, j]*B[j, k], (j, 0, N-1)) + >>> cg = convert_indexed_to_array(expr) + >>> convert_array_to_matrix(cg) + A*B + >>> cg = convert_indexed_to_array(expr, first_indices=[k]) + >>> convert_array_to_matrix(cg) + B.T*A.T + + Transposition is detected: + + >>> expr = Sum(A[j, i]*B[j, k], (j, 0, N-1)) + >>> cg = convert_indexed_to_array(expr) + >>> convert_array_to_matrix(cg) + A.T*B + >>> cg = convert_indexed_to_array(expr, first_indices=[k]) + >>> convert_array_to_matrix(cg) + B.T*A + + Detect the trace: + + >>> expr = Sum(A[i, i], (i, 0, N-1)) + >>> cg = convert_indexed_to_array(expr) + >>> convert_array_to_matrix(cg) + Trace(A) + + Recognize some more complex traces: + + >>> expr = Sum(A[i, j]*B[j, i], (i, 0, N-1), (j, 0, N-1)) + >>> cg = convert_indexed_to_array(expr) + >>> convert_array_to_matrix(cg) + Trace(A*B) + + More complicated expressions: + + >>> expr = Sum(A[i, j]*B[k, j]*A[l, k], (j, 0, N-1), (k, 0, N-1)) + >>> cg = convert_indexed_to_array(expr) + >>> convert_array_to_matrix(cg) + A*B.T*A.T + + Expressions constructed from matrix expressions do not contain literal + indices, the positions of free indices are returned instead: + + >>> expr = A*B + >>> cg = convert_matrix_to_array(expr) + >>> convert_array_to_matrix(cg) + A*B + + If more than one line of matrix multiplications is detected, return + separate matrix multiplication factors embedded in a tensor product object: + + >>> cg = tensorcontraction(tensorproduct(A, B, C, D), (1, 2), (5, 6)) + >>> convert_array_to_matrix(cg) + ArrayTensorProduct(A*B, C*D) + + The two lines have free indices at axes 0, 3 and 4, 7, respectively. + """ + rec = _array2matrix(expr) + rec, removed = _remove_trivial_dims(rec) + return rec + + +def _array_diag2contr_diagmatrix(expr: ArrayDiagonal): + if isinstance(expr.expr, ArrayTensorProduct): + args = list(expr.expr.args) + diag_indices = list(expr.diagonal_indices) + mapping = _get_mapping_from_subranks([_get_subrank(arg) for arg in args]) + tuple_links = [[mapping[j] for j in i] for i in diag_indices] + contr_indices = [] + total_rank = get_rank(expr) + replaced = [False for arg in args] + for i, (abs_pos, rel_pos) in enumerate(zip(diag_indices, tuple_links)): + if len(abs_pos) != 2: + continue + (pos1_outer, pos1_inner), (pos2_outer, pos2_inner) = rel_pos + arg1 = args[pos1_outer] + arg2 = args[pos2_outer] + if get_rank(arg1) != 2 or get_rank(arg2) != 2: + if replaced[pos1_outer]: + diag_indices[i] = None + if replaced[pos2_outer]: + diag_indices[i] = None + continue + pos1_in2 = 1 - pos1_inner + pos2_in2 = 1 - pos2_inner + if arg1.shape[pos1_in2] == 1: + if arg1.shape[pos1_inner] != 1: + darg1 = DiagMatrix(arg1) + else: + darg1 = arg1 + args.append(darg1) + contr_indices.append(((pos2_outer, pos2_inner), (len(args)-1, pos1_inner))) + total_rank += 1 + diag_indices[i] = None + args[pos1_outer] = OneArray(arg1.shape[pos1_in2]) + replaced[pos1_outer] = True + elif arg2.shape[pos2_in2] == 1: + if arg2.shape[pos2_inner] != 1: + darg2 = DiagMatrix(arg2) + else: + darg2 = arg2 + args.append(darg2) + contr_indices.append(((pos1_outer, pos1_inner), (len(args)-1, pos2_inner))) + total_rank += 1 + diag_indices[i] = None + args[pos2_outer] = OneArray(arg2.shape[pos2_in2]) + replaced[pos2_outer] = True + diag_indices_new = [i for i in diag_indices if i is not None] + cumul = list(accumulate([0] + [get_rank(arg) for arg in args])) + contr_indices2 = [tuple(cumul[a] + b for a, b in i) for i in contr_indices] + tc = _array_contraction( + _array_tensor_product(*args), *contr_indices2 + ) + td = _array_diagonal(tc, *diag_indices_new) + return td + return expr + + +def _a2m_mul(*args): + if not any(isinstance(i, _CodegenArrayAbstract) for i in args): + from sympy.matrices.expressions.matmul import MatMul + return MatMul(*args).doit() + else: + return _array_contraction( + _array_tensor_product(*args), + *[(2*i-1, 2*i) for i in range(1, len(args))] + ) + + +def _a2m_tensor_product(*args): + scalars = [] + arrays = [] + for arg in args: + if isinstance(arg, (MatrixExpr, _ArrayExpr, _CodegenArrayAbstract)): + arrays.append(arg) + else: + scalars.append(arg) + scalar = Mul.fromiter(scalars) + if len(arrays) == 0: + return scalar + if scalar != 1: + if isinstance(arrays[0], _CodegenArrayAbstract): + arrays = [scalar] + arrays + else: + arrays[0] *= scalar + return _array_tensor_product(*arrays) + + +def _a2m_add(*args): + if not any(isinstance(i, _CodegenArrayAbstract) for i in args): + from sympy.matrices.expressions.matadd import MatAdd + return MatAdd(*args).doit() + else: + return _array_add(*args) + + +def _a2m_trace(arg): + if isinstance(arg, _CodegenArrayAbstract): + return _array_contraction(arg, (0, 1)) + else: + from sympy.matrices.expressions.trace import Trace + return Trace(arg) + + +def _a2m_transpose(arg): + if isinstance(arg, _CodegenArrayAbstract): + return _permute_dims(arg, [1, 0]) + else: + from sympy.matrices.expressions.transpose import Transpose + return Transpose(arg).doit() + + +def identify_hadamard_products(expr: ArrayContraction | ArrayDiagonal): + + editor: _EditArrayContraction = _EditArrayContraction(expr) + + map_contr_to_args: dict[FrozenSet, list[_ArgE]] = defaultdict(list) + map_ind_to_inds: dict[int | None, int] = defaultdict(int) + for arg_with_ind in editor.args_with_ind: + for ind in arg_with_ind.indices: + map_ind_to_inds[ind] += 1 + if None in arg_with_ind.indices: + continue + map_contr_to_args[frozenset(arg_with_ind.indices)].append(arg_with_ind) + + k: FrozenSet[int] + v: list[_ArgE] + for k, v in map_contr_to_args.items(): + make_trace: bool = False + if len(k) == 1 and next(iter(k)) >= 0 and sum(next(iter(k)) in i for i in map_contr_to_args) == 1: + # This is a trace: the arguments are fully contracted with only one + # index, and the index isn't used anywhere else: + make_trace = True + first_element = S.One + elif len(k) != 2: + # Hadamard product only defined for matrices: + continue + if len(v) == 1: + # Hadamard product with a single argument makes no sense: + continue + for ind in k: + if map_ind_to_inds[ind] <= 2: + # There is no other contraction, skip: + continue + + def check_transpose(x): + x = [i if i >= 0 else -1-i for i in x] + return x == sorted(x) + + # Check if expression is a trace: + if all(map_ind_to_inds[j] == len(v) and j >= 0 for j in k) and all(j >= 0 for j in k): + # This is a trace + make_trace = True + first_element = v[0].element + if not check_transpose(v[0].indices): + first_element = first_element.T # type: ignore + hadamard_factors = v[1:] + else: + hadamard_factors = v + + # This is a Hadamard product: + + hp = hadamard_product(*[i.element if check_transpose(i.indices) else Transpose(i.element) for i in hadamard_factors]) + hp_indices = v[0].indices + if not check_transpose(hadamard_factors[0].indices): + hp_indices = list(reversed(hp_indices)) + if make_trace: + hp = Trace(first_element*hp.T)._normalize() + hp_indices = [] + editor.insert_after(v[0], _ArgE(hp, hp_indices)) + for i in v: + editor.args_with_ind.remove(i) + + return editor.to_array_contraction() + + +def identify_removable_identity_matrices(expr): + editor = _EditArrayContraction(expr) + + flag = True + while flag: + flag = False + for arg_with_ind in editor.args_with_ind: + if isinstance(arg_with_ind.element, Identity): + k = arg_with_ind.element.shape[0] + # Candidate for removal: + if arg_with_ind.indices == [None, None]: + # Free identity matrix, will be cleared by _remove_trivial_dims: + continue + elif None in arg_with_ind.indices: + ind = [j for j in arg_with_ind.indices if j is not None][0] + counted = editor.count_args_with_index(ind) + if counted == 1: + # Identity matrix contracted only on one index with itself, + # transform to a OneArray(k) element: + editor.insert_after(arg_with_ind, OneArray(k)) + editor.args_with_ind.remove(arg_with_ind) + flag = True + break + elif counted > 2: + # Case counted = 2 is a matrix multiplication by identity matrix, skip it. + # Case counted > 2 is a multiple contraction, + # this is a case where the contraction becomes a diagonalization if the + # identity matrix is dropped. + continue + elif arg_with_ind.indices[0] == arg_with_ind.indices[1]: + ind = arg_with_ind.indices[0] + counted = editor.count_args_with_index(ind) + if counted > 1: + editor.args_with_ind.remove(arg_with_ind) + flag = True + break + else: + # This is a trace, skip it as it will be recognized somewhere else: + pass + elif ask(Q.diagonal(arg_with_ind.element)): + if arg_with_ind.indices == [None, None]: + continue + elif None in arg_with_ind.indices: + pass + elif arg_with_ind.indices[0] == arg_with_ind.indices[1]: + ind = arg_with_ind.indices[0] + counted = editor.count_args_with_index(ind) + if counted == 3: + # A_ai B_bi D_ii ==> A_ai D_ij B_bj + ind_new = editor.get_new_contraction_index() + other_args = [j for j in editor.args_with_ind if j != arg_with_ind] + other_args[1].indices = [ind_new if j == ind else j for j in other_args[1].indices] + arg_with_ind.indices = [ind, ind_new] + flag = True + break + + return editor.to_array_contraction() + + +def remove_identity_matrices(expr: ArrayContraction): + editor = _EditArrayContraction(expr) + removed: list[int] = [] + + permutation_map = {} + + free_indices = list(accumulate([0] + [sum(i is None for i in arg.indices) for arg in editor.args_with_ind])) + free_map = dict(zip(editor.args_with_ind, free_indices[:-1])) + + update_pairs = {} + + for ind in range(editor.number_of_contraction_indices): + args = editor.get_args_with_index(ind) + identity_matrices = [i for i in args if isinstance(i.element, Identity)] + number_identity_matrices = len(identity_matrices) + # If the contraction involves a non-identity matrix and multiple identity matrices: + if number_identity_matrices != len(args) - 1 or number_identity_matrices == 0: + continue + # Get the non-identity element: + non_identity = [i for i in args if not isinstance(i.element, Identity)][0] + # Check that all identity matrices have at least one free index + # (otherwise they would be contractions to some other elements) + if any(None not in i.indices for i in identity_matrices): + continue + # Mark the identity matrices for removal: + for i in identity_matrices: + i.element = None + removed.extend(range(free_map[i], free_map[i] + len([j for j in i.indices if j is None]))) + last_removed = removed.pop(-1) + update_pairs[last_removed, ind] = non_identity.indices[:] + # Remove the indices from the non-identity matrix, as the contraction + # no longer exists: + non_identity.indices = [None if i == ind else i for i in non_identity.indices] + + removed.sort() + + shifts = list(accumulate([1 if i in removed else 0 for i in range(get_rank(expr))])) + for (last_removed, ind), non_identity_indices in update_pairs.items(): + pos = [free_map[non_identity] + i for i, e in enumerate(non_identity_indices) if e == ind] + assert len(pos) == 1 + for j in pos: + permutation_map[j] = last_removed + + editor.args_with_ind = [i for i in editor.args_with_ind if i.element is not None] + ret_expr = editor.to_array_contraction() + permutation = [] + counter = 0 + counter2 = 0 + for j in range(get_rank(expr)): + if j in removed: + continue + if counter2 in permutation_map: + target = permutation_map[counter2] + permutation.append(target - shifts[target]) + counter2 += 1 + else: + while counter in permutation_map.values(): + counter += 1 + permutation.append(counter) + counter += 1 + counter2 += 1 + ret_expr2 = _permute_dims(ret_expr, _af_invert(permutation)) + return ret_expr2, removed + + +def _combine_removed(dim: int, removed1: list[int], removed2: list[int]) -> list[int]: + # Concatenate two axis removal operations as performed by + # _remove_trivial_dims, + removed1 = sorted(removed1) + removed2 = sorted(removed2) + i = 0 + j = 0 + removed = [] + while True: + if j >= len(removed2): + while i < len(removed1): + removed.append(removed1[i]) + i += 1 + break + elif i < len(removed1) and removed1[i] <= i + removed2[j]: + removed.append(removed1[i]) + i += 1 + else: + removed.append(i + removed2[j]) + j += 1 + return removed + + +def _array_contraction_to_diagonal_multiple_identity(expr: ArrayContraction): + editor = _EditArrayContraction(expr) + editor.track_permutation_start() + removed: list[int] = [] + diag_index_counter: int = 0 + for i in range(editor.number_of_contraction_indices): + identities = [] + args = [] + for j, arg in enumerate(editor.args_with_ind): + if i not in arg.indices: + continue + if isinstance(arg.element, Identity): + identities.append(arg) + else: + args.append(arg) + if len(identities) == 0: + continue + if len(args) + len(identities) < 3: + continue + new_diag_ind = -1 - diag_index_counter + diag_index_counter += 1 + # Variable "flag" to control whether to skip this contraction set: + flag: bool = True + for i1, id1 in enumerate(identities): + if None not in id1.indices: + flag = True + break + free_pos = list(range(*editor.get_absolute_free_range(id1)))[0] + editor._track_permutation[-1].append(free_pos) # type: ignore + id1.element = None + flag = False + break + if flag: + continue + for arg in identities[:i1] + identities[i1+1:]: + arg.element = None + removed.extend(range(*editor.get_absolute_free_range(arg))) + for arg in args: + arg.indices = [new_diag_ind if j == i else j for j in arg.indices] + for j, e in enumerate(editor.args_with_ind): + if e.element is None: + editor._track_permutation[j] = None # type: ignore + editor._track_permutation = [i for i in editor._track_permutation if i is not None] # type: ignore + # Renumber permutation array form in order to deal with deleted positions: + remap = {e: i for i, e in enumerate(sorted({k for j in editor._track_permutation for k in j}))} + editor._track_permutation = [[remap[j] for j in i] for i in editor._track_permutation] + editor.args_with_ind = [i for i in editor.args_with_ind if i.element is not None] + new_expr = editor.to_array_contraction() + return new_expr, removed diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/expressions/from_indexed_to_array.py b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/from_indexed_to_array.py new file mode 100644 index 0000000000000000000000000000000000000000..c219a205c4305bd7070e5117978146224521c58c --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/from_indexed_to_array.py @@ -0,0 +1,257 @@ +from collections import defaultdict + +from sympy import Function +from sympy.combinatorics.permutations import _af_invert +from sympy.concrete.summations import Sum +from sympy.core.add import Add +from sympy.core.mul import Mul +from sympy.core.numbers import Integer +from sympy.core.power import Pow +from sympy.core.sorting import default_sort_key +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.tensor.array.expressions import ArrayElementwiseApplyFunc +from sympy.tensor.indexed import (Indexed, IndexedBase) +from sympy.combinatorics import Permutation +from sympy.matrices.expressions.matexpr import MatrixElement +from sympy.tensor.array.expressions.array_expressions import ArrayDiagonal, \ + get_shape, ArrayElement, _array_tensor_product, _array_diagonal, _array_contraction, _array_add, \ + _permute_dims, OneArray, ArrayAdd +from sympy.tensor.array.expressions.utils import _get_argindex, _get_diagonal_indices + + +def convert_indexed_to_array(expr, first_indices=None): + r""" + Parse indexed expression into a form useful for code generation. + + Examples + ======== + + >>> from sympy.tensor.array.expressions.from_indexed_to_array import convert_indexed_to_array + >>> from sympy import MatrixSymbol, Sum, symbols + + >>> i, j, k, d = symbols("i j k d") + >>> M = MatrixSymbol("M", d, d) + >>> N = MatrixSymbol("N", d, d) + + Recognize the trace in summation form: + + >>> expr = Sum(M[i, i], (i, 0, d-1)) + >>> convert_indexed_to_array(expr) + ArrayContraction(M, (0, 1)) + + Recognize the extraction of the diagonal by using the same index `i` on + both axes of the matrix: + + >>> expr = M[i, i] + >>> convert_indexed_to_array(expr) + ArrayDiagonal(M, (0, 1)) + + This function can help perform the transformation expressed in two + different mathematical notations as: + + `\sum_{j=0}^{N-1} A_{i,j} B_{j,k} \Longrightarrow \mathbf{A}\cdot \mathbf{B}` + + Recognize the matrix multiplication in summation form: + + >>> expr = Sum(M[i, j]*N[j, k], (j, 0, d-1)) + >>> convert_indexed_to_array(expr) + ArrayContraction(ArrayTensorProduct(M, N), (1, 2)) + + Specify that ``k`` has to be the starting index: + + >>> convert_indexed_to_array(expr, first_indices=[k]) + ArrayContraction(ArrayTensorProduct(N, M), (0, 3)) + """ + + result, indices = _convert_indexed_to_array(expr) + + if any(isinstance(i, (int, Integer)) for i in indices): + result = ArrayElement(result, indices) + indices = [] + + if not first_indices: + return result + + def _check_is_in(elem, indices): + if elem in indices: + return True + if any(elem in i for i in indices if isinstance(i, frozenset)): + return True + return False + + repl = {j: i for i in indices if isinstance(i, frozenset) for j in i} + first_indices = [repl.get(i, i) for i in first_indices] + for i in first_indices: + if not _check_is_in(i, indices): + first_indices.remove(i) + first_indices.extend([i for i in indices if not _check_is_in(i, first_indices)]) + + def _get_pos(elem, indices): + if elem in indices: + return indices.index(elem) + for i, e in enumerate(indices): + if not isinstance(e, frozenset): + continue + if elem in e: + return i + raise ValueError("not found") + + permutation = _af_invert([_get_pos(i, first_indices) for i in indices]) + if isinstance(result, ArrayAdd): + return _array_add(*[_permute_dims(arg, permutation) for arg in result.args]) + else: + return _permute_dims(result, permutation) + + +def _convert_indexed_to_array(expr): + if isinstance(expr, Sum): + function = expr.function + summation_indices = expr.variables + subexpr, subindices = _convert_indexed_to_array(function) + subindicessets = {j: i for i in subindices if isinstance(i, frozenset) for j in i} + summation_indices = sorted({subindicessets.get(i, i) for i in summation_indices}, key=default_sort_key) + # TODO: check that Kronecker delta is only contracted to one other element: + kronecker_indices = set() + if isinstance(function, Mul): + for arg in function.args: + if not isinstance(arg, KroneckerDelta): + continue + arg_indices = sorted(set(arg.indices), key=default_sort_key) + if len(arg_indices) == 2: + kronecker_indices.update(arg_indices) + kronecker_indices = sorted(kronecker_indices, key=default_sort_key) + # Check dimensional consistency: + shape = get_shape(subexpr) + if shape: + for ind, istart, iend in expr.limits: + i = _get_argindex(subindices, ind) + if istart != 0 or iend+1 != shape[i]: + raise ValueError("summation index and array dimension mismatch: %s" % ind) + contraction_indices = [] + subindices = list(subindices) + if isinstance(subexpr, ArrayDiagonal): + diagonal_indices = list(subexpr.diagonal_indices) + dindices = subindices[-len(diagonal_indices):] + subindices = subindices[:-len(diagonal_indices)] + for index in summation_indices: + if index in dindices: + position = dindices.index(index) + contraction_indices.append(diagonal_indices[position]) + diagonal_indices[position] = None + diagonal_indices = [i for i in diagonal_indices if i is not None] + for i, ind in enumerate(subindices): + if ind in summation_indices: + pass + if diagonal_indices: + subexpr = _array_diagonal(subexpr.expr, *diagonal_indices) + else: + subexpr = subexpr.expr + + axes_contraction = defaultdict(list) + for i, ind in enumerate(subindices): + include = all(j not in kronecker_indices for j in ind) if isinstance(ind, frozenset) else ind not in kronecker_indices + if ind in summation_indices and include: + axes_contraction[ind].append(i) + subindices[i] = None + for k, v in axes_contraction.items(): + if any(i in kronecker_indices for i in k) if isinstance(k, frozenset) else k in kronecker_indices: + continue + contraction_indices.append(tuple(v)) + free_indices = [i for i in subindices if i is not None] + indices_ret = list(free_indices) + indices_ret.sort(key=lambda x: free_indices.index(x)) + return _array_contraction( + subexpr, + *contraction_indices, + free_indices=free_indices + ), tuple(indices_ret) + if isinstance(expr, Mul): + args, indices = zip(*[_convert_indexed_to_array(arg) for arg in expr.args]) + # Check if there are KroneckerDelta objects: + kronecker_delta_repl = {} + for arg in args: + if not isinstance(arg, KroneckerDelta): + continue + # Diagonalize two indices: + i, j = arg.indices + kindices = set(arg.indices) + if i in kronecker_delta_repl: + kindices.update(kronecker_delta_repl[i]) + if j in kronecker_delta_repl: + kindices.update(kronecker_delta_repl[j]) + kindices = frozenset(kindices) + for index in kindices: + kronecker_delta_repl[index] = kindices + # Remove KroneckerDelta objects, their relations should be handled by + # ArrayDiagonal: + newargs = [] + newindices = [] + for arg, loc_indices in zip(args, indices): + if isinstance(arg, KroneckerDelta): + continue + newargs.append(arg) + newindices.append(loc_indices) + flattened_indices = [kronecker_delta_repl.get(j, j) for i in newindices for j in i] + diagonal_indices, ret_indices = _get_diagonal_indices(flattened_indices) + tp = _array_tensor_product(*newargs) + if diagonal_indices: + return _array_diagonal(tp, *diagonal_indices), ret_indices + else: + return tp, ret_indices + if isinstance(expr, MatrixElement): + indices = expr.args[1:] + diagonal_indices, ret_indices = _get_diagonal_indices(indices) + if diagonal_indices: + return _array_diagonal(expr.args[0], *diagonal_indices), ret_indices + else: + return expr.args[0], ret_indices + if isinstance(expr, ArrayElement): + indices = expr.indices + diagonal_indices, ret_indices = _get_diagonal_indices(indices) + if diagonal_indices: + return _array_diagonal(expr.name, *diagonal_indices), ret_indices + else: + return expr.name, ret_indices + if isinstance(expr, Indexed): + indices = expr.indices + diagonal_indices, ret_indices = _get_diagonal_indices(indices) + if diagonal_indices: + return _array_diagonal(expr.base, *diagonal_indices), ret_indices + else: + return expr.args[0], ret_indices + if isinstance(expr, IndexedBase): + raise NotImplementedError + if isinstance(expr, KroneckerDelta): + return expr, expr.indices + if isinstance(expr, Add): + args, indices = zip(*[_convert_indexed_to_array(arg) for arg in expr.args]) + args = list(args) + # Check if all indices are compatible. Otherwise expand the dimensions: + index0 = [] + shape0 = [] + for arg, arg_indices in zip(args, indices): + arg_indices_set = set(arg_indices) + arg_indices_missing = arg_indices_set.difference(index0) + index0.extend([i for i in arg_indices if i in arg_indices_missing]) + arg_shape = get_shape(arg) + shape0.extend([arg_shape[i] for i, e in enumerate(arg_indices) if e in arg_indices_missing]) + for i, (arg, arg_indices) in enumerate(zip(args, indices)): + if len(arg_indices) < len(index0): + missing_indices_pos = [i for i, e in enumerate(index0) if e not in arg_indices] + missing_shape = [shape0[i] for i in missing_indices_pos] + arg_indices = tuple(index0[j] for j in missing_indices_pos) + arg_indices + args[i] = _array_tensor_product(OneArray(*missing_shape), args[i]) + permutation = Permutation([arg_indices.index(j) for j in index0]) + # Perform index permutations: + args[i] = _permute_dims(args[i], permutation) + return _array_add(*args), tuple(index0) + if isinstance(expr, Pow): + subexpr, subindices = _convert_indexed_to_array(expr.base) + if isinstance(expr.exp, (int, Integer)): + diags = zip(*[(2*i, 2*i + 1) for i in range(expr.exp)]) + arr = _array_diagonal(_array_tensor_product(*[subexpr for i in range(expr.exp)]), *diags) + return arr, subindices + if isinstance(expr, Function): + subexpr, subindices = _convert_indexed_to_array(expr.args[0]) + return ArrayElementwiseApplyFunc(type(expr), subexpr), subindices + return expr, () diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/expressions/from_matrix_to_array.py b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/from_matrix_to_array.py new file mode 100644 index 0000000000000000000000000000000000000000..8f66961727f6338318d65876a7768802773e4f2d --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/from_matrix_to_array.py @@ -0,0 +1,87 @@ +from sympy import KroneckerProduct +from sympy.core.basic import Basic +from sympy.core.function import Lambda +from sympy.core.mul import Mul +from sympy.core.numbers import Integer +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, symbols) +from sympy.matrices.expressions.hadamard import (HadamardPower, HadamardProduct) +from sympy.matrices.expressions.matadd import MatAdd +from sympy.matrices.expressions.matmul import MatMul +from sympy.matrices.expressions.matpow import MatPow +from sympy.matrices.expressions.trace import Trace +from sympy.matrices.expressions.transpose import Transpose +from sympy.matrices.expressions.matexpr import MatrixExpr +from sympy.tensor.array.expressions.array_expressions import \ + ArrayElementwiseApplyFunc, _array_tensor_product, _array_contraction, \ + _array_diagonal, _array_add, _permute_dims, Reshape + + +def convert_matrix_to_array(expr: Basic) -> Basic: + if isinstance(expr, MatMul): + args_nonmat = [] + args = [] + for arg in expr.args: + if isinstance(arg, MatrixExpr): + args.append(arg) + else: + args_nonmat.append(convert_matrix_to_array(arg)) + contractions = [(2*i+1, 2*i+2) for i in range(len(args)-1)] + scalar = _array_tensor_product(*args_nonmat) if args_nonmat else S.One + if scalar == 1: + tprod = _array_tensor_product( + *[convert_matrix_to_array(arg) for arg in args]) + else: + tprod = _array_tensor_product( + scalar, + *[convert_matrix_to_array(arg) for arg in args]) + return _array_contraction( + tprod, + *contractions + ) + elif isinstance(expr, MatAdd): + return _array_add( + *[convert_matrix_to_array(arg) for arg in expr.args] + ) + elif isinstance(expr, Transpose): + return _permute_dims( + convert_matrix_to_array(expr.args[0]), [1, 0] + ) + elif isinstance(expr, Trace): + inner_expr: MatrixExpr = convert_matrix_to_array(expr.arg) # type: ignore + return _array_contraction(inner_expr, (0, len(inner_expr.shape) - 1)) + elif isinstance(expr, Mul): + return _array_tensor_product(*[convert_matrix_to_array(i) for i in expr.args]) + elif isinstance(expr, Pow): + base = convert_matrix_to_array(expr.base) + if (expr.exp > 0) == True: + return _array_tensor_product(*[base for i in range(expr.exp)]) + else: + return expr + elif isinstance(expr, MatPow): + base = convert_matrix_to_array(expr.base) + if expr.exp.is_Integer != True: + b = symbols("b", cls=Dummy) + return ArrayElementwiseApplyFunc(Lambda(b, b**expr.exp), convert_matrix_to_array(base)) + elif (expr.exp > 0) == True: + return convert_matrix_to_array(MatMul.fromiter(base for i in range(expr.exp))) + else: + return expr + elif isinstance(expr, HadamardProduct): + tp = _array_tensor_product(*[convert_matrix_to_array(arg) for arg in expr.args]) + diag = [[2*i for i in range(len(expr.args))], [2*i+1 for i in range(len(expr.args))]] + return _array_diagonal(tp, *diag) + elif isinstance(expr, HadamardPower): + base, exp = expr.args + if isinstance(exp, Integer) and exp > 0: + return convert_matrix_to_array(HadamardProduct.fromiter(base for i in range(exp))) + else: + d = Dummy("d") + return ArrayElementwiseApplyFunc(Lambda(d, d**exp), base) + elif isinstance(expr, KroneckerProduct): + kp_args = [convert_matrix_to_array(arg) for arg in expr.args] + permutation = [2*i for i in range(len(kp_args))] + [2*i + 1 for i in range(len(kp_args))] + return Reshape(_permute_dims(_array_tensor_product(*kp_args), permutation), expr.shape) + else: + return expr diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/expressions/tests/__init__.py b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/expressions/tests/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/tests/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..34ff8b8a4b18e4483bdeb138b92bb04bbe1928ec Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/tests/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_array_expressions.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_array_expressions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d98bc7c7b44b7cc3d9e3b2d92153f217ca2878b Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_array_expressions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_arrayexpr_derivatives.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_arrayexpr_derivatives.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82fc7e7a01443cd9e33f496c4a5a2edcf0ff1c14 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_arrayexpr_derivatives.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_as_explicit.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_as_explicit.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..793b99c7e4c9a812bc7111a166008214e0178b18 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_as_explicit.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_convert_array_to_indexed.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_convert_array_to_indexed.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b45594a5f87af0283ad405033b17a476a0cdeb2 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_convert_array_to_indexed.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_convert_array_to_matrix.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_convert_array_to_matrix.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0381d8b8f7fb3485f0120f69439ec32eef0d108 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_convert_array_to_matrix.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_convert_indexed_to_array.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_convert_indexed_to_array.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2122995499421ba43cc735f139b280f3ee31b17 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_convert_indexed_to_array.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_convert_matrix_to_array.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_convert_matrix_to_array.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab2486c37e7070df2adc9c853ddd903b341c0bbf Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_convert_matrix_to_array.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_deprecated_conv_modules.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_deprecated_conv_modules.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9010736a3b03e26c5e7b140858d21b96e3977506 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/tests/__pycache__/test_deprecated_conv_modules.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/expressions/tests/test_array_expressions.py b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/tests/test_array_expressions.py new file mode 100644 index 0000000000000000000000000000000000000000..63fb79ab7ced7bff5ecb55b1764f43e29f98609d --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/tests/test_array_expressions.py @@ -0,0 +1,808 @@ +import random + +from sympy import tensordiagonal, eye, KroneckerDelta, Array +from sympy.core.symbol import symbols +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.matrices.expressions.diagonal import DiagMatrix +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.expressions.special import ZeroMatrix +from sympy.tensor.array.arrayop import (permutedims, tensorcontraction, tensorproduct) +from sympy.tensor.array.dense_ndim_array import ImmutableDenseNDimArray +from sympy.combinatorics import Permutation +from sympy.tensor.array.expressions.array_expressions import ZeroArray, OneArray, ArraySymbol, ArrayElement, \ + PermuteDims, ArrayContraction, ArrayTensorProduct, ArrayDiagonal, \ + ArrayAdd, nest_permutation, ArrayElementwiseApplyFunc, _EditArrayContraction, _ArgE, _array_tensor_product, \ + _array_contraction, _array_diagonal, _array_add, _permute_dims, Reshape +from sympy.testing.pytest import raises + +i, j, k, l, m, n = symbols("i j k l m n") + + +M = ArraySymbol("M", (k, k)) +N = ArraySymbol("N", (k, k)) +P = ArraySymbol("P", (k, k)) +Q = ArraySymbol("Q", (k, k)) + +A = ArraySymbol("A", (k, k)) +B = ArraySymbol("B", (k, k)) +C = ArraySymbol("C", (k, k)) +D = ArraySymbol("D", (k, k)) + +X = ArraySymbol("X", (k, k)) +Y = ArraySymbol("Y", (k, k)) + +a = ArraySymbol("a", (k, 1)) +b = ArraySymbol("b", (k, 1)) +c = ArraySymbol("c", (k, 1)) +d = ArraySymbol("d", (k, 1)) + + +def test_array_symbol_and_element(): + A = ArraySymbol("A", (2,)) + A0 = ArrayElement(A, (0,)) + A1 = ArrayElement(A, (1,)) + assert A[0] == A0 + assert A[1] != A0 + assert A.as_explicit() == ImmutableDenseNDimArray([A0, A1]) + + A2 = tensorproduct(A, A) + assert A2.shape == (2, 2) + # TODO: not yet supported: + # assert A2.as_explicit() == Array([[A[0]*A[0], A[1]*A[0]], [A[0]*A[1], A[1]*A[1]]]) + A3 = tensorcontraction(A2, (0, 1)) + assert A3.shape == () + # TODO: not yet supported: + # assert A3.as_explicit() == Array([]) + + A = ArraySymbol("A", (2, 3, 4)) + Ae = A.as_explicit() + assert Ae == ImmutableDenseNDimArray( + [[[ArrayElement(A, (i, j, k)) for k in range(4)] for j in range(3)] for i in range(2)]) + + p = _permute_dims(A, Permutation(0, 2, 1)) + assert isinstance(p, PermuteDims) + + A = ArraySymbol("A", (2,)) + raises(IndexError, lambda: A[()]) + raises(IndexError, lambda: A[0, 1]) + raises(ValueError, lambda: A[-1]) + raises(ValueError, lambda: A[2]) + + O = OneArray(3, 4) + Z = ZeroArray(m, n) + + raises(IndexError, lambda: O[()]) + raises(IndexError, lambda: O[1, 2, 3]) + raises(ValueError, lambda: O[3, 0]) + raises(ValueError, lambda: O[0, 4]) + + assert O[1, 2] == 1 + assert Z[1, 2] == 0 + + +def test_zero_array(): + assert ZeroArray() == 0 + assert ZeroArray().is_Integer + + za = ZeroArray(3, 2, 4) + assert za.shape == (3, 2, 4) + za_e = za.as_explicit() + assert za_e.shape == (3, 2, 4) + + m, n, k = symbols("m n k") + za = ZeroArray(m, n, k, 2) + assert za.shape == (m, n, k, 2) + raises(ValueError, lambda: za.as_explicit()) + + +def test_one_array(): + assert OneArray() == 1 + assert OneArray().is_Integer + + oa = OneArray(3, 2, 4) + assert oa.shape == (3, 2, 4) + oa_e = oa.as_explicit() + assert oa_e.shape == (3, 2, 4) + + m, n, k = symbols("m n k") + oa = OneArray(m, n, k, 2) + assert oa.shape == (m, n, k, 2) + raises(ValueError, lambda: oa.as_explicit()) + + +def test_arrayexpr_contraction_construction(): + + cg = _array_contraction(A) + assert cg == A + + cg = _array_contraction(_array_tensor_product(A, B), (1, 0)) + assert cg == _array_contraction(_array_tensor_product(A, B), (0, 1)) + + cg = _array_contraction(_array_tensor_product(M, N), (0, 1)) + indtup = cg._get_contraction_tuples() + assert indtup == [[(0, 0), (0, 1)]] + assert cg._contraction_tuples_to_contraction_indices(cg.expr, indtup) == [(0, 1)] + + cg = _array_contraction(_array_tensor_product(M, N), (1, 2)) + indtup = cg._get_contraction_tuples() + assert indtup == [[(0, 1), (1, 0)]] + assert cg._contraction_tuples_to_contraction_indices(cg.expr, indtup) == [(1, 2)] + + cg = _array_contraction(_array_tensor_product(M, M, N), (1, 4), (2, 5)) + indtup = cg._get_contraction_tuples() + assert indtup == [[(0, 0), (1, 1)], [(0, 1), (2, 0)]] + assert cg._contraction_tuples_to_contraction_indices(cg.expr, indtup) == [(0, 3), (1, 4)] + + # Test removal of trivial contraction: + assert _array_contraction(a, (1,)) == a + assert _array_contraction( + _array_tensor_product(a, b), (0, 2), (1,), (3,)) == _array_contraction( + _array_tensor_product(a, b), (0, 2)) + + +def test_arrayexpr_array_flatten(): + + # Flatten nested ArrayTensorProduct objects: + expr1 = _array_tensor_product(M, N) + expr2 = _array_tensor_product(P, Q) + expr = _array_tensor_product(expr1, expr2) + assert expr == _array_tensor_product(M, N, P, Q) + assert expr.args == (M, N, P, Q) + + # Flatten mixed ArrayTensorProduct and ArrayContraction objects: + cg1 = _array_contraction(expr1, (1, 2)) + cg2 = _array_contraction(expr2, (0, 3)) + + expr = _array_tensor_product(cg1, cg2) + assert expr == _array_contraction(_array_tensor_product(M, N, P, Q), (1, 2), (4, 7)) + + expr = _array_tensor_product(M, cg1) + assert expr == _array_contraction(_array_tensor_product(M, M, N), (3, 4)) + + # Flatten nested ArrayContraction objects: + cgnested = _array_contraction(cg1, (0, 1)) + assert cgnested == _array_contraction(_array_tensor_product(M, N), (0, 3), (1, 2)) + + cgnested = _array_contraction(_array_tensor_product(cg1, cg2), (0, 3)) + assert cgnested == _array_contraction(_array_tensor_product(M, N, P, Q), (0, 6), (1, 2), (4, 7)) + + cg3 = _array_contraction(_array_tensor_product(M, N, P, Q), (1, 3), (2, 4)) + cgnested = _array_contraction(cg3, (0, 1)) + assert cgnested == _array_contraction(_array_tensor_product(M, N, P, Q), (0, 5), (1, 3), (2, 4)) + + cgnested = _array_contraction(cg3, (0, 3), (1, 2)) + assert cgnested == _array_contraction(_array_tensor_product(M, N, P, Q), (0, 7), (1, 3), (2, 4), (5, 6)) + + cg4 = _array_contraction(_array_tensor_product(M, N, P, Q), (1, 5), (3, 7)) + cgnested = _array_contraction(cg4, (0, 1)) + assert cgnested == _array_contraction(_array_tensor_product(M, N, P, Q), (0, 2), (1, 5), (3, 7)) + + cgnested = _array_contraction(cg4, (0, 1), (2, 3)) + assert cgnested == _array_contraction(_array_tensor_product(M, N, P, Q), (0, 2), (1, 5), (3, 7), (4, 6)) + + cg = _array_diagonal(cg4) + assert cg == cg4 + assert isinstance(cg, type(cg4)) + + # Flatten nested ArrayDiagonal objects: + cg1 = _array_diagonal(expr1, (1, 2)) + cg2 = _array_diagonal(expr2, (0, 3)) + cg3 = _array_diagonal(_array_tensor_product(M, N, P, Q), (1, 3), (2, 4)) + cg4 = _array_diagonal(_array_tensor_product(M, N, P, Q), (1, 5), (3, 7)) + + cgnested = _array_diagonal(cg1, (0, 1)) + assert cgnested == _array_diagonal(_array_tensor_product(M, N), (1, 2), (0, 3)) + + cgnested = _array_diagonal(cg3, (1, 2)) + assert cgnested == _array_diagonal(_array_tensor_product(M, N, P, Q), (1, 3), (2, 4), (5, 6)) + + cgnested = _array_diagonal(cg4, (1, 2)) + assert cgnested == _array_diagonal(_array_tensor_product(M, N, P, Q), (1, 5), (3, 7), (2, 4)) + + cg = _array_add(M, N) + cg2 = _array_add(cg, P) + assert isinstance(cg2, ArrayAdd) + assert cg2.args == (M, N, P) + assert cg2.shape == (k, k) + + expr = _array_tensor_product(_array_diagonal(X, (0, 1)), _array_diagonal(A, (0, 1))) + assert expr == _array_diagonal(_array_tensor_product(X, A), (0, 1), (2, 3)) + + expr1 = _array_diagonal(_array_tensor_product(X, A), (1, 2)) + expr2 = _array_tensor_product(expr1, a) + assert expr2 == _permute_dims(_array_diagonal(_array_tensor_product(X, A, a), (1, 2)), [0, 1, 4, 2, 3]) + + expr1 = _array_contraction(_array_tensor_product(X, A), (1, 2)) + expr2 = _array_tensor_product(expr1, a) + assert isinstance(expr2, ArrayContraction) + assert isinstance(expr2.expr, ArrayTensorProduct) + + cg = _array_tensor_product(_array_diagonal(_array_tensor_product(A, X, Y), (0, 3), (1, 5)), a, b) + assert cg == _permute_dims(_array_diagonal(_array_tensor_product(A, X, Y, a, b), (0, 3), (1, 5)), [0, 1, 6, 7, 2, 3, 4, 5]) + + +def test_arrayexpr_array_diagonal(): + cg = _array_diagonal(M, (1, 0)) + assert cg == _array_diagonal(M, (0, 1)) + + cg = _array_diagonal(_array_tensor_product(M, N, P), (4, 1), (2, 0)) + assert cg == _array_diagonal(_array_tensor_product(M, N, P), (1, 4), (0, 2)) + + cg = _array_diagonal(_array_tensor_product(M, N), (1, 2), (3,), allow_trivial_diags=True) + assert cg == _permute_dims(_array_diagonal(_array_tensor_product(M, N), (1, 2)), [0, 2, 1]) + + Ax = ArraySymbol("Ax", shape=(1, 2, 3, 4, 3, 5, 6, 2, 7)) + cg = _array_diagonal(Ax, (1, 7), (3,), (2, 4), (6,), allow_trivial_diags=True) + assert cg == _permute_dims(_array_diagonal(Ax, (1, 7), (2, 4)), [0, 2, 4, 5, 1, 6, 3]) + + cg = _array_diagonal(M, (0,), allow_trivial_diags=True) + assert cg == _permute_dims(M, [1, 0]) + + raises(ValueError, lambda: _array_diagonal(M, (0, 0))) + + +def test_arrayexpr_array_shape(): + expr = _array_tensor_product(M, N, P, Q) + assert expr.shape == (k, k, k, k, k, k, k, k) + Z = MatrixSymbol("Z", m, n) + expr = _array_tensor_product(M, Z) + assert expr.shape == (k, k, m, n) + expr2 = _array_contraction(expr, (0, 1)) + assert expr2.shape == (m, n) + expr2 = _array_diagonal(expr, (0, 1)) + assert expr2.shape == (m, n, k) + exprp = _permute_dims(expr, [2, 1, 3, 0]) + assert exprp.shape == (m, k, n, k) + expr3 = _array_tensor_product(N, Z) + expr2 = _array_add(expr, expr3) + assert expr2.shape == (k, k, m, n) + + # Contraction along axes with discordant dimensions: + raises(ValueError, lambda: _array_contraction(expr, (1, 2))) + # Also diagonal needs the same dimensions: + raises(ValueError, lambda: _array_diagonal(expr, (1, 2))) + # Diagonal requires at least to axes to compute the diagonal: + raises(ValueError, lambda: _array_diagonal(expr, (1,))) + + +def test_arrayexpr_permutedims_sink(): + + cg = _permute_dims(_array_tensor_product(M, N), [0, 1, 3, 2], nest_permutation=False) + sunk = nest_permutation(cg) + assert sunk == _array_tensor_product(M, _permute_dims(N, [1, 0])) + + cg = _permute_dims(_array_tensor_product(M, N), [1, 0, 3, 2], nest_permutation=False) + sunk = nest_permutation(cg) + assert sunk == _array_tensor_product(_permute_dims(M, [1, 0]), _permute_dims(N, [1, 0])) + + cg = _permute_dims(_array_tensor_product(M, N), [3, 2, 1, 0], nest_permutation=False) + sunk = nest_permutation(cg) + assert sunk == _array_tensor_product(_permute_dims(N, [1, 0]), _permute_dims(M, [1, 0])) + + cg = _permute_dims(_array_contraction(_array_tensor_product(M, N), (1, 2)), [1, 0], nest_permutation=False) + sunk = nest_permutation(cg) + assert sunk == _array_contraction(_permute_dims(_array_tensor_product(M, N), [[0, 3]]), (1, 2)) + + cg = _permute_dims(_array_tensor_product(M, N), [1, 0, 3, 2], nest_permutation=False) + sunk = nest_permutation(cg) + assert sunk == _array_tensor_product(_permute_dims(M, [1, 0]), _permute_dims(N, [1, 0])) + + cg = _permute_dims(_array_contraction(_array_tensor_product(M, N, P), (1, 2), (3, 4)), [1, 0], nest_permutation=False) + sunk = nest_permutation(cg) + assert sunk == _array_contraction(_permute_dims(_array_tensor_product(M, N, P), [[0, 5]]), (1, 2), (3, 4)) + + +def test_arrayexpr_push_indices_up_and_down(): + + indices = list(range(12)) + + contr_diag_indices = [(0, 6), (2, 8)] + assert ArrayContraction._push_indices_down(contr_diag_indices, indices) == (1, 3, 4, 5, 7, 9, 10, 11, 12, 13, 14, 15) + assert ArrayContraction._push_indices_up(contr_diag_indices, indices) == (None, 0, None, 1, 2, 3, None, 4, None, 5, 6, 7) + + assert ArrayDiagonal._push_indices_down(contr_diag_indices, indices, 10) == (1, 3, 4, 5, 7, 9, (0, 6), (2, 8), None, None, None, None) + assert ArrayDiagonal._push_indices_up(contr_diag_indices, indices, 10) == (6, 0, 7, 1, 2, 3, 6, 4, 7, 5, None, None) + + contr_diag_indices = [(1, 2), (7, 8)] + assert ArrayContraction._push_indices_down(contr_diag_indices, indices) == (0, 3, 4, 5, 6, 9, 10, 11, 12, 13, 14, 15) + assert ArrayContraction._push_indices_up(contr_diag_indices, indices) == (0, None, None, 1, 2, 3, 4, None, None, 5, 6, 7) + + assert ArrayDiagonal._push_indices_down(contr_diag_indices, indices, 10) == (0, 3, 4, 5, 6, 9, (1, 2), (7, 8), None, None, None, None) + assert ArrayDiagonal._push_indices_up(contr_diag_indices, indices, 10) == (0, 6, 6, 1, 2, 3, 4, 7, 7, 5, None, None) + + +def test_arrayexpr_split_multiple_contractions(): + a = MatrixSymbol("a", k, 1) + b = MatrixSymbol("b", k, 1) + A = MatrixSymbol("A", k, k) + B = MatrixSymbol("B", k, k) + C = MatrixSymbol("C", k, k) + X = MatrixSymbol("X", k, k) + + cg = _array_contraction(_array_tensor_product(A.T, a, b, b.T, (A*X*b).applyfunc(cos)), (1, 2, 8), (5, 6, 9)) + expected = _array_contraction(_array_tensor_product(A.T, DiagMatrix(a), OneArray(1), b, b.T, (A*X*b).applyfunc(cos)), (1, 3), (2, 9), (6, 7, 10)) + assert cg.split_multiple_contractions().dummy_eq(expected) + + # Check no overlap of lines: + + cg = _array_contraction(_array_tensor_product(A, a, C, a, B), (1, 2, 4), (5, 6, 8), (3, 7)) + assert cg.split_multiple_contractions() == cg + + cg = _array_contraction(_array_tensor_product(a, b, A), (0, 2, 4), (1, 3)) + assert cg.split_multiple_contractions() == cg + + +def test_arrayexpr_nested_permutations(): + + cg = _permute_dims(_permute_dims(M, (1, 0)), (1, 0)) + assert cg == M + + times = 3 + plist1 = [list(range(6)) for i in range(times)] + plist2 = [list(range(6)) for i in range(times)] + + for i in range(times): + random.shuffle(plist1[i]) + random.shuffle(plist2[i]) + + plist1.append([2, 5, 4, 1, 0, 3]) + plist2.append([3, 5, 0, 4, 1, 2]) + + plist1.append([2, 5, 4, 0, 3, 1]) + plist2.append([3, 0, 5, 1, 2, 4]) + + plist1.append([5, 4, 2, 0, 3, 1]) + plist2.append([4, 5, 0, 2, 3, 1]) + + Me = M.subs(k, 3).as_explicit() + Ne = N.subs(k, 3).as_explicit() + Pe = P.subs(k, 3).as_explicit() + cge = tensorproduct(Me, Ne, Pe) + + for permutation_array1, permutation_array2 in zip(plist1, plist2): + p1 = Permutation(permutation_array1) + p2 = Permutation(permutation_array2) + + cg = _permute_dims( + _permute_dims( + _array_tensor_product(M, N, P), + p1), + p2 + ) + result = _permute_dims( + _array_tensor_product(M, N, P), + p2*p1 + ) + assert cg == result + + # Check that `permutedims` behaves the same way with explicit-component arrays: + result1 = _permute_dims(_permute_dims(cge, p1), p2) + result2 = _permute_dims(cge, p2*p1) + assert result1 == result2 + + +def test_arrayexpr_contraction_permutation_mix(): + + Me = M.subs(k, 3).as_explicit() + Ne = N.subs(k, 3).as_explicit() + + cg1 = _array_contraction(PermuteDims(_array_tensor_product(M, N), Permutation([0, 2, 1, 3])), (2, 3)) + cg2 = _array_contraction(_array_tensor_product(M, N), (1, 3)) + assert cg1 == cg2 + cge1 = tensorcontraction(permutedims(tensorproduct(Me, Ne), Permutation([0, 2, 1, 3])), (2, 3)) + cge2 = tensorcontraction(tensorproduct(Me, Ne), (1, 3)) + assert cge1 == cge2 + + cg1 = _permute_dims(_array_tensor_product(M, N), Permutation([0, 1, 3, 2])) + cg2 = _array_tensor_product(M, _permute_dims(N, Permutation([1, 0]))) + assert cg1 == cg2 + + cg1 = _array_contraction( + _permute_dims( + _array_tensor_product(M, N, P, Q), Permutation([0, 2, 3, 1, 4, 5, 7, 6])), + (1, 2), (3, 5) + ) + cg2 = _array_contraction( + _array_tensor_product(M, N, P, _permute_dims(Q, Permutation([1, 0]))), + (1, 5), (2, 3) + ) + assert cg1 == cg2 + + cg1 = _array_contraction( + _permute_dims( + _array_tensor_product(M, N, P, Q), Permutation([1, 0, 4, 6, 2, 7, 5, 3])), + (0, 1), (2, 6), (3, 7) + ) + cg2 = _permute_dims( + _array_contraction( + _array_tensor_product(M, P, Q, N), + (0, 1), (2, 3), (4, 7)), + [1, 0] + ) + assert cg1 == cg2 + + cg1 = _array_contraction( + _permute_dims( + _array_tensor_product(M, N, P, Q), Permutation([1, 0, 4, 6, 7, 2, 5, 3])), + (0, 1), (2, 6), (3, 7) + ) + cg2 = _permute_dims( + _array_contraction( + _array_tensor_product(_permute_dims(M, [1, 0]), N, P, Q), + (0, 1), (3, 6), (4, 5) + ), + Permutation([1, 0]) + ) + assert cg1 == cg2 + + +def test_arrayexpr_permute_tensor_product(): + cg1 = _permute_dims(_array_tensor_product(M, N, P, Q), Permutation([2, 3, 1, 0, 5, 4, 6, 7])) + cg2 = _array_tensor_product(N, _permute_dims(M, [1, 0]), + _permute_dims(P, [1, 0]), Q) + assert cg1 == cg2 + + # TODO: reverse operation starting with `PermuteDims` and getting down to `bb`... + cg1 = _permute_dims(_array_tensor_product(M, N, P, Q), Permutation([2, 3, 4, 5, 0, 1, 6, 7])) + cg2 = _array_tensor_product(N, P, M, Q) + assert cg1 == cg2 + + cg1 = _permute_dims(_array_tensor_product(M, N, P, Q), Permutation([2, 3, 4, 6, 5, 7, 0, 1])) + assert cg1.expr == _array_tensor_product(N, P, Q, M) + assert cg1.permutation == Permutation([0, 1, 2, 4, 3, 5, 6, 7]) + + cg1 = _array_contraction( + _permute_dims( + _array_tensor_product(N, Q, Q, M), + [2, 1, 5, 4, 0, 3, 6, 7]), + [1, 2, 6]) + cg2 = _permute_dims(_array_contraction(_array_tensor_product(Q, Q, N, M), (3, 5, 6)), [0, 2, 3, 1, 4]) + assert cg1 == cg2 + + cg1 = _array_contraction( + _array_contraction( + _array_contraction( + _array_contraction( + _permute_dims( + _array_tensor_product(N, Q, Q, M), + [2, 1, 5, 4, 0, 3, 6, 7]), + [1, 2, 6]), + [1, 3, 4]), + [1]), + [0]) + cg2 = _array_contraction(_array_tensor_product(M, N, Q, Q), (0, 3, 5), (1, 4, 7), (2,), (6,)) + assert cg1 == cg2 + + +def test_arrayexpr_canonicalize_diagonal__permute_dims(): + tp = _array_tensor_product(M, Q, N, P) + expr = _array_diagonal( + _permute_dims(tp, [0, 1, 2, 4, 7, 6, 3, 5]), (2, 4, 5), (6, 7), + (0, 3)) + result = _array_diagonal(tp, (2, 6, 7), (3, 5), (0, 4)) + assert expr == result + + tp = _array_tensor_product(M, N, P, Q) + expr = _array_diagonal(_permute_dims(tp, [0, 5, 2, 4, 1, 6, 3, 7]), (1, 2, 6), (3, 4)) + result = _array_diagonal(_array_tensor_product(M, P, N, Q), (3, 4, 5), (1, 2)) + assert expr == result + + +def test_arrayexpr_canonicalize_diagonal_contraction(): + tp = _array_tensor_product(M, N, P, Q) + expr = _array_contraction(_array_diagonal(tp, (1, 3, 4)), (0, 3)) + result = _array_diagonal(_array_contraction(_array_tensor_product(M, N, P, Q), (0, 6)), (0, 2, 3)) + assert expr == result + + expr = _array_contraction(_array_diagonal(tp, (0, 1, 2, 3, 7)), (1, 2, 3)) + result = _array_contraction(_array_tensor_product(M, N, P, Q), (0, 1, 2, 3, 5, 6, 7)) + assert expr == result + + expr = _array_contraction(_array_diagonal(tp, (0, 2, 6, 7)), (1, 2, 3)) + result = _array_diagonal(_array_contraction(tp, (3, 4, 5)), (0, 2, 3, 4)) + assert expr == result + + td = _array_diagonal(_array_tensor_product(M, N, P, Q), (0, 3)) + expr = _array_contraction(td, (2, 1), (0, 4, 6, 5, 3)) + result = _array_contraction(_array_tensor_product(M, N, P, Q), (0, 1, 3, 5, 6, 7), (2, 4)) + assert expr == result + + +def test_arrayexpr_array_wrong_permutation_size(): + cg = _array_tensor_product(M, N) + raises(ValueError, lambda: _permute_dims(cg, [1, 0])) + raises(ValueError, lambda: _permute_dims(cg, [1, 0, 2, 3, 5, 4])) + + +def test_arrayexpr_nested_array_elementwise_add(): + cg = _array_contraction(_array_add( + _array_tensor_product(M, N), + _array_tensor_product(N, M) + ), (1, 2)) + result = _array_add( + _array_contraction(_array_tensor_product(M, N), (1, 2)), + _array_contraction(_array_tensor_product(N, M), (1, 2)) + ) + assert cg == result + + cg = _array_diagonal(_array_add( + _array_tensor_product(M, N), + _array_tensor_product(N, M) + ), (1, 2)) + result = _array_add( + _array_diagonal(_array_tensor_product(M, N), (1, 2)), + _array_diagonal(_array_tensor_product(N, M), (1, 2)) + ) + assert cg == result + + +def test_arrayexpr_array_expr_zero_array(): + za1 = ZeroArray(k, l, m, n) + zm1 = ZeroMatrix(m, n) + + za2 = ZeroArray(k, m, m, n) + zm2 = ZeroMatrix(m, m) + zm3 = ZeroMatrix(k, k) + + assert _array_tensor_product(M, N, za1) == ZeroArray(k, k, k, k, k, l, m, n) + assert _array_tensor_product(M, N, zm1) == ZeroArray(k, k, k, k, m, n) + + assert _array_contraction(za1, (3,)) == ZeroArray(k, l, m) + assert _array_contraction(zm1, (1,)) == ZeroArray(m) + assert _array_contraction(za2, (1, 2)) == ZeroArray(k, n) + assert _array_contraction(zm2, (0, 1)) == 0 + + assert _array_diagonal(za2, (1, 2)) == ZeroArray(k, n, m) + assert _array_diagonal(zm2, (0, 1)) == ZeroArray(m) + + assert _permute_dims(za1, [2, 1, 3, 0]) == ZeroArray(m, l, n, k) + assert _permute_dims(zm1, [1, 0]) == ZeroArray(n, m) + + assert _array_add(za1) == za1 + assert _array_add(zm1) == ZeroArray(m, n) + tp1 = _array_tensor_product(MatrixSymbol("A", k, l), MatrixSymbol("B", m, n)) + assert _array_add(tp1, za1) == tp1 + tp2 = _array_tensor_product(MatrixSymbol("C", k, l), MatrixSymbol("D", m, n)) + assert _array_add(tp1, za1, tp2) == _array_add(tp1, tp2) + assert _array_add(M, zm3) == M + assert _array_add(M, N, zm3) == _array_add(M, N) + + +def test_arrayexpr_array_expr_applyfunc(): + + A = ArraySymbol("A", (3, k, 2)) + aaf = ArrayElementwiseApplyFunc(sin, A) + assert aaf.shape == (3, k, 2) + + +def test_edit_array_contraction(): + cg = _array_contraction(_array_tensor_product(A, B, C, D), (1, 2, 5)) + ecg = _EditArrayContraction(cg) + assert ecg.to_array_contraction() == cg + + ecg.args_with_ind[1], ecg.args_with_ind[2] = ecg.args_with_ind[2], ecg.args_with_ind[1] + assert ecg.to_array_contraction() == _array_contraction(_array_tensor_product(A, C, B, D), (1, 3, 4)) + + ci = ecg.get_new_contraction_index() + new_arg = _ArgE(X) + new_arg.indices = [ci, ci] + ecg.args_with_ind.insert(2, new_arg) + assert ecg.to_array_contraction() == _array_contraction(_array_tensor_product(A, C, X, B, D), (1, 3, 6), (4, 5)) + + assert ecg.get_contraction_indices() == [[1, 3, 6], [4, 5]] + assert [[tuple(j) for j in i] for i in ecg.get_contraction_indices_to_ind_rel_pos()] == [[(0, 1), (1, 1), (3, 0)], [(2, 0), (2, 1)]] + assert [list(i) for i in ecg.get_mapping_for_index(0)] == [[0, 1], [1, 1], [3, 0]] + assert [list(i) for i in ecg.get_mapping_for_index(1)] == [[2, 0], [2, 1]] + raises(ValueError, lambda: ecg.get_mapping_for_index(2)) + + ecg.args_with_ind.pop(1) + assert ecg.to_array_contraction() == _array_contraction(_array_tensor_product(A, X, B, D), (1, 4), (2, 3)) + + ecg.args_with_ind[0].indices[1] = ecg.args_with_ind[1].indices[0] + ecg.args_with_ind[1].indices[1] = ecg.args_with_ind[2].indices[0] + assert ecg.to_array_contraction() == _array_contraction(_array_tensor_product(A, X, B, D), (1, 2), (3, 4)) + + ecg.insert_after(ecg.args_with_ind[1], _ArgE(C)) + assert ecg.to_array_contraction() == _array_contraction(_array_tensor_product(A, X, C, B, D), (1, 2), (3, 6)) + + +def test_array_expressions_no_canonicalization(): + + tp = _array_tensor_product(M, N, P) + + # ArrayTensorProduct: + + expr = ArrayTensorProduct(tp, N) + assert str(expr) == "ArrayTensorProduct(ArrayTensorProduct(M, N, P), N)" + assert expr.doit() == ArrayTensorProduct(M, N, P, N) + + expr = ArrayTensorProduct(ArrayContraction(M, (0, 1)), N) + assert str(expr) == "ArrayTensorProduct(ArrayContraction(M, (0, 1)), N)" + assert expr.doit() == ArrayContraction(ArrayTensorProduct(M, N), (0, 1)) + + expr = ArrayTensorProduct(ArrayDiagonal(M, (0, 1)), N) + assert str(expr) == "ArrayTensorProduct(ArrayDiagonal(M, (0, 1)), N)" + assert expr.doit() == PermuteDims(ArrayDiagonal(ArrayTensorProduct(M, N), (0, 1)), [2, 0, 1]) + + expr = ArrayTensorProduct(PermuteDims(M, [1, 0]), N) + assert str(expr) == "ArrayTensorProduct(PermuteDims(M, (0 1)), N)" + assert expr.doit() == PermuteDims(ArrayTensorProduct(M, N), [1, 0, 2, 3]) + + # ArrayContraction: + + expr = ArrayContraction(_array_contraction(tp, (0, 2)), (0, 1)) + assert isinstance(expr, ArrayContraction) + assert isinstance(expr.expr, ArrayContraction) + assert str(expr) == "ArrayContraction(ArrayContraction(ArrayTensorProduct(M, N, P), (0, 2)), (0, 1))" + assert expr.doit() == ArrayContraction(tp, (0, 2), (1, 3)) + + expr = ArrayContraction(ArrayContraction(ArrayContraction(tp, (0, 1)), (0, 1)), (0, 1)) + assert expr.doit() == ArrayContraction(tp, (0, 1), (2, 3), (4, 5)) + # assert expr._canonicalize() == ArrayContraction(ArrayContraction(tp, (0, 1)), (0, 1), (2, 3)) + + expr = ArrayContraction(ArrayDiagonal(tp, (0, 1)), (0, 1)) + assert str(expr) == "ArrayContraction(ArrayDiagonal(ArrayTensorProduct(M, N, P), (0, 1)), (0, 1))" + assert expr.doit() == ArrayDiagonal(ArrayContraction(ArrayTensorProduct(N, M, P), (0, 1)), (0, 1)) + + expr = ArrayContraction(PermuteDims(M, [1, 0]), (0, 1)) + assert str(expr) == "ArrayContraction(PermuteDims(M, (0 1)), (0, 1))" + assert expr.doit() == ArrayContraction(M, (0, 1)) + + # ArrayDiagonal: + + expr = ArrayDiagonal(ArrayDiagonal(tp, (0, 2)), (0, 1)) + assert str(expr) == "ArrayDiagonal(ArrayDiagonal(ArrayTensorProduct(M, N, P), (0, 2)), (0, 1))" + assert expr.doit() == ArrayDiagonal(tp, (0, 2), (1, 3)) + + expr = ArrayDiagonal(ArrayDiagonal(ArrayDiagonal(tp, (0, 1)), (0, 1)), (0, 1)) + assert expr.doit() == ArrayDiagonal(tp, (0, 1), (2, 3), (4, 5)) + assert expr._canonicalize() == expr.doit() + + expr = ArrayDiagonal(ArrayContraction(tp, (0, 1)), (0, 1)) + assert str(expr) == "ArrayDiagonal(ArrayContraction(ArrayTensorProduct(M, N, P), (0, 1)), (0, 1))" + assert expr.doit() == expr + + expr = ArrayDiagonal(PermuteDims(M, [1, 0]), (0, 1)) + assert str(expr) == "ArrayDiagonal(PermuteDims(M, (0 1)), (0, 1))" + assert expr.doit() == ArrayDiagonal(M, (0, 1)) + + # ArrayAdd: + + expr = ArrayAdd(M) + assert isinstance(expr, ArrayAdd) + assert expr.doit() == M + + expr = ArrayAdd(ArrayAdd(M, N), P) + assert str(expr) == "ArrayAdd(ArrayAdd(M, N), P)" + assert expr.doit() == ArrayAdd(M, N, P) + + expr = ArrayAdd(M, ArrayAdd(N, ArrayAdd(P, M))) + assert expr.doit() == ArrayAdd(M, N, P, M) + assert expr._canonicalize() == ArrayAdd(M, N, ArrayAdd(P, M)) + + expr = ArrayAdd(M, ZeroArray(k, k), N) + assert str(expr) == "ArrayAdd(M, ZeroArray(k, k), N)" + assert expr.doit() == ArrayAdd(M, N) + + # PermuteDims: + + expr = PermuteDims(PermuteDims(M, [1, 0]), [1, 0]) + assert str(expr) == "PermuteDims(PermuteDims(M, (0 1)), (0 1))" + assert expr.doit() == M + + expr = PermuteDims(PermuteDims(PermuteDims(M, [1, 0]), [1, 0]), [1, 0]) + assert expr.doit() == PermuteDims(M, [1, 0]) + assert expr._canonicalize() == expr.doit() + + # Reshape + + expr = Reshape(A, (k**2,)) + assert expr.shape == (k**2,) + assert isinstance(expr, Reshape) + + +def test_array_expr_construction_with_functions(): + + tp = tensorproduct(M, N) + assert tp == ArrayTensorProduct(M, N) + + expr = tensorproduct(A, eye(2)) + assert expr == ArrayTensorProduct(A, eye(2)) + + # Contraction: + + expr = tensorcontraction(M, (0, 1)) + assert expr == ArrayContraction(M, (0, 1)) + + expr = tensorcontraction(tp, (1, 2)) + assert expr == ArrayContraction(tp, (1, 2)) + + expr = tensorcontraction(tensorcontraction(tp, (1, 2)), (0, 1)) + assert expr == ArrayContraction(tp, (0, 3), (1, 2)) + + # Diagonalization: + + expr = tensordiagonal(M, (0, 1)) + assert expr == ArrayDiagonal(M, (0, 1)) + + expr = tensordiagonal(tensordiagonal(tp, (0, 1)), (0, 1)) + assert expr == ArrayDiagonal(tp, (0, 1), (2, 3)) + + # Permutation of dimensions: + + expr = permutedims(M, [1, 0]) + assert expr == PermuteDims(M, [1, 0]) + + expr = permutedims(PermuteDims(tp, [1, 0, 2, 3]), [0, 1, 3, 2]) + assert expr == PermuteDims(tp, [1, 0, 3, 2]) + + expr = PermuteDims(tp, index_order_new=["a", "b", "c", "d"], index_order_old=["d", "c", "b", "a"]) + assert expr == PermuteDims(tp, [3, 2, 1, 0]) + + arr = Array(range(32)).reshape(2, 2, 2, 2, 2) + expr = PermuteDims(arr, index_order_new=["a", "b", "c", "d", "e"], index_order_old=['b', 'e', 'a', 'd', 'c']) + assert expr == PermuteDims(arr, [2, 0, 4, 3, 1]) + assert expr.as_explicit() == permutedims(arr, index_order_new=["a", "b", "c", "d", "e"], index_order_old=['b', 'e', 'a', 'd', 'c']) + + +def test_array_element_expressions(): + # Check commutative property: + assert M[0, 0]*N[0, 0] == N[0, 0]*M[0, 0] + + # Check derivatives: + assert M[0, 0].diff(M[0, 0]) == 1 + assert M[0, 0].diff(M[1, 0]) == 0 + assert M[0, 0].diff(N[0, 0]) == 0 + assert M[0, 1].diff(M[i, j]) == KroneckerDelta(i, 0)*KroneckerDelta(j, 1) + assert M[0, 1].diff(N[i, j]) == 0 + + K4 = ArraySymbol("K4", shape=(k, k, k, k)) + + assert K4[i, j, k, l].diff(K4[1, 2, 3, 4]) == ( + KroneckerDelta(i, 1)*KroneckerDelta(j, 2)*KroneckerDelta(k, 3)*KroneckerDelta(l, 4) + ) + + +def test_array_expr_reshape(): + + A = MatrixSymbol("A", 2, 2) + B = ArraySymbol("B", (2, 2, 2)) + C = Array([1, 2, 3, 4]) + + expr = Reshape(A, (4,)) + assert expr.expr == A + assert expr.shape == (4,) + assert expr.as_explicit() == Array([A[0, 0], A[0, 1], A[1, 0], A[1, 1]]) + + expr = Reshape(B, (2, 4)) + assert expr.expr == B + assert expr.shape == (2, 4) + ee = expr.as_explicit() + assert isinstance(ee, ImmutableDenseNDimArray) + assert ee.shape == (2, 4) + assert ee == Array([[B[0, 0, 0], B[0, 0, 1], B[0, 1, 0], B[0, 1, 1]], [B[1, 0, 0], B[1, 0, 1], B[1, 1, 0], B[1, 1, 1]]]) + + expr = Reshape(A, (k, 2)) + assert expr.shape == (k, 2) + + raises(ValueError, lambda: Reshape(A, (2, 3))) + raises(ValueError, lambda: Reshape(A, (3,))) + + expr = Reshape(C, (2, 2)) + assert expr.expr == C + assert expr.shape == (2, 2) + assert expr.doit() == Array([[1, 2], [3, 4]]) + + +def test_array_expr_as_explicit_with_explicit_component_arrays(): + # Test if .as_explicit() works with explicit-component arrays + # nested in array expressions: + from sympy.abc import x, y, z, t + A = Array([[x, y], [z, t]]) + assert ArrayTensorProduct(A, A).as_explicit() == tensorproduct(A, A) + assert ArrayDiagonal(A, (0, 1)).as_explicit() == tensordiagonal(A, (0, 1)) + assert ArrayContraction(A, (0, 1)).as_explicit() == tensorcontraction(A, (0, 1)) + assert ArrayAdd(A, A).as_explicit() == A + A + assert ArrayElementwiseApplyFunc(sin, A).as_explicit() == A.applyfunc(sin) + assert PermuteDims(A, [1, 0]).as_explicit() == permutedims(A, [1, 0]) + assert Reshape(A, [4]).as_explicit() == A.reshape(4) diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/expressions/tests/test_arrayexpr_derivatives.py b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/tests/test_arrayexpr_derivatives.py new file mode 100644 index 0000000000000000000000000000000000000000..bc0fcf63f2607b23feb38758e4f0994de4f0384b --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/tests/test_arrayexpr_derivatives.py @@ -0,0 +1,78 @@ +from sympy.core.symbol import symbols +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.expressions.special import Identity +from sympy.matrices.expressions.applyfunc import ElementwiseApplyFunction +from sympy.tensor.array.expressions.array_expressions import ArraySymbol, ArrayTensorProduct, \ + PermuteDims, ArrayDiagonal, ArrayElementwiseApplyFunc, ArrayContraction, _permute_dims, Reshape +from sympy.tensor.array.expressions.arrayexpr_derivatives import array_derive + +k = symbols("k") + +I = Identity(k) +X = MatrixSymbol("X", k, k) +x = MatrixSymbol("x", k, 1) + +A = MatrixSymbol("A", k, k) +B = MatrixSymbol("B", k, k) +C = MatrixSymbol("C", k, k) +D = MatrixSymbol("D", k, k) + +A1 = ArraySymbol("A", (3, 2, k)) + + +def test_arrayexpr_derivatives1(): + + res = array_derive(X, X) + assert res == PermuteDims(ArrayTensorProduct(I, I), [0, 2, 1, 3]) + + cg = ArrayTensorProduct(A, X, B) + res = array_derive(cg, X) + assert res == _permute_dims( + ArrayTensorProduct(I, A, I, B), + [0, 4, 2, 3, 1, 5, 6, 7]) + + cg = ArrayContraction(X, (0, 1)) + res = array_derive(cg, X) + assert res == ArrayContraction(ArrayTensorProduct(I, I), (1, 3)) + + cg = ArrayDiagonal(X, (0, 1)) + res = array_derive(cg, X) + assert res == ArrayDiagonal(ArrayTensorProduct(I, I), (1, 3)) + + cg = ElementwiseApplyFunction(sin, X) + res = array_derive(cg, X) + assert res.dummy_eq(ArrayDiagonal( + ArrayTensorProduct( + ElementwiseApplyFunction(cos, X), + I, + I + ), (0, 3), (1, 5))) + + cg = ArrayElementwiseApplyFunc(sin, X) + res = array_derive(cg, X) + assert res.dummy_eq(ArrayDiagonal( + ArrayTensorProduct( + I, + I, + ArrayElementwiseApplyFunc(cos, X) + ), (1, 4), (3, 5))) + + res = array_derive(A1, A1) + assert res == PermuteDims( + ArrayTensorProduct(Identity(3), Identity(2), Identity(k)), + [0, 2, 4, 1, 3, 5] + ) + + cg = ArrayElementwiseApplyFunc(sin, A1) + res = array_derive(cg, A1) + assert res.dummy_eq(ArrayDiagonal( + ArrayTensorProduct( + Identity(3), Identity(2), Identity(k), + ArrayElementwiseApplyFunc(cos, A1) + ), (1, 6), (3, 7), (5, 8) + )) + + cg = Reshape(A, (k**2,)) + res = array_derive(cg, A) + assert res == Reshape(PermuteDims(ArrayTensorProduct(I, I), [0, 2, 1, 3]), (k, k, k**2)) diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/expressions/tests/test_as_explicit.py b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/tests/test_as_explicit.py new file mode 100644 index 0000000000000000000000000000000000000000..30cc61b1ee651ca032e165cd67926fa33c71354f --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/tests/test_as_explicit.py @@ -0,0 +1,63 @@ +from sympy.core.symbol import Symbol +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.tensor.array.arrayop import (permutedims, tensorcontraction, tensordiagonal, tensorproduct) +from sympy.tensor.array.dense_ndim_array import ImmutableDenseNDimArray +from sympy.tensor.array.expressions.array_expressions import ZeroArray, OneArray, ArraySymbol, \ + ArrayTensorProduct, PermuteDims, ArrayDiagonal, ArrayContraction, ArrayAdd +from sympy.testing.pytest import raises + + +def test_array_as_explicit_call(): + + assert ZeroArray(3, 2, 4).as_explicit() == ImmutableDenseNDimArray.zeros(3, 2, 4) + assert OneArray(3, 2, 4).as_explicit() == ImmutableDenseNDimArray([1 for i in range(3*2*4)]).reshape(3, 2, 4) + + k = Symbol("k") + X = ArraySymbol("X", (k, 3, 2)) + raises(ValueError, lambda: X.as_explicit()) + raises(ValueError, lambda: ZeroArray(k, 2, 3).as_explicit()) + raises(ValueError, lambda: OneArray(2, k, 2).as_explicit()) + + A = ArraySymbol("A", (3, 3)) + B = ArraySymbol("B", (3, 3)) + + texpr = tensorproduct(A, B) + assert isinstance(texpr, ArrayTensorProduct) + assert texpr.as_explicit() == tensorproduct(A.as_explicit(), B.as_explicit()) + + texpr = tensorcontraction(A, (0, 1)) + assert isinstance(texpr, ArrayContraction) + assert texpr.as_explicit() == A[0, 0] + A[1, 1] + A[2, 2] + + texpr = tensordiagonal(A, (0, 1)) + assert isinstance(texpr, ArrayDiagonal) + assert texpr.as_explicit() == ImmutableDenseNDimArray([A[0, 0], A[1, 1], A[2, 2]]) + + texpr = permutedims(A, [1, 0]) + assert isinstance(texpr, PermuteDims) + assert texpr.as_explicit() == permutedims(A.as_explicit(), [1, 0]) + + +def test_array_as_explicit_matrix_symbol(): + + A = MatrixSymbol("A", 3, 3) + B = MatrixSymbol("B", 3, 3) + + texpr = tensorproduct(A, B) + assert isinstance(texpr, ArrayTensorProduct) + assert texpr.as_explicit() == tensorproduct(A.as_explicit(), B.as_explicit()) + + texpr = tensorcontraction(A, (0, 1)) + assert isinstance(texpr, ArrayContraction) + assert texpr.as_explicit() == A[0, 0] + A[1, 1] + A[2, 2] + + texpr = tensordiagonal(A, (0, 1)) + assert isinstance(texpr, ArrayDiagonal) + assert texpr.as_explicit() == ImmutableDenseNDimArray([A[0, 0], A[1, 1], A[2, 2]]) + + texpr = permutedims(A, [1, 0]) + assert isinstance(texpr, PermuteDims) + assert texpr.as_explicit() == permutedims(A.as_explicit(), [1, 0]) + + expr = ArrayAdd(ArrayTensorProduct(A, B), ArrayTensorProduct(B, A)) + assert expr.as_explicit() == expr.args[0].as_explicit() + expr.args[1].as_explicit() diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/expressions/tests/test_convert_array_to_indexed.py b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/tests/test_convert_array_to_indexed.py new file mode 100644 index 0000000000000000000000000000000000000000..a6b713fbec94ab7808c5a8a778b3313402d9d0c7 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/tests/test_convert_array_to_indexed.py @@ -0,0 +1,61 @@ +from sympy import Sum, Dummy, sin +from sympy.tensor.array.expressions import ArraySymbol, ArrayTensorProduct, ArrayContraction, PermuteDims, \ + ArrayDiagonal, ArrayAdd, OneArray, ZeroArray, convert_indexed_to_array, ArrayElementwiseApplyFunc, Reshape +from sympy.tensor.array.expressions.from_array_to_indexed import convert_array_to_indexed + +from sympy.abc import i, j, k, l, m, n, o + + +def test_convert_array_to_indexed_main(): + A = ArraySymbol("A", (3, 3, 3)) + B = ArraySymbol("B", (3, 3)) + C = ArraySymbol("C", (3, 3)) + + d_ = Dummy("d_") + + assert convert_array_to_indexed(A, [i, j, k]) == A[i, j, k] + + expr = ArrayTensorProduct(A, B, C) + conv = convert_array_to_indexed(expr, [i,j,k,l,m,n,o]) + assert conv == A[i,j,k]*B[l,m]*C[n,o] + assert convert_indexed_to_array(conv, [i,j,k,l,m,n,o]) == expr + + expr = ArrayContraction(A, (0, 2)) + assert convert_array_to_indexed(expr, [i]).dummy_eq(Sum(A[d_, i, d_], (d_, 0, 2))) + + expr = ArrayDiagonal(A, (0, 2)) + assert convert_array_to_indexed(expr, [i, j]) == A[j, i, j] + + expr = PermuteDims(A, [1, 2, 0]) + conv = convert_array_to_indexed(expr, [i, j, k]) + assert conv == A[k, i, j] + assert convert_indexed_to_array(conv, [i, j, k]) == expr + + expr = ArrayAdd(B, C, PermuteDims(C, [1, 0])) + conv = convert_array_to_indexed(expr, [i, j]) + assert conv == B[i, j] + C[i, j] + C[j, i] + assert convert_indexed_to_array(conv, [i, j]) == expr + + expr = ArrayElementwiseApplyFunc(sin, A) + conv = convert_array_to_indexed(expr, [i, j, k]) + assert conv == sin(A[i, j, k]) + assert convert_indexed_to_array(conv, [i, j, k]).dummy_eq(expr) + + assert convert_array_to_indexed(OneArray(3, 3), [i, j]) == 1 + assert convert_array_to_indexed(ZeroArray(3, 3), [i, j]) == 0 + + expr = Reshape(A, (27,)) + assert convert_array_to_indexed(expr, [i]) == A[i // 9, i // 3 % 3, i % 3] + + X = ArraySymbol("X", (2, 3, 4, 5, 6)) + expr = Reshape(X, (2*3*4*5*6,)) + assert convert_array_to_indexed(expr, [i]) == X[i // 360, i // 120 % 3, i // 30 % 4, i // 6 % 5, i % 6] + + expr = Reshape(X, (4, 9, 2, 2, 5)) + one_index = 180*i + 20*j + 10*k + 5*l + m + expected = X[one_index // (3*4*5*6), one_index // (4*5*6) % 3, one_index // (5*6) % 4, one_index // 6 % 5, one_index % 6] + assert convert_array_to_indexed(expr, [i, j, k, l, m]) == expected + + X = ArraySymbol("X", (2*3*5,)) + expr = Reshape(X, (2, 3, 5)) + assert convert_array_to_indexed(expr, [i, j, k]) == X[15*i + 5*j + k] diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/expressions/tests/test_convert_array_to_matrix.py b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/tests/test_convert_array_to_matrix.py new file mode 100644 index 0000000000000000000000000000000000000000..26839d5e7cec0554948c6b726482f9d8ca250b1c --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/tests/test_convert_array_to_matrix.py @@ -0,0 +1,689 @@ +from sympy import Lambda, S, Dummy, KroneckerProduct +from sympy.core.symbol import symbols +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import cos, sin +from sympy.matrices.expressions.hadamard import HadamardProduct, HadamardPower +from sympy.matrices.expressions.special import (Identity, OneMatrix, ZeroMatrix) +from sympy.matrices.expressions.matexpr import MatrixElement +from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array +from sympy.tensor.array.expressions.from_array_to_matrix import _support_function_tp1_recognize, \ + _array_diag2contr_diagmatrix, convert_array_to_matrix, _remove_trivial_dims, _array2matrix, \ + _combine_removed, identify_removable_identity_matrices, _array_contraction_to_diagonal_multiple_identity +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.combinatorics import Permutation +from sympy.matrices.expressions.diagonal import DiagMatrix, DiagonalMatrix +from sympy.matrices import Trace, MatMul, Transpose +from sympy.tensor.array.expressions.array_expressions import ZeroArray, OneArray, \ + ArrayElement, ArraySymbol, ArrayElementwiseApplyFunc, _array_tensor_product, _array_contraction, \ + _array_diagonal, _permute_dims, PermuteDims, ArrayAdd, ArrayDiagonal, ArrayContraction, ArrayTensorProduct +from sympy.testing.pytest import raises + + +i, j, k, l, m, n = symbols("i j k l m n") + +I = Identity(k) +I1 = Identity(1) + +M = MatrixSymbol("M", k, k) +N = MatrixSymbol("N", k, k) +P = MatrixSymbol("P", k, k) +Q = MatrixSymbol("Q", k, k) + +A = MatrixSymbol("A", k, k) +B = MatrixSymbol("B", k, k) +C = MatrixSymbol("C", k, k) +D = MatrixSymbol("D", k, k) + +X = MatrixSymbol("X", k, k) +Y = MatrixSymbol("Y", k, k) + +a = MatrixSymbol("a", k, 1) +b = MatrixSymbol("b", k, 1) +c = MatrixSymbol("c", k, 1) +d = MatrixSymbol("d", k, 1) + +x = MatrixSymbol("x", k, 1) +y = MatrixSymbol("y", k, 1) + + +def test_arrayexpr_convert_array_to_matrix(): + + cg = _array_contraction(_array_tensor_product(M), (0, 1)) + assert convert_array_to_matrix(cg) == Trace(M) + + cg = _array_contraction(_array_tensor_product(M, N), (0, 1), (2, 3)) + assert convert_array_to_matrix(cg) == Trace(M) * Trace(N) + + cg = _array_contraction(_array_tensor_product(M, N), (0, 3), (1, 2)) + assert convert_array_to_matrix(cg) == Trace(M * N) + + cg = _array_contraction(_array_tensor_product(M, N), (0, 2), (1, 3)) + assert convert_array_to_matrix(cg) == Trace(M * N.T) + + cg = convert_matrix_to_array(M * N * P) + assert convert_array_to_matrix(cg) == M * N * P + + cg = convert_matrix_to_array(M * N.T * P) + assert convert_array_to_matrix(cg) == M * N.T * P + + cg = _array_contraction(_array_tensor_product(M,N,P,Q), (1, 2), (5, 6)) + assert convert_array_to_matrix(cg) == _array_tensor_product(M * N, P * Q) + + cg = _array_contraction(_array_tensor_product(-2, M, N), (1, 2)) + assert convert_array_to_matrix(cg) == -2 * M * N + + a = MatrixSymbol("a", k, 1) + b = MatrixSymbol("b", k, 1) + c = MatrixSymbol("c", k, 1) + cg = PermuteDims( + _array_contraction( + _array_tensor_product( + a, + ArrayAdd( + _array_tensor_product(b, c), + _array_tensor_product(c, b), + ) + ), (2, 4)), [0, 1, 3, 2]) + assert convert_array_to_matrix(cg) == a * (b.T * c + c.T * b) + + za = ZeroArray(m, n) + assert convert_array_to_matrix(za) == ZeroMatrix(m, n) + + cg = _array_tensor_product(3, M) + assert convert_array_to_matrix(cg) == 3 * M + + # Partial conversion to matrix multiplication: + expr = _array_contraction(_array_tensor_product(M, N, P, Q), (0, 2), (1, 4, 6)) + assert convert_array_to_matrix(expr) == _array_contraction(_array_tensor_product(M.T*N, P, Q), (0, 2, 4)) + + x = MatrixSymbol("x", k, 1) + cg = PermuteDims( + _array_contraction(_array_tensor_product(OneArray(1), x, OneArray(1), DiagMatrix(Identity(1))), + (0, 5)), Permutation(1, 2, 3)) + assert convert_array_to_matrix(cg) == x + + expr = ArrayAdd(M, PermuteDims(M, [1, 0])) + assert convert_array_to_matrix(expr) == M + Transpose(M) + + +def test_arrayexpr_convert_array_to_matrix2(): + cg = _array_contraction(_array_tensor_product(M, N), (1, 3)) + assert convert_array_to_matrix(cg) == M * N.T + + cg = PermuteDims(_array_tensor_product(M, N), Permutation([0, 1, 3, 2])) + assert convert_array_to_matrix(cg) == _array_tensor_product(M, N.T) + + cg = _array_tensor_product(M, PermuteDims(N, Permutation([1, 0]))) + assert convert_array_to_matrix(cg) == _array_tensor_product(M, N.T) + + cg = _array_contraction( + PermuteDims( + _array_tensor_product(M, N, P, Q), Permutation([0, 2, 3, 1, 4, 5, 7, 6])), + (1, 2), (3, 5) + ) + assert convert_array_to_matrix(cg) == _array_tensor_product(M * P.T * Trace(N), Q.T) + + cg = _array_contraction( + _array_tensor_product(M, N, P, PermuteDims(Q, Permutation([1, 0]))), + (1, 5), (2, 3) + ) + assert convert_array_to_matrix(cg) == _array_tensor_product(M * P.T * Trace(N), Q.T) + + cg = _array_tensor_product(M, PermuteDims(N, [1, 0])) + assert convert_array_to_matrix(cg) == _array_tensor_product(M, N.T) + + cg = _array_tensor_product(PermuteDims(M, [1, 0]), PermuteDims(N, [1, 0])) + assert convert_array_to_matrix(cg) == _array_tensor_product(M.T, N.T) + + cg = _array_tensor_product(PermuteDims(N, [1, 0]), PermuteDims(M, [1, 0])) + assert convert_array_to_matrix(cg) == _array_tensor_product(N.T, M.T) + + cg = _array_contraction(M, (0,), (1,)) + assert convert_array_to_matrix(cg) == OneMatrix(1, k)*M*OneMatrix(k, 1) + + cg = _array_contraction(x, (0,), (1,)) + assert convert_array_to_matrix(cg) == OneMatrix(1, k)*x + + Xm = MatrixSymbol("Xm", m, n) + cg = _array_contraction(Xm, (0,), (1,)) + assert convert_array_to_matrix(cg) == OneMatrix(1, m)*Xm*OneMatrix(n, 1) + + +def test_arrayexpr_convert_array_to_diagonalized_vector(): + + # Check matrix recognition over trivial dimensions: + + cg = _array_tensor_product(a, b) + assert convert_array_to_matrix(cg) == a * b.T + + cg = _array_tensor_product(I1, a, b) + assert convert_array_to_matrix(cg) == a * b.T + + # Recognize trace inside a tensor product: + + cg = _array_contraction(_array_tensor_product(A, B, C), (0, 3), (1, 2)) + assert convert_array_to_matrix(cg) == Trace(A * B) * C + + # Transform diagonal operator to contraction: + + cg = _array_diagonal(_array_tensor_product(A, a), (1, 2)) + assert _array_diag2contr_diagmatrix(cg) == _array_contraction(_array_tensor_product(A, OneArray(1), DiagMatrix(a)), (1, 3)) + assert convert_array_to_matrix(cg) == A * DiagMatrix(a) + + cg = _array_diagonal(_array_tensor_product(a, b), (0, 2)) + assert _array_diag2contr_diagmatrix(cg) == _permute_dims( + _array_contraction(_array_tensor_product(DiagMatrix(a), OneArray(1), b), (0, 3)), [1, 2, 0] + ) + assert convert_array_to_matrix(cg) == b.T * DiagMatrix(a) + + cg = _array_diagonal(_array_tensor_product(A, a), (0, 2)) + assert _array_diag2contr_diagmatrix(cg) == _array_contraction(_array_tensor_product(A, OneArray(1), DiagMatrix(a)), (0, 3)) + assert convert_array_to_matrix(cg) == A.T * DiagMatrix(a) + + cg = _array_diagonal(_array_tensor_product(I, x, I1), (0, 2), (3, 5)) + assert _array_diag2contr_diagmatrix(cg) == _array_contraction(_array_tensor_product(I, OneArray(1), I1, DiagMatrix(x)), (0, 5)) + assert convert_array_to_matrix(cg) == DiagMatrix(x) + + cg = _array_diagonal(_array_tensor_product(I, x, A, B), (1, 2), (5, 6)) + assert _array_diag2contr_diagmatrix(cg) == _array_diagonal(_array_contraction(_array_tensor_product(I, OneArray(1), A, B, DiagMatrix(x)), (1, 7)), (5, 6)) + # TODO: this is returning a wrong result: + # convert_array_to_matrix(cg) + + cg = _array_diagonal(_array_tensor_product(I1, a, b), (1, 3, 5)) + assert convert_array_to_matrix(cg) == a*b.T + + cg = _array_diagonal(_array_tensor_product(I1, a, b), (1, 3)) + assert _array_diag2contr_diagmatrix(cg) == _array_contraction(_array_tensor_product(OneArray(1), a, b, I1), (2, 6)) + assert convert_array_to_matrix(cg) == a*b.T + + cg = _array_diagonal(_array_tensor_product(x, I1), (1, 2)) + assert isinstance(cg, ArrayDiagonal) + assert cg.diagonal_indices == ((1, 2),) + assert convert_array_to_matrix(cg) == x + + cg = _array_diagonal(_array_tensor_product(x, I), (0, 2)) + assert _array_diag2contr_diagmatrix(cg) == _array_contraction(_array_tensor_product(OneArray(1), I, DiagMatrix(x)), (1, 3)) + assert convert_array_to_matrix(cg).doit() == DiagMatrix(x) + + raises(ValueError, lambda: _array_diagonal(x, (1,))) + + # Ignore identity matrices with contractions: + + cg = _array_contraction(_array_tensor_product(I, A, I, I), (0, 2), (1, 3), (5, 7)) + assert cg.split_multiple_contractions() == cg + assert convert_array_to_matrix(cg) == Trace(A) * I + + cg = _array_contraction(_array_tensor_product(Trace(A) * I, I, I), (1, 5), (3, 4)) + assert cg.split_multiple_contractions() == cg + assert convert_array_to_matrix(cg).doit() == Trace(A) * I + + # Add DiagMatrix when required: + + cg = _array_contraction(_array_tensor_product(A, a), (1, 2)) + assert cg.split_multiple_contractions() == cg + assert convert_array_to_matrix(cg) == A * a + + cg = _array_contraction(_array_tensor_product(A, a, B), (1, 2, 4)) + assert cg.split_multiple_contractions() == _array_contraction(_array_tensor_product(A, DiagMatrix(a), OneArray(1), B), (1, 2), (3, 5)) + assert convert_array_to_matrix(cg) == A * DiagMatrix(a) * B + + cg = _array_contraction(_array_tensor_product(A, a, B), (0, 2, 4)) + assert cg.split_multiple_contractions() == _array_contraction(_array_tensor_product(A, DiagMatrix(a), OneArray(1), B), (0, 2), (3, 5)) + assert convert_array_to_matrix(cg) == A.T * DiagMatrix(a) * B + + cg = _array_contraction(_array_tensor_product(A, a, b, a.T, B), (0, 2, 4, 7, 9)) + assert cg.split_multiple_contractions() == _array_contraction(_array_tensor_product(A, DiagMatrix(a), OneArray(1), + DiagMatrix(b), OneArray(1), DiagMatrix(a), OneArray(1), B), + (0, 2), (3, 5), (6, 9), (8, 12)) + assert convert_array_to_matrix(cg) == A.T * DiagMatrix(a) * DiagMatrix(b) * DiagMatrix(a) * B.T + + cg = _array_contraction(_array_tensor_product(I1, I1, I1), (1, 2, 4)) + assert cg.split_multiple_contractions() == _array_contraction(_array_tensor_product(I1, I1, OneArray(1), I1), (1, 2), (3, 5)) + assert convert_array_to_matrix(cg) == 1 + + cg = _array_contraction(_array_tensor_product(I, I, I, I, A), (1, 2, 8), (5, 6, 9)) + assert convert_array_to_matrix(cg.split_multiple_contractions()).doit() == A + + cg = _array_contraction(_array_tensor_product(A, a, C, a, B), (1, 2, 4), (5, 6, 8)) + expected = _array_contraction(_array_tensor_product(A, DiagMatrix(a), OneArray(1), C, DiagMatrix(a), OneArray(1), B), (1, 3), (2, 5), (6, 7), (8, 10)) + assert cg.split_multiple_contractions() == expected + assert convert_array_to_matrix(cg) == A * DiagMatrix(a) * C * DiagMatrix(a) * B + + cg = _array_contraction(_array_tensor_product(a, I1, b, I1, (a.T*b).applyfunc(cos)), (1, 2, 8), (5, 6, 9)) + expected = _array_contraction(_array_tensor_product(a, I1, OneArray(1), b, I1, OneArray(1), (a.T*b).applyfunc(cos)), + (1, 3), (2, 10), (6, 8), (7, 11)) + assert cg.split_multiple_contractions().dummy_eq(expected) + assert convert_array_to_matrix(cg).doit().dummy_eq(MatMul(a, (a.T * b).applyfunc(cos), b.T)) + + +def test_arrayexpr_convert_array_contraction_tp_additions(): + a = ArrayAdd( + _array_tensor_product(M, N), + _array_tensor_product(N, M) + ) + tp = _array_tensor_product(P, a, Q) + expr = _array_contraction(tp, (3, 4)) + expected = _array_tensor_product( + P, + ArrayAdd( + _array_contraction(_array_tensor_product(M, N), (1, 2)), + _array_contraction(_array_tensor_product(N, M), (1, 2)), + ), + Q + ) + assert expr == expected + assert convert_array_to_matrix(expr) == _array_tensor_product(P, M * N + N * M, Q) + + expr = _array_contraction(tp, (1, 2), (3, 4), (5, 6)) + result = _array_contraction( + _array_tensor_product( + P, + ArrayAdd( + _array_contraction(_array_tensor_product(M, N), (1, 2)), + _array_contraction(_array_tensor_product(N, M), (1, 2)), + ), + Q + ), (1, 2), (3, 4)) + assert expr == result + assert convert_array_to_matrix(expr) == P * (M * N + N * M) * Q + + +def test_arrayexpr_convert_array_to_implicit_matmul(): + # Trivial dimensions are suppressed, so the result can be expressed in matrix form: + + cg = _array_tensor_product(a, b) + assert convert_array_to_matrix(cg) == a * b.T + + cg = _array_tensor_product(a, b, I) + assert convert_array_to_matrix(cg) == _array_tensor_product(a*b.T, I) + + cg = _array_tensor_product(I, a, b) + assert convert_array_to_matrix(cg) == _array_tensor_product(I, a*b.T) + + cg = _array_tensor_product(a, I, b) + assert convert_array_to_matrix(cg) == _array_tensor_product(a, I, b) + + cg = _array_contraction(_array_tensor_product(I, I), (1, 2)) + assert convert_array_to_matrix(cg) == I + + cg = PermuteDims(_array_tensor_product(I, Identity(1)), [0, 2, 1, 3]) + assert convert_array_to_matrix(cg) == I + + +def test_arrayexpr_convert_array_to_matrix_remove_trivial_dims(): + + # Tensor Product: + assert _remove_trivial_dims(_array_tensor_product(a, b)) == (a * b.T, [1, 3]) + assert _remove_trivial_dims(_array_tensor_product(a.T, b)) == (a * b.T, [0, 3]) + assert _remove_trivial_dims(_array_tensor_product(a, b.T)) == (a * b.T, [1, 2]) + assert _remove_trivial_dims(_array_tensor_product(a.T, b.T)) == (a * b.T, [0, 2]) + + assert _remove_trivial_dims(_array_tensor_product(I, a.T, b.T)) == (_array_tensor_product(I, a * b.T), [2, 4]) + assert _remove_trivial_dims(_array_tensor_product(a.T, I, b.T)) == (_array_tensor_product(a.T, I, b.T), []) + + assert _remove_trivial_dims(_array_tensor_product(a, I)) == (_array_tensor_product(a, I), []) + assert _remove_trivial_dims(_array_tensor_product(I, a)) == (_array_tensor_product(I, a), []) + + assert _remove_trivial_dims(_array_tensor_product(a.T, b.T, c, d)) == ( + _array_tensor_product(a * b.T, c * d.T), [0, 2, 5, 7]) + assert _remove_trivial_dims(_array_tensor_product(a.T, I, b.T, c, d, I)) == ( + _array_tensor_product(a.T, I, b*c.T, d, I), [4, 7]) + + # Addition: + + cg = ArrayAdd(_array_tensor_product(a, b), _array_tensor_product(c, d)) + assert _remove_trivial_dims(cg) == (a * b.T + c * d.T, [1, 3]) + + # Permute Dims: + + cg = PermuteDims(_array_tensor_product(a, b), Permutation(3)(1, 2)) + assert _remove_trivial_dims(cg) == (a * b.T, [2, 3]) + + cg = PermuteDims(_array_tensor_product(a, I, b), Permutation(5)(1, 2, 3, 4)) + assert _remove_trivial_dims(cg) == (cg, []) + + cg = PermuteDims(_array_tensor_product(I, b, a), Permutation(5)(1, 2, 4, 5, 3)) + assert _remove_trivial_dims(cg) == (PermuteDims(_array_tensor_product(I, b * a.T), [0, 2, 3, 1]), [4, 5]) + + # Diagonal: + + cg = _array_diagonal(_array_tensor_product(M, a), (1, 2)) + assert _remove_trivial_dims(cg) == (cg, []) + + # Contraction: + + cg = _array_contraction(_array_tensor_product(M, a), (1, 2)) + assert _remove_trivial_dims(cg) == (cg, []) + + # A few more cases to test the removal and shift of nested removed axes + # with array contractions and array diagonals: + tp = _array_tensor_product( + OneMatrix(1, 1), + M, + x, + OneMatrix(1, 1), + Identity(1), + ) + + expr = _array_contraction(tp, (1, 8)) + rexpr, removed = _remove_trivial_dims(expr) + assert removed == [0, 5, 6, 7] + + expr = _array_contraction(tp, (1, 8), (3, 4)) + rexpr, removed = _remove_trivial_dims(expr) + assert removed == [0, 3, 4, 5] + + expr = _array_diagonal(tp, (1, 8)) + rexpr, removed = _remove_trivial_dims(expr) + assert removed == [0, 5, 6, 7, 8] + + expr = _array_diagonal(tp, (1, 8), (3, 4)) + rexpr, removed = _remove_trivial_dims(expr) + assert removed == [0, 3, 4, 5, 6] + + expr = _array_diagonal(_array_contraction(_array_tensor_product(A, x, I, I1), (1, 2, 5)), (1, 4)) + rexpr, removed = _remove_trivial_dims(expr) + assert removed == [2, 3] + + cg = _array_diagonal(_array_tensor_product(PermuteDims(_array_tensor_product(x, I1), Permutation(1, 2, 3)), (x.T*x).applyfunc(sqrt)), (2, 4), (3, 5)) + rexpr, removed = _remove_trivial_dims(cg) + assert removed == [1, 2] + + # Contractions with identity matrices need to be followed by a permutation + # in order + cg = _array_contraction(_array_tensor_product(A, B, C, M, I), (1, 8)) + ret, removed = _remove_trivial_dims(cg) + assert ret == PermuteDims(_array_tensor_product(A, B, C, M), [0, 2, 3, 4, 5, 6, 7, 1]) + assert removed == [] + + cg = _array_contraction(_array_tensor_product(A, B, C, M, I), (1, 8), (3, 4)) + ret, removed = _remove_trivial_dims(cg) + assert ret == PermuteDims(_array_contraction(_array_tensor_product(A, B, C, M), (3, 4)), [0, 2, 3, 4, 5, 1]) + assert removed == [] + + # Trivial matrices are sometimes inserted into MatMul expressions: + + cg = _array_tensor_product(b*b.T, a.T*a) + ret, removed = _remove_trivial_dims(cg) + assert ret == b*a.T*a*b.T + assert removed == [2, 3] + + Xs = ArraySymbol("X", (3, 2, k)) + cg = _array_tensor_product(M, Xs, b.T*c, a*a.T, b*b.T, c.T*d) + ret, removed = _remove_trivial_dims(cg) + assert ret == _array_tensor_product(M, Xs, a*b.T*c*c.T*d*a.T, b*b.T) + assert removed == [5, 6, 11, 12] + + cg = _array_diagonal(_array_tensor_product(I, I1, x), (1, 4), (3, 5)) + assert _remove_trivial_dims(cg) == (PermuteDims(_array_diagonal(_array_tensor_product(I, x), (1, 2)), Permutation(1, 2)), [1]) + + expr = _array_diagonal(_array_tensor_product(x, I, y), (0, 2)) + assert _remove_trivial_dims(expr) == (PermuteDims(_array_tensor_product(DiagMatrix(x), y), [1, 2, 3, 0]), [0]) + + expr = _array_diagonal(_array_tensor_product(x, I, y), (0, 2), (3, 4)) + assert _remove_trivial_dims(expr) == (expr, []) + + +def test_arrayexpr_convert_array_to_matrix_diag2contraction_diagmatrix(): + cg = _array_diagonal(_array_tensor_product(M, a), (1, 2)) + res = _array_diag2contr_diagmatrix(cg) + assert res.shape == cg.shape + assert res == _array_contraction(_array_tensor_product(M, OneArray(1), DiagMatrix(a)), (1, 3)) + + raises(ValueError, lambda: _array_diagonal(_array_tensor_product(a, M), (1, 2))) + + cg = _array_diagonal(_array_tensor_product(a.T, M), (1, 2)) + res = _array_diag2contr_diagmatrix(cg) + assert res.shape == cg.shape + assert res == _array_contraction(_array_tensor_product(OneArray(1), M, DiagMatrix(a.T)), (1, 4)) + + cg = _array_diagonal(_array_tensor_product(a.T, M, N, b.T), (1, 2), (4, 7)) + res = _array_diag2contr_diagmatrix(cg) + assert res.shape == cg.shape + assert res == _array_contraction( + _array_tensor_product(OneArray(1), M, N, OneArray(1), DiagMatrix(a.T), DiagMatrix(b.T)), (1, 7), (3, 9)) + + cg = _array_diagonal(_array_tensor_product(a, M, N, b.T), (0, 2), (4, 7)) + res = _array_diag2contr_diagmatrix(cg) + assert res.shape == cg.shape + assert res == _array_contraction( + _array_tensor_product(OneArray(1), M, N, OneArray(1), DiagMatrix(a), DiagMatrix(b.T)), (1, 6), (3, 9)) + + cg = _array_diagonal(_array_tensor_product(a, M, N, b.T), (0, 4), (3, 7)) + res = _array_diag2contr_diagmatrix(cg) + assert res.shape == cg.shape + assert res == _array_contraction( + _array_tensor_product(OneArray(1), M, N, OneArray(1), DiagMatrix(a), DiagMatrix(b.T)), (3, 6), (2, 9)) + + I1 = Identity(1) + x = MatrixSymbol("x", k, 1) + A = MatrixSymbol("A", k, k) + cg = _array_diagonal(_array_tensor_product(x, A.T, I1), (0, 2)) + assert _array_diag2contr_diagmatrix(cg).shape == cg.shape + assert _array2matrix(cg).shape == cg.shape + + +def test_arrayexpr_convert_array_to_matrix_support_function(): + + assert _support_function_tp1_recognize([], [2 * k]) == 2 * k + + assert _support_function_tp1_recognize([(1, 2)], [A, 2 * k, B, 3]) == 6 * k * A * B + + assert _support_function_tp1_recognize([(0, 3), (1, 2)], [A, B]) == Trace(A * B) + + assert _support_function_tp1_recognize([(1, 2)], [A, B]) == A * B + assert _support_function_tp1_recognize([(0, 2)], [A, B]) == A.T * B + assert _support_function_tp1_recognize([(1, 3)], [A, B]) == A * B.T + assert _support_function_tp1_recognize([(0, 3)], [A, B]) == A.T * B.T + + assert _support_function_tp1_recognize([(1, 2), (5, 6)], [A, B, C, D]) == _array_tensor_product(A * B, C * D) + assert _support_function_tp1_recognize([(1, 4), (3, 6)], [A, B, C, D]) == PermuteDims( + _array_tensor_product(A * C, B * D), [0, 2, 1, 3]) + + assert _support_function_tp1_recognize([(0, 3), (1, 4)], [A, B, C]) == B * A * C + + assert _support_function_tp1_recognize([(9, 10), (1, 2), (5, 6), (3, 4), (7, 8)], + [X, Y, A, B, C, D]) == X * Y * A * B * C * D + + assert _support_function_tp1_recognize([(9, 10), (1, 2), (5, 6), (3, 4)], + [X, Y, A, B, C, D]) == _array_tensor_product(X * Y * A * B, C * D) + + assert _support_function_tp1_recognize([(1, 7), (3, 8), (4, 11)], [X, Y, A, B, C, D]) == PermuteDims( + _array_tensor_product(X * B.T, Y * C, A.T * D.T), [0, 2, 4, 1, 3, 5] + ) + + assert _support_function_tp1_recognize([(0, 1), (3, 6), (5, 8)], [X, A, B, C, D]) == PermuteDims( + _array_tensor_product(Trace(X) * A * C, B * D), [0, 2, 1, 3]) + + assert _support_function_tp1_recognize([(1, 2), (3, 4), (5, 6), (7, 8)], [A, A, B, C, D]) == A ** 2 * B * C * D + assert _support_function_tp1_recognize([(1, 2), (3, 4), (5, 6), (7, 8)], [X, A, B, C, D]) == X * A * B * C * D + + assert _support_function_tp1_recognize([(1, 6), (3, 8), (5, 10)], [X, Y, A, B, C, D]) == PermuteDims( + _array_tensor_product(X * B, Y * C, A * D), [0, 2, 4, 1, 3, 5] + ) + + assert _support_function_tp1_recognize([(1, 4), (3, 6)], [A, B, C, D]) == PermuteDims( + _array_tensor_product(A * C, B * D), [0, 2, 1, 3]) + + assert _support_function_tp1_recognize([(0, 4), (1, 7), (2, 5), (3, 8)], [X, A, B, C, D]) == C*X.T*B*A*D + + assert _support_function_tp1_recognize([(0, 4), (1, 7), (2, 5), (3, 8)], [X, A, B, C, D]) == C*X.T*B*A*D + + +def test_convert_array_to_hadamard_products(): + + expr = HadamardProduct(M, N) + cg = convert_matrix_to_array(expr) + ret = convert_array_to_matrix(cg) + assert ret == expr + + expr = HadamardProduct(M, N)*P + cg = convert_matrix_to_array(expr) + ret = convert_array_to_matrix(cg) + assert ret == expr + + expr = Q*HadamardProduct(M, N)*P + cg = convert_matrix_to_array(expr) + ret = convert_array_to_matrix(cg) + assert ret == expr + + expr = Q*HadamardProduct(M, N.T)*P + cg = convert_matrix_to_array(expr) + ret = convert_array_to_matrix(cg) + assert ret == expr + + expr = HadamardProduct(M, N)*HadamardProduct(Q, P) + cg = convert_matrix_to_array(expr) + ret = convert_array_to_matrix(cg) + assert expr == ret + + expr = P.T*HadamardProduct(M, N)*HadamardProduct(Q, P) + cg = convert_matrix_to_array(expr) + ret = convert_array_to_matrix(cg) + assert expr == ret + + # ArrayDiagonal should be converted + cg = _array_diagonal(_array_tensor_product(M, N, Q), (1, 3), (0, 2, 4)) + ret = convert_array_to_matrix(cg) + expected = PermuteDims(_array_diagonal(_array_tensor_product(HadamardProduct(M.T, N.T), Q), (1, 2)), [1, 0, 2]) + assert expected == ret + + # Special case that should return the same expression: + cg = _array_diagonal(_array_tensor_product(HadamardProduct(M, N), Q), (0, 2)) + ret = convert_array_to_matrix(cg) + assert ret == cg + + # Hadamard products with traces: + + expr = Trace(HadamardProduct(M, N)) + cg = convert_matrix_to_array(expr) + ret = convert_array_to_matrix(cg) + assert ret == Trace(HadamardProduct(M.T, N.T)) + + expr = Trace(A*HadamardProduct(M, N)) + cg = convert_matrix_to_array(expr) + ret = convert_array_to_matrix(cg) + assert ret == Trace(HadamardProduct(M, N)*A) + + expr = Trace(HadamardProduct(A, M)*N) + cg = convert_matrix_to_array(expr) + ret = convert_array_to_matrix(cg) + assert ret == Trace(HadamardProduct(M.T, N)*A) + + # These should not be converted into Hadamard products: + + cg = _array_diagonal(_array_tensor_product(M, N), (0, 1, 2, 3)) + ret = convert_array_to_matrix(cg) + assert ret == cg + + cg = _array_diagonal(_array_tensor_product(A), (0, 1)) + ret = convert_array_to_matrix(cg) + assert ret == cg + + cg = _array_diagonal(_array_tensor_product(M, N, P), (0, 2, 4), (1, 3, 5)) + assert convert_array_to_matrix(cg) == HadamardProduct(M, N, P) + + cg = _array_diagonal(_array_tensor_product(M, N, P), (0, 3, 4), (1, 2, 5)) + assert convert_array_to_matrix(cg) == HadamardProduct(M, P, N.T) + + cg = _array_diagonal(_array_tensor_product(I, I1, x), (1, 4), (3, 5)) + assert convert_array_to_matrix(cg) == DiagMatrix(x) + + +def test_identify_removable_identity_matrices(): + + D = DiagonalMatrix(MatrixSymbol("D", k, k)) + + cg = _array_contraction(_array_tensor_product(A, B, I), (1, 2, 4, 5)) + expected = _array_contraction(_array_tensor_product(A, B), (1, 2)) + assert identify_removable_identity_matrices(cg) == expected + + cg = _array_contraction(_array_tensor_product(A, B, C, I), (1, 3, 5, 6, 7)) + expected = _array_contraction(_array_tensor_product(A, B, C), (1, 3, 5)) + assert identify_removable_identity_matrices(cg) == expected + + # Tests with diagonal matrices: + + cg = _array_contraction(_array_tensor_product(A, B, D), (1, 2, 4, 5)) + ret = identify_removable_identity_matrices(cg) + expected = _array_contraction(_array_tensor_product(A, B, D), (1, 4), (2, 5)) + assert ret == expected + + cg = _array_contraction(_array_tensor_product(A, B, D, M, N), (1, 2, 4, 5, 6, 8)) + ret = identify_removable_identity_matrices(cg) + assert ret == cg + + +def test_combine_removed(): + + assert _combine_removed(6, [0, 1, 2], [0, 1, 2]) == [0, 1, 2, 3, 4, 5] + assert _combine_removed(8, [2, 5], [1, 3, 4]) == [1, 2, 4, 5, 6] + assert _combine_removed(8, [7], []) == [7] + + +def test_array_contraction_to_diagonal_multiple_identities(): + + expr = _array_contraction(_array_tensor_product(A, B, I, C), (1, 2, 4), (5, 6)) + assert _array_contraction_to_diagonal_multiple_identity(expr) == (expr, []) + assert convert_array_to_matrix(expr) == _array_contraction(_array_tensor_product(A, B, C), (1, 2, 4)) + + expr = _array_contraction(_array_tensor_product(A, I, I), (1, 2, 4)) + assert _array_contraction_to_diagonal_multiple_identity(expr) == (A, [2]) + assert convert_array_to_matrix(expr) == A + + expr = _array_contraction(_array_tensor_product(A, I, I, B), (1, 2, 4), (3, 6)) + assert _array_contraction_to_diagonal_multiple_identity(expr) == (expr, []) + + expr = _array_contraction(_array_tensor_product(A, I, I, B), (1, 2, 3, 4, 6)) + assert _array_contraction_to_diagonal_multiple_identity(expr) == (expr, []) + + +def test_convert_array_element_to_matrix(): + + expr = ArrayElement(M, (i, j)) + assert convert_array_to_matrix(expr) == MatrixElement(M, i, j) + + expr = ArrayElement(_array_contraction(_array_tensor_product(M, N), (1, 3)), (i, j)) + assert convert_array_to_matrix(expr) == MatrixElement(M*N.T, i, j) + + expr = ArrayElement(_array_tensor_product(M, N), (i, j, m, n)) + assert convert_array_to_matrix(expr) == expr + + +def test_convert_array_elementwise_function_to_matrix(): + + d = Dummy("d") + + expr = ArrayElementwiseApplyFunc(Lambda(d, sin(d)), x.T*y) + assert convert_array_to_matrix(expr) == sin(x.T*y) + + expr = ArrayElementwiseApplyFunc(Lambda(d, d**2), x.T*y) + assert convert_array_to_matrix(expr) == (x.T*y)**2 + + expr = ArrayElementwiseApplyFunc(Lambda(d, sin(d)), x) + assert convert_array_to_matrix(expr).dummy_eq(x.applyfunc(sin)) + + expr = ArrayElementwiseApplyFunc(Lambda(d, 1 / (2 * sqrt(d))), x) + assert convert_array_to_matrix(expr) == S.Half * HadamardPower(x, -S.Half) + + +def test_array2matrix(): + # See issue https://github.com/sympy/sympy/pull/22877 + expr = PermuteDims(ArrayContraction(ArrayTensorProduct(x, I, I1, x), (0, 3), (1, 7)), Permutation(2, 3)) + expected = PermuteDims(ArrayTensorProduct(x*x.T, I1), Permutation(3)(1, 2)) + assert _array2matrix(expr) == expected + + +def test_recognize_broadcasting(): + expr = ArrayTensorProduct(x.T*x, A) + assert _remove_trivial_dims(expr) == (KroneckerProduct(x.T*x, A), [0, 1]) + + expr = ArrayTensorProduct(A, x.T*x) + assert _remove_trivial_dims(expr) == (KroneckerProduct(A, x.T*x), [2, 3]) + + expr = ArrayTensorProduct(A, B, x.T*x, C) + assert _remove_trivial_dims(expr) == (ArrayTensorProduct(A, KroneckerProduct(B, x.T*x), C), [4, 5]) + + # Always prefer matrix multiplication to Kronecker product, if possible: + expr = ArrayTensorProduct(a, b, x.T*x) + assert _remove_trivial_dims(expr) == (a*x.T*x*b.T, [1, 3, 4, 5]) diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/expressions/tests/test_convert_indexed_to_array.py b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/tests/test_convert_indexed_to_array.py new file mode 100644 index 0000000000000000000000000000000000000000..258062eadeca041ae3c864dabeefd5165f1cef11 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/tests/test_convert_indexed_to_array.py @@ -0,0 +1,205 @@ +from sympy import tanh +from sympy.concrete.summations import Sum +from sympy.core.symbol import symbols +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.expressions.special import Identity +from sympy.tensor.array.expressions import ArrayElementwiseApplyFunc +from sympy.tensor.indexed import IndexedBase +from sympy.combinatorics import Permutation +from sympy.tensor.array.expressions.array_expressions import ArrayContraction, ArrayTensorProduct, \ + ArrayDiagonal, ArrayAdd, PermuteDims, ArrayElement, _array_tensor_product, _array_contraction, _array_diagonal, \ + _array_add, _permute_dims, ArraySymbol, OneArray +from sympy.tensor.array.expressions.from_array_to_matrix import convert_array_to_matrix +from sympy.tensor.array.expressions.from_indexed_to_array import convert_indexed_to_array, _convert_indexed_to_array +from sympy.testing.pytest import raises + + +A, B = symbols("A B", cls=IndexedBase) +i, j, k, l, m, n = symbols("i j k l m n") +d0, d1, d2, d3 = symbols("d0:4") + +I = Identity(k) + +M = MatrixSymbol("M", k, k) +N = MatrixSymbol("N", k, k) +P = MatrixSymbol("P", k, k) +Q = MatrixSymbol("Q", k, k) + +a = MatrixSymbol("a", k, 1) +b = MatrixSymbol("b", k, 1) +c = MatrixSymbol("c", k, 1) +d = MatrixSymbol("d", k, 1) + + +def test_arrayexpr_convert_index_to_array_support_function(): + expr = M[i, j] + assert _convert_indexed_to_array(expr) == (M, (i, j)) + expr = M[i, j]*N[k, l] + assert _convert_indexed_to_array(expr) == (ArrayTensorProduct(M, N), (i, j, k, l)) + expr = M[i, j]*N[j, k] + assert _convert_indexed_to_array(expr) == (ArrayDiagonal(ArrayTensorProduct(M, N), (1, 2)), (i, k, j)) + expr = Sum(M[i, j]*N[j, k], (j, 0, k-1)) + assert _convert_indexed_to_array(expr) == (ArrayContraction(ArrayTensorProduct(M, N), (1, 2)), (i, k)) + expr = M[i, j] + N[i, j] + assert _convert_indexed_to_array(expr) == (ArrayAdd(M, N), (i, j)) + expr = M[i, j] + N[j, i] + assert _convert_indexed_to_array(expr) == (ArrayAdd(M, PermuteDims(N, Permutation([1, 0]))), (i, j)) + expr = M[i, j] + M[j, i] + assert _convert_indexed_to_array(expr) == (ArrayAdd(M, PermuteDims(M, Permutation([1, 0]))), (i, j)) + expr = (M*N*P)[i, j] + assert _convert_indexed_to_array(expr) == (_array_contraction(ArrayTensorProduct(M, N, P), (1, 2), (3, 4)), (i, j)) + expr = expr.function # Disregard summation in previous expression + ret1, ret2 = _convert_indexed_to_array(expr) + assert ret1 == ArrayDiagonal(ArrayTensorProduct(M, N, P), (1, 2), (3, 4)) + assert str(ret2) == "(i, j, _i_1, _i_2)" + expr = KroneckerDelta(i, j)*M[i, k] + assert _convert_indexed_to_array(expr) == (M, ({i, j}, k)) + expr = KroneckerDelta(i, j)*KroneckerDelta(j, k)*M[i, l] + assert _convert_indexed_to_array(expr) == (M, ({i, j, k}, l)) + expr = KroneckerDelta(j, k)*(M[i, j]*N[k, l] + N[i, j]*M[k, l]) + assert _convert_indexed_to_array(expr) == (_array_diagonal(_array_add( + ArrayTensorProduct(M, N), + _permute_dims(ArrayTensorProduct(M, N), Permutation(0, 2)(1, 3)) + ), (1, 2)), (i, l, frozenset({j, k}))) + expr = KroneckerDelta(j, m)*KroneckerDelta(m, k)*(M[i, j]*N[k, l] + N[i, j]*M[k, l]) + assert _convert_indexed_to_array(expr) == (_array_diagonal(_array_add( + ArrayTensorProduct(M, N), + _permute_dims(ArrayTensorProduct(M, N), Permutation(0, 2)(1, 3)) + ), (1, 2)), (i, l, frozenset({j, m, k}))) + expr = KroneckerDelta(i, j)*KroneckerDelta(j, k)*KroneckerDelta(k,m)*M[i, 0]*KroneckerDelta(m, n) + assert _convert_indexed_to_array(expr) == (M, ({i, j, k, m, n}, 0)) + expr = M[i, i] + assert _convert_indexed_to_array(expr) == (ArrayDiagonal(M, (0, 1)), (i,)) + + +def test_arrayexpr_convert_indexed_to_array_expression(): + + s = Sum(A[i]*B[i], (i, 0, 3)) + cg = convert_indexed_to_array(s) + assert cg == ArrayContraction(ArrayTensorProduct(A, B), (0, 1)) + + expr = M*N + result = ArrayContraction(ArrayTensorProduct(M, N), (1, 2)) + elem = expr[i, j] + assert convert_indexed_to_array(elem) == result + + expr = M*N*M + elem = expr[i, j] + result = _array_contraction(_array_tensor_product(M, M, N), (1, 4), (2, 5)) + cg = convert_indexed_to_array(elem) + assert cg == result + + cg = convert_indexed_to_array((M * N * P)[i, j]) + assert cg == _array_contraction(ArrayTensorProduct(M, N, P), (1, 2), (3, 4)) + + cg = convert_indexed_to_array((M * N.T * P)[i, j]) + assert cg == _array_contraction(ArrayTensorProduct(M, N, P), (1, 3), (2, 4)) + + expr = -2*M*N + elem = expr[i, j] + cg = convert_indexed_to_array(elem) + assert cg == ArrayContraction(ArrayTensorProduct(-2, M, N), (1, 2)) + + +def test_arrayexpr_convert_array_element_to_array_expression(): + A = ArraySymbol("A", (k,)) + B = ArraySymbol("B", (k,)) + + s = Sum(A[i]*B[i], (i, 0, k-1)) + cg = convert_indexed_to_array(s) + assert cg == ArrayContraction(ArrayTensorProduct(A, B), (0, 1)) + + s = A[i]*B[i] + cg = convert_indexed_to_array(s) + assert cg == ArrayDiagonal(ArrayTensorProduct(A, B), (0, 1)) + + s = A[i]*B[j] + cg = convert_indexed_to_array(s, [i, j]) + assert cg == ArrayTensorProduct(A, B) + cg = convert_indexed_to_array(s, [j, i]) + assert cg == ArrayTensorProduct(B, A) + + s = tanh(A[i]*B[j]) + cg = convert_indexed_to_array(s, [i, j]) + assert cg.dummy_eq(ArrayElementwiseApplyFunc(tanh, ArrayTensorProduct(A, B))) + + +def test_arrayexpr_convert_indexed_to_array_and_back_to_matrix(): + + expr = a.T*b + elem = expr[0, 0] + cg = convert_indexed_to_array(elem) + assert cg == ArrayElement(ArrayContraction(ArrayTensorProduct(a, b), (0, 2)), [0, 0]) + + expr = M[i,j] + N[i,j] + p1, p2 = _convert_indexed_to_array(expr) + assert convert_array_to_matrix(p1) == M + N + + expr = M[i,j] + N[j,i] + p1, p2 = _convert_indexed_to_array(expr) + assert convert_array_to_matrix(p1) == M + N.T + + expr = M[i,j]*N[k,l] + N[i,j]*M[k,l] + p1, p2 = _convert_indexed_to_array(expr) + assert convert_array_to_matrix(p1) == ArrayAdd( + ArrayTensorProduct(M, N), + ArrayTensorProduct(N, M)) + + expr = (M*N*P)[i, j] + p1, p2 = _convert_indexed_to_array(expr) + assert convert_array_to_matrix(p1) == M * N * P + + expr = Sum(M[i,j]*(N*P)[j,m], (j, 0, k-1)) + p1, p2 = _convert_indexed_to_array(expr) + assert convert_array_to_matrix(p1) == M * N * P + + expr = Sum((P[j, m] + P[m, j])*(M[i,j]*N[m,n] + N[i,j]*M[m,n]), (j, 0, k-1), (m, 0, k-1)) + p1, p2 = _convert_indexed_to_array(expr) + assert convert_array_to_matrix(p1) == M * P * N + M * P.T * N + N * P * M + N * P.T * M + + +def test_arrayexpr_convert_indexed_to_array_out_of_bounds(): + + expr = Sum(M[i, i], (i, 0, 4)) + raises(ValueError, lambda: convert_indexed_to_array(expr)) + expr = Sum(M[i, i], (i, 0, k)) + raises(ValueError, lambda: convert_indexed_to_array(expr)) + expr = Sum(M[i, i], (i, 1, k-1)) + raises(ValueError, lambda: convert_indexed_to_array(expr)) + + expr = Sum(M[i, j]*N[j,m], (j, 0, 4)) + raises(ValueError, lambda: convert_indexed_to_array(expr)) + expr = Sum(M[i, j]*N[j,m], (j, 0, k)) + raises(ValueError, lambda: convert_indexed_to_array(expr)) + expr = Sum(M[i, j]*N[j,m], (j, 1, k-1)) + raises(ValueError, lambda: convert_indexed_to_array(expr)) + + +def test_arrayexpr_convert_indexed_to_array_broadcast(): + A = ArraySymbol("A", (3, 3)) + B = ArraySymbol("B", (3, 3)) + + expr = A[i, j] + B[k, l] + O2 = OneArray(3, 3) + expected = ArrayAdd(ArrayTensorProduct(A, O2), ArrayTensorProduct(O2, B)) + assert convert_indexed_to_array(expr) == expected + assert convert_indexed_to_array(expr, [i, j, k, l]) == expected + assert convert_indexed_to_array(expr, [l, k, i, j]) == ArrayAdd(PermuteDims(ArrayTensorProduct(O2, A), [1, 0, 2, 3]), PermuteDims(ArrayTensorProduct(B, O2), [1, 0, 2, 3])) + + expr = A[i, j] + B[j, k] + O1 = OneArray(3) + assert convert_indexed_to_array(expr, [i, j, k]) == ArrayAdd(ArrayTensorProduct(A, O1), ArrayTensorProduct(O1, B)) + + C = ArraySymbol("C", (d0, d1)) + D = ArraySymbol("D", (d3, d1)) + + expr = C[i, j] + D[k, j] + assert convert_indexed_to_array(expr, [i, j, k]) == ArrayAdd(ArrayTensorProduct(C, OneArray(d3)), PermuteDims(ArrayTensorProduct(OneArray(d0), D), [0, 2, 1])) + + X = ArraySymbol("X", (5, 3)) + + expr = X[i, n] - X[j, n] + assert convert_indexed_to_array(expr, [i, j, n]) == ArrayAdd(ArrayTensorProduct(-1, OneArray(5), X), PermuteDims(ArrayTensorProduct(X, OneArray(5)), [0, 2, 1])) + + raises(ValueError, lambda: convert_indexed_to_array(C[i, j] + D[i, j])) diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/expressions/tests/test_convert_matrix_to_array.py b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/tests/test_convert_matrix_to_array.py new file mode 100644 index 0000000000000000000000000000000000000000..142585882588df6aa0e4648d9d8881ea755f42a0 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/tests/test_convert_matrix_to_array.py @@ -0,0 +1,128 @@ +from sympy import Lambda, KroneckerProduct +from sympy.core.symbol import symbols, Dummy +from sympy.matrices.expressions.hadamard import (HadamardPower, HadamardProduct) +from sympy.matrices.expressions.inverse import Inverse +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.expressions.matpow import MatPow +from sympy.matrices.expressions.special import Identity +from sympy.matrices.expressions.trace import Trace +from sympy.matrices.expressions.transpose import Transpose +from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct, ArrayContraction, \ + PermuteDims, ArrayDiagonal, ArrayElementwiseApplyFunc, _array_contraction, _array_tensor_product, Reshape +from sympy.tensor.array.expressions.from_array_to_matrix import convert_array_to_matrix +from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array + +i, j, k, l, m, n = symbols("i j k l m n") + +I = Identity(k) + +M = MatrixSymbol("M", k, k) +N = MatrixSymbol("N", k, k) +P = MatrixSymbol("P", k, k) +Q = MatrixSymbol("Q", k, k) + +A = MatrixSymbol("A", k, k) +B = MatrixSymbol("B", k, k) +C = MatrixSymbol("C", k, k) +D = MatrixSymbol("D", k, k) + +X = MatrixSymbol("X", k, k) +Y = MatrixSymbol("Y", k, k) + +a = MatrixSymbol("a", k, 1) +b = MatrixSymbol("b", k, 1) +c = MatrixSymbol("c", k, 1) +d = MatrixSymbol("d", k, 1) + + +def test_arrayexpr_convert_matrix_to_array(): + + expr = M*N + result = ArrayContraction(ArrayTensorProduct(M, N), (1, 2)) + assert convert_matrix_to_array(expr) == result + + expr = M*N*M + result = _array_contraction(ArrayTensorProduct(M, N, M), (1, 2), (3, 4)) + assert convert_matrix_to_array(expr) == result + + expr = Transpose(M) + assert convert_matrix_to_array(expr) == PermuteDims(M, [1, 0]) + + expr = M*Transpose(N) + assert convert_matrix_to_array(expr) == _array_contraction(_array_tensor_product(M, PermuteDims(N, [1, 0])), (1, 2)) + + expr = 3*M*N + res = convert_matrix_to_array(expr) + rexpr = convert_array_to_matrix(res) + assert expr == rexpr + + expr = 3*M + N*M.T*M + 4*k*N + res = convert_matrix_to_array(expr) + rexpr = convert_array_to_matrix(res) + assert expr == rexpr + + expr = Inverse(M)*N + rexpr = convert_array_to_matrix(convert_matrix_to_array(expr)) + assert expr == rexpr + + expr = M**2 + rexpr = convert_array_to_matrix(convert_matrix_to_array(expr)) + assert expr == rexpr + + expr = M*(2*N + 3*M) + res = convert_matrix_to_array(expr) + rexpr = convert_array_to_matrix(res) + assert expr == rexpr + + expr = Trace(M) + result = ArrayContraction(M, (0, 1)) + assert convert_matrix_to_array(expr) == result + + expr = 3*Trace(M) + result = ArrayContraction(ArrayTensorProduct(3, M), (0, 1)) + assert convert_matrix_to_array(expr) == result + + expr = 3*Trace(Trace(M) * M) + result = ArrayContraction(ArrayTensorProduct(3, M, M), (0, 1), (2, 3)) + assert convert_matrix_to_array(expr) == result + + expr = 3*Trace(M)**2 + result = ArrayContraction(ArrayTensorProduct(3, M, M), (0, 1), (2, 3)) + assert convert_matrix_to_array(expr) == result + + expr = HadamardProduct(M, N) + result = ArrayDiagonal(ArrayTensorProduct(M, N), (0, 2), (1, 3)) + assert convert_matrix_to_array(expr) == result + + expr = HadamardProduct(M*N, N*M) + result = ArrayDiagonal(ArrayContraction(ArrayTensorProduct(M, N, N, M), (1, 2), (5, 6)), (0, 2), (1, 3)) + assert convert_matrix_to_array(expr) == result + + expr = HadamardPower(M, 2) + result = ArrayDiagonal(ArrayTensorProduct(M, M), (0, 2), (1, 3)) + assert convert_matrix_to_array(expr) == result + + expr = HadamardPower(M*N, 2) + result = ArrayDiagonal(ArrayContraction(ArrayTensorProduct(M, N, M, N), (1, 2), (5, 6)), (0, 2), (1, 3)) + assert convert_matrix_to_array(expr) == result + + expr = HadamardPower(M, n) + d0 = Dummy("d0") + result = ArrayElementwiseApplyFunc(Lambda(d0, d0**n), M) + assert convert_matrix_to_array(expr).dummy_eq(result) + + expr = M**2 + assert isinstance(expr, MatPow) + assert convert_matrix_to_array(expr) == ArrayContraction(ArrayTensorProduct(M, M), (1, 2)) + + expr = a.T*b + cg = convert_matrix_to_array(expr) + assert cg == ArrayContraction(ArrayTensorProduct(a, b), (0, 2)) + + expr = KroneckerProduct(A, B) + cg = convert_matrix_to_array(expr) + assert cg == Reshape(PermuteDims(ArrayTensorProduct(A, B), [0, 2, 1, 3]), (k**2, k**2)) + + expr = KroneckerProduct(A, B, C, D) + cg = convert_matrix_to_array(expr) + assert cg == Reshape(PermuteDims(ArrayTensorProduct(A, B, C, D), [0, 2, 4, 6, 1, 3, 5, 7]), (k**4, k**4)) diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/expressions/tests/test_deprecated_conv_modules.py b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/tests/test_deprecated_conv_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..b41b6105410a308e7774fce760b235497d0303bb --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/tests/test_deprecated_conv_modules.py @@ -0,0 +1,22 @@ +from sympy import MatrixSymbol, symbols, Sum +from sympy.tensor.array.expressions import conv_array_to_indexed, from_array_to_indexed, ArrayTensorProduct, \ + ArrayContraction, conv_array_to_matrix, from_array_to_matrix, conv_matrix_to_array, from_matrix_to_array, \ + conv_indexed_to_array, from_indexed_to_array +from sympy.testing.pytest import warns +from sympy.utilities.exceptions import SymPyDeprecationWarning + + +def test_deprecated_conv_module_results(): + + M = MatrixSymbol("M", 3, 3) + N = MatrixSymbol("N", 3, 3) + i, j, d = symbols("i j d") + + x = ArrayContraction(ArrayTensorProduct(M, N), (1, 2)) + y = Sum(M[i, d]*N[d, j], (d, 0, 2)) + + with warns(SymPyDeprecationWarning, test_stacklevel=False): + assert conv_array_to_indexed.convert_array_to_indexed(x, [i, j]).dummy_eq(from_array_to_indexed.convert_array_to_indexed(x, [i, j])) + assert conv_array_to_matrix.convert_array_to_matrix(x) == from_array_to_matrix.convert_array_to_matrix(x) + assert conv_matrix_to_array.convert_matrix_to_array(M*N) == from_matrix_to_array.convert_matrix_to_array(M*N) + assert conv_indexed_to_array.convert_indexed_to_array(y) == from_indexed_to_array.convert_indexed_to_array(y) diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/expressions/utils.py b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e55c0e6ed47cdc9ff1c24cc92f006998aeb86822 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/tensor/array/expressions/utils.py @@ -0,0 +1,123 @@ +import bisect +from collections import defaultdict + +from sympy.combinatorics import Permutation +from sympy.core.containers import Tuple +from sympy.core.numbers import Integer + + +def _get_mapping_from_subranks(subranks): + mapping = {} + counter = 0 + for i, rank in enumerate(subranks): + for j in range(rank): + mapping[counter] = (i, j) + counter += 1 + return mapping + + +def _get_contraction_links(args, subranks, *contraction_indices): + mapping = _get_mapping_from_subranks(subranks) + contraction_tuples = [[mapping[j] for j in i] for i in contraction_indices] + dlinks = defaultdict(dict) + for links in contraction_tuples: + if len(links) == 2: + (arg1, pos1), (arg2, pos2) = links + dlinks[arg1][pos1] = (arg2, pos2) + dlinks[arg2][pos2] = (arg1, pos1) + continue + + return args, dict(dlinks) + + +def _sort_contraction_indices(pairing_indices): + pairing_indices = [Tuple(*sorted(i)) for i in pairing_indices] + pairing_indices.sort(key=lambda x: min(x)) + return pairing_indices + + +def _get_diagonal_indices(flattened_indices): + axes_contraction = defaultdict(list) + for i, ind in enumerate(flattened_indices): + if isinstance(ind, (int, Integer)): + # If the indices is a number, there can be no diagonal operation: + continue + axes_contraction[ind].append(i) + axes_contraction = {k: v for k, v in axes_contraction.items() if len(v) > 1} + # Put the diagonalized indices at the end: + ret_indices = [i for i in flattened_indices if i not in axes_contraction] + diag_indices = list(axes_contraction) + diag_indices.sort(key=lambda x: flattened_indices.index(x)) + diagonal_indices = [tuple(axes_contraction[i]) for i in diag_indices] + ret_indices += diag_indices + ret_indices = tuple(ret_indices) + return diagonal_indices, ret_indices + + +def _get_argindex(subindices, ind): + for i, sind in enumerate(subindices): + if ind == sind: + return i + if isinstance(sind, (set, frozenset)) and ind in sind: + return i + raise IndexError("%s not found in %s" % (ind, subindices)) + + +def _apply_recursively_over_nested_lists(func, arr): + if isinstance(arr, (tuple, list, Tuple)): + return tuple(_apply_recursively_over_nested_lists(func, i) for i in arr) + elif isinstance(arr, Tuple): + return Tuple.fromiter(_apply_recursively_over_nested_lists(func, i) for i in arr) + else: + return func(arr) + + +def _build_push_indices_up_func_transformation(flattened_contraction_indices): + shifts = {0: 0} + i = 0 + cumulative = 0 + while i < len(flattened_contraction_indices): + j = 1 + while i+j < len(flattened_contraction_indices): + if flattened_contraction_indices[i] + j != flattened_contraction_indices[i+j]: + break + j += 1 + cumulative += j + shifts[flattened_contraction_indices[i]] = cumulative + i += j + shift_keys = sorted(shifts.keys()) + + def func(idx): + return shifts[shift_keys[bisect.bisect_right(shift_keys, idx)-1]] + + def transform(j): + if j in flattened_contraction_indices: + return None + else: + return j - func(j) + + return transform + + +def _build_push_indices_down_func_transformation(flattened_contraction_indices): + N = flattened_contraction_indices[-1]+2 + + shifts = [i for i in range(N) if i not in flattened_contraction_indices] + + def transform(j): + if j < len(shifts): + return shifts[j] + else: + return j + shifts[-1] - len(shifts) + 1 + + return transform + + +def _apply_permutation_to_list(perm: Permutation, target_list: list): + """ + Permute a list according to the given permutation. + """ + new_list = [None for i in range(perm.size)] + for i, e in enumerate(target_list): + new_list[perm(i)] = e + return new_list diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/mutable_ndim_array.py b/phivenv/Lib/site-packages/sympy/tensor/array/mutable_ndim_array.py new file mode 100644 index 0000000000000000000000000000000000000000..e1eaaf7241bc3b4a48234178d18da3aa5736e189 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/tensor/array/mutable_ndim_array.py @@ -0,0 +1,13 @@ +from sympy.tensor.array.ndim_array import NDimArray + + +class MutableNDimArray(NDimArray): + + def as_immutable(self): + raise NotImplementedError("abstract method") + + def as_mutable(self): + return self + + def _sympy_(self): + return self.as_immutable() diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/ndim_array.py b/phivenv/Lib/site-packages/sympy/tensor/array/ndim_array.py new file mode 100644 index 0000000000000000000000000000000000000000..2b9a857b8cfd9ee46646c46f274636d6b9962b6e --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/tensor/array/ndim_array.py @@ -0,0 +1,601 @@ +from sympy.core.basic import Basic +from sympy.core.containers import (Dict, Tuple) +from sympy.core.expr import Expr +from sympy.core.kind import Kind, NumberKind, UndefinedKind +from sympy.core.numbers import Integer +from sympy.core.singleton import S +from sympy.core.sympify import sympify +from sympy.external.gmpy import SYMPY_INTS +from sympy.printing.defaults import Printable + +import itertools +from collections.abc import Iterable + + +class ArrayKind(Kind): + """ + Kind for N-dimensional array in SymPy. + + This kind represents the multidimensional array that algebraic + operations are defined. Basic class for this kind is ``NDimArray``, + but any expression representing the array can have this. + + Parameters + ========== + + element_kind : Kind + Kind of the element. Default is :obj:NumberKind ``, + which means that the array contains only numbers. + + Examples + ======== + + Any instance of array class has ``ArrayKind``. + + >>> from sympy import NDimArray + >>> NDimArray([1,2,3]).kind + ArrayKind(NumberKind) + + Although expressions representing an array may be not instance of + array class, it will have ``ArrayKind`` as well. + + >>> from sympy import Integral + >>> from sympy.tensor.array import NDimArray + >>> from sympy.abc import x + >>> intA = Integral(NDimArray([1,2,3]), x) + >>> isinstance(intA, NDimArray) + False + >>> intA.kind + ArrayKind(NumberKind) + + Use ``isinstance()`` to check for ``ArrayKind` without specifying + the element kind. Use ``is`` with specifying the element kind. + + >>> from sympy.tensor.array import ArrayKind + >>> from sympy.core import NumberKind + >>> boolA = NDimArray([True, False]) + >>> isinstance(boolA.kind, ArrayKind) + True + >>> boolA.kind is ArrayKind(NumberKind) + False + + See Also + ======== + + shape : Function to return the shape of objects with ``MatrixKind``. + + """ + def __new__(cls, element_kind=NumberKind): + obj = super().__new__(cls, element_kind) + obj.element_kind = element_kind + return obj + + def __repr__(self): + return "ArrayKind(%s)" % self.element_kind + + @classmethod + def _union(cls, kinds) -> 'ArrayKind': + elem_kinds = {e.kind for e in kinds} + if len(elem_kinds) == 1: + elemkind, = elem_kinds + else: + elemkind = UndefinedKind + return ArrayKind(elemkind) + + +class NDimArray(Printable): + """N-dimensional array. + + Examples + ======== + + Create an N-dim array of zeros: + + >>> from sympy import MutableDenseNDimArray + >>> a = MutableDenseNDimArray.zeros(2, 3, 4) + >>> a + [[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]] + + Create an N-dim array from a list; + + >>> a = MutableDenseNDimArray([[2, 3], [4, 5]]) + >>> a + [[2, 3], [4, 5]] + + >>> b = MutableDenseNDimArray([[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]]) + >>> b + [[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]] + + Create an N-dim array from a flat list with dimension shape: + + >>> a = MutableDenseNDimArray([1, 2, 3, 4, 5, 6], (2, 3)) + >>> a + [[1, 2, 3], [4, 5, 6]] + + Create an N-dim array from a matrix: + + >>> from sympy import Matrix + >>> a = Matrix([[1,2],[3,4]]) + >>> a + Matrix([ + [1, 2], + [3, 4]]) + >>> b = MutableDenseNDimArray(a) + >>> b + [[1, 2], [3, 4]] + + Arithmetic operations on N-dim arrays + + >>> a = MutableDenseNDimArray([1, 1, 1, 1], (2, 2)) + >>> b = MutableDenseNDimArray([4, 4, 4, 4], (2, 2)) + >>> c = a + b + >>> c + [[5, 5], [5, 5]] + >>> a - b + [[-3, -3], [-3, -3]] + + """ + + _diff_wrt = True + is_scalar = False + + def __new__(cls, iterable, shape=None, **kwargs): + from sympy.tensor.array import ImmutableDenseNDimArray + return ImmutableDenseNDimArray(iterable, shape, **kwargs) + + def __getitem__(self, index): + raise NotImplementedError("A subclass of NDimArray should implement __getitem__") + + def _parse_index(self, index): + if isinstance(index, (SYMPY_INTS, Integer)): + if index >= self._loop_size: + raise ValueError("Only a tuple index is accepted") + return index + + if self._loop_size == 0: + raise ValueError("Index not valid with an empty array") + + if len(index) != self._rank: + raise ValueError('Wrong number of array axes') + + real_index = 0 + # check if input index can exist in current indexing + for i in range(self._rank): + if (index[i] >= self.shape[i]) or (index[i] < -self.shape[i]): + raise ValueError('Index ' + str(index) + ' out of border') + if index[i] < 0: + real_index += 1 + real_index = real_index*self.shape[i] + index[i] + + return real_index + + def _get_tuple_index(self, integer_index): + index = [] + for sh in reversed(self.shape): + index.append(integer_index % sh) + integer_index //= sh + index.reverse() + return tuple(index) + + def _check_symbolic_index(self, index): + # Check if any index is symbolic: + tuple_index = (index if isinstance(index, tuple) else (index,)) + if any((isinstance(i, Expr) and (not i.is_number)) for i in tuple_index): + for i, nth_dim in zip(tuple_index, self.shape): + if ((i < 0) == True) or ((i >= nth_dim) == True): + raise ValueError("index out of range") + from sympy.tensor import Indexed + return Indexed(self, *tuple_index) + return None + + def _setter_iterable_check(self, value): + from sympy.matrices.matrixbase import MatrixBase + if isinstance(value, (Iterable, MatrixBase, NDimArray)): + raise NotImplementedError + + @classmethod + def _scan_iterable_shape(cls, iterable): + def f(pointer): + if not isinstance(pointer, Iterable): + return [pointer], () + + if len(pointer) == 0: + return [], (0,) + + result = [] + elems, shapes = zip(*[f(i) for i in pointer]) + if len(set(shapes)) != 1: + raise ValueError("could not determine shape unambiguously") + for i in elems: + result.extend(i) + return result, (len(shapes),)+shapes[0] + + return f(iterable) + + @classmethod + def _handle_ndarray_creation_inputs(cls, iterable=None, shape=None, **kwargs): + from sympy.matrices.matrixbase import MatrixBase + from sympy.tensor.array import SparseNDimArray + + if shape is None: + if iterable is None: + shape = () + iterable = () + # Construction of a sparse array from a sparse array + elif isinstance(iterable, SparseNDimArray): + return iterable._shape, iterable._sparse_array + + # Construct N-dim array from another N-dim array: + elif isinstance(iterable, NDimArray): + shape = iterable.shape + + # Construct N-dim array from an iterable (numpy arrays included): + elif isinstance(iterable, Iterable): + iterable, shape = cls._scan_iterable_shape(iterable) + + # Construct N-dim array from a Matrix: + elif isinstance(iterable, MatrixBase): + shape = iterable.shape + + else: + shape = () + iterable = (iterable,) + + if isinstance(iterable, (Dict, dict)) and shape is not None: + new_dict = iterable.copy() + for k in new_dict: + if isinstance(k, (tuple, Tuple)): + new_key = 0 + for i, idx in enumerate(k): + new_key = new_key * shape[i] + idx + iterable[new_key] = iterable[k] + del iterable[k] + + if isinstance(shape, (SYMPY_INTS, Integer)): + shape = (shape,) + + if not all(isinstance(dim, (SYMPY_INTS, Integer)) for dim in shape): + raise TypeError("Shape should contain integers only.") + + return tuple(shape), iterable + + def __len__(self): + """Overload common function len(). Returns number of elements in array. + + Examples + ======== + + >>> from sympy import MutableDenseNDimArray + >>> a = MutableDenseNDimArray.zeros(3, 3) + >>> a + [[0, 0, 0], [0, 0, 0], [0, 0, 0]] + >>> len(a) + 9 + + """ + return self._loop_size + + @property + def shape(self): + """ + Returns array shape (dimension). + + Examples + ======== + + >>> from sympy import MutableDenseNDimArray + >>> a = MutableDenseNDimArray.zeros(3, 3) + >>> a.shape + (3, 3) + + """ + return self._shape + + def rank(self): + """ + Returns rank of array. + + Examples + ======== + + >>> from sympy import MutableDenseNDimArray + >>> a = MutableDenseNDimArray.zeros(3,4,5,6,3) + >>> a.rank() + 5 + + """ + return self._rank + + def diff(self, *args, **kwargs): + """ + Calculate the derivative of each element in the array. + + Examples + ======== + + >>> from sympy import ImmutableDenseNDimArray + >>> from sympy.abc import x, y + >>> M = ImmutableDenseNDimArray([[x, y], [1, x*y]]) + >>> M.diff(x) + [[1, 0], [0, y]] + + """ + from sympy.tensor.array.array_derivatives import ArrayDerivative + kwargs.setdefault('evaluate', True) + return ArrayDerivative(self.as_immutable(), *args, **kwargs) + + def _eval_derivative(self, base): + # Types are (base: scalar, self: array) + return self.applyfunc(lambda x: base.diff(x)) + + def _eval_derivative_n_times(self, s, n): + return Basic._eval_derivative_n_times(self, s, n) + + def applyfunc(self, f): + """Apply a function to each element of the N-dim array. + + Examples + ======== + + >>> from sympy import ImmutableDenseNDimArray + >>> m = ImmutableDenseNDimArray([i*2+j for i in range(2) for j in range(2)], (2, 2)) + >>> m + [[0, 1], [2, 3]] + >>> m.applyfunc(lambda i: 2*i) + [[0, 2], [4, 6]] + """ + from sympy.tensor.array import SparseNDimArray + from sympy.tensor.array.arrayop import Flatten + + if isinstance(self, SparseNDimArray) and f(S.Zero) == 0: + return type(self)({k: f(v) for k, v in self._sparse_array.items() if f(v) != 0}, self.shape) + + return type(self)(map(f, Flatten(self)), self.shape) + + def _sympystr(self, printer): + def f(sh, shape_left, i, j): + if len(shape_left) == 1: + return "["+", ".join([printer._print(self[self._get_tuple_index(e)]) for e in range(i, j)])+"]" + + sh //= shape_left[0] + return "[" + ", ".join([f(sh, shape_left[1:], i+e*sh, i+(e+1)*sh) for e in range(shape_left[0])]) + "]" # + "\n"*len(shape_left) + + if self.rank() == 0: + return printer._print(self[()]) + if 0 in self.shape: + return f"{self.__class__.__name__}([], {self.shape})" + return f(self._loop_size, self.shape, 0, self._loop_size) + + def tolist(self): + """ + Converting MutableDenseNDimArray to one-dim list + + Examples + ======== + + >>> from sympy import MutableDenseNDimArray + >>> a = MutableDenseNDimArray([1, 2, 3, 4], (2, 2)) + >>> a + [[1, 2], [3, 4]] + >>> b = a.tolist() + >>> b + [[1, 2], [3, 4]] + """ + + def f(sh, shape_left, i, j): + if len(shape_left) == 1: + return [self[self._get_tuple_index(e)] for e in range(i, j)] + result = [] + sh //= shape_left[0] + for e in range(shape_left[0]): + result.append(f(sh, shape_left[1:], i+e*sh, i+(e+1)*sh)) + return result + + return f(self._loop_size, self.shape, 0, self._loop_size) + + def __add__(self, other): + from sympy.tensor.array.arrayop import Flatten + + if not isinstance(other, NDimArray): + return NotImplemented + + if self.shape != other.shape: + raise ValueError("array shape mismatch") + result_list = [i+j for i,j in zip(Flatten(self), Flatten(other))] + + return type(self)(result_list, self.shape) + + def __sub__(self, other): + from sympy.tensor.array.arrayop import Flatten + + if not isinstance(other, NDimArray): + return NotImplemented + + if self.shape != other.shape: + raise ValueError("array shape mismatch") + result_list = [i-j for i,j in zip(Flatten(self), Flatten(other))] + + return type(self)(result_list, self.shape) + + def __mul__(self, other): + from sympy.matrices.matrixbase import MatrixBase + from sympy.tensor.array import SparseNDimArray + from sympy.tensor.array.arrayop import Flatten + + if isinstance(other, (Iterable, NDimArray, MatrixBase)): + raise ValueError("scalar expected, use tensorproduct(...) for tensorial product") + + other = sympify(other) + if isinstance(self, SparseNDimArray): + if other.is_zero: + return type(self)({}, self.shape) + return type(self)({k: other*v for (k, v) in self._sparse_array.items()}, self.shape) + + result_list = [i*other for i in Flatten(self)] + return type(self)(result_list, self.shape) + + def __rmul__(self, other): + from sympy.matrices.matrixbase import MatrixBase + from sympy.tensor.array import SparseNDimArray + from sympy.tensor.array.arrayop import Flatten + + if isinstance(other, (Iterable, NDimArray, MatrixBase)): + raise ValueError("scalar expected, use tensorproduct(...) for tensorial product") + + other = sympify(other) + if isinstance(self, SparseNDimArray): + if other.is_zero: + return type(self)({}, self.shape) + return type(self)({k: other*v for (k, v) in self._sparse_array.items()}, self.shape) + + result_list = [other*i for i in Flatten(self)] + return type(self)(result_list, self.shape) + + def __truediv__(self, other): + from sympy.matrices.matrixbase import MatrixBase + from sympy.tensor.array import SparseNDimArray + from sympy.tensor.array.arrayop import Flatten + + if isinstance(other, (Iterable, NDimArray, MatrixBase)): + raise ValueError("scalar expected") + + other = sympify(other) + if isinstance(self, SparseNDimArray) and other != S.Zero: + return type(self)({k: v/other for (k, v) in self._sparse_array.items()}, self.shape) + + result_list = [i/other for i in Flatten(self)] + return type(self)(result_list, self.shape) + + def __rtruediv__(self, other): + raise NotImplementedError('unsupported operation on NDimArray') + + def __neg__(self): + from sympy.tensor.array import SparseNDimArray + from sympy.tensor.array.arrayop import Flatten + + if isinstance(self, SparseNDimArray): + return type(self)({k: -v for (k, v) in self._sparse_array.items()}, self.shape) + + result_list = [-i for i in Flatten(self)] + return type(self)(result_list, self.shape) + + def __iter__(self): + def iterator(): + if self._shape: + for i in range(self._shape[0]): + yield self[i] + else: + yield self[()] + + return iterator() + + def __eq__(self, other): + """ + NDimArray instances can be compared to each other. + Instances equal if they have same shape and data. + + Examples + ======== + + >>> from sympy import MutableDenseNDimArray + >>> a = MutableDenseNDimArray.zeros(2, 3) + >>> b = MutableDenseNDimArray.zeros(2, 3) + >>> a == b + True + >>> c = a.reshape(3, 2) + >>> c == b + False + >>> a[0,0] = 1 + >>> b[0,0] = 2 + >>> a == b + False + """ + from sympy.tensor.array import SparseNDimArray + if not isinstance(other, NDimArray): + return False + + if not self.shape == other.shape: + return False + + if isinstance(self, SparseNDimArray) and isinstance(other, SparseNDimArray): + return dict(self._sparse_array) == dict(other._sparse_array) + + return list(self) == list(other) + + def __ne__(self, other): + return not self == other + + def _eval_transpose(self): + if self.rank() != 2: + raise ValueError("array rank not 2") + from .arrayop import permutedims + return permutedims(self, (1, 0)) + + def transpose(self): + return self._eval_transpose() + + def _eval_conjugate(self): + from sympy.tensor.array.arrayop import Flatten + + return self.func([i.conjugate() for i in Flatten(self)], self.shape) + + def conjugate(self): + return self._eval_conjugate() + + def _eval_adjoint(self): + return self.transpose().conjugate() + + def adjoint(self): + return self._eval_adjoint() + + def _slice_expand(self, s, dim): + if not isinstance(s, slice): + return (s,) + start, stop, step = s.indices(dim) + return [start + i*step for i in range((stop-start)//step)] + + def _get_slice_data_for_array_access(self, index): + sl_factors = [self._slice_expand(i, dim) for (i, dim) in zip(index, self.shape)] + eindices = itertools.product(*sl_factors) + return sl_factors, eindices + + def _get_slice_data_for_array_assignment(self, index, value): + if not isinstance(value, NDimArray): + value = type(self)(value) + sl_factors, eindices = self._get_slice_data_for_array_access(index) + slice_offsets = [min(i) if isinstance(i, list) else None for i in sl_factors] + # TODO: add checks for dimensions for `value`? + return value, eindices, slice_offsets + + @classmethod + def _check_special_bounds(cls, flat_list, shape): + if shape == () and len(flat_list) != 1: + raise ValueError("arrays without shape need one scalar value") + if shape == (0,) and len(flat_list) > 0: + raise ValueError("if array shape is (0,) there cannot be elements") + + def _check_index_for_getitem(self, index): + if isinstance(index, (SYMPY_INTS, Integer, slice)): + index = (index,) + + if len(index) < self.rank(): + index = tuple(index) + \ + tuple(slice(None) for i in range(len(index), self.rank())) + + if len(index) > self.rank(): + raise ValueError('Dimension of index greater than rank of array') + + return index + + +class ImmutableNDimArray(NDimArray, Basic): + _op_priority = 11.0 + + def __hash__(self): + return Basic.__hash__(self) + + def as_immutable(self): + return self + + def as_mutable(self): + raise NotImplementedError("abstract method") diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/sparse_ndim_array.py b/phivenv/Lib/site-packages/sympy/tensor/array/sparse_ndim_array.py new file mode 100644 index 0000000000000000000000000000000000000000..f11aa95be8ec9d10a9104d48fb28f406fe43845e --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/tensor/array/sparse_ndim_array.py @@ -0,0 +1,196 @@ +from sympy.core.basic import Basic +from sympy.core.containers import (Dict, Tuple) +from sympy.core.singleton import S +from sympy.core.sympify import _sympify +from sympy.tensor.array.mutable_ndim_array import MutableNDimArray +from sympy.tensor.array.ndim_array import NDimArray, ImmutableNDimArray +from sympy.utilities.iterables import flatten + +import functools + +class SparseNDimArray(NDimArray): + + def __new__(self, *args, **kwargs): + return ImmutableSparseNDimArray(*args, **kwargs) + + def __getitem__(self, index): + """ + Get an element from a sparse N-dim array. + + Examples + ======== + + >>> from sympy import MutableSparseNDimArray + >>> a = MutableSparseNDimArray(range(4), (2, 2)) + >>> a + [[0, 1], [2, 3]] + >>> a[0, 0] + 0 + >>> a[1, 1] + 3 + >>> a[0] + [0, 1] + >>> a[1] + [2, 3] + + Symbolic indexing: + + >>> from sympy.abc import i, j + >>> a[i, j] + [[0, 1], [2, 3]][i, j] + + Replace `i` and `j` to get element `(0, 0)`: + + >>> a[i, j].subs({i: 0, j: 0}) + 0 + + """ + syindex = self._check_symbolic_index(index) + if syindex is not None: + return syindex + + index = self._check_index_for_getitem(index) + + # `index` is a tuple with one or more slices: + if isinstance(index, tuple) and any(isinstance(i, slice) for i in index): + sl_factors, eindices = self._get_slice_data_for_array_access(index) + array = [self._sparse_array.get(self._parse_index(i), S.Zero) for i in eindices] + nshape = [len(el) for i, el in enumerate(sl_factors) if isinstance(index[i], slice)] + return type(self)(array, nshape) + else: + index = self._parse_index(index) + return self._sparse_array.get(index, S.Zero) + + @classmethod + def zeros(cls, *shape): + """ + Return a sparse N-dim array of zeros. + """ + return cls({}, shape) + + def tomatrix(self): + """ + Converts MutableDenseNDimArray to Matrix. Can convert only 2-dim array, else will raise error. + + Examples + ======== + + >>> from sympy import MutableSparseNDimArray + >>> a = MutableSparseNDimArray([1 for i in range(9)], (3, 3)) + >>> b = a.tomatrix() + >>> b + Matrix([ + [1, 1, 1], + [1, 1, 1], + [1, 1, 1]]) + """ + from sympy.matrices import SparseMatrix + if self.rank() != 2: + raise ValueError('Dimensions must be of size of 2') + + mat_sparse = {} + for key, value in self._sparse_array.items(): + mat_sparse[self._get_tuple_index(key)] = value + + return SparseMatrix(self.shape[0], self.shape[1], mat_sparse) + + def reshape(self, *newshape): + new_total_size = functools.reduce(lambda x,y: x*y, newshape) + if new_total_size != self._loop_size: + raise ValueError("Invalid reshape parameters " + newshape) + + return type(self)(self._sparse_array, newshape) + +class ImmutableSparseNDimArray(SparseNDimArray, ImmutableNDimArray): # type: ignore + + def __new__(cls, iterable=None, shape=None, **kwargs): + shape, flat_list = cls._handle_ndarray_creation_inputs(iterable, shape, **kwargs) + shape = Tuple(*map(_sympify, shape)) + cls._check_special_bounds(flat_list, shape) + loop_size = functools.reduce(lambda x,y: x*y, shape) if shape else len(flat_list) + + # Sparse array: + if isinstance(flat_list, (dict, Dict)): + sparse_array = Dict(flat_list) + else: + sparse_array = {} + for i, el in enumerate(flatten(flat_list)): + if el != 0: + sparse_array[i] = _sympify(el) + + sparse_array = Dict(sparse_array) + + self = Basic.__new__(cls, sparse_array, shape, **kwargs) + self._shape = shape + self._rank = len(shape) + self._loop_size = loop_size + self._sparse_array = sparse_array + + return self + + def __setitem__(self, index, value): + raise TypeError("immutable N-dim array") + + def as_mutable(self): + return MutableSparseNDimArray(self) + + +class MutableSparseNDimArray(MutableNDimArray, SparseNDimArray): + + def __new__(cls, iterable=None, shape=None, **kwargs): + shape, flat_list = cls._handle_ndarray_creation_inputs(iterable, shape, **kwargs) + self = object.__new__(cls) + self._shape = shape + self._rank = len(shape) + self._loop_size = functools.reduce(lambda x,y: x*y, shape) if shape else len(flat_list) + + # Sparse array: + if isinstance(flat_list, (dict, Dict)): + self._sparse_array = dict(flat_list) + return self + + self._sparse_array = {} + + for i, el in enumerate(flatten(flat_list)): + if el != 0: + self._sparse_array[i] = _sympify(el) + + return self + + def __setitem__(self, index, value): + """Allows to set items to MutableDenseNDimArray. + + Examples + ======== + + >>> from sympy import MutableSparseNDimArray + >>> a = MutableSparseNDimArray.zeros(2, 2) + >>> a[0, 0] = 1 + >>> a[1, 1] = 1 + >>> a + [[1, 0], [0, 1]] + """ + if isinstance(index, tuple) and any(isinstance(i, slice) for i in index): + value, eindices, slice_offsets = self._get_slice_data_for_array_assignment(index, value) + for i in eindices: + other_i = [ind - j for ind, j in zip(i, slice_offsets) if j is not None] + other_value = value[other_i] + complete_index = self._parse_index(i) + if other_value != 0: + self._sparse_array[complete_index] = other_value + elif complete_index in self._sparse_array: + self._sparse_array.pop(complete_index) + else: + index = self._parse_index(index) + value = _sympify(value) + if value == 0 and index in self._sparse_array: + self._sparse_array.pop(index) + else: + self._sparse_array[index] = value + + def as_immutable(self): + return ImmutableSparseNDimArray(self) + + @property + def free_symbols(self): + return {i for j in self._sparse_array.values() for i in j.free_symbols} diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/tests/__init__.py b/phivenv/Lib/site-packages/sympy/tensor/array/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/tests/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/tensor/array/tests/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc878c2369e86c492af8b4848a9e3fccc7819937 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/tensor/array/tests/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/tests/__pycache__/test_array_comprehension.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/tensor/array/tests/__pycache__/test_array_comprehension.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a9115c49bfb8abc220027bf0cb24d1db6f428f9d Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/tensor/array/tests/__pycache__/test_array_comprehension.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/tests/__pycache__/test_array_derivatives.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/tensor/array/tests/__pycache__/test_array_derivatives.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65e6766ef8c4a94ec3605731eca136e3858a93c9 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/tensor/array/tests/__pycache__/test_array_derivatives.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/tests/__pycache__/test_arrayop.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/tensor/array/tests/__pycache__/test_arrayop.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4c11cbcbacd179b375287d44523331918bb4342 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/tensor/array/tests/__pycache__/test_arrayop.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/tests/__pycache__/test_immutable_ndim_array.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/tensor/array/tests/__pycache__/test_immutable_ndim_array.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..542b1d476f26b0c9b7bac501ae0e834c8942d292 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/tensor/array/tests/__pycache__/test_immutable_ndim_array.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/tests/__pycache__/test_mutable_ndim_array.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/tensor/array/tests/__pycache__/test_mutable_ndim_array.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..390e844681d4e37cf8051b00fe418df3602bd7ae Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/tensor/array/tests/__pycache__/test_mutable_ndim_array.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/tests/__pycache__/test_ndim_array.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/tensor/array/tests/__pycache__/test_ndim_array.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b940108efca0de6d313f4d99319fa231d07d971 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/tensor/array/tests/__pycache__/test_ndim_array.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/tests/__pycache__/test_ndim_array_conversions.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/tensor/array/tests/__pycache__/test_ndim_array_conversions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..479e4354db595a48e52a083681a268f9eae48c02 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/tensor/array/tests/__pycache__/test_ndim_array_conversions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/tests/test_array_comprehension.py b/phivenv/Lib/site-packages/sympy/tensor/array/tests/test_array_comprehension.py new file mode 100644 index 0000000000000000000000000000000000000000..510e068f287fa04419712e5e9a16a314e522a62d --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/tensor/array/tests/test_array_comprehension.py @@ -0,0 +1,78 @@ +from sympy.tensor.array.array_comprehension import ArrayComprehension, ArrayComprehensionMap +from sympy.tensor.array import ImmutableDenseNDimArray +from sympy.abc import i, j, k, l +from sympy.testing.pytest import raises +from sympy.matrices import Matrix + + +def test_array_comprehension(): + a = ArrayComprehension(i*j, (i, 1, 3), (j, 2, 4)) + b = ArrayComprehension(i, (i, 1, j+1)) + c = ArrayComprehension(i+j+k+l, (i, 1, 2), (j, 1, 3), (k, 1, 4), (l, 1, 5)) + d = ArrayComprehension(k, (i, 1, 5)) + e = ArrayComprehension(i, (j, k+1, k+5)) + assert a.doit().tolist() == [[2, 3, 4], [4, 6, 8], [6, 9, 12]] + assert a.shape == (3, 3) + assert a.is_shape_numeric == True + assert a.tolist() == [[2, 3, 4], [4, 6, 8], [6, 9, 12]] + assert a.tomatrix() == Matrix([ + [2, 3, 4], + [4, 6, 8], + [6, 9, 12]]) + assert len(a) == 9 + assert isinstance(b.doit(), ArrayComprehension) + assert isinstance(a.doit(), ImmutableDenseNDimArray) + assert b.subs(j, 3) == ArrayComprehension(i, (i, 1, 4)) + assert b.free_symbols == {j} + assert b.shape == (j + 1,) + assert b.rank() == 1 + assert b.is_shape_numeric == False + assert c.free_symbols == set() + assert c.function == i + j + k + l + assert c.limits == ((i, 1, 2), (j, 1, 3), (k, 1, 4), (l, 1, 5)) + assert c.doit().tolist() == [[[[4, 5, 6, 7, 8], [5, 6, 7, 8, 9], [6, 7, 8, 9, 10], [7, 8, 9, 10, 11]], + [[5, 6, 7, 8, 9], [6, 7, 8, 9, 10], [7, 8, 9, 10, 11], [8, 9, 10, 11, 12]], + [[6, 7, 8, 9, 10], [7, 8, 9, 10, 11], [8, 9, 10, 11, 12], [9, 10, 11, 12, 13]]], + [[[5, 6, 7, 8, 9], [6, 7, 8, 9, 10], [7, 8, 9, 10, 11], [8, 9, 10, 11, 12]], + [[6, 7, 8, 9, 10], [7, 8, 9, 10, 11], [8, 9, 10, 11, 12], [9, 10, 11, 12, 13]], + [[7, 8, 9, 10, 11], [8, 9, 10, 11, 12], [9, 10, 11, 12, 13], [10, 11, 12, 13, 14]]]] + assert c.free_symbols == set() + assert c.variables == [i, j, k, l] + assert c.bound_symbols == [i, j, k, l] + assert d.doit().tolist() == [k, k, k, k, k] + assert len(e) == 5 + raises(TypeError, lambda: ArrayComprehension(i*j, (i, 1, 3), (j, 2, [1, 3, 2]))) + raises(ValueError, lambda: ArrayComprehension(i*j, (i, 1, 3), (j, 2, 1))) + raises(ValueError, lambda: ArrayComprehension(i*j, (i, 1, 3), (j, 2, j+1))) + raises(ValueError, lambda: len(ArrayComprehension(i*j, (i, 1, 3), (j, 2, j+4)))) + raises(TypeError, lambda: ArrayComprehension(i*j, (i, 0, i + 1.5), (j, 0, 2))) + raises(ValueError, lambda: b.tolist()) + raises(ValueError, lambda: b.tomatrix()) + raises(ValueError, lambda: c.tomatrix()) + +def test_arraycomprehensionmap(): + a = ArrayComprehensionMap(lambda i: i+1, (i, 1, 5)) + assert a.doit().tolist() == [2, 3, 4, 5, 6] + assert a.shape == (5,) + assert a.is_shape_numeric + assert a.tolist() == [2, 3, 4, 5, 6] + assert len(a) == 5 + assert isinstance(a.doit(), ImmutableDenseNDimArray) + expr = ArrayComprehensionMap(lambda i: i+1, (i, 1, k)) + assert expr.doit() == expr + assert expr.subs(k, 4) == ArrayComprehensionMap(lambda i: i+1, (i, 1, 4)) + assert expr.subs(k, 4).doit() == ImmutableDenseNDimArray([2, 3, 4, 5]) + b = ArrayComprehensionMap(lambda i: i+1, (i, 1, 2), (i, 1, 3), (i, 1, 4), (i, 1, 5)) + assert b.doit().tolist() == [[[[2, 3, 4, 5, 6], [3, 5, 7, 9, 11], [4, 7, 10, 13, 16], [5, 9, 13, 17, 21]], + [[3, 5, 7, 9, 11], [5, 9, 13, 17, 21], [7, 13, 19, 25, 31], [9, 17, 25, 33, 41]], + [[4, 7, 10, 13, 16], [7, 13, 19, 25, 31], [10, 19, 28, 37, 46], [13, 25, 37, 49, 61]]], + [[[3, 5, 7, 9, 11], [5, 9, 13, 17, 21], [7, 13, 19, 25, 31], [9, 17, 25, 33, 41]], + [[5, 9, 13, 17, 21], [9, 17, 25, 33, 41], [13, 25, 37, 49, 61], [17, 33, 49, 65, 81]], + [[7, 13, 19, 25, 31], [13, 25, 37, 49, 61], [19, 37, 55, 73, 91], [25, 49, 73, 97, 121]]]] + + # tests about lambda expression + assert ArrayComprehensionMap(lambda: 3, (i, 1, 5)).doit().tolist() == [3, 3, 3, 3, 3] + assert ArrayComprehensionMap(lambda i: i+1, (i, 1, 5)).doit().tolist() == [2, 3, 4, 5, 6] + raises(ValueError, lambda: ArrayComprehensionMap(i*j, (i, 1, 3), (j, 2, 4))) + a = ArrayComprehensionMap(lambda i, j: i+j, (i, 1, 5)) + raises(ValueError, lambda: a.doit()) diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/tests/test_array_derivatives.py b/phivenv/Lib/site-packages/sympy/tensor/array/tests/test_array_derivatives.py new file mode 100644 index 0000000000000000000000000000000000000000..7f6c777c55a9170704f309bf74387d140bf2ec32 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/tensor/array/tests/test_array_derivatives.py @@ -0,0 +1,52 @@ +from sympy.core.symbol import symbols +from sympy.matrices.dense import Matrix +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.tensor.array.ndim_array import NDimArray +from sympy.matrices.matrixbase import MatrixBase +from sympy.tensor.array.array_derivatives import ArrayDerivative + +x, y, z, t = symbols("x y z t") + +m = Matrix([[x, y], [z, t]]) + +M = MatrixSymbol("M", 3, 2) +N = MatrixSymbol("N", 4, 3) + + +def test_array_derivative_construction(): + + d = ArrayDerivative(x, m, evaluate=False) + assert d.shape == (2, 2) + expr = d.doit() + assert isinstance(expr, MatrixBase) + assert expr.shape == (2, 2) + + d = ArrayDerivative(m, m, evaluate=False) + assert d.shape == (2, 2, 2, 2) + expr = d.doit() + assert isinstance(expr, NDimArray) + assert expr.shape == (2, 2, 2, 2) + + d = ArrayDerivative(m, x, evaluate=False) + assert d.shape == (2, 2) + expr = d.doit() + assert isinstance(expr, MatrixBase) + assert expr.shape == (2, 2) + + d = ArrayDerivative(M, N, evaluate=False) + assert d.shape == (4, 3, 3, 2) + expr = d.doit() + assert isinstance(expr, ArrayDerivative) + assert expr.shape == (4, 3, 3, 2) + + d = ArrayDerivative(M, (N, 2), evaluate=False) + assert d.shape == (4, 3, 4, 3, 3, 2) + expr = d.doit() + assert isinstance(expr, ArrayDerivative) + assert expr.shape == (4, 3, 4, 3, 3, 2) + + d = ArrayDerivative(M.as_explicit(), (N.as_explicit(), 2), evaluate=False) + assert d.doit().shape == (4, 3, 4, 3, 3, 2) + expr = d.doit() + assert isinstance(expr, NDimArray) + assert expr.shape == (4, 3, 4, 3, 3, 2) diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/tests/test_arrayop.py b/phivenv/Lib/site-packages/sympy/tensor/array/tests/test_arrayop.py new file mode 100644 index 0000000000000000000000000000000000000000..de56e81e0064f1e303a7a58e41932d15f2d0b41e --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/tensor/array/tests/test_arrayop.py @@ -0,0 +1,361 @@ +import itertools +import random + +from sympy.combinatorics import Permutation +from sympy.combinatorics.permutations import _af_invert +from sympy.testing.pytest import raises + +from sympy.core.function import diff +from sympy.core.symbol import symbols +from sympy.functions.elementary.complexes import (adjoint, conjugate, transpose) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.tensor.array import Array, ImmutableDenseNDimArray, ImmutableSparseNDimArray, MutableSparseNDimArray + +from sympy.tensor.array.arrayop import tensorproduct, tensorcontraction, derive_by_array, permutedims, Flatten, \ + tensordiagonal + + +def test_import_NDimArray(): + from sympy.tensor.array import NDimArray + del NDimArray + + +def test_tensorproduct(): + x,y,z,t = symbols('x y z t') + from sympy.abc import a,b,c,d + assert tensorproduct() == 1 + assert tensorproduct([x]) == Array([x]) + assert tensorproduct([x], [y]) == Array([[x*y]]) + assert tensorproduct([x], [y], [z]) == Array([[[x*y*z]]]) + assert tensorproduct([x], [y], [z], [t]) == Array([[[[x*y*z*t]]]]) + + assert tensorproduct(x) == x + assert tensorproduct(x, y) == x*y + assert tensorproduct(x, y, z) == x*y*z + assert tensorproduct(x, y, z, t) == x*y*z*t + + for ArrayType in [ImmutableDenseNDimArray, ImmutableSparseNDimArray]: + A = ArrayType([x, y]) + B = ArrayType([1, 2, 3]) + C = ArrayType([a, b, c, d]) + + assert tensorproduct(A, B, C) == ArrayType([[[a*x, b*x, c*x, d*x], [2*a*x, 2*b*x, 2*c*x, 2*d*x], [3*a*x, 3*b*x, 3*c*x, 3*d*x]], + [[a*y, b*y, c*y, d*y], [2*a*y, 2*b*y, 2*c*y, 2*d*y], [3*a*y, 3*b*y, 3*c*y, 3*d*y]]]) + + assert tensorproduct([x, y], [1, 2, 3]) == tensorproduct(A, B) + + assert tensorproduct(A, 2) == ArrayType([2*x, 2*y]) + assert tensorproduct(A, [2]) == ArrayType([[2*x], [2*y]]) + assert tensorproduct([2], A) == ArrayType([[2*x, 2*y]]) + assert tensorproduct(a, A) == ArrayType([a*x, a*y]) + assert tensorproduct(a, A, B) == ArrayType([[a*x, 2*a*x, 3*a*x], [a*y, 2*a*y, 3*a*y]]) + assert tensorproduct(A, B, a) == ArrayType([[a*x, 2*a*x, 3*a*x], [a*y, 2*a*y, 3*a*y]]) + assert tensorproduct(B, a, A) == ArrayType([[a*x, a*y], [2*a*x, 2*a*y], [3*a*x, 3*a*y]]) + + # tests for large scale sparse array + for SparseArrayType in [ImmutableSparseNDimArray, MutableSparseNDimArray]: + a = SparseArrayType({1:2, 3:4},(1000, 2000)) + b = SparseArrayType({1:2, 3:4},(1000, 2000)) + assert tensorproduct(a, b) == ImmutableSparseNDimArray({2000001: 4, 2000003: 8, 6000001: 8, 6000003: 16}, (1000, 2000, 1000, 2000)) + + +def test_tensorcontraction(): + from sympy.abc import a,b,c,d,e,f,g,h,i,j,k,l,m,n,o,p,q,r,s,t,u,v,w,x + B = Array(range(18), (2, 3, 3)) + assert tensorcontraction(B, (1, 2)) == Array([12, 39]) + C1 = Array([a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, u, v, w, x], (2, 3, 2, 2)) + + assert tensorcontraction(C1, (0, 2)) == Array([[a + o, b + p], [e + s, f + t], [i + w, j + x]]) + assert tensorcontraction(C1, (0, 2, 3)) == Array([a + p, e + t, i + x]) + assert tensorcontraction(C1, (2, 3)) == Array([[a + d, e + h, i + l], [m + p, q + t, u + x]]) + + +def test_derivative_by_array(): + from sympy.abc import i, j, t, x, y, z + + bexpr = x*y**2*exp(z)*log(t) + sexpr = sin(bexpr) + cexpr = cos(bexpr) + + a = Array([sexpr]) + + assert derive_by_array(sexpr, t) == x*y**2*exp(z)*cos(x*y**2*exp(z)*log(t))/t + assert derive_by_array(sexpr, [x, y, z]) == Array([bexpr/x*cexpr, 2*y*bexpr/y**2*cexpr, bexpr*cexpr]) + assert derive_by_array(a, [x, y, z]) == Array([[bexpr/x*cexpr], [2*y*bexpr/y**2*cexpr], [bexpr*cexpr]]) + + assert derive_by_array(sexpr, [[x, y], [z, t]]) == Array([[bexpr/x*cexpr, 2*y*bexpr/y**2*cexpr], [bexpr*cexpr, bexpr/log(t)/t*cexpr]]) + assert derive_by_array(a, [[x, y], [z, t]]) == Array([[[bexpr/x*cexpr], [2*y*bexpr/y**2*cexpr]], [[bexpr*cexpr], [bexpr/log(t)/t*cexpr]]]) + assert derive_by_array([[x, y], [z, t]], [x, y]) == Array([[[1, 0], [0, 0]], [[0, 1], [0, 0]]]) + assert derive_by_array([[x, y], [z, t]], [[x, y], [z, t]]) == Array([[[[1, 0], [0, 0]], [[0, 1], [0, 0]]], + [[[0, 0], [1, 0]], [[0, 0], [0, 1]]]]) + + assert diff(sexpr, t) == x*y**2*exp(z)*cos(x*y**2*exp(z)*log(t))/t + assert diff(sexpr, Array([x, y, z])) == Array([bexpr/x*cexpr, 2*y*bexpr/y**2*cexpr, bexpr*cexpr]) + assert diff(a, Array([x, y, z])) == Array([[bexpr/x*cexpr], [2*y*bexpr/y**2*cexpr], [bexpr*cexpr]]) + + assert diff(sexpr, Array([[x, y], [z, t]])) == Array([[bexpr/x*cexpr, 2*y*bexpr/y**2*cexpr], [bexpr*cexpr, bexpr/log(t)/t*cexpr]]) + assert diff(a, Array([[x, y], [z, t]])) == Array([[[bexpr/x*cexpr], [2*y*bexpr/y**2*cexpr]], [[bexpr*cexpr], [bexpr/log(t)/t*cexpr]]]) + assert diff(Array([[x, y], [z, t]]), Array([x, y])) == Array([[[1, 0], [0, 0]], [[0, 1], [0, 0]]]) + assert diff(Array([[x, y], [z, t]]), Array([[x, y], [z, t]])) == Array([[[[1, 0], [0, 0]], [[0, 1], [0, 0]]], + [[[0, 0], [1, 0]], [[0, 0], [0, 1]]]]) + + # test for large scale sparse array + for SparseArrayType in [ImmutableSparseNDimArray, MutableSparseNDimArray]: + b = MutableSparseNDimArray({0:i, 1:j}, (10000, 20000)) + assert derive_by_array(b, i) == ImmutableSparseNDimArray({0: 1}, (10000, 20000)) + assert derive_by_array(b, (i, j)) == ImmutableSparseNDimArray({0: 1, 200000001: 1}, (2, 10000, 20000)) + + #https://github.com/sympy/sympy/issues/20655 + U = Array([x, y, z]) + E = 2 + assert derive_by_array(E, U) == ImmutableDenseNDimArray([0, 0, 0]) + + +def test_issue_emerged_while_discussing_10972(): + ua = Array([-1,0]) + Fa = Array([[0, 1], [-1, 0]]) + po = tensorproduct(Fa, ua, Fa, ua) + assert tensorcontraction(po, (1, 2), (4, 5)) == Array([[0, 0], [0, 1]]) + + sa = symbols('a0:144') + po = Array(sa, [2, 2, 3, 3, 2, 2]) + assert tensorcontraction(po, (0, 1), (2, 3), (4, 5)) == sa[0] + sa[108] + sa[111] + sa[124] + sa[127] + sa[140] + sa[143] + sa[16] + sa[19] + sa[3] + sa[32] + sa[35] + assert tensorcontraction(po, (0, 1, 4, 5), (2, 3)) == sa[0] + sa[111] + sa[127] + sa[143] + sa[16] + sa[32] + assert tensorcontraction(po, (0, 1), (4, 5)) == Array([[sa[0] + sa[108] + sa[111] + sa[3], sa[112] + sa[115] + sa[4] + sa[7], + sa[11] + sa[116] + sa[119] + sa[8]], [sa[12] + sa[120] + sa[123] + sa[15], + sa[124] + sa[127] + sa[16] + sa[19], sa[128] + sa[131] + sa[20] + sa[23]], + [sa[132] + sa[135] + sa[24] + sa[27], sa[136] + sa[139] + sa[28] + sa[31], + sa[140] + sa[143] + sa[32] + sa[35]]]) + assert tensorcontraction(po, (0, 1), (2, 3)) == Array([[sa[0] + sa[108] + sa[124] + sa[140] + sa[16] + sa[32], sa[1] + sa[109] + sa[125] + sa[141] + sa[17] + sa[33]], + [sa[110] + sa[126] + sa[142] + sa[18] + sa[2] + sa[34], sa[111] + sa[127] + sa[143] + sa[19] + sa[3] + sa[35]]]) + + +def test_array_permutedims(): + sa = symbols('a0:144') + + for ArrayType in [ImmutableDenseNDimArray, ImmutableSparseNDimArray]: + m1 = ArrayType(sa[:6], (2, 3)) + assert permutedims(m1, (1, 0)) == transpose(m1) + assert m1.tomatrix().T == permutedims(m1, (1, 0)).tomatrix() + + assert m1.tomatrix().T == transpose(m1).tomatrix() + assert m1.tomatrix().C == conjugate(m1).tomatrix() + assert m1.tomatrix().H == adjoint(m1).tomatrix() + + assert m1.tomatrix().T == m1.transpose().tomatrix() + assert m1.tomatrix().C == m1.conjugate().tomatrix() + assert m1.tomatrix().H == m1.adjoint().tomatrix() + + raises(ValueError, lambda: permutedims(m1, (0,))) + raises(ValueError, lambda: permutedims(m1, (0, 0))) + raises(ValueError, lambda: permutedims(m1, (1, 2, 0))) + + # Some tests with random arrays: + dims = 6 + shape = [random.randint(1,5) for i in range(dims)] + elems = [random.random() for i in range(tensorproduct(*shape))] + ra = ArrayType(elems, shape) + perm = list(range(dims)) + # Randomize the permutation: + random.shuffle(perm) + # Test inverse permutation: + assert permutedims(permutedims(ra, perm), _af_invert(perm)) == ra + # Test that permuted shape corresponds to action by `Permutation`: + assert permutedims(ra, perm).shape == tuple(Permutation(perm)(shape)) + + z = ArrayType.zeros(4,5,6,7) + + assert permutedims(z, (2, 3, 1, 0)).shape == (6, 7, 5, 4) + assert permutedims(z, [2, 3, 1, 0]).shape == (6, 7, 5, 4) + assert permutedims(z, Permutation([2, 3, 1, 0])).shape == (6, 7, 5, 4) + + po = ArrayType(sa, [2, 2, 3, 3, 2, 2]) + + raises(ValueError, lambda: permutedims(po, (1, 1))) + raises(ValueError, lambda: po.transpose()) + raises(ValueError, lambda: po.adjoint()) + + assert permutedims(po, reversed(range(po.rank()))) == ArrayType( + [[[[[[sa[0], sa[72]], [sa[36], sa[108]]], [[sa[12], sa[84]], [sa[48], sa[120]]], [[sa[24], + sa[96]], [sa[60], sa[132]]]], + [[[sa[4], sa[76]], [sa[40], sa[112]]], [[sa[16], + sa[88]], [sa[52], sa[124]]], + [[sa[28], sa[100]], [sa[64], sa[136]]]], + [[[sa[8], + sa[80]], [sa[44], sa[116]]], [[sa[20], sa[92]], [sa[56], sa[128]]], [[sa[32], + sa[104]], [sa[68], sa[140]]]]], + [[[[sa[2], sa[74]], [sa[38], sa[110]]], [[sa[14], + sa[86]], [sa[50], sa[122]]], [[sa[26], sa[98]], [sa[62], sa[134]]]], + [[[sa[6], + sa[78]], [sa[42], sa[114]]], [[sa[18], sa[90]], [sa[54], sa[126]]], [[sa[30], + sa[102]], [sa[66], sa[138]]]], + [[[sa[10], sa[82]], [sa[46], sa[118]]], [[sa[22], + sa[94]], [sa[58], sa[130]]], + [[sa[34], sa[106]], [sa[70], sa[142]]]]]], + [[[[[sa[1], + sa[73]], [sa[37], sa[109]]], [[sa[13], sa[85]], [sa[49], sa[121]]], [[sa[25], + sa[97]], [sa[61], sa[133]]]], + [[[sa[5], sa[77]], [sa[41], sa[113]]], [[sa[17], + sa[89]], [sa[53], sa[125]]], + [[sa[29], sa[101]], [sa[65], sa[137]]]], + [[[sa[9], + sa[81]], [sa[45], sa[117]]], [[sa[21], sa[93]], [sa[57], sa[129]]], [[sa[33], + sa[105]], [sa[69], sa[141]]]]], + [[[[sa[3], sa[75]], [sa[39], sa[111]]], [[sa[15], + sa[87]], [sa[51], sa[123]]], [[sa[27], sa[99]], [sa[63], sa[135]]]], + [[[sa[7], + sa[79]], [sa[43], sa[115]]], [[sa[19], sa[91]], [sa[55], sa[127]]], [[sa[31], + sa[103]], [sa[67], sa[139]]]], + [[[sa[11], sa[83]], [sa[47], sa[119]]], [[sa[23], + sa[95]], [sa[59], sa[131]]], + [[sa[35], sa[107]], [sa[71], sa[143]]]]]]]) + + assert permutedims(po, (1, 0, 2, 3, 4, 5)) == ArrayType( + [[[[[[sa[0], sa[1]], [sa[2], sa[3]]], [[sa[4], sa[5]], [sa[6], sa[7]]], [[sa[8], sa[9]], [sa[10], + sa[11]]]], + [[[sa[12], sa[13]], [sa[14], sa[15]]], [[sa[16], sa[17]], [sa[18], + sa[19]]], [[sa[20], sa[21]], [sa[22], sa[23]]]], + [[[sa[24], sa[25]], [sa[26], + sa[27]]], [[sa[28], sa[29]], [sa[30], sa[31]]], [[sa[32], sa[33]], [sa[34], + sa[35]]]]], + [[[[sa[72], sa[73]], [sa[74], sa[75]]], [[sa[76], sa[77]], [sa[78], + sa[79]]], [[sa[80], sa[81]], [sa[82], sa[83]]]], + [[[sa[84], sa[85]], [sa[86], + sa[87]]], [[sa[88], sa[89]], [sa[90], sa[91]]], [[sa[92], sa[93]], [sa[94], + sa[95]]]], + [[[sa[96], sa[97]], [sa[98], sa[99]]], [[sa[100], sa[101]], [sa[102], + sa[103]]], + [[sa[104], sa[105]], [sa[106], sa[107]]]]]], [[[[[sa[36], sa[37]], [sa[38], + sa[39]]], + [[sa[40], sa[41]], [sa[42], sa[43]]], + [[sa[44], sa[45]], [sa[46], + sa[47]]]], + [[[sa[48], sa[49]], [sa[50], sa[51]]], + [[sa[52], sa[53]], [sa[54], + sa[55]]], + [[sa[56], sa[57]], [sa[58], sa[59]]]], + [[[sa[60], sa[61]], [sa[62], + sa[63]]], + [[sa[64], sa[65]], [sa[66], sa[67]]], + [[sa[68], sa[69]], [sa[70], + sa[71]]]]], [ + [[[sa[108], sa[109]], [sa[110], sa[111]]], + [[sa[112], sa[113]], [sa[114], + sa[115]]], + [[sa[116], sa[117]], [sa[118], sa[119]]]], + [[[sa[120], sa[121]], [sa[122], + sa[123]]], + [[sa[124], sa[125]], [sa[126], sa[127]]], + [[sa[128], sa[129]], [sa[130], + sa[131]]]], + [[[sa[132], sa[133]], [sa[134], sa[135]]], + [[sa[136], sa[137]], [sa[138], + sa[139]]], + [[sa[140], sa[141]], [sa[142], sa[143]]]]]]]) + + assert permutedims(po, (0, 2, 1, 4, 3, 5)) == ArrayType( + [[[[[[sa[0], sa[1]], [sa[4], sa[5]], [sa[8], sa[9]]], [[sa[2], sa[3]], [sa[6], sa[7]], [sa[10], + sa[11]]]], + [[[sa[36], sa[37]], [sa[40], sa[41]], [sa[44], sa[45]]], [[sa[38], + sa[39]], [sa[42], sa[43]], [sa[46], sa[47]]]]], + [[[[sa[12], sa[13]], [sa[16], + sa[17]], [sa[20], sa[21]]], [[sa[14], sa[15]], [sa[18], sa[19]], [sa[22], + sa[23]]]], + [[[sa[48], sa[49]], [sa[52], sa[53]], [sa[56], sa[57]]], [[sa[50], + sa[51]], [sa[54], sa[55]], [sa[58], sa[59]]]]], + [[[[sa[24], sa[25]], [sa[28], + sa[29]], [sa[32], sa[33]]], [[sa[26], sa[27]], [sa[30], sa[31]], [sa[34], + sa[35]]]], + [[[sa[60], sa[61]], [sa[64], sa[65]], [sa[68], sa[69]]], [[sa[62], + sa[63]], [sa[66], sa[67]], [sa[70], sa[71]]]]]], + [[[[[sa[72], sa[73]], [sa[76], + sa[77]], [sa[80], sa[81]]], [[sa[74], sa[75]], [sa[78], sa[79]], [sa[82], + sa[83]]]], + [[[sa[108], sa[109]], [sa[112], sa[113]], [sa[116], sa[117]]], [[sa[110], + sa[111]], [sa[114], sa[115]], + [sa[118], sa[119]]]]], + [[[[sa[84], sa[85]], [sa[88], + sa[89]], [sa[92], sa[93]]], [[sa[86], sa[87]], [sa[90], sa[91]], [sa[94], + sa[95]]]], + [[[sa[120], sa[121]], [sa[124], sa[125]], [sa[128], sa[129]]], [[sa[122], + sa[123]], [sa[126], sa[127]], + [sa[130], sa[131]]]]], + [[[[sa[96], sa[97]], [sa[100], + sa[101]], [sa[104], sa[105]]], [[sa[98], sa[99]], [sa[102], sa[103]], [sa[106], + sa[107]]]], + [[[sa[132], sa[133]], [sa[136], sa[137]], [sa[140], sa[141]]], [[sa[134], + sa[135]], [sa[138], sa[139]], + [sa[142], sa[143]]]]]]]) + + po2 = po.reshape(4, 9, 2, 2) + assert po2 == ArrayType([[[[sa[0], sa[1]], [sa[2], sa[3]]], [[sa[4], sa[5]], [sa[6], sa[7]]], [[sa[8], sa[9]], [sa[10], sa[11]]], [[sa[12], sa[13]], [sa[14], sa[15]]], [[sa[16], sa[17]], [sa[18], sa[19]]], [[sa[20], sa[21]], [sa[22], sa[23]]], [[sa[24], sa[25]], [sa[26], sa[27]]], [[sa[28], sa[29]], [sa[30], sa[31]]], [[sa[32], sa[33]], [sa[34], sa[35]]]], [[[sa[36], sa[37]], [sa[38], sa[39]]], [[sa[40], sa[41]], [sa[42], sa[43]]], [[sa[44], sa[45]], [sa[46], sa[47]]], [[sa[48], sa[49]], [sa[50], sa[51]]], [[sa[52], sa[53]], [sa[54], sa[55]]], [[sa[56], sa[57]], [sa[58], sa[59]]], [[sa[60], sa[61]], [sa[62], sa[63]]], [[sa[64], sa[65]], [sa[66], sa[67]]], [[sa[68], sa[69]], [sa[70], sa[71]]]], [[[sa[72], sa[73]], [sa[74], sa[75]]], [[sa[76], sa[77]], [sa[78], sa[79]]], [[sa[80], sa[81]], [sa[82], sa[83]]], [[sa[84], sa[85]], [sa[86], sa[87]]], [[sa[88], sa[89]], [sa[90], sa[91]]], [[sa[92], sa[93]], [sa[94], sa[95]]], [[sa[96], sa[97]], [sa[98], sa[99]]], [[sa[100], sa[101]], [sa[102], sa[103]]], [[sa[104], sa[105]], [sa[106], sa[107]]]], [[[sa[108], sa[109]], [sa[110], sa[111]]], [[sa[112], sa[113]], [sa[114], sa[115]]], [[sa[116], sa[117]], [sa[118], sa[119]]], [[sa[120], sa[121]], [sa[122], sa[123]]], [[sa[124], sa[125]], [sa[126], sa[127]]], [[sa[128], sa[129]], [sa[130], sa[131]]], [[sa[132], sa[133]], [sa[134], sa[135]]], [[sa[136], sa[137]], [sa[138], sa[139]]], [[sa[140], sa[141]], [sa[142], sa[143]]]]]) + + assert permutedims(po2, (3, 2, 0, 1)) == ArrayType([[[[sa[0], sa[4], sa[8], sa[12], sa[16], sa[20], sa[24], sa[28], sa[32]], [sa[36], sa[40], sa[44], sa[48], sa[52], sa[56], sa[60], sa[64], sa[68]], [sa[72], sa[76], sa[80], sa[84], sa[88], sa[92], sa[96], sa[100], sa[104]], [sa[108], sa[112], sa[116], sa[120], sa[124], sa[128], sa[132], sa[136], sa[140]]], [[sa[2], sa[6], sa[10], sa[14], sa[18], sa[22], sa[26], sa[30], sa[34]], [sa[38], sa[42], sa[46], sa[50], sa[54], sa[58], sa[62], sa[66], sa[70]], [sa[74], sa[78], sa[82], sa[86], sa[90], sa[94], sa[98], sa[102], sa[106]], [sa[110], sa[114], sa[118], sa[122], sa[126], sa[130], sa[134], sa[138], sa[142]]]], [[[sa[1], sa[5], sa[9], sa[13], sa[17], sa[21], sa[25], sa[29], sa[33]], [sa[37], sa[41], sa[45], sa[49], sa[53], sa[57], sa[61], sa[65], sa[69]], [sa[73], sa[77], sa[81], sa[85], sa[89], sa[93], sa[97], sa[101], sa[105]], [sa[109], sa[113], sa[117], sa[121], sa[125], sa[129], sa[133], sa[137], sa[141]]], [[sa[3], sa[7], sa[11], sa[15], sa[19], sa[23], sa[27], sa[31], sa[35]], [sa[39], sa[43], sa[47], sa[51], sa[55], sa[59], sa[63], sa[67], sa[71]], [sa[75], sa[79], sa[83], sa[87], sa[91], sa[95], sa[99], sa[103], sa[107]], [sa[111], sa[115], sa[119], sa[123], sa[127], sa[131], sa[135], sa[139], sa[143]]]]]) + + # test for large scale sparse array + for SparseArrayType in [ImmutableSparseNDimArray, MutableSparseNDimArray]: + A = SparseArrayType({1:1, 10000:2}, (10000, 20000, 10000)) + assert permutedims(A, (0, 1, 2)) == A + assert permutedims(A, (1, 0, 2)) == SparseArrayType({1: 1, 100000000: 2}, (20000, 10000, 10000)) + B = SparseArrayType({1:1, 20000:2}, (10000, 20000)) + assert B.transpose() == SparseArrayType({10000: 1, 1: 2}, (20000, 10000)) + + +def test_permutedims_with_indices(): + A = Array(range(32)).reshape(2, 2, 2, 2, 2) + indices_new = list("abcde") + indices_old = list("ebdac") + new_A = permutedims(A, index_order_new=indices_new, index_order_old=indices_old) + for a, b, c, d, e in itertools.product(range(2), range(2), range(2), range(2), range(2)): + assert new_A[a, b, c, d, e] == A[e, b, d, a, c] + indices_old = list("cabed") + new_A = permutedims(A, index_order_new=indices_new, index_order_old=indices_old) + for a, b, c, d, e in itertools.product(range(2), range(2), range(2), range(2), range(2)): + assert new_A[a, b, c, d, e] == A[c, a, b, e, d] + raises(ValueError, lambda: permutedims(A, index_order_old=list("aacde"), index_order_new=list("abcde"))) + raises(ValueError, lambda: permutedims(A, index_order_old=list("abcde"), index_order_new=list("abcce"))) + raises(ValueError, lambda: permutedims(A, index_order_old=list("abcde"), index_order_new=list("abce"))) + raises(ValueError, lambda: permutedims(A, index_order_old=list("abce"), index_order_new=list("abce"))) + raises(ValueError, lambda: permutedims(A, [2, 1, 0, 3, 4], index_order_old=list("abcde"))) + raises(ValueError, lambda: permutedims(A, [2, 1, 0, 3, 4], index_order_new=list("abcde"))) + + +def test_flatten(): + from sympy.matrices.dense import Matrix + for ArrayType in [ImmutableDenseNDimArray, ImmutableSparseNDimArray, Matrix]: + A = ArrayType(range(24)).reshape(4, 6) + assert list(Flatten(A)) == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23] + + for i, v in enumerate(Flatten(A)): + assert i == v + + +def test_tensordiagonal(): + from sympy.matrices.dense import eye + expr = Array(range(9)).reshape(3, 3) + raises(ValueError, lambda: tensordiagonal(expr, [0], [1])) + raises(ValueError, lambda: tensordiagonal(expr, [0, 0])) + assert tensordiagonal(eye(3), [0, 1]) == Array([1, 1, 1]) + assert tensordiagonal(expr, [0, 1]) == Array([0, 4, 8]) + x, y, z = symbols("x y z") + expr2 = tensorproduct([x, y, z], expr) + assert tensordiagonal(expr2, [1, 2]) == Array([[0, 4*x, 8*x], [0, 4*y, 8*y], [0, 4*z, 8*z]]) + assert tensordiagonal(expr2, [0, 1]) == Array([[0, 3*y, 6*z], [x, 4*y, 7*z], [2*x, 5*y, 8*z]]) + assert tensordiagonal(expr2, [0, 1, 2]) == Array([0, 4*y, 8*z]) + # assert tensordiagonal(expr2, [0]) == permutedims(expr2, [1, 2, 0]) + # assert tensordiagonal(expr2, [1]) == permutedims(expr2, [0, 2, 1]) + # assert tensordiagonal(expr2, [2]) == expr2 + # assert tensordiagonal(expr2, [1], [2]) == expr2 + # assert tensordiagonal(expr2, [0], [1]) == permutedims(expr2, [2, 0, 1]) + + a, b, c, X, Y, Z = symbols("a b c X Y Z") + expr3 = tensorproduct([x, y, z], [1, 2, 3], [a, b, c], [X, Y, Z]) + assert tensordiagonal(expr3, [0, 1, 2, 3]) == Array([x*a*X, 2*y*b*Y, 3*z*c*Z]) + assert tensordiagonal(expr3, [0, 1], [2, 3]) == tensorproduct([x, 2*y, 3*z], [a*X, b*Y, c*Z]) + + # assert tensordiagonal(expr3, [0], [1, 2], [3]) == tensorproduct([x, y, z], [a, 2*b, 3*c], [X, Y, Z]) + assert tensordiagonal(tensordiagonal(expr3, [2, 3]), [0, 1]) == tensorproduct([a*X, b*Y, c*Z], [x, 2*y, 3*z]) + + raises(ValueError, lambda: tensordiagonal([[1, 2, 3], [4, 5, 6]], [0, 1])) + raises(ValueError, lambda: tensordiagonal(expr3.reshape(3, 3, 9), [1, 2])) diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/tests/test_immutable_ndim_array.py b/phivenv/Lib/site-packages/sympy/tensor/array/tests/test_immutable_ndim_array.py new file mode 100644 index 0000000000000000000000000000000000000000..c6bed4b605c424284b4752592b03b13a9178aac8 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/tensor/array/tests/test_immutable_ndim_array.py @@ -0,0 +1,452 @@ +from copy import copy + +from sympy.tensor.array.dense_ndim_array import ImmutableDenseNDimArray +from sympy.core.containers import Dict +from sympy.core.function import diff +from sympy.core.numbers import Rational +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.matrices import SparseMatrix +from sympy.tensor.indexed import (Indexed, IndexedBase) +from sympy.matrices import Matrix +from sympy.tensor.array.sparse_ndim_array import ImmutableSparseNDimArray +from sympy.testing.pytest import raises + + +def test_ndim_array_initiation(): + arr_with_no_elements = ImmutableDenseNDimArray([], shape=(0,)) + assert len(arr_with_no_elements) == 0 + assert arr_with_no_elements.rank() == 1 + + raises(ValueError, lambda: ImmutableDenseNDimArray([0], shape=(0,))) + raises(ValueError, lambda: ImmutableDenseNDimArray([1, 2, 3], shape=(0,))) + raises(ValueError, lambda: ImmutableDenseNDimArray([], shape=())) + + raises(ValueError, lambda: ImmutableSparseNDimArray([0], shape=(0,))) + raises(ValueError, lambda: ImmutableSparseNDimArray([1, 2, 3], shape=(0,))) + raises(ValueError, lambda: ImmutableSparseNDimArray([], shape=())) + + arr_with_one_element = ImmutableDenseNDimArray([23]) + assert len(arr_with_one_element) == 1 + assert arr_with_one_element[0] == 23 + assert arr_with_one_element[:] == ImmutableDenseNDimArray([23]) + assert arr_with_one_element.rank() == 1 + + arr_with_symbol_element = ImmutableDenseNDimArray([Symbol('x')]) + assert len(arr_with_symbol_element) == 1 + assert arr_with_symbol_element[0] == Symbol('x') + assert arr_with_symbol_element[:] == ImmutableDenseNDimArray([Symbol('x')]) + assert arr_with_symbol_element.rank() == 1 + + number5 = 5 + vector = ImmutableDenseNDimArray.zeros(number5) + assert len(vector) == number5 + assert vector.shape == (number5,) + assert vector.rank() == 1 + + vector = ImmutableSparseNDimArray.zeros(number5) + assert len(vector) == number5 + assert vector.shape == (number5,) + assert vector._sparse_array == Dict() + assert vector.rank() == 1 + + n_dim_array = ImmutableDenseNDimArray(range(3**4), (3, 3, 3, 3,)) + assert len(n_dim_array) == 3 * 3 * 3 * 3 + assert n_dim_array.shape == (3, 3, 3, 3) + assert n_dim_array.rank() == 4 + + array_shape = (3, 3, 3, 3) + sparse_array = ImmutableSparseNDimArray.zeros(*array_shape) + assert len(sparse_array._sparse_array) == 0 + assert len(sparse_array) == 3 * 3 * 3 * 3 + assert n_dim_array.shape == array_shape + assert n_dim_array.rank() == 4 + + one_dim_array = ImmutableDenseNDimArray([2, 3, 1]) + assert len(one_dim_array) == 3 + assert one_dim_array.shape == (3,) + assert one_dim_array.rank() == 1 + assert one_dim_array.tolist() == [2, 3, 1] + + shape = (3, 3) + array_with_many_args = ImmutableSparseNDimArray.zeros(*shape) + assert len(array_with_many_args) == 3 * 3 + assert array_with_many_args.shape == shape + assert array_with_many_args[0, 0] == 0 + assert array_with_many_args.rank() == 2 + + shape = (int(3), int(3)) + array_with_long_shape = ImmutableSparseNDimArray.zeros(*shape) + assert len(array_with_long_shape) == 3 * 3 + assert array_with_long_shape.shape == shape + assert array_with_long_shape[int(0), int(0)] == 0 + assert array_with_long_shape.rank() == 2 + + vector_with_long_shape = ImmutableDenseNDimArray(range(5), int(5)) + assert len(vector_with_long_shape) == 5 + assert vector_with_long_shape.shape == (int(5),) + assert vector_with_long_shape.rank() == 1 + raises(ValueError, lambda: vector_with_long_shape[int(5)]) + + from sympy.abc import x + for ArrayType in [ImmutableDenseNDimArray, ImmutableSparseNDimArray]: + rank_zero_array = ArrayType(x) + assert len(rank_zero_array) == 1 + assert rank_zero_array.shape == () + assert rank_zero_array.rank() == 0 + assert rank_zero_array[()] == x + raises(ValueError, lambda: rank_zero_array[0]) + + +def test_reshape(): + array = ImmutableDenseNDimArray(range(50), 50) + assert array.shape == (50,) + assert array.rank() == 1 + + array = array.reshape(5, 5, 2) + assert array.shape == (5, 5, 2) + assert array.rank() == 3 + assert len(array) == 50 + + +def test_getitem(): + for ArrayType in [ImmutableDenseNDimArray, ImmutableSparseNDimArray]: + array = ArrayType(range(24)).reshape(2, 3, 4) + assert array.tolist() == [[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]] + assert array[0] == ArrayType([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]) + assert array[0, 0] == ArrayType([0, 1, 2, 3]) + value = 0 + for i in range(2): + for j in range(3): + for k in range(4): + assert array[i, j, k] == value + value += 1 + + raises(ValueError, lambda: array[3, 4, 5]) + raises(ValueError, lambda: array[3, 4, 5, 6]) + raises(ValueError, lambda: array[3, 4, 5, 3:4]) + + +def test_iterator(): + array = ImmutableDenseNDimArray(range(4), (2, 2)) + assert array[0] == ImmutableDenseNDimArray([0, 1]) + assert array[1] == ImmutableDenseNDimArray([2, 3]) + + array = array.reshape(4) + j = 0 + for i in array: + assert i == j + j += 1 + + +def test_sparse(): + sparse_array = ImmutableSparseNDimArray([0, 0, 0, 1], (2, 2)) + assert len(sparse_array) == 2 * 2 + # dictionary where all data is, only non-zero entries are actually stored: + assert len(sparse_array._sparse_array) == 1 + + assert sparse_array.tolist() == [[0, 0], [0, 1]] + + for i, j in zip(sparse_array, [[0, 0], [0, 1]]): + assert i == ImmutableSparseNDimArray(j) + + def sparse_assignment(): + sparse_array[0, 0] = 123 + + assert len(sparse_array._sparse_array) == 1 + raises(TypeError, sparse_assignment) + assert len(sparse_array._sparse_array) == 1 + assert sparse_array[0, 0] == 0 + assert sparse_array/0 == ImmutableSparseNDimArray([[S.NaN, S.NaN], [S.NaN, S.ComplexInfinity]], (2, 2)) + + # test for large scale sparse array + # equality test + assert ImmutableSparseNDimArray.zeros(100000, 200000) == ImmutableSparseNDimArray.zeros(100000, 200000) + + # __mul__ and __rmul__ + a = ImmutableSparseNDimArray({200001: 1}, (100000, 200000)) + assert a * 3 == ImmutableSparseNDimArray({200001: 3}, (100000, 200000)) + assert 3 * a == ImmutableSparseNDimArray({200001: 3}, (100000, 200000)) + assert a * 0 == ImmutableSparseNDimArray({}, (100000, 200000)) + assert 0 * a == ImmutableSparseNDimArray({}, (100000, 200000)) + + # __truediv__ + assert a/3 == ImmutableSparseNDimArray({200001: Rational(1, 3)}, (100000, 200000)) + + # __neg__ + assert -a == ImmutableSparseNDimArray({200001: -1}, (100000, 200000)) + + +def test_calculation(): + + a = ImmutableDenseNDimArray([1]*9, (3, 3)) + b = ImmutableDenseNDimArray([9]*9, (3, 3)) + + c = a + b + for i in c: + assert i == ImmutableDenseNDimArray([10, 10, 10]) + + assert c == ImmutableDenseNDimArray([10]*9, (3, 3)) + assert c == ImmutableSparseNDimArray([10]*9, (3, 3)) + + c = b - a + for i in c: + assert i == ImmutableDenseNDimArray([8, 8, 8]) + + assert c == ImmutableDenseNDimArray([8]*9, (3, 3)) + assert c == ImmutableSparseNDimArray([8]*9, (3, 3)) + + +def test_ndim_array_converting(): + dense_array = ImmutableDenseNDimArray([1, 2, 3, 4], (2, 2)) + alist = dense_array.tolist() + + assert alist == [[1, 2], [3, 4]] + + matrix = dense_array.tomatrix() + assert (isinstance(matrix, Matrix)) + + for i in range(len(dense_array)): + assert dense_array[dense_array._get_tuple_index(i)] == matrix[i] + assert matrix.shape == dense_array.shape + + assert ImmutableDenseNDimArray(matrix) == dense_array + assert ImmutableDenseNDimArray(matrix.as_immutable()) == dense_array + assert ImmutableDenseNDimArray(matrix.as_mutable()) == dense_array + + sparse_array = ImmutableSparseNDimArray([1, 2, 3, 4], (2, 2)) + alist = sparse_array.tolist() + + assert alist == [[1, 2], [3, 4]] + + matrix = sparse_array.tomatrix() + assert(isinstance(matrix, SparseMatrix)) + + for i in range(len(sparse_array)): + assert sparse_array[sparse_array._get_tuple_index(i)] == matrix[i] + assert matrix.shape == sparse_array.shape + + assert ImmutableSparseNDimArray(matrix) == sparse_array + assert ImmutableSparseNDimArray(matrix.as_immutable()) == sparse_array + assert ImmutableSparseNDimArray(matrix.as_mutable()) == sparse_array + + +def test_converting_functions(): + arr_list = [1, 2, 3, 4] + arr_matrix = Matrix(((1, 2), (3, 4))) + + # list + arr_ndim_array = ImmutableDenseNDimArray(arr_list, (2, 2)) + assert (isinstance(arr_ndim_array, ImmutableDenseNDimArray)) + assert arr_matrix.tolist() == arr_ndim_array.tolist() + + # Matrix + arr_ndim_array = ImmutableDenseNDimArray(arr_matrix) + assert (isinstance(arr_ndim_array, ImmutableDenseNDimArray)) + assert arr_matrix.tolist() == arr_ndim_array.tolist() + assert arr_matrix.shape == arr_ndim_array.shape + + +def test_equality(): + first_list = [1, 2, 3, 4] + second_list = [1, 2, 3, 4] + third_list = [4, 3, 2, 1] + assert first_list == second_list + assert first_list != third_list + + first_ndim_array = ImmutableDenseNDimArray(first_list, (2, 2)) + second_ndim_array = ImmutableDenseNDimArray(second_list, (2, 2)) + fourth_ndim_array = ImmutableDenseNDimArray(first_list, (2, 2)) + + assert first_ndim_array == second_ndim_array + + def assignment_attempt(a): + a[0, 0] = 0 + + raises(TypeError, lambda: assignment_attempt(second_ndim_array)) + assert first_ndim_array == second_ndim_array + assert first_ndim_array == fourth_ndim_array + + +def test_arithmetic(): + a = ImmutableDenseNDimArray([3 for i in range(9)], (3, 3)) + b = ImmutableDenseNDimArray([7 for i in range(9)], (3, 3)) + + c1 = a + b + c2 = b + a + assert c1 == c2 + + d1 = a - b + d2 = b - a + assert d1 == d2 * (-1) + + e1 = a * 5 + e2 = 5 * a + e3 = copy(a) + e3 *= 5 + assert e1 == e2 == e3 + + f1 = a / 5 + f2 = copy(a) + f2 /= 5 + assert f1 == f2 + assert f1[0, 0] == f1[0, 1] == f1[0, 2] == f1[1, 0] == f1[1, 1] == \ + f1[1, 2] == f1[2, 0] == f1[2, 1] == f1[2, 2] == Rational(3, 5) + + assert type(a) == type(b) == type(c1) == type(c2) == type(d1) == type(d2) \ + == type(e1) == type(e2) == type(e3) == type(f1) + + z0 = -a + assert z0 == ImmutableDenseNDimArray([-3 for i in range(9)], (3, 3)) + + +def test_higher_dimenions(): + m3 = ImmutableDenseNDimArray(range(10, 34), (2, 3, 4)) + + assert m3.tolist() == [[[10, 11, 12, 13], + [14, 15, 16, 17], + [18, 19, 20, 21]], + + [[22, 23, 24, 25], + [26, 27, 28, 29], + [30, 31, 32, 33]]] + + assert m3._get_tuple_index(0) == (0, 0, 0) + assert m3._get_tuple_index(1) == (0, 0, 1) + assert m3._get_tuple_index(4) == (0, 1, 0) + assert m3._get_tuple_index(12) == (1, 0, 0) + + assert str(m3) == '[[[10, 11, 12, 13], [14, 15, 16, 17], [18, 19, 20, 21]], [[22, 23, 24, 25], [26, 27, 28, 29], [30, 31, 32, 33]]]' + + m3_rebuilt = ImmutableDenseNDimArray([[[10, 11, 12, 13], [14, 15, 16, 17], [18, 19, 20, 21]], [[22, 23, 24, 25], [26, 27, 28, 29], [30, 31, 32, 33]]]) + assert m3 == m3_rebuilt + + m3_other = ImmutableDenseNDimArray([[[10, 11, 12, 13], [14, 15, 16, 17], [18, 19, 20, 21]], [[22, 23, 24, 25], [26, 27, 28, 29], [30, 31, 32, 33]]], (2, 3, 4)) + + assert m3 == m3_other + + +def test_rebuild_immutable_arrays(): + sparr = ImmutableSparseNDimArray(range(10, 34), (2, 3, 4)) + densarr = ImmutableDenseNDimArray(range(10, 34), (2, 3, 4)) + + assert sparr == sparr.func(*sparr.args) + assert densarr == densarr.func(*densarr.args) + + +def test_slices(): + md = ImmutableDenseNDimArray(range(10, 34), (2, 3, 4)) + + assert md[:] == ImmutableDenseNDimArray(range(10, 34), (2, 3, 4)) + assert md[:, :, 0].tomatrix() == Matrix([[10, 14, 18], [22, 26, 30]]) + assert md[0, 1:2, :].tomatrix() == Matrix([[14, 15, 16, 17]]) + assert md[0, 1:3, :].tomatrix() == Matrix([[14, 15, 16, 17], [18, 19, 20, 21]]) + assert md[:, :, :] == md + + sd = ImmutableSparseNDimArray(range(10, 34), (2, 3, 4)) + assert sd == ImmutableSparseNDimArray(md) + + assert sd[:] == ImmutableSparseNDimArray(range(10, 34), (2, 3, 4)) + assert sd[:, :, 0].tomatrix() == Matrix([[10, 14, 18], [22, 26, 30]]) + assert sd[0, 1:2, :].tomatrix() == Matrix([[14, 15, 16, 17]]) + assert sd[0, 1:3, :].tomatrix() == Matrix([[14, 15, 16, 17], [18, 19, 20, 21]]) + assert sd[:, :, :] == sd + + +def test_diff_and_applyfunc(): + from sympy.abc import x, y, z + md = ImmutableDenseNDimArray([[x, y], [x*z, x*y*z]]) + assert md.diff(x) == ImmutableDenseNDimArray([[1, 0], [z, y*z]]) + assert diff(md, x) == ImmutableDenseNDimArray([[1, 0], [z, y*z]]) + + sd = ImmutableSparseNDimArray(md) + assert sd == ImmutableSparseNDimArray([x, y, x*z, x*y*z], (2, 2)) + assert sd.diff(x) == ImmutableSparseNDimArray([[1, 0], [z, y*z]]) + assert diff(sd, x) == ImmutableSparseNDimArray([[1, 0], [z, y*z]]) + + mdn = md.applyfunc(lambda x: x*3) + assert mdn == ImmutableDenseNDimArray([[3*x, 3*y], [3*x*z, 3*x*y*z]]) + assert md != mdn + + sdn = sd.applyfunc(lambda x: x/2) + assert sdn == ImmutableSparseNDimArray([[x/2, y/2], [x*z/2, x*y*z/2]]) + assert sd != sdn + + sdp = sd.applyfunc(lambda x: x+1) + assert sdp == ImmutableSparseNDimArray([[x + 1, y + 1], [x*z + 1, x*y*z + 1]]) + assert sd != sdp + + +def test_op_priority(): + from sympy.abc import x + md = ImmutableDenseNDimArray([1, 2, 3]) + e1 = (1+x)*md + e2 = md*(1+x) + assert e1 == ImmutableDenseNDimArray([1+x, 2+2*x, 3+3*x]) + assert e1 == e2 + + sd = ImmutableSparseNDimArray([1, 2, 3]) + e3 = (1+x)*sd + e4 = sd*(1+x) + assert e3 == ImmutableDenseNDimArray([1+x, 2+2*x, 3+3*x]) + assert e3 == e4 + + +def test_symbolic_indexing(): + x, y, z, w = symbols("x y z w") + M = ImmutableDenseNDimArray([[x, y], [z, w]]) + i, j = symbols("i, j") + Mij = M[i, j] + assert isinstance(Mij, Indexed) + Ms = ImmutableSparseNDimArray([[2, 3*x], [4, 5]]) + msij = Ms[i, j] + assert isinstance(msij, Indexed) + for oi, oj in [(0, 0), (0, 1), (1, 0), (1, 1)]: + assert Mij.subs({i: oi, j: oj}) == M[oi, oj] + assert msij.subs({i: oi, j: oj}) == Ms[oi, oj] + A = IndexedBase("A", (0, 2)) + assert A[0, 0].subs(A, M) == x + assert A[i, j].subs(A, M) == M[i, j] + assert M[i, j].subs(M, A) == A[i, j] + + assert isinstance(M[3 * i - 2, j], Indexed) + assert M[3 * i - 2, j].subs({i: 1, j: 0}) == M[1, 0] + assert isinstance(M[i, 0], Indexed) + assert M[i, 0].subs(i, 0) == M[0, 0] + assert M[0, i].subs(i, 1) == M[0, 1] + + assert M[i, j].diff(x) == ImmutableDenseNDimArray([[1, 0], [0, 0]])[i, j] + assert Ms[i, j].diff(x) == ImmutableSparseNDimArray([[0, 3], [0, 0]])[i, j] + + Mo = ImmutableDenseNDimArray([1, 2, 3]) + assert Mo[i].subs(i, 1) == 2 + Mos = ImmutableSparseNDimArray([1, 2, 3]) + assert Mos[i].subs(i, 1) == 2 + + raises(ValueError, lambda: M[i, 2]) + raises(ValueError, lambda: M[i, -1]) + raises(ValueError, lambda: M[2, i]) + raises(ValueError, lambda: M[-1, i]) + + raises(ValueError, lambda: Ms[i, 2]) + raises(ValueError, lambda: Ms[i, -1]) + raises(ValueError, lambda: Ms[2, i]) + raises(ValueError, lambda: Ms[-1, i]) + + +def test_issue_12665(): + # Testing Python 3 hash of immutable arrays: + arr = ImmutableDenseNDimArray([1, 2, 3]) + # This should NOT raise an exception: + hash(arr) + + +def test_zeros_without_shape(): + arr = ImmutableDenseNDimArray.zeros() + assert arr == ImmutableDenseNDimArray(0) + +def test_issue_21870(): + a0 = ImmutableDenseNDimArray(0) + assert a0.rank() == 0 + a1 = ImmutableDenseNDimArray(a0) + assert a1.rank() == 0 diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/tests/test_mutable_ndim_array.py b/phivenv/Lib/site-packages/sympy/tensor/array/tests/test_mutable_ndim_array.py new file mode 100644 index 0000000000000000000000000000000000000000..9a232f399bbc0639d326217975fb0a12e645a984 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/tensor/array/tests/test_mutable_ndim_array.py @@ -0,0 +1,374 @@ +from copy import copy + +from sympy.tensor.array.dense_ndim_array import MutableDenseNDimArray +from sympy.core.function import diff +from sympy.core.numbers import Rational +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.core.sympify import sympify +from sympy.matrices import SparseMatrix +from sympy.matrices import Matrix +from sympy.tensor.array.sparse_ndim_array import MutableSparseNDimArray +from sympy.testing.pytest import raises + + +def test_ndim_array_initiation(): + arr_with_one_element = MutableDenseNDimArray([23]) + assert len(arr_with_one_element) == 1 + assert arr_with_one_element[0] == 23 + assert arr_with_one_element.rank() == 1 + raises(ValueError, lambda: arr_with_one_element[1]) + + arr_with_symbol_element = MutableDenseNDimArray([Symbol('x')]) + assert len(arr_with_symbol_element) == 1 + assert arr_with_symbol_element[0] == Symbol('x') + assert arr_with_symbol_element.rank() == 1 + + number5 = 5 + vector = MutableDenseNDimArray.zeros(number5) + assert len(vector) == number5 + assert vector.shape == (number5,) + assert vector.rank() == 1 + raises(ValueError, lambda: arr_with_one_element[5]) + + vector = MutableSparseNDimArray.zeros(number5) + assert len(vector) == number5 + assert vector.shape == (number5,) + assert vector._sparse_array == {} + assert vector.rank() == 1 + + n_dim_array = MutableDenseNDimArray(range(3**4), (3, 3, 3, 3,)) + assert len(n_dim_array) == 3 * 3 * 3 * 3 + assert n_dim_array.shape == (3, 3, 3, 3) + assert n_dim_array.rank() == 4 + raises(ValueError, lambda: n_dim_array[0, 0, 0, 3]) + raises(ValueError, lambda: n_dim_array[3, 0, 0, 0]) + raises(ValueError, lambda: n_dim_array[3**4]) + + array_shape = (3, 3, 3, 3) + sparse_array = MutableSparseNDimArray.zeros(*array_shape) + assert len(sparse_array._sparse_array) == 0 + assert len(sparse_array) == 3 * 3 * 3 * 3 + assert n_dim_array.shape == array_shape + assert n_dim_array.rank() == 4 + + one_dim_array = MutableDenseNDimArray([2, 3, 1]) + assert len(one_dim_array) == 3 + assert one_dim_array.shape == (3,) + assert one_dim_array.rank() == 1 + assert one_dim_array.tolist() == [2, 3, 1] + + shape = (3, 3) + array_with_many_args = MutableSparseNDimArray.zeros(*shape) + assert len(array_with_many_args) == 3 * 3 + assert array_with_many_args.shape == shape + assert array_with_many_args[0, 0] == 0 + assert array_with_many_args.rank() == 2 + + shape = (int(3), int(3)) + array_with_long_shape = MutableSparseNDimArray.zeros(*shape) + assert len(array_with_long_shape) == 3 * 3 + assert array_with_long_shape.shape == shape + assert array_with_long_shape[int(0), int(0)] == 0 + assert array_with_long_shape.rank() == 2 + + vector_with_long_shape = MutableDenseNDimArray(range(5), int(5)) + assert len(vector_with_long_shape) == 5 + assert vector_with_long_shape.shape == (int(5),) + assert vector_with_long_shape.rank() == 1 + raises(ValueError, lambda: vector_with_long_shape[int(5)]) + + from sympy.abc import x + for ArrayType in [MutableDenseNDimArray, MutableSparseNDimArray]: + rank_zero_array = ArrayType(x) + assert len(rank_zero_array) == 1 + assert rank_zero_array.shape == () + assert rank_zero_array.rank() == 0 + assert rank_zero_array[()] == x + raises(ValueError, lambda: rank_zero_array[0]) + +def test_sympify(): + from sympy.abc import x, y, z, t + arr = MutableDenseNDimArray([[x, y], [1, z*t]]) + arr_other = sympify(arr) + assert arr_other.shape == (2, 2) + assert arr_other == arr + + +def test_reshape(): + array = MutableDenseNDimArray(range(50), 50) + assert array.shape == (50,) + assert array.rank() == 1 + + array = array.reshape(5, 5, 2) + assert array.shape == (5, 5, 2) + assert array.rank() == 3 + assert len(array) == 50 + + +def test_iterator(): + array = MutableDenseNDimArray(range(4), (2, 2)) + assert array[0] == MutableDenseNDimArray([0, 1]) + assert array[1] == MutableDenseNDimArray([2, 3]) + + array = array.reshape(4) + j = 0 + for i in array: + assert i == j + j += 1 + + +def test_getitem(): + for ArrayType in [MutableDenseNDimArray, MutableSparseNDimArray]: + array = ArrayType(range(24)).reshape(2, 3, 4) + assert array.tolist() == [[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]] + assert array[0] == ArrayType([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]) + assert array[0, 0] == ArrayType([0, 1, 2, 3]) + value = 0 + for i in range(2): + for j in range(3): + for k in range(4): + assert array[i, j, k] == value + value += 1 + + raises(ValueError, lambda: array[3, 4, 5]) + raises(ValueError, lambda: array[3, 4, 5, 6]) + raises(ValueError, lambda: array[3, 4, 5, 3:4]) + + +def test_sparse(): + sparse_array = MutableSparseNDimArray([0, 0, 0, 1], (2, 2)) + assert len(sparse_array) == 2 * 2 + # dictionary where all data is, only non-zero entries are actually stored: + assert len(sparse_array._sparse_array) == 1 + + assert sparse_array.tolist() == [[0, 0], [0, 1]] + + for i, j in zip(sparse_array, [[0, 0], [0, 1]]): + assert i == MutableSparseNDimArray(j) + + sparse_array[0, 0] = 123 + assert len(sparse_array._sparse_array) == 2 + assert sparse_array[0, 0] == 123 + assert sparse_array/0 == MutableSparseNDimArray([[S.ComplexInfinity, S.NaN], [S.NaN, S.ComplexInfinity]], (2, 2)) + + # when element in sparse array become zero it will disappear from + # dictionary + sparse_array[0, 0] = 0 + assert len(sparse_array._sparse_array) == 1 + sparse_array[1, 1] = 0 + assert len(sparse_array._sparse_array) == 0 + assert sparse_array[0, 0] == 0 + + # test for large scale sparse array + # equality test + a = MutableSparseNDimArray.zeros(100000, 200000) + b = MutableSparseNDimArray.zeros(100000, 200000) + assert a == b + a[1, 1] = 1 + b[1, 1] = 2 + assert a != b + + # __mul__ and __rmul__ + assert a * 3 == MutableSparseNDimArray({200001: 3}, (100000, 200000)) + assert 3 * a == MutableSparseNDimArray({200001: 3}, (100000, 200000)) + assert a * 0 == MutableSparseNDimArray({}, (100000, 200000)) + assert 0 * a == MutableSparseNDimArray({}, (100000, 200000)) + + # __truediv__ + assert a/3 == MutableSparseNDimArray({200001: Rational(1, 3)}, (100000, 200000)) + + # __neg__ + assert -a == MutableSparseNDimArray({200001: -1}, (100000, 200000)) + + +def test_calculation(): + + a = MutableDenseNDimArray([1]*9, (3, 3)) + b = MutableDenseNDimArray([9]*9, (3, 3)) + + c = a + b + for i in c: + assert i == MutableDenseNDimArray([10, 10, 10]) + + assert c == MutableDenseNDimArray([10]*9, (3, 3)) + assert c == MutableSparseNDimArray([10]*9, (3, 3)) + + c = b - a + for i in c: + assert i == MutableSparseNDimArray([8, 8, 8]) + + assert c == MutableDenseNDimArray([8]*9, (3, 3)) + assert c == MutableSparseNDimArray([8]*9, (3, 3)) + + +def test_ndim_array_converting(): + dense_array = MutableDenseNDimArray([1, 2, 3, 4], (2, 2)) + alist = dense_array.tolist() + + assert alist == [[1, 2], [3, 4]] + + matrix = dense_array.tomatrix() + assert (isinstance(matrix, Matrix)) + + for i in range(len(dense_array)): + assert dense_array[dense_array._get_tuple_index(i)] == matrix[i] + assert matrix.shape == dense_array.shape + + assert MutableDenseNDimArray(matrix) == dense_array + assert MutableDenseNDimArray(matrix.as_immutable()) == dense_array + assert MutableDenseNDimArray(matrix.as_mutable()) == dense_array + + sparse_array = MutableSparseNDimArray([1, 2, 3, 4], (2, 2)) + alist = sparse_array.tolist() + + assert alist == [[1, 2], [3, 4]] + + matrix = sparse_array.tomatrix() + assert(isinstance(matrix, SparseMatrix)) + + for i in range(len(sparse_array)): + assert sparse_array[sparse_array._get_tuple_index(i)] == matrix[i] + assert matrix.shape == sparse_array.shape + + assert MutableSparseNDimArray(matrix) == sparse_array + assert MutableSparseNDimArray(matrix.as_immutable()) == sparse_array + assert MutableSparseNDimArray(matrix.as_mutable()) == sparse_array + + +def test_converting_functions(): + arr_list = [1, 2, 3, 4] + arr_matrix = Matrix(((1, 2), (3, 4))) + + # list + arr_ndim_array = MutableDenseNDimArray(arr_list, (2, 2)) + assert (isinstance(arr_ndim_array, MutableDenseNDimArray)) + assert arr_matrix.tolist() == arr_ndim_array.tolist() + + # Matrix + arr_ndim_array = MutableDenseNDimArray(arr_matrix) + assert (isinstance(arr_ndim_array, MutableDenseNDimArray)) + assert arr_matrix.tolist() == arr_ndim_array.tolist() + assert arr_matrix.shape == arr_ndim_array.shape + + +def test_equality(): + first_list = [1, 2, 3, 4] + second_list = [1, 2, 3, 4] + third_list = [4, 3, 2, 1] + assert first_list == second_list + assert first_list != third_list + + first_ndim_array = MutableDenseNDimArray(first_list, (2, 2)) + second_ndim_array = MutableDenseNDimArray(second_list, (2, 2)) + third_ndim_array = MutableDenseNDimArray(third_list, (2, 2)) + fourth_ndim_array = MutableDenseNDimArray(first_list, (2, 2)) + + assert first_ndim_array == second_ndim_array + second_ndim_array[0, 0] = 0 + assert first_ndim_array != second_ndim_array + assert first_ndim_array != third_ndim_array + assert first_ndim_array == fourth_ndim_array + + +def test_arithmetic(): + a = MutableDenseNDimArray([3 for i in range(9)], (3, 3)) + b = MutableDenseNDimArray([7 for i in range(9)], (3, 3)) + + c1 = a + b + c2 = b + a + assert c1 == c2 + + d1 = a - b + d2 = b - a + assert d1 == d2 * (-1) + + e1 = a * 5 + e2 = 5 * a + e3 = copy(a) + e3 *= 5 + assert e1 == e2 == e3 + + f1 = a / 5 + f2 = copy(a) + f2 /= 5 + assert f1 == f2 + assert f1[0, 0] == f1[0, 1] == f1[0, 2] == f1[1, 0] == f1[1, 1] == \ + f1[1, 2] == f1[2, 0] == f1[2, 1] == f1[2, 2] == Rational(3, 5) + + assert type(a) == type(b) == type(c1) == type(c2) == type(d1) == type(d2) \ + == type(e1) == type(e2) == type(e3) == type(f1) + + z0 = -a + assert z0 == MutableDenseNDimArray([-3 for i in range(9)], (3, 3)) + + +def test_higher_dimenions(): + m3 = MutableDenseNDimArray(range(10, 34), (2, 3, 4)) + + assert m3.tolist() == [[[10, 11, 12, 13], + [14, 15, 16, 17], + [18, 19, 20, 21]], + + [[22, 23, 24, 25], + [26, 27, 28, 29], + [30, 31, 32, 33]]] + + assert m3._get_tuple_index(0) == (0, 0, 0) + assert m3._get_tuple_index(1) == (0, 0, 1) + assert m3._get_tuple_index(4) == (0, 1, 0) + assert m3._get_tuple_index(12) == (1, 0, 0) + + assert str(m3) == '[[[10, 11, 12, 13], [14, 15, 16, 17], [18, 19, 20, 21]], [[22, 23, 24, 25], [26, 27, 28, 29], [30, 31, 32, 33]]]' + + m3_rebuilt = MutableDenseNDimArray([[[10, 11, 12, 13], [14, 15, 16, 17], [18, 19, 20, 21]], [[22, 23, 24, 25], [26, 27, 28, 29], [30, 31, 32, 33]]]) + assert m3 == m3_rebuilt + + m3_other = MutableDenseNDimArray([[[10, 11, 12, 13], [14, 15, 16, 17], [18, 19, 20, 21]], [[22, 23, 24, 25], [26, 27, 28, 29], [30, 31, 32, 33]]], (2, 3, 4)) + + assert m3 == m3_other + + +def test_slices(): + md = MutableDenseNDimArray(range(10, 34), (2, 3, 4)) + + assert md[:] == MutableDenseNDimArray(range(10, 34), (2, 3, 4)) + assert md[:, :, 0].tomatrix() == Matrix([[10, 14, 18], [22, 26, 30]]) + assert md[0, 1:2, :].tomatrix() == Matrix([[14, 15, 16, 17]]) + assert md[0, 1:3, :].tomatrix() == Matrix([[14, 15, 16, 17], [18, 19, 20, 21]]) + assert md[:, :, :] == md + + sd = MutableSparseNDimArray(range(10, 34), (2, 3, 4)) + assert sd == MutableSparseNDimArray(md) + + assert sd[:] == MutableSparseNDimArray(range(10, 34), (2, 3, 4)) + assert sd[:, :, 0].tomatrix() == Matrix([[10, 14, 18], [22, 26, 30]]) + assert sd[0, 1:2, :].tomatrix() == Matrix([[14, 15, 16, 17]]) + assert sd[0, 1:3, :].tomatrix() == Matrix([[14, 15, 16, 17], [18, 19, 20, 21]]) + assert sd[:, :, :] == sd + + +def test_slices_assign(): + a = MutableDenseNDimArray(range(12), shape=(4, 3)) + b = MutableSparseNDimArray(range(12), shape=(4, 3)) + + for i in [a, b]: + assert i.tolist() == [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]] + i[0, :] = [2, 2, 2] + assert i.tolist() == [[2, 2, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]] + i[0, 1:] = [8, 8] + assert i.tolist() == [[2, 8, 8], [3, 4, 5], [6, 7, 8], [9, 10, 11]] + i[1:3, 1] = [20, 44] + assert i.tolist() == [[2, 8, 8], [3, 20, 5], [6, 44, 8], [9, 10, 11]] + + +def test_diff(): + from sympy.abc import x, y, z + md = MutableDenseNDimArray([[x, y], [x*z, x*y*z]]) + assert md.diff(x) == MutableDenseNDimArray([[1, 0], [z, y*z]]) + assert diff(md, x) == MutableDenseNDimArray([[1, 0], [z, y*z]]) + + sd = MutableSparseNDimArray(md) + assert sd == MutableSparseNDimArray([x, y, x*z, x*y*z], (2, 2)) + assert sd.diff(x) == MutableSparseNDimArray([[1, 0], [z, y*z]]) + assert diff(sd, x) == MutableSparseNDimArray([[1, 0], [z, y*z]]) diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/tests/test_ndim_array.py b/phivenv/Lib/site-packages/sympy/tensor/array/tests/test_ndim_array.py new file mode 100644 index 0000000000000000000000000000000000000000..7ff9b032631c01272c00478e4cdf0dcbc6997990 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/tensor/array/tests/test_ndim_array.py @@ -0,0 +1,73 @@ +from sympy.testing.pytest import raises +from sympy.functions.elementary.trigonometric import sin, cos +from sympy.matrices.dense import Matrix +from sympy.simplify import simplify +from sympy.tensor.array import Array +from sympy.tensor.array.dense_ndim_array import ( + ImmutableDenseNDimArray, MutableDenseNDimArray) +from sympy.tensor.array.sparse_ndim_array import ( + ImmutableSparseNDimArray, MutableSparseNDimArray) + +from sympy.abc import x, y + +mutable_array_types = [ + MutableDenseNDimArray, + MutableSparseNDimArray +] + +array_types = [ + ImmutableDenseNDimArray, + ImmutableSparseNDimArray, + MutableDenseNDimArray, + MutableSparseNDimArray +] + + +def test_array_negative_indices(): + for ArrayType in array_types: + test_array = ArrayType([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]) + assert test_array[:, -1] == Array([5, 10]) + assert test_array[:, -2] == Array([4, 9]) + assert test_array[:, -3] == Array([3, 8]) + assert test_array[:, -4] == Array([2, 7]) + assert test_array[:, -5] == Array([1, 6]) + assert test_array[:, 0] == Array([1, 6]) + assert test_array[:, 1] == Array([2, 7]) + assert test_array[:, 2] == Array([3, 8]) + assert test_array[:, 3] == Array([4, 9]) + assert test_array[:, 4] == Array([5, 10]) + + raises(ValueError, lambda: test_array[:, -6]) + raises(ValueError, lambda: test_array[-3, :]) + + assert test_array[-1, -1] == 10 + + +def test_issue_18361(): + A = Array([sin(2 * x) - 2 * sin(x) * cos(x)]) + B = Array([sin(x)**2 + cos(x)**2, 0]) + C = Array([(x + x**2)/(x*sin(y)**2 + x*cos(y)**2), 2*sin(x)*cos(x)]) + assert simplify(A) == Array([0]) + assert simplify(B) == Array([1, 0]) + assert simplify(C) == Array([x + 1, sin(2*x)]) + + +def test_issue_20222(): + A = Array([[1, 2], [3, 4]]) + B = Matrix([[1,2],[3,4]]) + raises(TypeError, lambda: A - B) + + +def test_issue_17851(): + for array_type in array_types: + A = array_type([]) + assert isinstance(A, array_type) + assert A.shape == (0,) + assert list(A) == [] + + +def test_issue_and_18715(): + for array_type in mutable_array_types: + A = array_type([0, 1, 2]) + A[0] += 5 + assert A[0] == 5 diff --git a/phivenv/Lib/site-packages/sympy/tensor/array/tests/test_ndim_array_conversions.py b/phivenv/Lib/site-packages/sympy/tensor/array/tests/test_ndim_array_conversions.py new file mode 100644 index 0000000000000000000000000000000000000000..f43260ccc636ac461ba0c06dbfcf3fe3a8d5338d --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/tensor/array/tests/test_ndim_array_conversions.py @@ -0,0 +1,22 @@ +from sympy.tensor.array import (ImmutableDenseNDimArray, + ImmutableSparseNDimArray, MutableDenseNDimArray, MutableSparseNDimArray) +from sympy.abc import x, y, z + + +def test_NDim_array_conv(): + MD = MutableDenseNDimArray([x, y, z]) + MS = MutableSparseNDimArray([x, y, z]) + ID = ImmutableDenseNDimArray([x, y, z]) + IS = ImmutableSparseNDimArray([x, y, z]) + + assert MD.as_immutable() == ID + assert MD.as_mutable() == MD + + assert MS.as_immutable() == IS + assert MS.as_mutable() == MS + + assert ID.as_immutable() == ID + assert ID.as_mutable() == MD + + assert IS.as_immutable() == IS + assert IS.as_mutable() == MS diff --git a/phivenv/Lib/site-packages/sympy/tensor/functions.py b/phivenv/Lib/site-packages/sympy/tensor/functions.py new file mode 100644 index 0000000000000000000000000000000000000000..f14599d69152db1713f21c9dd785683901c5eeb9 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/tensor/functions.py @@ -0,0 +1,154 @@ +from collections.abc import Iterable +from functools import singledispatch + +from sympy.core.expr import Expr +from sympy.core.mul import Mul +from sympy.core.singleton import S +from sympy.core.sympify import sympify +from sympy.core.parameters import global_parameters + + +class TensorProduct(Expr): + """ + Generic class for tensor products. + """ + is_number = False + + def __new__(cls, *args, **kwargs): + from sympy.tensor.array import NDimArray, tensorproduct, Array + from sympy.matrices.expressions.matexpr import MatrixExpr + from sympy.matrices.matrixbase import MatrixBase + from sympy.strategies import flatten + + args = [sympify(arg) for arg in args] + evaluate = kwargs.get("evaluate", global_parameters.evaluate) + + if not evaluate: + obj = Expr.__new__(cls, *args) + return obj + + arrays = [] + other = [] + scalar = S.One + for arg in args: + if isinstance(arg, (Iterable, MatrixBase, NDimArray)): + arrays.append(Array(arg)) + elif isinstance(arg, (MatrixExpr,)): + other.append(arg) + else: + scalar *= arg + + coeff = scalar*tensorproduct(*arrays) + if len(other) == 0: + return coeff + if coeff != 1: + newargs = [coeff] + other + else: + newargs = other + obj = Expr.__new__(cls, *newargs, **kwargs) + return flatten(obj) + + def rank(self): + return len(self.shape) + + def _get_args_shapes(self): + from sympy.tensor.array import Array + return [i.shape if hasattr(i, "shape") else Array(i).shape for i in self.args] + + @property + def shape(self): + shape_list = self._get_args_shapes() + return sum(shape_list, ()) + + def __getitem__(self, index): + index = iter(index) + return Mul.fromiter( + arg.__getitem__(tuple(next(index) for i in shp)) + for arg, shp in zip(self.args, self._get_args_shapes()) + ) + + +@singledispatch +def shape(expr): + """ + Return the shape of the *expr* as a tuple. *expr* should represent + suitable object such as matrix or array. + + Parameters + ========== + + expr : SymPy object having ``MatrixKind`` or ``ArrayKind``. + + Raises + ====== + + NoShapeError : Raised when object with wrong kind is passed. + + Examples + ======== + + This function returns the shape of any object representing matrix or array. + + >>> from sympy import shape, Array, ImmutableDenseMatrix, Integral + >>> from sympy.abc import x + >>> A = Array([1, 2]) + >>> shape(A) + (2,) + >>> shape(Integral(A, x)) + (2,) + >>> M = ImmutableDenseMatrix([1, 2]) + >>> shape(M) + (2, 1) + >>> shape(Integral(M, x)) + (2, 1) + + You can support new type by dispatching. + + >>> from sympy import Expr + >>> class NewExpr(Expr): + ... pass + >>> @shape.register(NewExpr) + ... def _(expr): + ... return shape(expr.args[0]) + >>> shape(NewExpr(M)) + (2, 1) + + If unsuitable expression is passed, ``NoShapeError()`` will be raised. + + >>> shape(Integral(x, x)) + Traceback (most recent call last): + ... + sympy.tensor.functions.NoShapeError: shape() called on non-array object: Integral(x, x) + + Notes + ===== + + Array-like classes (such as ``Matrix`` or ``NDimArray``) has ``shape`` + property which returns its shape, but it cannot be used for non-array + classes containing array. This function returns the shape of any + registered object representing array. + + """ + if hasattr(expr, "shape"): + return expr.shape + raise NoShapeError( + "%s does not have shape, or its type is not registered to shape()." % expr) + + +class NoShapeError(Exception): + """ + Raised when ``shape()`` is called on non-array object. + + This error can be imported from ``sympy.tensor.functions``. + + Examples + ======== + + >>> from sympy import shape + >>> from sympy.abc import x + >>> shape(x) + Traceback (most recent call last): + ... + sympy.tensor.functions.NoShapeError: shape() called on non-array object: x + """ + pass diff --git a/phivenv/Lib/site-packages/sympy/tensor/index_methods.py b/phivenv/Lib/site-packages/sympy/tensor/index_methods.py new file mode 100644 index 0000000000000000000000000000000000000000..12f707b60b4ad0bcadc35a222d9abe0cc5e033fc --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/tensor/index_methods.py @@ -0,0 +1,469 @@ +"""Module with functions operating on IndexedBase, Indexed and Idx objects + + - Check shape conformance + - Determine indices in resulting expression + + etc. + + Methods in this module could be implemented by calling methods on Expr + objects instead. When things stabilize this could be a useful + refactoring. +""" + +from functools import reduce + +from sympy.core.function import Function +from sympy.functions import exp, Piecewise +from sympy.tensor.indexed import Idx, Indexed +from sympy.utilities import sift + +from collections import OrderedDict + +class IndexConformanceException(Exception): + pass + +def _unique_and_repeated(inds): + """ + Returns the unique and repeated indices. Also note, from the examples given below + that the order of indices is maintained as given in the input. + + Examples + ======== + + >>> from sympy.tensor.index_methods import _unique_and_repeated + >>> _unique_and_repeated([2, 3, 1, 3, 0, 4, 0]) + ([2, 1, 4], [3, 0]) + """ + uniq = OrderedDict() + for i in inds: + if i in uniq: + uniq[i] = 0 + else: + uniq[i] = 1 + return sift(uniq, lambda x: uniq[x], binary=True) + +def _remove_repeated(inds): + """ + Removes repeated objects from sequences + + Returns a set of the unique objects and a tuple of all that have been + removed. + + Examples + ======== + + >>> from sympy.tensor.index_methods import _remove_repeated + >>> l1 = [1, 2, 3, 2] + >>> _remove_repeated(l1) + ({1, 3}, (2,)) + + """ + u, r = _unique_and_repeated(inds) + return set(u), tuple(r) + + +def _get_indices_Mul(expr, return_dummies=False): + """Determine the outer indices of a Mul object. + + Examples + ======== + + >>> from sympy.tensor.index_methods import _get_indices_Mul + >>> from sympy.tensor.indexed import IndexedBase, Idx + >>> i, j, k = map(Idx, ['i', 'j', 'k']) + >>> x = IndexedBase('x') + >>> y = IndexedBase('y') + >>> _get_indices_Mul(x[i, k]*y[j, k]) + ({i, j}, {}) + >>> _get_indices_Mul(x[i, k]*y[j, k], return_dummies=True) + ({i, j}, {}, (k,)) + + """ + + inds = list(map(get_indices, expr.args)) + inds, syms = list(zip(*inds)) + + inds = list(map(list, inds)) + inds = list(reduce(lambda x, y: x + y, inds)) + inds, dummies = _remove_repeated(inds) + + symmetry = {} + for s in syms: + for pair in s: + if pair in symmetry: + symmetry[pair] *= s[pair] + else: + symmetry[pair] = s[pair] + + if return_dummies: + return inds, symmetry, dummies + else: + return inds, symmetry + + +def _get_indices_Pow(expr): + """Determine outer indices of a power or an exponential. + + A power is considered a universal function, so that the indices of a Pow is + just the collection of indices present in the expression. This may be + viewed as a bit inconsistent in the special case: + + x[i]**2 = x[i]*x[i] (1) + + The above expression could have been interpreted as the contraction of x[i] + with itself, but we choose instead to interpret it as a function + + lambda y: y**2 + + applied to each element of x (a universal function in numpy terms). In + order to allow an interpretation of (1) as a contraction, we need + contravariant and covariant Idx subclasses. (FIXME: this is not yet + implemented) + + Expressions in the base or exponent are subject to contraction as usual, + but an index that is present in the exponent, will not be considered + contractable with its own base. Note however, that indices in the same + exponent can be contracted with each other. + + Examples + ======== + + >>> from sympy.tensor.index_methods import _get_indices_Pow + >>> from sympy import Pow, exp, IndexedBase, Idx + >>> A = IndexedBase('A') + >>> x = IndexedBase('x') + >>> i, j, k = map(Idx, ['i', 'j', 'k']) + >>> _get_indices_Pow(exp(A[i, j]*x[j])) + ({i}, {}) + >>> _get_indices_Pow(Pow(x[i], x[i])) + ({i}, {}) + >>> _get_indices_Pow(Pow(A[i, j]*x[j], x[i])) + ({i}, {}) + + """ + base, exp = expr.as_base_exp() + binds, bsyms = get_indices(base) + einds, esyms = get_indices(exp) + + inds = binds | einds + + # FIXME: symmetries from power needs to check special cases, else nothing + symmetries = {} + + return inds, symmetries + + +def _get_indices_Add(expr): + """Determine outer indices of an Add object. + + In a sum, each term must have the same set of outer indices. A valid + expression could be + + x(i)*y(j) - x(j)*y(i) + + But we do not allow expressions like: + + x(i)*y(j) - z(j)*z(j) + + FIXME: Add support for Numpy broadcasting + + Examples + ======== + + >>> from sympy.tensor.index_methods import _get_indices_Add + >>> from sympy.tensor.indexed import IndexedBase, Idx + >>> i, j, k = map(Idx, ['i', 'j', 'k']) + >>> x = IndexedBase('x') + >>> y = IndexedBase('y') + >>> _get_indices_Add(x[i] + x[k]*y[i, k]) + ({i}, {}) + + """ + + inds = list(map(get_indices, expr.args)) + inds, syms = list(zip(*inds)) + + # allow broadcast of scalars + non_scalars = [x for x in inds if x != set()] + if not non_scalars: + return set(), {} + + if not all(x == non_scalars[0] for x in non_scalars[1:]): + raise IndexConformanceException("Indices are not consistent: %s" % expr) + if not reduce(lambda x, y: x != y or y, syms): + symmetries = syms[0] + else: + # FIXME: search for symmetries + symmetries = {} + + return non_scalars[0], symmetries + + +def get_indices(expr): + """Determine the outer indices of expression ``expr`` + + By *outer* we mean indices that are not summation indices. Returns a set + and a dict. The set contains outer indices and the dict contains + information about index symmetries. + + Examples + ======== + + >>> from sympy.tensor.index_methods import get_indices + >>> from sympy import symbols + >>> from sympy.tensor import IndexedBase + >>> x, y, A = map(IndexedBase, ['x', 'y', 'A']) + >>> i, j, a, z = symbols('i j a z', integer=True) + + The indices of the total expression is determined, Repeated indices imply a + summation, for instance the trace of a matrix A: + + >>> get_indices(A[i, i]) + (set(), {}) + + In the case of many terms, the terms are required to have identical + outer indices. Else an IndexConformanceException is raised. + + >>> get_indices(x[i] + A[i, j]*y[j]) + ({i}, {}) + + :Exceptions: + + An IndexConformanceException means that the terms ar not compatible, e.g. + + >>> get_indices(x[i] + y[j]) #doctest: +SKIP + (...) + IndexConformanceException: Indices are not consistent: x(i) + y(j) + + .. warning:: + The concept of *outer* indices applies recursively, starting on the deepest + level. This implies that dummies inside parenthesis are assumed to be + summed first, so that the following expression is handled gracefully: + + >>> get_indices((x[i] + A[i, j]*y[j])*x[j]) + ({i, j}, {}) + + This is correct and may appear convenient, but you need to be careful + with this as SymPy will happily .expand() the product, if requested. The + resulting expression would mix the outer ``j`` with the dummies inside + the parenthesis, which makes it a different expression. To be on the + safe side, it is best to avoid such ambiguities by using unique indices + for all contractions that should be held separate. + + """ + # We call ourself recursively to determine indices of sub expressions. + + # break recursion + if isinstance(expr, Indexed): + c = expr.indices + inds, dummies = _remove_repeated(c) + return inds, {} + elif expr is None: + return set(), {} + elif isinstance(expr, Idx): + return {expr}, {} + elif expr.is_Atom: + return set(), {} + + + # recurse via specialized functions + else: + if expr.is_Mul: + return _get_indices_Mul(expr) + elif expr.is_Add: + return _get_indices_Add(expr) + elif expr.is_Pow or isinstance(expr, exp): + return _get_indices_Pow(expr) + + elif isinstance(expr, Piecewise): + # FIXME: No support for Piecewise yet + return set(), {} + elif isinstance(expr, Function): + # Support ufunc like behaviour by returning indices from arguments. + # Functions do not interpret repeated indices across arguments + # as summation + ind0 = set() + for arg in expr.args: + ind, sym = get_indices(arg) + ind0 |= ind + return ind0, sym + + # this test is expensive, so it should be at the end + elif not expr.has(Indexed): + return set(), {} + raise NotImplementedError( + "FIXME: No specialized handling of type %s" % type(expr)) + + +def get_contraction_structure(expr): + """Determine dummy indices of ``expr`` and describe its structure + + By *dummy* we mean indices that are summation indices. + + The structure of the expression is determined and described as follows: + + 1) A conforming summation of Indexed objects is described with a dict where + the keys are summation indices and the corresponding values are sets + containing all terms for which the summation applies. All Add objects + in the SymPy expression tree are described like this. + + 2) For all nodes in the SymPy expression tree that are *not* of type Add, the + following applies: + + If a node discovers contractions in one of its arguments, the node + itself will be stored as a key in the dict. For that key, the + corresponding value is a list of dicts, each of which is the result of a + recursive call to get_contraction_structure(). The list contains only + dicts for the non-trivial deeper contractions, omitting dicts with None + as the one and only key. + + .. Note:: The presence of expressions among the dictionary keys indicates + multiple levels of index contractions. A nested dict displays nested + contractions and may itself contain dicts from a deeper level. In + practical calculations the summation in the deepest nested level must be + calculated first so that the outer expression can access the resulting + indexed object. + + Examples + ======== + + >>> from sympy.tensor.index_methods import get_contraction_structure + >>> from sympy import default_sort_key + >>> from sympy.tensor import IndexedBase, Idx + >>> x, y, A = map(IndexedBase, ['x', 'y', 'A']) + >>> i, j, k, l = map(Idx, ['i', 'j', 'k', 'l']) + >>> get_contraction_structure(x[i]*y[i] + A[j, j]) + {(i,): {x[i]*y[i]}, (j,): {A[j, j]}} + >>> get_contraction_structure(x[i]*y[j]) + {None: {x[i]*y[j]}} + + A multiplication of contracted factors results in nested dicts representing + the internal contractions. + + >>> d = get_contraction_structure(x[i, i]*y[j, j]) + >>> sorted(d.keys(), key=default_sort_key) + [None, x[i, i]*y[j, j]] + + In this case, the product has no contractions: + + >>> d[None] + {x[i, i]*y[j, j]} + + Factors are contracted "first": + + >>> sorted(d[x[i, i]*y[j, j]], key=default_sort_key) + [{(i,): {x[i, i]}}, {(j,): {y[j, j]}}] + + A parenthesized Add object is also returned as a nested dictionary. The + term containing the parenthesis is a Mul with a contraction among the + arguments, so it will be found as a key in the result. It stores the + dictionary resulting from a recursive call on the Add expression. + + >>> d = get_contraction_structure(x[i]*(y[i] + A[i, j]*x[j])) + >>> sorted(d.keys(), key=default_sort_key) + [(A[i, j]*x[j] + y[i])*x[i], (i,)] + >>> d[(i,)] + {(A[i, j]*x[j] + y[i])*x[i]} + >>> d[x[i]*(A[i, j]*x[j] + y[i])] + [{None: {y[i]}, (j,): {A[i, j]*x[j]}}] + + Powers with contractions in either base or exponent will also be found as + keys in the dictionary, mapping to a list of results from recursive calls: + + >>> d = get_contraction_structure(A[j, j]**A[i, i]) + >>> d[None] + {A[j, j]**A[i, i]} + >>> nested_contractions = d[A[j, j]**A[i, i]] + >>> nested_contractions[0] + {(j,): {A[j, j]}} + >>> nested_contractions[1] + {(i,): {A[i, i]}} + + The description of the contraction structure may appear complicated when + represented with a string in the above examples, but it is easy to iterate + over: + + >>> from sympy import Expr + >>> for key in d: + ... if isinstance(key, Expr): + ... continue + ... for term in d[key]: + ... if term in d: + ... # treat deepest contraction first + ... pass + ... # treat outermost contactions here + + """ + + # We call ourself recursively to inspect sub expressions. + + if isinstance(expr, Indexed): + junk, key = _remove_repeated(expr.indices) + return {key or None: {expr}} + elif expr.is_Atom: + return {None: {expr}} + elif expr.is_Mul: + junk, junk, key = _get_indices_Mul(expr, return_dummies=True) + result = {key or None: {expr}} + # recurse on every factor + nested = [] + for fac in expr.args: + facd = get_contraction_structure(fac) + if not (None in facd and len(facd) == 1): + nested.append(facd) + if nested: + result[expr] = nested + return result + elif expr.is_Pow or isinstance(expr, exp): + # recurse in base and exp separately. If either has internal + # contractions we must include ourselves as a key in the returned dict + b, e = expr.as_base_exp() + dbase = get_contraction_structure(b) + dexp = get_contraction_structure(e) + + dicts = [] + for d in dbase, dexp: + if not (None in d and len(d) == 1): + dicts.append(d) + result = {None: {expr}} + if dicts: + result[expr] = dicts + return result + elif expr.is_Add: + # Note: we just collect all terms with identical summation indices, We + # do nothing to identify equivalent terms here, as this would require + # substitutions or pattern matching in expressions of unknown + # complexity. + result = {} + for term in expr.args: + # recurse on every term + d = get_contraction_structure(term) + for key in d: + if key in result: + result[key] |= d[key] + else: + result[key] = d[key] + return result + + elif isinstance(expr, Piecewise): + # FIXME: No support for Piecewise yet + return {None: expr} + elif isinstance(expr, Function): + # Collect non-trivial contraction structures in each argument + # We do not report repeated indices in separate arguments as a + # contraction + deeplist = [] + for arg in expr.args: + deep = get_contraction_structure(arg) + if not (None in deep and len(deep) == 1): + deeplist.append(deep) + d = {None: {expr}} + if deeplist: + d[expr] = deeplist + return d + + # this test is expensive, so it should be at the end + elif not expr.has(Indexed): + return {None: {expr}} + raise NotImplementedError( + "FIXME: No specialized handling of type %s" % type(expr)) diff --git a/phivenv/Lib/site-packages/sympy/tensor/indexed.py b/phivenv/Lib/site-packages/sympy/tensor/indexed.py new file mode 100644 index 0000000000000000000000000000000000000000..feddad21e52bbab2e1243beafdb11f30b2eded4d --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/tensor/indexed.py @@ -0,0 +1,793 @@ +r"""Module that defines indexed objects. + +The classes ``IndexedBase``, ``Indexed``, and ``Idx`` represent a +matrix element ``M[i, j]`` as in the following diagram:: + + 1) The Indexed class represents the entire indexed object. + | + ___|___ + ' ' + M[i, j] + / \__\______ + | | + | | + | 2) The Idx class represents indices; each Idx can + | optionally contain information about its range. + | + 3) IndexedBase represents the 'stem' of an indexed object, here `M`. + The stem used by itself is usually taken to represent the entire + array. + +There can be any number of indices on an Indexed object. No +transformation properties are implemented in these Base objects, but +implicit contraction of repeated indices is supported. + +Note that the support for complicated (i.e. non-atomic) integer +expressions as indices is limited. (This should be improved in +future releases.) + +Examples +======== + +To express the above matrix element example you would write: + +>>> from sympy import symbols, IndexedBase, Idx +>>> M = IndexedBase('M') +>>> i, j = symbols('i j', cls=Idx) +>>> M[i, j] +M[i, j] + +Repeated indices in a product implies a summation, so to express a +matrix-vector product in terms of Indexed objects: + +>>> x = IndexedBase('x') +>>> M[i, j]*x[j] +M[i, j]*x[j] + +If the indexed objects will be converted to component based arrays, e.g. +with the code printers or the autowrap framework, you also need to provide +(symbolic or numerical) dimensions. This can be done by passing an +optional shape parameter to IndexedBase upon construction: + +>>> dim1, dim2 = symbols('dim1 dim2', integer=True) +>>> A = IndexedBase('A', shape=(dim1, 2*dim1, dim2)) +>>> A.shape +(dim1, 2*dim1, dim2) +>>> A[i, j, 3].shape +(dim1, 2*dim1, dim2) + +If an IndexedBase object has no shape information, it is assumed that the +array is as large as the ranges of its indices: + +>>> n, m = symbols('n m', integer=True) +>>> i = Idx('i', m) +>>> j = Idx('j', n) +>>> M[i, j].shape +(m, n) +>>> M[i, j].ranges +[(0, m - 1), (0, n - 1)] + +The above can be compared with the following: + +>>> A[i, 2, j].shape +(dim1, 2*dim1, dim2) +>>> A[i, 2, j].ranges +[(0, m - 1), None, (0, n - 1)] + +To analyze the structure of indexed expressions, you can use the methods +get_indices() and get_contraction_structure(): + +>>> from sympy.tensor import get_indices, get_contraction_structure +>>> get_indices(A[i, j, j]) +({i}, {}) +>>> get_contraction_structure(A[i, j, j]) +{(j,): {A[i, j, j]}} + +See the appropriate docstrings for a detailed explanation of the output. +""" + +# TODO: (some ideas for improvement) +# +# o test and guarantee numpy compatibility +# - implement full support for broadcasting +# - strided arrays +# +# o more functions to analyze indexed expressions +# - identify standard constructs, e.g matrix-vector product in a subexpression +# +# o functions to generate component based arrays (numpy and sympy.Matrix) +# - generate a single array directly from Indexed +# - convert simple sub-expressions +# +# o sophisticated indexing (possibly in subclasses to preserve simplicity) +# - Idx with range smaller than dimension of Indexed +# - Idx with stepsize != 1 +# - Idx with step determined by function call +from collections.abc import Iterable + +from sympy.core.numbers import Number +from sympy.core.assumptions import StdFactKB +from sympy.core import Expr, Tuple, sympify, S +from sympy.core.symbol import _filter_assumptions, Symbol +from sympy.core.logic import fuzzy_bool, fuzzy_not +from sympy.core.sympify import _sympify +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.multipledispatch import dispatch +from sympy.utilities.iterables import is_sequence, NotIterable +from sympy.utilities.misc import filldedent + + +class IndexException(Exception): + pass + + +class Indexed(Expr): + """Represents a mathematical object with indices. + + >>> from sympy import Indexed, IndexedBase, Idx, symbols + >>> i, j = symbols('i j', cls=Idx) + >>> Indexed('A', i, j) + A[i, j] + + It is recommended that ``Indexed`` objects be created by indexing ``IndexedBase``: + ``IndexedBase('A')[i, j]`` instead of ``Indexed(IndexedBase('A'), i, j)``. + + >>> A = IndexedBase('A') + >>> a_ij = A[i, j] # Prefer this, + >>> b_ij = Indexed(A, i, j) # over this. + >>> a_ij == b_ij + True + + """ + is_Indexed = True + is_symbol = True + is_Atom = True + + def __new__(cls, base, *args, **kw_args): + from sympy.tensor.array.ndim_array import NDimArray + from sympy.matrices.matrixbase import MatrixBase + + if not args: + raise IndexException("Indexed needs at least one index.") + if isinstance(base, (str, Symbol)): + base = IndexedBase(base) + elif not hasattr(base, '__getitem__') and not isinstance(base, IndexedBase): + raise TypeError(filldedent(""" + The base can only be replaced with a string, Symbol, + IndexedBase or an object with a method for getting + items (i.e. an object with a `__getitem__` method). + """)) + args = list(map(sympify, args)) + if isinstance(base, (NDimArray, Iterable, Tuple, MatrixBase)) and all(i.is_number for i in args): + if len(args) == 1: + return base[args[0]] + else: + return base[args] + + base = _sympify(base) + + obj = Expr.__new__(cls, base, *args, **kw_args) + + IndexedBase._set_assumptions(obj, base.assumptions0) + + return obj + + def _hashable_content(self): + return super()._hashable_content() + tuple(sorted(self.assumptions0.items())) + + @property + def name(self): + return str(self) + + @property + def _diff_wrt(self): + """Allow derivatives with respect to an ``Indexed`` object.""" + return True + + def _eval_derivative(self, wrt): + from sympy.tensor.array.ndim_array import NDimArray + + if isinstance(wrt, Indexed) and wrt.base == self.base: + if len(self.indices) != len(wrt.indices): + msg = "Different # of indices: d({!s})/d({!s})".format(self, + wrt) + raise IndexException(msg) + result = S.One + for index1, index2 in zip(self.indices, wrt.indices): + result *= KroneckerDelta(index1, index2) + return result + elif isinstance(self.base, NDimArray): + from sympy.tensor.array import derive_by_array + return Indexed(derive_by_array(self.base, wrt), *self.args[1:]) + else: + if Tuple(self.indices).has(wrt): + return S.NaN + return S.Zero + + @property + def assumptions0(self): + return {k: v for k, v in self._assumptions.items() if v is not None} + + @property + def base(self): + """Returns the ``IndexedBase`` of the ``Indexed`` object. + + Examples + ======== + + >>> from sympy import Indexed, IndexedBase, Idx, symbols + >>> i, j = symbols('i j', cls=Idx) + >>> Indexed('A', i, j).base + A + >>> B = IndexedBase('B') + >>> B == B[i, j].base + True + + """ + return self.args[0] + + @property + def indices(self): + """ + Returns the indices of the ``Indexed`` object. + + Examples + ======== + + >>> from sympy import Indexed, Idx, symbols + >>> i, j = symbols('i j', cls=Idx) + >>> Indexed('A', i, j).indices + (i, j) + + """ + return self.args[1:] + + @property + def rank(self): + """ + Returns the rank of the ``Indexed`` object. + + Examples + ======== + + >>> from sympy import Indexed, Idx, symbols + >>> i, j, k, l, m = symbols('i:m', cls=Idx) + >>> Indexed('A', i, j).rank + 2 + >>> q = Indexed('A', i, j, k, l, m) + >>> q.rank + 5 + >>> q.rank == len(q.indices) + True + + """ + return len(self.args) - 1 + + @property + def shape(self): + """Returns a list with dimensions of each index. + + Dimensions is a property of the array, not of the indices. Still, if + the ``IndexedBase`` does not define a shape attribute, it is assumed + that the ranges of the indices correspond to the shape of the array. + + >>> from sympy import IndexedBase, Idx, symbols + >>> n, m = symbols('n m', integer=True) + >>> i = Idx('i', m) + >>> j = Idx('j', m) + >>> A = IndexedBase('A', shape=(n, n)) + >>> B = IndexedBase('B') + >>> A[i, j].shape + (n, n) + >>> B[i, j].shape + (m, m) + """ + + if self.base.shape: + return self.base.shape + sizes = [] + for i in self.indices: + upper = getattr(i, 'upper', None) + lower = getattr(i, 'lower', None) + if None in (upper, lower): + raise IndexException(filldedent(""" + Range is not defined for all indices in: %s""" % self)) + try: + size = upper - lower + 1 + except TypeError: + raise IndexException(filldedent(""" + Shape cannot be inferred from Idx with + undefined range: %s""" % self)) + sizes.append(size) + return Tuple(*sizes) + + @property + def ranges(self): + """Returns a list of tuples with lower and upper range of each index. + + If an index does not define the data members upper and lower, the + corresponding slot in the list contains ``None`` instead of a tuple. + + Examples + ======== + + >>> from sympy import Indexed,Idx, symbols + >>> Indexed('A', Idx('i', 2), Idx('j', 4), Idx('k', 8)).ranges + [(0, 1), (0, 3), (0, 7)] + >>> Indexed('A', Idx('i', 3), Idx('j', 3), Idx('k', 3)).ranges + [(0, 2), (0, 2), (0, 2)] + >>> x, y, z = symbols('x y z', integer=True) + >>> Indexed('A', x, y, z).ranges + [None, None, None] + + """ + ranges = [] + sentinel = object() + for i in self.indices: + upper = getattr(i, 'upper', sentinel) + lower = getattr(i, 'lower', sentinel) + if sentinel not in (upper, lower): + ranges.append((lower, upper)) + else: + ranges.append(None) + return ranges + + def _sympystr(self, p): + indices = list(map(p.doprint, self.indices)) + return "%s[%s]" % (p.doprint(self.base), ", ".join(indices)) + + @property + def free_symbols(self): + base_free_symbols = self.base.free_symbols + indices_free_symbols = { + fs for i in self.indices for fs in i.free_symbols} + if base_free_symbols: + return {self} | base_free_symbols | indices_free_symbols + else: + return indices_free_symbols + + @property + def expr_free_symbols(self): + from sympy.utilities.exceptions import sympy_deprecation_warning + sympy_deprecation_warning(""" + The expr_free_symbols property is deprecated. Use free_symbols to get + the free symbols of an expression. + """, + deprecated_since_version="1.9", + active_deprecations_target="deprecated-expr-free-symbols") + + return {self} + + +class IndexedBase(Expr, NotIterable): + """Represent the base or stem of an indexed object + + The IndexedBase class represent an array that contains elements. The main purpose + of this class is to allow the convenient creation of objects of the Indexed + class. The __getitem__ method of IndexedBase returns an instance of + Indexed. Alone, without indices, the IndexedBase class can be used as a + notation for e.g. matrix equations, resembling what you could do with the + Symbol class. But, the IndexedBase class adds functionality that is not + available for Symbol instances: + + - An IndexedBase object can optionally store shape information. This can + be used in to check array conformance and conditions for numpy + broadcasting. (TODO) + - An IndexedBase object implements syntactic sugar that allows easy symbolic + representation of array operations, using implicit summation of + repeated indices. + - The IndexedBase object symbolizes a mathematical structure equivalent + to arrays, and is recognized as such for code generation and automatic + compilation and wrapping. + + >>> from sympy.tensor import IndexedBase, Idx + >>> from sympy import symbols + >>> A = IndexedBase('A'); A + A + >>> type(A) + + + When an IndexedBase object receives indices, it returns an array with named + axes, represented by an Indexed object: + + >>> i, j = symbols('i j', integer=True) + >>> A[i, j, 2] + A[i, j, 2] + >>> type(A[i, j, 2]) + + + The IndexedBase constructor takes an optional shape argument. If given, + it overrides any shape information in the indices. (But not the index + ranges!) + + >>> m, n, o, p = symbols('m n o p', integer=True) + >>> i = Idx('i', m) + >>> j = Idx('j', n) + >>> A[i, j].shape + (m, n) + >>> B = IndexedBase('B', shape=(o, p)) + >>> B[i, j].shape + (o, p) + + Assumptions can be specified with keyword arguments the same way as for Symbol: + + >>> A_real = IndexedBase('A', real=True) + >>> A_real.is_real + True + >>> A != A_real + True + + Assumptions can also be inherited if a Symbol is used to initialize the IndexedBase: + + >>> I = symbols('I', integer=True) + >>> C_inherit = IndexedBase(I) + >>> C_explicit = IndexedBase('I', integer=True) + >>> C_inherit == C_explicit + True + """ + is_symbol = True + is_Atom = True + + @staticmethod + def _set_assumptions(obj, assumptions): + """Set assumptions on obj, making sure to apply consistent values.""" + tmp_asm_copy = assumptions.copy() + is_commutative = fuzzy_bool(assumptions.get('commutative', True)) + assumptions['commutative'] = is_commutative + obj._assumptions = StdFactKB(assumptions) + obj._assumptions._generator = tmp_asm_copy # Issue #8873 + + def __new__(cls, label, shape=None, *, offset=S.Zero, strides=None, **kw_args): + from sympy.matrices.matrixbase import MatrixBase + from sympy.tensor.array.ndim_array import NDimArray + + assumptions, kw_args = _filter_assumptions(kw_args) + if isinstance(label, str): + label = Symbol(label, **assumptions) + elif isinstance(label, Symbol): + assumptions = label._merge(assumptions) + elif isinstance(label, (MatrixBase, NDimArray)): + return label + elif isinstance(label, Iterable): + return _sympify(label) + else: + label = _sympify(label) + + if is_sequence(shape): + shape = Tuple(*shape) + elif shape is not None: + shape = Tuple(shape) + + if shape is not None: + obj = Expr.__new__(cls, label, shape) + else: + obj = Expr.__new__(cls, label) + obj._shape = shape + obj._offset = offset + obj._strides = strides + obj._name = str(label) + + IndexedBase._set_assumptions(obj, assumptions) + return obj + + @property + def name(self): + return self._name + + def _hashable_content(self): + return super()._hashable_content() + tuple(sorted(self.assumptions0.items())) + + @property + def assumptions0(self): + return {k: v for k, v in self._assumptions.items() if v is not None} + + def __getitem__(self, indices, **kw_args): + if is_sequence(indices): + # Special case needed because M[*my_tuple] is a syntax error. + if self.shape and len(self.shape) != len(indices): + raise IndexException("Rank mismatch.") + return Indexed(self, *indices, **kw_args) + else: + if self.shape and len(self.shape) != 1: + raise IndexException("Rank mismatch.") + return Indexed(self, indices, **kw_args) + + @property + def shape(self): + """Returns the shape of the ``IndexedBase`` object. + + Examples + ======== + + >>> from sympy import IndexedBase, Idx + >>> from sympy.abc import x, y + >>> IndexedBase('A', shape=(x, y)).shape + (x, y) + + Note: If the shape of the ``IndexedBase`` is specified, it will override + any shape information given by the indices. + + >>> A = IndexedBase('A', shape=(x, y)) + >>> B = IndexedBase('B') + >>> i = Idx('i', 2) + >>> j = Idx('j', 1) + >>> A[i, j].shape + (x, y) + >>> B[i, j].shape + (2, 1) + + """ + return self._shape + + @property + def strides(self): + """Returns the strided scheme for the ``IndexedBase`` object. + + Normally this is a tuple denoting the number of + steps to take in the respective dimension when traversing + an array. For code generation purposes strides='C' and + strides='F' can also be used. + + strides='C' would mean that code printer would unroll + in row-major order and 'F' means unroll in column major + order. + + """ + + return self._strides + + @property + def offset(self): + """Returns the offset for the ``IndexedBase`` object. + + This is the value added to the resulting index when the + 2D Indexed object is unrolled to a 1D form. Used in code + generation. + + Examples + ========== + >>> from sympy.printing import ccode + >>> from sympy.tensor import IndexedBase, Idx + >>> from sympy import symbols + >>> l, m, n, o = symbols('l m n o', integer=True) + >>> A = IndexedBase('A', strides=(l, m, n), offset=o) + >>> i, j, k = map(Idx, 'ijk') + >>> ccode(A[i, j, k]) + 'A[l*i + m*j + n*k + o]' + + """ + return self._offset + + @property + def label(self): + """Returns the label of the ``IndexedBase`` object. + + Examples + ======== + + >>> from sympy import IndexedBase + >>> from sympy.abc import x, y + >>> IndexedBase('A', shape=(x, y)).label + A + + """ + return self.args[0] + + def _sympystr(self, p): + return p.doprint(self.label) + + +class Idx(Expr): + """Represents an integer index as an ``Integer`` or integer expression. + + There are a number of ways to create an ``Idx`` object. The constructor + takes two arguments: + + ``label`` + An integer or a symbol that labels the index. + ``range`` + Optionally you can specify a range as either + + * ``Symbol`` or integer: This is interpreted as a dimension. Lower and + upper bounds are set to ``0`` and ``range - 1``, respectively. + * ``tuple``: The two elements are interpreted as the lower and upper + bounds of the range, respectively. + + Note: bounds of the range are assumed to be either integer or infinite (oo + and -oo are allowed to specify an unbounded range). If ``n`` is given as a + bound, then ``n.is_integer`` must not return false. + + For convenience, if the label is given as a string it is automatically + converted to an integer symbol. (Note: this conversion is not done for + range or dimension arguments.) + + Examples + ======== + + >>> from sympy import Idx, symbols, oo + >>> n, i, L, U = symbols('n i L U', integer=True) + + If a string is given for the label an integer ``Symbol`` is created and the + bounds are both ``None``: + + >>> idx = Idx('qwerty'); idx + qwerty + >>> idx.lower, idx.upper + (None, None) + + Both upper and lower bounds can be specified: + + >>> idx = Idx(i, (L, U)); idx + i + >>> idx.lower, idx.upper + (L, U) + + When only a single bound is given it is interpreted as the dimension + and the lower bound defaults to 0: + + >>> idx = Idx(i, n); idx.lower, idx.upper + (0, n - 1) + >>> idx = Idx(i, 4); idx.lower, idx.upper + (0, 3) + >>> idx = Idx(i, oo); idx.lower, idx.upper + (0, oo) + + """ + + is_integer = True + is_finite = True + is_real = True + is_symbol = True + is_Atom = True + _diff_wrt = True + + def __new__(cls, label, range=None, **kw_args): + + if isinstance(label, str): + label = Symbol(label, integer=True) + label, range = list(map(sympify, (label, range))) + + if label.is_Number: + if not label.is_integer: + raise TypeError("Index is not an integer number.") + return label + + if not label.is_integer: + raise TypeError("Idx object requires an integer label.") + + elif is_sequence(range): + if len(range) != 2: + raise ValueError(filldedent(""" + Idx range tuple must have length 2, but got %s""" % len(range))) + for bound in range: + if (bound.is_integer is False and bound is not S.Infinity + and bound is not S.NegativeInfinity): + raise TypeError("Idx object requires integer bounds.") + args = label, Tuple(*range) + elif isinstance(range, Expr): + if range is not S.Infinity and fuzzy_not(range.is_integer): + raise TypeError("Idx object requires an integer dimension.") + args = label, Tuple(0, range - 1) + elif range: + raise TypeError(filldedent(""" + The range must be an ordered iterable or + integer SymPy expression.""")) + else: + args = label, + + obj = Expr.__new__(cls, *args, **kw_args) + obj._assumptions["finite"] = True + obj._assumptions["real"] = True + return obj + + @property + def label(self): + """Returns the label (Integer or integer expression) of the Idx object. + + Examples + ======== + + >>> from sympy import Idx, Symbol + >>> x = Symbol('x', integer=True) + >>> Idx(x).label + x + >>> j = Symbol('j', integer=True) + >>> Idx(j).label + j + >>> Idx(j + 1).label + j + 1 + + """ + return self.args[0] + + @property + def lower(self): + """Returns the lower bound of the ``Idx``. + + Examples + ======== + + >>> from sympy import Idx + >>> Idx('j', 2).lower + 0 + >>> Idx('j', 5).lower + 0 + >>> Idx('j').lower is None + True + + """ + try: + return self.args[1][0] + except IndexError: + return + + @property + def upper(self): + """Returns the upper bound of the ``Idx``. + + Examples + ======== + + >>> from sympy import Idx + >>> Idx('j', 2).upper + 1 + >>> Idx('j', 5).upper + 4 + >>> Idx('j').upper is None + True + + """ + try: + return self.args[1][1] + except IndexError: + return + + def _sympystr(self, p): + return p.doprint(self.label) + + @property + def name(self): + return self.label.name if self.label.is_Symbol else str(self.label) + + @property + def free_symbols(self): + return {self} + + +@dispatch(Idx, Idx) +def _eval_is_ge(lhs, rhs): # noqa:F811 + + other_upper = rhs if rhs.upper is None else rhs.upper + other_lower = rhs if rhs.lower is None else rhs.lower + + if lhs.lower is not None and (lhs.lower >= other_upper) == True: + return True + if lhs.upper is not None and (lhs.upper < other_lower) == True: + return False + return None + + +@dispatch(Idx, Number) # type:ignore +def _eval_is_ge(lhs, rhs): # noqa:F811 + + other_upper = rhs + other_lower = rhs + + if lhs.lower is not None and (lhs.lower >= other_upper) == True: + return True + if lhs.upper is not None and (lhs.upper < other_lower) == True: + return False + return None + + +@dispatch(Number, Idx) # type:ignore +def _eval_is_ge(lhs, rhs): # noqa:F811 + + other_upper = lhs + other_lower = lhs + + if rhs.upper is not None and (rhs.upper <= other_lower) == True: + return True + if rhs.lower is not None and (rhs.lower > other_upper) == True: + return False + return None diff --git a/phivenv/Lib/site-packages/sympy/tensor/tensor.py b/phivenv/Lib/site-packages/sympy/tensor/tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..579e7c7a86c2a1f18ab889af32ce0053a729ff5f --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/tensor/tensor.py @@ -0,0 +1,5265 @@ +""" +This module defines tensors with abstract index notation. + +The abstract index notation has been first formalized by Penrose. + +Tensor indices are formal objects, with a tensor type; there is no +notion of index range, it is only possible to assign the dimension, +used to trace the Kronecker delta; the dimension can be a Symbol. + +The Einstein summation convention is used. +The covariant indices are indicated with a minus sign in front of the index. + +For instance the tensor ``t = p(a)*A(b,c)*q(-c)`` has the index ``c`` +contracted. + +A tensor expression ``t`` can be called; called with its +indices in sorted order it is equal to itself: +in the above example ``t(a, b) == t``; +one can call ``t`` with different indices; ``t(c, d) == p(c)*A(d,a)*q(-a)``. + +The contracted indices are dummy indices, internally they have no name, +the indices being represented by a graph-like structure. + +Tensors are put in canonical form using ``canon_bp``, which uses +the Butler-Portugal algorithm for canonicalization using the monoterm +symmetries of the tensors. + +If there is a (anti)symmetric metric, the indices can be raised and +lowered when the tensor is put in canonical form. +""" + +from __future__ import annotations +from typing import Any +from functools import reduce +from math import prod + +from abc import abstractmethod, ABC +from collections import defaultdict +import operator +import itertools + +from sympy.core.numbers import (Integer, Rational) +from sympy.combinatorics import Permutation +from sympy.combinatorics.tensor_can import get_symmetric_group_sgs, \ + bsgs_direct_product, canonicalize, riemann_bsgs +from sympy.core import Basic, Expr, sympify, Add, Mul, S +from sympy.core.cache import clear_cache +from sympy.core.containers import Tuple, Dict +from sympy.core.function import WildFunction +from sympy.core.sorting import default_sort_key +from sympy.core.symbol import Symbol, symbols, Wild +from sympy.core.sympify import CantSympify, _sympify +from sympy.core.operations import AssocOp +from sympy.external.gmpy import SYMPY_INTS +from sympy.matrices import eye +from sympy.utilities.exceptions import (sympy_deprecation_warning, + SymPyDeprecationWarning, + ignore_warnings) +from sympy.utilities.decorator import memoize_property, deprecated +from sympy.utilities.iterables import sift + + +def deprecate_data(): + sympy_deprecation_warning( + """ + The data attribute of TensorIndexType is deprecated. Use The + replace_with_arrays() method instead. + """, + deprecated_since_version="1.4", + active_deprecations_target="deprecated-tensorindextype-attrs", + stacklevel=4, + ) + +def deprecate_fun_eval(): + sympy_deprecation_warning( + """ + The Tensor.fun_eval() method is deprecated. Use + Tensor.substitute_indices() instead. + """, + deprecated_since_version="1.5", + active_deprecations_target="deprecated-tensor-fun-eval", + stacklevel=4, + ) + + +def deprecate_call(): + sympy_deprecation_warning( + """ + Calling a tensor like Tensor(*indices) is deprecated. Use + Tensor.substitute_indices() instead. + """, + deprecated_since_version="1.5", + active_deprecations_target="deprecated-tensor-fun-eval", + stacklevel=4, + ) + + +class _IndexStructure(CantSympify): + """ + This class handles the indices (free and dummy ones). It contains the + algorithms to manage the dummy indices replacements and contractions of + free indices under multiplications of tensor expressions, as well as stuff + related to canonicalization sorting, getting the permutation of the + expression and so on. It also includes tools to get the ``TensorIndex`` + objects corresponding to the given index structure. + """ + + def __init__(self, free, dum, index_types, indices, canon_bp=False): + self.free = free + self.dum = dum + self.index_types = index_types + self.indices = indices + self._ext_rank = len(self.free) + 2*len(self.dum) + self.dum.sort(key=lambda x: x[0]) + + @staticmethod + def from_indices(*indices): + """ + Create a new ``_IndexStructure`` object from a list of ``indices``. + + Explanation + =========== + + ``indices`` ``TensorIndex`` objects, the indices. Contractions are + detected upon construction. + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices, _IndexStructure + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> m0, m1, m2, m3 = tensor_indices('m0,m1,m2,m3', Lorentz) + >>> _IndexStructure.from_indices(m0, m1, -m1, m3) + _IndexStructure([(m0, 0), (m3, 3)], [(1, 2)], [Lorentz, Lorentz, Lorentz, Lorentz]) + """ + + free, dum = _IndexStructure._free_dum_from_indices(*indices) + index_types = [i.tensor_index_type for i in indices] + indices = _IndexStructure._replace_dummy_names(indices, free, dum) + return _IndexStructure(free, dum, index_types, indices) + + @staticmethod + def from_components_free_dum(components, free, dum): + index_types = [] + for component in components: + index_types.extend(component.index_types) + indices = _IndexStructure.generate_indices_from_free_dum_index_types(free, dum, index_types) + return _IndexStructure(free, dum, index_types, indices) + + @staticmethod + def _free_dum_from_indices(*indices): + """ + Convert ``indices`` into ``free``, ``dum`` for single component tensor. + + Explanation + =========== + + ``free`` list of tuples ``(index, pos, 0)``, + where ``pos`` is the position of index in + the list of indices formed by the component tensors + + ``dum`` list of tuples ``(pos_contr, pos_cov, 0, 0)`` + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices, \ + _IndexStructure + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> m0, m1, m2, m3 = tensor_indices('m0,m1,m2,m3', Lorentz) + >>> _IndexStructure._free_dum_from_indices(m0, m1, -m1, m3) + ([(m0, 0), (m3, 3)], [(1, 2)]) + """ + n = len(indices) + if n == 1: + return [(indices[0], 0)], [] + + # find the positions of the free indices and of the dummy indices + free = [True]*len(indices) + index_dict = {} + dum = [] + for i, index in enumerate(indices): + name = index.name + typ = index.tensor_index_type + contr = index.is_up + if (name, typ) in index_dict: + # found a pair of dummy indices + is_contr, pos = index_dict[(name, typ)] + # check consistency and update free + if is_contr: + if contr: + raise ValueError('two equal contravariant indices in slots %d and %d' %(pos, i)) + else: + free[pos] = False + free[i] = False + else: + if contr: + free[pos] = False + free[i] = False + else: + raise ValueError('two equal covariant indices in slots %d and %d' %(pos, i)) + if contr: + dum.append((i, pos)) + else: + dum.append((pos, i)) + else: + index_dict[(name, typ)] = index.is_up, i + + free = [(index, i) for i, index in enumerate(indices) if free[i]] + free.sort() + return free, dum + + def get_indices(self): + """ + Get a list of indices, creating new tensor indices to complete dummy indices. + """ + return self.indices[:] + + @staticmethod + def generate_indices_from_free_dum_index_types(free, dum, index_types): + indices = [None]*(len(free)+2*len(dum)) + for idx, pos in free: + indices[pos] = idx + + generate_dummy_name = _IndexStructure._get_generator_for_dummy_indices(free) + for pos1, pos2 in dum: + typ1 = index_types[pos1] + indname = generate_dummy_name(typ1) + indices[pos1] = TensorIndex(indname, typ1, True) + indices[pos2] = TensorIndex(indname, typ1, False) + + return _IndexStructure._replace_dummy_names(indices, free, dum) + + @staticmethod + def _get_generator_for_dummy_indices(free): + cdt = defaultdict(int) + # if the free indices have names with dummy_name, start with an + # index higher than those for the dummy indices + # to avoid name collisions + for indx, ipos in free: + if indx.name.split('_')[0] == indx.tensor_index_type.dummy_name: + cdt[indx.tensor_index_type] = max(cdt[indx.tensor_index_type], int(indx.name.split('_')[1]) + 1) + + def dummy_name_gen(tensor_index_type): + nd = str(cdt[tensor_index_type]) + cdt[tensor_index_type] += 1 + return tensor_index_type.dummy_name + '_' + nd + + return dummy_name_gen + + @staticmethod + def _replace_dummy_names(indices, free, dum): + dum.sort(key=lambda x: x[0]) + new_indices = list(indices) + assert len(indices) == len(free) + 2*len(dum) + generate_dummy_name = _IndexStructure._get_generator_for_dummy_indices(free) + for ipos1, ipos2 in dum: + typ1 = new_indices[ipos1].tensor_index_type + indname = generate_dummy_name(typ1) + new_indices[ipos1] = TensorIndex(indname, typ1, True) + new_indices[ipos2] = TensorIndex(indname, typ1, False) + return new_indices + + def get_free_indices(self) -> list[TensorIndex]: + """ + Get a list of free indices. + """ + # get sorted indices according to their position: + free = sorted(self.free, key=lambda x: x[1]) + return [i[0] for i in free] + + def __str__(self): + return "_IndexStructure({}, {}, {})".format(self.free, self.dum, self.index_types) + + def __repr__(self): + return self.__str__() + + def _get_sorted_free_indices_for_canon(self): + sorted_free = self.free[:] + sorted_free.sort(key=lambda x: x[0]) + return sorted_free + + def _get_sorted_dum_indices_for_canon(self): + return sorted(self.dum, key=lambda x: x[0]) + + def _get_lexicographically_sorted_index_types(self): + permutation = self.indices_canon_args()[0] + index_types = [None]*self._ext_rank + for i, it in enumerate(self.index_types): + index_types[permutation(i)] = it + return index_types + + def _get_lexicographically_sorted_indices(self): + permutation = self.indices_canon_args()[0] + indices = [None]*self._ext_rank + for i, it in enumerate(self.indices): + indices[permutation(i)] = it + return indices + + def perm2tensor(self, g, is_canon_bp=False): + """ + Returns a ``_IndexStructure`` instance corresponding to the permutation ``g``. + + Explanation + =========== + + ``g`` permutation corresponding to the tensor in the representation + used in canonicalization + + ``is_canon_bp`` if True, then ``g`` is the permutation + corresponding to the canonical form of the tensor + """ + sorted_free = [i[0] for i in self._get_sorted_free_indices_for_canon()] + lex_index_types = self._get_lexicographically_sorted_index_types() + lex_indices = self._get_lexicographically_sorted_indices() + nfree = len(sorted_free) + rank = self._ext_rank + dum = [[None]*2 for i in range((rank - nfree)//2)] + free = [] + + index_types = [None]*rank + indices = [None]*rank + for i in range(rank): + gi = g[i] + index_types[i] = lex_index_types[gi] + indices[i] = lex_indices[gi] + if gi < nfree: + ind = sorted_free[gi] + assert index_types[i] == sorted_free[gi].tensor_index_type + free.append((ind, i)) + else: + j = gi - nfree + idum, cov = divmod(j, 2) + if cov: + dum[idum][1] = i + else: + dum[idum][0] = i + dum = [tuple(x) for x in dum] + + return _IndexStructure(free, dum, index_types, indices) + + def indices_canon_args(self): + """ + Returns ``(g, dummies, msym, v)``, the entries of ``canonicalize`` + + See ``canonicalize`` in ``tensor_can.py`` in combinatorics module. + """ + # to be called after sorted_components + from sympy.combinatorics.permutations import _af_new + n = self._ext_rank + g = [None]*n + [n, n+1] + + # Converts the symmetry of the metric into msym from .canonicalize() + # method in the combinatorics module + def metric_symmetry_to_msym(metric): + if metric is None: + return None + sym = metric.symmetry + if sym == TensorSymmetry.fully_symmetric(2): + return 0 + if sym == TensorSymmetry.fully_symmetric(-2): + return 1 + return None + + # ordered indices: first the free indices, ordered by types + # then the dummy indices, ordered by types and contravariant before + # covariant + # g[position in tensor] = position in ordered indices + for i, (indx, ipos) in enumerate(self._get_sorted_free_indices_for_canon()): + g[ipos] = i + pos = len(self.free) + j = len(self.free) + dummies = [] + prev = None + a = [] + msym = [] + for ipos1, ipos2 in self._get_sorted_dum_indices_for_canon(): + g[ipos1] = j + g[ipos2] = j + 1 + j += 2 + typ = self.index_types[ipos1] + if typ != prev: + if a: + dummies.append(a) + a = [pos, pos + 1] + prev = typ + msym.append(metric_symmetry_to_msym(typ.metric)) + else: + a.extend([pos, pos + 1]) + pos += 2 + if a: + dummies.append(a) + + return _af_new(g), dummies, msym + + +def components_canon_args(components): + numtyp = [] + prev = None + for t in components: + if t == prev: + numtyp[-1][1] += 1 + else: + prev = t + numtyp.append([prev, 1]) + v = [] + for h, n in numtyp: + if h.comm in (0, 1): + comm = h.comm + else: + comm = TensorManager.get_comm(h.comm, h.comm) + v.append((h.symmetry.base, h.symmetry.generators, n, comm)) + return v + + +class _TensorDataLazyEvaluator(CantSympify): + """ + EXPERIMENTAL: do not rely on this class, it may change without deprecation + warnings in future versions of SymPy. + + Explanation + =========== + + This object contains the logic to associate components data to a tensor + expression. Components data are set via the ``.data`` property of tensor + expressions, is stored inside this class as a mapping between the tensor + expression and the ``ndarray``. + + Computations are executed lazily: whereas the tensor expressions can have + contractions, tensor products, and additions, components data are not + computed until they are accessed by reading the ``.data`` property + associated to the tensor expression. + """ + _substitutions_dict: dict[Any, Any] = {} + _substitutions_dict_tensmul: dict[Any, Any] = {} + + def __getitem__(self, key): + dat = self._get(key) + if dat is None: + return None + + from .array import NDimArray + if not isinstance(dat, NDimArray): + return dat + + if dat.rank() == 0: + return dat[()] + elif dat.rank() == 1 and len(dat) == 1: + return dat[0] + return dat + + def _get(self, key): + """ + Retrieve ``data`` associated with ``key``. + + Explanation + =========== + + This algorithm looks into ``self._substitutions_dict`` for all + ``TensorHead`` in the ``TensExpr`` (or just ``TensorHead`` if key is a + TensorHead instance). It reconstructs the components data that the + tensor expression should have by performing on components data the + operations that correspond to the abstract tensor operations applied. + + Metric tensor is handled in a different manner: it is pre-computed in + ``self._substitutions_dict_tensmul``. + """ + if key in self._substitutions_dict: + return self._substitutions_dict[key] + + if isinstance(key, TensorHead): + return None + + if isinstance(key, Tensor): + # special case to handle metrics. Metric tensors cannot be + # constructed through contraction by the metric, their + # components show if they are a matrix or its inverse. + signature = tuple([i.is_up for i in key.get_indices()]) + srch = (key.component,) + signature + if srch in self._substitutions_dict_tensmul: + return self._substitutions_dict_tensmul[srch] + array_list = [self.data_from_tensor(key)] + return self.data_contract_dum(array_list, key.dum, key.ext_rank) + + if isinstance(key, TensMul): + tensmul_args = key.args + if len(tensmul_args) == 1 and len(tensmul_args[0].components) == 1: + # special case to handle metrics. Metric tensors cannot be + # constructed through contraction by the metric, their + # components show if they are a matrix or its inverse. + signature = tuple([i.is_up for i in tensmul_args[0].get_indices()]) + srch = (tensmul_args[0].components[0],) + signature + if srch in self._substitutions_dict_tensmul: + return self._substitutions_dict_tensmul[srch] + #data_list = [self.data_from_tensor(i) for i in tensmul_args if isinstance(i, TensExpr)] + data_list = [self.data_from_tensor(i) if isinstance(i, Tensor) else i.data for i in tensmul_args if isinstance(i, TensExpr)] + coeff = prod([i for i in tensmul_args if not isinstance(i, TensExpr)]) + if all(i is None for i in data_list): + return None + if any(i is None for i in data_list): + raise ValueError("Mixing tensors with associated components "\ + "data with tensors without components data") + data_result = self.data_contract_dum(data_list, key.dum, key.ext_rank) + return coeff*data_result + + if isinstance(key, TensAdd): + data_list = [] + free_args_list = [] + for arg in key.args: + if isinstance(arg, TensExpr): + data_list.append(arg.data) + free_args_list.append([x[0] for x in arg.free]) + else: + data_list.append(arg) + free_args_list.append([]) + if all(i is None for i in data_list): + return None + if any(i is None for i in data_list): + raise ValueError("Mixing tensors with associated components "\ + "data with tensors without components data") + + sum_list = [] + from .array import permutedims + for data, free_args in zip(data_list, free_args_list): + if len(free_args) < 2: + sum_list.append(data) + else: + free_args_pos = {y: x for x, y in enumerate(free_args)} + axes = [free_args_pos[arg] for arg in key.free_args] + sum_list.append(permutedims(data, axes)) + return reduce(lambda x, y: x+y, sum_list) + + return None + + @staticmethod + def data_contract_dum(ndarray_list, dum, ext_rank): + from .array import tensorproduct, tensorcontraction, MutableDenseNDimArray + arrays = list(map(MutableDenseNDimArray, ndarray_list)) + prodarr = tensorproduct(*arrays) + return tensorcontraction(prodarr, *dum) + + def data_tensorhead_from_tensmul(self, data, tensmul, tensorhead): + """ + This method is used when assigning components data to a ``TensMul`` + object, it converts components data to a fully contravariant ndarray, + which is then stored according to the ``TensorHead`` key. + """ + if data is None: + return None + + return self._correct_signature_from_indices( + data, + tensmul.get_indices(), + tensmul.free, + tensmul.dum, + True) + + def data_from_tensor(self, tensor): + """ + This method corrects the components data to the right signature + (covariant/contravariant) using the metric associated with each + ``TensorIndexType``. + """ + tensorhead = tensor.component + + if tensorhead.data is None: + return None + + return self._correct_signature_from_indices( + tensorhead.data, + tensor.get_indices(), + tensor.free, + tensor.dum) + + def _assign_data_to_tensor_expr(self, key, data): + if isinstance(key, TensAdd): + raise ValueError('cannot assign data to TensAdd') + # here it is assumed that `key` is a `TensMul` instance. + if len(key.components) != 1: + raise ValueError('cannot assign data to TensMul with multiple components') + tensorhead = key.components[0] + newdata = self.data_tensorhead_from_tensmul(data, key, tensorhead) + return tensorhead, newdata + + def _check_permutations_on_data(self, tens, data): + from .array import permutedims + from .array.arrayop import Flatten + + if isinstance(tens, TensorHead): + rank = tens.rank + generators = tens.symmetry.generators + elif isinstance(tens, Tensor): + rank = tens.rank + generators = tens.components[0].symmetry.generators + elif isinstance(tens, TensorIndexType): + rank = tens.metric.rank + generators = tens.metric.symmetry.generators + + # Every generator is a permutation, check that by permuting the array + # by that permutation, the array will be the same, except for a + # possible sign change if the permutation admits it. + for gener in generators: + sign_change = +1 if (gener(rank) == rank) else -1 + data_swapped = data + last_data = data + permute_axes = list(map(gener, range(rank))) + # the order of a permutation is the number of times to get the + # identity by applying that permutation. + for i in range(gener.order()-1): + data_swapped = permutedims(data_swapped, permute_axes) + # if any value in the difference array is non-zero, raise an error: + if any(Flatten(last_data - sign_change*data_swapped)): + raise ValueError("Component data symmetry structure error") + last_data = data_swapped + + def __setitem__(self, key, value): + """ + Set the components data of a tensor object/expression. + + Explanation + =========== + + Components data are transformed to the all-contravariant form and stored + with the corresponding ``TensorHead`` object. If a ``TensorHead`` object + cannot be uniquely identified, it will raise an error. + """ + data = _TensorDataLazyEvaluator.parse_data(value) + self._check_permutations_on_data(key, data) + + # TensorHead and TensorIndexType can be assigned data directly, while + # TensMul must first convert data to a fully contravariant form, and + # assign it to its corresponding TensorHead single component. + if not isinstance(key, (TensorHead, TensorIndexType)): + key, data = self._assign_data_to_tensor_expr(key, data) + + if isinstance(key, TensorHead): + for dim, indextype in zip(data.shape, key.index_types): + if indextype.data is None: + raise ValueError("index type {} has no components data"\ + " associated (needed to raise/lower index)".format(indextype)) + if not indextype.dim.is_number: + continue + if dim != indextype.dim: + raise ValueError("wrong dimension of ndarray") + self._substitutions_dict[key] = data + + def __delitem__(self, key): + del self._substitutions_dict[key] + + def __contains__(self, key): + return key in self._substitutions_dict + + def add_metric_data(self, metric, data): + """ + Assign data to the ``metric`` tensor. The metric tensor behaves in an + anomalous way when raising and lowering indices. + + Explanation + =========== + + A fully covariant metric is the inverse transpose of the fully + contravariant metric (it is meant matrix inverse). If the metric is + symmetric, the transpose is not necessary and mixed + covariant/contravariant metrics are Kronecker deltas. + """ + # hard assignment, data should not be added to `TensorHead` for metric: + # the problem with `TensorHead` is that the metric is anomalous, i.e. + # raising and lowering the index means considering the metric or its + # inverse, this is not the case for other tensors. + self._substitutions_dict_tensmul[metric, True, True] = data + inverse_transpose = self.inverse_transpose_matrix(data) + # in symmetric spaces, the transpose is the same as the original matrix, + # the full covariant metric tensor is the inverse transpose, so this + # code will be able to handle non-symmetric metrics. + self._substitutions_dict_tensmul[metric, False, False] = inverse_transpose + # now mixed cases, these are identical to the unit matrix if the metric + # is symmetric. + m = data.tomatrix() + invt = inverse_transpose.tomatrix() + self._substitutions_dict_tensmul[metric, True, False] = m * invt + self._substitutions_dict_tensmul[metric, False, True] = invt * m + + @staticmethod + def _flip_index_by_metric(data, metric, pos): + from .array import tensorproduct, tensorcontraction + + mdim = metric.rank() + ddim = data.rank() + + if pos == 0: + data = tensorcontraction( + tensorproduct( + metric, + data + ), + (1, mdim+pos) + ) + else: + data = tensorcontraction( + tensorproduct( + data, + metric + ), + (pos, ddim) + ) + return data + + @staticmethod + def inverse_matrix(ndarray): + m = ndarray.tomatrix().inv() + return _TensorDataLazyEvaluator.parse_data(m) + + @staticmethod + def inverse_transpose_matrix(ndarray): + m = ndarray.tomatrix().inv().T + return _TensorDataLazyEvaluator.parse_data(m) + + @staticmethod + def _correct_signature_from_indices(data, indices, free, dum, inverse=False): + """ + Utility function to correct the values inside the components data + ndarray according to whether indices are covariant or contravariant. + + It uses the metric matrix to lower values of covariant indices. + """ + # change the ndarray values according covariantness/contravariantness of the indices + # use the metric + for i, indx in enumerate(indices): + if not indx.is_up and not inverse: + data = _TensorDataLazyEvaluator._flip_index_by_metric(data, indx.tensor_index_type.data, i) + elif not indx.is_up and inverse: + data = _TensorDataLazyEvaluator._flip_index_by_metric( + data, + _TensorDataLazyEvaluator.inverse_matrix(indx.tensor_index_type.data), + i + ) + return data + + @staticmethod + def _sort_data_axes(old, new): + from .array import permutedims + + new_data = old.data.copy() + + old_free = [i[0] for i in old.free] + new_free = [i[0] for i in new.free] + + for i in range(len(new_free)): + for j in range(i, len(old_free)): + if old_free[j] == new_free[i]: + old_free[i], old_free[j] = old_free[j], old_free[i] + new_data = permutedims(new_data, (i, j)) + break + return new_data + + @staticmethod + def add_rearrange_tensmul_parts(new_tensmul, old_tensmul): + def sorted_compo(): + return _TensorDataLazyEvaluator._sort_data_axes(old_tensmul, new_tensmul) + + _TensorDataLazyEvaluator._substitutions_dict[new_tensmul] = sorted_compo() + + @staticmethod + def parse_data(data): + """ + Transform ``data`` to array. The parameter ``data`` may + contain data in various formats, e.g. nested lists, SymPy ``Matrix``, + and so on. + + Examples + ======== + + >>> from sympy.tensor.tensor import _TensorDataLazyEvaluator + >>> _TensorDataLazyEvaluator.parse_data([1, 3, -6, 12]) + [1, 3, -6, 12] + + >>> _TensorDataLazyEvaluator.parse_data([[1, 2], [4, 7]]) + [[1, 2], [4, 7]] + """ + from .array import MutableDenseNDimArray + + if not isinstance(data, MutableDenseNDimArray): + if len(data) == 2 and hasattr(data[0], '__call__'): + data = MutableDenseNDimArray(data[0], data[1]) + else: + data = MutableDenseNDimArray(data) + return data + +_tensor_data_substitution_dict = _TensorDataLazyEvaluator() + + +class _TensorManager: + """ + Class to manage tensor properties. + + Notes + ===== + + Tensors belong to tensor commutation groups; each group has a label + ``comm``; there are predefined labels: + + ``0`` tensors commuting with any other tensor + + ``1`` tensors anticommuting among themselves + + ``2`` tensors not commuting, apart with those with ``comm=0`` + + Other groups can be defined using ``set_comm``; tensors in those + groups commute with those with ``comm=0``; by default they + do not commute with any other group. + """ + def __init__(self): + self._comm_init() + + def _comm_init(self): + self._comm = [{} for i in range(3)] + for i in range(3): + self._comm[0][i] = 0 + self._comm[i][0] = 0 + self._comm[1][1] = 1 + self._comm[2][1] = None + self._comm[1][2] = None + self._comm_symbols2i = {0:0, 1:1, 2:2} + self._comm_i2symbol = {0:0, 1:1, 2:2} + + @property + def comm(self): + return self._comm + + def comm_symbols2i(self, i): + """ + Get the commutation group number corresponding to ``i``. + + ``i`` can be a symbol or a number or a string. + + If ``i`` is not already defined its commutation group number + is set. + """ + if i not in self._comm_symbols2i: + n = len(self._comm) + self._comm.append({}) + self._comm[n][0] = 0 + self._comm[0][n] = 0 + self._comm_symbols2i[i] = n + self._comm_i2symbol[n] = i + return n + return self._comm_symbols2i[i] + + def comm_i2symbol(self, i): + """ + Returns the symbol corresponding to the commutation group number. + """ + return self._comm_i2symbol[i] + + def set_comm(self, i, j, c): + """ + Set the commutation parameter ``c`` for commutation groups ``i, j``. + + Parameters + ========== + + i, j : symbols representing commutation groups + + c : group commutation number + + Notes + ===== + + ``i, j`` can be symbols, strings or numbers, + apart from ``0, 1`` and ``2`` which are reserved respectively + for commuting, anticommuting tensors and tensors not commuting + with any other group apart with the commuting tensors. + For the remaining cases, use this method to set the commutation rules; + by default ``c=None``. + + The group commutation number ``c`` is assigned in correspondence + to the group commutation symbols; it can be + + 0 commuting + + 1 anticommuting + + None no commutation property + + Examples + ======== + + ``G`` and ``GH`` do not commute with themselves and commute with + each other; A is commuting. + + >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices, TensorHead, TensorManager, TensorSymmetry + >>> Lorentz = TensorIndexType('Lorentz') + >>> i0,i1,i2,i3,i4 = tensor_indices('i0:5', Lorentz) + >>> A = TensorHead('A', [Lorentz]) + >>> G = TensorHead('G', [Lorentz], TensorSymmetry.no_symmetry(1), 'Gcomm') + >>> GH = TensorHead('GH', [Lorentz], TensorSymmetry.no_symmetry(1), 'GHcomm') + >>> TensorManager.set_comm('Gcomm', 'GHcomm', 0) + >>> (GH(i1)*G(i0)).canon_bp() + G(i0)*GH(i1) + >>> (G(i1)*G(i0)).canon_bp() + G(i1)*G(i0) + >>> (G(i1)*A(i0)).canon_bp() + A(i0)*G(i1) + """ + if c not in (0, 1, None): + raise ValueError('`c` can assume only the values 0, 1 or None') + + i = sympify(i) + j = sympify(j) + + if i not in self._comm_symbols2i: + n = len(self._comm) + self._comm.append({}) + self._comm[n][0] = 0 + self._comm[0][n] = 0 + self._comm_symbols2i[i] = n + self._comm_i2symbol[n] = i + if j not in self._comm_symbols2i: + n = len(self._comm) + self._comm.append({}) + self._comm[0][n] = 0 + self._comm[n][0] = 0 + self._comm_symbols2i[j] = n + self._comm_i2symbol[n] = j + ni = self._comm_symbols2i[i] + nj = self._comm_symbols2i[j] + self._comm[ni][nj] = c + self._comm[nj][ni] = c + + """ + Cached sympy functions (e.g. expand) may have cached the results of + expressions involving tensors, but those results may not be valid after + changing the commutation properties. To stay on the safe side, we clear + the cache of all functions. + """ + clear_cache() + + def set_comms(self, *args): + """ + Set the commutation group numbers ``c`` for symbols ``i, j``. + + Parameters + ========== + + args : sequence of ``(i, j, c)`` + """ + for i, j, c in args: + self.set_comm(i, j, c) + + def get_comm(self, i, j): + """ + Return the commutation parameter for commutation group numbers ``i, j`` + + see ``_TensorManager.set_comm`` + """ + return self._comm[i].get(j, 0 if i == 0 or j == 0 else None) + + def clear(self): + """ + Clear the TensorManager. + """ + self._comm_init() + + +TensorManager = _TensorManager() + + +class TensorIndexType(Basic): + """ + A TensorIndexType is characterized by its name and its metric. + + Parameters + ========== + + name : name of the tensor type + dummy_name : name of the head of dummy indices + dim : dimension, it can be a symbol or an integer or ``None`` + eps_dim : dimension of the epsilon tensor + metric_symmetry : integer that denotes metric symmetry or ``None`` for no metric + metric_name : string with the name of the metric tensor + + Attributes + ========== + + ``metric`` : the metric tensor + ``delta`` : ``Kronecker delta`` + ``epsilon`` : the ``Levi-Civita epsilon`` tensor + ``data`` : (deprecated) a property to add ``ndarray`` values, to work in a specified basis. + + Notes + ===== + + The possible values of the ``metric_symmetry`` parameter are: + + ``1`` : metric tensor is fully symmetric + ``0`` : metric tensor possesses no index symmetry + ``-1`` : metric tensor is fully antisymmetric + ``None``: there is no metric tensor (metric equals to ``None``) + + The metric is assumed to be symmetric by default. It can also be set + to a custom tensor by the ``.set_metric()`` method. + + If there is a metric the metric is used to raise and lower indices. + + In the case of non-symmetric metric, the following raising and + lowering conventions will be adopted: + + ``psi(a) = g(a, b)*psi(-b); chi(-a) = chi(b)*g(-b, -a)`` + + From these it is easy to find: + + ``g(-a, b) = delta(-a, b)`` + + where ``delta(-a, b) = delta(b, -a)`` is the ``Kronecker delta`` + (see ``TensorIndex`` for the conventions on indices). + For antisymmetric metrics there is also the following equality: + + ``g(a, -b) = -delta(a, -b)`` + + If there is no metric it is not possible to raise or lower indices; + e.g. the index of the defining representation of ``SU(N)`` + is 'covariant' and the conjugate representation is + 'contravariant'; for ``N > 2`` they are linearly independent. + + ``eps_dim`` is by default equal to ``dim``, if the latter is an integer; + else it can be assigned (for use in naive dimensional regularization); + if ``eps_dim`` is not an integer ``epsilon`` is ``None``. + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> Lorentz.metric + metric(Lorentz,Lorentz) + """ + + def __new__(cls, name, dummy_name=None, dim=None, eps_dim=None, + metric_symmetry=1, metric_name='metric', **kwargs): + if 'dummy_fmt' in kwargs: + dummy_fmt = kwargs['dummy_fmt'] + sympy_deprecation_warning( + f""" + The dummy_fmt keyword to TensorIndexType is deprecated. Use + dummy_name={dummy_fmt} instead. + """, + deprecated_since_version="1.5", + active_deprecations_target="deprecated-tensorindextype-dummy-fmt", + ) + dummy_name = dummy_fmt + + if isinstance(name, str): + name = Symbol(name) + + if dummy_name is None: + dummy_name = str(name)[0] + if isinstance(dummy_name, str): + dummy_name = Symbol(dummy_name) + + if dim is None: + dim = Symbol("dim_" + dummy_name.name) + else: + dim = sympify(dim) + + if eps_dim is None: + eps_dim = dim + else: + eps_dim = sympify(eps_dim) + + metric_symmetry = sympify(metric_symmetry) + + if isinstance(metric_name, str): + metric_name = Symbol(metric_name) + + if 'metric' in kwargs: + SymPyDeprecationWarning( + """ + The 'metric' keyword argument to TensorIndexType is + deprecated. Use the 'metric_symmetry' keyword argument or the + TensorIndexType.set_metric() method instead. + """, + deprecated_since_version="1.5", + active_deprecations_target="deprecated-tensorindextype-metric", + ) + metric = kwargs.get('metric') + if metric is not None: + if metric in (True, False, 0, 1): + metric_name = 'metric' + #metric_antisym = metric + else: + metric_name = metric.name + #metric_antisym = metric.antisym + + if metric: + metric_symmetry = -1 + else: + metric_symmetry = 1 + + obj = Basic.__new__(cls, name, dummy_name, dim, eps_dim, + metric_symmetry, metric_name) + + obj._autogenerated = [] + return obj + + @property + def name(self): + return self.args[0].name + + @property + def dummy_name(self): + return self.args[1].name + + @property + def dim(self): + return self.args[2] + + @property + def eps_dim(self): + return self.args[3] + + @memoize_property + def metric(self): + metric_symmetry = self.args[4] + metric_name = self.args[5] + if metric_symmetry is None: + return None + + if metric_symmetry == 0: + symmetry = TensorSymmetry.no_symmetry(2) + elif metric_symmetry == 1: + symmetry = TensorSymmetry.fully_symmetric(2) + elif metric_symmetry == -1: + symmetry = TensorSymmetry.fully_symmetric(-2) + + return TensorHead(metric_name, [self]*2, symmetry) + + @memoize_property + def delta(self): + return TensorHead('KD', [self]*2, TensorSymmetry.fully_symmetric(2)) + + @memoize_property + def epsilon(self): + if not isinstance(self.eps_dim, (SYMPY_INTS, Integer)): + return None + symmetry = TensorSymmetry.fully_symmetric(-self.eps_dim) + return TensorHead('Eps', [self]*self.eps_dim, symmetry) + + def set_metric(self, tensor): + self._metric = tensor + + def __lt__(self, other): + return self.name < other.name + + def __str__(self): + return self.name + + __repr__ = __str__ + + # Everything below this line is deprecated + + @property + def data(self): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + return _tensor_data_substitution_dict[self] + + @data.setter + def data(self, data): + deprecate_data() + # This assignment is a bit controversial, should metric components be assigned + # to the metric only or also to the TensorIndexType object? The advantage here + # is the ability to assign a 1D array and transform it to a 2D diagonal array. + from .array import MutableDenseNDimArray + + data = _TensorDataLazyEvaluator.parse_data(data) + if data.rank() > 2: + raise ValueError("data have to be of rank 1 (diagonal metric) or 2.") + if data.rank() == 1: + if self.dim.is_number: + nda_dim = data.shape[0] + if nda_dim != self.dim: + raise ValueError("Dimension mismatch") + + dim = data.shape[0] + newndarray = MutableDenseNDimArray.zeros(dim, dim) + for i, val in enumerate(data): + newndarray[i, i] = val + data = newndarray + dim1, dim2 = data.shape + if dim1 != dim2: + raise ValueError("Non-square matrix tensor.") + if self.dim.is_number: + if self.dim != dim1: + raise ValueError("Dimension mismatch") + _tensor_data_substitution_dict[self] = data + _tensor_data_substitution_dict.add_metric_data(self.metric, data) + with ignore_warnings(SymPyDeprecationWarning): + delta = self.get_kronecker_delta() + i1 = TensorIndex('i1', self) + i2 = TensorIndex('i2', self) + with ignore_warnings(SymPyDeprecationWarning): + delta(i1, -i2).data = _TensorDataLazyEvaluator.parse_data(eye(dim1)) + + @data.deleter + def data(self): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + if self in _tensor_data_substitution_dict: + del _tensor_data_substitution_dict[self] + if self.metric in _tensor_data_substitution_dict: + del _tensor_data_substitution_dict[self.metric] + + @deprecated( + """ + The TensorIndexType.get_kronecker_delta() method is deprecated. Use + the TensorIndexType.delta attribute instead. + """, + deprecated_since_version="1.5", + active_deprecations_target="deprecated-tensorindextype-methods", + ) + def get_kronecker_delta(self): + sym2 = TensorSymmetry(get_symmetric_group_sgs(2)) + delta = TensorHead('KD', [self]*2, sym2) + return delta + + @deprecated( + """ + The TensorIndexType.get_epsilon() method is deprecated. Use + the TensorIndexType.epsilon attribute instead. + """, + deprecated_since_version="1.5", + active_deprecations_target="deprecated-tensorindextype-methods", + ) + def get_epsilon(self): + if not isinstance(self._eps_dim, (SYMPY_INTS, Integer)): + return None + sym = TensorSymmetry(get_symmetric_group_sgs(self._eps_dim, 1)) + epsilon = TensorHead('Eps', [self]*self._eps_dim, sym) + return epsilon + + def _components_data_full_destroy(self): + """ + EXPERIMENTAL: do not rely on this API method. + + This destroys components data associated to the ``TensorIndexType``, if + any, specifically: + + * metric tensor data + * Kronecker tensor data + """ + if self in _tensor_data_substitution_dict: + del _tensor_data_substitution_dict[self] + + def delete_tensmul_data(key): + if key in _tensor_data_substitution_dict._substitutions_dict_tensmul: + del _tensor_data_substitution_dict._substitutions_dict_tensmul[key] + + # delete metric data: + delete_tensmul_data((self.metric, True, True)) + delete_tensmul_data((self.metric, True, False)) + delete_tensmul_data((self.metric, False, True)) + delete_tensmul_data((self.metric, False, False)) + + # delete delta tensor data: + delta = self.get_kronecker_delta() + if delta in _tensor_data_substitution_dict: + del _tensor_data_substitution_dict[delta] + + +class TensorIndex(Basic): + """ + Represents a tensor index + + Parameters + ========== + + name : name of the index, or ``True`` if you want it to be automatically assigned + tensor_index_type : ``TensorIndexType`` of the index + is_up : flag for contravariant index (is_up=True by default) + + Attributes + ========== + + ``name`` + ``tensor_index_type`` + ``is_up`` + + Notes + ===== + + Tensor indices are contracted with the Einstein summation convention. + + An index can be in contravariant or in covariant form; in the latter + case it is represented prepending a ``-`` to the index name. Adding + ``-`` to a covariant (is_up=False) index makes it contravariant. + + Dummy indices have a name with head given by + ``tensor_inde_type.dummy_name`` with underscore and a number. + + Similar to ``symbols`` multiple contravariant indices can be created + at once using ``tensor_indices(s, typ)``, where ``s`` is a string + of names. + + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType, TensorIndex, TensorHead, tensor_indices + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> mu = TensorIndex('mu', Lorentz, is_up=False) + >>> nu, rho = tensor_indices('nu, rho', Lorentz) + >>> A = TensorHead('A', [Lorentz, Lorentz]) + >>> A(mu, nu) + A(-mu, nu) + >>> A(-mu, -rho) + A(mu, -rho) + >>> A(mu, -mu) + A(-L_0, L_0) + """ + def __new__(cls, name, tensor_index_type, is_up=True): + if isinstance(name, str): + name_symbol = Symbol(name) + elif isinstance(name, Symbol): + name_symbol = name + elif name is True: + name = "_i{}".format(len(tensor_index_type._autogenerated)) + name_symbol = Symbol(name) + tensor_index_type._autogenerated.append(name_symbol) + else: + raise ValueError("invalid name") + + is_up = sympify(is_up) + return Basic.__new__(cls, name_symbol, tensor_index_type, is_up) + + @property + def name(self): + return self.args[0].name + + @property + def tensor_index_type(self): + return self.args[1] + + @property + def is_up(self): + return self.args[2] + + def _print(self): + s = self.name + if not self.is_up: + s = '-%s' % s + return s + + def __lt__(self, other): + return ((self.tensor_index_type, self.name) < + (other.tensor_index_type, other.name)) + + def __neg__(self): + t1 = TensorIndex(self.name, self.tensor_index_type, + (not self.is_up)) + return t1 + + +def tensor_indices(s, typ): + """ + Returns list of tensor indices given their names and their types. + + Parameters + ========== + + s : string of comma separated names of indices + + typ : ``TensorIndexType`` of the indices + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> a, b, c, d = tensor_indices('a,b,c,d', Lorentz) + """ + if isinstance(s, str): + a = [x.name for x in symbols(s, seq=True)] + else: + raise ValueError('expecting a string') + + tilist = [TensorIndex(i, typ) for i in a] + if len(tilist) == 1: + return tilist[0] + return tilist + + +class TensorSymmetry(Basic): + """ + Monoterm symmetry of a tensor (i.e. any symmetric or anti-symmetric + index permutation). For the relevant terminology see ``tensor_can.py`` + section of the combinatorics module. + + Parameters + ========== + + bsgs : tuple ``(base, sgs)`` BSGS of the symmetry of the tensor + + Attributes + ========== + + ``base`` : base of the BSGS + ``generators`` : generators of the BSGS + ``rank`` : rank of the tensor + + Notes + ===== + + A tensor can have an arbitrary monoterm symmetry provided by its BSGS. + Multiterm symmetries, like the cyclic symmetry of the Riemann tensor + (i.e., Bianchi identity), are not covered. See combinatorics module for + information on how to generate BSGS for a general index permutation group. + Simple symmetries can be generated using built-in methods. + + See Also + ======== + + sympy.combinatorics.tensor_can.get_symmetric_group_sgs + + Examples + ======== + + Define a symmetric tensor of rank 2 + + >>> from sympy.tensor.tensor import TensorIndexType, TensorSymmetry, get_symmetric_group_sgs, TensorHead + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> sym = TensorSymmetry(get_symmetric_group_sgs(2)) + >>> T = TensorHead('T', [Lorentz]*2, sym) + + Note, that the same can also be done using built-in TensorSymmetry methods + + >>> sym2 = TensorSymmetry.fully_symmetric(2) + >>> sym == sym2 + True + """ + def __new__(cls, *args, **kw_args): + if len(args) == 1: + base, generators = args[0] + elif len(args) == 2: + base, generators = args + else: + raise TypeError("bsgs required, either two separate parameters or one tuple") + + if not isinstance(base, Tuple): + base = Tuple(*base) + if not isinstance(generators, Tuple): + generators = Tuple(*generators) + + return Basic.__new__(cls, base, generators, **kw_args) + + @property + def base(self): + return self.args[0] + + @property + def generators(self): + return self.args[1] + + @property + def rank(self): + return self.generators[0].size - 2 + + @classmethod + def fully_symmetric(cls, rank): + """ + Returns a fully symmetric (antisymmetric if ``rank``<0) + TensorSymmetry object for ``abs(rank)`` indices. + """ + if rank > 0: + bsgs = get_symmetric_group_sgs(rank, False) + elif rank < 0: + bsgs = get_symmetric_group_sgs(-rank, True) + elif rank == 0: + bsgs = ([], [Permutation(1)]) + return TensorSymmetry(bsgs) + + @classmethod + def direct_product(cls, *args): + """ + Returns a TensorSymmetry object that is being a direct product of + fully (anti-)symmetric index permutation groups. + + Notes + ===== + + Some examples for different values of ``(*args)``: + ``(1)`` vector, equivalent to ``TensorSymmetry.fully_symmetric(1)`` + ``(2)`` tensor with 2 symmetric indices, equivalent to ``.fully_symmetric(2)`` + ``(-2)`` tensor with 2 antisymmetric indices, equivalent to ``.fully_symmetric(-2)`` + ``(2, -2)`` tensor with the first 2 indices commuting and the last 2 anticommuting + ``(1, 1, 1)`` tensor with 3 indices without any symmetry + """ + base, sgs = [], [Permutation(1)] + for arg in args: + if arg > 0: + bsgs2 = get_symmetric_group_sgs(arg, False) + elif arg < 0: + bsgs2 = get_symmetric_group_sgs(-arg, True) + else: + continue + base, sgs = bsgs_direct_product(base, sgs, *bsgs2) + + return TensorSymmetry(base, sgs) + + @classmethod + def riemann(cls): + """ + Returns a monotorem symmetry of the Riemann tensor + """ + return TensorSymmetry(riemann_bsgs) + + @classmethod + def no_symmetry(cls, rank): + """ + TensorSymmetry object for ``rank`` indices with no symmetry + """ + return TensorSymmetry([], [Permutation(rank+1)]) + + +@deprecated( + """ + The tensorsymmetry() function is deprecated. Use the TensorSymmetry + constructor instead. + """, + deprecated_since_version="1.5", + active_deprecations_target="deprecated-tensorsymmetry", +) +def tensorsymmetry(*args): + """ + Returns a ``TensorSymmetry`` object. This method is deprecated, use + ``TensorSymmetry.direct_product()`` or ``.riemann()`` instead. + + Explanation + =========== + + One can represent a tensor with any monoterm slot symmetry group + using a BSGS. + + ``args`` can be a BSGS + ``args[0]`` base + ``args[1]`` sgs + + Usually tensors are in (direct products of) representations + of the symmetric group; + ``args`` can be a list of lists representing the shapes of Young tableaux + + Notes + ===== + + For instance: + ``[[1]]`` vector + ``[[1]*n]`` symmetric tensor of rank ``n`` + ``[[n]]`` antisymmetric tensor of rank ``n`` + ``[[2, 2]]`` monoterm slot symmetry of the Riemann tensor + ``[[1],[1]]`` vector*vector + ``[[2],[1],[1]`` (antisymmetric tensor)*vector*vector + + Notice that with the shape ``[2, 2]`` we associate only the monoterm + symmetries of the Riemann tensor; this is an abuse of notation, + since the shape ``[2, 2]`` corresponds usually to the irreducible + representation characterized by the monoterm symmetries and by the + cyclic symmetry. + """ + from sympy.combinatorics import Permutation + + def tableau2bsgs(a): + if len(a) == 1: + # antisymmetric vector + n = a[0] + bsgs = get_symmetric_group_sgs(n, 1) + else: + if all(x == 1 for x in a): + # symmetric vector + n = len(a) + bsgs = get_symmetric_group_sgs(n) + elif a == [2, 2]: + bsgs = riemann_bsgs + else: + raise NotImplementedError + return bsgs + + if not args: + return TensorSymmetry(Tuple(), Tuple(Permutation(1))) + + if len(args) == 2 and isinstance(args[1][0], Permutation): + return TensorSymmetry(args) + base, sgs = tableau2bsgs(args[0]) + for a in args[1:]: + basex, sgsx = tableau2bsgs(a) + base, sgs = bsgs_direct_product(base, sgs, basex, sgsx) + return TensorSymmetry(Tuple(base, sgs)) + +@deprecated( + "TensorType is deprecated. Use tensor_heads() instead.", + deprecated_since_version="1.5", + active_deprecations_target="deprecated-tensortype", +) +class TensorType(Basic): + """ + Class of tensor types. Deprecated, use tensor_heads() instead. + + Parameters + ========== + + index_types : list of ``TensorIndexType`` of the tensor indices + symmetry : ``TensorSymmetry`` of the tensor + + Attributes + ========== + + ``index_types`` + ``symmetry`` + ``types`` : list of ``TensorIndexType`` without repetitions + """ + is_commutative = False + + def __new__(cls, index_types, symmetry, **kw_args): + assert symmetry.rank == len(index_types) + obj = Basic.__new__(cls, Tuple(*index_types), symmetry, **kw_args) + return obj + + @property + def index_types(self): + return self.args[0] + + @property + def symmetry(self): + return self.args[1] + + @property + def types(self): + return sorted(set(self.index_types), key=lambda x: x.name) + + def __str__(self): + return 'TensorType(%s)' % ([str(x) for x in self.index_types]) + + def __call__(self, s, comm=0): + """ + Return a TensorHead object or a list of TensorHead objects. + + Parameters + ========== + + s : name or string of names. + + comm : Commutation group. + + see ``_TensorManager.set_comm`` + """ + if isinstance(s, str): + names = [x.name for x in symbols(s, seq=True)] + else: + raise ValueError('expecting a string') + if len(names) == 1: + return TensorHead(names[0], self.index_types, self.symmetry, comm) + else: + return [TensorHead(name, self.index_types, self.symmetry, comm) for name in names] + + +@deprecated( + """ + The tensorhead() function is deprecated. Use tensor_heads() instead. + """, + deprecated_since_version="1.5", + active_deprecations_target="deprecated-tensorhead", +) +def tensorhead(name, typ, sym=None, comm=0): + """ + Function generating tensorhead(s). This method is deprecated, + use TensorHead constructor or tensor_heads() instead. + + Parameters + ========== + + name : name or sequence of names (as in ``symbols``) + + typ : index types + + sym : same as ``*args`` in ``tensorsymmetry`` + + comm : commutation group number + see ``_TensorManager.set_comm`` + """ + if sym is None: + sym = [[1] for i in range(len(typ))] + with ignore_warnings(SymPyDeprecationWarning): + sym = tensorsymmetry(*sym) + return TensorHead(name, typ, sym, comm) + + +class TensorHead(Basic): + """ + Tensor head of the tensor. + + Parameters + ========== + + name : name of the tensor + index_types : list of TensorIndexType + symmetry : TensorSymmetry of the tensor + comm : commutation group number + + Attributes + ========== + + ``name`` + ``index_types`` + ``rank`` : total number of indices + ``symmetry`` + ``comm`` : commutation group + + Notes + ===== + + Similar to ``symbols`` multiple TensorHeads can be created using + ``tensorhead(s, typ, sym=None, comm=0)`` function, where ``s`` + is the string of names and ``sym`` is the monoterm tensor symmetry + (see ``tensorsymmetry``). + + A ``TensorHead`` belongs to a commutation group, defined by a + symbol on number ``comm`` (see ``_TensorManager.set_comm``); + tensors in a commutation group have the same commutation properties; + by default ``comm`` is ``0``, the group of the commuting tensors. + + Examples + ======== + + Define a fully antisymmetric tensor of rank 2: + + >>> from sympy.tensor.tensor import TensorIndexType, TensorHead, TensorSymmetry + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> asym2 = TensorSymmetry.fully_symmetric(-2) + >>> A = TensorHead('A', [Lorentz, Lorentz], asym2) + + Examples with ndarray values, the components data assigned to the + ``TensorHead`` object are assumed to be in a fully-contravariant + representation. In case it is necessary to assign components data which + represents the values of a non-fully covariant tensor, see the other + examples. + + >>> from sympy.tensor.tensor import tensor_indices + >>> from sympy import diag + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> i0, i1 = tensor_indices('i0:2', Lorentz) + + Specify a replacement dictionary to keep track of the arrays to use for + replacements in the tensorial expression. The ``TensorIndexType`` is + associated to the metric used for contractions (in fully covariant form): + + >>> repl = {Lorentz: diag(1, -1, -1, -1)} + + Let's see some examples of working with components with the electromagnetic + tensor: + + >>> from sympy import symbols + >>> Ex, Ey, Ez, Bx, By, Bz = symbols('E_x E_y E_z B_x B_y B_z') + >>> c = symbols('c', positive=True) + + Let's define `F`, an antisymmetric tensor: + + >>> F = TensorHead('F', [Lorentz, Lorentz], asym2) + + Let's update the dictionary to contain the matrix to use in the + replacements: + + >>> repl.update({F(-i0, -i1): [ + ... [0, Ex/c, Ey/c, Ez/c], + ... [-Ex/c, 0, -Bz, By], + ... [-Ey/c, Bz, 0, -Bx], + ... [-Ez/c, -By, Bx, 0]]}) + + Now it is possible to retrieve the contravariant form of the Electromagnetic + tensor: + + >>> F(i0, i1).replace_with_arrays(repl, [i0, i1]) + [[0, -E_x/c, -E_y/c, -E_z/c], [E_x/c, 0, -B_z, B_y], [E_y/c, B_z, 0, -B_x], [E_z/c, -B_y, B_x, 0]] + + and the mixed contravariant-covariant form: + + >>> F(i0, -i1).replace_with_arrays(repl, [i0, -i1]) + [[0, E_x/c, E_y/c, E_z/c], [E_x/c, 0, B_z, -B_y], [E_y/c, -B_z, 0, B_x], [E_z/c, B_y, -B_x, 0]] + + Energy-momentum of a particle may be represented as: + + >>> from sympy import symbols + >>> P = TensorHead('P', [Lorentz], TensorSymmetry.no_symmetry(1)) + >>> E, px, py, pz = symbols('E p_x p_y p_z', positive=True) + >>> repl.update({P(i0): [E, px, py, pz]}) + + The contravariant and covariant components are, respectively: + + >>> P(i0).replace_with_arrays(repl, [i0]) + [E, p_x, p_y, p_z] + >>> P(-i0).replace_with_arrays(repl, [-i0]) + [E, -p_x, -p_y, -p_z] + + The contraction of a 1-index tensor by itself: + + >>> expr = P(i0)*P(-i0) + >>> expr.replace_with_arrays(repl, []) + E**2 - p_x**2 - p_y**2 - p_z**2 + """ + is_commutative = False + + def __new__(cls, name, index_types, symmetry=None, comm=0): + if isinstance(name, str): + name_symbol = Symbol(name) + elif isinstance(name, Symbol): + name_symbol = name + else: + raise ValueError("invalid name") + + if symmetry is None: + symmetry = TensorSymmetry.no_symmetry(len(index_types)) + else: + assert symmetry.rank == len(index_types) + + obj = Basic.__new__(cls, name_symbol, Tuple(*index_types), symmetry, sympify(comm)) + return obj + + @property + def name(self): + return self.args[0].name + + @property + def index_types(self): + return list(self.args[1]) + + @property + def symmetry(self): + return self.args[2] + + @property + def comm(self): + return TensorManager.comm_symbols2i(self.args[3]) + + @property + def rank(self): + return len(self.index_types) + + def __lt__(self, other): + return (self.name, self.index_types) < (other.name, other.index_types) + + def commutes_with(self, other): + """ + Returns ``0`` if ``self`` and ``other`` commute, ``1`` if they anticommute. + + Returns ``None`` if ``self`` and ``other`` neither commute nor anticommute. + """ + r = TensorManager.get_comm(self.comm, other.comm) + return r + + def _print(self): + return '%s(%s)' %(self.name, ','.join([str(x) for x in self.index_types])) + + def __call__(self, *indices, **kw_args): + """ + Returns a tensor with indices. + + Explanation + =========== + + There is a special behavior in case of indices denoted by ``True``, + they are considered auto-matrix indices, their slots are automatically + filled, and confer to the tensor the behavior of a matrix or vector + upon multiplication with another tensor containing auto-matrix indices + of the same ``TensorIndexType``. This means indices get summed over the + same way as in matrix multiplication. For matrix behavior, define two + auto-matrix indices, for vector behavior define just one. + + Indices can also be strings, in which case the attribute + ``index_types`` is used to convert them to proper ``TensorIndex``. + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices, TensorSymmetry, TensorHead + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> a, b = tensor_indices('a,b', Lorentz) + >>> A = TensorHead('A', [Lorentz]*2, TensorSymmetry.no_symmetry(2)) + >>> t = A(a, -b) + >>> t + A(a, -b) + + """ + + updated_indices = [] + for idx, typ in zip(indices, self.index_types): + if isinstance(idx, str): + idx = idx.strip().replace(" ", "") + if idx.startswith('-'): + updated_indices.append(TensorIndex(idx[1:], typ, + is_up=False)) + else: + updated_indices.append(TensorIndex(idx, typ)) + else: + updated_indices.append(idx) + + updated_indices += indices[len(updated_indices):] + + tensor = Tensor(self, updated_indices, **kw_args) + return tensor.doit() + + # Everything below this line is deprecated + + def __pow__(self, other): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + if self.data is None: + raise ValueError("No power on abstract tensors.") + from .array import tensorproduct, tensorcontraction + metrics = [_.data for _ in self.index_types] + + marray = self.data + marraydim = marray.rank() + for metric in metrics: + marray = tensorproduct(marray, metric, marray) + marray = tensorcontraction(marray, (0, marraydim), (marraydim+1, marraydim+2)) + + return marray ** (other * S.Half) + + @property + def data(self): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + return _tensor_data_substitution_dict[self] + + @data.setter + def data(self, data): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + _tensor_data_substitution_dict[self] = data + + @data.deleter + def data(self): + deprecate_data() + if self in _tensor_data_substitution_dict: + del _tensor_data_substitution_dict[self] + + def __iter__(self): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + return self.data.__iter__() + + def _components_data_full_destroy(self): + """ + EXPERIMENTAL: do not rely on this API method. + + Destroy components data associated to the ``TensorHead`` object, this + checks for attached components data, and destroys components data too. + """ + # do not garbage collect Kronecker tensor (it should be done by + # ``TensorIndexType`` garbage collection) + deprecate_data() + if self.name == "KD": + return + + # the data attached to a tensor must be deleted only by the TensorHead + # destructor. If the TensorHead is deleted, it means that there are no + # more instances of that tensor anywhere. + if self in _tensor_data_substitution_dict: + del _tensor_data_substitution_dict[self] + + +def tensor_heads(s, index_types, symmetry=None, comm=0): + """ + Returns a sequence of TensorHeads from a string `s` + """ + if isinstance(s, str): + names = [x.name for x in symbols(s, seq=True)] + else: + raise ValueError('expecting a string') + + thlist = [TensorHead(name, index_types, symmetry, comm) for name in names] + if len(thlist) == 1: + return thlist[0] + return thlist + + +class TensExpr(Expr, ABC): + """ + Abstract base class for tensor expressions + + Notes + ===== + + A tensor expression is an expression formed by tensors; + currently the sums of tensors are distributed. + + A ``TensExpr`` can be a ``TensAdd`` or a ``TensMul``. + + ``TensMul`` objects are formed by products of component tensors, + and include a coefficient, which is a SymPy expression. + + + In the internal representation contracted indices are represented + by ``(ipos1, ipos2, icomp1, icomp2)``, where ``icomp1`` is the position + of the component tensor with contravariant index, ``ipos1`` is the + slot which the index occupies in that component tensor. + + Contracted indices are therefore nameless in the internal representation. + """ + + _op_priority = 12.0 + is_commutative = False + + def __neg__(self): + return self*S.NegativeOne + + def __abs__(self): + raise NotImplementedError + + def __add__(self, other): + return TensAdd(self, other).doit(deep=False) + + def __radd__(self, other): + return TensAdd(other, self).doit(deep=False) + + def __sub__(self, other): + return TensAdd(self, -other).doit(deep=False) + + def __rsub__(self, other): + return TensAdd(other, -self).doit(deep=False) + + def __mul__(self, other): + """ + Multiply two tensors using Einstein summation convention. + + Explanation + =========== + + If the two tensors have an index in common, one contravariant + and the other covariant, in their product the indices are summed + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices, tensor_heads + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> m0, m1, m2 = tensor_indices('m0,m1,m2', Lorentz) + >>> g = Lorentz.metric + >>> p, q = tensor_heads('p,q', [Lorentz]) + >>> t1 = p(m0) + >>> t2 = q(-m0) + >>> t1*t2 + p(L_0)*q(-L_0) + """ + return TensMul(self, other).doit(deep=False) + + def __rmul__(self, other): + return TensMul(other, self).doit(deep=False) + + def __truediv__(self, other): + other = _sympify(other) + if isinstance(other, TensExpr): + raise ValueError('cannot divide by a tensor') + return TensMul(self, S.One/other).doit(deep=False) + + def __rtruediv__(self, other): + raise ValueError('cannot divide by a tensor') + + def __pow__(self, other): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + if self.data is None: + raise ValueError("No power without ndarray data.") + from .array import tensorproduct, tensorcontraction + free = self.free + marray = self.data + mdim = marray.rank() + for metric in free: + marray = tensorcontraction( + tensorproduct( + marray, + metric[0].tensor_index_type.data, + marray), + (0, mdim), (mdim+1, mdim+2) + ) + return marray ** (other * S.Half) + + def __rpow__(self, other): + raise NotImplementedError + + @property + @abstractmethod + def nocoeff(self): + raise NotImplementedError("abstract method") + + @property + @abstractmethod + def coeff(self): + raise NotImplementedError("abstract method") + + @abstractmethod + def get_indices(self): + raise NotImplementedError("abstract method") + + @abstractmethod + def get_free_indices(self) -> list[TensorIndex]: + raise NotImplementedError("abstract method") + + @abstractmethod + def _replace_indices(self, repl: dict[TensorIndex, TensorIndex]) -> TensExpr: + raise NotImplementedError("abstract method") + + def fun_eval(self, *index_tuples): + deprecate_fun_eval() + return self.substitute_indices(*index_tuples) + + def get_matrix(self): + """ + DEPRECATED: do not use. + + Returns ndarray components data as a matrix, if components data are + available and ndarray dimension does not exceed 2. + """ + from sympy.matrices.dense import Matrix + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + if 0 < self.rank <= 2: + rows = self.data.shape[0] + columns = self.data.shape[1] if self.rank == 2 else 1 + if self.rank == 2: + mat_list = [] * rows + for i in range(rows): + mat_list.append([]) + for j in range(columns): + mat_list[i].append(self[i, j]) + else: + mat_list = [None] * rows + for i in range(rows): + mat_list[i] = self[i] + return Matrix(mat_list) + else: + raise NotImplementedError( + "missing multidimensional reduction to matrix.") + + @staticmethod + def _get_indices_permutation(indices1, indices2): + return [indices1.index(i) for i in indices2] + + def _get_free_indices_set(self): + indset = set() + for arg in self.args: + if isinstance(arg, TensExpr): + indset.update(arg._get_free_indices_set()) + return indset + + def _get_dummy_indices_set(self): + indset = set() + for arg in self.args: + if isinstance(arg, TensExpr): + indset.update(arg._get_dummy_indices_set()) + return indset + + def _get_indices_set(self): + indset = set() + for arg in self.args: + if isinstance(arg, TensExpr): + indset.update(arg._get_indices_set()) + return indset + + @property + def _iterate_dummy_indices(self): + dummy_set = self._get_dummy_indices_set() + + def recursor(expr, pos): + if isinstance(expr, TensorIndex): + if expr in dummy_set: + yield (expr, pos) + elif isinstance(expr, (Tuple, TensExpr)): + for p, arg in enumerate(expr.args): + yield from recursor(arg, pos+(p,)) + + return recursor(self, ()) + + @property + def _iterate_free_indices(self): + free_set = self._get_free_indices_set() + + def recursor(expr, pos): + if isinstance(expr, TensorIndex): + if expr in free_set: + yield (expr, pos) + elif isinstance(expr, (Tuple, TensExpr)): + for p, arg in enumerate(expr.args): + yield from recursor(arg, pos+(p,)) + + return recursor(self, ()) + + @property + def _iterate_indices(self): + def recursor(expr, pos): + if isinstance(expr, TensorIndex): + yield (expr, pos) + elif isinstance(expr, (Tuple, TensExpr)): + for p, arg in enumerate(expr.args): + yield from recursor(arg, pos+(p,)) + + return recursor(self, ()) + + @staticmethod + def _contract_and_permute_with_metric(metric, array, pos, dim): + # TODO: add possibility of metric after (spinors) + from .array import tensorcontraction, tensorproduct, permutedims + + array = tensorcontraction(tensorproduct(metric, array), (1, 2+pos)) + permu = list(range(dim)) + permu[0], permu[pos] = permu[pos], permu[0] + return permutedims(array, permu) + + @staticmethod + def _match_indices_with_other_tensor(array, free_ind1, free_ind2, replacement_dict): + from .array import permutedims + + index_types1 = [i.tensor_index_type for i in free_ind1] + + # Check if variance of indices needs to be fixed: + pos2up = [] + pos2down = [] + free2remaining = free_ind2[:] + for pos1, index1 in enumerate(free_ind1): + if index1 in free2remaining: + pos2 = free2remaining.index(index1) + free2remaining[pos2] = None + continue + if -index1 in free2remaining: + pos2 = free2remaining.index(-index1) + free2remaining[pos2] = None + free_ind2[pos2] = index1 + if index1.is_up: + pos2up.append(pos2) + else: + pos2down.append(pos2) + else: + index2 = free2remaining[pos1] + if index2 is None: + raise ValueError("incompatible indices: %s and %s" % (free_ind1, free_ind2)) + free2remaining[pos1] = None + free_ind2[pos1] = index1 + if index1.is_up ^ index2.is_up: + if index1.is_up: + pos2up.append(pos1) + else: + pos2down.append(pos1) + + if len(set(free_ind1) & set(free_ind2)) < len(free_ind1): + raise ValueError("incompatible indices: %s and %s" % (free_ind1, free_ind2)) + + # Raise indices: + for pos in pos2up: + index_type_pos = index_types1[pos] + if index_type_pos not in replacement_dict: + raise ValueError("No metric provided to lower index") + metric = replacement_dict[index_type_pos] + metric_inverse = _TensorDataLazyEvaluator.inverse_matrix(metric) + array = TensExpr._contract_and_permute_with_metric(metric_inverse, array, pos, len(free_ind1)) + # Lower indices: + for pos in pos2down: + index_type_pos = index_types1[pos] + if index_type_pos not in replacement_dict: + raise ValueError("No metric provided to lower index") + metric = replacement_dict[index_type_pos] + array = TensExpr._contract_and_permute_with_metric(metric, array, pos, len(free_ind1)) + + if free_ind1: + permutation = TensExpr._get_indices_permutation(free_ind2, free_ind1) + array = permutedims(array, permutation) + + if hasattr(array, "rank") and array.rank() == 0: + array = array[()] + + return free_ind2, array + + def replace_with_arrays(self, replacement_dict, indices=None): + """ + Replace the tensorial expressions with arrays. The final array will + correspond to the N-dimensional array with indices arranged according + to ``indices``. + + Parameters + ========== + + replacement_dict + dictionary containing the replacement rules for tensors. + indices + the index order with respect to which the array is read. The + original index order will be used if no value is passed. + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices + >>> from sympy.tensor.tensor import TensorHead + >>> from sympy import symbols, diag + + >>> L = TensorIndexType("L") + >>> i, j = tensor_indices("i j", L) + >>> A = TensorHead("A", [L]) + >>> A(i).replace_with_arrays({A(i): [1, 2]}, [i]) + [1, 2] + + Since 'indices' is optional, we can also call replace_with_arrays by + this way if no specific index order is needed: + + >>> A(i).replace_with_arrays({A(i): [1, 2]}) + [1, 2] + + >>> expr = A(i)*A(j) + >>> expr.replace_with_arrays({A(i): [1, 2]}) + [[1, 2], [2, 4]] + + For contractions, specify the metric of the ``TensorIndexType``, which + in this case is ``L``, in its covariant form: + + >>> expr = A(i)*A(-i) + >>> expr.replace_with_arrays({A(i): [1, 2], L: diag(1, -1)}) + -3 + + Symmetrization of an array: + + >>> H = TensorHead("H", [L, L]) + >>> a, b, c, d = symbols("a b c d") + >>> expr = H(i, j)/2 + H(j, i)/2 + >>> expr.replace_with_arrays({H(i, j): [[a, b], [c, d]]}) + [[a, b/2 + c/2], [b/2 + c/2, d]] + + Anti-symmetrization of an array: + + >>> expr = H(i, j)/2 - H(j, i)/2 + >>> repl = {H(i, j): [[a, b], [c, d]]} + >>> expr.replace_with_arrays(repl) + [[0, b/2 - c/2], [-b/2 + c/2, 0]] + + The same expression can be read as the transpose by inverting ``i`` and + ``j``: + + >>> expr.replace_with_arrays(repl, [j, i]) + [[0, -b/2 + c/2], [b/2 - c/2, 0]] + """ + from .array import Array + + indices = indices or [] + remap = {k.args[0] if k.is_up else -k.args[0]: k for k in self.get_free_indices()} + for i, index in enumerate(indices): + if isinstance(index, (Symbol, Mul)): + if index in remap: + indices[i] = remap[index] + else: + indices[i] = -remap[-index] + + replacement_dict = {tensor: Array(array) for tensor, array in replacement_dict.items()} + + # Check dimensions of replaced arrays: + for tensor, array in replacement_dict.items(): + if isinstance(tensor, TensorIndexType): + expected_shape = [tensor.dim for i in range(2)] + else: + expected_shape = [index_type.dim for index_type in tensor.index_types] + if len(expected_shape) != array.rank() or (not all(dim1 == dim2 if + dim1.is_number else True for dim1, dim2 in zip(expected_shape, + array.shape))): + raise ValueError("shapes for tensor %s expected to be %s, "\ + "replacement array shape is %s" % (tensor, expected_shape, + array.shape)) + + ret_indices, array = self._extract_data(replacement_dict) + + last_indices, array = self._match_indices_with_other_tensor(array, indices, ret_indices, replacement_dict) + return array + + def _check_add_Sum(self, expr, index_symbols): + from sympy.concrete.summations import Sum + indices = self.get_indices() + dum = self.dum + sum_indices = [ (index_symbols[i], 0, + indices[i].tensor_index_type.dim-1) for i, j in dum] + if sum_indices: + expr = Sum(expr, *sum_indices) + return expr + + def _expand_partial_derivative(self): + # simply delegate the _expand_partial_derivative() to + # its arguments to expand a possibly found PartialDerivative + return self.func(*[ + a._expand_partial_derivative() + if isinstance(a, TensExpr) else a + for a in self.args]) + + def _matches_simple(self, expr, repl_dict=None, old=False): + """ + Matches assuming there are no wild objects in self. + """ + if repl_dict is None: + repl_dict = {} + else: + repl_dict = repl_dict.copy() + + if not isinstance(expr, TensExpr): + if len(self.get_free_indices()) > 0: + #self has indices, but expr does not. + return None + elif set(self.get_free_indices()) != set(expr.get_free_indices()): + #If there are no wilds and the free indices are not the same, they cannot match. + return None + + if canon_bp(self - expr) == S.Zero: + return repl_dict + else: + return None + + +class TensAdd(TensExpr, AssocOp): + """ + Sum of tensors. + + Parameters + ========== + + free_args : list of the free indices + + Attributes + ========== + + ``args`` : tuple of addends + ``rank`` : rank of the tensor + ``free_args`` : list of the free indices in sorted order + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType, tensor_heads, tensor_indices + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> a, b = tensor_indices('a,b', Lorentz) + >>> p, q = tensor_heads('p,q', [Lorentz]) + >>> t = p(a) + q(a); t + p(a) + q(a) + + Examples with components data added to the tensor expression: + + >>> from sympy import symbols, diag + >>> x, y, z, t = symbols("x y z t") + >>> repl = {} + >>> repl[Lorentz] = diag(1, -1, -1, -1) + >>> repl[p(a)] = [1, 2, 3, 4] + >>> repl[q(a)] = [x, y, z, t] + + The following are: 2**2 - 3**2 - 2**2 - 7**2 ==> -58 + + >>> expr = p(a) + q(a) + >>> expr.replace_with_arrays(repl, [a]) + [x + 1, y + 2, z + 3, t + 4] + """ + + def __new__(cls, *args, **kw_args): + args = [_sympify(x) for x in args if x] + args = TensAdd._tensAdd_flatten(args) + args.sort(key=default_sort_key) + if not args: + return S.Zero + if len(args) == 1: + return args[0] + + return Basic.__new__(cls, *args, **kw_args) + + @property + def coeff(self): + return S.One + + @property + def nocoeff(self): + return self + + def get_free_indices(self) -> list[TensorIndex]: + return self.free_indices + + def _replace_indices(self, repl: dict[TensorIndex, TensorIndex]) -> TensExpr: + newargs = [arg._replace_indices(repl) if isinstance(arg, TensExpr) else arg for arg in self.args] + return self.func(*newargs) + + @memoize_property + def rank(self): + if isinstance(self.args[0], TensExpr): + return self.args[0].rank + else: + return 0 + + @memoize_property + def free_args(self): + if isinstance(self.args[0], TensExpr): + return self.args[0].free_args + else: + return [] + + @memoize_property + def free_indices(self): + if isinstance(self.args[0], TensExpr): + return self.args[0].get_free_indices() + else: + return set() + + def doit(self, **hints) -> Expr: + deep = hints.get('deep', True) + if deep: + args = [arg.doit(**hints) for arg in self.args] + else: + args = self.args # type: ignore + + # if any of the args are zero (after doit), drop them. Otherwise, _tensAdd_check will complain about non-matching indices, even though the TensAdd is correctly formed. + args = [arg for arg in args if arg != S.Zero] + + if len(args) == 0: + return S.Zero + elif len(args) == 1: + return args[0] + + # now check that all addends have the same indices: + TensAdd._tensAdd_check(args) + + # Collect terms appearing more than once, differing by their coefficients: + args = TensAdd._tensAdd_collect_terms(args) + + # collect canonicalized terms + def sort_key(t): + if not isinstance(t, TensExpr): + return [], [], [] + if hasattr(t, "_index_structure") and hasattr(t, "components"): + x = get_index_structure(t) + return t.components, x.free, x.dum + return [], [], [] + args.sort(key=sort_key) + + if not args: + return S.Zero + # it there is only a component tensor return it + if len(args) == 1: + return args[0] + + obj = self.func(*args) + return obj + + @staticmethod + def _tensAdd_flatten(args): + # flatten TensAdd, coerce terms which are not tensors to tensors + a = [] + for x in args: + if isinstance(x, (Add, TensAdd)): + a.extend(list(x.args)) + else: + a.append(x) + args = [x for x in a if x.coeff] + return args + + @staticmethod + def _tensAdd_check(args): + # check that all addends have the same free indices + + def get_indices_set(x: Expr) -> set[TensorIndex]: + if isinstance(x, TensExpr): + return set(x.get_free_indices()) + return set() + + indices0 = get_indices_set(args[0]) + list_indices = [get_indices_set(arg) for arg in args[1:]] + if not all(x == indices0 for x in list_indices): + raise ValueError('all tensors must have the same indices') + + @staticmethod + def _tensAdd_collect_terms(args): + # collect TensMul terms differing at most by their coefficient + terms_dict = defaultdict(list) + scalars = S.Zero + if isinstance(args[0], TensExpr): + free_indices = set(args[0].get_free_indices()) + else: + free_indices = set() + + for arg in args: + if not isinstance(arg, TensExpr): + if free_indices != set(): + raise ValueError("wrong valence") + scalars += arg + continue + if free_indices != set(arg.get_free_indices()): + raise ValueError("wrong valence") + # TODO: what is the part which is not a coeff? + # needs an implementation similar to .as_coeff_Mul() + terms_dict[arg.nocoeff].append(arg.coeff) + + new_args = [TensMul(Add(*coeff), t).doit(deep=False) for t, coeff in terms_dict.items() if Add(*coeff) != 0] + if isinstance(scalars, Add): + new_args = list(scalars.args) + new_args + elif scalars != 0: + new_args = [scalars] + new_args + return new_args + + def get_indices(self): + indices = [] + for arg in self.args: + indices.extend([i for i in get_indices(arg) if i not in indices]) + return indices + + + def __call__(self, *indices): + deprecate_call() + free_args = self.free_args + indices = list(indices) + if [x.tensor_index_type for x in indices] != [x.tensor_index_type for x in free_args]: + raise ValueError('incompatible types') + if indices == free_args: + return self + index_tuples = list(zip(free_args, indices)) + a = [x.func(*x.substitute_indices(*index_tuples).args) for x in self.args] + res = TensAdd(*a).doit(deep=False) + return res + + def canon_bp(self): + """ + Canonicalize using the Butler-Portugal algorithm for canonicalization + under monoterm symmetries. + """ + expr = self.expand() + if isinstance(expr, self.func): + args = [canon_bp(x) for x in expr.args] + res = TensAdd(*args).doit(deep=False) + return res + else: + return canon_bp(expr) + + def equals(self, other): + other = _sympify(other) + if isinstance(other, TensMul) and other.coeff == 0: + return all(x.coeff == 0 for x in self.args) + if isinstance(other, TensExpr): + if self.rank != other.rank: + return False + if isinstance(other, TensAdd): + if set(self.args) != set(other.args): + return False + else: + return True + t = self - other + if not isinstance(t, TensExpr): + return t == 0 + else: + if isinstance(t, TensMul): + return t.coeff == 0 + else: + return all(x.coeff == 0 for x in t.args) + + def __getitem__(self, item): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + return self.data[item] + + def contract_delta(self, delta): + args = [x.contract_delta(delta) if isinstance(x, TensExpr) else x for x in self.args] + t = TensAdd(*args).doit(deep=False) + return canon_bp(t) + + def contract_metric(self, g): + """ + Raise or lower indices with the metric ``g``. + + Parameters + ========== + + g : metric + + contract_all : if True, eliminate all ``g`` which are contracted + + Notes + ===== + + see the ``TensorIndexType`` docstring for the contraction conventions + """ + + args = [contract_metric(x, g) for x in self.args] + t = TensAdd(*args).doit(deep=False) + return canon_bp(t) + + def substitute_indices(self, *index_tuples): + new_args = [] + for arg in self.args: + if isinstance(arg, TensExpr): + arg = arg.substitute_indices(*index_tuples) + new_args.append(arg) + return TensAdd(*new_args).doit(deep=False) + + def _print(self): + a = [] + args = self.args + for x in args: + a.append(str(x)) + s = ' + '.join(a) + s = s.replace('+ -', '- ') + return s + + def _extract_data(self, replacement_dict): + from sympy.tensor.array import Array, permutedims + args_indices, arrays = zip(*[ + arg._extract_data(replacement_dict) if + isinstance(arg, TensExpr) else ([], arg) for arg in self.args + ]) + arrays = [Array(i) for i in arrays] + ref_indices = args_indices[0] + for i in range(1, len(args_indices)): + indices = args_indices[i] + array = arrays[i] + permutation = TensMul._get_indices_permutation(indices, ref_indices) + arrays[i] = permutedims(array, permutation) + return ref_indices, sum(arrays, Array.zeros(*array.shape)) + + @property + def data(self): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + return _tensor_data_substitution_dict[self.expand()] + + @data.setter + def data(self, data): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + _tensor_data_substitution_dict[self] = data + + @data.deleter + def data(self): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + if self in _tensor_data_substitution_dict: + del _tensor_data_substitution_dict[self] + + def __iter__(self): + deprecate_data() + if not self.data: + raise ValueError("No iteration on abstract tensors") + return self.data.flatten().__iter__() + + def _eval_rewrite_as_Indexed(self, *args, **kwargs): + return Add.fromiter(args) + + def _eval_partial_derivative(self, s): + # Evaluation like Add + list_addends = [] + for a in self.args: + if isinstance(a, TensExpr): + list_addends.append(a._eval_partial_derivative(s)) + # do not call diff if s is no symbol + elif s._diff_wrt: + list_addends.append(a._eval_derivative(s)) + + return self.func(*list_addends).doit(deep=False) + + def matches(self, expr, repl_dict=None, old=False): + expr = sympify(expr) + + if repl_dict is None: + repl_dict = {} + else: + repl_dict = repl_dict.copy() + + if not isinstance(expr, TensAdd): + return None + + if len(_get_wilds(self)) == 0: + return self._matches_simple(expr, repl_dict, old) + + def siftkey(arg): + wildatoms = _get_wilds(arg) + wildatom_types = sift(wildatoms, type) + if len(wildatoms) == 0: + return "nonwild" + elif WildTensor in wildatom_types.keys(): + for w in wildatom_types["WildTensor"]: + if len(w.get_indices()) == 0: + return "indexless_wildtensor" + return "wildtensor" + else: + return "otherwild" + + query_sifted = sift(self.args, siftkey) + expr_sifted = sift(expr.args, siftkey) + + #First try to match the terms without WildTensors + matched_e_tensors = [] #Used to make sure that the same tensor in expr is not matched with more than one tensor in self. + for q_tensor in query_sifted["nonwild"]: + matched_this_q = False + for e_tensor in expr_sifted["nonwild"]: + if e_tensor in matched_e_tensors: + continue + + m = q_tensor.matches(e_tensor, repl_dict=repl_dict, old=old) + if m is None: + continue + else: + matched_this_q = True + repl_dict.update(m) + matched_e_tensors.append(e_tensor) + break + + if not matched_this_q: + return None + + remaining_e_tensors = [t for t in expr_sifted["nonwild"] if t not in matched_e_tensors] + for w in query_sifted["otherwild"]: + for e in remaining_e_tensors: + m = w.matches(e) + if m is not None: + matched_e_tensors.append(e) + if w in repl_dict.keys(): + repl_dict[w] += m.pop(w) + repl_dict.update(m) + + remaining_e_tensors = [t for t in expr_sifted["nonwild"] if t not in matched_e_tensors] + for w in query_sifted["wildtensor"]: + for e in remaining_e_tensors: + m = w.matches(e) + if m is not None: + matched_e_tensors.append(e) + if w.component in repl_dict.keys(): + repl_dict[w.component] += m.pop(w.component) + repl_dict.update(m) + + remaining_e_tensors = [t for t in expr_sifted["nonwild"] if t not in matched_e_tensors] + for w in query_sifted["indexless_wildtensor"]: + for e in remaining_e_tensors: + m = w.matches(e) + if m is not None: + matched_e_tensors.append(e) + if w.component in repl_dict.keys(): + repl_dict[w.component] += m.pop(w.component) + repl_dict.update(m) + + remaining_e_tensors = [t for t in expr_sifted["nonwild"] if t not in matched_e_tensors] + if len(remaining_e_tensors) > 0: + return None + else: + return repl_dict + + +class Tensor(TensExpr): + """ + Base tensor class, i.e. this represents a tensor, the single unit to be + put into an expression. + + Explanation + =========== + + This object is usually created from a ``TensorHead``, by attaching indices + to it. Indices preceded by a minus sign are considered contravariant, + otherwise covariant. + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices, TensorHead + >>> Lorentz = TensorIndexType("Lorentz", dummy_name="L") + >>> mu, nu = tensor_indices('mu nu', Lorentz) + >>> A = TensorHead("A", [Lorentz, Lorentz]) + >>> A(mu, -nu) + A(mu, -nu) + >>> A(mu, -mu) + A(L_0, -L_0) + + It is also possible to use symbols instead of inidices (appropriate indices + are then generated automatically). + + >>> from sympy import Symbol + >>> x = Symbol('x') + >>> A(x, mu) + A(x, mu) + >>> A(x, -x) + A(L_0, -L_0) + + """ + + is_commutative = False + + _index_structure: _IndexStructure + args: tuple[TensorHead, Tuple] + + def __new__(cls, tensor_head, indices, *, is_canon_bp=False, **kw_args): + indices = cls._parse_indices(tensor_head, indices) + obj = Basic.__new__(cls, tensor_head, Tuple(*indices), **kw_args) + obj._index_structure = _IndexStructure.from_indices(*indices) + obj._free = obj._index_structure.free[:] + obj._dum = obj._index_structure.dum[:] + obj._ext_rank = obj._index_structure._ext_rank + obj._coeff = S.One + obj._nocoeff = obj + obj._component = tensor_head + obj._components = [tensor_head] + if tensor_head.rank != len(indices): + raise ValueError("wrong number of indices") + obj.is_canon_bp = is_canon_bp + obj._index_map = Tensor._build_index_map(indices, obj._index_structure) + return obj + + @property + def free(self): + return self._free + + @property + def dum(self): + return self._dum + + @property + def ext_rank(self): + return self._ext_rank + + @property + def coeff(self): + return self._coeff + + @property + def nocoeff(self): + return self._nocoeff + + @property + def component(self): + return self._component + + @property + def components(self): + return self._components + + @property + def head(self): + return self.args[0] + + @property + def indices(self): + return self.args[1] + + @property + def free_indices(self): + return set(self._index_structure.get_free_indices()) + + @property + def index_types(self): + return self.head.index_types + + @property + def rank(self): + return len(self.free_indices) + + @staticmethod + def _build_index_map(indices, index_structure): + index_map = {} + for idx in indices: + index_map[idx] = (indices.index(idx),) + return index_map + + def doit(self, **hints): + args, indices, free, dum = TensMul._tensMul_contract_indices([self]) + return args[0] + + @staticmethod + def _parse_indices(tensor_head, indices): + if not isinstance(indices, (tuple, list, Tuple)): + raise TypeError("indices should be an array, got %s" % type(indices)) + indices = list(indices) + for i, index in enumerate(indices): + if isinstance(index, Symbol): + indices[i] = TensorIndex(index, tensor_head.index_types[i], True) + elif isinstance(index, Mul): + c, e = index.as_coeff_Mul() + if c == -1 and isinstance(e, Symbol): + indices[i] = TensorIndex(e, tensor_head.index_types[i], False) + else: + raise ValueError("index not understood: %s" % index) + elif not isinstance(index, TensorIndex): + raise TypeError("wrong type for index: %s is %s" % (index, type(index))) + return indices + + def _set_new_index_structure(self, im, is_canon_bp=False): + indices = im.get_indices() + return self._set_indices(*indices, is_canon_bp=is_canon_bp) + + def _set_indices(self, *indices, is_canon_bp=False, **kw_args): + if len(indices) != self.ext_rank: + raise ValueError("indices length mismatch") + return self.func(self.args[0], indices, is_canon_bp=is_canon_bp).doit() + + def _get_free_indices_set(self): + return {i[0] for i in self._index_structure.free} + + def _get_dummy_indices_set(self): + dummy_pos = set(itertools.chain(*self._index_structure.dum)) + return {idx for i, idx in enumerate(self.args[1]) if i in dummy_pos} + + def _get_indices_set(self): + return set(self.args[1].args) + + @property + def free_in_args(self): + return [(ind, pos, 0) for ind, pos in self.free] + + @property + def dum_in_args(self): + return [(p1, p2, 0, 0) for p1, p2 in self.dum] + + @property + def free_args(self): + return sorted([x[0] for x in self.free]) + + def commutes_with(self, other): + """ + :param other: + :return: + 0 commute + 1 anticommute + None neither commute nor anticommute + """ + if not isinstance(other, TensExpr): + return 0 + elif isinstance(other, Tensor): + return self.component.commutes_with(other.component) + return NotImplementedError + + def perm2tensor(self, g, is_canon_bp=False): + """ + Returns the tensor corresponding to the permutation ``g``. + + For further details, see the method in ``TIDS`` with the same name. + """ + return perm2tensor(self, g, is_canon_bp) + + def canon_bp(self): + if self.is_canon_bp: + return self + expr = self.expand() + g, dummies, msym = expr._index_structure.indices_canon_args() + v = components_canon_args([expr.component]) + can = canonicalize(g, dummies, msym, *v) + if can == 0: + return S.Zero + tensor = self.perm2tensor(can, True) + return tensor + + def split(self): + return [self] + + def sorted_components(self): + return self + + def get_indices(self) -> list[TensorIndex]: + """ + Get a list of indices, corresponding to those of the tensor. + """ + return list(self.args[1]) + + def get_free_indices(self) -> list[TensorIndex]: + """ + Get a list of free indices, corresponding to those of the tensor. + """ + return self._index_structure.get_free_indices() + + def _replace_indices(self, repl: dict[TensorIndex, TensorIndex]) -> TensExpr: + # TODO: this could be optimized by only swapping the indices + # instead of visiting the whole expression tree: + return self.xreplace(repl) + + def as_base_exp(self): + return self, S.One + + def substitute_indices(self, *index_tuples): + """ + Return a tensor with free indices substituted according to ``index_tuples``. + + ``index_types`` list of tuples ``(old_index, new_index)``. + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices, tensor_heads, TensorSymmetry + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> i, j, k, l = tensor_indices('i,j,k,l', Lorentz) + >>> A, B = tensor_heads('A,B', [Lorentz]*2, TensorSymmetry.fully_symmetric(2)) + >>> t = A(i, k)*B(-k, -j); t + A(i, L_0)*B(-L_0, -j) + >>> t.substitute_indices((i, k),(-j, l)) + A(k, L_0)*B(-L_0, l) + """ + indices = [] + for index in self.indices: + for ind_old, ind_new in index_tuples: + if (index.name == ind_old.name and index.tensor_index_type == + ind_old.tensor_index_type): + if index.is_up == ind_old.is_up: + indices.append(ind_new) + else: + indices.append(-ind_new) + break + else: + indices.append(index) + return self.head(*indices) + + def _get_symmetrized_forms(self): + """ + Return a list giving all possible permutations of self that are allowed by its symmetries. + """ + comp = self.component + gens = comp.symmetry.generators + rank = comp.rank + + old_perms = None + new_perms = {self} + while new_perms != old_perms: + old_perms = new_perms.copy() + for tens in old_perms: + for gen in gens: + inds = tens.get_indices() + per = [gen.apply(i) for i in range(0,rank)] + sign = (-1)**(gen.apply(rank) - rank) + ind_map = dict(zip(inds, [inds[i] for i in per])) + new_perms.add( sign * tens._replace_indices(ind_map) ) + + return new_perms + + def matches(self, expr, repl_dict=None, old=False): + expr = sympify(expr) + + if repl_dict is None: + repl_dict = {} + else: + repl_dict = repl_dict.copy() + + #simple checks + if self == expr: + return repl_dict + if not isinstance(expr, Tensor): + return None + if self.head != expr.head: + return None + + #Now consider all index symmetries of expr, and see if any of them allow a match. + for new_expr in expr._get_symmetrized_forms(): + m = self._matches(new_expr, repl_dict, old=old) + if m is not None: + repl_dict.update(m) + return repl_dict + + return None + + def _matches(self, expr, repl_dict=None, old=False): + """ + This does not account for index symmetries of expr + """ + expr = sympify(expr) + + if repl_dict is None: + repl_dict = {} + else: + repl_dict = repl_dict.copy() + + #simple checks + if self == expr: + return repl_dict + if not isinstance(expr, Tensor): + return None + if self.head != expr.head: + return None + + s_indices = self.get_indices() + e_indices = expr.get_indices() + + if len(s_indices) != len(e_indices): + return None + + for i in range(len(s_indices)): + s_ind = s_indices[i] + m = s_ind.matches(e_indices[i]) + if m is None: + return None + elif -s_ind in repl_dict.keys() and -repl_dict[-s_ind] != m[s_ind]: + return None + else: + repl_dict.update(m) + + return repl_dict + + def __call__(self, *indices): + deprecate_call() + free_args = self.free_args + indices = list(indices) + if [x.tensor_index_type for x in indices] != [x.tensor_index_type for x in free_args]: + raise ValueError('incompatible types') + if indices == free_args: + return self + t = self.substitute_indices(*list(zip(free_args, indices))) + + # object is rebuilt in order to make sure that all contracted indices + # get recognized as dummies, but only if there are contracted indices. + if len({i if i.is_up else -i for i in indices}) != len(indices): + return t.func(*t.args) + return t + + # TODO: put this into TensExpr? + def __iter__(self): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + return self.data.__iter__() + + # TODO: put this into TensExpr? + def __getitem__(self, item): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + return self.data[item] + + def _extract_data(self, replacement_dict): + from .array import Array + for k, v in replacement_dict.items(): + if isinstance(k, Tensor) and k.args[0] == self.args[0]: + other = k + array = v + break + else: + raise ValueError("%s not found in %s" % (self, replacement_dict)) + + # TODO: inefficient, this should be done at root level only: + replacement_dict = {k: Array(v) for k, v in replacement_dict.items()} + array = Array(array) + + dum1 = self.dum + dum2 = other.dum + + if len(dum2) > 0: + for pair in dum2: + # allow `dum2` if the contained values are also in `dum1`. + if pair not in dum1: + raise NotImplementedError("%s with contractions is not implemented" % other) + # Remove elements in `dum2` from `dum1`: + dum1 = [pair for pair in dum1 if pair not in dum2] + if len(dum1) > 0: + indices1 = self.get_indices() + indices2 = other.get_indices() + repl = {} + for p1, p2 in dum1: + repl[indices2[p2]] = -indices2[p1] + for pos in (p1, p2): + if indices1[pos].is_up ^ indices2[pos].is_up: + metric = replacement_dict[indices1[pos].tensor_index_type] + if indices1[pos].is_up: + metric = _TensorDataLazyEvaluator.inverse_matrix(metric) + array = self._contract_and_permute_with_metric(metric, array, pos, len(indices2)) + other = other.xreplace(repl).doit() + array = _TensorDataLazyEvaluator.data_contract_dum([array], dum1, len(indices2)) + + free_ind1 = self.get_free_indices() + free_ind2 = other.get_free_indices() + + return self._match_indices_with_other_tensor(array, free_ind1, free_ind2, replacement_dict) + + @property + def data(self): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + return _tensor_data_substitution_dict[self] + + @data.setter + def data(self, data): + deprecate_data() + # TODO: check data compatibility with properties of tensor. + with ignore_warnings(SymPyDeprecationWarning): + _tensor_data_substitution_dict[self] = data + + @data.deleter + def data(self): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + if self in _tensor_data_substitution_dict: + del _tensor_data_substitution_dict[self] + if self.metric in _tensor_data_substitution_dict: + del _tensor_data_substitution_dict[self.metric] + + def _print(self): + indices = [str(ind) for ind in self.indices] + component = self.component + if component.rank > 0: + return ('%s(%s)' % (component.name, ', '.join(indices))) + else: + return ('%s' % component.name) + + def equals(self, other): + if other == 0: + return self.coeff == 0 + other = _sympify(other) + if not isinstance(other, TensExpr): + assert not self.components + return S.One == other + + def _get_compar_comp(self): + t = self.canon_bp() + r = (t.coeff, tuple(t.components), \ + tuple(sorted(t.free)), tuple(sorted(t.dum))) + return r + + return _get_compar_comp(self) == _get_compar_comp(other) + + def contract_metric(self, g): + # if metric is not the same, ignore this step: + if self.component != g: + return self + # in case there are free components, do not perform anything: + if len(self.free) != 0: + return self + + #antisym = g.index_types[0].metric_antisym + if g.symmetry == TensorSymmetry.fully_symmetric(-2): + antisym = 1 + elif g.symmetry == TensorSymmetry.fully_symmetric(2): + antisym = 0 + elif g.symmetry == TensorSymmetry.no_symmetry(2): + antisym = None + else: + raise NotImplementedError + sign = S.One + typ = g.index_types[0] + + if not antisym: + # g(i, -i) + sign = sign*typ.dim + else: + # g(i, -i) + sign = sign*typ.dim + + dp0, dp1 = self.dum[0] + if dp0 < dp1: + # g(i, -i) = -D with antisymmetric metric + sign = -sign + + return sign + + def contract_delta(self, metric): + return self.contract_metric(metric) + + def _eval_rewrite_as_Indexed(self, tens, indices, **kwargs): + from sympy.tensor.indexed import Indexed + # TODO: replace .args[0] with .name: + index_symbols = [i.args[0] for i in self.get_indices()] + expr = Indexed(tens.args[0], *index_symbols) + return self._check_add_Sum(expr, index_symbols) + + def _eval_partial_derivative(self, s: Tensor) -> Expr: + + if not isinstance(s, Tensor): + return S.Zero + else: + + # @a_i/@a_k = delta_i^k + # @a_i/@a^k = g_ij delta^j_k + # @a^i/@a^k = delta^i_k + # @a^i/@a_k = g^ij delta_j^k + # TODO: if there is no metric present, the derivative should be zero? + + if self.head != s.head: + return S.Zero + + # if heads are the same, provide delta and/or metric products + # for every free index pair in the appropriate tensor + # assumed that the free indices are in proper order + # A contravariante index in the derivative becomes covariant + # after performing the derivative and vice versa + + kronecker_delta_list = [1] + + # not guarantee a correct index order + + for (count, (iself, iother)) in enumerate(zip(self.get_free_indices(), s.get_free_indices())): + if iself.tensor_index_type != iother.tensor_index_type: + raise ValueError("index types not compatible") + else: + tensor_index_type = iself.tensor_index_type + tensor_metric = tensor_index_type.metric + dummy = TensorIndex("d_" + str(count), tensor_index_type, + is_up=iself.is_up) + if iself.is_up == iother.is_up: + kroneckerdelta = tensor_index_type.delta(iself, -iother) + else: + kroneckerdelta = ( + TensMul(tensor_metric(iself, dummy), + tensor_index_type.delta(-dummy, -iother)) + ) + kronecker_delta_list.append(kroneckerdelta) + return TensMul.fromiter(kronecker_delta_list).doit(deep=False) + # doit necessary to rename dummy indices accordingly + + +class TensMul(TensExpr, AssocOp): + """ + Product of tensors. + + Parameters + ========== + + coeff : SymPy coefficient of the tensor + args + + Attributes + ========== + + ``components`` : list of ``TensorHead`` of the component tensors + ``types`` : list of nonrepeated ``TensorIndexType`` + ``free`` : list of ``(ind, ipos, icomp)``, see Notes + ``dum`` : list of ``(ipos1, ipos2, icomp1, icomp2)``, see Notes + ``ext_rank`` : rank of the tensor counting the dummy indices + ``rank`` : rank of the tensor + ``coeff`` : SymPy coefficient of the tensor + ``free_args`` : list of the free indices in sorted order + ``is_canon_bp`` : ``True`` if the tensor in in canonical form + + Notes + ===== + + ``args[0]`` list of ``TensorHead`` of the component tensors. + + ``args[1]`` list of ``(ind, ipos, icomp)`` + where ``ind`` is a free index, ``ipos`` is the slot position + of ``ind`` in the ``icomp``-th component tensor. + + ``args[2]`` list of tuples representing dummy indices. + ``(ipos1, ipos2, icomp1, icomp2)`` indicates that the contravariant + dummy index is the ``ipos1``-th slot position in the ``icomp1``-th + component tensor; the corresponding covariant index is + in the ``ipos2`` slot position in the ``icomp2``-th component tensor. + + """ + identity = S.One + + _index_structure: _IndexStructure + + def __new__(cls, *args, **kw_args): + is_canon_bp = kw_args.get('is_canon_bp', False) + args = list(map(_sympify, args)) + + """ + If the internal dummy indices in one arg conflict with the free indices + of the remaining args, we need to rename those internal dummy indices. + """ + free = [get_free_indices(arg) for arg in args] + free = set(itertools.chain(*free)) #flatten free + newargs = [] + for arg in args: + dum_this = set(get_dummy_indices(arg)) + dum_other = [get_dummy_indices(a) for a in newargs] + dum_other = set(itertools.chain(*dum_other)) #flatten dum_other + free_this = set(get_free_indices(arg)) + if len(dum_this.intersection(free)) > 0: + exclude = free_this.union(free, dum_other) + newarg = TensMul._dedupe_indices(arg, exclude) + else: + newarg = arg + newargs.append(newarg) + + args = newargs + + # Flatten: + args = [i for arg in args for i in (arg.args if isinstance(arg, (TensMul, Mul)) else [arg])] + + args, indices, free, dum = TensMul._tensMul_contract_indices(args, replace_indices=False) + + # Data for indices: + index_types = [i.tensor_index_type for i in indices] + index_structure = _IndexStructure(free, dum, index_types, indices, canon_bp=is_canon_bp) + + obj = TensExpr.__new__(cls, *args) + obj._indices = indices + obj._index_types = index_types.copy() + obj._index_structure = index_structure + obj._free = index_structure.free[:] + obj._dum = index_structure.dum[:] + obj._free_indices = {x[0] for x in obj.free} + obj._rank = len(obj.free) + obj._ext_rank = len(obj._index_structure.free) + 2*len(obj._index_structure.dum) + obj._coeff = S.One + obj._is_canon_bp = is_canon_bp + return obj + + index_types = property(lambda self: self._index_types) + free = property(lambda self: self._free) + dum = property(lambda self: self._dum) + free_indices = property(lambda self: self._free_indices) + rank = property(lambda self: self._rank) + ext_rank = property(lambda self: self._ext_rank) + + @staticmethod + def _indices_to_free_dum(args_indices): + free2pos1 = {} + free2pos2 = {} + dummy_data = [] + indices = [] + + # Notation for positions (to better understand the code): + # `pos1`: position in the `args`. + # `pos2`: position in the indices. + + # Example: + # A(i, j)*B(k, m, n)*C(p) + # `pos1` of `n` is 1 because it's in `B` (second `args` of TensMul). + # `pos2` of `n` is 4 because it's the fifth overall index. + + # Counter for the index position wrt the whole expression: + pos2 = 0 + + for pos1, arg_indices in enumerate(args_indices): + + for index in arg_indices: + if not isinstance(index, TensorIndex): + raise TypeError("expected TensorIndex") + if -index in free2pos1: + # Dummy index detected: + other_pos1 = free2pos1.pop(-index) + other_pos2 = free2pos2.pop(-index) + if index.is_up: + dummy_data.append((index, pos1, other_pos1, pos2, other_pos2)) + else: + dummy_data.append((-index, other_pos1, pos1, other_pos2, pos2)) + indices.append(index) + elif index in free2pos1: + raise ValueError("Repeated index: %s" % index) + else: + free2pos1[index] = pos1 + free2pos2[index] = pos2 + indices.append(index) + pos2 += 1 + + free = list(free2pos2.items()) + free_names = [i.name for i in free2pos2.keys()] + + dummy_data.sort(key=lambda x: x[3]) + return indices, free, free_names, dummy_data + + @staticmethod + def _dummy_data_to_dum(dummy_data): + return [(p2a, p2b) for (i, p1a, p1b, p2a, p2b) in dummy_data] + + @staticmethod + def _tensMul_contract_indices(args, replace_indices=True): + replacements = [{} for _ in args] + + #_index_order = all(_has_index_order(arg) for arg in args) + + args_indices = [get_indices(arg) for arg in args] + indices, free, free_names, dummy_data = TensMul._indices_to_free_dum(args_indices) + + cdt = defaultdict(int) + + def dummy_name_gen(tensor_index_type): + nd = str(cdt[tensor_index_type]) + cdt[tensor_index_type] += 1 + return tensor_index_type.dummy_name + '_' + nd + + if replace_indices: + for old_index, pos1cov, pos1contra, pos2cov, pos2contra in dummy_data: + index_type = old_index.tensor_index_type + while True: + dummy_name = dummy_name_gen(index_type) + if dummy_name not in free_names: + break + dummy = old_index.func(dummy_name, index_type, *old_index.args[2:]) + replacements[pos1cov][old_index] = dummy + replacements[pos1contra][-old_index] = -dummy + indices[pos2cov] = dummy + indices[pos2contra] = -dummy + args = [ + arg._replace_indices(repl) if isinstance(arg, TensExpr) else arg + for arg, repl in zip(args, replacements)] + + """ + The order of indices might've changed due to the replacements (e.g. if one of the args is a TensAdd, replacing an index can change the sort order of the terms, thus changing the order of indices returned by its get_indices() method). + To stay on the safe side, we calculate these quantities again. + """ + args_indices = [get_indices(arg) for arg in args] + indices, free, free_names, dummy_data = TensMul._indices_to_free_dum(args_indices) + + dum = TensMul._dummy_data_to_dum(dummy_data) + return args, indices, free, dum + + @staticmethod + def _get_components_from_args(args): + """ + Get a list of ``Tensor`` objects having the same ``TIDS`` if multiplied + by one another. + """ + components = [] + for arg in args: + if not isinstance(arg, TensExpr): + continue + if isinstance(arg, TensAdd): + continue + components.extend(arg.components) + return components + + @staticmethod + def _rebuild_tensors_list(args, index_structure): + indices = index_structure.get_indices() + #tensors = [None for i in components] # pre-allocate list + ind_pos = 0 + for i, arg in enumerate(args): + if not isinstance(arg, TensExpr): + continue + prev_pos = ind_pos + ind_pos += arg.ext_rank + args[i] = Tensor(arg.component, indices[prev_pos:ind_pos]) + + def doit(self, **hints): + is_canon_bp = self._is_canon_bp + deep = hints.get('deep', True) + if deep: + args = [arg.doit(**hints) for arg in self.args] + + """ + There may now be conflicts between dummy indices of different args + (each arg's doit method does not have any information about which + dummy indices are already used in the other args), so we + deduplicate them. + """ + rule = dict(zip(self.args, args)) + rule = self._dedupe_indices_in_rule(rule) + args = [rule[a] for a in self.args] + + else: + args = self.args + + args = [arg for arg in args if arg != self.identity] + + # Extract non-tensor coefficients: + coeff = reduce(lambda a, b: a*b, [arg for arg in args if not isinstance(arg, TensExpr)], S.One) + args = [arg for arg in args if isinstance(arg, TensExpr)] + + if len(args) == 0: + return coeff + + if coeff != self.identity: + args = [coeff] + args + if coeff == 0: + return S.Zero + + if len(args) == 1: + return args[0] + + args, indices, free, dum = TensMul._tensMul_contract_indices(args) + + # Data for indices: + index_types = [i.tensor_index_type for i in indices] + index_structure = _IndexStructure(free, dum, index_types, indices, canon_bp=is_canon_bp) + + obj = self.func(*args) + obj._index_types = index_types + obj._index_structure = index_structure + obj._ext_rank = len(obj._index_structure.free) + 2*len(obj._index_structure.dum) + obj._coeff = coeff + obj._is_canon_bp = is_canon_bp + return obj + + # TODO: this method should be private + # TODO: should this method be renamed _from_components_free_dum ? + @staticmethod + def from_data(coeff, components, free, dum, **kw_args): + return TensMul(coeff, *TensMul._get_tensors_from_components_free_dum(components, free, dum), **kw_args).doit(deep=False) + + @staticmethod + def _get_tensors_from_components_free_dum(components, free, dum): + """ + Get a list of ``Tensor`` objects by distributing ``free`` and ``dum`` indices on the ``components``. + """ + index_structure = _IndexStructure.from_components_free_dum(components, free, dum) + indices = index_structure.get_indices() + tensors = [None for i in components] # pre-allocate list + + # distribute indices on components to build a list of tensors: + ind_pos = 0 + for i, component in enumerate(components): + prev_pos = ind_pos + ind_pos += component.rank + tensors[i] = Tensor(component, indices[prev_pos:ind_pos]) + return tensors + + def _get_free_indices_set(self): + return {i[0] for i in self.free} + + def _get_dummy_indices_set(self): + dummy_pos = set(itertools.chain(*self.dum)) + return {idx for i, idx in enumerate(self._index_structure.get_indices()) if i in dummy_pos} + + def _get_position_offset_for_indices(self): + arg_offset = [None for i in range(self.ext_rank)] + counter = 0 + for arg in self.args: + if not isinstance(arg, TensExpr): + continue + for j in range(arg.ext_rank): + arg_offset[j + counter] = counter + counter += arg.ext_rank + return arg_offset + + @property + def free_args(self): + return sorted([x[0] for x in self.free]) + + @property + def components(self): + return self._get_components_from_args(self.args) + + @property + def free_in_args(self): + arg_offset = self._get_position_offset_for_indices() + argpos = self._get_indices_to_args_pos() + return [(ind, pos-arg_offset[pos], argpos[pos]) for (ind, pos) in self.free] + + @property + def coeff(self): + # return Mul.fromiter([c for c in self.args if not isinstance(c, TensExpr)]) + return self._coeff + + @property + def nocoeff(self): + return self.func(*self.args, 1/self.coeff).doit(deep=False) + + @property + def dum_in_args(self): + arg_offset = self._get_position_offset_for_indices() + argpos = self._get_indices_to_args_pos() + return [(p1-arg_offset[p1], p2-arg_offset[p2], argpos[p1], argpos[p2]) for p1, p2 in self.dum] + + def equals(self, other): + if other == 0: + return self.coeff == 0 + other = _sympify(other) + if not isinstance(other, TensExpr): + assert not self.components + return self.coeff == other + + return self.canon_bp() == other.canon_bp() + + def get_indices(self): + """ + Returns the list of indices of the tensor. + + Explanation + =========== + + The indices are listed in the order in which they appear in the + component tensors. + The dummy indices are given a name which does not collide with + the names of the free indices. + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices, tensor_heads + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> m0, m1, m2 = tensor_indices('m0,m1,m2', Lorentz) + >>> g = Lorentz.metric + >>> p, q = tensor_heads('p,q', [Lorentz]) + >>> t = p(m1)*g(m0,m2) + >>> t.get_indices() + [m1, m0, m2] + >>> t2 = p(m1)*g(-m1, m2) + >>> t2.get_indices() + [L_0, -L_0, m2] + """ + return self._indices + + def get_free_indices(self) -> list[TensorIndex]: + """ + Returns the list of free indices of the tensor. + + Explanation + =========== + + The indices are listed in the order in which they appear in the + component tensors. + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices, tensor_heads + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> m0, m1, m2 = tensor_indices('m0,m1,m2', Lorentz) + >>> g = Lorentz.metric + >>> p, q = tensor_heads('p,q', [Lorentz]) + >>> t = p(m1)*g(m0,m2) + >>> t.get_free_indices() + [m1, m0, m2] + >>> t2 = p(m1)*g(-m1, m2) + >>> t2.get_free_indices() + [m2] + """ + return self._index_structure.get_free_indices() + + def _replace_indices(self, repl: dict[TensorIndex, TensorIndex]) -> TensExpr: + return self.func(*[arg._replace_indices(repl) if isinstance(arg, TensExpr) else arg for arg in self.args]) + + def split(self): + """ + Returns a list of tensors, whose product is ``self``. + + Explanation + =========== + + Dummy indices contracted among different tensor components + become free indices with the same name as the one used to + represent the dummy indices. + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices, tensor_heads, TensorSymmetry + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> a, b, c, d = tensor_indices('a,b,c,d', Lorentz) + >>> A, B = tensor_heads('A,B', [Lorentz]*2, TensorSymmetry.fully_symmetric(2)) + >>> t = A(a,b)*B(-b,c) + >>> t + A(a, L_0)*B(-L_0, c) + >>> t.split() + [A(a, L_0), B(-L_0, c)] + """ + if self.args == (): + return [self] + splitp = [] + res = 1 + for arg in self.args: + if isinstance(arg, Tensor): + splitp.append(res*arg) + res = 1 + else: + res *= arg + return splitp + + def _eval_expand_mul(self, **hints): + args1 = [arg.args if isinstance(arg, (Add, TensAdd)) else (arg,) for arg in self.args] + return TensAdd(*[ + TensMul(*i).doit(deep=False) for i in itertools.product(*args1)] + ) + + def __neg__(self): + return TensMul(S.NegativeOne, self, is_canon_bp=self._is_canon_bp).doit(deep=False) + + def __getitem__(self, item): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + return self.data[item] + + def _get_args_for_traditional_printer(self): + args = list(self.args) + if self.coeff.could_extract_minus_sign(): + # expressions like "-A(a)" + sign = "-" + if args[0] == S.NegativeOne: + args = args[1:] + else: + args[0] = -args[0] + else: + sign = "" + return sign, args + + def _sort_args_for_sorted_components(self): + """ + Returns the ``args`` sorted according to the components commutation + properties. + + Explanation + =========== + + The sorting is done taking into account the commutation group + of the component tensors. + """ + cv = [arg for arg in self.args if isinstance(arg, TensExpr)] + sign = 1 + n = len(cv) - 1 + for i in range(n): + for j in range(n, i, -1): + c = cv[j-1].commutes_with(cv[j]) + # if `c` is `None`, it does neither commute nor anticommute, skip: + if c not in (0, 1): + continue + typ1 = sorted(set(cv[j-1].component.index_types), key=lambda x: x.name) + typ2 = sorted(set(cv[j].component.index_types), key=lambda x: x.name) + if (typ1, cv[j-1].component.name) > (typ2, cv[j].component.name): + cv[j-1], cv[j] = cv[j], cv[j-1] + # if `c` is 1, the anticommute, so change sign: + if c: + sign = -sign + + coeff = sign * self.coeff + if coeff != 1: + return [coeff] + cv + return cv + + def sorted_components(self): + """ + Returns a tensor product with sorted components. + """ + return TensMul(*self._sort_args_for_sorted_components()).doit(deep=False) + + def perm2tensor(self, g, is_canon_bp=False): + """ + Returns the tensor corresponding to the permutation ``g`` + + For further details, see the method in ``TIDS`` with the same name. + """ + return perm2tensor(self, g, is_canon_bp=is_canon_bp) + + def canon_bp(self): + """ + Canonicalize using the Butler-Portugal algorithm for canonicalization + under monoterm symmetries. + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices, TensorHead, TensorSymmetry + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> m0, m1, m2 = tensor_indices('m0,m1,m2', Lorentz) + >>> A = TensorHead('A', [Lorentz]*2, TensorSymmetry.fully_symmetric(-2)) + >>> t = A(m0,-m1)*A(m1,-m0) + >>> t.canon_bp() + -A(L_0, L_1)*A(-L_0, -L_1) + >>> t = A(m0,-m1)*A(m1,-m2)*A(m2,-m0) + >>> t.canon_bp() + 0 + """ + if self._is_canon_bp: + return self + expr = self.expand() + if isinstance(expr, TensAdd): + return expr.canon_bp() + if not expr.components: + return expr + expr = expr.doit(deep=False) #make sure self.coeff is populated correctly + t = expr.sorted_components() + g, dummies, msym = t._index_structure.indices_canon_args() + v = components_canon_args(t.components) + can = canonicalize(g, dummies, msym, *v) + if can == 0: + return S.Zero + tmul = t.perm2tensor(can, True) + return tmul + + def contract_delta(self, delta): + t = self.contract_metric(delta) + return t + + def _get_indices_to_args_pos(self): + """ + Get a dict mapping the index position to TensMul's argument number. + """ + pos_map = {} + pos_counter = 0 + for arg_i, arg in enumerate(self.args): + if not isinstance(arg, TensExpr): + continue + assert isinstance(arg, Tensor) + for i in range(arg.ext_rank): + pos_map[pos_counter] = arg_i + pos_counter += 1 + return pos_map + + def contract_metric(self, g): + """ + Raise or lower indices with the metric ``g``. + + Parameters + ========== + + g : metric + + Notes + ===== + + See the ``TensorIndexType`` docstring for the contraction conventions. + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices, tensor_heads + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> m0, m1, m2 = tensor_indices('m0,m1,m2', Lorentz) + >>> g = Lorentz.metric + >>> p, q = tensor_heads('p,q', [Lorentz]) + >>> t = p(m0)*q(m1)*g(-m0, -m1) + >>> t.canon_bp() + metric(L_0, L_1)*p(-L_0)*q(-L_1) + >>> t.contract_metric(g).canon_bp() + p(L_0)*q(-L_0) + """ + expr = self.expand().doit(deep=False) + if self != expr: + expr = canon_bp(expr) + return contract_metric(expr, g) + pos_map = self._get_indices_to_args_pos() + args = list(self.args) + + #antisym = g.index_types[0].metric_antisym + if g.symmetry == TensorSymmetry.fully_symmetric(-2): + antisym = 1 + elif g.symmetry == TensorSymmetry.fully_symmetric(2): + antisym = 0 + elif g.symmetry == TensorSymmetry.no_symmetry(2): + antisym = None + else: + raise NotImplementedError + + # list of positions of the metric ``g`` inside ``args`` + gpos = [i for i, x in enumerate(self.args) if isinstance(x, Tensor) and x.component == g] + if not gpos: + return self + + # Sign is either 1 or -1, to correct the sign after metric contraction + # (for spinor indices). + sign = 1 + dum = self.dum[:] + free = self.free[:] + elim = set() + for gposx in gpos: + if gposx in elim: + continue + free1 = [x for x in free if pos_map[x[1]] == gposx] + dum1 = [x for x in dum if pos_map[x[0]] == gposx or pos_map[x[1]] == gposx] + if not dum1: + continue + elim.add(gposx) + # subs with the multiplication neutral element, that is, remove it: + args[gposx] = 1 + if len(dum1) == 2: + if not antisym: + dum10, dum11 = dum1 + if pos_map[dum10[1]] == gposx: + # the index with pos p0 contravariant + p0 = dum10[0] + else: + # the index with pos p0 is covariant + p0 = dum10[1] + if pos_map[dum11[1]] == gposx: + # the index with pos p1 is contravariant + p1 = dum11[0] + else: + # the index with pos p1 is covariant + p1 = dum11[1] + + dum.append((p0, p1)) + else: + dum10, dum11 = dum1 + # change the sign to bring the indices of the metric to contravariant + # form; change the sign if dum10 has the metric index in position 0 + if pos_map[dum10[1]] == gposx: + # the index with pos p0 is contravariant + p0 = dum10[0] + if dum10[1] == 1: + sign = -sign + else: + # the index with pos p0 is covariant + p0 = dum10[1] + if dum10[0] == 0: + sign = -sign + if pos_map[dum11[1]] == gposx: + # the index with pos p1 is contravariant + p1 = dum11[0] + sign = -sign + else: + # the index with pos p1 is covariant + p1 = dum11[1] + + dum.append((p0, p1)) + + elif len(dum1) == 1: + if not antisym: + dp0, dp1 = dum1[0] + if pos_map[dp0] == pos_map[dp1]: + # g(i, -i) + typ = g.index_types[0] + sign = sign*typ.dim + + else: + # g(i0, i1)*p(-i1) + if pos_map[dp0] == gposx: + p1 = dp1 + else: + p1 = dp0 + + ind, p = free1[0] + free.append((ind, p1)) + else: + dp0, dp1 = dum1[0] + if pos_map[dp0] == pos_map[dp1]: + # g(i, -i) + typ = g.index_types[0] + sign = sign*typ.dim + + if dp0 < dp1: + # g(i, -i) = -D with antisymmetric metric + sign = -sign + else: + # g(i0, i1)*p(-i1) + if pos_map[dp0] == gposx: + p1 = dp1 + if dp0 == 0: + sign = -sign + else: + p1 = dp0 + ind, p = free1[0] + free.append((ind, p1)) + dum = [x for x in dum if x not in dum1] + free = [x for x in free if x not in free1] + + # shift positions: + shift = 0 + shifts = [0]*len(args) + for i in range(len(args)): + if i in elim: + shift += 2 + continue + shifts[i] = shift + free = [(ind, p - shifts[pos_map[p]]) for (ind, p) in free if pos_map[p] not in elim] + dum = [(p0 - shifts[pos_map[p0]], p1 - shifts[pos_map[p1]]) for p0, p1 in dum if pos_map[p0] not in elim and pos_map[p1] not in elim] + + res = ( sign*TensMul(*args) ).doit(deep=False) + if not isinstance(res, TensExpr): + return res + im = _IndexStructure.from_components_free_dum(res.components, free, dum) + return res._set_new_index_structure(im) + + def _set_new_index_structure(self, im, is_canon_bp=False): + indices = im.get_indices() + return self._set_indices(*indices, is_canon_bp=is_canon_bp) + + def _set_indices(self, *indices, is_canon_bp=False, **kw_args): + if len(indices) != self.ext_rank: + raise ValueError("indices length mismatch") + args = list(self.args) + pos = 0 + for i, arg in enumerate(args): + if not isinstance(arg, TensExpr): + continue + assert isinstance(arg, Tensor) + ext_rank = arg.ext_rank + args[i] = arg._set_indices(*indices[pos:pos+ext_rank]) + pos += ext_rank + return TensMul(*args, is_canon_bp=is_canon_bp).doit(deep=False) + + @staticmethod + def _index_replacement_for_contract_metric(args, free, dum): + for arg in args: + if not isinstance(arg, TensExpr): + continue + assert isinstance(arg, Tensor) + + def substitute_indices(self, *index_tuples): + new_args = [] + for arg in self.args: + if isinstance(arg, TensExpr): + arg = arg.substitute_indices(*index_tuples) + new_args.append(arg) + return TensMul(*new_args).doit(deep=False) + + def __call__(self, *indices): + deprecate_call() + free_args = self.free_args + indices = list(indices) + if [x.tensor_index_type for x in indices] != [x.tensor_index_type for x in free_args]: + raise ValueError('incompatible types') + if indices == free_args: + return self + t = self.substitute_indices(*list(zip(free_args, indices))) + + # object is rebuilt in order to make sure that all contracted indices + # get recognized as dummies, but only if there are contracted indices. + if len({i if i.is_up else -i for i in indices}) != len(indices): + return t.func(*t.args) + return t + + def _extract_data(self, replacement_dict): + args_indices, arrays = zip(*[arg._extract_data(replacement_dict) for arg in self.args if isinstance(arg, TensExpr)]) + coeff = reduce(operator.mul, [a for a in self.args if not isinstance(a, TensExpr)], S.One) + indices, free, free_names, dummy_data = TensMul._indices_to_free_dum(args_indices) + dum = TensMul._dummy_data_to_dum(dummy_data) + ext_rank = self.ext_rank + free.sort(key=lambda x: x[1]) + free_indices = [i[0] for i in free] + return free_indices, coeff*_TensorDataLazyEvaluator.data_contract_dum(arrays, dum, ext_rank) + + @property + def data(self): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + dat = _tensor_data_substitution_dict[self.expand()] + return dat + + @data.setter + def data(self, data): + deprecate_data() + raise ValueError("Not possible to set component data to a tensor expression") + + @data.deleter + def data(self): + deprecate_data() + raise ValueError("Not possible to delete component data to a tensor expression") + + def __iter__(self): + deprecate_data() + with ignore_warnings(SymPyDeprecationWarning): + if self.data is None: + raise ValueError("No iteration on abstract tensors") + return self.data.__iter__() + + @staticmethod + def _dedupe_indices(new, exclude): + """ + exclude: set + new: TensExpr + + If ``new`` has any dummy indices that are in ``exclude``, return a version + of new with those indices replaced. If no replacements are needed, + return None + + """ + exclude = set(exclude) + dums_new = set(get_dummy_indices(new)) + free_new = set(get_free_indices(new)) + + conflicts = dums_new.intersection(exclude) + if len(conflicts) == 0: + return None + + """ + ``exclude_for_gen`` is to be passed to ``_IndexStructure._get_generator_for_dummy_indices()``. + Since the latter does not use the index position for anything, we just + set it as ``None`` here. + """ + exclude.update(dums_new) + exclude.update(free_new) + exclude_for_gen = [(i, None) for i in exclude] + gen = _IndexStructure._get_generator_for_dummy_indices(exclude_for_gen) + repl = {} + for d in conflicts: + if -d in repl.keys(): + continue + newname = gen(d.tensor_index_type) + new_d = d.func(newname, *d.args[1:]) + repl[d] = new_d + repl[-d] = -new_d + + if len(repl) == 0: + return None + + new_renamed = new._replace_indices(repl) + return new_renamed + + def _dedupe_indices_in_rule(self, rule): + """ + rule: dict + + This applies TensMul._dedupe_indices on all values of rule. + + """ + index_rules = {k:v for k,v in rule.items() if isinstance(k, TensorIndex)} + other_rules = {k:v for k,v in rule.items() if k not in index_rules.keys()} + exclude = set(self.get_indices()) + + newrule = {} + newrule.update(index_rules) + exclude.update(index_rules.keys()) + exclude.update(index_rules.values()) + for old, new in other_rules.items(): + new_renamed = TensMul._dedupe_indices(new, exclude) + if old == new or new_renamed is None: + newrule[old] = new + else: + newrule[old] = new_renamed + exclude.update(get_indices(new_renamed)) + return newrule + + def _eval_subs(self, old, new): + """ + If new is an index which is already present in self as a dummy, the dummies in self should be renamed. + """ + + if not isinstance(new, TensorIndex): + return None + + exclude = {new} + self_renamed = self._dedupe_indices(self, exclude) + if self_renamed is None: + return None + else: + return self_renamed._subs(old, new).doit(deep=False) + + def _eval_rewrite_as_Indexed(self, *args, **kwargs): + from sympy.concrete.summations import Sum + index_symbols = [i.args[0] for i in self.get_indices()] + args = [arg.args[0] if isinstance(arg, Sum) else arg for arg in args] + expr = Mul.fromiter(args) + return self._check_add_Sum(expr, index_symbols) + + def _eval_partial_derivative(self, s): + # Evaluation like Mul + terms = [] + for i, arg in enumerate(self.args): + # checking whether some tensor instance is differentiated + # or some other thing is necessary, but ugly + if isinstance(arg, TensExpr): + d = arg._eval_partial_derivative(s) + else: + # do not call diff is s is no symbol + if s._diff_wrt: + d = arg._eval_derivative(s) + else: + d = S.Zero + if d: + terms.append(TensMul.fromiter(self.args[:i] + (d,) + self.args[i + 1:]).doit(deep=False)) + return TensAdd.fromiter(terms).doit(deep=False) + + + def _matches_commutative(self, expr, repl_dict=None, old=False): + """ + Match assuming all tensors commute. But note that we are not assuming anything about their symmetry under index permutations. + """ + #Take care of the various possible types for expr. + if not isinstance(expr, TensMul): + if isinstance(expr, (TensExpr, Expr)): + expr = TensMul(expr) + else: + return None + + #The code that follows assumes expr is a TensMul + + if repl_dict is None: + repl_dict = {} + else: + repl_dict = repl_dict.copy() + + #Make sure that none of the dummy indices in self, expr conflict with the values already present in repl_dict. This may happen due to automatic index relabelling when rem_query and rem_expr are formed later on in this function (it calls itself recursively). + indices = [k for k in repl_dict.values() if isinstance(k ,TensorIndex)] + + def dedupe(expr): + renamed = TensMul._dedupe_indices(expr, indices) + if renamed is not None: + return renamed + else: + return expr + + self = dedupe(self) + expr = dedupe(expr) + + #Find the non-tensor part of expr. This need not be the same as expr.coeff when expr.doit() has not been called. + expr_coeff = reduce(lambda a, b: a*b, [arg for arg in expr.args if not isinstance(arg, TensExpr)], S.One) + + # handle simple patterns + if self == expr: + return repl_dict + + if len(_get_wilds(self)) == 0: + return self._matches_simple(expr, repl_dict, old) + + def siftkey(arg): + if isinstance(arg, WildTensor): + return "WildTensor" + elif isinstance(arg, (Tensor, TensExpr)): + return "Tensor" + else: + return "coeff" + + query_sifted = sift(self.args, siftkey) + expr_sifted = sift(expr.args, siftkey) + + #Sanity checks + if "coeff" in query_sifted.keys(): + if TensMul(*query_sifted["coeff"]).doit(deep=False) != self.coeff: + raise NotImplementedError(f"Found something that we do not know to handle: {query_sifted['coeff']}") + if "coeff" in expr_sifted.keys(): + if TensMul(*expr_sifted["coeff"]).doit(deep=False) != expr_coeff: + raise NotImplementedError(f"Found something that we do not know to handle: {expr_sifted['coeff']}") + + query_tens_heads = {tuple(getattr(x, "components", [])) for x in query_sifted["Tensor"]} #We use getattr because, e.g. TensAdd does not have the 'components' attribute. + expr_tens_heads = {tuple(getattr(x, "components", [])) for x in expr_sifted["Tensor"]} + if not query_tens_heads.issubset(expr_tens_heads): + #Some tensorheads in self are not present in the expr + return None + + #Try to match all non-wild tensors of self with tensors that compose expr + if len(query_sifted["Tensor"]) > 0: + q_tensor = query_sifted["Tensor"][0] + """ + We need to iterate over all possible symmetrized forms of q_tensor since the matches given by some of them may map dummy indices to free indices; the information about which indices are dummy/free will only be available later, when we are doing rem_q.matches(rem_e) + """ + for q_tens in q_tensor._get_symmetrized_forms(): + for e in expr_sifted["Tensor"]: + if isinstance(q_tens, TensMul): + #q_tensor got a minus sign due to this permutation. + sign = -1 + else: + sign = 1 + + """ + _matches is used here since we are already iterating over index permutations of q_tensor. Also note that the sign is removed from q_tensor, and will later be put into rem_q. + """ + m = (sign*q_tens)._matches(e) + if m is None: + continue + + rem_query = self.func(sign, *[a for a in self.args if a != q_tensor]).doit(deep=False) + rem_expr = expr.func(*[a for a in expr.args if a != e]).doit(deep=False) + tmp_repl = {} + tmp_repl.update(repl_dict) + tmp_repl.update(m) + rem_m = rem_query.matches(rem_expr, repl_dict=tmp_repl) + if rem_m is not None: + #Check that contracted indices are not mapped to different indices. + internally_consistent = True + for k in rem_m.keys(): + if isinstance(k,TensorIndex): + if -k in rem_m.keys() and rem_m[-k] != -rem_m[k]: + internally_consistent = False + break + if internally_consistent: + repl_dict.update(rem_m) + return repl_dict + + return None + + #Try to match WildTensor instances which have indices + matched_e_tensors = [] + remaining_e_tensors = expr_sifted["Tensor"] + indexless_wilds, wilds = sift(query_sifted["WildTensor"], lambda x: len(x.get_free_indices()) == 0, binary=True) + + for w in wilds: + free_this_wild = set(w.get_free_indices()) + tensors_to_try = [] + for t in remaining_e_tensors: + free = t.get_free_indices() + shares_indices_with_wild = True + for i in free: + if all(j.matches(i) is None for j in free_this_wild): + #The index i matches none of the indices in free_this_wild + shares_indices_with_wild = False + if shares_indices_with_wild: + tensors_to_try.append(t) + + m = w.matches(TensMul(*tensors_to_try).doit(deep=False) ) + if m is None: + return None + else: + for tens in tensors_to_try: + matched_e_tensors.append(tens) + repl_dict.update(m) + + #Try to match indexless WildTensor instances + remaining_e_tensors = [t for t in expr_sifted["Tensor"] if t not in matched_e_tensors] + if len(indexless_wilds) > 0: + #If there are any remaining tensors, match them with the indexless WildTensor + m = indexless_wilds[0].matches( TensMul(1,*remaining_e_tensors).doit(deep=False) ) + if m is None: + return None + else: + repl_dict.update(m) + elif len(remaining_e_tensors) > 0: + return None + + #Try to match the non-tensorial coefficient + m = self.coeff.matches(expr_coeff, old=old) + if m is None: + return None + else: + repl_dict.update(m) + + return repl_dict + + def matches(self, expr, repl_dict=None, old=False): + expr = sympify(expr) + + if repl_dict is None: + repl_dict = {} + else: + repl_dict = repl_dict.copy() + + commute = all(arg.component.comm == 0 for arg in expr.args if isinstance(arg, Tensor)) + if commute: + return self._matches_commutative(expr, repl_dict, old) + else: + raise NotImplementedError("Tensor matching not implemented for non-commuting tensors") + +class TensorElement(TensExpr): + """ + Tensor with evaluated components. + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType, TensorHead, TensorSymmetry + >>> from sympy import symbols + >>> L = TensorIndexType("L") + >>> i, j, k = symbols("i j k") + >>> A = TensorHead("A", [L, L], TensorSymmetry.fully_symmetric(2)) + >>> A(i, j).get_free_indices() + [i, j] + + If we want to set component ``i`` to a specific value, use the + ``TensorElement`` class: + + >>> from sympy.tensor.tensor import TensorElement + >>> te = TensorElement(A(i, j), {i: 2}) + + As index ``i`` has been accessed (``{i: 2}`` is the evaluation of its 3rd + element), the free indices will only contain ``j``: + + >>> te.get_free_indices() + [j] + """ + + def __new__(cls, expr, index_map): + if not isinstance(expr, Tensor): + # remap + if not isinstance(expr, TensExpr): + raise TypeError("%s is not a tensor expression" % expr) + return expr.func(*[TensorElement(arg, index_map) for arg in expr.args]) + expr_free_indices = expr.get_free_indices() + name_translation = {i.args[0]: i for i in expr_free_indices} + index_map = {name_translation.get(index, index): value for index, value in index_map.items()} + index_map = {index: value for index, value in index_map.items() if index in expr_free_indices} + if len(index_map) == 0: + return expr + free_indices = [i for i in expr_free_indices if i not in index_map.keys()] + index_map = Dict(index_map) + obj = TensExpr.__new__(cls, expr, index_map) + obj._free_indices = free_indices + return obj + + @property + def free(self): + return [(index, i) for i, index in enumerate(self.get_free_indices())] + + @property + def dum(self): + # TODO: inherit dummies from expr + return [] + + @property + def expr(self): + return self._args[0] + + @property + def index_map(self): + return self._args[1] + + @property + def coeff(self): + return S.One + + @property + def nocoeff(self): + return self + + def get_free_indices(self): + return self._free_indices + + def _replace_indices(self, repl: dict[TensorIndex, TensorIndex]) -> TensExpr: + # TODO: can be improved: + return self.xreplace(repl) + + def get_indices(self): + return self.get_free_indices() + + def _extract_data(self, replacement_dict): + ret_indices, array = self.expr._extract_data(replacement_dict) + index_map = self.index_map + slice_tuple = tuple(index_map.get(i, slice(None)) for i in ret_indices) + ret_indices = [i for i in ret_indices if i not in index_map] + array = array.__getitem__(slice_tuple) + return ret_indices, array + + +class WildTensorHead(TensorHead): + """ + A wild object that is used to create ``WildTensor`` instances + + Explanation + =========== + + Examples + ======== + >>> from sympy.tensor.tensor import TensorHead, TensorIndex, WildTensorHead, TensorIndexType + >>> R3 = TensorIndexType('R3', dim=3) + >>> p = TensorIndex('p', R3) + >>> q = TensorIndex('q', R3) + + A WildTensorHead can be created without specifying a ``TensorIndexType`` + + >>> W = WildTensorHead("W") + + Calling it with a ``TensorIndex`` creates a ``WildTensor`` instance. + + >>> type(W(p)) + + + The ``TensorIndexType`` is automatically detected from the index that is passed + + >>> W(p).component + W(R3) + + Calling it with no indices returns an object that can match tensors with any number of indices. + + >>> K = TensorHead('K', [R3]) + >>> Q = TensorHead('Q', [R3, R3]) + >>> W().matches(K(p)) + {W: K(p)} + >>> W().matches(Q(p,q)) + {W: Q(p, q)} + + If you want to ignore the order of indices while matching, pass ``unordered_indices=True``. + + >>> U = WildTensorHead("U", unordered_indices=True) + >>> W(p,q).matches(Q(q,p)) + >>> U(p,q).matches(Q(q,p)) + {U(R3,R3): _WildTensExpr(Q(q, p))} + + Parameters + ========== + name : name of the tensor + unordered_indices : whether the order of the indices matters for matching + (default: False) + + See also + ======== + ``WildTensor`` + ``TensorHead`` + + """ + def __new__(cls, name, index_types=None, symmetry=None, comm=0, unordered_indices=False): + if isinstance(name, str): + name_symbol = Symbol(name) + elif isinstance(name, Symbol): + name_symbol = name + else: + raise ValueError("invalid name") + + if index_types is None: + index_types = [] + + if symmetry is None: + symmetry = TensorSymmetry.no_symmetry(len(index_types)) + else: + assert symmetry.rank == len(index_types) + + if symmetry != TensorSymmetry.no_symmetry(len(index_types)): + raise NotImplementedError("Wild matching based on symmetry is not implemented.") + + obj = Basic.__new__(cls, name_symbol, Tuple(*index_types), sympify(symmetry), sympify(comm), sympify(unordered_indices)) + + return obj + + @property + def unordered_indices(self): + return self.args[4] + + def __call__(self, *indices, **kwargs): + tensor = WildTensor(self, indices, **kwargs) + return tensor.doit() + + +class WildTensor(Tensor): + """ + A wild object which matches ``Tensor`` instances + + Explanation + =========== + This is instantiated by attaching indices to a ``WildTensorHead`` instance. + + Examples + ======== + >>> from sympy.tensor.tensor import TensorHead, TensorIndex, WildTensorHead, TensorIndexType + >>> W = WildTensorHead("W") + >>> R3 = TensorIndexType('R3', dim=3) + >>> p = TensorIndex('p', R3) + >>> q = TensorIndex('q', R3) + >>> K = TensorHead('K', [R3]) + >>> Q = TensorHead('Q', [R3, R3]) + + Matching also takes the indices into account + >>> W(p).matches(K(p)) + {W(R3): _WildTensExpr(K(p))} + >>> W(p).matches(K(q)) + >>> W(p).matches(K(-p)) + + If you want to match objects with any number of indices, just use a ``WildTensor`` with no indices. + >>> W().matches(K(p)) + {W: K(p)} + >>> W().matches(Q(p,q)) + {W: Q(p, q)} + + See Also + ======== + ``WildTensorHead`` + ``Tensor`` + + """ + def __new__(cls, tensor_head, indices, **kw_args): + is_canon_bp = kw_args.pop("is_canon_bp", False) + + if tensor_head.func == TensorHead: + """ + If someone tried to call WildTensor by supplying a TensorHead (not a WildTensorHead), return a normal tensor instead. This is helpful when using subs on an expression to replace occurrences of a WildTensorHead with a TensorHead. + """ + return Tensor(tensor_head, indices, is_canon_bp=is_canon_bp, **kw_args) + elif tensor_head.func == _WildTensExpr: + return tensor_head(*indices) + + indices = cls._parse_indices(tensor_head, indices) + index_types = [ind.tensor_index_type for ind in indices] + tensor_head = tensor_head.func( + tensor_head.name, + index_types, + symmetry=None, + comm=tensor_head.comm, + unordered_indices=tensor_head.unordered_indices, + ) + + obj = Basic.__new__(cls, tensor_head, Tuple(*indices)) + obj.name = tensor_head.name + obj._index_structure = _IndexStructure.from_indices(*indices) + obj._free = obj._index_structure.free[:] + obj._dum = obj._index_structure.dum[:] + obj._ext_rank = obj._index_structure._ext_rank + obj._coeff = S.One + obj._nocoeff = obj + obj._component = tensor_head + obj._components = [tensor_head] + if tensor_head.rank != len(indices): + raise ValueError("wrong number of indices") + obj.is_canon_bp = is_canon_bp + obj._index_map = obj._build_index_map(indices, obj._index_structure) + + return obj + + + def matches(self, expr, repl_dict=None, old=False): + if not isinstance(expr, TensExpr) and expr != S(1): + return None + + if repl_dict is None: + repl_dict = {} + else: + repl_dict = repl_dict.copy() + + if len(self.indices) > 0: + if not hasattr(expr, "get_free_indices"): + return None + expr_indices = expr.get_free_indices() + if len(expr_indices) != len(self.indices): + return None + if self._component.unordered_indices: + m = self._match_indices_ignoring_order(expr) + if m is None: + return None + else: + repl_dict.update(m) + else: + for i in range(len(expr_indices)): + m = self.indices[i].matches(expr_indices[i]) + if m is None: + return None + else: + repl_dict.update(m) + + repl_dict[self.component] = _WildTensExpr(expr) + else: + #If no indices were passed to the WildTensor, it may match tensors with any number of indices. + repl_dict[self] = expr + + return repl_dict + + def _match_indices_ignoring_order(self, expr, repl_dict=None, old=False): + """ + Helper method for matches. Checks if the indices of self and expr + match disregarding index ordering. + """ + if repl_dict is None: + repl_dict = {} + else: + repl_dict = repl_dict.copy() + + def siftkey(ind): + if isinstance(ind, WildTensorIndex): + if ind.ignore_updown: + return "wild, updown" + else: + return "wild" + else: + return "nonwild" + + indices_sifted = sift(self.indices, siftkey) + + matched_indices = [] + expr_indices_remaining = expr.get_indices() + for ind in indices_sifted["nonwild"]: + matched_this_ind = False + for e_ind in expr_indices_remaining: + if e_ind in matched_indices: + continue + m = ind.matches(e_ind) + if m is not None: + matched_this_ind = True + repl_dict.update(m) + matched_indices.append(e_ind) + break + if not matched_this_ind: + return None + + expr_indices_remaining = [i for i in expr_indices_remaining if i not in matched_indices] + for ind in indices_sifted["wild"]: + matched_this_ind = False + for e_ind in expr_indices_remaining: + m = ind.matches(e_ind) + if m is not None: + if -ind in repl_dict.keys() and -repl_dict[-ind] != m[ind]: + return None + matched_this_ind = True + repl_dict.update(m) + matched_indices.append(e_ind) + break + if not matched_this_ind: + return None + + expr_indices_remaining = [i for i in expr_indices_remaining if i not in matched_indices] + for ind in indices_sifted["wild, updown"]: + matched_this_ind = False + for e_ind in expr_indices_remaining: + m = ind.matches(e_ind) + if m is not None: + if -ind in repl_dict.keys() and -repl_dict[-ind] != m[ind]: + return None + matched_this_ind = True + repl_dict.update(m) + matched_indices.append(e_ind) + break + if not matched_this_ind: + return None + + if len(matched_indices) < len(self.indices): + return None + else: + return repl_dict + +class WildTensorIndex(TensorIndex): + """ + A wild object that matches TensorIndex instances. + + Examples + ======== + >>> from sympy.tensor.tensor import TensorIndex, TensorIndexType, WildTensorIndex + >>> R3 = TensorIndexType('R3', dim=3) + >>> p = TensorIndex("p", R3) + + By default, covariant indices only match with covariant indices (and + similarly for contravariant) + + >>> q = WildTensorIndex("q", R3) + >>> (q).matches(p) + {q: p} + >>> (q).matches(-p) + + If you want matching to ignore whether the index is co/contra-variant, set + ignore_updown=True + + >>> r = WildTensorIndex("r", R3, ignore_updown=True) + >>> (r).matches(-p) + {r: -p} + >>> (r).matches(p) + {r: p} + + Parameters + ========== + name : name of the index (string), or ``True`` if you want it to be + automatically assigned + tensor_index_type : ``TensorIndexType`` of the index + is_up : flag for contravariant index (is_up=True by default) + ignore_updown : bool, Whether this should match both co- and contra-variant + indices (default:False) + """ + def __new__(cls, name, tensor_index_type, is_up=True, ignore_updown=False): + if isinstance(name, str): + name_symbol = Symbol(name) + elif isinstance(name, Symbol): + name_symbol = name + elif name is True: + name = "_i{}".format(len(tensor_index_type._autogenerated)) + name_symbol = Symbol(name) + tensor_index_type._autogenerated.append(name_symbol) + else: + raise ValueError("invalid name") + + is_up = sympify(is_up) + ignore_updown = sympify(ignore_updown) + return Basic.__new__(cls, name_symbol, tensor_index_type, is_up, ignore_updown) + + @property + def ignore_updown(self): + return self.args[3] + + def __neg__(self): + t1 = WildTensorIndex(self.name, self.tensor_index_type, + (not self.is_up), self.ignore_updown) + return t1 + + def matches(self, expr, repl_dict=None, old=False): + if not isinstance(expr, TensorIndex): + return None + if self.tensor_index_type != expr.tensor_index_type: + return None + if not self.ignore_updown: + if self.is_up != expr.is_up: + return None + + if repl_dict is None: + repl_dict = {} + else: + repl_dict = repl_dict.copy() + + repl_dict[self] = expr + return repl_dict + + +class _WildTensExpr(Basic): + """ + INTERNAL USE ONLY + + This is an object that helps with replacement of WildTensors in expressions. + When this object is set as the tensor_head of a WildTensor, it replaces the + WildTensor by a TensExpr (passed when initializing this object). + + Examples + ======== + >>> from sympy.tensor.tensor import WildTensorHead, TensorIndex, TensorHead, TensorIndexType + >>> W = WildTensorHead("W") + >>> R3 = TensorIndexType('R3', dim=3) + >>> p = TensorIndex('p', R3) + >>> q = TensorIndex('q', R3) + >>> K = TensorHead('K', [R3]) + >>> print( ( K(p) ).replace( W(p), W(q)*W(-q)*W(p) ) ) + K(R_0)*K(-R_0)*K(p) + + """ + def __init__(self, expr): + if not isinstance(expr, TensExpr): + raise TypeError("_WildTensExpr expects a TensExpr as argument") + self.expr = expr + + def __call__(self, *indices): + return self.expr._replace_indices(dict(zip(self.expr.get_free_indices(), indices))) + + def __neg__(self): + return self.func(self.expr*S.NegativeOne) + + def __abs__(self): + raise NotImplementedError + + def __add__(self, other): + if other.func != self.func: + raise TypeError(f"Cannot add {self.func} to {other.func}") + return self.func(self.expr+other.expr) + + def __radd__(self, other): + if other.func != self.func: + raise TypeError(f"Cannot add {self.func} to {other.func}") + return self.func(other.expr+self.expr) + + def __sub__(self, other): + return self + (-other) + + def __rsub__(self, other): + return other + (-self) + + def __mul__(self, other): + raise NotImplementedError + + def __rmul__(self, other): + raise NotImplementedError + + def __truediv__(self, other): + raise NotImplementedError + + def __rtruediv__(self, other): + raise NotImplementedError + + def __pow__(self, other): + raise NotImplementedError + + def __rpow__(self, other): + raise NotImplementedError + + +def canon_bp(p): + """ + Butler-Portugal canonicalization. See ``tensor_can.py`` from the + combinatorics module for the details. + """ + if isinstance(p, TensExpr): + return p.canon_bp() + return p + + +def tensor_mul(*a): + """ + product of tensors + """ + if not a: + return TensMul.from_data(S.One, [], [], []) + t = a[0] + for tx in a[1:]: + t = t*tx + return t + + +def riemann_cyclic_replace(t_r): + """ + replace Riemann tensor with an equivalent expression + + ``R(m,n,p,q) -> 2/3*R(m,n,p,q) - 1/3*R(m,q,n,p) + 1/3*R(m,p,n,q)`` + + """ + free = sorted(t_r.free, key=lambda x: x[1]) + m, n, p, q = [x[0] for x in free] + t0 = t_r*Rational(2, 3) + t1 = -t_r.substitute_indices((m,m),(n,q),(p,n),(q,p))*Rational(1, 3) + t2 = t_r.substitute_indices((m,m),(n,p),(p,n),(q,q))*Rational(1, 3) + t3 = t0 + t1 + t2 + return t3 + +def riemann_cyclic(t2): + """ + Replace each Riemann tensor with an equivalent expression + satisfying the cyclic identity. + + This trick is discussed in the reference guide to Cadabra. + + Examples + ======== + + >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices, TensorHead, riemann_cyclic, TensorSymmetry + >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L') + >>> i, j, k, l = tensor_indices('i,j,k,l', Lorentz) + >>> R = TensorHead('R', [Lorentz]*4, TensorSymmetry.riemann()) + >>> t = R(i,j,k,l)*(R(-i,-j,-k,-l) - 2*R(-i,-k,-j,-l)) + >>> riemann_cyclic(t) + 0 + """ + t2 = t2.expand() + if isinstance(t2, (TensMul, Tensor)): + args = [t2] + else: + args = t2.args + a1 = [x.split() for x in args] + a2 = [[riemann_cyclic_replace(tx) for tx in y] for y in a1] + a3 = [tensor_mul(*v) for v in a2] + t3 = TensAdd(*a3).doit(deep=False) + if not t3: + return t3 + else: + return canon_bp(t3) + + +def get_lines(ex, index_type): + """ + Returns ``(lines, traces, rest)`` for an index type, + where ``lines`` is the list of list of positions of a matrix line, + ``traces`` is the list of list of traced matrix lines, + ``rest`` is the rest of the elements of the tensor. + """ + def _join_lines(a): + i = 0 + while i < len(a): + x = a[i] + xend = x[-1] + xstart = x[0] + hit = True + while hit: + hit = False + for j in range(i + 1, len(a)): + if j >= len(a): + break + if a[j][0] == xend: + hit = True + x.extend(a[j][1:]) + xend = x[-1] + a.pop(j) + continue + if a[j][0] == xstart: + hit = True + a[i] = reversed(a[j][1:]) + x + x = a[i] + xstart = a[i][0] + a.pop(j) + continue + if a[j][-1] == xend: + hit = True + x.extend(reversed(a[j][:-1])) + xend = x[-1] + a.pop(j) + continue + if a[j][-1] == xstart: + hit = True + a[i] = a[j][:-1] + x + x = a[i] + xstart = x[0] + a.pop(j) + continue + i += 1 + return a + + arguments = ex.args + dt = {} + for c in ex.args: + if not isinstance(c, TensExpr): + continue + if c in dt: + continue + index_types = c.index_types + a = [] + for i in range(len(index_types)): + if index_types[i] is index_type: + a.append(i) + if len(a) > 2: + raise ValueError('at most two indices of type %s allowed' % index_type) + if len(a) == 2: + dt[c] = a + #dum = ex.dum + lines = [] + traces = [] + traces1 = [] + #indices_to_args_pos = ex._get_indices_to_args_pos() + # TODO: add a dum_to_components_map ? + for p0, p1, c0, c1 in ex.dum_in_args: + if arguments[c0] not in dt: + continue + if c0 == c1: + traces.append([c0]) + continue + ta0 = dt[arguments[c0]] + ta1 = dt[arguments[c1]] + if p0 not in ta0: + continue + if ta0.index(p0) == ta1.index(p1): + # case gamma(i,s0,-s1) in c0, gamma(j,-s0,s2) in c1; + # to deal with this case one could add to the position + # a flag for transposition; + # one could write [(c0, False), (c1, True)] + raise NotImplementedError + # if p0 == ta0[1] then G in pos c0 is mult on the right by G in c1 + # if p0 == ta0[0] then G in pos c1 is mult on the right by G in c0 + ta0 = dt[arguments[c0]] + b0, b1 = (c0, c1) if p0 == ta0[1] else (c1, c0) + lines1 = lines.copy() + for line in lines: + if line[-1] == b0: + if line[0] == b1: + n = line.index(min(line)) + traces1.append(line) + traces.append(line[n:] + line[:n]) + else: + line.append(b1) + break + elif line[0] == b1: + line.insert(0, b0) + break + else: + lines1.append([b0, b1]) + + lines = [x for x in lines1 if x not in traces1] + lines = _join_lines(lines) + rest = [] + for line in lines: + for y in line: + rest.append(y) + for line in traces: + for y in line: + rest.append(y) + rest = [x for x in range(len(arguments)) if x not in rest] + + return lines, traces, rest + + +def get_free_indices(t): + if not isinstance(t, TensExpr): + return () + return t.get_free_indices() + + +def get_indices(t): + if not isinstance(t, TensExpr): + return () + return t.get_indices() + +def get_dummy_indices(t): + if not isinstance(t, TensExpr): + return () + inds = t.get_indices() + free = t.get_free_indices() + return [i for i in inds if i not in free] + +def get_index_structure(t): + if isinstance(t, TensExpr): + return t._index_structure + return _IndexStructure([], [], [], []) + + +def get_coeff(t): + if isinstance(t, Tensor): + return S.One + if isinstance(t, TensMul): + return t.coeff + if isinstance(t, TensExpr): + raise ValueError("no coefficient associated to this tensor expression") + return t + +def contract_metric(t, g): + if isinstance(t, TensExpr): + return t.contract_metric(g) + return t + +def perm2tensor(t, g, is_canon_bp=False): + """ + Returns the tensor corresponding to the permutation ``g`` + + For further details, see the method in ``TIDS`` with the same name. + """ + if not isinstance(t, TensExpr): + return t + elif isinstance(t, (Tensor, TensMul)): + nim = get_index_structure(t).perm2tensor(g, is_canon_bp=is_canon_bp) + res = t._set_new_index_structure(nim, is_canon_bp=is_canon_bp) + if g[-1] != len(g) - 1: + return -res + + return res + raise NotImplementedError() + + +def substitute_indices(t, *index_tuples): + if not isinstance(t, TensExpr): + return t + return t.substitute_indices(*index_tuples) + + +def _get_wilds(expr): + return list(expr.atoms(Wild, WildFunction, WildTensor, WildTensorIndex, WildTensorHead)) + + +def get_postprocessor(cls): + def _postprocessor(expr): + tens_class = {Mul: TensMul, Add: TensAdd}[cls] + if any(isinstance(a, TensExpr) for a in expr.args): + return tens_class(*expr.args) + else: + return expr + + return _postprocessor + +Basic._constructor_postprocessor_mapping[TensExpr] = { + "Mul": [get_postprocessor(Mul)], +} diff --git a/phivenv/Lib/site-packages/sympy/tensor/tests/__init__.py b/phivenv/Lib/site-packages/sympy/tensor/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/tensor/tests/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/tensor/tests/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..882a1e06dffea04e900fb81b42145afa0280b2a1 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/tensor/tests/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/tensor/tests/__pycache__/test_functions.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/tensor/tests/__pycache__/test_functions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..079aaaa1e63fd232838340628e54793acf086d8a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/tensor/tests/__pycache__/test_functions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/tensor/tests/__pycache__/test_index_methods.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/tensor/tests/__pycache__/test_index_methods.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d394fe6bad1c54eae772880f64c6ed7dad87360 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/tensor/tests/__pycache__/test_index_methods.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/tensor/tests/__pycache__/test_indexed.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/tensor/tests/__pycache__/test_indexed.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67072f918c7754015b4e21f71a87f960e78054ea Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/tensor/tests/__pycache__/test_indexed.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/tensor/tests/__pycache__/test_printing.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/tensor/tests/__pycache__/test_printing.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f7b4b3e10bcd37cb4c2d91a0d5cc054976117d8 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/tensor/tests/__pycache__/test_printing.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/tensor/tests/__pycache__/test_tensor.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/tensor/tests/__pycache__/test_tensor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d29da7c0df21372d9d40fc57043c2142dcacf22 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/tensor/tests/__pycache__/test_tensor.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/tensor/tests/__pycache__/test_tensor_element.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/tensor/tests/__pycache__/test_tensor_element.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93e3ee9bcff6b4e6d038b4541840b32e22aa0885 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/tensor/tests/__pycache__/test_tensor_element.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/tensor/tests/__pycache__/test_tensor_operators.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/tensor/tests/__pycache__/test_tensor_operators.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d7a78d915cbdae35799e9d4c1b20ac4568e1118 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/tensor/tests/__pycache__/test_tensor_operators.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/tensor/tests/test_functions.py b/phivenv/Lib/site-packages/sympy/tensor/tests/test_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..ae40865d1bddffaa976dc3d94ae1ef1b6c97ca35 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/tensor/tests/test_functions.py @@ -0,0 +1,57 @@ +from sympy.tensor.functions import TensorProduct +from sympy.matrices.dense import Matrix +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.tensor.array import Array +from sympy.abc import x, y, z +from sympy.abc import i, j, k, l + + +A = MatrixSymbol("A", 3, 3) +B = MatrixSymbol("B", 3, 3) +C = MatrixSymbol("C", 3, 3) + + +def test_TensorProduct_construction(): + assert TensorProduct(3, 4) == 12 + assert isinstance(TensorProduct(A, A), TensorProduct) + + expr = TensorProduct(TensorProduct(x, y), z) + assert expr == x*y*z + + expr = TensorProduct(TensorProduct(A, B), C) + assert expr == TensorProduct(A, B, C) + + expr = TensorProduct(Matrix.eye(2), Array([[0, -1], [1, 0]])) + assert expr == Array([ + [ + [[0, -1], [1, 0]], + [[0, 0], [0, 0]] + ], + [ + [[0, 0], [0, 0]], + [[0, -1], [1, 0]] + ] + ]) + + +def test_TensorProduct_shape(): + + expr = TensorProduct(3, 4, evaluate=False) + assert expr.shape == () + assert expr.rank() == 0 + + expr = TensorProduct(Array([1, 2]), Array([x, y]), evaluate=False) + assert expr.shape == (2, 2) + assert expr.rank() == 2 + expr = TensorProduct(expr, expr, evaluate=False) + assert expr.shape == (2, 2, 2, 2) + assert expr.rank() == 4 + + expr = TensorProduct(Matrix.eye(2), Array([[0, -1], [1, 0]]), evaluate=False) + assert expr.shape == (2, 2, 2, 2) + assert expr.rank() == 4 + + +def test_TensorProduct_getitem(): + expr = TensorProduct(A, B) + assert expr[i, j, k, l] == A[i, j]*B[k, l] diff --git a/phivenv/Lib/site-packages/sympy/tensor/tests/test_index_methods.py b/phivenv/Lib/site-packages/sympy/tensor/tests/test_index_methods.py new file mode 100644 index 0000000000000000000000000000000000000000..df20f7e7c1ab392321e8350b95dd07c5639c1865 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/tensor/tests/test_index_methods.py @@ -0,0 +1,227 @@ +from sympy.core import symbols, S, Pow, Function +from sympy.functions import exp +from sympy.testing.pytest import raises +from sympy.tensor.indexed import Idx, IndexedBase +from sympy.tensor.index_methods import IndexConformanceException + +from sympy.tensor.index_methods import (get_contraction_structure, get_indices) + + +def test_trivial_indices(): + x, y = symbols('x y') + assert get_indices(x) == (set(), {}) + assert get_indices(x*y) == (set(), {}) + assert get_indices(x + y) == (set(), {}) + assert get_indices(x**y) == (set(), {}) + + +def test_get_indices_Indexed(): + x = IndexedBase('x') + i, j = Idx('i'), Idx('j') + assert get_indices(x[i, j]) == ({i, j}, {}) + assert get_indices(x[j, i]) == ({j, i}, {}) + + +def test_get_indices_Idx(): + f = Function('f') + i, j = Idx('i'), Idx('j') + assert get_indices(f(i)*j) == ({i, j}, {}) + assert get_indices(f(j, i)) == ({j, i}, {}) + assert get_indices(f(i)*i) == (set(), {}) + + +def test_get_indices_mul(): + x = IndexedBase('x') + y = IndexedBase('y') + i, j = Idx('i'), Idx('j') + assert get_indices(x[j]*y[i]) == ({i, j}, {}) + assert get_indices(x[i]*y[j]) == ({i, j}, {}) + + +def test_get_indices_exceptions(): + x = IndexedBase('x') + y = IndexedBase('y') + i, j = Idx('i'), Idx('j') + raises(IndexConformanceException, lambda: get_indices(x[i] + y[j])) + + +def test_scalar_broadcast(): + x = IndexedBase('x') + y = IndexedBase('y') + i, j = Idx('i'), Idx('j') + assert get_indices(x[i] + y[i, i]) == ({i}, {}) + assert get_indices(x[i] + y[j, j]) == ({i}, {}) + + +def test_get_indices_add(): + x = IndexedBase('x') + y = IndexedBase('y') + A = IndexedBase('A') + i, j, k = Idx('i'), Idx('j'), Idx('k') + assert get_indices(x[i] + 2*y[i]) == ({i}, {}) + assert get_indices(y[i] + 2*A[i, j]*x[j]) == ({i}, {}) + assert get_indices(y[i] + 2*(x[i] + A[i, j]*x[j])) == ({i}, {}) + assert get_indices(y[i] + x[i]*(A[j, j] + 1)) == ({i}, {}) + assert get_indices( + y[i] + x[i]*x[j]*(y[j] + A[j, k]*x[k])) == ({i}, {}) + + +def test_get_indices_Pow(): + x = IndexedBase('x') + y = IndexedBase('y') + A = IndexedBase('A') + i, j, k = Idx('i'), Idx('j'), Idx('k') + assert get_indices(Pow(x[i], y[j])) == ({i, j}, {}) + assert get_indices(Pow(x[i, k], y[j, k])) == ({i, j, k}, {}) + assert get_indices(Pow(A[i, k], y[k] + A[k, j]*x[j])) == ({i, k}, {}) + assert get_indices(Pow(2, x[i])) == get_indices(exp(x[i])) + + # test of a design decision, this may change: + assert get_indices(Pow(x[i], 2)) == ({i}, {}) + + +def test_get_contraction_structure_basic(): + x = IndexedBase('x') + y = IndexedBase('y') + i, j = Idx('i'), Idx('j') + assert get_contraction_structure(x[i]*y[j]) == {None: {x[i]*y[j]}} + assert get_contraction_structure(x[i] + y[j]) == {None: {x[i], y[j]}} + assert get_contraction_structure(x[i]*y[i]) == {(i,): {x[i]*y[i]}} + assert get_contraction_structure( + 1 + x[i]*y[i]) == {None: {S.One}, (i,): {x[i]*y[i]}} + assert get_contraction_structure(x[i]**y[i]) == {None: {x[i]**y[i]}} + + +def test_get_contraction_structure_complex(): + x = IndexedBase('x') + y = IndexedBase('y') + A = IndexedBase('A') + i, j, k = Idx('i'), Idx('j'), Idx('k') + expr1 = y[i] + A[i, j]*x[j] + d1 = {None: {y[i]}, (j,): {A[i, j]*x[j]}} + assert get_contraction_structure(expr1) == d1 + expr2 = expr1*A[k, i] + x[k] + d2 = {None: {x[k]}, (i,): {expr1*A[k, i]}, expr1*A[k, i]: [d1]} + assert get_contraction_structure(expr2) == d2 + + +def test_contraction_structure_simple_Pow(): + x = IndexedBase('x') + y = IndexedBase('y') + i, j, k = Idx('i'), Idx('j'), Idx('k') + ii_jj = x[i, i]**y[j, j] + assert get_contraction_structure(ii_jj) == { + None: {ii_jj}, + ii_jj: [ + {(i,): {x[i, i]}}, + {(j,): {y[j, j]}} + ] + } + + ii_jk = x[i, i]**y[j, k] + assert get_contraction_structure(ii_jk) == { + None: {x[i, i]**y[j, k]}, + x[i, i]**y[j, k]: [ + {(i,): {x[i, i]}} + ] + } + + +def test_contraction_structure_Mul_and_Pow(): + x = IndexedBase('x') + y = IndexedBase('y') + i, j, k = Idx('i'), Idx('j'), Idx('k') + + i_ji = x[i]**(y[j]*x[i]) + assert get_contraction_structure(i_ji) == {None: {i_ji}} + ij_i = (x[i]*y[j])**(y[i]) + assert get_contraction_structure(ij_i) == {None: {ij_i}} + j_ij_i = x[j]*(x[i]*y[j])**(y[i]) + assert get_contraction_structure(j_ij_i) == {(j,): {j_ij_i}} + j_i_ji = x[j]*x[i]**(y[j]*x[i]) + assert get_contraction_structure(j_i_ji) == {(j,): {j_i_ji}} + ij_exp_kki = x[i]*y[j]*exp(y[i]*y[k, k]) + result = get_contraction_structure(ij_exp_kki) + expected = { + (i,): {ij_exp_kki}, + ij_exp_kki: [{ + None: {exp(y[i]*y[k, k])}, + exp(y[i]*y[k, k]): [{ + None: {y[i]*y[k, k]}, + y[i]*y[k, k]: [{(k,): {y[k, k]}}] + }]} + ] + } + assert result == expected + + +def test_contraction_structure_Add_in_Pow(): + x = IndexedBase('x') + y = IndexedBase('y') + i, j, k = Idx('i'), Idx('j'), Idx('k') + s_ii_jj_s = (1 + x[i, i])**(1 + y[j, j]) + expected = { + None: {s_ii_jj_s}, + s_ii_jj_s: [ + {None: {S.One}, (i,): {x[i, i]}}, + {None: {S.One}, (j,): {y[j, j]}} + ] + } + result = get_contraction_structure(s_ii_jj_s) + assert result == expected + + s_ii_jk_s = (1 + x[i, i]) ** (1 + y[j, k]) + expected_2 = { + None: {(x[i, i] + 1)**(y[j, k] + 1)}, + s_ii_jk_s: [ + {None: {S.One}, (i,): {x[i, i]}} + ] + } + result_2 = get_contraction_structure(s_ii_jk_s) + assert result_2 == expected_2 + + +def test_contraction_structure_Pow_in_Pow(): + x = IndexedBase('x') + y = IndexedBase('y') + z = IndexedBase('z') + i, j, k = Idx('i'), Idx('j'), Idx('k') + ii_jj_kk = x[i, i]**y[j, j]**z[k, k] + expected = { + None: {ii_jj_kk}, + ii_jj_kk: [ + {(i,): {x[i, i]}}, + { + None: {y[j, j]**z[k, k]}, + y[j, j]**z[k, k]: [ + {(j,): {y[j, j]}}, + {(k,): {z[k, k]}} + ] + } + ] + } + assert get_contraction_structure(ii_jj_kk) == expected + + +def test_ufunc_support(): + f = Function('f') + g = Function('g') + x = IndexedBase('x') + y = IndexedBase('y') + i, j = Idx('i'), Idx('j') + a = symbols('a') + + assert get_indices(f(x[i])) == ({i}, {}) + assert get_indices(f(x[i], y[j])) == ({i, j}, {}) + assert get_indices(f(y[i])*g(x[i])) == (set(), {}) + assert get_indices(f(a, x[i])) == ({i}, {}) + assert get_indices(f(a, y[i], x[j])*g(x[i])) == ({j}, {}) + assert get_indices(g(f(x[i]))) == ({i}, {}) + + assert get_contraction_structure(f(x[i])) == {None: {f(x[i])}} + assert get_contraction_structure( + f(y[i])*g(x[i])) == {(i,): {f(y[i])*g(x[i])}} + assert get_contraction_structure( + f(y[i])*g(f(x[i]))) == {(i,): {f(y[i])*g(f(x[i]))}} + assert get_contraction_structure( + f(x[j], y[i])*g(x[i])) == {(i,): {f(x[j], y[i])*g(x[i])}} diff --git a/phivenv/Lib/site-packages/sympy/tensor/tests/test_indexed.py b/phivenv/Lib/site-packages/sympy/tensor/tests/test_indexed.py new file mode 100644 index 0000000000000000000000000000000000000000..689ec932c8fcefe0a24de289dd2ffd6820c63f19 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/tensor/tests/test_indexed.py @@ -0,0 +1,511 @@ +from sympy.core import symbols, Symbol, Tuple, oo, Dummy +from sympy.tensor.indexed import IndexException +from sympy.testing.pytest import raises +from sympy.utilities.iterables import iterable + +# import test: +from sympy.concrete.summations import Sum +from sympy.core.function import Function, Subs, Derivative +from sympy.core.relational import (StrictLessThan, GreaterThan, + StrictGreaterThan, LessThan) +from sympy.core.singleton import S +from sympy.functions.elementary.exponential import exp, log +from sympy.functions.elementary.trigonometric import cos, sin +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.series.order import Order +from sympy.sets.fancysets import Range +from sympy.tensor.indexed import IndexedBase, Idx, Indexed + + +def test_Idx_construction(): + i, a, b = symbols('i a b', integer=True) + assert Idx(i) != Idx(i, 1) + assert Idx(i, a) == Idx(i, (0, a - 1)) + assert Idx(i, oo) == Idx(i, (0, oo)) + + x = symbols('x', integer=False) + raises(TypeError, lambda: Idx(x)) + raises(TypeError, lambda: Idx(0.5)) + raises(TypeError, lambda: Idx(i, x)) + raises(TypeError, lambda: Idx(i, 0.5)) + raises(TypeError, lambda: Idx(i, (x, 5))) + raises(TypeError, lambda: Idx(i, (2, x))) + raises(TypeError, lambda: Idx(i, (2, 3.5))) + + +def test_Idx_properties(): + i, a, b = symbols('i a b', integer=True) + assert Idx(i).is_integer + assert Idx(i).name == 'i' + assert Idx(i + 2).name == 'i + 2' + assert Idx('foo').name == 'foo' + + +def test_Idx_bounds(): + i, a, b = symbols('i a b', integer=True) + assert Idx(i).lower is None + assert Idx(i).upper is None + assert Idx(i, a).lower == 0 + assert Idx(i, a).upper == a - 1 + assert Idx(i, 5).lower == 0 + assert Idx(i, 5).upper == 4 + assert Idx(i, oo).lower == 0 + assert Idx(i, oo).upper is oo + assert Idx(i, (a, b)).lower == a + assert Idx(i, (a, b)).upper == b + assert Idx(i, (1, 5)).lower == 1 + assert Idx(i, (1, 5)).upper == 5 + assert Idx(i, (-oo, oo)).lower is -oo + assert Idx(i, (-oo, oo)).upper is oo + + +def test_Idx_fixed_bounds(): + i, a, b, x = symbols('i a b x', integer=True) + assert Idx(x).lower is None + assert Idx(x).upper is None + assert Idx(x, a).lower == 0 + assert Idx(x, a).upper == a - 1 + assert Idx(x, 5).lower == 0 + assert Idx(x, 5).upper == 4 + assert Idx(x, oo).lower == 0 + assert Idx(x, oo).upper is oo + assert Idx(x, (a, b)).lower == a + assert Idx(x, (a, b)).upper == b + assert Idx(x, (1, 5)).lower == 1 + assert Idx(x, (1, 5)).upper == 5 + assert Idx(x, (-oo, oo)).lower is -oo + assert Idx(x, (-oo, oo)).upper is oo + + +def test_Idx_inequalities(): + i14 = Idx("i14", (1, 4)) + i79 = Idx("i79", (7, 9)) + i46 = Idx("i46", (4, 6)) + i35 = Idx("i35", (3, 5)) + + assert i14 <= 5 + assert i14 < 5 + assert not (i14 >= 5) + assert not (i14 > 5) + + assert 5 >= i14 + assert 5 > i14 + assert not (5 <= i14) + assert not (5 < i14) + + assert LessThan(i14, 5) + assert StrictLessThan(i14, 5) + assert not GreaterThan(i14, 5) + assert not StrictGreaterThan(i14, 5) + + assert i14 <= 4 + assert isinstance(i14 < 4, StrictLessThan) + assert isinstance(i14 >= 4, GreaterThan) + assert not (i14 > 4) + + assert isinstance(i14 <= 1, LessThan) + assert not (i14 < 1) + assert i14 >= 1 + assert isinstance(i14 > 1, StrictGreaterThan) + + assert not (i14 <= 0) + assert not (i14 < 0) + assert i14 >= 0 + assert i14 > 0 + + from sympy.abc import x + + assert isinstance(i14 < x, StrictLessThan) + assert isinstance(i14 > x, StrictGreaterThan) + assert isinstance(i14 <= x, LessThan) + assert isinstance(i14 >= x, GreaterThan) + + assert i14 < i79 + assert i14 <= i79 + assert not (i14 > i79) + assert not (i14 >= i79) + + assert i14 <= i46 + assert isinstance(i14 < i46, StrictLessThan) + assert isinstance(i14 >= i46, GreaterThan) + assert not (i14 > i46) + + assert isinstance(i14 < i35, StrictLessThan) + assert isinstance(i14 > i35, StrictGreaterThan) + assert isinstance(i14 <= i35, LessThan) + assert isinstance(i14 >= i35, GreaterThan) + + iNone1 = Idx("iNone1") + iNone2 = Idx("iNone2") + + assert isinstance(iNone1 < iNone2, StrictLessThan) + assert isinstance(iNone1 > iNone2, StrictGreaterThan) + assert isinstance(iNone1 <= iNone2, LessThan) + assert isinstance(iNone1 >= iNone2, GreaterThan) + + +def test_Idx_inequalities_current_fails(): + i14 = Idx("i14", (1, 4)) + + assert S(5) >= i14 + assert S(5) > i14 + assert not (S(5) <= i14) + assert not (S(5) < i14) + + +def test_Idx_func_args(): + i, a, b = symbols('i a b', integer=True) + ii = Idx(i) + assert ii.func(*ii.args) == ii + ii = Idx(i, a) + assert ii.func(*ii.args) == ii + ii = Idx(i, (a, b)) + assert ii.func(*ii.args) == ii + + +def test_Idx_subs(): + i, a, b = symbols('i a b', integer=True) + assert Idx(i, a).subs(a, b) == Idx(i, b) + assert Idx(i, a).subs(i, b) == Idx(b, a) + + assert Idx(i).subs(i, 2) == Idx(2) + assert Idx(i, a).subs(a, 2) == Idx(i, 2) + assert Idx(i, (a, b)).subs(i, 2) == Idx(2, (a, b)) + + +def test_IndexedBase_sugar(): + i, j = symbols('i j', integer=True) + a = symbols('a') + A1 = Indexed(a, i, j) + A2 = IndexedBase(a) + assert A1 == A2[i, j] + assert A1 == A2[(i, j)] + assert A1 == A2[[i, j]] + assert A1 == A2[Tuple(i, j)] + assert all(a.is_Integer for a in A2[1, 0].args[1:]) + + +def test_IndexedBase_subs(): + i = symbols('i', integer=True) + a, b = symbols('a b') + A = IndexedBase(a) + B = IndexedBase(b) + assert A[i] == B[i].subs(b, a) + C = {1: 2} + assert C[1] == A[1].subs(A, C) + + +def test_IndexedBase_shape(): + i, j, m, n = symbols('i j m n', integer=True) + a = IndexedBase('a', shape=(m, m)) + b = IndexedBase('a', shape=(m, n)) + assert b.shape == Tuple(m, n) + assert a[i, j] != b[i, j] + assert a[i, j] == b[i, j].subs(n, m) + assert b.func(*b.args) == b + assert b[i, j].func(*b[i, j].args) == b[i, j] + raises(IndexException, lambda: b[i]) + raises(IndexException, lambda: b[i, i, j]) + F = IndexedBase("F", shape=m) + assert F.shape == Tuple(m) + assert F[i].subs(i, j) == F[j] + raises(IndexException, lambda: F[i, j]) + + +def test_IndexedBase_assumptions(): + i = Symbol('i', integer=True) + a = Symbol('a') + A = IndexedBase(a, positive=True) + for c in (A, A[i]): + assert c.is_real + assert c.is_complex + assert not c.is_imaginary + assert c.is_nonnegative + assert c.is_nonzero + assert c.is_commutative + assert log(exp(c)) == c + + assert A != IndexedBase(a) + assert A == IndexedBase(a, positive=True, real=True) + assert A[i] != Indexed(a, i) + + +def test_IndexedBase_assumptions_inheritance(): + I = Symbol('I', integer=True) + I_inherit = IndexedBase(I) + I_explicit = IndexedBase('I', integer=True) + + assert I_inherit.is_integer + assert I_explicit.is_integer + assert I_inherit.label.is_integer + assert I_explicit.label.is_integer + assert I_inherit == I_explicit + + +def test_issue_17652(): + """Regression test issue #17652. + + IndexedBase.label should not upcast subclasses of Symbol + """ + class SubClass(Symbol): + pass + + x = SubClass('X') + assert type(x) == SubClass + base = IndexedBase(x) + assert type(x) == SubClass + assert type(base.label) == SubClass + + +def test_Indexed_constructor(): + i, j = symbols('i j', integer=True) + A = Indexed('A', i, j) + assert A == Indexed(Symbol('A'), i, j) + assert A == Indexed(IndexedBase('A'), i, j) + raises(TypeError, lambda: Indexed(A, i, j)) + raises(IndexException, lambda: Indexed("A")) + assert A.free_symbols == {A, A.base.label, i, j} + + +def test_Indexed_func_args(): + i, j = symbols('i j', integer=True) + a = symbols('a') + A = Indexed(a, i, j) + assert A == A.func(*A.args) + + +def test_Indexed_subs(): + i, j, k = symbols('i j k', integer=True) + a, b = symbols('a b') + A = IndexedBase(a) + B = IndexedBase(b) + assert A[i, j] == B[i, j].subs(b, a) + assert A[i, j] == A[i, k].subs(k, j) + + +def test_Indexed_properties(): + i, j = symbols('i j', integer=True) + A = Indexed('A', i, j) + assert A.name == 'A[i, j]' + assert A.rank == 2 + assert A.indices == (i, j) + assert A.base == IndexedBase('A') + assert A.ranges == [None, None] + raises(IndexException, lambda: A.shape) + + n, m = symbols('n m', integer=True) + assert Indexed('A', Idx( + i, m), Idx(j, n)).ranges == [Tuple(0, m - 1), Tuple(0, n - 1)] + assert Indexed('A', Idx(i, m), Idx(j, n)).shape == Tuple(m, n) + raises(IndexException, lambda: Indexed("A", Idx(i, m), Idx(j)).shape) + + +def test_Indexed_shape_precedence(): + i, j = symbols('i j', integer=True) + o, p = symbols('o p', integer=True) + n, m = symbols('n m', integer=True) + a = IndexedBase('a', shape=(o, p)) + assert a.shape == Tuple(o, p) + assert Indexed( + a, Idx(i, m), Idx(j, n)).ranges == [Tuple(0, m - 1), Tuple(0, n - 1)] + assert Indexed(a, Idx(i, m), Idx(j, n)).shape == Tuple(o, p) + assert Indexed( + a, Idx(i, m), Idx(j)).ranges == [Tuple(0, m - 1), (None, None)] + assert Indexed(a, Idx(i, m), Idx(j)).shape == Tuple(o, p) + + +def test_complex_indices(): + i, j = symbols('i j', integer=True) + A = Indexed('A', i, i + j) + assert A.rank == 2 + assert A.indices == (i, i + j) + + +def test_not_interable(): + i, j = symbols('i j', integer=True) + A = Indexed('A', i, i + j) + assert not iterable(A) + + +def test_Indexed_coeff(): + N = Symbol('N', integer=True) + len_y = N + i = Idx('i', len_y-1) + y = IndexedBase('y', shape=(len_y,)) + a = (1/y[i+1]*y[i]).coeff(y[i]) + b = (y[i]/y[i+1]).coeff(y[i]) + assert a == b + + +def test_differentiation(): + from sympy.functions.special.tensor_functions import KroneckerDelta + i, j, k, l = symbols('i j k l', cls=Idx) + a = symbols('a') + m, n = symbols("m, n", integer=True, finite=True) + assert m.is_real + h, L = symbols('h L', cls=IndexedBase) + hi, hj = h[i], h[j] + + expr = hi + assert expr.diff(hj) == KroneckerDelta(i, j) + assert expr.diff(hi) == KroneckerDelta(i, i) + + expr = S(2) * hi + assert expr.diff(hj) == S(2) * KroneckerDelta(i, j) + assert expr.diff(hi) == S(2) * KroneckerDelta(i, i) + assert expr.diff(a) is S.Zero + + assert Sum(expr, (i, -oo, oo)).diff(hj) == Sum(2*KroneckerDelta(i, j), (i, -oo, oo)) + assert Sum(expr.diff(hj), (i, -oo, oo)) == Sum(2*KroneckerDelta(i, j), (i, -oo, oo)) + assert Sum(expr, (i, -oo, oo)).diff(hj).doit() == 2 + + assert Sum(expr.diff(hi), (i, -oo, oo)).doit() == Sum(2, (i, -oo, oo)).doit() + assert Sum(expr, (i, -oo, oo)).diff(hi).doit() is oo + + expr = a * hj * hj / S(2) + assert expr.diff(hi) == a * h[j] * KroneckerDelta(i, j) + assert expr.diff(a) == hj * hj / S(2) + assert expr.diff(a, 2) is S.Zero + + assert Sum(expr, (i, -oo, oo)).diff(hi) == Sum(a*KroneckerDelta(i, j)*h[j], (i, -oo, oo)) + assert Sum(expr.diff(hi), (i, -oo, oo)) == Sum(a*KroneckerDelta(i, j)*h[j], (i, -oo, oo)) + assert Sum(expr, (i, -oo, oo)).diff(hi).doit() == a*h[j] + + assert Sum(expr, (j, -oo, oo)).diff(hi) == Sum(a*KroneckerDelta(i, j)*h[j], (j, -oo, oo)) + assert Sum(expr.diff(hi), (j, -oo, oo)) == Sum(a*KroneckerDelta(i, j)*h[j], (j, -oo, oo)) + assert Sum(expr, (j, -oo, oo)).diff(hi).doit() == a*h[i] + + expr = a * sin(hj * hj) + assert expr.diff(hi) == 2*a*cos(hj * hj) * hj * KroneckerDelta(i, j) + assert expr.diff(hj) == 2*a*cos(hj * hj) * hj + + expr = a * L[i, j] * h[j] + assert expr.diff(hi) == a*L[i, j]*KroneckerDelta(i, j) + assert expr.diff(hj) == a*L[i, j] + assert expr.diff(L[i, j]) == a*h[j] + assert expr.diff(L[k, l]) == a*KroneckerDelta(i, k)*KroneckerDelta(j, l)*h[j] + assert expr.diff(L[i, l]) == a*KroneckerDelta(j, l)*h[j] + + assert Sum(expr, (j, -oo, oo)).diff(L[k, l]) == Sum(a * KroneckerDelta(i, k) * KroneckerDelta(j, l) * h[j], (j, -oo, oo)) + assert Sum(expr, (j, -oo, oo)).diff(L[k, l]).doit() == a * KroneckerDelta(i, k) * h[l] + + assert h[m].diff(h[m]) == 1 + assert h[m].diff(h[n]) == KroneckerDelta(m, n) + assert Sum(a*h[m], (m, -oo, oo)).diff(h[n]) == Sum(a*KroneckerDelta(m, n), (m, -oo, oo)) + assert Sum(a*h[m], (m, -oo, oo)).diff(h[n]).doit() == a + assert Sum(a*h[m], (n, -oo, oo)).diff(h[n]) == Sum(a*KroneckerDelta(m, n), (n, -oo, oo)) + assert Sum(a*h[m], (m, -oo, oo)).diff(h[m]).doit() == oo*a + + +def test_indexed_series(): + A = IndexedBase("A") + i = symbols("i", integer=True) + assert sin(A[i]).series(A[i]) == A[i] - A[i]**3/6 + A[i]**5/120 + Order(A[i]**6, A[i]) + + +def test_indexed_is_constant(): + A = IndexedBase("A") + i, j, k = symbols("i,j,k") + assert not A[i].is_constant() + assert A[i].is_constant(j) + assert not A[1+2*i, k].is_constant() + assert not A[1+2*i, k].is_constant(i) + assert A[1+2*i, k].is_constant(j) + assert not A[1+2*i, k].is_constant(k) + + +def test_issue_12533(): + d = IndexedBase('d') + assert IndexedBase(range(5)) == Range(0, 5, 1) + assert d[0].subs(Symbol("d"), range(5)) == 0 + assert d[0].subs(d, range(5)) == 0 + assert d[1].subs(d, range(5)) == 1 + assert Indexed(Range(5), 2) == 2 + + +def test_issue_12780(): + n = symbols("n") + i = Idx("i", (0, n)) + raises(TypeError, lambda: i.subs(n, 1.5)) + + +def test_issue_18604(): + m = symbols("m") + assert Idx("i", m).name == 'i' + assert Idx("i", m).lower == 0 + assert Idx("i", m).upper == m - 1 + m = symbols("m", real=False) + raises(TypeError, lambda: Idx("i", m)) + +def test_Subs_with_Indexed(): + A = IndexedBase("A") + i, j, k = symbols("i,j,k") + x, y, z = symbols("x,y,z") + f = Function("f") + + assert Subs(A[i], A[i], A[j]).diff(A[j]) == 1 + assert Subs(A[i], A[i], x).diff(A[i]) == 0 + assert Subs(A[i], A[i], x).diff(A[j]) == 0 + assert Subs(A[i], A[i], x).diff(x) == 1 + assert Subs(A[i], A[i], x).diff(y) == 0 + assert Subs(A[i], A[i], A[j]).diff(A[k]) == KroneckerDelta(j, k) + assert Subs(x, x, A[i]).diff(A[j]) == KroneckerDelta(i, j) + assert Subs(f(A[i]), A[i], x).diff(A[j]) == 0 + assert Subs(f(A[i]), A[i], A[k]).diff(A[j]) == Derivative(f(A[k]), A[k])*KroneckerDelta(j, k) + assert Subs(x, x, A[i]**2).diff(A[j]) == 2*KroneckerDelta(i, j)*A[i] + assert Subs(A[i], A[i], A[j]**2).diff(A[k]) == 2*KroneckerDelta(j, k)*A[j] + + assert Subs(A[i]*x, x, A[i]).diff(A[i]) == 2*A[i] + assert Subs(A[i]*x, x, A[i]).diff(A[j]) == 2*A[i]*KroneckerDelta(i, j) + assert Subs(A[i]*x, x, A[j]).diff(A[i]) == A[j] + A[i]*KroneckerDelta(i, j) + assert Subs(A[i]*x, x, A[j]).diff(A[j]) == A[i] + A[j]*KroneckerDelta(i, j) + assert Subs(A[i]*x, x, A[i]).diff(A[k]) == 2*A[i]*KroneckerDelta(i, k) + assert Subs(A[i]*x, x, A[j]).diff(A[k]) == KroneckerDelta(i, k)*A[j] + KroneckerDelta(j, k)*A[i] + + assert Subs(A[i]*x, A[i], x).diff(A[i]) == 0 + assert Subs(A[i]*x, A[i], x).diff(A[j]) == 0 + assert Subs(A[i]*x, A[j], x).diff(A[i]) == x + assert Subs(A[i]*x, A[j], x).diff(A[j]) == x*KroneckerDelta(i, j) + assert Subs(A[i]*x, A[i], x).diff(A[k]) == 0 + assert Subs(A[i]*x, A[j], x).diff(A[k]) == x*KroneckerDelta(i, k) + + +def test_complicated_derivative_with_Indexed(): + x, y = symbols("x,y", cls=IndexedBase) + sigma = symbols("sigma") + i, j, k = symbols("i,j,k") + m0,m1,m2,m3,m4,m5 = symbols("m0:6") + f = Function("f") + + expr = f((x[i] - y[i])**2/sigma) + _xi_1 = symbols("xi_1", cls=Dummy) + assert expr.diff(x[m0]).dummy_eq( + (x[i] - y[i])*KroneckerDelta(i, m0)*\ + 2*Subs( + Derivative(f(_xi_1), _xi_1), + (_xi_1,), + ((x[i] - y[i])**2/sigma,) + )/sigma + ) + assert expr.diff(x[m0]).diff(x[m1]).dummy_eq( + 2*KroneckerDelta(i, m0)*\ + KroneckerDelta(i, m1)*Subs( + Derivative(f(_xi_1), _xi_1), + (_xi_1,), + ((x[i] - y[i])**2/sigma,) + )/sigma + \ + 4*(x[i] - y[i])**2*KroneckerDelta(i, m0)*KroneckerDelta(i, m1)*\ + Subs( + Derivative(f(_xi_1), _xi_1, _xi_1), + (_xi_1,), + ((x[i] - y[i])**2/sigma,) + )/sigma**2 + ) + + +def test_IndexedBase_commutative(): + t = IndexedBase('t', commutative=False) + u = IndexedBase('u', commutative=False) + v = IndexedBase('v') + assert t[0]*v[0] == v[0]*t[0] + assert t[0]*u[0] != u[0]*t[0] diff --git a/phivenv/Lib/site-packages/sympy/tensor/tests/test_printing.py b/phivenv/Lib/site-packages/sympy/tensor/tests/test_printing.py new file mode 100644 index 0000000000000000000000000000000000000000..9f3cf7f0591a7012c93354ab7b8d7e010def38bb --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/tensor/tests/test_printing.py @@ -0,0 +1,13 @@ +from sympy.tensor.tensor import TensorIndexType, tensor_indices, TensorHead +from sympy import I + +def test_printing_TensMul(): + R3 = TensorIndexType('R3', dim=3) + p, q = tensor_indices("p q", R3) + K = TensorHead("K", [R3]) + + assert repr(2*K(p)) == "2*K(p)" + assert repr(-K(p)) == "-K(p)" + assert repr(-2*K(p)*K(q)) == "-2*K(p)*K(q)" + assert repr(-I*K(p)) == "-I*K(p)" + assert repr(I*K(p)) == "I*K(p)" diff --git a/phivenv/Lib/site-packages/sympy/tensor/tests/test_tensor.py b/phivenv/Lib/site-packages/sympy/tensor/tests/test_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..3113f5be9bcd32224f3525b5d831b6d7476c39e3 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/tensor/tests/test_tensor.py @@ -0,0 +1,2218 @@ +from sympy.concrete.summations import Sum +from sympy.core.function import expand +from sympy.core.numbers import Integer +from sympy.matrices.dense import (Matrix, eye) +from sympy.tensor.indexed import Indexed +from sympy.combinatorics import Permutation +from sympy.core import S, Rational, Symbol, Basic, Add, Wild, Function +from sympy.core.containers import Tuple +from sympy.core.symbol import symbols +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.integrals import integrate +from sympy.tensor.array import Array +from sympy.tensor.tensor import TensorIndexType, tensor_indices, TensorSymmetry, \ + get_symmetric_group_sgs, TensorIndex, tensor_mul, TensAdd, \ + riemann_cyclic_replace, riemann_cyclic, TensMul, tensor_heads, \ + TensorManager, TensExpr, TensorHead, canon_bp, \ + tensorhead, tensorsymmetry, TensorType, substitute_indices, \ + WildTensorIndex, WildTensorHead, _WildTensExpr +from sympy.testing.pytest import raises, XFAIL, warns_deprecated_sympy +from sympy.matrices import diag + +def _is_equal(arg1, arg2): + if isinstance(arg1, TensExpr): + return arg1.equals(arg2) + elif isinstance(arg2, TensExpr): + return arg2.equals(arg1) + return arg1 == arg2 + + +#################### Tests from tensor_can.py ####################### +def test_canonicalize_no_slot_sym(): + # A_d0 * B^d0; T_c = A^d0*B_d0 + Lorentz = TensorIndexType('Lorentz', dummy_name='L') + a, b, d0, d1 = tensor_indices('a,b,d0,d1', Lorentz) + A, B = tensor_heads('A,B', [Lorentz], TensorSymmetry.no_symmetry(1)) + t = A(-d0)*B(d0) + tc = t.canon_bp() + assert str(tc) == 'A(L_0)*B(-L_0)' + + # A^a * B^b; T_c = T + t = A(a)*B(b) + tc = t.canon_bp() + assert tc == t + # B^b * A^a + t1 = B(b)*A(a) + tc = t1.canon_bp() + assert str(tc) == 'A(a)*B(b)' + + # A symmetric + # A^{b}_{d0}*A^{d0, a}; T_c = A^{a d0}*A{b}_{d0} + A = TensorHead('A', [Lorentz]*2, TensorSymmetry.fully_symmetric(2)) + t = A(b, -d0)*A(d0, a) + tc = t.canon_bp() + assert str(tc) == 'A(a, L_0)*A(b, -L_0)' + + # A^{d1}_{d0}*B^d0*C_d1 + # T_c = A^{d0 d1}*B_d0*C_d1 + B, C = tensor_heads('B,C', [Lorentz], TensorSymmetry.no_symmetry(1)) + t = A(d1, -d0)*B(d0)*C(-d1) + tc = t.canon_bp() + assert str(tc) == 'A(L_0, L_1)*B(-L_0)*C(-L_1)' + + # A without symmetry + # A^{d1}_{d0}*B^d0*C_d1 ord=[d0,-d0,d1,-d1]; g = [2,1,0,3,4,5] + # T_c = A^{d0 d1}*B_d1*C_d0; can = [0,2,3,1,4,5] + A = TensorHead('A', [Lorentz]*2, TensorSymmetry.no_symmetry(2)) + t = A(d1, -d0)*B(d0)*C(-d1) + tc = t.canon_bp() + assert str(tc) == 'A(L_0, L_1)*B(-L_1)*C(-L_0)' + + # A, B without symmetry + # A^{d1}_{d0}*B_{d1}^{d0} + # T_c = A^{d0 d1}*B_{d0 d1} + B = TensorHead('B', [Lorentz]*2, TensorSymmetry.no_symmetry(2)) + t = A(d1, -d0)*B(-d1, d0) + tc = t.canon_bp() + assert str(tc) == 'A(L_0, L_1)*B(-L_0, -L_1)' + # A_{d0}^{d1}*B_{d1}^{d0} + # T_c = A^{d0 d1}*B_{d1 d0} + t = A(-d0, d1)*B(-d1, d0) + tc = t.canon_bp() + assert str(tc) == 'A(L_0, L_1)*B(-L_1, -L_0)' + + # A, B, C without symmetry + # A^{d1 d0}*B_{a d0}*C_{d1 b} + # T_c=A^{d0 d1}*B_{a d1}*C_{d0 b} + C = TensorHead('C', [Lorentz]*2, TensorSymmetry.no_symmetry(2)) + t = A(d1, d0)*B(-a, -d0)*C(-d1, -b) + tc = t.canon_bp() + assert str(tc) == 'A(L_0, L_1)*B(-a, -L_1)*C(-L_0, -b)' + + # A symmetric, B and C without symmetry + # A^{d1 d0}*B_{a d0}*C_{d1 b} + # T_c = A^{d0 d1}*B_{a d0}*C_{d1 b} + A = TensorHead('A', [Lorentz]*2, TensorSymmetry.fully_symmetric(2)) + t = A(d1, d0)*B(-a, -d0)*C(-d1, -b) + tc = t.canon_bp() + assert str(tc) == 'A(L_0, L_1)*B(-a, -L_0)*C(-L_1, -b)' + + # A and C symmetric, B without symmetry + # A^{d1 d0}*B_{a d0}*C_{d1 b} ord=[a,b,d0,-d0,d1,-d1] + # T_c = A^{d0 d1}*B_{a d0}*C_{b d1} + C = TensorHead('C', [Lorentz]*2, TensorSymmetry.fully_symmetric(2)) + t = A(d1, d0)*B(-a, -d0)*C(-d1, -b) + tc = t.canon_bp() + assert str(tc) == 'A(L_0, L_1)*B(-a, -L_0)*C(-b, -L_1)' + +def test_canonicalize_no_dummies(): + Lorentz = TensorIndexType('Lorentz', dummy_name='L') + a, b, c, d = tensor_indices('a, b, c, d', Lorentz) + + # A commuting + # A^c A^b A^a + # T_c = A^a A^b A^c + A = TensorHead('A', [Lorentz], TensorSymmetry.no_symmetry(1)) + t = A(c)*A(b)*A(a) + tc = t.canon_bp() + assert str(tc) == 'A(a)*A(b)*A(c)' + + # A anticommuting + # A^c A^b A^a + # T_c = -A^a A^b A^c + A = TensorHead('A', [Lorentz], TensorSymmetry.no_symmetry(1), 1) + t = A(c)*A(b)*A(a) + tc = t.canon_bp() + assert str(tc) == '-A(a)*A(b)*A(c)' + + # A commuting and symmetric + # A^{b,d}*A^{c,a} + # T_c = A^{a c}*A^{b d} + A = TensorHead('A', [Lorentz]*2, TensorSymmetry.fully_symmetric(2)) + t = A(b, d)*A(c, a) + tc = t.canon_bp() + assert str(tc) == 'A(a, c)*A(b, d)' + + # A anticommuting and symmetric + # A^{b,d}*A^{c,a} + # T_c = -A^{a c}*A^{b d} + A = TensorHead('A', [Lorentz]*2, TensorSymmetry.fully_symmetric(2), 1) + t = A(b, d)*A(c, a) + tc = t.canon_bp() + assert str(tc) == '-A(a, c)*A(b, d)' + + # A^{c,a}*A^{b,d} + # T_c = A^{a c}*A^{b d} + t = A(c, a)*A(b, d) + tc = t.canon_bp() + assert str(tc) == 'A(a, c)*A(b, d)' + +def test_tensorhead_construction_without_symmetry(): + L = TensorIndexType('Lorentz') + A1 = TensorHead('A', [L, L]) + A2 = TensorHead('A', [L, L], TensorSymmetry.no_symmetry(2)) + assert A1 == A2 + A3 = TensorHead('A', [L, L], TensorSymmetry.fully_symmetric(2)) # Symmetric + assert A1 != A3 + +def test_no_metric_symmetry(): + # no metric symmetry; A no symmetry + # A^d1_d0 * A^d0_d1 + # T_c = A^d0_d1 * A^d1_d0 + Lorentz = TensorIndexType('Lorentz', dummy_name='L', metric_symmetry=0) + d0, d1, d2, d3 = tensor_indices('d:4', Lorentz) + A = TensorHead('A', [Lorentz]*2, TensorSymmetry.no_symmetry(2)) + t = A(d1, -d0)*A(d0, -d1) + tc = t.canon_bp() + assert str(tc) == 'A(L_0, -L_1)*A(L_1, -L_0)' + + # A^d1_d2 * A^d0_d3 * A^d2_d1 * A^d3_d0 + # T_c = A^d0_d1 * A^d1_d0 * A^d2_d3 * A^d3_d2 + t = A(d1, -d2)*A(d0, -d3)*A(d2, -d1)*A(d3, -d0) + tc = t.canon_bp() + assert str(tc) == 'A(L_0, -L_1)*A(L_1, -L_0)*A(L_2, -L_3)*A(L_3, -L_2)' + + # A^d0_d2 * A^d1_d3 * A^d3_d0 * A^d2_d1 + # T_c = A^d0_d1 * A^d1_d2 * A^d2_d3 * A^d3_d0 + t = A(d0, -d1)*A(d1, -d2)*A(d2, -d3)*A(d3, -d0) + tc = t.canon_bp() + assert str(tc) == 'A(L_0, -L_1)*A(L_1, -L_2)*A(L_2, -L_3)*A(L_3, -L_0)' + +def test_canonicalize1(): + Lorentz = TensorIndexType('Lorentz', dummy_name='L') + a, a0, a1, a2, a3, b, d0, d1, d2, d3 = \ + tensor_indices('a,a0,a1,a2,a3,b,d0,d1,d2,d3', Lorentz) + + # A_d0*A^d0; ord = [d0,-d0] + # T_c = A^d0*A_d0 + A = TensorHead('A', [Lorentz], TensorSymmetry.no_symmetry(1)) + t = A(-d0)*A(d0) + tc = t.canon_bp() + assert str(tc) == 'A(L_0)*A(-L_0)' + + # A commuting + # A_d0*A_d1*A_d2*A^d2*A^d1*A^d0 + # T_c = A^d0*A_d0*A^d1*A_d1*A^d2*A_d2 + t = A(-d0)*A(-d1)*A(-d2)*A(d2)*A(d1)*A(d0) + tc = t.canon_bp() + assert str(tc) == 'A(L_0)*A(-L_0)*A(L_1)*A(-L_1)*A(L_2)*A(-L_2)' + + # A anticommuting + # A_d0*A_d1*A_d2*A^d2*A^d1*A^d0 + # T_c 0 + A = TensorHead('A', [Lorentz], TensorSymmetry.no_symmetry(1), 1) + t = A(-d0)*A(-d1)*A(-d2)*A(d2)*A(d1)*A(d0) + tc = t.canon_bp() + assert tc == 0 + + # A commuting symmetric + # A^{d0 b}*A^a_d1*A^d1_d0 + # T_c = A^{a d0}*A^{b d1}*A_{d0 d1} + A = TensorHead('A', [Lorentz]*2, TensorSymmetry.fully_symmetric(2)) + t = A(d0, b)*A(a, -d1)*A(d1, -d0) + tc = t.canon_bp() + assert str(tc) == 'A(a, L_0)*A(b, L_1)*A(-L_0, -L_1)' + + # A, B commuting symmetric + # A^{d0 b}*A^d1_d0*B^a_d1 + # T_c = A^{b d0}*A_d0^d1*B^a_d1 + B = TensorHead('B', [Lorentz]*2, TensorSymmetry.fully_symmetric(2)) + t = A(d0, b)*A(d1, -d0)*B(a, -d1) + tc = t.canon_bp() + assert str(tc) == 'A(b, L_0)*A(-L_0, L_1)*B(a, -L_1)' + + # A commuting symmetric + # A^{d1 d0 b}*A^{a}_{d1 d0}; ord=[a,b, d0,-d0,d1,-d1] + # T_c = A^{a d0 d1}*A^{b}_{d0 d1} + A = TensorHead('A', [Lorentz]*3, TensorSymmetry.fully_symmetric(3)) + t = A(d1, d0, b)*A(a, -d1, -d0) + tc = t.canon_bp() + assert str(tc) == 'A(a, L_0, L_1)*A(b, -L_0, -L_1)' + + # A^{d3 d0 d2}*A^a0_{d1 d2}*A^d1_d3^a1*A^{a2 a3}_d0 + # T_c = A^{a0 d0 d1}*A^a1_d0^d2*A^{a2 a3 d3}*A_{d1 d2 d3} + t = A(d3, d0, d2)*A(a0, -d1, -d2)*A(d1, -d3, a1)*A(a2, a3, -d0) + tc = t.canon_bp() + assert str(tc) == 'A(a0, L_0, L_1)*A(a1, -L_0, L_2)*A(a2, a3, L_3)*A(-L_1, -L_2, -L_3)' + + # A commuting symmetric, B antisymmetric + # A^{d0 d1 d2} * A_{d2 d3 d1} * B_d0^d3 + # in this esxample and in the next three, + # renaming dummy indices and using symmetry of A, + # T = A^{d0 d1 d2} * A_{d0 d1 d3} * B_d2^d3 + # can = 0 + A = TensorHead('A', [Lorentz]*3, TensorSymmetry.fully_symmetric(3)) + B = TensorHead('B', [Lorentz]*2, TensorSymmetry.fully_symmetric(-2)) + t = A(d0, d1, d2)*A(-d2, -d3, -d1)*B(-d0, d3) + tc = t.canon_bp() + assert tc == 0 + + # A anticommuting symmetric, B antisymmetric + # A^{d0 d1 d2} * A_{d2 d3 d1} * B_d0^d3 + # T_c = A^{d0 d1 d2} * A_{d0 d1}^d3 * B_{d2 d3} + A = TensorHead('A', [Lorentz]*3, TensorSymmetry.fully_symmetric(3), 1) + B = TensorHead('B', [Lorentz]*2, TensorSymmetry.fully_symmetric(-2)) + t = A(d0, d1, d2)*A(-d2, -d3, -d1)*B(-d0, d3) + tc = t.canon_bp() + assert str(tc) == 'A(L_0, L_1, L_2)*A(-L_0, -L_1, L_3)*B(-L_2, -L_3)' + + # A anticommuting symmetric, B antisymmetric commuting, antisymmetric metric + # A^{d0 d1 d2} * A_{d2 d3 d1} * B_d0^d3 + # T_c = -A^{d0 d1 d2} * A_{d0 d1}^d3 * B_{d2 d3} + Spinor = TensorIndexType('Spinor', dummy_name='S', metric_symmetry=-1) + a, a0, a1, a2, a3, b, d0, d1, d2, d3 = \ + tensor_indices('a,a0,a1,a2,a3,b,d0,d1,d2,d3', Spinor) + A = TensorHead('A', [Spinor]*3, TensorSymmetry.fully_symmetric(3), 1) + B = TensorHead('B', [Spinor]*2, TensorSymmetry.fully_symmetric(-2)) + t = A(d0, d1, d2)*A(-d2, -d3, -d1)*B(-d0, d3) + tc = t.canon_bp() + assert str(tc) == '-A(S_0, S_1, S_2)*A(-S_0, -S_1, S_3)*B(-S_2, -S_3)' + + # A anticommuting symmetric, B antisymmetric anticommuting, + # no metric symmetry + # A^{d0 d1 d2} * A_{d2 d3 d1} * B_d0^d3 + # T_c = A^{d0 d1 d2} * A_{d0 d1 d3} * B_d2^d3 + Mat = TensorIndexType('Mat', metric_symmetry=0, dummy_name='M') + a, a0, a1, a2, a3, b, d0, d1, d2, d3 = \ + tensor_indices('a,a0,a1,a2,a3,b,d0,d1,d2,d3', Mat) + A = TensorHead('A', [Mat]*3, TensorSymmetry.fully_symmetric(3), 1) + B = TensorHead('B', [Mat]*2, TensorSymmetry.fully_symmetric(-2)) + t = A(d0, d1, d2)*A(-d2, -d3, -d1)*B(-d0, d3) + tc = t.canon_bp() + assert str(tc) == 'A(M_0, M_1, M_2)*A(-M_0, -M_1, -M_3)*B(-M_2, M_3)' + + # Gamma anticommuting + # Gamma_{mu nu} * gamma^rho * Gamma^{nu mu alpha} + # T_c = -Gamma^{mu nu} * gamma^rho * Gamma_{alpha mu nu} + alpha, beta, gamma, mu, nu, rho = \ + tensor_indices('alpha,beta,gamma,mu,nu,rho', Lorentz) + Gamma = TensorHead('Gamma', [Lorentz], + TensorSymmetry.fully_symmetric(1), 2) + Gamma2 = TensorHead('Gamma', [Lorentz]*2, + TensorSymmetry.fully_symmetric(-2), 2) + Gamma3 = TensorHead('Gamma', [Lorentz]*3, + TensorSymmetry.fully_symmetric(-3), 2) + t = Gamma2(-mu, -nu)*Gamma(rho)*Gamma3(nu, mu, alpha) + tc = t.canon_bp() + assert str(tc) == '-Gamma(L_0, L_1)*Gamma(rho)*Gamma(alpha, -L_0, -L_1)' + + # Gamma_{mu nu} * Gamma^{gamma beta} * gamma_rho * Gamma^{nu mu alpha} + # T_c = Gamma^{mu nu} * Gamma^{beta gamma} * gamma_rho * Gamma^alpha_{mu nu} + t = Gamma2(mu, nu)*Gamma2(beta, gamma)*Gamma(-rho)*Gamma3(alpha, -mu, -nu) + tc = t.canon_bp() + assert str(tc) == 'Gamma(L_0, L_1)*Gamma(beta, gamma)*Gamma(-rho)*Gamma(alpha, -L_0, -L_1)' + + # f^a_{b,c} antisymmetric in b,c; A_mu^a no symmetry + # f^c_{d a} * f_{c e b} * A_mu^d * A_nu^a * A^{nu e} * A^{mu b} + # g = [8,11,5, 9,13,7, 1,10, 3,4, 2,12, 0,6, 14,15] + # T_c = -f^{a b c} * f_a^{d e} * A^mu_b * A_{mu d} * A^nu_c * A_{nu e} + Flavor = TensorIndexType('Flavor', dummy_name='F') + a, b, c, d, e, ff = tensor_indices('a,b,c,d,e,f', Flavor) + mu, nu = tensor_indices('mu,nu', Lorentz) + f = TensorHead('f', [Flavor]*3, TensorSymmetry.direct_product(1, -2)) + A = TensorHead('A', [Lorentz, Flavor], TensorSymmetry.no_symmetry(2)) + t = f(c, -d, -a)*f(-c, -e, -b)*A(-mu, d)*A(-nu, a)*A(nu, e)*A(mu, b) + tc = t.canon_bp() + assert str(tc) == '-f(F_0, F_1, F_2)*f(-F_0, F_3, F_4)*A(L_0, -F_1)*A(-L_0, -F_3)*A(L_1, -F_2)*A(-L_1, -F_4)' + + +def test_bug_correction_tensor_indices(): + # to make sure that tensor_indices does not return a list if creating + # only one index: + A = TensorIndexType("A") + i = tensor_indices('i', A) + assert not isinstance(i, (tuple, list)) + assert isinstance(i, TensorIndex) + + +def test_riemann_invariants(): + Lorentz = TensorIndexType('Lorentz', dummy_name='L') + d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10, d11 = \ + tensor_indices('d0:12', Lorentz) + # R^{d0 d1}_{d1 d0}; ord = [d0,-d0,d1,-d1] + # T_c = -R^{d0 d1}_{d0 d1} + R = TensorHead('R', [Lorentz]*4, TensorSymmetry.riemann()) + t = R(d0, d1, -d1, -d0) + tc = t.canon_bp() + assert str(tc) == '-R(L_0, L_1, -L_0, -L_1)' + + # R_d11^d1_d0^d5 * R^{d6 d4 d0}_d5 * R_{d7 d2 d8 d9} * + # R_{d10 d3 d6 d4} * R^{d2 d7 d11}_d1 * R^{d8 d9 d3 d10} + # can = [0,2,4,6, 1,3,8,10, 5,7,12,14, 9,11,16,18, 13,15,20,22, + # 17,19,21>> from sympy.tensor.tensor import TensorIndexType, TensorHead + >>> from sympy.tensor.toperators import PartialDerivative + >>> from sympy import symbols + >>> L = TensorIndexType("L") + >>> A = TensorHead("A", [L]) + >>> B = TensorHead("B", [L]) + >>> i, j, k = symbols("i j k") + + >>> expr = PartialDerivative(A(i), A(j)) + >>> expr + PartialDerivative(A(i), A(j)) + + The ``PartialDerivative`` object behaves like a tensorial expression: + + >>> expr.get_indices() + [i, -j] + + Notice that the deriving variables have opposite valence than the + printed one: ``A(j)`` is printed as covariant, but the index of the + derivative is actually contravariant, i.e. ``-j``. + + Indices can be contracted: + + >>> expr = PartialDerivative(A(i), A(i)) + >>> expr + PartialDerivative(A(L_0), A(L_0)) + >>> expr.get_indices() + [L_0, -L_0] + + The method ``.get_indices()`` always returns all indices (even the + contracted ones). If only uncontracted indices are needed, call + ``.get_free_indices()``: + + >>> expr.get_free_indices() + [] + + Nested partial derivatives are flattened: + + >>> expr = PartialDerivative(PartialDerivative(A(i), A(j)), A(k)) + >>> expr + PartialDerivative(A(i), A(j), A(k)) + >>> expr.get_indices() + [i, -j, -k] + + Replace a derivative with array values: + + >>> from sympy.abc import x, y + >>> from sympy import sin, log + >>> compA = [sin(x), log(x)*y**3] + >>> compB = [x, y] + >>> expr = PartialDerivative(A(i), B(j)) + >>> expr.replace_with_arrays({A(i): compA, B(i): compB}) + [[cos(x), 0], [y**3/x, 3*y**2*log(x)]] + + The returned array is indexed by `(i, -j)`. + + Be careful that other SymPy modules put the indices of the deriving + variables before the indices of the derivand in the derivative result. + For example: + + >>> expr.get_free_indices() + [i, -j] + + >>> from sympy import Matrix, Array + >>> Matrix(compA).diff(Matrix(compB)).reshape(2, 2) + [[cos(x), y**3/x], [0, 3*y**2*log(x)]] + >>> Array(compA).diff(Array(compB)) + [[cos(x), y**3/x], [0, 3*y**2*log(x)]] + + These are the transpose of the result of ``PartialDerivative``, + as the matrix and the array modules put the index `-j` before `i` in the + derivative result. An array read with index order `(-j, i)` is indeed the + transpose of the same array read with index order `(i, -j)`. By specifying + the index order to ``.replace_with_arrays`` one can get a compatible + expression: + + >>> expr.replace_with_arrays({A(i): compA, B(i): compB}, [-j, i]) + [[cos(x), y**3/x], [0, 3*y**2*log(x)]] + """ + + def __new__(cls, expr, *variables): + + # Flatten: + if isinstance(expr, PartialDerivative): + variables = expr.variables + variables + expr = expr.expr + + args, indices, free, dum = cls._contract_indices_for_derivative( + S(expr), variables) + + obj = TensExpr.__new__(cls, *args) + + obj._indices = indices + obj._free = free + obj._dum = dum + return obj + + @property + def coeff(self): + return S.One + + @property + def nocoeff(self): + return self + + @classmethod + def _contract_indices_for_derivative(cls, expr, variables): + variables_opposite_valence = [] + + for i in variables: + if isinstance(i, Tensor): + i_free_indices = i.get_free_indices() + variables_opposite_valence.append( + i.xreplace({k: -k for k in i_free_indices})) + elif isinstance(i, Symbol): + variables_opposite_valence.append(i) + + args, indices, free, dum = TensMul._tensMul_contract_indices( + [expr] + variables_opposite_valence, replace_indices=True) + + for i in range(1, len(args)): + args_i = args[i] + if isinstance(args_i, Tensor): + i_indices = args[i].get_free_indices() + args[i] = args[i].xreplace({k: -k for k in i_indices}) + + return args, indices, free, dum + + def doit(self, **hints): + args, indices, free, dum = self._contract_indices_for_derivative(self.expr, self.variables) + + obj = self.func(*args) + obj._indices = indices + obj._free = free + obj._dum = dum + + return obj + + def _expand_partial_derivative(self): + args, indices, free, dum = self._contract_indices_for_derivative(self.expr, self.variables) + + obj = self.func(*args) + obj._indices = indices + obj._free = free + obj._dum = dum + + result = obj + + if not args[0].free_symbols: + return S.Zero + elif isinstance(obj.expr, TensAdd): + # take care of sums of multi PDs + result = obj.expr.func(*[ + self.func(a, *obj.variables)._expand_partial_derivative() + for a in result.expr.args]) + elif isinstance(obj.expr, TensMul): + # take care of products of multi PDs + if len(obj.variables) == 1: + # derivative with respect to single variable + terms = [] + mulargs = list(obj.expr.args) + for ind in range(len(mulargs)): + if not isinstance(sympify(mulargs[ind]), Number): + # a number coefficient is not considered for + # expansion of PartialDerivative + d = self.func(mulargs[ind], *obj.variables)._expand_partial_derivative() + terms.append(TensMul(*(mulargs[:ind] + + [d] + + mulargs[(ind + 1):]))) + result = TensAdd.fromiter(terms) + else: + # derivative with respect to multiple variables + # decompose: + # partial(expr, (u, v)) + # = partial(partial(expr, u).doit(), v).doit() + result = obj.expr # init with expr + for v in obj.variables: + result = self.func(result, v)._expand_partial_derivative() + # then throw PD on it + + return result + + def _perform_derivative(self): + result = self.expr + for v in self.variables: + if isinstance(result, TensExpr): + result = result._eval_partial_derivative(v) + else: + if v._diff_wrt: + result = result._eval_derivative(v) + else: + result = S.Zero + return result + + def get_indices(self): + return self._indices + + def get_free_indices(self): + free = sorted(self._free, key=lambda x: x[1]) + return [i[0] for i in free] + + def _replace_indices(self, repl): + expr = self.expr.xreplace(repl) + mirrored = {-k: -v for k, v in repl.items()} + variables = [i.xreplace(mirrored) for i in self.variables] + return self.func(expr, *variables) + + @property + def expr(self): + return self.args[0] + + @property + def variables(self): + return self.args[1:] + + def _extract_data(self, replacement_dict): + from .array import derive_by_array, tensorcontraction + indices, array = self.expr._extract_data(replacement_dict) + for variable in self.variables: + var_indices, var_array = variable._extract_data(replacement_dict) + var_indices = [-i for i in var_indices] + coeff_array, var_array = zip(*[i.as_coeff_Mul() for i in var_array]) + dim_before = len(array.shape) + array = derive_by_array(array, var_array) + dim_after = len(array.shape) + dim_increase = dim_after - dim_before + array = permutedims(array, [i + dim_increase for i in range(dim_before)] + list(range(dim_increase))) + array = array.as_mutable() + varindex = var_indices[0] + # Remove coefficients of base vector: + coeff_index = [0] + [slice(None) for i in range(len(indices))] + for i, coeff in enumerate(coeff_array): + coeff_index[0] = i + array[tuple(coeff_index)] /= coeff + if -varindex in indices: + pos = indices.index(-varindex) + array = tensorcontraction(array, (0, pos+1)) + indices.pop(pos) + else: + indices.append(varindex) + return indices, array diff --git a/phivenv/Lib/site-packages/sympy/testing/__init__.py b/phivenv/Lib/site-packages/sympy/testing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aa66402955e7d5d2b27cccc6627b42d33f4cd855 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/testing/__init__.py @@ -0,0 +1,10 @@ +"""This module contains code for running the tests in SymPy.""" + + +from .runtests import doctest +from .runtests_pytest import test + + +__all__ = [ + 'test', 'doctest', +] diff --git a/phivenv/Lib/site-packages/sympy/testing/matrices.py b/phivenv/Lib/site-packages/sympy/testing/matrices.py new file mode 100644 index 0000000000000000000000000000000000000000..236a384366df7f69d0d92f43f7e007e95c12388c --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/testing/matrices.py @@ -0,0 +1,8 @@ +def allclose(A, B, rtol=1e-05, atol=1e-08): + if len(A) != len(B): + return False + + for x, y in zip(A, B): + if abs(x-y) > atol + rtol * max(abs(x), abs(y)): + return False + return True diff --git a/phivenv/Lib/site-packages/sympy/testing/pytest.py b/phivenv/Lib/site-packages/sympy/testing/pytest.py new file mode 100644 index 0000000000000000000000000000000000000000..498515a2d3c0a167a5f7067753736898cfa64799 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/testing/pytest.py @@ -0,0 +1,392 @@ +"""py.test hacks to support XFAIL/XPASS""" + +import platform +import sys +import re +import functools +import os +import contextlib +import warnings +import inspect +import pathlib +from typing import Any, Callable + +from sympy.utilities.exceptions import SymPyDeprecationWarning +# Imported here for backwards compatibility. Note: do not import this from +# here in library code (importing sympy.pytest in library code will break the +# pytest integration). +from sympy.utilities.exceptions import ignore_warnings # noqa:F401 + +ON_CI = os.getenv('CI', None) == "true" + +try: + import pytest + USE_PYTEST = getattr(sys, '_running_pytest', False) +except ImportError: + USE_PYTEST = False + +IS_WASM: bool = sys.platform == 'emscripten' or platform.machine() in ["wasm32", "wasm64"] + +raises: Callable[[Any, Any], Any] +XFAIL: Callable[[Any], Any] +skip: Callable[[Any], Any] +SKIP: Callable[[Any], Any] +slow: Callable[[Any], Any] +tooslow: Callable[[Any], Any] +nocache_fail: Callable[[Any], Any] + + +if USE_PYTEST: + raises = pytest.raises + skip = pytest.skip + XFAIL = pytest.mark.xfail + SKIP = pytest.mark.skip + slow = pytest.mark.slow + tooslow = pytest.mark.tooslow + nocache_fail = pytest.mark.nocache_fail + from _pytest.outcomes import Failed + +else: + # Not using pytest so define the things that would have been imported from + # there. + + # _pytest._code.code.ExceptionInfo + class ExceptionInfo: + def __init__(self, value): + self.value = value + + def __repr__(self): + return "".format(self.value) + + + def raises(expectedException, code=None): + """ + Tests that ``code`` raises the exception ``expectedException``. + + ``code`` may be a callable, such as a lambda expression or function + name. + + If ``code`` is not given or None, ``raises`` will return a context + manager for use in ``with`` statements; the code to execute then + comes from the scope of the ``with``. + + ``raises()`` does nothing if the callable raises the expected exception, + otherwise it raises an AssertionError. + + Examples + ======== + + >>> from sympy.testing.pytest import raises + + >>> raises(ZeroDivisionError, lambda: 1/0) + + >>> raises(ZeroDivisionError, lambda: 1/2) + Traceback (most recent call last): + ... + Failed: DID NOT RAISE + + >>> with raises(ZeroDivisionError): + ... n = 1/0 + >>> with raises(ZeroDivisionError): + ... n = 1/2 + Traceback (most recent call last): + ... + Failed: DID NOT RAISE + + Note that you cannot test multiple statements via + ``with raises``: + + >>> with raises(ZeroDivisionError): + ... n = 1/0 # will execute and raise, aborting the ``with`` + ... n = 9999/0 # never executed + + This is just what ``with`` is supposed to do: abort the + contained statement sequence at the first exception and let + the context manager deal with the exception. + + To test multiple statements, you'll need a separate ``with`` + for each: + + >>> with raises(ZeroDivisionError): + ... n = 1/0 # will execute and raise + >>> with raises(ZeroDivisionError): + ... n = 9999/0 # will also execute and raise + + """ + if code is None: + return RaisesContext(expectedException) + elif callable(code): + try: + code() + except expectedException as e: + return ExceptionInfo(e) + raise Failed("DID NOT RAISE") + elif isinstance(code, str): + raise TypeError( + '\'raises(xxx, "code")\' has been phased out; ' + 'change \'raises(xxx, "expression")\' ' + 'to \'raises(xxx, lambda: expression)\', ' + '\'raises(xxx, "statement")\' ' + 'to \'with raises(xxx): statement\'') + else: + raise TypeError( + 'raises() expects a callable for the 2nd argument.') + + class RaisesContext: + def __init__(self, expectedException): + self.expectedException = expectedException + + def __enter__(self): + return None + + def __exit__(self, exc_type, exc_value, traceback): + if exc_type is None: + raise Failed("DID NOT RAISE") + return issubclass(exc_type, self.expectedException) + + class XFail(Exception): + pass + + class XPass(Exception): + pass + + class Skipped(Exception): + pass + + class Failed(Exception): # type: ignore + pass + + def XFAIL(func): + def wrapper(): + try: + func() + except Exception as e: + message = str(e) + if message != "Timeout": + raise XFail(func.__name__) + else: + raise Skipped("Timeout") + raise XPass(func.__name__) + + wrapper = functools.update_wrapper(wrapper, func) + return wrapper + + def skip(str): + raise Skipped(str) + + def SKIP(reason): + """Similar to ``skip()``, but this is a decorator. """ + def wrapper(func): + def func_wrapper(): + raise Skipped(reason) + + func_wrapper = functools.update_wrapper(func_wrapper, func) + return func_wrapper + + return wrapper + + def slow(func): + func._slow = True + + def func_wrapper(): + func() + + func_wrapper = functools.update_wrapper(func_wrapper, func) + func_wrapper.__wrapped__ = func + return func_wrapper + + def tooslow(func): + func._slow = True + func._tooslow = True + + def func_wrapper(): + skip("Too slow") + + func_wrapper = functools.update_wrapper(func_wrapper, func) + func_wrapper.__wrapped__ = func + return func_wrapper + + def nocache_fail(func): + "Dummy decorator for marking tests that fail when cache is disabled" + return func + +@contextlib.contextmanager +def warns(warningcls, *, match='', test_stacklevel=True): + ''' + Like raises but tests that warnings are emitted. + + >>> from sympy.testing.pytest import warns + >>> import warnings + + >>> with warns(UserWarning): + ... warnings.warn('deprecated', UserWarning, stacklevel=2) + + >>> with warns(UserWarning): + ... pass + Traceback (most recent call last): + ... + Failed: DID NOT WARN. No warnings of type UserWarning\ + was emitted. The list of emitted warnings is: []. + + ``test_stacklevel`` makes it check that the ``stacklevel`` parameter to + ``warn()`` is set so that the warning shows the user line of code (the + code under the warns() context manager). Set this to False if this is + ambiguous or if the context manager does not test the direct user code + that emits the warning. + + If the warning is a ``SymPyDeprecationWarning``, this additionally tests + that the ``active_deprecations_target`` is a real target in the + ``active-deprecations.md`` file. + + ''' + # Absorbs all warnings in warnrec + with warnings.catch_warnings(record=True) as warnrec: + # Any warning other than the one we are looking for is an error + warnings.simplefilter("error") + warnings.filterwarnings("always", category=warningcls) + # Now run the test + yield warnrec + + # Raise if expected warning not found + if not any(issubclass(w.category, warningcls) for w in warnrec): + msg = ('Failed: DID NOT WARN.' + ' No warnings of type %s was emitted.' + ' The list of emitted warnings is: %s.' + ) % (warningcls, [w.message for w in warnrec]) + raise Failed(msg) + + # We don't include the match in the filter above because it would then + # fall to the error filter, so we instead manually check that it matches + # here + for w in warnrec: + # Should always be true due to the filters above + assert issubclass(w.category, warningcls) + if not re.compile(match, re.IGNORECASE).match(str(w.message)): + raise Failed(f"Failed: WRONG MESSAGE. A warning with of the correct category ({warningcls.__name__}) was issued, but it did not match the given match regex ({match!r})") + + if test_stacklevel: + for f in inspect.stack(): + thisfile = f.filename + file = os.path.split(thisfile)[1] + if file.startswith('test_'): + break + elif file == 'doctest.py': + # skip the stacklevel testing in the doctests of this + # function + return + else: + raise RuntimeError("Could not find the file for the given warning to test the stacklevel") + for w in warnrec: + if w.filename != thisfile: + msg = f'''\ +Failed: Warning has the wrong stacklevel. The warning stacklevel needs to be +set so that the line of code shown in the warning message is user code that +calls the deprecated code (the current stacklevel is showing code from +{w.filename} (line {w.lineno}), expected {thisfile})'''.replace('\n', ' ') + raise Failed(msg) + + if warningcls == SymPyDeprecationWarning: + this_file = pathlib.Path(__file__) + active_deprecations_file = (this_file.parent.parent.parent / 'doc' / + 'src' / 'explanation' / + 'active-deprecations.md') + if not active_deprecations_file.exists(): + # We can only test that the active_deprecations_target works if we are + # in the git repo. + return + targets = [] + for w in warnrec: + targets.append(w.message.active_deprecations_target) + text = pathlib.Path(active_deprecations_file).read_text(encoding="utf-8") + for target in targets: + if f'({target})=' not in text: + raise Failed(f"The active deprecations target {target!r} does not appear to be a valid target in the active-deprecations.md file ({active_deprecations_file}).") + +def _both_exp_pow(func): + """ + Decorator used to run the test twice: the first time `e^x` is represented + as ``Pow(E, x)``, the second time as ``exp(x)`` (exponential object is not + a power). + + This is a temporary trick helping to manage the elimination of the class + ``exp`` in favor of a replacement by ``Pow(E, ...)``. + """ + from sympy.core.parameters import _exp_is_pow + + def func_wrap(): + with _exp_is_pow(True): + func() + with _exp_is_pow(False): + func() + + wrapper = functools.update_wrapper(func_wrap, func) + return wrapper + + +@contextlib.contextmanager +def warns_deprecated_sympy(): + ''' + Shorthand for ``warns(SymPyDeprecationWarning)`` + + This is the recommended way to test that ``SymPyDeprecationWarning`` is + emitted for deprecated features in SymPy. To test for other warnings use + ``warns``. To suppress warnings without asserting that they are emitted + use ``ignore_warnings``. + + .. note:: + + ``warns_deprecated_sympy()`` is only intended for internal use in the + SymPy test suite to test that a deprecation warning triggers properly. + All other code in the SymPy codebase, including documentation examples, + should not use deprecated behavior. + + If you are a user of SymPy and you want to disable + SymPyDeprecationWarnings, use ``warnings`` filters (see + :ref:`silencing-sympy-deprecation-warnings`). + + >>> from sympy.testing.pytest import warns_deprecated_sympy + >>> from sympy.utilities.exceptions import sympy_deprecation_warning + >>> with warns_deprecated_sympy(): + ... sympy_deprecation_warning("Don't use", + ... deprecated_since_version="1.0", + ... active_deprecations_target="active-deprecations") + + >>> with warns_deprecated_sympy(): + ... pass + Traceback (most recent call last): + ... + Failed: DID NOT WARN. No warnings of type \ + SymPyDeprecationWarning was emitted. The list of emitted warnings is: []. + + .. note:: + + Sometimes the stacklevel test will fail because the same warning is + emitted multiple times. In this case, you can use + :func:`sympy.utilities.exceptions.ignore_warnings` in the code to + prevent the ``SymPyDeprecationWarning`` from being emitted again + recursively. In rare cases it is impossible to have a consistent + ``stacklevel`` for deprecation warnings because different ways of + calling a function will produce different call stacks.. In those cases, + use ``warns(SymPyDeprecationWarning)`` instead. + + See Also + ======== + sympy.utilities.exceptions.SymPyDeprecationWarning + sympy.utilities.exceptions.sympy_deprecation_warning + sympy.utilities.decorator.deprecated + + ''' + with warns(SymPyDeprecationWarning): + yield + + +def skip_under_pyodide(message): + """Decorator to skip a test if running under Pyodide/WASM.""" + def decorator(test_func): + @functools.wraps(test_func) + def test_wrapper(): + if IS_WASM: + skip(message) + return test_func() + return test_wrapper + return decorator diff --git a/phivenv/Lib/site-packages/sympy/testing/quality_unicode.py b/phivenv/Lib/site-packages/sympy/testing/quality_unicode.py new file mode 100644 index 0000000000000000000000000000000000000000..d43623ff5112610e377347f50c6a40a15810644b --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/testing/quality_unicode.py @@ -0,0 +1,102 @@ +import re +import fnmatch + + +message_unicode_B = \ + "File contains a unicode character : %s, line %s. " \ + "But not in the whitelist. " \ + "Add the file to the whitelist in " + __file__ +message_unicode_D = \ + "File does not contain a unicode character : %s." \ + "but is in the whitelist. " \ + "Remove the file from the whitelist in " + __file__ + + +encoding_header_re = re.compile( + r'^[ \t\f]*#.*?coding[:=][ \t]*([-_.a-zA-Z0-9]+)') + +# Whitelist pattern for files which can have unicode. +unicode_whitelist = [ + # Author names can include non-ASCII characters + r'*/bin/authors_update.py', + r'*/bin/mailmap_check.py', + + # These files have functions and test functions for unicode input and + # output. + r'*/sympy/testing/tests/test_code_quality.py', + r'*/sympy/physics/vector/tests/test_printing.py', + r'*/physics/quantum/tests/test_printing.py', + r'*/sympy/vector/tests/test_printing.py', + r'*/sympy/parsing/tests/test_sympy_parser.py', + r'*/sympy/printing/pretty/stringpict.py', + r'*/sympy/printing/pretty/tests/test_pretty.py', + r'*/sympy/printing/tests/test_conventions.py', + r'*/sympy/printing/tests/test_preview.py', + r'*/liealgebras/type_g.py', + r'*/liealgebras/weyl_group.py', + r'*/liealgebras/tests/test_type_G.py', + + # wigner.py and polarization.py have unicode doctests. These probably + # don't need to be there but some of the examples that are there are + # pretty ugly without use_unicode (matrices need to be wrapped across + # multiple lines etc) + r'*/sympy/physics/wigner.py', + r'*/sympy/physics/optics/polarization.py', + + # joint.py uses some unicode for variable names in the docstrings + r'*/sympy/physics/mechanics/joint.py', + + # lll method has unicode in docstring references and author name + r'*/sympy/polys/matrices/domainmatrix.py', + r'*/sympy/matrices/repmatrix.py', + + # Explanation of symbols uses greek letters + r'*/sympy/core/symbol.py', +] + +unicode_strict_whitelist = [ + r'*/sympy/parsing/latex/_antlr/__init__.py', + # test_mathematica.py uses some unicode for testing Greek characters are working #24055 + r'*/sympy/parsing/tests/test_mathematica.py', +] + + +def _test_this_file_encoding( + fname, test_file, + unicode_whitelist=unicode_whitelist, + unicode_strict_whitelist=unicode_strict_whitelist): + """Test helper function for unicode test + + The test may have to operate on filewise manner, so it had moved + to a separate process. + """ + has_unicode = False + + is_in_whitelist = False + is_in_strict_whitelist = False + for patt in unicode_whitelist: + if fnmatch.fnmatch(fname, patt): + is_in_whitelist = True + break + for patt in unicode_strict_whitelist: + if fnmatch.fnmatch(fname, patt): + is_in_strict_whitelist = True + is_in_whitelist = True + break + + if is_in_whitelist: + for idx, line in enumerate(test_file): + try: + line.encode(encoding='ascii') + except (UnicodeEncodeError, UnicodeDecodeError): + has_unicode = True + + if not has_unicode and not is_in_strict_whitelist: + assert False, message_unicode_D % fname + + else: + for idx, line in enumerate(test_file): + try: + line.encode(encoding='ascii') + except (UnicodeEncodeError, UnicodeDecodeError): + assert False, message_unicode_B % (fname, idx + 1) diff --git a/phivenv/Lib/site-packages/sympy/testing/randtest.py b/phivenv/Lib/site-packages/sympy/testing/randtest.py new file mode 100644 index 0000000000000000000000000000000000000000..3ce2c8c031eec1c886532daba32c96d83e9cf85c --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/testing/randtest.py @@ -0,0 +1,19 @@ +""" +.. deprecated:: 1.10 + + ``sympy.testing.randtest`` functions have been moved to + :mod:`sympy.core.random`. + +""" +from sympy.utilities.exceptions import sympy_deprecation_warning + +sympy_deprecation_warning("The sympy.testing.randtest submodule is deprecated. Use sympy.core.random instead.", + deprecated_since_version="1.10", + active_deprecations_target="deprecated-sympy-testing-randtest") + +from sympy.core.random import ( # noqa:F401 + random_complex_number, + verify_numerically, + test_derivative_numerically, + _randrange, + _randint) diff --git a/phivenv/Lib/site-packages/sympy/testing/runtests.py b/phivenv/Lib/site-packages/sympy/testing/runtests.py new file mode 100644 index 0000000000000000000000000000000000000000..e2650e4e6dabaaa07fc25c76cce3d9d28723b0d5 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/testing/runtests.py @@ -0,0 +1,2409 @@ +""" +This is our testing framework. + +Goals: + +* it should be compatible with py.test and operate very similarly + (or identically) +* does not require any external dependencies +* preferably all the functionality should be in this file only +* no magic, just import the test file and execute the test functions, that's it +* portable + +""" + +import os +import sys +import platform +import inspect +import traceback +import pdb +import re +import linecache +import time +from fnmatch import fnmatch +from timeit import default_timer as clock +import doctest as pdoctest # avoid clashing with our doctest() function +from doctest import DocTestFinder, DocTestRunner +import random +import subprocess +import shutil +import signal +import stat +import tempfile +import warnings +from contextlib import contextmanager +from inspect import unwrap +from pathlib import Path + +from sympy.core.cache import clear_cache +from sympy.external import import_module +from sympy.external.gmpy import GROUND_TYPES + +IS_WINDOWS = (os.name == 'nt') +ON_CI = os.getenv('CI', None) + +# empirically generated list of the proportion of time spent running +# an even split of tests. This should periodically be regenerated. +# A list of [.6, .1, .3] would mean that if the tests are evenly split +# into '1/3', '2/3', '3/3', the first split would take 60% of the time, +# the second 10% and the third 30%. These lists are normalized to sum +# to 1, so [60, 10, 30] has the same behavior as [6, 1, 3] or [.6, .1, .3]. +# +# This list can be generated with the code: +# from time import time +# import sympy +# import os +# os.environ["CI"] = 'true' # Mock CI to get more correct densities +# delays, num_splits = [], 30 +# for i in range(1, num_splits + 1): +# tic = time() +# sympy.test(split='{}/{}'.format(i, num_splits), time_balance=False) # Add slow=True for slow tests +# delays.append(time() - tic) +# tot = sum(delays) +# print([round(x / tot, 4) for x in delays]) +SPLIT_DENSITY = [ + 0.0059, 0.0027, 0.0068, 0.0011, 0.0006, + 0.0058, 0.0047, 0.0046, 0.004, 0.0257, + 0.0017, 0.0026, 0.004, 0.0032, 0.0016, + 0.0015, 0.0004, 0.0011, 0.0016, 0.0014, + 0.0077, 0.0137, 0.0217, 0.0074, 0.0043, + 0.0067, 0.0236, 0.0004, 0.1189, 0.0142, + 0.0234, 0.0003, 0.0003, 0.0047, 0.0006, + 0.0013, 0.0004, 0.0008, 0.0007, 0.0006, + 0.0139, 0.0013, 0.0007, 0.0051, 0.002, + 0.0004, 0.0005, 0.0213, 0.0048, 0.0016, + 0.0012, 0.0014, 0.0024, 0.0015, 0.0004, + 0.0005, 0.0007, 0.011, 0.0062, 0.0015, + 0.0021, 0.0049, 0.0006, 0.0006, 0.0011, + 0.0006, 0.0019, 0.003, 0.0044, 0.0054, + 0.0057, 0.0049, 0.0016, 0.0006, 0.0009, + 0.0006, 0.0012, 0.0006, 0.0149, 0.0532, + 0.0076, 0.0041, 0.0024, 0.0135, 0.0081, + 0.2209, 0.0459, 0.0438, 0.0488, 0.0137, + 0.002, 0.0003, 0.0008, 0.0039, 0.0024, + 0.0005, 0.0004, 0.003, 0.056, 0.0026] +SPLIT_DENSITY_SLOW = [0.0086, 0.0004, 0.0568, 0.0003, 0.0032, 0.0005, 0.0004, 0.0013, 0.0016, 0.0648, 0.0198, 0.1285, 0.098, 0.0005, 0.0064, 0.0003, 0.0004, 0.0026, 0.0007, 0.0051, 0.0089, 0.0024, 0.0033, 0.0057, 0.0005, 0.0003, 0.001, 0.0045, 0.0091, 0.0006, 0.0005, 0.0321, 0.0059, 0.1105, 0.216, 0.1489, 0.0004, 0.0003, 0.0006, 0.0483] + +class Skipped(Exception): + pass + +class TimeOutError(Exception): + pass + +class DependencyError(Exception): + pass + + +def _indent(s, indent=4): + """ + Add the given number of space characters to the beginning of + every non-blank line in ``s``, and return the result. + If the string ``s`` is Unicode, it is encoded using the stdout + encoding and the ``backslashreplace`` error handler. + """ + # This regexp matches the start of non-blank lines: + return re.sub('(?m)^(?!$)', indent*' ', s) + + +pdoctest._indent = _indent # type: ignore + +# override reporter to maintain windows and python3 + + +def _report_failure(self, out, test, example, got): + """ + Report that the given example failed. + """ + s = self._checker.output_difference(example, got, self.optionflags) + s = s.encode('raw_unicode_escape').decode('utf8', 'ignore') + out(self._failure_header(test, example) + s) + + +if IS_WINDOWS: + DocTestRunner.report_failure = _report_failure # type: ignore + + +def convert_to_native_paths(lst): + """ + Converts a list of '/' separated paths into a list of + native (os.sep separated) paths and converts to lowercase + if the system is case insensitive. + """ + newlst = [] + for rv in lst: + rv = os.path.join(*rv.split("/")) + # on windows the slash after the colon is dropped + if sys.platform == "win32": + pos = rv.find(':') + if pos != -1: + if rv[pos + 1] != '\\': + rv = rv[:pos + 1] + '\\' + rv[pos + 1:] + newlst.append(os.path.normcase(rv)) + return newlst + + +def get_sympy_dir(): + """ + Returns the root SymPy directory and set the global value + indicating whether the system is case sensitive or not. + """ + this_file = os.path.abspath(__file__) + sympy_dir = os.path.join(os.path.dirname(this_file), "..", "..") + sympy_dir = os.path.normpath(sympy_dir) + return os.path.normcase(sympy_dir) + + +def setup_pprint(disable_line_wrap=True): + from sympy.interactive.printing import init_printing + from sympy.printing.pretty.pretty import pprint_use_unicode + import sympy.interactive.printing as interactive_printing + from sympy.printing.pretty import stringpict + + # Prevent init_printing() in doctests from affecting other doctests + interactive_printing.NO_GLOBAL = True + + # force pprint to be in ascii mode in doctests + use_unicode_prev = pprint_use_unicode(False) + + # disable line wrapping for pprint() outputs + wrap_line_prev = stringpict._GLOBAL_WRAP_LINE + if disable_line_wrap: + stringpict._GLOBAL_WRAP_LINE = False + + # hook our nice, hash-stable strprinter + init_printing(pretty_print=False) + + return use_unicode_prev, wrap_line_prev + + +@contextmanager +def raise_on_deprecated(): + """Context manager to make DeprecationWarning raise an error + + This is to catch SymPyDeprecationWarning from library code while running + tests and doctests. It is important to use this context manager around + each individual test/doctest in case some tests modify the warning + filters. + """ + with warnings.catch_warnings(): + warnings.filterwarnings('error', '.*', DeprecationWarning, module='sympy.*') + yield + + +def run_in_subprocess_with_hash_randomization( + function, function_args=(), + function_kwargs=None, command=sys.executable, + module='sympy.testing.runtests', force=False): + """ + Run a function in a Python subprocess with hash randomization enabled. + + If hash randomization is not supported by the version of Python given, it + returns False. Otherwise, it returns the exit value of the command. The + function is passed to sys.exit(), so the return value of the function will + be the return value. + + The environment variable PYTHONHASHSEED is used to seed Python's hash + randomization. If it is set, this function will return False, because + starting a new subprocess is unnecessary in that case. If it is not set, + one is set at random, and the tests are run. Note that if this + environment variable is set when Python starts, hash randomization is + automatically enabled. To force a subprocess to be created even if + PYTHONHASHSEED is set, pass ``force=True``. This flag will not force a + subprocess in Python versions that do not support hash randomization (see + below), because those versions of Python do not support the ``-R`` flag. + + ``function`` should be a string name of a function that is importable from + the module ``module``, like "_test". The default for ``module`` is + "sympy.testing.runtests". ``function_args`` and ``function_kwargs`` + should be a repr-able tuple and dict, respectively. The default Python + command is sys.executable, which is the currently running Python command. + + This function is necessary because the seed for hash randomization must be + set by the environment variable before Python starts. Hence, in order to + use a predetermined seed for tests, we must start Python in a separate + subprocess. + + Hash randomization was added in the minor Python versions 2.6.8, 2.7.3, + 3.1.5, and 3.2.3, and is enabled by default in all Python versions after + and including 3.3.0. + + Examples + ======== + + >>> from sympy.testing.runtests import ( + ... run_in_subprocess_with_hash_randomization) + >>> # run the core tests in verbose mode + >>> run_in_subprocess_with_hash_randomization("_test", + ... function_args=("core",), + ... function_kwargs={'verbose': True}) # doctest: +SKIP + # Will return 0 if sys.executable supports hash randomization and tests + # pass, 1 if they fail, and False if it does not support hash + # randomization. + + """ + cwd = get_sympy_dir() + # Note, we must return False everywhere, not None, as subprocess.call will + # sometimes return None. + + # First check if the Python version supports hash randomization + # If it does not have this support, it won't recognize the -R flag + p = subprocess.Popen([command, "-RV"], stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, cwd=cwd) + p.communicate() + if p.returncode != 0: + return False + + hash_seed = os.getenv("PYTHONHASHSEED") + if not hash_seed: + os.environ["PYTHONHASHSEED"] = str(random.randrange(2**32)) + else: + if not force: + return False + + function_kwargs = function_kwargs or {} + + # Now run the command + commandstring = ("import sys; from %s import %s;sys.exit(%s(*%s, **%s))" % + (module, function, function, repr(function_args), + repr(function_kwargs))) + + try: + p = subprocess.Popen([command, "-R", "-c", commandstring], cwd=cwd) + p.communicate() + except KeyboardInterrupt: + p.wait() + finally: + # Put the environment variable back, so that it reads correctly for + # the current Python process. + if hash_seed is None: + del os.environ["PYTHONHASHSEED"] + else: + os.environ["PYTHONHASHSEED"] = hash_seed + return p.returncode + + +def run_all_tests(test_args=(), test_kwargs=None, + doctest_args=(), doctest_kwargs=None, + examples_args=(), examples_kwargs=None): + """ + Run all tests. + + Right now, this runs the regular tests (bin/test), the doctests + (bin/doctest), and the examples (examples/all.py). + + This is what ``setup.py test`` uses. + + You can pass arguments and keyword arguments to the test functions that + support them (for now, test, doctest, and the examples). See the + docstrings of those functions for a description of the available options. + + For example, to run the solvers tests with colors turned off: + + >>> from sympy.testing.runtests import run_all_tests + >>> run_all_tests(test_args=("solvers",), + ... test_kwargs={"colors:False"}) # doctest: +SKIP + + """ + tests_successful = True + + test_kwargs = test_kwargs or {} + doctest_kwargs = doctest_kwargs or {} + examples_kwargs = examples_kwargs or {'quiet': True} + + try: + # Regular tests + if not test(*test_args, **test_kwargs): + # some regular test fails, so set the tests_successful + # flag to false and continue running the doctests + tests_successful = False + + # Doctests + print() + if not doctest(*doctest_args, **doctest_kwargs): + tests_successful = False + + # Examples + print() + sys.path.append("examples") # examples/all.py + from all import run_examples # type: ignore + if not run_examples(*examples_args, **examples_kwargs): + tests_successful = False + + if tests_successful: + return + else: + # Return nonzero exit code + sys.exit(1) + except KeyboardInterrupt: + print() + print("DO *NOT* COMMIT!") + sys.exit(1) + + +def test(*paths, subprocess=True, rerun=0, **kwargs): + """ + Run tests in the specified test_*.py files. + + Tests in a particular test_*.py file are run if any of the given strings + in ``paths`` matches a part of the test file's path. If ``paths=[]``, + tests in all test_*.py files are run. + + Notes: + + - If sort=False, tests are run in random order (not default). + - Paths can be entered in native system format or in unix, + forward-slash format. + - Files that are on the blacklist can be tested by providing + their path; they are only excluded if no paths are given. + + **Explanation of test results** + + ====== =============================================================== + Output Meaning + ====== =============================================================== + . passed + F failed + X XPassed (expected to fail but passed) + f XFAILed (expected to fail and indeed failed) + s skipped + w slow + T timeout (e.g., when ``--timeout`` is used) + K KeyboardInterrupt (when running the slow tests with ``--slow``, + you can interrupt one of them without killing the test runner) + ====== =============================================================== + + + Colors have no additional meaning and are used just to facilitate + interpreting the output. + + Examples + ======== + + >>> import sympy + + Run all tests: + + >>> sympy.test() # doctest: +SKIP + + Run one file: + + >>> sympy.test("sympy/core/tests/test_basic.py") # doctest: +SKIP + >>> sympy.test("_basic") # doctest: +SKIP + + Run all tests in sympy/functions/ and some particular file: + + >>> sympy.test("sympy/core/tests/test_basic.py", + ... "sympy/functions") # doctest: +SKIP + + Run all tests in sympy/core and sympy/utilities: + + >>> sympy.test("/core", "/util") # doctest: +SKIP + + Run specific test from a file: + + >>> sympy.test("sympy/core/tests/test_basic.py", + ... kw="test_equality") # doctest: +SKIP + + Run specific test from any file: + + >>> sympy.test(kw="subs") # doctest: +SKIP + + Run the tests with verbose mode on: + + >>> sympy.test(verbose=True) # doctest: +SKIP + + Do not sort the test output: + + >>> sympy.test(sort=False) # doctest: +SKIP + + Turn on post-mortem pdb: + + >>> sympy.test(pdb=True) # doctest: +SKIP + + Turn off colors: + + >>> sympy.test(colors=False) # doctest: +SKIP + + Force colors, even when the output is not to a terminal (this is useful, + e.g., if you are piping to ``less -r`` and you still want colors) + + >>> sympy.test(force_colors=False) # doctest: +SKIP + + The traceback verboseness can be set to "short" or "no" (default is + "short") + + >>> sympy.test(tb='no') # doctest: +SKIP + + The ``split`` option can be passed to split the test run into parts. The + split currently only splits the test files, though this may change in the + future. ``split`` should be a string of the form 'a/b', which will run + part ``a`` of ``b``. For instance, to run the first half of the test suite: + + >>> sympy.test(split='1/2') # doctest: +SKIP + + The ``time_balance`` option can be passed in conjunction with ``split``. + If ``time_balance=True`` (the default for ``sympy.test``), SymPy will attempt + to split the tests such that each split takes equal time. This heuristic + for balancing is based on pre-recorded test data. + + >>> sympy.test(split='1/2', time_balance=True) # doctest: +SKIP + + You can disable running the tests in a separate subprocess using + ``subprocess=False``. This is done to support seeding hash randomization, + which is enabled by default in the Python versions where it is supported. + If subprocess=False, hash randomization is enabled/disabled according to + whether it has been enabled or not in the calling Python process. + However, even if it is enabled, the seed cannot be printed unless it is + called from a new Python process. + + Hash randomization was added in the minor Python versions 2.6.8, 2.7.3, + 3.1.5, and 3.2.3, and is enabled by default in all Python versions after + and including 3.3.0. + + If hash randomization is not supported ``subprocess=False`` is used + automatically. + + >>> sympy.test(subprocess=False) # doctest: +SKIP + + To set the hash randomization seed, set the environment variable + ``PYTHONHASHSEED`` before running the tests. This can be done from within + Python using + + >>> import os + >>> os.environ['PYTHONHASHSEED'] = '42' # doctest: +SKIP + + Or from the command line using + + $ PYTHONHASHSEED=42 ./bin/test + + If the seed is not set, a random seed will be chosen. + + Note that to reproduce the same hash values, you must use both the same seed + as well as the same architecture (32-bit vs. 64-bit). + + """ + # count up from 0, do not print 0 + print_counter = lambda i : (print("rerun %d" % (rerun-i)) + if rerun-i else None) + + if subprocess: + # loop backwards so last i is 0 + for i in range(rerun, -1, -1): + print_counter(i) + ret = run_in_subprocess_with_hash_randomization("_test", + function_args=paths, function_kwargs=kwargs) + if ret is False: + break + val = not bool(ret) + # exit on the first failure or if done + if not val or i == 0: + return val + + # rerun even if hash randomization is not supported + for i in range(rerun, -1, -1): + print_counter(i) + val = not bool(_test(*paths, **kwargs)) + if not val or i == 0: + return val + + +def _test(*paths, + verbose=False, tb="short", kw=None, pdb=False, colors=True, + force_colors=False, sort=True, seed=None, timeout=False, + fail_on_timeout=False, slow=False, enhance_asserts=False, split=None, + time_balance=True, blacklist=(), + fast_threshold=None, slow_threshold=None): + """ + Internal function that actually runs the tests. + + All keyword arguments from ``test()`` are passed to this function except for + ``subprocess``. + + Returns 0 if tests passed and 1 if they failed. See the docstring of + ``test()`` for more information. + """ + kw = kw or () + # ensure that kw is a tuple + if isinstance(kw, str): + kw = (kw,) + post_mortem = pdb + if seed is None: + seed = random.randrange(100000000) + if ON_CI and timeout is False: + timeout = 595 + fail_on_timeout = True + if ON_CI: + blacklist = list(blacklist) + ['sympy/plotting/pygletplot/tests'] + blacklist = convert_to_native_paths(blacklist) + r = PyTestReporter(verbose=verbose, tb=tb, colors=colors, + force_colors=force_colors, split=split) + # This won't strictly run the test for the corresponding file, but it is + # good enough for copying and pasting the failing test. + _paths = [] + for path in paths: + if '::' in path: + path, _kw = path.split('::', 1) + kw += (_kw,) + _paths.append(path) + paths = _paths + + t = SymPyTests(r, kw, post_mortem, seed, + fast_threshold=fast_threshold, + slow_threshold=slow_threshold) + + test_files = t.get_test_files('sympy') + + not_blacklisted = [f for f in test_files + if not any(b in f for b in blacklist)] + + if len(paths) == 0: + matched = not_blacklisted + else: + paths = convert_to_native_paths(paths) + matched = [] + for f in not_blacklisted: + basename = os.path.basename(f) + for p in paths: + if p in f or fnmatch(basename, p): + matched.append(f) + break + + density = None + if time_balance: + if slow: + density = SPLIT_DENSITY_SLOW + else: + density = SPLIT_DENSITY + + if split: + matched = split_list(matched, split, density=density) + + t._testfiles.extend(matched) + + return int(not t.test(sort=sort, timeout=timeout, slow=slow, + enhance_asserts=enhance_asserts, fail_on_timeout=fail_on_timeout)) + + +def doctest(*paths, subprocess=True, rerun=0, **kwargs): + r""" + Runs doctests in all \*.py files in the SymPy directory which match + any of the given strings in ``paths`` or all tests if paths=[]. + + Notes: + + - Paths can be entered in native system format or in unix, + forward-slash format. + - Files that are on the blacklist can be tested by providing + their path; they are only excluded if no paths are given. + + Examples + ======== + + >>> import sympy + + Run all tests: + + >>> sympy.doctest() # doctest: +SKIP + + Run one file: + + >>> sympy.doctest("sympy/core/basic.py") # doctest: +SKIP + >>> sympy.doctest("polynomial.rst") # doctest: +SKIP + + Run all tests in sympy/functions/ and some particular file: + + >>> sympy.doctest("/functions", "basic.py") # doctest: +SKIP + + Run any file having polynomial in its name, doc/src/modules/polynomial.rst, + sympy/functions/special/polynomials.py, and sympy/polys/polynomial.py: + + >>> sympy.doctest("polynomial") # doctest: +SKIP + + The ``split`` option can be passed to split the test run into parts. The + split currently only splits the test files, though this may change in the + future. ``split`` should be a string of the form 'a/b', which will run + part ``a`` of ``b``. Note that the regular doctests and the Sphinx + doctests are split independently. For instance, to run the first half of + the test suite: + + >>> sympy.doctest(split='1/2') # doctest: +SKIP + + The ``subprocess`` and ``verbose`` options are the same as with the function + ``test()`` (see the docstring of that function for more information) except + that ``verbose`` may also be set equal to ``2`` in order to print + individual doctest lines, as they are being tested. + """ + # count up from 0, do not print 0 + print_counter = lambda i : (print("rerun %d" % (rerun-i)) + if rerun-i else None) + + if subprocess: + # loop backwards so last i is 0 + for i in range(rerun, -1, -1): + print_counter(i) + ret = run_in_subprocess_with_hash_randomization("_doctest", + function_args=paths, function_kwargs=kwargs) + if ret is False: + break + val = not bool(ret) + # exit on the first failure or if done + if not val or i == 0: + return val + + # rerun even if hash randomization is not supported + for i in range(rerun, -1, -1): + print_counter(i) + val = not bool(_doctest(*paths, **kwargs)) + if not val or i == 0: + return val + + +def _get_doctest_blacklist(): + '''Get the default blacklist for the doctests''' + blacklist = [] + + blacklist.extend([ + "doc/src/modules/plotting.rst", # generates live plots + "doc/src/modules/physics/mechanics/autolev_parser.rst", + "sympy/codegen/array_utils.py", # raises deprecation warning + "sympy/core/compatibility.py", # backwards compatibility shim, importing it triggers a deprecation warning + "sympy/core/trace.py", # backwards compatibility shim, importing it triggers a deprecation warning + "sympy/galgebra.py", # no longer part of SymPy + "sympy/parsing/autolev/_antlr/autolevlexer.py", # generated code + "sympy/parsing/autolev/_antlr/autolevlistener.py", # generated code + "sympy/parsing/autolev/_antlr/autolevparser.py", # generated code + "sympy/parsing/latex/_antlr/latexlexer.py", # generated code + "sympy/parsing/latex/_antlr/latexparser.py", # generated code + "sympy/plotting/pygletplot/__init__.py", # crashes on some systems + "sympy/plotting/pygletplot/plot.py", # crashes on some systems + "sympy/printing/ccode.py", # backwards compatibility shim, importing it breaks the codegen doctests + "sympy/printing/cxxcode.py", # backwards compatibility shim, importing it breaks the codegen doctests + "sympy/printing/fcode.py", # backwards compatibility shim, importing it breaks the codegen doctests + "sympy/testing/randtest.py", # backwards compatibility shim, importing it triggers a deprecation warning + "sympy/this.py", # prints text + ]) + # autolev parser tests + num = 12 + for i in range (1, num+1): + blacklist.append("sympy/parsing/autolev/test-examples/ruletest" + str(i) + ".py") + blacklist.extend(["sympy/parsing/autolev/test-examples/pydy-example-repo/mass_spring_damper.py", + "sympy/parsing/autolev/test-examples/pydy-example-repo/chaos_pendulum.py", + "sympy/parsing/autolev/test-examples/pydy-example-repo/double_pendulum.py", + "sympy/parsing/autolev/test-examples/pydy-example-repo/non_min_pendulum.py"]) + + if import_module('numpy') is None: + blacklist.extend([ + "sympy/plotting/experimental_lambdify.py", + "sympy/plotting/plot_implicit.py", + "examples/advanced/autowrap_integrators.py", + "examples/advanced/autowrap_ufuncify.py", + "examples/intermediate/sample.py", + "examples/intermediate/mplot2d.py", + "examples/intermediate/mplot3d.py", + "doc/src/modules/numeric-computation.rst", + "doc/src/explanation/best-practices.md", + "doc/src/tutorials/physics/biomechanics/biomechanical-model-example.rst", + "doc/src/tutorials/physics/biomechanics/biomechanics.rst", + ]) + else: + if import_module('matplotlib') is None: + blacklist.extend([ + "examples/intermediate/mplot2d.py", + "examples/intermediate/mplot3d.py" + ]) + else: + # Use a non-windowed backend, so that the tests work on CI + import matplotlib + matplotlib.use('Agg') + + if ON_CI or import_module('pyglet') is None: + blacklist.extend(["sympy/plotting/pygletplot"]) + + if import_module('aesara') is None: + blacklist.extend([ + "sympy/printing/aesaracode.py", + "doc/src/modules/numeric-computation.rst", + ]) + + if import_module('cupy') is None: + blacklist.extend([ + "doc/src/modules/numeric-computation.rst", + ]) + + if import_module('jax') is None: + blacklist.extend([ + "doc/src/modules/numeric-computation.rst", + ]) + + if import_module('antlr4') is None: + blacklist.extend([ + "sympy/parsing/autolev/__init__.py", + "sympy/parsing/latex/_parse_latex_antlr.py", + ]) + + if import_module('lfortran') is None: + #throws ImportError when lfortran not installed + blacklist.extend([ + "sympy/parsing/sym_expr.py", + ]) + + if import_module("scipy") is None: + # throws ModuleNotFoundError when scipy not installed + blacklist.extend([ + "doc/src/guides/solving/solve-numerically.md", + "doc/src/guides/solving/solve-ode.md", + ]) + + if import_module("numpy") is None: + # throws ModuleNotFoundError when numpy not installed + blacklist.extend([ + "doc/src/guides/solving/solve-ode.md", + "doc/src/guides/solving/solve-numerically.md", + ]) + + # disabled because of doctest failures in asmeurer's bot + blacklist.extend([ + "sympy/utilities/autowrap.py", + "examples/advanced/autowrap_integrators.py", + "examples/advanced/autowrap_ufuncify.py" + ]) + + blacklist.extend([ + "sympy/conftest.py", # Depends on pytest + ]) + + # These are deprecated stubs to be removed: + blacklist.extend([ + "sympy/utilities/tmpfiles.py", + "sympy/utilities/pytest.py", + "sympy/utilities/runtests.py", + "sympy/utilities/quality_unicode.py", + "sympy/utilities/randtest.py", + ]) + + blacklist = convert_to_native_paths(blacklist) + return blacklist + + +def _doctest(*paths, **kwargs): + """ + Internal function that actually runs the doctests. + + All keyword arguments from ``doctest()`` are passed to this function + except for ``subprocess``. + + Returns 0 if tests passed and 1 if they failed. See the docstrings of + ``doctest()`` and ``test()`` for more information. + """ + from sympy.printing.pretty.pretty import pprint_use_unicode + from sympy.printing.pretty import stringpict + + normal = kwargs.get("normal", False) + verbose = kwargs.get("verbose", False) + colors = kwargs.get("colors", True) + force_colors = kwargs.get("force_colors", False) + blacklist = kwargs.get("blacklist", []) + split = kwargs.get('split', None) + + blacklist.extend(_get_doctest_blacklist()) + + # Use a non-windowed backend, so that the tests work on CI + if import_module('matplotlib') is not None: + import matplotlib + matplotlib.use('Agg') + + # Disable warnings for external modules + import sympy.external + sympy.external.importtools.WARN_OLD_VERSION = False + sympy.external.importtools.WARN_NOT_INSTALLED = False + + # Disable showing up of plots + from sympy.plotting.plot import unset_show + unset_show() + + r = PyTestReporter(verbose, split=split, colors=colors,\ + force_colors=force_colors) + t = SymPyDocTests(r, normal) + + test_files = t.get_test_files('sympy') + test_files.extend(t.get_test_files('examples', init_only=False)) + + not_blacklisted = [f for f in test_files + if not any(b in f for b in blacklist)] + if len(paths) == 0: + matched = not_blacklisted + else: + # take only what was requested...but not blacklisted items + # and allow for partial match anywhere or fnmatch of name + paths = convert_to_native_paths(paths) + matched = [] + for f in not_blacklisted: + basename = os.path.basename(f) + for p in paths: + if p in f or fnmatch(basename, p): + matched.append(f) + break + + matched.sort() + + if split: + matched = split_list(matched, split) + + t._testfiles.extend(matched) + + # run the tests and record the result for this *py portion of the tests + if t._testfiles: + failed = not t.test() + else: + failed = False + + # N.B. + # -------------------------------------------------------------------- + # Here we test *.rst and *.md files at or below doc/src. Code from these + # must be self supporting in terms of imports since there is no importing + # of necessary modules by doctest.testfile. If you try to pass *.py files + # through this they might fail because they will lack the needed imports + # and smarter parsing that can be done with source code. + # + test_files_rst = t.get_test_files('doc/src', '*.rst', init_only=False) + test_files_md = t.get_test_files('doc/src', '*.md', init_only=False) + test_files = test_files_rst + test_files_md + test_files.sort() + + not_blacklisted = [f for f in test_files + if not any(b in f for b in blacklist)] + + if len(paths) == 0: + matched = not_blacklisted + else: + # Take only what was requested as long as it's not on the blacklist. + # Paths were already made native in *py tests so don't repeat here. + # There's no chance of having a *py file slip through since we + # only have *rst files in test_files. + matched = [] + for f in not_blacklisted: + basename = os.path.basename(f) + for p in paths: + if p in f or fnmatch(basename, p): + matched.append(f) + break + + if split: + matched = split_list(matched, split) + + first_report = True + for rst_file in matched: + if not os.path.isfile(rst_file): + continue + old_displayhook = sys.displayhook + try: + use_unicode_prev, wrap_line_prev = setup_pprint() + out = sympytestfile( + rst_file, module_relative=False, encoding='utf-8', + optionflags=pdoctest.ELLIPSIS | pdoctest.NORMALIZE_WHITESPACE | + pdoctest.IGNORE_EXCEPTION_DETAIL) + finally: + # make sure we return to the original displayhook in case some + # doctest has changed that + sys.displayhook = old_displayhook + # The NO_GLOBAL flag overrides the no_global flag to init_printing + # if True + import sympy.interactive.printing as interactive_printing + interactive_printing.NO_GLOBAL = False + pprint_use_unicode(use_unicode_prev) + stringpict._GLOBAL_WRAP_LINE = wrap_line_prev + + rstfailed, tested = out + if tested: + failed = rstfailed or failed + if first_report: + first_report = False + msg = 'rst/md doctests start' + if not t._testfiles: + r.start(msg=msg) + else: + r.write_center(msg) + print() + # use as the id, everything past the first 'sympy' + file_id = rst_file[rst_file.find('sympy') + len('sympy') + 1:] + print(file_id, end=" ") + # get at least the name out so it is know who is being tested + wid = r.terminal_width - len(file_id) - 1 # update width + test_file = '[%s]' % (tested) + report = '[%s]' % (rstfailed or 'OK') + print(''.join( + [test_file, ' '*(wid - len(test_file) - len(report)), report]) + ) + + # the doctests for *py will have printed this message already if there was + # a failure, so now only print it if there was intervening reporting by + # testing the *rst as evidenced by first_report no longer being True. + if not first_report and failed: + print() + print("DO *NOT* COMMIT!") + + return int(failed) + +sp = re.compile(r'([0-9]+)/([1-9][0-9]*)') + +def split_list(l, split, density=None): + """ + Splits a list into part a of b + + split should be a string of the form 'a/b'. For instance, '1/3' would give + the split one of three. + + If the length of the list is not divisible by the number of splits, the + last split will have more items. + + `density` may be specified as a list. If specified, + tests will be balanced so that each split has as equal-as-possible + amount of mass according to `density`. + + >>> from sympy.testing.runtests import split_list + >>> a = list(range(10)) + >>> split_list(a, '1/3') + [0, 1, 2] + >>> split_list(a, '2/3') + [3, 4, 5] + >>> split_list(a, '3/3') + [6, 7, 8, 9] + """ + m = sp.match(split) + if not m: + raise ValueError("split must be a string of the form a/b where a and b are ints") + i, t = map(int, m.groups()) + + if not density: + return l[(i - 1)*len(l)//t : i*len(l)//t] + + # normalize density + tot = sum(density) + density = [x / tot for x in density] + + def density_inv(x): + """Interpolate the inverse to the cumulative + distribution function given by density""" + if x <= 0: + return 0 + if x >= sum(density): + return 1 + + # find the first time the cumulative sum surpasses x + # and linearly interpolate + cumm = 0 + for i, d in enumerate(density): + cumm += d + if cumm >= x: + break + frac = (d - (cumm - x)) / d + return (i + frac) / len(density) + + lower_frac = density_inv((i - 1) / t) + higher_frac = density_inv(i / t) + return l[int(lower_frac*len(l)) : int(higher_frac*len(l))] + +from collections import namedtuple +SymPyTestResults = namedtuple('SymPyTestResults', 'failed attempted') + +def sympytestfile(filename, module_relative=True, name=None, package=None, + globs=None, verbose=None, report=True, optionflags=0, + extraglobs=None, raise_on_error=False, + parser=pdoctest.DocTestParser(), encoding=None): + + """ + Test examples in the given file. Return (#failures, #tests). + + Optional keyword arg ``module_relative`` specifies how filenames + should be interpreted: + + - If ``module_relative`` is True (the default), then ``filename`` + specifies a module-relative path. By default, this path is + relative to the calling module's directory; but if the + ``package`` argument is specified, then it is relative to that + package. To ensure os-independence, ``filename`` should use + "/" characters to separate path segments, and should not + be an absolute path (i.e., it may not begin with "/"). + + - If ``module_relative`` is False, then ``filename`` specifies an + os-specific path. The path may be absolute or relative (to + the current working directory). + + Optional keyword arg ``name`` gives the name of the test; by default + use the file's basename. + + Optional keyword argument ``package`` is a Python package or the + name of a Python package whose directory should be used as the + base directory for a module relative filename. If no package is + specified, then the calling module's directory is used as the base + directory for module relative filenames. It is an error to + specify ``package`` if ``module_relative`` is False. + + Optional keyword arg ``globs`` gives a dict to be used as the globals + when executing examples; by default, use {}. A copy of this dict + is actually used for each docstring, so that each docstring's + examples start with a clean slate. + + Optional keyword arg ``extraglobs`` gives a dictionary that should be + merged into the globals that are used to execute examples. By + default, no extra globals are used. + + Optional keyword arg ``verbose`` prints lots of stuff if true, prints + only failures if false; by default, it's true iff "-v" is in sys.argv. + + Optional keyword arg ``report`` prints a summary at the end when true, + else prints nothing at the end. In verbose mode, the summary is + detailed, else very brief (in fact, empty if all tests passed). + + Optional keyword arg ``optionflags`` or's together module constants, + and defaults to 0. Possible values (see the docs for details): + + - DONT_ACCEPT_TRUE_FOR_1 + - DONT_ACCEPT_BLANKLINE + - NORMALIZE_WHITESPACE + - ELLIPSIS + - SKIP + - IGNORE_EXCEPTION_DETAIL + - REPORT_UDIFF + - REPORT_CDIFF + - REPORT_NDIFF + - REPORT_ONLY_FIRST_FAILURE + + Optional keyword arg ``raise_on_error`` raises an exception on the + first unexpected exception or failure. This allows failures to be + post-mortem debugged. + + Optional keyword arg ``parser`` specifies a DocTestParser (or + subclass) that should be used to extract tests from the files. + + Optional keyword arg ``encoding`` specifies an encoding that should + be used to convert the file to unicode. + + Advanced tomfoolery: testmod runs methods of a local instance of + class doctest.Tester, then merges the results into (or creates) + global Tester instance doctest.master. Methods of doctest.master + can be called directly too, if you want to do something unusual. + Passing report=0 to testmod is especially useful then, to delay + displaying a summary. Invoke doctest.master.summarize(verbose) + when you're done fiddling. + """ + if package and not module_relative: + raise ValueError("Package may only be specified for module-" + "relative paths.") + + # Relativize the path + text, filename = pdoctest._load_testfile( + filename, package, module_relative, encoding) + + # If no name was given, then use the file's name. + if name is None: + name = os.path.basename(filename) + + # Assemble the globals. + if globs is None: + globs = {} + else: + globs = globs.copy() + if extraglobs is not None: + globs.update(extraglobs) + if '__name__' not in globs: + globs['__name__'] = '__main__' + + if raise_on_error: + runner = pdoctest.DebugRunner(verbose=verbose, optionflags=optionflags) + else: + runner = SymPyDocTestRunner(verbose=verbose, optionflags=optionflags) + runner._checker = SymPyOutputChecker() + + # Read the file, convert it to a test, and run it. + test = parser.get_doctest(text, globs, name, filename, 0) + runner.run(test) + + if report: + runner.summarize() + + if pdoctest.master is None: + pdoctest.master = runner + else: + pdoctest.master.merge(runner) + + return SymPyTestResults(runner.failures, runner.tries) + + +class SymPyTests: + + def __init__(self, reporter, kw="", post_mortem=False, + seed=None, fast_threshold=None, slow_threshold=None): + self._post_mortem = post_mortem + self._kw = kw + self._count = 0 + self._root_dir = get_sympy_dir() + self._reporter = reporter + self._reporter.root_dir(self._root_dir) + self._testfiles = [] + self._seed = seed if seed is not None else random.random() + + # Defaults in seconds, from human / UX design limits + # http://www.nngroup.com/articles/response-times-3-important-limits/ + # + # These defaults are *NOT* set in stone as we are measuring different + # things, so others feel free to come up with a better yardstick :) + if fast_threshold: + self._fast_threshold = float(fast_threshold) + else: + self._fast_threshold = 8 + if slow_threshold: + self._slow_threshold = float(slow_threshold) + else: + self._slow_threshold = 10 + + def test(self, sort=False, timeout=False, slow=False, + enhance_asserts=False, fail_on_timeout=False): + """ + Runs the tests returning True if all tests pass, otherwise False. + + If sort=False run tests in random order. + """ + if sort: + self._testfiles.sort() + elif slow: + pass + else: + random.seed(self._seed) + random.shuffle(self._testfiles) + self._reporter.start(self._seed) + for f in self._testfiles: + try: + self.test_file(f, sort, timeout, slow, + enhance_asserts, fail_on_timeout) + except KeyboardInterrupt: + print(" interrupted by user") + self._reporter.finish() + raise + return self._reporter.finish() + + def _enhance_asserts(self, source): + from ast import (NodeTransformer, Compare, Name, Store, Load, Tuple, + Assign, BinOp, Str, Mod, Assert, parse, fix_missing_locations) + + ops = {"Eq": '==', "NotEq": '!=', "Lt": '<', "LtE": '<=', + "Gt": '>', "GtE": '>=', "Is": 'is', "IsNot": 'is not', + "In": 'in', "NotIn": 'not in'} + + class Transform(NodeTransformer): + def visit_Assert(self, stmt): + if isinstance(stmt.test, Compare): + compare = stmt.test + values = [compare.left] + compare.comparators + names = [ "_%s" % i for i, _ in enumerate(values) ] + names_store = [ Name(n, Store()) for n in names ] + names_load = [ Name(n, Load()) for n in names ] + target = Tuple(names_store, Store()) + value = Tuple(values, Load()) + assign = Assign([target], value) + new_compare = Compare(names_load[0], compare.ops, names_load[1:]) + msg_format = "\n%s " + "\n%s ".join([ ops[op.__class__.__name__] for op in compare.ops ]) + "\n%s" + msg = BinOp(Str(msg_format), Mod(), Tuple(names_load, Load())) + test = Assert(new_compare, msg, lineno=stmt.lineno, col_offset=stmt.col_offset) + return [assign, test] + else: + return stmt + + tree = parse(source) + new_tree = Transform().visit(tree) + return fix_missing_locations(new_tree) + + def test_file(self, filename, sort=True, timeout=False, slow=False, + enhance_asserts=False, fail_on_timeout=False): + reporter = self._reporter + funcs = [] + try: + gl = {'__file__': filename} + try: + open_file = lambda: open(filename, encoding="utf8") + + with open_file() as f: + source = f.read() + if self._kw: + for l in source.splitlines(): + if l.lstrip().startswith('def '): + if any(l.lower().find(k.lower()) != -1 for k in self._kw): + break + else: + return + + if enhance_asserts: + try: + source = self._enhance_asserts(source) + except ImportError: + pass + + code = compile(source, filename, "exec", flags=0, dont_inherit=True) + exec(code, gl) + except (SystemExit, KeyboardInterrupt): + raise + except ImportError: + reporter.import_error(filename, sys.exc_info()) + return + except Exception: + reporter.test_exception(sys.exc_info()) + + clear_cache() + self._count += 1 + random.seed(self._seed) + disabled = gl.get("disabled", False) + if not disabled: + # we need to filter only those functions that begin with 'test_' + # We have to be careful about decorated functions. As long as + # the decorator uses functools.wraps, we can detect it. + funcs = [] + for f in gl: + if (f.startswith("test_") and (inspect.isfunction(gl[f]) + or inspect.ismethod(gl[f]))): + func = gl[f] + # Handle multiple decorators + while hasattr(func, '__wrapped__'): + func = func.__wrapped__ + + if inspect.getsourcefile(func) == filename: + funcs.append(gl[f]) + if slow: + funcs = [f for f in funcs if getattr(f, '_slow', False)] + # Sorting of XFAILed functions isn't fixed yet :-( + funcs.sort(key=lambda x: inspect.getsourcelines(x)[1]) + i = 0 + while i < len(funcs): + if inspect.isgeneratorfunction(funcs[i]): + # some tests can be generators, that return the actual + # test functions. We unpack it below: + f = funcs.pop(i) + for fg in f(): + func = fg[0] + args = fg[1:] + fgw = lambda: func(*args) + funcs.insert(i, fgw) + i += 1 + else: + i += 1 + # drop functions that are not selected with the keyword expression: + funcs = [x for x in funcs if self.matches(x)] + + if not funcs: + return + except Exception: + reporter.entering_filename(filename, len(funcs)) + raise + + reporter.entering_filename(filename, len(funcs)) + if not sort: + random.shuffle(funcs) + + for f in funcs: + start = time.time() + reporter.entering_test(f) + try: + if getattr(f, '_slow', False) and not slow: + raise Skipped("Slow") + with raise_on_deprecated(): + if timeout: + self._timeout(f, timeout, fail_on_timeout) + else: + random.seed(self._seed) + f() + except KeyboardInterrupt: + if getattr(f, '_slow', False): + reporter.test_skip("KeyboardInterrupt") + else: + raise + except Exception: + if timeout: + signal.alarm(0) # Disable the alarm. It could not be handled before. + t, v, tr = sys.exc_info() + if t is AssertionError: + reporter.test_fail((t, v, tr)) + if self._post_mortem: + pdb.post_mortem(tr) + elif t.__name__ == "Skipped": + reporter.test_skip(v) + elif t.__name__ == "XFail": + reporter.test_xfail() + elif t.__name__ == "XPass": + reporter.test_xpass(v) + else: + reporter.test_exception((t, v, tr)) + if self._post_mortem: + pdb.post_mortem(tr) + else: + reporter.test_pass() + taken = time.time() - start + if taken > self._slow_threshold: + filename = os.path.relpath(filename, reporter._root_dir) + reporter.slow_test_functions.append( + (filename + "::" + f.__name__, taken)) + if getattr(f, '_slow', False) and slow: + if taken < self._fast_threshold: + filename = os.path.relpath(filename, reporter._root_dir) + reporter.fast_test_functions.append( + (filename + "::" + f.__name__, taken)) + reporter.leaving_filename() + + def _timeout(self, function, timeout, fail_on_timeout): + def callback(x, y): + signal.alarm(0) + if fail_on_timeout: + raise TimeOutError("Timed out after %d seconds" % timeout) + else: + raise Skipped("Timeout") + signal.signal(signal.SIGALRM, callback) + signal.alarm(timeout) # Set an alarm with a given timeout + function() + signal.alarm(0) # Disable the alarm + + def matches(self, x): + """ + Does the keyword expression self._kw match "x"? Returns True/False. + + Always returns True if self._kw is "". + """ + if not self._kw: + return True + for kw in self._kw: + if x.__name__.lower().find(kw.lower()) != -1: + return True + return False + + def get_test_files(self, dir, pat='test_*.py'): + """ + Returns the list of test_*.py (default) files at or below directory + ``dir`` relative to the SymPy home directory. + """ + dir = os.path.join(self._root_dir, convert_to_native_paths([dir])[0]) + + g = [] + for path, folders, files in os.walk(dir): + g.extend([os.path.join(path, f) for f in files if fnmatch(f, pat)]) + + return sorted([os.path.normcase(gi) for gi in g]) + + +class SymPyDocTests: + + def __init__(self, reporter, normal): + self._count = 0 + self._root_dir = get_sympy_dir() + self._reporter = reporter + self._reporter.root_dir(self._root_dir) + self._normal = normal + + self._testfiles = [] + + def test(self): + """ + Runs the tests and returns True if all tests pass, otherwise False. + """ + self._reporter.start() + for f in self._testfiles: + try: + self.test_file(f) + except KeyboardInterrupt: + print(" interrupted by user") + self._reporter.finish() + raise + return self._reporter.finish() + + def test_file(self, filename): + clear_cache() + + from io import StringIO + import sympy.interactive.printing as interactive_printing + from sympy.printing.pretty.pretty import pprint_use_unicode + from sympy.printing.pretty import stringpict + + rel_name = filename[len(self._root_dir) + 1:] + dirname, file = os.path.split(filename) + module = rel_name.replace(os.sep, '.')[:-3] + + if rel_name.startswith("examples"): + # Examples files do not have __init__.py files, + # So we have to temporarily extend sys.path to import them + sys.path.insert(0, dirname) + module = file[:-3] # remove ".py" + try: + module = pdoctest._normalize_module(module) + tests = SymPyDocTestFinder().find(module) + except (SystemExit, KeyboardInterrupt): + raise + except ImportError: + self._reporter.import_error(filename, sys.exc_info()) + return + finally: + if rel_name.startswith("examples"): + del sys.path[0] + + tests = [test for test in tests if len(test.examples) > 0] + # By default tests are sorted by alphabetical order by function name. + # We sort by line number so one can edit the file sequentially from + # bottom to top. However, if there are decorated functions, their line + # numbers will be too large and for now one must just search for these + # by text and function name. + tests.sort(key=lambda x: -x.lineno) + + if not tests: + return + self._reporter.entering_filename(filename, len(tests)) + for test in tests: + assert len(test.examples) != 0 + + if self._reporter._verbose: + self._reporter.write("\n{} ".format(test.name)) + + # check if there are external dependencies which need to be met + if '_doctest_depends_on' in test.globs: + try: + self._check_dependencies(**test.globs['_doctest_depends_on']) + except DependencyError as e: + self._reporter.test_skip(v=str(e)) + continue + + runner = SymPyDocTestRunner(verbose=self._reporter._verbose==2, + optionflags=pdoctest.ELLIPSIS | + pdoctest.NORMALIZE_WHITESPACE | + pdoctest.IGNORE_EXCEPTION_DETAIL) + runner._checker = SymPyOutputChecker() + old = sys.stdout + new = old if self._reporter._verbose==2 else StringIO() + sys.stdout = new + # If the testing is normal, the doctests get importing magic to + # provide the global namespace. If not normal (the default) then + # then must run on their own; all imports must be explicit within + # a function's docstring. Once imported that import will be + # available to the rest of the tests in a given function's + # docstring (unless clear_globs=True below). + if not self._normal: + test.globs = {} + # if this is uncommented then all the test would get is what + # comes by default with a "from sympy import *" + #exec('from sympy import *') in test.globs + old_displayhook = sys.displayhook + use_unicode_prev, wrap_line_prev = setup_pprint() + + try: + f, t = runner.run(test, + out=new.write, clear_globs=False) + except KeyboardInterrupt: + raise + finally: + sys.stdout = old + if f > 0: + self._reporter.doctest_fail(test.name, new.getvalue()) + else: + self._reporter.test_pass() + sys.displayhook = old_displayhook + interactive_printing.NO_GLOBAL = False + pprint_use_unicode(use_unicode_prev) + stringpict._GLOBAL_WRAP_LINE = wrap_line_prev + + self._reporter.leaving_filename() + + def get_test_files(self, dir, pat='*.py', init_only=True): + r""" + Returns the list of \*.py files (default) from which docstrings + will be tested which are at or below directory ``dir``. By default, + only those that have an __init__.py in their parent directory + and do not start with ``test_`` will be included. + """ + def importable(x): + """ + Checks if given pathname x is an importable module by checking for + __init__.py file. + + Returns True/False. + + Currently we only test if the __init__.py file exists in the + directory with the file "x" (in theory we should also test all the + parent dirs). + """ + init_py = os.path.join(os.path.dirname(x), "__init__.py") + return os.path.exists(init_py) + + dir = os.path.join(self._root_dir, convert_to_native_paths([dir])[0]) + + g = [] + for path, folders, files in os.walk(dir): + g.extend([os.path.join(path, f) for f in files + if not f.startswith('test_') and fnmatch(f, pat)]) + if init_only: + # skip files that are not importable (i.e. missing __init__.py) + g = [x for x in g if importable(x)] + + return [os.path.normcase(gi) for gi in g] + + def _check_dependencies(self, + executables=(), + modules=(), + disable_viewers=(), + python_version=(3, 5), + ground_types=None): + """ + Checks if the dependencies for the test are installed. + + Raises ``DependencyError`` it at least one dependency is not installed. + """ + + for executable in executables: + if not shutil.which(executable): + raise DependencyError("Could not find %s" % executable) + + for module in modules: + if module == 'matplotlib': + matplotlib = import_module( + 'matplotlib', + import_kwargs={'fromlist': + ['pyplot', 'cm', 'collections']}, + min_module_version='1.0.0', catch=(RuntimeError,)) + if matplotlib is None: + raise DependencyError("Could not import matplotlib") + else: + if not import_module(module): + raise DependencyError("Could not import %s" % module) + + if disable_viewers: + tempdir = tempfile.mkdtemp() + os.environ['PATH'] = '%s:%s' % (tempdir, os.environ['PATH']) + + vw = ('#!/usr/bin/env python3\n' + 'import sys\n' + 'if len(sys.argv) <= 1:\n' + ' exit("wrong number of args")\n') + + for viewer in disable_viewers: + Path(os.path.join(tempdir, viewer)).write_text(vw) + + # make the file executable + os.chmod(os.path.join(tempdir, viewer), + stat.S_IREAD | stat.S_IWRITE | stat.S_IXUSR) + + if python_version: + if sys.version_info < python_version: + raise DependencyError("Requires Python >= " + '.'.join(map(str, python_version))) + + if ground_types is not None: + if GROUND_TYPES not in ground_types: + raise DependencyError("Requires ground_types in " + str(ground_types)) + + if 'pyglet' in modules: + # monkey-patch pyglet s.t. it does not open a window during + # doctesting + import pyglet + class DummyWindow: + def __init__(self, *args, **kwargs): + self.has_exit = True + self.width = 600 + self.height = 400 + + def set_vsync(self, x): + pass + + def switch_to(self): + pass + + def push_handlers(self, x): + pass + + def close(self): + pass + + pyglet.window.Window = DummyWindow + + +class SymPyDocTestFinder(DocTestFinder): + """ + A class used to extract the DocTests that are relevant to a given + object, from its docstring and the docstrings of its contained + objects. Doctests can currently be extracted from the following + object types: modules, functions, classes, methods, staticmethods, + classmethods, and properties. + + Modified from doctest's version to look harder for code that + appears comes from a different module. For example, the @vectorize + decorator makes it look like functions come from multidimensional.py + even though their code exists elsewhere. + """ + + def _find(self, tests, obj, name, module, source_lines, globs, seen): + """ + Find tests for the given object and any contained objects, and + add them to ``tests``. + """ + if self._verbose: + print('Finding tests in %s' % name) + + # If we've already processed this object, then ignore it. + if id(obj) in seen: + return + seen[id(obj)] = 1 + + # Make sure we don't run doctests for classes outside of sympy, such + # as in numpy or scipy. + if inspect.isclass(obj): + if obj.__module__.split('.')[0] != 'sympy': + return + + # Find a test for this object, and add it to the list of tests. + test = self._get_test(obj, name, module, globs, source_lines) + if test is not None: + tests.append(test) + + if not self._recurse: + return + + # Look for tests in a module's contained objects. + if inspect.ismodule(obj): + for rawname, val in obj.__dict__.items(): + # Recurse to functions & classes. + if inspect.isfunction(val) or inspect.isclass(val): + # Make sure we don't run doctests functions or classes + # from different modules + if val.__module__ != module.__name__: + continue + + assert self._from_module(module, val), \ + "%s is not in module %s (rawname %s)" % (val, module, rawname) + + try: + valname = '%s.%s' % (name, rawname) + self._find(tests, val, valname, module, + source_lines, globs, seen) + except KeyboardInterrupt: + raise + + # Look for tests in a module's __test__ dictionary. + for valname, val in getattr(obj, '__test__', {}).items(): + if not isinstance(valname, str): + raise ValueError("SymPyDocTestFinder.find: __test__ keys " + "must be strings: %r" % + (type(valname),)) + if not (inspect.isfunction(val) or inspect.isclass(val) or + inspect.ismethod(val) or inspect.ismodule(val) or + isinstance(val, str)): + raise ValueError("SymPyDocTestFinder.find: __test__ values " + "must be strings, functions, methods, " + "classes, or modules: %r" % + (type(val),)) + valname = '%s.__test__.%s' % (name, valname) + self._find(tests, val, valname, module, source_lines, + globs, seen) + + + # Look for tests in a class's contained objects. + if inspect.isclass(obj): + for valname, val in obj.__dict__.items(): + # Special handling for staticmethod/classmethod. + if isinstance(val, staticmethod): + val = getattr(obj, valname) + if isinstance(val, classmethod): + val = getattr(obj, valname).__func__ + + + # Recurse to methods, properties, and nested classes. + if ((inspect.isfunction(unwrap(val)) or + inspect.isclass(val) or + isinstance(val, property)) and + self._from_module(module, val)): + # Make sure we don't run doctests functions or classes + # from different modules + if isinstance(val, property): + if hasattr(val.fget, '__module__'): + if val.fget.__module__ != module.__name__: + continue + else: + if val.__module__ != module.__name__: + continue + + assert self._from_module(module, val), \ + "%s is not in module %s (valname %s)" % ( + val, module, valname) + + valname = '%s.%s' % (name, valname) + self._find(tests, val, valname, module, source_lines, + globs, seen) + + def _get_test(self, obj, name, module, globs, source_lines): + """ + Return a DocTest for the given object, if it defines a docstring; + otherwise, return None. + """ + + lineno = None + + # Extract the object's docstring. If it does not have one, + # then return None (no test for this object). + if isinstance(obj, str): + # obj is a string in the case for objects in the polys package. + # Note that source_lines is a binary string (compiled polys + # modules), which can't be handled by _find_lineno so determine + # the line number here. + + docstring = obj + + matches = re.findall(r"line \d+", name) + assert len(matches) == 1, \ + "string '%s' does not contain lineno " % name + + # NOTE: this is not the exact linenumber but its better than no + # lineno ;) + lineno = int(matches[0][5:]) + + else: + docstring = getattr(obj, '__doc__', '') + if docstring is None: + docstring = '' + if not isinstance(docstring, str): + docstring = str(docstring) + + # Don't bother if the docstring is empty. + if self._exclude_empty and not docstring: + return None + + # check that properties have a docstring because _find_lineno + # assumes it + if isinstance(obj, property): + if obj.fget.__doc__ is None: + return None + + # Find the docstring's location in the file. + if lineno is None: + obj = unwrap(obj) + # handling of properties is not implemented in _find_lineno so do + # it here + if hasattr(obj, 'func_closure') and obj.func_closure is not None: + tobj = obj.func_closure[0].cell_contents + elif isinstance(obj, property): + tobj = obj.fget + else: + tobj = obj + lineno = self._find_lineno(tobj, source_lines) + + if lineno is None: + return None + + # Return a DocTest for this object. + if module is None: + filename = None + else: + filename = getattr(module, '__file__', module.__name__) + if filename[-4:] in (".pyc", ".pyo"): + filename = filename[:-1] + + globs['_doctest_depends_on'] = getattr(obj, '_doctest_depends_on', {}) + + return self._parser.get_doctest(docstring, globs, name, + filename, lineno) + + +class SymPyDocTestRunner(DocTestRunner): + """ + A class used to run DocTest test cases, and accumulate statistics. + The ``run`` method is used to process a single DocTest case. It + returns a tuple ``(f, t)``, where ``t`` is the number of test cases + tried, and ``f`` is the number of test cases that failed. + + Modified from the doctest version to not reset the sys.displayhook (see + issue 5140). + + See the docstring of the original DocTestRunner for more information. + """ + + def run(self, test, compileflags=None, out=None, clear_globs=True): + """ + Run the examples in ``test``, and display the results using the + writer function ``out``. + + The examples are run in the namespace ``test.globs``. If + ``clear_globs`` is true (the default), then this namespace will + be cleared after the test runs, to help with garbage + collection. If you would like to examine the namespace after + the test completes, then use ``clear_globs=False``. + + ``compileflags`` gives the set of flags that should be used by + the Python compiler when running the examples. If not + specified, then it will default to the set of future-import + flags that apply to ``globs``. + + The output of each example is checked using + ``SymPyDocTestRunner.check_output``, and the results are + formatted by the ``SymPyDocTestRunner.report_*`` methods. + """ + self.test = test + + # Remove ``` from the end of example, which may appear in Markdown + # files + for example in test.examples: + example.want = example.want.replace('```\n', '') + example.exc_msg = example.exc_msg and example.exc_msg.replace('```\n', '') + + + if compileflags is None: + compileflags = pdoctest._extract_future_flags(test.globs) + + save_stdout = sys.stdout + if out is None: + out = save_stdout.write + sys.stdout = self._fakeout + + # Patch pdb.set_trace to restore sys.stdout during interactive + # debugging (so it's not still redirected to self._fakeout). + # Note that the interactive output will go to *our* + # save_stdout, even if that's not the real sys.stdout; this + # allows us to write test cases for the set_trace behavior. + save_set_trace = pdb.set_trace + self.debugger = pdoctest._OutputRedirectingPdb(save_stdout) + self.debugger.reset() + pdb.set_trace = self.debugger.set_trace + + # Patch linecache.getlines, so we can see the example's source + # when we're inside the debugger. + self.save_linecache_getlines = pdoctest.linecache.getlines + linecache.getlines = self.__patched_linecache_getlines + + # Fail for deprecation warnings + with raise_on_deprecated(): + try: + return self.__run(test, compileflags, out) + finally: + sys.stdout = save_stdout + pdb.set_trace = save_set_trace + linecache.getlines = self.save_linecache_getlines + if clear_globs: + test.globs.clear() + + +# We have to override the name mangled methods. +monkeypatched_methods = [ + 'patched_linecache_getlines', + 'run', + 'record_outcome' +] +for method in monkeypatched_methods: + oldname = '_DocTestRunner__' + method + newname = '_SymPyDocTestRunner__' + method + setattr(SymPyDocTestRunner, newname, getattr(DocTestRunner, oldname)) + + +class SymPyOutputChecker(pdoctest.OutputChecker): + """ + Compared to the OutputChecker from the stdlib our OutputChecker class + supports numerical comparison of floats occurring in the output of the + doctest examples + """ + + def __init__(self): + # NOTE OutputChecker is an old-style class with no __init__ method, + # so we can't call the base class version of __init__ here + + got_floats = r'(\d+\.\d*|\.\d+)' + + # floats in the 'want' string may contain ellipses + want_floats = got_floats + r'(\.{3})?' + + front_sep = r'\s|\+|\-|\*|,' + back_sep = front_sep + r'|j|e' + + fbeg = r'^%s(?=%s|$)' % (got_floats, back_sep) + fmidend = r'(?<=%s)%s(?=%s|$)' % (front_sep, got_floats, back_sep) + self.num_got_rgx = re.compile(r'(%s|%s)' %(fbeg, fmidend)) + + fbeg = r'^%s(?=%s|$)' % (want_floats, back_sep) + fmidend = r'(?<=%s)%s(?=%s|$)' % (front_sep, want_floats, back_sep) + self.num_want_rgx = re.compile(r'(%s|%s)' %(fbeg, fmidend)) + + def check_output(self, want, got, optionflags): + """ + Return True iff the actual output from an example (`got`) + matches the expected output (`want`). These strings are + always considered to match if they are identical; but + depending on what option flags the test runner is using, + several non-exact match types are also possible. See the + documentation for `TestRunner` for more information about + option flags. + """ + # Handle the common case first, for efficiency: + # if they're string-identical, always return true. + if got == want: + return True + + # TODO parse integers as well ? + # Parse floats and compare them. If some of the parsed floats contain + # ellipses, skip the comparison. + matches = self.num_got_rgx.finditer(got) + numbers_got = [match.group(1) for match in matches] # list of strs + matches = self.num_want_rgx.finditer(want) + numbers_want = [match.group(1) for match in matches] # list of strs + if len(numbers_got) != len(numbers_want): + return False + + if len(numbers_got) > 0: + nw_ = [] + for ng, nw in zip(numbers_got, numbers_want): + if '...' in nw: + nw_.append(ng) + continue + else: + nw_.append(nw) + + if abs(float(ng)-float(nw)) > 1e-5: + return False + + got = self.num_got_rgx.sub(r'%s', got) + got = got % tuple(nw_) + + # can be used as a special sequence to signify a + # blank line, unless the DONT_ACCEPT_BLANKLINE flag is used. + if not (optionflags & pdoctest.DONT_ACCEPT_BLANKLINE): + # Replace in want with a blank line. + want = re.sub(r'(?m)^%s\s*?$' % re.escape(pdoctest.BLANKLINE_MARKER), + '', want) + # If a line in got contains only spaces, then remove the + # spaces. + got = re.sub(r'(?m)^\s*?$', '', got) + if got == want: + return True + + # This flag causes doctest to ignore any differences in the + # contents of whitespace strings. Note that this can be used + # in conjunction with the ELLIPSIS flag. + if optionflags & pdoctest.NORMALIZE_WHITESPACE: + got = ' '.join(got.split()) + want = ' '.join(want.split()) + if got == want: + return True + + # The ELLIPSIS flag says to let the sequence "..." in `want` + # match any substring in `got`. + if optionflags & pdoctest.ELLIPSIS: + if pdoctest._ellipsis_match(want, got): + return True + + # We didn't find any match; return false. + return False + + +class Reporter: + """ + Parent class for all reporters. + """ + pass + + +class PyTestReporter(Reporter): + """ + Py.test like reporter. Should produce output identical to py.test. + """ + + def __init__(self, verbose=False, tb="short", colors=True, + force_colors=False, split=None): + self._verbose = verbose + self._tb_style = tb + self._colors = colors + self._force_colors = force_colors + self._xfailed = 0 + self._xpassed = [] + self._failed = [] + self._failed_doctest = [] + self._passed = 0 + self._skipped = 0 + self._exceptions = [] + self._terminal_width = None + self._default_width = 80 + self._split = split + self._active_file = '' + self._active_f = None + + # TODO: Should these be protected? + self.slow_test_functions = [] + self.fast_test_functions = [] + + # this tracks the x-position of the cursor (useful for positioning + # things on the screen), without the need for any readline library: + self._write_pos = 0 + self._line_wrap = False + + def root_dir(self, dir): + self._root_dir = dir + + @property + def terminal_width(self): + if self._terminal_width is not None: + return self._terminal_width + + def findout_terminal_width(): + if sys.platform == "win32": + # Windows support is based on: + # + # http://code.activestate.com/recipes/ + # 440694-determine-size-of-console-window-on-windows/ + + from ctypes import windll, create_string_buffer + + h = windll.kernel32.GetStdHandle(-12) + csbi = create_string_buffer(22) + res = windll.kernel32.GetConsoleScreenBufferInfo(h, csbi) + + if res: + import struct + (_, _, _, _, _, left, _, right, _, _, _) = \ + struct.unpack("hhhhHhhhhhh", csbi.raw) + return right - left + else: + return self._default_width + + if hasattr(sys.stdout, 'isatty') and not sys.stdout.isatty(): + return self._default_width # leave PIPEs alone + + try: + process = subprocess.Popen(['stty', '-a'], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + stdout, stderr = process.communicate() + stdout = stdout.decode("utf-8") + except OSError: + pass + else: + # We support the following output formats from stty: + # + # 1) Linux -> columns 80 + # 2) OS X -> 80 columns + # 3) Solaris -> columns = 80 + + re_linux = r"columns\s+(?P\d+);" + re_osx = r"(?P\d+)\s*columns;" + re_solaris = r"columns\s+=\s+(?P\d+);" + + for regex in (re_linux, re_osx, re_solaris): + match = re.search(regex, stdout) + + if match is not None: + columns = match.group('columns') + + try: + width = int(columns) + except ValueError: + pass + if width != 0: + return width + + return self._default_width + + width = findout_terminal_width() + self._terminal_width = width + + return width + + def write(self, text, color="", align="left", width=None, + force_colors=False): + """ + Prints a text on the screen. + + It uses sys.stdout.write(), so no readline library is necessary. + + Parameters + ========== + + color : choose from the colors below, "" means default color + align : "left"/"right", "left" is a normal print, "right" is aligned on + the right-hand side of the screen, filled with spaces if + necessary + width : the screen width + + """ + color_templates = ( + ("Black", "0;30"), + ("Red", "0;31"), + ("Green", "0;32"), + ("Brown", "0;33"), + ("Blue", "0;34"), + ("Purple", "0;35"), + ("Cyan", "0;36"), + ("LightGray", "0;37"), + ("DarkGray", "1;30"), + ("LightRed", "1;31"), + ("LightGreen", "1;32"), + ("Yellow", "1;33"), + ("LightBlue", "1;34"), + ("LightPurple", "1;35"), + ("LightCyan", "1;36"), + ("White", "1;37"), + ) + + colors = {} + + for name, value in color_templates: + colors[name] = value + c_normal = '\033[0m' + c_color = '\033[%sm' + + if width is None: + width = self.terminal_width + + if align == "right": + if self._write_pos + len(text) > width: + # we don't fit on the current line, create a new line + self.write("\n") + self.write(" "*(width - self._write_pos - len(text))) + + if not self._force_colors and hasattr(sys.stdout, 'isatty') and not \ + sys.stdout.isatty(): + # the stdout is not a terminal, this for example happens if the + # output is piped to less, e.g. "bin/test | less". In this case, + # the terminal control sequences would be printed verbatim, so + # don't use any colors. + color = "" + elif sys.platform == "win32": + # Windows consoles don't support ANSI escape sequences + color = "" + elif not self._colors: + color = "" + + if self._line_wrap: + if text[0] != "\n": + sys.stdout.write("\n") + + # Avoid UnicodeEncodeError when printing out test failures + if IS_WINDOWS: + text = text.encode('raw_unicode_escape').decode('utf8', 'ignore') + elif not sys.stdout.encoding.lower().startswith('utf'): + text = text.encode(sys.stdout.encoding, 'backslashreplace' + ).decode(sys.stdout.encoding) + + if color == "": + sys.stdout.write(text) + else: + sys.stdout.write("%s%s%s" % + (c_color % colors[color], text, c_normal)) + sys.stdout.flush() + l = text.rfind("\n") + if l == -1: + self._write_pos += len(text) + else: + self._write_pos = len(text) - l - 1 + self._line_wrap = self._write_pos >= width + self._write_pos %= width + + def write_center(self, text, delim="="): + width = self.terminal_width + if text != "": + text = " %s " % text + idx = (width - len(text)) // 2 + t = delim*idx + text + delim*(width - idx - len(text)) + self.write(t + "\n") + + def write_exception(self, e, val, tb): + # remove the first item, as that is always runtests.py + tb = tb.tb_next + t = traceback.format_exception(e, val, tb) + self.write("".join(t)) + + def start(self, seed=None, msg="test process starts"): + self.write_center(msg) + executable = sys.executable + v = tuple(sys.version_info) + python_version = "%s.%s.%s-%s-%s" % v + implementation = platform.python_implementation() + if implementation == 'PyPy': + implementation += " %s.%s.%s-%s-%s" % sys.pypy_version_info + self.write("executable: %s (%s) [%s]\n" % + (executable, python_version, implementation)) + from sympy.utilities.misc import ARCH + self.write("architecture: %s\n" % ARCH) + from sympy.core.cache import USE_CACHE + self.write("cache: %s\n" % USE_CACHE) + version = '' + if GROUND_TYPES =='gmpy': + import gmpy2 as gmpy + version = gmpy.version() + self.write("ground types: %s %s\n" % (GROUND_TYPES, version)) + numpy = import_module('numpy') + self.write("numpy: %s\n" % (None if not numpy else numpy.__version__)) + if seed is not None: + self.write("random seed: %d\n" % seed) + from sympy.utilities.misc import HASH_RANDOMIZATION + self.write("hash randomization: ") + hash_seed = os.getenv("PYTHONHASHSEED") or '0' + if HASH_RANDOMIZATION and (hash_seed == "random" or int(hash_seed)): + self.write("on (PYTHONHASHSEED=%s)\n" % hash_seed) + else: + self.write("off\n") + if self._split: + self.write("split: %s\n" % self._split) + self.write('\n') + self._t_start = clock() + + def finish(self): + self._t_end = clock() + self.write("\n") + global text, linelen + text = "tests finished: %d passed, " % self._passed + linelen = len(text) + + def add_text(mytext): + global text, linelen + """Break new text if too long.""" + if linelen + len(mytext) > self.terminal_width: + text += '\n' + linelen = 0 + text += mytext + linelen += len(mytext) + + if len(self._failed) > 0: + add_text("%d failed, " % len(self._failed)) + if len(self._failed_doctest) > 0: + add_text("%d failed, " % len(self._failed_doctest)) + if self._skipped > 0: + add_text("%d skipped, " % self._skipped) + if self._xfailed > 0: + add_text("%d expected to fail, " % self._xfailed) + if len(self._xpassed) > 0: + add_text("%d expected to fail but passed, " % len(self._xpassed)) + if len(self._exceptions) > 0: + add_text("%d exceptions, " % len(self._exceptions)) + add_text("in %.2f seconds" % (self._t_end - self._t_start)) + + if self.slow_test_functions: + self.write_center('slowest tests', '_') + sorted_slow = sorted(self.slow_test_functions, key=lambda r: r[1]) + for slow_func_name, taken in sorted_slow: + print('%s - Took %.3f seconds' % (slow_func_name, taken)) + + if self.fast_test_functions: + self.write_center('unexpectedly fast tests', '_') + sorted_fast = sorted(self.fast_test_functions, + key=lambda r: r[1]) + for fast_func_name, taken in sorted_fast: + print('%s - Took %.3f seconds' % (fast_func_name, taken)) + + if len(self._xpassed) > 0: + self.write_center("xpassed tests", "_") + for e in self._xpassed: + self.write("%s: %s\n" % (e[0], e[1])) + self.write("\n") + + if self._tb_style != "no" and len(self._exceptions) > 0: + for e in self._exceptions: + filename, f, (t, val, tb) = e + self.write_center("", "_") + if f is None: + s = "%s" % filename + else: + s = "%s:%s" % (filename, f.__name__) + self.write_center(s, "_") + self.write_exception(t, val, tb) + self.write("\n") + + if self._tb_style != "no" and len(self._failed) > 0: + for e in self._failed: + filename, f, (t, val, tb) = e + self.write_center("", "_") + self.write_center("%s::%s" % (filename, f.__name__), "_") + self.write_exception(t, val, tb) + self.write("\n") + + if self._tb_style != "no" and len(self._failed_doctest) > 0: + for e in self._failed_doctest: + filename, msg = e + self.write_center("", "_") + self.write_center("%s" % filename, "_") + self.write(msg) + self.write("\n") + + self.write_center(text) + ok = len(self._failed) == 0 and len(self._exceptions) == 0 and \ + len(self._failed_doctest) == 0 + if not ok: + self.write("DO *NOT* COMMIT!\n") + return ok + + def entering_filename(self, filename, n): + rel_name = filename[len(self._root_dir) + 1:] + self._active_file = rel_name + self._active_file_error = False + self.write(rel_name) + self.write("[%d] " % n) + + def leaving_filename(self): + self.write(" ") + if self._active_file_error: + self.write("[FAIL]", "Red", align="right") + else: + self.write("[OK]", "Green", align="right") + self.write("\n") + if self._verbose: + self.write("\n") + + def entering_test(self, f): + self._active_f = f + if self._verbose: + self.write("\n" + f.__name__ + " ") + + def test_xfail(self): + self._xfailed += 1 + self.write("f", "Green") + + def test_xpass(self, v): + message = str(v) + self._xpassed.append((self._active_file, message)) + self.write("X", "Green") + + def test_fail(self, exc_info): + self._failed.append((self._active_file, self._active_f, exc_info)) + self.write("F", "Red") + self._active_file_error = True + + def doctest_fail(self, name, error_msg): + # the first line contains "******", remove it: + error_msg = "\n".join(error_msg.split("\n")[1:]) + self._failed_doctest.append((name, error_msg)) + self.write("F", "Red") + self._active_file_error = True + + def test_pass(self, char="."): + self._passed += 1 + if self._verbose: + self.write("ok", "Green") + else: + self.write(char, "Green") + + def test_skip(self, v=None): + char = "s" + self._skipped += 1 + if v is not None: + message = str(v) + if message == "KeyboardInterrupt": + char = "K" + elif message == "Timeout": + char = "T" + elif message == "Slow": + char = "w" + if self._verbose: + if v is not None: + self.write(message + ' ', "Blue") + else: + self.write(" - ", "Blue") + self.write(char, "Blue") + + def test_exception(self, exc_info): + self._exceptions.append((self._active_file, self._active_f, exc_info)) + if exc_info[0] is TimeOutError: + self.write("T", "Red") + else: + self.write("E", "Red") + self._active_file_error = True + + def import_error(self, filename, exc_info): + self._exceptions.append((filename, None, exc_info)) + rel_name = filename[len(self._root_dir) + 1:] + self.write(rel_name) + self.write("[?] Failed to import", "Red") + self.write(" ") + self.write("[FAIL]", "Red", align="right") + self.write("\n") diff --git a/phivenv/Lib/site-packages/sympy/testing/runtests_pytest.py b/phivenv/Lib/site-packages/sympy/testing/runtests_pytest.py new file mode 100644 index 0000000000000000000000000000000000000000..635f27864ca86571128e6c9a055199dfbde1ed63 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/testing/runtests_pytest.py @@ -0,0 +1,461 @@ +"""Backwards compatible functions for running tests from SymPy using pytest. + +SymPy historically had its own testing framework that aimed to: +- be compatible with pytest; +- operate similarly (or identically) to pytest; +- not require any external dependencies; +- have all the functionality in one file only; +- have no magic, just import the test file and execute the test functions; and +- be portable. + +To reduce the maintenance burden of developing an independent testing framework +and to leverage the benefits of existing Python testing infrastructure, SymPy +now uses pytest (and various of its plugins) to run the test suite. + +To maintain backwards compatibility with the legacy testing interface of SymPy, +which implemented functions that allowed users to run the tests on their +installed version of SymPy, the functions in this module are implemented to +match the existing API while thinly wrapping pytest. + +These two key functions are `test` and `doctest`. + +""" + +import functools +import importlib.util +import os +import pathlib +import re +from fnmatch import fnmatch +from typing import List, Optional, Tuple + +try: + import pytest +except ImportError: + + class NoPytestError(Exception): + """Raise when an internal test helper function is called with pytest.""" + + class pytest: # type: ignore + """Shadow to support pytest features when pytest can't be imported.""" + + @staticmethod + def main(*args, **kwargs): + msg = 'pytest must be installed to run tests via this function' + raise NoPytestError(msg) + +from sympy.testing.runtests import test as test_sympy + + +TESTPATHS_DEFAULT = ( + pathlib.Path('sympy'), + pathlib.Path('doc', 'src'), +) +BLACKLIST_DEFAULT = ( + 'sympy/integrals/rubi/rubi_tests/tests', +) + + +class PytestPluginManager: + """Module names for pytest plugins used by SymPy.""" + PYTEST: str = 'pytest' + RANDOMLY: str = 'pytest_randomly' + SPLIT: str = 'pytest_split' + TIMEOUT: str = 'pytest_timeout' + XDIST: str = 'xdist' + + @functools.cached_property + def has_pytest(self) -> bool: + return bool(importlib.util.find_spec(self.PYTEST)) + + @functools.cached_property + def has_randomly(self) -> bool: + return bool(importlib.util.find_spec(self.RANDOMLY)) + + @functools.cached_property + def has_split(self) -> bool: + return bool(importlib.util.find_spec(self.SPLIT)) + + @functools.cached_property + def has_timeout(self) -> bool: + return bool(importlib.util.find_spec(self.TIMEOUT)) + + @functools.cached_property + def has_xdist(self) -> bool: + return bool(importlib.util.find_spec(self.XDIST)) + + +split_pattern = re.compile(r'([1-9][0-9]*)/([1-9][0-9]*)') + + +@functools.lru_cache +def sympy_dir() -> pathlib.Path: + """Returns the root SymPy directory.""" + return pathlib.Path(__file__).parents[2] + + +def update_args_with_paths( + paths: List[str], + keywords: Optional[Tuple[str]], + args: List[str], +) -> List[str]: + """Appends valid paths and flags to the args `list` passed to `pytest.main`. + + The are three different types of "path" that a user may pass to the `paths` + positional arguments, all of which need to be handled slightly differently: + + 1. Nothing is passed + The paths to the `testpaths` defined in `pytest.ini` need to be appended + to the arguments list. + 2. Full, valid paths are passed + These paths need to be validated but can then be directly appended to + the arguments list. + 3. Partial paths are passed. + The `testpaths` defined in `pytest.ini` need to be recursed and any + matches be appended to the arguments list. + + """ + + def find_paths_matching_partial(partial_paths): + partial_path_file_patterns = [] + for partial_path in partial_paths: + if len(partial_path) >= 4: + has_test_prefix = partial_path[:4] == 'test' + has_py_suffix = partial_path[-3:] == '.py' + elif len(partial_path) >= 3: + has_test_prefix = False + has_py_suffix = partial_path[-3:] == '.py' + else: + has_test_prefix = False + has_py_suffix = False + if has_test_prefix and has_py_suffix: + partial_path_file_patterns.append(partial_path) + elif has_test_prefix: + partial_path_file_patterns.append(f'{partial_path}*.py') + elif has_py_suffix: + partial_path_file_patterns.append(f'test*{partial_path}') + else: + partial_path_file_patterns.append(f'test*{partial_path}*.py') + matches = [] + for testpath in valid_testpaths_default: + for path, dirs, files in os.walk(testpath, topdown=True): + zipped = zip(partial_paths, partial_path_file_patterns) + for (partial_path, partial_path_file) in zipped: + if fnmatch(path, f'*{partial_path}*'): + matches.append(str(pathlib.Path(path))) + dirs[:] = [] + else: + for file in files: + if fnmatch(file, partial_path_file): + matches.append(str(pathlib.Path(path, file))) + return matches + + def is_tests_file(filepath: str) -> bool: + path = pathlib.Path(filepath) + if not path.is_file(): + return False + if not path.parts[-1].startswith('test_'): + return False + if not path.suffix == '.py': + return False + return True + + def find_tests_matching_keywords(keywords, filepath): + matches = [] + source = pathlib.Path(filepath).read_text(encoding='utf-8') + for line in source.splitlines(): + if line.lstrip().startswith('def '): + for kw in keywords: + if line.lower().find(kw.lower()) != -1: + test_name = line.split(' ')[1].split('(')[0] + full_test_path = filepath + '::' + test_name + matches.append(full_test_path) + return matches + + valid_testpaths_default = [] + for testpath in TESTPATHS_DEFAULT: + absolute_testpath = pathlib.Path(sympy_dir(), testpath) + if absolute_testpath.exists(): + valid_testpaths_default.append(str(absolute_testpath)) + + candidate_paths = [] + if paths: + full_paths = [] + partial_paths = [] + for path in paths: + if pathlib.Path(path).exists(): + full_paths.append(str(pathlib.Path(sympy_dir(), path))) + else: + partial_paths.append(path) + matched_paths = find_paths_matching_partial(partial_paths) + candidate_paths.extend(full_paths) + candidate_paths.extend(matched_paths) + else: + candidate_paths.extend(valid_testpaths_default) + + if keywords is not None and keywords != (): + matches = [] + for path in candidate_paths: + if is_tests_file(path): + test_matches = find_tests_matching_keywords(keywords, path) + matches.extend(test_matches) + else: + for root, dirnames, filenames in os.walk(path): + for filename in filenames: + absolute_filepath = str(pathlib.Path(root, filename)) + if is_tests_file(absolute_filepath): + test_matches = find_tests_matching_keywords( + keywords, + absolute_filepath, + ) + matches.extend(test_matches) + args.extend(matches) + else: + args.extend(candidate_paths) + + return args + + +def make_absolute_path(partial_path: str) -> str: + """Convert a partial path to an absolute path. + + A path such a `sympy/core` might be needed. However, absolute paths should + be used in the arguments to pytest in all cases as it avoids errors that + arise from nonexistent paths. + + This function assumes that partial_paths will be passed in such that they + begin with the explicit `sympy` directory, i.e. `sympy/...`. + + """ + + def is_valid_partial_path(partial_path: str) -> bool: + """Assumption that partial paths are defined from the `sympy` root.""" + return pathlib.Path(partial_path).parts[0] == 'sympy' + + if not is_valid_partial_path(partial_path): + msg = ( + f'Partial path {dir(partial_path)} is invalid, partial paths are ' + f'expected to be defined with the `sympy` directory as the root.' + ) + raise ValueError(msg) + + absolute_path = str(pathlib.Path(sympy_dir(), partial_path)) + return absolute_path + + +def test(*paths, subprocess=True, rerun=0, **kwargs): + """Interface to run tests via pytest compatible with SymPy's test runner. + + Explanation + =========== + + Note that a `pytest.ExitCode`, which is an `enum`, is returned. This is + different to the legacy SymPy test runner which would return a `bool`. If + all tests successfully pass the `pytest.ExitCode.OK` with value `0` is + returned, whereas the legacy SymPy test runner would return `True`. In any + other scenario, a non-zero `enum` value is returned, whereas the legacy + SymPy test runner would return `False`. Users need to, therefore, be careful + if treating the pytest exit codes as booleans because + `bool(pytest.ExitCode.OK)` evaluates to `False`, the opposite of legacy + behaviour. + + Examples + ======== + + >>> import sympy # doctest: +SKIP + + Run one file: + + >>> sympy.test('sympy/core/tests/test_basic.py') # doctest: +SKIP + >>> sympy.test('_basic') # doctest: +SKIP + + Run all tests in sympy/functions/ and some particular file: + + >>> sympy.test("sympy/core/tests/test_basic.py", + ... "sympy/functions") # doctest: +SKIP + + Run all tests in sympy/core and sympy/utilities: + + >>> sympy.test("/core", "/util") # doctest: +SKIP + + Run specific test from a file: + + >>> sympy.test("sympy/core/tests/test_basic.py", + ... kw="test_equality") # doctest: +SKIP + + Run specific test from any file: + + >>> sympy.test(kw="subs") # doctest: +SKIP + + Run the tests using the legacy SymPy runner: + + >>> sympy.test(use_sympy_runner=True) # doctest: +SKIP + + Note that this option is slated for deprecation in the near future and is + only currently provided to ensure users have an alternative option while the + pytest-based runner receives real-world testing. + + Parameters + ========== + paths : first n positional arguments of strings + Paths, both partial and absolute, describing which subset(s) of the test + suite are to be run. + subprocess : bool, default is True + Legacy option, is currently ignored. + rerun : int, default is 0 + Legacy option, is ignored. + use_sympy_runner : bool or None, default is None + Temporary option to invoke the legacy SymPy test runner instead of + `pytest.main`. Will be removed in the near future. + verbose : bool, default is False + Sets the verbosity of the pytest output. Using `True` will add the + `--verbose` option to the pytest call. + tb : str, 'auto', 'long', 'short', 'line', 'native', or 'no' + Sets the traceback print mode of pytest using the `--tb` option. + kw : str + Only run tests which match the given substring expression. An expression + is a Python evaluatable expression where all names are substring-matched + against test names and their parent classes. Example: -k 'test_method or + test_other' matches all test functions and classes whose name contains + 'test_method' or 'test_other', while -k 'not test_method' matches those + that don't contain 'test_method' in their names. -k 'not test_method and + not test_other' will eliminate the matches. Additionally keywords are + matched to classes and functions containing extra names in their + 'extra_keyword_matches' set, as well as functions which have names + assigned directly to them. The matching is case-insensitive. + pdb : bool, default is False + Start the interactive Python debugger on errors or `KeyboardInterrupt`. + colors : bool, default is True + Color terminal output. + force_colors : bool, default is False + Legacy option, is ignored. + sort : bool, default is True + Run the tests in sorted order. pytest uses a sorted test order by + default. Requires pytest-randomly. + seed : int + Seed to use for random number generation. Requires pytest-randomly. + timeout : int, default is 0 + Timeout in seconds before dumping the stacks. 0 means no timeout. + Requires pytest-timeout. + fail_on_timeout : bool, default is False + Legacy option, is currently ignored. + slow : bool, default is False + Run the subset of tests marked as `slow`. + enhance_asserts : bool, default is False + Legacy option, is currently ignored. + split : string in form `/` or None, default is None + Used to split the tests up. As an example, if `split='2/3' is used then + only the middle third of tests are run. Requires pytest-split. + time_balance : bool, default is True + Legacy option, is currently ignored. + blacklist : iterable of test paths as strings, default is BLACKLIST_DEFAULT + Blacklisted test paths are ignored using the `--ignore` option. Paths + may be partial or absolute. If partial then they are matched against + all paths in the pytest tests path. + parallel : bool, default is False + Parallelize the test running using pytest-xdist. If `True` then pytest + will automatically detect the number of CPU cores available and use them + all. Requires pytest-xdist. + store_durations : bool, False + Store test durations into the file `.test_durations`. The is used by + `pytest-split` to help determine more even splits when more than one + test group is being used. Requires pytest-split. + + """ + # NOTE: to be removed alongside SymPy test runner + if kwargs.get('use_sympy_runner', False): + kwargs.pop('parallel', False) + kwargs.pop('store_durations', False) + kwargs.pop('use_sympy_runner', True) + if kwargs.get('slow') is None: + kwargs['slow'] = False + return test_sympy(*paths, subprocess=True, rerun=0, **kwargs) + + pytest_plugin_manager = PytestPluginManager() + if not pytest_plugin_manager.has_pytest: + pytest.main() + + args = [] + + if kwargs.get('verbose', False): + args.append('--verbose') + + if tb := kwargs.get('tb'): + args.extend(['--tb', tb]) + + if kwargs.get('pdb'): + args.append('--pdb') + + if not kwargs.get('colors', True): + args.extend(['--color', 'no']) + + if seed := kwargs.get('seed'): + if not pytest_plugin_manager.has_randomly: + msg = '`pytest-randomly` plugin required to control random seed.' + raise ModuleNotFoundError(msg) + args.extend(['--randomly-seed', str(seed)]) + + if kwargs.get('sort', True) and pytest_plugin_manager.has_randomly: + args.append('--randomly-dont-reorganize') + elif not kwargs.get('sort', True) and not pytest_plugin_manager.has_randomly: + msg = '`pytest-randomly` plugin required to randomize test order.' + raise ModuleNotFoundError(msg) + + if timeout := kwargs.get('timeout', None): + if not pytest_plugin_manager.has_timeout: + msg = '`pytest-timeout` plugin required to apply timeout to tests.' + raise ModuleNotFoundError(msg) + args.extend(['--timeout', str(int(timeout))]) + + # Skip slow tests by default and always skip tooslow tests + if kwargs.get('slow', False): + args.extend(['-m', 'slow and not tooslow']) + else: + args.extend(['-m', 'not slow and not tooslow']) + + if (split := kwargs.get('split')) is not None: + if not pytest_plugin_manager.has_split: + msg = '`pytest-split` plugin required to run tests as groups.' + raise ModuleNotFoundError(msg) + match = split_pattern.match(split) + if not match: + msg = ('split must be a string of the form a/b where a and b are ' + 'positive nonzero ints') + raise ValueError(msg) + group, splits = map(str, match.groups()) + args.extend(['--group', group, '--splits', splits]) + if group > splits: + msg = (f'cannot have a group number {group} with only {splits} ' + 'splits') + raise ValueError(msg) + + if blacklist := kwargs.get('blacklist', BLACKLIST_DEFAULT): + for path in blacklist: + args.extend(['--ignore', make_absolute_path(path)]) + + if kwargs.get('parallel', False): + if not pytest_plugin_manager.has_xdist: + msg = '`pytest-xdist` plugin required to run tests in parallel.' + raise ModuleNotFoundError(msg) + args.extend(['-n', 'auto']) + + if kwargs.get('store_durations', False): + if not pytest_plugin_manager.has_split: + msg = '`pytest-split` plugin required to store test durations.' + raise ModuleNotFoundError(msg) + args.append('--store-durations') + + if (keywords := kwargs.get('kw')) is not None: + keywords = tuple(str(kw) for kw in keywords) + else: + keywords = () + + args = update_args_with_paths(paths, keywords, args) + exit_code = pytest.main(args) + return exit_code + + +def doctest(): + """Interface to run doctests via pytest compatible with SymPy's test runner. + """ + raise NotImplementedError diff --git a/phivenv/Lib/site-packages/sympy/testing/tests/diagnose_imports.py b/phivenv/Lib/site-packages/sympy/testing/tests/diagnose_imports.py new file mode 100644 index 0000000000000000000000000000000000000000..a31b66c66690c082800ae36eee37dad6927e0b37 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/testing/tests/diagnose_imports.py @@ -0,0 +1,216 @@ +#!/usr/bin/env python + +""" +Import diagnostics. Run bin/diagnose_imports.py --help for details. +""" + +from __future__ import annotations + +if __name__ == "__main__": + + import sys + import inspect + import builtins + + import optparse + + from os.path import abspath, dirname, join, normpath + this_file = abspath(__file__) + sympy_dir = join(dirname(this_file), '..', '..', '..') + sympy_dir = normpath(sympy_dir) + sys.path.insert(0, sympy_dir) + + option_parser = optparse.OptionParser( + usage= + "Usage: %prog option [options]\n" + "\n" + "Import analysis for imports between SymPy modules.") + option_group = optparse.OptionGroup( + option_parser, + 'Analysis options', + 'Options that define what to do. Exactly one of these must be given.') + option_group.add_option( + '--problems', + help= + 'Print all import problems, that is: ' + 'If an import pulls in a package instead of a module ' + '(e.g. sympy.core instead of sympy.core.add); ' # see ##PACKAGE## + 'if it imports a symbol that is already present; ' # see ##DUPLICATE## + 'if it imports a symbol ' + 'from somewhere other than the defining module.', # see ##ORIGIN## + action='count') + option_group.add_option( + '--origins', + help= + 'For each imported symbol in each module, ' + 'print the module that defined it. ' + '(This is useful for import refactoring.)', + action='count') + option_parser.add_option_group(option_group) + option_group = optparse.OptionGroup( + option_parser, + 'Sort options', + 'These options define the sort order for output lines. ' + 'At most one of these options is allowed. ' + 'Unsorted output will reflect the order in which imports happened.') + option_group.add_option( + '--by-importer', + help='Sort output lines by name of importing module.', + action='count') + option_group.add_option( + '--by-origin', + help='Sort output lines by name of imported module.', + action='count') + option_parser.add_option_group(option_group) + (options, args) = option_parser.parse_args() + if args: + option_parser.error( + 'Unexpected arguments %s (try %s --help)' % (args, sys.argv[0])) + if options.problems > 1: + option_parser.error('--problems must not be given more than once.') + if options.origins > 1: + option_parser.error('--origins must not be given more than once.') + if options.by_importer > 1: + option_parser.error('--by-importer must not be given more than once.') + if options.by_origin > 1: + option_parser.error('--by-origin must not be given more than once.') + options.problems = options.problems == 1 + options.origins = options.origins == 1 + options.by_importer = options.by_importer == 1 + options.by_origin = options.by_origin == 1 + if not options.problems and not options.origins: + option_parser.error( + 'At least one of --problems and --origins is required') + if options.problems and options.origins: + option_parser.error( + 'At most one of --problems and --origins is allowed') + if options.by_importer and options.by_origin: + option_parser.error( + 'At most one of --by-importer and --by-origin is allowed') + options.by_process = not options.by_importer and not options.by_origin + + builtin_import = builtins.__import__ + + class Definition: + """Information about a symbol's definition.""" + def __init__(self, name, value, definer): + self.name = name + self.value = value + self.definer = definer + def __hash__(self): + return hash(self.name) + def __eq__(self, other): + return self.name == other.name and self.value == other.value + def __ne__(self, other): + return not (self == other) + def __repr__(self): + return 'Definition(%s, ..., %s)' % ( + repr(self.name), repr(self.definer)) + + # Maps each function/variable to name of module to define it + symbol_definers: dict[Definition, str] = {} + + def in_module(a, b): + """Is a the same module as or a submodule of b?""" + return a == b or a != None and b != None and a.startswith(b + '.') + + def relevant(module): + """Is module relevant for import checking? + + Only imports between relevant modules will be checked.""" + return in_module(module, 'sympy') + + sorted_messages = [] + + def msg(msg, *args): + if options.by_process: + print(msg % args) + else: + sorted_messages.append(msg % args) + + def tracking_import(module, globals=globals(), locals=[], fromlist=None, level=-1): + """__import__ wrapper - does not change imports at all, but tracks them. + + Default order is implemented by doing output directly. + All other orders are implemented by collecting output information into + a sorted list that will be emitted after all imports are processed. + + Indirect imports can only occur after the requested symbol has been + imported directly (because the indirect import would not have a module + to pick the symbol up from). + So this code detects indirect imports by checking whether the symbol in + question was already imported. + + Keeps the semantics of __import__ unchanged.""" + caller_frame = inspect.getframeinfo(sys._getframe(1)) + importer_filename = caller_frame.filename + importer_module = globals['__name__'] + if importer_filename == caller_frame.filename: + importer_reference = '%s line %s' % ( + importer_filename, str(caller_frame.lineno)) + else: + importer_reference = importer_filename + result = builtin_import(module, globals, locals, fromlist, level) + importee_module = result.__name__ + # We're only interested if importer and importee are in SymPy + if relevant(importer_module) and relevant(importee_module): + for symbol in result.__dict__.iterkeys(): + definition = Definition( + symbol, result.__dict__[symbol], importer_module) + if definition not in symbol_definers: + symbol_definers[definition] = importee_module + if hasattr(result, '__path__'): + ##PACKAGE## + # The existence of __path__ is documented in the tutorial on modules. + # Python 3.3 documents this in http://docs.python.org/3.3/reference/import.html + if options.by_origin: + msg('Error: %s (a package) is imported by %s', + module, importer_reference) + else: + msg('Error: %s contains package import %s', + importer_reference, module) + if fromlist != None: + symbol_list = fromlist + if '*' in symbol_list: + if (importer_filename.endswith(("__init__.py", "__init__.pyc", "__init__.pyo"))): + # We do not check starred imports inside __init__ + # That's the normal "please copy over its imports to my namespace" + symbol_list = [] + else: + symbol_list = result.__dict__.iterkeys() + for symbol in symbol_list: + if symbol not in result.__dict__: + if options.by_origin: + msg('Error: %s.%s is not defined (yet), but %s tries to import it', + importee_module, symbol, importer_reference) + else: + msg('Error: %s tries to import %s.%s, which did not define it (yet)', + importer_reference, importee_module, symbol) + else: + definition = Definition( + symbol, result.__dict__[symbol], importer_module) + symbol_definer = symbol_definers[definition] + if symbol_definer == importee_module: + ##DUPLICATE## + if options.by_origin: + msg('Error: %s.%s is imported again into %s', + importee_module, symbol, importer_reference) + else: + msg('Error: %s imports %s.%s again', + importer_reference, importee_module, symbol) + else: + ##ORIGIN## + if options.by_origin: + msg('Error: %s.%s is imported by %s, which should import %s.%s instead', + importee_module, symbol, importer_reference, symbol_definer, symbol) + else: + msg('Error: %s imports %s.%s but should import %s.%s instead', + importer_reference, importee_module, symbol, symbol_definer, symbol) + return result + + builtins.__import__ = tracking_import + __import__('sympy') + + sorted_messages.sort() + for message in sorted_messages: + print(message) diff --git a/phivenv/Lib/site-packages/sympy/testing/tests/test_code_quality.py b/phivenv/Lib/site-packages/sympy/testing/tests/test_code_quality.py new file mode 100644 index 0000000000000000000000000000000000000000..9a9f363f0b9a802553f8643186b7d858d1ad0694 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/testing/tests/test_code_quality.py @@ -0,0 +1,510 @@ +# coding=utf-8 +from os import walk, sep, pardir +from os.path import split, join, abspath, exists, isfile +from glob import glob +import re +import random +import ast + +from sympy.testing.pytest import raises +from sympy.testing.quality_unicode import _test_this_file_encoding + +# System path separator (usually slash or backslash) to be +# used with excluded files, e.g. +# exclude = set([ +# "%(sep)smpmath%(sep)s" % sepd, +# ]) +sepd = {"sep": sep} + +# path and sympy_path +SYMPY_PATH = abspath(join(split(__file__)[0], pardir, pardir)) # go to sympy/ +assert exists(SYMPY_PATH) + +TOP_PATH = abspath(join(SYMPY_PATH, pardir)) +BIN_PATH = join(TOP_PATH, "bin") +EXAMPLES_PATH = join(TOP_PATH, "examples") + +# Error messages +message_space = "File contains trailing whitespace: %s, line %s." +message_implicit = "File contains an implicit import: %s, line %s." +message_tabs = "File contains tabs instead of spaces: %s, line %s." +message_carriage = "File contains carriage returns at end of line: %s, line %s" +message_str_raise = "File contains string exception: %s, line %s" +message_gen_raise = "File contains generic exception: %s, line %s" +message_old_raise = "File contains old-style raise statement: %s, line %s, \"%s\"" +message_eof = "File does not end with a newline: %s, line %s" +message_multi_eof = "File ends with more than 1 newline: %s, line %s" +message_test_suite_def = "Function should start with 'test_' or '_': %s, line %s" +message_duplicate_test = "This is a duplicate test function: %s, line %s" +message_self_assignments = "File contains assignments to self/cls: %s, line %s." +message_func_is = "File contains '.func is': %s, line %s." +message_bare_expr = "File contains bare expression: %s, line %s." + +implicit_test_re = re.compile(r'^\s*(>>> )?(\.\.\. )?from .* import .*\*') +str_raise_re = re.compile( + r'^\s*(>>> )?(\.\.\. )?raise(\s+(\'|\")|\s*(\(\s*)+(\'|\"))') +gen_raise_re = re.compile( + r'^\s*(>>> )?(\.\.\. )?raise(\s+Exception|\s*(\(\s*)+Exception)') +old_raise_re = re.compile(r'^\s*(>>> )?(\.\.\. )?raise((\s*\(\s*)|\s+)\w+\s*,') +test_suite_def_re = re.compile(r'^def\s+(?!(_|test))[^(]*\(\s*\)\s*:$') +test_ok_def_re = re.compile(r'^def\s+test_.*:$') +test_file_re = re.compile(r'.*[/\\]test_.*\.py$') +func_is_re = re.compile(r'\.\s*func\s+is') + + +def tab_in_leading(s): + """Returns True if there are tabs in the leading whitespace of a line, + including the whitespace of docstring code samples.""" + n = len(s) - len(s.lstrip()) + if not s[n:n + 3] in ['...', '>>>']: + check = s[:n] + else: + smore = s[n + 3:] + check = s[:n] + smore[:len(smore) - len(smore.lstrip())] + return not (check.expandtabs() == check) + + +def find_self_assignments(s): + """Returns a list of "bad" assignments: if there are instances + of assigning to the first argument of the class method (except + for staticmethod's). + """ + t = [n for n in ast.parse(s).body if isinstance(n, ast.ClassDef)] + + bad = [] + for c in t: + for n in c.body: + if not isinstance(n, ast.FunctionDef): + continue + if any(d.id == 'staticmethod' + for d in n.decorator_list if isinstance(d, ast.Name)): + continue + if n.name == '__new__': + continue + if not n.args.args: + continue + first_arg = n.args.args[0].arg + + for m in ast.walk(n): + if isinstance(m, ast.Assign): + for a in m.targets: + if isinstance(a, ast.Name) and a.id == first_arg: + bad.append(m) + elif (isinstance(a, ast.Tuple) and + any(q.id == first_arg for q in a.elts + if isinstance(q, ast.Name))): + bad.append(m) + + return bad + + +def check_directory_tree(base_path, file_check, exclusions=set(), pattern="*.py"): + """ + Checks all files in the directory tree (with base_path as starting point) + with the file_check function provided, skipping files that contain + any of the strings in the set provided by exclusions. + """ + if not base_path: + return + for root, dirs, files in walk(base_path): + check_files(glob(join(root, pattern)), file_check, exclusions) + + +def check_files(files, file_check, exclusions=set(), pattern=None): + """ + Checks all files with the file_check function provided, skipping files + that contain any of the strings in the set provided by exclusions. + """ + if not files: + return + for fname in files: + if not exists(fname) or not isfile(fname): + continue + if any(ex in fname for ex in exclusions): + continue + if pattern is None or re.match(pattern, fname): + file_check(fname) + + +class _Visit(ast.NodeVisitor): + """return the line number corresponding to the + line on which a bare expression appears if it is a binary op + or a comparison that is not in a with block. + + EXAMPLES + ======== + + >>> import ast + >>> class _Visit(ast.NodeVisitor): + ... def visit_Expr(self, node): + ... if isinstance(node.value, (ast.BinOp, ast.Compare)): + ... print(node.lineno) + ... def visit_With(self, node): + ... pass # no checking there + ... + >>> code='''x = 1 # line 1 + ... for i in range(3): + ... x == 2 # <-- 3 + ... if x == 2: + ... x == 3 # <-- 5 + ... x + 1 # <-- 6 + ... x = 1 + ... if x == 1: + ... print(1) + ... while x != 1: + ... x == 1 # <-- 11 + ... with raises(TypeError): + ... c == 1 + ... raise TypeError + ... assert x == 1 + ... ''' + >>> _Visit().visit(ast.parse(code)) + 3 + 5 + 6 + 11 + """ + def visit_Expr(self, node): + if isinstance(node.value, (ast.BinOp, ast.Compare)): + assert None, message_bare_expr % ('', node.lineno) + def visit_With(self, node): + pass + + +BareExpr = _Visit() + + +def line_with_bare_expr(code): + """return None or else 0-based line number of code on which + a bare expression appeared. + """ + tree = ast.parse(code) + try: + BareExpr.visit(tree) + except AssertionError as msg: + assert msg.args + msg = msg.args[0] + assert msg.startswith(message_bare_expr.split(':', 1)[0]) + return int(msg.rsplit(' ', 1)[1].rstrip('.')) # the line number + + +def test_files(): + """ + This test tests all files in SymPy and checks that: + o no lines contains a trailing whitespace + o no lines end with \r\n + o no line uses tabs instead of spaces + o that the file ends with a single newline + o there are no general or string exceptions + o there are no old style raise statements + o name of arg-less test suite functions start with _ or test_ + o no duplicate function names that start with test_ + o no assignments to self variable in class methods + o no lines contain ".func is" except in the test suite + o there is no do-nothing expression like `a == b` or `x + 1` + """ + + def test(fname): + with open(fname, encoding="utf8") as test_file: + test_this_file(fname, test_file) + with open(fname, encoding='utf8') as test_file: + _test_this_file_encoding(fname, test_file) + + def test_this_file(fname, test_file): + idx = None + code = test_file.read() + test_file.seek(0) # restore reader to head + py = fname if sep not in fname else fname.rsplit(sep, 1)[-1] + if py.startswith('test_'): + idx = line_with_bare_expr(code) + if idx is not None: + assert False, message_bare_expr % (fname, idx + 1) + + line = None # to flag the case where there were no lines in file + tests = 0 + test_set = set() + for idx, line in enumerate(test_file): + if test_file_re.match(fname): + if test_suite_def_re.match(line): + assert False, message_test_suite_def % (fname, idx + 1) + if test_ok_def_re.match(line): + tests += 1 + test_set.add(line[3:].split('(')[0].strip()) + if len(test_set) != tests: + assert False, message_duplicate_test % (fname, idx + 1) + if line.endswith((" \n", "\t\n")): + assert False, message_space % (fname, idx + 1) + if line.endswith("\r\n"): + assert False, message_carriage % (fname, idx + 1) + if tab_in_leading(line): + assert False, message_tabs % (fname, idx + 1) + if str_raise_re.search(line): + assert False, message_str_raise % (fname, idx + 1) + if gen_raise_re.search(line): + assert False, message_gen_raise % (fname, idx + 1) + if (implicit_test_re.search(line) and + not list(filter(lambda ex: ex in fname, import_exclude))): + assert False, message_implicit % (fname, idx + 1) + if func_is_re.search(line) and not test_file_re.search(fname): + assert False, message_func_is % (fname, idx + 1) + + result = old_raise_re.search(line) + + if result is not None: + assert False, message_old_raise % ( + fname, idx + 1, result.group(2)) + + if line is not None: + if line == '\n' and idx > 0: + assert False, message_multi_eof % (fname, idx + 1) + elif not line.endswith('\n'): + # eof newline check + assert False, message_eof % (fname, idx + 1) + + + # Files to test at top level + top_level_files = [join(TOP_PATH, file) for file in [ + "isympy.py", + "build.py", + "setup.py", + ]] + # Files to exclude from all tests + exclude = { + "%(sep)ssympy%(sep)sparsing%(sep)sautolev%(sep)s_antlr%(sep)sautolevparser.py" % sepd, + "%(sep)ssympy%(sep)sparsing%(sep)sautolev%(sep)s_antlr%(sep)sautolevlexer.py" % sepd, + "%(sep)ssympy%(sep)sparsing%(sep)sautolev%(sep)s_antlr%(sep)sautolevlistener.py" % sepd, + "%(sep)ssympy%(sep)sparsing%(sep)slatex%(sep)s_antlr%(sep)slatexparser.py" % sepd, + "%(sep)ssympy%(sep)sparsing%(sep)slatex%(sep)s_antlr%(sep)slatexlexer.py" % sepd, + } + # Files to exclude from the implicit import test + import_exclude = { + # glob imports are allowed in top-level __init__.py: + "%(sep)ssympy%(sep)s__init__.py" % sepd, + # these __init__.py should be fixed: + # XXX: not really, they use useful import pattern (DRY) + "%(sep)svector%(sep)s__init__.py" % sepd, + "%(sep)smechanics%(sep)s__init__.py" % sepd, + "%(sep)squantum%(sep)s__init__.py" % sepd, + "%(sep)spolys%(sep)s__init__.py" % sepd, + "%(sep)spolys%(sep)sdomains%(sep)s__init__.py" % sepd, + # interactive SymPy executes ``from sympy import *``: + "%(sep)sinteractive%(sep)ssession.py" % sepd, + # isympy.py executes ``from sympy import *``: + "%(sep)sisympy.py" % sepd, + # these two are import timing tests: + "%(sep)sbin%(sep)ssympy_time.py" % sepd, + "%(sep)sbin%(sep)ssympy_time_cache.py" % sepd, + # Taken from Python stdlib: + "%(sep)sparsing%(sep)ssympy_tokenize.py" % sepd, + # this one should be fixed: + "%(sep)splotting%(sep)spygletplot%(sep)s" % sepd, + # False positive in the docstring + "%(sep)sbin%(sep)stest_external_imports.py" % sepd, + "%(sep)sbin%(sep)stest_submodule_imports.py" % sepd, + # These are deprecated stubs that can be removed at some point: + "%(sep)sutilities%(sep)sruntests.py" % sepd, + "%(sep)sutilities%(sep)spytest.py" % sepd, + "%(sep)sutilities%(sep)srandtest.py" % sepd, + "%(sep)sutilities%(sep)stmpfiles.py" % sepd, + "%(sep)sutilities%(sep)squality_unicode.py" % sepd, + } + check_files(top_level_files, test) + check_directory_tree(BIN_PATH, test, {"~", ".pyc", ".sh"}, "*") + check_directory_tree(SYMPY_PATH, test, exclude) + check_directory_tree(EXAMPLES_PATH, test, exclude) + + +def _with_space(c): + # return c with a random amount of leading space + return random.randint(0, 10)*' ' + c + + +def test_raise_statement_regular_expression(): + candidates_ok = [ + "some text # raise Exception, 'text'", + "raise ValueError('text') # raise Exception, 'text'", + "raise ValueError('text')", + "raise ValueError", + "raise ValueError('text')", + "raise ValueError('text') #,", + # Talking about an exception in a docstring + ''''"""This function will raise ValueError, except when it doesn't"""''', + "raise (ValueError('text')", + ] + str_candidates_fail = [ + "raise 'exception'", + "raise 'Exception'", + 'raise "exception"', + 'raise "Exception"', + "raise 'ValueError'", + ] + gen_candidates_fail = [ + "raise Exception('text') # raise Exception, 'text'", + "raise Exception('text')", + "raise Exception", + "raise Exception('text')", + "raise Exception('text') #,", + "raise Exception, 'text'", + "raise Exception, 'text' # raise Exception('text')", + "raise Exception, 'text' # raise Exception, 'text'", + ">>> raise Exception, 'text'", + ">>> raise Exception, 'text' # raise Exception('text')", + ">>> raise Exception, 'text' # raise Exception, 'text'", + ] + old_candidates_fail = [ + "raise Exception, 'text'", + "raise Exception, 'text' # raise Exception('text')", + "raise Exception, 'text' # raise Exception, 'text'", + ">>> raise Exception, 'text'", + ">>> raise Exception, 'text' # raise Exception('text')", + ">>> raise Exception, 'text' # raise Exception, 'text'", + "raise ValueError, 'text'", + "raise ValueError, 'text' # raise Exception('text')", + "raise ValueError, 'text' # raise Exception, 'text'", + ">>> raise ValueError, 'text'", + ">>> raise ValueError, 'text' # raise Exception('text')", + ">>> raise ValueError, 'text' # raise Exception, 'text'", + "raise(ValueError,", + "raise (ValueError,", + "raise( ValueError,", + "raise ( ValueError,", + "raise(ValueError ,", + "raise (ValueError ,", + "raise( ValueError ,", + "raise ( ValueError ,", + ] + + for c in candidates_ok: + assert str_raise_re.search(_with_space(c)) is None, c + assert gen_raise_re.search(_with_space(c)) is None, c + assert old_raise_re.search(_with_space(c)) is None, c + for c in str_candidates_fail: + assert str_raise_re.search(_with_space(c)) is not None, c + for c in gen_candidates_fail: + assert gen_raise_re.search(_with_space(c)) is not None, c + for c in old_candidates_fail: + assert old_raise_re.search(_with_space(c)) is not None, c + + +def test_implicit_imports_regular_expression(): + candidates_ok = [ + "from sympy import something", + ">>> from sympy import something", + "from sympy.somewhere import something", + ">>> from sympy.somewhere import something", + "import sympy", + ">>> import sympy", + "import sympy.something.something", + "... import sympy", + "... import sympy.something.something", + "... from sympy import something", + "... from sympy.somewhere import something", + ">> from sympy import *", # To allow 'fake' docstrings + "# from sympy import *", + "some text # from sympy import *", + ] + candidates_fail = [ + "from sympy import *", + ">>> from sympy import *", + "from sympy.somewhere import *", + ">>> from sympy.somewhere import *", + "... from sympy import *", + "... from sympy.somewhere import *", + ] + for c in candidates_ok: + assert implicit_test_re.search(_with_space(c)) is None, c + for c in candidates_fail: + assert implicit_test_re.search(_with_space(c)) is not None, c + + +def test_test_suite_defs(): + candidates_ok = [ + " def foo():\n", + "def foo(arg):\n", + "def _foo():\n", + "def test_foo():\n", + ] + candidates_fail = [ + "def foo():\n", + "def foo() :\n", + "def foo( ):\n", + "def foo():\n", + ] + for c in candidates_ok: + assert test_suite_def_re.search(c) is None, c + for c in candidates_fail: + assert test_suite_def_re.search(c) is not None, c + + +def test_test_duplicate_defs(): + candidates_ok = [ + "def foo():\ndef foo():\n", + "def test():\ndef test_():\n", + "def test_():\ndef test__():\n", + ] + candidates_fail = [ + "def test_():\ndef test_ ():\n", + "def test_1():\ndef test_1():\n", + ] + ok = (None, 'check') + def check(file): + tests = 0 + test_set = set() + for idx, line in enumerate(file.splitlines()): + if test_ok_def_re.match(line): + tests += 1 + test_set.add(line[3:].split('(')[0].strip()) + if len(test_set) != tests: + return False, message_duplicate_test % ('check', idx + 1) + return None, 'check' + for c in candidates_ok: + assert check(c) == ok + for c in candidates_fail: + assert check(c) != ok + + +def test_find_self_assignments(): + candidates_ok = [ + "class A(object):\n def foo(self, arg): arg = self\n", + "class A(object):\n def foo(self, arg): self.prop = arg\n", + "class A(object):\n def foo(self, arg): obj, obj2 = arg, self\n", + "class A(object):\n @classmethod\n def bar(cls, arg): arg = cls\n", + "class A(object):\n def foo(var, arg): arg = var\n", + ] + candidates_fail = [ + "class A(object):\n def foo(self, arg): self = arg\n", + "class A(object):\n def foo(self, arg): obj, self = arg, arg\n", + "class A(object):\n def foo(self, arg):\n if arg: self = arg", + "class A(object):\n @classmethod\n def foo(cls, arg): cls = arg\n", + "class A(object):\n def foo(var, arg): var = arg\n", + ] + + for c in candidates_ok: + assert find_self_assignments(c) == [] + for c in candidates_fail: + assert find_self_assignments(c) != [] + + +def test_test_unicode_encoding(): + unicode_whitelist = ['foo'] + unicode_strict_whitelist = ['bar'] + + fname = 'abc' + test_file = ['α'] + raises(AssertionError, lambda: _test_this_file_encoding( + fname, test_file, unicode_whitelist, unicode_strict_whitelist)) + + fname = 'abc' + test_file = ['abc'] + _test_this_file_encoding( + fname, test_file, unicode_whitelist, unicode_strict_whitelist) + + fname = 'foo' + test_file = ['abc'] + raises(AssertionError, lambda: _test_this_file_encoding( + fname, test_file, unicode_whitelist, unicode_strict_whitelist)) + + fname = 'bar' + test_file = ['abc'] + _test_this_file_encoding( + fname, test_file, unicode_whitelist, unicode_strict_whitelist) diff --git a/phivenv/Lib/site-packages/sympy/testing/tests/test_deprecated.py b/phivenv/Lib/site-packages/sympy/testing/tests/test_deprecated.py new file mode 100644 index 0000000000000000000000000000000000000000..696933d96d6232ea869da1002ec9ebee5309724d --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/testing/tests/test_deprecated.py @@ -0,0 +1,5 @@ +from sympy.testing.pytest import warns_deprecated_sympy + +def test_deprecated_testing_randtest(): + with warns_deprecated_sympy(): + import sympy.testing.randtest # noqa:F401 diff --git a/phivenv/Lib/site-packages/sympy/testing/tests/test_module_imports.py b/phivenv/Lib/site-packages/sympy/testing/tests/test_module_imports.py new file mode 100644 index 0000000000000000000000000000000000000000..d16dbaa98156c287c18b46ff07c0ede5d26e069a --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/testing/tests/test_module_imports.py @@ -0,0 +1,42 @@ +""" +Checks that SymPy does not contain indirect imports. + +An indirect import is importing a symbol from a module that itself imported the +symbol from elsewhere. Such a constellation makes it harder to diagnose +inter-module dependencies and import order problems, and is therefore strongly +discouraged. + +(Indirect imports from end-user code is fine and in fact a best practice.) + +Implementation note: Forcing Python into actually unloading already-imported +submodules is a tricky and partly undocumented process. To avoid these issues, +the actual diagnostic code is in bin/diagnose_imports, which is run as a +separate, pristine Python process. +""" + +import subprocess +import sys +from os.path import abspath, dirname, join, normpath +import inspect + +from sympy.testing.pytest import XFAIL + +@XFAIL +def test_module_imports_are_direct(): + my_filename = abspath(inspect.getfile(inspect.currentframe())) + my_dirname = dirname(my_filename) + diagnose_imports_filename = join(my_dirname, 'diagnose_imports.py') + diagnose_imports_filename = normpath(diagnose_imports_filename) + + process = subprocess.Popen( + [ + sys.executable, + normpath(diagnose_imports_filename), + '--problems', + '--by-importer' + ], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + bufsize=-1) + output, _ = process.communicate() + assert output == '', "There are import problems:\n" + output.decode() diff --git a/phivenv/Lib/site-packages/sympy/testing/tests/test_pytest.py b/phivenv/Lib/site-packages/sympy/testing/tests/test_pytest.py new file mode 100644 index 0000000000000000000000000000000000000000..7080a4ed4602707a3efb4d9025e8e9cbe23a5ef8 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/testing/tests/test_pytest.py @@ -0,0 +1,211 @@ +import warnings + +from sympy.testing.pytest import (raises, warns, ignore_warnings, + warns_deprecated_sympy, Failed) +from sympy.utilities.exceptions import sympy_deprecation_warning + + + +# Test callables + + +def test_expected_exception_is_silent_callable(): + def f(): + raise ValueError() + raises(ValueError, f) + + +# Under pytest raises will raise Failed rather than AssertionError +def test_lack_of_exception_triggers_AssertionError_callable(): + try: + raises(Exception, lambda: 1 + 1) + assert False + except Failed as e: + assert "DID NOT RAISE" in str(e) + + +def test_unexpected_exception_is_passed_through_callable(): + def f(): + raise ValueError("some error message") + try: + raises(TypeError, f) + assert False + except ValueError as e: + assert str(e) == "some error message" + +# Test with statement + +def test_expected_exception_is_silent_with(): + with raises(ValueError): + raise ValueError() + + +def test_lack_of_exception_triggers_AssertionError_with(): + try: + with raises(Exception): + 1 + 1 + assert False + except Failed as e: + assert "DID NOT RAISE" in str(e) + + +def test_unexpected_exception_is_passed_through_with(): + try: + with raises(TypeError): + raise ValueError("some error message") + assert False + except ValueError as e: + assert str(e) == "some error message" + +# Now we can use raises() instead of try/catch +# to test that a specific exception class is raised + + +def test_second_argument_should_be_callable_or_string(): + raises(TypeError, lambda: raises("irrelevant", 42)) + + +def test_warns_catches_warning(): + with warnings.catch_warnings(record=True) as w: + with warns(UserWarning): + warnings.warn('this is the warning message') + assert len(w) == 0 + + +def test_warns_raises_without_warning(): + with raises(Failed): + with warns(UserWarning): + pass + + +def test_warns_hides_other_warnings(): + with raises(RuntimeWarning): + with warns(UserWarning): + warnings.warn('this is the warning message', UserWarning) + warnings.warn('this is the other message', RuntimeWarning) + + +def test_warns_continues_after_warning(): + with warnings.catch_warnings(record=True) as w: + finished = False + with warns(UserWarning): + warnings.warn('this is the warning message') + finished = True + assert finished + assert len(w) == 0 + + +def test_warns_many_warnings(): + with warns(UserWarning): + warnings.warn('this is the warning message', UserWarning) + warnings.warn('this is the other warning message', UserWarning) + + +def test_warns_match_matching(): + with warnings.catch_warnings(record=True) as w: + with warns(UserWarning, match='this is the warning message'): + warnings.warn('this is the warning message', UserWarning) + assert len(w) == 0 + + +def test_warns_match_non_matching(): + with warnings.catch_warnings(record=True) as w: + with raises(Failed): + with warns(UserWarning, match='this is the warning message'): + warnings.warn('this is not the expected warning message', UserWarning) + assert len(w) == 0 + +def _warn_sympy_deprecation(stacklevel=3): + sympy_deprecation_warning( + "feature", + active_deprecations_target="active-deprecations", + deprecated_since_version="0.0.0", + stacklevel=stacklevel, + ) + +def test_warns_deprecated_sympy_catches_warning(): + with warnings.catch_warnings(record=True) as w: + with warns_deprecated_sympy(): + _warn_sympy_deprecation() + assert len(w) == 0 + + +def test_warns_deprecated_sympy_raises_without_warning(): + with raises(Failed): + with warns_deprecated_sympy(): + pass + +def test_warns_deprecated_sympy_wrong_stacklevel(): + with raises(Failed): + with warns_deprecated_sympy(): + _warn_sympy_deprecation(stacklevel=1) + +def test_warns_deprecated_sympy_doesnt_hide_other_warnings(): + # Unlike pytest's deprecated_call, we should not hide other warnings. + with raises(RuntimeWarning): + with warns_deprecated_sympy(): + _warn_sympy_deprecation() + warnings.warn('this is the other message', RuntimeWarning) + + +def test_warns_deprecated_sympy_continues_after_warning(): + with warnings.catch_warnings(record=True) as w: + finished = False + with warns_deprecated_sympy(): + _warn_sympy_deprecation() + finished = True + assert finished + assert len(w) == 0 + +def test_ignore_ignores_warning(): + with warnings.catch_warnings(record=True) as w: + with ignore_warnings(UserWarning): + warnings.warn('this is the warning message') + assert len(w) == 0 + + +def test_ignore_does_not_raise_without_warning(): + with warnings.catch_warnings(record=True) as w: + with ignore_warnings(UserWarning): + pass + assert len(w) == 0 + + +def test_ignore_allows_other_warnings(): + with warnings.catch_warnings(record=True) as w: + # This is needed when pytest is run as -Werror + # the setting is reverted at the end of the catch_Warnings block. + warnings.simplefilter("always") + with ignore_warnings(UserWarning): + warnings.warn('this is the warning message', UserWarning) + warnings.warn('this is the other message', RuntimeWarning) + assert len(w) == 1 + assert isinstance(w[0].message, RuntimeWarning) + assert str(w[0].message) == 'this is the other message' + + +def test_ignore_continues_after_warning(): + with warnings.catch_warnings(record=True) as w: + finished = False + with ignore_warnings(UserWarning): + warnings.warn('this is the warning message') + finished = True + assert finished + assert len(w) == 0 + + +def test_ignore_many_warnings(): + with warnings.catch_warnings(record=True) as w: + # This is needed when pytest is run as -Werror + # the setting is reverted at the end of the catch_Warnings block. + warnings.simplefilter("always") + with ignore_warnings(UserWarning): + warnings.warn('this is the warning message', UserWarning) + warnings.warn('this is the other message', RuntimeWarning) + warnings.warn('this is the warning message', UserWarning) + warnings.warn('this is the other message', RuntimeWarning) + warnings.warn('this is the other message', RuntimeWarning) + assert len(w) == 3 + for wi in w: + assert isinstance(wi.message, RuntimeWarning) + assert str(wi.message) == 'this is the other message' diff --git a/phivenv/Lib/site-packages/sympy/testing/tests/test_runtests_pytest.py b/phivenv/Lib/site-packages/sympy/testing/tests/test_runtests_pytest.py new file mode 100644 index 0000000000000000000000000000000000000000..cd56d831f3618c9dbb8a1dffe63d95a0befc6adf --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/testing/tests/test_runtests_pytest.py @@ -0,0 +1,171 @@ +import pathlib +from typing import List + +import pytest + +from sympy.testing.runtests_pytest import ( + make_absolute_path, + sympy_dir, + update_args_with_paths, +) + + +class TestMakeAbsolutePath: + + @staticmethod + @pytest.mark.parametrize( + 'partial_path', ['sympy', 'sympy/core', 'sympy/nonexistant_directory'], + ) + def test_valid_partial_path(partial_path: str): + """Paths that start with `sympy` are valid.""" + _ = make_absolute_path(partial_path) + + @staticmethod + @pytest.mark.parametrize( + 'partial_path', ['not_sympy', 'also/not/sympy'], + ) + def test_invalid_partial_path_raises_value_error(partial_path: str): + """A `ValueError` is raises on paths that don't start with `sympy`.""" + with pytest.raises(ValueError): + _ = make_absolute_path(partial_path) + + +class TestUpdateArgsWithPaths: + + @staticmethod + def test_no_paths(): + """If no paths are passed, only `sympy` and `doc/src` are appended. + + `sympy` and `doc/src` are the `testpaths` stated in `pytest.ini`. They + need to be manually added as if any path-related arguments are passed + to `pytest.main` then the settings in `pytest.ini` may be ignored. + + """ + paths = [] + args = update_args_with_paths(paths=paths, keywords=None, args=[]) + expected = [ + str(pathlib.Path(sympy_dir(), 'sympy')), + str(pathlib.Path(sympy_dir(), 'doc/src')), + ] + assert args == expected + + @staticmethod + @pytest.mark.parametrize( + 'path', + ['sympy/core/tests/test_basic.py', '_basic'] + ) + def test_one_file(path: str): + """Single files/paths, full or partial, are matched correctly.""" + args = update_args_with_paths(paths=[path], keywords=None, args=[]) + expected = [ + str(pathlib.Path(sympy_dir(), 'sympy/core/tests/test_basic.py')), + ] + assert args == expected + + @staticmethod + def test_partial_path_from_root(): + """Partial paths from the root directly are matched correctly.""" + args = update_args_with_paths(paths=['sympy/functions'], keywords=None, args=[]) + expected = [str(pathlib.Path(sympy_dir(), 'sympy/functions'))] + assert args == expected + + @staticmethod + def test_multiple_paths_from_root(): + """Multiple paths, partial or full, are matched correctly.""" + paths = ['sympy/core/tests/test_basic.py', 'sympy/functions'] + args = update_args_with_paths(paths=paths, keywords=None, args=[]) + expected = [ + str(pathlib.Path(sympy_dir(), 'sympy/core/tests/test_basic.py')), + str(pathlib.Path(sympy_dir(), 'sympy/functions')), + ] + assert args == expected + + @staticmethod + @pytest.mark.parametrize( + 'paths, expected_paths', + [ + ( + ['/core', '/util'], + [ + 'doc/src/modules/utilities', + 'doc/src/reference/public/utilities', + 'sympy/core', + 'sympy/logic/utilities', + 'sympy/utilities', + ] + ), + ] + ) + def test_multiple_paths_from_non_root(paths: List[str], expected_paths: List[str]): + """Multiple partial paths are matched correctly.""" + args = update_args_with_paths(paths=paths, keywords=None, args=[]) + assert len(args) == len(expected_paths) + for arg, expected in zip(sorted(args), expected_paths): + assert expected in arg + + @staticmethod + @pytest.mark.parametrize( + 'paths', + [ + + [], + ['sympy/physics'], + ['sympy/physics/mechanics'], + ['sympy/physics/mechanics/tests'], + ['sympy/physics/mechanics/tests/test_kane3.py'], + ] + ) + def test_string_as_keyword(paths: List[str]): + """String keywords are matched correctly.""" + keywords = ('bicycle', ) + args = update_args_with_paths(paths=paths, keywords=keywords, args=[]) + expected_args = ['sympy/physics/mechanics/tests/test_kane3.py::test_bicycle'] + assert len(args) == len(expected_args) + for arg, expected in zip(sorted(args), expected_args): + assert expected in arg + + @staticmethod + @pytest.mark.parametrize( + 'paths', + [ + + [], + ['sympy/core'], + ['sympy/core/tests'], + ['sympy/core/tests/test_sympify.py'], + ] + ) + def test_integer_as_keyword(paths: List[str]): + """Integer keywords are matched correctly.""" + keywords = ('3538', ) + args = update_args_with_paths(paths=paths, keywords=keywords, args=[]) + expected_args = ['sympy/core/tests/test_sympify.py::test_issue_3538'] + assert len(args) == len(expected_args) + for arg, expected in zip(sorted(args), expected_args): + assert expected in arg + + @staticmethod + def test_multiple_keywords(): + """Multiple keywords are matched correctly.""" + keywords = ('bicycle', '3538') + args = update_args_with_paths(paths=[], keywords=keywords, args=[]) + expected_args = [ + 'sympy/core/tests/test_sympify.py::test_issue_3538', + 'sympy/physics/mechanics/tests/test_kane3.py::test_bicycle', + ] + assert len(args) == len(expected_args) + for arg, expected in zip(sorted(args), expected_args): + assert expected in arg + + @staticmethod + def test_keyword_match_in_multiple_files(): + """Keywords are matched across multiple files.""" + keywords = ('1130', ) + args = update_args_with_paths(paths=[], keywords=keywords, args=[]) + expected_args = [ + 'sympy/integrals/tests/test_heurisch.py::test_heurisch_symbolic_coeffs_1130', + 'sympy/utilities/tests/test_lambdify.py::test_python_div_zero_issue_11306', + ] + assert len(args) == len(expected_args) + for arg, expected in zip(sorted(args), expected_args): + assert expected in arg diff --git a/phivenv/Lib/site-packages/sympy/testing/tmpfiles.py b/phivenv/Lib/site-packages/sympy/testing/tmpfiles.py new file mode 100644 index 0000000000000000000000000000000000000000..1d5c69cb58aa11f77679855f3df21f03a10d3b2b --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/testing/tmpfiles.py @@ -0,0 +1,46 @@ +""" +This module adds context manager for temporary files generated by the tests. +""" + +import shutil +import os + + +class TmpFileManager: + """ + A class to track record of every temporary files created by the tests. + """ + tmp_files = set('') + tmp_folders = set('') + + @classmethod + def tmp_file(cls, name=''): + cls.tmp_files.add(name) + return name + + @classmethod + def tmp_folder(cls, name=''): + cls.tmp_folders.add(name) + return name + + @classmethod + def cleanup(cls): + while cls.tmp_files: + file = cls.tmp_files.pop() + if os.path.isfile(file): + os.remove(file) + while cls.tmp_folders: + folder = cls.tmp_folders.pop() + shutil.rmtree(folder) + +def cleanup_tmp_files(test_func): + """ + A decorator to help test codes remove temporary files after the tests. + """ + def wrapper_function(): + try: + test_func() + finally: + TmpFileManager.cleanup() + + return wrapper_function